[
  {
    "path": ".editorconfig",
    "content": "# https://EditorConfig.org\n\n# Top-most EditorConfig file\nroot = true\n\n# Unix-style newlines with a newline ending every file, utf-8 charset\n[*]\nend_of_line = lf\ninsert_final_newline = true\ntrim_trailing_whitespace = true\ncharset = utf-8\nindent_style = space\nindent_size = 4\n\n[*.md]\nindent_size = 2\n\n[Makefile]\nindent_style = tab\n\n[prompts/*.txt]\ninsert_final_newline = unset\n"
  },
  {
    "path": ".github/pull_request_template.md",
    "content": "*For changes to the core `ggml` library (including to the CMake build system), please open a PR in https://github.com/ggml-org/llama.cpp. Doing so will make your PR more visible, better tested and more likely to be reviewed.*\n"
  },
  {
    "path": ".github/workflows/ci.yml",
    "content": "name: CI\n\non:\n  push:\n    branches: [ master ]\n  pull_request:\n    branches: [ master ]\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  build:\n    strategy:\n      matrix:\n        os: [ubuntu-latest, macos-latest, windows-latest]\n        libraries: [shared, static]\n\n    runs-on: ${{ matrix.os }}\n\n    steps:\n    - name: Clone\n      uses: actions/checkout@v6\n\n    - name: Dependencies for Ubuntu\n      if: matrix.os == 'ubuntu-latest'\n      run: |\n        sudo apt-get update\n        sudo apt-get install llvm\n\n    - name: Add msbuild to PATH\n      if: matrix.os == 'windows-latest'\n      uses: microsoft/setup-msbuild@v2\n\n    - name: Create Build Environment\n      run: mkdir build\n\n    - name: Configure CMake\n      working-directory: ./build\n      run: cmake ..\n        ${{ contains(matrix.os, 'windows') && '-A x64' || '-G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++' }}\n        ${{ matrix.libraries == 'static' && '-DBUILD_SHARED_LIBS=OFF' || '-DBUILD_SHARED_LIBS=ON' }}\n        -DCMAKE_INSTALL_PREFIX=${{ github.workspace }}/installed\n        -DGGML_METAL=OFF\n\n    - name: Build\n      working-directory: ./build\n      run: cmake --build . ${{ contains(matrix.os, 'windows') && '--config Release' || '' }}\n\n    - name: Test\n      working-directory: ./build\n      run: ctest --verbose --timeout 900 ${{ contains(matrix.os, 'windows') && '--build-config Release' || '' }}\n\n    - name: Install\n      working-directory: ./build\n      run: cmake --build . --target install ${{ contains(matrix.os, 'windows') && '--config Release' || '' }}\n\n    - name: Test CMake config\n      run: |\n        mkdir test-cmake\n        cmake -S examples/test-cmake -B test-cmake -DCMAKE_PREFIX_PATH=${{ github.workspace }}/installed ${{ contains(matrix.os, 'windows') && '-A x64' || '-G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++' }}\n        cmake --build test-cmake ${{ contains(matrix.os, 'windows') && '--config Release' || '' }}\n\n# TODO: simplify the following workflows using a matrix\n  ggml-ci-x64-cpu-low-perf:\n    runs-on: ubuntu-22.04\n\n    steps:\n      - name: Clone\n        id: checkout\n        uses: actions/checkout@v6\n\n      - name: ccache\n        uses: ggml-org/ccache-action@v1.2.16\n        with:\n          key: ggml-ci-x64-cpu-low-perf\n          evict-old-files: 1d\n\n      - name: Dependencies\n        id: depends\n        run: |\n          sudo apt-get update\n          sudo apt-get install build-essential libcurl4-openssl-dev\n\n      - name: Test\n        id: ggml-ci\n        run: |\n          LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt\n\n  ggml-ci-arm64-cpu-low-perf:\n    runs-on: ubuntu-22.04-arm\n\n    steps:\n      - name: Clone\n        id: checkout\n        uses: actions/checkout@v6\n\n      - name: ccache\n        uses: ggml-org/ccache-action@v1.2.16\n        with:\n          key: ggml-ci-arm64-cpu-low-perf\n          evict-old-files: 1d\n\n      - name: Dependencies\n        id: depends\n        run: |\n          sudo apt-get update\n          sudo apt-get install build-essential libcurl4-openssl-dev\n\n      - name: Test\n        id: ggml-ci\n        run: |\n          LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt\n\n  ggml-ci-x64-cpu-high-perf:\n    runs-on: ubuntu-22.04\n\n    steps:\n      - name: Clone\n        id: checkout\n        uses: actions/checkout@v6\n\n      - name: ccache\n        uses: ggml-org/ccache-action@v1.2.16\n        with:\n          key: ggml-ci-x64-cpu-high-perf\n          evict-old-files: 1d\n\n      - name: Dependencies\n        id: depends\n        run: |\n          sudo apt-get update\n          sudo apt-get install build-essential libcurl4-openssl-dev\n\n      - name: Test\n        id: ggml-ci\n        run: |\n          LLAMA_ARG_THREADS=$(nproc) bash ./ci/run.sh ./tmp/results ./tmp/mnt\n\n  ggml-ci-arm64-cpu-high-perf:\n    runs-on: ubuntu-22.04-arm\n\n    steps:\n      - name: Clone\n        id: checkout\n        uses: actions/checkout@v6\n\n      - name: ccache\n        uses: ggml-org/ccache-action@v1.2.16\n        with:\n          key: ggml-ci-arm64-cpu-high-perf\n          evict-old-files: 1d\n\n      - name: Dependencies\n        id: depends\n        run: |\n          sudo apt-get update\n          sudo apt-get install build-essential libcurl4-openssl-dev\n\n      - name: Test\n        id: ggml-ci\n        run: |\n          LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_SVE=1 GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt\n\n  ggml-ci-arm64-cpu-high-perf-sve:\n    runs-on: ubuntu-22.04-arm\n\n    steps:\n      - name: Clone\n        id: checkout\n        uses: actions/checkout@v6\n\n      - name: ccache\n        uses: ggml-org/ccache-action@v1.2.16\n        with:\n          key: ggml-ci-arm64-cpu-high-perf-sve\n          evict-old-files: 1d\n\n      - name: Dependencies\n        id: depends\n        run: |\n          sudo apt-get update\n          sudo apt-get install build-essential libcurl4-openssl-dev\n\n      - name: Test\n        id: ggml-ci\n        run: |\n          LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt\n\n  ggml-ci-x64-nvidia-cuda:\n    runs-on: [self-hosted, Linux, X64, NVIDIA]\n\n    steps:\n      - name: Clone\n        id: checkout\n        uses: actions/checkout@v6\n\n      - name: Test\n        id: ggml-ci\n        run: |\n          nvidia-smi\n          GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/ggml /mnt/ggml\n\n  ggml-ci-x64-nvidia-vulkan-cm:\n    runs-on: [self-hosted, Linux, X64, NVIDIA]\n\n    steps:\n      - name: Clone\n        id: checkout\n        uses: actions/checkout@v6\n\n      - name: Test\n        id: ggml-ci\n        run: |\n          vulkaninfo --summary\n          GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/ggml /mnt/ggml\n\n  ggml-ci-x64-nvidia-vulkan-cm2:\n    runs-on: [self-hosted, Linux, X64, NVIDIA, COOPMAT2]\n\n    steps:\n      - name: Clone\n        id: checkout\n        uses: actions/checkout@v6\n\n      - name: Test\n        id: ggml-ci\n        run: |\n          vulkaninfo --summary\n          GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/ggml /mnt/ggml\n\n  # TODO: provision AMX-compatible machine\n  #ggml-ci-x64-cpu-amx:\n  #  runs-on: [self-hosted, Linux, X64, CPU, AMX]\n\n  #  steps:\n  #    - name: Clone\n  #      id: checkout\n  #      uses: actions/checkout@v6\n\n  #    - name: Test\n  #      id: ggml-ci\n  #      run: |\n  #        bash ./ci/run.sh ~/results/ggml /mnt/ggml\n\n  ggml-ci-mac-metal:\n    runs-on: [self-hosted, macOS, ARM64]\n\n    steps:\n      - name: Clone\n        id: checkout\n        uses: actions/checkout@v6\n\n      - name: Test\n        id: ggml-ci\n        run: |\n          GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/ggml ~/mnt/ggml\n\n  ggml-ci-mac-vulkan:\n    runs-on: [self-hosted, macOS, ARM64]\n\n    steps:\n      - name: Clone\n        id: checkout\n        uses: actions/checkout@v6\n\n      - name: Test\n        id: ggml-ci\n        run: |\n          vulkaninfo --summary\n          GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/ggml ~/mnt/ggml\n"
  },
  {
    "path": ".github/workflows/release.yml",
    "content": "name: Release\n\non:\n  push:\n    tags:\n      - 'v*'\n\njobs:\n  release:\n    runs-on: ubuntu-latest\n    permissions:\n      contents: write\n\n    steps:\n    - name: Checkout code\n      uses: actions/checkout@v6\n\n    - name: Create Release\n      id: create_release\n      uses: ggml-org/action-create-release@v1\n      env:\n        GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n      with:\n          tag_name: ${{ github.ref_name }}\n          release_name: ${{ github.ref }}\n          draft: false\n          prerelease: false\n"
  },
  {
    "path": ".gitignore",
    "content": "build/\nbuild-*/\nout/\ntmp/\nmodels/\nmodels-mnt\n\ncompile_commands.json\nCMakeSettings.json\n.vs/\n.vscode/\n.idea/\n.clangd\n\n.venv/\nggml_env/\n.exrc\n.cache\n.DS_Store\n.stablelm\n.gpt-2\n\nsrc/arm_neon.h\ntests/arm_neon.h\n\nzig-out/\nzig-cache/\n\n*.o\n*.d\n*.dot\n\n*.sw?\n\n__pycache__/\n\n# Model files\nggml-model-f16.bin\n*.bat\n"
  },
  {
    "path": ".gitmodules",
    "content": ""
  },
  {
    "path": "AUTHORS",
    "content": "# date: Tue Feb  4 13:03:51 EET 2025\n# this file is auto-generated by scripts/gen-authors.sh\n\n0cc4m <picard12@live.de>\n65a <10104049+65a@users.noreply.github.com>\nAT <manyoso@users.noreply.github.com>\nAbhilash Majumder <30946547+abhilash1910@users.noreply.github.com>\nAdam Tazi <52357206+ad1tazi@users.noreply.github.com>\nAdrien Gallouët <adrien@gallouet.fr>\nAdrien Gallouët <angt@huggingface.co>\nAhmad Tameem <113388789+Tameem-10xE@users.noreply.github.com>\nAidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com>\nAidanBeltonS <aidan.belton@codeplay.com>\nAkarshan Biswas <akarshan.biswas@gmail.com>\nAkarshan Biswas <akarshanbiswas@fedoraproject.org>\nAlbert Jin <albert.jin@gmail.com>\nAlberto Cabrera Pérez <alberto.cabrera@codeplay.com>\nAlberto Cabrera Pérez <alberto.cabrera@intel.com>\nAlex Azarov <alex@azarov.by>\nAlex O'Connell <35843486+acon96@users.noreply.github.com>\nAlex von Gluck IV <kallisti5@unixzen.com>\nAmbientL <107641468+AmbientL@users.noreply.github.com>\nAmirAli Mirian <37371367+amiralimi@users.noreply.github.com>\nAnanta Bastola <anantarajbastola@gmail.com>\nAndreas (Andi) Kunar <andreask@msn.com>\nAndreas Kieslinger <47689530+aendk@users.noreply.github.com>\nAndrei <abetlen@gmail.com>\nAndrew Minh Nguyen <40281306+amqdn@users.noreply.github.com>\nAndrii Ryzhkov <andriiryzhkov@users.noreply.github.com>\nArjun <ccldarjun@icloud.com>\nAshraful Islam <ashraful.meche@gmail.com>\nAstariul <43774355+astariul@users.noreply.github.com>\nAsukaMinato <asukaminato@nyan.eu.org>\nAvi Lumelsky <avilume@gmail.com>\nBart Pelle <3662930+Velocity-@users.noreply.github.com>\nBen Ashbaugh <ben.ashbaugh@intel.com>\nBernhard M. Wiedemann <githubbmwprimary@lsmod.de>\nBorislav Stanimirov <b.stanimirov@abv.bg>\nBrad Ito <phlogisticfugu@users.noreply.github.com>\nBrad Murray <59848399+bradmurray-dt@users.noreply.github.com>\nBrian <mofosyne@gmail.com>\nBryan Lozano <b.lozano.havoc@gmail.com>\nCarolinabanana <140120812+Carolinabanana@users.noreply.github.com>\nCarterLi999 <664681047@qq.com>\nCebtenzzre <cebtenzzre@gmail.com>\nChangyeon Kim <cyzero.kim@samsung.com>\nCharles Xu <63788048+chaxu01@users.noreply.github.com>\nCharles Xu <charles.xu@arm.com>\nChen Xi <xi2.chen@intel.com>\nChen Xi <xixichen08@foxmail.com>\nChenguang Li <87689256+noemotiovon@users.noreply.github.com>\nChris Elrod <elrodc@gmail.com>\nChristian Kastner <ckk@kvr.at>\nClint Herron <hanclinto@gmail.com>\nConrad Kramer <conrad@conradkramer.com>\nCordeiro <1471463+ocordeiro@users.noreply.github.com>\nCristiano Calcagno <cristianoc@users.noreply.github.com>\nDAN™ <dranger003@gmail.com>\nDan Forbes <dan@danforbes.dev>\nDan Johansson <164997844+eddnjjn@users.noreply.github.com>\nDan Johansson <dan.johansson@arm.com>\nDaniel Bevenius <daniel.bevenius@gmail.com>\nDaniel Ziegenberg <daniel@ziegenberg.at>\nDaniele <57776841+daniandtheweb@users.noreply.github.com>\nDaulet Zhanguzin <daulet@users.noreply.github.com>\nDave <dave-fl@users.noreply.github.com>\nDave Airlie <airlied@gmail.com>\nDave Airlie <airlied@redhat.com>\nDavid Miller <david@patagona.ca>\nDavidKorczynski <david@adalogics.com>\nDavidson Francis <davidsondfgl@gmail.com>\nDibakar Gope <dibakar.gope@arm.com>\nDidzis Gosko <didzis@users.noreply.github.com>\nDiego Devesa <slarengh@gmail.com>\nDiogo <dgcruz983@gmail.com>\nDjip007 <3705339+Djip007@users.noreply.github.com>\nDjip007 <djip.perois@free.fr>\nDou Xinpeng <15529241576@163.com>\nDou Xinpeng <81913537+Dou-Git@users.noreply.github.com>\nDr. Tom Murphy VII Ph.D <499244+tom7@users.noreply.github.com>\nEbey Abraham <ebey97@gmail.com>\nEldar Yusupov <eyusupov@gmail.com>\nEmmanuel Durand <emmanueldurand@protonmail.com>\nEngininja2 <139037756+Engininja2@users.noreply.github.com>\nEric Zhang <34133756+EZForever@users.noreply.github.com>\nErik Scholz <Green-Sky@users.noreply.github.com>\nEttore Di Giacinto <mudler@users.noreply.github.com>\nEve <139727413+netrunnereve@users.noreply.github.com>\nF1L1P <78918286+F1L1Pv2@users.noreply.github.com>\nFaisal Zaghloul <quic_fzaghlou@quicinc.com>\nFantasyGmm <16450052+FantasyGmm@users.noreply.github.com>\nFelix <stenbackfelix@gmail.com>\nFinn Voorhees <finnvoorhees@gmail.com>\nFirstTimeEZ <179362031+FirstTimeEZ@users.noreply.github.com>\nFrankie Robertson <frankier@users.noreply.github.com>\nGainLee <perfecter.gen@gmail.com>\nGeorge Hindle <george@georgehindle.com>\nGeorgi Gerganov <ggerganov@gmail.com>\nGilad S <7817232+giladgd@users.noreply.github.com>\nGilad S <giladgd@users.noreply.github.com>\nGilad S. <7817232+giladgd@users.noreply.github.com>\nGuillaume Wenzek <gwenzek@users.noreply.github.com>\nHalalaluyafail3 <55773281+Halalaluyafail3@users.noreply.github.com>\nHaus1 <haus.xda@gmail.com>\nHerman Semenov <GermanAizek@yandex.ru>\nHimariO <dsfhe49854@gmail.com>\nHirochika Matsumoto <git@hkmatsumoto.com>\nHong Bo PENG <penghb@cn.ibm.com>\nHugo Rosenkranz-Costa <hugo.rosenkranz@gmail.com>\nHyunsung Lee <ita9naiwa@gmail.com>\nIGUILIZ Salah-Eddine <76955987+salahiguiliz@users.noreply.github.com>\nIan Bull <irbull@eclipsesource.com>\nIhar Hrachyshka <ihrachys@redhat.com>\nIkko Eltociear Ashimine <eltociear@gmail.com>\nIvan <nekotekina@gmail.com>\nIvan Filipov <159561759+vanaka11@users.noreply.github.com>\nIvan Stepanov <ivanstepanovftw@gmail.com>\nIvan Zdane <accounts@ivanzdane.com>\nJack Mousseau <jmousseau@users.noreply.github.com>\nJack Vial <vialjack@gmail.com>\nJacobLinCool <jacoblincool@gmail.com>\nJakob Frick <jakob.maria.frick@gmail.com>\nJan Ploski <jpl@plosquare.com>\nJared Van Bortel <jared@nomic.ai>\nJeff Bolz <jbolz@nvidia.com>\nJeffrey Quesnelle <jquesnelle@gmail.com>\nJeroen Mostert <jeroen.mostert@cm.com>\nJiahao Li <liplus17@163.com>\nJidongZhang-THU <1119708529@qq.com>\nJiří Podivín <66251151+jpodivin@users.noreply.github.com>\nJo Liss <joliss42@gmail.com>\nJoe Todd <joe.todd@codeplay.com>\nJohannes Gäßler <johannesg@5d6.de>\nJohn Balis <phobossystems@gmail.com>\nJosh Bleecher Snyder <josharian@gmail.com>\nJudd <foldl@users.noreply.github.com>\nJun Hee Yoo <contact.jhyoo@gmail.com>\nJunil Kim <logyourself@gmail.com>\nJustina Cho <justcho5@gmail.com>\nJustine Tunney <jtunney@gmail.com>\nJustine Tunney <jtunney@mozilla.com>\nKarol Kontny <82021046+kkontny@users.noreply.github.com>\nKawrakow <48489457+ikawrakow@users.noreply.github.com>\nKevin Gibbons <bakkot@gmail.com>\nKonstantin Zhuravlyov <konstantin.zhuravlyov@amd.com>\nKylin <56434533+KyL0N@users.noreply.github.com>\nLoganDark <git@logandark.mozmail.com>\nLoganDark <github@logandark.mozmail.com>\nLostRuins <39025047+LostRuins@users.noreply.github.com>\nLukas Möller <mail@lukas-moeller.ch>\nM Refi D.A <24388107+refinism@users.noreply.github.com>\nM. Yusuf Sarıgöz <yusufsarigoz@gmail.com>\nMa Mingfei <mingfei.ma@intel.com>\nMahesh Madhav <67384846+heshpdx@users.noreply.github.com>\nMaiHD <maihd.dev@gmail.com>\nMark Zhuang <zhuangqiubin@gmail.com>\nMarkus Tavenrath <mtavenrath@users.noreply.github.com>\nMasaya, Kato <62578291+msy-kato@users.noreply.github.com>\nMathieu Baudier <mbaudier@argeo.org>\nMathijs de Bruin <mathijs@mathijsfietst.nl>\nMatt Stephenson <mstephenson6@users.noreply.github.com>\nMax Krasnyansky <max.krasnyansky@gmail.com>\nMax Krasnyansky <quic_maxk@quicinc.com>\nMayank Kumar Pal <mynkpl1998@gmail.com>\nMeng, Hengyu <hengyu.meng@intel.com>\nMengqing Cao <cmq0113@163.com>\nMetal Whale <45712559+metalwhale@users.noreply.github.com>\nMichael Klimenko <mklimenko29@gmail.com>\nMichael Podvitskiy <podvitskiymichael@gmail.com>\nMichael Verrilli <msv@pobox.com>\nMolly Sophia <mollysophia379@gmail.com>\nNatsu <chino@hotococoa.moe>\nNeo Zhang <14088817+arthw@users.noreply.github.com>\nNeo Zhang Jianyu <jianyu.zhang@intel.com>\nNeuman Vong <neuman.vong@gmail.com>\nNevin <nevinpuri1901@gmail.com>\nNicholai Tukanov <nicholaitukanov@gmail.com>\nNico Bosshard <nico@bosshome.ch>\nNicolò Scipione <nicolo.scipione@codeplay.com>\nNikita Sarychev <42014488+sARY77@users.noreply.github.com>\nNouamane Tazi <nouamane98@gmail.com>\nOlivier Chafik <ochafik@google.com>\nOlivier Chafik <ochafik@users.noreply.github.com>\nOndřej Čertík <ondrej@certik.us>\nOuadie EL FAROUKI <ouadie.elfarouki@codeplay.com>\nPAB <pierreantoine.bannier@gmail.com>\nPaul Tsochantaris <ptsochantaris@icloud.com>\nPeter <peter277@users.noreply.github.com>\nPhilpax <me@philpax.me>\nPierre Alexandre SCHEMBRI <pa.schembri@gmail.com>\nPlamen Minev <pacominev@gmail.com>\nPlaydev <josang1204@gmail.com>\nPrashant Vithule <119530321+Vithulep@users.noreply.github.com>\nPrzemysław Pawełczyk <przemoc@gmail.com>\nR0CKSTAR <xiaodong.ye@mthreads.com>\nR0CKSTAR <yeahdongcn@gmail.com>\nRadoslav Gerganov <rgerganov@gmail.com>\nRadosław Gryta <radek.gryta@gmail.com>\nRavindra Marella <marella@users.noreply.github.com>\nRay Cromwell <cromwellian@gmail.com>\nReinforce-II <fate@eastal.com>\nRémy Oudompheng <oudomphe@phare.normalesup.org>\nReza Rezvan <reza@rezvan.xyz>\nRick G <26732651+TheFlipbook@users.noreply.github.com>\nRiverZhou <riverzhou2000@gmail.com>\nRobert Ormandi <52251610+ormandi@users.noreply.github.com>\nRomain Biessy <romain.biessy@codeplay.com>\nRonsor <ronsor@ronsor.pw>\nRotem Dan <rotemdan@gmail.com>\nRyan Hitchman <hitchmanr@gmail.com>\nSRHMorris <69468379+SRHMorris@users.noreply.github.com>\nSXX <sxx1136965276@gmail.com>\nSalvatore Mesoraca <s.mesoraca16@gmail.com>\nSam Spilsbury <smspillaz@gmail.com>\nSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>\nSanttu Keskinen <santtu.keskinen@gmail.com>\nSergio López <slp@redhat.com>\nSergio López <slp@sinrega.org>\nShanshan Shen <467638484@qq.com>\nShijie <821898965@qq.com>\nShupei Fan <dymarkfan@outlook.com>\nSiddharth Ramakrishnan <srr2141@columbia.edu>\nSigbjørn Skjæret <sigbjorn.skjaeret@scala.com>\nSkyler Celestinian-Sterling <80314197+Celestinian@users.noreply.github.com>\nSlava Primenko <primenko.s@gmail.com>\nSrihari-mcw <96763064+Srihari-mcw@users.noreply.github.com>\nSteward Garcia <57494570+FSSRepo@users.noreply.github.com>\nSupreet Sethi <supreet.sethi@gmail.com>\nTakuya Takeuchi <takuya.takeuchi.dev@gmail.com>\nTamotsu Takahashi <ttakah+github@gmail.com>\nTanmay <tnmysachan@gmail.com>\nTanmay Sachan <tnmysachan@gmail.com>\nTimothy Cronin <40186632+4imothy@users.noreply.github.com>\nTom Bailey <tombailey@users.noreply.github.com>\nTom Jobbins <784313+TheBloke@users.noreply.github.com>\nTony Wasserka <4840017+neobrain@users.noreply.github.com>\nTristan Druyen <tristan@vault81.mozmail.com>\nTyé singwa <92231658+tye-singwa@users.noreply.github.com>\nUEXTM.com <84163508+uextm@users.noreply.github.com>\nWillCorticesAI <150854901+WillCorticesAI@users.noreply.github.com>\nWilliam Tambellini <william.tambellini@gmail.com>\nWilliam Tambellini <wtambellini@sdl.com>\nXiaotaoChen <chenxiaotao1234@gmail.com>\nXinpeng Dou <81913537+Dou-Git@users.noreply.github.com>\nXuan Son Nguyen <thichthat@gmail.com>\nYavor Ivanov <yivanov@viewray.com>\nYavorGIvanov <yivanov@viewray.com>\nYilong Guo <vfirst218@gmail.com>\nYilong Guo <yilong.guo@intel.com>\nYuri Khrustalev <ykhrustalev@users.noreply.github.com>\nZhenwei Jin <109658203+kylo5aby@users.noreply.github.com>\nZhiyuan Li <lizhiyuan@uniartisan.com>\nZhiyuan Li <uniartisan2017@gmail.com>\na3sh <38979186+A3shTnT@users.noreply.github.com>\nag2s20150909 <19373730+ag2s20150909@users.noreply.github.com>\nagray3 <agray3@users.noreply.github.com>\namd-dwang <dong.wang@amd.com>\namritahs-ibm <amritahs@linux.vnet.ibm.com>\napcameron <37645737+apcameron@users.noreply.github.com>\nappvoid <78444142+appvoid@users.noreply.github.com>\nariez-xyz <41232910+ariez-xyz@users.noreply.github.com>\nautomaticcat <daogiatuank54@gmail.com>\nbandoti <141645996+bandoti@users.noreply.github.com>\nbmwl <brian.marshall@tolko.com>\nbobqianic <129547291+bobqianic@users.noreply.github.com>\nbssrdf <merlintiger@hotmail.com>\nchengchi <davesjoewang@gmail.com>\ncompilade <113953597+compilade@users.noreply.github.com>\ncompilade <git@compilade.net>\nddpasa <112642920+ddpasa@users.noreply.github.com>\ndenersc <denerstassun@gmail.com>\ndscripka <dscripka@users.noreply.github.com>\nfitzsim <fitzsim@fitzsim.org>\nfj-y-saito <85871716+fj-y-saito@users.noreply.github.com>\nfraxy-v <65565042+fraxy-v@users.noreply.github.com>\ngn64 <yukikaze.jp@gmail.com>\ngoerch <jhr.walter@t-online.de>\ngoldwaving <77494627+goldwaving@users.noreply.github.com>\nhaopeng <657407891@qq.com>\nhidenorly <hidenorly@users.noreply.github.com>\nhipudding <huafengchun@gmail.com>\nhydai <z54981220@gmail.com>\nissixx <46835150+issixx@users.noreply.github.com>\njaeminSon <woalsdnd@gmail.com>\njdomke <28772296+jdomke@users.noreply.github.com>\njiez <373447296@qq.com>\njohnson442 <56517414+johnson442@users.noreply.github.com>\njunchao-loongson <68935141+junchao-loongson@users.noreply.github.com>\nk.h.lai <adrian.k.h.lai@outlook.com>\nkatsu560 <118887472+katsu560@users.noreply.github.com>\nklosax <131523366+klosax@users.noreply.github.com>\nkunnis <kunnis@users.noreply.github.com>\nl3utterfly <gc.pthzfoldr@gmail.com>\nle.chang <cljs118@126.com>\nleejet <31925346+leejet@users.noreply.github.com>\nleejet <leejet714@gmail.com>\nleo-pony <nengjunma@outlook.com>\nlhez <quic_lih@quicinc.com>\nliuwei-git <14815172+liuwei-git@users.noreply.github.com>\nluoyu-intel <yu.luo@intel.com>\nmagicse <magicse@users.noreply.github.com>\nmahorozte <41834471+mahorozte@users.noreply.github.com>\nmashizora <30516315+mashizora@users.noreply.github.com>\nmatt23654 <matthew.webber@protonmail.com>\nmatteo <matteogeniaccio@yahoo.it>\nochafik <ochafik@google.com>\notaGran <ujt2h8@gmail.com>\npengxin99 <pengxin.yuan@intel.com>\npikalover6 <49179590+pikalover6@users.noreply.github.com>\npostmasters <namnguyen@google.com>\nsjinzh <sjinzh@gmail.com>\nskirodev <57715494+skirodev@users.noreply.github.com>\nslaren <slarengh@gmail.com>\nsnadampal <87143774+snadampal@users.noreply.github.com>\nsomeone13574 <81528246+someone13574@users.noreply.github.com>\nstduhpf <stephduh@live.fr>\ntaher <8665427+nullhook@users.noreply.github.com>\ntexmex76 <40733439+texmex76@users.noreply.github.com>\nthe-crypt-keeper <84680712+the-crypt-keeper@users.noreply.github.com>\nthewh1teagle <61390950+thewh1teagle@users.noreply.github.com>\nucag.li <ucag@qq.com>\nulatekh <ulatekh@yahoo.com>\nuvos <devnull@uvos.xyz>\nuvos <philipp@uvos.xyz>\nwangshuai09 <391746016@qq.com>\nwoachk <24752637+woachk@users.noreply.github.com>\nxctan <axunlei@gmail.com>\nyangyaofei <yangyaofei@gmail.com>\nyuri@FreeBSD <yuri@FreeBSD>\nzhentaoyu <zhentao.yu@intel.com>\nzhouwg <6889919+zhouwg@users.noreply.github.com>\nzhouwg <zhouwg2000@gmail.com>\n谢乃闻 <sienaiwun@users.noreply.github.com>\n布客飞龙 <562826179@qq.com>\n旺旺碎冰冰 <38837039+Cyberhan123@users.noreply.github.com>\n"
  },
  {
    "path": "CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories.\nproject(\"ggml\" C CXX ASM)\n\n### GGML Version\nset(GGML_VERSION_MAJOR 0)\nset(GGML_VERSION_MINOR 9)\nset(GGML_VERSION_PATCH 8)\nset(GGML_VERSION_BASE \"${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}\")\n\nfind_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)\nif(GIT_EXE)\n    # Get current git commit hash\n    execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD\n        WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}\n        OUTPUT_VARIABLE GGML_BUILD_COMMIT\n        OUTPUT_STRIP_TRAILING_WHITESPACE\n        ERROR_QUIET\n    )\n\n    # Check if the working directory is dirty (i.e., has uncommitted changes)\n    execute_process(COMMAND ${GIT_EXE} diff-index --quiet HEAD -- .\n        WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}\n        RESULT_VARIABLE GGML_GIT_DIRTY\n        ERROR_QUIET\n    )\nendif()\n\nset(GGML_VERSION \"${GGML_VERSION_BASE}\")\n\nif(NOT GGML_BUILD_COMMIT)\n    set(GGML_BUILD_COMMIT \"unknown\")\nendif()\n\n# Build the commit string with optional dirty flag\nif(DEFINED GGML_GIT_DIRTY AND GGML_GIT_DIRTY EQUAL 1)\n    set(GGML_BUILD_COMMIT \"${GGML_BUILD_COMMIT}-dirty\")\nendif()\n\ninclude(CheckIncludeFileCXX)\n\nset(CMAKE_EXPORT_COMPILE_COMMANDS ON)\n\nif (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)\n    set(CMAKE_BUILD_TYPE Release CACHE STRING \"Build type\" FORCE)\n    set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS \"Debug\" \"Release\" \"MinSizeRel\" \"RelWithDebInfo\")\nendif()\n\nif (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)\n    set(GGML_STANDALONE ON)\n\n    set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)\n\n    # configure project version\n    # TODO\nelse()\n    set(GGML_STANDALONE OFF)\n\n    if (NOT CMAKE_RUNTIME_OUTPUT_DIRECTORY)\n        set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)\n    endif()\nendif()\n\nif (EMSCRIPTEN)\n    set(BUILD_SHARED_LIBS_DEFAULT OFF)\n\n    option(GGML_WASM_SINGLE_FILE \"ggml: embed WASM inside the generated ggml.js\" ON)\nelse()\n    if (MINGW)\n        set(BUILD_SHARED_LIBS_DEFAULT OFF)\n    else()\n        set(BUILD_SHARED_LIBS_DEFAULT ON)\n    endif()\nendif()\n\n# remove the lib prefix on win32 mingw\nif (WIN32)\n    set(CMAKE_STATIC_LIBRARY_PREFIX \"\")\n    set(CMAKE_SHARED_LIBRARY_PREFIX \"\")\n    set(CMAKE_SHARED_MODULE_PREFIX  \"\")\nendif()\n\noption(BUILD_SHARED_LIBS           \"ggml: build shared libraries\" ${BUILD_SHARED_LIBS_DEFAULT})\noption(GGML_BACKEND_DL             \"ggml: build backends as dynamic libraries (requires BUILD_SHARED_LIBS)\" OFF)\nset(GGML_BACKEND_DIR \"\" CACHE PATH \"ggml: directory to load dynamic backends from (requires GGML_BACKEND_DL\")\n\n#\n# option list\n#\n\n# TODO: mark all options as advanced when not GGML_STANDALONE\n\nif (APPLE)\n    set(GGML_METAL_DEFAULT ON)\n    set(GGML_BLAS_DEFAULT ON)\n    set(GGML_BLAS_VENDOR_DEFAULT \"Apple\")\nelse()\n    set(GGML_METAL_DEFAULT OFF)\n    set(GGML_BLAS_DEFAULT OFF)\n    set(GGML_BLAS_VENDOR_DEFAULT \"Generic\")\nendif()\n\nif (CMAKE_CROSSCOMPILING OR DEFINED ENV{SOURCE_DATE_EPOCH})\n    message(STATUS \"Setting GGML_NATIVE_DEFAULT to OFF\")\n    set(GGML_NATIVE_DEFAULT OFF)\nelse()\n    set(GGML_NATIVE_DEFAULT ON)\nendif()\n\n# defaults\nif (NOT GGML_LLAMAFILE_DEFAULT)\n    set(GGML_LLAMAFILE_DEFAULT OFF)\nendif()\n\nif (NOT GGML_CUDA_GRAPHS_DEFAULT)\n    set(GGML_CUDA_GRAPHS_DEFAULT OFF)\nendif()\n\n# general\noption(GGML_STATIC \"ggml: static link libraries\"                     OFF)\noption(GGML_NATIVE \"ggml: optimize the build for the current system\" ${GGML_NATIVE_DEFAULT})\noption(GGML_LTO    \"ggml: enable link time optimization\"             OFF)\noption(GGML_CCACHE \"ggml: use ccache if available\"                   ON)\n\n# debug\noption(GGML_ALL_WARNINGS           \"ggml: enable all compiler warnings\"                   ON)\noption(GGML_ALL_WARNINGS_3RD_PARTY \"ggml: enable all compiler warnings in 3rd party libs\" OFF)\noption(GGML_GPROF                  \"ggml: enable gprof\"                                   OFF)\n\n# build\noption(GGML_FATAL_WARNINGS    \"ggml: enable -Werror flag\"    OFF)\n\n# sanitizers\noption(GGML_SANITIZE_THREAD    \"ggml: enable thread sanitizer\"    OFF)\noption(GGML_SANITIZE_ADDRESS   \"ggml: enable address sanitizer\"   OFF)\noption(GGML_SANITIZE_UNDEFINED \"ggml: enable undefined sanitizer\" OFF)\n\n# instruction set specific\nif (GGML_NATIVE OR NOT GGML_NATIVE_DEFAULT)\n    set(INS_ENB OFF)\nelse()\n    set(INS_ENB ON)\nendif()\n\nmessage(DEBUG \"GGML_NATIVE         : ${GGML_NATIVE}\")\nmessage(DEBUG \"GGML_NATIVE_DEFAULT : ${GGML_NATIVE_DEFAULT}\")\nmessage(DEBUG \"INS_ENB             : ${INS_ENB}\")\n\noption(GGML_CPU_HBM          \"ggml: use memkind for CPU HBM\" OFF)\noption(GGML_CPU_REPACK       \"ggml: use runtime weight conversion of Q4_0 to Q4_X_X\" ON)\noption(GGML_CPU_KLEIDIAI     \"ggml: use KleidiAI optimized kernels if applicable\" OFF)\noption(GGML_SSE42            \"ggml: enable SSE 4.2\"          ${INS_ENB})\noption(GGML_AVX              \"ggml: enable AVX\"              ${INS_ENB})\noption(GGML_AVX_VNNI         \"ggml: enable AVX-VNNI\"         OFF)\noption(GGML_AVX2             \"ggml: enable AVX2\"             ${INS_ENB})\noption(GGML_BMI2             \"ggml: enable BMI2\"             ${INS_ENB})\noption(GGML_AVX512           \"ggml: enable AVX512F\"          OFF)\noption(GGML_AVX512_VBMI      \"ggml: enable AVX512-VBMI\"      OFF)\noption(GGML_AVX512_VNNI      \"ggml: enable AVX512-VNNI\"      OFF)\noption(GGML_AVX512_BF16      \"ggml: enable AVX512-BF16\"      OFF)\nif (NOT MSVC)\n    # in MSVC F16C and FMA is implied with AVX2/AVX512\n    option(GGML_FMA          \"ggml: enable FMA\"              ${INS_ENB})\n    option(GGML_F16C         \"ggml: enable F16C\"             ${INS_ENB})\n    # MSVC does not seem to support AMX\n    option(GGML_AMX_TILE     \"ggml: enable AMX-TILE\"         OFF)\n    option(GGML_AMX_INT8     \"ggml: enable AMX-INT8\"         OFF)\n    option(GGML_AMX_BF16     \"ggml: enable AMX-BF16\"         OFF)\nendif()\noption(GGML_LASX             \"ggml: enable lasx\"             ON)\noption(GGML_LSX              \"ggml: enable lsx\"              ON)\noption(GGML_RVV              \"ggml: enable rvv\"              ON)\noption(GGML_RV_ZFH           \"ggml: enable riscv zfh\"        ON)\noption(GGML_RV_ZVFH          \"ggml: enable riscv zvfh\"       ON)\noption(GGML_RV_ZICBOP        \"ggml: enable riscv zicbop\"     ON)\noption(GGML_RV_ZIHINTPAUSE   \"ggml: enable riscv zihintpause \"  ON)\noption(GGML_XTHEADVECTOR     \"ggml: enable xtheadvector\"     OFF)\noption(GGML_VXE              \"ggml: enable vxe\"              ${GGML_NATIVE})\n\noption(GGML_CPU_ALL_VARIANTS \"ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)\" OFF)\nset(GGML_CPU_ARM_ARCH        \"\" CACHE STRING \"ggml: CPU architecture for ARM\")\nset(GGML_CPU_POWERPC_CPUTYPE \"\" CACHE STRING \"ggml: CPU type for PowerPC\")\n\n# ggml core\nset(GGML_SCHED_MAX_COPIES  \"4\" CACHE STRING \"ggml: max input copies for pipeline parallelism\")\noption(GGML_CPU                             \"ggml: enable CPU backend\"                        ON)\noption(GGML_SCHED_NO_REALLOC                \"ggml: disallow reallocations in ggml-alloc (for debugging)\" OFF)\n\n# 3rd party libs / backends\noption(GGML_ACCELERATE                      \"ggml: enable Accelerate framework\"               ON)\noption(GGML_BLAS                            \"ggml: use BLAS\"                                  ${GGML_BLAS_DEFAULT})\nset(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING\n                                            \"ggml: BLAS library vendor\")\noption(GGML_LLAMAFILE                       \"ggml: use LLAMAFILE\"                             ${GGML_LLAMAFILE_DEFAULT})\n\noption(GGML_CUDA                            \"ggml: use CUDA\"                                  OFF)\noption(GGML_MUSA                            \"ggml: use MUSA\"                                  OFF)\noption(GGML_CUDA_FORCE_MMQ                  \"ggml: use mmq kernels instead of cuBLAS\"         OFF)\noption(GGML_CUDA_FORCE_CUBLAS               \"ggml: always use cuBLAS instead of mmq kernels\"  OFF)\nset   (GGML_CUDA_PEER_MAX_BATCH_SIZE \"128\" CACHE STRING\n                                            \"ggml: max. batch size for using peer access\")\noption(GGML_CUDA_NO_PEER_COPY               \"ggml: do not use peer to peer copies\"            OFF)\noption(GGML_CUDA_NO_VMM                     \"ggml: do not try to use CUDA VMM\"                OFF)\noption(GGML_CUDA_FA                         \"ggml: compile ggml FlashAttention CUDA kernels\"  ON)\noption(GGML_CUDA_FA_ALL_QUANTS              \"ggml: compile all quants for FlashAttention\"     OFF)\noption(GGML_CUDA_GRAPHS                     \"ggml: use CUDA graphs (llama.cpp only)\"          ${GGML_CUDA_GRAPHS_DEFAULT})\nset   (GGML_CUDA_COMPRESSION_MODE \"size\" CACHE STRING\n                                            \"ggml: cuda link binary compression mode; requires cuda 12.8+\")\nset_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS \"none;speed;balance;size\")\n\noption(GGML_HIP                             \"ggml: use HIP\"                                   OFF)\noption(GGML_HIP_GRAPHS                      \"ggml: use HIP graph, experimental, slow\"         OFF)\noption(GGML_HIP_NO_VMM                      \"ggml: do not try to use HIP VMM\"                 ON)\noption(GGML_HIP_ROCWMMA_FATTN               \"ggml: enable rocWMMA for FlashAttention\"         OFF)\noption(GGML_HIP_MMQ_MFMA                    \"ggml: enable MFMA MMA for CDNA in MMQ\"           ON)\noption(GGML_HIP_EXPORT_METRICS              \"ggml: enable kernel perf metrics output\"         OFF)\noption(GGML_MUSA_GRAPHS                     \"ggml: use MUSA graph, experimental, unstable\"    OFF)\noption(GGML_MUSA_MUDNN_COPY                 \"ggml: enable muDNN for accelerated copy\"         OFF)\noption(GGML_VULKAN                          \"ggml: use Vulkan\"                                OFF)\noption(GGML_VULKAN_CHECK_RESULTS            \"ggml: run Vulkan op checks\"                      OFF)\noption(GGML_VULKAN_DEBUG                    \"ggml: enable Vulkan debug output\"                OFF)\noption(GGML_VULKAN_MEMORY_DEBUG             \"ggml: enable Vulkan memory debug output\"         OFF)\noption(GGML_VULKAN_SHADER_DEBUG_INFO        \"ggml: enable Vulkan shader debug info\"           OFF)\noption(GGML_VULKAN_VALIDATE                 \"ggml: enable Vulkan validation\"                  OFF)\noption(GGML_VULKAN_RUN_TESTS                \"ggml: run Vulkan tests\"                          OFF)\noption(GGML_WEBGPU                          \"ggml: use WebGPU\"                                OFF)\noption(GGML_WEBGPU_DEBUG                    \"ggml: enable WebGPU debug output\"                OFF)\noption(GGML_WEBGPU_CPU_PROFILE              \"ggml: enable WebGPU profiling (CPU)\"             OFF)\noption(GGML_WEBGPU_GPU_PROFILE              \"ggml: enable WebGPU profiling (GPU)\"             OFF)\noption(GGML_WEBGPU_JSPI                     \"ggml: use JSPI for WebGPU\"                       ON)\noption(GGML_ZDNN                            \"ggml: use zDNN\"                                  OFF)\noption(GGML_VIRTGPU                         \"ggml: use the VirtGPU/Virglrenderer API Remoting frontend\"     OFF)\noption(GGML_VIRTGPU_BACKEND                 \"ggml: build the VirtGPU/Virglrenderer API Remoting backend\"    OFF)\noption(GGML_METAL                           \"ggml: use Metal\"                                 ${GGML_METAL_DEFAULT})\noption(GGML_METAL_NDEBUG                    \"ggml: disable Metal debugging\"                   OFF)\noption(GGML_METAL_SHADER_DEBUG              \"ggml: compile Metal with -fno-fast-math\"         OFF)\noption(GGML_METAL_EMBED_LIBRARY             \"ggml: embed Metal library\"                       ${GGML_METAL})\nset   (GGML_METAL_MACOSX_VERSION_MIN \"\" CACHE STRING\n                                            \"ggml: metal minimum macOS version\")\nset   (GGML_METAL_STD \"\" CACHE STRING       \"ggml: metal standard version (-std flag)\")\noption(GGML_OPENMP                          \"ggml: use OpenMP\"                                ON)\noption(GGML_RPC                             \"ggml: use RPC\"                                   OFF)\noption(GGML_SYCL                            \"ggml: use SYCL\"                                  OFF)\noption(GGML_SYCL_F16                        \"ggml: use 16 bit floats for sycl calculations\"   OFF)\noption(GGML_SYCL_GRAPH                      \"ggml: enable graphs in the SYCL backend\"         ON)\noption(GGML_SYCL_DNN                        \"ggml: enable oneDNN in the SYCL backend\"         ON)\nset   (GGML_SYCL_TARGET \"INTEL\" CACHE STRING\n                                            \"ggml: sycl target device\")\nset   (GGML_SYCL_DEVICE_ARCH \"\" CACHE STRING\n                                            \"ggml: sycl device architecture\")\n\noption(GGML_OPENVINO                        \"ggml: use OPENVINO\"                              OFF)\n\noption(GGML_OPENCL                          \"ggml: use OpenCL\"                                OFF)\noption(GGML_OPENCL_PROFILING                \"ggml: use OpenCL profiling (increases overhead)\" OFF)\noption(GGML_OPENCL_EMBED_KERNELS            \"ggml: embed kernels\"                             ON)\noption(GGML_OPENCL_USE_ADRENO_KERNELS       \"ggml: use optimized kernels for Adreno\"          ON)\nset   (GGML_OPENCL_TARGET_VERSION \"300\" CACHE STRING\n                                            \"ggml: OpenCL API version to target\")\n\noption(GGML_HEXAGON                         \"ggml: enable Hexagon backend\"                    OFF)\nset(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING \"ggml: quantize group size (32, 64, or 128)\")\n\n# toolchain for vulkan-shaders-gen\nset   (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN \"\" CACHE FILEPATH \"ggml: toolchain file for vulkan-shaders-gen\")\n\noption(GGML_ZENDNN                          \"ggml: use ZenDNN\"                                OFF)\noption(ZENDNN_ROOT                          \"ggml: path to ZenDNN installation\"               \"\")\n\n# extra artifacts\noption(GGML_BUILD_TESTS    \"ggml: build tests\"    ${GGML_STANDALONE})\noption(GGML_BUILD_EXAMPLES \"ggml: build examples\" ${GGML_STANDALONE})\n\n#\n# dependencies\n#\n\nset(CMAKE_C_STANDARD 11)\nset(CMAKE_C_STANDARD_REQUIRED true)\n\nset(CMAKE_CXX_STANDARD 17)\nset(CMAKE_CXX_STANDARD_REQUIRED true)\n\nset(THREADS_PREFER_PTHREAD_FLAG ON)\n\nfind_package(Threads REQUIRED)\n\ninclude(GNUInstallDirs)\n\n#\n# build the library\n#\n\nadd_subdirectory(src)\n\n#\n# tests and examples\n#\n\nif (GGML_BUILD_TESTS)\n    enable_testing()\n    add_subdirectory(tests)\nendif ()\n\nif (GGML_BUILD_EXAMPLES)\n    add_subdirectory(examples)\nendif ()\n\n#\n# install\n#\n\ninclude(CMakePackageConfigHelpers)\n\n# all public headers\nset(GGML_PUBLIC_HEADERS\n    include/ggml.h\n    include/ggml-cpu.h\n    include/ggml-alloc.h\n    include/ggml-backend.h\n    include/ggml-blas.h\n    include/ggml-cann.h\n    include/ggml-cpp.h\n    include/ggml-cuda.h\n    include/ggml-opt.h\n    include/ggml-metal.h\n    include/ggml-rpc.h\n    include/ggml-virtgpu.h\n    include/ggml-sycl.h\n    include/ggml-vulkan.h\n    include/ggml-webgpu.h\n    include/ggml-zendnn.h\n    include/ggml-openvino.h\n    include/gguf.h)\n\nset_target_properties(ggml PROPERTIES PUBLIC_HEADER \"${GGML_PUBLIC_HEADERS}\")\n#if (GGML_METAL)\n#    set_target_properties(ggml PROPERTIES RESOURCE \"${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal\")\n#endif()\ninstall(TARGETS ggml LIBRARY PUBLIC_HEADER)\ninstall(TARGETS ggml-base LIBRARY)\n\nif (GGML_STANDALONE)\n    configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ggml.pc.in\n        ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc\n        @ONLY)\n\n    install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc\n        DESTINATION share/pkgconfig)\nendif()\n\n#\n# Create CMake package\n#\n\n\n\n# Capture variables prefixed with GGML_.\n\nset(variable_set_statements\n\"\n####### Expanded from @GGML_VARIABLES_EXPANED@ by configure_package_config_file() #######\n####### Any changes to this file will be overwritten by the next CMake run        #######\n\n\")\n\nset(GGML_SHARED_LIB ${BUILD_SHARED_LIBS})\n\nget_cmake_property(all_variables VARIABLES)\nforeach(variable_name IN LISTS all_variables)\n    if(variable_name MATCHES \"^GGML_\")\n        string(REPLACE \";\" \"\\\\;\"\n               variable_value \"${${variable_name}}\")\n\n        set(variable_set_statements\n            \"${variable_set_statements}set(${variable_name} \\\"${variable_value}\\\")\\n\")\n    endif()\nendforeach()\n\nset(GGML_VARIABLES_EXPANDED ${variable_set_statements})\n\n# Create the CMake package and set install location.\n\nset(GGML_INSTALL_VERSION ${GGML_VERSION})\nset(GGML_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH \"Location of header  files\")\nset(GGML_LIB_INSTALL_DIR     ${CMAKE_INSTALL_LIBDIR}     CACHE PATH \"Location of library files\")\nset(GGML_BIN_INSTALL_DIR     ${CMAKE_INSTALL_BINDIR}     CACHE PATH \"Location of binary  files\")\n\nconfigure_package_config_file(\n        ${CMAKE_CURRENT_SOURCE_DIR}/cmake/ggml-config.cmake.in\n        ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake\n    INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml\n    PATH_VARS GGML_INCLUDE_INSTALL_DIR\n              GGML_LIB_INSTALL_DIR\n              GGML_BIN_INSTALL_DIR)\n\nwrite_basic_package_version_file(\n        ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake\n    VERSION ${GGML_INSTALL_VERSION}\n    COMPATIBILITY SameMajorVersion)\n\ntarget_compile_definitions(ggml-base PRIVATE\n    GGML_VERSION=\"${GGML_INSTALL_VERSION}\"\n    GGML_COMMIT=\"${GGML_BUILD_COMMIT}\"\n)\nmessage(STATUS \"ggml version: ${GGML_INSTALL_VERSION}\")\nmessage(STATUS \"ggml commit:  ${GGML_BUILD_COMMIT}\")\n\ninstall(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake\n              ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake\n        DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml)\n\nif (MSVC)\n    set(MSVC_WARNING_FLAGS\n        /wd4005  # Macro redefinition\n        /wd4244  # Conversion from one type to another type, possible loss of data\n        /wd4267  # Conversion from 'size_t' to a smaller type, possible loss of data\n        /wd4305  # Conversion from 'type1' to 'type2', possible loss of data\n        /wd4566  # Conversion from 'char' to 'wchar_t', possible loss of data\n        /wd4996  # Disable POSIX deprecation warnings\n        /wd4702  # Unreachable code warnings\n    )\n    set(MSVC_COMPILE_OPTIONS\n        \"$<$<COMPILE_LANGUAGE:C>:/utf-8>\"\n        \"$<$<COMPILE_LANGUAGE:CXX>:/utf-8>\"\n    )\n    function(configure_msvc_target target_name)\n        if(TARGET ${target_name})\n            target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS})\n            target_compile_options(${target_name} PRIVATE ${MSVC_COMPILE_OPTIONS})\n        endif()\n    endfunction()\n\n    configure_msvc_target(ggml-base)\n    configure_msvc_target(ggml)\n    configure_msvc_target(ggml-cpu)\n    configure_msvc_target(ggml-cpu-x64)\n    configure_msvc_target(ggml-cpu-sse42)\n    configure_msvc_target(ggml-cpu-sandybridge)\n    # __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512\n    # skipping            ggml-cpu-ivybridge\n    # skipping            ggml-cpu-piledriver\n    configure_msvc_target(ggml-cpu-haswell)\n    configure_msvc_target(ggml-cpu-skylakex)\n    configure_msvc_target(ggml-cpu-cannonlake)\n    configure_msvc_target(ggml-cpu-cascadelake)\n    configure_msvc_target(ggml-cpu-icelake)\n    # MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!\n    # https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170\n    # https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170\n    # skipping            ggml-cpu-cooperlake\n    # skipping            ggml-cpu-zen4\n    configure_msvc_target(ggml-cpu-alderlake)\n    # MSVC doesn't support AMX\n    # skipping            ggml-cpu-sapphirerapids\n\n    if (GGML_BUILD_EXAMPLES)\n        configure_msvc_target(common-ggml)\n        configure_msvc_target(common)\n\n        configure_msvc_target(mnist-common)\n        configure_msvc_target(mnist-eval)\n        configure_msvc_target(mnist-train)\n\n        configure_msvc_target(gpt-2-ctx)\n        configure_msvc_target(gpt-2-alloc)\n        configure_msvc_target(gpt-2-backend)\n        configure_msvc_target(gpt-2-sched)\n        configure_msvc_target(gpt-2-quantize)\n        configure_msvc_target(gpt-2-batched)\n\n        configure_msvc_target(gpt-j)\n        configure_msvc_target(gpt-j-quantize)\n\n        configure_msvc_target(magika)\n        configure_msvc_target(yolov3-tiny)\n        configure_msvc_target(sam)\n\n        configure_msvc_target(simple-ctx)\n        configure_msvc_target(simple-backend)\n    endif()\n\n    if (GGML_BUILD_TESTS)\n        configure_msvc_target(test-mul-mat)\n        configure_msvc_target(test-arange)\n        configure_msvc_target(test-backend-ops)\n        configure_msvc_target(test-cont)\n        configure_msvc_target(test-conv-transpose)\n        configure_msvc_target(test-conv-transpose-1d)\n        configure_msvc_target(test-conv1d)\n        configure_msvc_target(test-conv2d)\n        configure_msvc_target(test-conv2d-dw)\n        configure_msvc_target(test-customop)\n        configure_msvc_target(test-dup)\n        configure_msvc_target(test-opt)\n        configure_msvc_target(test-pool)\n    endif ()\nendif()\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "Please use [llama.cpp's contribution guidelines](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md) for this project.\n\n*For changes to the core `ggml` library (including to the CMake build system), please open a PR in https://github.com/ggml-org/llama.cpp. Doing so will make your PR more visible, better tested and more likely to be reviewed.*\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023-2026 The ggml authors\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# ggml\n\n[Manifesto](https://github.com/ggerganov/llama.cpp/discussions/205)\n\nTensor library for machine learning\n\n***Note that this project is under active development. \\\nSome of the development is currently happening in the [llama.cpp](https://github.com/ggerganov/llama.cpp) and [whisper.cpp](https://github.com/ggerganov/whisper.cpp) repos***\n\n## Features\n\n- Low-level cross-platform implementation\n- Integer quantization support\n- Broad hardware support\n- Automatic differentiation\n- ADAM and L-BFGS optimizers\n- No third-party dependencies\n- Zero memory allocations during runtime\n\n## Build\n\n```bash\ngit clone https://github.com/ggml-org/ggml\ncd ggml\n\n# install python dependencies in a virtual environment\npython3.10 -m venv .venv\nsource .venv/bin/activate\npip install -r requirements.txt\n\n# build the examples\nmkdir build && cd build\ncmake ..\ncmake --build . --config Release -j 8\n```\n\n## GPT inference (example)\n\n```bash\n# run the GPT-2 small 117M model\n../examples/gpt-2/download-ggml-model.sh 117M\n./bin/gpt-2-backend -m models/gpt-2-117M/ggml-model.bin -p \"This is an example\"\n```\n\nFor more information, checkout the corresponding programs in the [examples](examples) folder.\n\n## Resources\n\n- [Introduction to ggml](https://huggingface.co/blog/introduction-to-ggml)\n- [The GGUF file format](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md)\n"
  },
  {
    "path": "ci/run.sh",
    "content": "#/bin/bash\n#\n# sample usage:\n#\n# mkdir tmp\n#\n# # CPU-only build\n# bash ./ci/run.sh ./tmp/results ./tmp/mnt\n#\n# # with CUDA support\n# GG_BUILD_CUDA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt\n#\n# # With SYCL support\n# GG_BUILD_SYCL=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt\n#\n\nif [ -z \"$2\" ]; then\n    echo \"usage: $0 <output-dir> <mnt-dir>\"\n    exit 1\nfi\n\nmkdir -p \"$1\"\nmkdir -p \"$2\"\n\nOUT=$(realpath \"$1\")\nMNT=$(realpath \"$2\")\n\nrm -v $OUT/*.log\nrm -v $OUT/*.exit\nrm -v $OUT/*.md\n\nsd=`dirname $0`\ncd $sd/../\nSRC=`pwd`\n\nCMAKE_EXTRA=\"\"\nCTEST_EXTRA=\"\"\n\nif [ ! -z ${GG_BUILD_METAL} ]; then\n    CMAKE_EXTRA=\"${CMAKE_EXTRA} -DGGML_METAL=ON\"\nfi\n\nif [ ! -z ${GG_BUILD_CUDA} ]; then\n    CMAKE_EXTRA=\"${CMAKE_EXTRA} -DGGML_CUDA=ON\"\n\n    if command -v nvidia-smi >/dev/null 2>&1; then\n        CUDA_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits 2>/dev/null | head -1 | tr -d '.')\n        if [[ -n \"$CUDA_ARCH\" && \"$CUDA_ARCH\" =~ ^[0-9]+$ ]]; then\n            CMAKE_EXTRA=\"${CMAKE_EXTRA} -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCH}\"\n        else\n            echo \"Warning: Using fallback CUDA architectures\"\n            CMAKE_EXTRA=\"${CMAKE_EXTRA} -DCMAKE_CUDA_ARCHITECTURES=61;70;75;80;86;89\"\n        fi\n    else\n        echo \"Error: nvidia-smi not found, cannot build with CUDA\"\n        exit 1\n    fi\nfi\n\nif [ ! -z ${GG_BUILD_ROCM} ]; then\n    CMAKE_EXTRA=\"${CMAKE_EXTRA} -DGGML_HIP=ON\"\n    if [ -z ${GG_BUILD_AMDGPU_TARGETS} ]; then\n        echo \"Missing GG_BUILD_AMDGPU_TARGETS, please set it to your GPU architecture (e.g. gfx90a, gfx1100, etc.)\"\n        exit 1\n    fi\n\n    CMAKE_EXTRA=\"${CMAKE_EXTRA} -DAMDGPU_TARGETS=${GG_BUILD_AMDGPU_TARGETS}\"\nfi\n\nif [ ! -z ${GG_BUILD_SYCL} ]; then\n    if [ -z ${ONEAPI_ROOT} ]; then\n        echo \"Not detected ONEAPI_ROOT, please install oneAPI base toolkit and enable it by:\"\n        echo \"source /opt/intel/oneapi/setvars.sh\"\n        exit 1\n    fi\n    # Use only main GPU\n    export ONEAPI_DEVICE_SELECTOR=\"level_zero:0\"\n    # Enable sysman for correct memory reporting\n    export ZES_ENABLE_SYSMAN=1\n    # to circumvent precision issues on CPY operations\n    export SYCL_PROGRAM_COMPILE_OPTIONS=\"-cl-fp32-correctly-rounded-divide-sqrt\"\n    CMAKE_EXTRA=\"${CMAKE_EXTRA} -DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON\"\nfi\n\nif [ ! -z ${GG_BUILD_VULKAN} ]; then\n    CMAKE_EXTRA=\"${CMAKE_EXTRA} -DGGML_VULKAN=1\"\n\n    # if on Mac, disable METAL\n    if [[ \"$OSTYPE\" == \"darwin\"* ]]; then\n        CMAKE_EXTRA=\"${CMAKE_EXTRA} -DGGML_METAL=OFF -DGGML_BLAS=OFF\"\n    fi\n\nfi\n\nif [ ! -z ${GG_BUILD_WEBGPU} ]; then\n    CMAKE_EXTRA=\"${CMAKE_EXTRA} -DGGML_WEBGPU=1\"\nfi\n\nif [ ! -z ${GG_BUILD_MUSA} ]; then\n    # Use qy1 by default (MTT S80)\n    MUSA_ARCH=${MUSA_ARCH:-21}\n    CMAKE_EXTRA=\"${CMAKE_EXTRA} -DGGML_MUSA=ON -DMUSA_ARCHITECTURES=${MUSA_ARCH}\"\nfi\n\nif [ ! -z ${GG_BUILD_NO_SVE} ]; then\n    # arm 9 and newer enables sve by default, adjust these flags depending on the cpu used\n    CMAKE_EXTRA=\"${CMAKE_EXTRA} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm\"\nfi\n\n## helpers\n\n# download a file if it does not exist or if it is outdated\nfunction gg_wget {\n    local out=$1\n    local url=$2\n\n    local cwd=`pwd`\n\n    mkdir -p $out\n    cd $out\n\n    # should not re-download if file is the same\n    wget -nv -N $url\n\n    cd $cwd\n}\n\nfunction gg_printf {\n    printf -- \"$@\" >> $OUT/README.md\n}\n\nfunction gg_run {\n    ci=$1\n\n    set -o pipefail\n    set -x\n\n    gg_run_$ci | tee $OUT/$ci.log\n    cur=$?\n    echo \"$cur\" > $OUT/$ci.exit\n\n    set +x\n    set +o pipefail\n\n    gg_sum_$ci\n\n    ret=$((ret | cur))\n}\n\n## ci\n\n# ctest_debug\n\nfunction gg_run_ctest_debug {\n    cd ${SRC}\n\n    rm -rf build-ci-debug && mkdir build-ci-debug && cd build-ci-debug\n\n    set -e\n\n    (time cmake -DCMAKE_BUILD_TYPE=Debug ${CMAKE_EXTRA} ..     ) 2>&1 | tee -a $OUT/${ci}-cmake.log\n    (time make -j$(nproc)                                      ) 2>&1 | tee -a $OUT/${ci}-make.log\n\n    (time ctest ${CTEST_EXTRA} --output-on-failure -E \"test-opt|test-backend-ops\" ) 2>&1 | tee -a $OUT/${ci}-ctest.log\n\n    set +e\n}\n\nfunction gg_sum_ctest_debug {\n    gg_printf '### %s\\n\\n' \"${ci}\"\n\n    gg_printf 'Runs ctest in debug mode\\n'\n    gg_printf '- status: %s\\n' \"$(cat $OUT/${ci}.exit)\"\n    gg_printf '```\\n'\n    gg_printf '%s\\n' \"$(cat $OUT/${ci}-ctest.log)\"\n    gg_printf '```\\n'\n    gg_printf '\\n'\n}\n\n# ctest_release\n\nfunction gg_run_ctest_release {\n    cd ${SRC}\n\n    rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release\n\n    set -e\n\n    (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} ..   ) 2>&1 | tee -a $OUT/${ci}-cmake.log\n    (time make -j$(nproc)                                      ) 2>&1 | tee -a $OUT/${ci}-make.log\n\n    if [ -z $GG_BUILD_LOW_PERF ]; then\n        (time ctest ${CTEST_EXTRA} --output-on-failure ) 2>&1 | tee -a $OUT/${ci}-ctest.log\n    else\n        (time ctest ${CTEST_EXTRA} --output-on-failure -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log\n    fi\n\n    set +e\n}\n\nfunction gg_sum_ctest_release {\n    gg_printf '### %s\\n\\n' \"${ci}\"\n\n    gg_printf 'Runs ctest in release mode\\n'\n    gg_printf '- status: %s\\n' \"$(cat $OUT/${ci}.exit)\"\n    gg_printf '```\\n'\n    gg_printf '%s\\n' \"$(cat $OUT/${ci}-ctest.log)\"\n    gg_printf '```\\n'\n}\n\n# gpt_2\n\nfunction gg_run_gpt_2 {\n    cd ${SRC}\n\n    gg_wget models-mnt/gpt-2 https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-117M.bin\n\n    cd build-ci-release\n\n    set -e\n\n    model=\"../models-mnt/gpt-2/ggml-model-gpt-2-117M.bin\"\n    prompts=\"../examples/prompts/gpt-2.txt\"\n\n    (time ./bin/gpt-2-backend --model ${model} -s 1234 -n 64 -tt ${prompts}                       ) 2>&1 | tee -a $OUT/${ci}-tg.log\n    (time ./bin/gpt-2-backend --model ${model} -s 1234 -n 64 -p \"I believe the meaning of life is\") 2>&1 | tee -a $OUT/${ci}-tg.log\n    (time ./bin/gpt-2-sched   --model ${model} -s 1234 -n 64 -p \"I believe the meaning of life is\") 2>&1 | tee -a $OUT/${ci}-tg.log\n\n    (time ./bin/gpt-2-batched --model ${model} -s 1234 -n 64 -np 8 -p \"I believe the meaning of life is\") 2>&1 | tee -a $OUT/${ci}-tg.log\n\n    set +e\n}\n\nfunction gg_sum_gpt_2 {\n    gg_printf '### %s\\n\\n' \"${ci}\"\n\n    gg_printf 'Runs short GPT-2 text generation\\n'\n    gg_printf '- status: %s\\n' \"$(cat $OUT/${ci}.exit)\"\n    gg_printf '```\\n'\n    gg_printf '%s\\n' \"$(cat $OUT/${ci}-tg.log)\"\n    gg_printf '```\\n'\n}\n\n# TODO: update\n## mnist\n#\n#function gg_run_mnist {\n#    cd ${SRC}\n#\n#    cd build-ci-release\n#\n#    set -e\n#\n#    mkdir -p models/mnist\n#    python3 ../examples/mnist/convert-h5-to-ggml.py ../examples/mnist/models/mnist/mnist_model.state_dict\n#\n#    model_f32=\"./models/mnist/ggml-model-f32.bin\"\n#    samples=\"../examples/mnist/models/mnist/t10k-images.idx3-ubyte\"\n#\n#    # first command runs and exports \"mnist.ggml\", the second command runs the exported model\n#\n#    (time ./bin/mnist     ${model_f32} ${samples} ) 2>&1 | tee -a $OUT/${ci}-mnist.log\n#    (time ./bin/mnist-cpu ./mnist.ggml ${samples} ) 2>&1 | tee -a $OUT/${ci}-mnist.log\n#\n#    set +e\n#}\n#\n#function gg_sum_mnist {\n#    gg_printf '### %s\\n\\n' \"${ci}\"\n#\n#    gg_printf 'MNIST\\n'\n#    gg_printf '- status: %s\\n' \"$(cat $OUT/${ci}.exit)\"\n#    gg_printf '```\\n'\n#    gg_printf '%s\\n' \"$(cat $OUT/${ci}-mnist.log)\"\n#    gg_printf '```\\n'\n#}\n\n# sam\n\nfunction gg_run_sam {\n    cd ${SRC}\n\n    gg_wget models-mnt/sam/ https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth\n    gg_wget models-mnt/sam/ https://raw.githubusercontent.com/YavorGIvanov/sam.cpp/ceafb7467bff7ec98e0c4f952e58a9eb8fd0238b/img.jpg\n\n    cd build-ci-release\n\n    set -e\n\n    path_models=\"../models-mnt/sam/\"\n    model_f16=\"${path_models}/ggml-model-f16.bin\"\n    img_0=\"${path_models}/img.jpg\"\n\n    python3 ../examples/sam/convert-pth-to-ggml.py ${path_models}/sam_vit_b_01ec64.pth ${path_models}/ 1\n\n    # Test default parameters\n    (time ./bin/sam -m ${model_f16} -i ${img_0} -st 0.925 ) 2>&1 | tee -a $OUT/${ci}-main.log\n    grep -q \"point prompt\" $OUT/${ci}-main.log\n    grep -q \"bbox (371, 436), (144, 168)\" $OUT/${ci}-main.log ||\n    grep -q \"bbox (370, 439), (144, 168)\" $OUT/${ci}-main.log\n\n    # Test box prompt and single mask output\n    (time ./bin/sam -m ${model_f16} -i ${img_0} -st 0.925 -b 368,144,441,173 -sm) 2>&1 | tee -a $OUT/${ci}-main.log\n    grep -q \"box prompt\" $OUT/${ci}-main.log\n    grep -q \"bbox (370, 439), (144, 169)\" $OUT/${ci}-main.log ||\n    grep -q \"bbox (370, 439), (144, 168)\" $OUT/${ci}-main.log\n\n    set +e\n}\n\nfunction gg_sum_sam {\n    gg_printf '### %s\\n\\n' \"${ci}\"\n\n    gg_printf 'Run SAM\\n'\n    gg_printf '- status: %s\\n' \"$(cat $OUT/${ci}.exit)\"\n    gg_printf '```\\n'\n    gg_printf '%s\\n' \"$(cat $OUT/${ci}-main.log)\"\n    gg_printf '```\\n'\n}\n\n# yolo\n\nfunction gg_run_yolo {\n    cd ${SRC}\n\n    gg_wget models-mnt/yolo/ https://huggingface.co/ggml-org/models/resolve/main/yolo/yolov3-tiny.weights\n    gg_wget models-mnt/yolo/ https://huggingface.co/ggml-org/models/resolve/main/yolo/dog.jpg\n\n    cd build-ci-release\n    cp -r ../examples/yolo/data .\n\n    set -e\n\n    path_models=\"../models-mnt/yolo/\"\n\n    python3 ../examples/yolo/convert-yolov3-tiny.py ${path_models}/yolov3-tiny.weights\n\n    (time ./bin/yolov3-tiny -m yolov3-tiny.gguf -i ${path_models}/dog.jpg ) 2>&1 | tee -a $OUT/${ci}-main.log\n\n    grep -qE \"dog: (55|56|57|58|59)%\" $OUT/${ci}-main.log\n    grep -qE \"car: (50|51|52|53|54)%\" $OUT/${ci}-main.log\n    grep -qE \"truck: (54|55|56|57|58)%\" $OUT/${ci}-main.log\n    grep -qE \"bicycle: (57|58|59|60|61)%\" $OUT/${ci}-main.log\n\n    set +e\n}\n\nfunction gg_sum_yolo {\n    gg_printf '### %s\\n\\n' \"${ci}\"\n\n    gg_printf 'Run YOLO\\n'\n    gg_printf '- status: %s\\n' \"$(cat $OUT/${ci}.exit)\"\n    gg_printf '```\\n'\n    gg_printf '%s\\n' \"$(cat $OUT/${ci}-main.log)\"\n    gg_printf '```\\n'\n}\n\n## main\n\nif true ; then\n    # Create symlink: ./ggml/models-mnt -> $MNT/models/models-mnt\n    rm -rf ${SRC}/models-mnt\n    mnt_models=${MNT}/models\n    mkdir -p ${mnt_models}\n    ln -sfn ${mnt_models} ${SRC}/models-mnt\n\n    # Create a fresh python3 venv and enter it\n    if ! python3 -m venv \"$MNT/venv\"; then\n        echo \"Error: Failed to create Python virtual environment at $MNT/venv.\"\n        exit 1\n    fi\n    source \"$MNT/venv/bin/activate\"\n\n    pip install -r ${SRC}/requirements.txt --disable-pip-version-check\nfi\n\n\nret=0\n\ntest $ret -eq 0 && gg_run ctest_debug\ntest $ret -eq 0 && gg_run ctest_release\n\ntest $ret -eq 0 && gg_run gpt_2\n#test $ret -eq 0 && gg_run mnist\ntest $ret -eq 0 && gg_run sam\ntest $ret -eq 0 && gg_run yolo\n\nif [ -z $GG_BUILD_LOW_PERF ]; then\n    # run tests meant for low-perf runners\n    date\nfi\n\ncat $OUT/README.md\n\nexit $ret\n"
  },
  {
    "path": "cmake/GitVars.cmake",
    "content": "find_package(Git)\n\n# the commit's SHA1\nexecute_process(COMMAND\n    \"${GIT_EXECUTABLE}\" describe --match=NeVeRmAtCh --always --abbrev=8\n    WORKING_DIRECTORY \"${CMAKE_SOURCE_DIR}\"\n    OUTPUT_VARIABLE GIT_SHA1\n    ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)\n\n# the date of the commit\nexecute_process(COMMAND\n    \"${GIT_EXECUTABLE}\" log -1 --format=%ad --date=local\n    WORKING_DIRECTORY \"${CMAKE_SOURCE_DIR}\"\n    OUTPUT_VARIABLE GIT_DATE\n    ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)\n\n# the subject of the commit\nexecute_process(COMMAND\n    \"${GIT_EXECUTABLE}\" log -1 --format=%s\n    WORKING_DIRECTORY \"${CMAKE_SOURCE_DIR}\"\n    OUTPUT_VARIABLE GIT_COMMIT_SUBJECT\n    ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)\n"
  },
  {
    "path": "cmake/common.cmake",
    "content": "function(ggml_get_flags CCID CCVER)\n    set(C_FLAGS \"\")\n    set(CXX_FLAGS \"\")\n\n    if (CCID MATCHES \"Clang\")\n        set(C_FLAGS   -Wunreachable-code-break -Wunreachable-code-return)\n        set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi)\n\n        if (\n            (CCID STREQUAL \"Clang\"      AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR\n            (CCID STREQUAL \"AppleClang\" AND CCVER VERSION_GREATER_EQUAL 7.3.0)\n        )\n            list(APPEND C_FLAGS -Wdouble-promotion)\n        endif()\n    elseif (CCID STREQUAL \"GNU\")\n        set(C_FLAGS   -Wdouble-promotion)\n        set(CXX_FLAGS -Wno-array-bounds)\n\n        if (CCVER VERSION_GREATER_EQUAL 8.1.0)\n            list(APPEND CXX_FLAGS -Wextra-semi)\n        endif()\n    endif()\n\n    set(GF_C_FLAGS   ${C_FLAGS}   PARENT_SCOPE)\n    set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE)\nendfunction()\n\nfunction(ggml_get_system_arch)\n    if (CMAKE_OSX_ARCHITECTURES      STREQUAL \"arm64\" OR\n        CMAKE_GENERATOR_PLATFORM_LWR STREQUAL \"arm64\" OR\n        (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND\n            CMAKE_SYSTEM_PROCESSOR MATCHES \"^(aarch64|arm.*|ARM64)$\"))\n        set(GGML_SYSTEM_ARCH \"ARM\" PARENT_SCOPE)\n    elseif (CMAKE_OSX_ARCHITECTURES STREQUAL \"x86_64\" OR\n            CMAKE_GENERATOR_PLATFORM_LWR MATCHES \"^(x86_64|i686|amd64|x64|win32)$\" OR\n            (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND\n            CMAKE_SYSTEM_PROCESSOR MATCHES \"^(x86_64|i686|AMD64|amd64)$\"))\n        set(GGML_SYSTEM_ARCH \"x86\" PARENT_SCOPE)\n    elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"ppc|power\")\n        set(GGML_SYSTEM_ARCH \"PowerPC\" PARENT_SCOPE)\n    elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"loongarch64\")\n        set(GGML_SYSTEM_ARCH \"loongarch64\"  PARENT_SCOPE)\n    elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"riscv64\")\n        set(GGML_SYSTEM_ARCH \"riscv64\" PARENT_SCOPE)\n    elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"s390x\")\n        set(GGML_SYSTEM_ARCH \"s390x\" PARENT_SCOPE)\n    else()\n        set(GGML_SYSTEM_ARCH \"UNKNOWN\" PARENT_SCOPE)\n    endif()\nendfunction()\n"
  },
  {
    "path": "cmake/ggml-config.cmake.in",
    "content": "@PACKAGE_INIT@\n\n@GGML_VARIABLES_EXPANDED@\n\n# Find all dependencies before creating any target.\ninclude(CMakeFindDependencyMacro)\nfind_dependency(Threads)\nif (NOT GGML_SHARED_LIB)\n    set(GGML_CPU_INTERFACE_LINK_LIBRARIES \"\")\n    set(GGML_CPU_INTERFACE_LINK_OPTIONS   \"\")\n\n    if (APPLE AND GGML_ACCELERATE)\n        find_library(ACCELERATE_FRAMEWORK Accelerate)\n        if(NOT ACCELERATE_FRAMEWORK)\n            set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)\n            return()\n        endif()\n        list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK})\n    endif()\n\n    if (GGML_OPENMP_ENABLED)\n        find_dependency(OpenMP)\n        list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX)\n    endif()\n\n    if (GGML_CPU_HBM)\n        find_library(memkind memkind)\n        if(NOT memkind)\n            set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)\n            return()\n        endif()\n        list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind)\n    endif()\n\n    if (GGML_BLAS)\n        find_dependency(BLAS)\n        list(APPEND GGML_BLAS_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})\n        list(APPEND GGML_BLAS_INTERFACE_LINK_OPTIONS   ${BLAS_LINKER_FLAGS})\n    endif()\n\n    if (GGML_CUDA)\n        set(GGML_CUDA_INTERFACE_LINK_LIBRARIES \"\")\n        find_dependency(CUDAToolkit)\n        if (GGML_STATIC)\n            list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cudart_static>)\n            if (WIN32)\n                list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cublas> $<LINK_ONLY:CUDA::cublasLt>)\n            else()\n                list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cublas_static> $<LINK_ONLY:CUDA::cublasLt_static>)\n            endif()\n        endif()\n        if (NOT GGML_CUDA_NO_VMM)\n            list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cuda_driver>)\n        endif()\n    endif()\n\n    if (GGML_METAL)\n        find_library(FOUNDATION_LIBRARY Foundation)\n        find_library(METAL_FRAMEWORK    Metal)\n        find_library(METALKIT_FRAMEWORK MetalKit)\n        if(NOT FOUNDATION_LIBRARY OR NOT METAL_FRAMEWORK OR NOT METALKIT_FRAMEWORK)\n            set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)\n            return()\n        endif()\n        set(GGML_METAL_INTERFACE_LINK_LIBRARIES\n            ${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})\n    endif()\n\n    if (GGML_OPENCL)\n        find_dependency(OpenCL)\n        set(GGML_OPENCL_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:OpenCL::OpenCL>)\n    endif()\n\n    if (GGML_VULKAN)\n        find_dependency(Vulkan)\n        set(GGML_VULKAN_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:Vulkan::Vulkan>)\n    endif()\n\n    if (GGML_HIP)\n        find_dependency(hip)\n        find_dependency(hipblas)\n        find_dependency(rocblas)\n        set(GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas)\n    endif()\n\n    if (GGML_SYCL)\n        set(GGML_SYCL_INTERFACE_LINK_LIBRARIES \"\")\n        find_package(DNNL)\n        if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL \"INTEL\")\n            list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl)\n        endif()\n        if (WIN32)\n            find_dependency(IntelSYCL)\n            find_dependency(MKL)\n            list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)\n        endif()\n    endif()\nendif()\n\nset_and_check(GGML_INCLUDE_DIR \"@PACKAGE_GGML_INCLUDE_INSTALL_DIR@\")\nset_and_check(GGML_LIB_DIR \"@PACKAGE_GGML_LIB_INSTALL_DIR@\")\n#set_and_check(GGML_BIN_DIR \"@PACKAGE_GGML_BIN_INSTALL_DIR@\")\n\nif(NOT TARGET ggml::ggml)\n    find_package(Threads REQUIRED)\n\n    find_library(GGML_LIBRARY ggml\n        REQUIRED\n        HINTS ${GGML_LIB_DIR}\n        NO_CMAKE_FIND_ROOT_PATH)\n\n    add_library(ggml::ggml UNKNOWN IMPORTED)\n    set_target_properties(ggml::ggml\n        PROPERTIES\n            IMPORTED_LOCATION \"${GGML_LIBRARY}\")\n\n    find_library(GGML_BASE_LIBRARY ggml-base\n        REQUIRED\n        HINTS ${GGML_LIB_DIR}\n        NO_CMAKE_FIND_ROOT_PATH)\n\n    add_library(ggml::ggml-base UNKNOWN IMPORTED)\n    set_target_properties(ggml::ggml-base\n        PROPERTIES\n            IMPORTED_LOCATION \"${GGML_BASE_LIBRARY}\")\n\n    set(_ggml_all_targets \"\")\n    if (NOT GGML_BACKEND_DL)\n        foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})\n            string(REPLACE \"-\" \"_\" _ggml_backend_pfx \"${_ggml_backend}\")\n            string(TOUPPER \"${_ggml_backend_pfx}\" _ggml_backend_pfx)\n\n            find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}\n                REQUIRED\n                HINTS ${GGML_LIB_DIR}\n                NO_CMAKE_FIND_ROOT_PATH)\n\n            message(STATUS \"Found ${${_ggml_backend_pfx}_LIBRARY}\")\n\n            add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)\n            set_target_properties(ggml::${_ggml_backend}\n                PROPERTIES\n                    INTERFACE_INCLUDE_DIRECTORIES \"${GGML_INCLUDE_DIR}\"\n                    IMPORTED_LINK_INTERFACE_LANGUAGES \"CXX\"\n                    IMPORTED_LOCATION \"${${_ggml_backend_pfx}_LIBRARY}\"\n                    INTERFACE_COMPILE_FEATURES c_std_90\n                    POSITION_INDEPENDENT_CODE ON)\n\n            string(REGEX MATCH \"^ggml-cpu\" is_cpu_variant \"${_ggml_backend}\")\n            if(is_cpu_variant)\n                list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES \"ggml::ggml-base\")\n                set_target_properties(ggml::${_ggml_backend}\n                PROPERTIES\n                    INTERFACE_LINK_LIBRARIES \"${GGML_CPU_INTERFACE_LINK_LIBRARIES}\")\n\n                if(GGML_CPU_INTERFACE_LINK_OPTIONS)\n                    set_target_properties(ggml::${_ggml_backend}\n                        PROPERTIES\n                            INTERFACE_LINK_OPTIONS \"${GGML_CPU_INTERFACE_LINK_OPTIONS}\")\n                endif()\n\n            else()\n                list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES \"ggml::ggml-base\")\n                set_target_properties(ggml::${_ggml_backend}\n                    PROPERTIES\n                        INTERFACE_LINK_LIBRARIES \"${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}\")\n\n                if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)\n                    set_target_properties(ggml::${_ggml_backend}\n                        PROPERTIES\n                            INTERFACE_LINK_OPTIONS \"${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}\")\n                endif()\n            endif()\n\n            list(APPEND _ggml_all_targets ggml::${_ggml_backend})\n        endforeach()\n    endif()\n\n    list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base \"${_ggml_all_targets}\")\n    set_target_properties(ggml::ggml\n        PROPERTIES\n            INTERFACE_LINK_LIBRARIES \"${GGML_INTERFACE_LINK_LIBRARIES}\")\n\n    add_library(ggml::all INTERFACE IMPORTED)\n    set_target_properties(ggml::all\n        PROPERTIES\n            INTERFACE_LINK_LIBRARIES \"${_ggml_all_targets}\")\n\nendif()\n\ncheck_required_components(ggml)\n"
  },
  {
    "path": "docs/gguf.md",
    "content": "# GGUF\n\nGGUF is a file format for storing models for inference with GGML and executors based on GGML. GGUF is a binary format that is designed for fast loading and saving of models, and for ease of reading. Models are traditionally developed using PyTorch or another framework, and then converted to GGUF for use in GGML.\n\nIt is a successor file format to GGML, GGMF and GGJT, and is designed to be unambiguous by containing all the information needed to load a model. It is also designed to be extensible, so that new information can be added to models without breaking compatibility.\n\nFor more information about the motivation behind GGUF, see [Historical State of Affairs](#historical-state-of-affairs).\n\n## Specification\n\nGGUF is a format based on the existing GGJT, but makes a few changes to the format to make it more extensible and easier to use. The following features are desired:\n\n- Single-file deployment: they can be easily distributed and loaded, and do not require any external files for additional information.\n- Extensible: new features can be added to GGML-based executors/new information can be added to GGUF models without breaking compatibility with existing models.\n- `mmap` compatibility: models can be loaded using `mmap` for fast loading and saving.\n- Easy to use: models can be easily loaded and saved using a small amount of code, with no need for external libraries, regardless of the language used.\n- Full information: all information needed to load a model is contained in the model file, and no additional information needs to be provided by the user.\n\nThe key difference between GGJT and GGUF is the use of a key-value structure for the hyperparameters (now referred to as metadata), rather than a list of untyped values. This allows for new metadata to be added without breaking compatibility with existing models, and to annotate the model with additional information that may be useful for inference or for identifying the model.\n\n### GGUF Naming Convention\n\nGGUF follow a naming convention of `<BaseName><SizeLabel><FineTune><Version><Encoding><Type><Shard>.gguf` where each component is delimitated by a `-` if present. Ultimately this is intended to make it easier for humans to at a glance get the most important details of a model. It is not intended to be perfectly parsable in the field due to the diversity of existing gguf filenames.\n\nThe components are:\n1. **BaseName**: A descriptive name for the model base type or architecture.\n    - This can be derived from gguf metadata `general.basename` substituting spaces for dashes.\n1. **SizeLabel**: Parameter weight class (useful for leader boards) represented as `<expertCount>x<count><scale-prefix>`\n    - This can be derived from gguf metadata `general.size_label` if available or calculated if missing.\n    - Rounded decimal point is supported in count with a single letter scale prefix to assist in floating point exponent shown below\n      - `Q`: Quadrillion parameters.\n      - `T`: Trillion parameters.\n      - `B`: Billion parameters.\n      - `M`: Million parameters.\n      - `K`: Thousand parameters.\n    - Additional `-<attributes><count><scale-prefix>` can be appended as needed to indicate other attributes of interest\n1. **FineTune**: A descriptive name for the model fine tuning goal (e.g. Chat, Instruct, etc...)\n    - This can be derived from gguf metadata `general.finetune` substituting spaces for dashes.\n1. **Version**: (Optional) Denotes the model version number, formatted as `v<Major>.<Minor>`\n    - If model is missing a version number then assume `v1.0` (First Public Release)\n    - This can be derived from gguf metadata `general.version`\n1. **Encoding**: Indicates the weights encoding scheme that was applied to the model. Content, type mixture and arrangement however are determined by user code and can vary depending on project needs.\n1. **Type**: Indicates the kind of gguf file and the intended purpose for it\n  - If missing, then file is by default a typical gguf tensor model file\n  - `LoRA` : GGUF file is a LoRA adapter\n  - `vocab` : GGUF file with only vocab data and metadata\n1. **Shard**: (Optional) Indicates and denotes that the model has been split into multiple shards, formatted as `<ShardNum>-of-<ShardTotal>`.\n    - *ShardNum* : Shard position in this model. Must be 5 digits padded by zeros.\n      - Shard number always starts from `00001` onwards (e.g. First shard always starts at `00001-of-XXXXX` rather than `00000-of-XXXXX`).\n    - *ShardTotal* : Total number of shards in this model. Must be 5 digits padded by zeros.\n\n\n#### Validating Above Naming Convention\n\nAt a minimum all model files should have at least BaseName, SizeLabel, Version, in order to be easily validated as a file that is keeping with the GGUF Naming Convention. An example of this issue is that it is easy for Encoding to be mistaken as a FineTune if Version is omitted.\n\nTo validate you can use this regular expression `^(?<BaseName>[A-Za-z0-9\\s]*(?:(?:-(?:(?:[A-Za-z\\s][A-Za-z0-9\\s]*)|(?:[0-9\\s]*)))*))-(?:(?<SizeLabel>(?:\\d+x)?(?:\\d+\\.)?\\d+[A-Za-z](?:-[A-Za-z]+(\\d+\\.)?\\d+[A-Za-z]+)?)(?:-(?<FineTune>[A-Za-z0-9\\s-]+))?)?-(?:(?<Version>v\\d+(?:\\.\\d+)*))(?:-(?<Encoding>(?!LoRA|vocab)[\\w_]+))?(?:-(?<Type>LoRA|vocab))?(?:-(?<Shard>\\d{5}-of-\\d{5}))?\\.gguf$` which will check that you got the minimum BaseName, SizeLabel and Version present in the correct order.\n\nFor example:\n\n  * `Mixtral-8x7B-v0.1-KQ2.gguf`:\n    - Model Name: Mixtral\n    - Expert Count: 8\n    - Parameter Count: 7B\n    - Version Number: v0.1\n    - Weight Encoding Scheme: KQ2\n\n  * `Hermes-2-Pro-Llama-3-8B-F16.gguf`:\n    - Model Name: Hermes 2 Pro Llama 3\n    - Expert Count: 0\n    - Parameter Count: 8B\n    - Version Number: v1.0\n    - Weight Encoding Scheme: F16\n    - Shard: N/A\n\n  * `Grok-100B-v1.0-Q4_0-00003-of-00009.gguf`\n    - Model Name: Grok\n    - Expert Count: 0\n    - Parameter Count: 100B\n    - Version Number: v1.0\n    - Weight Encoding Scheme: Q4_0\n    - Shard: 3 out of 9 total shards\n\n\n<details><summary>Example Node.js Regex Function</summary>\n\n```js\n#!/usr/bin/env node\nconst ggufRegex = /^(?<BaseName>[A-Za-z0-9\\s]*(?:(?:-(?:(?:[A-Za-z\\s][A-Za-z0-9\\s]*)|(?:[0-9\\s]*)))*))-(?:(?<SizeLabel>(?:\\d+x)?(?:\\d+\\.)?\\d+[A-Za-z](?:-[A-Za-z]+(\\d+\\.)?\\d+[A-Za-z]+)?)(?:-(?<FineTune>[A-Za-z0-9\\s-]+))?)?-(?:(?<Version>v\\d+(?:\\.\\d+)*))(?:-(?<Encoding>(?!LoRA|vocab)[\\w_]+))?(?:-(?<Type>LoRA|vocab))?(?:-(?<Shard>\\d{5}-of-\\d{5}))?\\.gguf$/;\n\nfunction parseGGUFFilename(filename) {\n  const match = ggufRegex.exec(filename);\n  if (!match)\n    return null;\n  const {BaseName = null, SizeLabel = null, FineTune = null, Version = \"v1.0\", Encoding = null, Type = null, Shard = null} = match.groups;\n  return {BaseName: BaseName, SizeLabel: SizeLabel, FineTune: FineTune, Version: Version, Encoding: Encoding, Type: Type, Shard: Shard};\n}\n\nconst testCases = [\n  {filename: 'Mixtral-8x7B-v0.1-KQ2.gguf',                         expected: { BaseName: 'Mixtral',              SizeLabel: '8x7B',     FineTune: null, Version: 'v0.1',   Encoding: 'KQ2',  Type: null, Shard: null}},\n  {filename: 'Grok-100B-v1.0-Q4_0-00003-of-00009.gguf',            expected: { BaseName: 'Grok',                 SizeLabel: '100B',     FineTune: null, Version: 'v1.0',   Encoding: 'Q4_0', Type: null, Shard: \"00003-of-00009\"}},\n  {filename: 'Hermes-2-Pro-Llama-3-8B-v1.0-F16.gguf',              expected: { BaseName: 'Hermes-2-Pro-Llama-3', SizeLabel: '8B', FineTune: null, Version: 'v1.0',   Encoding: 'F16',  Type: null, Shard: null}},\n  {filename: 'Phi-3-mini-3.8B-ContextLength4k-instruct-v1.0.gguf', expected: { BaseName: 'Phi-3-mini',   SizeLabel: '3.8B-ContextLength4k', FineTune: 'instruct', Version: 'v1.0',   Encoding: null,  Type: null, Shard: null}},\n  {filename: 'not-a-known-arrangement.gguf',                       expected: null},\n];\n\ntestCases.forEach(({ filename, expected }) => {\n  const result = parseGGUFFilename(filename);\n  const passed = JSON.stringify(result) === JSON.stringify(expected);\n  console.log(`${filename}: ${passed ? \"PASS\" : \"FAIL\"}`);\n  if (!passed) {\n      console.log(result);\n      console.log(expected);\n  }\n});\n```\n\n</details>\n\n\n### File Structure\n\n![image](https://github.com/ggerganov/ggml/assets/1991296/c3623641-3a1d-408e-bfaf-1b7c4e16aa63)\n*diagram by [@mishig25](https://github.com/mishig25) (GGUF v3)*\n\nGGUF files are structured as follows. They use a global alignment specified in the `general.alignment` metadata field, referred to as `ALIGNMENT` below. Where required, the file is padded with `0x00` bytes to the next multiple of `general.alignment`.\n\nFields, including arrays, are written sequentially without alignment unless otherwise specified.\n\nModels are little-endian by default. They can also come in big-endian for use with big-endian computers; in this case, all values (including metadata values and tensors) will also be big-endian. At the time of writing, there is no way to determine if a model is big-endian; this may be rectified in future versions. If no additional information is provided, assume the model is little-endian.\n\n```c\nenum ggml_type: uint32_t {\n    GGML_TYPE_F32     = 0,\n    GGML_TYPE_F16     = 1,\n    GGML_TYPE_Q4_0    = 2,\n    GGML_TYPE_Q4_1    = 3,\n    // GGML_TYPE_Q4_2 = 4, support has been removed\n    // GGML_TYPE_Q4_3 = 5, support has been removed\n    GGML_TYPE_Q5_0    = 6,\n    GGML_TYPE_Q5_1    = 7,\n    GGML_TYPE_Q8_0    = 8,\n    GGML_TYPE_Q8_1    = 9,\n    GGML_TYPE_Q2_K    = 10,\n    GGML_TYPE_Q3_K    = 11,\n    GGML_TYPE_Q4_K    = 12,\n    GGML_TYPE_Q5_K    = 13,\n    GGML_TYPE_Q6_K    = 14,\n    GGML_TYPE_Q8_K    = 15,\n    GGML_TYPE_IQ2_XXS = 16,\n    GGML_TYPE_IQ2_XS  = 17,\n    GGML_TYPE_IQ3_XXS = 18,\n    GGML_TYPE_IQ1_S   = 19,\n    GGML_TYPE_IQ4_NL  = 20,\n    GGML_TYPE_IQ3_S   = 21,\n    GGML_TYPE_IQ2_S   = 22,\n    GGML_TYPE_IQ4_XS  = 23,\n    GGML_TYPE_I8      = 24,\n    GGML_TYPE_I16     = 25,\n    GGML_TYPE_I32     = 26,\n    GGML_TYPE_I64     = 27,\n    GGML_TYPE_F64     = 28,\n    GGML_TYPE_IQ1_M   = 29,\n    GGML_TYPE_BF16    = 30,\n    // GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files\n    // GGML_TYPE_Q4_0_4_8 = 32,\n    // GGML_TYPE_Q4_0_8_8 = 33,\n    GGML_TYPE_TQ1_0   = 34,\n    GGML_TYPE_TQ2_0   = 35,\n    // GGML_TYPE_IQ4_NL_4_4 = 36,\n    // GGML_TYPE_IQ4_NL_4_8 = 37,\n    // GGML_TYPE_IQ4_NL_8_8 = 38,\n    GGML_TYPE_MXFP4   = 39, // MXFP4 (1 block)\n    GGML_TYPE_COUNT   = 40,\n};\n\nenum gguf_metadata_value_type: uint32_t {\n    // The value is a 8-bit unsigned integer.\n    GGUF_METADATA_VALUE_TYPE_UINT8 = 0,\n    // The value is a 8-bit signed integer.\n    GGUF_METADATA_VALUE_TYPE_INT8 = 1,\n    // The value is a 16-bit unsigned little-endian integer.\n    GGUF_METADATA_VALUE_TYPE_UINT16 = 2,\n    // The value is a 16-bit signed little-endian integer.\n    GGUF_METADATA_VALUE_TYPE_INT16 = 3,\n    // The value is a 32-bit unsigned little-endian integer.\n    GGUF_METADATA_VALUE_TYPE_UINT32 = 4,\n    // The value is a 32-bit signed little-endian integer.\n    GGUF_METADATA_VALUE_TYPE_INT32 = 5,\n    // The value is a 32-bit IEEE754 floating point number.\n    GGUF_METADATA_VALUE_TYPE_FLOAT32 = 6,\n    // The value is a boolean.\n    // 1-byte value where 0 is false and 1 is true.\n    // Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy.\n    GGUF_METADATA_VALUE_TYPE_BOOL = 7,\n    // The value is a UTF-8 non-null-terminated string, with length prepended.\n    GGUF_METADATA_VALUE_TYPE_STRING = 8,\n    // The value is an array of other values, with the length and type prepended.\n    ///\n    // Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.\n    GGUF_METADATA_VALUE_TYPE_ARRAY = 9,\n    // The value is a 64-bit unsigned little-endian integer.\n    GGUF_METADATA_VALUE_TYPE_UINT64 = 10,\n    // The value is a 64-bit signed little-endian integer.\n    GGUF_METADATA_VALUE_TYPE_INT64 = 11,\n    // The value is a 64-bit IEEE754 floating point number.\n    GGUF_METADATA_VALUE_TYPE_FLOAT64 = 12,\n};\n\n// A string in GGUF.\nstruct gguf_string_t {\n    // The length of the string, in bytes.\n    uint64_t len;\n    // The string as a UTF-8 non-null-terminated string.\n    char string[len];\n};\n\nunion gguf_metadata_value_t {\n    uint8_t uint8;\n    int8_t int8;\n    uint16_t uint16;\n    int16_t int16;\n    uint32_t uint32;\n    int32_t int32;\n    float float32;\n    uint64_t uint64;\n    int64_t int64;\n    double float64;\n    bool bool_;\n    gguf_string_t string;\n    struct {\n        // Any value type is valid, including arrays.\n        gguf_metadata_value_type type;\n        // Number of elements, not bytes\n        uint64_t len;\n        // The array of values.\n        gguf_metadata_value_t array[len];\n    } array;\n};\n\nstruct gguf_metadata_kv_t {\n    // The key of the metadata. It is a standard GGUF string, with the following caveats:\n    // - It must be a valid ASCII string.\n    // - It must be a hierarchical key, where each segment is `lower_snake_case` and separated by a `.`.\n    // - It must be at most 2^16-1/65535 bytes long.\n    // Any keys that do not follow these rules are invalid.\n    gguf_string_t key;\n\n    // The type of the value.\n    // Must be one of the `gguf_metadata_value_type` values.\n    gguf_metadata_value_type value_type;\n    // The value.\n    gguf_metadata_value_t value;\n};\n\nstruct gguf_header_t {\n    // Magic number to announce that this is a GGUF file.\n    // Must be `GGUF` at the byte level: `0x47` `0x47` `0x55` `0x46`.\n    // Your executor might do little-endian byte order, so it might be\n    // check for 0x46554747 and letting the endianness cancel out.\n    // Consider being *very* explicit about the byte order here.\n    uint32_t magic;\n    // The version of the format implemented.\n    // Must be `3` for version described in this spec, which introduces big-endian support.\n    //\n    // This version should only be increased for structural changes to the format.\n    // Changes that do not affect the structure of the file should instead update the metadata\n    // to signify the change.\n    uint32_t version;\n    // The number of tensors in the file.\n    // This is explicit, instead of being included in the metadata, to ensure it is always present\n    // for loading the tensors.\n    uint64_t tensor_count;\n    // The number of metadata key-value pairs.\n    uint64_t metadata_kv_count;\n    // The metadata key-value pairs.\n    gguf_metadata_kv_t metadata_kv[metadata_kv_count];\n};\n\nuint64_t align_offset(uint64_t offset) {\n    return offset + (ALIGNMENT - (offset % ALIGNMENT)) % ALIGNMENT;\n}\n\nstruct gguf_tensor_info_t {\n    // The name of the tensor. It is a standard GGUF string, with the caveat that\n    // it must be at most 64 bytes long.\n    gguf_string_t name;\n    // The number of dimensions in the tensor.\n    // Currently at most 4, but this may change in the future.\n    uint32_t n_dimensions;\n    // The dimensions of the tensor.\n    uint64_t dimensions[n_dimensions];\n    // The type of the tensor.\n    ggml_type type;\n    // The offset of the tensor's data in this file in bytes.\n    //\n    // This offset is relative to `tensor_data`, not to the start\n    // of the file, to make it easier for writers to write the file.\n    // Readers should consider exposing this offset relative to the\n    // file to make it easier to read the data.\n    //\n    // Must be a multiple of `ALIGNMENT`. That is, `align_offset(offset) == offset`.\n    uint64_t offset;\n};\n\nstruct gguf_file_t {\n    // The header of the file.\n    gguf_header_t header;\n\n    // Tensor infos, which can be used to locate the tensor data.\n    gguf_tensor_info_t tensor_infos[header.tensor_count];\n\n    // Padding to the nearest multiple of `ALIGNMENT`.\n    //\n    // That is, if `sizeof(header) + sizeof(tensor_infos)` is not a multiple of `ALIGNMENT`,\n    // this padding is added to make it so.\n    //\n    // This can be calculated as `align_offset(position) - position`, where `position` is\n    // the position of the end of `tensor_infos` (i.e. `sizeof(header) + sizeof(tensor_infos)`).\n    uint8_t _padding[];\n\n    // Tensor data.\n    //\n    // This is arbitrary binary data corresponding to the weights of the model. This data should be close\n    // or identical to the data in the original model file, but may be different due to quantization or\n    // other optimizations for inference. Any such deviations should be recorded in the metadata or as\n    // part of the architecture definition.\n    //\n    // Each tensor's data must be stored within this array, and located through its `tensor_infos` entry.\n    // The offset of each tensor's data must be a multiple of `ALIGNMENT`, and the space between tensors\n    // should be padded to `ALIGNMENT` bytes.\n    uint8_t tensor_data[];\n};\n```\n\n## Standardized key-value pairs\n\nThe following key-value pairs are standardized. This list may grow in the future as more use cases are discovered. Where possible, names are shared with the original model definitions to make it easier to map between the two.\n\nNot all of these are required, but they are all recommended. Keys that are required are bolded. For omitted pairs, the reader should assume that the value is unknown and either default or error as appropriate.\n\nThe community can develop their own key-value pairs to carry additional data. However, these should be namespaced with the relevant community name to avoid collisions. For example, the `rustformers` community might use `rustformers.` as a prefix for all of their keys.\n\nIf a particular community key is widely used, it may be promoted to a standardized key.\n\nBy convention, most counts/lengths/etc are `uint64` unless otherwise specified. This is to allow for larger models to be supported in the future. Some models may use `uint32` for their values; it is recommended that readers support both.\n\n### General\n\n#### Required\n\n- **`general.architecture: string`**: describes what architecture this model implements. All lowercase ASCII, with only `[a-z0-9]+` characters allowed. Known values include:\n  - `llama`\n  - `mpt`\n  - `gptneox`\n  - `gptj`\n  - `gpt2`\n  - `bloom`\n  - `falcon`\n  - `mamba`\n  - `rwkv`\n- **`general.quantization_version: uint32`**: The version of the quantization format. Not required if the model is not quantized (i.e. no tensors are quantized). If any tensors are quantized, this _must_ be present. This is separate to the quantization scheme of the tensors itself; the quantization version may change without changing the scheme's name (e.g. the quantization scheme is Q5_K, and the quantization version is 4).\n- **`general.alignment: uint32`**: the global alignment to use, as described above. This can vary to allow for different alignment schemes, but it must be a multiple of 8. Some writers may not write the alignment. If the alignment is **not** specified, assume it is `32`.\n\n#### General metadata\n\n- `general.name: string`: The name of the model. This should be a human-readable name that can be used to identify the model. It should be unique within the community that the model is defined in.\n- `general.author: string`: The author of the model.\n- `general.version: string`: The version of the model.\n- `general.organization: string`: The organization of the model.\n- `general.basename: string`: The base model name / architecture of the model\n- `general.finetune: string`: What has the base model been optimized toward.\n- `general.description: string`: free-form description of the model including anything that isn't covered by the other fields\n- `general.quantized_by: string`: The name of the individual who quantized the model\n- `general.size_label: string`: Size class of the model, such as number of weights and experts. (Useful for leader boards)\n- `general.license: string`: License of the model, expressed as a [SPDX license expression](https://spdx.github.io/spdx-spec/v2-draft/SPDX-license-expressions/) (e.g. `\"MIT OR Apache-2.0`). Do not include any other information, such as the license text or the URL to the license.\n- `general.license.name: string`: Human friendly license name\n- `general.license.link: string`: URL to the license.\n- `general.url: string`: URL to the model's homepage. This can be a GitHub repo, a paper, etc.\n- `general.doi: string`: Digital Object Identifier (DOI) https://www.doi.org/\n- `general.uuid: string`: [Universally unique identifier](https://en.wikipedia.org/wiki/Universally_unique_identifier)\n- `general.repo_url: string`: URL to the model's repository such as a GitHub repo or HuggingFace repo\n- `general.tags: string[]`: List of tags that can be used as search terms for a search engine or social media\n- `general.languages: string[]`: What languages can the model speak. Encoded as [ISO 639](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) two letter codes\n- `general.datasets: string[]`: Links or references to datasets that the model was trained upon\n- `general.file_type: uint32`: An enumerated value describing the type of the majority of the tensors in the file. Optional; can be inferred from the tensor types.\n  - `ALL_F32 = 0`\n  - `MOSTLY_F16 = 1`\n  - `MOSTLY_Q4_0 = 2`\n  - `MOSTLY_Q4_1 = 3`\n  - `MOSTLY_Q4_1_SOME_F16 = 4`\n  - `MOSTLY_Q4_2 = 5` (support removed)\n  - `MOSTLY_Q4_3 = 6` (support removed)\n  - `MOSTLY_Q8_0 = 7`\n  - `MOSTLY_Q5_0 = 8`\n  - `MOSTLY_Q5_1 = 9`\n  - `MOSTLY_Q2_K = 10`\n  - `MOSTLY_Q3_K_S = 11`\n  - `MOSTLY_Q3_K_M = 12`\n  - `MOSTLY_Q3_K_L = 13`\n  - `MOSTLY_Q4_K_S = 14`\n  - `MOSTLY_Q4_K_M = 15`\n  - `MOSTLY_Q5_K_S = 16`\n  - `MOSTLY_Q5_K_M = 17`\n  - `MOSTLY_Q6_K = 18`\n\n#### Source metadata\n\nInformation about where this model came from. This is useful for tracking the provenance of the model, and for finding the original source if the model is modified. For a model that was converted from GGML, for example, these keys would point to the model that was converted from.\n\n- `general.source.url: string`: URL to the source of the model's homepage. This can be a GitHub repo, a paper, etc.\n- `general.source.doi: string`: Source Digital Object Identifier (DOI) https://www.doi.org/\n- `general.source.uuid: string`: Source [Universally unique identifier](https://en.wikipedia.org/wiki/Universally_unique_identifier)\n- `general.source.repo_url: string`: URL to the source of the model's repository such as a GitHub repo or HuggingFace repo\n\n- `general.base_model.count: uint32`: Number of parent models\n- `general.base_model.{id}.name: string`: The name of the parent model.\n- `general.base_model.{id}.author: string`: The author of the parent model.\n- `general.base_model.{id}.version: string`: The version of the parent model.\n- `general.base_model.{id}.organization: string`: The organization of the parent model.\n- `general.base_model.{id}.url: string`: URL to the source of the parent model's homepage. This can be a GitHub repo, a paper, etc.\n- `general.base_model.{id}.doi: string`: Parent Digital Object Identifier (DOI) https://www.doi.org/\n- `general.base_model.{id}.uuid: string`: Parent [Universally unique identifier](https://en.wikipedia.org/wiki/Universally_unique_identifier)\n- `general.base_model.{id}.repo_url: string`: URL to the source of the parent model's repository such as a GitHub repo or HuggingFace repo\n\n### LLM\n\nIn the following, `[llm]` is used to fill in for the name of a specific LLM architecture. For example, `llama` for LLaMA, `mpt` for MPT, etc. If mentioned in an architecture's section, it is required for that architecture, but not all keys are required for all architectures. Consult the relevant section for more information.\n\n- `[llm].context_length: uint64`: Also known as `n_ctx`. length of the context (in tokens) that the model was trained on. For most architectures, this is the hard limit on the length of the input. Architectures, like RWKV, that are not reliant on transformer-style attention may be able to handle larger inputs, but this is not guaranteed.\n- `[llm].embedding_length: uint64`: Also known as `n_embd`. Embedding layer size.\n- `[llm].block_count: uint64`: The number of blocks of attention+feed-forward layers (i.e. the bulk of the LLM). Does not include the input or embedding layers.\n- `[llm].feed_forward_length: uint64`: Also known as `n_ff`. The length of the feed-forward layer.\n- `[llm].use_parallel_residual: bool`: Whether or not the parallel residual logic should be used.\n- `[llm].tensor_data_layout: string`: When a model is converted to GGUF, tensors may be rearranged to improve performance. This key describes the layout of the tensor data. This is not required; if not present, it is assumed to be `reference`.\n  - `reference`: tensors are laid out in the same order as the original model\n  - further options can be found for each architecture in their respective sections\n- `[llm].expert_count: uint32`: Number of experts in MoE models (optional for non-MoE arches).\n- `[llm].expert_used_count: uint32`: Number of experts used during each token token evaluation (optional for non-MoE arches).\n\n#### Attention\n\n- `[llm].attention.head_count: uint64`: Also known as `n_head`. Number of attention heads.\n- `[llm].attention.head_count_kv: uint64`: The number of heads per group used in Grouped-Query-Attention. If not present or if present and equal to `[llm].attention.head_count`, the model does not use GQA.\n- `[llm].attention.max_alibi_bias: float32`: The maximum bias to use for ALiBI.\n- `[llm].attention.clamp_kqv: float32`: Value (`C`) to clamp the values of the `Q`, `K`, and `V` tensors between (`[-C, C]`).\n- `[llm].attention.layer_norm_epsilon: float32`: Layer normalization epsilon.\n- `[llm].attention.layer_norm_rms_epsilon: float32`: Layer RMS normalization epsilon.\n- `[llm].attention.key_length: uint32`: The optional size of a key head, $d_k$. If not specified, it will be `n_embd / n_head`.\n- `[llm].attention.value_length: uint32`: The optional size of a value head, $d_v$. If not specified, it will be `n_embd / n_head`.\n\n#### RoPE\n\n- `[llm].rope.dimension_count: uint64`: The number of rotary dimensions for RoPE.\n- `[llm].rope.freq_base: float32`: The base frequency for RoPE.\n\n##### Scaling\n\nThe following keys describe RoPE scaling parameters:\n\n- `[llm].rope.scaling.type: string`: Can be `none`, `linear`, or `yarn`.\n- `[llm].rope.scaling.factor: float32`: A scale factor for RoPE to adjust the context length.\n- `[llm].rope.scaling.original_context_length: uint32_t`: The original context length of the base model.\n- `[llm].rope.scaling.finetuned: bool`: True if model has been finetuned with RoPE scaling.\n\nNote that older models may not have these keys, and may instead use the following key:\n\n- `[llm].rope.scale_linear: float32`: A linear scale factor for RoPE to adjust the context length.\n\nIt is recommended that models use the newer keys if possible, as they are more flexible and allow for more complex scaling schemes. Executors will need to support both indefinitely.\n\n#### SSM\n\n- `[llm].ssm.conv_kernel: uint32`: The size of the rolling/shift state.\n- `[llm].ssm.inner_size: uint32`: The embedding size of the states.\n- `[llm].ssm.state_size: uint32`: The size of the recurrent state.\n- `[llm].ssm.time_step_rank: uint32`: The rank of time steps.\n\n#### Models\n\nThe following sections describe the metadata for each model architecture. Each key specified _must_ be present.\n\n##### LLaMA\n\n- `llama.context_length`\n- `llama.embedding_length`\n- `llama.block_count`\n- `llama.feed_forward_length`\n- `llama.rope.dimension_count`\n- `llama.attention.head_count`\n- `llama.attention.layer_norm_rms_epsilon`\n\n###### Optional\n\n- `llama.rope.scale`\n- `llama.attention.head_count_kv`\n- `llama.tensor_data_layout`:\n  - `Meta AI original pth`:\n    ```python\n    def permute(weights: NDArray, n_head: int) -> NDArray:\n        return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])\n                    .swapaxes(1, 2)\n                    .reshape(weights.shape))\n    ```\n- `llama.expert_count`\n- `llama.expert_used_count`\n\n##### MPT\n\n- `mpt.context_length`\n- `mpt.embedding_length`\n- `mpt.block_count`\n- `mpt.attention.head_count`\n- `mpt.attention.alibi_bias_max`\n- `mpt.attention.clip_kqv`\n- `mpt.attention.layer_norm_epsilon`\n\n##### GPT-NeoX\n\n- `gptneox.context_length`\n- `gptneox.embedding_length`\n- `gptneox.block_count`\n- `gptneox.use_parallel_residual`\n- `gptneox.rope.dimension_count`\n- `gptneox.attention.head_count`\n- `gptneox.attention.layer_norm_epsilon`\n\n###### Optional\n\n- `gptneox.rope.scale`\n\n##### GPT-J\n\n- `gptj.context_length`\n- `gptj.embedding_length`\n- `gptj.block_count`\n- `gptj.rope.dimension_count`\n- `gptj.attention.head_count`\n- `gptj.attention.layer_norm_epsilon`\n\n###### Optional\n\n- `gptj.rope.scale`\n\n##### GPT-2\n\n- `gpt2.context_length`\n- `gpt2.embedding_length`\n- `gpt2.block_count`\n- `gpt2.attention.head_count`\n- `gpt2.attention.layer_norm_epsilon`\n\n##### BLOOM\n\n- `bloom.context_length`\n- `bloom.embedding_length`\n- `bloom.block_count`\n- `bloom.feed_forward_length`\n- `bloom.attention.head_count`\n- `bloom.attention.layer_norm_epsilon`\n\n##### Falcon\n\n- `falcon.context_length`\n- `falcon.embedding_length`\n- `falcon.block_count`\n- `falcon.attention.head_count`\n- `falcon.attention.head_count_kv`\n- `falcon.attention.use_norm`\n- `falcon.attention.layer_norm_epsilon`\n\n###### Optional\n\n- `falcon.tensor_data_layout`:\n\n  - `jploski` (author of the original GGML implementation of Falcon):\n\n    ```python\n    # The original query_key_value tensor contains n_head_kv \"kv groups\",\n    # each consisting of n_head/n_head_kv query weights followed by one key\n    # and one value weight (shared by all query heads in the kv group).\n    # This layout makes it a big pain to work with in GGML.\n    # So we rearrange them here,, so that we have n_head query weights\n    # followed by n_head_kv key weights followed by n_head_kv value weights,\n    # in contiguous fashion.\n\n    if \"query_key_value\" in src:\n        qkv = model[src].view(\n            n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)\n\n        q = qkv[:, :-2 ].reshape(n_head * head_dim, head_dim * n_head)\n        k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)\n        v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)\n\n        model[src] = torch.cat((q,k,v)).reshape_as(model[src])\n    ```\n\n##### Mamba\n\n- `mamba.context_length`\n- `mamba.embedding_length`\n- `mamba.block_count`\n- `mamba.ssm.conv_kernel`\n- `mamba.ssm.inner_size`\n- `mamba.ssm.state_size`\n- `mamba.ssm.time_step_rank`\n- `mamba.attention.layer_norm_rms_epsilon`\n\n##### RWKV\n\nThe vocabulary size is the same as the number of rows in the `head` matrix.\n\n- `rwkv.architecture_version: uint32`: The only allowed value currently is 4. Version 5 is expected to appear some time in the future.\n- `rwkv.context_length: uint64`: Length of the context used during training or fine-tuning. RWKV is able to handle larger context than this limit, but the output quality may suffer.\n- `rwkv.block_count: uint64`\n- `rwkv.embedding_length: uint64`\n- `rwkv.feed_forward_length: uint64`\n\n##### Whisper\n\nKeys that do not have types defined should be assumed to share definitions with `llm.` keys.\n(For example, `whisper.context_length` is equivalent to `llm.context_length`.)\nThis is because they are both transformer models.\n\n- `whisper.encoder.context_length`\n- `whisper.encoder.embedding_length`\n- `whisper.encoder.block_count`\n- `whisper.encoder.mels_count: uint64`\n- `whisper.encoder.attention.head_count`\n\n- `whisper.decoder.context_length`\n- `whisper.decoder.embedding_length`\n- `whisper.decoder.block_count`\n- `whisper.decoder.attention.head_count`\n\n#### Prompting\n\n**TODO**: Include prompt format, and/or metadata about how it should be used (instruction, conversation, autocomplete, etc).\n\n### LoRA\n\n**TODO**: Figure out what metadata is needed for LoRA. Probably desired features:\n\n- match an existing model exactly, so that it can't be misapplied\n- be marked as a LoRA so executors won't try to run it by itself\n\nShould this be an architecture, or should it share the details of the original model with additional fields to mark it as a LoRA?\n\n### Tokenizer\n\nThe following keys are used to describe the tokenizer of the model. It is recommended that model authors support as many of these as possible, as it will allow for better tokenization quality with supported executors.\n\n#### GGML\n\nGGML supports an embedded vocabulary that enables inference of the model, but implementations of tokenization using this vocabulary (i.e. `llama.cpp`'s tokenizer) may have lower accuracy than the original tokenizer used for the model. When a more accurate tokenizer is available and supported, it should be used instead.\n\nIt is not guaranteed to be standardized across models, and may change in the future. It is recommended that model authors use a more standardized tokenizer if possible.\n\n- `tokenizer.ggml.model: string`: The name of the tokenizer model.\n  - `llama`: Llama style SentencePiece (tokens and scores extracted from HF `tokenizer.model`)\n  - `replit`: Replit style SentencePiece (tokens and scores extracted from HF `spiece.model`)\n  - `gpt2`: GPT-2 / GPT-NeoX style BPE (tokens extracted from HF `tokenizer.json`)\n  - `rwkv`: RWKV tokenizer\n- `tokenizer.ggml.tokens: array[string]`: A list of tokens indexed by the token ID used by the model.\n- `tokenizer.ggml.scores: array[float32]`: If present, the score/probability of each token. If not present, all tokens are assumed to have equal probability. If present, it must have the same length and index as `tokens`.\n- `tokenizer.ggml.token_type: array[int32]`: The token type (1=normal, 2=unknown, 3=control, 4=user defined, 5=unused, 6=byte). If present, it must have the same length and index as `tokens`.\n- `tokenizer.ggml.merges: array[string]`: If present, the merges of the tokenizer. If not present, the tokens are assumed to be atomic.\n- `tokenizer.ggml.added_tokens: array[string]`: If present, tokens that were added after training.\n\n##### Special tokens\n\n- `tokenizer.ggml.bos_token_id: uint32`: Beginning of sequence marker\n- `tokenizer.ggml.eos_token_id: uint32`: End of sequence marker\n- `tokenizer.ggml.unknown_token_id: uint32`: Unknown token\n- `tokenizer.ggml.separator_token_id: uint32`: Separator token\n- `tokenizer.ggml.padding_token_id: uint32`: Padding token\n\n#### Hugging Face\n\nHugging Face maintains their own `tokenizers` library that supports a wide variety of tokenizers. If your executor uses this library, it may be able to use the model's tokenizer directly.\n\n- `tokenizer.huggingface.json: string`: the entirety of the HF `tokenizer.json` for a given model (e.g. <https://huggingface.co/mosaicml/mpt-7b-instruct/blob/main/tokenizer.json>). Included for compatibility with executors that support HF tokenizers directly.\n\n#### Other\n\nOther tokenizers may be used, but are not necessarily standardized. They may be executor-specific. They will be documented here as they are discovered/further developed.\n\n- `tokenizer.rwkv.world: string`: a RWKV World tokenizer, like [this](https://github.com/BlinkDL/ChatRWKV/blob/main/tokenizer/rwkv_vocab_v20230424.txt). This text file should be included verbatim.\n- `tokenizer.chat_template : string`: a Jinja template that specifies the input format expected by the model. For more details see: <https://huggingface.co/docs/transformers/main/en/chat_templating>\n\n### Computation graph\n\nThis is a future extension and still needs to be discussed, and may necessitate a new GGUF version. At the time of writing, the primary blocker is the stabilization of the computation graph format.\n\nA sample computation graph of GGML nodes could be included in the model itself, allowing an executor to run the model without providing its own implementation of the architecture. This would allow for a more consistent experience across executors, and would allow for more complex architectures to be supported without requiring the executor to implement them.\n\n## Standardized tensor names\n\nTo minimize complexity and maximize compatibility, it is recommended that models using the transformer architecture use the following naming convention for their tensors:\n\n### Base layers\n\n`AA.weight` `AA.bias`\n\nwhere `AA` can be:\n\n- `token_embd`: Token embedding layer\n- `pos_embd`: Position embedding layer\n- `output_norm`: Output normalization layer\n- `output`: Output layer\n\n### Attention and feed-forward layer blocks\n\n`blk.N.BB.weight` `blk.N.BB.bias`\n\nwhere N signifies the block number a layer belongs to, and where `BB` could be:\n\n- `attn_norm`: Attention normalization layer\n- `attn_norm_2`: Attention normalization layer\n- `attn_qkv`: Attention query-key-value layer\n- `attn_q`: Attention query layer\n- `attn_k`: Attention key layer\n- `attn_v`: Attention value layer\n- `attn_output`: Attention output layer\n\n- `ffn_norm`: Feed-forward network normalization layer\n- `ffn_up`: Feed-forward network \"up\" layer\n- `ffn_gate`: Feed-forward network \"gate\" layer\n- `ffn_down`: Feed-forward network \"down\" layer\n- `ffn_gate_inp`: Expert-routing layer for the Feed-forward network in MoE models\n- `ffn_gate_exp`: Feed-forward network \"gate\" layer per expert in MoE models\n- `ffn_down_exp`: Feed-forward network \"down\" layer per expert in MoE models\n- `ffn_up_exp`: Feed-forward network \"up\" layer per expert in MoE models\n\n- `ssm_in`: State space model input projections layer\n- `ssm_conv1d`: State space model rolling/shift layer\n- `ssm_x`: State space model selective parametrization layer\n- `ssm_a`: State space model state compression layer\n- `ssm_d`: State space model skip connection layer\n- `ssm_dt`: State space model time step layer\n- `ssm_out`: State space model output projection layer\n\n## Version History\n\nThis document is actively updated to describe the current state of the metadata, and these changes are not tracked outside of the commits.\n\nHowever, the format _itself_ has changed. The following sections describe the changes to the format itself.\n\n### v3\n\nAdds big-endian support.\n\n### v2\n\nMost countable values (lengths, etc) were changed from `uint32` to `uint64` to allow for larger models to be supported in the future.\n\n### v1\n\nInitial version.\n\n## Historical State of Affairs\n\nThe following information is provided for context, but is not necessary to understand the rest of this document.\n\n### Overview\n\nAt present, there are three GGML file formats floating around for LLMs:\n\n- **GGML** (unversioned): baseline format, with no versioning or alignment.\n- **GGMF** (versioned): the same as GGML, but with versioning. Only one version exists.\n- **GGJT**: Aligns the tensors to allow for use with `mmap`, which requires alignment. v1, v2 and v3 are identical, but the latter versions use a different quantization scheme that is incompatible with previous versions.\n\nGGML is primarily used by the examples in `ggml`, while GGJT is used by `llama.cpp` models. Other executors may use any of the three formats, but this is not 'officially' supported.\n\nThese formats share the same fundamental structure:\n\n- a magic number with an optional version number\n- model-specific hyperparameters, including\n  - metadata about the model, such as the number of layers, the number of heads, etc.\n  - a `ftype` that describes the type of the majority of the tensors,\n    - for GGML files, the quantization version is encoded in the `ftype` divided by 1000\n- an embedded vocabulary, which is a list of strings with length prepended. The GGMF/GGJT formats embed a float32 score next to the strings.\n- finally, a list of tensors with their length-prepended name, type, and (aligned, in the case of GGJT) tensor data\n\nNotably, this structure does not identify what model architecture the model belongs to, nor does it offer any flexibility for changing the structure of the hyperparameters. This means that the only way to add new hyperparameters is to add them to the end of the list, which is a breaking change for existing models.\n\n### Drawbacks\n\nUnfortunately, over the last few months, there are a few issues that have become apparent with the existing models:\n\n- There's no way to identify which model architecture a given model is for, because that information isn't present\n  - Similarly, existing programs cannot intelligently fail upon encountering new architectures\n- Adding or removing any new hyperparameters is a breaking change, which is impossible for a reader to detect without using heuristics\n- Each model architecture requires its own conversion script to their architecture's variant of GGML\n- Maintaining backwards compatibility without breaking the structure of the format requires clever tricks, like packing the quantization version into the ftype, which are not guaranteed to be picked up by readers/writers, and are not consistent between the two formats\n\n### Why not other formats?\n\nThere are a few other formats that could be used, but issues include:\n\n- requiring additional dependencies to load or save the model, which is complicated in a C environment\n- limited or no support for 4-bit quantization\n- existing cultural expectations (e.g. whether or not the model is a directory or a file)\n- lack of support for embedded vocabularies\n- lack of control over direction of future development\n\nUltimately, it is likely that GGUF will remain necessary for the foreseeable future, and it is better to have a single format that is well-documented and supported by all executors than to contort an existing format to fit the needs of GGML.\n"
  },
  {
    "path": "examples/CMakeLists.txt",
    "content": "if (GGML_ALL_WARNINGS)\n  if (NOT MSVC)\n      set(cxx_flags\n          # TODO(marella): Add other warnings.\n          -Wpedantic\n          -Wunused-variable\n          -Wno-unused-function\n          -Wno-multichar\n      )\n      add_compile_options(\"$<$<COMPILE_LANGUAGE:CXX>:${cxx_flags}>\")\n  endif()\nendif()\n\nadd_library(common STATIC common.cpp)\ntarget_include_directories(common PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})\n\nadd_library(common-ggml STATIC common-ggml.cpp)\ntarget_link_libraries(common-ggml PRIVATE ggml)\ntarget_include_directories(common-ggml PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})\n\nadd_subdirectory(yolo)\n\nif (NOT GGML_BACKEND_DL)\n    add_subdirectory(gpt-2)\n    add_subdirectory(gpt-j)\n    add_subdirectory(mnist)\n    add_subdirectory(sam)\n    add_subdirectory(simple)\n    add_subdirectory(magika)\nendif()\n\nif (GGML_METAL)\n    add_subdirectory(perf-metal)\nendif()\n"
  },
  {
    "path": "examples/common-ggml.cpp",
    "content": "#include \"common-ggml.h\"\n\n#include <regex>\n#include <map>\n\nstatic const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {\n    {\"q4_0\", GGML_FTYPE_MOSTLY_Q4_0},\n    {\"q4_1\", GGML_FTYPE_MOSTLY_Q4_1},\n    {\"q5_0\", GGML_FTYPE_MOSTLY_Q5_0},\n    {\"q5_1\", GGML_FTYPE_MOSTLY_Q5_1},\n    {\"q8_0\", GGML_FTYPE_MOSTLY_Q8_0},\n    {\"q2_k\", GGML_FTYPE_MOSTLY_Q2_K},\n    {\"q3_k\", GGML_FTYPE_MOSTLY_Q3_K},\n    {\"q4_k\", GGML_FTYPE_MOSTLY_Q4_K},\n    {\"q5_k\", GGML_FTYPE_MOSTLY_Q5_K},\n    {\"q6_k\", GGML_FTYPE_MOSTLY_Q6_K},\n};\n\nvoid ggml_print_ftypes(FILE * fp) {\n    for (auto it = GGML_FTYPE_MAP.begin(); it != GGML_FTYPE_MAP.end(); it++) {\n        fprintf(fp, \"  type = \\\"%s\\\" or %d\\n\", it->first.c_str(), it->second);\n    }\n}\n\nenum ggml_ftype ggml_parse_ftype(const char * str) {\n    enum ggml_ftype ftype;\n    if (str[0] == 'q') {\n        const auto it = GGML_FTYPE_MAP.find(str);\n        if (it == GGML_FTYPE_MAP.end()) {\n            fprintf(stderr, \"%s: unknown ftype '%s'\\n\", __func__, str);\n            return GGML_FTYPE_UNKNOWN;\n        }\n        ftype = it->second;\n    } else {\n        ftype = (enum ggml_ftype) atoi(str);\n    }\n\n    return ftype;\n}\n\nbool ggml_common_quantize_0(\n        std::ifstream & finp,\n        std::ofstream & fout,\n        const ggml_ftype ftype,\n        const std::vector<std::string> & to_quant,\n        const std::vector<std::string> & to_skip) {\n\n    ggml_type qtype = GGML_TYPE_F32;\n\n    switch (ftype) {\n        case GGML_FTYPE_MOSTLY_Q4_0: qtype = GGML_TYPE_Q4_0; break;\n        case GGML_FTYPE_MOSTLY_Q4_1: qtype = GGML_TYPE_Q4_1; break;\n        case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break;\n        case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break;\n        case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break;\n        case GGML_FTYPE_MOSTLY_Q2_K: qtype = GGML_TYPE_Q2_K; break;\n        case GGML_FTYPE_MOSTLY_Q3_K: qtype = GGML_TYPE_Q3_K; break;\n        case GGML_FTYPE_MOSTLY_Q4_K: qtype = GGML_TYPE_Q4_K; break;\n        case GGML_FTYPE_MOSTLY_Q5_K: qtype = GGML_TYPE_Q5_K; break;\n        case GGML_FTYPE_MOSTLY_Q6_K: qtype = GGML_TYPE_Q6_K; break;\n        case GGML_FTYPE_UNKNOWN:\n        case GGML_FTYPE_ALL_F32:\n        case GGML_FTYPE_MOSTLY_F16:\n        case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:\n        case GGML_FTYPE_MOSTLY_IQ2_XXS:\n        case GGML_FTYPE_MOSTLY_IQ2_XS:\n        case GGML_FTYPE_MOSTLY_IQ2_S:\n        case GGML_FTYPE_MOSTLY_IQ3_XXS:\n        case GGML_FTYPE_MOSTLY_IQ3_S:\n        case GGML_FTYPE_MOSTLY_IQ1_S:\n        case GGML_FTYPE_MOSTLY_IQ4_NL:\n        case GGML_FTYPE_MOSTLY_IQ4_XS:\n        case GGML_FTYPE_MOSTLY_IQ1_M:\n        case GGML_FTYPE_MOSTLY_BF16:\n        case GGML_FTYPE_MOSTLY_MXFP4:\n        case GGML_FTYPE_MOSTLY_NVFP4:\n                {\n                    fprintf(stderr, \"%s: invalid model type %d\\n\", __func__, ftype);\n                    return false;\n                }\n    };\n\n    if (!ggml_is_quantized(qtype)) {\n        fprintf(stderr, \"%s: invalid quantization type %d (%s)\\n\", __func__, qtype, ggml_type_name(qtype));\n        return false;\n    }\n\n    size_t total_size_org = 0;\n    size_t total_size_new = 0;\n\n    std::vector<float> work;\n\n    std::vector<uint8_t>     data_u8;\n    std::vector<ggml_fp16_t> data_f16;\n    std::vector<float>       data_f32;\n\n    while (true) {\n        int32_t n_dims;\n        int32_t length;\n        int32_t ttype;\n\n        finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));\n        finp.read(reinterpret_cast<char *>(&length), sizeof(length));\n        finp.read(reinterpret_cast<char *>(&ttype),  sizeof(ttype));\n\n        if (finp.eof()) {\n            break;\n        }\n\n        int32_t nelements = 1;\n        int32_t ne[4] = { 1, 1, 1, 1 };\n        for (int i = 0; i < n_dims; ++i) {\n            finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));\n            nelements *= ne[i];\n        }\n\n        std::string name(length, 0);\n        finp.read (&name[0], length);\n\n        printf(\"%64s - [%5d, %5d, %5d], type = %6s \", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype));\n\n        bool quantize = false;\n\n        // check if we should quantize this tensor\n        for (const auto & s : to_quant) {\n            if (std::regex_match(name, std::regex(s))) {\n                quantize = true;\n                break;\n            }\n        }\n\n        // check if we should skip this tensor\n        for (const auto & s : to_skip) {\n            if (std::regex_match(name, std::regex(s))) {\n                quantize = false;\n                break;\n            }\n        }\n\n        // quantize only 2D tensors\n        quantize &= (n_dims == 2);\n\n        if (quantize) {\n            if (ttype != GGML_TYPE_F32 && ttype != GGML_TYPE_F16) {\n                fprintf(stderr, \"%s: unsupported ttype %d (%s) for integer quantization\\n\", __func__, ttype, ggml_type_name((ggml_type) ttype));\n                return false;\n            }\n\n            if (ttype == GGML_TYPE_F16) {\n                data_f16.resize(nelements);\n                finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));\n                data_f32.resize(nelements);\n                for (int i = 0; i < nelements; ++i) {\n                    data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);\n                }\n            } else {\n                data_f32.resize(nelements);\n                finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));\n            }\n\n            ttype = qtype;\n        } else {\n            const int bpe = (ttype == 0) ? sizeof(float) : sizeof(uint16_t);\n\n            data_u8.resize(nelements*bpe);\n            finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);\n        }\n\n        fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));\n        fout.write(reinterpret_cast<char *>(&length), sizeof(length));\n        fout.write(reinterpret_cast<char *>(&ttype),  sizeof(ttype));\n        for (int i = 0; i < n_dims; ++i) {\n            fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));\n        }\n        fout.write(&name[0], length);\n\n        if (quantize) {\n            work.resize(nelements); // for quantization\n\n            size_t cur_size = 0;\n            switch ((ggml_type) ttype) {\n                case GGML_TYPE_Q4_0:\n                case GGML_TYPE_Q4_1:\n                case GGML_TYPE_Q5_0:\n                case GGML_TYPE_Q5_1:\n                case GGML_TYPE_Q8_0:\n                case GGML_TYPE_Q2_K:\n                case GGML_TYPE_Q3_K:\n                case GGML_TYPE_Q4_K:\n                case GGML_TYPE_Q5_K:\n                case GGML_TYPE_Q6_K:\n                    {\n                        cur_size = ggml_quantize_chunk((ggml_type) ttype, data_f32.data(), work.data(), 0, nelements/ne[0], ne[0], nullptr);\n                    } break;\n                case GGML_TYPE_F32:\n                case GGML_TYPE_F16:\n                case GGML_TYPE_I8:\n                case GGML_TYPE_I16:\n                case GGML_TYPE_I32:\n                case GGML_TYPE_I64:\n                case GGML_TYPE_F64:\n                case GGML_TYPE_Q8_1:\n                case GGML_TYPE_Q8_K:\n                case GGML_TYPE_IQ2_XXS:\n                case GGML_TYPE_IQ2_XS:\n                case GGML_TYPE_IQ2_S:\n                case GGML_TYPE_IQ3_XXS:\n                case GGML_TYPE_IQ3_S:\n                case GGML_TYPE_IQ1_S:\n                case GGML_TYPE_IQ4_NL:\n                case GGML_TYPE_IQ4_XS:\n                case GGML_TYPE_IQ1_M:\n                case GGML_TYPE_BF16:\n                case GGML_TYPE_TQ1_0:\n                case GGML_TYPE_TQ2_0:\n                case GGML_TYPE_MXFP4:\n                case GGML_TYPE_NVFP4:\n                case GGML_TYPE_COUNT:\n                    {\n                        fprintf(stderr, \"%s: unsupported quantization type %d (%s)\\n\", __func__, ttype, ggml_type_name((ggml_type) ttype));\n                        return false;\n                    }\n            }\n\n            fout.write(reinterpret_cast<char *>(work.data()), cur_size);\n            total_size_new += cur_size;\n\n            printf(\"size = %8.2f MB -> %8.2f MB\\n\", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0);\n        } else {\n            printf(\"size = %8.3f MB\\n\", data_u8.size()/1024.0/1024.0);\n            fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());\n            total_size_new += data_u8.size();\n        }\n\n        total_size_org += nelements * sizeof(float);\n    }\n\n    printf(\"%s: model size  = %8.2f MB\\n\", __func__, total_size_org/1024.0/1024.0);\n    printf(\"%s: quant size  = %8.2f MB | ftype = %d (%s)\\n\", __func__, total_size_new/1024.0/1024.0, ftype, ggml_type_name(qtype));\n\n    return true;\n}\n"
  },
  {
    "path": "examples/common-ggml.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n\n#include <fstream>\n#include <vector>\n#include <string>\n\nenum ggml_ftype ggml_parse_ftype(const char * str);\n\nvoid ggml_print_ftypes(FILE * fp = stderr);\n\nbool ggml_common_quantize_0(\n        std::ifstream & finp,\n        std::ofstream & fout,\n        const ggml_ftype ftype,\n        const std::vector<std::string> & to_quant,\n        const std::vector<std::string> & to_skip);\n"
  },
  {
    "path": "examples/common.cpp",
    "content": "#define _USE_MATH_DEFINES // for M_PI\n\n#include \"common.h\"\n\n#include <cmath>\n#include <codecvt>\n#include <cstring>\n#include <fstream>\n#include <locale>\n#include <regex>\n#include <sstream>\n\n// Function to check if the next argument exists\nstatic std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) {\n    if (i + 1 < argc && argv[i + 1][0] != '-') {\n        return argv[++i];\n    } else {\n        fprintf(stderr, \"error: %s requires one argument.\\n\", flag.c_str());\n        gpt_print_usage(argc, argv, params);\n        exit(0);\n    }\n}\n\nbool gpt_params_parse(int argc, char ** argv, gpt_params & params) {\n    for (int i = 1; i < argc; i++) {\n        std::string arg = argv[i];\n\n        if (arg == \"-s\" || arg == \"--seed\") {\n            params.seed = std::stoi(get_next_arg(i, argc, argv, arg, params));\n        } else if (arg == \"-t\" || arg == \"--threads\") {\n            params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));\n        } else if (arg == \"-p\" || arg == \"--prompt\") {\n            params.prompt = get_next_arg(i, argc, argv, arg, params);\n        } else if (arg == \"-n\" || arg == \"--n_predict\") {\n            params.n_predict = std::stoi(get_next_arg(i, argc, argv, arg, params));\n        } else if (arg == \"-np\" || arg == \"--n_parallel\") {\n            params.n_parallel = std::stoi(get_next_arg(i, argc, argv, arg, params));\n        } else if (arg == \"--top_k\") {\n            params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params));\n        } else if (arg == \"--top_p\") {\n            params.top_p = std::stof(get_next_arg(i, argc, argv, arg, params));\n        } else if (arg == \"--temp\") {\n            params.temp = std::stof(get_next_arg(i, argc, argv, arg, params));\n        } else if (arg == \"--repeat-last-n\") {\n            params.repeat_last_n = std::stoi(get_next_arg(i, argc, argv, arg, params));\n        } else if (arg == \"--repeat-penalty\") {\n            params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params));\n        } else if (arg == \"-b\" || arg == \"--batch_size\") {\n            params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params));\n        } else if (arg == \"-c\" || arg == \"--context\") {\n            params.n_ctx= std::stoi(get_next_arg(i, argc, argv, arg, params));\n        } else if (arg == \"-ngl\" || arg == \"--gpu-layers\" || arg == \"--n-gpu-layers\") {\n            params.n_gpu_layers = std::stoi(get_next_arg(i, argc, argv, arg, params));\n        } else if (arg == \"--ignore-eos\") {\n            params.ignore_eos = true;\n        } else if (arg == \"-m\" || arg == \"--model\") {\n            params.model = get_next_arg(i, argc, argv, arg, params);\n        } else if (arg == \"-i\" || arg == \"--interactive\") {\n            params.interactive = true;\n        } else if (arg == \"-ip\" || arg == \"--interactive-port\") {\n            params.interactive = true;\n            params.interactive_port = std::stoi(get_next_arg(i, argc, argv, arg, params));\n        } else if (arg == \"-h\" || arg == \"--help\") {\n            gpt_print_usage(argc, argv, params);\n            exit(0);\n        } else if (arg == \"-f\" || arg == \"--file\") {\n            get_next_arg(i, argc, argv, arg, params);\n            std::ifstream file(argv[i]);\n            if (!file) {\n                fprintf(stderr, \"error: failed to open file '%s'\\n\", argv[i]);\n                break;\n            }\n            std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));\n            if (params.prompt.back() == '\\n') {\n                params.prompt.pop_back();\n            }\n        } else if (arg == \"-tt\" || arg == \"--token_test\") {\n            params.token_test = get_next_arg(i, argc, argv, arg, params);\n        }\n        else {\n            fprintf(stderr, \"error: unknown argument: %s\\n\", arg.c_str());\n            gpt_print_usage(argc, argv, params);\n            exit(0);\n        }\n    }\n\n    return true;\n}\n\nvoid gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {\n    fprintf(stderr, \"usage: %s [options]\\n\", argv[0]);\n    fprintf(stderr, \"\\n\");\n    fprintf(stderr, \"options:\\n\");\n    fprintf(stderr, \"  -h, --help            show this help message and exit\\n\");\n    fprintf(stderr, \"  -s SEED, --seed SEED  RNG seed (default: -1)\\n\");\n    fprintf(stderr, \"  -t N, --threads N     number of threads to use during computation (default: %d)\\n\", params.n_threads);\n    fprintf(stderr, \"  -p PROMPT, --prompt PROMPT\\n\");\n    fprintf(stderr, \"                        prompt to start generation with (default: random)\\n\");\n    fprintf(stderr, \"  -f FNAME, --file FNAME\\n\");\n    fprintf(stderr, \"                        load prompt from a file\\n\");\n    fprintf(stderr, \"  -tt TOKEN_TEST, --token_test TOKEN_TEST\\n\");\n    fprintf(stderr, \"                        test tokenization\\n\");\n    fprintf(stderr, \"  -n N, --n_predict N   number of tokens to predict (default: %d)\\n\", params.n_predict);\n    fprintf(stderr, \"  --top_k N             top-k sampling (default: %d)\\n\", params.top_k);\n    fprintf(stderr, \"  --top_p N             top-p sampling (default: %.1f)\\n\", params.top_p);\n    fprintf(stderr, \"  --temp N              temperature (default: %.1f)\\n\", params.temp);\n    fprintf(stderr, \"  --repeat-last-n N     last n tokens to consider for penalize (default: %d, 0 = disabled)\\n\", params.repeat_last_n);\n    fprintf(stderr, \"  --repeat-penalty N    penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\\n\", (double)params.repeat_penalty);\n    fprintf(stderr, \"  -b N, --batch_size N  batch size for prompt processing (default: %d)\\n\", params.n_batch);\n    fprintf(stderr, \"  -c N, --context N     context / KV cache size (default: %d)\\n\", params.n_ctx);\n    fprintf(stderr, \"  --ignore-eos          ignore EOS token during generation\\n\");\n    fprintf(stderr, \"  -ngl N, --gpu-layers N  number of layers to offload to GPU on supported models (default: %d)\\n\", params.n_gpu_layers);\n    fprintf(stderr, \"  -m FNAME, --model FNAME\\n\");\n    fprintf(stderr, \"                        model path (default: %s)\\n\", params.model.c_str());\n    fprintf(stderr, \"\\n\");\n}\n\nstd::string gpt_random_prompt(std::mt19937 & rng) {\n    const int r = rng() % 10;\n    switch (r) {\n        case 0: return \"So\";\n        case 1: return \"Once upon a time\";\n        case 2: return \"When\";\n        case 3: return \"The\";\n        case 4: return \"After\";\n        case 5: return \"If\";\n        case 6: return \"import\";\n        case 7: return \"He\";\n        case 8: return \"She\";\n        case 9: return \"They\";\n    }\n\n    return \"The\";\n}\n\nstd::string trim(const std::string & s) {\n    std::regex e(\"^\\\\s+|\\\\s+$\");\n    return std::regex_replace(s, e, \"\");\n}\n\nstd::string replace(const std::string & s, const std::string & from, const std::string & to) {\n    std::string result = s;\n    size_t pos = 0;\n    while ((pos = result.find(from, pos)) != std::string::npos) {\n        result.replace(pos, from.length(), to);\n        pos += to.length();\n    }\n    return result;\n}\n\nvoid gpt_vocab::add_special_token(const std::string & token) {\n    special_tokens.push_back(token);\n}\n\nstd::map<std::string, int32_t> json_parse(const std::string & fname) {\n    std::map<std::string, int32_t> result;\n\n    // read file into string\n    std::string json;\n    {\n        std::ifstream ifs(fname);\n        if (!ifs) {\n            fprintf(stderr, \"Failed to open %s\\n\", fname.c_str());\n            exit(1);\n        }\n\n        json = std::string((std::istreambuf_iterator<char>(ifs)),\n                (std::istreambuf_iterator<char>()));\n    }\n\n    if (json[0] != '{') {\n        return result;\n    }\n\n    // parse json\n    {\n        bool has_key  = false;\n        bool in_token = false;\n\n        std::string str_key = \"\";\n        std::string str_val = \"\";\n\n        int n = json.size();\n        for (int i = 1; i < n; ++i) {\n            if (!in_token) {\n                if (json[i] == ' ') continue;\n                if (json[i] == '\"') {\n                    in_token = true;\n                    continue;\n                }\n            } else {\n                if (json[i] == '\\\\' && i+1 < n) {\n                    if (has_key == false) {\n                        str_key += json[i];\n                    } else {\n                        str_val += json[i];\n                    }\n                    ++i;\n                } else if (json[i] == '\"') {\n                    if (has_key == false) {\n                        has_key = true;\n                        ++i;\n                        while (json[i] == ' ') ++i;\n                        ++i; // :\n                        while (json[i] == ' ') ++i;\n                        if (json[i] != '\\\"') {\n                            while (json[i] != ',' && json[i] != '}') {\n                                str_val += json[i++];\n                            }\n                            has_key = false;\n                        } else {\n                            in_token = true;\n                            continue;\n                        }\n                    } else {\n                        has_key = false;\n                    }\n\n                    str_key = ::replace(str_key, \"\\\\u0120\", \" \" ); // \\u0120 -> space\n                    str_key = ::replace(str_key, \"\\\\u010a\", \"\\n\"); // \\u010a -> new line\n                    str_key = ::replace(str_key, \"\\\\\\\"\",    \"\\\"\"); // \\\\\\\"   -> \"\n\n                    try {\n                        result[str_key] = std::stoi(str_val);\n                    } catch (...) {\n                        //fprintf(stderr, \"%s: ignoring key '%s' with value '%s'\\n\", fname.c_str(), str_key.c_str(), str_val.c_str());\n\n                    }\n                    str_key = \"\";\n                    str_val = \"\";\n                    in_token = false;\n                    continue;\n                }\n                if (has_key == false) {\n                    str_key += json[i];\n                } else {\n                    str_val += json[i];\n                }\n            }\n        }\n    }\n\n    return result;\n}\n\nvoid gpt_split_words(std::string str, std::vector<std::string>& words) {\n    const std::string pattern = R\"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\\s[:alpha:][:digit:]]+|\\s+(?!\\S)|\\s+)\";\n    const std::regex re(pattern);\n    std::smatch m;\n\n    while (std::regex_search(str, m, re)) {\n        for (auto x : m) {\n            words.push_back(x);\n        }\n        str = m.suffix();\n    }\n}\n\nstd::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {\n    std::vector<std::string> words;\n\n    // first split the text into words\n    {\n        std::string str = text;\n\n        // Generate the subpattern from the special_tokens vector if it's not empty\n        if (!vocab.special_tokens.empty()) {\n            const std::regex escape(R\"([\\[\\\\\\^\\$\\.\\|\\?\\*\\+\\(\\)\\{\\}])\");\n            std::string special_tokens_subpattern;\n            for (const auto & token : vocab.special_tokens) {\n                if (!special_tokens_subpattern.empty()) {\n                    special_tokens_subpattern += \"|\";\n                }\n                special_tokens_subpattern += std::regex_replace(token, escape, R\"(\\$&)\");\n            }\n\n            std::regex re(special_tokens_subpattern);\n            std::smatch m;\n            // Split the text by special tokens.\n            while (std::regex_search(str, m, re)) {\n                // Split the substrings in-between special tokens into words.\n                gpt_split_words(m.prefix(), words);\n                // Add matched special tokens as words.\n                for (auto x : m) {\n                    words.push_back(x);\n                }\n                str = m.suffix();\n            }\n            // Remaining text without special tokens will be handled below.\n        }\n\n        gpt_split_words(str, words);\n    }\n\n    // find the longest token that forms each word in words:\n    std::vector<gpt_vocab::id> tokens;\n    for (const auto & word : words) {\n        for (int i = 0; i < (int) word.size(); ){\n            for (int j = word.size() - 1; j >= i; j--){\n                auto cand = word.substr(i, j-i+1);\n                auto it = vocab.token_to_id.find(cand);\n                if (it != vocab.token_to_id.end()){ // word.substr(i, j-i+1) in vocab\n                    tokens.push_back(it->second);\n                    i = j + 1;\n                    break;\n                }\n                else if (j == i){ // word.substr(i, 1) has no matching\n                    fprintf(stderr, \"%s: unknown token '%s'\\n\", __func__, word.substr(i, 1).data());\n                    i++;\n                }\n            }\n        }\n    }\n\n    return tokens;\n}\n\nstatic std::vector<gpt_vocab::id> parse_tokens_from_string(const std::string& input, char delimiter) {\n    std::vector<gpt_vocab::id> output;\n    std::stringstream ss(input);\n    std::string token;\n\n    while (std::getline(ss, token, delimiter)) {\n        output.push_back(std::stoi(token));\n    }\n\n    return output;\n}\n\nstatic std::map<std::string, std::vector<gpt_vocab::id>> extract_tests_from_file(const std::string & fpath_test){\n    if (fpath_test.empty()){\n        fprintf(stderr, \"%s : No test file found.\\n\", __func__);\n        return std::map<std::string, std::vector<gpt_vocab::id>>();\n    }\n\n    std::map<std::string, std::vector<gpt_vocab::id>> tests;\n\n    auto fin = std::ifstream(fpath_test, std::ios_base::in);\n    const char * delimeter = \" => \";\n    const char del_tok = ',';\n    std::string line;\n    while (std::getline(fin, line)) {\n        size_t delimiterPos = line.find(delimeter);\n        if (delimiterPos != std::string::npos) {\n            std::string text = line.substr(0, delimiterPos);\n            std::string s_tokens = line.substr(delimiterPos + std::strlen(delimeter));\n            tests[text] = parse_tokens_from_string(s_tokens, del_tok);\n        }\n    }\n    return tests;\n}\n\nvoid test_gpt_tokenizer(gpt_vocab & vocab, const std::string & fpath_test){\n    std::map<std::string, std::vector<gpt_vocab::id>> tests = extract_tests_from_file(fpath_test);\n\n    size_t n_fails = 0;\n\n    for (const auto & test : tests) {\n        std::vector<gpt_vocab::id> tokens = gpt_tokenize(vocab, test.first);\n\n        if (tokens != test.second){\n            n_fails++;\n\n            // print out failure cases\n            fprintf(stderr, \"%s : failed test: '%s'\\n\", __func__, test.first.c_str());\n            fprintf(stderr, \"%s : tokens in hf:   \", __func__);\n            for (const auto & t : test.second) {\n                fprintf(stderr, \"%s(%d), \", vocab.id_to_token[t].c_str(), t);\n            }\n            fprintf(stderr, \"\\n\");\n            fprintf(stderr, \"%s : tokens in ggml: \", __func__);\n            for (const auto & t : tokens) {\n                fprintf(stderr, \"%s(%d), \", vocab.id_to_token[t].c_str(), t);\n            }\n            fprintf(stderr, \"\\n\");\n        }\n    }\n\n    fprintf(stderr, \"%s : %zu tests failed out of %zu tests.\\n\", __func__, n_fails, tests.size());\n}\n\nbool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {\n    printf(\"%s: loading vocab from '%s'\\n\", __func__, fname.c_str());\n\n    vocab.token_to_id = ::json_parse(fname);\n\n    for (const auto & kv : vocab.token_to_id) {\n        vocab.id_to_token[kv.second] = kv.first;\n    }\n\n    printf(\"%s: vocab size = %d\\n\", __func__, (int) vocab.token_to_id.size());\n\n    // print the vocabulary\n    //for (auto kv : vocab.token_to_id) {\n    //    printf(\"'%s' -> %d\\n\", kv.first.data(), kv.second);\n    //}\n\n    return true;\n}\n\ngpt_vocab::id gpt_sample_top_k_top_p(\n        const gpt_vocab & vocab,\n        const float * logits,\n        int    top_k,\n        double top_p,\n        double temp,\n        std::mt19937 & rng) {\n    int n_logits = vocab.id_to_token.size();\n\n    std::vector<std::pair<double, gpt_vocab::id>> logits_id;\n    logits_id.reserve(n_logits);\n\n    {\n        const double scale = 1.0/temp;\n        for (int i = 0; i < n_logits; ++i) {\n            logits_id.push_back(std::make_pair(logits[i]*scale, i));\n        }\n    }\n\n    // find the top K tokens\n    std::partial_sort(\n            logits_id.begin(),\n            logits_id.begin() + top_k, logits_id.end(),\n            [](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {\n        return a.first > b.first;\n    });\n\n    logits_id.resize(top_k);\n\n    double maxl = -INFINITY;\n    for (const auto & kv : logits_id) {\n        maxl = std::max(maxl, kv.first);\n    }\n\n    // compute probs for the top K tokens\n    std::vector<double> probs;\n    probs.reserve(logits_id.size());\n\n    double sum = 0.0;\n    for (const auto & kv : logits_id) {\n        double p = exp(kv.first - maxl);\n        probs.push_back(p);\n        sum += p;\n    }\n\n    // normalize the probs\n    for (auto & p : probs) {\n        p /= sum;\n    }\n\n    if (top_p < 1.0f) {\n        double cumsum = 0.0f;\n        for (int i = 0; i < top_k; i++) {\n            cumsum += probs[i];\n            if (cumsum >= top_p) {\n                top_k = i + 1;\n                probs.resize(top_k);\n                logits_id.resize(top_k);\n                break;\n            }\n        }\n\n        cumsum = 1.0/cumsum;\n        for (int i = 0; i < (int) probs.size(); i++) {\n            probs[i] *= cumsum;\n        }\n    }\n\n    //printf(\"\\n\");\n    //for (int i = 0; i < (int) probs.size(); i++) {\n    //    printf(\"%d: '%s' %f\\n\", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);\n    //}\n    //exit(0);\n\n    std::discrete_distribution<> dist(probs.begin(), probs.end());\n    int idx = dist(rng);\n\n    return logits_id[idx].second;\n}\n\ngpt_vocab::id gpt_sample_top_k_top_p_repeat(\n        const gpt_vocab & vocab,\n        const float * logits,\n        const int32_t * last_n_tokens_data,\n        size_t last_n_tokens_data_size,\n        int    top_k,\n        double top_p,\n        double temp,\n        int repeat_last_n,\n        float repeat_penalty,\n        std::mt19937 & rng) {\n\n    int n_logits = vocab.id_to_token.size();\n\n    const auto * plogits = logits;\n\n    const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_data_size);\n\n    if (temp <= 0) {\n        // select the token with the highest logit directly\n        float max_logit = plogits[0];\n        gpt_vocab::id max_id = 0;\n\n        for (int i = 1; i < n_logits; ++i) {\n            if (plogits[i] > max_logit) {\n                max_logit = plogits[i];\n                max_id = i;\n            }\n        }\n        return max_id;\n    }\n\n\n    std::vector<std::pair<double, gpt_vocab::id>> logits_id;\n    logits_id.reserve(n_logits);\n\n    {\n        const float scale = 1.0f/temp;\n        for (int i = 0; i < n_logits; ++i) {\n            // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)\n            // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main\n            if (repeat_last_n > 0 && std::find(last_n_tokens.end()-repeat_last_n, last_n_tokens.end(), i) != last_n_tokens.end()) {\n                // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability\n                if (plogits[i] < 0.0f) {\n                    logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));\n                } else {\n                    logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));\n                }\n            } else {\n                logits_id.push_back(std::make_pair(plogits[i]*scale, i));\n            }\n        }\n    }\n\n    // find the top K tokens\n    std::partial_sort(\n            logits_id.begin(),\n            logits_id.begin() + top_k, logits_id.end(),\n            [](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {\n        return a.first > b.first;\n    });\n\n    logits_id.resize(top_k);\n\n    double maxl = -INFINITY;\n    for (const auto & kv : logits_id) {\n        maxl = std::max(maxl, kv.first);\n    }\n\n    // compute probs for the top K tokens\n    std::vector<double> probs;\n    probs.reserve(logits_id.size());\n\n    double sum = 0.0;\n    for (const auto & kv : logits_id) {\n        double p = exp(kv.first - maxl);\n        probs.push_back(p);\n        sum += p;\n    }\n\n    // normalize the probs\n    for (auto & p : probs) {\n        p /= sum;\n    }\n\n    if (top_p < 1.0f) {\n        double cumsum = 0.0f;\n        for (int i = 0; i < top_k; i++) {\n            cumsum += probs[i];\n            if (cumsum >= top_p) {\n                top_k = i + 1;\n                probs.resize(top_k);\n                logits_id.resize(top_k);\n                break;\n            }\n        }\n\n        cumsum = 1.0/cumsum;\n        for (int i = 0; i < (int) probs.size(); i++) {\n            probs[i] *= cumsum;\n        }\n    }\n\n//    printf(\"\\n\");\n//    for (int i = 0; i < (int) probs.size(); i++) {\n//    for (int i = 0; i < 10; i++) {\n//        printf(\"%d: '%s' %f\\n\", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);\n//    }\n\n    std::discrete_distribution<> dist(probs.begin(), probs.end());\n    int idx = dist(rng);\n\n    return logits_id[idx].second;\n\n}\n\nvoid high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {\n    const float rc = 1.0f / (2.0f * M_PI * cutoff);\n    const float dt = 1.0f / sample_rate;\n    const float alpha = dt / (rc + dt);\n\n    float y = data[0];\n\n    for (size_t i = 1; i < data.size(); i++) {\n        y = alpha * (y + data[i] - data[i - 1]);\n        data[i] = y;\n    }\n}\n\nbool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {\n    const int n_samples      = pcmf32.size();\n    const int n_samples_last = (sample_rate * last_ms) / 1000;\n\n    if (n_samples_last >= n_samples) {\n        // not enough samples - assume no speech\n        return false;\n    }\n\n    if (freq_thold > 0.0f) {\n        high_pass_filter(pcmf32, freq_thold, sample_rate);\n    }\n\n    float energy_all  = 0.0f;\n    float energy_last = 0.0f;\n\n    for (int i = 0; i < n_samples; i++) {\n        energy_all += fabsf(pcmf32[i]);\n\n        if (i >= n_samples - n_samples_last) {\n            energy_last += fabsf(pcmf32[i]);\n        }\n    }\n\n    energy_all  /= n_samples;\n    energy_last /= n_samples_last;\n\n    if (verbose) {\n        fprintf(stderr, \"%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\\n\", __func__, energy_all, energy_last, vad_thold, freq_thold);\n    }\n\n    if (energy_last > vad_thold*energy_all) {\n        return false;\n    }\n\n    return true;\n}\n\nfloat similarity(const std::string & s0, const std::string & s1) {\n    const size_t len0 = s0.size() + 1;\n    const size_t len1 = s1.size() + 1;\n\n    std::vector<int> col(len1, 0);\n    std::vector<int> prevCol(len1, 0);\n\n    for (size_t i = 0; i < len1; i++) {\n        prevCol[i] = i;\n    }\n\n    for (size_t i = 0; i < len0; i++) {\n        col[0] = i;\n        for (size_t j = 1; j < len1; j++) {\n            col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (i > 0 && s0[i - 1] == s1[j - 1] ? 0 : 1));\n        }\n        col.swap(prevCol);\n    }\n\n    const float dist = prevCol[len1 - 1];\n\n    return 1.0f - (dist / std::max(s0.size(), s1.size()));\n}\n\nbool is_file_exist(const char * filename) {\n    std::ifstream infile(filename);\n    return infile.good();\n}\n"
  },
  {
    "path": "examples/common.h",
    "content": "// Various helper functions and utilities\n\n#pragma once\n\n#include <string>\n#include <map>\n#include <vector>\n#include <random>\n#include <thread>\n#include <ctime>\n#include <fstream>\n#include <sstream>\n\n//\n// GPT CLI argument parsing\n//\n\nstruct gpt_params {\n    int32_t seed         = -1;   // RNG seed\n    int32_t n_threads    = std::min(4, (int32_t) std::thread::hardware_concurrency());\n    int32_t n_predict    = 200;  // new tokens to predict\n    int32_t n_parallel   = 1;    // number of parallel streams\n    int32_t n_batch      = 32;   // batch size for prompt processing\n    int32_t n_ctx        = 2048; // context size (this is the KV cache max size)\n    int32_t n_gpu_layers = 0;    // number of layers to offlload to the GPU\n\n    bool ignore_eos = false; // ignore EOS token when generating text\n\n    // sampling parameters\n    int32_t top_k          = 40;\n    float   top_p          = 0.9f;\n    float   temp           = 0.9f;\n    int32_t repeat_last_n  = 64;\n    float   repeat_penalty = 1.00f;\n\n    std::string model      = \"models/gpt-2-117M/ggml-model.bin\"; // model path\n    std::string prompt     = \"\";\n    std::string token_test = \"\";\n\n    bool    interactive      = false;\n    int32_t interactive_port = -1;\n};\n\nbool gpt_params_parse(int argc, char ** argv, gpt_params & params);\n\nvoid gpt_print_usage(int argc, char ** argv, const gpt_params & params);\n\nstd::string gpt_random_prompt(std::mt19937 & rng);\n\n//\n// Vocab utils\n//\n\nstd::string trim(const std::string & s);\n\nstd::string replace(\n        const std::string & s,\n        const std::string & from,\n        const std::string & to);\n\nstruct gpt_vocab {\n    using id    = int32_t;\n    using token = std::string;\n\n    std::map<token, id> token_to_id;\n    std::map<id, token> id_to_token;\n    std::vector<std::string> special_tokens;\n\n    void add_special_token(const std::string & token);\n};\n\n// poor-man's JSON parsing\nstd::map<std::string, int32_t> json_parse(const std::string & fname);\n\nstd::string convert_to_utf8(const std::wstring & input);\n\nstd::wstring convert_to_wstring(const std::string & input);\n\nvoid gpt_split_words(std::string str, std::vector<std::string>& words);\n\n// split text into tokens\n//\n// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53\n//\n// Regex (Python):\n// r\"\"\"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+\"\"\"\n//\n// Regex (C++):\n// R\"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\\s[:alpha:][:digit:]]+|\\s+(?!\\S)|\\s+)\"\n//\nstd::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);\n\n// test outputs of gpt_tokenize\n//\n//   - compare with tokens generated by the huggingface tokenizer\n//   - test cases are chosen based on the model's main language (under 'prompt' directory)\n//   - if all sentences are tokenized identically, print 'All tests passed.'\n//   - otherwise, print sentence, huggingface tokens, ggml tokens\n//\nvoid test_gpt_tokenizer(gpt_vocab & vocab, const std::string & fpath_test);\n\n// load the tokens from encoder.json\nbool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);\n\n// sample next token given probabilities for each embedding\n//\n//   - consider only the top K tokens\n//   - from them, consider only the top tokens with cumulative probability > P\n//\n// TODO: not sure if this implementation is correct\n// TODO: temperature is not implemented\n//\ngpt_vocab::id gpt_sample_top_k_top_p(\n        const gpt_vocab & vocab,\n        const float * logits,\n        int    top_k,\n        double top_p,\n        double temp,\n        std::mt19937 & rng);\n\ngpt_vocab::id gpt_sample_top_k_top_p_repeat(\n        const gpt_vocab & vocab,\n        const float * logits,\n        const int32_t * last_n_tokens_data,\n        size_t last_n_tokens_data_size,\n        int    top_k,\n        double top_p,\n        double temp,\n        int repeat_last_n,\n        float repeat_penalty,\n        std::mt19937 & rng);\n\n//\n// Audio utils\n//\n\n// Write PCM data into WAV audio file\nclass wav_writer {\nprivate:\n    std::ofstream file;\n    uint32_t dataSize = 0;\n    std::string wav_filename;\n\n    bool write_header(const uint32_t sample_rate,\n                      const uint16_t bits_per_sample,\n                      const uint16_t channels) {\n\n        file.write(\"RIFF\", 4);\n        file.write(\"\\0\\0\\0\\0\", 4);    // Placeholder for file size\n        file.write(\"WAVE\", 4);\n        file.write(\"fmt \", 4);\n\n        const uint32_t sub_chunk_size = 16;\n        const uint16_t audio_format = 1;      // PCM format\n        const uint32_t byte_rate = sample_rate * channels * bits_per_sample / 8;\n        const uint16_t block_align = channels * bits_per_sample / 8;\n\n        file.write(reinterpret_cast<const char *>(&sub_chunk_size), 4);\n        file.write(reinterpret_cast<const char *>(&audio_format), 2);\n        file.write(reinterpret_cast<const char *>(&channels), 2);\n        file.write(reinterpret_cast<const char *>(&sample_rate), 4);\n        file.write(reinterpret_cast<const char *>(&byte_rate), 4);\n        file.write(reinterpret_cast<const char *>(&block_align), 2);\n        file.write(reinterpret_cast<const char *>(&bits_per_sample), 2);\n        file.write(\"data\", 4);\n        file.write(\"\\0\\0\\0\\0\", 4);    // Placeholder for data size\n\n        return true;\n    }\n\n    // It is assumed that PCM data is normalized to a range from -1 to 1\n    bool write_audio(const float * data, size_t length) {\n        for (size_t i = 0; i < length; ++i) {\n            const int16_t intSample = int16_t(data[i] * 32767);\n            file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));\n            dataSize += sizeof(int16_t);\n        }\n        if (file.is_open()) {\n            file.seekp(4, std::ios::beg);\n            uint32_t fileSize = 36 + dataSize;\n            file.write(reinterpret_cast<char *>(&fileSize), 4);\n            file.seekp(40, std::ios::beg);\n            file.write(reinterpret_cast<char *>(&dataSize), 4);\n            file.seekp(0, std::ios::end);\n        }\n        return true;\n    }\n\n    bool open_wav(const std::string & filename) {\n        if (filename != wav_filename) {\n            if (file.is_open()) {\n                file.close();\n            }\n        }\n        if (!file.is_open()) {\n            file.open(filename, std::ios::binary);\n            wav_filename = filename;\n            dataSize = 0;\n        }\n        return file.is_open();\n    }\n\npublic:\n    bool open(const std::string & filename,\n              const    uint32_t   sample_rate,\n              const    uint16_t   bits_per_sample,\n              const    uint16_t   channels) {\n\n        if (open_wav(filename)) {\n            write_header(sample_rate, bits_per_sample, channels);\n        } else {\n            return false;\n        }\n\n        return true;\n    }\n\n    bool close() {\n        file.close();\n        return true;\n    }\n\n    bool write(const float * data, size_t length) {\n        return write_audio(data, length);\n    }\n\n    ~wav_writer() {\n        if (file.is_open()) {\n            file.close();\n        }\n    }\n};\n\n\n// Apply a high-pass frequency filter to PCM audio\n// Suppresses frequencies below cutoff Hz\nvoid high_pass_filter(\n        std::vector<float> & data,\n        float cutoff,\n        float sample_rate);\n\n// Basic voice activity detection (VAD) using audio energy adaptive threshold\nbool vad_simple(\n        std::vector<float> & pcmf32,\n        int   sample_rate,\n        int   last_ms,\n        float vad_thold,\n        float freq_thold,\n        bool  verbose);\n\n// compute similarity between two strings using Levenshtein distance\nfloat similarity(const std::string & s0, const std::string & s1);\n\n//\n// Terminal utils\n//\n\n#define SQR(X)    ((X) * (X))\n#define UNCUBE(x) x < 48 ? 0 : x < 115 ? 1 : (x - 35) / 40\n\n/**\n * Quantizes 24-bit RGB to xterm256 code range [16,256).\n */\nstatic int rgb2xterm256(int r, int g, int b) {\n    unsigned char cube[] = {0, 0137, 0207, 0257, 0327, 0377};\n    int av, ir, ig, ib, il, qr, qg, qb, ql;\n    av = r * .299 + g * .587 + b * .114 + .5;\n    ql = (il = av > 238 ? 23 : (av - 3) / 10) * 10 + 8;\n    qr = cube[(ir = UNCUBE(r))];\n    qg = cube[(ig = UNCUBE(g))];\n    qb = cube[(ib = UNCUBE(b))];\n    if (SQR(qr - r) + SQR(qg - g) + SQR(qb - b) <=\n        SQR(ql - r) + SQR(ql - g) + SQR(ql - b))\n        return ir * 36 + ig * 6 + ib + 020;\n    return il + 0350;\n}\n\nstatic std::string set_xterm256_foreground(int r, int g, int b) {\n    int x = rgb2xterm256(r, g, b);\n    std::ostringstream oss;\n    oss << \"\\033[38;5;\" << x << \"m\";\n    return oss.str();\n}\n\n// Lowest is red, middle is yellow, highest is green. Color scheme from\n// Paul Tol; it is colorblind friendly https://sronpersonalpages.nl/~pault\nconst std::vector<std::string> k_colors = {\n    set_xterm256_foreground(220,   5,  12),\n    set_xterm256_foreground(232,  96,  28),\n    set_xterm256_foreground(241, 147,  45),\n    set_xterm256_foreground(246, 193,  65),\n    set_xterm256_foreground(247, 240,  86),\n    set_xterm256_foreground(144, 201, 135),\n    set_xterm256_foreground( 78, 178, 101),\n};\n\n// ANSI formatting codes\nstatic std::string set_inverse() {\n    return \"\\033[7m\";\n}\n\nstatic std::string set_underline() {\n    return \"\\033[4m\";\n}\n\nstatic std::string set_dim() {\n    return \"\\033[2m\";\n}\n\n// Style scheme for different confidence levels\nconst std::vector<std::string> k_styles = {\n    set_inverse(),   // Low confidence - inverse (highlighted)\n    set_underline(), // Medium confidence - underlined\n    set_dim(),       // High confidence - dim\n};\n\n//\n// Other utils\n//\n\n// check if file exists using ifstream\nbool is_file_exist(const char * filename);\n"
  },
  {
    "path": "examples/gpt-2/CMakeLists.txt",
    "content": "#\n# gpt-2\n\nset(TEST_TARGET gpt-2-ctx)\nadd_executable(${TEST_TARGET} main-ctx.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)\n\nset(TEST_TARGET gpt-2-alloc)\nadd_executable(${TEST_TARGET} main-alloc.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)\n\nset(TEST_TARGET gpt-2-backend)\nadd_executable(${TEST_TARGET} main-backend.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)\n\nset(TEST_TARGET gpt-2-sched)\nadd_executable(${TEST_TARGET} main-sched.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)\n\n#\n# gpt-2-quantize\n\nset(TEST_TARGET gpt-2-quantize)\nadd_executable(${TEST_TARGET} quantize.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)\n\n#\n# gpt-2-batched\n\nset(TEST_TARGET gpt-2-batched)\nadd_executable(${TEST_TARGET} main-batched.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)\n"
  },
  {
    "path": "examples/gpt-2/README.md",
    "content": "# gpt-2\n\nThis is a C++ example running GPT-2 inference using the [ggml](https://github.com/ggerganov/ggml) library.\n\nThe program runs on the CPU - no video card is required.\n\nThe [Cerebras-GPT](https://huggingface.co/cerebras) models are also supported.\n\nThe example supports the following GPT-2 models:\n\n| Model | Description  | Disk Size |\n| ---   | ---          | ---       |\n| 117M  | Small model  | 240 MB    |\n| 345M  | Medium model | 680 MB    |\n| 774M  | Large model  | 1.5 GB    |\n| 1558M | XL model     | 3.0 GB    |\n\nSample performance on MacBook M1 Pro:\n\n| Model | Size  | Time / Token |\n| ---   | ---   | ---    |\n| GPT-2 |  117M |   5 ms |\n| GPT-2 |  345M |  12 ms |\n| GPT-2 |  774M |  23 ms |\n| GPT-2 | 1558M |  42 ms |\n\n*TODO: add tables for Cerebras-GPT models*\n\nSample output:\n\n```bash\n$ ./bin/gpt-2 -h\nusage: ./bin/gpt-2 [options]\n\noptions:\n  -h, --help            show this help message and exit\n  -s SEED, --seed SEED  RNG seed (default: -1)\n  -t N, --threads N     number of threads to use during computation (default: 8)\n  -p PROMPT, --prompt PROMPT\n                        prompt to start generation with (default: random)\n  -n N, --n_predict N   number of tokens to predict (default: 200)\n  --top_k N             top-k sampling (default: 40)\n  --top_p N             top-p sampling (default: 0.9)\n  --temp N              temperature (default: 1.0)\n  -b N, --batch_size N  batch size for prompt processing (default: 8)\n  -m FNAME, --model FNAME\n                        model path (default: models/gpt-2-117M/ggml-model.bin)\n\n$ ./bin/gpt-2\ngpt2_model_load: loading model from 'models/gpt-2-117M/ggml-model.bin'\ngpt2_model_load: n_vocab = 50257\ngpt2_model_load: n_ctx   = 1024\ngpt2_model_load: n_embd  = 768\ngpt2_model_load: n_head  = 12\ngpt2_model_load: n_layer = 12\ngpt2_model_load: f16     = 1\ngpt2_model_load: ggml ctx size = 311.12 MB\ngpt2_model_load: memory size =    72.00 MB, n_mem = 12288\ngpt2_model_load: model size  =   239.08 MB\nmain: number of tokens in prompt = 1\n\nSo this is going to be the end of the line for us.\n\nIf the Dolphins continue to do their business, it's possible that the team could make a bid to bring in new defensive coordinator Scott Linehan.\n\nLinehan's job is a little daunting, but he's a great coach and an excellent coach. I don't believe we're going to make the playoffs.\n\nWe're going to have to work hard to keep our heads down and get ready to go.<|endoftext|>\n\nmain: mem per token =  2048612 bytes\nmain:     load time =   106.32 ms\nmain:   sample time =     7.10 ms\nmain:  predict time =   506.40 ms / 5.06 ms per token\nmain:    total time =   629.84 ms\n```\n\n## Downloading and converting the original models (GPT-2)\n\nYou can download the original model files using the [download-model.sh](download-model.sh) Bash script. The models are\nin Tensorflow format, so in order to use them with ggml, you need to convert them to appropriate format. This is done\nvia the [convert-ckpt-to-ggml.py](convert-ckpt-to-ggml.py) python script.\n\nHere is the entire process for the GPT-2 117M model (download from official site + conversion):\n\n```bash\ncd ggml/build\n../examples/gpt-2/download-model.sh 117M\n\nDownloading model 117M ...\nmodels/gpt-2-117M/checkpoint                      100%[=============================>]      77  --.-KB/s    in 0s\nmodels/gpt-2-117M/encoder.json                    100%[=============================>]   1018K  1.20MB/s    in 0.8s\nmodels/gpt-2-117M/hparams.json                    100%[=============================>]      90  --.-KB/s    in 0s\nmodels/gpt-2-117M/model.ckpt.data-00000-of-00001  100%[=============================>] 474.70M  1.21MB/s    in 8m 39s\nmodels/gpt-2-117M/model.ckpt.index                100%[=============================>]   5.09K  --.-KB/s    in 0s\nmodels/gpt-2-117M/model.ckpt.meta                 100%[=============================>] 460.11K   806KB/s    in 0.6s\nmodels/gpt-2-117M/vocab.bpe                       100%[=============================>] 445.62K   799KB/s    in 0.6s\nDone! Model '117M' saved in 'models/gpt-2-117M/'\n\nRun the convert-ckpt-to-ggml.py script to convert the model to ggml format.\n\n  python /Users/john/ggml/examples/gpt-2/convert-ckpt-to-ggml.py models/gpt-2-117M/ 1\n\n```\n\nThis conversion requires that you have python and Tensorflow installed on your computer. Still, if you want to avoid\nthis, you can download the already converted ggml models as described below.\n\n## Downloading and converting the original models (Cerebras-GPT)\n\nClone the respective repository from here: https://huggingface.co/cerebras\n\nUse the [convert-cerebras-to-ggml.py](convert-cerebras-to-ggml.py) script to convert the model to `ggml` format:\n\n```bash\ncd ggml/build\ngit clone https://huggingface.co/cerebras/Cerebras-GPT-111M models/\npython ../examples/gpt-2/convert-cerebras-to-ggml.py models/Cerebras-GPT-111M/\n\n```\n\n## Downloading the ggml model directly (GPT-2)\n\nFor convenience, I will be hosting the converted ggml model files in order to make it easier to run the examples. This\nway, you can directly download a single binary file and start using it. No python or Tensorflow is required.\n\nHere is how to get the 117M ggml model:\n\n```bash\ncd ggml/build\n../examples/gpt-2/download-ggml-model.sh 117M\n\nDownloading ggml model 117M ...\nmodels/gpt-2-117M/ggml-model.bin         100%[===============================>] 239.58M  8.52MB/s    in 28s\nDone! Model '117M' saved in 'models/gpt-2-117M/ggml-model.bin'\nYou can now use it like this:\n\n  $ ./bin/gpt-2 -m models/gpt-2-117M/ggml-model.bin -p \"This is an example\"\n\n```\n\nAt some point, I might decide to stop hosting these models. So in that case, simply revert to the manual process above.\n\n## Quantizing the models\n\nYou can also try to quantize the `ggml` models via 4-bit integer quantization.\nKeep in mind that for smaller models, this will render them completely useless.\nYou generally want to quantize larger models.\n\n```bash\n# quantize GPT-2 F16 to Q4_0 (faster but less precise)\n./bin/gpt-2-quantize models/gpt-2-1558M/ggml-model-f16.bin models/gpt-2-1558M/ggml-model-q4_0.bin 2\n./bin/gpt-2 -m models/gpt-2-1558M/ggml-model-q4_0.bin -p \"This is an example\"\n\n# quantize Cerebras F16 to Q4_1 (slower but more precise)\n./bin/gpt-2-quantize models/Cerebras-GPT-6.7B/ggml-model-f16.bin models/Cerebras-GPT-6.7B/ggml-model-q4_1.bin 3\n./bin/gpt-2 -m models/Cerebras-GPT-6.7B/ggml-model-q4_1.bin -p \"This is an example\"\n\n```\n\n## Batched generation example\n\nYou can try the batched generation from a given prompt using the gpt-2-batched binary.\n\nSample output:\n\n```bash\n$ gpt-2-batched -np 5 -m models/gpt-2-117M/ggml-model.bin -p \"Hello my name is\" -n 50\n\nmain: seed = 1697037431\ngpt2_model_load: loading model from 'models/gpt-2-117M/ggml-model.bin'\ngpt2_model_load: n_vocab = 50257\ngpt2_model_load: n_ctx   = 1024\ngpt2_model_load: n_embd  = 768\ngpt2_model_load: n_head  = 12\ngpt2_model_load: n_layer = 12\ngpt2_model_load: ftype   = 1\ngpt2_model_load: qntvr   = 0\ngpt2_model_load: ggml tensor size    = 320 bytes\ngpt2_model_load: backend buffer size = 312.72 MB\nggml_init_cublas: found 1 CUDA devices:\n  Device 0: NVIDIA GeForce GTX 1660, compute capability 7.5\ngpt2_model_load: using CPU backend\ngpt2_model_load: memory size =    72.00 MB, n_mem = 12288\ngpt2_model_load: model size  =   239.08 MB\nextract_tests_from_file : No test file found.\ntest_gpt_tokenizer : 0 tests failed out of 0 tests.\nmain: compute buffer size: 3.26 MB\n\n\nmain: generating 5 sequences ...\nmain: prompt: 'Hello my name is'\nmain: number of tokens in prompt = 4, first 8 tokens: 15496 616 1438 318\n\n\nsequence 0:\n\nHello my name is John. You can call me any way you want, if you want, but for my very first date, I will be on the phone with you. We're both in our early 20s, but I feel like it's all\n\nsequence 1:\n\nHello my name is Robert, and I want to say that we're proud to have your company here on the world's largest platform for sharing your stories with us. This is a huge opportunity for our community. We have hundreds of people on this team and\n\nsequence 2:\n\nHello my name is Jack. I'm the one who created you.\n\nJack is a boy with a big smile and a big heart. He is a handsome guy. He loves the outdoors and loves the people he meets. He wants to be a\n\nsequence 3:\n\nHello my name is John. I am a Canadian citizen with a large number of family in Quebec and I am interested in studying. My aim is to take up a post in the Journal of the International Academy of Sciences of Canada which I am currently finishing.\n\nsequence 4:\n\nHello my name is Dan. I am an entrepreneur. I am a great father. I am a great husband. I am a great husband. I am a great dad. And I am a great husband.\n\nI love my life. I love\n\n\n\nmain:     load time =   880.80 ms\nmain:   sample time =    91.43 ms\nmain:  predict time =  2518.29 ms\nmain:    total time =  3544.32 ms\n```\n"
  },
  {
    "path": "examples/gpt-2/convert-cerebras-to-ggml.py",
    "content": "# Convert Cerebras models to ggml format\n#\n# ref: https://www.cerebras.net/blog/cerebras-gpt-a-family-of-open-compute-efficient-large-language-models/\n#\n\nimport sys\nimport struct\nimport json\nimport torch\nimport numpy as np\nimport re\n\nfrom transformers import AutoModelForCausalLM\n\n# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py\ndef bytes_to_unicode():\n    \"\"\"\n    Returns list of utf-8 byte and a corresponding list of unicode strings.\n    The reversible bpe codes work on unicode strings.\n    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.\n    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.\n    This is a signficant percentage of your normal, say, 32K bpe vocab.\n    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\n    And avoids mapping to whitespace/control characters the bpe code barfs on.\n    \"\"\"\n    bs = list(range(ord(\"!\"), ord(\"~\")+1))+list(range(ord(\"¡\"), ord(\"¬\")+1))+list(range(ord(\"®\"), ord(\"ÿ\")+1))\n    cs = bs[:]\n    n = 0\n    for b in range(2**8):\n        if b not in bs:\n            bs.append(b)\n            cs.append(2**8+n)\n            n += 1\n    cs = [chr(n) for n in cs]\n    return dict(zip(bs, cs))\n\nif len(sys.argv) < 2:\n    print(\"Usage: convert-cerebras-to-ggml.py dir-model [use-f32]\\n\")\n    sys.exit(1)\n\n# output in the same directory as the model\ndir_model = sys.argv[1]\nfname_out = sys.argv[1] + \"/ggml-model-f16.bin\"\n\nwith open(dir_model + \"/vocab.json\", \"r\", encoding=\"utf-8\") as f:\n    encoder = json.load(f)\n\nwith open(dir_model + \"/config.json\", \"r\", encoding=\"utf-8\") as f:\n    hparams = json.load(f)\n\n# use 16-bit or 32-bit floats\nuse_f16 = True\nif len(sys.argv) > 2:\n    use_f16 = False\n    fname_out = sys.argv[1] + \"/ggml-model-f32.bin\"\n\nmodel = AutoModelForCausalLM.from_pretrained(dir_model, low_cpu_mem_usage=True)\n#print (model)\n\nlist_vars = model.state_dict()\n#print (list_vars)\n\nprint(hparams)\n\nfout = open(fname_out, \"wb\")\n\nfout.write(struct.pack(\"i\", 0x67676d6c)) # magic: ggml in hex\nfout.write(struct.pack(\"i\", hparams[\"vocab_size\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_positions\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_embd\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_head\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_layer\"]))\nfout.write(struct.pack(\"i\", use_f16))\n\nbyte_encoder = bytes_to_unicode()\nbyte_decoder = {v:k for k, v in byte_encoder.items()}\n\nfout.write(struct.pack(\"i\", len(encoder)))\n\nfor key in encoder:\n    text = bytearray([byte_decoder[c] for c in key])\n    fout.write(struct.pack(\"i\", len(text)))\n    fout.write(text)\n\nfor name in list_vars.keys():\n    data = list_vars[name].squeeze().numpy()\n    print(\"Processing variable: \" + name + \" with shape: \", data.shape)\n\n    # rename headers to keep compatibility\n    if name == \"transformer.ln_f.weight\":\n        name = \"model/ln_f/g\"\n    elif name == \"transformer.ln_f.bias\":\n        name = \"model/ln_f/b\"\n    elif name == \"transformer.wte.weight\":\n        name = \"model/wte\"\n    elif name == \"transformer.wpe.weight\":\n        name = \"model/wpe\"\n    elif name == \"lm_head.weight\":\n        name = \"model/lm_head\"\n    elif re.match(r\"transformer.h\\.\\d+\\.ln_1\\.weight\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/ln_1/g\"\n    elif re.match(r\"transformer.h\\.\\d+\\.ln_1\\.bias\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/ln_1/b\"\n    elif re.match(r\"transformer.h\\.\\d+\\.attn\\.c_attn\\.weight\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/attn/c_attn/w\"\n    elif re.match(r\"transformer.h\\.\\d+\\.attn\\.c_attn\\.bias\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/attn/c_attn/b\"\n    elif re.match(r\"transformer.h\\.\\d+\\.attn\\.c_proj\\.weight\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/attn/c_proj/w\"\n    elif re.match(r\"transformer.h.\\d+.attn.c_proj.bias\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/attn/c_proj/b\"\n    elif re.match(r\"transformer.h.\\d+.ln_2.weight\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/ln_2/g\"\n    elif re.match(r\"transformer.h.\\d+.ln_2.bias\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/ln_2/b\"\n    elif re.match(r\"transformer.h.\\d+.mlp.c_fc.weight\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/mlp/c_fc/w\"\n    elif re.match(r\"transformer.h.\\d+.mlp.c_fc.bias\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/mlp/c_fc/b\"\n    elif re.match(r\"transformer.h.\\d+.mlp.c_proj.weight\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/mlp/c_proj/w\"\n    elif re.match(r\"transformer.h.\\d+.mlp.c_proj.bias\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/mlp/c_proj/b\"\n    else:\n        print(\"Unrecognized variable name. %s\", name)\n\n    # we don't need these\n    if name.endswith(\"attn.masked_bias\") or name.endswith(\".attn.bias\"):\n        print(\"  Skipping variable: \" + name)\n        continue\n\n    n_dims = len(data.shape);\n\n    # ftype == 0 -> float32, ftype == 1 -> float16\n    ftype = 0;\n    if use_f16:\n        if (name == \"model/wte\" or name == \"model/lm_head\" or name[-2:] == \"/g\" or name[-2:] == \"/w\") and n_dims == 2:\n            print(\"  Converting to float16\")\n            data = data.astype(np.float16)\n            ftype = 1\n        else:\n            print(\"  Converting to float32\")\n            data = data.astype(np.float32)\n            ftype = 0\n\n    # for efficiency - transpose the projection matrices\n    # \"model/h.*/attn/c_attn/w\"\n    # \"model/h.*/attn/c_proj/w\"\n    # \"model/h.*/mlp/c_fc/w\"\n    # \"model/h.*/mlp/c_proj/w\"\n    if name[-14:] == \"/attn/c_attn/w\" or \\\n       name[-14:] == \"/attn/c_proj/w\" or \\\n       name[-11:] == \"/mlp/c_fc/w\" or \\\n       name[-13:] == \"/mlp/c_proj/w\":\n        print(\"  Transposing\")\n        data = data.transpose()\n\n    # header\n    str = name.encode('utf-8')\n    fout.write(struct.pack(\"iii\", n_dims, len(str), ftype))\n    for i in range(n_dims):\n        fout.write(struct.pack(\"i\", data.shape[n_dims - 1 - i]))\n    fout.write(str);\n\n    # data\n    data.tofile(fout)\n\nfout.close()\n\nprint(\"Done. Output file: \" + fname_out)\nprint(\"\")\n"
  },
  {
    "path": "examples/gpt-2/convert-ckpt-to-ggml.py",
    "content": "# Convert a model checkpoint to a ggml compatible file\n#\n# Load the model using TensorFlow.\n# Iterate over all variables and write them to a binary file.\n#\n# For each variable, write the following:\n#   - Number of dimensions (int)\n#   - Name length (int)\n#   - Dimensions (int[n_dims])\n#   - Name (char[name_length])\n#   - Data (float[n_dims])\n#\n# By default, the bigger matrices are converted to 16-bit floats.\n# This can be disabled by adding the \"use-f32\" CLI argument.\n#\n# At the start of the ggml file we write the model parameters\n# and vocabulary.\n#\n\nimport sys\nimport json\nimport struct\nimport numpy as np\nimport tensorflow as tf\n\n# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py\ndef bytes_to_unicode():\n    \"\"\"\n    Returns list of utf-8 byte and a corresponding list of unicode strings.\n    The reversible bpe codes work on unicode strings.\n    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.\n    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.\n    This is a signficant percentage of your normal, say, 32K bpe vocab.\n    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\n    And avoids mapping to whitespace/control characters the bpe code barfs on.\n    \"\"\"\n    bs = list(range(ord(\"!\"), ord(\"~\")+1))+list(range(ord(\"¡\"), ord(\"¬\")+1))+list(range(ord(\"®\"), ord(\"ÿ\")+1))\n    cs = bs[:]\n    n = 0\n    for b in range(2**8):\n        if b not in bs:\n            bs.append(b)\n            cs.append(2**8+n)\n            n += 1\n    cs = [chr(n) for n in cs]\n    return dict(zip(bs, cs))\n\n# helper method to convert a numpy array to different float types\ndef convert_to_ftype(data, ftype):\n    # fp16\n    if ftype == 1:\n        return data.astype(np.float16)\n\n    assert False, \"Invalid ftype: \" + str(ftype)\n\nif len(sys.argv) < 3:\n    print(\"Usage: convert-ckpt-to-ggml.py dir-model ftype\\n\")\n    print(\"  ftype == 0 -> float32\")\n    print(\"  ftype == 1 -> float16\")\n    sys.exit(1)\n\n# output in the same directory as the model\ndir_model = sys.argv[1]\nfname_out = sys.argv[1] + \"/ggml-model.bin\"\n\nwith open(dir_model + \"/encoder.json\", \"r\", encoding=\"utf-8\") as f:\n    encoder = json.load(f)\n\nwith open(dir_model + \"/hparams.json\", \"r\", encoding=\"utf-8\") as f:\n    hparams = json.load(f)\n\n# possible data types\n#   ftype == 0 -> float32\n#   ftype == 1 -> float16\n#\n# map from ftype to string\nftype_str = [\"f32\", \"f16\"]\n\nftype = 1\nif len(sys.argv) > 2:\n    ftype = int(sys.argv[2])\n    if ftype < 0 or ftype > 1:\n        print(\"Invalid ftype: \" + str(ftype))\n        sys.exit(1)\n    fname_out = sys.argv[1] + \"/ggml-model-\" + ftype_str[ftype] + \".bin\"\n\nlist_vars = tf.train.list_variables(dir_model)\n\nfout = open(fname_out, \"wb\")\n\nfout.write(struct.pack(\"i\", 0x67676d6c)) # magic: ggml in hex\nfout.write(struct.pack(\"i\", hparams[\"n_vocab\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_ctx\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_embd\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_head\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_layer\"]))\nfout.write(struct.pack(\"i\", ftype))\n\nbyte_encoder = bytes_to_unicode()\nbyte_decoder = {v:k for k, v in byte_encoder.items()}\n\nfout.write(struct.pack(\"i\", len(encoder)))\n\nfor key in encoder:\n    text = bytearray([byte_decoder[c] for c in key])\n    fout.write(struct.pack(\"i\", len(text)))\n    fout.write(text)\n\nfor name, shape in list_vars:\n    print(\"Processing variable: \" + name + \" with shape: \", shape)\n\n    data = tf.train.load_variable(dir_model, name).squeeze()\n    n_dims = len(data.shape);\n\n    # for efficiency - transpose the projection matrices\n    # \"model/h.*/attn/c_attn/w\"\n    # \"model/h.*/attn/c_proj/w\"\n    # \"model/h.*/mlp/c_fc/w\"\n    # \"model/h.*/mlp/c_proj/w\"\n    if name[-14:] == \"/attn/c_attn/w\" or \\\n       name[-14:] == \"/attn/c_proj/w\" or \\\n       name[-11:] == \"/mlp/c_fc/w\" or \\\n       name[-13:] == \"/mlp/c_proj/w\":\n        print(\"  Transposing\")\n        data = data.transpose()\n\n    dshape = data.shape\n\n    ftype_cur = 0\n    if ftype != 0:\n        # match name:\n        #  \"model/wte\"\n        #  \"model/h.*/attn/c_attn/w\"\n        #  \"model/h.*/attn/c_proj/w\"\n        #  \"model/h.*/mlp/c_fc/w\"\n        #  \"model/h.*/mlp/c_proj/w\"\n        if name == \"model/wte\" or name[-2:] == \"/w\":\n            print(\"  Converting to \" + ftype_str[ftype])\n            data = convert_to_ftype(data, ftype)\n            ftype_cur = ftype\n        else:\n            print(\"  Converting to float32\")\n            data = data.astype(np.float32)\n            ftype_cur = 0\n\n    # header\n    str = name.encode('utf-8')\n    fout.write(struct.pack(\"iii\", n_dims, len(str), ftype_cur))\n    for i in range(n_dims):\n        fout.write(struct.pack(\"i\", dshape[n_dims - 1 - i]))\n    fout.write(str);\n\n    # data\n    data.tofile(fout)\n\nfout.close()\n\nprint(\"Done. Output file: \" + fname_out)\nprint(\"\")\n"
  },
  {
    "path": "examples/gpt-2/convert-h5-to-ggml.py",
    "content": "# Convert GPT-2 h5 transformer model to ggml format\n#\n# Load the model using GPT2Model.\n# Iterate over all variables and write them to a binary file.\n#\n# For each variable, write the following:\n#   - Number of dimensions (int)\n#   - Name length (int)\n#   - Dimensions (int[n_dims])\n#   - Name (char[name_length])\n#   - Data (float[n_dims])\n#\n# By default, the bigger matrices are converted to 16-bit floats.\n# This can be disabled by adding the \"use-f32\" CLI argument.\n#\n# At the start of the ggml file we write the model parameters\n# and vocabulary.\n#\n\nimport sys\nimport struct\nimport json\nimport numpy as np\nimport re\n\nfrom transformers import GPT2Model\n\n# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py\ndef bytes_to_unicode():\n    \"\"\"\n    Returns list of utf-8 byte and a corresponding list of unicode strings.\n    The reversible bpe codes work on unicode strings.\n    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.\n    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.\n    This is a signficant percentage of your normal, say, 32K bpe vocab.\n    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\n    And avoids mapping to whitespace/control characters the bpe code barfs on.\n    \"\"\"\n    bs = list(range(ord(\"!\"), ord(\"~\")+1))+list(range(ord(\"¡\"), ord(\"¬\")+1))+list(range(ord(\"®\"), ord(\"ÿ\")+1))\n    cs = bs[:]\n    n = 0\n    for b in range(2**8):\n        if b not in bs:\n            bs.append(b)\n            cs.append(2**8+n)\n            n += 1\n    cs = [chr(n) for n in cs]\n    return dict(zip(bs, cs))\n\nif len(sys.argv) < 2:\n    print(\"Usage: convert-h5-to-ggml.py dir-model [use-f32]\\n\")\n    sys.exit(1)\n\n# output in the same directory as the model\ndir_model = sys.argv[1]\nfname_out = sys.argv[1] + \"/ggml-model.bin\"\n\nwith open(dir_model + \"/vocab.json\", \"r\", encoding=\"utf-8\") as f:\n    encoder = json.load(f)\n\nwith open(dir_model + \"/added_tokens.json\", \"r\", encoding=\"utf-8\") as f:\n    encoder_added = json.load(f)\n\nwith open(dir_model + \"/config.json\", \"r\", encoding=\"utf-8\") as f:\n    hparams = json.load(f)\n\n# use 16-bit or 32-bit floats\nuse_f16 = True\nif len(sys.argv) > 2:\n    use_f16 = False\n    fname_out = sys.argv[1] + \"/ggml-model-f32.bin\"\n\nmodel = GPT2Model.from_pretrained(dir_model, low_cpu_mem_usage=True)\n#print (model)\n\nlist_vars = model.state_dict()\n#print (list_vars)\n\nfout = open(fname_out, \"wb\")\n\nfout.write(struct.pack(\"i\", 0x67676d6c)) # magic: ggml in hex\nfout.write(struct.pack(\"i\", hparams[\"vocab_size\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_positions\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_embd\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_head\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_layer\"]))\n#fout.write(struct.pack(\"i\", hparams[\"rotary_dim\"]))\nfout.write(struct.pack(\"i\", use_f16))\n\nbyte_encoder = bytes_to_unicode()\nbyte_decoder = {v:k for k, v in byte_encoder.items()}\n\nfout.write(struct.pack(\"i\", len(encoder) + len(encoder_added)))\n\nfor key in encoder:\n    text = bytearray([byte_decoder[c] for c in key])\n    fout.write(struct.pack(\"i\", len(text)))\n    fout.write(text)\n\nfor key in encoder_added:\n    text = bytearray([byte_decoder[c] for c in key])\n    fout.write(struct.pack(\"i\", len(text)))\n    fout.write(text)\n\nfor name in list_vars.keys():\n    data = list_vars[name].squeeze().numpy()\n    print(\"Processing variable: \" + name + \" with shape: \", data.shape)\n\n    # we don't need these\n    if name.endswith(\"attn.masked_bias\") or name.endswith(\".attn.bias\"):\n        print(\"  Skipping variable: \" + name)\n        continue\n\n    n_dims = len(data.shape);\n\n    # ftype == 0 -> float32, ftype == 1 -> float16\n    ftype = 0;\n    if use_f16:\n        if name[-7:] == \".weight\" and n_dims == 2:\n            print(\"  Converting to float16\")\n            data = data.astype(np.float16)\n            ftype = 1\n        else:\n            print(\"  Converting to float32\")\n            data = data.astype(np.float32)\n            ftype = 0\n\n    # for efficiency - transpose these matrices:\n    #  \"transformer.h.*.mlp.c_proj.weight\n    if name.endswith(\".mlp.c_proj.weight\"):\n        print(\"  Transposing\")\n        data = data.transpose()\n\n    # rename headers to keep compatibility\n    if name == \"ln_f.weight\":\n        name = \"model/ln_f/g\"\n    elif name == \"ln_f.bias\":\n        name = \"model/ln_f/b\"\n    elif name == \"wte.weight\":\n        name = \"model/wte\"\n    elif name == \"wpe.weight\":\n        name = \"model/wpe\"\n    elif re.match(r\"h\\.\\d+\\.ln_1\\.weight\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/ln_1/g\"\n    elif re.match(r\"h\\.\\d+\\.ln_1\\.bias\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/ln_1/b\"\n    elif re.match(r\"h\\.\\d+\\.attn\\.c_attn\\.weight\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/attn/c_attn/w\"\n    elif re.match(r\"h\\.\\d+\\.attn\\.c_attn\\.bias\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/attn/c_attn/b\"\n    elif re.match(r\"h\\.\\d+\\.attn\\.c_proj\\.weight\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/attn/c_proj/w\"\n    elif re.match(r\"h.\\d+.attn.c_proj.bias\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/attn/c_proj/b\"\n    elif re.match(r\"h.\\d+.ln_2.weight\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/ln_2/g\"\n    elif re.match(r\"h.\\d+.ln_2.bias\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/ln_2/b\"\n    elif re.match(r\"h.\\d+.mlp.c_fc.weight\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/mlp/c_fc/w\"\n    elif re.match(r\"h.\\d+.mlp.c_fc.bias\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/mlp/c_fc/b\"\n    elif re.match(r\"h.\\d+.mlp.c_proj.weight\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/mlp/c_proj/w\"\n    elif re.match(r\"h.\\d+.mlp.c_proj.bias\", name):\n        i = re.findall(\"\\d+\", name)[0]\n        name = f\"model/h{i}/mlp/c_proj/b\"\n    else:\n        print(\"Unrecognized variable name. %s\", name)\n\n    str = name.encode('utf-8')\n\n    fout.write(struct.pack(\"iii\", n_dims, len(str), ftype))\n    for i in range(n_dims):\n        fout.write(struct.pack(\"i\", data.shape[n_dims - 1 - i]))\n    fout.write(str);\n\n    # data\n    data.tofile(fout)\n\nfout.close()\n\nprint(\"Done. Output file: \" + fname_out)\nprint(\"\")\n"
  },
  {
    "path": "examples/gpt-2/download-ggml-model.sh",
    "content": "#!/bin/bash\n\n# This script downloads GPT-2 model files that have already been converted to ggml format.\n# This way you don't have to convert them yourself.\n#\n# If you want to download the original GPT-2 model files, use the \"download-model.sh\" script instead.\n\n#src=\"https://ggml.ggerganov.com\"\n#pfx=\"ggml-model-gpt-2\"\n\nsrc=\"https://huggingface.co/ggerganov/ggml\"\npfx=\"resolve/main/ggml-model-gpt-2\"\n\nggml_path=$(dirname $(realpath $0))\n\n# GPT-2 models\nmodels=( \"117M\" \"345M\" \"774M\" \"1558M\" )\n\n# list available models\nfunction list_models {\n    printf \"\\n\"\n    printf \"  Available models:\"\n    for model in \"${models[@]}\"; do\n        printf \" $model\"\n    done\n    printf \"\\n\\n\"\n}\n\nif [ \"$#\" -ne 1 ]; then\n    printf \"Usage: $0 <model>\\n\"\n    list_models\n\n    exit 1\nfi\n\nmodel=$1\n\nif [[ ! \" ${models[@]} \" =~ \" ${model} \" ]]; then\n    printf \"Invalid model: $model\\n\"\n    list_models\n\n    exit 1\nfi\n\n# download ggml model\n\nprintf \"Downloading ggml model $model ...\\n\"\n\nmkdir -p models/gpt-2-$model\n\nif [ -x \"$(command -v wget)\" ]; then\n    wget --quiet --show-progress -O models/gpt-2-$model/ggml-model.bin $src/$pfx-$model.bin\nelif [ -x \"$(command -v curl)\" ]; then\n    curl -L --output models/gpt-2-$model/ggml-model.bin $src/$pfx-$model.bin\nelse\n    printf \"Either wget or curl is required to download models.\\n\"\n    exit 1\nfi\n\nif [ $? -ne 0 ]; then\n    printf \"Failed to download ggml model $model \\n\"\n    printf \"Please try again later or download the original GPT-2 model files and convert them yourself.\\n\"\n    exit 1\nfi\n\nprintf \"Done! Model '$model' saved in 'models/gpt-2-$model/ggml-model.bin'\\n\"\nprintf \"You can now use it like this:\\n\\n\"\nprintf \"  $ ./bin/gpt-2 -m models/gpt-2-$model/ggml-model.bin -p \\\"This is an example\\\"\\n\"\nprintf \"\\n\"\n"
  },
  {
    "path": "examples/gpt-2/download-model.sh",
    "content": "#!/bin/bash\n\nggml_path=$(dirname $(realpath $0))\n\n# GPT-2 models\nmodels=( \"117M\" \"345M\" \"774M\" \"1558M\" )\n\n# list available models\nfunction list_models {\n    printf \"\\n\"\n    printf \"  Available models:\"\n    for model in \"${models[@]}\"; do\n        printf \" $model\"\n    done\n    printf \"\\n\\n\"\n}\n\nif [ \"$#\" -ne 1 ]; then\n    printf \"Usage: $0 <model>\\n\"\n    list_models\n\n    exit 1\nfi\n\nmodel=$1\n\nif [[ ! \" ${models[@]} \" =~ \" ${model} \" ]]; then\n    printf \"Invalid model: $model\\n\"\n    list_models\n\n    exit 1\nfi\n\n# download model\n\nprintf \"Downloading model $model ...\\n\"\n\nmkdir -p models/gpt-2-$model\n\nfor file in checkpoint encoder.json hparams.json model.ckpt.data-00000-of-00001 model.ckpt.index model.ckpt.meta vocab.bpe; do\n    wget --quiet --show-progress -O models/gpt-2-$model/$file https://openaipublic.blob.core.windows.net/gpt-2/models/$model/$file\ndone\n\nprintf \"Done! Model '$model' saved in 'models/gpt-2-$model/'\\n\\n\"\nprintf \"Run the convert-ckpt-to-ggml.py script to convert the model to ggml format.\\n\"\nprintf \"\\n\"\nprintf \"  python $ggml_path/convert-ckpt-to-ggml.py models/gpt-2-$model/\\n\"\nprintf \"\\n\"\n"
  },
  {
    "path": "examples/gpt-2/main-alloc.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n\n#include \"common.h\"\n#include \"common-ggml.h\"\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <string>\n#include <vector>\n\n#if defined(_MSC_VER)\n#pragma warning(disable: 4244 4267) // possible loss of data\n#endif\n\n// default hparams (GPT-2 117M)\nstruct gpt2_hparams {\n    int32_t n_vocab = 50257;\n    int32_t n_ctx   = 1024;\n    int32_t n_embd  = 768;\n    int32_t n_head  = 12;\n    int32_t n_layer = 12;\n    int32_t ftype   = 1;\n    float   eps     = 1e-5f;\n};\n\nstruct gpt2_layer {\n    // normalization\n    struct ggml_tensor * ln_1_g;\n    struct ggml_tensor * ln_1_b;\n\n    struct ggml_tensor * ln_2_g;\n    struct ggml_tensor * ln_2_b;\n\n    // attention\n    struct ggml_tensor * c_attn_attn_w;\n    struct ggml_tensor * c_attn_attn_b;\n\n    struct ggml_tensor * c_attn_proj_w;\n    struct ggml_tensor * c_attn_proj_b;\n\n    // mlp\n    struct ggml_tensor * c_mlp_fc_w;\n    struct ggml_tensor * c_mlp_fc_b;\n\n    struct ggml_tensor * c_mlp_proj_w;\n    struct ggml_tensor * c_mlp_proj_b;\n};\n\nstruct gpt2_model {\n    gpt2_hparams hparams;\n\n    // normalization\n    struct ggml_tensor * ln_f_g;\n    struct ggml_tensor * ln_f_b;\n\n    struct ggml_tensor * wte;     //    token embedding\n    struct ggml_tensor * wpe;     // position embedding\n    struct ggml_tensor * lm_head; // language model head\n\n    std::vector<gpt2_layer> layers;\n\n    // key + value memory\n    struct ggml_tensor * memory_k;\n    struct ggml_tensor * memory_v;\n\n    //\n    struct ggml_context * ctx_w;\n    std::map<std::string, struct ggml_tensor *> tensors;\n};\n\n// load the model's weights from a file\nbool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {\n    printf(\"%s: loading model from '%s'\\n\", __func__, fname.c_str());\n\n    auto fin = std::ifstream(fname, std::ios::binary);\n    if (!fin) {\n        fprintf(stderr, \"%s: failed to open '%s'\\n\", __func__, fname.c_str());\n        return false;\n    }\n\n    // verify magic\n    {\n        uint32_t magic;\n        fin.read((char *) &magic, sizeof(magic));\n        if (magic != GGML_FILE_MAGIC) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad magic)\\n\", __func__, fname.c_str());\n            return false;\n        }\n    }\n\n    // load hparams\n    {\n        auto & hparams = model.hparams;\n\n        fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));\n        fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));\n        fin.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));\n        fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));\n        fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));\n        fin.read((char *) &hparams.ftype,   sizeof(hparams.ftype));\n\n        const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;\n\n        printf(\"%s: n_vocab = %d\\n\", __func__, hparams.n_vocab);\n        printf(\"%s: n_ctx   = %d\\n\", __func__, hparams.n_ctx);\n        printf(\"%s: n_embd  = %d\\n\", __func__, hparams.n_embd);\n        printf(\"%s: n_head  = %d\\n\", __func__, hparams.n_head);\n        printf(\"%s: n_layer = %d\\n\", __func__, hparams.n_layer);\n        printf(\"%s: ftype   = %d\\n\", __func__, hparams.ftype);\n        printf(\"%s: qntvr   = %d\\n\", __func__, qntvr);\n\n        hparams.ftype %= GGML_QNT_VERSION_FACTOR;\n    }\n\n    // load vocab\n    {\n        int32_t n_vocab = 0;\n        fin.read((char *) &n_vocab, sizeof(n_vocab));\n\n        if (n_vocab != model.hparams.n_vocab) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad vocab size %d != %d)\\n\",\n                    __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);\n            return false;\n        }\n\n        std::string word;\n        std::vector<char> buf(128);\n\n        for (int i = 0; i < n_vocab; i++) {\n            uint32_t len;\n            fin.read((char *) &len, sizeof(len));\n\n            buf.resize(len);\n            fin.read((char *) buf.data(), len);\n            word.assign(buf.data(), len);\n\n            vocab.token_to_id[word] = i;\n            vocab.id_to_token[i] = word;\n        }\n    }\n\n    // for the big tensors, we have the option to store the data in 16-bit floats or quantized\n    // in order to save memory and also to speed up the computation\n    ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));\n    if (wtype == GGML_TYPE_COUNT) {\n        fprintf(stderr, \"%s: invalid model file '%s' (bad ftype value %d)\\n\",\n                __func__, fname.c_str(), model.hparams.ftype);\n        return false;\n    }\n\n    auto & ctx = model.ctx_w;\n\n    size_t ctx_size = 0;\n\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n        const int n_vocab = hparams.n_vocab;\n\n        ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g\n        ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b\n\n        ctx_size += ggml_row_size(wtype,         n_vocab*n_embd); // wte\n        ctx_size += ggml_row_size(GGML_TYPE_F32  , n_ctx*n_embd); // wpe\n        ctx_size += ggml_row_size(wtype,         n_vocab*n_embd); // lm_head\n\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b\n\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b\n\n        ctx_size += n_layer*(ggml_row_size(wtype,         3*n_embd*n_embd)); // c_attn_attn_w\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd));        // c_attn_attn_b\n\n        ctx_size += n_layer*(ggml_row_size(wtype,         n_embd*n_embd));   // c_attn_proj_w\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd));          // c_attn_proj_b\n\n        ctx_size += n_layer*(ggml_row_size(wtype,         4*n_embd*n_embd)); // c_mlp_fc_w\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd));        // c_mlp_fc_b\n\n        ctx_size += n_layer*(ggml_row_size(wtype,         4*n_embd*n_embd)); // c_mlp_proj_w\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd));        // c_mlp_proj_b\n\n        ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k\n        ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v\n\n        ctx_size += (6 + 12*n_layer)*512; // object overhead\n\n        printf(\"%s: ggml tensor size = %d bytes\\n\", __func__, (int) sizeof(ggml_tensor));\n        printf(\"%s: ggml ctx size = %6.2f MB\\n\", __func__, ctx_size/(1024.0*1024.0));\n    }\n\n    // create the ggml context\n    {\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ ctx_size,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ false,\n        };\n\n        model.ctx_w = ggml_init(params);\n        if (!model.ctx_w) {\n            fprintf(stderr, \"%s: ggml_init() failed\\n\", __func__);\n            return false;\n        }\n    }\n\n    // prepare memory for the weights\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n        const int n_vocab = hparams.n_vocab;\n\n        model.layers.resize(n_layer);\n\n        model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);\n        model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);\n\n        model.wte     = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);\n        model.wpe     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);\n        model.lm_head = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);\n\n        // map by name\n        model.tensors[\"model/ln_f/g\"] = model.ln_f_g;\n        model.tensors[\"model/ln_f/b\"] = model.ln_f_b;\n\n        model.tensors[\"model/wte\"]     = model.wte;\n        model.tensors[\"model/wpe\"]     = model.wpe;\n        model.tensors[\"model/lm_head\"] = model.lm_head;\n\n        for (int i = 0; i < n_layer; ++i) {\n            auto & layer = model.layers[i];\n\n            layer.ln_1_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n            layer.ln_1_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.ln_2_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n            layer.ln_2_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, 3*n_embd);\n            layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);\n\n            layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, n_embd);\n            layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.c_mlp_fc_w    = ggml_new_tensor_2d(ctx, wtype,           n_embd, 4*n_embd);\n            layer.c_mlp_fc_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);\n\n            layer.c_mlp_proj_w  = ggml_new_tensor_2d(ctx, wtype,         4*n_embd, n_embd);\n            layer.c_mlp_proj_b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            // map by name\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_1/g\"]        = layer.ln_1_g;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_1/b\"]        = layer.ln_1_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_2/g\"]        = layer.ln_2_g;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_2/b\"]        = layer.ln_2_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_attn/w\"] = layer.c_attn_attn_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_attn/b\"] = layer.c_attn_attn_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_proj/w\"] = layer.c_attn_proj_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_proj/b\"] = layer.c_attn_proj_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_fc/w\"]    = layer.c_mlp_fc_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_fc/b\"]    = layer.c_mlp_fc_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_proj/w\"]  = layer.c_mlp_proj_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_proj/b\"]  = layer.c_mlp_proj_b;\n        }\n    }\n\n    // key + value memory\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n\n        const int n_mem      = n_layer*n_ctx;\n        const int n_elements = n_embd*n_mem;\n\n        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);\n        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);\n\n        const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);\n\n        printf(\"%s: memory size = %8.2f MB, n_mem = %d\\n\", __func__, memory_size/1024.0/1024.0, n_mem);\n    }\n\n    // load weights\n    {\n        size_t total_size = 0;\n\n        bool has_lm_head = false;\n\n        while (true) {\n            int32_t n_dims;\n            int32_t length;\n            int32_t ttype;\n\n            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));\n            fin.read(reinterpret_cast<char *>(&length), sizeof(length));\n            fin.read(reinterpret_cast<char *>(&ttype),  sizeof(ttype));\n\n            if (fin.eof()) {\n                break;\n            }\n\n            int32_t nelements = 1;\n            int32_t ne[2] = { 1, 1 };\n            for (int i = 0; i < n_dims; ++i) {\n                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));\n                nelements *= ne[i];\n            }\n\n            std::string name(length, 0);\n            fin.read(&name[0], length);\n\n            if (model.tensors.find(name) == model.tensors.end()) {\n                fprintf(stderr, \"%s: unknown tensor '%s' in model file\\n\", __func__, name.c_str());\n                return false;\n            }\n\n            auto tensor = model.tensors[name];\n            if (ggml_nelements(tensor) != nelements) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong size in model file\\n\", __func__, name.c_str());\n                return false;\n            }\n\n            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\\n\",\n                        __func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);\n                return false;\n            }\n\n            // for debugging\n            if (0) {\n                printf(\"%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\\n\", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));\n            }\n\n            const size_t bpe = ggml_type_size(ggml_type(ttype));\n\n            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\\n\",\n                        __func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe);\n                return false;\n            }\n\n            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));\n\n            // GPT-2 models share the WTE tensor as the LM head\n            if (name == \"model/wte\" && has_lm_head == false) {\n                memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));\n            }\n\n            if (name == \"model/lm_head\") {\n                has_lm_head = true;\n            }\n\n            total_size += ggml_nbytes(tensor);\n        }\n\n        printf(\"%s: model size  = %8.2f MB\\n\", __func__, total_size/1024.0/1024.0);\n    }\n\n    fin.close();\n\n    return true;\n}\n\n// build the computation graph\nstruct ggml_cgraph * gpt2_graph(\n        const gpt2_model & model,\n        const int n_past,\n        const int n_tokens) {\n    const int N = n_tokens;\n\n    const auto & hparams = model.hparams;\n\n    const int n_embd  = hparams.n_embd;\n    const int n_layer = hparams.n_layer;\n    const int n_ctx   = hparams.n_ctx;\n    const int n_head  = hparams.n_head;\n\n    // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data\n    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();\n    static std::vector<uint8_t> buf(buf_size);\n\n    struct ggml_init_params params = {\n        /*.mem_size   =*/ buf_size,\n        /*.mem_buffer =*/ buf.data(),\n        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()\n    };\n\n    struct ggml_context * ctx = ggml_init(params);\n\n    struct ggml_cgraph  * gf = ggml_new_graph(ctx);\n\n    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);\n    // at this point, the tensor data is not allocated yet and cannot be set\n    // we will find the tensor after the graph is allocated by its name, and set the data then\n    ggml_set_name(embd, \"embd\");\n    // setting a tensor as an input will ensure that it is allocated at the beginning of the graph\n    // this is important to ensure that the input tensors are not overwritten before they are used\n    ggml_set_input(embd);\n\n    struct ggml_tensor * position = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);\n    ggml_set_name(position, \"position\");\n    ggml_set_input(position);\n\n    // wte + wpe\n    struct ggml_tensor * inpL =\n        ggml_add(ctx,\n                ggml_get_rows(ctx, model.wte, embd),\n                ggml_get_rows(ctx, model.wpe, position));\n\n    for (int il = 0; il < n_layer; ++il) {\n        struct ggml_tensor * cur;\n\n        // norm\n        {\n            // [ 768, N]\n            cur = ggml_norm(ctx, inpL, hparams.eps);\n\n            // cur = ln_1_g*cur + ln_1_b\n            // [ 768, N]\n            cur = ggml_add(ctx,\n                    ggml_mul(ctx,\n                        ggml_repeat(ctx, model.layers[il].ln_1_g, cur),\n                        cur),\n                    ggml_repeat(ctx, model.layers[il].ln_1_b, cur));\n        }\n\n        // attn\n        // [2304, 768] - model.layers[il].c_attn_attn_w\n        // [2304,   1] - model.layers[il].c_attn_attn_b\n        // [ 768,   N] - cur (in)\n        // [2304,   N] - cur (out)\n        //\n        // cur = attn_w*cur + attn_b\n        // [2304, N]\n        {\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_attn_attn_w,\n                    cur);\n\n            cur = ggml_add(ctx,\n                    ggml_repeat(ctx, model.layers[il].c_attn_attn_b, cur),\n                    cur);\n        }\n\n        // self-attention\n        {\n            struct ggml_tensor * Qcur = ggml_view_2d(ctx, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);\n            struct ggml_tensor * Kcur = ggml_view_2d(ctx, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);\n            struct ggml_tensor * Vcur = ggml_view_2d(ctx, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);\n\n            // store key and value to memory\n            if (N >= 1) {\n                struct ggml_tensor * k = ggml_view_1d(ctx, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));\n                struct ggml_tensor * v = ggml_view_1d(ctx, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));\n\n                ggml_build_forward_expand(gf, ggml_cpy(ctx, Kcur, k));\n                ggml_build_forward_expand(gf, ggml_cpy(ctx, Vcur, v));\n            }\n\n            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)\n            // [64, N, 12]\n            struct ggml_tensor * Q =\n                ggml_permute(ctx,\n                        ggml_cont_3d(ctx, Qcur, n_embd/n_head, n_head, N),\n                        0, 2, 1, 3);\n\n            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)\n            // [64, n_past + N, 12]\n            struct ggml_tensor * K =\n                ggml_permute(ctx,\n                        ggml_reshape_3d(ctx,\n                            ggml_view_1d(ctx, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),\n                            n_embd/n_head, n_head, n_past + N),\n                        0, 2, 1, 3);\n\n            // GG: flash attention\n            //struct ggml_tensor * V =\n            //    ggml_cpy(ctx0,\n            //            ggml_permute(ctx0,\n            //                ggml_reshape_3d(ctx0,\n            //                    ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),\n            //                    n_embd/n_head, n_head, n_past + N),\n            //                1, 2, 0, 3),\n            //            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));\n\n            //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);\n\n            // K * Q\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ = ggml_mul_mat(ctx, K, Q);\n\n            // KQ_scaled = KQ / sqrt(n_embd/n_head)\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ_scaled =\n                ggml_scale(ctx,\n                        KQ,\n                        1.0f/sqrtf(float(n_embd)/n_head));\n\n            // KQ_masked = mask_past(KQ_scaled)\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx, KQ_scaled, n_past);\n\n            // KQ = soft_max(KQ_masked)\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx, KQ_masked);\n\n            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()\n            // [n_past + N, 64, 12]\n            struct ggml_tensor * V_trans =\n                ggml_cont_3d(ctx,\n                        ggml_permute(ctx,\n                            ggml_reshape_3d(ctx,\n                                ggml_view_1d(ctx, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),\n                                n_embd/n_head, n_head, n_past + N),\n                            1, 2, 0, 3),\n                        n_past + N, n_embd/n_head, n_head);\n\n            // KQV = transpose(V) * KQ_soft_max\n            // [64, N, 12]\n            struct ggml_tensor * KQV = ggml_mul_mat(ctx, V_trans, KQ_soft_max);\n\n            // KQV_merged = KQV.permute(0, 2, 1, 3)\n            // [64, 12, N]\n            struct ggml_tensor * KQV_merged = ggml_permute(ctx, KQV, 0, 2, 1, 3);\n\n            // cur = KQV_merged.contiguous().view(n_embd, N)\n            // [768, N]\n            cur = ggml_cont_2d(ctx, KQV_merged, n_embd, N);\n        }\n\n        // projection\n        // [ 768, 768] - model.layers[il].c_attn_proj_w\n        // [ 768,   1] - model.layers[il].c_attn_proj_b\n        // [ 768,   N] - cur (in)\n        // [ 768,   N] - cur (out)\n        //\n        // cur = proj_w*cur + proj_b\n        // [768, N]\n        {\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_attn_proj_w,\n                    cur);\n\n            cur = ggml_add(ctx,\n                    ggml_repeat(ctx, model.layers[il].c_attn_proj_b, cur),\n                    cur);\n        }\n\n        // add the input\n        cur = ggml_add(ctx, cur, inpL);\n\n        struct ggml_tensor * inpFF = cur;\n\n        // feed-forward network\n        {\n            // norm\n            {\n                cur = ggml_norm(ctx, inpFF, hparams.eps);\n\n                // cur = ln_2_g*cur + ln_2_b\n                // [ 768, N]\n                cur = ggml_add(ctx,\n                        ggml_mul(ctx,\n                            ggml_repeat(ctx, model.layers[il].ln_2_g, cur),\n                            cur),\n                        ggml_repeat(ctx, model.layers[il].ln_2_b, cur));\n            }\n\n            // fully connected\n            // [3072, 768] - model.layers[il].c_mlp_fc_w\n            // [3072,   1] - model.layers[il].c_mlp_fc_b\n            // [ 768,   N] - cur (in)\n            // [3072,   N] - cur (out)\n            //\n            // cur = fc_w*cur + fc_b\n            // [3072, N]\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_mlp_fc_w,\n                    cur);\n\n            cur = ggml_add(ctx,\n                    ggml_repeat(ctx, model.layers[il].c_mlp_fc_b, cur),\n                    cur);\n\n            // GELU activation\n            // [3072, N]\n            cur = ggml_gelu(ctx, cur);\n\n            // projection\n            // [ 768, 3072] - model.layers[il].c_mlp_proj_w\n            // [ 768,    1] - model.layers[il].c_mlp_proj_b\n            // [3072,    N] - cur (in)\n            // [ 768,    N] - cur (out)\n            //\n            // cur = proj_w*cur + proj_b\n            // [768, N]\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_mlp_proj_w,\n                    cur);\n\n            cur = ggml_add(ctx,\n                    ggml_repeat(ctx, model.layers[il].c_mlp_proj_b, cur),\n                    cur);\n        }\n\n        // input for next layer\n        inpL = ggml_add(ctx, cur, inpFF);\n    }\n\n    // norm\n    {\n        // [ 768, N]\n        inpL = ggml_norm(ctx, inpL, hparams.eps);\n\n        // inpL = ln_f_g*inpL + ln_f_b\n        // [ 768, N]\n        inpL = ggml_add(ctx,\n                ggml_mul(ctx,\n                    ggml_repeat(ctx, model.ln_f_g, inpL),\n                    inpL),\n                ggml_repeat(ctx, model.ln_f_b, inpL));\n    }\n\n    // inpL = WTE * inpL\n    // [ 768, 50257] - model.lm_head\n    // [ 768, N]     - inpL\n    inpL = ggml_mul_mat(ctx, model.lm_head, inpL);\n    ggml_set_name(inpL, \"logits\");\n    // setting a tensor as the output will ensure that it is not overwritten by subsequent operations\n    ggml_set_output(inpL);\n\n    // logits -> probs\n    //inpL = ggml_soft_max(ctx0, inpL);\n\n    ggml_build_forward_expand(gf, inpL);\n\n    ggml_free(ctx);\n\n    return gf;\n}\n\n// evaluate the transformer\n//\n//   - model:     the model\n//   - allocr:    ggml_gallocr to use to allocate the compute buffer\n//   - n_threads: number of threads to use\n//   - n_past:    the context size so far\n//   - embd_inp:  the embeddings of the tokens in the context\n//   - embd_w:    the predicted logits for the next token\n//\nbool gpt2_eval(\n        const gpt2_model & model,\n        ggml_gallocr_t allocr,\n        const int n_threads,\n        const int n_past,\n        const std::vector<gpt_vocab::id> & embd_inp,\n              std::vector<float>         & embd_w) {\n    const int N = embd_inp.size();\n\n    const auto & hparams = model.hparams;\n\n    const int n_vocab = hparams.n_vocab;\n\n    struct ggml_cgraph * gf = gpt2_graph(model, n_past, embd_inp.size());\n\n    // allocate the graph tensors\n    ggml_gallocr_alloc_graph(allocr, gf);\n\n    // set the graph inputs\n    struct ggml_tensor * embd = ggml_graph_get_tensor(gf, \"embd\");\n    memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));\n\n    struct ggml_tensor * position = ggml_graph_get_tensor(gf, \"position\");\n    for (int i = 0; i < N; ++i) {\n        ((int32_t *) position->data)[i] = n_past + i;\n    }\n\n    // run the computation\n    struct ggml_cplan plan = ggml_graph_plan(gf, n_threads, nullptr);\n    static std::vector<uint8_t> work_buffer;\n    work_buffer.resize(plan.work_size);\n    plan.work_data = work_buffer.data();\n    ggml_graph_compute(gf, &plan);\n\n    //if (n_past%100 == 0) {\n    //    ggml_graph_print   (&gf);\n    //    ggml_graph_dump_dot(&gf, NULL, \"gpt-2.dot\");\n    //}\n\n    // get the graph outputs\n    struct ggml_tensor * logits = ggml_graph_get_tensor(gf, \"logits\");\n\n    //embd_w.resize(n_vocab*N);\n    //memcpy(embd_w.data(), ggml_get_data(logits), sizeof(float)*n_vocab*N);\n\n    // return result just for the last token\n    embd_w.resize(n_vocab);\n    memcpy(embd_w.data(), (float *) ggml_get_data(logits) + (n_vocab*(N-1)), sizeof(float)*n_vocab);\n\n    return true;\n}\n\nint main(int argc, char ** argv) {\n    ggml_time_init();\n\n    const int64_t t_main_start_us = ggml_time_us();\n\n    gpt_params params;\n    params.model = \"models/gpt-2-117M/ggml-model.bin\";\n\n    if (gpt_params_parse(argc, argv, params) == false) {\n        return 1;\n    }\n\n    if (params.seed < 0) {\n        params.seed = time(NULL);\n    }\n\n    printf(\"%s: seed = %d\\n\", __func__, params.seed);\n\n    std::mt19937 rng(params.seed);\n    if (params.prompt.empty()) {\n        params.prompt = gpt_random_prompt(rng);\n    }\n\n    int64_t t_load_us = 0;\n\n    gpt_vocab vocab;\n    gpt2_model model;\n\n    // load the model\n    {\n        const int64_t t_start_us = ggml_time_us();\n\n        if (!gpt2_model_load(params.model, model, vocab)) {\n            fprintf(stderr, \"%s: failed to load model from '%s'\\n\", __func__, params.model.c_str());\n            return 1;\n        }\n\n        t_load_us = ggml_time_us() - t_start_us;\n\n        test_gpt_tokenizer(vocab, params.token_test);\n    }\n\n    ggml_gallocr_t allocr = NULL;\n    // allocate the compute buffer\n    {\n        allocr = ggml_gallocr_new(ggml_backend_cpu_buffer_type());\n\n        // create the worst case graph for memory usage estimation\n        int n_tokens = std::min(model.hparams.n_ctx, params.n_batch);\n        int n_past = model.hparams.n_ctx - n_tokens;\n        struct ggml_cgraph * gf = gpt2_graph(model, n_past, n_tokens);\n\n        // pre-allocate the compute buffer for the worst case (optional)\n        ggml_gallocr_reserve(allocr, gf);\n        size_t mem_size =  ggml_gallocr_get_buffer_size(allocr, 0);\n        fprintf(stderr, \"%s: compute buffer size: %.2f MB\\n\", __func__, mem_size/1024.0/1024.0);\n    }\n\n    int n_past = 0;\n\n    int64_t t_sample_us  = 0;\n    int64_t t_predict_us = 0;\n\n    std::vector<float> logits;\n\n    // tokenize the prompt\n    std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);\n\n    params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());\n\n    printf(\"%s: prompt: '%s'\\n\", __func__, params.prompt.c_str());\n    printf(\"%s: number of tokens in prompt = %zu, first 8 tokens: \", __func__, embd_inp.size());\n    for (int i = 0; i < std::min(8, (int) embd_inp.size()); i++) {\n        printf(\"%d \", embd_inp[i]);\n    }\n    printf(\"\\n\\n\");\n\n    // submit the input prompt token-by-token\n    // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning\n    std::vector<gpt_vocab::id> embd;\n\n    for (size_t i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {\n        // predict\n        if (embd.size() > 0) {\n            const int64_t t_start_us = ggml_time_us();\n\n            if (!gpt2_eval(model, allocr, params.n_threads, n_past, embd, logits)) {\n                printf(\"Failed to predict\\n\");\n                return 1;\n            }\n\n            t_predict_us += ggml_time_us() - t_start_us;\n        }\n\n        n_past += embd.size();\n        embd.clear();\n\n        if (i >= embd_inp.size()) {\n            // sample next token\n            const int   top_k = params.top_k;\n            const float top_p = params.top_p;\n            const float temp  = params.temp;\n\n            const int n_vocab = model.hparams.n_vocab;\n\n            gpt_vocab::id id = 0;\n\n            {\n                const int64_t t_start_sample_us = ggml_time_us();\n\n                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);\n\n                t_sample_us += ggml_time_us() - t_start_sample_us;\n            }\n\n            // add it to the context\n            embd.push_back(id);\n        } else {\n            // if here, it means we are still processing the input prompt\n            for (size_t k = i; k < embd_inp.size(); k++) {\n                embd.push_back(embd_inp[k]);\n                if (int32_t(embd.size()) >= params.n_batch) {\n                    break;\n                }\n            }\n            i += embd.size() - 1;\n        }\n\n        // display text\n        for (auto id : embd) {\n            printf(\"%s\", vocab.id_to_token[id].c_str());\n        }\n        fflush(stdout);\n\n        // end of text token\n        if (embd.back() == 50256) {\n            break;\n        }\n    }\n\n    // report timing\n    {\n        const int64_t t_main_end_us = ggml_time_us();\n\n        printf(\"\\n\\n\");\n        printf(\"%s:     load time = %8.2f ms\\n\", __func__, t_load_us/1000.0f);\n        printf(\"%s:   sample time = %8.2f ms\\n\", __func__, t_sample_us/1000.0f);\n        printf(\"%s:  predict time = %8.2f ms / %.2f ms per token\\n\", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);\n        printf(\"%s:    total time = %8.2f ms\\n\", __func__, (t_main_end_us - t_main_start_us)/1000.0f);\n    }\n\n    ggml_free(model.ctx_w);\n\n    return 0;\n}\n"
  },
  {
    "path": "examples/gpt-2/main-backend.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n\n#ifdef GGML_USE_CUDA\n#include \"ggml-cuda.h\"\n#endif\n\n#ifdef GGML_USE_METAL\n#include \"ggml-metal.h\"\n#endif\n\n#include \"common.h\"\n#include \"common-ggml.h\"\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <string>\n#include <vector>\n\n#if defined(_MSC_VER)\n#pragma warning(disable: 4244 4267) // possible loss of data\n#endif\n\n#define GPT2_MAX_NODES 4096\n\nstatic void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {\n    (void) level;\n    (void) user_data;\n    fputs(text, stderr);\n    fflush(stderr);\n}\n\n// default hparams (GPT-2 117M)\nstruct gpt2_hparams {\n    int32_t n_vocab = 50257;\n    int32_t n_ctx   = 1024;\n    int32_t n_embd  = 768;\n    int32_t n_head  = 12;\n    int32_t n_layer = 12;\n    int32_t ftype   = 1;\n    float   eps     = 1e-5f;\n};\n\nstruct gpt2_layer {\n    // normalization\n    struct ggml_tensor * ln_1_g;\n    struct ggml_tensor * ln_1_b;\n\n    struct ggml_tensor * ln_2_g;\n    struct ggml_tensor * ln_2_b;\n\n    // attention\n    struct ggml_tensor * c_attn_attn_w;\n    struct ggml_tensor * c_attn_attn_b;\n\n    struct ggml_tensor * c_attn_proj_w;\n    struct ggml_tensor * c_attn_proj_b;\n\n    // mlp\n    struct ggml_tensor * c_mlp_fc_w;\n    struct ggml_tensor * c_mlp_fc_b;\n\n    struct ggml_tensor * c_mlp_proj_w;\n    struct ggml_tensor * c_mlp_proj_b;\n};\n\nstruct gpt2_model {\n    gpt2_hparams hparams;\n\n    // normalization\n    struct ggml_tensor * ln_f_g;\n    struct ggml_tensor * ln_f_b;\n\n    struct ggml_tensor * wte;     //    token embedding\n    struct ggml_tensor * wpe;     // position embedding\n    struct ggml_tensor * lm_head; // language model head\n\n    std::vector<gpt2_layer> layers;\n\n    // key + value memory\n    struct ggml_tensor * memory_k;\n    struct ggml_tensor * memory_v;\n\n    //\n    struct ggml_context * ctx_w;\n    struct ggml_context * ctx_kv;\n\n    ggml_backend_t backend = NULL;\n\n    ggml_backend_buffer_t buffer_w;\n    ggml_backend_buffer_t buffer_kv;\n\n    std::map<std::string, struct ggml_tensor *> tensors;\n};\n\n// load the model's weights from a file\nbool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_ctx, int n_gpu_layers) {\n    printf(\"%s: loading model from '%s'\\n\", __func__, fname.c_str());\n\n    auto fin = std::ifstream(fname, std::ios::binary);\n    if (!fin) {\n        fprintf(stderr, \"%s: failed to open '%s'\\n\", __func__, fname.c_str());\n        return false;\n    }\n\n    // verify magic\n    {\n        uint32_t magic;\n        fin.read((char *) &magic, sizeof(magic));\n        if (magic != GGML_FILE_MAGIC) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad magic)\\n\", __func__, fname.c_str());\n            return false;\n        }\n    }\n\n    // load hparams\n    {\n        auto & hparams = model.hparams;\n\n        fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));\n        fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));\n        fin.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));\n        fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));\n        fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));\n        fin.read((char *) &hparams.ftype,   sizeof(hparams.ftype));\n\n        const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;\n\n        printf(\"%s: n_vocab = %d\\n\", __func__, hparams.n_vocab);\n        printf(\"%s: n_ctx   = %d\\n\", __func__, hparams.n_ctx);\n        printf(\"%s: n_embd  = %d\\n\", __func__, hparams.n_embd);\n        printf(\"%s: n_head  = %d\\n\", __func__, hparams.n_head);\n        printf(\"%s: n_layer = %d\\n\", __func__, hparams.n_layer);\n        printf(\"%s: ftype   = %d\\n\", __func__, hparams.ftype);\n        printf(\"%s: qntvr   = %d\\n\", __func__, qntvr);\n\n        hparams.ftype %= GGML_QNT_VERSION_FACTOR;\n    }\n\n    // load vocab\n    {\n        int32_t n_vocab = 0;\n        fin.read((char *) &n_vocab, sizeof(n_vocab));\n\n        if (n_vocab != model.hparams.n_vocab) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad vocab size %d != %d)\\n\",\n                    __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);\n            return false;\n        }\n\n        std::string word;\n        std::vector<char> buf(128);\n\n        for (int i = 0; i < n_vocab; i++) {\n            uint32_t len;\n            fin.read((char *) &len, sizeof(len));\n\n            buf.resize(len);\n            fin.read((char *) buf.data(), len);\n            word.assign(buf.data(), len);\n\n            vocab.token_to_id[word] = i;\n            vocab.id_to_token[i] = word;\n        }\n    }\n\n    // for the big tensors, we have the option to store the data in 16-bit floats or quantized\n    // in order to save memory and also to speed up the computation\n    ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));\n    if (wtype == GGML_TYPE_COUNT) {\n        fprintf(stderr, \"%s: invalid model file '%s' (bad ftype value %d)\\n\",\n                __func__, fname.c_str(), model.hparams.ftype);\n        return false;\n    }\n\n    ggml_log_set(ggml_log_callback_default, nullptr);\n\n    auto & ctx = model.ctx_w;\n\n    // create the ggml context\n    {\n        size_t n_tensors = 2 + 6 + 12*model.hparams.n_layer;\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ ggml_tensor_overhead() * n_tensors,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ true,\n        };\n\n        ctx = ggml_init(params);\n        if (!ctx) {\n            fprintf(stderr, \"%s: ggml_init() failed\\n\", __func__);\n            return false;\n        }\n    }\n\n    // initialize the backend\n#ifdef GGML_USE_CUDA\n    if (n_gpu_layers > 0) {\n        fprintf(stderr, \"%s: using CUDA backend\\n\", __func__);\n        model.backend = ggml_backend_cuda_init(0);\n        if (!model.backend) {\n            fprintf(stderr, \"%s: ggml_backend_cuda_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n#ifdef GGML_USE_METAL\n    if (n_gpu_layers > 0) {\n        fprintf(stderr, \"%s: using Metal backend\\n\", __func__);\n        model.backend = ggml_backend_metal_init();\n        if (!model.backend) {\n            fprintf(stderr, \"%s: ggml_backend_metal_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n    if (!model.backend) {\n        // fallback to CPU backend\n        fprintf(stderr, \"%s: using CPU backend\\n\", __func__);\n        model.backend = ggml_backend_cpu_init();\n    }\n\n    if (!model.backend) {\n        fprintf(stderr, \"%s: ggml_backend_cpu_init() failed\\n\", __func__);\n        return false;\n    }\n\n    // create the tensors for the model\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n        const int n_vocab = hparams.n_vocab;\n\n        model.layers.resize(n_layer);\n\n        model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);\n        model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);\n\n        model.wte     = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);\n        model.wpe     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);\n        model.lm_head = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);\n\n        // map by name\n        model.tensors[\"model/ln_f/g\"] = model.ln_f_g;\n        model.tensors[\"model/ln_f/b\"] = model.ln_f_b;\n\n        model.tensors[\"model/wte\"]     = model.wte;\n        model.tensors[\"model/wpe\"]     = model.wpe;\n        model.tensors[\"model/lm_head\"] = model.lm_head;\n\n        for (int i = 0; i < n_layer; ++i) {\n            auto & layer = model.layers[i];\n\n            layer.ln_1_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n            layer.ln_1_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.ln_2_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n            layer.ln_2_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, 3*n_embd);\n            layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);\n\n            layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, n_embd);\n            layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.c_mlp_fc_w    = ggml_new_tensor_2d(ctx, wtype,           n_embd, 4*n_embd);\n            layer.c_mlp_fc_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);\n\n            layer.c_mlp_proj_w  = ggml_new_tensor_2d(ctx, wtype,         4*n_embd, n_embd);\n            layer.c_mlp_proj_b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            // map by name\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_1/g\"]        = layer.ln_1_g;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_1/b\"]        = layer.ln_1_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_2/g\"]        = layer.ln_2_g;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_2/b\"]        = layer.ln_2_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_attn/w\"] = layer.c_attn_attn_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_attn/b\"] = layer.c_attn_attn_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_proj/w\"] = layer.c_attn_proj_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_proj/b\"] = layer.c_attn_proj_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_fc/w\"]    = layer.c_mlp_fc_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_fc/b\"]    = layer.c_mlp_fc_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_proj/w\"]  = layer.c_mlp_proj_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_proj/b\"]  = layer.c_mlp_proj_b;\n        }\n    }\n\n    // allocate the model tensors in a backend buffer\n    model.buffer_w = ggml_backend_alloc_ctx_tensors(ctx, model.backend);\n\n    printf(\"%s: ggml tensor size    = %d bytes\\n\", __func__, (int) sizeof(ggml_tensor));\n    printf(\"%s: backend buffer size = %6.2f MB\\n\", __func__, ggml_backend_buffer_get_size(model.buffer_w)/(1024.0*1024.0));\n\n    // override the default training context with the user-provided\n    model.hparams.n_ctx = n_ctx;\n\n    // key + value memory\n    {\n        auto * ctx = model.ctx_kv;\n\n        // create the ggml context\n        {\n            size_t n_tensors = 2;\n            struct ggml_init_params params = {\n                /*.mem_size   =*/ ggml_tensor_overhead() * n_tensors,\n                /*.mem_buffer =*/ NULL,\n                /*.no_alloc   =*/ true,\n            };\n\n            ctx = ggml_init(params);\n            if (!ctx) {\n                fprintf(stderr, \"%s: ggml_init() failed\\n\", __func__);\n                return false;\n            }\n        }\n\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n\n        const int n_mem      = n_layer*n_ctx;\n        const int n_elements = n_embd*n_mem;\n\n        // k and v here can also be GGML_TYPE_F16 to save memory and speed up the computation\n        // if backend supports it\n        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);\n        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);\n\n        // allocate the KV memory in a backend buffer\n        model.buffer_kv = ggml_backend_alloc_ctx_tensors(ctx, model.backend);\n\n        const size_t memory_size = ggml_backend_buffer_get_size(model.buffer_kv);\n        printf(\"%s: memory size = %8.2f MB, n_mem = %d\\n\", __func__, memory_size/1024.0/1024.0, n_mem);\n    }\n\n    // load weights\n    {\n        size_t total_size = 0;\n\n        bool has_lm_head = false;\n\n        std::vector<char> read_buf;\n\n        while (true) {\n            int32_t n_dims;\n            int32_t length;\n            int32_t ttype;\n\n            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));\n            fin.read(reinterpret_cast<char *>(&length), sizeof(length));\n            fin.read(reinterpret_cast<char *>(&ttype),  sizeof(ttype));\n\n            if (fin.eof()) {\n                break;\n            }\n\n            int32_t nelements = 1;\n            int32_t ne[2] = { 1, 1 };\n            for (int i = 0; i < n_dims; ++i) {\n                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));\n                nelements *= ne[i];\n            }\n\n            std::string name(length, 0);\n            fin.read(&name[0], length);\n\n            if (model.tensors.find(name) == model.tensors.end()) {\n                fprintf(stderr, \"%s: unknown tensor '%s' in model file\\n\", __func__, name.c_str());\n                return false;\n            }\n\n            auto tensor = model.tensors[name];\n            ggml_set_name(tensor, name.c_str());\n            if (ggml_nelements(tensor) != nelements) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong size in model file\\n\", __func__, name.c_str());\n                return false;\n            }\n\n            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\\n\",\n                        __func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);\n                return false;\n            }\n\n            // for debugging\n            if (0) {\n                printf(\"%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\\n\", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));\n            }\n\n            const size_t bpe = ggml_type_size(ggml_type(ttype));\n\n            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\\n\",\n                        __func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe);\n                return false;\n            }\n\n            if (ggml_backend_buffer_is_host(model.buffer_w)) {\n                // for some backends such as CPU and Metal, the tensor data is in system memory and we can read directly into it\n                fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));\n            } else {\n                // read into a temporary buffer first, then copy to device memory\n                read_buf.resize(ggml_nbytes(tensor));\n                fin.read(read_buf.data(), ggml_nbytes(tensor));\n                ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));\n            }\n\n            // GPT-2 models share the WTE tensor as the LM head\n            if (name == \"model/wte\" && has_lm_head == false) {\n                //ggml_backend_tensor_copy(tensor, model.lm_head);\n                model.lm_head = tensor;\n            }\n\n            if (name == \"model/lm_head\") {\n                has_lm_head = true;\n            }\n\n            total_size += ggml_nbytes(tensor);\n        }\n\n        printf(\"%s: model size  = %8.2f MB\\n\", __func__, total_size/1024.0/1024.0);\n    }\n\n    fin.close();\n\n    return true;\n}\n\n// build the computation graph\nstruct ggml_cgraph * gpt2_graph(\n        const gpt2_model & model,\n        const int n_past,\n        const int n_tokens) {\n    const int N = n_tokens;\n\n    const auto & hparams = model.hparams;\n\n    const int n_embd  = hparams.n_embd;\n    const int n_layer = hparams.n_layer;\n    const int n_ctx   = hparams.n_ctx;\n    const int n_head  = hparams.n_head;\n\n    // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data\n    static size_t buf_size = ggml_tensor_overhead()*GPT2_MAX_NODES + ggml_graph_overhead_custom(GPT2_MAX_NODES, false);\n    static std::vector<uint8_t> buf(buf_size);\n\n    struct ggml_init_params params = {\n        /*.mem_size   =*/ buf_size,\n        /*.mem_buffer =*/ buf.data(),\n        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()\n    };\n\n    struct ggml_context * ctx = ggml_init(params);\n\n    struct ggml_cgraph  * gf = ggml_new_graph_custom(ctx, GPT2_MAX_NODES, false);\n\n    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);\n    // at this point, the tensor data is not allocated yet and cannot be set\n    // we will find the tensor after the graph is allocated by its name, and set the data then\n    ggml_set_name(embd, \"embd\");\n    // setting a tensor as an input will ensure that it is allocated at the beginning of the graph\n    // this is important to ensure that the input tensors are not overwritten before they are used\n    ggml_set_input(embd);\n\n    struct ggml_tensor * position = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);\n    ggml_set_name(position, \"position\");\n    ggml_set_input(position);\n\n    // wte + wpe\n    struct ggml_tensor * inpL =\n        ggml_add(ctx,\n                ggml_get_rows(ctx, model.wte, embd),\n                ggml_get_rows(ctx, model.wpe, position));\n\n    for (int il = 0; il < n_layer; ++il) {\n        struct ggml_tensor * cur;\n\n        // norm\n        {\n            // [ 768, N]\n            cur = ggml_norm(ctx, inpL, hparams.eps);\n\n            // cur = ln_1_g*cur + ln_1_b\n            // [ 768, N]\n            cur = ggml_add(ctx,\n                    ggml_mul(ctx,\n                        cur,\n                        model.layers[il].ln_1_g),\n                    model.layers[il].ln_1_b);\n        }\n\n        // attn\n        // [2304, 768] - model.layers[il].c_attn_attn_w\n        // [2304,   1] - model.layers[il].c_attn_attn_b\n        // [ 768,   N] - cur (in)\n        // [2304,   N] - cur (out)\n        //\n        // cur = attn_w*cur + attn_b\n        // [2304, N]\n        {\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_attn_attn_w,\n                    cur);\n\n            cur = ggml_add(ctx,\n                    cur,\n                    model.layers[il].c_attn_attn_b);\n        }\n\n        // self-attention\n        {\n            struct ggml_tensor * Qcur = ggml_view_2d(ctx, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);\n            struct ggml_tensor * Kcur = ggml_view_2d(ctx, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);\n            struct ggml_tensor * Vcur = ggml_view_2d(ctx, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);\n\n            // store key and value to memory\n            if (N >= 1) {\n                struct ggml_tensor * k = ggml_view_1d(ctx, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));\n                struct ggml_tensor * v = ggml_view_1d(ctx, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));\n\n                ggml_build_forward_expand(gf, ggml_cpy(ctx, Kcur, k));\n                ggml_build_forward_expand(gf, ggml_cpy(ctx, Vcur, v));\n            }\n\n            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)\n            // [64, N, 12]\n            struct ggml_tensor * Q =\n                ggml_permute(ctx,\n                        ggml_cont_3d(ctx, Qcur, n_embd/n_head, n_head, N),\n                        0, 2, 1, 3);\n\n            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)\n            // [64, n_past + N, 12]\n            struct ggml_tensor * K =\n                ggml_permute(ctx,\n                        ggml_reshape_3d(ctx,\n                            ggml_view_1d(ctx, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),\n                            n_embd/n_head, n_head, n_past + N),\n                        0, 2, 1, 3);\n\n            // GG: flash attention\n            //struct ggml_tensor * V =\n            //    ggml_cpy(ctx0,\n            //            ggml_permute(ctx0,\n            //                ggml_reshape_3d(ctx0,\n            //                    ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),\n            //                    n_embd/n_head, n_head, n_past + N),\n            //                1, 2, 0, 3),\n            //            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));\n\n            //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);\n\n            // K * Q\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ = ggml_mul_mat(ctx, K, Q);\n\n            // KQ_scaled = KQ / sqrt(n_embd/n_head)\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ_scaled =\n                ggml_scale(ctx,\n                        KQ,\n                        1.0f/sqrtf(float(n_embd)/n_head));\n\n            // KQ_masked = mask_past(KQ_scaled)\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx, KQ_scaled, n_past);\n\n            // KQ = soft_max(KQ_masked)\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx, KQ_masked);\n\n            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()\n            // [n_past + N, 64, 12]\n            struct ggml_tensor * V_trans =\n                ggml_cont_3d(ctx,\n                        ggml_permute(ctx,\n                            ggml_reshape_3d(ctx,\n                                ggml_view_1d(ctx, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),\n                                n_embd/n_head, n_head, n_past + N),\n                            1, 2, 0, 3),\n                        n_past + N, n_embd/n_head, n_head);\n\n            // KQV = transpose(V) * KQ_soft_max\n            // [64, N, 12]\n            struct ggml_tensor * KQV = ggml_mul_mat(ctx, V_trans, KQ_soft_max);\n\n            // KQV_merged = KQV.permute(0, 2, 1, 3)\n            // [64, 12, N]\n            struct ggml_tensor * KQV_merged = ggml_permute(ctx, KQV, 0, 2, 1, 3);\n\n            // cur = KQV_merged.contiguous().view(n_embd, N)\n            // [768, N]\n            cur = ggml_cont_2d(ctx, KQV_merged, n_embd, N);\n        }\n\n        // projection\n        // [ 768, 768] - model.layers[il].c_attn_proj_w\n        // [ 768,   1] - model.layers[il].c_attn_proj_b\n        // [ 768,   N] - cur (in)\n        // [ 768,   N] - cur (out)\n        //\n        // cur = proj_w*cur + proj_b\n        // [768, N]\n        {\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_attn_proj_w,\n                    cur);\n\n            cur = ggml_add(ctx,\n                    cur,\n                    model.layers[il].c_attn_proj_b);\n        }\n\n        // add the input\n        cur = ggml_add(ctx, cur, inpL);\n\n        struct ggml_tensor * inpFF = cur;\n\n        // feed-forward network\n        {\n            // norm\n            {\n                cur = ggml_norm(ctx, inpFF, hparams.eps);\n\n                // cur = ln_2_g*cur + ln_2_b\n                // [ 768, N]\n                cur = ggml_add(ctx,\n                        ggml_mul(ctx,\n                            cur,\n                            model.layers[il].ln_2_g),\n                        model.layers[il].ln_2_b);\n            }\n\n            // fully connected\n            // [3072, 768] - model.layers[il].c_mlp_fc_w\n            // [3072,   1] - model.layers[il].c_mlp_fc_b\n            // [ 768,   N] - cur (in)\n            // [3072,   N] - cur (out)\n            //\n            // cur = fc_w*cur + fc_b\n            // [3072, N]\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_mlp_fc_w,\n                    cur);\n\n            cur = ggml_add(ctx,\n                    cur,\n                    model.layers[il].c_mlp_fc_b);\n\n            // GELU activation\n            // [3072, N]\n            cur = ggml_gelu(ctx, cur);\n\n            // projection\n            // [ 768, 3072] - model.layers[il].c_mlp_proj_w\n            // [ 768,    1] - model.layers[il].c_mlp_proj_b\n            // [3072,    N] - cur (in)\n            // [ 768,    N] - cur (out)\n            //\n            // cur = proj_w*cur + proj_b\n            // [768, N]\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_mlp_proj_w,\n                    cur);\n\n            cur = ggml_add(ctx,\n                    cur,\n                    model.layers[il].c_mlp_proj_b);\n        }\n\n        // input for next layer\n        inpL = ggml_add(ctx, cur, inpFF);\n    }\n\n    // norm\n    {\n        // [ 768, N]\n        inpL = ggml_norm(ctx, inpL, hparams.eps);\n\n        // inpL = ln_f_g*inpL + ln_f_b\n        // [ 768, N]\n        inpL = ggml_add(ctx,\n                ggml_mul(ctx,\n                    inpL,\n                    model.ln_f_g),\n                model.ln_f_b);\n    }\n\n    // inpL = WTE * inpL\n    // [ 768, 50257] - model.lm_head\n    // [ 768, N]     - inpL\n    inpL = ggml_mul_mat(ctx, model.lm_head, inpL);\n    ggml_set_name(inpL, \"logits\");\n    // setting a tensor as the output will ensure that it is not overwritten by subsequent operations\n    ggml_set_output(inpL);\n\n    // logits -> probs\n    //inpL = ggml_soft_max(ctx0, inpL);\n\n    ggml_build_forward_expand(gf, inpL);\n\n    ggml_free(ctx);\n\n    return gf;\n}\n\n// evaluate the transformer\n//\n//   - model:     the model\n//   - allocr:    ggml_gallocr to use to allocate the compute buffer\n//   - n_threads: number of threads to use\n//   - n_past:    the context size so far\n//   - embd_inp:  the embeddings of the tokens in the context\n//   - embd_w:    the predicted logits for the next token\n//\nbool gpt2_eval(\n        const gpt2_model & model,\n        ggml_gallocr_t allocr,\n        const int n_threads,\n        const int n_past,\n        const std::vector<gpt_vocab::id> & embd_inp,\n              std::vector<float>         & embd_w) {\n    const int N = embd_inp.size();\n\n    const auto & hparams = model.hparams;\n\n    const int n_vocab = hparams.n_vocab;\n\n    struct ggml_cgraph * gf = gpt2_graph(model, n_past, embd_inp.size());\n\n    // allocate the graph tensors\n    ggml_gallocr_alloc_graph(allocr, gf);\n\n    // set the graph inputs\n    struct ggml_tensor * embd = ggml_graph_get_tensor(gf, \"embd\");\n    ggml_backend_tensor_set(embd, embd_inp.data(), 0, N*ggml_element_size(embd));\n\n    struct ggml_tensor * position = ggml_graph_get_tensor(gf, \"position\");\n    for (int i = 0; i < N; ++i) {\n        int32_t v = n_past + i;\n        ggml_backend_tensor_set(position, &v, i*sizeof(int32_t), sizeof(v));\n    }\n\n    // set backend options\n    if (ggml_backend_is_cpu(model.backend)) {\n        ggml_backend_cpu_set_n_threads(model.backend, n_threads);\n    }\n\n    // run the computation\n    ggml_backend_graph_compute(model.backend, gf);\n\n    //if (n_past%100 == 0) {\n    //    ggml_graph_print   (&gf);\n    //    ggml_graph_dump_dot(&gf, NULL, \"gpt-2.dot\");\n    //}\n\n    // get the graph outputs\n    struct ggml_tensor * logits = ggml_graph_get_tensor(gf, \"logits\");\n\n    //embd_w.resize(n_vocab*N);\n    //ggml_backend_tensor_get(logits, embd_w.data(), 0, sizeof(float)*n_vocab*N);\n\n    // return result just for the last token\n    embd_w.resize(n_vocab);\n    ggml_backend_tensor_get(logits, embd_w.data(), (n_vocab*(N-1))*sizeof(float), sizeof(float)*n_vocab);\n\n    return true;\n}\n\nint main(int argc, char ** argv) {\n    ggml_time_init();\n\n    const int64_t t_main_start_us = ggml_time_us();\n\n    gpt_params params;\n    params.model = \"models/gpt-2-117M/ggml-model.bin\";\n\n    if (gpt_params_parse(argc, argv, params) == false) {\n        return 1;\n    }\n\n    if (params.seed < 0) {\n        params.seed = time(NULL);\n    }\n\n    printf(\"%s: seed = %d\\n\", __func__, params.seed);\n\n    std::mt19937 rng(params.seed);\n    if (params.prompt.empty()) {\n        params.prompt = gpt_random_prompt(rng);\n    }\n\n    int64_t t_load_us = 0;\n\n    gpt_vocab vocab;\n    gpt2_model model;\n\n    // load the model\n    {\n        const int64_t t_start_us = ggml_time_us();\n\n        if (!gpt2_model_load(params.model, model, vocab, params.n_ctx, params.n_gpu_layers)) {\n            fprintf(stderr, \"%s: failed to load model from '%s'\\n\", __func__, params.model.c_str());\n            return 1;\n        }\n\n        t_load_us = ggml_time_us() - t_start_us;\n\n        test_gpt_tokenizer(vocab, params.token_test);\n    }\n\n    ggml_gallocr_t allocr = NULL;\n    // allocate the compute buffer\n    {\n        // create a graph allocator with the backend's default buffer type\n        allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));\n\n        // create the worst case graph for memory usage estimation\n        int n_tokens = std::min(model.hparams.n_ctx, params.n_batch);\n        int n_past = model.hparams.n_ctx - n_tokens;\n        struct ggml_cgraph * gf = gpt2_graph(model, n_past, n_tokens);\n\n        // pre-allocate the compute buffer for the worst case (optional)\n        ggml_gallocr_reserve(allocr, gf);\n        size_t mem_size =  ggml_gallocr_get_buffer_size(allocr, 0);\n        fprintf(stderr, \"%s: compute buffer size: %.2f MB\\n\", __func__, mem_size/1024.0/1024.0);\n    }\n\n    int n_past = 0;\n\n    int64_t t_sample_us  = 0;\n    int64_t t_predict_us = 0;\n\n    std::vector<float> logits;\n\n    // tokenize the prompt\n    std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);\n\n    params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());\n\n    printf(\"%s: prompt: '%s'\\n\", __func__, params.prompt.c_str());\n    printf(\"%s: number of tokens in prompt = %zu, first 8 tokens: \", __func__, embd_inp.size());\n    for (int i = 0; i < std::min(8, (int) embd_inp.size()); i++) {\n        printf(\"%d \", embd_inp[i]);\n    }\n    printf(\"\\n\\n\");\n\n    // submit the input prompt token-by-token\n    // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning\n    std::vector<gpt_vocab::id> embd;\n\n    for (size_t i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {\n        // predict\n        if (embd.size() > 0) {\n            const int64_t t_start_us = ggml_time_us();\n\n            if (!gpt2_eval(model, allocr, params.n_threads, n_past, embd, logits)) {\n                printf(\"Failed to predict\\n\");\n                return 1;\n            }\n\n            t_predict_us += ggml_time_us() - t_start_us;\n        }\n\n        n_past += embd.size();\n        embd.clear();\n\n        if (i >= embd_inp.size()) {\n            // sample next token\n            const int   top_k = params.top_k;\n            const float top_p = params.top_p;\n            const float temp  = params.temp;\n\n            const int n_vocab = model.hparams.n_vocab;\n\n            gpt_vocab::id id = 0;\n\n            {\n                const int64_t t_start_sample_us = ggml_time_us();\n\n                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);\n\n                t_sample_us += ggml_time_us() - t_start_sample_us;\n            }\n\n            // add it to the context\n            embd.push_back(id);\n        } else {\n            // if here, it means we are still processing the input prompt\n            for (size_t k = i; k < embd_inp.size(); k++) {\n                embd.push_back(embd_inp[k]);\n                if (int32_t(embd.size()) >= params.n_batch) {\n                    break;\n                }\n            }\n            i += embd.size() - 1;\n        }\n\n        // display text\n        for (auto id : embd) {\n            printf(\"%s\", vocab.id_to_token[id].c_str());\n        }\n        fflush(stdout);\n\n        // end of text token\n        if (!params.ignore_eos && embd.back() == 50256) {\n            break;\n        }\n    }\n\n    // report timing\n    {\n        const int64_t t_main_end_us = ggml_time_us();\n\n        printf(\"\\n\\n\");\n        printf(\"%s:     load time = %8.2f ms\\n\", __func__, t_load_us/1000.0f);\n        printf(\"%s:   sample time = %8.2f ms\\n\", __func__, t_sample_us/1000.0f);\n        printf(\"%s:  predict time = %8.2f ms / %.2f ms per token\\n\", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);\n        printf(\"%s:    total time = %8.2f ms\\n\", __func__, (t_main_end_us - t_main_start_us)/1000.0f);\n    }\n\n    ggml_free(model.ctx_w);\n\n    ggml_gallocr_free(allocr);\n    ggml_backend_buffer_free(model.buffer_w);\n    ggml_backend_buffer_free(model.buffer_kv);\n    ggml_backend_free(model.backend);\n\n    return 0;\n}\n"
  },
  {
    "path": "examples/gpt-2/main-batched.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n\n#ifdef GGML_USE_CUDA\n#include \"ggml-cuda.h\"\n#endif\n\n#ifdef GGML_USE_METAL\n#include \"ggml-metal.h\"\n#endif\n\n#include \"common.h\"\n#include \"common-ggml.h\"\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <set>\n#include <string>\n#include <vector>\n\n#if defined(_MSC_VER)\n#pragma warning(disable: 4244 4267) // possible loss of data\n#endif\n\n#define GPT2_MAX_NODES 4096\n\nstatic void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {\n    (void) level;\n    (void) user_data;\n    fputs(text, stderr);\n    fflush(stderr);\n}\n\ntypedef int32_t gpt2_pos;\ntypedef int32_t gpt2_seq_id;\n\n// default hparams (GPT-2 117M)\nstruct gpt2_hparams {\n    int32_t n_vocab = 50257;\n    int32_t n_ctx   = 1024;\n    int32_t n_embd  = 768;\n    int32_t n_head  = 12;\n    int32_t n_layer = 12;\n    int32_t ftype   = 1;\n    float   eps     = 1e-5f;\n};\n\nstruct gpt2_layer {\n    // normalization\n    struct ggml_tensor * ln_1_g;\n    struct ggml_tensor * ln_1_b;\n\n    struct ggml_tensor * ln_2_g;\n    struct ggml_tensor * ln_2_b;\n\n    // attention\n    struct ggml_tensor * c_attn_attn_w;\n    struct ggml_tensor * c_attn_attn_b;\n\n    struct ggml_tensor * c_attn_proj_w;\n    struct ggml_tensor * c_attn_proj_b;\n\n    // mlp\n    struct ggml_tensor * c_mlp_fc_w;\n    struct ggml_tensor * c_mlp_fc_b;\n\n    struct ggml_tensor * c_mlp_proj_w;\n    struct ggml_tensor * c_mlp_proj_b;\n};\n\nstruct gpt2_kv_cell {\n    gpt2_pos pos   = -1;\n    gpt2_pos delta = 0;\n\n    std::set<gpt2_seq_id> seq_id;\n\n    bool has_seq_id(const gpt2_seq_id & id) const {\n        return seq_id.find(id) != seq_id.end();\n    }\n};\n\nstruct gpt2_kv_cache {\n    // key + value memory\n    struct ggml_tensor * k;\n    struct ggml_tensor * v;\n    //\n\n    uint32_t head = 0;\n    uint32_t size = 0;\n\n    // computed before each graph build\n    uint32_t n = 0;\n\n    std::vector<gpt2_kv_cell> cells;\n\n    ggml_backend_buffer_t buffer;\n};\n\nstruct gpt2_model {\n    gpt2_hparams hparams;\n\n    // normalization\n    struct ggml_tensor * ln_f_g;\n    struct ggml_tensor * ln_f_b;\n\n    struct ggml_tensor * wte;     //    token embedding\n    struct ggml_tensor * wpe;     // position embedding\n    struct ggml_tensor * lm_head; // language model head\n\n    std::vector<gpt2_layer> layers;\n\n    gpt2_kv_cache kv_cache;\n\n    struct ggml_context * ctx_w;\n\n    ggml_backend_t backend = NULL;\n\n    ggml_backend_buffer_t buffer_w;\n\n    std::map<std::string, struct ggml_tensor *> tensors;\n};\n\n// Input data for gpt2_decode\n// A gpt2_batch object can contain input about one or many sequences\n// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens\n//\n// - token  : the token ids of the input (used when embd is NULL)\n// - embd   : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)\n// - pos    : the positions of the respective token in the sequence\n// - seq_id : the sequence to which the respective token belongs\n// - logits : if zero, the logits for the respective token will not be output\n//\nstruct gpt2_batch {\n    int32_t n_tokens = -1;\n\n    gpt_vocab::id  * token  = {};\n    float          * embd   = {};\n    gpt2_pos       * pos    = {};\n    gpt2_seq_id    * seq_id = {};\n    int8_t         * logits = {};\n};\n\n// load the model's weights from a file\nbool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_ctx, int n_gpu_layers) {\n    printf(\"%s: loading model from '%s'\\n\", __func__, fname.c_str());\n\n    auto fin = std::ifstream(fname, std::ios::binary);\n    if (!fin) {\n        fprintf(stderr, \"%s: failed to open '%s'\\n\", __func__, fname.c_str());\n        return false;\n    }\n\n    // verify magic\n    {\n        uint32_t magic;\n        fin.read((char *) &magic, sizeof(magic));\n        if (magic != GGML_FILE_MAGIC) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad magic)\\n\", __func__, fname.c_str());\n            return false;\n        }\n    }\n\n    // load hparams\n    {\n        auto & hparams = model.hparams;\n\n        fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));\n        fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));\n        fin.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));\n        fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));\n        fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));\n        fin.read((char *) &hparams.ftype,   sizeof(hparams.ftype));\n\n        const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;\n\n        printf(\"%s: n_vocab = %d\\n\", __func__, hparams.n_vocab);\n        printf(\"%s: n_ctx   = %d\\n\", __func__, hparams.n_ctx);\n        printf(\"%s: n_embd  = %d\\n\", __func__, hparams.n_embd);\n        printf(\"%s: n_head  = %d\\n\", __func__, hparams.n_head);\n        printf(\"%s: n_layer = %d\\n\", __func__, hparams.n_layer);\n        printf(\"%s: ftype   = %d\\n\", __func__, hparams.ftype);\n        printf(\"%s: qntvr   = %d\\n\", __func__, qntvr);\n\n        hparams.ftype %= GGML_QNT_VERSION_FACTOR;\n    }\n\n    // load vocab\n    {\n        int32_t n_vocab = 0;\n        fin.read((char *) &n_vocab, sizeof(n_vocab));\n\n        if (n_vocab != model.hparams.n_vocab) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad vocab size %d != %d)\\n\",\n                    __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);\n            return false;\n        }\n\n        std::string word;\n        std::vector<char> buf(128);\n\n        for (int i = 0; i < n_vocab; i++) {\n            uint32_t len;\n            fin.read((char *) &len, sizeof(len));\n\n            buf.resize(len);\n            fin.read((char *) buf.data(), len);\n            word.assign(buf.data(), len);\n\n            vocab.token_to_id[word] = i;\n            vocab.id_to_token[i] = word;\n        }\n    }\n\n    // for the big tensors, we have the option to store the data in 16-bit floats or quantized\n    // in order to save memory and also to speed up the computation\n    ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));\n    if (wtype == GGML_TYPE_COUNT) {\n        fprintf(stderr, \"%s: invalid model file '%s' (bad ftype value %d)\\n\",\n                __func__, fname.c_str(), model.hparams.ftype);\n        return false;\n    }\n\n    auto & ctx = model.ctx_w;\n\n    size_t buffer_size = 0;\n\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n        const int n_vocab = hparams.n_vocab;\n\n        buffer_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g\n        buffer_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b\n\n        buffer_size += ggml_row_size(wtype,         n_vocab*n_embd); // wte\n        buffer_size += ggml_row_size(GGML_TYPE_F32,   n_ctx*n_embd); // wpe\n        buffer_size += ggml_row_size(wtype,         n_vocab*n_embd); // lm_head\n\n        buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g\n        buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b\n\n        buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g\n        buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b\n\n        buffer_size += n_layer*(ggml_row_size(wtype,         3*n_embd*n_embd)); // c_attn_attn_w\n        buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd));        // c_attn_attn_b\n\n        buffer_size += n_layer*(ggml_row_size(wtype,         n_embd*n_embd));   // c_attn_proj_w\n        buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd));          // c_attn_proj_b\n\n        buffer_size += n_layer*(ggml_row_size(wtype,         4*n_embd*n_embd)); // c_mlp_fc_w\n        buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd));        // c_mlp_fc_b\n\n        buffer_size += n_layer*(ggml_row_size(wtype,         4*n_embd*n_embd)); // c_mlp_proj_w\n        buffer_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd));        // c_mlp_proj_b\n\n        buffer_size += (6 + 12*n_layer)*128; // alignment overhead\n\n        printf(\"%s: ggml tensor size    = %d bytes\\n\", __func__, (int) sizeof(ggml_tensor));\n        printf(\"%s: backend buffer size = %6.2f MB\\n\", __func__, buffer_size/(1024.0*1024.0));\n    }\n\n    ggml_log_set(ggml_log_callback_default, nullptr);\n\n    // create the ggml context\n    {\n        size_t n_tensors = 2 + 6 + 12*model.hparams.n_layer;\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ ggml_tensor_overhead() * n_tensors,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ true,\n        };\n\n        model.ctx_w = ggml_init(params);\n        if (!model.ctx_w) {\n            fprintf(stderr, \"%s: ggml_init() failed\\n\", __func__);\n            return false;\n        }\n    }\n\n    // initialize the backend\n#ifdef GGML_USE_CUDA\n    if (n_gpu_layers > 0) {\n        fprintf(stderr, \"%s: using CUDA backend\\n\", __func__);\n        model.backend = ggml_backend_cuda_init(0);\n        if (!model.backend) {\n            fprintf(stderr, \"%s: ggml_backend_cuda_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n#ifdef GGML_USE_METAL\n    if (n_gpu_layers > 0) {\n        fprintf(stderr, \"%s: using Metal backend\\n\", __func__);\n        model.backend = ggml_backend_metal_init();\n        if (!model.backend) {\n            fprintf(stderr, \"%s: ggml_backend_metal_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n    if (!model.backend) {\n        // fallback to CPU backend\n        fprintf(stderr, \"%s: using CPU backend\\n\", __func__);\n        model.backend = ggml_backend_cpu_init();\n    }\n\n    if (!model.backend) {\n        fprintf(stderr, \"%s: ggml_backend_cpu_init() failed\\n\", __func__);\n        return false;\n    }\n\n    // allocate weights buffer\n    model.buffer_w = ggml_backend_alloc_buffer(model.backend, buffer_size);\n\n    // prepare memory for the weights\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n        const int n_vocab = hparams.n_vocab;\n\n        model.layers.resize(n_layer);\n\n        model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);\n        model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);\n\n        model.wte     = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);\n        model.wpe     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);\n        model.lm_head = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);\n\n        // map by name\n        model.tensors[\"model/ln_f/g\"] = model.ln_f_g;\n        model.tensors[\"model/ln_f/b\"] = model.ln_f_b;\n\n        model.tensors[\"model/wte\"]     = model.wte;\n        model.tensors[\"model/wpe\"]     = model.wpe;\n        model.tensors[\"model/lm_head\"] = model.lm_head;\n\n        for (int i = 0; i < n_layer; ++i) {\n            auto & layer = model.layers[i];\n\n            layer.ln_1_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n            layer.ln_1_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.ln_2_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n            layer.ln_2_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, 3*n_embd);\n            layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);\n\n            layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, n_embd);\n            layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.c_mlp_fc_w    = ggml_new_tensor_2d(ctx, wtype,           n_embd, 4*n_embd);\n            layer.c_mlp_fc_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);\n\n            layer.c_mlp_proj_w  = ggml_new_tensor_2d(ctx, wtype,         4*n_embd, n_embd);\n            layer.c_mlp_proj_b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            // map by name\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_1/g\"]        = layer.ln_1_g;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_1/b\"]        = layer.ln_1_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_2/g\"]        = layer.ln_2_g;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_2/b\"]        = layer.ln_2_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_attn/w\"] = layer.c_attn_attn_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_attn/b\"] = layer.c_attn_attn_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_proj/w\"] = layer.c_attn_proj_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_proj/b\"] = layer.c_attn_proj_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_fc/w\"]    = layer.c_mlp_fc_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_fc/b\"]    = layer.c_mlp_fc_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_proj/w\"]  = layer.c_mlp_proj_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_proj/b\"]  = layer.c_mlp_proj_b;\n        }\n    }\n\n    // override the default training context with the user-provided\n    model.hparams.n_ctx = n_ctx;\n\n    // key + value memory\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n\n        const int n_mem      = n_layer*n_ctx;\n        const int n_elements = n_embd*n_mem;\n\n        model.kv_cache.k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);\n        model.kv_cache.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);\n\n        model.kv_cache.head      = 0;\n        model.kv_cache.size      = n_ctx;\n\n        model.kv_cache.cells.resize(n_ctx);\n\n        const size_t memory_size = ggml_nbytes(model.kv_cache.k) + ggml_nbytes(model.kv_cache.v);\n\n        printf(\"%s: memory size = %8.2f MB, n_mem = %d\\n\", __func__, memory_size/1024.0/1024.0, n_mem);\n\n        // create a backend buffer (can be in host or device memory)\n        model.kv_cache.buffer = ggml_backend_alloc_buffer(model.backend, memory_size + 256);\n\n        // allocate the tensors into the backend buffer\n        {\n            ggml_tallocr alloc = ggml_tallocr_new(model.kv_cache.buffer);\n\n            // this updates the pointers in the tensors to point to the correct location in the buffer\n            // this is necessary since the ggml_context is .no_alloc == true\n            // note that the buffer can actually be a device buffer, depending on the backend\n            ggml_tallocr_alloc(&alloc, model.kv_cache.k);\n            ggml_tallocr_alloc(&alloc, model.kv_cache.v);\n        }\n    }\n\n    // load weights\n    {\n        ggml_tallocr alloc = ggml_tallocr_new(model.buffer_w);\n\n        size_t total_size = 0;\n\n        bool has_lm_head = false;\n\n        std::vector<char> read_buf;\n\n        while (true) {\n            int32_t n_dims;\n            int32_t length;\n            int32_t ttype;\n\n            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));\n            fin.read(reinterpret_cast<char *>(&length), sizeof(length));\n            fin.read(reinterpret_cast<char *>(&ttype),  sizeof(ttype));\n\n            if (fin.eof()) {\n                break;\n            }\n\n            int32_t nelements = 1;\n            int32_t ne[2] = { 1, 1 };\n            for (int i = 0; i < n_dims; ++i) {\n                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));\n                nelements *= ne[i];\n            }\n\n            std::string name(length, 0);\n            fin.read(&name[0], length);\n\n            if (model.tensors.find(name) == model.tensors.end()) {\n                fprintf(stderr, \"%s: unknown tensor '%s' in model file\\n\", __func__, name.c_str());\n                return false;\n            }\n\n            auto tensor = model.tensors[name];\n            ggml_set_name(tensor, name.c_str());\n            if (ggml_nelements(tensor) != nelements) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong size in model file\\n\", __func__, name.c_str());\n                return false;\n            }\n\n            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\\n\",\n                        __func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);\n                return false;\n            }\n\n            // for debugging\n            if (0) {\n                printf(\"%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\\n\", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));\n            }\n\n            const size_t bpe = ggml_type_size(ggml_type(ttype));\n\n            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\\n\",\n                        __func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe);\n                return false;\n            }\n\n            ggml_tallocr_alloc(&alloc, tensor);\n\n            if (ggml_backend_is_cpu  (model.backend)\n#ifdef GGML_USE_METAL\n                || ggml_backend_is_metal(model.backend)\n#endif\n                ) {\n                // for the CPU and Metal backend, we can read directly into the tensor\n                fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));\n            } else {\n                // read into a temporary buffer first, then copy to device memory\n                read_buf.resize(ggml_nbytes(tensor));\n                fin.read(read_buf.data(), ggml_nbytes(tensor));\n                ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));\n            }\n\n            // GPT-2 models share the WTE tensor as the LM head\n            if (name == \"model/wte\" && has_lm_head == false) {\n                //ggml_tallocr_alloc(alloc, model.lm_head);\n                //ggml_backend_tensor_copy(tensor, model.lm_head);\n                model.lm_head = tensor;\n            }\n\n            if (name == \"model/lm_head\") {\n                has_lm_head = true;\n            }\n\n            total_size += ggml_nbytes(tensor);\n        }\n\n        printf(\"%s: model size  = %8.2f MB\\n\", __func__, total_size/1024.0/1024.0);\n    }\n\n    fin.close();\n\n    return true;\n}\n\n// build the computation graph\nstruct ggml_cgraph * gpt2_graph(\n        const  gpt2_model  & model,\n        const  gpt2_batch  & batch,\n                     bool    measure) {\n    const auto & hparams = model.hparams;\n\n    const int n_embd  = hparams.n_embd;\n    const int n_layer = hparams.n_layer;\n    const int n_ctx   = hparams.n_ctx;\n    const int n_head  = hparams.n_head;\n\n    const auto & kv_cache = model.kv_cache;\n\n    const int32_t n_tokens = batch.n_tokens;\n    const int32_t n_kv     = measure ? n_ctx            : kv_cache.n;\n    const int32_t kv_head  = measure ? n_ctx - n_tokens : kv_cache.head;\n\n    // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data\n    static size_t buf_size = ggml_tensor_overhead()*GPT2_MAX_NODES + ggml_graph_overhead_custom(GPT2_MAX_NODES, false);\n    static std::vector<uint8_t> buf(buf_size);\n\n    struct ggml_init_params params = {\n        /*.mem_size   =*/ buf_size,\n        /*.mem_buffer =*/ buf.data(),\n        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()\n    };\n\n    struct ggml_context * ctx = ggml_init(params);\n\n    struct ggml_cgraph  * gf = ggml_new_graph_custom(ctx, GPT2_MAX_NODES, false);\n\n    struct ggml_tensor * inpL;\n    if (batch.token) {\n        struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens);\n        ggml_set_name(inp_tokens, \"inp_tokens\");\n        ggml_set_input(inp_tokens);\n\n        struct ggml_tensor * position = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens);\n        ggml_set_name(position, \"position\");\n        ggml_set_input(position);\n\n        // wte + wpe\n        inpL =\n            ggml_add(ctx,\n                    ggml_get_rows(ctx, model.wte, inp_tokens),\n                    ggml_get_rows(ctx, model.wpe, position));\n    } else {\n        GGML_ASSERT(batch.embd);\n\n        inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);\n        ggml_set_name(inpL, \"embd\");\n        ggml_set_input(inpL);\n    }\n\n    // KQ_mask (mask for 1 head, it will be broadcasted to all heads)\n    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_kv, n_tokens, 1);\n    ggml_set_name(KQ_mask, \"KQ_mask\");\n    ggml_set_input(KQ_mask);\n\n\n    for (int il = 0; il < n_layer; ++il) {\n        struct ggml_tensor * cur;\n\n        // norm\n        {\n            // [ 768, N]\n            cur = ggml_norm(ctx, inpL, hparams.eps);\n\n            // cur = ln_1_g*cur + ln_1_b\n            // [ 768, N]\n            cur = ggml_add(ctx,\n                    ggml_mul(ctx,\n                        cur,\n                        model.layers[il].ln_1_g),\n                    model.layers[il].ln_1_b);\n        }\n\n        // attn\n        // [2304,        768] - model.layers[il].c_attn_attn_w\n        // [2304,          1] - model.layers[il].c_attn_attn_b\n        // [ 768,   n_tokens] - cur (in)\n        // [2304,   n_tokens] - cur (out)\n        //\n        // cur = attn_w*cur + attn_b\n        // [2304, n_tokens]\n        {\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_attn_attn_w,\n                    cur);\n\n            cur = ggml_add(ctx,\n                    cur,\n                    model.layers[il].c_attn_attn_b);\n        }\n\n        // self-attention\n        {\n            struct ggml_tensor * Qcur = ggml_view_2d(ctx, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd);\n            struct ggml_tensor * Kcur = ggml_view_2d(ctx, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*n_embd);\n            struct ggml_tensor * Vcur = ggml_view_2d(ctx, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*n_embd);\n\n            // store key and value to memory\n            if (n_tokens >= 1) {\n                struct ggml_tensor * k = ggml_view_1d(ctx, model.kv_cache.k, n_tokens*n_embd, (ggml_element_size(model.kv_cache.k)*n_embd)*(il*n_ctx + kv_head));\n                struct ggml_tensor * v = ggml_view_1d(ctx, model.kv_cache.v, n_tokens*n_embd, (ggml_element_size(model.kv_cache.v)*n_embd)*(il*n_ctx + kv_head));\n\n                ggml_build_forward_expand(gf, ggml_cpy(ctx, Kcur, k));\n                ggml_build_forward_expand(gf, ggml_cpy(ctx, Vcur, v));\n            }\n\n            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)\n            // [64, N, 12]\n            struct ggml_tensor * Q =\n                ggml_permute(ctx,\n                        ggml_cont_3d(ctx,\n                            Qcur,\n                            n_embd/n_head, n_head, n_tokens),\n                        0, 2, 1, 3);\n\n            // K = Kmem.view(n_embd/n_head, n_head, n_kv).permute(0, 2, 1, 3)\n            // [64, n_kv, 12]\n            struct ggml_tensor * K =\n                ggml_permute(ctx,\n                        ggml_reshape_3d(ctx,\n                            ggml_view_1d(ctx, model.kv_cache.k, n_kv*n_embd, il*n_ctx*ggml_element_size(model.kv_cache.k)*n_embd),\n                            n_embd/n_head, n_head, n_kv),\n                        0, 2, 1, 3);\n\n            // GG: flash attention\n            //struct ggml_tensor * V =\n            //    ggml_cpy(ctx0,\n            //            ggml_permute(ctx0,\n            //                ggml_reshape_3d(ctx0,\n            //                    ggml_view_1d(ctx0, model.kv_cache.v, n_kv*n_embd, il*n_ctx*ggml_element_size(model.kv_cache.v)*n_embd),\n            //                    n_embd/n_head, n_head, n_kv),\n            //                1, 2, 0, 3),\n            //            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_embd/n_head, n_head));\n\n            //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);\n\n            // K * Q\n            // [n_kv, n_tokens, 12]\n            struct ggml_tensor * KQ = ggml_mul_mat(ctx, K, Q);\n\n            // KQ_scaled = KQ / sqrt(n_embd/n_head)\n            // [n_kv, n_tokens, 12]\n            struct ggml_tensor * KQ_scaled =\n                ggml_scale(ctx,\n                        KQ,\n                        1.0f/sqrtf(float(n_embd)/n_head));\n\n            // KQ_masked = mask_past(KQ_scaled)\n            // [n_kv, n_tokens, 12]\n            struct ggml_tensor * KQ_masked = ggml_add(ctx, KQ_scaled, KQ_mask);\n\n            // KQ = soft_max(KQ_masked)\n            // [n_kv, N, 12]\n            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx, KQ_masked);\n\n            // V_trans = Vmem.view(n_embd/n_head, n_head, n_kv).permute(1, 2, 0, 3).contiguous()\n            // [n_kv, 64, 12]\n            struct ggml_tensor * V_trans =\n                ggml_cont_3d(ctx,\n                        ggml_permute(ctx,\n                            ggml_reshape_3d(ctx,\n                                ggml_view_1d(ctx, model.kv_cache.v, n_kv*n_embd, il*n_ctx*ggml_element_size(model.kv_cache.v)*n_embd),\n                                n_embd/n_head, n_head, n_kv),\n                            1, 2, 0, 3),\n                        n_kv, n_embd/n_head, n_head);\n\n            // KQV = transpose(V) * KQ_soft_max\n            // [64, n_tokens, 12]\n            struct ggml_tensor * KQV = ggml_mul_mat(ctx, V_trans, KQ_soft_max);\n\n            // KQV_merged = KQV.permute(0, 2, 1, 3)\n            // [64, 12, n_tokens]\n            struct ggml_tensor * KQV_merged = ggml_permute(ctx, KQV, 0, 2, 1, 3);\n\n            // cur = KQV_merged.contiguous().view(n_embd, N)\n            // [768, n_tokens]\n            cur = ggml_cont_2d(ctx, KQV_merged, n_embd, n_tokens);\n        }\n\n        // projection\n        // [ 768, 768] - model.layers[il].c_attn_proj_w\n        // [ 768,   1] - model.layers[il].c_attn_proj_b\n        // [ 768,   N] - cur (in)\n        // [ 768,   N] - cur (out)\n        //\n        // cur = proj_w*cur + proj_b\n        // [768, N]\n        {\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_attn_proj_w,\n                    cur);\n\n            cur = ggml_add(ctx,\n                    cur,\n                    model.layers[il].c_attn_proj_b);\n        }\n\n        // add the input\n        cur = ggml_add(ctx, cur, inpL);\n\n        struct ggml_tensor * inpFF = cur;\n\n        // feed-forward network\n        {\n            // norm\n            {\n                cur = ggml_norm(ctx, inpFF, hparams.eps);\n\n                // cur = ln_2_g*cur + ln_2_b\n                // [ 768, N]\n                cur = ggml_add(ctx,\n                        ggml_mul(ctx,\n                            cur,\n                            model.layers[il].ln_2_g),\n                        model.layers[il].ln_2_b);\n            }\n\n            // fully connected\n            // [3072, 768] - model.layers[il].c_mlp_fc_w\n            // [3072,   1] - model.layers[il].c_mlp_fc_b\n            // [ 768,   N] - cur (in)\n            // [3072,   N] - cur (out)\n            //\n            // cur = fc_w*cur + fc_b\n            // [3072, N]\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_mlp_fc_w,\n                    cur);\n\n            cur = ggml_add(ctx,\n                    cur,\n                    model.layers[il].c_mlp_fc_b);\n\n            // GELU activation\n            // [3072, N]\n            cur = ggml_gelu(ctx, cur);\n\n            // projection\n            // [ 768, 3072] - model.layers[il].c_mlp_proj_w\n            // [ 768,    1] - model.layers[il].c_mlp_proj_b\n            // [3072,    N] - cur (in)\n            // [ 768,    N] - cur (out)\n            //\n            // cur = proj_w*cur + proj_b\n            // [768, N]\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_mlp_proj_w,\n                    cur);\n\n            cur = ggml_add(ctx,\n                    cur,\n                    model.layers[il].c_mlp_proj_b);\n        }\n\n        // input for next layer\n        inpL = ggml_add(ctx, cur, inpFF);\n    }\n\n    // norm\n    {\n        // [ 768, N]\n        inpL = ggml_norm(ctx, inpL, hparams.eps);\n\n        // inpL = ln_f_g*inpL + ln_f_b\n        // [ 768, N]\n        inpL = ggml_add(ctx,\n                ggml_mul(ctx,\n                    inpL,\n                    model.ln_f_g),\n                model.ln_f_b);\n    }\n\n    // inpL = WTE * inpL\n    // [ 768, 50257] - model.lm_head\n    // [ 768, N]     - inpL\n    inpL = ggml_mul_mat(ctx, model.lm_head, inpL);\n\n    // logits -> probs\n    //inpL = ggml_soft_max(ctx0, inpL);\n\n    ggml_build_forward_expand(gf, inpL);\n\n    ggml_free(ctx);\n\n    return gf;\n}\n\nstatic void gpt2_kv_cache_seq_cp(\n        struct gpt2_kv_cache & cache,\n                 gpt2_seq_id   seq_id_src,\n                 gpt2_seq_id   seq_id_dst,\n                    gpt2_pos   p0,\n                    gpt2_pos   p1) {\n    if (p0 < 0) p0 = 0;\n    if (p1 < 0) p1 = std::numeric_limits<gpt2_pos>::max();\n\n    for (uint32_t i = 0; i < cache.size; ++i) {\n        if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {\n            cache.cells[i].seq_id.insert(seq_id_dst);\n        }\n    }\n}\n\nstruct gpt2_batch gpt2_batch_init(int32_t n_tokens, int32_t embd) {\n    gpt2_batch batch;\n\n    if (embd) {\n        batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);\n    } else {\n        batch.token = (gpt_vocab::id *) malloc(sizeof(gpt_vocab::id) * n_tokens);\n    }\n\n    batch.pos    = (gpt2_pos *)    malloc(sizeof(gpt2_pos)    * n_tokens);\n    batch.seq_id = (gpt2_seq_id *) malloc(sizeof(gpt2_seq_id) * n_tokens);\n    batch.logits = (int8_t *)      malloc(sizeof(int8_t)      * n_tokens);\n\n    return batch;\n}\n\nvoid gpt2_batch_free(struct gpt2_batch batch) {\n    if (batch.token)  free(batch.token);\n    if (batch.embd)   free(batch.embd);\n    if (batch.pos)    free(batch.pos);\n    if (batch.seq_id) free(batch.seq_id);\n    if (batch.logits) free(batch.logits);\n}\n\n// Positive return values does not mean a fatal error, but rather a warning.\n//   0 - success\n// < 0 - error\nint gpt2_decode(\n        struct gpt2_model &  model,\n        ggml_gallocr_t       allocr,\n        struct gpt2_batch    batch,\n        int                  n_threads,\n        std::vector<float> & logits) {\n    const int32_t n_tokens = batch.n_tokens;\n    const auto &  hparams  = model.hparams;\n    const int     n_vocab  = hparams.n_vocab;\n\n    if (n_tokens == 0) {\n        printf(\"%s: n_tokens == 0\", __func__);\n        return -1;\n    }\n\n    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd));\n\n    auto & cache = model.kv_cache;\n\n    for (int i = 0; i < n_tokens; i++) {\n        cache.cells[cache.head + i].pos = batch.pos[i];\n        cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]);\n    }\n\n    cache.n = cache.head + n_tokens;\n\n    struct ggml_cgraph * gf = gpt2_graph(model, batch, false);\n\n    // allocate tensors\n    ggml_gallocr_alloc_graph(allocr, gf);\n\n    // set the graph inputs\n    if (batch.token) {\n        struct ggml_tensor * inp_tokens = ggml_graph_get_tensor(gf, \"inp_tokens\");\n        ggml_backend_tensor_set(inp_tokens, batch.token, 0, n_tokens*ggml_element_size(inp_tokens));\n\n        struct ggml_tensor * position = ggml_graph_get_tensor(gf, \"position\");\n        for (int i = 0; i < n_tokens; ++i) {\n            int32_t v = batch.pos[i];\n            ggml_backend_tensor_set(position, &v, i*sizeof(int32_t), sizeof(v));\n        }\n    } else {\n        struct ggml_tensor * embd = ggml_graph_get_tensor(gf, \"embd\");\n        ggml_backend_tensor_set(embd, batch.embd, 0, n_tokens * hparams.n_embd * ggml_element_size(embd));\n    }\n\n    {\n        struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, \"KQ_mask\");\n        const auto & kv_cache = model.kv_cache;\n        const int32_t n_tokens = batch.n_tokens;\n        const int32_t n_kv     = kv_cache.n;\n\n        std::vector<float> data_buf(n_kv*n_tokens);\n        const float neg_inf_v = -INFINITY;\n\n        for (int h = 0; h < 1; ++h) {\n            int h_offset = h*(n_kv*n_tokens);\n            for (int j = 0; j < n_tokens; ++j) {\n                const gpt2_pos    pos    = batch.pos[j];\n                const gpt2_seq_id seq_id = batch.seq_id[j];\n\n                for (int i = 0; i < n_kv; ++i) {\n                    if (!kv_cache.cells[i].has_seq_id(seq_id) || kv_cache.cells[i].pos > pos) {\n                        data_buf[h_offset + j*n_kv + i] = neg_inf_v;\n                    }\n                }\n            }\n        }\n\n        ggml_backend_tensor_set(KQ_mask, data_buf.data(), 0, data_buf.size() * sizeof(float));\n    }\n\n    // run the computation\n    if (ggml_backend_is_cpu(model.backend)) {\n        ggml_backend_cpu_set_n_threads(model.backend, n_threads);\n    }\n    ggml_backend_graph_compute(model.backend, gf);\n\n    //if (n_past%100 == 0) {\n    //    ggml_graph_print   (&gf);\n    //    ggml_graph_dump_dot(&gf, NULL, \"gpt-2.dot\");\n    //}\n\n    // in this case, the output tensor is the last one in the graph\n    struct ggml_tensor * inpL = ggml_graph_node(gf, -1);\n\n    if (batch.logits) {\n        // return logits for all tokens\n        logits.resize(n_vocab*n_tokens);\n        for (int32_t i = 0; i < n_tokens; i++) {\n            if (batch.logits[i] == 0) {\n                continue;\n            }\n            ggml_backend_tensor_get(inpL, logits.data() + n_vocab*i, n_vocab*i*sizeof(float), sizeof(float)*n_vocab);\n        }\n    } else {\n        // return result just for the last token\n        logits.resize(n_vocab);\n        ggml_backend_tensor_get(inpL, logits.data(), (n_vocab*(n_tokens-1))*sizeof(float), sizeof(float)*n_vocab);\n    }\n\n    // update the kv ring buffer\n    cache.head += n_tokens;\n\n    // ensure kv cache head points to a valid index.\n    if (cache.head >= cache.size) {\n        printf(\"%s: cache.head >= cache.size\\n\", __func__);\n        return -2;\n    }\n\n    return 0;\n}\n\nint main(int argc, char ** argv) {\n    ggml_time_init();\n\n    const int64_t t_main_start_us = ggml_time_us();\n\n    gpt_params params;\n\n    if (gpt_params_parse(argc, argv, params) == false) {\n        return 1;\n    }\n\n    if (params.seed < 0) {\n        params.seed = time(NULL);\n    }\n\n    printf(\"%s: seed = %d\\n\", __func__, params.seed);\n\n    std::mt19937 rng(params.seed);\n    if (params.prompt.empty()) {\n        params.prompt = gpt_random_prompt(rng);\n    }\n\n    int64_t t_load_us = 0;\n\n    gpt_vocab vocab;\n    gpt2_model model;\n\n    // load the model\n    {\n        const int64_t t_start_us = ggml_time_us();\n\n        if (!gpt2_model_load(params.model, model, vocab, params.n_ctx, params.n_gpu_layers)) {\n            fprintf(stderr, \"%s: failed to load model from '%s'\\n\", __func__, params.model.c_str());\n            return 1;\n        }\n\n        t_load_us = ggml_time_us() - t_start_us;\n\n        test_gpt_tokenizer(vocab, params.token_test);\n    }\n\n    // tokenize the prompt\n    std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);\n\n    const int n_parallel = params.n_parallel;\n    const int n_batch_max = std::max(embd_inp.size(), (size_t)n_parallel);\n\n    // create a gpt2_batch\n    // we use this object to submit token data for decoding\n    gpt2_batch batch = gpt2_batch_init(n_batch_max, 0);\n\n    // prepare required memory and allocate the compute buffer\n    ggml_gallocr_t allocr = NULL;\n    {\n        // create an allocator to measure the memory usage\n        allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));\n\n        // create the worst case graph for memory usage estimation\n        batch.n_tokens = n_batch_max;\n        struct ggml_cgraph * gf = gpt2_graph(model, batch, true);\n\n        // pre-allocate the compute buffer for the worst case (optional)\n        ggml_gallocr_reserve(allocr, gf);\n        size_t mem_size = ggml_gallocr_get_buffer_size(allocr, 0);\n        fprintf(stderr, \"%s: compute buffer size: %.2f MB\\n\", __func__, mem_size/1024.0/1024.0);\n    }\n\n    int64_t t_sample_us  = 0;\n    int64_t t_predict_us = 0;\n\n    std::vector<float> logits;\n\n    // evaluate the initial prompt\n    batch.n_tokens = embd_inp.size();\n\n    for (int32_t i = 0; i < batch.n_tokens; i++) {\n        batch.token[i]  = embd_inp[i];\n        batch.pos[i]    = i;\n        batch.seq_id[i] = 0;\n        batch.logits[i] = false;\n    }\n\n    // gpt2_decode will output logits only for the last token of the prompt\n    batch.logits[batch.n_tokens - 1] = true;\n\n    if (gpt2_decode(model, allocr, batch, params.n_threads, logits) != 0) {\n        printf(\"%s: gpt2_decode() failed\\n\", __func__);\n        return 1;\n    }\n\n    // assign the system KV cache to all parallel sequences\n    // this way, the parallel sequences will \"reuse\" the prompt tokens without having to copy them\n    for (int32_t i = 1; i < n_parallel; ++i) {\n        gpt2_kv_cache_seq_cp(model.kv_cache, 0, i, 0, batch.n_tokens);\n    }\n\n    if (n_parallel > 1) {\n        printf(\"\\n\\n%s: generating %d sequences ...\\n\", __func__, n_parallel);\n    }\n\n    params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());\n\n    printf(\"%s: prompt: '%s'\\n\", __func__, params.prompt.c_str());\n    printf(\"%s: number of tokens in prompt = %zu, first 8 tokens: \", __func__, embd_inp.size());\n    for (int i = 0; i < std::min(8, (int) embd_inp.size()); i++) {\n        printf(\"%d \", embd_inp[i]);\n    }\n    printf(\"\\n\\n\");\n\n    std::vector<gpt_vocab::token> streams(n_parallel);\n\n    // remember the batch index of the last token for each parallel sequence\n    // we need this to determine which logits to sample from\n    std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);\n\n    int n_cur     = batch.n_tokens;\n    int n_len     = batch.n_tokens + params.n_predict;\n    int n_decoded = 0;\n\n    const int   n_vocab = model.hparams.n_vocab;\n    const int   top_k = params.top_k;\n    const float top_p = params.top_p;\n    const float temp  = params.temp;\n\n    while (n_cur < n_len) {\n        batch.n_tokens = 0;\n\n        for (int32_t i = 0; i < n_parallel; ++i) {\n            if (i_batch[i] < 0) {\n                // the stream has already finished\n                continue;\n            }\n\n            auto * logits_i = logits.data() + i_batch[i]*n_vocab;\n\n            gpt_vocab::id id = 0;\n            {\n                const int64_t t_start_sample_us = ggml_time_us();\n\n                id = gpt_sample_top_k_top_p(vocab, logits_i, top_k, top_p, temp, rng);\n\n                t_sample_us += ggml_time_us() - t_start_sample_us;\n            }\n\n            // is it an end of stream? -> mark the stream as finished\n            if ((!params.ignore_eos && id == 50256) || n_cur == n_len - 1) {\n                i_batch[i] = -1;\n                printf(\"\\n\");\n                if (n_parallel > 1) {\n                    printf(\"%s: stream %d finished at n_cur = %d\", __func__, i, n_cur);\n                }\n\n                continue;\n            }\n\n            auto& token = vocab.id_to_token[id];\n            if (n_parallel == 1) {\n                printf(\"%s\", token.c_str());\n                fflush(stdout);\n            }\n\n            streams[i] += token;\n\n            // push this new token for next evaluation\n            batch.token [batch.n_tokens] = id;\n            batch.pos   [batch.n_tokens] = n_cur;\n            batch.seq_id[batch.n_tokens] = i;\n            batch.logits[batch.n_tokens] = true;\n\n            i_batch[i] = batch.n_tokens;\n\n            batch.n_tokens += 1;\n\n            n_decoded += 1;\n        }\n\n        // all streams are finished\n        if (batch.n_tokens == 0) {\n            break;\n        }\n\n        n_cur += 1;\n\n        {\n            const int64_t t_start_us = ggml_time_us();\n\n            // evaluate the current batch with the transformer model\n            int ret_code = gpt2_decode(model, allocr, batch, params.n_threads, logits);\n            if (ret_code != 0) {\n                fprintf(stderr, \"%s : failed to eval, return code %d\\n\", __func__, ret_code);\n                return 1;\n            }\n\n            t_predict_us += ggml_time_us() - t_start_us;\n        }\n    }\n\n    if (n_parallel > 1) {\n        printf(\"\\n\");\n\n        for (int32_t i = 0; i < n_parallel; ++i) {\n            printf(\"sequence %d:\\n\\n%s%s\\n\\n\", i, params.prompt.c_str(), streams[i].c_str());\n        }\n    }\n\n    // report timing\n    {\n        const int64_t t_main_end_us = ggml_time_us();\n\n        printf(\"\\n\\n\");\n        printf(\"%s:     n_decoded = %8d\\n\",      __func__, n_decoded);\n        printf(\"%s:     load time = %8.2f ms\\n\", __func__, t_load_us/1000.0f);\n        printf(\"%s:   sample time = %8.2f ms\\n\", __func__, t_sample_us/1000.0f);\n        printf(\"%s:  predict time = %8.2f ms\\n\", __func__, t_predict_us/1000.0f);\n        printf(\"%s:    total time = %8.2f ms\\n\", __func__, (t_main_end_us - t_main_start_us)/1000.0f);\n    }\n\n    gpt2_batch_free(batch);\n    ggml_free(model.ctx_w);\n\n    ggml_gallocr_free(allocr);\n    ggml_backend_buffer_free(model.buffer_w);\n    ggml_backend_buffer_free(model.kv_cache.buffer);\n    ggml_backend_free(model.backend);\n\n    return 0;\n}\n"
  },
  {
    "path": "examples/gpt-2/main-ctx.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n#include \"common.h\"\n#include \"common-ggml.h\"\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <string>\n#include <vector>\n\n#if defined(_MSC_VER)\n#pragma warning(disable: 4244 4267) // possible loss of data\n#endif\n\n// default hparams (GPT-2 117M)\nstruct gpt2_hparams {\n    int32_t n_vocab = 50257;\n    int32_t n_ctx   = 1024;\n    int32_t n_embd  = 768;\n    int32_t n_head  = 12;\n    int32_t n_layer = 12;\n    int32_t ftype   = 1;\n    float   eps     = 1e-5f;\n};\n\nstruct gpt2_layer {\n    // normalization\n    struct ggml_tensor * ln_1_g;\n    struct ggml_tensor * ln_1_b;\n\n    struct ggml_tensor * ln_2_g;\n    struct ggml_tensor * ln_2_b;\n\n    // attention\n    struct ggml_tensor * c_attn_attn_w;\n    struct ggml_tensor * c_attn_attn_b;\n\n    struct ggml_tensor * c_attn_proj_w;\n    struct ggml_tensor * c_attn_proj_b;\n\n    // mlp\n    struct ggml_tensor * c_mlp_fc_w;\n    struct ggml_tensor * c_mlp_fc_b;\n\n    struct ggml_tensor * c_mlp_proj_w;\n    struct ggml_tensor * c_mlp_proj_b;\n};\n\nstruct gpt2_model {\n    gpt2_hparams hparams;\n\n    // normalization\n    struct ggml_tensor * ln_f_g;\n    struct ggml_tensor * ln_f_b;\n\n    struct ggml_tensor * wte;     //    token embedding\n    struct ggml_tensor * wpe;     // position embedding\n    struct ggml_tensor * lm_head; // language model head\n\n    std::vector<gpt2_layer> layers;\n\n    // key + value memory\n    struct ggml_tensor * memory_k;\n    struct ggml_tensor * memory_v;\n\n    //\n    struct ggml_context * ctx_w;\n    std::map<std::string, struct ggml_tensor *> tensors;\n};\n\n// load the model's weights from a file\nbool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {\n    printf(\"%s: loading model from '%s'\\n\", __func__, fname.c_str());\n\n    auto fin = std::ifstream(fname, std::ios::binary);\n    if (!fin) {\n        fprintf(stderr, \"%s: failed to open '%s'\\n\", __func__, fname.c_str());\n        return false;\n    }\n\n    // verify magic\n    {\n        uint32_t magic;\n        fin.read((char *) &magic, sizeof(magic));\n        if (magic != GGML_FILE_MAGIC) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad magic)\\n\", __func__, fname.c_str());\n            return false;\n        }\n    }\n\n    // load hparams\n    {\n        auto & hparams = model.hparams;\n\n        fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));\n        fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));\n        fin.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));\n        fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));\n        fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));\n        fin.read((char *) &hparams.ftype,   sizeof(hparams.ftype));\n\n        const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;\n\n        printf(\"%s: n_vocab = %d\\n\", __func__, hparams.n_vocab);\n        printf(\"%s: n_ctx   = %d\\n\", __func__, hparams.n_ctx);\n        printf(\"%s: n_embd  = %d\\n\", __func__, hparams.n_embd);\n        printf(\"%s: n_head  = %d\\n\", __func__, hparams.n_head);\n        printf(\"%s: n_layer = %d\\n\", __func__, hparams.n_layer);\n        printf(\"%s: ftype   = %d\\n\", __func__, hparams.ftype);\n        printf(\"%s: qntvr   = %d\\n\", __func__, qntvr);\n\n        hparams.ftype %= GGML_QNT_VERSION_FACTOR;\n    }\n\n    // load vocab\n    {\n        int32_t n_vocab = 0;\n        fin.read((char *) &n_vocab, sizeof(n_vocab));\n\n        if (n_vocab != model.hparams.n_vocab) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad vocab size %d != %d)\\n\",\n                    __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);\n            return false;\n        }\n\n        std::string word;\n        std::vector<char> buf(128);\n\n        for (int i = 0; i < n_vocab; i++) {\n            uint32_t len;\n            fin.read((char *) &len, sizeof(len));\n\n            buf.resize(len);\n            fin.read((char *) buf.data(), len);\n            word.assign(buf.data(), len);\n\n            vocab.token_to_id[word] = i;\n            vocab.id_to_token[i] = word;\n        }\n    }\n\n    // for the big tensors, we have the option to store the data in 16-bit floats or quantized\n    // in order to save memory and also to speed up the computation\n    ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));\n    if (wtype == GGML_TYPE_COUNT) {\n        fprintf(stderr, \"%s: invalid model file '%s' (bad ftype value %d)\\n\",\n                __func__, fname.c_str(), model.hparams.ftype);\n        return false;\n    }\n\n    auto & ctx = model.ctx_w;\n\n    size_t ctx_size = 0;\n\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n        const int n_vocab = hparams.n_vocab;\n\n        ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g\n        ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b\n\n        ctx_size += ggml_row_size(wtype,         n_vocab*n_embd); // wte\n        ctx_size += ggml_row_size(GGML_TYPE_F32,   n_ctx*n_embd); // wpe\n        ctx_size += ggml_row_size(wtype,         n_vocab*n_embd); // lm_head\n\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b\n\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b\n\n        ctx_size += n_layer*(ggml_row_size(wtype,         3*n_embd*n_embd)); // c_attn_attn_w\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd));        // c_attn_attn_b\n\n        ctx_size += n_layer*(ggml_row_size(wtype,         n_embd*n_embd));   // c_attn_proj_w\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd));          // c_attn_proj_b\n\n        ctx_size += n_layer*(ggml_row_size(wtype,         4*n_embd*n_embd)); // c_mlp_fc_w\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd));        // c_mlp_fc_b\n\n        ctx_size += n_layer*(ggml_row_size(wtype,         4*n_embd*n_embd)); // c_mlp_proj_w\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd));        // c_mlp_proj_b\n\n        ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k\n        ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v\n\n        ctx_size += (6 + 12*n_layer)*512; // object overhead\n\n        printf(\"%s: ggml tensor size = %d bytes\\n\", __func__, (int) sizeof(ggml_tensor));\n        printf(\"%s: ggml ctx size = %6.2f MB\\n\", __func__, ctx_size/(1024.0*1024.0));\n    }\n\n    // create the ggml context\n    {\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ ctx_size,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ false,\n        };\n\n        model.ctx_w = ggml_init(params);\n        if (!model.ctx_w) {\n            fprintf(stderr, \"%s: ggml_init() failed\\n\", __func__);\n            return false;\n        }\n    }\n\n    // prepare memory for the weights\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n        const int n_vocab = hparams.n_vocab;\n\n        model.layers.resize(n_layer);\n\n        model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);\n        model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);\n\n        model.wte     = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);\n        model.wpe     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);\n        model.lm_head = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);\n\n        // map by name\n        model.tensors[\"model/ln_f/g\"] = model.ln_f_g;\n        model.tensors[\"model/ln_f/b\"] = model.ln_f_b;\n\n        model.tensors[\"model/wte\"]     = model.wte;\n        model.tensors[\"model/wpe\"]     = model.wpe;\n        model.tensors[\"model/lm_head\"] = model.lm_head;\n\n        for (int i = 0; i < n_layer; ++i) {\n            auto & layer = model.layers[i];\n\n            layer.ln_1_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n            layer.ln_1_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.ln_2_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n            layer.ln_2_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, 3*n_embd);\n            layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);\n\n            layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, n_embd);\n            layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.c_mlp_fc_w    = ggml_new_tensor_2d(ctx, wtype,           n_embd, 4*n_embd);\n            layer.c_mlp_fc_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);\n\n            layer.c_mlp_proj_w  = ggml_new_tensor_2d(ctx, wtype,         4*n_embd, n_embd);\n            layer.c_mlp_proj_b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            // map by name\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_1/g\"]        = layer.ln_1_g;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_1/b\"]        = layer.ln_1_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_2/g\"]        = layer.ln_2_g;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_2/b\"]        = layer.ln_2_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_attn/w\"] = layer.c_attn_attn_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_attn/b\"] = layer.c_attn_attn_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_proj/w\"] = layer.c_attn_proj_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_proj/b\"] = layer.c_attn_proj_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_fc/w\"]    = layer.c_mlp_fc_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_fc/b\"]    = layer.c_mlp_fc_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_proj/w\"]  = layer.c_mlp_proj_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_proj/b\"]  = layer.c_mlp_proj_b;\n        }\n    }\n\n    // key + value memory\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n\n        const int n_mem      = n_layer*n_ctx;\n        const int n_elements = n_embd*n_mem;\n\n        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);\n        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);\n\n        const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);\n\n        printf(\"%s: memory size = %8.2f MB, n_mem = %d\\n\", __func__, memory_size/1024.0/1024.0, n_mem);\n    }\n\n    // load weights\n    {\n        size_t total_size = 0;\n\n        bool has_lm_head = false;\n\n        while (true) {\n            int32_t n_dims;\n            int32_t length;\n            int32_t ttype;\n\n            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));\n            fin.read(reinterpret_cast<char *>(&length), sizeof(length));\n            fin.read(reinterpret_cast<char *>(&ttype),  sizeof(ttype));\n\n            if (fin.eof()) {\n                break;\n            }\n\n            int32_t nelements = 1;\n            int32_t ne[2] = { 1, 1 };\n            for (int i = 0; i < n_dims; ++i) {\n                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));\n                nelements *= ne[i];\n            }\n\n            std::string name(length, 0);\n            fin.read(&name[0], length);\n\n            if (model.tensors.find(name) == model.tensors.end()) {\n                fprintf(stderr, \"%s: unknown tensor '%s' in model file\\n\", __func__, name.c_str());\n                return false;\n            }\n\n            auto tensor = model.tensors[name];\n            if (ggml_nelements(tensor) != nelements) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong size in model file\\n\", __func__, name.c_str());\n                return false;\n            }\n\n            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\\n\",\n                        __func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);\n                return false;\n            }\n\n            // for debugging\n            if (0) {\n                printf(\"%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\\n\", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));\n            }\n\n            const size_t bpe = ggml_type_size(ggml_type(ttype));\n\n            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\\n\",\n                        __func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe);\n                return false;\n            }\n\n            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));\n\n            // GPT-2 models share the WTE tensor as the LM head\n            if (name == \"model/wte\" && has_lm_head == false) {\n                memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));\n            }\n\n            if (name == \"model/lm_head\") {\n                has_lm_head = true;\n            }\n\n            total_size += ggml_nbytes(tensor);\n        }\n\n        printf(\"%s: model size  = %8.2f MB\\n\", __func__, total_size/1024.0/1024.0);\n    }\n\n    fin.close();\n\n    return true;\n}\n\n// evaluate the transformer\n//\n//   - model:     the model\n//   - n_threads: number of threads to use\n//   - n_past:    the context size so far\n//   - embd_inp:  the embeddings of the tokens in the context\n//   - embd_w:    the predicted logits for the next token\n//\nbool gpt2_eval(\n        const gpt2_model & model,\n        const int n_threads,\n        const int n_past,\n        const std::vector<gpt_vocab::id> & embd_inp,\n              std::vector<float>         & embd_w,\n              size_t                     & mem_per_token) {\n    const int N = embd_inp.size();\n\n    const auto & hparams = model.hparams;\n\n    const int n_embd  = hparams.n_embd;\n    const int n_layer = hparams.n_layer;\n    const int n_ctx   = hparams.n_ctx;\n    const int n_head  = hparams.n_head;\n    const int n_vocab = hparams.n_vocab;\n\n    static size_t buf_size = 256u*1024*1024;\n    static void * buf = malloc(buf_size);\n\n    if (mem_per_token > 0 && mem_per_token*N > buf_size) {\n        const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead\n        //printf(\"\\n%s: reallocating buffer from %zu to %zu bytes\\n\", __func__, buf_size, buf_size_new);\n\n        // reallocate\n        buf_size = buf_size_new;\n        buf = realloc(buf, buf_size);\n        if (buf == nullptr) {\n            fprintf(stderr, \"%s: failed to allocate %zu bytes\\n\", __func__, buf_size);\n            return false;\n        }\n    }\n\n    struct ggml_init_params params = {\n        /*.mem_size   =*/ buf_size,\n        /*.mem_buffer =*/ buf,\n        /*.no_alloc   =*/ false,\n    };\n\n    struct ggml_context * ctx0 = ggml_init(params);\n    struct ggml_cgraph * gf = ggml_new_graph(ctx0);\n\n    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);\n    memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));\n\n    struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);\n    for (int i = 0; i < N; ++i) {\n        ((int32_t *) position->data)[i] = n_past + i;\n    }\n\n    // wte + wpe\n    struct ggml_tensor * inpL =\n        ggml_add(ctx0,\n                ggml_get_rows(ctx0, model.wte, embd),\n                ggml_get_rows(ctx0, model.wpe, position));\n\n    for (int il = 0; il < n_layer; ++il) {\n        struct ggml_tensor * cur;\n\n        // norm\n        {\n            // [ 768, N]\n            cur = ggml_norm(ctx0, inpL, hparams.eps);\n\n            // cur = ln_1_g*cur + ln_1_b\n            // [ 768, N]\n            cur = ggml_add(ctx0,\n                    ggml_mul(ctx0,\n                        ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),\n                        cur),\n                    ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));\n        }\n\n        // attn\n        // [2304, 768] - model.layers[il].c_attn_attn_w\n        // [2304,   1] - model.layers[il].c_attn_attn_b\n        // [ 768,   N] - cur (in)\n        // [2304,   N] - cur (out)\n        //\n        // cur = attn_w*cur + attn_b\n        // [2304, N]\n        {\n            cur = ggml_mul_mat(ctx0,\n                    model.layers[il].c_attn_attn_w,\n                    cur);\n\n            cur = ggml_add(ctx0,\n                    ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),\n                    cur);\n        }\n\n        // self-attention\n        {\n            struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);\n            struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);\n            struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);\n\n            // store key and value to memory\n            if (N >= 1) {\n                struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));\n                struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));\n\n                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));\n                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));\n            }\n\n            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)\n            // [64, N, 12]\n            struct ggml_tensor * Q =\n                ggml_permute(ctx0,\n                        ggml_cpy(ctx0,\n                            Qcur,\n                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),\n                        0, 2, 1, 3);\n\n            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)\n            // [64, n_past + N, 12]\n            struct ggml_tensor * K =\n                ggml_permute(ctx0,\n                        ggml_reshape_3d(ctx0,\n                            ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),\n                            n_embd/n_head, n_head, n_past + N),\n                        0, 2, 1, 3);\n\n            // GG: flash attention\n            //struct ggml_tensor * V =\n            //    ggml_cpy(ctx0,\n            //            ggml_permute(ctx0,\n            //                ggml_reshape_3d(ctx0,\n            //                    ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),\n            //                    n_embd/n_head, n_head, n_past + N),\n            //                1, 2, 0, 3),\n            //            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));\n\n            //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);\n\n            // K * Q\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);\n\n            // KQ_scaled = KQ / sqrt(n_embd/n_head)\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f/sqrt(float(n_embd)/n_head));\n\n            // KQ_masked = mask_past(KQ_scaled)\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);\n\n            // KQ = soft_max(KQ_masked)\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);\n\n            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()\n            // [n_past + N, 64, 12]\n            struct ggml_tensor * V_trans =\n                ggml_cpy(ctx0,\n                        ggml_permute(ctx0,\n                            ggml_reshape_3d(ctx0,\n                                ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),\n                                n_embd/n_head, n_head, n_past + N),\n                            1, 2, 0, 3),\n                        ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));\n\n            // KQV = transpose(V) * KQ_soft_max\n            // [64, N, 12]\n            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);\n\n            // KQV_merged = KQV.permute(0, 2, 1, 3)\n            // [64, 12, N]\n            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);\n\n            // cur = KQV_merged.contiguous().view(n_embd, N)\n            // [768, N]\n            cur = ggml_cpy(ctx0,\n                    KQV_merged,\n                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));\n        }\n\n        // projection\n        // [ 768, 768] - model.layers[il].c_attn_proj_w\n        // [ 768,   1] - model.layers[il].c_attn_proj_b\n        // [ 768,   N] - cur (in)\n        // [ 768,   N] - cur (out)\n        //\n        // cur = proj_w*cur + proj_b\n        // [768, N]\n        {\n            cur = ggml_mul_mat(ctx0,\n                    model.layers[il].c_attn_proj_w,\n                    cur);\n\n            cur = ggml_add(ctx0,\n                    ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),\n                    cur);\n        }\n\n        // add the input\n        cur = ggml_add(ctx0, cur, inpL);\n\n        struct ggml_tensor * inpFF = cur;\n\n        // feed-forward network\n        {\n            // norm\n            {\n                cur = ggml_norm(ctx0, inpFF, hparams.eps);\n\n                // cur = ln_2_g*cur + ln_2_b\n                // [ 768, N]\n                cur = ggml_add(ctx0,\n                        ggml_mul(ctx0,\n                            ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),\n                            cur),\n                        ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));\n            }\n\n            // fully connected\n            // [3072, 768] - model.layers[il].c_mlp_fc_w\n            // [3072,   1] - model.layers[il].c_mlp_fc_b\n            // [ 768,   N] - cur (in)\n            // [3072,   N] - cur (out)\n            //\n            // cur = fc_w*cur + fc_b\n            // [3072, N]\n            cur = ggml_mul_mat(ctx0,\n                    model.layers[il].c_mlp_fc_w,\n                    cur);\n\n            cur = ggml_add(ctx0,\n                    ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),\n                    cur);\n\n            // GELU activation\n            // [3072, N]\n            cur = ggml_gelu(ctx0, cur);\n\n            // projection\n            // [ 768, 3072] - model.layers[il].c_mlp_proj_w\n            // [ 768,    1] - model.layers[il].c_mlp_proj_b\n            // [3072,    N] - cur (in)\n            // [ 768,    N] - cur (out)\n            //\n            // cur = proj_w*cur + proj_b\n            // [768, N]\n            cur = ggml_mul_mat(ctx0,\n                    model.layers[il].c_mlp_proj_w,\n                    cur);\n\n            cur = ggml_add(ctx0,\n                    ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),\n                    cur);\n        }\n\n        // input for next layer\n        inpL = ggml_add(ctx0, cur, inpFF);\n    }\n\n    // norm\n    {\n        // [ 768, N]\n        inpL = ggml_norm(ctx0, inpL, hparams.eps);\n\n        // inpL = ln_f_g*inpL + ln_f_b\n        // [ 768, N]\n        inpL = ggml_add(ctx0,\n                ggml_mul(ctx0,\n                    ggml_repeat(ctx0, model.ln_f_g, inpL),\n                    inpL),\n                ggml_repeat(ctx0, model.ln_f_b, inpL));\n    }\n\n    // inpL = WTE * inpL\n    // [ 768, 50257] - model.lm_head\n    // [ 768, N]     - inpL\n    inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);\n\n    // logits -> probs\n    //inpL = ggml_soft_max_inplace(ctx0, inpL);\n\n    // run the computation\n    ggml_build_forward_expand(gf, inpL);\n    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);\n\n    //if (n_past%100 == 0) {\n    //    ggml_graph_print   (&gf);\n    //    ggml_graph_dump_dot(&gf, NULL, \"gpt-2.dot\");\n    //}\n\n    //embd_w.resize(n_vocab*N);\n    //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);\n\n    // return result just for the last token\n    embd_w.resize(n_vocab);\n    memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);\n\n    if (mem_per_token == 0) {\n        mem_per_token = ggml_used_mem(ctx0)/N;\n    }\n    //printf(\"used_mem = %zu\\n\", ggml_used_mem(ctx0));\n\n    ggml_free(ctx0);\n\n    return true;\n}\n\nint main(int argc, char ** argv) {\n    ggml_time_init();\n\n    const int64_t t_main_start_us = ggml_time_us();\n\n    gpt_params params;\n    params.model = \"models/gpt-2-117M/ggml-model.bin\";\n\n    if (gpt_params_parse(argc, argv, params) == false) {\n        return 1;\n    }\n\n    if (params.seed < 0) {\n        params.seed = time(NULL);\n    }\n\n    printf(\"%s: seed = %d\\n\", __func__, params.seed);\n\n    std::mt19937 rng(params.seed);\n    if (params.prompt.empty()) {\n        params.prompt = gpt_random_prompt(rng);\n    }\n\n    int64_t t_load_us = 0;\n\n    gpt_vocab vocab;\n    gpt2_model model;\n\n    // load the model\n    {\n        const int64_t t_start_us = ggml_time_us();\n\n        if (!gpt2_model_load(params.model, model, vocab)) {\n            fprintf(stderr, \"%s: failed to load model from '%s'\\n\", __func__, params.model.c_str());\n            return 1;\n        }\n\n        t_load_us = ggml_time_us() - t_start_us;\n\n        test_gpt_tokenizer(vocab, params.token_test);\n    }\n\n    int n_past = 0;\n\n    int64_t t_sample_us  = 0;\n    int64_t t_predict_us = 0;\n\n    std::vector<float> logits;\n\n    // tokenize the prompt\n    std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);\n\n    params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());\n\n    printf(\"%s: prompt: '%s'\\n\", __func__, params.prompt.c_str());\n    printf(\"%s: number of tokens in prompt = %zu, first 8 tokens: \", __func__, embd_inp.size());\n    for (int i = 0; i < std::min(8, (int) embd_inp.size()); i++) {\n        printf(\"%d \", embd_inp[i]);\n    }\n    printf(\"\\n\\n\");\n\n    // submit the input prompt token-by-token\n    // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning\n    std::vector<gpt_vocab::id> embd;\n\n    // determine the required inference memory per token:\n    size_t mem_per_token = 0;\n    gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);\n\n    for (size_t i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {\n        // predict\n        if (embd.size() > 0) {\n            const int64_t t_start_us = ggml_time_us();\n\n            if (!gpt2_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {\n                printf(\"Failed to predict\\n\");\n                return 1;\n            }\n\n            t_predict_us += ggml_time_us() - t_start_us;\n        }\n\n        n_past += embd.size();\n        embd.clear();\n\n        if (i >= embd_inp.size()) {\n            // sample next token\n            const int   top_k = params.top_k;\n            const float top_p = params.top_p;\n            const float temp  = params.temp;\n\n            const int n_vocab = model.hparams.n_vocab;\n\n            gpt_vocab::id id = 0;\n\n            {\n                const int64_t t_start_sample_us = ggml_time_us();\n\n                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);\n\n                t_sample_us += ggml_time_us() - t_start_sample_us;\n            }\n\n            // add it to the context\n            embd.push_back(id);\n        } else {\n            // if here, it means we are still processing the input prompt\n            for (size_t k = i; k < embd_inp.size(); k++) {\n                embd.push_back(embd_inp[k]);\n                if (int32_t(embd.size()) >= params.n_batch) {\n                    break;\n                }\n            }\n            i += embd.size() - 1;\n        }\n\n        // display text\n        for (auto id : embd) {\n            printf(\"%s\", vocab.id_to_token[id].c_str());\n        }\n        fflush(stdout);\n\n        // end of text token\n        if (embd.back() == 50256) {\n            break;\n        }\n    }\n\n    // report timing\n    {\n        const int64_t t_main_end_us = ggml_time_us();\n\n        printf(\"\\n\\n\");\n        printf(\"%s: mem per token = %8zu bytes\\n\", __func__, mem_per_token);\n        printf(\"%s:     load time = %8.2f ms\\n\", __func__, t_load_us/1000.0f);\n        printf(\"%s:   sample time = %8.2f ms\\n\", __func__, t_sample_us/1000.0f);\n        printf(\"%s:  predict time = %8.2f ms / %.2f ms per token\\n\", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);\n        printf(\"%s:    total time = %8.2f ms\\n\", __func__, (t_main_end_us - t_main_start_us)/1000.0f);\n    }\n\n    ggml_free(model.ctx_w);\n\n    return 0;\n}\n"
  },
  {
    "path": "examples/gpt-2/main-sched.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n\n#ifdef GGML_USE_CUDA\n#include \"ggml-cuda.h\"\n#endif\n\n#ifdef GGML_USE_METAL\n#include \"ggml-metal.h\"\n#endif\n\n#ifdef GGML_USE_BLAS\n#include \"ggml-blas.h\"\n#endif\n\n#include \"common.h\"\n#include \"common-ggml.h\"\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <string>\n#include <vector>\n\n#if defined(_MSC_VER)\n#pragma warning(disable: 4244 4267) // possible loss of data\n#endif\n\n#define GPT2_MAX_NODES 4096\n\nstatic void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {\n    (void) level;\n    (void) user_data;\n    fputs(text, stderr);\n    fflush(stderr);\n}\n\n// default hparams (GPT-2 117M)\nstruct gpt2_hparams {\n    int32_t n_vocab = 50257;\n    int32_t n_ctx   = 1024;\n    int32_t n_embd  = 768;\n    int32_t n_head  = 12;\n    int32_t n_layer = 12;\n    int32_t ftype   = 1;\n    float   eps     = 1e-5f;\n};\n\nstruct gpt2_layer {\n    // normalization\n    struct ggml_tensor * ln_1_g;\n    struct ggml_tensor * ln_1_b;\n\n    struct ggml_tensor * ln_2_g;\n    struct ggml_tensor * ln_2_b;\n\n    // attention\n    struct ggml_tensor * c_attn_attn_w;\n    struct ggml_tensor * c_attn_attn_b;\n\n    struct ggml_tensor * c_attn_proj_w;\n    struct ggml_tensor * c_attn_proj_b;\n\n    // mlp\n    struct ggml_tensor * c_mlp_fc_w;\n    struct ggml_tensor * c_mlp_fc_b;\n\n    struct ggml_tensor * c_mlp_proj_w;\n    struct ggml_tensor * c_mlp_proj_b;\n};\n\nstruct gpt2_model {\n    gpt2_hparams hparams;\n\n    // normalization\n    struct ggml_tensor * ln_f_g;\n    struct ggml_tensor * ln_f_b;\n\n    struct ggml_tensor * wte;     //    tkoen embedding\n    struct ggml_tensor * wpe;     // position embedding\n    struct ggml_tensor * lm_head; // language model head\n\n    std::vector<gpt2_layer> layers;\n\n    // key + value memory\n    struct ggml_tensor * memory_k;\n    struct ggml_tensor * memory_v;\n\n    //\n    struct ggml_context * ctx_w;\n\n    std::vector<ggml_backend_t> backends;\n    std::vector<ggml_backend_buffer_t> buffers_w;\n    ggml_backend_buffer_t buffer_kv;\n    ggml_backend_buffer_t buffer_input;\n\n    std::map<std::string, struct ggml_tensor *> tensors;\n\n    // inputs/constants\n    struct ggml_tensor * embd;\n    struct ggml_tensor * position;\n};\n\nvoid init_backends(gpt2_model & model, const gpt_params & params) {\n    ggml_backend_t gpu_backend = NULL;\n\n    ggml_log_set(ggml_log_callback_default, nullptr);\n\n    // initialize the backends\n#ifdef GGML_USE_CUDA\n    if (params.n_gpu_layers > 0) {\n        fprintf(stderr, \"%s: using CUDA backend\\n\", __func__);\n        gpu_backend = ggml_backend_cuda_init(0);\n        if (!gpu_backend) {\n            fprintf(stderr, \"%s: ggml_backend_cuda_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n#ifdef GGML_USE_METAL\n    if (params.n_gpu_layers > 0) {\n        fprintf(stderr, \"%s: using Metal backend\\n\", __func__);\n        gpu_backend = ggml_backend_metal_init();\n        if (!gpu_backend) {\n            fprintf(stderr, \"%s: ggml_backend_metal_init() failed\\n\", __func__);\n        }\n    }\n#endif\n    if (gpu_backend) {\n        model.backends.push_back(gpu_backend);\n    }\n\n#ifdef GGML_USE_BLAS\n    ggml_backend_t blas_backend = ggml_backend_blas_init();\n    if (!blas_backend) {\n        fprintf(stderr, \"%s: failed to initialize BLAS backend\\n\", __func__);\n    } else {\n        ggml_backend_blas_set_n_threads(blas_backend, params.n_threads);\n        model.backends.push_back(blas_backend);\n    }\n#endif\n\n    // always add the CPU backend as a fallback\n    ggml_backend_t cpu_backend = ggml_backend_cpu_init();\n    ggml_backend_cpu_set_n_threads(cpu_backend, params.n_threads);\n    model.backends.push_back(cpu_backend);\n}\n\n// load the model's weights from a file\nbool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, const gpt_params & params) {\n    printf(\"%s: loading model from '%s'\\n\", __func__, fname.c_str());\n\n    auto fin = std::ifstream(fname, std::ios::binary);\n    if (!fin) {\n        fprintf(stderr, \"%s: failed to open '%s'\\n\", __func__, fname.c_str());\n        return false;\n    }\n\n    // verify magic\n    {\n        uint32_t magic;\n        fin.read((char *) &magic, sizeof(magic));\n        if (magic != GGML_FILE_MAGIC) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad magic)\\n\", __func__, fname.c_str());\n            return false;\n        }\n    }\n\n    // load hparams\n    {\n        auto & hparams = model.hparams;\n\n        fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));\n        fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));\n        fin.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));\n        fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));\n        fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));\n        fin.read((char *) &hparams.ftype,   sizeof(hparams.ftype));\n\n        const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;\n\n        printf(\"%s: n_vocab = %d\\n\", __func__, hparams.n_vocab);\n        printf(\"%s: n_ctx   = %d\\n\", __func__, hparams.n_ctx);\n        printf(\"%s: n_embd  = %d\\n\", __func__, hparams.n_embd);\n        printf(\"%s: n_head  = %d\\n\", __func__, hparams.n_head);\n        printf(\"%s: n_layer = %d\\n\", __func__, hparams.n_layer);\n        printf(\"%s: ftype   = %d\\n\", __func__, hparams.ftype);\n        printf(\"%s: qntvr   = %d\\n\", __func__, qntvr);\n\n        hparams.ftype %= GGML_QNT_VERSION_FACTOR;\n    }\n\n    // load vocab\n    {\n        int32_t n_vocab = 0;\n        fin.read((char *) &n_vocab, sizeof(n_vocab));\n\n        if (n_vocab != model.hparams.n_vocab) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad vocab size %d != %d)\\n\",\n                    __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);\n            return false;\n        }\n\n        std::string word;\n        std::vector<char> buf(128);\n\n        for (int i = 0; i < n_vocab; i++) {\n            uint32_t len;\n            fin.read((char *) &len, sizeof(len));\n\n            buf.resize(len);\n            fin.read((char *) buf.data(), len);\n            word.assign(buf.data(), len);\n\n            vocab.token_to_id[word] = i;\n            vocab.id_to_token[i] = word;\n        }\n    }\n\n    // for the big tensors, we have the option to store the data in 16-bit floats or quantized\n    // in order to save memory and also to speed up the computation\n    ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));\n    if (wtype == GGML_TYPE_COUNT) {\n        fprintf(stderr, \"%s: invalid model file '%s' (bad ftype value %d)\\n\",\n                __func__, fname.c_str(), model.hparams.ftype);\n        return false;\n    }\n\n    auto & ctx = model.ctx_w;\n\n    // create the ggml context\n    {\n        size_t n_tensors = 3 /* input */ + 2 /* kv */ + 6 + 12*model.hparams.n_layer;\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ ggml_tensor_overhead() * n_tensors,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ true,\n        };\n\n        model.ctx_w = ggml_init(params);\n        if (!model.ctx_w) {\n            fprintf(stderr, \"%s: ggml_init() failed\\n\", __func__);\n            return false;\n        }\n    }\n\n    // create tensors for the weights\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n        const int n_vocab = hparams.n_vocab;\n\n        model.layers.resize(n_layer);\n\n        model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);\n        model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);\n\n        model.wte     = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);\n        model.wpe     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);\n        model.lm_head = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);\n\n        // map by name\n        model.tensors[\"model/ln_f/g\"] = model.ln_f_g;\n        model.tensors[\"model/ln_f/b\"] = model.ln_f_b;\n\n        model.tensors[\"model/wte\"]     = model.wte;\n        model.tensors[\"model/wpe\"]     = model.wpe;\n        model.tensors[\"model/lm_head\"] = model.lm_head;\n\n        for (int i = 0; i < n_layer; ++i) {\n            auto & layer = model.layers[i];\n\n            layer.ln_1_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n            layer.ln_1_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.ln_2_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n            layer.ln_2_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, 3*n_embd);\n            layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);\n\n            layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, n_embd);\n            layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.c_mlp_fc_w    = ggml_new_tensor_2d(ctx, wtype,           n_embd, 4*n_embd);\n            layer.c_mlp_fc_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);\n\n            layer.c_mlp_proj_w  = ggml_new_tensor_2d(ctx, wtype,         4*n_embd, n_embd);\n            layer.c_mlp_proj_b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            // map by name\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_1/g\"]        = layer.ln_1_g;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_1/b\"]        = layer.ln_1_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_2/g\"]        = layer.ln_2_g;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/ln_2/b\"]        = layer.ln_2_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_attn/w\"] = layer.c_attn_attn_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_attn/b\"] = layer.c_attn_attn_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_proj/w\"] = layer.c_attn_proj_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/attn/c_proj/b\"] = layer.c_attn_proj_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_fc/w\"]    = layer.c_mlp_fc_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_fc/b\"]    = layer.c_mlp_fc_b;\n\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_proj/w\"]  = layer.c_mlp_proj_w;\n            model.tensors[\"model/h\" + std::to_string(i) + \"/mlp/c_proj/b\"]  = layer.c_mlp_proj_b;\n        }\n    }\n\n    // assign tensors to backends\n    init_backends(model, params);\n    ggml_backend_t backend_gpu = model.backends.front();\n    ggml_backend_t backend_cpu = model.backends.back();\n    std::map<std::string, ggml_backend_t> tensor_backends;\n    {\n        const int i_gpu_first_layer = model.hparams.n_layer - params.n_gpu_layers;\n        for (auto it : model.tensors) {\n            const std::string & name = it.first;\n            // input tensors\n            if (name == \"model/wte\" || name == \"model/wpe\") {\n                if (params.n_gpu_layers > model.hparams.n_layer) {\n                    tensor_backends[name] = backend_gpu;\n                } else {\n                    tensor_backends[name] = backend_cpu;\n                }\n            }\n            // output tensors\n            if (name == \"model/ln_f/g\" || name == \"model/ln_f/b\" || name == \"model/lm_head\") {\n                if (params.n_gpu_layers > 0) {\n                    tensor_backends[name] = backend_gpu;\n                } else {\n                    tensor_backends[name] = backend_cpu;\n                }\n            }\n            // layer tensors\n            if (name.substr(0, 7) == \"model/h\") {\n                // parse layer number\n                int layer = std::stoi(name.substr(7, 2));\n                if (layer >= i_gpu_first_layer) {\n                    tensor_backends[name] = backend_gpu;\n                } else {\n                    tensor_backends[name] = backend_cpu;\n                }\n            }\n        }\n    }\n\n    // allocate buffers\n    std::map<ggml_backend_t, ggml_tallocr> backend_buffers;\n    for (auto backend : model.backends) {\n        // compute the size of the buffer\n        size_t size = 0;\n        for (auto it : model.tensors) {\n            if (tensor_backends[it.first] == backend) {\n                size += ggml_nbytes(it.second) + 512;\n            }\n        }\n        if (size > 0) {\n            printf(\"%s: %8s buffer size = %8.2f MB\\n\", __func__, ggml_backend_name(backend), size/1024.0/1024.0);\n            // allocate the buffer\n            ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, size);\n            ggml_backend_buffer_set_usage(buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);\n            model.buffers_w.push_back(buffer);\n\n            // create an allocator for the buffer to allocate the tensors\n            auto alloc = ggml_tallocr_new(buffer);\n            backend_buffers.insert(std::make_pair(backend, std::move(alloc)));\n        } else {\n            model.buffers_w.push_back(NULL);\n        }\n    }\n\n    // allocate key + value memory\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n\n        const int n_mem      = n_layer*n_ctx;\n        const int n_elements = n_embd*n_mem;\n\n        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);\n        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);\n\n        ggml_set_name(model.memory_k, \"model/memory_k\");\n        ggml_set_name(model.memory_v, \"model/memory_v\");\n\n        const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);\n\n        printf(\"%s: memory size = %8.2f MB, n_mem = %d\\n\", __func__, memory_size/1024.0/1024.0, n_mem);\n\n        // create a backend buffer (can be in host or device memory)\n        ggml_backend_t backend_kv = params.n_gpu_layers >= hparams.n_layer/2 ? backend_gpu : backend_cpu;\n        printf(\"%s: backend_kv = %s\\n\", __func__, ggml_backend_name(backend_kv));\n        model.buffer_kv = ggml_backend_alloc_buffer(backend_kv, memory_size + 512*2);\n\n        // allocate the tensors into the backend buffer\n        {\n            ggml_tallocr alloc = ggml_tallocr_new(model.buffer_kv);\n\n            // this updates the pointers in the tensors to point to the correct location in the buffer\n            // this is necessary since the ggml_context is .no_alloc == true\n            // note that the buffer can actually be a device buffer, depending on the backend\n            ggml_tallocr_alloc(&alloc, model.memory_k);\n            ggml_tallocr_alloc(&alloc, model.memory_v);\n        }\n    }\n\n    // load weights\n    {\n        size_t total_size = 0;\n\n        bool has_lm_head = false;\n\n        std::vector<char> read_buf;\n\n        while (true) {\n            int32_t n_dims;\n            int32_t length;\n            int32_t ttype;\n\n            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));\n            fin.read(reinterpret_cast<char *>(&length), sizeof(length));\n            fin.read(reinterpret_cast<char *>(&ttype),  sizeof(ttype));\n\n            if (fin.eof()) {\n                break;\n            }\n\n            int32_t nelements = 1;\n            int32_t ne[2] = { 1, 1 };\n            for (int i = 0; i < n_dims; ++i) {\n                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));\n                nelements *= ne[i];\n            }\n\n            std::string name(length, 0);\n            fin.read(&name[0], length);\n\n            if (model.tensors.find(name) == model.tensors.end()) {\n                fprintf(stderr, \"%s: unknown tensor '%s' in model file\\n\", __func__, name.c_str());\n                return false;\n            }\n\n            auto tensor = model.tensors[name];\n            ggml_set_name(tensor, name.c_str());\n            if (ggml_nelements(tensor) != nelements) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong size in model file\\n\", __func__, name.c_str());\n                return false;\n            }\n\n            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\\n\",\n                        __func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);\n                return false;\n            }\n\n            // for debugging\n            if (0) {\n                printf(\"%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\\n\", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));\n            }\n\n            const size_t bpe = ggml_type_size(ggml_type(ttype));\n\n            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\\n\",\n                        __func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe);\n                return false;\n            }\n\n            // allocate the tensor\n            ggml_backend_t backend = tensor_backends[name];\n            ggml_tallocr * alloc = &backend_buffers.find(backend)->second;\n            ggml_tallocr_alloc(alloc, tensor);\n            //printf(\"%s: [%5.5s] %s\\n\", __func__, ggml_backend_name(backend), name.c_str());\n\n            if (ggml_backend_is_cpu(backend)\n#ifdef GGML_USE_METAL\n                || ggml_backend_is_metal(backend)\n#endif\n                ) {\n                // for the CPU and Metal backend, we can read directly into the tensor\n                fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));\n            } else {\n                // read into a temporary buffer first, then copy to device memory\n                read_buf.resize(ggml_nbytes(tensor));\n                fin.read(read_buf.data(), ggml_nbytes(tensor));\n                ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));\n            }\n\n            // GPT-2 models share the WTE tensor as the LM head\n            if (name == \"model/wte\" && has_lm_head == false) {\n                ggml_tallocr * alloc_head = &backend_buffers.find(tensor_backends[\"model/lm_head\"])->second;\n                ggml_tallocr_alloc(alloc_head, model.lm_head);\n                //printf(\"%s: [%5.5s] %s (copied)\\n\", __func__, ggml_backend_name(tensor_backends[\"model/lm_head\"]), \"model/lm_head\");\n                ggml_backend_tensor_copy(tensor, model.lm_head);\n                total_size += ggml_nbytes(model.lm_head);\n            }\n\n            if (name == \"model/lm_head\") {\n                has_lm_head = true;\n            }\n\n            total_size += ggml_nbytes(tensor);\n        }\n        printf(\"%s: model size  = %8.2f MB\\n\", __func__, total_size/1024.0/1024.0);\n    }\n\n    fin.close();\n\n    // allocate input tensors\n    {\n        model.embd = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, model.hparams.n_ctx);\n        model.position = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, model.hparams.n_ctx);\n\n        ggml_set_name(model.embd, \"in/embd\");\n        ggml_set_name(model.position, \"in/position\");\n\n        // add input tensors to cpu backend\n        size_t input_size = ggml_nbytes(model.embd) + ggml_nbytes(model.position);\n\n        // FIXME: use cpu backend after sched impl\n        ggml_backend_t backend_input = params.n_gpu_layers >= model.hparams.n_layer ? backend_gpu : backend_cpu;\n        model.buffer_input = ggml_backend_alloc_buffer(backend_input, input_size + 512*3);\n        printf(\"%s: backend_in = %s (%zu bytes)\\n\", __func__, ggml_backend_name(backend_input), input_size);\n\n        // allocate the tensors into the backend buffer\n        ggml_tallocr alloc = ggml_tallocr_new(model.buffer_input);\n        ggml_tallocr_alloc(&alloc, model.embd);\n        ggml_tallocr_alloc(&alloc, model.position);\n    }\n\n    return true;\n}\n\n// build the computation graph\nstruct ggml_cgraph * gpt2_graph(\n        const gpt2_model & model,\n        const int n_past,\n        const std::vector<gpt_vocab::id> & embd_inp) {\n    const int N = embd_inp.size();\n\n    const auto & hparams = model.hparams;\n\n    const int n_embd  = hparams.n_embd;\n    const int n_layer = hparams.n_layer;\n    const int n_ctx   = hparams.n_ctx;\n    const int n_head  = hparams.n_head;\n\n    // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data\n    static size_t buf_size = ggml_tensor_overhead()*GPT2_MAX_NODES + ggml_graph_overhead_custom(GPT2_MAX_NODES, false);\n    static std::vector<uint8_t> buf(buf_size);\n\n    struct ggml_init_params params = {\n        /*.mem_size   =*/ buf_size,\n        /*.mem_buffer =*/ buf.data(),\n        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()\n    };\n\n    struct ggml_context * ctx = ggml_init(params);\n\n    struct ggml_cgraph  * gf = ggml_new_graph_custom(ctx, GPT2_MAX_NODES, false);\n\n    struct ggml_tensor * embd = ggml_view_1d(ctx, model.embd, N, 0);\n\n    // set inputs\n    // TODO: move to gpt2_eval\n    ggml_backend_tensor_set(model.embd, embd_inp.data(), 0, N*ggml_element_size(embd));\n\n    struct ggml_tensor * position = ggml_view_1d(ctx, model.position, N, 0);\n    for (int i = 0; i < N; ++i) {\n        int32_t v = n_past + i;\n        ggml_backend_tensor_set(model.position, &v, i*sizeof(int32_t), sizeof(v));\n    }\n\n    const float KQ_scale = 1.0f/sqrtf(float(model.hparams.n_embd)/model.hparams.n_head);\n\n    // wte + wpe\n    struct ggml_tensor * inpL =\n        ggml_add(ctx,\n                ggml_get_rows(ctx, model.wte, embd),\n                ggml_get_rows(ctx, model.wpe, position));\n    ggml_set_name(inpL, \"inpL\");\n    ggml_set_name(inpL->src[0], \"wte\");\n    ggml_set_name(inpL->src[1], \"wpe\");\n\n    for (int il = 0; il < n_layer; ++il) {\n        struct ggml_tensor * cur;\n\n        // norm\n        {\n            // [ 768, N]\n            cur = ggml_norm(ctx, inpL, hparams.eps);\n            ggml_format_name(cur, \"l%d.norm\", il);\n\n            // cur = ln_1_g*cur + ln_1_b\n            // [ 768, N]\n            cur = ggml_add(ctx,\n                    ggml_mul(ctx,\n                        cur,\n                        model.layers[il].ln_1_g),\n                    model.layers[il].ln_1_b);\n            ggml_format_name(cur, \"l%d.ln_1_b\", il);\n            ggml_format_name(cur->src[0], \"l%d.ln_1_g\", il);\n        }\n\n        // attn\n        // [2304, 768] - model.layers[il].c_attn_attn_w\n        // [2304,   1] - model.layers[il].c_attn_attn_b\n        // [ 768,   N] - cur (in)\n        // [2304,   N] - cur (out)\n        //\n        // cur = attn_w*cur + attn_b\n        // [2304, N]\n        {\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_attn_attn_w,\n                    cur);\n            ggml_format_name(cur, \"l%d.attn_w\", il);\n\n            cur = ggml_add(ctx,\n                    cur,\n                    model.layers[il].c_attn_attn_b);\n            ggml_format_name(cur, \"l%d.attn_b\", il);\n        }\n\n        // self-attention\n        {\n            struct ggml_tensor * Qcur = ggml_view_2d(ctx, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);\n            struct ggml_tensor * Kcur = ggml_view_2d(ctx, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);\n            struct ggml_tensor * Vcur = ggml_view_2d(ctx, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);\n\n            ggml_format_name(Qcur, \"l%d.Qcur\", il);\n            ggml_format_name(Kcur, \"l%d.Kcur\", il);\n            ggml_format_name(Vcur, \"l%d.Vcur\", il);\n\n            // store key and value to memory\n            if (N >= 1) {\n                struct ggml_tensor * k = ggml_view_1d(ctx, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));\n                struct ggml_tensor * v = ggml_view_1d(ctx, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));\n\n                ggml_build_forward_expand(gf, ggml_cpy(ctx, Kcur, k));\n                ggml_build_forward_expand(gf, ggml_cpy(ctx, Vcur, v));\n            }\n\n            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)\n            // [64, N, 12]\n            struct ggml_tensor * Q =\n                ggml_permute(ctx,\n                        ggml_cont_3d(ctx, Qcur, n_embd/n_head, n_head, N),\n                        0, 2, 1, 3);\n            ggml_format_name(Q, \"l%d.Q\", il);\n\n            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)\n            // [64, n_past + N, 12]\n            struct ggml_tensor * K =\n                ggml_permute(ctx,\n                        ggml_reshape_3d(ctx,\n                            ggml_view_1d(ctx, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),\n                            n_embd/n_head, n_head, n_past + N),\n                        0, 2, 1, 3);\n            ggml_format_name(K, \"l%d.K\", il);\n\n            // GG: flash attention\n            //struct ggml_tensor * V =\n            //    ggml_cpy(ctx0,\n            //            ggml_permute(ctx0,\n            //                ggml_reshape_3d(ctx0,\n            //                    ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),\n            //                    n_embd/n_head, n_head, n_past + N),\n            //                1, 2, 0, 3),\n            //            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));\n\n            //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);\n\n            // K * Q\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ = ggml_mul_mat(ctx, K, Q);\n            ggml_format_name(KQ, \"l%d.KQ\", il);\n\n            // KQ_scaled = KQ / sqrt(n_embd/n_head)\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ_scaled = ggml_scale(ctx, KQ, KQ_scale);\n            ggml_format_name(KQ_scaled, \"l%d.KQ_scaled\", il);\n\n            // KQ_masked = mask_past(KQ_scaled)\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx, KQ_scaled, n_past);\n            ggml_format_name(KQ_masked, \"l%d.KQ_masked\", il);\n\n            // KQ = soft_max(KQ_masked)\n            // [n_past + N, N, 12]\n            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx, KQ_masked);\n            ggml_format_name(KQ_soft_max, \"l%d.KQ_soft_max\", il);\n\n            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()\n            // [n_past + N, 64, 12]\n            struct ggml_tensor * V_trans =\n                ggml_cont_3d(ctx,\n                        ggml_permute(ctx,\n                            ggml_reshape_3d(ctx,\n                                ggml_view_1d(ctx, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),\n                                n_embd/n_head, n_head, n_past + N),\n                            1, 2, 0, 3),\n                        n_past + N, n_embd/n_head, n_head);\n\n            // KQV = transpose(V) * KQ_soft_max\n            // [64, N, 12]\n            struct ggml_tensor * KQV = ggml_mul_mat(ctx, V_trans, KQ_soft_max);\n            ggml_format_name(KQV, \"l%d.KQV\", il);\n\n            // KQV_merged = KQV.permute(0, 2, 1, 3)\n            // [64, 12, N]\n            struct ggml_tensor * KQV_merged = ggml_permute(ctx, KQV, 0, 2, 1, 3);\n            ggml_format_name(KQV_merged, \"l%d.KQV_merged\", il);\n\n            // cur = KQV_merged.contiguous().view(n_embd, N)\n            // [768, N]\n            cur = ggml_cont_2d(ctx, KQV_merged, n_embd, N);\n            ggml_format_name(cur, \"l%d.KQV_merged_contiguous\", il);\n        }\n\n        // projection\n        // [ 768, 768] - model.layers[il].c_attn_proj_w\n        // [ 768,   1] - model.layers[il].c_attn_proj_b\n        // [ 768,   N] - cur (in)\n        // [ 768,   N] - cur (out)\n        //\n        // cur = proj_w*cur + proj_b\n        // [768, N]\n        {\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_attn_proj_w,\n                    cur);\n            ggml_format_name(cur, \"l%d.attn_proj_w\", il);\n\n            cur = ggml_add(ctx,\n                    cur,\n                    model.layers[il].c_attn_proj_b);\n            ggml_format_name(cur, \"l%d.attn_proj_b\", il);\n        }\n\n        // add the input\n        cur = ggml_add(ctx, cur, inpL);\n        ggml_format_name(cur, \"l%d.add\", il);\n\n        struct ggml_tensor * inpFF = cur;\n\n        // feed-forward network\n        {\n            // norm\n            {\n                cur = ggml_norm(ctx, inpFF, hparams.eps);\n                ggml_format_name(cur, \"l%d.FFnorm\", il);\n\n                // cur = ln_2_g*cur + ln_2_b\n                // [ 768, N]\n                cur = ggml_add(ctx,\n                        ggml_mul(ctx,\n                            cur,\n                            model.layers[il].ln_2_g),\n                        model.layers[il].ln_2_b);\n                ggml_format_name(cur, \"l%d.ln_2_b\", il);\n                ggml_format_name(cur->src[0], \"l%d.ln_2_g\", il);\n            }\n\n            // fully connected\n            // [3072, 768] - model.layers[il].c_mlp_fc_w\n            // [3072,   1] - model.layers[il].c_mlp_fc_b\n            // [ 768,   N] - cur (in)\n            // [3072,   N] - cur (out)\n            //\n            // cur = fc_w*cur + fc_b\n            // [3072, N]\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_mlp_fc_w,\n                    cur);\n            ggml_format_name(cur, \"l%d.mlp_fc_w\", il);\n\n            cur = ggml_add(ctx,\n                    cur,\n                    model.layers[il].c_mlp_fc_b);\n            ggml_format_name(cur, \"l%d.mlp_fc_b\", il);\n\n            // GELU activation\n            // [3072, N]\n            cur = ggml_gelu(ctx, cur);\n            ggml_format_name(cur, \"l%d.gelu\", il);\n\n            // projection\n            // [ 768, 3072] - model.layers[il].c_mlp_proj_w\n            // [ 768,    1] - model.layers[il].c_mlp_proj_b\n            // [3072,    N] - cur (in)\n            // [ 768,    N] - cur (out)\n            //\n            // cur = proj_w*cur + proj_b\n            // [768, N]\n            cur = ggml_mul_mat(ctx,\n                    model.layers[il].c_mlp_proj_w,\n                    cur);\n            ggml_format_name(cur, \"l%d.mlp_proj_w\", il);\n\n            cur = ggml_add(ctx,\n                    cur,\n                    model.layers[il].c_mlp_proj_b);\n            ggml_format_name(cur, \"l%d.mlp_proj_b\", il);\n        }\n\n        // input for next layer\n        inpL = ggml_add(ctx, cur, inpFF);\n        ggml_format_name(inpL, \"l%d.add2\", il);\n    }\n\n    // norm\n    {\n        // [ 768, N]\n        inpL = ggml_norm(ctx, inpL, hparams.eps);\n        ggml_format_name(inpL, \"out_norm\");\n\n        // inpL = ln_f_g*inpL + ln_f_b\n        // [ 768, N]\n        inpL = ggml_add(ctx,\n                ggml_mul(ctx,\n                    inpL,\n                    model.ln_f_g),\n                model.ln_f_b);\n        ggml_format_name(inpL, \"out_ln_f_b\");\n        ggml_format_name(inpL->src[0], \"out_ln_f_g\");\n    }\n\n    // inpL = WTE * inpL\n    // [ 768, 50257] - model.lm_head\n    // [ 768, N]     - inpL\n    inpL = ggml_mul_mat(ctx, model.lm_head, inpL);\n    ggml_format_name(inpL, \"out_lm_head\");\n\n    // logits -> probs\n    //inpL = ggml_soft_max(ctx0, inpL);\n\n    ggml_build_forward_expand(gf, inpL);\n\n    ggml_free(ctx);\n\n    return gf;\n}\n\n// evaluate the transformer\n//\n//   - model:     the model\n//   - sched:     the backend scheduler\n//   - n_past:    the context size so far\n//   - embd_inp:  the embeddings of the tokens in the context\n//   - embd_w:    the predicted logits for the next token\n//\nbool gpt2_eval(\n        const gpt2_model & model,\n        ggml_backend_sched_t sched,\n        const int n_past,\n        const std::vector<gpt_vocab::id> & embd_inp,\n              std::vector<float>         & embd_w) {\n    const int N = embd_inp.size();\n\n    const auto & hparams = model.hparams;\n\n    const int n_vocab = hparams.n_vocab;\n\n    struct ggml_cgraph * gf = gpt2_graph(model, n_past, embd_inp);\n\n    // run the computation\n    ggml_backend_sched_reset(sched);\n    ggml_backend_sched_graph_compute(sched, gf);\n\n    //if (n_past%100 == 0) {\n    //    ggml_graph_print   (&gf);\n    //    ggml_graph_dump_dot(&gf, NULL, \"gpt-2.dot\");\n    //}\n\n    // in this case, the output tensor is the last one in the graph\n    struct ggml_tensor * inpL = ggml_graph_node(gf, -1);\n\n    //embd_w.resize(n_vocab*N);\n    //ggml_backend_tensor_get(inpL, embd_w.data(), 0, sizeof(float)*n_vocab*N);\n\n    // return result just for the last token\n    embd_w.resize(n_vocab);\n    ggml_backend_tensor_get(inpL, embd_w.data(), (n_vocab*(N-1))*sizeof(float), sizeof(float)*n_vocab);\n\n    return true;\n}\n\nint main(int argc, char ** argv) {\n    ggml_time_init();\n\n    const int64_t t_main_start_us = ggml_time_us();\n\n    gpt_params params;\n    params.model = \"models/gpt-2-117M/ggml-model.bin\";\n\n    if (gpt_params_parse(argc, argv, params) == false) {\n        return 1;\n    }\n\n    if (params.seed < 0) {\n        params.seed = time(NULL);\n    }\n\n    printf(\"%s: seed = %d\\n\", __func__, params.seed);\n\n    std::mt19937 rng(params.seed);\n    if (params.prompt.empty()) {\n        params.prompt = gpt_random_prompt(rng);\n    }\n\n    int64_t t_load_us = 0;\n\n    gpt_vocab vocab;\n    gpt2_model model;\n\n    // load the model\n    {\n        const int64_t t_start_us = ggml_time_us();\n\n        if (!gpt2_model_load(params.model, model, vocab, params)) {\n            fprintf(stderr, \"%s: failed to load model from '%s'\\n\", __func__, params.model.c_str());\n            return 1;\n        }\n\n        t_load_us = ggml_time_us() - t_start_us;\n\n        test_gpt_tokenizer(vocab, params.token_test);\n    }\n\n    // create the backend scheduler\n    // the scheduler handles the allocation of the compute buffers and the scheduling of the computation between the different backends\n    ggml_backend_sched_t sched;\n    {\n        // initialize the scheduler\n        sched = ggml_backend_sched_new(model.backends.data(), NULL, model.backends.size(), GPT2_MAX_NODES, false, true);\n\n        // create the worst case graph for memory usage estimation\n        int n_tokens = std::min(model.hparams.n_ctx, params.n_batch);\n        int n_past = model.hparams.n_ctx - n_tokens;\n        struct ggml_cgraph * gf = gpt2_graph(model, n_past, std::vector<gpt_vocab::id>(n_tokens, 0));\n\n        ggml_backend_sched_reserve(sched, gf);\n\n\n        // compute the required memory\n        size_t mem_size = 0;\n        for (size_t i = 0; i < model.backends.size(); i++) {\n            size_t size = ggml_backend_sched_get_buffer_size(sched, model.backends[i]);\n            if (size > 0) {\n                mem_size += size;\n                printf(\"%s: %8s compute buffer size = %8.2f MB\\n\", __func__, ggml_backend_name(model.backends[i]), size/1024.0/1024.0);\n                //printf(\"%s: %8s compute buffer size = %zu bytes\\n\", __func__, ggml_backend_name(model.backends[i]), size);\n            }\n        }\n\n        printf(\"%s: total compute buffer size: %.2f MB\\n\", __func__, mem_size/1024.0/1024.0);\n    }\n\n    int n_past = 0;\n\n    int64_t t_sample_us  = 0;\n    int64_t t_predict_us = 0;\n\n    std::vector<float> logits;\n\n    // tokenize the prompt\n    std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);\n\n    params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());\n\n    printf(\"%s: prompt: '%s'\\n\", __func__, params.prompt.c_str());\n    printf(\"%s: number of tokens in prompt = %zu, first 8 tokens: \", __func__, embd_inp.size());\n    for (int i = 0; i < std::min(8, (int) embd_inp.size()); i++) {\n        printf(\"%d \", embd_inp[i]);\n    }\n    printf(\"\\n\\n\");\n\n    // submit the input prompt token-by-token\n    // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning\n    std::vector<gpt_vocab::id> embd;\n\n    for (size_t i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {\n        // predict\n        if (embd.size() > 0) {\n            const int64_t t_start_us = ggml_time_us();\n\n            if (!gpt2_eval(model, sched, n_past, embd, logits)) {\n                printf(\"Failed to predict\\n\");\n                return 1;\n            }\n\n            t_predict_us += ggml_time_us() - t_start_us;\n        }\n\n        n_past += embd.size();\n        embd.clear();\n\n        if (i >= embd_inp.size()) {\n            // sample next token\n            const int   top_k = params.top_k;\n            const float top_p = params.top_p;\n            const float temp  = params.temp;\n\n            const int n_vocab = model.hparams.n_vocab;\n\n            gpt_vocab::id id = 0;\n\n            {\n                const int64_t t_start_sample_us = ggml_time_us();\n\n                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);\n\n                t_sample_us += ggml_time_us() - t_start_sample_us;\n            }\n\n            // add it to the context\n            embd.push_back(id);\n        } else {\n            // if here, it means we are still processing the input prompt\n            for (size_t k = i; k < embd_inp.size(); k++) {\n                embd.push_back(embd_inp[k]);\n                if (int32_t(embd.size()) >= params.n_batch) {\n                    break;\n                }\n            }\n            i += embd.size() - 1;\n        }\n\n        // display text\n        for (auto id : embd) {\n            printf(\"%s\", vocab.id_to_token[id].c_str());\n        }\n        fflush(stdout);\n\n        // end of text token\n        if (embd.back() == 50256) {\n            break;\n        }\n    }\n\n    // report timing\n    {\n        const int64_t t_main_end_us = ggml_time_us();\n\n        printf(\"\\n\\n\");\n        printf(\"%s:     load time = %8.2f ms\\n\", __func__, t_load_us/1000.0f);\n        printf(\"%s:   sample time = %8.2f ms\\n\", __func__, t_sample_us/1000.0f);\n        printf(\"%s:  predict time = %8.2f ms / %.2f ms per token\\n\", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);\n        printf(\"%s:    total time = %8.2f ms\\n\", __func__, (t_main_end_us - t_main_start_us)/1000.0f);\n    }\n\n    ggml_free(model.ctx_w);\n\n    ggml_backend_sched_free(sched);\n    ggml_backend_buffer_free(model.buffer_kv);\n    for (auto buf : model.buffers_w) {\n        ggml_backend_buffer_free(buf);\n    }\n    for (auto backend : model.backends) {\n        ggml_backend_free(backend);\n    }\n\n    return 0;\n}\n"
  },
  {
    "path": "examples/gpt-2/quantize.cpp",
    "content": "#include \"ggml.h\"\n\n#include \"common.h\"\n#include \"common-ggml.h\"\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <string>\n#include <vector>\n#include <regex>\n\n// default hparams (GPT-2 117M)\nstruct gpt2_hparams {\n    int32_t n_vocab = 50257;\n    int32_t n_ctx   = 1024;\n    int32_t n_embd  = 768;\n    int32_t n_head  = 12;\n    int32_t n_layer = 12;\n    int32_t ftype   = 1;\n};\n\n// quantize a model\nbool gpt2_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {\n    gpt_vocab vocab;\n\n    printf(\"%s: loading model from '%s'\\n\", __func__, fname_inp.c_str());\n\n    auto finp = std::ifstream(fname_inp, std::ios::binary);\n    if (!finp) {\n        fprintf(stderr, \"%s: failed to open '%s' for reading\\n\", __func__, fname_inp.c_str());\n        return false;\n    }\n\n    auto fout = std::ofstream(fname_out, std::ios::binary);\n    if (!fout) {\n        fprintf(stderr, \"%s: failed to open '%s' for writing\\n\", __func__, fname_out.c_str());\n        return false;\n    }\n\n    // verify magic\n    {\n        uint32_t magic;\n        finp.read((char *) &magic, sizeof(magic));\n        if (magic != GGML_FILE_MAGIC) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad magic)\\n\", __func__, fname_inp.c_str());\n            return false;\n        }\n\n        fout.write((char *) &magic, sizeof(magic));\n    }\n\n    gpt2_hparams hparams;\n\n    // load hparams\n    {\n        finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));\n        finp.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));\n        finp.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));\n        finp.read((char *) &hparams.n_head,  sizeof(hparams.n_head));\n        finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));\n        finp.read((char *) &hparams.ftype,   sizeof(hparams.ftype));\n\n        const int32_t qntvr_src =    hparams.ftype / GGML_QNT_VERSION_FACTOR;\n        const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype;\n\n        printf(\"%s: n_vocab     = %d\\n\", __func__, hparams.n_vocab);\n        printf(\"%s: n_ctx       = %d\\n\", __func__, hparams.n_ctx);\n        printf(\"%s: n_embd      = %d\\n\", __func__, hparams.n_embd);\n        printf(\"%s: n_head      = %d\\n\", __func__, hparams.n_head);\n        printf(\"%s: n_layer     = %d\\n\", __func__, hparams.n_layer);\n        printf(\"%s: ftype (src) = %d\\n\", __func__, hparams.ftype);\n        printf(\"%s: qntvr (src) = %d\\n\", __func__, qntvr_src);\n        printf(\"%s: ftype (dst) = %d\\n\", __func__, ftype_dst);\n        printf(\"%s: qntvr (dst) = %d\\n\", __func__, GGML_QNT_VERSION);\n\n        fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));\n        fout.write((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));\n        fout.write((char *) &hparams.n_embd,  sizeof(hparams.n_embd));\n        fout.write((char *) &hparams.n_head,  sizeof(hparams.n_head));\n        fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer));\n        fout.write((char *) &ftype_dst,       sizeof(ftype_dst));\n    }\n\n    // load vocab\n    {\n        int32_t n_vocab = 0;\n        finp.read ((char *) &n_vocab, sizeof(n_vocab));\n        fout.write((char *) &n_vocab, sizeof(n_vocab));\n\n        if (n_vocab != hparams.n_vocab) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad vocab size %d != %d)\\n\",\n                    __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);\n            return false;\n        }\n\n        std::string word;\n        for (int i = 0; i < n_vocab; i++) {\n            uint32_t len;\n            finp.read ((char *) &len, sizeof(len));\n            fout.write((char *) &len, sizeof(len));\n\n            word.resize(len);\n            finp.read ((char *) word.data(), len);\n            fout.write((char *) word.data(), len);\n\n            vocab.token_to_id[word] = i;\n            vocab.id_to_token[i] = word;\n        }\n    }\n\n    // regexes of tensor names to be quantized\n    const std::vector<std::string> to_quant = {\n        \"model/wte\",\n        \"model/lm_head\",\n        \"model/h.*/attn/c_attn/w\",\n        \"model/h.*/attn/c_proj/w\",\n        \"model/h.*/mlp/c_fc/w\",\n        \"model/h.*/mlp/c_proj/w\",\n    };\n\n    if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, {})) {\n        fprintf(stderr, \"%s: failed to quantize model '%s'\\n\", __func__, fname_inp.c_str());\n        return false;\n    }\n\n    finp.close();\n    fout.close();\n\n    return true;\n}\n\n// usage:\n//  ./gpt-2-quantize models/gpt-2-117M/ggml-model.bin models/gpt-2-117M/ggml-model-quant.bin type\n//\nint main(int argc, char ** argv) {\n    if (argc != 4) {\n        fprintf(stderr, \"usage: %s model-f32.bin model-quant.bin type\\n\", argv[0]);\n        ggml_print_ftypes(stderr);\n        return 1;\n    }\n\n    // needed to initialize f16 tables\n    {\n        struct ggml_init_params params = { 0, NULL, false };\n        struct ggml_context * ctx = ggml_init(params);\n        ggml_free(ctx);\n    }\n\n    const std::string fname_inp = argv[1];\n    const std::string fname_out = argv[2];\n\n    const ggml_ftype ftype = ggml_parse_ftype(argv[3]);\n\n    const int64_t t_main_start_us = ggml_time_us();\n\n    int64_t t_quantize_us = 0;\n\n    // load the model\n    {\n        const int64_t t_start_us = ggml_time_us();\n\n        if (!gpt2_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {\n            fprintf(stderr, \"%s: failed to quantize model from '%s'\\n\", __func__, fname_inp.c_str());\n            return 1;\n        }\n\n        t_quantize_us = ggml_time_us() - t_start_us;\n    }\n\n    // report timing\n    {\n        const int64_t t_main_end_us = ggml_time_us();\n\n        printf(\"\\n\");\n        printf(\"%s: quantize time = %8.2f ms\\n\", __func__, t_quantize_us/1000.0f);\n        printf(\"%s:    total time = %8.2f ms\\n\", __func__, (t_main_end_us - t_main_start_us)/1000.0f);\n    }\n\n    return 0;\n}\n"
  },
  {
    "path": "examples/gpt-j/CMakeLists.txt",
    "content": "#\n# gpt-j\n\nset(TEST_TARGET gpt-j)\nadd_executable(${TEST_TARGET} main.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)\n\n#\n# gpt-j-quantize\n\nset(TEST_TARGET gpt-j-quantize)\nadd_executable(${TEST_TARGET} quantize.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)\n"
  },
  {
    "path": "examples/gpt-j/README.md",
    "content": "# gpt-j\n\nLocal GPT-J inference on your computer using C/C++\n\nNo video card required. You just need to have 16 GB of RAM.\n\n## Motivation\n\nThe GPT-J 6B model is the open-source alternative to OpenAI's GPT-3. It's basically a neural network that allows you to\ngenerate coherent, human-like text given a certain context (prompt).\n\nThe GPT-J model is quite big - the compact version of the model uses 16-bit floating point representation of the weights\nand is still 12 GB big. This means that in order to run inference on your computer, you would need to have a video card\nwith at least 12 GB of video RAM. Alternatively, you can try to run the python implementations on the CPU, but that\nwould probably not be very efficient as they are primarily optimized for running on a GPU (or at least this is my guess -\nI don't have much experience with python).\n\nI wanted to try and run the model on my MacBook, so I decided to implement the model inference from scratch using my own\ncustom build tensor library. The tensor library (called [ggml](https://github.com/ggerganov/ggml), written in C) is in\nearly development stage, but it already allows me to run the GPT-J model.\n\nOn my 32GB MacBook M1 Pro, I achieve an inference speed of about `125 ms/token` or about ~6 words per second (1 word\ntypically consists of 1 or 2 tokens).\n\nHere is a sample run with prompt `int main(int argc, char ** argv) {`:\n\n```bash\n$ time ./bin/gpt-j -p \"int main(int argc, char ** argv) {\"\n\ngptj_model_load: loading model from 'models/gpt-j-6B/ggml-model.bin' - please wait ...\ngptj_model_load: n_vocab = 50400\ngptj_model_load: n_ctx   = 2048\ngptj_model_load: n_embd  = 4096\ngptj_model_load: n_head  = 16\ngptj_model_load: n_layer = 28\ngptj_model_load: n_rot   = 64\ngptj_model_load: f16     = 1\ngptj_model_load: ggml ctx size = 13334.86 MB\ngptj_model_load: memory_size =  1792.00 MB, n_mem = 57344\ngptj_model_load: ................................... done\ngptj_model_load: model size = 11542.79 MB / num tensors = 285\nmain: number of tokens in prompt = 13\n\nint main(int argc, char ** argv) {\n    (void)argc;\n    (void)argv;\n\n    {\n        struct sockaddr_in addr;\n        int addrlen;\n        char * ip = \"192.168.1.4\";\n        int i;\n\n        if ( (addrlen = sizeof(addr)) == -1 )\n            return -1;\n\n        for (i = 0; i < 10; ++i) {\n            addr.sin_family = AF_INET;\n            addr.sin_addr.s_addr = inet_addr(ip);\n\nmain: mem per token = 16430420 bytes\nmain:     load time =  6211.48 ms\nmain:   sample time =    13.74 ms\nmain:  predict time = 26420.34 ms / 124.62 ms per token\nmain:    total time = 33035.37 ms\n\nreal\t0m33.171s\nuser\t3m32.269s\nsys      0m3.686s\n\n$\n```\n\nIt took ~6.2 seconds to load the model to memory. After that, it took ~26.4 seconds to generate 200 tokens of what\nlooks like to be the beginning of a networking program in C. Pretty cool!\n\nHere is another run, just for fun:\n\n```bash\ntime ./bin/gpt-j -n 500 -t 8 -p \"Ask HN: Inherited the worst code and tech team I have ever seen. How to fix it?\n\"\n\ngptj_model_load: loading model from 'models/gpt-j-6B/ggml-model.bin' - please wait ...\ngptj_model_load: n_vocab = 50400\ngptj_model_load: n_ctx   = 2048\ngptj_model_load: n_embd  = 4096\ngptj_model_load: n_head  = 16\ngptj_model_load: n_layer = 28\ngptj_model_load: n_rot   = 64\ngptj_model_load: f16     = 1\ngptj_model_load: ggml ctx size = 13334.86 MB\ngptj_model_load: memory_size =  1792.00 MB, n_mem = 57344\ngptj_model_load: ................................... done\ngptj_model_load: model size = 11542.79 MB / num tensors = 285\nmain: number of tokens in prompt = 24\n\nAsk HN: Inherited the worst code and tech team I have ever seen. How to fix it?\n\nI've inherited a team with some very strange and un-documented practices, one of them is that they use an old custom\napplication with a very slow tech stack written in Python that the team doesn't want to touch but also doesn't want to\nthrow away as it has some \"legacy\" code in it.\n\nThe problem is, the tech stack is very very slow.\n\nThey have a single web server on a VM that is slow.\nThe server is a little bit busy (not very busy though) and they have a lot of processes (30+ that are constantly being\nspawned by the application)\nThey have an application that is single threaded and was written in Python and the team don't want to touch this, and\nthe application is very slow.\n\nMy task as a new member of the team is to fix this.\n\nI'm a senior dev on the team (3 years on the project) and have been told that I will take the lead on this task. I know\nnext to nothing about Python. So here is what I have so far.\n\nWhat I have done is I've been trying to debug the processes with the \"ps\" command. This way I can see what is running\nand where. From what I see, the application spawns 10 processes a minute and some of them are used for nothing.\n\nI have also started to look for the code. The application source is not in GitHub or any other repository, it is only on\nour internal GitLab.\n\nWhat I've found so far:\n\nThe application uses a custom SQLAlchemy implementation to interact with the data. I've looked at the source, it looks\nlike an object cache or something like that. But from what I've seen, the cache gets full every 20 minutes and then gets\ncleared with a special command.\n\nAnother strange thing is that the application creates a file for every entry in the database (even if the entry already\nexists). I've looked at the file to see if it contains something, but it seems to be a JSON file with lots of records.\n\nThe other strange thing is that I can only find the database tables in the GitLab repository and not the code. So I\ncan't really understand how the application is supposed to interact with the database.\n\nI also found a \"log\" directory, but the code is encrypted with AES. From what I've found, it is in\n\nmain: mem per token = 16430420 bytes\nmain:     load time =  3900.10 ms\nmain:   sample time =    32.58 ms\nmain:  predict time = 68049.91 ms / 130.11 ms per token\nmain:    total time = 73020.05 ms\n\nreal\t1m13.156s\nuser\t9m1.328s\nsys.    0m7.103s\n```\n\n## Implementation details\n\nThe high level implementation of the model is contained in the [main.cpp](main.cpp) file. The core computations are\nperformed by the [ggml](https://github.com/ggerganov/ggml/blob/master/include/ggml.h) library.\n\n\n#### Matrix multiplication\n\nThe most performance critical part of the implementation is of course the matrix multiplication routine. 99% of the time\nis spent here, so it was important to optimize this as much as possible.\n\nOn Arm64, I utilize the 128-bit NEON intrinsics for 16-bit floating point operations:\n\nhttps://github.com/ggerganov/ggml/blob/fb558f78d905f85c54813602649ddd628ffe0f3a/src/ggml.c#L187-L243\n\nThese instructions allow each core to operate simultaneously on 64 16-bit floats. I'm no expert in SIMD, but after quite\nsome trials this was the most efficient code for dot product of a row and column that I could come up with. Combined\nwith the parallel computation on 8 CPU threads, I believe I'm close to the maximum performance that one could possibly\nget on the M1 CPU. Still, I'm curious to know if there is a more efficient way to implement this.\n\n\n#### Attempt to use the M1 GPU\n\nOne interesting property of the GPT-J transformer architecture is that it allows you to perform part of the inference in\nparallel - i.e. the Feed-forward network can be computed in parallel to the Self-attention layer:\n\nhttps://github.com/ggerganov/ggml/blob/fb558f78d905f85c54813602649ddd628ffe0f3a/examples/gpt-j/main.cpp#L507-L531\n\nSo I thought why not try and bring in the M1 GPU to compute half of the neural network in parallel to the CPU and\npotentially gain some extra performance. Thanks to the M1's shared memory model, it was relatively easy to offload part\nof the computation to the GPU using Apple's [Metal Performance\nShaders](https://developer.apple.com/documentation/metalperformanceshaders). The GPU shares the host memory, so there is\nno need to copy the data back and forth as you would normally do with Cuda or OpenCL. The weight matrices are directly\navailable to be used by the GPU.\n\nHowever, to my surprise, using MPS together with the CPU did not lead to any performance improvement at all. My\nconclusion was that the 8-thread NEON CPU computation is already saturating the memory bandwidth of the M1 and since\nthe CPU and the GPU on the MacBook are sharing that bandwidth, it does not help to offload the computation to the GPU.\nAnother observation was that the MPS GPU matrix multiplication using 16-bit floats had the same performance as the\n8-thread NEON CPU implementation. Again, I explain this with a saturated memory channel. But of course, my explanation\ncould be totally wrong and somehow the implementation wasn't utilizing the resources correctly.\n\nIn the end, I decided to not use MPS or the GPU all together.\n\n### Zero memory allocations\n\nAnother property of my implementation is that it does not perform any memory allocations once the model is loaded into\nmemory. All required memory is allocated at the start of the program with a single `malloc` (technically 2 calls, but\nthat is not important).\n\n## Usage\n\nIf you want to give this a try and you are on Linux or Mac OS, simply follow these instructions:\n\n```bash\n# Download the ggml-compatible GPT-J 6B model (requires 12GB disk space)\n../examples/gpt-j/download-ggml-model.sh 6B\n\n# Run the inference (requires 16GB of CPU RAM)\n./bin/gpt-j -m models/gpt-j-6B/ggml-model.bin -p \"This is an example\"\n\n# Input prompt through pipe and run the inference.\necho \"This is an example\" > prompt.txt\ncat prompt.txt | ./bin/gpt-j -m models/gpt-j-6B/ggml-model.bin\n```\n\nTo run the `gpt-j` tool, you need the 12GB `ggml-model.bin` file which contains the GPT-J model in\n[ggml](https://github.com/ggerganov/ggml) compatible format. In the instructions above, the binary file\nis downloaded from my repository on Hugging Face using the [download-ggml-model.sh](download-ggml-model.sh) script.\nYou can also, download the file manually from this link:\n\nhttps://huggingface.co/ggerganov/ggml/tree/main\n\n---\n\nAlternatively, if you don't want to download the 12GB ggml model file, you can perform the conversion yourself using\npython.\n\nFirst, you need to download the full GPT-J model from here: https://huggingface.co/EleutherAI/gpt-j-6B\n\nNote that the full model is quite big - about 72 GB. After you download it, you need to convert it to ggml format using\nthe [convert-h5-to-ggml.py](convert-h5-to-ggml.py) script. This will generate the `ggml-model.bin` file, which you can\nthen use with the `gpt-j` program.\n\n\n## GPT-2\n\nI also implemented a tool for CPU inference using the smaller GPT-2 models. They have worse quality compared to GPT-J,\nbut are much faster to execute.\n\nFor example, the Small GPT-2 model is only 240 MB big and the inference speed on my MacBook is about 200 tokens/sec.\n\nFor more details, checkout the GPT-2 example here: [gpt-2](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2)\n"
  },
  {
    "path": "examples/gpt-j/convert-h5-to-ggml.py",
    "content": "# Convert GPT-J-6B h5 transformer model to ggml format\n#\n# Load the model using GPTJForCausalLM.\n# Iterate over all variables and write them to a binary file.\n#\n# For each variable, write the following:\n#   - Number of dimensions (int)\n#   - Name length (int)\n#   - Dimensions (int[n_dims])\n#   - Name (char[name_length])\n#   - Data (float[n_dims])\n#\n# By default, the bigger matrices are converted to 16-bit floats.\n# This can be disabled by adding the \"use-f32\" CLI argument.\n#\n# At the start of the ggml file we write the model parameters\n# and vocabulary.\n#\n\nimport sys\nimport struct\nimport json\nimport torch\nimport numpy as np\n\nfrom transformers import GPTJForCausalLM\n\n# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py\ndef bytes_to_unicode():\n    \"\"\"\n    Returns list of utf-8 byte and a corresponding list of unicode strings.\n    The reversible bpe codes work on unicode strings.\n    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.\n    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.\n    This is a signficant percentage of your normal, say, 32K bpe vocab.\n    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\n    And avoids mapping to whitespace/control characters the bpe code barfs on.\n    \"\"\"\n    bs = list(range(ord(\"!\"), ord(\"~\")+1))+list(range(ord(\"¡\"), ord(\"¬\")+1))+list(range(ord(\"®\"), ord(\"ÿ\")+1))\n    cs = bs[:]\n    n = 0\n    for b in range(2**8):\n        if b not in bs:\n            bs.append(b)\n            cs.append(2**8+n)\n            n += 1\n    cs = [chr(n) for n in cs]\n    return dict(zip(bs, cs))\n\nif len(sys.argv) < 3:\n    print(\"Usage: convert-h5-to-ggml.py dir-model [use-f32]\\n\")\n    print(\"  ftype == 0 -> float32\")\n    print(\"  ftype == 1 -> float16\")\n    sys.exit(1)\n\n# output in the same directory as the model\ndir_model = sys.argv[1]\nfname_out = sys.argv[1] + \"/ggml-model.bin\"\n\nwith open(dir_model + \"/vocab.json\", \"r\", encoding=\"utf-8\") as f:\n    encoder = json.load(f)\n\nwith open(dir_model + \"/added_tokens.json\", \"r\", encoding=\"utf-8\") as f:\n    encoder_added = json.load(f)\n\nwith open(dir_model + \"/config.json\", \"r\", encoding=\"utf-8\") as f:\n    hparams = json.load(f)\n\n# possible data types\n#   ftype == 0 -> float32\n#   ftype == 1 -> float16\n#\n# map from ftype to string\nftype_str = [\"f32\", \"f16\"]\n\nftype = 1\nif len(sys.argv) > 2:\n    ftype = int(sys.argv[2])\n    if ftype < 0 or ftype > 1:\n        print(\"Invalid ftype: \" + str(ftype))\n        sys.exit(1)\n    fname_out = sys.argv[1] + \"/ggml-model-\" + ftype_str[ftype] + \".bin\"\n\n\nmodel = GPTJForCausalLM.from_pretrained(dir_model, low_cpu_mem_usage=True)\n#print (model)\n\nlist_vars = model.state_dict()\n#print (list_vars)\n\nfout = open(fname_out, \"wb\")\n\nfout.write(struct.pack(\"i\", 0x67676d6c)) # magic: ggml in hex\nfout.write(struct.pack(\"i\", hparams[\"vocab_size\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_positions\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_embd\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_head\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_layer\"]))\nfout.write(struct.pack(\"i\", hparams[\"rotary_dim\"]))\nfout.write(struct.pack(\"i\", ftype))\n\nbyte_encoder = bytes_to_unicode()\nbyte_decoder = {v:k for k, v in byte_encoder.items()}\n\nfout.write(struct.pack(\"i\", len(encoder) + len(encoder_added)))\n\nfor key in encoder:\n    text = bytearray([byte_decoder[c] for c in key])\n    fout.write(struct.pack(\"i\", len(text)))\n    fout.write(text)\n\nfor key in encoder_added:\n    text = bytearray([byte_decoder[c] for c in key])\n    fout.write(struct.pack(\"i\", len(text)))\n    fout.write(text)\n\nfor name in list_vars.keys():\n    data = list_vars[name].squeeze().numpy()\n    print(\"Processing variable: \" + name + \" with shape: \", data.shape)\n\n    # we don't need these\n    if name.endswith(\"attn.masked_bias\") or name.endswith(\".attn.bias\"):\n        print(\"  Skipping variable: \" + name)\n        continue\n\n    n_dims = len(data.shape);\n\n    # ftype == 0 -> float32, ftype == 1 -> float16\n    ftype_cur = 0;\n    if ftype != 0:\n        if name[-7:] == \".weight\" and n_dims == 2:\n            print(\"  Converting to float16\")\n            data = data.astype(np.float16)\n            ftype_cur = 1\n        else:\n            print(\"  Converting to float32\")\n            data = data.astype(np.float32)\n            ftype_cur = 0\n    else:\n        if data.dtype != np.float32:\n            print(\"  Converting to float32\")\n            data = data.astype(np.float32)\n            ftype_cur = 0\n\n    # for efficiency - transpose these matrices:\n    # (note - with latest ggml this is no longer more efficient, so disabling it)\n    #  \"transformer.h.*.mlp.fc_in.weight\"\n    #  \"transformer.h.*.attn.out_proj.weight\"\n    #  \"transformer.h.*.attn.q_proj.weight\"\n    #  \"transformer.h.*.attn.k_proj.weight\"\n    #  \"transformer.h.*.attn.v_proj.weight\"\n    #if name.endswith(\".mlp.fc_in.weight\")     or \\\n    #   name.endswith(\".attn.out_proj.weight\") or \\\n    #   name.endswith(\".attn.q_proj.weight\")   or \\\n    #   name.endswith(\".attn.k_proj.weight\")   or \\\n    #   name.endswith(\".attn.v_proj.weight\"):\n    #    print(\"  Transposing\")\n    #    data = data.transpose()\n\n    # header\n    str = name.encode('utf-8')\n    fout.write(struct.pack(\"iii\", n_dims, len(str), ftype_cur))\n    for i in range(n_dims):\n        fout.write(struct.pack(\"i\", data.shape[n_dims - 1 - i]))\n    fout.write(str);\n\n    # data\n    data.tofile(fout)\n\nfout.close()\n\nprint(\"Done. Output file: \" + fname_out)\nprint(\"\")\n"
  },
  {
    "path": "examples/gpt-j/download-ggml-model.sh",
    "content": "#!/bin/bash\n\n# This script downloads GPT-J model files that have already been converted to ggml format.\n# This way you don't have to convert them yourself.\n#\n# If you want to download the original GPT-J model files, use the \"download-model.sh\" script instead.\n\n#src=\"https://ggml.ggerganov.com\"\n#pfx=\"ggml-model-gpt-j\"\n\nsrc=\"https://huggingface.co/ggerganov/ggml\"\npfx=\"resolve/main/ggml-model-gpt-j\"\n\nggml_path=$(dirname $(realpath $0))\n\n# GPT-J models\nmodels=( \"6B\" )\n\n# list available models\nfunction list_models {\n    printf \"\\n\"\n    printf \"  Available models:\"\n    for model in \"${models[@]}\"; do\n        printf \" $model\"\n    done\n    printf \"\\n\\n\"\n}\n\nif [ \"$#\" -ne 1 ]; then\n    printf \"Usage: $0 <model>\\n\"\n    list_models\n\n    exit 1\nfi\n\nmodel=$1\n\nif [[ ! \" ${models[@]} \" =~ \" ${model} \" ]]; then\n    printf \"Invalid model: $model\\n\"\n    list_models\n\n    exit 1\nfi\n\n# download ggml model\n\nprintf \"Downloading ggml model $model ...\\n\"\n\nmkdir -p models/gpt-j-$model\n\nif [ -x \"$(command -v wget)\" ]; then\n    wget --quiet --show-progress -O models/gpt-j-$model/ggml-model.bin $src/$pfx-$model.bin\nelif [ -x \"$(command -v curl)\" ]; then\n    curl -L --output models/gpt-j-$model/ggml-model.bin $src/$pfx-$model.bin\nelse\n    printf \"Either wget or curl is required to download models.\\n\"\n    exit 1\nfi\n\nif [ $? -ne 0 ]; then\n    printf \"Failed to download ggml model $model \\n\"\n    printf \"Please try again later or download the original GPT-J model files and convert them yourself.\\n\"\n    exit 1\nfi\n\nprintf \"Done! Model '$model' saved in 'models/gpt-j-$model/ggml-model.bin'\\n\"\nprintf \"You can now use it like this:\\n\\n\"\nprintf \"  $ ./bin/gpt-j -m models/gpt-j-$model/ggml-model.bin -p \\\"This is an example\\\"\\n\"\nprintf \"\\n\"\n"
  },
  {
    "path": "examples/gpt-j/download-model.sh",
    "content": "#!/bin/bash\n\nprintf \"To obtain the GPT-J 6B model files, please visit: https://huggingface.co/EleutherAI/gpt-j-6B\\n\\n\"\n\nprintf \"The model is very big. For example, the reposirory above is 72GB in size.\\n\"\nprintf \"If you are sure that you want to clone it, simply run the following command:\\n\\n\"\n\nprintf \" $ git clone https://huggingface.co/EleutherAI/gpt-j-6B models/gpt-j-6B\\n\\n\"\n\nprintf \"Alternatively, use the 'download-ggml-model.sh' script to download a 12GB ggml version of the model.\\n\"\nprintf \"This version is enough to run inference using the ggml library.\\n\\n\"\n"
  },
  {
    "path": "examples/gpt-j/main.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n\n#include \"common.h\"\n#include \"common-ggml.h\"\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <string>\n#include <vector>\n\n#if defined(_MSC_VER)\n#pragma warning(disable: 4244 4267) // possible loss of data\n#endif\n\n\n// default hparams (GPT-J 6B)\nstruct gptj_hparams {\n    int32_t n_vocab = 50400;\n    int32_t n_ctx   = 2048;\n    int32_t n_embd  = 4096;\n    int32_t n_head  = 16;\n    int32_t n_layer = 28;\n    int32_t n_rot   = 64;\n    int32_t ftype   = 1;\n    float   eps     = 1e-5f;\n};\n\nstruct gptj_layer {\n    // normalization\n    struct ggml_tensor * ln_1_g;\n    struct ggml_tensor * ln_1_b;\n\n    // attention\n    struct ggml_tensor * c_attn_q_proj_w;\n    struct ggml_tensor * c_attn_k_proj_w;\n    struct ggml_tensor * c_attn_v_proj_w;\n\n    struct ggml_tensor * c_attn_proj_w;\n\n    // ff\n    struct ggml_tensor * c_mlp_fc_w;\n    struct ggml_tensor * c_mlp_fc_b;\n\n    struct ggml_tensor * c_mlp_proj_w;\n    struct ggml_tensor * c_mlp_proj_b;\n};\n\nstruct gptj_model {\n    gptj_hparams hparams;\n\n    // normalization\n    struct ggml_tensor * ln_f_g;\n    struct ggml_tensor * ln_f_b;\n\n    struct ggml_tensor * wte; // token embedding\n\n    struct ggml_tensor * lmh_g; // language model head\n    struct ggml_tensor * lmh_b; // language model bias\n\n    std::vector<gptj_layer> layers;\n\n    // key + value memory\n    struct ggml_tensor * memory_k;\n    struct ggml_tensor * memory_v;\n\n    //\n    struct ggml_context * ctx;\n    std::map<std::string, struct ggml_tensor *> tensors;\n};\n\n// load the model's weights from a file\nbool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & vocab) {\n    printf(\"%s: loading model from '%s' - please wait ...\\n\", __func__, fname.c_str());\n\n    auto fin = std::ifstream(fname, std::ios::binary);\n    if (!fin) {\n        fprintf(stderr, \"%s: failed to open '%s'\\n\", __func__, fname.c_str());\n        return false;\n    }\n\n    // verify magic\n    {\n        uint32_t magic;\n        fin.read((char *) &magic, sizeof(magic));\n        if (magic != GGML_FILE_MAGIC) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad magic)\\n\", __func__, fname.c_str());\n            return false;\n        }\n    }\n\n    // load hparams\n    {\n        auto & hparams = model.hparams;\n\n        fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));\n        fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));\n        fin.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));\n        fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));\n        fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));\n        fin.read((char *) &hparams.n_rot,   sizeof(hparams.n_rot));\n        fin.read((char *) &hparams.ftype,   sizeof(hparams.ftype));\n\n        const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;\n\n        printf(\"%s: n_vocab = %d\\n\", __func__, hparams.n_vocab);\n        printf(\"%s: n_ctx   = %d\\n\", __func__, hparams.n_ctx);\n        printf(\"%s: n_embd  = %d\\n\", __func__, hparams.n_embd);\n        printf(\"%s: n_head  = %d\\n\", __func__, hparams.n_head);\n        printf(\"%s: n_layer = %d\\n\", __func__, hparams.n_layer);\n        printf(\"%s: n_rot   = %d\\n\", __func__, hparams.n_rot);\n        printf(\"%s: ftype   = %d\\n\", __func__, hparams.ftype);\n        printf(\"%s: qntvr   = %d\\n\", __func__, qntvr);\n\n        hparams.ftype %= GGML_QNT_VERSION_FACTOR;\n    }\n\n    // load vocab\n    {\n        int32_t n_vocab = 0;\n        fin.read((char *) &n_vocab, sizeof(n_vocab));\n\n        if (n_vocab != model.hparams.n_vocab) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad vocab size %d != %d)\\n\",\n                    __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);\n            return false;\n        }\n\n        std::string word;\n        std::vector<char> buf(128);\n\n        for (int i = 0; i < n_vocab; i++) {\n            uint32_t len;\n            fin.read((char *) &len, sizeof(len));\n\n            buf.resize(len);\n            fin.read((char *) buf.data(), len);\n            word.assign(buf.data(), len);\n\n            vocab.token_to_id[word] = i;\n            vocab.id_to_token[i] = word;\n        }\n    }\n\n    // for the big tensors, we have the option to store the data in 16-bit floats or quantized\n    // in order to save memory and also to speed up the computation\n    ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));\n    if (wtype == GGML_TYPE_COUNT) {\n        fprintf(stderr, \"%s: invalid model file '%s' (bad ftype value %d)\\n\",\n                __func__, fname.c_str(), model.hparams.ftype);\n        return false;\n    }\n\n    auto & ctx = model.ctx;\n\n    size_t ctx_size = 0;\n\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n        const int n_vocab = hparams.n_vocab;\n\n        ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g\n        ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b\n\n        ctx_size += ggml_row_size(wtype, n_embd*n_vocab); // wte\n\n        ctx_size += ggml_row_size(wtype,         n_embd*n_vocab); // lmh_g\n        ctx_size += ggml_row_size(GGML_TYPE_F32,        n_vocab); // lmh_b\n\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b\n\n        ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_q_proj_w\n        ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_k_proj_w\n        ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_v_proj_w\n\n        ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w\n\n        ctx_size += n_layer*(ggml_row_size(wtype,         4*n_embd*n_embd)); // c_mlp_fc_w\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd));        // c_mlp_fc_b\n\n        ctx_size += n_layer*(ggml_row_size(wtype,         4*n_embd*n_embd)); // c_mlp_proj_w\n        ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32,   n_embd));        // c_mlp_proj_b\n\n        ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F16, n_embd); // memory_k\n        ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F16, n_embd); // memory_v\n\n        ctx_size += (5 + 10*n_layer)*512; // object overhead\n\n        printf(\"%s: ggml ctx size = %6.2f MB\\n\", __func__, ctx_size/(1024.0*1024.0));\n    }\n\n    // create the ggml context\n    {\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ ctx_size,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ false,\n        };\n\n        model.ctx = ggml_init(params);\n        if (!model.ctx) {\n            fprintf(stderr, \"%s: ggml_init() failed\\n\", __func__);\n            return false;\n        }\n    }\n\n    // prepare memory for the weights\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_vocab = hparams.n_vocab;\n\n        model.layers.resize(n_layer);\n\n        model.wte    = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);\n\n        model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);\n        model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);\n\n        model.lmh_g  = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);\n        model.lmh_b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab);\n\n        // map by name\n        model.tensors[\"transformer.wte.weight\"] = model.wte;\n\n        model.tensors[\"transformer.ln_f.weight\"] = model.ln_f_g;\n        model.tensors[\"transformer.ln_f.bias\"]   = model.ln_f_b;\n\n        model.tensors[\"lm_head.weight\"] = model.lmh_g;\n        model.tensors[\"lm_head.bias\"]   = model.lmh_b;\n\n        for (int i = 0; i < n_layer; ++i) {\n            auto & layer = model.layers[i];\n\n            layer.ln_1_g          = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n            layer.ln_1_b          = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            layer.c_attn_q_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);\n            layer.c_attn_k_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);\n            layer.c_attn_v_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);\n\n            layer.c_attn_proj_w   = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);\n\n            layer.c_mlp_fc_w      = ggml_new_tensor_2d(ctx, wtype,           n_embd, 4*n_embd);\n            layer.c_mlp_fc_b      = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);\n\n            layer.c_mlp_proj_w    = ggml_new_tensor_2d(ctx, wtype,         4*n_embd,   n_embd);\n            layer.c_mlp_proj_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);\n\n            // map by name\n            model.tensors[\"transformer.h.\" + std::to_string(i) + \".ln_1.weight\"]          = layer.ln_1_g;\n            model.tensors[\"transformer.h.\" + std::to_string(i) + \".ln_1.bias\"]            = layer.ln_1_b;\n\n            model.tensors[\"transformer.h.\" + std::to_string(i) + \".attn.q_proj.weight\"]   = layer.c_attn_q_proj_w;\n            model.tensors[\"transformer.h.\" + std::to_string(i) + \".attn.k_proj.weight\"]   = layer.c_attn_k_proj_w;\n            model.tensors[\"transformer.h.\" + std::to_string(i) + \".attn.v_proj.weight\"]   = layer.c_attn_v_proj_w;\n\n            model.tensors[\"transformer.h.\" + std::to_string(i) + \".attn.out_proj.weight\"] = layer.c_attn_proj_w;\n\n            model.tensors[\"transformer.h.\" + std::to_string(i) + \".mlp.fc_in.weight\"]     = layer.c_mlp_fc_w;\n            model.tensors[\"transformer.h.\" + std::to_string(i) + \".mlp.fc_in.bias\"]       = layer.c_mlp_fc_b;\n\n            model.tensors[\"transformer.h.\" + std::to_string(i) + \".mlp.fc_out.weight\"]    = layer.c_mlp_proj_w;\n            model.tensors[\"transformer.h.\" + std::to_string(i) + \".mlp.fc_out.bias\"]      = layer.c_mlp_proj_b;\n        }\n    }\n\n    // key + value memory\n    {\n        const auto & hparams = model.hparams;\n\n        const int n_embd  = hparams.n_embd;\n        const int n_layer = hparams.n_layer;\n        const int n_ctx   = hparams.n_ctx;\n\n        const int n_mem      = n_layer*n_ctx;\n        const int n_elements = n_embd*n_mem;\n\n        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);\n        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);\n\n        const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);\n\n        printf(\"%s: memory_size = %8.2f MB, n_mem = %d\\n\", __func__, memory_size/1024.0/1024.0, n_mem);\n    }\n\n    // load weights\n    {\n        int n_tensors = 0;\n        size_t total_size = 0;\n\n        printf(\"%s: \", __func__);\n\n        while (true) {\n            int32_t n_dims;\n            int32_t length;\n            int32_t ttype;\n\n            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));\n            fin.read(reinterpret_cast<char *>(&length), sizeof(length));\n            fin.read(reinterpret_cast<char *>(&ttype),  sizeof(ttype));\n\n            if (fin.eof()) {\n                break;\n            }\n\n            int32_t nelements = 1;\n            int32_t ne[2] = { 1, 1 };\n            for (int i = 0; i < n_dims; ++i) {\n                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));\n                nelements *= ne[i];\n            }\n\n            std::string name(length, 0);\n            fin.read(&name[0], length);\n\n            if (model.tensors.find(name) == model.tensors.end()) {\n                fprintf(stderr, \"%s: unknown tensor '%s' in model file\\n\", __func__, name.c_str());\n                return false;\n            }\n\n            auto tensor = model.tensors[name];\n            if (ggml_nelements(tensor) != nelements) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong size in model file\\n\", __func__, name.c_str());\n                return false;\n            }\n\n            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\\n\",\n                        __func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);\n                return false;\n            }\n\n            // for debugging\n            if (0) {\n                printf(\"%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\\n\", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));\n            }\n\n            const size_t bpe = ggml_type_size(ggml_type(ttype));\n\n            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\\n\",\n                        __func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe);\n                return false;\n            }\n\n            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));\n\n            //printf(\"%42s - [%5d, %5d], type = %6s, %6.2f MB\\n\", name.c_str(), ne[0], ne[1], ttype == 0 ? \"float\" : \"f16\", ggml_nbytes(tensor)/1024.0/1024.0);\n            total_size += ggml_nbytes(tensor);\n            if (++n_tensors % 8 == 0) {\n                printf(\".\");\n                fflush(stdout);\n            }\n        }\n\n        printf(\" done\\n\");\n\n        printf(\"%s: model size = %8.2f MB / num tensors = %d\\n\", __func__, total_size/1024.0/1024.0, n_tensors);\n    }\n\n    fin.close();\n\n    return true;\n}\n\n// evaluate the transformer\n//\n//   - model:     the model\n//   - n_threads: number of threads to use\n//   - n_past:    the context size so far\n//   - embd_inp:  the embeddings of the tokens in the context\n//   - embd_w:    the predicted logits for the next token\n//\n// The GPT-J model requires about 16MB of memory per input token.\n//\nbool gptj_eval(\n        const gptj_model & model,\n        const int n_threads,\n        const int n_past,\n        const std::vector<gpt_vocab::id> & embd_inp,\n              std::vector<float>         & embd_w,\n              size_t                     & mem_per_token) {\n    const int N = embd_inp.size();\n\n    const auto & hparams = model.hparams;\n\n    const int n_embd  = hparams.n_embd;\n    const int n_layer = hparams.n_layer;\n    const int n_ctx   = hparams.n_ctx;\n    const int n_head  = hparams.n_head;\n    const int n_vocab = hparams.n_vocab;\n    const int n_rot   = hparams.n_rot;\n\n    static size_t buf_size = 256u*1024*1024;\n    static void * buf = malloc(buf_size);\n\n    if (mem_per_token > 0 && mem_per_token*N > buf_size) {\n        const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead\n        //printf(\"\\n%s: reallocating buffer from %zu to %zu bytes\\n\", __func__, buf_size, buf_size_new);\n\n        // reallocate\n        buf_size = buf_size_new;\n        buf = realloc(buf, buf_size);\n        if (buf == nullptr) {\n            fprintf(stderr, \"%s: failed to allocate %zu bytes\\n\", __func__, buf_size);\n            return false;\n        }\n    }\n\n    struct ggml_init_params params = {\n        /*.mem_size   =*/ buf_size,\n        /*.mem_buffer =*/ buf,\n        /*.no_alloc   =*/ false,\n    };\n\n    struct ggml_context * ctx0 = ggml_init(params);\n    struct ggml_cgraph * gf = ggml_new_graph(ctx0);\n\n    // KQ_pos - contains the positions\n    struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);\n    int * data = (int *) KQ_pos->data;\n    for (int i = 0; i < N; ++i) {\n        data[i] = n_past + i;\n    }\n\n    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);\n    memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));\n\n    // wte\n    struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.wte, embd);\n\n    for (int il = 0; il < n_layer; ++il) {\n        struct ggml_tensor * cur;\n\n        // norm\n        {\n            cur = ggml_norm(ctx0, inpL, hparams.eps);\n\n            // cur = ln_1_g*cur + ln_1_b\n            cur = ggml_add(ctx0,\n                    ggml_mul(ctx0,\n                        ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),\n                        cur),\n                    ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));\n        }\n\n        struct ggml_tensor * inpSA = cur;\n\n        // self-attention\n        {\n            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0);\n            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0);\n\n            // store key and value to memory\n            {\n                struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_v_proj_w, cur));\n\n                struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));\n                struct ggml_tensor * v = ggml_view_2d(ctx0, model.memory_v, N, n_embd,\n                        (   n_ctx)*ggml_element_size(model.memory_v),\n                        (il*n_ctx)*ggml_element_size(model.memory_v)*n_embd + n_past*ggml_element_size(model.memory_v));\n\n                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));\n                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));\n            }\n\n            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)\n            struct ggml_tensor * Q =\n                ggml_permute(ctx0,\n                        Qcur,\n                        0, 2, 1, 3);\n\n            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)\n            struct ggml_tensor * K =\n                ggml_permute(ctx0,\n                        ggml_reshape_3d(ctx0,\n                            ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),\n                            n_embd/n_head, n_head, n_past + N),\n                        0, 2, 1, 3);\n\n            // K * Q\n            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);\n\n            // KQ_scaled = KQ / sqrt(n_embd/n_head)\n            struct ggml_tensor * KQ_scaled =\n                ggml_scale_inplace(ctx0,\n                        KQ,\n                        1.0f/sqrt(float(n_embd)/n_head));\n\n            // KQ_masked = mask_past(KQ_scaled)\n            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);\n\n            // KQ = soft_max(KQ_masked)\n            struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);\n\n            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()\n            struct ggml_tensor * V =\n                ggml_view_3d(ctx0, model.memory_v,\n                        n_past + N, n_embd/n_head, n_head,\n                        n_ctx*ggml_element_size(model.memory_v),\n                        n_ctx*ggml_element_size(model.memory_v)*n_embd/n_head,\n                        il*n_ctx*ggml_element_size(model.memory_v)*n_embd);\n\n            // KQV = transpose(V) * KQ_soft_max\n            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);\n\n            // KQV_merged = KQV.permute(0, 2, 1, 3)\n            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);\n\n            // cur = KQV_merged.contiguous().view(n_embd, N)\n            cur = ggml_cpy(ctx0,\n                    KQV_merged,\n                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));\n\n            // projection (no bias)\n            cur = ggml_mul_mat(ctx0,\n                    model.layers[il].c_attn_proj_w,\n                    cur);\n        }\n\n        struct ggml_tensor * inpFF = cur;\n\n        // feed-forward network\n        // this is independent of the self-attention result, so it could be done in parallel to the self-attention\n        {\n            // note here we pass inpSA instead of cur\n            cur = ggml_mul_mat(ctx0,\n                    model.layers[il].c_mlp_fc_w,\n                    inpSA);\n\n            cur = ggml_add(ctx0,\n                    ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),\n                    cur);\n\n            // GELU activation\n            cur = ggml_gelu(ctx0, cur);\n\n            // projection\n            // cur = proj_w*cur + proj_b\n            cur = ggml_mul_mat(ctx0,\n                    model.layers[il].c_mlp_proj_w,\n                    cur);\n\n            cur = ggml_add(ctx0,\n                    ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),\n                    cur);\n        }\n\n        // self-attention + FF\n        cur  = ggml_add(ctx0, cur, inpFF);\n\n        // input for next layer\n        inpL = ggml_add(ctx0, cur, inpL);\n    }\n\n    // norm\n    {\n        inpL = ggml_norm(ctx0, inpL, hparams.eps);\n\n        // inpL = ln_f_g*inpL + ln_f_b\n        inpL = ggml_add(ctx0,\n                ggml_mul(ctx0,\n                    ggml_repeat(ctx0, model.ln_f_g, inpL),\n                    inpL),\n                ggml_repeat(ctx0, model.ln_f_b, inpL));\n    }\n\n    // lm_head\n    {\n        inpL = ggml_mul_mat(ctx0, model.lmh_g, inpL);\n\n        inpL = ggml_add(ctx0,\n                ggml_repeat(ctx0, model.lmh_b, inpL),\n                inpL);\n    }\n\n    // logits -> probs\n    //inpL = ggml_soft_max_inplace(ctx0, inpL);\n\n    // run the computation\n    ggml_build_forward_expand(gf, inpL);\n    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);\n\n    //if (n_past%100 == 0) {\n    //    ggml_graph_print   (&gf);\n    //    ggml_graph_dump_dot(&gf, NULL, \"gpt-j.dot\");\n    //}\n\n    //embd_w.resize(n_vocab*N);\n    //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);\n\n    // return result for just the last token\n    embd_w.resize(n_vocab);\n    memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);\n\n    if (mem_per_token == 0) {\n        mem_per_token = ggml_used_mem(ctx0)/N;\n    }\n    //printf(\"used_mem = %zu\\n\", ggml_used_mem(ctx0));\n\n    ggml_free(ctx0);\n\n    return true;\n}\n\nint main(int argc, char ** argv) {\n    ggml_time_init();\n\n    const int64_t t_main_start_us = ggml_time_us();\n\n    gpt_params params;\n    params.model = \"models/gpt-j-6B/ggml-model.bin\";\n\n    if (gpt_params_parse(argc, argv, params) == false) {\n        return 1;\n    }\n\n    if (params.seed < 0) {\n        params.seed = time(NULL);\n    }\n\n    printf(\"%s: seed = %d\\n\", __func__, params.seed);\n\n    std::mt19937 rng(params.seed);\n    if (params.prompt.empty()) {\n        params.prompt = gpt_random_prompt(rng);\n    }\n\n    int64_t t_load_us = 0;\n\n    gpt_vocab vocab;\n    gptj_model model;\n\n    // load the model\n    {\n        const int64_t t_start_us = ggml_time_us();\n\n        if (!gptj_model_load(params.model, model, vocab)) {\n            fprintf(stderr, \"%s: failed to load model from '%s'\\n\", __func__, params.model.c_str());\n            return 1;\n        }\n\n        t_load_us = ggml_time_us() - t_start_us;\n\n        test_gpt_tokenizer(vocab, params.token_test);\n    }\n\n    int n_past = 0;\n\n    int64_t t_sample_us  = 0;\n    int64_t t_predict_us = 0;\n\n    std::vector<float> logits;\n\n    // tokenize the prompt\n    std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);\n\n    params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());\n\n    printf(\"%s: number of tokens in prompt = %zu\\n\", __func__, embd_inp.size());\n    printf(\"\\n\");\n\n    std::vector<gpt_vocab::id> embd;\n\n    // determine the required inference memory per token:\n    size_t mem_per_token = 0;\n    gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);\n\n    for (size_t i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {\n        // predict\n        if (embd.size() > 0) {\n            const int64_t t_start_us = ggml_time_us();\n\n            if (!gptj_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {\n                printf(\"Failed to predict\\n\");\n                return 1;\n            }\n\n            t_predict_us += ggml_time_us() - t_start_us;\n        }\n\n        n_past += embd.size();\n        embd.clear();\n\n        if (i >= embd_inp.size()) {\n            // sample next token\n            const int   top_k = params.top_k;\n            const float top_p = params.top_p;\n            const float temp  = params.temp;\n\n            const int n_vocab = model.hparams.n_vocab;\n\n            gpt_vocab::id id = 0;\n\n            {\n                const int64_t t_start_sample_us = ggml_time_us();\n\n                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);\n\n                t_sample_us += ggml_time_us() - t_start_sample_us;\n            }\n\n            // add it to the context\n            embd.push_back(id);\n        } else {\n            // if here, it means we are still processing the input prompt\n            for (size_t k = i; k < embd_inp.size(); k++) {\n                embd.push_back(embd_inp[k]);\n                if (int32_t(embd.size()) > params.n_batch) {\n                    break;\n                }\n            }\n            i += embd.size() - 1;\n        }\n\n        // display text\n        for (auto id : embd) {\n            printf(\"%s\", vocab.id_to_token[id].c_str());\n        }\n        fflush(stdout);\n\n        // end of text token\n        if (embd.back() == 50256) {\n            break;\n        }\n    }\n\n    // report timing\n    {\n        const int64_t t_main_end_us = ggml_time_us();\n\n        printf(\"\\n\\n\");\n        printf(\"%s: mem per token = %8zu bytes\\n\", __func__, mem_per_token);\n        printf(\"%s:     load time = %8.2f ms\\n\", __func__, t_load_us/1000.0f);\n        printf(\"%s:   sample time = %8.2f ms\\n\", __func__, t_sample_us/1000.0f);\n        printf(\"%s:  predict time = %8.2f ms / %.2f ms per token\\n\", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);\n        printf(\"%s:    total time = %8.2f ms\\n\", __func__, (t_main_end_us - t_main_start_us)/1000.0f);\n    }\n\n    ggml_free(model.ctx);\n\n    return 0;\n}\n"
  },
  {
    "path": "examples/gpt-j/quantize.cpp",
    "content": "#include \"ggml.h\"\n\n#include \"common.h\"\n#include \"common-ggml.h\"\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <string>\n#include <vector>\n#include <regex>\n\n// default hparams (GPT-J 6B)\nstruct gptj_hparams {\n    int32_t n_vocab = 50400;\n    int32_t n_ctx   = 2048;\n    int32_t n_embd  = 4096;\n    int32_t n_head  = 16;\n    int32_t n_layer = 28;\n    int32_t n_rot   = 64;\n    int32_t ftype   = 1;\n};\n\n// quantize a model\nbool gptj_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {\n    gpt_vocab vocab;\n\n    printf(\"%s: loading model from '%s'\\n\", __func__, fname_inp.c_str());\n\n    auto finp = std::ifstream(fname_inp, std::ios::binary);\n    if (!finp) {\n        fprintf(stderr, \"%s: failed to open '%s' for reading\\n\", __func__, fname_inp.c_str());\n        return false;\n    }\n\n    auto fout = std::ofstream(fname_out, std::ios::binary);\n    if (!fout) {\n        fprintf(stderr, \"%s: failed to open '%s' for writing\\n\", __func__, fname_out.c_str());\n        return false;\n    }\n\n    // verify magic\n    {\n        uint32_t magic;\n        finp.read((char *) &magic, sizeof(magic));\n        if (magic != GGML_FILE_MAGIC) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad magic)\\n\", __func__, fname_inp.c_str());\n            return false;\n        }\n\n        fout.write((char *) &magic, sizeof(magic));\n    }\n\n    gptj_hparams hparams;\n\n    // load hparams\n    {\n        finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));\n        finp.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));\n        finp.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));\n        finp.read((char *) &hparams.n_head,  sizeof(hparams.n_head));\n        finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));\n        finp.read((char *) &hparams.n_rot,   sizeof(hparams.n_rot));\n        finp.read((char *) &hparams.ftype,   sizeof(hparams.ftype));\n\n        const int32_t qntvr_src =    hparams.ftype / GGML_QNT_VERSION_FACTOR;\n        const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype;\n\n        printf(\"%s: n_vocab     = %d\\n\", __func__, hparams.n_vocab);\n        printf(\"%s: n_ctx       = %d\\n\", __func__, hparams.n_ctx);\n        printf(\"%s: n_embd      = %d\\n\", __func__, hparams.n_embd);\n        printf(\"%s: n_head      = %d\\n\", __func__, hparams.n_head);\n        printf(\"%s: n_layer     = %d\\n\", __func__, hparams.n_layer);\n        printf(\"%s: ftype (src) = %d\\n\", __func__, hparams.ftype);\n        printf(\"%s: qntvr (src) = %d\\n\", __func__, qntvr_src);\n        printf(\"%s: ftype (dst) = %d\\n\", __func__, ftype_dst);\n        printf(\"%s: qntvr (dst) = %d\\n\", __func__, GGML_QNT_VERSION);\n\n        fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));\n        fout.write((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));\n        fout.write((char *) &hparams.n_embd,  sizeof(hparams.n_embd));\n        fout.write((char *) &hparams.n_head,  sizeof(hparams.n_head));\n        fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer));\n        fout.write((char *) &hparams.n_rot,   sizeof(hparams.n_rot));\n        fout.write((char *) &ftype_dst,       sizeof(ftype_dst));\n    }\n\n    // load vocab\n    {\n        int32_t n_vocab = 0;\n        finp.read ((char *) &n_vocab, sizeof(n_vocab));\n        fout.write((char *) &n_vocab, sizeof(n_vocab));\n\n        if (n_vocab != hparams.n_vocab) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad vocab size %d != %d)\\n\",\n                    __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);\n            return false;\n        }\n\n        std::string word;\n        for (int i = 0; i < n_vocab; i++) {\n            uint32_t len;\n            finp.read ((char *) &len, sizeof(len));\n            fout.write((char *) &len, sizeof(len));\n\n            word.resize(len);\n            finp.read ((char *) word.data(), len);\n            fout.write((char *) word.data(), len);\n\n            vocab.token_to_id[word] = i;\n            vocab.id_to_token[i] = word;\n        }\n    }\n\n    // regexes of tensor names to be quantized\n    const std::vector<std::string> to_quant = {\n        \".*weight\",\n    };\n\n    if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, {})) {\n        fprintf(stderr, \"%s: failed to quantize model '%s'\\n\", __func__, fname_inp.c_str());\n        return false;\n    }\n\n    finp.close();\n    fout.close();\n\n    return true;\n}\n\n// usage:\n//  ./gpt-2-quantize models/gpt-2-117M/ggml-model.bin models/gpt-2-117M/ggml-model-quant.bin type\n//\nint main(int argc, char ** argv) {\n    if (argc != 4) {\n        fprintf(stderr, \"usage: %s model-f32.bin model-quant.bin type\\n\", argv[0]);\n        ggml_print_ftypes(stderr);\n        return 1;\n    }\n\n    // needed to initialize f16 tables\n    {\n        struct ggml_init_params params = { 0, NULL, false };\n        struct ggml_context * ctx = ggml_init(params);\n        ggml_free(ctx);\n    }\n\n    const std::string fname_inp = argv[1];\n    const std::string fname_out = argv[2];\n\n    const ggml_ftype ftype = ggml_parse_ftype(argv[3]);\n\n    const int64_t t_main_start_us = ggml_time_us();\n\n    int64_t t_quantize_us = 0;\n\n    // load the model\n    {\n        const int64_t t_start_us = ggml_time_us();\n\n        if (!gptj_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {\n            fprintf(stderr, \"%s: failed to quantize model from '%s'\\n\", __func__, fname_inp.c_str());\n            return 1;\n        }\n\n        t_quantize_us = ggml_time_us() - t_start_us;\n    }\n\n    // report timing\n    {\n        const int64_t t_main_end_us = ggml_time_us();\n\n        printf(\"\\n\");\n        printf(\"%s: quantize time = %8.2f ms\\n\", __func__, t_quantize_us/1000.0f);\n        printf(\"%s:    total time = %8.2f ms\\n\", __func__, (t_main_end_us - t_main_start_us)/1000.0f);\n    }\n\n    return 0;\n}\n"
  },
  {
    "path": "examples/magika/CMakeLists.txt",
    "content": "#\n# magika\n\nset(TEST_TARGET magika)\nadd_executable(${TEST_TARGET} main.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)\n\n#\n# For GPU offloading\n\nif (GGML_CUDA)\n    add_compile_definitions(GGML_USE_CUDA)\nendif()\n\nif (GGML_METAL)\n    add_compile_definitions(GGML_USE_METAL)\nendif()\n"
  },
  {
    "path": "examples/magika/README.md",
    "content": "# Google Magika inference\n\nSimple example that shows how to use GGML for inference with the [Google Magika](https://github.com/google/magika) file type detection model.\n\n### Usage\n\n- Obtain the Magika model in H5 format\n  - Pinned version: https://github.com/google/magika/blob/4460acb5d3f86807c3b53223229dee2afa50c025/assets_generation/models/standard_v1/model.h5\n- Use `convert.py` to convert the model to gguf format:\n```bash\n  $ python examples/magika/convert.py /path/to/model.h5\n```\n- Invoke the program with the model file and a list of files to identify:\n```bash\n  $ build/bin/magika model.h5.gguf examples/sam/example.jpg examples/magika/convert.py README.md src/ggml.c /bin/gcc write.exe jfk.wav\n  examples/sam/example.jpg      : jpeg (100.00%) pptx (0.00%) smali (0.00%) shell (0.00%) sevenzip (0.00%)\n  examples/magika/convert.py    : python (99.99%) javascript (0.00%) txt (0.00%) asm (0.00%) scala (0.00%)\n  README.md                     : markdown (100.00%) txt (0.00%) yaml (0.00%) ppt (0.00%) shell (0.00%)\n  src/ggml.c                    : c (99.95%) txt (0.04%) asm (0.01%) yaml (0.00%) html (0.00%)\n  /bin/gcc                      : elf (99.98%) odex (0.02%) pptx (0.00%) smali (0.00%) shell (0.00%)\n  write.exe                     : pebin (100.00%) ppt (0.00%) smali (0.00%) shell (0.00%) sevenzip (0.00%)\n  jfk.wav                       : wav (100.00%) ppt (0.00%) shell (0.00%) sevenzip (0.00%) scala (0.00%)\n```\n"
  },
  {
    "path": "examples/magika/convert.py",
    "content": "import sys\nfrom tensorflow import keras\nimport gguf\n\ndef convert(model_name):\n    model = keras.models.load_model(model_name, compile=False)\n    gguf_model_name = model_name + \".gguf\"\n    gguf_writer = gguf.GGUFWriter(gguf_model_name, \"magika\")\n\n    for layer in model.layers:\n        # export layers with weights\n        if layer.weights:\n            for weight in layer.weights:\n                print(f\"  [{weight.name}] {weight.shape} {weight.dtype}\")\n                weight_data = weight.numpy()\n                gguf_writer.add_tensor(weight.name, weight_data.T)\n\n\n    gguf_writer.write_header_to_file()\n    gguf_writer.write_kv_data_to_file()\n    gguf_writer.write_tensors_to_file()\n    gguf_writer.close()\n    print(\"Model converted and saved to '{}'\".format(gguf_model_name))\n\n\nif __name__ == '__main__':\n    if len(sys.argv) > 1:\n        model_file = sys.argv[1]\n    else:\n        model_file = \"model.h5\"\n\n    convert(model_file)\n"
  },
  {
    "path": "examples/magika/main.cpp",
    "content": "#include \"ggml.h\"\n#include \"gguf.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n\n#include <algorithm>\n#include <cmath>\n#include <numeric>\n#include <stdexcept>\n#include <string>\n#include <vector>\n\nstatic const char * magika_labels[] = {\n    \"ai\",                 \"apk\",                \"appleplist\",         \"asm\",                \"asp\",\n    \"batch\",              \"bmp\",                \"bzip\",               \"c\",                  \"cab\",\n    \"cat\",                \"chm\",                \"coff\",               \"crx\",                \"cs\",\n    \"css\",                \"csv\",                \"deb\",                \"dex\",                \"dmg\",\n    \"doc\",                \"docx\",               \"elf\",                \"emf\",                \"eml\",\n    \"epub\",               \"flac\",               \"gif\",                \"go\",                 \"gzip\",\n    \"hlp\",                \"html\",               \"ico\",                \"ini\",                \"internetshortcut\",\n    \"iso\",                \"jar\",                \"java\",               \"javabytecode\",       \"javascript\",\n    \"jpeg\",               \"json\",               \"latex\",              \"lisp\",               \"lnk\",\n    \"m3u\",                \"macho\",              \"makefile\",           \"markdown\",           \"mht\",\n    \"mp3\",                \"mp4\",                \"mscompress\",         \"msi\",                \"mum\",\n    \"odex\",               \"odp\",                \"ods\",                \"odt\",                \"ogg\",\n    \"outlook\",            \"pcap\",               \"pdf\",                \"pebin\",              \"pem\",\n    \"perl\",               \"php\",                \"png\",                \"postscript\",         \"powershell\",\n    \"ppt\",                \"pptx\",               \"python\",             \"pythonbytecode\",     \"rar\",\n    \"rdf\",                \"rpm\",                \"rst\",                \"rtf\",                \"ruby\",\n    \"rust\",               \"scala\",              \"sevenzip\",           \"shell\",              \"smali\",\n    \"sql\",                \"squashfs\",           \"svg\",                \"swf\",                \"symlinktext\",\n    \"tar\",                \"tga\",                \"tiff\",               \"torrent\",            \"ttf\",\n    \"txt\",                \"unknown\",            \"vba\",                \"wav\",                \"webm\",\n    \"webp\",               \"winregistry\",        \"wmf\",                \"xar\",                \"xls\",\n    \"xlsb\",               \"xlsx\",               \"xml\",                \"xpi\",                \"xz\",\n    \"yaml\",               \"zip\",                \"zlibstream\"\n};\n\nstruct magika_hparams {\n    const int block_size = 4096;\n    const int beg_size = 512;\n    const int mid_size = 512;\n    const int end_size = 512;\n    const int min_file_size_for_dl = 16;\n    const int n_label = 113;\n    const float f_norm_eps = 0.001f;\n    const int padding_token = 256;\n};\n\nstruct magika_model {\n    ~magika_model() {\n        ggml_backend_buffer_free(buf_w);\n        ggml_backend_free(backend);\n        ggml_free(ctx_w);\n    }\n\n    magika_hparams hparams;\n\n    struct ggml_tensor * dense_w;\n    struct ggml_tensor * dense_b;\n\n    struct ggml_tensor * layer_norm_gamma;\n    struct ggml_tensor * layer_norm_beta;\n\n    struct ggml_tensor * dense_1_w;\n    struct ggml_tensor * dense_1_b;\n\n    struct ggml_tensor * dense_2_w;\n    struct ggml_tensor * dense_2_b;\n\n    struct ggml_tensor * layer_norm_1_gamma;\n    struct ggml_tensor * layer_norm_1_beta;\n\n    struct ggml_tensor * target_label_w;\n    struct ggml_tensor * target_label_b;\n\n    ggml_backend_t backend = ggml_backend_cpu_init();\n    ggml_backend_buffer_t buf_w = nullptr;\n    struct ggml_context * ctx_w = nullptr;\n};\n\nstruct ggml_tensor * checked_get_tensor(struct ggml_context * ctx, const char * name) {\n    struct ggml_tensor * tensor = ggml_get_tensor(ctx, name);\n    if (!tensor) {\n        fprintf(stderr, \"%s: tensor '%s' not found\\n\", __func__, name);\n        throw std::runtime_error(\"ggml_get_tensor() failed\");\n    }\n    return tensor;\n}\n\nbool magika_model_load(const std::string & fname, magika_model & model) {\n    auto & ctx = model.ctx_w;\n\n    struct gguf_init_params params = {\n        /*.no_alloc   =*/ true,\n        /*.ctx        =*/ &ctx,\n    };\n\n    struct gguf_context * ctx_gguf = gguf_init_from_file(fname.c_str(), params);\n    if (!ctx_gguf) {\n        fprintf(stderr, \"%s: gguf_init_from_file() failed\\n\", __func__);\n        return false;\n    }\n\n    model.buf_w = ggml_backend_alloc_ctx_tensors(ctx, model.backend);\n    if (!model.buf_w) {\n        fprintf(stderr, \"%s: ggml_backend_alloc_ctx_tensors() failed\\n\", __func__);\n        gguf_free(ctx_gguf);\n        return false;\n    }\n\n    try {\n        model.dense_w = checked_get_tensor(ctx, \"dense/kernel:0\");\n        model.dense_b = checked_get_tensor(ctx, \"dense/bias:0\");\n\n        model.layer_norm_gamma = checked_get_tensor(ctx, \"layer_normalization/gamma:0\");\n        model.layer_norm_beta  = checked_get_tensor(ctx, \"layer_normalization/beta:0\");\n\n        model.dense_1_w = checked_get_tensor(ctx, \"dense_1/kernel:0\");\n        model.dense_1_b = checked_get_tensor(ctx, \"dense_1/bias:0\");\n\n        model.dense_2_w = checked_get_tensor(ctx, \"dense_2/kernel:0\");\n        model.dense_2_b = checked_get_tensor(ctx, \"dense_2/bias:0\");\n\n        model.layer_norm_1_gamma = checked_get_tensor(ctx, \"layer_normalization_1/gamma:0\");\n        model.layer_norm_1_beta  = checked_get_tensor(ctx, \"layer_normalization_1/beta:0\");\n\n        model.target_label_w = checked_get_tensor(ctx, \"target_label/kernel:0\");\n        model.target_label_b = checked_get_tensor(ctx, \"target_label/bias:0\");\n    } catch (const std::exception & e) {\n        fprintf(stderr, \"%s: %s\\n\", __func__, e.what());\n        gguf_free(ctx_gguf);\n        return false;\n    }\n\n    FILE * f = fopen(fname.c_str(), \"rb\");\n    if (!f) {\n        fprintf(stderr, \"%s: fopen() failed\\n\", __func__);\n        gguf_free(ctx_gguf);\n        return false;\n    }\n\n    const int n_tensors = gguf_get_n_tensors(ctx_gguf);\n\n    for (int i = 0; i < n_tensors; i++) {\n        const char * name = gguf_get_tensor_name(ctx_gguf, i);\n        struct ggml_tensor * tensor = ggml_get_tensor(ctx, name);\n        size_t offs = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i);\n\n        //printf(\"%-30s: [%3ld, %3ld, %3ld, %3ld] %s\\n\",\n        //    name,\n        //    tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3],\n        //    ggml_type_name(tensor->type));\n\n        std::vector<uint8_t> buf(ggml_nbytes(tensor));\n        if (fseek(f, offs, SEEK_SET) != 0) {\n            fprintf(stderr, \"%s: fseek() failed\\n\", __func__);\n            gguf_free(ctx_gguf);\n            fclose(f);\n            return false;\n        }\n\n        if (fread(buf.data(), 1, buf.size(), f) != buf.size()) {\n            fprintf(stderr, \"%s: fread() failed\\n\", __func__);\n            gguf_free(ctx_gguf);\n            fclose(f);\n            return false;\n        }\n\n        ggml_backend_tensor_set(tensor, buf.data(), 0, buf.size());\n    }\n\n    fclose(f);\n\n    gguf_free(ctx_gguf);\n\n    return true;\n}\n\nstruct ggml_cgraph * magika_graph(\n    const magika_model & model,\n    const int n_files) {\n\n    const auto & hparams = model.hparams;\n\n    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();\n    static std::vector<uint8_t> buf(buf_size);\n\n    struct ggml_init_params params = {\n        /*.mem_size   =*/ buf_size,\n        /*.mem_buffer =*/ buf.data(),\n        /*.no_alloc   =*/ true,\n    };\n\n    struct ggml_context * ctx = ggml_init(params);\n\n    struct ggml_cgraph * gf = ggml_new_graph(ctx);\n\n    struct ggml_tensor * input = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 257, 1536, n_files); // one-hot\n    ggml_set_name(input, \"input\");\n    ggml_set_input(input);\n\n    struct ggml_tensor * cur;\n\n    // dense\n    cur = ggml_mul_mat(ctx, model.dense_w, input);\n    cur = ggml_add(ctx, cur, model.dense_b); // [128, 1536, n_files]\n    cur = ggml_gelu(ctx, cur);\n\n    // reshape\n    cur = ggml_reshape_3d(ctx, cur, 512, 384, n_files); // [384, 512, n_files]\n    cur = ggml_cont(ctx, ggml_transpose(ctx, cur));\n\n    // layer normalization\n    cur = ggml_norm(ctx, cur, hparams.f_norm_eps);\n    cur = ggml_mul(ctx, cur, model.layer_norm_gamma); // [384, 512, n_files]\n    cur = ggml_add(ctx, cur, model.layer_norm_beta);  // [384, 512, n_files]\n\n    // dense_1\n    cur = ggml_cont(ctx, ggml_transpose(ctx, cur));\n    cur = ggml_mul_mat(ctx, model.dense_1_w, cur);\n    cur = ggml_add(ctx, cur, model.dense_1_b); // [256, 384, n_files]\n    cur = ggml_gelu(ctx, cur);\n\n    // dense_2\n    cur = ggml_mul_mat(ctx, model.dense_2_w, cur);\n    cur = ggml_add(ctx, cur, model.dense_2_b); // [256, 384, n_files]\n    cur = ggml_gelu(ctx, cur);\n\n    // global_max_pooling1d\n    cur = ggml_cont(ctx, ggml_transpose(ctx, cur)); // [384, 256, n_files]\n    cur = ggml_pool_1d(ctx, cur, GGML_OP_POOL_MAX, 384, 384, 0); // [1, 256, n_files]\n    cur = ggml_reshape_2d(ctx, cur, 256, n_files); // [256, n_files]\n\n    // layer normalization 1\n    cur = ggml_norm(ctx, cur, hparams.f_norm_eps);\n    cur = ggml_mul(ctx, cur, model.layer_norm_1_gamma); // [256, n_files]\n    cur = ggml_add(ctx, cur, model.layer_norm_1_beta);  // [256, n_files]\n\n    // target_label\n    cur = ggml_mul_mat(ctx, model.target_label_w, cur);\n    cur = ggml_add(ctx, cur, model.target_label_b); // [n_label, n_files]\n    cur = ggml_soft_max(ctx, cur); // [n_label, n_files]\n    ggml_set_name(cur, \"target_label_probs\");\n    ggml_set_output(cur);\n\n    ggml_build_forward_expand(gf, cur);\n\n    return gf;\n}\n\nbool magika_eval(\n    struct magika_model & model,\n    const std::vector<std::string> & fnames) {\n\n    const auto & hparams = model.hparams;\n\n    static ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));\n\n    struct ggml_cgraph * gf = magika_graph(model, fnames.size());\n\n    if (!ggml_gallocr_alloc_graph(alloc, gf)) {\n        fprintf(stderr, \"%s: ggml_gallocr_alloc_graph() failed\\n\", __func__);\n        return false;\n    }\n\n    struct ggml_tensor * input = ggml_graph_get_tensor(gf, \"input\");\n\n    for (size_t i = 0; i < fnames.size(); i++) {\n        FILE * f = fopen(fnames[i].c_str(), \"rb\");\n        if (!f) {\n            fprintf(stderr, \"%s: fopen() failed\\n\", __func__);\n            return false;\n        }\n        fseek(f, 0, SEEK_END);\n        long fsize = ftell(f);\n\n        // the buffer is padded with the padding_token if the file is smaller than the block size\n        std::vector<int> buf(1536, hparams.padding_token);\n        std::vector<uint8_t> read_buf(std::max(hparams.beg_size, std::max(hparams.mid_size, hparams.end_size)));\n\n        // read beg\n        fseek(f, 0, SEEK_SET);\n        int n_read = fread(read_buf.data(), 1, hparams.beg_size, f);\n        for (int j = 0; j < n_read; j++) {\n            // pad at the end\n            buf[j] = read_buf[j];\n        }\n\n        // read mid\n        long mid_offs = std::max(0L, (fsize - hparams.mid_size) / 2);\n        fseek(f, mid_offs, SEEK_SET);\n        n_read = fread(read_buf.data(), 1, hparams.mid_size, f);\n        for (int j = 0; j < n_read; j++) {\n            // pad at both ends\n            long mid_idx = hparams.beg_size + (hparams.mid_size / 2) - n_read / 2 + j;\n            buf[mid_idx] = read_buf[j];\n        }\n\n        // read end\n        long end_offs = std::max(0L, fsize - hparams.end_size);\n        fseek(f, end_offs, SEEK_SET);\n        n_read = fread(read_buf.data(), 1, hparams.end_size, f);\n        for (int j = 0; j < n_read; j++) {\n            // pad at the beginning\n            int end_idx = hparams.beg_size + hparams.mid_size + hparams.end_size - n_read + j;\n            buf[end_idx] = read_buf[j];\n        }\n\n        fclose(f);\n\n        const size_t inp_bytes = hparams.beg_size + hparams.mid_size + hparams.end_size;\n\n        // convert to one-hot\n        std::vector<float> one_hot(257*inp_bytes);\n        for (size_t j = 0; j < inp_bytes; j++) {\n            one_hot[257*j + buf[j]] = 1.0f;\n        }\n\n        ggml_backend_tensor_set(input, one_hot.data(), 257*inp_bytes*i*sizeof(float), 257*inp_bytes*sizeof(float));\n    }\n\n    if (ggml_backend_graph_compute(model.backend, gf) != GGML_STATUS_SUCCESS) {\n        fprintf(stderr, \"%s: ggml_backend_graph_compute() failed\\n\", __func__);\n        return false;\n    }\n\n    struct ggml_tensor * target_label_probs = ggml_graph_get_tensor(gf, \"target_label_probs\");\n\n    // print probabilities for the top labels of each file\n    for (size_t i = 0; i < fnames.size(); i++) {\n        std::vector<float> probs(hparams.n_label);\n        ggml_backend_tensor_get(target_label_probs, probs.data(), hparams.n_label*i*sizeof(float), hparams.n_label*sizeof(float));\n\n        // sort the probabilities\n        std::vector<int> idx(hparams.n_label);\n        std::iota(idx.begin(), idx.end(), 0);\n        std::sort(idx.begin(), idx.end(), [&probs](int i1, int i2) { return probs[i1] > probs[i2]; });\n\n        // print the top labels\n        const int top_n = 5;\n        printf(\"%-30s: \", fnames[i].c_str());\n        for (int j = 0; j < top_n; j++) {\n            printf(\"%s (%.2f%%) \", magika_labels[idx[j]], probs[idx[j]]*100);\n        }\n        printf(\"\\n\");\n    }\n\n    return true;\n}\n\nint main(int argc, const char ** argv) {\n    if (argc < 3) {\n        fprintf(stderr, \"usage: %s <model> <file1> [<file2> ...]\\n\", argv[0]);\n        return 1;\n    }\n\n    const char * model_fname = argv[1];\n    std::vector<std::string> fnames;\n    for (int i = 2; i < argc; i++) {\n        fnames.push_back(argv[i]);\n    }\n\n    magika_model model;\n    if (!magika_model_load(model_fname, model)) {\n        fprintf(stderr, \"magika_model_load() failed\\n\");\n        return 1;\n    }\n\n    magika_eval(model, fnames);\n\n    return 0;\n}\n"
  },
  {
    "path": "examples/mnist/.gitignore",
    "content": "data/\n*.gguf\n*.ggml\n"
  },
  {
    "path": "examples/mnist/CMakeLists.txt",
    "content": "#\n# mnist-common\n\nset(TEST_TARGET mnist-common)\nadd_library(${TEST_TARGET} STATIC mnist-common.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml common)\n\n#\n# mnist-eval\n\nset(TEST_TARGET mnist-eval)\nadd_executable(${TEST_TARGET} mnist-eval.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml common mnist-common)\n\n#\n# mnist-train\n\nset(TEST_TARGET mnist-train)\nadd_executable(${TEST_TARGET} mnist-train.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml common mnist-common)\n\n\n#\n# mnist-wasm\nif (EMSCRIPTEN)\n    set(TARGET mnist)\n\n    add_executable(${TARGET} mnist-common.cpp)\n    target_link_libraries(${TARGET} PRIVATE ggml ggml-cpu)\n\n    set_target_properties(${TARGET} PROPERTIES LINK_FLAGS \" \\\n        --bind \\\n        -s FORCE_FILESYSTEM=1 \\\n        -s USE_PTHREADS=1 \\\n        -s PTHREAD_POOL_SIZE=10 \\\n        -s ASSERTIONS=1 \\\n        -s WASM=1 \\\n        -s EXPORTED_RUNTIME_METHODS=\\\"['ccall', 'cwrap', 'setValue', 'getValue']\\\" \\\n        -s EXPORTED_FUNCTIONS=\\\"['_wasm_eval','_wasm_random_digit','_malloc','_free']\\\" \\\n        -s ALLOW_MEMORY_GROWTH=1 \\\n        --preload-file ${CMAKE_CURRENT_SOURCE_DIR}/mnist-f32.gguf@/ \\\n        --preload-file ${CMAKE_CURRENT_SOURCE_DIR}/t10k-images-idx3-ubyte@/ \\\n        \")\n\n    # Copy output to web directory\n    add_custom_command(\n        TARGET ${TARGET} POST_BUILD\n        COMMAND ${CMAKE_COMMAND} -E copy\n            ${CMAKE_BINARY_DIR}/bin/mnist.js\n            ${CMAKE_CURRENT_SOURCE_DIR}/web/mnist.js\n        COMMAND ${CMAKE_COMMAND} -E copy\n            ${CMAKE_BINARY_DIR}/bin/mnist.wasm\n            ${CMAKE_CURRENT_SOURCE_DIR}/web/mnist.wasm\n        COMMAND ${CMAKE_COMMAND} -E copy\n            ${CMAKE_BINARY_DIR}/bin/mnist.worker.js\n            ${CMAKE_CURRENT_SOURCE_DIR}/web/mnist.worker.js\n        )\nendif()\n"
  },
  {
    "path": "examples/mnist/README.md",
    "content": "# MNIST Examples for GGML\n\nThis directory contains simple examples of how to use GGML for training and inference using the [MNIST dataset](https://yann.lecun.com/exdb/mnist/).\nAll commands listed in this README assume the working directory to be `examples/mnist`.\nPlease note that training in GGML is a work-in-progress and not production ready.\n\n## Obtaining the data\n\nA description of the dataset can be found on [Yann LeCun's website](https://yann.lecun.com/exdb/mnist/).\nWhile it is also in principle possible to download the dataset from this website these downloads are frequently throttled and\nit is recommended to use [HuggingFace](https://huggingface.co/datasets/ylecun/mnist) instead.\nThe dataset will be downloaded automatically when running `mnist-train-fc.py`.\n\n## Fully connected network\n\nFor our first example we will train a fully connected network.\nTo train a fully connected model in PyTorch and save it as a GGUF file, run:\n\n```bash\n$ python3 mnist-train-fc.py mnist-fc-f32.gguf\n\n...\n\nTest loss: 0.066377+-0.010468, Test accuracy: 97.94+-0.14%\n\nModel tensors saved to mnist-fc-f32.gguf:\nfc1.weight       (500, 784)\nfc1.bias         (500,)\nfc2.weight       (10, 500)\nfc2.bias         (10,)\n```\n\nThe training script includes an evaluation of the model on the test set.\nTo evaluate the model on the CPU using GGML, run:\n\n```bash\n$ ../../build/bin/mnist-eval mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte\n\n________________________________________________________\n________________________________________________________\n________________________________________________________\n________________________________________________________\n__________________________________####__________________\n______________________________########__________________\n__________________________##########____________________\n______________________##############____________________\n____________________######________####__________________\n__________________________________####__________________\n__________________________________####__________________\n________________________________####____________________\n______________________________####______________________\n________________________##########______________________\n______________________########__####____________________\n________________________##__________##__________________\n____________________________________##__________________\n__________________________________##____________________\n__________________________________##____________________\n________________________________##______________________\n____________________________####________________________\n__________##____________######__________________________\n__________##############________________________________\n________________####____________________________________\n________________________________________________________\n________________________________________________________\n________________________________________________________\n________________________________________________________\nggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no\nggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no\nggml_cuda_init: found 1 CUDA devices:\n  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes\nmnist_model: using CUDA0 (NVIDIA GeForce RTX 3090) as primary backend\nmnist_model: unsupported operations will be executed on the following fallback backends (in order of priority):\nmnist_model:  - CPU (AMD Ryzen 9 5950X 16-Core Processor)\nmnist_model_init_from_file: loading model weights from 'mnist-fc-f32.gguf'\nmnist_model_init_from_file: model arch is mnist-fc\nmnist_model_init_from_file: successfully loaded weights from mnist-fc-f32.gguf\nmain: loaded model in 109.44 ms\nmnist_model_eval: model evaluation on 10000 images took 76.92 ms, 7.69 us/image\nmain: predicted digit is 3\nmain: test_loss=0.066379+-0.009101\nmain: test_acc=97.94+-0.14%\n```\n\nIn addition to the evaluation on the test set the GGML evaluation also prints a random image from the test set as well as the model prediction for said image.\nTo train a fully connected model on the CPU using GGML run:\n\n``` bash\n$ ../../build/bin/mnist-train mnist-fc mnist-fc-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte\n```\n\nIt can then be evaluated with the same binary as above.\n\n## Convolutional network\n\nTo train a convolutional network using TensorFlow run:\n\n```bash\n$ python3 mnist-train-cnn.py mnist-cnn-f32.gguf\n\n...\n\nTest loss: 0.047947\nTest accuracy: 98.46%\nGGUF model saved to 'mnist-cnn-f32.gguf'\n```\n\nThe saved model can be evaluated on the CPU using the `mnist-eval` binary:\n\n```bash\n$ ../../build/bin/mnist-eval mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte\n\n________________________________________________________\n________________________________________________________\n________________________________________________________\n________________________________________________________\n________________________________________________________\n______________________________________##________________\n______________________________________##________________\n______________________________________##________________\n____________________________________##__________________\n__________________________________####__________________\n__________________________________##____________________\n________________________________##______________________\n______________________________##________________________\n____________________________####________________________\n____________________________##__________________________\n__________________________##____________________________\n________________________##______________________________\n______________________##________________________________\n____________________####________________________________\n____________________##__________________________________\n__________________##____________________________________\n________________##______________________________________\n________________________________________________________\n________________________________________________________\n________________________________________________________\n________________________________________________________\n________________________________________________________\n________________________________________________________\nggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no\nggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no\nggml_cuda_init: found 1 CUDA devices:\n  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes\nmnist_model: using CUDA0 (NVIDIA GeForce RTX 3090) as primary backend\nmnist_model: unsupported operations will be executed on the following fallback backends (in order of priority):\nmnist_model:  - CPU (AMD Ryzen 9 5950X 16-Core Processor)\nmnist_model_init_from_file: loading model weights from 'mnist-cnn-f32.gguf'\nmnist_model_init_from_file: model arch is mnist-cnn\nmnist_model_init_from_file: successfully loaded weights from mnist-cnn-f32.gguf\nmain: loaded model in 91.99 ms\nmnist_model_eval: model evaluation on 10000 images took 267.61 ms, 26.76 us/image\nmain: predicted digit is 1\nmain: test_loss=0.047955+-0.007029\nmain: test_acc=98.46+-0.12%\n```\n\nLike with the fully connected network the convolutional network can also be trained using GGML:\n\n``` bash\n$ ../../build/bin/mnist-train mnist-cnn mnist-cnn-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte\n```\n\nAs always, the evaluation is done using `mnist-eval` and like with the fully connected network the GGML graph is exported to `mnist-cnn-f32.ggml`.\n\n## Hardware Acceleration\n\nBoth the training and evaluation code is agnostic in terms of hardware as long as the corresponding GGML backend has implemented the necessary operations.\nA specific backend can be selected by appending the above commands with a backend name.\nThe compute graphs then schedule the operations to preferentially use the specified backend.\nNote that if a backend does not implement some of the necessary operations a CPU fallback is used instead which may result in bad performance.\n\n## Web demo\n\nThe evaluation code can be compiled to WebAssembly using [Emscripten](https://emscripten.org/) (may need to re-login to update `$PATH` after installation).\nFirst, copy the GGUF file of either of the trained models to `examples/mnist` and name it `mnist-f32.gguf`.\nCopy the test set to `examples/mnist` and name it `t10k-images-idx3-ubyte`.\nSymlinking these files will *not* work!\nCompile the code like so:\n\n```bash\n$ cd ../../\n$ mkdir -p build-em\n$ emcmake cmake .. -DGGML_BUILD_EXAMPLES=ON \\\n    -DCMAKE_C_FLAGS=\"-pthread -matomics -mbulk-memory\" \\\n    -DCMAKE_CXX_FLAGS=\"-pthread -matomics -mbulk-memory\"\n$ make mnist\n```\n\nThe compilation output is copied into `examples/mnist/web`.\nTo run it, you need an HTTP server.\nFor example:\n\n``` bash\n$ python3 examples/mnist/server.py\n\nServing directory '/home/danbev/work/ai/ggml/examples/mnist/web' at http://localhost:8000\nApplication context root: http://localhost:8000/\n```\n\nThe web demo can then be accessed via the link printed on the console.\nSimply draw a digit on the canvas and the model will try to predict what it's supposed to be.\nAlternatively, click the \"Random\" button to retrieve a random digit from the test set.\nBe aware that like all neural networks the one we trained is susceptible to distributional shift:\nif the numbers you draw look different than the ones in the training set\n(e.g. because they're not centered) the model will perform comparatively worse.\nAn online demo can be accessed [here](https://mnist.ggerganov.com).\n"
  },
  {
    "path": "examples/mnist/mnist-common.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-opt.h\"\n\n#include \"mnist-common.h\"\n\n#include <algorithm>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <cstdint>\n#include <fstream>\n#include <random>\n#include <string>\n#include <utility>\n\nbool mnist_image_load(const std::string & fname, ggml_opt_dataset_t dataset) {\n    auto fin = std::ifstream(fname, std::ios::binary);\n    if (!fin) {\n        fprintf(stderr, \"failed to open images file %s\\n\", fname.c_str());\n        return false;\n    }\n    fin.seekg(16);\n\n    uint8_t image[MNIST_NINPUT];\n    struct ggml_tensor * images = ggml_opt_dataset_data(dataset);\n    float * buf = ggml_get_data_f32(images);\n\n    GGML_ASSERT(images->ne[0] == MNIST_NINPUT);\n    for (int64_t iex = 0; iex < images->ne[1]; ++iex) {\n        fin.read((char *) image, sizeof(image));\n\n        for (int64_t i = 0; i < MNIST_NINPUT; ++i) {\n            buf[iex*MNIST_NINPUT + i] = image[i] / 255.0f; // Normalize to [0, 1]\n        }\n    }\n\n    return true;\n}\n\nvoid mnist_image_print(FILE * stream, ggml_opt_dataset_t dataset, const int iex) {\n    struct ggml_tensor * images = ggml_opt_dataset_data(dataset);\n    GGML_ASSERT(images->ne[0] == MNIST_NINPUT);\n    GGML_ASSERT(iex < images->ne[1]);\n    const float * image = ggml_get_data_f32(images) + iex*MNIST_NINPUT;\n\n    for (int64_t row = 0; row < MNIST_HW; row++) {\n        for (int64_t col = 0; col < MNIST_HW; col++) {\n            const int rgb = roundf(255.0f * image[row*MNIST_HW + col]);\n#ifdef _WIN32\n            fprintf(stream, \"%s\", rgb >= 220 ? \"##\" : \"__\");                // Represented via text.\n#else\n            fprintf(stream, \"\\033[48;2;%d;%d;%dm  \\033[0m\", rgb, rgb, rgb); // Represented via colored blocks.\n#endif // _WIN32\n        }\n        fprintf(stream, \"\\n\");\n    }\n}\n\nbool mnist_label_load(const std::string & fname, ggml_opt_dataset_t dataset) {\n    auto fin = std::ifstream(fname, std::ios::binary);\n    if (!fin) {\n        fprintf(stderr, \"failed to open labels file %s\\n\", fname.c_str());\n        return 0;\n    }\n    fin.seekg(8);\n\n    uint8_t label;\n    struct ggml_tensor * labels = ggml_opt_dataset_labels(dataset);\n    float * buf = ggml_get_data_f32(labels);\n\n    GGML_ASSERT(labels->ne[0] == MNIST_NCLASSES);\n    for (int64_t iex = 0; iex < labels->ne[1]; ++iex) {\n        fin.read((char *) &label, sizeof(label));\n\n        for (int64_t i = 0; i < MNIST_NCLASSES; ++i) {\n            buf[iex*MNIST_NCLASSES + i] = i == label ? 1.0f : 0.0f;\n        }\n    }\n\n    return true;\n}\n\n// Temporary util function for loading data from GGUF to a backend != CPU until GGML itself provides this functionality:\nbool load_from_gguf(const char * fname, struct ggml_context * ctx_ggml, struct gguf_context * ctx_gguf) {\n    FILE * f = ggml_fopen(fname, \"rb\");\n    if (!f) {\n        return false;\n    }\n\n    const size_t buf_size = 4*1024*1024;\n    void * buf = malloc(buf_size);\n\n    const int n_tensors = gguf_get_n_tensors(ctx_gguf);\n    for (int i = 0; i < n_tensors; i++) {\n        const char * name = gguf_get_tensor_name(ctx_gguf, i);\n\n        struct ggml_tensor * tensor = ggml_get_tensor(ctx_ggml, name);\n        if (!tensor) {\n            continue;\n        }\n\n        const size_t offs = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i);\n\n        if (fseek(f, offs, SEEK_SET) != 0) {\n            fclose(f);\n            free(buf);\n            return false;\n        }\n\n        const size_t nbytes = ggml_nbytes(tensor);\n        for (size_t pos = 0; pos < nbytes; pos += buf_size) {\n            const size_t nbytes_cpy = buf_size < nbytes - pos ? buf_size : nbytes - pos;\n\n            if (fread(buf, 1, nbytes_cpy, f) != nbytes_cpy) {\n                fclose(f);\n                free(buf);\n                return false;\n            }\n\n            ggml_backend_tensor_set(tensor, buf, pos, nbytes_cpy);\n        }\n    }\n\n    fclose(f);\n    free(buf);\n    return true;\n}\n\nmnist_model mnist_model_init_from_file(const std::string & fname, const std::string & backend, const int nbatch_logical, const int nbatch_physical) {\n    mnist_model model(backend, nbatch_logical, nbatch_physical);\n    fprintf(stderr, \"%s: loading model weights from '%s'\\n\", __func__, fname.c_str());\n\n    struct gguf_context * ctx;\n    {\n        struct gguf_init_params params = {\n            /*.no_alloc   =*/ true,\n            /*.ctx        =*/ &model.ctx_gguf,\n        };\n        ctx = gguf_init_from_file(fname.c_str(), params);\n        if (!ctx) {\n            fprintf(stderr, \"%s: gguf_init_from_file() failed\\n\", __func__);\n            exit(1);\n        }\n    }\n    model.arch = gguf_get_val_str(ctx, gguf_find_key(ctx, \"general.architecture\"));\n    fprintf(stderr, \"%s: model arch is %s\\n\", __func__, model.arch.c_str());\n\n    if (model.arch == \"mnist-fc\") {\n        model.fc1_weight = ggml_get_tensor(model.ctx_gguf, \"fc1.weight\");\n        GGML_ASSERT(model.fc1_weight->ne[0] == MNIST_NINPUT);\n        GGML_ASSERT(model.fc1_weight->ne[1] == MNIST_NHIDDEN);\n        GGML_ASSERT(model.fc1_weight->ne[2] == 1);\n        GGML_ASSERT(model.fc1_weight->ne[3] == 1);\n\n        model.fc1_bias = ggml_get_tensor(model.ctx_gguf, \"fc1.bias\");\n        GGML_ASSERT(model.fc1_bias->ne[0] == MNIST_NHIDDEN);\n        GGML_ASSERT(model.fc1_bias->ne[1] == 1);\n        GGML_ASSERT(model.fc1_bias->ne[2] == 1);\n        GGML_ASSERT(model.fc1_bias->ne[3] == 1);\n\n        model.fc2_weight = ggml_get_tensor(model.ctx_gguf, \"fc2.weight\");\n        GGML_ASSERT(model.fc2_weight->ne[0] == MNIST_NHIDDEN);\n        GGML_ASSERT(model.fc2_weight->ne[1] == MNIST_NCLASSES);\n        GGML_ASSERT(model.fc2_weight->ne[2] == 1);\n        GGML_ASSERT(model.fc2_weight->ne[3] == 1);\n\n        model.fc2_bias = ggml_get_tensor(model.ctx_gguf, \"fc2.bias\");\n        GGML_ASSERT(model.fc2_bias->ne[0] == MNIST_NCLASSES);\n        GGML_ASSERT(model.fc2_bias->ne[1] == 1);\n        GGML_ASSERT(model.fc2_bias->ne[2] == 1);\n        GGML_ASSERT(model.fc2_bias->ne[3] == 1);\n    } else if (model.arch == \"mnist-cnn\") {\n        model.conv1_kernel = ggml_get_tensor(model.ctx_gguf, \"conv1.kernel\");\n        GGML_ASSERT(model.conv1_kernel->type == GGML_TYPE_F32);\n        GGML_ASSERT(model.conv1_kernel->ne[0] == 3);\n        GGML_ASSERT(model.conv1_kernel->ne[1] == 3);\n        GGML_ASSERT(model.conv1_kernel->ne[2] == 1);\n        GGML_ASSERT(model.conv1_kernel->ne[3] == MNIST_CNN_NCB);\n\n        model.conv1_bias = ggml_get_tensor(model.ctx_gguf, \"conv1.bias\");\n        GGML_ASSERT(model.conv1_bias->type == GGML_TYPE_F32);\n        GGML_ASSERT(model.conv1_bias->ne[0] == 1);\n        GGML_ASSERT(model.conv1_bias->ne[1] == 1);\n        GGML_ASSERT(model.conv1_bias->ne[2] == MNIST_CNN_NCB);\n        GGML_ASSERT(model.conv1_bias->ne[3] == 1);\n\n        model.conv2_kernel = ggml_get_tensor(model.ctx_gguf, \"conv2.kernel\");\n        GGML_ASSERT(model.conv2_kernel->type == GGML_TYPE_F32);\n        GGML_ASSERT(model.conv2_kernel->ne[0] == 3);\n        GGML_ASSERT(model.conv2_kernel->ne[1] == 3);\n        GGML_ASSERT(model.conv2_kernel->ne[2] == MNIST_CNN_NCB);\n        GGML_ASSERT(model.conv2_kernel->ne[3] == MNIST_CNN_NCB*2);\n\n        model.conv2_bias = ggml_get_tensor(model.ctx_gguf, \"conv2.bias\");\n        GGML_ASSERT(model.conv2_bias->type == GGML_TYPE_F32);\n        GGML_ASSERT(model.conv2_bias->ne[0] == 1);\n        GGML_ASSERT(model.conv2_bias->ne[1] == 1);\n        GGML_ASSERT(model.conv2_bias->ne[2] == MNIST_CNN_NCB*2);\n        GGML_ASSERT(model.conv2_bias->ne[3] == 1);\n\n        model.dense_weight = ggml_get_tensor(model.ctx_gguf, \"dense.weight\");\n        GGML_ASSERT(model.dense_weight->type == GGML_TYPE_F32);\n        GGML_ASSERT(model.dense_weight->ne[0] == (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2));\n        GGML_ASSERT(model.dense_weight->ne[1] == MNIST_NCLASSES);\n        GGML_ASSERT(model.dense_weight->ne[2] == 1);\n        GGML_ASSERT(model.dense_weight->ne[3] == 1);\n\n        model.dense_bias = ggml_get_tensor(model.ctx_gguf, \"dense.bias\");\n        GGML_ASSERT(model.dense_bias->type == GGML_TYPE_F32);\n        GGML_ASSERT(model.dense_bias->ne[0] == MNIST_NCLASSES);\n        GGML_ASSERT(model.dense_bias->ne[1] == 1);\n        GGML_ASSERT(model.dense_bias->ne[2] == 1);\n        GGML_ASSERT(model.dense_bias->ne[3] == 1);\n    } else {\n        fprintf(stderr, \"%s: unknown model arch: %s\\n\", __func__, model.arch.c_str());\n    }\n\n    model.buf_gguf = ggml_backend_alloc_ctx_tensors(model.ctx_gguf, model.backends[0]);\n\n    if(!load_from_gguf(fname.c_str(), model.ctx_gguf, ctx)) {\n        fprintf(stderr, \"%s: loading weights from %s failed\\n\", __func__, fname.c_str());\n        exit(1);\n    }\n\n    // The space in ctx_gguf exactly fits the model weights,\n    // the images (which also need to be statically allocated) need to be put in a different context.\n\n    model.images = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT, nbatch_physical);\n\n    ggml_set_name(model.images, \"images\");\n    ggml_set_input(model.images);\n\n    model.buf_static = ggml_backend_alloc_ctx_tensors(model.ctx_static, model.backends[0]);\n\n    fprintf(stderr, \"%s: successfully loaded weights from %s\\n\", __func__, fname.c_str());\n    return model;\n}\n\nmnist_model mnist_model_init_random(const std::string & arch, const std::string & backend, const int nbatch_logical, const int nbatch_physical) {\n    mnist_model model(backend, nbatch_logical, nbatch_physical);\n    model.arch = arch;\n\n    std::random_device rd{};\n    std::mt19937 gen{rd()};\n    std::normal_distribution<float> nd{0.0f, 1e-2f};\n    std::vector<ggml_tensor *> init_tensors;\n\n    if (model.arch == \"mnist-fc\") {\n        fprintf(stderr, \"%s: initializing random weights for a fully connected model\\n\", __func__);\n\n        model.fc1_weight = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT,  MNIST_NHIDDEN);\n        model.fc1_bias   = ggml_new_tensor_1d(model.ctx_static, GGML_TYPE_F32,                MNIST_NHIDDEN);\n        model.fc2_weight = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NHIDDEN, MNIST_NCLASSES);\n        model.fc2_bias   = ggml_new_tensor_1d(model.ctx_static, GGML_TYPE_F32,                MNIST_NCLASSES);\n\n        ggml_set_name(model.fc1_weight, \"fc1.weight\");\n        ggml_set_name(model.fc1_bias,   \"fc1.bias\");\n        ggml_set_name(model.fc2_weight, \"fc2.weight\");\n        ggml_set_name(model.fc2_bias,   \"fc2.bias\");\n\n        init_tensors.push_back(model.fc1_weight);\n        init_tensors.push_back(model.fc1_bias);\n        init_tensors.push_back(model.fc2_weight);\n        init_tensors.push_back(model.fc2_bias);\n    } else if (model.arch == \"mnist-cnn\") {\n        model.conv1_kernel = ggml_new_tensor_4d(model.ctx_static, GGML_TYPE_F32, 3, 3, 1, MNIST_CNN_NCB);\n        model.conv1_bias   = ggml_new_tensor_3d(model.ctx_static, GGML_TYPE_F32, 1, 1,    MNIST_CNN_NCB);\n        model.conv2_kernel = ggml_new_tensor_4d(model.ctx_static, GGML_TYPE_F32, 3, 3, MNIST_CNN_NCB, MNIST_CNN_NCB*2);\n        model.conv2_bias   = ggml_new_tensor_3d(model.ctx_static, GGML_TYPE_F32, 1, 1,                MNIST_CNN_NCB*2);\n        model.dense_weight = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2), MNIST_NCLASSES);\n        model.dense_bias   = ggml_new_tensor_1d(model.ctx_static, GGML_TYPE_F32, MNIST_NCLASSES);\n\n        ggml_set_name(model.conv1_kernel, \"conv1.kernel\");\n        ggml_set_name(model.conv1_bias,   \"conv1.bias\");\n        ggml_set_name(model.conv2_kernel, \"conv2.kernel\");\n        ggml_set_name(model.conv2_bias,   \"conv2.bias\");\n        ggml_set_name(model.dense_weight, \"dense.weight\");\n        ggml_set_name(model.dense_bias,   \"dense.bias\");\n\n        init_tensors.push_back(model.conv1_kernel);\n        init_tensors.push_back(model.conv1_bias);\n        init_tensors.push_back(model.conv2_kernel);\n        init_tensors.push_back(model.conv2_bias);\n        init_tensors.push_back(model.dense_weight);\n        init_tensors.push_back(model.dense_bias);\n    } else {\n        fprintf(stderr, \"%s: unknown model arch: %s\\n\", __func__, model.arch.c_str());\n    }\n\n    model.images = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NBATCH_PHYSICAL);\n    ggml_set_name(model.images, \"images\");\n    ggml_set_input(model.images);\n\n    model.buf_static = ggml_backend_alloc_ctx_tensors(model.ctx_static, model.backends[0]);\n\n    for (ggml_tensor * t : init_tensors) {\n        GGML_ASSERT(t->type == GGML_TYPE_F32);\n        const int64_t ne = ggml_nelements(t);\n        std::vector<float> tmp(ne);\n\n        for (int64_t i = 0; i < ne; ++i) {\n            tmp[i] = nd(gen);\n        }\n        ggml_backend_tensor_set(t, tmp.data(), 0, ggml_nbytes(t));\n    }\n\n    return model;\n}\n\nvoid mnist_model_build(mnist_model & model) {\n    if (model.arch == \"mnist-fc\") {\n        ggml_set_param(model.fc1_weight);\n        ggml_set_param(model.fc1_bias);\n        ggml_set_param(model.fc2_weight);\n        ggml_set_param(model.fc2_bias);\n\n        ggml_tensor * fc1 = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,\n            ggml_mul_mat(model.ctx_compute, model.fc1_weight, model.images),\n            model.fc1_bias));\n        model.logits = ggml_add(model.ctx_compute,\n            ggml_mul_mat(model.ctx_compute, model.fc2_weight, fc1),\n            model.fc2_bias);\n    } else if (model.arch == \"mnist-cnn\") {\n        ggml_set_param(model.conv1_kernel);\n        ggml_set_param(model.conv1_bias);\n        ggml_set_param(model.conv2_kernel);\n        ggml_set_param(model.conv2_bias);\n        ggml_set_param(model.dense_weight);\n        ggml_set_param(model.dense_bias);\n\n        struct ggml_tensor * images_2D = ggml_reshape_4d(model.ctx_compute, model.images, MNIST_HW, MNIST_HW, 1, model.images->ne[1]);\n\n        struct ggml_tensor * conv1_out = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,\n            ggml_conv_2d(model.ctx_compute, model.conv1_kernel, images_2D, 1, 1, 1, 1, 1, 1),\n            model.conv1_bias));\n        GGML_ASSERT(conv1_out->ne[0] == MNIST_HW);\n        GGML_ASSERT(conv1_out->ne[1] == MNIST_HW);\n        GGML_ASSERT(conv1_out->ne[2] == MNIST_CNN_NCB);\n        GGML_ASSERT(conv1_out->ne[3] == model.nbatch_physical);\n\n        struct ggml_tensor * conv2_in = ggml_pool_2d(model.ctx_compute, conv1_out, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);\n        GGML_ASSERT(conv2_in->ne[0] == MNIST_HW/2);\n        GGML_ASSERT(conv2_in->ne[1] == MNIST_HW/2);\n        GGML_ASSERT(conv2_in->ne[2] == MNIST_CNN_NCB);\n        GGML_ASSERT(conv2_in->ne[3] == model.nbatch_physical);\n\n        struct ggml_tensor * conv2_out = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,\n            ggml_conv_2d(model.ctx_compute, model.conv2_kernel, conv2_in, 1, 1, 1, 1, 1, 1),\n            model.conv2_bias));\n        GGML_ASSERT(conv2_out->ne[0] == MNIST_HW/2);\n        GGML_ASSERT(conv2_out->ne[1] == MNIST_HW/2);\n        GGML_ASSERT(conv2_out->ne[2] == MNIST_CNN_NCB*2);\n        GGML_ASSERT(conv2_out->ne[3] == model.nbatch_physical);\n\n        struct ggml_tensor * dense_in = ggml_pool_2d(model.ctx_compute, conv2_out, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);\n        GGML_ASSERT(dense_in->ne[0] == MNIST_HW/4);\n        GGML_ASSERT(dense_in->ne[1] == MNIST_HW/4);\n        GGML_ASSERT(dense_in->ne[2] == MNIST_CNN_NCB*2);\n        GGML_ASSERT(dense_in->ne[3] == model.nbatch_physical);\n\n        dense_in = ggml_reshape_2d(model.ctx_compute,\n            ggml_cont(model.ctx_compute, ggml_permute(model.ctx_compute, dense_in, 1, 2, 0, 3)),\n            (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2), model.nbatch_physical);\n        GGML_ASSERT(dense_in->ne[0] == (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2));\n        GGML_ASSERT(dense_in->ne[1] == model.nbatch_physical);\n        GGML_ASSERT(dense_in->ne[2] == 1);\n        GGML_ASSERT(dense_in->ne[3] == 1);\n\n        model.logits = ggml_add(model.ctx_compute, ggml_mul_mat(model.ctx_compute, model.dense_weight, dense_in), model.dense_bias);\n    } else {\n        GGML_ASSERT(false);\n    }\n\n    ggml_set_name(model.logits, \"logits\");\n    ggml_set_output(model.logits);\n    GGML_ASSERT(model.logits->type == GGML_TYPE_F32);\n    GGML_ASSERT(model.logits->ne[0] == MNIST_NCLASSES);\n    GGML_ASSERT(model.logits->ne[1] == model.nbatch_physical);\n    GGML_ASSERT(model.logits->ne[2] == 1);\n    GGML_ASSERT(model.logits->ne[3] == 1);\n}\n\nggml_opt_result_t mnist_model_eval(mnist_model & model, ggml_opt_dataset_t dataset) {\n    ggml_opt_result_t result = ggml_opt_result_init();\n\n    ggml_opt_params params = ggml_opt_default_params(model.backend_sched, GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);\n    params.ctx_compute = model.ctx_compute;\n    params.inputs      = model.images;\n    params.outputs     = model.logits;\n    params.build_type  = GGML_OPT_BUILD_TYPE_FORWARD;\n    ggml_opt_context_t opt_ctx = ggml_opt_init(params);\n\n    {\n        const int64_t t_start_us = ggml_time_us();\n\n        ggml_opt_epoch(opt_ctx, dataset, nullptr, result, /*idata_split =*/ 0, nullptr, nullptr);\n\n        const int64_t t_total_us = ggml_time_us() - t_start_us;\n        const double t_total_ms = 1e-3*t_total_us;\n        const int nex = ggml_opt_dataset_data(dataset)->ne[1];\n        fprintf(stderr, \"%s: model evaluation on %d images took %.2lf ms, %.2lf us/image\\n\",\n                __func__, nex, t_total_ms, (double) t_total_us/nex);\n    }\n\n    ggml_opt_free(opt_ctx);\n\n    return result;\n}\n\nvoid mnist_model_train(mnist_model & model, ggml_opt_dataset_t dataset, const int nepoch, const float val_split) {\n    ggml_opt_fit(model.backend_sched, model.ctx_compute, model.images, model.logits, dataset,\n        GGML_OPT_LOSS_TYPE_CROSS_ENTROPY, GGML_OPT_OPTIMIZER_TYPE_ADAMW, ggml_opt_get_default_optimizer_params, nepoch, model.nbatch_logical, val_split, false);\n}\n\nvoid mnist_model_save(mnist_model & model, const std::string & fname) {\n    printf(\"%s: saving model to '%s'\\n\", __func__, fname.c_str());\n\n    struct ggml_context * ggml_ctx;\n    {\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ 100 * 1024*1024,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ false,\n        };\n        ggml_ctx = ggml_init(params);\n    }\n\n    gguf_context * gguf_ctx = gguf_init_empty();\n    gguf_set_val_str(gguf_ctx, \"general.architecture\", model.arch.c_str());\n\n    std::vector<struct ggml_tensor *> weights;\n    if (model.arch == \"mnist-fc\") {\n        weights = {model.fc1_weight, model.fc1_bias, model.fc2_weight, model.fc2_bias};\n    } else if (model.arch == \"mnist-cnn\") {\n        weights = {model.conv1_kernel, model.conv1_bias, model.conv2_kernel, model.conv2_bias, model.dense_weight, model.dense_bias};\n    } else {\n        GGML_ASSERT(false);\n    }\n    for (struct ggml_tensor * t : weights) {\n        struct ggml_tensor * copy = ggml_dup_tensor(ggml_ctx, t);\n        ggml_set_name(copy, t->name);\n        ggml_backend_tensor_get(t, copy->data, 0, ggml_nbytes(t));\n        gguf_add_tensor(gguf_ctx, copy);\n    }\n    gguf_write_to_file(gguf_ctx, fname.c_str(), false);\n\n    ggml_free(ggml_ctx);\n    gguf_free(gguf_ctx);\n}\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nint wasm_eval(uint8_t * digitPtr) {\n    std::vector<float> digit(digitPtr, digitPtr + MNIST_NINPUT);\n\n    ggml_opt_dataset_t dataset = ggml_opt_dataset_init(GGML_TYPE_F32, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NCLASSES, 1, 1);\n    struct ggml_tensor * data = ggml_opt_dataset_data(dataset);\n\n    float * buf = ggml_get_data_f32(data);\n    for (int i = 0; i < MNIST_NINPUT; ++i) {\n        buf[i] = digitPtr[i] / 255.0f;\n    }\n    ggml_set_zero(ggml_opt_dataset_labels(dataset)); // The labels are not needed.\n\n    mnist_model model = mnist_model_init_from_file(\"mnist-f32.gguf\", \"CPU\", /*nbatch_logical =*/ 1, /*nbatch_physical =*/ 1);\n    mnist_model_build(model);\n    ggml_opt_result_t result = mnist_model_eval(model, dataset);\n\n    int32_t pred;\n    ggml_opt_result_pred(result, &pred);\n\n    return pred;\n}\n\nint wasm_random_digit(char * digitPtr) {\n    auto fin = std::ifstream(\"t10k-images-idx3-ubyte\", std::ios::binary);\n    if (!fin) {\n        fprintf(stderr, \"failed to open digits file\\n\");\n        return 0;\n    }\n    srand(time(NULL));\n\n    // Seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000)\n    fin.seekg(16 + MNIST_NINPUT * (rand() % MNIST_NTEST));\n    fin.read(digitPtr, MNIST_NINPUT);\n\n    return 1;\n}\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "examples/mnist/mnist-common.h",
    "content": "#include <algorithm>\n#include <cstdint>\n#include <random>\n#include <string>\n#include <thread>\n#include <vector>\n\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n#include \"ggml.h\"\n#include \"gguf.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-opt.h\"\n\n#define MNIST_NTRAIN 60000\n#define MNIST_NTEST  10000\n\n// Gradient accumulation can be achieved by setting the logical batch size to a multiple of the physical one.\n// The logical batch size determines how many datapoints are used for a gradient update.\n// The physical batch size determines how many datapoints are processed in parallel, larger values utilize compute better but need more memory.\n#define MNIST_NBATCH_LOGICAL  1000\n#define MNIST_NBATCH_PHYSICAL  500\n\nstatic_assert(MNIST_NBATCH_LOGICAL % MNIST_NBATCH_PHYSICAL == 0, \"MNIST_NBATCH_LOGICAL % MNIST_NBATCH_PHYSICAL != 0\");\nstatic_assert(MNIST_NTRAIN % MNIST_NBATCH_LOGICAL == 0, \"MNIST_NTRAIN % MNIST_NBATCH_LOGICAL != 0\");\nstatic_assert(MNIST_NTEST  % MNIST_NBATCH_LOGICAL == 0, \"MNIST_NTRAIN % MNIST_NBATCH_LOGICAL != 0\");\n\n#define MNIST_HW       28\n#define MNIST_NINPUT   (MNIST_HW*MNIST_HW)\n#define MNIST_NCLASSES 10\n\n#define MNIST_NHIDDEN  500\n\n// NCB = number of channels base\n#define MNIST_CNN_NCB 8\n\nstruct mnist_model {\n    std::string arch;\n    ggml_backend_sched_t backend_sched;\n    std::vector<ggml_backend_t> backends;\n    const int nbatch_logical;\n    const int nbatch_physical;\n\n    struct ggml_tensor * images     = nullptr;\n    struct ggml_tensor * logits     = nullptr;\n\n    struct ggml_tensor * fc1_weight = nullptr;\n    struct ggml_tensor * fc1_bias   = nullptr;\n    struct ggml_tensor * fc2_weight = nullptr;\n    struct ggml_tensor * fc2_bias   = nullptr;\n\n    struct ggml_tensor * conv1_kernel = nullptr;\n    struct ggml_tensor * conv1_bias   = nullptr;\n    struct ggml_tensor * conv2_kernel = nullptr;\n    struct ggml_tensor * conv2_bias   = nullptr;\n    struct ggml_tensor * dense_weight = nullptr;\n    struct ggml_tensor * dense_bias   = nullptr;\n\n    struct ggml_context * ctx_gguf    = nullptr;\n    struct ggml_context * ctx_static  = nullptr;\n    struct ggml_context * ctx_compute = nullptr;\n    ggml_backend_buffer_t buf_gguf    = nullptr;\n    ggml_backend_buffer_t buf_static  = nullptr;\n\n    mnist_model(const std::string & backend_name, const int nbatch_logical, const int nbatch_physical)\n            : nbatch_logical(nbatch_logical), nbatch_physical(nbatch_physical) {\n        std::vector<ggml_backend_dev_t> devices;\n        const int ncores_logical = std::thread::hardware_concurrency();\n        const int nthreads = std::min(ncores_logical, (ncores_logical + 4) / 2);\n\n        // Add primary backend:\n        if (!backend_name.empty()) {\n            ggml_backend_dev_t dev = ggml_backend_dev_by_name(backend_name.c_str());\n            if (dev == nullptr) {\n                fprintf(stderr, \"%s: ERROR: backend %s not found, available:\\n\", __func__, backend_name.c_str());\n                for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {\n                    ggml_backend_dev_t dev_i = ggml_backend_dev_get(i);\n                    fprintf(stderr, \"  - %s (%s)\\n\", ggml_backend_dev_name(dev_i), ggml_backend_dev_description(dev_i));\n                }\n                exit(1);\n            }\n\n            ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);\n            GGML_ASSERT(backend);\n\n            if (ggml_backend_is_cpu(backend)) {\n                ggml_backend_cpu_set_n_threads(backend, nthreads);\n            }\n\n            backends.push_back(backend);\n            devices.push_back(dev);\n        }\n\n        // Add all available backends as fallback.\n        // A \"backend\" is a stream on a physical device so there is no problem with adding multiple backends for the same device.\n        for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {\n            ggml_backend_dev_t dev = ggml_backend_dev_get(i);\n\n            ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);\n            GGML_ASSERT(backend);\n\n            if (ggml_backend_is_cpu(backend)) {\n                ggml_backend_cpu_set_n_threads(backend, nthreads);\n            }\n\n            backends.push_back(backend);\n            devices.push_back(dev);\n        }\n\n        // The order of the backends passed to ggml_backend_sched_new determines which backend is given priority.\n        backend_sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), GGML_DEFAULT_GRAPH_SIZE, false, true);\n        fprintf(stderr, \"%s: using %s (%s) as primary backend\\n\",\n                __func__, ggml_backend_name(backends[0]), ggml_backend_dev_description(devices[0]));\n        if (backends.size() >= 2) {\n            fprintf(stderr, \"%s: unsupported operations will be executed on the following fallback backends (in order of priority):\\n\", __func__);\n            for (size_t i = 1; i < backends.size(); ++i) {\n                fprintf(stderr, \"%s:  - %s (%s)\\n\", __func__, ggml_backend_name(backends[i]), ggml_backend_dev_description(devices[i]));\n            }\n        }\n\n        {\n            const size_t size_meta = 1024*ggml_tensor_overhead();\n            struct ggml_init_params params = {\n                /*.mem_size   =*/ size_meta,\n                /*.mem_buffer =*/ nullptr,\n                /*.no_alloc   =*/ true,\n            };\n            ctx_static = ggml_init(params);\n        }\n\n        {\n            // The compute context needs a total of 3 compute graphs: forward pass + backwards pass (with/without optimizer step).\n            const size_t size_meta = GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + 3*ggml_graph_overhead();\n            struct ggml_init_params params = {\n                /*.mem_size   =*/ size_meta,\n                /*.mem_buffer =*/ nullptr,\n                /*.no_alloc   =*/ true,\n            };\n            ctx_compute = ggml_init(params);\n        }\n    }\n\n    ~mnist_model() {\n        ggml_free(ctx_gguf);\n        ggml_free(ctx_static);\n        ggml_free(ctx_compute);\n\n        ggml_backend_buffer_free(buf_gguf);\n        ggml_backend_buffer_free(buf_static);\n        ggml_backend_sched_free(backend_sched);\n        for (ggml_backend_t backend : backends) {\n            ggml_backend_free(backend);\n        }\n    }\n};\n\nbool mnist_image_load(const std::string & fname, ggml_opt_dataset_t dataset);\nvoid mnist_image_print(FILE * f, ggml_opt_dataset_t dataset, const int iex);\nbool mnist_label_load(const std::string & fname, ggml_opt_dataset_t dataset);\n\nmnist_model       mnist_model_init_from_file(const std::string & fname, const std::string & backend, const int nbatch_logical, const int nbatch_physical);\nmnist_model       mnist_model_init_random(const std::string & arch, const std::string & backend, const int nbatch_logical, const int nbatch_physical);\nvoid              mnist_model_build(mnist_model & model);\nggml_opt_result_t mnist_model_eval(mnist_model & model, ggml_opt_dataset_t dataset);\nvoid              mnist_model_train(mnist_model & model, ggml_opt_dataset_t dataset, const int nepoch, const float val_split);\nvoid              mnist_model_save(mnist_model & model, const std::string & fname);\n"
  },
  {
    "path": "examples/mnist/mnist-eval.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-opt.h\"\n\n#include \"mnist-common.h\"\n\n#include <cmath>\n#include <cstdint>\n#include <cstdio>\n#include <cstring>\n#include <ctime>\n#include <string>\n#include <thread>\n#include <vector>\n\n#if defined(_MSC_VER)\n#pragma warning(disable: 4244 4267) // possible loss of data\n#endif\n\nint main(int argc, char ** argv) {\n    srand(time(NULL));\n    ggml_time_init();\n\n    if (argc != 4 && argc != 5) {\n        fprintf(stderr, \"Usage: %s mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte [CPU/CUDA0]\\n\", argv[0]);\n        exit(1);\n    }\n\n    ggml_opt_dataset_t dataset = ggml_opt_dataset_init(GGML_TYPE_F32, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NCLASSES, MNIST_NTEST, MNIST_NBATCH_PHYSICAL);\n\n    if (!mnist_image_load(argv[2], dataset)) {\n        return 1;\n    }\n    if (!mnist_label_load(argv[3], dataset)) {\n        return 1;\n    }\n\n    const int iex = rand() % MNIST_NTEST;\n    mnist_image_print(stdout, dataset, iex);\n\n    const std::string backend = argc >= 5 ? argv[4] : \"\";\n\n    const int64_t t_start_us = ggml_time_us();\n    mnist_model model = mnist_model_init_from_file(argv[1], backend, MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL);\n    mnist_model_build(model);\n    const int64_t t_load_us = ggml_time_us() - t_start_us;\n    fprintf(stdout, \"%s: loaded model in %.2lf ms\\n\", __func__, t_load_us / 1000.0);\n\n    ggml_opt_result_t result_eval = mnist_model_eval(model, dataset);\n\n    std::vector<int32_t> pred(MNIST_NTEST);\n    ggml_opt_result_pred(result_eval, pred.data());\n    fprintf(stdout, \"%s: predicted digit is %d\\n\", __func__, pred[iex]);\n\n    double loss;\n    double loss_unc;\n    ggml_opt_result_loss(result_eval, &loss, &loss_unc);\n    fprintf(stdout, \"%s: test_loss=%.6lf+-%.6lf\\n\", __func__, loss, loss_unc);\n\n    double accuracy;\n    double accuracy_unc;\n    ggml_opt_result_accuracy(result_eval, &accuracy, &accuracy_unc);\n    fprintf(stdout, \"%s: test_acc=%.2lf+-%.2lf%%\\n\", __func__, 100.0*accuracy, 100.0*accuracy_unc);\n\n    ggml_opt_result_free(result_eval);\n\n    return 0;\n}\n"
  },
  {
    "path": "examples/mnist/mnist-train-cnn.py",
    "content": "#!/usr/bin/env python3\nimport sys\nfrom time import time\nimport gguf\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow import keras\nfrom tensorflow.keras import layers\n\n\ndef train(model_path):\n    # Model / data parameters\n    num_classes = 10\n    input_shape = (28, 28, 1)\n\n    # Load the data and split it between train and test sets\n    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n\n    # Scale images to the [0, 1] range\n    x_train = x_train.astype(\"float32\") / 255\n    x_test = x_test.astype(\"float32\") / 255\n    x_train = np.expand_dims(x_train, -1)\n    x_test = np.expand_dims(x_test, -1)\n    print(\"x_train shape:\", x_train.shape)\n    print(x_train.shape[0], \"train samples\")\n    print(x_test.shape[0], \"test samples\")\n\n    # convert class vectors to binary class matrices\n    y_train = keras.utils.to_categorical(y_train, num_classes)\n    y_test = keras.utils.to_categorical(y_test, num_classes)\n\n    model = keras.Sequential(\n        [\n            keras.Input(shape=input_shape, dtype=tf.float32),\n            layers.Conv2D(8, kernel_size=(3, 3), padding=\"same\", activation=\"relu\", dtype=tf.float32),\n            layers.MaxPooling2D(pool_size=(2, 2)),\n            layers.Conv2D(16, kernel_size=(3, 3), padding=\"same\", activation=\"relu\", dtype=tf.float32),\n            layers.MaxPooling2D(pool_size=(2, 2)),\n            layers.Flatten(),\n            layers.Dense(num_classes, activation=\"softmax\", dtype=tf.float32),\n        ]\n    )\n\n    model.summary()\n    batch_size = 1000\n    epochs = 30\n    model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n\n    t_start = time()\n    model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)\n    print(f\"Training took {time()-t_start:.2f}s\")\n\n    score = model.evaluate(x_test, y_test, verbose=0)\n    print(f\"Test loss: {score[0]:.6f}\")\n    print(f\"Test accuracy: {100*score[1]:.2f}%\")\n\n    gguf_writer = gguf.GGUFWriter(model_path, \"mnist-cnn\")\n\n    conv1_kernel = model.layers[0].weights[0].numpy()\n    conv1_kernel = np.moveaxis(conv1_kernel, [2, 3], [0, 1])\n    gguf_writer.add_tensor(\"conv1.kernel\", conv1_kernel, raw_shape=(8, 1, 3, 3))\n\n    conv1_bias = model.layers[0].weights[1].numpy()\n    gguf_writer.add_tensor(\"conv1.bias\", conv1_bias, raw_shape=(1, 8, 1, 1))\n\n    conv2_kernel = model.layers[2].weights[0].numpy()\n    conv2_kernel = np.moveaxis(conv2_kernel, [0, 1, 2, 3], [2, 3, 1, 0])\n    gguf_writer.add_tensor(\"conv2.kernel\", conv2_kernel, raw_shape=(16, 8, 3, 3))\n\n    conv2_bias = model.layers[2].weights[1].numpy()\n    gguf_writer.add_tensor(\"conv2.bias\", conv2_bias, raw_shape=(1, 16, 1, 1))\n\n    dense_weight = model.layers[-1].weights[0].numpy()\n    dense_weight = dense_weight.transpose()\n    gguf_writer.add_tensor(\"dense.weight\", dense_weight, raw_shape=(10, 7*7*16))\n\n    dense_bias = model.layers[-1].weights[1].numpy()\n    gguf_writer.add_tensor(\"dense.bias\", dense_bias)\n\n    gguf_writer.write_header_to_file()\n    gguf_writer.write_kv_data_to_file()\n    gguf_writer.write_tensors_to_file()\n    gguf_writer.close()\n    print(f\"GGUF model saved to '{model_path}'\")\n\n\nif __name__ == '__main__':\n    if len(sys.argv) != 2:\n        print(f\"Usage: {sys.argv[0]} <model_path>\")\n        sys.exit(1)\n    train(sys.argv[1])\n"
  },
  {
    "path": "examples/mnist/mnist-train-fc.py",
    "content": "import gguf\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torchvision.datasets as dsets\nimport torchvision.transforms as transforms\nfrom torch.autograd import Variable\n\nimport sys\nfrom time import time\n\ninput_size  = 784  # img_size = (28,28) ---> 28*28=784 in total\nhidden_size = 500  # number of nodes at hidden layer\nnum_classes = 10   # number of output classes discrete range [0,9]\nnum_epochs  = 30   # number of times which the entire dataset is passed throughout the model\nbatch_size  = 1000 # the size of input data used for one iteration\nlr          = 1e-3 # size of step\n\n\nclass Net(nn.Module):\n    def __init__(self, input_size, hidden_size, num_classes):\n        super(Net, self).__init__()\n        self.fc1 = nn.Linear(input_size, hidden_size)\n        self.relu = nn.ReLU()\n        self.fc2 = nn.Linear(hidden_size, num_classes)\n\n    def forward(self, x):\n        out = self.fc1(x)\n        out = self.relu(out)\n        out = self.fc2(out)\n        return out\n\n\ndef train(model_path):\n    train_data = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)\n    test_data  = dsets.MNIST(root='./data', train=False, transform=transforms.ToTensor())\n\n    assert len(train_data) == 60000\n    assert len(test_data)  == 10000\n\n    kwargs_train_test = dict(batch_size=batch_size, num_workers=4, pin_memory=True)\n    train_gen = torch.utils.data.DataLoader(dataset=train_data, shuffle=True,  **kwargs_train_test)\n    test_gen  = torch.utils.data.DataLoader(dataset=test_data,  shuffle=False, **kwargs_train_test)\n\n    net = Net(input_size, hidden_size, num_classes)\n\n    if torch.cuda.is_available():\n        net.cuda()\n\n    loss_function = nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n\n    t_start = time()\n    for epoch in range(num_epochs):\n        loss_history = []\n        ncorrect = 0\n\n        for i, (images, labels) in enumerate(train_gen):\n            images = Variable(images.view(-1, 28*28))\n            labels = Variable(labels)\n\n            if torch.cuda.is_available():\n                images = images.cuda()\n                labels = labels.cuda()\n\n            optimizer.zero_grad()\n            outputs = net(images)\n            loss = loss_function(outputs, labels)\n\n            loss_history.append(loss.cpu().data)\n            _, predictions = torch.max(outputs, 1)\n            ncorrect += (predictions == labels).sum()\n\n            loss.backward()\n            optimizer.step()\n\n            if (i + 1)*batch_size % 10000 == 0:\n                loss_mean = np.mean(loss_history)\n                accuracy = ncorrect / ((i + 1) * batch_size)\n                print(\n                    f\"Epoch [{epoch+1:02d}/{num_epochs}], \"\n                    f\"Step [{(i+1)*batch_size:05d}/{len(train_data)}], \"\n                    f\"Loss: {loss_mean:.4f}, Accuracy: {100*accuracy:.2f}%\")\n    print()\n    print(f\"Training took {time()-t_start:.2f}s\")\n\n    loss_history = []\n    ncorrect = 0\n\n    for i, (images, labels) in enumerate(test_gen):\n        images = Variable(images.view(-1, 28*28))\n        labels = Variable(labels)\n\n        if torch.cuda.is_available():\n            images = images.cuda()\n            labels = labels.cuda()\n\n        outputs = net(images)\n        loss = loss_function(outputs, labels)\n\n        loss_history.append(loss.cpu().data)\n        _, predictions = torch.max(outputs, 1)\n        ncorrect += (predictions == labels).sum().cpu().numpy()\n\n    loss_mean            = np.mean(loss_history)\n    loss_uncertainty     = np.std(loss_history) / np.sqrt(len(loss_history) - 1)\n    accuracy_mean        = ncorrect / (len(test_gen) * batch_size)\n    accuracy_uncertainty = np.sqrt(accuracy_mean * (1.0 - accuracy_mean) / (len(test_gen) * batch_size))\n    print()\n    print(f\"Test loss: {loss_mean:.6f}+-{loss_uncertainty:.6f}, Test accuracy: {100*accuracy_mean:.2f}+-{100*accuracy_uncertainty:.2f}%\")\n\n    gguf_writer = gguf.GGUFWriter(model_path, \"mnist-fc\")\n\n    print()\n    print(f\"Model tensors saved to {model_path}:\")\n    for tensor_name in net.state_dict().keys():\n        data = net.state_dict()[tensor_name].squeeze().cpu().numpy()\n        print(tensor_name, \"\\t\", data.shape)\n        gguf_writer.add_tensor(tensor_name, data)\n\n    gguf_writer.write_header_to_file()\n    gguf_writer.write_kv_data_to_file()\n    gguf_writer.write_tensors_to_file()\n    gguf_writer.close()\n\n\nif __name__ == '__main__':\n    if len(sys.argv) != 2:\n        print(f\"Usage: {sys.argv[0]} <model_path>\")\n        sys.exit(1)\n    train(sys.argv[1])\n"
  },
  {
    "path": "examples/mnist/mnist-train.cpp",
    "content": "#include \"ggml-opt.h\"\n#include \"mnist-common.h\"\n\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <ctime>\n#include <string>\n\n#if defined(_MSC_VER)\n#pragma warning(disable: 4244 4267) // possible loss of data\n#endif\n\nint main(int argc, char ** argv) {\n    if (argc != 5 && argc != 6) {\n        fprintf(stderr, \"Usage: %s mnist-fc mnist-fc-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte [CPU/CUDA0]\\n\", argv[0]);\n        exit(0);\n    }\n\n    // The MNIST model is so small that the overhead from data shuffling is non-negligible, especially with CUDA.\n    // With a shard size of 10 this overhead is greatly reduced at the cost of less shuffling (does not seem to have a significant impact).\n    // A batch of 500 images then consists of 50 random shards of size 10 instead of 500 random shards of size 1.\n    ggml_opt_dataset_t dataset = ggml_opt_dataset_init(GGML_TYPE_F32, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NCLASSES, MNIST_NTRAIN, /*ndata_shard =*/ 10);\n\n    if (!mnist_image_load(argv[3], dataset)) {\n        return 1;\n    }\n    if (!mnist_label_load(argv[4], dataset)) {\n        return 1;\n    }\n\n    mnist_model model = mnist_model_init_random(argv[1], argc >= 6 ? argv[5] : \"\", MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL);\n\n    mnist_model_build(model);\n\n    mnist_model_train(model, dataset, /*nepoch =*/ 30, /*val_split =*/ 0.05f);\n\n    mnist_model_save(model, argv[2]);\n}\n"
  },
  {
    "path": "examples/mnist/server.py",
    "content": "import http.server\nimport socketserver\nimport os\nimport sys\n\nDIRECTORY = os.path.abspath(os.path.join(os.path.dirname(__file__), 'web'))\nPORT = 8000\n\nclass CustomHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, directory=DIRECTORY, **kwargs)\n\n    def end_headers(self):\n        # Add required headers for SharedArrayBuffer\n        self.send_header(\"Cross-Origin-Opener-Policy\", \"same-origin\")\n        self.send_header(\"Cross-Origin-Embedder-Policy\", \"require-corp\")\n        self.send_header(\"Access-Control-Allow-Origin\", \"*\")\n        super().end_headers()\n\n# Enable address reuse\nclass CustomServer(socketserver.TCPServer):\n    allow_reuse_address = True\n\ntry:\n    with CustomServer((\"\", PORT), CustomHTTPRequestHandler) as httpd:\n        print(f\"Serving directory '{DIRECTORY}' at http://localhost:{PORT}\")\n        print(f\"Application context root: http://localhost:{PORT}/\")\n        try:\n            httpd.serve_forever()\n        except KeyboardInterrupt:\n            print(\"\\nServer stopped.\")\n            # Force complete exit\n            sys.exit(0)\nexcept OSError as e:\n    print(f\"Error: {e}\")\n    sys.exit(1)\n"
  },
  {
    "path": "examples/perf-metal/CMakeLists.txt",
    "content": "#\n# perf-metal\n\nset(TEST_TARGET perf-metal)\nadd_executable(${TEST_TARGET} perf-metal.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml)\n\n"
  },
  {
    "path": "examples/perf-metal/perf-metal.cpp",
    "content": "// basic tool to experiment with the Metal backend\n//\n// 1. Get GPU trace of a dummy graph:\n//\n//   rm -rf /tmp/perf-metal.gputrace\n//   make -j perf-metal && METAL_CAPTURE_ENABLED=1 ./bin/perf-metal\n//   open /tmp/perf-metal.gputrace\n//\n//   https://github.com/ggerganov/llama.cpp/issues/9507\n//\n\n#include \"ggml.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-metal.h\"\n\n#include <cstdio>\n#include <vector>\n#include <thread>\n\nint main(int argc, char ** argv) {\n    int n_op = 1024;\n    int n_iter = 128;\n\n    if (argc > 1) {\n        n_op = std::atoi(argv[1]);\n    }\n\n    if (argc > 2) {\n        n_iter = std::atoi(argv[2]);\n    }\n\n    printf(\"%s: n_op = %d, n_iter = %d\\n\", __func__, n_op, n_iter);\n\n    const int ne00 = 8;\n    const int ne01 = 8;\n    const int ne11 = 8;\n\n    std::vector<float> data0(ne00*ne01, 1.0f);\n    std::vector<float> data1(ne00*ne01, 1.0f/ne00);\n\n    ggml_backend_t backend = ggml_backend_metal_init();\n    if (!backend) {\n        fprintf(stderr, \"%s: ggml_backend_metal_init() failed\\n\", __func__);\n        return 1;\n    }\n\n    const size_t ctx_size = 2 * ggml_tensor_overhead();\n\n    struct ggml_init_params params = {\n        /*.mem_size   =*/ ctx_size,\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ true,\n    };\n    struct ggml_context * ctx = ggml_init(params);\n\n    struct ggml_tensor * t0 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne00, ne01);\n    struct ggml_tensor * t1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne00, ne11);\n\n    ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);\n\n    ggml_backend_tensor_set(t0, data0.data(), 0, ggml_nbytes(t0));\n    ggml_backend_tensor_set(t1, data1.data(), 0, ggml_nbytes(t1));\n\n    struct ggml_cgraph * gf = NULL;\n\n    struct ggml_context * ctx_cgraph = NULL;\n\n    // create a dummy compute graph:\n    //\n    // x = mul_mat(t0, t1)\n    // x = x * 1.0f\n    // x = mul_mat(x, t1)\n    // x = x * 1.0f\n    // ... repeat n_op times ...\n    //\n    {\n        struct ggml_init_params params0 = {\n            /*.mem_size   =*/ 4*n_op*ggml_tensor_overhead() + ggml_graph_overhead(),\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ true,\n        };\n        ctx_cgraph = ggml_init(params0);\n\n        gf = ggml_new_graph_custom(ctx_cgraph, 4*n_op, false);\n\n        struct ggml_tensor * cur = ggml_mul_mat(ctx_cgraph, t0, t1);\n        cur = ggml_scale(ctx_cgraph, cur, 1.0f);\n\n        for (int i = 0; i < n_op - 1; i++) {\n            cur = ggml_mul_mat(ctx_cgraph, cur, t1);\n            cur = ggml_scale(ctx_cgraph, cur, 1.0f);\n        }\n\n        cur = ggml_scale(ctx_cgraph, cur, 42.0f);\n\n        ggml_build_forward_expand(gf, cur);\n    }\n\n    printf(\"%s: graph nodes = %d\\n\", __func__, ggml_graph_n_nodes(gf));\n\n    ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));\n    ggml_gallocr_alloc_graph(allocr, gf);\n\n    {\n        // warm-up\n        ggml_backend_graph_compute(backend, gf);\n\n        const int64_t t_start = ggml_time_us();\n\n        for (int iter = 0; iter < n_iter; iter++) {\n            ggml_backend_graph_compute(backend, gf);\n        }\n\n        const int64_t t_end = ggml_time_us();\n\n        // actual trace\n        ggml_backend_metal_capture_next_compute(backend);\n        ggml_backend_graph_compute(backend, gf);\n        //std::this_thread::sleep_for(std::chrono::milliseconds(1000)); // NOTE: these intervals do not appear in the XCode trace!\n        ggml_backend_metal_capture_next_compute(backend);\n        ggml_backend_graph_compute(backend, gf);\n        //std::this_thread::sleep_for(std::chrono::milliseconds(1000)); // NOTE: these intervals do not appear in the XCode trace!\n        ggml_backend_metal_capture_next_compute(backend);\n        ggml_backend_graph_compute(backend, gf);\n\n        printf(\"%s: time = %f ms\\n\", __func__, (t_end - t_start) / 1000.0 / n_iter);\n    }\n\n    {\n        struct ggml_tensor * res = ggml_graph_node(gf, -1);\n\n        std::vector<float> data(res->ne[0] * res->ne[1], 0.0f);\n\n        ggml_backend_tensor_get(res, data.data(), 0, ggml_nbytes(res));\n\n        for (int i1 = 0; i1 < res->ne[1]; i1++) {\n            for (int i0 = 0; i0 < res->ne[0]; i0++) {\n                printf(\"%f \", data[i1*res->ne[0] + i0]);\n            }\n            printf(\"\\n\");\n        }\n    }\n\n    ggml_free(ctx_cgraph);\n    ggml_gallocr_free(allocr);\n    ggml_free(ctx);\n    ggml_backend_buffer_free(buffer);\n    ggml_backend_free(backend);\n\n    return 0;\n}\n"
  },
  {
    "path": "examples/prompts/dolly-v2.txt",
    "content": "Hello World! => 12092,3645,2\nI can't believe it's already Friday!\" => 42,476,626,2868,352,434,2168,6794,1476\nThe URL for the website is https://www.example.com.\" => 510,10611,323,253,4422,310,5987,1358,2700,15,11667,15,681,449\n\"She said, 'I love to travel.'\" => 3,2993,753,13,686,42,2389,281,4288,18574\n'The temperature is 25.5°C.' => 8,510,3276,310,2030,15,22,3272,36,2464\n\"Let's meet at 2:30 p.m. in the park.\" => 3,1466,434,2525,387,374,27,1229,268,15,78,15,275,253,5603,449\nThe book costs $19.99 => 510,1984,4815,370,746,15,1525\n\"John's favorite color is blue.\" => 3,8732,434,7583,3295,310,4797,449\nTh@nk y0u f0r y0ur h3lp! => 1044,33,30664,340,17,86,269,17,83,340,17,321,288,20,24343,2\nC@n I g3t a c0ffee, pl3@se? => 36,33,79,309,305,20,85,247,260,17,71,6851,13,499,20,33,339,32\nW0w! Th@t's @m@zing! => 56,17,88,2,596,33,85,434,1214,78,33,8537,2\nH0w 4re y0u t0d@y? => 41,17,88,577,250,340,17,86,246,17,69,33,90,32\nI l0ve t0 tr@vel @r0und the w0rld. => 42,298,17,306,246,17,492,33,652,1214,83,17,1504,253,259,17,83,392,15\nWh@t's y0ur f@v0rite m0vie? => 3152,33,85,434,340,17,321,269,33,87,17,3852,278,17,25858,32\nThe cat is sleeping on the mat. => 510,5798,310,14343,327,253,1111,15\nI need to buy some groceries for dinner. => 42,878,281,4489,690,45160,447,323,8955,15\nThe sun is shining brightly in the sky. => 510,5101,310,28115,43925,275,253,8467,15\nShe is reading a book in the park. => 2993,310,4361,247,1984,275,253,5603,15\nWe went for a walk on the beach yesterday. => 1231,2427,323,247,2940,327,253,11600,11066,15\nHe plays the guitar like a pro. => 1328,7120,253,12609,751,247,354,15\nThey are going to the movies tonight. => 3726,403,1469,281,253,11321,11608,15\nThe flowers are blooming in the garden. => 510,12405,403,30601,272,275,253,10329,15\nI enjoy listening to classical music. => 42,4264,11298,281,8946,3440,15\nWe need to buy groceries for the week. => 1231,878,281,4489,45160,447,323,253,2129,15\nThe dog is chasing its tail in circles. => 510,4370,310,31702,697,8105,275,14240,15\nShe is wearing a beautiful red dress. => 2993,310,9398,247,5389,2502,7619,15\nHe is a talented actor in Hollywood. => 1328,310,247,21220,12353,275,14759,15\nThe children are playing in the playground. => 510,2151,403,4882,275,253,41008,15\nI'm going to visit my grandparents this weekend. => 42,1353,1469,281,4143,619,37186,436,8849,15\nThe coffee tastes bitter without sugar. => 510,8574,27491,17123,1293,8618,15\nThey are planning a surprise party for her. => 3726,403,7219,247,9326,3128,323,617,15\nShe sings like an angel on stage. => 2993,44718,751,271,23087,327,3924,15\nWe should take a vacation to relax. => 1231,943,1379,247,18125,281,7921,15\nHe is studying medicine at the university. => 1328,310,12392,9921,387,253,9835,15\nThe rain is pouring heavily outside. => 510,9313,310,31226,11306,3345,15\nI enjoy watching romantic movies. => 42,4264,7487,18109,11321,15\nThey are celebrating their anniversary today. => 3726,403,28765,616,19054,3063,15\nShe dances gracefully to the music. => 2993,47078,14426,2920,281,253,3440,15\nHe is an excellent basketball player. => 1328,310,271,7126,14648,4760,15\nThe baby is sleeping soundly in the crib. => 510,6858,310,14343,3590,314,275,253,260,725,15\nI need to finish my homework before dinner. => 42,878,281,8416,619,32110,1078,8955,15\nThey are organizing a charity event next month. => 3726,403,26169,247,19489,2362,1735,1770,15\nShe is cooking a delicious meal for us. => 2993,310,12398,247,17319,11484,323,441,15\nWe should go hiking in the mountains. => 1231,943,564,33061,275,253,14700,15\nThe car broke down on the way to work. => 510,1113,9377,1066,327,253,1039,281,789,15\nHe loves playing video games in his free time. => 1328,14528,4882,3492,3958,275,521,1959,673,15\nThe birds are chirping in the trees. => 510,11260,403,36494,14650,275,253,7139,15\nI want to learn how to play the piano. => 42,971,281,3037,849,281,1132,253,18542,15\nThey are building a new shopping mall in the city. => 3726,403,3652,247,747,12701,28974,275,253,2846,15\nShe is writing a novel in her spare time. => 2993,310,4028,247,4460,275,617,18345,673,15\nWe are going to the zoo this Saturday. => 1231,403,1469,281,253,41089,436,7814,15\nThe cake looks delicious with chocolate frosting. => 510,15221,4453,17319,342,14354,34724,272,15\nHe is a talented painter who sells his artwork. => 1328,310,247,21220,27343,665,27924,521,28227,15\nThe students are studying for their exams. => 510,3484,403,12392,323,616,34666,15\nI enjoy swimming in the ocean. => 42,4264,17120,275,253,12927,15\nThey are renovating their house. => 3726,403,30074,839,616,2419,15\nShe is practicing yoga to stay healthy. => 2993,310,25815,25551,281,3297,5875,15\nWe should plant flowers in the garden. => 1231,943,4444,12405,275,253,10329,15\nThe traffic is heavy during rush hour. => 510,7137,310,5536,1309,16949,4964,15\nHe is a skilled chef who creates amazing dishes. => 1328,310,247,18024,26540,665,10513,8644,17114,15\nThe baby is crawling on the floor. => 510,6858,310,44922,327,253,5254,15\nI need to buy a new pair of shoes. => 42,878,281,4489,247,747,4667,273,12682,15\nThey are going on a road trip across the country. => 3726,403,1469,327,247,3971,7408,2439,253,2586,15\nShe is playing the piano beautifully. => 2993,310,4882,253,18542,27839,15\nWe are going to a concert tomorrow night. => 1231,403,1469,281,247,12699,10873,2360,15\nThe cake tastes delicious with vanilla frosting. => 510,15221,27491,17319,342,26724,34724,272,15\nHe is a dedicated teacher who inspires his students. => 1328,310,247,9940,9732,665,6381,2731,521,3484,15\nThe students are participating in a science fair. => 510,3484,403,15299,275,247,5859,4344,15\nI enjoy hiking in the mountains. => 42,4264,33061,275,253,14700,15\nThey are organizing a beach cleanup next weekend. => 3726,403,26169,247,11600,34709,1735,8849,15\nShe is taking photographs of nature. => 2993,310,3192,15928,273,3753,15\nWe should try a new restaurant in town. => 1231,943,1611,247,747,10301,275,3874,15\nThe traffic is moving slowly on the highway. => 510,7137,310,4886,7808,327,253,17657,15\nHe is a talented singer with a beautiful voice. => 1328,310,247,21220,16057,342,247,5389,4318,15\nThe baby is laughing and giggling. => 510,6858,310,17053,285,41542,1981,15\nI need to do laundry and wash my clothes. => 42,878,281,513,29023,285,14841,619,10015,15\nThey are planning a trip to Europe. => 3726,403,7219,247,7408,281,3060,15\nShe is learning how to play the guitar. => 2993,310,4715,849,281,1132,253,12609,15\nWe are going to a museum this Sunday. => 1231,403,1469,281,247,16064,436,6926,15\nThe coffee smells amazing in the morning. => 510,8574,34247,8644,275,253,4131,15\nHe is a hardworking farmer who grows crops. => 1328,310,247,1892,21107,24718,665,17202,19492,15\nThe students are presenting their research projects. => 510,3484,403,15250,616,2561,6493,15\nI enjoy playing soccer with my friends. => 42,4264,4882,20391,342,619,3858,15\nThey are volunteering at a local shelter. => 3726,403,10057,2158,387,247,1980,17824,15\nShe is practicing martial arts for self-defense. => 2993,310,25815,29731,14635,323,1881,14,29337,15\nWe should try a new recipe for dinner. => 1231,943,1611,247,747,13612,323,8955,15\nThe traffic is congest => 510,7137,310,25801\nThe sun is shining brightly today. => 510,5101,310,28115,43925,3063,15\nI enjoy reading books in my free time. => 42,4264,4361,5098,275,619,1959,673,15\nShe plays the piano beautifully. => 2993,7120,253,18542,27839,15\nThe cat chased the mouse around the room. => 510,5798,40754,253,6521,1475,253,2316,15\nI love eating pizza with extra cheese. => 42,2389,9123,22534,342,4465,12173,15\nHe always wears a hat wherever he goes. => 1328,1900,31394,247,7856,20312,344,4566,15\nThe flowers in the garden are blooming. => 510,12405,275,253,10329,403,30601,272,15\nShe danced gracefully on the stage. => 2993,39860,14426,2920,327,253,3924,15\nThe dog barked loudly in the park. => 510,4370,21939,264,31311,275,253,5603,15\nWe went swimming in the ocean yesterday. => 1231,2427,17120,275,253,12927,11066,15\nHe speaks fluent French and Spanish. => 1328,16544,2938,290,5112,285,9883,15\nThe train arrived at the station on time. => 510,6194,7244,387,253,4660,327,673,15\nShe cooked a delicious meal for her family. => 2993,18621,247,17319,11484,323,617,2021,15\n"
  },
  {
    "path": "examples/prompts/gpt-2-chinese.txt",
    "content": "请问洗手间在哪里？ => 6435,7309,3819,2797,7313,1762,1525,7027,8043\n"
  },
  {
    "path": "examples/prompts/gpt-2.txt",
    "content": "Hello World! => 15496,2159,0\nI can't believe it's already Friday!\" => 40,460,470,1975,340,338,1541,3217,2474\nThe URL for the website is https://www.example.com.\" => 464,10289,329,262,3052,318,3740,1378,2503,13,20688,13,785,526\n\"She said, 'I love to travel.'\" => 1,3347,531,11,705,40,1842,284,3067,11496\n'The temperature is 25.5°C.' => 6,464,5951,318,1679,13,20,7200,34,2637\n\"Let's meet at 2:30 p.m. in the park.\" => 1,5756,338,1826,379,362,25,1270,279,13,76,13,287,262,3952,526\nThe book costs $19.99 => 464,1492,3484,720,1129,13,2079\n\"John's favorite color is blue.\" => 1,7554,338,4004,3124,318,4171,526\nTh@nk y0u f0r y0ur h3lp! => 817,31,77,74,331,15,84,277,15,81,331,15,333,289,18,34431,0\nC@n I g3t a c0ffee, pl3@se? => 34,31,77,314,308,18,83,257,269,15,5853,11,458,18,31,325,30\nW0w! Th@t's @m@zing! => 54,15,86,0,536,31,83,338,2488,76,31,9510,0\nH0w 4re y0u t0d@y? => 39,15,86,604,260,331,15,84,256,15,67,31,88,30\nI l0ve t0 tr@vel @r0und the w0rld. => 40,300,15,303,256,15,491,31,626,2488,81,15,917,262,266,15,81,335,13\nWh@t's y0ur f@v0rite m0vie? => 1199,31,83,338,331,15,333,277,31,85,15,6525,285,15,85,494,30\nThe cat is sleeping on the mat. => 464,3797,318,11029,319,262,2603,13\nI need to buy some groceries for dinner. => 40,761,284,2822,617,38464,329,8073,13\nThe sun is shining brightly in the sky. => 464,4252,318,22751,35254,287,262,6766,13\nShe is reading a book in the park. => 3347,318,3555,257,1492,287,262,3952,13\nWe went for a walk on the beach yesterday. => 1135,1816,329,257,2513,319,262,10481,7415,13\nHe plays the guitar like a pro. => 1544,5341,262,10047,588,257,386,13\nThey are going to the movies tonight. => 2990,389,1016,284,262,6918,9975,13\nThe flowers are blooming in the garden. => 464,12734,389,24924,3383,287,262,11376,13\nI enjoy listening to classical music. => 40,2883,8680,284,15993,2647,13\nWe need to buy groceries for the week. => 1135,761,284,2822,38464,329,262,1285,13\nThe dog is chasing its tail in circles. => 464,3290,318,20023,663,7894,287,13332,13\nShe is wearing a beautiful red dress. => 3347,318,5762,257,4950,2266,6576,13\nHe is a talented actor in Hollywood. => 1544,318,257,12356,8674,287,8502,13\nThe children are playing in the playground. => 464,1751,389,2712,287,262,24817,13\nI'm going to visit my grandparents this weekend. => 40,1101,1016,284,3187,616,28571,428,5041,13\nThe coffee tastes bitter without sugar. => 464,6891,18221,12922,1231,7543,13\nThey are planning a surprise party for her. => 2990,389,5410,257,5975,2151,329,607,13\nShe sings like an angel on stage. => 3347,33041,588,281,18304,319,3800,13\nWe should take a vacation to relax. => 1135,815,1011,257,14600,284,8960,13\nHe is studying medicine at the university. => 1544,318,11065,9007,379,262,6403,13\nThe rain is pouring heavily outside. => 464,6290,318,23147,7272,2354,13\nI enjoy watching romantic movies. => 40,2883,4964,14348,6918,13\nThey are celebrating their anniversary today. => 2990,389,17499,511,11162,1909,13\nShe dances gracefully to the music. => 3347,38207,11542,2759,284,262,2647,13\nHe is an excellent basketball player. => 1544,318,281,6275,9669,2137,13\nThe baby is sleeping soundly in the crib. => 464,5156,318,11029,2128,306,287,262,48083,13\nI need to finish my homework before dinner. => 40,761,284,5461,616,26131,878,8073,13\nThey are organizing a charity event next month. => 2990,389,16924,257,11016,1785,1306,1227,13\nShe is cooking a delicious meal for us. => 3347,318,10801,257,12625,9799,329,514,13\nWe should go hiking in the mountains. => 1135,815,467,24522,287,262,12269,13\nThe car broke down on the way to work. => 464,1097,6265,866,319,262,835,284,670,13\nHe loves playing video games in his free time. => 1544,10408,2712,2008,1830,287,465,1479,640,13\nThe birds are chirping in the trees. => 464,10087,389,442,343,13886,287,262,7150,13\nI want to learn how to play the piano. => 40,765,284,2193,703,284,711,262,19132,13\nThey are building a new shopping mall in the city. => 2990,389,2615,257,649,9735,17374,287,262,1748,13\nShe is writing a novel in her spare time. => 3347,318,3597,257,5337,287,607,13952,640,13\nWe are going to the zoo this Saturday. => 1135,389,1016,284,262,26626,428,3909,13\nThe cake looks delicious with chocolate frosting. => 464,12187,3073,12625,351,11311,21682,278,13\nHe is a talented painter who sells his artwork. => 1544,318,257,12356,34537,508,16015,465,16257,13\nThe students are studying for their exams. => 464,2444,389,11065,329,511,26420,13\nI enjoy swimming in the ocean. => 40,2883,14899,287,262,9151,13\nThey are renovating their house. => 2990,389,24317,803,511,2156,13\nShe is practicing yoga to stay healthy. => 3347,318,18207,20351,284,2652,5448,13\nWe should plant flowers in the garden. => 1135,815,4618,12734,287,262,11376,13\nThe traffic is heavy during rush hour. => 464,4979,318,4334,1141,10484,1711,13\nHe is a skilled chef who creates amazing dishes. => 1544,318,257,14297,21221,508,8075,4998,16759,13\nThe baby is crawling on the floor. => 464,5156,318,34499,319,262,4314,13\nI need to buy a new pair of shoes. => 40,761,284,2822,257,649,5166,286,10012,13\nThey are going on a road trip across the country. => 2990,389,1016,319,257,2975,5296,1973,262,1499,13\nShe is playing the piano beautifully. => 3347,318,2712,262,19132,21104,13\nWe are going to a concert tomorrow night. => 1135,389,1016,284,257,10010,9439,1755,13\nThe cake tastes delicious with vanilla frosting. => 464,12187,18221,12625,351,16858,21682,278,13\nHe is a dedicated teacher who inspires his students. => 1544,318,257,7256,4701,508,38934,465,2444,13\nThe students are participating in a science fair. => 464,2444,389,11983,287,257,3783,3148,13\nI enjoy hiking in the mountains. => 40,2883,24522,287,262,12269,13\nThey are organizing a beach cleanup next weekend. => 2990,389,16924,257,10481,27425,1306,5041,13\nShe is taking photographs of nature. => 3347,318,2263,12566,286,3450,13\nWe should try a new restaurant in town. => 1135,815,1949,257,649,7072,287,3240,13\nThe traffic is moving slowly on the highway. => 464,4979,318,3867,6364,319,262,12763,13\nHe is a talented singer with a beautiful voice. => 1544,318,257,12356,14015,351,257,4950,3809,13\nThe baby is laughing and giggling. => 464,5156,318,14376,290,30442,1359,13\nI need to do laundry and wash my clothes. => 40,761,284,466,25724,290,13502,616,8242,13\nThey are planning a trip to Europe. => 2990,389,5410,257,5296,284,2031,13\nShe is learning how to play the guitar. => 3347,318,4673,703,284,711,262,10047,13\nWe are going to a museum this Sunday. => 1135,389,1016,284,257,13257,428,3502,13\nThe coffee smells amazing in the morning. => 464,6891,25760,4998,287,262,3329,13\nHe is a hardworking farmer who grows crops. => 1544,318,257,1327,16090,18739,508,13676,14450,13\nThe students are presenting their research projects. => 464,2444,389,17728,511,2267,4493,13\nI enjoy playing soccer with my friends. => 40,2883,2712,11783,351,616,2460,13\nThey are volunteering at a local shelter. => 2990,389,41434,379,257,1957,11772,13\nShe is practicing martial arts for self-defense. => 3347,318,18207,15618,10848,329,2116,12,19774,13\nWe should try a new recipe for dinner. => 1135,815,1949,257,649,8364,329,8073,13\nThe traffic is congest => 464,4979,318,22791\nThe sun is shining brightly today. => 464,4252,318,22751,35254,1909,13\nI enjoy reading books in my free time. => 40,2883,3555,3835,287,616,1479,640,13\nShe plays the piano beautifully. => 3347,5341,262,19132,21104,13\nThe cat chased the mouse around the room. => 464,3797,26172,262,10211,1088,262,2119,13\nI love eating pizza with extra cheese. => 40,1842,6600,14256,351,3131,9891,13\nHe always wears a hat wherever he goes. => 1544,1464,17326,257,6877,14530,339,2925,13\nThe flowers in the garden are blooming. => 464,12734,287,262,11376,389,24924,3383,13\nShe danced gracefully on the stage. => 3347,39480,11542,2759,319,262,3800,13\nThe dog barked loudly in the park. => 464,3290,21405,276,23112,287,262,3952,13\nWe went swimming in the ocean yesterday. => 1135,1816,14899,287,262,9151,7415,13\nHe speaks fluent French and Spanish. => 1544,9209,43472,4141,290,7897,13\nThe train arrived at the station on time. => 464,4512,5284,379,262,4429,319,640,13\nShe cooked a delicious meal for her family. => 3347,15847,257,12625,9799,329,607,1641,13\n"
  },
  {
    "path": "examples/prompts/gpt-j.txt",
    "content": "Hello World! => 15496,2159,0\nI can't believe it's already Friday!\" => 40,460,470,1975,340,338,1541,3217,2474\nThe URL for the website is https://www.example.com.\" => 464,10289,329,262,3052,318,3740,1378,2503,13,20688,13,785,526\n\"She said, 'I love to travel.'\" => 1,3347,531,11,705,40,1842,284,3067,11496\n'The temperature is 25.5°C.' => 6,464,5951,318,1679,13,20,7200,34,2637\n\"Let's meet at 2:30 p.m. in the park.\" => 1,5756,338,1826,379,362,25,1270,279,13,76,13,287,262,3952,526\nThe book costs $19.99 => 464,1492,3484,720,1129,13,2079\n\"John's favorite color is blue.\" => 1,7554,338,4004,3124,318,4171,526\nTh@nk y0u f0r y0ur h3lp! => 817,31,77,74,331,15,84,277,15,81,331,15,333,289,18,34431,0\nC@n I g3t a c0ffee, pl3@se? => 34,31,77,314,308,18,83,257,269,15,5853,11,458,18,31,325,30\nW0w! Th@t's @m@zing! => 54,15,86,0,536,31,83,338,2488,76,31,9510,0\nH0w 4re y0u t0d@y? => 39,15,86,604,260,331,15,84,256,15,67,31,88,30\nI l0ve t0 tr@vel @r0und the w0rld. => 40,300,15,303,256,15,491,31,626,2488,81,15,917,262,266,15,81,335,13\nWh@t's y0ur f@v0rite m0vie? => 1199,31,83,338,331,15,333,277,31,85,15,6525,285,15,85,494,30\nThe cat is sleeping on the mat. => 464,3797,318,11029,319,262,2603,13\nI need to buy some groceries for dinner. => 40,761,284,2822,617,38464,329,8073,13\nThe sun is shining brightly in the sky. => 464,4252,318,22751,35254,287,262,6766,13\nShe is reading a book in the park. => 3347,318,3555,257,1492,287,262,3952,13\nWe went for a walk on the beach yesterday. => 1135,1816,329,257,2513,319,262,10481,7415,13\nHe plays the guitar like a pro. => 1544,5341,262,10047,588,257,386,13\nThey are going to the movies tonight. => 2990,389,1016,284,262,6918,9975,13\nThe flowers are blooming in the garden. => 464,12734,389,24924,3383,287,262,11376,13\nI enjoy listening to classical music. => 40,2883,8680,284,15993,2647,13\nWe need to buy groceries for the week. => 1135,761,284,2822,38464,329,262,1285,13\nThe dog is chasing its tail in circles. => 464,3290,318,20023,663,7894,287,13332,13\nShe is wearing a beautiful red dress. => 3347,318,5762,257,4950,2266,6576,13\nHe is a talented actor in Hollywood. => 1544,318,257,12356,8674,287,8502,13\nThe children are playing in the playground. => 464,1751,389,2712,287,262,24817,13\nI'm going to visit my grandparents this weekend. => 40,1101,1016,284,3187,616,28571,428,5041,13\nThe coffee tastes bitter without sugar. => 464,6891,18221,12922,1231,7543,13\nThey are planning a surprise party for her. => 2990,389,5410,257,5975,2151,329,607,13\nShe sings like an angel on stage. => 3347,33041,588,281,18304,319,3800,13\nWe should take a vacation to relax. => 1135,815,1011,257,14600,284,8960,13\nHe is studying medicine at the university. => 1544,318,11065,9007,379,262,6403,13\nThe rain is pouring heavily outside. => 464,6290,318,23147,7272,2354,13\nI enjoy watching romantic movies. => 40,2883,4964,14348,6918,13\nThey are celebrating their anniversary today. => 2990,389,17499,511,11162,1909,13\nShe dances gracefully to the music. => 3347,38207,11542,2759,284,262,2647,13\nHe is an excellent basketball player. => 1544,318,281,6275,9669,2137,13\nThe baby is sleeping soundly in the crib. => 464,5156,318,11029,2128,306,287,262,48083,13\nI need to finish my homework before dinner. => 40,761,284,5461,616,26131,878,8073,13\nThey are organizing a charity event next month. => 2990,389,16924,257,11016,1785,1306,1227,13\nShe is cooking a delicious meal for us. => 3347,318,10801,257,12625,9799,329,514,13\nWe should go hiking in the mountains. => 1135,815,467,24522,287,262,12269,13\nThe car broke down on the way to work. => 464,1097,6265,866,319,262,835,284,670,13\nHe loves playing video games in his free time. => 1544,10408,2712,2008,1830,287,465,1479,640,13\nThe birds are chirping in the trees. => 464,10087,389,442,343,13886,287,262,7150,13\nI want to learn how to play the piano. => 40,765,284,2193,703,284,711,262,19132,13\nThey are building a new shopping mall in the city. => 2990,389,2615,257,649,9735,17374,287,262,1748,13\nShe is writing a novel in her spare time. => 3347,318,3597,257,5337,287,607,13952,640,13\nWe are going to the zoo this Saturday. => 1135,389,1016,284,262,26626,428,3909,13\nThe cake looks delicious with chocolate frosting. => 464,12187,3073,12625,351,11311,21682,278,13\nHe is a talented painter who sells his artwork. => 1544,318,257,12356,34537,508,16015,465,16257,13\nThe students are studying for their exams. => 464,2444,389,11065,329,511,26420,13\nI enjoy swimming in the ocean. => 40,2883,14899,287,262,9151,13\nThey are renovating their house. => 2990,389,24317,803,511,2156,13\nShe is practicing yoga to stay healthy. => 3347,318,18207,20351,284,2652,5448,13\nWe should plant flowers in the garden. => 1135,815,4618,12734,287,262,11376,13\nThe traffic is heavy during rush hour. => 464,4979,318,4334,1141,10484,1711,13\nHe is a skilled chef who creates amazing dishes. => 1544,318,257,14297,21221,508,8075,4998,16759,13\nThe baby is crawling on the floor. => 464,5156,318,34499,319,262,4314,13\nI need to buy a new pair of shoes. => 40,761,284,2822,257,649,5166,286,10012,13\nThey are going on a road trip across the country. => 2990,389,1016,319,257,2975,5296,1973,262,1499,13\nShe is playing the piano beautifully. => 3347,318,2712,262,19132,21104,13\nWe are going to a concert tomorrow night. => 1135,389,1016,284,257,10010,9439,1755,13\nThe cake tastes delicious with vanilla frosting. => 464,12187,18221,12625,351,16858,21682,278,13\nHe is a dedicated teacher who inspires his students. => 1544,318,257,7256,4701,508,38934,465,2444,13\nThe students are participating in a science fair. => 464,2444,389,11983,287,257,3783,3148,13\nI enjoy hiking in the mountains. => 40,2883,24522,287,262,12269,13\nThey are organizing a beach cleanup next weekend. => 2990,389,16924,257,10481,27425,1306,5041,13\nShe is taking photographs of nature. => 3347,318,2263,12566,286,3450,13\nWe should try a new restaurant in town. => 1135,815,1949,257,649,7072,287,3240,13\nThe traffic is moving slowly on the highway. => 464,4979,318,3867,6364,319,262,12763,13\nHe is a talented singer with a beautiful voice. => 1544,318,257,12356,14015,351,257,4950,3809,13\nThe baby is laughing and giggling. => 464,5156,318,14376,290,30442,1359,13\nI need to do laundry and wash my clothes. => 40,761,284,466,25724,290,13502,616,8242,13\nThey are planning a trip to Europe. => 2990,389,5410,257,5296,284,2031,13\nShe is learning how to play the guitar. => 3347,318,4673,703,284,711,262,10047,13\nWe are going to a museum this Sunday. => 1135,389,1016,284,257,13257,428,3502,13\nThe coffee smells amazing in the morning. => 464,6891,25760,4998,287,262,3329,13\nHe is a hardworking farmer who grows crops. => 1544,318,257,1327,16090,18739,508,13676,14450,13\nThe students are presenting their research projects. => 464,2444,389,17728,511,2267,4493,13\nI enjoy playing soccer with my friends. => 40,2883,2712,11783,351,616,2460,13\nThey are volunteering at a local shelter. => 2990,389,41434,379,257,1957,11772,13\nShe is practicing martial arts for self-defense. => 3347,318,18207,15618,10848,329,2116,12,19774,13\nWe should try a new recipe for dinner. => 1135,815,1949,257,649,8364,329,8073,13\nThe traffic is congest => 464,4979,318,22791\nThe sun is shining brightly today. => 464,4252,318,22751,35254,1909,13\nI enjoy reading books in my free time. => 40,2883,3555,3835,287,616,1479,640,13\nShe plays the piano beautifully. => 3347,5341,262,19132,21104,13\nThe cat chased the mouse around the room. => 464,3797,26172,262,10211,1088,262,2119,13\nI love eating pizza with extra cheese. => 40,1842,6600,14256,351,3131,9891,13\nHe always wears a hat wherever he goes. => 1544,1464,17326,257,6877,14530,339,2925,13\nThe flowers in the garden are blooming. => 464,12734,287,262,11376,389,24924,3383,13\nShe danced gracefully on the stage. => 3347,39480,11542,2759,319,262,3800,13\nThe dog barked loudly in the park. => 464,3290,21405,276,23112,287,262,3952,13\nWe went swimming in the ocean yesterday. => 1135,1816,14899,287,262,9151,7415,13\nHe speaks fluent French and Spanish. => 1544,9209,43472,4141,290,7897,13\nThe train arrived at the station on time. => 464,4512,5284,379,262,4429,319,640,13\nShe cooked a delicious meal for her family. => 3347,15847,257,12625,9799,329,607,1641,13\n"
  },
  {
    "path": "examples/prompts/gpt-neox-japanese.txt",
    "content": "明日の天気はどうですか。 => 263,7353,268,18461,271,1722,18405,265\n"
  },
  {
    "path": "examples/prompts/gpt-neox.txt",
    "content": "Hello World! => 12092,3645,2\nI can't believe it's already Friday!\" => 42,476,626,2868,352,434,2168,6794,1476\nThe URL for the website is https://www.example.com.\" => 510,10611,323,253,4422,310,5987,1358,2700,15,11667,15,681,449\n\"She said, 'I love to travel.'\" => 3,2993,753,13,686,42,2389,281,4288,18574\n'The temperature is 25.5°C.' => 8,510,3276,310,2030,15,22,3272,36,2464\n\"Let's meet at 2:30 p.m. in the park.\" => 3,1466,434,2525,387,374,27,1229,268,15,78,15,275,253,5603,449\nThe book costs $19.99 => 510,1984,4815,370,746,15,1525\n\"John's favorite color is blue.\" => 3,8732,434,7583,3295,310,4797,449\nTh@nk y0u f0r y0ur h3lp! => 1044,33,30664,340,17,86,269,17,83,340,17,321,288,20,24343,2\nC@n I g3t a c0ffee, pl3@se? => 36,33,79,309,305,20,85,247,260,17,71,6851,13,499,20,33,339,32\nW0w! Th@t's @m@zing! => 56,17,88,2,596,33,85,434,1214,78,33,8537,2\nH0w 4re y0u t0d@y? => 41,17,88,577,250,340,17,86,246,17,69,33,90,32\nI l0ve t0 tr@vel @r0und the w0rld. => 42,298,17,306,246,17,492,33,652,1214,83,17,1504,253,259,17,83,392,15\nWh@t's y0ur f@v0rite m0vie? => 3152,33,85,434,340,17,321,269,33,87,17,3852,278,17,25858,32\nThe cat is sleeping on the mat. => 510,5798,310,14343,327,253,1111,15\nI need to buy some groceries for dinner. => 42,878,281,4489,690,45160,447,323,8955,15\nThe sun is shining brightly in the sky. => 510,5101,310,28115,43925,275,253,8467,15\nShe is reading a book in the park. => 2993,310,4361,247,1984,275,253,5603,15\nWe went for a walk on the beach yesterday. => 1231,2427,323,247,2940,327,253,11600,11066,15\nHe plays the guitar like a pro. => 1328,7120,253,12609,751,247,354,15\nThey are going to the movies tonight. => 3726,403,1469,281,253,11321,11608,15\nThe flowers are blooming in the garden. => 510,12405,403,30601,272,275,253,10329,15\nI enjoy listening to classical music. => 42,4264,11298,281,8946,3440,15\nWe need to buy groceries for the week. => 1231,878,281,4489,45160,447,323,253,2129,15\nThe dog is chasing its tail in circles. => 510,4370,310,31702,697,8105,275,14240,15\nShe is wearing a beautiful red dress. => 2993,310,9398,247,5389,2502,7619,15\nHe is a talented actor in Hollywood. => 1328,310,247,21220,12353,275,14759,15\nThe children are playing in the playground. => 510,2151,403,4882,275,253,41008,15\nI'm going to visit my grandparents this weekend. => 42,1353,1469,281,4143,619,37186,436,8849,15\nThe coffee tastes bitter without sugar. => 510,8574,27491,17123,1293,8618,15\nThey are planning a surprise party for her. => 3726,403,7219,247,9326,3128,323,617,15\nShe sings like an angel on stage. => 2993,44718,751,271,23087,327,3924,15\nWe should take a vacation to relax. => 1231,943,1379,247,18125,281,7921,15\nHe is studying medicine at the university. => 1328,310,12392,9921,387,253,9835,15\nThe rain is pouring heavily outside. => 510,9313,310,31226,11306,3345,15\nI enjoy watching romantic movies. => 42,4264,7487,18109,11321,15\nThey are celebrating their anniversary today. => 3726,403,28765,616,19054,3063,15\nShe dances gracefully to the music. => 2993,47078,14426,2920,281,253,3440,15\nHe is an excellent basketball player. => 1328,310,271,7126,14648,4760,15\nThe baby is sleeping soundly in the crib. => 510,6858,310,14343,3590,314,275,253,260,725,15\nI need to finish my homework before dinner. => 42,878,281,8416,619,32110,1078,8955,15\nThey are organizing a charity event next month. => 3726,403,26169,247,19489,2362,1735,1770,15\nShe is cooking a delicious meal for us. => 2993,310,12398,247,17319,11484,323,441,15\nWe should go hiking in the mountains. => 1231,943,564,33061,275,253,14700,15\nThe car broke down on the way to work. => 510,1113,9377,1066,327,253,1039,281,789,15\nHe loves playing video games in his free time. => 1328,14528,4882,3492,3958,275,521,1959,673,15\nThe birds are chirping in the trees. => 510,11260,403,36494,14650,275,253,7139,15\nI want to learn how to play the piano. => 42,971,281,3037,849,281,1132,253,18542,15\nThey are building a new shopping mall in the city. => 3726,403,3652,247,747,12701,28974,275,253,2846,15\nShe is writing a novel in her spare time. => 2993,310,4028,247,4460,275,617,18345,673,15\nWe are going to the zoo this Saturday. => 1231,403,1469,281,253,41089,436,7814,15\nThe cake looks delicious with chocolate frosting. => 510,15221,4453,17319,342,14354,34724,272,15\nHe is a talented painter who sells his artwork. => 1328,310,247,21220,27343,665,27924,521,28227,15\nThe students are studying for their exams. => 510,3484,403,12392,323,616,34666,15\nI enjoy swimming in the ocean. => 42,4264,17120,275,253,12927,15\nThey are renovating their house. => 3726,403,30074,839,616,2419,15\nShe is practicing yoga to stay healthy. => 2993,310,25815,25551,281,3297,5875,15\nWe should plant flowers in the garden. => 1231,943,4444,12405,275,253,10329,15\nThe traffic is heavy during rush hour. => 510,7137,310,5536,1309,16949,4964,15\nHe is a skilled chef who creates amazing dishes. => 1328,310,247,18024,26540,665,10513,8644,17114,15\nThe baby is crawling on the floor. => 510,6858,310,44922,327,253,5254,15\nI need to buy a new pair of shoes. => 42,878,281,4489,247,747,4667,273,12682,15\nThey are going on a road trip across the country. => 3726,403,1469,327,247,3971,7408,2439,253,2586,15\nShe is playing the piano beautifully. => 2993,310,4882,253,18542,27839,15\nWe are going to a concert tomorrow night. => 1231,403,1469,281,247,12699,10873,2360,15\nThe cake tastes delicious with vanilla frosting. => 510,15221,27491,17319,342,26724,34724,272,15\nHe is a dedicated teacher who inspires his students. => 1328,310,247,9940,9732,665,6381,2731,521,3484,15\nThe students are participating in a science fair. => 510,3484,403,15299,275,247,5859,4344,15\nI enjoy hiking in the mountains. => 42,4264,33061,275,253,14700,15\nThey are organizing a beach cleanup next weekend. => 3726,403,26169,247,11600,34709,1735,8849,15\nShe is taking photographs of nature. => 2993,310,3192,15928,273,3753,15\nWe should try a new restaurant in town. => 1231,943,1611,247,747,10301,275,3874,15\nThe traffic is moving slowly on the highway. => 510,7137,310,4886,7808,327,253,17657,15\nHe is a talented singer with a beautiful voice. => 1328,310,247,21220,16057,342,247,5389,4318,15\nThe baby is laughing and giggling. => 510,6858,310,17053,285,41542,1981,15\nI need to do laundry and wash my clothes. => 42,878,281,513,29023,285,14841,619,10015,15\nThey are planning a trip to Europe. => 3726,403,7219,247,7408,281,3060,15\nShe is learning how to play the guitar. => 2993,310,4715,849,281,1132,253,12609,15\nWe are going to a museum this Sunday. => 1231,403,1469,281,247,16064,436,6926,15\nThe coffee smells amazing in the morning. => 510,8574,34247,8644,275,253,4131,15\nHe is a hardworking farmer who grows crops. => 1328,310,247,1892,21107,24718,665,17202,19492,15\nThe students are presenting their research projects. => 510,3484,403,15250,616,2561,6493,15\nI enjoy playing soccer with my friends. => 42,4264,4882,20391,342,619,3858,15\nThey are volunteering at a local shelter. => 3726,403,10057,2158,387,247,1980,17824,15\nShe is practicing martial arts for self-defense. => 2993,310,25815,29731,14635,323,1881,14,29337,15\nWe should try a new recipe for dinner. => 1231,943,1611,247,747,13612,323,8955,15\nThe traffic is congest => 510,7137,310,25801\nThe sun is shining brightly today. => 510,5101,310,28115,43925,3063,15\nI enjoy reading books in my free time. => 42,4264,4361,5098,275,619,1959,673,15\nShe plays the piano beautifully. => 2993,7120,253,18542,27839,15\nThe cat chased the mouse around the room. => 510,5798,40754,253,6521,1475,253,2316,15\nI love eating pizza with extra cheese. => 42,2389,9123,22534,342,4465,12173,15\nHe always wears a hat wherever he goes. => 1328,1900,31394,247,7856,20312,344,4566,15\nThe flowers in the garden are blooming. => 510,12405,275,253,10329,403,30601,272,15\nShe danced gracefully on the stage. => 2993,39860,14426,2920,327,253,3924,15\nThe dog barked loudly in the park. => 510,4370,21939,264,31311,275,253,5603,15\nWe went swimming in the ocean yesterday. => 1231,2427,17120,275,253,12927,11066,15\nHe speaks fluent French and Spanish. => 1328,16544,2938,290,5112,285,9883,15\nThe train arrived at the station on time. => 510,6194,7244,387,253,4660,327,673,15\nShe cooked a delicious meal for her family. => 2993,18621,247,17319,11484,323,617,2021,15\n"
  },
  {
    "path": "examples/prompts/polyglot-ko.txt",
    "content": "이것은 테스트 이다. => 12271,296,6474,28037,17\n걱정할 필요 없다. => 18311,482,1062,550,267,17\n버그는 언젠가 고쳐진다. => 6904,272,8575,10381,1765,17\n"
  },
  {
    "path": "examples/prompts/replit.txt",
    "content": "Hello World! => 6466,147,2317,350\nI can't believe it's already Friday!\" => 286,512,172,185,13392,393,172,155,3239,147,29249,8537\nThe URL for the website is https://www.example.com.\" => 505,5635,250,170,11745,235,147,303,262,552,148,811,148,241,148,161\n\"She said, 'I love to travel.'\" => 161,10386,4089,150,206,286,8440,194,147,12363,148,172,161\n'The temperature is 25.5°C.' => 172,505,147,9502,235,147,20022,8516,228,148,172\n\"Let's meet at 2:30 p.m. in the park.\" => 161,8997,172,155,17120,536,147,162,5245,147,207,148,204,148,219,170,147,17664,148,161\nThe book costs $19.99 => 505,147,2277,17494,236,166,11824\n\"John's favorite color is blue.\" => 161,7475,172,155,147,11105,147,349,235,17046,148,161\nTh@nk y0u f0r y0ur h3lp! => 6309,240,9019,147,237,159,247,147,202,159,223,147,237,159,2458,147,226,171,3899,350\nC@n I g3t a c0ffee, pl3@se? => 228,240,211,398,147,267,171,185,216,147,196,159,13360,163,150,147,1287,171,240,155,163,272\nW0w! Th@t's @m@zing! => 450,159,274,350,147,6309,240,185,172,155,268,204,240,301,248,350\nH0w 4re y0u t0d@y? => 304,159,274,320,440,147,237,159,247,147,185,159,182,240,237,272\nI l0ve t0 tr@vel @r0und the w0rld. => 286,997,159,1290,147,185,159,147,490,240,3893,268,223,159,3981,170,147,274,159,223,2833,148\nWh@t's y0ur f@v0rite m0vie? => 450,226,240,185,172,155,147,237,159,2458,147,202,240,252,159,5961,163,147,204,159,24373,272\nThe cat is sleeping on the mat. => 505,147,1604,235,147,3987,248,347,170,147,1297,148\nI need to buy some groceries for dinner. => 286,1645,194,147,8068,1499,147,10022,1037,10023,250,147,182,2749,148\nThe sun is shining brightly in the sky. => 505,147,5852,235,147,7304,2967,147,215,649,391,219,170,147,7310,148\nShe is reading a book in the park. => 10386,235,9838,216,147,2277,219,170,147,17664,148\nWe went for a walk on the beach yesterday. => 3250,10825,250,216,147,8156,347,170,294,5371,147,28830,148\nHe plays the guitar like a pro. => 5301,7084,155,170,147,4604,2214,1425,216,3474,148\nThey are going to the movies tonight. => 18815,429,6552,194,170,147,15877,194,7907,148\nThe flowers are blooming in the garden. => 505,147,22953,155,429,147,10411,2799,248,219,170,147,22140,148\nI enjoy listening to classical music. => 286,23162,15876,248,194,239,4251,147,7395,148\nWe need to buy groceries for the week. => 3250,1645,194,147,8068,147,10022,1037,10023,250,170,9238,148\nThe dog is chasing its tail in circles. => 505,147,6540,235,147,196,916,248,1602,147,5129,219,147,4095,155,148\nShe is wearing a beautiful red dress. => 10386,235,147,16427,248,216,147,23447,147,1160,147,14592,148\nHe is a talented actor in Hollywood. => 5301,235,216,147,29750,246,147,5112,219,147,16924,391,10477,148\nThe children are playing in the playground. => 505,7934,429,7084,248,219,170,7084,12055,148\nI'm going to visit my grandparents this weekend. => 286,172,204,6552,194,9939,1247,147,11806,12019,291,9238,314,148\nThe coffee tastes bitter without sugar. => 505,147,21526,147,20931,155,5145,1430,1988,147,28759,148\nThey are planning a surprise party for her. => 18815,429,147,23661,216,147,29240,147,7344,250,1869,148\nShe sings like an angel on stage. => 10386,147,155,6502,1425,426,147,26028,347,12685,148\nWe should take a vacation to relax. => 3250,936,4654,216,147,15388,946,194,1998,2744,148\nHe is studying medicine at the university. => 5301,235,7959,248,147,20742,1668,536,170,147,8025,148\nThe rain is pouring heavily outside. => 505,147,6885,235,5306,248,1189,5451,391,8096,148\nI enjoy watching romantic movies. => 286,23162,147,3355,248,147,26080,4140,147,15877,148\nThey are celebrating their anniversary today. => 18815,429,147,30000,5841,1669,147,24734,5464,1770,13386,148\nShe dances gracefully to the music. => 10386,147,182,1626,155,147,267,8771,8001,194,170,147,7395,148\nHe is an excellent basketball player. => 5301,235,426,147,12300,675,185,147,26646,5132,6294,148\nThe baby is sleeping soundly in the crib. => 505,147,23597,235,147,3987,248,12642,391,219,170,147,7696,215,148\nI need to finish my homework before dinner. => 286,1645,194,147,6717,1247,147,1071,2722,2643,147,182,2749,148\nThey are organizing a charity event next month. => 18815,429,147,16442,248,216,1054,1511,1663,2399,12821,148\nShe is cooking a delicious meal for us. => 10386,235,147,20453,248,216,3936,23455,147,26658,250,147,539,148\nWe should go hiking in the mountains. => 3250,936,4242,147,2254,5357,219,170,147,204,18028,155,148\nThe car broke down on the way to work. => 505,7553,147,510,10036,4288,347,170,3699,194,1916,148\nHe loves playing video games in his free time. => 5301,8440,155,7084,248,8722,147,11281,219,1439,4002,801,148\nThe birds are chirping in the trees. => 505,147,13043,155,429,147,3904,223,4639,219,170,5311,155,148\nI want to learn how to play the piano. => 286,1857,194,14167,2496,194,7084,170,147,207,23635,148\nThey are building a new shopping mall in the city. => 18815,429,11038,216,277,147,22184,147,204,609,219,170,147,2416,148\nShe is writing a novel in her spare time. => 10386,235,3242,216,147,25814,219,1869,6772,2382,801,148\nWe are going to the zoo this Saturday. => 3250,429,6552,194,170,147,25101,291,147,31426,148\nThe cake looks delicious with chocolate frosting. => 505,147,24422,16303,3936,23455,312,147,5619,533,2239,147,202,3973,3431,148\nHe is a talented painter who sells his artwork. => 5301,235,216,147,29750,246,147,9226,279,2888,13004,155,1439,12234,2722,148\nThe students are studying for their exams. => 505,15707,429,7959,248,250,1669,147,12398,155,148\nI enjoy swimming in the ocean. => 286,23162,147,4729,8528,248,219,170,147,26193,148\nThey are renovating their house. => 18815,429,991,10724,3643,1669,13788,148\nShe is practicing yoga to stay healthy. => 10386,235,147,18453,248,147,5063,1186,194,15344,147,28550,148\nWe should plant flowers in the garden. => 3250,936,147,9212,147,22953,155,219,170,147,22140,148\nThe traffic is heavy during rush hour. => 505,147,11097,235,147,22232,4340,147,22319,147,5686,148\nHe is a skilled chef who creates amazing dishes. => 5301,235,216,147,8891,246,9784,202,2888,13720,147,28880,147,23852,383,148\nThe baby is crawling on the floor. => 505,147,23597,235,147,22120,248,347,170,147,5895,148\nI need to buy a new pair of shoes. => 286,1645,194,147,8068,216,277,12632,210,147,155,21953,155,148\nThey are going on a road trip across the country. => 18815,429,6552,347,216,147,6362,147,11395,9762,170,11305,148\nShe is playing the piano beautifully. => 10386,235,7084,248,170,147,207,23635,147,23447,391,148\nWe are going to a concert tomorrow night. => 3250,429,6552,194,216,1710,4391,29524,12716,148\nThe cake tastes delicious with vanilla frosting. => 505,147,24422,147,20931,155,3936,23455,312,5535,7476,147,202,3973,3431,148\nHe is a dedicated teacher who inspires his students. => 5301,235,216,326,8298,3460,147,9675,2888,147,28801,155,1439,15707,148\nThe students are participating in a science fair. => 505,15707,429,147,30961,3643,219,216,147,10587,147,7636,148\nI enjoy hiking in the mountains. => 286,23162,147,2254,5357,219,170,147,204,18028,155,148\nThey are organizing a beach cleanup next weekend. => 18815,429,147,16442,248,216,294,5371,147,10401,2399,9238,314,148\nShe is taking photographs of nature. => 10386,235,147,12345,147,4709,1547,155,210,147,211,8603,148\nWe should try a new restaurant in town. => 3250,936,147,746,216,277,147,11007,219,147,10200,148\nThe traffic is moving slowly on the highway. => 505,147,11097,235,147,8601,147,9880,391,347,170,5976,3330,148\nHe is a talented singer with a beautiful voice. => 5301,235,216,147,29750,246,147,155,248,279,312,216,147,23447,147,9316,148\nThe baby is laughing and giggling. => 505,147,23597,235,147,23066,248,221,147,2341,3631,2869,148\nI need to do laundry and wash my clothes. => 286,1645,194,543,960,3981,2154,221,147,27589,1247,147,22141,383,148\nThey are planning a trip to Europe. => 18815,429,147,23661,216,147,11395,194,13131,148\nShe is learning how to play the guitar. => 10386,235,11754,2496,194,7084,170,147,4604,2214,148\nWe are going to a museum this Sunday. => 3250,429,6552,194,216,147,204,433,1177,291,147,29111,148\nThe coffee smells amazing in the morning. => 505,147,21526,31454,155,147,28880,219,170,20701,148\nHe is a hardworking farmer who grows crops. => 5301,235,216,8524,14992,147,16679,279,2888,147,6044,155,147,8650,155,148\nThe students are presenting their research projects. => 505,15707,429,5130,248,1669,13217,14235,148\nI enjoy playing soccer with my friends. => 286,23162,7084,248,147,9351,5318,312,1247,147,5347,155,148\nThey are volunteering at a local shelter. => 18815,429,147,5238,7478,163,12798,536,216,2491,2905,1359,279,148\nShe is practicing martial arts for self-defense. => 10386,235,147,18453,248,147,3261,185,4381,12234,155,250,623,153,29896,148\nWe should try a new recipe for dinner. => 3250,936,147,746,216,277,147,9851,250,147,182,2749,148\nThe traffic is congest => 505,147,11097,235,1710,14169\nThe sun is shining brightly today. => 505,147,5852,235,147,7304,2967,147,215,649,391,13386,148\nI enjoy reading books in my free time. => 286,23162,9838,147,9670,219,1247,4002,801,148\nShe plays the piano beautifully. => 10386,7084,155,170,147,207,23635,147,23447,391,148\nThe cat chased the mouse around the room. => 505,147,1604,147,196,916,246,170,12551,6890,170,9654,148\nI love eating pizza with extra cheese. => 286,8440,147,163,3643,147,207,8403,312,8230,9784,383,163,148\nHe always wears a hat wherever he goes. => 5301,5418,147,16427,155,216,147,4879,2171,2433,1189,16177,148\nThe flowers in the garden are blooming. => 505,147,22953,155,219,170,147,22140,429,147,10411,2799,248,148\nShe danced gracefully on the stage. => 10386,13378,12408,147,267,8771,8001,347,170,12685,148\nThe dog barked loudly in the park. => 505,147,6540,147,973,293,246,147,30182,391,219,170,147,17664,148\nWe went swimming in the ocean yesterday. => 3250,10825,147,4729,8528,248,219,170,147,26193,147,28830,148\nHe speaks fluent French and Spanish. => 5301,147,13285,155,147,21677,147,254,17590,221,147,31519,148\nThe train arrived at the station on time. => 505,147,872,147,20712,182,536,170,147,7184,347,801,148\nShe cooked a delicious meal for her family. => 10386,147,20453,246,216,3936,23455,147,26658,250,1869,147,2002,148\n"
  },
  {
    "path": "examples/prompts/starcoder.txt",
    "content": "Hello World! => 8279,10896,19\nI can't believe it's already Friday!\" => 59,883,1330,13710,561,1182,3425,506,25674,11555\nThe URL for the website is https://www.example.com.\" => 1318,3834,436,322,9575,438,1678,555,1499,32,2763,32,508,3107\n\"She said, 'I love to travel.'\" => 20,25387,9884,30,330,59,14290,372,25283,29329\n'The temperature is 25.5°C.' => 25,1318,13587,438,225,36,39,32,39,23767,53,4564\n\"Let's meet at 2:30 p.m. in the park.\" => 20,9809,1182,18450,821,225,36,44,37,34,298,32,95,32,328,322,880,93,3107\nThe book costs $19.99 => 1318,7618,25950,398,35,43,32,43,43\n\"John's favorite color is blue.\" => 20,19693,1182,27448,1963,438,10087,3107\nTh@nk y0u f0r y0ur h3lp! => 1027,50,19877,533,34,103,296,34,100,533,34,305,420,37,1915,19\nC@n I g3t a c0ffee, pl3@se? => 53,50,96,439,485,37,102,312,281,34,21298,30,1278,37,50,277,49\nW0w! Th@t's @m@zing! => 73,34,105,19,947,50,102,1182,477,95,50,26768,19\nH0w 4re y0u t0d@y? => 58,34,105,225,38,268,533,34,103,273,34,86,50,107,49\nI l0ve t0 tr@vel @r0und the w0rld. => 59,456,34,587,273,34,554,50,1203,477,100,34,642,322,341,34,100,1381,32\nWh@t's y0ur f@v0rite m0vie? => 2444,50,102,1182,533,34,305,296,50,104,34,1049,345,34,104,1075,49\nThe cat is sleeping on the mat. => 1318,10501,438,9368,299,544,322,2491,32\nI need to buy some groceries for dinner. => 59,1849,372,16968,1629,20234,85,6958,436,343,3369,32\nThe sun is shining brightly in the sky. => 1318,15323,438,787,19068,38231,631,328,322,26718,32\nShe is reading a book in the park. => 25387,438,9175,312,7618,328,322,880,93,32\nWe went for a walk on the beach yesterday. => 3122,14236,436,312,13503,544,322,526,867,39485,32\nHe plays the guitar like a pro. => 1331,41271,322,3932,19931,2124,312,534,32\nThey are going to the movies tonight. => 31805,884,6783,372,322,27889,26076,694,32\nThe flowers are blooming in the garden. => 1318,7290,483,884,323,18466,299,328,322,485,22461,32\nI enjoy listening to classical music. => 59,31567,20498,372,443,1578,17522,32\nWe need to buy groceries for the week. => 3122,1849,372,16968,20234,85,6958,436,322,8209,32\nThe dog is chasing its tail in circles. => 1318,27435,438,663,9949,2819,13203,328,46428,32\nShe is wearing a beautiful red dress. => 25387,438,996,6992,312,36493,3346,343,714,32\nHe is a talented actor in Hollywood. => 1331,438,312,273,9556,318,16038,328,48228,631,21118,32\nThe children are playing in the playground. => 1318,5713,884,19788,328,322,4654,1749,32\nI'm going to visit my grandparents this weekend. => 59,3464,6783,372,7725,1672,33162,19277,458,40618,32\nThe coffee tastes bitter without sugar. => 1318,36917,273,633,307,3493,391,2876,309,18628,32\nThey are planning a surprise party for her. => 31805,884,26116,312,6178,9251,15270,436,7791,32\nShe sings like an angel on stage. => 25387,309,2052,2124,600,600,17691,544,10019,32\nWe should take a vacation to relax. => 3122,1395,4818,312,29164,367,372,41972,32\nHe is studying medicine at the university. => 1331,438,14866,299,32388,482,821,322,707,9190,32\nThe rain is pouring heavily outside. => 1318,36987,438,9202,299,46003,2801,11127,32\nI enjoy watching romantic movies. => 59,31567,37652,26045,7268,27889,32\nThey are celebrating their anniversary today. => 31805,884,48278,839,1741,3623,23921,5810,672,11610,32\nShe dances gracefully to the music. => 25387,343,3151,31376,4938,372,322,17522,32\nHe is an excellent basketball player. => 1331,438,600,39203,48400,11653,4362,32\nThe baby is sleeping soundly in the crib. => 1318,323,17156,438,9368,299,9934,631,328,322,281,7972,32\nI need to finish my homework before dinner. => 59,1849,372,11361,1672,6765,1007,2670,343,3369,32\nThey are organizing a charity event next month. => 31805,884,10558,6183,312,1351,543,1692,2354,6811,32\nShe is cooking a delicious meal for us. => 25387,438,23682,299,312,409,406,2406,597,279,436,1770,32\nWe should go hiking in the mountains. => 3122,1395,1983,420,1546,299,328,322,10874,1907,32\nThe car broke down on the way to work. => 1318,6346,43289,2835,544,322,3352,372,1389,32\nHe loves playing video games in his free time. => 1331,598,4954,19788,6027,19705,328,6697,3741,1133,32\nThe birds are chirping in the trees. => 1318,8424,3210,884,663,476,7075,328,322,23453,32\nI want to learn how to play the piano. => 59,2637,372,7350,2624,372,4654,322,298,25757,32\nThey are building a new shopping mall in the city. => 31805,884,9038,312,537,40692,345,464,328,322,11297,32\nShe is writing a novel in her spare time. => 25387,438,4127,312,32913,328,7791,1869,586,1133,32\nWe are going to the zoo this Saturday. => 3122,884,6783,372,322,1288,604,458,358,30288,32\nThe cake looks delicious with chocolate frosting. => 1318,281,1062,7780,409,406,2406,623,10408,27589,296,20932,299,32\nHe is a talented painter who sells his artwork. => 1331,438,312,273,9556,318,42300,6560,10800,101,6697,5549,1007,32\nThe students are studying for their exams. => 1318,16512,884,14866,299,436,3623,538,1462,32\nI enjoy swimming in the ocean. => 59,31567,2535,449,6714,328,322,337,18857,32\nThey are renovating their house. => 31805,884,316,15007,1741,3623,17075,32\nShe is practicing yoga to stay healthy. => 25387,438,11808,11636,533,40067,372,20005,44538,32\nWe should plant flowers in the garden. => 3122,1395,26795,7290,483,328,322,485,22461,32\nThe traffic is heavy during rush hour. => 1318,16391,438,32389,5929,540,1372,12021,32\nHe is a skilled chef who creates amazing dishes. => 1331,438,312,3001,12088,44051,6560,9585,36986,1214,4279,32\nThe baby is crawling on the floor. => 1318,323,17156,438,281,1294,2920,544,322,17648,32\nI need to buy a new pair of shoes. => 59,1849,372,16968,312,537,6092,432,787,37764,32\nThey are going on a road trip across the country. => 31805,884,6783,544,312,24122,19337,10160,322,10769,32\nShe is playing the piano beautifully. => 25387,438,19788,322,298,25757,526,4846,325,514,107,32\nWe are going to a concert tomorrow night. => 3122,884,6783,372,312,457,6989,31841,19212,32\nThe cake tastes delicious with vanilla frosting. => 1318,281,1062,273,633,307,409,406,2406,623,44653,296,20932,299,32\nHe is a dedicated teacher who inspires his students. => 1331,438,312,23112,30877,6560,26194,8017,6697,16512,32\nThe students are participating in a science fair. => 1318,16512,884,24623,1741,328,312,27536,19375,32\nI enjoy hiking in the mountains. => 59,31567,420,1546,299,328,322,10874,1907,32\nThey are organizing a beach cleanup next weekend. => 31805,884,10558,6183,312,526,867,13144,2354,40618,32\nShe is taking photographs of nature. => 25387,438,15137,15110,23626,432,24406,32\nWe should try a new restaurant in town. => 3122,1395,1596,312,537,43719,328,38212,32\nThe traffic is moving slowly on the highway. => 1318,16391,438,14089,12899,631,544,322,3857,3073,32\nHe is a talented singer with a beautiful voice. => 1331,438,312,273,9556,318,309,10118,623,312,36493,20309,32\nThe baby is laughing and giggling. => 1318,323,17156,438,2317,2943,299,461,485,365,36088,32\nI need to do laundry and wash my clothes. => 59,1849,372,745,2317,642,994,461,341,917,1672,7375,46948,32\nThey are planning a trip to Europe. => 31805,884,26116,312,19337,372,27268,32\nShe is learning how to play the guitar. => 25387,438,9608,2624,372,4654,322,3932,19931,32\nWe are going to a museum this Sunday. => 3122,884,6783,372,312,345,539,378,458,358,28036,32\nThe coffee smells amazing in the morning. => 1318,36917,309,42153,101,36986,328,322,33768,32\nHe is a hardworking farmer who grows crops. => 1331,438,312,6784,13578,9019,2302,6560,485,2138,25170,1069,32\nThe students are presenting their research projects. => 1318,16512,884,5024,299,3623,13234,8528,32\nI enjoy playing soccer with my friends. => 59,31567,19788,22682,10035,623,1672,22523,32\nThey are volunteering at a local shelter. => 31805,884,3920,45585,8637,821,312,2196,309,2542,391,32\nShe is practicing martial arts for self-defense. => 25387,438,11808,11636,345,502,564,5549,101,436,630,31,43694,32\nWe should try a new recipe for dinner. => 3122,1395,1596,312,537,15233,436,343,3369,32\nThe traffic is congest => 1318,16391,438,457,2776\nThe sun is shining brightly today. => 1318,15323,438,787,19068,38231,631,11610,32\nI enjoy reading books in my free time. => 59,31567,9175,21739,328,1672,3741,1133,32\nShe plays the piano beautifully. => 25387,41271,322,298,25757,526,4846,325,514,107,32\nThe cat chased the mouse around the room. => 1318,10501,663,16109,322,8459,6835,322,8355,32\nI love eating pizza with extra cheese. => 59,14290,484,1741,47630,623,6717,8277,30315,32\nHe always wears a hat wherever he goes. => 1331,5182,996,4177,312,25793,2154,424,938,13107,32\nThe flowers in the garden are blooming. => 1318,7290,483,328,322,485,22461,884,323,18466,299,32\nShe danced gracefully on the stage. => 25387,343,6087,31376,4938,544,322,10019,32\nThe dog barked loudly in the park. => 1318,27435,323,1087,318,598,836,631,328,322,880,93,32\nWe went swimming in the ocean yesterday. => 3122,14236,2535,449,6714,328,322,337,18857,39485,32\nHe speaks fluent French and Spanish. => 1331,24498,101,38055,43652,461,14911,1708,32\nThe train arrived at the station on time. => 1318,5683,2099,32114,821,322,18662,544,1133,32\nShe cooked a delicious meal for her family. => 25387,23682,318,312,409,406,2406,597,279,436,7791,13872,32\n"
  },
  {
    "path": "examples/prompts/test-cases.txt",
    "content": "# test case format\n# <language>: <sentence>\n\nEnglish: Hello World!\nEnglish: I can't believe it's already Friday!\"\nEnglish: The URL for the website is https://www.example.com.\"\nEnglish: \"She said, 'I love to travel.'\"\nEnglish: 'The temperature is 25.5°C.'\nEnglish: \"Let's meet at 2:30 p.m. in the park.\"\nEnglish: The book costs $19.99\nEnglish: \"John's favorite color is blue.\"\nEnglish: Th@nk y0u f0r y0ur h3lp!\nEnglish: C@n I g3t a c0ffee, pl3@se?\nEnglish: W0w! Th@t's @m@zing!\nEnglish: H0w 4re y0u t0d@y?\nEnglish: I l0ve t0 tr@vel @r0und the w0rld.\nEnglish: Wh@t's y0ur f@v0rite m0vie?\nEnglish: The cat is sleeping on the mat.\nEnglish: I need to buy some groceries for dinner.\nEnglish: The sun is shining brightly in the sky.\nEnglish: She is reading a book in the park.\nEnglish: We went for a walk on the beach yesterday.\nEnglish: He plays the guitar like a pro.\nEnglish: They are going to the movies tonight.\nEnglish: The flowers are blooming in the garden.\nEnglish: I enjoy listening to classical music.\nEnglish: We need to buy groceries for the week.\nEnglish: The dog is chasing its tail in circles.\nEnglish: She is wearing a beautiful red dress.\nEnglish: He is a talented actor in Hollywood.\nEnglish: The children are playing in the playground.\nEnglish: I'm going to visit my grandparents this weekend.\nEnglish: The coffee tastes bitter without sugar.\nEnglish: They are planning a surprise party for her.\nEnglish: She sings like an angel on stage.\nEnglish: We should take a vacation to relax.\nEnglish: He is studying medicine at the university.\nEnglish: The rain is pouring heavily outside.\nEnglish: I enjoy watching romantic movies.\nEnglish: They are celebrating their anniversary today.\nEnglish: She dances gracefully to the music.\nEnglish: He is an excellent basketball player.\nEnglish: The baby is sleeping soundly in the crib.\nEnglish: I need to finish my homework before dinner.\nEnglish: They are organizing a charity event next month.\nEnglish: She is cooking a delicious meal for us.\nEnglish: We should go hiking in the mountains.\nEnglish: The car broke down on the way to work.\nEnglish: He loves playing video games in his free time.\nEnglish: The birds are chirping in the trees.\nEnglish: I want to learn how to play the piano.\nEnglish: They are building a new shopping mall in the city.\nEnglish: She is writing a novel in her spare time.\nEnglish: We are going to the zoo this Saturday.\nEnglish: The cake looks delicious with chocolate frosting.\nEnglish: He is a talented painter who sells his artwork.\nEnglish: The students are studying for their exams.\nEnglish: I enjoy swimming in the ocean.\nEnglish: They are renovating their house.\nEnglish: She is practicing yoga to stay healthy.\nEnglish: We should plant flowers in the garden.\nEnglish: The traffic is heavy during rush hour.\nEnglish: He is a skilled chef who creates amazing dishes.\nEnglish: The baby is crawling on the floor.\nEnglish: I need to buy a new pair of shoes.\nEnglish: They are going on a road trip across the country.\nEnglish: She is playing the piano beautifully.\nEnglish: We are going to a concert tomorrow night.\nEnglish: The cake tastes delicious with vanilla frosting.\nEnglish: He is a dedicated teacher who inspires his students.\nEnglish: The students are participating in a science fair.\nEnglish: I enjoy hiking in the mountains.\nEnglish: They are organizing a beach cleanup next weekend.\nEnglish: She is taking photographs of nature.\nEnglish: We should try a new restaurant in town.\nEnglish: The traffic is moving slowly on the highway.\nEnglish: He is a talented singer with a beautiful voice.\nEnglish: The baby is laughing and giggling.\nEnglish: I need to do laundry and wash my clothes.\nEnglish: They are planning a trip to Europe.\nEnglish: She is learning how to play the guitar.\nEnglish: We are going to a museum this Sunday.\nEnglish: The coffee smells amazing in the morning.\nEnglish: He is a hardworking farmer who grows crops.\nEnglish: The students are presenting their research projects.\nEnglish: I enjoy playing soccer with my friends.\nEnglish: They are volunteering at a local shelter.\nEnglish: She is practicing martial arts for self-defense.\nEnglish: We should try a new recipe for dinner.\nEnglish: The traffic is congest\nEnglish: The sun is shining brightly today.\nEnglish: I enjoy reading books in my free time.\nEnglish: She plays the piano beautifully.\nEnglish: The cat chased the mouse around the room.\nEnglish: I love eating pizza with extra cheese.\nEnglish: He always wears a hat wherever he goes.\nEnglish: The flowers in the garden are blooming.\nEnglish: She danced gracefully on the stage.\nEnglish: The dog barked loudly in the park.\nEnglish: We went swimming in the ocean yesterday.\nEnglish: He speaks fluent French and Spanish.\nEnglish: The train arrived at the station on time.\nEnglish: She cooked a delicious meal for her family.\nKorean: 이것은 테스트 이다.\nKorean: 걱정할 필요 없다.\nKorean: 버그는 언젠가 고쳐진다.\nJapanese: 明日の天気はどうですか。\nChinese: 请问洗手间在哪里？\nEmoji: I'm feeling 😄 today!\nUnicode: ◑ ▢ ▣ ◱"
  },
  {
    "path": "examples/prompts/tokenize_huggingface.py",
    "content": "import os\nfrom transformers import AutoTokenizer\n\nos.environ['TOKENIZERS_PARALLELISM'] = \"false\"\n\nlist_repo_hf  = [\"databricks/dolly-v2-3b\",           # dolly-v2 (3b, 7b, 12b models share the same tokenizer)\n                 \"gpt2\",                             # gpt-2 (gpt2-xl, gpt2-large share the same tokenizer)\n                 \"uer/gpt2-chinese-cluecorpussmall\", # gpt-2-chinese\n                 \"EleutherAI/gpt-j-6b\",              # gpt-j\n                 \"EleutherAI/gpt-neox-20b\",          # gpt-neox\n                 \"EleutherAI/polyglot-ko-1.3b\",      # gpt-neox (polyglot-ko 5.8b and 12.8b share the same tokenizer\")\n                 \"rinna/japanese-gpt-neox-3.6b\",     # gpt-neox\n                 # mpt-7b (uses gpt-neox-20b tokenizer)\n                 \"replit/replit-code-v1-3b\",         # replit\n                 \"bigcode/starcoder\",                # starcoder (huggingface-cli login required)\n                 \"openai/whisper-tiny\"               # whisper (base, large, large-v2 share the same tokenizer)\n                 ]\n\nrepo2ggml     = {\"databricks/dolly-v2-3b\"           : \"dolly-v2\",\n                 \"gpt2\"                             : \"gpt-2\",\n                 \"uer/gpt2-chinese-cluecorpussmall\" : \"gpt-2-chinese\",\n                 \"EleutherAI/gpt-j-6b\"              : \"gpt-j\",\n                 \"EleutherAI/gpt-neox-20b\"          : \"gpt-neox\",\n                 \"EleutherAI/polyglot-ko-1.3b\"      : \"polyglot-ko\",\n                 \"rinna/japanese-gpt-neox-3.6b\"     : \"gpt-neox-japanese\",\n                 \"replit/replit-code-v1-3b\"         : \"replit\",\n                 \"bigcode/starcoder\"                : \"starcoder\",\n                 \"openai/whisper-tiny\"              : \"whisper\"}\n\nrepo2language = {\"databricks/dolly-v2-3b\"           : \"english\",\n                 \"gpt2\"                             : \"english\",\n                 \"uer/gpt2-chinese-cluecorpussmall\" : \"chinese\",\n                 \"EleutherAI/gpt-j-6b\"              : \"english\",\n                 \"EleutherAI/gpt-neox-20b\"          : \"english\",\n                 \"EleutherAI/polyglot-ko-1.3b\"      : \"korean\",\n                 \"rinna/japanese-gpt-neox-3.6b\"     : \"japanese\",\n                 \"replit/replit-code-v1-3b\"         : \"english\",\n                 \"bigcode/starcoder\"                : \"english\",\n                 \"openai/whisper-tiny\"              : \"english\"}\n\ndelimeter = \": \"\ntest_sentences = []\nwith open(\"test-cases.txt\", \"r\") as f:\n    lines = [l.rstrip() for l in f.readlines()]\n    for l in lines:\n        if delimeter in l:\n            language = l[:l.index(delimeter)]\n            sentence = l[l.index(delimeter) + len(delimeter):]\n            test_sentences.append((language.lower(), sentence))\n\nfor repo in list_repo_hf:\n\n    target_language = repo2language[repo]\n\n    tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True)\n\n    tokens_hf = []\n    for language, sentence in test_sentences:\n        if language == target_language:\n            tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sentence))\n            tokens_hf.append((sentence, tokens))\n\n    save_txt = repo2ggml[repo] + \".txt\"\n    with open(save_txt, \"w\") as f:\n        f.writelines([sentence + \" => \" + \",\".join(str(t) for t in tokens) + \"\\n\" for sentence, tokens in tokens_hf])\n"
  },
  {
    "path": "examples/prompts/whisper.txt",
    "content": "Hello World! => 15947,3937,0\nI can't believe it's already Friday!\" => 40,393,380,1697,309,311,1217,6984,2963\nThe URL for the website is https://www.example.com.\" => 2278,12905,337,220,3322,3144,307,34426,21492,17919,13,3121,335,781,13,1112,889\n\"She said, 'I love to travel.'\" => 1,9526,848,11,922,40,959,220,1353,220,17227,779,28763\n'The temperature is 25.5°C.' => 6,2278,220,18275,610,1503,307,3552,13,20,11782,34,4443\n\"Let's meet at 2:30 p.m. in the park.\" => 1,8373,311,1677,412,568,25,3446,280,13,76,13,294,220,3322,3884,889\nThe book costs $19.99 => 2278,1446,5497,1848,3405,13,8494\n\"John's favorite color is blue.\" => 1,16938,311,2954,2017,307,3344,889\nTh@nk y0u f0r y0ur h3lp! => 2434,31,77,74,288,15,84,283,15,81,288,15,374,276,18,75,79,0\nC@n I g3t a c0ffee, pl3@se? => 34,31,77,286,290,18,83,257,269,15,4617,11,499,18,31,405,30\nW0w! Th@t's @m@zing! => 54,15,86,0,334,31,83,311,10428,76,31,8781,0\nH0w 4re y0u t0d@y? => 39,15,86,1017,265,288,15,84,220,83,15,67,31,88,30\nI l0ve t0 tr@vel @r0und the w0rld. => 40,287,15,303,220,83,15,220,6903,31,779,10428,81,15,997,220,3322,261,15,81,348,13\nWh@t's y0ur f@v0rite m0vie? => 2471,31,83,311,288,15,374,283,31,85,15,35002,275,15,12702,30\nThe cat is sleeping on the mat. => 2278,3857,307,8296,322,220,3322,3803,13\nI need to buy some groceries for dinner. => 40,643,220,1353,2256,512,31391,337,6148,13\nThe sun is shining brightly in the sky. => 2278,3295,307,18269,47418,294,220,3322,5443,13\nShe is reading a book in the park. => 9526,307,3760,257,1446,294,220,3322,3884,13\nWe went for a walk on the beach yesterday. => 4360,1437,337,257,1792,322,220,3322,7534,5186,13\nHe plays the guitar like a pro. => 5205,5749,220,3322,7531,411,257,447,13\nThey are going to the movies tonight. => 8829,366,516,220,1353,220,3322,6233,220,1756,397,13\nThe flowers are blooming in the garden. => 2278,8085,366,45294,294,220,3322,7431,13\nI enjoy listening to classical music. => 40,2103,4764,220,1353,13735,1318,13\nWe need to buy groceries for the week. => 4360,643,220,1353,2256,31391,337,220,3322,1243,13\nThe dog is chasing its tail in circles. => 2278,3000,307,17876,1080,220,14430,294,13040,13\nShe is wearing a beautiful red dress. => 9526,307,4769,257,2238,2182,5231,13\nHe is a talented actor in Hollywood. => 5205,307,257,220,32831,6003,8747,294,11628,13\nThe children are playing in the playground. => 2278,2227,366,2433,294,220,3322,24646,13\nI'm going to visit my grandparents this weekend. => 40,478,516,220,1353,3441,452,21876,220,11176,6711,13\nThe coffee tastes bitter without sugar. => 2278,4982,220,83,40246,13871,1553,5076,13\nThey are planning a surprise party for her. => 8829,366,5038,257,6365,3595,337,720,13\nShe sings like an angel on stage. => 9526,23250,411,364,14250,322,3233,13\nWe should take a vacation to relax. => 4360,820,220,27612,257,12830,220,1353,5789,13\nHe is studying medicine at the university. => 5205,307,7601,7195,412,220,3322,5454,13\nThe rain is pouring heavily outside. => 2278,4830,307,20450,10950,2380,13\nI enjoy watching romantic movies. => 40,2103,1976,13590,6233,13\nThey are celebrating their anniversary today. => 8829,366,15252,220,3322,347,12962,220,83,378,320,13\nShe dances gracefully to the music. => 9526,28322,10042,2277,220,1353,220,3322,1318,13\nHe is an excellent basketball player. => 5205,307,364,7103,11767,4256,13\nThe baby is sleeping soundly in the crib. => 2278,3186,307,8296,1626,356,294,220,3322,47163,13\nI need to finish my homework before dinner. => 40,643,220,1353,2413,452,14578,949,6148,13\nThey are organizing a charity event next month. => 8829,366,17608,257,16863,2280,958,1618,13\nShe is cooking a delicious meal for us. => 9526,307,6361,257,4809,6791,337,505,13\nWe should go hiking in the mountains. => 4360,820,352,23784,294,220,3322,10233,13\nThe car broke down on the way to work. => 2278,1032,6902,760,322,220,3322,636,220,1353,589,13\nHe loves playing video games in his free time. => 5205,6752,2433,960,2813,294,702,1737,220,3766,13\nThe birds are chirping in the trees. => 2278,9009,366,36682,294,220,3322,220,3599,279,13\nI want to learn how to play the piano. => 40,528,220,1353,1466,577,220,1353,862,220,3322,9211,13\nThey are building a new shopping mall in the city. => 8829,366,2390,257,777,8688,16026,294,220,3322,2307,13\nShe is writing a novel in her spare time. => 9526,307,3579,257,7613,294,720,13798,220,3766,13\nWe are going to the zoo this Saturday. => 4360,366,516,220,1353,220,3322,25347,220,11176,8803,13\nThe cake looks delicious with chocolate frosting. => 2278,5908,1542,4809,365,6215,37048,13\nHe is a talented painter who sells his artwork. => 5205,307,257,220,32831,6003,26619,567,20897,702,15829,13\nThe students are studying for their exams. => 2278,1731,366,7601,337,220,3322,347,20514,13\nI enjoy swimming in the ocean. => 40,2103,11989,294,220,3322,7810,13\nThey are renovating their house. => 8829,366,18845,990,220,3322,347,1782,13\nShe is practicing yoga to stay healthy. => 9526,307,11350,15128,220,1353,1754,4627,13\nWe should plant flowers in the garden. => 4360,820,3709,8085,294,220,3322,7431,13\nThe traffic is heavy during rush hour. => 2278,220,17227,3341,307,4676,1830,9300,1773,13\nHe is a skilled chef who creates amazing dishes. => 5205,307,257,19690,10530,567,7829,2243,10814,13\nThe baby is crawling on the floor. => 2278,3186,307,32979,322,220,3322,4123,13\nI need to buy a new pair of shoes. => 40,643,220,1353,2256,257,777,6119,295,6654,13\nThey are going on a road trip across the country. => 8829,366,516,322,257,3060,220,83,8400,2108,220,3322,1941,13\nShe is playing the piano beautifully. => 9526,307,2433,220,3322,9211,16525,13\nWe are going to a concert tomorrow night. => 4360,366,516,220,1353,257,8543,220,83,298,3162,1818,13\nThe cake tastes delicious with vanilla frosting. => 2278,5908,220,83,40246,4809,365,17528,37048,13\nHe is a dedicated teacher who inspires his students. => 5205,307,257,8374,220,975,4062,567,32566,702,1731,13\nThe students are participating in a science fair. => 2278,1731,366,13950,294,257,3497,3143,13\nI enjoy hiking in the mountains. => 40,2103,23784,294,220,3322,10233,13\nThey are organizing a beach cleanup next weekend. => 8829,366,17608,257,7534,40991,958,6711,13\nShe is taking photographs of nature. => 9526,307,220,48625,17649,295,3687,13\nWe should try a new restaurant in town. => 4360,820,220,83,627,257,777,6383,294,220,30401,13\nThe traffic is moving slowly on the highway. => 2278,220,17227,3341,307,2684,5692,322,220,3322,17205,13\nHe is a talented singer with a beautiful voice. => 5205,307,257,220,32831,6003,11564,365,257,2238,3177,13\nThe baby is laughing and giggling. => 2278,3186,307,5059,293,290,24542,13\nI need to do laundry and wash my clothes. => 40,643,220,1353,360,19811,293,5675,452,5534,13\nThey are planning a trip to Europe. => 8829,366,5038,257,220,83,8400,220,1353,3315,13\nShe is learning how to play the guitar. => 9526,307,2539,577,220,1353,862,220,3322,7531,13\nWe are going to a museum this Sunday. => 4360,366,516,220,1353,257,8441,220,11176,7776,13\nThe coffee smells amazing in the morning. => 2278,4982,10036,2243,294,220,3322,2446,13\nHe is a hardworking farmer who grows crops. => 5205,307,257,1152,22475,17891,567,13156,16829,13\nThe students are presenting their research projects. => 2278,1731,366,15578,220,3322,347,2132,4455,13\nI enjoy playing soccer with my friends. => 40,2103,2433,15469,365,452,1855,13\nThey are volunteering at a local shelter. => 8829,366,33237,412,257,2654,13341,13\nShe is practicing martial arts for self-defense. => 9526,307,11350,20755,8609,337,2698,12,49268,13\nWe should try a new recipe for dinner. => 4360,820,220,83,627,257,777,6782,337,6148,13\nThe traffic is congest => 2278,220,17227,3341,307,31871\nThe sun is shining brightly today. => 2278,3295,307,18269,47418,220,83,378,320,13\nI enjoy reading books in my free time. => 40,2103,3760,3642,294,452,1737,220,3766,13\nShe plays the piano beautifully. => 9526,5749,220,3322,9211,16525,13\nThe cat chased the mouse around the room. => 2278,3857,33091,220,3322,9719,926,220,3322,1808,13\nI love eating pizza with extra cheese. => 40,959,3936,8298,365,2857,5399,13\nHe always wears a hat wherever he goes. => 5205,1009,20877,257,2385,8660,415,1709,13\nThe flowers in the garden are blooming. => 2278,8085,294,220,3322,7431,366,45294,13\nShe danced gracefully on the stage. => 9526,32909,10042,2277,322,220,3322,3233,13\nThe dog barked loudly in the park. => 2278,3000,16202,292,22958,294,220,3322,3884,13\nWe went swimming in the ocean yesterday. => 4360,1437,11989,294,220,3322,7810,5186,13\nHe speaks fluent French and Spanish. => 5205,10789,40799,5522,293,8058,13\nThe train arrived at the station on time. => 2278,220,83,7146,6678,412,220,3322,5214,322,220,3766,13\nShe cooked a delicious meal for her family. => 9526,9267,257,4809,6791,337,720,1605,13\n"
  },
  {
    "path": "examples/python/README.md",
    "content": "# Simple autogenerated Python bindings for ggml\n\nThis folder contains:\n\n- Scripts to generate full Python bindings from ggml headers (+ stubs for autocompletion in IDEs)\n- Some barebones utils (see [ggml/utils.py](./ggml/utils.py)):\n  - `ggml.utils.init` builds a context that's freed automatically when the pointer gets GC'd\n  - `ggml.utils.copy` **copies between same-shaped tensors (numpy or ggml), w/ automatic (de/re)quantization**\n  - `ggml.utils.numpy` returns a numpy view over a ggml tensor; if it's quantized, it returns a copy (requires `allow_copy=True`)\n- Very basic examples (anyone wants to port [llama2.c](https://github.com/karpathy/llama2.c)?)\n\nProvided you set `GGML_LIBRARY=.../path/to/libggml_shared.so` (see instructions below), it's trivial to do some operations on quantized tensors:\n\n```python\n# Make sure libllama.so is in your [DY]LD_LIBRARY_PATH, or set GGML_LIBRARY=.../libggml_shared.so\n\nfrom ggml import lib, ffi\nfrom ggml.utils import init, copy, numpy\nimport numpy as np\n\nctx = init(mem_size=12*1024*1024)\nn = 256\nn_threads = 4\n\na = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_Q5_K, n)\nb = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, n) # Can't both be quantized\nsum = lib.ggml_add(ctx, a, b) # all zeroes for now. Will be quantized too!\n\ngf = ffi.new('struct ggml_cgraph*')\nlib.ggml_build_forward_expand(gf, sum)\n\ncopy(np.array([i for i in range(n)], np.float32), a)\ncopy(np.array([i*100 for i in range(n)], np.float32), b)\n\nlib.ggml_graph_compute_with_ctx(ctx, gf, n_threads)\n\nprint(numpy(a, allow_copy=True))\n#  0.    1.0439453   2.0878906   3.131836    4.1757812   5.2197266. ...\nprint(numpy(b))\n#  0.  100.        200.        300.        400.        500.         ...\nprint(numpy(sum, allow_copy=True))\n#  0.  105.4375    210.875     316.3125    421.75      527.1875     ...\n```\n\n### Prerequisites\n\nYou'll need a shared library of ggml to use the bindings.\n\n#### Build libggml_shared.so or libllama.so\n\nAs of this writing the best is to use [ggerganov/llama.cpp](https://github.com/ggerganov/llama.cpp)'s generated `libggml_shared.so` or `libllama.so`, which you can build as follows:\n\n```bash\ngit clone https://github.com/ggerganov/llama.cpp\n# On a CUDA-enabled system add -DLLAMA_CUDA=1\n# On a Mac add -DLLAMA_METAL=1\ncmake llama.cpp \\\n  -B llama_build \\\n  -DCMAKE_C_FLAGS=-Ofast \\\n  -DLLAMA_NATIVE=1 \\\n  -DLLAMA_LTO=1 \\\n  -DBUILD_SHARED_LIBS=1 \\\n  -DLLAMA_MPI=1 \\\n  -DLLAMA_BUILD_TESTS=0 \\\n  -DLLAMA_BUILD_EXAMPLES=0\n( cd llama_build && make -j )\n\n# On Mac, this will be libggml_shared.dylib instead\nexport GGML_LIBRARY=$PWD/llama_build/libggml_shared.so\n# Alternatively, you can just copy it to your system's lib dir, e.g /usr/local/lib\n```\n\n#### (Optional) Regenerate the bindings and stubs\n\nIf you added or changed any signatures of the C API, you'll want to regenerate the bindings ([ggml/cffi.py](./ggml/cffi.py)) and stubs ([ggml/__init__.pyi](./ggml/__init__.pyi)).\n\nLuckily it's a one-liner using [regenerate.py](./regenerate.py):\n\n```bash\npip install -q cffi\n\npython regenerate.py\n```\n\nBy default it assumes `llama.cpp` was cloned in ../../../llama.cpp (alongside the ggml folder). You can override this with:\n\n```bash\nC_INCLUDE_DIR=$LLAMA_CPP_DIR python regenerate.py\n```\n\nYou can also edit [api.h](./api.h) to control which files should be included in the generated bindings (defaults to `llama.cpp/ggml*.h`)\n\nIn fact, if you wanted to only generate bindings for the current version of the `ggml` repo itself (instead of `llama.cpp`; you'd loose support for k-quants), you could run:\n\n```bash\nAPI=../../include/ggml.h python regenerate.py\n```\n\n## Develop\n\nRun tests:\n\n```bash\npytest\n```\n\n### Alternatives\n\nThis example's goal is to showcase [cffi](https://cffi.readthedocs.io/)-generated bindings that are trivial to use and update, but there are already alternatives in the wild:\n\n- https://github.com/abetlen/ggml-python: these bindings seem to be hand-written and use [ctypes](https://docs.python.org/3/library/ctypes.html). It has [high-quality API reference docs](https://ggml-python.readthedocs.io/en/latest/api-reference/#ggml.ggml) that can be used with these bindings too, but it doesn't expose Metal, CUDA, MPI or OpenCL calls, doesn't support transparent (de/re)quantization like this example does (see [ggml.utils](./ggml/utils.py) module), and won't pick up your local changes.\n\n- https://github.com/abetlen/llama-cpp-python: these expose the C++ `llama.cpp` interface, which this example cannot easily be extended to support (`cffi` only generates bindings of C libraries)\n\n- [pybind11](https://github.com/pybind/pybind11) and [nanobind](https://github.com/wjakob/nanobind) are two alternatives to cffi that support binding C++ libraries, but it doesn't seem either of them have an automatic generator (writing bindings is rather time-consuming).\n"
  },
  {
    "path": "examples/python/api.h",
    "content": "/*\n  List here all the headers you want to expose in the Python bindings,\n  then run `python regenerate.py` (see details in README.md)\n*/\n\n#include \"ggml.h\"\n#include \"ggml-metal.h\"\n#include \"ggml-opencl.h\"\n\n// Headers below are currently only present in the llama.cpp repository, comment them out if you don't have them.\n#include \"k_quants.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-cuda.h\"\n#include \"ggml-mpi.h\""
  },
  {
    "path": "examples/python/example_add_quant.py",
    "content": "from ggml import lib, ffi\nfrom ggml.utils import init, copy, numpy\nimport numpy as np\n\nctx = init(mem_size=12*1024*1024) # automatically freed when pointer is GC'd\nn = 256\nn_threads = 4\n\na = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_Q5_K, n)\nb = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, n) # can't both be quantized\nsum = lib.ggml_add(ctx, a, b) # all zeroes for now. Will be quantized too!\n\n# See cffi's doc on how to allocate native memory: it's very simple!\n# https://cffi.readthedocs.io/en/latest/ref.html#ffi-interface\ngf = ffi.new('struct ggml_cgraph*')\nlib.ggml_build_forward_expand(gf, sum)\n\ncopy(np.array([i for i in range(n)], np.float32), a)\ncopy(np.array([i*100 for i in range(n)], np.float32), b)\n\nlib.ggml_graph_compute_with_ctx(ctx, gf, n_threads)\n\nprint(numpy(a, allow_copy=True))\nprint(numpy(b))\nprint(numpy(sum, allow_copy=True))"
  },
  {
    "path": "examples/python/example_test_all_quants.py",
    "content": "from ggml import ffi, lib\nfrom ggml.utils import init, numpy, copy\nimport numpy as np\nfrom math import pi, cos, sin, ceil\n\nimport matplotlib.pyplot as plt\n\nctx = init(mem_size=100*1024*1024) # Will be auto-GC'd\nn = 256\n\norig = np.array([\n    [\n        cos(j * 2 * pi / n) * (sin(i * 2 * pi / n))\n        for j in range(n)\n    ]\n    for i in range(n)\n], np.float32)\norig_tensor = lib.ggml_new_tensor_2d(ctx, lib.GGML_TYPE_F32, n, n)\ncopy(orig, orig_tensor)\n\nquants = [\n    type for type in range(lib.GGML_TYPE_COUNT)\n    if lib.ggml_is_quantized(type) and\n       type not in [lib.GGML_TYPE_Q8_1, lib.GGML_TYPE_Q8_K] # Apparently not supported\n]\n# quants = [lib.GGML_TYPE_Q2_K] # Test a single one\n\ndef get_name(type):\n    name = lib.ggml_type_name(type)\n    return ffi.string(name).decode('utf-8') if name else '?'\n\nquants.sort(key=get_name)\nquants.insert(0, None)\nprint(quants)\n\nncols=4\nnrows = ceil(len(quants) / ncols)\n\nplt.figure(figsize=(ncols * 5, nrows * 5), layout='tight')\n\nfor i, type in enumerate(quants):\n    plt.subplot(nrows, ncols, i + 1)\n    try:\n        if type == None:\n            plt.title('Original')\n            plt.imshow(orig)\n        else:\n            quantized_tensor = lib.ggml_new_tensor_2d(ctx, type, n, n)\n            copy(orig_tensor, quantized_tensor)\n            quantized = numpy(quantized_tensor, allow_copy=True)\n            d = quantized - orig\n            results = {\n                \"l2\": np.linalg.norm(d, 2),\n                \"linf\": np.linalg.norm(d, np.inf),\n                \"compression\":\n                    round(lib.ggml_nbytes(orig_tensor) /\n                          lib.ggml_nbytes(quantized_tensor), 1)\n            }\n            name = get_name(type)\n            print(f'{name}: {results}')\n\n            plt.title(f'{name} ({results[\"compression\"]}x smaller)')\n            plt.imshow(quantized, interpolation='nearest')\n        \n    except Exception as e:\n        print(f'Error: {e}')\n\nplt.show()"
  },
  {
    "path": "examples/python/ggml/__init__.py",
    "content": "\"\"\"\n  Python bindings for the ggml library.\n\n  Usage example:\n\n      from ggml import lib, ffi\n      from ggml.utils import init, copy, numpy\n      import numpy as np\n\n      ctx = init(mem_size=10*1024*1024)\n      n = 1024\n      n_threads = 4\n\n      a = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_Q5_K, n)\n      b = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, n)\n      sum = lib.ggml_add(ctx, a, b)\n\n      gf = ffi.new('struct ggml_cgraph*')\n      lib.ggml_build_forward_expand(gf, sum)\n\n      copy(np.array([i for i in range(n)], np.float32), a)\n      copy(np.array([i*100 for i in range(n)], np.float32), b)\n      lib.ggml_graph_compute_with_ctx(ctx, gf, n_threads)\n\n      print(numpy(sum, allow_copy=True))\n\n  See https://cffi.readthedocs.io/en/latest/cdef.html for more on cffi.\n\"\"\"\n\ntry:\n    from ggml.cffi import ffi as ffi\nexcept ImportError as e:\n    raise ImportError(f\"Couldn't find ggml bindings ({e}). Run `python regenerate.py` or check your PYTHONPATH.\")\n\nimport os, platform\n\n__exact_library = os.environ.get(\"GGML_LIBRARY\")\nif __exact_library:\n    __candidates = [__exact_library]\nelif platform.system() == \"Windows\":\n    __candidates = [\"ggml_shared.dll\", \"llama.dll\"]\nelse:\n    __candidates = [\"libggml_shared.so\", \"libllama.so\"]\n    if platform.system() == \"Darwin\":\n        __candidates += [\"libggml_shared.dylib\", \"libllama.dylib\"]\n\nfor i, name in enumerate(__candidates):\n    try:\n        # This is where all the functions, enums and constants are defined\n        lib = ffi.dlopen(name)\n    except OSError:\n        if i < len(__candidates) - 1:\n            continue\n        raise OSError(f\"Couldn't find ggml's shared library (tried names: {__candidates}). Add its directory to DYLD_LIBRARY_PATH (on Mac) or LD_LIBRARY_PATH, or define GGML_LIBRARY.\")\n\n# This contains the cffi helpers such as new, cast, string, etc.\n# https://cffi.readthedocs.io/en/latest/ref.html#ffi-interface\nffi = ffi\n"
  },
  {
    "path": "examples/python/ggml/__init__.pyi",
    "content": "# auto-generated file\nimport ggml.ffi as ffi\nimport numpy as np\nclass lib:\n  @property\n  def GGML_BACKEND_CPU(self) -> int: ...\n  @property\n  def GGML_BACKEND_GPU(self) -> int: ...\n  @property\n  def GGML_BACKEND_GPU_SPLIT(self) -> int: ...\n  @property\n  def GGML_FTYPE_ALL_F32(self) -> int: ...\n  @property\n  def GGML_FTYPE_MOSTLY_F16(self) -> int: ...\n  @property\n  def GGML_FTYPE_MOSTLY_Q2_K(self) -> int: ...\n  @property\n  def GGML_FTYPE_MOSTLY_Q3_K(self) -> int: ...\n  @property\n  def GGML_FTYPE_MOSTLY_Q4_0(self) -> int: ...\n  @property\n  def GGML_FTYPE_MOSTLY_Q4_1(self) -> int: ...\n  @property\n  def GGML_FTYPE_MOSTLY_Q4_1_SOME_F16(self) -> int: ...\n  @property\n  def GGML_FTYPE_MOSTLY_Q4_K(self) -> int: ...\n  @property\n  def GGML_FTYPE_MOSTLY_Q5_0(self) -> int: ...\n  @property\n  def GGML_FTYPE_MOSTLY_Q5_1(self) -> int: ...\n  @property\n  def GGML_FTYPE_MOSTLY_Q5_K(self) -> int: ...\n  @property\n  def GGML_FTYPE_MOSTLY_Q6_K(self) -> int: ...\n  @property\n  def GGML_FTYPE_MOSTLY_Q8_0(self) -> int: ...\n  @property\n  def GGML_FTYPE_UNKNOWN(self) -> int: ...\n  @property\n  def GGML_LINESEARCH_BACKTRACKING_ARMIJO(self) -> int: ...\n  @property\n  def GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE(self) -> int: ...\n  @property\n  def GGML_LINESEARCH_BACKTRACKING_WOLFE(self) -> int: ...\n  @property\n  def GGML_LINESEARCH_DEFAULT(self) -> int: ...\n  @property\n  def GGML_LINESEARCH_FAIL(self) -> int: ...\n  @property\n  def GGML_LINESEARCH_INVALID_PARAMETERS(self) -> int: ...\n  @property\n  def GGML_LINESEARCH_MAXIMUM_ITERATIONS(self) -> int: ...\n  @property\n  def GGML_LINESEARCH_MAXIMUM_STEP(self) -> int: ...\n  @property\n  def GGML_LINESEARCH_MINIMUM_STEP(self) -> int: ...\n  @property\n  def GGML_OBJECT_GRAPH(self) -> int: ...\n  @property\n  def GGML_OBJECT_TENSOR(self) -> int: ...\n  @property\n  def GGML_OBJECT_WORK_BUFFER(self) -> int: ...\n  @property\n  def GGML_OPT_ADAM(self) -> int: ...\n  @property\n  def GGML_OPT_DID_NOT_CONVERGE(self) -> int: ...\n  @property\n  def GGML_OPT_FAIL(self) -> int: ...\n  @property\n  def GGML_OPT_INVALID_WOLFE(self) -> int: ...\n  @property\n  def GGML_OPT_LBFGS(self) -> int: ...\n  @property\n  def GGML_OPT_NO_CONTEXT(self) -> int: ...\n  @property\n  def GGML_OPT_OK(self) -> int: ...\n  @property\n  def GGML_OP_ACC(self) -> int: ...\n  @property\n  def GGML_OP_ADD(self) -> int: ...\n  @property\n  def GGML_OP_ADD1(self) -> int: ...\n  @property\n  def GGML_OP_ALIBI(self) -> int: ...\n  @property\n  def GGML_OP_ARGMAX(self) -> int: ...\n  @property\n  def GGML_OP_CLAMP(self) -> int: ...\n  @property\n  def GGML_OP_CONT(self) -> int: ...\n  @property\n  def GGML_OP_CONV_1D(self) -> int: ...\n  @property\n  def GGML_OP_CONV_2D(self) -> int: ...\n  @property\n  def GGML_OP_COUNT(self) -> int: ...\n  @property\n  def GGML_OP_CPY(self) -> int: ...\n  @property\n  def GGML_OP_CROSS_ENTROPY_LOSS(self) -> int: ...\n  @property\n  def GGML_OP_CROSS_ENTROPY_LOSS_BACK(self) -> int: ...\n  @property\n  def GGML_OP_DIAG(self) -> int: ...\n  @property\n  def GGML_OP_DIAG_MASK_INF(self) -> int: ...\n  @property\n  def GGML_OP_DIAG_MASK_ZERO(self) -> int: ...\n  @property\n  def GGML_OP_DIV(self) -> int: ...\n  @property\n  def GGML_OP_DUP(self) -> int: ...\n  @property\n  def GGML_OP_FLASH_ATTN(self) -> int: ...\n  @property\n  def GGML_OP_FLASH_ATTN_BACK(self) -> int: ...\n  @property\n  def GGML_OP_FLASH_FF(self) -> int: ...\n  @property\n  def GGML_OP_GET_ROWS(self) -> int: ...\n  @property\n  def GGML_OP_GET_ROWS_BACK(self) -> int: ...\n  @property\n  def GGML_OP_LOG(self) -> int: ...\n  @property\n  def GGML_OP_MAP_BINARY(self) -> int: ...\n  @property\n  def GGML_OP_MAP_CUSTOM1(self) -> int: ...\n  @property\n  def GGML_OP_MAP_CUSTOM1_F32(self) -> int: ...\n  @property\n  def GGML_OP_MAP_CUSTOM2(self) -> int: ...\n  @property\n  def GGML_OP_MAP_CUSTOM2_F32(self) -> int: ...\n  @property\n  def GGML_OP_MAP_CUSTOM3(self) -> int: ...\n  @property\n  def GGML_OP_MAP_CUSTOM3_F32(self) -> int: ...\n  @property\n  def GGML_OP_MAP_UNARY(self) -> int: ...\n  @property\n  def GGML_OP_MEAN(self) -> int: ...\n  @property\n  def GGML_OP_MUL(self) -> int: ...\n  @property\n  def GGML_OP_MUL_MAT(self) -> int: ...\n  @property\n  def GGML_OP_NONE(self) -> int: ...\n  @property\n  def GGML_OP_NORM(self) -> int: ...\n  @property\n  def GGML_OP_OUT_PROD(self) -> int: ...\n  @property\n  def GGML_OP_PERMUTE(self) -> int: ...\n  @property\n  def GGML_OP_POOL_1D(self) -> int: ...\n  @property\n  def GGML_OP_POOL_2D(self) -> int: ...\n  @property\n  def GGML_OP_POOL_AVG(self) -> int: ...\n  @property\n  def GGML_OP_POOL_COUNT(self) -> int: ...\n  @property\n  def GGML_OP_POOL_MAX(self) -> int: ...\n  @property\n  def GGML_OP_REPEAT(self) -> int: ...\n  @property\n  def GGML_OP_REPEAT_BACK(self) -> int: ...\n  @property\n  def GGML_OP_RESHAPE(self) -> int: ...\n  @property\n  def GGML_OP_RMS_NORM(self) -> int: ...\n  @property\n  def GGML_OP_RMS_NORM_BACK(self) -> int: ...\n  @property\n  def GGML_OP_ROPE(self) -> int: ...\n  @property\n  def GGML_OP_ROPE_BACK(self) -> int: ...\n  @property\n  def GGML_OP_SCALE(self) -> int: ...\n  @property\n  def GGML_OP_SET(self) -> int: ...\n  @property\n  def GGML_OP_SILU_BACK(self) -> int: ...\n  @property\n  def GGML_OP_SOFT_MAX(self) -> int: ...\n  @property\n  def GGML_OP_SOFT_MAX_BACK(self) -> int: ...\n  @property\n  def GGML_OP_SQR(self) -> int: ...\n  @property\n  def GGML_OP_SQRT(self) -> int: ...\n  @property\n  def GGML_OP_SUB(self) -> int: ...\n  @property\n  def GGML_OP_SUM(self) -> int: ...\n  @property\n  def GGML_OP_SUM_ROWS(self) -> int: ...\n  @property\n  def GGML_OP_TRANSPOSE(self) -> int: ...\n  @property\n  def GGML_OP_UNARY(self) -> int: ...\n  @property\n  def GGML_OP_VIEW(self) -> int: ...\n  @property\n  def GGML_OP_WIN_PART(self) -> int: ...\n  @property\n  def GGML_OP_WIN_UNPART(self) -> int: ...\n  @property\n  def GGML_TASK_COMPUTE(self) -> int: ...\n  @property\n  def GGML_TASK_FINALIZE(self) -> int: ...\n  @property\n  def GGML_TASK_INIT(self) -> int: ...\n  @property\n  def GGML_TYPE_COUNT(self) -> int: ...\n  @property\n  def GGML_TYPE_F16(self) -> int: ...\n  @property\n  def GGML_TYPE_F32(self) -> int: ...\n  @property\n  def GGML_TYPE_I16(self) -> int: ...\n  @property\n  def GGML_TYPE_I32(self) -> int: ...\n  @property\n  def GGML_TYPE_I8(self) -> int: ...\n  @property\n  def GGML_TYPE_Q2_K(self) -> int: ...\n  @property\n  def GGML_TYPE_Q3_K(self) -> int: ...\n  @property\n  def GGML_TYPE_Q4_0(self) -> int: ...\n  @property\n  def GGML_TYPE_Q4_1(self) -> int: ...\n  @property\n  def GGML_TYPE_Q4_K(self) -> int: ...\n  @property\n  def GGML_TYPE_Q5_0(self) -> int: ...\n  @property\n  def GGML_TYPE_Q5_1(self) -> int: ...\n  @property\n  def GGML_TYPE_Q5_K(self) -> int: ...\n  @property\n  def GGML_TYPE_Q6_K(self) -> int: ...\n  @property\n  def GGML_TYPE_Q8_0(self) -> int: ...\n  @property\n  def GGML_TYPE_Q8_1(self) -> int: ...\n  @property\n  def GGML_TYPE_Q8_K(self) -> int: ...\n  @property\n  def GGML_UNARY_OP_ABS(self) -> int: ...\n  @property\n  def GGML_UNARY_OP_ELU(self) -> int: ...\n  @property\n  def GGML_UNARY_OP_GELU(self) -> int: ...\n  @property\n  def GGML_UNARY_OP_GELU_QUICK(self) -> int: ...\n  @property\n  def GGML_UNARY_OP_NEG(self) -> int: ...\n  @property\n  def GGML_UNARY_OP_RELU(self) -> int: ...\n  @property\n  def GGML_UNARY_OP_SGN(self) -> int: ...\n  @property\n  def GGML_UNARY_OP_SILU(self) -> int: ...\n  @property\n  def GGML_UNARY_OP_STEP(self) -> int: ...\n  @property\n  def GGML_UNARY_OP_TANH(self) -> int: ...\n  @property\n  def GGUF_TYPE_ARRAY(self) -> int: ...\n  @property\n  def GGUF_TYPE_BOOL(self) -> int: ...\n  @property\n  def GGUF_TYPE_COUNT(self) -> int: ...\n  @property\n  def GGUF_TYPE_FLOAT32(self) -> int: ...\n  @property\n  def GGUF_TYPE_INT16(self) -> int: ...\n  @property\n  def GGUF_TYPE_INT32(self) -> int: ...\n  @property\n  def GGUF_TYPE_INT8(self) -> int: ...\n  @property\n  def GGUF_TYPE_STRING(self) -> int: ...\n  @property\n  def GGUF_TYPE_UINT16(self) -> int: ...\n  @property\n  def GGUF_TYPE_UINT32(self) -> int: ...\n  @property\n  def GGUF_TYPE_UINT8(self) -> int: ...\n  def abort_callback(data: ffi.CData) -> bool:\n    \"\"\"\n    abort ggml_graph_compute when true\n\n            bool (*abort_callback)(void * data);\n    \"\"\"\n    ...\n  def dequantize_row_q2_K(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"\n    Dequantization\n\n    void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k);\n    \"\"\"\n    ...\n  def dequantize_row_q3_K(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k);\"\"\"\n    ...\n  def dequantize_row_q4_K(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k);\"\"\"\n    ...\n  def dequantize_row_q5_K(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);\"\"\"\n    ...\n  def dequantize_row_q6_K(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);\"\"\"\n    ...\n  def dequantize_row_q8_K(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);\"\"\"\n    ...\n  def ggml_abs(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_abs(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_abs_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_abs_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_acc(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, nb1: int, nb2: int, nb3: int, offset: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_acc(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b,\n                size_t                nb1,\n                size_t                nb2,\n                size_t                nb3,\n                size_t                offset);\n    \"\"\"\n    ...\n  def ggml_acc_inplace(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, nb1: int, nb2: int, nb3: int, offset: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_acc_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b,\n                size_t                nb1,\n                size_t                nb2,\n                size_t                nb3,\n                size_t                offset);\n    \"\"\"\n    ...\n  def ggml_add(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_add(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_add1(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_add1(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_add1_inplace(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_add1_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_add_inplace(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_add_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_alibi(ctx: ffi.CData, a: ffi.CData, n_past: int, n_head: int, bias_max: float) -> ffi.CData:\n    \"\"\"\n    alibi position embedding\n    in-place, returns view(a)\n\n        struct ggml_tensor * ggml_alibi(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int                   n_past,\n                int                   n_head,\n                float                 bias_max);\n    \"\"\"\n    ...\n  def ggml_allocr_alloc(alloc: ffi.CData, tensor: ffi.CData) -> None:\n    \"\"\"GGML_API void   ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_allocr_alloc_graph(alloc: ffi.CData, graph: ffi.CData) -> int:\n    \"\"\"GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);\"\"\"\n    ...\n  def ggml_allocr_free(alloc: ffi.CData) -> None:\n    \"\"\"GGML_API void   ggml_allocr_free(struct ggml_allocr * alloc);\"\"\"\n    ...\n  def ggml_allocr_is_measure(alloc: ffi.CData) -> bool:\n    \"\"\"GGML_API bool   ggml_allocr_is_measure(struct ggml_allocr * alloc);\"\"\"\n    ...\n  def ggml_allocr_new(data: ffi.CData, size: int, alignment: int) -> ffi.CData:\n    \"\"\"GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment);\"\"\"\n    ...\n  def ggml_allocr_new_measure(alignment: int) -> ffi.CData:\n    \"\"\"GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);\"\"\"\n    ...\n  def ggml_allocr_reset(alloc: ffi.CData) -> None:\n    \"\"\"GGML_API void   ggml_allocr_reset(struct ggml_allocr * alloc);\"\"\"\n    ...\n  def ggml_allocr_set_parse_seq(alloc: ffi.CData, list: ffi.CData, n: int) -> None:\n    \"\"\"\n    tell the allocator to parse nodes following the order described in the list\n    you should call this if your graph are optimized to execute out-of-order\n\n    GGML_API void   ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n);\n    \"\"\"\n    ...\n  def ggml_are_same_shape(t0: ffi.CData, t1: ffi.CData) -> bool:\n    \"\"\"    GGML_API bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1);\"\"\"\n    ...\n  def ggml_argmax(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n    argmax along rows\n\n        GGML_API struct ggml_tensor * ggml_argmax(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_blck_size(type: int) -> int:\n    \"\"\"    GGML_API int     ggml_blck_size (enum ggml_type type);\"\"\"\n    ...\n  def ggml_build_backward(ctx: ffi.CData, gf: ffi.CData, keep: bool) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);\"\"\"\n    ...\n  def ggml_build_forward(tensor: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_build_forward_ctx(ctx: ffi.CData, tensor: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_build_forward_expand(cgraph: ffi.CData, tensor: ffi.CData) -> None:\n    \"\"\"    GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_cl_can_mul_mat(src0: ffi.CData, src1: ffi.CData, dst: ffi.CData) -> bool:\n    \"\"\"bool   ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);\"\"\"\n    ...\n  def ggml_cl_free_data(tensor: ffi.CData) -> None:\n    \"\"\"void ggml_cl_free_data(const struct ggml_tensor* tensor);\"\"\"\n    ...\n  def ggml_cl_host_free(ptr: ffi.CData) -> None:\n    \"\"\"void   ggml_cl_host_free(void * ptr);\"\"\"\n    ...\n  def ggml_cl_host_malloc(size: int) -> ffi.CData:\n    \"\"\"void * ggml_cl_host_malloc(size_t size);\"\"\"\n    ...\n  def ggml_cl_init() -> None:\n    \"\"\"void ggml_cl_init(void);\"\"\"\n    ...\n  def ggml_cl_mul(src0: ffi.CData, src1: ffi.CData, dst: ffi.CData) -> None:\n    \"\"\"void   ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);\"\"\"\n    ...\n  def ggml_cl_mul_mat(src0: ffi.CData, src1: ffi.CData, dst: ffi.CData, wdata: ffi.CData, wsize: int) -> None:\n    \"\"\"void   ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);\"\"\"\n    ...\n  def ggml_cl_mul_mat_get_wsize(src0: ffi.CData, src1: ffi.CData, dst: ffi.CData) -> int:\n    \"\"\"size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);\"\"\"\n    ...\n  def ggml_cl_transform_tensor(data: ffi.CData, tensor: ffi.CData) -> None:\n    \"\"\"void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_clamp(ctx: ffi.CData, a: ffi.CData, min: float, max: float) -> ffi.CData:\n    \"\"\"\n    clamp\n    in-place, returns view(a)\n\n        struct ggml_tensor * ggml_clamp(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                float                 min,\n                float                 max);\n    \"\"\"\n    ...\n  def ggml_cont(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n    make contiguous\n\n        GGML_API struct ggml_tensor * ggml_cont(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_conv_1d(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, s0: int, p0: int, d0: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_conv_1d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b,\n                int                   s0,  // stride\n                int                   p0,  // padding\n                int                   d0); // dilation\n    \"\"\"\n    ...\n  def ggml_conv_1d_ph(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, s: int, d: int) -> ffi.CData:\n    \"\"\"\n    conv_1d with padding = half\n    alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)\n\n        GGML_API struct ggml_tensor * ggml_conv_1d_ph(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b,\n                int                   s,\n                int                   d);\n    \"\"\"\n    ...\n  def ggml_conv_2d(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, s0: int, s1: int, p0: int, p1: int, d0: int, d1: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_conv_2d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b,\n                int                   s0,\n                int                   s1,\n                int                   p0,\n                int                   p1,\n                int                   d0,\n                int                   d1);\n    \"\"\"\n    ...\n  def ggml_cpu_has_arm_fma() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_arm_fma    (void);\"\"\"\n    ...\n  def ggml_cpu_has_avx() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_avx        (void);\"\"\"\n    ...\n  def ggml_cpu_has_avx2() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_avx2       (void);\"\"\"\n    ...\n  def ggml_cpu_has_avx512() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_avx512     (void);\"\"\"\n    ...\n  def ggml_cpu_has_avx512_vbmi() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_avx512_vbmi(void);\"\"\"\n    ...\n  def ggml_cpu_has_avx512_vnni() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_avx512_vnni(void);\"\"\"\n    ...\n  def ggml_cpu_has_blas() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_blas       (void);\"\"\"\n    ...\n  def ggml_cpu_has_clblast() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_clblast    (void);\"\"\"\n    ...\n  def ggml_cpu_has_cuda() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_cuda       (void);\"\"\"\n    ...\n  def ggml_cpu_has_f16c() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_f16c       (void);\"\"\"\n    ...\n  def ggml_cpu_has_fma() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_fma        (void);\"\"\"\n    ...\n  def ggml_cpu_has_fp16_va() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_fp16_va    (void);\"\"\"\n    ...\n  def ggml_cpu_has_gpublas() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_gpublas    (void);\"\"\"\n    ...\n  def ggml_cpu_has_neon() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_neon       (void);\"\"\"\n    ...\n  def ggml_cpu_has_sse3() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_sse3       (void);\"\"\"\n    ...\n  def ggml_cpu_has_vsx() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_vsx        (void);\"\"\"\n    ...\n  def ggml_cpu_has_wasm_simd() -> int:\n    \"\"\"    GGML_API int ggml_cpu_has_wasm_simd  (void);\"\"\"\n    ...\n  def ggml_cpy(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n    a -> b, return view(b)\n\n        GGML_API struct ggml_tensor * ggml_cpy(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_cross_entropy_loss(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_cross_entropy_loss(\n                struct ggml_context         * ctx,\n                struct ggml_tensor          * a,\n                struct ggml_tensor          * b);\n    \"\"\"\n    ...\n  def ggml_cross_entropy_loss_back(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, c: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(\n                struct ggml_context         * ctx,\n                struct ggml_tensor          * a,\n                struct ggml_tensor          * b,\n                struct ggml_tensor          * c);\n    \"\"\"\n    ...\n  def ggml_cuda_assign_buffers(tensor: ffi.CData) -> None:\n    \"\"\"GGML_API void   ggml_cuda_assign_buffers(struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_cuda_assign_buffers_force_inplace(tensor: ffi.CData) -> None:\n    \"\"\"GGML_API void   ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_cuda_assign_buffers_no_scratch(tensor: ffi.CData) -> None:\n    \"\"\"GGML_API void   ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_cuda_can_mul_mat(src0: ffi.CData, src1: ffi.CData, dst: ffi.CData) -> bool:\n    \"\"\"GGML_API bool   ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);\"\"\"\n    ...\n  def ggml_cuda_compute_forward(params: ffi.CData, tensor: ffi.CData) -> bool:\n    \"\"\"GGML_API bool   ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_cuda_free_data(tensor: ffi.CData) -> None:\n    \"\"\"GGML_API void   ggml_cuda_free_data(struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_cuda_free_scratch() -> None:\n    \"\"\"GGML_API void   ggml_cuda_free_scratch(void);\"\"\"\n    ...\n  def ggml_cuda_get_device_count() -> int:\n    \"\"\"GGML_API int    ggml_cuda_get_device_count(void);\"\"\"\n    ...\n  def ggml_cuda_get_device_description(device: int, description: ffi.CData, description_size: int) -> None:\n    \"\"\"GGML_API void   ggml_cuda_get_device_description(int device, char * description, size_t description_size);\"\"\"\n    ...\n  def ggml_cuda_host_free(ptr: ffi.CData) -> None:\n    \"\"\"GGML_API void   ggml_cuda_host_free(void * ptr);\"\"\"\n    ...\n  def ggml_cuda_host_malloc(size: int) -> ffi.CData:\n    \"\"\"GGML_API void * ggml_cuda_host_malloc(size_t size);\"\"\"\n    ...\n  def ggml_cuda_set_main_device(main_device: int) -> None:\n    \"\"\"GGML_API void   ggml_cuda_set_main_device(int main_device);\"\"\"\n    ...\n  def ggml_cuda_set_mul_mat_q(mul_mat_q: bool) -> None:\n    \"\"\"GGML_API void   ggml_cuda_set_mul_mat_q(bool mul_mat_q);\"\"\"\n    ...\n  def ggml_cuda_set_scratch_size(scratch_size: int) -> None:\n    \"\"\"GGML_API void   ggml_cuda_set_scratch_size(size_t scratch_size);\"\"\"\n    ...\n  def ggml_cuda_set_tensor_split(tensor_split: ffi.CData) -> None:\n    \"\"\"GGML_API void   ggml_cuda_set_tensor_split(const float * tensor_split);\"\"\"\n    ...\n  def ggml_cuda_transform_tensor(data: ffi.CData, tensor: ffi.CData) -> None:\n    \"\"\"GGML_API void   ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_cycles() -> int:\n    \"\"\"    GGML_API int64_t ggml_cycles(void);\"\"\"\n    ...\n  def ggml_cycles_per_ms() -> int:\n    \"\"\"    GGML_API int64_t ggml_cycles_per_ms(void);\"\"\"\n    ...\n  def ggml_diag(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_diag(\n            struct ggml_context     * ctx,\n            struct ggml_tensor      * a);\n    \"\"\"\n    ...\n  def ggml_diag_mask_inf(ctx: ffi.CData, a: ffi.CData, n_past: int) -> ffi.CData:\n    \"\"\"\n    set elements above the diagonal to -INF\n\n        GGML_API struct ggml_tensor * ggml_diag_mask_inf(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int                   n_past);\n    \"\"\"\n    ...\n  def ggml_diag_mask_inf_inplace(ctx: ffi.CData, a: ffi.CData, n_past: int) -> ffi.CData:\n    \"\"\"\n    in-place, returns view(a)\n\n        GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int                   n_past);\n    \"\"\"\n    ...\n  def ggml_diag_mask_zero(ctx: ffi.CData, a: ffi.CData, n_past: int) -> ffi.CData:\n    \"\"\"\n    set elements above the diagonal to 0\n\n        GGML_API struct ggml_tensor * ggml_diag_mask_zero(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int                   n_past);\n    \"\"\"\n    ...\n  def ggml_diag_mask_zero_inplace(ctx: ffi.CData, a: ffi.CData, n_past: int) -> ffi.CData:\n    \"\"\"\n    in-place, returns view(a)\n\n        GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int                   n_past);\n    \"\"\"\n    ...\n  def ggml_div(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_div(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_div_inplace(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_div_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_dup(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_dup(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_dup_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n    in-place, returns view(a)\n\n        GGML_API struct ggml_tensor * ggml_dup_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_dup_tensor(ctx: ffi.CData, src: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);\"\"\"\n    ...\n  def ggml_element_size(tensor: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_elu(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_elu(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_elu_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_elu_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_flash_attn(ctx: ffi.CData, q: ffi.CData, k: ffi.CData, v: ffi.CData, masked: bool) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_flash_attn(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * q,\n                struct ggml_tensor  * k,\n                struct ggml_tensor  * v,\n                bool                  masked);\n    \"\"\"\n    ...\n  def ggml_flash_attn_back(ctx: ffi.CData, q: ffi.CData, k: ffi.CData, v: ffi.CData, d: ffi.CData, masked: bool) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_flash_attn_back(\n               struct ggml_context * ctx,\n               struct ggml_tensor  * q,\n               struct ggml_tensor  * k,\n               struct ggml_tensor  * v,\n               struct ggml_tensor  * d,\n               bool                  masked);\n    \"\"\"\n    ...\n  def ggml_flash_ff(ctx: ffi.CData, a: ffi.CData, b0: ffi.CData, b1: ffi.CData, c0: ffi.CData, c1: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_flash_ff(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b0,\n                struct ggml_tensor  * b1,\n                struct ggml_tensor  * c0,\n                struct ggml_tensor  * c1);\n    \"\"\"\n    ...\n  def ggml_format_name(tensor: ffi.CData, fmt: ffi.CData, *args2) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_tensor * ggml_format_name(      struct ggml_tensor * tensor, const char * fmt, ...);\"\"\"\n    ...\n  def ggml_fp16_to_fp32(x: np.float16) -> float:\n    \"\"\"\n    convert FP16 <-> FP32\n\n        GGML_API float       ggml_fp16_to_fp32(ggml_fp16_t x);\n    \"\"\"\n    ...\n  def ggml_fp16_to_fp32_row(x: ffi.CData, y: ffi.CData, n: int) -> None:\n    \"\"\"    GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n);\"\"\"\n    ...\n  def ggml_fp32_to_fp16(x: float) -> np.float16:\n    \"\"\"    GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);\"\"\"\n    ...\n  def ggml_fp32_to_fp16_row(x: ffi.CData, y: ffi.CData, n: int) -> None:\n    \"\"\"    GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n);\"\"\"\n    ...\n  def ggml_free(ctx: ffi.CData) -> None:\n    \"\"\"    GGML_API void                  ggml_free(struct ggml_context * ctx);\"\"\"\n    ...\n  def ggml_ftype_to_ggml_type(ftype: int) -> int:\n    \"\"\"\n    TODO: temporary until model loading of ggml examples is refactored\n\n        GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);\n    \"\"\"\n    ...\n  def ggml_gelu(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n    TODO: double-check this computation is correct\n\n        GGML_API struct ggml_tensor * ggml_gelu(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_gelu_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_gelu_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_gelu_quick(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_gelu_quick(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_gelu_quick_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_gelu_quick_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_get_data(tensor: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_get_data_f32(tensor: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_get_f32_1d(tensor: ffi.CData, i: int) -> float:\n    \"\"\"    GGML_API float   ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);\"\"\"\n    ...\n  def ggml_get_i32_1d(tensor: ffi.CData, i: int) -> int:\n    \"\"\"    GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);\"\"\"\n    ...\n  def ggml_get_max_tensor_size(ctx: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t  ggml_get_max_tensor_size(const struct ggml_context * ctx);\"\"\"\n    ...\n  def ggml_get_mem_buffer(ctx: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API void *  ggml_get_mem_buffer     (const struct ggml_context * ctx);\"\"\"\n    ...\n  def ggml_get_mem_size(ctx: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t  ggml_get_mem_size       (const struct ggml_context * ctx);\"\"\"\n    ...\n  def ggml_get_name(tensor: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API const char *         ggml_get_name   (const struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_get_no_alloc(ctx: ffi.CData) -> bool:\n    \"\"\"    GGML_API bool    ggml_get_no_alloc(struct ggml_context * ctx);\"\"\"\n    ...\n  def ggml_get_rows(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_get_rows(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_get_rows_back(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, c: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_get_rows_back(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b,\n                struct ggml_tensor  * c);\n    \"\"\"\n    ...\n  def ggml_get_tensor(ctx: ffi.CData, name: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);\"\"\"\n    ...\n  def ggml_get_unary_op(tensor: ffi.CData) -> int:\n    \"\"\"    GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_graph_compute(cgraph: ffi.CData, cplan: ffi.CData) -> int:\n    \"\"\"    GGML_API               int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);\"\"\"\n    ...\n  def ggml_graph_compute_with_ctx(ctx: ffi.CData, cgraph: ffi.CData, n_threads: int) -> None:\n    \"\"\"\n    same as ggml_graph_compute() but the work data is allocated as a part of the context\n    note: the drawback of this API is that you must have ensured that the context has enough memory for the work data\n\n        GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);\n    \"\"\"\n    ...\n  def ggml_graph_dump_dot(gb: ffi.CData, gf: ffi.CData, filename: ffi.CData) -> None:\n    \"\"\"\n    dump the graph into a file using the dot format\n\n        GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);\n    \"\"\"\n    ...\n  def ggml_graph_get_tensor(cgraph: ffi.CData, name: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);\"\"\"\n    ...\n  def ggml_graph_overhead() -> int:\n    \"\"\"    GGML_API size_t ggml_graph_overhead(void);\"\"\"\n    ...\n  def ggml_graph_plan(cgraph: ffi.CData, n_threads: int) -> ffi.CData:\n    \"\"\"\n    ggml_graph_plan() has to be called before ggml_graph_compute()\n    when plan.work_size > 0, caller must allocate memory for plan.work_data\n\n        GGML_API struct ggml_cplan ggml_graph_plan   (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);\n    \"\"\"\n    ...\n  def ggml_graph_print(cgraph: ffi.CData) -> None:\n    \"\"\"\n    print info and performance information for the graph\n\n        GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);\n    \"\"\"\n    ...\n  def ggml_graph_reset(cgraph: ffi.CData) -> None:\n    \"\"\"    GGML_API              void ggml_graph_reset  (struct ggml_cgraph * cgraph);\"\"\"\n    ...\n  def ggml_init(params: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);\"\"\"\n    ...\n  def ggml_init_cuda() -> None:\n    \"\"\"GGML_API void   ggml_init_cuda(void);\"\"\"\n    ...\n  def ggml_internal_get_type_traits(type: int) -> ffi.CData:\n    \"\"\"    ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);\"\"\"\n    ...\n  def ggml_is_contiguous(tensor: ffi.CData) -> bool:\n    \"\"\"    GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_is_numa() -> bool:\n    \"\"\"    GGML_API bool    ggml_is_numa(void); // true if init detected that system has >1 NUMA node\"\"\"\n    ...\n  def ggml_is_permuted(tensor: ffi.CData) -> bool:\n    \"\"\"    GGML_API bool ggml_is_permuted  (const struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_is_quantized(type: int) -> bool:\n    \"\"\"    GGML_API bool    ggml_is_quantized(enum ggml_type type);\"\"\"\n    ...\n  def ggml_is_transposed(tensor: ffi.CData) -> bool:\n    \"\"\"    GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_log(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_log(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_log_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_log_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_map_binary_f32(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, fun: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32(\n                struct ggml_context         * ctx,\n                struct ggml_tensor          * a,\n                struct ggml_tensor          * b,\n                       ggml_binary_op_f32_t   fun),\n            \"use ggml_map_custom2 instead\");\n    \"\"\"\n    ...\n  def ggml_map_binary_inplace_f32(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, fun: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(\n                struct ggml_context         * ctx,\n                struct ggml_tensor          * a,\n                struct ggml_tensor          * b,\n                       ggml_binary_op_f32_t   fun),\n            \"use ggml_map_custom2_inplace instead\");\n    \"\"\"\n    ...\n  def ggml_map_custom1(ctx: ffi.CData, a: ffi.CData, fun: ffi.CData, n_tasks: int, userdata: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_map_custom1(\n                struct ggml_context   * ctx,\n                struct ggml_tensor    * a,\n                ggml_custom1_op_t       fun,\n                int                     n_tasks,\n                void                  * userdata);\n    \"\"\"\n    ...\n  def ggml_map_custom1_f32(ctx: ffi.CData, a: ffi.CData, fun: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32(\n                struct ggml_context          * ctx,\n                struct ggml_tensor           * a,\n                       ggml_custom1_op_f32_t   fun),\n            \"use ggml_map_custom1 instead\");\n    \"\"\"\n    ...\n  def ggml_map_custom1_inplace(ctx: ffi.CData, a: ffi.CData, fun: ffi.CData, n_tasks: int, userdata: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_map_custom1_inplace(\n                struct ggml_context   * ctx,\n                struct ggml_tensor    * a,\n                ggml_custom1_op_t       fun,\n                int                     n_tasks,\n                void                  * userdata);\n    \"\"\"\n    ...\n  def ggml_map_custom1_inplace_f32(ctx: ffi.CData, a: ffi.CData, fun: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(\n                struct ggml_context          * ctx,\n                struct ggml_tensor           * a,\n                       ggml_custom1_op_f32_t   fun),\n            \"use ggml_map_custom1_inplace instead\");\n    \"\"\"\n    ...\n  def ggml_map_custom2(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, fun: ffi.CData, n_tasks: int, userdata: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_map_custom2(\n                struct ggml_context   * ctx,\n                struct ggml_tensor    * a,\n                struct ggml_tensor    * b,\n                ggml_custom2_op_t       fun,\n                int                     n_tasks,\n                void                  * userdata);\n    \"\"\"\n    ...\n  def ggml_map_custom2_f32(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, fun: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32(\n                struct ggml_context          * ctx,\n                struct ggml_tensor           * a,\n                struct ggml_tensor           * b,\n                       ggml_custom2_op_f32_t   fun),\n            \"use ggml_map_custom2 instead\");\n    \"\"\"\n    ...\n  def ggml_map_custom2_inplace(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, fun: ffi.CData, n_tasks: int, userdata: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_map_custom2_inplace(\n                struct ggml_context   * ctx,\n                struct ggml_tensor    * a,\n                struct ggml_tensor    * b,\n                ggml_custom2_op_t       fun,\n                int                     n_tasks,\n                void                  * userdata);\n    \"\"\"\n    ...\n  def ggml_map_custom2_inplace_f32(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, fun: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(\n                struct ggml_context          * ctx,\n                struct ggml_tensor           * a,\n                struct ggml_tensor           * b,\n                       ggml_custom2_op_f32_t   fun),\n            \"use ggml_map_custom2_inplace instead\");\n    \"\"\"\n    ...\n  def ggml_map_custom3(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, c: ffi.CData, fun: ffi.CData, n_tasks: int, userdata: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_map_custom3(\n                struct ggml_context   * ctx,\n                struct ggml_tensor    * a,\n                struct ggml_tensor    * b,\n                struct ggml_tensor    * c,\n                ggml_custom3_op_t       fun,\n                int                     n_tasks,\n                void                  * userdata);\n    \"\"\"\n    ...\n  def ggml_map_custom3_f32(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, c: ffi.CData, fun: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32(\n                struct ggml_context          * ctx,\n                struct ggml_tensor           * a,\n                struct ggml_tensor           * b,\n                struct ggml_tensor           * c,\n                       ggml_custom3_op_f32_t   fun),\n            \"use ggml_map_custom3 instead\");\n    \"\"\"\n    ...\n  def ggml_map_custom3_inplace(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, c: ffi.CData, fun: ffi.CData, n_tasks: int, userdata: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_map_custom3_inplace(\n                struct ggml_context   * ctx,\n                struct ggml_tensor    * a,\n                struct ggml_tensor    * b,\n                struct ggml_tensor    * c,\n                ggml_custom3_op_t       fun,\n                int                     n_tasks,\n                void                  * userdata);\n    \"\"\"\n    ...\n  def ggml_map_custom3_inplace_f32(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, c: ffi.CData, fun: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(\n                struct ggml_context          * ctx,\n                struct ggml_tensor           * a,\n                struct ggml_tensor           * b,\n                struct ggml_tensor           * c,\n                       ggml_custom3_op_f32_t   fun),\n            \"use ggml_map_custom3_inplace instead\");\n    \"\"\"\n    ...\n  def ggml_map_unary_f32(ctx: ffi.CData, a: ffi.CData, fun: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32(\n                struct ggml_context        * ctx,\n                struct ggml_tensor         * a,\n                       ggml_unary_op_f32_t   fun),\n            \"use ggml_map_custom1 instead\");\n    \"\"\"\n    ...\n  def ggml_map_unary_inplace_f32(ctx: ffi.CData, a: ffi.CData, fun: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(\n                struct ggml_context        * ctx,\n                struct ggml_tensor         * a,\n                       ggml_unary_op_f32_t   fun),\n            \"use ggml_map_custom1_inplace instead\");\n    \"\"\"\n    ...\n  def ggml_mean(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n    mean along rows\n\n        GGML_API struct ggml_tensor * ggml_mean(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_metal_add_buffer(ctx: ffi.CData, name: ffi.CData, data: ffi.CData, size: int, max_size: int) -> bool:\n    \"\"\"\n    creates a mapping between a host memory buffer and a device memory buffer\n    - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute\n    - the mapping is used during computation to determine the arguments of the compute kernels\n    - you don't need to keep the host memory buffer allocated as it is never accessed by Metal\n    - max_size specifies the maximum size of a tensor and is used to create shared views such\n    that it is guaranteed that the tensor will fit in at least one of the views\n\n\n    bool ggml_metal_add_buffer(\n            struct ggml_metal_context * ctx,\n                           const char * name,\n                                 void * data,\n                               size_t   size,\n                               size_t   max_size);\n    \"\"\"\n    ...\n  def ggml_metal_free(ctx: ffi.CData) -> None:\n    \"\"\"void ggml_metal_free(struct ggml_metal_context * ctx);\"\"\"\n    ...\n  def ggml_metal_get_concur_list(ctx: ffi.CData) -> ffi.CData:\n    \"\"\"\n    output the concur_list for ggml_alloc\n\n    int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);\n    \"\"\"\n    ...\n  def ggml_metal_get_tensor(ctx: ffi.CData, t: ffi.CData) -> None:\n    \"\"\"\n    get data from the device into host memory\n\n    void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);\n    \"\"\"\n    ...\n  def ggml_metal_graph_compute(ctx: ffi.CData, gf: ffi.CData) -> None:\n    \"\"\"\n    same as ggml_graph_compute but uses Metal\n    creates gf->n_threads command buffers in parallel\n\n    void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);\n    \"\"\"\n    ...\n  def ggml_metal_graph_find_concurrency(ctx: ffi.CData, gf: ffi.CData, check_mem: bool) -> None:\n    \"\"\"\n    try to find operations that can be run concurrently in the graph\n    you should run it again if the topology of your graph changes\n\n    void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem);\n    \"\"\"\n    ...\n  def ggml_metal_host_free(data: ffi.CData) -> None:\n    \"\"\"void   ggml_metal_host_free  (void * data);\"\"\"\n    ...\n  def ggml_metal_host_malloc(n: int) -> ffi.CData:\n    \"\"\"void * ggml_metal_host_malloc(size_t n);\"\"\"\n    ...\n  def ggml_metal_if_optimized(ctx: ffi.CData) -> int:\n    \"\"\"\n    if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized\n\n    int ggml_metal_if_optimized(struct ggml_metal_context * ctx);\n    \"\"\"\n    ...\n  def ggml_metal_init(n_cb: int) -> ffi.CData:\n    \"\"\"\n    number of command buffers to use\n\n    struct ggml_metal_context * ggml_metal_init(int n_cb);\n    \"\"\"\n    ...\n  def ggml_metal_set_n_cb(ctx: ffi.CData, n_cb: int) -> None:\n    \"\"\"\n    set the number of command buffers to use\n\n    void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);\n    \"\"\"\n    ...\n  def ggml_metal_set_tensor(ctx: ffi.CData, t: ffi.CData) -> None:\n    \"\"\"\n    set data from host memory into the device\n\n    void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);\n    \"\"\"\n    ...\n  def ggml_mpi_backend_free() -> None:\n    \"\"\"void ggml_mpi_backend_free(void);\"\"\"\n    ...\n  def ggml_mpi_backend_init() -> None:\n    \"\"\"void ggml_mpi_backend_init(void);\"\"\"\n    ...\n  def ggml_mpi_eval_init(ctx_mpi: ffi.CData, n_tokens: ffi.CData, n_past: ffi.CData, n_threads: ffi.CData) -> None:\n    \"\"\"\n    void ggml_mpi_eval_init(\n            struct ggml_mpi_context * ctx_mpi,\n                                int * n_tokens,\n                                int * n_past,\n                                int * n_threads);\n    \"\"\"\n    ...\n  def ggml_mpi_free(ctx: ffi.CData) -> None:\n    \"\"\"void ggml_mpi_free(struct ggml_mpi_context * ctx);\"\"\"\n    ...\n  def ggml_mpi_graph_compute_post(ctx_mpi: ffi.CData, gf: ffi.CData, n_layers: int) -> None:\n    \"\"\"\n    void ggml_mpi_graph_compute_post(\n            struct ggml_mpi_context * ctx_mpi,\n                 struct ggml_cgraph * gf,\n                                int   n_layers);\n    \"\"\"\n    ...\n  def ggml_mpi_graph_compute_pre(ctx_mpi: ffi.CData, gf: ffi.CData, n_layers: int) -> None:\n    \"\"\"\n    void ggml_mpi_graph_compute_pre(\n            struct ggml_mpi_context * ctx_mpi,\n                 struct ggml_cgraph * gf,\n                                int   n_layers);\n    \"\"\"\n    ...\n  def ggml_mpi_init() -> ffi.CData:\n    \"\"\"struct ggml_mpi_context * ggml_mpi_init(void);\"\"\"\n    ...\n  def ggml_mpi_rank(ctx: ffi.CData) -> int:\n    \"\"\"int ggml_mpi_rank(struct ggml_mpi_context * ctx);\"\"\"\n    ...\n  def ggml_mul(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_mul(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_mul_inplace(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_mul_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_mul_mat(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n    A: n columns, m rows\n    B: n columns, p rows  (i.e. we transpose it internally)\n    result is m columns, p rows\n\n        GGML_API struct ggml_tensor * ggml_mul_mat(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_nbytes(tensor: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t  ggml_nbytes      (const struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_nbytes_pad(tensor: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t  ggml_nbytes_pad  (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN\"\"\"\n    ...\n  def ggml_nbytes_split(tensor: ffi.CData, nrows_split: int) -> int:\n    \"\"\"    GGML_API size_t  ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split);\"\"\"\n    ...\n  def ggml_neg(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_neg(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_neg_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_neg_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_nelements(tensor: ffi.CData) -> int:\n    \"\"\"    GGML_API int64_t ggml_nelements   (const struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_new_f32(ctx: ffi.CData, value: float) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);\"\"\"\n    ...\n  def ggml_new_graph(ctx: ffi.CData) -> ffi.CData:\n    \"\"\"\n    graph allocation in a context\n\n        GGML_API struct ggml_cgraph * ggml_new_graph        (struct ggml_context * ctx);\n    \"\"\"\n    ...\n  def ggml_new_i32(ctx: ffi.CData, value: int) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);\"\"\"\n    ...\n  def ggml_new_tensor(ctx: ffi.CData, type: int, n_dims: int, ne: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_new_tensor(\n                struct ggml_context * ctx,\n                enum   ggml_type type,\n                int    n_dims,\n                const int64_t *ne);\n    \"\"\"\n    ...\n  def ggml_new_tensor_1d(ctx: ffi.CData, type: int, ne0: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_new_tensor_1d(\n                struct ggml_context * ctx,\n                enum   ggml_type type,\n                int64_t ne0);\n    \"\"\"\n    ...\n  def ggml_new_tensor_2d(ctx: ffi.CData, type: int, ne0: int, ne1: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_new_tensor_2d(\n                struct ggml_context * ctx,\n                enum   ggml_type type,\n                int64_t ne0,\n                int64_t ne1);\n    \"\"\"\n    ...\n  def ggml_new_tensor_3d(ctx: ffi.CData, type: int, ne0: int, ne1: int, ne2: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_new_tensor_3d(\n                struct ggml_context * ctx,\n                enum   ggml_type type,\n                int64_t ne0,\n                int64_t ne1,\n                int64_t ne2);\n    \"\"\"\n    ...\n  def ggml_new_tensor_4d(ctx: ffi.CData, type: int, ne0: int, ne1: int, ne2: int, ne3: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_new_tensor_4d(\n                struct ggml_context * ctx,\n                enum   ggml_type type,\n                int64_t ne0,\n                int64_t ne1,\n                int64_t ne2,\n                int64_t ne3);\n    \"\"\"\n    ...\n  def ggml_norm(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n    normalize along rows\n    TODO: eps is hardcoded to 1e-5 for now\n\n        GGML_API struct ggml_tensor * ggml_norm(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_norm_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_norm_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_nrows(tensor: ffi.CData) -> int:\n    \"\"\"    GGML_API int64_t ggml_nrows       (const struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_numa_init() -> None:\n    \"\"\"    GGML_API void    ggml_numa_init(void); // call once for better performance on NUMA systems\"\"\"\n    ...\n  def ggml_op_name(op: int) -> ffi.CData:\n    \"\"\"    GGML_API const char * ggml_op_name  (enum ggml_op   op);\"\"\"\n    ...\n  def ggml_op_symbol(op: int) -> ffi.CData:\n    \"\"\"    GGML_API const char * ggml_op_symbol(enum ggml_op   op);\"\"\"\n    ...\n  def ggml_opt(ctx: ffi.CData, params: ffi.CData, f: ffi.CData) -> int:\n    \"\"\"\n    optimize the function defined by the tensor f\n\n        GGML_API enum ggml_opt_result ggml_opt(\n                struct ggml_context * ctx,\n                struct ggml_opt_params params,\n                struct ggml_tensor * f);\n    \"\"\"\n    ...\n  def ggml_opt_default_params(type: int) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);\"\"\"\n    ...\n  def ggml_opt_init(ctx: ffi.CData, opt: ffi.CData, params: ffi.CData, nx: int) -> None:\n    \"\"\"\n    initialize optimizer context\n\n        GGML_API void ggml_opt_init(\n                struct ggml_context * ctx,\n                struct ggml_opt_context * opt,\n                struct ggml_opt_params params,\n                int64_t nx);\n    \"\"\"\n    ...\n  def ggml_opt_resume(ctx: ffi.CData, opt: ffi.CData, f: ffi.CData) -> int:\n    \"\"\"\n    continue optimizing the function defined by the tensor f\n\n        GGML_API enum ggml_opt_result ggml_opt_resume(\n                struct ggml_context * ctx,\n                struct ggml_opt_context * opt,\n                struct ggml_tensor * f);\n    \"\"\"\n    ...\n  def ggml_opt_resume_g(ctx: ffi.CData, opt: ffi.CData, f: ffi.CData, gf: ffi.CData, gb: ffi.CData) -> int:\n    \"\"\"\n    continue optimizing the function defined by the tensor f\n\n        GGML_API enum ggml_opt_result ggml_opt_resume_g(\n                struct ggml_context * ctx,\n                struct ggml_opt_context * opt,\n                struct ggml_tensor * f,\n                struct ggml_cgraph * gf,\n                struct ggml_cgraph * gb);\n    \"\"\"\n    ...\n  def ggml_out_prod(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n    A: m columns, n rows,\n    B: p columns, n rows,\n    result is m columns, p rows\n\n        GGML_API struct ggml_tensor * ggml_out_prod(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_permute(ctx: ffi.CData, a: ffi.CData, axis0: int, axis1: int, axis2: int, axis3: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_permute(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int                   axis0,\n                int                   axis1,\n                int                   axis2,\n                int                   axis3);\n    \"\"\"\n    ...\n  def ggml_pool_1d(ctx: ffi.CData, a: ffi.CData, op: int, k0: int, s0: int, p0: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_pool_1d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                enum ggml_op_pool     op,\n                int                   k0, // kernel size\n                int                   s0, // stride\n                int                   p0); // padding\n    \"\"\"\n    ...\n  def ggml_pool_2d(ctx: ffi.CData, a: ffi.CData, op: int, k0: int, k1: int, s0: int, s1: int, p0: int, p1: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_pool_2d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                enum ggml_op_pool     op,\n                int                   k0,\n                int                   k1,\n                int                   s0,\n                int                   s1,\n                int                   p0,\n                int                   p1);\n    \"\"\"\n    ...\n  def ggml_print_object(obj: ffi.CData) -> None:\n    \"\"\"    GGML_API void    ggml_print_object (const struct ggml_object * obj);\"\"\"\n    ...\n  def ggml_print_objects(ctx: ffi.CData) -> None:\n    \"\"\"    GGML_API void    ggml_print_objects(const struct ggml_context * ctx);\"\"\"\n    ...\n  def ggml_quantize_chunk(type: int, src: ffi.CData, dst: ffi.CData, start: int, n: int, hist: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);\"\"\"\n    ...\n  def ggml_quantize_q2_K(src: ffi.CData, dst: ffi.CData, n: int, k: int, hist: ffi.CData) -> int:\n    \"\"\"\n    Quantization with histogram collection\n\n    size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);\n    \"\"\"\n    ...\n  def ggml_quantize_q3_K(src: ffi.CData, dst: ffi.CData, n: int, k: int, hist: ffi.CData) -> int:\n    \"\"\"size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);\"\"\"\n    ...\n  def ggml_quantize_q4_0(src: ffi.CData, dst: ffi.CData, n: int, k: int, hist: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);\"\"\"\n    ...\n  def ggml_quantize_q4_1(src: ffi.CData, dst: ffi.CData, n: int, k: int, hist: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);\"\"\"\n    ...\n  def ggml_quantize_q4_K(src: ffi.CData, dst: ffi.CData, n: int, k: int, hist: ffi.CData) -> int:\n    \"\"\"size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);\"\"\"\n    ...\n  def ggml_quantize_q5_0(src: ffi.CData, dst: ffi.CData, n: int, k: int, hist: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);\"\"\"\n    ...\n  def ggml_quantize_q5_1(src: ffi.CData, dst: ffi.CData, n: int, k: int, hist: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);\"\"\"\n    ...\n  def ggml_quantize_q5_K(src: ffi.CData, dst: ffi.CData, n: int, k: int, hist: ffi.CData) -> int:\n    \"\"\"size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);\"\"\"\n    ...\n  def ggml_quantize_q6_K(src: ffi.CData, dst: ffi.CData, n: int, k: int, hist: ffi.CData) -> int:\n    \"\"\"size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);\"\"\"\n    ...\n  def ggml_quantize_q8_0(src: ffi.CData, dst: ffi.CData, n: int, k: int, hist: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);\"\"\"\n    ...\n  def ggml_relu(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_relu(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_relu_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_relu_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_repeat(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n    if a is the same shape as b, and a is not parameter, return a\n    otherwise, return a new tensor: repeat(a) to fit in b\n\n        GGML_API struct ggml_tensor * ggml_repeat(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_repeat_back(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_repeat_back(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_reshape(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n    return view(a), b specifies the new shape\n    TODO: when we start computing gradient, make a copy instead of view\n\n        GGML_API struct ggml_tensor * ggml_reshape(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_reshape_1d(ctx: ffi.CData, a: ffi.CData, ne0: int) -> ffi.CData:\n    \"\"\"\n    return view(a)\n    TODO: when we start computing gradient, make a copy instead of view\n\n        GGML_API struct ggml_tensor * ggml_reshape_1d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int64_t               ne0);\n    \"\"\"\n    ...\n  def ggml_reshape_2d(ctx: ffi.CData, a: ffi.CData, ne0: int, ne1: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_reshape_2d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int64_t               ne0,\n                int64_t               ne1);\n    \"\"\"\n    ...\n  def ggml_reshape_3d(ctx: ffi.CData, a: ffi.CData, ne0: int, ne1: int, ne2: int) -> ffi.CData:\n    \"\"\"\n    return view(a)\n    TODO: when we start computing gradient, make a copy instead of view\n\n        GGML_API struct ggml_tensor * ggml_reshape_3d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int64_t               ne0,\n                int64_t               ne1,\n                int64_t               ne2);\n    \"\"\"\n    ...\n  def ggml_reshape_4d(ctx: ffi.CData, a: ffi.CData, ne0: int, ne1: int, ne2: int, ne3: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_reshape_4d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int64_t               ne0,\n                int64_t               ne1,\n                int64_t               ne2,\n                int64_t               ne3);\n    \"\"\"\n    ...\n  def ggml_rms_norm(ctx: ffi.CData, a: ffi.CData, eps: float) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_rms_norm(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                float                 eps);\n    \"\"\"\n    ...\n  def ggml_rms_norm_back(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n    a - x\n    b - dy\n    TODO: update with configurable eps\n\n        GGML_API struct ggml_tensor * ggml_rms_norm_back(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_rms_norm_inplace(ctx: ffi.CData, a: ffi.CData, eps: float) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_rms_norm_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                float                 eps);\n    \"\"\"\n    ...\n  def ggml_rope(ctx: ffi.CData, a: ffi.CData, n_past: int, n_dims: int, mode: int, n_ctx: int) -> ffi.CData:\n    \"\"\"\n    rotary position embedding\n    if mode & 1 == 1, skip n_past elements\n    if mode & 2 == 1, GPT-NeoX style\n    if mode & 4 == 1, ChatGLM style\n    TODO: avoid creating a new tensor every time\n\n        GGML_API struct ggml_tensor * ggml_rope(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int                   n_past,\n                int                   n_dims,\n                int                   mode,\n                int                   n_ctx);\n    \"\"\"\n    ...\n  def ggml_rope_back(ctx: ffi.CData, a: ffi.CData, n_past: int, n_dims: int, mode: int, n_ctx: int) -> ffi.CData:\n    \"\"\"\n    rotary position embedding backward, i.e compute dx from dy\n    a - dy\n\n        GGML_API struct ggml_tensor * ggml_rope_back(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int                   n_past,\n                int                   n_dims,\n                int                   mode,\n                int                   n_ctx);\n    \"\"\"\n    ...\n  def ggml_rope_custom(ctx: ffi.CData, a: ffi.CData, n_past: int, n_dims: int, mode: int, n_ctx: int, freq_base: float, freq_scale: float) -> ffi.CData:\n    \"\"\"\n    custom RoPE\n\n        GGML_API struct ggml_tensor * ggml_rope_custom(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int                   n_past,\n                int                   n_dims,\n                int                   mode,\n                int                   n_ctx,\n                float                 freq_base,\n                float                 freq_scale);\n    \"\"\"\n    ...\n  def ggml_rope_custom_inplace(ctx: ffi.CData, a: ffi.CData, n_past: int, n_dims: int, mode: int, n_ctx: int, freq_base: float, freq_scale: float) -> ffi.CData:\n    \"\"\"\n    in-place, returns view(a)\n\n        GGML_API struct ggml_tensor * ggml_rope_custom_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int                   n_past,\n                int                   n_dims,\n                int                   mode,\n                int                   n_ctx,\n                float                 freq_base,\n                float                 freq_scale);\n    \"\"\"\n    ...\n  def ggml_rope_inplace(ctx: ffi.CData, a: ffi.CData, n_past: int, n_dims: int, mode: int, n_ctx: int) -> ffi.CData:\n    \"\"\"\n    in-place, returns view(a)\n\n        GGML_API struct ggml_tensor * ggml_rope_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int                   n_past,\n                int                   n_dims,\n                int                   mode,\n                int                   n_ctx);\n    \"\"\"\n    ...\n  def ggml_scale(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_scale(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_scale_inplace(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n    in-place, returns view(a)\n\n        GGML_API struct ggml_tensor * ggml_scale_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_set(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, nb1: int, nb2: int, nb3: int, offset: int) -> ffi.CData:\n    \"\"\"\n    b -> view(a,offset,nb1,nb2,3), return modified a\n\n        GGML_API struct ggml_tensor * ggml_set(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b,\n                size_t                nb1,\n                size_t                nb2,\n                size_t                nb3,\n                size_t                offset);\n    \"\"\"\n    ...\n  def ggml_set_1d(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, offset: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_set_1d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b,\n                size_t                offset);\n    \"\"\"\n    ...\n  def ggml_set_1d_inplace(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, offset: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_set_1d_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b,\n                size_t                offset);\n    \"\"\"\n    ...\n  def ggml_set_2d(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, nb1: int, offset: int) -> ffi.CData:\n    \"\"\"\n    b -> view(a,offset,nb1,nb2,3), return modified a\n\n        GGML_API struct ggml_tensor * ggml_set_2d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b,\n                size_t                nb1,\n                size_t                offset);\n    \"\"\"\n    ...\n  def ggml_set_2d_inplace(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, nb1: int, offset: int) -> ffi.CData:\n    \"\"\"\n    b -> view(a,offset,nb1,nb2,3), return view(a)\n\n        GGML_API struct ggml_tensor * ggml_set_2d_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b,\n                size_t                nb1,\n                size_t                offset);\n    \"\"\"\n    ...\n  def ggml_set_f32(tensor: ffi.CData, value: float) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);\"\"\"\n    ...\n  def ggml_set_f32_1d(tensor: ffi.CData, i: int, value: float) -> None:\n    \"\"\"    GGML_API void    ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);\"\"\"\n    ...\n  def ggml_set_i32(tensor: ffi.CData, value: int) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);\"\"\"\n    ...\n  def ggml_set_i32_1d(tensor: ffi.CData, i: int, value: int) -> None:\n    \"\"\"    GGML_API void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);\"\"\"\n    ...\n  def ggml_set_inplace(ctx: ffi.CData, a: ffi.CData, b: ffi.CData, nb1: int, nb2: int, nb3: int, offset: int) -> ffi.CData:\n    \"\"\"\n    b -> view(a,offset,nb1,nb2,3), return view(a)\n\n        GGML_API struct ggml_tensor * ggml_set_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b,\n                size_t                nb1,\n                size_t                nb2,\n                size_t                nb3,\n                size_t                offset);\n    \"\"\"\n    ...\n  def ggml_set_name(tensor: ffi.CData, name: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_tensor * ggml_set_name   (      struct ggml_tensor * tensor, const char * name);\"\"\"\n    ...\n  def ggml_set_no_alloc(ctx: ffi.CData, no_alloc: bool) -> None:\n    \"\"\"    GGML_API void    ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);\"\"\"\n    ...\n  def ggml_set_param(ctx: ffi.CData, tensor: ffi.CData) -> None:\n    \"\"\"\n        GGML_API void ggml_set_param(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * tensor);\n    \"\"\"\n    ...\n  def ggml_set_scratch(ctx: ffi.CData, scratch: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t  ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);\"\"\"\n    ...\n  def ggml_set_zero(tensor: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);\"\"\"\n    ...\n  def ggml_sgn(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_sgn(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_sgn_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_sgn_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_silu(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_silu(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_silu_back(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n    a - x\n    b - dy\n\n        GGML_API struct ggml_tensor * ggml_silu_back(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_silu_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_silu_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_soft_max(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_soft_max(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_soft_max_back(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_soft_max_back(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_soft_max_back_inplace(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n    in-place, returns view(a)\n\n        GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_soft_max_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n    in-place, returns view(a)\n\n        GGML_API struct ggml_tensor * ggml_soft_max_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_sqr(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_sqr(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_sqr_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_sqr_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_sqrt(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_sqrt(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_sqrt_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_sqrt_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_step(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_step(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_step_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_step_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_sub(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_sub(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_sub_inplace(ctx: ffi.CData, a: ffi.CData, b: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_sub_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b);\n    \"\"\"\n    ...\n  def ggml_sum(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n    return scalar\n\n        GGML_API struct ggml_tensor * ggml_sum(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_sum_rows(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n    sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d]\n\n        GGML_API struct ggml_tensor * ggml_sum_rows(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_tanh(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_tanh(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_tanh_inplace(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_tanh_inplace(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_tensor_overhead() -> int:\n    \"\"\"\n    use this to compute the memory overhead of a tensor\n\n        GGML_API size_t ggml_tensor_overhead(void);\n    \"\"\"\n    ...\n  def ggml_time_init() -> None:\n    \"\"\"    GGML_API void    ggml_time_init(void); // call this once at the beginning of the program\"\"\"\n    ...\n  def ggml_time_ms() -> int:\n    \"\"\"    GGML_API int64_t ggml_time_ms(void);\"\"\"\n    ...\n  def ggml_time_us() -> int:\n    \"\"\"    GGML_API int64_t ggml_time_us(void);\"\"\"\n    ...\n  def ggml_transpose(ctx: ffi.CData, a: ffi.CData) -> ffi.CData:\n    \"\"\"\n    alias for ggml_permute(ctx, a, 1, 0, 2, 3)\n\n        GGML_API struct ggml_tensor * ggml_transpose(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a);\n    \"\"\"\n    ...\n  def ggml_type_name(type: int) -> ffi.CData:\n    \"\"\"    GGML_API const char * ggml_type_name(enum ggml_type type);\"\"\"\n    ...\n  def ggml_type_size(type: int) -> int:\n    \"\"\"    GGML_API size_t  ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block\"\"\"\n    ...\n  def ggml_type_sizef(type: int) -> float:\n    \"\"\"    GGML_API float   ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float\"\"\"\n    ...\n  def ggml_unary(ctx: ffi.CData, a: ffi.CData, op: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_unary(\n                struct ggml_context * ctx,\n                 struct ggml_tensor * a,\n                 enum ggml_unary_op op);\n    \"\"\"\n    ...\n  def ggml_unary_inplace(ctx: ffi.CData, a: ffi.CData, op: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_unary_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            enum ggml_unary_op op);\n    \"\"\"\n    ...\n  def ggml_used_mem(ctx: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);\"\"\"\n    ...\n  def ggml_vec_dot_q2_K_q8_K(n: int, s: ffi.CData, vx: ffi.CData, vy: ffi.CData) -> None:\n    \"\"\"\n    Dot product\n\n    void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);\n    \"\"\"\n    ...\n  def ggml_vec_dot_q3_K_q8_K(n: int, s: ffi.CData, vx: ffi.CData, vy: ffi.CData) -> None:\n    \"\"\"void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);\"\"\"\n    ...\n  def ggml_vec_dot_q4_K_q8_K(n: int, s: ffi.CData, vx: ffi.CData, vy: ffi.CData) -> None:\n    \"\"\"void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);\"\"\"\n    ...\n  def ggml_vec_dot_q5_K_q8_K(n: int, s: ffi.CData, vx: ffi.CData, vy: ffi.CData) -> None:\n    \"\"\"void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);\"\"\"\n    ...\n  def ggml_vec_dot_q6_K_q8_K(n: int, s: ffi.CData, vx: ffi.CData, vy: ffi.CData) -> None:\n    \"\"\"void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);\"\"\"\n    ...\n  def ggml_view_1d(ctx: ffi.CData, a: ffi.CData, ne0: int, offset: int) -> ffi.CData:\n    \"\"\"\n    offset in bytes\n\n        GGML_API struct ggml_tensor * ggml_view_1d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int64_t               ne0,\n                size_t                offset);\n    \"\"\"\n    ...\n  def ggml_view_2d(ctx: ffi.CData, a: ffi.CData, ne0: int, ne1: int, nb1: int, offset: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_view_2d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int64_t               ne0,\n                int64_t               ne1,\n                size_t                nb1, // row stride in bytes\n                size_t                offset);\n    \"\"\"\n    ...\n  def ggml_view_3d(ctx: ffi.CData, a: ffi.CData, ne0: int, ne1: int, ne2: int, nb1: int, nb2: int, offset: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_view_3d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int64_t               ne0,\n                int64_t               ne1,\n                int64_t               ne2,\n                size_t                nb1, // row   stride in bytes\n                size_t                nb2, // slice stride in bytes\n                size_t                offset);\n    \"\"\"\n    ...\n  def ggml_view_4d(ctx: ffi.CData, a: ffi.CData, ne0: int, ne1: int, ne2: int, ne3: int, nb1: int, nb2: int, nb3: int, offset: int) -> ffi.CData:\n    \"\"\"\n        GGML_API struct ggml_tensor * ggml_view_4d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int64_t               ne0,\n                int64_t               ne1,\n                int64_t               ne2,\n                int64_t               ne3,\n                size_t                nb1, // row   stride in bytes\n                size_t                nb2, // slice stride in bytes\n                size_t                nb3,\n                size_t                offset);\n    \"\"\"\n    ...\n  def ggml_view_tensor(ctx: ffi.CData, src: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);\"\"\"\n    ...\n  def ggml_win_part(ctx: ffi.CData, a: ffi.CData, w: int) -> ffi.CData:\n    \"\"\"\n    partition into non-overlapping windows with padding if needed\n    example:\n    a:   768   64   64    1\n    w:    14\n    res: 768   14   14    25\n    used in sam\n\n        GGML_API struct ggml_tensor * ggml_win_part(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int                   w);\n    \"\"\"\n    ...\n  def ggml_win_unpart(ctx: ffi.CData, a: ffi.CData, w0: int, h0: int, w: int) -> ffi.CData:\n    \"\"\"\n    reverse of ggml_win_part\n    used in sam\n\n        GGML_API struct ggml_tensor * ggml_win_unpart(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                int                   w0,\n                int                   h0,\n                int                   w);\n    \"\"\"\n    ...\n  def gguf_add_tensor(ctx: ffi.CData, tensor: ffi.CData) -> None:\n    \"\"\"\n    manage tensor info\n\n        GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);\n    \"\"\"\n    ...\n  def gguf_find_key(ctx: ffi.CData, key: ffi.CData) -> int:\n    \"\"\"    GGML_API int          gguf_find_key(struct gguf_context * ctx, const char * key);\"\"\"\n    ...\n  def gguf_find_tensor(ctx: ffi.CData, name: ffi.CData) -> int:\n    \"\"\"    GGML_API int    gguf_find_tensor      (struct gguf_context * ctx, const char * name);\"\"\"\n    ...\n  def gguf_free(ctx: ffi.CData) -> None:\n    \"\"\"    GGML_API void gguf_free(struct gguf_context * ctx);\"\"\"\n    ...\n  def gguf_get_alignment(ctx: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t gguf_get_alignment  (struct gguf_context * ctx);\"\"\"\n    ...\n  def gguf_get_arr_data(ctx: ffi.CData, i: int) -> ffi.CData:\n    \"\"\"    GGML_API const void * gguf_get_arr_data(struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_arr_n(ctx: ffi.CData, i: int) -> int:\n    \"\"\"    GGML_API int          gguf_get_arr_n   (struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_arr_str(ctx: ffi.CData, key_id: int, i: int) -> ffi.CData:\n    \"\"\"    GGML_API const char * gguf_get_arr_str (struct gguf_context * ctx, int key_id, int i);\"\"\"\n    ...\n  def gguf_get_arr_type(ctx: ffi.CData, i: int) -> int:\n    \"\"\"    GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_data(ctx: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API void * gguf_get_data       (struct gguf_context * ctx);\"\"\"\n    ...\n  def gguf_get_data_offset(ctx: ffi.CData) -> int:\n    \"\"\"    GGML_API size_t gguf_get_data_offset(struct gguf_context * ctx);\"\"\"\n    ...\n  def gguf_get_key(ctx: ffi.CData, i: int) -> ffi.CData:\n    \"\"\"    GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_kv_type(ctx: ffi.CData, i: int) -> int:\n    \"\"\"    GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_meta_data(ctx: ffi.CData, data: ffi.CData) -> None:\n    \"\"\"    GGML_API void   gguf_get_meta_data(struct gguf_context * ctx, void * data);\"\"\"\n    ...\n  def gguf_get_meta_size(ctx: ffi.CData) -> int:\n    \"\"\"\n    get the size in bytes of the meta data (header, kv pairs, tensor info) including padding\n\n        GGML_API size_t gguf_get_meta_size(struct gguf_context * ctx);\n    \"\"\"\n    ...\n  def gguf_get_n_kv(ctx: ffi.CData) -> int:\n    \"\"\"    GGML_API int          gguf_get_n_kv(struct gguf_context * ctx);\"\"\"\n    ...\n  def gguf_get_n_tensors(ctx: ffi.CData) -> int:\n    \"\"\"    GGML_API int    gguf_get_n_tensors    (struct gguf_context * ctx);\"\"\"\n    ...\n  def gguf_get_tensor_name(ctx: ffi.CData, i: int) -> ffi.CData:\n    \"\"\"    GGML_API char * gguf_get_tensor_name  (struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_tensor_offset(ctx: ffi.CData, i: int) -> int:\n    \"\"\"    GGML_API size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_val_bool(ctx: ffi.CData, i: int) -> bool:\n    \"\"\"    GGML_API bool         gguf_get_val_bool(struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_val_f32(ctx: ffi.CData, i: int) -> float:\n    \"\"\"    GGML_API float        gguf_get_val_f32 (struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_val_i16(ctx: ffi.CData, i: int) -> int:\n    \"\"\"    GGML_API int16_t      gguf_get_val_i16 (struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_val_i32(ctx: ffi.CData, i: int) -> int:\n    \"\"\"    GGML_API int32_t      gguf_get_val_i32 (struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_val_i8(ctx: ffi.CData, i: int) -> int:\n    \"\"\"    GGML_API int8_t       gguf_get_val_i8  (struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_val_str(ctx: ffi.CData, i: int) -> ffi.CData:\n    \"\"\"    GGML_API const char * gguf_get_val_str (struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_val_u16(ctx: ffi.CData, i: int) -> int:\n    \"\"\"    GGML_API uint16_t     gguf_get_val_u16 (struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_val_u32(ctx: ffi.CData, i: int) -> int:\n    \"\"\"    GGML_API uint32_t     gguf_get_val_u32 (struct gguf_context * ctx, int i);\"\"\"\n    ...\n  def gguf_get_val_u8(ctx: ffi.CData, i: int) -> int:\n    \"\"\"\n    results are undefined if the wrong type is used for the key\n\n        GGML_API uint8_t      gguf_get_val_u8  (struct gguf_context * ctx, int i);\n    \"\"\"\n    ...\n  def gguf_get_version(ctx: ffi.CData) -> int:\n    \"\"\"    GGML_API int    gguf_get_version    (struct gguf_context * ctx);\"\"\"\n    ...\n  def gguf_init_empty() -> ffi.CData:\n    \"\"\"    GGML_API struct gguf_context * gguf_init_empty(void);\"\"\"\n    ...\n  def gguf_init_from_file(fname: ffi.CData, params: ffi.CData) -> ffi.CData:\n    \"\"\"    GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);\"\"\"\n    ...\n  def gguf_set_arr_data(ctx: ffi.CData, key: ffi.CData, type: int, data: ffi.CData, n: int) -> None:\n    \"\"\"    GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);\"\"\"\n    ...\n  def gguf_set_arr_str(ctx: ffi.CData, key: ffi.CData, data: ffi.CData, n: int) -> None:\n    \"\"\"    GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);\"\"\"\n    ...\n  def gguf_set_kv(ctx: ffi.CData, src: ffi.CData) -> None:\n    \"\"\"\n    set or add KV pairs from another context\n\n        GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);\n    \"\"\"\n    ...\n  def gguf_set_tensor_data(ctx: ffi.CData, name: ffi.CData, data: ffi.CData, size: int) -> None:\n    \"\"\"    GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);\"\"\"\n    ...\n  def gguf_set_tensor_type(ctx: ffi.CData, name: ffi.CData, type: int) -> None:\n    \"\"\"    GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);\"\"\"\n    ...\n  def gguf_set_val_bool(ctx: ffi.CData, key: ffi.CData, val: bool) -> None:\n    \"\"\"    GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool     val);\"\"\"\n    ...\n  def gguf_set_val_f32(ctx: ffi.CData, key: ffi.CData, val: float) -> None:\n    \"\"\"    GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float    val);\"\"\"\n    ...\n  def gguf_set_val_i16(ctx: ffi.CData, key: ffi.CData, val: int) -> None:\n    \"\"\"    GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t  val);\"\"\"\n    ...\n  def gguf_set_val_i32(ctx: ffi.CData, key: ffi.CData, val: int) -> None:\n    \"\"\"    GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t  val);\"\"\"\n    ...\n  def gguf_set_val_i8(ctx: ffi.CData, key: ffi.CData, val: int) -> None:\n    \"\"\"    GGML_API void gguf_set_val_i8  (struct gguf_context * ctx, const char * key, int8_t   val);\"\"\"\n    ...\n  def gguf_set_val_str(ctx: ffi.CData, key: ffi.CData, val: ffi.CData) -> None:\n    \"\"\"    GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);\"\"\"\n    ...\n  def gguf_set_val_u16(ctx: ffi.CData, key: ffi.CData, val: int) -> None:\n    \"\"\"    GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);\"\"\"\n    ...\n  def gguf_set_val_u32(ctx: ffi.CData, key: ffi.CData, val: int) -> None:\n    \"\"\"    GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);\"\"\"\n    ...\n  def gguf_set_val_u8(ctx: ffi.CData, key: ffi.CData, val: int) -> None:\n    \"\"\"\n    overrides existing values or adds a new one\n\n        GGML_API void gguf_set_val_u8  (struct gguf_context * ctx, const char * key, uint8_t  val);\n    \"\"\"\n    ...\n  def gguf_type_name(type: int) -> ffi.CData:\n    \"\"\"    GGML_API const char * gguf_type_name(enum gguf_type type);\"\"\"\n    ...\n  def gguf_write_to_file(ctx: ffi.CData, fname: ffi.CData, only_meta: bool) -> None:\n    \"\"\"\n    write the entire context to a binary file\n\n        GGML_API void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta);\n    \"\"\"\n    ...\n  def quantize_row_q2_K(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void quantize_row_q2_K(const float * restrict x, void * restrict y, int k);\"\"\"\n    ...\n  def quantize_row_q2_K_reference(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"\n    Quantization\n\n    void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k);\n    \"\"\"\n    ...\n  def quantize_row_q3_K(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void quantize_row_q3_K(const float * restrict x, void * restrict y, int k);\"\"\"\n    ...\n  def quantize_row_q3_K_reference(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k);\"\"\"\n    ...\n  def quantize_row_q4_K(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);\"\"\"\n    ...\n  def quantize_row_q4_K_reference(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k);\"\"\"\n    ...\n  def quantize_row_q5_K(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);\"\"\"\n    ...\n  def quantize_row_q5_K_reference(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);\"\"\"\n    ...\n  def quantize_row_q6_K(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);\"\"\"\n    ...\n  def quantize_row_q6_K_reference(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);\"\"\"\n    ...\n  def quantize_row_q8_K(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);\"\"\"\n    ...\n  def quantize_row_q8_K_reference(x: ffi.CData, y: ffi.CData, k: int) -> None:\n    \"\"\"void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);\"\"\"\n    ...\n"
  },
  {
    "path": "examples/python/ggml/cffi.py",
    "content": "# auto-generated file\nimport _cffi_backend\n\nffi = _cffi_backend.FFI('ggml.cffi',\n    _version = 0x2601,\n    _types = b'\\x00\\x00\\xB6\\x0D\\x00\\x00\\x09\\x0B\\x00\\x00\\x00\\x0F\\x00\\x00\\xB6\\x0D\\x00\\x04\\x2F\\x03\\x00\\x00\\x00\\x0F\\x00\\x00\\xB6\\x0D\\x00\\x04\\x31\\x03\\x00\\x04\\x3D\\x03\\x00\\x00\\x00\\x0F\\x00\\x00\\xB6\\x0D\\x00\\x04\\x32\\x03\\x00\\x00\\x00\\x0F\\x00\\x00\\xB6\\x0D\\x00\\x04\\x34\\x03\\x00\\x03\\xFE\\x03\\x00\\x04\\x53\\x03\\x00\\x00\\x0A\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\xB6\\x0D\\x00\\x04\\x3D\\x03\\x00\\x00\\x00\\x0F\\x00\\x00\\xB6\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\xB6\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\xB6\\x0D\\x00\\x04\\x3E\\x03\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\xB6\\x0D\\x00\\x00\\x10\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\xB6\\x0D\\x00\\x00\\x00\\x0F\\x00\\x02\\xD0\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x0F\\x0D\\x00\\x00\\x04\\x0B\\x00\\x00\\x00\\x0F\\x00\\x00\\x0F\\x0D\\x00\\x00\\x01\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x0F\\x0D\\x00\\x00\\x0B\\x0B\\x00\\x00\\x00\\x0F\\x00\\x00\\x0F\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x0F\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x0F\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x16\\x0D\\x00\\x00\\x0B\\x11\\x00\\x04\\x38\\x03\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x16\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x44\\x11\\x00\\x00\\x08\\x11\\x00\\x04\\x30\\x03\\x00\\x00\\x4B\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x16\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x20\\x09\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x01\\x0D\\x00\\x00\\x01\\x0B\\x00\\x00\\x00\\x0F\\x00\\x01\\x14\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x34\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x02\\x7E\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\xF4\\x0D\\x00\\x00\\x01\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\xF4\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\xF4\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\xF4\\x0D\\x00\\x00\\x06\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x18\\x0D\\x00\\x00\\x01\\x11\\x00\\x00\\x00\\x0F\\x00\\x02\\xE9\\x0D\\x00\\x00\\x0E\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x22\\x0D\\x00\\x00\\x01\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x22\\x0D\\x00\\x00\\x4B\\x11\\x00\\x04\\x33\\x03\\x00\\x00\\x00\\x0F\\x00\\x00\\x22\\x0D\\x00\\x00\\x0E\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x22\\x0D\\x00\\x04\\x35\\x03\\x00\\x00\\x00\\x0F\\x00\\x00\\x22\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x22\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x22\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x22\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x22\\x0D\\x00\\x00\\x00\\x0F\\x00\\x00\\xDB\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\xDB\\x0D\\x00\\x00\\x00\\x0F\\x00\\x03\\xB0\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x03\\xB5\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x04\\x0D\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x04\\x0D\\x00\\x00\\x10\\x11\\x00\\x00\\x0A\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x4B\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x4B\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x30\\x0D\\x00\\x00\\x0F\\x11\\x00\\x00\\x0B\\x03\\x00\\x00\\xB0\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x30\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x4B\\x11\\x00\\x00\\x01\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x30\\x0D\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x0B\\x0D\\x00\\x00\\x1B\\x09\\x00\\x00\\x00\\x0F\\x00\\x04\\x33\\x0D\\x00\\x00\\x4B\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x0E\\x0D\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x7F\\x0D\\x00\\x00\\x00\\x0F\\x00\\x00\\x50\\x0D\\x00\\x00\\x07\\x0B\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x4B\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x01\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\xDB\\x03\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x01\\x11\\x00\\x00\\x0B\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x01\\x11\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x01\\x11\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x01\\x11\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x0D\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x05\\x0B\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x01\\x01\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x0A\\x0B\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x0D\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x0D\\x01\\x00\\x00\\x0D\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x0D\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x0D\\x01\\x00\\x00\\x0D\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x0B\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x0B\\x01\\x00\\x00\\x0B\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x0B\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x01\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x01\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x03\\x5C\\x03\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x03\\x62\\x03\\x00\\x00\\x07\\x01\\x00\\x00\\x10\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x0A\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x0A\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x02\\xD8\\x03\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x03\\x4F\\x03\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x08\\x11\\x00\\x03\\x54\\x03\\x00\\x00\\x07\\x01\\x00\\x00\\x10\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x02\\xD3\\x03\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x03\\x44\\x03\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x03\\x48\\x03\\x00\\x00\\x07\\x01\\x00\\x00\\x10\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x08\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x08\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x01\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x08\\x11\\x00\\x00\\x0D\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x08\\x0D\\x00\\x00\\x08\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x21\\x0D\\x00\\x00\\x0F\\x11\\x00\\x00\\x24\\x09\\x00\\x00\\x00\\x0F\\x00\\x00\\x21\\x0D\\x00\\x00\\x00\\x0F\\x00\\x03\\xBA\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x03\\xBF\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x11\\x0D\\x00\\x00\\x01\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x11\\x0D\\x00\\x00\\x01\\x11\\x00\\x00\\xF4\\x03\\x00\\x00\\x10\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\xDB\\x03\\x00\\x00\\x00\\x0F\\x00\\x00\\x11\\x0D\\x00\\x02\\x35\\x11\\x00\\x00\\x10\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x02\\x39\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x11\\x0D\\x00\\x00\\x04\\x11\\x00\\x00\\x4B\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x11\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x21\\x09\\x00\\x00\\x00\\x0F\\x00\\x00\\x11\\x0D\\x00\\x04\\x32\\x03\\x00\\x00\\x00\\x0F\\x00\\x00\\x11\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x11\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x11\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x11\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x11\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x11\\x0D\\x00\\x00\\x00\\x0F\\x00\\x00\\x6C\\x0D\\x00\\x00\\x0D\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x6C\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x00\\x10\\x0D\\x00\\x02\\x4B\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x10\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x10\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x00\\x0F\\x00\\x00\\x10\\x0D\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x02\\xE1\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x01\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x03\\xF8\\x03\\x00\\x00\\xF4\\x03\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x03\\xF9\\x03\\x00\\x02\\x7E\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x03\\xFA\\x03\\x00\\x02\\x7E\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x03\\xFB\\x03\\x00\\x02\\x7E\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x03\\xFC\\x03\\x00\\x02\\x7E\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x03\\xFD\\x03\\x00\\x02\\x7E\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x0F\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x0F\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x02\\x35\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x02\\x35\\x11\\x00\\x03\\xF8\\x03\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x02\\x35\\x11\\x00\\x03\\xF9\\x03\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x02\\x35\\x11\\x00\\x03\\xFA\\x03\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x02\\x35\\x11\\x00\\x03\\xFB\\x03\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x02\\x35\\x11\\x00\\x03\\xFC\\x03\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x02\\x35\\x11\\x00\\x03\\xFD\\x03\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x02\\x35\\x11\\x00\\x00\\x6C\\x03\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x02\\x35\\x11\\x00\\x00\\x10\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x07\\x01\\x00\\x03\\xFE\\x03\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x07\\x01\\x00\\x02\\x7E\\x11\\x00\\x02\\x35\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x07\\x01\\x00\\x02\\x7E\\x11\\x00\\x02\\x35\\x11\\x00\\x02\\x35\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x07\\x01\\x00\\x02\\x7E\\x11\\x00\\x04\\x53\\x03\\x00\\x02\\xE1\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x04\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x04\\x11\\x00\\x00\\x22\\x03\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x04\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x4B\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x4B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x04\\x30\\x03\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x02\\xF8\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x02\\xF8\\x11\\x00\\x02\\xF8\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x01\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x4B\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x44\\x11\\x00\\x00\\x50\\x11\\x00\\x00\\x0B\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x0B\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x02\\x4B\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x0E\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x0E\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x0E\\x11\\x00\\x00\\x4B\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x0E\\x11\\x00\\x00\\x4B\\x11\\x00\\x00\\x01\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x0E\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x7F\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x7F\\x11\\x00\\x02\\xE9\\x11\\x00\\x02\\xE9\\x11\\x00\\x02\\xE9\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x7F\\x11\\x00\\x00\\x4B\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x04\\x37\\x03\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x08\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x08\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x10\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x08\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x08\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x10\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x08\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x08\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x10\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x0D\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x15\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x10\\x11\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x01\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x0F\\x03\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x01\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x34\\x11\\x00\\x02\\xE1\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x0D\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x05\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x03\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x04\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x08\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x0F\\x11\\x00\\x00\\x06\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x0F\\x11\\x00\\x02\\xE1\\x11\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x15\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x21\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x21\\x11\\x00\\x00\\x10\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x0A\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x6C\\x03\\x00\\x02\\x7E\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x10\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x10\\x11\\x00\\x00\\x08\\x11\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x02\\xE1\\x11\\x00\\x02\\x7E\\x11\\x00\\x00\\x07\\x01\\x00\\x00\\x00\\x0F\\x00\\x04\\x53\\x0D\\x00\\x00\\x00\\x0F\\x00\\x00\\x24\\x03\\x00\\x00\\x0D\\x09\\x00\\x00\\x0E\\x09\\x00\\x00\\x0F\\x09\\x00\\x00\\x10\\x09\\x00\\x00\\x11\\x09\\x00\\x00\\x12\\x09\\x00\\x00\\x13\\x09\\x00\\x00\\x14\\x09\\x00\\x00\\x04\\x09\\x00\\x00\\x05\\x09\\x00\\x00\\x06\\x09\\x00\\x00\\x07\\x09\\x00\\x00\\x08\\x09\\x00\\x00\\x09\\x09\\x00\\x00\\x0A\\x09\\x00\\x00\\x02\\x01\\x00\\x03\\xFE\\x05\\x00\\x00\\x00\\x80\\x00\\x03\\xFE\\x05\\x00\\x00\\x00\\x10\\x00\\x03\\xFE\\x05\\x00\\x00\\x00\\xC0\\x00\\x03\\xFE\\x05\\x00\\x00\\x00\\x25\\x00\\x03\\xFE\\x05\\x00\\x00\\x00\\x28\\x00\\x03\\xFE\\x05\\x00\\x00\\x00\\x04\\x00\\x03\\xFE\\x05\\x00\\x00\\x00\\x38\\x00\\x03\\xFE\\x05\\x00\\x00\\x00\\x40\\x00\\x03\\xFE\\x05\\x00\\x00\\x1F\\xF0\\x00\\x03\\xFE\\x05\\x00\\x00\\x00\\x08\\x00\\x00\\x00\\x0B\\x00\\x00\\x02\\x0B\\x00\\x00\\x03\\x0B\\x00\\x00\\x06\\x0B\\x00\\x00\\x08\\x0B\\x00\\x00\\x0B\\x09\\x00\\x00\\x22\\x05\\x00\\x00\\x10\\x00\\x00\\x00\\x22\\x05\\x00\\x00\\x00\\x08\\x00\\x00\\x0F\\x01\\x00\\x00\\xDB\\x05\\x00\\x00\\x00\\x04\\x00\\x00\\x09\\x01\\x00\\x03\\xB0\\x05\\x00\\x00\\x00\\x10\\x00\\x03\\xB5\\x05\\x00\\x00\\x00\\x10\\x00\\x03\\xB5\\x05\\x00\\x00\\x01\\x00\\x00\\x00\\x00\\x09\\x00\\x00\\x01\\x09\\x00\\x00\\x02\\x09\\x00\\x00\\x03\\x09\\x00\\x04\\x2C\\x03\\x00\\x00\\x0C\\x09\\x00\\x04\\x2E\\x03\\x00\\x00\\x15\\x09\\x00\\x00\\x16\\x09\\x00\\x00\\x17\\x09\\x00\\x00\\x18\\x09\\x00\\x00\\x19\\x09\\x00\\x00\\x1A\\x09\\x00\\x00\\x1C\\x09\\x00\\x00\\x1D\\x09\\x00\\x04\\x37\\x03\\x00\\x00\\x1E\\x09\\x00\\x00\\x1F\\x09\\x00\\x00\\x08\\x05\\x00\\x00\\x10\\x00\\x00\\x00\\x08\\x05\\x00\\x00\\x00\\x06\\x00\\x00\\x22\\x09\\x00\\x00\\x23\\x09\\x00\\x03\\xBA\\x03\\x00\\x03\\xBA\\x05\\x00\\x00\\x00\\x80\\x00\\x03\\xBA\\x05\\x00\\x00\\x00\\x0C\\x00\\x03\\xBA\\x05\\x00\\x00\\x00\\x10\\x00\\x03\\xBA\\x05\\x00\\x00\\x00\\x20\\x00\\x03\\xBA\\x05\\x00\\x00\\x00\\x40\\x00\\x00\\x0C\\x01\\x00\\x00\\x11\\x05\\x00\\x00\\x00\\x04\\x00\\x00\\x10\\x05\\x00\\x00\\x20\\x51\\x00\\x02\\xC6\\x03\\x00\\x02\\xDE\\x03\\x00\\x03\\xE0\\x03\\x00\\x03\\xE7\\x03\\x00\\x00\\x00\\x01',\n    _globals = (b'\\xFF\\xFF\\xFF\\x0BGGML_BACKEND_CPU',0,b'\\xFF\\xFF\\xFF\\x0BGGML_BACKEND_GPU',10,b'\\xFF\\xFF\\xFF\\x0BGGML_BACKEND_GPU_SPLIT',20,b'\\xFF\\xFF\\xFF\\x0BGGML_FTYPE_ALL_F32',0,b'\\xFF\\xFF\\xFF\\x0BGGML_FTYPE_MOSTLY_F16',1,b'\\xFF\\xFF\\xFF\\x0BGGML_FTYPE_MOSTLY_Q2_K',10,b'\\xFF\\xFF\\xFF\\x0BGGML_FTYPE_MOSTLY_Q3_K',11,b'\\xFF\\xFF\\xFF\\x0BGGML_FTYPE_MOSTLY_Q4_0',2,b'\\xFF\\xFF\\xFF\\x0BGGML_FTYPE_MOSTLY_Q4_1',3,b'\\xFF\\xFF\\xFF\\x0BGGML_FTYPE_MOSTLY_Q4_1_SOME_F16',4,b'\\xFF\\xFF\\xFF\\x0BGGML_FTYPE_MOSTLY_Q4_K',12,b'\\xFF\\xFF\\xFF\\x0BGGML_FTYPE_MOSTLY_Q5_0',8,b'\\xFF\\xFF\\xFF\\x0BGGML_FTYPE_MOSTLY_Q5_1',9,b'\\xFF\\xFF\\xFF\\x0BGGML_FTYPE_MOSTLY_Q5_K',13,b'\\xFF\\xFF\\xFF\\x0BGGML_FTYPE_MOSTLY_Q6_K',14,b'\\xFF\\xFF\\xFF\\x0BGGML_FTYPE_MOSTLY_Q8_0',7,b'\\xFF\\xFF\\xFF\\x0BGGML_FTYPE_UNKNOWN',-1,b'\\xFF\\xFF\\xFF\\x1FGGML_GRAPH_SIZE',164520,b'\\xFF\\xFF\\xFF\\x0BGGML_LINESEARCH_BACKTRACKING_ARMIJO',0,b'\\xFF\\xFF\\xFF\\x0BGGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE',2,b'\\xFF\\xFF\\xFF\\x0BGGML_LINESEARCH_BACKTRACKING_WOLFE',1,b'\\xFF\\xFF\\xFF\\x0BGGML_LINESEARCH_DEFAULT',1,b'\\xFF\\xFF\\xFF\\x0BGGML_LINESEARCH_FAIL',-128,b'\\xFF\\xFF\\xFF\\x0BGGML_LINESEARCH_INVALID_PARAMETERS',-124,b'\\xFF\\xFF\\xFF\\x0BGGML_LINESEARCH_MAXIMUM_ITERATIONS',-125,b'\\xFF\\xFF\\xFF\\x0BGGML_LINESEARCH_MAXIMUM_STEP',-126,b'\\xFF\\xFF\\xFF\\x0BGGML_LINESEARCH_MINIMUM_STEP',-127,b'\\xFF\\xFF\\xFF\\x0BGGML_OBJECT_GRAPH',1,b'\\xFF\\xFF\\xFF\\x1FGGML_OBJECT_SIZE',32,b'\\xFF\\xFF\\xFF\\x0BGGML_OBJECT_TENSOR',0,b'\\xFF\\xFF\\xFF\\x0BGGML_OBJECT_WORK_BUFFER',2,b'\\xFF\\xFF\\xFF\\x0BGGML_OPT_ADAM',0,b'\\xFF\\xFF\\xFF\\x0BGGML_OPT_DID_NOT_CONVERGE',1,b'\\xFF\\xFF\\xFF\\x0BGGML_OPT_FAIL',4,b'\\xFF\\xFF\\xFF\\x0BGGML_OPT_INVALID_WOLFE',3,b'\\xFF\\xFF\\xFF\\x0BGGML_OPT_LBFGS',1,b'\\xFF\\xFF\\xFF\\x0BGGML_OPT_NO_CONTEXT',2,b'\\xFF\\xFF\\xFF\\x0BGGML_OPT_OK',0,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_ACC',4,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_ADD',2,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_ADD1',3,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_ALIBI',40,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_ARGMAX',14,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_CLAMP',41,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_CONT',26,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_CONV_1D',42,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_CONV_2D',43,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_COUNT',62,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_CPY',25,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_CROSS_ENTROPY_LOSS',60,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_CROSS_ENTROPY_LOSS_BACK',61,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_DIAG',33,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_DIAG_MASK_INF',34,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_DIAG_MASK_ZERO',35,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_DIV',7,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_DUP',1,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_FLASH_ATTN',46,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_FLASH_ATTN_BACK',48,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_FLASH_FF',47,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_GET_ROWS',31,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_GET_ROWS_BACK',32,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_LOG',10,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_MAP_BINARY',53,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_MAP_CUSTOM1',57,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_MAP_CUSTOM1_F32',54,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_MAP_CUSTOM2',58,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_MAP_CUSTOM2_F32',55,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_MAP_CUSTOM3',59,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_MAP_CUSTOM3_F32',56,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_MAP_UNARY',52,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_MEAN',13,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_MUL',6,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_MUL_MAT',21,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_NONE',0,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_NORM',18,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_OUT_PROD',22,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_PERMUTE',29,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_POOL_1D',44,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_POOL_2D',45,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_POOL_AVG',1,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_POOL_COUNT',2,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_POOL_MAX',0,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_REPEAT',15,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_REPEAT_BACK',16,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_RESHAPE',27,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_RMS_NORM',19,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_RMS_NORM_BACK',20,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_ROPE',38,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_ROPE_BACK',39,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_SCALE',23,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_SET',24,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_SILU_BACK',17,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_SOFT_MAX',36,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_SOFT_MAX_BACK',37,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_SQR',8,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_SQRT',9,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_SUB',5,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_SUM',11,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_SUM_ROWS',12,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_TRANSPOSE',30,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_UNARY',51,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_VIEW',28,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_WIN_PART',49,b'\\xFF\\xFF\\xFF\\x0BGGML_OP_WIN_UNPART',50,b'\\xFF\\xFF\\xFF\\x0BGGML_TASK_COMPUTE',1,b'\\xFF\\xFF\\xFF\\x0BGGML_TASK_FINALIZE',2,b'\\xFF\\xFF\\xFF\\x0BGGML_TASK_INIT',0,b'\\xFF\\xFF\\xFF\\x1FGGML_TENSOR_SIZE',288,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_COUNT',19,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_F16',1,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_F32',0,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_I16',17,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_I32',18,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_I8',16,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_Q2_K',10,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_Q3_K',11,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_Q4_0',2,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_Q4_1',3,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_Q4_K',12,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_Q5_0',6,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_Q5_1',7,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_Q5_K',13,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_Q6_K',14,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_Q8_0',8,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_Q8_1',9,b'\\xFF\\xFF\\xFF\\x0BGGML_TYPE_Q8_K',15,b'\\xFF\\xFF\\xFF\\x0BGGML_UNARY_OP_ABS',0,b'\\xFF\\xFF\\xFF\\x0BGGML_UNARY_OP_ELU',5,b'\\xFF\\xFF\\xFF\\x0BGGML_UNARY_OP_GELU',7,b'\\xFF\\xFF\\xFF\\x0BGGML_UNARY_OP_GELU_QUICK',8,b'\\xFF\\xFF\\xFF\\x0BGGML_UNARY_OP_NEG',2,b'\\xFF\\xFF\\xFF\\x0BGGML_UNARY_OP_RELU',6,b'\\xFF\\xFF\\xFF\\x0BGGML_UNARY_OP_SGN',1,b'\\xFF\\xFF\\xFF\\x0BGGML_UNARY_OP_SILU',9,b'\\xFF\\xFF\\xFF\\x0BGGML_UNARY_OP_STEP',3,b'\\xFF\\xFF\\xFF\\x0BGGML_UNARY_OP_TANH',4,b'\\xFF\\xFF\\xFF\\x0BGGUF_TYPE_ARRAY',9,b'\\xFF\\xFF\\xFF\\x0BGGUF_TYPE_BOOL',7,b'\\xFF\\xFF\\xFF\\x0BGGUF_TYPE_COUNT',10,b'\\xFF\\xFF\\xFF\\x0BGGUF_TYPE_FLOAT32',6,b'\\xFF\\xFF\\xFF\\x0BGGUF_TYPE_INT16',3,b'\\xFF\\xFF\\xFF\\x0BGGUF_TYPE_INT32',5,b'\\xFF\\xFF\\xFF\\x0BGGUF_TYPE_INT8',1,b'\\xFF\\xFF\\xFF\\x0BGGUF_TYPE_STRING',8,b'\\xFF\\xFF\\xFF\\x0BGGUF_TYPE_UINT16',2,b'\\xFF\\xFF\\xFF\\x0BGGUF_TYPE_UINT32',4,b'\\xFF\\xFF\\xFF\\x0BGGUF_TYPE_UINT8',0,b'\\x00\\x02\\x9A\\x23__assert_rtn',0,b'\\x00\\x02\\x7C\\x23dequantize_row_q2_K',0,b'\\x00\\x02\\x81\\x23dequantize_row_q3_K',0,b'\\x00\\x02\\x86\\x23dequantize_row_q4_K',0,b'\\x00\\x02\\x8B\\x23dequantize_row_q5_K',0,b'\\x00\\x02\\x90\\x23dequantize_row_q6_K',0,b'\\x00\\x02\\x95\\x23dequantize_row_q8_K',0,b'\\x00\\x00\\xFA\\x23ggml_abs',0,b'\\x00\\x00\\xFA\\x23ggml_abs_inplace',0,b'\\x00\\x01\\xDD\\x23ggml_acc',0,b'\\x00\\x01\\xDD\\x23ggml_acc_inplace',0,b'\\x00\\x01\\x84\\x23ggml_add',0,b'\\x00\\x01\\x84\\x23ggml_add1',0,b'\\x00\\x01\\x84\\x23ggml_add1_inplace',0,b'\\x00\\x01\\x84\\x23ggml_add_inplace',0,b'\\x00\\x01\\x26\\x23ggml_alibi',0,b'\\x00\\x02\\xEC\\x23ggml_allocr_alloc',0,b'\\x00\\x02\\x42\\x23ggml_allocr_alloc_graph',0,b'\\x00\\x02\\xE4\\x23ggml_allocr_free',0,b'\\x00\\x00\\x03\\x23ggml_allocr_is_measure',0,b'\\x00\\x00\\xA2\\x23ggml_allocr_new',0,b'\\x00\\x00\\x9F\\x23ggml_allocr_new_measure',0,b'\\x00\\x02\\xE4\\x23ggml_allocr_reset',0,b'\\x00\\x02\\xE7\\x23ggml_allocr_set_parse_seq',0,b'\\x00\\x00\\x17\\x23ggml_are_same_shape',0,b'\\x00\\x00\\xFA\\x23ggml_argmax',0,b'\\x00\\x00\\x74\\x23ggml_blck_size',0,b'\\x00\\x00\\xB3\\x23ggml_build_backward',0,b'\\x00\\x00\\xB8\\x23ggml_build_forward',0,b'\\x00\\x00\\xAA\\x23ggml_build_forward_ctx',0,b'\\x00\\x02\\xF3\\x23ggml_build_forward_expand',0,b'\\x00\\x00\\x1B\\x23ggml_cl_can_mul_mat',0,b'\\x00\\x03\\x6B\\x23ggml_cl_free_data',0,b'\\x00\\x03\\xE0\\x23ggml_cl_host_free',0,b'\\x00\\x02\\x72\\x23ggml_cl_host_malloc',0,b'\\x00\\x03\\xEC\\x23ggml_cl_init',0,b'\\x00\\x03\\x78\\x23ggml_cl_mul',0,b'\\x00\\x03\\x7D\\x23ggml_cl_mul_mat',0,b'\\x00\\x02\\x54\\x23ggml_cl_mul_mat_get_wsize',0,b'\\x00\\x03\\xE3\\x23ggml_cl_transform_tensor',0,b'\\x00\\x01\\x1B\\x23ggml_clamp',0,b'\\x00\\x00\\xFA\\x23ggml_cont',0,b'\\x00\\x01\\x90\\x23ggml_conv_1d',0,b'\\x00\\x01\\x89\\x23ggml_conv_1d_ph',0,b'\\x00\\x01\\x98\\x23ggml_conv_2d',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_arm_fma',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_avx',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_avx2',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_avx512',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_avx512_vbmi',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_avx512_vnni',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_blas',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_clblast',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_cublas',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_f16c',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_fma',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_fp16_va',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_gpublas',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_neon',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_sse3',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_vsx',0,b'\\x00\\x00\\x90\\x23ggml_cpu_has_wasm_simd',0,b'\\x00\\x01\\x84\\x23ggml_cpy',0,b'\\x00\\x01\\x84\\x23ggml_cross_entropy_loss',0,b'\\x00\\x01\\xA3\\x23ggml_cross_entropy_loss_back',0,b'\\x00\\x03\\x41\\x23ggml_cuda_assign_buffers',0,b'\\x00\\x03\\x41\\x23ggml_cuda_assign_buffers_force_inplace',0,b'\\x00\\x03\\x41\\x23ggml_cuda_assign_buffers_no_scratch',0,b'\\x00\\x00\\x1B\\x23ggml_cuda_can_mul_mat',0,b'\\x00\\x00\\x06\\x23ggml_cuda_compute_forward',0,b'\\x00\\x03\\x41\\x23ggml_cuda_free_data',0,b'\\x00\\x03\\xEC\\x23ggml_cuda_free_scratch',0,b'\\x00\\x00\\x90\\x23ggml_cuda_get_device_count',0,b'\\x00\\x02\\xCE\\x23ggml_cuda_get_device_description',0,b'\\x00\\x03\\xE0\\x23ggml_cuda_host_free',0,b'\\x00\\x02\\x72\\x23ggml_cuda_host_malloc',0,b'\\x00\\x02\\xCB\\x23ggml_cuda_set_main_device',0,b'\\x00\\x02\\x79\\x23ggml_cuda_set_mul_mat_q',0,b'\\x00\\x03\\xD8\\x23ggml_cuda_set_scratch_size',0,b'\\x00\\x02\\xA0\\x23ggml_cuda_set_tensor_split',0,b'\\x00\\x03\\xE3\\x23ggml_cuda_transform_tensor',0,b'\\x00\\x00\\x95\\x23ggml_cycles',0,b'\\x00\\x00\\x95\\x23ggml_cycles_per_ms',0,b'\\x00\\x00\\xFA\\x23ggml_diag',0,b'\\x00\\x01\\x21\\x23ggml_diag_mask_inf',0,b'\\x00\\x01\\x21\\x23ggml_diag_mask_inf_inplace',0,b'\\x00\\x01\\x21\\x23ggml_diag_mask_zero',0,b'\\x00\\x01\\x21\\x23ggml_diag_mask_zero_inplace',0,b'\\x00\\x01\\x84\\x23ggml_div',0,b'\\x00\\x01\\x84\\x23ggml_div_inplace',0,b'\\x00\\x00\\xFA\\x23ggml_dup',0,b'\\x00\\x00\\xFA\\x23ggml_dup_inplace',0,b'\\x00\\x02\\x0B\\x23ggml_dup_tensor',0,b'\\x00\\x02\\x4D\\x23ggml_element_size',0,b'\\x00\\x00\\xFA\\x23ggml_elu',0,b'\\x00\\x00\\xFA\\x23ggml_elu_inplace',0,b'\\x00\\x01\\xA9\\x23ggml_flash_attn',0,b'\\x00\\x01\\xB0\\x23ggml_flash_attn_back',0,b'\\x00\\x01\\xB8\\x23ggml_flash_ff',0,b'\\x00\\x02\\x16\\x23ggml_format_name',0,b'\\x00\\x00\\x6B\\x23ggml_fp16_to_fp32',0,b'\\x00\\x03\\xDB\\x23ggml_fp16_to_fp32_row',0,b'\\x00\\x02\\x62\\x23ggml_fp32_to_fp16',0,b'\\x00\\x02\\xC1\\x23ggml_fp32_to_fp16_row',0,b'\\x00\\x03\\x03\\x23ggml_free',0,b'\\x00\\x00\\x53\\x23ggml_ftype_to_ggml_type',0,b'\\x00\\x00\\xFA\\x23ggml_gelu',0,b'\\x00\\x00\\xFA\\x23ggml_gelu_inplace',0,b'\\x00\\x00\\xFA\\x23ggml_gelu_quick',0,b'\\x00\\x00\\xFA\\x23ggml_gelu_quick_inplace',0,b'\\x00\\x02\\x6C\\x23ggml_get_data',0,b'\\x00\\x00\\x5D\\x23ggml_get_data_f32',0,b'\\x00\\x00\\x63\\x23ggml_get_f32_1d',0,b'\\x00\\x00\\x81\\x23ggml_get_i32_1d',0,b'\\x00\\x02\\x4A\\x23ggml_get_max_tensor_size',0,b'\\x00\\x02\\x69\\x23ggml_get_mem_buffer',0,b'\\x00\\x02\\x4A\\x23ggml_get_mem_size',0,b'\\x00\\x00\\x36\\x23ggml_get_name',0,b'\\x00\\x00\\x0A\\x23ggml_get_no_alloc',0,b'\\x00\\x01\\x84\\x23ggml_get_rows',0,b'\\x00\\x01\\xA3\\x23ggml_get_rows_back',0,b'\\x00\\x00\\xCE\\x23ggml_get_tensor',0,b'\\x00\\x00\\x56\\x23ggml_get_unary_op',0,b'\\x00\\x00\\x77\\x23ggml_graph_compute',0,b'\\x00\\x03\\x0A\\x23ggml_graph_compute_with_ctx',0,b'\\x00\\x02\\xFE\\x23ggml_graph_dump_dot',0,b'\\x00\\x02\\xFA\\x23ggml_graph_export',0,b'\\x00\\x00\\xCA\\x23ggml_graph_get_tensor',0,b'\\x00\\x00\\xAE\\x23ggml_graph_import',0,b'\\x00\\x02\\x60\\x23ggml_graph_overhead',0,b'\\x00\\x00\\xBE\\x23ggml_graph_plan',0,b'\\x00\\x02\\xF7\\x23ggml_graph_print',0,b'\\x00\\x02\\xF0\\x23ggml_graph_reset',0,b'\\x00\\x00\\xBB\\x23ggml_init',0,b'\\x00\\x03\\xEC\\x23ggml_init_cublas',0,b'\\x00\\x00\\x6E\\x23ggml_internal_get_type_traits',0,b'\\x00\\x00\\x14\\x23ggml_is_contiguous',0,b'\\x00\\x00\\x27\\x23ggml_is_numa',0,b'\\x00\\x00\\x14\\x23ggml_is_permuted',0,b'\\x00\\x00\\x00\\x23ggml_is_quantized',0,b'\\x00\\x00\\x14\\x23ggml_is_transposed',0,b'\\x00\\x00\\xFA\\x23ggml_log',0,b'\\x00\\x00\\xFA\\x23ggml_log_inplace',0,b'\\x00\\x01\\xE6\\x23ggml_map_binary_f32',0,b'\\x00\\x01\\xE6\\x23ggml_map_binary_inplace_f32',0,b'\\x00\\x02\\x04\\x23ggml_map_custom1',0,b'\\x00\\x01\\xFF\\x23ggml_map_custom1_f32',0,b'\\x00\\x02\\x04\\x23ggml_map_custom1_inplace',0,b'\\x00\\x01\\xFF\\x23ggml_map_custom1_inplace_f32',0,b'\\x00\\x01\\xF2\\x23ggml_map_custom2',0,b'\\x00\\x01\\xEC\\x23ggml_map_custom2_f32',0,b'\\x00\\x01\\xF2\\x23ggml_map_custom2_inplace',0,b'\\x00\\x01\\xEC\\x23ggml_map_custom2_inplace_f32',0,b'\\x00\\x01\\xC7\\x23ggml_map_custom3',0,b'\\x00\\x01\\xC0\\x23ggml_map_custom3_f32',0,b'\\x00\\x01\\xC7\\x23ggml_map_custom3_inplace',0,b'\\x00\\x01\\xC0\\x23ggml_map_custom3_inplace_f32',0,b'\\x00\\x01\\xFA\\x23ggml_map_unary_f32',0,b'\\x00\\x01\\xFA\\x23ggml_map_unary_inplace_f32',0,b'\\x00\\x00\\xFA\\x23ggml_mean',0,b'\\x00\\x00\\x0D\\x23ggml_metal_add_buffer',0,b'\\x00\\x03\\x1C\\x23ggml_metal_free',0,b'\\x00\\x00\\x71\\x23ggml_metal_get_concur_list',0,b'\\x00\\x03\\x2C\\x23ggml_metal_get_tensor',0,b'\\x00\\x03\\x23\\x23ggml_metal_graph_compute',0,b'\\x00\\x03\\x27\\x23ggml_metal_graph_find_concurrency',0,b'\\x00\\x03\\xE0\\x23ggml_metal_host_free',0,b'\\x00\\x02\\x72\\x23ggml_metal_host_malloc',0,b'\\x00\\x00\\x7B\\x23ggml_metal_if_optimized',0,b'\\x00\\x00\\xC2\\x23ggml_metal_init',0,b'\\x00\\x03\\x1F\\x23ggml_metal_set_n_cb',0,b'\\x00\\x03\\x2C\\x23ggml_metal_set_tensor',0,b'\\x00\\x03\\xEC\\x23ggml_mpi_backend_free',0,b'\\x00\\x03\\xEC\\x23ggml_mpi_backend_init',0,b'\\x00\\x03\\x33\\x23ggml_mpi_eval_init',0,b'\\x00\\x03\\x30\\x23ggml_mpi_free',0,b'\\x00\\x03\\x39\\x23ggml_mpi_graph_compute_post',0,b'\\x00\\x03\\x39\\x23ggml_mpi_graph_compute_pre',0,b'\\x00\\x00\\xC5\\x23ggml_mpi_init',0,b'\\x00\\x00\\x7E\\x23ggml_mpi_rank',0,b'\\x00\\x01\\x84\\x23ggml_mul',0,b'\\x00\\x01\\x84\\x23ggml_mul_inplace',0,b'\\x00\\x01\\x84\\x23ggml_mul_mat',0,b'\\x00\\x02\\x4D\\x23ggml_nbytes',0,b'\\x00\\x02\\x4D\\x23ggml_nbytes_pad',0,b'\\x00\\x02\\x50\\x23ggml_nbytes_split',0,b'\\x00\\x00\\xFA\\x23ggml_neg',0,b'\\x00\\x00\\xFA\\x23ggml_neg_inplace',0,b'\\x00\\x00\\x92\\x23ggml_nelements',0,b'\\x00\\x00\\xF2\\x23ggml_new_f32',0,b'\\x00\\x00\\xA7\\x23ggml_new_graph',0,b'\\x00\\x00\\xF6\\x23ggml_new_i32',0,b'\\x00\\x00\\xD2\\x23ggml_new_tensor',0,b'\\x00\\x00\\xD8\\x23ggml_new_tensor_1d',0,b'\\x00\\x00\\xDD\\x23ggml_new_tensor_2d',0,b'\\x00\\x00\\xE3\\x23ggml_new_tensor_3d',0,b'\\x00\\x00\\xEA\\x23ggml_new_tensor_4d',0,b'\\x00\\x00\\xFA\\x23ggml_norm',0,b'\\x00\\x00\\xFA\\x23ggml_norm_inplace',0,b'\\x00\\x00\\x92\\x23ggml_nrows',0,b'\\x00\\x03\\xEC\\x23ggml_numa_init',0,b'\\x00\\x00\\x2D\\x23ggml_op_name',0,b'\\x00\\x00\\x2D\\x23ggml_op_symbol',0,b'\\x00\\x00\\x4E\\x23ggml_opt',0,b'\\x00\\x00\\xC7\\x23ggml_opt_default_params',0,b'\\x00\\x03\\x0F\\x23ggml_opt_init',0,b'\\x00\\x00\\x42\\x23ggml_opt_resume',0,b'\\x00\\x00\\x47\\x23ggml_opt_resume_g',0,b'\\x00\\x01\\x84\\x23ggml_out_prod',0,b'\\x00\\x01\\x34\\x23ggml_permute',0,b'\\x00\\x00\\xFE\\x23ggml_pool_1d',0,b'\\x00\\x01\\x06\\x23ggml_pool_2d',0,b'\\x00\\x03\\x3E\\x23ggml_print_object',0,b'\\x00\\x03\\x19\\x23ggml_print_objects',0,b'\\x00\\x02\\x33\\x23ggml_quantize_chunk',0,b'\\x00\\x02\\x3B\\x23ggml_quantize_q2_K',0,b'\\x00\\x02\\x3B\\x23ggml_quantize_q3_K',0,b'\\x00\\x02\\x3B\\x23ggml_quantize_q4_0',0,b'\\x00\\x02\\x3B\\x23ggml_quantize_q4_1',0,b'\\x00\\x02\\x3B\\x23ggml_quantize_q4_K',0,b'\\x00\\x02\\x3B\\x23ggml_quantize_q5_0',0,b'\\x00\\x02\\x3B\\x23ggml_quantize_q5_1',0,b'\\x00\\x02\\x3B\\x23ggml_quantize_q5_K',0,b'\\x00\\x02\\x3B\\x23ggml_quantize_q6_K',0,b'\\x00\\x02\\x3B\\x23ggml_quantize_q8_0',0,b'\\x00\\x00\\xFA\\x23ggml_relu',0,b'\\x00\\x00\\xFA\\x23ggml_relu_inplace',0,b'\\x00\\x01\\x84\\x23ggml_repeat',0,b'\\x00\\x01\\x84\\x23ggml_repeat_back',0,b'\\x00\\x01\\x84\\x23ggml_reshape',0,b'\\x00\\x01\\x46\\x23ggml_reshape_1d',0,b'\\x00\\x01\\x4B\\x23ggml_reshape_2d',0,b'\\x00\\x01\\x51\\x23ggml_reshape_3d',0,b'\\x00\\x01\\x58\\x23ggml_reshape_4d',0,b'\\x00\\x01\\x16\\x23ggml_rms_norm',0,b'\\x00\\x01\\x84\\x23ggml_rms_norm_back',0,b'\\x00\\x01\\x16\\x23ggml_rms_norm_inplace',0,b'\\x00\\x01\\x34\\x23ggml_rope',0,b'\\x00\\x01\\x34\\x23ggml_rope_back',0,b'\\x00\\x01\\x3C\\x23ggml_rope_custom',0,b'\\x00\\x01\\x3C\\x23ggml_rope_custom_inplace',0,b'\\x00\\x01\\x34\\x23ggml_rope_inplace',0,b'\\x00\\x01\\x84\\x23ggml_scale',0,b'\\x00\\x01\\x84\\x23ggml_scale_inplace',0,b'\\x00\\x01\\xDD\\x23ggml_set',0,b'\\x00\\x01\\xD0\\x23ggml_set_1d',0,b'\\x00\\x01\\xD0\\x23ggml_set_1d_inplace',0,b'\\x00\\x01\\xD6\\x23ggml_set_2d',0,b'\\x00\\x01\\xD6\\x23ggml_set_2d_inplace',0,b'\\x00\\x02\\x1A\\x23ggml_set_f32',0,b'\\x00\\x03\\x6E\\x23ggml_set_f32_1d',0,b'\\x00\\x02\\x1E\\x23ggml_set_i32',0,b'\\x00\\x03\\x73\\x23ggml_set_i32_1d',0,b'\\x00\\x01\\xDD\\x23ggml_set_inplace',0,b'\\x00\\x02\\x12\\x23ggml_set_name',0,b'\\x00\\x03\\x06\\x23ggml_set_no_alloc',0,b'\\x00\\x03\\x15\\x23ggml_set_param',0,b'\\x00\\x02\\x46\\x23ggml_set_scratch',0,b'\\x00\\x02\\x0F\\x23ggml_set_zero',0,b'\\x00\\x00\\xFA\\x23ggml_sgn',0,b'\\x00\\x00\\xFA\\x23ggml_sgn_inplace',0,b'\\x00\\x00\\xFA\\x23ggml_silu',0,b'\\x00\\x01\\x84\\x23ggml_silu_back',0,b'\\x00\\x00\\xFA\\x23ggml_silu_inplace',0,b'\\x00\\x00\\xFA\\x23ggml_soft_max',0,b'\\x00\\x01\\x84\\x23ggml_soft_max_back',0,b'\\x00\\x01\\x84\\x23ggml_soft_max_back_inplace',0,b'\\x00\\x00\\xFA\\x23ggml_soft_max_inplace',0,b'\\x00\\x00\\xFA\\x23ggml_sqr',0,b'\\x00\\x00\\xFA\\x23ggml_sqr_inplace',0,b'\\x00\\x00\\xFA\\x23ggml_sqrt',0,b'\\x00\\x00\\xFA\\x23ggml_sqrt_inplace',0,b'\\x00\\x00\\xFA\\x23ggml_step',0,b'\\x00\\x00\\xFA\\x23ggml_step_inplace',0,b'\\x00\\x01\\x84\\x23ggml_sub',0,b'\\x00\\x01\\x84\\x23ggml_sub_inplace',0,b'\\x00\\x00\\xFA\\x23ggml_sum',0,b'\\x00\\x00\\xFA\\x23ggml_sum_rows',0,b'\\x00\\x00\\xFA\\x23ggml_tanh',0,b'\\x00\\x00\\xFA\\x23ggml_tanh_inplace',0,b'\\x00\\x02\\x60\\x23ggml_tensor_overhead',0,b'\\x00\\x03\\xEC\\x23ggml_time_init',0,b'\\x00\\x00\\x95\\x23ggml_time_ms',0,b'\\x00\\x00\\x95\\x23ggml_time_us',0,b'\\x00\\x00\\xFA\\x23ggml_transpose',0,b'\\x00\\x00\\x30\\x23ggml_type_name',0,b'\\x00\\x02\\x30\\x23ggml_type_size',0,b'\\x00\\x00\\x60\\x23ggml_type_sizef',0,b'\\x00\\x01\\x11\\x23ggml_unary',0,b'\\x00\\x01\\x11\\x23ggml_unary_inplace',0,b'\\x00\\x02\\x4A\\x23ggml_used_mem',0,b'\\x00\\x02\\xDE\\x23ggml_vec_dot_q2_K_q8_K',0,b'\\x00\\x02\\xDE\\x23ggml_vec_dot_q3_K_q8_K',0,b'\\x00\\x02\\xDE\\x23ggml_vec_dot_q4_K_q8_K',0,b'\\x00\\x02\\xDE\\x23ggml_vec_dot_q5_K_q8_K',0,b'\\x00\\x02\\xDE\\x23ggml_vec_dot_q6_K_q8_K',0,b'\\x00\\x01\\x7E\\x23ggml_view_1d',0,b'\\x00\\x01\\x76\\x23ggml_view_2d',0,b'\\x00\\x01\\x6C\\x23ggml_view_3d',0,b'\\x00\\x01\\x60\\x23ggml_view_4d',0,b'\\x00\\x02\\x0B\\x23ggml_view_tensor',0,b'\\x00\\x01\\x21\\x23ggml_win_part',0,b'\\x00\\x01\\x2D\\x23ggml_win_unpart',0,b'\\x00\\x03\\xCC\\x23gguf_add_tensor',0,b'\\x00\\x00\\x88\\x23gguf_find_key',0,b'\\x00\\x00\\x88\\x23gguf_find_tensor',0,b'\\x00\\x03\\x84\\x23gguf_free',0,b'\\x00\\x02\\x59\\x23gguf_get_alignment',0,b'\\x00\\x02\\x75\\x23gguf_get_arr_data',0,b'\\x00\\x00\\x8C\\x23gguf_get_arr_n',0,b'\\x00\\x00\\x3D\\x23gguf_get_arr_str',0,b'\\x00\\x00\\x59\\x23gguf_get_arr_type',0,b'\\x00\\x02\\x6F\\x23gguf_get_data',0,b'\\x00\\x02\\x59\\x23gguf_get_data_offset',0,b'\\x00\\x00\\x39\\x23gguf_get_key',0,b'\\x00\\x00\\x59\\x23gguf_get_kv_type',0,b'\\x00\\x03\\xD4\\x23gguf_get_meta_data',0,b'\\x00\\x02\\x59\\x23gguf_get_meta_size',0,b'\\x00\\x00\\x85\\x23gguf_get_n_kv',0,b'\\x00\\x00\\x85\\x23gguf_get_n_tensors',0,b'\\x00\\x00\\x29\\x23gguf_get_tensor_name',0,b'\\x00\\x02\\x5C\\x23gguf_get_tensor_offset',0,b'\\x00\\x00\\x20\\x23gguf_get_val_bool',0,b'\\x00\\x00\\x67\\x23gguf_get_val_f32',0,b'\\x00\\x00\\x97\\x23gguf_get_val_i16',0,b'\\x00\\x00\\x8C\\x23gguf_get_val_i32',0,b'\\x00\\x00\\x9B\\x23gguf_get_val_i8',0,b'\\x00\\x00\\x39\\x23gguf_get_val_str',0,b'\\x00\\x02\\x65\\x23gguf_get_val_u16',0,b'\\x00\\x02\\x2C\\x23gguf_get_val_u32',0,b'\\x00\\x02\\x28\\x23gguf_get_val_u8',0,b'\\x00\\x00\\x85\\x23gguf_get_version',0,b'\\x00\\x02\\x26\\x23gguf_init_empty',0,b'\\x00\\x02\\x22\\x23gguf_init_from_file',0,b'\\x00\\x03\\x9C\\x23gguf_set_arr_data',0,b'\\x00\\x03\\x8C\\x23gguf_set_arr_str',0,b'\\x00\\x03\\xD0\\x23gguf_set_kv',0,b'\\x00\\x03\\xC6\\x23gguf_set_tensor_data',0,b'\\x00\\x03\\x97\\x23gguf_set_tensor_type',0,b'\\x00\\x03\\x87\\x23gguf_set_val_bool',0,b'\\x00\\x03\\xA3\\x23gguf_set_val_f32',0,b'\\x00\\x03\\xAD\\x23gguf_set_val_i16',0,b'\\x00\\x03\\xA8\\x23gguf_set_val_i32',0,b'\\x00\\x03\\xB2\\x23gguf_set_val_i8',0,b'\\x00\\x03\\x92\\x23gguf_set_val_str',0,b'\\x00\\x03\\xC1\\x23gguf_set_val_u16',0,b'\\x00\\x03\\xBC\\x23gguf_set_val_u32',0,b'\\x00\\x03\\xB7\\x23gguf_set_val_u8',0,b'\\x00\\x00\\x33\\x23gguf_type_name',0,b'\\x00\\x03\\x87\\x23gguf_write_to_file',0,b'\\x00\\x02\\xC6\\x23quantize_row_q2_K',0,b'\\x00\\x02\\xA3\\x23quantize_row_q2_K_reference',0,b'\\x00\\x02\\xC6\\x23quantize_row_q3_K',0,b'\\x00\\x02\\xA8\\x23quantize_row_q3_K_reference',0,b'\\x00\\x02\\xC6\\x23quantize_row_q4_K',0,b'\\x00\\x02\\xAD\\x23quantize_row_q4_K_reference',0,b'\\x00\\x02\\xC6\\x23quantize_row_q5_K',0,b'\\x00\\x02\\xB2\\x23quantize_row_q5_K_reference',0,b'\\x00\\x02\\xC6\\x23quantize_row_q6_K',0,b'\\x00\\x02\\xB7\\x23quantize_row_q6_K_reference',0,b'\\x00\\x02\\xC6\\x23quantize_row_q8_K',0,b'\\x00\\x02\\xBC\\x23quantize_row_q8_K_reference',0),\n    _struct_unions = ((b'\\x00\\x00\\x04\\x27\\x00\\x00\\x00\\x02$1',b'\\x00\\x00\\x22\\x11n_iter',b'\\x00\\x00\\xF4\\x11sched',b'\\x00\\x00\\xF4\\x11decay',b'\\x00\\x00\\xF4\\x11alpha',b'\\x00\\x00\\xF4\\x11beta1',b'\\x00\\x00\\xF4\\x11beta2',b'\\x00\\x00\\xF4\\x11eps',b'\\x00\\x00\\xF4\\x11eps_f',b'\\x00\\x00\\xF4\\x11eps_g'),(b'\\x00\\x00\\x04\\x28\\x00\\x00\\x00\\x02$2',b'\\x00\\x00\\x22\\x11m',b'\\x00\\x00\\x22\\x11n_iter',b'\\x00\\x00\\x22\\x11max_linesearch',b'\\x00\\x00\\xF4\\x11eps',b'\\x00\\x00\\xF4\\x11ftol',b'\\x00\\x00\\xF4\\x11wolfe',b'\\x00\\x00\\xF4\\x11min_step',b'\\x00\\x00\\xF4\\x11max_step',b'\\x00\\x04\\x14\\x11linesearch'),(b'\\x00\\x00\\x04\\x29\\x00\\x00\\x00\\x02$3',b'\\x00\\x00\\x08\\x11x',b'\\x00\\x00\\x08\\x11g1',b'\\x00\\x00\\x08\\x11g2',b'\\x00\\x00\\x08\\x11m',b'\\x00\\x00\\x08\\x11v',b'\\x00\\x00\\x08\\x11mh',b'\\x00\\x00\\x08\\x11vh',b'\\x00\\x00\\x08\\x11pf',b'\\x00\\x00\\xF4\\x11fx_best',b'\\x00\\x00\\xF4\\x11fx_prev',b'\\x00\\x00\\x22\\x11n_no_improvement'),(b'\\x00\\x00\\x04\\x2A\\x00\\x00\\x00\\x02$4',b'\\x00\\x00\\x08\\x11x',b'\\x00\\x00\\x08\\x11xp',b'\\x00\\x00\\x08\\x11g',b'\\x00\\x00\\x08\\x11gp',b'\\x00\\x00\\x08\\x11d',b'\\x00\\x00\\x08\\x11pf',b'\\x00\\x00\\x08\\x11lmal',b'\\x00\\x00\\x08\\x11lmys',b'\\x00\\x00\\x08\\x11lms',b'\\x00\\x00\\x08\\x11lmy',b'\\x00\\x00\\xF4\\x11fx_best',b'\\x00\\x00\\xF4\\x11step',b'\\x00\\x00\\x22\\x11j',b'\\x00\\x00\\x22\\x11k',b'\\x00\\x00\\x22\\x11end',b'\\x00\\x00\\x22\\x11n_no_improvement'),(b'\\x00\\x00\\x03\\xF7\\x00\\x00\\x00\\x03$__mbstate_t',b'\\x00\\x03\\xFF\\x11__mbstate8',b'\\x00\\x00\\xDB\\x11_mbstateL'),(b'\\x00\\x00\\x03\\xF8\\x00\\x00\\x00\\x02$block_q2_K',b'\\x00\\x04\\x44\\x11scales',b'\\x00\\x04\\x48\\x11qs',b'\\x00\\x00\\x6C\\x11d',b'\\x00\\x00\\x6C\\x11dmin'),(b'\\x00\\x00\\x03\\xF9\\x00\\x00\\x00\\x02$block_q3_K',b'\\x00\\x04\\x46\\x11hmask',b'\\x00\\x04\\x48\\x11qs',b'\\x00\\x04\\x42\\x11scales',b'\\x00\\x00\\x6C\\x11d'),(b'\\x00\\x00\\x03\\xFA\\x00\\x00\\x00\\x02$block_q4_K',b'\\x00\\x00\\x6C\\x11d',b'\\x00\\x00\\x6C\\x11dmin',b'\\x00\\x04\\x42\\x11scales',b'\\x00\\x04\\x40\\x11qs'),(b'\\x00\\x00\\x03\\xFB\\x00\\x00\\x00\\x02$block_q5_K',b'\\x00\\x00\\x6C\\x11d',b'\\x00\\x00\\x6C\\x11dmin',b'\\x00\\x04\\x42\\x11scales',b'\\x00\\x04\\x46\\x11qh',b'\\x00\\x04\\x40\\x11qs'),(b'\\x00\\x00\\x03\\xFC\\x00\\x00\\x00\\x02$block_q6_K',b'\\x00\\x04\\x40\\x11ql',b'\\x00\\x04\\x48\\x11qh',b'\\x00\\x04\\x23\\x11scales',b'\\x00\\x00\\x6C\\x11d'),(b'\\x00\\x00\\x03\\xFD\\x00\\x00\\x00\\x02$block_q8_K',b'\\x00\\x00\\xF4\\x11d',b'\\x00\\x04\\x25\\x11qs',b'\\x00\\x04\\x21\\x11bsums'),(b'\\x00\\x00\\x04\\x18\\x00\\x00\\x00\\x02$ggml_type_traits_t',b'\\x00\\x00\\x0F\\x11type_name',b'\\x00\\x00\\x22\\x11blck_size',b'\\x00\\x00\\x11\\x11type_size',b'\\x00\\x00\\xB6\\x11is_quantized',b'\\x00\\x04\\x52\\x11to_float',b'\\x00\\x04\\x4F\\x11from_float',b'\\x00\\x04\\x4F\\x11from_float_reference',b'\\x00\\x04\\x50\\x11vec_dot',b'\\x00\\x00\\x01\\x11vec_dot_type'),(b'\\x00\\x00\\x04\\x2C\\x00\\x00\\x00\\x02__darwin_pthread_handler_rec',b'\\x00\\x04\\x51\\x11__routine',b'\\x00\\x00\\x10\\x11__arg',b'\\x00\\x04\\x2B\\x11__next'),(b'\\x00\\x00\\x03\\xEF\\x00\\x00\\x00\\x02_opaque_pthread_attr_t',b'\\x00\\x04\\x20\\x11__sig',b'\\x00\\x04\\x0B\\x11__opaque'),(b'\\x00\\x00\\x03\\xF0\\x00\\x00\\x00\\x02_opaque_pthread_cond_t',b'\\x00\\x04\\x20\\x11__sig',b'\\x00\\x04\\x07\\x11__opaque'),(b'\\x00\\x00\\x03\\xF1\\x00\\x00\\x00\\x02_opaque_pthread_condattr_t',b'\\x00\\x04\\x20\\x11__sig',b'\\x00\\x04\\x11\\x11__opaque'),(b'\\x00\\x00\\x03\\xF2\\x00\\x00\\x00\\x02_opaque_pthread_mutex_t',b'\\x00\\x04\\x20\\x11__sig',b'\\x00\\x04\\x0B\\x11__opaque'),(b'\\x00\\x00\\x03\\xF3\\x00\\x00\\x00\\x02_opaque_pthread_mutexattr_t',b'\\x00\\x04\\x20\\x11__sig',b'\\x00\\x04\\x11\\x11__opaque'),(b'\\x00\\x00\\x03\\xF4\\x00\\x00\\x00\\x02_opaque_pthread_once_t',b'\\x00\\x04\\x20\\x11__sig',b'\\x00\\x04\\x11\\x11__opaque'),(b'\\x00\\x00\\x03\\xF5\\x00\\x00\\x00\\x02_opaque_pthread_rwlock_t',b'\\x00\\x04\\x20\\x11__sig',b'\\x00\\x04\\x03\\x11__opaque'),(b'\\x00\\x00\\x03\\xF6\\x00\\x00\\x00\\x02_opaque_pthread_rwlockattr_t',b'\\x00\\x04\\x20\\x11__sig',b'\\x00\\x04\\x01\\x11__opaque'),(b'\\x00\\x00\\x04\\x2E\\x00\\x00\\x00\\x02_opaque_pthread_t',b'\\x00\\x04\\x20\\x11__sig',b'\\x00\\x04\\x2B\\x11__cleanup_stack',b'\\x00\\x04\\x0F\\x11__opaque'),(b'\\x00\\x00\\x04\\x2F\\x00\\x00\\x00\\x10ggml_allocr',),(b'\\x00\\x00\\x04\\x30\\x00\\x00\\x00\\x02ggml_cgraph',b'\\x00\\x00\\x22\\x11n_nodes',b'\\x00\\x00\\x22\\x11n_leafs',b'\\x00\\x04\\x39\\x11nodes',b'\\x00\\x04\\x39\\x11grads',b'\\x00\\x04\\x39\\x11leafs',b'\\x00\\x04\\x4D\\x11visited_hash_table',b'\\x00\\x00\\x22\\x11perf_runs',b'\\x00\\x00\\xDB\\x11perf_cycles',b'\\x00\\x00\\xDB\\x11perf_time_us'),(b'\\x00\\x00\\x04\\x31\\x00\\x00\\x00\\x02ggml_compute_params',b'\\x00\\x04\\x17\\x11type',b'\\x00\\x00\\x22\\x11ith',b'\\x00\\x00\\x22\\x11nth',b'\\x00\\x00\\x11\\x11wsize',b'\\x00\\x00\\x10\\x11wdata'),(b'\\x00\\x00\\x04\\x32\\x00\\x00\\x00\\x10ggml_context',),(b'\\x00\\x00\\x04\\x33\\x00\\x00\\x00\\x02ggml_cplan',b'\\x00\\x00\\x11\\x11work_size',b'\\x00\\x04\\x3F\\x11work_data',b'\\x00\\x00\\x22\\x11n_threads',b'\\x00\\x04\\x19\\x11n_tasks',b'\\x00\\x03\\xEE\\x11abort_callback',b'\\x00\\x00\\x10\\x11abort_callback_data'),(b'\\x00\\x00\\x00\\xBC\\x00\\x00\\x00\\x02ggml_init_params',b'\\x00\\x00\\x11\\x11mem_size',b'\\x00\\x00\\x10\\x11mem_buffer',b'\\x00\\x00\\xB6\\x11no_alloc'),(b'\\x00\\x00\\x04\\x34\\x00\\x00\\x00\\x10ggml_metal_context',),(b'\\x00\\x00\\x04\\x35\\x00\\x00\\x00\\x10ggml_mpi_context',),(b'\\x00\\x00\\x04\\x37\\x00\\x00\\x00\\x02ggml_object',b'\\x00\\x00\\x11\\x11offs',b'\\x00\\x00\\x11\\x11size',b'\\x00\\x04\\x36\\x11next',b'\\x00\\x04\\x15\\x11type',b'\\x00\\x04\\x09\\x11padding'),(b'\\x00\\x00\\x04\\x38\\x00\\x00\\x00\\x02ggml_opt_context',b'\\x00\\x00\\x0B\\x11ctx',b'\\x00\\x00\\x50\\x11params',b'\\x00\\x00\\x22\\x11iter',b'\\x00\\x00\\xDB\\x11nx',b'\\x00\\x00\\xB6\\x11just_initialized',b'\\x00\\x04\\x29\\x11adam',b'\\x00\\x04\\x2A\\x11lbfgs'),(b'\\x00\\x00\\x00\\x50\\x00\\x00\\x00\\x02ggml_opt_params',b'\\x00\\x00\\xC8\\x11type',b'\\x00\\x00\\x22\\x11n_threads',b'\\x00\\x00\\x22\\x11past',b'\\x00\\x00\\xF4\\x11delta',b'\\x00\\x00\\x22\\x11max_no_improvement',b'\\x00\\x00\\xB6\\x11print_forward_graph',b'\\x00\\x00\\xB6\\x11print_backward_graph',b'\\x00\\x04\\x27\\x11adam',b'\\x00\\x04\\x28\\x11lbfgs'),(b'\\x00\\x00\\x02\\x48\\x00\\x00\\x00\\x02ggml_scratch',b'\\x00\\x00\\x11\\x11offs',b'\\x00\\x00\\x11\\x11size',b'\\x00\\x00\\x10\\x11data'),(b'\\x00\\x00\\x04\\x3D\\x00\\x00\\x00\\x02ggml_tensor',b'\\x00\\x00\\x01\\x11type',b'\\x00\\x04\\x13\\x11backend',b'\\x00\\x00\\x22\\x11n_dims',b'\\x00\\x04\\x1E\\x11ne',b'\\x00\\x04\\x4B\\x11nb',b'\\x00\\x00\\x2E\\x11op',b'\\x00\\x04\\x1B\\x11op_params',b'\\x00\\x00\\xB6\\x11is_param',b'\\x00\\x00\\x08\\x11grad',b'\\x00\\x04\\x3B\\x11src',b'\\x00\\x00\\x22\\x11perf_runs',b'\\x00\\x00\\xDB\\x11perf_cycles',b'\\x00\\x00\\xDB\\x11perf_time_us',b'\\x00\\x00\\x10\\x11data',b'\\x00\\x04\\x0D\\x11name',b'\\x00\\x00\\x10\\x11extra',b'\\x00\\x04\\x09\\x11padding'),(b'\\x00\\x00\\x04\\x3E\\x00\\x00\\x00\\x10gguf_context',),(b'\\x00\\x00\\x02\\x24\\x00\\x00\\x00\\x02gguf_init_params',b'\\x00\\x00\\xB6\\x11no_alloc',b'\\x00\\x00\\xB0\\x11ctx')),\n    _enums = (b'\\x00\\x00\\x04\\x13\\x00\\x00\\x00\\x16ggml_backend\\x00GGML_BACKEND_CPU,GGML_BACKEND_GPU,GGML_BACKEND_GPU_SPLIT',b'\\x00\\x00\\x00\\x54\\x00\\x00\\x00\\x15ggml_ftype\\x00GGML_FTYPE_UNKNOWN,GGML_FTYPE_ALL_F32,GGML_FTYPE_MOSTLY_F16,GGML_FTYPE_MOSTLY_Q4_0,GGML_FTYPE_MOSTLY_Q4_1,GGML_FTYPE_MOSTLY_Q4_1_SOME_F16,GGML_FTYPE_MOSTLY_Q8_0,GGML_FTYPE_MOSTLY_Q5_0,GGML_FTYPE_MOSTLY_Q5_1,GGML_FTYPE_MOSTLY_Q2_K,GGML_FTYPE_MOSTLY_Q3_K,GGML_FTYPE_MOSTLY_Q4_K,GGML_FTYPE_MOSTLY_Q5_K,GGML_FTYPE_MOSTLY_Q6_K',b'\\x00\\x00\\x04\\x14\\x00\\x00\\x00\\x16ggml_linesearch\\x00GGML_LINESEARCH_DEFAULT,GGML_LINESEARCH_BACKTRACKING_ARMIJO,GGML_LINESEARCH_BACKTRACKING_WOLFE,GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE',b'\\x00\\x00\\x04\\x15\\x00\\x00\\x00\\x16ggml_object_type\\x00GGML_OBJECT_TENSOR,GGML_OBJECT_GRAPH,GGML_OBJECT_WORK_BUFFER',b'\\x00\\x00\\x00\\x2E\\x00\\x00\\x00\\x16ggml_op\\x00GGML_OP_NONE,GGML_OP_DUP,GGML_OP_ADD,GGML_OP_ADD1,GGML_OP_ACC,GGML_OP_SUB,GGML_OP_MUL,GGML_OP_DIV,GGML_OP_SQR,GGML_OP_SQRT,GGML_OP_LOG,GGML_OP_SUM,GGML_OP_SUM_ROWS,GGML_OP_MEAN,GGML_OP_ARGMAX,GGML_OP_REPEAT,GGML_OP_REPEAT_BACK,GGML_OP_SILU_BACK,GGML_OP_NORM,GGML_OP_RMS_NORM,GGML_OP_RMS_NORM_BACK,GGML_OP_MUL_MAT,GGML_OP_OUT_PROD,GGML_OP_SCALE,GGML_OP_SET,GGML_OP_CPY,GGML_OP_CONT,GGML_OP_RESHAPE,GGML_OP_VIEW,GGML_OP_PERMUTE,GGML_OP_TRANSPOSE,GGML_OP_GET_ROWS,GGML_OP_GET_ROWS_BACK,GGML_OP_DIAG,GGML_OP_DIAG_MASK_INF,GGML_OP_DIAG_MASK_ZERO,GGML_OP_SOFT_MAX,GGML_OP_SOFT_MAX_BACK,GGML_OP_ROPE,GGML_OP_ROPE_BACK,GGML_OP_ALIBI,GGML_OP_CLAMP,GGML_OP_CONV_1D,GGML_OP_CONV_2D,GGML_OP_POOL_1D,GGML_OP_POOL_2D,GGML_OP_FLASH_ATTN,GGML_OP_FLASH_FF,GGML_OP_FLASH_ATTN_BACK,GGML_OP_WIN_PART,GGML_OP_WIN_UNPART,GGML_OP_UNARY,GGML_OP_MAP_UNARY,GGML_OP_MAP_BINARY,GGML_OP_MAP_CUSTOM1_F32,GGML_OP_MAP_CUSTOM2_F32,GGML_OP_MAP_CUSTOM3_F32,GGML_OP_MAP_CUSTOM1,GGML_OP_MAP_CUSTOM2,GGML_OP_MAP_CUSTOM3,GGML_OP_CROSS_ENTROPY_LOSS,GGML_OP_CROSS_ENTROPY_LOSS_BACK,GGML_OP_COUNT',b'\\x00\\x00\\x01\\x01\\x00\\x00\\x00\\x16ggml_op_pool\\x00GGML_OP_POOL_MAX,GGML_OP_POOL_AVG,GGML_OP_POOL_COUNT',b'\\x00\\x00\\x04\\x16\\x00\\x00\\x00\\x15ggml_opt_result\\x00GGML_OPT_OK,GGML_OPT_DID_NOT_CONVERGE,GGML_OPT_NO_CONTEXT,GGML_OPT_INVALID_WOLFE,GGML_OPT_FAIL,GGML_LINESEARCH_FAIL,GGML_LINESEARCH_MINIMUM_STEP,GGML_LINESEARCH_MAXIMUM_STEP,GGML_LINESEARCH_MAXIMUM_ITERATIONS,GGML_LINESEARCH_INVALID_PARAMETERS',b'\\x00\\x00\\x00\\xC8\\x00\\x00\\x00\\x16ggml_opt_type\\x00GGML_OPT_ADAM,GGML_OPT_LBFGS',b'\\x00\\x00\\x04\\x17\\x00\\x00\\x00\\x16ggml_task_type\\x00GGML_TASK_INIT,GGML_TASK_COMPUTE,GGML_TASK_FINALIZE',b'\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x16ggml_type\\x00GGML_TYPE_F32,GGML_TYPE_F16,GGML_TYPE_Q4_0,GGML_TYPE_Q4_1,GGML_TYPE_Q5_0,GGML_TYPE_Q5_1,GGML_TYPE_Q8_0,GGML_TYPE_Q8_1,GGML_TYPE_Q2_K,GGML_TYPE_Q3_K,GGML_TYPE_Q4_K,GGML_TYPE_Q5_K,GGML_TYPE_Q6_K,GGML_TYPE_Q8_K,GGML_TYPE_I8,GGML_TYPE_I16,GGML_TYPE_I32,GGML_TYPE_COUNT',b'\\x00\\x00\\x01\\x14\\x00\\x00\\x00\\x16ggml_unary_op\\x00GGML_UNARY_OP_ABS,GGML_UNARY_OP_SGN,GGML_UNARY_OP_NEG,GGML_UNARY_OP_STEP,GGML_UNARY_OP_TANH,GGML_UNARY_OP_ELU,GGML_UNARY_OP_RELU,GGML_UNARY_OP_GELU,GGML_UNARY_OP_GELU_QUICK,GGML_UNARY_OP_SILU',b'\\x00\\x00\\x00\\x34\\x00\\x00\\x00\\x16gguf_type\\x00GGUF_TYPE_UINT8,GGUF_TYPE_INT8,GGUF_TYPE_UINT16,GGUF_TYPE_INT16,GGUF_TYPE_UINT32,GGUF_TYPE_INT32,GGUF_TYPE_FLOAT32,GGUF_TYPE_BOOL,GGUF_TYPE_STRING,GGUF_TYPE_ARRAY,GGUF_TYPE_COUNT'),\n    _typenames = (b'\\x00\\x00\\x00\\xDB__darwin_blkcnt_t',b'\\x00\\x00\\x00\\x22__darwin_blksize_t',b'\\x00\\x00\\x00\\x11__darwin_clock_t',b'\\x00\\x00\\x00\\x22__darwin_ct_rune_t',b'\\x00\\x00\\x00\\x22__darwin_dev_t',b'\\x00\\x00\\x03\\xBF__darwin_fsblkcnt_t',b'\\x00\\x00\\x03\\xBF__darwin_fsfilcnt_t',b'\\x00\\x00\\x03\\xBF__darwin_gid_t',b'\\x00\\x00\\x03\\xBF__darwin_id_t',b'\\x00\\x00\\x04\\x4A__darwin_ino64_t',b'\\x00\\x00\\x04\\x4A__darwin_ino_t',b'\\x00\\x00\\x04\\x20__darwin_intptr_t',b'\\x00\\x00\\x03\\xBF__darwin_mach_port_name_t',b'\\x00\\x00\\x03\\xBF__darwin_mach_port_t',b'\\x00\\x00\\x03\\xF7__darwin_mbstate_t',b'\\x00\\x00\\x00\\x6C__darwin_mode_t',b'\\x00\\x00\\x03\\xBF__darwin_natural_t',b'\\x00\\x00\\x00\\xDB__darwin_off_t',b'\\x00\\x00\\x00\\x22__darwin_pid_t',b'\\x00\\x00\\x03\\xEF__darwin_pthread_attr_t',b'\\x00\\x00\\x03\\xF0__darwin_pthread_cond_t',b'\\x00\\x00\\x03\\xF1__darwin_pthread_condattr_t',b'\\x00\\x00\\x00\\x11__darwin_pthread_key_t',b'\\x00\\x00\\x03\\xF2__darwin_pthread_mutex_t',b'\\x00\\x00\\x03\\xF3__darwin_pthread_mutexattr_t',b'\\x00\\x00\\x03\\xF4__darwin_pthread_once_t',b'\\x00\\x00\\x03\\xF5__darwin_pthread_rwlock_t',b'\\x00\\x00\\x03\\xF6__darwin_pthread_rwlockattr_t',b'\\x00\\x00\\x04\\x2D__darwin_pthread_t',b'\\x00\\x00\\x04\\x20__darwin_ptrdiff_t',b'\\x00\\x00\\x00\\x22__darwin_rune_t',b'\\x00\\x00\\x03\\xBF__darwin_sigset_t',b'\\x00\\x00\\x00\\x11__darwin_size_t',b'\\x00\\x00\\x03\\xBF__darwin_socklen_t',b'\\x00\\x00\\x04\\x20__darwin_ssize_t',b'\\x00\\x00\\x00\\x22__darwin_suseconds_t',b'\\x00\\x00\\x04\\x20__darwin_time_t',b'\\x00\\x00\\x03\\xBF__darwin_uid_t',b'\\x00\\x00\\x03\\xBF__darwin_useconds_t',b'\\x00\\x00\\x04\\x05__darwin_uuid_string_t',b'\\x00\\x00\\x04\\x44__darwin_uuid_t',b'\\x00\\x00\\x00\\x22__darwin_wchar_t',b'\\x00\\x00\\x00\\x22__darwin_wint_t',b'\\x00\\x00\\x03\\xB0__int16_t',b'\\x00\\x00\\x00\\x22__int32_t',b'\\x00\\x00\\x00\\xDB__int64_t',b'\\x00\\x00\\x03\\xB5__int8_t',b'\\x00\\x00\\x03\\xF7__mbstate_t',b'\\x00\\x00\\x00\\x6C__uint16_t',b'\\x00\\x00\\x03\\xBF__uint32_t',b'\\x00\\x00\\x04\\x4A__uint64_t',b'\\x00\\x00\\x03\\xBA__uint8_t',b'\\x00\\x00\\x03\\xF8block_q2_K',b'\\x00\\x00\\x03\\xF9block_q3_K',b'\\x00\\x00\\x03\\xFAblock_q4_K',b'\\x00\\x00\\x03\\xFBblock_q5_K',b'\\x00\\x00\\x03\\xFCblock_q6_K',b'\\x00\\x00\\x03\\xFDblock_q8_K',b'\\x00\\x00\\x01\\xEAggml_binary_op_f32_t',b'\\x00\\x00\\x02\\x02ggml_custom1_op_f32_t',b'\\x00\\x00\\x02\\x07ggml_custom1_op_t',b'\\x00\\x00\\x01\\xF0ggml_custom2_op_f32_t',b'\\x00\\x00\\x01\\xF6ggml_custom2_op_t',b'\\x00\\x00\\x01\\xC5ggml_custom3_op_f32_t',b'\\x00\\x00\\x01\\xCCggml_custom3_op_t',b'\\x00\\x00\\x00\\x6Cggml_fp16_t',b'\\x00\\x00\\x04\\x4Fggml_from_float_t',b'\\x00\\x00\\x04\\x52ggml_to_float_t',b'\\x00\\x00\\x04\\x18ggml_type_traits_t',b'\\x00\\x00\\x01\\xFDggml_unary_op_f32_t',b'\\x00\\x00\\x04\\x50ggml_vec_dot_t',b'\\x00\\x00\\x03\\xB0int16_t',b'\\x00\\x00\\x00\\x22int32_t',b'\\x00\\x00\\x00\\xDBint64_t',b'\\x00\\x00\\x03\\xB5int8_t',b'\\x00\\x00\\x03\\xB0int_fast16_t',b'\\x00\\x00\\x00\\x22int_fast32_t',b'\\x00\\x00\\x00\\xDBint_fast64_t',b'\\x00\\x00\\x03\\xB5int_fast8_t',b'\\x00\\x00\\x03\\xB0int_least16_t',b'\\x00\\x00\\x00\\x22int_least32_t',b'\\x00\\x00\\x00\\xDBint_least64_t',b'\\x00\\x00\\x03\\xB5int_least8_t',b'\\x00\\x00\\x04\\x20intmax_t',b'\\x00\\x00\\x04\\x20intptr_t',b'\\x00\\x00\\x04\\x1Dmax_align_t',b'\\x00\\x00\\x04\\x20ptrdiff_t',b'\\x00\\x00\\x00\\xDBregister_t',b'\\x00\\x00\\x00\\x11rsize_t',b'\\x00\\x00\\x00\\x11size_t',b'\\x00\\x00\\x04\\x4Asyscall_arg_t',b'\\x00\\x00\\x00\\x6Cu_int16_t',b'\\x00\\x00\\x03\\xBFu_int32_t',b'\\x00\\x00\\x04\\x4Au_int64_t',b'\\x00\\x00\\x03\\xBAu_int8_t',b'\\x00\\x00\\x00\\x6Cuint16_t',b'\\x00\\x00\\x03\\xBFuint32_t',b'\\x00\\x00\\x04\\x4Auint64_t',b'\\x00\\x00\\x03\\xBAuint8_t',b'\\x00\\x00\\x00\\x6Cuint_fast16_t',b'\\x00\\x00\\x03\\xBFuint_fast32_t',b'\\x00\\x00\\x04\\x4Auint_fast64_t',b'\\x00\\x00\\x03\\xBAuint_fast8_t',b'\\x00\\x00\\x00\\x6Cuint_least16_t',b'\\x00\\x00\\x03\\xBFuint_least32_t',b'\\x00\\x00\\x04\\x4Auint_least64_t',b'\\x00\\x00\\x03\\xBAuint_least8_t',b'\\x00\\x00\\x00\\x11uintmax_t',b'\\x00\\x00\\x00\\x11uintptr_t',b'\\x00\\x00\\x04\\x4Auser_addr_t',b'\\x00\\x00\\x00\\xDBuser_long_t',b'\\x00\\x00\\x00\\xDBuser_off_t',b'\\x00\\x00\\x04\\x4Auser_size_t',b'\\x00\\x00\\x00\\xDBuser_ssize_t',b'\\x00\\x00\\x00\\xDBuser_time_t',b'\\x00\\x00\\x04\\x4Auser_ulong_t',b'\\x00\\x00\\x00\\x22wchar_t'),\n)\n"
  },
  {
    "path": "examples/python/ggml/ffi/__init__.pyi",
    "content": "# Phony stubs.\n\nclass CData:\n    pass\n\nclass CType:\n    pass"
  },
  {
    "path": "examples/python/ggml/utils.py",
    "content": "\"\"\"\n  Common helpers for working with ggml + numpy\n\"\"\"\nfrom ggml import ffi, lib\nfrom typing import Union, Optional\nimport numpy as np\n\ndef init(mem_size: int, mem_buffer: ffi.CData = ffi.NULL, no_alloc: bool = False) -> ffi.CData:\n    \"\"\"\n      Initialize a ggml context, which will be freed automatically when the pointer is garbage collected.\n    \"\"\"\n    params = ffi.new('struct ggml_init_params*')\n    params.mem_size = mem_size\n    params.mem_buffer = mem_buffer\n    params.no_alloc = no_alloc\n    return ffi.gc(lib.ggml_init(params[0]), lib.ggml_free)\n \nTensorLike = Union[ffi.CData, np.ndarray]\n\ndef copy(from_tensor: TensorLike, to_tensor: TensorLike, allow_requantize: bool = True):\n    \"\"\"\n      Copy the contents of one tensor to another, doing any necessary (de/re)quantization transparently.\n      Works across numpy & ggml tensors, but they must have the same shape (and be contiguous).\n\n      Parameters\n      ----------\n      from_tensor : TensorLike\n          The tensor to copy from (a numpy array or possibly-quantized ggml tensor)\n      to_tensor : TensorLike\n          The tensor to copy to (a numpy array or possibly-quantized ggml tensor)\n      allow_requantize : bool\n          If False, will throw an error if requantization is required (i.e. both from_tensor\n          and to_tensor are quantized with different quantization types)\n    \"\"\"\n    if id(from_tensor) == id(to_tensor):\n        return\n \n    __expect_same_layout(\"source\", from_tensor, \"destination\", to_tensor)\n    __check_shape_consistent_with_type(from_tensor)\n    __check_shape_consistent_with_type(to_tensor)\n\n    from_type = __get_type(from_tensor)\n    to_type = __get_type(to_tensor)\n\n    if from_type == to_type:\n        ffi.memmove(__get_data(to_tensor), __get_data(from_tensor), __get_nbytes(from_tensor))\n    else:\n        assert allow_requantize or not lib.ggml_is_quantized(from_type) or not lib.ggml_is_quantized(to_type), \\\n            f\"Requantizing from {__type_name(from_type)} to {__type_name(to_type)} is disabled. Force with allow_requantize=True\"\n \n        __set_floats(to_tensor, __get_floats(from_tensor))\n\ndef numpy(tensor: ffi.CData, allow_copy: Union[bool, np.ndarray] = False, allow_requantize=False) -> np.ndarray:\n    \"\"\"\n      Convert a ggml tensor to a numpy array.\n      If the tensor isn't quantized, the returned numpy array will be a view over its data.\n \n      If it is quantized (and allow_copy is True), the copy will involve dequantization and the returned array will\n      be a copy of the original tensor (any changes to the numpy array won't then be reflected back to the tensor).\n\n      Parameters\n      ----------\n      tensor : ffi.CData\n          The tensor to convert to a numpy array\n      allow_copy : bool or np.ndarray\n          If False, will throw an error if the tensor is quantized (since dequantization requires extra memory).\n          If True, will dequantize the tensor and return a copy of the data in a new float32 numpy array.\n          If an np.ndarray, will copy the data into the given array (which must be the same shape as the tensor) when dequantization is needed\n      allow_requantize : bool\n          If allow_copy is a tensor with a different quantization type than the source tensor, will throw an error unless allow_requantize is True.\n    \"\"\"\n    shape = __get_shape(tensor)\n\n    if lib.ggml_is_quantized(tensor.type):\n        if allow_copy == False:\n            raise ValueError(f\"{__describe(tensor)} is quantized, conversion to numpy requires a copy (pass allow_copy=True; changes to the numpy array won't affect the original).\")\n        elif isinstance(allow_copy, np.ndarray):\n            __expect_same_layout(\"source tensor\", tensor, \"dequantization output tensor\", allow_copy)\n            destination = allow_copy\n        else:\n            destination = np.empty(shape, dtype=np.float32)\n\n        copy(tensor, destination, allow_requantize=allow_requantize)\n        return destination\n    else:\n        dtype = __type_to_dtype(tensor.type)\n        if not dtype:\n            raise NotImplementedError(f'Cannot convert {__describe(tensor)} to numpy')\n\n        assert __is_contiguous(tensor), f\"Cannot convert {__describe(tensor)} to numpy (support contiguous tensors only)\"\n        nbytes = lib.ggml_nelements(tensor) * lib.ggml_type_size(tensor.type)\n        array = np.frombuffer(ffi.buffer(lib.ggml_get_data(tensor), nbytes), dtype=dtype)\n        array.shape = shape\n        return array\n\ndef __type_name(type: int) -> str:\n    name = lib.ggml_type_name(type)\n    return ffi.string(name).decode('utf-8') if name else None\n\n__k_quant_types = set([\n  lib.GGML_TYPE_Q2_K,\n  lib.GGML_TYPE_Q3_K,\n  lib.GGML_TYPE_Q4_K,\n  lib.GGML_TYPE_Q5_K,\n  lib.GGML_TYPE_Q6_K,\n  lib.GGML_TYPE_Q8_K,\n])\n\n__type_to_dtype_dict = {\n  lib.GGML_TYPE_I8: np.int8,\n  lib.GGML_TYPE_I16: np.int16,\n  lib.GGML_TYPE_I32: np.int32,\n  lib.GGML_TYPE_F16: np.float16,\n  lib.GGML_TYPE_F32: np.float32,\n}\n\ndef __type_to_dtype(type: int) -> Optional[np.dtype]: return __type_to_dtype_dict.get(type)\ndef __dtype_to_type(dtype: np.dtype):\n    if dtype == np.float32: return lib.GGML_TYPE_F32\n    elif dtype == np.float16: return lib.GGML_TYPE_F16\n    elif dtype == np.int32: return lib.GGML_TYPE_I32\n    elif dtype == np.int16: return lib.GGML_TYPE_I16\n    elif dtype == np.int8: return lib.GGML_TYPE_I8\n    else: raise ValueError(f\"Unsupported dtype: {dtype}\")\n\ndef __describe(tensor: ffi.CType): return f'Tensor[{__type_name(__get_type(tensor))}, {__get_shape(tensor)}]'\ndef __get_type(tensor: TensorLike): return __dtype_to_type(tensor.dtype) if isinstance(tensor, np.ndarray) else tensor.type\ndef __get_shape(x: TensorLike): return x.shape if isinstance(x, np.ndarray) else tuple([x.ne[i] for i in range(x.n_dims)])\ndef __get_strides(x: TensorLike): return x.strides if isinstance(x, np.ndarray) else tuple([x.nb[i] for i in range(x.n_dims)])\ndef __get_data(x: TensorLike) -> ffi.CData: return ffi.from_buffer(x) if isinstance(x, np.ndarray) else lib.ggml_get_data(x)\ndef __get_nbytes(tensor: TensorLike): return tensor.nbytes if isinstance(tensor, np.ndarray) else lib.ggml_nbytes(tensor)\ndef __get_nelements(tensor: TensorLike): return tensor.size if isinstance(tensor, np.ndarray) else lib.ggml_nelements(tensor)\ndef __is_contiguous(tensor: TensorLike): return tensor.flags['C_CONTIGUOUS'] if isinstance(tensor, np.ndarray) else lib.ggml_is_contiguous(tensor)\n\ndef __get_floats(tensor: TensorLike) -> ffi.CData:\n    data, type = __get_data(tensor), __get_type(tensor)\n    if type == lib.GGML_TYPE_F32:\n        return ffi.cast('float*', data)\n    else:\n      nelements = __get_nelements(tensor)\n      floats = ffi.new('float[]', nelements)\n      if type == lib.GGML_TYPE_F16:\n          lib.ggml_fp16_to_fp32_row(ffi.cast('uint16_t*', data), floats, nelements)\n      elif lib.ggml_is_quantized(type):\n          qtype = lib.ggml_internal_get_type_traits(type)\n          assert qtype.to_float, f\"Type {__type_name(type)} is not supported by ggml\"\n          qtype.to_float(data, floats, nelements)\n      else:\n          raise NotImplementedError(f'Cannot read floats from {__describe(tensor)}')\n      return floats\n\ndef __set_floats(tensor: TensorLike, f32_data: ffi.CData) -> None:\n    data, type, nbytes = __get_data(tensor), __get_type(tensor), __get_nbytes(tensor)\n    if type == lib.GGML_TYPE_F32:\n        ffi.memmove(data, f32_data, nbytes)\n    else:\n      nelements = __get_nelements(tensor)\n      if type == lib.GGML_TYPE_F16:\n          lib.ggml_fp32_to_fp16_row(f32_data, ffi.cast('uint16_t*', data), nelements)\n      elif lib.ggml_is_quantized(type):\n          qtype = lib.ggml_internal_get_type_traits(type)\n          assert qtype.from_float, f\"Type {__type_name(type)} is not supported by ggml\"\n          qtype.from_float(f32_data, data, nelements)\n      else:\n          raise NotImplementedError(f'Cannot write floats to {__describe(tensor)}')\n\ndef __expect_same_layout(name1: str, tensor1: TensorLike, name2: str, tensor2: TensorLike):\n    shape1, shape2 = __get_shape(tensor1), __get_shape(tensor2)\n    assert shape1 == shape2, f\"Shape mismatch: {name1} has {shape1} but {name2} has {shape2}\"\n    assert __is_contiguous(tensor1) and __is_contiguous(tensor2), f\"Only contiguous tensors are supported (got {name1} with strides {__get_strides(tensor1)} and {name2} with strides {__get_strides(tensor2)})\"\n\ndef __check_shape_consistent_with_type(tensor: TensorLike):\n    type = __get_type(tensor)\n    if not lib.ggml_is_quantized(type):\n        return\n    shape = __get_shape(tensor)\n\n    block_size = lib.ggml_blck_size(type)\n    assert not (block_size == 0 and type in __k_quant_types), f\"Can't quantize, native library was not compiled with USE_K_QUANTS!\"\n    assert block_size > 0, f\"Invalid block size {block_size} for type {__type_name(type)}\"\n    for i, d in enumerate(shape):\n        assert d % block_size == 0, f\"Dimension {i} of {__describe(tensor)} is not divisible by {block_size}, required for quantization.\"\n"
  },
  {
    "path": "examples/python/regenerate.py",
    "content": "# Generates bindings for the ggml library.\n#\n# cffi requires prior C preprocessing of the headers, and it uses pycparser which chokes on a couple of things\n# so we help it a bit (e.g. replace sizeof expressions with their value, remove exotic syntax found in Darwin headers).\nimport os, sys, re, subprocess\nimport cffi\nfrom stubs import generate_stubs\n\nAPI = os.environ.get('API', 'api.h')\nCC = os.environ.get('CC') or 'gcc'\nC_INCLUDE_DIR = os.environ.get('C_INCLUDE_DIR', '../../../llama.cpp')\nCPPFLAGS = [\n    \"-I\", C_INCLUDE_DIR,\n    '-D__fp16=uint16_t',  # pycparser doesn't support __fp16\n    '-D__attribute__(x)=',\n    '-D_Static_assert(x, m)=',\n] + [x for x in os.environ.get('CPPFLAGS', '').split(' ') if x != '']\n\ntry: header = subprocess.run([CC, \"-E\", *CPPFLAGS, API], capture_output=True, text=True, check=True).stdout\nexcept subprocess.CalledProcessError as e: print(f'{e.stderr}\\n{e}', file=sys.stderr); raise\n\nheader = '\\n'.join([l for l in header.split('\\n') if '__darwin_va_list' not in l]) # pycparser hates this\n\n# Replace constant size expressions w/ their value (compile & run a mini exe for each, because why not).\n# First, extract anyting *inside* square brackets and anything that looks like a sizeof call.\nfor expr in set(re.findall(f'(?<=\\\\[)[^\\\\]]+(?=])|sizeof\\\\s*\\\\([^()]+\\\\)', header)):\n    if re.match(r'^(\\d+|\\s*)$', expr): continue # skip constants and empty bracket contents\n    subprocess.run([CC, \"-o\", \"eval_size_expr\", *CPPFLAGS, \"-x\", \"c\", \"-\"], text=True, check=True,\n                   input=f'''#include <stdio.h>\n                             #include \"{API}\"\n                             int main() {{ printf(\"%lu\", (size_t)({expr})); }}''')\n    size = subprocess.run([\"./eval_size_expr\"], capture_output=True, text=True, check=True).stdout\n    print(f'Computed constexpr {expr} = {size}')\n    header = header.replace(expr, size)\n\nffibuilder = cffi.FFI()\nffibuilder.cdef(header)\nffibuilder.set_source(f'ggml.cffi', None) # we're not compiling a native extension, as this quickly gets hairy\nffibuilder.compile(verbose=True)\n\nwith open(\"ggml/__init__.pyi\", \"wt\") as f:\n    f.write(generate_stubs(header))"
  },
  {
    "path": "examples/python/stubs.py",
    "content": "\"\"\"\n  This generates .pyi stubs for the cffi Python bindings generated by regenerate.py\n\"\"\"\nimport sys, re, itertools\nsys.path.extend(['.', '..']) # for pycparser\n\nfrom pycparser import c_ast, parse_file, CParser\nimport pycparser.plyparser\nfrom pycparser.c_ast import PtrDecl, TypeDecl, FuncDecl, EllipsisParam, IdentifierType, Struct, Enum, Typedef\nfrom typing import Tuple\n\n__c_type_to_python_type = {\n    'void': 'None', '_Bool': 'bool',\n    'char': 'int', 'short': 'int', 'int': 'int', 'long': 'int',\n    'ptrdiff_t': 'int', 'size_t': 'int',\n    'int8_t': 'int', 'uint8_t': 'int',\n    'int16_t': 'int', 'uint16_t': 'int',\n    'int32_t': 'int', 'uint32_t': 'int',\n    'int64_t': 'int', 'uint64_t': 'int',\n    'float': 'float', 'double': 'float',\n    'ggml_fp16_t': 'np.float16',\n}\n\ndef format_type(t: TypeDecl):\n    if isinstance(t, PtrDecl) or isinstance(t, Struct):\n        return 'ffi.CData'\n    if isinstance(t, Enum):\n        return 'int'\n    if isinstance(t, TypeDecl):\n        return format_type(t.type)\n    if isinstance(t, IdentifierType):\n        assert len(t.names) == 1, f'Expected a single name, got {t.names}'\n        return __c_type_to_python_type.get(t.names[0]) or 'ffi.CData'\n    return t.name\n\nclass PythonStubFuncDeclVisitor(c_ast.NodeVisitor):\n    def __init__(self):\n        self.sigs = {}\n        self.sources = {}\n\n    def get_source_snippet_lines(self, coord: pycparser.plyparser.Coord) -> Tuple[list[str], list[str]]:\n        if coord.file not in self.sources:\n            with open(coord.file, 'rt') as f:\n                self.sources[coord.file] = f.readlines()\n        source_lines = self.sources[coord.file]\n        ncomment_lines = len(list(itertools.takewhile(lambda i: re.search(r'^\\s*(//|/\\*)', source_lines[i]), range(coord.line - 2, -1, -1))))\n        comment_lines = [l.strip() for l in source_lines[coord.line - 1 - ncomment_lines:coord.line - 1]]\n        decl_lines = []\n        for line in source_lines[coord.line - 1:]:\n            decl_lines.append(line.rstrip())\n            if (';' in line) or ('{' in line): break\n        return (comment_lines, decl_lines)\n\n    def visit_Enum(self, node: Enum):\n        if node.values is not None:\n          for e in node.values.enumerators:\n              self.sigs[e.name] = f'  @property\\n  def {e.name}(self) -> int: ...'\n\n    def visit_Typedef(self, node: Typedef):\n        pass\n\n    def visit_FuncDecl(self, node: FuncDecl):\n        ret_type = node.type\n        is_ptr = False\n        while isinstance(ret_type, PtrDecl):\n            ret_type = ret_type.type\n            is_ptr = True\n\n        fun_name = ret_type.declname\n        if fun_name.startswith('__'):\n            return\n\n        args = []\n        argnames = []\n        def gen_name(stem):\n            i = 1\n            while True:\n                new_name = stem if i == 1 else f'{stem}{i}'\n                if new_name not in argnames: return new_name\n                i += 1\n\n        for a in node.args.params:\n            if isinstance(a, EllipsisParam):\n                arg_name = gen_name('args')\n                argnames.append(arg_name)\n                args.append('*' + gen_name('args'))\n            elif format_type(a.type) == 'None':\n                continue\n            else:\n                arg_name = a.name or gen_name('arg')\n                argnames.append(arg_name)\n                args.append(f'{arg_name}: {format_type(a.type)}')\n\n        ret = format_type(ret_type if not is_ptr else node.type)\n\n        comment_lines, decl_lines = self.get_source_snippet_lines(node.coord)\n\n        lines = [f'  def {fun_name}({\", \".join(args)}) -> {ret}:']\n        if len(comment_lines) == 0 and len(decl_lines) == 1:\n            lines += [f'    \"\"\"{decl_lines[0]}\"\"\"']\n        else:\n            lines += ['    \"\"\"']\n            lines += [f'    {c.lstrip(\"/* \")}' for c in comment_lines]\n            if len(comment_lines) > 0:\n                lines += ['']\n            lines += [f'    {d}' for d in decl_lines]\n            lines += ['    \"\"\"']\n        lines += ['    ...']\n        self.sigs[fun_name] = '\\n'.join(lines)\n\ndef generate_stubs(header: str):\n    \"\"\"\n      Generates a .pyi Python stub file for the GGML API using C header files.\n    \"\"\"\n\n    v = PythonStubFuncDeclVisitor()\n    v.visit(CParser().parse(header, \"<input>\"))\n\n    keys = list(v.sigs.keys())\n    keys.sort()\n\n    return '\\n'.join([\n        '# auto-generated file',\n        'import ggml.ffi as ffi',\n        'import numpy as np',\n        'class lib:',\n        *[v.sigs[k] for k in keys]\n    ])\n"
  },
  {
    "path": "examples/python/test_tensor.py",
    "content": "import pytest\nfrom pytest import raises\n\nfrom ggml import lib, ffi\nfrom ggml.utils import init, copy, numpy\nimport numpy as np\nimport numpy.testing as npt\n\n@pytest.fixture()\ndef ctx():\n    print(\"setup\")\n    yield init(mem_size=10*1024*1024)\n    print(\"teardown\")\n\nclass TestNumPy:\n    \n    # Single element\n\n    def test_set_get_single_i32(self, ctx):\n        i = lib.ggml_new_i32(ctx, 42)\n        assert lib.ggml_get_i32_1d(i, 0) == 42\n        assert numpy(i) == np.array([42], dtype=np.int32)\n\n    def test_set_get_single_f32(self, ctx):\n        i = lib.ggml_new_f32(ctx, 4.2)\n        \n        epsilon = 0.000001 # Not sure why so large a difference??\n        pytest.approx(lib.ggml_get_f32_1d(i, 0), 4.2, epsilon)\n        pytest.approx(numpy(i), np.array([4.2], dtype=np.float32), epsilon)\n\n    def _test_copy_np_to_ggml(self, a: np.ndarray, t: ffi.CData):\n        a2 = a.copy() # Clone original\n        copy(a, t)\n        npt.assert_array_equal(numpy(t), a2)\n\n    # I32\n\n    def test_copy_np_to_ggml_1d_i32(self, ctx):\n        t = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_I32, 10)\n        a = np.arange(10, dtype=np.int32)\n        self._test_copy_np_to_ggml(a, t)\n\n    def test_copy_np_to_ggml_2d_i32(self, ctx):\n        t = lib.ggml_new_tensor_2d(ctx, lib.GGML_TYPE_I32, 2, 3)\n        a = np.arange(2 * 3, dtype=np.int32).reshape((2, 3))\n        self._test_copy_np_to_ggml(a, t)\n\n    def test_copy_np_to_ggml_3d_i32(self, ctx):\n        t = lib.ggml_new_tensor_3d(ctx, lib.GGML_TYPE_I32, 2, 3, 4)\n        a = np.arange(2 * 3 * 4, dtype=np.int32).reshape((2, 3, 4))\n        self._test_copy_np_to_ggml(a, t)\n\n    def test_copy_np_to_ggml_4d_i32(self, ctx):\n        t = lib.ggml_new_tensor_4d(ctx, lib.GGML_TYPE_I32, 2, 3, 4, 5)\n        a = np.arange(2 * 3 * 4 * 5, dtype=np.int32).reshape((2, 3, 4, 5))\n        self._test_copy_np_to_ggml(a, t)\n\n    def test_copy_np_to_ggml_4d_n_i32(self, ctx):\n        dims = [2, 3, 4, 5] # GGML_MAX_DIMS is 4, going beyond would crash\n        pdims = ffi.new('int64_t[]', len(dims))\n        for i, d in enumerate(dims): pdims[i] = d\n        t = lib.ggml_new_tensor(ctx, lib.GGML_TYPE_I32, len(dims), pdims)\n        a = np.arange(np.prod(dims), dtype=np.int32).reshape(tuple(pdims))\n        self._test_copy_np_to_ggml(a, t)\n\n    # F32\n\n    def test_copy_np_to_ggml_1d_f32(self, ctx):\n        t = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, 10)\n        a = np.arange(10, dtype=np.float32)\n        self._test_copy_np_to_ggml(a, t)\n\n    def test_copy_np_to_ggml_2d_f32(self, ctx):\n        t = lib.ggml_new_tensor_2d(ctx, lib.GGML_TYPE_F32, 2, 3)\n        a = np.arange(2 * 3, dtype=np.float32).reshape((2, 3))\n        self._test_copy_np_to_ggml(a, t)\n\n    def test_copy_np_to_ggml_3d_f32(self, ctx):\n        t = lib.ggml_new_tensor_3d(ctx, lib.GGML_TYPE_F32, 2, 3, 4)\n        a = np.arange(2 * 3 * 4, dtype=np.float32).reshape((2, 3, 4))\n        self._test_copy_np_to_ggml(a, t)\n\n    def test_copy_np_to_ggml_4d_f32(self, ctx):\n        t = lib.ggml_new_tensor_4d(ctx, lib.GGML_TYPE_F32, 2, 3, 4, 5)\n        a = np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape((2, 3, 4, 5))\n        self._test_copy_np_to_ggml(a, t)\n\n    def test_copy_np_to_ggml_4d_n_f32(self, ctx):\n        dims = [2, 3, 4, 5] # GGML_MAX_DIMS is 4, going beyond would crash\n        pdims = ffi.new('int64_t[]', len(dims))\n        for i, d in enumerate(dims): pdims[i] = d\n        t = lib.ggml_new_tensor(ctx, lib.GGML_TYPE_F32, len(dims), pdims)\n        a = np.arange(np.prod(dims), dtype=np.float32).reshape(tuple(pdims))\n        self._test_copy_np_to_ggml(a, t)\n\n    # F16\n\n    def test_copy_np_to_ggml_1d_f16(self, ctx):\n        t = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F16, 10)\n        a = np.arange(10, dtype=np.float16)\n        self._test_copy_np_to_ggml(a, t)\n\n    def test_copy_np_to_ggml_2d_f16(self, ctx):\n        t = lib.ggml_new_tensor_2d(ctx, lib.GGML_TYPE_F16, 2, 3)\n        a = np.arange(2 * 3, dtype=np.float16).reshape((2, 3))\n        self._test_copy_np_to_ggml(a, t)\n\n    def test_copy_np_to_ggml_3d_f16(self, ctx):\n        t = lib.ggml_new_tensor_3d(ctx, lib.GGML_TYPE_F16, 2, 3, 4)\n        a = np.arange(2 * 3 * 4, dtype=np.float16).reshape((2, 3, 4))\n        self._test_copy_np_to_ggml(a, t)\n\n    def test_copy_np_to_ggml_4d_f16(self, ctx):\n        t = lib.ggml_new_tensor_4d(ctx, lib.GGML_TYPE_F16, 2, 3, 4, 5)\n        a = np.arange(2 * 3 * 4 * 5, dtype=np.float16).reshape((2, 3, 4, 5))\n        self._test_copy_np_to_ggml(a, t)\n\n    def test_copy_np_to_ggml_4d_n_f16(self, ctx):\n        dims = [2, 3, 4, 5] # GGML_MAX_DIMS is 4, going beyond would crash\n        pdims = ffi.new('int64_t[]', len(dims))\n        for i, d in enumerate(dims): pdims[i] = d\n        t = lib.ggml_new_tensor(ctx, lib.GGML_TYPE_F16, len(dims), pdims)\n        a = np.arange(np.prod(dims), dtype=np.float16).reshape(tuple(pdims))\n        self._test_copy_np_to_ggml(a, t)\n\n    # Mismatching shapes\n\n    def test_copy_mismatching_shapes_1d(self, ctx):\n        t = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, 10)\n        a = np.arange(10, dtype=np.float32)\n        copy(a, t) # OK\n        \n        a = a.reshape((5, 2))\n        with raises(AssertionError): copy(a, t)\n        with raises(AssertionError): copy(t, a)\n            \n    def test_copy_mismatching_shapes_2d(self, ctx):\n        t = lib.ggml_new_tensor_2d(ctx, lib.GGML_TYPE_F32, 2, 3)\n        a = np.arange(6, dtype=np.float32)\n        copy(a.reshape((2, 3)), t) # OK\n        \n        a = a.reshape((3, 2))\n        with raises(AssertionError): copy(a, t)\n        with raises(AssertionError): copy(t, a)\n\n    def test_copy_mismatching_shapes_3d(self, ctx):\n        t = lib.ggml_new_tensor_3d(ctx, lib.GGML_TYPE_F32, 2, 3, 4)\n        a = np.arange(24, dtype=np.float32)\n        copy(a.reshape((2, 3, 4)), t) # OK\n        \n        a = a.reshape((2, 4, 3))\n        with raises(AssertionError): copy(a, t)\n        with raises(AssertionError): copy(t, a)\n\n    def test_copy_mismatching_shapes_4d(self, ctx):\n        t = lib.ggml_new_tensor_4d(ctx, lib.GGML_TYPE_F32, 2, 3, 4, 5)\n        a = np.arange(24*5, dtype=np.float32)\n        copy(a.reshape((2, 3, 4, 5)), t) # OK\n        \n        a = a.reshape((2, 3, 5, 4))\n        with raises(AssertionError): copy(a, t)\n        with raises(AssertionError): copy(t, a)\n\n    def test_copy_f16_to_f32(self, ctx):\n        t = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, 1)\n        a = np.array([123.45], dtype=np.float16)\n        copy(a, t)\n        np.testing.assert_allclose(lib.ggml_get_f32_1d(t, 0), 123.45, rtol=1e-3)\n\n    def test_copy_f32_to_f16(self, ctx):\n        t = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F16, 1)\n        a = np.array([123.45], dtype=np.float32)\n        copy(a, t)\n        np.testing.assert_allclose(lib.ggml_get_f32_1d(t, 0), 123.45, rtol=1e-3)\n\n    def test_copy_f16_to_Q5_K(self, ctx):\n        n = 256\n        t = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_Q5_K, n)\n        a = np.arange(n, dtype=np.float16)\n        copy(a, t)\n        np.testing.assert_allclose(a, numpy(t, allow_copy=True), rtol=0.05)\n\n    def test_copy_Q5_K_to_f16(self, ctx):\n        n = 256\n        t = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_Q5_K, n)\n        copy(np.arange(n, dtype=np.float32), t)\n        a = np.arange(n, dtype=np.float16)\n        copy(t, a)\n        np.testing.assert_allclose(a, numpy(t, allow_copy=True), rtol=0.05)\n\n    def test_copy_i16_f32_mismatching_types(self, ctx):\n        t = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, 1)\n        a = np.arange(1, dtype=np.int16)\n        with raises(NotImplementedError): copy(a, t)\n        with raises(NotImplementedError): copy(t, a)\n\nclass TestTensorCopy:\n\n    def test_copy_self(self, ctx):\n        t = lib.ggml_new_i32(ctx, 42)\n        copy(t, t)\n        assert lib.ggml_get_i32_1d(t, 0) == 42\n\n    def test_copy_1d(self, ctx):\n        t1 = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, 10)\n        t2 = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, 10)\n        a = np.arange(10, dtype=np.float32)\n        copy(a, t1)\n        copy(t1, t2)\n        assert np.allclose(a, numpy(t2))\n        assert np.allclose(numpy(t1), numpy(t2))\n\nclass TestGraph:\n\n    def test_add(self, ctx):\n        n = 256\n        ta = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, n)\n        tb = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, n)\n        tsum = lib.ggml_add(ctx, ta, tb)\n        assert tsum.type == lib.GGML_TYPE_F32\n\n        gf = ffi.new('struct ggml_cgraph*')\n        lib.ggml_build_forward_expand(gf, tsum)\n\n        a = np.arange(0, n, dtype=np.float32)\n        b = np.arange(n, 0, -1, dtype=np.float32)\n        copy(a, ta)\n        copy(b, tb)\n\n        lib.ggml_graph_compute_with_ctx(ctx, gf, 1)\n\n        assert np.allclose(numpy(tsum, allow_copy=True), a + b)\n\nclass TestQuantization:\n\n    def test_quantized_add(self, ctx):\n        n = 256\n        ta = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_Q5_K, n)\n        tb = lib.ggml_new_tensor_1d(ctx, lib.GGML_TYPE_F32, n)\n        tsum = lib.ggml_add(ctx, ta, tb)\n        assert tsum.type == lib.GGML_TYPE_Q5_K\n\n        gf = ffi.new('struct ggml_cgraph*')\n        lib.ggml_build_forward_expand(gf, tsum)\n\n        a = np.arange(0, n, dtype=np.float32)\n        b = np.arange(n, 0, -1, dtype=np.float32)\n        copy(a, ta)\n        copy(b, tb)\n\n        lib.ggml_graph_compute_with_ctx(ctx, gf, 1)\n\n        unquantized_sum = a + b\n        sum = numpy(tsum, allow_copy=True)\n\n        diff = np.linalg.norm(unquantized_sum - sum, np.inf)\n        assert diff > 4\n        assert diff < 5\n"
  },
  {
    "path": "examples/sam/CMakeLists.txt",
    "content": "#\n# sam\n\nset(TEST_TARGET sam)\nadd_executable(${TEST_TARGET} sam.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml common)\n\n#\n# sam-quantize\n\n#set(TEST_TARGET sam-quantize)\n#add_executable(${TEST_TARGET} quantize.cpp)\n#target_link_libraries(${TEST_TARGET} PRIVATE ggml common)\n"
  },
  {
    "path": "examples/sam/README.md",
    "content": "# SAM.cpp\n\nInference of Meta's [Segment Anything Model](https://github.com/facebookresearch/segment-anything/) in pure C/C++\n\n## Description\n\nThe example currently supports only the [ViT-B SAM model checkpoint](https://huggingface.co/facebook/sam-vit-base).\n\n## Next steps\n\n- [X] Reduce memory usage by utilizing the new ggml-alloc\n- [X] Remove redundant graph nodes\n- [ ] Make inference faster\n- [X] Fix the difference in output masks compared to the PyTorch implementation\n- [X] Filter masks based on stability score\n- [ ] Add support for user input\n- [ ] Support F16 for heavy F32 ops\n- [ ] Test quantization\n- [X] Support bigger model checkpoints\n- [ ] GPU support\n\n## Quick start\nSetup Python and build examples according to main README.\n\n```bash\n# Download PTH model\nwget -P examples/sam/ https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth\n\n# Convert PTH model to ggml\npython examples/sam/convert-pth-to-ggml.py examples/sam/sam_vit_b_01ec64.pth examples/sam/ 1\n\n# run inference\n./bin/sam -t 16 -i ../examples/sam/example.jpg -m ../examples/sam/ggml-model-f16.bin\n```\n\n## Downloading and converting the model checkpoints\n\nYou can download a [model checkpoint](https://github.com/facebookresearch/segment-anything/tree/main#model-checkpoints) and convert it to `ggml` format using the script `convert-pth-to-ggml.py`:\n\n## Example output on M2 Ultra\n```\n $ ▶ make -j sam && time ./bin/sam -t 8 -i img.jpg\n[ 28%] Built target common\n[ 71%] Built target ggml\n[100%] Built target sam\nmain: seed = 1693224265\nmain: loaded image 'img.jpg' (680 x 453)\nsam_image_preprocess: scale = 0.664062\nmain: preprocessed image (1024 x 1024)\nsam_model_load: loading model from 'models/sam-vit-b/ggml-model-f16.bin' - please wait ...\nsam_model_load: n_enc_state      = 768\nsam_model_load: n_enc_layer      = 12\nsam_model_load: n_enc_head       = 12\nsam_model_load: n_enc_out_chans  = 256\nsam_model_load: n_pt_embd        = 4\nsam_model_load: ftype            = 1\nsam_model_load: qntvr            = 0\noperator(): ggml ctx size = 202.32 MB\nsam_model_load: ...................................... done\nsam_model_load: model size =   185.05 MB / num tensors = 304\nembd_img\ndims: 64 64 256 1 f32\nFirst & Last 10 elements:\n-0.05117 -0.06408 -0.07154 -0.06991 -0.07212 -0.07690 -0.07508 -0.07281 -0.07383 -0.06779\n0.01589 0.01775 0.02250 0.01675 0.01766 0.01661 0.01811 0.02051 0.02103 0.03382\nsum:  12736.272313\n\nSkipping mask 0 with iou 0.705935 below threshold 0.880000\nSkipping mask 1 with iou 0.762136 below threshold 0.880000\nMask 2: iou = 0.947081, stability_score = 0.955437, bbox (371, 436), (144, 168)\n\n\nmain:     load time =    51.28 ms\nmain:    total time =  2047.49 ms\n\nreal\t0m2.068s\nuser\t0m16.343s\nsys\t0m0.214s\n```\n\nInput point is (414.375, 162.796875) (currently hardcoded)\n\nInput image:\n\n![llamas](https://user-images.githubusercontent.com/8558655/261301565-37b7bf4b-bf91-40cf-8ec1-1532316e1612.jpg)\n\nOutput mask (mask_out_2.png in build folder):\n\n![mask_glasses](https://user-images.githubusercontent.com/8558655/263706800-47eeea30-1457-4c87-938b-8f11536c5aa7.png)\n\n## References\n\n- [ggml](https://github.com/ggerganov/ggml)\n- [SAM](https://segment-anything.com/)\n- [SAM demo](https://segment-anything.com/demo)\n"
  },
  {
    "path": "examples/sam/convert-pth-to-ggml.py",
    "content": "# Convert a SAM model checkpoint to a ggml compatible file\n#\n\nimport sys\nimport torch\nimport struct\nimport numpy as np\n\nif len(sys.argv) < 3:\n    print(\"Usage: convert-pth-to-ggml.py file-model dir-output [ftype]\\n\")\n    print(\"  ftype == 0 -> float32\")\n    print(\"  ftype == 1 -> float16\")\n    sys.exit(1)\n\n# output in the same directory as the model\nfname_model = sys.argv[1]\ndir_out     = sys.argv[2]\nfname_out   = dir_out + \"/ggml-model.bin\"\n\n# possible data types\n#   ftype == 0 -> float32\n#   ftype == 1 -> float16\n#\n# map from ftype to string\nftype_str = [\"f32\", \"f16\"]\n\nftype = 1\nif len(sys.argv) > 3:\n    ftype = int(sys.argv[3])\n\nif ftype < 0 or ftype > 1:\n    print(\"Invalid ftype: \" + str(ftype))\n    sys.exit(1)\n\nfname_out = fname_out.replace(\".bin\", \"-\" + ftype_str[ftype] + \".bin\")\n\n# Default params are set to sam_vit_b checkpoint\nn_enc_state = 768\nn_enc_layers = 12\nn_enc_heads = 12\nn_enc_out_chans = 256\nn_pt_embd = 4\n\nmodel = torch.load(fname_model, map_location=\"cpu\")\nfor k, v in model.items():\n    print(k, v.shape)\n    if k == \"image_encoder.blocks.0.norm1.weight\":\n        n_enc_state = v.shape[0]\n\nif n_enc_state == 1024: # sam_vit_l\n    n_enc_layers = 24\n    n_enc_heads  = 16\nelif n_enc_state == 1280: # sam_vit_h\n    n_enc_layers = 32\n    n_enc_heads  = 16\n\nhparams = {\n    \"n_enc_state\":      n_enc_state,\n    \"n_enc_layers\":     n_enc_layers,\n    \"n_enc_heads\":      n_enc_heads,\n    \"n_enc_out_chans\":  n_enc_out_chans,\n    \"n_pt_embd\":        n_pt_embd,\n}\n\nprint(hparams)\n\nfor k, v in model.items():\n    print(k, v.shape)\n\n#exit()\n#code.interact(local=locals())\n\nfout = open(fname_out, \"wb\")\n\nfout.write(struct.pack(\"i\", 0x67676d6c)) # magic: ggml in hex\nfout.write(struct.pack(\"i\", hparams[\"n_enc_state\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_enc_layers\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_enc_heads\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_enc_out_chans\"]))\nfout.write(struct.pack(\"i\", hparams[\"n_pt_embd\"]))\nfout.write(struct.pack(\"i\", ftype))\n\nfor k, v in model.items():\n    name = k\n    shape = v.shape\n\n    if name[:19] == \"prompt_encoder.mask\":\n        continue\n\n    print(\"Processing variable: \" + name + \" with shape: \", shape, \" and type: \", v.dtype)\n\n    #data = tf.train.load_variable(dir_model, name).squeeze()\n    #data = v.numpy().squeeze()\n    data = v.numpy()\n    n_dims = len(data.shape)\n\n    # for efficiency - transpose some matrices\n    # \"model/h.*/attn/c_attn/w\"\n    # \"model/h.*/attn/c_proj/w\"\n    # \"model/h.*/mlp/c_fc/w\"\n    # \"model/h.*/mlp/c_proj/w\"\n    #if name[-14:] == \"/attn/c_attn/w\" or \\\n    #   name[-14:] == \"/attn/c_proj/w\" or \\\n    #   name[-11:] == \"/mlp/c_fc/w\" or \\\n    #   name[-13:] == \"/mlp/c_proj/w\":\n    #    print(\"  Transposing\")\n    #    data = data.transpose()\n\n    dshape = data.shape\n\n    # default type is fp16\n    ftype_cur = 1\n    if ftype == 0 or n_dims == 1 or \\\n            name == \"image_encoder.pos_embed\" or \\\n            name.startswith(\"prompt_encoder\") or \\\n            name.startswith(\"mask_decoder.iou_token\") or \\\n            name.startswith(\"mask_decoder.mask_tokens\"):\n        print(\"  Converting to float32\")\n        data = data.astype(np.float32)\n        ftype_cur = 0\n    else:\n        print(\"  Converting to float16\")\n        data = data.astype(np.float16)\n\n    # reshape the 1D bias into a 4D tensor so we can use ggml_repeat\n    # keep it in F32 since the data is small\n    if name == \"image_encoder.patch_embed.proj.bias\":\n        data = data.reshape(1, data.shape[0], 1, 1)\n        n_dims = len(data.shape)\n        dshape = data.shape\n\n    print(\"  New shape: \", dshape)\n\n    # header\n    str = name.encode('utf-8')\n    fout.write(struct.pack(\"iii\", n_dims, len(str), ftype_cur))\n    for i in range(n_dims):\n        fout.write(struct.pack(\"i\", dshape[n_dims - 1 - i]))\n    fout.write(str)\n\n    # data\n    data.tofile(fout)\n\nfout.close()\n\nprint(\"Done. Output file: \" + fname_out)\nprint(\"\")\n"
  },
  {
    "path": "examples/sam/sam.cpp",
    "content": "#define _USE_MATH_DEFINES // for M_PI\n#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous \"unsafe\" warnigns on Windows\n\n#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n#define STB_IMAGE_IMPLEMENTATION\n#include \"stb_image.h\"\n#define STB_IMAGE_WRITE_IMPLEMENTATION\n#include \"stb_image_write.h\"\n\n#include <cassert>\n#include <cmath>\n#include <cstddef>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <string>\n#include <vector>\n#include <thread>\n#include <cinttypes>\n\n#if defined(_MSC_VER)\n#pragma warning(disable: 4244 4267) // possible loss of data\n#endif\n\n// default hparams (ViT-B SAM)\nstruct sam_hparams {\n    int32_t n_enc_state               = 768;\n    int32_t n_enc_layer               = 12;\n    int32_t n_enc_head                = 12;\n    int32_t n_enc_out_chans           = 256;\n    int32_t n_pt_embd                 = 4;\n    int32_t n_dec_heads               = 8;\n    int32_t ftype                     = 1;\n    float   mask_threshold            = 0.f;\n    float   iou_threshold             = 0.88f;\n    float   stability_score_threshold = 0.95f;\n    float   stability_score_offset    = 1.0f;\n    float   eps                       = 1e-6f;\n    float   eps_decoder_transformer   = 1e-5f;\n\n    int32_t n_enc_head_dim() const { return n_enc_state / n_enc_head; }\n    int32_t n_img_size()     const { return 1024; }\n    int32_t n_window_size()  const { return 14; }\n    int32_t n_patch_size()   const { return 16; }\n    int32_t n_img_embd()     const { return n_img_size() / n_patch_size(); }\n\n    std::vector<int32_t> global_attn_indices() const {\n        switch (n_enc_state) {\n            case  768: return {  2,  5,  8, 11 };\n            case 1024: return {  5, 11, 17, 23 };\n            case 1280: return {  7, 15, 23, 31 };\n            default:\n                {\n                    fprintf(stderr, \"%s: unsupported n_enc_state = %d\\n\", __func__, n_enc_state);\n                } break;\n        };\n\n        return {};\n    }\n\n    bool is_global_attn(int32_t layer) const {\n        const auto indices = global_attn_indices();\n\n        for (const auto & idx : indices) {\n            if (layer == idx) {\n                return true;\n            }\n        }\n\n        return false;\n    }\n};\n\nstruct sam_layer_enc {\n    struct ggml_tensor * norm1_w;\n    struct ggml_tensor * norm1_b;\n\n    struct ggml_tensor * rel_pos_w;\n    struct ggml_tensor * rel_pos_h;\n\n    struct ggml_tensor * qkv_w;\n    struct ggml_tensor * qkv_b;\n\n    struct ggml_tensor * proj_w;\n    struct ggml_tensor * proj_b;\n\n    struct ggml_tensor * norm2_w;\n    struct ggml_tensor * norm2_b;\n\n    struct ggml_tensor * mlp_lin1_w;\n    struct ggml_tensor * mlp_lin1_b;\n\n    struct ggml_tensor * mlp_lin2_w;\n    struct ggml_tensor * mlp_lin2_b;\n};\n\nstruct sam_encoder_image {\n    struct ggml_tensor * pe;\n\n    struct ggml_tensor * proj_w;\n    struct ggml_tensor * proj_b;\n\n    struct ggml_tensor * neck_conv_0;\n    struct ggml_tensor * neck_norm_0_w;\n    struct ggml_tensor * neck_norm_0_b;\n    struct ggml_tensor * neck_conv_1;\n    struct ggml_tensor * neck_norm_1_w;\n    struct ggml_tensor * neck_norm_1_b;\n\n    std::vector<sam_layer_enc> layers;\n};\n\nstruct sam_encoder_prompt {\n    struct ggml_tensor * pe;\n\n    struct ggml_tensor * not_a_pt_embd_w;\n    std::vector<struct ggml_tensor *> pt_embd;\n\n    struct ggml_tensor * no_mask_embd_w;\n    //std::vector<struct ggml_tensor *> mask_down_w;\n    //std::vector<struct ggml_tensor *> mask_down_b;\n};\n\nstruct  sam_layer_dec_transformer_attn {\n    // q_proj\n    struct ggml_tensor * q_w;\n    struct ggml_tensor * q_b;\n\n    // k_proj\n    struct ggml_tensor * k_w;\n    struct ggml_tensor * k_b;\n\n    // v_proj\n    struct ggml_tensor * v_w;\n    struct ggml_tensor * v_b;\n\n    // out_proj\n    struct ggml_tensor * out_w;\n    struct ggml_tensor * out_b;\n};\n\nstruct sam_layer_dec_transformer {\n    sam_layer_dec_transformer_attn self_attn;\n\n    // norm1\n    struct ggml_tensor * norm1_w;\n    struct ggml_tensor * norm1_b;\n\n    sam_layer_dec_transformer_attn cross_attn_token_to_img;\n\n    // norm2\n    struct ggml_tensor * norm2_w;\n    struct ggml_tensor * norm2_b;\n\n    // mlp.lin1\n    struct ggml_tensor * mlp_lin1_w;\n    struct ggml_tensor * mlp_lin1_b;\n\n    // mlp.lin2\n    struct ggml_tensor * mlp_lin2_w;\n    struct ggml_tensor * mlp_lin2_b;\n\n    // norm3\n    struct ggml_tensor * norm3_w;\n    struct ggml_tensor * norm3_b;\n\n    // norm4\n    struct ggml_tensor * norm4_w;\n    struct ggml_tensor * norm4_b;\n\n    sam_layer_dec_transformer_attn cross_attn_img_to_token;\n};\n\nstruct sam_layer_dec_output_hypernet_mlps {\n    // mlps_*.layers.0\n    struct ggml_tensor * w_0;\n    struct ggml_tensor * b_0;\n\n    // mlps_*.layers.1\n    struct ggml_tensor * w_1;\n    struct ggml_tensor * b_1;\n\n    // mlps_*.layers.2\n    struct ggml_tensor * w_2;\n    struct ggml_tensor * b_2;\n};\n\nstruct sam_decoder_mask {\n    std::vector<sam_layer_dec_transformer> transformer_layers;\n\n    // trasnformer.final_attn_token_to_image\n    sam_layer_dec_transformer_attn transformer_final_attn_token_to_img;\n\n    // transformer.norm_final\n    struct ggml_tensor * transformer_norm_final_w;\n    struct ggml_tensor * transformer_norm_final_b;\n\n    // output_upscaling.0\n    struct ggml_tensor * output_upscaling_0_w;\n    struct ggml_tensor * output_upscaling_0_b;\n\n    // output_upscaling.1\n    struct ggml_tensor * output_upscaling_1_w;\n    struct ggml_tensor * output_upscaling_1_b;\n\n    // output_upscaling.3\n    struct ggml_tensor * output_upscaling_3_w;\n    struct ggml_tensor * output_upscaling_3_b;\n\n    // output_hypernetworks_mlps\n    std::vector<sam_layer_dec_output_hypernet_mlps> output_hypernet_mlps;\n\n    // iou_prediction_head.0\n    struct ggml_tensor * iou_prediction_head_0_w;\n    struct ggml_tensor * iou_prediction_head_0_b;\n\n    // iou_prediction_head.1\n    struct ggml_tensor * iou_prediction_head_1_w;\n    struct ggml_tensor * iou_prediction_head_1_b;\n\n    // iou_prediction_head.2\n    struct ggml_tensor * iou_prediction_head_2_w;\n    struct ggml_tensor * iou_prediction_head_2_b;\n\n    // iou_token.weight\n    struct ggml_tensor * iou_token_w;\n\n    // mask_tokens.weight\n    struct ggml_tensor * mask_tokens_w;\n};\n\n\nstruct sam_state {\n    struct ggml_tensor * embd_img;\n\n    struct ggml_tensor * low_res_masks;\n    struct ggml_tensor * iou_predictions;\n\n    //struct ggml_tensor * tmp_save = {};\n\n    struct ggml_context * ctx;\n\n    // buffer for `ggml_graph_plan.work_data`\n    std::vector<uint8_t> work_buffer;\n    // buffers to evaluate the model\n    std::vector<uint8_t> buf_compute_img_enc;\n\n    std::vector<uint8_t> buf_compute_fast;\n\n    ggml_gallocr_t       allocr = {};\n};\n\n// void save_tensor(sam_state& state, struct ggml_tensor * t, struct ggml_cgraph * gf) {\n//     if (!state.tmp_save) {\n//         state.tmp_save = ggml_new_tensor(state.ctx, t->type, t->n_dims, t->ne);\n//     }\n//     struct ggml_tensor * tmp0 = ggml_cpy(state.ctx, t, state.tmp_save);\n//     ggml_build_forward_expand(gf, tmp0);\n// }\n\nstruct sam_model {\n    sam_hparams hparams;\n\n    sam_encoder_image  enc_img;\n    sam_encoder_prompt enc_prompt;\n    sam_decoder_mask   dec;\n\n    //\n    struct ggml_context * ctx;\n    std::map<std::string, struct ggml_tensor *> tensors;\n};\n\nstruct sam_point {\n    float x;\n    float y;\n};\n\nstruct sam_box {\n    float x1;\n    float y1;\n    float x2;\n    float y2;\n};\n\n// RGB uint8 image\nstruct sam_image_u8 {\n    int nx;\n    int ny;\n\n    std::vector<uint8_t> data;\n};\n\n// RGB float32 image\n// Memory layout: RGBRGBRGB...\nstruct sam_image_f32 {\n    int nx;\n    int ny;\n\n    std::vector<float> data;\n};\n\nenum sam_prompt_type {\n    SAM_PROMPT_TYPE_POINT = 0,\n    SAM_PROMPT_TYPE_BOX = 1,\n};\n\nstruct sam_prompt {\n    sam_prompt_type prompt_type = SAM_PROMPT_TYPE_POINT;\n    sam_point pt = { 414.375f, 162.796875f, };\n    sam_box   box = { 368.0f, 144.0f, 441.0f, 173.0f };\n};\n\nstruct sam_params {\n    int32_t seed      = -1; // RNG seed\n    int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());\n\n    std::string model     = \"models/sam-vit-b/ggml-model-f16.bin\"; // model path\n    std::string fname_inp = \"img.jpg\";\n    std::string fname_out = \"img.out\";\n    float   mask_threshold            = 0.f;\n    float   iou_threshold             = 0.88f;\n    float   stability_score_threshold = 0.95f;\n    float   stability_score_offset    = 1.0f;\n    float   eps                       = 1e-6f;\n    float   eps_decoder_transformer   = 1e-5f;\n\n    sam_prompt prompt;\n    bool multimask_output = true;\n};\n\nvoid print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) {\n    printf(\"%s\\n\", title);\n    float * data = (float *)t->data;\n    printf(\"dims: % \" PRId64 \" % \" PRId64 \" % \" PRId64 \" % \" PRId64 \" f32\\n\", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);\n    printf(\"First & Last %d elements:\\n\", n);\n    for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) {\n        printf(\"%.5f \", data[i]);\n        if (i != 0 && i % t->ne[0] == 0) {\n            printf(\"\\n\");\n        }\n    }\n    printf(\"\\n\");\n    for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) {\n        printf(\"%.5f \", data[ggml_nelements(t) - n + i]);\n        if ((ggml_nelements(t) - n + i) % t->ne[0] == 0) {\n            printf(\"\\n\");\n        }\n    }\n    printf(\"\\n\");\n    double sum = 0.0;\n    for (int i = 0; i < ggml_nelements(t); i++) {\n        sum += data[i];\n    }\n    printf(\"sum:  %f\\n\\n\", sum);\n}\n\nstatic void ggml_disconnect_node_from_graph(ggml_tensor * t) {\n    t->op = GGML_OP_NONE;\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        t->src[i] = NULL;\n    }\n}\n\nstatic void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {\n    struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);\n\n    if (plan.work_size > 0) {\n        buf.resize(plan.work_size);\n        plan.work_data = buf.data();\n    }\n\n    ggml_graph_compute(graph, &plan);\n}\n\nstatic void ggml_sam_sin(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) {\n    GGML_ASSERT(userdata == NULL);\n    GGML_ASSERT(ggml_are_same_shape(dst, src));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    GGML_ASSERT(ggml_is_contiguous(src));\n\n    const float * src_data = ggml_get_data_f32(src);\n    float * dst_data = ggml_get_data_f32(dst);\n\n    const int ne = (int)ggml_nelements(dst);\n    const int dr = (ne + nth - 1) / nth;\n    const int ie0 = dr * ith;\n    const int ie1 = std::min(ie0 + dr, ne);\n\n    for (int i = ie0; i < ie1; ++i) {\n        dst_data[i] = sinf(src_data[i]);\n    }\n}\n\nstatic void ggml_sam_cos(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) {\n    GGML_ASSERT(userdata == NULL);\n    GGML_ASSERT(ggml_are_same_shape(dst, src));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    GGML_ASSERT(ggml_is_contiguous(src));\n\n    const float * src_data = ggml_get_data_f32(src);\n    float * dst_data = ggml_get_data_f32(dst);\n\n    const int ne = (int)ggml_nelements(dst);\n    const int dr = (ne + nth - 1) / nth;\n    const int ie0 = dr * ith;\n    const int ie1 = std::min(ie0 + dr, ne);\n\n    for (int i = ie0; i < ie1; ++i) {\n        dst_data[i] = cosf(src_data[i]);\n    }\n}\n\nbool sam_image_load_from_file(const std::string & fname, sam_image_u8 & img) {\n    int nx, ny, nc;\n    auto data = stbi_load(fname.c_str(), &nx, &ny, &nc, 3);\n    if (!data) {\n        fprintf(stderr, \"%s: failed to load '%s'\\n\", __func__, fname.c_str());\n        return false;\n    }\n\n    img.nx = nx;\n    img.ny = ny;\n    img.data.resize(nx * ny * 3);\n    memcpy(img.data.data(), data, nx * ny * 3);\n\n    stbi_image_free(data);\n\n    return true;\n}\n\n// ref: https://github.com/facebookresearch/segment-anything/blob/efeab7296ab579d4a261e554eca80faf6b33924a/segment_anything/modeling/sam.py#L164\n// resize largest dimension to 1024\n// normalize: x = (x - mean) / std\n//     mean = [123.675, 116.28, 103.53]\n//     std  = [58.395, 57.12, 57.375]\n//     TODO: why are these hardcoded !?\n// pad to 1024x1024\n// TODO: for some reason, this is not numerically identical to pytorch's interpolation\nbool sam_image_preprocess(const sam_image_u8 & img, sam_image_f32 & res) {\n    const int nx = img.nx;\n    const int ny = img.ny;\n\n    const int nx2 = 1024;\n    const int ny2 = 1024;\n\n    res.nx = nx2;\n    res.ny = ny2;\n    res.data.resize(3*nx2*ny2);\n\n    const float scale = std::max(nx, ny) / 1024.0f;\n\n    fprintf(stderr, \"%s: scale = %f\\n\", __func__, scale);\n\n    const int nx3 = int(nx/scale + 0.5f);\n    const int ny3 = int(ny/scale + 0.5f);\n\n    const float m3[3] = { 123.675f, 116.280f, 103.530f };\n    const float s3[3] = {  58.395f,  57.120f,  57.375f };\n\n    for (int y = 0; y < ny3; y++) {\n        for (int x = 0; x < nx3; x++) {\n            for (int c = 0; c < 3; c++) {\n                // linear interpolation\n                const float sx = (x + 0.5f)*scale - 0.5f;\n                const float sy = (y + 0.5f)*scale - 0.5f;\n\n                const int x0 = std::max(0, (int) std::floor(sx));\n                const int y0 = std::max(0, (int) std::floor(sy));\n\n                const int x1 = std::min(x0 + 1, nx - 1);\n                const int y1 = std::min(y0 + 1, ny - 1);\n\n                const float dx = sx - x0;\n                const float dy = sy - y0;\n\n                const int j00 = 3*(y0*nx + x0) + c;\n                const int j01 = 3*(y0*nx + x1) + c;\n                const int j10 = 3*(y1*nx + x0) + c;\n                const int j11 = 3*(y1*nx + x1) + c;\n\n                const float v00 = img.data[j00];\n                const float v01 = img.data[j01];\n                const float v10 = img.data[j10];\n                const float v11 = img.data[j11];\n\n                const float v0 = v00*(1.0f - dx) + v01*dx;\n                const float v1 = v10*(1.0f - dx) + v11*dx;\n\n                const float v = v0*(1.0f - dy) + v1*dy;\n\n                const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f);\n\n                const int i = 3*(y*nx3 + x) + c;\n\n                res.data[i] = (float(v2) - m3[c]) / s3[c];\n            }\n        }\n    }\n\n    return true;\n}\n\n// load the model's weights from a file\nbool sam_model_load(const sam_params & params, sam_model & model) {\n    fprintf(stderr, \"%s: loading model from '%s' - please wait ...\\n\", __func__, params.model.c_str());\n\n    auto fin = std::ifstream(params.model, std::ios::binary);\n    if (!fin) {\n        fprintf(stderr, \"%s: failed to open '%s'\\n\", __func__, params.model.c_str());\n        return false;\n    }\n\n    // verify magic\n    {\n        uint32_t magic;\n        fin.read((char *) &magic, sizeof(magic));\n        if (magic != 0x67676d6c) {\n            fprintf(stderr, \"%s: invalid model file '%s' (bad magic)\\n\", __func__, params.model.c_str());\n            return false;\n        }\n    }\n\n    // load hparams\n    {\n        // Override defaults with user choices\n        model.hparams.mask_threshold            = params.mask_threshold;\n        model.hparams.iou_threshold             = params.iou_threshold;\n        model.hparams.stability_score_threshold = params.stability_score_threshold;\n        model.hparams.stability_score_offset    = params.stability_score_offset;\n        model.hparams.eps                       = params.eps;\n        model.hparams.eps_decoder_transformer   = params.eps_decoder_transformer;\n\n        auto & hparams = model.hparams;\n\n        fin.read((char *) &hparams.n_enc_state,     sizeof(hparams.n_enc_state));\n        fin.read((char *) &hparams.n_enc_layer,     sizeof(hparams.n_enc_layer));\n        fin.read((char *) &hparams.n_enc_head,      sizeof(hparams.n_enc_head));\n        fin.read((char *) &hparams.n_enc_out_chans, sizeof(hparams.n_enc_out_chans));\n        fin.read((char *) &hparams.n_pt_embd,       sizeof(hparams.n_pt_embd));\n        fin.read((char *) &hparams.ftype,           sizeof(hparams.ftype));\n\n        const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;\n\n        printf(\"%s: n_enc_state      = %d\\n\", __func__, hparams.n_enc_state);\n        printf(\"%s: n_enc_layer      = %d\\n\", __func__, hparams.n_enc_layer);\n        printf(\"%s: n_enc_head       = %d\\n\", __func__, hparams.n_enc_head);\n        printf(\"%s: n_enc_out_chans  = %d\\n\", __func__, hparams.n_enc_out_chans);\n        printf(\"%s: n_pt_embd        = %d\\n\", __func__, hparams.n_pt_embd);\n        printf(\"%s: ftype            = %d\\n\", __func__, hparams.ftype);\n        printf(\"%s: qntvr            = %d\\n\", __func__, qntvr);\n\n        hparams.ftype %= GGML_QNT_VERSION_FACTOR;\n\n    }\n\n    // for the big tensors, we have the option to store the data in 16-bit floats or quantized\n    // in order to save memory and also to speed up the computation\n    ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));\n    if (wtype == GGML_TYPE_COUNT) {\n        fprintf(stderr, \"%s: invalid model file '%s' (bad ftype value %d)\\n\",\n                __func__, params.model.c_str(), model.hparams.ftype);\n        return false;\n    }\n\n    auto & ctx = model.ctx;\n\n    const size_t ctx_size = [&]() {\n        size_t ctx_size = 0;\n\n        const auto & hparams = model.hparams;\n\n        const int32_t n_enc_state     = hparams.n_enc_state;\n        const int32_t n_enc_layer     = hparams.n_enc_layer;\n        const int32_t n_enc_head_dim  = hparams.n_enc_head_dim();\n        const int32_t n_enc_out_chans = hparams.n_enc_out_chans;\n        const int32_t n_pt_embd       = hparams.n_pt_embd;\n\n        const int32_t n_enc_layer_local  = hparams.global_attn_indices().size();\n        const int32_t n_enc_layer_global = n_enc_layer - n_enc_layer_local;\n\n        const int32_t n_img_embd    = hparams.n_img_embd();\n        const int32_t n_window_size = hparams.n_window_size();\n        const int32_t n_patch_size  = hparams.n_patch_size();\n\n        // image encoder\n        {\n            ctx_size += n_enc_state*n_img_embd*n_img_embd*ggml_type_size(GGML_TYPE_F32);\n\n            ctx_size += n_enc_state*3*n_patch_size*n_patch_size*ggml_type_size(GGML_TYPE_F16);\n            ctx_size += n_enc_state*ggml_type_size(GGML_TYPE_F32);\n\n            ctx_size +=     n_enc_state*n_enc_out_chans*1*1*ggml_type_size(GGML_TYPE_F16);\n            ctx_size += n_enc_out_chans*n_enc_out_chans*3*3*ggml_type_size(GGML_TYPE_F16);\n\n            ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32);\n            ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32);\n\n            ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32);\n            ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32);\n        }\n\n        // image encoder layers\n        {\n            ctx_size += n_enc_layer*n_enc_state*ggml_type_size(GGML_TYPE_F32);\n            ctx_size += n_enc_layer*n_enc_state*ggml_type_size(GGML_TYPE_F32);\n\n            ctx_size += n_enc_layer_global*n_enc_head_dim*(2*n_img_embd - 1)*ggml_type_size(GGML_TYPE_F16);\n            ctx_size += n_enc_layer_global*n_enc_head_dim*(2*n_img_embd - 1)*ggml_type_size(GGML_TYPE_F16);\n\n            ctx_size += n_enc_layer_local*n_enc_head_dim*(2*n_window_size - 1)*ggml_type_size(GGML_TYPE_F16);\n            ctx_size += n_enc_layer_local*n_enc_head_dim*(2*n_window_size - 1)*ggml_type_size(GGML_TYPE_F16);\n\n            ctx_size += n_enc_layer*3*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16);\n            ctx_size += n_enc_layer*3*n_enc_state*            ggml_type_size(GGML_TYPE_F32);\n\n            ctx_size += n_enc_layer*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16);\n            ctx_size += n_enc_layer*n_enc_state*            ggml_type_size(GGML_TYPE_F32);\n\n            ctx_size += n_enc_layer*n_enc_state*ggml_type_size(GGML_TYPE_F32);\n            ctx_size += n_enc_layer*n_enc_state*ggml_type_size(GGML_TYPE_F32);\n\n            ctx_size += n_enc_layer*4*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16);\n            ctx_size += n_enc_layer*4*n_enc_state*            ggml_type_size(GGML_TYPE_F32);\n\n            ctx_size += n_enc_layer*4*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16);\n            ctx_size += n_enc_layer*4*n_enc_state*            ggml_type_size(GGML_TYPE_F32);\n        }\n\n        ctx_size += (8 + 14*n_enc_layer)*ggml_tensor_overhead();\n\n        // prompt encoder\n        {\n            ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F16); // 2*(n_enc_out_chans/2)\n\n            ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32);\n            ctx_size += n_pt_embd*n_enc_out_chans*ggml_type_size(GGML_TYPE_F32);\n        }\n\n        ctx_size += (2 + n_pt_embd)*ggml_tensor_overhead();\n\n        // mask decoder\n        {\n            //transformer\n            {\n                const int tfm_layers_count = 2;\n                const int qkv_count = 3;\n                const int norm_count = 4;\n                const int n_hypernet_mpls_count = 4;\n\n                // self_attn\n                ctx_size += tfm_layers_count*qkv_count*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16);\n                ctx_size += tfm_layers_count*qkv_count*n_enc_state*            ggml_type_size(GGML_TYPE_F32);\n                ctx_size += tfm_layers_count*n_enc_state*                      ggml_type_size(GGML_TYPE_F32);\n\n                // all norms\n                ctx_size += tfm_layers_count*norm_count*n_enc_state*ggml_type_size(GGML_TYPE_F32);\n                ctx_size += tfm_layers_count*norm_count*n_enc_state*ggml_type_size(GGML_TYPE_F32);\n\n                // cross_attn_token_to_img\n                ctx_size += tfm_layers_count*qkv_count*n_enc_state*(n_enc_state/2)*ggml_type_size(GGML_TYPE_F16);\n                ctx_size += tfm_layers_count*qkv_count*(n_enc_state/2)*            ggml_type_size(GGML_TYPE_F32);\n                ctx_size += tfm_layers_count*n_enc_state*                          ggml_type_size(GGML_TYPE_F32);\n\n                // mlp\n                ctx_size += tfm_layers_count*8*n_enc_out_chans*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16);\n                ctx_size += tfm_layers_count*8*n_enc_out_chans*                ggml_type_size(GGML_TYPE_F32);\n                ctx_size += tfm_layers_count*n_enc_out_chans*8*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16);\n                ctx_size += tfm_layers_count*n_enc_out_chans*                  ggml_type_size(GGML_TYPE_F32);\n\n                // cross_attn_img_to_token\n                ctx_size += tfm_layers_count*qkv_count*n_enc_state*(n_enc_state/2)*ggml_type_size(GGML_TYPE_F16);\n                ctx_size += tfm_layers_count*qkv_count*(n_enc_state/2)*            ggml_type_size(GGML_TYPE_F32);\n                ctx_size += tfm_layers_count*n_enc_state*                          ggml_type_size(GGML_TYPE_F32);\n\n                // transformer_final_attn_token_to_img\n                ctx_size += qkv_count*n_enc_state*(n_enc_state/2)*ggml_type_size(GGML_TYPE_F16);\n                ctx_size += qkv_count*(n_enc_state/2)*            ggml_type_size(GGML_TYPE_F32);\n                ctx_size += n_enc_state*                          ggml_type_size(GGML_TYPE_F32);\n\n                // transformer_norm_final\n                ctx_size += norm_count*n_enc_state*ggml_type_size(GGML_TYPE_F32);\n                ctx_size += norm_count*n_enc_state*ggml_type_size(GGML_TYPE_F32);\n\n                // output_upscaling\n                ctx_size += n_enc_out_chans*n_img_embd*2*2*ggml_type_size(GGML_TYPE_F16);\n                ctx_size += 3*n_img_embd*                  ggml_type_size(GGML_TYPE_F32);\n                ctx_size += n_enc_out_chans*n_img_embd*(n_img_embd/2)*2*2*ggml_type_size(GGML_TYPE_F16);\n                ctx_size += (n_img_embd/2)*                               ggml_type_size(GGML_TYPE_F32);\n\n                // output_hypernetworks_mlps\n                ctx_size += n_hypernet_mpls_count*2*n_enc_out_chans*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16);\n                ctx_size += n_hypernet_mpls_count*2*n_enc_out_chans*                ggml_type_size(GGML_TYPE_F32);\n                ctx_size += n_hypernet_mpls_count*n_enc_out_chans*(n_img_embd/2)*ggml_type_size(GGML_TYPE_F16);\n                ctx_size += n_hypernet_mpls_count*(n_img_embd/2)*                ggml_type_size(GGML_TYPE_F32);\n\n                // iou_prediction_head\n                ctx_size += 2*n_enc_out_chans*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16);\n                ctx_size += 2*n_enc_out_chans*                ggml_type_size(GGML_TYPE_F32);\n                ctx_size += n_pt_embd*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16);\n                ctx_size += n_pt_embd*                ggml_type_size(GGML_TYPE_F32);\n\n                // iou_token_w\n                ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32);\n\n                // mask_tokens_w\n                ctx_size += n_pt_embd*n_enc_out_chans*ggml_type_size(GGML_TYPE_F32);\n            }\n        }\n        fprintf(stderr, \"%s: ggml ctx size = %6.2f MB\\n\", __func__, ctx_size/(1024.0*1024.0));\n\n        return ctx_size;\n    }();\n\n    // create the ggml context\n    {\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ ctx_size,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ false,\n        };\n\n        ctx = ggml_init(params);\n        if (!ctx) {\n            fprintf(stderr, \"%s: ggml_init() failed\\n\", __func__);\n            return false;\n        }\n    }\n\n    // prepare memory for the weights\n    {\n        const auto & hparams = model.hparams;\n\n        const int32_t n_enc_state      = hparams.n_enc_state;\n        const int32_t n_enc_layer      = hparams.n_enc_layer;\n        const int32_t n_enc_head_dim   = hparams.n_enc_head_dim();\n        const int32_t n_enc_out_chans  = hparams.n_enc_out_chans;\n        const int32_t n_pt_embd        = hparams.n_pt_embd;\n\n        const int32_t n_img_embd    = hparams.n_img_embd();\n        const int32_t n_window_size = hparams.n_window_size();\n        const int32_t n_patch_size  = hparams.n_patch_size();\n\n        model.enc_img.layers.resize(n_enc_layer);\n\n        // image encoder\n        {\n            auto & enc = model.enc_img;\n\n            enc.pe = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_enc_state, n_img_embd, n_img_embd, 1);\n\n            enc.proj_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, n_patch_size, n_patch_size,           3, n_enc_state);\n            enc.proj_b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32,            1,            1, n_enc_state);\n\n            enc.neck_conv_0 = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, n_enc_state,     n_enc_out_chans);\n            enc.neck_conv_1 = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, n_enc_out_chans, n_enc_out_chans);\n\n            enc.neck_norm_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n            enc.neck_norm_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n\n            enc.neck_norm_1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n            enc.neck_norm_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n\n            model.tensors[\"image_encoder.pos_embed\"] = enc.pe;\n\n            model.tensors[\"image_encoder.patch_embed.proj.weight\"] = enc.proj_w;\n            model.tensors[\"image_encoder.patch_embed.proj.bias\"]   = enc.proj_b;\n\n            model.tensors[\"image_encoder.neck.0.weight\"] = enc.neck_conv_0;\n            model.tensors[\"image_encoder.neck.2.weight\"] = enc.neck_conv_1;\n\n            model.tensors[\"image_encoder.neck.1.weight\"] = enc.neck_norm_0_w;\n            model.tensors[\"image_encoder.neck.1.bias\"]   = enc.neck_norm_0_b;\n\n            model.tensors[\"image_encoder.neck.3.weight\"] = enc.neck_norm_1_w;\n            model.tensors[\"image_encoder.neck.3.bias\"]   = enc.neck_norm_1_b;\n\n            for (int i = 0; i < n_enc_layer; ++i) {\n                auto & layer = enc.layers[i];\n\n                layer.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state);\n                layer.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state);\n\n                if (hparams.is_global_attn(i)) {\n                    layer.rel_pos_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_img_embd - 1);\n                    layer.rel_pos_h = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_img_embd - 1);\n                } else {\n                    layer.rel_pos_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_window_size - 1);\n                    layer.rel_pos_h = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_window_size - 1);\n                }\n\n                layer.qkv_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16,   n_enc_state, 3*n_enc_state);\n                layer.qkv_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_enc_state);\n\n                layer.proj_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16,  n_enc_state,   n_enc_state);\n                layer.proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,  n_enc_state);\n\n                layer.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state);\n                layer.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state);\n\n                layer.mlp_lin1_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16,   n_enc_state, 4*n_enc_state);\n                layer.mlp_lin1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_enc_state);\n\n                layer.mlp_lin2_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 4*n_enc_state,   n_enc_state);\n                layer.mlp_lin2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_enc_state);\n\n                model.tensors[\"image_encoder.blocks.\" + std::to_string(i) + \".norm1.weight\"] = layer.norm1_w;\n                model.tensors[\"image_encoder.blocks.\" + std::to_string(i) + \".norm1.bias\"]   = layer.norm1_b;\n\n                model.tensors[\"image_encoder.blocks.\" + std::to_string(i) + \".attn.rel_pos_w\"] = layer.rel_pos_w;\n                model.tensors[\"image_encoder.blocks.\" + std::to_string(i) + \".attn.rel_pos_h\"] = layer.rel_pos_h;\n\n                model.tensors[\"image_encoder.blocks.\" + std::to_string(i) + \".attn.qkv.weight\"] = layer.qkv_w;\n                model.tensors[\"image_encoder.blocks.\" + std::to_string(i) + \".attn.qkv.bias\"]   = layer.qkv_b;\n\n                model.tensors[\"image_encoder.blocks.\" + std::to_string(i) + \".attn.proj.weight\"] = layer.proj_w;\n                model.tensors[\"image_encoder.blocks.\" + std::to_string(i) + \".attn.proj.bias\"]   = layer.proj_b;\n\n                model.tensors[\"image_encoder.blocks.\" + std::to_string(i) + \".norm2.weight\"] = layer.norm2_w;\n                model.tensors[\"image_encoder.blocks.\" + std::to_string(i) + \".norm2.bias\"]   = layer.norm2_b;\n\n                model.tensors[\"image_encoder.blocks.\" + std::to_string(i) + \".mlp.lin1.weight\"] = layer.mlp_lin1_w;\n                model.tensors[\"image_encoder.blocks.\" + std::to_string(i) + \".mlp.lin1.bias\"]   = layer.mlp_lin1_b;\n\n                model.tensors[\"image_encoder.blocks.\" + std::to_string(i) + \".mlp.lin2.weight\"] = layer.mlp_lin2_w;\n                model.tensors[\"image_encoder.blocks.\" + std::to_string(i) + \".mlp.lin2.bias\"]   = layer.mlp_lin2_b;\n            }\n        }\n\n        // prompt encoder\n        {\n            auto & enc = model.enc_prompt;\n\n            enc.pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_enc_out_chans/2, 2);\n\n            enc.not_a_pt_embd_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n            enc.no_mask_embd_w  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n\n            model.tensors[\"prompt_encoder.pe_layer.positional_encoding_gaussian_matrix\"] = enc.pe;\n            model.tensors[\"prompt_encoder.not_a_point_embed.weight\"] = enc.not_a_pt_embd_w;\n            model.tensors[\"prompt_encoder.no_mask_embed.weight\"]     = enc.no_mask_embd_w;\n\n            enc.pt_embd.resize(n_pt_embd);\n            for (int i = 0; i < n_pt_embd; i++) {\n                enc.pt_embd[i] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n\n                model.tensors[\"prompt_encoder.point_embeddings.\" + std::to_string(i) + \".weight\"] = enc.pt_embd[i];\n            }\n        }\n\n        // mask decoder\n        {\n            auto & dec = model.dec;\n            auto & tfm_layers = dec.transformer_layers;\n\n            const int tfm_layers_count = 2;\n            tfm_layers.resize(tfm_layers_count);\n            for (int i = 0; i < tfm_layers_count; ++i) {\n                auto& l = tfm_layers[i];\n                l.self_attn.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);\n                l.self_attn.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n                l.self_attn.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);\n                l.self_attn.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n                l.self_attn.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);\n                l.self_attn.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n                l.self_attn.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);\n                l.self_attn.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n\n                l.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n                l.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n\n                l.cross_attn_token_to_img.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);\n                l.cross_attn_token_to_img.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);\n                l.cross_attn_token_to_img.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);\n                l.cross_attn_token_to_img.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);\n                l.cross_attn_token_to_img.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);\n                l.cross_attn_token_to_img.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);\n                l.cross_attn_token_to_img.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans/2, n_enc_out_chans);\n                l.cross_attn_token_to_img.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n\n                l.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n                l.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n\n                l.mlp_lin1_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, 8*n_enc_out_chans);\n                l.mlp_lin1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 8*n_enc_out_chans);\n                l.mlp_lin2_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 8*n_enc_out_chans, n_enc_out_chans);\n                l.mlp_lin2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n\n                l.norm3_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n                l.norm3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n\n                l.norm4_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n                l.norm4_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n\n                l.cross_attn_img_to_token.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);\n                l.cross_attn_img_to_token.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);\n                l.cross_attn_img_to_token.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);\n                l.cross_attn_img_to_token.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);\n                l.cross_attn_img_to_token.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);\n                l.cross_attn_img_to_token.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);\n                l.cross_attn_img_to_token.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans/2, n_enc_out_chans);\n                l.cross_attn_img_to_token.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n\n                const auto prefix = \"mask_decoder.transformer.layers.\" + std::to_string(i) + \".\";\n                model.tensors[prefix + \"self_attn.q_proj.weight\"] = l.self_attn.q_w;\n                model.tensors[prefix + \"self_attn.q_proj.bias\"]   = l.self_attn.q_b;\n                model.tensors[prefix + \"self_attn.k_proj.weight\"] = l.self_attn.k_w;\n                model.tensors[prefix + \"self_attn.k_proj.bias\"]   = l.self_attn.k_b;\n                model.tensors[prefix + \"self_attn.v_proj.weight\"] = l.self_attn.v_w;\n                model.tensors[prefix + \"self_attn.v_proj.bias\"]   = l.self_attn.v_b;\n                model.tensors[prefix + \"self_attn.out_proj.weight\"] = l.self_attn.out_w;\n                model.tensors[prefix + \"self_attn.out_proj.bias\"]   = l.self_attn.out_b;\n\n                model.tensors[prefix + \"norm1.weight\"] = l.norm1_w;\n                model.tensors[prefix + \"norm1.bias\"]   = l.norm1_b;\n\n                model.tensors[prefix + \"cross_attn_token_to_image.q_proj.weight\"] = l.cross_attn_token_to_img.q_w;\n                model.tensors[prefix + \"cross_attn_token_to_image.q_proj.bias\"]   = l.cross_attn_token_to_img.q_b;\n                model.tensors[prefix + \"cross_attn_token_to_image.k_proj.weight\"] = l.cross_attn_token_to_img.k_w;\n                model.tensors[prefix + \"cross_attn_token_to_image.k_proj.bias\"]   = l.cross_attn_token_to_img.k_b;\n                model.tensors[prefix + \"cross_attn_token_to_image.v_proj.weight\"] = l.cross_attn_token_to_img.v_w;\n                model.tensors[prefix + \"cross_attn_token_to_image.v_proj.bias\"]   = l.cross_attn_token_to_img.v_b;\n                model.tensors[prefix + \"cross_attn_token_to_image.out_proj.weight\"] = l.cross_attn_token_to_img.out_w;\n                model.tensors[prefix + \"cross_attn_token_to_image.out_proj.bias\"]   = l.cross_attn_token_to_img.out_b;\n\n                model.tensors[prefix + \"norm2.weight\"] = l.norm2_w;\n                model.tensors[prefix + \"norm2.bias\"]   = l.norm2_b;\n\n                model.tensors[prefix + \"mlp.lin1.weight\"] = l.mlp_lin1_w;\n                model.tensors[prefix + \"mlp.lin1.bias\"]   = l.mlp_lin1_b;\n                model.tensors[prefix + \"mlp.lin2.weight\"] = l.mlp_lin2_w;\n                model.tensors[prefix + \"mlp.lin2.bias\"]   = l.mlp_lin2_b;\n\n                model.tensors[prefix + \"norm3.weight\"] = l.norm3_w;\n                model.tensors[prefix + \"norm3.bias\"]   = l.norm3_b;\n                model.tensors[prefix + \"norm4.weight\"] = l.norm4_w;\n                model.tensors[prefix + \"norm4.bias\"]   = l.norm4_b;\n\n                model.tensors[prefix + \"cross_attn_image_to_token.q_proj.weight\"] = l.cross_attn_img_to_token.q_w;\n                model.tensors[prefix + \"cross_attn_image_to_token.q_proj.bias\"]   = l.cross_attn_img_to_token.q_b;\n                model.tensors[prefix + \"cross_attn_image_to_token.k_proj.weight\"] = l.cross_attn_img_to_token.k_w;\n                model.tensors[prefix + \"cross_attn_image_to_token.k_proj.bias\"]   = l.cross_attn_img_to_token.k_b;\n                model.tensors[prefix + \"cross_attn_image_to_token.v_proj.weight\"] = l.cross_attn_img_to_token.v_w;\n                model.tensors[prefix + \"cross_attn_image_to_token.v_proj.bias\"]   = l.cross_attn_img_to_token.v_b;\n                model.tensors[prefix + \"cross_attn_image_to_token.out_proj.weight\"] = l.cross_attn_img_to_token.out_w;\n                model.tensors[prefix + \"cross_attn_image_to_token.out_proj.bias\"]   = l.cross_attn_img_to_token.out_b;\n            }\n\n            dec.transformer_final_attn_token_to_img.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);\n            dec.transformer_final_attn_token_to_img.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);\n            dec.transformer_final_attn_token_to_img.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);\n            dec.transformer_final_attn_token_to_img.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);\n            dec.transformer_final_attn_token_to_img.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);\n            dec.transformer_final_attn_token_to_img.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);\n            dec.transformer_final_attn_token_to_img.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans/2, n_enc_out_chans);\n            dec.transformer_final_attn_token_to_img.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n\n            model.tensors[\"mask_decoder.transformer.final_attn_token_to_image.q_proj.weight\"] = dec.transformer_final_attn_token_to_img.q_w;\n            model.tensors[\"mask_decoder.transformer.final_attn_token_to_image.q_proj.bias\"]   = dec.transformer_final_attn_token_to_img.q_b;\n            model.tensors[\"mask_decoder.transformer.final_attn_token_to_image.k_proj.weight\"] = dec.transformer_final_attn_token_to_img.k_w;\n            model.tensors[\"mask_decoder.transformer.final_attn_token_to_image.k_proj.bias\"]   = dec.transformer_final_attn_token_to_img.k_b;\n            model.tensors[\"mask_decoder.transformer.final_attn_token_to_image.v_proj.weight\"] = dec.transformer_final_attn_token_to_img.v_w;\n            model.tensors[\"mask_decoder.transformer.final_attn_token_to_image.v_proj.bias\"]   = dec.transformer_final_attn_token_to_img.v_b;\n            model.tensors[\"mask_decoder.transformer.final_attn_token_to_image.out_proj.weight\"] = dec.transformer_final_attn_token_to_img.out_w;\n            model.tensors[\"mask_decoder.transformer.final_attn_token_to_image.out_proj.bias\"]   = dec.transformer_final_attn_token_to_img.out_b;\n\n            dec.transformer_norm_final_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n            dec.transformer_norm_final_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n\n            model.tensors[\"mask_decoder.transformer.norm_final_attn.weight\"] = dec.transformer_norm_final_w;\n            model.tensors[\"mask_decoder.transformer.norm_final_attn.bias\"]   = dec.transformer_norm_final_b;\n\n            dec.output_upscaling_0_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 2, 2, n_img_embd, n_enc_out_chans);\n            dec.output_upscaling_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd);\n            dec.output_upscaling_1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd);\n            dec.output_upscaling_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd);\n            dec.output_upscaling_3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16,  2, 2, n_img_embd/2, n_img_embd);\n            dec.output_upscaling_3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd/2);\n\n            model.tensors[\"mask_decoder.output_upscaling.0.weight\"] = dec.output_upscaling_0_w;\n            model.tensors[\"mask_decoder.output_upscaling.0.bias\"]   = dec.output_upscaling_0_b;\n            model.tensors[\"mask_decoder.output_upscaling.1.weight\"] = dec.output_upscaling_1_w;\n            model.tensors[\"mask_decoder.output_upscaling.1.bias\"]   = dec.output_upscaling_1_b;\n            model.tensors[\"mask_decoder.output_upscaling.3.weight\"] = dec.output_upscaling_3_w;\n            model.tensors[\"mask_decoder.output_upscaling.3.bias\"]   = dec.output_upscaling_3_b;\n\n            const int n_hypernet_mpls_count = 4;\n            dec.output_hypernet_mlps.resize(n_hypernet_mpls_count);\n            for (int i = 0; i < n_hypernet_mpls_count; ++i) {\n                auto& mlp = dec.output_hypernet_mlps[i];\n\n                mlp.w_0 = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);\n                mlp.b_0 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n                mlp.w_1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);\n                mlp.b_1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n                mlp.w_2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_img_embd/2);\n                mlp.b_2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd/2);\n\n                const auto prefix = \"mask_decoder.output_hypernetworks_mlps.\" + std::to_string(i) + \".\";\n                model.tensors[prefix + \"layers.0.weight\"] = mlp.w_0;\n                model.tensors[prefix + \"layers.0.bias\"]   = mlp.b_0;\n                model.tensors[prefix + \"layers.1.weight\"] = mlp.w_1;\n                model.tensors[prefix + \"layers.1.bias\"]   = mlp.b_1;\n                model.tensors[prefix + \"layers.2.weight\"] = mlp.w_2;\n                model.tensors[prefix + \"layers.2.bias\"]   = mlp.b_2;\n            }\n\n            dec.iou_prediction_head_0_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);\n            dec.iou_prediction_head_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n            dec.iou_prediction_head_1_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);\n            dec.iou_prediction_head_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);\n            dec.iou_prediction_head_2_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_pt_embd);\n            dec.iou_prediction_head_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pt_embd);\n\n            dec.iou_token_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_enc_out_chans, 1);\n            dec.mask_tokens_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_enc_out_chans, n_pt_embd);\n\n            model.tensors[\"mask_decoder.iou_prediction_head.layers.0.weight\"] = dec.iou_prediction_head_0_w;\n            model.tensors[\"mask_decoder.iou_prediction_head.layers.0.bias\"]   = dec.iou_prediction_head_0_b;\n            model.tensors[\"mask_decoder.iou_prediction_head.layers.1.weight\"] = dec.iou_prediction_head_1_w;\n            model.tensors[\"mask_decoder.iou_prediction_head.layers.1.bias\"]   = dec.iou_prediction_head_1_b;\n            model.tensors[\"mask_decoder.iou_prediction_head.layers.2.weight\"] = dec.iou_prediction_head_2_w;\n            model.tensors[\"mask_decoder.iou_prediction_head.layers.2.bias\"]   = dec.iou_prediction_head_2_b;\n\n            model.tensors[\"mask_decoder.iou_token.weight\"] = dec.iou_token_w;\n            model.tensors[\"mask_decoder.mask_tokens.weight\"] = dec.mask_tokens_w;\n        }\n    }\n\n    // load weights\n    {\n        int n_tensors = 0;\n        size_t total_size = 0;\n\n        fprintf(stderr, \"%s: \", __func__);\n\n        while (true) {\n            int32_t n_dims;\n            int32_t length;\n            int32_t ftype;\n\n            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));\n            fin.read(reinterpret_cast<char *>(&length), sizeof(length));\n            fin.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));\n\n            if (fin.eof()) {\n                break;\n            }\n\n            int64_t nelements = 1;\n            int64_t ne[4] = { 1, 1, 1, 1 };\n            for (int i = 0; i < n_dims; ++i) {\n                int32_t ne_cur;\n                fin.read(reinterpret_cast<char *>(&ne_cur), sizeof(ne_cur));\n                ne[i] = ne_cur;\n                nelements *= ne[i];\n            }\n\n            std::string name(length, 0);\n            fin.read(&name[0], length);\n\n            if (model.tensors.find(name.data()) == model.tensors.end()) {\n                fprintf(stderr, \"%s: unknown tensor '%s' in model file\\n\", __func__, name.data());\n                return false;\n            }\n\n            auto tensor = model.tensors[name.data()];\n            //printf(\"ne0 = %jd, ne1 = %jd, ne2 = %jd, ne3 = %jd\\n\", ne[0], ne[1], ne[2], ne[3]);\n\n            if (ggml_nelements(tensor) != nelements) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong size in model file: got %d, expected %d\\n\",\n                        __func__, name.data(), (int) nelements, (int) ggml_nelements(tensor));\n                return false;\n            }\n\n            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2] || tensor->ne[3] != ne[3]) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d, %d], expected [%d, %d, %d, %d]\\n\",\n                        __func__, name.data(),\n                        (int) ne[0], (int) ne[1], (int) ne[2], (int) ne[3],\n                        (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], (int) tensor->ne[3]);\n                return false;\n            }\n\n            size_t bpe = 0;\n\n            switch (ftype) {\n                case 0: bpe = ggml_type_size(GGML_TYPE_F32);  break;\n                case 1: bpe = ggml_type_size(GGML_TYPE_F16);  break;\n                case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;\n                case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;\n                default:\n                        {\n                            fprintf(stderr, \"%s: unknown ftype %d in model file\\n\", __func__, ftype);\n                            return false;\n                        }\n            };\n\n            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {\n                fprintf(stderr, \"%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\\n\",\n                        __func__, name.data(), ggml_nbytes(tensor), (size_t) nelements*bpe);\n                return false;\n            }\n\n            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));\n\n            total_size += ggml_nbytes(tensor);\n            if (++n_tensors % 8 == 0) {\n                fprintf(stderr, \".\");\n                fflush(stdout);\n            }\n        }\n\n        if (n_tensors != int(model.tensors.size())) {\n            fprintf(stderr, \"%s: model file has %d tensors, but %d tensors were expected\\n\", __func__, n_tensors, (int) model.tensors.size());\n            return false;\n        }\n\n        fprintf(stderr, \" done\\n\");\n\n        fprintf(stderr, \"%s: model size = %8.2f MB / num tensors = %d\\n\", __func__, total_size/1024.0/1024.0, n_tensors);\n    }\n\n    fin.close();\n\n    return true;\n}\n\nstruct ggml_tensor * sam_fill_dense_pe(\n            const sam_model   & model,\n          struct ggml_context * ctx0,\n          struct ggml_cgraph  * gf,\n                  sam_state   & state) {\n    const auto & hparams = model.hparams;\n    const auto & enc     = model.enc_prompt;\n\n\n    const int32_t n_img_embd = hparams.n_img_embd();\n    struct ggml_tensor * xy_embed_stacked = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 2, n_img_embd, n_img_embd);\n    ggml_set_name(xy_embed_stacked, \"xy_embed_stacked\");\n    ggml_set_input(xy_embed_stacked);\n\n    struct ggml_tensor * cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, enc.pe)), xy_embed_stacked);\n\n    cur = ggml_scale(ctx0, cur, float(2.0*M_PI));\n\n    // concat\n    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192\n    {\n        struct ggml_tensor * t_sin = ggml_map_custom1(ctx0, cur, ggml_sam_sin, GGML_N_TASKS_MAX, NULL);\n        struct ggml_tensor * t_cos = ggml_map_custom1(ctx0, cur, ggml_sam_cos, GGML_N_TASKS_MAX, NULL);\n\n        cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1], cur->ne[2]);\n\n        ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_sin, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], 0)));\n        ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_cos, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], t_sin->nb[1])));\n    }\n\n    struct ggml_tensor * pe_img_dense = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));\n    ggml_build_forward_expand(gf, pe_img_dense);\n\n    return pe_img_dense;\n}\n\nstruct ggml_tensor* sam_layer_norm_2d(\n                    struct ggml_context * ctx0,\n                    struct ggml_tensor  * layer,\n                    int                   n_channels,\n                    struct ggml_tensor  * w,\n                    struct ggml_tensor  * b,\n                    float                 eps) {\n    // LayerNorm2d\n    // normalize along channel dimmension\n    // TODO: better implementation\n    layer = ggml_permute(ctx0,\n                ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3)), eps),\n                2, 0, 1, 3);\n\n    layer = ggml_add(ctx0,\n              ggml_mul(ctx0,\n                  ggml_repeat(ctx0, ggml_reshape_3d(ctx0, w, 1, 1, n_channels), layer),\n                  layer),\n              ggml_repeat(ctx0, ggml_reshape_3d(ctx0, b, 1, 1, n_channels), layer));\n\n    return layer;\n}\n\nstruct ggml_cgraph  * sam_encode_image(\n            const sam_model & model,\n                  sam_state & state,\n        const sam_image_f32 & img) {\n\n    const auto & hparams = model.hparams;\n    const auto & enc     = model.enc_img;\n\n    const int32_t n_enc_state     = hparams.n_enc_state;\n    const int32_t n_enc_layer     = hparams.n_enc_layer;\n    const int32_t n_enc_head      = hparams.n_enc_head;\n    const int32_t n_enc_head_dim  = hparams.n_enc_head_dim();\n    const int32_t n_enc_out_chans = hparams.n_enc_out_chans;\n    const int32_t n_img_size    = hparams.n_img_size();\n    const int32_t n_window_size = hparams.n_window_size();\n\n    struct ggml_init_params ggml_params = {\n        /*.mem_size   =*/ state.buf_compute_img_enc.size(),\n        /*.mem_buffer =*/ state.buf_compute_img_enc.data(),\n        /*.no_alloc   =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements\n    };\n\n    struct ggml_context * ctx0   = ggml_init(ggml_params);\n    struct ggml_cgraph  * gf     = ggml_new_graph(ctx0);\n\n    struct ggml_tensor * inp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_img_size, n_img_size, 3, 1);\n    ggml_set_name(inp, \"inp\");\n    ggml_set_input(inp);\n\n    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L392\n    struct ggml_tensor * cur = ggml_conv_2d_sk_p0(ctx0, enc.proj_w, inp);\n    cur = ggml_add_inplace(ctx0,\n            cur,\n            ggml_repeat(ctx0, enc.proj_b, cur));\n\n    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L394\n    // keep in F32\n    cur = ggml_cont(ctx0,\n            ggml_permute(ctx0, cur, 1, 2, 0, 3));\n\n    // convert to F16\n    //cur = ggml_cpy(ctx0,\n    //        ggml_permute(ctx0, cur, 1, 2, 0, 3),\n    //        ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_enc_state, n_img_embd, n_img_embd));\n\n    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L108-L109\n    cur = ggml_add_inplace(ctx0, cur, enc.pe);\n\n    struct ggml_tensor * inpL = cur;\n\n    for (int il = 0; il < n_enc_layer; ++il) {\n        const auto & layer = enc.layers[il];\n\n        // norm\n        // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L168\n        {\n            cur = ggml_norm(ctx0, inpL, hparams.eps);\n\n            // cur = ln_0_w*cur + ln_0_b\n            cur = ggml_mul(ctx0, cur, layer.norm1_w);\n            cur = ggml_add_inplace(ctx0, cur, layer.norm1_b);\n        }\n\n        const int64_t w0 = cur->ne[1];\n        const int64_t h0 = cur->ne[2];\n\n        if (hparams.is_global_attn(il) == false) {\n            // local attention layer - apply window partition\n            // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172\n            cur = ggml_win_part(ctx0, cur, n_window_size);\n        }\n\n        const int64_t W = cur->ne[1];\n        const int64_t H = cur->ne[2];\n\n        // self-attention\n        {\n            cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);\n            cur = ggml_add_inplace(ctx0, cur, layer.qkv_b);\n\n            // split qkv into separate tensors\n            // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L225-L229\n            const int B = cur->ne[3];\n\n            cur = ggml_reshape_4d(ctx0, cur, n_enc_state, 3, W*H, B);\n            cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 3, 1, 2));\n\n            struct ggml_tensor * Q;\n            struct ggml_tensor * K;\n            struct ggml_tensor * V;\n\n            Q = ggml_view_3d   (ctx0, cur, n_enc_state, W*H, B, cur->nb[1], cur->nb[2], 0*cur->nb[3]);\n            Q = ggml_reshape_4d(ctx0, Q,   n_enc_head_dim, n_enc_head, W*H, B);\n            Q = ggml_cont      (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));\n            Q = ggml_reshape_3d(ctx0, Q,   n_enc_head_dim, W*H, B*n_enc_head);\n\n            K = ggml_view_3d   (ctx0, cur, n_enc_state, W*H, B, cur->nb[1], cur->nb[2], 1*cur->nb[3]);\n            K = ggml_reshape_4d(ctx0, K,   n_enc_head_dim, n_enc_head, W*H, B);\n            K = ggml_cont      (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));\n            K = ggml_reshape_3d(ctx0, K,   n_enc_head_dim, W*H, B*n_enc_head);\n\n            V = ggml_view_3d   (ctx0, cur, n_enc_state, W*H, B, cur->nb[1], cur->nb[2], 2*cur->nb[3]);\n            V = ggml_reshape_4d(ctx0, V,   n_enc_head_dim, n_enc_head, W*H, B);\n            V = ggml_cont      (ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); // transposed\n            V = ggml_reshape_3d(ctx0, V,   W*H, n_enc_head_dim, B*n_enc_head);\n\n            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);\n\n            struct ggml_tensor * KQ_scaled =\n                ggml_scale_inplace(ctx0,\n                        KQ,\n                        1.0f/sqrtf(n_enc_head_dim));\n\n            struct ggml_tensor * rw = ggml_get_rel_pos(ctx0, layer.rel_pos_w, W, W);\n            struct ggml_tensor * rh = ggml_get_rel_pos(ctx0, layer.rel_pos_h, H, H);\n\n            struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Q, n_enc_head_dim, W, H, B*n_enc_head);\n\n            struct ggml_tensor * rel_w = ggml_cont(ctx0, ggml_permute(ctx0,\n                        ggml_mul_mat(ctx0,\n                            rw,\n                            ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))),\n                        0, 2, 1, 3));\n            struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r);\n\n            struct ggml_tensor * attn = ggml_add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h);\n\n            struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn);\n\n            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);\n\n            cur =\n                ggml_reshape_4d(ctx0,\n                        ggml_cont(ctx0,\n                            ggml_permute(ctx0,\n                                ggml_reshape_4d(ctx0, KQV, n_enc_head_dim, W*H, n_enc_head, B),\n                                0, 2, 1, 3)),\n                        n_enc_state, W, H, B);\n\n            cur = ggml_mul_mat(ctx0, layer.proj_w, cur);\n            cur = ggml_add_inplace(ctx0, cur, layer.proj_b);\n        }\n\n        if (hparams.is_global_attn(il) == false) {\n            // local attention layer - reverse window partition\n            cur = ggml_win_unpart(ctx0, cur, w0, h0, n_window_size);\n        }\n\n        cur = ggml_add_inplace(ctx0, cur, inpL);\n\n        struct ggml_tensor * inpFF = cur;\n\n        // feed-forward network\n        {\n            // norm\n            {\n                cur = ggml_norm(ctx0, inpFF, hparams.eps);\n\n                // cur = mlp_ln_w*cur + mlp_ln_b\n                cur = ggml_mul(ctx0, cur, layer.norm2_w);\n                cur = ggml_add_inplace(ctx0, cur, layer.norm2_b);\n            }\n\n            // fully connected\n            cur = ggml_mul_mat(ctx0, layer.mlp_lin1_w, cur);\n            cur = ggml_add_inplace(ctx0, cur, layer.mlp_lin1_b);\n\n            // GELU activation\n            cur = ggml_gelu(ctx0, cur);\n\n            // projection\n            cur = ggml_mul_mat(ctx0, layer.mlp_lin2_w, cur);\n            cur = ggml_add_inplace(ctx0, cur, layer.mlp_lin2_b);\n        }\n\n        inpL = ggml_add(ctx0, cur, inpFF);\n    }\n\n    cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3));\n\n    cur = ggml_conv_2d_sk_p0(ctx0, enc.neck_conv_0, cur);\n\n    cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_0_w, enc.neck_norm_0_b, hparams.eps);\n\n    cur = ggml_conv_2d_s1_ph(ctx0, enc.neck_conv_1, cur);\n\n    cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b, hparams.eps);\n\n    cur = ggml_cpy(ctx0, cur, state.embd_img);\n\n    ggml_build_forward_expand(gf, cur);\n    ggml_disconnect_node_from_graph(state.embd_img);\n\n    //ggml_graph_print(&gf);\n\n    ggml_free(ctx0);\n\n    ggml_gallocr_alloc_graph(state.allocr, gf);\n\n    {\n        struct ggml_tensor * inp = ggml_graph_get_tensor(gf, \"inp\");\n        float * data = (float *) ggml_get_data(inp);\n\n        const int nx = img.nx;\n        const int ny = img.ny;\n        const int n  = nx*ny;\n\n        GGML_ASSERT(nx == n_img_size && ny == n_img_size);\n\n        for (int k = 0; k < 3; k++) {\n            for (int y = 0; y < ny; y++) {\n                for (int x = 0; x < nx; x++) {\n                    data[k*n + y*nx + x] = img.data[3*(y*nx + x) + k];\n                }\n            }\n        }\n    }\n\n    return gf;\n}\n\n\nstruct prompt_encoder_result {\n    struct ggml_tensor * embd_prompt_sparse = {};\n    struct ggml_tensor * embd_prompt_dense = {};\n};\n\nstruct ggml_tensor * sam_prompt_encode_pe_encoding(\n        const sam_encoder_prompt & enc,\n        struct ggml_context      * ctx0,\n        struct ggml_cgraph       * gf,\n        struct ggml_tensor       * coords) {\n\n    auto * cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, enc.pe)), coords);\n    cur = ggml_scale(ctx0, cur, float(2.0*M_PI));\n\n    // concat\n    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192\n    {\n        struct ggml_tensor * t_sin = ggml_map_custom1(ctx0, cur, ggml_sam_sin, GGML_N_TASKS_MAX, NULL);\n        struct ggml_tensor * t_cos = ggml_map_custom1(ctx0, cur, ggml_sam_cos, GGML_N_TASKS_MAX, NULL);\n\n        cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1]);\n\n        ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_sin, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], 0)));\n        ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_cos, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], t_sin->nb[1])));\n\n    }\n    return cur;\n\n}\n// encode a prompt\n//\n// - points\n// - boxes\n// - masks\n//\n// TODO: currently just encode a single point for simplicity\n//\nprompt_encoder_result sam_encode_prompt(\n        const sam_model     & model,\n        struct ggml_context * ctx0,\n        struct ggml_cgraph  * gf,\n                  sam_state & state,\n           const sam_prompt & prompt) {\n\n    const auto & hparams = model.hparams;\n    const auto & enc = model.enc_prompt;\n\n    struct ggml_tensor * inp = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2, 2);\n    ggml_set_name(inp, \"prompt_input\");\n    ggml_set_input(inp);\n\n    auto * embd_prompt_sparse = [&]() -> struct ggml_tensor * {\n        switch (prompt.prompt_type) {\n        case SAM_PROMPT_TYPE_POINT: {\n            // PromptEncoder._embed_points\n            auto * pt_embd  = sam_prompt_encode_pe_encoding(enc, ctx0, gf, inp);\n\n            // overwrite label == -1 with not_a_point_embed.weight\n            // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L86\n            // TODO: extend for multiple points\n            auto * pt_embd_not = ggml_view_2d(ctx0, pt_embd, pt_embd->ne[0], 1, pt_embd->nb[1], pt_embd->nb[1]);\n            ggml_build_forward_expand(gf, ggml_cpy(ctx0, enc.not_a_pt_embd_w, pt_embd_not));\n\n            // add point_embeddings[1] to label == 1\n            // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L90\n            auto * pt_embd1 = ggml_view_2d(ctx0, pt_embd, pt_embd->ne[0], 1, pt_embd->nb[1], 0);\n            ggml_build_forward_expand(gf,  ggml_add_inplace(ctx0, pt_embd1, enc.pt_embd[1]));\n\n            return pt_embd;\n        } break;\n        case SAM_PROMPT_TYPE_BOX: {\n            // PromptEncoder._embed_boxes\n            auto * corner_embd = sam_prompt_encode_pe_encoding(enc, ctx0, gf, inp);\n\n            // corner_embd[:, 0, :] += self.point_embeddings[2].weight\n            // corner_embd[:, 1, :] += self.point_embeddings[3].weight\n            auto * corner0 = ggml_view_2d(\n                ctx0, corner_embd, corner_embd->ne[0], 1, corner_embd->nb[1], 0);\n            auto * corner1 = ggml_view_2d(\n                ctx0, corner_embd, corner_embd->ne[0], 1, corner_embd->nb[1], corner_embd->nb[1]);\n\n            ggml_build_forward_expand(gf, ggml_add_inplace(ctx0, corner0, enc.pt_embd[2]));\n            ggml_build_forward_expand(gf, ggml_add_inplace(ctx0, corner1, enc.pt_embd[3]));\n\n            return corner_embd;\n        } break;\n        default: {\n            fprintf(stderr, \"%s: unsupported prompt type %d\\n\", __func__, prompt.prompt_type);\n            return nullptr;\n        } break;\n        }\n    }();\n\n    ggml_build_forward_expand(gf, embd_prompt_sparse);\n\n    struct ggml_tensor * embd_prompt_dense = ggml_repeat(ctx0,\n            ggml_cont(ctx0,\n                ggml_view_3d(ctx0, enc.no_mask_embd_w,\n                    1, 1, enc.no_mask_embd_w->ne[0], enc.no_mask_embd_w->nb[0], enc.no_mask_embd_w->nb[0], 0)),\n            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hparams.n_img_embd(), hparams.n_img_embd(), hparams.n_enc_out_chans));\n\n    ggml_build_forward_expand(gf, embd_prompt_dense);\n\n    //printf(\"used_mem = %zu\\n\", ggml_used_mem(ctx0));\n\n    prompt_encoder_result res;\n    res.embd_prompt_sparse = embd_prompt_sparse;\n    res.embd_prompt_dense  = embd_prompt_dense;\n    return res;\n}\n\nstruct ggml_tensor* sam_decode_mask_transformer_attn(\n    const sam_layer_dec_transformer_attn & attn,\n                      struct ggml_tensor * queries,\n                      struct ggml_tensor * keys,\n                      struct ggml_tensor * values,\n                     struct ggml_context * ctx0,\n                         const sam_model & model) {\n    const auto & hparams = model.hparams;\n    const int n_head = hparams.n_dec_heads;\n\n    struct ggml_tensor * Qcur = {};\n    struct ggml_tensor * Kcur = {};\n    struct ggml_tensor * Vcur = {};\n\n    Qcur = ggml_mul_mat(ctx0, attn.q_w, queries);\n    Qcur = ggml_add_inplace(ctx0, Qcur, attn.q_b);\n\n    Kcur = ggml_mul_mat(ctx0, attn.k_w, keys);\n    Kcur = ggml_add_inplace(ctx0, Kcur, attn.k_b);\n\n    Vcur = ggml_mul_mat(ctx0, attn.v_w, values);\n    Vcur = ggml_add_inplace(ctx0, Vcur, attn.v_b);\n\n    struct ggml_tensor * Q = {};\n    struct ggml_tensor * K = {};\n    struct ggml_tensor * V = {};\n\n    Q = ggml_reshape_4d(ctx0, Qcur, Qcur->ne[0]/n_head, n_head, Qcur->ne[1], Qcur->ne[2]);\n    Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));\n\n    K = ggml_reshape_4d(ctx0, Kcur, Kcur->ne[0]/n_head, n_head, Kcur->ne[1], Kcur->ne[2]);\n    K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));\n\n    V = ggml_reshape_4d(ctx0, Vcur, Vcur->ne[0]/n_head, n_head, Vcur->ne[1], Vcur->ne[2]);\n    V = ggml_cont(ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3));\n\n    // Q * K\n    struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);\n\n    struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f/sqrt(float(Q->ne[0])));\n\n    struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_scaled);\n\n    struct ggml_tensor * KQV = ggml_mul_mat(ctx0, KQ_soft_max, ggml_cont(ctx0, ggml_transpose(ctx0, V)));\n\n    struct ggml_tensor * KQV_merged = ggml_cont(ctx0, ggml_transpose(ctx0, KQV));\n    KQV_merged = ggml_cont(ctx0, ggml_permute(ctx0, KQV_merged, 0, 2, 1, 3));\n    KQV_merged = ggml_reshape_3d(ctx0, KQV_merged, KQV_merged->ne[0]*KQV_merged->ne[1], KQV_merged->ne[2], KQV_merged->ne[3]);\n    KQV_merged = ggml_mul_mat(ctx0, attn.out_w, KQV_merged);\n    KQV_merged = ggml_add_inplace(ctx0, KQV_merged, attn.out_b);\n\n    return KQV_merged;\n}\n\nstruct ggml_tensor * sam_decode_mask_mlp_relu_3(\n     struct ggml_tensor * in,\n     struct ggml_tensor * w_0,\n     struct ggml_tensor * b_0,\n     struct ggml_tensor * w_1,\n     struct ggml_tensor * b_1,\n     struct ggml_tensor * w_2,\n     struct ggml_tensor * b_2,\n    struct ggml_context * ctx0) {\n\n    struct ggml_tensor * cur = {};\n    cur = ggml_mul_mat(ctx0, w_0, in);\n    cur = ggml_add_inplace(ctx0, cur, b_0);\n\n    cur = ggml_relu_inplace(ctx0, cur);\n\n    cur = ggml_mul_mat(ctx0, w_1, cur);\n    cur = ggml_add_inplace(ctx0, cur, b_1);\n\n    cur = ggml_relu_inplace(ctx0, cur);\n\n    cur = ggml_mul_mat(ctx0, w_2, cur);\n    cur = ggml_add_inplace(ctx0, cur, b_2);\n\n    return cur;\n}\n\nbool sam_decode_mask(\n                    const sam_model & model,\n        const prompt_encoder_result & prompt,\n                 struct ggml_tensor * pe_img,\n                struct ggml_context * ctx0,\n                struct ggml_cgraph  * gf,\n                          sam_state & state,\n                         const bool   multimask_output) {\n\n    const auto & hparams = model.hparams;\n    const auto & dec = model.dec;\n    const int n_img_embd = hparams.n_img_embd();\n\n    struct ggml_tensor * tokens = {};\n    {\n        // Concatenate output tokens\n        // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L120\n        const auto& sparse = prompt.embd_prompt_sparse;\n\n        tokens = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, dec.iou_token_w->ne[0], dec.iou_token_w->ne[1] + dec.mask_tokens_w->ne[1] + sparse->ne[1], sparse->ne[2]);\n\n        const size_t offsets[3] = { 0, dec.iou_token_w->ne[1]*tokens->nb[1], dec.iou_token_w->ne[1]*tokens->nb[1] + dec.mask_tokens_w->ne[1]*tokens->nb[1] };\n        ggml_build_forward_expand(gf, ggml_cpy(ctx0, dec.iou_token_w,   ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.iou_token_w->ne[1],   tokens->nb[1], offsets[0])));\n        ggml_build_forward_expand(gf, ggml_cpy(ctx0, dec.mask_tokens_w, ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.mask_tokens_w->ne[1], tokens->nb[1], offsets[1])));\n        ggml_build_forward_expand(gf, ggml_cpy(ctx0, sparse,            ggml_view_2d(ctx0, tokens, tokens->ne[0], sparse->ne[1],            tokens->nb[1], offsets[2])));\n        // TODO: Sparse prompt embeddings can have more than one point\n    }\n\n\n    struct ggml_tensor * src = {};\n    struct ggml_tensor * pos_src = {};\n    int srcNE[4] = { 0, 0, 0, 0 };\n    {\n        // Expand per-image data in the batch direction to be per-mask\n        // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L125\n        src = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, state.embd_img->ne[0], state.embd_img->ne[1], state.embd_img->ne[2], tokens->ne[2]);\n\n        src = ggml_add(ctx0,\n            ggml_repeat(ctx0,\n                state.embd_img,\n                src),\n            prompt.embd_prompt_dense);\n\n        srcNE[0] = src->ne[0];\n        srcNE[1] = src->ne[1];\n        srcNE[2] = src->ne[2];\n        srcNE[3] = src->ne[3];\n\n        // flatten & permute\n        // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L83\n        src = ggml_cont(ctx0, ggml_permute(ctx0,\n            ggml_view_3d(ctx0,\n                src,\n                src->ne[0]*src->ne[1],\n                src->ne[2],\n                src->ne[3],\n                src->nb[2],\n                src->nb[3],\n                0),\n            1, 0, 2, 3));\n\n        pos_src = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, pe_img->ne[0], pe_img->ne[1], pe_img->ne[2], tokens->ne[2]);\n        pos_src = ggml_repeat(ctx0,\n            pe_img,\n            pos_src);\n\n        // flatten & permute\n        // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L83\n        pos_src = ggml_cont(ctx0, ggml_permute(ctx0,\n            ggml_view_3d(ctx0,\n                pos_src,\n                pos_src->ne[0]*pos_src->ne[1],\n                pos_src->ne[2],\n                pos_src->ne[3],\n                pos_src->nb[2],\n                pos_src->nb[3],\n                0),\n            1, 0, 2, 3));\n    }\n\n    struct ggml_tensor * queries = tokens;\n    struct ggml_tensor * keys = src;\n    {\n        // Run the transformer\n        // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L62\n        for (int i = 0; i < int(model.dec.transformer_layers.size()); ++i) {\n            const auto& tfm_layer = model.dec.transformer_layers[i];\n\n            // Self attention block\n            // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L154\n            const bool skip_first_layer_pe = i == 0;\n            if (skip_first_layer_pe) {\n                queries = sam_decode_mask_transformer_attn(tfm_layer.self_attn, queries, queries, queries, ctx0, model);\n            }\n            else {\n                struct ggml_tensor * q_0 = ggml_add(ctx0, queries, tokens);\n\n                struct ggml_tensor * self_attn = sam_decode_mask_transformer_attn(tfm_layer.self_attn, q_0, q_0, queries, ctx0, model);\n                queries = ggml_add(ctx0, queries, self_attn);\n            }\n\n            queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer);\n            queries = ggml_add_inplace(ctx0,\n                    ggml_mul(ctx0, queries, tfm_layer.norm1_w),\n                    tfm_layer.norm1_b);\n\n            // Cross attention block, tokens attending to image embedding\n            // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L163\n            struct ggml_tensor * q_1 = ggml_add(ctx0, queries, tokens);\n            struct ggml_tensor * k_1 = ggml_add(ctx0, keys, pos_src);\n\n            struct ggml_tensor * cross_attn_token_to_img = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_token_to_img, q_1, k_1, keys, ctx0, model);\n\n            queries = ggml_add_inplace(ctx0, queries, cross_attn_token_to_img);\n            queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer);\n            queries = ggml_add_inplace(ctx0,\n                    ggml_mul(ctx0, queries, tfm_layer.norm2_w),\n                    tfm_layer.norm2_b);\n\n            // MLP block\n            // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L170\n            struct ggml_tensor * mlp_out = ggml_mul_mat(ctx0,\n                tfm_layer.mlp_lin1_w,\n                queries);\n\n            mlp_out = ggml_add_inplace(ctx0, mlp_out, tfm_layer.mlp_lin1_b);\n\n            // RELU activation\n            mlp_out = ggml_relu_inplace(ctx0, mlp_out);\n            mlp_out = ggml_mul_mat(ctx0, tfm_layer.mlp_lin2_w, mlp_out);\n\n            mlp_out = ggml_add_inplace(ctx0, mlp_out, tfm_layer.mlp_lin2_b);\n\n            queries = ggml_add_inplace(ctx0, queries, mlp_out);\n            queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer);\n            queries = ggml_add_inplace(ctx0,\n                    ggml_mul(ctx0, queries, tfm_layer.norm3_w),\n                    tfm_layer.norm3_b);\n\n            // Cross attention block, image embedding attending to tokens\n            // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L175\n            struct ggml_tensor * q_2 = ggml_add(ctx0, queries, tokens);\n            struct ggml_tensor * k_2 = ggml_add(ctx0, keys, pos_src);\n\n            struct ggml_tensor * cross_attn_img_to_token = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_img_to_token, k_2, q_2, queries, ctx0, model);\n            keys = ggml_add_inplace(ctx0, keys, cross_attn_img_to_token);\n            keys = ggml_norm_inplace(ctx0, keys, hparams.eps_decoder_transformer);\n            keys = ggml_add_inplace(ctx0,\n                    ggml_mul(ctx0, keys, tfm_layer.norm4_w),\n                    tfm_layer.norm4_b);\n        }\n\n        // Apply the final attention layer from the points to the image\n        // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L99\n        struct ggml_tensor * q = ggml_add(ctx0, queries, tokens);\n        struct ggml_tensor * k = ggml_add(ctx0, keys, pos_src);\n\n        struct ggml_tensor * final_attn_token_to_img = sam_decode_mask_transformer_attn(dec.transformer_final_attn_token_to_img, q, k, keys, ctx0, model);\n\n        queries = ggml_add_inplace(ctx0, queries, final_attn_token_to_img);\n        queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer);\n        queries = ggml_add_inplace(ctx0,\n                ggml_mul(ctx0, queries, dec.transformer_norm_final_w),\n                dec.transformer_norm_final_b);\n    }\n\n\n    struct ggml_tensor * iou_pred = ggml_view_2d(ctx0, queries, queries->ne[0], queries->ne[2], queries->nb[2], 0);\n    const int num_mask_tokens = 4; // num_multimask_outputs + 1\n    struct ggml_tensor * mask_tokens_out = ggml_view_3d(ctx0, queries, queries->ne[0], num_mask_tokens, queries->ne[2], queries->nb[1], num_mask_tokens*queries->nb[1], queries->nb[1]);\n\n    // Upscale mask embeddings and predict masks using the mask tokens\n    // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L136\n    keys = ggml_cont(ctx0, ggml_transpose(ctx0, keys));\n    keys = ggml_view_4d(ctx0, keys, srcNE[0], srcNE[1], srcNE[2], srcNE[3], srcNE[0]*keys->nb[0], keys->nb[1], keys->nb[2], 0);\n    // ggml_build_forward_expand(gf, keys);\n    struct ggml_tensor * upscaled_embedding = {};\n    {\n        // ConvTranspose2d\n        keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_0_w, keys, 2);\n        keys = ggml_add_inplace(ctx0, keys, ggml_repeat(ctx0,\n                                     ggml_reshape_3d(ctx0, dec.output_upscaling_0_b, 1, 1, dec.output_upscaling_0_b->ne[0]),\n                                     keys));\n\n        keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b, hparams.eps);\n\n        // GELU activation\n        keys = ggml_gelu_inplace(ctx0, keys);\n\n        // ConvTranspose2d\n        keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_3_w, keys, 2);\n        keys = ggml_add_inplace(ctx0, ggml_repeat(ctx0,\n                                ggml_reshape_3d(ctx0, dec.output_upscaling_3_b, 1, 1, dec.output_upscaling_3_b->ne[0]),\n                                keys), keys);\n        // GELU activation\n        keys = ggml_gelu_inplace(ctx0, keys);\n        upscaled_embedding = ggml_reshape_3d(ctx0, keys, keys->ne[0]*keys->ne[1], keys->ne[2], keys->ne[3]);\n        upscaled_embedding = ggml_cont(ctx0, ggml_transpose(ctx0, upscaled_embedding)); // TODO: Shouldn't be needed\n    }\n\n    struct ggml_tensor * hyper_in = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_img_embd/2, num_mask_tokens, mask_tokens_out->ne[2]);\n\n    for (int i = 0; i < num_mask_tokens; ++i) {\n        const auto& mlp = dec.output_hypernet_mlps[i];\n        struct ggml_tensor * in = ggml_view_2d(ctx0, mask_tokens_out, mask_tokens_out->ne[0], mask_tokens_out->ne[2], mask_tokens_out->nb[1], i*mask_tokens_out->nb[1]);\n        struct ggml_tensor * out = sam_decode_mask_mlp_relu_3(in, mlp.w_0, mlp.b_0, mlp.w_1, mlp.b_1, mlp.w_2, mlp.b_2, ctx0);\n        ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, ggml_view_2d(ctx0, hyper_in, hyper_in->ne[0], hyper_in->ne[2], hyper_in->nb[1], i*hyper_in->nb[1])));\n    }\n\n    struct ggml_tensor * masks = ggml_mul_mat(ctx0, hyper_in, upscaled_embedding);\n    masks = ggml_cont(ctx0, ggml_transpose(ctx0, masks)); // TODO: Shouldn't be needed\n    masks = ggml_reshape_4d(ctx0, masks, keys->ne[0], keys->ne[1], masks->ne[1], keys->ne[3]);\n\n    // Generate mask quality predictions\n    // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L146\n    iou_pred = sam_decode_mask_mlp_relu_3(iou_pred, dec.iou_prediction_head_0_w, dec.iou_prediction_head_0_b, dec.iou_prediction_head_1_w, dec.iou_prediction_head_1_b, dec.iou_prediction_head_2_w, dec.iou_prediction_head_2_b, ctx0);\n\n    // Select the correct mask or masks for output\n    // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L101\n    if (multimask_output) {\n        iou_pred = ggml_cpy(state.ctx, ggml_view_1d(ctx0, iou_pred, iou_pred->ne[0] - 1, iou_pred->nb[0]), state.iou_predictions);\n        masks = ggml_view_4d(ctx0, masks, masks->ne[0], masks->ne[1], masks->ne[2] - 1, masks->ne[3],\n                                        masks->nb[1], masks->nb[2], masks->nb[3], masks->nb[2] /* offset*/);\n        masks = ggml_cpy(state.ctx, masks, state.low_res_masks);\n    } else {\n        iou_pred = ggml_cpy(state.ctx, ggml_view_1d(ctx0, iou_pred, 1, 0), ggml_view_1d(ctx0, state.iou_predictions, 1, 0));\n        masks = ggml_view_4d(ctx0, masks, masks->ne[0], masks->ne[1], 1, masks->ne[3],\n                                        masks->nb[1], masks->nb[2], masks->nb[3], 0);\n        auto * low_res_mask = ggml_view_4d(ctx0, state.low_res_masks, masks->ne[0], masks->ne[1], 1, masks->ne[3],\n                                        masks->nb[1], masks->nb[2], masks->nb[3], 0);\n        masks = ggml_cpy(state.ctx, masks, low_res_mask);\n    }\n\n    ggml_build_forward_expand(gf, masks);\n    ggml_build_forward_expand(gf, iou_pred);\n\n    ggml_disconnect_node_from_graph(state.low_res_masks);\n    ggml_disconnect_node_from_graph(state.iou_predictions);\n\n    return true;\n}\n\nbool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state, const std::string & fname, const bool multimask_output) {\n    if (state.low_res_masks->ne[2] == 0) return true;\n    if (state.low_res_masks->ne[2] != state.iou_predictions->ne[0]) {\n        printf(\"Error: number of masks (%d) does not match number of iou predictions (%d)\\n\", (int)state.low_res_masks->ne[2], (int)state.iou_predictions->ne[0]);\n        return false;\n    }\n\n    const int n_img_size = hparams.n_img_size();\n    const float mask_threshold = hparams.mask_threshold;\n    const float iou_threshold = hparams.iou_threshold;\n    const float stability_score_threshold = hparams.stability_score_threshold;\n    const float intersection_threshold = mask_threshold + hparams.stability_score_offset;\n    const float union_threshold = mask_threshold - hparams.stability_score_offset;\n\n    const int ne0 = state.low_res_masks->ne[0];\n    const int ne1 = state.low_res_masks->ne[1];\n    const int ne2 = multimask_output ? state.low_res_masks->ne[2] : 1;\n\n    // Remove padding and upscale masks to the original image size.\n    // ref: https://github.com/facebookresearch/segment-anything/blob/efeab7296ab579d4a261e554eca80faf6b33924a/segment_anything/modeling/sam.py#L140\n\n    const float preprocess_scale = std::max(nx, ny) / float(n_img_size);\n    const int cropped_nx = int(nx / preprocess_scale + 0.5f);\n    const int cropped_ny = int(ny / preprocess_scale + 0.5f);\n\n    const float scale_x_1 = (float)ne0 / (float)n_img_size;\n    const float scale_y_1 = (float)ne1 / (float)n_img_size;\n\n    const float scale_x_2 = float(cropped_nx) / float(nx);\n    const float scale_y_2 = float(cropped_ny) / float(ny);\n\n    const auto iou_data = (float*)state.iou_predictions->data;\n\n    for (int i = 0; i < ne2; ++i) {\n        if (iou_threshold > 0.f && iou_data[i] < iou_threshold) {\n            printf(\"Skipping mask %d with iou %f below threshold %f\\n\", i, iou_data[i], iou_threshold);\n            continue; // Filtering masks with iou below the threshold\n        }\n\n        std::vector<float> mask_data(n_img_size*n_img_size);\n        {\n            const float* data = (float *) state.low_res_masks->data + i*ne0*ne1;\n\n            for (int iy = 0; iy < n_img_size; ++iy) {\n                for (int ix = 0; ix < n_img_size; ++ix) {\n                    const float sx = std::max(scale_x_1*(ix + 0.5f) - 0.5f, 0.0f);\n                    const float sy = std::max(scale_y_1*(iy + 0.5f) - 0.5f, 0.0f);\n\n                    const int x0 = std::max(0, (int)sx);\n                    const int y0 = std::max(0, (int)sy);\n\n                    const int x1 = std::min(x0 + 1, ne0 - 1);\n                    const int y1 = std::min(y0 + 1, ne1 - 1);\n\n                    const float dx = sx - x0;\n                    const float dy = sy - y0;\n\n                    const int j00 = y0*ne0 + x0;\n                    const int j01 = y0*ne0 + x1;\n                    const int j10 = y1*ne0 + x0;\n                    const int j11 = y1*ne0 + x1;\n\n                    const float v00 = data[j00];\n                    const float v01 = data[j01];\n                    const float v10 = data[j10];\n                    const float v11 = data[j11];\n\n                    const float v0 = (1-dx)*v00 + dx*v01;\n                    const float v1 = (1-dx)*v10 + dx*v11;\n\n                    const float v = (1-dy)*v0 + dy*v1;\n\n                    mask_data[iy*n_img_size + ix] = v;\n                }\n            }\n        }\n\n        int intersections = 0;\n        int unions = 0;\n        sam_image_u8 res;\n        int min_iy = ny;\n        int max_iy = 0;\n        int min_ix = nx;\n        int max_ix = 0;\n        {\n            const float* data = mask_data.data();\n\n            res.nx = nx;\n            res.ny = ny;\n            res.data.resize(nx*ny);\n\n            for (int iy = 0; iy < ny; ++iy) {\n                for (int ix = 0; ix < nx; ++ix) {\n                    const float sx = std::max(scale_x_2*(ix + 0.5f) - 0.5f, 0.0f);\n                    const float sy = std::max(scale_y_2*(iy + 0.5f) - 0.5f, 0.0f);\n\n                    const int x0 = std::max(0, (int)sx);\n                    const int y0 = std::max(0, (int)sy);\n\n                    const int x1 = std::min(x0 + 1, cropped_nx - 1);\n                    const int y1 = std::min(y0 + 1, cropped_ny - 1);\n\n                    const float dx = sx - x0;\n                    const float dy = sy - y0;\n\n                    const int j00 = y0*n_img_size + x0;\n                    const int j01 = y0*n_img_size + x1;\n                    const int j10 = y1*n_img_size + x0;\n                    const int j11 = y1*n_img_size + x1;\n\n                    const float v00 = data[j00];\n                    const float v01 = data[j01];\n                    const float v10 = data[j10];\n                    const float v11 = data[j11];\n\n                    const float v0 = (1-dx)*v00 + dx*v01;\n                    const float v1 = (1-dx)*v10 + dx*v11;\n\n                    const float v = (1-dy)*v0 + dy*v1;\n\n                    if (v > intersection_threshold) {\n                        intersections++;\n                    }\n                    if (v > union_threshold) {\n                        unions++;\n                    }\n                    if (v > mask_threshold) {\n                        min_iy = std::min(min_iy, iy);\n                        max_iy = std::max(max_iy, iy);\n                        min_ix = std::min(min_ix, ix);\n                        max_ix = std::max(max_ix, ix);\n\n                        res.data[iy*nx + ix] = 255;\n                    }\n                }\n            }\n        }\n\n        const float stability_score = float(intersections) / float(unions);\n        if (stability_score_threshold > 0.f && stability_score < stability_score_threshold) {\n            printf(\"Skipping mask %d with stability score %f below threshold %f\\n\", i, stability_score, stability_score_threshold);\n            continue; // Filtering masks with stability score below the threshold\n        }\n\n        printf(\"Mask %d: iou = %f, stability_score = %f, bbox (%d, %d), (%d, %d)\\n\",\n                i, iou_data[i], stability_score, min_ix, max_ix, min_iy, max_iy);\n\n        const std::string filename = multimask_output ? fname + std::to_string(i) + \".png\" : fname + \".png\";\n        if (!stbi_write_png(filename.c_str(), res.nx, res.ny, 1, res.data.data(), res.nx)) {\n            printf(\"%s: failed to write mask %s\\n\", __func__, filename.c_str());\n            return false;\n        }\n    }\n\n\n    return true;\n}\n\n\nstruct ggml_cgraph  * sam_build_fast_graph(\n        const sam_model  & model,\n               sam_state & state,\n               const int   nx,\n               const int   ny,\n        const sam_prompt & prompt,\n              const bool   multimask_output) {\n\n    struct ggml_init_params ggml_params = {\n        /*.mem_size   =*/ state.buf_compute_fast.size(),\n        /*.mem_buffer =*/ state.buf_compute_fast.data(),\n        /*.no_alloc   =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements\n    };\n\n    struct ggml_context * ctx0   = ggml_init(ggml_params);\n    struct ggml_cgraph  * gf     = ggml_new_graph(ctx0);\n\n    prompt_encoder_result enc_res = sam_encode_prompt(model, ctx0, gf, state, prompt);\n    if (!enc_res.embd_prompt_sparse || !enc_res.embd_prompt_dense) {\n        fprintf(stderr, \"%s: failed to encode prompt\\n\", __func__);\n        return {};\n    }\n\n    struct ggml_tensor * pe_img_dense = sam_fill_dense_pe(model, ctx0, gf, state);\n    if (!pe_img_dense) {\n        fprintf(stderr, \"%s: failed to get dense positional encoding\\n\", __func__);\n        return {};\n    }\n\n    if (!sam_decode_mask(model, enc_res, pe_img_dense, ctx0, gf, state, multimask_output)) {\n         fprintf(stderr, \"%s: failed to decode mask\\n\", __func__);\n         return {};\n    }\n\n    ggml_free(ctx0);\n\n    ggml_gallocr_alloc_graph(state.allocr, gf);\n\n    struct ggml_tensor * inp = ggml_graph_get_tensor(gf, \"prompt_input\");\n    auto * data = (float *) inp->data;\n\n    // Transform prompt (point or box)\n    {\n        // https://github.com/facebookresearch/segment-anything/blob/dca509fe793f601edb92606367a655c15ac00fdf/segment_anything/utils/transforms.py#L33\n        // The point scaling here is greatly simplified but mathematically equivalent.\n        const auto scale = 1.0F / std::max(nx, ny);\n\n        switch (prompt.prompt_type) {\n        case SAM_PROMPT_TYPE_POINT: {\n            const auto & pt = prompt.pt;\n\n            // set the input by converting the [0, 1] coordinates to [-1, 1]\n            data[0] = 2.0f*pt.x*scale - 1.0f;\n            data[1] = 2.0f*pt.y*scale - 1.0f;\n\n            // padding\n            // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L81-L85\n            data[2] = 2.0f*(0.0f) - 1.0f;\n            data[3] = 2.0f*(0.0f) - 1.0f;\n        } break;\n        case SAM_PROMPT_TYPE_BOX: {\n            const auto & box = prompt.box;\n\n            data[0] = 2.0f*box.x1*scale - 1.0f;\n            data[1] = 2.0f*box.y1*scale - 1.0f;\n            data[2] = 2.0f*box.x2*scale - 1.0f;\n            data[3] = 2.0f*box.y2*scale - 1.0f;\n        } break;\n        }\n    }\n\n    // from sam_fill_dense_pe\n    {\n        struct ggml_tensor * xy_embed_stacked = ggml_graph_get_tensor(gf, \"xy_embed_stacked\");\n        const int32_t n_img_embd = model.hparams.n_img_embd();\n        const float n_img_embd_inv = 1.0f / n_img_embd;\n        float * data = (float *) ggml_get_data(xy_embed_stacked);\n        for (int i = 0; i < n_img_embd; ++i) {\n            const int row = 2*i*n_img_embd;\n            const float y_val = 2 * (i + 0.5f) * n_img_embd_inv - 1;\n            for (int j = 0; j < n_img_embd; ++j) {\n                const float x_val = 2 * (j + 0.5f) * n_img_embd_inv - 1;\n                data[row + 2*j + 0] = x_val;\n                data[row + 2*j + 1] = y_val;\n            }\n        }\n    }\n\n    return gf;\n}\n\nvoid sam_print_usage(int argc, char ** argv, const sam_params & params) {\n    fprintf(stderr, \"usage: %s [options]\\n\", argv[0]);\n    fprintf(stderr, \"\\n\");\n    fprintf(stderr, \"options:\\n\");\n    fprintf(stderr, \"  -h, --help            show this help message and exit\\n\");\n    fprintf(stderr, \"  -s SEED, --seed SEED  RNG seed (default: -1)\\n\");\n    fprintf(stderr, \"  -t N, --threads N     number of threads to use during computation (default: %d)\\n\", params.n_threads);\n    fprintf(stderr, \"  -m FNAME, --model FNAME\\n\");\n    fprintf(stderr, \"                        model path (default: %s)\\n\", params.model.c_str());\n    fprintf(stderr, \"  -i FNAME, --inp FNAME\\n\");\n    fprintf(stderr, \"                        input file (default: %s)\\n\", params.fname_inp.c_str());\n    fprintf(stderr, \"  -o FNAME, --out FNAME\\n\");\n    fprintf(stderr, \"                        mask file name prefix (default: %s)\\n\", params.fname_out.c_str());\n    fprintf(stderr, \"  -sm, --single-mask\\n\");\n    fprintf(stderr, \"                        single mask output (default multi mask output)\\n\");\n    fprintf(stderr, \"SAM hyperparameters:\\n\");\n    fprintf(stderr, \"  -mt FLOAT, --mask-threshold\\n\");\n    fprintf(stderr, \"                        mask threshold (default: %f)\\n\", params.mask_threshold);\n    fprintf(stderr, \"  -it FLOAT, --iou-threshold\\n\");\n    fprintf(stderr, \"                        iou threshold (default: %f)\\n\", params.iou_threshold);\n    fprintf(stderr, \"  -st FLOAT, --score-threshold\\n\");\n    fprintf(stderr, \"                        score threshold (default: %f)\\n\", params.stability_score_threshold);\n    fprintf(stderr, \"  -so FLOAT, --score-offset\\n\");\n    fprintf(stderr, \"                        score offset (default: %f)\\n\", params.stability_score_offset);\n    fprintf(stderr, \"  -e FLOAT, --epsilon\\n\");\n    fprintf(stderr, \"                        epsilon (default: %f)\\n\", params.eps);\n    fprintf(stderr, \"  -ed FLOAT, --epsilon-decoder-transformer\\n\");\n    fprintf(stderr, \"                        epsilon decoder transformer (default: %f)\\n\", params.eps_decoder_transformer);\n    fprintf(stderr, \"SAM prompt:\\n\");\n    fprintf(stderr, \"  -p [x,y], --point-prompt\\n\");\n    fprintf(stderr, \"                        point to be used as prompt for SAM (default: %f,%f). Must be in a format FLOAT,FLOAT \\n\", params.prompt.pt.x, params.prompt.pt.y);\n    fprintf(stderr, \"  -b [x1,y1,x2,y2], --box-prompt\\n\");\n    fprintf(stderr, \"                        box to be used as prompt for SAM (default: %f,%f,%f,%f). Must be in a format FLOAT,FLOAT,FLOAT,FLOAT \\n\",\n        params.prompt.box.x1, params.prompt.box.y1, params.prompt.box.x2, params.prompt.box.y2);\n    fprintf(stderr, \"\\n\");\n}\n\nbool sam_params_parse(int argc, char ** argv, sam_params & params) {\n\n    bool use_point_prompt = false;\n    bool use_box_prompt = false;\n\n    for (int i = 1; i < argc; i++) {\n        std::string arg = argv[i];\n\n        if (arg == \"-s\" || arg == \"--seed\") {\n            params.seed = std::stoi(argv[++i]);\n        } else if (arg == \"-t\" || arg == \"--threads\") {\n            params.n_threads = std::stoi(argv[++i]);\n        } else if (arg == \"-m\" || arg == \"--model\") {\n            params.model = argv[++i];\n        } else if (arg == \"-i\" || arg == \"--inp\") {\n            params.fname_inp = argv[++i];\n        } else if (arg == \"-o\" || arg == \"--out\") {\n            params.fname_out = argv[++i];\n        } else if (arg == \"-sm\" || arg == \"--single-mask\") {\n            params.multimask_output = false;\n        } else if (arg == \"-mt\" || arg == \"--mask-threshold\") {\n            params.mask_threshold = std::stof(argv[++i]);\n        } else if (arg == \"-it\" || arg == \"--iou-threshold\") {\n            params.iou_threshold = std::stof(argv[++i]);\n        } else if (arg == \"-st\" || arg == \"--score-threshold\") {\n            params.stability_score_threshold = std::stof(argv[++i]);\n        } else if (arg == \"-so\" || arg == \"--score-offset\") {\n            params.stability_score_offset = std::stof(argv[++i]);\n        } else if (arg == \"-e\" || arg == \"--epsilon\") {\n            params.eps = std::stof(argv[++i]);\n        } else if (arg == \"-ed\" || arg == \"--epsilon-decoder-transformer\") {\n            params.eps_decoder_transformer = std::stof(argv[++i]);\n        } else if (arg == \"-p\" || arg == \"--point-prompt\") {\n            // TODO multiple points per model invocation\n            use_point_prompt = true;\n            char* point = argv[++i];\n\n            char* coord = strtok(point, \",\");\n            if (!coord){\n                fprintf(stderr, \"Error while parsing prompt!\\n\");\n                exit(1);\n            }\n            params.prompt.pt.x = std::stof(coord);\n\n            coord = strtok(NULL, \",\");\n            if (!coord){\n                fprintf(stderr, \"Error while parsing prompt!\\n\");\n                exit(1);\n            }\n            params.prompt.pt.y = std::stof(coord);\n        } else if (arg == \"-b\" || arg == \"--box-prompt\") {\n            use_box_prompt = true;\n            char * box_prompt = argv[++i];\n            float box_vals[4];\n\n            char * val = strtok(box_prompt, \",\");\n            if (!val) {\n                fprintf(stderr, \"Error while parsing prompt!\\n\");\n                exit(1);\n            }\n            box_vals[0] = std::stof(val);\n\n            for (int j = 1; j < 4; ++j) {\n                char * val = strtok(NULL, \",\");\n                if (!val) {\n                    fprintf(stderr, \"Error while parsing prompt!\\n\");\n                    exit(1);\n                }\n                box_vals[j] = std::stof(val);\n            }\n\n            params.prompt.box.x1 = box_vals[0];\n            params.prompt.box.y1 = box_vals[1];\n            params.prompt.box.x2 = box_vals[2];\n            params.prompt.box.y2 = box_vals[3];\n        } else if (arg == \"-h\" || arg == \"--help\") {\n            sam_print_usage(argc, argv, params);\n            exit(0);\n        } else {\n            fprintf(stderr, \"error: unknown argument: %s\\n\", arg.c_str());\n            sam_print_usage(argc, argv, params);\n            exit(0);\n        }\n    }\n\n    if (use_box_prompt && use_point_prompt) {\n        fprintf(stderr, \"Error: use either point or box prompt, not both.\\n\");\n        exit(1);\n    }\n\n    params.prompt.prompt_type = SAM_PROMPT_TYPE_POINT;\n    if (use_box_prompt) {\n        params.prompt.prompt_type = SAM_PROMPT_TYPE_BOX;\n    }\n\n    return true;\n}\n\n\nint main(int argc, char ** argv) {\n    const int64_t t_main_start_us = ggml_time_us();\n\n    sam_params params;\n    params.model = \"models/sam-vit-b/ggml-model-f16.bin\";\n\n    sam_model model;\n    sam_state state;\n    int64_t t_load_us = 0;\n\n    if (sam_params_parse(argc, argv, params) == false) {\n        return 1;\n    }\n\n    if (params.seed < 0) {\n        params.seed = time(NULL);\n    }\n    fprintf(stderr, \"%s: seed = %d\\n\", __func__, params.seed);\n\n    // load the image\n    sam_image_u8 img0;\n    if (!sam_image_load_from_file(params.fname_inp, img0)) {\n        fprintf(stderr, \"%s: failed to load image from '%s'\\n\", __func__, params.fname_inp.c_str());\n        return 1;\n    }\n    fprintf(stderr, \"%s: loaded image '%s' (%d x %d)\\n\", __func__, params.fname_inp.c_str(), img0.nx, img0.ny);\n\n    // preprocess to f32\n    sam_image_f32 img1;\n    if (!sam_image_preprocess(img0, img1)) {\n        fprintf(stderr, \"%s: failed to preprocess image\\n\", __func__);\n        return 1;\n    }\n    fprintf(stderr, \"%s: preprocessed image (%d x %d)\\n\", __func__, img1.nx, img1.ny);\n\n\n    // load the model\n    {\n        const int64_t t_start_us = ggml_time_us();\n\n        if (!sam_model_load(params, model)) {\n            fprintf(stderr, \"%s: failed to load model from '%s'\\n\", __func__, params.model.c_str());\n            return 1;\n        }\n\n        t_load_us = ggml_time_us() - t_start_us;\n    }\n\n    {\n        static size_t buf_size = 256u*1024*1024;\n\n        struct ggml_init_params ggml_params = {\n            /*.mem_size   =*/ buf_size,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ false,\n        };\n\n        state.ctx = ggml_init(ggml_params);\n\n        state.embd_img = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32,\n                model.hparams.n_img_embd(), model.hparams.n_img_embd(), model.hparams.n_enc_out_chans);\n\n        state.low_res_masks = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32,\n                model.hparams.n_enc_out_chans, model.hparams.n_enc_out_chans, 3);\n\n        state.iou_predictions = ggml_new_tensor_1d(state.ctx, GGML_TYPE_F32, 3);\n    }\n\n    // Encode image\n    {\n        state.buf_compute_img_enc.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());\n        state.allocr = ggml_gallocr_new(ggml_backend_cpu_buffer_type());\n\n        struct ggml_cgraph  * gf = sam_encode_image(model, state, img1);\n        if (!gf) {\n            fprintf(stderr, \"%s: failed to encode image\\n\", __func__);\n            return 1;\n        }\n\n        ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads);\n\n        // print_t_f32(\"embd_img\", state.embd_img);\n\n        ggml_gallocr_free(state.allocr);\n        state.allocr = NULL;\n        state.work_buffer.clear();\n    }\n\n    // Encode prompt and decode mask\n    {\n        state.buf_compute_fast.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());\n        state.allocr = ggml_gallocr_new(ggml_backend_cpu_buffer_type());\n\n        switch (params.prompt.prompt_type) {\n        case SAM_PROMPT_TYPE_POINT:\n            fprintf(stderr, \"Using point prompt: (%f, %f)\\n\", params.prompt.pt.x, params.prompt.pt.y);\n            break;\n        case SAM_PROMPT_TYPE_BOX:\n            fprintf(stderr, \"Using box prompt: (%f, %f, %f, %f)\\n\",\n                params.prompt.box.x1,\n                params.prompt.box.y1,\n                params.prompt.box.x2,\n                params.prompt.box.y2);\n            break;\n        }\n\n        struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, params.prompt, params.multimask_output);\n        if (!gf) {\n            fprintf(stderr, \"%s: failed to build fast graph\\n\", __func__);\n            return 1;\n        }\n\n        ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads);\n\n        //print_t_f32(\"iou_predictions\", state.iou_predictions);\n        //print_t_f32(\"low_res_masks\", state.low_res_masks);\n        ggml_gallocr_free(state.allocr);\n        state.allocr = NULL;\n    }\n\n    if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state, params.fname_out, params.multimask_output)) {\n        fprintf(stderr, \"%s: failed to write masks\\n\", __func__);\n        return 1;\n    }\n\n    // report timing\n    {\n        const int64_t t_main_end_us = ggml_time_us();\n\n        fprintf(stderr, \"\\n\\n\");\n        fprintf(stderr, \"%s:     load time = %8.2f ms\\n\", __func__, t_load_us/1000.0f);\n        fprintf(stderr, \"%s:    total time = %8.2f ms\\n\", __func__, (t_main_end_us - t_main_start_us)/1000.0f);\n    }\n\n    ggml_free(model.ctx);\n\n    return 0;\n}\n"
  },
  {
    "path": "examples/simple/CMakeLists.txt",
    "content": "#\n# simple-ctx\n\nset(TEST_TARGET simple-ctx)\nadd_executable(${TEST_TARGET} simple-ctx.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml)\n\n#\n# simple-backend\n\nset(TEST_TARGET simple-backend)\nadd_executable(${TEST_TARGET} simple-backend.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml)\n\nif (GGML_CUDA)\n    add_compile_definitions(GGML_USE_CUDA)\nendif()\n\nif (GGML_METAL)\n    add_compile_definitions(GGML_USE_METAL)\nendif()\n"
  },
  {
    "path": "examples/simple/README.md",
    "content": "## Simple\n\nThis example simply performs a matrix multiplication, solely for the purpose of demonstrating a basic usage of ggml and backend handling. The code is commented to help understand what each part does.\n\nTraditional matrix multiplication goes like this (multiply row-by-column):\n\n$$\nA \\times B = C\n$$\n\n$$\n\\begin{bmatrix}\n2 & 8 \\\\\n5 & 1 \\\\\n4 & 2 \\\\\n8 & 6 \\\\\n\\end{bmatrix}\n\\times\n\\begin{bmatrix}\n10 & 9 & 5 \\\\\n5 & 9 & 4 \\\\\n\\end{bmatrix}\n\\=\n\\begin{bmatrix}\n60 & 90 & 42 \\\\\n55 & 54 & 29 \\\\\n50 &  54 & 28 \\\\\n110 & 126 & 64 \\\\\n\\end{bmatrix}\n$$\n\nIn `ggml`, we pass the matrix $B$ in transposed form and multiply row-by-row. The result $C$ is also transposed:\n\n$$\nggml\\\\_mul\\\\_mat(A, B^T) = C^T\n$$\n\n$$\nggml\\\\_mul\\\\_mat(\n\\begin{bmatrix}\n2 & 8 \\\\\n5 & 1 \\\\\n4 & 2 \\\\\n8 & 6 \\\\\n\\end{bmatrix}\n,\n\\begin{bmatrix}\n10 & 5 \\\\\n9 & 9 \\\\\n5 & 4 \\\\\n\\end{bmatrix}\n)\n\\=\n\\begin{bmatrix}\n60 & 55 & 50 & 110 \\\\\n90 & 54 & 54 & 126 \\\\\n42 & 29 & 28 & 64 \\\\\n\\end{bmatrix}\n$$\n\nThe `simple-ctx` doesn't support gpu acceleration. `simple-backend` demonstrates how to use other backends like CUDA and Metal.\n"
  },
  {
    "path": "examples/simple/simple-backend.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-backend.h\"\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <vector>\n\nstatic void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {\n    (void) level;\n    (void) user_data;\n    fputs(text, stderr);\n    fflush(stderr);\n}\n\n// This is a simple model with two tensors a and b\nstruct simple_model {\n    struct ggml_tensor * a {};\n    struct ggml_tensor * b {};\n\n    // the backend to perform the computation (CPU, CUDA, METAL)\n    ggml_backend_t backend {};\n    ggml_backend_t cpu_backend {};\n    ggml_backend_sched_t sched {};\n\n    // storage for the graph and tensors\n    std::vector<uint8_t> buf;\n};\n\n// initialize data of matrices to perform matrix multiplication\nconst int rows_A = 4, cols_A = 2;\n\nfloat matrix_A[rows_A * cols_A] = {\n    2, 8,\n    5, 1,\n    4, 2,\n    8, 6\n};\n\nconst int rows_B = 3, cols_B = 2;\n/* Transpose([\n    10, 9, 5,\n    5, 9, 4\n]) 2 rows, 3 cols */\nfloat matrix_B[rows_B * cols_B] = {\n    10, 5,\n    9, 9,\n    5, 4\n};\n\n\n// initialize the tensors of the model in this case two matrices 2x2\nvoid init_model(simple_model & model) {\n    ggml_log_set(ggml_log_callback_default, nullptr);\n\n    ggml_backend_load_all();\n\n    model.backend = ggml_backend_init_best();\n    model.cpu_backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);\n\n    ggml_backend_t backends[2] = { model.backend, model.cpu_backend };\n    model.sched = ggml_backend_sched_new(backends, nullptr, 2, GGML_DEFAULT_GRAPH_SIZE, false, true);\n}\n\n// build the compute graph to perform a matrix multiplication\nstruct ggml_cgraph * build_graph(simple_model& model) {\n    size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();\n    model.buf.resize(buf_size);\n\n    struct ggml_init_params params0 = {\n        /*.mem_size   =*/ buf_size,\n        /*.mem_buffer =*/ model.buf.data(),\n        /*.no_alloc   =*/ true, // the tensors will be allocated later\n    };\n\n    // create a context to build the graph\n    struct ggml_context * ctx = ggml_init(params0);\n\n    struct ggml_cgraph  * gf = ggml_new_graph(ctx);\n\n    // create tensors\n    model.a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, cols_A, rows_A);\n    model.b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, cols_B, rows_B);\n\n    // result = a*b^T\n    struct ggml_tensor * result = ggml_mul_mat(ctx, model.a, model.b);\n\n    // build operations nodes\n    ggml_build_forward_expand(gf, result);\n\n    ggml_free(ctx);\n\n    return gf;\n}\n\n// compute with backend\nstruct ggml_tensor * compute(simple_model & model, struct ggml_cgraph * gf) {\n    ggml_backend_sched_reset(model.sched);\n    ggml_backend_sched_alloc_graph(model.sched, gf);\n\n    // load data from cpu memory to backend buffer\n    ggml_backend_tensor_set(model.a, matrix_A, 0, ggml_nbytes(model.a));\n    ggml_backend_tensor_set(model.b, matrix_B, 0, ggml_nbytes(model.b));\n\n    // compute the graph\n    ggml_backend_sched_graph_compute(model.sched, gf);\n\n    // in this case, the output tensor is the last one in the graph\n    return ggml_graph_node(gf, -1);\n}\n\nint main(void) {\n    ggml_time_init();\n\n    simple_model model;\n    init_model(model);\n\n    struct ggml_cgraph * gf = build_graph(model);\n\n    // perform computation\n    struct ggml_tensor * result = compute(model, gf);\n\n    // create a array to print result\n    std::vector<float> out_data(ggml_nelements(result));\n\n    // bring the data from the backend memory\n    ggml_backend_tensor_get(result, out_data.data(), 0, ggml_nbytes(result));\n\n    // expected result:\n    // [ 60.00 55.00 50.00 110.00\n    //  90.00 54.00 54.00 126.00\n    //  42.00 29.00 28.00 64.00 ]\n\n    printf(\"mul mat (%d x %d) (transposed result):\\n[\", (int) result->ne[0], (int) result->ne[1]);\n    for (int j = 0; j < result->ne[1] /* rows */; j++) {\n        if (j > 0) {\n            printf(\"\\n\");\n        }\n\n        for (int i = 0; i < result->ne[0] /* cols */; i++) {\n            printf(\" %.2f\", out_data[j * result->ne[0] + i]);\n        }\n    }\n    printf(\" ]\\n\");\n\n    // release backend memory and free backend\n    ggml_backend_sched_free(model.sched);\n    ggml_backend_free(model.backend);\n    ggml_backend_free(model.cpu_backend);\n    return 0;\n}\n"
  },
  {
    "path": "examples/simple/simple-ctx.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <string>\n#include <vector>\n\n// This is a simple model with two tensors a and b\nstruct simple_model {\n    struct ggml_tensor * a;\n    struct ggml_tensor * b;\n\n    // the context to define the tensor information (dimensions, size, memory data)\n    struct ggml_context * ctx;\n};\n\n// initialize the tensors of the model in this case two matrices 2x2\nvoid load_model(simple_model & model, float * a, float * b, int rows_A, int cols_A, int rows_B, int cols_B) {\n    size_t ctx_size = 0;\n    {\n        ctx_size += rows_A * cols_A * ggml_type_size(GGML_TYPE_F32); // tensor a\n        ctx_size += rows_B * cols_B * ggml_type_size(GGML_TYPE_F32); // tensor b\n        ctx_size += 2 * ggml_tensor_overhead(), // tensors\n        ctx_size += ggml_graph_overhead(); // compute graph\n        ctx_size += 1024; // some overhead\n    }\n\n    struct ggml_init_params params {\n            /*.mem_size   =*/ ctx_size,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ false, // NOTE: this should be false when using the legacy API\n    };\n\n    // create context\n    model.ctx = ggml_init(params);\n\n    // create tensors\n    model.a = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, cols_A, rows_A);\n    model.b = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, cols_B, rows_B);\n\n    memcpy(model.a->data, a, ggml_nbytes(model.a));\n    memcpy(model.b->data, b, ggml_nbytes(model.b));\n}\n\n// build the compute graph to perform a matrix multiplication\nstruct ggml_cgraph * build_graph(const simple_model& model) {\n    struct ggml_cgraph  * gf = ggml_new_graph(model.ctx);\n\n    // result = a*b^T\n    struct ggml_tensor * result = ggml_mul_mat(model.ctx, model.a, model.b);\n\n    ggml_build_forward_expand(gf, result);\n    return gf;\n}\n\n// compute with backend\nstruct ggml_tensor * compute(const simple_model & model) {\n    struct ggml_cgraph * gf = build_graph(model);\n\n    int n_threads = 1; // number of threads to perform some operations with multi-threading\n\n    ggml_graph_compute_with_ctx(model.ctx, gf, n_threads);\n\n    // in this case, the output tensor is the last one in the graph\n    return ggml_graph_node(gf, -1);\n}\n\nint main(void) {\n    ggml_time_init();\n\n    // initialize data of matrices to perform matrix multiplication\n    const int rows_A = 4, cols_A = 2;\n\n    float matrix_A[rows_A * cols_A] = {\n        2, 8,\n        5, 1,\n        4, 2,\n        8, 6\n    };\n\n    const int rows_B = 3, cols_B = 2;\n    /* Transpose([\n        10, 9, 5,\n        5, 9, 4\n    ]) 2 rows, 3 cols */\n    float matrix_B[rows_B * cols_B] = {\n        10, 5,\n        9, 9,\n        5, 4\n    };\n\n    simple_model model;\n    load_model(model, matrix_A, matrix_B, rows_A, cols_A, rows_B, cols_B);\n\n    // perform computation in cpu\n    struct ggml_tensor * result = compute(model);\n\n    // get the result data pointer as a float array to print\n    std::vector<float> out_data(ggml_nelements(result));\n    memcpy(out_data.data(), result->data, ggml_nbytes(result));\n\n    // expected result:\n    // [ 60.00 55.00 50.00 110.00\n    //   90.00 54.00 54.00 126.00\n    //   42.00 29.00 28.00 64.00 ]\n\n    printf(\"mul mat (%d x %d) (transposed result):\\n[\", (int) result->ne[0], (int) result->ne[1]);\n    for (int j = 0; j < result->ne[1] /* rows */; j++) {\n        if (j > 0) {\n            printf(\"\\n\");\n        }\n\n        for (int i = 0; i < result->ne[0] /* cols */; i++) {\n            printf(\" %.2f\", out_data[j * result->ne[0] + i]);\n        }\n    }\n    printf(\" ]\\n\");\n\n    // free memory\n    ggml_free(model.ctx);\n    return 0;\n}\n"
  },
  {
    "path": "examples/stb_image.h",
    "content": "/* stb_image - v2.28 - public domain image loader - http://nothings.org/stb\n                                  no warranty implied; use at your own risk\n\n   Do this:\n      #define STB_IMAGE_IMPLEMENTATION\n   before you include this file in *one* C or C++ file to create the implementation.\n\n   // i.e. it should look like this:\n   #include ...\n   #include ...\n   #include ...\n   #define STB_IMAGE_IMPLEMENTATION\n   #include \"stb_image.h\"\n\n   You can #define STBI_ASSERT(x) before the #include to avoid using assert.h.\n   And #define STBI_MALLOC, STBI_REALLOC, and STBI_FREE to avoid using malloc,realloc,free\n\n\n   QUICK NOTES:\n      Primarily of interest to game developers and other people who can\n          avoid problematic images and only need the trivial interface\n\n      JPEG baseline & progressive (12 bpc/arithmetic not supported, same as stock IJG lib)\n      PNG 1/2/4/8/16-bit-per-channel\n\n      TGA (not sure what subset, if a subset)\n      BMP non-1bpp, non-RLE\n      PSD (composited view only, no extra channels, 8/16 bit-per-channel)\n\n      GIF (*comp always reports as 4-channel)\n      HDR (radiance rgbE format)\n      PIC (Softimage PIC)\n      PNM (PPM and PGM binary only)\n\n      Animated GIF still needs a proper API, but here's one way to do it:\n          http://gist.github.com/urraka/685d9a6340b26b830d49\n\n      - decode from memory or through FILE (define STBI_NO_STDIO to remove code)\n      - decode from arbitrary I/O callbacks\n      - SIMD acceleration on x86/x64 (SSE2) and ARM (NEON)\n\n   Full documentation under \"DOCUMENTATION\" below.\n\n\nLICENSE\n\n  See end of file for license information.\n\nRECENT REVISION HISTORY:\n\n      2.28  (2023-01-29) many error fixes, security errors, just tons of stuff\n      2.27  (2021-07-11) document stbi_info better, 16-bit PNM support, bug fixes\n      2.26  (2020-07-13) many minor fixes\n      2.25  (2020-02-02) fix warnings\n      2.24  (2020-02-02) fix warnings; thread-local failure_reason and flip_vertically\n      2.23  (2019-08-11) fix clang static analysis warning\n      2.22  (2019-03-04) gif fixes, fix warnings\n      2.21  (2019-02-25) fix typo in comment\n      2.20  (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs\n      2.19  (2018-02-11) fix warning\n      2.18  (2018-01-30) fix warnings\n      2.17  (2018-01-29) bugfix, 1-bit BMP, 16-bitness query, fix warnings\n      2.16  (2017-07-23) all functions have 16-bit variants; optimizations; bugfixes\n      2.15  (2017-03-18) fix png-1,2,4; all Imagenet JPGs; no runtime SSE detection on GCC\n      2.14  (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs\n      2.13  (2016-12-04) experimental 16-bit API, only for PNG so far; fixes\n      2.12  (2016-04-02) fix typo in 2.11 PSD fix that caused crashes\n      2.11  (2016-04-02) 16-bit PNGS; enable SSE2 in non-gcc x64\n                         RGB-format JPEG; remove white matting in PSD;\n                         allocate large structures on the stack;\n                         correct channel count for PNG & BMP\n      2.10  (2016-01-22) avoid warning introduced in 2.09\n      2.09  (2016-01-16) 16-bit TGA; comments in PNM files; STBI_REALLOC_SIZED\n\n   See end of file for full revision history.\n\n\n ============================    Contributors    =========================\n\n Image formats                          Extensions, features\n    Sean Barrett (jpeg, png, bmp)          Jetro Lauha (stbi_info)\n    Nicolas Schulz (hdr, psd)              Martin \"SpartanJ\" Golini (stbi_info)\n    Jonathan Dummer (tga)                  James \"moose2000\" Brown (iPhone PNG)\n    Jean-Marc Lienher (gif)                Ben \"Disch\" Wenger (io callbacks)\n    Tom Seddon (pic)                       Omar Cornut (1/2/4-bit PNG)\n    Thatcher Ulrich (psd)                  Nicolas Guillemot (vertical flip)\n    Ken Miller (pgm, ppm)                  Richard Mitton (16-bit PSD)\n    github:urraka (animated gif)           Junggon Kim (PNM comments)\n    Christopher Forseth (animated gif)     Daniel Gibson (16-bit TGA)\n                                           socks-the-fox (16-bit PNG)\n                                           Jeremy Sawicki (handle all ImageNet JPGs)\n Optimizations & bugfixes                  Mikhail Morozov (1-bit BMP)\n    Fabian \"ryg\" Giesen                    Anael Seghezzi (is-16-bit query)\n    Arseny Kapoulkine                      Simon Breuss (16-bit PNM)\n    John-Mark Allen\n    Carmelo J Fdez-Aguera\n\n Bug & warning fixes\n    Marc LeBlanc            David Woo          Guillaume George     Martins Mozeiko\n    Christpher Lloyd        Jerry Jansson      Joseph Thomson       Blazej Dariusz Roszkowski\n    Phil Jordan                                Dave Moore           Roy Eltham\n    Hayaki Saito            Nathan Reed        Won Chun\n    Luke Graham             Johan Duparc       Nick Verigakis       the Horde3D community\n    Thomas Ruf              Ronny Chevalier                         github:rlyeh\n    Janez Zemva             John Bartholomew   Michal Cichon        github:romigrou\n    Jonathan Blow           Ken Hamada         Tero Hanninen        github:svdijk\n    Eugene Golushkov        Laurent Gomila     Cort Stratton        github:snagar\n    Aruelien Pocheville     Sergio Gonzalez    Thibault Reuille     github:Zelex\n    Cass Everitt            Ryamond Barbiero                        github:grim210\n    Paul Du Bois            Engin Manap        Aldo Culquicondor    github:sammyhw\n    Philipp Wiesemann       Dale Weiler        Oriol Ferrer Mesia   github:phprus\n    Josh Tobin              Neil Bickford      Matthew Gregan       github:poppolopoppo\n    Julian Raschke          Gregory Mullen     Christian Floisand   github:darealshinji\n    Baldur Karlsson         Kevin Schmidt      JR Smith             github:Michaelangel007\n                            Brad Weinberger    Matvey Cherevko      github:mosra\n    Luca Sas                Alexander Veselov  Zack Middleton       [reserved]\n    Ryan C. Gordon          [reserved]                              [reserved]\n                     DO NOT ADD YOUR NAME HERE\n\n                     Jacko Dirks\n\n  To add your name to the credits, pick a random blank space in the middle and fill it.\n  80% of merge conflicts on stb PRs are due to people adding their name at the end\n  of the credits.\n*/\n\n#ifndef STBI_INCLUDE_STB_IMAGE_H\n#define STBI_INCLUDE_STB_IMAGE_H\n\n// DOCUMENTATION\n//\n// Limitations:\n//    - no 12-bit-per-channel JPEG\n//    - no JPEGs with arithmetic coding\n//    - GIF always returns *comp=4\n//\n// Basic usage (see HDR discussion below for HDR usage):\n//    int x,y,n;\n//    unsigned char *data = stbi_load(filename, &x, &y, &n, 0);\n//    // ... process data if not NULL ...\n//    // ... x = width, y = height, n = # 8-bit components per pixel ...\n//    // ... replace '0' with '1'..'4' to force that many components per pixel\n//    // ... but 'n' will always be the number that it would have been if you said 0\n//    stbi_image_free(data);\n//\n// Standard parameters:\n//    int *x                 -- outputs image width in pixels\n//    int *y                 -- outputs image height in pixels\n//    int *channels_in_file  -- outputs # of image components in image file\n//    int desired_channels   -- if non-zero, # of image components requested in result\n//\n// The return value from an image loader is an 'unsigned char *' which points\n// to the pixel data, or NULL on an allocation failure or if the image is\n// corrupt or invalid. The pixel data consists of *y scanlines of *x pixels,\n// with each pixel consisting of N interleaved 8-bit components; the first\n// pixel pointed to is top-left-most in the image. There is no padding between\n// image scanlines or between pixels, regardless of format. The number of\n// components N is 'desired_channels' if desired_channels is non-zero, or\n// *channels_in_file otherwise. If desired_channels is non-zero,\n// *channels_in_file has the number of components that _would_ have been\n// output otherwise. E.g. if you set desired_channels to 4, you will always\n// get RGBA output, but you can check *channels_in_file to see if it's trivially\n// opaque because e.g. there were only 3 channels in the source image.\n//\n// An output image with N components has the following components interleaved\n// in this order in each pixel:\n//\n//     N=#comp     components\n//       1           grey\n//       2           grey, alpha\n//       3           red, green, blue\n//       4           red, green, blue, alpha\n//\n// If image loading fails for any reason, the return value will be NULL,\n// and *x, *y, *channels_in_file will be unchanged. The function\n// stbi_failure_reason() can be queried for an extremely brief, end-user\n// unfriendly explanation of why the load failed. Define STBI_NO_FAILURE_STRINGS\n// to avoid compiling these strings at all, and STBI_FAILURE_USERMSG to get slightly\n// more user-friendly ones.\n//\n// Paletted PNG, BMP, GIF, and PIC images are automatically depalettized.\n//\n// To query the width, height and component count of an image without having to\n// decode the full file, you can use the stbi_info family of functions:\n//\n//   int x,y,n,ok;\n//   ok = stbi_info(filename, &x, &y, &n);\n//   // returns ok=1 and sets x, y, n if image is a supported format,\n//   // 0 otherwise.\n//\n// Note that stb_image pervasively uses ints in its public API for sizes,\n// including sizes of memory buffers. This is now part of the API and thus\n// hard to change without causing breakage. As a result, the various image\n// loaders all have certain limits on image size; these differ somewhat\n// by format but generally boil down to either just under 2GB or just under\n// 1GB. When the decoded image would be larger than this, stb_image decoding\n// will fail.\n//\n// Additionally, stb_image will reject image files that have any of their\n// dimensions set to a larger value than the configurable STBI_MAX_DIMENSIONS,\n// which defaults to 2**24 = 16777216 pixels. Due to the above memory limit,\n// the only way to have an image with such dimensions load correctly\n// is for it to have a rather extreme aspect ratio. Either way, the\n// assumption here is that such larger images are likely to be malformed\n// or malicious. If you do need to load an image with individual dimensions\n// larger than that, and it still fits in the overall size limit, you can\n// #define STBI_MAX_DIMENSIONS on your own to be something larger.\n//\n// ===========================================================================\n//\n// UNICODE:\n//\n//   If compiling for Windows and you wish to use Unicode filenames, compile\n//   with\n//       #define STBI_WINDOWS_UTF8\n//   and pass utf8-encoded filenames. Call stbi_convert_wchar_to_utf8 to convert\n//   Windows wchar_t filenames to utf8.\n//\n// ===========================================================================\n//\n// Philosophy\n//\n// stb libraries are designed with the following priorities:\n//\n//    1. easy to use\n//    2. easy to maintain\n//    3. good performance\n//\n// Sometimes I let \"good performance\" creep up in priority over \"easy to maintain\",\n// and for best performance I may provide less-easy-to-use APIs that give higher\n// performance, in addition to the easy-to-use ones. Nevertheless, it's important\n// to keep in mind that from the standpoint of you, a client of this library,\n// all you care about is #1 and #3, and stb libraries DO NOT emphasize #3 above all.\n//\n// Some secondary priorities arise directly from the first two, some of which\n// provide more explicit reasons why performance can't be emphasized.\n//\n//    - Portable (\"ease of use\")\n//    - Small source code footprint (\"easy to maintain\")\n//    - No dependencies (\"ease of use\")\n//\n// ===========================================================================\n//\n// I/O callbacks\n//\n// I/O callbacks allow you to read from arbitrary sources, like packaged\n// files or some other source. Data read from callbacks are processed\n// through a small internal buffer (currently 128 bytes) to try to reduce\n// overhead.\n//\n// The three functions you must define are \"read\" (reads some bytes of data),\n// \"skip\" (skips some bytes of data), \"eof\" (reports if the stream is at the end).\n//\n// ===========================================================================\n//\n// SIMD support\n//\n// The JPEG decoder will try to automatically use SIMD kernels on x86 when\n// supported by the compiler. For ARM Neon support, you must explicitly\n// request it.\n//\n// (The old do-it-yourself SIMD API is no longer supported in the current\n// code.)\n//\n// On x86, SSE2 will automatically be used when available based on a run-time\n// test; if not, the generic C versions are used as a fall-back. On ARM targets,\n// the typical path is to have separate builds for NEON and non-NEON devices\n// (at least this is true for iOS and Android). Therefore, the NEON support is\n// toggled by a build flag: define STBI_NEON to get NEON loops.\n//\n// If for some reason you do not want to use any of SIMD code, or if\n// you have issues compiling it, you can disable it entirely by\n// defining STBI_NO_SIMD.\n//\n// ===========================================================================\n//\n// HDR image support   (disable by defining STBI_NO_HDR)\n//\n// stb_image supports loading HDR images in general, and currently the Radiance\n// .HDR file format specifically. You can still load any file through the existing\n// interface; if you attempt to load an HDR file, it will be automatically remapped\n// to LDR, assuming gamma 2.2 and an arbitrary scale factor defaulting to 1;\n// both of these constants can be reconfigured through this interface:\n//\n//     stbi_hdr_to_ldr_gamma(2.2f);\n//     stbi_hdr_to_ldr_scale(1.0f);\n//\n// (note, do not use _inverse_ constants; stbi_image will invert them\n// appropriately).\n//\n// Additionally, there is a new, parallel interface for loading files as\n// (linear) floats to preserve the full dynamic range:\n//\n//    float *data = stbi_loadf(filename, &x, &y, &n, 0);\n//\n// If you load LDR images through this interface, those images will\n// be promoted to floating point values, run through the inverse of\n// constants corresponding to the above:\n//\n//     stbi_ldr_to_hdr_scale(1.0f);\n//     stbi_ldr_to_hdr_gamma(2.2f);\n//\n// Finally, given a filename (or an open file or memory block--see header\n// file for details) containing image data, you can query for the \"most\n// appropriate\" interface to use (that is, whether the image is HDR or\n// not), using:\n//\n//     stbi_is_hdr(char *filename);\n//\n// ===========================================================================\n//\n// iPhone PNG support:\n//\n// We optionally support converting iPhone-formatted PNGs (which store\n// premultiplied BGRA) back to RGB, even though they're internally encoded\n// differently. To enable this conversion, call\n// stbi_convert_iphone_png_to_rgb(1).\n//\n// Call stbi_set_unpremultiply_on_load(1) as well to force a divide per\n// pixel to remove any premultiplied alpha *only* if the image file explicitly\n// says there's premultiplied data (currently only happens in iPhone images,\n// and only if iPhone convert-to-rgb processing is on).\n//\n// ===========================================================================\n//\n// ADDITIONAL CONFIGURATION\n//\n//  - You can suppress implementation of any of the decoders to reduce\n//    your code footprint by #defining one or more of the following\n//    symbols before creating the implementation.\n//\n//        STBI_NO_JPEG\n//        STBI_NO_PNG\n//        STBI_NO_BMP\n//        STBI_NO_PSD\n//        STBI_NO_TGA\n//        STBI_NO_GIF\n//        STBI_NO_HDR\n//        STBI_NO_PIC\n//        STBI_NO_PNM   (.ppm and .pgm)\n//\n//  - You can request *only* certain decoders and suppress all other ones\n//    (this will be more forward-compatible, as addition of new decoders\n//    doesn't require you to disable them explicitly):\n//\n//        STBI_ONLY_JPEG\n//        STBI_ONLY_PNG\n//        STBI_ONLY_BMP\n//        STBI_ONLY_PSD\n//        STBI_ONLY_TGA\n//        STBI_ONLY_GIF\n//        STBI_ONLY_HDR\n//        STBI_ONLY_PIC\n//        STBI_ONLY_PNM   (.ppm and .pgm)\n//\n//   - If you use STBI_NO_PNG (or _ONLY_ without PNG), and you still\n//     want the zlib decoder to be available, #define STBI_SUPPORT_ZLIB\n//\n//  - If you define STBI_MAX_DIMENSIONS, stb_image will reject images greater\n//    than that size (in either width or height) without further processing.\n//    This is to let programs in the wild set an upper bound to prevent\n//    denial-of-service attacks on untrusted data, as one could generate a\n//    valid image of gigantic dimensions and force stb_image to allocate a\n//    huge block of memory and spend disproportionate time decoding it. By\n//    default this is set to (1 << 24), which is 16777216, but that's still\n//    very big.\n\n#ifndef STBI_NO_STDIO\n#include <stdio.h>\n#endif // STBI_NO_STDIO\n\n#define STBI_VERSION 1\n\nenum\n{\n   STBI_default = 0, // only used for desired_channels\n\n   STBI_grey       = 1,\n   STBI_grey_alpha = 2,\n   STBI_rgb        = 3,\n   STBI_rgb_alpha  = 4\n};\n\n#include <stdlib.h>\ntypedef unsigned char stbi_uc;\ntypedef unsigned short stbi_us;\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n#ifndef STBIDEF\n#ifdef STB_IMAGE_STATIC\n#define STBIDEF static\n#else\n#define STBIDEF extern\n#endif\n#endif\n\n//////////////////////////////////////////////////////////////////////////////\n//\n// PRIMARY API - works on images of any type\n//\n\n//\n// load image by filename, open file, or memory buffer\n//\n\ntypedef struct\n{\n   int      (*read)  (void *user,char *data,int size);   // fill 'data' with 'size' bytes.  return number of bytes actually read\n   void     (*skip)  (void *user,int n);                 // skip the next 'n' bytes, or 'unget' the last -n bytes if negative\n   int      (*eof)   (void *user);                       // returns nonzero if we are at end of file/data\n} stbi_io_callbacks;\n\n////////////////////////////////////\n//\n// 8-bits-per-channel interface\n//\n\nSTBIDEF stbi_uc *stbi_load_from_memory   (stbi_uc           const *buffer, int len   , int *x, int *y, int *channels_in_file, int desired_channels);\nSTBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk  , void *user, int *x, int *y, int *channels_in_file, int desired_channels);\n\n#ifndef STBI_NO_STDIO\nSTBIDEF stbi_uc *stbi_load            (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels);\nSTBIDEF stbi_uc *stbi_load_from_file  (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels);\n// for stbi_load_from_file, file pointer is left pointing immediately after image\n#endif\n\n#ifndef STBI_NO_GIF\nSTBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp);\n#endif\n\n#ifdef STBI_WINDOWS_UTF8\nSTBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input);\n#endif\n\n////////////////////////////////////\n//\n// 16-bits-per-channel interface\n//\n\nSTBIDEF stbi_us *stbi_load_16_from_memory   (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels);\nSTBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels);\n\n#ifndef STBI_NO_STDIO\nSTBIDEF stbi_us *stbi_load_16          (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels);\nSTBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_in_file, int desired_channels);\n#endif\n\n////////////////////////////////////\n//\n// float-per-channel interface\n//\n#ifndef STBI_NO_LINEAR\n   STBIDEF float *stbi_loadf_from_memory     (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels);\n   STBIDEF float *stbi_loadf_from_callbacks  (stbi_io_callbacks const *clbk, void *user, int *x, int *y,  int *channels_in_file, int desired_channels);\n\n   #ifndef STBI_NO_STDIO\n   STBIDEF float *stbi_loadf            (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels);\n   STBIDEF float *stbi_loadf_from_file  (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels);\n   #endif\n#endif\n\n#ifndef STBI_NO_HDR\n   STBIDEF void   stbi_hdr_to_ldr_gamma(float gamma);\n   STBIDEF void   stbi_hdr_to_ldr_scale(float scale);\n#endif // STBI_NO_HDR\n\n#ifndef STBI_NO_LINEAR\n   STBIDEF void   stbi_ldr_to_hdr_gamma(float gamma);\n   STBIDEF void   stbi_ldr_to_hdr_scale(float scale);\n#endif // STBI_NO_LINEAR\n\n// stbi_is_hdr is always defined, but always returns false if STBI_NO_HDR\nSTBIDEF int    stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user);\nSTBIDEF int    stbi_is_hdr_from_memory(stbi_uc const *buffer, int len);\n#ifndef STBI_NO_STDIO\nSTBIDEF int      stbi_is_hdr          (char const *filename);\nSTBIDEF int      stbi_is_hdr_from_file(FILE *f);\n#endif // STBI_NO_STDIO\n\n\n// get a VERY brief reason for failure\n// on most compilers (and ALL modern mainstream compilers) this is threadsafe\nSTBIDEF const char *stbi_failure_reason  (void);\n\n// free the loaded image -- this is just free()\nSTBIDEF void     stbi_image_free      (void *retval_from_stbi_load);\n\n// get image dimensions & components without fully decoding\nSTBIDEF int      stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp);\nSTBIDEF int      stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp);\nSTBIDEF int      stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len);\nSTBIDEF int      stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, void *user);\n\n#ifndef STBI_NO_STDIO\nSTBIDEF int      stbi_info               (char const *filename,     int *x, int *y, int *comp);\nSTBIDEF int      stbi_info_from_file     (FILE *f,                  int *x, int *y, int *comp);\nSTBIDEF int      stbi_is_16_bit          (char const *filename);\nSTBIDEF int      stbi_is_16_bit_from_file(FILE *f);\n#endif\n\n\n\n// for image formats that explicitly notate that they have premultiplied alpha,\n// we just return the colors as stored in the file. set this flag to force\n// unpremultiplication. results are undefined if the unpremultiply overflow.\nSTBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply);\n\n// indicate whether we should process iphone images back to canonical format,\n// or just pass them through \"as-is\"\nSTBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert);\n\n// flip the image vertically, so the first pixel in the output array is the bottom left\nSTBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip);\n\n// as above, but only applies to images loaded on the thread that calls the function\n// this function is only available if your compiler supports thread-local variables;\n// calling it will fail to link if your compiler doesn't\nSTBIDEF void stbi_set_unpremultiply_on_load_thread(int flag_true_if_should_unpremultiply);\nSTBIDEF void stbi_convert_iphone_png_to_rgb_thread(int flag_true_if_should_convert);\nSTBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip);\n\n// ZLIB client - used by PNG, available for other purposes\n\nSTBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen);\nSTBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header);\nSTBIDEF char *stbi_zlib_decode_malloc(const char *buffer, int len, int *outlen);\nSTBIDEF int   stbi_zlib_decode_buffer(char *obuffer, int olen, const char *ibuffer, int ilen);\n\nSTBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, int *outlen);\nSTBIDEF int   stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen);\n\n\n#ifdef __cplusplus\n}\n#endif\n\n//\n//\n////   end header file   /////////////////////////////////////////////////////\n#endif // STBI_INCLUDE_STB_IMAGE_H\n\n#ifdef STB_IMAGE_IMPLEMENTATION\n\n#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) \\\n  || defined(STBI_ONLY_TGA) || defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) \\\n  || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || defined(STBI_ONLY_PNM) \\\n  || defined(STBI_ONLY_ZLIB)\n   #ifndef STBI_ONLY_JPEG\n   #define STBI_NO_JPEG\n   #endif\n   #ifndef STBI_ONLY_PNG\n   #define STBI_NO_PNG\n   #endif\n   #ifndef STBI_ONLY_BMP\n   #define STBI_NO_BMP\n   #endif\n   #ifndef STBI_ONLY_PSD\n   #define STBI_NO_PSD\n   #endif\n   #ifndef STBI_ONLY_TGA\n   #define STBI_NO_TGA\n   #endif\n   #ifndef STBI_ONLY_GIF\n   #define STBI_NO_GIF\n   #endif\n   #ifndef STBI_ONLY_HDR\n   #define STBI_NO_HDR\n   #endif\n   #ifndef STBI_ONLY_PIC\n   #define STBI_NO_PIC\n   #endif\n   #ifndef STBI_ONLY_PNM\n   #define STBI_NO_PNM\n   #endif\n#endif\n\n#if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && !defined(STBI_NO_ZLIB)\n#define STBI_NO_ZLIB\n#endif\n\n\n#include <stdarg.h>\n#include <stddef.h> // ptrdiff_t on osx\n#include <stdlib.h>\n#include <string.h>\n#include <limits.h>\n\n#if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR)\n#include <math.h>  // ldexp, pow\n#endif\n\n#ifndef STBI_NO_STDIO\n#include <stdio.h>\n#endif\n\n#ifndef STBI_ASSERT\n#include <assert.h>\n#define STBI_ASSERT(x) assert(x)\n#endif\n\n#ifdef __cplusplus\n#define STBI_EXTERN extern \"C\"\n#else\n#define STBI_EXTERN extern\n#endif\n\n\n#ifndef _MSC_VER\n   #ifdef __cplusplus\n   #define stbi_inline inline\n   #else\n   #define stbi_inline\n   #endif\n#else\n   #define stbi_inline __forceinline\n#endif\n\n#ifndef STBI_NO_THREAD_LOCALS\n   #if defined(__cplusplus) &&  __cplusplus >= 201103L\n      #define STBI_THREAD_LOCAL       thread_local\n   #elif defined(__GNUC__) && __GNUC__ < 5\n      #define STBI_THREAD_LOCAL       __thread\n   #elif defined(_MSC_VER)\n      #define STBI_THREAD_LOCAL       __declspec(thread)\n   #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__)\n      #define STBI_THREAD_LOCAL       _Thread_local\n   #endif\n\n   #ifndef STBI_THREAD_LOCAL\n      #if defined(__GNUC__)\n        #define STBI_THREAD_LOCAL       __thread\n      #endif\n   #endif\n#endif\n\n#if defined(_MSC_VER) || defined(__SYMBIAN32__)\ntypedef unsigned short stbi__uint16;\ntypedef   signed short stbi__int16;\ntypedef unsigned int   stbi__uint32;\ntypedef   signed int   stbi__int32;\n#else\n#include <stdint.h>\ntypedef uint16_t stbi__uint16;\ntypedef int16_t  stbi__int16;\ntypedef uint32_t stbi__uint32;\ntypedef int32_t  stbi__int32;\n#endif\n\n// should produce compiler error if size is wrong\ntypedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1];\n\n#ifdef _MSC_VER\n#define STBI_NOTUSED(v)  (void)(v)\n#else\n#define STBI_NOTUSED(v)  (void)sizeof(v)\n#endif\n\n#ifdef _MSC_VER\n#define STBI_HAS_LROTL\n#endif\n\n#ifdef STBI_HAS_LROTL\n   #define stbi_lrot(x,y)  _lrotl(x,y)\n#else\n   #define stbi_lrot(x,y)  (((x) << (y)) | ((x) >> (-(y) & 31)))\n#endif\n\n#if defined(STBI_MALLOC) && defined(STBI_FREE) && (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED))\n// ok\n#elif !defined(STBI_MALLOC) && !defined(STBI_FREE) && !defined(STBI_REALLOC) && !defined(STBI_REALLOC_SIZED)\n// ok\n#else\n#error \"Must define all or none of STBI_MALLOC, STBI_FREE, and STBI_REALLOC (or STBI_REALLOC_SIZED).\"\n#endif\n\n#ifndef STBI_MALLOC\n#define STBI_MALLOC(sz)           malloc(sz)\n#define STBI_REALLOC(p,newsz)     realloc(p,newsz)\n#define STBI_FREE(p)              free(p)\n#endif\n\n#ifndef STBI_REALLOC_SIZED\n#define STBI_REALLOC_SIZED(p,oldsz,newsz) STBI_REALLOC(p,newsz)\n#endif\n\n// x86/x64 detection\n#if defined(__x86_64__) || defined(_M_X64)\n#define STBI__X64_TARGET\n#elif defined(__i386) || defined(_M_IX86)\n#define STBI__X86_TARGET\n#endif\n\n#if defined(__GNUC__) && defined(STBI__X86_TARGET) && !defined(__SSE2__) && !defined(STBI_NO_SIMD)\n// gcc doesn't support sse2 intrinsics unless you compile with -msse2,\n// which in turn means it gets to use SSE2 everywhere. This is unfortunate,\n// but previous attempts to provide the SSE2 functions with runtime\n// detection caused numerous issues. The way architecture extensions are\n// exposed in GCC/Clang is, sadly, not really suited for one-file libs.\n// New behavior: if compiled with -msse2, we use SSE2 without any\n// detection; if not, we don't use it at all.\n#define STBI_NO_SIMD\n#endif\n\n#if defined(__MINGW32__) && defined(STBI__X86_TARGET) && !defined(STBI_MINGW_ENABLE_SSE2) && !defined(STBI_NO_SIMD)\n// Note that __MINGW32__ doesn't actually mean 32-bit, so we have to avoid STBI__X64_TARGET\n//\n// 32-bit MinGW wants ESP to be 16-byte aligned, but this is not in the\n// Windows ABI and VC++ as well as Windows DLLs don't maintain that invariant.\n// As a result, enabling SSE2 on 32-bit MinGW is dangerous when not\n// simultaneously enabling \"-mstackrealign\".\n//\n// See https://github.com/nothings/stb/issues/81 for more information.\n//\n// So default to no SSE2 on 32-bit MinGW. If you've read this far and added\n// -mstackrealign to your build settings, feel free to #define STBI_MINGW_ENABLE_SSE2.\n#define STBI_NO_SIMD\n#endif\n\n#if !defined(STBI_NO_SIMD) && (defined(STBI__X86_TARGET) || defined(STBI__X64_TARGET))\n#define STBI_SSE2\n#include <emmintrin.h>\n\n#ifdef _MSC_VER\n\n#if _MSC_VER >= 1400  // not VC6\n#include <intrin.h> // __cpuid\nstatic int stbi__cpuid3(void)\n{\n   int info[4];\n   __cpuid(info,1);\n   return info[3];\n}\n#else\nstatic int stbi__cpuid3(void)\n{\n   int res;\n   __asm {\n      mov  eax,1\n      cpuid\n      mov  res,edx\n   }\n   return res;\n}\n#endif\n\n#define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name\n\n#if !defined(STBI_NO_JPEG) && defined(STBI_SSE2)\nstatic int stbi__sse2_available(void)\n{\n   int info3 = stbi__cpuid3();\n   return ((info3 >> 26) & 1) != 0;\n}\n#endif\n\n#else // assume GCC-style if not VC++\n#define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16)))\n\n#if !defined(STBI_NO_JPEG) && defined(STBI_SSE2)\nstatic int stbi__sse2_available(void)\n{\n   // If we're even attempting to compile this on GCC/Clang, that means\n   // -msse2 is on, which means the compiler is allowed to use SSE2\n   // instructions at will, and so are we.\n   return 1;\n}\n#endif\n\n#endif\n#endif\n\n// ARM NEON\n#if defined(STBI_NO_SIMD) && defined(STBI_NEON)\n#undef STBI_NEON\n#endif\n\n#ifdef STBI_NEON\n#include <arm_neon.h>\n#ifdef _MSC_VER\n#define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name\n#else\n#define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16)))\n#endif\n#endif\n\n#ifndef STBI_SIMD_ALIGN\n#define STBI_SIMD_ALIGN(type, name) type name\n#endif\n\n#ifndef STBI_MAX_DIMENSIONS\n#define STBI_MAX_DIMENSIONS (1 << 24)\n#endif\n\n///////////////////////////////////////////////\n//\n//  stbi__context struct and start_xxx functions\n\n// stbi__context structure is our basic context used by all images, so it\n// contains all the IO context, plus some basic image information\ntypedef struct\n{\n   stbi__uint32 img_x, img_y;\n   int img_n, img_out_n;\n\n   stbi_io_callbacks io;\n   void *io_user_data;\n\n   int read_from_callbacks;\n   int buflen;\n   stbi_uc buffer_start[128];\n   int callback_already_read;\n\n   stbi_uc *img_buffer, *img_buffer_end;\n   stbi_uc *img_buffer_original, *img_buffer_original_end;\n} stbi__context;\n\n\nstatic void stbi__refill_buffer(stbi__context *s);\n\n// initialize a memory-decode context\nstatic void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len)\n{\n   s->io.read = NULL;\n   s->read_from_callbacks = 0;\n   s->callback_already_read = 0;\n   s->img_buffer = s->img_buffer_original = (stbi_uc *) buffer;\n   s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *) buffer+len;\n}\n\n// initialize a callback-based context\nstatic void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, void *user)\n{\n   s->io = *c;\n   s->io_user_data = user;\n   s->buflen = sizeof(s->buffer_start);\n   s->read_from_callbacks = 1;\n   s->callback_already_read = 0;\n   s->img_buffer = s->img_buffer_original = s->buffer_start;\n   stbi__refill_buffer(s);\n   s->img_buffer_original_end = s->img_buffer_end;\n}\n\n#ifndef STBI_NO_STDIO\n\nstatic int stbi__stdio_read(void *user, char *data, int size)\n{\n   return (int) fread(data,1,size,(FILE*) user);\n}\n\nstatic void stbi__stdio_skip(void *user, int n)\n{\n   int ch;\n   fseek((FILE*) user, n, SEEK_CUR);\n   ch = fgetc((FILE*) user);  /* have to read a byte to reset feof()'s flag */\n   if (ch != EOF) {\n      ungetc(ch, (FILE *) user);  /* push byte back onto stream if valid. */\n   }\n}\n\nstatic int stbi__stdio_eof(void *user)\n{\n   return feof((FILE*) user) || ferror((FILE *) user);\n}\n\nstatic stbi_io_callbacks stbi__stdio_callbacks =\n{\n   stbi__stdio_read,\n   stbi__stdio_skip,\n   stbi__stdio_eof,\n};\n\nstatic void stbi__start_file(stbi__context *s, FILE *f)\n{\n   stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *) f);\n}\n\n//static void stop_file(stbi__context *s) { }\n\n#endif // !STBI_NO_STDIO\n\nstatic void stbi__rewind(stbi__context *s)\n{\n   // conceptually rewind SHOULD rewind to the beginning of the stream,\n   // but we just rewind to the beginning of the initial buffer, because\n   // we only use it after doing 'test', which only ever looks at at most 92 bytes\n   s->img_buffer = s->img_buffer_original;\n   s->img_buffer_end = s->img_buffer_original_end;\n}\n\nenum\n{\n   STBI_ORDER_RGB,\n   STBI_ORDER_BGR\n};\n\ntypedef struct\n{\n   int bits_per_channel;\n   int num_channels;\n   int channel_order;\n} stbi__result_info;\n\n#ifndef STBI_NO_JPEG\nstatic int      stbi__jpeg_test(stbi__context *s);\nstatic void    *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri);\nstatic int      stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp);\n#endif\n\n#ifndef STBI_NO_PNG\nstatic int      stbi__png_test(stbi__context *s);\nstatic void    *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri);\nstatic int      stbi__png_info(stbi__context *s, int *x, int *y, int *comp);\nstatic int      stbi__png_is16(stbi__context *s);\n#endif\n\n#ifndef STBI_NO_BMP\nstatic int      stbi__bmp_test(stbi__context *s);\nstatic void    *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri);\nstatic int      stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp);\n#endif\n\n#ifndef STBI_NO_TGA\nstatic int      stbi__tga_test(stbi__context *s);\nstatic void    *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri);\nstatic int      stbi__tga_info(stbi__context *s, int *x, int *y, int *comp);\n#endif\n\n#ifndef STBI_NO_PSD\nstatic int      stbi__psd_test(stbi__context *s);\nstatic void    *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc);\nstatic int      stbi__psd_info(stbi__context *s, int *x, int *y, int *comp);\nstatic int      stbi__psd_is16(stbi__context *s);\n#endif\n\n#ifndef STBI_NO_HDR\nstatic int      stbi__hdr_test(stbi__context *s);\nstatic float   *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri);\nstatic int      stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp);\n#endif\n\n#ifndef STBI_NO_PIC\nstatic int      stbi__pic_test(stbi__context *s);\nstatic void    *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri);\nstatic int      stbi__pic_info(stbi__context *s, int *x, int *y, int *comp);\n#endif\n\n#ifndef STBI_NO_GIF\nstatic int      stbi__gif_test(stbi__context *s);\nstatic void    *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri);\nstatic void    *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp);\nstatic int      stbi__gif_info(stbi__context *s, int *x, int *y, int *comp);\n#endif\n\n#ifndef STBI_NO_PNM\nstatic int      stbi__pnm_test(stbi__context *s);\nstatic void    *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri);\nstatic int      stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp);\nstatic int      stbi__pnm_is16(stbi__context *s);\n#endif\n\nstatic\n#ifdef STBI_THREAD_LOCAL\nSTBI_THREAD_LOCAL\n#endif\nconst char *stbi__g_failure_reason;\n\nSTBIDEF const char *stbi_failure_reason(void)\n{\n   return stbi__g_failure_reason;\n}\n\n#ifndef STBI_NO_FAILURE_STRINGS\nstatic int stbi__err(const char *str)\n{\n   stbi__g_failure_reason = str;\n   return 0;\n}\n#endif\n\nstatic void *stbi__malloc(size_t size)\n{\n    return STBI_MALLOC(size);\n}\n\n// stb_image uses ints pervasively, including for offset calculations.\n// therefore the largest decoded image size we can support with the\n// current code, even on 64-bit targets, is INT_MAX. this is not a\n// significant limitation for the intended use case.\n//\n// we do, however, need to make sure our size calculations don't\n// overflow. hence a few helper functions for size calculations that\n// multiply integers together, making sure that they're non-negative\n// and no overflow occurs.\n\n// return 1 if the sum is valid, 0 on overflow.\n// negative terms are considered invalid.\nstatic int stbi__addsizes_valid(int a, int b)\n{\n   if (b < 0) return 0;\n   // now 0 <= b <= INT_MAX, hence also\n   // 0 <= INT_MAX - b <= INTMAX.\n   // And \"a + b <= INT_MAX\" (which might overflow) is the\n   // same as a <= INT_MAX - b (no overflow)\n   return a <= INT_MAX - b;\n}\n\n// returns 1 if the product is valid, 0 on overflow.\n// negative factors are considered invalid.\nstatic int stbi__mul2sizes_valid(int a, int b)\n{\n   if (a < 0 || b < 0) return 0;\n   if (b == 0) return 1; // mul-by-0 is always safe\n   // portable way to check for no overflows in a*b\n   return a <= INT_MAX/b;\n}\n\n#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR)\n// returns 1 if \"a*b + add\" has no negative terms/factors and doesn't overflow\nstatic int stbi__mad2sizes_valid(int a, int b, int add)\n{\n   return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a*b, add);\n}\n#endif\n\n// returns 1 if \"a*b*c + add\" has no negative terms/factors and doesn't overflow\nstatic int stbi__mad3sizes_valid(int a, int b, int c, int add)\n{\n   return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) &&\n      stbi__addsizes_valid(a*b*c, add);\n}\n\n// returns 1 if \"a*b*c*d + add\" has no negative terms/factors and doesn't overflow\n#if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) || !defined(STBI_NO_PNM)\nstatic int stbi__mad4sizes_valid(int a, int b, int c, int d, int add)\n{\n   return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) &&\n      stbi__mul2sizes_valid(a*b*c, d) && stbi__addsizes_valid(a*b*c*d, add);\n}\n#endif\n\n#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR)\n// mallocs with size overflow checking\nstatic void *stbi__malloc_mad2(int a, int b, int add)\n{\n   if (!stbi__mad2sizes_valid(a, b, add)) return NULL;\n   return stbi__malloc(a*b + add);\n}\n#endif\n\nstatic void *stbi__malloc_mad3(int a, int b, int c, int add)\n{\n   if (!stbi__mad3sizes_valid(a, b, c, add)) return NULL;\n   return stbi__malloc(a*b*c + add);\n}\n\n#if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) || !defined(STBI_NO_PNM)\nstatic void *stbi__malloc_mad4(int a, int b, int c, int d, int add)\n{\n   if (!stbi__mad4sizes_valid(a, b, c, d, add)) return NULL;\n   return stbi__malloc(a*b*c*d + add);\n}\n#endif\n\n// returns 1 if the sum of two signed ints is valid (between -2^31 and 2^31-1 inclusive), 0 on overflow.\nstatic int stbi__addints_valid(int a, int b)\n{\n   if ((a >= 0) != (b >= 0)) return 1; // a and b have different signs, so no overflow\n   if (a < 0 && b < 0) return a >= INT_MIN - b; // same as a + b >= INT_MIN; INT_MIN - b cannot overflow since b < 0.\n   return a <= INT_MAX - b;\n}\n\n// returns 1 if the product of two signed shorts is valid, 0 on overflow.\nstatic int stbi__mul2shorts_valid(short a, short b)\n{\n   if (b == 0 || b == -1) return 1; // multiplication by 0 is always 0; check for -1 so SHRT_MIN/b doesn't overflow\n   if ((a >= 0) == (b >= 0)) return a <= SHRT_MAX/b; // product is positive, so similar to mul2sizes_valid\n   if (b < 0) return a <= SHRT_MIN / b; // same as a * b >= SHRT_MIN\n   return a >= SHRT_MIN / b;\n}\n\n// stbi__err - error\n// stbi__errpf - error returning pointer to float\n// stbi__errpuc - error returning pointer to unsigned char\n\n#ifdef STBI_NO_FAILURE_STRINGS\n   #define stbi__err(x,y)  0\n#elif defined(STBI_FAILURE_USERMSG)\n   #define stbi__err(x,y)  stbi__err(y)\n#else\n   #define stbi__err(x,y)  stbi__err(x)\n#endif\n\n#define stbi__errpf(x,y)   ((float *)(size_t) (stbi__err(x,y)?NULL:NULL))\n#define stbi__errpuc(x,y)  ((unsigned char *)(size_t) (stbi__err(x,y)?NULL:NULL))\n\nSTBIDEF void stbi_image_free(void *retval_from_stbi_load)\n{\n   STBI_FREE(retval_from_stbi_load);\n}\n\n#ifndef STBI_NO_LINEAR\nstatic float   *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp);\n#endif\n\n#ifndef STBI_NO_HDR\nstatic stbi_uc *stbi__hdr_to_ldr(float   *data, int x, int y, int comp);\n#endif\n\nstatic int stbi__vertically_flip_on_load_global = 0;\n\nSTBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip)\n{\n   stbi__vertically_flip_on_load_global = flag_true_if_should_flip;\n}\n\n#ifndef STBI_THREAD_LOCAL\n#define stbi__vertically_flip_on_load  stbi__vertically_flip_on_load_global\n#else\nstatic STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, stbi__vertically_flip_on_load_set;\n\nSTBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip)\n{\n   stbi__vertically_flip_on_load_local = flag_true_if_should_flip;\n   stbi__vertically_flip_on_load_set = 1;\n}\n\n#define stbi__vertically_flip_on_load  (stbi__vertically_flip_on_load_set       \\\n                                         ? stbi__vertically_flip_on_load_local  \\\n                                         : stbi__vertically_flip_on_load_global)\n#endif // STBI_THREAD_LOCAL\n\nstatic void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc)\n{\n   memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields\n   ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed\n   ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order\n   ri->num_channels = 0;\n\n   // test the formats with a very explicit header first (at least a FOURCC\n   // or distinctive magic number first)\n   #ifndef STBI_NO_PNG\n   if (stbi__png_test(s))  return stbi__png_load(s,x,y,comp,req_comp, ri);\n   #endif\n   #ifndef STBI_NO_BMP\n   if (stbi__bmp_test(s))  return stbi__bmp_load(s,x,y,comp,req_comp, ri);\n   #endif\n   #ifndef STBI_NO_GIF\n   if (stbi__gif_test(s))  return stbi__gif_load(s,x,y,comp,req_comp, ri);\n   #endif\n   #ifndef STBI_NO_PSD\n   if (stbi__psd_test(s))  return stbi__psd_load(s,x,y,comp,req_comp, ri, bpc);\n   #else\n   STBI_NOTUSED(bpc);\n   #endif\n   #ifndef STBI_NO_PIC\n   if (stbi__pic_test(s))  return stbi__pic_load(s,x,y,comp,req_comp, ri);\n   #endif\n\n   // then the formats that can end up attempting to load with just 1 or 2\n   // bytes matching expectations; these are prone to false positives, so\n   // try them later\n   #ifndef STBI_NO_JPEG\n   if (stbi__jpeg_test(s)) return stbi__jpeg_load(s,x,y,comp,req_comp, ri);\n   #endif\n   #ifndef STBI_NO_PNM\n   if (stbi__pnm_test(s))  return stbi__pnm_load(s,x,y,comp,req_comp, ri);\n   #endif\n\n   #ifndef STBI_NO_HDR\n   if (stbi__hdr_test(s)) {\n      float *hdr = stbi__hdr_load(s, x,y,comp,req_comp, ri);\n      return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp);\n   }\n   #endif\n\n   #ifndef STBI_NO_TGA\n   // test tga last because it's a crappy test!\n   if (stbi__tga_test(s))\n      return stbi__tga_load(s,x,y,comp,req_comp, ri);\n   #endif\n\n   return stbi__errpuc(\"unknown image type\", \"Image not of any known type, or corrupt\");\n}\n\nstatic stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, int channels)\n{\n   int i;\n   int img_len = w * h * channels;\n   stbi_uc *reduced;\n\n   reduced = (stbi_uc *) stbi__malloc(img_len);\n   if (reduced == NULL) return stbi__errpuc(\"outofmem\", \"Out of memory\");\n\n   for (i = 0; i < img_len; ++i)\n      reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling\n\n   STBI_FREE(orig);\n   return reduced;\n}\n\nstatic stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, int channels)\n{\n   int i;\n   int img_len = w * h * channels;\n   stbi__uint16 *enlarged;\n\n   enlarged = (stbi__uint16 *) stbi__malloc(img_len*2);\n   if (enlarged == NULL) return (stbi__uint16 *) stbi__errpuc(\"outofmem\", \"Out of memory\");\n\n   for (i = 0; i < img_len; ++i)\n      enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff\n\n   STBI_FREE(orig);\n   return enlarged;\n}\n\nstatic void stbi__vertical_flip(void *image, int w, int h, int bytes_per_pixel)\n{\n   int row;\n   size_t bytes_per_row = (size_t)w * bytes_per_pixel;\n   stbi_uc temp[2048];\n   stbi_uc *bytes = (stbi_uc *)image;\n\n   for (row = 0; row < (h>>1); row++) {\n      stbi_uc *row0 = bytes + row*bytes_per_row;\n      stbi_uc *row1 = bytes + (h - row - 1)*bytes_per_row;\n      // swap row0 with row1\n      size_t bytes_left = bytes_per_row;\n      while (bytes_left) {\n         size_t bytes_copy = (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp);\n         memcpy(temp, row0, bytes_copy);\n         memcpy(row0, row1, bytes_copy);\n         memcpy(row1, temp, bytes_copy);\n         row0 += bytes_copy;\n         row1 += bytes_copy;\n         bytes_left -= bytes_copy;\n      }\n   }\n}\n\n#ifndef STBI_NO_GIF\nstatic void stbi__vertical_flip_slices(void *image, int w, int h, int z, int bytes_per_pixel)\n{\n   int slice;\n   int slice_size = w * h * bytes_per_pixel;\n\n   stbi_uc *bytes = (stbi_uc *)image;\n   for (slice = 0; slice < z; ++slice) {\n      stbi__vertical_flip(bytes, w, h, bytes_per_pixel);\n      bytes += slice_size;\n   }\n}\n#endif\n\nstatic unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, int *y, int *comp, int req_comp)\n{\n   stbi__result_info ri;\n   void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8);\n\n   if (result == NULL)\n      return NULL;\n\n   // it is the responsibility of the loaders to make sure we get either 8 or 16 bit.\n   STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16);\n\n   if (ri.bits_per_channel != 8) {\n      result = stbi__convert_16_to_8((stbi__uint16 *) result, *x, *y, req_comp == 0 ? *comp : req_comp);\n      ri.bits_per_channel = 8;\n   }\n\n   // @TODO: move stbi__convert_format to here\n\n   if (stbi__vertically_flip_on_load) {\n      int channels = req_comp ? req_comp : *comp;\n      stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc));\n   }\n\n   return (unsigned char *) result;\n}\n\nstatic stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, int *y, int *comp, int req_comp)\n{\n   stbi__result_info ri;\n   void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16);\n\n   if (result == NULL)\n      return NULL;\n\n   // it is the responsibility of the loaders to make sure we get either 8 or 16 bit.\n   STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16);\n\n   if (ri.bits_per_channel != 16) {\n      result = stbi__convert_8_to_16((stbi_uc *) result, *x, *y, req_comp == 0 ? *comp : req_comp);\n      ri.bits_per_channel = 16;\n   }\n\n   // @TODO: move stbi__convert_format16 to here\n   // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to keep more precision\n\n   if (stbi__vertically_flip_on_load) {\n      int channels = req_comp ? req_comp : *comp;\n      stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16));\n   }\n\n   return (stbi__uint16 *) result;\n}\n\n#if !defined(STBI_NO_HDR) && !defined(STBI_NO_LINEAR)\nstatic void stbi__float_postprocess(float *result, int *x, int *y, int *comp, int req_comp)\n{\n   if (stbi__vertically_flip_on_load && result != NULL) {\n      int channels = req_comp ? req_comp : *comp;\n      stbi__vertical_flip(result, *x, *y, channels * sizeof(float));\n   }\n}\n#endif\n\n#ifndef STBI_NO_STDIO\n\n#if defined(_WIN32) && defined(STBI_WINDOWS_UTF8)\nSTBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide);\nSTBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default);\n#endif\n\n#if defined(_WIN32) && defined(STBI_WINDOWS_UTF8)\nSTBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input)\n{\n\treturn WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL);\n}\n#endif\n\nstatic FILE *stbi__fopen(char const *filename, char const *mode)\n{\n   FILE *f;\n#if defined(_WIN32) && defined(STBI_WINDOWS_UTF8)\n   wchar_t wMode[64];\n   wchar_t wFilename[1024];\n\tif (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename)))\n      return 0;\n\n\tif (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode)))\n      return 0;\n\n#if defined(_MSC_VER) && _MSC_VER >= 1400\n\tif (0 != _wfopen_s(&f, wFilename, wMode))\n\t\tf = 0;\n#else\n   f = _wfopen(wFilename, wMode);\n#endif\n\n#elif defined(_MSC_VER) && _MSC_VER >= 1400\n   if (0 != fopen_s(&f, filename, mode))\n      f=0;\n#else\n   f = fopen(filename, mode);\n#endif\n   return f;\n}\n\n\nSTBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, int req_comp)\n{\n   FILE *f = stbi__fopen(filename, \"rb\");\n   unsigned char *result;\n   if (!f) return stbi__errpuc(\"can't fopen\", \"Unable to open file\");\n   result = stbi_load_from_file(f,x,y,comp,req_comp);\n   fclose(f);\n   return result;\n}\n\nSTBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, int req_comp)\n{\n   unsigned char *result;\n   stbi__context s;\n   stbi__start_file(&s,f);\n   result = stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp);\n   if (result) {\n      // need to 'unget' all the characters in the IO buffer\n      fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR);\n   }\n   return result;\n}\n\nSTBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, int req_comp)\n{\n   stbi__uint16 *result;\n   stbi__context s;\n   stbi__start_file(&s,f);\n   result = stbi__load_and_postprocess_16bit(&s,x,y,comp,req_comp);\n   if (result) {\n      // need to 'unget' all the characters in the IO buffer\n      fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR);\n   }\n   return result;\n}\n\nSTBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, int req_comp)\n{\n   FILE *f = stbi__fopen(filename, \"rb\");\n   stbi__uint16 *result;\n   if (!f) return (stbi_us *) stbi__errpuc(\"can't fopen\", \"Unable to open file\");\n   result = stbi_load_from_file_16(f,x,y,comp,req_comp);\n   fclose(f);\n   return result;\n}\n\n\n#endif //!STBI_NO_STDIO\n\nSTBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels)\n{\n   stbi__context s;\n   stbi__start_mem(&s,buffer,len);\n   return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels);\n}\n\nSTBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels)\n{\n   stbi__context s;\n   stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user);\n   return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels);\n}\n\nSTBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp)\n{\n   stbi__context s;\n   stbi__start_mem(&s,buffer,len);\n   return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp);\n}\n\nSTBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp)\n{\n   stbi__context s;\n   stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user);\n   return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp);\n}\n\n#ifndef STBI_NO_GIF\nSTBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp)\n{\n   unsigned char *result;\n   stbi__context s;\n   stbi__start_mem(&s,buffer,len);\n\n   result = (unsigned char*) stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp);\n   if (stbi__vertically_flip_on_load) {\n      stbi__vertical_flip_slices( result, *x, *y, *z, *comp );\n   }\n\n   return result;\n}\n#endif\n\n#ifndef STBI_NO_LINEAR\nstatic float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, int req_comp)\n{\n   unsigned char *data;\n   #ifndef STBI_NO_HDR\n   if (stbi__hdr_test(s)) {\n      stbi__result_info ri;\n      float *hdr_data = stbi__hdr_load(s,x,y,comp,req_comp, &ri);\n      if (hdr_data)\n         stbi__float_postprocess(hdr_data,x,y,comp,req_comp);\n      return hdr_data;\n   }\n   #endif\n   data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp);\n   if (data)\n      return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp);\n   return stbi__errpf(\"unknown image type\", \"Image not of any known type, or corrupt\");\n}\n\nSTBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp)\n{\n   stbi__context s;\n   stbi__start_mem(&s,buffer,len);\n   return stbi__loadf_main(&s,x,y,comp,req_comp);\n}\n\nSTBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp)\n{\n   stbi__context s;\n   stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user);\n   return stbi__loadf_main(&s,x,y,comp,req_comp);\n}\n\n#ifndef STBI_NO_STDIO\nSTBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, int req_comp)\n{\n   float *result;\n   FILE *f = stbi__fopen(filename, \"rb\");\n   if (!f) return stbi__errpf(\"can't fopen\", \"Unable to open file\");\n   result = stbi_loadf_from_file(f,x,y,comp,req_comp);\n   fclose(f);\n   return result;\n}\n\nSTBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_comp)\n{\n   stbi__context s;\n   stbi__start_file(&s,f);\n   return stbi__loadf_main(&s,x,y,comp,req_comp);\n}\n#endif // !STBI_NO_STDIO\n\n#endif // !STBI_NO_LINEAR\n\n// these is-hdr-or-not is defined independent of whether STBI_NO_LINEAR is\n// defined, for API simplicity; if STBI_NO_LINEAR is defined, it always\n// reports false!\n\nSTBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len)\n{\n   #ifndef STBI_NO_HDR\n   stbi__context s;\n   stbi__start_mem(&s,buffer,len);\n   return stbi__hdr_test(&s);\n   #else\n   STBI_NOTUSED(buffer);\n   STBI_NOTUSED(len);\n   return 0;\n   #endif\n}\n\n#ifndef STBI_NO_STDIO\nSTBIDEF int      stbi_is_hdr          (char const *filename)\n{\n   FILE *f = stbi__fopen(filename, \"rb\");\n   int result=0;\n   if (f) {\n      result = stbi_is_hdr_from_file(f);\n      fclose(f);\n   }\n   return result;\n}\n\nSTBIDEF int stbi_is_hdr_from_file(FILE *f)\n{\n   #ifndef STBI_NO_HDR\n   long pos = ftell(f);\n   int res;\n   stbi__context s;\n   stbi__start_file(&s,f);\n   res = stbi__hdr_test(&s);\n   fseek(f, pos, SEEK_SET);\n   return res;\n   #else\n   STBI_NOTUSED(f);\n   return 0;\n   #endif\n}\n#endif // !STBI_NO_STDIO\n\nSTBIDEF int      stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user)\n{\n   #ifndef STBI_NO_HDR\n   stbi__context s;\n   stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user);\n   return stbi__hdr_test(&s);\n   #else\n   STBI_NOTUSED(clbk);\n   STBI_NOTUSED(user);\n   return 0;\n   #endif\n}\n\n#ifndef STBI_NO_LINEAR\nstatic float stbi__l2h_gamma=2.2f, stbi__l2h_scale=1.0f;\n\nSTBIDEF void   stbi_ldr_to_hdr_gamma(float gamma) { stbi__l2h_gamma = gamma; }\nSTBIDEF void   stbi_ldr_to_hdr_scale(float scale) { stbi__l2h_scale = scale; }\n#endif\n\nstatic float stbi__h2l_gamma_i=1.0f/2.2f, stbi__h2l_scale_i=1.0f;\n\nSTBIDEF void   stbi_hdr_to_ldr_gamma(float gamma) { stbi__h2l_gamma_i = 1/gamma; }\nSTBIDEF void   stbi_hdr_to_ldr_scale(float scale) { stbi__h2l_scale_i = 1/scale; }\n\n\n//////////////////////////////////////////////////////////////////////////////\n//\n// Common code used by all image loaders\n//\n\nenum\n{\n   STBI__SCAN_load=0,\n   STBI__SCAN_type,\n   STBI__SCAN_header\n};\n\nstatic void stbi__refill_buffer(stbi__context *s)\n{\n   int n = (s->io.read)(s->io_user_data,(char*)s->buffer_start,s->buflen);\n   s->callback_already_read += (int) (s->img_buffer - s->img_buffer_original);\n   if (n == 0) {\n      // at end of file, treat same as if from memory, but need to handle case\n      // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file\n      s->read_from_callbacks = 0;\n      s->img_buffer = s->buffer_start;\n      s->img_buffer_end = s->buffer_start+1;\n      *s->img_buffer = 0;\n   } else {\n      s->img_buffer = s->buffer_start;\n      s->img_buffer_end = s->buffer_start + n;\n   }\n}\n\nstbi_inline static stbi_uc stbi__get8(stbi__context *s)\n{\n   if (s->img_buffer < s->img_buffer_end)\n      return *s->img_buffer++;\n   if (s->read_from_callbacks) {\n      stbi__refill_buffer(s);\n      return *s->img_buffer++;\n   }\n   return 0;\n}\n\n#if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM)\n// nothing\n#else\nstbi_inline static int stbi__at_eof(stbi__context *s)\n{\n   if (s->io.read) {\n      if (!(s->io.eof)(s->io_user_data)) return 0;\n      // if feof() is true, check if buffer = end\n      // special case: we've only got the special 0 character at the end\n      if (s->read_from_callbacks == 0) return 1;\n   }\n\n   return s->img_buffer >= s->img_buffer_end;\n}\n#endif\n\n#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC)\n// nothing\n#else\nstatic void stbi__skip(stbi__context *s, int n)\n{\n   if (n == 0) return;  // already there!\n   if (n < 0) {\n      s->img_buffer = s->img_buffer_end;\n      return;\n   }\n   if (s->io.read) {\n      int blen = (int) (s->img_buffer_end - s->img_buffer);\n      if (blen < n) {\n         s->img_buffer = s->img_buffer_end;\n         (s->io.skip)(s->io_user_data, n - blen);\n         return;\n      }\n   }\n   s->img_buffer += n;\n}\n#endif\n\n#if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && defined(STBI_NO_PNM)\n// nothing\n#else\nstatic int stbi__getn(stbi__context *s, stbi_uc *buffer, int n)\n{\n   if (s->io.read) {\n      int blen = (int) (s->img_buffer_end - s->img_buffer);\n      if (blen < n) {\n         int res, count;\n\n         memcpy(buffer, s->img_buffer, blen);\n\n         count = (s->io.read)(s->io_user_data, (char*) buffer + blen, n - blen);\n         res = (count == (n-blen));\n         s->img_buffer = s->img_buffer_end;\n         return res;\n      }\n   }\n\n   if (s->img_buffer+n <= s->img_buffer_end) {\n      memcpy(buffer, s->img_buffer, n);\n      s->img_buffer += n;\n      return 1;\n   } else\n      return 0;\n}\n#endif\n\n#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC)\n// nothing\n#else\nstatic int stbi__get16be(stbi__context *s)\n{\n   int z = stbi__get8(s);\n   return (z << 8) + stbi__get8(s);\n}\n#endif\n\n#if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC)\n// nothing\n#else\nstatic stbi__uint32 stbi__get32be(stbi__context *s)\n{\n   stbi__uint32 z = stbi__get16be(s);\n   return (z << 16) + stbi__get16be(s);\n}\n#endif\n\n#if defined(STBI_NO_BMP) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF)\n// nothing\n#else\nstatic int stbi__get16le(stbi__context *s)\n{\n   int z = stbi__get8(s);\n   return z + (stbi__get8(s) << 8);\n}\n#endif\n\n#ifndef STBI_NO_BMP\nstatic stbi__uint32 stbi__get32le(stbi__context *s)\n{\n   stbi__uint32 z = stbi__get16le(s);\n   z += (stbi__uint32)stbi__get16le(s) << 16;\n   return z;\n}\n#endif\n\n#define STBI__BYTECAST(x)  ((stbi_uc) ((x) & 255))  // truncate int to byte without warnings\n\n#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM)\n// nothing\n#else\n//////////////////////////////////////////////////////////////////////////////\n//\n//  generic converter from built-in img_n to req_comp\n//    individual types do this automatically as much as possible (e.g. jpeg\n//    does all cases internally since it needs to colorspace convert anyway,\n//    and it never has alpha, so very few cases ). png can automatically\n//    interleave an alpha=255 channel, but falls back to this for other cases\n//\n//  assume data buffer is malloced, so malloc a new one and free that one\n//  only failure mode is malloc failing\n\nstatic stbi_uc stbi__compute_y(int r, int g, int b)\n{\n   return (stbi_uc) (((r*77) + (g*150) +  (29*b)) >> 8);\n}\n#endif\n\n#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM)\n// nothing\n#else\nstatic unsigned char *stbi__convert_format(unsigned char *data, int img_n, int req_comp, unsigned int x, unsigned int y)\n{\n   int i,j;\n   unsigned char *good;\n\n   if (req_comp == img_n) return data;\n   STBI_ASSERT(req_comp >= 1 && req_comp <= 4);\n\n   good = (unsigned char *) stbi__malloc_mad3(req_comp, x, y, 0);\n   if (good == NULL) {\n      STBI_FREE(data);\n      return stbi__errpuc(\"outofmem\", \"Out of memory\");\n   }\n\n   for (j=0; j < (int) y; ++j) {\n      unsigned char *src  = data + j * x * img_n   ;\n      unsigned char *dest = good + j * x * req_comp;\n\n      #define STBI__COMBO(a,b)  ((a)*8+(b))\n      #define STBI__CASE(a,b)   case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b)\n      // convert source image with img_n components to one with req_comp components;\n      // avoid switch per pixel, so use switch per scanline and massive macros\n      switch (STBI__COMBO(img_n, req_comp)) {\n         STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=255;                                     } break;\n         STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0];                                  } break;\n         STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=255;                     } break;\n         STBI__CASE(2,1) { dest[0]=src[0];                                                  } break;\n         STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0];                                  } break;\n         STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1];                  } break;\n         STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=255;        } break;\n         STBI__CASE(3,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]);                   } break;\n         STBI__CASE(3,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = 255;    } break;\n         STBI__CASE(4,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]);                   } break;\n         STBI__CASE(4,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = src[3]; } break;\n         STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];                    } break;\n         default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return stbi__errpuc(\"unsupported\", \"Unsupported format conversion\");\n      }\n      #undef STBI__CASE\n   }\n\n   STBI_FREE(data);\n   return good;\n}\n#endif\n\n#if defined(STBI_NO_PNG) && defined(STBI_NO_PSD)\n// nothing\n#else\nstatic stbi__uint16 stbi__compute_y_16(int r, int g, int b)\n{\n   return (stbi__uint16) (((r*77) + (g*150) +  (29*b)) >> 8);\n}\n#endif\n\n#if defined(STBI_NO_PNG) && defined(STBI_NO_PSD)\n// nothing\n#else\nstatic stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, int req_comp, unsigned int x, unsigned int y)\n{\n   int i,j;\n   stbi__uint16 *good;\n\n   if (req_comp == img_n) return data;\n   STBI_ASSERT(req_comp >= 1 && req_comp <= 4);\n\n   good = (stbi__uint16 *) stbi__malloc(req_comp * x * y * 2);\n   if (good == NULL) {\n      STBI_FREE(data);\n      return (stbi__uint16 *) stbi__errpuc(\"outofmem\", \"Out of memory\");\n   }\n\n   for (j=0; j < (int) y; ++j) {\n      stbi__uint16 *src  = data + j * x * img_n   ;\n      stbi__uint16 *dest = good + j * x * req_comp;\n\n      #define STBI__COMBO(a,b)  ((a)*8+(b))\n      #define STBI__CASE(a,b)   case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b)\n      // convert source image with img_n components to one with req_comp components;\n      // avoid switch per pixel, so use switch per scanline and massive macros\n      switch (STBI__COMBO(img_n, req_comp)) {\n         STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=0xffff;                                     } break;\n         STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0];                                     } break;\n         STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=0xffff;                     } break;\n         STBI__CASE(2,1) { dest[0]=src[0];                                                     } break;\n         STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0];                                     } break;\n         STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1];                     } break;\n         STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=0xffff;        } break;\n         STBI__CASE(3,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]);                   } break;\n         STBI__CASE(3,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = 0xffff; } break;\n         STBI__CASE(4,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]);                   } break;\n         STBI__CASE(4,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = src[3]; } break;\n         STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];                       } break;\n         default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return (stbi__uint16*) stbi__errpuc(\"unsupported\", \"Unsupported format conversion\");\n      }\n      #undef STBI__CASE\n   }\n\n   STBI_FREE(data);\n   return good;\n}\n#endif\n\n#ifndef STBI_NO_LINEAR\nstatic float   *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp)\n{\n   int i,k,n;\n   float *output;\n   if (!data) return NULL;\n   output = (float *) stbi__malloc_mad4(x, y, comp, sizeof(float), 0);\n   if (output == NULL) { STBI_FREE(data); return stbi__errpf(\"outofmem\", \"Out of memory\"); }\n   // compute number of non-alpha components\n   if (comp & 1) n = comp; else n = comp-1;\n   for (i=0; i < x*y; ++i) {\n      for (k=0; k < n; ++k) {\n         output[i*comp + k] = (float) (pow(data[i*comp+k]/255.0f, stbi__l2h_gamma) * stbi__l2h_scale);\n      }\n   }\n   if (n < comp) {\n      for (i=0; i < x*y; ++i) {\n         output[i*comp + n] = data[i*comp + n]/255.0f;\n      }\n   }\n   STBI_FREE(data);\n   return output;\n}\n#endif\n\n#ifndef STBI_NO_HDR\n#define stbi__float2int(x)   ((int) (x))\nstatic stbi_uc *stbi__hdr_to_ldr(float   *data, int x, int y, int comp)\n{\n   int i,k,n;\n   stbi_uc *output;\n   if (!data) return NULL;\n   output = (stbi_uc *) stbi__malloc_mad3(x, y, comp, 0);\n   if (output == NULL) { STBI_FREE(data); return stbi__errpuc(\"outofmem\", \"Out of memory\"); }\n   // compute number of non-alpha components\n   if (comp & 1) n = comp; else n = comp-1;\n   for (i=0; i < x*y; ++i) {\n      for (k=0; k < n; ++k) {\n         float z = (float) pow(data[i*comp+k]*stbi__h2l_scale_i, stbi__h2l_gamma_i) * 255 + 0.5f;\n         if (z < 0) z = 0;\n         if (z > 255) z = 255;\n         output[i*comp + k] = (stbi_uc) stbi__float2int(z);\n      }\n      if (k < comp) {\n         float z = data[i*comp+k] * 255 + 0.5f;\n         if (z < 0) z = 0;\n         if (z > 255) z = 255;\n         output[i*comp + k] = (stbi_uc) stbi__float2int(z);\n      }\n   }\n   STBI_FREE(data);\n   return output;\n}\n#endif\n\n//////////////////////////////////////////////////////////////////////////////\n//\n//  \"baseline\" JPEG/JFIF decoder\n//\n//    simple implementation\n//      - doesn't support delayed output of y-dimension\n//      - simple interface (only one output format: 8-bit interleaved RGB)\n//      - doesn't try to recover corrupt jpegs\n//      - doesn't allow partial loading, loading multiple at once\n//      - still fast on x86 (copying globals into locals doesn't help x86)\n//      - allocates lots of intermediate memory (full size of all components)\n//        - non-interleaved case requires this anyway\n//        - allows good upsampling (see next)\n//    high-quality\n//      - upsampled channels are bilinearly interpolated, even across blocks\n//      - quality integer IDCT derived from IJG's 'slow'\n//    performance\n//      - fast huffman; reasonable integer IDCT\n//      - some SIMD kernels for common paths on targets with SSE2/NEON\n//      - uses a lot of intermediate memory, could cache poorly\n\n#ifndef STBI_NO_JPEG\n\n// huffman decoding acceleration\n#define FAST_BITS   9  // larger handles more cases; smaller stomps less cache\n\ntypedef struct\n{\n   stbi_uc  fast[1 << FAST_BITS];\n   // weirdly, repacking this into AoS is a 10% speed loss, instead of a win\n   stbi__uint16 code[256];\n   stbi_uc  values[256];\n   stbi_uc  size[257];\n   unsigned int maxcode[18];\n   int    delta[17];   // old 'firstsymbol' - old 'firstcode'\n} stbi__huffman;\n\ntypedef struct\n{\n   stbi__context *s;\n   stbi__huffman huff_dc[4];\n   stbi__huffman huff_ac[4];\n   stbi__uint16 dequant[4][64];\n   stbi__int16 fast_ac[4][1 << FAST_BITS];\n\n// sizes for components, interleaved MCUs\n   int img_h_max, img_v_max;\n   int img_mcu_x, img_mcu_y;\n   int img_mcu_w, img_mcu_h;\n\n// definition of jpeg image component\n   struct\n   {\n      int id;\n      int h,v;\n      int tq;\n      int hd,ha;\n      int dc_pred;\n\n      int x,y,w2,h2;\n      stbi_uc *data;\n      void *raw_data, *raw_coeff;\n      stbi_uc *linebuf;\n      short   *coeff;   // progressive only\n      int      coeff_w, coeff_h; // number of 8x8 coefficient blocks\n   } img_comp[4];\n\n   stbi__uint32   code_buffer; // jpeg entropy-coded buffer\n   int            code_bits;   // number of valid bits\n   unsigned char  marker;      // marker seen while filling entropy buffer\n   int            nomore;      // flag if we saw a marker so must stop\n\n   int            progressive;\n   int            spec_start;\n   int            spec_end;\n   int            succ_high;\n   int            succ_low;\n   int            eob_run;\n   int            jfif;\n   int            app14_color_transform; // Adobe APP14 tag\n   int            rgb;\n\n   int scan_n, order[4];\n   int restart_interval, todo;\n\n// kernels\n   void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]);\n   void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step);\n   stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs);\n} stbi__jpeg;\n\nstatic int stbi__build_huffman(stbi__huffman *h, int *count)\n{\n   int i,j,k=0;\n   unsigned int code;\n   // build size list for each symbol (from JPEG spec)\n   for (i=0; i < 16; ++i) {\n      for (j=0; j < count[i]; ++j) {\n         h->size[k++] = (stbi_uc) (i+1);\n         if(k >= 257) return stbi__err(\"bad size list\",\"Corrupt JPEG\");\n      }\n   }\n   h->size[k] = 0;\n\n   // compute actual symbols (from jpeg spec)\n   code = 0;\n   k = 0;\n   for(j=1; j <= 16; ++j) {\n      // compute delta to add to code to compute symbol id\n      h->delta[j] = k - code;\n      if (h->size[k] == j) {\n         while (h->size[k] == j)\n            h->code[k++] = (stbi__uint16) (code++);\n         if (code-1 >= (1u << j)) return stbi__err(\"bad code lengths\",\"Corrupt JPEG\");\n      }\n      // compute largest code + 1 for this size, preshifted as needed later\n      h->maxcode[j] = code << (16-j);\n      code <<= 1;\n   }\n   h->maxcode[j] = 0xffffffff;\n\n   // build non-spec acceleration table; 255 is flag for not-accelerated\n   memset(h->fast, 255, 1 << FAST_BITS);\n   for (i=0; i < k; ++i) {\n      int s = h->size[i];\n      if (s <= FAST_BITS) {\n         int c = h->code[i] << (FAST_BITS-s);\n         int m = 1 << (FAST_BITS-s);\n         for (j=0; j < m; ++j) {\n            h->fast[c+j] = (stbi_uc) i;\n         }\n      }\n   }\n   return 1;\n}\n\n// build a table that decodes both magnitude and value of small ACs in\n// one go.\nstatic void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h)\n{\n   int i;\n   for (i=0; i < (1 << FAST_BITS); ++i) {\n      stbi_uc fast = h->fast[i];\n      fast_ac[i] = 0;\n      if (fast < 255) {\n         int rs = h->values[fast];\n         int run = (rs >> 4) & 15;\n         int magbits = rs & 15;\n         int len = h->size[fast];\n\n         if (magbits && len + magbits <= FAST_BITS) {\n            // magnitude code followed by receive_extend code\n            int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits);\n            int m = 1 << (magbits - 1);\n            if (k < m) k += (~0U << magbits) + 1;\n            // if the result is small enough, we can fit it in fast_ac table\n            if (k >= -128 && k <= 127)\n               fast_ac[i] = (stbi__int16) ((k * 256) + (run * 16) + (len + magbits));\n         }\n      }\n   }\n}\n\nstatic void stbi__grow_buffer_unsafe(stbi__jpeg *j)\n{\n   do {\n      unsigned int b = j->nomore ? 0 : stbi__get8(j->s);\n      if (b == 0xff) {\n         int c = stbi__get8(j->s);\n         while (c == 0xff) c = stbi__get8(j->s); // consume fill bytes\n         if (c != 0) {\n            j->marker = (unsigned char) c;\n            j->nomore = 1;\n            return;\n         }\n      }\n      j->code_buffer |= b << (24 - j->code_bits);\n      j->code_bits += 8;\n   } while (j->code_bits <= 24);\n}\n\n// (1 << n) - 1\nstatic const stbi__uint32 stbi__bmask[17]={0,1,3,7,15,31,63,127,255,511,1023,2047,4095,8191,16383,32767,65535};\n\n// decode a jpeg huffman value from the bitstream\nstbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h)\n{\n   unsigned int temp;\n   int c,k;\n\n   if (j->code_bits < 16) stbi__grow_buffer_unsafe(j);\n\n   // look at the top FAST_BITS and determine what symbol ID it is,\n   // if the code is <= FAST_BITS\n   c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1);\n   k = h->fast[c];\n   if (k < 255) {\n      int s = h->size[k];\n      if (s > j->code_bits)\n         return -1;\n      j->code_buffer <<= s;\n      j->code_bits -= s;\n      return h->values[k];\n   }\n\n   // naive test is to shift the code_buffer down so k bits are\n   // valid, then test against maxcode. To speed this up, we've\n   // preshifted maxcode left so that it has (16-k) 0s at the\n   // end; in other words, regardless of the number of bits, it\n   // wants to be compared against something shifted to have 16;\n   // that way we don't need to shift inside the loop.\n   temp = j->code_buffer >> 16;\n   for (k=FAST_BITS+1 ; ; ++k)\n      if (temp < h->maxcode[k])\n         break;\n   if (k == 17) {\n      // error! code not found\n      j->code_bits -= 16;\n      return -1;\n   }\n\n   if (k > j->code_bits)\n      return -1;\n\n   // convert the huffman code to the symbol id\n   c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k];\n   if(c < 0 || c >= 256) // symbol id out of bounds!\n       return -1;\n   STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & stbi__bmask[h->size[c]]) == h->code[c]);\n\n   // convert the id to a symbol\n   j->code_bits -= k;\n   j->code_buffer <<= k;\n   return h->values[c];\n}\n\n// bias[n] = (-1<<n) + 1\nstatic const int stbi__jbias[16] = {0,-1,-3,-7,-15,-31,-63,-127,-255,-511,-1023,-2047,-4095,-8191,-16383,-32767};\n\n// combined JPEG 'receive' and JPEG 'extend', since baseline\n// always extends everything it receives.\nstbi_inline static int stbi__extend_receive(stbi__jpeg *j, int n)\n{\n   unsigned int k;\n   int sgn;\n   if (j->code_bits < n) stbi__grow_buffer_unsafe(j);\n   if (j->code_bits < n) return 0; // ran out of bits from stream, return 0s intead of continuing\n\n   sgn = j->code_buffer >> 31; // sign bit always in MSB; 0 if MSB clear (positive), 1 if MSB set (negative)\n   k = stbi_lrot(j->code_buffer, n);\n   j->code_buffer = k & ~stbi__bmask[n];\n   k &= stbi__bmask[n];\n   j->code_bits -= n;\n   return k + (stbi__jbias[n] & (sgn - 1));\n}\n\n// get some unsigned bits\nstbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n)\n{\n   unsigned int k;\n   if (j->code_bits < n) stbi__grow_buffer_unsafe(j);\n   if (j->code_bits < n) return 0; // ran out of bits from stream, return 0s intead of continuing\n   k = stbi_lrot(j->code_buffer, n);\n   j->code_buffer = k & ~stbi__bmask[n];\n   k &= stbi__bmask[n];\n   j->code_bits -= n;\n   return k;\n}\n\nstbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j)\n{\n   unsigned int k;\n   if (j->code_bits < 1) stbi__grow_buffer_unsafe(j);\n   if (j->code_bits < 1) return 0; // ran out of bits from stream, return 0s intead of continuing\n   k = j->code_buffer;\n   j->code_buffer <<= 1;\n   --j->code_bits;\n   return k & 0x80000000;\n}\n\n// given a value that's at position X in the zigzag stream,\n// where does it appear in the 8x8 matrix coded as row-major?\nstatic const stbi_uc stbi__jpeg_dezigzag[64+15] =\n{\n    0,  1,  8, 16,  9,  2,  3, 10,\n   17, 24, 32, 25, 18, 11,  4,  5,\n   12, 19, 26, 33, 40, 48, 41, 34,\n   27, 20, 13,  6,  7, 14, 21, 28,\n   35, 42, 49, 56, 57, 50, 43, 36,\n   29, 22, 15, 23, 30, 37, 44, 51,\n   58, 59, 52, 45, 38, 31, 39, 46,\n   53, 60, 61, 54, 47, 55, 62, 63,\n   // let corrupt input sample past end\n   63, 63, 63, 63, 63, 63, 63, 63,\n   63, 63, 63, 63, 63, 63, 63\n};\n\n// decode one 64-entry block--\nstatic int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], stbi__huffman *hdc, stbi__huffman *hac, stbi__int16 *fac, int b, stbi__uint16 *dequant)\n{\n   int diff,dc,k;\n   int t;\n\n   if (j->code_bits < 16) stbi__grow_buffer_unsafe(j);\n   t = stbi__jpeg_huff_decode(j, hdc);\n   if (t < 0 || t > 15) return stbi__err(\"bad huffman code\",\"Corrupt JPEG\");\n\n   // 0 all the ac values now so we can do it 32-bits at a time\n   memset(data,0,64*sizeof(data[0]));\n\n   diff = t ? stbi__extend_receive(j, t) : 0;\n   if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) return stbi__err(\"bad delta\",\"Corrupt JPEG\");\n   dc = j->img_comp[b].dc_pred + diff;\n   j->img_comp[b].dc_pred = dc;\n   if (!stbi__mul2shorts_valid(dc, dequant[0])) return stbi__err(\"can't merge dc and ac\", \"Corrupt JPEG\");\n   data[0] = (short) (dc * dequant[0]);\n\n   // decode AC components, see JPEG spec\n   k = 1;\n   do {\n      unsigned int zig;\n      int c,r,s;\n      if (j->code_bits < 16) stbi__grow_buffer_unsafe(j);\n      c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1);\n      r = fac[c];\n      if (r) { // fast-AC path\n         k += (r >> 4) & 15; // run\n         s = r & 15; // combined length\n         if (s > j->code_bits) return stbi__err(\"bad huffman code\", \"Combined length longer than code bits available\");\n         j->code_buffer <<= s;\n         j->code_bits -= s;\n         // decode into unzigzag'd location\n         zig = stbi__jpeg_dezigzag[k++];\n         data[zig] = (short) ((r >> 8) * dequant[zig]);\n      } else {\n         int rs = stbi__jpeg_huff_decode(j, hac);\n         if (rs < 0) return stbi__err(\"bad huffman code\",\"Corrupt JPEG\");\n         s = rs & 15;\n         r = rs >> 4;\n         if (s == 0) {\n            if (rs != 0xf0) break; // end block\n            k += 16;\n         } else {\n            k += r;\n            // decode into unzigzag'd location\n            zig = stbi__jpeg_dezigzag[k++];\n            data[zig] = (short) (stbi__extend_receive(j,s) * dequant[zig]);\n         }\n      }\n   } while (k < 64);\n   return 1;\n}\n\nstatic int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], stbi__huffman *hdc, int b)\n{\n   int diff,dc;\n   int t;\n   if (j->spec_end != 0) return stbi__err(\"can't merge dc and ac\", \"Corrupt JPEG\");\n\n   if (j->code_bits < 16) stbi__grow_buffer_unsafe(j);\n\n   if (j->succ_high == 0) {\n      // first scan for DC coefficient, must be first\n      memset(data,0,64*sizeof(data[0])); // 0 all the ac values now\n      t = stbi__jpeg_huff_decode(j, hdc);\n      if (t < 0 || t > 15) return stbi__err(\"can't merge dc and ac\", \"Corrupt JPEG\");\n      diff = t ? stbi__extend_receive(j, t) : 0;\n\n      if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) return stbi__err(\"bad delta\", \"Corrupt JPEG\");\n      dc = j->img_comp[b].dc_pred + diff;\n      j->img_comp[b].dc_pred = dc;\n      if (!stbi__mul2shorts_valid(dc, 1 << j->succ_low)) return stbi__err(\"can't merge dc and ac\", \"Corrupt JPEG\");\n      data[0] = (short) (dc * (1 << j->succ_low));\n   } else {\n      // refinement scan for DC coefficient\n      if (stbi__jpeg_get_bit(j))\n         data[0] += (short) (1 << j->succ_low);\n   }\n   return 1;\n}\n\n// @OPTIMIZE: store non-zigzagged during the decode passes,\n// and only de-zigzag when dequantizing\nstatic int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], stbi__huffman *hac, stbi__int16 *fac)\n{\n   int k;\n   if (j->spec_start == 0) return stbi__err(\"can't merge dc and ac\", \"Corrupt JPEG\");\n\n   if (j->succ_high == 0) {\n      int shift = j->succ_low;\n\n      if (j->eob_run) {\n         --j->eob_run;\n         return 1;\n      }\n\n      k = j->spec_start;\n      do {\n         unsigned int zig;\n         int c,r,s;\n         if (j->code_bits < 16) stbi__grow_buffer_unsafe(j);\n         c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1);\n         r = fac[c];\n         if (r) { // fast-AC path\n            k += (r >> 4) & 15; // run\n            s = r & 15; // combined length\n            if (s > j->code_bits) return stbi__err(\"bad huffman code\", \"Combined length longer than code bits available\");\n            j->code_buffer <<= s;\n            j->code_bits -= s;\n            zig = stbi__jpeg_dezigzag[k++];\n            data[zig] = (short) ((r >> 8) * (1 << shift));\n         } else {\n            int rs = stbi__jpeg_huff_decode(j, hac);\n            if (rs < 0) return stbi__err(\"bad huffman code\",\"Corrupt JPEG\");\n            s = rs & 15;\n            r = rs >> 4;\n            if (s == 0) {\n               if (r < 15) {\n                  j->eob_run = (1 << r);\n                  if (r)\n                     j->eob_run += stbi__jpeg_get_bits(j, r);\n                  --j->eob_run;\n                  break;\n               }\n               k += 16;\n            } else {\n               k += r;\n               zig = stbi__jpeg_dezigzag[k++];\n               data[zig] = (short) (stbi__extend_receive(j,s) * (1 << shift));\n            }\n         }\n      } while (k <= j->spec_end);\n   } else {\n      // refinement scan for these AC coefficients\n\n      short bit = (short) (1 << j->succ_low);\n\n      if (j->eob_run) {\n         --j->eob_run;\n         for (k = j->spec_start; k <= j->spec_end; ++k) {\n            short *p = &data[stbi__jpeg_dezigzag[k]];\n            if (*p != 0)\n               if (stbi__jpeg_get_bit(j))\n                  if ((*p & bit)==0) {\n                     if (*p > 0)\n                        *p += bit;\n                     else\n                        *p -= bit;\n                  }\n         }\n      } else {\n         k = j->spec_start;\n         do {\n            int r,s;\n            int rs = stbi__jpeg_huff_decode(j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh\n            if (rs < 0) return stbi__err(\"bad huffman code\",\"Corrupt JPEG\");\n            s = rs & 15;\n            r = rs >> 4;\n            if (s == 0) {\n               if (r < 15) {\n                  j->eob_run = (1 << r) - 1;\n                  if (r)\n                     j->eob_run += stbi__jpeg_get_bits(j, r);\n                  r = 64; // force end of block\n               } else {\n                  // r=15 s=0 should write 16 0s, so we just do\n                  // a run of 15 0s and then write s (which is 0),\n                  // so we don't have to do anything special here\n               }\n            } else {\n               if (s != 1) return stbi__err(\"bad huffman code\", \"Corrupt JPEG\");\n               // sign bit\n               if (stbi__jpeg_get_bit(j))\n                  s = bit;\n               else\n                  s = -bit;\n            }\n\n            // advance by r\n            while (k <= j->spec_end) {\n               short *p = &data[stbi__jpeg_dezigzag[k++]];\n               if (*p != 0) {\n                  if (stbi__jpeg_get_bit(j))\n                     if ((*p & bit)==0) {\n                        if (*p > 0)\n                           *p += bit;\n                        else\n                           *p -= bit;\n                     }\n               } else {\n                  if (r == 0) {\n                     *p = (short) s;\n                     break;\n                  }\n                  --r;\n               }\n            }\n         } while (k <= j->spec_end);\n      }\n   }\n   return 1;\n}\n\n// take a -128..127 value and stbi__clamp it and convert to 0..255\nstbi_inline static stbi_uc stbi__clamp(int x)\n{\n   // trick to use a single test to catch both cases\n   if ((unsigned int) x > 255) {\n      if (x < 0) return 0;\n      if (x > 255) return 255;\n   }\n   return (stbi_uc) x;\n}\n\n#define stbi__f2f(x)  ((int) (((x) * 4096 + 0.5)))\n#define stbi__fsh(x)  ((x) * 4096)\n\n// derived from jidctint -- DCT_ISLOW\n#define STBI__IDCT_1D(s0,s1,s2,s3,s4,s5,s6,s7) \\\n   int t0,t1,t2,t3,p1,p2,p3,p4,p5,x0,x1,x2,x3; \\\n   p2 = s2;                                    \\\n   p3 = s6;                                    \\\n   p1 = (p2+p3) * stbi__f2f(0.5411961f);       \\\n   t2 = p1 + p3*stbi__f2f(-1.847759065f);      \\\n   t3 = p1 + p2*stbi__f2f( 0.765366865f);      \\\n   p2 = s0;                                    \\\n   p3 = s4;                                    \\\n   t0 = stbi__fsh(p2+p3);                      \\\n   t1 = stbi__fsh(p2-p3);                      \\\n   x0 = t0+t3;                                 \\\n   x3 = t0-t3;                                 \\\n   x1 = t1+t2;                                 \\\n   x2 = t1-t2;                                 \\\n   t0 = s7;                                    \\\n   t1 = s5;                                    \\\n   t2 = s3;                                    \\\n   t3 = s1;                                    \\\n   p3 = t0+t2;                                 \\\n   p4 = t1+t3;                                 \\\n   p1 = t0+t3;                                 \\\n   p2 = t1+t2;                                 \\\n   p5 = (p3+p4)*stbi__f2f( 1.175875602f);      \\\n   t0 = t0*stbi__f2f( 0.298631336f);           \\\n   t1 = t1*stbi__f2f( 2.053119869f);           \\\n   t2 = t2*stbi__f2f( 3.072711026f);           \\\n   t3 = t3*stbi__f2f( 1.501321110f);           \\\n   p1 = p5 + p1*stbi__f2f(-0.899976223f);      \\\n   p2 = p5 + p2*stbi__f2f(-2.562915447f);      \\\n   p3 = p3*stbi__f2f(-1.961570560f);           \\\n   p4 = p4*stbi__f2f(-0.390180644f);           \\\n   t3 += p1+p4;                                \\\n   t2 += p2+p3;                                \\\n   t1 += p2+p4;                                \\\n   t0 += p1+p3;\n\nstatic void stbi__idct_block(stbi_uc *out, int out_stride, short data[64])\n{\n   int i,val[64],*v=val;\n   stbi_uc *o;\n   short *d = data;\n\n   // columns\n   for (i=0; i < 8; ++i,++d, ++v) {\n      // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing\n      if (d[ 8]==0 && d[16]==0 && d[24]==0 && d[32]==0\n           && d[40]==0 && d[48]==0 && d[56]==0) {\n         //    no shortcut                 0     seconds\n         //    (1|2|3|4|5|6|7)==0          0     seconds\n         //    all separate               -0.047 seconds\n         //    1 && 2|3 && 4|5 && 6|7:    -0.047 seconds\n         int dcterm = d[0]*4;\n         v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm;\n      } else {\n         STBI__IDCT_1D(d[ 0],d[ 8],d[16],d[24],d[32],d[40],d[48],d[56])\n         // constants scaled things up by 1<<12; let's bring them back\n         // down, but keep 2 extra bits of precision\n         x0 += 512; x1 += 512; x2 += 512; x3 += 512;\n         v[ 0] = (x0+t3) >> 10;\n         v[56] = (x0-t3) >> 10;\n         v[ 8] = (x1+t2) >> 10;\n         v[48] = (x1-t2) >> 10;\n         v[16] = (x2+t1) >> 10;\n         v[40] = (x2-t1) >> 10;\n         v[24] = (x3+t0) >> 10;\n         v[32] = (x3-t0) >> 10;\n      }\n   }\n\n   for (i=0, v=val, o=out; i < 8; ++i,v+=8,o+=out_stride) {\n      // no fast case since the first 1D IDCT spread components out\n      STBI__IDCT_1D(v[0],v[1],v[2],v[3],v[4],v[5],v[6],v[7])\n      // constants scaled things up by 1<<12, plus we had 1<<2 from first\n      // loop, plus horizontal and vertical each scale by sqrt(8) so together\n      // we've got an extra 1<<3, so 1<<17 total we need to remove.\n      // so we want to round that, which means adding 0.5 * 1<<17,\n      // aka 65536. Also, we'll end up with -128 to 127 that we want\n      // to encode as 0..255 by adding 128, so we'll add that before the shift\n      x0 += 65536 + (128<<17);\n      x1 += 65536 + (128<<17);\n      x2 += 65536 + (128<<17);\n      x3 += 65536 + (128<<17);\n      // tried computing the shifts into temps, or'ing the temps to see\n      // if any were out of range, but that was slower\n      o[0] = stbi__clamp((x0+t3) >> 17);\n      o[7] = stbi__clamp((x0-t3) >> 17);\n      o[1] = stbi__clamp((x1+t2) >> 17);\n      o[6] = stbi__clamp((x1-t2) >> 17);\n      o[2] = stbi__clamp((x2+t1) >> 17);\n      o[5] = stbi__clamp((x2-t1) >> 17);\n      o[3] = stbi__clamp((x3+t0) >> 17);\n      o[4] = stbi__clamp((x3-t0) >> 17);\n   }\n}\n\n#ifdef STBI_SSE2\n// sse2 integer IDCT. not the fastest possible implementation but it\n// produces bit-identical results to the generic C version so it's\n// fully \"transparent\".\nstatic void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64])\n{\n   // This is constructed to match our regular (generic) integer IDCT exactly.\n   __m128i row0, row1, row2, row3, row4, row5, row6, row7;\n   __m128i tmp;\n\n   // dot product constant: even elems=x, odd elems=y\n   #define dct_const(x,y)  _mm_setr_epi16((x),(y),(x),(y),(x),(y),(x),(y))\n\n   // out(0) = c0[even]*x + c0[odd]*y   (c0, x, y 16-bit, out 32-bit)\n   // out(1) = c1[even]*x + c1[odd]*y\n   #define dct_rot(out0,out1, x,y,c0,c1) \\\n      __m128i c0##lo = _mm_unpacklo_epi16((x),(y)); \\\n      __m128i c0##hi = _mm_unpackhi_epi16((x),(y)); \\\n      __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \\\n      __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \\\n      __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \\\n      __m128i out1##_h = _mm_madd_epi16(c0##hi, c1)\n\n   // out = in << 12  (in 16-bit, out 32-bit)\n   #define dct_widen(out, in) \\\n      __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \\\n      __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4)\n\n   // wide add\n   #define dct_wadd(out, a, b) \\\n      __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \\\n      __m128i out##_h = _mm_add_epi32(a##_h, b##_h)\n\n   // wide sub\n   #define dct_wsub(out, a, b) \\\n      __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \\\n      __m128i out##_h = _mm_sub_epi32(a##_h, b##_h)\n\n   // butterfly a/b, add bias, then shift by \"s\" and pack\n   #define dct_bfly32o(out0, out1, a,b,bias,s) \\\n      { \\\n         __m128i abiased_l = _mm_add_epi32(a##_l, bias); \\\n         __m128i abiased_h = _mm_add_epi32(a##_h, bias); \\\n         dct_wadd(sum, abiased, b); \\\n         dct_wsub(dif, abiased, b); \\\n         out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \\\n         out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \\\n      }\n\n   // 8-bit interleave step (for transposes)\n   #define dct_interleave8(a, b) \\\n      tmp = a; \\\n      a = _mm_unpacklo_epi8(a, b); \\\n      b = _mm_unpackhi_epi8(tmp, b)\n\n   // 16-bit interleave step (for transposes)\n   #define dct_interleave16(a, b) \\\n      tmp = a; \\\n      a = _mm_unpacklo_epi16(a, b); \\\n      b = _mm_unpackhi_epi16(tmp, b)\n\n   #define dct_pass(bias,shift) \\\n      { \\\n         /* even part */ \\\n         dct_rot(t2e,t3e, row2,row6, rot0_0,rot0_1); \\\n         __m128i sum04 = _mm_add_epi16(row0, row4); \\\n         __m128i dif04 = _mm_sub_epi16(row0, row4); \\\n         dct_widen(t0e, sum04); \\\n         dct_widen(t1e, dif04); \\\n         dct_wadd(x0, t0e, t3e); \\\n         dct_wsub(x3, t0e, t3e); \\\n         dct_wadd(x1, t1e, t2e); \\\n         dct_wsub(x2, t1e, t2e); \\\n         /* odd part */ \\\n         dct_rot(y0o,y2o, row7,row3, rot2_0,rot2_1); \\\n         dct_rot(y1o,y3o, row5,row1, rot3_0,rot3_1); \\\n         __m128i sum17 = _mm_add_epi16(row1, row7); \\\n         __m128i sum35 = _mm_add_epi16(row3, row5); \\\n         dct_rot(y4o,y5o, sum17,sum35, rot1_0,rot1_1); \\\n         dct_wadd(x4, y0o, y4o); \\\n         dct_wadd(x5, y1o, y5o); \\\n         dct_wadd(x6, y2o, y5o); \\\n         dct_wadd(x7, y3o, y4o); \\\n         dct_bfly32o(row0,row7, x0,x7,bias,shift); \\\n         dct_bfly32o(row1,row6, x1,x6,bias,shift); \\\n         dct_bfly32o(row2,row5, x2,x5,bias,shift); \\\n         dct_bfly32o(row3,row4, x3,x4,bias,shift); \\\n      }\n\n   __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f));\n   __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f( 0.765366865f), stbi__f2f(0.5411961f));\n   __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), stbi__f2f(1.175875602f));\n   __m128i rot1_1 = dct_const(stbi__f2f(1.175875602f), stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f));\n   __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f( 0.298631336f), stbi__f2f(-1.961570560f));\n   __m128i rot2_1 = dct_const(stbi__f2f(-1.961570560f), stbi__f2f(-1.961570560f) + stbi__f2f( 3.072711026f));\n   __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f( 2.053119869f), stbi__f2f(-0.390180644f));\n   __m128i rot3_1 = dct_const(stbi__f2f(-0.390180644f), stbi__f2f(-0.390180644f) + stbi__f2f( 1.501321110f));\n\n   // rounding biases in column/row passes, see stbi__idct_block for explanation.\n   __m128i bias_0 = _mm_set1_epi32(512);\n   __m128i bias_1 = _mm_set1_epi32(65536 + (128<<17));\n\n   // load\n   row0 = _mm_load_si128((const __m128i *) (data + 0*8));\n   row1 = _mm_load_si128((const __m128i *) (data + 1*8));\n   row2 = _mm_load_si128((const __m128i *) (data + 2*8));\n   row3 = _mm_load_si128((const __m128i *) (data + 3*8));\n   row4 = _mm_load_si128((const __m128i *) (data + 4*8));\n   row5 = _mm_load_si128((const __m128i *) (data + 5*8));\n   row6 = _mm_load_si128((const __m128i *) (data + 6*8));\n   row7 = _mm_load_si128((const __m128i *) (data + 7*8));\n\n   // column pass\n   dct_pass(bias_0, 10);\n\n   {\n      // 16bit 8x8 transpose pass 1\n      dct_interleave16(row0, row4);\n      dct_interleave16(row1, row5);\n      dct_interleave16(row2, row6);\n      dct_interleave16(row3, row7);\n\n      // transpose pass 2\n      dct_interleave16(row0, row2);\n      dct_interleave16(row1, row3);\n      dct_interleave16(row4, row6);\n      dct_interleave16(row5, row7);\n\n      // transpose pass 3\n      dct_interleave16(row0, row1);\n      dct_interleave16(row2, row3);\n      dct_interleave16(row4, row5);\n      dct_interleave16(row6, row7);\n   }\n\n   // row pass\n   dct_pass(bias_1, 17);\n\n   {\n      // pack\n      __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7\n      __m128i p1 = _mm_packus_epi16(row2, row3);\n      __m128i p2 = _mm_packus_epi16(row4, row5);\n      __m128i p3 = _mm_packus_epi16(row6, row7);\n\n      // 8bit 8x8 transpose pass 1\n      dct_interleave8(p0, p2); // a0e0a1e1...\n      dct_interleave8(p1, p3); // c0g0c1g1...\n\n      // transpose pass 2\n      dct_interleave8(p0, p1); // a0c0e0g0...\n      dct_interleave8(p2, p3); // b0d0f0h0...\n\n      // transpose pass 3\n      dct_interleave8(p0, p2); // a0b0c0d0...\n      dct_interleave8(p1, p3); // a4b4c4d4...\n\n      // store\n      _mm_storel_epi64((__m128i *) out, p0); out += out_stride;\n      _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p0, 0x4e)); out += out_stride;\n      _mm_storel_epi64((__m128i *) out, p2); out += out_stride;\n      _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p2, 0x4e)); out += out_stride;\n      _mm_storel_epi64((__m128i *) out, p1); out += out_stride;\n      _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p1, 0x4e)); out += out_stride;\n      _mm_storel_epi64((__m128i *) out, p3); out += out_stride;\n      _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p3, 0x4e));\n   }\n\n#undef dct_const\n#undef dct_rot\n#undef dct_widen\n#undef dct_wadd\n#undef dct_wsub\n#undef dct_bfly32o\n#undef dct_interleave8\n#undef dct_interleave16\n#undef dct_pass\n}\n\n#endif // STBI_SSE2\n\n#ifdef STBI_NEON\n\n// NEON integer IDCT. should produce bit-identical\n// results to the generic C version.\nstatic void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64])\n{\n   int16x8_t row0, row1, row2, row3, row4, row5, row6, row7;\n\n   int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f));\n   int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f));\n   int16x4_t rot0_2 = vdup_n_s16(stbi__f2f( 0.765366865f));\n   int16x4_t rot1_0 = vdup_n_s16(stbi__f2f( 1.175875602f));\n   int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f));\n   int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f));\n   int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f));\n   int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f));\n   int16x4_t rot3_0 = vdup_n_s16(stbi__f2f( 0.298631336f));\n   int16x4_t rot3_1 = vdup_n_s16(stbi__f2f( 2.053119869f));\n   int16x4_t rot3_2 = vdup_n_s16(stbi__f2f( 3.072711026f));\n   int16x4_t rot3_3 = vdup_n_s16(stbi__f2f( 1.501321110f));\n\n#define dct_long_mul(out, inq, coeff) \\\n   int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \\\n   int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff)\n\n#define dct_long_mac(out, acc, inq, coeff) \\\n   int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \\\n   int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff)\n\n#define dct_widen(out, inq) \\\n   int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \\\n   int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12)\n\n// wide add\n#define dct_wadd(out, a, b) \\\n   int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \\\n   int32x4_t out##_h = vaddq_s32(a##_h, b##_h)\n\n// wide sub\n#define dct_wsub(out, a, b) \\\n   int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \\\n   int32x4_t out##_h = vsubq_s32(a##_h, b##_h)\n\n// butterfly a/b, then shift using \"shiftop\" by \"s\" and pack\n#define dct_bfly32o(out0,out1, a,b,shiftop,s) \\\n   { \\\n      dct_wadd(sum, a, b); \\\n      dct_wsub(dif, a, b); \\\n      out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \\\n      out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \\\n   }\n\n#define dct_pass(shiftop, shift) \\\n   { \\\n      /* even part */ \\\n      int16x8_t sum26 = vaddq_s16(row2, row6); \\\n      dct_long_mul(p1e, sum26, rot0_0); \\\n      dct_long_mac(t2e, p1e, row6, rot0_1); \\\n      dct_long_mac(t3e, p1e, row2, rot0_2); \\\n      int16x8_t sum04 = vaddq_s16(row0, row4); \\\n      int16x8_t dif04 = vsubq_s16(row0, row4); \\\n      dct_widen(t0e, sum04); \\\n      dct_widen(t1e, dif04); \\\n      dct_wadd(x0, t0e, t3e); \\\n      dct_wsub(x3, t0e, t3e); \\\n      dct_wadd(x1, t1e, t2e); \\\n      dct_wsub(x2, t1e, t2e); \\\n      /* odd part */ \\\n      int16x8_t sum15 = vaddq_s16(row1, row5); \\\n      int16x8_t sum17 = vaddq_s16(row1, row7); \\\n      int16x8_t sum35 = vaddq_s16(row3, row5); \\\n      int16x8_t sum37 = vaddq_s16(row3, row7); \\\n      int16x8_t sumodd = vaddq_s16(sum17, sum35); \\\n      dct_long_mul(p5o, sumodd, rot1_0); \\\n      dct_long_mac(p1o, p5o, sum17, rot1_1); \\\n      dct_long_mac(p2o, p5o, sum35, rot1_2); \\\n      dct_long_mul(p3o, sum37, rot2_0); \\\n      dct_long_mul(p4o, sum15, rot2_1); \\\n      dct_wadd(sump13o, p1o, p3o); \\\n      dct_wadd(sump24o, p2o, p4o); \\\n      dct_wadd(sump23o, p2o, p3o); \\\n      dct_wadd(sump14o, p1o, p4o); \\\n      dct_long_mac(x4, sump13o, row7, rot3_0); \\\n      dct_long_mac(x5, sump24o, row5, rot3_1); \\\n      dct_long_mac(x6, sump23o, row3, rot3_2); \\\n      dct_long_mac(x7, sump14o, row1, rot3_3); \\\n      dct_bfly32o(row0,row7, x0,x7,shiftop,shift); \\\n      dct_bfly32o(row1,row6, x1,x6,shiftop,shift); \\\n      dct_bfly32o(row2,row5, x2,x5,shiftop,shift); \\\n      dct_bfly32o(row3,row4, x3,x4,shiftop,shift); \\\n   }\n\n   // load\n   row0 = vld1q_s16(data + 0*8);\n   row1 = vld1q_s16(data + 1*8);\n   row2 = vld1q_s16(data + 2*8);\n   row3 = vld1q_s16(data + 3*8);\n   row4 = vld1q_s16(data + 4*8);\n   row5 = vld1q_s16(data + 5*8);\n   row6 = vld1q_s16(data + 6*8);\n   row7 = vld1q_s16(data + 7*8);\n\n   // add DC bias\n   row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0));\n\n   // column pass\n   dct_pass(vrshrn_n_s32, 10);\n\n   // 16bit 8x8 transpose\n   {\n// these three map to a single VTRN.16, VTRN.32, and VSWP, respectively.\n// whether compilers actually get this is another story, sadly.\n#define dct_trn16(x, y) { int16x8x2_t t = vtrnq_s16(x, y); x = t.val[0]; y = t.val[1]; }\n#define dct_trn32(x, y) { int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); x = vreinterpretq_s16_s32(t.val[0]); y = vreinterpretq_s16_s32(t.val[1]); }\n#define dct_trn64(x, y) { int16x8_t x0 = x; int16x8_t y0 = y; x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); }\n\n      // pass 1\n      dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6\n      dct_trn16(row2, row3);\n      dct_trn16(row4, row5);\n      dct_trn16(row6, row7);\n\n      // pass 2\n      dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4\n      dct_trn32(row1, row3);\n      dct_trn32(row4, row6);\n      dct_trn32(row5, row7);\n\n      // pass 3\n      dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0\n      dct_trn64(row1, row5);\n      dct_trn64(row2, row6);\n      dct_trn64(row3, row7);\n\n#undef dct_trn16\n#undef dct_trn32\n#undef dct_trn64\n   }\n\n   // row pass\n   // vrshrn_n_s32 only supports shifts up to 16, we need\n   // 17. so do a non-rounding shift of 16 first then follow\n   // up with a rounding shift by 1.\n   dct_pass(vshrn_n_s32, 16);\n\n   {\n      // pack and round\n      uint8x8_t p0 = vqrshrun_n_s16(row0, 1);\n      uint8x8_t p1 = vqrshrun_n_s16(row1, 1);\n      uint8x8_t p2 = vqrshrun_n_s16(row2, 1);\n      uint8x8_t p3 = vqrshrun_n_s16(row3, 1);\n      uint8x8_t p4 = vqrshrun_n_s16(row4, 1);\n      uint8x8_t p5 = vqrshrun_n_s16(row5, 1);\n      uint8x8_t p6 = vqrshrun_n_s16(row6, 1);\n      uint8x8_t p7 = vqrshrun_n_s16(row7, 1);\n\n      // again, these can translate into one instruction, but often don't.\n#define dct_trn8_8(x, y) { uint8x8x2_t t = vtrn_u8(x, y); x = t.val[0]; y = t.val[1]; }\n#define dct_trn8_16(x, y) { uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); x = vreinterpret_u8_u16(t.val[0]); y = vreinterpret_u8_u16(t.val[1]); }\n#define dct_trn8_32(x, y) { uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); x = vreinterpret_u8_u32(t.val[0]); y = vreinterpret_u8_u32(t.val[1]); }\n\n      // sadly can't use interleaved stores here since we only write\n      // 8 bytes to each scan line!\n\n      // 8x8 8-bit transpose pass 1\n      dct_trn8_8(p0, p1);\n      dct_trn8_8(p2, p3);\n      dct_trn8_8(p4, p5);\n      dct_trn8_8(p6, p7);\n\n      // pass 2\n      dct_trn8_16(p0, p2);\n      dct_trn8_16(p1, p3);\n      dct_trn8_16(p4, p6);\n      dct_trn8_16(p5, p7);\n\n      // pass 3\n      dct_trn8_32(p0, p4);\n      dct_trn8_32(p1, p5);\n      dct_trn8_32(p2, p6);\n      dct_trn8_32(p3, p7);\n\n      // store\n      vst1_u8(out, p0); out += out_stride;\n      vst1_u8(out, p1); out += out_stride;\n      vst1_u8(out, p2); out += out_stride;\n      vst1_u8(out, p3); out += out_stride;\n      vst1_u8(out, p4); out += out_stride;\n      vst1_u8(out, p5); out += out_stride;\n      vst1_u8(out, p6); out += out_stride;\n      vst1_u8(out, p7);\n\n#undef dct_trn8_8\n#undef dct_trn8_16\n#undef dct_trn8_32\n   }\n\n#undef dct_long_mul\n#undef dct_long_mac\n#undef dct_widen\n#undef dct_wadd\n#undef dct_wsub\n#undef dct_bfly32o\n#undef dct_pass\n}\n\n#endif // STBI_NEON\n\n#define STBI__MARKER_none  0xff\n// if there's a pending marker from the entropy stream, return that\n// otherwise, fetch from the stream and get a marker. if there's no\n// marker, return 0xff, which is never a valid marker value\nstatic stbi_uc stbi__get_marker(stbi__jpeg *j)\n{\n   stbi_uc x;\n   if (j->marker != STBI__MARKER_none) { x = j->marker; j->marker = STBI__MARKER_none; return x; }\n   x = stbi__get8(j->s);\n   if (x != 0xff) return STBI__MARKER_none;\n   while (x == 0xff)\n      x = stbi__get8(j->s); // consume repeated 0xff fill bytes\n   return x;\n}\n\n// in each scan, we'll have scan_n components, and the order\n// of the components is specified by order[]\n#define STBI__RESTART(x)     ((x) >= 0xd0 && (x) <= 0xd7)\n\n// after a restart interval, stbi__jpeg_reset the entropy decoder and\n// the dc prediction\nstatic void stbi__jpeg_reset(stbi__jpeg *j)\n{\n   j->code_bits = 0;\n   j->code_buffer = 0;\n   j->nomore = 0;\n   j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = j->img_comp[3].dc_pred = 0;\n   j->marker = STBI__MARKER_none;\n   j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff;\n   j->eob_run = 0;\n   // no more than 1<<31 MCUs if no restart_interal? that's plenty safe,\n   // since we don't even allow 1<<30 pixels\n}\n\nstatic int stbi__parse_entropy_coded_data(stbi__jpeg *z)\n{\n   stbi__jpeg_reset(z);\n   if (!z->progressive) {\n      if (z->scan_n == 1) {\n         int i,j;\n         STBI_SIMD_ALIGN(short, data[64]);\n         int n = z->order[0];\n         // non-interleaved data, we just need to process one block at a time,\n         // in trivial scanline order\n         // number of blocks to do just depends on how many actual \"pixels\" this\n         // component has, independent of interleaved MCU blocking and such\n         int w = (z->img_comp[n].x+7) >> 3;\n         int h = (z->img_comp[n].y+7) >> 3;\n         for (j=0; j < h; ++j) {\n            for (i=0; i < w; ++i) {\n               int ha = z->img_comp[n].ha;\n               if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0;\n               z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data);\n               // every data block is an MCU, so countdown the restart interval\n               if (--z->todo <= 0) {\n                  if (z->code_bits < 24) stbi__grow_buffer_unsafe(z);\n                  // if it's NOT a restart, then just bail, so we get corrupt data\n                  // rather than no data\n                  if (!STBI__RESTART(z->marker)) return 1;\n                  stbi__jpeg_reset(z);\n               }\n            }\n         }\n         return 1;\n      } else { // interleaved\n         int i,j,k,x,y;\n         STBI_SIMD_ALIGN(short, data[64]);\n         for (j=0; j < z->img_mcu_y; ++j) {\n            for (i=0; i < z->img_mcu_x; ++i) {\n               // scan an interleaved mcu... process scan_n components in order\n               for (k=0; k < z->scan_n; ++k) {\n                  int n = z->order[k];\n                  // scan out an mcu's worth of this component; that's just determined\n                  // by the basic H and V specified for the component\n                  for (y=0; y < z->img_comp[n].v; ++y) {\n                     for (x=0; x < z->img_comp[n].h; ++x) {\n                        int x2 = (i*z->img_comp[n].h + x)*8;\n                        int y2 = (j*z->img_comp[n].v + y)*8;\n                        int ha = z->img_comp[n].ha;\n                        if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0;\n                        z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*y2+x2, z->img_comp[n].w2, data);\n                     }\n                  }\n               }\n               // after all interleaved components, that's an interleaved MCU,\n               // so now count down the restart interval\n               if (--z->todo <= 0) {\n                  if (z->code_bits < 24) stbi__grow_buffer_unsafe(z);\n                  if (!STBI__RESTART(z->marker)) return 1;\n                  stbi__jpeg_reset(z);\n               }\n            }\n         }\n         return 1;\n      }\n   } else {\n      if (z->scan_n == 1) {\n         int i,j;\n         int n = z->order[0];\n         // non-interleaved data, we just need to process one block at a time,\n         // in trivial scanline order\n         // number of blocks to do just depends on how many actual \"pixels\" this\n         // component has, independent of interleaved MCU blocking and such\n         int w = (z->img_comp[n].x+7) >> 3;\n         int h = (z->img_comp[n].y+7) >> 3;\n         for (j=0; j < h; ++j) {\n            for (i=0; i < w; ++i) {\n               short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w);\n               if (z->spec_start == 0) {\n                  if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n))\n                     return 0;\n               } else {\n                  int ha = z->img_comp[n].ha;\n                  if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], z->fast_ac[ha]))\n                     return 0;\n               }\n               // every data block is an MCU, so countdown the restart interval\n               if (--z->todo <= 0) {\n                  if (z->code_bits < 24) stbi__grow_buffer_unsafe(z);\n                  if (!STBI__RESTART(z->marker)) return 1;\n                  stbi__jpeg_reset(z);\n               }\n            }\n         }\n         return 1;\n      } else { // interleaved\n         int i,j,k,x,y;\n         for (j=0; j < z->img_mcu_y; ++j) {\n            for (i=0; i < z->img_mcu_x; ++i) {\n               // scan an interleaved mcu... process scan_n components in order\n               for (k=0; k < z->scan_n; ++k) {\n                  int n = z->order[k];\n                  // scan out an mcu's worth of this component; that's just determined\n                  // by the basic H and V specified for the component\n                  for (y=0; y < z->img_comp[n].v; ++y) {\n                     for (x=0; x < z->img_comp[n].h; ++x) {\n                        int x2 = (i*z->img_comp[n].h + x);\n                        int y2 = (j*z->img_comp[n].v + y);\n                        short *data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w);\n                        if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n))\n                           return 0;\n                     }\n                  }\n               }\n               // after all interleaved components, that's an interleaved MCU,\n               // so now count down the restart interval\n               if (--z->todo <= 0) {\n                  if (z->code_bits < 24) stbi__grow_buffer_unsafe(z);\n                  if (!STBI__RESTART(z->marker)) return 1;\n                  stbi__jpeg_reset(z);\n               }\n            }\n         }\n         return 1;\n      }\n   }\n}\n\nstatic void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant)\n{\n   int i;\n   for (i=0; i < 64; ++i)\n      data[i] *= dequant[i];\n}\n\nstatic void stbi__jpeg_finish(stbi__jpeg *z)\n{\n   if (z->progressive) {\n      // dequantize and idct the data\n      int i,j,n;\n      for (n=0; n < z->s->img_n; ++n) {\n         int w = (z->img_comp[n].x+7) >> 3;\n         int h = (z->img_comp[n].y+7) >> 3;\n         for (j=0; j < h; ++j) {\n            for (i=0; i < w; ++i) {\n               short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w);\n               stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]);\n               z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data);\n            }\n         }\n      }\n   }\n}\n\nstatic int stbi__process_marker(stbi__jpeg *z, int m)\n{\n   int L;\n   switch (m) {\n      case STBI__MARKER_none: // no marker found\n         return stbi__err(\"expected marker\",\"Corrupt JPEG\");\n\n      case 0xDD: // DRI - specify restart interval\n         if (stbi__get16be(z->s) != 4) return stbi__err(\"bad DRI len\",\"Corrupt JPEG\");\n         z->restart_interval = stbi__get16be(z->s);\n         return 1;\n\n      case 0xDB: // DQT - define quantization table\n         L = stbi__get16be(z->s)-2;\n         while (L > 0) {\n            int q = stbi__get8(z->s);\n            int p = q >> 4, sixteen = (p != 0);\n            int t = q & 15,i;\n            if (p != 0 && p != 1) return stbi__err(\"bad DQT type\",\"Corrupt JPEG\");\n            if (t > 3) return stbi__err(\"bad DQT table\",\"Corrupt JPEG\");\n\n            for (i=0; i < 64; ++i)\n               z->dequant[t][stbi__jpeg_dezigzag[i]] = (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s));\n            L -= (sixteen ? 129 : 65);\n         }\n         return L==0;\n\n      case 0xC4: // DHT - define huffman table\n         L = stbi__get16be(z->s)-2;\n         while (L > 0) {\n            stbi_uc *v;\n            int sizes[16],i,n=0;\n            int q = stbi__get8(z->s);\n            int tc = q >> 4;\n            int th = q & 15;\n            if (tc > 1 || th > 3) return stbi__err(\"bad DHT header\",\"Corrupt JPEG\");\n            for (i=0; i < 16; ++i) {\n               sizes[i] = stbi__get8(z->s);\n               n += sizes[i];\n            }\n            if(n > 256) return stbi__err(\"bad DHT header\",\"Corrupt JPEG\"); // Loop over i < n would write past end of values!\n            L -= 17;\n            if (tc == 0) {\n               if (!stbi__build_huffman(z->huff_dc+th, sizes)) return 0;\n               v = z->huff_dc[th].values;\n            } else {\n               if (!stbi__build_huffman(z->huff_ac+th, sizes)) return 0;\n               v = z->huff_ac[th].values;\n            }\n            for (i=0; i < n; ++i)\n               v[i] = stbi__get8(z->s);\n            if (tc != 0)\n               stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th);\n            L -= n;\n         }\n         return L==0;\n   }\n\n   // check for comment block or APP blocks\n   if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) {\n      L = stbi__get16be(z->s);\n      if (L < 2) {\n         if (m == 0xFE)\n            return stbi__err(\"bad COM len\",\"Corrupt JPEG\");\n         else\n            return stbi__err(\"bad APP len\",\"Corrupt JPEG\");\n      }\n      L -= 2;\n\n      if (m == 0xE0 && L >= 5) { // JFIF APP0 segment\n         static const unsigned char tag[5] = {'J','F','I','F','\\0'};\n         int ok = 1;\n         int i;\n         for (i=0; i < 5; ++i)\n            if (stbi__get8(z->s) != tag[i])\n               ok = 0;\n         L -= 5;\n         if (ok)\n            z->jfif = 1;\n      } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment\n         static const unsigned char tag[6] = {'A','d','o','b','e','\\0'};\n         int ok = 1;\n         int i;\n         for (i=0; i < 6; ++i)\n            if (stbi__get8(z->s) != tag[i])\n               ok = 0;\n         L -= 6;\n         if (ok) {\n            stbi__get8(z->s); // version\n            stbi__get16be(z->s); // flags0\n            stbi__get16be(z->s); // flags1\n            z->app14_color_transform = stbi__get8(z->s); // color transform\n            L -= 6;\n         }\n      }\n\n      stbi__skip(z->s, L);\n      return 1;\n   }\n\n   return stbi__err(\"unknown marker\",\"Corrupt JPEG\");\n}\n\n// after we see SOS\nstatic int stbi__process_scan_header(stbi__jpeg *z)\n{\n   int i;\n   int Ls = stbi__get16be(z->s);\n   z->scan_n = stbi__get8(z->s);\n   if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int) z->s->img_n) return stbi__err(\"bad SOS component count\",\"Corrupt JPEG\");\n   if (Ls != 6+2*z->scan_n) return stbi__err(\"bad SOS len\",\"Corrupt JPEG\");\n   for (i=0; i < z->scan_n; ++i) {\n      int id = stbi__get8(z->s), which;\n      int q = stbi__get8(z->s);\n      for (which = 0; which < z->s->img_n; ++which)\n         if (z->img_comp[which].id == id)\n            break;\n      if (which == z->s->img_n) return 0; // no match\n      z->img_comp[which].hd = q >> 4;   if (z->img_comp[which].hd > 3) return stbi__err(\"bad DC huff\",\"Corrupt JPEG\");\n      z->img_comp[which].ha = q & 15;   if (z->img_comp[which].ha > 3) return stbi__err(\"bad AC huff\",\"Corrupt JPEG\");\n      z->order[i] = which;\n   }\n\n   {\n      int aa;\n      z->spec_start = stbi__get8(z->s);\n      z->spec_end   = stbi__get8(z->s); // should be 63, but might be 0\n      aa = stbi__get8(z->s);\n      z->succ_high = (aa >> 4);\n      z->succ_low  = (aa & 15);\n      if (z->progressive) {\n         if (z->spec_start > 63 || z->spec_end > 63  || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13)\n            return stbi__err(\"bad SOS\", \"Corrupt JPEG\");\n      } else {\n         if (z->spec_start != 0) return stbi__err(\"bad SOS\",\"Corrupt JPEG\");\n         if (z->succ_high != 0 || z->succ_low != 0) return stbi__err(\"bad SOS\",\"Corrupt JPEG\");\n         z->spec_end = 63;\n      }\n   }\n\n   return 1;\n}\n\nstatic int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why)\n{\n   int i;\n   for (i=0; i < ncomp; ++i) {\n      if (z->img_comp[i].raw_data) {\n         STBI_FREE(z->img_comp[i].raw_data);\n         z->img_comp[i].raw_data = NULL;\n         z->img_comp[i].data = NULL;\n      }\n      if (z->img_comp[i].raw_coeff) {\n         STBI_FREE(z->img_comp[i].raw_coeff);\n         z->img_comp[i].raw_coeff = 0;\n         z->img_comp[i].coeff = 0;\n      }\n      if (z->img_comp[i].linebuf) {\n         STBI_FREE(z->img_comp[i].linebuf);\n         z->img_comp[i].linebuf = NULL;\n      }\n   }\n   return why;\n}\n\nstatic int stbi__process_frame_header(stbi__jpeg *z, int scan)\n{\n   stbi__context *s = z->s;\n   int Lf,p,i,q, h_max=1,v_max=1,c;\n   Lf = stbi__get16be(s);         if (Lf < 11) return stbi__err(\"bad SOF len\",\"Corrupt JPEG\"); // JPEG\n   p  = stbi__get8(s);            if (p != 8) return stbi__err(\"only 8-bit\",\"JPEG format not supported: 8-bit only\"); // JPEG baseline\n   s->img_y = stbi__get16be(s);   if (s->img_y == 0) return stbi__err(\"no header height\", \"JPEG format not supported: delayed height\"); // Legal, but we don't handle it--but neither does IJG\n   s->img_x = stbi__get16be(s);   if (s->img_x == 0) return stbi__err(\"0 width\",\"Corrupt JPEG\"); // JPEG requires\n   if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err(\"too large\",\"Very large image (corrupt?)\");\n   if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err(\"too large\",\"Very large image (corrupt?)\");\n   c = stbi__get8(s);\n   if (c != 3 && c != 1 && c != 4) return stbi__err(\"bad component count\",\"Corrupt JPEG\");\n   s->img_n = c;\n   for (i=0; i < c; ++i) {\n      z->img_comp[i].data = NULL;\n      z->img_comp[i].linebuf = NULL;\n   }\n\n   if (Lf != 8+3*s->img_n) return stbi__err(\"bad SOF len\",\"Corrupt JPEG\");\n\n   z->rgb = 0;\n   for (i=0; i < s->img_n; ++i) {\n      static const unsigned char rgb[3] = { 'R', 'G', 'B' };\n      z->img_comp[i].id = stbi__get8(s);\n      if (s->img_n == 3 && z->img_comp[i].id == rgb[i])\n         ++z->rgb;\n      q = stbi__get8(s);\n      z->img_comp[i].h = (q >> 4);  if (!z->img_comp[i].h || z->img_comp[i].h > 4) return stbi__err(\"bad H\",\"Corrupt JPEG\");\n      z->img_comp[i].v = q & 15;    if (!z->img_comp[i].v || z->img_comp[i].v > 4) return stbi__err(\"bad V\",\"Corrupt JPEG\");\n      z->img_comp[i].tq = stbi__get8(s);  if (z->img_comp[i].tq > 3) return stbi__err(\"bad TQ\",\"Corrupt JPEG\");\n   }\n\n   if (scan != STBI__SCAN_load) return 1;\n\n   if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) return stbi__err(\"too large\", \"Image too large to decode\");\n\n   for (i=0; i < s->img_n; ++i) {\n      if (z->img_comp[i].h > h_max) h_max = z->img_comp[i].h;\n      if (z->img_comp[i].v > v_max) v_max = z->img_comp[i].v;\n   }\n\n   // check that plane subsampling factors are integer ratios; our resamplers can't deal with fractional ratios\n   // and I've never seen a non-corrupted JPEG file actually use them\n   for (i=0; i < s->img_n; ++i) {\n      if (h_max % z->img_comp[i].h != 0) return stbi__err(\"bad H\",\"Corrupt JPEG\");\n      if (v_max % z->img_comp[i].v != 0) return stbi__err(\"bad V\",\"Corrupt JPEG\");\n   }\n\n   // compute interleaved mcu info\n   z->img_h_max = h_max;\n   z->img_v_max = v_max;\n   z->img_mcu_w = h_max * 8;\n   z->img_mcu_h = v_max * 8;\n   // these sizes can't be more than 17 bits\n   z->img_mcu_x = (s->img_x + z->img_mcu_w-1) / z->img_mcu_w;\n   z->img_mcu_y = (s->img_y + z->img_mcu_h-1) / z->img_mcu_h;\n\n   for (i=0; i < s->img_n; ++i) {\n      // number of effective pixels (e.g. for non-interleaved MCU)\n      z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max-1) / h_max;\n      z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max-1) / v_max;\n      // to simplify generation, we'll allocate enough memory to decode\n      // the bogus oversized data from using interleaved MCUs and their\n      // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't\n      // discard the extra data until colorspace conversion\n      //\n      // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked earlier)\n      // so these muls can't overflow with 32-bit ints (which we require)\n      z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8;\n      z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8;\n      z->img_comp[i].coeff = 0;\n      z->img_comp[i].raw_coeff = 0;\n      z->img_comp[i].linebuf = NULL;\n      z->img_comp[i].raw_data = stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15);\n      if (z->img_comp[i].raw_data == NULL)\n         return stbi__free_jpeg_components(z, i+1, stbi__err(\"outofmem\", \"Out of memory\"));\n      // align blocks for idct using mmx/sse\n      z->img_comp[i].data = (stbi_uc*) (((size_t) z->img_comp[i].raw_data + 15) & ~15);\n      if (z->progressive) {\n         // w2, h2 are multiples of 8 (see above)\n         z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8;\n         z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8;\n         z->img_comp[i].raw_coeff = stbi__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15);\n         if (z->img_comp[i].raw_coeff == NULL)\n            return stbi__free_jpeg_components(z, i+1, stbi__err(\"outofmem\", \"Out of memory\"));\n         z->img_comp[i].coeff = (short*) (((size_t) z->img_comp[i].raw_coeff + 15) & ~15);\n      }\n   }\n\n   return 1;\n}\n\n// use comparisons since in some cases we handle more than one case (e.g. SOF)\n#define stbi__DNL(x)         ((x) == 0xdc)\n#define stbi__SOI(x)         ((x) == 0xd8)\n#define stbi__EOI(x)         ((x) == 0xd9)\n#define stbi__SOF(x)         ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2)\n#define stbi__SOS(x)         ((x) == 0xda)\n\n#define stbi__SOF_progressive(x)   ((x) == 0xc2)\n\nstatic int stbi__decode_jpeg_header(stbi__jpeg *z, int scan)\n{\n   int m;\n   z->jfif = 0;\n   z->app14_color_transform = -1; // valid values are 0,1,2\n   z->marker = STBI__MARKER_none; // initialize cached marker to empty\n   m = stbi__get_marker(z);\n   if (!stbi__SOI(m)) return stbi__err(\"no SOI\",\"Corrupt JPEG\");\n   if (scan == STBI__SCAN_type) return 1;\n   m = stbi__get_marker(z);\n   while (!stbi__SOF(m)) {\n      if (!stbi__process_marker(z,m)) return 0;\n      m = stbi__get_marker(z);\n      while (m == STBI__MARKER_none) {\n         // some files have extra padding after their blocks, so ok, we'll scan\n         if (stbi__at_eof(z->s)) return stbi__err(\"no SOF\", \"Corrupt JPEG\");\n         m = stbi__get_marker(z);\n      }\n   }\n   z->progressive = stbi__SOF_progressive(m);\n   if (!stbi__process_frame_header(z, scan)) return 0;\n   return 1;\n}\n\nstatic int stbi__skip_jpeg_junk_at_end(stbi__jpeg *j)\n{\n   // some JPEGs have junk at end, skip over it but if we find what looks\n   // like a valid marker, resume there\n   while (!stbi__at_eof(j->s)) {\n      int x = stbi__get8(j->s);\n      while (x == 255) { // might be a marker\n         if (stbi__at_eof(j->s)) return STBI__MARKER_none;\n         x = stbi__get8(j->s);\n         if (x != 0x00 && x != 0xff) {\n            // not a stuffed zero or lead-in to another marker, looks\n            // like an actual marker, return it\n            return x;\n         }\n         // stuffed zero has x=0 now which ends the loop, meaning we go\n         // back to regular scan loop.\n         // repeated 0xff keeps trying to read the next byte of the marker.\n      }\n   }\n   return STBI__MARKER_none;\n}\n\n// decode image to YCbCr format\nstatic int stbi__decode_jpeg_image(stbi__jpeg *j)\n{\n   int m;\n   for (m = 0; m < 4; m++) {\n      j->img_comp[m].raw_data = NULL;\n      j->img_comp[m].raw_coeff = NULL;\n   }\n   j->restart_interval = 0;\n   if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) return 0;\n   m = stbi__get_marker(j);\n   while (!stbi__EOI(m)) {\n      if (stbi__SOS(m)) {\n         if (!stbi__process_scan_header(j)) return 0;\n         if (!stbi__parse_entropy_coded_data(j)) return 0;\n         if (j->marker == STBI__MARKER_none ) {\n         j->marker = stbi__skip_jpeg_junk_at_end(j);\n            // if we reach eof without hitting a marker, stbi__get_marker() below will fail and we'll eventually return 0\n         }\n         m = stbi__get_marker(j);\n         if (STBI__RESTART(m))\n            m = stbi__get_marker(j);\n      } else if (stbi__DNL(m)) {\n         int Ld = stbi__get16be(j->s);\n         stbi__uint32 NL = stbi__get16be(j->s);\n         if (Ld != 4) return stbi__err(\"bad DNL len\", \"Corrupt JPEG\");\n         if (NL != j->s->img_y) return stbi__err(\"bad DNL height\", \"Corrupt JPEG\");\n         m = stbi__get_marker(j);\n      } else {\n         if (!stbi__process_marker(j, m)) return 1;\n         m = stbi__get_marker(j);\n      }\n   }\n   if (j->progressive)\n      stbi__jpeg_finish(j);\n   return 1;\n}\n\n// static jfif-centered resampling (across block boundaries)\n\ntypedef stbi_uc *(*resample_row_func)(stbi_uc *out, stbi_uc *in0, stbi_uc *in1,\n                                    int w, int hs);\n\n#define stbi__div4(x) ((stbi_uc) ((x) >> 2))\n\nstatic stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs)\n{\n   STBI_NOTUSED(out);\n   STBI_NOTUSED(in_far);\n   STBI_NOTUSED(w);\n   STBI_NOTUSED(hs);\n   return in_near;\n}\n\nstatic stbi_uc* stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs)\n{\n   // need to generate two samples vertically for every one in input\n   int i;\n   STBI_NOTUSED(hs);\n   for (i=0; i < w; ++i)\n      out[i] = stbi__div4(3*in_near[i] + in_far[i] + 2);\n   return out;\n}\n\nstatic stbi_uc*  stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs)\n{\n   // need to generate two samples horizontally for every one in input\n   int i;\n   stbi_uc *input = in_near;\n\n   if (w == 1) {\n      // if only one sample, can't do any interpolation\n      out[0] = out[1] = input[0];\n      return out;\n   }\n\n   out[0] = input[0];\n   out[1] = stbi__div4(input[0]*3 + input[1] + 2);\n   for (i=1; i < w-1; ++i) {\n      int n = 3*input[i]+2;\n      out[i*2+0] = stbi__div4(n+input[i-1]);\n      out[i*2+1] = stbi__div4(n+input[i+1]);\n   }\n   out[i*2+0] = stbi__div4(input[w-2]*3 + input[w-1] + 2);\n   out[i*2+1] = input[w-1];\n\n   STBI_NOTUSED(in_far);\n   STBI_NOTUSED(hs);\n\n   return out;\n}\n\n#define stbi__div16(x) ((stbi_uc) ((x) >> 4))\n\nstatic stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs)\n{\n   // need to generate 2x2 samples for every one in input\n   int i,t0,t1;\n   if (w == 1) {\n      out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2);\n      return out;\n   }\n\n   t1 = 3*in_near[0] + in_far[0];\n   out[0] = stbi__div4(t1+2);\n   for (i=1; i < w; ++i) {\n      t0 = t1;\n      t1 = 3*in_near[i]+in_far[i];\n      out[i*2-1] = stbi__div16(3*t0 + t1 + 8);\n      out[i*2  ] = stbi__div16(3*t1 + t0 + 8);\n   }\n   out[w*2-1] = stbi__div4(t1+2);\n\n   STBI_NOTUSED(hs);\n\n   return out;\n}\n\n#if defined(STBI_SSE2) || defined(STBI_NEON)\nstatic stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs)\n{\n   // need to generate 2x2 samples for every one in input\n   int i=0,t0,t1;\n\n   if (w == 1) {\n      out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2);\n      return out;\n   }\n\n   t1 = 3*in_near[0] + in_far[0];\n   // process groups of 8 pixels for as long as we can.\n   // note we can't handle the last pixel in a row in this loop\n   // because we need to handle the filter boundary conditions.\n   for (; i < ((w-1) & ~7); i += 8) {\n#if defined(STBI_SSE2)\n      // load and perform the vertical filtering pass\n      // this uses 3*x + y = 4*x + (y - x)\n      __m128i zero  = _mm_setzero_si128();\n      __m128i farb  = _mm_loadl_epi64((__m128i *) (in_far + i));\n      __m128i nearb = _mm_loadl_epi64((__m128i *) (in_near + i));\n      __m128i farw  = _mm_unpacklo_epi8(farb, zero);\n      __m128i nearw = _mm_unpacklo_epi8(nearb, zero);\n      __m128i diff  = _mm_sub_epi16(farw, nearw);\n      __m128i nears = _mm_slli_epi16(nearw, 2);\n      __m128i curr  = _mm_add_epi16(nears, diff); // current row\n\n      // horizontal filter works the same based on shifted vers of current\n      // row. \"prev\" is current row shifted right by 1 pixel; we need to\n      // insert the previous pixel value (from t1).\n      // \"next\" is current row shifted left by 1 pixel, with first pixel\n      // of next block of 8 pixels added in.\n      __m128i prv0 = _mm_slli_si128(curr, 2);\n      __m128i nxt0 = _mm_srli_si128(curr, 2);\n      __m128i prev = _mm_insert_epi16(prv0, t1, 0);\n      __m128i next = _mm_insert_epi16(nxt0, 3*in_near[i+8] + in_far[i+8], 7);\n\n      // horizontal filter, polyphase implementation since it's convenient:\n      // even pixels = 3*cur + prev = cur*4 + (prev - cur)\n      // odd  pixels = 3*cur + next = cur*4 + (next - cur)\n      // note the shared term.\n      __m128i bias  = _mm_set1_epi16(8);\n      __m128i curs = _mm_slli_epi16(curr, 2);\n      __m128i prvd = _mm_sub_epi16(prev, curr);\n      __m128i nxtd = _mm_sub_epi16(next, curr);\n      __m128i curb = _mm_add_epi16(curs, bias);\n      __m128i even = _mm_add_epi16(prvd, curb);\n      __m128i odd  = _mm_add_epi16(nxtd, curb);\n\n      // interleave even and odd pixels, then undo scaling.\n      __m128i int0 = _mm_unpacklo_epi16(even, odd);\n      __m128i int1 = _mm_unpackhi_epi16(even, odd);\n      __m128i de0  = _mm_srli_epi16(int0, 4);\n      __m128i de1  = _mm_srli_epi16(int1, 4);\n\n      // pack and write output\n      __m128i outv = _mm_packus_epi16(de0, de1);\n      _mm_storeu_si128((__m128i *) (out + i*2), outv);\n#elif defined(STBI_NEON)\n      // load and perform the vertical filtering pass\n      // this uses 3*x + y = 4*x + (y - x)\n      uint8x8_t farb  = vld1_u8(in_far + i);\n      uint8x8_t nearb = vld1_u8(in_near + i);\n      int16x8_t diff  = vreinterpretq_s16_u16(vsubl_u8(farb, nearb));\n      int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2));\n      int16x8_t curr  = vaddq_s16(nears, diff); // current row\n\n      // horizontal filter works the same based on shifted vers of current\n      // row. \"prev\" is current row shifted right by 1 pixel; we need to\n      // insert the previous pixel value (from t1).\n      // \"next\" is current row shifted left by 1 pixel, with first pixel\n      // of next block of 8 pixels added in.\n      int16x8_t prv0 = vextq_s16(curr, curr, 7);\n      int16x8_t nxt0 = vextq_s16(curr, curr, 1);\n      int16x8_t prev = vsetq_lane_s16(t1, prv0, 0);\n      int16x8_t next = vsetq_lane_s16(3*in_near[i+8] + in_far[i+8], nxt0, 7);\n\n      // horizontal filter, polyphase implementation since it's convenient:\n      // even pixels = 3*cur + prev = cur*4 + (prev - cur)\n      // odd  pixels = 3*cur + next = cur*4 + (next - cur)\n      // note the shared term.\n      int16x8_t curs = vshlq_n_s16(curr, 2);\n      int16x8_t prvd = vsubq_s16(prev, curr);\n      int16x8_t nxtd = vsubq_s16(next, curr);\n      int16x8_t even = vaddq_s16(curs, prvd);\n      int16x8_t odd  = vaddq_s16(curs, nxtd);\n\n      // undo scaling and round, then store with even/odd phases interleaved\n      uint8x8x2_t o;\n      o.val[0] = vqrshrun_n_s16(even, 4);\n      o.val[1] = vqrshrun_n_s16(odd,  4);\n      vst2_u8(out + i*2, o);\n#endif\n\n      // \"previous\" value for next iter\n      t1 = 3*in_near[i+7] + in_far[i+7];\n   }\n\n   t0 = t1;\n   t1 = 3*in_near[i] + in_far[i];\n   out[i*2] = stbi__div16(3*t1 + t0 + 8);\n\n   for (++i; i < w; ++i) {\n      t0 = t1;\n      t1 = 3*in_near[i]+in_far[i];\n      out[i*2-1] = stbi__div16(3*t0 + t1 + 8);\n      out[i*2  ] = stbi__div16(3*t1 + t0 + 8);\n   }\n   out[w*2-1] = stbi__div4(t1+2);\n\n   STBI_NOTUSED(hs);\n\n   return out;\n}\n#endif\n\nstatic stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs)\n{\n   // resample with nearest-neighbor\n   int i,j;\n   STBI_NOTUSED(in_far);\n   for (i=0; i < w; ++i)\n      for (j=0; j < hs; ++j)\n         out[i*hs+j] = in_near[i];\n   return out;\n}\n\n// this is a reduced-precision calculation of YCbCr-to-RGB introduced\n// to make sure the code produces the same results in both SIMD and scalar\n#define stbi__float2fixed(x)  (((int) ((x) * 4096.0f + 0.5f)) << 8)\nstatic void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step)\n{\n   int i;\n   for (i=0; i < count; ++i) {\n      int y_fixed = (y[i] << 20) + (1<<19); // rounding\n      int r,g,b;\n      int cr = pcr[i] - 128;\n      int cb = pcb[i] - 128;\n      r = y_fixed +  cr* stbi__float2fixed(1.40200f);\n      g = y_fixed + (cr*-stbi__float2fixed(0.71414f)) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000);\n      b = y_fixed                                     +   cb* stbi__float2fixed(1.77200f);\n      r >>= 20;\n      g >>= 20;\n      b >>= 20;\n      if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; }\n      if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; }\n      if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; }\n      out[0] = (stbi_uc)r;\n      out[1] = (stbi_uc)g;\n      out[2] = (stbi_uc)b;\n      out[3] = 255;\n      out += step;\n   }\n}\n\n#if defined(STBI_SSE2) || defined(STBI_NEON)\nstatic void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, stbi_uc const *pcb, stbi_uc const *pcr, int count, int step)\n{\n   int i = 0;\n\n#ifdef STBI_SSE2\n   // step == 3 is pretty ugly on the final interleave, and i'm not convinced\n   // it's useful in practice (you wouldn't use it for textures, for example).\n   // so just accelerate step == 4 case.\n   if (step == 4) {\n      // this is a fairly straightforward implementation and not super-optimized.\n      __m128i signflip  = _mm_set1_epi8(-0x80);\n      __m128i cr_const0 = _mm_set1_epi16(   (short) ( 1.40200f*4096.0f+0.5f));\n      __m128i cr_const1 = _mm_set1_epi16( - (short) ( 0.71414f*4096.0f+0.5f));\n      __m128i cb_const0 = _mm_set1_epi16( - (short) ( 0.34414f*4096.0f+0.5f));\n      __m128i cb_const1 = _mm_set1_epi16(   (short) ( 1.77200f*4096.0f+0.5f));\n      __m128i y_bias = _mm_set1_epi8((char) (unsigned char) 128);\n      __m128i xw = _mm_set1_epi16(255); // alpha channel\n\n      for (; i+7 < count; i += 8) {\n         // load\n         __m128i y_bytes = _mm_loadl_epi64((__m128i *) (y+i));\n         __m128i cr_bytes = _mm_loadl_epi64((__m128i *) (pcr+i));\n         __m128i cb_bytes = _mm_loadl_epi64((__m128i *) (pcb+i));\n         __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128\n         __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128\n\n         // unpack to short (and left-shift cr, cb by 8)\n         __m128i yw  = _mm_unpacklo_epi8(y_bias, y_bytes);\n         __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased);\n         __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased);\n\n         // color transform\n         __m128i yws = _mm_srli_epi16(yw, 4);\n         __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw);\n         __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw);\n         __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1);\n         __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1);\n         __m128i rws = _mm_add_epi16(cr0, yws);\n         __m128i gwt = _mm_add_epi16(cb0, yws);\n         __m128i bws = _mm_add_epi16(yws, cb1);\n         __m128i gws = _mm_add_epi16(gwt, cr1);\n\n         // descale\n         __m128i rw = _mm_srai_epi16(rws, 4);\n         __m128i bw = _mm_srai_epi16(bws, 4);\n         __m128i gw = _mm_srai_epi16(gws, 4);\n\n         // back to byte, set up for transpose\n         __m128i brb = _mm_packus_epi16(rw, bw);\n         __m128i gxb = _mm_packus_epi16(gw, xw);\n\n         // transpose to interleave channels\n         __m128i t0 = _mm_unpacklo_epi8(brb, gxb);\n         __m128i t1 = _mm_unpackhi_epi8(brb, gxb);\n         __m128i o0 = _mm_unpacklo_epi16(t0, t1);\n         __m128i o1 = _mm_unpackhi_epi16(t0, t1);\n\n         // store\n         _mm_storeu_si128((__m128i *) (out + 0), o0);\n         _mm_storeu_si128((__m128i *) (out + 16), o1);\n         out += 32;\n      }\n   }\n#endif\n\n#ifdef STBI_NEON\n   // in this version, step=3 support would be easy to add. but is there demand?\n   if (step == 4) {\n      // this is a fairly straightforward implementation and not super-optimized.\n      uint8x8_t signflip = vdup_n_u8(0x80);\n      int16x8_t cr_const0 = vdupq_n_s16(   (short) ( 1.40200f*4096.0f+0.5f));\n      int16x8_t cr_const1 = vdupq_n_s16( - (short) ( 0.71414f*4096.0f+0.5f));\n      int16x8_t cb_const0 = vdupq_n_s16( - (short) ( 0.34414f*4096.0f+0.5f));\n      int16x8_t cb_const1 = vdupq_n_s16(   (short) ( 1.77200f*4096.0f+0.5f));\n\n      for (; i+7 < count; i += 8) {\n         // load\n         uint8x8_t y_bytes  = vld1_u8(y + i);\n         uint8x8_t cr_bytes = vld1_u8(pcr + i);\n         uint8x8_t cb_bytes = vld1_u8(pcb + i);\n         int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip));\n         int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip));\n\n         // expand to s16\n         int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4));\n         int16x8_t crw = vshll_n_s8(cr_biased, 7);\n         int16x8_t cbw = vshll_n_s8(cb_biased, 7);\n\n         // color transform\n         int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0);\n         int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0);\n         int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1);\n         int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1);\n         int16x8_t rws = vaddq_s16(yws, cr0);\n         int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1);\n         int16x8_t bws = vaddq_s16(yws, cb1);\n\n         // undo scaling, round, convert to byte\n         uint8x8x4_t o;\n         o.val[0] = vqrshrun_n_s16(rws, 4);\n         o.val[1] = vqrshrun_n_s16(gws, 4);\n         o.val[2] = vqrshrun_n_s16(bws, 4);\n         o.val[3] = vdup_n_u8(255);\n\n         // store, interleaving r/g/b/a\n         vst4_u8(out, o);\n         out += 8*4;\n      }\n   }\n#endif\n\n   for (; i < count; ++i) {\n      int y_fixed = (y[i] << 20) + (1<<19); // rounding\n      int r,g,b;\n      int cr = pcr[i] - 128;\n      int cb = pcb[i] - 128;\n      r = y_fixed + cr* stbi__float2fixed(1.40200f);\n      g = y_fixed + cr*-stbi__float2fixed(0.71414f) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000);\n      b = y_fixed                                   +   cb* stbi__float2fixed(1.77200f);\n      r >>= 20;\n      g >>= 20;\n      b >>= 20;\n      if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; }\n      if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; }\n      if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; }\n      out[0] = (stbi_uc)r;\n      out[1] = (stbi_uc)g;\n      out[2] = (stbi_uc)b;\n      out[3] = 255;\n      out += step;\n   }\n}\n#endif\n\n// set up the kernels\nstatic void stbi__setup_jpeg(stbi__jpeg *j)\n{\n   j->idct_block_kernel = stbi__idct_block;\n   j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row;\n   j->resample_row_hv_2_kernel = stbi__resample_row_hv_2;\n\n#ifdef STBI_SSE2\n   if (stbi__sse2_available()) {\n      j->idct_block_kernel = stbi__idct_simd;\n      j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd;\n      j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd;\n   }\n#endif\n\n#ifdef STBI_NEON\n   j->idct_block_kernel = stbi__idct_simd;\n   j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd;\n   j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd;\n#endif\n}\n\n// clean up the temporary component buffers\nstatic void stbi__cleanup_jpeg(stbi__jpeg *j)\n{\n   stbi__free_jpeg_components(j, j->s->img_n, 0);\n}\n\ntypedef struct\n{\n   resample_row_func resample;\n   stbi_uc *line0,*line1;\n   int hs,vs;   // expansion factor in each axis\n   int w_lores; // horizontal pixels pre-expansion\n   int ystep;   // how far through vertical expansion we are\n   int ypos;    // which pre-expansion row we're on\n} stbi__resample;\n\n// fast 0..255 * 0..255 => 0..255 rounded multiplication\nstatic stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y)\n{\n   unsigned int t = x*y + 128;\n   return (stbi_uc) ((t + (t >>8)) >> 8);\n}\n\nstatic stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, int *comp, int req_comp)\n{\n   int n, decode_n, is_rgb;\n   z->s->img_n = 0; // make stbi__cleanup_jpeg safe\n\n   // validate req_comp\n   if (req_comp < 0 || req_comp > 4) return stbi__errpuc(\"bad req_comp\", \"Internal error\");\n\n   // load a jpeg image from whichever source, but leave in YCbCr format\n   if (!stbi__decode_jpeg_image(z)) { stbi__cleanup_jpeg(z); return NULL; }\n\n   // determine actual number of components to generate\n   n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1;\n\n   is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif));\n\n   if (z->s->img_n == 3 && n < 3 && !is_rgb)\n      decode_n = 1;\n   else\n      decode_n = z->s->img_n;\n\n   // nothing to do if no components requested; check this now to avoid\n   // accessing uninitialized coutput[0] later\n   if (decode_n <= 0) { stbi__cleanup_jpeg(z); return NULL; }\n\n   // resample and color-convert\n   {\n      int k;\n      unsigned int i,j;\n      stbi_uc *output;\n      stbi_uc *coutput[4] = { NULL, NULL, NULL, NULL };\n\n      stbi__resample res_comp[4];\n\n      for (k=0; k < decode_n; ++k) {\n         stbi__resample *r = &res_comp[k];\n\n         // allocate line buffer big enough for upsampling off the edges\n         // with upsample factor of 4\n         z->img_comp[k].linebuf = (stbi_uc *) stbi__malloc(z->s->img_x + 3);\n         if (!z->img_comp[k].linebuf) { stbi__cleanup_jpeg(z); return stbi__errpuc(\"outofmem\", \"Out of memory\"); }\n\n         r->hs      = z->img_h_max / z->img_comp[k].h;\n         r->vs      = z->img_v_max / z->img_comp[k].v;\n         r->ystep   = r->vs >> 1;\n         r->w_lores = (z->s->img_x + r->hs-1) / r->hs;\n         r->ypos    = 0;\n         r->line0   = r->line1 = z->img_comp[k].data;\n\n         if      (r->hs == 1 && r->vs == 1) r->resample = resample_row_1;\n         else if (r->hs == 1 && r->vs == 2) r->resample = stbi__resample_row_v_2;\n         else if (r->hs == 2 && r->vs == 1) r->resample = stbi__resample_row_h_2;\n         else if (r->hs == 2 && r->vs == 2) r->resample = z->resample_row_hv_2_kernel;\n         else                               r->resample = stbi__resample_row_generic;\n      }\n\n      // can't error after this so, this is safe\n      output = (stbi_uc *) stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1);\n      if (!output) { stbi__cleanup_jpeg(z); return stbi__errpuc(\"outofmem\", \"Out of memory\"); }\n\n      // now go ahead and resample\n      for (j=0; j < z->s->img_y; ++j) {\n         stbi_uc *out = output + n * z->s->img_x * j;\n         for (k=0; k < decode_n; ++k) {\n            stbi__resample *r = &res_comp[k];\n            int y_bot = r->ystep >= (r->vs >> 1);\n            coutput[k] = r->resample(z->img_comp[k].linebuf,\n                                     y_bot ? r->line1 : r->line0,\n                                     y_bot ? r->line0 : r->line1,\n                                     r->w_lores, r->hs);\n            if (++r->ystep >= r->vs) {\n               r->ystep = 0;\n               r->line0 = r->line1;\n               if (++r->ypos < z->img_comp[k].y)\n                  r->line1 += z->img_comp[k].w2;\n            }\n         }\n         if (n >= 3) {\n            stbi_uc *y = coutput[0];\n            if (z->s->img_n == 3) {\n               if (is_rgb) {\n                  for (i=0; i < z->s->img_x; ++i) {\n                     out[0] = y[i];\n                     out[1] = coutput[1][i];\n                     out[2] = coutput[2][i];\n                     out[3] = 255;\n                     out += n;\n                  }\n               } else {\n                  z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n);\n               }\n            } else if (z->s->img_n == 4) {\n               if (z->app14_color_transform == 0) { // CMYK\n                  for (i=0; i < z->s->img_x; ++i) {\n                     stbi_uc m = coutput[3][i];\n                     out[0] = stbi__blinn_8x8(coutput[0][i], m);\n                     out[1] = stbi__blinn_8x8(coutput[1][i], m);\n                     out[2] = stbi__blinn_8x8(coutput[2][i], m);\n                     out[3] = 255;\n                     out += n;\n                  }\n               } else if (z->app14_color_transform == 2) { // YCCK\n                  z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n);\n                  for (i=0; i < z->s->img_x; ++i) {\n                     stbi_uc m = coutput[3][i];\n                     out[0] = stbi__blinn_8x8(255 - out[0], m);\n                     out[1] = stbi__blinn_8x8(255 - out[1], m);\n                     out[2] = stbi__blinn_8x8(255 - out[2], m);\n                     out += n;\n                  }\n               } else { // YCbCr + alpha?  Ignore the fourth channel for now\n                  z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n);\n               }\n            } else\n               for (i=0; i < z->s->img_x; ++i) {\n                  out[0] = out[1] = out[2] = y[i];\n                  out[3] = 255; // not used if n==3\n                  out += n;\n               }\n         } else {\n            if (is_rgb) {\n               if (n == 1)\n                  for (i=0; i < z->s->img_x; ++i)\n                     *out++ = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]);\n               else {\n                  for (i=0; i < z->s->img_x; ++i, out += 2) {\n                     out[0] = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]);\n                     out[1] = 255;\n                  }\n               }\n            } else if (z->s->img_n == 4 && z->app14_color_transform == 0) {\n               for (i=0; i < z->s->img_x; ++i) {\n                  stbi_uc m = coutput[3][i];\n                  stbi_uc r = stbi__blinn_8x8(coutput[0][i], m);\n                  stbi_uc g = stbi__blinn_8x8(coutput[1][i], m);\n                  stbi_uc b = stbi__blinn_8x8(coutput[2][i], m);\n                  out[0] = stbi__compute_y(r, g, b);\n                  out[1] = 255;\n                  out += n;\n               }\n            } else if (z->s->img_n == 4 && z->app14_color_transform == 2) {\n               for (i=0; i < z->s->img_x; ++i) {\n                  out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]);\n                  out[1] = 255;\n                  out += n;\n               }\n            } else {\n               stbi_uc *y = coutput[0];\n               if (n == 1)\n                  for (i=0; i < z->s->img_x; ++i) out[i] = y[i];\n               else\n                  for (i=0; i < z->s->img_x; ++i) { *out++ = y[i]; *out++ = 255; }\n            }\n         }\n      }\n      stbi__cleanup_jpeg(z);\n      *out_x = z->s->img_x;\n      *out_y = z->s->img_y;\n      if (comp) *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output\n      return output;\n   }\n}\n\nstatic void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri)\n{\n   unsigned char* result;\n   stbi__jpeg* j = (stbi__jpeg*) stbi__malloc(sizeof(stbi__jpeg));\n   if (!j) return stbi__errpuc(\"outofmem\", \"Out of memory\");\n   memset(j, 0, sizeof(stbi__jpeg));\n   STBI_NOTUSED(ri);\n   j->s = s;\n   stbi__setup_jpeg(j);\n   result = load_jpeg_image(j, x,y,comp,req_comp);\n   STBI_FREE(j);\n   return result;\n}\n\nstatic int stbi__jpeg_test(stbi__context *s)\n{\n   int r;\n   stbi__jpeg* j = (stbi__jpeg*)stbi__malloc(sizeof(stbi__jpeg));\n   if (!j) return stbi__err(\"outofmem\", \"Out of memory\");\n   memset(j, 0, sizeof(stbi__jpeg));\n   j->s = s;\n   stbi__setup_jpeg(j);\n   r = stbi__decode_jpeg_header(j, STBI__SCAN_type);\n   stbi__rewind(s);\n   STBI_FREE(j);\n   return r;\n}\n\nstatic int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp)\n{\n   if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) {\n      stbi__rewind( j->s );\n      return 0;\n   }\n   if (x) *x = j->s->img_x;\n   if (y) *y = j->s->img_y;\n   if (comp) *comp = j->s->img_n >= 3 ? 3 : 1;\n   return 1;\n}\n\nstatic int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp)\n{\n   int result;\n   stbi__jpeg* j = (stbi__jpeg*) (stbi__malloc(sizeof(stbi__jpeg)));\n   if (!j) return stbi__err(\"outofmem\", \"Out of memory\");\n   memset(j, 0, sizeof(stbi__jpeg));\n   j->s = s;\n   result = stbi__jpeg_info_raw(j, x, y, comp);\n   STBI_FREE(j);\n   return result;\n}\n#endif\n\n// public domain zlib decode    v0.2  Sean Barrett 2006-11-18\n//    simple implementation\n//      - all input must be provided in an upfront buffer\n//      - all output is written to a single output buffer (can malloc/realloc)\n//    performance\n//      - fast huffman\n\n#ifndef STBI_NO_ZLIB\n\n// fast-way is faster to check than jpeg huffman, but slow way is slower\n#define STBI__ZFAST_BITS  9 // accelerate all cases in default tables\n#define STBI__ZFAST_MASK  ((1 << STBI__ZFAST_BITS) - 1)\n#define STBI__ZNSYMS 288 // number of symbols in literal/length alphabet\n\n// zlib-style huffman encoding\n// (jpegs packs from left, zlib from right, so can't share code)\ntypedef struct\n{\n   stbi__uint16 fast[1 << STBI__ZFAST_BITS];\n   stbi__uint16 firstcode[16];\n   int maxcode[17];\n   stbi__uint16 firstsymbol[16];\n   stbi_uc  size[STBI__ZNSYMS];\n   stbi__uint16 value[STBI__ZNSYMS];\n} stbi__zhuffman;\n\nstbi_inline static int stbi__bitreverse16(int n)\n{\n  n = ((n & 0xAAAA) >>  1) | ((n & 0x5555) << 1);\n  n = ((n & 0xCCCC) >>  2) | ((n & 0x3333) << 2);\n  n = ((n & 0xF0F0) >>  4) | ((n & 0x0F0F) << 4);\n  n = ((n & 0xFF00) >>  8) | ((n & 0x00FF) << 8);\n  return n;\n}\n\nstbi_inline static int stbi__bit_reverse(int v, int bits)\n{\n   STBI_ASSERT(bits <= 16);\n   // to bit reverse n bits, reverse 16 and shift\n   // e.g. 11 bits, bit reverse and shift away 5\n   return stbi__bitreverse16(v) >> (16-bits);\n}\n\nstatic int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int num)\n{\n   int i,k=0;\n   int code, next_code[16], sizes[17];\n\n   // DEFLATE spec for generating codes\n   memset(sizes, 0, sizeof(sizes));\n   memset(z->fast, 0, sizeof(z->fast));\n   for (i=0; i < num; ++i)\n      ++sizes[sizelist[i]];\n   sizes[0] = 0;\n   for (i=1; i < 16; ++i)\n      if (sizes[i] > (1 << i))\n         return stbi__err(\"bad sizes\", \"Corrupt PNG\");\n   code = 0;\n   for (i=1; i < 16; ++i) {\n      next_code[i] = code;\n      z->firstcode[i] = (stbi__uint16) code;\n      z->firstsymbol[i] = (stbi__uint16) k;\n      code = (code + sizes[i]);\n      if (sizes[i])\n         if (code-1 >= (1 << i)) return stbi__err(\"bad codelengths\",\"Corrupt PNG\");\n      z->maxcode[i] = code << (16-i); // preshift for inner loop\n      code <<= 1;\n      k += sizes[i];\n   }\n   z->maxcode[16] = 0x10000; // sentinel\n   for (i=0; i < num; ++i) {\n      int s = sizelist[i];\n      if (s) {\n         int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s];\n         stbi__uint16 fastv = (stbi__uint16) ((s << 9) | i);\n         z->size [c] = (stbi_uc     ) s;\n         z->value[c] = (stbi__uint16) i;\n         if (s <= STBI__ZFAST_BITS) {\n            int j = stbi__bit_reverse(next_code[s],s);\n            while (j < (1 << STBI__ZFAST_BITS)) {\n               z->fast[j] = fastv;\n               j += (1 << s);\n            }\n         }\n         ++next_code[s];\n      }\n   }\n   return 1;\n}\n\n// zlib-from-memory implementation for PNG reading\n//    because PNG allows splitting the zlib stream arbitrarily,\n//    and it's annoying structurally to have PNG call ZLIB call PNG,\n//    we require PNG read all the IDATs and combine them into a single\n//    memory buffer\n\ntypedef struct\n{\n   stbi_uc *zbuffer, *zbuffer_end;\n   int num_bits;\n   stbi__uint32 code_buffer;\n\n   char *zout;\n   char *zout_start;\n   char *zout_end;\n   int   z_expandable;\n\n   stbi__zhuffman z_length, z_distance;\n} stbi__zbuf;\n\nstbi_inline static int stbi__zeof(stbi__zbuf *z)\n{\n   return (z->zbuffer >= z->zbuffer_end);\n}\n\nstbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z)\n{\n   return stbi__zeof(z) ? 0 : *z->zbuffer++;\n}\n\nstatic void stbi__fill_bits(stbi__zbuf *z)\n{\n   do {\n      if (z->code_buffer >= (1U << z->num_bits)) {\n        z->zbuffer = z->zbuffer_end;  /* treat this as EOF so we fail. */\n        return;\n      }\n      z->code_buffer |= (unsigned int) stbi__zget8(z) << z->num_bits;\n      z->num_bits += 8;\n   } while (z->num_bits <= 24);\n}\n\nstbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n)\n{\n   unsigned int k;\n   if (z->num_bits < n) stbi__fill_bits(z);\n   k = z->code_buffer & ((1 << n) - 1);\n   z->code_buffer >>= n;\n   z->num_bits -= n;\n   return k;\n}\n\nstatic int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z)\n{\n   int b,s,k;\n   // not resolved by fast table, so compute it the slow way\n   // use jpeg approach, which requires MSbits at top\n   k = stbi__bit_reverse(a->code_buffer, 16);\n   for (s=STBI__ZFAST_BITS+1; ; ++s)\n      if (k < z->maxcode[s])\n         break;\n   if (s >= 16) return -1; // invalid code!\n   // code size is s, so:\n   b = (k >> (16-s)) - z->firstcode[s] + z->firstsymbol[s];\n   if (b >= STBI__ZNSYMS) return -1; // some data was corrupt somewhere!\n   if (z->size[b] != s) return -1;  // was originally an assert, but report failure instead.\n   a->code_buffer >>= s;\n   a->num_bits -= s;\n   return z->value[b];\n}\n\nstbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z)\n{\n   int b,s;\n   if (a->num_bits < 16) {\n      if (stbi__zeof(a)) {\n         return -1;   /* report error for unexpected end of data. */\n      }\n      stbi__fill_bits(a);\n   }\n   b = z->fast[a->code_buffer & STBI__ZFAST_MASK];\n   if (b) {\n      s = b >> 9;\n      a->code_buffer >>= s;\n      a->num_bits -= s;\n      return b & 511;\n   }\n   return stbi__zhuffman_decode_slowpath(a, z);\n}\n\nstatic int stbi__zexpand(stbi__zbuf *z, char *zout, int n)  // need to make room for n bytes\n{\n   char *q;\n   unsigned int cur, limit, old_limit;\n   z->zout = zout;\n   if (!z->z_expandable) return stbi__err(\"output buffer limit\",\"Corrupt PNG\");\n   cur   = (unsigned int) (z->zout - z->zout_start);\n   limit = old_limit = (unsigned) (z->zout_end - z->zout_start);\n   if (UINT_MAX - cur < (unsigned) n) return stbi__err(\"outofmem\", \"Out of memory\");\n   while (cur + n > limit) {\n      if(limit > UINT_MAX / 2) return stbi__err(\"outofmem\", \"Out of memory\");\n      limit *= 2;\n   }\n   q = (char *) STBI_REALLOC_SIZED(z->zout_start, old_limit, limit);\n   STBI_NOTUSED(old_limit);\n   if (q == NULL) return stbi__err(\"outofmem\", \"Out of memory\");\n   z->zout_start = q;\n   z->zout       = q + cur;\n   z->zout_end   = q + limit;\n   return 1;\n}\n\nstatic const int stbi__zlength_base[31] = {\n   3,4,5,6,7,8,9,10,11,13,\n   15,17,19,23,27,31,35,43,51,59,\n   67,83,99,115,131,163,195,227,258,0,0 };\n\nstatic const int stbi__zlength_extra[31]=\n{ 0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0,0,0 };\n\nstatic const int stbi__zdist_base[32] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,\n257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577,0,0};\n\nstatic const int stbi__zdist_extra[32] =\n{ 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13};\n\nstatic int stbi__parse_huffman_block(stbi__zbuf *a)\n{\n   char *zout = a->zout;\n   for(;;) {\n      int z = stbi__zhuffman_decode(a, &a->z_length);\n      if (z < 256) {\n         if (z < 0) return stbi__err(\"bad huffman code\",\"Corrupt PNG\"); // error in huffman codes\n         if (zout >= a->zout_end) {\n            if (!stbi__zexpand(a, zout, 1)) return 0;\n            zout = a->zout;\n         }\n         *zout++ = (char) z;\n      } else {\n         stbi_uc *p;\n         int len,dist;\n         if (z == 256) {\n            a->zout = zout;\n            return 1;\n         }\n         if (z >= 286) return stbi__err(\"bad huffman code\",\"Corrupt PNG\"); // per DEFLATE, length codes 286 and 287 must not appear in compressed data\n         z -= 257;\n         len = stbi__zlength_base[z];\n         if (stbi__zlength_extra[z]) len += stbi__zreceive(a, stbi__zlength_extra[z]);\n         z = stbi__zhuffman_decode(a, &a->z_distance);\n         if (z < 0 || z >= 30) return stbi__err(\"bad huffman code\",\"Corrupt PNG\"); // per DEFLATE, distance codes 30 and 31 must not appear in compressed data\n         dist = stbi__zdist_base[z];\n         if (stbi__zdist_extra[z]) dist += stbi__zreceive(a, stbi__zdist_extra[z]);\n         if (zout - a->zout_start < dist) return stbi__err(\"bad dist\",\"Corrupt PNG\");\n         if (zout + len > a->zout_end) {\n            if (!stbi__zexpand(a, zout, len)) return 0;\n            zout = a->zout;\n         }\n         p = (stbi_uc *) (zout - dist);\n         if (dist == 1) { // run of one byte; common in images.\n            stbi_uc v = *p;\n            if (len) { do *zout++ = v; while (--len); }\n         } else {\n            if (len) { do *zout++ = *p++; while (--len); }\n         }\n      }\n   }\n}\n\nstatic int stbi__compute_huffman_codes(stbi__zbuf *a)\n{\n   static const stbi_uc length_dezigzag[19] = { 16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15 };\n   stbi__zhuffman z_codelength;\n   stbi_uc lencodes[286+32+137];//padding for maximum single op\n   stbi_uc codelength_sizes[19];\n   int i,n;\n\n   int hlit  = stbi__zreceive(a,5) + 257;\n   int hdist = stbi__zreceive(a,5) + 1;\n   int hclen = stbi__zreceive(a,4) + 4;\n   int ntot  = hlit + hdist;\n\n   memset(codelength_sizes, 0, sizeof(codelength_sizes));\n   for (i=0; i < hclen; ++i) {\n      int s = stbi__zreceive(a,3);\n      codelength_sizes[length_dezigzag[i]] = (stbi_uc) s;\n   }\n   if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) return 0;\n\n   n = 0;\n   while (n < ntot) {\n      int c = stbi__zhuffman_decode(a, &z_codelength);\n      if (c < 0 || c >= 19) return stbi__err(\"bad codelengths\", \"Corrupt PNG\");\n      if (c < 16)\n         lencodes[n++] = (stbi_uc) c;\n      else {\n         stbi_uc fill = 0;\n         if (c == 16) {\n            c = stbi__zreceive(a,2)+3;\n            if (n == 0) return stbi__err(\"bad codelengths\", \"Corrupt PNG\");\n            fill = lencodes[n-1];\n         } else if (c == 17) {\n            c = stbi__zreceive(a,3)+3;\n         } else if (c == 18) {\n            c = stbi__zreceive(a,7)+11;\n         } else {\n            return stbi__err(\"bad codelengths\", \"Corrupt PNG\");\n         }\n         if (ntot - n < c) return stbi__err(\"bad codelengths\", \"Corrupt PNG\");\n         memset(lencodes+n, fill, c);\n         n += c;\n      }\n   }\n   if (n != ntot) return stbi__err(\"bad codelengths\",\"Corrupt PNG\");\n   if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) return 0;\n   if (!stbi__zbuild_huffman(&a->z_distance, lencodes+hlit, hdist)) return 0;\n   return 1;\n}\n\nstatic int stbi__parse_uncompressed_block(stbi__zbuf *a)\n{\n   stbi_uc header[4];\n   int len,nlen,k;\n   if (a->num_bits & 7)\n      stbi__zreceive(a, a->num_bits & 7); // discard\n   // drain the bit-packed data into header\n   k = 0;\n   while (a->num_bits > 0) {\n      header[k++] = (stbi_uc) (a->code_buffer & 255); // suppress MSVC run-time check\n      a->code_buffer >>= 8;\n      a->num_bits -= 8;\n   }\n   if (a->num_bits < 0) return stbi__err(\"zlib corrupt\",\"Corrupt PNG\");\n   // now fill header the normal way\n   while (k < 4)\n      header[k++] = stbi__zget8(a);\n   len  = header[1] * 256 + header[0];\n   nlen = header[3] * 256 + header[2];\n   if (nlen != (len ^ 0xffff)) return stbi__err(\"zlib corrupt\",\"Corrupt PNG\");\n   if (a->zbuffer + len > a->zbuffer_end) return stbi__err(\"read past buffer\",\"Corrupt PNG\");\n   if (a->zout + len > a->zout_end)\n      if (!stbi__zexpand(a, a->zout, len)) return 0;\n   memcpy(a->zout, a->zbuffer, len);\n   a->zbuffer += len;\n   a->zout += len;\n   return 1;\n}\n\nstatic int stbi__parse_zlib_header(stbi__zbuf *a)\n{\n   int cmf   = stbi__zget8(a);\n   int cm    = cmf & 15;\n   /* int cinfo = cmf >> 4; */\n   int flg   = stbi__zget8(a);\n   if (stbi__zeof(a)) return stbi__err(\"bad zlib header\",\"Corrupt PNG\"); // zlib spec\n   if ((cmf*256+flg) % 31 != 0) return stbi__err(\"bad zlib header\",\"Corrupt PNG\"); // zlib spec\n   if (flg & 32) return stbi__err(\"no preset dict\",\"Corrupt PNG\"); // preset dictionary not allowed in png\n   if (cm != 8) return stbi__err(\"bad compression\",\"Corrupt PNG\"); // DEFLATE required for png\n   // window = 1 << (8 + cinfo)... but who cares, we fully buffer output\n   return 1;\n}\n\nstatic const stbi_uc stbi__zdefault_length[STBI__ZNSYMS] =\n{\n   8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,\n   8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,\n   8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,\n   8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,\n   8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,\n   9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,\n   9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,\n   9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,\n   7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8\n};\nstatic const stbi_uc stbi__zdefault_distance[32] =\n{\n   5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5\n};\n/*\nInit algorithm:\n{\n   int i;   // use <= to match clearly with spec\n   for (i=0; i <= 143; ++i)     stbi__zdefault_length[i]   = 8;\n   for (   ; i <= 255; ++i)     stbi__zdefault_length[i]   = 9;\n   for (   ; i <= 279; ++i)     stbi__zdefault_length[i]   = 7;\n   for (   ; i <= 287; ++i)     stbi__zdefault_length[i]   = 8;\n\n   for (i=0; i <=  31; ++i)     stbi__zdefault_distance[i] = 5;\n}\n*/\n\nstatic int stbi__parse_zlib(stbi__zbuf *a, int parse_header)\n{\n   int final, type;\n   if (parse_header)\n      if (!stbi__parse_zlib_header(a)) return 0;\n   a->num_bits = 0;\n   a->code_buffer = 0;\n   do {\n      final = stbi__zreceive(a,1);\n      type = stbi__zreceive(a,2);\n      if (type == 0) {\n         if (!stbi__parse_uncompressed_block(a)) return 0;\n      } else if (type == 3) {\n         return 0;\n      } else {\n         if (type == 1) {\n            // use fixed code lengths\n            if (!stbi__zbuild_huffman(&a->z_length  , stbi__zdefault_length  , STBI__ZNSYMS)) return 0;\n            if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance,  32)) return 0;\n         } else {\n            if (!stbi__compute_huffman_codes(a)) return 0;\n         }\n         if (!stbi__parse_huffman_block(a)) return 0;\n      }\n   } while (!final);\n   return 1;\n}\n\nstatic int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, int parse_header)\n{\n   a->zout_start = obuf;\n   a->zout       = obuf;\n   a->zout_end   = obuf + olen;\n   a->z_expandable = exp;\n\n   return stbi__parse_zlib(a, parse_header);\n}\n\nSTBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen)\n{\n   stbi__zbuf a;\n   char *p = (char *) stbi__malloc(initial_size);\n   if (p == NULL) return NULL;\n   a.zbuffer = (stbi_uc *) buffer;\n   a.zbuffer_end = (stbi_uc *) buffer + len;\n   if (stbi__do_zlib(&a, p, initial_size, 1, 1)) {\n      if (outlen) *outlen = (int) (a.zout - a.zout_start);\n      return a.zout_start;\n   } else {\n      STBI_FREE(a.zout_start);\n      return NULL;\n   }\n}\n\nSTBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, int *outlen)\n{\n   return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen);\n}\n\nSTBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header)\n{\n   stbi__zbuf a;\n   char *p = (char *) stbi__malloc(initial_size);\n   if (p == NULL) return NULL;\n   a.zbuffer = (stbi_uc *) buffer;\n   a.zbuffer_end = (stbi_uc *) buffer + len;\n   if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) {\n      if (outlen) *outlen = (int) (a.zout - a.zout_start);\n      return a.zout_start;\n   } else {\n      STBI_FREE(a.zout_start);\n      return NULL;\n   }\n}\n\nSTBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, char const *ibuffer, int ilen)\n{\n   stbi__zbuf a;\n   a.zbuffer = (stbi_uc *) ibuffer;\n   a.zbuffer_end = (stbi_uc *) ibuffer + ilen;\n   if (stbi__do_zlib(&a, obuffer, olen, 0, 1))\n      return (int) (a.zout - a.zout_start);\n   else\n      return -1;\n}\n\nSTBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, int *outlen)\n{\n   stbi__zbuf a;\n   char *p = (char *) stbi__malloc(16384);\n   if (p == NULL) return NULL;\n   a.zbuffer = (stbi_uc *) buffer;\n   a.zbuffer_end = (stbi_uc *) buffer+len;\n   if (stbi__do_zlib(&a, p, 16384, 1, 0)) {\n      if (outlen) *outlen = (int) (a.zout - a.zout_start);\n      return a.zout_start;\n   } else {\n      STBI_FREE(a.zout_start);\n      return NULL;\n   }\n}\n\nSTBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen)\n{\n   stbi__zbuf a;\n   a.zbuffer = (stbi_uc *) ibuffer;\n   a.zbuffer_end = (stbi_uc *) ibuffer + ilen;\n   if (stbi__do_zlib(&a, obuffer, olen, 0, 0))\n      return (int) (a.zout - a.zout_start);\n   else\n      return -1;\n}\n#endif\n\n// public domain \"baseline\" PNG decoder   v0.10  Sean Barrett 2006-11-18\n//    simple implementation\n//      - only 8-bit samples\n//      - no CRC checking\n//      - allocates lots of intermediate memory\n//        - avoids problem of streaming data between subsystems\n//        - avoids explicit window management\n//    performance\n//      - uses stb_zlib, a PD zlib implementation with fast huffman decoding\n\n#ifndef STBI_NO_PNG\ntypedef struct\n{\n   stbi__uint32 length;\n   stbi__uint32 type;\n} stbi__pngchunk;\n\nstatic stbi__pngchunk stbi__get_chunk_header(stbi__context *s)\n{\n   stbi__pngchunk c;\n   c.length = stbi__get32be(s);\n   c.type   = stbi__get32be(s);\n   return c;\n}\n\nstatic int stbi__check_png_header(stbi__context *s)\n{\n   static const stbi_uc png_sig[8] = { 137,80,78,71,13,10,26,10 };\n   int i;\n   for (i=0; i < 8; ++i)\n      if (stbi__get8(s) != png_sig[i]) return stbi__err(\"bad png sig\",\"Not a PNG\");\n   return 1;\n}\n\ntypedef struct\n{\n   stbi__context *s;\n   stbi_uc *idata, *expanded, *out;\n   int depth;\n} stbi__png;\n\n\nenum {\n   STBI__F_none=0,\n   STBI__F_sub=1,\n   STBI__F_up=2,\n   STBI__F_avg=3,\n   STBI__F_paeth=4,\n   // synthetic filters used for first scanline to avoid needing a dummy row of 0s\n   STBI__F_avg_first,\n   STBI__F_paeth_first\n};\n\nstatic stbi_uc first_row_filter[5] =\n{\n   STBI__F_none,\n   STBI__F_sub,\n   STBI__F_none,\n   STBI__F_avg_first,\n   STBI__F_paeth_first\n};\n\nstatic int stbi__paeth(int a, int b, int c)\n{\n   int p = a + b - c;\n   int pa = abs(p-a);\n   int pb = abs(p-b);\n   int pc = abs(p-c);\n   if (pa <= pb && pa <= pc) return a;\n   if (pb <= pc) return b;\n   return c;\n}\n\nstatic const stbi_uc stbi__depth_scale_table[9] = { 0, 0xff, 0x55, 0, 0x11, 0,0,0, 0x01 };\n\n// create the png data from post-deflated data\nstatic int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, stbi__uint32 y, int depth, int color)\n{\n   int bytes = (depth == 16? 2 : 1);\n   stbi__context *s = a->s;\n   stbi__uint32 i,j,stride = x*out_n*bytes;\n   stbi__uint32 img_len, img_width_bytes;\n   int k;\n   int img_n = s->img_n; // copy it into a local for later\n\n   int output_bytes = out_n*bytes;\n   int filter_bytes = img_n*bytes;\n   int width = x;\n\n   STBI_ASSERT(out_n == s->img_n || out_n == s->img_n+1);\n   a->out = (stbi_uc *) stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into\n   if (!a->out) return stbi__err(\"outofmem\", \"Out of memory\");\n\n   if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) return stbi__err(\"too large\", \"Corrupt PNG\");\n   img_width_bytes = (((img_n * x * depth) + 7) >> 3);\n   img_len = (img_width_bytes + 1) * y;\n\n   // we used to check for exact match between raw_len and img_len on non-interlaced PNGs,\n   // but issue #276 reported a PNG in the wild that had extra data at the end (all zeros),\n   // so just check for raw_len < img_len always.\n   if (raw_len < img_len) return stbi__err(\"not enough pixels\",\"Corrupt PNG\");\n\n   for (j=0; j < y; ++j) {\n      stbi_uc *cur = a->out + stride*j;\n      stbi_uc *prior;\n      int filter = *raw++;\n\n      if (filter > 4)\n         return stbi__err(\"invalid filter\",\"Corrupt PNG\");\n\n      if (depth < 8) {\n         if (img_width_bytes > x) return stbi__err(\"invalid width\",\"Corrupt PNG\");\n         cur += x*out_n - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place\n         filter_bytes = 1;\n         width = img_width_bytes;\n      }\n      prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above\n\n      // if first row, use special filter that doesn't sample previous row\n      if (j == 0) filter = first_row_filter[filter];\n\n      // handle first byte explicitly\n      for (k=0; k < filter_bytes; ++k) {\n         switch (filter) {\n            case STBI__F_none       : cur[k] = raw[k]; break;\n            case STBI__F_sub        : cur[k] = raw[k]; break;\n            case STBI__F_up         : cur[k] = STBI__BYTECAST(raw[k] + prior[k]); break;\n            case STBI__F_avg        : cur[k] = STBI__BYTECAST(raw[k] + (prior[k]>>1)); break;\n            case STBI__F_paeth      : cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0,prior[k],0)); break;\n            case STBI__F_avg_first  : cur[k] = raw[k]; break;\n            case STBI__F_paeth_first: cur[k] = raw[k]; break;\n         }\n      }\n\n      if (depth == 8) {\n         if (img_n != out_n)\n            cur[img_n] = 255; // first pixel\n         raw += img_n;\n         cur += out_n;\n         prior += out_n;\n      } else if (depth == 16) {\n         if (img_n != out_n) {\n            cur[filter_bytes]   = 255; // first pixel top byte\n            cur[filter_bytes+1] = 255; // first pixel bottom byte\n         }\n         raw += filter_bytes;\n         cur += output_bytes;\n         prior += output_bytes;\n      } else {\n         raw += 1;\n         cur += 1;\n         prior += 1;\n      }\n\n      // this is a little gross, so that we don't switch per-pixel or per-component\n      if (depth < 8 || img_n == out_n) {\n         int nk = (width - 1)*filter_bytes;\n         #define STBI__CASE(f) \\\n             case f:     \\\n                for (k=0; k < nk; ++k)\n         switch (filter) {\n            // \"none\" filter turns into a memcpy here; make that explicit.\n            case STBI__F_none:         memcpy(cur, raw, nk); break;\n            STBI__CASE(STBI__F_sub)          { cur[k] = STBI__BYTECAST(raw[k] + cur[k-filter_bytes]); } break;\n            STBI__CASE(STBI__F_up)           { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } break;\n            STBI__CASE(STBI__F_avg)          { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k-filter_bytes])>>1)); } break;\n            STBI__CASE(STBI__F_paeth)        { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes],prior[k],prior[k-filter_bytes])); } break;\n            STBI__CASE(STBI__F_avg_first)    { cur[k] = STBI__BYTECAST(raw[k] + (cur[k-filter_bytes] >> 1)); } break;\n            STBI__CASE(STBI__F_paeth_first)  { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes],0,0)); } break;\n         }\n         #undef STBI__CASE\n         raw += nk;\n      } else {\n         STBI_ASSERT(img_n+1 == out_n);\n         #define STBI__CASE(f) \\\n             case f:     \\\n                for (i=x-1; i >= 1; --i, cur[filter_bytes]=255,raw+=filter_bytes,cur+=output_bytes,prior+=output_bytes) \\\n                   for (k=0; k < filter_bytes; ++k)\n         switch (filter) {\n            STBI__CASE(STBI__F_none)         { cur[k] = raw[k]; } break;\n            STBI__CASE(STBI__F_sub)          { cur[k] = STBI__BYTECAST(raw[k] + cur[k- output_bytes]); } break;\n            STBI__CASE(STBI__F_up)           { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } break;\n            STBI__CASE(STBI__F_avg)          { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k- output_bytes])>>1)); } break;\n            STBI__CASE(STBI__F_paeth)        { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k- output_bytes],prior[k],prior[k- output_bytes])); } break;\n            STBI__CASE(STBI__F_avg_first)    { cur[k] = STBI__BYTECAST(raw[k] + (cur[k- output_bytes] >> 1)); } break;\n            STBI__CASE(STBI__F_paeth_first)  { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k- output_bytes],0,0)); } break;\n         }\n         #undef STBI__CASE\n\n         // the loop above sets the high byte of the pixels' alpha, but for\n         // 16 bit png files we also need the low byte set. we'll do that here.\n         if (depth == 16) {\n            cur = a->out + stride*j; // start at the beginning of the row again\n            for (i=0; i < x; ++i,cur+=output_bytes) {\n               cur[filter_bytes+1] = 255;\n            }\n         }\n      }\n   }\n\n   // we make a separate pass to expand bits to pixels; for performance,\n   // this could run two scanlines behind the above code, so it won't\n   // intefere with filtering but will still be in the cache.\n   if (depth < 8) {\n      for (j=0; j < y; ++j) {\n         stbi_uc *cur = a->out + stride*j;\n         stbi_uc *in  = a->out + stride*j + x*out_n - img_width_bytes;\n         // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common 8-bit path optimal at minimal cost for 1/2/4-bit\n         // png guarante byte alignment, if width is not multiple of 8/4/2 we'll decode dummy trailing data that will be skipped in the later loop\n         stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range\n\n         // note that the final byte might overshoot and write more data than desired.\n         // we can allocate enough data that this never writes out of memory, but it\n         // could also overwrite the next scanline. can it overwrite non-empty data\n         // on the next scanline? yes, consider 1-pixel-wide scanlines with 1-bit-per-pixel.\n         // so we need to explicitly clamp the final ones\n\n         if (depth == 4) {\n            for (k=x*img_n; k >= 2; k-=2, ++in) {\n               *cur++ = scale * ((*in >> 4)       );\n               *cur++ = scale * ((*in     ) & 0x0f);\n            }\n            if (k > 0) *cur++ = scale * ((*in >> 4)       );\n         } else if (depth == 2) {\n            for (k=x*img_n; k >= 4; k-=4, ++in) {\n               *cur++ = scale * ((*in >> 6)       );\n               *cur++ = scale * ((*in >> 4) & 0x03);\n               *cur++ = scale * ((*in >> 2) & 0x03);\n               *cur++ = scale * ((*in     ) & 0x03);\n            }\n            if (k > 0) *cur++ = scale * ((*in >> 6)       );\n            if (k > 1) *cur++ = scale * ((*in >> 4) & 0x03);\n            if (k > 2) *cur++ = scale * ((*in >> 2) & 0x03);\n         } else if (depth == 1) {\n            for (k=x*img_n; k >= 8; k-=8, ++in) {\n               *cur++ = scale * ((*in >> 7)       );\n               *cur++ = scale * ((*in >> 6) & 0x01);\n               *cur++ = scale * ((*in >> 5) & 0x01);\n               *cur++ = scale * ((*in >> 4) & 0x01);\n               *cur++ = scale * ((*in >> 3) & 0x01);\n               *cur++ = scale * ((*in >> 2) & 0x01);\n               *cur++ = scale * ((*in >> 1) & 0x01);\n               *cur++ = scale * ((*in     ) & 0x01);\n            }\n            if (k > 0) *cur++ = scale * ((*in >> 7)       );\n            if (k > 1) *cur++ = scale * ((*in >> 6) & 0x01);\n            if (k > 2) *cur++ = scale * ((*in >> 5) & 0x01);\n            if (k > 3) *cur++ = scale * ((*in >> 4) & 0x01);\n            if (k > 4) *cur++ = scale * ((*in >> 3) & 0x01);\n            if (k > 5) *cur++ = scale * ((*in >> 2) & 0x01);\n            if (k > 6) *cur++ = scale * ((*in >> 1) & 0x01);\n         }\n         if (img_n != out_n) {\n            int q;\n            // insert alpha = 255\n            cur = a->out + stride*j;\n            if (img_n == 1) {\n               for (q=x-1; q >= 0; --q) {\n                  cur[q*2+1] = 255;\n                  cur[q*2+0] = cur[q];\n               }\n            } else {\n               STBI_ASSERT(img_n == 3);\n               for (q=x-1; q >= 0; --q) {\n                  cur[q*4+3] = 255;\n                  cur[q*4+2] = cur[q*3+2];\n                  cur[q*4+1] = cur[q*3+1];\n                  cur[q*4+0] = cur[q*3+0];\n               }\n            }\n         }\n      }\n   } else if (depth == 16) {\n      // force the image data from big-endian to platform-native.\n      // this is done in a separate pass due to the decoding relying\n      // on the data being untouched, but could probably be done\n      // per-line during decode if care is taken.\n      stbi_uc *cur = a->out;\n      stbi__uint16 *cur16 = (stbi__uint16*)cur;\n\n      for(i=0; i < x*y*out_n; ++i,cur16++,cur+=2) {\n         *cur16 = (cur[0] << 8) | cur[1];\n      }\n   }\n\n   return 1;\n}\n\nstatic int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, stbi__uint32 image_data_len, int out_n, int depth, int color, int interlaced)\n{\n   int bytes = (depth == 16 ? 2 : 1);\n   int out_bytes = out_n * bytes;\n   stbi_uc *final;\n   int p;\n   if (!interlaced)\n      return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, a->s->img_x, a->s->img_y, depth, color);\n\n   // de-interlacing\n   final = (stbi_uc *) stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0);\n   if (!final) return stbi__err(\"outofmem\", \"Out of memory\");\n   for (p=0; p < 7; ++p) {\n      int xorig[] = { 0,4,0,2,0,1,0 };\n      int yorig[] = { 0,0,4,0,2,0,1 };\n      int xspc[]  = { 8,8,4,4,2,2,1 };\n      int yspc[]  = { 8,8,8,4,4,2,2 };\n      int i,j,x,y;\n      // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1\n      x = (a->s->img_x - xorig[p] + xspc[p]-1) / xspc[p];\n      y = (a->s->img_y - yorig[p] + yspc[p]-1) / yspc[p];\n      if (x && y) {\n         stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y;\n         if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, y, depth, color)) {\n            STBI_FREE(final);\n            return 0;\n         }\n         for (j=0; j < y; ++j) {\n            for (i=0; i < x; ++i) {\n               int out_y = j*yspc[p]+yorig[p];\n               int out_x = i*xspc[p]+xorig[p];\n               memcpy(final + out_y*a->s->img_x*out_bytes + out_x*out_bytes,\n                      a->out + (j*x+i)*out_bytes, out_bytes);\n            }\n         }\n         STBI_FREE(a->out);\n         image_data += img_len;\n         image_data_len -= img_len;\n      }\n   }\n   a->out = final;\n\n   return 1;\n}\n\nstatic int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n)\n{\n   stbi__context *s = z->s;\n   stbi__uint32 i, pixel_count = s->img_x * s->img_y;\n   stbi_uc *p = z->out;\n\n   // compute color-based transparency, assuming we've\n   // already got 255 as the alpha value in the output\n   STBI_ASSERT(out_n == 2 || out_n == 4);\n\n   if (out_n == 2) {\n      for (i=0; i < pixel_count; ++i) {\n         p[1] = (p[0] == tc[0] ? 0 : 255);\n         p += 2;\n      }\n   } else {\n      for (i=0; i < pixel_count; ++i) {\n         if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2])\n            p[3] = 0;\n         p += 4;\n      }\n   }\n   return 1;\n}\n\nstatic int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], int out_n)\n{\n   stbi__context *s = z->s;\n   stbi__uint32 i, pixel_count = s->img_x * s->img_y;\n   stbi__uint16 *p = (stbi__uint16*) z->out;\n\n   // compute color-based transparency, assuming we've\n   // already got 65535 as the alpha value in the output\n   STBI_ASSERT(out_n == 2 || out_n == 4);\n\n   if (out_n == 2) {\n      for (i = 0; i < pixel_count; ++i) {\n         p[1] = (p[0] == tc[0] ? 0 : 65535);\n         p += 2;\n      }\n   } else {\n      for (i = 0; i < pixel_count; ++i) {\n         if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2])\n            p[3] = 0;\n         p += 4;\n      }\n   }\n   return 1;\n}\n\nstatic int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, int pal_img_n)\n{\n   stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y;\n   stbi_uc *p, *temp_out, *orig = a->out;\n\n   p = (stbi_uc *) stbi__malloc_mad2(pixel_count, pal_img_n, 0);\n   if (p == NULL) return stbi__err(\"outofmem\", \"Out of memory\");\n\n   // between here and free(out) below, exitting would leak\n   temp_out = p;\n\n   if (pal_img_n == 3) {\n      for (i=0; i < pixel_count; ++i) {\n         int n = orig[i]*4;\n         p[0] = palette[n  ];\n         p[1] = palette[n+1];\n         p[2] = palette[n+2];\n         p += 3;\n      }\n   } else {\n      for (i=0; i < pixel_count; ++i) {\n         int n = orig[i]*4;\n         p[0] = palette[n  ];\n         p[1] = palette[n+1];\n         p[2] = palette[n+2];\n         p[3] = palette[n+3];\n         p += 4;\n      }\n   }\n   STBI_FREE(a->out);\n   a->out = temp_out;\n\n   STBI_NOTUSED(len);\n\n   return 1;\n}\n\nstatic int stbi__unpremultiply_on_load_global = 0;\nstatic int stbi__de_iphone_flag_global = 0;\n\nSTBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply)\n{\n   stbi__unpremultiply_on_load_global = flag_true_if_should_unpremultiply;\n}\n\nSTBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert)\n{\n   stbi__de_iphone_flag_global = flag_true_if_should_convert;\n}\n\n#ifndef STBI_THREAD_LOCAL\n#define stbi__unpremultiply_on_load  stbi__unpremultiply_on_load_global\n#define stbi__de_iphone_flag  stbi__de_iphone_flag_global\n#else\nstatic STBI_THREAD_LOCAL int stbi__unpremultiply_on_load_local, stbi__unpremultiply_on_load_set;\nstatic STBI_THREAD_LOCAL int stbi__de_iphone_flag_local, stbi__de_iphone_flag_set;\n\nSTBIDEF void stbi_set_unpremultiply_on_load_thread(int flag_true_if_should_unpremultiply)\n{\n   stbi__unpremultiply_on_load_local = flag_true_if_should_unpremultiply;\n   stbi__unpremultiply_on_load_set = 1;\n}\n\nSTBIDEF void stbi_convert_iphone_png_to_rgb_thread(int flag_true_if_should_convert)\n{\n   stbi__de_iphone_flag_local = flag_true_if_should_convert;\n   stbi__de_iphone_flag_set = 1;\n}\n\n#define stbi__unpremultiply_on_load  (stbi__unpremultiply_on_load_set           \\\n                                       ? stbi__unpremultiply_on_load_local      \\\n                                       : stbi__unpremultiply_on_load_global)\n#define stbi__de_iphone_flag  (stbi__de_iphone_flag_set                         \\\n                                ? stbi__de_iphone_flag_local                    \\\n                                : stbi__de_iphone_flag_global)\n#endif // STBI_THREAD_LOCAL\n\nstatic void stbi__de_iphone(stbi__png *z)\n{\n   stbi__context *s = z->s;\n   stbi__uint32 i, pixel_count = s->img_x * s->img_y;\n   stbi_uc *p = z->out;\n\n   if (s->img_out_n == 3) {  // convert bgr to rgb\n      for (i=0; i < pixel_count; ++i) {\n         stbi_uc t = p[0];\n         p[0] = p[2];\n         p[2] = t;\n         p += 3;\n      }\n   } else {\n      STBI_ASSERT(s->img_out_n == 4);\n      if (stbi__unpremultiply_on_load) {\n         // convert bgr to rgb and unpremultiply\n         for (i=0; i < pixel_count; ++i) {\n            stbi_uc a = p[3];\n            stbi_uc t = p[0];\n            if (a) {\n               stbi_uc half = a / 2;\n               p[0] = (p[2] * 255 + half) / a;\n               p[1] = (p[1] * 255 + half) / a;\n               p[2] = ( t   * 255 + half) / a;\n            } else {\n               p[0] = p[2];\n               p[2] = t;\n            }\n            p += 4;\n         }\n      } else {\n         // convert bgr to rgb\n         for (i=0; i < pixel_count; ++i) {\n            stbi_uc t = p[0];\n            p[0] = p[2];\n            p[2] = t;\n            p += 4;\n         }\n      }\n   }\n}\n\n#define STBI__PNG_TYPE(a,b,c,d)  (((unsigned) (a) << 24) + ((unsigned) (b) << 16) + ((unsigned) (c) << 8) + (unsigned) (d))\n\nstatic int stbi__parse_png_file(stbi__png *z, int scan, int req_comp)\n{\n   stbi_uc palette[1024], pal_img_n=0;\n   stbi_uc has_trans=0, tc[3]={0};\n   stbi__uint16 tc16[3];\n   stbi__uint32 ioff=0, idata_limit=0, i, pal_len=0;\n   int first=1,k,interlace=0, color=0, is_iphone=0;\n   stbi__context *s = z->s;\n\n   z->expanded = NULL;\n   z->idata = NULL;\n   z->out = NULL;\n\n   if (!stbi__check_png_header(s)) return 0;\n\n   if (scan == STBI__SCAN_type) return 1;\n\n   for (;;) {\n      stbi__pngchunk c = stbi__get_chunk_header(s);\n      switch (c.type) {\n         case STBI__PNG_TYPE('C','g','B','I'):\n            is_iphone = 1;\n            stbi__skip(s, c.length);\n            break;\n         case STBI__PNG_TYPE('I','H','D','R'): {\n            int comp,filter;\n            if (!first) return stbi__err(\"multiple IHDR\",\"Corrupt PNG\");\n            first = 0;\n            if (c.length != 13) return stbi__err(\"bad IHDR len\",\"Corrupt PNG\");\n            s->img_x = stbi__get32be(s);\n            s->img_y = stbi__get32be(s);\n            if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err(\"too large\",\"Very large image (corrupt?)\");\n            if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err(\"too large\",\"Very large image (corrupt?)\");\n            z->depth = stbi__get8(s);  if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && z->depth != 16)  return stbi__err(\"1/2/4/8/16-bit only\",\"PNG not supported: 1/2/4/8/16-bit only\");\n            color = stbi__get8(s);  if (color > 6)         return stbi__err(\"bad ctype\",\"Corrupt PNG\");\n            if (color == 3 && z->depth == 16)                  return stbi__err(\"bad ctype\",\"Corrupt PNG\");\n            if (color == 3) pal_img_n = 3; else if (color & 1) return stbi__err(\"bad ctype\",\"Corrupt PNG\");\n            comp  = stbi__get8(s);  if (comp) return stbi__err(\"bad comp method\",\"Corrupt PNG\");\n            filter= stbi__get8(s);  if (filter) return stbi__err(\"bad filter method\",\"Corrupt PNG\");\n            interlace = stbi__get8(s); if (interlace>1) return stbi__err(\"bad interlace method\",\"Corrupt PNG\");\n            if (!s->img_x || !s->img_y) return stbi__err(\"0-pixel image\",\"Corrupt PNG\");\n            if (!pal_img_n) {\n               s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0);\n               if ((1 << 30) / s->img_x / s->img_n < s->img_y) return stbi__err(\"too large\", \"Image too large to decode\");\n            } else {\n               // if paletted, then pal_n is our final components, and\n               // img_n is # components to decompress/filter.\n               s->img_n = 1;\n               if ((1 << 30) / s->img_x / 4 < s->img_y) return stbi__err(\"too large\",\"Corrupt PNG\");\n            }\n            // even with SCAN_header, have to scan to see if we have a tRNS\n            break;\n         }\n\n         case STBI__PNG_TYPE('P','L','T','E'):  {\n            if (first) return stbi__err(\"first not IHDR\", \"Corrupt PNG\");\n            if (c.length > 256*3) return stbi__err(\"invalid PLTE\",\"Corrupt PNG\");\n            pal_len = c.length / 3;\n            if (pal_len * 3 != c.length) return stbi__err(\"invalid PLTE\",\"Corrupt PNG\");\n            for (i=0; i < pal_len; ++i) {\n               palette[i*4+0] = stbi__get8(s);\n               palette[i*4+1] = stbi__get8(s);\n               palette[i*4+2] = stbi__get8(s);\n               palette[i*4+3] = 255;\n            }\n            break;\n         }\n\n         case STBI__PNG_TYPE('t','R','N','S'): {\n            if (first) return stbi__err(\"first not IHDR\", \"Corrupt PNG\");\n            if (z->idata) return stbi__err(\"tRNS after IDAT\",\"Corrupt PNG\");\n            if (pal_img_n) {\n               if (scan == STBI__SCAN_header) { s->img_n = 4; return 1; }\n               if (pal_len == 0) return stbi__err(\"tRNS before PLTE\",\"Corrupt PNG\");\n               if (c.length > pal_len) return stbi__err(\"bad tRNS len\",\"Corrupt PNG\");\n               pal_img_n = 4;\n               for (i=0; i < c.length; ++i)\n                  palette[i*4+3] = stbi__get8(s);\n            } else {\n               if (!(s->img_n & 1)) return stbi__err(\"tRNS with alpha\",\"Corrupt PNG\");\n               if (c.length != (stbi__uint32) s->img_n*2) return stbi__err(\"bad tRNS len\",\"Corrupt PNG\");\n               has_trans = 1;\n               // non-paletted with tRNS = constant alpha. if header-scanning, we can stop now.\n               if (scan == STBI__SCAN_header) { ++s->img_n; return 1; }\n               if (z->depth == 16) {\n                  for (k = 0; k < s->img_n; ++k) tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is\n               } else {\n                  for (k = 0; k < s->img_n; ++k) tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger\n               }\n            }\n            break;\n         }\n\n         case STBI__PNG_TYPE('I','D','A','T'): {\n            if (first) return stbi__err(\"first not IHDR\", \"Corrupt PNG\");\n            if (pal_img_n && !pal_len) return stbi__err(\"no PLTE\",\"Corrupt PNG\");\n            if (scan == STBI__SCAN_header) {\n               // header scan definitely stops at first IDAT\n               if (pal_img_n)\n                  s->img_n = pal_img_n;\n               return 1;\n            }\n            if (c.length > (1u << 30)) return stbi__err(\"IDAT size limit\", \"IDAT section larger than 2^30 bytes\");\n            if ((int)(ioff + c.length) < (int)ioff) return 0;\n            if (ioff + c.length > idata_limit) {\n               stbi__uint32 idata_limit_old = idata_limit;\n               stbi_uc *p;\n               if (idata_limit == 0) idata_limit = c.length > 4096 ? c.length : 4096;\n               while (ioff + c.length > idata_limit)\n                  idata_limit *= 2;\n               STBI_NOTUSED(idata_limit_old);\n               p = (stbi_uc *) STBI_REALLOC_SIZED(z->idata, idata_limit_old, idata_limit); if (p == NULL) return stbi__err(\"outofmem\", \"Out of memory\");\n               z->idata = p;\n            }\n            if (!stbi__getn(s, z->idata+ioff,c.length)) return stbi__err(\"outofdata\",\"Corrupt PNG\");\n            ioff += c.length;\n            break;\n         }\n\n         case STBI__PNG_TYPE('I','E','N','D'): {\n            stbi__uint32 raw_len, bpl;\n            if (first) return stbi__err(\"first not IHDR\", \"Corrupt PNG\");\n            if (scan != STBI__SCAN_load) return 1;\n            if (z->idata == NULL) return stbi__err(\"no IDAT\",\"Corrupt PNG\");\n            // initial guess for decoded data size to avoid unnecessary reallocs\n            bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component\n            raw_len = bpl * s->img_y * s->img_n /* pixels */ + s->img_y /* filter mode per row */;\n            z->expanded = (stbi_uc *) stbi_zlib_decode_malloc_guesssize_headerflag((char *) z->idata, ioff, raw_len, (int *) &raw_len, !is_iphone);\n            if (z->expanded == NULL) return 0; // zlib should set error\n            STBI_FREE(z->idata); z->idata = NULL;\n            if ((req_comp == s->img_n+1 && req_comp != 3 && !pal_img_n) || has_trans)\n               s->img_out_n = s->img_n+1;\n            else\n               s->img_out_n = s->img_n;\n            if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, z->depth, color, interlace)) return 0;\n            if (has_trans) {\n               if (z->depth == 16) {\n                  if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) return 0;\n               } else {\n                  if (!stbi__compute_transparency(z, tc, s->img_out_n)) return 0;\n               }\n            }\n            if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2)\n               stbi__de_iphone(z);\n            if (pal_img_n) {\n               // pal_img_n == 3 or 4\n               s->img_n = pal_img_n; // record the actual colors we had\n               s->img_out_n = pal_img_n;\n               if (req_comp >= 3) s->img_out_n = req_comp;\n               if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n))\n                  return 0;\n            } else if (has_trans) {\n               // non-paletted image with tRNS -> source image has (constant) alpha\n               ++s->img_n;\n            }\n            STBI_FREE(z->expanded); z->expanded = NULL;\n            // end of PNG chunk, read and skip CRC\n            stbi__get32be(s);\n            return 1;\n         }\n\n         default:\n            // if critical, fail\n            if (first) return stbi__err(\"first not IHDR\", \"Corrupt PNG\");\n            if ((c.type & (1 << 29)) == 0) {\n               #ifndef STBI_NO_FAILURE_STRINGS\n               // not threadsafe\n               static char invalid_chunk[] = \"XXXX PNG chunk not known\";\n               invalid_chunk[0] = STBI__BYTECAST(c.type >> 24);\n               invalid_chunk[1] = STBI__BYTECAST(c.type >> 16);\n               invalid_chunk[2] = STBI__BYTECAST(c.type >>  8);\n               invalid_chunk[3] = STBI__BYTECAST(c.type >>  0);\n               #endif\n               return stbi__err(invalid_chunk, \"PNG not supported: unknown PNG chunk type\");\n            }\n            stbi__skip(s, c.length);\n            break;\n      }\n      // end of PNG chunk, read and skip CRC\n      stbi__get32be(s);\n   }\n}\n\nstatic void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, stbi__result_info *ri)\n{\n   void *result=NULL;\n   if (req_comp < 0 || req_comp > 4) return stbi__errpuc(\"bad req_comp\", \"Internal error\");\n   if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) {\n      if (p->depth <= 8)\n         ri->bits_per_channel = 8;\n      else if (p->depth == 16)\n         ri->bits_per_channel = 16;\n      else\n         return stbi__errpuc(\"bad bits_per_channel\", \"PNG not supported: unsupported color depth\");\n      result = p->out;\n      p->out = NULL;\n      if (req_comp && req_comp != p->s->img_out_n) {\n         if (ri->bits_per_channel == 8)\n            result = stbi__convert_format((unsigned char *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y);\n         else\n            result = stbi__convert_format16((stbi__uint16 *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y);\n         p->s->img_out_n = req_comp;\n         if (result == NULL) return result;\n      }\n      *x = p->s->img_x;\n      *y = p->s->img_y;\n      if (n) *n = p->s->img_n;\n   }\n   STBI_FREE(p->out);      p->out      = NULL;\n   STBI_FREE(p->expanded); p->expanded = NULL;\n   STBI_FREE(p->idata);    p->idata    = NULL;\n\n   return result;\n}\n\nstatic void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri)\n{\n   stbi__png p;\n   p.s = s;\n   return stbi__do_png(&p, x,y,comp,req_comp, ri);\n}\n\nstatic int stbi__png_test(stbi__context *s)\n{\n   int r;\n   r = stbi__check_png_header(s);\n   stbi__rewind(s);\n   return r;\n}\n\nstatic int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp)\n{\n   if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) {\n      stbi__rewind( p->s );\n      return 0;\n   }\n   if (x) *x = p->s->img_x;\n   if (y) *y = p->s->img_y;\n   if (comp) *comp = p->s->img_n;\n   return 1;\n}\n\nstatic int stbi__png_info(stbi__context *s, int *x, int *y, int *comp)\n{\n   stbi__png p;\n   p.s = s;\n   return stbi__png_info_raw(&p, x, y, comp);\n}\n\nstatic int stbi__png_is16(stbi__context *s)\n{\n   stbi__png p;\n   p.s = s;\n   if (!stbi__png_info_raw(&p, NULL, NULL, NULL))\n\t   return 0;\n   if (p.depth != 16) {\n      stbi__rewind(p.s);\n      return 0;\n   }\n   return 1;\n}\n#endif\n\n// Microsoft/Windows BMP image\n\n#ifndef STBI_NO_BMP\nstatic int stbi__bmp_test_raw(stbi__context *s)\n{\n   int r;\n   int sz;\n   if (stbi__get8(s) != 'B') return 0;\n   if (stbi__get8(s) != 'M') return 0;\n   stbi__get32le(s); // discard filesize\n   stbi__get16le(s); // discard reserved\n   stbi__get16le(s); // discard reserved\n   stbi__get32le(s); // discard data offset\n   sz = stbi__get32le(s);\n   r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124);\n   return r;\n}\n\nstatic int stbi__bmp_test(stbi__context *s)\n{\n   int r = stbi__bmp_test_raw(s);\n   stbi__rewind(s);\n   return r;\n}\n\n\n// returns 0..31 for the highest set bit\nstatic int stbi__high_bit(unsigned int z)\n{\n   int n=0;\n   if (z == 0) return -1;\n   if (z >= 0x10000) { n += 16; z >>= 16; }\n   if (z >= 0x00100) { n +=  8; z >>=  8; }\n   if (z >= 0x00010) { n +=  4; z >>=  4; }\n   if (z >= 0x00004) { n +=  2; z >>=  2; }\n   if (z >= 0x00002) { n +=  1;/* >>=  1;*/ }\n   return n;\n}\n\nstatic int stbi__bitcount(unsigned int a)\n{\n   a = (a & 0x55555555) + ((a >>  1) & 0x55555555); // max 2\n   a = (a & 0x33333333) + ((a >>  2) & 0x33333333); // max 4\n   a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits\n   a = (a + (a >> 8)); // max 16 per 8 bits\n   a = (a + (a >> 16)); // max 32 per 8 bits\n   return a & 0xff;\n}\n\n// extract an arbitrarily-aligned N-bit value (N=bits)\n// from v, and then make it 8-bits long and fractionally\n// extend it to full full range.\nstatic int stbi__shiftsigned(unsigned int v, int shift, int bits)\n{\n   static unsigned int mul_table[9] = {\n      0,\n      0xff/*0b11111111*/, 0x55/*0b01010101*/, 0x49/*0b01001001*/, 0x11/*0b00010001*/,\n      0x21/*0b00100001*/, 0x41/*0b01000001*/, 0x81/*0b10000001*/, 0x01/*0b00000001*/,\n   };\n   static unsigned int shift_table[9] = {\n      0, 0,0,1,0,2,4,6,0,\n   };\n   if (shift < 0)\n      v <<= -shift;\n   else\n      v >>= shift;\n   STBI_ASSERT(v < 256);\n   v >>= (8-bits);\n   STBI_ASSERT(bits >= 0 && bits <= 8);\n   return (int) ((unsigned) v * mul_table[bits]) >> shift_table[bits];\n}\n\ntypedef struct\n{\n   int bpp, offset, hsz;\n   unsigned int mr,mg,mb,ma, all_a;\n   int extra_read;\n} stbi__bmp_data;\n\nstatic int stbi__bmp_set_mask_defaults(stbi__bmp_data *info, int compress)\n{\n   // BI_BITFIELDS specifies masks explicitly, don't override\n   if (compress == 3)\n      return 1;\n\n   if (compress == 0) {\n      if (info->bpp == 16) {\n         info->mr = 31u << 10;\n         info->mg = 31u <<  5;\n         info->mb = 31u <<  0;\n      } else if (info->bpp == 32) {\n         info->mr = 0xffu << 16;\n         info->mg = 0xffu <<  8;\n         info->mb = 0xffu <<  0;\n         info->ma = 0xffu << 24;\n         info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0\n      } else {\n         // otherwise, use defaults, which is all-0\n         info->mr = info->mg = info->mb = info->ma = 0;\n      }\n      return 1;\n   }\n   return 0; // error\n}\n\nstatic void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info)\n{\n   int hsz;\n   if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') return stbi__errpuc(\"not BMP\", \"Corrupt BMP\");\n   stbi__get32le(s); // discard filesize\n   stbi__get16le(s); // discard reserved\n   stbi__get16le(s); // discard reserved\n   info->offset = stbi__get32le(s);\n   info->hsz = hsz = stbi__get32le(s);\n   info->mr = info->mg = info->mb = info->ma = 0;\n   info->extra_read = 14;\n\n   if (info->offset < 0) return stbi__errpuc(\"bad BMP\", \"bad BMP\");\n\n   if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) return stbi__errpuc(\"unknown BMP\", \"BMP type not supported: unknown\");\n   if (hsz == 12) {\n      s->img_x = stbi__get16le(s);\n      s->img_y = stbi__get16le(s);\n   } else {\n      s->img_x = stbi__get32le(s);\n      s->img_y = stbi__get32le(s);\n   }\n   if (stbi__get16le(s) != 1) return stbi__errpuc(\"bad BMP\", \"bad BMP\");\n   info->bpp = stbi__get16le(s);\n   if (hsz != 12) {\n      int compress = stbi__get32le(s);\n      if (compress == 1 || compress == 2) return stbi__errpuc(\"BMP RLE\", \"BMP type not supported: RLE\");\n      if (compress >= 4) return stbi__errpuc(\"BMP JPEG/PNG\", \"BMP type not supported: unsupported compression\"); // this includes PNG/JPEG modes\n      if (compress == 3 && info->bpp != 16 && info->bpp != 32) return stbi__errpuc(\"bad BMP\", \"bad BMP\"); // bitfields requires 16 or 32 bits/pixel\n      stbi__get32le(s); // discard sizeof\n      stbi__get32le(s); // discard hres\n      stbi__get32le(s); // discard vres\n      stbi__get32le(s); // discard colorsused\n      stbi__get32le(s); // discard max important\n      if (hsz == 40 || hsz == 56) {\n         if (hsz == 56) {\n            stbi__get32le(s);\n            stbi__get32le(s);\n            stbi__get32le(s);\n            stbi__get32le(s);\n         }\n         if (info->bpp == 16 || info->bpp == 32) {\n            if (compress == 0) {\n               stbi__bmp_set_mask_defaults(info, compress);\n            } else if (compress == 3) {\n               info->mr = stbi__get32le(s);\n               info->mg = stbi__get32le(s);\n               info->mb = stbi__get32le(s);\n               info->extra_read += 12;\n               // not documented, but generated by photoshop and handled by mspaint\n               if (info->mr == info->mg && info->mg == info->mb) {\n                  // ?!?!?\n                  return stbi__errpuc(\"bad BMP\", \"bad BMP\");\n               }\n            } else\n               return stbi__errpuc(\"bad BMP\", \"bad BMP\");\n         }\n      } else {\n         // V4/V5 header\n         int i;\n         if (hsz != 108 && hsz != 124)\n            return stbi__errpuc(\"bad BMP\", \"bad BMP\");\n         info->mr = stbi__get32le(s);\n         info->mg = stbi__get32le(s);\n         info->mb = stbi__get32le(s);\n         info->ma = stbi__get32le(s);\n         if (compress != 3) // override mr/mg/mb unless in BI_BITFIELDS mode, as per docs\n            stbi__bmp_set_mask_defaults(info, compress);\n         stbi__get32le(s); // discard color space\n         for (i=0; i < 12; ++i)\n            stbi__get32le(s); // discard color space parameters\n         if (hsz == 124) {\n            stbi__get32le(s); // discard rendering intent\n            stbi__get32le(s); // discard offset of profile data\n            stbi__get32le(s); // discard size of profile data\n            stbi__get32le(s); // discard reserved\n         }\n      }\n   }\n   return (void *) 1;\n}\n\n\nstatic void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri)\n{\n   stbi_uc *out;\n   unsigned int mr=0,mg=0,mb=0,ma=0, all_a;\n   stbi_uc pal[256][4];\n   int psize=0,i,j,width;\n   int flip_vertically, pad, target;\n   stbi__bmp_data info;\n   STBI_NOTUSED(ri);\n\n   info.all_a = 255;\n   if (stbi__bmp_parse_header(s, &info) == NULL)\n      return NULL; // error code already set\n\n   flip_vertically = ((int) s->img_y) > 0;\n   s->img_y = abs((int) s->img_y);\n\n   if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc(\"too large\",\"Very large image (corrupt?)\");\n   if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc(\"too large\",\"Very large image (corrupt?)\");\n\n   mr = info.mr;\n   mg = info.mg;\n   mb = info.mb;\n   ma = info.ma;\n   all_a = info.all_a;\n\n   if (info.hsz == 12) {\n      if (info.bpp < 24)\n         psize = (info.offset - info.extra_read - 24) / 3;\n   } else {\n      if (info.bpp < 16)\n         psize = (info.offset - info.extra_read - info.hsz) >> 2;\n   }\n   if (psize == 0) {\n      // accept some number of extra bytes after the header, but if the offset points either to before\n      // the header ends or implies a large amount of extra data, reject the file as malformed\n      int bytes_read_so_far = s->callback_already_read + (int)(s->img_buffer - s->img_buffer_original);\n      int header_limit = 1024; // max we actually read is below 256 bytes currently.\n      int extra_data_limit = 256*4; // what ordinarily goes here is a palette; 256 entries*4 bytes is its max size.\n      if (bytes_read_so_far <= 0 || bytes_read_so_far > header_limit) {\n         return stbi__errpuc(\"bad header\", \"Corrupt BMP\");\n      }\n      // we established that bytes_read_so_far is positive and sensible.\n      // the first half of this test rejects offsets that are either too small positives, or\n      // negative, and guarantees that info.offset >= bytes_read_so_far > 0. this in turn\n      // ensures the number computed in the second half of the test can't overflow.\n      if (info.offset < bytes_read_so_far || info.offset - bytes_read_so_far > extra_data_limit) {\n         return stbi__errpuc(\"bad offset\", \"Corrupt BMP\");\n      } else {\n         stbi__skip(s, info.offset - bytes_read_so_far);\n      }\n   }\n\n   if (info.bpp == 24 && ma == 0xff000000)\n      s->img_n = 3;\n   else\n      s->img_n = ma ? 4 : 3;\n   if (req_comp && req_comp >= 3) // we can directly decode 3 or 4\n      target = req_comp;\n   else\n      target = s->img_n; // if they want monochrome, we'll post-convert\n\n   // sanity-check size\n   if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0))\n      return stbi__errpuc(\"too large\", \"Corrupt BMP\");\n\n   out = (stbi_uc *) stbi__malloc_mad3(target, s->img_x, s->img_y, 0);\n   if (!out) return stbi__errpuc(\"outofmem\", \"Out of memory\");\n   if (info.bpp < 16) {\n      int z=0;\n      if (psize == 0 || psize > 256) { STBI_FREE(out); return stbi__errpuc(\"invalid\", \"Corrupt BMP\"); }\n      for (i=0; i < psize; ++i) {\n         pal[i][2] = stbi__get8(s);\n         pal[i][1] = stbi__get8(s);\n         pal[i][0] = stbi__get8(s);\n         if (info.hsz != 12) stbi__get8(s);\n         pal[i][3] = 255;\n      }\n      stbi__skip(s, info.offset - info.extra_read - info.hsz - psize * (info.hsz == 12 ? 3 : 4));\n      if (info.bpp == 1) width = (s->img_x + 7) >> 3;\n      else if (info.bpp == 4) width = (s->img_x + 1) >> 1;\n      else if (info.bpp == 8) width = s->img_x;\n      else { STBI_FREE(out); return stbi__errpuc(\"bad bpp\", \"Corrupt BMP\"); }\n      pad = (-width)&3;\n      if (info.bpp == 1) {\n         for (j=0; j < (int) s->img_y; ++j) {\n            int bit_offset = 7, v = stbi__get8(s);\n            for (i=0; i < (int) s->img_x; ++i) {\n               int color = (v>>bit_offset)&0x1;\n               out[z++] = pal[color][0];\n               out[z++] = pal[color][1];\n               out[z++] = pal[color][2];\n               if (target == 4) out[z++] = 255;\n               if (i+1 == (int) s->img_x) break;\n               if((--bit_offset) < 0) {\n                  bit_offset = 7;\n                  v = stbi__get8(s);\n               }\n            }\n            stbi__skip(s, pad);\n         }\n      } else {\n         for (j=0; j < (int) s->img_y; ++j) {\n            for (i=0; i < (int) s->img_x; i += 2) {\n               int v=stbi__get8(s),v2=0;\n               if (info.bpp == 4) {\n                  v2 = v & 15;\n                  v >>= 4;\n               }\n               out[z++] = pal[v][0];\n               out[z++] = pal[v][1];\n               out[z++] = pal[v][2];\n               if (target == 4) out[z++] = 255;\n               if (i+1 == (int) s->img_x) break;\n               v = (info.bpp == 8) ? stbi__get8(s) : v2;\n               out[z++] = pal[v][0];\n               out[z++] = pal[v][1];\n               out[z++] = pal[v][2];\n               if (target == 4) out[z++] = 255;\n            }\n            stbi__skip(s, pad);\n         }\n      }\n   } else {\n      int rshift=0,gshift=0,bshift=0,ashift=0,rcount=0,gcount=0,bcount=0,acount=0;\n      int z = 0;\n      int easy=0;\n      stbi__skip(s, info.offset - info.extra_read - info.hsz);\n      if (info.bpp == 24) width = 3 * s->img_x;\n      else if (info.bpp == 16) width = 2*s->img_x;\n      else /* bpp = 32 and pad = 0 */ width=0;\n      pad = (-width) & 3;\n      if (info.bpp == 24) {\n         easy = 1;\n      } else if (info.bpp == 32) {\n         if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000)\n            easy = 2;\n      }\n      if (!easy) {\n         if (!mr || !mg || !mb) { STBI_FREE(out); return stbi__errpuc(\"bad masks\", \"Corrupt BMP\"); }\n         // right shift amt to put high bit in position #7\n         rshift = stbi__high_bit(mr)-7; rcount = stbi__bitcount(mr);\n         gshift = stbi__high_bit(mg)-7; gcount = stbi__bitcount(mg);\n         bshift = stbi__high_bit(mb)-7; bcount = stbi__bitcount(mb);\n         ashift = stbi__high_bit(ma)-7; acount = stbi__bitcount(ma);\n         if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { STBI_FREE(out); return stbi__errpuc(\"bad masks\", \"Corrupt BMP\"); }\n      }\n      for (j=0; j < (int) s->img_y; ++j) {\n         if (easy) {\n            for (i=0; i < (int) s->img_x; ++i) {\n               unsigned char a;\n               out[z+2] = stbi__get8(s);\n               out[z+1] = stbi__get8(s);\n               out[z+0] = stbi__get8(s);\n               z += 3;\n               a = (easy == 2 ? stbi__get8(s) : 255);\n               all_a |= a;\n               if (target == 4) out[z++] = a;\n            }\n         } else {\n            int bpp = info.bpp;\n            for (i=0; i < (int) s->img_x; ++i) {\n               stbi__uint32 v = (bpp == 16 ? (stbi__uint32) stbi__get16le(s) : stbi__get32le(s));\n               unsigned int a;\n               out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount));\n               out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount));\n               out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount));\n               a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255);\n               all_a |= a;\n               if (target == 4) out[z++] = STBI__BYTECAST(a);\n            }\n         }\n         stbi__skip(s, pad);\n      }\n   }\n\n   // if alpha channel is all 0s, replace with all 255s\n   if (target == 4 && all_a == 0)\n      for (i=4*s->img_x*s->img_y-1; i >= 0; i -= 4)\n         out[i] = 255;\n\n   if (flip_vertically) {\n      stbi_uc t;\n      for (j=0; j < (int) s->img_y>>1; ++j) {\n         stbi_uc *p1 = out +      j     *s->img_x*target;\n         stbi_uc *p2 = out + (s->img_y-1-j)*s->img_x*target;\n         for (i=0; i < (int) s->img_x*target; ++i) {\n            t = p1[i]; p1[i] = p2[i]; p2[i] = t;\n         }\n      }\n   }\n\n   if (req_comp && req_comp != target) {\n      out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y);\n      if (out == NULL) return out; // stbi__convert_format frees input on failure\n   }\n\n   *x = s->img_x;\n   *y = s->img_y;\n   if (comp) *comp = s->img_n;\n   return out;\n}\n#endif\n\n// Targa Truevision - TGA\n// by Jonathan Dummer\n#ifndef STBI_NO_TGA\n// returns STBI_rgb or whatever, 0 on error\nstatic int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int* is_rgb16)\n{\n   // only RGB or RGBA (incl. 16bit) or grey allowed\n   if (is_rgb16) *is_rgb16 = 0;\n   switch(bits_per_pixel) {\n      case 8:  return STBI_grey;\n      case 16: if(is_grey) return STBI_grey_alpha;\n               // fallthrough\n      case 15: if(is_rgb16) *is_rgb16 = 1;\n               return STBI_rgb;\n      case 24: // fallthrough\n      case 32: return bits_per_pixel/8;\n      default: return 0;\n   }\n}\n\nstatic int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp)\n{\n    int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, tga_colormap_bpp;\n    int sz, tga_colormap_type;\n    stbi__get8(s);                   // discard Offset\n    tga_colormap_type = stbi__get8(s); // colormap type\n    if( tga_colormap_type > 1 ) {\n        stbi__rewind(s);\n        return 0;      // only RGB or indexed allowed\n    }\n    tga_image_type = stbi__get8(s); // image type\n    if ( tga_colormap_type == 1 ) { // colormapped (paletted) image\n        if (tga_image_type != 1 && tga_image_type != 9) {\n            stbi__rewind(s);\n            return 0;\n        }\n        stbi__skip(s,4);       // skip index of first colormap entry and number of entries\n        sz = stbi__get8(s);    //   check bits per palette color entry\n        if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) {\n            stbi__rewind(s);\n            return 0;\n        }\n        stbi__skip(s,4);       // skip image x and y origin\n        tga_colormap_bpp = sz;\n    } else { // \"normal\" image w/o colormap - only RGB or grey allowed, +/- RLE\n        if ( (tga_image_type != 2) && (tga_image_type != 3) && (tga_image_type != 10) && (tga_image_type != 11) ) {\n            stbi__rewind(s);\n            return 0; // only RGB or grey allowed, +/- RLE\n        }\n        stbi__skip(s,9); // skip colormap specification and image x/y origin\n        tga_colormap_bpp = 0;\n    }\n    tga_w = stbi__get16le(s);\n    if( tga_w < 1 ) {\n        stbi__rewind(s);\n        return 0;   // test width\n    }\n    tga_h = stbi__get16le(s);\n    if( tga_h < 1 ) {\n        stbi__rewind(s);\n        return 0;   // test height\n    }\n    tga_bits_per_pixel = stbi__get8(s); // bits per pixel\n    stbi__get8(s); // ignore alpha bits\n    if (tga_colormap_bpp != 0) {\n        if((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) {\n            // when using a colormap, tga_bits_per_pixel is the size of the indexes\n            // I don't think anything but 8 or 16bit indexes makes sense\n            stbi__rewind(s);\n            return 0;\n        }\n        tga_comp = stbi__tga_get_comp(tga_colormap_bpp, 0, NULL);\n    } else {\n        tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), NULL);\n    }\n    if(!tga_comp) {\n      stbi__rewind(s);\n      return 0;\n    }\n    if (x) *x = tga_w;\n    if (y) *y = tga_h;\n    if (comp) *comp = tga_comp;\n    return 1;                   // seems to have passed everything\n}\n\nstatic int stbi__tga_test(stbi__context *s)\n{\n   int res = 0;\n   int sz, tga_color_type;\n   stbi__get8(s);      //   discard Offset\n   tga_color_type = stbi__get8(s);   //   color type\n   if ( tga_color_type > 1 ) goto errorEnd;   //   only RGB or indexed allowed\n   sz = stbi__get8(s);   //   image type\n   if ( tga_color_type == 1 ) { // colormapped (paletted) image\n      if (sz != 1 && sz != 9) goto errorEnd; // colortype 1 demands image type 1 or 9\n      stbi__skip(s,4);       // skip index of first colormap entry and number of entries\n      sz = stbi__get8(s);    //   check bits per palette color entry\n      if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd;\n      stbi__skip(s,4);       // skip image x and y origin\n   } else { // \"normal\" image w/o colormap\n      if ( (sz != 2) && (sz != 3) && (sz != 10) && (sz != 11) ) goto errorEnd; // only RGB or grey allowed, +/- RLE\n      stbi__skip(s,9); // skip colormap specification and image x/y origin\n   }\n   if ( stbi__get16le(s) < 1 ) goto errorEnd;      //   test width\n   if ( stbi__get16le(s) < 1 ) goto errorEnd;      //   test height\n   sz = stbi__get8(s);   //   bits per pixel\n   if ( (tga_color_type == 1) && (sz != 8) && (sz != 16) ) goto errorEnd; // for colormapped images, bpp is size of an index\n   if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd;\n\n   res = 1; // if we got this far, everything's good and we can return 1 instead of 0\n\nerrorEnd:\n   stbi__rewind(s);\n   return res;\n}\n\n// read 16bit value and convert to 24bit RGB\nstatic void stbi__tga_read_rgb16(stbi__context *s, stbi_uc* out)\n{\n   stbi__uint16 px = (stbi__uint16)stbi__get16le(s);\n   stbi__uint16 fiveBitMask = 31;\n   // we have 3 channels with 5bits each\n   int r = (px >> 10) & fiveBitMask;\n   int g = (px >> 5) & fiveBitMask;\n   int b = px & fiveBitMask;\n   // Note that this saves the data in RGB(A) order, so it doesn't need to be swapped later\n   out[0] = (stbi_uc)((r * 255)/31);\n   out[1] = (stbi_uc)((g * 255)/31);\n   out[2] = (stbi_uc)((b * 255)/31);\n\n   // some people claim that the most significant bit might be used for alpha\n   // (possibly if an alpha-bit is set in the \"image descriptor byte\")\n   // but that only made 16bit test images completely translucent..\n   // so let's treat all 15 and 16bit TGAs as RGB with no alpha.\n}\n\nstatic void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri)\n{\n   //   read in the TGA header stuff\n   int tga_offset = stbi__get8(s);\n   int tga_indexed = stbi__get8(s);\n   int tga_image_type = stbi__get8(s);\n   int tga_is_RLE = 0;\n   int tga_palette_start = stbi__get16le(s);\n   int tga_palette_len = stbi__get16le(s);\n   int tga_palette_bits = stbi__get8(s);\n   int tga_x_origin = stbi__get16le(s);\n   int tga_y_origin = stbi__get16le(s);\n   int tga_width = stbi__get16le(s);\n   int tga_height = stbi__get16le(s);\n   int tga_bits_per_pixel = stbi__get8(s);\n   int tga_comp, tga_rgb16=0;\n   int tga_inverted = stbi__get8(s);\n   // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused (useless?)\n   //   image data\n   unsigned char *tga_data;\n   unsigned char *tga_palette = NULL;\n   int i, j;\n   unsigned char raw_data[4] = {0};\n   int RLE_count = 0;\n   int RLE_repeating = 0;\n   int read_next_pixel = 1;\n   STBI_NOTUSED(ri);\n   STBI_NOTUSED(tga_x_origin); // @TODO\n   STBI_NOTUSED(tga_y_origin); // @TODO\n\n   if (tga_height > STBI_MAX_DIMENSIONS) return stbi__errpuc(\"too large\",\"Very large image (corrupt?)\");\n   if (tga_width > STBI_MAX_DIMENSIONS) return stbi__errpuc(\"too large\",\"Very large image (corrupt?)\");\n\n   //   do a tiny bit of precessing\n   if ( tga_image_type >= 8 )\n   {\n      tga_image_type -= 8;\n      tga_is_RLE = 1;\n   }\n   tga_inverted = 1 - ((tga_inverted >> 5) & 1);\n\n   //   If I'm paletted, then I'll use the number of bits from the palette\n   if ( tga_indexed ) tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16);\n   else tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), &tga_rgb16);\n\n   if(!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency\n      return stbi__errpuc(\"bad format\", \"Can't find out TGA pixelformat\");\n\n   //   tga info\n   *x = tga_width;\n   *y = tga_height;\n   if (comp) *comp = tga_comp;\n\n   if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0))\n      return stbi__errpuc(\"too large\", \"Corrupt TGA\");\n\n   tga_data = (unsigned char*)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0);\n   if (!tga_data) return stbi__errpuc(\"outofmem\", \"Out of memory\");\n\n   // skip to the data's starting position (offset usually = 0)\n   stbi__skip(s, tga_offset );\n\n   if ( !tga_indexed && !tga_is_RLE && !tga_rgb16 ) {\n      for (i=0; i < tga_height; ++i) {\n         int row = tga_inverted ? tga_height -i - 1 : i;\n         stbi_uc *tga_row = tga_data + row*tga_width*tga_comp;\n         stbi__getn(s, tga_row, tga_width * tga_comp);\n      }\n   } else  {\n      //   do I need to load a palette?\n      if ( tga_indexed)\n      {\n         if (tga_palette_len == 0) {  /* you have to have at least one entry! */\n            STBI_FREE(tga_data);\n            return stbi__errpuc(\"bad palette\", \"Corrupt TGA\");\n         }\n\n         //   any data to skip? (offset usually = 0)\n         stbi__skip(s, tga_palette_start );\n         //   load the palette\n         tga_palette = (unsigned char*)stbi__malloc_mad2(tga_palette_len, tga_comp, 0);\n         if (!tga_palette) {\n            STBI_FREE(tga_data);\n            return stbi__errpuc(\"outofmem\", \"Out of memory\");\n         }\n         if (tga_rgb16) {\n            stbi_uc *pal_entry = tga_palette;\n            STBI_ASSERT(tga_comp == STBI_rgb);\n            for (i=0; i < tga_palette_len; ++i) {\n               stbi__tga_read_rgb16(s, pal_entry);\n               pal_entry += tga_comp;\n            }\n         } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) {\n               STBI_FREE(tga_data);\n               STBI_FREE(tga_palette);\n               return stbi__errpuc(\"bad palette\", \"Corrupt TGA\");\n         }\n      }\n      //   load the data\n      for (i=0; i < tga_width * tga_height; ++i)\n      {\n         //   if I'm in RLE mode, do I need to get a RLE stbi__pngchunk?\n         if ( tga_is_RLE )\n         {\n            if ( RLE_count == 0 )\n            {\n               //   yep, get the next byte as a RLE command\n               int RLE_cmd = stbi__get8(s);\n               RLE_count = 1 + (RLE_cmd & 127);\n               RLE_repeating = RLE_cmd >> 7;\n               read_next_pixel = 1;\n            } else if ( !RLE_repeating )\n            {\n               read_next_pixel = 1;\n            }\n         } else\n         {\n            read_next_pixel = 1;\n         }\n         //   OK, if I need to read a pixel, do it now\n         if ( read_next_pixel )\n         {\n            //   load however much data we did have\n            if ( tga_indexed )\n            {\n               // read in index, then perform the lookup\n               int pal_idx = (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s);\n               if ( pal_idx >= tga_palette_len ) {\n                  // invalid index\n                  pal_idx = 0;\n               }\n               pal_idx *= tga_comp;\n               for (j = 0; j < tga_comp; ++j) {\n                  raw_data[j] = tga_palette[pal_idx+j];\n               }\n            } else if(tga_rgb16) {\n               STBI_ASSERT(tga_comp == STBI_rgb);\n               stbi__tga_read_rgb16(s, raw_data);\n            } else {\n               //   read in the data raw\n               for (j = 0; j < tga_comp; ++j) {\n                  raw_data[j] = stbi__get8(s);\n               }\n            }\n            //   clear the reading flag for the next pixel\n            read_next_pixel = 0;\n         } // end of reading a pixel\n\n         // copy data\n         for (j = 0; j < tga_comp; ++j)\n           tga_data[i*tga_comp+j] = raw_data[j];\n\n         //   in case we're in RLE mode, keep counting down\n         --RLE_count;\n      }\n      //   do I need to invert the image?\n      if ( tga_inverted )\n      {\n         for (j = 0; j*2 < tga_height; ++j)\n         {\n            int index1 = j * tga_width * tga_comp;\n            int index2 = (tga_height - 1 - j) * tga_width * tga_comp;\n            for (i = tga_width * tga_comp; i > 0; --i)\n            {\n               unsigned char temp = tga_data[index1];\n               tga_data[index1] = tga_data[index2];\n               tga_data[index2] = temp;\n               ++index1;\n               ++index2;\n            }\n         }\n      }\n      //   clear my palette, if I had one\n      if ( tga_palette != NULL )\n      {\n         STBI_FREE( tga_palette );\n      }\n   }\n\n   // swap RGB - if the source data was RGB16, it already is in the right order\n   if (tga_comp >= 3 && !tga_rgb16)\n   {\n      unsigned char* tga_pixel = tga_data;\n      for (i=0; i < tga_width * tga_height; ++i)\n      {\n         unsigned char temp = tga_pixel[0];\n         tga_pixel[0] = tga_pixel[2];\n         tga_pixel[2] = temp;\n         tga_pixel += tga_comp;\n      }\n   }\n\n   // convert to target component count\n   if (req_comp && req_comp != tga_comp)\n      tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, tga_height);\n\n   //   the things I do to get rid of an error message, and yet keep\n   //   Microsoft's C compilers happy... [8^(\n   tga_palette_start = tga_palette_len = tga_palette_bits =\n         tga_x_origin = tga_y_origin = 0;\n   STBI_NOTUSED(tga_palette_start);\n   //   OK, done\n   return tga_data;\n}\n#endif\n\n// *************************************************************************************************\n// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, tweaked by STB\n\n#ifndef STBI_NO_PSD\nstatic int stbi__psd_test(stbi__context *s)\n{\n   int r = (stbi__get32be(s) == 0x38425053);\n   stbi__rewind(s);\n   return r;\n}\n\nstatic int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount)\n{\n   int count, nleft, len;\n\n   count = 0;\n   while ((nleft = pixelCount - count) > 0) {\n      len = stbi__get8(s);\n      if (len == 128) {\n         // No-op.\n      } else if (len < 128) {\n         // Copy next len+1 bytes literally.\n         len++;\n         if (len > nleft) return 0; // corrupt data\n         count += len;\n         while (len) {\n            *p = stbi__get8(s);\n            p += 4;\n            len--;\n         }\n      } else if (len > 128) {\n         stbi_uc   val;\n         // Next -len+1 bytes in the dest are replicated from next source byte.\n         // (Interpret len as a negative 8-bit int.)\n         len = 257 - len;\n         if (len > nleft) return 0; // corrupt data\n         val = stbi__get8(s);\n         count += len;\n         while (len) {\n            *p = val;\n            p += 4;\n            len--;\n         }\n      }\n   }\n\n   return 1;\n}\n\nstatic void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc)\n{\n   int pixelCount;\n   int channelCount, compression;\n   int channel, i;\n   int bitdepth;\n   int w,h;\n   stbi_uc *out;\n   STBI_NOTUSED(ri);\n\n   // Check identifier\n   if (stbi__get32be(s) != 0x38425053)   // \"8BPS\"\n      return stbi__errpuc(\"not PSD\", \"Corrupt PSD image\");\n\n   // Check file type version.\n   if (stbi__get16be(s) != 1)\n      return stbi__errpuc(\"wrong version\", \"Unsupported version of PSD image\");\n\n   // Skip 6 reserved bytes.\n   stbi__skip(s, 6 );\n\n   // Read the number of channels (R, G, B, A, etc).\n   channelCount = stbi__get16be(s);\n   if (channelCount < 0 || channelCount > 16)\n      return stbi__errpuc(\"wrong channel count\", \"Unsupported number of channels in PSD image\");\n\n   // Read the rows and columns of the image.\n   h = stbi__get32be(s);\n   w = stbi__get32be(s);\n\n   if (h > STBI_MAX_DIMENSIONS) return stbi__errpuc(\"too large\",\"Very large image (corrupt?)\");\n   if (w > STBI_MAX_DIMENSIONS) return stbi__errpuc(\"too large\",\"Very large image (corrupt?)\");\n\n   // Make sure the depth is 8 bits.\n   bitdepth = stbi__get16be(s);\n   if (bitdepth != 8 && bitdepth != 16)\n      return stbi__errpuc(\"unsupported bit depth\", \"PSD bit depth is not 8 or 16 bit\");\n\n   // Make sure the color mode is RGB.\n   // Valid options are:\n   //   0: Bitmap\n   //   1: Grayscale\n   //   2: Indexed color\n   //   3: RGB color\n   //   4: CMYK color\n   //   7: Multichannel\n   //   8: Duotone\n   //   9: Lab color\n   if (stbi__get16be(s) != 3)\n      return stbi__errpuc(\"wrong color format\", \"PSD is not in RGB color format\");\n\n   // Skip the Mode Data.  (It's the palette for indexed color; other info for other modes.)\n   stbi__skip(s,stbi__get32be(s) );\n\n   // Skip the image resources.  (resolution, pen tool paths, etc)\n   stbi__skip(s, stbi__get32be(s) );\n\n   // Skip the reserved data.\n   stbi__skip(s, stbi__get32be(s) );\n\n   // Find out if the data is compressed.\n   // Known values:\n   //   0: no compression\n   //   1: RLE compressed\n   compression = stbi__get16be(s);\n   if (compression > 1)\n      return stbi__errpuc(\"bad compression\", \"PSD has an unknown compression format\");\n\n   // Check size\n   if (!stbi__mad3sizes_valid(4, w, h, 0))\n      return stbi__errpuc(\"too large\", \"Corrupt PSD\");\n\n   // Create the destination image.\n\n   if (!compression && bitdepth == 16 && bpc == 16) {\n      out = (stbi_uc *) stbi__malloc_mad3(8, w, h, 0);\n      ri->bits_per_channel = 16;\n   } else\n      out = (stbi_uc *) stbi__malloc(4 * w*h);\n\n   if (!out) return stbi__errpuc(\"outofmem\", \"Out of memory\");\n   pixelCount = w*h;\n\n   // Initialize the data to zero.\n   //memset( out, 0, pixelCount * 4 );\n\n   // Finally, the image data.\n   if (compression) {\n      // RLE as used by .PSD and .TIFF\n      // Loop until you get the number of unpacked bytes you are expecting:\n      //     Read the next source byte into n.\n      //     If n is between 0 and 127 inclusive, copy the next n+1 bytes literally.\n      //     Else if n is between -127 and -1 inclusive, copy the next byte -n+1 times.\n      //     Else if n is 128, noop.\n      // Endloop\n\n      // The RLE-compressed data is preceded by a 2-byte data count for each row in the data,\n      // which we're going to just skip.\n      stbi__skip(s, h * channelCount * 2 );\n\n      // Read the RLE data by channel.\n      for (channel = 0; channel < 4; channel++) {\n         stbi_uc *p;\n\n         p = out+channel;\n         if (channel >= channelCount) {\n            // Fill this channel with default data.\n            for (i = 0; i < pixelCount; i++, p += 4)\n               *p = (channel == 3 ? 255 : 0);\n         } else {\n            // Read the RLE data.\n            if (!stbi__psd_decode_rle(s, p, pixelCount)) {\n               STBI_FREE(out);\n               return stbi__errpuc(\"corrupt\", \"bad RLE data\");\n            }\n         }\n      }\n\n   } else {\n      // We're at the raw image data.  It's each channel in order (Red, Green, Blue, Alpha, ...)\n      // where each channel consists of an 8-bit (or 16-bit) value for each pixel in the image.\n\n      // Read the data by channel.\n      for (channel = 0; channel < 4; channel++) {\n         if (channel >= channelCount) {\n            // Fill this channel with default data.\n            if (bitdepth == 16 && bpc == 16) {\n               stbi__uint16 *q = ((stbi__uint16 *) out) + channel;\n               stbi__uint16 val = channel == 3 ? 65535 : 0;\n               for (i = 0; i < pixelCount; i++, q += 4)\n                  *q = val;\n            } else {\n               stbi_uc *p = out+channel;\n               stbi_uc val = channel == 3 ? 255 : 0;\n               for (i = 0; i < pixelCount; i++, p += 4)\n                  *p = val;\n            }\n         } else {\n            if (ri->bits_per_channel == 16) {    // output bpc\n               stbi__uint16 *q = ((stbi__uint16 *) out) + channel;\n               for (i = 0; i < pixelCount; i++, q += 4)\n                  *q = (stbi__uint16) stbi__get16be(s);\n            } else {\n               stbi_uc *p = out+channel;\n               if (bitdepth == 16) {  // input bpc\n                  for (i = 0; i < pixelCount; i++, p += 4)\n                     *p = (stbi_uc) (stbi__get16be(s) >> 8);\n               } else {\n                  for (i = 0; i < pixelCount; i++, p += 4)\n                     *p = stbi__get8(s);\n               }\n            }\n         }\n      }\n   }\n\n   // remove weird white matte from PSD\n   if (channelCount >= 4) {\n      if (ri->bits_per_channel == 16) {\n         for (i=0; i < w*h; ++i) {\n            stbi__uint16 *pixel = (stbi__uint16 *) out + 4*i;\n            if (pixel[3] != 0 && pixel[3] != 65535) {\n               float a = pixel[3] / 65535.0f;\n               float ra = 1.0f / a;\n               float inv_a = 65535.0f * (1 - ra);\n               pixel[0] = (stbi__uint16) (pixel[0]*ra + inv_a);\n               pixel[1] = (stbi__uint16) (pixel[1]*ra + inv_a);\n               pixel[2] = (stbi__uint16) (pixel[2]*ra + inv_a);\n            }\n         }\n      } else {\n         for (i=0; i < w*h; ++i) {\n            unsigned char *pixel = out + 4*i;\n            if (pixel[3] != 0 && pixel[3] != 255) {\n               float a = pixel[3] / 255.0f;\n               float ra = 1.0f / a;\n               float inv_a = 255.0f * (1 - ra);\n               pixel[0] = (unsigned char) (pixel[0]*ra + inv_a);\n               pixel[1] = (unsigned char) (pixel[1]*ra + inv_a);\n               pixel[2] = (unsigned char) (pixel[2]*ra + inv_a);\n            }\n         }\n      }\n   }\n\n   // convert to desired output format\n   if (req_comp && req_comp != 4) {\n      if (ri->bits_per_channel == 16)\n         out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, 4, req_comp, w, h);\n      else\n         out = stbi__convert_format(out, 4, req_comp, w, h);\n      if (out == NULL) return out; // stbi__convert_format frees input on failure\n   }\n\n   if (comp) *comp = 4;\n   *y = h;\n   *x = w;\n\n   return out;\n}\n#endif\n\n// *************************************************************************************************\n// Softimage PIC loader\n// by Tom Seddon\n//\n// See http://softimage.wiki.softimage.com/index.php/INFO:_PIC_file_format\n// See http://ozviz.wasp.uwa.edu.au/~pbourke/dataformats/softimagepic/\n\n#ifndef STBI_NO_PIC\nstatic int stbi__pic_is4(stbi__context *s,const char *str)\n{\n   int i;\n   for (i=0; i<4; ++i)\n      if (stbi__get8(s) != (stbi_uc)str[i])\n         return 0;\n\n   return 1;\n}\n\nstatic int stbi__pic_test_core(stbi__context *s)\n{\n   int i;\n\n   if (!stbi__pic_is4(s,\"\\x53\\x80\\xF6\\x34\"))\n      return 0;\n\n   for(i=0;i<84;++i)\n      stbi__get8(s);\n\n   if (!stbi__pic_is4(s,\"PICT\"))\n      return 0;\n\n   return 1;\n}\n\ntypedef struct\n{\n   stbi_uc size,type,channel;\n} stbi__pic_packet;\n\nstatic stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest)\n{\n   int mask=0x80, i;\n\n   for (i=0; i<4; ++i, mask>>=1) {\n      if (channel & mask) {\n         if (stbi__at_eof(s)) return stbi__errpuc(\"bad file\",\"PIC file too short\");\n         dest[i]=stbi__get8(s);\n      }\n   }\n\n   return dest;\n}\n\nstatic void stbi__copyval(int channel,stbi_uc *dest,const stbi_uc *src)\n{\n   int mask=0x80,i;\n\n   for (i=0;i<4; ++i, mask>>=1)\n      if (channel&mask)\n         dest[i]=src[i];\n}\n\nstatic stbi_uc *stbi__pic_load_core(stbi__context *s,int width,int height,int *comp, stbi_uc *result)\n{\n   int act_comp=0,num_packets=0,y,chained;\n   stbi__pic_packet packets[10];\n\n   // this will (should...) cater for even some bizarre stuff like having data\n    // for the same channel in multiple packets.\n   do {\n      stbi__pic_packet *packet;\n\n      if (num_packets==sizeof(packets)/sizeof(packets[0]))\n         return stbi__errpuc(\"bad format\",\"too many packets\");\n\n      packet = &packets[num_packets++];\n\n      chained = stbi__get8(s);\n      packet->size    = stbi__get8(s);\n      packet->type    = stbi__get8(s);\n      packet->channel = stbi__get8(s);\n\n      act_comp |= packet->channel;\n\n      if (stbi__at_eof(s))          return stbi__errpuc(\"bad file\",\"file too short (reading packets)\");\n      if (packet->size != 8)  return stbi__errpuc(\"bad format\",\"packet isn't 8bpp\");\n   } while (chained);\n\n   *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel?\n\n   for(y=0; y<height; ++y) {\n      int packet_idx;\n\n      for(packet_idx=0; packet_idx < num_packets; ++packet_idx) {\n         stbi__pic_packet *packet = &packets[packet_idx];\n         stbi_uc *dest = result+y*width*4;\n\n         switch (packet->type) {\n            default:\n               return stbi__errpuc(\"bad format\",\"packet has bad compression type\");\n\n            case 0: {//uncompressed\n               int x;\n\n               for(x=0;x<width;++x, dest+=4)\n                  if (!stbi__readval(s,packet->channel,dest))\n                     return 0;\n               break;\n            }\n\n            case 1://Pure RLE\n               {\n                  int left=width, i;\n\n                  while (left>0) {\n                     stbi_uc count,value[4];\n\n                     count=stbi__get8(s);\n                     if (stbi__at_eof(s))   return stbi__errpuc(\"bad file\",\"file too short (pure read count)\");\n\n                     if (count > left)\n                        count = (stbi_uc) left;\n\n                     if (!stbi__readval(s,packet->channel,value))  return 0;\n\n                     for(i=0; i<count; ++i,dest+=4)\n                        stbi__copyval(packet->channel,dest,value);\n                     left -= count;\n                  }\n               }\n               break;\n\n            case 2: {//Mixed RLE\n               int left=width;\n               while (left>0) {\n                  int count = stbi__get8(s), i;\n                  if (stbi__at_eof(s))  return stbi__errpuc(\"bad file\",\"file too short (mixed read count)\");\n\n                  if (count >= 128) { // Repeated\n                     stbi_uc value[4];\n\n                     if (count==128)\n                        count = stbi__get16be(s);\n                     else\n                        count -= 127;\n                     if (count > left)\n                        return stbi__errpuc(\"bad file\",\"scanline overrun\");\n\n                     if (!stbi__readval(s,packet->channel,value))\n                        return 0;\n\n                     for(i=0;i<count;++i, dest += 4)\n                        stbi__copyval(packet->channel,dest,value);\n                  } else { // Raw\n                     ++count;\n                     if (count>left) return stbi__errpuc(\"bad file\",\"scanline overrun\");\n\n                     for(i=0;i<count;++i, dest+=4)\n                        if (!stbi__readval(s,packet->channel,dest))\n                           return 0;\n                  }\n                  left-=count;\n               }\n               break;\n            }\n         }\n      }\n   }\n\n   return result;\n}\n\nstatic void *stbi__pic_load(stbi__context *s,int *px,int *py,int *comp,int req_comp, stbi__result_info *ri)\n{\n   stbi_uc *result;\n   int i, x,y, internal_comp;\n   STBI_NOTUSED(ri);\n\n   if (!comp) comp = &internal_comp;\n\n   for (i=0; i<92; ++i)\n      stbi__get8(s);\n\n   x = stbi__get16be(s);\n   y = stbi__get16be(s);\n\n   if (y > STBI_MAX_DIMENSIONS) return stbi__errpuc(\"too large\",\"Very large image (corrupt?)\");\n   if (x > STBI_MAX_DIMENSIONS) return stbi__errpuc(\"too large\",\"Very large image (corrupt?)\");\n\n   if (stbi__at_eof(s))  return stbi__errpuc(\"bad file\",\"file too short (pic header)\");\n   if (!stbi__mad3sizes_valid(x, y, 4, 0)) return stbi__errpuc(\"too large\", \"PIC image too large to decode\");\n\n   stbi__get32be(s); //skip `ratio'\n   stbi__get16be(s); //skip `fields'\n   stbi__get16be(s); //skip `pad'\n\n   // intermediate buffer is RGBA\n   result = (stbi_uc *) stbi__malloc_mad3(x, y, 4, 0);\n   if (!result) return stbi__errpuc(\"outofmem\", \"Out of memory\");\n   memset(result, 0xff, x*y*4);\n\n   if (!stbi__pic_load_core(s,x,y,comp, result)) {\n      STBI_FREE(result);\n      result=0;\n   }\n   *px = x;\n   *py = y;\n   if (req_comp == 0) req_comp = *comp;\n   result=stbi__convert_format(result,4,req_comp,x,y);\n\n   return result;\n}\n\nstatic int stbi__pic_test(stbi__context *s)\n{\n   int r = stbi__pic_test_core(s);\n   stbi__rewind(s);\n   return r;\n}\n#endif\n\n// *************************************************************************************************\n// GIF loader -- public domain by Jean-Marc Lienher -- simplified/shrunk by stb\n\n#ifndef STBI_NO_GIF\ntypedef struct\n{\n   stbi__int16 prefix;\n   stbi_uc first;\n   stbi_uc suffix;\n} stbi__gif_lzw;\n\ntypedef struct\n{\n   int w,h;\n   stbi_uc *out;                 // output buffer (always 4 components)\n   stbi_uc *background;          // The current \"background\" as far as a gif is concerned\n   stbi_uc *history;\n   int flags, bgindex, ratio, transparent, eflags;\n   stbi_uc  pal[256][4];\n   stbi_uc lpal[256][4];\n   stbi__gif_lzw codes[8192];\n   stbi_uc *color_table;\n   int parse, step;\n   int lflags;\n   int start_x, start_y;\n   int max_x, max_y;\n   int cur_x, cur_y;\n   int line_size;\n   int delay;\n} stbi__gif;\n\nstatic int stbi__gif_test_raw(stbi__context *s)\n{\n   int sz;\n   if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') return 0;\n   sz = stbi__get8(s);\n   if (sz != '9' && sz != '7') return 0;\n   if (stbi__get8(s) != 'a') return 0;\n   return 1;\n}\n\nstatic int stbi__gif_test(stbi__context *s)\n{\n   int r = stbi__gif_test_raw(s);\n   stbi__rewind(s);\n   return r;\n}\n\nstatic void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], int num_entries, int transp)\n{\n   int i;\n   for (i=0; i < num_entries; ++i) {\n      pal[i][2] = stbi__get8(s);\n      pal[i][1] = stbi__get8(s);\n      pal[i][0] = stbi__get8(s);\n      pal[i][3] = transp == i ? 0 : 255;\n   }\n}\n\nstatic int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, int is_info)\n{\n   stbi_uc version;\n   if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8')\n      return stbi__err(\"not GIF\", \"Corrupt GIF\");\n\n   version = stbi__get8(s);\n   if (version != '7' && version != '9')    return stbi__err(\"not GIF\", \"Corrupt GIF\");\n   if (stbi__get8(s) != 'a')                return stbi__err(\"not GIF\", \"Corrupt GIF\");\n\n   stbi__g_failure_reason = \"\";\n   g->w = stbi__get16le(s);\n   g->h = stbi__get16le(s);\n   g->flags = stbi__get8(s);\n   g->bgindex = stbi__get8(s);\n   g->ratio = stbi__get8(s);\n   g->transparent = -1;\n\n   if (g->w > STBI_MAX_DIMENSIONS) return stbi__err(\"too large\",\"Very large image (corrupt?)\");\n   if (g->h > STBI_MAX_DIMENSIONS) return stbi__err(\"too large\",\"Very large image (corrupt?)\");\n\n   if (comp != 0) *comp = 4;  // can't actually tell whether it's 3 or 4 until we parse the comments\n\n   if (is_info) return 1;\n\n   if (g->flags & 0x80)\n      stbi__gif_parse_colortable(s,g->pal, 2 << (g->flags & 7), -1);\n\n   return 1;\n}\n\nstatic int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp)\n{\n   stbi__gif* g = (stbi__gif*) stbi__malloc(sizeof(stbi__gif));\n   if (!g) return stbi__err(\"outofmem\", \"Out of memory\");\n   if (!stbi__gif_header(s, g, comp, 1)) {\n      STBI_FREE(g);\n      stbi__rewind( s );\n      return 0;\n   }\n   if (x) *x = g->w;\n   if (y) *y = g->h;\n   STBI_FREE(g);\n   return 1;\n}\n\nstatic void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code)\n{\n   stbi_uc *p, *c;\n   int idx;\n\n   // recurse to decode the prefixes, since the linked-list is backwards,\n   // and working backwards through an interleaved image would be nasty\n   if (g->codes[code].prefix >= 0)\n      stbi__out_gif_code(g, g->codes[code].prefix);\n\n   if (g->cur_y >= g->max_y) return;\n\n   idx = g->cur_x + g->cur_y;\n   p = &g->out[idx];\n   g->history[idx / 4] = 1;\n\n   c = &g->color_table[g->codes[code].suffix * 4];\n   if (c[3] > 128) { // don't render transparent pixels;\n      p[0] = c[2];\n      p[1] = c[1];\n      p[2] = c[0];\n      p[3] = c[3];\n   }\n   g->cur_x += 4;\n\n   if (g->cur_x >= g->max_x) {\n      g->cur_x = g->start_x;\n      g->cur_y += g->step;\n\n      while (g->cur_y >= g->max_y && g->parse > 0) {\n         g->step = (1 << g->parse) * g->line_size;\n         g->cur_y = g->start_y + (g->step >> 1);\n         --g->parse;\n      }\n   }\n}\n\nstatic stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g)\n{\n   stbi_uc lzw_cs;\n   stbi__int32 len, init_code;\n   stbi__uint32 first;\n   stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear;\n   stbi__gif_lzw *p;\n\n   lzw_cs = stbi__get8(s);\n   if (lzw_cs > 12) return NULL;\n   clear = 1 << lzw_cs;\n   first = 1;\n   codesize = lzw_cs + 1;\n   codemask = (1 << codesize) - 1;\n   bits = 0;\n   valid_bits = 0;\n   for (init_code = 0; init_code < clear; init_code++) {\n      g->codes[init_code].prefix = -1;\n      g->codes[init_code].first = (stbi_uc) init_code;\n      g->codes[init_code].suffix = (stbi_uc) init_code;\n   }\n\n   // support no starting clear code\n   avail = clear+2;\n   oldcode = -1;\n\n   len = 0;\n   for(;;) {\n      if (valid_bits < codesize) {\n         if (len == 0) {\n            len = stbi__get8(s); // start new block\n            if (len == 0)\n               return g->out;\n         }\n         --len;\n         bits |= (stbi__int32) stbi__get8(s) << valid_bits;\n         valid_bits += 8;\n      } else {\n         stbi__int32 code = bits & codemask;\n         bits >>= codesize;\n         valid_bits -= codesize;\n         // @OPTIMIZE: is there some way we can accelerate the non-clear path?\n         if (code == clear) {  // clear code\n            codesize = lzw_cs + 1;\n            codemask = (1 << codesize) - 1;\n            avail = clear + 2;\n            oldcode = -1;\n            first = 0;\n         } else if (code == clear + 1) { // end of stream code\n            stbi__skip(s, len);\n            while ((len = stbi__get8(s)) > 0)\n               stbi__skip(s,len);\n            return g->out;\n         } else if (code <= avail) {\n            if (first) {\n               return stbi__errpuc(\"no clear code\", \"Corrupt GIF\");\n            }\n\n            if (oldcode >= 0) {\n               p = &g->codes[avail++];\n               if (avail > 8192) {\n                  return stbi__errpuc(\"too many codes\", \"Corrupt GIF\");\n               }\n\n               p->prefix = (stbi__int16) oldcode;\n               p->first = g->codes[oldcode].first;\n               p->suffix = (code == avail) ? p->first : g->codes[code].first;\n            } else if (code == avail)\n               return stbi__errpuc(\"illegal code in raster\", \"Corrupt GIF\");\n\n            stbi__out_gif_code(g, (stbi__uint16) code);\n\n            if ((avail & codemask) == 0 && avail <= 0x0FFF) {\n               codesize++;\n               codemask = (1 << codesize) - 1;\n            }\n\n            oldcode = code;\n         } else {\n            return stbi__errpuc(\"illegal code in raster\", \"Corrupt GIF\");\n         }\n      }\n   }\n}\n\n// this function is designed to support animated gifs, although stb_image doesn't support it\n// two back is the image from two frames ago, used for a very specific disposal format\nstatic stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, int req_comp, stbi_uc *two_back)\n{\n   int dispose;\n   int first_frame;\n   int pi;\n   int pcount;\n   STBI_NOTUSED(req_comp);\n\n   // on first frame, any non-written pixels get the background colour (non-transparent)\n   first_frame = 0;\n   if (g->out == 0) {\n      if (!stbi__gif_header(s, g, comp,0)) return 0; // stbi__g_failure_reason set by stbi__gif_header\n      if (!stbi__mad3sizes_valid(4, g->w, g->h, 0))\n         return stbi__errpuc(\"too large\", \"GIF image is too large\");\n      pcount = g->w * g->h;\n      g->out = (stbi_uc *) stbi__malloc(4 * pcount);\n      g->background = (stbi_uc *) stbi__malloc(4 * pcount);\n      g->history = (stbi_uc *) stbi__malloc(pcount);\n      if (!g->out || !g->background || !g->history)\n         return stbi__errpuc(\"outofmem\", \"Out of memory\");\n\n      // image is treated as \"transparent\" at the start - ie, nothing overwrites the current background;\n      // background colour is only used for pixels that are not rendered first frame, after that \"background\"\n      // color refers to the color that was there the previous frame.\n      memset(g->out, 0x00, 4 * pcount);\n      memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent)\n      memset(g->history, 0x00, pcount);        // pixels that were affected previous frame\n      first_frame = 1;\n   } else {\n      // second frame - how do we dispose of the previous one?\n      dispose = (g->eflags & 0x1C) >> 2;\n      pcount = g->w * g->h;\n\n      if ((dispose == 3) && (two_back == 0)) {\n         dispose = 2; // if I don't have an image to revert back to, default to the old background\n      }\n\n      if (dispose == 3) { // use previous graphic\n         for (pi = 0; pi < pcount; ++pi) {\n            if (g->history[pi]) {\n               memcpy( &g->out[pi * 4], &two_back[pi * 4], 4 );\n            }\n         }\n      } else if (dispose == 2) {\n         // restore what was changed last frame to background before that frame;\n         for (pi = 0; pi < pcount; ++pi) {\n            if (g->history[pi]) {\n               memcpy( &g->out[pi * 4], &g->background[pi * 4], 4 );\n            }\n         }\n      } else {\n         // This is a non-disposal case eithe way, so just\n         // leave the pixels as is, and they will become the new background\n         // 1: do not dispose\n         // 0:  not specified.\n      }\n\n      // background is what out is after the undoing of the previou frame;\n      memcpy( g->background, g->out, 4 * g->w * g->h );\n   }\n\n   // clear my history;\n   memset( g->history, 0x00, g->w * g->h );        // pixels that were affected previous frame\n\n   for (;;) {\n      int tag = stbi__get8(s);\n      switch (tag) {\n         case 0x2C: /* Image Descriptor */\n         {\n            stbi__int32 x, y, w, h;\n            stbi_uc *o;\n\n            x = stbi__get16le(s);\n            y = stbi__get16le(s);\n            w = stbi__get16le(s);\n            h = stbi__get16le(s);\n            if (((x + w) > (g->w)) || ((y + h) > (g->h)))\n               return stbi__errpuc(\"bad Image Descriptor\", \"Corrupt GIF\");\n\n            g->line_size = g->w * 4;\n            g->start_x = x * 4;\n            g->start_y = y * g->line_size;\n            g->max_x   = g->start_x + w * 4;\n            g->max_y   = g->start_y + h * g->line_size;\n            g->cur_x   = g->start_x;\n            g->cur_y   = g->start_y;\n\n            // if the width of the specified rectangle is 0, that means\n            // we may not see *any* pixels or the image is malformed;\n            // to make sure this is caught, move the current y down to\n            // max_y (which is what out_gif_code checks).\n            if (w == 0)\n               g->cur_y = g->max_y;\n\n            g->lflags = stbi__get8(s);\n\n            if (g->lflags & 0x40) {\n               g->step = 8 * g->line_size; // first interlaced spacing\n               g->parse = 3;\n            } else {\n               g->step = g->line_size;\n               g->parse = 0;\n            }\n\n            if (g->lflags & 0x80) {\n               stbi__gif_parse_colortable(s,g->lpal, 2 << (g->lflags & 7), g->eflags & 0x01 ? g->transparent : -1);\n               g->color_table = (stbi_uc *) g->lpal;\n            } else if (g->flags & 0x80) {\n               g->color_table = (stbi_uc *) g->pal;\n            } else\n               return stbi__errpuc(\"missing color table\", \"Corrupt GIF\");\n\n            o = stbi__process_gif_raster(s, g);\n            if (!o) return NULL;\n\n            // if this was the first frame,\n            pcount = g->w * g->h;\n            if (first_frame && (g->bgindex > 0)) {\n               // if first frame, any pixel not drawn to gets the background color\n               for (pi = 0; pi < pcount; ++pi) {\n                  if (g->history[pi] == 0) {\n                     g->pal[g->bgindex][3] = 255; // just in case it was made transparent, undo that; It will be reset next frame if need be;\n                     memcpy( &g->out[pi * 4], &g->pal[g->bgindex], 4 );\n                  }\n               }\n            }\n\n            return o;\n         }\n\n         case 0x21: // Comment Extension.\n         {\n            int len;\n            int ext = stbi__get8(s);\n            if (ext == 0xF9) { // Graphic Control Extension.\n               len = stbi__get8(s);\n               if (len == 4) {\n                  g->eflags = stbi__get8(s);\n                  g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths.\n\n                  // unset old transparent\n                  if (g->transparent >= 0) {\n                     g->pal[g->transparent][3] = 255;\n                  }\n                  if (g->eflags & 0x01) {\n                     g->transparent = stbi__get8(s);\n                     if (g->transparent >= 0) {\n                        g->pal[g->transparent][3] = 0;\n                     }\n                  } else {\n                     // don't need transparent\n                     stbi__skip(s, 1);\n                     g->transparent = -1;\n                  }\n               } else {\n                  stbi__skip(s, len);\n                  break;\n               }\n            }\n            while ((len = stbi__get8(s)) != 0) {\n               stbi__skip(s, len);\n            }\n            break;\n         }\n\n         case 0x3B: // gif stream termination code\n            return (stbi_uc *) s; // using '1' causes warning on some compilers\n\n         default:\n            return stbi__errpuc(\"unknown code\", \"Corrupt GIF\");\n      }\n   }\n}\n\nstatic void *stbi__load_gif_main_outofmem(stbi__gif *g, stbi_uc *out, int **delays)\n{\n   STBI_FREE(g->out);\n   STBI_FREE(g->history);\n   STBI_FREE(g->background);\n\n   if (out) STBI_FREE(out);\n   if (delays && *delays) STBI_FREE(*delays);\n   return stbi__errpuc(\"outofmem\", \"Out of memory\");\n}\n\nstatic void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp)\n{\n   if (stbi__gif_test(s)) {\n      int layers = 0;\n      stbi_uc *u = 0;\n      stbi_uc *out = 0;\n      stbi_uc *two_back = 0;\n      stbi__gif g;\n      int stride;\n      int out_size = 0;\n      int delays_size = 0;\n\n      STBI_NOTUSED(out_size);\n      STBI_NOTUSED(delays_size);\n\n      memset(&g, 0, sizeof(g));\n      if (delays) {\n         *delays = 0;\n      }\n\n      do {\n         u = stbi__gif_load_next(s, &g, comp, req_comp, two_back);\n         if (u == (stbi_uc *) s) u = 0;  // end of animated gif marker\n\n         if (u) {\n            *x = g.w;\n            *y = g.h;\n            ++layers;\n            stride = g.w * g.h * 4;\n\n            if (out) {\n               void *tmp = (stbi_uc*) STBI_REALLOC_SIZED( out, out_size, layers * stride );\n               if (!tmp)\n                  return stbi__load_gif_main_outofmem(&g, out, delays);\n               else {\n                   out = (stbi_uc*) tmp;\n                   out_size = layers * stride;\n               }\n\n               if (delays) {\n                  int *new_delays = (int*) STBI_REALLOC_SIZED( *delays, delays_size, sizeof(int) * layers );\n                  if (!new_delays)\n                     return stbi__load_gif_main_outofmem(&g, out, delays);\n                  *delays = new_delays;\n                  delays_size = layers * sizeof(int);\n               }\n            } else {\n               out = (stbi_uc*)stbi__malloc( layers * stride );\n               if (!out)\n                  return stbi__load_gif_main_outofmem(&g, out, delays);\n               out_size = layers * stride;\n               if (delays) {\n                  *delays = (int*) stbi__malloc( layers * sizeof(int) );\n                  if (!*delays)\n                     return stbi__load_gif_main_outofmem(&g, out, delays);\n                  delays_size = layers * sizeof(int);\n               }\n            }\n            memcpy( out + ((layers - 1) * stride), u, stride );\n            if (layers >= 2) {\n               two_back = out - 2 * stride;\n            }\n\n            if (delays) {\n               (*delays)[layers - 1U] = g.delay;\n            }\n         }\n      } while (u != 0);\n\n      // free temp buffer;\n      STBI_FREE(g.out);\n      STBI_FREE(g.history);\n      STBI_FREE(g.background);\n\n      // do the final conversion after loading everything;\n      if (req_comp && req_comp != 4)\n         out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h);\n\n      *z = layers;\n      return out;\n   } else {\n      return stbi__errpuc(\"not GIF\", \"Image was not as a gif type.\");\n   }\n}\n\nstatic void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri)\n{\n   stbi_uc *u = 0;\n   stbi__gif g;\n   memset(&g, 0, sizeof(g));\n   STBI_NOTUSED(ri);\n\n   u = stbi__gif_load_next(s, &g, comp, req_comp, 0);\n   if (u == (stbi_uc *) s) u = 0;  // end of animated gif marker\n   if (u) {\n      *x = g.w;\n      *y = g.h;\n\n      // moved conversion to after successful load so that the same\n      // can be done for multiple frames.\n      if (req_comp && req_comp != 4)\n         u = stbi__convert_format(u, 4, req_comp, g.w, g.h);\n   } else if (g.out) {\n      // if there was an error and we allocated an image buffer, free it!\n      STBI_FREE(g.out);\n   }\n\n   // free buffers needed for multiple frame loading;\n   STBI_FREE(g.history);\n   STBI_FREE(g.background);\n\n   return u;\n}\n\nstatic int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp)\n{\n   return stbi__gif_info_raw(s,x,y,comp);\n}\n#endif\n\n// *************************************************************************************************\n// Radiance RGBE HDR loader\n// originally by Nicolas Schulz\n#ifndef STBI_NO_HDR\nstatic int stbi__hdr_test_core(stbi__context *s, const char *signature)\n{\n   int i;\n   for (i=0; signature[i]; ++i)\n      if (stbi__get8(s) != signature[i])\n          return 0;\n   stbi__rewind(s);\n   return 1;\n}\n\nstatic int stbi__hdr_test(stbi__context* s)\n{\n   int r = stbi__hdr_test_core(s, \"#?RADIANCE\\n\");\n   stbi__rewind(s);\n   if(!r) {\n       r = stbi__hdr_test_core(s, \"#?RGBE\\n\");\n       stbi__rewind(s);\n   }\n   return r;\n}\n\n#define STBI__HDR_BUFLEN  1024\nstatic char *stbi__hdr_gettoken(stbi__context *z, char *buffer)\n{\n   int len=0;\n   char c = '\\0';\n\n   c = (char) stbi__get8(z);\n\n   while (!stbi__at_eof(z) && c != '\\n') {\n      buffer[len++] = c;\n      if (len == STBI__HDR_BUFLEN-1) {\n         // flush to end of line\n         while (!stbi__at_eof(z) && stbi__get8(z) != '\\n')\n            ;\n         break;\n      }\n      c = (char) stbi__get8(z);\n   }\n\n   buffer[len] = 0;\n   return buffer;\n}\n\nstatic void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp)\n{\n   if ( input[3] != 0 ) {\n      float f1;\n      // Exponent\n      f1 = (float) ldexp(1.0f, input[3] - (int)(128 + 8));\n      if (req_comp <= 2)\n         output[0] = (input[0] + input[1] + input[2]) * f1 / 3;\n      else {\n         output[0] = input[0] * f1;\n         output[1] = input[1] * f1;\n         output[2] = input[2] * f1;\n      }\n      if (req_comp == 2) output[1] = 1;\n      if (req_comp == 4) output[3] = 1;\n   } else {\n      switch (req_comp) {\n         case 4: output[3] = 1; /* fallthrough */\n         case 3: output[0] = output[1] = output[2] = 0;\n                 break;\n         case 2: output[1] = 1; /* fallthrough */\n         case 1: output[0] = 0;\n                 break;\n      }\n   }\n}\n\nstatic float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri)\n{\n   char buffer[STBI__HDR_BUFLEN];\n   char *token;\n   int valid = 0;\n   int width, height;\n   stbi_uc *scanline;\n   float *hdr_data;\n   int len;\n   unsigned char count, value;\n   int i, j, k, c1,c2, z;\n   const char *headerToken;\n   STBI_NOTUSED(ri);\n\n   // Check identifier\n   headerToken = stbi__hdr_gettoken(s,buffer);\n   if (strcmp(headerToken, \"#?RADIANCE\") != 0 && strcmp(headerToken, \"#?RGBE\") != 0)\n      return stbi__errpf(\"not HDR\", \"Corrupt HDR image\");\n\n   // Parse header\n   for(;;) {\n      token = stbi__hdr_gettoken(s,buffer);\n      if (token[0] == 0) break;\n      if (strcmp(token, \"FORMAT=32-bit_rle_rgbe\") == 0) valid = 1;\n   }\n\n   if (!valid)    return stbi__errpf(\"unsupported format\", \"Unsupported HDR format\");\n\n   // Parse width and height\n   // can't use sscanf() if we're not using stdio!\n   token = stbi__hdr_gettoken(s,buffer);\n   if (strncmp(token, \"-Y \", 3))  return stbi__errpf(\"unsupported data layout\", \"Unsupported HDR format\");\n   token += 3;\n   height = (int) strtol(token, &token, 10);\n   while (*token == ' ') ++token;\n   if (strncmp(token, \"+X \", 3))  return stbi__errpf(\"unsupported data layout\", \"Unsupported HDR format\");\n   token += 3;\n   width = (int) strtol(token, NULL, 10);\n\n   if (height > STBI_MAX_DIMENSIONS) return stbi__errpf(\"too large\",\"Very large image (corrupt?)\");\n   if (width > STBI_MAX_DIMENSIONS) return stbi__errpf(\"too large\",\"Very large image (corrupt?)\");\n\n   *x = width;\n   *y = height;\n\n   if (comp) *comp = 3;\n   if (req_comp == 0) req_comp = 3;\n\n   if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0))\n      return stbi__errpf(\"too large\", \"HDR image is too large\");\n\n   // Read data\n   hdr_data = (float *) stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0);\n   if (!hdr_data)\n      return stbi__errpf(\"outofmem\", \"Out of memory\");\n\n   // Load image data\n   // image data is stored as some number of sca\n   if ( width < 8 || width >= 32768) {\n      // Read flat data\n      for (j=0; j < height; ++j) {\n         for (i=0; i < width; ++i) {\n            stbi_uc rgbe[4];\n           main_decode_loop:\n            stbi__getn(s, rgbe, 4);\n            stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, req_comp);\n         }\n      }\n   } else {\n      // Read RLE-encoded data\n      scanline = NULL;\n\n      for (j = 0; j < height; ++j) {\n         c1 = stbi__get8(s);\n         c2 = stbi__get8(s);\n         len = stbi__get8(s);\n         if (c1 != 2 || c2 != 2 || (len & 0x80)) {\n            // not run-length encoded, so we have to actually use THIS data as a decoded\n            // pixel (note this can't be a valid pixel--one of RGB must be >= 128)\n            stbi_uc rgbe[4];\n            rgbe[0] = (stbi_uc) c1;\n            rgbe[1] = (stbi_uc) c2;\n            rgbe[2] = (stbi_uc) len;\n            rgbe[3] = (stbi_uc) stbi__get8(s);\n            stbi__hdr_convert(hdr_data, rgbe, req_comp);\n            i = 1;\n            j = 0;\n            STBI_FREE(scanline);\n            goto main_decode_loop; // yes, this makes no sense\n         }\n         len <<= 8;\n         len |= stbi__get8(s);\n         if (len != width) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf(\"invalid decoded scanline length\", \"corrupt HDR\"); }\n         if (scanline == NULL) {\n            scanline = (stbi_uc *) stbi__malloc_mad2(width, 4, 0);\n            if (!scanline) {\n               STBI_FREE(hdr_data);\n               return stbi__errpf(\"outofmem\", \"Out of memory\");\n            }\n         }\n\n         for (k = 0; k < 4; ++k) {\n            int nleft;\n            i = 0;\n            while ((nleft = width - i) > 0) {\n               count = stbi__get8(s);\n               if (count > 128) {\n                  // Run\n                  value = stbi__get8(s);\n                  count -= 128;\n                  if ((count == 0) || (count > nleft)) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf(\"corrupt\", \"bad RLE data in HDR\"); }\n                  for (z = 0; z < count; ++z)\n                     scanline[i++ * 4 + k] = value;\n               } else {\n                  // Dump\n                  if ((count == 0) || (count > nleft)) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf(\"corrupt\", \"bad RLE data in HDR\"); }\n                  for (z = 0; z < count; ++z)\n                     scanline[i++ * 4 + k] = stbi__get8(s);\n               }\n            }\n         }\n         for (i=0; i < width; ++i)\n            stbi__hdr_convert(hdr_data+(j*width + i)*req_comp, scanline + i*4, req_comp);\n      }\n      if (scanline)\n         STBI_FREE(scanline);\n   }\n\n   return hdr_data;\n}\n\nstatic int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp)\n{\n   char buffer[STBI__HDR_BUFLEN];\n   char *token;\n   int valid = 0;\n   int dummy;\n\n   if (!x) x = &dummy;\n   if (!y) y = &dummy;\n   if (!comp) comp = &dummy;\n\n   if (stbi__hdr_test(s) == 0) {\n       stbi__rewind( s );\n       return 0;\n   }\n\n   for(;;) {\n      token = stbi__hdr_gettoken(s,buffer);\n      if (token[0] == 0) break;\n      if (strcmp(token, \"FORMAT=32-bit_rle_rgbe\") == 0) valid = 1;\n   }\n\n   if (!valid) {\n       stbi__rewind( s );\n       return 0;\n   }\n   token = stbi__hdr_gettoken(s,buffer);\n   if (strncmp(token, \"-Y \", 3)) {\n       stbi__rewind( s );\n       return 0;\n   }\n   token += 3;\n   *y = (int) strtol(token, &token, 10);\n   while (*token == ' ') ++token;\n   if (strncmp(token, \"+X \", 3)) {\n       stbi__rewind( s );\n       return 0;\n   }\n   token += 3;\n   *x = (int) strtol(token, NULL, 10);\n   *comp = 3;\n   return 1;\n}\n#endif // STBI_NO_HDR\n\n#ifndef STBI_NO_BMP\nstatic int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp)\n{\n   void *p;\n   stbi__bmp_data info;\n\n   info.all_a = 255;\n   p = stbi__bmp_parse_header(s, &info);\n   if (p == NULL) {\n      stbi__rewind( s );\n      return 0;\n   }\n   if (x) *x = s->img_x;\n   if (y) *y = s->img_y;\n   if (comp) {\n      if (info.bpp == 24 && info.ma == 0xff000000)\n         *comp = 3;\n      else\n         *comp = info.ma ? 4 : 3;\n   }\n   return 1;\n}\n#endif\n\n#ifndef STBI_NO_PSD\nstatic int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp)\n{\n   int channelCount, dummy, depth;\n   if (!x) x = &dummy;\n   if (!y) y = &dummy;\n   if (!comp) comp = &dummy;\n   if (stbi__get32be(s) != 0x38425053) {\n       stbi__rewind( s );\n       return 0;\n   }\n   if (stbi__get16be(s) != 1) {\n       stbi__rewind( s );\n       return 0;\n   }\n   stbi__skip(s, 6);\n   channelCount = stbi__get16be(s);\n   if (channelCount < 0 || channelCount > 16) {\n       stbi__rewind( s );\n       return 0;\n   }\n   *y = stbi__get32be(s);\n   *x = stbi__get32be(s);\n   depth = stbi__get16be(s);\n   if (depth != 8 && depth != 16) {\n       stbi__rewind( s );\n       return 0;\n   }\n   if (stbi__get16be(s) != 3) {\n       stbi__rewind( s );\n       return 0;\n   }\n   *comp = 4;\n   return 1;\n}\n\nstatic int stbi__psd_is16(stbi__context *s)\n{\n   int channelCount, depth;\n   if (stbi__get32be(s) != 0x38425053) {\n       stbi__rewind( s );\n       return 0;\n   }\n   if (stbi__get16be(s) != 1) {\n       stbi__rewind( s );\n       return 0;\n   }\n   stbi__skip(s, 6);\n   channelCount = stbi__get16be(s);\n   if (channelCount < 0 || channelCount > 16) {\n       stbi__rewind( s );\n       return 0;\n   }\n   STBI_NOTUSED(stbi__get32be(s));\n   STBI_NOTUSED(stbi__get32be(s));\n   depth = stbi__get16be(s);\n   if (depth != 16) {\n       stbi__rewind( s );\n       return 0;\n   }\n   return 1;\n}\n#endif\n\n#ifndef STBI_NO_PIC\nstatic int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp)\n{\n   int act_comp=0,num_packets=0,chained,dummy;\n   stbi__pic_packet packets[10];\n\n   if (!x) x = &dummy;\n   if (!y) y = &dummy;\n   if (!comp) comp = &dummy;\n\n   if (!stbi__pic_is4(s,\"\\x53\\x80\\xF6\\x34\")) {\n      stbi__rewind(s);\n      return 0;\n   }\n\n   stbi__skip(s, 88);\n\n   *x = stbi__get16be(s);\n   *y = stbi__get16be(s);\n   if (stbi__at_eof(s)) {\n      stbi__rewind( s);\n      return 0;\n   }\n   if ( (*x) != 0 && (1 << 28) / (*x) < (*y)) {\n      stbi__rewind( s );\n      return 0;\n   }\n\n   stbi__skip(s, 8);\n\n   do {\n      stbi__pic_packet *packet;\n\n      if (num_packets==sizeof(packets)/sizeof(packets[0]))\n         return 0;\n\n      packet = &packets[num_packets++];\n      chained = stbi__get8(s);\n      packet->size    = stbi__get8(s);\n      packet->type    = stbi__get8(s);\n      packet->channel = stbi__get8(s);\n      act_comp |= packet->channel;\n\n      if (stbi__at_eof(s)) {\n          stbi__rewind( s );\n          return 0;\n      }\n      if (packet->size != 8) {\n          stbi__rewind( s );\n          return 0;\n      }\n   } while (chained);\n\n   *comp = (act_comp & 0x10 ? 4 : 3);\n\n   return 1;\n}\n#endif\n\n// *************************************************************************************************\n// Portable Gray Map and Portable Pixel Map loader\n// by Ken Miller\n//\n// PGM: http://netpbm.sourceforge.net/doc/pgm.html\n// PPM: http://netpbm.sourceforge.net/doc/ppm.html\n//\n// Known limitations:\n//    Does not support comments in the header section\n//    Does not support ASCII image data (formats P2 and P3)\n\n#ifndef STBI_NO_PNM\n\nstatic int      stbi__pnm_test(stbi__context *s)\n{\n   char p, t;\n   p = (char) stbi__get8(s);\n   t = (char) stbi__get8(s);\n   if (p != 'P' || (t != '5' && t != '6')) {\n       stbi__rewind( s );\n       return 0;\n   }\n   return 1;\n}\n\nstatic void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri)\n{\n   stbi_uc *out;\n   STBI_NOTUSED(ri);\n\n   ri->bits_per_channel = stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n);\n   if (ri->bits_per_channel == 0)\n      return 0;\n\n   if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc(\"too large\",\"Very large image (corrupt?)\");\n   if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc(\"too large\",\"Very large image (corrupt?)\");\n\n   *x = s->img_x;\n   *y = s->img_y;\n   if (comp) *comp = s->img_n;\n\n   if (!stbi__mad4sizes_valid(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0))\n      return stbi__errpuc(\"too large\", \"PNM too large\");\n\n   out = (stbi_uc *) stbi__malloc_mad4(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0);\n   if (!out) return stbi__errpuc(\"outofmem\", \"Out of memory\");\n   if (!stbi__getn(s, out, s->img_n * s->img_x * s->img_y * (ri->bits_per_channel / 8))) {\n      STBI_FREE(out);\n      return stbi__errpuc(\"bad PNM\", \"PNM file truncated\");\n   }\n\n   if (req_comp && req_comp != s->img_n) {\n      if (ri->bits_per_channel == 16) {\n         out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, s->img_n, req_comp, s->img_x, s->img_y);\n      } else {\n         out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y);\n      }\n      if (out == NULL) return out; // stbi__convert_format frees input on failure\n   }\n   return out;\n}\n\nstatic int      stbi__pnm_isspace(char c)\n{\n   return c == ' ' || c == '\\t' || c == '\\n' || c == '\\v' || c == '\\f' || c == '\\r';\n}\n\nstatic void     stbi__pnm_skip_whitespace(stbi__context *s, char *c)\n{\n   for (;;) {\n      while (!stbi__at_eof(s) && stbi__pnm_isspace(*c))\n         *c = (char) stbi__get8(s);\n\n      if (stbi__at_eof(s) || *c != '#')\n         break;\n\n      while (!stbi__at_eof(s) && *c != '\\n' && *c != '\\r' )\n         *c = (char) stbi__get8(s);\n   }\n}\n\nstatic int      stbi__pnm_isdigit(char c)\n{\n   return c >= '0' && c <= '9';\n}\n\nstatic int      stbi__pnm_getinteger(stbi__context *s, char *c)\n{\n   int value = 0;\n\n   while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) {\n      value = value*10 + (*c - '0');\n      *c = (char) stbi__get8(s);\n      if((value > 214748364) || (value == 214748364 && *c > '7'))\n          return stbi__err(\"integer parse overflow\", \"Parsing an integer in the PPM header overflowed a 32-bit int\");\n   }\n\n   return value;\n}\n\nstatic int      stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp)\n{\n   int maxv, dummy;\n   char c, p, t;\n\n   if (!x) x = &dummy;\n   if (!y) y = &dummy;\n   if (!comp) comp = &dummy;\n\n   stbi__rewind(s);\n\n   // Get identifier\n   p = (char) stbi__get8(s);\n   t = (char) stbi__get8(s);\n   if (p != 'P' || (t != '5' && t != '6')) {\n       stbi__rewind(s);\n       return 0;\n   }\n\n   *comp = (t == '6') ? 3 : 1;  // '5' is 1-component .pgm; '6' is 3-component .ppm\n\n   c = (char) stbi__get8(s);\n   stbi__pnm_skip_whitespace(s, &c);\n\n   *x = stbi__pnm_getinteger(s, &c); // read width\n   if(*x == 0)\n       return stbi__err(\"invalid width\", \"PPM image header had zero or overflowing width\");\n   stbi__pnm_skip_whitespace(s, &c);\n\n   *y = stbi__pnm_getinteger(s, &c); // read height\n   if (*y == 0)\n       return stbi__err(\"invalid width\", \"PPM image header had zero or overflowing width\");\n   stbi__pnm_skip_whitespace(s, &c);\n\n   maxv = stbi__pnm_getinteger(s, &c);  // read max value\n   if (maxv > 65535)\n      return stbi__err(\"max value > 65535\", \"PPM image supports only 8-bit and 16-bit images\");\n   else if (maxv > 255)\n      return 16;\n   else\n      return 8;\n}\n\nstatic int stbi__pnm_is16(stbi__context *s)\n{\n   if (stbi__pnm_info(s, NULL, NULL, NULL) == 16)\n\t   return 1;\n   return 0;\n}\n#endif\n\nstatic int stbi__info_main(stbi__context *s, int *x, int *y, int *comp)\n{\n   #ifndef STBI_NO_JPEG\n   if (stbi__jpeg_info(s, x, y, comp)) return 1;\n   #endif\n\n   #ifndef STBI_NO_PNG\n   if (stbi__png_info(s, x, y, comp))  return 1;\n   #endif\n\n   #ifndef STBI_NO_GIF\n   if (stbi__gif_info(s, x, y, comp))  return 1;\n   #endif\n\n   #ifndef STBI_NO_BMP\n   if (stbi__bmp_info(s, x, y, comp))  return 1;\n   #endif\n\n   #ifndef STBI_NO_PSD\n   if (stbi__psd_info(s, x, y, comp))  return 1;\n   #endif\n\n   #ifndef STBI_NO_PIC\n   if (stbi__pic_info(s, x, y, comp))  return 1;\n   #endif\n\n   #ifndef STBI_NO_PNM\n   if (stbi__pnm_info(s, x, y, comp))  return 1;\n   #endif\n\n   #ifndef STBI_NO_HDR\n   if (stbi__hdr_info(s, x, y, comp))  return 1;\n   #endif\n\n   // test tga last because it's a crappy test!\n   #ifndef STBI_NO_TGA\n   if (stbi__tga_info(s, x, y, comp))\n       return 1;\n   #endif\n   return stbi__err(\"unknown image type\", \"Image not of any known type, or corrupt\");\n}\n\nstatic int stbi__is_16_main(stbi__context *s)\n{\n   #ifndef STBI_NO_PNG\n   if (stbi__png_is16(s))  return 1;\n   #endif\n\n   #ifndef STBI_NO_PSD\n   if (stbi__psd_is16(s))  return 1;\n   #endif\n\n   #ifndef STBI_NO_PNM\n   if (stbi__pnm_is16(s))  return 1;\n   #endif\n   return 0;\n}\n\n#ifndef STBI_NO_STDIO\nSTBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp)\n{\n    FILE *f = stbi__fopen(filename, \"rb\");\n    int result;\n    if (!f) return stbi__err(\"can't fopen\", \"Unable to open file\");\n    result = stbi_info_from_file(f, x, y, comp);\n    fclose(f);\n    return result;\n}\n\nSTBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp)\n{\n   int r;\n   stbi__context s;\n   long pos = ftell(f);\n   stbi__start_file(&s, f);\n   r = stbi__info_main(&s,x,y,comp);\n   fseek(f,pos,SEEK_SET);\n   return r;\n}\n\nSTBIDEF int stbi_is_16_bit(char const *filename)\n{\n    FILE *f = stbi__fopen(filename, \"rb\");\n    int result;\n    if (!f) return stbi__err(\"can't fopen\", \"Unable to open file\");\n    result = stbi_is_16_bit_from_file(f);\n    fclose(f);\n    return result;\n}\n\nSTBIDEF int stbi_is_16_bit_from_file(FILE *f)\n{\n   int r;\n   stbi__context s;\n   long pos = ftell(f);\n   stbi__start_file(&s, f);\n   r = stbi__is_16_main(&s);\n   fseek(f,pos,SEEK_SET);\n   return r;\n}\n#endif // !STBI_NO_STDIO\n\nSTBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp)\n{\n   stbi__context s;\n   stbi__start_mem(&s,buffer,len);\n   return stbi__info_main(&s,x,y,comp);\n}\n\nSTBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, int *x, int *y, int *comp)\n{\n   stbi__context s;\n   stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user);\n   return stbi__info_main(&s,x,y,comp);\n}\n\nSTBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len)\n{\n   stbi__context s;\n   stbi__start_mem(&s,buffer,len);\n   return stbi__is_16_main(&s);\n}\n\nSTBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user)\n{\n   stbi__context s;\n   stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user);\n   return stbi__is_16_main(&s);\n}\n\n#endif // STB_IMAGE_IMPLEMENTATION\n\n/*\n   revision history:\n      2.20  (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs\n      2.19  (2018-02-11) fix warning\n      2.18  (2018-01-30) fix warnings\n      2.17  (2018-01-29) change sbti__shiftsigned to avoid clang -O2 bug\n                         1-bit BMP\n                         *_is_16_bit api\n                         avoid warnings\n      2.16  (2017-07-23) all functions have 16-bit variants;\n                         STBI_NO_STDIO works again;\n                         compilation fixes;\n                         fix rounding in unpremultiply;\n                         optimize vertical flip;\n                         disable raw_len validation;\n                         documentation fixes\n      2.15  (2017-03-18) fix png-1,2,4 bug; now all Imagenet JPGs decode;\n                         warning fixes; disable run-time SSE detection on gcc;\n                         uniform handling of optional \"return\" values;\n                         thread-safe initialization of zlib tables\n      2.14  (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs\n      2.13  (2016-11-29) add 16-bit API, only supported for PNG right now\n      2.12  (2016-04-02) fix typo in 2.11 PSD fix that caused crashes\n      2.11  (2016-04-02) allocate large structures on the stack\n                         remove white matting for transparent PSD\n                         fix reported channel count for PNG & BMP\n                         re-enable SSE2 in non-gcc 64-bit\n                         support RGB-formatted JPEG\n                         read 16-bit PNGs (only as 8-bit)\n      2.10  (2016-01-22) avoid warning introduced in 2.09 by STBI_REALLOC_SIZED\n      2.09  (2016-01-16) allow comments in PNM files\n                         16-bit-per-pixel TGA (not bit-per-component)\n                         info() for TGA could break due to .hdr handling\n                         info() for BMP to shares code instead of sloppy parse\n                         can use STBI_REALLOC_SIZED if allocator doesn't support realloc\n                         code cleanup\n      2.08  (2015-09-13) fix to 2.07 cleanup, reading RGB PSD as RGBA\n      2.07  (2015-09-13) fix compiler warnings\n                         partial animated GIF support\n                         limited 16-bpc PSD support\n                         #ifdef unused functions\n                         bug with < 92 byte PIC,PNM,HDR,TGA\n      2.06  (2015-04-19) fix bug where PSD returns wrong '*comp' value\n      2.05  (2015-04-19) fix bug in progressive JPEG handling, fix warning\n      2.04  (2015-04-15) try to re-enable SIMD on MinGW 64-bit\n      2.03  (2015-04-12) extra corruption checking (mmozeiko)\n                         stbi_set_flip_vertically_on_load (nguillemot)\n                         fix NEON support; fix mingw support\n      2.02  (2015-01-19) fix incorrect assert, fix warning\n      2.01  (2015-01-17) fix various warnings; suppress SIMD on gcc 32-bit without -msse2\n      2.00b (2014-12-25) fix STBI_MALLOC in progressive JPEG\n      2.00  (2014-12-25) optimize JPG, including x86 SSE2 & NEON SIMD (ryg)\n                         progressive JPEG (stb)\n                         PGM/PPM support (Ken Miller)\n                         STBI_MALLOC,STBI_REALLOC,STBI_FREE\n                         GIF bugfix -- seemingly never worked\n                         STBI_NO_*, STBI_ONLY_*\n      1.48  (2014-12-14) fix incorrectly-named assert()\n      1.47  (2014-12-14) 1/2/4-bit PNG support, both direct and paletted (Omar Cornut & stb)\n                         optimize PNG (ryg)\n                         fix bug in interlaced PNG with user-specified channel count (stb)\n      1.46  (2014-08-26)\n              fix broken tRNS chunk (colorkey-style transparency) in non-paletted PNG\n      1.45  (2014-08-16)\n              fix MSVC-ARM internal compiler error by wrapping malloc\n      1.44  (2014-08-07)\n              various warning fixes from Ronny Chevalier\n      1.43  (2014-07-15)\n              fix MSVC-only compiler problem in code changed in 1.42\n      1.42  (2014-07-09)\n              don't define _CRT_SECURE_NO_WARNINGS (affects user code)\n              fixes to stbi__cleanup_jpeg path\n              added STBI_ASSERT to avoid requiring assert.h\n      1.41  (2014-06-25)\n              fix search&replace from 1.36 that messed up comments/error messages\n      1.40  (2014-06-22)\n              fix gcc struct-initialization warning\n      1.39  (2014-06-15)\n              fix to TGA optimization when req_comp != number of components in TGA;\n              fix to GIF loading because BMP wasn't rewinding (whoops, no GIFs in my test suite)\n              add support for BMP version 5 (more ignored fields)\n      1.38  (2014-06-06)\n              suppress MSVC warnings on integer casts truncating values\n              fix accidental rename of 'skip' field of I/O\n      1.37  (2014-06-04)\n              remove duplicate typedef\n      1.36  (2014-06-03)\n              convert to header file single-file library\n              if de-iphone isn't set, load iphone images color-swapped instead of returning NULL\n      1.35  (2014-05-27)\n              various warnings\n              fix broken STBI_SIMD path\n              fix bug where stbi_load_from_file no longer left file pointer in correct place\n              fix broken non-easy path for 32-bit BMP (possibly never used)\n              TGA optimization by Arseny Kapoulkine\n      1.34  (unknown)\n              use STBI_NOTUSED in stbi__resample_row_generic(), fix one more leak in tga failure case\n      1.33  (2011-07-14)\n              make stbi_is_hdr work in STBI_NO_HDR (as specified), minor compiler-friendly improvements\n      1.32  (2011-07-13)\n              support for \"info\" function for all supported filetypes (SpartanJ)\n      1.31  (2011-06-20)\n              a few more leak fixes, bug in PNG handling (SpartanJ)\n      1.30  (2011-06-11)\n              added ability to load files via callbacks to accomidate custom input streams (Ben Wenger)\n              removed deprecated format-specific test/load functions\n              removed support for installable file formats (stbi_loader) -- would have been broken for IO callbacks anyway\n              error cases in bmp and tga give messages and don't leak (Raymond Barbiero, grisha)\n              fix inefficiency in decoding 32-bit BMP (David Woo)\n      1.29  (2010-08-16)\n              various warning fixes from Aurelien Pocheville\n      1.28  (2010-08-01)\n              fix bug in GIF palette transparency (SpartanJ)\n      1.27  (2010-08-01)\n              cast-to-stbi_uc to fix warnings\n      1.26  (2010-07-24)\n              fix bug in file buffering for PNG reported by SpartanJ\n      1.25  (2010-07-17)\n              refix trans_data warning (Won Chun)\n      1.24  (2010-07-12)\n              perf improvements reading from files on platforms with lock-heavy fgetc()\n              minor perf improvements for jpeg\n              deprecated type-specific functions so we'll get feedback if they're needed\n              attempt to fix trans_data warning (Won Chun)\n      1.23    fixed bug in iPhone support\n      1.22  (2010-07-10)\n              removed image *writing* support\n              stbi_info support from Jetro Lauha\n              GIF support from Jean-Marc Lienher\n              iPhone PNG-extensions from James Brown\n              warning-fixes from Nicolas Schulz and Janez Zemva (i.stbi__err. Janez (U+017D)emva)\n      1.21    fix use of 'stbi_uc' in header (reported by jon blow)\n      1.20    added support for Softimage PIC, by Tom Seddon\n      1.19    bug in interlaced PNG corruption check (found by ryg)\n      1.18  (2008-08-02)\n              fix a threading bug (local mutable static)\n      1.17    support interlaced PNG\n      1.16    major bugfix - stbi__convert_format converted one too many pixels\n      1.15    initialize some fields for thread safety\n      1.14    fix threadsafe conversion bug\n              header-file-only version (#define STBI_HEADER_FILE_ONLY before including)\n      1.13    threadsafe\n      1.12    const qualifiers in the API\n      1.11    Support installable IDCT, colorspace conversion routines\n      1.10    Fixes for 64-bit (don't use \"unsigned long\")\n              optimized upsampling by Fabian \"ryg\" Giesen\n      1.09    Fix format-conversion for PSD code (bad global variables!)\n      1.08    Thatcher Ulrich's PSD code integrated by Nicolas Schulz\n      1.07    attempt to fix C++ warning/errors again\n      1.06    attempt to fix C++ warning/errors again\n      1.05    fix TGA loading to return correct *comp and use good luminance calc\n      1.04    default float alpha is 1, not 255; use 'void *' for stbi_image_free\n      1.03    bugfixes to STBI_NO_STDIO, STBI_NO_HDR\n      1.02    support for (subset of) HDR files, float interface for preferred access to them\n      1.01    fix bug: possible bug in handling right-side up bmps... not sure\n              fix bug: the stbi__bmp_load() and stbi__tga_load() functions didn't work at all\n      1.00    interface to zlib that skips zlib header\n      0.99    correct handling of alpha in palette\n      0.98    TGA loader by lonesock; dynamically add loaders (untested)\n      0.97    jpeg errors on too large a file; also catch another malloc failure\n      0.96    fix detection of invalid v value - particleman@mollyrocket forum\n      0.95    during header scan, seek to markers in case of padding\n      0.94    STBI_NO_STDIO to disable stdio usage; rename all #defines the same\n      0.93    handle jpegtran output; verbose errors\n      0.92    read 4,8,16,24,32-bit BMP files of several formats\n      0.91    output 24-bit Windows 3.0 BMP files\n      0.90    fix a few more warnings; bump version number to approach 1.0\n      0.61    bugfixes due to Marc LeBlanc, Christopher Lloyd\n      0.60    fix compiling as c++\n      0.59    fix warnings: merge Dave Moore's -Wall fixes\n      0.58    fix bug: zlib uncompressed mode len/nlen was wrong endian\n      0.57    fix bug: jpg last huffman symbol before marker was >9 bits but less than 16 available\n      0.56    fix bug: zlib uncompressed mode len vs. nlen\n      0.55    fix bug: restart_interval not initialized to 0\n      0.54    allow NULL for 'int *comp'\n      0.53    fix bug in png 3->4; speedup png decoding\n      0.52    png handles req_comp=3,4 directly; minor cleanup; jpeg comments\n      0.51    obey req_comp requests, 1-component jpegs return as 1-component,\n              on 'test' only check type, not whether we support this variant\n      0.50  (2006-11-19)\n              first released version\n*/\n\n\n/*\n------------------------------------------------------------------------------\nThis software is available under 2 licenses -- choose whichever you prefer.\n------------------------------------------------------------------------------\nALTERNATIVE A - MIT License\nCopyright (c) 2017 Sean Barrett\nPermission is hereby granted, free of charge, to any person obtaining a copy of\nthis software and associated documentation files (the \"Software\"), to deal in\nthe Software without restriction, including without limitation the rights to\nuse, copy, modify, merge, publish, distribute, sublicense, and/or sell copies\nof the Software, and to permit persons to whom the Software is furnished to do\nso, subject to the following conditions:\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n------------------------------------------------------------------------------\nALTERNATIVE B - Public Domain (www.unlicense.org)\nThis is free and unencumbered software released into the public domain.\nAnyone is free to copy, modify, publish, use, compile, sell, or distribute this\nsoftware, either in source code form or as a compiled binary, for any purpose,\ncommercial or non-commercial, and by any means.\nIn jurisdictions that recognize copyright laws, the author or authors of this\nsoftware dedicate any and all copyright interest in the software to the public\ndomain. We make this dedication for the benefit of the public at large and to\nthe detriment of our heirs and successors. We intend this dedication to be an\novert act of relinquishment in perpetuity of all present and future rights to\nthis software under copyright law.\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN\nACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION\nWITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n------------------------------------------------------------------------------\n*/\n"
  },
  {
    "path": "examples/stb_image_write.h",
    "content": "/* stb_image_write - v1.16 - public domain - http://nothings.org/stb\n   writes out PNG/BMP/TGA/JPEG/HDR images to C stdio - Sean Barrett 2010-2015\n                                     no warranty implied; use at your own risk\n\n   Before #including,\n\n       #define STB_IMAGE_WRITE_IMPLEMENTATION\n\n   in the file that you want to have the implementation.\n\n   Will probably not work correctly with strict-aliasing optimizations.\n\nABOUT:\n\n   This header file is a library for writing images to C stdio or a callback.\n\n   The PNG output is not optimal; it is 20-50% larger than the file\n   written by a decent optimizing implementation; though providing a custom\n   zlib compress function (see STBIW_ZLIB_COMPRESS) can mitigate that.\n   This library is designed for source code compactness and simplicity,\n   not optimal image file size or run-time performance.\n\nBUILDING:\n\n   You can #define STBIW_ASSERT(x) before the #include to avoid using assert.h.\n   You can #define STBIW_MALLOC(), STBIW_REALLOC(), and STBIW_FREE() to replace\n   malloc,realloc,free.\n   You can #define STBIW_MEMMOVE() to replace memmove()\n   You can #define STBIW_ZLIB_COMPRESS to use a custom zlib-style compress function\n   for PNG compression (instead of the builtin one), it must have the following signature:\n   unsigned char * my_compress(unsigned char *data, int data_len, int *out_len, int quality);\n   The returned data will be freed with STBIW_FREE() (free() by default),\n   so it must be heap allocated with STBIW_MALLOC() (malloc() by default),\n\nUNICODE:\n\n   If compiling for Windows and you wish to use Unicode filenames, compile\n   with\n       #define STBIW_WINDOWS_UTF8\n   and pass utf8-encoded filenames. Call stbiw_convert_wchar_to_utf8 to convert\n   Windows wchar_t filenames to utf8.\n\nUSAGE:\n\n   There are five functions, one for each image file format:\n\n     int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes);\n     int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data);\n     int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data);\n     int stbi_write_jpg(char const *filename, int w, int h, int comp, const void *data, int quality);\n     int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data);\n\n     void stbi_flip_vertically_on_write(int flag); // flag is non-zero to flip data vertically\n\n   There are also five equivalent functions that use an arbitrary write function. You are\n   expected to open/close your file-equivalent before and after calling these:\n\n     int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void  *data, int stride_in_bytes);\n     int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void  *data);\n     int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void  *data);\n     int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data);\n     int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality);\n\n   where the callback is:\n      void stbi_write_func(void *context, void *data, int size);\n\n   You can configure it with these global variables:\n      int stbi_write_tga_with_rle;             // defaults to true; set to 0 to disable RLE\n      int stbi_write_png_compression_level;    // defaults to 8; set to higher for more compression\n      int stbi_write_force_png_filter;         // defaults to -1; set to 0..5 to force a filter mode\n\n\n   You can define STBI_WRITE_NO_STDIO to disable the file variant of these\n   functions, so the library will not use stdio.h at all. However, this will\n   also disable HDR writing, because it requires stdio for formatted output.\n\n   Each function returns 0 on failure and non-0 on success.\n\n   The functions create an image file defined by the parameters. The image\n   is a rectangle of pixels stored from left-to-right, top-to-bottom.\n   Each pixel contains 'comp' channels of data stored interleaved with 8-bits\n   per channel, in the following order: 1=Y, 2=YA, 3=RGB, 4=RGBA. (Y is\n   monochrome color.) The rectangle is 'w' pixels wide and 'h' pixels tall.\n   The *data pointer points to the first byte of the top-left-most pixel.\n   For PNG, \"stride_in_bytes\" is the distance in bytes from the first byte of\n   a row of pixels to the first byte of the next row of pixels.\n\n   PNG creates output files with the same number of components as the input.\n   The BMP format expands Y to RGB in the file format and does not\n   output alpha.\n\n   PNG supports writing rectangles of data even when the bytes storing rows of\n   data are not consecutive in memory (e.g. sub-rectangles of a larger image),\n   by supplying the stride between the beginning of adjacent rows. The other\n   formats do not. (Thus you cannot write a native-format BMP through the BMP\n   writer, both because it is in BGR order and because it may have padding\n   at the end of the line.)\n\n   PNG allows you to set the deflate compression level by setting the global\n   variable 'stbi_write_png_compression_level' (it defaults to 8).\n\n   HDR expects linear float data. Since the format is always 32-bit rgb(e)\n   data, alpha (if provided) is discarded, and for monochrome data it is\n   replicated across all three channels.\n\n   TGA supports RLE or non-RLE compressed data. To use non-RLE-compressed\n   data, set the global variable 'stbi_write_tga_with_rle' to 0.\n\n   JPEG does ignore alpha channels in input data; quality is between 1 and 100.\n   Higher quality looks better but results in a bigger image.\n   JPEG baseline (no JPEG progressive).\n\nCREDITS:\n\n\n   Sean Barrett           -    PNG/BMP/TGA\n   Baldur Karlsson        -    HDR\n   Jean-Sebastien Guay    -    TGA monochrome\n   Tim Kelsey             -    misc enhancements\n   Alan Hickman           -    TGA RLE\n   Emmanuel Julien        -    initial file IO callback implementation\n   Jon Olick              -    original jo_jpeg.cpp code\n   Daniel Gibson          -    integrate JPEG, allow external zlib\n   Aarni Koskela          -    allow choosing PNG filter\n\n   bugfixes:\n      github:Chribba\n      Guillaume Chereau\n      github:jry2\n      github:romigrou\n      Sergio Gonzalez\n      Jonas Karlsson\n      Filip Wasil\n      Thatcher Ulrich\n      github:poppolopoppo\n      Patrick Boettcher\n      github:xeekworx\n      Cap Petschulat\n      Simon Rodriguez\n      Ivan Tikhonov\n      github:ignotion\n      Adam Schackart\n      Andrew Kensler\n\nLICENSE\n\n  See end of file for license information.\n\n*/\n\n#ifndef INCLUDE_STB_IMAGE_WRITE_H\n#define INCLUDE_STB_IMAGE_WRITE_H\n\n#include <stdlib.h>\n\n// if STB_IMAGE_WRITE_STATIC causes problems, try defining STBIWDEF to 'inline' or 'static inline'\n#ifndef STBIWDEF\n#ifdef STB_IMAGE_WRITE_STATIC\n#define STBIWDEF  static\n#else\n#ifdef __cplusplus\n#define STBIWDEF  extern \"C\"\n#else\n#define STBIWDEF  extern\n#endif\n#endif\n#endif\n\n#ifndef STB_IMAGE_WRITE_STATIC  // C++ forbids static forward declarations\nSTBIWDEF int stbi_write_tga_with_rle;\nSTBIWDEF int stbi_write_png_compression_level;\nSTBIWDEF int stbi_write_force_png_filter;\n#endif\n\n#ifndef STBI_WRITE_NO_STDIO\nSTBIWDEF int stbi_write_png(char const *filename, int w, int h, int comp, const void  *data, int stride_in_bytes);\nSTBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void  *data);\nSTBIWDEF int stbi_write_tga(char const *filename, int w, int h, int comp, const void  *data);\nSTBIWDEF int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data);\nSTBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void  *data, int quality);\n\n#ifdef STBIW_WINDOWS_UTF8\nSTBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input);\n#endif\n#endif\n\ntypedef void stbi_write_func(void *context, void *data, int size);\n\nSTBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void  *data, int stride_in_bytes);\nSTBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void  *data);\nSTBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void  *data);\nSTBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data);\nSTBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void  *data, int quality);\n\nSTBIWDEF void stbi_flip_vertically_on_write(int flip_boolean);\n\n#endif//INCLUDE_STB_IMAGE_WRITE_H\n\n#ifdef STB_IMAGE_WRITE_IMPLEMENTATION\n\n#ifdef _WIN32\n   #ifndef _CRT_SECURE_NO_WARNINGS\n   #define _CRT_SECURE_NO_WARNINGS\n   #endif\n   #ifndef _CRT_NONSTDC_NO_DEPRECATE\n   #define _CRT_NONSTDC_NO_DEPRECATE\n   #endif\n#endif\n\n#ifndef STBI_WRITE_NO_STDIO\n#include <stdio.h>\n#endif // STBI_WRITE_NO_STDIO\n\n#include <stdarg.h>\n#include <stdlib.h>\n#include <string.h>\n#include <math.h>\n\n#if defined(STBIW_MALLOC) && defined(STBIW_FREE) && (defined(STBIW_REALLOC) || defined(STBIW_REALLOC_SIZED))\n// ok\n#elif !defined(STBIW_MALLOC) && !defined(STBIW_FREE) && !defined(STBIW_REALLOC) && !defined(STBIW_REALLOC_SIZED)\n// ok\n#else\n#error \"Must define all or none of STBIW_MALLOC, STBIW_FREE, and STBIW_REALLOC (or STBIW_REALLOC_SIZED).\"\n#endif\n\n#ifndef STBIW_MALLOC\n#define STBIW_MALLOC(sz)        malloc(sz)\n#define STBIW_REALLOC(p,newsz)  realloc(p,newsz)\n#define STBIW_FREE(p)           free(p)\n#endif\n\n#ifndef STBIW_REALLOC_SIZED\n#define STBIW_REALLOC_SIZED(p,oldsz,newsz) STBIW_REALLOC(p,newsz)\n#endif\n\n\n#ifndef STBIW_MEMMOVE\n#define STBIW_MEMMOVE(a,b,sz) memmove(a,b,sz)\n#endif\n\n\n#ifndef STBIW_ASSERT\n#include <assert.h>\n#define STBIW_ASSERT(x) assert(x)\n#endif\n\n#define STBIW_UCHAR(x) (unsigned char) ((x) & 0xff)\n\n#ifdef STB_IMAGE_WRITE_STATIC\nstatic int stbi_write_png_compression_level = 8;\nstatic int stbi_write_tga_with_rle = 1;\nstatic int stbi_write_force_png_filter = -1;\n#else\nint stbi_write_png_compression_level = 8;\nint stbi_write_tga_with_rle = 1;\nint stbi_write_force_png_filter = -1;\n#endif\n\nstatic int stbi__flip_vertically_on_write = 0;\n\nSTBIWDEF void stbi_flip_vertically_on_write(int flag)\n{\n   stbi__flip_vertically_on_write = flag;\n}\n\ntypedef struct\n{\n   stbi_write_func *func;\n   void *context;\n   unsigned char buffer[64];\n   int buf_used;\n} stbi__write_context;\n\n// initialize a callback-based context\nstatic void stbi__start_write_callbacks(stbi__write_context *s, stbi_write_func *c, void *context)\n{\n   s->func    = c;\n   s->context = context;\n}\n\n#ifndef STBI_WRITE_NO_STDIO\n\nstatic void stbi__stdio_write(void *context, void *data, int size)\n{\n   fwrite(data,1,size,(FILE*) context);\n}\n\n#if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8)\n#ifdef __cplusplus\n#define STBIW_EXTERN extern \"C\"\n#else\n#define STBIW_EXTERN extern\n#endif\nSTBIW_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide);\nSTBIW_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default);\n\nSTBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input)\n{\n   return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL);\n}\n#endif\n\nstatic FILE *stbiw__fopen(char const *filename, char const *mode)\n{\n   FILE *f;\n#if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8)\n   wchar_t wMode[64];\n   wchar_t wFilename[1024];\n   if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename)))\n      return 0;\n\n   if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode)))\n      return 0;\n\n#if defined(_MSC_VER) && _MSC_VER >= 1400\n   if (0 != _wfopen_s(&f, wFilename, wMode))\n      f = 0;\n#else\n   f = _wfopen(wFilename, wMode);\n#endif\n\n#elif defined(_MSC_VER) && _MSC_VER >= 1400\n   if (0 != fopen_s(&f, filename, mode))\n      f=0;\n#else\n   f = fopen(filename, mode);\n#endif\n   return f;\n}\n\nstatic int stbi__start_write_file(stbi__write_context *s, const char *filename)\n{\n   FILE *f = stbiw__fopen(filename, \"wb\");\n   stbi__start_write_callbacks(s, stbi__stdio_write, (void *) f);\n   return f != NULL;\n}\n\nstatic void stbi__end_write_file(stbi__write_context *s)\n{\n   fclose((FILE *)s->context);\n}\n\n#endif // !STBI_WRITE_NO_STDIO\n\ntypedef unsigned int stbiw_uint32;\ntypedef int stb_image_write_test[sizeof(stbiw_uint32)==4 ? 1 : -1];\n\nstatic void stbiw__writefv(stbi__write_context *s, const char *fmt, va_list v)\n{\n   while (*fmt) {\n      switch (*fmt++) {\n         case ' ': break;\n         case '1': { unsigned char x = STBIW_UCHAR(va_arg(v, int));\n                     s->func(s->context,&x,1);\n                     break; }\n         case '2': { int x = va_arg(v,int);\n                     unsigned char b[2];\n                     b[0] = STBIW_UCHAR(x);\n                     b[1] = STBIW_UCHAR(x>>8);\n                     s->func(s->context,b,2);\n                     break; }\n         case '4': { stbiw_uint32 x = va_arg(v,int);\n                     unsigned char b[4];\n                     b[0]=STBIW_UCHAR(x);\n                     b[1]=STBIW_UCHAR(x>>8);\n                     b[2]=STBIW_UCHAR(x>>16);\n                     b[3]=STBIW_UCHAR(x>>24);\n                     s->func(s->context,b,4);\n                     break; }\n         default:\n            STBIW_ASSERT(0);\n            return;\n      }\n   }\n}\n\nstatic void stbiw__writef(stbi__write_context *s, const char *fmt, ...)\n{\n   va_list v;\n   va_start(v, fmt);\n   stbiw__writefv(s, fmt, v);\n   va_end(v);\n}\n\nstatic void stbiw__write_flush(stbi__write_context *s)\n{\n   if (s->buf_used) {\n      s->func(s->context, &s->buffer, s->buf_used);\n      s->buf_used = 0;\n   }\n}\n\nstatic void stbiw__putc(stbi__write_context *s, unsigned char c)\n{\n   s->func(s->context, &c, 1);\n}\n\nstatic void stbiw__write1(stbi__write_context *s, unsigned char a)\n{\n   if ((size_t)s->buf_used + 1 > sizeof(s->buffer))\n      stbiw__write_flush(s);\n   s->buffer[s->buf_used++] = a;\n}\n\nstatic void stbiw__write3(stbi__write_context *s, unsigned char a, unsigned char b, unsigned char c)\n{\n   int n;\n   if ((size_t)s->buf_used + 3 > sizeof(s->buffer))\n      stbiw__write_flush(s);\n   n = s->buf_used;\n   s->buf_used = n+3;\n   s->buffer[n+0] = a;\n   s->buffer[n+1] = b;\n   s->buffer[n+2] = c;\n}\n\nstatic void stbiw__write_pixel(stbi__write_context *s, int rgb_dir, int comp, int write_alpha, int expand_mono, unsigned char *d)\n{\n   unsigned char bg[3] = { 255, 0, 255}, px[3];\n   int k;\n\n   if (write_alpha < 0)\n      stbiw__write1(s, d[comp - 1]);\n\n   switch (comp) {\n      case 2: // 2 pixels = mono + alpha, alpha is written separately, so same as 1-channel case\n      case 1:\n         if (expand_mono)\n            stbiw__write3(s, d[0], d[0], d[0]); // monochrome bmp\n         else\n            stbiw__write1(s, d[0]);  // monochrome TGA\n         break;\n      case 4:\n         if (!write_alpha) {\n            // composite against pink background\n            for (k = 0; k < 3; ++k)\n               px[k] = bg[k] + ((d[k] - bg[k]) * d[3]) / 255;\n            stbiw__write3(s, px[1 - rgb_dir], px[1], px[1 + rgb_dir]);\n            break;\n         }\n         /* FALLTHROUGH */\n      case 3:\n         stbiw__write3(s, d[1 - rgb_dir], d[1], d[1 + rgb_dir]);\n         break;\n   }\n   if (write_alpha > 0)\n      stbiw__write1(s, d[comp - 1]);\n}\n\nstatic void stbiw__write_pixels(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, void *data, int write_alpha, int scanline_pad, int expand_mono)\n{\n   stbiw_uint32 zero = 0;\n   int i,j, j_end;\n\n   if (y <= 0)\n      return;\n\n   if (stbi__flip_vertically_on_write)\n      vdir *= -1;\n\n   if (vdir < 0) {\n      j_end = -1; j = y-1;\n   } else {\n      j_end =  y; j = 0;\n   }\n\n   for (; j != j_end; j += vdir) {\n      for (i=0; i < x; ++i) {\n         unsigned char *d = (unsigned char *) data + (j*x+i)*comp;\n         stbiw__write_pixel(s, rgb_dir, comp, write_alpha, expand_mono, d);\n      }\n      stbiw__write_flush(s);\n      s->func(s->context, &zero, scanline_pad);\n   }\n}\n\nstatic int stbiw__outfile(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, int expand_mono, void *data, int alpha, int pad, const char *fmt, ...)\n{\n   if (y < 0 || x < 0) {\n      return 0;\n   } else {\n      va_list v;\n      va_start(v, fmt);\n      stbiw__writefv(s, fmt, v);\n      va_end(v);\n      stbiw__write_pixels(s,rgb_dir,vdir,x,y,comp,data,alpha,pad, expand_mono);\n      return 1;\n   }\n}\n\nstatic int stbi_write_bmp_core(stbi__write_context *s, int x, int y, int comp, const void *data)\n{\n   if (comp != 4) {\n      // write RGB bitmap\n      int pad = (-x*3) & 3;\n      return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *) data,0,pad,\n              \"11 4 22 4\" \"4 44 22 444444\",\n              'B', 'M', 14+40+(x*3+pad)*y, 0,0, 14+40,  // file header\n               40, x,y, 1,24, 0,0,0,0,0,0);             // bitmap header\n   } else {\n      // RGBA bitmaps need a v4 header\n      // use BI_BITFIELDS mode with 32bpp and alpha mask\n      // (straight BI_RGB with alpha mask doesn't work in most readers)\n      return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *)data,1,0,\n         \"11 4 22 4\" \"4 44 22 444444 4444 4 444 444 444 444\",\n         'B', 'M', 14+108+x*y*4, 0, 0, 14+108, // file header\n         108, x,y, 1,32, 3,0,0,0,0,0, 0xff0000,0xff00,0xff,0xff000000u, 0, 0,0,0, 0,0,0, 0,0,0, 0,0,0); // bitmap V4 header\n   }\n}\n\nSTBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data)\n{\n   stbi__write_context s = { 0 };\n   stbi__start_write_callbacks(&s, func, context);\n   return stbi_write_bmp_core(&s, x, y, comp, data);\n}\n\n#ifndef STBI_WRITE_NO_STDIO\nSTBIWDEF int stbi_write_bmp(char const *filename, int x, int y, int comp, const void *data)\n{\n   stbi__write_context s = { 0 };\n   if (stbi__start_write_file(&s,filename)) {\n      int r = stbi_write_bmp_core(&s, x, y, comp, data);\n      stbi__end_write_file(&s);\n      return r;\n   } else\n      return 0;\n}\n#endif //!STBI_WRITE_NO_STDIO\n\nstatic int stbi_write_tga_core(stbi__write_context *s, int x, int y, int comp, void *data)\n{\n   int has_alpha = (comp == 2 || comp == 4);\n   int colorbytes = has_alpha ? comp-1 : comp;\n   int format = colorbytes < 2 ? 3 : 2; // 3 color channels (RGB/RGBA) = 2, 1 color channel (Y/YA) = 3\n\n   if (y < 0 || x < 0)\n      return 0;\n\n   if (!stbi_write_tga_with_rle) {\n      return stbiw__outfile(s, -1, -1, x, y, comp, 0, (void *) data, has_alpha, 0,\n         \"111 221 2222 11\", 0, 0, format, 0, 0, 0, 0, 0, x, y, (colorbytes + has_alpha) * 8, has_alpha * 8);\n   } else {\n      int i,j,k;\n      int jend, jdir;\n\n      stbiw__writef(s, \"111 221 2222 11\", 0,0,format+8, 0,0,0, 0,0,x,y, (colorbytes + has_alpha) * 8, has_alpha * 8);\n\n      if (stbi__flip_vertically_on_write) {\n         j = 0;\n         jend = y;\n         jdir = 1;\n      } else {\n         j = y-1;\n         jend = -1;\n         jdir = -1;\n      }\n      for (; j != jend; j += jdir) {\n         unsigned char *row = (unsigned char *) data + j * x * comp;\n         int len;\n\n         for (i = 0; i < x; i += len) {\n            unsigned char *begin = row + i * comp;\n            int diff = 1;\n            len = 1;\n\n            if (i < x - 1) {\n               ++len;\n               diff = memcmp(begin, row + (i + 1) * comp, comp);\n               if (diff) {\n                  const unsigned char *prev = begin;\n                  for (k = i + 2; k < x && len < 128; ++k) {\n                     if (memcmp(prev, row + k * comp, comp)) {\n                        prev += comp;\n                        ++len;\n                     } else {\n                        --len;\n                        break;\n                     }\n                  }\n               } else {\n                  for (k = i + 2; k < x && len < 128; ++k) {\n                     if (!memcmp(begin, row + k * comp, comp)) {\n                        ++len;\n                     } else {\n                        break;\n                     }\n                  }\n               }\n            }\n\n            if (diff) {\n               unsigned char header = STBIW_UCHAR(len - 1);\n               stbiw__write1(s, header);\n               for (k = 0; k < len; ++k) {\n                  stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin + k * comp);\n               }\n            } else {\n               unsigned char header = STBIW_UCHAR(len - 129);\n               stbiw__write1(s, header);\n               stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin);\n            }\n         }\n      }\n      stbiw__write_flush(s);\n   }\n   return 1;\n}\n\nSTBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data)\n{\n   stbi__write_context s = { 0 };\n   stbi__start_write_callbacks(&s, func, context);\n   return stbi_write_tga_core(&s, x, y, comp, (void *) data);\n}\n\n#ifndef STBI_WRITE_NO_STDIO\nSTBIWDEF int stbi_write_tga(char const *filename, int x, int y, int comp, const void *data)\n{\n   stbi__write_context s = { 0 };\n   if (stbi__start_write_file(&s,filename)) {\n      int r = stbi_write_tga_core(&s, x, y, comp, (void *) data);\n      stbi__end_write_file(&s);\n      return r;\n   } else\n      return 0;\n}\n#endif\n\n// *************************************************************************************************\n// Radiance RGBE HDR writer\n// by Baldur Karlsson\n\n#define stbiw__max(a, b)  ((a) > (b) ? (a) : (b))\n\n#ifndef STBI_WRITE_NO_STDIO\n\nstatic void stbiw__linear_to_rgbe(unsigned char *rgbe, float *linear)\n{\n   int exponent;\n   float maxcomp = stbiw__max(linear[0], stbiw__max(linear[1], linear[2]));\n\n   if (maxcomp < 1e-32f) {\n      rgbe[0] = rgbe[1] = rgbe[2] = rgbe[3] = 0;\n   } else {\n      float normalize = (float) frexp(maxcomp, &exponent) * 256.0f/maxcomp;\n\n      rgbe[0] = (unsigned char)(linear[0] * normalize);\n      rgbe[1] = (unsigned char)(linear[1] * normalize);\n      rgbe[2] = (unsigned char)(linear[2] * normalize);\n      rgbe[3] = (unsigned char)(exponent + 128);\n   }\n}\n\nstatic void stbiw__write_run_data(stbi__write_context *s, int length, unsigned char databyte)\n{\n   unsigned char lengthbyte = STBIW_UCHAR(length+128);\n   STBIW_ASSERT(length+128 <= 255);\n   s->func(s->context, &lengthbyte, 1);\n   s->func(s->context, &databyte, 1);\n}\n\nstatic void stbiw__write_dump_data(stbi__write_context *s, int length, unsigned char *data)\n{\n   unsigned char lengthbyte = STBIW_UCHAR(length);\n   STBIW_ASSERT(length <= 128); // inconsistent with spec but consistent with official code\n   s->func(s->context, &lengthbyte, 1);\n   s->func(s->context, data, length);\n}\n\nstatic void stbiw__write_hdr_scanline(stbi__write_context *s, int width, int ncomp, unsigned char *scratch, float *scanline)\n{\n   unsigned char scanlineheader[4] = { 2, 2, 0, 0 };\n   unsigned char rgbe[4];\n   float linear[3];\n   int x;\n\n   scanlineheader[2] = (width&0xff00)>>8;\n   scanlineheader[3] = (width&0x00ff);\n\n   /* skip RLE for images too small or large */\n   if (width < 8 || width >= 32768) {\n      for (x=0; x < width; x++) {\n         switch (ncomp) {\n            case 4: /* fallthrough */\n            case 3: linear[2] = scanline[x*ncomp + 2];\n                    linear[1] = scanline[x*ncomp + 1];\n                    linear[0] = scanline[x*ncomp + 0];\n                    break;\n            default:\n                    linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0];\n                    break;\n         }\n         stbiw__linear_to_rgbe(rgbe, linear);\n         s->func(s->context, rgbe, 4);\n      }\n   } else {\n      int c,r;\n      /* encode into scratch buffer */\n      for (x=0; x < width; x++) {\n         switch(ncomp) {\n            case 4: /* fallthrough */\n            case 3: linear[2] = scanline[x*ncomp + 2];\n                    linear[1] = scanline[x*ncomp + 1];\n                    linear[0] = scanline[x*ncomp + 0];\n                    break;\n            default:\n                    linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0];\n                    break;\n         }\n         stbiw__linear_to_rgbe(rgbe, linear);\n         scratch[x + width*0] = rgbe[0];\n         scratch[x + width*1] = rgbe[1];\n         scratch[x + width*2] = rgbe[2];\n         scratch[x + width*3] = rgbe[3];\n      }\n\n      s->func(s->context, scanlineheader, 4);\n\n      /* RLE each component separately */\n      for (c=0; c < 4; c++) {\n         unsigned char *comp = &scratch[width*c];\n\n         x = 0;\n         while (x < width) {\n            // find first run\n            r = x;\n            while (r+2 < width) {\n               if (comp[r] == comp[r+1] && comp[r] == comp[r+2])\n                  break;\n               ++r;\n            }\n            if (r+2 >= width)\n               r = width;\n            // dump up to first run\n            while (x < r) {\n               int len = r-x;\n               if (len > 128) len = 128;\n               stbiw__write_dump_data(s, len, &comp[x]);\n               x += len;\n            }\n            // if there's a run, output it\n            if (r+2 < width) { // same test as what we break out of in search loop, so only true if we break'd\n               // find next byte after run\n               while (r < width && comp[r] == comp[x])\n                  ++r;\n               // output run up to r\n               while (x < r) {\n                  int len = r-x;\n                  if (len > 127) len = 127;\n                  stbiw__write_run_data(s, len, comp[x]);\n                  x += len;\n               }\n            }\n         }\n      }\n   }\n}\n\nstatic int stbi_write_hdr_core(stbi__write_context *s, int x, int y, int comp, float *data)\n{\n   if (y <= 0 || x <= 0 || data == NULL)\n      return 0;\n   else {\n      // Each component is stored separately. Allocate scratch space for full output scanline.\n      unsigned char *scratch = (unsigned char *) STBIW_MALLOC(x*4);\n      int i, len;\n      char buffer[128];\n      char header[] = \"#?RADIANCE\\n# Written by stb_image_write.h\\nFORMAT=32-bit_rle_rgbe\\n\";\n      s->func(s->context, header, sizeof(header)-1);\n\n#ifdef __STDC_LIB_EXT1__\n      len = sprintf_s(buffer, sizeof(buffer), \"EXPOSURE=          1.0000000000000\\n\\n-Y %d +X %d\\n\", y, x);\n#else\n      len = sprintf(buffer, \"EXPOSURE=          1.0000000000000\\n\\n-Y %d +X %d\\n\", y, x);\n#endif\n      s->func(s->context, buffer, len);\n\n      for(i=0; i < y; i++)\n         stbiw__write_hdr_scanline(s, x, comp, scratch, data + comp*x*(stbi__flip_vertically_on_write ? y-1-i : i));\n      STBIW_FREE(scratch);\n      return 1;\n   }\n}\n\nSTBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const float *data)\n{\n   stbi__write_context s = { 0 };\n   stbi__start_write_callbacks(&s, func, context);\n   return stbi_write_hdr_core(&s, x, y, comp, (float *) data);\n}\n\nSTBIWDEF int stbi_write_hdr(char const *filename, int x, int y, int comp, const float *data)\n{\n   stbi__write_context s = { 0 };\n   if (stbi__start_write_file(&s,filename)) {\n      int r = stbi_write_hdr_core(&s, x, y, comp, (float *) data);\n      stbi__end_write_file(&s);\n      return r;\n   } else\n      return 0;\n}\n#endif // STBI_WRITE_NO_STDIO\n\n\n//////////////////////////////////////////////////////////////////////////////\n//\n// PNG writer\n//\n\n#ifndef STBIW_ZLIB_COMPRESS\n// stretchy buffer; stbiw__sbpush() == vector<>::push_back() -- stbiw__sbcount() == vector<>::size()\n#define stbiw__sbraw(a) ((int *) (void *) (a) - 2)\n#define stbiw__sbm(a)   stbiw__sbraw(a)[0]\n#define stbiw__sbn(a)   stbiw__sbraw(a)[1]\n\n#define stbiw__sbneedgrow(a,n)  ((a)==0 || stbiw__sbn(a)+n >= stbiw__sbm(a))\n#define stbiw__sbmaybegrow(a,n) (stbiw__sbneedgrow(a,(n)) ? stbiw__sbgrow(a,n) : 0)\n#define stbiw__sbgrow(a,n)  stbiw__sbgrowf((void **) &(a), (n), sizeof(*(a)))\n\n#define stbiw__sbpush(a, v)      (stbiw__sbmaybegrow(a,1), (a)[stbiw__sbn(a)++] = (v))\n#define stbiw__sbcount(a)        ((a) ? stbiw__sbn(a) : 0)\n#define stbiw__sbfree(a)         ((a) ? STBIW_FREE(stbiw__sbraw(a)),0 : 0)\n\nstatic void *stbiw__sbgrowf(void **arr, int increment, int itemsize)\n{\n   int m = *arr ? 2*stbiw__sbm(*arr)+increment : increment+1;\n   void *p = STBIW_REALLOC_SIZED(*arr ? stbiw__sbraw(*arr) : 0, *arr ? (stbiw__sbm(*arr)*itemsize + sizeof(int)*2) : 0, itemsize * m + sizeof(int)*2);\n   STBIW_ASSERT(p);\n   if (p) {\n      if (!*arr) ((int *) p)[1] = 0;\n      *arr = (void *) ((int *) p + 2);\n      stbiw__sbm(*arr) = m;\n   }\n   return *arr;\n}\n\nstatic unsigned char *stbiw__zlib_flushf(unsigned char *data, unsigned int *bitbuffer, int *bitcount)\n{\n   while (*bitcount >= 8) {\n      stbiw__sbpush(data, STBIW_UCHAR(*bitbuffer));\n      *bitbuffer >>= 8;\n      *bitcount -= 8;\n   }\n   return data;\n}\n\nstatic int stbiw__zlib_bitrev(int code, int codebits)\n{\n   int res=0;\n   while (codebits--) {\n      res = (res << 1) | (code & 1);\n      code >>= 1;\n   }\n   return res;\n}\n\nstatic unsigned int stbiw__zlib_countm(unsigned char *a, unsigned char *b, int limit)\n{\n   int i;\n   for (i=0; i < limit && i < 258; ++i)\n      if (a[i] != b[i]) break;\n   return i;\n}\n\nstatic unsigned int stbiw__zhash(unsigned char *data)\n{\n   stbiw_uint32 hash = data[0] + (data[1] << 8) + (data[2] << 16);\n   hash ^= hash << 3;\n   hash += hash >> 5;\n   hash ^= hash << 4;\n   hash += hash >> 17;\n   hash ^= hash << 25;\n   hash += hash >> 6;\n   return hash;\n}\n\n#define stbiw__zlib_flush() (out = stbiw__zlib_flushf(out, &bitbuf, &bitcount))\n#define stbiw__zlib_add(code,codebits) \\\n      (bitbuf |= (code) << bitcount, bitcount += (codebits), stbiw__zlib_flush())\n#define stbiw__zlib_huffa(b,c)  stbiw__zlib_add(stbiw__zlib_bitrev(b,c),c)\n// default huffman tables\n#define stbiw__zlib_huff1(n)  stbiw__zlib_huffa(0x30 + (n), 8)\n#define stbiw__zlib_huff2(n)  stbiw__zlib_huffa(0x190 + (n)-144, 9)\n#define stbiw__zlib_huff3(n)  stbiw__zlib_huffa(0 + (n)-256,7)\n#define stbiw__zlib_huff4(n)  stbiw__zlib_huffa(0xc0 + (n)-280,8)\n#define stbiw__zlib_huff(n)  ((n) <= 143 ? stbiw__zlib_huff1(n) : (n) <= 255 ? stbiw__zlib_huff2(n) : (n) <= 279 ? stbiw__zlib_huff3(n) : stbiw__zlib_huff4(n))\n#define stbiw__zlib_huffb(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : stbiw__zlib_huff2(n))\n\n#define stbiw__ZHASH   16384\n\n#endif // STBIW_ZLIB_COMPRESS\n\nSTBIWDEF unsigned char * stbi_zlib_compress(unsigned char *data, int data_len, int *out_len, int quality)\n{\n#ifdef STBIW_ZLIB_COMPRESS\n   // user provided a zlib compress implementation, use that\n   return STBIW_ZLIB_COMPRESS(data, data_len, out_len, quality);\n#else // use builtin\n   static unsigned short lengthc[] = { 3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258, 259 };\n   static unsigned char  lengtheb[]= { 0,0,0,0,0,0,0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4,  4,  5,  5,  5,  5,  0 };\n   static unsigned short distc[]   = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577, 32768 };\n   static unsigned char  disteb[]  = { 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13 };\n   unsigned int bitbuf=0;\n   int i,j, bitcount=0;\n   unsigned char *out = NULL;\n   unsigned char ***hash_table = (unsigned char***) STBIW_MALLOC(stbiw__ZHASH * sizeof(unsigned char**));\n   if (hash_table == NULL)\n      return NULL;\n   if (quality < 5) quality = 5;\n\n   stbiw__sbpush(out, 0x78);   // DEFLATE 32K window\n   stbiw__sbpush(out, 0x5e);   // FLEVEL = 1\n   stbiw__zlib_add(1,1);  // BFINAL = 1\n   stbiw__zlib_add(1,2);  // BTYPE = 1 -- fixed huffman\n\n   for (i=0; i < stbiw__ZHASH; ++i)\n      hash_table[i] = NULL;\n\n   i=0;\n   while (i < data_len-3) {\n      // hash next 3 bytes of data to be compressed\n      int h = stbiw__zhash(data+i)&(stbiw__ZHASH-1), best=3;\n      unsigned char *bestloc = 0;\n      unsigned char **hlist = hash_table[h];\n      int n = stbiw__sbcount(hlist);\n      for (j=0; j < n; ++j) {\n         if (hlist[j]-data > i-32768) { // if entry lies within window\n            int d = stbiw__zlib_countm(hlist[j], data+i, data_len-i);\n            if (d >= best) { best=d; bestloc=hlist[j]; }\n         }\n      }\n      // when hash table entry is too long, delete half the entries\n      if (hash_table[h] && stbiw__sbn(hash_table[h]) == 2*quality) {\n         STBIW_MEMMOVE(hash_table[h], hash_table[h]+quality, sizeof(hash_table[h][0])*quality);\n         stbiw__sbn(hash_table[h]) = quality;\n      }\n      stbiw__sbpush(hash_table[h],data+i);\n\n      if (bestloc) {\n         // \"lazy matching\" - check match at *next* byte, and if it's better, do cur byte as literal\n         h = stbiw__zhash(data+i+1)&(stbiw__ZHASH-1);\n         hlist = hash_table[h];\n         n = stbiw__sbcount(hlist);\n         for (j=0; j < n; ++j) {\n            if (hlist[j]-data > i-32767) {\n               int e = stbiw__zlib_countm(hlist[j], data+i+1, data_len-i-1);\n               if (e > best) { // if next match is better, bail on current match\n                  bestloc = NULL;\n                  break;\n               }\n            }\n         }\n      }\n\n      if (bestloc) {\n         int d = (int) (data+i - bestloc); // distance back\n         STBIW_ASSERT(d <= 32767 && best <= 258);\n         for (j=0; best > lengthc[j+1]-1; ++j);\n         stbiw__zlib_huff(j+257);\n         if (lengtheb[j]) stbiw__zlib_add(best - lengthc[j], lengtheb[j]);\n         for (j=0; d > distc[j+1]-1; ++j);\n         stbiw__zlib_add(stbiw__zlib_bitrev(j,5),5);\n         if (disteb[j]) stbiw__zlib_add(d - distc[j], disteb[j]);\n         i += best;\n      } else {\n         stbiw__zlib_huffb(data[i]);\n         ++i;\n      }\n   }\n   // write out final bytes\n   for (;i < data_len; ++i)\n      stbiw__zlib_huffb(data[i]);\n   stbiw__zlib_huff(256); // end of block\n   // pad with 0 bits to byte boundary\n   while (bitcount)\n      stbiw__zlib_add(0,1);\n\n   for (i=0; i < stbiw__ZHASH; ++i)\n      (void) stbiw__sbfree(hash_table[i]);\n   STBIW_FREE(hash_table);\n\n   // store uncompressed instead if compression was worse\n   if (stbiw__sbn(out) > data_len + 2 + ((data_len+32766)/32767)*5) {\n      stbiw__sbn(out) = 2;  // truncate to DEFLATE 32K window and FLEVEL = 1\n      for (j = 0; j < data_len;) {\n         int blocklen = data_len - j;\n         if (blocklen > 32767) blocklen = 32767;\n         stbiw__sbpush(out, data_len - j == blocklen); // BFINAL = ?, BTYPE = 0 -- no compression\n         stbiw__sbpush(out, STBIW_UCHAR(blocklen)); // LEN\n         stbiw__sbpush(out, STBIW_UCHAR(blocklen >> 8));\n         stbiw__sbpush(out, STBIW_UCHAR(~blocklen)); // NLEN\n         stbiw__sbpush(out, STBIW_UCHAR(~blocklen >> 8));\n         memcpy(out+stbiw__sbn(out), data+j, blocklen);\n         stbiw__sbn(out) += blocklen;\n         j += blocklen;\n      }\n   }\n\n   {\n      // compute adler32 on input\n      unsigned int s1=1, s2=0;\n      int blocklen = (int) (data_len % 5552);\n      j=0;\n      while (j < data_len) {\n         for (i=0; i < blocklen; ++i) { s1 += data[j+i]; s2 += s1; }\n         s1 %= 65521; s2 %= 65521;\n         j += blocklen;\n         blocklen = 5552;\n      }\n      stbiw__sbpush(out, STBIW_UCHAR(s2 >> 8));\n      stbiw__sbpush(out, STBIW_UCHAR(s2));\n      stbiw__sbpush(out, STBIW_UCHAR(s1 >> 8));\n      stbiw__sbpush(out, STBIW_UCHAR(s1));\n   }\n   *out_len = stbiw__sbn(out);\n   // make returned pointer freeable\n   STBIW_MEMMOVE(stbiw__sbraw(out), out, *out_len);\n   return (unsigned char *) stbiw__sbraw(out);\n#endif // STBIW_ZLIB_COMPRESS\n}\n\nstatic unsigned int stbiw__crc32(unsigned char *buffer, int len)\n{\n#ifdef STBIW_CRC32\n    return STBIW_CRC32(buffer, len);\n#else\n   static unsigned int crc_table[256] =\n   {\n      0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F, 0xE963A535, 0x9E6495A3,\n      0x0eDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988, 0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91,\n      0x1DB71064, 0x6AB020F2, 0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7,\n      0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, 0x14015C4F, 0x63066CD9, 0xFA0F3D63, 0x8D080DF5,\n      0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172, 0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B,\n      0x35B5A8FA, 0x42B2986C, 0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59,\n      0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423, 0xCFBA9599, 0xB8BDA50F,\n      0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924, 0x2F6F7C87, 0x58684C11, 0xC1611DAB, 0xB6662D3D,\n      0x76DC4190, 0x01DB7106, 0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433,\n      0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D, 0x91646C97, 0xE6635C01,\n      0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E, 0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457,\n      0x65B0D9C6, 0x12B7E950, 0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65,\n      0x4DB26158, 0x3AB551CE, 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7, 0xA4D1C46D, 0xD3D6F4FB,\n      0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0, 0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9,\n      0x5005713C, 0x270241AA, 0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, 0xCE61E49F,\n      0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81, 0xB7BD5C3B, 0xC0BA6CAD,\n      0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A, 0xEAD54739, 0x9DD277AF, 0x04DB2615, 0x73DC1683,\n      0xE3630B12, 0x94643B84, 0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1,\n      0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB, 0x196C3671, 0x6E6B06E7,\n      0xFED41B76, 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC, 0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5,\n      0xD6D6A3E8, 0xA1D1937E, 0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B,\n      0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55, 0x316E8EEF, 0x4669BE79,\n      0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236, 0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F,\n      0xC5BA3BBE, 0xB2BD0B28, 0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D,\n      0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F, 0x72076785, 0x05005713,\n      0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38, 0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21,\n      0x86D3D2D4, 0xF1D4E242, 0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777,\n      0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69, 0x616BFFD3, 0x166CCF45,\n      0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2, 0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB,\n      0xAED16A4A, 0xD9D65ADC, 0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9,\n      0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, 0xCDD70693, 0x54DE5729, 0x23D967BF,\n      0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94, 0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D\n   };\n\n   unsigned int crc = ~0u;\n   int i;\n   for (i=0; i < len; ++i)\n      crc = (crc >> 8) ^ crc_table[buffer[i] ^ (crc & 0xff)];\n   return ~crc;\n#endif\n}\n\n#define stbiw__wpng4(o,a,b,c,d) ((o)[0]=STBIW_UCHAR(a),(o)[1]=STBIW_UCHAR(b),(o)[2]=STBIW_UCHAR(c),(o)[3]=STBIW_UCHAR(d),(o)+=4)\n#define stbiw__wp32(data,v) stbiw__wpng4(data, (v)>>24,(v)>>16,(v)>>8,(v));\n#define stbiw__wptag(data,s) stbiw__wpng4(data, s[0],s[1],s[2],s[3])\n\nstatic void stbiw__wpcrc(unsigned char **data, int len)\n{\n   unsigned int crc = stbiw__crc32(*data - len - 4, len+4);\n   stbiw__wp32(*data, crc);\n}\n\nstatic unsigned char stbiw__paeth(int a, int b, int c)\n{\n   int p = a + b - c, pa = abs(p-a), pb = abs(p-b), pc = abs(p-c);\n   if (pa <= pb && pa <= pc) return STBIW_UCHAR(a);\n   if (pb <= pc) return STBIW_UCHAR(b);\n   return STBIW_UCHAR(c);\n}\n\n// @OPTIMIZE: provide an option that always forces left-predict or paeth predict\nstatic void stbiw__encode_png_line(unsigned char *pixels, int stride_bytes, int width, int height, int y, int n, int filter_type, signed char *line_buffer)\n{\n   static int mapping[] = { 0,1,2,3,4 };\n   static int firstmap[] = { 0,1,0,5,6 };\n   int *mymap = (y != 0) ? mapping : firstmap;\n   int i;\n   int type = mymap[filter_type];\n   unsigned char *z = pixels + stride_bytes * (stbi__flip_vertically_on_write ? height-1-y : y);\n   int signed_stride = stbi__flip_vertically_on_write ? -stride_bytes : stride_bytes;\n\n   if (type==0) {\n      memcpy(line_buffer, z, width*n);\n      return;\n   }\n\n   // first loop isn't optimized since it's just one pixel\n   for (i = 0; i < n; ++i) {\n      switch (type) {\n         case 1: line_buffer[i] = z[i]; break;\n         case 2: line_buffer[i] = z[i] - z[i-signed_stride]; break;\n         case 3: line_buffer[i] = z[i] - (z[i-signed_stride]>>1); break;\n         case 4: line_buffer[i] = (signed char) (z[i] - stbiw__paeth(0,z[i-signed_stride],0)); break;\n         case 5: line_buffer[i] = z[i]; break;\n         case 6: line_buffer[i] = z[i]; break;\n      }\n   }\n   switch (type) {\n      case 1: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-n]; break;\n      case 2: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-signed_stride]; break;\n      case 3: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - ((z[i-n] + z[i-signed_stride])>>1); break;\n      case 4: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], z[i-signed_stride], z[i-signed_stride-n]); break;\n      case 5: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - (z[i-n]>>1); break;\n      case 6: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], 0,0); break;\n   }\n}\n\nSTBIWDEF unsigned char *stbi_write_png_to_mem(const unsigned char *pixels, int stride_bytes, int x, int y, int n, int *out_len)\n{\n   int force_filter = stbi_write_force_png_filter;\n   int ctype[5] = { -1, 0, 4, 2, 6 };\n   unsigned char sig[8] = { 137,80,78,71,13,10,26,10 };\n   unsigned char *out,*o, *filt, *zlib;\n   signed char *line_buffer;\n   int j,zlen;\n\n   if (stride_bytes == 0)\n      stride_bytes = x * n;\n\n   if (force_filter >= 5) {\n      force_filter = -1;\n   }\n\n   filt = (unsigned char *) STBIW_MALLOC((x*n+1) * y); if (!filt) return 0;\n   line_buffer = (signed char *) STBIW_MALLOC(x * n); if (!line_buffer) { STBIW_FREE(filt); return 0; }\n   for (j=0; j < y; ++j) {\n      int filter_type;\n      if (force_filter > -1) {\n         filter_type = force_filter;\n         stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, force_filter, line_buffer);\n      } else { // Estimate the best filter by running through all of them:\n         int best_filter = 0, best_filter_val = 0x7fffffff, est, i;\n         for (filter_type = 0; filter_type < 5; filter_type++) {\n            stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, filter_type, line_buffer);\n\n            // Estimate the entropy of the line using this filter; the less, the better.\n            est = 0;\n            for (i = 0; i < x*n; ++i) {\n               est += abs((signed char) line_buffer[i]);\n            }\n            if (est < best_filter_val) {\n               best_filter_val = est;\n               best_filter = filter_type;\n            }\n         }\n         if (filter_type != best_filter) {  // If the last iteration already got us the best filter, don't redo it\n            stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, best_filter, line_buffer);\n            filter_type = best_filter;\n         }\n      }\n      // when we get here, filter_type contains the filter type, and line_buffer contains the data\n      filt[j*(x*n+1)] = (unsigned char) filter_type;\n      STBIW_MEMMOVE(filt+j*(x*n+1)+1, line_buffer, x*n);\n   }\n   STBIW_FREE(line_buffer);\n   zlib = stbi_zlib_compress(filt, y*( x*n+1), &zlen, stbi_write_png_compression_level);\n   STBIW_FREE(filt);\n   if (!zlib) return 0;\n\n   // each tag requires 12 bytes of overhead\n   out = (unsigned char *) STBIW_MALLOC(8 + 12+13 + 12+zlen + 12);\n   if (!out) return 0;\n   *out_len = 8 + 12+13 + 12+zlen + 12;\n\n   o=out;\n   STBIW_MEMMOVE(o,sig,8); o+= 8;\n   stbiw__wp32(o, 13); // header length\n   stbiw__wptag(o, \"IHDR\");\n   stbiw__wp32(o, x);\n   stbiw__wp32(o, y);\n   *o++ = 8;\n   *o++ = STBIW_UCHAR(ctype[n]);\n   *o++ = 0;\n   *o++ = 0;\n   *o++ = 0;\n   stbiw__wpcrc(&o,13);\n\n   stbiw__wp32(o, zlen);\n   stbiw__wptag(o, \"IDAT\");\n   STBIW_MEMMOVE(o, zlib, zlen);\n   o += zlen;\n   STBIW_FREE(zlib);\n   stbiw__wpcrc(&o, zlen);\n\n   stbiw__wp32(o,0);\n   stbiw__wptag(o, \"IEND\");\n   stbiw__wpcrc(&o,0);\n\n   STBIW_ASSERT(o == out + *out_len);\n\n   return out;\n}\n\n#ifndef STBI_WRITE_NO_STDIO\nSTBIWDEF int stbi_write_png(char const *filename, int x, int y, int comp, const void *data, int stride_bytes)\n{\n   FILE *f;\n   int len;\n   unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len);\n   if (png == NULL) return 0;\n\n   f = stbiw__fopen(filename, \"wb\");\n   if (!f) { STBIW_FREE(png); return 0; }\n   fwrite(png, 1, len, f);\n   fclose(f);\n   STBIW_FREE(png);\n   return 1;\n}\n#endif\n\nSTBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int stride_bytes)\n{\n   int len;\n   unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len);\n   if (png == NULL) return 0;\n   func(context, png, len);\n   STBIW_FREE(png);\n   return 1;\n}\n\n\n/* ***************************************************************************\n *\n * JPEG writer\n *\n * This is based on Jon Olick's jo_jpeg.cpp:\n * public domain Simple, Minimalistic JPEG writer - http://www.jonolick.com/code.html\n */\n\nstatic const unsigned char stbiw__jpg_ZigZag[] = { 0,1,5,6,14,15,27,28,2,4,7,13,16,26,29,42,3,8,12,17,25,30,41,43,9,11,18,\n      24,31,40,44,53,10,19,23,32,39,45,52,54,20,22,33,38,46,51,55,60,21,34,37,47,50,56,59,61,35,36,48,49,57,58,62,63 };\n\nstatic void stbiw__jpg_writeBits(stbi__write_context *s, int *bitBufP, int *bitCntP, const unsigned short *bs) {\n   int bitBuf = *bitBufP, bitCnt = *bitCntP;\n   bitCnt += bs[1];\n   bitBuf |= bs[0] << (24 - bitCnt);\n   while(bitCnt >= 8) {\n      unsigned char c = (bitBuf >> 16) & 255;\n      stbiw__putc(s, c);\n      if(c == 255) {\n         stbiw__putc(s, 0);\n      }\n      bitBuf <<= 8;\n      bitCnt -= 8;\n   }\n   *bitBufP = bitBuf;\n   *bitCntP = bitCnt;\n}\n\nstatic void stbiw__jpg_DCT(float *d0p, float *d1p, float *d2p, float *d3p, float *d4p, float *d5p, float *d6p, float *d7p) {\n   float d0 = *d0p, d1 = *d1p, d2 = *d2p, d3 = *d3p, d4 = *d4p, d5 = *d5p, d6 = *d6p, d7 = *d7p;\n   float z1, z2, z3, z4, z5, z11, z13;\n\n   float tmp0 = d0 + d7;\n   float tmp7 = d0 - d7;\n   float tmp1 = d1 + d6;\n   float tmp6 = d1 - d6;\n   float tmp2 = d2 + d5;\n   float tmp5 = d2 - d5;\n   float tmp3 = d3 + d4;\n   float tmp4 = d3 - d4;\n\n   // Even part\n   float tmp10 = tmp0 + tmp3;   // phase 2\n   float tmp13 = tmp0 - tmp3;\n   float tmp11 = tmp1 + tmp2;\n   float tmp12 = tmp1 - tmp2;\n\n   d0 = tmp10 + tmp11;       // phase 3\n   d4 = tmp10 - tmp11;\n\n   z1 = (tmp12 + tmp13) * 0.707106781f; // c4\n   d2 = tmp13 + z1;       // phase 5\n   d6 = tmp13 - z1;\n\n   // Odd part\n   tmp10 = tmp4 + tmp5;       // phase 2\n   tmp11 = tmp5 + tmp6;\n   tmp12 = tmp6 + tmp7;\n\n   // The rotator is modified from fig 4-8 to avoid extra negations.\n   z5 = (tmp10 - tmp12) * 0.382683433f; // c6\n   z2 = tmp10 * 0.541196100f + z5; // c2-c6\n   z4 = tmp12 * 1.306562965f + z5; // c2+c6\n   z3 = tmp11 * 0.707106781f; // c4\n\n   z11 = tmp7 + z3;      // phase 5\n   z13 = tmp7 - z3;\n\n   *d5p = z13 + z2;         // phase 6\n   *d3p = z13 - z2;\n   *d1p = z11 + z4;\n   *d7p = z11 - z4;\n\n   *d0p = d0;  *d2p = d2;  *d4p = d4;  *d6p = d6;\n}\n\nstatic void stbiw__jpg_calcBits(int val, unsigned short bits[2]) {\n   int tmp1 = val < 0 ? -val : val;\n   val = val < 0 ? val-1 : val;\n   bits[1] = 1;\n   while(tmp1 >>= 1) {\n      ++bits[1];\n   }\n   bits[0] = val & ((1<<bits[1])-1);\n}\n\nstatic int stbiw__jpg_processDU(stbi__write_context *s, int *bitBuf, int *bitCnt, float *CDU, int du_stride, float *fdtbl, int DC, const unsigned short HTDC[256][2], const unsigned short HTAC[256][2]) {\n   const unsigned short EOB[2] = { HTAC[0x00][0], HTAC[0x00][1] };\n   const unsigned short M16zeroes[2] = { HTAC[0xF0][0], HTAC[0xF0][1] };\n   int dataOff, i, j, n, diff, end0pos, x, y;\n   int DU[64];\n\n   // DCT rows\n   for(dataOff=0, n=du_stride*8; dataOff<n; dataOff+=du_stride) {\n      stbiw__jpg_DCT(&CDU[dataOff], &CDU[dataOff+1], &CDU[dataOff+2], &CDU[dataOff+3], &CDU[dataOff+4], &CDU[dataOff+5], &CDU[dataOff+6], &CDU[dataOff+7]);\n   }\n   // DCT columns\n   for(dataOff=0; dataOff<8; ++dataOff) {\n      stbiw__jpg_DCT(&CDU[dataOff], &CDU[dataOff+du_stride], &CDU[dataOff+du_stride*2], &CDU[dataOff+du_stride*3], &CDU[dataOff+du_stride*4],\n                     &CDU[dataOff+du_stride*5], &CDU[dataOff+du_stride*6], &CDU[dataOff+du_stride*7]);\n   }\n   // Quantize/descale/zigzag the coefficients\n   for(y = 0, j=0; y < 8; ++y) {\n      for(x = 0; x < 8; ++x,++j) {\n         float v;\n         i = y*du_stride+x;\n         v = CDU[i]*fdtbl[j];\n         // DU[stbiw__jpg_ZigZag[j]] = (int)(v < 0 ? ceilf(v - 0.5f) : floorf(v + 0.5f));\n         // ceilf() and floorf() are C99, not C89, but I /think/ they're not needed here anyway?\n         DU[stbiw__jpg_ZigZag[j]] = (int)(v < 0 ? v - 0.5f : v + 0.5f);\n      }\n   }\n\n   // Encode DC\n   diff = DU[0] - DC;\n   if (diff == 0) {\n      stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTDC[0]);\n   } else {\n      unsigned short bits[2];\n      stbiw__jpg_calcBits(diff, bits);\n      stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTDC[bits[1]]);\n      stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits);\n   }\n   // Encode ACs\n   end0pos = 63;\n   for(; (end0pos>0)&&(DU[end0pos]==0); --end0pos) {\n   }\n   // end0pos = first element in reverse order !=0\n   if(end0pos == 0) {\n      stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB);\n      return DU[0];\n   }\n   for(i = 1; i <= end0pos; ++i) {\n      int startpos = i;\n      int nrzeroes;\n      unsigned short bits[2];\n      for (; DU[i]==0 && i<=end0pos; ++i) {\n      }\n      nrzeroes = i-startpos;\n      if ( nrzeroes >= 16 ) {\n         int lng = nrzeroes>>4;\n         int nrmarker;\n         for (nrmarker=1; nrmarker <= lng; ++nrmarker)\n            stbiw__jpg_writeBits(s, bitBuf, bitCnt, M16zeroes);\n         nrzeroes &= 15;\n      }\n      stbiw__jpg_calcBits(DU[i], bits);\n      stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTAC[(nrzeroes<<4)+bits[1]]);\n      stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits);\n   }\n   if(end0pos != 63) {\n      stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB);\n   }\n   return DU[0];\n}\n\nstatic int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality) {\n   // Constants that don't pollute global namespace\n   static const unsigned char std_dc_luminance_nrcodes[] = {0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0};\n   static const unsigned char std_dc_luminance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11};\n   static const unsigned char std_ac_luminance_nrcodes[] = {0,0,2,1,3,3,2,4,3,5,5,4,4,0,0,1,0x7d};\n   static const unsigned char std_ac_luminance_values[] = {\n      0x01,0x02,0x03,0x00,0x04,0x11,0x05,0x12,0x21,0x31,0x41,0x06,0x13,0x51,0x61,0x07,0x22,0x71,0x14,0x32,0x81,0x91,0xa1,0x08,\n      0x23,0x42,0xb1,0xc1,0x15,0x52,0xd1,0xf0,0x24,0x33,0x62,0x72,0x82,0x09,0x0a,0x16,0x17,0x18,0x19,0x1a,0x25,0x26,0x27,0x28,\n      0x29,0x2a,0x34,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,0x59,\n      0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x83,0x84,0x85,0x86,0x87,0x88,0x89,\n      0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,0xb5,0xb6,\n      0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,0xe1,0xe2,\n      0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf1,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa\n   };\n   static const unsigned char std_dc_chrominance_nrcodes[] = {0,0,3,1,1,1,1,1,1,1,1,1,0,0,0,0,0};\n   static const unsigned char std_dc_chrominance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11};\n   static const unsigned char std_ac_chrominance_nrcodes[] = {0,0,2,1,2,4,4,3,4,7,5,4,4,0,1,2,0x77};\n   static const unsigned char std_ac_chrominance_values[] = {\n      0x00,0x01,0x02,0x03,0x11,0x04,0x05,0x21,0x31,0x06,0x12,0x41,0x51,0x07,0x61,0x71,0x13,0x22,0x32,0x81,0x08,0x14,0x42,0x91,\n      0xa1,0xb1,0xc1,0x09,0x23,0x33,0x52,0xf0,0x15,0x62,0x72,0xd1,0x0a,0x16,0x24,0x34,0xe1,0x25,0xf1,0x17,0x18,0x19,0x1a,0x26,\n      0x27,0x28,0x29,0x2a,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,\n      0x59,0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x82,0x83,0x84,0x85,0x86,0x87,\n      0x88,0x89,0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,\n      0xb5,0xb6,0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,\n      0xe2,0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa\n   };\n   // Huffman tables\n   static const unsigned short YDC_HT[256][2] = { {0,2},{2,3},{3,3},{4,3},{5,3},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9}};\n   static const unsigned short UVDC_HT[256][2] = { {0,2},{1,2},{2,2},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9},{1022,10},{2046,11}};\n   static const unsigned short YAC_HT[256][2] = {\n      {10,4},{0,2},{1,2},{4,3},{11,4},{26,5},{120,7},{248,8},{1014,10},{65410,16},{65411,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {12,4},{27,5},{121,7},{502,9},{2038,11},{65412,16},{65413,16},{65414,16},{65415,16},{65416,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {28,5},{249,8},{1015,10},{4084,12},{65417,16},{65418,16},{65419,16},{65420,16},{65421,16},{65422,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {58,6},{503,9},{4085,12},{65423,16},{65424,16},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {59,6},{1016,10},{65430,16},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {122,7},{2039,11},{65438,16},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {123,7},{4086,12},{65446,16},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {250,8},{4087,12},{65454,16},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {504,9},{32704,15},{65462,16},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {505,9},{65470,16},{65471,16},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {506,9},{65479,16},{65480,16},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {1017,10},{65488,16},{65489,16},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {1018,10},{65497,16},{65498,16},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {2040,11},{65506,16},{65507,16},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {65515,16},{65516,16},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {2041,11},{65525,16},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0}\n   };\n   static const unsigned short UVAC_HT[256][2] = {\n      {0,2},{1,2},{4,3},{10,4},{24,5},{25,5},{56,6},{120,7},{500,9},{1014,10},{4084,12},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {11,4},{57,6},{246,8},{501,9},{2038,11},{4085,12},{65416,16},{65417,16},{65418,16},{65419,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {26,5},{247,8},{1015,10},{4086,12},{32706,15},{65420,16},{65421,16},{65422,16},{65423,16},{65424,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {27,5},{248,8},{1016,10},{4087,12},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{65430,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {58,6},{502,9},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{65438,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {59,6},{1017,10},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{65446,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {121,7},{2039,11},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{65454,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {122,7},{2040,11},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{65462,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {249,8},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{65470,16},{65471,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {503,9},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{65479,16},{65480,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {504,9},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{65488,16},{65489,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {505,9},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{65497,16},{65498,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {506,9},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{65506,16},{65507,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {2041,11},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{65515,16},{65516,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {16352,14},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{65525,16},{0,0},{0,0},{0,0},{0,0},{0,0},\n      {1018,10},{32707,15},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0}\n   };\n   static const int YQT[] = {16,11,10,16,24,40,51,61,12,12,14,19,26,58,60,55,14,13,16,24,40,57,69,56,14,17,22,29,51,87,80,62,18,22,\n                             37,56,68,109,103,77,24,35,55,64,81,104,113,92,49,64,78,87,103,121,120,101,72,92,95,98,112,100,103,99};\n   static const int UVQT[] = {17,18,24,47,99,99,99,99,18,21,26,66,99,99,99,99,24,26,56,99,99,99,99,99,47,66,99,99,99,99,99,99,\n                              99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99};\n   static const float aasf[] = { 1.0f * 2.828427125f, 1.387039845f * 2.828427125f, 1.306562965f * 2.828427125f, 1.175875602f * 2.828427125f,\n                                 1.0f * 2.828427125f, 0.785694958f * 2.828427125f, 0.541196100f * 2.828427125f, 0.275899379f * 2.828427125f };\n\n   int row, col, i, k, subsample;\n   float fdtbl_Y[64], fdtbl_UV[64];\n   unsigned char YTable[64], UVTable[64];\n\n   if(!data || !width || !height || comp > 4 || comp < 1) {\n      return 0;\n   }\n\n   quality = quality ? quality : 90;\n   subsample = quality <= 90 ? 1 : 0;\n   quality = quality < 1 ? 1 : quality > 100 ? 100 : quality;\n   quality = quality < 50 ? 5000 / quality : 200 - quality * 2;\n\n   for(i = 0; i < 64; ++i) {\n      int uvti, yti = (YQT[i]*quality+50)/100;\n      YTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (yti < 1 ? 1 : yti > 255 ? 255 : yti);\n      uvti = (UVQT[i]*quality+50)/100;\n      UVTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (uvti < 1 ? 1 : uvti > 255 ? 255 : uvti);\n   }\n\n   for(row = 0, k = 0; row < 8; ++row) {\n      for(col = 0; col < 8; ++col, ++k) {\n         fdtbl_Y[k]  = 1 / (YTable [stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]);\n         fdtbl_UV[k] = 1 / (UVTable[stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]);\n      }\n   }\n\n   // Write Headers\n   {\n      static const unsigned char head0[] = { 0xFF,0xD8,0xFF,0xE0,0,0x10,'J','F','I','F',0,1,1,0,0,1,0,1,0,0,0xFF,0xDB,0,0x84,0 };\n      static const unsigned char head2[] = { 0xFF,0xDA,0,0xC,3,1,0,2,0x11,3,0x11,0,0x3F,0 };\n      const unsigned char head1[] = { 0xFF,0xC0,0,0x11,8,(unsigned char)(height>>8),STBIW_UCHAR(height),(unsigned char)(width>>8),STBIW_UCHAR(width),\n                                      3,1,(unsigned char)(subsample?0x22:0x11),0,2,0x11,1,3,0x11,1,0xFF,0xC4,0x01,0xA2,0 };\n      s->func(s->context, (void*)head0, sizeof(head0));\n      s->func(s->context, (void*)YTable, sizeof(YTable));\n      stbiw__putc(s, 1);\n      s->func(s->context, UVTable, sizeof(UVTable));\n      s->func(s->context, (void*)head1, sizeof(head1));\n      s->func(s->context, (void*)(std_dc_luminance_nrcodes+1), sizeof(std_dc_luminance_nrcodes)-1);\n      s->func(s->context, (void*)std_dc_luminance_values, sizeof(std_dc_luminance_values));\n      stbiw__putc(s, 0x10); // HTYACinfo\n      s->func(s->context, (void*)(std_ac_luminance_nrcodes+1), sizeof(std_ac_luminance_nrcodes)-1);\n      s->func(s->context, (void*)std_ac_luminance_values, sizeof(std_ac_luminance_values));\n      stbiw__putc(s, 1); // HTUDCinfo\n      s->func(s->context, (void*)(std_dc_chrominance_nrcodes+1), sizeof(std_dc_chrominance_nrcodes)-1);\n      s->func(s->context, (void*)std_dc_chrominance_values, sizeof(std_dc_chrominance_values));\n      stbiw__putc(s, 0x11); // HTUACinfo\n      s->func(s->context, (void*)(std_ac_chrominance_nrcodes+1), sizeof(std_ac_chrominance_nrcodes)-1);\n      s->func(s->context, (void*)std_ac_chrominance_values, sizeof(std_ac_chrominance_values));\n      s->func(s->context, (void*)head2, sizeof(head2));\n   }\n\n   // Encode 8x8 macroblocks\n   {\n      static const unsigned short fillBits[] = {0x7F, 7};\n      int DCY=0, DCU=0, DCV=0;\n      int bitBuf=0, bitCnt=0;\n      // comp == 2 is grey+alpha (alpha is ignored)\n      int ofsG = comp > 2 ? 1 : 0, ofsB = comp > 2 ? 2 : 0;\n      const unsigned char *dataR = (const unsigned char *)data;\n      const unsigned char *dataG = dataR + ofsG;\n      const unsigned char *dataB = dataR + ofsB;\n      int x, y, pos;\n      if(subsample) {\n         for(y = 0; y < height; y += 16) {\n            for(x = 0; x < width; x += 16) {\n               float Y[256], U[256], V[256];\n               for(row = y, pos = 0; row < y+16; ++row) {\n                  // row >= height => use last input row\n                  int clamped_row = (row < height) ? row : height - 1;\n                  int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp;\n                  for(col = x; col < x+16; ++col, ++pos) {\n                     // if col >= width => use pixel from last input column\n                     int p = base_p + ((col < width) ? col : (width-1))*comp;\n                     float r = dataR[p], g = dataG[p], b = dataB[p];\n                     Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128;\n                     U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b;\n                     V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b;\n                  }\n               }\n               DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+0,   16, fdtbl_Y, DCY, YDC_HT, YAC_HT);\n               DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+8,   16, fdtbl_Y, DCY, YDC_HT, YAC_HT);\n               DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+128, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);\n               DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+136, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);\n\n               // subsample U,V\n               {\n                  float subU[64], subV[64];\n                  int yy, xx;\n                  for(yy = 0, pos = 0; yy < 8; ++yy) {\n                     for(xx = 0; xx < 8; ++xx, ++pos) {\n                        int j = yy*32+xx*2;\n                        subU[pos] = (U[j+0] + U[j+1] + U[j+16] + U[j+17]) * 0.25f;\n                        subV[pos] = (V[j+0] + V[j+1] + V[j+16] + V[j+17]) * 0.25f;\n                     }\n                  }\n                  DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subU, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT);\n                  DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subV, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT);\n               }\n            }\n         }\n      } else {\n         for(y = 0; y < height; y += 8) {\n            for(x = 0; x < width; x += 8) {\n               float Y[64], U[64], V[64];\n               for(row = y, pos = 0; row < y+8; ++row) {\n                  // row >= height => use last input row\n                  int clamped_row = (row < height) ? row : height - 1;\n                  int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp;\n                  for(col = x; col < x+8; ++col, ++pos) {\n                     // if col >= width => use pixel from last input column\n                     int p = base_p + ((col < width) ? col : (width-1))*comp;\n                     float r = dataR[p], g = dataG[p], b = dataB[p];\n                     Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128;\n                     U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b;\n                     V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b;\n                  }\n               }\n\n               DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y, 8, fdtbl_Y,  DCY, YDC_HT, YAC_HT);\n               DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, U, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT);\n               DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, V, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT);\n            }\n         }\n      }\n\n      // Do the bit alignment of the EOI marker\n      stbiw__jpg_writeBits(s, &bitBuf, &bitCnt, fillBits);\n   }\n\n   // EOI\n   stbiw__putc(s, 0xFF);\n   stbiw__putc(s, 0xD9);\n\n   return 1;\n}\n\nSTBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality)\n{\n   stbi__write_context s = { 0 };\n   stbi__start_write_callbacks(&s, func, context);\n   return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality);\n}\n\n\n#ifndef STBI_WRITE_NO_STDIO\nSTBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality)\n{\n   stbi__write_context s = { 0 };\n   if (stbi__start_write_file(&s,filename)) {\n      int r = stbi_write_jpg_core(&s, x, y, comp, data, quality);\n      stbi__end_write_file(&s);\n      return r;\n   } else\n      return 0;\n}\n#endif\n\n#endif // STB_IMAGE_WRITE_IMPLEMENTATION\n\n/* Revision history\n      1.16  (2021-07-11)\n             make Deflate code emit uncompressed blocks when it would otherwise expand\n             support writing BMPs with alpha channel\n      1.15  (2020-07-13) unknown\n      1.14  (2020-02-02) updated JPEG writer to downsample chroma channels\n      1.13\n      1.12\n      1.11  (2019-08-11)\n\n      1.10  (2019-02-07)\n             support utf8 filenames in Windows; fix warnings and platform ifdefs\n      1.09  (2018-02-11)\n             fix typo in zlib quality API, improve STB_I_W_STATIC in C++\n      1.08  (2018-01-29)\n             add stbi__flip_vertically_on_write, external zlib, zlib quality, choose PNG filter\n      1.07  (2017-07-24)\n             doc fix\n      1.06 (2017-07-23)\n             writing JPEG (using Jon Olick's code)\n      1.05   ???\n      1.04 (2017-03-03)\n             monochrome BMP expansion\n      1.03   ???\n      1.02 (2016-04-02)\n             avoid allocating large structures on the stack\n      1.01 (2016-01-16)\n             STBIW_REALLOC_SIZED: support allocators with no realloc support\n             avoid race-condition in crc initialization\n             minor compile issues\n      1.00 (2015-09-14)\n             installable file IO function\n      0.99 (2015-09-13)\n             warning fixes; TGA rle support\n      0.98 (2015-04-08)\n             added STBIW_MALLOC, STBIW_ASSERT etc\n      0.97 (2015-01-18)\n             fixed HDR asserts, rewrote HDR rle logic\n      0.96 (2015-01-17)\n             add HDR output\n             fix monochrome BMP\n      0.95 (2014-08-17)\n             add monochrome TGA output\n      0.94 (2014-05-31)\n             rename private functions to avoid conflicts with stb_image.h\n      0.93 (2014-05-27)\n             warning fixes\n      0.92 (2010-08-01)\n             casts to unsigned char to fix warnings\n      0.91 (2010-07-17)\n             first public release\n      0.90   first internal release\n*/\n\n/*\n------------------------------------------------------------------------------\nThis software is available under 2 licenses -- choose whichever you prefer.\n------------------------------------------------------------------------------\nALTERNATIVE A - MIT License\nCopyright (c) 2017 Sean Barrett\nPermission is hereby granted, free of charge, to any person obtaining a copy of\nthis software and associated documentation files (the \"Software\"), to deal in\nthe Software without restriction, including without limitation the rights to\nuse, copy, modify, merge, publish, distribute, sublicense, and/or sell copies\nof the Software, and to permit persons to whom the Software is furnished to do\nso, subject to the following conditions:\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n------------------------------------------------------------------------------\nALTERNATIVE B - Public Domain (www.unlicense.org)\nThis is free and unencumbered software released into the public domain.\nAnyone is free to copy, modify, publish, use, compile, sell, or distribute this\nsoftware, either in source code form or as a compiled binary, for any purpose,\ncommercial or non-commercial, and by any means.\nIn jurisdictions that recognize copyright laws, the author or authors of this\nsoftware dedicate any and all copyright interest in the software to the public\ndomain. We make this dedication for the benefit of the public at large and to\nthe detriment of our heirs and successors. We intend this dedication to be an\novert act of relinquishment in perpetuity of all present and future rights to\nthis software under copyright law.\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN\nACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION\nWITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n------------------------------------------------------------------------------\n*/\n"
  },
  {
    "path": "examples/test-cmake/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.14)\nproject(ggml-simple)\n\nset(CMAKE_CXX_STANDARD 17)\n\nfind_package(ggml CONFIG REQUIRED)\n\nset(TEST_TARGET test-cmake)\nadd_executable(test-cmake test-cmake.cpp)\ntarget_link_libraries(test-cmake PRIVATE ggml::ggml)\n"
  },
  {
    "path": "examples/test-cmake/README.md",
    "content": "## cmake-test\n\nThis directory can be built as a separate project with an installed ggml.\n"
  },
  {
    "path": "examples/test-cmake/test-cmake.cpp",
    "content": "#include \"ggml-backend.h\"\n\nint main(void) {\n    ggml_backend_load_all();\n    return 0;\n}\n"
  },
  {
    "path": "examples/yolo/CMakeLists.txt",
    "content": "#\n# yolov3-tiny\n\nset(TEST_TARGET yolov3-tiny)\nadd_executable(${TEST_TARGET} yolov3-tiny.cpp yolo-image.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml common)\n"
  },
  {
    "path": "examples/yolo/README.md",
    "content": "This example shows how to implement YOLO object detection with ggml using pretrained model.\n\n# YOLOv3-tiny\n\nDownload the model weights:\n\n```bash\n$ wget https://pjreddie.com/media/files/yolov3-tiny.weights\n$ sha1sum yolov3-tiny.weights \n40f3c11883bef62fd850213bc14266632ed4414f  yolov3-tiny.weights\n```\n\nConvert the weights to GGUF format:\n\n```bash\n$ ./convert-yolov3-tiny.py yolov3-tiny.weights\nyolov3-tiny.weights converted to yolov3-tiny.gguf\n```\n\nAlternatively, you can download the converted model from [HuggingFace](https://huggingface.co/rgerganov/yolo-gguf/resolve/main/yolov3-tiny.gguf)\n\nObject detection:\n\n```bash\n$ wget https://raw.githubusercontent.com/pjreddie/darknet/master/data/dog.jpg\n$ ./yolov3-tiny -m yolov3-tiny.gguf -i dog.jpg\nload_model: using CUDA backend\nggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no\nggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no\nggml_cuda_init: found 1 CUDA devices:\n  Device 0: NVIDIA T1200 Laptop GPU, compute capability 7.5, VMM: yes\nLayer  0 output shape:  416 x 416 x   16 x   1\nLayer  1 output shape:  208 x 208 x   16 x   1\nLayer  2 output shape:  208 x 208 x   32 x   1\nLayer  3 output shape:  104 x 104 x   32 x   1\nLayer  4 output shape:  104 x 104 x   64 x   1\nLayer  5 output shape:   52 x  52 x   64 x   1\nLayer  6 output shape:   52 x  52 x  128 x   1\nLayer  7 output shape:   26 x  26 x  128 x   1\nLayer  8 output shape:   26 x  26 x  256 x   1\nLayer  9 output shape:   13 x  13 x  256 x   1\nLayer 10 output shape:   13 x  13 x  512 x   1\nLayer 11 output shape:   13 x  13 x  512 x   1\nLayer 12 output shape:   13 x  13 x 1024 x   1\nLayer 13 output shape:   13 x  13 x  256 x   1\nLayer 14 output shape:   13 x  13 x  512 x   1\nLayer 15 output shape:   13 x  13 x  255 x   1\nLayer 18 output shape:   13 x  13 x  128 x   1\nLayer 19 output shape:   26 x  26 x  128 x   1\nLayer 20 output shape:   26 x  26 x  384 x   1\nLayer 21 output shape:   26 x  26 x  256 x   1\nLayer 22 output shape:   26 x  26 x  255 x   1\ndog: 57%\ncar: 52%\ntruck: 56%\ncar: 62%\nbicycle: 59%\nDetected objects saved in 'predictions.jpg' (time: 0.057000 sec.)\n```"
  },
  {
    "path": "examples/yolo/convert-yolov3-tiny.py",
    "content": "#!/usr/bin/env python3\nimport sys\nimport gguf\nimport numpy as np\n\ndef save_conv2d_layer(f, gguf_writer, prefix, inp_c, filters, size, batch_normalize=True):\n    biases = np.fromfile(f, dtype=np.float32, count=filters)\n    gguf_writer.add_tensor(prefix + \"_biases\", biases, raw_shape=(1, filters, 1, 1))\n\n    if batch_normalize:\n        scales = np.fromfile(f, dtype=np.float32, count=filters)\n        gguf_writer.add_tensor(prefix + \"_scales\", scales, raw_shape=(1, filters, 1, 1))\n        rolling_mean = np.fromfile(f, dtype=np.float32, count=filters)\n        gguf_writer.add_tensor(prefix + \"_rolling_mean\", rolling_mean, raw_shape=(1, filters, 1, 1))\n        rolling_variance = np.fromfile(f, dtype=np.float32, count=filters)\n        gguf_writer.add_tensor(prefix + \"_rolling_variance\", rolling_variance, raw_shape=(1, filters, 1, 1))\n\n    weights_count = filters * inp_c * size * size\n    l0_weights = np.fromfile(f, dtype=np.float32, count=weights_count)\n    ## ggml doesn't support f32 convolution yet, use f16 instead\n    l0_weights = l0_weights.astype(np.float16)\n    gguf_writer.add_tensor(prefix + \"_weights\", l0_weights, raw_shape=(filters, inp_c, size, size))\n\n\nif __name__ == '__main__':\n    if len(sys.argv) != 2:\n        print(\"Usage: %s <yolov3-tiny.weights>\" % sys.argv[0])\n        sys.exit(1)\n    outfile = 'yolov3-tiny.gguf'\n    gguf_writer = gguf.GGUFWriter(outfile, 'yolov3-tiny')\n\n    f = open(sys.argv[1], 'rb')\n    f.read(20) # skip header\n    save_conv2d_layer(f, gguf_writer, \"l0\", 3, 16, 3)\n    save_conv2d_layer(f, gguf_writer, \"l1\", 16, 32, 3)\n    save_conv2d_layer(f, gguf_writer, \"l2\", 32, 64, 3)\n    save_conv2d_layer(f, gguf_writer, \"l3\", 64, 128, 3)\n    save_conv2d_layer(f, gguf_writer, \"l4\", 128, 256, 3)\n    save_conv2d_layer(f, gguf_writer, \"l5\", 256, 512, 3)\n    save_conv2d_layer(f, gguf_writer, \"l6\", 512, 1024, 3)\n    save_conv2d_layer(f, gguf_writer, \"l7\", 1024, 256, 1)\n    save_conv2d_layer(f, gguf_writer, \"l8\", 256, 512, 3)\n    save_conv2d_layer(f, gguf_writer, \"l9\", 512, 255, 1, batch_normalize=False)\n    save_conv2d_layer(f, gguf_writer, \"l10\", 256, 128, 1)\n    save_conv2d_layer(f, gguf_writer, \"l11\", 384, 256, 3)\n    save_conv2d_layer(f, gguf_writer, \"l12\", 256, 255, 1, batch_normalize=False)\n    f.close()\n    \n    gguf_writer.write_header_to_file()\n    gguf_writer.write_kv_data_to_file()\n    gguf_writer.write_tensors_to_file()\n    gguf_writer.close()\n    print(\"{} converted to {}\".format(sys.argv[1], outfile))\n"
  },
  {
    "path": "examples/yolo/data/coco.names",
    "content": "person\nbicycle\ncar\nmotorbike\naeroplane\nbus\ntrain\ntruck\nboat\ntraffic light\nfire hydrant\nstop sign\nparking meter\nbench\nbird\ncat\ndog\nhorse\nsheep\ncow\nelephant\nbear\nzebra\ngiraffe\nbackpack\numbrella\nhandbag\ntie\nsuitcase\nfrisbee\nskis\nsnowboard\nsports ball\nkite\nbaseball bat\nbaseball glove\nskateboard\nsurfboard\ntennis racket\nbottle\nwine glass\ncup\nfork\nknife\nspoon\nbowl\nbanana\napple\nsandwich\norange\nbroccoli\ncarrot\nhot dog\npizza\ndonut\ncake\nchair\nsofa\npottedplant\nbed\ndiningtable\ntoilet\ntvmonitor\nlaptop\nmouse\nremote\nkeyboard\ncell phone\nmicrowave\noven\ntoaster\nsink\nrefrigerator\nbook\nclock\nvase\nscissors\nteddy bear\nhair drier\ntoothbrush\n"
  },
  {
    "path": "examples/yolo/yolo-image.cpp",
    "content": "#define STB_IMAGE_IMPLEMENTATION\n#include \"stb_image.h\"\n#define STB_IMAGE_WRITE_IMPLEMENTATION\n#include \"stb_image_write.h\"\n\n#include \"yolo-image.h\"\n\nstatic void draw_box(yolo_image & a, int x1, int y1, int x2, int y2, float r, float g, float b)\n{\n    if (x1 < 0) x1 = 0;\n    if (x1 >= a.w) x1 = a.w-1;\n    if (x2 < 0) x2 = 0;\n    if (x2 >= a.w) x2 = a.w-1;\n\n    if (y1 < 0) y1 = 0;\n    if (y1 >= a.h) y1 = a.h-1;\n    if (y2 < 0) y2 = 0;\n    if (y2 >= a.h) y2 = a.h-1;\n\n    for (int i = x1; i <= x2; ++i){\n        a.data[i + y1*a.w + 0*a.w*a.h] = r;\n        a.data[i + y2*a.w + 0*a.w*a.h] = r;\n\n        a.data[i + y1*a.w + 1*a.w*a.h] = g;\n        a.data[i + y2*a.w + 1*a.w*a.h] = g;\n\n        a.data[i + y1*a.w + 2*a.w*a.h] = b;\n        a.data[i + y2*a.w + 2*a.w*a.h] = b;\n    }\n    for (int i = y1; i <= y2; ++i){\n        a.data[x1 + i*a.w + 0*a.w*a.h] = r;\n        a.data[x2 + i*a.w + 0*a.w*a.h] = r;\n\n        a.data[x1 + i*a.w + 1*a.w*a.h] = g;\n        a.data[x2 + i*a.w + 1*a.w*a.h] = g;\n\n        a.data[x1 + i*a.w + 2*a.w*a.h] = b;\n        a.data[x2 + i*a.w + 2*a.w*a.h] = b;\n    }\n}\n\nvoid draw_box_width(yolo_image & a, int x1, int y1, int x2, int y2, int w, float r, float g, float b)\n{\n    for (int i = 0; i < w; ++i) {\n        draw_box(a, x1+i, y1+i, x2-i, y2-i, r, g, b);\n    }\n}\n\nbool save_image(const yolo_image & im, const char *name, int quality)\n{\n    uint8_t *data = (uint8_t*)calloc(im.w*im.h*im.c, sizeof(uint8_t));\n    for (int k = 0; k < im.c; ++k) {\n        for (int i = 0; i < im.w*im.h; ++i) {\n            data[i*im.c+k] = (uint8_t) (255*im.data[i + k*im.w*im.h]);\n        }\n    }\n    int success = stbi_write_jpg(name, im.w, im.h, im.c, data, quality);\n    free(data);\n    if (!success) {\n        fprintf(stderr, \"Failed to write image %s\\n\", name);\n        return false;\n    }\n    return true;\n}\n\nbool load_image(const char *fname, yolo_image & img)\n{\n    int w, h, c;\n    uint8_t * data = stbi_load(fname, &w, &h, &c, 3);\n    if (!data) {\n        return false;\n    }\n    c = 3;\n    img.w = w;\n    img.h = h;\n    img.c = c;\n    img.data.resize(w*h*c);\n    for (int k = 0; k < c; ++k){\n        for (int j = 0; j < h; ++j){\n            for (int i = 0; i < w; ++i){\n                int dst_index = i + w*j + w*h*k;\n                int src_index = k + c*i + c*w*j;\n                img.data[dst_index] = (float)data[src_index]/255.;\n            }\n        }\n    }\n    stbi_image_free(data);\n    return true;\n}\n\nstatic yolo_image resize_image(const yolo_image & im, int w, int h)\n{\n    yolo_image resized(w, h, im.c);\n    yolo_image part(w, im.h, im.c);\n    float w_scale = (float)(im.w - 1) / (w - 1);\n    float h_scale = (float)(im.h - 1) / (h - 1);\n    for (int k = 0; k < im.c; ++k){\n        for (int r = 0; r < im.h; ++r) {\n            for (int c = 0; c < w; ++c) {\n                float val = 0;\n                if (c == w-1 || im.w == 1){\n                    val = im.get_pixel(im.w-1, r, k);\n                } else {\n                    float sx = c*w_scale;\n                    int ix = (int) sx;\n                    float dx = sx - ix;\n                    val = (1 - dx) * im.get_pixel(ix, r, k) + dx * im.get_pixel(ix+1, r, k);\n                }\n                part.set_pixel(c, r, k, val);\n            }\n        }\n    }\n    for (int k = 0; k < im.c; ++k){\n        for (int r = 0; r < h; ++r){\n            float sy = r*h_scale;\n            int iy = (int) sy;\n            float dy = sy - iy;\n            for (int c = 0; c < w; ++c){\n                float val = (1-dy) * part.get_pixel(c, iy, k);\n                resized.set_pixel(c, r, k, val);\n            }\n            if (r == h-1 || im.h == 1) continue;\n            for (int c = 0; c < w; ++c){\n                float val = dy * part.get_pixel(c, iy+1, k);\n                resized.add_pixel(c, r, k, val);\n            }\n        }\n    }\n    return resized;\n}\n\nstatic void embed_image(const yolo_image & source, yolo_image & dest, int dx, int dy)\n{\n    for (int k = 0; k < source.c; ++k) {\n        for (int y = 0; y < source.h; ++y) {\n            for (int x = 0; x < source.w; ++x) {\n                float val = source.get_pixel(x, y, k);\n                dest.set_pixel(dx+x, dy+y, k, val);\n            }\n        }\n    }\n}\n\nyolo_image letterbox_image(const yolo_image & im, int w, int h)\n{\n    int new_w = im.w;\n    int new_h = im.h;\n    if (((float)w/im.w) < ((float)h/im.h)) {\n        new_w = w;\n        new_h = (im.h * w)/im.w;\n    } else {\n        new_h = h;\n        new_w = (im.w * h)/im.h;\n    }\n    yolo_image resized = resize_image(im, new_w, new_h);\n    yolo_image boxed(w, h, im.c);\n    boxed.fill(0.5);\n    embed_image(resized, boxed, (w-new_w)/2, (h-new_h)/2);\n    return boxed;\n}\n\nstatic yolo_image tile_images(const yolo_image & a, const yolo_image & b, int dx)\n{\n    if (a.w == 0) {\n        return b;\n    }\n    yolo_image c(a.w + b.w + dx, (a.h > b.h) ? a.h : b.h, a.c);\n    c.fill(1.0f);\n    embed_image(a, c, 0, 0);\n    embed_image(b, c, a.w + dx, 0);\n    return c;\n}\n\nstatic yolo_image border_image(const yolo_image & a, int border)\n{\n    yolo_image b(a.w + 2*border, a.h + 2*border, a.c);\n    b.fill(1.0f);\n    embed_image(a, b, border, border);\n    return b;\n}\n\nyolo_image get_label(const std::vector<yolo_image> & alphabet, const std::string & label, int size)\n{\n    size = size/10;\n    size = std::min(size, 7);\n    yolo_image result(0,0,0);\n    for (int i = 0; i < (int)label.size(); ++i) {\n        int ch = label[i];\n        yolo_image img = alphabet[size*128 + ch];\n        result = tile_images(result, img, -size - 1 + (size+1)/2);\n    }\n    return border_image(result, (int)(result.h*.25));\n}\n\nvoid draw_label(yolo_image & im, int row, int col, const yolo_image & label, const float * rgb)\n{\n    int w = label.w;\n    int h = label.h;\n    if (row - h >= 0) {\n        row = row - h;\n    }\n    for (int j = 0; j < h && j + row < im.h; j++) {\n        for (int i = 0; i < w && i + col < im.w; i++) {\n            for (int k = 0; k < label.c; k++) {\n                float val = label.get_pixel(i, j, k);\n                im.set_pixel(i + col, j + row, k, rgb[k] * val);\n            }\n        }\n    }\n}"
  },
  {
    "path": "examples/yolo/yolo-image.h",
    "content": "#pragma once\n\n#include <string>\n#include <vector>\n#include <cassert>\n\nstruct yolo_image {\n    int w, h, c;\n    std::vector<float> data;\n\n    yolo_image() : w(0), h(0), c(0) {}\n    yolo_image(int w, int h, int c) : w(w), h(h), c(c), data(w*h*c) {}\n\n    float get_pixel(int x, int y, int c) const {\n        assert(x >= 0 && x < w && y >= 0 && y < h && c >= 0 && c < this->c);\n        return data[c*w*h + y*w + x];\n    }\n\n    void set_pixel(int x, int y, int c, float val) {\n        assert(x >= 0 && x < w && y >= 0 && y < h && c >= 0 && c < this->c);\n        data[c*w*h + y*w + x] = val;\n    }\n\n    void add_pixel(int x, int y, int c, float val) {\n        assert(x >= 0 && x < w && y >= 0 && y < h && c >= 0 && c < this->c);\n        data[c*w*h + y*w + x] += val;\n    }\n\n    void fill(float val) {\n        std::fill(data.begin(), data.end(), val);\n    }\n};\n\nbool load_image(const char *fname, yolo_image & img);\nvoid draw_box_width(yolo_image & a, int x1, int y1, int x2, int y2, int w, float r, float g, float b);\nyolo_image letterbox_image(const yolo_image & im, int w, int h);\nbool save_image(const yolo_image & im, const char *name, int quality);\nyolo_image get_label(const std::vector<yolo_image> & alphabet, const std::string & label, int size);\nvoid draw_label(yolo_image & im, int row, int col, const yolo_image & label, const float * rgb);\n"
  },
  {
    "path": "examples/yolo/yolov3-tiny.cpp",
    "content": "#include \"ggml.h\"\n#include \"gguf.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n\n#include \"yolo-image.h\"\n\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <ctime>\n#include <string>\n#include <vector>\n#include <algorithm>\n#include <fstream>\n#include <algorithm>\n#include <thread>\n\n#if defined(_MSC_VER)\n#pragma warning(disable: 4244 4267) // possible loss of data\n#endif\n\nstruct conv2d_layer {\n    struct ggml_tensor * weights;\n    struct ggml_tensor * biases;\n    struct ggml_tensor * scales;\n    struct ggml_tensor * rolling_mean;\n    struct ggml_tensor * rolling_variance;\n    int padding = 1;\n    bool batch_normalize = true;\n    bool activate = true; // true for leaky relu, false for linear\n};\n\nstruct yolo_model {\n    int width = 416;\n    int height = 416;\n    std::vector<conv2d_layer> conv2d_layers;\n    ggml_backend_t backend;\n    ggml_backend_buffer_t buffer;\n    struct ggml_context * ctx;\n};\n\nstruct yolo_layer {\n    int classes = 80;\n    std::vector<int> mask;\n    std::vector<float> anchors;\n    std::vector<float> predictions;\n    int w;\n    int h;\n\n    yolo_layer(int classes, const std::vector<int> & mask, const std::vector<float> & anchors, struct ggml_tensor * prev_layer)\n        : classes(classes), mask(mask), anchors(anchors)\n    {\n        w = prev_layer->ne[0];\n        h = prev_layer->ne[1];\n        predictions.resize(ggml_nbytes(prev_layer)/sizeof(float));\n        ggml_backend_tensor_get(prev_layer, predictions.data(), 0, ggml_nbytes(prev_layer));\n    }\n\n    int entry_index(int location, int entry) const {\n        int n = location / (w*h);\n        int loc = location % (w*h);\n        return n*w*h*(4+classes+1) + entry*w*h + loc;\n    }\n};\n\nstruct box {\n    float x, y, w, h;\n};\n\nstruct detection {\n    box bbox;\n    std::vector<float> prob;\n    float objectness;\n};\n\nstatic bool load_model(const std::string & fname, yolo_model & model) {\n    struct ggml_context * tmp_ctx = nullptr;\n    struct gguf_init_params gguf_params = {\n        /*.no_alloc   =*/ false,\n        /*.ctx        =*/ &tmp_ctx,\n    };\n    gguf_context * gguf_ctx = gguf_init_from_file(fname.c_str(), gguf_params);\n    if (!gguf_ctx) {\n        fprintf(stderr, \"%s: gguf_init_from_file() failed\\n\", __func__);\n        return false;\n    }\n\n    int num_tensors = gguf_get_n_tensors(gguf_ctx);\n    struct ggml_init_params params {\n            /*.mem_size   =*/ ggml_tensor_overhead() * num_tensors,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ true,\n    };\n    model.ctx = ggml_init(params);\n    for (int i = 0; i < num_tensors; i++) {\n        const char * name = gguf_get_tensor_name(gguf_ctx, i);\n        struct ggml_tensor * src = ggml_get_tensor(tmp_ctx, name);\n        struct ggml_tensor * dst = ggml_dup_tensor(model.ctx, src);\n        ggml_set_name(dst, name);\n    }\n    model.buffer = ggml_backend_alloc_ctx_tensors(model.ctx, model.backend);\n    // copy tensors from main memory to backend\n    for (struct ggml_tensor * cur = ggml_get_first_tensor(model.ctx); cur != NULL; cur = ggml_get_next_tensor(model.ctx, cur)) {\n        struct ggml_tensor * src = ggml_get_tensor(tmp_ctx, ggml_get_name(cur));\n        size_t n_size = ggml_nbytes(src);\n        ggml_backend_tensor_set(cur, ggml_get_data(src), 0, n_size);\n    }\n    gguf_free(gguf_ctx);\n    ggml_free(tmp_ctx);\n    \n    model.width  = 416;\n    model.height = 416;\n    model.conv2d_layers.resize(13);\n    model.conv2d_layers[7].padding = 0;\n    model.conv2d_layers[9].padding = 0;\n    model.conv2d_layers[9].batch_normalize = false;\n    model.conv2d_layers[9].activate = false;\n    model.conv2d_layers[10].padding = 0;\n    model.conv2d_layers[12].padding = 0;\n    model.conv2d_layers[12].batch_normalize = false;\n    model.conv2d_layers[12].activate = false;\n    for (int i = 0; i < (int)model.conv2d_layers.size(); i++) {\n        char name[256];\n        snprintf(name, sizeof(name), \"l%d_weights\", i);\n        model.conv2d_layers[i].weights = ggml_get_tensor(model.ctx, name);\n        snprintf(name, sizeof(name), \"l%d_biases\", i);\n        model.conv2d_layers[i].biases = ggml_get_tensor(model.ctx, name);\n        if (model.conv2d_layers[i].batch_normalize) {\n            snprintf(name, sizeof(name), \"l%d_scales\", i);\n            model.conv2d_layers[i].scales = ggml_get_tensor(model.ctx, name);\n            snprintf(name, sizeof(name), \"l%d_rolling_mean\", i);\n            model.conv2d_layers[i].rolling_mean = ggml_get_tensor(model.ctx, name);\n            snprintf(name, sizeof(name), \"l%d_rolling_variance\", i);\n            model.conv2d_layers[i].rolling_variance = ggml_get_tensor(model.ctx, name);\n        }\n    }\n    return true;\n}\n\nstatic bool load_labels(const char * filename, std::vector<std::string> & labels)\n{\n    std::ifstream file_in(filename);\n    if (!file_in) {\n        return false;\n    }\n    std::string line;\n    while (std::getline(file_in, line)) {\n        labels.push_back(line);\n    }\n    GGML_ASSERT(labels.size() == 80);\n    return true;\n}\n\nstatic bool load_alphabet(std::vector<yolo_image> & alphabet)\n{\n    alphabet.resize(8 * 128);\n    for (int j = 0; j < 8; j++) {\n        for (int i = 32; i < 127; i++) {\n            char fname[256];\n            snprintf(fname, sizeof(fname), \"data/labels/%d_%d.png\", i, j);\n            if (!load_image(fname, alphabet[j*128 + i])) {\n                fprintf(stderr, \"Cannot load '%s'\\n\", fname);\n                return false;\n            }\n        }\n    }\n    return true;\n}\n\nstatic ggml_tensor * apply_conv2d(ggml_context * ctx, ggml_tensor * input, const conv2d_layer & layer)\n{\n    struct ggml_tensor * result = ggml_conv_2d(ctx, layer.weights, input, 1, 1, layer.padding, layer.padding, 1, 1);\n    if (layer.batch_normalize) {\n        result = ggml_sub(ctx, result, ggml_repeat(ctx, layer.rolling_mean, result));\n        result = ggml_div(ctx, result, ggml_sqrt(ctx, ggml_repeat(ctx, layer.rolling_variance, result)));\n        result = ggml_mul(ctx, result, ggml_repeat(ctx, layer.scales, result));\n    }\n    result = ggml_add(ctx, result, ggml_repeat(ctx, layer.biases, result));\n    if (layer.activate) {\n        result = ggml_leaky_relu(ctx, result, 0.1f, true);\n    }\n    return result;\n}\n\nstatic void activate_array(float * x, const int n)\n{\n    // logistic activation\n    for (int i = 0; i < n; i++) {\n        x[i] = 1./(1. + exp(-x[i]));\n    }\n}\n\nstatic void apply_yolo(yolo_layer & layer)\n{\n    int w = layer.w;\n    int h = layer.h;\n    int N = layer.mask.size();\n    float * data = layer.predictions.data();\n    for (int n = 0; n < N; n++) {\n        int index = layer.entry_index(n*w*h, 0);\n        activate_array(data + index, 2*w*h);\n        index = layer.entry_index(n*w*h, 4);\n        activate_array(data + index, (1+layer.classes)*w*h);\n    }\n}\n\nstatic box get_yolo_box(const yolo_layer & layer, int n, int index, int i, int j, int lw, int lh, int w, int h, int stride)\n{\n    const float * predictions = layer.predictions.data();\n    box b;\n    b.x = (i + predictions[index + 0*stride]) / lw;\n    b.y = (j + predictions[index + 1*stride]) / lh;\n    b.w = exp(predictions[index + 2*stride]) * layer.anchors[2*n]   / w;\n    b.h = exp(predictions[index + 3*stride]) * layer.anchors[2*n+1] / h;\n    return b;\n}\n\nstatic void correct_yolo_box(box & b, int im_w, int im_h, int net_w, int net_h)\n{\n    int new_w = 0;\n    int new_h = 0;\n    if (((float)net_w/im_w) < ((float)net_h/im_h)) {\n        new_w = net_w;\n        new_h = (im_h * net_w)/im_w;\n    } else {\n        new_h = net_h;\n        new_w = (im_w * net_h)/im_h;\n    }\n    b.x = (b.x - (net_w - new_w)/2./net_w) / ((float)new_w/net_w);\n    b.y = (b.y - (net_h - new_h)/2./net_h) / ((float)new_h/net_h);\n    b.w *= (float)net_w/new_w;\n    b.h *= (float)net_h/new_h;\n}\n\nstatic void get_yolo_detections(const yolo_layer & layer, std::vector<detection> & detections, int im_w, int im_h, int netw, int neth, float thresh)\n{\n    int w = layer.w;\n    int h = layer.h;\n    int N = layer.mask.size();\n    const float * predictions = layer.predictions.data();\n    std::vector<detection> result;\n    for (int i = 0; i < w*h; i++) {\n        for (int n = 0; n < N; n++) {\n            int obj_index = layer.entry_index(n*w*h + i, 4);\n            float objectness = predictions[obj_index];\n            if (objectness <= thresh) {\n                continue;\n            }\n            detection det;\n            int box_index = layer.entry_index(n*w*h + i, 0);\n            int row = i / w;\n            int col = i % w;\n            det.bbox = get_yolo_box(layer, layer.mask[n], box_index, col, row, w, h, netw, neth, w*h);\n            correct_yolo_box(det.bbox, im_w, im_h, netw, neth);\n            det.objectness = objectness;\n            det.prob.resize(layer.classes);\n            for (int j = 0; j < layer.classes; j++) {\n                int class_index = layer.entry_index(n*w*h + i, 4 + 1 + j);\n                float prob = objectness*predictions[class_index];\n                det.prob[j] = (prob > thresh) ? prob : 0;\n            }\n            detections.push_back(det);\n        }\n    }\n}\n\nstatic float overlap(float x1, float w1, float x2, float w2)\n{\n    float l1 = x1 - w1/2;\n    float l2 = x2 - w2/2;\n    float left = l1 > l2 ? l1 : l2;\n    float r1 = x1 + w1/2;\n    float r2 = x2 + w2/2;\n    float right = r1 < r2 ? r1 : r2;\n    return right - left;\n}\n\nstatic float box_intersection(const box & a, const box & b)\n{\n    float w = overlap(a.x, a.w, b.x, b.w);\n    float h = overlap(a.y, a.h, b.y, b.h);\n    if (w < 0 || h < 0) return 0;\n    float area = w*h;\n    return area;\n}\n\nstatic float box_union(const box & a, const box & b)\n{\n    float i = box_intersection(a, b);\n    float u = a.w*a.h + b.w*b.h - i;\n    return u;\n}\n\nstatic float box_iou(const box & a, const box & b)\n{\n    return box_intersection(a, b)/box_union(a, b);\n}\n\nstatic void do_nms_sort(std::vector<detection> & dets, int classes, float thresh)\n{\n    int k = (int)dets.size()-1;\n    for (int i = 0; i <= k; ++i) {\n        if (dets[i].objectness == 0) {\n            std::swap(dets[i], dets[k]);\n            --k;\n            --i;\n        }\n    }\n    int total = k+1;\n    for (int k = 0; k < classes; ++k) {\n        std::sort(dets.begin(), dets.begin()+total, [=](const detection & a, const detection & b) {\n            return a.prob[k] > b.prob[k];\n        });\n        for (int i = 0; i < total; ++i) {\n            if (dets[i].prob[k] == 0) {\n                continue;\n            }\n            box a = dets[i].bbox;\n            for (int j = i+1; j < total; ++j){\n                box b = dets[j].bbox;\n                if (box_iou(a, b) > thresh) {\n                    dets[j].prob[k] = 0;\n                }\n            }\n        }\n    }\n}\n\nstatic float get_color(int c, int x, int max)\n{\n    float colors[6][3] = { {1,0,1}, {0,0,1}, {0,1,1}, {0,1,0}, {1,1,0}, {1,0,0} };\n    float ratio = ((float)x/max)*5;\n    int i = floor(ratio);\n    int j = ceil(ratio);\n    ratio -= i;\n    float r = (1-ratio) * colors[i][c] + ratio*colors[j][c];\n    return r;\n}\n\nstatic void draw_detections(yolo_image & im, const std::vector<detection> & dets, float thresh, const std::vector<std::string> & labels, const std::vector<yolo_image> & alphabet)\n{\n    int classes = (int)labels.size();\n    for (int i = 0; i < (int)dets.size(); i++) {\n        std::string labelstr;\n        int cl = -1;\n        for (int j = 0; j < (int)dets[i].prob.size(); j++) {\n            if (dets[i].prob[j] > thresh) {\n                if (cl < 0) {\n                    labelstr = labels[j];\n                    cl = j;\n                } else {\n                    labelstr += \", \";\n                    labelstr += labels[j];\n                }\n                printf(\"%s: %.0f%%\\n\", labels[j].c_str(), dets[i].prob[j]*100);\n            }\n        }\n        if (cl >= 0) {\n            int width = im.h * .006;\n            int offset = cl*123457 % classes;\n            float red = get_color(2,offset,classes);\n            float green = get_color(1,offset,classes);\n            float blue = get_color(0,offset,classes);\n            float rgb[3];\n\n            rgb[0] = red;\n            rgb[1] = green;\n            rgb[2] = blue;\n            box b = dets[i].bbox;\n\n            int left  = (b.x-b.w/2.)*im.w;\n            int right = (b.x+b.w/2.)*im.w;\n            int top   = (b.y-b.h/2.)*im.h;\n            int bot   = (b.y+b.h/2.)*im.h;\n\n            if (left < 0) left = 0;\n            if (right > im.w-1) right = im.w-1;\n            if (top < 0) top = 0;\n            if (bot > im.h-1) bot = im.h-1;\n\n            draw_box_width(im, left, top, right, bot, width, red, green, blue);\n            yolo_image label = get_label(alphabet, labelstr, (im.h*.03));\n            draw_label(im, top + width, left, label, rgb);\n        }\n    }\n}\n\nstatic void print_shape(int layer, const ggml_tensor * t)\n{\n    printf(\"Layer %2d output shape:  %3d x %3d x %4d x %3d\\n\", layer, (int)t->ne[0], (int)t->ne[1], (int)t->ne[2], (int)t->ne[3]);\n}\n\nstatic struct ggml_cgraph * build_graph(struct ggml_context * ctx_cgraph, const yolo_model & model) {\n    struct ggml_cgraph * gf = ggml_new_graph(ctx_cgraph);\n\n    struct ggml_tensor * input = ggml_new_tensor_4d(ctx_cgraph, GGML_TYPE_F32, model.width, model.height, 3, 1);\n    ggml_set_name(input, \"input\");\n    struct ggml_tensor * result = apply_conv2d(ctx_cgraph, input, model.conv2d_layers[0]);\n    print_shape(0, result);\n    result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);\n    print_shape(1, result);\n    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[1]);\n    print_shape(2, result);\n    result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);\n    print_shape(3, result);\n    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[2]);\n    print_shape(4, result);\n    result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);\n    print_shape(5, result);\n    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[3]);\n    print_shape(6, result);\n    result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);\n    print_shape(7, result);\n    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[4]);\n    struct ggml_tensor * layer_8 = result;\n    print_shape(8, result);\n    result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);\n    print_shape(9, result);\n    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[5]);\n    print_shape(10, result);\n    result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 1, 1, 0.5, 0.5);\n    print_shape(11, result);\n    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[6]);\n    print_shape(12, result);\n    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[7]);\n    struct ggml_tensor * layer_13 = result;\n    print_shape(13, result);\n    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[8]);\n    print_shape(14, result);\n    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[9]);\n    struct ggml_tensor * layer_15 = result;\n    ggml_set_output(layer_15);\n    ggml_set_name(layer_15, \"layer_15\");\n\n    print_shape(15, result);\n    result = apply_conv2d(ctx_cgraph, layer_13, model.conv2d_layers[10]);\n    print_shape(18, result);\n    result = ggml_upscale(ctx_cgraph, result, 2, GGML_SCALE_MODE_NEAREST);\n    print_shape(19, result);\n    result = ggml_concat(ctx_cgraph, result, layer_8, 2);\n    print_shape(20, result);\n    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[11]);\n    print_shape(21, result);\n    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[12]);\n    struct ggml_tensor * layer_22 = result;\n    ggml_set_output(layer_22);\n    ggml_set_name(layer_22, \"layer_22\");\n    print_shape(22, result);\n\n    ggml_build_forward_expand(gf, layer_15);\n    ggml_build_forward_expand(gf, layer_22);\n    return gf;\n}\n\nvoid detect(yolo_image & img, struct ggml_cgraph * gf, const yolo_model & model, float thresh, const std::vector<std::string> & labels, const std::vector<yolo_image> & alphabet)\n{\n    std::vector<detection> detections;\n    yolo_image sized = letterbox_image(img, model.width, model.height);\n    struct ggml_tensor * input = ggml_graph_get_tensor(gf, \"input\");\n    ggml_backend_tensor_set(input, sized.data.data(), 0, ggml_nbytes(input));\n\n    if (ggml_backend_graph_compute(model.backend, gf) != GGML_STATUS_SUCCESS) {\n        fprintf(stderr, \"%s: ggml_backend_graph_compute() failed\\n\", __func__);\n        return;\n    }\n\n    struct ggml_tensor * layer_15 = ggml_graph_get_tensor(gf, \"layer_15\");\n    yolo_layer yolo16{ 80, {3, 4, 5}, {10, 14, 23, 27, 37,58, 81, 82, 135, 169, 344, 319}, layer_15};\n    apply_yolo(yolo16);\n    get_yolo_detections(yolo16, detections, img.w, img.h, model.width, model.height, thresh);\n\n    struct ggml_tensor * layer_22 = ggml_graph_get_tensor(gf, \"layer_22\");\n    yolo_layer yolo23{ 80, {0, 1, 2}, {10, 14, 23, 27, 37,58, 81, 82, 135, 169, 344, 319}, layer_22};\n    apply_yolo(yolo23);\n    get_yolo_detections(yolo23, detections, img.w, img.h, model.width, model.height, thresh);\n\n    do_nms_sort(detections, yolo23.classes, .45);\n    draw_detections(img, detections, thresh, labels, alphabet);\n}\n\nstruct yolo_params {\n    float thresh          = 0.5;\n    std::string model     = \"yolov3-tiny.gguf\";\n    std::string fname_inp = \"input.jpg\";\n    std::string fname_out = \"predictions.jpg\";\n    int         n_threads  = std::max(1U, std::thread::hardware_concurrency()/2);\n    std::string device;\n};\n\nvoid yolo_print_usage(int argc, char ** argv, const yolo_params & params) {\n    fprintf(stderr, \"usage: %s [options]\\n\", argv[0]);\n    fprintf(stderr, \"\\n\");\n    fprintf(stderr, \"options:\\n\");\n    fprintf(stderr, \"  -h,  --help                show this help message and exit\\n\");\n    fprintf(stderr, \"  -d,  --device DEV          device to use\\n\");\n    fprintf(stderr, \"  -t,  --threads N           number of threads for the CPU backend (default: %d)\\n\", params.n_threads);\n    fprintf(stderr, \"  -th, --thresh T            detection threshold (default: %.2f)\\n\", params.thresh);\n    fprintf(stderr, \"  -m,  --model FNAME         model path (default: %s)\\n\", params.model.c_str());\n    fprintf(stderr, \"  -i,  --inp FNAME           input file (default: %s)\\n\", params.fname_inp.c_str());\n    fprintf(stderr, \"  -o,  --out FNAME           output file (default: %s)\\n\", params.fname_out.c_str());\n    fprintf(stderr, \"\\n\");\n}\n\nbool yolo_params_parse(int argc, char ** argv, yolo_params & params) {\n    for (int i = 1; i < argc; i++) {\n        std::string arg = argv[i];\n        if (arg == \"-th\" || arg == \"--thresh\") {\n            params.thresh = std::stof(argv[++i]);\n            if (params.thresh < 0 || params.thresh > 1) {\n                fprintf(stderr, \"error: invalid threshold: %.2f\\n\", params.thresh);\n                return false;\n            }\n        } else if (arg == \"-m\" || arg == \"--model\") {\n            params.model = argv[++i];\n        } else if (arg == \"-i\" || arg == \"--inp\") {\n            params.fname_inp = argv[++i];\n        } else if (arg == \"-o\" || arg == \"--out\") {\n            params.fname_out = argv[++i];\n        } else if (arg == \"-t\" || arg == \"--threads\") {\n            if (++i >= argc) {\n                return false;\n            }\n            params.n_threads = std::stoi(argv[i]);\n            if (params.n_threads <= 0) {\n                fprintf(stderr, \"error: invalid number of threads: %d\\n\", params.n_threads);\n                return false;\n            }\n        } else if (arg == \"-d\" || arg == \"--device\") {\n            if (++i >= argc) {\n                return false;\n            }\n            params.device = argv[i];\n            if (ggml_backend_dev_by_name(params.device.c_str()) == nullptr) {\n                fprintf(stderr, \"error: unknown device: %s\\n\", params.device.c_str());\n                fprintf(stderr, \"available devices:\\n\");\n                for (size_t i = 0; i < ggml_backend_dev_count(); i++) {\n                    auto * dev = ggml_backend_dev_get(i);\n                    size_t free, total;\n                    ggml_backend_dev_memory(dev, &free, &total);\n                    printf(\"  %s: %s (%zu MiB, %zu MiB free)\\n\", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);\n                }\n                return false;\n            }\n        } else if (arg == \"-h\" || arg == \"--help\") {\n            yolo_print_usage(argc, argv, params);\n            exit(0);\n        } else {\n            fprintf(stderr, \"error: unknown argument: %s\\n\", arg.c_str());\n            yolo_print_usage(argc, argv, params);\n            exit(0);\n        }\n    }\n\n    return true;\n}\n\nstatic ggml_backend_t create_backend(const yolo_params & params) {\n    ggml_backend_t backend = nullptr;\n\n    if (!params.device.empty()) {\n        ggml_backend_dev_t dev = ggml_backend_dev_by_name(params.device.c_str());\n        if (dev) {\n            backend = ggml_backend_dev_init(dev, nullptr);\n            if (!backend) {\n                fprintf(stderr, \"Failed to create backend for device %s\\n\", params.device.c_str());\n                return nullptr;\n            }\n        }\n    }\n\n    // try to initialize a GPU backend first\n    if (!backend) {\n        backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr);\n    }\n\n    // if there aren't GPU backends fallback to CPU backend\n    if (!backend) {\n        backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);\n    }\n\n    if (backend) {\n        fprintf(stderr, \"%s: using %s backend\\n\", __func__, ggml_backend_name(backend));\n\n        // set the number of threads\n        ggml_backend_dev_t dev = ggml_backend_get_device(backend);\n        ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;\n        if (reg) {\n            auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, \"ggml_backend_set_n_threads\");\n            if (ggml_backend_set_n_threads_fn) {\n                ggml_backend_set_n_threads_fn(backend, params.n_threads);\n            }\n        }\n    }\n\n    return backend;\n}\n\nint main(int argc, char *argv[])\n{\n    ggml_backend_load_all();\n    ggml_time_init();\n    yolo_model model;\n\n    yolo_params params;\n    if (!yolo_params_parse(argc, argv, params)) {\n        return 1;\n    }\n    model.backend = create_backend(params);\n    if (!model.backend) {\n        fprintf(stderr, \"Failed to create backend\\n\");\n        return 1;\n    }\n\n    if (!load_model(params.model, model)) {\n        fprintf(stderr, \"%s: failed to load model from '%s'\\n\", __func__, params.model.c_str());\n        return 1;\n    }\n    yolo_image img(0,0,0);\n    if (!load_image(params.fname_inp.c_str(), img)) {\n        fprintf(stderr, \"%s: failed to load image from '%s'\\n\", __func__, params.fname_inp.c_str());\n        return 1;\n    }\n    std::vector<std::string> labels;\n    if (!load_labels(\"data/coco.names\", labels)) {\n        fprintf(stderr, \"%s: failed to load labels from 'data/coco.names'\\n\", __func__);\n        return 1;\n    }\n    std::vector<yolo_image> alphabet;\n    if (!load_alphabet(alphabet)) {\n        fprintf(stderr, \"%s: failed to load alphabet\\n\", __func__);\n        return 1;\n    }\n\n    struct ggml_init_params params0 = {\n        /*.mem_size   =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()\n    };\n    struct ggml_context * ctx_cgraph = ggml_init(params0);\n    struct ggml_cgraph * gf = build_graph(ctx_cgraph, model);\n\n    ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));\n    ggml_gallocr_alloc_graph(allocr, gf);\n\n    const int64_t t_start_ms = ggml_time_ms();\n    detect(img, gf, model, params.thresh, labels, alphabet);\n    const int64_t t_detect_ms = ggml_time_ms() - t_start_ms;\n    if (!save_image(img, params.fname_out.c_str(), 80)) {\n        fprintf(stderr, \"%s: failed to save image to '%s'\\n\", __func__, params.fname_out.c_str());\n        return 1;\n    }\n    printf(\"Detected objects saved in '%s' (time: %f sec.)\\n\", params.fname_out.c_str(), t_detect_ms / 1000.0f);\n\n    ggml_free(ctx_cgraph);\n    ggml_gallocr_free(allocr);\n    ggml_free(model.ctx);\n    ggml_backend_buffer_free(model.buffer);\n    ggml_backend_free(model.backend);\n    return 0;\n}\n"
  },
  {
    "path": "ggml.pc.in",
    "content": "prefix=@CMAKE_INSTALL_PREFIX@\nexec_prefix=${prefix}\nincludedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@\nlibdir=${prefix}/@CMAKE_INSTALL_LIBDIR@\n\nName: ggml\nDescription: The GGML Tensor Library for Machine Learning\nVersion: @GGML_VERSION@\nCflags: -I${includedir}\nLibs: -L${libdir} -lggml\n"
  },
  {
    "path": "include/ggml-alloc.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\ntypedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;\ntypedef struct      ggml_backend_buffer * ggml_backend_buffer_t;\ntypedef struct             ggml_backend * ggml_backend_t;\n\n// Tensor allocator\nstruct ggml_tallocr {\n    ggml_backend_buffer_t buffer;\n    void * base;\n    size_t alignment;\n    size_t offset;\n};\n\nGGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer);\nGGML_API enum ggml_status    ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);\n\n// Graph allocator\n/*\n  Example usage:\n    ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());\n\n    // optional: create a worst-case graph and reserve the buffers to avoid reallocations\n    ggml_gallocr_reserve(galloc, build_graph(max_batch));\n\n    // allocate the graph\n    struct ggml_cgraph * graph = build_graph(batch);\n    ggml_gallocr_alloc_graph(galloc, graph);\n\n    printf(\"compute buffer size: %zu bytes\\n\", ggml_gallocr_get_buffer_size(galloc, 0));\n\n    // evaluate the graph\n    ggml_backend_graph_compute(backend, graph);\n*/\n\n// special tensor flags for use with the graph allocator:\n//   ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses\n//   ggml_set_output(): output tensors are never freed and never overwritten\n\ntypedef struct ggml_gallocr * ggml_gallocr_t;\n\nGGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft);\nGGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs);\nGGML_API void           ggml_gallocr_free(ggml_gallocr_t galloc);\n\n// pre-allocate buffers from a measure graph - does not allocate or modify the graph\n// call with a worst-case graph to avoid buffer reallocations\n// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed\n// returns false if the buffer allocation failed\n// ggml_gallocr_resrve_n_size writes the buffer sizes per galloc buffer that would be allocated by ggml_gallocr_reserve_n to sizes\nGGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);\nGGML_API void ggml_gallocr_reserve_n_size(\n    ggml_gallocr_t galloc,\n    struct ggml_cgraph * graph,\n    const int * node_buffer_ids,\n    const int * leaf_buffer_ids,\n    size_t * sizes);\nGGML_API bool ggml_gallocr_reserve_n(\n    ggml_gallocr_t galloc,\n    struct ggml_cgraph * graph,\n    const int * node_buffer_ids,\n    const int * leaf_buffer_ids);\n\n// automatic reallocation if the topology changes when using a single buffer\n// returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers)\nGGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph);\n\nGGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);\n\n// Utils\n// Create a buffer and allocate all the tensors in a ggml_context\n// ggml_backend_alloc_ctx_tensors_from_buft_size returns the size of the buffer that would be allocated by ggml_backend_alloc_ctx_tensors_from_buft\nGGML_API size_t                       ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);\nGGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);\nGGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-backend.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n#include \"ggml-alloc.h\"\n\n#ifdef GGML_BACKEND_SHARED\n#    if defined(_WIN32) && !defined(__MINGW32__)\n#        ifdef GGML_BACKEND_BUILD\n#            define GGML_BACKEND_API __declspec(dllexport) extern\n#        else\n#            define GGML_BACKEND_API __declspec(dllimport) extern\n#        endif\n#    else\n#        define GGML_BACKEND_API __attribute__ ((visibility (\"default\"))) extern\n#    endif\n#else\n#    define GGML_BACKEND_API extern\n#endif\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\n    typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;\n    typedef struct ggml_backend_buffer * ggml_backend_buffer_t;\n    typedef struct ggml_backend_event * ggml_backend_event_t;\n    typedef struct ggml_backend * ggml_backend_t;\n    typedef void * ggml_backend_graph_plan_t;\n    typedef struct ggml_backend_reg * ggml_backend_reg_t;\n    typedef struct ggml_backend_device * ggml_backend_dev_t;\n\n\n    //\n    // Backend buffer type\n    //\n\n    GGML_API const char *          ggml_backend_buft_name          (ggml_backend_buffer_type_t buft);\n    GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer  (ggml_backend_buffer_type_t buft, size_t size);\n    GGML_API size_t                ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);\n    GGML_API size_t                ggml_backend_buft_get_max_size  (ggml_backend_buffer_type_t buft);\n    GGML_API size_t                ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);\n    GGML_API bool                  ggml_backend_buft_is_host       (ggml_backend_buffer_type_t buft);\n    GGML_API ggml_backend_dev_t    ggml_backend_buft_get_device    (ggml_backend_buffer_type_t buft);\n\n    //\n    // Backend buffer\n    //\n\n    enum ggml_backend_buffer_usage {\n        GGML_BACKEND_BUFFER_USAGE_ANY = 0,\n        GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1,\n        GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2,\n    };\n\n    GGML_API const char *                   ggml_backend_buffer_name          (ggml_backend_buffer_t buffer);\n    GGML_API void                           ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);\n    GGML_API void *                         ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);\n    GGML_API size_t                         ggml_backend_buffer_get_size      (ggml_backend_buffer_t buffer);\n    GGML_API enum ggml_status               ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);\n    GGML_API size_t                         ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);\n    GGML_API size_t                         ggml_backend_buffer_get_max_size  (ggml_backend_buffer_t buffer);\n    GGML_API size_t                         ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor);\n    GGML_API void                           ggml_backend_buffer_clear         (ggml_backend_buffer_t buffer, uint8_t value);\n    GGML_API bool                           ggml_backend_buffer_is_host       (ggml_backend_buffer_t buffer);\n    GGML_API void                           ggml_backend_buffer_set_usage     (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);\n    GGML_API enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage     (ggml_backend_buffer_t buffer);\n    GGML_API ggml_backend_buffer_type_t     ggml_backend_buffer_get_type      (ggml_backend_buffer_t buffer);\n    GGML_API void                           ggml_backend_buffer_reset         (ggml_backend_buffer_t buffer);\n\n    // tensor copy between different backends\n    GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);\n\n    //\n    // Backend (stream)\n    //\n\n    GGML_API ggml_guid_t  ggml_backend_guid(ggml_backend_t backend);\n    GGML_API const char * ggml_backend_name(ggml_backend_t backend);\n    GGML_API void         ggml_backend_free(ggml_backend_t backend);\n\n    GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);\n    GGML_API ggml_backend_buffer_t      ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);\n    GGML_API size_t                     ggml_backend_get_alignment(ggml_backend_t backend);\n    GGML_API size_t                     ggml_backend_get_max_size(ggml_backend_t backend);\n\n    GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);\n    GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);\n\n    // \"offset\" refers to the offset in tensor->data for setting/getting data\n    GGML_API void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);\n    GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);\n    GGML_API void ggml_backend_tensor_memset(   struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);\n\n    GGML_API void ggml_backend_synchronize(ggml_backend_t backend);\n\n    GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);\n    GGML_API void                      ggml_backend_graph_plan_free  (ggml_backend_t backend, ggml_backend_graph_plan_t plan);\n\n    GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);\n    GGML_API enum ggml_status ggml_backend_graph_compute      (ggml_backend_t backend, struct ggml_cgraph * cgraph);\n    GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);\n\n    // NOTE: will be removed, use device version instead\n    GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);\n    GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft);\n    GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op);\n\n    // asynchronous copy\n    // the copy is performed after all the currently queued operations in backend_src\n    // backend_dst will wait for the copy to complete before performing other operations\n    // automatic fallback to sync copy if async is not supported\n    GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);\n\n    GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend);\n\n    //\n    // Events\n    //\n\n    GGML_API ggml_backend_event_t ggml_backend_event_new(ggml_backend_dev_t device);\n    GGML_API void                 ggml_backend_event_free(ggml_backend_event_t event);\n    GGML_API void                 ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend);\n    GGML_API void                 ggml_backend_event_synchronize(ggml_backend_event_t event);\n    GGML_API void                 ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event);\n\n    //\n    // Backend device\n    //\n\n    enum ggml_backend_dev_type {\n        // CPU device using system memory\n        GGML_BACKEND_DEVICE_TYPE_CPU,\n        // GPU device using dedicated memory\n        GGML_BACKEND_DEVICE_TYPE_GPU,\n        // integrated GPU device using host memory\n        GGML_BACKEND_DEVICE_TYPE_IGPU,\n        // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)\n        GGML_BACKEND_DEVICE_TYPE_ACCEL\n    };\n\n    // functionality supported by the device\n    struct ggml_backend_dev_caps {\n        // asynchronous operations\n        bool async;\n        // pinned host buffer\n        bool host_buffer;\n        // creating buffers from host ptr\n        bool buffer_from_host_ptr;\n        // event synchronization\n        bool events;\n    };\n\n    // all the device properties\n    struct ggml_backend_dev_props {\n        // device name\n        const char * name;\n        // device description\n        const char * description;\n        // device free memory in bytes\n        size_t memory_free;\n        // device total memory in bytes\n        size_t memory_total;\n        // device type\n        enum ggml_backend_dev_type type;\n        // device id\n        //   for PCI devices, this should be the PCI bus id formatted as \"domain:bus:device.function\" (e.g. \"0000:01:00.0\")\n        //   if the id is unknown, this should be NULL\n        const char * device_id;\n        // device capabilities\n        struct ggml_backend_dev_caps caps;\n    };\n\n    GGML_API const char *                  ggml_backend_dev_name(ggml_backend_dev_t device);\n    GGML_API const char *                  ggml_backend_dev_description(ggml_backend_dev_t device);\n    GGML_API void                          ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total);\n    GGML_API enum ggml_backend_dev_type    ggml_backend_dev_type(ggml_backend_dev_t device);\n    GGML_API void                          ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);\n    GGML_API ggml_backend_reg_t            ggml_backend_dev_backend_reg(ggml_backend_dev_t device);\n    GGML_API ggml_backend_t                ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);\n    GGML_API ggml_backend_buffer_type_t    ggml_backend_dev_buffer_type(ggml_backend_dev_t device);\n    GGML_API ggml_backend_buffer_type_t    ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);\n    GGML_API ggml_backend_buffer_t         ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);\n\n    GGML_API bool                          ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op);\n    GGML_API bool                          ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft);\n    GGML_API bool                          ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op);\n\n    //\n    // Backend (reg)\n    //\n\n    GGML_API const char *       ggml_backend_reg_name(ggml_backend_reg_t reg);\n    GGML_API size_t             ggml_backend_reg_dev_count(ggml_backend_reg_t reg);\n    GGML_API ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index);\n    GGML_API void *             ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name);\n\n    // Common functions that may be obtained using ggml_backend_reg_get_proc_address\n\n    // Split buffer type for tensor parallelism\n    typedef ggml_backend_buffer_type_t   (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split);\n    // Set the number of threads for the backend\n    typedef void                         (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads);\n    // Get additional buffer types provided by the device (returns a NULL-terminated array)\n    typedef ggml_backend_buffer_type_t * (*ggml_backend_dev_get_extra_bufts_t)(ggml_backend_dev_t device);\n    // Set the abort callback for the backend\n    typedef void                         (*ggml_backend_set_abort_callback_t)(ggml_backend_t backend, ggml_abort_callback abort_callback, void * abort_callback_data);\n    // Get a list of feature flags supported by the backend (returns a NULL-terminated array)\n    struct ggml_backend_feature {\n        const char * name;\n        const char * value;\n    };\n    typedef struct ggml_backend_feature * (*ggml_backend_get_features_t)(ggml_backend_reg_t reg);\n\n    //\n    // Backend registry\n    //\n\n    GGML_API void ggml_backend_register(ggml_backend_reg_t reg);\n\n    GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);\n\n    // Backend (reg) enumeration\n    GGML_API size_t             ggml_backend_reg_count(void);\n    GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index);\n    GGML_API ggml_backend_reg_t ggml_backend_reg_by_name(const char * name);\n\n    // Device enumeration\n    GGML_API size_t             ggml_backend_dev_count(void);\n    GGML_API ggml_backend_dev_t ggml_backend_dev_get(size_t index);\n    GGML_API ggml_backend_dev_t ggml_backend_dev_by_name(const char * name);\n    GGML_API ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type);\n\n    // Direct backend (stream) initialization\n    // = ggml_backend_dev_init(ggml_backend_dev_by_name(name), params)\n    GGML_API ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params);\n    // = ggml_backend_dev_init(ggml_backend_dev_by_type(type), params)\n    GGML_API ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params);\n    // = ggml_backend_dev_init(ggml_backend_dev_by_type(GPU) OR ggml_backend_dev_by_type(CPU), NULL)\n    GGML_API ggml_backend_t ggml_backend_init_best(void);\n\n    // Load a backend from a dynamic library and register it\n    GGML_API ggml_backend_reg_t ggml_backend_load(const char * path);\n    // Unload a backend if loaded dynamically and unregister it\n    GGML_API void               ggml_backend_unload(ggml_backend_reg_t reg);\n    // Load all known backends from dynamic libraries\n    GGML_API void               ggml_backend_load_all(void);\n    GGML_API void               ggml_backend_load_all_from_path(const char * dir_path);\n\n    //\n    // Backend scheduler\n    //\n\n    // The backend scheduler allows for multiple backend devices to be used together\n    // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends\n    // The backends are selected based on:\n    // - the backend that supports the operation\n    // - the location of the pre-allocated tensors (e.g. the weights)\n    /*\n      Example usage:\n\n        // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned\n        // preferably to run on the same backend as the buffer\n        ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);\n\n        sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true);\n\n        // initialize buffers from a max size graph (optional)\n        reserve_graph = build_graph(sched, max_batch_size);\n\n        // manually assign nodes to a backend (optional, should not be needed in most cases)\n        struct ggml_tensor * node = ggml_mul_mat(ctx, ...);\n        ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu);\n\n        ggml_backend_sched_reserve(sched, reserve_graph);\n\n        // compute\n        graph = build_graph(sched); // the graph and its tensors are single-use in terms of allocation, multi-use in terms of computation\n        for (int i = 0; i < 10; ++i) {\n            ggml_backend_sched_graph_compute(sched, graph); // on the first iteration the graph is allocated automatically\n        }\n\n        // if there are graph inputs:\n        graph = build_graph(sched); // get a new graph that is not allocated (the metadata for the old graph is freed once ggml_free is called)\n        ggml_backend_sched_reset(sched); // clear the allocation of the previous graph\n        ggml_backend_sched_alloc_graph(sched, graph); // explicitly allocate the new graph but do not execute it\n        ggml_backend_tensor_set(input_tensor, ...); // copy data to the newly allocated graph tensors\n        ggml_backend_sched_graph_compute(sched, graph); // execute the graph\n\n        // as an alternative to the above it is also possible to assign the inputs to a dedicated context and\n        // allocate them statically via ggml_backend_alloc_ctx_tensors\n    }\n    */\n\n    typedef struct ggml_backend_sched * ggml_backend_sched_t;\n\n    // Evaluation callback for each node in the graph (set with ggml_backend_sched_set_eval_callback)\n    // when ask == true, the scheduler wants to know if the user wants to observe this node\n    // this allows the scheduler to batch nodes together in order to evaluate them in a single call\n    //\n    // when ask == false, the scheduler is passing the node tensor to the user for observation\n    // if the user returns false, the scheduler will cancel the graph compute\n    //\n    typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);\n\n    // Initialize a backend scheduler, backends with low index are given priority over backends with high index\n    GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload);\n    GGML_API void                 ggml_backend_sched_free(ggml_backend_sched_t sched);\n\n    // Initialize backend buffers from a measure graph\n    GGML_API void                 ggml_backend_sched_reserve_size(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph, size_t * sizes);\n    GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success\n\n    GGML_API int                  ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);\n    GGML_API ggml_backend_t       ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);\n\n    // Get the number of splits of the last graph\n    GGML_API int                  ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);\n    GGML_API int                  ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);\n\n    GGML_API ggml_backend_buffer_type_t ggml_backend_sched_get_buffer_type(ggml_backend_sched_t sched, ggml_backend_t backend);\n    GGML_API size_t                     ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);\n\n    GGML_API void                 ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);\n    GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);\n\n    // Split graph without allocating it\n    GGML_API void                 ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);\n\n    // Allocate and compute graph on the backend scheduler\n    GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success\n    GGML_API enum ggml_status     ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);\n    GGML_API enum ggml_status     ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);\n    GGML_API void                 ggml_backend_sched_synchronize(ggml_backend_sched_t sched);\n\n    // Reset all assignments and allocators - must be called before changing the node backends or allocating a new graph.\n    // This in effect deallocates all tensors that were previously allocated and leaves them with dangling pointers.\n    // The correct way to use this API is to discard the deallocated tensors and create new ones.\n    GGML_API void                 ggml_backend_sched_reset(ggml_backend_sched_t sched);\n\n    // Set a callback to be called for each resulting node during graph compute\n    GGML_API void                 ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);\n\n    //\n    // Utils\n    //\n\n    struct ggml_backend_graph_copy {\n        ggml_backend_buffer_t buffer;\n        struct ggml_context * ctx_allocated;\n        struct ggml_context * ctx_unallocated;\n        struct ggml_cgraph * graph;\n    };\n\n    // Copy a graph to a different backend\n    GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);\n    GGML_API void                           ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);\n\n    typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);\n\n    // Compare the output of two backends\n    GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes);\n\n    // Tensor initialization\n    GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);\n    GGML_API enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor);\n\n    // CPU buffer types are always available\n    GGML_API ggml_backend_buffer_t      ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);\n    GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-blas.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n#include \"ggml-backend.h\"\n\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\n// backend API\nGGML_BACKEND_API ggml_backend_t ggml_backend_blas_init(void);\n\nGGML_BACKEND_API bool ggml_backend_is_blas(ggml_backend_t backend);\n\n// number of threads used for conversion to float\n// for openblas and blis, this will also set the number of threads used for blas operations\nGGML_BACKEND_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_blas_reg(void);\n\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-cann.h",
    "content": "/*\n * Copyright (c) 2023-2026 The ggml authors\n *\n * Permission is hereby granted, free of charge, to any person obtaining a copy\n * of this software and associated documentation files (the \"Software\"), to\n * deal in the Software without restriction, including without limitation the\n * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or\n * sell copies of the Software, and to permit persons to whom the Software is\n * furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS\n * IN THE SOFTWARE.\n */\n\n#pragma once\n\n#include \"ggml-backend.h\"\n#include \"ggml.h\"\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n/**\n * @brief Maximum number of CANN devices supported.\n */\n#define GGML_CANN_MAX_DEVICES 16\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_cann_reg(void);\n\n/**\n * @brief Initializes the CANN backend for a specified device.\n *\n * This function initializes the CANN backend for the given device.\n * It verifies the device index, allocates a context, and creates a backend\n * instance.\n *\n * @param device The index of the device to initialize.\n * @return A pointer to the initialized backend instance, or nullptr on failure.\n */\nGGML_BACKEND_API ggml_backend_t ggml_backend_cann_init(int32_t device);\n\n/**\n * @brief Checks if a given backend is a CANN backend.\n *\n * This function verifies if the provided backend is a CANN backend by comparing\n * its GUID with the CANN backend's GUID.\n *\n * @param backend The backend instance to check.\n * @return True if the backend is a CANN backend, false otherwise.\n */\nGGML_BACKEND_API bool ggml_backend_is_cann(ggml_backend_t backend);\n\n/**\n * @brief Retrieves the CANN buffer type for a specified device.\n *\n * This function initializes and returns the buffer type interface associated\n * with the given device. It ensures thread-safe access using a mutex.\n *\n * @param device The device index for which to retrieve the buffer type.\n * @return A pointer to the buffer type interface for the specified device, or\n * nullptr if the device index is out of range.\n */\nGGML_BACKEND_API ggml_backend_buffer_type_t\nggml_backend_cann_buffer_type(int32_t device);\n\n/**\n * @brief Retrieves the number of CANN devices available.\n *\n * This function returns the number of CANN devices available based on\n * information obtained from `ggml_cann_info()`.\n *\n * @return The number of CANN devices available.\n */\nGGML_BACKEND_API int32_t ggml_backend_cann_get_device_count(void);\n\n/**\n * @brief pinned host buffer for use with the CPU backend for faster copies between CPU and NPU.\n *\n * @return A pointer to the host buffer type interface.\n */\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void);\n\n/**\n * @brief Retrieves the description of a specific CANN device.\n *\n * This function sets the specified device, retrieves the SoC name,\n * and writes it into the provided description buffer.\n *\n * @param device The device index to retrieve the description for.\n * @param description Pointer to a buffer where the description will be written.\n * @param description_size Size of the description buffer.\n */\nGGML_BACKEND_API void ggml_backend_cann_get_device_description(\n    int32_t device, char* description, size_t description_size);\n\n/**\n * @brief Retrieves the memory information of a specific CANN device.\n *\n * This function sets the specified device, retrieves the free and total\n * memory information of the specified type (ACL_HBM_MEM), and stores them\n * in the provided pointers.\n *\n * @param device The device index to retrieve memory information for.\n * @param free Pointer to a variable where the free memory size will be stored.\n * @param total Pointer to a variable where the total memory size will be\n * stored.\n */\nGGML_BACKEND_API void ggml_backend_cann_get_device_memory(int32_t device,\n                                                  size_t* free,\n                                                  size_t* total);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-cpp.h",
    "content": "#pragma once\n\n#ifndef __cplusplus\n#error \"This header is for C++ only\"\n#endif\n\n#include \"ggml.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n#include \"gguf.h\"\n#include <memory>\n\n// Smart pointers for ggml types\n\n// ggml\n\nstruct ggml_context_deleter { void operator()(ggml_context * ctx) { ggml_free(ctx); } };\nstruct gguf_context_deleter { void operator()(gguf_context * ctx) { gguf_free(ctx); } };\n\ntypedef std::unique_ptr<ggml_context, ggml_context_deleter> ggml_context_ptr;\ntypedef std::unique_ptr<gguf_context, gguf_context_deleter> gguf_context_ptr;\n\n// ggml-alloc\n\nstruct ggml_gallocr_deleter { void operator()(ggml_gallocr_t galloc) { ggml_gallocr_free(galloc); } };\n\ntypedef std::unique_ptr<ggml_gallocr, ggml_gallocr_deleter> ggml_gallocr_ptr;\n\n// ggml-backend\n\nstruct ggml_backend_deleter        { void operator()(ggml_backend_t backend)       { ggml_backend_free(backend); } };\nstruct ggml_backend_buffer_deleter { void operator()(ggml_backend_buffer_t buffer) { ggml_backend_buffer_free(buffer); } };\nstruct ggml_backend_event_deleter  { void operator()(ggml_backend_event_t event)   { ggml_backend_event_free(event); } };\nstruct ggml_backend_sched_deleter  { void operator()(ggml_backend_sched_t sched)   { ggml_backend_sched_free(sched); } };\n\ntypedef std::unique_ptr<ggml_backend,        ggml_backend_deleter>        ggml_backend_ptr;\ntypedef std::unique_ptr<ggml_backend_buffer, ggml_backend_buffer_deleter> ggml_backend_buffer_ptr;\ntypedef std::unique_ptr<ggml_backend_event,  ggml_backend_event_deleter>  ggml_backend_event_ptr;\ntypedef std::unique_ptr<ggml_backend_sched,  ggml_backend_sched_deleter>  ggml_backend_sched_ptr;\n"
  },
  {
    "path": "include/ggml-cpu.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n#include \"ggml-backend.h\"\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\n    // the compute plan that needs to be prepared for ggml_graph_compute()\n    // since https://github.com/ggml-org/ggml/issues/287\n    struct ggml_cplan {\n        size_t    work_size; // size of work buffer, calculated by `ggml_graph_plan()`\n        uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`\n\n        int n_threads;\n        struct ggml_threadpool * threadpool;\n\n        // abort ggml_graph_compute when true\n        ggml_abort_callback abort_callback;\n        void *              abort_callback_data;\n\n        // use only reference implementations\n        bool use_ref;\n    };\n\n    // numa strategies\n    enum ggml_numa_strategy {\n        GGML_NUMA_STRATEGY_DISABLED   = 0,\n        GGML_NUMA_STRATEGY_DISTRIBUTE = 1,\n        GGML_NUMA_STRATEGY_ISOLATE    = 2,\n        GGML_NUMA_STRATEGY_NUMACTL    = 3,\n        GGML_NUMA_STRATEGY_MIRROR     = 4,\n        GGML_NUMA_STRATEGY_COUNT\n    };\n\n    GGML_BACKEND_API void    ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems\n    GGML_BACKEND_API bool    ggml_is_numa(void); // true if init detected that system has >1 NUMA node\n\n    GGML_BACKEND_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);\n    GGML_BACKEND_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);\n\n    GGML_BACKEND_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);\n    GGML_BACKEND_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);\n\n    GGML_BACKEND_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);\n    GGML_BACKEND_API void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);\n\n    GGML_BACKEND_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);\n    GGML_BACKEND_API void    ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);\n\n    GGML_BACKEND_API float   ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);\n    GGML_BACKEND_API void    ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);\n\n    GGML_BACKEND_API float   ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);\n    GGML_BACKEND_API void    ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);\n\n    GGML_BACKEND_API struct ggml_threadpool *      ggml_threadpool_new           (struct ggml_threadpool_params  * params);\n    GGML_BACKEND_API void                          ggml_threadpool_free          (struct ggml_threadpool * threadpool);\n    GGML_BACKEND_API int                           ggml_threadpool_get_n_threads (struct ggml_threadpool * threadpool);\n    GGML_BACKEND_API void                          ggml_threadpool_pause         (struct ggml_threadpool * threadpool);\n    GGML_BACKEND_API void                          ggml_threadpool_resume        (struct ggml_threadpool * threadpool);\n\n    // ggml_graph_plan() has to be called before ggml_graph_compute()\n    // when plan.work_size > 0, caller must allocate memory for plan.work_data\n    GGML_BACKEND_API struct ggml_cplan ggml_graph_plan(\n                  const struct ggml_cgraph * cgraph,\n                                       int   n_threads, /* = GGML_DEFAULT_N_THREADS */\n                    struct ggml_threadpool * threadpool /* = NULL */ );\n    GGML_BACKEND_API enum ggml_status  ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);\n\n    // same as ggml_graph_compute() but the work data is allocated as a part of the context\n    // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data\n    GGML_BACKEND_API enum ggml_status  ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);\n\n    //\n    // system info\n    //\n\n    // x86\n    GGML_BACKEND_API int ggml_cpu_has_sse3       (void);\n    GGML_BACKEND_API int ggml_cpu_has_ssse3      (void);\n    GGML_BACKEND_API int ggml_cpu_has_avx        (void);\n    GGML_BACKEND_API int ggml_cpu_has_avx_vnni   (void);\n    GGML_BACKEND_API int ggml_cpu_has_avx2       (void);\n    GGML_BACKEND_API int ggml_cpu_has_bmi2       (void);\n    GGML_BACKEND_API int ggml_cpu_has_f16c       (void);\n    GGML_BACKEND_API int ggml_cpu_has_fma        (void);\n    GGML_BACKEND_API int ggml_cpu_has_avx512     (void);\n    GGML_BACKEND_API int ggml_cpu_has_avx512_vbmi(void);\n    GGML_BACKEND_API int ggml_cpu_has_avx512_vnni(void);\n    GGML_BACKEND_API int ggml_cpu_has_avx512_bf16(void);\n    GGML_BACKEND_API int ggml_cpu_has_amx_int8   (void);\n    // ARM\n    GGML_BACKEND_API int ggml_cpu_has_neon       (void);\n    GGML_BACKEND_API int ggml_cpu_has_arm_fma    (void);\n    GGML_BACKEND_API int ggml_cpu_has_fp16_va    (void);\n    GGML_BACKEND_API int ggml_cpu_has_dotprod    (void);\n    GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);\n    GGML_BACKEND_API int ggml_cpu_has_sve        (void);\n    GGML_BACKEND_API int ggml_cpu_get_sve_cnt    (void);  // sve vector length in bytes\n    GGML_BACKEND_API int ggml_cpu_has_sme        (void);\n    // other\n    GGML_BACKEND_API int ggml_cpu_has_riscv_v    (void);\n    GGML_BACKEND_API int ggml_cpu_get_rvv_vlen   (void);  // risc-v vector length in bytes\n    GGML_BACKEND_API int ggml_cpu_has_vsx        (void);\n    GGML_BACKEND_API int ggml_cpu_has_vxe        (void);\n    GGML_BACKEND_API int ggml_cpu_has_wasm_simd  (void);\n    GGML_BACKEND_API int ggml_cpu_has_llamafile  (void);\n\n    // Internal types and functions exposed for tests and benchmarks\n\n    typedef void (*ggml_vec_dot_t)  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,\n                                       const void * GGML_RESTRICT y, size_t by, int nrc);\n\n    struct ggml_type_traits_cpu {\n        ggml_from_float_t        from_float;\n        ggml_vec_dot_t           vec_dot;\n        enum ggml_type           vec_dot_type;\n        int64_t                  nrows; // number of rows to process simultaneously\n    };\n\n    GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);\n\n    GGML_BACKEND_API void ggml_cpu_init(void);\n\n    //\n    // CPU backend\n    //\n\n    GGML_BACKEND_API ggml_backend_t ggml_backend_cpu_init(void);\n\n    GGML_BACKEND_API bool ggml_backend_is_cpu                (ggml_backend_t backend);\n    GGML_BACKEND_API void ggml_backend_cpu_set_n_threads     (ggml_backend_t backend_cpu, int n_threads);\n    GGML_BACKEND_API void ggml_backend_cpu_set_threadpool    (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);\n    GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);\n\n    GGML_BACKEND_API void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref);\n\n    GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);\n\n    GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *,       float *, int64_t);\n    GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *,     int32_t *, int64_t);\n    GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);\n    GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);\n    GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);\n    GGML_BACKEND_API void ggml_cpu_bf16_to_fp32(const ggml_bf16_t *, float *, int64_t);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-cuda.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n#include \"ggml-backend.h\"\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\n#ifdef GGML_USE_HIP\n#define GGML_CUDA_NAME \"ROCm\"\n#define GGML_CUBLAS_NAME \"hipBLAS\"\n#elif defined(GGML_USE_MUSA)\n#define GGML_CUDA_NAME \"MUSA\"\n#define GGML_CUBLAS_NAME \"muBLAS\"\n#else\n#define GGML_CUDA_NAME \"CUDA\"\n#define GGML_CUBLAS_NAME \"cuBLAS\"\n#endif\n#define GGML_CUDA_MAX_DEVICES       16\n\n// backend API\nGGML_BACKEND_API ggml_backend_t ggml_backend_cuda_init(int device);\n\nGGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend);\n\n// device buffer\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);\n\n// split tensor buffer that splits matrices by rows across multiple devices\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);\n\n// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);\n\nGGML_BACKEND_API int  ggml_backend_cuda_get_device_count(void);\nGGML_BACKEND_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);\nGGML_BACKEND_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);\n\nGGML_BACKEND_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);\nGGML_BACKEND_API void ggml_backend_cuda_unregister_host_buffer(void * buffer);\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_cuda_reg(void);\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-hexagon.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n#include \"ggml-backend.h\"\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\n// backend API\nGGML_BACKEND_API ggml_backend_t ggml_backend_hexagon_init(void);\n\nGGML_BACKEND_API bool ggml_backend_is_hexagon(ggml_backend_t backend);\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_hexagon_reg(void);\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-metal.h",
    "content": "// Note: this description is outdated\n//\n// An interface allowing to compute ggml_cgraph with Metal\n//\n// This is a fully functional interface that extends ggml with GPU support for Apple devices.\n// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)\n//\n// How it works?\n//\n// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this\n// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you\n// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)\n//\n// You only need to make sure that all memory buffers that you used during the graph creation\n// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is\n// used during the graph evaluation to determine the arguments of the compute kernels.\n//\n// Synchronization between device and host memory (for example for input and output tensors)\n// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.\n//\n\n#pragma once\n\n#include \"ggml.h\"\n#include \"ggml-backend.h\"\n\n#include <stddef.h>\n#include <stdbool.h>\n\nstruct ggml_tensor;\nstruct ggml_cgraph;\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n//\n// backend API\n// user-code should use only these functions\n//\n\n// TODO: remove in the future\nGGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);\n\nGGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);\n\nGGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);\n\n// helper to check if the device supports a specific family\n// ideally, the user code should be doing these checks\n// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf\nGGML_BACKEND_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);\n\n// capture all command buffers committed the next time `ggml_backend_graph_compute` is called\nGGML_BACKEND_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_metal_reg(void);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-opencl.h",
    "content": "#ifndef GGML_OPENCL_H\n#define GGML_OPENCL_H\n\n#include \"ggml.h\"\n#include \"ggml-backend.h\"\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\n//\n// backend API\n//\nGGML_BACKEND_API ggml_backend_t ggml_backend_opencl_init(void);\nGGML_BACKEND_API bool ggml_backend_is_opencl(ggml_backend_t backend);\n\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type(void);\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type(void);\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_opencl_reg(void);\n\n#ifdef  __cplusplus\n}\n#endif\n\n#endif // GGML_OPENCL_H\n"
  },
  {
    "path": "include/ggml-openvino.h",
    "content": "#pragma once\n\n#include \"ggml-backend.h\"\n\n#include <cstring>\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n#define GGML_OPENVINO_NAME \"OPENVINO\"\n\n// backend API\nGGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device);\n\nGGML_BACKEND_API bool ggml_backend_is_openvino(ggml_backend_t backend);\n\nGGML_BACKEND_API bool ggml_backend_buffer_is_openvino(ggml_backend_buffer_t buffer);\n\nGGML_BACKEND_API bool ggml_backend_buft_is_openvino(ggml_backend_buffer_type_t buft);\n\nGGML_BACKEND_API bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft);\n\nGGML_BACKEND_API size_t ggml_backend_openvino_buffer_get_ctx_id(ggml_backend_buffer_t buffer);\n\n// device buffer\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(int device);\n\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_host_buffer_type(int device);\n\nGGML_BACKEND_API int ggml_backend_openvino_get_device_count(void);\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_openvino_reg(void);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-opt.h",
    "content": "// This file contains functionality for training models using GGML.\n// It is not strictly needed vs. just vanilla GGML but it provides a more high-level interface for common needs such as datasets.\n// At the bottom of this file especially there are relatively high-level functions that are suitable use or adaptation in user code.\n//\n// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de)\n\n#pragma once\n\n#include \"ggml.h\"\n#include \"ggml-backend.h\"\n\n#include <stdint.h>\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\n    struct ggml_opt_dataset;\n    struct ggml_opt_context;\n    struct ggml_opt_result;\n\n    typedef struct ggml_opt_dataset * ggml_opt_dataset_t;\n    typedef struct ggml_opt_context * ggml_opt_context_t;\n    typedef struct ggml_opt_result  * ggml_opt_result_t;\n\n    // ====== Loss ======\n\n    // built-in loss types, i.e. the built-in quantities minimized by the optimizer\n    // custom loss types can be defined via mean or sum which simply reduce the outputs for all datapoints to a single value\n    enum ggml_opt_loss_type {\n        GGML_OPT_LOSS_TYPE_MEAN,\n        GGML_OPT_LOSS_TYPE_SUM,\n        GGML_OPT_LOSS_TYPE_CROSS_ENTROPY,\n        GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR,\n    };\n\n    // ====== Dataset ======\n\n    GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(\n            enum ggml_type type_data,    // the type for the internal data tensor\n            enum ggml_type type_label,   // the type for the internal labels tensor\n            int64_t        ne_datapoint, // number of elements per datapoint\n            int64_t        ne_label,     // number of elements per label\n            int64_t        ndata,        // total number of datapoints/labels\n            int64_t        ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)\n    GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);\n\n    // get underlying tensors that store the data\n    GGML_API int64_t              ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset);\n    GGML_API struct ggml_tensor * ggml_opt_dataset_data  (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]\n    GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label,     ndata]\n\n    // shuffle idata first datapoints from dataset with RNG from opt_ctx, shuffle all datapoints if idata is negative\n    GGML_API void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata);\n\n    // get batch at position ibatch from dataset and copy the data to data_batch and labels_batch\n    GGML_API void ggml_opt_dataset_get_batch(\n            ggml_opt_dataset_t   dataset,\n            struct ggml_tensor * data_batch,   // shape = [ne_datapoint, ndata_batch]\n            struct ggml_tensor * labels_batch, // shape = [ne_label,     ndata_batch]\n            int64_t              ibatch);\n    GGML_API void ggml_opt_dataset_get_batch_host(\n            ggml_opt_dataset_t   dataset,\n            void               * data_batch,\n            size_t               nb_data_batch,\n            void               * labels_batch,\n            int64_t              ibatch);\n\n    // ====== Model / Context ======\n\n    enum ggml_opt_build_type {\n        GGML_OPT_BUILD_TYPE_FORWARD = 10,\n        GGML_OPT_BUILD_TYPE_GRAD    = 20,\n        GGML_OPT_BUILD_TYPE_OPT     = 30,\n    };\n\n    enum ggml_opt_optimizer_type {\n        GGML_OPT_OPTIMIZER_TYPE_ADAMW,\n        GGML_OPT_OPTIMIZER_TYPE_SGD,\n\n        GGML_OPT_OPTIMIZER_TYPE_COUNT\n    };\n\n    // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss\n    struct ggml_opt_optimizer_params {\n        struct {\n            float alpha; // learning rate\n            float beta1; // first AdamW momentum\n            float beta2; // second AdamW momentum\n            float eps;   // epsilon for numerical stability\n            float wd;    // weight decay - 0.0f to disable\n        } adamw;\n        struct {\n            float alpha; // learning rate\n            float wd;    // weight decay\n        } sgd;\n    };\n\n    // callback to calculate optimizer parameters prior to a backward pass\n    // userdata can be used to pass arbitrary data\n    typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);\n\n    // returns the default optimizer params (constant, hard-coded values)\n    // userdata is not used\n    GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);\n\n    // casts userdata to ggml_opt_optimizer_params and returns it\n    GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata);\n\n    // parameters for initializing a new optimization context\n    struct ggml_opt_params {\n        ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs\n\n        // by default the forward graph needs to be reconstructed for each eval\n        // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically\n        struct ggml_context * ctx_compute;\n        struct ggml_tensor  * inputs;\n        struct ggml_tensor  * outputs;\n\n        enum ggml_opt_loss_type  loss_type;\n        enum ggml_opt_build_type build_type;\n\n        int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done\n\n        ggml_opt_get_optimizer_params get_opt_pars;    // callback for calculating optimizer parameters\n        void *                        get_opt_pars_ud; // userdata for calculating optimizer parameters\n\n        // only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor\n        enum ggml_opt_optimizer_type optimizer;\n    };\n\n    // get parameters for an optimization context with defaults set where possible\n    // parameters for which no sensible defaults exist are supplied as arguments to this function\n    GGML_API struct ggml_opt_params ggml_opt_default_params(\n            ggml_backend_sched_t    backend_sched,\n            enum ggml_opt_loss_type loss_type);\n\n    GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);\n    GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);\n\n    // set gradients to zero, initialize loss, and optionally reset the optimizer\n    GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);\n\n    GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically\n\n    // get underlying tensors that store data\n    // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc\n    GGML_API struct ggml_tensor * ggml_opt_inputs(  ggml_opt_context_t opt_ctx); // forward graph input tensor\n    GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor\n    GGML_API struct ggml_tensor * ggml_opt_labels(  ggml_opt_context_t opt_ctx); // labels to compare outputs against\n    GGML_API struct ggml_tensor * ggml_opt_loss(    ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss\n    GGML_API struct ggml_tensor * ggml_opt_pred(    ggml_opt_context_t opt_ctx); // predictions made by outputs\n    GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels\n\n    // get the gradient accumulator for a node from the forward graph\n    GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);\n\n    GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme\n\n    GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);\n\n    // ====== Optimization Result ======\n\n    GGML_API ggml_opt_result_t ggml_opt_result_init(void);\n    GGML_API void ggml_opt_result_free(ggml_opt_result_t result);\n    GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);\n\n    // get data from result, uncertainties are optional and can be ignored by passing NULL\n    GGML_API void ggml_opt_result_ndata(   ggml_opt_result_t result, int64_t * ndata);                  // writes 1 value, number of datapoints\n    GGML_API void ggml_opt_result_loss(    ggml_opt_result_t result, double  * loss,     double * unc); // writes 1 value\n    GGML_API void ggml_opt_result_pred(    ggml_opt_result_t result, int32_t * pred);                   // writes ndata values\n    GGML_API void ggml_opt_result_accuracy(ggml_opt_result_t result, double  * accuracy, double * unc); // writes 1 value\n\n    // ====== Computation ======\n\n    // if not using static graphs, this function must be called prior to ggml_opt_alloc\n    GGML_API void ggml_opt_prepare_alloc(\n        ggml_opt_context_t    opt_ctx,\n        struct ggml_context * ctx_compute,\n        struct ggml_cgraph  * gf,\n        struct ggml_tensor  * inputs,\n        struct ggml_tensor  * outputs);\n\n    // allocate the next graph for evaluation, either forward or forward + backward\n    // must be called exactly once prior to calling ggml_opt_eval\n    GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward);\n\n    // do forward pass, increment result if not NULL, do backward pass if allocated\n    GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);\n\n    // ############################################################################\n    // ## The high-level functions start here. They do not depend on any private ##\n    // ## functions or structs and can be copied to and adapted for user code.   ##\n    // ############################################################################\n\n    // ====== Intended Usage ======\n    //\n    // 1. Select the appropriate loss for your problem.\n    // 2. Create a dataset and set the data for the \"data\" tensor. Also set the \"labels\" tensor if your loss needs them.\n    //    Setting the shard size to 1 will be fine, it's the granularity with which data is shuffled/loaded (bigger values are faster).\n    // 3. Create a GGML graph for your model with no_alloc == true. Use two separate contexts for the tensors.\n    //    The first context should contain the model parameters and inputs and be allocated statically in user code.\n    //    The second context should contain all other tensors and will be (re)allocated automatically.\n    //    Due to this automated allocation the data of the second context is not defined when accessed in user code.\n    //    Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors.\n    // 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead.\n\n    // signature for a callback while evaluating opt_ctx on dataset, called after an evaluation\n    typedef void (*ggml_opt_epoch_callback)(\n            bool               train,       // true after training evaluation, false after validation evaluation\n            ggml_opt_context_t opt_ctx,\n            ggml_opt_dataset_t dataset,\n            ggml_opt_result_t  result,      // result associated with the dataset subsection\n            int64_t            ibatch,      // number of batches that have been evaluated so far\n            int64_t            ibatch_max,  // total number of batches in this dataset subsection\n            int64_t            t_start_us); // time at which the evaluation on the dataset subsection was started\n\n    // do training on front of dataset, do evaluation only on back of dataset\n    GGML_API void ggml_opt_epoch(\n            ggml_opt_context_t      opt_ctx,\n            ggml_opt_dataset_t      dataset,\n            ggml_opt_result_t       result_train,   // result to increment during training, ignored if NULL\n            ggml_opt_result_t       result_eval,    // result to increment during evaluation, ignored if NULL\n            int64_t                 idata_split,    // data index at which to split training and evaluation\n            ggml_opt_epoch_callback callback_train,\n            ggml_opt_epoch_callback callback_eval);\n\n    // callback that prints a progress bar on stderr\n    GGML_API void ggml_opt_epoch_callback_progress_bar(\n            bool               train,\n            ggml_opt_context_t opt_ctx,\n            ggml_opt_dataset_t dataset,\n            ggml_opt_result_t  result,\n            int64_t            ibatch,\n            int64_t            ibatch_max,\n            int64_t            t_start_us);\n\n    // fit model defined by inputs and outputs to dataset\n    GGML_API void ggml_opt_fit(\n            ggml_backend_sched_t            backend_sched,  // backend scheduler for constructing the compute graphs\n            struct ggml_context           * ctx_compute,    // context with temporarily allocated tensors to calculate the outputs\n            struct ggml_tensor            * inputs,         // input tensor with shape [ne_datapoint, ndata_batch]\n            struct ggml_tensor            * outputs,        // output tensor, must have shape [ne_label, ndata_batch] if labels are used\n            ggml_opt_dataset_t              dataset,        // dataset with data and optionally also labels\n            enum ggml_opt_loss_type         loss_type,      // loss to minimize\n            enum ggml_opt_optimizer_type    optimizer,      // sgd or adamw\n            ggml_opt_get_optimizer_params   get_opt_pars,   // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)\n            int64_t                         nepoch,         // how many times the dataset should be iterated over\n            int64_t                         nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs\n            float                           val_split,      // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)\n            bool                            silent);        // whether or not info prints to stderr should be suppressed\n\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-rpc.h",
    "content": "#pragma once\n\n#include \"ggml-backend.h\"\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\n#define RPC_PROTO_MAJOR_VERSION    3\n#define RPC_PROTO_MINOR_VERSION    6\n#define RPC_PROTO_PATCH_VERSION    1\n\n#ifdef  __cplusplus\nstatic_assert(GGML_OP_COUNT == 96, \"GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION\");\n#endif\n\n#define GGML_RPC_MAX_SERVERS       16\n\n// backend API\nGGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device);\nGGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend);\n\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device);\n\nGGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);\n\nGGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,\n                                                    size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices);\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-sycl.h",
    "content": "//\n//  MIT license\n//  Copyright (C) 2024 Intel Corporation\n//  SPDX-License-Identifier: MIT\n//\n\n#pragma once\n\n#include \"ggml.h\"\n#include \"ggml-backend.h\"\n\n#define GGML_SYCL_NAME \"SYCL\"\n#define GGML_SYCL_MAX_DEVICES 48\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\n// backend API\nGGML_BACKEND_API ggml_backend_t ggml_backend_sycl_init(int device);\n\nGGML_BACKEND_API bool ggml_backend_is_sycl(ggml_backend_t backend);\n\n// devide buffer\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);\n\n// split tensor buffer that splits matrices by rows across multiple devices\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);\n\n// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);\n\nGGML_BACKEND_API void ggml_backend_sycl_print_sycl_devices(void);\nGGML_BACKEND_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len);\nGGML_BACKEND_API void ggml_backend_sycl_get_device_description(int device,\n                                                       char *description,\n                                                       size_t description_size);\nGGML_BACKEND_API int  ggml_backend_sycl_get_device_count();\nGGML_BACKEND_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);\n\n// SYCL doesn't support registering host memory, keep here for reference\n// GGML_BACKEND_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);\n// GGML_BACKEND_API void ggml_backend_sycl_unregister_host_buffer(void * buffer);\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_sycl_reg(void);\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-virtgpu.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n#include \"ggml-backend.h\"\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_virtgpu_reg();\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-vulkan.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n#include \"ggml-backend.h\"\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\n#define GGML_VK_NAME \"Vulkan\"\n#define GGML_VK_MAX_DEVICES 16\n\n// backend API\nGGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);\n\nGGML_BACKEND_API bool ggml_backend_is_vk(ggml_backend_t backend);\nGGML_BACKEND_API int  ggml_backend_vk_get_device_count(void);\nGGML_BACKEND_API void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size);\nGGML_BACKEND_API void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total);\n\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);\n// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_vk_reg(void);\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-webgpu.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n#include \"ggml-backend.h\"\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\n#define GGML_WEBGPU_NAME \"WebGPU\"\n\n// Needed for examples in ggml\nGGML_BACKEND_API ggml_backend_t ggml_backend_webgpu_init(void);\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_webgpu_reg(void);\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-zdnn.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n#include \"ggml-backend.h\"\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n// device buffer\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_zdnn_buffer_type(void);\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_zdnn_reg(void);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml-zendnn.h",
    "content": "#pragma once\n\n#include \"ggml-backend.h\"\n#include \"ggml.h\"\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n// backend API\nGGML_BACKEND_API ggml_backend_t ggml_backend_zendnn_init(void);\n\nGGML_BACKEND_API bool ggml_backend_is_zendnn(ggml_backend_t backend);\n\n// number of threads used for zendnn operations\nGGML_BACKEND_API void ggml_backend_zendnn_set_n_threads(ggml_backend_t backend_zendnn, int n_threads);\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_zendnn_reg(void);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/ggml.h",
    "content": "#pragma once\n\n//\n// GGML Tensor Library\n//\n// This documentation is still a work in progress.\n// If you wish some specific topics to be covered, feel free to drop a comment:\n//\n//   https://github.com/ggml-org/whisper.cpp/issues/40\n//\n// ## Overview\n//\n// This library implements:\n//\n//  - a set of tensor operations\n//  - automatic differentiation\n//  - basic optimization algorithms\n//\n// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes,\n// but is not limited to, the following:\n//\n//  - linear regression\n//  - support vector machines\n//  - neural networks\n//\n// The library allows the user to define a certain function using the available tensor operations. This function\n// definition is represented internally via a computation graph. Each tensor operation in the function definition\n// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the\n// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized\n// using one of the available optimization algorithms.\n//\n// For example, here we define the function: f(x) = a*x^2 + b\n//\n//   {\n//       struct ggml_init_params params = {\n//           .mem_size   = 16*1024*1024,\n//           .mem_buffer = NULL,\n//       };\n//\n//       // memory allocation happens here\n//       struct ggml_context * ctx = ggml_init(params);\n//\n//       struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);\n//\n//       ggml_set_param(ctx, x); // x is an input variable\n//\n//       struct ggml_tensor * a  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);\n//       struct ggml_tensor * b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);\n//       struct ggml_tensor * x2 = ggml_mul(ctx, x, x);\n//       struct ggml_tensor * f  = ggml_add(ctx, ggml_mul(ctx, a, x2), b);\n//\n//       ...\n//   }\n//\n// Notice that the function definition above does not involve any actual computation. The computation is performed only\n// when the user explicitly requests it. For example, to compute the function's value at x = 2.0:\n//\n//   {\n//       ...\n//\n//       struct ggml_cgraph * gf = ggml_new_graph(ctx);\n//       ggml_build_forward_expand(gf, f);\n//\n//       // set the input variable and parameter values\n//       ggml_set_f32(x, 2.0f);\n//       ggml_set_f32(a, 3.0f);\n//       ggml_set_f32(b, 4.0f);\n//\n//       ggml_graph_compute_with_ctx(ctx, &gf, n_threads);\n//\n//       printf(\"f = %f\\n\", ggml_get_f32_1d(f, 0));\n//\n//       ...\n//   }\n//\n// The actual computation is performed in the ggml_graph_compute() function.\n//\n// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the\n// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know\n// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory\n// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was\n// actually needed.\n//\n// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic\n// differentiation and optimization algorithms.\n//\n// The described approach allows to define the function graph once and then compute its forward or backward graphs\n// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way\n// the user can avoid the memory allocation overhead at runtime.\n//\n// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class\n// citizens, but in theory the library can be extended to support FP8 and integer data types.\n//\n// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary\n// and binary operations. Most of the available operations fall into one of these two categories. With time, it became\n// clear that the library needs to support more complex operations. The way to support these operations is not clear\n// yet, but a few examples are demonstrated in the following operations:\n//\n//   - ggml_permute()\n//   - ggml_conv_1d_1s()\n//   - ggml_conv_1d_2s()\n//\n// For each tensor operator, the library implements a forward and backward computation function. The forward function\n// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the\n// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a\n// calculus class, or watch the following video:\n//\n//   What is Automatic Differentiation?\n//   https://www.youtube.com/watch?v=wG_nF1awSSY\n//\n//\n// ## Tensor data (struct ggml_tensor)\n//\n// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of\n// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains\n// pointers to the \"source\" tensors - i.e. the tensors that were used to compute the current tensor. For example:\n//\n//   {\n//       struct ggml_tensor * c = ggml_add(ctx, a, b);\n//\n//       assert(c->src[0] == a);\n//       assert(c->src[1] == b);\n//   }\n//\n// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the\n// number of elements in each dimension (\"ne\") as well as the number of bytes (\"nb\", a.k.a. stride). This allows\n// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and\n// permutation. All tensor operations have to take the stride into account and not assume that the tensor is\n// contiguous in memory.\n//\n// The data of the tensor is accessed via the \"data\" pointer. For example:\n//\n//   {\n//       const int nx = 2;\n//       const int ny = 3;\n//\n//       struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny);\n//\n//       for (int y = 0; y < ny; y++) {\n//           for (int x = 0; x < nx; x++) {\n//               *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y;\n//           }\n//       }\n//\n//       ...\n//   }\n//\n// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used.\n//\n// ## The matrix multiplication operator (ggml_mul_mat)\n//\n// TODO\n//\n//\n// ## Multi-threading\n//\n// TODO\n//\n//\n// ## Overview of ggml.c\n//\n// TODO\n//\n//\n// ## SIMD optimizations\n//\n// TODO\n//\n//\n// ## Debugging ggml\n//\n// TODO\n//\n//\n\n#ifdef GGML_SHARED\n#    if defined(_WIN32) && !defined(__MINGW32__)\n#        ifdef GGML_BUILD\n#            define GGML_API __declspec(dllexport) extern\n#        else\n#            define GGML_API __declspec(dllimport) extern\n#        endif\n#    else\n#        define GGML_API __attribute__ ((visibility (\"default\"))) extern\n#    endif\n#else\n#    define GGML_API extern\n#endif\n\n// TODO: support for clang\n#ifdef __GNUC__\n#    define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))\n#elif defined(_MSC_VER)\n#    define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func\n#else\n#    define GGML_DEPRECATED(func, hint) func\n#endif\n\n#ifndef __GNUC__\n#    define GGML_ATTRIBUTE_FORMAT(...)\n#elif defined(__MINGW32__) && !defined(__clang__)\n#    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))\n#else\n#    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))\n#endif\n\n#if defined(_WIN32) && !defined(_WIN32_WINNT)\n#    define _WIN32_WINNT 0x0A00\n#endif\n\n#include <stdbool.h>\n#include <stddef.h>\n#include <stdint.h>\n#include <stdio.h>\n\n#define GGML_FILE_MAGIC   0x67676d6c // \"ggml\"\n#define GGML_FILE_VERSION 2\n\n#define GGML_QNT_VERSION        2    // bump this on quantization format changes\n#define GGML_QNT_VERSION_FACTOR 1000 // do not change this\n\n#define GGML_MAX_DIMS           4\n#define GGML_MAX_PARAMS         2048\n#define GGML_MAX_SRC            10\n#define GGML_MAX_N_THREADS      512\n#define GGML_MAX_OP_PARAMS      64\n\n#ifndef GGML_MAX_NAME\n#   define GGML_MAX_NAME        64\n#endif\n\n#define GGML_DEFAULT_N_THREADS  4\n#define GGML_DEFAULT_GRAPH_SIZE 2048\n\n#if UINTPTR_MAX == 0xFFFFFFFF\n    #define GGML_MEM_ALIGN 4\n#elif defined(__EMSCRIPTEN__)\n// emscripten uses max_align_t == 8, so we need GGML_MEM_ALIGN == 8 for 64-bit wasm.\n// (for 32-bit wasm, the first conditional is true and GGML_MEM_ALIGN stays 4.)\n// ref: https://github.com/ggml-org/llama.cpp/pull/18628\n    #define GGML_MEM_ALIGN 8\n#else\n    #define GGML_MEM_ALIGN 16\n#endif\n\n#define GGML_EXIT_SUCCESS 0\n#define GGML_EXIT_ABORTED 1\n\n// TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726\n#define GGML_ROPE_TYPE_NORMAL 0\n#define GGML_ROPE_TYPE_NEOX   2\n#define GGML_ROPE_TYPE_MROPE  8\n#define GGML_ROPE_TYPE_VISION 24\n#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000\n\n#define GGML_MROPE_SECTIONS   4\n\n#define GGML_UNUSED(x) (void)(x)\n#ifdef __CUDACC__\ntemplate<typename... Args>\n__host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexcept {}\n#define GGML_UNUSED_VARS(...) ggml_unused_vars_impl(__VA_ARGS__)\n#else\n#define GGML_UNUSED_VARS(...) do { (void)sizeof((__VA_ARGS__, 0)); } while(0)\n#endif // __CUDACC__\n\n#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))\n\n#ifndef NDEBUG\n#   define GGML_UNREACHABLE() do { fprintf(stderr, \"statement should be unreachable\\n\"); abort(); } while(0)\n#elif defined(__GNUC__)\n#   define GGML_UNREACHABLE() __builtin_unreachable()\n#elif defined(_MSC_VER)\n#   define GGML_UNREACHABLE() __assume(0)\n#else\n#   define GGML_UNREACHABLE() ((void) 0)\n#endif\n\n#ifdef __cplusplus\n#   define GGML_NORETURN [[noreturn]]\n#elif defined(_MSC_VER)\n#   define GGML_NORETURN __declspec(noreturn)\n#else\n#   define GGML_NORETURN _Noreturn\n#endif\n\n#define GGML_ABORT(...) ggml_abort(__FILE__, __LINE__, __VA_ARGS__)\n#define GGML_ASSERT(x) if (!(x)) GGML_ABORT(\"GGML_ASSERT(%s) failed\", #x)\n\n// used to copy the number of elements and stride in bytes of tensors into local variables.\n// main purpose is to reduce code duplication and improve readability.\n//\n// example:\n//\n//    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);\n//    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb);\n//\n#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \\\n    const type prefix##0 = (pointer) ? (pointer)->array[0] : 0; \\\n    GGML_UNUSED(prefix##0);\n#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \\\n    GGML_TENSOR_LOCALS_1    (type, prefix, pointer, array) \\\n    const type prefix##1 = (pointer) ? (pointer)->array[1] : 0; \\\n    GGML_UNUSED(prefix##1);\n#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \\\n    GGML_TENSOR_LOCALS_2    (type, prefix, pointer, array) \\\n    const type prefix##2 = (pointer) ? (pointer)->array[2] : 0; \\\n    GGML_UNUSED(prefix##2);\n#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \\\n    GGML_TENSOR_LOCALS_3  (type, prefix, pointer, array) \\\n    const type prefix##3 = (pointer) ? (pointer)->array[3] : 0; \\\n    GGML_UNUSED(prefix##3);\n\n#define GGML_TENSOR_UNARY_OP_LOCALS \\\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \\\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \\\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \\\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)\n\n#define GGML_TENSOR_BINARY_OP_LOCALS \\\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \\\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \\\n    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \\\n    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \\\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \\\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)\n\n#define GGML_TENSOR_TERNARY_OP_LOCALS \\\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \\\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \\\n    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \\\n    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \\\n    GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \\\n    GGML_TENSOR_LOCALS(size_t,  nb2, src2, nb) \\\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \\\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)\n\n#define GGML_TENSOR_BINARY_OP_LOCALS01 \\\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \\\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \\\n    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \\\n    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\n    // Function type used in fatal error callbacks\n    typedef void (*ggml_abort_callback_t)(const char * error_message);\n\n    // Set the abort callback (passing null will restore original abort functionality: printing a message to stdout)\n    // Returns the old callback for chaining\n    GGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback);\n\n    GGML_NORETURN GGML_ATTRIBUTE_FORMAT(3, 4)\n    GGML_API void ggml_abort(const char * file, int line, const char * fmt, ...);\n\n    enum ggml_status {\n        GGML_STATUS_ALLOC_FAILED = -2,\n        GGML_STATUS_FAILED = -1,\n        GGML_STATUS_SUCCESS = 0,\n        GGML_STATUS_ABORTED = 1,\n    };\n\n    // get ggml_status name string\n    GGML_API const char * ggml_status_to_string(enum ggml_status status);\n\n    // ieee 754-2008 half-precision float16\n    // todo: make this not an integral type\n    typedef uint16_t ggml_fp16_t;\n    GGML_API float       ggml_fp16_to_fp32(ggml_fp16_t);\n    GGML_API ggml_fp16_t ggml_fp32_to_fp16(float);\n    GGML_API void        ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t);\n    GGML_API void        ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t);\n\n    // google brain half-precision bfloat16\n    typedef struct { uint16_t bits; } ggml_bf16_t;\n    GGML_API ggml_bf16_t ggml_fp32_to_bf16(float);\n    GGML_API float       ggml_bf16_to_fp32(ggml_bf16_t);  // consider just doing << 16\n    GGML_API void        ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t);\n    GGML_API void        ggml_fp32_to_bf16_row_ref(const float *, ggml_bf16_t *, int64_t);\n    GGML_API void        ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t);\n\n    struct ggml_object;\n    struct ggml_context;\n    struct ggml_cgraph;\n\n    // NOTE: always add types at the end of the enum to keep backward compatibility\n    enum ggml_type {\n        GGML_TYPE_F32     = 0,\n        GGML_TYPE_F16     = 1,\n        GGML_TYPE_Q4_0    = 2,\n        GGML_TYPE_Q4_1    = 3,\n        // GGML_TYPE_Q4_2 = 4, support has been removed\n        // GGML_TYPE_Q4_3 = 5, support has been removed\n        GGML_TYPE_Q5_0    = 6,\n        GGML_TYPE_Q5_1    = 7,\n        GGML_TYPE_Q8_0    = 8,\n        GGML_TYPE_Q8_1    = 9,\n        GGML_TYPE_Q2_K    = 10,\n        GGML_TYPE_Q3_K    = 11,\n        GGML_TYPE_Q4_K    = 12,\n        GGML_TYPE_Q5_K    = 13,\n        GGML_TYPE_Q6_K    = 14,\n        GGML_TYPE_Q8_K    = 15,\n        GGML_TYPE_IQ2_XXS = 16,\n        GGML_TYPE_IQ2_XS  = 17,\n        GGML_TYPE_IQ3_XXS = 18,\n        GGML_TYPE_IQ1_S   = 19,\n        GGML_TYPE_IQ4_NL  = 20,\n        GGML_TYPE_IQ3_S   = 21,\n        GGML_TYPE_IQ2_S   = 22,\n        GGML_TYPE_IQ4_XS  = 23,\n        GGML_TYPE_I8      = 24,\n        GGML_TYPE_I16     = 25,\n        GGML_TYPE_I32     = 26,\n        GGML_TYPE_I64     = 27,\n        GGML_TYPE_F64     = 28,\n        GGML_TYPE_IQ1_M   = 29,\n        GGML_TYPE_BF16    = 30,\n        // GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files\n        // GGML_TYPE_Q4_0_4_8 = 32,\n        // GGML_TYPE_Q4_0_8_8 = 33,\n        GGML_TYPE_TQ1_0   = 34,\n        GGML_TYPE_TQ2_0   = 35,\n        // GGML_TYPE_IQ4_NL_4_4 = 36,\n        // GGML_TYPE_IQ4_NL_4_8 = 37,\n        // GGML_TYPE_IQ4_NL_8_8 = 38,\n        GGML_TYPE_MXFP4   = 39, // MXFP4 (1 block)\n        GGML_TYPE_NVFP4   = 40, // NVFP4 (4 blocks, E4M3 scale)\n        GGML_TYPE_COUNT   = 41,\n    };\n\n    // precision\n    enum ggml_prec {\n        GGML_PREC_DEFAULT =  0, // stored as ggml_tensor.op_params, 0 by default\n        GGML_PREC_F32     = 10,\n    };\n\n    // model file types\n    enum ggml_ftype {\n        GGML_FTYPE_UNKNOWN        = -1,\n        GGML_FTYPE_ALL_F32        = 0,\n        GGML_FTYPE_MOSTLY_F16     = 1,  // except 1d tensors\n        GGML_FTYPE_MOSTLY_Q4_0    = 2,  // except 1d tensors\n        GGML_FTYPE_MOSTLY_Q4_1    = 3,  // except 1d tensors\n        GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16\n        GGML_FTYPE_MOSTLY_Q8_0    = 7,  // except 1d tensors\n        GGML_FTYPE_MOSTLY_Q5_0    = 8,  // except 1d tensors\n        GGML_FTYPE_MOSTLY_Q5_1    = 9,  // except 1d tensors\n        GGML_FTYPE_MOSTLY_Q2_K    = 10, // except 1d tensors\n        GGML_FTYPE_MOSTLY_Q3_K    = 11, // except 1d tensors\n        GGML_FTYPE_MOSTLY_Q4_K    = 12, // except 1d tensors\n        GGML_FTYPE_MOSTLY_Q5_K    = 13, // except 1d tensors\n        GGML_FTYPE_MOSTLY_Q6_K    = 14, // except 1d tensors\n        GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors\n        GGML_FTYPE_MOSTLY_IQ2_XS  = 16, // except 1d tensors\n        GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors\n        GGML_FTYPE_MOSTLY_IQ1_S   = 18, // except 1d tensors\n        GGML_FTYPE_MOSTLY_IQ4_NL  = 19, // except 1d tensors\n        GGML_FTYPE_MOSTLY_IQ3_S   = 20, // except 1d tensors\n        GGML_FTYPE_MOSTLY_IQ2_S   = 21, // except 1d tensors\n        GGML_FTYPE_MOSTLY_IQ4_XS  = 22, // except 1d tensors\n        GGML_FTYPE_MOSTLY_IQ1_M   = 23, // except 1d tensors\n        GGML_FTYPE_MOSTLY_BF16    = 24, // except 1d tensors\n        GGML_FTYPE_MOSTLY_MXFP4   = 25, // except 1d tensors\n        GGML_FTYPE_MOSTLY_NVFP4   = 26, // except 1d tensors\n    };\n\n    // available tensor operations:\n    enum ggml_op {\n        GGML_OP_NONE = 0,\n\n        GGML_OP_DUP,\n        GGML_OP_ADD,\n        GGML_OP_ADD_ID,\n        GGML_OP_ADD1,\n        GGML_OP_ACC,\n        GGML_OP_SUB,\n        GGML_OP_MUL,\n        GGML_OP_DIV,\n        GGML_OP_SQR,\n        GGML_OP_SQRT,\n        GGML_OP_LOG,\n        GGML_OP_SIN,\n        GGML_OP_COS,\n        GGML_OP_SUM,\n        GGML_OP_SUM_ROWS,\n        GGML_OP_CUMSUM,\n        GGML_OP_MEAN,\n        GGML_OP_ARGMAX,\n        GGML_OP_COUNT_EQUAL,\n        GGML_OP_REPEAT,\n        GGML_OP_REPEAT_BACK,\n        GGML_OP_CONCAT,\n        GGML_OP_SILU_BACK,\n        GGML_OP_NORM, // normalize\n        GGML_OP_RMS_NORM,\n        GGML_OP_RMS_NORM_BACK,\n        GGML_OP_GROUP_NORM,\n        GGML_OP_L2_NORM,\n\n        GGML_OP_MUL_MAT,\n        GGML_OP_MUL_MAT_ID,\n        GGML_OP_OUT_PROD,\n\n        GGML_OP_SCALE,\n        GGML_OP_SET,\n        GGML_OP_CPY,\n        GGML_OP_CONT,\n        GGML_OP_RESHAPE,\n        GGML_OP_VIEW,\n        GGML_OP_PERMUTE,\n        GGML_OP_TRANSPOSE,\n        GGML_OP_GET_ROWS,\n        GGML_OP_GET_ROWS_BACK,\n        GGML_OP_SET_ROWS,\n        GGML_OP_DIAG,\n        GGML_OP_DIAG_MASK_INF,\n        GGML_OP_DIAG_MASK_ZERO,\n        GGML_OP_SOFT_MAX,\n        GGML_OP_SOFT_MAX_BACK,\n        GGML_OP_ROPE,\n        GGML_OP_ROPE_BACK,\n        GGML_OP_CLAMP,\n        GGML_OP_CONV_TRANSPOSE_1D,\n        GGML_OP_IM2COL,\n        GGML_OP_IM2COL_BACK,\n        GGML_OP_IM2COL_3D,\n        GGML_OP_CONV_2D,\n        GGML_OP_CONV_3D,\n        GGML_OP_CONV_2D_DW,\n        GGML_OP_CONV_TRANSPOSE_2D,\n        GGML_OP_POOL_1D,\n        GGML_OP_POOL_2D,\n        GGML_OP_POOL_2D_BACK,\n        GGML_OP_UPSCALE,\n        GGML_OP_PAD,\n        GGML_OP_PAD_REFLECT_1D,\n        GGML_OP_ROLL,\n        GGML_OP_ARANGE,\n        GGML_OP_TIMESTEP_EMBEDDING,\n        GGML_OP_ARGSORT,\n        GGML_OP_TOP_K,\n        GGML_OP_LEAKY_RELU,\n        GGML_OP_TRI,\n        GGML_OP_FILL,\n\n        GGML_OP_FLASH_ATTN_EXT,\n        GGML_OP_FLASH_ATTN_BACK,\n        GGML_OP_SSM_CONV,\n        GGML_OP_SSM_SCAN,\n        GGML_OP_WIN_PART,\n        GGML_OP_WIN_UNPART,\n        GGML_OP_GET_REL_POS,\n        GGML_OP_ADD_REL_POS,\n        GGML_OP_RWKV_WKV6,\n        GGML_OP_GATED_LINEAR_ATTN,\n        GGML_OP_RWKV_WKV7,\n        GGML_OP_SOLVE_TRI,\n        GGML_OP_GATED_DELTA_NET,\n\n        GGML_OP_UNARY,\n\n        GGML_OP_MAP_CUSTOM1,\n        GGML_OP_MAP_CUSTOM2,\n        GGML_OP_MAP_CUSTOM3,\n\n        GGML_OP_CUSTOM,\n\n        GGML_OP_CROSS_ENTROPY_LOSS,\n        GGML_OP_CROSS_ENTROPY_LOSS_BACK,\n        GGML_OP_OPT_STEP_ADAMW,\n        GGML_OP_OPT_STEP_SGD,\n\n        GGML_OP_GLU,\n\n        GGML_OP_COUNT,\n    };\n\n    enum ggml_unary_op {\n        GGML_UNARY_OP_ABS,\n        GGML_UNARY_OP_SGN,\n        GGML_UNARY_OP_NEG,\n        GGML_UNARY_OP_STEP,\n        GGML_UNARY_OP_TANH,\n        GGML_UNARY_OP_ELU,\n        GGML_UNARY_OP_RELU,\n        GGML_UNARY_OP_SIGMOID,\n        GGML_UNARY_OP_GELU,\n        GGML_UNARY_OP_GELU_QUICK,\n        GGML_UNARY_OP_SILU,\n        GGML_UNARY_OP_HARDSWISH,\n        GGML_UNARY_OP_HARDSIGMOID,\n        GGML_UNARY_OP_EXP,\n        GGML_UNARY_OP_EXPM1,\n        GGML_UNARY_OP_SOFTPLUS,\n        GGML_UNARY_OP_GELU_ERF,\n        GGML_UNARY_OP_XIELU,\n        GGML_UNARY_OP_FLOOR,\n        GGML_UNARY_OP_CEIL,\n        GGML_UNARY_OP_ROUND,\n        GGML_UNARY_OP_TRUNC,\n\n        GGML_UNARY_OP_COUNT,\n    };\n\n    enum ggml_glu_op {\n        GGML_GLU_OP_REGLU,\n        GGML_GLU_OP_GEGLU,\n        GGML_GLU_OP_SWIGLU,\n        GGML_GLU_OP_SWIGLU_OAI,\n        GGML_GLU_OP_GEGLU_ERF,\n        GGML_GLU_OP_GEGLU_QUICK,\n\n        GGML_GLU_OP_COUNT,\n    };\n\n    enum ggml_object_type {\n        GGML_OBJECT_TYPE_TENSOR,\n        GGML_OBJECT_TYPE_GRAPH,\n        GGML_OBJECT_TYPE_WORK_BUFFER\n    };\n\n    enum ggml_log_level {\n        GGML_LOG_LEVEL_NONE  = 0,\n        GGML_LOG_LEVEL_DEBUG = 1,\n        GGML_LOG_LEVEL_INFO  = 2,\n        GGML_LOG_LEVEL_WARN  = 3,\n        GGML_LOG_LEVEL_ERROR = 4,\n        GGML_LOG_LEVEL_CONT  = 5, // continue previous log\n    };\n\n    // this tensor...\n    enum ggml_tensor_flag {\n        GGML_TENSOR_FLAG_INPUT   =  1, // ...is an input for the GGML compute graph\n        GGML_TENSOR_FLAG_OUTPUT  =  2, // ...is an output for the GGML compute graph\n        GGML_TENSOR_FLAG_PARAM   =  4, // ...contains trainable parameters\n        GGML_TENSOR_FLAG_LOSS    =  8, // ...defines loss for numerical optimization (multiple loss tensors add up)\n        GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed\n    };\n\n    enum ggml_tri_type {\n        GGML_TRI_TYPE_UPPER_DIAG = 0,\n        GGML_TRI_TYPE_UPPER      = 1,\n        GGML_TRI_TYPE_LOWER_DIAG = 2,\n        GGML_TRI_TYPE_LOWER      = 3\n    };\n\n    struct ggml_init_params {\n        // memory pool\n        size_t mem_size;   // bytes\n        void * mem_buffer; // if NULL, memory will be allocated internally\n        bool   no_alloc;   // don't allocate memory for the tensor data\n    };\n\n    // n-dimensional tensor\n    struct ggml_tensor {\n        enum ggml_type type;\n\n        struct ggml_backend_buffer * buffer;\n\n        int64_t ne[GGML_MAX_DIMS]; // number of elements\n        size_t  nb[GGML_MAX_DIMS]; // stride in bytes:\n                                   // nb[0] = ggml_type_size(type)\n                                   // nb[1] = nb[0]   * (ne[0] / ggml_blck_size(type)) + padding\n                                   // nb[i] = nb[i-1] * ne[i-1]\n\n        // compute data\n        enum ggml_op op;\n\n        // op params - allocated as int32_t for alignment\n        int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];\n\n        int32_t flags;\n\n        struct ggml_tensor * src[GGML_MAX_SRC];\n\n        // source tensor and offset for views\n        struct ggml_tensor * view_src;\n        size_t               view_offs;\n\n        void * data;\n\n        char name[GGML_MAX_NAME];\n\n        void * extra; // extra things e.g. for ggml-cuda.cu\n\n        char padding[8];\n    };\n\n    static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);\n\n    // Abort callback\n    // If not NULL, called before ggml computation\n    // If it returns true, the computation is aborted\n    typedef bool (*ggml_abort_callback)(void * data);\n\n\n    //\n    // GUID\n    //\n\n    // GUID types\n    typedef uint8_t ggml_guid[16];\n    typedef ggml_guid * ggml_guid_t;\n\n    GGML_API bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b);\n\n    // misc\n\n    GGML_API const char * ggml_version(void);\n    GGML_API const char * ggml_commit(void);\n\n    GGML_API void    ggml_time_init(void); // call this once at the beginning of the program\n    GGML_API int64_t ggml_time_ms(void);\n    GGML_API int64_t ggml_time_us(void);\n    GGML_API int64_t ggml_cycles(void);\n    GGML_API int64_t ggml_cycles_per_ms(void);\n\n    // accepts a UTF-8 path, even on Windows\n    GGML_API FILE *  ggml_fopen(const char * fname, const char * mode);\n\n    GGML_API void    ggml_print_object (const struct ggml_object * obj);\n    GGML_API void    ggml_print_objects(const struct ggml_context * ctx);\n\n    GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor);\n    GGML_API int64_t ggml_nrows     (const struct ggml_tensor * tensor);\n    GGML_API size_t  ggml_nbytes    (const struct ggml_tensor * tensor);\n    GGML_API size_t  ggml_nbytes_pad(const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN\n\n    GGML_API int64_t ggml_blck_size(enum ggml_type type);\n    GGML_API size_t  ggml_type_size(enum ggml_type type);             // size in bytes for all elements in a block\n    GGML_API size_t  ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row\n\n    GGML_DEPRECATED(\n    GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float\n    \"use ggml_row_size() instead\");\n\n    GGML_API const char * ggml_type_name(enum ggml_type type);\n    GGML_API const char * ggml_op_name  (enum ggml_op   op);\n    GGML_API const char * ggml_op_symbol(enum ggml_op   op);\n\n    GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);\n    GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op);\n    GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name\n\n    GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);\n\n    GGML_API bool    ggml_is_quantized(enum ggml_type type);\n\n    // TODO: temporary until model loading of ggml examples is refactored\n    GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);\n\n    GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);\n    GGML_API bool ggml_is_permuted  (const struct ggml_tensor * tensor);\n    GGML_API bool ggml_is_empty     (const struct ggml_tensor * tensor);\n    GGML_API bool ggml_is_view      (const struct ggml_tensor * tensor);\n    GGML_API bool ggml_is_scalar    (const struct ggml_tensor * tensor);\n    GGML_API bool ggml_is_vector    (const struct ggml_tensor * tensor);\n    GGML_API bool ggml_is_matrix    (const struct ggml_tensor * tensor);\n    GGML_API bool ggml_is_3d        (const struct ggml_tensor * tensor);\n    GGML_API int  ggml_n_dims       (const struct ggml_tensor * tensor); // returns 1 for scalars\n\n    // returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation)\n    GGML_API bool ggml_is_contiguous  (const struct ggml_tensor * tensor);\n    GGML_API bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()\n    GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1\n    GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2\n\n    // returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok)\n    GGML_API bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor);\n\n    // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN\n    GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);\n\n    // true if the elements in dimension 0 are contiguous, or there is just 1 block of elements\n    GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);\n\n    GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);\n    GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);\n\n    GGML_API bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1);\n\n    // use this to compute the memory overhead of a tensor\n    GGML_API size_t ggml_tensor_overhead(void);\n\n    GGML_API bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes);\n\n    // main\n\n    GGML_API struct ggml_context * ggml_init (struct ggml_init_params params);\n    GGML_API void                  ggml_reset(struct ggml_context * ctx);\n    GGML_API void                  ggml_free (struct ggml_context * ctx);\n\n    GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);\n\n    GGML_API bool    ggml_get_no_alloc(struct ggml_context * ctx);\n    GGML_API void    ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);\n\n    GGML_API void *  ggml_get_mem_buffer     (const struct ggml_context * ctx);\n    GGML_API size_t  ggml_get_mem_size       (const struct ggml_context * ctx);\n    GGML_API size_t  ggml_get_max_tensor_size(const struct ggml_context * ctx);\n\n    GGML_API struct ggml_tensor * ggml_new_tensor(\n            struct ggml_context * ctx,\n            enum   ggml_type type,\n            int    n_dims,\n            const int64_t *ne);\n\n    GGML_API struct ggml_tensor * ggml_new_tensor_1d(\n            struct ggml_context * ctx,\n            enum   ggml_type type,\n            int64_t ne0);\n\n    GGML_API struct ggml_tensor * ggml_new_tensor_2d(\n            struct ggml_context * ctx,\n            enum   ggml_type type,\n            int64_t ne0,\n            int64_t ne1);\n\n    GGML_API struct ggml_tensor * ggml_new_tensor_3d(\n            struct ggml_context * ctx,\n            enum   ggml_type type,\n            int64_t ne0,\n            int64_t ne1,\n            int64_t ne2);\n\n    GGML_API struct ggml_tensor * ggml_new_tensor_4d(\n            struct ggml_context * ctx,\n            enum   ggml_type type,\n            int64_t ne0,\n            int64_t ne1,\n            int64_t ne2,\n            int64_t ne3);\n\n    GGML_API void * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes);\n\n    GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);\n    GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);\n\n    // Context tensor enumeration and lookup\n    GGML_API struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx);\n    GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor);\n    GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);\n\n    // Converts a flat index into coordinates\n    GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);\n\n    GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);\n    GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor);\n\n    GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);\n    GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);\n\n    GGML_API const char *         ggml_get_name   (const struct ggml_tensor * tensor);\n    GGML_API struct ggml_tensor * ggml_set_name   (      struct ggml_tensor * tensor, const char * name);\n    GGML_ATTRIBUTE_FORMAT(2, 3)\n    GGML_API struct ggml_tensor * ggml_format_name(      struct ggml_tensor * tensor, const char * fmt, ...);\n\n    // Tensor flags\n    GGML_API void ggml_set_input(struct ggml_tensor * tensor);\n    GGML_API void ggml_set_output(struct ggml_tensor * tensor);\n    GGML_API void ggml_set_param(struct ggml_tensor * tensor);\n    GGML_API void ggml_set_loss(struct ggml_tensor * tensor);\n\n    //\n    // operations on tensors with backpropagation\n    //\n\n    GGML_API struct ggml_tensor * ggml_dup(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    // in-place, returns view(a)\n    GGML_API struct ggml_tensor * ggml_dup_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_add(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    GGML_API struct ggml_tensor * ggml_add_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    GGML_API struct ggml_tensor * ggml_add_cast(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            enum   ggml_type      type);\n\n    // dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]\n    GGML_API struct ggml_tensor * ggml_add_id(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            struct ggml_tensor  * ids);\n\n    GGML_API struct ggml_tensor * ggml_add1(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    GGML_API struct ggml_tensor * ggml_add1_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    // dst = a\n    // view(dst, nb1, nb2, nb3, offset) += b\n    // return dst\n    GGML_API struct ggml_tensor * ggml_acc(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            size_t                nb1,\n            size_t                nb2,\n            size_t                nb3,\n            size_t                offset);\n\n    GGML_API struct ggml_tensor * ggml_acc_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            size_t                nb1,\n            size_t                nb2,\n            size_t                nb3,\n            size_t                offset);\n\n    GGML_API struct ggml_tensor * ggml_sub(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    GGML_API struct ggml_tensor * ggml_sub_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    GGML_API struct ggml_tensor * ggml_mul(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    GGML_API struct ggml_tensor * ggml_mul_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    GGML_API struct ggml_tensor * ggml_div(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    GGML_API struct ggml_tensor * ggml_div_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    GGML_API struct ggml_tensor * ggml_sqr(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_sqr_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_sqrt(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_sqrt_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_log(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_log_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_expm1(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_expm1_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_softplus(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_softplus_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_sin(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_sin_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_cos(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_cos_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    // return scalar\n    GGML_API struct ggml_tensor * ggml_sum(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d]\n    GGML_API struct ggml_tensor * ggml_sum_rows(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_cumsum(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a);\n\n    // mean along rows\n    GGML_API struct ggml_tensor * ggml_mean(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    // argmax along rows\n    GGML_API struct ggml_tensor * ggml_argmax(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    // count number of equal elements in a and b\n    GGML_API struct ggml_tensor * ggml_count_equal(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    // if a is the same shape as b, and a is not parameter, return a\n    // otherwise, return a new tensor: repeat(a) to fit in b\n    GGML_API struct ggml_tensor * ggml_repeat(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    // repeat a to the specified shape\n    GGML_API struct ggml_tensor * ggml_repeat_4d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n                       int64_t    ne0,\n                       int64_t    ne1,\n                       int64_t    ne2,\n                       int64_t    ne3);\n\n    // sums repetitions in a into shape of b\n    GGML_API struct ggml_tensor * ggml_repeat_back(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride\n\n    // concat a and b along dim\n    // used in stable-diffusion\n    GGML_API struct ggml_tensor * ggml_concat(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            int                   dim);\n\n    GGML_API struct ggml_tensor * ggml_abs(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_abs_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_sgn(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_sgn_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_neg(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_neg_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_step(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_step_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_tanh(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_tanh_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_elu(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_elu_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_relu(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_leaky_relu(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a, float negative_slope, bool inplace);\n\n    GGML_API struct ggml_tensor * ggml_relu_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_sigmoid(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_sigmoid_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_gelu(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_gelu_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    // GELU using erf (error function) when possible\n    // some backends may fallback to approximation based on Abramowitz and Stegun formula\n    GGML_API struct ggml_tensor * ggml_gelu_erf(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_gelu_erf_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_gelu_quick(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_gelu_quick_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_silu(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_silu_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    // a - x\n    // b - dy\n    GGML_API struct ggml_tensor * ggml_silu_back(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    // hardswish(x) = x * relu6(x + 3) / 6\n    GGML_API struct ggml_tensor * ggml_hardswish(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    // hardsigmoid(x) = relu6(x + 3) / 6\n    GGML_API struct ggml_tensor * ggml_hardsigmoid(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_exp(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_exp_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_floor(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_floor_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_ceil(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_ceil_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_round(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_round_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n     /**\n     * Truncates the fractional part of each element in the tensor (towards zero).\n     * For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0\n     * Similar to std::trunc in C/C++.\n     */\n\n    GGML_API struct ggml_tensor * ggml_trunc(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_trunc_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n\n\n    // xIELU activation function\n    // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)\n    // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions\n    // that constrain the positive and negative source alpha values respectively\n    GGML_API struct ggml_tensor * ggml_xielu(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            float alpha_n,\n            float alpha_p,\n            float beta,\n            float eps);\n\n    // gated linear unit ops\n    // A: n columns, r rows,\n    // result is n / 2 columns, r rows,\n    // expects gate in second half of row, unless swapped is true\n    GGML_API struct ggml_tensor * ggml_glu(\n            struct ggml_context * ctx,\n             struct ggml_tensor * a,\n             enum ggml_glu_op     op,\n             bool                 swapped);\n\n    GGML_API struct ggml_tensor * ggml_reglu(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_reglu_swapped(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_geglu(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_geglu_swapped(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_swiglu(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_swiglu_swapped(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_geglu_erf(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_geglu_erf_swapped(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_geglu_quick(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    GGML_API struct ggml_tensor * ggml_geglu_quick_swapped(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    // A: n columns, r rows,\n    // B: n columns, r rows,\n    GGML_API struct ggml_tensor * ggml_glu_split(\n            struct ggml_context * ctx,\n             struct ggml_tensor * a,\n             struct ggml_tensor * b,\n             enum ggml_glu_op     op);\n\n    GGML_API struct ggml_tensor * ggml_reglu_split(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    GGML_API struct ggml_tensor * ggml_geglu_split(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    GGML_API struct ggml_tensor * ggml_swiglu_split(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    GGML_API struct ggml_tensor * ggml_geglu_erf_split(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    GGML_API struct ggml_tensor * ggml_geglu_quick_split(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    GGML_API struct ggml_tensor * ggml_swiglu_oai(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            float                 alpha,\n            float                 limit);\n\n    // normalize along rows\n    GGML_API struct ggml_tensor * ggml_norm(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            float                 eps);\n\n    GGML_API struct ggml_tensor * ggml_norm_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            float                 eps);\n\n    GGML_API struct ggml_tensor * ggml_rms_norm(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            float                 eps);\n\n    GGML_API struct ggml_tensor * ggml_rms_norm_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            float                 eps);\n\n    // group normalize along ne0*ne1*n_groups\n    // used in stable-diffusion\n    GGML_API struct ggml_tensor * ggml_group_norm(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   n_groups,\n            float                 eps);\n\n    GGML_API struct ggml_tensor * ggml_group_norm_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   n_groups,\n            float                 eps);\n\n    // l2 normalize along rows\n    // used in rwkv v7\n    GGML_API struct ggml_tensor * ggml_l2_norm(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            float                 eps);\n\n    GGML_API struct ggml_tensor * ggml_l2_norm_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            float                 eps);\n\n    // a - x\n    // b - dy\n    GGML_API struct ggml_tensor * ggml_rms_norm_back(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            float                 eps);\n\n    // A: k columns, n rows => [ne03, ne02, n, k]\n    // B: k columns, m rows  (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k]\n    // result is n columns, m rows => [ne03 * x, ne02 * y, m, n]\n    GGML_API struct ggml_tensor * ggml_mul_mat(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    // change the precision of a matrix multiplication\n    // set to GGML_PREC_F32 for higher precision (useful for phi-2)\n    GGML_API void ggml_mul_mat_set_prec(\n            struct ggml_tensor * a,\n            enum ggml_prec       prec);\n\n    // indirect matrix multiplication\n    GGML_API struct ggml_tensor * ggml_mul_mat_id(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * as,\n            struct ggml_tensor  * b,\n            struct ggml_tensor  * ids);\n\n    // A: m columns, n rows,\n    // B: p columns, n rows,\n    // result is m columns, p rows\n    GGML_API struct ggml_tensor * ggml_out_prod(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    //\n    // operations on tensors without backpropagation\n    //\n\n    GGML_API struct ggml_tensor * ggml_scale(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            float                 s);\n\n    // in-place, returns view(a)\n    GGML_API struct ggml_tensor * ggml_scale_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            float                 s);\n\n    // x = s * a + b\n    GGML_API struct ggml_tensor * ggml_scale_bias(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 s,\n        float                 b);\n\n    GGML_API struct ggml_tensor * ggml_scale_bias_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 s,\n        float                 b);\n\n    // b -> view(a,offset,nb1,nb2,3), return modified a\n    GGML_API struct ggml_tensor * ggml_set(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            size_t                nb1,\n            size_t                nb2,\n            size_t                nb3,\n            size_t                offset); // in bytes\n\n    // b -> view(a,offset,nb1,nb2,3), return view(a)\n    GGML_API struct ggml_tensor * ggml_set_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            size_t                nb1,\n            size_t                nb2,\n            size_t                nb3,\n            size_t                offset); // in bytes\n\n    GGML_API struct ggml_tensor * ggml_set_1d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            size_t                offset); // in bytes\n\n    GGML_API struct ggml_tensor * ggml_set_1d_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            size_t                offset); // in bytes\n\n    // b -> view(a,offset,nb1,nb2,3), return modified a\n    GGML_API struct ggml_tensor * ggml_set_2d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            size_t                nb1,\n            size_t                offset); // in bytes\n\n    // b -> view(a,offset,nb1,nb2,3), return view(a)\n    GGML_API struct ggml_tensor * ggml_set_2d_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            size_t                nb1,\n            size_t                offset); // in bytes\n\n    // a -> b, return view(b)\n    GGML_API struct ggml_tensor * ggml_cpy(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    // note: casting from f32 to i32 will discard the fractional part\n    GGML_API struct ggml_tensor * ggml_cast(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            enum   ggml_type      type);\n\n    // make contiguous\n    GGML_API struct ggml_tensor * ggml_cont(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    // make contiguous, with new shape\n    GGML_API struct ggml_tensor * ggml_cont_1d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int64_t               ne0);\n\n    GGML_API struct ggml_tensor * ggml_cont_2d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int64_t               ne0,\n            int64_t               ne1);\n\n    GGML_API struct ggml_tensor * ggml_cont_3d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int64_t               ne0,\n            int64_t               ne1,\n            int64_t               ne2);\n\n    GGML_API struct ggml_tensor * ggml_cont_4d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int64_t               ne0,\n            int64_t               ne1,\n            int64_t               ne2,\n            int64_t               ne3);\n\n    // return view(a), b specifies the new shape\n    // TODO: when we start computing gradient, make a copy instead of view\n    GGML_API struct ggml_tensor * ggml_reshape(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    // return view(a)\n    // TODO: when we start computing gradient, make a copy instead of view\n    GGML_API struct ggml_tensor * ggml_reshape_1d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int64_t               ne0);\n\n    GGML_API struct ggml_tensor * ggml_reshape_2d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int64_t               ne0,\n            int64_t               ne1);\n\n    // return view(a)\n    // TODO: when we start computing gradient, make a copy instead of view\n    GGML_API struct ggml_tensor * ggml_reshape_3d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int64_t               ne0,\n            int64_t               ne1,\n            int64_t               ne2);\n\n    GGML_API struct ggml_tensor * ggml_reshape_4d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int64_t               ne0,\n            int64_t               ne1,\n            int64_t               ne2,\n            int64_t               ne3);\n\n    // offset in bytes\n    GGML_API struct ggml_tensor * ggml_view_1d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int64_t               ne0,\n            size_t                offset);\n\n    GGML_API struct ggml_tensor * ggml_view_2d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int64_t               ne0,\n            int64_t               ne1,\n            size_t                nb1, // row stride in bytes\n            size_t                offset);\n\n    GGML_API struct ggml_tensor * ggml_view_3d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int64_t               ne0,\n            int64_t               ne1,\n            int64_t               ne2,\n            size_t                nb1, // row   stride in bytes\n            size_t                nb2, // slice stride in bytes\n            size_t                offset);\n\n    GGML_API struct ggml_tensor * ggml_view_4d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int64_t               ne0,\n            int64_t               ne1,\n            int64_t               ne2,\n            int64_t               ne3,\n            size_t                nb1, // row   stride in bytes\n            size_t                nb2, // slice stride in bytes\n            size_t                nb3,\n            size_t                offset);\n\n    GGML_API struct ggml_tensor * ggml_permute(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   axis0,\n            int                   axis1,\n            int                   axis2,\n            int                   axis3);\n\n    // alias for ggml_permute(ctx, a, 1, 0, 2, 3)\n    GGML_API struct ggml_tensor * ggml_transpose(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    // supports 4D a:\n    // a     [n_embd, ne1, ne2, ne3]\n    // b I32 [n_rows, ne2, ne3, 1]\n    //\n    // return [n_embd, n_rows, ne2, ne3]\n    GGML_API struct ggml_tensor * ggml_get_rows(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,  // data\n            struct ggml_tensor  * b); // row indices\n\n    GGML_API struct ggml_tensor * ggml_get_rows_back(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,  // gradients of ggml_get_rows result\n            struct ggml_tensor  * b,  // row indices\n            struct ggml_tensor  * c); // data for ggml_get_rows, only used for its shape\n\n    // a TD  [n_embd, ne1,    ne2,    ne3]\n    // b TS  [n_embd, n_rows, ne02,   ne03] | ne02 == ne2, ne03 == ne3\n    // c I64 [n_rows, ne11,   ne12,   1]    | c[i] in [0, ne1)\n    //\n    // undefined behavior if destination rows overlap\n    //\n    // broadcast:\n    //   ne2 % ne11 == 0\n    //   ne3 % ne12 == 0\n    //\n    // return view(a)\n    GGML_API struct ggml_tensor * ggml_set_rows(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,  // destination\n            struct ggml_tensor  * b,  // source\n            struct ggml_tensor  * c); // row indices\n\n    GGML_API struct ggml_tensor * ggml_diag(\n        struct ggml_context     * ctx,\n        struct ggml_tensor      * a);\n\n    // set elements above the diagonal to -INF\n    GGML_API struct ggml_tensor * ggml_diag_mask_inf(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   n_past);\n\n    // in-place, returns view(a)\n    GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   n_past);\n\n    // set elements above the diagonal to 0\n    GGML_API struct ggml_tensor * ggml_diag_mask_zero(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   n_past);\n\n    // in-place, returns view(a)\n    GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   n_past);\n\n    GGML_API struct ggml_tensor * ggml_soft_max(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    // in-place, returns view(a)\n    GGML_API struct ggml_tensor * ggml_soft_max_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a);\n\n    // a    [ne0, ne01, ne02, ne03]\n    // mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional\n    //\n    // broadcast:\n    //   ne02 % ne12 == 0\n    //   ne03 % ne13 == 0\n    //\n    // fused soft_max(a*scale + mask*(ALiBi slope))\n    // max_bias = 0.0f for no ALiBi\n    GGML_API struct ggml_tensor * ggml_soft_max_ext(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * mask,\n            float                 scale,\n            float                 max_bias);\n\n    GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * mask,\n            float                 scale,\n            float                 max_bias);\n\n    GGML_API void ggml_soft_max_add_sinks(\n            struct ggml_tensor * a,\n            struct ggml_tensor * sinks);\n\n    GGML_API struct ggml_tensor * ggml_soft_max_ext_back(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            float                 scale,\n            float                 max_bias);\n\n    // in-place, returns view(a)\n    GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            float                 scale,\n            float                 max_bias);\n\n    // rotary position embedding\n    // if (mode & 1) - skip n_past elements (NOT SUPPORTED)\n    // if (mode & GGML_ROPE_TYPE_NEOX) - GPT-NeoX style\n    //\n    // b is an int32 vector with size a->ne[2], it contains the positions\n    GGML_API struct ggml_tensor * ggml_rope(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            int                   n_dims,\n            int                   mode);\n\n    // in-place, returns view(a)\n    GGML_API struct ggml_tensor * ggml_rope_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            int                   n_dims,\n            int                   mode);\n\n    // custom RoPE\n    // c is freq factors (e.g. phi3-128k), (optional)\n    GGML_API struct ggml_tensor * ggml_rope_ext(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            struct ggml_tensor  * c,\n            int                   n_dims,\n            int                   mode,\n            int                   n_ctx_orig,\n            float                 freq_base,\n            float                 freq_scale,\n            float                 ext_factor,\n            float                 attn_factor,\n            float                 beta_fast,\n            float                 beta_slow);\n\n    GGML_API struct ggml_tensor * ggml_rope_multi(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            struct ggml_tensor  * c,\n            int                   n_dims,\n            int                   sections[GGML_MROPE_SECTIONS],\n            int                   mode,\n            int                   n_ctx_orig,\n            float                 freq_base,\n            float                 freq_scale,\n            float                 ext_factor,\n            float                 attn_factor,\n            float                 beta_fast,\n            float                 beta_slow);\n\n    // in-place, returns view(a)\n    GGML_API struct ggml_tensor * ggml_rope_ext_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            struct ggml_tensor  * c,\n            int                   n_dims,\n            int                   mode,\n            int                   n_ctx_orig,\n            float                 freq_base,\n            float                 freq_scale,\n            float                 ext_factor,\n            float                 attn_factor,\n            float                 beta_fast,\n            float                 beta_slow);\n\n    GGML_API struct ggml_tensor * ggml_rope_multi_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            struct ggml_tensor  * c,\n            int                   n_dims,\n            int                   sections[GGML_MROPE_SECTIONS],\n            int                   mode,\n            int                   n_ctx_orig,\n            float                 freq_base,\n            float                 freq_scale,\n            float                 ext_factor,\n            float                 attn_factor,\n            float                 beta_fast,\n            float                 beta_slow);\n\n    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            int                   n_dims,\n            int                   mode,\n            int                   n_ctx_orig,\n            float                 freq_base,\n            float                 freq_scale,\n            float                 ext_factor,\n            float                 attn_factor,\n            float                 beta_fast,\n            float                 beta_slow),\n        \"use ggml_rope_ext instead\");\n\n    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            int                   n_dims,\n            int                   mode,\n            int                   n_ctx_orig,\n            float                 freq_base,\n            float                 freq_scale,\n            float                 ext_factor,\n            float                 attn_factor,\n            float                 beta_fast,\n            float                 beta_slow),\n        \"use ggml_rope_ext_inplace instead\");\n\n    // compute correction dims for YaRN RoPE scaling\n    GGML_API void ggml_rope_yarn_corr_dims(\n        int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);\n\n    // rotary position embedding backward, i.e compute dx from dy\n    // a - dy\n    GGML_API struct ggml_tensor * ggml_rope_ext_back(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a, // gradients of ggml_rope result\n            struct ggml_tensor  * b, // positions\n            struct ggml_tensor  * c, // freq factors\n            int                   n_dims,\n            int                   mode,\n            int                   n_ctx_orig,\n            float                 freq_base,\n            float                 freq_scale,\n            float                 ext_factor,\n            float                 attn_factor,\n            float                 beta_fast,\n            float                 beta_slow);\n\n    GGML_API struct ggml_tensor * ggml_rope_multi_back(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            struct ggml_tensor  * c,\n            int                   n_dims,\n            int                   sections[4],\n            int                   mode,\n            int                   n_ctx_orig,\n            float                 freq_base,\n            float                 freq_scale,\n            float                 ext_factor,\n            float                 attn_factor,\n            float                 beta_fast,\n            float                 beta_slow);\n\n\n    // clamp\n    // in-place, returns view(a)\n    GGML_API struct ggml_tensor * ggml_clamp(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            float                 min,\n            float                 max);\n\n    // im2col\n    // converts data into a format that effectively results in a convolution when combined with matrix multiplication\n    GGML_API struct ggml_tensor * ggml_im2col(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,  // convolution kernel\n            struct ggml_tensor  * b,  // data\n            int                   s0, // stride dimension 0\n            int                   s1, // stride dimension 1\n            int                   p0, // padding dimension 0\n            int                   p1, // padding dimension 1\n            int                   d0, // dilation dimension 0\n            int                   d1, // dilation dimension 1\n            bool                  is_2D,\n            enum ggml_type        dst_type);\n\n    GGML_API struct ggml_tensor * ggml_im2col_back(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,  // convolution kernel\n        struct ggml_tensor  * b,  // gradient of im2col output\n        int64_t             * ne, // shape of im2col input\n        int                   s0, // stride dimension 0\n        int                   s1, // stride dimension 1\n        int                   p0, // padding dimension 0\n        int                   p1, // padding dimension 1\n        int                   d0, // dilation dimension 0\n        int                   d1, // dilation dimension 1\n        bool                  is_2D);\n\n    GGML_API struct ggml_tensor * ggml_conv_1d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,   // convolution kernel\n            struct ggml_tensor  * b,   // data\n            int                   s0,  // stride\n            int                   p0,  // padding\n            int                   d0); // dilation\n\n    // conv_1d with padding = half\n    // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)\n    GGML_API struct ggml_tensor* ggml_conv_1d_ph(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,  // convolution kernel\n            struct ggml_tensor  * b,  // data\n            int                   s,  // stride\n            int                   d); // dilation\n\n    // depthwise\n    // TODO: this is very likely wrong for some cases! - needs more testing\n    GGML_API struct ggml_tensor * ggml_conv_1d_dw(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,   // convolution kernel\n            struct ggml_tensor  * b,   // data\n            int                   s0,  // stride\n            int                   p0,  // padding\n            int                   d0); // dilation\n\n    GGML_API struct ggml_tensor * ggml_conv_1d_dw_ph(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,   // convolution kernel\n            struct ggml_tensor  * b,   // data\n            int                   s0,  // stride\n            int                   d0); // dilation\n\n    GGML_API struct ggml_tensor * ggml_conv_transpose_1d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,   // convolution kernel\n            struct ggml_tensor  * b,   // data\n            int                   s0,  // stride\n            int                   p0,  // padding\n            int                   d0); // dilation\n\n    GGML_API struct ggml_tensor * ggml_conv_2d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,   // convolution kernel\n            struct ggml_tensor  * b,   // data\n            int                   s0,  // stride dimension 0\n            int                   s1,  // stride dimension 1\n            int                   p0,  // padding dimension 0\n            int                   p1,  // padding dimension 1\n            int                   d0,  // dilation dimension 0\n            int                   d1); // dilation dimension 1\n\n    GGML_API struct ggml_tensor * ggml_im2col_3d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            int64_t               IC,\n            int                   s0, // stride width\n            int                   s1, // stride height\n            int                   s2, // stride depth\n            int                   p0, // padding width\n            int                   p1, // padding height\n            int                   p2, // padding depth\n            int                   d0, // dilation width\n            int                   d1, // dilation height\n            int                   d2, // dilation depth\n            enum ggml_type        dst_type);\n\n    // a: [OC*IC, KD, KH, KW]\n    // b: [N*IC, ID, IH, IW]\n    // result: [N*OC, OD, OH, OW]\n    GGML_API struct ggml_tensor * ggml_conv_3d(\n                struct ggml_context * ctx,\n                struct ggml_tensor  * a,\n                struct ggml_tensor  * b,\n                int64_t               IC,\n                int                   s0, // stride width\n                int                   s1, // stride height\n                int                   s2, // stride depth\n                int                   p0, // padding width\n                int                   p1, // padding height\n                int                   p2, // padding depth\n                int                   d0, // dilation width\n                int                   d1, // dilation height\n                int                   d2  // dilation depth\n        );\n\n    // kernel size is a->ne[0] x a->ne[1]\n    // stride is equal to kernel size\n    // padding is zero\n    // example:\n    // a:     16   16    3  768\n    // b:   1024 1024    3    1\n    // res:   64   64  768    1\n    // used in sam\n    GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    // kernel size is a->ne[0] x a->ne[1]\n    // stride is 1\n    // padding is half\n    // example:\n    // a:      3    3    256  256\n    // b:     64   64    256    1\n    // res:   64   64    256    1\n    // used in sam\n    GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b);\n\n    // depthwise (via im2col and mul_mat)\n    GGML_API struct ggml_tensor * ggml_conv_2d_dw(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,  // convolution kernel\n            struct ggml_tensor  * b,  // data\n            int                  s0,  // stride dimension 0\n            int                  s1,  // stride dimension 1\n            int                  p0,  // padding dimension 0\n            int                  p1,  // padding dimension 1\n            int                  d0,  // dilation dimension 0\n            int                  d1); // dilation dimension 1\n\n    // Depthwise 2D convolution\n    // may be faster than ggml_conv_2d_dw, but not available in all backends\n    // a:   KW    KH    1    C    convolution kernel\n    // b:   W     H     C    N    input data\n    // res: W_out H_out C    N\n    GGML_API struct ggml_tensor * ggml_conv_2d_dw_direct(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            int                   stride0,\n            int                   stride1,\n            int                   pad0,\n            int                   pad1,\n            int                   dilation0,\n            int                   dilation1);\n\n    GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            int                   stride);\n\n    GGML_API struct ggml_tensor * ggml_conv_2d_direct(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,   // convolution kernel [KW, KH, IC, OC]\n            struct ggml_tensor  * b,   // input data [W, H, C, N]\n            int                   s0,  // stride dimension 0\n            int                   s1,  // stride dimension 1\n            int                   p0,  // padding dimension 0\n            int                   p1,  // padding dimension 1\n            int                   d0,  // dilation dimension 0\n            int                   d1); // dilation dimension 1\n\n    GGML_API struct ggml_tensor * ggml_conv_3d_direct(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,   // kernel [KW, KH, KD, IC * OC]\n            struct ggml_tensor  * b,   // input  [W, H, D, C * N]\n            int                   s0,  // stride\n            int                   s1,\n            int                   s2,\n            int                   p0,  // padding\n            int                   p1,\n            int                   p2,\n            int                   d0,  // dilation\n            int                   d1,\n            int                   d2,\n            int                   n_channels,\n            int                   n_batch,\n            int                   n_channels_out);\n\n    enum ggml_op_pool {\n        GGML_OP_POOL_MAX,\n        GGML_OP_POOL_AVG,\n        GGML_OP_POOL_COUNT,\n    };\n\n    GGML_API struct ggml_tensor * ggml_pool_1d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            enum ggml_op_pool     op,\n            int                   k0, // kernel size\n            int                   s0, // stride\n            int                   p0); // padding\n\n    // the result will have 2*p0 padding for the first dimension\n    // and 2*p1 padding for the second dimension\n    GGML_API struct ggml_tensor * ggml_pool_2d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            enum ggml_op_pool     op,\n            int                   k0,\n            int                   k1,\n            int                   s0,\n            int                   s1,\n            float                 p0,\n            float                 p1);\n\n    GGML_API struct ggml_tensor * ggml_pool_2d_back(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * af, // \"a\"/input used in forward pass\n            enum ggml_op_pool     op,\n            int                   k0,\n            int                   k1,\n            int                   s0,\n            int                   s1,\n            float                 p0,\n            float                 p1);\n\n    enum ggml_scale_mode {\n        GGML_SCALE_MODE_NEAREST  = 0,\n        GGML_SCALE_MODE_BILINEAR = 1,\n        GGML_SCALE_MODE_BICUBIC  = 2,\n\n        GGML_SCALE_MODE_COUNT\n    };\n\n    enum ggml_scale_flag {\n        GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8),\n        GGML_SCALE_FLAG_ANTIALIAS     = (1 << 9),\n    };\n\n    // interpolate\n    // multiplies ne0 and ne1 by scale factor\n    GGML_API struct ggml_tensor * ggml_upscale(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   scale_factor,\n            enum ggml_scale_mode  mode);\n\n    // interpolate\n    // interpolate scale to specified dimensions\n    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_upscale_ext(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   ne0,\n            int                   ne1,\n            int                   ne2,\n            int                   ne3,\n            enum ggml_scale_mode  mode),\n        \"use ggml_interpolate instead\");\n\n    // Up- or downsamples the input to the specified size.\n    // 2D scale modes (eg. bilinear) are applied to the first two dimensions.\n    GGML_API struct ggml_tensor * ggml_interpolate(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int64_t               ne0,\n            int64_t               ne1,\n            int64_t               ne2,\n            int64_t               ne3,\n            uint32_t              mode); // ggml_scale_mode [ | ggml_scale_flag...]\n\n    // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]\n    GGML_API struct ggml_tensor * ggml_pad(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                  p0,\n            int                  p1,\n            int                  p2,\n            int                  p3);\n\n    // pad each dimension with values on the other side of the torus (looping around)\n    GGML_API struct ggml_tensor * ggml_pad_circular(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   p0,\n            int                   p1,\n            int                   p2,\n            int                   p3);\n\n    GGML_API struct ggml_tensor * ggml_pad_ext(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                  lp0,\n            int                  rp0,\n            int                  lp1,\n            int                  rp1,\n            int                  lp2,\n            int                  rp2,\n            int                  lp3,\n            int                  rp3\n            );\n\n    // pad each dimension with values on the other side of the torus (looping around)\n    GGML_API struct ggml_tensor * ggml_pad_ext_circular(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   lp0,\n            int                   rp0,\n            int                   lp1,\n            int                   rp1,\n            int                   lp2,\n            int                   rp2,\n            int                   lp3,\n            int                   rp3);\n\n    // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]\n    GGML_API struct ggml_tensor * ggml_pad_reflect_1d(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   p0,\n            int                   p1);\n\n    // Move tensor elements by an offset given for each dimension. Elements that\n    // are shifted beyond the last position are wrapped around to the beginning.\n    GGML_API struct ggml_tensor * ggml_roll(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   shift0,\n            int                   shift1,\n            int                   shift2,\n            int                   shift3);\n\n    // Convert matrix into a triangular one (upper, strict upper, lower or strict lower) by writing\n    // zeroes everywhere outside the masked area\n    GGML_API struct ggml_tensor * ggml_tri(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            enum ggml_tri_type    type);\n\n    // Fill tensor a with constant c\n    GGML_API struct ggml_tensor * ggml_fill(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            float                 c);\n\n    GGML_API struct ggml_tensor * ggml_fill_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            float                 c);\n\n    // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151\n    // timesteps: [N,]\n    // return: [N, dim]\n    GGML_API struct ggml_tensor * ggml_timestep_embedding(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * timesteps,\n            int                   dim,\n            int                   max_period);\n\n    // sort rows\n    enum ggml_sort_order {\n        GGML_SORT_ORDER_ASC,\n        GGML_SORT_ORDER_DESC,\n    };\n\n    GGML_API struct ggml_tensor * ggml_argsort(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            enum ggml_sort_order  order);\n\n    // similar to ggml_top_k but implemented as `argsort` + `view`\n    GGML_API struct ggml_tensor * ggml_argsort_top_k(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   k);\n\n    // top k elements per row\n    // note: the resulting top k indices are in no particular order\n    GGML_API struct ggml_tensor * ggml_top_k(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   k);\n\n    GGML_API struct ggml_tensor * ggml_arange(\n            struct ggml_context * ctx,\n            float                 start,\n            float                 stop,\n            float                 step);\n\n    // q:    [n_embd_k, n_batch, n_head,    ne3 ]\n    // k:    [n_embd_k, n_kv,    n_head_kv, ne3 ]\n    // v:    [n_embd_v, n_kv,    n_head_kv, ne3 ] !! not transposed !!\n    // mask: [n_kv,     n_batch, ne32,      ne33]\n    // res:  [n_embd_v, n_head,  n_batch,   ne3 ] !! permuted !!\n    //\n    // broadcast:\n    //   n_head % n_head_kv == 0\n    //   n_head % ne32      == 0\n    //   ne3    % ne33      == 0\n    //\n    GGML_API struct ggml_tensor * ggml_flash_attn_ext(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * q,\n            struct ggml_tensor  * k,\n            struct ggml_tensor  * v,\n            struct ggml_tensor  * mask,\n            float                 scale,\n            float                 max_bias,\n            float                 logit_softcap);\n\n    GGML_API void ggml_flash_attn_ext_set_prec(\n            struct ggml_tensor * a,\n            enum ggml_prec       prec);\n\n    GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(\n            const struct ggml_tensor * a);\n\n    GGML_API void ggml_flash_attn_ext_add_sinks(\n            struct ggml_tensor * a,\n            struct ggml_tensor * sinks);\n\n    // TODO: needs to be adapted to ggml_flash_attn_ext\n    GGML_API struct ggml_tensor * ggml_flash_attn_back(\n           struct ggml_context * ctx,\n           struct ggml_tensor  * q,\n           struct ggml_tensor  * k,\n           struct ggml_tensor  * v,\n           struct ggml_tensor  * d,\n           bool                  masked);\n\n    GGML_API struct ggml_tensor * ggml_ssm_conv(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * sx,\n            struct ggml_tensor  * c);\n\n    GGML_API struct ggml_tensor * ggml_ssm_scan(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * s,\n            struct ggml_tensor  * x,\n            struct ggml_tensor  * dt,\n            struct ggml_tensor  * A,\n            struct ggml_tensor  * B,\n            struct ggml_tensor  * C,\n            struct ggml_tensor  * ids);\n\n    // partition into non-overlapping windows with padding if needed\n    // example:\n    // a:   768   64   64    1\n    // w:    14\n    // res: 768   14   14    25\n    // used in sam\n    GGML_API struct ggml_tensor * ggml_win_part(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   w);\n\n    // reverse of ggml_win_part\n    // used in sam\n    GGML_API struct ggml_tensor * ggml_win_unpart(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   w0,\n            int                   h0,\n            int                   w);\n\n    GGML_API struct ggml_tensor * ggml_unary(\n            struct ggml_context * ctx,\n             struct ggml_tensor * a,\n             enum ggml_unary_op op);\n\n    GGML_API struct ggml_tensor * ggml_unary_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        enum ggml_unary_op op);\n\n    // used in sam\n    GGML_API struct ggml_tensor * ggml_get_rel_pos(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                   qh,\n            int                   kh);\n\n    // used in sam\n    GGML_API struct ggml_tensor * ggml_add_rel_pos(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * pw,\n            struct ggml_tensor  * ph);\n\n    GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * pw,\n            struct ggml_tensor  * ph);\n\n    GGML_API struct ggml_tensor * ggml_rwkv_wkv6(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * k,\n            struct ggml_tensor  * v,\n            struct ggml_tensor  * r,\n            struct ggml_tensor  * tf,\n            struct ggml_tensor  * td,\n            struct ggml_tensor  * state);\n\n    GGML_API struct ggml_tensor * ggml_gated_linear_attn(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * k,\n            struct ggml_tensor  * v,\n            struct ggml_tensor  * q,\n            struct ggml_tensor  * g,\n            struct ggml_tensor  * state,\n            float scale);\n\n    GGML_API struct ggml_tensor * ggml_rwkv_wkv7(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * r,\n            struct ggml_tensor  * w,\n            struct ggml_tensor  * k,\n            struct ggml_tensor  * v,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            struct ggml_tensor  * state);\n\n    /* Solves a specific equation of the form Ax=B, where A is a triangular matrix\n    *  without zeroes on the diagonal (i.e. invertible).\n    *  B can have any number of columns, but must have the same number of rows as A\n    *  If A is [n, n] and B is [n, m], then the result will be [n, m] as well\n    *  Has O(n^3) complexity (unlike most matrix ops out there), so use on cases\n    *  where n > 100 sparingly, pre-chunk if necessary.\n    *\n    *  If left = false, solves xA=B instead\n    *  If lower = false, assumes upper triangular instead\n    *  If uni = true, assumes diagonal of A to be all ones (will override actual values)\n    *\n    *  TODO: currently only lower, right, non-unitriangular variant is implemented\n    */\n    GGML_API struct ggml_tensor * ggml_solve_tri(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        bool                  left,\n        bool                  lower,\n        bool                  uni);\n\n    // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]\n    // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306\n    GGML_API struct ggml_tensor * ggml_gated_delta_net(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * q,\n            struct ggml_tensor  * k,\n            struct ggml_tensor  * v,\n            struct ggml_tensor  * g,\n            struct ggml_tensor  * beta,\n            struct ggml_tensor  * state);\n\n    // custom operators\n\n    typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);\n    typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);\n    typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);\n\n#define GGML_N_TASKS_MAX (-1)\n    // n_tasks == GGML_N_TASKS_MAX means to use max number of tasks\n\n    GGML_API struct ggml_tensor * ggml_map_custom1(\n            struct ggml_context   * ctx,\n            struct ggml_tensor    * a,\n            ggml_custom1_op_t       fun,\n            int                     n_tasks,\n            void                  * userdata);\n\n    GGML_API struct ggml_tensor * ggml_map_custom1_inplace(\n            struct ggml_context   * ctx,\n            struct ggml_tensor    * a,\n            ggml_custom1_op_t       fun,\n            int                     n_tasks,\n            void                  * userdata);\n\n    GGML_API struct ggml_tensor * ggml_map_custom2(\n            struct ggml_context   * ctx,\n            struct ggml_tensor    * a,\n            struct ggml_tensor    * b,\n            ggml_custom2_op_t       fun,\n            int                     n_tasks,\n            void                  * userdata);\n\n    GGML_API struct ggml_tensor * ggml_map_custom2_inplace(\n            struct ggml_context   * ctx,\n            struct ggml_tensor    * a,\n            struct ggml_tensor    * b,\n            ggml_custom2_op_t       fun,\n            int                     n_tasks,\n            void                  * userdata);\n\n    GGML_API struct ggml_tensor * ggml_map_custom3(\n            struct ggml_context   * ctx,\n            struct ggml_tensor    * a,\n            struct ggml_tensor    * b,\n            struct ggml_tensor    * c,\n            ggml_custom3_op_t       fun,\n            int                     n_tasks,\n            void                  * userdata);\n\n    GGML_API struct ggml_tensor * ggml_map_custom3_inplace(\n            struct ggml_context   * ctx,\n            struct ggml_tensor    * a,\n            struct ggml_tensor    * b,\n            struct ggml_tensor    * c,\n            ggml_custom3_op_t       fun,\n            int                     n_tasks,\n            void                  * userdata);\n\n    typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata);\n\n    GGML_API struct ggml_tensor * ggml_custom_4d(\n            struct ggml_context * ctx,\n            enum ggml_type        type,\n            int64_t               ne0,\n            int64_t               ne1,\n            int64_t               ne2,\n            int64_t               ne3,\n            struct ggml_tensor ** args,\n            int                   n_args,\n            ggml_custom_op_t      fun,\n            int                   n_tasks,\n            void                * userdata);\n\n    GGML_API struct ggml_tensor * ggml_custom_inplace(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor ** args,\n            int                   n_args,\n            ggml_custom_op_t      fun,\n            int                   n_tasks,\n            void                * userdata);\n\n    // loss function\n\n    GGML_API struct ggml_tensor * ggml_cross_entropy_loss(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,  // logits\n            struct ggml_tensor  * b); // labels\n\n    GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,  // logits\n            struct ggml_tensor  * b,  // labels\n            struct ggml_tensor  * c); // gradients of cross_entropy_loss result\n\n    // AdamW optimizer step\n    // Paper: https://arxiv.org/pdf/1711.05101v3.pdf\n    // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html\n    GGML_API struct ggml_tensor * ggml_opt_step_adamw(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * grad,\n            struct ggml_tensor  * m,\n            struct ggml_tensor  * v,\n            struct ggml_tensor  * adamw_params); // parameters such as the learning rate\n\n    // stochastic gradient descent step (with weight decay)\n    GGML_API struct ggml_tensor * ggml_opt_step_sgd(\n        struct ggml_context * ctx,\n        struct ggml_tensor *  a,\n        struct ggml_tensor *  grad,\n        struct ggml_tensor *  sgd_params); // alpha, weight decay\n\n    // build forward multiple tensors and select one of them for computing\n    // this is useful for creating graphs that have constant topology but compute different things based on the input\n    // ref: https://github.com/ggml-org/llama.cpp/pull/18550\n    //\n    // nodes:\n    //   | - build forward into the graph but do not compute\n    //   c - build forward into the graph and compute\n    //\n    //    |  |  ...  c  ...  |\n    //    |  |  ...  c  ...  |\n    //    |  |  ...  c  ...  |\n    //   [0  1  ... idx ...  n-1]        <-- ggml_build_forward_select(..., n, idx)\n    //               c\n    //               c\n    //\n    // example:\n    //   struct ggml_tensor * curs[3];\n    //\n    //   curs[0]  = compute0(...);\n    //   curs[1]  = compute1(...);\n    //   curs[2]  = compute2(...);\n    //\n    //   int idx = select_branch(some_input);\n    //\n    //   struct ggml_tensor * out = ggml_build_forward_select(cgraph, curs, 3, idx);\n    //\n    GGML_API struct ggml_tensor * ggml_build_forward_select(\n            struct ggml_cgraph  * cgraph,\n            struct ggml_tensor ** tensors,\n            int                   n_tensors,\n            int                   idx);\n\n    GGML_API void ggml_build_forward_expand(\n            struct ggml_cgraph * cgraph,\n            struct ggml_tensor * tensor);\n\n    GGML_API void ggml_build_backward_expand(\n        struct ggml_context *  ctx,        // context for gradient computation\n        struct ggml_cgraph  *  cgraph,\n        struct ggml_tensor  ** grad_accs);\n\n    // graph allocation in a context\n    GGML_API struct ggml_cgraph * ggml_new_graph       (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false\n    GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);\n    GGML_API struct ggml_cgraph * ggml_graph_dup       (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads);\n    GGML_API void                 ggml_graph_cpy       (struct ggml_cgraph * src, struct ggml_cgraph * dst);\n    GGML_API void                 ggml_graph_reset     (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1\n    GGML_API void                 ggml_graph_clear     (struct ggml_cgraph * cgraph);\n\n    GGML_API int                   ggml_graph_size   (struct ggml_cgraph * cgraph);\n    GGML_API struct ggml_tensor *  ggml_graph_node   (struct ggml_cgraph * cgraph, int i); // if i < 0, returns nodes[n_nodes + i]\n    GGML_API struct ggml_tensor ** ggml_graph_nodes  (struct ggml_cgraph * cgraph);\n    GGML_API int                   ggml_graph_n_nodes(struct ggml_cgraph * cgraph);\n\n    GGML_API void   ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);\n\n    GGML_API size_t ggml_graph_overhead(void);\n    GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);\n\n    GGML_API struct ggml_tensor * ggml_graph_get_tensor  (const struct ggml_cgraph * cgraph, const char * name);\n    GGML_API struct ggml_tensor * ggml_graph_get_grad    (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);\n    GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);\n\n    // print info and performance information for the graph\n    GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);\n\n    // dump the graph into a file using the dot format\n    GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename);\n\n    // TODO these functions were sandwiched in the old optimization interface, is there a better place for them?\n    typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);\n\n    // Set callback for all future logging events.\n    // If this is not called, or NULL is supplied, everything is output on stderr.\n    GGML_API void ggml_log_get(ggml_log_callback * log_callback, void ** user_data);\n    GGML_API void ggml_log_set(ggml_log_callback   log_callback, void *  user_data);\n\n    GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);\n\n    //\n    // quantization\n    //\n\n    // - ggml_quantize_init can be called multiple times with the same type\n    //   it will only initialize the quantization tables for the first call or after ggml_quantize_free\n    //   automatically called by ggml_quantize_chunk for convenience\n    //\n    // - ggml_quantize_free will free any memory allocated by ggml_quantize_init\n    //   call this at the end of the program to avoid memory leaks\n    //\n    // note: these are thread-safe\n    //\n    GGML_API void ggml_quantize_init(enum ggml_type type);\n    GGML_API void ggml_quantize_free(void);\n\n    // some quantization type cannot be used without an importance matrix\n    GGML_API bool ggml_quantize_requires_imatrix(enum ggml_type type);\n\n    // calls ggml_quantize_init internally (i.e. can allocate memory)\n    GGML_API size_t ggml_quantize_chunk(\n            enum ggml_type   type,\n               const float * src,\n                      void * dst,\n                   int64_t   start,\n                   int64_t   nrows,\n                   int64_t   n_per_row,\n               const float * imatrix);\n\n#ifdef __cplusplus\n    // restrict not standard in C++\n#    if defined(__GNUC__)\n#        define GGML_RESTRICT __restrict__\n#    elif defined(__clang__)\n#        define GGML_RESTRICT __restrict\n#    elif defined(_MSC_VER)\n#        define GGML_RESTRICT __restrict\n#    else\n#        define GGML_RESTRICT\n#    endif\n#else\n#    if defined (_MSC_VER) && (__STDC_VERSION__ < 201112L)\n#        define GGML_RESTRICT __restrict\n#    else\n#        define GGML_RESTRICT restrict\n#    endif\n#endif\n    typedef void (*ggml_to_float_t)  (const void  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\n    typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void  * GGML_RESTRICT y, int64_t k);\n\n    struct ggml_type_traits {\n        const char             * type_name;\n        int64_t                  blck_size;\n        int64_t                  blck_size_interleave; // interleave elements in blocks\n        size_t                   type_size;\n        bool                     is_quantized;\n        ggml_to_float_t          to_float;\n        ggml_from_float_t        from_float_ref;\n    };\n\n    GGML_API const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type);\n\n    // ggml threadpool\n    // TODO: currently, only a few functions are in the base ggml API, while the rest are in the CPU backend\n    // the goal should be to create an API that other backends can use move everything to the ggml base\n\n    // scheduling priorities\n    enum ggml_sched_priority {\n        GGML_SCHED_PRIO_LOW = -1,\n        GGML_SCHED_PRIO_NORMAL,\n        GGML_SCHED_PRIO_MEDIUM,\n        GGML_SCHED_PRIO_HIGH,\n        GGML_SCHED_PRIO_REALTIME\n    };\n\n    // threadpool params\n    // Use ggml_threadpool_params_default() or ggml_threadpool_params_init() to populate the defaults\n    struct ggml_threadpool_params {\n        bool                cpumask[GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings)\n        int                 n_threads;                   // number of threads\n        enum ggml_sched_priority prio;                   // thread priority\n        uint32_t            poll;                        // polling level (0 - no polling, 100 - aggressive polling)\n        bool                strict_cpu;                  // strict cpu placement\n        bool                paused;                      // start in paused state\n    };\n\n    struct ggml_threadpool;     // forward declaration, see ggml.c\n\n    typedef struct ggml_threadpool * ggml_threadpool_t;\n\n    GGML_API struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads);\n    GGML_API void                          ggml_threadpool_params_init   (struct ggml_threadpool_params * p, int n_threads);\n    GGML_API bool                          ggml_threadpool_params_match  (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1);\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "include/gguf.h",
    "content": "// This file contains functionality related to \"GGUF\" files, the binary file format used by ggml.\n// GGUF files have the following structure:\n//\n// 1. File magic \"GGUF\" (4 bytes).\n// 2. File version (uint32_t).\n// 3. Number of ggml tensors in file (int64_t).\n// 4. Number of key-value-pairs in file (int64_t).\n// 5. For each KV pair:\n//   1. The key (string).\n//   2. The value type (gguf_type).\n//   3a. If the value type is GGUF_TYPE_ARRAY:\n//     1. The type of the array (gguf_type).\n//     2. The number of elements in the array (uint64_t).\n//     3. The binary representation of each element in the array.\n//   3b. Otherwise:\n//     1. The binary representation of the value.\n// 6. For each ggml tensor:\n//   1. The tensor name (string).\n//   2. The number of dimensions of the tensor (uint32_t).\n//   3. For each dimension:\n//     1. The size of the tensor in the dimension (int64_t).\n//   4. The tensor data type (ggml_type).\n//   5. The tensor data offset in the tensor data binary blob (uint64_t).\n// 7. The tensor data binary blob (optional, aligned).\n//\n// Strings are serialized as the string length (uint64_t) followed by the C string without the null terminator.\n// All enums are stored as int32_t.\n// All bool values are stored as int8_t.\n// If the special key \"general.alignment\" (uint32_t) is defined it is used for alignment,\n//   otherwise GGUF_DEFAULT_ALIGNMENT is used.\n//\n// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de)\n\n#pragma once\n\n#include \"ggml.h\"\n\n#include <stdbool.h>\n#include <stdint.h>\n\n#define GGUF_MAGIC   \"GGUF\"\n#define GGUF_VERSION 3\n\n#define GGUF_KEY_GENERAL_ALIGNMENT \"general.alignment\"\n\n#define GGUF_DEFAULT_ALIGNMENT 32\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\n    // types that can be stored as GGUF KV data\n    enum gguf_type {\n        GGUF_TYPE_UINT8   = 0,\n        GGUF_TYPE_INT8    = 1,\n        GGUF_TYPE_UINT16  = 2,\n        GGUF_TYPE_INT16   = 3,\n        GGUF_TYPE_UINT32  = 4,\n        GGUF_TYPE_INT32   = 5,\n        GGUF_TYPE_FLOAT32 = 6,\n        GGUF_TYPE_BOOL    = 7,\n        GGUF_TYPE_STRING  = 8,\n        GGUF_TYPE_ARRAY   = 9,\n        GGUF_TYPE_UINT64  = 10,\n        GGUF_TYPE_INT64   = 11,\n        GGUF_TYPE_FLOAT64 = 12,\n        GGUF_TYPE_COUNT,       // marks the end of the enum\n    };\n\n    struct gguf_context;\n\n    struct gguf_init_params {\n        bool no_alloc;\n\n        // if not NULL, create a ggml_context and allocate the tensor data in it\n        struct ggml_context ** ctx;\n    };\n\n    GGML_API struct gguf_context * gguf_init_empty(void);\n    GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);\n    //GGML_API struct gguf_context * gguf_init_from_buffer(..);\n\n    GGML_API void gguf_free(struct gguf_context * ctx);\n\n    GGML_API const char * gguf_type_name(enum gguf_type type);\n\n    GGML_API uint32_t gguf_get_version    (const struct gguf_context * ctx);\n    GGML_API size_t   gguf_get_alignment  (const struct gguf_context * ctx);\n    GGML_API size_t   gguf_get_data_offset(const struct gguf_context * ctx);\n\n    GGML_API int64_t      gguf_get_n_kv(const struct gguf_context * ctx);\n    GGML_API int64_t      gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found\n    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int64_t key_id);\n\n    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int64_t key_id);\n    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id);\n\n    // will abort if the wrong type is used for the key\n    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int64_t key_id);\n    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int64_t key_id);\n    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int64_t key_id);\n    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int64_t key_id);\n    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int64_t key_id);\n    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int64_t key_id);\n    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int64_t key_id);\n    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int64_t key_id);\n    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int64_t key_id);\n    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int64_t key_id);\n    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id);\n    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int64_t key_id);\n    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id);\n    GGML_API size_t       gguf_get_arr_n   (const struct gguf_context * ctx, int64_t key_id);\n\n    // get raw pointer to the first element of the array with the given key_id\n    // for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)\n    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);\n\n    // get ith C string from array with given key_id\n    GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);\n\n    GGML_API int64_t        gguf_get_n_tensors    (const struct gguf_context * ctx);\n    GGML_API int64_t        gguf_find_tensor      (const struct gguf_context * ctx, const char * name); // returns -1 if the tensor is not found\n    GGML_API size_t         gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id);\n    GGML_API const char *   gguf_get_tensor_name  (const struct gguf_context * ctx, int64_t tensor_id);\n    GGML_API enum ggml_type gguf_get_tensor_type  (const struct gguf_context * ctx, int64_t tensor_id);\n    GGML_API size_t         gguf_get_tensor_size  (const struct gguf_context * ctx, int64_t tensor_id);\n\n    // removes key if it exists, returns id that the key had prior to removal (-1 if it didn't exist)\n    GGML_API int64_t gguf_remove_key(struct gguf_context * ctx, const char * key);\n\n    // overrides an existing KV pair or adds a new one, the new KV pair is always at the back\n    GGML_API void gguf_set_val_u8  (struct gguf_context * ctx, const char * key, uint8_t      val);\n    GGML_API void gguf_set_val_i8  (struct gguf_context * ctx, const char * key, int8_t       val);\n    GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t     val);\n    GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t      val);\n    GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t     val);\n    GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t      val);\n    GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float        val);\n    GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t     val);\n    GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t      val);\n    GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double       val);\n    GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool         val);\n    GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);\n\n    // creates a new array with n elements of the given type and copies the corresponding number of bytes from data\n    GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n);\n\n    // creates a new array with n strings and copies the corresponding strings from data\n    GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, size_t n);\n\n    // set or add KV pairs from another context\n    GGML_API void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src);\n\n    // add tensor to GGUF context, tensor name must be unique\n    GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);\n\n    // after changing a tensor's type, the offsets of all tensors with higher indices are immediately recalculated\n    //   in such a way that the tensor data remains as one contiguous block (except for padding)\n    GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);\n\n    // assumes that at least gguf_get_tensor_size bytes can be read from data\n    GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data);\n\n    // writing gguf files can be done in 3 ways:\n    //\n    // - write the entire gguf_context to a binary file in a single pass:\n    //\n    //   gguf_write_to_file(ctx, fname, /*only_meta =*/ false);\n    //\n    // - write only the meta data to a file, then re-open the file and append the tensor data:\n    //\n    //   gguf_write_to_file(ctx, fname, /*only_meta =*/ true);\n    //   FILE * f = fopen(fname, \"ab\");\n    //   fwrite(f, ...); // write tensor data\n    //   fclose(f);\n    //\n    // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:\n    //\n    //   FILE * f = fopen(fname, \"wb\");\n    //   const size_t size_meta = gguf_get_meta_size(ctx);\n    //   fseek(f, size_meta, SEEK_SET);\n    //   fwrite(f, ...); // write tensor data\n    //   void * data = malloc(size_meta);\n    //   gguf_get_meta_data(ctx, data);\n    //   rewind(f);\n    //   fwrite(data, 1, data, f);\n    //   free(data);\n    //   fclose(f);\n    //\n\n    // write the entire context to a binary file\n    GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);\n\n    // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding\n    GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);\n\n    // writes the meta data to pointer \"data\"\n    GGML_API void   gguf_get_meta_data(const struct gguf_context * ctx, void * data);\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "requirements.txt",
    "content": "accelerate==0.19.0\nnumpy>=2.0.2\nsentencepiece~=0.1.98\ntorchvision>=0.15.2\ntransformers>=4.35.2,<5.0.0\ngguf>=0.1.0\nkeras==3.5.0\ntensorflow==2.18.0\n\n--extra-index-url https://download.pytorch.org/whl/cpu\ntorch~=2.5.1\n"
  },
  {
    "path": "scripts/gen-authors.sh",
    "content": "#!/usr/bin/env bash\n\nprintf \"# date: $(date)\\n\" > AUTHORS\nprintf \"# this file is auto-generated by scripts/gen-authors.sh\\n\\n\" >> AUTHORS\n\ngit log --format='%an <%ae>' --reverse --date=short master | awk '!seen[$0]++' | sort >> AUTHORS\n\n# if necessary, update your name here. for example: jdoe -> John Doe\nsed -i '' 's/^jdoe/John Doe/g' AUTHORS\n"
  },
  {
    "path": "scripts/release.sh",
    "content": "#!/bin/bash\n#\n# Automated release script for ggml.\n#\n# Note: Sync from llama.cpp should be done separately via PR process\n# prior to running this script.\n#\n# Usage:\n#   ./scripts/release.sh prepare [major|minor|patch] [--dry-run]\n#   ./scripts/release.sh finalize [--dry-run]\n#\n# Two-stage release process:\n#\n# Stage 1 - Prepare:\n# $ ./scripts/release.sh prepare minor\n# This creates a release candidate branch with version bump and removes -dev suffix.\n# The branch should then be manually pushed and a PR created, reviewed, and merged.\n#\n# Stage 2 - Finalize:\n# $ ./scripts/release.sh finalize\n# After the RC PR is merged, this reads the current version from CMakeLists.txt,\n# creates the release tag, and prepares the next development cycle.\n#\n# Prepare stage:\n# 1. Creates release candidate branch\n# 2. Updates version and removes -dev suffix\n# 3. Commits the version bump\n#\n# Finalize stage:\n# 1. Reads current release version from CMakeLists.txt\n# 2. Creates signed git tag on master\n# 3. Adds -dev suffix back for next development cycle\n# 4. Creates branch and commit for development version\n#\n\nset -e\n\nif [ ! -f \"CMakeLists.txt\" ] || [ ! -d \"scripts\" ]; then\n    echo \"Error: Must be run from ggml root directory\"\n    exit 1\nfi\n\n# Parse command line arguments\nCOMMAND=\"\"\nVERSION_TYPE=\"\"\nDRY_RUN=false\n\n# First argument should be the command\nif [ $# -eq 0 ]; then\n    echo \"Error: Missing command\"\n    echo \"Usage: $0 prepare [major|minor|patch] [--dry-run]\"\n    echo \"       $0 finalize [--dry-run]\"\n    exit 1\nfi\n\nCOMMAND=\"$1\"\nshift\n\n# Parse remaining arguments\nfor arg in \"$@\"; do\n    case $arg in\n        --dry-run)\n            DRY_RUN=true\n            ;;\n        major|minor|patch)\n            if [ \"$COMMAND\" = \"prepare\" ]; then\n                VERSION_TYPE=\"$arg\"\n            else\n                echo \"Error: Version type only valid for 'prepare' command\"\n                exit 1\n            fi\n            ;;\n        *)\n            echo \"Error: Unknown argument '$arg'\"\n            echo \"Usage: $0 prepare [major|minor|patch] [--dry-run]\"\n            echo \"       $0 finalize [--dry-run]\"\n            exit 1\n            ;;\n    esac\ndone\n\n# Validate command\nif [[ ! \"$COMMAND\" =~ ^(prepare|finalize)$ ]]; then\n    echo \"Error: Command must be 'prepare' or 'finalize'\"\n    echo \"Usage: $0 prepare [major|minor|patch] [--dry-run]\"\n    echo \"       $0 finalize [--dry-run]\"\n    exit 1\nfi\n\n# For prepare command, default to patch if no version type specified\nif [ \"$COMMAND\" = \"prepare\" ]; then\n    VERSION_TYPE=\"${VERSION_TYPE:-patch}\"\n    if [[ ! \"$VERSION_TYPE\" =~ ^(major|minor|patch)$ ]]; then\n        echo \"Error: Version type must be 'major', 'minor', or 'patch'\"\n        echo \"Usage: $0 prepare [major|minor|patch] [--dry-run]\"\n        exit 1\n    fi\nfi\n\n# Common validation functions\ncheck_git_status() {\n    # Check for uncommitted changes (skip in dry-run)\n    if [ \"$DRY_RUN\" = false ] && ! git diff-index --quiet HEAD --; then\n        echo \"Error: You have uncommitted changes. Please commit or stash them first.\"\n        exit 1\n    fi\n}\n\ncheck_master_branch() {\n    # Ensure we're on master branch\n    CURRENT_BRANCH=$(git branch --show-current)\n    if [ \"$CURRENT_BRANCH\" != \"master\" ]; then\n        if [ \"$DRY_RUN\" = true ]; then\n            echo \"[dry run] Warning: Not on master branch (currently on: $CURRENT_BRANCH). Continuing with dry-run...\"\n            echo \"\"\n        else\n            echo \"Error: Must be on master branch. Currently on: $CURRENT_BRANCH\"\n            exit 1\n        fi\n    fi\n}\n\ncheck_master_up_to_date() {\n    # Check if we have the latest from master (skip in dry-run)\n    if [ \"$DRY_RUN\" = false ]; then\n        echo \"Checking if local master is up-to-date with remote...\"\n        git fetch origin master\n        LOCAL=$(git rev-parse HEAD)\n        REMOTE=$(git rev-parse origin/master)\n\n        if [ \"$LOCAL\" != \"$REMOTE\" ]; then\n            echo \"Error: Your local master branch is not up-to-date with origin/master.\"\n            echo \"Please run 'git pull origin master' first.\"\n            exit 1\n        fi\n        echo \"✓ Local master is up-to-date with remote\"\n        echo \"\"\n    elif [ \"$(git branch --show-current)\" = \"master\" ]; then\n        echo \"[dry run] Warning: Dry-run mode - not checking if master is up-to-date with remote\"\n        echo \"\"\n    fi\n}\n\nprepare_release() {\n    if [ \"$DRY_RUN\" = true ]; then\n        echo \"[dry-run] Preparing release (no changes will be made)\"\n    else\n        echo \"Starting release preparation...\"\n    fi\n    echo \"\"\n\n    check_git_status\n    check_master_branch\n    check_master_up_to_date\n\n    # Extract current version from CMakeLists.txt\n    echo \"Step 1: Reading current version...\"\n    MAJOR=$(grep \"set(GGML_VERSION_MAJOR\" CMakeLists.txt | sed 's/.*MAJOR \\([0-9]*\\).*/\\1/')\n    MINOR=$(grep \"set(GGML_VERSION_MINOR\" CMakeLists.txt | sed 's/.*MINOR \\([0-9]*\\).*/\\1/')\n    PATCH=$(grep \"set(GGML_VERSION_PATCH\" CMakeLists.txt | sed 's/.*PATCH \\([0-9]*\\).*/\\1/')\n\n    echo \"Current version: $MAJOR.$MINOR.$PATCH\"\n\n    # Calculate new version\n    case $VERSION_TYPE in\n        major)\n            NEW_MAJOR=$((MAJOR + 1))\n            NEW_MINOR=0\n            NEW_PATCH=0\n            ;;\n        minor)\n            NEW_MAJOR=$MAJOR\n            NEW_MINOR=$((MINOR + 1))\n            NEW_PATCH=0\n            ;;\n        patch)\n            NEW_MAJOR=$MAJOR\n            NEW_MINOR=$MINOR\n            NEW_PATCH=$((PATCH + 1))\n            ;;\n    esac\n\n    NEW_VERSION=\"$NEW_MAJOR.$NEW_MINOR.$NEW_PATCH\"\n    RC_BRANCH=\"ggml-rc-v$NEW_VERSION\"\n    echo \"New release version: $NEW_VERSION\"\n    echo \"Release candidate branch: $RC_BRANCH\"\n    echo \"\"\n\n    # Create release candidate branch\n    echo \"Step 2: Creating release candidate branch...\"\n    if [ \"$DRY_RUN\" = true ]; then\n        echo \"  [dry-run] Would create branch: $RC_BRANCH\"\n    else\n        git checkout -b \"$RC_BRANCH\"\n        echo \"✓ Created and switched to branch: $RC_BRANCH\"\n    fi\n    echo \"\"\n\n    # Update CMakeLists.txt for release\n    echo \"Step 3: Updating version in CMakeLists.txt...\"\n    if [ \"$DRY_RUN\" = true ]; then\n        echo \"  [dry-run] Would update GGML_VERSION_MAJOR to $NEW_MAJOR\"\n        echo \"  [dry-run] Would update GGML_VERSION_MINOR to $NEW_MINOR\"\n        echo \"  [dry-run] Would update GGML_VERSION_PATCH to $NEW_PATCH\"\n    else\n        sed -i'' -e \"s/set(GGML_VERSION_MAJOR [0-9]*)/set(GGML_VERSION_MAJOR $NEW_MAJOR)/\" CMakeLists.txt\n        sed -i'' -e \"s/set(GGML_VERSION_MINOR [0-9]*)/set(GGML_VERSION_MINOR $NEW_MINOR)/\" CMakeLists.txt\n        sed -i'' -e \"s/set(GGML_VERSION_PATCH [0-9]*)/set(GGML_VERSION_PATCH $NEW_PATCH)/\" CMakeLists.txt\n    fi\n    echo \"\"\n\n    # Commit version bump\n    echo \"Step 4: Committing version bump...\"\n    if [ \"$DRY_RUN\" = true ]; then\n        echo \"  [dry-run] Would commit: 'ggml : bump version to $NEW_VERSION'\"\n    else\n        git add CMakeLists.txt\n        git commit -m \"ggml : bump version to $NEW_VERSION\"\n    fi\n    echo \"\"\n\n    echo \"\"\n    if [ \"$DRY_RUN\" = true ]; then\n        echo \"[dry-run] Summary (no changes were made):\"\n        echo \"  • Would have created branch: $RC_BRANCH\"\n        echo \"  • Would have updated version to: $NEW_VERSION\"\n    else\n        echo \"Release preparation completed!\"\n        echo \"Summary:\"\n        echo \"  • Created branch: $RC_BRANCH\"\n        echo \"  • Updated version to: $NEW_VERSION\"\n        echo \"\"\n        echo \"Next steps:\"\n        echo \"  • Push branch to remote: git push origin $RC_BRANCH\"\n        echo \"  • Create a Pull Request from $RC_BRANCH to master\"\n        echo \"  • After PR is merged, run: ./scripts/release.sh finalize\"\n    fi\n}\n\nfinalize_release() {\n    if [ \"$DRY_RUN\" = true ]; then\n        echo \"[dry-run] Finalizing release (no changes will be made)\"\n    else\n        echo \"Starting release finalization...\"\n    fi\n    echo \"\"\n\n    check_git_status\n    check_master_branch\n    check_master_up_to_date\n\n    # Read current version from CMakeLists.txt\n    echo \"Step 1: Reading current release version...\"\n    MAJOR=$(grep \"set(GGML_VERSION_MAJOR\" CMakeLists.txt | sed 's/.*MAJOR \\([0-9]*\\).*/\\1/')\n    MINOR=$(grep \"set(GGML_VERSION_MINOR\" CMakeLists.txt | sed 's/.*MINOR \\([0-9]*\\).*/\\1/')\n    PATCH=$(grep \"set(GGML_VERSION_PATCH\" CMakeLists.txt | sed 's/.*PATCH \\([0-9]*\\).*/\\1/')\n\n    RELEASE_VERSION=\"$MAJOR.$MINOR.$PATCH\"\n    echo \"Release version: $RELEASE_VERSION\"\n    echo \"\"\n\n    # Create git tag\n    echo \"Step 2: Creating signed git tag...\"\n    if [ \"$DRY_RUN\" = true ]; then\n        echo \"  [dry-run] Would create signed tag: v$RELEASE_VERSION with message 'Release version $RELEASE_VERSION'\"\n    else\n        git tag -s \"v$RELEASE_VERSION\" -m \"Release version $RELEASE_VERSION\"\n        echo \"✓ Created signed tag: v$RELEASE_VERSION\"\n    fi\n    echo \"\"\n\n\n    echo \"\"\n    if [ \"$DRY_RUN\" = true ]; then\n        echo \"[dry-run] Summary (no changes were made):\"\n        echo \"  • Would have created tag: v$RELEASE_VERSION\"\n    else\n        echo \"Release finalization completed!\"\n        echo \"Summary:\"\n        echo \"  • Created signed tag: v$RELEASE_VERSION\"\n        echo \"\"\n        echo \"Next steps:\"\n        echo \"  • Push tag to remote: git push origin v$RELEASE_VERSION\"\n        echo \"  • The release is now complete!\"\n    fi\n}\n\n# Execute the appropriate command\ncase $COMMAND in\n    prepare)\n        prepare_release\n        ;;\n    finalize)\n        finalize_release\n        ;;\nesac\n"
  },
  {
    "path": "scripts/sync-llama-am.sh",
    "content": "#!/bin/bash\n#\n# Synchronize llama.cpp changes to ggml\n#\n# Usage:\n#\n#   $ cd /path/to/ggml\n#   $ ./scripts/sync-llama-am.sh -skip hash0,hash1,hash2... -C 3\n#\n\nset -e\n\nsd=$(dirname $0)\ncd $sd/../\n\nSRC_GGML=$(pwd)\nSRC_LLAMA=$(cd ../llama.cpp; pwd)\n\nif [ ! -d $SRC_LLAMA ]; then\n    echo \"llama.cpp not found at $SRC_LLAMA\"\n    exit 1\nfi\n\nlc=$(cat $SRC_GGML/scripts/sync-llama.last)\necho \"Syncing llama.cpp changes since commit $lc\"\n\nto_skip=\"\"\n\n# context for git patches in number of lines\nctx=\"8\"\n\nwhile [ \"$1\" != \"\" ]; do\n    case $1 in\n        -skip )\n            shift\n            to_skip=$1\n            ;;\n        -C )\n            shift\n            ctx=$1\n            ;;\n    esac\n    shift\ndone\n\ncd $SRC_LLAMA\n\ngit log --oneline $lc..HEAD\ngit log --oneline $lc..HEAD --reverse | grep -v \"(ggml/[0-9]*)\" | grep -v \"(whisper/[0-9]*)\" | cut -d' ' -f1 > $SRC_GGML/llama-commits\n\nif [ ! -s $SRC_GGML/llama-commits ]; then\n    rm -v $SRC_GGML/llama-commits\n    echo \"No new commits\"\n    exit 0\nfi\n\nif [ -f $SRC_GGML/llama-src.patch ]; then\n    rm -v $SRC_GGML/llama-src.patch\nfi\n\nwhile read c; do\n    if [ -n \"$to_skip\" ]; then\n        if [[ $to_skip == *\"$c\"* ]]; then\n            echo \"Skipping $c\"\n            continue\n        fi\n    fi\n\n    git format-patch -U${ctx} -k $c~1..$c --stdout -- \\\n        ggml/CMakeLists.txt \\\n        ggml/src/CMakeLists.txt \\\n        ggml/cmake/BuildTypes.cmake \\\n        ggml/cmake/GitVars.cmake \\\n        ggml/cmake/common.cmake \\\n        ggml/cmake/ggml-config.cmake.in \\\n        ggml/src/ggml-cpu/cmake/FindSIMD.cmake \\\n        ggml/src/ggml* \\\n        ggml/src/gguf* \\\n        ggml/include/ggml*.h \\\n        ggml/include/gguf*.h \\\n        tests/test-opt.cpp \\\n        tests/test-quantize-fns.cpp \\\n        tests/test-quantize-perf.cpp \\\n        tests/test-backend-ops.cpp \\\n        LICENSE \\\n        scripts/gen-authors.sh \\\n        >> $SRC_GGML/llama-src.patch\ndone < $SRC_GGML/llama-commits\n\nrm -v $SRC_GGML/llama-commits\n\n# delete files if empty\nif [ ! -s $SRC_GGML/llama-src.patch ]; then\n    rm -v $SRC_GGML/llama-src.patch\nfi\n\ncd $SRC_GGML\n\nif [ -f $SRC_GGML/llama-src.patch ]; then\n    # replace PR numbers\n    #\n    # Subject: some text (#1234)\n    # Subject: some text (llama/1234)\n    cat llama-src.patch | sed -e 's/^Subject: \\(.*\\) (#\\([0-9]*\\))/Subject: \\1 (llama\\/\\2)/' > llama-src.patch.tmp\n    mv llama-src.patch.tmp llama-src.patch\n\n    cat llama-src.patch | sed -e 's/^\\(.*\\) (#\\([0-9]*\\))$/\\1 (llama\\/\\2)/' > llama-src.patch.tmp\n    mv llama-src.patch.tmp llama-src.patch\n\n    # replace filenames:\n    #\n    # ggml/CMakelists.txt       -> CMakeLists.txt\n    # ggml/src/CMakelists.txt   -> src/CMakeLists.txt\n    #\n    # ggml/cmake/BuildTypes.cmake            -> cmake/BuildTypes.cmake\n    # ggml/cmake/GitVars.cmake               -> cmake/GitVars.cmake\n    # ggml/cmake/common.cmake                -> cmake/common.cmake\n    # ggml/cmake/ggml-config.cmake.in        -> cmake/ggml-config.cmake.in\n    # ggml/src/ggml-cpu/cmake/FindSIMD.cmake -> src/ggml-cpu/cmake/FindSIMD.cmake\n    #\n    # ggml/src/ggml* -> src/ggml*\n    # ggml/src/gguf* -> src/gguf*\n    #\n    # ggml/include/ggml*.h -> include/ggml*.h\n    # ggml/include/gguf*.h -> include/gguf*.h\n    #\n    # tests/test-opt.cpp           -> tests/test-opt.cpp\n    # tests/test-quantize-fns.cpp  -> tests/test-quantize-fns.cpp\n    # tests/test-quantize-perf.cpp -> tests/test-quantize-perf.cpp\n    # tests/test-backend-ops.cpp   -> tests/test-backend-ops.cpp\n    #\n    # LICENSE                -> LICENSE\n    # scripts/gen-authors.sh -> scripts/gen-authors.sh\n\n    cat llama-src.patch | sed -E \\\n        -e 's/(^[[:space:]]| [ab]\\/)ggml\\/CMakeLists\\.txt/\\1CMakeLists.txt/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)ggml\\/src\\/CMakeLists\\.txt/\\1src\\/CMakeLists.txt/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)ggml\\/cmake\\/BuildTypes\\.cmake/\\1cmake\\/BuildTypes\\.cmake/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)ggml\\/cmake\\/GitVars\\.cmake/\\1cmake\\/GitVars\\.cmake/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)ggml\\/cmake\\/common\\.cmake/\\1cmake\\/common\\.cmake/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)ggml\\/cmake\\/ggml-config\\.cmake\\.in/\\1cmake\\/ggml-config\\.cmake\\.in/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)ggml\\/src\\/ggml-cpu\\/cmake\\/FindSIMD\\.cmake/\\1src\\/ggml-cpu\\/cmake\\/FindSIMD\\.cmake/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)ggml\\/src\\/ggml(.*)/\\1src\\/ggml\\2/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)ggml\\/src\\/gguf(.*)/\\1src\\/gguf\\2/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)ggml\\/include\\/ggml(.*)\\.h/\\1include\\/ggml\\2.h/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)ggml\\/include\\/gguf(.*)\\.h/\\1include\\/gguf\\2.h/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)tests\\/test-opt\\.cpp/\\1tests\\/test-opt.cpp/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)tests\\/test-quantize-fns\\.cpp/\\1tests\\/test-quantize-fns.cpp/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)tests\\/test-quantize-perf\\.cpp/\\1tests\\/test-quantize-perf.cpp/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)tests\\/test-backend-ops\\.cpp/\\1tests\\/test-backend-ops.cpp/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)LICENSE/\\1LICENSE/g' \\\n        -e 's/(^[[:space:]]| [ab]\\/)scripts\\/gen-authors\\.sh/\\1scripts\\/gen-authors.sh/g' \\\n        > llama-src.patch.tmp\n    mv llama-src.patch.tmp llama-src.patch\n\n    git am -C${ctx} llama-src.patch\n\n    rm -v $SRC_GGML/llama-src.patch\nfi\n\n# update last commit\ncd $SRC_LLAMA\ngit log -1 --format=%H > $SRC_GGML/scripts/sync-llama.last\n\necho \"Done\"\n\nexit 0\n"
  },
  {
    "path": "scripts/sync-llama.last",
    "content": "ae40cd27c85aa30b9cd56033da1d6a954290f7ea\n"
  },
  {
    "path": "scripts/sync-llama.sh",
    "content": "#!/bin/bash\n\ncp -rpv ../llama.cpp/ggml/CMakeLists.txt       CMakeLists.txt\ncp -rpv ../llama.cpp/ggml/src/CMakeLists.txt   src/CMakeLists.txt\n\ncp -rpv ../llama.cpp/ggml/cmake/*              cmake/\ncp -rpv ../llama.cpp/ggml/src/ggml-cpu/cmake/* src/ggml-cpu/cmake/\n\ncp -rpv ../llama.cpp/ggml/src/ggml* src/\ncp -rpv ../llama.cpp/ggml/src/gguf* src/\n\ncp -rpv ../llama.cpp/ggml/include/ggml*.h include/\ncp -rpv ../llama.cpp/ggml/include/gguf*.h include/\n\ncp -rpv ../llama.cpp/tests/test-opt.cpp           tests/test-opt.cpp\ncp -rpv ../llama.cpp/tests/test-quantize-fns.cpp  tests/test-quantize-fns.cpp\ncp -rpv ../llama.cpp/tests/test-quantize-perf.cpp tests/test-quantize-perf.cpp\ncp -rpv ../llama.cpp/tests/test-backend-ops.cpp   tests/test-backend-ops.cpp\n\ncp -rpv ../llama.cpp/LICENSE                ./LICENSE\ncp -rpv ../llama.cpp/scripts/gen-authors.sh ./scripts/gen-authors.sh\n"
  },
  {
    "path": "scripts/sync-whisper-am.sh",
    "content": "#!/bin/bash\n#\n# Synchronize whisper.cpp changes to ggml\n#\n# Usage:\n#\n#   $ cd /path/to/ggml\n#   $ ./scripts/sync-whisper-am.sh -skip hash0,hash1,hash2...\n#\n\nset -e\n\nsd=$(dirname $0)\ncd $sd/../\n\nSRC_GGML=$(pwd)\nSRC_WHISPER=$(cd ../whisper.cpp; pwd)\n\nif [ ! -d $SRC_WHISPER ]; then\n    echo \"whisper.cpp not found at $SRC_WHISPER\"\n    exit 1\nfi\n\nlc=$(cat $SRC_GGML/scripts/sync-whisper.last)\necho \"Syncing whisper.cpp changes since commit $lc\"\n\nto_skip=\"\"\nif [ \"$1\" == \"-skip\" ]; then\n    to_skip=$2\nfi\n\ncd $SRC_WHISPER\n\ngit log --oneline $lc..HEAD\ngit log --oneline $lc..HEAD --reverse | grep -v \"(ggml/[0-9]*)\" | grep -v \"(llama/[0-9]*)\" | cut -d' ' -f1 > $SRC_GGML/whisper-commits\n\nif [ ! -s $SRC_GGML/whisper-commits ]; then\n    rm -v $SRC_GGML/whisper-commits\n    echo \"No new commits\"\n    exit 0\nfi\n\nif [ -f $SRC_GGML/whisper-src.patch ]; then\n    rm -v $SRC_GGML/whisper-src.patch\nfi\n\nwhile read c; do\n    if [ -n \"$to_skip\" ]; then\n        if [[ $to_skip == *\"$c\"* ]]; then\n            echo \"Skipping $c\"\n            continue\n        fi\n    fi\n\n    git format-patch -k $c~1..$c --stdout -- \\\n        ggml/CMakeLists.txt \\\n        ggml/src/CMakeLists.txt \\\n        ggml/cmake/FindSIMD.cmake \\\n        ggml/src/ggml* \\\n        ggml/src/gguf* \\\n        ggml/include/ggml*.h \\\n        ggml/include/gguf*.h \\\n        examples/common-ggml.h \\\n        examples/common-ggml.cpp \\\n        LICENSE \\\n        scripts/gen-authors.sh \\\n        >> $SRC_GGML/whisper-src.patch\ndone < $SRC_GGML/whisper-commits\n\nrm -v $SRC_GGML/whisper-commits\n\n# delete files if empty\nif [ ! -s $SRC_GGML/whisper-src.patch ]; then\n    rm -v $SRC_GGML/whisper-src.patch\nfi\n\ncd $SRC_GGML\n\nif [ -f $SRC_GGML/whisper-src.patch ]; then\n    # replace PR numbers\n    #\n    # Subject: some text (#1234)\n    # Subject: some text (whisper/1234)\n    cat whisper-src.patch | sed -e 's/^Subject: \\(.*\\) (#\\([0-9]*\\))/Subject: \\1 (whisper\\/\\2)/' > whisper-src.patch.tmp\n    mv whisper-src.patch.tmp whisper-src.patch\n\n    cat whisper-src.patch | sed -e 's/^\\(.*\\) (#\\([0-9]*\\))$/\\1 (whisper\\/\\2)/' > whisper-src.patch.tmp\n    mv whisper-src.patch.tmp whisper-src.patch\n\n    # replace filenames:\n    #\n    # ggml/CMakelists.txt       -> CMakeLists.txt\n    # ggml/src/CMakelists.txt   -> src/CMakeLists.txt\n    # ggml/cmake/FindSIMD.cmake -> cmake/FindSIMD.cmake\n    #\n    # ggml/src/ggml* -> src/ggml*\n    # ggml/src/gguf* -> src/gguf*\n    #\n    # ggml/include/ggml*.h -> include/ggml*.h\n    # ggml/include/gguf*.h -> include/gguf*.h\n    #\n    # examples/common.h        -> examples/common.h\n    # examples/common.cpp      -> examples/common.cpp\n    # examples/common-ggml.h   -> examples/common-ggml.h\n    # examples/common-ggml.cpp -> examples/common-ggml.cpp\n    #\n    # LICENSE                -> LICENSE\n    # scripts/gen-authors.sh -> scripts/gen-authors.sh\n\n    cat whisper-src.patch | sed -E \\\n        -e 's/\\/ggml\\/CMakeLists\\.txt/\\/CMakeLists.txt/g' \\\n        -e 's/\\/ggml\\/src\\/CMakeLists\\.txt/\\/src\\/CMakeLists.txt/g' \\\n        -e 's/\\/ggml\\/cmake\\/FindSIMD\\.cmake/\\/cmake\\/FindSIMD.cmake/g' \\\n        -e 's/\\/ggml\\/src\\/ggml(.*)/\\/src\\/ggml\\1/g' \\\n        -e 's/\\/ggml\\/src\\/gguf(.*)/\\/src\\/gguf\\1/g' \\\n        -e 's/\\/ggml\\/include\\/ggml(.*)\\.h/\\/include\\/ggml\\1.h/g' \\\n        -e 's/\\/ggml\\/include\\/gguf(.*)\\.h/\\/include\\/gguf\\1.h/g' \\\n        -e 's/\\/examples\\/common\\.h/\\/examples\\/common.h/g' \\\n        -e 's/\\/examples\\/common\\.cpp/\\/examples\\/common.cpp/g' \\\n        -e 's/\\/examples\\/common-ggml\\.h/\\/examples\\/common-ggml.h/g' \\\n        -e 's/\\/examples\\/common-ggml\\.cpp/\\/examples\\/common-ggml.cpp/g' \\\n        -e 's/\\/LICENSE/\\/LICENSE/g' \\\n        -e 's/\\/scripts\\/gen-authors\\.sh/\\/scripts\\/gen-authors.sh/g' \\\n        > whisper-src.patch.tmp\n    mv whisper-src.patch.tmp whisper-src.patch\n\n    git am whisper-src.patch\n\n    rm -v $SRC_GGML/whisper-src.patch\nfi\n\n# update last commit\ncd $SRC_WHISPER\ngit log -1 --format=%H > $SRC_GGML/scripts/sync-whisper.last\n\necho \"Done\"\n\nexit 0\n"
  },
  {
    "path": "scripts/sync-whisper.last",
    "content": "79218f51d02ffe70575ef7fba3496dfc7adda027\n"
  },
  {
    "path": "scripts/sync-whisper.sh",
    "content": "#!/bin/bash\n\ncp -rpv ../whisper.cpp/ggml/CMakeLists.txt       CMakeLists.txt\ncp -rpv ../whisper.cpp/ggml/src/CMakeLists.txt   src/CMakeLists.txt\ncp -rpv ../whisper.cpp/ggml/cmake/FindSIMD.cmake cmake/FindSIMD.cmake\n\ncp -rpv ../whisper.cpp/ggml/src/ggml* src/\ncp -rpv ../whisper.cpp/ggml/src/gguf* src/\n\ncp -rpv ../whisper.cpp/ggml/include/ggml*.h include/\ncp -rpv ../whisper.cpp/ggml/include/gguf*.h include/\n\ncp -rpv ../whisper.cpp/examples/common-ggml.h   examples/common-ggml.h\ncp -rpv ../whisper.cpp/examples/common-ggml.cpp examples/common-ggml.cpp\n\ncp -rpv ../whisper.cpp/LICENSE                ./LICENSE\ncp -rpv ../whisper.cpp/scripts/gen-authors.sh ./scripts/gen-authors.sh\n"
  },
  {
    "path": "src/CMakeLists.txt",
    "content": "include(CheckCXXCompilerFlag)\ninclude(\"../cmake/common.cmake\")\n\nadd_compile_definitions(GGML_SCHED_MAX_COPIES=${GGML_SCHED_MAX_COPIES})\n\n# enable libstdc++ assertions for debug builds\nif (CMAKE_SYSTEM_NAME MATCHES \"Linux\")\n    add_compile_definitions($<$<CONFIG:Debug>:_GLIBCXX_ASSERTIONS>)\nendif()\n\nif (NOT MSVC)\n    if (GGML_SANITIZE_THREAD)\n        add_compile_options(-fsanitize=thread)\n        link_libraries     (-fsanitize=thread)\n    endif()\n\n    if (GGML_SANITIZE_ADDRESS)\n        add_compile_options(-fsanitize=address -fno-omit-frame-pointer)\n        link_libraries     (-fsanitize=address)\n    endif()\n\n    if (GGML_SANITIZE_UNDEFINED)\n        add_compile_options(-fsanitize=undefined)\n        link_libraries     (-fsanitize=undefined)\n    endif()\nendif()\n\nif (GGML_FATAL_WARNINGS)\n    if (CMAKE_CXX_COMPILER_ID MATCHES \"GNU\" OR CMAKE_CXX_COMPILER_ID MATCHES \"Clang\")\n        list(APPEND C_FLAGS   -Werror)\n        list(APPEND CXX_FLAGS -Werror)\n    elseif (CMAKE_CXX_COMPILER_ID STREQUAL \"MSVC\")\n        add_compile_options(/WX)\n    endif()\nendif()\n\nif (GGML_ALL_WARNINGS)\n    if (NOT MSVC)\n        list(APPEND WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function)\n        list(APPEND C_FLAGS       -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes\n                                  -Werror=implicit-int -Werror=implicit-function-declaration)\n        list(APPEND CXX_FLAGS     -Wmissing-declarations -Wmissing-noreturn)\n\n        list(APPEND C_FLAGS   ${WARNING_FLAGS})\n        list(APPEND CXX_FLAGS ${WARNING_FLAGS})\n\n        ggml_get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION})\n\n        add_compile_options(\"$<$<COMPILE_LANGUAGE:C>:${C_FLAGS};${GF_C_FLAGS}>\"\n                            \"$<$<COMPILE_LANGUAGE:CXX>:${CXX_FLAGS};${GF_CXX_FLAGS}>\")\n    else()\n        # todo : msvc\n        set(C_FLAGS   \"\")\n        set(CXX_FLAGS \"\")\n    endif()\nendif()\n\nif (GGML_LTO)\n    include(CheckIPOSupported)\n    check_ipo_supported(RESULT result OUTPUT output)\n    if (result)\n        set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)\n    else()\n        message(WARNING \"IPO is not supported: ${output}\")\n    endif()\nendif()\n\nif (GGML_CCACHE AND NOT CMAKE_C_COMPILER_LAUNCHER AND NOT CMAKE_CXX_COMPILER_LAUNCHER)\n    find_program(GGML_CCACHE_FOUND ccache)\n    find_program(GGML_SCCACHE_FOUND sccache)\n\n    if (GGML_CCACHE_FOUND OR GGML_SCCACHE_FOUND)\n        if(GGML_CCACHE_FOUND)\n            set(GGML_CCACHE_VARIANT ccache)\n        else()\n            set(GGML_CCACHE_VARIANT sccache)\n        endif()\n        # TODO: should not be set globally\n        if (GGML_SYCL AND GGML_CCACHE_FOUND AND WIN32)\n            set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE \"ccache compiler_type=icl\")\n        else ()\n            set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE \"${GGML_CCACHE_VARIANT}\")\n        endif ()\n        set(ENV{CCACHE_SLOPPINESS} time_macros)\n        message(STATUS \"${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.\")\n    else()\n        message(STATUS \"Warning: ccache not found - consider installing it for faster compilation or disable this warning with GGML_CCACHE=OFF\")\n    endif ()\nendif()\n\n# this version of Apple ld64 is buggy\nexecute_process(\n    COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v\n    ERROR_VARIABLE output\n    OUTPUT_QUIET\n)\n\nif (output MATCHES \"dyld-1015\\.7\")\n    add_compile_definitions(HAVE_BUGGY_APPLE_LINKER)\nendif()\n\n# architecture specific\n# TODO: probably these flags need to be tweaked on some architectures\n#       feel free to update the Makefile for your architecture and send a pull request or issue\nmessage(STATUS \"CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}\")\nif (MSVC)\n    string(TOLOWER \"${CMAKE_GENERATOR_PLATFORM}\" CMAKE_GENERATOR_PLATFORM_LWR)\n    message(STATUS \"CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}\")\nelse ()\n    set(CMAKE_GENERATOR_PLATFORM_LWR \"\")\nendif ()\nggml_get_system_arch()\nmessage(STATUS \"GGML_SYSTEM_ARCH: ${GGML_SYSTEM_ARCH}\")\n\nif (NOT MSVC)\n    if (GGML_STATIC)\n        if (UNIX AND NOT APPLE)\n            set(CMAKE_FIND_LIBRARY_SUFFIXES \".a;.so\")\n        endif()\n        add_link_options(-static)\n        if (MINGW)\n            add_link_options(-static-libgcc -static-libstdc++)\n        endif()\n    endif()\n    if (GGML_GPROF)\n        add_compile_options(-pg)\n    endif()\nendif()\n\n#\n# POSIX conformance\n#\n\n# clock_gettime came in POSIX.1b (1993)\n# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional\n# posix_memalign came in POSIX.1-2001 / SUSv3\n# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985)\n\n# Somehow in OpenBSD whenever POSIX conformance is specified\n# some string functions rely on locale_t availability,\n# which was introduced in POSIX.1-2008, forcing us to go higher\nif (CMAKE_SYSTEM_NAME MATCHES \"OpenBSD\")\n    add_compile_definitions(_XOPEN_SOURCE=700)\nelseif (CMAKE_SYSTEM_NAME MATCHES \"AIX\")\n    # Don't define _XOPEN_SOURCE.  We need _ALL_SOURCE, which is the default,\n    # in order to define _SC_PHYS_PAGES.\nelse()\n    add_compile_definitions(_XOPEN_SOURCE=600)\nendif()\n\n# Data types, macros and functions related to controlling CPU affinity and\n# some memory allocation are available on Linux through GNU extensions in libc\nif (CMAKE_SYSTEM_NAME MATCHES \"Linux\" OR CMAKE_SYSTEM_NAME MATCHES \"Android\")\n    add_compile_definitions(_GNU_SOURCE)\nendif()\n\n# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,\n# and on macOS its availability depends on enabling Darwin extensions\n# similarly on DragonFly, enabling BSD extensions is necessary\nif (\n    CMAKE_SYSTEM_NAME MATCHES \"Darwin\" OR\n    CMAKE_SYSTEM_NAME MATCHES \"iOS\"    OR\n    CMAKE_SYSTEM_NAME MATCHES \"tvOS\"   OR\n    CMAKE_SYSTEM_NAME MATCHES \"DragonFly\"\n)\n    add_compile_definitions(_DARWIN_C_SOURCE)\nendif()\n\n# alloca is a non-standard interface that is not visible on BSDs when\n# POSIX conformance is specified, but not all of them provide a clean way\n# to enable it in such cases\nif (CMAKE_SYSTEM_NAME MATCHES \"FreeBSD\")\n    add_compile_definitions(__BSD_VISIBLE)\nendif()\nif (CMAKE_SYSTEM_NAME MATCHES \"NetBSD\")\n    add_compile_definitions(_NETBSD_SOURCE)\nendif()\nif (CMAKE_SYSTEM_NAME MATCHES \"OpenBSD\")\n    add_compile_definitions(_BSD_SOURCE)\nendif()\n\nif (WIN32)\n    add_compile_definitions(_CRT_SECURE_NO_WARNINGS)\nendif()\n\n# ggml\n\nif (GGML_BACKEND_DL AND NOT BUILD_SHARED_LIBS)\n    message(FATAL_ERROR \"GGML_BACKEND_DL requires BUILD_SHARED_LIBS\")\nendif()\n\nadd_library(ggml-base\n            ../include/ggml.h\n            ../include/ggml-alloc.h\n            ../include/ggml-backend.h\n            ../include/ggml-cpp.h\n            ../include/ggml-opt.h\n            ../include/gguf.h\n            ggml.c\n            ggml.cpp\n            ggml-alloc.c\n            ggml-backend.cpp\n            ggml-opt.cpp\n            ggml-threading.cpp\n            ggml-threading.h\n            ggml-quants.c\n            ggml-quants.h\n            gguf.cpp)\n\nset_target_properties(ggml-base PROPERTIES\n    VERSION ${GGML_VERSION}\n    SOVERSION ${GGML_VERSION_MAJOR}\n)\n\ntarget_include_directories(ggml-base PRIVATE .)\nif (GGML_BACKEND_DL)\n    target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)\nendif()\n\nif (GGML_SCHED_NO_REALLOC)\n    target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC)\nendif()\n\nadd_library(ggml\n            ggml-backend-dl.cpp\n            ggml-backend-reg.cpp)\nadd_library(ggml::ggml ALIAS ggml)\n\nset_target_properties(ggml PROPERTIES\n    VERSION ${GGML_VERSION}\n    SOVERSION ${GGML_VERSION_MAJOR}\n)\n\nif (GGML_BACKEND_DIR)\n    if (NOT GGML_BACKEND_DL)\n        message(FATAL_ERROR \"GGML_BACKEND_DIR requires GGML_BACKEND_DL\")\n    endif()\n    target_compile_definitions(ggml PUBLIC GGML_BACKEND_DIR=\"${GGML_BACKEND_DIR}\")\nendif()\n\ntarget_link_libraries(ggml PUBLIC ggml-base)\n\nif (CMAKE_SYSTEM_NAME MATCHES \"Linux\")\n    target_link_libraries(ggml PRIVATE dl)\nendif()\n\nfunction(ggml_add_backend_library backend)\n    if (GGML_BACKEND_DL)\n        add_library(${backend} MODULE ${ARGN})\n        # write the shared library to the output directory\n        set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})\n        target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL)\n        add_dependencies(ggml ${backend})\n        if (GGML_BACKEND_DIR)\n            install(TARGETS ${backend} LIBRARY DESTINATION ${GGML_BACKEND_DIR})\n        else()\n            install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR})\n        endif()\n    else()\n        add_library(${backend} ${ARGN})\n        target_link_libraries(ggml PUBLIC ${backend})\n        install(TARGETS ${backend} LIBRARY)\n    endif()\n\n    target_link_libraries(${backend} PRIVATE ggml-base)\n    target_include_directories(${backend} PRIVATE ..)\n\n    if (${BUILD_SHARED_LIBS})\n        target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD)\n        target_compile_definitions(${backend} PUBLIC  GGML_BACKEND_SHARED)\n    endif()\n\n    # Set versioning properties for all backend libraries\n    # Building a MODULE library with a version is not supported on macOS (https://gitlab.kitware.com/cmake/cmake/-/issues/20782)\n    if (NOT (APPLE AND GGML_BACKEND_DL))\n        set_target_properties(${backend} PROPERTIES\n            VERSION ${GGML_VERSION}\n            SOVERSION ${GGML_VERSION_MAJOR}\n        )\n    endif()\n\n    if(NOT GGML_AVAILABLE_BACKENDS)\n        set(GGML_AVAILABLE_BACKENDS \"${backend}\"\n            CACHE INTERNAL \"List of backends for cmake package\")\n    else()\n        list(FIND GGML_AVAILABLE_BACKENDS \"${backend}\" has_backend)\n        if(has_backend EQUAL -1)\n            set(GGML_AVAILABLE_BACKENDS \"${GGML_AVAILABLE_BACKENDS};${backend}\"\n                CACHE INTERNAL \"List of backends for cmake package\")\n        endif()\n    endif()\nendfunction()\n\nfunction(ggml_add_backend backend)\n    string(TOUPPER \"GGML_${backend}\" backend_id)\n    if (${backend_id})\n        string(TOLOWER \"ggml-${backend}\" backend_target)\n        add_subdirectory(${backend_target})\n        message(STATUS \"Including ${backend} backend\")\n        if (NOT GGML_BACKEND_DL)\n            string(TOUPPER \"GGML_USE_${backend}\" backend_use)\n            target_compile_definitions(ggml PUBLIC ${backend_use})\n        endif()\n    endif()\nendfunction()\n\nfunction(ggml_add_cpu_backend_variant tag_name)\n    set(GGML_CPU_TAG_NAME ${tag_name})\n    # other: OPENMP LLAMAFILE CPU_HBM\n    if (GGML_SYSTEM_ARCH STREQUAL \"x86\")\n        foreach (feat NATIVE\n                      SSE42\n                      AVX AVX2 BMI2 AVX_VNNI FMA F16C\n                      AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16\n                      AMX_TILE AMX_INT8 AMX_BF16)\n            set(GGML_${feat} OFF)\n        endforeach()\n\n        foreach (feat ${ARGN})\n            set(GGML_${feat} ON)\n        endforeach()\n    elseif (GGML_SYSTEM_ARCH STREQUAL \"ARM\")\n        foreach (feat ${ARGN})\n            set(GGML_INTERNAL_${feat} ON)\n        endforeach()\n    elseif (GGML_SYSTEM_ARCH STREQUAL \"PowerPC\")\n        foreach (feat ${ARGN})\n            set(GGML_INTERNAL_${feat} ON)\n        endforeach()\n    elseif (GGML_SYSTEM_ARCH STREQUAL \"s390x\")\n        foreach (feat VXE2 NNPA)\n            set(GGML_INTERNAL_${feat} OFF)\n        endforeach()\n\n        foreach (feat ${ARGN})\n            set(GGML_INTERNAL_${feat} ON)\n        endforeach()\n    elseif (GGML_SYSTEM_ARCH STREQUAL \"riscv64\")\n        foreach (feat RVV)\n            set(GGML_INTERNAL_${feat} OFF)\n        endforeach()\n\n        foreach (feat ${ARGN})\n            set(GGML_INTERNAL_${feat} ON)\n        endforeach()\n    endif()\n\n    ggml_add_cpu_backend_variant_impl(${tag_name})\nendfunction()\n\nggml_add_backend(CPU)\n\nif (GGML_CPU_ALL_VARIANTS)\n    if (NOT GGML_BACKEND_DL)\n        message(FATAL_ERROR \"GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL\")\n    elseif (GGML_CPU_ARM_ARCH)\n        message(FATAL_ERROR \"Cannot use both GGML_CPU_ARM_ARCH and GGML_CPU_ALL_VARIANTS\")\n    endif()\n    if (GGML_SYSTEM_ARCH STREQUAL \"x86\")\n        ggml_add_cpu_backend_variant(x64)\n        ggml_add_cpu_backend_variant(sse42              SSE42)\n        ggml_add_cpu_backend_variant(sandybridge        SSE42 AVX)\n        if (NOT MSVC)\n            # __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512\n            ggml_add_cpu_backend_variant(ivybridge      SSE42 AVX F16C)\n            ggml_add_cpu_backend_variant(piledriver     SSE42 AVX F16C FMA)\n        endif()\n        ggml_add_cpu_backend_variant(haswell            SSE42 AVX F16C FMA AVX2 BMI2)\n        ggml_add_cpu_backend_variant(skylakex           SSE42 AVX F16C FMA AVX2 BMI2 AVX512)\n        ggml_add_cpu_backend_variant(cannonlake         SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI)\n        ggml_add_cpu_backend_variant(cascadelake        SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI)\n        ggml_add_cpu_backend_variant(icelake            SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI)\n        if (NOT MSVC)\n            # MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!\n            # https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170\n            # https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170\n            ggml_add_cpu_backend_variant(cooperlake     SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI AVX512_BF16)\n            ggml_add_cpu_backend_variant(zen4           SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16)\n        endif()\n        ggml_add_cpu_backend_variant(alderlake          SSE42 AVX F16C FMA AVX2 BMI2 AVX_VNNI)\n        if (NOT MSVC)\n            # MSVC doesn't support AMX\n            ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)\n        endif()\n    elseif(GGML_SYSTEM_ARCH STREQUAL \"ARM\")\n        if (CMAKE_SYSTEM_NAME MATCHES \"Linux\")\n            # Many of these features are optional so we build versions with popular\n            # combinations and name the backends based on the version they were\n            # first released with\n            ggml_add_cpu_backend_variant(armv8.0_1)\n            ggml_add_cpu_backend_variant(armv8.2_1    DOTPROD)\n            ggml_add_cpu_backend_variant(armv8.2_2    DOTPROD FP16_VECTOR_ARITHMETIC)\n            ggml_add_cpu_backend_variant(armv8.2_3    DOTPROD FP16_VECTOR_ARITHMETIC SVE)\n            ggml_add_cpu_backend_variant(armv8.6_1    DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8)\n            ggml_add_cpu_backend_variant(armv8.6_2    DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2)\n            ggml_add_cpu_backend_variant(armv9.2_1    DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SME)\n            ggml_add_cpu_backend_variant(armv9.2_2    DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2 SME)\n        elseif (CMAKE_SYSTEM_NAME MATCHES \"Android\")\n            # Android-specific backends with SoC-compatible feature sets\n            ggml_add_cpu_backend_variant(android_armv8.0_1)\n            ggml_add_cpu_backend_variant(android_armv8.2_1    DOTPROD)\n            ggml_add_cpu_backend_variant(android_armv8.2_2    DOTPROD FP16_VECTOR_ARITHMETIC)\n            ggml_add_cpu_backend_variant(android_armv8.6_1    DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8)\n            ggml_add_cpu_backend_variant(android_armv9.0_1    DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE2)\n            ggml_add_cpu_backend_variant(android_armv9.2_1    DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME)\n            ggml_add_cpu_backend_variant(android_armv9.2_2    DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SVE2 SME)\n        elseif (APPLE)\n            ggml_add_cpu_backend_variant(apple_m1             DOTPROD)\n            ggml_add_cpu_backend_variant(apple_m2_m3          DOTPROD MATMUL_INT8)\n            ggml_add_cpu_backend_variant(apple_m4             DOTPROD MATMUL_INT8 NOSVE SME)\n        else()\n            message(FATAL_ERROR \"Unsupported ARM target OS: ${CMAKE_SYSTEM_NAME}\")\n        endif()\n    elseif (GGML_SYSTEM_ARCH STREQUAL \"PowerPC\")\n        if (CMAKE_SYSTEM_NAME MATCHES \"Linux\")\n            ggml_add_cpu_backend_variant(power0)\n            ggml_add_cpu_backend_variant(power7_1       POWER7)\n            ggml_add_cpu_backend_variant(power7_2       POWER7  VSX)\n            ggml_add_cpu_backend_variant(power8_1       POWER8)\n            ggml_add_cpu_backend_variant(power8_2       POWER8  VSX)\n            ggml_add_cpu_backend_variant(power9         POWER9  VSX)\n            ggml_add_cpu_backend_variant(power10        POWER10 VSX)\n            ggml_add_cpu_backend_variant(power11        POWER11 VSX)\n        else()\n            message(FATAL_ERROR \"Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}\")\n        endif()\n    elseif (GGML_SYSTEM_ARCH STREQUAL \"s390x\")\n        if (CMAKE_SYSTEM_NAME MATCHES \"Linux\")\n            ggml_add_cpu_backend_variant(z15    Z15 VXE2)\n            ggml_add_cpu_backend_variant(z16    Z16 VXE2 NNPA)\n        else()\n            message(FATAL_ERROR \"Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}\")\n        endif()\n    elseif (GGML_SYSTEM_ARCH STREQUAL \"riscv64\")\n        if (CMAKE_SYSTEM_NAME MATCHES \"Linux\")\n            ggml_add_cpu_backend_variant(riscv64_0)\n            ggml_add_cpu_backend_variant(riscv64_v   RVV)\n        else()\n            message(FATAL_ERROR \"Unsupported RISC-V target OS: ${CMAKE_SYSTEM_NAME}\")\n        endif()\n    else()\n        message(FATAL_ERROR \"GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}\")\n    endif()\nelseif (GGML_CPU)\n    ggml_add_cpu_backend_variant_impl(\"\")\nendif()\n\nggml_add_backend(BLAS)\nggml_add_backend(CANN)\nggml_add_backend(CUDA)\nggml_add_backend(HIP)\nggml_add_backend(METAL)\nggml_add_backend(MUSA)\nggml_add_backend(RPC)\nggml_add_backend(VirtGPU)\nggml_add_backend(SYCL)\nggml_add_backend(Vulkan)\nggml_add_backend(WebGPU)\nggml_add_backend(zDNN)\nggml_add_backend(OpenCL)\nggml_add_backend(Hexagon)\nggml_add_backend(ZenDNN)\nggml_add_backend(OPENVINO)\n\nforeach (target ggml-base ggml)\n    target_include_directories(${target} PUBLIC    $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>)\n    target_compile_features   (${target} PRIVATE c_std_11 cxx_std_17) # don't bump\nendforeach()\n\ntarget_link_libraries(ggml-base PRIVATE Threads::Threads)\n\nfind_library(MATH_LIBRARY m)\nif (MATH_LIBRARY)\n    if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT})\n        target_link_libraries(ggml-base PRIVATE m)\n    endif()\nendif()\n\nif (CMAKE_SYSTEM_NAME MATCHES \"Android\")\n    target_link_libraries(ggml-base PRIVATE dl)\nendif()\n\nif(CMAKE_SYSTEM_NAME MATCHES \"visionOS\")\n    target_compile_definitions(ggml-base PUBLIC _DARWIN_C_SOURCE)\nendif()\n\nif (BUILD_SHARED_LIBS)\n    foreach (target ggml-base ggml)\n        set_target_properties(${target} PROPERTIES POSITION_INDEPENDENT_CODE ON)\n        target_compile_definitions(${target} PRIVATE GGML_BUILD)\n        target_compile_definitions(${target} PUBLIC  GGML_SHARED)\n    endforeach()\nendif()\n"
  },
  {
    "path": "src/ggml-alloc.c",
    "content": "#include \"ggml-alloc.h\"\n#include \"ggml-backend-impl.h\"\n#include \"ggml.h\"\n#include \"ggml-impl.h\"\n#include <assert.h>\n#include <limits.h>\n#include <stdarg.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <string.h>\n\n#define MAX(a, b) ((a) > (b) ? (a) : (b))\n#define MAX_FREE_BLOCKS 256\n\n//#define GGML_ALLOCATOR_DEBUG\n\n//#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__)\n#define AT_PRINTF(...)\n\n// ops that return true for this function must not use restrict pointers for their backend implementations\nbool ggml_op_can_inplace(enum ggml_op op) {\n    switch (op) {\n        case GGML_OP_FILL:\n        case GGML_OP_SCALE:\n        case GGML_OP_DIAG_MASK_ZERO:\n        case GGML_OP_DIAG_MASK_INF:\n        case GGML_OP_ADD:\n        case GGML_OP_ADD_ID:\n        case GGML_OP_ADD1:\n        case GGML_OP_SUB:\n        case GGML_OP_MUL:\n        case GGML_OP_DIV:\n        case GGML_OP_SQR:\n        case GGML_OP_SQRT:\n        case GGML_OP_LOG:\n        case GGML_OP_UNARY:\n        case GGML_OP_ROPE:\n        case GGML_OP_ROPE_BACK:\n        case GGML_OP_SILU_BACK:\n        case GGML_OP_RMS_NORM:\n        case GGML_OP_RMS_NORM_BACK:\n        case GGML_OP_SOFT_MAX:\n        case GGML_OP_SOFT_MAX_BACK:\n            return true;\n\n        default:\n            return false;\n    }\n}\n\nstatic size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {\n    assert(alignment && !(alignment & (alignment - 1))); // power of 2\n    size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment;\n    return offset + align;\n}\n\n// tallocr\n\nstruct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer) {\n    void * base = ggml_backend_buffer_get_base(buffer);\n    size_t align = ggml_backend_buffer_get_alignment(buffer);\n\n    assert(align && !(align & (align - 1))); // power of 2\n\n    struct ggml_tallocr talloc = (struct ggml_tallocr) {\n        /*.buffer    = */ buffer,\n        /*.base      = */ base,\n        /*.alignment = */ align,\n        /*.offset    = */ aligned_offset(base, 0, align),\n    };\n    return talloc;\n}\n\nenum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor) {\n    size_t size = ggml_backend_buffer_get_alloc_size(talloc->buffer, tensor);\n    size = GGML_PAD(size, talloc->alignment);\n\n    if (talloc->offset + size > ggml_backend_buffer_get_size(talloc->buffer)) {\n        GGML_LOG_ERROR(\"%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\\n\",\n                __func__, tensor->name, size, ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset);\n        GGML_ABORT(\"not enough space in the buffer\");\n    }\n\n    void * addr = (char *)ggml_backend_buffer_get_base(talloc->buffer) + talloc->offset;\n    talloc->offset += size;\n\n    assert(((uintptr_t)addr % talloc->alignment) == 0);\n\n    return ggml_backend_tensor_alloc(talloc->buffer, tensor, addr);\n}\n\n// dynamic tensor allocator\n\n#define GGML_VBUFFER_MAX_CHUNKS 16\n\n// relative memory address within an allocation that can be split into multiple buffers (chunks)\nstruct buffer_address {\n    int chunk;     // index of a backend buffer\n    size_t offset; // local memory offset within the buffer\n};\n\nstatic const struct buffer_address GGML_BUFFER_ADDRESS_INVALID = { -1, SIZE_MAX };\n\nstatic bool ggml_buffer_address_less(struct buffer_address a, struct buffer_address b) {\n    return a.chunk != b.chunk ? a.chunk < b.chunk : a.offset < b.offset;\n}\n\nstruct free_block {\n    size_t offset;\n    size_t size;\n};\n\nstruct tallocr_chunk {\n    struct free_block free_blocks[MAX_FREE_BLOCKS];\n    int n_free_blocks;\n    size_t max_size;\n};\n\nstruct ggml_dyn_tallocr {\n    size_t alignment;\n    size_t max_chunk_size;\n    struct tallocr_chunk * chunks[GGML_VBUFFER_MAX_CHUNKS];\n    int n_chunks;\n\n#ifdef GGML_ALLOCATOR_DEBUG\n    struct {\n        const struct ggml_tensor * tensor;\n        struct buffer_address addr;\n    } allocated_tensors[1024];\n#endif\n};\n\nstatic void ggml_dyn_tallocr_insert_block(struct tallocr_chunk * chunk, size_t offset, size_t size) {\n    GGML_ASSERT(chunk->n_free_blocks < MAX_FREE_BLOCKS && \"out of free blocks\");\n    // insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster)\n    int insert_pos = 0;\n    while (insert_pos < chunk->n_free_blocks && chunk->free_blocks[insert_pos].offset < offset) {\n        insert_pos++;\n    }\n    // shift all blocks from insert_pos onward to make room for the new block\n    for (int i = chunk->n_free_blocks; i > insert_pos; i--) {\n        chunk->free_blocks[i] = chunk->free_blocks[i-1];\n    }\n    // insert the new block\n    chunk->free_blocks[insert_pos].offset = offset;\n    chunk->free_blocks[insert_pos].size = size;\n    chunk->n_free_blocks++;\n}\n\nstatic void ggml_dyn_tallocr_remove_block(struct tallocr_chunk * chunk, int idx) {\n    // shift all elements after idx by 1 to the left, overwriting the element at idx\n    for (int i = idx; i < chunk->n_free_blocks; i++) {\n        chunk->free_blocks[i] = chunk->free_blocks[i+1];\n    }\n    chunk->n_free_blocks--;\n}\n\nstatic int ggml_dyn_tallocr_new_chunk(struct ggml_dyn_tallocr * alloc, size_t min_size) {\n    if (alloc->n_chunks >= GGML_VBUFFER_MAX_CHUNKS) {\n        return -1;\n    }\n    struct tallocr_chunk * chunk = calloc(1, sizeof(struct tallocr_chunk));\n    chunk->n_free_blocks = 1;\n    chunk->free_blocks[0].offset = 0;\n    // available space in a chunk is limited to max_chunk_size, but can be higher if:\n    // 1. a single tensor exceeds the maximum, and cannot fit any other way\n    // 2. we are running out of chunks\n    // backends will either manage to allocate the larger size, or report an error.\n    chunk->free_blocks[0].size = MAX(min_size, alloc->max_chunk_size);\n    if (alloc->n_chunks == GGML_VBUFFER_MAX_CHUNKS - 1) {\n        chunk->free_blocks[0].size = SIZE_MAX/2;\n    }\n    alloc->chunks[alloc->n_chunks] = chunk;\n    alloc->n_chunks++;\n    return alloc->n_chunks - 1;\n}\n\n#ifdef GGML_ALLOCATOR_DEBUG\nstatic void add_allocated_tensor(struct ggml_dyn_tallocr * alloc, struct buffer_address addr, const struct ggml_tensor * tensor) {\n    for (int i = 0; i < 1024; i++) {\n        if (alloc->allocated_tensors[i].tensor == NULL) {\n            alloc->allocated_tensors[i].tensor = tensor;\n            alloc->allocated_tensors[i].addr = addr;\n            return;\n        }\n    }\n    GGML_ABORT(\"out of allocated_tensors\");\n}\nstatic void remove_allocated_tensor(struct ggml_dyn_tallocr * alloc, struct buffer_address addr, const struct ggml_tensor * tensor) {\n    for (int i = 0; i < 1024; i++) {\n        if (alloc->allocated_tensors[i].addr.chunk == addr.chunk && alloc->allocated_tensors[i].addr.offset == addr.offset) {\n            alloc->allocated_tensors[i].tensor = NULL;\n            return;\n        }\n    }\n    GGML_ABORT(\"tried to free tensor %s not found\\n\", tensor->name);\n}\n#endif\n\nstatic struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t size, const struct ggml_tensor * tensor) {\n    size = aligned_offset(NULL, size, alloc->alignment);\n\n    AT_PRINTF(\"%s: allocating %s (%zu bytes) - \", __func__, tensor->name, size);\n\n    int best_fit_chunk = -1;\n    int best_fit_block = -1;\n    size_t max_avail = 0;\n\n    // find the best fitting free block besides the last block, within any chunk\n    for (int c = 0; c < alloc->n_chunks; ++c) {\n        struct tallocr_chunk * chunk = alloc->chunks[c];\n        size_t best_fit_size = SIZE_MAX;\n        for (int i = 0; i < chunk->n_free_blocks - 1; i++) {\n            struct free_block * block = &chunk->free_blocks[i];\n            max_avail = MAX(max_avail, block->size);\n            if (block->size >= size && block->size <= best_fit_size) {\n                best_fit_chunk = c;\n                best_fit_block = i;\n                best_fit_size = block->size;\n            }\n        }\n    }\n\n    if (best_fit_block == -1) {\n        // no suitable block found, try the last block (this may grow a chunks size)\n        int64_t best_reuse = INT64_MIN;\n        for (int c = 0; c < alloc->n_chunks; ++c) {\n            struct tallocr_chunk * chunk = alloc->chunks[c];\n            if (chunk->n_free_blocks > 0) {\n                struct free_block * block = &chunk->free_blocks[chunk->n_free_blocks - 1];\n                max_avail = MAX(max_avail, block->size);\n                int64_t reuse_factor = chunk->max_size - block->offset - size;\n                // reuse_factor < 0 : amount of extra memory that needs to be allocated\n                // reuse_factor = 0 : allocated free space exactly matches tensor size\n                // reuse_factor > 0 : superfluous memory that will remain unused\n                bool better_reuse = best_reuse < 0 && reuse_factor > best_reuse;\n                bool better_fit = reuse_factor >= 0 && reuse_factor < best_reuse;\n                if (block->size >= size && (better_reuse || better_fit)) {\n                    best_fit_chunk = c;\n                    best_fit_block = chunk->n_free_blocks - 1;\n                    best_reuse = reuse_factor;\n                }\n            }\n        }\n    }\n\n    if (best_fit_block == -1) {\n        // none of the existing chunks have enough space left\n        best_fit_chunk = ggml_dyn_tallocr_new_chunk(alloc, size);\n        best_fit_block = 0;\n    }\n    if (best_fit_chunk == -1) {\n        // since the last chunk always has virtually endless memory, this should never happen\n        GGML_LOG_ERROR(\"%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\\n\",\n            __func__, size, max_avail);\n        GGML_ABORT(\"graph allocation: failed to reserve memory\");\n    }\n\n    struct tallocr_chunk * chunk = alloc->chunks[best_fit_chunk];\n    struct free_block    * block = &chunk->free_blocks[best_fit_block];\n    struct buffer_address  addr  = {.chunk = best_fit_chunk, .offset = block->offset };\n    block->offset += size;\n    block->size -= size;\n    if (block->size == 0) {\n        // remove block if empty\n        ggml_dyn_tallocr_remove_block(chunk, best_fit_block);\n    }\n\n    AT_PRINTF(\"block %d, offset %zu, chunk %d\\n\", best_fit_block, addr.offset, addr.chunk);\n\n#ifdef GGML_ALLOCATOR_DEBUG\n    add_allocated_tensor(alloc, addr, tensor);\n    size_t cur_max = addr.offset + size;\n    if (cur_max > chunk->max_size) {\n        // sort allocated_tensors by chunk/offset\n        for (int i = 0; i < 1024; i++) {\n            for (int j = i + 1; j < 1024; j++) {\n                if (ggml_buffer_address_less(alloc->allocated_tensors[j].addr, alloc->allocated_tensors[i].addr)) {\n                    const struct ggml_tensor * tmp_tensor = alloc->allocated_tensors[i].tensor;\n                    struct buffer_address tmp_addr = alloc->allocated_tensors[i].addr;\n                    alloc->allocated_tensors[i].tensor = alloc->allocated_tensors[j].tensor;\n                    alloc->allocated_tensors[i].addr = alloc->allocated_tensors[j].addr;\n                    alloc->allocated_tensors[j].tensor = tmp_tensor;\n                    alloc->allocated_tensors[j].addr = tmp_addr;\n                }\n            }\n        }\n        GGML_LOG_DEBUG(\"max_size[%d] = %.2f MB: tensors: \", addr.chunk, cur_max / 1024.0 / 1024.0);\n        for (int i = 0; i < 1024; i++) {\n            if (alloc->allocated_tensors[i].tensor) {\n                GGML_LOG_DEBUG(\"%s [%d: %zx-%zx] (%.2f MB) \", alloc->allocated_tensors[i].tensor->name,\n                    alloc->allocated_tensors[i].addr.chunk,\n                    alloc->allocated_tensors[i].addr.offset,\n                    alloc->allocated_tensors[i].addr.offset + ggml_nbytes(alloc->allocated_tensors[i].tensor),\n                    ggml_nbytes(alloc->allocated_tensors[i].tensor) / 1024.0 / 1024.0);\n            }\n        }\n        GGML_LOG_DEBUG(\"\\n\");\n    }\n#endif\n\n    chunk->max_size = MAX(chunk->max_size, addr.offset + size);\n\n    return addr;\n\n    GGML_UNUSED(tensor);\n}\n\n// this is a very naive implementation, but for our case the number of free blocks should be very small\nstatic void ggml_dyn_tallocr_free_bytes(struct ggml_dyn_tallocr * alloc, struct buffer_address addr, size_t size) {\n    size = aligned_offset(NULL, size, alloc->alignment);\n\n    struct tallocr_chunk * chunk = alloc->chunks[addr.chunk];\n\n    // see if we can merge with an existing block\n    for (int i = 0; i < chunk->n_free_blocks; i++) {\n        struct free_block * block = &chunk->free_blocks[i];\n        // check if ptr is at the end of the block\n        if (block->offset + block->size == addr.offset) {\n            block->size += size;\n            // check if we can merge with the next block\n            if (i < chunk->n_free_blocks - 1) {\n                struct free_block * next = &chunk->free_blocks[i+1];\n                if (block->offset + block->size == next->offset) {\n                    block->size += next->size;\n                    ggml_dyn_tallocr_remove_block(chunk, i+1);\n                }\n            }\n            return;\n        }\n        // check if ptr is at the beginning of the block\n        if (addr.offset + size == block->offset) {\n            block->offset = addr.offset;\n            block->size += size;\n            // check if we can merge with the previous block\n            if (i > 0) {\n                struct free_block * prev = &chunk->free_blocks[i-1];\n                if (prev->offset + prev->size == block->offset) {\n                    prev->size += block->size;\n                    ggml_dyn_tallocr_remove_block(chunk, i);\n                }\n            }\n            return;\n        }\n    }\n    // otherwise, add a new block\n    ggml_dyn_tallocr_insert_block(chunk, addr.offset, size);\n}\n\nstatic void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) {\n    for (int i = 0; i < GGML_VBUFFER_MAX_CHUNKS; i++) {\n        free(alloc->chunks[i]);\n        alloc->chunks[i] = NULL;\n    }\n    alloc->n_chunks = 0;\n\n#ifdef GGML_ALLOCATOR_DEBUG\n    for (int i = 0; i < 1024; i++) {\n        alloc->allocated_tensors[i].tensor = NULL;\n    }\n#endif\n}\n\nstatic struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment, size_t max_buffer_size) {\n    struct ggml_dyn_tallocr * alloc = (struct ggml_dyn_tallocr *)malloc(sizeof(struct ggml_dyn_tallocr));\n\n    *alloc = (struct ggml_dyn_tallocr) {\n        /*.alignment      = */ alignment,\n        /*.max_chunk_size = */ MIN(max_buffer_size, SIZE_MAX/2), // clamp to avoid overflows\n        /*.chunks         = */ {NULL},\n        /*.n_chunks       = */ 0,\n#ifdef GGML_ALLOCATOR_DEBUG\n        /*.allocated_tensors = */ {{0}},\n#endif\n    };\n\n    ggml_dyn_tallocr_reset(alloc);\n\n    return alloc;\n}\n\nstatic void ggml_dyn_tallocr_free(struct ggml_dyn_tallocr * alloc) {\n    for (int i = 0; i < alloc->n_chunks; ++i) {\n        free(alloc->chunks[i]);\n    }\n    free(alloc);\n}\n\nstatic size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc, int chunk) {\n    return chunk < alloc->n_chunks ? alloc->chunks[chunk]->max_size : 0;\n}\n\n\n// virtual buffer with contiguous memory range, split into multiple backend buffers (chunks)\n\nstruct vbuffer {\n    ggml_backend_buffer_t chunks[GGML_VBUFFER_MAX_CHUNKS];\n};\n\nstatic void ggml_vbuffer_free(struct vbuffer * buf) {\n    if (buf == NULL) {\n        return;\n    }\n    for (int i = 0; i < GGML_VBUFFER_MAX_CHUNKS; ++i) {\n        ggml_backend_buffer_free(buf->chunks[i]);\n    }\n    free(buf);\n}\n\nstatic size_t ggml_vbuffer_chunk_size(struct vbuffer * buf, int chunk) {\n    return buf->chunks[chunk] ? ggml_backend_buffer_get_size(buf->chunks[chunk]) : 0;\n}\n\nstatic size_t ggml_vbuffer_size(struct vbuffer * buf) {\n    size_t size = 0;\n    for (int i = 0; i < GGML_VBUFFER_MAX_CHUNKS && buf->chunks[i]; ++i) {\n        size += ggml_backend_buffer_get_size(buf->chunks[i]);\n    }\n    return size;\n}\n\nstatic struct vbuffer * ggml_vbuffer_alloc(ggml_backend_buffer_type_t buft, const struct ggml_dyn_tallocr * talloc, enum ggml_backend_buffer_usage usage) {\n    struct vbuffer * buf = (struct vbuffer *)calloc(1, sizeof(struct vbuffer));\n    if (buf == NULL) {\n        return NULL;\n    }\n\n    for (int n = 0; n < talloc->n_chunks; n++) {\n        size_t chunk_size = talloc->chunks[n]->max_size;\n        buf->chunks[n] = ggml_backend_buft_alloc_buffer(buft, chunk_size);\n        if (buf->chunks[n] == NULL) {\n            ggml_vbuffer_free(buf);\n            return NULL;\n        }\n        ggml_backend_buffer_set_usage(buf->chunks[n], usage);\n    }\n    return buf;\n}\n\nstatic void ggml_vbuffer_tensor_alloc(struct vbuffer * buf, struct ggml_tensor * tensor, struct buffer_address buf_addr) {\n    void * base = ggml_backend_buffer_get_base(buf->chunks[buf_addr.chunk]);\n    void * addr = (char *)base + buf_addr.offset;\n    ggml_backend_tensor_alloc(buf->chunks[buf_addr.chunk], tensor, addr);\n}\n\nstatic void ggml_vbuffer_reset(struct vbuffer * buf) {\n    for (int i = 0; i < GGML_VBUFFER_MAX_CHUNKS && buf->chunks[i]; ++i) {\n        ggml_backend_buffer_reset(buf->chunks[i]);\n    }\n}\n\n\n/////////////////////////////////////\n\n// graph allocator\n\nstruct hash_node {\n    int n_children;\n    int n_views;\n    int buffer_id;\n    struct buffer_address addr;\n    bool allocated;\n};\n\nstruct tensor_alloc {\n    int buffer_id;\n    struct buffer_address addr;\n    size_t size_max; // 0 = pre-allocated, unused, or view\n};\n\nstruct leaf_alloc {\n    struct tensor_alloc leaf;\n};\n\nstruct node_alloc {\n    struct tensor_alloc dst;\n    struct tensor_alloc src[GGML_MAX_SRC];\n};\n\nstruct ggml_gallocr {\n    ggml_backend_buffer_type_t * bufts; // [n_buffers]\n    struct vbuffer ** buffers; // [n_buffers]\n    struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers]\n    int n_buffers;\n\n    struct ggml_hash_set hash_set;\n    struct hash_node * hash_values; // [hash_set.size]\n\n    struct node_alloc * node_allocs; // [n_nodes]\n    int n_nodes;\n\n    struct leaf_alloc * leaf_allocs; // [n_leafs]\n    int n_leafs;\n};\n\nggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs) {\n    ggml_gallocr_t galloc = (ggml_gallocr_t)calloc(1, sizeof(struct ggml_gallocr));\n    GGML_ASSERT(galloc != NULL);\n\n    galloc->bufts = calloc(n_bufs, sizeof(ggml_backend_buffer_type_t));\n    GGML_ASSERT(galloc->bufts != NULL);\n\n    galloc->buffers = calloc(n_bufs, sizeof(struct vbuffer *));\n    GGML_ASSERT(galloc->buffers != NULL);\n\n    galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *));\n    GGML_ASSERT(galloc->buf_tallocs != NULL);\n\n    for (int i = 0; i < n_bufs; i++) {\n        galloc->bufts[i] = bufts[i];\n        galloc->buffers[i] = NULL;\n\n        // check if the same buffer type is used multiple times and reuse the same allocator\n        for (int j = 0; j < i; j++) {\n            if (bufts[i] == bufts[j]) {\n                galloc->buf_tallocs[i] = galloc->buf_tallocs[j];\n                break;\n            }\n        }\n\n        if (galloc->buf_tallocs[i] == NULL) {\n            size_t alignment = ggml_backend_buft_get_alignment(bufts[i]);\n            size_t max_size = ggml_backend_buft_get_max_size(bufts[i]);\n            galloc->buf_tallocs[i] = ggml_dyn_tallocr_new(alignment, max_size);\n        }\n    }\n    galloc->n_buffers = n_bufs;\n\n    return galloc;\n}\n\nggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft) {\n    return ggml_gallocr_new_n(&buft, 1);\n}\n\nvoid ggml_gallocr_free(ggml_gallocr_t galloc) {\n    if (galloc == NULL) {\n        return;\n    }\n\n    for (int i = 0; i < galloc->n_buffers; i++) {\n        if (galloc->buffers != NULL) {\n            // skip if already freed\n            bool freed = false;\n            for (int j = 0; j < i; j++) {\n                if (galloc->buffers[j] == galloc->buffers[i]) {\n                    freed = true;\n                    break;\n                }\n            }\n            if (!freed) {\n                ggml_vbuffer_free(galloc->buffers[i]);\n            }\n        }\n        if (galloc->buf_tallocs != NULL) {\n            // skip if already freed\n            bool freed = false;\n            for (int j = 0; j < i; j++) {\n                if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) {\n                    freed = true;\n                    break;\n                }\n            }\n            if (!freed) {\n                ggml_dyn_tallocr_free(galloc->buf_tallocs[i]);\n            }\n        }\n    }\n\n    ggml_hash_set_free(&galloc->hash_set);\n    free(galloc->hash_values);\n    free(galloc->bufts);\n    free(galloc->buffers);\n    free(galloc->buf_tallocs);\n    free(galloc->node_allocs);\n    free(galloc->leaf_allocs);\n    free(galloc);\n}\n\ntypedef struct ggml_gallocr * ggml_gallocr_t;\n\nstatic struct hash_node * ggml_gallocr_hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) {\n    size_t i = ggml_hash_find_or_insert(&galloc->hash_set, t);\n    return &galloc->hash_values[i];\n}\n\nstatic bool ggml_gallocr_is_own(ggml_gallocr_t galloc, struct ggml_tensor * t) {\n    return ggml_gallocr_hash_get(galloc, t)->allocated;\n}\n\nstatic bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) {\n    return t->data != NULL // tensor data already set externally\n        || t->buffer // tensor on external buffer (but not yet allocated)\n        || ggml_gallocr_is_own(galloc, t); // tensor will be allocated by galloc\n}\n\n// free the extra space at the end if the new tensor is smaller\nstatic void ggml_gallocr_free_extra_space(ggml_gallocr_t galloc, struct ggml_tensor * node, struct ggml_tensor * parent) {\n    struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);\n    struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);\n\n    size_t parent_size = ggml_backend_buft_get_alloc_size(galloc->bufts[p_hn->buffer_id], parent);\n    size_t node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);\n\n    GGML_ASSERT(parent_size >= node_size);\n\n    // note: we want after the freeing the chunks to continue to be aligned\n    struct ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id];\n    parent_size = aligned_offset(NULL, parent_size, p_alloc->alignment);\n    node_size = aligned_offset(NULL, node_size, p_alloc->alignment);\n\n    if (parent_size > node_size) {\n        struct buffer_address p_addr = p_hn->addr;\n        p_addr.offset += node_size;\n        size_t extra_size = parent_size - node_size;\n        AT_PRINTF(\"freeing extra %zu bytes from parent %s for %s\\n\", extra_size, parent->name, node->name);\n        ggml_dyn_tallocr_free_bytes(p_alloc, p_addr, extra_size);\n    }\n}\n\nstatic void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {\n    GGML_ASSERT(buffer_id >= 0);\n    struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);\n\n    if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_impl_is_view(node)) {\n        hn->allocated = true;\n        assert(hn->addr.offset == 0);\n\n        // try to reuse a parent's buffer (inplace)\n        if (ggml_op_can_inplace(node->op)) {\n            for (int i = 0; i < GGML_MAX_SRC; i++) {\n                struct ggml_tensor * parent = node->src[i];\n                if (parent == NULL) {\n                    continue;\n                }\n\n                // if the node's data is external, then we cannot re-use it\n                if (!ggml_gallocr_is_own(galloc, parent)) {\n                    AT_PRINTF(\"not reusing parent %s for %s as %p is external\\n\", parent->name, node->name, parent->data);\n                    continue;\n                }\n\n                // outputs cannot be reused\n                if (parent->flags & GGML_TENSOR_FLAG_OUTPUT || (parent->view_src != NULL && parent->view_src->flags & GGML_TENSOR_FLAG_OUTPUT)) {\n                    AT_PRINTF(\"not reusing parent %s for %s as it is an output\\n\", parent->name, node->name);\n                    continue;\n                }\n\n                if (!ggml_are_same_layout(node, parent)) {\n                    AT_PRINTF(\"not reusing parent %s for %s as layouts are different\\n\", parent->name, node->name);\n                    continue;\n                }\n\n                struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);\n                if (p_hn->n_children == 1 && p_hn->n_views == 0) {\n                    if (ggml_impl_is_view(parent)) {\n                        struct ggml_tensor * view_src = parent->view_src;\n                        struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);\n                        if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {\n                            AT_PRINTF(\"reusing view parent %s (%s) for %s\\n\", parent->name, view_src->name, node->name);\n                            assert(view_src_hn->addr.chunk == p_hn->addr.chunk && view_src_hn->addr.offset == p_hn->addr.offset);\n                            hn->buffer_id = p_hn->buffer_id;\n                            hn->addr = p_hn->addr;\n                            p_hn->allocated = false; // avoid freeing the parent\n                            view_src_hn->allocated = false;\n                            ggml_gallocr_free_extra_space(galloc, node, view_src);\n                            return;\n                        }\n                    } else {\n                        AT_PRINTF(\"reusing parent %s for %s\\n\", parent->name, node->name);\n                        hn->buffer_id = p_hn->buffer_id;\n                        hn->addr = p_hn->addr;\n                        p_hn->allocated = false; // avoid freeing the parent\n                        ggml_gallocr_free_extra_space(galloc, node, parent);\n                        return;\n                    }\n                }\n            }\n        }\n        // allocate tensor from the buffer\n        struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id];\n        ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id];\n        size_t size = ggml_backend_buft_get_alloc_size(buft, node);\n        hn->buffer_id = buffer_id;\n        hn->addr = ggml_dyn_tallocr_alloc(alloc, size, node);\n    }\n}\n\nstatic void ggml_gallocr_free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {\n    // graph outputs are never freed\n    if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {\n        AT_PRINTF(\"not freeing output %s\\n\", node->name);\n        return;\n    }\n\n    struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);\n    int buffer_id = hn->buffer_id;\n    struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id];\n    ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id];\n    size_t size = ggml_backend_buft_get_alloc_size(buft, node);\n\n    AT_PRINTF(\"%s: freeing %s at {chunk=%d, offset=%zu} (%zu bytes) - n_free_blocks = %d\\n\",\n        __func__, node->name, hn->addr.chunk, hn->addr.offset, size, alloc->chunks[hn->addr.chunk]->n_free_blocks);\n#ifdef GGML_ALLOCATOR_DEBUG\n    remove_allocated_tensor(alloc, hn->addr, node);\n#endif\n\n    ggml_dyn_tallocr_free_bytes(alloc, hn->addr, size);\n    hn->allocated = false;\n}\n\nstatic int get_node_buffer_id(const int * node_buffer_ids, int i) {\n    return node_buffer_ids ? node_buffer_ids[i] : 0;\n}\n\nstatic void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {\n    // clear hash tables\n    ggml_hash_set_reset(&galloc->hash_set);\n    memset(galloc->hash_values, 0, sizeof(struct hash_node) * galloc->hash_set.size);\n\n    // allocate leafs\n    // these may be tensors that the application is not using in the graph, but may still want to allocate for other purposes\n    for (int i = 0; i < graph->n_leafs; i++) {\n        struct ggml_tensor * leaf = graph->leafs[i];\n        ggml_gallocr_allocate_node(galloc, leaf, get_node_buffer_id(leaf_buffer_ids, i));\n    }\n\n    // count number of children and views\n    // allocate other graph inputs and leafs first to avoid overwriting them\n    for (int i = 0; i < graph->n_nodes; i++) {\n        struct ggml_tensor * node = graph->nodes[i];\n\n        // TODO: better way to add external dependencies\n        // GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to\n        // control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node\n        // itself is never used and should not be considered a dependency\n        if (ggml_impl_is_view(node) && node->op != GGML_OP_NONE) {\n            struct ggml_tensor * view_src = node->view_src;\n            ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;\n        }\n\n        if (node->flags & GGML_TENSOR_FLAG_INPUT) {\n            ggml_gallocr_allocate_node(galloc, graph->nodes[i], get_node_buffer_id(node_buffer_ids, i));\n        }\n\n        for (int j = 0; j < GGML_MAX_SRC; j++) {\n            struct ggml_tensor * src = node->src[j];\n            if (src == NULL) {\n                continue;\n            }\n\n            ggml_gallocr_hash_get(galloc, src)->n_children += 1;\n\n            // allocate explicit inputs\n            if (src->flags & GGML_TENSOR_FLAG_INPUT) {\n                ggml_gallocr_allocate_node(galloc, src, get_node_buffer_id(node_buffer_ids, i));\n            }\n        }\n    }\n\n    // allocate tensors\n    for (int i = 0; i < graph->n_nodes; i++) {\n        struct ggml_tensor * node = graph->nodes[i];\n        int buffer_id = get_node_buffer_id(node_buffer_ids, i);\n\n        // allocate parents (only leafs need to be allocated at this point)\n        for (int j = 0; j < GGML_MAX_SRC; j++) {\n            struct ggml_tensor * parent = node->src[j];\n            if (parent == NULL) {\n                continue;\n            }\n            ggml_gallocr_allocate_node(galloc, parent, buffer_id);\n        }\n\n        // allocate node\n        ggml_gallocr_allocate_node(galloc, node, buffer_id);\n\n        AT_PRINTF(\"exec: %s (%s) <= \", ggml_op_desc(node), node->name);\n        for (int j = 0; j < GGML_MAX_SRC; j++) {\n            struct ggml_tensor * parent = node->src[j];\n            if (parent == NULL) {\n                continue;\n            }\n            AT_PRINTF(\"%s\", parent->name);\n            if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {\n                AT_PRINTF(\", \");\n            }\n        }\n        AT_PRINTF(\"\\n\");\n\n        // update parents\n        for (int j = 0; j < GGML_MAX_SRC; j++) {\n            struct ggml_tensor * parent = node->src[j];\n            if (parent == NULL) {\n                continue;\n            }\n            struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);\n            p_hn->n_children -= 1;\n\n            AT_PRINTF(\"parent %s: %d children, %d views, allocated: %d\\n\",\n                parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated);\n\n            if (p_hn->n_children == 0 && p_hn->n_views == 0) {\n                if (ggml_impl_is_view(parent)) {\n                    struct ggml_tensor * view_src = parent->view_src;\n                    struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);\n                    view_src_hn->n_views -= 1;\n                    AT_PRINTF(\"view_src %s: %d children, %d views\\n\",\n                        view_src->name, view_src_hn->n_children, view_src_hn->n_views);\n                    if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src_hn->allocated) {\n                        ggml_gallocr_free_node(galloc, view_src);\n                    }\n                }\n                else if (p_hn->allocated) {\n                    ggml_gallocr_free_node(galloc, parent);\n                }\n            }\n            AT_PRINTF(\"\\n\");\n        }\n    }\n}\n\nstatic bool ggml_gallocr_reserve_n_impl(\n        ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids, bool no_alloc) {\n    size_t min_hash_size = graph->n_nodes + graph->n_leafs;\n    // add 25% margin to avoid hash collisions\n    min_hash_size += min_hash_size / 4;\n\n    // initialize hash table\n    if (galloc->hash_set.size < min_hash_size) {\n        ggml_hash_set_free(&galloc->hash_set);\n        galloc->hash_set = ggml_hash_set_new(min_hash_size);\n        GGML_ASSERT(galloc->hash_set.keys != NULL);\n\n        free(galloc->hash_values);\n        galloc->hash_values = malloc(sizeof(struct hash_node) * galloc->hash_set.size);\n        GGML_ASSERT(galloc->hash_values != NULL);\n    }\n\n    // reset allocators\n    for (int i = 0; i < galloc->n_buffers; i++) {\n        ggml_dyn_tallocr_reset(galloc->buf_tallocs[i]);\n    }\n\n    // allocate in hash table\n    ggml_gallocr_alloc_graph_impl(galloc, graph, node_buffer_ids, leaf_buffer_ids);\n\n    // set the node_allocs from the hash table\n    if (galloc->n_nodes < graph->n_nodes) {\n        free(galloc->node_allocs);\n        galloc->node_allocs = calloc(graph->n_nodes, sizeof(struct node_alloc));\n        GGML_ASSERT(galloc->node_allocs != NULL);\n    }\n    galloc->n_nodes = graph->n_nodes;\n    for (int i = 0; i < graph->n_nodes; i++) {\n        struct ggml_tensor * node = graph->nodes[i];\n        struct node_alloc * node_alloc = &galloc->node_allocs[i];\n        if (node->view_src || node->data) {\n            node_alloc->dst.buffer_id = -1;\n            node_alloc->dst.addr = GGML_BUFFER_ADDRESS_INVALID;\n            node_alloc->dst.size_max = 0;\n        } else {\n            struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);\n            node_alloc->dst.buffer_id = hn->buffer_id;\n            node_alloc->dst.addr = hn->addr;\n            node_alloc->dst.size_max  = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);\n        }\n        for (int j = 0; j < GGML_MAX_SRC; j++) {\n            struct ggml_tensor * src = node->src[j];\n            if (!src || src->view_src || src->data) {\n                node_alloc->src[j].buffer_id = -1;\n                node_alloc->src[j].addr = GGML_BUFFER_ADDRESS_INVALID;\n                node_alloc->src[j].size_max = 0;\n            } else {\n                struct hash_node * hn = ggml_gallocr_hash_get(galloc, src);\n                node_alloc->src[j].buffer_id = hn->buffer_id;\n                node_alloc->src[j].addr = hn->addr;\n                node_alloc->src[j].size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], src);\n            }\n        }\n    }\n    if (galloc->n_leafs < graph->n_leafs) {\n        free(galloc->leaf_allocs);\n        galloc->leaf_allocs = calloc(graph->n_leafs, sizeof(galloc->leaf_allocs[0]));\n        GGML_ASSERT(galloc->leaf_allocs != NULL);\n    }\n    galloc->n_leafs = graph->n_leafs;\n    for (int i = 0; i < graph->n_leafs; i++) {\n        struct ggml_tensor * leaf = graph->leafs[i];\n        struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf);\n        if (leaf->view_src || leaf->data) {\n            galloc->leaf_allocs[i].leaf.buffer_id = -1;\n            galloc->leaf_allocs[i].leaf.addr = GGML_BUFFER_ADDRESS_INVALID;\n            galloc->leaf_allocs[i].leaf.size_max = 0;\n        } else {\n            galloc->leaf_allocs[i].leaf.buffer_id = hn->buffer_id;\n            galloc->leaf_allocs[i].leaf.addr = hn->addr;\n            galloc->leaf_allocs[i].leaf.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], leaf);\n        }\n    }\n\n    // reallocate buffers if needed\n    for (int i = 0; i < galloc->n_buffers; i++) {\n        // if the buffer type is used multiple times, we reuse the same buffer\n        for (int j = 0; j < i; j++) {\n            if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) {\n                galloc->buffers[i] = galloc->buffers[j];\n                break;\n            }\n        }\n\n        // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views\n        bool realloc = galloc->buffers[i] == NULL;\n        size_t new_size = 0;\n        for (int c = 0; c < galloc->buf_tallocs[i]->n_chunks; c++) {\n            size_t cur_chunk_size = galloc->buffers[i] ? ggml_vbuffer_chunk_size(galloc->buffers[i], c) : 0;\n            size_t new_chunk_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i], c);\n            new_size += new_chunk_size;\n            if (new_chunk_size > cur_chunk_size) {\n                realloc = true;\n            }\n        }\n        if (realloc) {\n#ifndef NDEBUG\n            {\n                size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;\n                if (cur_size > 0) {\n                    GGML_LOG_DEBUG(\"%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\\n\",\n                        __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);\n                }\n            }\n#endif\n            ggml_vbuffer_free(galloc->buffers[i]);\n            if (no_alloc) {\n                galloc->buffers[i] = NULL;\n            } else {\n                galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);\n                if (galloc->buffers[i] == NULL) {\n                    GGML_LOG_ERROR(\"%s: failed to allocate %s buffer of size %zu\\n\", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);\n                    return false;\n                }\n            }\n        }\n    }\n\n    return true;\n}\n\nvoid ggml_gallocr_reserve_n_size(\n        ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids, size_t * sizes) {\n    GGML_ASSERT(ggml_gallocr_reserve_n_impl(galloc, graph, node_buffer_ids, leaf_buffer_ids, /*no_alloc =*/ true));\n    for (int i = 0; i < galloc->n_buffers; i++) {\n        sizes[i] = 0;\n        for (int c = 0; c < galloc->buf_tallocs[i]->n_chunks; c++) {\n            sizes[i] += galloc->buf_tallocs[i]->chunks[c]->max_size;\n        }\n    }\n}\n\nbool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {\n    return ggml_gallocr_reserve_n_impl(galloc, graph, node_buffer_ids, leaf_buffer_ids, /*no_alloc =*/ false);\n}\n\nbool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {\n    return ggml_gallocr_reserve_n(galloc, graph, NULL, NULL);\n}\n\nstatic void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * tensor, struct tensor_alloc * tensor_alloc) {\n    int buffer_id = tensor_alloc->buffer_id;\n    assert(tensor->data || tensor->view_src || ggml_backend_buft_get_alloc_size(galloc->bufts[buffer_id], tensor) <= tensor_alloc->size_max);\n\n    if (tensor->view_src != NULL) {\n        if (tensor->buffer == NULL) {\n            assert(tensor_alloc->addr.offset == SIZE_MAX);\n            if (tensor->view_src->buffer == NULL) {\n                // this tensor was allocated without ggml-backend\n                return;\n            }\n            ggml_backend_view_init(tensor);\n        }\n    } else {\n        if (tensor->data == NULL) {\n            assert(tensor_alloc->addr.offset != SIZE_MAX);\n            assert(ggml_backend_buft_get_alloc_size(galloc->bufts[buffer_id], tensor) <= tensor_alloc->size_max);\n            ggml_vbuffer_tensor_alloc(galloc->buffers[buffer_id], tensor, tensor_alloc->addr);\n        } else {\n            if (tensor->buffer == NULL) {\n                // this tensor was allocated without ggml-backend\n                return;\n            }\n        }\n    }\n}\n\nstatic bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) {\n    size_t node_size = 0;\n    if (!node->data && !node->view_src) {\n        // If we previously had data but don't now then reallocate\n        if (talloc->buffer_id < 0) {\n            return false;\n        }\n        node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);\n    }\n    return talloc->size_max >= node_size;\n}\n\nstatic bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph * graph) {\n    if (galloc->n_nodes != graph->n_nodes) {\n#ifndef NDEBUG\n        GGML_LOG_DEBUG(\"%s: graph has different number of nodes\\n\", __func__);\n#endif\n        return true;\n    }\n\n    if (galloc->n_leafs != graph->n_leafs) {\n#ifndef NDEBUG\n        GGML_LOG_DEBUG(\"%s: graph has different number of leafs\\n\", __func__);\n#endif\n        return true;\n    }\n\n    for (int i = 0; i < graph->n_nodes; i++) {\n        struct ggml_tensor * node = graph->nodes[i];\n        struct node_alloc * node_alloc = &galloc->node_allocs[i];\n\n        if (!ggml_gallocr_node_needs_realloc(galloc, node, &node_alloc->dst)) {\n#ifndef NDEBUG\n            GGML_LOG_DEBUG(\"%s: node %s is not valid\\n\", __func__, node->name);\n#endif\n            return true;\n        }\n\n        for (int j = 0; j < GGML_MAX_SRC; j++) {\n            struct ggml_tensor * src = node->src[j];\n            if (src == NULL) {\n                continue;\n            }\n            if (!ggml_gallocr_node_needs_realloc(galloc, src, &node_alloc->src[j])) {\n#ifndef NDEBUG\n                GGML_LOG_DEBUG(\"%s: src %d (%s) of node %s is not valid\\n\", __func__, j, src->name, node->name);\n#endif\n                return true;\n            }\n        }\n    }\n\n    return false;\n}\n\nbool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph) {\n    if (ggml_gallocr_needs_realloc(galloc, graph)) {\n        if (galloc->n_buffers == 1) {\n#ifndef NDEBUG\n            GGML_LOG_DEBUG(\"%s: reallocating buffers automatically\\n\", __func__);\n#endif\n            if (!ggml_gallocr_reserve(galloc, graph)) {\n                return false;\n            }\n        } else {\n#ifndef NDEBUG\n            GGML_LOG_DEBUG(\"%s: cannot reallocate multi buffer graph automatically, call reserve\\n\", __func__);\n#endif\n            return false;\n        }\n    }\n\n    // reset buffers\n    for (int i = 0; i < galloc->n_buffers; i++) {\n        if (galloc->buffers[i] != NULL) {\n            ggml_vbuffer_reset(galloc->buffers[i]);\n        }\n    }\n\n    // allocate the graph tensors from the previous assignments\n    // leafs\n    for (int i = 0; i < graph->n_leafs; i++) {\n        struct ggml_tensor * leaf = graph->leafs[i];\n        struct leaf_alloc * leaf_alloc = &galloc->leaf_allocs[i];\n        ggml_gallocr_init_tensor(galloc, leaf, &leaf_alloc->leaf);\n    }\n    // nodes\n    for (int i = 0; i < graph->n_nodes; i++) {\n        struct ggml_tensor * node = graph->nodes[i];\n        struct node_alloc * node_alloc = &galloc->node_allocs[i];\n        for (int j = 0; j < GGML_MAX_SRC; j++) {\n            struct ggml_tensor * src = node->src[j];\n            if (src == NULL) {\n                continue;\n            }\n            ggml_gallocr_init_tensor(galloc, src, &node_alloc->src[j]);\n        }\n        ggml_gallocr_init_tensor(galloc, node, &node_alloc->dst);\n    }\n\n    return true;\n}\n\nsize_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {\n    GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers);\n\n    if (galloc->buffers[buffer_id] == NULL) {\n        return 0;\n    }\n\n    for (int i = 0; i < buffer_id; i++) {\n        if (galloc->buffers[i] == galloc->buffers[buffer_id]) {\n            // this buffer is the same as a previous one due to the same buffer type being used multiple times\n            // only return the buffer size the first time it appears to avoid double counting\n            return 0;\n        }\n    }\n\n    return ggml_vbuffer_size(galloc->buffers[buffer_id]);\n}\n\n// utils\n\nstatic void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {\n    for (size_t i = 0; i < *n_buffers; i++) {\n        ggml_backend_buffer_free((*buffers)[i]);\n    }\n    free(*buffers);\n}\n\nstatic bool alloc_tensor_range(struct ggml_context * ctx,\n        struct ggml_tensor * first, struct ggml_tensor * last,\n        ggml_backend_buffer_type_t buft, size_t size,\n        ggml_backend_buffer_t ** buffers, size_t * n_buffers) {\n\n    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);\n    if (buffer == NULL) {\n        GGML_LOG_ERROR(\"%s: failed to allocate %s buffer of size %zu\\n\", __func__, ggml_backend_buft_name(buft), size);\n        free_buffers(buffers, n_buffers);\n        return false;\n    }\n\n    *buffers = realloc(*buffers, sizeof(ggml_backend_buffer_t) * (*n_buffers + 1));\n    (*buffers)[(*n_buffers)++] = buffer;\n\n    struct ggml_tallocr tallocr = ggml_tallocr_new(buffer);\n\n    for (struct ggml_tensor * t = first; t != last; t = ggml_get_next_tensor(ctx, t)) {\n        enum ggml_status status = GGML_STATUS_SUCCESS;\n        if (t->data == NULL) {\n            if (t->view_src == NULL) {\n                status = ggml_tallocr_alloc(&tallocr, t);\n            } else if (t->buffer == NULL) {\n                status = ggml_backend_view_init(t);\n            }\n        } else {\n            if (t->view_src != NULL && t->buffer == NULL) {\n                // view of a pre-allocated tensor\n                status = ggml_backend_view_init(t);\n            }\n        }\n        if (status != GGML_STATUS_SUCCESS) {\n            GGML_LOG_ERROR(\"%s: failed to initialize tensor %s\\n\", __func__, t->name);\n            free_buffers(buffers, n_buffers);\n            return false;\n        }\n    }\n\n    return true;\n}\n\nstatic ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft_impl(\n        struct ggml_context * ctx, ggml_backend_buffer_type_t buft, size_t * nbytes_total, bool no_alloc) {\n    GGML_ASSERT(ggml_get_no_alloc(ctx) == true);\n\n    size_t alignment = ggml_backend_buft_get_alignment(buft);\n    size_t max_size = ggml_backend_buft_get_max_size(buft);\n\n    ggml_backend_buffer_t * buffers = NULL;\n    size_t n_buffers = 0;\n    *nbytes_total = 0;\n\n    size_t cur_buf_size = 0;\n    struct ggml_tensor * first = ggml_get_first_tensor(ctx);\n    for (struct ggml_tensor * t = first; t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n        size_t this_size = 0;\n        if (t->data == NULL && t->view_src == NULL) {\n            this_size = GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);\n        }\n\n        if (cur_buf_size > 0 && (cur_buf_size + this_size) > max_size) {\n            // allocate tensors in the current buffer\n            if (!no_alloc && !alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) {\n                return NULL;\n            }\n            first = t;\n            *nbytes_total += cur_buf_size;\n            cur_buf_size = this_size;\n        } else {\n            cur_buf_size += this_size;\n        }\n    }\n\n    // allocate remaining tensors\n    if (cur_buf_size > 0) {\n        *nbytes_total += cur_buf_size;\n        if (!no_alloc && !alloc_tensor_range(ctx, first, NULL, buft, cur_buf_size, &buffers, &n_buffers)) {\n            return NULL;\n        }\n    }\n\n    if (no_alloc) {\n        return NULL;\n    }\n\n    if (n_buffers == 0) {\n#ifndef NDEBUG\n        GGML_LOG_DEBUG(\"%s: all tensors in the context are already allocated\\n\", __func__);\n#endif\n        GGML_ASSERT(!buffers);\n        return NULL;\n    }\n\n    ggml_backend_buffer_t buffer;\n    if (n_buffers == 1) {\n        buffer = buffers[0];\n    } else {\n        buffer = ggml_backend_multi_buffer_alloc_buffer(buffers, n_buffers);\n    }\n    if (buffers) {\n        free(buffers); // can be NULL if context is empty or no_alloc\n    }\n    return buffer;\n}\n\nsize_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {\n    size_t nbytes_total = 0;\n    ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft_impl(ctx, buft, &nbytes_total, /*no_alloc=*/ true);\n    GGML_ASSERT(!buf);\n    return nbytes_total;\n}\n\nggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {\n    size_t nbytes_total = 0;\n    return ggml_backend_alloc_ctx_tensors_from_buft_impl(ctx, buft, &nbytes_total, /*no_alloc =*/ false);\n}\n\nggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {\n    return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));\n}\n"
  },
  {
    "path": "src/ggml-backend-dl.cpp",
    "content": "#include \"ggml-backend-dl.h\"\n\n#ifdef _WIN32\n\ndl_handle * dl_load_library(const fs::path & path) {\n    // suppress error dialogs for missing DLLs\n    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);\n    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);\n\n    HMODULE handle = LoadLibraryW(path.wstring().c_str());\n\n    SetErrorMode(old_mode);\n\n    return handle;\n}\n\nvoid * dl_get_sym(dl_handle * handle, const char * name) {\n    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);\n    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);\n\n    void * p = (void *) GetProcAddress(handle, name);\n\n    SetErrorMode(old_mode);\n\n    return p;\n}\n\nconst char * dl_error() {\n    return \"\";\n}\n\n#else\n\ndl_handle * dl_load_library(const fs::path & path) {\n    dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);\n    return handle;\n}\n\nvoid * dl_get_sym(dl_handle * handle, const char * name) {\n    return dlsym(handle, name);\n}\n\nconst char * dl_error() {\n    const char *rslt = dlerror();\n    return rslt != nullptr ? rslt : \"\";\n}\n\n#endif\n"
  },
  {
    "path": "src/ggml-backend-dl.h",
    "content": "#pragma once\n\n#ifdef _WIN32\n#   define WIN32_LEAN_AND_MEAN\n#   ifndef NOMINMAX\n#       define NOMINMAX\n#   endif\n#   include <windows.h>\n#   include <winevt.h>\n#else\n#    include <dlfcn.h>\n#    include <unistd.h>\n#endif\n#include <filesystem>\n\nnamespace fs = std::filesystem;\n\n#ifdef _WIN32\n\nusing dl_handle = std::remove_pointer_t<HMODULE>;\n\nstruct dl_handle_deleter {\n    void operator()(HMODULE handle) {\n        FreeLibrary(handle);\n    }\n};\n\n#else\n\nusing dl_handle = void;\n\nstruct dl_handle_deleter {\n    void operator()(void * handle) {\n        dlclose(handle);\n    }\n};\n\n#endif\n\nusing dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;\n\ndl_handle * dl_load_library(const fs::path & path);\nvoid * dl_get_sym(dl_handle * handle, const char * name);\nconst char * dl_error();\n\n"
  },
  {
    "path": "src/ggml-backend-impl.h",
    "content": "#pragma once\n\n// ggml-backend internal header\n\n#include \"ggml-backend.h\"\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\n    #define GGML_BACKEND_API_VERSION 2\n\n    //\n    // Backend buffer type\n    //\n\n    struct ggml_backend_buffer_type_i {\n        const char *          (*get_name)      (ggml_backend_buffer_type_t buft);\n        // allocate a buffer of this type\n        ggml_backend_buffer_t (*alloc_buffer)  (ggml_backend_buffer_type_t buft, size_t size);\n        // tensor alignment\n        size_t                (*get_alignment) (ggml_backend_buffer_type_t buft);\n        // (optional) max buffer size that can be allocated (defaults to SIZE_MAX)\n        size_t                (*get_max_size)  (ggml_backend_buffer_type_t buft);\n        // (optional) data size needed to allocate the tensor, including padding (defaults to ggml_nbytes)\n        size_t                (*get_alloc_size)(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);\n        // (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false)\n        bool                  (*is_host)       (ggml_backend_buffer_type_t buft);\n    };\n\n    struct ggml_backend_buffer_type {\n        struct ggml_backend_buffer_type_i  iface;\n        ggml_backend_dev_t device;\n        void * context;\n    };\n\n    //\n    // Backend buffer\n    //\n\n    struct ggml_backend_buffer_i {\n        // (optional) free the buffer\n        void         (*free_buffer)  (ggml_backend_buffer_t buffer);\n        // base address of the buffer\n        void *       (*get_base)     (ggml_backend_buffer_t buffer);\n        // (optional) initialize a tensor in the buffer (eg. add tensor extras)\n        enum ggml_status (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);\n        // tensor data access\n        void         (*memset_tensor)(ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);\n        void         (*set_tensor)   (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);\n        void         (*get_tensor)   (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);\n        // (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported)\n        bool         (*cpy_tensor)   (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst);\n        // clear the entire buffer\n        void         (*clear)        (ggml_backend_buffer_t buffer, uint8_t value);\n        // (optional) reset any internal state due to tensor initialization, such as tensor extras\n        void         (*reset)        (ggml_backend_buffer_t buffer);\n    };\n\n    struct ggml_backend_buffer {\n        struct ggml_backend_buffer_i  iface;\n        ggml_backend_buffer_type_t    buft;\n        void * context;\n        size_t size;\n        enum ggml_backend_buffer_usage usage;\n    };\n\n    GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(\n                   ggml_backend_buffer_type_t buft,\n            struct ggml_backend_buffer_i      iface,\n                   void *                     context,\n                   size_t                     size);\n\n    // do not use directly, use ggml_backend_tensor_copy instead\n    GGML_API bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst);\n\n    // multi-buffer\n    // buffer that contains a collection of buffers\n    GGML_API ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers);\n    GGML_API bool                  ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer);\n    GGML_API void                  ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);\n\n    //\n    // Backend (stream)\n    //\n\n    struct ggml_backend_i {\n        const char * (*get_name)(ggml_backend_t backend);\n\n        void (*free)(ggml_backend_t backend);\n\n        // (optional) asynchronous tensor data access\n        void (*set_tensor_async)(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);\n        void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);\n        bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);\n\n        // (optional) complete all pending operations (required if the backend supports async operations)\n        void (*synchronize)(ggml_backend_t backend);\n\n        // (optional) graph plans (not used currently)\n        // compute graph with a plan\n        ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);\n        void                      (*graph_plan_free)   (ggml_backend_t backend, ggml_backend_graph_plan_t plan);\n        // update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology\n        void                      (*graph_plan_update) (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph);\n        // compute the graph with the plan\n        enum ggml_status          (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);\n\n        // compute graph (always async if supported by the backend)\n        enum ggml_status          (*graph_compute)     (ggml_backend_t backend, struct ggml_cgraph * cgraph);\n\n        // (optional) event synchronization\n        // record an event on this stream\n        void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event);\n        // wait for an event on on a different stream\n        void (*event_wait)  (ggml_backend_t backend, ggml_backend_event_t event);\n\n        // (optional) sort/optimize the nodes in the graph\n        void                      (*graph_optimize)    (ggml_backend_t backend, struct ggml_cgraph * cgraph);\n    };\n\n    struct ggml_backend {\n        ggml_guid_t guid;\n        struct ggml_backend_i iface;\n        ggml_backend_dev_t device;\n        void * context;\n    };\n\n    struct ggml_backend_event {\n        struct ggml_backend_device * device;\n        void * context;\n    };\n\n    //\n    // Backend device\n    //\n\n    // Note: if additional properties are needed, we should add a struct with all of them\n    //       the current functions to obtain the properties can remain, since they are more convenient for often used properties\n    struct ggml_backend_device_i {\n        // device name: short identifier for this device, such as \"CPU\" or \"CUDA0\"\n        const char * (*get_name)(ggml_backend_dev_t dev);\n\n        // device description: short informative description of the device, could be the model name\n        const char * (*get_description)(ggml_backend_dev_t dev);\n\n        // device memory in bytes: 0 bytes to indicate no memory to report\n        void         (*get_memory)(ggml_backend_dev_t dev, size_t * free, size_t * total);\n\n        // device type\n        enum ggml_backend_dev_type (*get_type)(ggml_backend_dev_t dev);\n\n        // device properties\n        void (*get_props)(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props);\n\n        // backend (stream) initialization\n        ggml_backend_t (*init_backend)(ggml_backend_dev_t dev, const char * params);\n\n        // preferred buffer type\n        ggml_backend_buffer_type_t (*get_buffer_type)(ggml_backend_dev_t dev);\n\n        // (optional) host buffer type (in system memory, typically this is a pinned memory buffer for faster transfers between host and device)\n        ggml_backend_buffer_type_t (*get_host_buffer_type)(ggml_backend_dev_t dev);\n\n        // (optional) buffer from pointer: create a buffer from a host pointer (useful for memory mapped models and importing data from other libraries)\n        ggml_backend_buffer_t (*buffer_from_host_ptr)(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size);\n\n        // check if the backend can compute an operation\n        bool (*supports_op)(ggml_backend_dev_t dev, const struct ggml_tensor * op);\n\n        // check if the backend can use tensors allocated in a buffer type\n        bool (*supports_buft)(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft);\n\n        // (optional) check if the backend wants to run an operation, even if the weights are allocated in an incompatible buffer\n        // these should be expensive operations that may benefit from running on this backend instead of the CPU backend\n        bool (*offload_op)(ggml_backend_dev_t dev, const struct ggml_tensor * op);\n\n        // (optional) event synchronization\n        ggml_backend_event_t (*event_new)         (ggml_backend_dev_t dev);\n        void                 (*event_free)        (ggml_backend_dev_t dev, ggml_backend_event_t event);\n        void                 (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);\n    };\n\n    struct ggml_backend_device {\n        struct ggml_backend_device_i iface;\n        ggml_backend_reg_t reg;\n        void * context;\n    };\n\n    //\n    // Backend (reg)\n    //\n\n    struct ggml_backend_reg_i {\n        const char * (*get_name)(ggml_backend_reg_t reg);\n\n        // enumerate available devices\n        size_t             (*get_device_count)(ggml_backend_reg_t reg);\n        ggml_backend_dev_t (*get_device)(ggml_backend_reg_t reg, size_t index);\n\n        // (optional) get a pointer to a function in the backend\n        // backends can add custom functions that are not part of the standard ggml-backend interface\n        void * (*get_proc_address)(ggml_backend_reg_t reg, const char * name);\n    };\n\n    struct ggml_backend_reg {\n        int api_version; // initialize to GGML_BACKEND_API_VERSION\n        struct ggml_backend_reg_i iface;\n        void * context;\n    };\n\n    // Add backend dynamic loading support to the backend\n\n    // Initialize the backend\n    typedef ggml_backend_reg_t (*ggml_backend_init_t)(void);\n    // Optional: obtain a score for the backend based on the system configuration\n    // Higher scores are preferred, 0 means the backend is not supported in the current system\n    typedef int                (*ggml_backend_score_t)(void);\n\n#ifdef GGML_BACKEND_DL\n#    ifdef __cplusplus\n#        define GGML_BACKEND_DL_IMPL(reg_fn)                             \\\n            extern \"C\" {                                                 \\\n            GGML_BACKEND_API ggml_backend_reg_t ggml_backend_init(void); \\\n            }                                                            \\\n            ggml_backend_reg_t ggml_backend_init(void) {                 \\\n                return reg_fn();                                         \\\n            }\n#        define GGML_BACKEND_DL_SCORE_IMPL(score_fn)       \\\n            extern \"C\" {                                   \\\n            GGML_BACKEND_API int ggml_backend_score(void); \\\n            }                                              \\\n            int ggml_backend_score(void) {                 \\\n                return score_fn();                         \\\n            }\n#    else\n#        define GGML_BACKEND_DL_IMPL(reg_fn)                              \\\n            GGML_BACKEND_API ggml_backend_reg_t ggml_backend_init(void);  \\\n            ggml_backend_reg_t                  ggml_backend_init(void) { \\\n                return reg_fn();                                          \\\n            }\n#        define GGML_BACKEND_DL_SCORE_IMPL(score_fn)        \\\n            GGML_BACKEND_API int ggml_backend_score(void);  \\\n            int                  ggml_backend_score(void) { \\\n                return score_fn();                          \\\n            }\n#    endif\n#else\n#    define GGML_BACKEND_DL_IMPL(reg_fn)\n#    define GGML_BACKEND_DL_SCORE_IMPL(score_fn)\n#endif\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-backend-reg.cpp",
    "content": "#include \"ggml-backend-impl.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-backend-dl.h\"\n#include \"ggml-impl.h\"\n#include <algorithm>\n#include <cstring>\n#include <filesystem>\n#include <memory>\n#include <string>\n#include <type_traits>\n#include <vector>\n#include <cctype>\n\n#ifdef _WIN32\n#    define WIN32_LEAN_AND_MEAN\n#    ifndef NOMINMAX\n#        define NOMINMAX\n#    endif\n#    include <windows.h>\n#elif defined(__APPLE__)\n#    include <mach-o/dyld.h>\n#    include <dlfcn.h>\n#else\n#    include <dlfcn.h>\n#    include <unistd.h>\n#endif\n\n// Backend registry\n#ifdef GGML_USE_CPU\n#include \"ggml-cpu.h\"\n#endif\n\n#ifdef GGML_USE_CUDA\n#include \"ggml-cuda.h\"\n#endif\n\n#ifdef GGML_USE_METAL\n#include \"ggml-metal.h\"\n#endif\n\n#ifdef GGML_USE_SYCL\n#include \"ggml-sycl.h\"\n#endif\n\n#ifdef GGML_USE_VULKAN\n#include \"ggml-vulkan.h\"\n#endif\n\n#ifdef GGML_USE_WEBGPU\n#include \"ggml-webgpu.h\"\n#endif\n\n#ifdef GGML_USE_ZDNN\n#include \"ggml-zdnn.h\"\n#endif\n\n#ifdef GGML_USE_OPENCL\n#include \"ggml-opencl.h\"\n#endif\n\n#ifdef GGML_USE_HEXAGON\n#include \"ggml-hexagon.h\"\n#endif\n\n#ifdef GGML_USE_BLAS\n#include \"ggml-blas.h\"\n#endif\n\n#ifdef GGML_USE_RPC\n#include \"ggml-rpc.h\"\n#endif\n\n#ifdef GGML_USE_VIRTGPU_FRONTEND\n#include \"ggml-virtgpu.h\"\n#endif\n\n#ifdef GGML_USE_CANN\n#include \"ggml-cann.h\"\n#endif\n\n#ifdef GGML_USE_ZENDNN\n#include \"ggml-zendnn.h\"\n#endif\n\n#ifdef GGML_USE_OPENVINO\n#include \"ggml-openvino.h\"\n#endif\n\nnamespace fs = std::filesystem;\n\nstatic std::string path_str(const fs::path & path) {\n    try {\n#if defined(__cpp_lib_char8_t)\n        // C++20 and later: u8string() returns std::u8string\n        const std::u8string u8str = path.u8string();\n        return std::string(reinterpret_cast<const char *>(u8str.data()), u8str.size());\n#else\n        // C++17: u8string() returns std::string\n        return path.u8string();\n#endif\n    } catch (...) {\n        return std::string();\n    }\n}\n\nstruct ggml_backend_reg_entry {\n    ggml_backend_reg_t reg;\n    dl_handle_ptr handle;\n};\n\nstruct ggml_backend_registry {\n    std::vector<ggml_backend_reg_entry> backends;\n    std::vector<ggml_backend_dev_t> devices;\n\n    ggml_backend_registry() {\n#ifdef GGML_USE_CUDA\n        register_backend(ggml_backend_cuda_reg());\n#endif\n#ifdef GGML_USE_METAL\n        register_backend(ggml_backend_metal_reg());\n#endif\n#ifdef GGML_USE_SYCL\n        register_backend(ggml_backend_sycl_reg());\n#endif\n#ifdef GGML_USE_VULKAN\n    // Add runtime disable check\n    if (getenv(\"GGML_DISABLE_VULKAN\") == nullptr) {\n        register_backend(ggml_backend_vk_reg());\n    } else {\n        GGML_LOG_DEBUG(\"Vulkan backend disabled by GGML_DISABLE_VULKAN environment variable\\n\");\n    }\n#endif\n#ifdef GGML_USE_WEBGPU\n        register_backend(ggml_backend_webgpu_reg());\n#endif\n#ifdef GGML_USE_ZDNN\n        register_backend(ggml_backend_zdnn_reg());\n#endif\n#ifdef GGML_USE_VIRTGPU_FRONTEND\n        register_backend(ggml_backend_virtgpu_reg());\n#endif\n\n#ifdef GGML_USE_OPENCL\n        register_backend(ggml_backend_opencl_reg());\n#endif\n#ifdef GGML_USE_ZENDNN\n        register_backend(ggml_backend_zendnn_reg());\n#endif\n#ifdef GGML_USE_HEXAGON\n        register_backend(ggml_backend_hexagon_reg());\n#endif\n#ifdef GGML_USE_CANN\n        register_backend(ggml_backend_cann_reg());\n#endif\n#ifdef GGML_USE_BLAS\n        register_backend(ggml_backend_blas_reg());\n#endif\n#ifdef GGML_USE_RPC\n        register_backend(ggml_backend_rpc_reg());\n#endif\n#ifdef GGML_USE_OPENVINO\n        register_backend(ggml_backend_openvino_reg());\n#endif\n#ifdef GGML_USE_CPU\n        register_backend(ggml_backend_cpu_reg());\n#endif\n    }\n\n    ~ggml_backend_registry() {\n        // FIXME: backends cannot be safely unloaded without a function to destroy all the backend resources,\n        // since backend threads may still be running and accessing resources from the dynamic library\n        for (auto & entry : backends) {\n            if (entry.handle) {\n                entry.handle.release(); // NOLINT\n            }\n        }\n    }\n\n    void register_backend(ggml_backend_reg_t reg, dl_handle_ptr handle = nullptr) {\n        if (!reg) {\n            return;\n        }\n\n#ifndef NDEBUG\n        GGML_LOG_DEBUG(\"%s: registered backend %s (%zu devices)\\n\",\n            __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg));\n#endif\n        backends.push_back({ reg, std::move(handle) });\n        for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {\n            register_device(ggml_backend_reg_dev_get(reg, i));\n        }\n    }\n\n    void register_device(ggml_backend_dev_t device) {\n#ifndef NDEBUG\n        GGML_LOG_DEBUG(\"%s: registered device %s (%s)\\n\", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));\n#endif\n        devices.push_back(device);\n    }\n\n    ggml_backend_reg_t load_backend(const fs::path & path, bool silent) {\n        dl_handle_ptr handle { dl_load_library(path) };\n        if (!handle) {\n            if (!silent) {\n                GGML_LOG_ERROR(\"%s: failed to load %s: %s\\n\", __func__, path_str(path).c_str(), dl_error());\n            }\n            return nullptr;\n        }\n\n        auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), \"ggml_backend_score\");\n        if (score_fn && score_fn() == 0) {\n            if (!silent) {\n                GGML_LOG_INFO(\"%s: backend %s is not supported on this system\\n\", __func__, path_str(path).c_str());\n            }\n            return nullptr;\n        }\n\n        auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), \"ggml_backend_init\");\n        if (!backend_init_fn) {\n            if (!silent) {\n                GGML_LOG_ERROR(\"%s: failed to find ggml_backend_init in %s\\n\", __func__, path_str(path).c_str());\n            }\n            return nullptr;\n        }\n\n        ggml_backend_reg_t reg = backend_init_fn();\n        if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {\n            if (!silent) {\n                if (!reg) {\n                    GGML_LOG_ERROR(\"%s: failed to initialize backend from %s: ggml_backend_init returned NULL\\n\",\n                        __func__, path_str(path).c_str());\n                } else {\n                    GGML_LOG_ERROR(\"%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\\n\",\n                        __func__, path_str(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);\n                }\n            }\n            return nullptr;\n        }\n\n        GGML_LOG_INFO(\"%s: loaded %s backend from %s\\n\", __func__, ggml_backend_reg_name(reg), path_str(path).c_str());\n\n        register_backend(reg, std::move(handle));\n\n        return reg;\n    }\n\n    void unload_backend(ggml_backend_reg_t reg, bool silent) {\n        auto it = std::find_if(backends.begin(), backends.end(),\n                               [reg](const ggml_backend_reg_entry & entry) { return entry.reg == reg; });\n\n        if (it == backends.end()) {\n            if (!silent) {\n                GGML_LOG_ERROR(\"%s: backend not found\\n\", __func__);\n            }\n            return;\n        }\n\n        if (!silent) {\n            GGML_LOG_DEBUG(\"%s: unloading %s backend\\n\", __func__, ggml_backend_reg_name(reg));\n        }\n\n        // remove devices\n        devices.erase(\n            std::remove_if(devices.begin(), devices.end(),\n                            [reg](ggml_backend_dev_t dev) { return ggml_backend_dev_backend_reg(dev) == reg; }),\n            devices.end());\n\n        // remove backend\n        backends.erase(it);\n    }\n};\n\nstatic ggml_backend_registry & get_reg() {\n    static ggml_backend_registry reg;\n    return reg;\n}\n\n// Internal API\nvoid ggml_backend_register(ggml_backend_reg_t reg) {\n    get_reg().register_backend(reg);\n}\n\nvoid ggml_backend_device_register(ggml_backend_dev_t device) {\n    get_reg().register_device(device);\n}\n\n// Backend (reg) enumeration\nstatic bool striequals(const char * a, const char * b) {\n    for (; *a && *b; a++, b++) {\n        if (std::tolower(*a) != std::tolower(*b)) {\n            return false;\n        }\n    }\n    return *a == *b;\n}\n\nsize_t ggml_backend_reg_count() {\n    return get_reg().backends.size();\n}\n\nggml_backend_reg_t ggml_backend_reg_get(size_t index) {\n    GGML_ASSERT(index < ggml_backend_reg_count());\n    return get_reg().backends[index].reg;\n}\n\nggml_backend_reg_t ggml_backend_reg_by_name(const char * name) {\n    for (size_t i = 0; i < ggml_backend_reg_count(); i++) {\n        ggml_backend_reg_t reg = ggml_backend_reg_get(i);\n        if (striequals(ggml_backend_reg_name(reg), name)) {\n            return reg;\n        }\n    }\n    return nullptr;\n}\n\n// Device enumeration\nsize_t ggml_backend_dev_count() {\n    return get_reg().devices.size();\n}\n\nggml_backend_dev_t ggml_backend_dev_get(size_t index) {\n    GGML_ASSERT(index < ggml_backend_dev_count());\n    return get_reg().devices[index];\n}\n\nggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {\n    for (size_t i = 0; i < ggml_backend_dev_count(); i++) {\n        ggml_backend_dev_t dev = ggml_backend_dev_get(i);\n        if (striequals(ggml_backend_dev_name(dev), name)) {\n            return dev;\n        }\n    }\n    return nullptr;\n}\n\nggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type) {\n    for (size_t i = 0; i < ggml_backend_dev_count(); i++) {\n        ggml_backend_dev_t dev = ggml_backend_dev_get(i);\n        if (ggml_backend_dev_type(dev) == type) {\n            return dev;\n        }\n    }\n    return nullptr;\n}\n\n// Convenience functions\nggml_backend_t ggml_backend_init_by_name(const char * name, const char * params) {\n    ggml_backend_dev_t dev = ggml_backend_dev_by_name(name);\n    if (!dev) {\n        return nullptr;\n    }\n    return ggml_backend_dev_init(dev, params);\n}\n\nggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params) {\n    ggml_backend_dev_t dev = ggml_backend_dev_by_type(type);\n    if (!dev) {\n        return nullptr;\n    }\n    return ggml_backend_dev_init(dev, params);\n}\n\nggml_backend_t ggml_backend_init_best(void) {\n    ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);\n    dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU);\n    dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);\n    if (!dev) {\n        return nullptr;\n    }\n    return ggml_backend_dev_init(dev, nullptr);\n}\n\n// Dynamic loading\nggml_backend_reg_t ggml_backend_load(const char * path) {\n    return get_reg().load_backend(path, false);\n}\n\nvoid ggml_backend_unload(ggml_backend_reg_t reg) {\n    get_reg().unload_backend(reg, true);\n}\n\nstatic fs::path get_executable_path() {\n#if defined(__APPLE__)\n    // get executable path\n    std::vector<char> path;\n    uint32_t size;\n    while (true) {\n        size = path.size();\n        if (_NSGetExecutablePath(path.data(), &size) == 0) {\n            break;\n        }\n        path.resize(size);\n    }\n    std::string base_path(path.data(), size);\n    // remove executable name\n    auto last_slash = base_path.find_last_of('/');\n    if (last_slash != std::string::npos) {\n        base_path = base_path.substr(0, last_slash);\n    }\n    return base_path + \"/\";\n#elif defined(__linux__) || defined(__FreeBSD__)\n    std::string base_path = \".\";\n    std::vector<char> path(1024);\n    while (true) {\n        // get executable path\n#    if defined(__linux__)\n        ssize_t len = readlink(\"/proc/self/exe\", path.data(), path.size());\n#    elif defined(__FreeBSD__)\n        ssize_t len = readlink(\"/proc/curproc/file\", path.data(), path.size());\n#    endif\n        if (len == -1) {\n            break;\n        }\n        if (len < (ssize_t) path.size()) {\n            base_path = std::string(path.data(), len);\n            // remove executable name\n            auto last_slash = base_path.find_last_of('/');\n            if (last_slash != std::string::npos) {\n                base_path = base_path.substr(0, last_slash);\n            }\n            break;\n        }\n        path.resize(path.size() * 2);\n    }\n\n    return base_path + \"/\";\n#elif defined(_WIN32)\n    std::vector<wchar_t> path(MAX_PATH);\n    DWORD len = GetModuleFileNameW(NULL, path.data(), path.size());\n    if (len == 0) {\n        return {};\n    }\n    std::wstring base_path(path.data(), len);\n    // remove executable name\n    auto last_slash = base_path.find_last_of('\\\\');\n    if (last_slash != std::string::npos) {\n        base_path = base_path.substr(0, last_slash);\n    }\n    return base_path + L\"\\\\\";\n#else\n    return {};\n#endif\n}\n\nstatic fs::path backend_filename_prefix() {\n#ifdef _WIN32\n    return fs::u8path(\"ggml-\");\n#else\n    return fs::u8path(\"libggml-\");\n#endif\n}\n\nstatic fs::path backend_filename_extension() {\n#ifdef _WIN32\n    return fs::u8path(\".dll\");\n#else\n    return fs::u8path(\".so\");\n#endif\n}\n\nstatic ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) {\n    // enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths\n    const fs::path name_path = fs::u8path(name);\n    const fs::path file_prefix = backend_filename_prefix().native() + name_path.native() + fs::u8path(\"-\").native();\n    const fs::path file_extension = backend_filename_extension();\n\n    std::vector<fs::path> search_paths;\n    if (user_search_path == nullptr) {\n#ifdef GGML_BACKEND_DIR\n        search_paths.push_back(fs::u8path(GGML_BACKEND_DIR));\n#endif\n        // default search paths: executable directory, current directory\n        search_paths.push_back(get_executable_path());\n        search_paths.push_back(fs::current_path());\n    } else {\n        search_paths.push_back(fs::u8path(user_search_path));\n    }\n\n    int best_score = 0;\n    fs::path best_path;\n    std::error_code ec;\n\n    for (const auto & search_path : search_paths) {\n        if (!fs::exists(search_path, ec)) {\n            if (ec) {\n                GGML_LOG_DEBUG(\"%s: posix_stat(%s) failure, error-message: %s\\n\", __func__, path_str(search_path).c_str(), ec.message().c_str());\n            } else {\n                GGML_LOG_DEBUG(\"%s: search path %s does not exist\\n\", __func__, path_str(search_path).c_str());\n            }\n            continue;\n        }\n        fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);\n        for (const auto & entry : dir_it) {\n            if (entry.is_regular_file(ec)) {\n                auto filename = entry.path().filename();\n                auto ext = entry.path().extension();\n                if (filename.native().find(file_prefix) == 0 && ext == file_extension) {\n                    dl_handle_ptr handle { dl_load_library(entry) };\n                    if (!handle && !silent) {\n                        GGML_LOG_ERROR(\"%s: failed to load %s: %s\\n\", __func__, path_str(entry.path()).c_str(), dl_error());\n                    }\n                    if (handle) {\n                        auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), \"ggml_backend_score\");\n                        if (score_fn) {\n                            int s = score_fn();\n#ifndef NDEBUG\n                            GGML_LOG_DEBUG(\"%s: %s score: %d\\n\", __func__, path_str(entry.path()).c_str(), s);\n#endif\n                            if (s > best_score) {\n                                best_score = s;\n                                best_path = entry.path();\n                            }\n                        } else {\n                            if (!silent) {\n                                GGML_LOG_INFO(\"%s: failed to find ggml_backend_score in %s\\n\", __func__, path_str(entry.path()).c_str());\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    if (best_score == 0) {\n        // try to load the base backend\n        for (const auto & search_path : search_paths) {\n            fs::path filename = backend_filename_prefix().native() + name_path.native() + backend_filename_extension().native();\n            fs::path path = search_path / filename;\n            if (std::error_code ec; fs::exists(path, ec)) {\n                return get_reg().load_backend(path, silent);\n            } else {\n                if (ec) {\n                    GGML_LOG_DEBUG(\"%s: posix_stat(%s) failure, error-message: %s\\n\", __func__, path_str(path).c_str(), ec.message().c_str());\n                }\n            }\n        }\n        return nullptr;\n    }\n\n    return get_reg().load_backend(best_path, silent);\n}\n\nvoid ggml_backend_load_all() {\n    ggml_backend_load_all_from_path(nullptr);\n}\n\nvoid ggml_backend_load_all_from_path(const char * dir_path) {\n#ifdef NDEBUG\n    bool silent = true;\n#else\n    bool silent = false;\n#endif\n\n    ggml_backend_load_best(\"blas\", silent, dir_path);\n    ggml_backend_load_best(\"zendnn\", silent, dir_path);\n    ggml_backend_load_best(\"cann\", silent, dir_path);\n    ggml_backend_load_best(\"cuda\", silent, dir_path);\n    ggml_backend_load_best(\"hip\", silent, dir_path);\n    ggml_backend_load_best(\"metal\", silent, dir_path);\n    ggml_backend_load_best(\"rpc\", silent, dir_path);\n    ggml_backend_load_best(\"sycl\", silent, dir_path);\n    ggml_backend_load_best(\"vulkan\", silent, dir_path);\n    ggml_backend_load_best(\"virtgpu\", silent, dir_path);\n    ggml_backend_load_best(\"opencl\", silent, dir_path);\n    ggml_backend_load_best(\"hexagon\", silent, dir_path);\n    ggml_backend_load_best(\"musa\", silent, dir_path);\n    ggml_backend_load_best(\"openvino\", silent, dir_path);\n    ggml_backend_load_best(\"cpu\", silent, dir_path);\n    // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend\n    const char * backend_path = std::getenv(\"GGML_BACKEND_PATH\");\n    if (backend_path) {\n        ggml_backend_load(backend_path);\n    }\n}\n"
  },
  {
    "path": "src/ggml-backend.cpp",
    "content": "// Note: porting this file to C++ is a work in progress\n\n#ifdef _WIN32\n#define WIN32_LEAN_AND_MEAN\n#ifndef NOMINMAX\n#   define NOMINMAX\n#endif\n#include <windows.h>\n#endif\n\n#include \"ggml-backend.h\"\n#include \"ggml-backend-impl.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-impl.h\"\n\n#include <assert.h>\n#include <limits.h>\n#include <stdarg.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <string.h>\n#include <algorithm>\n#include <vector>\n\n#ifdef __APPLE__\n#include <sys/types.h>\n#include <sys/sysctl.h>\n#endif\n\n\n// backend buffer type\n\nconst char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {\n    GGML_ASSERT(buft);\n    return buft->iface.get_name(buft);\n}\n\nggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    GGML_ASSERT(buft);\n    if (size == 0) {\n        // return a dummy buffer for zero-sized allocations\n        return ggml_backend_buffer_init(buft, {}, NULL, 0);\n    }\n    return buft->iface.alloc_buffer(buft, size);\n}\n\nsize_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {\n    GGML_ASSERT(buft);\n    return buft->iface.get_alignment(buft);\n}\n\nsize_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {\n    GGML_ASSERT(buft);\n    // get_max_size is optional, defaults to SIZE_MAX\n    if (buft->iface.get_max_size) {\n        return buft->iface.get_max_size(buft);\n    }\n    return SIZE_MAX;\n}\n\nsize_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {\n    GGML_ASSERT(buft);\n    // get_alloc_size is optional, defaults to ggml_nbytes\n    if (buft->iface.get_alloc_size) {\n        size_t size = buft->iface.get_alloc_size(buft, tensor);\n        assert(size >= ggml_nbytes(tensor));\n        return size;\n    }\n    return ggml_nbytes(tensor);\n}\n\nbool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {\n    GGML_ASSERT(buft);\n    if (buft->iface.is_host) {\n        return buft->iface.is_host(buft);\n    }\n    return false;\n}\n\nggml_backend_dev_t ggml_backend_buft_get_device(ggml_backend_buffer_type_t buft) {\n    GGML_ASSERT(buft);\n    return buft->device;\n}\n\n// backend buffer\n\nggml_backend_buffer_t ggml_backend_buffer_init(\n               ggml_backend_buffer_type_t buft,\n        struct ggml_backend_buffer_i      iface,\n               void *                     context,\n               size_t                     size) {\n    ggml_backend_buffer_t buffer = new ggml_backend_buffer {\n        /* .interface = */ iface,\n        /* .buft      = */ buft,\n        /* .context   = */ context,\n        /* .size      = */ size,\n        /* .usage     = */ GGML_BACKEND_BUFFER_USAGE_ANY\n    };\n\n    return buffer;\n}\n\nconst char * ggml_backend_buffer_name(ggml_backend_buffer_t buffer) {\n    return ggml_backend_buft_name(ggml_backend_buffer_get_type(buffer));\n}\n\nvoid ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {\n    if (buffer == NULL) {\n        return;\n    }\n\n    if (buffer->iface.free_buffer != NULL) {\n        buffer->iface.free_buffer(buffer);\n    }\n    delete buffer;\n}\n\nsize_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {\n    GGML_ASSERT(buffer);\n    return buffer->size;\n}\n\nvoid * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {\n    GGML_ASSERT(buffer);\n    // get_base is optional if the buffer is zero-sized\n    if (buffer->size == 0) {\n        return NULL;\n    }\n\n    // FIXME JG: a multi_buffer has a non-zero size, according to the above comment get_base is not optional,\n    //     I don't know whether the above comment is correct\n    if (!buffer->iface.get_base) {\n        return NULL;\n    }\n\n    void * base = buffer->iface.get_base(buffer);\n\n    GGML_ASSERT(base != NULL && \"backend buffer base cannot be NULL\");\n\n    return base;\n}\n\nenum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {\n    GGML_ASSERT(buffer);\n    // init_tensor is optional\n    if (buffer->iface.init_tensor) {\n        return buffer->iface.init_tensor(buffer, tensor);\n    }\n    return GGML_STATUS_SUCCESS;\n}\n\nvoid ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    GGML_ASSERT(buffer);\n    // clear is optional if the buffer is zero-sized\n    if (buffer->size == 0) {\n        return;\n    }\n\n    buffer->iface.clear(buffer, value);\n}\n\nsize_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) {\n    return ggml_backend_buft_get_alignment(ggml_backend_buffer_get_type(buffer));\n}\n\nsize_t ggml_backend_buffer_get_max_size(ggml_backend_buffer_t buffer) {\n    return ggml_backend_buft_get_max_size(ggml_backend_buffer_get_type(buffer));\n}\n\nsize_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor) {\n    return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor);\n}\n\nbool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) {\n    return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));\n}\n\nvoid ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {\n    GGML_ASSERT(buffer);\n    buffer->usage = usage;\n\n    // FIXME: add a generic callback to the buffer interface\n    if (ggml_backend_buffer_is_multi_buffer(buffer)) {\n        ggml_backend_multi_buffer_set_usage(buffer, usage);\n    }\n}\n\nenum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) {\n    GGML_ASSERT(buffer);\n    return buffer->usage;\n}\n\nggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) {\n    GGML_ASSERT(buffer);\n    return buffer->buft;\n}\n\nvoid ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) {\n    GGML_ASSERT(buffer);\n    if (buffer->iface.reset) {\n        buffer->iface.reset(buffer);\n    }\n}\n\nbool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst) {\n    ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer;\n    if (dst_buf->iface.cpy_tensor) {\n        return dst_buf->iface.cpy_tensor(dst_buf, src, dst);\n    }\n    return false;\n}\n\n// backend\n\nggml_guid_t ggml_backend_guid(ggml_backend_t backend) {\n    if (backend == NULL) {\n        return NULL;\n    }\n    return backend->guid;\n}\n\nconst char * ggml_backend_name(ggml_backend_t backend) {\n    if (backend == NULL) {\n        return \"NULL\";\n    }\n    return backend->iface.get_name(backend);\n}\n\nvoid ggml_backend_free(ggml_backend_t backend) {\n    if (backend == NULL) {\n        return;\n    }\n\n    backend->iface.free(backend);\n}\n\nggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {\n    GGML_ASSERT(backend);\n    return ggml_backend_dev_buffer_type(backend->device);\n}\n\nggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {\n    return ggml_backend_buft_alloc_buffer(ggml_backend_get_default_buffer_type(backend), size);\n}\n\nsize_t ggml_backend_get_alignment(ggml_backend_t backend) {\n    return ggml_backend_buft_get_alignment(ggml_backend_get_default_buffer_type(backend));\n}\n\nsize_t ggml_backend_get_max_size(ggml_backend_t backend) {\n    return ggml_backend_buft_get_max_size(ggml_backend_get_default_buffer_type(backend));\n}\n\nvoid ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    GGML_ASSERT(backend);\n    GGML_ASSERT(tensor);\n    GGML_ASSERT(tensor->data != NULL && \"tensor not allocated\");\n    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && \"tensor write out of bounds\");\n\n    if (backend->iface.set_tensor_async == NULL) {\n        ggml_backend_synchronize(backend);\n        ggml_backend_tensor_set(tensor, data, offset, size);\n    } else {\n        backend->iface.set_tensor_async(backend, tensor, data, offset, size);\n    }\n}\n\nvoid ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    GGML_ASSERT(backend);\n    GGML_ASSERT(tensor);\n    GGML_ASSERT(tensor->data != NULL && \"tensor not allocated\");\n    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && \"tensor read out of bounds\");\n\n    if (backend->iface.get_tensor_async == NULL) {\n        ggml_backend_synchronize(backend);\n        ggml_backend_tensor_get(tensor, data, offset, size);\n    } else {\n        backend->iface.get_tensor_async(backend, tensor, data, offset, size);\n    }\n}\n\nvoid ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    GGML_ASSERT(tensor);\n    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;\n\n    if (size == 0) {\n        return;\n    }\n\n    GGML_ASSERT(buf != NULL && \"tensor buffer not set\");\n    GGML_ASSERT(tensor->data != NULL && \"tensor not allocated\");\n    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && \"tensor write out of bounds\");\n\n    buf->iface.set_tensor(buf, tensor, data, offset, size);\n}\n\nvoid ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    GGML_ASSERT(tensor);\n    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;\n\n    if (size == 0) {\n        return;\n    }\n\n    GGML_ASSERT(buf != NULL && \"tensor buffer not set\");\n    GGML_ASSERT(tensor->data != NULL && \"tensor not allocated\");\n    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && \"tensor read out of bounds\");\n\n    buf->iface.get_tensor(buf, tensor, data, offset, size);\n}\n\nvoid ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {\n    GGML_ASSERT(tensor);\n    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;\n\n    if (size == 0) {\n        return;\n    }\n\n    GGML_ASSERT(buf != NULL && \"tensor buffer not set\");\n    GGML_ASSERT(tensor->data != NULL && \"tensor not allocated\");\n    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && \"tensor write out of bounds\");\n    GGML_ASSERT(buf->iface.memset_tensor != NULL && \"memset not implemented by backend buffer\");\n\n    buf->iface.memset_tensor(buf, tensor, value, offset, size);\n}\n\nvoid ggml_backend_synchronize(ggml_backend_t backend) {\n    GGML_ASSERT(backend);\n    if (backend->iface.synchronize == NULL) {\n        return;\n    }\n\n    backend->iface.synchronize(backend);\n}\n\nggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {\n    GGML_ASSERT(backend);\n    GGML_ASSERT(backend->iface.graph_plan_create != NULL);\n\n    return backend->iface.graph_plan_create(backend, cgraph);\n}\n\nvoid ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {\n    GGML_ASSERT(backend);\n    GGML_ASSERT(backend->iface.graph_plan_free != NULL);\n\n    backend->iface.graph_plan_free(backend, plan);\n}\n\nenum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {\n    GGML_ASSERT(backend);\n    GGML_ASSERT(backend->iface.graph_plan_compute != NULL);\n\n    return backend->iface.graph_plan_compute(backend, plan);\n}\n\nenum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {\n    enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph);\n    ggml_backend_synchronize(backend);\n    return err;\n}\n\nenum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {\n    GGML_ASSERT(backend);\n    return backend->iface.graph_compute(backend, cgraph);\n}\n\nbool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {\n    GGML_ASSERT(backend);\n    return ggml_backend_dev_supports_op(backend->device, op);\n}\n\nbool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {\n    GGML_ASSERT(backend);\n    return ggml_backend_dev_supports_buft(backend->device, buft);\n}\n\nbool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) {\n    GGML_ASSERT(backend);\n    return ggml_backend_dev_offload_op(backend->device, op);\n}\n\nggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {\n    GGML_ASSERT(backend);\n    return backend->device;\n}\n\n// backend copy\n\nvoid ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {\n    GGML_ASSERT(ggml_are_same_layout(src, dst) && \"cannot copy tensors with different layouts\");\n\n    if (src == dst) {\n        return;\n    }\n\n    if (ggml_backend_buffer_is_host(src->buffer)) {\n        ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));\n    } else if (ggml_backend_buffer_is_host(dst->buffer)) {\n        ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));\n    } else if (!ggml_backend_buffer_copy_tensor(src, dst)) {\n#ifndef NDEBUG\n        GGML_LOG_DEBUG(\"%s: warning: slow copy from %s to %s\\n\", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer));\n#endif\n        size_t nbytes = ggml_nbytes(src);\n        void * data = malloc(nbytes);\n        ggml_backend_tensor_get(src, data, 0, nbytes);\n        ggml_backend_tensor_set(dst, data, 0, nbytes);\n        free(data);\n    }\n}\n\nvoid ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) {\n    GGML_ASSERT(ggml_are_same_layout(src, dst) && \"cannot copy tensors with different layouts\");\n\n    if (src == dst) {\n        return;\n    }\n\n    GGML_ASSERT(backend_dst);\n    if (backend_dst->iface.cpy_tensor_async != NULL) {\n        if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) {\n            return;\n        }\n    }\n\n    // an async copy would normally happen after all the queued operations on both backends are completed\n    // to simulate the same behavior, we need to synchronize both backends first, and do a blocking copy\n    ggml_backend_synchronize(backend_src);\n    ggml_backend_synchronize(backend_dst);\n    ggml_backend_tensor_copy(src, dst);\n}\n\n// events\n\nggml_backend_event_t ggml_backend_event_new(ggml_backend_dev_t device) {\n    // null device is allowed for the transition period to the device interface\n    if (device == NULL || device->iface.event_new == NULL) {\n        return NULL;\n    }\n    return device->iface.event_new(device);\n}\n\nvoid ggml_backend_event_free(ggml_backend_event_t event) {\n    if (event == NULL) {\n        return;\n    }\n    event->device->iface.event_free(event->device, event);\n}\n\nvoid ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend) {\n    GGML_ASSERT(backend);\n    GGML_ASSERT(backend->iface.event_record != NULL);\n\n    backend->iface.event_record(backend, event);\n}\n\nvoid ggml_backend_event_synchronize(ggml_backend_event_t event) {\n    GGML_ASSERT(event);\n    GGML_ASSERT(event->device->iface.event_synchronize);\n\n    event->device->iface.event_synchronize(event->device, event);\n}\n\nvoid ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {\n    GGML_ASSERT(backend);\n    GGML_ASSERT(backend->iface.event_wait != NULL);\n\n    backend->iface.event_wait(backend, event);\n}\n\nstatic void ggml_backend_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * cgraph) {\n    GGML_ASSERT(backend);\n    if (backend->iface.graph_optimize != NULL) {\n        backend->iface.graph_optimize(backend, cgraph);\n    }\n}\n\n// Backend device\n\nconst char * ggml_backend_dev_name(ggml_backend_dev_t device) {\n    GGML_ASSERT(device);\n    return device->iface.get_name(device);\n}\n\nconst char * ggml_backend_dev_description(ggml_backend_dev_t device) {\n    GGML_ASSERT(device);\n    return device->iface.get_description(device);\n}\n\nvoid ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {\n    GGML_ASSERT(device);\n    device->iface.get_memory(device, free, total);\n}\n\nenum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) {\n    GGML_ASSERT(device);\n    return device->iface.get_type(device);\n}\n\nvoid ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) {\n    memset(props, 0, sizeof(*props));\n    device->iface.get_props(device, props);\n}\n\nggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device) {\n    GGML_ASSERT(device);\n    return device->reg;\n}\n\nggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params) {\n    GGML_ASSERT(device);\n    return device->iface.init_backend(device, params);\n}\n\nggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {\n    GGML_ASSERT(device);\n    return device->iface.get_buffer_type(device);\n}\n\nggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) {\n    GGML_ASSERT(device);\n    if (device->iface.get_host_buffer_type == NULL) {\n        return NULL;\n    }\n\n    return device->iface.get_host_buffer_type(device);\n}\n\nggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size) {\n    GGML_ASSERT(device);\n    return device->iface.buffer_from_host_ptr(device, ptr, size, max_tensor_size);\n}\n\nbool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {\n    GGML_ASSERT(device);\n    return device->iface.supports_op(device, op);\n}\n\nbool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft) {\n    GGML_ASSERT(device);\n    return device->iface.supports_buft(device, buft);\n}\n\nbool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {\n    GGML_ASSERT(device);\n    if (device->iface.offload_op != NULL) {\n        return device->iface.offload_op(device, op);\n    }\n\n    return false;\n}\n\n// Backend (reg)\n\nconst char * ggml_backend_reg_name(ggml_backend_reg_t reg) {\n    GGML_ASSERT(reg);\n    return reg->iface.get_name(reg);\n}\n\nsize_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg) {\n    GGML_ASSERT(reg);\n    return reg->iface.get_device_count(reg);\n}\n\nggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index) {\n    GGML_ASSERT(reg);\n    return reg->iface.get_device(reg, index);\n}\n\nvoid * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {\n    GGML_ASSERT(reg);\n    if (!reg->iface.get_proc_address) {\n        return NULL;\n    }\n    return reg->iface.get_proc_address(reg, name);\n}\n\n// multi-buffer buffer\n\nstruct ggml_backend_multi_buffer_context {\n    ggml_backend_buffer_t * buffers;\n    size_t n_buffers;\n};\n\nstatic void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    GGML_ASSERT(buffer);\n    ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;\n    for (size_t i = 0; i < ctx->n_buffers; i++) {\n        ggml_backend_buffer_free(ctx->buffers[i]);\n    }\n\n    free(ctx->buffers);\n    free(ctx);\n}\n\nstatic void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    GGML_ASSERT(buffer);\n    ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;\n    for (size_t i = 0; i < ctx->n_buffers; i++) {\n        ggml_backend_buffer_clear(ctx->buffers[i], value);\n    }\n}\n\nstatic const struct ggml_backend_buffer_i ggml_backend_multi_buffer_i = {\n    /* .free_buffer     = */ ggml_backend_multi_buffer_free_buffer,\n    /* .get_base        = */ NULL,\n    /* .init_tensor     = */ NULL,\n    /* .memset_tensor   = */ NULL,\n    /* .set_tensor      = */ NULL,\n    /* .get_tensor      = */ NULL,\n    /* .cpy_tensor      = */ NULL,\n    /* .clear           = */ ggml_backend_multi_buffer_clear,\n    /* .reset           = */ NULL,\n};\n\nggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers) {\n    ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) malloc(sizeof(struct ggml_backend_multi_buffer_context));\n    ctx->n_buffers = n_buffers;\n    ctx->buffers = (ggml_backend_buffer_t *) malloc(n_buffers * sizeof(ggml_backend_buffer_t));\n\n    GGML_ASSERT(ctx->buffers != NULL);\n\n    size_t total_size = 0;\n    for (size_t i = 0; i < n_buffers; i++) {\n        ctx->buffers[i] = buffers[i];\n        total_size += ggml_backend_buffer_get_size(buffers[i]);\n    }\n\n    return ggml_backend_buffer_init(buffers[0]->buft, ggml_backend_multi_buffer_i, ctx, total_size);\n}\n\nbool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) {\n    GGML_ASSERT(buffer);\n    return buffer->iface.free_buffer == ggml_backend_multi_buffer_free_buffer;\n}\n\nvoid ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {\n    GGML_ASSERT(buffer);\n    GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer));\n    ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;\n    for (size_t i = 0; i < ctx->n_buffers; i++) {\n        ggml_backend_buffer_set_usage(ctx->buffers[i], usage);\n    }\n}\n\n// creates a copy of the tensor with the same memory layout\nstatic struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) {\n    struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor);\n    for (int i = 0; i < GGML_MAX_DIMS; i++) {\n        dup->nb[i] = tensor->nb[i];\n    }\n    return dup;\n}\n\nstatic bool ggml_is_view_op(enum ggml_op op) {\n    return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;\n}\n\n// scheduler\n\n#ifndef GGML_SCHED_MAX_BACKENDS\n#define GGML_SCHED_MAX_BACKENDS 16\n#endif\n\n#ifndef GGML_SCHED_MAX_SPLIT_INPUTS\n#define GGML_SCHED_MAX_SPLIT_INPUTS 30\n#endif\n\n#ifndef GGML_SCHED_MAX_COPIES\n#define GGML_SCHED_MAX_COPIES 4\n#endif\n\nstruct ggml_backend_sched_split {\n    int backend_id;\n    int i_start;\n    int i_end;\n    struct ggml_tensor * inputs[GGML_SCHED_MAX_SPLIT_INPUTS];\n    int n_inputs;\n    // graph view of this split\n    struct ggml_cgraph graph;\n};\n\nstruct ggml_backend_sched {\n    bool is_reset; // true if the scheduler has been reset since the last graph split\n    bool is_alloc;\n\n    int n_backends;\n\n    ggml_backend_t backends[GGML_SCHED_MAX_BACKENDS];\n    ggml_backend_buffer_type_t bufts[GGML_SCHED_MAX_BACKENDS];\n    ggml_gallocr_t galloc;\n\n    // hash map of the nodes in the graph\n    struct ggml_hash_set  hash_set;\n    int                 * hv_tensor_backend_ids; // [hash_set.size]\n    struct ggml_tensor ** hv_tensor_copies;      // [hash_set.size][n_backends][n_copies]\n\n    int * node_backend_ids; // [graph_size]\n    int * leaf_backend_ids; // [graph_size]\n\n    int * prev_node_backend_ids; // [graph_size]\n    int * prev_leaf_backend_ids; // [graph_size]\n\n    // copy of the graph with modified inputs\n    struct ggml_cgraph graph;\n\n    // graph splits\n    struct ggml_backend_sched_split * splits;\n    int n_splits;\n    int splits_capacity;\n\n    // pipeline parallelism support\n    int n_copies;\n    int cur_copy;\n    int next_copy;\n    ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES];\n    struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS];\n    int n_graph_inputs;\n\n    struct ggml_context * ctx;\n\n    ggml_backend_sched_eval_callback callback_eval;\n    void * callback_eval_user_data;\n\n    char * context_buffer;\n    size_t context_buffer_size;\n\n    bool op_offload;\n\n    int debug;\n\n    // used for debugging graph reallocations [GGML_SCHED_DEBUG_REALLOC]\n    // ref: https://github.com/ggml-org/llama.cpp/pull/17617\n    int debug_realloc;\n    int debug_graph_size;\n    int debug_prev_graph_size;\n};\n\n#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)\n#define tensor_backend_id(tensor) sched->hv_tensor_backend_ids[hash_id(tensor)]\n#define tensor_id_copy(id, backend_id, copy_id) sched->hv_tensor_copies[(id) * sched->n_backends * sched->n_copies + (backend_id) * sched->n_copies + (copy_id)]\n#define tensor_copy(tensor, backend_id, copy_id) tensor_id_copy(hash_id(tensor), backend_id, copy_id)\n\n// returns the priority of the backend, lower id is higher priority\nstatic int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backend_t backend) {\n    for (int i = 0; i < sched->n_backends; i++) {\n        if (sched->backends[i] == backend) {\n            return i;\n        }\n    }\n    return -1;\n}\n\nstatic int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) {\n    ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;\n    if (buffer == NULL) {\n        return -1;\n    }\n\n    // find highest prio backend that supports the buffer type and the op\n    for (int i = 0; i < sched->n_backends; i++) {\n        if (ggml_backend_supports_buft(sched->backends[i], buffer->buft) &&\n            ggml_backend_supports_op(sched->backends[i], op)) {\n            return i;\n        }\n    }\n\n#ifndef NDEBUG\n    GGML_LOG_DEBUG(\"%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\\n\",\n        __func__, ggml_op_desc(tensor), ggml_backend_buffer_name(buffer), tensor->name);\n#endif\n\n    return -1;\n}\n\n#if 0\n#define GGML_SCHED_MAX_SPLITS_DEBUG 4096\nstatic char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only\n#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)\n#define GET_CAUSE(node) causes[hash_id(node)]\n#else\n#define SET_CAUSE(node, ...)\n#define GET_CAUSE(node) \"\"\n#endif\n\n// returns the backend that should be used for the node based on the current locations\nstatic int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) {\n    // assign pre-allocated nodes to their backend\n    int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor);\n    if (cur_backend_id != -1) {\n        SET_CAUSE(tensor, \"1.dst\");\n        return cur_backend_id;\n    }\n\n    // view_src\n    if (tensor->view_src != NULL) {\n        cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src, tensor);\n        if (cur_backend_id != -1) {\n            SET_CAUSE(tensor, \"1.vsrc\");\n            return cur_backend_id;\n        }\n    }\n\n    if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) {\n        // since the tensor is pre-allocated, it cannot be moved to another backend\n        ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;\n        GGML_ABORT(\"pre-allocated tensor (%s) in a buffer (%s) that cannot run the operation (%s)\", tensor->name, ggml_backend_buffer_name(buffer), ggml_op_name(tensor->op));\n    }\n\n    // graph input\n    if (tensor->flags & GGML_TENSOR_FLAG_INPUT) {\n        cur_backend_id = sched->n_backends - 1; // last backend (assumed CPU)\n        SET_CAUSE(tensor, \"1.inp\");\n        return cur_backend_id;\n    }\n\n    // operations with weights are preferably run on the same backend as the weights\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        const struct ggml_tensor * src = tensor->src[i];\n        if (src == NULL) {\n            continue;\n        }\n        // skip ROPE since the rope freqs tensor is too small to choose a backend based on it\n        // not an ideal solution\n        if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {\n            int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);\n            // check if a backend with higher prio wants to offload the op\n            if (sched->op_offload && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {\n                for (int b = 0; b < src_backend_id; b++) {\n                    if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {\n                        SET_CAUSE(tensor, \"1.off\");\n                        return b;\n                    }\n                }\n            }\n            SET_CAUSE(tensor, \"1.wgt%d\", i);\n            return src_backend_id;\n        }\n    }\n\n    return -1;\n}\n\nstatic char * fmt_size(size_t size) {\n    static char buffer[128];\n    if (size >= 1024*1024) {\n        snprintf(buffer, sizeof(buffer), \"%zuM\", size/1024/1024);\n    } else {\n        snprintf(buffer, sizeof(buffer), \"%zuK\", size/1024);\n    }\n    return buffer;\n}\n\nstatic void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {\n    int cur_split = 0;\n    for (int i = 0; i < graph->n_nodes; i++) {\n        if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {\n            ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id];\n            GGML_LOG_DEBUG(\"\\n## SPLIT #%d: %s # %d inputs\", cur_split, ggml_backend_name(split_backend),\n                sched->splits[cur_split].n_inputs);\n            for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {\n                if (j == 0) {\n                    GGML_LOG_DEBUG(\": \");\n                }\n                GGML_LOG_DEBUG(\"[%s (%5.5s)] \", sched->splits[cur_split].inputs[j]->name,\n                    fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));\n            }\n            GGML_LOG_DEBUG(\"\\n\");\n            cur_split++;\n        }\n        struct ggml_tensor * node = graph->nodes[i];\n        if (ggml_is_view_op(node->op)) {\n            continue;\n        }\n        if (sched->debug > 1) {\n            ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);\n            GGML_LOG_DEBUG(\"node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:\", i, ggml_op_name(node->op), node->name,\n                fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : \"NULL\", GET_CAUSE(node),\n                graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0);\n            for (int j = 0; j < GGML_MAX_SRC; j++) {\n                struct ggml_tensor * src = node->src[j];\n                if (src == NULL) {\n                    continue;\n                }\n                ggml_backend_t src_backend = ggml_backend_sched_get_tensor_backend(sched, src);\n                GGML_LOG_DEBUG(\" %20.20s (%5.5s) [%5.5s %8.8s]\", src->name,\n                    fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : \"NULL\", GET_CAUSE(src));\n            }\n            GGML_LOG_DEBUG(\"\\n\");\n        }\n    }\n}\n\nstatic bool ggml_backend_sched_buffer_supported(ggml_backend_sched_t sched, struct ggml_tensor * t, int backend_id) {\n    ggml_backend_buffer_t buf = t->view_src ? t->view_src->buffer : t->buffer;\n    ggml_backend_buffer_type_t buft = NULL;\n\n    if (buf) {\n        // the tensor is already allocated\n        buft = buf->buft;\n    } else {\n        // see if the tensor already has a backend assigned, and use the buffer type of that backend\n        int tensor_backend_id = tensor_backend_id(t);\n        if (tensor_backend_id == -1 && t->view_src) {\n            tensor_backend_id = tensor_backend_id(t->view_src);\n        }\n        if (tensor_backend_id != -1) {\n            buft = sched->bufts[tensor_backend_id];\n        }\n    }\n\n    return buft != NULL && ggml_backend_supports_buft(sched->backends[backend_id], buft);\n}\n\nstatic void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, struct ggml_tensor * node, int cur_backend_id, int * node_backend_id) {\n    if (ggml_backend_supports_op(sched->backends[cur_backend_id], node)) {\n        *node_backend_id = cur_backend_id;\n        SET_CAUSE(node, \"2.sup\");\n    }\n}\n\n// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend\nvoid ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {\n    // reset splits\n    sched->n_splits = 0;\n    sched->n_graph_inputs = 0;\n    sched->is_reset = false;\n\n    struct ggml_init_params params = {\n        /* .mem_size =   */ sched->context_buffer_size,\n        /* .mem_buffer = */ sched->context_buffer,\n        /* .no_alloc =   */ true\n    };\n\n    ggml_free(sched->ctx);\n\n    sched->ctx = ggml_init(params);\n    if (sched->ctx == NULL) {\n        GGML_ABORT(\"%s: failed to initialize context\\n\", __func__);\n    }\n\n    // pass 1: assign backends to ops with pre-allocated inputs\n    for (int i = 0; i < graph->n_leafs; i++) {\n        struct ggml_tensor * leaf = graph->leafs[i];\n        int * leaf_backend_id = &tensor_backend_id(leaf);\n        // do not overwrite user assignments\n        if (*leaf_backend_id == -1) {\n            *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf);\n        }\n    }\n\n    for (int i = 0; i < graph->n_nodes; i++) {\n        struct ggml_tensor * node = graph->nodes[i];\n        int * node_backend_id = &tensor_backend_id(node);\n        // do not overwrite user assignments\n        if (*node_backend_id == -1) {\n            *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node);\n\n#if 0\n            // src\n            if (node->op == GGML_OP_NONE) {\n                continue;\n            }\n\n            for (int j = 0; j < GGML_MAX_SRC; j++) {\n                struct ggml_tensor * src = node->src[j];\n                if (src == NULL) {\n                    continue;\n                }\n                int * src_backend_id = &tensor_backend_id(src);\n                if (*src_backend_id == -1) {\n                    *src_backend_id = ggml_backend_sched_backend_id_from_cur(sched, src);\n                }\n            }\n#endif\n        }\n    }\n\n    // pass 2: expand current backend assignments\n    // assign the same backend to adjacent nodes\n    // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend)\n    // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops\n    // ops unsupported by the backend being expanded will be left unassigned so that they can be assigned later when the locations of its inputs are known\n    // expand gpu down\n    {\n        int cur_backend_id = -1;\n        for (int i = 0; i < graph->n_nodes; i++) {\n            struct ggml_tensor * node = graph->nodes[i];\n            if (ggml_is_view_op(node->op)) {\n                continue;\n            }\n            int * node_backend_id = &tensor_backend_id(node);\n            if (*node_backend_id != -1) {\n                if (*node_backend_id == sched->n_backends - 1) {\n                    // skip cpu (lowest prio backend)\n                    cur_backend_id = -1;\n                } else {\n                    cur_backend_id = *node_backend_id;\n                }\n            } else if (cur_backend_id != -1) {\n                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);\n            }\n        }\n    }\n    // expand gpu up\n    {\n        int cur_backend_id = -1;\n        for (int i = graph->n_nodes - 1; i >= 0; i--) {\n            struct ggml_tensor * node = graph->nodes[i];\n            if (ggml_is_view_op(node->op)) {\n                continue;\n            }\n            int * node_backend_id = &tensor_backend_id(node);\n            if (*node_backend_id != -1) {\n                if (*node_backend_id == sched->n_backends - 1) {\n                    // skip cpu (lowest prio backend)\n                    cur_backend_id = -1;\n                } else {\n                    cur_backend_id = *node_backend_id;\n                }\n            } else if (cur_backend_id != -1) {\n                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);\n            }\n        }\n    }\n    // expand rest down\n    {\n        int cur_backend_id = -1;\n        for (int i = 0; i < graph->n_nodes; i++) {\n            struct ggml_tensor * node = graph->nodes[i];\n            if (ggml_is_view_op(node->op)) {\n                continue;\n            }\n            int * node_backend_id = &tensor_backend_id(node);\n            if (*node_backend_id != -1) {\n                cur_backend_id = *node_backend_id;\n            } else if (cur_backend_id != -1) {\n                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);\n            }\n        }\n    }\n    // expand rest up\n    {\n        int cur_backend_id = -1;\n        for (int i = graph->n_nodes - 1; i >= 0; i--) {\n            struct ggml_tensor * node = graph->nodes[i];\n            if (ggml_is_view_op(node->op)) {\n                continue;\n            }\n            int * node_backend_id = &tensor_backend_id(node);\n            if (*node_backend_id != -1) {\n                cur_backend_id = *node_backend_id;\n            } else if (cur_backend_id != -1) {\n                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);\n            }\n        }\n    }\n\n    // pass 3: upgrade nodes to higher prio backends with compatible buffer types\n    // if the tensor is already in the same buffer type (*) as another higher priority backend, we should move it there\n    // however, we also need to verify that the sources are in compatible buffer types\n    // (*) the actual requirement is more relaxed, the buffer type of the backend should be supported by all the users of this tensor further down the graph\n    // however, this is slow to verify, so we have a more strict requirement that the buffer type is the same\n    // this is not uncommon since multiple backends can use host memory, with the same buffer type (eg. BLAS and CPU)\n    // additionally, set remaining unassigned nodes to the backend with the most supported inputs\n    // only nodes that could not be assigned during expansion due to the backend not supporting the op should be unassigned at this point\n    for (int i = 0; i < graph->n_nodes; i++) {\n        struct ggml_tensor * node = graph->nodes[i];\n        if (ggml_is_view_op(node->op)) {\n            continue;\n        }\n        int * node_backend_id = &tensor_backend_id(node);\n        if (*node_backend_id == -1) {\n            // unassigned node: find the backend with the most supported inputs\n            int n_supported_best = -1;\n            for (int b = 0; b < sched->n_backends; b++) {\n                if (ggml_backend_supports_op(sched->backends[b], node)) {\n                    int n_supported = 0;\n                    for (int j = 0; j < GGML_MAX_SRC; j++) {\n                        struct ggml_tensor * src = node->src[j];\n                        if (src == NULL) {\n                            continue;\n                        }\n                        if ((tensor_backend_id(src) != -1 || tensor_backend_id(src->view_src) != -1) && ggml_backend_sched_buffer_supported(sched, src, b)) {\n                            n_supported++;\n                        }\n                    }\n                    if (n_supported > n_supported_best) {\n                        n_supported_best = n_supported;\n                        *node_backend_id = b;\n                        SET_CAUSE(node, \"3.best\");\n                    }\n                }\n            }\n        } else {\n            // assigned node: upgrade to higher prio backend if possible\n            for (int b = 0; b < *node_backend_id; b++) {\n                if (sched->bufts[b] == sched->bufts[*node_backend_id] && ggml_backend_supports_op(sched->backends[b], node)) {\n                    bool supported = true;\n                    for (int j = 0; j < GGML_MAX_SRC; j++) {\n                        struct ggml_tensor * src = node->src[j];\n                        if (src == NULL) {\n                            continue;\n                        }\n                        if (!ggml_backend_sched_buffer_supported(sched, src, b)) {\n                            supported = false;\n                            break;\n                        }\n                    }\n                    if (supported) {\n                        *node_backend_id = b;\n                        SET_CAUSE(node, \"3.upg\");\n                        break;\n                    }\n                }\n            }\n        }\n    }\n\n    // pass 4: assign backends to remaining src from dst and view_src\n    for (int i = 0; i < graph->n_nodes; i++) {\n        struct ggml_tensor * node = graph->nodes[i];\n        int * cur_backend_id = &tensor_backend_id(node);\n        if (node->view_src != NULL && *cur_backend_id == -1) {\n            *cur_backend_id = tensor_backend_id(node->view_src);\n            SET_CAUSE(node, \"4.vsrc\");\n        }\n        for (int j = 0; j < GGML_MAX_SRC; j++) {\n            struct ggml_tensor * src = node->src[j];\n            if (src == NULL) {\n                continue;\n            }\n            int * src_backend_id = &tensor_backend_id(src);\n            if (*src_backend_id == -1) {\n                if (src->view_src != NULL) {\n                    // views are always on the same backend as the source\n                    *src_backend_id = tensor_backend_id(src->view_src);\n                    SET_CAUSE(src, \"4.vsrc\");\n                } else {\n                    *src_backend_id = *cur_backend_id;\n                    SET_CAUSE(src, \"4.cur\");\n                }\n            }\n        }\n        // if the node is still unassigned, assign it to the first backend that supports it\n        for (int b = 0; b < sched->n_backends && *cur_backend_id == -1; b++) {\n            ggml_backend_sched_set_if_supported(sched, node, b, cur_backend_id);\n        }\n        GGML_ASSERT(*cur_backend_id != -1);\n    }\n\n    // pass 5: split graph, find tensors that need to be copied\n    {\n        int i_split = 0;\n        struct ggml_backend_sched_split * split = &sched->splits[0];\n        // find the backend of the first split, skipping view ops\n        int i = 0;\n        for (; i < graph->n_nodes; i++) {\n            struct ggml_tensor * node = graph->nodes[i];\n            if (!ggml_is_view_op(node->op)) {\n                split->backend_id = tensor_backend_id(node);\n                break;\n            }\n        }\n        split->i_start = 0;\n        split->n_inputs = 0;\n        int cur_backend_id = split->backend_id;\n        for (; i < graph->n_nodes; i++) {\n            struct ggml_tensor * node = graph->nodes[i];\n\n            if (ggml_is_view_op(node->op)) {\n                continue;\n            }\n\n            const int node_backend_id = tensor_backend_id(node);\n\n            GGML_ASSERT(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback\n\n            // check if we should start a new split based on the sources of the current node\n            bool need_new_split = false;\n            if (node_backend_id == cur_backend_id && split->n_inputs > 0) {\n                for (int j = 0; j < GGML_MAX_SRC; j++) {\n                    struct ggml_tensor * src = node->src[j];\n                    if (src == NULL) {\n                        continue;\n                    }\n                    // check if a weight is on a different and incompatible backend\n                    // by starting a new split, the memory of the previously offloaded weights can be reused\n                    if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {\n                        int src_backend_id = tensor_backend_id(src);\n                        if (src_backend_id != cur_backend_id && !ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) {\n                            need_new_split = true;\n                            break;\n                        }\n                    }\n                    // check if the split has too many inputs\n                    // FIXME: count the number of inputs instead of only checking when full\n                    if (split->n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS) {\n                        const size_t id = hash_id(src);\n                        int src_backend_id = sched->hv_tensor_backend_ids[id];\n                        bool supported = ggml_backend_sched_buffer_supported(sched, src, cur_backend_id);\n                        if (src_backend_id != cur_backend_id && tensor_id_copy(id, cur_backend_id, 0) == NULL && !supported) {\n                            need_new_split = true;\n                            break;\n                        }\n                    }\n                }\n            }\n\n            if (node_backend_id != cur_backend_id || need_new_split) {\n                split->i_end = i;\n                i_split++;\n                if (i_split >= sched->splits_capacity) {\n                    sched->splits_capacity *= 2;\n                    sched->splits = (ggml_backend_sched_split *)\n                        realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split));\n                    GGML_ASSERT(sched->splits != NULL);\n                }\n                split = &sched->splits[i_split];\n                split->backend_id = node_backend_id;\n                split->i_start = i;\n                split->n_inputs = 0;\n                cur_backend_id = node_backend_id;\n            }\n\n            // find inputs that are not on the same backend\n            for (int j = 0; j < GGML_MAX_SRC; j++) {\n                struct ggml_tensor * src = node->src[j];\n                if (src == NULL) {\n                    continue;\n                }\n\n                size_t src_id = hash_id(src);\n                const int src_backend_id = sched->hv_tensor_backend_ids[src_id];\n                GGML_ASSERT(src_backend_id != -1); // all inputs should be assigned by now\n\n                if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) {\n                    if (tensor_id_copy(src_id, src_backend_id, 0) == NULL) {\n                        ggml_backend_t backend = sched->backends[src_backend_id];\n                        for (int c = 0; c < sched->n_copies; c++) {\n                            struct ggml_tensor * tensor_copy;\n                            if (c == sched->cur_copy) {\n                                tensor_copy = src; // use the original tensor as the current copy\n                            } else {\n                                tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);\n                                ggml_format_name(tensor_copy, \"%s#%s#%d\", ggml_backend_name(backend), src->name, c);\n                            }\n                            ggml_set_input(tensor_copy);\n                            ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor\n                            tensor_id_copy(src_id, src_backend_id, c) = tensor_copy;\n                            SET_CAUSE(tensor_copy, \"4.cpy\");\n                        }\n                        int n_graph_inputs = sched->n_graph_inputs++;\n                        GGML_ASSERT(n_graph_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);\n                        sched->graph_inputs[n_graph_inputs] = src;\n                    }\n                }\n\n                if (src_backend_id != cur_backend_id && !ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) {\n                    // create a copy of the input in the split's backend\n                    if (tensor_id_copy(src_id, cur_backend_id, 0) == NULL) {\n                        ggml_backend_t backend = sched->backends[cur_backend_id];\n                        for (int c = 0; c < sched->n_copies; c++) {\n                            struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);\n                            ggml_format_name(tensor_copy, \"%s#%s#%d\", ggml_backend_name(backend), src->name, c);\n                            if (sched->n_copies > 1) {\n                                ggml_set_input(tensor_copy);\n                                ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor\n                            }\n                            tensor_id_copy(src_id, cur_backend_id, c) = tensor_copy;\n                            SET_CAUSE(tensor_copy, \"4.cpy\");\n                        }\n                        int n_inputs = split->n_inputs++;\n                        GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);\n                        split->inputs[n_inputs] = src;\n                    }\n                    node->src[j] = tensor_id_copy(src_id, cur_backend_id, sched->cur_copy);\n                }\n            }\n        }\n        split->i_end = graph->n_nodes;\n        sched->n_splits = i_split + 1;\n    }\n\n    if (sched->debug) {\n        ggml_backend_sched_print_assignments(sched, graph);\n    }\n\n    // swap node_backend_ids and leaf _backend_ids with prevs\n    {\n        int * tmp = sched->node_backend_ids;\n        sched->node_backend_ids = sched->prev_node_backend_ids;\n        sched->prev_node_backend_ids = tmp;\n\n        tmp = sched->leaf_backend_ids;\n        sched->leaf_backend_ids = sched->prev_leaf_backend_ids;\n        sched->prev_leaf_backend_ids = tmp;\n    }\n\n    int graph_size = std::max(graph->n_nodes, graph->n_leafs) + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sched->n_copies;\n\n    // remember the actual graph_size for performing reallocation checks later [GGML_SCHED_DEBUG_REALLOC]\n    sched->debug_prev_graph_size = sched->debug_graph_size;\n    sched->debug_graph_size = graph_size;\n\n    if (sched->graph.size < graph_size) {\n        sched->graph.size = graph_size;\n        sched->graph.nodes = (ggml_tensor **) realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *));\n        sched->graph.leafs = (ggml_tensor **) realloc(sched->graph.leafs, graph_size * sizeof(struct ggml_tensor *));\n        GGML_ASSERT(sched->graph.nodes != NULL);\n        GGML_ASSERT(sched->graph.leafs != NULL);\n    }\n    sched->graph.n_nodes = 0;\n    sched->graph.n_leafs = 0;\n\n    struct ggml_cgraph * graph_copy = &sched->graph;\n\n    for (int i = 0; i < sched->n_splits; i++) {\n        struct ggml_backend_sched_split * split = &sched->splits[i];\n        split->graph = ggml_graph_view(graph, split->i_start, split->i_end);\n\n        // Optimize this split of the graph. This needs to happen before we make graph_copy,\n        // so they are in sync.\n        ggml_backend_graph_optimize(sched->backends[split->backend_id], &split->graph);\n\n        // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split\n        for (int j = 0; j < split->n_inputs; j++) {\n            assert(graph_copy->size > (graph_copy->n_nodes + 1));\n\n            struct ggml_tensor * input = split->inputs[j];\n            const size_t input_id = hash_id(input);\n            struct ggml_tensor * input_cpy = tensor_id_copy(input_id, split->backend_id, sched->cur_copy);\n\n            // add a dependency to the input source so that it is not freed before the copy is done\n            struct ggml_tensor * input_dep = ggml_view_tensor(sched->ctx, input);\n            input_dep->src[0] = input;\n            sched->node_backend_ids[graph_copy->n_nodes] = sched->hv_tensor_backend_ids[input_id];\n            graph_copy->nodes[graph_copy->n_nodes++] = input_dep;\n\n            // add a dependency to the input copy so that it is allocated at the start of the split\n            sched->node_backend_ids[graph_copy->n_nodes] = split->backend_id;\n            graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;\n        }\n\n        for (int j = split->i_start; j < split->i_end; j++) {\n            assert(graph_copy->size > graph_copy->n_nodes);\n            sched->node_backend_ids[graph_copy->n_nodes] = tensor_backend_id(graph->nodes[j]);\n            graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];\n        }\n    }\n\n    if (sched->n_copies > 1) {\n        // add input copies as leafs so that they are allocated first\n        for (int i = 0; i < sched->n_graph_inputs; i++) {\n            struct ggml_tensor * input = sched->graph_inputs[i];\n            size_t id = hash_id(input);\n            int backend_id = tensor_backend_id(input);\n            for (int c = 0; c < sched->n_copies; c++) {\n                struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c);\n                sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;\n                assert(graph_copy->size > graph_copy->n_leafs);\n                graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;\n            }\n        }\n\n        for (int i = 0; i < sched->n_splits; i++) {\n            struct ggml_backend_sched_split * split = &sched->splits[i];\n            int backend_id = split->backend_id;\n            for (int j = 0; j < split->n_inputs; j++) {\n                struct ggml_tensor * input = split->inputs[j];\n                size_t id = hash_id(input);\n                for (int c = 0; c < sched->n_copies; c++) {\n                    struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c);\n                    sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;\n                    assert(graph_copy->size > graph_copy->n_leafs);\n                    graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;\n                }\n            }\n        }\n    }\n\n    // add leafs from the original graph\n    for (int i = 0; i < graph->n_leafs; i++) {\n        struct ggml_tensor * leaf = graph->leafs[i];\n        sched->leaf_backend_ids[graph_copy->n_leafs] = tensor_backend_id(leaf);\n        assert(graph_copy->size > graph_copy->n_leafs);\n        graph_copy->leafs[graph_copy->n_leafs++] = leaf;\n    }\n}\n\nstatic bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {\n    bool backend_ids_changed = false;\n    for (int i = 0; i < sched->graph.n_nodes; i++) {\n        if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i] &&\n            sched->bufts[sched->node_backend_ids[i]] != sched->bufts[sched->prev_node_backend_ids[i]]) {\n            backend_ids_changed = true;\n            break;\n        }\n    }\n    if (!backend_ids_changed) {\n        for (int i = 0; i < sched->graph.n_leafs; i++) {\n            if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i] &&\n                sched->bufts[sched->leaf_backend_ids[i]] != sched->bufts[sched->prev_leaf_backend_ids[i]]) {\n                backend_ids_changed = true;\n                break;\n            }\n        }\n    }\n\n    // allocate graph\n    if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {\n#ifndef NDEBUG\n        GGML_LOG_DEBUG(\"%s: failed to allocate graph, reserving (backend_ids_changed = %d)\\n\", __func__, backend_ids_changed);\n#endif\n\n        if (sched->debug_realloc > 0) {\n            // we are interested only in situations where the graph was reallocated even though its size remained the same [GGML_SCHED_DEBUG_REALLOC]\n            // example: https://github.com/ggml-org/llama.cpp/pull/17143\n            const bool unexpected = !backend_ids_changed && sched->debug_prev_graph_size == sched->debug_graph_size;\n\n            if (unexpected || sched->debug_realloc > 1) {\n                GGML_ABORT(\"%s: unexpected graph reallocation (graph size = %d, nodes = %d, leafs = %d), debug_realloc = %d\\n\", __func__,\n                        sched->debug_graph_size, sched->graph.n_nodes, sched->graph.n_leafs, sched->debug_realloc);\n            }\n        }\n\n        // the re-allocation may cause the split inputs to be moved to a different address\n        // synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy\n        for (int i = 0; i < sched->n_backends; i++) {\n            ggml_backend_synchronize(sched->backends[i]);\n        }\n\n        ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);\n        if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {\n            GGML_LOG_ERROR(\"%s: failed to allocate graph\\n\", __func__);\n            return false;\n        }\n    }\n\n    return true;\n}\n\nstatic enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {\n    GGML_ASSERT(sched);\n    struct ggml_backend_sched_split * splits = sched->splits;\n\n    ggml_tensor * prev_ids_tensor = nullptr;\n    std::vector<int32_t> ids;\n    std::vector<ggml_bitset_t> used_ids;\n\n    for (int split_id = 0; split_id < sched->n_splits; split_id++) {\n        struct ggml_backend_sched_split * split = &splits[split_id];\n        int split_backend_id = split->backend_id;\n        ggml_backend_t split_backend = sched->backends[split_backend_id];\n\n        // copy the input tensors to the split backend\n        for (int input_id = 0; input_id < split->n_inputs; input_id++) {\n            ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]);\n            struct ggml_tensor * input = split->inputs[input_id];\n            struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);\n\n            if (input->flags & GGML_TENSOR_FLAG_INPUT) {\n                // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done\n                if (sched->events[split_backend_id][sched->cur_copy] != NULL) {\n                    ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);\n                } else {\n                    ggml_backend_synchronize(split_backend);\n                }\n                ggml_backend_tensor_copy(input, input_cpy);\n            } else {\n                // wait for the split backend to finish using the input before overwriting it\n                if (sched->events[split_backend_id][sched->cur_copy] != NULL) {\n                    ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);\n                } else {\n                    ggml_backend_synchronize(split_backend);\n                }\n\n                // when offloading MoE weights, we can reduce the amount of data copied by copying only the experts that are used\n                ggml_tensor * node = split->graph.nodes[0];\n                if (split->graph.n_nodes > 0 &&\n                    ggml_backend_buffer_get_usage(input->buffer) == GGML_BACKEND_BUFFER_USAGE_WEIGHTS &&\n                    ggml_backend_buffer_is_host(input->buffer) && (\n                    (node->src[0] == input_cpy && node->op == GGML_OP_MUL_MAT_ID)\n                    //|| (node->src[1] == input_cpy && node->op == GGML_OP_ADD_ID) /* GGML_OP_ADD_ID weights are small and not worth splitting */\n                    )) {\n\n                    const int64_t n_expert   = node->op == GGML_OP_MUL_MAT_ID ? input->ne[2] : input->ne[1];\n                    const size_t expert_size = node->op == GGML_OP_MUL_MAT_ID ? input->nb[2] : input->nb[1];\n\n                    ggml_backend_synchronize(input_backend);\n\n                    // get the ids\n                    ggml_tensor * ids_tensor = node->src[2];\n                    ggml_backend_t ids_backend = split_backend;\n\n                    // if the ids tensor is also an input of the split, it may not have been copied yet to the split backend\n                    // in that case, we use the original ids tensor\n                    for (int i = input_id + 1; i < split->n_inputs; i++) {\n                        if (ids_tensor == tensor_copy(split->inputs[i], split_backend_id, sched->cur_copy)) {\n                            ids_tensor = split->inputs[i];\n                            ids_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[i]);\n                            break;\n                        }\n                    }\n\n                    if (ids_tensor != prev_ids_tensor) {\n                        ids.resize(ggml_nbytes(ids_tensor) / sizeof(int32_t));\n                        ggml_backend_tensor_get_async(ids_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor));\n                        ggml_backend_synchronize(ids_backend);\n\n                        // find the used experts\n                        used_ids.clear();\n                        used_ids.resize(ggml_bitset_size(n_expert));\n                        for (int64_t i1 = 0; i1 < ids_tensor->ne[1]; i1++) {\n                            for (int64_t i0 = 0; i0 < ids_tensor->ne[0]; i0++) {\n                                int32_t id = ids[i1 * ids_tensor->nb[1]/sizeof(int32_t) + i0 * ids_tensor->nb[0]/sizeof(int32_t)];\n                                GGML_ASSERT(id >= 0 && id < n_expert);\n                                ggml_bitset_set(used_ids.data(), id);\n                            }\n                        }\n\n                        prev_ids_tensor = ids_tensor;\n                    }\n\n                    // group consecutive experts and copy them together\n                    auto copy_experts = [&](int32_t first_id, int32_t last_id) {\n                        const size_t expert_offset = first_id * expert_size;\n                        const size_t expert_size_copy =  (last_id - first_id + 1) * expert_size;\n                        const size_t padding = std::min<size_t>(expert_size, 512);\n                        const size_t padding_end = last_id < n_expert - 1 ? padding : 0;\n\n                        ggml_backend_tensor_set_async(split_backend,\n                            input_cpy,\n                            (const uint8_t *)input->data + expert_offset, expert_offset,\n                            // copy a bit extra at the to ensure there are no NaNs in the padding of the last expert\n                            // this is necessary for MMQ in the CUDA backend\n                            expert_size_copy + padding_end);\n                    };\n\n                    int id = 0;\n                    while (!ggml_bitset_get(used_ids.data(), id)) {\n                        id++;\n                    }\n                    int32_t first_id = id;\n                    int32_t last_id = first_id;\n\n                    for (++id; id < n_expert; ++id) {\n                        if (!ggml_bitset_get(used_ids.data(), id)) {\n                            continue;\n                        }\n\n                        if (id == last_id + 1) {\n                            last_id = id;\n                            continue;\n                        }\n\n                        copy_experts(first_id, last_id);\n\n                        first_id = id;\n                        last_id = id;\n                    }\n                    copy_experts(first_id, last_id);\n                } else {\n                    // try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events\n                    // TODO: add public function to facilitate this, since applications do not have direct access to the backend interface\n                    if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) {\n                        ggml_backend_synchronize(input_backend);\n                        if (sched->events[split_backend_id][sched->cur_copy] != NULL) {\n                            ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);\n                        } else {\n                            ggml_backend_synchronize(split_backend);\n                        }\n                        ggml_backend_tensor_copy(input, input_cpy);\n                    }\n                }\n            }\n        }\n\n        if (!sched->callback_eval) {\n            enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);\n            if (ec != GGML_STATUS_SUCCESS) {\n                return ec;\n            }\n        } else {\n            // similar to ggml_backend_compare_graph_backend\n            for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {\n                struct ggml_tensor * t = split->graph.nodes[j0];\n\n                // check if the user needs data from this node\n                bool need = sched->callback_eval(t, true, sched->callback_eval_user_data);\n\n                int j1 = j0;\n\n                // determine the range [j0, j1] of nodes that can be computed together\n                while (!need && j1 < split->graph.n_nodes - 1) {\n                    t = split->graph.nodes[++j1];\n                    need = sched->callback_eval(t, true, sched->callback_eval_user_data);\n                }\n\n                struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);\n\n                enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv);\n                if (ec != GGML_STATUS_SUCCESS) {\n                    return ec;\n                }\n\n                // TODO: pass backend to the callback, then the user can decide if they want to synchronize\n                ggml_backend_synchronize(split_backend);\n\n                if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {\n                    break;\n                }\n\n                j0 = j1;\n            }\n        }\n\n        // record the event of this copy\n        if (split->n_inputs > 0) {\n            if (sched->events[split_backend_id][sched->cur_copy] != NULL) {\n                ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy], split_backend);\n            }\n        }\n    }\n\n    return GGML_STATUS_SUCCESS;\n}\n\nggml_backend_sched_t ggml_backend_sched_new(\n        ggml_backend_t * backends,\n        ggml_backend_buffer_type_t * bufts,\n        int n_backends,\n        size_t graph_size,\n        bool parallel,\n        bool op_offload) {\n    GGML_ASSERT(n_backends > 0);\n    GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);\n    GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU);\n\n    struct ggml_backend_sched * sched = (ggml_backend_sched *) calloc(1, sizeof(struct ggml_backend_sched));\n\n    const char * GGML_SCHED_DEBUG = getenv(\"GGML_SCHED_DEBUG\");\n    sched->debug = GGML_SCHED_DEBUG ? atoi(GGML_SCHED_DEBUG) : 0;\n\n    sched->debug_realloc = 0;\n#ifdef GGML_SCHED_NO_REALLOC\n    sched->debug_realloc = 1;\n#endif\n    const char * GGML_SCHED_DEBUG_REALLOC = getenv(\"GGML_SCHED_DEBUG_REALLOC\");\n    sched->debug_realloc = GGML_SCHED_DEBUG_REALLOC ? atoi(GGML_SCHED_DEBUG_REALLOC) : sched->debug_realloc;\n\n    sched->n_backends = n_backends;\n    sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;\n\n    // initialize hash table\n    // FIXME: needs to be size*2 to account for leafs (do it in graph_split instead)\n    sched->hash_set    = ggml_hash_set_new(graph_size);\n    sched->hv_tensor_backend_ids = (int *) malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));\n    sched->hv_tensor_copies      = (ggml_tensor **) malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));\n\n    const size_t ggml_sched_max_splits = graph_size; // at most there is one split for each node in the graph\n    const size_t nodes_size = graph_size + ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2;\n    sched->node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->node_backend_ids[0]));\n    sched->leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->leaf_backend_ids[0]));\n    sched->prev_node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0]));\n    sched->prev_leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0]));\n\n    sched->debug_graph_size = 0;\n    sched->debug_prev_graph_size = 0;\n\n    sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false);\n    sched->context_buffer = (char *) malloc(sched->context_buffer_size);\n\n    const int initial_splits_capacity = 16;\n    sched->splits = (ggml_backend_sched_split *) calloc(initial_splits_capacity, sizeof(sched->splits[0]));\n    sched->splits_capacity = initial_splits_capacity;\n\n    for (int b = 0; b < n_backends; b++) {\n        sched->backends[b] = backends[b];\n        sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]);\n        GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b]));\n\n        if (sched->n_copies > 1) {\n            for (int c = 0; c < sched->n_copies; c++) {\n                sched->events[b][c] = ggml_backend_event_new(backends[b]->device);\n            }\n        }\n    }\n\n    sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);\n    sched->op_offload = op_offload;\n\n    ggml_backend_sched_reset(sched);\n\n    return sched;\n}\n\nvoid ggml_backend_sched_free(ggml_backend_sched_t sched) {\n    if (sched == NULL) {\n        return;\n    }\n    for (int b = 0; b < sched->n_backends; b++) {\n        for (int c = 0; c < sched->n_copies; c++) {\n            ggml_backend_event_free(sched->events[b][c]);\n        }\n    }\n    ggml_gallocr_free(sched->galloc);\n    ggml_free(sched->ctx);\n    ggml_hash_set_free(&sched->hash_set);\n    free(sched->splits);\n    free(sched->hv_tensor_backend_ids);\n    free(sched->hv_tensor_copies);\n    free(sched->node_backend_ids);\n    free(sched->leaf_backend_ids);\n    free(sched->prev_node_backend_ids);\n    free(sched->prev_leaf_backend_ids);\n    free(sched->context_buffer);\n    free(sched->graph.nodes);\n    free(sched->graph.leafs);\n    free(sched);\n}\n\nvoid ggml_backend_sched_reset(ggml_backend_sched_t sched) {\n    GGML_ASSERT(sched);\n    // reset state for the next run\n    if (!sched->is_reset) {\n        ggml_hash_set_reset(&sched->hash_set);\n        memset(sched->hv_tensor_backend_ids, -1, sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));\n        memset(sched->hv_tensor_copies,       0, sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));\n        sched->is_reset = true;\n    }\n    sched->is_alloc = false;\n}\n\nvoid ggml_backend_sched_reserve_size(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph, size_t * sizes) {\n    GGML_ASSERT(sched);\n    GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);\n    GGML_ASSERT(sizes);\n\n    ggml_backend_sched_reset(sched);\n\n    ggml_backend_sched_synchronize(sched);\n\n    ggml_backend_sched_split_graph(sched, measure_graph);\n\n    ggml_gallocr_reserve_n_size(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids, sizes);\n}\n\nbool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {\n    GGML_ASSERT(sched);\n    GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);\n\n    ggml_backend_sched_synchronize(sched);\n\n    ggml_backend_sched_split_graph(sched, measure_graph);\n\n    if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) {\n        return false;\n    }\n\n    ggml_backend_sched_reset(sched);\n\n    return true;\n}\n\nbool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {\n    GGML_ASSERT(sched);\n    GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);\n    GGML_ASSERT(!sched->is_alloc);\n\n    sched->cur_copy = sched->next_copy;\n    sched->next_copy = (sched->next_copy + 1) % sched->n_copies;\n\n    ggml_backend_sched_split_graph(sched, graph);\n\n    if (!ggml_backend_sched_alloc_splits(sched)) {\n        return false;\n    }\n\n    sched->is_alloc = true;\n\n    return true;\n}\n\nenum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {\n    enum ggml_status err = ggml_backend_sched_graph_compute_async(sched, graph);\n    ggml_backend_sched_synchronize(sched);\n    return err;\n}\n\nenum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {\n    GGML_ASSERT(sched);\n    if (!sched->is_reset && !sched->is_alloc) {\n        ggml_backend_sched_reset(sched);\n    }\n\n    if (!sched->is_alloc) {\n        if (!ggml_backend_sched_alloc_graph(sched, graph)) {\n            return GGML_STATUS_ALLOC_FAILED;\n        }\n    }\n\n    return ggml_backend_sched_compute_splits(sched);\n}\n\nvoid ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {\n    GGML_ASSERT(sched);\n    for (int i = 0; i < sched->n_backends; i++) {\n        ggml_backend_synchronize(sched->backends[i]);\n    }\n    if (!sched->is_alloc) {\n        // if the graph is not already allocated, always use copy 0 after a synchronization\n        // this ensures that during generation the same copy is used every time,\n        // which avoids changes in the graph that could cause CUDA or other graphs to be disabled\n        sched->next_copy = 0;\n    }\n}\n\nvoid ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {\n    GGML_ASSERT(sched);\n    sched->callback_eval = callback;\n    sched->callback_eval_user_data = user_data;\n}\n\nint ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {\n    GGML_ASSERT(sched);\n    return sched->n_splits;\n}\n\nint ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {\n    GGML_ASSERT(sched);\n    return sched->n_copies;\n}\n\nint ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) {\n    GGML_ASSERT(sched);\n    return sched->n_backends;\n}\n\nggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) {\n    GGML_ASSERT(sched);\n    GGML_ASSERT(i >= 0 && i < sched->n_backends);\n    return sched->backends[i];\n}\n\nggml_backend_buffer_type_t ggml_backend_sched_get_buffer_type(ggml_backend_sched_t sched, ggml_backend_t backend) {\n    GGML_ASSERT(sched);\n    int backend_index = ggml_backend_sched_backend_id(sched, backend);\n    GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);\n\n    return sched->bufts[backend_index];\n}\n\nsize_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {\n    GGML_ASSERT(sched);\n    int backend_index = ggml_backend_sched_backend_id(sched, backend);\n    GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);\n\n    return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);\n}\n\nvoid ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {\n    GGML_ASSERT(sched);\n    int backend_index = ggml_backend_sched_backend_id(sched, backend);\n    GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);\n    tensor_backend_id(node) = backend_index;\n    SET_CAUSE(node, \"usr\");\n    sched->is_reset = false;\n}\n\nggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {\n    GGML_ASSERT(sched);\n    int backend_index = tensor_backend_id(node);\n    if (backend_index == -1) {\n        return NULL;\n    }\n    return sched->backends[backend_index];\n}\n\n// utils\n\nenum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) {\n    GGML_ASSERT(tensor);\n    GGML_ASSERT(tensor->buffer == NULL);\n    GGML_ASSERT(tensor->view_src != NULL);\n    GGML_ASSERT(tensor->view_src->buffer != NULL);\n    GGML_ASSERT(tensor->view_src->data != NULL);\n\n    tensor->buffer = tensor->view_src->buffer;\n    tensor->data = (char *)tensor->view_src->data + tensor->view_offs;\n    return ggml_backend_buffer_init_tensor(tensor->buffer, tensor);\n}\n\nenum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {\n    GGML_ASSERT(tensor);\n    GGML_ASSERT(tensor->buffer == NULL);\n    GGML_ASSERT(tensor->data == NULL);\n    GGML_ASSERT(tensor->view_src == NULL);\n    GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));\n    GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=\n                (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));\n\n    tensor->buffer = buffer;\n    tensor->data = addr;\n    return ggml_backend_buffer_init_tensor(buffer, tensor);\n}\n\nstatic struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,\n    struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) {\n\n    GGML_ASSERT(src != NULL);\n    GGML_ASSERT(src->data && \"graph must be allocated\");\n\n    size_t id = ggml_hash_insert(&hash_set, src);\n    if (id == GGML_HASHSET_ALREADY_EXISTS) {\n        return node_copies[ggml_hash_find(&hash_set, src)];\n    }\n\n    struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);\n    if (src->view_src != NULL) {\n        dst->view_src = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);\n        dst->view_offs = src->view_offs;\n    }\n    dst->op = src->op;\n    dst->flags = src->flags;\n    memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));\n    ggml_set_name(dst, src->name);\n\n    // copy src\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        struct ggml_tensor * s = src->src[i];\n        if (s == NULL) {\n            continue;\n        }\n        dst->src[i] = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);\n    }\n\n    node_copies[id] = dst;\n    return dst;\n}\n\nstatic void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) {\n    size_t id = ggml_hash_find(hash_set, src);\n    if (node_init[id]) {\n        return;\n    }\n    node_init[id] = true;\n\n    struct ggml_tensor * dst = node_copies[id];\n    if (dst->view_src != NULL) {\n        graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src);\n        enum ggml_status status = ggml_backend_view_init(dst);\n        GGML_ASSERT(status == GGML_STATUS_SUCCESS);\n    }\n    else {\n        ggml_backend_tensor_copy(src, dst);\n    }\n\n    // init src\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        struct ggml_tensor * s = src->src[i];\n        if (s == NULL) {\n            continue;\n        }\n        graph_copy_init_tensor(hash_set, node_copies, node_init, s);\n    }\n}\n\nstruct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {\n    GGML_ASSERT(graph);\n    struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size);\n    struct ggml_tensor ** node_copies = (ggml_tensor **) calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT\n    bool * node_init = (bool *) calloc(hash_set.size, sizeof(node_init[0]));\n\n    struct ggml_init_params params = {\n        /* .mem_size   = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false),\n        /* .mem_buffer = */ NULL,\n        /* .no_alloc   = */ true\n    };\n\n    struct ggml_context * ctx_allocated = ggml_init(params);\n    struct ggml_context * ctx_unallocated = ggml_init(params);\n\n    if (ctx_allocated == NULL || ctx_unallocated == NULL) {\n        GGML_LOG_ERROR(\"%s: failed to allocate context for graph copy\\n\", __func__);\n        ggml_hash_set_free(&hash_set);\n        free(node_copies);\n        free(node_init);\n        ggml_free(ctx_allocated);\n        ggml_free(ctx_unallocated);\n        return {\n            /* .buffer           = */ NULL,\n            /* .ctx_allocated    = */ NULL,\n            /* .ctx_unallocated  = */ NULL,\n            /* .graph            = */ NULL,\n        };\n    }\n\n    // dup nodes\n    for (int i = 0; i < graph->n_nodes; i++) {\n        struct ggml_tensor * node = graph->nodes[i];\n        graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);\n    }\n\n    // allocate nodes\n    ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);\n    if (buffer == NULL) {\n        GGML_LOG_ERROR(\"%s: failed to allocate buffer for graph copy\\n\", __func__);\n        ggml_hash_set_free(&hash_set);\n        free(node_copies);\n        free(node_init);\n        ggml_free(ctx_allocated);\n        ggml_free(ctx_unallocated);\n        return {\n            /* .buffer           = */ NULL,\n            /* .ctx_allocated    = */ NULL,\n            /* .ctx_unallocated  = */ NULL,\n            /* .graph            = */ NULL,\n        };\n    }\n\n    //printf(\"copy buffer size: %zu MB\\n\", ggml_backend_buffer_get_size(buffer) / 1024 / 1024);\n\n    // copy data and init views\n    for (int i = 0; i < graph->n_nodes; i++) {\n        struct ggml_tensor * node = graph->nodes[i];\n        graph_copy_init_tensor(&hash_set, node_copies, node_init, node);\n    }\n\n    // build graph copy\n    struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false);\n    for (int i = 0; i < graph->n_nodes; i++) {\n        struct ggml_tensor * node = graph->nodes[i];\n        struct ggml_tensor * node_copy = node_copies[ggml_hash_find(&hash_set, node)];\n        graph_copy->nodes[i] = node_copy;\n    }\n    graph_copy->n_nodes = graph->n_nodes;\n\n    ggml_hash_set_free(&hash_set);\n    free(node_copies);\n    free(node_init);\n\n    return {\n        /* .buffer           = */ buffer,\n        /* .ctx_allocated    = */ ctx_allocated,\n        /* .ctx_unallocated  = */ ctx_unallocated,\n        /* .graph            = */ graph_copy,\n    };\n}\n\nvoid ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {\n    ggml_backend_buffer_free(copy.buffer);\n    ggml_free(copy.ctx_allocated);\n    ggml_free(copy.ctx_unallocated);\n}\n\nbool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes) {\n    struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);\n    if (copy.buffer == NULL) {\n        return false;\n    }\n\n    struct ggml_cgraph * g1 = graph;\n    struct ggml_cgraph * g2 = copy.graph;\n\n    assert(g1->n_nodes == g2->n_nodes);\n\n    if (num_test_nodes != 0) {\n        GGML_ASSERT(test_nodes);\n        // Compute the whole graph and only test the output for specific tensors\n        ggml_backend_graph_compute(backend1, g1);\n        ggml_backend_graph_compute(backend2, g2);\n\n        bool verified = false;\n        for (int i = 0; i < g1->n_nodes; i++) {\n            for (size_t j = 0; j < num_test_nodes; ++j) {\n                if (g1->nodes[i] == test_nodes[j]) {\n                    callback(i, g1->nodes[i], g2->nodes[i], user_data);\n                    verified = true;\n                }\n            }\n        }\n        GGML_ASSERT(verified);\n    } else {\n        for (int i = 0; i < g1->n_nodes; i++) {\n            struct ggml_tensor * t1 = g1->nodes[i];\n            struct ggml_tensor * t2 = g2->nodes[i];\n\n            assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));\n\n            struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);\n            struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);\n\n            ggml_backend_graph_compute(backend1, &g1v);\n            ggml_backend_graph_compute(backend2, &g2v);\n\n            if (ggml_is_view_op(t1->op)) {\n                continue;\n            }\n\n            // compare results, calculate rms etc\n            if (!callback(i, t1, t2, user_data)) {\n                break;\n            }\n        }\n    }\n    ggml_backend_graph_copy_free(copy);\n\n    return true;\n}\n\n// CPU backend - buffer\n\nstatic void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {\n    GGML_ASSERT(buffer);\n    uintptr_t data = (uintptr_t)buffer->context;\n\n    // align the buffer\n    if (data % TENSOR_ALIGNMENT != 0) {\n        data = GGML_PAD(data, TENSOR_ALIGNMENT);\n    }\n\n    return (void *)data;\n}\n\nstatic void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    GGML_ASSERT(buffer);\n    ggml_aligned_free(buffer->context, buffer->size);\n}\n\nstatic void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {\n    GGML_ASSERT(tensor);\n    memset((char *)tensor->data + offset, value, size);\n\n    GGML_UNUSED(buffer);\n}\n\nstatic void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    GGML_ASSERT(tensor);\n    memcpy((char *)tensor->data + offset, data, size);\n\n    GGML_UNUSED(buffer);\n}\n\nstatic void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    GGML_ASSERT(tensor);\n    memcpy(data, (const char *)tensor->data + offset, size);\n\n    GGML_UNUSED(buffer);\n}\n\nstatic bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {\n    GGML_ASSERT(src);\n    if (ggml_backend_buffer_is_host(src->buffer)) {\n        memcpy(dst->data, src->data, ggml_nbytes(src));\n        return true;\n    }\n    return false;\n\n    GGML_UNUSED(buffer);\n}\n\nstatic void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    GGML_ASSERT(buffer);\n    memset(buffer->context, value, buffer->size);\n}\n\nstatic const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = {\n    /* .free_buffer     = */ ggml_backend_cpu_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_cpu_buffer_get_base,\n    /* .init_tensor     = */ NULL, // no initialization required\n    /* .memset_tensor   = */ ggml_backend_cpu_buffer_memset_tensor,\n    /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,\n    /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,\n    /* .cpy_tensor      = */ ggml_backend_cpu_buffer_cpy_tensor,\n    /* .clear           = */ ggml_backend_cpu_buffer_clear,\n    /* .reset           = */ NULL,\n};\n\nstatic const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = {\n    /* .free_buffer     = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed\n    /* .get_base        = */ ggml_backend_cpu_buffer_get_base,\n    /* .init_tensor     = */ NULL, // no initialization required\n    /* .memset_tensor   = */ ggml_backend_cpu_buffer_memset_tensor,\n    /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,\n    /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,\n    /* .cpy_tensor      = */ ggml_backend_cpu_buffer_cpy_tensor,\n    /* .clear           = */ ggml_backend_cpu_buffer_clear,\n    /* .reset           = */ NULL,\n};\n\n// CPU backend buffer type\n\n// this buffer type is defined here to make it available to all backends\n\nstatic const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    return \"CPU\";\n\n    GGML_UNUSED(buft);\n}\n\nstatic ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    void * data = ggml_aligned_malloc(size);\n\n    if (data == NULL) {\n        GGML_LOG_ERROR(\"%s: failed to allocate buffer of size %zu\\n\", __func__, size);\n        return NULL;\n    }\n\n    return ggml_backend_buffer_init(buft, ggml_backend_cpu_buffer_i, data, size);\n}\n\nstatic size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    return TENSOR_ALIGNMENT;\n\n    GGML_UNUSED(buft);\n}\n\nstatic bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft) {\n    return true;\n\n    GGML_UNUSED(buft);\n}\n\nggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {\n    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = {\n        /* .iface   = */ {\n            /* .get_name         = */ ggml_backend_cpu_buffer_type_get_name,\n            /* .alloc_buffer     = */ ggml_backend_cpu_buffer_type_alloc_buffer,\n            /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,\n            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX\n            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes\n            /* .is_host          = */ ggml_backend_cpu_buffer_type_is_host,\n        },\n        /* .device  = */ NULL, // FIXME ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),\n        /* .context = */ NULL,\n    };\n\n    return &ggml_backend_cpu_buffer_type;\n}\n\nstatic const char * ggml_backend_cpu_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {\n    return \"CPU_Mapped\";\n\n    GGML_UNUSED(buft);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_cpu_buffer_from_ptr_type(void) {\n    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = {\n        /* .iface   = */ {\n            /* .get_name         = */ ggml_backend_cpu_buffer_from_ptr_type_get_name,\n            /* .alloc_buffer     = */ ggml_backend_cpu_buffer_type_alloc_buffer,\n            /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,\n            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX\n            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes\n            /* .is_host          = */ ggml_backend_cpu_buffer_type_is_host,\n        },\n        /* .device  = */ NULL, // FIXME ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),\n        /* .context = */ NULL,\n    };\n\n    return &ggml_backend_cpu_buffer_type;\n}\n\nggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {\n    GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && \"buffer pointer must be aligned\");\n    return ggml_backend_buffer_init(ggml_backend_cpu_buffer_from_ptr_type(), ggml_backend_cpu_buffer_from_ptr_i, ptr, size);\n}\n"
  },
  {
    "path": "src/ggml-blas/CMakeLists.txt",
    "content": "if (GGML_STATIC)\n    set(BLA_STATIC ON)\nendif()\n#if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.22)\n#    set(BLA_SIZEOF_INTEGER 8)\n#endif()\n\nset(BLA_VENDOR ${GGML_BLAS_VENDOR})\nfind_package(BLAS)\n\nif (BLAS_FOUND)\n    message(STATUS \"BLAS found, Libraries: ${BLAS_LIBRARIES}\")\n\n    ggml_add_backend_library(ggml-blas\n                             ggml-blas.cpp\n                            )\n\n    if (${GGML_BLAS_VENDOR} MATCHES \"Apple\")\n        add_compile_definitions(ACCELERATE_NEW_LAPACK)\n        add_compile_definitions(ACCELERATE_LAPACK_ILP64)\n        add_compile_definitions(GGML_BLAS_USE_ACCELERATE)\n    elseif (\"${BLAS_INCLUDE_DIRS}\" STREQUAL \"\")\n        # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake.\n        # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268\n        find_package(PkgConfig REQUIRED)\n        if (${GGML_BLAS_VENDOR} MATCHES \"Generic\")\n            pkg_check_modules(DepBLAS blas)\n        elseif (${GGML_BLAS_VENDOR} MATCHES \"OpenBLAS\")\n            # As of openblas v0.3.22, the 64-bit is named openblas64.pc\n            pkg_check_modules(DepBLAS openblas64)\n            if (NOT DepBLAS_FOUND)\n                pkg_check_modules(DepBLAS openblas)\n            endif()\n        elseif (${GGML_BLAS_VENDOR} MATCHES \"FLAME\")\n            pkg_check_modules(DepBLAS blis)\n        elseif (${GGML_BLAS_VENDOR} MATCHES \"ATLAS\")\n            pkg_check_modules(DepBLAS blas-atlas)\n        elseif (${GGML_BLAS_VENDOR} MATCHES \"FlexiBLAS\")\n            pkg_check_modules(DepBLAS flexiblas_api)\n        elseif (${GGML_BLAS_VENDOR} MATCHES \"Intel\")\n            # all Intel* libraries share the same include path\n            pkg_check_modules(DepBLAS mkl-sdl)\n        elseif (${GGML_BLAS_VENDOR} MATCHES \"NVHPC\")\n            # this doesn't provide pkg-config\n            # suggest to assign BLAS_INCLUDE_DIRS on your own\n            if (\"${NVHPC_VERSION}\" STREQUAL \"\")\n                message(WARNING \"Better to set NVHPC_VERSION\")\n            else()\n                set(DepBLAS_FOUND ON)\n                set(DepBLAS_INCLUDE_DIRS \"/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include\")\n            endif()\n        endif()\n        if (DepBLAS_FOUND)\n            set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS})\n        else()\n            message(WARNING \"BLAS_INCLUDE_DIRS neither been provided nor been automatically\"\n            \" detected by pkgconfig, trying to find cblas.h from possible paths...\")\n            find_path(BLAS_INCLUDE_DIRS\n                NAMES cblas.h\n                HINTS\n                    /usr/include\n                    /usr/local/include\n                    /usr/include/openblas\n                    /opt/homebrew/opt/openblas/include\n                    /usr/local/opt/openblas/include\n                    /usr/include/x86_64-linux-gnu/openblas/include\n            )\n        endif()\n    endif()\n\n    message(STATUS \"BLAS found, Includes: ${BLAS_INCLUDE_DIRS}\")\n\n    target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS})\n\n    if (\"${GGML_BLAS_VENDOR}\" STREQUAL \"\")\n        message(WARNING \"GGML_BLAS_VENDOR is not set; some methods may not link properly.\")\n    endif()\n\n    if (\"${GGML_BLAS_VENDOR}\" MATCHES \"Intel\" OR (\"${BLAS_INCLUDE_DIRS}\" MATCHES \"mkl\" AND \"${GGML_BLAS_VENDOR}\" MATCHES \"Generic\"))\n        add_compile_definitions(GGML_BLAS_USE_MKL)\n    endif()\n\n    if (\"${GGML_BLAS_VENDOR}\" MATCHES \"OpenBLAS\")\n        add_compile_definitions(GGML_BLAS_USE_OPENBLAS)\n    endif()\n\n    if (\"${GGML_BLAS_VENDOR}\" MATCHES \"FLAME\" OR \"${GGML_BLAS_VENDOR}\" MATCHES \"AOCL\" OR \"${GGML_BLAS_VENDOR}\" MATCHES \"AOCL_mt\")\n        add_compile_definitions(GGML_BLAS_USE_BLIS)\n    endif()\n\n    if (\"${GGML_BLAS_VENDOR}\" MATCHES \"NVPL\")\n        add_compile_definitions(GGML_BLAS_USE_NVPL)\n    endif()\n\n    target_link_libraries     (ggml-blas PRIVATE ${BLAS_LIBRARIES})\n    target_include_directories(ggml-blas SYSTEM PRIVATE ${BLAS_INCLUDE_DIRS})\nelse()\n    message(FATAL_ERROR \"BLAS not found, please refer to \"\n                        \"https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors\"\n                        \" to set correct GGML_BLAS_VENDOR\")\nendif()\n"
  },
  {
    "path": "src/ggml-blas/ggml-blas.cpp",
    "content": "#include \"ggml-impl.h\"\n#include \"ggml-blas.h\"\n#include \"ggml-backend-impl.h\"\n\n#include <future>\n#include <vector>\n#include <cstring>\n\n#if defined(GGML_BLAS_USE_ACCELERATE)\n#   include <Accelerate/Accelerate.h>\n#elif defined(GGML_BLAS_USE_MKL)\n#   include <mkl.h>\n#elif defined(GGML_BLAS_USE_BLIS)\n#   include <blis.h>\n#elif defined(GGML_BLAS_USE_NVPL)\n#   include <nvpl_blas.h>\n#else\n#   include <cblas.h>\n#endif\n\nstruct ggml_backend_blas_context {\n    int n_threads = GGML_DEFAULT_N_THREADS;\n    std::unique_ptr<char[]> work_data;\n    size_t work_size = 0;\n#ifndef GGML_USE_OPENMP\n    std::vector<std::future<void>> tasks;\n#endif\n};\n\nstatic void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {\n    const struct ggml_tensor * src0 = dst->src[0];\n    const struct ggml_tensor * src1 = dst->src[1];\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const enum ggml_type type = src0->type;\n\n    GGML_ASSERT(ne0 == ne01);\n    GGML_ASSERT(ne1 == ne11);\n    GGML_ASSERT(ne2 == ne12);\n    GGML_ASSERT(ne3 == ne13);\n\n    // we don't support permuted src0 or src1\n    GGML_ASSERT(nb00 == ggml_type_size(type));\n    GGML_ASSERT(nb10 == ggml_type_size(src1->type));\n\n    // dst cannot be transposed or permuted\n    GGML_ASSERT(nb0 == sizeof(float));\n    GGML_ASSERT(nb0 <= nb1);\n    GGML_ASSERT(nb1 <= nb2);\n    GGML_ASSERT(nb2 <= nb3);\n\n    // broadcast factors\n    const int64_t r2 = ne12/ne02;\n    const int64_t r3 = ne13/ne03;\n\n    const int64_t ne_plane      = ne01*ne00;\n    const size_t  desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);\n\n    if (ctx->work_size < desired_wsize) {\n        ctx->work_data.reset(new char[desired_wsize]);\n        ctx->work_size = desired_wsize;\n    }\n    void * wdata = ctx->work_data.get();\n\n    // convert src0 to float\n    if (type != GGML_TYPE_F32) {\n        const auto * type_traits = ggml_get_type_traits(type);\n        ggml_to_float_t const to_float = type_traits->to_float;\n\n        for (int64_t i03 = 0; i03 < ne03; i03++) {\n            for (int64_t i02 = 0; i02 < ne02; i02++) {\n                const void  *       x      = (char *)  src0->data + i02*nb02          + i03*nb03;\n                      float * const wplane = (float *) wdata      + i02*ne_plane      + i03*ne02*ne_plane;\n\n                const int min_cols_per_thread = 4096;\n                const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);\n                const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1);\n\n#ifdef GGML_USE_OPENMP\n                #pragma omp parallel for num_threads(n_threads)\n                for (int64_t i01 = 0; i01 < ne01; i01++) {\n                    to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);\n                }\n#else\n                for (int i = 1; i < n_threads; i++) {\n                    const int64_t start =       i*ne01/n_threads;\n                    const int64_t end   = (i + 1)*ne01/n_threads;\n                    if (start < end) {\n                        ctx->tasks.push_back(std::async(std::launch::async, [=]() {\n                            for (int64_t i01 = start; i01 < end; i01++) {\n                                to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);\n                            }\n                        }));\n                    }\n                }\n                {\n                    // reuse the current thread for the first task\n                    const int64_t start = 0;\n                    const int64_t end   = ne01/n_threads;\n                    for (int64_t i01 = start; i01 < end; i01++) {\n                        to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);\n                    }\n                }\n#endif\n            }\n        }\n\n#ifndef GGML_USE_OPENMP\n        // wait for all tasks to finish\n        for (auto & task : ctx->tasks) {\n            task.get();\n        }\n        ctx->tasks.clear();\n#endif\n    }\n\n#if defined(GGML_BLAS_USE_OPENBLAS)\n    openblas_set_num_threads(ctx->n_threads);\n#elif defined(GGML_BLAS_USE_BLIS)\n    bli_thread_set_num_threads(ctx->n_threads);\n#elif defined(GGML_BLAS_USE_NVPL)\n    nvpl_blas_set_num_threads(ctx->n_threads);\n#endif\n\n    for (int64_t i13 = 0; i13 < ne13; i13++) {\n        for (int64_t i12 = 0; i12 < ne12; i12++) {\n            const int64_t i03 = i13/r3;\n            const int64_t i02 = i12/r2;\n\n            const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);\n            const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);\n                  float * d = (float *) ((char *)  dst->data + i12*nb2  + i13*nb3);\n\n            if (type != GGML_TYPE_F32) {\n                x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;\n            }\n\n            cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,\n                        ne1, ne01, ne10,\n                        1.0f,   y, ne10,\n                                x, ne00,\n                        0.0f,   d, ne01);\n        }\n    }\n}\n\nstatic void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {\n    const struct ggml_tensor * src0 = dst->src[0];\n    const struct ggml_tensor * src1 = dst->src[1];\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    GGML_ASSERT(ne0  == ne00);\n    GGML_ASSERT(ne1  == ne10);\n    GGML_ASSERT(ne2  == ne02);\n    GGML_ASSERT(ne02 == ne12);\n    GGML_ASSERT(ne3  == ne13);\n    GGML_ASSERT(ne03 == ne13);\n\n    // we don't support permuted src0 or src1\n    GGML_ASSERT(nb00 == sizeof(float));\n\n    // dst cannot be transposed or permuted\n    GGML_ASSERT(nb0 == sizeof(float));\n    // GGML_ASSERT(nb0 <= nb1);\n    // GGML_ASSERT(nb1 <= nb2);\n    // GGML_ASSERT(nb2 <= nb3);\n\n    // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)\n    // src0: (k,n)\n    // src1: (k,m)\n    // dst:  (m,n)\n    //\n    // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)\n    // Also expressed as (major,minor)\n    // a: (m,k): so src1 transposed\n    // b: (k,n): so src0\n    // c: (m,n)\n    //\n    // However, if ggml_is_transposed(src1) is true, then\n    // src1->data already contains a transposed version, so sgemm mustn't\n    // transpose it further.\n\n    int n = src0->ne[0];\n    int k = src0->ne[1];\n    int m = src1->ne[0];\n\n    CBLAS_TRANSPOSE transposeA;\n    int lda;\n\n    if (!ggml_is_transposed(src1)) {\n        transposeA = CblasTrans;\n        lda = m;\n    } else {\n        transposeA = CblasNoTrans;\n        lda = k;\n    }\n\n    float * a = (float *) ((char *) src1->data);\n    float * b = (float *) ((char *) src0->data);\n    float * c = (float *) ((char *) dst->data);\n\n    cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);\n\n    GGML_UNUSED(ctx);\n}\n\n// backend interface\n\nstatic const char * ggml_backend_blas_get_name(ggml_backend_t backend) {\n    return \"BLAS\";\n\n    GGML_UNUSED(backend);\n}\n\nstatic void ggml_backend_blas_free(ggml_backend_t backend) {\n    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;\n    delete ctx;\n    delete backend;\n}\n\nstatic enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {\n    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;\n\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        struct ggml_tensor * node = cgraph->nodes[i];\n\n        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {\n            continue;\n        }\n\n        switch (node->op) {\n            case GGML_OP_MUL_MAT:\n                ggml_backend_blas_mul_mat(ctx, node);\n                break;\n\n            case GGML_OP_OUT_PROD:\n                ggml_backend_blas_out_prod(ctx, node);\n                break;\n\n            case GGML_OP_NONE:\n            case GGML_OP_RESHAPE:\n            case GGML_OP_VIEW:\n            case GGML_OP_PERMUTE:\n            case GGML_OP_TRANSPOSE:\n                break;\n\n            default:\n                GGML_ABORT(\"%s: unsupported op %s\\n\", __func__, ggml_op_desc(node));\n        }\n    }\n\n    return GGML_STATUS_SUCCESS;\n\n    GGML_UNUSED(backend);\n}\n\nstatic struct ggml_backend_i blas_backend_i = {\n    /* .get_name                = */ ggml_backend_blas_get_name,\n    /* .free                    = */ ggml_backend_blas_free,\n    /* .set_tensor_async        = */ NULL,\n    /* .get_tensor_async        = */ NULL,\n    /* .cpy_tensor_async        = */ NULL,\n    /* .synchronize             = */ NULL,\n    /* .graph_plan_create       = */ NULL,\n    /* .graph_plan_free         = */ NULL,\n    /* .graph_plan_update       = */ NULL,\n    /* .graph_plan_compute      = */ NULL,\n    /* .graph_compute           = */ ggml_backend_blas_graph_compute,\n    /* .event_record            = */ NULL,\n    /* .event_wait              = */ NULL,\n    /* .graph_optimize          = */ NULL,\n};\n\nstatic ggml_guid_t ggml_backend_blas_guid(void) {\n    static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d };\n    return &guid;\n}\n\nggml_backend_t ggml_backend_blas_init(void) {\n    ggml_backend_blas_context * ctx = new ggml_backend_blas_context;\n\n    ggml_backend_t backend = new ggml_backend {\n        /* .guid    = */ ggml_backend_blas_guid(),\n        /* .iface   = */ blas_backend_i,\n        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),\n        /* .context = */ ctx,\n    };\n\n#if defined(GGML_BLAS_USE_OPENBLAS) && defined(GGML_USE_OPENMP)\n    if (openblas_get_parallel() != OPENBLAS_OPENMP) {\n        GGML_LOG_DEBUG(\"%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\\n\", __func__);\n    }\n#endif\n\n#if defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)\n    GGML_LOG_DEBUG(\"%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\\n\", __func__);\n#endif\n\n    return backend;\n}\n\nbool ggml_backend_is_blas(ggml_backend_t backend) {\n    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid());\n}\n\nvoid ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) {\n    GGML_ASSERT(ggml_backend_is_blas(backend_blas));\n\n    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;\n    ctx->n_threads = n_threads;\n}\n\n// device interface\n\nstatic const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {\n    return \"BLAS\";\n\n    GGML_UNUSED(dev);\n}\n\nstatic const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) {\n    #if defined(GGML_BLAS_USE_ACCELERATE)\n        return \"Accelerate\";\n    #elif defined(GGML_BLAS_USE_MKL)\n        return \"MKL\";\n    #elif defined(GGML_BLAS_USE_BLIS)\n        return \"BLIS\";\n    #elif defined(GGML_BLAS_USE_NVPL)\n        return \"NVPL\";\n    #elif defined(GGML_BLAS_USE_OPENBLAS)\n        return \"OpenBLAS\";\n    #else\n        return \"BLAS\";\n    #endif\n\n    GGML_UNUSED(dev);\n}\n\nstatic void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {\n    // no memory to report\n    *free  = 0;\n    *total = 0;\n\n    GGML_UNUSED(dev);\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {\n    return GGML_BACKEND_DEVICE_TYPE_ACCEL;\n\n    GGML_UNUSED(dev);\n}\n\nstatic void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {\n    props->name        = ggml_backend_blas_device_get_name(dev);\n    props->description = ggml_backend_blas_device_get_description(dev);\n    props->type        = ggml_backend_blas_device_get_type(dev);\n    ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total);\n    props->caps = {\n        /* .async                 = */ false,\n        /* .host_buffer           = */ false,\n        /* .buffer_from_host_ptr  = */ true,\n        /* .events                = */ false,\n    };\n}\n\nstatic ggml_backend_t ggml_backend_blas_device_init_backend(ggml_backend_dev_t dev, const char * params) {\n    return ggml_backend_blas_init();\n\n    GGML_UNUSED(dev);\n    GGML_UNUSED(params);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) {\n    return ggml_backend_cpu_buffer_type();\n\n    GGML_UNUSED(dev);\n}\n\nstatic ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {\n    return ggml_backend_cpu_buffer_from_ptr(ptr, size);\n\n    GGML_UNUSED(dev);\n    GGML_UNUSED(max_tensor_size);\n}\n\nstatic bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {\n    const struct ggml_tensor * src0 = op->src[0];\n    const struct ggml_tensor * src1 = op->src[1];\n\n    switch (op->op) {\n        case GGML_OP_NONE:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n            return true;\n\n        case GGML_OP_MUL_MAT:\n        {\n            // BLAS usually is only faster for large matrices\n            const struct ggml_tensor * src0 = op->src[0];\n            const struct ggml_tensor * src1 = op->src[1];\n\n            const int64_t ne10 = src1->ne[0];\n\n            const int64_t ne0 = op->ne[0];\n            const int64_t ne1 = op->ne[1];\n\n            // TODO: find the optimal value\n            const int64_t min_batch = 32;\n\n            return ggml_is_contiguous(src0) &&\n                   ggml_is_contiguous(src1) &&\n                   src1->type == GGML_TYPE_F32 &&\n                   (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) &&\n                   (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);\n        }\n\n        case GGML_OP_OUT_PROD:\n            return op->src[0]->type == GGML_TYPE_F32 &&\n                   op->src[1]->type == GGML_TYPE_F32 &&\n                   ggml_is_matrix(src0) &&\n                   ggml_is_matrix(src1) &&\n                   ggml_is_contiguous(src0) &&\n                   (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) &&\n                   (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);\n\n        default:\n            return false;\n\n    }\n\n    GGML_UNUSED(dev);\n}\n\nstatic bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    return ggml_backend_buft_is_host(buft);\n\n    GGML_UNUSED(dev);\n}\n\nstatic const struct ggml_backend_device_i ggml_backend_blas_device_i = {\n    /* .get_name             = */ ggml_backend_blas_device_get_name,\n    /* .get_description      = */ ggml_backend_blas_device_get_description,\n    /* .get_memory           = */ ggml_backend_blas_device_get_memory,\n    /* .get_type             = */ ggml_backend_blas_device_get_type,\n    /* .get_props            = */ ggml_backend_blas_device_get_props,\n    /* .init_backend         = */ ggml_backend_blas_device_init_backend,\n    /* .get_buffer_type      = */ ggml_backend_blas_device_get_buffer_type,\n    /* .get_host_buffer_type = */ NULL,\n    /* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_host_ptr,\n    /* .supports_op          = */ ggml_backend_blas_device_supports_op,\n    /* .supports_buft        = */ ggml_backend_blas_device_supports_buft,\n    /* .offload_op           = */ NULL,\n    /* .event_new            = */ NULL,\n    /* .event_free           = */ NULL,\n    /* .event_synchronize    = */ NULL,\n};\n\n// backend reg interface\n\nstatic const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) {\n    return \"BLAS\";\n\n    GGML_UNUSED(reg);\n}\n\nstatic size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) {\n    return 1;\n\n    GGML_UNUSED(reg);\n}\n\nstatic ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) {\n    GGML_ASSERT(index == 0);\n\n    static ggml_backend_device ggml_backend_blas_device = {\n        /* .iface   = */ ggml_backend_blas_device_i,\n        /* .reg     = */ reg,\n        /* .context = */ nullptr,\n    };\n\n    return &ggml_backend_blas_device;\n\n    GGML_UNUSED(reg);\n    GGML_UNUSED(index);\n}\n\nstatic void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) {\n    if (std::strcmp(name, \"ggml_backend_set_n_threads\") == 0) {\n        return (void *)ggml_backend_blas_set_n_threads;\n    }\n    return NULL;\n\n    GGML_UNUSED(reg);\n    GGML_UNUSED(name);\n}\n\nstatic const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {\n    /* .get_name         = */ ggml_backend_blas_reg_get_name,\n    /* .get_device_count = */ ggml_backend_blas_reg_get_device_count,\n    /* .get_device       = */ ggml_backend_blas_reg_get_device,\n    /* .get_proc_address = */ ggml_backend_blas_get_proc_address,\n};\n\nggml_backend_reg_t ggml_backend_blas_reg(void) {\n    static struct ggml_backend_reg ggml_backend_blas_reg = {\n        /* .api_version = */ GGML_BACKEND_API_VERSION,\n        /* .iface       = */ ggml_backend_blas_reg_i,\n        /* .context     = */ NULL,\n    };\n\n    return &ggml_backend_blas_reg;\n}\n\nGGML_BACKEND_DL_IMPL(ggml_backend_blas_reg)\n"
  },
  {
    "path": "src/ggml-cann/CMakeLists.txt",
    "content": "if (\"cann${CANN_INSTALL_DIR}\" STREQUAL \"cann\" AND DEFINED ENV{ASCEND_TOOLKIT_HOME})\n    set(CANN_INSTALL_DIR $ENV{ASCEND_TOOLKIT_HOME})\n    message(STATUS \"CANN: updated CANN_INSTALL_DIR from ASCEND_TOOLKIT_HOME=$ENV{ASCEND_TOOLKIT_HOME}\")\nendif()\n\n# Auto-detech Soc type and Soc version, if detect failed, will abort build\nset(SOC_VERSION \"\")\nfunction(detect_ascend_soc_type SOC_VERSION)\n    execute_process(\n        COMMAND bash -c \"npu-smi info|awk -F' ' 'NF > 0 && NR==7 {print $3}'\"\n        OUTPUT_VARIABLE npu_info\n        RESULT_VARIABLE npu_result\n        OUTPUT_STRIP_TRAILING_WHITESPACE\n    )\n    if(\"${npu_info}\" STREQUAL \"\" OR ${npu_result})\n        message(FATAL_ERROR \"Auto-detech ascend soc type failed, please specify manually or check ascend device working normally.\")\n    endif()\n    set(${SOC_VERSION} \"Ascend${npu_info}\" PARENT_SCOPE)\nendfunction()\n\nif(NOT SOC_TYPE)\n    detect_ascend_soc_type(SOC_VERSION)\n    set(SOC_TYPE \"${SOC_VERSION}\")\n    message(STATUS \"CANN: SOC_VERSION auto-detected is:${SOC_VERSION}\")\nendif()\n\nstring(TOLOWER ${SOC_TYPE} SOC_VERSION) # SOC_VERSION need lower\n\n# Construct Soc specify compile option: ASCEND_#Soc_Major_SN. Such as ASCEND_910B, ASCEND_310P.\nstring(REGEX MATCH \"[0-9]+[a-zA-Z]\" SOC_TYPE_MAJOR_SN \"${SOC_VERSION}\")\nset(SOC_TYPE_COMPILE_OPTION \"ASCEND_${SOC_TYPE_MAJOR_SN}\")\nstring(TOUPPER ${SOC_TYPE_COMPILE_OPTION} SOC_TYPE_COMPILE_OPTION)\nmessage(STATUS \"CANN: SOC_VERSION =  ${SOC_VERSION}\")\noption(USE_ACL_GRAPH \"Enable CANN graph execution (ACL graph mode)\" OFF)\n\nif(USE_ACL_GRAPH AND (SOC_TYPE_MAJOR_SN STREQUAL \"310P\" OR SOC_TYPE_COMPILE_OPTION STREQUAL \"ASCEND_310P\"))\n    message(FATAL_ERROR\n        \"CANN Graph (ACL graph mode) is not supported on 310P devices. \"\n        \"Please build with -DUSE_ACL_GRAPH=OFF or use a supported SOC.\")\nendif()\n\nif (CANN_INSTALL_DIR)\n    # Only Support Linux.\n    if (NOT UNIX)\n        message(FATAL_ERROR \"CANN: CANN toolkit supports unix but not ${CMAKE_SYSTEM_NAME}\")\n    endif()\n\n    # Supported platforms: x86-64, arm64\n    if (CMAKE_SYSTEM_PROCESSOR STREQUAL \"aarch64\")\n    elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL \"x86_64\" OR CMAKE_SYSTEM_PROCESSOR STREQUAL \"amd64\")\n    else()\n        message(FATAL_ERROR \"CANN: CANN toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}\")\n    endif()\n\n    # Set header and libs\n    set(CANN_INCLUDE_DIRS\n        ${CANN_INSTALL_DIR}/include\n        ${CANN_INSTALL_DIR}/include/aclnn\n        ${CANN_INSTALL_DIR}/acllib/include\n    )\n\n    list(APPEND CANN_LIBRARIES\n        ascendcl\n        nnopbase\n        opapi\n        acl_op_compiler\n    )\n\n    file(GLOB GGML_SOURCES_CANN \"*.cpp\")\n\n    ggml_add_backend_library(ggml-cann ${GGML_SOURCES_CANN})\n    target_link_libraries(ggml-cann PRIVATE ${CANN_LIBRARIES})\n    target_include_directories(ggml-cann PRIVATE ${CANN_INCLUDE_DIRS})\n    target_link_directories(ggml-cann PRIVATE ${CANN_INSTALL_DIR}/lib64)\n\n    target_compile_definitions(ggml-cann PRIVATE \"-D${SOC_TYPE_COMPILE_OPTION}\")\n\n    if (USE_ACL_GRAPH)\n        target_compile_definitions(ggml-cann PRIVATE USE_ACL_GRAPH)\n        message(STATUS \"CANN: USE_ACL_GRAPH is enabled.\")\n    else()\n        message(STATUS \"CANN: USE_ACL_GRAPH is disabled.\")\n    endif()\n\n    message(STATUS \"CANN: CANN_INCLUDE_DIRS =  ${CANN_INCLUDE_DIRS}\")\n    message(STATUS \"CANN: CANN_LIBRARIES =  ${CANN_LIBRARIES}\")\nelse()\n    message(FATAL_ERROR \"CANN: Can't find CANN_INSTALL_DIR, did you forget to source set_var.sh?\")\nendif()\n"
  },
  {
    "path": "src/ggml-cann/acl_tensor.cpp",
    "content": "/*\n * Copyright (c) 2023-2026 The ggml authors\n *\n * Permission is hereby granted, free of charge, to any person obtaining a copy\n * of this software and associated documentation files (the \"Software\"), to\n * deal in the Software without restriction, including without limitation the\n * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or\n * sell copies of the Software, and to permit persons to whom the Software is\n * furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS\n * IN THE SOFTWARE.\n */\n\n#include \"acl_tensor.h\"\n\n#include <algorithm>\n#include <cstring>\n\naclDataType ggml_cann_type_mapping(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_F32:\n            return ACL_FLOAT;\n        case GGML_TYPE_F16:\n            return ACL_FLOAT16;\n        case GGML_TYPE_BF16:\n            return ACL_BF16;\n        case GGML_TYPE_I8:\n            return ACL_INT8;\n        case GGML_TYPE_I16:\n            return ACL_INT16;\n        case GGML_TYPE_I32:\n            return ACL_INT32;\n        case GGML_TYPE_Q4_0:\n            return ACL_INT4;\n        case GGML_TYPE_Q8_0:\n            return ACL_INT8;\n        case GGML_TYPE_I64:\n            return ACL_INT64;\n        default:\n            return ACL_DT_UNDEFINED;\n    }\n}\n\nacl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,\n                                       int64_t *           ne,\n                                       size_t *            nb,\n                                       int64_t             dims,\n                                       aclFormat           format,\n                                       size_t              offset) {\n    // If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be\n    // added.\n    int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2];\n\n    if (ne == nullptr) {\n        for (int i = 0; i < GGML_MAX_DIMS; i++) {\n            acl_ne[i]     = tensor->ne[i];\n            // The step size of acl is in elements.\n            acl_stride[i] = tensor->nb[i] / ggml_element_size(tensor);\n        }\n    } else {\n        // With bcast\n        for (int i = 0; i < dims; i++) {\n            acl_ne[i]     = ne[i];\n            acl_stride[i] = nb[i] / ggml_element_size(tensor);\n        }\n    }\n\n    int64_t final_dims      = (dims == 0 ? GGML_MAX_DIMS : dims);\n    int64_t acl_storage_len = 1;\n    for (int i = 0; i < final_dims; i++) {\n        acl_storage_len += (acl_ne[i] - 1) * acl_stride[i];\n    }\n    size_t elem_offset = offset / ggml_element_size(tensor);\n    acl_storage_len += elem_offset;\n\n    // Reverse ne and stride.\n    std::reverse(acl_ne, acl_ne + final_dims);\n    std::reverse(acl_stride, acl_stride + final_dims);\n\n    aclTensor * raw = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride, elem_offset,\n                                      format, &acl_storage_len, 1, tensor->data);\n\n    return acl_tensor_ptr(raw);\n}\n\nacl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size) {\n    aclIntArray * raw = aclCreateIntArray(value, size);\n    return acl_int_array_ptr(raw);\n}\n\nacl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType) {\n    aclScalar * raw = aclCreateScalar(value, dataType);\n    return acl_scalar_ptr(raw);\n}\n\nbool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1) {\n    for (int i = 0; i < GGML_MAX_DIMS; i++) {\n        if (t1->ne[i] != t0->ne[i] && t1->ne[i] != 1) {\n            return true;\n        }\n    }\n    return false;\n}\n\nint64_t ggml_cann_get_bcast_shape(const ggml_tensor * src0,\n                                  const ggml_tensor * src1,\n                                  int64_t *           bcast_src0_ne,\n                                  int64_t *           bcast_src1_ne,\n                                  size_t *            bcast_src0_nb,\n                                  size_t *            bcast_src1_nb) {\n    GGML_ASSERT(ggml_can_repeat(src1, src0));\n    int bcast_dim_cnt = 0;\n    for (int i = 0; i < GGML_MAX_DIMS; i++) {\n        int64_t nr                   = src0->ne[i] / src1->ne[i];\n        bcast_src0_ne[bcast_dim_cnt] = src0->ne[i] / nr;\n        bcast_src1_ne[bcast_dim_cnt] = src1->ne[i];\n        bcast_src0_nb[bcast_dim_cnt] = src0->nb[i];\n        bcast_src1_nb[bcast_dim_cnt] = src1->nb[i];\n        bcast_dim_cnt++;\n        if (nr != 1) {\n            // Need to add an extra dim.\n            bcast_src0_ne[bcast_dim_cnt] = nr;\n            bcast_src1_ne[bcast_dim_cnt] = 1;\n            bcast_src0_nb[bcast_dim_cnt] = bcast_src0_nb[bcast_dim_cnt - 1] * bcast_src0_ne[bcast_dim_cnt - 1];\n            bcast_src1_nb[bcast_dim_cnt] = bcast_src1_nb[bcast_dim_cnt - 1] * bcast_src1_ne[bcast_dim_cnt - 1];\n            bcast_dim_cnt++;\n        }\n    }\n    return bcast_dim_cnt;\n}\n\nint64_t ggml_cann_get_mulmat_bcast_shape(const int64_t * input_ne,\n                                         const int64_t * weight_ne,\n                                         const int64_t * dst_ne,\n                                         const size_t *  input_nb,\n                                         const size_t *  weight_nb,\n                                         const size_t *  dst_nb,\n                                         int64_t *       bcast_input_ne,\n                                         int64_t *       bcast_weight_ne,\n                                         int64_t *       bcast_dst_ne,\n                                         size_t *        bcast_input_nb,\n                                         size_t *        bcast_weight_nb,\n                                         size_t *        bcast_dst_nb) {\n    // input and dst shoule in same shape, except first two dims.\n    GGML_ASSERT(input_ne[2] == dst_ne[2]);\n    GGML_ASSERT(input_ne[3] == dst_ne[3]);\n\n    int bcast_dim_cnt = 0;\n\n    // For mul_mat, a dimension needs to be added before the dimension that\n    // weight needs to be expanded to satisfy the bcast rule of matrix\n    // multiplication.\n    for (int i = 0; i < GGML_MAX_DIMS; i++) {\n        int64_t nr = input_ne[i] / weight_ne[i];\n        // Do not use bcast in the first two dimensions because we only support\n        // the bcast batch dimension. Just copy them.\n        if (i < 2 || nr == 1) {\n            bcast_input_ne[bcast_dim_cnt]  = input_ne[i];\n            bcast_weight_ne[bcast_dim_cnt] = weight_ne[i];\n            bcast_dst_ne[bcast_dim_cnt]    = dst_ne[i];\n\n            bcast_input_nb[bcast_dim_cnt]  = input_nb[i];\n            bcast_weight_nb[bcast_dim_cnt] = weight_nb[i];\n            bcast_dst_nb[bcast_dim_cnt]    = dst_nb[i];\n            bcast_dim_cnt++;\n        } else {\n            // Need to add an extra dim.\n            bcast_input_ne[bcast_dim_cnt]  = nr;\n            bcast_dst_ne[bcast_dim_cnt]    = nr;\n            bcast_weight_ne[bcast_dim_cnt] = 1;\n            bcast_input_nb[bcast_dim_cnt]  = input_nb[i];\n            bcast_dst_nb[bcast_dim_cnt]    = dst_nb[i];\n            bcast_weight_nb[bcast_dim_cnt] = weight_nb[i];\n            bcast_dim_cnt++;\n\n            bcast_input_ne[bcast_dim_cnt]  = input_ne[i] / nr;\n            bcast_dst_ne[bcast_dim_cnt]    = dst_ne[i] / nr;\n            bcast_weight_ne[bcast_dim_cnt] = weight_ne[i];\n            bcast_input_nb[bcast_dim_cnt]  = bcast_input_nb[bcast_dim_cnt - 1] * bcast_input_ne[bcast_dim_cnt - 1];\n            bcast_dst_nb[bcast_dim_cnt]    = bcast_dst_nb[bcast_dim_cnt - 1] * bcast_dst_ne[bcast_dim_cnt - 1];\n            bcast_weight_nb[bcast_dim_cnt] = bcast_weight_nb[bcast_dim_cnt - 1] * bcast_weight_ne[bcast_dim_cnt - 1];\n            bcast_dim_cnt++;\n        }\n    }\n    return bcast_dim_cnt;\n}\n"
  },
  {
    "path": "src/ggml-cann/acl_tensor.h",
    "content": "/*\n * Copyright (c) 2023-2026 The ggml authors\n *\n * Permission is hereby granted, free of charge, to any person obtaining a copy\n * of this software and associated documentation files (the \"Software\"), to\n * deal in the Software without restriction, including without limitation the\n * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or\n * sell copies of the Software, and to permit persons to whom the Software is\n * furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS\n * IN THE SOFTWARE.\n */\n\n#ifndef CANN_ACL_TENSOR_H\n#define CANN_ACL_TENSOR_H\n\n#include \"common.h\"\n\n#include <aclnn/aclnn_base.h>\n\n#include <algorithm>\n#include <cstring>\n\n/**\n * @brief\tMaps a ggml_type to its corresponding aclDataType.\n *\n * @details\tThis function takes a ggml_type as input and returns the corresponding\n *\t\t\taclDataType. It supports mapping for various ggml_types. If the input type\n *\t\t\tdoes not match any of the predefined ggml_types, the function returns\n *          ACL_DT_UNDEFINED.\n *\n * @param\ttype    The ggml_type to be mapped.\n * @return\tThe corresponding aclDataType. If the input type is not recognized,\n *\t\t\tACL_DT_UNDEFINED is returned.\n */\naclDataType ggml_cann_type_mapping(ggml_type type);\n\n// Deleter for acl objects.\ntemplate <typename T, aclError (*DestroyFunc)(const T *)> struct acl_deleter {\n    void operator()(T * ptr) const noexcept {\n        if (ptr) {\n            ACL_CHECK(DestroyFunc(ptr));\n        }\n    }\n};\n\nusing acl_tensor_ptr      = std::unique_ptr<aclTensor, acl_deleter<aclTensor, aclDestroyTensor>>;\nusing acl_int_array_ptr   = std::unique_ptr<aclIntArray, acl_deleter<aclIntArray, aclDestroyIntArray>>;\nusing acl_scalar_ptr      = std::unique_ptr<aclScalar, acl_deleter<aclScalar, aclDestroyScalar>>;\nusing acl_tensor_list_ptr = std::unique_ptr<aclTensorList, acl_deleter<aclTensorList, aclDestroyTensorList>>;\n\n/**\n * @brief   Creates an ACL tensor from a ggml_tensor with optional shape.\n *\n * @details This function creates an ACL tensor based on the properties of the\n *          provided ggml_tensor. It supports customer shape by adjusting dimensions\n *          and strides accordingly. If customer shape is applied, additional\n *          dimensions and strides are calculated based on the provided parameters.\n *\n * @param   tensor      Pointer to the ggml_tensor to be converted to ACL tensor.\n * @param   ne          Pointer to an array containing dimensions. Defaults to nullptr\n *                      if no customer shape is applied.\n * @param   nb          Pointer to an array containing strides. Defaults to nullptr\n *                      if no customer shape is applied.\n * @param   dims        Number of dimensions in the tensor. Defaults to 0 if no customer\n *                      shape is applied.\n * @param   format      ACL tensor format. Defaults to ACL_FORMAT_ND.\n * @param   offset      Offset in bytes for the ACL tensor data. Defaults to 0.\n * @return  Pointer to the created ACL tensor.\n */\nacl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,\n                                       int64_t *           ne     = nullptr,\n                                       size_t *            nb     = nullptr,\n                                       int64_t             dims   = 0,\n                                       aclFormat           format = ACL_FORMAT_ND,\n                                       size_t              offset = 0);\n\n/**\n * @brief   Template for creating an ACL tensor from provided parameters. typename TYPE\n *          should be size_t or float.\n *\n * @details This function creates an ACL tensor using the provided data pointer,\n *          data type, dimensions, strides, format, offset, and additional parameters.\n *          It calculates necessary dimensions and strides based on the provided ne and nb\n *          arrays, adjusting them for the ACL tensor creation. The ACL storage length\n *          is also calculated based on the provided dimensions and strides.\n *\n * @param   data_ptr    Pointer to the data buffer for the ACL tensor.\n * @param   dtype       ACL data type of the tensor.\n * @param   type_size   Size of each element in the tensor data buffer.\n * @param   ne          Pointer to an array containing tensor dimensions.\n * @param   nb          Pointer to an array containing tensor strides.\n * @param   dims        Number of dimensions of the tensor.\n * @param   format      ACL tensor format. Defaults to ACL_FORMAT_ND.\n * @param   offset      Offset in bytes for the ACL tensor data. Defaults to 0.\n * @return  Pointer to the created ACL tensor.\n */\ntemplate <typename TYPE>\nacl_tensor_ptr ggml_cann_create_tensor(void *      data_ptr,\n                                       aclDataType dtype,\n                                       TYPE        type_size,\n                                       int64_t *   ne,\n                                       TYPE *      nb,\n                                       int64_t     dims,\n                                       aclFormat   format = ACL_FORMAT_ND,\n                                       size_t      offset = 0) {\n    int64_t tmp_ne[GGML_MAX_DIMS * 2];\n    int64_t tmp_stride[GGML_MAX_DIMS * 2];\n\n    memcpy(tmp_ne, ne, dims * sizeof(int64_t));\n    for (int i = 0; i < dims; i++) {\n        tmp_stride[i] = nb[i] / type_size;\n    }\n\n    int64_t acl_storage_len = 1;\n    for (int i = 0; i < dims; i++) {\n        acl_storage_len += (tmp_ne[i] - 1) * tmp_stride[i];\n    }\n\n    std::reverse(tmp_ne, tmp_ne + dims);\n    std::reverse(tmp_stride, tmp_stride + dims);\n\n    aclTensor * raw =\n        aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr);\n\n    return acl_tensor_ptr(raw);\n}\n\n/**\n * @brief Create an ACL int array resource wrapped in a smart pointer.\n *\n * This function constructs an aclIntArray from the provided int64_t values\n * and returns it as an acl_int_array_ptr (a std::unique_ptr with a custom\n * deleter). The returned pointer owns the ACL resource and will automatically\n * destroy it via aclDestroyIntArray().\n *\n * @param value  Pointer to the int64_t elements.\n * @param size   Number of elements in value.\n *\n * @return A smart pointer managing the created ACL int array.\n */\nacl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size);\n\n/**\n * @brief Create an ACL scalar resource wrapped in a smart pointer.\n *\n * This function constructs an aclScalar from the raw value pointer and ACL\n * data type, then returns it as an acl_scalar_ptr (a std::unique_ptr with\n * a custom deleter). The returned pointer owns the ACL scalar and will\n * automatically destroy it via aclDestroyScalar().\n *\n * @param value     Pointer to the raw scalar memory.\n * @param dataType  ACL data type of the scalar.\n *\n * @return A smart pointer managing the created ACL scalar.\n */\nacl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType);\n\n/**\n * @brief Create an ACL tensor list from multiple tensor smart pointers.\n *\n * This function accepts a variadic list of acl_tensor_ptr (a unique_ptr with\n * custom deleter) and produces an aclTensorList using aclCreateTensorList().\n *\n * The lifecycle management of the tensor objects changes as follows:\n *  - aclCreateTensorList() takes ownership of the tensors\n *  - Each input smart pointer releases ownership using release()\n *  - As a result, the tensors will NOT be destroyed by unique_ptr\n *  - Instead, they will be destroyed when aclDestroyTensorList() is called\n *\n * This ensures correct ownership transfer and prevents double-free situations.\n *\n * @param acl_tensor_ptr  Variadic template parameter; each argument must be\n *                         a unique_ptr-like type supporting get() and release().\n *\n * @param tensors  Variadic list of acl_tensor_ptr objects. Ownership of\n *                         each tensor is transferred away from these smart pointers.\n *\n * @return A smart pointer (acl_tensor_list_ptr) owning the created ACL tensor list.\n *\n * @note This implementation is C++11 compatible. The ownership-release process is\n *       executed using a pack expansion inside an initializer list.\n */\ntemplate <typename... acl_tensor_ptr> acl_tensor_list_ptr ggml_cann_create_tensor_list(acl_tensor_ptr &&... tensors) {\n    aclTensor *     raw_tensors[] = { tensors.get()... };\n    aclTensorList * raw           = aclCreateTensorList(raw_tensors, sizeof...(tensors));\n    // aclTensor will release by aclTensorList, so release ownership without\n    // destroying the tensor\n    int             dummy[]       = { (tensors.release(), 0)... };\n    GGML_UNUSED(dummy);\n    return acl_tensor_list_ptr(raw);\n}\n\n/**\n * @brief   Checks if tensors require broadcasting based on their shapes.\n *\n * @details This function determines if two ggml_tensors need to be broadcasted for\n *          element-wise operations. Broadcasting is necessary if the shapes of the\n *          tensors are not identical and no dimension in either tensor equals 1.\n *\n * @param   t0      Pointer to the first ggml_tensor.\n * @param   t1      Pointer to the second ggml_tensor.\n * @return  True if broadcasting is needed, False otherwise.\n *\n * @remarks This function iterates over the dimensions of t0 and t1. It checks if each\n *          dimension in t1 differs from t0's corresponding dimension and is not equal\n *          to 1. If such a dimension is found, broadcasting is required to align t1\n *          with t0 for element-wise operations.\n */\nbool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1);\n\n/**\n * @brief   Computes broadcast shapes and strides for two ggml_tensors.\n *\n * @details This function calculates the broadcast shapes and strides for two ggml_tensors,\n *          following the broadcasting rules similar to numpy. It adjusts dimensions and\n *          strides to ensure compatibility for element-wise operations where one tensor\n *          can be broadcasted to match the shape of another tensor.\n *\n * @param   src0                Pointer to the first ggml_tensor.\n * @param   src1                Pointer to the second ggml_tensor.\n * @param   bcast_ne_src0       Output array to store broadcasted dimensions for src0.\n * @param   bcast_ne_src1       Output array to store broadcasted dimensions for src1.\n * @param   bcast_nb_src0       Output array to store broadcasted strides for src0.\n * @param   bcast_nb_src1       Output array to store broadcasted strides for src1.\n * @return  Number of dimensions in the broadcasted shape.\n *\n * @pre     ggml_can_repeat(src1, src0) must return true, indicating src1 can be broadcasted\n *          to match src0.\n *\n * @remarks This function iterates over the dimensions of src0 and src1, calculating the\n *          necessary broadcast dimensions and strides. If a dimension requires broadcasting\n *          (i.e., its size in src1 is smaller than in src0), an additional dimension is\n *          added with size calculated to match src0's dimension. This adjustment ensures\n *          that src1 can be element-wise broadcasted to src0's shape.\n *\n *  How it works:\n *\n *  if dim0 has padding.\n *  a -> (2, 2) padding = 2\n *   a: [[1, 2, *, *]\n *       [2, 3, *, *]]\n *  nb = (8, 4, 2)\n *\n *  if a should bcast with b -> (2, 4)\n *  b' -> (2, 2, 2)\n *  b : [[1, 2, 3, 4, *, *]\n *       [5, 6, 7, 8, *, *]]\n *  nb = (12, 6, 1)\n *\n *  after bcast:\n *  a' -> (2, 1, 2)\n *  a': [[[1, 2], *, *]\n *       [[2, 3], *, *]]\n *  nb = (8, 4, 2, 1)\n *\n *  b' : [[[1, 2], [3, 4], *, *]\n *        [[5, 6], [7, 8], *, *]]\n *  nb = (12, 6, 2, 1)\n *  \\endcode\n *\n *  dim1 in a inserted dim, should add nb for dim1,\n *  and all other nb moves to next in order.\n */\nint64_t ggml_cann_get_bcast_shape(const ggml_tensor * src0,\n                                  const ggml_tensor * src1,\n                                  int64_t *           bcast_ne_src0,\n                                  int64_t *           bcast_ne_src1,\n                                  size_t *            bcast_nb_src0,\n                                  size_t *            bcast_nb_src1);\n\n// Bcast macro to avoid duplicate code.\n#define BCAST_SHAPE(src0, src1)                                                                      \\\n    int64_t bcast_##src0##_ne[GGML_MAX_DIMS * 2];                                                    \\\n    int64_t bcast_##src1##_ne[GGML_MAX_DIMS * 2];                                                    \\\n    size_t  bcast_##src0##_nb[GGML_MAX_DIMS * 2];                                                    \\\n    size_t  bcast_##src1##_nb[GGML_MAX_DIMS * 2];                                                    \\\n    int64_t bcast_dims = ggml_cann_get_bcast_shape(src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, \\\n                                                   bcast_##src0##_nb, bcast_##src1##_nb);\n\n#define BCAST_PARAM(tensor) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims\n\n/**\n * @brief Calculates broadcast shapes for matrix multiplication.\n *\n * @details This function computes the broadcast shapes required for matrix multiplication\n *          based on the input, weight, and destination tensor shapes. It ensures that the\n *          dimensions of weight tensors are expanded appropriately to satisfy matrix\n *          multiplication broadcast rules.\n *\n * @param input_ne      Array containing the dimensions of the input tensor.\n * @param weight_ne     Array containing the dimensions of the weight tensor.\n * @param dst_ne        Array containing the dimensions of the destination tensor.\n * @param input_nb      Array containing the strides of the input tensor.\n * @param weight_nb     Array containing the strides of the weight tensor.\n * @param dst_nb        Array containing the strides of the destination tensor.\n * @param bcast_input_ne    Output array for broadcasted input tensor dimensions.\n * @param bcast_weight_ne   Output array for broadcasted weight tensor dimensions.\n * @param bcast_dst_ne      Output array for broadcasted destination tensor dimensions.\n * @param bcast_input_nb    Output array for broadcasted input tensor strides.\n * @param bcast_weight_nb   Output array for broadcasted weight tensor strides.\n * @param bcast_dst_nb      Output array for broadcasted destination tensor strides.\n * @return The number of dimensions in the broadcasted tensors.\n *\n * @remarks This function iterates over the tensor dimensions and calculates the broadcast\n *          shapes needed for matrix multiplication. It ensures that dimensions where\n *          weight tensor requires expansion are appropriately handled to conform with\n *          broadcasting rules.\n * @note compare with ggml_cann_get_bcast_shape, mul_mat broadcast need add this new dim\n *       before cast dim.\n * @sa ggml_cann_get_bcast_shape\n */\nint64_t ggml_cann_get_mulmat_bcast_shape(const int64_t * input_ne,\n                                         const int64_t * weight_ne,\n                                         const int64_t * dst_ne,\n                                         const size_t *  input_nb,\n                                         const size_t *  weight_nb,\n                                         const size_t *  dst_nb,\n                                         int64_t *       bcast_input_ne,\n                                         int64_t *       bcast_weight_ne,\n                                         int64_t *       bcast_dst_ne,\n                                         size_t *        bcast_input_nb,\n                                         size_t *        bcast_weight_nb,\n                                         size_t *        bcast_dst_nb);\n\n// Bcast macro to avoid duplicate code.\n#define BCAST_MUL_MAT_SHAPE(input, weight, dst)                                                                  \\\n    int64_t bcast_##input##_ne[GGML_MAX_DIMS * 2];                                                               \\\n    int64_t bcast_##weight##_ne[GGML_MAX_DIMS * 2];                                                              \\\n    int64_t bcast_##dst##_ne[GGML_MAX_DIMS * 2];                                                                 \\\n    size_t  bcast_##input##_nb[GGML_MAX_DIMS * 2];                                                               \\\n    size_t  bcast_##weight##_nb[GGML_MAX_DIMS * 2];                                                              \\\n    size_t  bcast_##dst##_nb[GGML_MAX_DIMS * 2];                                                                 \\\n    int64_t bcast_dims = ggml_cann_get_mulmat_bcast_shape(                                                       \\\n        input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, bcast_##input##_ne, bcast_##weight##_ne, \\\n        bcast_##dst##_ne, bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb);\n\n#define BCAST_MUL_MAT_PARAM(tensor) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims\n\n#endif  // CANN_ACL_TENSOR_H\n"
  },
  {
    "path": "src/ggml-cann/aclnn_ops.cpp",
    "content": "/*\n * Copyright (c) 2023-2026 The ggml authors\n *\n * Permission is hereby granted, free of charge, to any person obtaining a copy\n * of this software and associated documentation files (the \"Software\"), to\n * deal in the Software without restriction, including without limitation the\n * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or\n * sell copies of the Software, and to permit persons to whom the Software is\n * furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS\n * IN THE SOFTWARE.\n */\n\n#include \"aclnn_ops.h\"\n\n#include \"ggml-impl.h\"\n#include \"ggml.h\"\n\n#include <aclnnop/aclnn_add.h>\n#include <aclnnop/aclnn_add_rms_norm.h>\n#include <aclnnop/aclnn_addcdiv.h>\n#include <aclnnop/aclnn_argmax.h>\n#include <aclnnop/aclnn_avgpool2d.h>\n#include <aclnnop/aclnn_batch_matmul.h>\n#include <aclnnop/aclnn_cast.h>\n#include <aclnnop/aclnn_clamp.h>\n#include <aclnnop/aclnn_constant_pad_nd.h>\n#include <aclnnop/aclnn_convolution.h>\n#include <aclnnop/aclnn_copy.h>\n#include <aclnnop/aclnn_div.h>\n#include <aclnnop/aclnn_elu.h>\n#include <aclnnop/aclnn_embedding.h>\n#include <aclnnop/aclnn_eq_tensor.h>\n#include <aclnnop/aclnn_exp.h>\n#include <aclnnop/aclnn_fill_scalar.h>\n#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>\n#include <aclnnop/aclnn_ger.h>\n#include <aclnnop/aclnn_group_norm.h>\n#include <aclnnop/aclnn_grouped_matmul_v3.h>\n#include <aclnnop/aclnn_gt_scalar.h>\n#include <aclnnop/aclnn_im2col.h>\n#include <aclnnop/aclnn_index_copy.h>\n#include <aclnnop/aclnn_index_fill_tensor.h>\n#include <aclnnop/aclnn_index_select.h>\n#include <aclnnop/aclnn_layer_norm.h>\n#include <aclnnop/aclnn_log.h>\n#include <aclnnop/aclnn_matmul.h>\n#include <aclnnop/aclnn_max_pool.h>\n#include <aclnnop/aclnn_mean.h>\n#include <aclnnop/aclnn_mm.h>\n#include <aclnnop/aclnn_mul.h>\n#include <aclnnop/aclnn_mv.h>\n#include <aclnnop/aclnn_permute.h>\n#include <aclnnop/aclnn_pow.h>\n#include <aclnnop/aclnn_pow_tensor_tensor.h>\n#include <aclnnop/aclnn_reduce_sum.h>\n#include <aclnnop/aclnn_reflection_pad1d.h>\n#include <aclnnop/aclnn_repeat.h>\n#include <aclnnop/aclnn_repeat_interleave.h>\n#include <aclnnop/aclnn_rms_norm.h>\n#include <aclnnop/aclnn_roll.h>\n#include <aclnnop/aclnn_softmax.h>\n#include <aclnnop/aclnn_sub.h>\n#include <aclnnop/aclnn_sum.h>\n#include <aclnnop/aclnn_threshold.h>\n#include <aclnnop/aclnn_tril.h>\n#include <aclnnop/aclnn_triu.h>\n#include <aclnnop/aclnn_upsample_nearest_2d.h>\n#include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>\n#include <aclnnop/aclnn_zero.h>\n#include <float.h>\n\n#include <cmath>\n#include <cstring>\n#include <exception>\n#include <vector>\n\n#define GGML_COMMON_DECL_C\n\n#include \"../ggml-common.h\"\n\nvoid bcast_shape(ggml_tensor *    src0,\n                 ggml_tensor *    src1,\n                 ggml_tensor *    dst,\n                 acl_tensor_ptr & acl_src0,\n                 acl_tensor_ptr & acl_src1,\n                 acl_tensor_ptr & acl_dst) {\n    GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0));\n    // Need bcast\n    if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) {\n        BCAST_SHAPE(src0, src1)\n        acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0));\n        acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1));\n        acl_dst  = ggml_cann_create_tensor(dst, BCAST_PARAM(src0));\n    } else {\n        acl_src0 = ggml_cann_create_tensor(src0);\n        acl_src1 = ggml_cann_create_tensor(src1);\n        acl_dst  = ggml_cann_create_tensor(dst);\n    }\n}\n\nvoid ggml_cann_op_unary(std::function<void(ggml_backend_cann_context &, aclTensor *, aclTensor *)> unary_op,\n                        ggml_backend_cann_context &                                                ctx,\n                        ggml_tensor *                                                              dst) {\n    ggml_tensor * src = dst->src[0];\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    unary_op(ctx, acl_src.get(), acl_dst.get());\n}\n\nvoid ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, aclTensor *, aclTensor *)> unary_op,\n                              ggml_backend_cann_context &                                                ctx,\n                              ggml_tensor *                                                              dst) {\n    ggml_tensor * src0 = dst->src[0];\n    ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(ggml_is_contiguous_1(dst));\n    const int32_t swapped = ggml_get_op_params_i32(dst, 1);\n\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n    acl_tensor_ptr acl_src0, acl_src1;\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src0->type == src1->type);\n\n        acl_src0 = ggml_cann_create_tensor(src0);\n        acl_src1 = ggml_cann_create_tensor(src1);\n    } else {\n        int64_t ne[] = { src0->ne[0] / 2, src0->ne[1], src0->ne[2], src0->ne[3] };\n        size_t  nb[] = { src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3] };\n        acl_src0     = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, 0);\n        acl_src1 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, ne[0] * ggml_element_size(src0));\n        if (swapped) {\n            std::swap(acl_src0, acl_src1);\n        }\n    }\n\n    unary_op(ctx, acl_src0.get(), acl_dst.get());\n    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst.get(), acl_src1.get());\n}\n\n/**\n * @brief Repeats elements of a tensor along each dimension according to the\n * specified repeat array.\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src The source tensor to be repeated.\n * @param acl_dst The destination tensor after repeating.\n * @param repeat_array The array specifying the number of repetitions along each\n * dimension.\n */\nstatic void aclnn_repeat(ggml_backend_cann_context & ctx,\n                         aclTensor *                 acl_src,\n                         aclTensor *                 acl_dst,\n                         int64_t *                   repeat_array) {\n    // repeat tensor along each dim with repeat_array\n    acl_int_array_ptr repeats = ggml_cann_create_int_array(repeat_array, GGML_MAX_DIMS);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_src, repeats.get(), acl_dst);\n}\n\n/**\n * @brief Casts the data type of a source tensor to a destination tensor.\n *\n * This function casts the data type of the source tensor `acl_src` to the\n * specified data type `cast_data_type` and stores the result in the destination\n * tensor `acl_dst`.\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src The source tensor whose data type will be casted.\n * @param acl_dst The destination tensor where the casted result will be stored.\n * @param cast_data_type The target data type to which the source tensor will be\n * casted.\n */\nstatic void aclnn_cast(ggml_backend_cann_context & ctx,\n                       aclTensor *                 acl_src,\n                       aclTensor *                 acl_dst,\n                       aclDataType                 cast_data_type) {\n    GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src, cast_data_type, acl_dst);\n}\n\nvoid ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src = dst->src[0];\n    GGML_ASSERT(ggml_can_repeat(src, dst));\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    int64_t repeatsArray[] = { dst->ne[3] / src->ne[3], dst->ne[2] / src->ne[2], dst->ne[1] / src->ne[1],\n                               dst->ne[0] / src->ne[0] };\n\n    aclnn_repeat(ctx, acl_src.get(), acl_dst.get(), repeatsArray);\n}\n\nvoid aclnn_add(ggml_backend_cann_context & ctx, aclTensor * acl_src0, aclTensor * acl_src1, aclTensor * acl_dst) {\n    float          alphaValue = 1.0f;\n    acl_scalar_ptr alpha      = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);\n    if (acl_dst != nullptr) {\n        GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_src0, acl_src1, alpha.get(), acl_dst);\n    } else {\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_src0, acl_src1, alpha.get());\n    }\n}\n\nvoid aclnn_sub(ggml_backend_cann_context & ctx, aclTensor * acl_src0, aclTensor * acl_src1, aclTensor * acl_dst) {\n    float          alphaValue = 1.0f;\n    acl_scalar_ptr alpha      = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);\n    if (acl_dst != nullptr) {\n        GGML_CANN_CALL_ACLNN_OP(ctx, Sub, acl_src0, acl_src1, alpha.get(), acl_dst);\n    } else {\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSub, acl_src0, acl_src1, alpha.get());\n    }\n}\n\nvoid aclnn_mul(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_other, aclTensor * acl_dst) {\n    if (acl_dst != nullptr) {\n        GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_src, acl_other, acl_dst);\n    } else {\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_src, acl_other);\n    }\n}\n\nvoid aclnn_div(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_other, aclTensor * acl_dst) {\n    if (acl_dst != nullptr) {\n        GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src, acl_other, acl_dst);\n    } else {\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDiv, acl_src, acl_other);\n    }\n}\n\n/**\n * @brief Multiplies elements of a tensor by a scalar value, optionally\n * in-place.\n *\n * This function multiplies each element of the source tensor `acl_src` by the\n * scalar `scale` and stores the result in the destination tensor `acl_dst`. If\n * `inplace` is true, `acl_dst` will not be used and the operation is performed\n *  in-place on `acl_src`.\n * The operation is defined as:\n * \\f[\n *     \\text {acl_dst }_i=\\text {acl_src }_i \\times \\text {scale}\n * \\f]\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src The source tensor whose elements will be multiplied.\n * @param scale The scalar value by which each element of `acl_src` will be\n *  multiplied.\n * @param acl_dst The destination tensor where the result will be stored if\n * `inplace` is false.\n * @param inplace Flag indicating whether to perform the operation in-place on\n * `acl_src`.\n */\nstatic void aclnn_muls(ggml_backend_cann_context & ctx,\n                       aclTensor *                 acl_src,\n                       float                       scale,\n                       aclTensor *                 acl_dst,\n                       bool                        inplace) {\n    acl_scalar_ptr acl_scale = ggml_cann_create_scalar(&scale, aclDataType::ACL_FLOAT);\n    if (inplace) {\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_src, acl_scale.get());\n    } else {\n        GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_src, acl_scale.get(), acl_dst);\n    }\n}\n\nvoid ggml_cann_leaky_relu(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src = dst->src[0];\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    float negative_slope;\n    memcpy(&negative_slope, dst->op_params, sizeof(float));\n    acl_scalar_ptr acl_negative_slope = ggml_cann_create_scalar(&negative_slope, aclDataType::ACL_FLOAT);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, LeakyRelu, acl_src.get(), acl_negative_slope.get(), acl_dst.get());\n}\n\n/**\n * @brief Concatenates a list of tensors along a specified dimension and stores\n * the result in a destination tensor.\n *\n * @param ctx The context for the CANN backend operations.\n * @param tensorList The list of tensors to be concatenated.\n * @param acl_dst The destination tensor where the concatenated result will be\n * stored.\n * @param concat_dim The dimension along which the tensors will be concatenated.\n */\nstatic void aclnn_concat(ggml_backend_cann_context & ctx,\n                         aclTensorList *             tensorList,\n                         aclTensor *                 acl_dst,\n                         int64_t                     concat_dim) {\n    GGML_CANN_CALL_ACLNN_OP(ctx, Cat, tensorList, concat_dim, acl_dst);\n}\n\nvoid ggml_cann_concat(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor *  src0     = dst->src[0];\n    ggml_tensor *  src1     = dst->src[1];\n    acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);\n    acl_tensor_ptr acl_src1 = ggml_cann_create_tensor(src1);\n    acl_tensor_ptr acl_dst  = ggml_cann_create_tensor(dst);\n\n    const int32_t dim = ggml_get_op_params_i32(dst, 0);\n\n    GGML_ASSERT(dim >= 0 && dim < 4);\n    int32_t acl_dim = 3 - dim;\n\n    acl_tensor_list_ptr tensor_list = ggml_cann_create_tensor_list(acl_src0, acl_src1);\n    aclnn_concat(ctx, tensor_list.get(), acl_dst.get(), acl_dim);\n}\n\n/**\n * @brief Creates a tensor with values starting from `start`, incremented by\n * `step`, and ending before `stop`.\n *\n * This function performs the operation:\n * \\f[\n *    \\text {out }_{i+1}=\\text {out }_i+\\text {step}\n * \\f]\n * the range is [start, stop).\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_dst The destination tensor where the values will be stored.\n * @param start The starting value of the range.\n * @param stop The ending value of the range (exclusive).\n * @param step The step size between consecutive values.\n * @param n_elements The number of elements in the destination tensor.\n */\nstatic void aclnn_arange(ggml_backend_cann_context & ctx,\n                         aclTensor *                 acl_dst,\n                         float                       start,\n                         float                       stop,\n                         float                       step,\n                         int64_t                     n_elements) {\n    int64_t steps = (int64_t) std::ceil((stop - start) / step);\n    GGML_ASSERT(n_elements == steps);\n\n    acl_scalar_ptr acl_start = ggml_cann_create_scalar(&start, aclDataType::ACL_FLOAT);\n    acl_scalar_ptr acl_end   = ggml_cann_create_scalar(&stop, aclDataType::ACL_FLOAT);\n    acl_scalar_ptr acl_step  = ggml_cann_create_scalar(&step, aclDataType::ACL_FLOAT);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, Arange, acl_start.get(), acl_end.get(), acl_step.get(), acl_dst);\n}\n\nvoid ggml_cann_arange(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    int64_t n_elements = ggml_nelements(dst);\n    float   start;\n    float   stop;\n    float   step;\n    memcpy(&start, (float *) dst->op_params + 0, sizeof(float));\n    memcpy(&stop, (float *) dst->op_params + 1, sizeof(float));\n    memcpy(&step, (float *) dst->op_params + 2, sizeof(float));\n\n    aclnn_arange(ctx, acl_dst.get(), start, stop, step, n_elements);\n}\n\nvoid ggml_cann_clamp(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src = dst->src[0];\n\n    float min;\n    float max;\n    memcpy(&min, dst->op_params, sizeof(float));\n    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    acl_scalar_ptr acl_min = ggml_cann_create_scalar(&min, aclDataType::ACL_FLOAT);\n    acl_scalar_ptr acl_max = ggml_cann_create_scalar(&max, aclDataType::ACL_FLOAT);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, Clamp, acl_src.get(), acl_min.get(), acl_max.get(), acl_dst.get());\n}\n\nvoid ggml_cann_scale(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src = dst->src[0];\n\n    // scale factor\n    float v;\n    memcpy(&v, dst->op_params, sizeof(float));\n\n    acl_scalar_ptr scale   = ggml_cann_create_scalar(&v, aclDataType::ACL_FLOAT);\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_src.get(), scale.get(), acl_dst.get());\n}\n\nvoid ggml_cann_argsort(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor *        src   = dst->src[0];\n    enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];\n\n    acl_tensor_ptr       acl_src = ggml_cann_create_tensor(src);\n    acl_tensor_ptr       acl_dst = ggml_cann_create_tensor(dst);\n    ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), ggml_nelements(dst) * sizeof(int64_t));\n    void *               buffer = temp_buffer_allocator.get();\n    acl_tensor_ptr       tmp_tensor =\n        ggml_cann_create_tensor(buffer, ACL_INT64, ggml_type_size(dst->type), dst->ne, dst->nb, GGML_MAX_DIMS);\n    GGML_CANN_CALL_ACLNN_OP(ctx, Argsort, acl_src.get(), -1, (order == GGML_SORT_ORDER_DESC ? true : false),\n                            tmp_tensor.get());\n    GGML_CANN_CALL_ACLNN_OP(ctx, Cast, tmp_tensor.get(), ggml_cann_type_mapping(dst->type), acl_dst.get());\n}\n\nvoid ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src = dst->src[0];\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n\n    std::vector<int64_t> normData = { dst->ne[0] };\n    acl_int_array_ptr    norm     = ggml_cann_create_int_array(normData.data(), normData.size());\n    GGML_CANN_CALL_ACLNN_OP(ctx, LayerNorm, acl_src.get(), norm.get(), nullptr, nullptr, eps, acl_dst.get(), nullptr,\n                            nullptr);\n}\n\nvoid ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src = dst->src[0];\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    size_t               type_size = ggml_type_size(src->type);\n    int64_t              n_bytes   = src->ne[3] * src->ne[2] * src->ne[1] * type_size;\n    ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes);\n    void *               buffer = temp_buffer_allocator.get();\n\n    int64_t div_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] };\n    size_t  div_nb[GGML_MAX_DIMS];\n    div_nb[0] = sizeof(float);\n    for (int i = 1; i < GGML_MAX_DIMS; ++i) {\n        div_nb[i] = div_nb[i - 1] * div_ne[i - 1];\n    }\n    acl_tensor_ptr acl_div = ggml_cann_create_tensor(buffer, ACL_FLOAT, type_size, div_ne, div_nb, GGML_MAX_DIMS);\n\n    std::vector<int64_t> norm_dims  = { 3 };\n    acl_int_array_ptr    dims_array = ggml_cann_create_int_array(norm_dims.data(), norm_dims.size());\n\n    float          p_value  = 2.0f;\n    acl_scalar_ptr p_scalar = ggml_cann_create_scalar(&p_value, aclDataType::ACL_FLOAT);\n    GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_div.get());\n    GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div.get(), acl_dst.get());\n}\n\nvoid ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];\n    ggml_tensor * src1 = dst->src[1];\n\n    const int64_t nc = src0->ne[0];\n    const int64_t nr = ggml_nrows(src0);\n\n    int64_t logits_ne[] = { nc, nr };\n    size_t  logits_nb[2];\n    logits_nb[0]              = ggml_type_size(src0->type);\n    logits_nb[1]              = logits_nb[0] * logits_ne[0];\n    acl_tensor_ptr acl_logits = ggml_cann_create_tensor(src0->data, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2);\n\n    size_t               log_softmax_type_size = sizeof(float);\n    int64_t              log_softmax_n_bytes   = nr * nc * log_softmax_type_size;\n    ggml_cann_pool_alloc log_softmax_allocator(ctx.pool(), log_softmax_n_bytes);\n    void *               log_softmax_buffer = log_softmax_allocator.get();\n\n    int64_t log_softmax_ne[] = { nc, nr };\n    size_t  log_softmax_nb[2];\n    log_softmax_nb[0]              = log_softmax_type_size;\n    log_softmax_nb[1]              = log_softmax_nb[0] * log_softmax_ne[0];\n    acl_tensor_ptr acl_log_softmax = ggml_cann_create_tensor(log_softmax_buffer, ACL_FLOAT, log_softmax_type_size,\n                                                             log_softmax_ne, log_softmax_nb, 2);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, LogSoftmax, acl_logits.get(), 1, acl_log_softmax.get());\n\n    int64_t labels_ne[] = { nc, nr };\n    size_t  labels_nb[2];\n    labels_nb[0]              = ggml_type_size(src1->type);\n    labels_nb[1]              = labels_nb[0] * labels_ne[0];\n    acl_tensor_ptr acl_labels = ggml_cann_create_tensor(src1->data, ACL_FLOAT, sizeof(float), labels_ne, labels_nb, 2);\n\n    size_t               mul_type_size = sizeof(float);\n    int64_t              mul_n_bytes   = nr * nc * mul_type_size;\n    ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_n_bytes);\n    void *               mul_buffer = mul_allocator.get();\n\n    int64_t mul_ne[] = { nc, nr };\n    size_t  mul_nb[2];\n    mul_nb[0]                     = mul_type_size;\n    mul_nb[1]                     = mul_nb[0] * mul_ne[0];\n    acl_tensor_ptr acl_mul_result = ggml_cann_create_tensor(mul_buffer, ACL_FLOAT, mul_type_size, mul_ne, mul_nb, 2);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_log_softmax.get(), acl_labels.get(), acl_mul_result.get());\n\n    size_t               sum_per_sample_type_size = sizeof(float);\n    int64_t              sum_per_sample_n_bytes   = nr * sum_per_sample_type_size;\n    ggml_cann_pool_alloc sum_per_sample_allocator(ctx.pool(), sum_per_sample_n_bytes);\n    void *               sum_per_sample_buffer = sum_per_sample_allocator.get();\n\n    int64_t sum_per_sample_ne[] = { nr };\n    size_t  sum_per_sample_nb[1];\n    sum_per_sample_nb[0]              = sum_per_sample_type_size;\n    acl_tensor_ptr acl_sum_per_sample = ggml_cann_create_tensor(\n        sum_per_sample_buffer, ACL_FLOAT, sum_per_sample_type_size, sum_per_sample_ne, sum_per_sample_nb, 1);\n\n    std::vector<int64_t> sum_dims   = { 1 };\n    acl_int_array_ptr    dims_array = ggml_cann_create_int_array(sum_dims.data(), sum_dims.size());\n    bool                 keep_dims  = false;\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_mul_result.get(), dims_array.get(), keep_dims, ACL_FLOAT,\n                            acl_sum_per_sample.get());\n\n    size_t               total_sum_type_size = sizeof(float);\n    int64_t              total_sum_n_bytes   = 1 * total_sum_type_size;\n    ggml_cann_pool_alloc total_sum_allocator(ctx.pool(), total_sum_n_bytes);\n    void *               total_sum_buffer = total_sum_allocator.get();\n\n    int64_t total_sum_ne[] = { 1 };\n    size_t  total_sum_nb[1];\n    total_sum_nb[0] = total_sum_type_size;\n\n    acl_tensor_ptr acl_total_sum =\n        ggml_cann_create_tensor(total_sum_buffer, ACL_FLOAT, total_sum_type_size, total_sum_ne, total_sum_nb, 1);\n\n    std::vector<int64_t> total_sum_dims    = { 0 };\n    acl_int_array_ptr total_sum_dims_array = ggml_cann_create_int_array(total_sum_dims.data(), total_sum_dims.size());\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_sum_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT,\n                            acl_total_sum.get());\n\n    float          value        = -1.0f / static_cast<float>(nr);\n    acl_scalar_ptr scale_factor = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT);\n    acl_tensor_ptr acl_dst =\n        ggml_cann_create_tensor(dst->data, ACL_FLOAT, sizeof(float), total_sum_ne, total_sum_nb, 1);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_total_sum.get(), scale_factor.get(), acl_dst.get());\n}\n\nvoid ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src = dst->src[0];\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    int n_groups = dst->op_params[0];\n\n    float eps;\n    memcpy(&eps, dst->op_params + 1, sizeof(float));\n\n    int64_t N   = src->ne[3];\n    int64_t C   = src->ne[2];\n    int64_t HxW = src->ne[1] * src->ne[0];\n\n    size_t  type_size = ggml_type_size(src->type);\n    int64_t ne[]      = { n_groups, N };\n    size_t  nb[]      = { type_size, type_size * n_groups };\n    size_t  n_bytes   = N * n_groups;\n\n    ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes * 2);\n    void *               buffer       = temp_buffer_allocator.get();\n    acl_tensor_ptr       acl_mean_out = ggml_cann_create_tensor(buffer, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);\n    acl_tensor_ptr       acl_rstd_out =\n        ggml_cann_create_tensor((char *) buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, GroupNorm, acl_src.get(), nullptr, nullptr, N, C, HxW, n_groups, eps, acl_dst.get(),\n                            acl_mean_out.get(), acl_rstd_out.get());\n}\n\nvoid ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];\n    ggml_tensor * src1 = dst->src[1];\n\n    size_t nb1     = ((int32_t *) dst->op_params)[0];\n    size_t nb2     = ((int32_t *) dst->op_params)[1];\n    size_t nb3     = ((int32_t *) dst->op_params)[2];\n    size_t offset  = ((int32_t *) dst->op_params)[3];\n    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];\n\n    size_t param_nb[] = { ggml_element_size(src0), nb1, nb2, nb3 };\n\n    acl_tensor_ptr acl_dst  = ggml_cann_create_tensor(dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);\n    acl_tensor_ptr acl_src1 = ggml_cann_create_tensor(src1);\n\n    acl_scalar_ptr alpha      = nullptr;\n    float          alphaValue = 1.0f;\n    alpha                     = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);\n\n    if (!inplace) {\n        size_t cpy_size = ggml_nbytes(dst);\n        ACL_CHECK(\n            aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));\n        acl_tensor_ptr acl_src0 =\n            ggml_cann_create_tensor(src0, src1->ne, src0->nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);\n\n        GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_src0.get(), acl_src1.get(), alpha.get(), acl_dst.get());\n    } else {\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), acl_src1.get(), alpha.get());\n    }\n}\n\n/**\n * @brief Performs sum reduction on a given tensor along specified dimensions.\n *\n * This function reduces the input tensor by summing along the specified dimensions.\n *\n * @param ctx The context for the CANN backend operations.\n * @param dst The destination tensor where the reduced result will be stored.\n * @param dim An array of dimension indices.\n * @param dim_size The number of dimensions.\n */\nstatic void aclnn_reduce_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t * dim, size_t dim_size) {\n    GGML_ASSERT(dst->ne[0] == 1);\n    ggml_tensor *     src         = dst->src[0];\n    acl_tensor_ptr    acl_src     = ggml_cann_create_tensor(src);\n    acl_tensor_ptr    acl_dst     = ggml_cann_create_tensor(dst);\n    acl_int_array_ptr reduce_dims = ggml_cann_create_int_array(dim, dim_size);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_src.get(), reduce_dims.get(), true, ggml_cann_type_mapping(dst->type),\n                            acl_dst.get());\n}\n\nvoid ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    int64_t reduce_dims[] = { 3 };\n    aclnn_reduce_sum(ctx, dst, reduce_dims, 1);\n}\n\nvoid ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    int64_t reduce_dims[] = { 0, 1, 2, 3 };\n    aclnn_reduce_sum(ctx, dst, reduce_dims, 4);\n}\n\nvoid ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor *  src     = dst->src[0];\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW);\n\n    std::vector<int64_t> output_size{ dst->ne[1], dst->ne[0] };\n    acl_int_array_ptr    output_size_array = ggml_cann_create_int_array(output_size.data(), 2);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, UpsampleNearest2d, acl_src.get(), output_size_array.get(), acl_dst.get());\n}\n\n/**\n * @brief Pads a tensor with a specified value along each dimension.\n *\n * This function performs padding of the source tensor `acl_src` and stores the\n * result in the destination tensor `acl_dst`. The padding values for each\n * dimension are specified in the `paddings` array.\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src The source tensor to be padded.\n * @param acl_dst The destination tensor where the padded result will be stored.\n * @param paddings An array specifying the padding values for each dimension.\n * The size of the array should be twice the number of dimensions of the tensor.\n * @param value The value to be used for padding. The default value is 0.0.\n */\nstatic void aclnn_pad(ggml_backend_cann_context & ctx,\n                      aclTensor *                 acl_src,\n                      aclTensor *                 acl_dst,\n                      int64_t *                   paddings,\n                      float                       value = 0.0f) {\n    acl_int_array_ptr acl_pad   = ggml_cann_create_int_array(paddings, GGML_MAX_DIMS * 2);\n    acl_scalar_ptr    acl_value = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_src, acl_pad.get(), acl_value.get(), acl_dst);\n}\n\nvoid ggml_cann_pad(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor *  src     = dst->src[0];\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    // padding: value in the array means how much distance will be padding.\n    // the position of elements in the array means which dirction to padding,\n    // each position means: [dim0.front, dim0.behind, dim1.front, dim1.behind,\n    //                       dim2.front, dim2.behind, dim3.front, dim3.behind]\n    const int32_t lp0 = ggml_get_op_params_i32(dst, 0);\n    const int32_t rp0 = ggml_get_op_params_i32(dst, 1);\n    const int32_t lp1 = ggml_get_op_params_i32(dst, 2);\n    const int32_t rp1 = ggml_get_op_params_i32(dst, 3);\n    const int32_t lp2 = ggml_get_op_params_i32(dst, 4);\n    const int32_t rp2 = ggml_get_op_params_i32(dst, 5);\n    const int32_t lp3 = ggml_get_op_params_i32(dst, 6);\n    const int32_t rp3 = ggml_get_op_params_i32(dst, 7);\n\n    int64_t paddings[] = { lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3 };\n    aclnn_pad(ctx, acl_src.get(), acl_dst.get(), paddings);\n}\n\n/**\n * @brief Performs 2D average pooling on the input tensor and stores the result\n * in the destination tensor.\n *\n * This function performs average pooling on the source tensor and stores the\n * result in the destination tensor. The pooling parameters (kernel size,\n * strides, padding) are specified in the `op_params` of the destination tensor.\n *\n * @param ctx The context for the CANN backend operations.\n * @param dst The destination tensor where the result will be stored. The source\n * tensor is referenced by `dst->src[0]`.\n */\nstatic void ggml_cann_avg_pool2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src = dst->src[0];\n    GGML_ASSERT(src->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW);\n\n    const int32_t * opts = (const int32_t *) dst->op_params;\n    const int       k0   = opts[1];\n    const int       k1   = opts[2];\n    const int       s0   = opts[3];\n    const int       s1   = opts[4];\n    const int       p0   = opts[5];\n    const int       p1   = opts[6];\n\n    std::vector<int64_t> kernel_dims      = { k1, k0 };\n    std::vector<int64_t> stride_dims      = { s1, s0 };\n    std::vector<int64_t> padding_avg_dims = { p1, p0 };  // (padH, padW)\n\n    acl_int_array_ptr kernel_size  = ggml_cann_create_int_array(kernel_dims.data(), 2);\n    acl_int_array_ptr strides      = ggml_cann_create_int_array(stride_dims.data(), 2);\n    acl_int_array_ptr paddings_avg = ggml_cann_create_int_array(padding_avg_dims.data(), 2);\n\n    bool    ceil_mode         = false;\n    bool    count_include_pad = true;\n    int64_t divisor_override  = 0;\n    int8_t  cube_math_type    = 0;\n#ifdef ASCEND_310P\n    cube_math_type = 1;\n#endif\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, AvgPool2d, acl_src.get(), kernel_size.get(), strides.get(), paddings_avg.get(),\n                            ceil_mode, count_include_pad, divisor_override, cube_math_type, acl_dst.get());\n}\n\n/**\n * @brief Performs 2D max pooling on the input tensor and stores the result in\n * the destination tensor.\n *\n * This function performs max pooling on the source tensor and stores the result\n * in the destination tensor. The pooling parameters (kernel size, strides,\n * padding) are specified in the `op_params` of the destination tensor.\n *\n * @param ctx The context for the CANN backend operations.\n * @param dst The destination tensor where the result will be stored. The source\n * tensor is referenced by `dst->src[0]`.\n */\nstatic void ggml_cann_max_pool2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src = dst->src[0];\n    GGML_ASSERT(src->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW);\n\n    const int32_t * opts = (const int32_t *) dst->op_params;\n    const int       k0   = opts[1];\n    const int       k1   = opts[2];\n    const int       s0   = opts[3];\n    const int       s1   = opts[4];\n    const int       p0   = opts[5];\n    const int       p1   = opts[6];\n\n    int64_t temp_ne[] = { src->ne[0] + p0 * 2, src->ne[1] + p1 * 2, src->ne[2], src->ne[3] };\n    size_t  temp_nb[GGML_MAX_DIMS];\n\n    temp_nb[0] = ggml_element_size(src);\n    for (int i = 1; i < GGML_MAX_DIMS; i++) {\n        temp_nb[i] = temp_nb[i - 1] * temp_ne[i - 1];\n    }\n\n    ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), ggml_nbytes(src) + p0 * 2 + p1 * 2 * src->nb[1]);\n    void *               buffer = temp_buffer_allocator.get();\n    acl_tensor_ptr tmp_tensor   = ggml_cann_create_tensor(buffer, ACL_FLOAT, ggml_element_size(src), temp_ne, temp_nb,\n                                                          GGML_MAX_DIMS, ACL_FORMAT_NCHW);\n\n    // pad: see padding in ggml_cann_pad()\n    int64_t paddings[] = { p0, p0, p1, p1, 0, 0, 0, 0 };\n    float   value      = -FLT_MAX;\n    aclnn_pad(ctx, acl_src.get(), tmp_tensor.get(), paddings, value);\n\n    // max_pool\n    std::vector<int64_t> kernel_dims      = { k1, k0 };\n    std::vector<int64_t> stride_dims      = { s1, s0 };\n    // padding_max_dims: [dim0_start, dim0_end, dim1_start, dim1_end]\n    std::vector<int64_t> padding_max_dims = { 0, 0, 0, 0 };\n    std::vector<int64_t> dilation_size    = { 1, 1 };\n    acl_int_array_ptr    kernel_size      = ggml_cann_create_int_array(kernel_dims.data(), 2);\n    acl_int_array_ptr    strides          = ggml_cann_create_int_array(stride_dims.data(), 2);\n    acl_int_array_ptr    paddings_max     = ggml_cann_create_int_array(padding_max_dims.data(), 4);\n    acl_int_array_ptr    dilations        = ggml_cann_create_int_array(dilation_size.data(), 2);\n\n    bool    ceil_mode = false;\n    int64_t auto_pads = 0;\n    GGML_CANN_CALL_ACLNN_OP(ctx, MaxPool, tmp_tensor.get(), kernel_size.get(), strides.get(), auto_pads,\n                            paddings_max.get(), dilations.get(), ceil_mode, acl_dst.get());\n}\n\nvoid ggml_cann_pool2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    const int32_t *   opts = (const int32_t *) dst->op_params;\n    enum ggml_op_pool op   = static_cast<ggml_op_pool>(opts[0]);\n    switch (op) {\n        case GGML_OP_POOL_AVG:\n            ggml_cann_avg_pool2d(ctx, dst);\n            break;\n        case GGML_OP_POOL_MAX:\n            ggml_cann_max_pool2d(ctx, dst);\n            break;\n        case GGML_OP_POOL_COUNT:\n            GGML_ABORT(\"fatal error\");\n            break;\n    }\n}\n\n/**\n * @brief Copies data from the source tensor to the destination tensor.\n *\n * This function copies data from the source tensor `acl_src` to the destination\n * tensor `acl_dst`.\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src The source tensor from which data will be copied.\n * @param acl_dst The destination tensor where the data will be copied to.\n */\nstatic void cann_copy(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {\n    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst, acl_src);\n}\n\nvoid ggml_cann_dup(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];\n\n    if (ggml_are_same_shape(src0, dst)) {\n        acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);\n        acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n        if (dst->type == src0->type) {\n            cann_copy(ctx, acl_src.get(), acl_dst.get());\n        } else {\n            aclnn_cast(ctx, acl_src.get(), acl_dst.get(), ggml_cann_type_mapping(dst->type));\n        }\n    } else {\n        void *               src_trans_buffer = src0->data;\n        ggml_cann_pool_alloc src_buffer_allocator;\n        if (!ggml_is_contiguous(src0)) {\n            acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);\n            src_buffer_allocator.alloc(ctx.pool(), ggml_nelements(src0) * ggml_type_size(src0->type));\n            src_trans_buffer = src_buffer_allocator.get();\n            size_t src_trans_nb[GGML_MAX_DIMS];\n            src_trans_nb[0] = ggml_type_size(src0->type);\n            for (int i = 1; i < GGML_MAX_DIMS; i++) {\n                src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];\n            }\n            acl_tensor_ptr src_trans_tensor =\n                ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(src0->type),\n                                        ggml_type_size(src0->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);\n            cann_copy(ctx, acl_src.get(), src_trans_tensor.get());\n        }\n\n        size_t src_reshape_nb[GGML_MAX_DIMS];\n        src_reshape_nb[0] = ggml_type_size(src0->type);\n        for (int i = 1; i < GGML_MAX_DIMS; i++) {\n            src_reshape_nb[i] = src_reshape_nb[i - 1] * dst->ne[i - 1];\n        }\n\n        acl_tensor_ptr trans_acl_src =\n            ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type),\n                                    dst->ne, src_reshape_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);\n        acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n        if (dst->type == src0->type) {\n            cann_copy(ctx, trans_acl_src.get(), acl_dst.get());\n        } else {\n            aclnn_cast(ctx, trans_acl_src.get(), acl_dst.get(), ggml_cann_type_mapping(dst->type));\n        }\n    }\n}\n\n/**\n * @brief Creates an ACL tensor initialized with zeros using a provided buffer.\n *\n * This function initializes a tensor with zeros using the specified buffer and\n * tensor parameters.\n *\n * @param ctx The context for the CANN backend operations.\n * @param buffer The buffer to be used for the tensor data.\n * @param n_bytes The size of the buffer in bytes.\n * @param ne An array specifying the extents (sizes) of each dimension of the\n * tensor.\n * @param dims The number of dimensions of the tensor.\n * @param type The data type of the tensor.\n * @param type_size The size of each element in the tensor data type.\n * @return A tensor smart pointer initialized with zeros.\n */\nstatic acl_tensor_ptr aclnn_zero(ggml_backend_cann_context & ctx,\n                                 void *                      buffer,\n                                 size_t                      n_bytes,\n                                 int64_t *                   ne,\n                                 int64_t                     dims,\n                                 aclDataType                 type,\n                                 size_t                      type_size) {\n    size_t nb[GGML_MAX_DIMS];\n    nb[0] = type_size;\n    for (int i = 1; i < dims; i++) {\n        nb[i] = nb[i - 1] * ne[i - 1];\n    }\n\n    acl_tensor_ptr zero = ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims);\n    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, zero.get());\n    return zero;\n    GGML_UNUSED(n_bytes);\n}\n\n/**\n * @brief Creates an ACL tensor initialized with value using a provided buffer.\n *\n * This function initializes a tensor with value using the specified buffer and\n * tensor parameters.\n *\n * @param ctx The context for the CANN backend operations.\n * @param buffer The buffer to be used for the tensor data.\n * @param n_bytes The size of the buffer in bytes.\n * @param ne An array specifying the extents (sizes) of each dimension of the\n * tensor.\n * @param dims The number of dimensions of the tensor.\n * @param type The data type of the tensor.\n * @param type_size The size of each element in the tensor data type.\n * @param value The value to be used for initializing the tensor (default\n * is 1.0).\n * @return A tensor smart pointer initialized with value.\n */\nstatic acl_tensor_ptr aclnn_values(ggml_backend_cann_context & ctx,\n                                   void *                      buffer,\n                                   size_t                      n_bytes,\n                                   int64_t *                   ne,\n                                   int64_t                     dims,\n                                   aclDataType                 type,\n                                   size_t                      type_size,\n                                   float                       value = 1.0f) {\n    acl_tensor_ptr acl_tensor = aclnn_zero(ctx, buffer, n_bytes, ne, dims, type, type_size);\n    float          alpha_host = 1.0f;\n    acl_scalar_ptr alpha      = ggml_cann_create_scalar(&alpha_host, aclDataType::ACL_FLOAT);\n    acl_scalar_ptr other      = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT);\n    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_tensor.get(), other.get(), alpha.get());\n    return acl_tensor;\n}\n\n/**\n * @brief Fills a tensor with a scalar value.\n *\n * This function fills the destination tensor `acl_dst` with the scalar value\n * `scalar`.\n *\n * @param ctx The context for the CANN backend operations.\n * @param scalar The scalar value used to fill the tensor.\n * @param acl_dst The destination tensor to be filled with the scalar value.\n */\nstatic void aclnn_fill_scalar(ggml_backend_cann_context & ctx, float scalar, aclTensor * acl_dst) {\n    acl_scalar_ptr acl_scalar = ggml_cann_create_scalar(&scalar, aclDataType::ACL_FLOAT);\n    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar.get());\n}\n\n/**\n * @brief Get or expand a cached tensor filled with a scalar value.\n *\n * This function manages cached device memory for tensors. If the current\n * cache size is insufficient for the requested tensor shape, the old memory will\n * be released and new memory will be allocated. The allocated buffer is\n * initialized  with the given scalar value using CANN operations.\n * Finally, an aclTensor object is created from the cached memory and returned.\n *\n * @param ctx           The CANN backend context that manages device memory.\n * @param buffer        A pointer to the cached device buffer (will be allocated\n *                      or reallocated if necessary).\n * @param cache_element The current number of cached elements. This will be\n *                      updated when the cache is expanded.\n * @param ne            The tensor shape array (number of elements in each dimension).\n * @param nb            The stride size for each dimension.\n * @param dtype         Data type of cached tensor.\n * @param dims          The number of tensor dimensions.\n * @param value         The scalar value used to fill the tensor (supports zero\n *                      initialization via memset or arbitrary values via fill_scalar).\n * @return              A tensor smart pointer created from the cached buffer.\n */\nstatic acl_tensor_ptr get_cache_acl_tensor(ggml_backend_cann_context & ctx,\n                                           void **                     buffer,\n                                           int64_t &                   cache_element,\n                                           int64_t *                   ne,\n                                           size_t *                    nb,\n                                           ggml_type                   dtype,\n                                           int64_t                     dims,\n                                           float                       value) {\n    // Calculate total number of elements\n    int64_t n_element = 1;\n    for (int i = 0; i < dims; i++) {\n        n_element *= ne[i];\n    }\n    size_t size = n_element * ggml_type_size(dtype);\n\n    // Allocate or expand cache if needed\n    if (cache_element < n_element) {\n        if (*buffer != nullptr) {\n            aclrtFree(*buffer);\n            *buffer = nullptr;\n        }\n\n        ACL_CHECK(aclrtMalloc(buffer, size, ACL_MEM_MALLOC_HUGE_FIRST));\n        cache_element = n_element;\n\n        // Initialize cache\n        int64_t        pool_ne[1] = { n_element };\n        size_t         pool_nb[1] = { ggml_type_size(dtype) };\n        acl_tensor_ptr acl_value =\n            ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype), pool_ne, pool_nb, 1);\n        aclnn_fill_scalar(ctx, value, acl_value.get());\n    }\n\n    return ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype), ne, nb, dims);\n}\n\nvoid ggml_cann_rms_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src = dst->src[0];\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n\n    // build gamma.\n    size_t acl_gamma_nb[GGML_MAX_DIMS];\n    // gamma's type is the same with dst.\n    acl_gamma_nb[0] = ggml_type_size(dst->type);\n    for (int i = 1; i < GGML_MAX_DIMS; i++) {\n        acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];\n    }\n    acl_tensor_ptr acl_gamma = get_cache_acl_tensor(\n        ctx, &ctx.rms_norm_one_tensor_cache.cache, ctx.rms_norm_one_tensor_cache.size, src->ne, acl_gamma_nb, dst->type,\n        1,    // dims\n        1.0f  // value\n    );\n\n    // build rstd.\n    int64_t acl_rstd_ne[] = { src->ne[1], src->ne[2], src->ne[3] };\n    size_t  acl_rstd_nb[GGML_MAX_DIMS - 1];\n    // rstd will always be F32.\n    acl_rstd_nb[0] = sizeof(float);\n    for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {\n        acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];\n    }\n    acl_tensor_ptr acl_rstd =\n        get_cache_acl_tensor(ctx, &ctx.rms_norm_zero_tensor_cache.cache, ctx.rms_norm_zero_tensor_cache.size,\n                             acl_rstd_ne, acl_rstd_nb, GGML_TYPE_F32, GGML_MAX_DIMS - 1,\n                             0.0f  // value\n        );\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src.get(), acl_gamma.get(), eps, acl_dst.get(), acl_rstd.get());\n}\n\n// TODO: performace is low.\nvoid ggml_cann_diag_mask(ggml_backend_cann_context & ctx, ggml_tensor * dst, float value) {\n    ggml_tensor * src = dst->src[0];\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    const int n_past = ((int32_t *) dst->op_params)[0];\n\n    ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), ggml_nbytes(src));\n    void *               buffer = one_tensor_allocator.get();\n\n    acl_tensor_ptr mask_tensor = ggml_cann_create_tensor(buffer, ggml_cann_type_mapping(src->type),\n                                                         ggml_type_size(src->type), src->ne, src->nb, GGML_MAX_DIMS);\n\n    aclnn_fill_scalar(ctx, value, mask_tensor.get());\n\n    float          alphaValue = 1.0f;\n    acl_scalar_ptr alpha      = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceTriu, mask_tensor.get(), n_past + 1);\n    GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), n_past + 1, acl_dst.get());\n    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), mask_tensor.get(), alpha.get());\n}\n\n/**\n * @brief Permutes the dimensions of a tensor according to a specified order.\n *\n * This function permutes the dimensions of the source tensor `acl_src`\n * according to the order specified in the `new_dim` array and stores the result\n * in the destination tensor `acl_dst`.\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src The source tensor whose dimensions will be permuted.\n * @param acl_dst The destination tensor where the permuted result will be\n * stored.\n * @param new_dim An array specifying the new order of dimensions for the\n * tensor.\n * @param dims The number of dimensions in the tensor.\n */\nstatic void aclnn_permute(ggml_backend_cann_context & ctx,\n                          aclTensor *                 acl_src,\n                          aclTensor *                 acl_dst,\n                          int64_t *                   new_dim,\n                          uint64_t                    dims) {\n    acl_int_array_ptr acl_dims = ggml_cann_create_int_array(new_dim, dims);\n    GGML_CANN_CALL_ACLNN_OP(ctx, Permute, acl_src, acl_dims.get(), acl_dst);\n}\n\nstatic void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context & ctx,\n                                             ggml_tensor *               dst,\n                                             ggml_tensor *               src1,\n                                             aclTensor *                 tmp_cast_tensor,\n                                             aclTensor *                 tmp_im2col_tensor) {\n    // Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW]\n    int64_t        dst_ne[] = { dst->ne[0], dst->ne[1] * dst->ne[2], dst->ne[3] };\n    size_t         dst_nb[] = { dst->nb[0], dst->nb[1], dst->nb[3] };\n    acl_tensor_ptr acl_dst  = ggml_cann_create_tensor(dst, dst_ne, dst_nb, GGML_MAX_DIMS - 1);\n\n    int64_t permute_dim[] = { 0, 2, 1 };\n    if (src1->type != dst->type) {\n        aclnn_permute(ctx, tmp_cast_tensor, acl_dst.get(), permute_dim, 3);\n    } else {\n        aclnn_permute(ctx, tmp_im2col_tensor, acl_dst.get(), permute_dim, 3);\n    }\n}\n\nstatic void ggml_cann_im2col_1d_post_process(ggml_backend_cann_context &  ctx,\n                                             ggml_tensor *                dst,\n                                             ggml_tensor *                src1,\n                                             aclTensor *                  tmp_cast_tensor,\n                                             aclTensor *                  tmp_im2col_tensor,\n                                             const std::vector<int64_t> & im2col_op_params) {\n    // get params\n    const int64_t KH             = im2col_op_params[0];\n    const int64_t KW             = im2col_op_params[1];\n    const int64_t IW             = im2col_op_params[2];\n    const int64_t IC             = im2col_op_params[3];\n    const int64_t N              = im2col_op_params[4];\n    const int64_t OH             = im2col_op_params[5];\n    const int64_t OW             = im2col_op_params[6];\n    const int64_t s0             = im2col_op_params[7];\n    const int64_t p0             = im2col_op_params[8];\n    const int64_t d0             = im2col_op_params[9];\n    const int64_t n_bytes_factor = im2col_op_params[10];\n\n    // Permute: [N, IC * KH * KW, OW * OH] ->\n    // [N, OW * OH * n_bytes_factor, IC * KH * KW]\n    ggml_cann_pool_alloc tmp_permute_allocator(ctx.pool());\n    tmp_permute_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor);\n    void * tmp_permute_buffer = tmp_permute_allocator.get();\n\n    int64_t tmp_permute_ne[] = { IC * KH * KW, OW * OH * n_bytes_factor, N };\n    size_t  tmp_permute_nb[GGML_MAX_DIMS - 1];\n    tmp_permute_nb[0] = ggml_type_size(dst->type);\n    for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {\n        tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1];\n    }\n\n    acl_tensor_ptr tmp_permute_tensor =\n        ggml_cann_create_tensor(tmp_permute_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),\n                                tmp_permute_ne, tmp_permute_nb, GGML_MAX_DIMS - 1, ACL_FORMAT_ND);\n\n    int64_t permute_dim[] = { 0, 2, 1 };\n    if (src1->type != dst->type) {\n        aclnn_permute(ctx, tmp_cast_tensor, tmp_permute_tensor.get(), permute_dim, 3);\n    } else {\n        aclnn_permute(ctx, tmp_im2col_tensor, tmp_permute_tensor.get(), permute_dim, 3);\n    }\n\n    // number of times the kernel moves in W dimension\n    const int n_step_w = (IW + 2 * p0 - d0 * (KW - 1) - 1) / s0 + 1;\n    size_t    offset;\n    void *    cur_dst_buffer = dst->data, *cur_permute_buffer = tmp_permute_buffer;\n\n    // memory copy with offset to restore 1D im2col from 2d\n    if (IC > 1) {\n        offset          = IC * KH * KW * n_step_w * ggml_type_size(dst->type);\n        size_t cpy_size = KH * KW * ggml_type_size(dst->type);\n\n        for (int c = 0; c < IC; c++) {\n            cur_permute_buffer = (char *) tmp_permute_buffer + offset + KH * KW * c * ggml_type_size(dst->type);\n            cur_dst_buffer     = (char *) dst->data + c * KH * KW * n_step_w * ggml_type_size(dst->type);\n\n            for (int i = 0; i < n_step_w; i++) {\n                ACL_CHECK(aclrtMemcpyAsync(cur_dst_buffer, cpy_size, cur_permute_buffer, cpy_size,\n                                           ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));\n                cur_dst_buffer     = (char *) cur_dst_buffer + KH * KW * ggml_type_size(dst->type);\n                cur_permute_buffer = (char *) cur_permute_buffer + KH * KW * IC * ggml_type_size(dst->type);\n            }\n        }\n    } else {\n        offset = KH * KW * n_step_w * ggml_type_size(dst->type);  // equal to ggml_nbytes(dst)\n        ACL_CHECK(aclrtMemcpyAsync(dst->data, offset, (char *) tmp_permute_buffer + offset, offset,\n                                   ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));\n    }\n}\n\nvoid ggml_cann_im2col(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];  // kernel\n    ggml_tensor * src1 = dst->src[1];  // input\n\n    GGML_TENSOR_BINARY_OP_LOCALS;\n\n    // aclnnIm2col only works on 2D. set s1, p1, d1 to 1 to perform 2D\n    // im2col and do post-processing to restore it to 1D.\n    const bool    is_2D = ((const int32_t *) (dst->op_params))[6] == 1;\n    const int32_t s0    = ((const int32_t *) (dst->op_params))[0];\n    const int32_t s1    = is_2D ? ((const int32_t *) (dst->op_params))[1] : 1;\n    const int32_t p0    = ((const int32_t *) (dst->op_params))[2];\n    const int32_t p1    = is_2D ? ((const int32_t *) (dst->op_params))[3] : 1;\n    const int32_t d0    = ((const int32_t *) (dst->op_params))[4];\n    const int32_t d1    = is_2D ? ((const int32_t *) (dst->op_params))[5] : 1;\n\n    const int64_t N  = ne13;\n    const int64_t IC = ne12;\n    const int64_t KH = ne01;\n    const int64_t KW = ne00;\n    const int64_t IW = ne10;\n\n    const int64_t OH = is_2D ? ne2 : 1;\n    const int64_t OW = ne1;\n\n    // memory allocated increased to 3x when is_2D == false\n    const int64_t n_bytes_factor = is_2D ? 1 : 3;\n\n    // im2col: [N,C,H,W] -> [N, IC * KH * KW, OW * OH * n_bytes_factor]\n    acl_tensor_ptr acl_src1        = ggml_cann_create_tensor(src1);\n    int64_t        tmp_im2col_ne[] = { OW * OH * n_bytes_factor, IC * KH * KW, N };\n    size_t         tmp_im2col_nb[GGML_MAX_DIMS - 1];\n\n    tmp_im2col_nb[0] = ggml_type_size(src1->type);\n    for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {\n        tmp_im2col_nb[i] = tmp_im2col_nb[i - 1] * tmp_im2col_ne[i - 1];\n    }\n\n    // Calculate im2col.\n    // If dst is f16, tmp_buffer is f32, we need alloc src.typesize *\n    // dst.elemcount.\n    ggml_cann_pool_alloc im2col_allocator(ctx.pool(), ggml_nelements(dst) * ggml_element_size(src1) * n_bytes_factor);\n    void *               tmp_im2col_buffer = im2col_allocator.get();\n\n    acl_tensor_ptr tmp_im2col_tensor =\n        ggml_cann_create_tensor(tmp_im2col_buffer, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),\n                                tmp_im2col_ne, tmp_im2col_nb, GGML_MAX_DIMS - 1, ACL_FORMAT_ND);\n\n    std::vector<int64_t> kernel_dims   = { KH, KW };\n    std::vector<int64_t> dilation_size = { d1, d0 };\n    std::vector<int64_t> padding_dims  = { p1, p0 };\n    std::vector<int64_t> stride_dims   = { s1, s0 };\n    acl_int_array_ptr    kernel_size   = ggml_cann_create_int_array(kernel_dims.data(), 2);\n    acl_int_array_ptr    dilations     = ggml_cann_create_int_array(dilation_size.data(), 2);\n    acl_int_array_ptr    paddings      = ggml_cann_create_int_array(padding_dims.data(), 2);\n    acl_int_array_ptr    strides       = ggml_cann_create_int_array(stride_dims.data(), 2);\n    GGML_CANN_CALL_ACLNN_OP(ctx, Im2col, acl_src1.get(), kernel_size.get(), dilations.get(), paddings.get(),\n                            strides.get(), tmp_im2col_tensor.get());\n\n    // Cast if dst is f16.\n    acl_tensor_ptr       tmp_cast_tensor;\n    ggml_cann_pool_alloc tmp_cast_allocator(ctx.pool());\n    void *               tmp_cast_buffer = nullptr;\n    if (src1->type != dst->type) {\n        tmp_cast_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor);\n        tmp_cast_buffer = tmp_cast_allocator.get();\n        size_t temp_cast_nb[GGML_MAX_DIMS - 1];\n        temp_cast_nb[0] = ggml_type_size(dst->type);\n        for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {\n            temp_cast_nb[i] = temp_cast_nb[i - 1] * tmp_im2col_ne[i - 1];\n        }\n\n        tmp_cast_tensor =\n            ggml_cann_create_tensor(tmp_cast_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),\n                                    tmp_im2col_ne, temp_cast_nb, GGML_MAX_DIMS - 1, ACL_FORMAT_ND);\n        aclnn_cast(ctx, tmp_im2col_tensor.get(), tmp_cast_tensor.get(), ggml_cann_type_mapping(dst->type));\n    }\n\n    // post-processing\n    if (is_2D) {\n        ggml_cann_im2col_2d_post_process(ctx, dst, src1, tmp_cast_tensor.get(), tmp_im2col_tensor.get());\n    } else {\n        std::vector<int64_t> im2col_op_params = { KH, KW, IW, IC, N, OH, OW, s0, p0, d0, n_bytes_factor };\n        ggml_cann_im2col_1d_post_process(ctx, dst, src1, tmp_cast_tensor.get(), tmp_im2col_tensor.get(),\n                                         im2col_op_params);\n    }\n}\n\n/**\n * @brief Applies element-wise exponential function to the elements of a tensor.\n *\n * This function computes the exponential of each element in the source tensor\n * `acl_src` and stores the result back into the same tensor.\n * The operation is defined as:\n * \\f[\n *     \\text {acl_src }_i=e^{acl\\_src_i}\n * \\f]\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src The tensor on which the exponential function will be applied.\n */\nstatic void aclnn_exp(ggml_backend_cann_context & ctx, aclTensor * acl_src) {\n    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceExp, acl_src);\n}\n\nvoid aclnn_cos(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {\n    if (acl_dst == nullptr) {\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCos, acl_src);\n    } else {\n        GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);\n    }\n}\n\nvoid aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {\n    if (acl_dst == nullptr) {\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSin, acl_src);\n    } else {\n        GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);\n    }\n}\n\nvoid ggml_cann_timestep_embedding(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src = dst->src[0];\n\n    GGML_ASSERT(src->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    const int dim        = dst->op_params[0];\n    const int max_period = dst->op_params[1];\n    int       half       = dim / 2;\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);\n\n    // arange: [0, ..., half)\n    float   start             = 0;\n    float   stop              = half;\n    float   step              = 1;\n    int64_t n_elements_arange = half;\n    int64_t tmp_arange_ne[]   = { half };\n    size_t  tmp_arange_nb[]   = { sizeof(dst->type) };\n\n    ggml_cann_pool_alloc arange_allocator(ctx.pool(), half * sizeof(dst->type));\n    void *               tmp_arange_buffer = arange_allocator.get();\n    acl_tensor_ptr       tmp_arange_tensor =\n        ggml_cann_create_tensor(tmp_arange_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),\n                                tmp_arange_ne, tmp_arange_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);\n\n    aclnn_arange(ctx, tmp_arange_tensor.get(), start, stop, step, n_elements_arange);\n\n    // freq\n    float freq_param = -logf(max_period) / half;\n    bool  inplace    = true;\n    aclnn_muls(ctx, tmp_arange_tensor.get(), freq_param, nullptr, inplace);\n    aclnn_exp(ctx, tmp_arange_tensor.get());\n\n    // permute: src [0,1,2,3]->[0,1,3,2]\n    int64_t tmp_permute_ne[] = { src->ne[1], src->ne[0], src->ne[2], src->ne[3] };\n    size_t  tmp_permute_nb[GGML_MAX_DIMS];\n    tmp_permute_nb[0] = ggml_type_size(src->type);\n    for (int i = 1; i < GGML_MAX_DIMS; i++) {\n        tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1];\n    }\n\n    ggml_cann_pool_alloc permute_allocator(ctx.pool(), ggml_nbytes(src));\n    void *               tmp_permute_buffer = permute_allocator.get();\n    acl_tensor_ptr       tmp_permute_tensor =\n        ggml_cann_create_tensor(tmp_permute_buffer, ggml_cann_type_mapping(src->type), ggml_type_size(src->type),\n                                tmp_permute_ne, tmp_permute_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);\n    int64_t permute_dim[] = { 0, 1, 3, 2 };\n    int64_t num_dims      = 4;\n    aclnn_permute(ctx, acl_src.get(), tmp_permute_tensor.get(), permute_dim, num_dims);\n\n    // timestep * freq\n    int64_t tmp_mul_ne[] = { src->ne[1] * half, src->ne[0], src->ne[2], src->ne[3] };\n    size_t  tmp_mul_nb[GGML_MAX_DIMS];\n    tmp_mul_nb[0] = ggml_type_size(src->type);\n    for (int i = 1; i < GGML_MAX_DIMS; i++) {\n        tmp_mul_nb[i] = tmp_mul_nb[i - 1] * tmp_mul_ne[i - 1];\n    }\n\n    int mul_nelements = src->ne[1] * half * src->ne[0] * src->ne[2] * src->ne[3];\n\n    ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_nelements * ggml_type_size(src->type));\n    void *               tmp_mul_buffer = mul_allocator.get();\n    acl_tensor_ptr       tmp_mul_tensor =\n        ggml_cann_create_tensor(tmp_mul_buffer, ggml_cann_type_mapping(src->type), ggml_type_size(src->type),\n                                tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);\n    aclnn_mul(ctx, tmp_permute_tensor.get(), tmp_arange_tensor.get(), tmp_mul_tensor.get());\n\n    // cos\n    ggml_cann_pool_alloc cos_allocator(ctx.pool(), mul_nelements * ggml_type_size(src->type));\n    void *               tmp_cos_buffer = cos_allocator.get();\n    acl_tensor_ptr       tmp_cos_tensor =\n        ggml_cann_create_tensor(tmp_cos_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),\n                                tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);\n\n    aclnn_cos(ctx, tmp_mul_tensor.get(), tmp_cos_tensor.get());\n\n    // sin\n    ggml_cann_pool_alloc sin_allocator(ctx.pool(), mul_nelements * ggml_type_size(src->type));\n    void *               tmp_sin_buffer = sin_allocator.get();\n    acl_tensor_ptr       tmp_sin_tensor =\n        ggml_cann_create_tensor(tmp_sin_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),\n                                tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);\n\n    aclnn_sin(ctx, tmp_mul_tensor.get(), tmp_sin_tensor.get());\n\n    // concat\n    int64_t             concat_dim  = 3;\n    acl_tensor_ptr      acl_dst     = ggml_cann_create_tensor(dst);\n    acl_tensor_list_ptr tensor_list = ggml_cann_create_tensor_list(tmp_cos_tensor, tmp_sin_tensor);\n    aclnn_concat(ctx, tensor_list.get(), acl_dst.get(), concat_dim);\n}\n\n/**\n * @brief Raises each element of a tensor to the power of the corresponding\n * element in another tensor.\n *\n * This function computes the element-wise power of the destination tensor\n * `acl_dst` raised to the power of the exponent tensor `acl_exp`.\n * The operation is defined as:\n * \\f[\n *     \\text {acl_dst }_i=acl\\_dst_i^{\\text {acl_exp }_i}\n * \\f]\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_dst The destination tensor, which also serves as the base tensor.\n * @param acl_exp The exponent tensor, each element of which is used to raise\n * the corresponding element in the destination tensor.\n */\nstatic void aclnn_pow_tensor_tensor(ggml_backend_cann_context & ctx, aclTensor * acl_dst, aclTensor * acl_exp) {\n    GGML_CANN_CALL_ACLNN_OP(ctx, InplacePowTensorTensor, acl_dst, acl_exp);\n}\n\n/**\n * @brief Generate a range of values and apply a scalar base exponentiation.\n *\n * This function creates an evenly spaced sequence from `start` to `stop` (exclusive),\n * with step size `step`, stores it in a temporary buffer, and then computes:\n *\n * @f[\n * slope[i] = m^{\\left( start + i \\cdot step \\right)}, \\quad 0 \\le i < size\n * @f]\n *\n * The results are written to the provided @p slope_buffer.\n *\n * @param ctx           CANN backend context for memory allocation and operator execution.\n * @param slope_buffer  Pointer to the output buffer (float array) for the computed slope values.\n * @param m             Scalar base for the exponentiation.\n * @param size          Number of elements in the generated sequence.\n * @param start         Starting exponent offset.\n * @param stop          Stopping exponent offset (exclusive).\n * @param step          Step size for the exponent increment.\n * @param dtype         Data type for slope tensor.\n */\nstatic void aclnn_get_slope_inner(ggml_backend_cann_context & ctx,\n                                  void *                      slope_buffer,\n                                  float                       m,\n                                  int64_t                     size,\n                                  float                       start,\n                                  float                       stop,\n                                  float                       step,\n                                  ggml_type                   dtype) {\n    aclDataType acl_type  = ggml_cann_type_mapping(dtype);\n    size_t      type_size = ggml_type_size(dtype);\n\n    int64_t ne[] = { size };\n    size_t  nb[] = { type_size };\n\n    ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * type_size);\n    void *               arange_buffer = arange_allocator.get();\n\n    acl_tensor_ptr arange_tensor = ggml_cann_create_tensor(arange_buffer, acl_type, type_size, ne, nb, 1);\n    aclnn_arange(ctx, arange_tensor.get(), start, stop, step, size);\n\n    acl_tensor_ptr slope_tensor = ggml_cann_create_tensor(slope_buffer, acl_type, type_size, ne, nb, 1);\n\n    acl_scalar_ptr sc = ggml_cann_create_scalar(&m, aclDataType::ACL_FLOAT);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, sc.get(), arange_tensor.get(), slope_tensor.get());\n}\n\n/**\n * @brief Compute slope values for multiple attention heads based on ALiBi bias parameters.\n *\n * This function generates slope values for each attention head according to the ALiBi\n * (Attention with Linear Biases) method. It splits the computation into two ranges depending\n * on whether the head index is less than @p n_head_log2 or not, and uses different base values\n * (`m0` and `m1`) for the exponentiation.\n *\n * @f[\n * slope[h] =\n * \\begin{cases}\n * m_0^{(h + 1)}, & h < n\\_head\\_log2 \\\\\n * m_1^{\\left( 2 \\cdot (h - n\\_head\\_log2) + 1 \\right)}, & h \\geq n\\_head\\_log2\n * \\end{cases}\n * \\quad , \\quad \\text{if } max\\_bias > 0\n * @f]\n *\n * If @p max_bias <= 0, all slope values are set to 1.0.\n *\n * @param ctx           CANN backend context for memory allocation and operator execution.\n * @param n_head        Total number of attention heads.\n * @param slope_buffer  Pointer to the output buffer (float array) for storing slopes.\n * @param max_bias      Maximum bias value for slope computation.\n * @param dtype         Data type for slope tensor.\n *\n*/\nstatic void aclnn_get_slope(ggml_backend_cann_context & ctx,\n                            int64_t                     n_head,\n                            void *                      slope_buffer,\n                            float                       max_bias,\n                            ggml_type                   dtype) {\n    const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head));\n\n    float m0 = powf(2.0f, -(max_bias) / n_head_log2);\n    float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n    // const float slope = (max_bias > 0.0f) ?\n    //                          h < n_head_log2 ?\n    //                              powf(m0, h + 1) :\n    //                              powf(m1, 2*(h - n_head_log2) + 1) :\n    //                          1.0f;\n    // arange1\n    float start = 0 + 1;\n    float end   = (n_head_log2 - 1) + 1;\n    float step  = 1;\n    float count = n_head_log2;\n    // end needs to be +1 because aclnn uses a left-closed, right-open interval.\n    aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step, dtype);\n    if (n_head_log2 < n_head) {\n        // arange2\n        start = 2 * (n_head_log2 - n_head_log2) + 1;\n        end   = 2 * ((n_head - 1) - n_head_log2) + 1;\n        step  = 2;\n        count = n_head - n_head_log2;\n        aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step,\n                              dtype);\n    }\n}\n\n/**\n * @brief Add ALiBi (Attention with Linear Biases) positional biases to the attention mask.\n *\n * This function computes the ALiBi slopes for each attention head (if max_bias > 0),\n * multiplies them with the attention mask to produce bias tensors, and adds these biases\n * to the destination tensor (@p dst).\n *\n * The function performs necessary broadcasting of the mask and slope tensors to match\n * the shape of the destination tensor, then applies element-wise multiplication and addition\n * using CANN operators.\n *\n * @param ctx         CANN backend context for memory management and operator execution.\n * @param mask        Input attention mask tensor, assumed to be contiguous.\n * @param dst         Destination tensor to which ALiBi biases will be added.\n * @param dst_ptr     Pointer to the memory of the destination tensor.\n * @param max_bias    Maximum bias value controlling the slope scaling.\n *\n * @note\n * - Write data into dst_ptr using only the shape information of the dst tensor.\n * - `GGML_MAX_DIMS + 2` is used to extend tensor dimensions for broadcasting.\n */\nstatic void aclnn_add_alibi(ggml_backend_cann_context & ctx,\n                            ggml_tensor *               mask,\n                            ggml_tensor *               dst,\n                            void *                      dst_ptr,\n                            float                       max_bias) {\n    void * slope_buffer = nullptr;\n    void * bias_buffer  = nullptr;\n\n    if (max_bias > 0.0f) {\n        int64_t              n_heads = dst->ne[2];\n        ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));\n        slope_buffer = slope_allocator.get();\n        ggml_cann_pool_alloc bias_allocator(ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst));\n        bias_buffer = bias_allocator.get();\n        aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias, GGML_TYPE_F32);\n    }\n\n    // broadcast for mask, slop and dst;\n    int64_t nr2 = dst->ne[2] / mask->ne[2];\n    int64_t nr3 = dst->ne[3] / mask->ne[3];\n\n    // broadcast the mask across rows\n    int64_t mask_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1 };\n    size_t  mask_nb[] = { mask_nb[0] = mask->nb[0], mask_nb[1] = mask->nb[1], mask_nb[2] = mask->nb[2],\n                          mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3] };\n\n    int64_t dst_ne[] = { dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3 };\n    size_t  dst_nb[] = { dst_nb[0] = dst->nb[0], dst_nb[1] = dst->nb[1], dst_nb[2] = dst->nb[2],\n                         dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3] };\n\n    // slope is a 1 dim tensor, slope.ne2 == dst.ne2\n    int64_t slope_ne[] = { 1, 1, mask->ne[2], nr2, 1, 1 };\n    size_t  slope_nb[GGML_MAX_DIMS + 2];\n    slope_nb[0] = sizeof(float);\n    for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {\n        slope_nb[i] = slope_nb[i - 1] * slope_ne[i - 1];\n    }\n\n    acl_tensor_ptr acl_slope =\n        ggml_cann_create_tensor(slope_buffer, ACL_FLOAT, sizeof(float), slope_ne, slope_nb, GGML_MAX_DIMS + 2);\n    acl_tensor_ptr acl_mask = ggml_cann_create_tensor(mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2);\n\n    // write data into dst_ptr using only the shape information of the dst tensor.\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst_ptr, ggml_cann_type_mapping(dst->type),\n                                                     ggml_type_size(dst->type), dst_ne, dst_nb, GGML_MAX_DIMS + 2);\n\n    if (max_bias > 0.0f) {\n        int64_t bias_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], 1 };\n        size_t  bias_nb[GGML_MAX_DIMS + 2];\n        bias_nb[0] = sizeof(float);\n        for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {\n            bias_nb[i] = bias_nb[i - 1] * bias_ne[i - 1];\n        }\n        acl_tensor_ptr bias_tensor =\n            ggml_cann_create_tensor(bias_buffer, ACL_FLOAT, sizeof(float), bias_ne, bias_nb, GGML_MAX_DIMS + 2);\n\n        aclnn_mul(ctx, acl_slope.get(), acl_mask.get(), bias_tensor.get());\n        aclnn_add(ctx, acl_dst.get(), bias_tensor.get());\n    } else {\n        aclnn_add(ctx, acl_dst.get(), acl_mask.get());\n    }\n}\n\nvoid ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_cann_dup(ctx, dst);\n}\n\n/**\n * @brief Applies the softmax function to a tensor along a specified dimension.\n *\n * This function computes the softmax of the source tensor `acl_src` along the\n * specified dimension `dim` and stores the result in the destination tensor\n * `acl_dst`.\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src The source tensor on which the softmax function will be\n * applied.\n * @param dim The dimension along which the softmax function will be computed.\n * @param acl_dst The destination tensor where the softmax results will be\n * stored.\n */\nstatic void aclnn_softmax(ggml_backend_cann_context & ctx, aclTensor * acl_src, int64_t dim, aclTensor * acl_dst) {\n    GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst);\n}\n\nvoid ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];\n    ggml_tensor * src1 = dst->src[1];  // mask\n\n    acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);\n    acl_tensor_ptr acl_dst  = ggml_cann_create_tensor(dst);\n\n    float scale    = 1.0f;\n    float max_bias = 0.0f;\n\n    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));\n    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));\n\n    // input mul scale\n    acl_scalar_ptr       acl_scale = ggml_cann_create_scalar(&scale, aclDataType::ACL_FLOAT);\n    ggml_cann_pool_alloc src_tensor_allocator(ctx.pool(), ggml_nbytes(src0));\n    void *               src_tensor_buffer = src_tensor_allocator.get();\n    acl_tensor_ptr       softmax_tensor = ggml_cann_create_tensor(src_tensor_buffer, ggml_cann_type_mapping(src0->type),\n                                                                  ggml_element_size(src0), src0->ne, src0->nb, GGML_MAX_DIMS);\n\n    aclnn_muls(ctx, acl_src0.get(), scale, softmax_tensor.get(), false);\n\n    // mask\n    if (src1) {\n        aclnn_add_alibi(ctx, src1, src0, src_tensor_buffer, max_bias);\n    }\n    // softmax\n    aclnn_softmax(ctx, softmax_tensor.get(), 3, acl_dst.get());\n}\n\n/**\n * @brief Performs index select operation on a 4D tensor using the CANN backend.\n *\n * This function applies the `IndexSelect` operation along a specific dimension\n * of the source tensor (`src_buffer`) using the indices from the index tensor (`index`).\n * It iterates over the last two dimensions of the source tensor, creates the corresponding\n * CANN tensors for the source, index, and output slices, and executes the `IndexSelect`\n * operation for each slice.\n *\n * @param ctx The context for CANN backend operations.\n * @param src_buffer The source buffer containing the 4D input tensor data.\n * @param src_ne The dimensions of the source tensor.\n * @param src_nb The strides (byte offsets) of the source tensor.\n * @param dst_buffer The destination buffer where the output tensor data will be written.\n * @param dst_ne The dimensions of the destination tensor.\n * @param dst_nb The strides (byte offsets) of the destination tensor.\n * @param index The index tensor specifying the indices to select from the source tensor.\n * @param type The data type of the source and destination tensors.\n */\nstatic void aclnn_index_select_4d(ggml_backend_cann_context & ctx,\n                                  void *                      src_buffer,\n                                  int64_t *                   src_ne,\n                                  size_t *                    src_nb,\n                                  void *                      dst_buffer,\n                                  int64_t *                   dst_ne,\n                                  size_t *                    dst_nb,\n                                  ggml_tensor *               index,\n                                  ggml_type                   type) {\n    for (int64_t i = 0; i < src_ne[3]; i++) {\n        for (int64_t j = 0; j < src_ne[2]; j++) {\n            // src\n            acl_tensor_ptr acl_src_tensor =\n                ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2],\n                                        ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2);\n\n            // index\n            acl_tensor_ptr acl_index = ggml_cann_create_tensor(\n                (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],\n                ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1);\n\n            // out\n            acl_tensor_ptr acl_out =\n                ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2],\n                                        ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2);\n            GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, acl_src_tensor.get(), 0, acl_index.get(), acl_out.get());\n        }\n    }\n}\n\n/**\n * @brief Performs inplace index copy operation on a 4D tensor using the CANN backend.\n *\n * This function applies the `IndexCopy` operation along a specific dimension of the\n * destination tensor (`dst_buffer`) by copying elements from the source tensor (`src_buffer`)\n * to positions specified by the index tensor (`index`).\n * It iterates over the last two dimensions of the tensors, creates the corresponding\n * CANN tensors for source, index, and destination slices, and performs the index copy\n * operation for each slice.\n *\n * @param ctx The context for CANN backend operations.\n * @param src_buffer The source buffer containing the 4D input tensor data to be copied.\n * @param src_ne The dimensions of the source tensor.\n * @param src_nb The strides (byte offsets) of the source tensor.\n * @param dst_buffer The destination buffer where values will be copied to.\n * @param dst_ne The dimensions of the destination tensor.\n * @param dst_nb The strides (byte offsets) of the destination tensor.\n * @param index The index tensor specifying target positions in the destination tensor.\n * @param type The data type of the source and destination tensors.\n */\nstatic void aclnn_index_copy_4d(ggml_backend_cann_context & ctx,\n                                void *                      src_buffer,\n                                int64_t *                   src_ne,\n                                size_t *                    src_nb,\n                                void *                      dst_buffer,\n                                int64_t *                   dst_ne,\n                                size_t *                    dst_nb,\n                                ggml_tensor *               index,\n                                ggml_type                   type) {\n    for (int64_t i = 0; i < src_ne[3]; i++) {\n        for (int64_t j = 0; j < src_ne[2]; j++) {\n            // src\n            acl_tensor_ptr acl_src_tensor =\n                ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2],\n                                        ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2);\n\n            // index\n            acl_tensor_ptr acl_index = ggml_cann_create_tensor(\n                (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],\n                ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1);\n\n            // out\n            acl_tensor_ptr acl_out =\n                ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2],\n                                        ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2);\n            GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_out.get(), 0, acl_index.get(), acl_src_tensor.get());\n        }\n    }\n}\n\nvoid ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];  // src\n    ggml_tensor * src1 = dst->src[1];  // index\n\n    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);\n\n    switch (src0->type) {\n        case GGML_TYPE_F16:\n        case GGML_TYPE_F32:\n            if (src0->type == dst->type) {\n                aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1,\n                                      dst->type);\n            } else {\n                acl_tensor_ptr       acl_src0 = ggml_cann_create_tensor(src0);\n                ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst));\n                void *               src_trans_buffer = src_buffer_allocator.get();\n                size_t               src_trans_nb[GGML_MAX_DIMS];\n                src_trans_nb[0] = dst->nb[0];\n                for (int i = 1; i < GGML_MAX_DIMS; i++) {\n                    src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];\n                }\n                acl_tensor_ptr src_trans_tensor =\n                    ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(dst->type),\n                                            ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);\n                aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));\n                aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,\n                                      dst->type);\n            }\n            break;\n        case GGML_TYPE_Q8_0:\n            {\n                // add 1 dim for bcast mul.\n                size_t  weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], dequant_nb[GGML_MAX_DIMS + 1];\n                int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne;\n                int64_t scale_offset = 0;\n                // [3,4,5,64] -> [3,4,5,2,32]\n                weight_ne[0]         = QK8_0;\n                weight_ne[1]         = src0->ne[0] / QK8_0;\n                weight_nb[0]         = sizeof(int8_t);\n                weight_nb[1]         = weight_nb[0] * weight_ne[0];\n                for (int i = 2; i < GGML_MAX_DIMS + 1; i++) {\n                    weight_ne[i] = src0->ne[i - 1];\n                    weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1];\n                }\n                // [3,4,5,64] -> [3,4,5,2,1]\n                scale_ne[0] = 1;\n                scale_ne[1] = src0->ne[0] / QK8_0;\n                scale_nb[0] = sizeof(uint16_t);\n                scale_nb[1] = scale_nb[0] * scale_ne[0];\n                for (int i = 2; i < GGML_MAX_DIMS + 1; i++) {\n                    scale_ne[i] = src0->ne[i - 1];\n                    scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1];\n                }\n                // [3,4,5,64] -> [3,4,5,2,32]\n                dequant_ne    = weight_ne;\n                dequant_nb[0] = ggml_type_size(dst->type);\n                for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {\n                    dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];\n                }\n                scale_offset = ggml_nelements(src0) * sizeof(int8_t);\n                ggml_cann_pool_alloc dequant_buffer_allocator(ctx.pool(),\n                                                              ggml_nelements(src0) * ggml_type_size(dst->type));\n                acl_tensor_ptr       acl_weight_tensor = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t),\n                                                                                 weight_ne, weight_nb, GGML_MAX_DIMS + 1);\n                acl_tensor_ptr       acl_scale_tensor =\n                    ggml_cann_create_tensor(src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,\n                                            GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);\n                acl_tensor_ptr dequant_tensor =\n                    ggml_cann_create_tensor(dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type),\n                                            ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);\n                aclnn_mul(ctx, acl_weight_tensor.get(), acl_scale_tensor.get(), dequant_tensor.get());\n                dequant_nb[0] = ggml_type_size(dst->type);\n                dequant_ne    = src0->ne;\n                for (int i = 1; i < GGML_MAX_DIMS; i++) {\n                    dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];\n                }\n                aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, dst->data, dst->ne,\n                                      dst->nb, src1, dst->type);\n                break;\n            }\n        default:\n            GGML_ABORT(\"Unsupported tensor type for GGML_OP_GET_ROWS\");\n            break;\n    }\n}\n\nvoid ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];  // src\n    ggml_tensor * src1 = dst->src[1];  // index\n\n    switch (dst->type) {\n        case GGML_TYPE_F32:\n            {\n                aclnn_index_copy_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, dst->type);\n                break;\n            }\n        case GGML_TYPE_F16:\n            {\n                acl_tensor_ptr       acl_src0 = ggml_cann_create_tensor(src0);\n                ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t));\n                void *               src_trans_buffer = src_buffer_allocator.get();\n                size_t               src_trans_nb[GGML_MAX_DIMS];\n                src_trans_nb[0] = sizeof(uint16_t);\n                for (int i = 1; i < GGML_MAX_DIMS; i++) {\n                    src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];\n                }\n                acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor(\n                    src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);\n                aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));\n                aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,\n                                    dst->type);\n                break;\n            }\n        default:\n            GGML_ABORT(\"Unsupported tensor type for GGML_OP_SET_ROWS\");\n            break;\n    }\n}\n\n/**\n * @brief Repeats elements of a tensor along a specified dimension.\n *\n * This function repeats each element of the source tensor `acl_src` a specified\n * number of times (`repeats`) along the specified dimension `dim` and stores\n * the result in the destination tensor `acl_dst`.\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src The source tensor whose elements will be repeated.\n * @param acl_dst The destination tensor where the repeated elements will be\n * stored.\n * @param dim The dimension along which the elements will be repeated.\n * @param repeats The number of times each element will be repeated.\n * @param output_size The size of the output tensor.\n */\nstatic void aclnn_repeat_interleave(ggml_backend_cann_context & ctx,\n                                    aclTensor *                 acl_src,\n                                    aclTensor *                 acl_dst,\n                                    int64_t                     dim,\n                                    int64_t                     repeats,\n                                    int64_t                     output_size) {\n    GGML_CANN_CALL_ACLNN_OP(ctx, RepeatInterleaveIntWithDim, acl_src, repeats, dim, output_size, acl_dst);\n}\n\n/**\n * @brief Performs matrix multiplication with floating-point precision on\n * tensors using the CANN backend.\n *\n * This function performs matrix multiplication of the input tensor and the\n * weight tensor, handling broadcasting and transposing as needed, and stores\n * the result in the destination tensor `dst`.\n *\n * @param ctx The context for the CANN backend operations.\n * @param dst The destination tensor where the result of the matrix\n * multiplication will be stored.\n */\nstatic void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * weight = dst->src[0];  // weight\n    ggml_tensor * input  = dst->src[1];  // input\n\n    // when weight ne2 or ne3 is 1, aclnnMatmulGetWorkspaceSize will auto\n    // broadcast, when weight ne2 or ne3 is not 1, weight need repeat.\n    BCAST_MUL_MAT_SHAPE(input, weight, dst);\n\n    int64_t n_dims = bcast_dims;\n    if (bcast_input_ne[3] == bcast_weight_ne[3] && bcast_input_ne[3] == 1) {\n        if (bcast_input_ne[2] == 1 && bcast_weight_ne[2] == 1) {\n            n_dims = 2;\n        } else if (bcast_input_ne[2] == 1) {\n            n_dims = 3;\n        }\n    }\n\n    acl_tensor_ptr acl_input_tensor = ggml_cann_create_tensor(input, bcast_input_ne, bcast_input_nb, n_dims);\n    int64_t        transpose_ne[]   = { bcast_weight_ne[1], bcast_weight_ne[0], bcast_weight_ne[2],\n                                        bcast_weight_ne[3], bcast_weight_ne[4], bcast_weight_ne[5] };\n    size_t         transpose_nb[]   = { bcast_weight_nb[1], bcast_weight_nb[0], bcast_weight_nb[2],\n                                        bcast_weight_nb[3], bcast_weight_nb[4], bcast_weight_nb[5] };\n    acl_tensor_ptr acl_weight_tensor;\n\n    // Only check env once.\n    static bool weight_to_nz = parse_bool(get_env_as_lowercase(\"GGML_CANN_WEIGHT_NZ\").value_or(\"on\"));\n    if (weight_to_nz && is_matmul_weight(weight)) {\n        acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);\n    } else {\n        acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);\n    }\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, bcast_dst_ne, bcast_dst_nb, n_dims);\n\n    switch (n_dims) {\n        case 2:\n            GGML_CANN_CALL_ACLNN_OP(ctx, Mm, acl_input_tensor.get(), acl_weight_tensor.get(), acl_dst.get(), 2);\n            break;\n        case 3:\n            GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, acl_input_tensor.get(), acl_weight_tensor.get(), acl_dst.get(),\n                                    2);\n            break;\n        default:\n            // ALLOW_FP32_DOWN_PRECISION, when input is\n            // fp32, atlas a2 will transpose it to HFLOAT32.\n            GGML_CANN_CALL_ACLNN_OP(ctx, Matmul, acl_input_tensor.get(), acl_weight_tensor.get(), acl_dst.get(), 1);\n            break;\n    }\n}\n\n/**\n * @brief Performs matrix multiplication with quantized weights and\n * floating-point inputs using the CANN backend.\n *\n * This function performs matrix multiplication of the input tensor `src1` and\n * the weight tensor `src0`, handling broadcasting, transposing, and\n * quantization as needed, and stores the result in the destination tensor\n * `dst`.\n *\n * @param ctx The context for the CANN backend operations.\n * @param dst The destination tensor where the result of the matrix\n * multiplication will be stored.\n */\nstatic void ggml_cann_mul_mat_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst, const enum ggml_type type) {\n    ggml_tensor * src0 = dst->src[0];  // weight\n    ggml_tensor * src1 = dst->src[1];  // input\n\n    // The shape of the weight is NCHW.\n    // Matrix multiplication uses HW dims.\n    // HC is regarded as batch.\n    // weight need transpose.\n    float weight_elem_size;\n    if (type == GGML_TYPE_Q4_0) {\n        weight_elem_size = float(sizeof(uint8_t)) / 2;\n    } else if (type == GGML_TYPE_Q8_0) {\n        weight_elem_size = float(sizeof(uint8_t));\n    } else {\n        GGML_ABORT(\"Only support Q4_0 and Q8_0 MUL_MAT\");\n    }\n    float  weight_nb[]   = { src0->ne[0] * weight_elem_size, weight_elem_size };\n    size_t weight_stride = src0->ne[1] * src0->ne[0] * weight_elem_size;\n    size_t weight_size   = weight_stride * src0->ne[2] * src0->ne[3];\n\n    // scale stored at the end of weight. Also need transpose.\n    size_t scale_elem_size = sizeof(uint16_t);\n    size_t scale_nb[]      = { src0->ne[0] / QK8_0 * scale_elem_size, scale_elem_size };\n    size_t scale_stride    = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;\n    char * scale_offset    = (char *) src0->data + weight_size;\n\n    // input\n    size_t               input_elem_size = sizeof(uint16_t);\n    int64_t              input_ne[]      = { src1->ne[0], src1->ne[1] };\n    size_t               input_nb[]      = { input_elem_size, input_ne[0] * input_elem_size };\n    size_t               input_stride    = input_ne[0] * input_ne[1] * input_elem_size;\n    ggml_cann_pool_alloc input_alloctor(ctx.pool());\n    void *               input_buffer = src1->data;\n\n    // case in\n    if (src1->type != GGML_TYPE_F16) {\n        acl_tensor_ptr acl_src1_tensor = ggml_cann_create_tensor(src1);\n        input_buffer                   = input_alloctor.alloc(ggml_nelements(src1) * input_elem_size);\n\n        int64_t * input_cast_ne = src1->ne;\n        size_t    input_cast_nb[GGML_MAX_DIMS];\n        input_cast_nb[0] = sizeof(uint16_t);\n        for (int i = 1; i < GGML_MAX_DIMS; i++) {\n            input_cast_nb[i] = input_cast_nb[i - 1] * input_cast_ne[i - 1];\n        }\n\n        acl_tensor_ptr acl_input_tensor = ggml_cann_create_tensor(input_buffer, ACL_FLOAT16, input_elem_size,\n                                                                  input_cast_ne, input_cast_nb, GGML_MAX_DIMS);\n        aclnn_cast(ctx, acl_src1_tensor.get(), acl_input_tensor.get(), ACL_FLOAT16);\n    }\n\n    // output\n    size_t               output_elem_size = sizeof(uint16_t);\n    size_t               output_nb[]      = { output_elem_size, dst->ne[0] * output_elem_size };\n    ggml_cann_pool_alloc output_allocator(ctx.pool());\n    void *               output_buffer = output_allocator.alloc(ggml_nelements(dst) * output_elem_size);\n    size_t               output_stride = dst->ne[0] * dst->ne[1] * output_elem_size;\n\n    // aclnn\n    int64_t              max_elem_size = 65535;\n    int64_t              split_size    = (src0->ne[1] / max_elem_size) + 1;\n    ggml_cann_pool_alloc workspace_allocator(ctx.pool());\n    for (int64_t n1 = 0; n1 < src1->ne[3]; n1++) {\n        for (int64_t c1 = 0; c1 < src1->ne[2]; c1++) {\n            int64_t n0 = n1 / (src1->ne[3] / src0->ne[3]);\n            int64_t c0 = c1 / (src1->ne[2] / src0->ne[2]);\n\n            int64_t batch1 = (n1 * src1->ne[2]) + c1;\n            int64_t batch0 = (n0 * src0->ne[2]) + c0;\n\n            acl_tensor_ptr acl_input_tensor = ggml_cann_create_tensor(\n                (char *) input_buffer + batch1 * input_stride, ACL_FLOAT16, input_elem_size, input_ne, input_nb, 2);\n\n            // first split\n            int64_t weight_ne_offset = 0;\n            int64_t weight_ne[2]     = { max_elem_size > src0->ne[1] ? src0->ne[1] : max_elem_size, src0->ne[0] };\n            int64_t scale_ne_offset  = 0;\n            int64_t scale_ne[2]      = { weight_ne[0], weight_ne[1] / QK8_0 };\n            int64_t output_ne_offset = 0;\n            int64_t output_ne[2]     = { weight_ne[0], dst->ne[1] };\n\n            acl_tensor_ptr acl_weight_tensor =\n                ggml_cann_create_tensor((char *) src0->data + batch0 * weight_stride, ggml_cann_type_mapping(type),\n                                        weight_elem_size, weight_ne, weight_nb, 2, ACL_FORMAT_ND, weight_ne_offset);\n            acl_tensor_ptr acl_scale_tensor =\n                ggml_cann_create_tensor(scale_offset + batch0 * scale_stride, ACL_FLOAT16, scale_elem_size, scale_ne,\n                                        scale_nb, 2, ACL_FORMAT_ND, scale_ne_offset);\n            acl_tensor_ptr acl_output_tensor =\n                ggml_cann_create_tensor((char *) output_buffer + batch1 * output_stride, ACL_FLOAT16, output_elem_size,\n                                        output_ne, output_nb, 2, ACL_FORMAT_ND, output_ne_offset);\n            int64_t antiquantGroupSize = 0;\n            if (src0->ne[0] > QK8_0) {\n                antiquantGroupSize = QK8_0;\n            }\n            GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor.get(), acl_weight_tensor.get(),\n                                    acl_scale_tensor.get(), nullptr, nullptr, nullptr, nullptr, antiquantGroupSize,\n                                    acl_output_tensor.get());\n\n            // other splits\n            for (int64_t split = 1; split < split_size; split++) {\n                weight_ne_offset += weight_elem_size * weight_ne[0] * weight_ne[1];\n                weight_ne[0] =\n                    max_elem_size * (split + 1) > src0->ne[1] ? src0->ne[1] - (max_elem_size * split) : max_elem_size;\n                scale_ne_offset += scale_elem_size * scale_ne[0] * scale_ne[1];\n                scale_ne[0] = weight_ne[0];\n                output_ne_offset += output_elem_size * output_ne[0] * output_ne[1];\n                output_ne[0] = weight_ne[0];\n\n                acl_weight_tensor =\n                    ggml_cann_create_tensor((char *) src0->data + batch0 * weight_stride, ggml_cann_type_mapping(type),\n                                            weight_elem_size, weight_ne, weight_nb, 2, ACL_FORMAT_ND, weight_ne_offset);\n                acl_scale_tensor =\n                    ggml_cann_create_tensor(scale_offset + batch0 * scale_stride, ACL_FLOAT16, scale_elem_size,\n                                            scale_ne, scale_nb, 2, ACL_FORMAT_ND, scale_ne_offset);\n                acl_output_tensor =\n                    ggml_cann_create_tensor((char *) output_buffer + batch1 * output_stride, ACL_FLOAT16,\n                                            output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND, output_ne_offset);\n                GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor.get(), acl_weight_tensor.get(),\n                                        acl_scale_tensor.get(), nullptr, nullptr, nullptr, nullptr, antiquantGroupSize,\n                                        acl_output_tensor.get());\n            }\n        }\n    }\n\n    // cast out\n    if (dst->type != GGML_TYPE_F16) {\n        int64_t * output_cast_ne = dst->ne;\n        size_t    output_cast_nb[GGML_MAX_DIMS];\n        output_cast_nb[0] = sizeof(uint16_t);\n        for (int i = 1; i < GGML_MAX_DIMS; i++) {\n            output_cast_nb[i] = output_cast_nb[i - 1] * output_cast_ne[i - 1];\n        }\n\n        acl_tensor_ptr acl_output_tensor = ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, output_elem_size,\n                                                                   output_cast_ne, output_cast_nb, GGML_MAX_DIMS);\n        acl_tensor_ptr acl_dst_tensor    = ggml_cann_create_tensor(dst);\n        aclnn_cast(ctx, acl_output_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));\n    }\n}\n\nvoid ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    const enum ggml_type type = dst->src[0]->type;\n    switch (type) {\n        case GGML_TYPE_F32:\n        case GGML_TYPE_F16:\n            ggml_cann_mat_mul_fp(ctx, dst);\n            break;\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q8_0:\n            ggml_cann_mul_mat_quant(ctx, dst, type);\n            break;\n        default:\n            GGML_ABORT(\"Unsupported type for mul_mat\");\n            break;\n    }\n}\n\n/**\n * @brief Rolls the elements of a tensor along a specified dimension.\n *\n * This function rolls the elements of the source tensor `acl_src` by the\n * specified shifts `shifts` along the specified dimensions `dims`, and stores\n * the result in the destination tensor `acl_dst`.\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src The source tensor whose elements will be rolled.\n * @param acl_dst The destination tensor where the rolled elements will be\n * stored.\n * @param shifts An array specifying the number of positions by which elements\n * are shifted.\n * @param dims An array specifying the dimensions along which elements are\n * shifted.\n */\nstatic void aclnn_roll(ggml_backend_cann_context & ctx,\n                       aclTensor *                 acl_src,\n                       aclTensor *                 acl_dst,\n                       int64_t *                   shifts,\n                       int64_t *                   dims) {\n    acl_int_array_ptr acl_shifts = ggml_cann_create_int_array(shifts, 1);\n    acl_int_array_ptr acl_dims   = ggml_cann_create_int_array(dims, 1);\n    GGML_CANN_CALL_ACLNN_OP(ctx, Roll, acl_src, acl_shifts.get(), acl_dims.get(), acl_dst);\n}\n\n/**\n * @brief Fills specified positions of a tensor with a scalar value.\n *\n * This function fills the positions in the source tensor `acl_src` specified by\n * `index` along the dimension `dim` with the scalar value `value`.\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src The source tensor where the positions will be filled.\n * @param dim The dimension along which the positions are specified.\n * @param index An array specifying the positions to be filled.\n * @param index_num The number of positions specified in the index array.\n * @param value The scalar value used to fill the specified positions.\n */\nstatic void aclnn_index_fill_tensor(ggml_backend_cann_context & ctx,\n                                    aclTensor *                 acl_src,\n                                    int64_t                     dim,\n                                    int64_t *                   index,\n                                    int64_t                     index_num,\n                                    float                       value) {\n    acl_int_array_ptr acl_index = ggml_cann_create_int_array(index, index_num);\n    acl_scalar_ptr    acl_value = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT);\n    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexFillTensor, acl_src, dim, acl_index.get(), acl_value.get());\n}\n\n/**\n * @brief Initializes and caches all intermediate tensors required for RoPE\n *        (Rotary Position Embedding), including support for Yarn, mRoPE,\n *        i-mRoPE, Neox repeat strategy, independent sectors, frequency factors，\n *        and multi-section rotary groups.\n *\n * This function computes and caches the per-dimension θ coefficients used for\n * Q/K rotary embedding. The cache is shared across layers, and recomputed only\n * when any dependent parameter changes.\n *\n * The function now supports:\n *   - Yarn RoPE extrapolation (via @param corr_dims and @param ext_factor)\n *   - Per-dimension independent sector exponent rules (indep_sects + sections[])\n *   - Multi-section RoPE (mRoPE) index mapping (mrope_used + is_imrope)\n *   - Frequency factor division (src2)\n *   - Neox / normal repeat expansion modes\n *\n * @param ctx                CANN backend context, containing memory pool,\n *                           cached buffers, and runtime stream.\n * @param dst                Destination ggml_tensor whose computation\n *                           depends on RoPE (typically Qcur or Kcur).\n * @param corr_dims          [low, high] Yarn correction range.\n * @param ext_factor         Yarn extrapolation strength. 0 = disabled.\n * @param theta_scale        Base multiplier for per-dimension θ exponent.\n * @param freq_scale         Global frequency scaling factor.\n * @param attn_factor        Optional scaling applied to sin/cos (if needed).\n * @param is_neox            Whether to use Neox-style dimension interleave.\n * @param sections           4-way sector sizes for independent-section RoPE\n *                           and multi-section mRoPE (t/h/w/e).\n * @param mrope_used         Whether to enable multi-section rotary embedding.\n * @param is_imrope          Whether to apply interleaved mRoPE rules.\n * @param indep_sects        Whether each dimension runs independent exponent\n *                           resets based on @p sections.\n */\nstatic void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,\n                                  ggml_tensor *               dst,\n                                  float *                     corr_dims,\n                                  float                       ext_factor,\n                                  float                       theta_scale,\n                                  float                       freq_scale,\n                                  float                       attn_factor,\n                                  bool                        is_neox,\n                                  int                         sections[4],\n                                  bool                        mrope_used,\n                                  bool                        is_imrope,\n                                  bool                        indep_sects,\n                                  int64_t                     rope_dims) {\n    ggml_tensor * src1 = dst->src[1];  // position\n    ggml_tensor * src2 = dst->src[2];  // freq_factors\n\n    int64_t theta_scale_length = rope_dims / 2;\n    int64_t position_length    = dst->ne[2];\n\n    // TODO: check theta_scale_length and position_length.\n    if (src2 == nullptr && ctx.rope_cache.cached &&\n        ctx.rope_cache.equal(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor,\n                             is_neox, indep_sects, mrope_used, is_imrope, sections)) {\n        // use cache.\n        return;\n    }\n\n    // Step0: calculate tensor shape.\n    int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 };\n    size_t  theta_scale_nb[] = { sizeof(float), theta_scale_length * sizeof(float), theta_scale_length * sizeof(float),\n                                 theta_scale_length * sizeof(float) };\n\n    GGML_ASSERT(src1->type == GGML_TYPE_I32);\n    int64_t position_ne[] = { 1, 1, position_length, 1 };\n    size_t  position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length };\n\n    int64_t cache_ne[] = { theta_scale_length, 1, position_length, 1 };\n    size_t  cache_nb[GGML_MAX_DIMS];\n    cache_nb[0] = sizeof(float);\n    for (int i = 1; i < GGML_MAX_DIMS; i++) {\n        cache_nb[i] = cache_nb[i - 1] * cache_ne[i - 1];\n    }\n\n    // Step1: Compute the coefficient of theta. During the cache_init process, aside from\n    // (1) multiplying by the position,\n    // (2) dividing by freq_factors,\n    // (3) computing the sine and cosine,\n    // the other parameters used in the computation generally do not change in most scenarios.\n    // Therefore, we can first compute this part of the result and then cache it.\n\n    // Step1.1: prepare theta_scale exponent. if this exponent updated, should update theta_scale_tensor.\n    acl_tensor_ptr acl_theta_scale_tensor;\n    bool           theta_scale_updated = false;\n    if (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.theta_scale != theta_scale ||\n        ctx.rope_cache.indep_sects != indep_sects) {\n        theta_scale_updated = true;\n        if (ctx.rope_cache.theta_scale_exp_host != nullptr) {\n            free(ctx.rope_cache.theta_scale_exp_host);\n        }\n        ctx.rope_cache.theta_scale_exp_host = (float *) malloc(theta_scale_length * sizeof(float));\n        GGML_ASSERT(ctx.rope_cache.theta_scale_exp_host != nullptr);\n        if (!indep_sects) {\n            ctx.rope_cache.theta_scale_exp_host[0] = 1;\n            for (int i = 1; i < theta_scale_length; i++) {\n                ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;\n            }\n        } else {\n            int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];\n            int sec_w     = sections[1] + sections[0];\n            int sec_e     = sections[2] + sec_w;\n\n            ctx.rope_cache.theta_scale_exp_host[0] = 1;\n            for (int i = 1; i < theta_scale_length; i++) {\n                int sector = i % sect_dims;\n                if (sector == 0 || sector == sections[0] || sector == sec_w || sector == sec_e) {\n                    ctx.rope_cache.theta_scale_exp_host[i] = 1;\n                    continue;\n                }\n                ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;\n            }\n        }\n\n        if (ctx.rope_cache.theta_scale_cache != nullptr) {\n            ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));\n        }\n        ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),\n                              ACL_MEM_MALLOC_HUGE_FIRST));\n\n        ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),\n                                   ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float),\n                                   ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));\n    }\n    acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),\n                                                     theta_scale_ne, theta_scale_nb, 1);\n\n    // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.\n    // TODO: acl_yarn_ramp_tensor use rope cache.\n    bool           yarn_ramp_tensor_updated = false;\n    acl_tensor_ptr acl_yarn_ramp_tensor;\n    if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length ||\n                            ctx.rope_cache.freq_scale != freq_scale)) {\n        yarn_ramp_tensor_updated = true;\n        if (ctx.rope_cache.yarn_ramp_cache != nullptr) {\n            ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache));\n        }\n        ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float),\n                              ACL_MEM_MALLOC_HUGE_FIRST));\n        // -rope_yarn_ramp\n        // const float y = (i0 / 2 - low) / MAX(0.001f, high - low);\n        // return MIN(1, MAX(0, y)) - 1;\n        acl_yarn_ramp_tensor      = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float),\n                                                            theta_scale_ne, theta_scale_nb, 1);\n        float          zero_value = 0, one_value = 1;\n        float          denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);\n        acl_scalar_ptr low              = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);\n        acl_scalar_ptr zero             = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT);\n        acl_scalar_ptr one              = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT);\n        acl_scalar_ptr denom_safe       = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT);\n        acl_scalar_ptr ext_factor_sc    = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT);\n\n        aclnn_arange(ctx, acl_yarn_ramp_tensor.get(), 0, theta_scale_length, 1, theta_scale_length);\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), low.get(), one.get());\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get());\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get());\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get());\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get());\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get());\n\n        // theta_interp = freq_scale * theta_extrap;\n        // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;\n        // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;\n        // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;\n        // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);\n        //\n        // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse\n        // cache freq_scale + (freq_scale - 1) * ramp_mix\n        float          freq_scale_1    = freq_scale - 1;\n        acl_scalar_ptr freq_scale_sc   = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT);\n        acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());\n    } else {\n        acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float),\n                                                       theta_scale_ne, theta_scale_nb, 1);\n    }\n    // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.\n    if (ext_factor != 0) {\n        if (theta_scale_updated || yarn_ramp_tensor_updated) {\n            theta_scale_updated = true;\n            aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get());\n        }\n    } else {\n        if (freq_scale != 1 && (ctx.rope_cache.freq_scale != freq_scale || theta_scale_updated)) {\n            theta_scale_updated = true;\n            aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);\n        }\n    }\n\n    // Nothing changed, use cache.\n    if (!theta_scale_updated) {\n        acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),\n                                                         theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);\n    }\n\n    // Step 1.4: prepare select index if mrope\n    acl_tensor_ptr position_select_index_tensor;\n    if (mrope_used) {\n        if (ctx.rope_cache.sections[0] != sections[0] || ctx.rope_cache.sections[1] != sections[1] ||\n            ctx.rope_cache.sections[2] != sections[2] || ctx.rope_cache.sections[3] != sections[3] ||\n            ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.is_imrope != is_imrope) {\n            if (ctx.rope_cache.position_select_index_host != nullptr) {\n                free(ctx.rope_cache.position_select_index_host);\n            }\n            ctx.rope_cache.position_select_index_host = (int *) malloc(theta_scale_length * sizeof(int));\n            GGML_ASSERT(ctx.rope_cache.position_select_index_host != nullptr);\n            int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];\n            int sec_w     = sections[1] + sections[0];\n            int sec_e     = sections[2] + sec_w;\n            // t,h,w,e\n            for (int i = 0; i < theta_scale_length; i++) {\n                int sector = i % sect_dims;\n\n                if (is_imrope) {  // qwen3vl apply interleaved mrope\n                    if (sector % 3 == 1 && sector < 3 * sections[1]) {\n                        ctx.rope_cache.position_select_index_host[i] = 1;\n                    } else if (sector % 3 == 2 && sector < 3 * sections[2]) {\n                        ctx.rope_cache.position_select_index_host[i] = 2;\n                    } else if (sector % 3 == 0 && sector < 3 * sections[0]) {\n                        ctx.rope_cache.position_select_index_host[i] = 0;\n                    } else {\n                        ctx.rope_cache.position_select_index_host[i] = 3;\n                    }\n                } else {\n                    if (sector >= sections[0] && sector < sec_w) {\n                        ctx.rope_cache.position_select_index_host[i] = 1;\n                    } else if (sector >= sec_w && sector < sec_e) {\n                        ctx.rope_cache.position_select_index_host[i] = 2;\n                    } else if (sector >= sec_e) {\n                        ctx.rope_cache.position_select_index_host[i] = 3;\n                    } else {\n                        ctx.rope_cache.position_select_index_host[i] = 0;\n                    }\n                }\n            }\n\n            if (ctx.rope_cache.position_select_index != nullptr) {\n                ACL_CHECK(aclrtFree(ctx.rope_cache.position_select_index));\n            }\n            ACL_CHECK(aclrtMalloc(&ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),\n                                  ACL_MEM_MALLOC_HUGE_FIRST));\n\n            ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),\n                                       ctx.rope_cache.position_select_index_host, theta_scale_length * sizeof(int),\n                                       ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));\n        }\n\n        position_select_index_tensor = ggml_cann_create_tensor(ctx.rope_cache.position_select_index, ACL_INT32,\n                                                               sizeof(int), theta_scale_ne, theta_scale_nb, 1);\n    }\n\n    // Step2: divide by freq_factors\n    ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());\n    if (src2) {\n        freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float));\n        void *         freq_fac_res_ptr = freq_fac_res_allocator.get();\n        acl_tensor_ptr acl_freq_factors_tensor =\n            ggml_cann_create_tensor(src2->data, ggml_cann_type_mapping(src2->type), ggml_type_size(src2->type),\n                                    theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);\n        acl_tensor_ptr acl_freq_fac_res_tensor = ggml_cann_create_tensor(freq_fac_res_ptr, ACL_FLOAT, sizeof(float),\n                                                                         theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);\n        aclnn_div(ctx, acl_theta_scale_tensor.get(), acl_freq_factors_tensor.get(), acl_freq_fac_res_tensor.get());\n        std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);\n    }\n\n    // Step3: prepare position_tensor\n    acl_tensor_ptr       acl_position_tensor;\n    ggml_cann_pool_alloc mrope_position_acllocator(ctx.pool());\n    if (mrope_used) {\n        // Step3.1: select current position;\n        // position :\n        // pos1: [[0, 1 ,2 ,3 ],\n        // pos2:  [4, 5 ,6 ,7 ],\n        // pos3:  [8, 9 ,10,11],\n        // pos4:  [12,13,14,15] ]\n        //\n        // select index = [0, 1, 2, 2, 1, 0]\n        //\n        // selected_tensor:\n        // [[0, 1 ,2 ,3 ],\n        //  [4, 5 ,6 ,7 ],\n        //  [8, 9 ,10,11],\n        //  [8, 9 ,10,11],\n        //  [4, 5 ,6 ,7 ],\n        //  [0, 1 ,2 ,3 ]]\n        //\n        // transpose, from [seq_len:dims] to [dims:seq_len]\n        // [0, 4, 8 ,8 ,4, 0],\n        // [1, 5, 9, 9, 5, 1],\n        // [2, 6, 10,10,6 ,2],\n        // [3, 7, 11,11,7 3 ]]\n        //\n        // multipy by theta_scale_tensor\n        // [theta_scale^0, theta_scale^1, ..., theta_scale ^ n]\n\n        int64_t        mrope_position_ne[] = { position_length, 4 };\n        size_t         mrope_position_nb[] = { sizeof(int), position_length * sizeof(int) };\n        acl_tensor_ptr mrope_position =\n            ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),\n                                    mrope_position_ne, mrope_position_nb, 2);\n\n        // selected position tensor's shape is a transpose of cache tensor.\n        int64_t selected_position_ne[] = { position_length, theta_scale_length };\n        size_t  selected_position_nb[] = { sizeof(float), position_length * sizeof(float) };\n        mrope_position_acllocator.alloc(theta_scale_length * position_length * sizeof(float));\n        void * mrope_position_buffer = mrope_position_acllocator.get();\n        acl_position_tensor =\n            ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),\n                                    ggml_type_size(src1->type), selected_position_ne, selected_position_nb, 2);\n        GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, mrope_position.get(), 0, position_select_index_tensor.get(),\n                                acl_position_tensor.get());\n\n        // transpose\n        int64_t transposed_ne[] = { position_length, 1, theta_scale_length, 1 };\n        size_t  transposed_nb[GGML_MAX_DIMS];\n        transposed_nb[0] = sizeof(float);\n        for (int i = 1; i < GGML_MAX_DIMS; i++) {\n            transposed_nb[i] = transposed_nb[i - 1] * transposed_ne[i - 1];\n        }\n\n        std::swap(transposed_ne[0], transposed_ne[2]);\n        std::swap(transposed_nb[0], transposed_nb[2]);\n\n        acl_position_tensor =\n            ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),\n                                    ggml_type_size(src1->type), transposed_ne, transposed_nb, GGML_MAX_DIMS);\n\n    } else {\n        // auto bcast.\n        acl_position_tensor =\n            ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),\n                                    position_ne, position_nb, GGML_MAX_DIMS);\n    }\n\n    // Step4: multiply by the position\n    int64_t              theta_length = theta_scale_length * position_length;\n    ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));\n    void *               theta_buffer = theta_allocator.get();\n\n    acl_tensor_ptr acl_theta_tensor =\n        ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS);\n    aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());\n\n    // Step5: calculate sin cos.\n    // init sin_repeat && cos_repeat, only to accelerate first layer on each device\n    if (position_length > ctx.rope_cache.position_length) {\n        ctx.rope_cache.position_length = position_length;\n        if (ctx.rope_cache.sin_cache != nullptr) {\n            ACL_CHECK(aclrtFree(ctx.rope_cache.sin_cache));\n        }\n        if (ctx.rope_cache.cos_cache != nullptr) {\n            ACL_CHECK(aclrtFree(ctx.rope_cache.cos_cache));\n        }\n        int64_t repeat_theta_length = theta_scale_length * position_length * 2;\n        ACL_CHECK(\n            aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));\n        ACL_CHECK(\n            aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));\n    }\n\n    // sin/cos\n    ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float));\n    void *               sin_buffer = sin_allocator.get();\n    acl_tensor_ptr       acl_sin_tensor =\n        ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);\n    aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get());\n\n    ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float));\n    void *               cos_buffer = cos_allocator.get();\n    acl_tensor_ptr       acl_cos_tensor =\n        ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);\n    aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get());\n\n    if (ext_factor != 0) {\n        attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);\n    }\n\n    // Step 5: multiply by attn_factor\n    if (attn_factor != 1) {\n        aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true);\n        aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true);\n    }\n\n    int64_t sin_reshape_ne[4] = { rope_dims, 1, dst->ne[2], 1 };\n    size_t  sin_reshape_nb[GGML_MAX_DIMS];\n    sin_reshape_nb[0] = sizeof(float);\n    for (int i = 1; i < GGML_MAX_DIMS; i++) {\n        sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];\n    }\n    acl_tensor_ptr acl_sin_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float),\n                                                                   sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);\n    acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),\n                                                                   sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);\n\n    // Step 6: repeat\n    if (is_neox) {\n        // [sinθ1, sinθ1, sinθ2, sinθ2, ..., sinθn, sinθn]\n        int64_t repeatsArray[] = { 1, 1, 1, 2 };\n        aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray);\n        aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray);\n    } else {\n        int64_t num_repeats = 2;\n        int64_t dim         = 3;\n        int64_t output_size = theta_scale_length * num_repeats;\n        // [sinθ1, sinθ2, ..., sinθn, sinθ1, sinθ2, ..., sinθn]\n        aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size);\n        aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size);\n    }\n\n    // Update cached value.\n    ctx.rope_cache.cached = true;\n    ctx.rope_cache.set(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, is_neox,\n                       indep_sects, mrope_used, is_imrope, sections);\n}\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\naclnnStatus aclnnRotaryPositionEmbeddingGetWorkspaceSize(const aclTensor * x,\n                                                         const aclTensor * cos,\n                                                         const aclTensor * sin,\n                                                         int64_t           mode,\n                                                         const aclTensor * yOut,\n                                                         uint64_t *        workspaceSize,\n                                                         aclOpExecutor **  executor);\naclnnStatus aclnnRotaryPositionEmbedding(void *          workspace,\n                                         uint64_t        workspaceSize,\n                                         aclOpExecutor * executor,\n                                         aclrtStream     stream);\n#ifdef __cplusplus\n}\n#endif\n\nvoid ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];  // input\n\n    // param\n    float     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;\n    int       sections[4];\n    // const int n_past     = ((int32_t *) dst->op_params)[0];\n    const int n_dims     = ((int32_t *) dst->op_params)[1];\n    const int mode       = ((int32_t *) dst->op_params)[2];\n    // const int n_ctx      = ((int32_t *) dst->op_params)[3];\n    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));\n    memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));\n    memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));\n    memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));\n    memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));\n    memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));\n    memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int) * 4);\n\n    GGML_ASSERT(n_dims % 2 == 0);\n    GGML_ASSERT(n_dims <= ne00);\n\n    const float theta_scale = powf(freq_base, -2.0f / n_dims);\n\n    float corr_dims[2];\n    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);\n\n    bool       is_neox    = mode & GGML_ROPE_TYPE_NEOX;\n    const bool is_imrope  = mode == GGML_ROPE_TYPE_IMROPE;  // qwen3vl apply interleaved mrope\n    // mrope_used means the GGML_ROPE_TYPE_MROPE bit is set.\n    // Note: this bit is also set for imrope and some vision modes,\n    // so mrope_used does NOT exclusively indicate pure mrope.\n    const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;\n    const bool is_vision  = mode == GGML_ROPE_TYPE_VISION;\n\n    if (mrope_used) {\n        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);\n    }\n\n    if (is_vision) {\n        GGML_ASSERT(n_dims == ne0 / 2);\n    }\n\n    if (is_imrope || mrope_used) {\n        is_neox = true;\n    }\n\n    int64_t rope_dims = n_dims;\n\n    //Our current RotaryPositionEmbedding does not support the VISION mode,\n    //but essentially it only modifies theta_base in mrope,\n    //then repeats it at the end in the same way as is_neox.\n    //In fact, RoPE is still applied across all dimensions.\n    if (is_vision) {\n        rope_dims = src0->ne[0];\n    }\n    int64_t tail_dims = ne00 - rope_dims;\n    bool    has_tail  = tail_dims > 0;\n\n    // init ctx.rope_cos/rope_sin cache\n    aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections,\n                          mrope_used, is_imrope, is_vision, rope_dims);\n\n    // Cache is generated with ne00 dimensions, so we use ne00 for reshape\n    int64_t sin_reshape_ne[4] = { rope_dims, 1, ne02, 1 };\n    size_t  sin_reshape_nb[GGML_MAX_DIMS];\n    sin_reshape_nb[0] = sizeof(float);\n    for (int i = 1; i < GGML_MAX_DIMS; i++) {\n        sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];\n    }\n    acl_tensor_ptr acl_sin_reshape_tensor = ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float),\n                                                                    sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);\n    acl_tensor_ptr acl_cos_reshape_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),\n                                                                    sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n#ifdef ASCEND_310P\n    // Special ROPE operation for 310P\n\n    // roll input\n    void *               input_roll_buffer;\n    acl_tensor_ptr       acl_minus_one_tensor;\n    void *               minus_one_scale_buffer = nullptr;\n    ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0));\n    ggml_cann_pool_alloc minus_one_scale_allocator(ctx.pool(), sizeof(float) * src0->ne[0]);\n    if (!is_neox) {\n        // roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...]\n        input_roll_buffer        = roll_allocator.get();\n        int64_t input_roll_ne[4] = { 2, src0->ne[1] * (src0->ne[0] / 2), src0->ne[2], src0->ne[3] };\n        size_t  input_roll_nb[GGML_MAX_DIMS];\n        input_roll_nb[0] = ggml_type_size(src0->type);\n        for (int i = 1; i < GGML_MAX_DIMS; i++) {\n            input_roll_nb[i] = input_roll_nb[i - 1] * input_roll_ne[i - 1];\n        }\n        acl_tensor_ptr acl_input_roll_tensor =\n            ggml_cann_create_tensor(input_roll_buffer, ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type),\n                                    input_roll_ne, input_roll_nb, GGML_MAX_DIMS);\n        acl_tensor_ptr acl_input_tensor =\n            ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type),\n                                    input_roll_ne, input_roll_nb, GGML_MAX_DIMS);\n\n        int64_t shifts[] = { 1 };\n        int64_t dims[]   = { 3 };\n        aclnn_roll(ctx, acl_input_tensor.get(), acl_input_roll_tensor.get(), shifts, dims);\n\n        // init [-1, 1, -1, 1, ...]\n        minus_one_scale_buffer = minus_one_scale_allocator.get();\n\n        int64_t minus_one_ne[4] = { src0->ne[0], 1, 1, 1 };\n        size_t  minus_one_nb[GGML_MAX_DIMS];\n        minus_one_nb[0] = sizeof(float);\n        for (int i = 1; i < GGML_MAX_DIMS; i++) {\n            minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];\n        }\n        acl_minus_one_tensor = aclnn_values(ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0], minus_one_ne,\n                                            GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1);\n        int64_t   dim        = 3;\n        int64_t * index      = new int64_t[src0->ne[0]];\n        for (int i = 0; i < src0->ne[0]; i++) {\n            index[i] = i / 2 * 2;\n        }\n        int64_t index_num = src0->ne[0];\n        float   value     = -1;\n        aclnn_index_fill_tensor(ctx, acl_minus_one_tensor.get(), dim, index, index_num, value);\n    } else {\n        // roll input: [q0,q1,q2,...] ->\n        // [q_half,q_half+1,...,q_end,q0,q1,...q_half-1]\n        input_roll_buffer = roll_allocator.get();\n        acl_tensor_ptr acl_input_roll_tensor =\n            ggml_cann_create_tensor(input_roll_buffer, ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type),\n                                    src0->ne, src0->nb, GGML_MAX_DIMS);\n        acl_tensor_ptr acl_input_tensor = ggml_cann_create_tensor(src0);\n\n        int64_t shifts[] = { src0->ne[0] / 2 };\n        int64_t dims[]   = { 3 };\n        aclnn_roll(ctx, acl_input_tensor.get(), acl_input_roll_tensor.get(), shifts, dims);\n\n        // init [-1, -1, -1, 1, 1，1，...]\n        minus_one_scale_buffer  = minus_one_scale_allocator.get();\n        int64_t minus_one_ne[4] = { src0->ne[0], 1, 1, 1 };\n        size_t  minus_one_nb[GGML_MAX_DIMS];\n        minus_one_nb[0] = sizeof(float);\n        for (int i = 1; i < GGML_MAX_DIMS; i++) {\n            minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];\n        }\n        acl_minus_one_tensor     = aclnn_values(ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0], minus_one_ne,\n                                                GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1);\n        // -1 * first half\n        int64_t first_half_ne[4] = { src0->ne[0] / 2, 1, 1, 1 };\n        size_t  first_half_nb[GGML_MAX_DIMS];\n        first_half_nb[0] = sizeof(float);\n        for (int i = 1; i < GGML_MAX_DIMS; i++) {\n            first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1];\n        }\n        acl_tensor_ptr acl_first_half_tensor = ggml_cann_create_tensor(minus_one_scale_buffer, ACL_FLOAT, sizeof(float),\n                                                                       first_half_ne, first_half_nb, GGML_MAX_DIMS);\n        bool           inplace               = true;\n        float          scale                 = -1;\n        aclnn_muls(ctx, acl_first_half_tensor.get(), scale, nullptr, inplace);\n    }\n\n    // TODO: n_dims < ne0\n    GGML_ASSERT(n_dims == src0->ne[0]);\n\n    // input * scale\n    ggml_cann_pool_alloc roll_mul_scale_allocator(ctx.pool(), ggml_nbytes(src0));\n    void *               input_roll_mul_scale_buffer = roll_mul_scale_allocator.get();\n    size_t               input_nb[GGML_MAX_DIMS];\n    input_nb[0] = ggml_type_size(src0->type);\n    for (int i = 1; i < GGML_MAX_DIMS; i++) {\n        input_nb[i] = input_nb[i - 1] * src0->ne[i - 1];\n    }\n    acl_tensor_ptr acl_input_roll_mul_scale_tensor =\n        ggml_cann_create_tensor(input_roll_mul_scale_buffer, ggml_cann_type_mapping(src0->type),\n                                ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS);\n    acl_tensor_ptr acl_input_roll_reshape_tensor =\n        ggml_cann_create_tensor(input_roll_buffer, ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type),\n                                src0->ne, input_nb, GGML_MAX_DIMS);\n\n    aclnn_mul(ctx, acl_input_roll_reshape_tensor.get(), acl_minus_one_tensor.get(),\n              acl_input_roll_mul_scale_tensor.get());\n\n    // output\n    void * output_fp32_buffer;\n    if (src0->type == GGML_TYPE_F32) {\n        aclnn_mul(ctx, acl_src.get(), acl_cos_reshape_tensor.get());\n        aclnn_mul(ctx, acl_input_roll_mul_scale_tensor.get(), acl_sin_reshape_tensor.get());\n        aclnn_add(ctx, acl_src.get(), acl_input_roll_mul_scale_tensor.get(), acl_dst.get());\n        // TODO: ne0 != n_dims in mode2\n    } else if (src0->type == GGML_TYPE_F16) {\n        size_t input_fp32_nb[GGML_MAX_DIMS];\n        input_fp32_nb[0] = sizeof(float);\n        for (int i = 1; i < GGML_MAX_DIMS; i++) {\n            input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1];\n        }\n        ggml_cann_pool_alloc fp32_allocator1(ctx.pool(), ggml_nelements(dst) * sizeof(float));\n        void *               input_fp32_buffer1 = fp32_allocator1.get();\n        acl_tensor_ptr       input_fp32_tensor1 = ggml_cann_create_tensor(input_fp32_buffer1, ACL_FLOAT, sizeof(float),\n                                                                          dst->ne, input_fp32_nb, GGML_MAX_DIMS);\n        ggml_cann_pool_alloc fp32_allocator2(ctx.pool(), ggml_nelements(dst) * sizeof(float));\n        void *               input_fp32_buffer2 = fp32_allocator2.get();\n        acl_tensor_ptr       input_fp32_tensor2 = ggml_cann_create_tensor(input_fp32_buffer2, ACL_FLOAT, sizeof(float),\n                                                                          dst->ne, input_fp32_nb, GGML_MAX_DIMS);\n\n        ggml_cann_pool_alloc fp32_allocator(ctx.pool(), ggml_nelements(dst) * sizeof(float));\n        output_fp32_buffer                = fp32_allocator.get();\n        acl_tensor_ptr output_fp32_tensor = ggml_cann_create_tensor(output_fp32_buffer, ACL_FLOAT, sizeof(float),\n                                                                    dst->ne, input_fp32_nb, GGML_MAX_DIMS);\n        aclnn_mul(ctx, acl_src.get(), acl_cos_reshape_tensor.get(), input_fp32_tensor1.get());\n        aclnn_mul(ctx, acl_input_roll_mul_scale_tensor.get(), acl_sin_reshape_tensor.get(), input_fp32_tensor2.get());\n        aclnn_add(ctx, input_fp32_tensor1.get(), input_fp32_tensor2.get(), output_fp32_tensor.get());\n        aclnn_cast(ctx, output_fp32_tensor.get(), acl_dst.get(), ACL_FLOAT16);\n    }\n    return;\n#endif\n    int64_t acl_mode = is_neox ? 0 : 1;\n\n    // Pre-define head and tail dimensions for reuse\n    int64_t head_ne[GGML_MAX_DIMS] = { rope_dims, ne01, ne02, ne03 };\n    int64_t tail_ne[GGML_MAX_DIMS] = { tail_dims, ne01, ne02, ne03 };\n\n    // Step 1: Prepare trans tensors for F16 type conversion to F32 if needed\n    bool                 src_dst_need_trans = false;\n    ggml_cann_pool_alloc src_trans_allocator(ctx.pool());\n    ggml_cann_pool_alloc dst_trans_allocator(ctx.pool());\n    acl_tensor_ptr       acl_src_trans_tensor;\n    acl_tensor_ptr       acl_dst_trans_tensor;\n    void *               src_trans_buffer = nullptr;\n    void *               dst_trans_buffer = nullptr;\n    size_t               src_dst_trans_nb[GGML_MAX_DIMS];\n    if (src0->type == GGML_TYPE_F16) {\n        src_dst_need_trans = true;\n        src_trans_buffer   = src_trans_allocator.alloc(ggml_nelements(src0) * sizeof(float));\n        dst_trans_buffer   = dst_trans_allocator.alloc(ggml_nelements(dst) * sizeof(float));\n\n        src_dst_trans_nb[0] = sizeof(float);\n        for (int i = 1; i < GGML_MAX_DIMS; i++) {\n            src_dst_trans_nb[i] = src_dst_trans_nb[i - 1] * src0->ne[i - 1];\n        }\n        acl_src_trans_tensor = ggml_cann_create_tensor(src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne,\n                                                       src_dst_trans_nb, GGML_MAX_DIMS);\n        acl_dst_trans_tensor = ggml_cann_create_tensor(dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne,\n                                                       src_dst_trans_nb, GGML_MAX_DIMS);\n        aclnn_cast(ctx, acl_src.get(), acl_src_trans_tensor.get(), ACL_FLOAT);\n    }\n\n    // Step 2: Prepare head tensors for tail splitting if needed\n    acl_tensor_ptr acl_src_head;\n    acl_tensor_ptr acl_dst_head;\n    if (has_tail) {\n        // Create head views for RotaryPositionEmbedding (only first rope_dims dimensions)\n        // RotaryPositionEmbedding requires contiguous dst tensor, so we use a temporary buffer\n        if (src_dst_need_trans) {\n            // Use F32 trans tensor strides\n            acl_src_head = ggml_cann_create_tensor((char *) src_trans_buffer, ACL_FLOAT, sizeof(float), head_ne,\n                                                   src_dst_trans_nb, GGML_MAX_DIMS);\n        } else {\n            // Use original F32 tensor strides\n            acl_src_head = ggml_cann_create_tensor((char *) src0->data, ACL_FLOAT, sizeof(float), head_ne, src0->nb,\n                                                   GGML_MAX_DIMS);\n        }\n\n        int64_t              head_elements = rope_dims * ne01 * ne02 * ne03;\n        ggml_cann_pool_alloc dst_head_contiguous_allocator(ctx.pool(), head_elements * sizeof(float));\n        void *               dst_head_contiguous_buffer = dst_head_contiguous_allocator.get();\n\n        size_t head_contiguous_nb[GGML_MAX_DIMS];\n        head_contiguous_nb[0] = sizeof(float);\n        for (int i = 1; i < GGML_MAX_DIMS; i++) {\n            head_contiguous_nb[i] = head_contiguous_nb[i - 1] * head_ne[i - 1];\n        }\n        acl_dst_head = ggml_cann_create_tensor(dst_head_contiguous_buffer, ACL_FLOAT, sizeof(float), head_ne,\n                                               head_contiguous_nb, GGML_MAX_DIMS);\n    }\n\n    // Step 3: Execute RotaryPositionEmbedding\n    if (has_tail) {\n        // Rotate only the head portion (first rope_dims dimensions)\n        GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_head.get(), acl_cos_reshape_tensor.get(),\n                                acl_sin_reshape_tensor.get(), acl_mode, acl_dst_head.get());\n\n        // Copy head result from contiguous buffer back to destination tensor\n        if (src_dst_need_trans) {\n            acl_tensor_ptr acl_dst_head_target = ggml_cann_create_tensor(\n                (char *) dst_trans_buffer, ACL_FLOAT, sizeof(float), head_ne, src_dst_trans_nb, GGML_MAX_DIMS);\n            cann_copy(ctx, acl_dst_head.get(), acl_dst_head_target.get());\n        } else {\n            acl_tensor_ptr acl_dst_head_target =\n                ggml_cann_create_tensor((char *) dst->data, ACL_FLOAT, sizeof(float), head_ne, dst->nb, GGML_MAX_DIMS);\n            cann_copy(ctx, acl_dst_head.get(), acl_dst_head_target.get());\n        }\n    } else if (src_dst_need_trans) {\n        // Rotate full tensor (no tail), using trans tensors\n        GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(),\n                                acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get());\n    } else {\n        // Rotate full tensor (no tail), using original tensors\n        GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(),\n                                acl_sin_reshape_tensor.get(), acl_mode, acl_dst.get());\n    }\n\n    // Step 4: Copy unrotated tail portion from source to destination\n    if (has_tail) {\n        size_t src_tail_offset;\n        size_t dst_tail_offset;\n\n        auto copy_tail_device = [&](void * src_ptr, void * dst_ptr, aclDataType dtype, size_t elem_size,\n                                    size_t * nb_src_arr, size_t * nb_dst_arr) {\n            acl_tensor_ptr acl_src_tail =\n                ggml_cann_create_tensor(src_ptr, dtype, elem_size, tail_ne, nb_src_arr, GGML_MAX_DIMS);\n            acl_tensor_ptr acl_dst_tail =\n                ggml_cann_create_tensor(dst_ptr, dtype, elem_size, tail_ne, nb_dst_arr, GGML_MAX_DIMS);\n            cann_copy(ctx, acl_src_tail.get(), acl_dst_tail.get());\n        };\n\n        if (src_dst_need_trans) {\n            // Use F32 trans tensor strides and offsets\n            src_tail_offset = rope_dims * src_dst_trans_nb[0];\n            dst_tail_offset = rope_dims * src_dst_trans_nb[0];\n            copy_tail_device((char *) src_trans_buffer + src_tail_offset, (char *) dst_trans_buffer + dst_tail_offset,\n                             ACL_FLOAT, sizeof(float), src_dst_trans_nb, src_dst_trans_nb);\n        } else {\n            // Use original tensor strides and offsets\n            src_tail_offset = rope_dims * nb00;\n            dst_tail_offset = rope_dims * nb0;\n            copy_tail_device((char *) src0->data + src_tail_offset, (char *) dst->data + dst_tail_offset,\n                             ggml_cann_type_mapping(dst->type), ggml_element_size(dst), src0->nb, dst->nb);\n        }\n    }\n\n    // Step 5: Cast back to F16 if needed\n    if (src_dst_need_trans) {\n        aclnn_cast(ctx, acl_dst_trans_tensor.get(), acl_dst.get(), ACL_FLOAT16);\n    }\n}\n\nvoid ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get());\n}\n\nvoid ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];\n    ggml_tensor * src1 = dst->src[1];\n\n    // stride\n    int64_t s0 = ((const int32_t *) (dst->op_params))[0];\n\n    acl_tensor_ptr acl_input  = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);\n    acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);\n    acl_tensor_ptr acl_dst    = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);\n\n    // get base information of input and kernel\n    int64_t input_len   = *(src1->ne);\n    int64_t dst_len     = *(dst->ne);\n    int64_t kernel_size = *(src0->ne);\n\n    // set the max kernel size for each conv\n    int64_t max_kernel_size = 255;\n\n    // compute the partition of kernel\n    int64_t part_num = 1;\n    part_num         = (kernel_size + max_kernel_size - 1) / max_kernel_size;\n\n    int64_t strideVal[1];\n    strideVal[0]                    = s0;\n    acl_int_array_ptr stride        = ggml_cann_create_int_array(strideVal, 1);\n    int64_t           paddingVal[]  = { 0 };\n    acl_int_array_ptr padding       = ggml_cann_create_int_array(paddingVal, 1);\n    int64_t           dilationVal[] = { 1 };\n    acl_int_array_ptr dilation      = ggml_cann_create_int_array(dilationVal, 1);\n    bool              transposed    = true;\n    int64_t           groups        = 1;\n    int8_t            cubeMathType  = 0;\n\n#ifdef ASCEND_310P\n    cubeMathType = 1;\n#endif\n\n    auto weight_type = ggml_cann_type_mapping(src0->type);\n    auto dst_type    = ggml_cann_type_mapping(dst->type);\n\n    // slice the kernel to make each conv available\n    int64_t slice_dim   = -1;\n    int64_t slice_start = 0;\n    int64_t slice_end   = max_kernel_size;\n    int64_t slice_step  = 1;\n    int64_t interval    = max_kernel_size;\n\n    int64_t left_pad_len  = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0];\n    int64_t right_pad_len = 0;\n\n    acl_scalar_ptr alpha      = nullptr;\n    float          alphaValue = 1.0;\n    alpha                     = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);\n\n    // set zero to destination\n    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());\n\n    for (int k = 0; k < part_num; k++) {\n        // create part kernel tensor and slice from big kernel\n        slice_start = max_kernel_size * k;\n        if (k == part_num - 1) {\n            slice_end = kernel_size;\n            interval  = kernel_size - max_kernel_size * k;\n        } else {\n            slice_end = max_kernel_size * (k + 1);\n        }\n\n        int64_t part_ne[4];\n        for (int i = 0; i < 4; i++) {\n            part_ne[i] = *(src0->ne + i);\n        }\n        part_ne[0] = interval;\n\n        size_t part_nb[4];\n        part_nb[0] = sizeof(weight_type);\n        for (int i = 1; i < 4; i++) {\n            part_nb[i] = part_nb[i - 1] * part_ne[i - 1];\n        }\n\n        ggml_cann_pool_alloc part_kernel_allocator;\n        part_kernel_allocator.alloc(ctx.pool(), part_nb[3]);\n        void * part_kernel_buf = part_kernel_allocator.get();\n\n        acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, ggml_element_size(src0),\n                                                             part_ne, part_nb, 3, ACL_FORMAT_NCL);\n\n        GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step,\n                                part_kernel.get());\n\n        // create the part conv result tensor\n        int64_t part_dst_ne[4];\n        for (int i = 0; i < 4; i++) {\n            part_dst_ne[i] = *(dst->ne + i);\n        }\n        part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1;\n\n        size_t part_dst_nb[4];\n        part_dst_nb[0] = sizeof(weight_type);\n        for (int i = 1; i < 4; i++) {\n            part_dst_nb[i] = part_dst_nb[i - 1] * part_dst_ne[i - 1];\n        }\n        ggml_cann_pool_alloc part_dst_allocator;\n        part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]);\n        void * part_dst_buf = part_dst_allocator.get();\n\n        acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst),\n                                                              part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst.get());\n\n        // compute part conv transpose 1d\n        GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(),\n                                padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(),\n                                cubeMathType);\n\n        // compute the position of part result in final result\n        int64_t global_start = slice_start;\n        int64_t global_end   = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len);\n\n        left_pad_len  = global_start;\n        right_pad_len = dst_len - global_end;\n\n        std::vector<int64_t> padDataVal = { left_pad_len, right_pad_len };\n        acl_int_array_ptr    padData    = ggml_cann_create_int_array(padDataVal.data(), 2);\n\n        acl_scalar_ptr pad_value    = nullptr;\n        float          pad_valueVal = 0.0;\n        pad_value                   = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);\n\n        int64_t conv_result_ne[4];\n        for (int i = 0; i < 4; i++) {\n            conv_result_ne[i] = *(dst->ne + i);\n        }\n\n        size_t conv_result_nb[4];\n        conv_result_nb[0] = sizeof(weight_type);\n        for (int i = 1; i < 4; i++) {\n            conv_result_nb[i] = conv_result_nb[i - 1] * conv_result_ne[i - 1];\n        }\n\n        ggml_cann_pool_alloc conv_result_allocator;\n        conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]);\n        void * conv_result_buf = conv_result_allocator.get();\n\n        acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst),\n                                                             conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);\n\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get());\n        GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(),\n                                conv_result.get());\n        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get());\n    }\n}\n\nvoid ggml_cann_elu(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];\n\n    acl_tensor_ptr acl_input = ggml_cann_create_tensor(src0);\n    acl_tensor_ptr acl_dst   = ggml_cann_create_tensor(dst);\n\n    float          alphaValue = 1.0f;\n    acl_scalar_ptr alpha      = nullptr;\n    alpha                     = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, Elu, acl_input.get(), alpha.get(), alpha.get(), alpha.get(), acl_dst.get());\n}\n\nvoid ggml_cann_mean(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    int64_t           reduceDimValue[] = { 3 };\n    acl_int_array_ptr reduceDim        = ggml_cann_create_int_array(reduceDimValue, 1);\n    bool              keepDim          = true;\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, Mean, acl_src.get(), reduceDim.get(), keepDim, ACL_FLOAT, acl_dst.get());\n}\n\nvoid ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor *     src0             = dst->src[0];\n    int32_t *         opts             = (int32_t *) dst->op_params;\n    int64_t           paddingsArray[2] = { opts[0], opts[1] };\n    acl_int_array_ptr paddings         = ggml_cann_create_int_array(paddingsArray, 2);\n\n    for (int64_t i = 0; i < src0->ne[3]; i++) {\n        acl_tensor_ptr acl_src =\n            ggml_cann_create_tensor((char *) src0->data + i * src0->ne[3], ggml_cann_type_mapping(src0->type),\n                                    ggml_element_size(src0), src0->ne, src0->nb, 3);\n\n        acl_tensor_ptr acl_dst =\n            ggml_cann_create_tensor((char *) dst->data + i * src0->ne[3], ggml_cann_type_mapping(dst->type),\n                                    ggml_element_size(dst), dst->ne, dst->nb, 3);\n\n        GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get());\n    }\n}\n\nvoid ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];\n    ggml_tensor * src1 = dst->src[1];\n\n    acl_tensor_ptr acl_self  = ggml_cann_create_tensor(src0);\n    acl_tensor_ptr acl_other = ggml_cann_create_tensor(src1);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self.get(), acl_other.get());\n\n    ggml_cann_sum(ctx, dst);\n}\n\nvoid ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    float          alphaValue = 0.0f;\n    acl_scalar_ptr alpha      = nullptr;\n    alpha                     = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src.get(), alpha.get(), acl_dst.get());\n}\n\n/**\n * @brief Performs expert-specific matrix multiplication (MoE) with\n * floating-point precision using the CANN backend.\n *\n * This function executes a matrix multiplication operation tailored for\n * Mixture of Experts (MoE) models, where the input tensor is multiplied\n * with expert-specific weight matrices. It uses the CANN backend for\n * efficient computation and stores the result in the destination tensor `dst`.\n * The operation may leverage identity-based optimizations or routing masks\n * as part of sparse expert selection.\n *\n * @param ctx The context for executing CANN backend operations.\n * @param dst The destination tensor where the MoE multiplication result\n * will be stored.\n *\n * @note This function assumes floating-point data types and is designed for\n * MoE architectures, possibly involving sparse expert routing.\n */\nstatic void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    //dst   [M, K, N, 1]\n    ggml_tensor * src0 = dst->src[0];  //src0\t[D, M, A, 1]  -> [D, M, K, 1]\n    ggml_tensor * src1 = dst->src[1];  //src1\t[D, B, N, 1], B = K or B = 1 -> [D, 1, K, 1]\n    ggml_tensor * ids  = dst->src[2];  //ids\t[K, N]\n\n    GGML_ASSERT(src0->ne[3] == 1);\n    GGML_ASSERT(src1->ne[3] == 1);\n    GGML_ASSERT(dst->ne[3] == 1);\n\n    int64_t batch = src1->ne[2];\n    GGML_ASSERT(batch == ids->ne[1]);\n\n    ggml_cann_pool_alloc export_allocator(ctx.pool(), src0->ne[0] * src0->ne[1] * ids->ne[0] * ggml_element_size(src0));\n    void *               export_ptr = export_allocator.get();\n    for (int64_t i = 0; i < batch; i++) {\n        acl_tensor_ptr select_index  = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, i * ids->nb[1]);\n        acl_tensor_ptr export_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3);\n\n        int64_t select_export_ne[] = { src0->ne[0], src0->ne[1], ids->ne[0] };\n        size_t  select_export_nb[3];\n        select_export_nb[0] = src0->nb[0];\n        for (int k = 1; k < 3; k++) {\n            select_export_nb[k] = select_export_nb[k - 1] * select_export_ne[k - 1];\n        }\n\n        acl_tensor_ptr select_export =\n            ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0),\n                                    select_export_ne, select_export_nb, 3);\n        GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, export_weight.get(), 0, select_index.get(), select_export.get());\n\n        int64_t        select_transpose_ne[] = { select_export_ne[1], select_export_ne[0], select_export_ne[2] };\n        size_t         select_transpose_nb[] = { select_export_nb[1], select_export_nb[0], select_export_nb[2] };\n        acl_tensor_ptr select_export_transpose =\n            ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0),\n                                    select_transpose_ne, select_transpose_nb, 3);\n\n        int64_t        active_tensor_ne[] = { src1->ne[0], 1, src1->ne[1] };\n        size_t         active_tensor_nb[] = { src1->nb[0], src1->nb[1], src1->nb[1] };\n        acl_tensor_ptr active_tensor =\n            ggml_cann_create_tensor(src1, active_tensor_ne, active_tensor_nb, 3, ACL_FORMAT_ND, i * src1->nb[2]);\n\n        int64_t        dst_ne[] = { dst->ne[0], 1, dst->ne[1] };\n        size_t         dst_nb[] = { dst->nb[0], dst->nb[1], dst->nb[1] };\n        acl_tensor_ptr acl_dst  = ggml_cann_create_tensor(dst, dst_ne, dst_nb, 3, ACL_FORMAT_ND, i * dst->nb[2]);\n\n        GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, active_tensor.get(), select_export_transpose.get(), acl_dst.get(), 2);\n    }\n}\n\n/**\n * @brief Performs quantized matrix multiplication for Mixture of Experts (MoE)\n * models using the CANN backend.\n *\n * This function implements MUL_MAT_ID operation for quantized weight matrices\n * (Q4_0 and Q8_0 formats). It selects expert-specific weight matrices based on\n * the provided expert indices, and computes matrix multiplication using CANN's\n * WeightQuantBatchMatmulV2 operator.\n *\n * The function performs the following steps:\n * 1. Converts input/output tensors to F16 format if necessary\n * 2. Uses IndexSelect to extract expert-specific weights and scales based on indices\n * 3. Performs quantized matrix multiplication for each expert using WeightQuantBatchMatmulV2\n * 4. Converts output back to the target type if needed\n *\n * Tensor shapes:\n * - dst:  [M, K, N, 1] - output tensor\n * - src0: [D, M, A, 1] - quantized weight matrices (Q4_0 or Q8_0)\n * - src1: [D, B, N, 1] - input activations (B = K for per-expert input, or B = 1 for broadcast)\n * - ids:  [K, N] - expert indices for routing\n *\n * @param ctx The CANN backend context for operation execution.\n * @param dst The destination tensor where the multiplication result will be stored.\n *\n * @note Only Q4_0 and Q8_0 quantization formats are supported.\n * @note The function handles automatic type conversion to/from F16 as needed by the hardware.\n */\nstatic void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    // dst:  [M, K, N, 1]\n    // src0: [D, M, A, 1] - quantized weights\n    // src1: [D, B, N, 1] - input activations, B = K or B = 1\n    // ids:  [K, N] - expert indices\n    ggml_tensor * src0 = dst->src[0];\n    ggml_tensor * src1 = dst->src[1];\n    ggml_tensor * ids  = dst->src[2];\n\n    GGML_ASSERT(src0->ne[3] == 1);\n    GGML_ASSERT(src1->ne[3] == 1);\n    GGML_ASSERT(dst->ne[3] == 1);\n    GGML_ASSERT(src1->ne[2] == ids->ne[1]);\n\n    const int64_t        n_batches        = ids->ne[1];\n    const int64_t        n_select_experts = ids->ne[0];\n    const enum ggml_type type             = src0->type;\n\n    const int32_t group_size = QK8_0;  // Both Q4_0 and Q8_0 use group size of 32\n    GGML_ASSERT(group_size == QK4_0);\n\n    // Calculate element size for quantized weights\n    const float weight_elem_size =\n        (type == GGML_TYPE_Q4_0) ? 0.5f :\n        (type == GGML_TYPE_Q8_0) ? 1.0f :\n                                   (GGML_ABORT(\"MUL_MAT_ID only supports Q4_0 and Q8_0\"), 0.0f);\n\n    // Calculate scale offset in memory\n    const size_t weight_size     = src0->ne[0] * src0->ne[1] * src0->ne[2] * weight_elem_size;\n    const size_t scale_elem_size = sizeof(uint16_t);\n    char *       scale_data      = (char *) src0->data + weight_size;\n\n    // Allocate buffers for selected expert weights and scales\n    const size_t         selected_weight_size = src0->ne[0] * src0->ne[1] * n_select_experts * weight_elem_size;\n    ggml_cann_pool_alloc selected_weight_alloc(ctx.pool(), selected_weight_size);\n    void *               selected_weight_buffer = selected_weight_alloc.get();\n\n    const size_t selected_scale_size = (src0->ne[0] / group_size) * src0->ne[1] * n_select_experts * scale_elem_size;\n    ggml_cann_pool_alloc selected_scale_alloc(ctx.pool(), selected_scale_size);\n    void *               selected_scale_buffer = selected_scale_alloc.get();\n\n    // Helper lambda to allocate and cast tensor to F16 if needed\n    constexpr size_t f16_elem_size      = sizeof(uint16_t);\n    auto             prepare_f16_buffer = [&](ggml_tensor * tensor, ggml_cann_pool_alloc & allocator,\n                                  bool need_cast = false) -> void * {\n        if (tensor->type == GGML_TYPE_F16) {\n            return tensor->data;\n        }\n\n        size_t total_size = f16_elem_size;\n        for (int i = 0; i < GGML_MAX_DIMS; i++) {\n            total_size *= tensor->ne[i];\n        }\n        void * buffer = allocator.alloc(total_size);\n\n        if (need_cast == false) {\n            return buffer;\n        }\n\n        int64_t ne[GGML_MAX_DIMS];\n        size_t  nb[GGML_MAX_DIMS] = { f16_elem_size };\n        for (int i = 0; i < GGML_MAX_DIMS; i++) {\n            ne[i] = tensor->ne[i];\n            if (i > 0) {\n                nb[i] = nb[i - 1] * ne[i - 1];\n            }\n        }\n\n        acl_tensor_ptr src_tensor = ggml_cann_create_tensor(tensor);\n        acl_tensor_ptr f16_tensor = ggml_cann_create_tensor(buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);\n        aclnn_cast(ctx, src_tensor.get(), f16_tensor.get(), ACL_FLOAT16);\n\n        return buffer;\n    };\n\n    // Prepare input and output buffers\n    ggml_cann_pool_alloc input_alloc(ctx.pool());\n    void *               input_buffer = prepare_f16_buffer(src1, input_alloc, true);\n\n    ggml_cann_pool_alloc output_alloc(ctx.pool());\n    void *               output_buffer = prepare_f16_buffer(dst, output_alloc, false);\n\n    // Process each batch\n    for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) {\n        // Create index tensor for current batch\n        const size_t   index_offset  = batch_idx * ids->nb[1];\n        acl_tensor_ptr batch_indices = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, index_offset);\n\n        // Select quantized weights using expert indices\n        // Q4_0 stores 2 values per byte, Q8_0 stores 1 value per byte\n        const int64_t weight_d         = (type == GGML_TYPE_Q4_0) ? src0->ne[0] / 2 : src0->ne[0];\n        const int64_t weight_m         = src0->ne[1];\n        const int64_t weight_n_experts = src0->ne[2];\n\n        int64_t weight_ne[3] = { weight_d, weight_m, weight_n_experts };\n        size_t  weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), weight_d * weight_m * sizeof(int8_t) };\n\n        acl_tensor_ptr all_weights =\n            ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, 3);\n\n        int64_t selected_weight_ne[3] = { weight_d, weight_m, n_select_experts };\n        size_t  selected_weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t),\n                                          weight_d * weight_m * sizeof(int8_t) };\n\n        acl_tensor_ptr selected_weights = ggml_cann_create_tensor(selected_weight_buffer, ACL_INT8, sizeof(int8_t),\n                                                                  selected_weight_ne, selected_weight_nb, 3);\n\n        GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_weights.get(), 0, batch_indices.get(), selected_weights.get());\n\n        // Select scales using the same expert indices\n        const int64_t scale_d     = src0->ne[0] / group_size;\n        int64_t       scale_ne[3] = { scale_d, weight_m, weight_n_experts };\n        size_t scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, scale_d * weight_m * scale_elem_size };\n\n        acl_tensor_ptr all_scales =\n            ggml_cann_create_tensor(scale_data, ACL_FLOAT16, scale_elem_size, scale_ne, scale_nb, 3);\n\n        int64_t selected_scale_ne[3] = { scale_d, weight_m, n_select_experts };\n        size_t  selected_scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size,\n                                         scale_d * weight_m * scale_elem_size };\n\n        acl_tensor_ptr selected_scales = ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size,\n                                                                 selected_scale_ne, selected_scale_nb, 3);\n\n        GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_scales.get(), 0, batch_indices.get(), selected_scales.get());\n\n        // Process each expert for current batch\n        // IndexSelect output layout: [D, M, K] in contiguous format\n        // WeightQuantBatchMatmulV2 expects: [M, D] with row-major stride\n        for (int64_t expert_idx = 0; expert_idx < n_select_experts; expert_idx++) {\n            // Determine input offset: broadcast if src1->ne[1]==1, otherwise use per-expert input\n            const size_t input_offset =\n                (batch_idx * src1->ne[1] + (src1->ne[1] == 1 ? 0 : expert_idx)) * src1->ne[0] * f16_elem_size;\n            const size_t output_offset = (batch_idx * dst->ne[1] + expert_idx) * dst->ne[0] * f16_elem_size;\n\n            // Create weight view for current expert: [D, M, K] -> [M, D]\n            int64_t      weight_view_ne[2]  = { weight_m, src0->ne[0] };\n            float        weight_view_nb[2]  = { src0->ne[0] * weight_elem_size, weight_elem_size };\n            const size_t weight_view_offset = expert_idx * selected_weight_nb[2];\n\n            acl_tensor_ptr weight_view =\n                ggml_cann_create_tensor(selected_weight_buffer, ggml_cann_type_mapping(type), weight_elem_size,\n                                        weight_view_ne, weight_view_nb, 2, ACL_FORMAT_ND, weight_view_offset);\n\n            // Create scale view for current expert: [D, M, K] -> [M, D]\n            int64_t      scale_view_ne[2]  = { weight_m, scale_d };\n            size_t       scale_view_nb[2]  = { selected_scale_nb[1], selected_scale_nb[0] };\n            const size_t scale_view_offset = expert_idx * selected_scale_nb[2];\n\n            acl_tensor_ptr scale_view =\n                ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, scale_view_ne,\n                                        scale_view_nb, 2, ACL_FORMAT_ND, scale_view_offset);\n\n            // Create input activation tensor [D, 1]\n            int64_t input_ne[2] = { src1->ne[0], 1 };\n            size_t  input_nb[2] = { f16_elem_size, src1->ne[0] * f16_elem_size };\n\n            acl_tensor_ptr input_tensor = ggml_cann_create_tensor(input_buffer, ACL_FLOAT16, f16_elem_size, input_ne,\n                                                                  input_nb, 2, ACL_FORMAT_ND, input_offset);\n\n            // Create output tensor [M, 1]\n            int64_t output_ne[2] = { dst->ne[0], 1 };\n            size_t  output_nb[2] = { f16_elem_size, dst->ne[0] * f16_elem_size };\n\n            acl_tensor_ptr output_tensor = ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, output_ne,\n                                                                   output_nb, 2, ACL_FORMAT_ND, output_offset);\n\n            // Perform quantized matrix multiplication\n            GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, input_tensor.get(), weight_view.get(),\n                                    scale_view.get(), nullptr, nullptr, nullptr, nullptr, group_size,\n                                    output_tensor.get());\n        }\n    }\n\n    // Cast output back to original type if we used a temporary F16 buffer\n    if (dst->type != GGML_TYPE_F16) {\n        int64_t ne[GGML_MAX_DIMS];\n        size_t  nb[GGML_MAX_DIMS] = { f16_elem_size };\n        for (int i = 0; i < GGML_MAX_DIMS; i++) {\n            ne[i] = dst->ne[i];\n            if (i > 0) {\n                nb[i] = nb[i - 1] * ne[i - 1];\n            }\n        }\n\n        acl_tensor_ptr f16_output =\n            ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);\n        acl_tensor_ptr dst_tensor = ggml_cann_create_tensor(dst);\n\n        aclnn_cast(ctx, f16_output.get(), dst_tensor.get(), ggml_cann_type_mapping(dst->type));\n    }\n}\n\nvoid ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    const enum ggml_type type = dst->src[0]->type;\n    switch (type) {\n        case GGML_TYPE_F32:\n        case GGML_TYPE_F16:\n            ggml_cann_mul_mat_id_fp(ctx, dst);\n            break;\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q8_0:\n            ggml_cann_mul_mat_id_quant(ctx, dst);\n            break;\n        default:\n            GGML_ABORT(\"Unsupported type for mul_mat_id\");\n            break;\n    }\n}\n\nvoid ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];  // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont)\n    ggml_tensor * src1 = dst->src[1];  // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)\n    ggml_tensor * src2 = dst->src[2];  // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)\n    ggml_tensor * src3 = dst->src[3];  // mask, fp16\n\n    // B, N, S, D (uncont) -> B, S, N, D (cont)\n    int64_t src0_bsnd_ne[GGML_MAX_DIMS];\n    memcpy(src0_bsnd_ne, src0->ne, GGML_MAX_DIMS * sizeof(int64_t));\n    size_t src0_bsnd_nb[GGML_MAX_DIMS];\n    memcpy(src0_bsnd_nb, src0->nb, GGML_MAX_DIMS * sizeof(size_t));\n    int64_t src1_bsnd_ne[GGML_MAX_DIMS];\n    memcpy(src1_bsnd_ne, src1->ne, GGML_MAX_DIMS * sizeof(int64_t));\n    size_t src1_bsnd_nb[GGML_MAX_DIMS];\n    memcpy(src1_bsnd_nb, src1->nb, GGML_MAX_DIMS * sizeof(size_t));\n    int64_t src2_bsnd_ne[GGML_MAX_DIMS];\n    memcpy(src2_bsnd_ne, src2->ne, GGML_MAX_DIMS * sizeof(int64_t));\n    size_t src2_bsnd_nb[GGML_MAX_DIMS];\n    memcpy(src2_bsnd_nb, src2->nb, GGML_MAX_DIMS * sizeof(size_t));\n\n    auto transpose12 = [](int64_t * ne, size_t * nb) {\n        int64_t ne_tmp = ne[1];\n        size_t  nb_tmp = nb[1];\n        ne[1]          = ne[2];\n        nb[1]          = nb[2];\n        ne[2]          = ne_tmp;\n        nb[2]          = nb_tmp;\n    };\n\n    transpose12(src0_bsnd_ne, src0_bsnd_nb);\n    transpose12(src1_bsnd_ne, src1_bsnd_nb);\n    transpose12(src2_bsnd_ne, src2_bsnd_nb);\n\n    float maxBias      = 0.0f;\n    float scaleValue   = 1.0f;\n    float logitSoftcap = 0.0f;\n    memcpy(&scaleValue, (float *) dst->op_params + 0, sizeof(float));\n    memcpy(&maxBias, (float *) dst->op_params + 1, sizeof(float));\n    memcpy(&logitSoftcap, (float *) dst->op_params + 2, sizeof(float));\n\n    if (logitSoftcap == 0.0f) {\n        size_t faElemSize = sizeof(uint16_t);\n        auto   faDataType = ACL_FLOAT16;  //ACL_BF16;\n\n        acl_tensor_ptr acl_q_tensor = nullptr;\n        acl_tensor_ptr acl_k_tensor = nullptr;\n        acl_tensor_ptr acl_v_tensor = nullptr;\n\n        // Step 1: cast the src0 (Query) to fp16 if needed\n        ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());\n        void *               src0_f16_buffer = nullptr;\n\n        if (ggml_cann_type_mapping(src0->type) != faDataType) {\n            acl_tensor_ptr acl_src0_f32_tensor =\n                ggml_cann_create_tensor(src0, src0_bsnd_ne, src0_bsnd_nb, GGML_MAX_DIMS);\n            src0_f16_buffer = src0_f16_allocator.alloc(ggml_nelements(src0) * faElemSize);\n\n            int64_t * src0_f16_ne = src0_bsnd_ne;\n            size_t    src0_f16_nb[GGML_MAX_DIMS];\n            src0_f16_nb[0] = sizeof(uint16_t);\n            for (int i = 1; i < GGML_MAX_DIMS; ++i) {\n                src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1];\n            }\n\n            acl_q_tensor = ggml_cann_create_tensor(src0_f16_buffer, faDataType, faElemSize, src0_f16_ne, src0_f16_nb,\n                                                   GGML_MAX_DIMS);\n            aclnn_cast(ctx, acl_src0_f32_tensor.get(), acl_q_tensor.get(), faDataType);\n        } else {\n            acl_q_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne, src0_bsnd_nb, GGML_MAX_DIMS);\n        }\n\n        // Step 2: create the acl tensors for src1 (Key), src2 (Value),\n        //         and the direct output from FusedInferAttention\n\n        acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS);\n        acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS);\n\n        // Step 3: create the PSEShift tensor if needed\n        //         this tensor is considered as mask (f16) in the llama.cpp\n        acl_tensor_ptr       bcast_pse_tensor;\n        ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());\n        if (src3 != nullptr) {\n            // Construct the truncated pse tensor (common for prefill/decode)\n            int64_t trunc_pse_ne[GGML_MAX_DIMS] = {\n                src3->ne[0],  // D\n                src0->ne[1],  // S (number of Q tokens)\n                src3->ne[2],  // mask N\n                src3->ne[3]   // B\n            };\n            size_t * trunc_pse_nb = src3->nb;\n\n            acl_tensor_ptr acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(\n                src3->data, ACL_FLOAT16, sizeof(uint16_t), trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);\n\n            int64_t bcast_pse_ne[GGML_MAX_DIMS];\n            size_t  bcast_pse_nb[GGML_MAX_DIMS];\n            bcast_pse_ne[0] = src3->ne[0];  // D\n            bcast_pse_ne[1] = src0->ne[1];  // S\n            bcast_pse_ne[2] = src0->ne[2];  // N (num_heads)\n            bcast_pse_ne[3] = src3->ne[3];  // B\n            if (maxBias == 0.0f) {\n                // When maxBias == 0.0f, use nb = 0 reduce once repeat (Qwen2)\n                // Construct the bcast tensor (simulate repeat on the head dimension using stride=0)\n                bcast_pse_nb[0] = sizeof(uint16_t);\n                bcast_pse_nb[1] = bcast_pse_nb[0] * bcast_pse_ne[0];\n                bcast_pse_nb[2] = 0;  // <---- the head dimension shares the same data\n                bcast_pse_nb[3] = src3->nb[3];\n\n                bcast_pse_tensor = ggml_cann_create_tensor(src3->data, ACL_FLOAT16, sizeof(uint16_t), bcast_pse_ne,\n                                                           bcast_pse_nb, GGML_MAX_DIMS);\n\n            } else {\n                bcast_pse_nb[0] = sizeof(uint16_t);\n                for (int i = 1; i < GGML_MAX_DIMS; i++) {\n                    bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];\n                }\n\n                void * bcast_pse_buffer =\n                    bcast_pse_allocator.alloc(ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));\n\n                bcast_pse_tensor = ggml_cann_create_tensor(bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),\n                                                           bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);\n\n                int64_t repeats[] = { 1, src0->ne[2], 1, 1 };\n                aclnn_repeat(ctx, acl_mask_f16_trunc_tensor.get(), bcast_pse_tensor.get(), repeats);\n\n                // alibi\n                // Compute the slope if needed. Derived from ggml_cann_softmax().\n                const int64_t        n_heads = src0->ne[2];\n                ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t));\n                void *               slope_buffer = slope_allocator.get();\n                aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias, GGML_TYPE_F16);\n\n                int64_t slope_ne[] = { 1, 1, n_heads, 1 };\n                size_t  slope_nb[GGML_MAX_DIMS];\n                slope_nb[0] = sizeof(uint16_t);\n                for (int i = 1; i < GGML_MAX_DIMS; i++) {\n                    slope_nb[i] = slope_nb[i - 1] * slope_ne[0];\n                }\n\n                acl_tensor_ptr slope_tensor = ggml_cann_create_tensor(slope_buffer, ACL_FLOAT16, sizeof(uint16_t),\n                                                                      slope_ne, slope_nb, GGML_MAX_DIMS);\n                GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor.get(), slope_tensor.get());\n            }\n        }\n\n        // Step 4: set the inputs for FusedInferAttention.\n        acl_tensor_list_ptr acl_k_tensor_list = ggml_cann_create_tensor_list(acl_k_tensor);\n        acl_tensor_list_ptr acl_v_tensor_list = ggml_cann_create_tensor_list(acl_v_tensor);\n\n        int64_t numHeads           = src0->ne[2];  // N\n        int64_t numKeyValueHeads   = src1->ne[2];\n        // double  scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)\n        int64_t preTokens          = 65535;\n        int64_t nextTokens         = 65535;\n        char    layout[5]          = { 'B', 'S', 'N', 'D', 0 };\n        int64_t sparseMode         = 0;\n        int64_t innerPrecise       = (src0->ne[1] == 1) ? 0 : 2;\n        int64_t blockSize          = 0;\n        int64_t antiquantMode      = 0;\n        bool    softmaxLseFlag     = false;\n        int64_t keyAntiquantMode   = 0;\n        int64_t valueAntiquantMode = 0;\n\n        GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);\n        acl_tensor_ptr       fa_dst_tensor;\n        acl_tensor_ptr       acl_dst_tensor;\n        ggml_cann_pool_alloc out_f16_allocator(ctx.pool());\n        if (dst->type == GGML_TYPE_F32) {\n            void * out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);\n\n            int64_t * out_f16_ne = src0_bsnd_ne;\n            size_t    out_f16_nb[GGML_MAX_DIMS];\n            out_f16_nb[0] = faElemSize;\n            for (int i = 1; i < GGML_MAX_DIMS; ++i) {\n                out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];\n            }\n\n            fa_dst_tensor =\n                ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS);\n        } else {\n            fa_dst_tensor = ggml_cann_create_tensor(dst);\n        }\n\n        GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, acl_q_tensor.get(), acl_k_tensor_list.get(),\n                                acl_v_tensor_list.get(),               // q, k, v\n                                bcast_pse_tensor.get(), nullptr,       // pse, mask\n                                nullptr, nullptr,                      // actSeqLen, actSeqLenkv\n                                nullptr, nullptr,                      // deqScale1, quantScale1\n                                nullptr, nullptr, nullptr,             // deqScale2, quantScale2, quantOffset2\n                                nullptr, nullptr,                      // antiquantScale, antiquantOffset\n                                nullptr,                               // blockTable\n                                nullptr, nullptr,                      // qPadSize, kvPadSize\n                                nullptr, nullptr,                      // kAntiquantScale, kAntiQuantOffset\n                                nullptr, nullptr,                      // vAntiquantScale, vAntiQuantOffset\n                                nullptr, nullptr, nullptr,             // kSharedPrefix, vSharedPrefix, actSharedLen\n                                numHeads, scaleValue,                  // heads, scaleValue\n                                preTokens, nextTokens,                 // preTokens, nextTokens\n                                layout,                                // inputLayout\n                                numKeyValueHeads,                      // numKVHeads\n                                sparseMode, innerPrecise,              // sparseMode, innerPrecise\n                                blockSize, antiquantMode,              // blockSize, antiquantMode\n                                softmaxLseFlag,                        // softmaxLseFlag\n                                keyAntiquantMode, valueAntiquantMode,  // keyAntiqMode, valueAntiqMode\n                                fa_dst_tensor.get(),                   // attentionOut\n                                nullptr                                // softmaxLse\n        );\n\n        if (dst->type == GGML_TYPE_F32) {\n            // Step 6: post-processing, permute and cast to f32\n            acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);\n            aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));\n        }\n    } else {\n        GGML_ABORT(\"Function is not implemented.\");\n    }\n}\n\nstatic void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];  // weight\n    ggml_tensor * src1 = dst->src[1];  // input\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());\n\n    const int64_t dps2 = ne2 / ne02;\n    const int64_t dps3 = ne3 / ne03;\n    for (int64_t i3 = 0; i3 < ne3; i3++) {\n        for (int64_t i2 = 0; i2 < ne2; i2++) {\n            const int64_t i02 = i2 / dps2;\n            const int64_t i03 = i3 / dps3;\n\n            const int64_t  i12 = i2;\n            const int64_t  i13 = i3;\n            acl_tensor_ptr accumulator =\n                ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type),\n                                        ggml_type_size(dst->type), dst->ne, dst->nb, 2);\n\n            // The outer product needs to be accumulated in this dimension.\n            for (int64_t i1 = 0; i1 < ne11; i1++) {\n                acl_tensor_ptr acl_input = ggml_cann_create_tensor(\n                    (char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type),\n                    ggml_type_size(src0->type), src1->ne, src1->nb, 1);\n\n                acl_tensor_ptr acl_weight = ggml_cann_create_tensor(\n                    (char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type),\n                    ggml_type_size(src0->type), src0->ne, src0->nb, 1);\n\n                ggml_cann_pool_alloc output_allocator(ctx.pool());\n                void *               output_buffer = output_allocator.alloc(ggml_nbytes(dst));\n                acl_tensor_ptr       acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type),\n                                                                       ggml_type_size(dst->type), dst->ne, dst->nb, 2);\n\n                GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get());\n                float       alpha_value = 1.0f;\n                aclScalar * alpha       = aclCreateScalar(&alpha_value, ACL_FLOAT);\n                GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha);\n            }\n        }\n    }\n}\n\nvoid ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];\n\n    const enum ggml_type type = src0->type;\n\n    switch (type) {\n        case GGML_TYPE_F32:\n        case GGML_TYPE_F16:\n            ggml_cann_out_prod_fp(ctx, dst);\n            break;\n        default:\n            GGML_ABORT(\"Unsupport type for GGML_OP_OUT_PROD\");\n            break;\n    }\n}\n\nvoid ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];  // conv_x\n    ggml_tensor * src1 = dst->src[1];  // conv1d.weight\n\n    // This op is currently defined only for F32 in ggml_cpu\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    // Shapes follow ggml_compute_forward_ssm_conv_f32\n    const int64_t nc  = src1->ne[0];   // d_conv\n    const int64_t ncs = src0->ne[0];   // d_conv - 1 + n_t\n    const int64_t nr  = src0->ne[1];   // d_inner\n    const int64_t n_s = src0->ne[2];   // n_seqs\n\n    const int64_t n_t = dst->ne[1];    // tokens per sequence\n\n    GGML_ASSERT(dst->ne[0] == nr);     // dst: {d_inner, n_t, n_s}\n    GGML_ASSERT(src1->ne[1] == nr);    // weight: {d_conv, d_inner}\n    GGML_ASSERT(ncs == nc - 1 + n_t);  // conv_x: {d_conv - 1 + n_t, d_inner, n_s}\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n    GGML_ASSERT(src1->nb[0] == sizeof(float));\n\n    // --- Build CANN tensors ---\n\n    // 1) Input: conv_x as NCL\n    //\n    // src0->ne = { ncs, nr, n_s, 1 }  // {L_in, C, N}\n    // Passing ACL_FORMAT_NCL here means:\n    //   reversed dims -> [N, C, L_in] = [n_s, nr, ncs]\n    acl_tensor_ptr acl_x = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);\n\n    // 2) Weights: depthwise conv kernel, view src1 as {K, 1, C}\n    //\n    // src1 original:   ne = { nc, nr, 1, 1 }  // [K, C, 1, 1]\n    // we want a view:  ne_w = { nc, 1, nr }   // [K, 1, C]\n    // so that reversed dims -> [C, 1, K] which matches\n    //   [out_channels, in_channels/groups, kernel_size]\n    int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 };  // [K, 1 input ch. per group, C groups]\n    // Layout: src1 data is [K, C] with\n    //   offset(k, c) = k*nb0 + c*nb1\n    // We want offset_w(k, 0, c) = k*nb0 + c*nb1,\n    // so we can reuse nb0 and nb1, and set nb2 = nb1.\n    size_t  w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] };  // same as src1\n\n    acl_tensor_ptr acl_w = ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type),\n                                                   ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);\n\n    // 3) Output: dst is { d_inner, n_t, n_s } (CLN)\n    //\n    // We need an NCL view of the same buffer:\n    //   desired NCL logical shape: { L_out = n_t, C = nr, N = n_s }\n    //\n    // Original CLN layout:\n    //   dst->ne = { nr, n_t, n_s }\n    //   dst->nb[0] = sizeof(float)\n    //   dst->nb[1] = nr * sizeof(float)\n    //   dst->nb[2] = nr * n_t * sizeof(float)\n    //\n    // We want offset_new(L, C, N) = offset_orig(C, L, N).\n    // Choose:\n    //   nb_y[0] = nr * sizeof(float);           // step in L\n    //   nb_y[1] = sizeof(float);                // step in C\n    //   nb_y[2] = nr * n_t * sizeof(float);     // step in N\n    int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 };  // [L_out, C, N]\n    size_t  y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float),\n                                    dst->nb[3] };       // [nr, 1, nr * n_t]\n\n    acl_tensor_ptr acl_y = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),\n                                                   ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);\n\n    // --- Conv1d parameters: depthwise, stride 1, no padding (\"valid\") ---\n    int64_t strideVal[1]   = { 1 };\n    int64_t paddingVal[1]  = { 0 };\n    int64_t dilationVal[1] = { 1 };\n\n    acl_int_array_ptr stride   = ggml_cann_create_int_array(strideVal, 1);\n    acl_int_array_ptr padding  = ggml_cann_create_int_array(paddingVal, 1);\n    acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);\n\n    const bool    transposed   = false;\n    const int64_t groups       = nr;  // depthwise: one group per inner dim\n    int8_t        cubeMathType = 0;\n\n#ifdef ASCEND_310P\n    cubeMathType = 1;\n#endif\n\n    GGML_CANN_CALL_ACLNN_OP(ctx, Convolution,\n                            acl_x.get(),    // input:  N, C, L_in = ncs\n                            acl_w.get(),    // weight: [C, 1, K] with groups=nr\n                            nullptr,        // bias\n                            stride.get(), padding.get(), dilation.get(), transposed,\n                            padding.get(),  // output padding (unused for non-transposed)\n                            groups, acl_y.get(), cubeMathType);\n}\n\nvoid ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,\n                                     ggml_tensor *               add_node,\n                                     ggml_tensor *               rms_norm_node) {\n    // Get the two input tensors for ADD operation\n    ggml_tensor * x1 = add_node->src[0];\n    ggml_tensor * x2 = add_node->src[1];\n\n    // Create ACL tensors for the two ADD inputs\n    acl_tensor_ptr acl_x1 = ggml_cann_create_tensor(x1);\n    acl_tensor_ptr acl_x2 = ggml_cann_create_tensor(x2);\n\n    // Get epsilon parameter from rms_norm_tensor\n    float eps;\n    memcpy(&eps, rms_norm_node->op_params, sizeof(float));\n\n    // Build gamma tensor (RMS normalization scaling factor)\n    // Gamma should match the normalized dimensions (last dimension of x1)\n    size_t acl_gamma_nb[GGML_MAX_DIMS];\n    acl_gamma_nb[0] = ggml_type_size(rms_norm_node->type);\n    for (int i = 1; i < GGML_MAX_DIMS; i++) {\n        acl_gamma_nb[i] = acl_gamma_nb[i - 1] * x1->ne[i - 1];\n    }\n    acl_tensor_ptr acl_gamma =\n        get_cache_acl_tensor(ctx, &ctx.rms_norm_one_tensor_cache.cache, ctx.rms_norm_one_tensor_cache.size, x1->ne,\n                             acl_gamma_nb, rms_norm_node->type,\n                             1,    // dims - only the last dimension\n                             1.0f  // value\n        );\n\n    // Build rstdOut tensor (output for normalized standard deviation)\n    // Shape should be the dimensions that are NOT normalized\n    int64_t acl_rstd_ne[] = { 1, x1->ne[1], x1->ne[2], x1->ne[3] };\n    size_t  acl_rstd_nb[GGML_MAX_DIMS - 1];\n    acl_rstd_nb[0] = sizeof(float);\n    for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {\n        acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];\n    }\n    acl_tensor_ptr acl_rstd =\n        get_cache_acl_tensor(ctx, &ctx.rms_norm_zero_tensor_cache.cache, ctx.rms_norm_zero_tensor_cache.size,\n                             acl_rstd_ne, acl_rstd_nb, GGML_TYPE_F32, GGML_MAX_DIMS,\n                             0.0f  // value\n        );\n\n    acl_tensor_ptr acl_xout = ggml_cann_create_tensor(add_node);\n\n    // Create yOut tensor (final output after RMS normalization)\n    acl_tensor_ptr acl_yout = ggml_cann_create_tensor(rms_norm_node);\n\n    // Call fused ADD + RMS_NORM operator\n    GGML_CANN_CALL_ACLNN_OP(ctx, AddRmsNorm, acl_x1.get(), acl_x2.get(), acl_gamma.get(),\n                            eps,  // double type\n                            acl_yout.get(), acl_rstd.get(), acl_xout.get());\n}\n\nvoid ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * k = dst->src[0];\n    ggml_tensor * v = dst->src[1];\n    ggml_tensor * q = dst->src[2];\n    ggml_tensor * g = dst->src[3];\n    ggml_tensor * s = dst->src[4];\n\n    int64_t B = dst->src[4]->ne[1];\n    int64_t T = dst->src[0]->ne[2];\n    int64_t H = dst->src[0]->ne[1];\n    int64_t C = dst->ne[0];\n    int64_t D = C / H;\n    int64_t L = T / B;\n\n    int64_t ne_qkg[2] = { 1, D };\n    int64_t ne_s[2]   = { D, D };\n    int64_t ne_st[2]  = { ne_s[1], ne_s[0] };\n    int64_t ne_vo[2]  = { D, 1 };\n    int64_t ne_q[1]   = { D };\n    size_t  nb_base   = ggml_type_size(k->type);\n    size_t  nb_qkg[2] = { nb_base, nb_base };\n    size_t  nb_s[2]   = { nb_base, D * nb_base };\n    size_t  nb_st[2]  = { nb_s[1], nb_s[0] };\n    size_t  nb_vo[2]  = { nb_base, D * nb_base };\n    size_t  nb_q[1]   = { nb_base };\n\n    const float scale = ggml_get_op_params_f32(dst, 0);\n\n    acl_tensor_ptr acl_s     = ggml_cann_create_tensor(s, s->ne, s->nb, 2, ACL_FORMAT_ND);\n    acl_tensor_ptr new_state = ggml_cann_create_tensor(dst, s->ne, s->nb, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base);\n    cann_copy(ctx, acl_s.get(), new_state.get());\n\n    for (int64_t b = 0; b < B; b++) {\n        for (int64_t h = 0; h < H; h++) {\n            size_t         s_offset = (b * (H * D * D) + h * (D * D)) * nb_base;\n            // D * D\n            acl_tensor_ptr acl_s_new =\n                ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);\n            acl_tensor_ptr acl_s_new_t =\n                ggml_cann_create_tensor(dst, ne_st, nb_st, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);\n            for (int64_t l = 0; l < L; l++) {\n                size_t               qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base;\n                // D * 1\n                acl_tensor_ptr       acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);\n                acl_tensor_ptr       acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);\n                // D\n                acl_tensor_ptr       acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);\n                // 1 * D\n                acl_tensor_ptr       acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset);\n                // D\n                acl_tensor_ptr       acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);\n                // k ⊗ v\n                size_t               buf_size = D * D * nb_base;\n                ggml_cann_pool_alloc buffer_allocator(ctx.pool(), buf_size);\n                acl_tensor_ptr       tmp_tensor = ggml_cann_create_tensor(\n                    buffer_allocator.get(), ggml_cann_type_mapping(k->type), nb_base, ne_s, nb_s, 2);\n                aclnn_mul(ctx, acl_k.get(), acl_v.get(), tmp_tensor.get());\n                //s_new = g ⊗ s_old + k ⊗ v\n                aclnn_mul(ctx, acl_s_new.get(), acl_g.get(), nullptr);\n                aclnn_add(ctx, acl_s_new.get(), tmp_tensor.get(), nullptr);\n                // compute output\n                GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_new_t.get(), acl_q.get(), acl_o.get(), 1);\n                aclnn_muls(ctx, acl_o.get(), scale, nullptr, true);\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-cann/aclnn_ops.h",
    "content": "/**\n * Copyright (c) 2023-2026 The ggml authors\n *\n * Permission is hereby granted, free of charge, to any person obtaining a copy\n * of this software and associated documentation files (the \"Software\"), to\n * deal in the Software without restriction, including without limitation the\n * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or\n * sell copies of the Software, and to permit persons to whom the Software is\n * furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS\n * IN THE SOFTWARE.\n */\n\n#ifndef CANN_ACLNN_OPS\n#define CANN_ACLNN_OPS\n\n#include \"acl_tensor.h\"\n#include \"common.h\"\n\n#include <aclnnop/aclnn_abs.h>\n#include <aclnnop/aclnn_arange.h>\n#include <aclnnop/aclnn_argsort.h>\n#include <aclnnop/aclnn_cat.h>\n#include <aclnnop/aclnn_clamp.h>\n#include <aclnnop/aclnn_cos.h>\n#include <aclnnop/aclnn_exp.h>\n#include <aclnnop/aclnn_gelu.h>\n#include <aclnnop/aclnn_gelu_v2.h>\n#include <aclnnop/aclnn_hardsigmoid.h>\n#include <aclnnop/aclnn_hardswish.h>\n#include <aclnnop/aclnn_leaky_relu.h>\n#include <aclnnop/aclnn_log.h>\n#include <aclnnop/aclnn_logsoftmax.h>\n#include <aclnnop/aclnn_neg.h>\n#include <aclnnop/aclnn_norm.h>\n#include <aclnnop/aclnn_relu.h>\n#include <aclnnop/aclnn_sigmoid.h>\n#include <aclnnop/aclnn_sign.h>\n#include <aclnnop/aclnn_silu.h>\n#include <aclnnop/aclnn_sin.h>\n#include <aclnnop/aclnn_slice.h>\n#include <aclnnop/aclnn_sqrt.h>\n#include <aclnnop/aclnn_tanh.h>\n\n#include <functional>\n#include <unordered_set>\n\n/**\n * @brief   Repeats a ggml tensor along each dimension to match the dimensions\n *          of another tensor.\n *\n * @details This function repeats the elements of a source ggml tensor along\n *          each dimension to create a destination tensor with the specified\n *          dimensions. The operation is performed using the ACL backend and\n *          executed asynchronously on the device.\n *\n * @param   ctx The CANN context used for operations.\n * @param   dst The ggml tensor representing the destination, which op is\n *              GGML_OP_REPEAT and specifies the desired dimensions.\n */\nvoid ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Applies the Leaky ReLU activation function to a tensor using the CANN\n *          backend.\n *\n * @details This function computes the Leaky ReLU activation for each element of\n *          the input tensor. The Leaky ReLU function allows a small gradient\n *          when the unit is not active (i.e., when the input is negative). The\n *          Leaky ReLU function is defined as:\n *          \\f[\n *              \\text{dst} = \\max(0, src) + \\text{negativeSlope} \\cdot \\min(0,\n *               src)\n *          \\f]\n *          `negativeSlope` is in dst->params.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the result of the Leaky ReLU\n *            activation is stored, which op is `GGML_OP_LEAKY_RELU`\n */\nvoid ggml_cann_leaky_relu(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief    Concatenates multiple tensors along a specified dimension using the\n *           CANN backend.\n *\n * @param ctx        The CANN context used for operations.\n * @param tensorList A pointer to the list of tensors to be concatenated.\n * @param dst        The destination tensor where the result of the\n *                   concatenation is stored. dst->op is `GGML_OP_CONCAT`.\n * @param concat_dim The dimension along which the tensors are concatenated.\n *\n * @attention tensorList length should be 2 and the dimension using for concat\n *            default to 1.\n */\nvoid ggml_cann_concat(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Generates a sequence of evenly spaced values within a specified\n *          interval for a ggml tensor using the CANN backend.\n *\n * @details This function creates a sequence of numbers over a specified i\n *          nterval, starting from `start`, ending before `stop`, and\n *          incrementing by `step`. The sequence is stored in the destination\n *          tensor `dst`.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the generated sequence will be stored.\n *            `start`, 'stop' and 'step' are in dst->op_params and dst->op is\n *            `GGML_OP_ARANGE`.\n */\nvoid ggml_cann_arange(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Applies a clamp operation to the elements of a ggml tensor using the\n *          CANN backend.\n *\n * @details This function clamps the elements of the input tensor `src` to a\n *          specified range defined by `min` and `max` values. The result is\n *          stored in the destination tensor `dst`. The operation is defined as:\n *          \\f[\n *              y = \\max(\\min(x, max\\_value), min\\_value)\n *           \\f]\n *          where `x` is an element of the input tensor, and `y` is the\n *          corresponding element in the output tensor.\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the clamped values will be stored.\n *            dst->op is `GGML_OP_CLAMP`, `min` and `max` value is in dst->params.\n */\nvoid ggml_cann_clamp(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Scales the elements of a ggml tensor by a constant factor using the\n *          CANN backend.\n *\n * @details This function multiplies each element of the input tensor `src` by\n *          a scaling factor `scale`, storing the result in the destination\n *          tensor `dst`. The operation is defined as:\n *          \\f[\n *             dst = src \\times scale\n *          \\f]\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the scaled values will be stored.\n *            dst->op is `GGML_OP_SCALE` and `scale` value is in dst->params.\n */\nvoid ggml_cann_scale(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Sorts the elements of a ggml tensor and returns the indices that\n *          would sort the tensor using the CANN backend.\n *\n * @details This function performs an argsort operation on the input tensor\n *          `src`. It sorts the elements of `src` in either ascending or\n *          descending order, depending on the `GGML_SORT_ORDER_DESC`,\n *          and returns the indices that would sort the original tensor.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the sorted indices will be stored.\n *            dst->op is `GGML_OP_ARGSORT`.\n */\nvoid ggml_cann_argsort(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Computes the Layer Normalization for a ggml tensor using the CANN\n *          backend.\n *\n * @details This function applies the Layer Normalization operation on the\n *          input tensor `src` and stores the result in the destination tensor\n *          `dst`. Layer Normalization normalizes the features at each sample in\n *          a mini-batch independently. It is commonly used in neural networks\n *          to normalize the activations of a layer by adjusting and scaling\n *          the outputs.\n *          The operation is defined as:\n *          \\f[\n *              \\text { out }=\\frac{x-\\mathrm{E}[x]}{\\sqrt{\\text{Var}[x]+eps}}\n *          \\f]\n *          `Var` defaults dst->ne[0]. `eps` is in dst->params.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the normalized values will be stored.\n * @attention `Var` defaults to dst->ne[0].\n */\nvoid ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Computes the L2 Normalization for a ggml tensor using the CANN\n *          backend.\n *\n * @details This function applies the L2 Normalization operation on the\n *          input tensor `src` and stores the result in the destination tensor\n *          `dst`. L2 Normalization scales the input tensor such that the\n *          L2 norm along the specified dimension equals 1. This operation\n *          is commonly used in neural networks for feature normalization\n *          and vector scaling.\n *          The operation is defined as:\n *          \\f[\n *              \\text{out} = \\frac{x}{\\sqrt{\\sum{x^2}}}\n *          \\f]\n *          The normalization is performed along the last dimension by default.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the normalized values will be stored.\n * @attention The normalization is performed along the last dimension of the\n *            input tensor by default.\n */\nvoid ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Computes the Cross Entropy Loss for a ggml tensor using the CANN\n *          backend.\n *\n * @details This function computes the cross entropy loss between the predicted\n *          logits and target probability distributions. The operation follows\n *          the same computation pattern as the CPU implementation:\n *          1. Applies log_softmax to the logits along the class dimension\n *          2. Element-wise multiplication with target distributions\n *          3. Summation along the class dimension to get per-sample losses\n *          4. Global summation and scaling by -1/nr to get final loss\n *\n *          The computation can be expressed as:\n *          \\f[\n *              \\text{loss} = -\\frac{1}{N} \\sum_{i=1}^{N} \\sum_{j=1}^{C} y_{ij} \\cdot \\log(\\text{softmax}(x_{ij}))\n *          \\f]\n *          where \\f$N\\f$ is the total number of samples, \\f$C\\f$ is the number\n *          of classes, \\f$x\\f$ are the logits, and \\f$y\\f$ are the target\n *          probability distributions.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the computed loss will be stored.\n *            This should be a scalar tensor containing the final loss value.\n *\n * @note This implementation computes cross entropy between probability\n *       distributions, not the typical classification cross entropy that\n *       expects class indices as targets. Both input tensors (src0 and src1)\n *       should have the same shape and represent probability distributions\n *       over the class dimension.\n * @note The function expects two source tensors:\n *       - dst->src[0]: Logits tensor (before softmax)\n *       - dst->src[1]: Target probability distributions tensor\n * @note The computation is performed using CANN backend operators including\n *       LogSoftmax, Mul, ReduceSum, and Muls for the final scaling.\n */\nvoid ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief  Computes the Group Normalization for a ggml tensor using the CANN\n *         backend.\n *\n * @brief  This function applies the Group Normalization operation on the input\n *         tensor `src` and stores the result in the destination tensor `dst`.\n *         Group Normalization divides the channels into groups and normalizes\n *         the features within each group across spatial locations.\n *         It is commonly used in convolutional neural networks to improve\n *         training stability and performance.\n *         The operation is defined as:\n *         \\f[\n *             \\text { out }=\\frac{x-\\mathrm{E}[x]}{\\sqrt{\\text{Var}[x]+eps}}\n *         \\f]\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the normalized values will be stored.\n *            `n_groups` is in dst->params, which split C channel to `n_groups`.\n *            dst->op is `GGML_OP_GROUP_NORM`.\n *\n * @attention eps defaults to 1e-6f.\n */\nvoid ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Computes the accumulation of tensors using the CANN backend.\n *\n * @details This function performs an accumulation operation on two tensors.\n *          Depending on the `inplace` flag, it either updates the destination\n *          tensor `dst` in place by adding `alpha * src1` to it, or it creates\n *          a new tensor as the result of `src0 + alpha * src1` and stores it in\n *          `dst`.\n *          The operation is defined as:\n *          \\f[\n *               dst = src0 + alpha \\times src1\n *          \\f]\n *          if `inplace` is `true`, `src0` is equal to 'dst'.\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the accumulated values will be stored.\n *            `inplace` is in dst->params, and dst->op is `GGML_OP_ACC`.\n */\nvoid ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Computes the sum of elements along the last dimension of a ggml tensor\n *          using the CANN backend.\n *\n * @details This function performs a reduction sum operation along the last\n *          dimension of the input tensor `src`. The result of the sum is stored\n *          in the destination tensor `dst`.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the reduced values will be stored。\n *            dst->op is `GGML_OP_SUM_ROWS`.\n *\n * @attention `reduce_dims` defaults to 3, which means the last dimension.\n */\nvoid ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Computes the sum of elements in a ggml tensor.\n *\n * @details This function performs a reduction sum operation along the last\n *          dimension of the input tensor `src`. The result of the sum is stored\n *          in the destination tensor `dst`.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the reduced values will be stored。\n *\n */\n\nvoid ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Upsamples a ggml tensor using nearest neighbor interpolation using\n *          the CANN backend.\n *\n * @details This function performs upsampling of the input tensor `src` using\n *          nearest neighbor interpolation. The upsampling is applied to the\n *          height and width dimensions (last two dimensions) of the tensor. The\n *          result is stored in the destination tensor `dst`, which must have\n *          the appropriate dimensions for the upsampled output.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the upsampled values will be stored.\n *            dst->op is `GGML_OP_UPSCALE`.\n */\nvoid ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Pads a ggml tensor to match the dimensions of the destination tensor\n *          using the CANN backend.\n *\n * @details This function pads the input tensor `src` so that it matches the\n *          dimensions of the destination tensor `dst`. The amount of padding\n *          is calculated based on the difference in sizes between `src` and\n *          `dst` along each dimension. The padded tensor is stored in `dst`.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor, which specifies the target dimensions for\n *            padding. dst->op is `GGML_OP_PAD`.\n */\nvoid ggml_cann_pad(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Executes a 2D pooling operation on a ggml tensor using the CANN\n *          backend.\n *\n * @details This function dispatches the execution of a 2D pooling operation on\n *          the input tensor `dst`. The type of pooling (average or max) is\n *          determined by the `op` parameter, which is read from the operation\n *          parameters of `dst`. The function supports average pooling\n *          (`GGML_OP_POOL_AVG`) and max pooling (`GGML_OP_POOL_MAX`). If an\n *          invalid operation is encountered, the function asserts a failure.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor on which the pooling operation is to be\n *            performed. dst->op is `GGML_OP_POOL_2D`.\n */\nvoid ggml_cann_pool2d(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Duplicates a ggml tensor using the CANN backend.\n *\n * @details This function duplicates the contents of the source tensor `src` to\n *          the destination tensor `dst`. The function supports various tensor\n *          types and configurations, including handling of extra data, type\n *          conversions, and special cases for contiguous and non-contiguous\n *          tensors.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the duplicated data will be stored.\n *            dst->op is `GGML_OP_DUP`\n *\n * @attention Only support Fp16/FP32. Not support when src and dst have\n *            different shape and dst is no-contiguous.\n * @note:     This func need to simplify.\n */\nvoid ggml_cann_dup(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Computes the Root Mean Square (RMS) normalization of a ggml tensor\n *          using the CANN backend.\n *\n * @details This function applies RMS normalization to the input tensor `src`\n *          and stores the result in the destination tensor `dst`. RMS\n *          normalization involves computing the root mean square of the input\n *          tensor along a specified dimension and then dividing each element of\n *          the tensor by this value, adjusted by a small epsilon value to\n *          prevent division by zero.\n *          The operation is defined as:\n *          \\f[\n *               \\text{RmsNorm}\\left(x_i\\right)=\\frac{x_i}{\\text{Rms}(\\mathbf{x})} g_i,\n *               \\quad \\text { where } \\text{Rms}(\\mathbf{x})=\\sqrt{\\frac{1}{n} \\sum_{i=1}^n x_i^2+e p s}\n *          \\f]\n *          `eps` is in dst->op_params.\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the normalized values will be stored.\n *            dst->op is `GGML_OP_RMS_NORM`.\n */\nvoid ggml_cann_rms_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Applies a diagonal mask to the tensor with a specified value.\n *\n * @details This function creates a mask tensor filled with ones, then applies\n *          an upper triangular and lower triangular operation to it based on\n *          the number of past elements specified. Afterward, it adds the masked\n *          tensor to the destination tensor in-place.\n *\n * @param ctx The backend CANN context used for operations.\n * @param dst The destination tensor where the result will be stored. dst->op is\n *            `GGML_OP_DIAG_MASK`\n * @param value The value to use for masking.\n */\nvoid ggml_cann_diag_mask(ggml_backend_cann_context & ctx, ggml_tensor * dst, float value);\n\n/**\n * @brief   Performs an image-to-column transformation on the input tensor.\n *\n * @details This function takes an input tensor and applies an image-to-column\n *          operation, converting spatial dimensions into column-like\n *          structures suitable for convolutional operations. It supports both\n *          half-precision (F16) and single-precision (F32) floating-point data\n *          types.\n *\n * @param ctx The backend CANN context for executing operations.\n * @param dst The destination tensor that stores the result of the operation.\n *            dst->op is `GGML_OP_IM2COL`.\n */\nvoid ggml_cann_im2col(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Computes time step embeddings using sine and cosine functions.\n *\n * @details This function calculates time step embeddings by applying sine and\n *          cosine transformations to a given input tensor, which is typically\n *          used in temporal models like diffusion models or transformers to\n *          encode time information effectively.\n *\n * @param ctx The backend CANN context for executing operations.\n * @param dst The destination tensor where the result of the embedding operation\n *            will be stored. dst->op is `GGML_OP_TIMESTEP_EMBEDDING`.\n */\nvoid ggml_cann_timestep_embedding(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n// @see ggml_cann_dup.\nvoid ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Computes the softmax activation with optional masking.\n *\n * @details This function computes the softmax activation over the input tensor,\n *          optionally applying a mask and scaling factor. It supports both FP16\n *          and FP32 data types and can handle masking by broadcasting the mask\n *          across rows if necessary.\n *          The function performs the following steps:\n *          1. Multiplies the input tensor by a scale factor.\n *          2. Optionally casts the mask tensor to FP32 if it is in FP16 format.\n *          3. Broadcasts the mask tensor if its dimensions do not match the\n *             input tensor's dimensions.\n *          4. Adds the mask to the scaled input tensor.\n *          5. Applies the softmax activation function along the specified\n *             dimension.\n *\n * @param ctx The backend CANN context for executing operations.\n * @param dst The destination tensor where the result will be stored. dst->op is\n *            `GGML_OP_SOFTMAX`.\n */\nvoid ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Extracts specific rows from a tensor based on indices.\n *\n * @details This function retrieves rows from a source tensor src0 according to\n *          the indices provided in another tensor src1 and stores the result in\n *          a destination tensor (\\p dst).\n *\n * @param ctx The backend CANN context for executing operations.\n * @param dst The destination tensor where the extracted rows will be stored.\n */\nvoid ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Writes specific rows into a tensor at positions specified by indices.\n *\n * @details This function copies rows from a source tensor into a destination\n *          tensor (\\p dst) at the positions indicated by the indices in another\n *          tensor.\n *\n * @param ctx The backend CANN context for executing operations.\n * @param dst The destination tensor where the specified rows will be updated.\n */\nvoid ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Executes matrix multiplication for the given tensor.\n *\n * @details This function performs matrix multiplication on the source tensors\n *          associated with the destination tensor. It supports matrix\n *          multiplication F32, F16, and Q8_0.\n *\n * @param ctx The backend CANN context for executing operations.\n * @param dst The destination tensor for storing the result of the matrix\n *            multiplication. dst->op is `GGML_OP_MUL_MAT`.\n */\nvoid ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief Applies Rotary Positional Embedding (RoPE) to the input tensor.\n *\n * @details This function implements the RoPE mechanism, which is a method to\n *          encode positional information into sequence data, particularly\n *          useful in transformer models. It supports both F32 and F16 data\n *          types.\n *\n * @param ctx The backend CANN context for executing operations.\n * @param dst The destination tensor where the RoPE-transformed data will be\n *            stored. dst->op is `GGML_OP_ROPE`.\n *\n * @note The function currently does not support cases where the n_dims is less\n *       than the input tensor's first dimension.\n * @note The function currently does not support cases where the freq_factors is\n *       not NULL.\n * @note The function currently does not support cases where the ext_factor is\n *       not equal 0.\n * @note The function currently does not support cases where the freq_scale is\n *       not equal 1.\n */\nvoid ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Computes the index of the maximum value along the specified dimension\n *          of a ggml tensor using the CANN backend.\n *\n * @details This function performs an argmax operation on the input tensor.\n *          It finds the index of the maximum value along the specified axis\n *          and stores these indices in the destination tensor `dst`. The\n *          operation is executed using the CANN backend for optimized performance.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the indices of the maximum values will\n *            be stored. dst->op is `GGML_OP_ARGMAX`.\n */\nvoid ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief Adds two tensors element-wise and stores the result in a destination\n * tensor.\n *\n * This function performs the operation:\n * \\f[\n *    dst = acl\\_src0 + alpha \\times acl\\_src1\n * \\f]\n * where alpha is a scalar value and defaults to 1.0f.\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src0 The first source tensor.\n * @param acl_src1 The second source tensor.\n * @param acl_dst The destination tensor where the result will be stored.\n */\nvoid aclnn_add(ggml_backend_cann_context & ctx,\n               aclTensor *                 acl_src0,\n               aclTensor *                 acl_src1,\n               aclTensor *                 acl_dst = nullptr);\n\n/**\n * @brief Sub two tensors element-wise and stores the result in a destination\n * tensor.\n *\n * This function performs the operation:\n * \\f[\n *    dst = acl\\_src0 - alpha \\times acl\\_src1\n * \\f]\n * where alpha is a scalar value and defaults to 1.0f.\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src0 The first source tensor.\n * @param acl_src1 The second source tensor.\n * @param acl_dst The destination tensor where the result will be stored.\n */\nvoid aclnn_sub(ggml_backend_cann_context & ctx,\n               aclTensor *                 acl_src0,\n               aclTensor *                 acl_src1,\n               aclTensor *                 acl_dst = nullptr);\n\n/**\n * @brief Performs element-wise multiplication of two tensors and stores the\n * result in a destination tensor.\n *\n * This function performs element-wise multiplication of the tensors `acl_src`\n * and `acl_other` and stores the result in the destination tensor `acl_dst`.\n * The operation is defined as:\n * \\f[\n *     \\text {acl_dst }_i=\\text {acl_src }_i \\times \\text {acl_other }_i\n * \\f]\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src The first tensor for element-wise multiplication.\n * @param acl_other The second tensor for element-wise multiplication.\n * @param acl_dst The destination tensor where the result will be stored.\n */\nvoid aclnn_mul(ggml_backend_cann_context & ctx,\n               aclTensor *                 acl_src,\n               aclTensor *                 acl_other,\n               aclTensor *                 acl_dst = nullptr);\n\n/**\n * @brief Matrix division, optionally in-place.\n *\n * This function division each element of the source tensor `acl_src` by the\n * tensor `acl_other` and stores the result in the destination tensor `acl_dst`.\n * If `inplace` is true, `acl_dst` will not be used and the operation is\n * performed in-place on `acl_src`. The operation is defined as: \\f[\n *     \\text{dst}_i = \\frac{\\text{acl_src}_i}{\\text{acl_other}_i}\n * \\f]\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src Numerator tensor..\n * @param acl_other Denominator tensor.\n * @param acl_dst The destination tensor where the result will be stored if\n * `inplace` is false.\n * @param inplace Flag indicating whether to perform the operation in-place on\n * `acl_src`.\n */\nvoid aclnn_div(ggml_backend_cann_context & ctx,\n               aclTensor *                 acl_src,\n               aclTensor *                 acl_other,\n               aclTensor *                 acl_dst = nullptr);\n\n/**\n * @brief Applies element-wise cosine function to the elements of a tensor.\n *\n * This function computes the cosine of each element in the source tensor\n * `acl_src` and stores the result in the destination tensor `acl_dst`. The\n * operation is defined as: \\f[ \\text {acl_dst }_i=\\cos \\left(\\text {acl_src\n * }_i\\right) \\f]\n *\n * @param ctx The context for the CANN backend operations.\n * @param acl_src The source tensor on which the cosine function will be\n * applied.\n * @param acl_dst The destination tensor where the cosine results will be\n * stored.\n */\nvoid aclnn_cos(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst);\n\n/**\n * @brief Applies element-wise sine function to the elements of a tensor.\n *\n * This function computes the sine of each element in the source tensor\n `acl_src`\n * and stores the result in the destination tensor `acl_dst`.\n * The operation is defined as:\n * \\f[\n *     \\text {acl_dst }_i=\\sin \\left(\\text {acl_src }_i\\right)\n * \\f]\n\n * @param ctx The context for the CANN backend operations.\n * @param acl_src The source tensor on which the sine function will be applied.\n * @param acl_dst The destination tensor where the sine results will be stored.\n */\nvoid aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst);\n\n/**\n * @brief Prepares broadcast-compatible ACL tensors for two input tensors and one\n * output tensor.\n *\n * This function checks whether broadcasting is needed between `src0` and `src1`.\n * If broadcasting is required, it calculates the proper shapes and creates\n * ACL tensors with broadcast parameters. Otherwise, it directly creates ACL tensors\n * based on the original tensor shapes.\n *\n * @param src0     The first input tensor (reference shape).\n * @param src1     The second input tensor (possibly broadcasted).\n * @param dst      The destination/output tensor.\n * @param acl_src0 Output pointer to the created ACL tensor corresponding to src0.\n * @param acl_src1 Output pointer to the created ACL tensor corresponding to src1.\n * @param acl_dst  Output pointer to the created ACL tensor corresponding to dst.\n */\nvoid bcast_shape(ggml_tensor *    src0,\n                 ggml_tensor *    src1,\n                 ggml_tensor *    dst,\n                 acl_tensor_ptr & acl_src0,\n                 acl_tensor_ptr & acl_src1,\n                 acl_tensor_ptr & acl_dst);\n\n/**\n * @brief   Computes the 1D transposed convolution (deconvolution) of a ggml\n * tensor using the CANN backend.\n *\n * @details This function performs a 1D transposed convolution (also known as\n * deconvolution) operation on the input tensor. The computed result is stored\n * in the destination tensor `dst`. The operation is optimized using the CANN\n * backend for improved performance.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the transposed convolution result\n * will be stored. dst->op is `GGML_OP_CONV_TRANSPOSE_1D`.\n */\nvoid ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Applies the ELU (Exponential Linear Unit) activation to a ggml tensor\n * using the CANN backend.\n *\n * @details This function performs an element-wise ELU activation on the input\n *          tensor.\n *          The result is written to the destination tensor `dst` in-place.\n *          The ELU function is defined as:\n *\n *          \\text{ELU}(x) =\n *          \\begin{cases}\n *          x, & \\text{if } x > 0 \\\\\n *          \\alpha \\left( \\exp(x) - 1 \\right), & \\text{if } x \\leq 0\n *          \\end{cases}\n *\n *          where α (alpha) is a hyperparameter, typically set to 1.0.\n *          This operation is optimized using the CANN backend for high-performance\n *          inference or training.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the ELU-activated result will be stored.\n *            dst->op is expected to be `GGML_OP_ELU`.\n */\nvoid ggml_cann_elu(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Computes the mean of a ggml tensor element-wise using the CANN backend.\n *\n * @details This function calculates the element-wise mean of the input tensor.\n *          The result is written to the destination tensor `dst`.\n *          The mean is computed by averaging the values across the entire tensor.\n *\n *          This operation is optimized using the CANN backend for high-performance inference or training.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the mean result will be stored.\n *            dst->op is expected to be `GGML_OP_MEAN`.\n */\nvoid ggml_cann_mean(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Applies 1D reflect padding to a ggml tensor using the CANN backend.\n *\n * @details This function performs 1D reflect padding on the input tensor.\n *          The amount of padding on each side is specified by parameters stored in `dst->op_params`.\n *          The operation reflects the values at the borders of the tensor to generate the padded output.\n *\n *          This operation is optimized using the CANN backend for high-performance inference or training.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the padded result will be stored.\n *            dst->op is expected to be `GGML_OP_PAD_REFLECT_1D`.\n */\nvoid ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Counts the number of equal elements in two ggml tensors using the CANN backend.\n *\n * @details This function performs an element-wise comparison between two input tensors,\n *          and counts the number of positions where the elements are equal. The result is\n *          stored in the destination tensor `dst` as a scalar.\n *\n *          The operation is optimized using the CANN backend, making it suitable for\n *          high-performance inference or training scenarios.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the result will be stored.\n *            dst->op is expected to be `GGML_OP_COUNT_EQUAL`.\n */\nvoid ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Applies the Step activation function to a ggml tensor using the CANN backend.\n *\n * @details This function applies a step function element-wise to the input tensor, where\n *          each element is transformed to 1.0 if it is greater than 0, and 0.0 otherwise.\n *          The result is stored in the destination tensor `dst`.\n *\n *          This operation is accelerated using the CANN backend to improve runtime performance.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the result will be stored.\n *            dst->op is expected to be `GGML_OP_STEP`.\n */\nvoid ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief   Performs the Flash Attention extended operator using the CANN backend.\n *\n * @details This function implements the memory-efficient Flash Attention algorithm\n *          for computing scaled dot-product attention with hardware acceleration.\n *          The result is stored in the destination tensor `dst`.\n *\n *          This operation is accelerated using the CANN backend to improve runtime performance.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the result will be stored.\n *            dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`.\n */\nvoid ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief Forward Gated Linear Attention on the CANN backend.\n *\n * Expects dst->src[0..4] = {k, v, q, g, s} with shape conventions:\n *   k, v, q, g: [D] with outer dims T x H batched as ne[2]=T, ne[1]=H\n *   s: initial state [B, H, D, D], where B is batch and D=C/H\n * dst holds both outputs (o) and updated state; a scale factor is read from op params.\n *\n * The kernel updates per time step l: S_new = g ⊗ S_old + k ⊗ v, then computes o = (S_new^T q) * scale.\n *\n * @param ctx Backend context providing stream/allocator utilities.\n * @param dst Output tensor; src deps are k, v, q, g, s as above.\n */\nvoid ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief Launches an asynchronous task using the memory allocator.\n *\n * This macro submit an asynchronous task on the specified stream.\n * The task uses memory allocated by the allocator. It is guaranteed\n * that the memory will not be accessed by other tasks until this task\n * completes, due to the sequential execution order within the same stream.\n *\n * @param OP_NAME aclnn operator name.\n * @param args Additional arguments required by the task.\n *\n * @note\n * Memory from the allocator will be \"freed\" immediately and can be\n * reallocated to other pointers. However, it won't be accessed by any\n * other task before this asynchronous task ends, because all tasks in the\n * same stream are executed in queue order.\n */\n\n#    define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...)                                           \\\n        do {                                                                                     \\\n            uint64_t        workspaceSize = 0;                                                   \\\n            aclOpExecutor * executor;                                                            \\\n            void *          workspaceAddr = nullptr;                                             \\\n            ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \\\n            /* workspace should alloced in main thread to keep malloc order when using vmm. */   \\\n            if (workspaceSize > 0) {                                                             \\\n                ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize);             \\\n                workspaceAddr = workspace_allocator.get();                                       \\\n            }                                                                                    \\\n            ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));     \\\n        } while (0)\n\n/**\n * @brief   Performs sparse expert-based matrix multiplication using the CANN backend.\n *\n * @details This function implements a MoE-style batched matrix multiplication, where each input token\n *          is routed to one or more experts, and each expert corresponds to a specific [D, M] weight matrix\n *          in the source tensor `src0`. The routing indices are provided via the `ids` tensor.\n *\n *          For each token (from `src1`), the function selects the corresponding expert(s) as specified by `ids`,\n *          performs the matrix multiplication with the selected expert's weight submatrix (from `src0`),\n *          and stores the results in `dst`. This operation is optimized and executed on the CANN backend.\n *\n *          Dimensions:\n *              - src0: [D, M, A, 1], where A is the number of experts\n *              - src1: [D, B, N, 1], where N is batch size and B is the slot count per sample\n *              - ids : [K, N],       where K is the number of experts each token is routed to\n *              - dst : [M, K, N, 1], output tensor storing the result of expert × token multiplication\n *\n *          The function handles two main modes:\n *              - If `ne12 == 1`, a simpler per-token loop is used.\n *              - TODO: If `ne12 > 1`, grouped multiplication and memory copying is used for efficiency.\n *\n * @param ctx The CANN context used for operations.\n * @param dst The destination tensor where the expert-weighted token outputs are stored.\n *            Expected to be of shape [M, K, N, 1].\n */\nvoid ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief Performs fused ADD + RMS_NORM operation using the CANN backend.\n *\n * This function fuses the ADD and RMS_NORM operations into a single kernel call\n * for better performance. It first adds two input tensors (x1 + x2), then applies\n * RMS normalization to the result.\n *\n * @param ctx The context for the CANN backend operations.\n * @param dst The ADD operation node, contains the two input tensors to be added.\n * @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights\n *                        and epsilon parameter.\n */\nvoid ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,\n                                     ggml_tensor *               add_node,\n                                     ggml_tensor *               rms_norm_node);\n\n/**\n * @brief   Check whether a tensor is a weight tensor for matrix multiplication.\n *\n * @details Checks whether the given tensor serves as weight parameters in matrix multiplication operations,\n *          typically within neural network layers. The function maintains a static set of canonical weight\n *          naming suffixes from Transformer-based architectures. Uses substring matching to identify weight\n *          tensors even with hierarchical naming patterns.\n *\n * @param tensor Pointer to the target ggml_tensor object (const-qualified).\n */\nstatic bool is_matmul_weight(const ggml_tensor * tensor) {\n    std::string                                  name = ggml_get_name(tensor);\n    static const std::unordered_set<std::string> weight_suffixes{ \"output.weight\",      \"attn_q.weight\",\n                                                                  \"attn_k.weight\",      \"attn_v.weight\",\n                                                                  \"attn_output.weight\", \"ffn_gate.weight\",\n                                                                  \"ffn_up.weight\",      \"ffn_down.weight\" };\n\n    for (const auto & suffix : weight_suffixes) {\n        if (name.find(suffix) != std::string::npos) {\n            return true;\n        }\n    }\n    return false;\n}\n\n/**\n * @brief Applies a element-wise operation to two input tensors using the CANN\n * backend.\n *\n * This templated function takes a binary operator and applies it to two source\n * tensors\n * associated with the destination tensor. The function handles broadcasting as\n * needed.\n *\n * @tparam binary_op A callable object (e.g., lambda or function pointer) representing\n *         the binary operation to be performed. It must take three arguments:\n *         (ggml_backend_cann_context&, aclTensor*, aclTensor*, aclTensor*).\n *\n * @param ctx The CANN backend context used to manage execution and resources.\n * @param dst The destination tensor.\n */\ntemplate <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];\n    ggml_tensor * src1 = dst->src[1];\n\n    acl_tensor_ptr acl_src0, acl_src1, acl_dst;\n\n    // Need bcast\n    bcast_shape(src0, src1, dst, acl_src0, acl_src1, acl_dst);\n    binary_op(ctx, acl_src0.get(), acl_src1.get(), acl_dst.get());\n}\n\n/**\n * @brief Applies a unary operation to an input tensor using the CANN backend.\n *\n * This templated function applies a unary operator to the source tensor of `dst`\n * and stores the result in the destination tensor.\n *\n * @tparam unary_op A callable with the signature:\n *         void(ggml_backend_cann_context&, aclTensor *, aclTensor *)\n *         where the first aclTensor is the source and the second is the destination.\n * @param ctx The CANN backend context for managing resources and execution.\n * @param dst The destination tensor. Its src[0] is treated as the input tensor.\n */\ntemplate <void unary_op(ggml_backend_cann_context &, aclTensor *, aclTensor *)>\nvoid ggml_cann_op_unary(ggml_backend_cann_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src = dst->src[0];\n\n    acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);\n    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);\n\n    unary_op(ctx, acl_src.get(), acl_dst.get());\n}\n\n/**\n * @brief Applies a unary operation to a ggml tensor using the CANN backend.\n *\n * @details This function applies a unary operation to the input tensor using\n * a user-provided lambda or callable `unary_op`. The lambda receives the\n * CANN backend context and two ACL tensors: the source and the destination.\n *\n * Internally, this function handles the conversion from GGML tensors to ACL tensors,\n * calls the provided unary op, and manages resource cleanup. The input is assumed\n * to be `dst->src[0]`, and the result is written to `dst`.\n *\n * This utility simplifies writing unary op wrappers by abstracting tensor preparation.\n *\n * @param unary_op A callable that performs the unary operation using CANN ACL APIs.\n * @param ctx The CANN context for operation execution.\n * @param dst The destination ggml_tensor where the result will be stored.\n *            The input tensor is assumed to be `dst->src[0]`.\n *\n * @see GGML_CANN_CALL_OP_UNARY\n */\nvoid ggml_cann_op_unary(std::function<void(ggml_backend_cann_context &, aclTensor *, aclTensor *)> unary_op,\n                        ggml_backend_cann_context &                                                ctx,\n                        ggml_tensor *                                                              dst);\n\nvoid ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n\n/**\n * @brief Applies a gated (GLU-style) unary operation using the CANN backend.\n *\n * @details This function performs a gated activation such as GEGLU or ReGLU.\n * It supports two input modes:\n *\n * 1. **Dual input mode**: `dst->src[0]` and `dst->src[1]` are both valid tensors.\n *    These are used directly as the value and gate tensors.\n *\n * 2. **Packed input mode**: Only `dst->src[0]` is valid, and it is assumed to\n *    contain a concatenation of value and gate along the first dimension. This tensor\n *    will be split into two equal halves to form the value and gate inputs.\n *\n * The function applies a user-provided unary operation (e.g., GELU) to the value tensor,\n * then multiplies the result in-place with the gate tensor:\n *\n * @code\n * dst = unary_op(value) * gate;\n * @endcode\n *\n * The `swapped` parameter (from `dst->op_params[1]`) allows flipping the\n * order of value/gate in the packed input case.\n *\n * @param unary_op A callable that performs the unary operation using CANN ACL APIs.\n *                 It receives (ctx, acl_value_tensor, acl_output_tensor).\n * @param ctx      The CANN context used for execution.\n * @param dst      The destination ggml_tensor. Source tensors are in `dst->src[0]` and optionally `src[1]`.\n *\n * @see GGML_CANN_CALL_OP_UNARY_GATED\n */\nvoid ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, aclTensor *, aclTensor *)> unary_op,\n                              ggml_backend_cann_context &                                                ctx,\n                              ggml_tensor *                                                              dst);\n\n/**\n * @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary.\n *\n * This macro wraps the specified ACLNN unary operator name into a lambda expression,\n * and passes it to `ggml_cann_op_unary`, which handles the common logic for executing\n * unary ops in the CANN backend.\n *\n * Internally, this macro expands to a lambda like:\n * @code\n * [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {\n *     GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);\n * };\n * @endcode\n *\n * This lambda is then passed to `ggml_cann_op_unary`, which applies the operation.\n *\n * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.\n *\n * @see ggml_cann_op_unary\n * @see GGML_CANN_CALL_ACLNN_OP\n */\n#    define GGML_CANN_CALL_OP_UNARY(OP_NAME)                                                              \\\n        do {                                                                                              \\\n            auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \\\n                GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);                                  \\\n            };                                                                                            \\\n            ggml_cann_op_unary(lambda, ctx, dst);                                                         \\\n        } while (0)\n\n/**\n * @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated.\n *\n * This macro wraps the specified ACLNN unary operator name into a lambda expression,\n * and passes it to `ggml_cann_op_unary_gated`, which handles the common logic for\n * executing gated unary ops in the CANN backend.\n *\n * Internally, this macro expands to a lambda like:\n * @code\n * [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {\n *     GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);\n * };\n * @endcode\n *\n * This lambda is then passed to `ggml_cann_op_unary_gated`, which applies the operation.\n *\n * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.\n *\n * @see ggml_cann_op_unary_gated\n * @see GGML_CANN_CALL_ACLNN_OP\n */\n#    define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME)                                                        \\\n        do {                                                                                              \\\n            auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \\\n                GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);                                  \\\n            };                                                                                            \\\n            ggml_cann_op_unary_gated(lambda, ctx, dst);                                                   \\\n        } while (0)\n\n#endif  // CANN_ACLNN_OPS\n\n/**\n * @brief Performs outer product operation on two ggml tensors using the CANN backend.\n *\n * @details This function computes the outer product of two input tensors (src0 and src1)\n * and stores the result in the destination tensor. The outer product operation is defined as:\n * dst[i,j,k,l] = sum_m (src0[i,m,k,l] * src1[j,m,k,l])\n *\n * The function supports multiple data types including F32, F16. For floating-point\n * types, it uses batch matrix multiplication for efficient computation.\n *\n * The implementation handles 4D tensor broadcasting and batch processing automatically.\n *\n * @param ctx The CANN backend context for operation execution and memory management.\n * @param dst The destination ggml_tensor where the outer product result will be stored.\n *            The input tensors are assumed to be `dst->src[0]` and `dst->src[1]`.\n *\n * @see GGML_CANN_CALL_ACLNN_OP for CANN operator invocation\n */\nvoid ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cann/common.h",
    "content": "/*\n * Copyright (c) 2023-2026 The ggml authors\n *\n * Permission is hereby granted, free of charge, to any person obtaining a copy\n * of this software and associated documentation files (the \"Software\"), to\n * deal in the Software without restriction, including without limitation the\n * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or\n * sell copies of the Software, and to permit persons to whom the Software is\n * furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS\n * IN THE SOFTWARE.\n */\n\n#ifndef CANN_COMMON_H\n#define CANN_COMMON_H\n\n#include \"../ggml-impl.h\"\n#include \"../include/ggml-cann.h\"\n#include \"../include/ggml.h\"\n\n#include <acl/acl.h>\n#include <unistd.h>\n\n#include <atomic>\n#include <condition_variable>\n#include <cstdio>\n#include <functional>\n#include <iostream>\n#include <list>\n#include <map>\n#include <memory>\n#include <mutex>\n#include <optional>\n#include <string>\n#include <thread>\n#include <vector>\n\n#define MATRIX_ROW_PADDING    512\n#define GGML_CANN_MAX_STREAMS 8\n\n/**\n * @brief Handles CANN-related errors by printing an error message and\n *        terminating the program.\n * @param stmt The statement that caused the error.\n * @param func The function in which the error occurred.\n * @param file The file in which the error occurred.\n * @param line The line number at which the error occurred.\n * @param msg The error message.\n */\n[[noreturn]] void ggml_cann_error(const char * stmt, const char * func, const char * file, int line, const char * msg);\n\n/**\n * @brief Checks the result of a CANN function call and invokes the error\n *        handler if the call fails.\n * @param stmt The CANN function call to check.\n * @param success The success code that indicates the call was successful.\n * @param error_fn The function to call to retrieve the error message.\n */\n#define ACL_CHECK_GEN(stmt, success, error_fn)                                \\\n    do {                                                                      \\\n        int err_code = (stmt);                                                \\\n        if (err_code != (success)) {                                          \\\n            ggml_cann_error(#stmt, __func__, __FILE__, __LINE__, error_fn()); \\\n        }                                                                     \\\n    } while (0);\n\n#define ACL_CHECK(stmt) ACL_CHECK_GEN(stmt, 0, aclGetRecentErrMsg)\n\n/**\n * @brief Contains information about CANN devices.\n */\nstruct ggml_cann_device_info {\n    /**\n     * @brief Number of CANN devices available.\n     */\n    int32_t device_count;\n\n    /**\n     * @brief Information about a single CANN device.\n     */\n    struct cann_device_info {\n        int    cc;              /**< Compute capability.                   */\n        size_t smpb;            /**< Maximum shared memory per block.      */\n        bool   vmm;             /**< Virtual memory support.               */\n        size_t vmm_granularity; /**< Granularity of virtual memory.        */\n        size_t total_vram;      /**< Total video RAM available on the device. */\n    };\n\n    cann_device_info devices[GGML_CANN_MAX_DEVICES] = {}; /**< Array of CANN device information. */\n};\n\nconst ggml_cann_device_info & ggml_cann_info();\n\nvoid    ggml_cann_set_device(int32_t device);\n\nstd::optional<std::string> get_env_as_lowercase(const std::string & name);\nbool                       parse_bool(const std::string & value);\nint                        parse_integer(const std::string & value);\n\n/**\n * @brief Abstract base class for memory pools used by CANN.\n */\nstruct ggml_cann_pool {\n    /**\n     * @brief Virtual destructor for the memory pool.\n     */\n    virtual ~ggml_cann_pool() = default;\n\n    /**\n     * @brief Allocates memory from the pool.\n     *\n     * @param size         The size of the memory block to allocate.\n     * @param actual_size  Pointer to a variable where the actual allocated size\n     *                     will be stored.\n     * @return             Pointer to the allocated memory block.\n     */\n    virtual void * alloc(size_t size, size_t * actual_size) = 0;\n\n    /**\n     * @brief Frees a previously allocated memory block.\n     *\n     * @param ptr   Pointer to the memory block to free.\n     * @param size  Size of the memory block to free.\n     * @note Note that all CANN opertors are running async. Make sure memory is\n     *       still avaiable before this operator finished.\n     */\n    virtual void free(void * ptr, size_t size) = 0;\n};\n\n/**\n * @brief RAII wrapper for managing memory allocations from a CANN memory pool.\n */\nstruct ggml_cann_pool_alloc {\n    ggml_cann_pool * pool        = nullptr; /**< Pointer to the memory pool. */\n    void *           ptr         = nullptr; /**< Pointer to the allocated memory block. */\n    size_t           actual_size = 0;       /**< Actual size of the allocated memory block. */\n\n    /**\n     * @brief Default constructor.\n     */\n    ggml_cann_pool_alloc() = default;\n\n    /**\n     * @brief Constructor that initializes the memory pool.\n     * @param pool Reference to the memory pool.\n     */\n    explicit ggml_cann_pool_alloc(ggml_cann_pool & pool) : pool(&pool) {}\n\n    /**\n     * @brief Constructor that initializes the memory pool and allocates memory.\n     * @param pool Reference to the memory pool.\n     * @param size Size of the memory block to allocate.\n     */\n    ggml_cann_pool_alloc(ggml_cann_pool & pool, size_t size) : pool(&pool) { alloc(size); }\n\n    /**\n     * @brief Destructor that frees the allocated memory block.\n     */\n    ~ggml_cann_pool_alloc() {\n        if (ptr != nullptr) {\n            pool->free(ptr, actual_size);\n        }\n    }\n\n    /**\n     * @brief Allocates memory from the pool.\n     * @param size Size of the memory block to allocate.\n     * @return Pointer to the allocated memory block.\n     */\n    void * alloc(size_t size) {\n        GGML_ASSERT(pool != nullptr);\n        GGML_ASSERT(ptr == nullptr);\n        ptr = pool->alloc(size, &this->actual_size);\n        return ptr;\n    }\n\n    /**\n     * @brief Allocates memory from a specific memory pool.\n     * @param pool Reference to the memory pool.\n     * @param size Size of the memory block to allocate.\n     * @return Pointer to the allocated memory block.\n     */\n    void * alloc(ggml_cann_pool & pool, size_t size) {\n        this->pool = &pool;\n        return alloc(size);\n    }\n\n    /**\n     * @brief Gets the pointer to the allocated memory block.\n     * @return Pointer to the allocated memory block.\n     */\n    void * get() { return ptr; }\n\n    // Deleted copy constructor\n    ggml_cann_pool_alloc(const ggml_cann_pool_alloc &) = delete;\n\n    // Deleted move constructor\n    ggml_cann_pool_alloc(ggml_cann_pool_alloc &&) = delete;\n\n    // Deleted copy assignment operator\n    ggml_cann_pool_alloc & operator=(const ggml_cann_pool_alloc &) = delete;\n\n    // Deleted move assignment operator\n    ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete;\n};\n\n#ifdef USE_ACL_GRAPH\nstruct ggml_graph_node_properties {\n    // dst tensor\n    void *  node_address;\n    int64_t ne[GGML_MAX_DIMS];\n    size_t  nb[GGML_MAX_DIMS];\n\n    // src tensor\n    void *  src_address[GGML_MAX_SRC];\n    int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];\n    size_t  src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];\n\n    // op\n    ggml_op node_op;\n    int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];\n\n    /**\n     * @brief Check if a ggml tensor node matches this property set.\n     *\n     * This function compares all relevant fields (address, op type, shape, source inputs, op params)\n     * to determine whether the current node matches these previously recorded properties.\n     *\n     * @param node The current ggml tensor node.\n     * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.\n     */\n    bool has_matching_properties(ggml_tensor * node) {\n        if (node->data != this->node_address && node->op != GGML_OP_VIEW) {\n            return false;\n        }\n\n        if (node->op != this->node_op) {\n            return false;\n        }\n\n        for (int i = 0; i < GGML_MAX_DIMS; i++) {\n            if (node->ne[i] != this->ne[i]) {\n                return false;\n            }\n            if (node->nb[i] != this->nb[i]) {\n                return false;\n            }\n        }\n\n        for (int i = 0; i < GGML_MAX_SRC; i++) {\n            if (node->src[i]) {\n                if (node->src[i]->data != this->src_address[i] && node->op != GGML_OP_VIEW) {\n                    return false;\n                }\n\n                for (int d = 0; d < GGML_MAX_DIMS; d++) {\n                    if (node->src[i]->ne[d] != this->src_ne[i][d]) {\n                        return false;\n                    }\n                    if (node->src[i]->nb[d] != this->src_nb[i][d]) {\n                        return false;\n                    }\n                }\n            } else {\n                if (this->src_address[i] != nullptr) {\n                    return false;\n                }\n            }\n        }\n\n        if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {\n            return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;\n        }\n        return true;\n    }\n};\n\nstruct ggml_cann_graph {\n    ~ggml_cann_graph() {\n        if (graph != nullptr) {\n            ACL_CHECK(aclmdlRIDestroy(graph));\n        }\n    }\n\n    aclmdlRI graph = nullptr;\n\n    std::vector<ggml_graph_node_properties> ggml_graph_properties;\n\n    /**\n     * @brief Create a new CANN graph from a ggml computation graph.\n     *\n     * This function creates a new ggml_cann_graph object and fills its node properties\n     * (operation type, dimensions, strides, input sources, and operation parameters)\n     * based on the current ggml computation graph.\n     *\n     * Each node in the ggml graph is mapped to a property entry in the new CANN graph:\n     * - node address\n     * - operation type\n     * - shape (ne) and strides (nb)\n     * - source tensor addresses\n     * - operation parameters\n     *\n     * @param cgraph The current ggml computation graph.\n     * @return Pointer to the newly created ggml_cann_graph object.\n     */\n    static ggml_cann_graph * create_from_cgraph(ggml_cgraph * cgraph) {\n        ggml_cann_graph * new_graph = new ggml_cann_graph();\n        new_graph->ggml_graph_properties.resize(cgraph->n_nodes);\n\n        for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {\n            ggml_tensor * node = cgraph->nodes[node_idx];\n            auto &        prop = new_graph->ggml_graph_properties[node_idx];\n\n            prop.node_address = node->data;\n            prop.node_op      = node->op;\n\n            std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);\n            std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);\n\n            for (int src = 0; src < GGML_MAX_SRC; ++src) {\n                if (node->src[src]) {\n                    prop.src_address[src] = node->src[src]->data;\n                    std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);\n                    std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);\n                } else {\n                    prop.src_address[src] = nullptr;\n                    std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);\n                    std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);\n                }\n            }\n\n            memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);\n        }\n\n        return new_graph;\n    }\n\n    /**\n     * @brief Check whether this CANN graph matches the given ggml computation graph.\n     *\n     * This function compares the number of nodes and each node's properties\n     * (operation type, dimensions, strides, inputs, and operation parameters)\n     * to determine whether this CANN graph matches the given ggml graph.\n     *\n     * @param cgraph The current ggml computation graph.\n     * @return true if this CANN graph matches the ggml graph; false otherwise.\n     */\n    bool matches_cgraph(ggml_cgraph * cgraph) {\n        if (this->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {\n            return false;\n        }\n\n        for (int i = 0; i < cgraph->n_nodes; ++i) {\n            if (!this->ggml_graph_properties[i].has_matching_properties(cgraph->nodes[i])) {\n                return false;\n            }\n        }\n\n        return true;\n    }\n};\n\n/**\n * @brief LRU cache for managing ggml_cann_graph objects.\n *\n * This class maintains a list of shared_ptr to ggml_cann_graph objects\n * and enforces a maximum capacity. It provides methods to push new graphs,\n * move existing graphs to the front (most recently used), and clear the cache.\n */\nstruct ggml_cann_graph_lru_cache {\n    size_t capacity;                         /**< Maximum number of graphs in the cache. */\n\n    std::list<ggml_cann_graph *> cache_list; /**< List storing cached graphs as raw pointers. */\n\n    ggml_cann_graph_lru_cache() { capacity = parse_integer(get_env_as_lowercase(\"GGML_CANN_GRAPH_CACHE_CAPACITY\").value_or(\"12\")); }\n\n    /**\n     * @brief Push a new graph to the front of the cache.\n     * If the cache exceeds capacity, the least recently used graph is deleted.\n     * @param new_node Pointer to the new ggml_cann_graph to cache.\n     *        Ownership is transferred to the cache (cache will delete it).\n     */\n    void push(ggml_cann_graph * new_node) {\n        if (cache_list.size() >= capacity) {\n            ggml_cann_graph * old = cache_list.back();\n            cache_list.pop_back();\n            delete old;  // free the old graph\n        }\n        cache_list.push_front(new_node);\n    }\n\n    /**\n     * @brief Clear all graphs from the cache (also frees memory).\n     */\n    void clear() {\n        for (auto ptr : cache_list) {\n            delete ptr;\n        }\n        cache_list.clear();\n    }\n\n    /**\n     * @brief Destructor that clears the cache and frees all cached graphs.\n     */\n    ~ggml_cann_graph_lru_cache() { clear(); }\n\n    /**\n     * @brief Find a cached CANN graph that matches the given ggml graph and move it to front.\n     *\n     * This function iterates through the cached CANN graphs stored in the LRU cache and\n     * compares them against the given ggml computation graph. If a matching graph is found,\n     * it is promoted to the front of the LRU cache and returned. Otherwise, the function\n     * returns nullptr.\n     *\n     * @param cgraph The current ggml computation graph.\n     * @return true if found; false otherwise.\n     */\n    bool find_and_move_to_front(ggml_cgraph * cgraph) {\n        for (auto & graph_ptr : this->cache_list) {\n            if (graph_ptr->matches_cgraph(cgraph)) {\n                cache_list.remove(graph_ptr);\n                cache_list.push_front(graph_ptr);\n                return true;\n            }\n        }\n        return false;\n    }\n};\n#endif  // USE_ACL_GRAPH\n\nstruct ggml_cann_rope_cache {\n    ~ggml_cann_rope_cache() {\n        if (theta_scale_cache) {\n            ACL_CHECK(aclrtFree(theta_scale_cache));\n        }\n        if (sin_cache) {\n            ACL_CHECK(aclrtFree(sin_cache));\n        }\n        if (cos_cache) {\n            ACL_CHECK(aclrtFree(cos_cache));\n        }\n        if (position_select_index) {\n            ACL_CHECK(aclrtFree(position_select_index));\n        }\n        if (theta_scale_exp_host) {\n            free(theta_scale_exp_host);\n        }\n        if (position_select_index_host) {\n            free(position_select_index_host);\n        }\n        if (yarn_ramp_cache) {\n            ACL_CHECK(aclrtFree(yarn_ramp_cache));\n        }\n    }\n\n    bool equal(int64_t theta_scale_length,\n               int64_t position_length,\n               float   ext_factor,\n               float   theta_scale,\n               float   freq_scale,\n               float   attn_factor,\n               bool    is_neox,\n               bool    indep_sects,\n               bool    mrope_used,\n               bool    is_imrope,\n               int     sections[4]) {\n        return this->theta_scale_length == theta_scale_length && this->position_length == position_length &&\n               this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale &&\n               this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects &&\n               this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] &&\n               this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3];\n    }\n\n    void set(int64_t theta_scale_length,\n             int64_t position_length,\n             float   ext_factor,\n             float   theta_scale,\n             float   freq_scale,\n             float   attn_factor,\n             bool    is_neox,\n             bool    indep_sects,\n             bool    mrope_used,\n             bool    is_imrope,\n             int     sections[4]) {\n        this->theta_scale_length = theta_scale_length;\n        this->position_length    = position_length;\n        this->ext_factor         = ext_factor;\n        this->theta_scale        = theta_scale;\n        this->freq_scale         = freq_scale;\n        this->attn_factor        = attn_factor;\n        this->is_neox            = is_neox;\n        this->indep_sects        = indep_sects;\n        this->mrope_used         = mrope_used;\n        this->is_imrope          = is_imrope;\n        this->sections[0]        = sections[0];\n        this->sections[1]        = sections[1];\n        this->sections[2]        = sections[2];\n        this->sections[3]        = sections[3];\n    }\n\n    // memory cache, prepare before inferencing.\n    void *  theta_scale_cache          = nullptr;\n    float * theta_scale_exp_host       = nullptr;\n    int *   position_select_index_host = nullptr;\n    void *  position_select_index      = nullptr;\n    void *  yarn_ramp_cache            = nullptr;\n    // sin/cos cache, used only to accelerate first layer on each device\n    void *  sin_cache                  = nullptr;\n    void *  cos_cache                  = nullptr;\n    // Properties to check before reusing the sincos cache\n    int64_t theta_scale_length         = 0;\n    int64_t position_length            = 0;\n    bool    cached                     = false;\n    float   ext_factor                 = 0.0f;\n    float   theta_scale                = 0.0f;\n    float   freq_scale                 = 0.0f;\n    float   attn_factor                = 0.0f;\n    bool    is_neox                    = false;\n    bool    indep_sects                = false;\n    bool    mrope_used                 = false;\n    int     sections[4]                = { 0, 0, 0, 0 };\n    bool    is_imrope                  = false;\n};\n\nstruct ggml_cann_tensor_cache {\n    ~ggml_cann_tensor_cache() {\n        if (cache != nullptr) {\n            ACL_CHECK(aclrtFree(cache));\n        }\n    }\n\n    void *  cache = nullptr;\n    int64_t size  = 0;\n};\n\n/**\n * @brief Context for managing CANN backend operations.\n */\nstruct ggml_backend_cann_context {\n    int32_t     device;               /**< Device ID. */\n    std::string name;                 /**< Name of the device. */\n    std::string description;          /**< Description of the device. */\n    aclrtEvent  copy_event = nullptr; /**< Event for managing copy operations. */\n#ifdef USE_ACL_GRAPH\n    /// Cached CANN ACL graph used for executing the current ggml computation graph.\n    ggml_cann_graph_lru_cache graph_lru_cache;\n    bool                      acl_graph_mode = true;\n#endif\n    bool                   async_mode;\n    // Rope Cache\n    ggml_cann_rope_cache   rope_cache;\n    // Constant Pool\n    ggml_cann_tensor_cache rms_norm_one_tensor_cache;\n    ggml_cann_tensor_cache rms_norm_zero_tensor_cache;\n\n    aclrtStream streams[GGML_CANN_MAX_STREAMS] = { nullptr }; /**< Array of streams for the device. */\n\n    /**\n     * @brief Constructor for initializing the context with a given device.\n     * @param device Device ID.\n     */\n    explicit ggml_backend_cann_context(int device) : device(device), name(\"CANN\" + std::to_string(device)) {\n        ggml_cann_set_device(device);\n        description = aclrtGetSocName();\n\n#ifdef USE_ACL_GRAPH\n        acl_graph_mode = parse_bool(get_env_as_lowercase(\"GGML_CANN_ACL_GRAPH\").value_or(\"on\"));\n        GGML_LOG_INFO(\"%s: device %d execution mode is %s (%s)\\n\", __func__, device, acl_graph_mode ? \"GRAPH\" : \"EAGER\",\n                      acl_graph_mode ? \"acl graph enabled\" : \"acl graph disabled\");\n#endif\n    }\n\n    /**\n     * @brief Destructor for cleaning up resources.\n     */\n    ~ggml_backend_cann_context() {\n        ggml_cann_set_device(device);\n        if (copy_event != nullptr) {\n            ACL_CHECK(aclrtDestroyEvent(copy_event));\n        }\n        for (int i = 0; i < GGML_CANN_MAX_STREAMS; ++i) {\n            if (streams[i] != nullptr) {\n                ACL_CHECK(aclrtDestroyStream(streams[i]));\n            }\n        }\n    }\n\n    /**\n     * @brief Get or create a stream for a given index.\n     * @param stream Index of the stream.\n     * @return The stream corresponding to the given index.\n     */\n    aclrtStream stream(int stream) {\n        if (streams[stream] == nullptr) {\n            // If the device is not set here, destroying the stream later may cause a mismatch\n            // between the thread contexts where the stream was created and destroyed.\n            // However, I printed the device_id, thread_id, and stream, and they are all consistent.\n            ACL_CHECK(aclrtSetDevice(device));\n            ACL_CHECK(aclrtCreateStream(&streams[stream]));\n        }\n        return streams[stream];\n    }\n\n    /**\n     * @brief Get or create the default stream (index 0).\n     * @return The default stream.\n     */\n    aclrtStream stream() { return stream(0); }\n\n    // TODO: each stream should have a memory pool.\n    std::unique_ptr<ggml_cann_pool> mem_pool; /**< Memory pool for the device. */\n\n    /**\n     * @brief Create a new memory pool for a given device.\n     * @param device Device ID.\n     * @return A unique pointer to the new memory pool.\n     */\n    static std::unique_ptr<ggml_cann_pool> new_pool_for_device(int device);\n\n    /**\n     * @brief Get or create the memory pool for the context.\n     * @return Reference to the memory pool.\n     */\n    ggml_cann_pool & pool() {\n        if (mem_pool == nullptr) {\n            mem_pool = new_pool_for_device(device);\n        }\n        return *mem_pool;\n    }\n};\n\n#endif  // CANN_COMMON_H\n"
  },
  {
    "path": "src/ggml-cann/ggml-cann.cpp",
    "content": "/*\n * Copyright (c) 2023-2026 The ggml authors\n *\n * Permission is hereby granted, free of charge, to any person obtaining a copy\n * of this software and associated documentation files (the \"Software\"), to\n * deal in the Software without restriction, including without limitation the\n * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or\n * sell copies of the Software, and to permit persons to whom the Software is\n * furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS\n * IN THE SOFTWARE.\n */\n\n#include \"ggml-cann.h\"\n\n#include \"ggml-backend-impl.h\"\n#include \"ggml-cann/aclnn_ops.h\"\n#include \"ggml-cann/common.h\"\n#include \"ggml-impl.h\"\n#include \"ggml.h\"\n\n#include <acl/acl.h>\n#include <aclnnop/aclnn_trans_matmul_weight.h>\n#include <stdarg.h>\n\n#include <chrono>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <mutex>\n#include <optional>\n#include <queue>\n#include <unordered_set>\n\n#define GGML_COMMON_DECL_C\n\n#include \"ggml-common.h\"\n\n#define GGML_CANN_NAME \"CANN\"\n\n/**\n * @brief Handles CANN errors by printing an error message and aborting.\n *\n * @param stmt The statement that caused the error.\n * @param func The function in which the error occurred.\n * @param file The file in which the error occurred.\n * @param line The line number where the error occurred.\n * @param msg The error message.\n */\n[[noreturn]] void ggml_cann_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {\n    int32_t id = -1;\n    aclrtGetDevice(&id);\n\n    GGML_LOG_ERROR(\"CANN error: %s\\n\", msg);\n    GGML_LOG_ERROR(\"  current device: %d, in function %s at %s:%d\\n\", id, func, file, line);\n    GGML_LOG_ERROR(\"  %s\\n\", stmt);\n    // abort with GGML_ASSERT to get a stack trace\n    GGML_ABORT(\"CANN error\");\n}\n\n// Thread-local variable to record the current device of this thread.\nthread_local int g_current_cann_device = -1;\n\n/**\n * @brief Set the CANN device to be used.\n *\n * @param device The target device ID to set.\n */\nvoid ggml_cann_set_device(const int32_t device) {\n    // int current_device = -1;\n    // Note: In some CANN versions, if no device has been set yet,\n    //       aclrtGetDevice(&current_device) may return 0 by default.\n    // aclrtGetDevice(&current_device);\n\n    // If the current device is already the target one, no need to switch.\n    if (device == g_current_cann_device) {\n        return;\n    }\n\n    // Switch to the new device.\n    ACL_CHECK(aclrtSetDevice(device));\n\n    // Update the global device record.\n    g_current_cann_device = device;\n}\n\n/**\n * @brief Get the value of the specified environment variable (name) as lowercase.\n *        if not empty, return a std::string object\n */\nstd::optional<std::string> get_env_as_lowercase(const std::string & name) {\n    const char * val = std::getenv(name.c_str());\n    if (!val) {\n        return std::nullopt;\n    }\n    std::string res = std::string(val);\n    std::transform(res.begin(), res.end(), res.begin(), ::tolower);\n    return res;\n}\n\n/**\n * @brief Verify whether the environment variable is a valid value.\n */\nbool parse_bool(const std::string & value) {\n    static const std::unordered_set<std::string> valid_values = { \"on\", \"1\", \"yes\", \"y\", \"enable\", \"true\" };\n    return valid_values.find(value) != valid_values.end();\n}\n\n/**\n * @brief Parse a string as an integer, returning 0 if invalid.\n *\n * This function attempts to convert the input string `value` to an `int`.\n * If the string is not a valid integer or is out of the `int` range,\n * it returns 0.\n *\n * @param value The string to parse.\n * @return The parsed integer, or 0 if conversion fails.\n */\nint parse_integer(const std::string & value) {\n    try {\n        return std::stoi(value);\n    } catch (...) {\n        return 0;\n    }\n}\n\n/**\n * @brief Initialize the CANN device information.\n *\n * This function initializes the CANN device information by obtaining the\n * device count and setting the memory allocation granularity for each device.\n *\n * @return A structure containing the device information.\n */\nstatic ggml_cann_device_info ggml_cann_init() {\n    ggml_cann_device_info info = {};\n\n    aclError err = aclrtGetDeviceCount((uint32_t *) &info.device_count);\n\n    if (err != ACL_SUCCESS) {\n        GGML_LOG_ERROR(\"%s: failed to initialize CANN: %s\\n\", __func__, aclGetRecentErrMsg());\n        return info;\n    }\n\n    GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);\n\n    for (int id = 0; id < info.device_count; ++id) {\n        aclrtPhysicalMemProp prop = {};\n        prop.handleType           = ACL_MEM_HANDLE_TYPE_NONE;\n        prop.allocationType       = ACL_MEM_ALLOCATION_TYPE_PINNED;\n        prop.memAttr              = ACL_HBM_MEM_HUGE;\n        prop.location.type        = ACL_MEM_LOCATION_TYPE_DEVICE;\n        prop.location.id          = id;\n        prop.reserve              = 0;\n        err                       = aclrtMemGetAllocationGranularity(&prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,\n                                                                     &info.devices[id].vmm_granularity);\n        info.devices[id].vmm      = err == ACL_SUCCESS;\n\n        size_t free, total;\n        ggml_backend_cann_get_device_memory(id, &free, &total);\n        info.devices[id].total_vram = free;\n    }\n\n    // TODO: add more device info later.\n    return info;\n}\n\n/**\n * @brief Retrieve the CANN device information.\n *\n * This function returns a reference to a structure containing the CANN device\n * information. The device information is initialized once and reused on\n * subsequent calls.\n *\n * @return A reference to the structure containing the device information.\n */\nconst ggml_cann_device_info & ggml_cann_info() {\n    static ggml_cann_device_info info = ggml_cann_init();\n    return info;\n}\n\n//#define DEBUG_CANN_MALLOC\n/**\n * @brief A pool of CANN buffers(priority segment buffer).\n *\n * This class manages a pool of CANN buffers for a specific device.\n */\nstruct ggml_cann_pool_buf_prio : public ggml_cann_pool {\n    /**\n     * @brief The maximum reuse margin for a buffer.\n     */\n    static const size_t max_reuse_margin = 1ull << 22;  // 4MB\n\n    /**\n     * @brief The minimum free margin for a buffer.\n     */\n    static const size_t min_free_margin = 1ull << 20;  // 1MB\n\n    /**\n     * @brief The alignment for buffer allocation.\n     */\n    static const size_t alignment = 128;\n\n    /**\n     * @brief The device ID associated with this buffer pool.\n     */\n    int device;\n\n    /**\n     * @brief Whether to disable clean during buffer allocation.\n     */\n    bool disable_clean = false;\n\n    /**\n     * @brief Structure representing a CANN buffer.\n     */\n    struct ggml_cann_buffer {\n        void *                                ptr  = nullptr;  ///< Pointer to the buffer.\n        size_t                                size = 0;        ///< Size of the buffer.\n        std::chrono::steady_clock::time_point last_used;       ///< Last used time.\n\n        bool operator>(const ggml_cann_buffer & other) const { return size > other.size; }\n    };\n\n    /**\n     * @brief Array of CANN buffers in the pool.\n     */\n    std::unordered_map<void *, size_t>                                                   buffer_pool;\n    std::priority_queue<ggml_cann_buffer, std::vector<ggml_cann_buffer>, std::greater<>> free_buffers;\n\n    /**\n     * @brief Total size of all buffers in the pool.\n     */\n    size_t pool_size = 0;\n\n    /**\n     * @brief Constructor to initialize the buffer pool for a specific device.\n     *\n     * @param device The device ID to associate with this buffer pool.\n     */\n    explicit ggml_cann_pool_buf_prio(int device) : device(device) {\n        disable_clean = parse_bool(get_env_as_lowercase(\"GGML_CANN_DISABLE_BUF_POOL_CLEAN\").value_or(\"\"));\n    }\n\n    /**\n     * @brief Destructor to free all buffers in the pool.\n     */\n    ~ggml_cann_pool_buf_prio() {\n        ggml_cann_set_device(device);\n        for (auto & [b_ptr, b_size] : buffer_pool) {\n            aclrtFree(b_ptr);\n            pool_size -= b_size;\n        }\n        buffer_pool.clear();\n        GGML_ASSERT(pool_size == 0);\n    }\n\n    /**\n     * @brief Allocate a buffer of the given size.\n     *\n     * @param size The size of the buffer to allocate.\n     * @param actual_size A pointer to a variable to receive the actual size of\n     * the allocated buffer.\n     * @return A pointer to the allocated buffer.\n     */\n    void * alloc(size_t size, size_t * actual_size) override {\n        size = GGML_PAD(size, alignment);\n        if (size == 0) {\n            size = alignment;\n        }\n\n        void * ptr = nullptr;\n        auto   now = std::chrono::steady_clock::now();\n\n        std::vector<ggml_cann_buffer> free_buffers_rest;\n        free_buffers_rest.reserve(free_buffers.size());\n        while (!free_buffers.empty()) {\n            auto b = free_buffers.top();\n            free_buffers.pop();\n\n            if (b.size >= size) {\n                // reuse the buffer if the size is enough\n                const size_t margin = b.size - size;\n                if (margin <= max_reuse_margin) {\n                    *actual_size = b.size;\n                    ptr          = b.ptr;\n#ifdef DEBUG_CANN_MALLOC\n                    GGML_LOG_INFO(\n                        \"cann pool[%d]: reused   %p, \"\n                        \"pool_size = %5u MB, \"\n                        \"size = %5u MB, \"\n                        \"margin = %5u MB\\n\",\n                        device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),\n                        (uint32_t) (GGML_PAD(size, 1048576) / 1048576),\n                        (uint32_t) (GGML_PAD(margin, 1048576) / 1048576));\n#endif\n                    break;\n                }\n            }\n\n            bool should_clean = !disable_clean && b.size > min_free_margin &&\n                                std::chrono::duration_cast<std::chrono::milliseconds>(now - b.last_used).count() > 100;\n            if (should_clean) {\n                // free the buffer if the size is needed to be freed\n                ACL_CHECK(aclrtFree(b.ptr));\n                pool_size -= b.size;\n                buffer_pool.erase(b.ptr);\n#ifdef DEBUG_CANN_MALLOC\n                GGML_LOG_INFO(\n                    \"cann pool[%d]: clean    %p, \"\n                    \"pool_size = %5u MB, \"\n                    \"size = %5u MB\\n\",\n                    device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),\n                    (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));\n#endif\n                continue;\n            }\n            free_buffers_rest.push_back(b);\n        }\n        for (ggml_cann_buffer & b : free_buffers_rest) {\n            free_buffers.push(std::move(b));\n        }\n\n#ifdef DEBUG_CANN_MALLOC\n        GGML_LOG_INFO(\"cann pool[%d] free pool_size = %5u MB\\n\\n\", device,\n                      (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));\n#endif\n        if (ptr != nullptr) {\n            return ptr;\n        }\n\n        // allocate a new buffer if no buffer can be reused\n        ggml_cann_set_device(device);\n        ACL_CHECK(aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));\n        *actual_size = size;\n        pool_size += size;\n#ifdef DEBUG_CANN_MALLOC\n        GGML_LOG_INFO(\n            \"cann pool[%d]: allocate %p, \"\n            \"pool_size = %5u MB, \"\n            \"size = %5u MB\\n\",\n            device, ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),\n            (uint32_t) (GGML_PAD(size, 1048576) / 1048576));\n#endif\n        buffer_pool.emplace(ptr, size);\n        return ptr;\n    }\n\n    /**\n     * @brief Free a buffer and return it to the pool.\n     *\n     * @param ptr Pointer to the buffer to free.\n     * @param size Size of the buffer to free.\n     */\n    void free(void * ptr, size_t size) override {\n        GGML_UNUSED(size);\n        auto it = buffer_pool.find(ptr);\n        if (it == buffer_pool.end()) {\n            GGML_ABORT(\"cann pool[%d]: buffer %p not found in pool\\n\", device, ptr);\n        }\n\n        auto now = std::chrono::steady_clock::now();\n        free_buffers.emplace(ggml_cann_buffer{ ptr, it->second, now });\n#ifdef DEBUG_CANN_MALLOC\n        GGML_LOG_INFO(\n            \"cann pool[%d]: return   %p, \"\n            \"pool_size = %5u MB\\n\",\n            device, ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));\n#endif\n    }\n};\n\n/**\n * @brief A pool of CANN buffers(segment buffer).\n *\n * This class manages a pool of CANN buffers for a specific device.\n */\nstruct ggml_cann_pool_buf : public ggml_cann_pool {\n    /**\n     * @brief The maximum reuse margin for a buffer.\n     */\n    static const size_t max_reuse_margin = 1ull << 22;  // 4MB\n\n    /**\n     * @brief The minimum free margin for a buffer.\n     */\n    static const size_t min_free_margin = 1ull << 20;  // 1MB\n\n    /**\n     * @brief The alignment for buffer allocation.\n     */\n    static const size_t alignment = 128;\n\n    /**\n     * @brief The maximum number of buffers in the pool.\n     */\n    static const int MAX_BUFFERS = 256;\n\n    /**\n     * @brief The device ID associated with this buffer pool.\n     */\n    int device;\n\n    /**\n     * @brief Whether to disable clean during buffer allocation.\n     */\n    bool disable_clean = false;\n\n    /**\n     * @brief Structure representing a CANN buffer.\n     */\n    struct ggml_cann_buffer {\n        void *                                ptr  = nullptr;  ///< Pointer to the buffer memory.\n        size_t                                size = 0;        ///< Size of the buffer.\n        bool                                  used = false;    ///< Whether the buffer is currently in use.\n        std::chrono::steady_clock::time_point last_used;       ///< Last used time.\n    };\n\n    /**\n     * @brief Array of CANN buffers in the pool.\n     */\n    ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};\n\n    /**\n     * @brief Total size of all buffers in the pool.\n     */\n    size_t pool_size = 0;\n\n    /**\n     * @brief Constructor to initialize the buffer pool for a specific device.\n     *\n     * @param device The device ID to associate with this buffer pool.\n     */\n    explicit ggml_cann_pool_buf(int device) : device(device) {\n        disable_clean = parse_bool(get_env_as_lowercase(\"GGML_CANN_DISABLE_BUF_POOL_CLEAN\").value_or(\"\"));\n    }\n\n    /**\n     * @brief Destructor to free all buffers in the pool.\n     */\n    ~ggml_cann_pool_buf() {\n        ggml_cann_set_device(device);\n        for (int i = 0; i < MAX_BUFFERS; ++i) {\n            ggml_cann_buffer & b = buffer_pool[i];\n            if (b.ptr != nullptr) {\n                aclrtFree(b.ptr);\n                pool_size -= b.size;\n            }\n        }\n        GGML_ASSERT(pool_size == 0);\n    }\n\n    /**\n     * @brief Allocate a buffer of the given size.\n     *\n     * @param size The size of the buffer to allocate.\n     * @param actual_size A pointer to a variable to receive the actual size of\n     * the allocated buffer.\n     * @return A pointer to the allocated buffer.\n     */\n    void * alloc(size_t size, size_t * actual_size) override {\n        size = GGML_PAD(size, alignment);\n        if (size == 0) {\n            size = alignment;\n        }\n\n        void * ptr = nullptr;\n        auto   now = std::chrono::steady_clock::now();\n\n        int i = 0;\n        for (; i < MAX_BUFFERS; ++i) {\n            ggml_cann_buffer & b = buffer_pool[i];\n            if (b.ptr == nullptr) {\n                break;\n            }\n            if (b.used) {\n                continue;\n            }\n            if (b.size >= size) {\n                // reuse the buffer if the size is enough\n                const size_t margin = b.size - size;\n                if (margin <= max_reuse_margin) {\n                    *actual_size = b.size;\n                    b.used       = true;\n                    ptr          = b.ptr;\n#ifdef DEBUG_CANN_MALLOC\n                    GGML_LOG_INFO(\n                        \"cann pool[%d]: reused   %p, \"\n                        \"pool_size = %5u MB, \"\n                        \"size = %5u MB, \"\n                        \"margin = %5u MB\\n\",\n                        device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),\n                        (uint32_t) (GGML_PAD(size, 1048576) / 1048576),\n                        (uint32_t) (GGML_PAD(margin, 1048576) / 1048576));\n#endif\n                    break;\n                }\n            }\n\n            bool should_clean = !disable_clean && b.size > min_free_margin &&\n                                std::chrono::duration_cast<std::chrono::milliseconds>(now - b.last_used).count() > 100;\n            if (should_clean) {\n                // free the buffer if the size is needed to be freed\n                ACL_CHECK(aclrtFree(b.ptr));\n                pool_size -= b.size;\n#ifdef DEBUG_CANN_MALLOC\n                GGML_LOG_INFO(\n                    \"cann pool[%d]: clean    %p, \"\n                    \"pool_size = %5u MB, \"\n                    \"size = %5u MB\\n\",\n                    device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),\n                    (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));\n#endif\n                b.ptr = nullptr;\n            }\n        }\n        if (ptr != nullptr) {\n            return ptr;\n        }\n\n        if (i < MAX_BUFFERS) {\n            // allocate a new buffer if no buffer can be reused\n            ggml_cann_buffer & b = buffer_pool[i];\n            ggml_cann_set_device(device);\n            ACL_CHECK(aclrtMalloc(&b.ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));\n            pool_size += size;\n            *actual_size = size;\n            b.size       = size;\n            b.used       = true;\n            if (i >= MAX_BUFFERS - 8) {\n                GGML_LOG_WARN(\"cann pool[%d]: slots almost full\\n\", device);\n            }\n#ifdef DEBUG_CANN_MALLOC\n            GGML_LOG_INFO(\n                \"cann pool[%d]: allocate %p, \"\n                \"pool_size = %5u MB, \"\n                \"size = %5u MB\\n\",\n                device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),\n                (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));\n#endif\n            return b.ptr;\n        }\n\n        GGML_ABORT(\"cann pool[%d]: slots full\\n\", device);\n    }\n\n    /**\n     * @brief Free a buffer and return it to the pool.\n     *\n     * @param ptr Pointer to the buffer to free.\n     * @param size Size of the buffer to free.\n     */\n    void free(void * ptr, size_t size) override {\n        GGML_UNUSED(size);\n        for (int i = 0; i < MAX_BUFFERS; ++i) {\n            ggml_cann_buffer & b = buffer_pool[i];\n            if (b.ptr != ptr) {\n                continue;\n            }\n            b.used      = false;\n            b.last_used = std::chrono::steady_clock::now();\n#ifdef DEBUG_CANN_MALLOC\n            GGML_LOG_INFO(\n                \"cann pool[%d]: return   %p, \"\n                \"pool_size = %5u MB\\n\",\n                device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));\n#endif\n            return;\n        }\n        GGML_ABORT(\"cann pool[%d]: slots full\\n\", device);\n    }\n};\n\n/**\n * @brief A pool of CANN buffers with virtual memory.\n *\n * This class manages a pool of CANN buffers with virtual memory for a specific\n * device.\n */\nstruct ggml_cann_pool_vmm : public ggml_cann_pool {\n    /**\n     * @brief The maximum size of the virtual memory pool (32 GB).\n     */\n    size_t max_size;\n\n    /**\n     * @brief The device ID associated with this buffer pool.\n     */\n    int device;\n\n    /**\n     * @brief Pointer to the start of the virtual memory pool.\n     */\n    void * pool_addr = 0;\n\n    /**\n     * @brief Amount of virtual memory used in the pool.\n     */\n    size_t pool_used = 0;\n\n    /**\n     * @brief Total size of the virtual memory pool.\n     */\n    size_t pool_size = 0;\n\n    /**\n     * @brief Allocation granularity for the virtual memory pool.\n     */\n    size_t granularity;\n\n    /**\n     * @brief Handles for the physical memory allocated.\n     */\n    std::vector<aclrtDrvMemHandle> handles;\n\n    /**\n     * @brief Offsets for the mapped memory regions.\n     */\n    std::vector<void *> map_offsets;\n\n    /**\n     * @brief Constructor to initialize the buffer pool with virtual memory for\n     * a specific device.\n     *\n     * @param device The device ID to associate with this buffer pool.\n     */\n    explicit ggml_cann_pool_vmm(int device) : device(device) {\n        auto dev    = ggml_cann_info().devices[device];\n        granularity = dev.vmm_granularity;\n        max_size    = dev.total_vram;\n    }\n\n    /**\n     * @brief Destructor to free all buffers in the virtual memory pool.\n     */\n    ~ggml_cann_pool_vmm() {\n        if (pool_addr != 0) {\n            for (auto & offset : map_offsets) {\n                ACL_CHECK(aclrtUnmapMem(offset));\n            }\n            for (auto & handle : handles) {\n                ACL_CHECK(aclrtFreePhysical(handle));\n            }\n            ACL_CHECK(aclrtReleaseMemAddress(pool_addr));\n        }\n    }\n\n    /**\n     * @brief Allocate a buffer of the given size in the virtual memory pool.\n     *\n     * @param size The size of the buffer to allocate.\n     * @param actual_size A pointer to a variable to receive the actual size of\n     * the allocated buffer.\n     * @return A pointer to the allocated buffer.\n     */\n    void * alloc(size_t size, size_t * actual_size) override {\n        // round up the allocation size to the alignment to ensure that all\n        // allocations are aligned for all data types\n        const size_t alignment = 128;\n        size                   = GGML_PAD(size, alignment);\n        if (size == 0) {\n            size = alignment;\n        }\n\n        size_t avail = pool_size - pool_used;\n\n        if (size > avail) {\n            // round up to the next multiple of the granularity\n            size_t reserve_size = size - avail;\n            reserve_size        = GGML_PAD(reserve_size, granularity);\n\n            GGML_ASSERT(pool_size + reserve_size <= max_size);\n\n            // allocate more physical memory\n            aclrtPhysicalMemProp prop = {};\n            prop.handleType           = ACL_MEM_HANDLE_TYPE_NONE;\n            prop.allocationType       = ACL_MEM_ALLOCATION_TYPE_PINNED;\n            prop.memAttr              = ACL_HBM_MEM_HUGE;\n            prop.location.type        = ACL_MEM_LOCATION_TYPE_DEVICE;\n            prop.location.id          = device;\n            prop.reserve              = 0;\n            aclrtDrvMemHandle handle;\n            ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));\n\n            // reserve virtual address space (if not already reserved)\n            if (pool_addr == 0) {\n                ACL_CHECK(aclrtReserveMemAddress(&pool_addr, max_size, 0, NULL, 1));\n            }\n\n            // map at the end of the pool\n            ACL_CHECK(aclrtMapMem((char *) pool_addr + pool_size, reserve_size, 0, handle, 0));\n\n            handles.push_back(handle);\n            map_offsets.push_back((char *) pool_addr + pool_size);\n\n            // add to the pool\n            pool_size += reserve_size;\n\n#ifdef DEBUG_CANN_MALLOC\n            GGML_LOG_INFO(\"cann pool[%d]: size increased to %llu MB (reserved %llu MB)\\n\", device,\n                          (unsigned long long) (pool_size / 1024 / 1024),\n                          (unsigned long long) (reserve_size / 1024 / 1024));\n#endif\n        }\n\n        GGML_ASSERT(pool_addr != 0);\n\n        void * ptr   = (void *) ((char *) pool_addr + pool_used);\n        *actual_size = size;\n        pool_used += size;\n\n#ifdef DEBUG_CANN_MALLOC\n        GGML_LOG_INFO(\"cann pool[%d]: allocated %llu bytes at %llx\\n\", device, (unsigned long long) size,\n                      (unsigned long long) ptr);\n#endif\n        return ptr;\n    }\n\n    /**\n     * @brief Free a buffer and return it to the virtual memory pool.\n     *\n     * @param ptr Pointer to the buffer to free.\n     * @param size Size of the buffer to free.\n     */\n    void free(void * ptr, size_t size) override {\n#ifdef DEBUG_CANN_MALLOC\n        GGML_LOG_INFO(\"cann pool[%d]: freed %llu bytes at %llx\\n\", device, (unsigned long long) size,\n                      (unsigned long long) ptr);\n#endif\n\n        pool_used -= size;\n\n        // all deallocations must be in reverse order of the allocations\n        GGML_ASSERT(ptr == (void *) ((char *) pool_addr + pool_used));\n    }\n};\n\n/**\n * @brief Create a new CANN pool for a specific device.\n *\n * Factory method to create a new CANN pool object based on the device type.\n *\n * @param device The device ID for which to create the pool.\n * @return A unique pointer to the created CANN pool.\n */\nstd::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(int device) {\n    std::string mem_pool_type = get_env_as_lowercase(\"GGML_CANN_MEM_POOL\").value_or(\"\");\n\n    if (mem_pool_type == \"prio\") {\n        GGML_LOG_INFO(\"%s: device %d use buffer pool with priority queue\\n\", __func__, device);\n        return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf_prio(device));\n    }\n\n    if (ggml_cann_info().devices[device].vmm && mem_pool_type != \"leg\") {\n        GGML_LOG_INFO(\"%s: device %d use vmm pool\\n\", __func__, device);\n        return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));\n    }\n\n    GGML_LOG_INFO(\"%s: device %d use buffer pool\\n\", __func__, device);\n    return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf(device));\n}\n\n// cann buffer\n/**\n * @brief Context for managing a CANN buffer associated with a specific device.\n *\n * This structure holds information about a CANN buffer, including the device\n * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.\n */\nstruct ggml_backend_cann_buffer_context {\n    int32_t device;             ///< The device ID associated with this buffer context.\n    void *  dev_ptr = nullptr;  ///< Pointer to the device memory allocated for the buffer.\n\n    /**\n     * @brief Constructor to initialize the CANN buffer context.\n     *\n     * @param device The device ID associated with this buffer context.\n     * @param dev_ptr Pointer to the device memory allocated for the buffer.\n     */\n    ggml_backend_cann_buffer_context(int32_t device, void * dev_ptr) : device(device), dev_ptr(dev_ptr) {}\n\n    /**\n     * @brief Destructor to free the device memory allocated for the buffer.\n     */\n    ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }\n};\n\n// cann buffer type\n/**\n * @brief Structure representing context information for a specific backend\n * buffer type.\n */\nstruct ggml_backend_cann_buffer_type_context {\n    int32_t     device; /**< Device identifier associated with the buffer context. */\n    std::string name;   /**< Name associated with the buffer context. */\n};\n\n/**\n * @brief Retrieves the name associated with a CANN buffer type.\n *\n * This function returns the descriptive name associated with the specified\n * CANN buffer type context.\n *\n * @param buft Pointer to the buffer type context.\n * @return Const pointer to the C-style string containing the name.\n */\nstatic const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {\n    ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;\n\n    return buft_ctx->name.c_str();\n}\n\n/**\n * @brief Checks if the backend buffer type is associated with the CANN backend.\n *\n * This function checks whether the provided backend buffer type is associated\n * with the CANN backend based on the comparison of its name retrieval function\n * pointer.\n *\n * @param buft Pointer to the backend buffer type to check.\n * @return bool Returns true if the buffer type is associated with the CANN\n * backend, otherwise false.\n */\nstatic bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {\n    return buft->iface.get_name == ggml_backend_cann_buffer_type_name;\n}\n\n/**\n * @brief Free resources associated with a CANN buffer.\n *\n * This function frees the resources associated with a CANN buffer, including\n * its context.\n *\n * @param buffer The CANN buffer to free.\n */\nstatic void ggml_backend_cann_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;\n    delete ctx;\n}\n\n/**\n * @brief Retrieve the base pointer of a CANN buffer.\n *\n * This function returns the base pointer of a CANN buffer, which points to the\n * device memory allocated for the buffer.\n *\n * @param buffer The CANN buffer whose base pointer is to be retrieved.\n * @return A pointer to the base of the device memory allocated for the buffer.\n */\nstatic void * ggml_backend_cann_buffer_get_base(ggml_backend_buffer_t buffer) {\n    ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;\n    return ctx->dev_ptr;\n}\n\n/**\n * @brief Transform quantized Q4.0 tensor data into a format suitable for CANN\n * processing.\n *\n * This function transforms quantized Q4.0 tensor data into a format suitable\n * for CANN processing. It extracts quantization values and scales from the\n * source data and prepares them in a format expected by CANN operations.\n *\n * @param tensor Pointer to the tensor information.\n * @param src Pointer to the source data in Q4.0 format.\n * @param dst Pointer to the destination buffer where transformed data will be\n * stored.\n */\nstatic void ggml_backend_cann_transform_q4_0(ggml_tensor * tensor, const void * src, void * dst) {\n    int64_t n_elems     = ggml_nelements(tensor);\n    int64_t groups      = n_elems / QK4_0;\n    size_t  quant_bytes = n_elems * sizeof(uint8_t) / 2;\n\n    uint8_t *  quant_offset = (uint8_t *) dst;\n    uint16_t * scale_offset = (uint16_t *) ((char *) dst + quant_bytes);\n\n    for (int i = 0; i < groups; i++) {\n        const block_q4_0 * group = (const block_q4_0 *) ((const char *) src + i * sizeof(block_q4_0));\n        *scale_offset            = group->d;\n        scale_offset++;\n\n        // 0-15\n        for (int j = 0; j < QK4_0 / 2; j += 2) {\n            (*quant_offset) = (group->qs[j] & 0x0F);\n            (*quant_offset) |= ((group->qs[j + 1] << 4));\n            quant_offset++;\n        }\n\n        // 16-31\n        for (int j = 0; j < QK4_0 / 2; j += 2) {\n            (*quant_offset) = (group->qs[j] >> 4);\n            (*quant_offset) |= (group->qs[j + 1] & 0xF0);\n            quant_offset++;\n        }\n    }\n\n    // put (uint4b_t -8) into int4b_t\n    for (quant_offset = (uint8_t *) dst; quant_offset < (uint8_t *) dst + quant_bytes; quant_offset++) {\n        (*quant_offset) ^= 0x88;\n    }\n}\n\n/**\n * @brief Transform CANN processed data back into quantized Q4.0 format.\n *\n * This function transforms CANN processed data back into quantized Q4.0 format.\n * It reverses the transformation performed by\n * ggml_backend_cann_transform_q4_0(), converting the data back into its\n * original quantized form.\n *\n * @param tensor Pointer to the tensor information.\n * @param src Pointer to the source buffer containing transformed data.\n * @param dst Pointer to the destination buffer where the Q4.0 formatted data\n * will be stored.\n */\nstatic void ggml_backend_cann_transform_back_q4_0(const ggml_tensor * tensor, void * src, void * dst) {\n    int64_t n_elems     = ggml_nelements(tensor);\n    int64_t groups      = n_elems / QK4_0;\n    size_t  quant_bytes = n_elems * sizeof(uint8_t) / 2;\n\n    uint8_t *  quant_offset = (uint8_t *) src;\n    uint16_t * scale_offset = (uint16_t *) ((char *) src + quant_bytes);\n\n    for (; quant_offset < (uint8_t *) src + quant_bytes; quant_offset++) {\n        (*quant_offset) ^= 0x88;\n    }\n    quant_offset = (uint8_t *) src;\n\n    for (int i = 0; i < groups; i++) {\n        block_q4_0 * group = (block_q4_0 *) ((char *) dst + i * sizeof(block_q4_0));\n        group->d           = *scale_offset;\n        scale_offset++;\n\n        // 0-15\n        for (int j = 0; j < QK4_0 / 2; j += 2) {\n            group->qs[j]     = ((*quant_offset) & 0x0F);\n            group->qs[j + 1] = ((*quant_offset) >> 4);\n            quant_offset++;\n        }\n\n        // 16-31\n        for (int j = 0; j < QK4_0 / 2; j += 2) {\n            group->qs[j] |= ((*quant_offset) << 4);\n            group->qs[j + 1] |= ((*quant_offset) & 0xF0);\n            quant_offset++;\n        }\n    }\n}\n\n/**\n * @brief Transform quantized Q8.0 tensor data into a format suitable for CANN\n * processing.\n *\n * This function transforms quantized Q8.0 tensor data into a format suitable\n * for CANN processing. It extracts quantization values and scales from the\n * source data and prepares them in a format expected by CANN operations.\n *\n * @param tensor Pointer to the tensor information.\n * @param src Pointer to the source data in Q8.0 format.\n * @param dst Pointer to the destination buffer where transformed data will be\n * stored.\n */\nstatic void ggml_backend_cann_transform_q8_0(ggml_tensor * tensor, const void * src, void * dst) {\n    int64_t n_elems     = ggml_nelements(tensor);\n    int64_t groups      = n_elems / QK8_0;\n    size_t  quant_bytes = n_elems * sizeof(uint8_t);\n\n    uint8_t *  quant_offset = (uint8_t *) dst;\n    uint16_t * scale_offset = (uint16_t *) ((char *) dst + quant_bytes);\n\n    for (int i = 0; i < groups; i++) {\n        const block_q8_0 * group = (const block_q8_0 *) ((const char *) src + i * sizeof(block_q8_0));\n        *scale_offset            = group->d;\n        scale_offset++;\n        size_t group_quant_size = QK8_0 * sizeof(uint8_t);\n        memcpy(quant_offset, group->qs, group_quant_size);\n        quant_offset += group_quant_size;\n    }\n}\n\n/**\n * @brief Transform CANN processed data back into quantized Q8.0 format.\n *\n * This function transforms CANN processed data back into quantized Q8.0 format.\n * It reverses the transformation performed by\n * ggml_backend_cann_transform_q8_0(), converting the data back into its\n * original quantized form.\n *\n * @param tensor Pointer to the tensor information.\n * @param src Pointer to the source buffer containing transformed data.\n * @param dst Pointer to the destination buffer where the Q8.0 formatted data\n * will be stored.\n */\nstatic void ggml_backend_cann_transform_back_q8_0(const ggml_tensor * tensor, const void * src, void * dst) {\n    int64_t n_elems     = ggml_nelements(tensor);\n    int64_t groups      = n_elems / QK8_0;\n    size_t  quant_bytes = n_elems * sizeof(uint8_t);\n\n    const uint8_t *  quant_offset = (const uint8_t *) src;\n    const uint16_t * scale_offset = (const uint16_t *) ((const char *) src + quant_bytes);\n\n    for (int i = 0; i < groups; i++) {\n        block_q8_0 * group = (block_q8_0 *) ((char *) dst + i * sizeof(block_q8_0));\n        group->d           = *scale_offset;\n        scale_offset++;\n        size_t group_quant_size = QK8_0 * sizeof(uint8_t);\n        memcpy(group->qs, quant_offset, group_quant_size);\n        quant_offset += group_quant_size;\n    }\n}\n\n/**\n * @brief Transform tensor data based on its type for CANN processing.\n *\n * This function transforms tensor data based on its quantization type for CANN\n * processing. It dispatches the transformation based on the tensor's type to\n * specialized functions handling Q4.0 and Q8.0 formats.\n *\n * @param tensor Pointer to the tensor information.\n * @param src Pointer to the source data to be transformed.\n * @param dst Pointer to the destination buffer where transformed data will be\n * stored.\n */\nstatic void ggml_backend_cann_transform(ggml_tensor * tensor, const void * src, void * dst) {\n    switch (tensor->type) {\n        case GGML_TYPE_Q4_0:\n            ggml_backend_cann_transform_q4_0(tensor, src, dst);\n            break;\n        case GGML_TYPE_Q8_0:\n            ggml_backend_cann_transform_q8_0(tensor, src, dst);\n            break;\n        default:\n            break;\n    }\n}\n\n/**\n * @brief Transform CANN processed data back into tensor data based on its type.\n *\n * This function transforms CANN processed data back into tensor data based on\n * its quantization type for Q4.0 and Q8.0 formats. It dispatches the\n * transformation based on the tensor's type to specialized functions.\n *\n * @param tensor Pointer to the tensor information.\n * @param src Pointer to the source data containing CANN processed data.\n * @param dst Pointer to the destination buffer where transformed tensor data\n * will be stored.\n */\nstatic void ggml_backend_cann_transform_back(const ggml_tensor * tensor, void * src, void * dst) {\n    switch (tensor->type) {\n        case GGML_TYPE_Q4_0:\n            ggml_backend_cann_transform_back_q4_0(tensor, src, dst);\n            break;\n        case GGML_TYPE_Q8_0:\n            ggml_backend_cann_transform_back_q8_0(tensor, src, dst);\n            break;\n        default:\n            break;\n    }\n}\n\n/**\n * @brief Check if transformation is needed for a given tensor type.\n *\n * This function checks if transformation is needed for a given tensor type\n * to prepare data for CANN processing.\n *\n * @param type The tensor type to check.\n * @return true if transformation is needed, false otherwise.\n */\nstatic bool need_transform(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q8_0:\n            return true;\n        default:\n            return false;\n    }\n}\n\n/**\n * @brief Initialize a tensor using data from a CANN buffer.\n *\n * This function initializes a tensor using data from a CANN buffer.\n * It handles special cases such as views and quantization.\n *\n * @param buffer The CANN buffer from which to initialize the tensor.\n * @param tensor Pointer to the tensor to be initialized.\n */\nstatic enum ggml_status ggml_backend_cann_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {\n    if (tensor->view_src != NULL && tensor->view_offs == 0) {\n        GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);\n        return GGML_STATUS_SUCCESS;\n    }\n\n    // TODO: cann backend doesn't support quantized yet. Just leave the code\n    // here.\n    if (ggml_is_quantized(tensor->type)) {\n        // Initialize padding to 0 to avoid possible NaN values\n        size_t original_size = ggml_nbytes(tensor);\n        size_t padded_size   = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);\n\n        if (padded_size > original_size && tensor->view_src == nullptr) {\n            size_t memset_size = padded_size - original_size;\n            ACL_CHECK(aclrtMemset((char *) tensor->data + original_size, memset_size, 0, memset_size));\n        }\n    }\n    return GGML_STATUS_SUCCESS;\n}\n\n/**\n * @brief Workspace for caching NZ buffers per device.\n *\n * This struct manages a device buffer used in NZ computations. It supports\n * allocation, reallocation, and clearing of cached memory. The struct is\n * designed to be used with a global array, one per device.\n */\nstruct ggml_cann_nz_workspace {\n    void * ptr;        // Pointer to allocated device buffer\n    size_t allocated;  // Size of currently allocated buffer in bytes\n\n    /**\n     * @brief Constructor. Initializes the workspace with no allocated memory.\n     */\n    ggml_cann_nz_workspace() : ptr(nullptr), allocated(0) {}\n\n    /**\n     * @brief Free cached memory and reset the workspace.\n     *\n     * If a buffer has been allocated, this function releases it using\n     * aclrtFree and resets internal state.\n     */\n    void clear() {\n        if (ptr) {\n            ACL_CHECK(aclrtFree(ptr));\n            ptr       = nullptr;\n            allocated = 0;\n        }\n    }\n\n    /**\n     * @brief Allocate or reallocate the workspace buffer.\n     *\n     * If the requested size is larger than the currently allocated size,\n     * the old buffer will be freed and a new buffer of the requested size\n     * will be allocated on the device.\n     *\n     * @param new_size Size in bytes to allocate for the workspace.\n     */\n    void realloc(size_t new_size) {\n        if (new_size > allocated) {\n            clear();\n            ACL_CHECK(aclrtMalloc(&ptr, new_size, ACL_MEM_MALLOC_HUGE_FIRST));\n            allocated = new_size;\n        }\n    }\n\n    /**\n     * @brief Get the device buffer pointer.\n     *\n     * @return Pointer to the allocated buffer, or nullptr if not allocated.\n     */\n    void * get() const { return ptr; }\n};\n\n/**\n * @brief Global array of NZ workspaces, one per device.\n */\nstatic ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];\n\n/**\n * @brief Convert tensor weights to NZ format using Ascend CANN API.\n *\n * This function creates a transposed tensor descriptor and performs the\n * TransMatmulWeight operation. Converting tensor formats can significantly\n * improve performance on certain hardware.\n *\n * @param tensor Pointer to the input ggml_tensor containing the weights.\n * @param offset Byte offset within the tensor data buffer where weights start.\n * @param device device id.\n *\n * @note The workspace buffer used in this function is managed globally and reused\n *       across calls. This reduces overhead from repeated memory allocation and deallocation.\n */\nstatic void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) {\n    acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);\n    uint64_t       workspaceSize    = 0;\n    aclOpExecutor * executor;\n\n    // TransMatmulWeight\n    ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor));\n    // Avoid frequent malloc/free of the workspace.\n    g_nz_workspaces[device].realloc(workspaceSize);\n\n    void * g_nz_workspace = g_nz_workspaces[device].get();\n\n    ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));\n}\n\n// TODO: need handle tensor which has paddings.\n/**\n * @brief Set tensor data in a CANN buffer.\n *\n * This function sets tensor data in a CANN buffer, handling transformations\n * if needed based on the tensor's type.\n *\n * @param buffer The CANN buffer where the tensor data will be set.\n * @param tensor Pointer to the tensor whose data will be set.\n * @param data Pointer to the source data to be copied into the tensor.\n * @param offset Offset in the source data from where to start copying.\n * @param size Size of the data to be copied, in bytes.\n */\nstatic void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,\n                                                ggml_tensor *         tensor,\n                                                const void *          data,\n                                                size_t                offset,\n                                                size_t                size) {\n    ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;\n\n    ggml_cann_set_device(ctx->device);\n    // TODO: refer to cann(#6017), it use thread's default stream.\n    // For acl, synchronous functions use this default stream.\n    // Why aclrtSynchronizeDevice?\n\n    // Only check env once.\n    static bool weight_to_nz = parse_bool(get_env_as_lowercase(\"GGML_CANN_WEIGHT_NZ\").value_or(\"on\"));\n    if (!need_transform(tensor->type)) {\n        ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));\n        if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {\n            GGML_ASSERT(tensor->ne[2] == 1);\n            GGML_ASSERT(tensor->ne[3] == 1);\n            weight_format_to_nz(tensor, offset, ctx->device);\n        }\n    } else {\n        void * transform_buffer = malloc(size);\n        ggml_backend_cann_transform(tensor, data, transform_buffer);\n\n        ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE));\n        free(transform_buffer);\n    }\n}\n\n/**\n * @brief Get tensor data from a CANN buffer.\n *\n * This function retrieves tensor data from a CANN buffer, handling\n * transformations if needed based on the tensor's type.\n *\n * @param buffer The CANN buffer from which to retrieve tensor data.\n * @param tensor Pointer to the tensor whose data will be retrieved.\n * @param data Pointer to the destination buffer where the tensor data will be\n * copied.\n * @param offset Offset in the destination buffer where to start copying.\n * @param size Size of the data to be copied, in bytes.\n */\nstatic void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer,\n                                                const ggml_tensor *   tensor,\n                                                void *                data,\n                                                size_t                offset,\n                                                size_t                size) {\n    ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;\n\n    ggml_cann_set_device(ctx->device);\n\n    if (!need_transform(tensor->type)) {\n        ACL_CHECK(aclrtMemcpy(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST));\n    } else {\n        void * transform_buffer = malloc(size);\n        ACL_CHECK(aclrtMemcpy(transform_buffer, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST));\n        ggml_backend_cann_transform_back(tensor, transform_buffer, data);\n        free(transform_buffer);\n    }\n}\n\n/**\n * @brief Copy tensor data between CANN buffers if possible.\n *\n * This function copies tensor data between CANN buffers if the source and\n * destination buffers are CANN buffers and they meet the necessary conditions\n * (same device or devices can access each other).\n *\n * @param buffer The destination CANN buffer where the tensor data will be\n * copied.\n * @param src Pointer to the source tensor whose data will be copied.\n * @param dst Pointer to the destination tensor where the data will be copied.\n * @return true if the copy operation succeeded, false otherwise.\n */\nstatic bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer,\n                                                const ggml_tensor *   src,\n                                                ggml_tensor *         dst) {\n    if (ggml_backend_buft_is_cann(src->buffer->buft)) {\n        ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context;\n        ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context;\n\n        size_t memcpy_size = ggml_nbytes(src);\n        // Same device.\n        if (src_ctx->device == dst_ctx->device) {\n            ACL_CHECK(aclrtMemcpy((char *) dst->data, memcpy_size, (const char *) src->data, memcpy_size,\n                                  ACL_MEMCPY_DEVICE_TO_DEVICE));\n            return true;\n        } else {\n#ifdef ASCEND_310P\n            // TODO: Support 310p P2P copy\n            return false;\n#endif\n            // Different device but can access by peer.\n            int32_t canAccessPeer = 0;\n            ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device, dst_ctx->device));\n            if (canAccessPeer) {\n                ggml_cann_set_device(src_ctx->device);\n                ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));\n                ACL_CHECK(aclrtMemcpy((char *) dst->data, memcpy_size, (const char *) src->data, memcpy_size,\n                                      ACL_MEMCPY_DEVICE_TO_DEVICE));\n                return true;\n            }\n        }\n    }\n    return false;\n}\n\n/**\n * @brief Clear a CANN buffer by setting all its memory to a specified value.\n *\n * This function clears a CANN buffer by setting all its memory to a specified\n * value.\n *\n * @param buffer The CANN buffer to be cleared.\n * @param value The value to which each byte in the buffer will be set.\n */\nstatic void ggml_backend_cann_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;\n\n    ggml_cann_set_device(ctx->device);\n    ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));\n}\n\n/**\n * @brief Interface for a CANN buffer in the backend.\n *\n * This structure defines function pointers to operations that can be performed\n * on a CANN buffer within the backend.\n */\nstatic const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {\n    /* .free_buffer     = */ ggml_backend_cann_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_cann_buffer_get_base,\n    /* .init_tensor     = */ ggml_backend_cann_buffer_init_tensor,\n    /* .memset_tensor   = */ NULL,\n    /* .set_tensor      = */ ggml_backend_cann_buffer_set_tensor,\n    /* .get_tensor      = */ ggml_backend_cann_buffer_get_tensor,\n    /* .cpy_tensor      = */ ggml_backend_cann_buffer_cpy_tensor,\n    /* .clear           = */ ggml_backend_cann_buffer_clear,\n    /* .reset           = */ NULL,\n};\n\n/**\n * @brief Allocates a new CANN buffer of the specified type and size.\n *\n * This function allocates a new CANN buffer on the specified device with the\n * given size.\n *\n * @param buft Pointer to the buffer type context.\n * @param size Size in bytes of the buffer to allocate.\n * @return Pointer to the allocated buffer, or nullptr if allocation fails.\n */\nstatic ggml_backend_buffer_t ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;\n\n    ggml_cann_set_device(buft_ctx->device);\n\n    const size_t alignment = 128;\n    size                   = GGML_PAD(size, alignment);\n    if (size == 0) {\n        size = alignment;\n    }\n    void *   dev_ptr;\n    aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);\n    if (err != ACL_SUCCESS) {\n        GGML_LOG_ERROR(\"%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\\n\", __func__,\n                       size / 1024.0 / 1024.0, buft_ctx->device, aclGetRecentErrMsg());\n        return nullptr;\n    }\n\n    ggml_backend_cann_buffer_context * ctx = new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);\n\n    return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface, ctx, size);\n}\n\n/**\n * @brief Retrieves the memory alignment requirement for CANN buffers of this\n * type.\n *\n * This function returns the alignment requirement in bytes for memory allocated\n * by the CANN buffer type.\n *\n * @param buft Pointer to the buffer type context (unused in this\n * implementation).\n * @return The alignment requirement in bytes (fixed at 128 bytes for CANN\n * buffers).\n */\nstatic size_t ggml_backend_cann_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    return 128;\n\n    GGML_UNUSED(buft);\n}\n\n/**\n * @brief Calculates the allocation size required for a tensor in a CANN buffer.\n *\n * Computes the total allocation size needed for storing the tensor's data in a\n * CANN buffer, considering any necessary padding or adjustments for quantized\n * types.\n *\n * @param buft Pointer to the buffer type context (unused in this\n * implementation).\n * @param tensor Pointer to the tensor for which the allocation size is\n * calculated.\n * @return The total allocation size in bytes required for the tensor in the\n * CANN buffer.\n */\nstatic size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,\n                                                           const ggml_tensor *        tensor) {\n    size_t  size = ggml_nbytes(tensor);\n    int64_t ne0  = tensor->ne[0];\n\n    // Only check env once.\n    static bool weight_to_nz = parse_bool(get_env_as_lowercase(\"GGML_CANN_WEIGHT_NZ\").value_or(\"on\"));\n\n    // last line must bigger than 32, because every single op deal at\n    // least 32 bytes.\n    // TODO: quantized type?\n    // int64_t line_size = ne0 * ggml_element_size(tensor);\n    // int64_t line_size_align_32 = (line_size + 31) & ~31;\n    // size += (line_size_align_32 - line_size);\n    if (ggml_is_quantized(tensor->type)) {\n        if (ne0 % MATRIX_ROW_PADDING != 0) {\n            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);\n        }\n    } else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {\n        // NZ format weight are not support quantized yet.\n        // If ND tensor transform to NZ, size may changed.\n        int64_t shape[] = { tensor->ne[1], tensor->ne[0] };\n        GGML_ASSERT(tensor->ne[2] == 1);\n        GGML_ASSERT(tensor->ne[3] == 1);\n        const aclIntArray * acl_shape = aclCreateIntArray(shape, 2);\n        size_t              new_size;\n        ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(acl_shape, ggml_cann_type_mapping(tensor->type), &new_size));\n        ACL_CHECK(aclDestroyIntArray(acl_shape));\n        size = std::max(size, new_size);\n    }\n\n    return size;\n\n    GGML_UNUSED(buft);\n}\n\nstatic bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {\n    return false;\n\n    GGML_UNUSED(buft);\n}\n\n/**\n * @brief Interface for managing CANN buffer types in the GGML backend.\n *\n * Provides function pointers for allocating, querying properties, and managing\n * memory for CANN buffer types in the GGML backend.\n */\nstatic const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {\n    /* .get_name         = */ ggml_backend_cann_buffer_type_name,\n    /* .alloc_buffer     = */ ggml_backend_cann_buffer_type_alloc_buffer,\n    /* .get_alignment    = */ ggml_backend_cann_buffer_type_get_alignment,\n    /* .get_max_size     = */ NULL,  // defaults to SIZE_MAX\n    /* .get_alloc_size   = */ ggml_backend_cann_buffer_type_get_alloc_size,\n    /* .is_host          = */ ggml_backend_cann_buffer_type_is_host,\n};\n\n/**\n * @brief Retrieves the CANN buffer type for a specified device.\n *\n * This function initializes and returns the buffer type interface associated\n * with the given device. It ensures thread-safe access using a mutex.\n *\n * @param device The device index for which to retrieve the buffer type.\n * @return A pointer to the buffer type interface for the specified device, or\n * nullptr if the device index is out of range.\n */\nggml_backend_buffer_type_t ggml_backend_cann_buffer_type(int32_t device) {\n    static std::mutex           mutex;\n    std::lock_guard<std::mutex> lock(mutex);\n\n    if (device >= ggml_backend_cann_get_device_count()) {\n        return nullptr;\n    }\n\n    static ggml_backend_buffer_type ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];\n\n    static bool ggml_backend_cann_buffer_type_initialized = false;\n\n    if (!ggml_backend_cann_buffer_type_initialized) {\n        for (int32_t i = 0; i < ggml_cann_info().device_count; i++) {\n            ggml_backend_cann_buffer_types[i] = {\n                /* .iface    = */ ggml_backend_cann_buffer_type_interface,\n                /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), i),\n                /* .context  = */\n                new ggml_backend_cann_buffer_type_context{ i, \"CANN\" + std::to_string(i) },\n            };\n        }\n        ggml_backend_cann_buffer_type_initialized = true;\n    }\n\n    return &ggml_backend_cann_buffer_types[device];\n}\n\n/**\n * @brief Retrieves the name associated with a CANN host buffer type.\n *\n * This function returns the descriptive name associated with the specified\n * CANN host buffer type context.\n *\n * @param buft Pointer to the host buffer type context.\n * @return Const pointer to the C-style string containing the name.\n */\nstatic const char * ggml_backend_cann_host_buffer_type_name(ggml_backend_buffer_type_t buft) {\n    return \"CANN_Host\";\n\n    GGML_UNUSED(buft);\n}\n\n/**\n * @brief Retrieves the name associated with a CANN host buffer.\n *\n * This function returns the descriptive name associated with the specified\n * CANN host buffer context.\n *\n * @param buft Pointer to the host buffer context.\n * @return Const pointer to the C-style string containing the name.\n */\nstatic const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buffer) {\n    return \"CANN_Host\";\n\n    GGML_UNUSED(buffer);\n}\n\n/**\n * @brief Free resources associated with a CANN host buffer.\n *\n * This function frees the resources associated with a CANN host buffer, including\n * its context.\n *\n * @param buffer The CANN host buffer to free.\n */\nstatic void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {\n    ACL_CHECK(aclrtFreeHost(buffer->context));\n}\n\n/**\n * @brief Allocates a new CANN host buffer of the specified size.\n *\n * This function allocates a new CANN host buffer with the given size.\n * @param size Size in bytes of the host buffer to allocate.\n * @return Pointer to the allocated host buffer, or nullptr if allocation fails.\n */\nstatic void * ggml_cann_host_malloc(size_t size) {\n    if (getenv(\"GGML_CANN_NO_PINNED\") != nullptr) {\n        return nullptr;\n    }\n\n    const size_t alignment = 128;\n    size                   = GGML_PAD(size, alignment);\n    if (size == 0) {\n        size = alignment;\n    }\n\n    void *   hostPtr = nullptr;\n    aclError err     = aclrtMallocHost((void **) &hostPtr, size);\n    if (err != ACL_SUCCESS) {\n        GGML_LOG_WARN(\"%s: failed to allocate %.2f MiB of pinned memory: %s\\n\", __func__, size / 1024.0 / 1024.0,\n                      aclGetRecentErrMsg());\n        return nullptr;\n    }\n    return hostPtr;\n}\n\n/**\n * @brief Allocates a new CANN host buffer of the specified type and size.\n *\n * @param buft Pointer to the host buffer type context.\n * @param size Size in bytes of the host buffer to allocate.\n * @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails.\n */\nstatic ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,\n                                                                             size_t                     size) {\n    void * hostPtr = ggml_cann_host_malloc(size);\n\n    if (hostPtr == nullptr) {\n        // fallback to cpu buffer\n        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);\n    }\n\n    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);\n    buffer->buft                 = buft;\n    buffer->iface.free_buffer    = ggml_backend_cann_host_buffer_free;\n\n    return buffer;\n}\n\n/**\n * @brief Interface for managing CANN host buffer types in the GGML backend.\n *\n * Provides function pointers for allocating, querying properties, and managing\n * memory for CANN buffer types in the GGML backend.\n */\nggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {\n    static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = {\n        /* .iface    = */ {\n                           /* .get_name         = */ ggml_backend_cann_host_buffer_type_name,\n                           /* .alloc_buffer     = */ ggml_backend_cann_host_buffer_type_alloc_buffer,\n                           /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,\n                           /* .get_max_size     = */ NULL,  // defaults to SIZE_MAX\n            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,\n                           /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,\n                           },\n        /* .device   = */\n        ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),\n        /* .context  = */ nullptr,\n    };\n\n    return &ggml_backend_cann_buffer_type_host;\n}\n\n/**\n * @brief Computes the forward operation for a given tensor using CANN\n * operations.\n *\n * This function selects the appropriate CANN operation based on the type of\n * operation specified in the tensor and performs the computation.\n *\n * @param ctx The CANN context containing necessary resources and\n * configurations.\n * @param dst The destination tensor where the result of the computation will be\n * stored.\n * @return true if the computation was successful; false otherwise.\n */\nstatic bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct ggml_tensor * dst) {\n    switch (dst->op) {\n        case GGML_OP_REPEAT:\n            ggml_cann_repeat(ctx, dst);\n            break;\n        case GGML_OP_GET_ROWS:\n            ggml_cann_get_rows(ctx, dst);\n            break;\n        case GGML_OP_SET_ROWS:\n            ggml_cann_set_rows(ctx, dst);\n            break;\n        case GGML_OP_DUP:\n            ggml_cann_dup(ctx, dst);\n            break;\n        case GGML_OP_ADD:\n        case GGML_OP_ADD1:\n            ggml_cann_binary_op<aclnn_add>(ctx, dst);\n            break;\n        case GGML_OP_SUB:\n            ggml_cann_binary_op<aclnn_sub>(ctx, dst);\n            break;\n        case GGML_OP_ACC:\n            ggml_cann_acc(ctx, dst);\n            break;\n        case GGML_OP_MUL:\n            ggml_cann_binary_op<aclnn_mul>(ctx, dst);\n            break;\n        case GGML_OP_DIV:\n            ggml_cann_binary_op<aclnn_div>(ctx, dst);\n            break;\n        case GGML_OP_UNARY:\n            switch (ggml_get_unary_op(dst)) {\n                case GGML_UNARY_OP_ABS:\n                    GGML_CANN_CALL_OP_UNARY(Abs);\n                    break;\n                case GGML_UNARY_OP_NEG:\n                    GGML_CANN_CALL_OP_UNARY(Neg);\n                    break;\n                case GGML_UNARY_OP_GELU:\n                case GGML_UNARY_OP_GELU_ERF:\n                    // aclnnGelu internally uses the erf-based approximation.\n                    GGML_CANN_CALL_OP_UNARY(Gelu);\n                    break;\n                case GGML_UNARY_OP_SILU:\n                    GGML_CANN_CALL_OP_UNARY(Silu);\n                    break;\n                case GGML_UNARY_OP_GELU_QUICK:\n                    {\n                        auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {\n                            GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);\n                        };\n                        ggml_cann_op_unary(lambda, ctx, dst);\n                    }\n                    break;\n                case GGML_UNARY_OP_TANH:\n                    GGML_CANN_CALL_OP_UNARY(Tanh);\n                    break;\n                case GGML_UNARY_OP_RELU:\n                    GGML_CANN_CALL_OP_UNARY(Relu);\n                    break;\n                case GGML_UNARY_OP_SIGMOID:\n                    GGML_CANN_CALL_OP_UNARY(Sigmoid);\n                    break;\n                case GGML_UNARY_OP_HARDSIGMOID:\n                    GGML_CANN_CALL_OP_UNARY(Hardsigmoid);\n                    break;\n                case GGML_UNARY_OP_HARDSWISH:\n                    GGML_CANN_CALL_OP_UNARY(Hardswish);\n                    break;\n                case GGML_UNARY_OP_EXP:\n                    GGML_CANN_CALL_OP_UNARY(Exp);\n                    break;\n                case GGML_UNARY_OP_ELU:\n                    ggml_cann_elu(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_SGN:\n                    GGML_CANN_CALL_OP_UNARY(Sign);\n                    break;\n                case GGML_UNARY_OP_STEP:\n                    ggml_cann_step(ctx, dst);\n                    break;\n                default:\n                    return false;\n            }\n            break;\n        case GGML_OP_GLU:\n            switch (ggml_get_glu_op(dst)) {\n                case GGML_GLU_OP_REGLU:\n                    GGML_CANN_CALL_OP_UNARY_GATED(Relu);\n                    break;\n                case GGML_GLU_OP_GEGLU:\n                case GGML_GLU_OP_GEGLU_ERF:\n                    // aclnnGelu internally uses the erf-based approximation.\n                    GGML_CANN_CALL_OP_UNARY_GATED(Gelu);\n                    break;\n                case GGML_GLU_OP_SWIGLU:\n                    GGML_CANN_CALL_OP_UNARY_GATED(Silu);\n                    break;\n                case GGML_GLU_OP_GEGLU_QUICK:\n                    {\n                        auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {\n                            GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);\n                        };\n                        ggml_cann_op_unary_gated(lambda, ctx, dst);\n                    }\n                    break;\n                default:\n                    return false;\n            }\n            break;\n        case GGML_OP_NORM:\n            ggml_cann_norm(ctx, dst);\n            break;\n        case GGML_OP_GROUP_NORM:\n            ggml_cann_group_norm(ctx, dst);\n            break;\n        case GGML_OP_L2_NORM:\n            ggml_cann_l2_norm(ctx, dst);\n            break;\n        case GGML_OP_CROSS_ENTROPY_LOSS:\n            ggml_cann_cross_entropy_loss(ctx, dst);\n            break;\n        case GGML_OP_CONCAT:\n            ggml_cann_concat(ctx, dst);\n            break;\n        case GGML_OP_UPSCALE:\n            ggml_cann_upsample_nearest2d(ctx, dst);\n            break;\n        case GGML_OP_PAD:\n            ggml_cann_pad(ctx, dst);\n            break;\n        case GGML_OP_ARANGE:\n            ggml_cann_arange(ctx, dst);\n            break;\n        case GGML_OP_TIMESTEP_EMBEDDING:\n            ggml_cann_timestep_embedding(ctx, dst);\n            break;\n        case GGML_OP_LEAKY_RELU:\n            ggml_cann_leaky_relu(ctx, dst);\n            break;\n        case GGML_OP_RMS_NORM:\n            ggml_cann_rms_norm(ctx, dst);\n            break;\n        case GGML_OP_MUL_MAT:\n            ggml_cann_mul_mat(ctx, dst);\n            break;\n        case GGML_OP_MUL_MAT_ID:\n            ggml_cann_mul_mat_id(ctx, dst);\n            break;\n        case GGML_OP_SCALE:\n            ggml_cann_scale(ctx, dst);\n            break;\n        case GGML_OP_SQR:\n            GGML_ASSERT(dst->src[1] == nullptr);\n            dst->src[1] = dst->src[0];\n            ggml_cann_binary_op<aclnn_mul>(ctx, dst);\n            break;\n        case GGML_OP_SQRT:\n            GGML_CANN_CALL_OP_UNARY(Sqrt);\n            break;\n        case GGML_OP_CLAMP:\n            ggml_cann_clamp(ctx, dst);\n            break;\n        case GGML_OP_CPY:\n            ggml_cann_cpy(ctx, dst);\n            break;\n        case GGML_OP_CONT:\n            ggml_cann_dup(ctx, dst);\n            break;\n        case GGML_OP_NONE:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n            break;\n        case GGML_OP_DIAG_MASK_INF:\n            ggml_cann_diag_mask(ctx, dst, -INFINITY);\n            break;\n        case GGML_OP_SOFT_MAX:\n            ggml_cann_softmax(ctx, dst);\n            break;\n        case GGML_OP_ROPE:\n            ggml_cann_rope(ctx, dst);\n            break;\n        case GGML_OP_IM2COL:\n            ggml_cann_im2col(ctx, dst);\n            break;\n        case GGML_OP_POOL_2D:\n            ggml_cann_pool2d(ctx, dst);\n            break;\n        case GGML_OP_SUM:\n            ggml_cann_sum(ctx, dst);\n            break;\n        case GGML_OP_SUM_ROWS:\n            ggml_cann_sum_rows(ctx, dst);\n            break;\n        case GGML_OP_ARGSORT:\n            ggml_cann_argsort(ctx, dst);\n            break;\n        case GGML_OP_ARGMAX:\n            ggml_cann_argmax(ctx, dst);\n            break;\n        case GGML_OP_COS:\n            ggml_cann_op_unary<aclnn_cos>(ctx, dst);\n            break;\n        case GGML_OP_SIN:\n            ggml_cann_op_unary<aclnn_sin>(ctx, dst);\n            break;\n        case GGML_OP_CONV_TRANSPOSE_1D:\n            ggml_cann_conv_transpose_1d(ctx, dst);\n            break;\n        case GGML_OP_LOG:\n            GGML_CANN_CALL_OP_UNARY(Log);\n            break;\n        case GGML_OP_MEAN:\n            ggml_cann_mean(ctx, dst);\n            break;\n        case GGML_OP_PAD_REFLECT_1D:\n            ggml_cann_pad_reflect_1d(ctx, dst);\n            break;\n        case GGML_OP_COUNT_EQUAL:\n            ggml_cann_count_equal(ctx, dst);\n            break;\n        case GGML_OP_FLASH_ATTN_EXT:\n            ggml_cann_flash_attn_ext(ctx, dst);\n            break;\n        case GGML_OP_OUT_PROD:\n            ggml_cann_out_prod(ctx, dst);\n            break;\n        case GGML_OP_GATED_LINEAR_ATTN:\n            ggml_cann_gated_linear_attn(ctx, dst);\n            break;\n        case GGML_OP_SSM_CONV:\n            ggml_cann_ssm_conv(ctx, dst);\n            break;\n        default:\n            return false;\n    }\n\n    return true;\n}\n\n// backend\n/**\n * @brief Retrieves the name associated with the CANN backend.\n *\n * This function returns the name assigned to the CANN backend, which is stored\n * in the context of the provided backend structure.\n *\n * @param backend Pointer to the CANN backend structure.\n * @return A pointer to a constant string representing the backend name.\n */\nstatic const char * ggml_backend_cann_name(ggml_backend_t backend) {\n    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;\n\n    return cann_ctx->name.c_str();\n}\n\n/**\n * @brief Frees resources associated with the CANN backend.\n *\n * This function releases resources associated with the CANN backend context\n * and resets the device associated with the backend to its initial state.\n *\n * @param backend Pointer to the CANN backend structure to be freed.\n */\nstatic void ggml_backend_cann_free(ggml_backend_t backend) {\n    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;\n    ACL_CHECK(aclrtSynchronizeDevice());\n    ACL_CHECK(aclrtResetDevice(cann_ctx->device));\n\n    delete cann_ctx;\n    delete backend;\n}\n\n/**\n * @brief Sets tensor data asynchronously in the CANN backend.\n *\n * This function asynchronously sets tensor data in the CANN backend.\n *\n * @param backend Pointer to the CANN backend structure.\n * @param tensor Pointer to the tensor structure to set data for.\n * @param data Pointer to the host data to copy to the tensor.\n * @param offset Offset in bytes within the host data.\n * @param size Size of the data to copy in bytes.\n */\nstatic void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,\n                                               ggml_tensor *  tensor,\n                                               const void *   data,\n                                               size_t         offset,\n                                               size_t         size) {\n    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;\n    ggml_backend_buffer_t       buf      = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;\n\n    GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && \"unsupported buffer type\");\n    GGML_ASSERT(!ggml_is_quantized(tensor->type));\n\n    ACL_CHECK(aclrtMemcpyAsync((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE,\n                               cann_ctx->stream()));\n}\n\n/**\n * @brief Gets tensor data asynchronously in the CANN backend.\n *\n * This function asynchronously gets tensor data in the CANN backend.\n *\n * @param backend Pointer to the CANN backend structure.\n * @param tensor Pointer to the tensor structure to get data from.\n * @param data Pointer to the host data to copy from the tensor.\n * @param offset Offset in bytes within the host data.\n * @param size Size of the data to copy in bytes.\n */\nstatic void ggml_backend_cann_get_tensor_async(ggml_backend_t      backend,\n                                               const ggml_tensor * tensor,\n                                               void *              data,\n                                               size_t              offset,\n                                               size_t              size) {\n    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;\n    ggml_backend_buffer_t       buf      = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;\n\n    GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && \"unsupported buffer type\");\n    GGML_ASSERT(!ggml_is_quantized(tensor->type));\n\n    ACL_CHECK(aclrtMemcpyAsync(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST,\n                               cann_ctx->stream()));\n}\n\n/**\n * @brief Asynchronously copies tensor data between CANN backends.\n *\n * This function copies tensor data asynchronously between two CANN backends. It\n * checks if both tensors reside in CANN buffers and whether the devices support\n * peer-to-peer access for direct copying. If not, it returns false.\n *\n * @param backend_src Pointer to the source CANN backend structure.\n * @param backend_dst Pointer to the destination CANN backend structure.\n * @param src Pointer to the source tensor to copy data from.\n * @param dst Pointer to the destination tensor to copy data to.\n * @return true if the copy operation succeeds, false otherwise.\n */\nstatic bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t      backend_src,\n                                               ggml_backend_t      backend_dst,\n                                               const ggml_tensor * src,\n                                               ggml_tensor *       dst) {\n    GGML_ASSERT(ggml_backend_is_cann(backend_src) || ggml_backend_is_cann(backend_dst));\n\n    GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src));\n\n    if (!ggml_backend_buft_is_cann(src->buffer->buft) || !ggml_backend_buft_is_cann(dst->buffer->buft)) {\n        return false;\n    }\n\n    ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;\n    ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;\n\n    ggml_backend_cann_context * cann_ctx_src = (ggml_backend_cann_context *) backend_src->context;\n    ggml_backend_cann_context * cann_ctx_dst = (ggml_backend_cann_context *) backend_dst->context;\n\n    size_t copy_size = ggml_nbytes(dst);\n    if (copy_size == 0) {\n        return true;\n    }\n    if (backend_src != backend_dst) {\n#ifdef ASCEND_310P\n        // TODO: Support 310p P2P copy\n        return false;\n#endif\n        ggml_backend_cann_buffer_context * buf_ctx_src = (ggml_backend_cann_buffer_context *) buf_src->context;\n        ggml_backend_cann_buffer_context * buf_ctx_dst = (ggml_backend_cann_buffer_context *) buf_dst->context;\n\n        GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);\n        GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);\n\n        int32_t canAccessPeer = 0;\n        ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device, cann_ctx_dst->device));\n        if (!canAccessPeer) {\n            return false;\n        }\n\n        // need open both directions for memcpyasync between devices.\n        ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));\n        ggml_cann_set_device(cann_ctx_src->device);\n        ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));\n\n        // wait for task_queue empty to keep task order.\n        ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,\n                                   cann_ctx_src->stream()));\n        // record event on src stream after the copy\n        // TODO: this event is not effective with acl graph mode, change to use aclrtSynchronizeStream\n        // if (!cann_ctx_src->copy_event) {\n        //     ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC));\n        // }\n        // ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream()));\n\n        // // wait on dst stream for the copy to complete\n        // ggml_cann_set_device(cann_ctx_dst->device);\n        // ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event));\n        ACL_CHECK(aclrtSynchronizeStream(cann_ctx_src->stream()));\n    } else {\n        // src and dst are on the same backend\n        ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,\n                                   cann_ctx_dst->stream()));\n    }\n\n    return true;\n}\n\n/**\n * @brief Synchronizes a CANN backend.\n *\n * This function synchronizes the specified CANN backend by waiting for all\n * operations in its associated stream to complete.\n *\n * @param backend Pointer to the CANN backend structure to synchronize.\n */\nstatic void ggml_backend_cann_synchronize(ggml_backend_t backend) {\n    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;\n    ggml_cann_set_device(cann_ctx->device);\n    ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));\n}\n\n/**\n * @brief Check if CANN backend can fuse the specified operation sequence\n *\n * This function determines whether an operation sequence starting from the specified node\n * can be fused into an optimized operation in the CANN backend. Operation fusion can reduce\n * memory access overhead and improve computational efficiency.\n *\n * @param cgraph Pointer to the computation graph\n * @param node_idx Index of the starting node in the computation graph\n * @param ops Sequence of operation types to check for fusion\n * @return true if the operations can be fused\n * @return false if the operations cannot be fused\n */\nstatic bool ggml_cann_can_fuse(const struct ggml_cgraph *          cgraph,\n                               int                                 node_idx,\n                               std::initializer_list<enum ggml_op> ops) {\n    if (!ggml_can_fuse(cgraph, node_idx, ops)) {\n        return false;\n    }\n\n    // CANN backend supports fusing ADD + RMS_NORM operations\n    if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_ADD && ops.begin()[1] == GGML_OP_RMS_NORM) {\n        ggml_tensor * add_node = cgraph->nodes[node_idx];\n        // TODO: support broadcast for ADD + RMS_NORM\n        if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] || add_node->src[0]->ne[1] != add_node->src[1]->ne[1] ||\n            add_node->src[0]->ne[2] != add_node->src[1]->ne[2] || add_node->src[0]->ne[3] != add_node->src[1]->ne[3]) {\n            return false;\n        }\n        return true;\n    }\n\n    return false;\n}\n\n/**\n * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.\n *\n * If CANN graph execution is enabled and graph capture is required, this function begins\n * graph capture, runs the graph, ends capture, and stores the captured graph.\n *\n * Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.\n *\n * @param cann_ctx                     The CANN backend context.\n * @param cgraph                       The ggml computation graph.\n * @param use_cann_graph               Whether to use CANN graph execution.\n * @param cann_graph_capture_required  Whether graph capture is needed due to graph changes.\n */\nstatic void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx,\n                                            ggml_cgraph *               cgraph,\n                                            bool                        use_cann_graph,\n                                            bool                        cann_graph_capture_required) {\n#ifdef USE_ACL_GRAPH\n    if (use_cann_graph && cann_graph_capture_required) {  // Begin CANN graph capture\n        ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));\n    }\n#endif  // USE_ACL_GRAPH\n    // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.\n    // With the use of CANN graphs, the execution will be performed by the graph launch.\n    static bool opt_fusion = parse_bool(get_env_as_lowercase(\"GGML_CANN_OPERATOR_FUSION\").value_or(\"\"));\n\n    if (!use_cann_graph || cann_graph_capture_required) {\n        for (int i = 0; i < cgraph->n_nodes; i++) {\n            ggml_tensor * node = cgraph->nodes[i];\n            if (opt_fusion) {\n                if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {\n                    ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);\n                    i++;\n                    continue;\n                }\n            }\n\n            if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||\n                node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {\n                continue;\n            }\n\n            if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {\n                continue;\n            }\n\n            bool ok = ggml_cann_compute_forward(*cann_ctx, node);\n            if (!ok) {\n                GGML_LOG_ERROR(\"%s: op not supported %s (%s)\\n\", __func__, node->name, ggml_op_name(node->op));\n            }\n            GGML_ASSERT(ok);\n        }\n    }\n\n#ifdef USE_ACL_GRAPH\n    if (use_cann_graph) {\n        GGML_ASSERT(!cann_ctx->graph_lru_cache.cache_list.empty());\n        ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();\n\n        if (cann_graph_capture_required) {  // End CANN graph capture\n            ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));\n        }\n\n        // Execute CANN graph\n        ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream()));\n    }\n#endif  // USE_ACL_GRAPH\n}\n\n/**\n * @brief Computes a computational graph using a CANN backend.\n *\n * This function computes the operations defined in the computational graph\n * using the specified CANN backend.\n *\n * @param backend Pointer to the CANN backend structure to use for computation.\n * @param cgraph Pointer to the computational graph structure containing nodes\n *               representing operations to be computed.\n * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation\n *         completes successfully, otherwise an appropriate error status.\n */\nstatic enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {\n    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;\n    ggml_cann_set_device(cann_ctx->device);\n    g_nz_workspaces[cann_ctx->device].clear();\n\n    // calculate rope cache for fist layer in current device.\n    cann_ctx->rope_cache.cached = false;\n\n    bool graph_capture_required = false;\n#ifdef USE_ACL_GRAPH\n    bool use_cann_graph = true;\n\n    static bool prefill_use_graph = parse_bool(get_env_as_lowercase(\"GGML_CANN_PREFILL_USE_GRAPH\").value_or(\"\"));\n    if (!prefill_use_graph) {\n        // Do not use acl_graph for prefill.\n        for (int i = 0; i < cgraph->n_nodes; i++) {\n            ggml_tensor * node = cgraph->nodes[i];\n            // TODO: Optimize here. Currently, we can only\n            // get seq_len by FA's input.\n            if (node->op == GGML_OP_FLASH_ATTN_EXT) {\n                // Q -> src[0], shape: [B, S, N, D]\n                use_cann_graph = (node->src[0]->ne[1] == 1);\n                break;\n            }\n        }\n    }\n\n    if (!cann_ctx->acl_graph_mode) {\n        use_cann_graph = false;\n    }\n\n    if (use_cann_graph) {\n        // If no matching graph is found, the graph needs to be recaptured.\n        graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph);\n        if (graph_capture_required) {\n            // If no matching graph is found, add a new ACL graph.\n            ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);\n            cann_ctx->graph_lru_cache.push(new_graph);\n        }\n    }\n#else\n    bool use_cann_graph = false;\n#endif  // USE_ACL_GRAPH\n    evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, graph_capture_required);\n\n    return GGML_STATUS_SUCCESS;\n}\n\n/**\n * @brief Checks if the CANN backend supports a specific operation.\n *\n * This function checks whether the specified operation is supported by the\n * CANN backend.\n *\n * @param backend Pointer to the CANN backend structure to check support for\n *                the operation.\n * @param op Pointer to the tensor representing the operation to check.\n * @return bool Returns true if the operation is supported by the backend,\n *              otherwise false.\n */\nstatic bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n    switch (op->op) {\n        case GGML_OP_UNARY:\n            switch (ggml_get_unary_op(op)) {\n                case GGML_UNARY_OP_ABS:\n                case GGML_UNARY_OP_NEG:\n                case GGML_UNARY_OP_GELU:\n                case GGML_UNARY_OP_SILU:\n                case GGML_UNARY_OP_RELU:\n                case GGML_UNARY_OP_SIGMOID:\n                case GGML_UNARY_OP_HARDSIGMOID:\n                case GGML_UNARY_OP_HARDSWISH:\n                case GGML_UNARY_OP_GELU_QUICK:\n                case GGML_UNARY_OP_TANH:\n                case GGML_UNARY_OP_EXP:\n                case GGML_UNARY_OP_ELU:\n                case GGML_UNARY_OP_SGN:\n                case GGML_UNARY_OP_STEP:\n                case GGML_UNARY_OP_GELU_ERF:\n                    return true;\n                default:\n                    return false;\n            }\n        case GGML_OP_GLU:\n            switch (ggml_get_glu_op(op)) {\n                case GGML_GLU_OP_REGLU:\n                case GGML_GLU_OP_GEGLU:\n                case GGML_GLU_OP_SWIGLU:\n                case GGML_GLU_OP_GEGLU_ERF:\n                case GGML_GLU_OP_GEGLU_QUICK:\n                    return true;\n                default:\n                    return false;\n            }\n            break;\n        case GGML_OP_MUL_MAT:\n            {\n                switch (op->src[0]->type) {\n                    case GGML_TYPE_F16:\n                    case GGML_TYPE_F32:\n                        return true;\n                    case GGML_TYPE_Q8_0:\n                    case GGML_TYPE_Q4_0:\n#ifdef ASCEND_310P\n                        // Q4 && Q8 per group is not support on 310p device\n                        return false;\n#endif\n                        // only support contiguous for quantized types.\n                        return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);\n                    default:\n                        return false;\n                }\n            }\n        case GGML_OP_MUL_MAT_ID:\n            switch (op->src[0]->type) {\n                case GGML_TYPE_F16:\n                case GGML_TYPE_F32:\n                    return true;\n                case GGML_TYPE_Q8_0:\n                case GGML_TYPE_Q4_0:\n#ifdef ASCEND_310P\n                    // Q4 && Q8 per group is not support on 310p device\n                    return false;\n#endif\n                    // only support contiguous for quantized types.\n                    return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);\n                default:\n                    return false;\n            }\n        // embedding\n        case GGML_OP_GET_ROWS:\n            {\n                switch (op->src[0]->type) {\n                    case GGML_TYPE_F32:\n                    case GGML_TYPE_F16:\n                    case GGML_TYPE_Q8_0:\n                        return true;\n                    default:\n                        return false;\n                }\n            }\n            break;\n        case GGML_OP_SET_ROWS:\n            {\n                switch (op->type) {\n                    case GGML_TYPE_F32:\n                    case GGML_TYPE_F16:\n                        return true;\n                    default:\n                        return false;\n                }\n            }\n            break;\n        case GGML_OP_CPY:\n            {\n                ggml_tensor * src = op->src[0];\n                if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||\n                    (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) {\n                    // only support F32 and F16.\n                    return false;\n                }\n                return true;\n            }\n            break;\n        case GGML_OP_CONT:\n            {\n                // TODO: support GGML_TYPE_BF16\n                switch (op->src[0]->type) {\n                    case GGML_TYPE_F32:\n                    case GGML_TYPE_F16:\n                        return true;\n                    default:\n                        return false;\n                }\n            }\n        case GGML_OP_ROPE:\n            {\n                if (op->src[0]->ne[0] > 896) {\n                    return false;\n                }\n#ifdef ASCEND_310P\n                // TODO: Support rope_dim < ne00(dim)\n                if (op->src[0]->ne[0] != op->op_params[1]) {\n                    return false;\n                }\n                if (!ggml_is_contiguous(op->src[0])) {\n                    return false;\n                }\n#endif\n                return true;\n            }\n        case GGML_OP_UPSCALE:\n            {\n                // aclnnUpsampleNearest2dGetWorkspaceSize not support\n                // selfDimN[2]/outDimN[2] or selfDimC[3]/outDimC[3] not equal\n                if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) {\n                    return false;\n                }\n                if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {\n                    return false;\n                }\n                if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {\n                    return false;\n                }\n                return true;\n            }\n        case GGML_OP_POOL_2D:\n            {\n                const int32_t * opts = (const int32_t *) op->op_params;\n#ifdef ASCEND_310P\n                enum ggml_op_pool opt = static_cast<ggml_op_pool>(opts[0]);\n                if (opt == GGML_OP_POOL_MAX) {\n                    return false;\n                }\n#endif\n                const int k0 = opts[1];\n                const int k1 = opts[2];\n                const int p0 = opts[5];\n                const int p1 = opts[6];\n                // value of paddingH should be at most half of kernelH\n                // value of paddingW should be at most half of kernelW\n                return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));\n            }\n        case GGML_OP_SUM:\n            return ggml_is_contiguous_rows(op->src[0]);\n        case GGML_OP_L2_NORM:\n        case GGML_OP_CROSS_ENTROPY_LOSS:\n        case GGML_OP_DUP:\n        case GGML_OP_IM2COL:\n        case GGML_OP_CONCAT:\n        case GGML_OP_REPEAT:\n        case GGML_OP_NONE:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n        case GGML_OP_NORM:\n        case GGML_OP_ADD:\n        case GGML_OP_ADD1:\n        case GGML_OP_SUB:\n        case GGML_OP_MUL:\n        case GGML_OP_DIV:\n        case GGML_OP_RMS_NORM:\n        case GGML_OP_SQR:\n        case GGML_OP_SQRT:\n        case GGML_OP_CLAMP:\n        case GGML_OP_DIAG_MASK_INF:\n        case GGML_OP_SUM_ROWS:\n        case GGML_OP_ARGSORT:\n        case GGML_OP_ACC:\n        case GGML_OP_GROUP_NORM:\n            return true;\n        case GGML_OP_PAD:\n            // TODO: add circular padding support for cann, see https://github.com/ggml-org/llama.cpp/pull/16985\n            return ggml_get_op_params_i32(op, 8) == 0;\n        case GGML_OP_ARANGE:\n        case GGML_OP_TIMESTEP_EMBEDDING:\n        case GGML_OP_LEAKY_RELU:\n        case GGML_OP_ARGMAX:\n        case GGML_OP_COS:\n        case GGML_OP_SIN:\n        case GGML_OP_LOG:\n        case GGML_OP_MEAN:\n        case GGML_OP_PAD_REFLECT_1D:\n        case GGML_OP_COUNT_EQUAL:\n        case GGML_OP_GATED_LINEAR_ATTN:\n            return true;\n        case GGML_OP_OUT_PROD:\n            {\n#ifdef ASCEND_310P\n                // Ger is not supported on 310p device\n                return false;\n#endif\n                switch (op->src[0]->type) {\n                    case GGML_TYPE_F16:\n                    case GGML_TYPE_F32:\n                        return true;\n                    default:\n                        return false;\n                }\n            }\n        case GGML_OP_CONV_TRANSPOSE_1D:\n            return true;\n        case GGML_OP_SCALE:\n            float bias;\n            memcpy(&bias, (const float *) (op->op_params) + 1, sizeof(float));\n            return bias == 0.0f;  // TODO: support bias != 0.0f\n        case GGML_OP_SOFT_MAX:\n            // TODO: support attention sinks [TAG_ATTN_SINKS]\n            if (op->src[2]) {\n                return false;\n            }\n            return true;\n        case GGML_OP_FLASH_ATTN_EXT:\n            {\n#ifdef ASCEND_310P\n                // FA not support on 310p device\n                return false;\n#endif\n                // derived from [ggml-cuda.cu]\n                if (op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16) {\n                    return false;\n                }\n                if (op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 &&\n                    op->src[1]->type != GGML_TYPE_BF16) {\n                    return false;\n                }\n                if (op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16) {\n                    return false;\n                }\n                // TODO: support attention sinks [TAG_ATTN_SINKS]\n                if (op->src[4]) {\n                    return false;\n                }\n                if (op->src[1]->ne[0] != op->src[2]->ne[0]) {\n                    // different head sizes of K and V are not supported yet\n                    return false;\n                }\n                if (op->src[0]->ne[0] % 16 != 0) {\n                    // TODO: padding to support\n                    return false;\n                }\n                float logitSoftcap = 0.0f;\n                memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float));\n                if (logitSoftcap != 0.0f) {\n                    return false;\n                }\n                return true;\n            }\n        case GGML_OP_SSM_CONV:\n            return true;\n        default:\n            return false;\n    }\n\n    GGML_UNUSED(dev);\n}\n\n/**\n * @brief Records an event on the CANN backend stream.\n *\n * This function records the given event on the ACL runtime stream associated\n * with the backend context.\n *\n * @param event Pointer to the event structure to be recorded.\n */\nstatic void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {\n    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;\n    ACL_CHECK(aclrtRecordEvent((aclrtEvent) event->context, cann_ctx->stream()));\n}\n\n/**\n * @brief Waits for a recorded event to complete on the CANN backend stream.\n *\n * This function makes the given backend wait for the event to complete on its\n * ACL runtime stream.\n *\n * @param backend Pointer to the backend structure.\n * @param event Pointer to the event structure that the backend needs to wait\n * for.\n */\nstatic void ggml_backend_cann_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {\n    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;\n    if (ggml_backend_is_cann(backend)) {\n        ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(), (aclrtEvent) event->context));\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n}\n\n/**\n * @brief Structure defining the interface for the CANN backend.\n *\n * This structure contains function pointers for various operations\n * supported by the CANN backend, including name retrieval, memory\n * management, tensor operations, synchronization, and event handling.\n */\nstatic const ggml_backend_i ggml_backend_cann_interface = {\n    /* .get_name                = */ ggml_backend_cann_name,\n    /* .free                    = */ ggml_backend_cann_free,\n    /* .set_tensor_async        = */ ggml_backend_cann_set_tensor_async,\n    /* .get_tensor_async        = */ ggml_backend_cann_get_tensor_async,\n    /* .cpy_tensor_async        = */ ggml_backend_cann_cpy_tensor_async,\n    /* .synchronize             = */ ggml_backend_cann_synchronize,\n    /* .graph_plan_create       = */ NULL,\n    /* .graph_plan_free         = */ NULL,\n    /* .graph_plan_update       = */ NULL,\n    /* .graph_plan_compute      = */ NULL,\n    /* .graph_compute           = */ ggml_backend_cann_graph_compute,\n    /* .event_record            = */ ggml_backend_cann_event_record,\n    /* .event_wait              = */ ggml_backend_cann_event_wait,\n    /* .graph_optimize          = */ NULL,\n};\n\n/**\n * @brief Return the hardcoded GUID for the CANN backend.\n *\n * This function returns a static GUID which uniquely identifies the CANN\n * backend.\n *\n * @return A pointer to the static GUID.\n */\nstatic ggml_guid_t ggml_backend_cann_guid() {\n    static ggml_guid guid = { 0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,\n                              0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64 };\n    return &guid;\n}\n\n// backend device\nstruct ggml_backend_cann_device_context {\n    int         device;\n    std::string name;\n    std::string description;\n    int op_offload_min_batch_size;\n};\n\nstatic const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {\n    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;\n    return ctx->name.c_str();\n}\n\nstatic const char * ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {\n    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;\n    return ctx->description.c_str();\n}\n\nstatic void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {\n    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;\n    ggml_backend_cann_get_device_memory(ctx->device, free, total);\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {\n    GGML_UNUSED(dev);\n    return GGML_BACKEND_DEVICE_TYPE_GPU;\n}\n\nstatic void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {\n    props->name        = ggml_backend_cann_device_get_name(dev);\n    props->description = ggml_backend_cann_device_get_description(dev);\n    props->type        = ggml_backend_cann_device_get_type(dev);\n    ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);\n\n    bool host_buffer = getenv(\"GGML_CANN_NO_PINNED\") == nullptr;\n\n    props->caps = {\n        /* .async                 = */ false,\n        /* .host_buffer           = */ host_buffer,\n        /* .buffer_from_host_ptr  = */ false,\n        /* .events                = */ true,\n    };\n}\n\nstatic ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {\n    GGML_UNUSED(params);\n    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;\n    return ggml_backend_cann_init(ctx->device);\n}\n\n/**\n * @brief Checks if the CANN backend supports a specific backend buffer type.\n *\n * This function determines whether the CANN backend supports the given backend\n * buffer type by comparing the device context of the backend and buffer type.\n * It returns true if the devices are same between the backend context and\n * buffer type context.\n *\n * @param backend Pointer to the CANN backend.\n * @param buft Pointer to the backend buffer type to check.\n * @return bool Returns true if the CANN backend supports the buffer type,\n *              otherwise false.\n */\nstatic bool ggml_backend_cann_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    if (ggml_backend_buft_is_cann(buft)) {\n        ggml_backend_cann_device_context *      dev_ctx  = (ggml_backend_cann_device_context *) dev->context;\n        ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;\n        return buft_ctx->device == dev_ctx->device;\n    }\n    return false;\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {\n    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;\n    return ggml_backend_cann_buffer_type(ctx->device);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {\n    GGML_UNUSED(dev);\n    return ggml_backend_cann_host_buffer_type();\n}\n\n/**\n * @brief Determines if a tensor operation should be offloaded to the CANN\n * backend.\n *\n * This function checks if a given tensor operation should be offloaded to the\n * CANN backend based on the operation type and the size of the tensor. It\n * returns true if the second dimension (ne[1]) of the tensor is greater than or\n * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.\n *\n * @param backend Pointer to the CANN backend.\n * @param op Pointer to the tensor operation to check.\n * @return bool Returns true if the operation should be offloaded, otherwise\n * false.\n */\nstatic bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n    ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;\n\n    return op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS;\n}\n\n/**\n * @brief Creates a new event for the CANN backend device.\n *\n * This function initializes a new event for the CANN backend by setting the\n * device and creating an ACL runtime event. The created event is then wrapped\n * in a ggml_backend_event structure and returned.\n *\n * @param backend Pointer to the CANN backend.\n * @return ggml_backend_event_t Returns a pointer to the new event structure.\n */\nstatic ggml_backend_event_t ggml_backend_cann_device_event_new(ggml_backend_dev_t dev) {\n    ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *) dev->context;\n\n    ggml_cann_set_device(dev_ctx->device);\n\n    aclrtEvent event;\n    ACL_CHECK(aclrtCreateEvent(&event));\n\n    return new ggml_backend_event{\n        /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),\n        /* .context = */ event,\n    };\n}\n\n/**\n * @brief Frees a CANN backend event.\n *\n * This function destroys the ACL runtime event associated with the given CANN\n * backend event and then deletes the event structure itself.\n *\n * @param event Pointer to the event structure to be freed.\n */\nstatic void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {\n    ACL_CHECK(aclrtDestroyEvent((aclrtEvent) event->context));\n\n    delete event;\n    GGML_UNUSED(dev);\n}\n\n/**\n * @brief Synchronizes the given event on the CANN backend.\n *\n * This function waits for the specified event to complete on the ACL runtime.\n *\n * @param event Pointer to the event structure to be synchronized.\n */\nstatic void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {\n    ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent) event->context));\n\n    GGML_UNUSED(dev);\n}\n\nstatic const ggml_backend_device_i ggml_backend_cann_device_interface = {\n    /* .get_name                = */ ggml_backend_cann_device_get_name,\n    /* .get_description         = */ ggml_backend_cann_device_get_description,\n    /* .get_memory              = */ ggml_backend_cann_device_get_memory,\n    /* .get_type                = */ ggml_backend_cann_device_get_type,\n    /* .get_props               = */ ggml_backend_cann_device_get_props,\n    /* .init_backend            = */ ggml_backend_cann_device_init,  // called for every card\n    /* .get_buffer_type         = */ ggml_backend_cann_device_get_buffer_type,\n    /* .get_host_buffer_type    = */ ggml_backend_cann_device_get_host_buffer_type,\n    /* .buffer_from_host_ptr    = */ NULL,  // not supported for CANN\n    /* .supports_op             = */ ggml_backend_cann_supports_op,\n    /* .supports_buft           = */ ggml_backend_cann_supports_buft,\n    /* .offload_op              = */ ggml_backend_cann_offload_op,\n    /* .event_new               = */ ggml_backend_cann_device_event_new,\n    /* .event_free              = */ ggml_backend_cann_device_event_free,\n    /* .event_synchronize       = */ ggml_backend_cann_device_event_synchronize,\n};\n\n// backend reg\nstruct ggml_backend_cann_reg_context {\n    std::vector<ggml_backend_dev_t> devices;\n};\n\nstatic const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {\n    GGML_UNUSED(reg);\n    return GGML_CANN_NAME;\n}\n\nstatic size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {\n    ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *) reg->context;\n    return ctx->devices.size();\n}\n\nstatic ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {\n    ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *) reg->context;\n    GGML_ASSERT(index < ctx->devices.size());\n    return ctx->devices[index];\n}\n\nstatic void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {\n    GGML_UNUSED(reg);\n    GGML_UNUSED(name);\n    // reserved for future use\n    return nullptr;\n}\n\nstatic const ggml_backend_reg_i ggml_backend_cann_reg_interface = {\n    /* .get_name          = */ ggml_backend_cann_reg_get_name,\n    /* .get_device_count  = */ ggml_backend_cann_reg_get_device_count,\n    /* .get_device        = */ ggml_backend_cann_reg_get_device,\n    /* .get_proc_address  = */ ggml_backend_cann_reg_get_proc_address,\n};\n\n// backend registry, called only once for cann backend\nggml_backend_reg_t ggml_backend_cann_reg() {\n    static ggml_backend_reg reg;\n    static bool             initialized = false;\n\n    {\n        static std::mutex           mutex;\n        std::lock_guard<std::mutex> lock(mutex);\n        if (!initialized) {\n            aclInit(nullptr);\n            ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;\n            const int min_batch_size = getenv(\"GGML_OP_OFFLOAD_MIN_BATCH\") ? atoi(getenv(\"GGML_OP_OFFLOAD_MIN_BATCH\")) : 32;\n\n            for (int i = 0; i < ggml_cann_info().device_count; i++) {\n                ggml_backend_cann_device_context * dev_ctx = new ggml_backend_cann_device_context();\n                dev_ctx->description                       = aclrtGetSocName();\n                dev_ctx->device                            = i;\n                dev_ctx->name                              = GGML_CANN_NAME + std::to_string(i);\n                dev_ctx->op_offload_min_batch_size         = min_batch_size;\n                ggml_cann_set_device(i);\n                ggml_backend_dev_t dev = new ggml_backend_device{ /* .iface   = */ ggml_backend_cann_device_interface,\n                                                                  /* .reg     = */ &reg,\n                                                                  /* .context = */ dev_ctx };\n                ctx->devices.push_back(dev);\n            }\n\n            reg = ggml_backend_reg{ /* .api_version = */ GGML_BACKEND_API_VERSION,\n                                    /* .iface       = */ ggml_backend_cann_reg_interface,\n                                    /* .context     = */ ctx };\n        }\n\n        initialized = true;\n    }\n\n    return &reg;\n}\n\nggml_backend_t ggml_backend_cann_init(int32_t device) {\n    aclInit(nullptr);\n    if (device < 0 || device >= ggml_backend_cann_get_device_count()) {\n        GGML_LOG_ERROR(\"%s: error: invalid device %d\\n\", __func__, device);\n        return nullptr;\n    }\n\n    ggml_backend_cann_context * ctx = new ggml_backend_cann_context(device);\n    if (ctx == nullptr) {\n        GGML_LOG_ERROR(\"%s: error: failed to allocate context\\n\", __func__);\n        return nullptr;\n    }\n    ggml_cann_set_device(ctx->device);\n    ggml_backend_t cann_backend =\n        new ggml_backend{ /* .guid      = */ ggml_backend_cann_guid(),\n                          /* .interface = */ ggml_backend_cann_interface,\n                          /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),\n                          /* .context   = */ ctx };\n\n    return cann_backend;\n}\n\nbool ggml_backend_is_cann(ggml_backend_t backend) {\n    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cann_guid());\n}\n\nint32_t ggml_backend_cann_get_device_count() {\n    return ggml_cann_info().device_count;\n}\n\nvoid ggml_backend_cann_get_device_description(int32_t device, char * description, size_t description_size) {\n    ggml_cann_set_device(device);\n    const char * soc_name = aclrtGetSocName();\n    snprintf(description, description_size, \"%s\", soc_name);\n}\n\nvoid ggml_backend_cann_get_device_memory(int32_t device, size_t * free, size_t * total) {\n    ggml_cann_set_device(device);\n    ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));\n}\n\nGGML_BACKEND_DL_IMPL(ggml_backend_cann_reg)\n"
  },
  {
    "path": "src/ggml-common.h",
    "content": "#ifndef GGML_COMMON_DECL\n\n#if defined(GGML_COMMON_DECL_C)\n#include <stdint.h>\n\ntypedef uint16_t ggml_half;\ntypedef uint32_t ggml_half2;\n\n#define GGML_COMMON_AGGR_U\n#define GGML_COMMON_AGGR_S\n\n#define GGML_COMMON_DECL\n#elif defined(GGML_COMMON_DECL_CPP)\n#include <cstdint>\n\ntypedef uint16_t ggml_half;\ntypedef uint32_t ggml_half2;\n\n// std-c++ allow anonymous unions but some compiler warn on it\n#define GGML_COMMON_AGGR_U data\n// std-c++ do not allow it.\n#define GGML_COMMON_AGGR_S data\n\n#define GGML_COMMON_DECL\n#elif defined(GGML_COMMON_DECL_METAL)\n#include <metal_stdlib>\n\ntypedef half  ggml_half;\ntypedef half2 ggml_half2;\n\n#define GGML_COMMON_AGGR_U\n#define GGML_COMMON_AGGR_S\n\n#define GGML_COMMON_DECL\n#elif defined(GGML_COMMON_DECL_CUDA)\n#if defined(GGML_COMMON_DECL_MUSA)\n#include <musa_fp16.h>\n#else\n#include <cuda_fp16.h>\n#endif\n#include <cstdint>\n\ntypedef half  ggml_half;\ntypedef half2 ggml_half2;\n\n#define GGML_COMMON_AGGR_U\n#define GGML_COMMON_AGGR_S data\n\n#define GGML_COMMON_DECL\n#elif defined(GGML_COMMON_DECL_HIP)\n#include <hip/hip_fp16.h>\n#include <cstdint>\n\ntypedef half  ggml_half;\ntypedef half2 ggml_half2;\n\n#define GGML_COMMON_AGGR_U\n#define GGML_COMMON_AGGR_S data\n\n#define GGML_COMMON_DECL\n#elif defined(GGML_COMMON_DECL_SYCL)\n#include <sycl/half_type.hpp>\n#include <cstdint>\n\ntypedef sycl::half  ggml_half;\ntypedef sycl::half2 ggml_half2;\n\n#define GGML_COMMON_AGGR_U\n#define GGML_COMMON_AGGR_S data\n\n#define GGML_COMMON_DECL\n#endif\n\n#if defined(GGML_COMMON_DECL)\n\n#ifndef __cplusplus\n#ifndef static_assert\n#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)\n#define static_assert(cond, msg) _Static_assert(cond, msg)\n#else\n#define static_assert(cond, msg) struct global_scope_noop_trick\n#endif\n#endif\n#endif // __cplusplus\n\n// QK = number of values after dequantization\n// QK_K = super-block size\n\n#define QK_K 256\n#define K_SCALE_SIZE 12\n\n#if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL)\n// QR = QK / number of values before dequantization\n// QI = number of 32 bit integers before dequantization\n\n#define QI4_0 (QK4_0 / (4 * QR4_0))\n#define QR4_0 2\n\n#define QI4_1 (QK4_1 / (4 * QR4_1))\n#define QR4_1 2\n\n#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))\n#define QR_MXFP4 2\n\n#define QI_NVFP4 (QK_NVFP4 / (4 * QR_NVFP4))\n#define QR_NVFP4 2\n\n#define QI5_0 (QK5_0 / (4 * QR5_0))\n#define QR5_0 2\n\n#define QI5_1 (QK5_1 / (4 * QR5_1))\n#define QR5_1 2\n\n#define QI8_0 (QK8_0 / (4 * QR8_0))\n#define QR8_0 1\n\n#define QI8_1 (QK8_1 / (4 * QR8_1))\n#define QR8_1 1\n\n#define QI2_K (QK_K / (4*QR2_K))\n#define QR2_K 4\n\n#define QI3_K (QK_K / (4*QR3_K))\n#define QR3_K 4\n\n#define QI4_K (QK_K / (4*QR4_K))\n#define QR4_K 2\n\n#define QI5_K (QK_K / (4*QR5_K))\n#define QR5_K 2\n\n#define QI6_K (QK_K / (4*QR6_K))\n#define QR6_K 2\n\n#define QI2_XXS (QK_K / (4*QR2_XXS))\n#define QR2_XXS 4\n\n#define QI2_XS (QK_K / (4*QR2_XS))\n#define QR2_XS 4\n\n#define QI2_S (QK_K / (4*QR2_S))\n#define QR2_S 4\n\n#define QI3_XXS (QK_K / (4*QR3_XXS))\n#define QR3_XXS 4\n\n#define QI3_XS (QK_K / (4*QR3_XS))\n#define QR3_XS 4\n\n#define QI1_S (QK_K / (4*QR1_S))\n#define QR1_S 8\n\n#define QI1_M (QK_K / (4*QR1_M))\n#define QR1_M 8\n\n#define QI4_NL (QK4_NL / (4*QR4_NL))\n#define QR4_NL 2\n\n#define QI4_XS (QK_K / (4*QR4_XS))\n#define QR4_XS 2\n\n#define QI3_S (QK_K / (4*QR3_S))\n#define QR3_S 4\n\n#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP\n\n#ifdef _MSC_VER\n#define GGML_EXTENSION\n#else // _MSC_VER\n#define GGML_EXTENSION __extension__\n#endif // _MSC_VER\n\n#define QK4_0 32\ntypedef struct {\n    ggml_half d;           // delta\n    uint8_t qs[QK4_0 / 2]; // nibbles / quants\n} block_q4_0;\nstatic_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, \"wrong q4_0 block size/padding\");\n\n#define QK4_1 32\ntypedef struct {\n    GGML_EXTENSION union {\n        struct {\n            ggml_half d; // delta\n            ggml_half m; // min\n        } GGML_COMMON_AGGR_S;\n        ggml_half2 dm;\n    } GGML_COMMON_AGGR_U;\n    uint8_t qs[QK4_1 / 2]; // nibbles / quants\n} block_q4_1;\nstatic_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, \"wrong q4_1 block size/padding\");\n\n#define QK_MXFP4 32\ntypedef struct {\n    uint8_t e; // E8M0\n    uint8_t qs[QK_MXFP4/2];\n} block_mxfp4;\nstatic_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, \"wrong mxfp4 block size/padding\");\n\n#define QK_NVFP4 64\n#define QK_NVFP4_SUB 16  // sub-block size for per-group scales\ntypedef struct {\n    uint8_t d[QK_NVFP4/QK_NVFP4_SUB]; // UE4M3 scales (4 bytes, one per 16-element sub-block)\n    uint8_t qs[QK_NVFP4/2];           // packed 4-bit E2M1 values (32 bytes)\n} block_nvfp4;\nstatic_assert(sizeof(block_nvfp4) == sizeof(uint8_t)*(QK_NVFP4/QK_NVFP4_SUB) + QK_NVFP4/2, \"wrong nvfp4 block size/padding\");\n\n#define QK5_0 32\ntypedef struct {\n    ggml_half d;           // delta\n    uint8_t qh[4];         // 5-th bit of quants\n    uint8_t qs[QK5_0 / 2]; // nibbles / quants\n} block_q5_0;\nstatic_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0 / 2, \"wrong q5_0 block size/padding\");\n\n#define QK5_1 32\ntypedef struct {\n    GGML_EXTENSION union {\n        struct {\n            ggml_half d; // delta\n            ggml_half m; // min\n        } GGML_COMMON_AGGR_S;\n        ggml_half2 dm;\n    } GGML_COMMON_AGGR_U;\n    uint8_t qh[4];         // 5-th bit of quants\n    uint8_t qs[QK5_1 / 2]; // nibbles / quants\n} block_q5_1;\nstatic_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_half) + sizeof(uint32_t) + QK5_1 / 2, \"wrong q5_1 block size/padding\");\n\n#define QK8_0 32\ntypedef struct {\n    ggml_half d;       // delta\n    int8_t  qs[QK8_0]; // quants\n} block_q8_0;\nstatic_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, \"wrong q8_0 block size/padding\");\n\n#define QK8_1 32\ntypedef struct {\n    GGML_EXTENSION union {\n        struct {\n            ggml_half d; // delta\n            ggml_half s; // d * sum(qs[i])\n        } GGML_COMMON_AGGR_S;\n        ggml_half2 ds;\n    } GGML_COMMON_AGGR_U;\n    int8_t qs[QK8_1]; // quants\n} block_q8_1;\nstatic_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, \"wrong q8_1 block size/padding\");\n\n//\n// Ternary quantization\n//\n\n// 1.6875 bpw\ntypedef struct {\n    uint8_t qs[(QK_K - 4 * QK_K / 64) / 5]; // 5 elements per byte (3^5 = 243 < 256)\n    uint8_t qh[QK_K/64]; // 4 elements per byte\n    ggml_half d;\n} block_tq1_0;\nstatic_assert(sizeof(block_tq1_0) == sizeof(ggml_half) + QK_K / 64 + (QK_K - 4 * QK_K / 64) / 5, \"wrong tq1_0 block size/padding\");\n\n// 2.0625 bpw\ntypedef struct {\n    uint8_t qs[QK_K/4]; // 2 bits per element\n    ggml_half d;\n} block_tq2_0;\nstatic_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, \"wrong tq2_0 block size/padding\");\n\n//\n// Super-block quantization structures\n//\n\n// 2-bit quantization\n// weight is represented as x = a * q + b\n// 16 blocks of 16 elements each\n// Effectively 2.625 bits per weight\ntypedef struct {\n    uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits\n    uint8_t qs[QK_K/4];      // quants\n    GGML_EXTENSION union {\n        struct {\n            ggml_half d;    // super-block scale for quantized scales\n            ggml_half dmin; // super-block scale for quantized mins\n        } GGML_COMMON_AGGR_S;\n        ggml_half2 dm;\n    } GGML_COMMON_AGGR_U;\n} block_q2_K;\nstatic_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, \"wrong q2_K block size/padding\");\n\n// 3-bit quantization\n// weight is represented as x = a * q\n// 16 blocks of 16 elements each\n// Effectively 3.4375 bits per weight\ntypedef struct {\n    uint8_t hmask[QK_K/8]; // quants - high bit\n    uint8_t qs[QK_K/4];    // quants - low 2 bits\n    uint8_t scales[12];    // scales, quantized with 6 bits\n    ggml_half d;           // super-block scale\n} block_q3_K;\nstatic_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12, \"wrong q3_K block size/padding\");\n\n// 4-bit quantization\n// 8 blocks of 32 elements each\n// weight is represented as x = a * q + b\n// Effectively 4.5 bits per weight\ntypedef struct {\n    GGML_EXTENSION union {\n        struct {\n            ggml_half d;    // super-block scale for quantized scales\n            ggml_half dmin; // super-block scale for quantized mins\n        } GGML_COMMON_AGGR_S;\n        ggml_half2 dm;\n    } GGML_COMMON_AGGR_U;\n    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits\n    uint8_t qs[QK_K/2];           // 4--bit quants\n} block_q4_K;\nstatic_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, \"wrong q4_K block size/padding\");\n\n// 5-bit quantization\n// 8 blocks of 32 elements each\n// weight is represented as x = a * q + b\n// Effectively 5.5 bits per weight\ntypedef struct {\n    GGML_EXTENSION union {\n        struct {\n            ggml_half d;    // super-block scale for quantized scales\n            ggml_half dmin; // super-block scale for quantized mins\n        } GGML_COMMON_AGGR_S;\n        ggml_half2 dm;\n    } GGML_COMMON_AGGR_U;\n    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits\n    uint8_t qh[QK_K/8];           // quants, high bit\n    uint8_t qs[QK_K/2];           // quants, low 4 bits\n} block_q5_K;\nstatic_assert(sizeof(block_q5_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, \"wrong q5_K block size/padding\");\n\n// 6-bit quantization\n// weight is represented as x = a * q\n// 16 blocks of 16 elements each\n// Effectively 6.5625 bits per weight\ntypedef struct {\n    uint8_t ql[QK_K/2];      // quants, lower 4 bits\n    uint8_t qh[QK_K/4];      // quants, upper 2 bits\n    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits\n    ggml_half d;             // super-block scale\n} block_q6_K;\nstatic_assert(sizeof(block_q6_K) == sizeof(ggml_half) + QK_K / 16 + 3*QK_K/4, \"wrong q6_K block size/padding\");\n\n// This is only used for intermediate quantization and dot products\ntypedef struct {\n    float   d;              // delta\n    int8_t  qs[QK_K];       // quants\n    int16_t bsums[QK_K/16]; // sum of quants in groups of 16\n} block_q8_K;\nstatic_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), \"wrong q8_K block size/padding\");\n\n// (Almost) \"true\" 2-bit quantization.\n// Due to the need to use blocks as per ggml design, it ends up using\n// 2.0625 bpw because of the 16-bit scale for each block of 256.\ntypedef struct {\n    ggml_half d;\n    uint16_t qs[QK_K/8];\n} block_iq2_xxs;\nstatic_assert(sizeof(block_iq2_xxs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t), \"wrong iq2_xxs block size/padding\");\n\n// 2.3125 bpw quants\ntypedef struct {\n    ggml_half d;\n    uint16_t qs[QK_K/8];\n    uint8_t  scales[QK_K/32];\n} block_iq2_xs;\nstatic_assert(sizeof(block_iq2_xs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t) + QK_K/32, \"wrong iq2_xs block size/padding\");\n\n// 2.5625 bpw quants\ntypedef struct {\n    ggml_half d;\n    uint8_t qs[QK_K/4];\n    uint8_t qh[QK_K/32];\n    uint8_t scales[QK_K/32];\n} block_iq2_s;\nstatic_assert(sizeof(block_iq2_s) == sizeof(ggml_half) + QK_K/4 + QK_K/16, \"wrong iq2_s block size/padding\");\n\n// (Almost) \"true\" 3-bit quantization.\n// Due to the need to use blocks as per ggml design, it ends up using\n// 3.0625 bpw because of the 16-bit scale for each block of 256.\ntypedef struct {\n    ggml_half d;\n    uint8_t qs[3*QK_K/8];\n} block_iq3_xxs;\nstatic_assert(sizeof(block_iq3_xxs) == sizeof(ggml_half) + 3*(QK_K/8), \"wrong iq3_xxs block size/padding\");\n\n// 3.4375 bpw\n#define IQ3S_N_SCALE QK_K/64\ntypedef struct {\n    ggml_half d;\n    uint8_t qs[QK_K/4];\n    uint8_t qh[QK_K/32];\n    uint8_t signs[QK_K/8];\n    uint8_t scales[IQ3S_N_SCALE];\n} block_iq3_s;\nstatic_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, \"wrong iq3_s block size/padding\");\n\n// 1.5625 bpw\ntypedef struct {\n    ggml_half d;\n    uint8_t  qs[QK_K/8];\n    uint16_t qh[QK_K/32];\n} block_iq1_s;\nstatic_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, \"wrong iq1_s block size/padding\");\n\n// 1.75 bpw\ntypedef struct {\n    uint8_t  qs[QK_K/8];      // grid index, low 8 bits\n    uint8_t  qh[QK_K/16];     // grid index, high 3 bits + grid shift bit (for two groups of 8)\n    uint8_t  scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64)\n} block_iq1_m;\nstatic_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, \"wrong iq1_m block size/padding\");\n\n// Used by IQ1_M quants\ntypedef union {\n    ggml_half f16;\n    uint16_t  u16;\n} iq1m_scale_t;\n\n// Non-linear quants\n#define QK4_NL 32\ntypedef struct {\n    ggml_half d;\n    uint8_t qs[QK4_NL/2];\n} block_iq4_nl;\nstatic_assert(sizeof(block_iq4_nl) == sizeof(ggml_half) + QK4_NL/2, \"wrong iq4_nl block size/padding\");\n\ntypedef struct {\n    ggml_half d;\n    uint16_t scales_h;\n    uint8_t  scales_l[QK_K/64];\n    uint8_t  qs[QK_K/2];\n} block_iq4_xs;\nstatic_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, \"wrong iq4_xs block size/padding\");\n\n#endif // GGML_COMMON_DECL\n#endif // GGML_COMMON_DECL\n\n////////////////////////////////////////////////////////////////////////////////\n\n#ifndef GGML_COMMON_IMPL\n\n#if defined(GGML_COMMON_IMPL_C)\n#include <stdint.h>\n\n#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {\n#define GGML_TABLE_END() };\n\n#define GGML_COMMON_IMPL\n#elif defined(GGML_COMMON_IMPL_CPP)\n#include <cstdint>\n\n#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {\n#define GGML_TABLE_END() };\n\n#define GGML_COMMON_IMPL\n#elif defined(GGML_COMMON_IMPL_METAL)\n#include <metal_stdlib>\n\n#define GGML_TABLE_BEGIN(type, name, size) static const constant type name[size] = {\n#define GGML_TABLE_END() };\n\n#define GGML_COMMON_IMPL\n#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)\n#include <cstdint>\n\n#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {\n#define GGML_TABLE_END() };\n\n#define GGML_COMMON_IMPL\n#elif defined(GGML_COMMON_IMPL_SYCL)\n\n#include <cstdint>\n\n#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {\n#define GGML_TABLE_END() };\n\n#define GGML_COMMON_IMPL\n#endif\n\n#if defined(GGML_COMMON_IMPL)\n\nGGML_TABLE_BEGIN(uint8_t, kmask_iq2xs, 8)\n    1, 2, 4, 8, 16, 32, 64, 128\nGGML_TABLE_END()\n\nGGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)\n      0, 129, 130,   3, 132,   5,   6, 135, 136,   9,  10, 139,  12, 141, 142,  15,\n    144,  17,  18, 147,  20, 149, 150,  23,  24, 153, 154,  27, 156,  29,  30, 159,\n    160,  33,  34, 163,  36, 165, 166,  39,  40, 169, 170,  43, 172,  45,  46, 175,\n     48, 177, 178,  51, 180,  53,  54, 183, 184,  57,  58, 187,  60, 189, 190,  63,\n    192,  65,  66, 195,  68, 197, 198,  71,  72, 201, 202,  75, 204,  77,  78, 207,\n     80, 209, 210,  83, 212,  85,  86, 215, 216,  89,  90, 219,  92, 221, 222,  95,\n     96, 225, 226,  99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,\n    240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,\nGGML_TABLE_END()\n\nGGML_TABLE_BEGIN(uint64_t, ksigns64, 128)\n    0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,\n    0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,\n    0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff,\n    0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff,\n    0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff,\n    0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff,\n    0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff,\n    0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff,\n    0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff,\n    0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff,\n    0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff,\n    0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff,\n    0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff,\n    0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff,\n    0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff,\n    0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff,\n    0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff,\n    0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff,\n    0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff,\n    0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff,\n    0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff,\n    0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff,\n    0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff,\n    0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff,\n    0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff,\n    0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff,\n    0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff,\n    0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff,\n    0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff,\n    0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff,\n    0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,\n    0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,\nGGML_TABLE_END()\n\n\nGGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)\n    0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,\n    0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,\n    0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,\n    0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,\n    0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,\n    0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,\n    0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,\n    0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,\n    0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,\n    0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,\n    0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,\n    0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,\n    0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,\n    0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,\n    0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,\n    0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,\n    0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,\n    0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,\n    0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,\n    0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,\n    0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,\n    0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,\n    0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,\n    0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,\n    0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,\n    0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,\n    0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,\n    0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,\n    0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,\n    0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,\n    0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,\n    0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,\n    0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,\n    0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,\n    0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,\n    0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,\n    0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,\n    0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,\n    0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,\n    0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,\n    0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,\n    0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,\n    0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,\n    0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,\n    0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,\n    0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,\n    0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,\n    0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,\n    0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,\n    0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,\n    0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,\n    0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,\n    0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,\n    0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,\n    0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,\n    0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,\n    0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,\n    0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,\n    0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,\n    0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,\n    0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,\n    0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,\n    0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,\n    0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,\nGGML_TABLE_END()\n\nGGML_TABLE_BEGIN(uint64_t, iq2xs_grid, 512)\n    0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,\n    0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,\n    0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,\n    0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,\n    0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,\n    0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,\n    0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,\n    0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,\n    0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,\n    0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,\n    0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,\n    0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,\n    0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,\n    0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,\n    0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,\n    0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,\n    0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,\n    0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,\n    0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,\n    0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,\n    0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,\n    0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,\n    0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,\n    0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,\n    0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,\n    0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,\n    0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,\n    0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,\n    0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,\n    0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,\n    0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,\n    0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,\n    0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,\n    0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,\n    0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,\n    0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,\n    0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,\n    0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,\n    0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,\n    0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,\n    0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,\n    0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,\n    0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,\n    0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,\n    0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,\n    0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,\n    0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,\n    0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,\n    0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,\n    0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,\n    0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,\n    0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,\n    0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,\n    0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,\n    0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,\n    0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,\n    0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,\n    0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,\n    0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,\n    0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,\n    0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,\n    0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,\n    0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,\n    0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,\n    0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,\n    0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,\n    0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,\n    0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,\n    0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,\n    0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,\n    0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,\n    0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,\n    0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,\n    0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,\n    0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,\n    0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,\n    0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,\n    0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,\n    0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,\n    0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,\n    0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,\n    0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,\n    0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,\n    0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,\n    0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,\n    0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,\n    0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,\n    0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,\n    0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,\n    0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,\n    0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,\n    0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,\n    0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,\n    0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,\n    0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,\n    0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,\n    0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,\n    0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,\n    0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,\n    0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,\n    0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,\n    0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,\n    0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,\n    0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,\n    0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,\n    0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,\n    0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,\n    0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,\n    0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,\n    0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,\n    0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,\n    0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,\n    0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,\n    0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,\n    0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,\n    0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,\n    0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,\n    0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,\n    0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,\n    0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,\n    0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,\n    0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,\n    0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,\n    0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,\n    0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,\n    0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,\n    0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,\n    0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,\nGGML_TABLE_END()\n\nGGML_TABLE_BEGIN(uint64_t, iq2s_grid, 1024)\n    0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,\n    0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,\n    0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,\n    0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,\n    0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,\n    0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,\n    0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,\n    0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,\n    0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,\n    0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,\n    0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,\n    0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,\n    0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,\n    0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,\n    0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,\n    0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,\n    0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,\n    0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,\n    0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,\n    0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,\n    0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,\n    0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,\n    0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,\n    0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,\n    0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,\n    0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,\n    0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,\n    0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,\n    0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,\n    0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,\n    0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,\n    0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,\n    0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,\n    0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,\n    0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,\n    0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,\n    0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,\n    0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,\n    0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,\n    0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,\n    0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,\n    0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,\n    0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,\n    0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,\n    0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,\n    0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,\n    0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,\n    0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,\n    0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,\n    0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,\n    0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,\n    0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,\n    0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,\n    0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,\n    0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,\n    0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,\n    0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,\n    0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,\n    0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,\n    0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,\n    0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,\n    0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,\n    0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,\n    0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,\n    0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,\n    0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,\n    0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,\n    0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,\n    0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,\n    0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,\n    0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,\n    0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,\n    0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,\n    0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,\n    0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,\n    0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,\n    0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,\n    0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,\n    0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,\n    0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,\n    0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,\n    0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,\n    0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,\n    0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,\n    0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,\n    0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,\n    0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,\n    0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,\n    0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,\n    0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,\n    0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,\n    0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,\n    0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,\n    0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,\n    0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,\n    0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,\n    0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,\n    0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,\n    0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,\n    0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,\n    0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,\n    0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,\n    0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,\n    0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,\n    0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,\n    0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,\n    0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,\n    0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,\n    0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,\n    0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,\n    0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,\n    0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,\n    0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,\n    0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,\n    0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,\n    0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,\n    0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,\n    0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,\n    0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,\n    0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,\n    0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,\n    0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,\n    0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,\n    0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,\n    0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,\n    0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,\n    0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,\n    0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,\n    0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,\n    0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,\n    0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,\n    0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,\n    0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,\n    0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,\n    0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,\n    0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,\n    0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,\n    0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,\n    0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,\n    0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,\n    0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,\n    0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,\n    0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,\n    0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,\n    0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,\n    0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,\n    0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,\n    0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,\n    0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,\n    0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,\n    0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,\n    0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,\n    0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,\n    0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,\n    0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,\n    0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,\n    0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,\n    0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,\n    0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,\n    0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,\n    0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,\n    0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,\n    0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,\n    0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,\n    0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,\n    0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,\n    0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,\n    0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,\n    0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,\n    0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,\n    0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,\n    0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,\n    0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,\n    0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,\n    0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,\n    0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,\n    0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,\n    0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,\n    0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,\n    0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,\n    0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,\n    0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,\n    0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,\n    0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,\n    0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,\n    0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,\n    0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,\n    0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,\n    0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,\n    0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,\n    0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,\n    0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,\n    0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,\n    0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,\n    0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,\n    0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,\n    0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,\n    0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,\n    0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,\n    0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,\n    0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,\n    0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,\n    0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,\n    0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,\n    0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,\n    0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,\n    0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,\n    0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,\n    0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,\n    0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,\n    0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,\n    0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,\n    0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,\n    0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,\n    0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,\n    0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,\n    0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,\n    0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,\n    0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,\n    0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,\n    0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,\n    0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,\n    0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,\n    0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,\n    0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,\n    0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,\n    0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,\n    0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,\n    0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,\n    0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,\n    0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,\n    0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,\n    0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,\n    0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,\n    0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,\n    0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,\n    0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,\n    0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,\n    0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,\n    0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,\n    0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,\n    0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,\n    0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,\n    0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,\n    0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,\n    0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,\n    0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,\n    0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,\n    0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,\n    0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,\n    0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,\n    0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,\n    0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,\n    0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,\n    0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,\n    0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,\nGGML_TABLE_END()\n\nGGML_TABLE_BEGIN(uint32_t, iq3xxs_grid, 256)\n    0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,\n    0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,\n    0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,\n    0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,\n    0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,\n    0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,\n    0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,\n    0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,\n    0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,\n    0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,\n    0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,\n    0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,\n    0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,\n    0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,\n    0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,\n    0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,\n    0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,\n    0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,\n    0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,\n    0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,\n    0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,\n    0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,\n    0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,\n    0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,\n    0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,\n    0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,\n    0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,\n    0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,\n    0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,\n    0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,\n    0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,\n    0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,\nGGML_TABLE_END()\n\nGGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)\n    0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,\n    0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,\n    0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,\n    0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,\n    0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,\n    0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,\n    0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,\n    0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,\n    0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,\n    0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,\n    0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,\n    0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,\n    0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,\n    0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,\n    0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,\n    0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,\n    0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,\n    0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,\n    0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,\n    0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,\n    0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,\n    0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,\n    0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,\n    0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,\n    0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,\n    0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,\n    0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,\n    0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,\n    0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,\n    0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,\n    0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,\n    0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,\n    0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,\n    0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,\n    0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,\n    0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,\n    0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,\n    0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,\n    0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,\n    0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,\n    0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,\n    0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,\n    0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,\n    0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,\n    0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,\n    0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,\n    0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,\n    0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,\n    0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,\n    0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,\n    0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,\n    0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,\n    0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,\n    0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,\n    0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,\n    0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,\n    0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,\n    0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,\n    0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,\n    0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,\n    0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,\n    0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,\n    0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,\n    0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,\nGGML_TABLE_END()\n\n// TODO: fix name to kvalues_iq4_nl\nGGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)\n    -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,\nGGML_TABLE_END()\n\n// e2m1 values (doubled)\n// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf\nGGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)\n    0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,\nGGML_TABLE_END()\n\n#define NGRID_IQ1S 2048\n#define IQ1S_DELTA 0.125f\n#define IQ1M_DELTA 0.125f\n#if defined(GGML_COMMON_IMPL_C)\nGGML_TABLE_BEGIN(uint64_t, iq1s_grid, NGRID_IQ1S)\n    0xffffffffffffffff, 0xffffffffffffff01, 0xffffffffffff0000, 0xffffffffffff01ff,\n    0xffffffffffff0101, 0xffffffffff00ff00, 0xffffffffff000000, 0xffffffffff01ffff,\n    0xffffffffff01ff01, 0xffffffffff0101ff, 0xffffffffff010101, 0xffffffff00ff0000,\n    0xffffffff0000ff00, 0xffffffff000000ff, 0xffffffff00000001, 0xffffffff00010000,\n    0xffffffff01ffffff, 0xffffffff01ffff01, 0xffffffff01ff01ff, 0xffffffff01ff0101,\n    0xffffffff01000000, 0xffffffff0101ffff, 0xffffffff0101ff01, 0xffffffff010101ff,\n    0xffffffff01010101, 0xffffff00ffff00ff, 0xffffff00ffff0000, 0xffffff00ff00ff00,\n    0xffffff00ff0000ff, 0xffffff00ff000001, 0xffffff00ff000100, 0xffffff00ff000101,\n    0xffffff00ff010000, 0xffffff0000ffff00, 0xffffff0000ff0001, 0xffffff0000ff0100,\n    0xffffff000000ff01, 0xffffff0000000000, 0xffffff0000000101, 0xffffff000001ff00,\n    0xffffff00000100ff, 0xffffff0000010001, 0xffffff00000101ff, 0xffffff0001ff0000,\n    0xffffff000100ff00, 0xffffff00010000ff, 0xffffff0001000001, 0xffffff0001010000,\n    0xffffff01ffffffff, 0xffffff01ffffff01, 0xffffff01ffff01ff, 0xffffff01ffff0101,\n    0xffffff01ff000000, 0xffffff01ff01ffff, 0xffffff01ff01ff01, 0xffffff01ff0101ff,\n    0xffffff01ff010101, 0xffffff0100ff0000, 0xffffff010000ff00, 0xffffff0100000100,\n    0xffffff01000100ff, 0xffffff0100010100, 0xffffff0101ffffff, 0xffffff0101ffff01,\n    0xffffff0101ff01ff, 0xffffff0101ff0101, 0xffffff010100ff00, 0xffffff0101000000,\n    0xffffff0101000100, 0xffffff010101ffff, 0xffffff010101ff01, 0xffffff01010101ff,\n    0xffffff0101010101, 0xffff00ffff00ff00, 0xffff00ffff0000ff, 0xffff00ffff000001,\n    0xffff00ffff010000, 0xffff00ff00ffff00, 0xffff00ff00ff0100, 0xffff00ff00000000,\n    0xffff00ff00000101, 0xffff00ff000100ff, 0xffff00ff00010000, 0xffff00ff0100ff00,\n    0xffff00ff01000100, 0xffff00ff01010000, 0xffff0000ffffff00, 0xffff0000ffff00ff,\n    0xffff0000ffff0000, 0xffff0000ffff0001, 0xffff0000ff000000, 0xffff0000ff0001ff,\n    0xffff0000ff000101, 0xffff0000ff010100, 0xffff000000ffffff, 0xffff000000ff0000,\n    0xffff000000ff0101, 0xffff00000000ffff, 0xffff00000000ff00, 0xffff0000000000ff,\n    0xffff000000000000, 0xffff000000000001, 0xffff000000000100, 0xffff00000001ffff,\n    0xffff00000001ff01, 0xffff000000010000, 0xffff0000000101ff, 0xffff000000010101,\n    0xffff000001ffff00, 0xffff00000100ff00, 0xffff000001000000, 0xffff0000010001ff,\n    0xffff000001000101, 0xffff00000101ff00, 0xffff0000010100ff, 0xffff000001010000,\n    0xffff000001010001, 0xffff000001010100, 0xffff0001ff0000ff, 0xffff0001ff000100,\n    0xffff000100ffff00, 0xffff000100ff00ff, 0xffff00010000ffff, 0xffff00010000ff01,\n    0xffff000100000000, 0xffff0001000001ff, 0xffff00010001ffff, 0xffff00010001ff00,\n    0xffff000100010001, 0xffff000100010100, 0xffff000101ff0000, 0xffff00010100ff00,\n    0xffff0001010000ff, 0xffff000101000100, 0xffff01ffffffffff, 0xffff01ffffffff01,\n    0xffff01ffffff01ff, 0xffff01ffffff0101, 0xffff01ffff000000, 0xffff01ffff01ffff,\n    0xffff01ffff01ff01, 0xffff01ffff0101ff, 0xffff01ffff010101, 0xffff01ff00ff0000,\n    0xffff01ff0000ff00, 0xffff01ff00000001, 0xffff01ff00010000, 0xffff01ff01ffffff,\n    0xffff01ff01ffff01, 0xffff01ff01ff01ff, 0xffff01ff01ff0101, 0xffff01ff01000000,\n    0xffff01ff0101ffff, 0xffff01ff0101ff01, 0xffff01ff010101ff, 0xffff01ff01010101,\n    0xffff0100ffff0000, 0xffff0100ff00ff00, 0xffff0100ff0000ff, 0xffff0100ff000100,\n    0xffff0100ff0100ff, 0xffff0100ff010000, 0xffff010000ffff00, 0xffff01000000ffff,\n    0xffff01000000ff00, 0xffff010000000000, 0xffff01000001ff00, 0xffff0100000100ff,\n    0xffff010000010100, 0xffff01000100ff00, 0xffff0100010000ff, 0xffff010001000001,\n    0xffff010001000100, 0xffff010001010000, 0xffff0101ffffffff, 0xffff0101ffffff01,\n    0xffff0101ffff01ff, 0xffff0101ffff0101, 0xffff0101ff000000, 0xffff0101ff01ffff,\n    0xffff0101ff01ff01, 0xffff0101ff0101ff, 0xffff0101ff010101, 0xffff010100ff0000,\n    0xffff01010000ff00, 0xffff010100000100, 0xffff01010001ff00, 0xffff010100010000,\n    0xffff010101ffffff, 0xffff010101ffff01, 0xffff010101ff0000, 0xffff010101ff01ff,\n    0xffff010101ff0101, 0xffff010101000000, 0xffff01010101ffff, 0xffff01010101ff01,\n    0xffff0101010101ff, 0xffff010101010101, 0xff00ffffff00ffff, 0xff00ffffff00ff00,\n    0xff00ffffff0000ff, 0xff00ffffff000100, 0xff00ffffff0100ff, 0xff00ffffff010000,\n    0xff00ffff00ffff00, 0xff00ffff00ff00ff, 0xff00ffff0000ffff, 0xff00ffff00000000,\n    0xff00ffff000001ff, 0xff00ffff0001ff00, 0xff00ffff000100ff, 0xff00ffff00010000,\n    0xff00ffff00010100, 0xff00ffff0100ff00, 0xff00ffff010000ff, 0xff00ffff01000001,\n    0xff00ffff0101ff00, 0xff00ffff01010000, 0xff00ff00ffffff00, 0xff00ff00ffff00ff,\n    0xff00ff00ffff0001, 0xff00ff00ffff0100, 0xff00ff00ff00ffff, 0xff00ff00ff00ff01,\n    0xff00ff00ff000000, 0xff00ff00ff0001ff, 0xff00ff00ff01ff00, 0xff00ff00ff0100ff,\n    0xff00ff00ff010100, 0xff00ff0000ff0000, 0xff00ff0000ff0101, 0xff00ff000000ffff,\n    0xff00ff000000ff00, 0xff00ff000000ff01, 0xff00ff00000000ff, 0xff00ff0000000000,\n    0xff00ff0000000001, 0xff00ff0000000100, 0xff00ff000001ffff, 0xff00ff0000010000,\n    0xff00ff0001ff00ff, 0xff00ff000100ff01, 0xff00ff0001000000, 0xff00ff000101ff00,\n    0xff00ff00010100ff, 0xff00ff01ff00ff00, 0xff00ff01ff0000ff, 0xff00ff01ff000001,\n    0xff00ff01ff010000, 0xff00ff0100ffffff, 0xff00ff0100ff0001, 0xff00ff0100ff0100,\n    0xff00ff010000ff01, 0xff00ff0100000000, 0xff00ff01000001ff, 0xff00ff0100000101,\n    0xff00ff01000100ff, 0xff00ff0100010001, 0xff00ff0101ff0000, 0xff00ff010100ff00,\n    0xff00ff01010000ff, 0xff00ff0101000001, 0xff00ff0101010000, 0xff0000ffffffff00,\n    0xff0000ffffff0001, 0xff0000ffffff0100, 0xff0000ffff0000ff, 0xff0000ffff000000,\n    0xff0000ffff0001ff, 0xff0000ffff000100, 0xff0000ffff01ff00, 0xff0000ffff010001,\n    0xff0000ff00ffff00, 0xff0000ff00ff0000, 0xff0000ff00ff0001, 0xff0000ff00ff01ff,\n    0xff0000ff00ff0101, 0xff0000ff0000ff00, 0xff0000ff000000ff, 0xff0000ff00000000,\n    0xff0000ff00000001, 0xff0000ff00000100, 0xff0000ff0001ff01, 0xff0000ff00010000,\n    0xff0000ff000101ff, 0xff0000ff01ff00ff, 0xff0000ff01ff0100, 0xff0000ff0100ffff,\n    0xff0000ff010000ff, 0xff0000ff01000000, 0xff0000ff010001ff, 0xff0000ff01000100,\n    0xff0000ff01000101, 0xff0000ff0101ff00, 0xff0000ff010100ff, 0xff0000ff01010000,\n    0xff0000ff01010100, 0xff000000ffffff01, 0xff000000ffff0000, 0xff000000ffff0101,\n    0xff000000ff00ff00, 0xff000000ff0000ff, 0xff000000ff000000, 0xff000000ff000001,\n    0xff000000ff000100, 0xff000000ff01ffff, 0xff000000ff01ff01, 0xff000000ff010000,\n    0xff000000ff0101ff, 0xff000000ff010101, 0xff00000000ffff00, 0xff00000000ff00ff,\n    0xff00000000ff0000, 0xff00000000ff0001, 0xff0000000000ff00, 0xff0000000000ff01,\n    0xff000000000000ff, 0xff00000000000000, 0xff00000000000001, 0xff00000000000100,\n    0xff00000000000101, 0xff0000000001ff00, 0xff000000000100ff, 0xff00000000010000,\n    0xff00000000010001, 0xff00000000010100, 0xff00000001ffffff, 0xff00000001ffff01,\n    0xff00000001ff00ff, 0xff00000001ff0000, 0xff00000001ff01ff, 0xff00000001ff0101,\n    0xff0000000100ffff, 0xff0000000100ff00, 0xff000000010000ff, 0xff00000001000000,\n    0xff00000001000001, 0xff00000001000100, 0xff00000001000101, 0xff0000000101ffff,\n    0xff0000000101ff01, 0xff00000001010000, 0xff000001ffffff00, 0xff000001ffff00ff,\n    0xff000001ffff0000, 0xff000001ffff0001, 0xff000001ff000000, 0xff000001ff000001,\n    0xff000001ff0001ff, 0xff000001ff000101, 0xff000001ff01ff00, 0xff000001ff010001,\n    0xff00000100ffffff, 0xff00000100ffff01, 0xff00000100ff00ff, 0xff00000100ff0000,\n    0xff00000100ff01ff, 0xff00000100ff0101, 0xff0000010000ff00, 0xff00000100000000,\n    0xff00000100000001, 0xff000001000001ff, 0xff00000100000100, 0xff0000010001ff00,\n    0xff000001000100ff, 0xff00000100010000, 0xff000001000101ff, 0xff00000100010100,\n    0xff00000100010101, 0xff00000101ff0001, 0xff00000101ff0101, 0xff0000010100ff01,\n    0xff00000101000000, 0xff000001010100ff, 0xff00000101010100, 0xff0001ffff00ff00,\n    0xff0001ffff000001, 0xff0001ffff010000, 0xff0001ff00ffff00, 0xff0001ff00ff00ff,\n    0xff0001ff00ff0001, 0xff0001ff00ff0100, 0xff0001ff0000ffff, 0xff0001ff00000000,\n    0xff0001ff000001ff, 0xff0001ff00000101, 0xff0001ff0001ffff, 0xff0001ff0001ff00,\n    0xff0001ff000100ff, 0xff0001ff00010001, 0xff0001ff00010100, 0xff0001ff01ff0000,\n    0xff0001ff0100ff00, 0xff0001ff010000ff, 0xff0001ff01010000, 0xff000100ff00ffff,\n    0xff000100ff00ff01, 0xff000100ff000000, 0xff000100ff000101, 0xff000100ff01ff00,\n    0xff000100ff010000, 0xff00010000ffff01, 0xff00010000ff00ff, 0xff00010000ff0000,\n    0xff00010000ff01ff, 0xff0001000000ff00, 0xff000100000000ff, 0xff00010000000000,\n    0xff00010000000001, 0xff00010000000100, 0xff00010000000101, 0xff0001000001ffff,\n    0xff00010000010000, 0xff00010000010101, 0xff00010001ff0100, 0xff0001000100ff00,\n    0xff0001000100ff01, 0xff00010001000000, 0xff000100010001ff, 0xff0001000101ff00,\n    0xff00010001010001, 0xff00010001010100, 0xff000101ffff0100, 0xff000101ff000001,\n    0xff000101ff0100ff, 0xff000101ff010001, 0xff00010100ff00ff, 0xff00010100ff0001,\n    0xff00010100ff0100, 0xff0001010000ffff, 0xff0001010000ff01, 0xff00010100000000,\n    0xff000101000001ff, 0xff0001010001ff00, 0xff00010100010001, 0xff00010100010100,\n    0xff00010101ff0000, 0xff0001010100ff00, 0xff00010101000001, 0xff00010101000101,\n    0xff01ffffffffffff, 0xff01ffffffffff01, 0xff01ffffffff01ff, 0xff01ffffffff0101,\n    0xff01ffffff000000, 0xff01ffffff01ffff, 0xff01ffffff01ff01, 0xff01ffffff010000,\n    0xff01ffffff0101ff, 0xff01ffffff010101, 0xff01ffff00ff0000, 0xff01ffff0000ff00,\n    0xff01ffff00000100, 0xff01ffff0001ff00, 0xff01ffff00010000, 0xff01ffff01ffffff,\n    0xff01ffff01ffff01, 0xff01ffff01ff01ff, 0xff01ffff01ff0101, 0xff01ffff01000000,\n    0xff01ffff0101ffff, 0xff01ffff0101ff01, 0xff01ffff01010000, 0xff01ffff010101ff,\n    0xff01ffff01010101, 0xff01ff00ffff0000, 0xff01ff00ff00ff00, 0xff01ff00ff0000ff,\n    0xff01ff00ff000100, 0xff01ff00ff010000, 0xff01ff0000ffff01, 0xff01ff0000ff00ff,\n    0xff01ff0000ff0100, 0xff01ff0000000000, 0xff01ff00000001ff, 0xff01ff0000000101,\n    0xff01ff000001ff00, 0xff01ff00000100ff, 0xff01ff0000010000, 0xff01ff0000010001,\n    0xff01ff0001ff0000, 0xff01ff000100ffff, 0xff01ff0001000001, 0xff01ff0001000100,\n    0xff01ff0001010000, 0xff01ff01ffffff00, 0xff01ff01ffff01ff, 0xff01ff01ffff0101,\n    0xff01ff01ff00ff00, 0xff01ff01ff000000, 0xff01ff01ff01ffff, 0xff01ff01ff01ff01,\n    0xff01ff01ff0101ff, 0xff01ff01ff010101, 0xff01ff0100ff0000, 0xff01ff010000ff00,\n    0xff01ff0100000001, 0xff01ff0100000100, 0xff01ff0100010000, 0xff01ff0101ffff00,\n    0xff01ff0101ff01ff, 0xff01ff0101ff0101, 0xff01ff010100ff00, 0xff01ff0101000000,\n    0xff01ff010101ffff, 0xff01ff010101ff01, 0xff01ff01010101ff, 0xff01ff0101010101,\n    0xff0100ffffff0000, 0xff0100ffff0000ff, 0xff0100ffff000001, 0xff0100ffff000100,\n    0xff0100ffff010000, 0xff0100ff00ff00ff, 0xff0100ff00ff0000, 0xff0100ff00ff0001,\n    0xff0100ff00ff0100, 0xff0100ff0000ff01, 0xff0100ff00000000, 0xff0100ff000001ff,\n    0xff0100ff00000101, 0xff0100ff00010001, 0xff0100ff01ff0000, 0xff0100ff0100ff00,\n    0xff0100ff010000ff, 0xff0100ff01000100, 0xff0100ff0101ff00, 0xff0100ff01010000,\n    0xff010000ffff0100, 0xff010000ff000000, 0xff010000ff01ff00, 0xff010000ff010100,\n    0xff01000000ffffff, 0xff01000000ff0000, 0xff01000000ff01ff, 0xff0100000000ff00,\n    0xff010000000000ff, 0xff01000000000000, 0xff01000000000100, 0xff0100000001ff01,\n    0xff01000000010000, 0xff010000000101ff, 0xff01000001ff0100, 0xff0100000100ffff,\n    0xff010000010000ff, 0xff01000001000000, 0xff010000010001ff, 0xff01000001000101,\n    0xff0100000101ff00, 0xff010000010100ff, 0xff01000001010001, 0xff01000001010100,\n    0xff010001ffff0000, 0xff010001ff00ffff, 0xff010001ff00ff01, 0xff010001ff000100,\n    0xff010001ff010000, 0xff01000100ffff00, 0xff01000100ff0100, 0xff01000100000000,\n    0xff0100010001ffff, 0xff0100010001ff00, 0xff01000100010100, 0xff01000101ff00ff,\n    0xff01000101ff0001, 0xff0100010100ffff, 0xff01000101000101, 0xff0101ffffffffff,\n    0xff0101ffffffff01, 0xff0101ffffff01ff, 0xff0101ffffff0101, 0xff0101ffff000000,\n    0xff0101ffff01ffff, 0xff0101ffff01ff01, 0xff0101ffff0101ff, 0xff0101ffff010101,\n    0xff0101ff00ff0000, 0xff0101ff0000ff00, 0xff0101ff000000ff, 0xff0101ff00010000,\n    0xff0101ff01ffffff, 0xff0101ff01ffff01, 0xff0101ff01ff01ff, 0xff0101ff01ff0101,\n    0xff0101ff0101ffff, 0xff0101ff0101ff01, 0xff0101ff010101ff, 0xff0101ff01010101,\n    0xff010100ffff0100, 0xff010100ff00ff00, 0xff010100ff0000ff, 0xff010100ff000100,\n    0xff010100ff010000, 0xff01010000ff0001, 0xff01010000ff0100, 0xff0101000000ff01,\n    0xff01010000000000, 0xff0101000001ff00, 0xff010100000100ff, 0xff01010000010001,\n    0xff01010000010100, 0xff01010001ff0000, 0xff0101000100ffff, 0xff01010001000001,\n    0xff01010001000100, 0xff010100010100ff, 0xff01010001010000, 0xff010101ffffffff,\n    0xff010101ffffff01, 0xff010101ffff01ff, 0xff010101ffff0101, 0xff010101ff01ffff,\n    0xff010101ff01ff01, 0xff010101ff0101ff, 0xff010101ff010101, 0xff01010100ff0000,\n    0xff0101010000ff00, 0xff01010100000001, 0xff01010100000100, 0xff01010100010000,\n    0xff01010101ffffff, 0xff01010101ffff01, 0xff01010101ff01ff, 0xff01010101ff0101,\n    0xff01010101000000, 0xff0101010101ffff, 0xff0101010101ff01, 0xff010101010101ff,\n    0xff01010101010101, 0x00ffffffffff0000, 0x00ffffffff00ff00, 0x00ffffffff000001,\n    0x00ffffffff010000, 0x00ffffff00ff0100, 0x00ffffff0000ff01, 0x00ffffff00000000,\n    0x00ffffff000001ff, 0x00ffffff00000101, 0x00ffffff0001ff00, 0x00ffffff000100ff,\n    0x00ffffff00010001, 0x00ffffff010000ff, 0x00ffffff01000100, 0x00ffffff0101ff00,\n    0x00ffffff01010001, 0x00ffff00ffffffff, 0x00ffff00ffffff00, 0x00ffff00ffff00ff,\n    0x00ffff00ffff0001, 0x00ffff00ffff0100, 0x00ffff00ff00ff01, 0x00ffff00ff000000,\n    0x00ffff00ff000001, 0x00ffff00ff0001ff, 0x00ffff00ff000101, 0x00ffff00ff01ff00,\n    0x00ffff00ff010001, 0x00ffff00ff010100, 0x00ffff0000ff0000, 0x00ffff0000ff01ff,\n    0x00ffff0000ff0101, 0x00ffff000000ff00, 0x00ffff00000000ff, 0x00ffff0000000000,\n    0x00ffff0000000001, 0x00ffff0000000100, 0x00ffff0000000101, 0x00ffff0000010000,\n    0x00ffff00000101ff, 0x00ffff0000010101, 0x00ffff0001ffff00, 0x00ffff0001ff00ff,\n    0x00ffff0001ff0001, 0x00ffff000100ffff, 0x00ffff000100ff01, 0x00ffff0001000000,\n    0x00ffff000101ffff, 0x00ffff000101ff00, 0x00ffff000101ff01, 0x00ffff01ffff0000,\n    0x00ffff01ff00ff00, 0x00ffff01ff0000ff, 0x00ffff01ff000001, 0x00ffff01ff010000,\n    0x00ffff0100ffff00, 0x00ffff010000ff01, 0x00ffff0100000000, 0x00ffff0100000101,\n    0x00ffff01000100ff, 0x00ffff0100010100, 0x00ffff0101ff0100, 0x00ffff01010000ff,\n    0x00ffff0101010000, 0x00ff00ffffffff00, 0x00ff00ffff000000, 0x00ff00ffff000100,\n    0x00ff00ffff010100, 0x00ff00ff00ff0000, 0x00ff00ff00ff01ff, 0x00ff00ff00ff0101,\n    0x00ff00ff0000ff00, 0x00ff00ff000000ff, 0x00ff00ff00000000, 0x00ff00ff00000001,\n    0x00ff00ff0001ff00, 0x00ff00ff0001ff01, 0x00ff00ff00010000, 0x00ff00ff000101ff,\n    0x00ff00ff00010101, 0x00ff00ff01ffff00, 0x00ff00ff01ff0001, 0x00ff00ff01ff0100,\n    0x00ff00ff0100ffff, 0x00ff00ff0100ff01, 0x00ff00ff01000000, 0x00ff00ff0101ffff,\n    0x00ff00ff0101ff00, 0x00ff00ff01010100, 0x00ff0000ffffff00, 0x00ff0000ffffff01,\n    0x00ff0000ffff0000, 0x00ff0000ffff0101, 0x00ff0000ff00ff00, 0x00ff0000ff0000ff,\n    0x00ff0000ff000000, 0x00ff0000ff000001, 0x00ff0000ff000100, 0x00ff0000ff01ffff,\n    0x00ff0000ff010000, 0x00ff0000ff010101, 0x00ff000000ffff00, 0x00ff000000ff00ff,\n    0x00ff000000ff0000, 0x00ff000000ff0001, 0x00ff000000ff0100, 0x00ff00000000ffff,\n    0x00ff00000000ff00, 0x00ff0000000000ff, 0x00ff000000000000, 0x00ff000000000001,\n    0x00ff0000000001ff, 0x00ff000000000100, 0x00ff00000001ff00, 0x00ff0000000100ff,\n    0x00ff000000010000, 0x00ff000000010001, 0x00ff000000010100, 0x00ff000001ffff01,\n    0x00ff000001ff00ff, 0x00ff000001ff0000, 0x00ff000001ff01ff, 0x00ff00000100ff00,\n    0x00ff0000010000ff, 0x00ff000001000000, 0x00ff000001000001, 0x00ff000001000100,\n    0x00ff000001000101, 0x00ff000001010000, 0x00ff0000010101ff, 0x00ff000001010101,\n    0x00ff0001ffffff00, 0x00ff0001ffff0000, 0x00ff0001ffff0100, 0x00ff0001ff0000ff,\n    0x00ff0001ff000000, 0x00ff0001ff0001ff, 0x00ff0001ff000101, 0x00ff0001ff01ff00,\n    0x00ff0001ff0100ff, 0x00ff0001ff010100, 0x00ff000100ffffff, 0x00ff000100ffff01,\n    0x00ff000100ff0000, 0x00ff000100ff01ff, 0x00ff00010000ffff, 0x00ff00010000ff00,\n    0x00ff00010000ff01, 0x00ff000100000000, 0x00ff000100000001, 0x00ff000100000100,\n    0x00ff00010001ff01, 0x00ff000100010000, 0x00ff0001000101ff, 0x00ff000101ffff00,\n    0x00ff000101ff0000, 0x00ff000101ff0101, 0x00ff0001010000ff, 0x00ff000101000000,\n    0x00ff00010101ff00, 0x00ff0001010100ff, 0x00ff000101010001, 0x00ff01ffffff0000,\n    0x00ff01ffff00ff00, 0x00ff01ffff000000, 0x00ff01ffff000101, 0x00ff01ffff010000,\n    0x00ff01ff00ffff01, 0x00ff01ff00ff0100, 0x00ff01ff0000ffff, 0x00ff01ff00000000,\n    0x00ff01ff000001ff, 0x00ff01ff0001ff00, 0x00ff01ff000100ff, 0x00ff01ff00010001,\n    0x00ff01ff00010100, 0x00ff01ff01ff0000, 0x00ff01ff0100ff00, 0x00ff01ff010000ff,\n    0x00ff01ff01000001, 0x00ff01ff01000100, 0x00ff01ff01010000, 0x00ff0100ffffff00,\n    0x00ff0100ffff0000, 0x00ff0100ffff0001, 0x00ff0100ffff0101, 0x00ff0100ff00ffff,\n    0x00ff0100ff0000ff, 0x00ff0100ff000000, 0x00ff0100ff0001ff, 0x00ff0100ff01ff00,\n    0x00ff0100ff0100ff, 0x00ff0100ff010001, 0x00ff010000ffffff, 0x00ff010000ff0000,\n    0x00ff010000ff0101, 0x00ff01000000ff00, 0x00ff01000000ff01, 0x00ff0100000000ff,\n    0x00ff010000000000, 0x00ff010000000001, 0x00ff010000000100, 0x00ff01000001ffff,\n    0x00ff01000001ff01, 0x00ff010000010000, 0x00ff010000010001, 0x00ff010000010101,\n    0x00ff010001ff0001, 0x00ff010001ff0100, 0x00ff01000100ff01, 0x00ff010001000000,\n    0x00ff010001000001, 0x00ff0100010001ff, 0x00ff01000101ff00, 0x00ff0100010100ff,\n    0x00ff010001010001, 0x00ff010001010100, 0x00ff0101ff000001, 0x00ff010100ff00ff,\n    0x00ff010100ff0001, 0x00ff010100ff0100, 0x00ff010100000000, 0x00ff0101000001ff,\n    0x00ff010100000101, 0x00ff0101000100ff, 0x00ff010100010100, 0x00ff0101010000ff,\n    0x00ff010101010000, 0x0000ffffffffff00, 0x0000ffffffff00ff, 0x0000ffffffff0000,\n    0x0000ffffffff0001, 0x0000ffffffff0100, 0x0000ffffff00ff01, 0x0000ffffff000000,\n    0x0000ffffff000101, 0x0000ffffff01ff00, 0x0000ffffff0100ff, 0x0000ffffff010100,\n    0x0000ffff00ffffff, 0x0000ffff00ff0000, 0x0000ffff00ff01ff, 0x0000ffff0000ff00,\n    0x0000ffff000000ff, 0x0000ffff00000000, 0x0000ffff00000001, 0x0000ffff00000100,\n    0x0000ffff00010000, 0x0000ffff000101ff, 0x0000ffff01ff0001, 0x0000ffff01ff0100,\n    0x0000ffff01000000, 0x0000ffff010001ff, 0x0000ffff0101ffff, 0x0000ffff0101ff00,\n    0x0000ffff01010001, 0x0000ffff01010100, 0x0000ff00ffff0000, 0x0000ff00ffff01ff,\n    0x0000ff00ffff0100, 0x0000ff00ffff0101, 0x0000ff00ff00ff00, 0x0000ff00ff0000ff,\n    0x0000ff00ff000000, 0x0000ff00ff000001, 0x0000ff00ff0001ff, 0x0000ff00ff000100,\n    0x0000ff00ff01ffff, 0x0000ff00ff010000, 0x0000ff00ff010001, 0x0000ff00ff0101ff,\n    0x0000ff00ff010101, 0x0000ff0000ffff00, 0x0000ff0000ff00ff, 0x0000ff0000ff0000,\n    0x0000ff0000ff0001, 0x0000ff0000ff0100, 0x0000ff000000ffff, 0x0000ff000000ff00,\n    0x0000ff000000ff01, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,\n    0x0000ff00000001ff, 0x0000ff0000000100, 0x0000ff0000000101, 0x0000ff000001ff00,\n    0x0000ff00000100ff, 0x0000ff0000010000, 0x0000ff0000010001, 0x0000ff0000010100,\n    0x0000ff0001ffff01, 0x0000ff0001ff0000, 0x0000ff000100ff00, 0x0000ff00010000ff,\n    0x0000ff0001000000, 0x0000ff0001000001, 0x0000ff0001000100, 0x0000ff000101ffff,\n    0x0000ff0001010000, 0x0000ff0001010101, 0x0000ff01ffffff00, 0x0000ff01ffff0001,\n    0x0000ff01ff00ff01, 0x0000ff01ff000000, 0x0000ff01ff000101, 0x0000ff01ff01ff00,\n    0x0000ff01ff0100ff, 0x0000ff0100ffff01, 0x0000ff0100ff0000, 0x0000ff0100ff0101,\n    0x0000ff010000ff00, 0x0000ff01000000ff, 0x0000ff0100000000, 0x0000ff0100000001,\n    0x0000ff0100000100, 0x0000ff010001ff01, 0x0000ff0100010000, 0x0000ff0101ff0000,\n    0x0000ff010100ffff, 0x0000ff010100ff01, 0x0000ff0101000000, 0x0000ff0101000100,\n    0x0000ff0101000101, 0x0000ff01010100ff, 0x000000ffffff00ff, 0x000000ffffff0000,\n    0x000000ffff00ff00, 0x000000ffff0000ff, 0x000000ffff000000, 0x000000ffff000001,\n    0x000000ffff0001ff, 0x000000ffff000100, 0x000000ffff01ff00, 0x000000ffff010000,\n    0x000000ffff0101ff, 0x000000ffff010101, 0x000000ff00ffff00, 0x000000ff00ff00ff,\n    0x000000ff00ff0000, 0x000000ff00ff0001, 0x000000ff00ff0100, 0x000000ff00ff0101,\n    0x000000ff0000ffff, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,\n    0x000000ff00000001, 0x000000ff000001ff, 0x000000ff00000100, 0x000000ff00000101,\n    0x000000ff0001ff00, 0x000000ff0001ff01, 0x000000ff000100ff, 0x000000ff00010000,\n    0x000000ff00010001, 0x000000ff00010100, 0x000000ff01ffffff, 0x000000ff01ff01ff,\n    0x000000ff01ff0101, 0x000000ff0100ff00, 0x000000ff010000ff, 0x000000ff01000000,\n    0x000000ff01000001, 0x000000ff01000100, 0x000000ff0101ff00, 0x000000ff010100ff,\n    0x000000ff01010000, 0x000000ff01010101, 0x00000000ffffff00, 0x00000000ffffff01,\n    0x00000000ffff00ff, 0x00000000ffff0000, 0x00000000ffff0001, 0x00000000ffff0100,\n    0x00000000ff00ffff, 0x00000000ff00ff00, 0x00000000ff00ff01, 0x00000000ff0000ff,\n    0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff000101,\n    0x00000000ff01ff00, 0x00000000ff0100ff, 0x00000000ff010000, 0x00000000ff010001,\n    0x00000000ff010100, 0x0000000000ffffff, 0x0000000000ffff00, 0x0000000000ffff01,\n    0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, 0x0000000000ff01ff,\n    0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,\n    0x00000000000000ff, 0x0000000000000000, 0x0000000000000001, 0x00000000000001ff,\n    0x0000000000000100, 0x0000000000000101, 0x000000000001ffff, 0x000000000001ff00,\n    0x00000000000100ff, 0x0000000000010000, 0x0000000000010001, 0x00000000000101ff,\n    0x0000000000010100, 0x0000000000010101, 0x0000000001ffff00, 0x0000000001ff00ff,\n    0x0000000001ff0000, 0x0000000001ff0100, 0x0000000001ff0101, 0x000000000100ffff,\n    0x000000000100ff00, 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001,\n    0x00000000010001ff, 0x0000000001000100, 0x000000000101ff00, 0x00000000010100ff,\n    0x0000000001010000, 0x0000000001010001, 0x0000000001010100, 0x00000001ffffffff,\n    0x00000001ffffff00, 0x00000001ffffff01, 0x00000001ffff00ff, 0x00000001ffff0001,\n    0x00000001ffff01ff, 0x00000001ffff0100, 0x00000001ff00ff00, 0x00000001ff0000ff,\n    0x00000001ff000000, 0x00000001ff0001ff, 0x00000001ff000100, 0x00000001ff01ffff,\n    0x00000001ff01ff00, 0x00000001ff01ff01, 0x00000001ff0100ff, 0x00000001ff010000,\n    0x00000001ff010001, 0x00000001ff0101ff, 0x00000001ff010100, 0x0000000100ffff00,\n    0x0000000100ff0000, 0x0000000100ff0001, 0x0000000100ff01ff, 0x0000000100ff0100,\n    0x0000000100ff0101, 0x000000010000ffff, 0x000000010000ff00, 0x000000010000ff01,\n    0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, 0x00000001000001ff,\n    0x0000000100000100, 0x0000000100000101, 0x000000010001ff00, 0x00000001000100ff,\n    0x0000000100010000, 0x0000000100010100, 0x0000000101ffff01, 0x0000000101ff0000,\n    0x0000000101ff0001, 0x0000000101ff01ff, 0x0000000101ff0100, 0x0000000101ff0101,\n    0x000000010100ff00, 0x0000000101000000, 0x0000000101000101, 0x000000010101ff01,\n    0x0000000101010000, 0x0000000101010001, 0x00000001010101ff, 0x0000000101010100,\n    0x000001ffffff00ff, 0x000001ffffff0000, 0x000001ffffff0001, 0x000001ffffff0100,\n    0x000001ffff00ffff, 0x000001ffff000000, 0x000001ffff0001ff, 0x000001ffff01ff00,\n    0x000001ffff010101, 0x000001ff00ff0000, 0x000001ff00ff01ff, 0x000001ff00ff0101,\n    0x000001ff0000ff00, 0x000001ff000000ff, 0x000001ff00000000, 0x000001ff00000001,\n    0x000001ff000001ff, 0x000001ff00000100, 0x000001ff0001ffff, 0x000001ff0001ff01,\n    0x000001ff000100ff, 0x000001ff00010000, 0x000001ff01ffff01, 0x000001ff01ff0100,\n    0x000001ff0100ffff, 0x000001ff0100ff01, 0x000001ff01000000, 0x000001ff010001ff,\n    0x000001ff0101ff00, 0x000001ff01010100, 0x00000100ffffff00, 0x00000100ffffff01,\n    0x00000100ffff0000, 0x00000100ffff0101, 0x00000100ff00ff00, 0x00000100ff0000ff,\n    0x00000100ff000000, 0x00000100ff000001, 0x00000100ff000100, 0x00000100ff010000,\n    0x0000010000ffff00, 0x0000010000ff00ff, 0x0000010000ff0000, 0x0000010000ff0001,\n    0x0000010000ff0100, 0x000001000000ffff, 0x000001000000ff00, 0x000001000000ff01,\n    0x00000100000000ff, 0x0000010000000000, 0x0000010000000001, 0x00000100000001ff,\n    0x0000010000000100, 0x0000010000000101, 0x000001000001ff00, 0x00000100000100ff,\n    0x0000010000010000, 0x0000010000010001, 0x0000010000010100, 0x0000010001ffff00,\n    0x0000010001ff0000, 0x0000010001ff0100, 0x000001000100ff00, 0x00000100010000ff,\n    0x0000010001000000, 0x0000010001000001, 0x00000100010001ff, 0x0000010001000100,\n    0x0000010001010000, 0x00000101ffff00ff, 0x00000101ffff01ff, 0x00000101ff000000,\n    0x00000101ff000101, 0x00000101ff01ffff, 0x00000101ff010000, 0x00000101ff010001,\n    0x00000101ff010100, 0x0000010100ff0000, 0x0000010100ff01ff, 0x0000010100ff0100,\n    0x000001010000ff00, 0x0000010100000000, 0x0000010100000001, 0x00000101000001ff,\n    0x0000010100000100, 0x000001010001ff01, 0x0000010100010000, 0x00000101000101ff,\n    0x0000010100010101, 0x0000010101ffff00, 0x0000010101ff0101, 0x000001010100ff01,\n    0x0000010101000000, 0x0000010101000001, 0x00000101010001ff, 0x0000010101000101,\n    0x000001010101ff00, 0x0001ffffffff0000, 0x0001ffffff0000ff, 0x0001ffffff000001,\n    0x0001ffffff000100, 0x0001ffffff010000, 0x0001ffff00ff00ff, 0x0001ffff0000ffff,\n    0x0001ffff00000000, 0x0001ffff00000001, 0x0001ffff000001ff, 0x0001ffff00000101,\n    0x0001ffff0001ff00, 0x0001ffff000100ff, 0x0001ffff00010001, 0x0001ffff00010100,\n    0x0001ffff01ffff00, 0x0001ffff01000001, 0x0001ffff01010000, 0x0001ff00ffffff00,\n    0x0001ff00ffff00ff, 0x0001ff00ffff0001, 0x0001ff00ffff0100, 0x0001ff00ff00ff01,\n    0x0001ff00ff000000, 0x0001ff00ff01ff00, 0x0001ff00ff01ff01, 0x0001ff00ff010001,\n    0x0001ff00ff010100, 0x0001ff0000ff0000, 0x0001ff0000ff0100, 0x0001ff000000ff00,\n    0x0001ff0000000000, 0x0001ff0000000001, 0x0001ff0000000100, 0x0001ff0000010000,\n    0x0001ff0000010001, 0x0001ff0000010101, 0x0001ff0001ff00ff, 0x0001ff0001ff0101,\n    0x0001ff000100ff01, 0x0001ff0001000000, 0x0001ff000101ff00, 0x0001ff0001010001,\n    0x0001ff0001010100, 0x0001ff01ff00ff00, 0x0001ff01ff000001, 0x0001ff01ff000100,\n    0x0001ff0100ffffff, 0x0001ff0100ffff00, 0x0001ff0100ff0001, 0x0001ff0100000000,\n    0x0001ff0100000001, 0x0001ff01000001ff, 0x0001ff010001ffff, 0x0001ff0101ff0000,\n    0x0001ff010100ff00, 0x0001ff0101000001, 0x0001ff0101010000, 0x000100ffff00ff00,\n    0x000100ffff00ff01, 0x000100ffff000000, 0x000100ffff000001, 0x000100ffff000101,\n    0x000100ffff01ff00, 0x000100ffff010001, 0x000100ffff010100, 0x000100ff00ffffff,\n    0x000100ff00ffff01, 0x000100ff00ff0000, 0x000100ff00ff01ff, 0x000100ff00ff0101,\n    0x000100ff0000ff00, 0x000100ff000000ff, 0x000100ff00000000, 0x000100ff00000001,\n    0x000100ff00000100, 0x000100ff00000101, 0x000100ff0001ffff, 0x000100ff0001ff01,\n    0x000100ff00010000, 0x000100ff01ff00ff, 0x000100ff01ff0000, 0x000100ff01ff0100,\n    0x000100ff0100ffff, 0x000100ff0100ff01, 0x000100ff010000ff, 0x000100ff01000000,\n    0x000100ff01000001, 0x000100ff010001ff, 0x000100ff01000101, 0x000100ff0101ff00,\n    0x000100ff010100ff, 0x000100ff01010100, 0x00010000ffff0000, 0x00010000ffff01ff,\n    0x00010000ffff0101, 0x00010000ff00ff00, 0x00010000ff000000, 0x00010000ff000001,\n    0x00010000ff000100, 0x0001000000ff00ff, 0x0001000000ff0000, 0x0001000000ff0001,\n    0x0001000000ff0100, 0x000100000000ffff, 0x000100000000ff00, 0x00010000000000ff,\n    0x0001000000000000, 0x0001000000000001, 0x0001000000000100, 0x000100000001ff00,\n    0x00010000000100ff, 0x0001000000010000, 0x0001000000010001, 0x0001000000010100,\n    0x0001000001ff0001, 0x0001000001ff0100, 0x0001000001ff0101, 0x000100000100ff00,\n    0x0001000001000000, 0x0001000001000001, 0x0001000001000100, 0x0001000001000101,\n    0x000100000101ff01, 0x0001000001010000, 0x0001000001010001, 0x00010000010101ff,\n    0x00010001ffffff01, 0x00010001ffff0100, 0x00010001ff000000, 0x00010001ff01ffff,\n    0x00010001ff010001, 0x00010001ff0101ff, 0x00010001ff010100, 0x0001000100ffffff,\n    0x0001000100ff0000, 0x0001000100ff01ff, 0x0001000100ff0101, 0x000100010000ff00,\n    0x00010001000000ff, 0x0001000100000000, 0x0001000100000001, 0x00010001000001ff,\n    0x0001000100000101, 0x000100010001ffff, 0x0001000100010000, 0x00010001000101ff,\n    0x0001000101ffffff, 0x0001000101ffff01, 0x0001000101ff0000, 0x0001000101ff0101,\n    0x00010001010000ff, 0x0001000101000001, 0x00010001010001ff, 0x0001000101000100,\n    0x000100010101ffff, 0x00010001010100ff, 0x0001000101010001, 0x0001000101010101,\n    0x000101ffff000001, 0x000101ffff000100, 0x000101ffff010000, 0x000101ff00ffff00,\n    0x000101ff0000ff01, 0x000101ff00000000, 0x000101ff00000101, 0x000101ff0001ff00,\n    0x000101ff00010100, 0x000101ff01ff0000, 0x000101ff0100ff00, 0x000101ff010001ff,\n    0x000101ff01010001, 0x00010100ffffff00, 0x00010100ffff00ff, 0x00010100ff00ffff,\n    0x00010100ff000000, 0x00010100ff01ff00, 0x00010100ff0100ff, 0x00010100ff010001,\n    0x00010100ff010100, 0x0001010000ffffff, 0x0001010000ffff00, 0x0001010000ff0000,\n    0x0001010000ff0001, 0x0001010000ff01ff, 0x000101000000ff00, 0x00010100000000ff,\n    0x0001010000000000, 0x0001010000000001, 0x0001010000000100, 0x000101000001ffff,\n    0x0001010000010000, 0x0001010000010101, 0x0001010001ffff01, 0x0001010001ff00ff,\n    0x0001010001ff0101, 0x0001010001000000, 0x000101000101ff00, 0x00010100010100ff,\n    0x0001010001010000, 0x0001010001010100, 0x00010101ff00ff00, 0x00010101ff000001,\n    0x00010101ff0001ff, 0x0001010100ffff00, 0x0001010100ff00ff, 0x0001010100ff0100,\n    0x000101010000ffff, 0x0001010100000000, 0x00010101000001ff, 0x0001010100000101,\n    0x00010101000100ff, 0x0001010100010000, 0x0001010100010100, 0x0001010101ff0001,\n    0x00010101010000ff, 0x00010101010001ff, 0x0001010101000101, 0x0001010101010001,\n    0x01ffffffffffffff, 0x01ffffffffffff01, 0x01ffffffffff01ff, 0x01ffffffffff0101,\n    0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, 0x01ffffffff010101,\n    0x01ffffff00ff0000, 0x01ffffff0000ffff, 0x01ffffff0000ff00, 0x01ffffff000000ff,\n    0x01ffffff00000001, 0x01ffffff00000100, 0x01ffffff00010000, 0x01ffffff01ffffff,\n    0x01ffffff01ffff01, 0x01ffffff01ff01ff, 0x01ffffff01ff0101, 0x01ffffff01000000,\n    0x01ffffff0101ffff, 0x01ffffff0101ff01, 0x01ffffff010101ff, 0x01ffffff01010101,\n    0x01ffff00ffff0000, 0x01ffff00ff00ff00, 0x01ffff00ff0000ff, 0x01ffff00ff000001,\n    0x01ffff00ff000100, 0x01ffff00ff010000, 0x01ffff0000ffff00, 0x01ffff0000ff00ff,\n    0x01ffff0000ff0100, 0x01ffff000000ffff, 0x01ffff000000ff01, 0x01ffff0000000000,\n    0x01ffff0000000001, 0x01ffff00000001ff, 0x01ffff0000000100, 0x01ffff00000100ff,\n    0x01ffff0000010001, 0x01ffff0000010100, 0x01ffff0001ff0000, 0x01ffff0001ff0100,\n    0x01ffff00010000ff, 0x01ffff0001000001, 0x01ffff0001000100, 0x01ffff0001010000,\n    0x01ffff01ffffffff, 0x01ffff01ffffff01, 0x01ffff01ffff01ff, 0x01ffff01ffff0101,\n    0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff01ff01, 0x01ffff01ff0101ff,\n    0x01ffff01ff010101, 0x01ffff010000ff00, 0x01ffff01000000ff, 0x01ffff0100000100,\n    0x01ffff0100010000, 0x01ffff0101ffffff, 0x01ffff0101ffff01, 0x01ffff0101ff01ff,\n    0x01ffff0101ff0101, 0x01ffff0101000000, 0x01ffff010101ffff, 0x01ffff010101ff01,\n    0x01ffff01010101ff, 0x01ffff0101010101, 0x01ff00ffff0000ff, 0x01ff00ffff000100,\n    0x01ff00ff00ffff00, 0x01ff00ff00ff00ff, 0x01ff00ff0000ff00, 0x01ff00ff00000000,\n    0x01ff00ff00000101, 0x01ff00ff0001ff00, 0x01ff00ff000100ff, 0x01ff00ff00010100,\n    0x01ff00ff010000ff, 0x01ff00ff01000100, 0x01ff0000ffffff00, 0x01ff0000ffff0100,\n    0x01ff0000ff00ff01, 0x01ff0000ff000000, 0x01ff0000ff000101, 0x01ff0000ff010001,\n    0x01ff0000ff010100, 0x01ff000000ffffff, 0x01ff000000ffff00, 0x01ff000000ff0000,\n    0x01ff000000ff01ff, 0x01ff00000000ff00, 0x01ff0000000000ff, 0x01ff000000000000,\n    0x01ff000000000001, 0x01ff000000000100, 0x01ff000000000101, 0x01ff000000010000,\n    0x01ff000000010001, 0x01ff0000000101ff, 0x01ff000000010101, 0x01ff000001ffff00,\n    0x01ff000001ff00ff, 0x01ff000001ff0001, 0x01ff000001ff0100, 0x01ff00000100ffff,\n    0x01ff00000100ff01, 0x01ff000001000000, 0x01ff0000010001ff, 0x01ff000001010001,\n    0x01ff0001ff00ff00, 0x01ff0001ff000001, 0x01ff0001ff000100, 0x01ff0001ff010000,\n    0x01ff000100ffff00, 0x01ff000100ff00ff, 0x01ff000100ff0100, 0x01ff000100ff0101,\n    0x01ff00010000ffff, 0x01ff000100000000, 0x01ff000100000100, 0x01ff000100000101,\n    0x01ff00010001ff00, 0x01ff000100010001, 0x01ff000100010101, 0x01ff000101ff0000,\n    0x01ff00010100ff00, 0x01ff000101000101, 0x01ff0001010100ff, 0x01ff01ffffffffff,\n    0x01ff01ffffffff01, 0x01ff01ffffff01ff, 0x01ff01ffffff0101, 0x01ff01ffff000000,\n    0x01ff01ffff01ffff, 0x01ff01ffff01ff01, 0x01ff01ffff0101ff, 0x01ff01ffff010101,\n    0x01ff01ff00ffff00, 0x01ff01ff00ff0000, 0x01ff01ff0000ff00, 0x01ff01ff000000ff,\n    0x01ff01ff00000100, 0x01ff01ff00010000, 0x01ff01ff00010100, 0x01ff01ff01ffffff,\n    0x01ff01ff01ffff01, 0x01ff01ff01ff01ff, 0x01ff01ff01ff0101, 0x01ff01ff01000000,\n    0x01ff01ff0101ffff, 0x01ff01ff0101ff01, 0x01ff01ff010101ff, 0x01ff01ff01010101,\n    0x01ff0100ffff0000, 0x01ff0100ffff0001, 0x01ff0100ff00ff00, 0x01ff0100ff0000ff,\n    0x01ff0100ff000001, 0x01ff0100ff010000, 0x01ff010000ffff00, 0x01ff010000ff00ff,\n    0x01ff010000ff0001, 0x01ff010000ff0100, 0x01ff01000000ffff, 0x01ff01000000ff01,\n    0x01ff010000000000, 0x01ff010000000101, 0x01ff01000001ff00, 0x01ff0100000100ff,\n    0x01ff010001ff0000, 0x01ff010001000001, 0x01ff010001000100, 0x01ff010001010000,\n    0x01ff0101ffffffff, 0x01ff0101ffffff01, 0x01ff0101ffff01ff, 0x01ff0101ffff0101,\n    0x01ff0101ff000000, 0x01ff0101ff01ffff, 0x01ff0101ff01ff01, 0x01ff0101ff0101ff,\n    0x01ff0101ff010101, 0x01ff010100ff0000, 0x01ff01010000ff00, 0x01ff0101000000ff,\n    0x01ff010100000001, 0x01ff010101ffffff, 0x01ff010101ffff01, 0x01ff010101ff01ff,\n    0x01ff010101ff0101, 0x01ff010101000000, 0x01ff01010101ffff, 0x01ff01010101ff01,\n    0x01ff0101010101ff, 0x01ff010101010101, 0x0100ffffffff0000, 0x0100ffffff00ff00,\n    0x0100ffffff000001, 0x0100ffffff0001ff, 0x0100ffffff000100, 0x0100ffffff010000,\n    0x0100ffff00ffff00, 0x0100ffff00ff0001, 0x0100ffff00ff0100, 0x0100ffff00000000,\n    0x0100ffff000001ff, 0x0100ffff00000101, 0x0100ffff00010100, 0x0100ffff00010101,\n    0x0100ffff01ff0000, 0x0100ffff0100ff00, 0x0100ffff010000ff, 0x0100ffff01000001,\n    0x0100ffff01000100, 0x0100ffff01010000, 0x0100ff00ffffff00, 0x0100ff00ffff00ff,\n    0x0100ff00ffff0001, 0x0100ff00ffff0100, 0x0100ff00ff00ffff, 0x0100ff00ff000000,\n    0x0100ff00ff0001ff, 0x0100ff00ff000101, 0x0100ff00ff01ff00, 0x0100ff00ff0100ff,\n    0x0100ff00ff010001, 0x0100ff00ff010100, 0x0100ff0000ffffff, 0x0100ff0000ff0000,\n    0x0100ff000000ffff, 0x0100ff000000ff00, 0x0100ff00000000ff, 0x0100ff0000000000,\n    0x0100ff0000000001, 0x0100ff0000000100, 0x0100ff000001ff01, 0x0100ff0000010000,\n    0x0100ff0001ff00ff, 0x0100ff0001ff0001, 0x0100ff000100ff01, 0x0100ff0001000000,\n    0x0100ff00010001ff, 0x0100ff000101ff00, 0x0100ff00010100ff, 0x0100ff0001010001,\n    0x0100ff0001010100, 0x0100ff01ffff0000, 0x0100ff01ff00ff00, 0x0100ff01ff0000ff,\n    0x0100ff01ff000100, 0x0100ff01ff010000, 0x0100ff0100ff00ff, 0x0100ff0100ff0001,\n    0x0100ff0100ff0100, 0x0100ff010000ffff, 0x0100ff010000ff01, 0x0100ff0100000000,\n    0x0100ff01000001ff, 0x0100ff0100010001, 0x0100ff0100010100, 0x0100ff0101ff0000,\n    0x0100ff01010000ff, 0x0100ff0101000001, 0x0100ff0101010100, 0x010000ffffffff00,\n    0x010000ffffff00ff, 0x010000ffffff0001, 0x010000ffff00ffff, 0x010000ffff000000,\n    0x010000ffff0001ff, 0x010000ffff010001, 0x010000ff00ffffff, 0x010000ff00ff0101,\n    0x010000ff0000ff00, 0x010000ff000000ff, 0x010000ff00000000, 0x010000ff00000001,\n    0x010000ff000001ff, 0x010000ff00000100, 0x010000ff0001ffff, 0x010000ff0001ff00,\n    0x010000ff0001ff01, 0x010000ff00010000, 0x010000ff01ff00ff, 0x010000ff01ff0001,\n    0x010000ff0100ff01, 0x010000ff010000ff, 0x010000ff01000000, 0x010000ff010001ff,\n    0x010000ff0101ff00, 0x010000ff01010100, 0x01000000ffffffff, 0x01000000ffff0000,\n    0x01000000ffff01ff, 0x01000000ffff0101, 0x01000000ff00ffff, 0x01000000ff00ff00,\n    0x01000000ff0000ff, 0x01000000ff000000, 0x01000000ff000001, 0x01000000ff000100,\n    0x01000000ff01ff00, 0x01000000ff010000, 0x01000000ff010100, 0x01000000ff010101,\n    0x0100000000ffff00, 0x0100000000ff00ff, 0x0100000000ff0000, 0x0100000000ff0001,\n    0x0100000000ff0100, 0x010000000000ffff, 0x010000000000ff00, 0x010000000000ff01,\n    0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, 0x01000000000001ff,\n    0x0100000000000100, 0x0100000000000101, 0x010000000001ff00, 0x01000000000100ff,\n    0x0100000000010000, 0x0100000000010001, 0x0100000000010100, 0x0100000001ffff00,\n    0x0100000001ff0000, 0x0100000001ff01ff, 0x010000000100ff00, 0x010000000100ff01,\n    0x01000000010000ff, 0x0100000001000000, 0x0100000001000001, 0x0100000001000100,\n    0x0100000001000101, 0x010000000101ffff, 0x010000000101ff01, 0x0100000001010000,\n    0x01000000010101ff, 0x0100000001010101, 0x01000001ffffff00, 0x01000001ffff00ff,\n    0x01000001ff00ffff, 0x01000001ff000000, 0x01000001ff000100, 0x01000001ff01ffff,\n    0x01000001ff010001, 0x01000001ff010100, 0x0100000100ff0000, 0x0100000100ff01ff,\n    0x0100000100ff0100, 0x010000010000ff00, 0x010000010000ff01, 0x0100000100000000,\n    0x0100000100000001, 0x0100000100000100, 0x0100000100010000, 0x01000001000101ff,\n    0x0100000101ffff01, 0x0100000101ff00ff, 0x0100000101ff0100, 0x0100000101ff0101,\n    0x010000010100ff01, 0x01000001010000ff, 0x0100000101000000, 0x01000001010100ff,\n    0x0100000101010001, 0x0100000101010100, 0x010001ffffff0000, 0x010001ffff000001,\n    0x010001ffff000100, 0x010001ffff010000, 0x010001ff00ffff00, 0x010001ff00ff0001,\n    0x010001ff0000ffff, 0x010001ff0000ff01, 0x010001ff00000000, 0x010001ff00000001,\n    0x010001ff00000101, 0x010001ff000100ff, 0x010001ff00010000, 0x010001ff01ff0000,\n    0x010001ff0100ff00, 0x010001ff01000001, 0x010001ff01000100, 0x010001ff01010000,\n    0x01000100ffff00ff, 0x01000100ffff0001, 0x01000100ffff0100, 0x01000100ff00ffff,\n    0x01000100ff00ff01, 0x01000100ff000000, 0x01000100ff0001ff, 0x01000100ff000101,\n    0x01000100ff01ffff, 0x01000100ff01ff00, 0x01000100ff0100ff, 0x01000100ff010001,\n    0x0100010000ffffff, 0x0100010000ffff01, 0x0100010000ff0000, 0x0100010000ff01ff,\n    0x0100010000ff0101, 0x010001000000ff00, 0x01000100000000ff, 0x0100010000000000,\n    0x0100010000000001, 0x0100010000000100, 0x010001000001ff01, 0x0100010000010000,\n    0x0100010000010001, 0x0100010000010101, 0x0100010001ffff00, 0x0100010001ff00ff,\n    0x010001000100ffff, 0x010001000100ff01, 0x0100010001000000, 0x0100010001000101,\n    0x010001000101ff00, 0x0100010001010001, 0x01000101ffff0000, 0x01000101ff000000,\n    0x01000101ff010000, 0x0100010100ff00ff, 0x0100010100ff0001, 0x0100010100ff0100,\n    0x010001010000ffff, 0x0100010100000000, 0x01000101000001ff, 0x010001010001ff00,\n    0x0100010101ff0000, 0x010001010100ff00, 0x01000101010000ff, 0x0100010101000000,\n    0x0100010101000001, 0x0101ffffffffffff, 0x0101ffffffffff01, 0x0101ffffffff01ff,\n    0x0101ffffffff0101, 0x0101ffffff000000, 0x0101ffffff01ffff, 0x0101ffffff01ff01,\n    0x0101ffffff0101ff, 0x0101ffffff010101, 0x0101ffff00ff0000, 0x0101ffff0000ff00,\n    0x0101ffff000000ff, 0x0101ffff00000001, 0x0101ffff00000100, 0x0101ffff01ffffff,\n    0x0101ffff01ffff01, 0x0101ffff01ff01ff, 0x0101ffff01ff0101, 0x0101ffff01000000,\n    0x0101ffff0101ffff, 0x0101ffff0101ff01, 0x0101ffff010101ff, 0x0101ffff01010101,\n    0x0101ff00ffff0000, 0x0101ff00ffff0100, 0x0101ff00ff00ff00, 0x0101ff00ff0000ff,\n    0x0101ff00ff000001, 0x0101ff00ff000100, 0x0101ff00ff000101, 0x0101ff0000ff0001,\n    0x0101ff0000ff0100, 0x0101ff000000ff00, 0x0101ff0000000000, 0x0101ff00000001ff,\n    0x0101ff0000000101, 0x0101ff000001ff00, 0x0101ff00000100ff, 0x0101ff0001ff0000,\n    0x0101ff000100ffff, 0x0101ff000100ff01, 0x0101ff0001000001, 0x0101ff0001000100,\n    0x0101ff01ffffff01, 0x0101ff01ffff01ff, 0x0101ff01ffff0101, 0x0101ff01ff00ffff,\n    0x0101ff01ff000100, 0x0101ff01ff01ff01, 0x0101ff01ff0101ff, 0x0101ff01ff010101,\n    0x0101ff0100ff0000, 0x0101ff010000ff00, 0x0101ff0100000001, 0x0101ff0100000100,\n    0x0101ff0100010000, 0x0101ff0101ffffff, 0x0101ff0101ffff01, 0x0101ff0101ff01ff,\n    0x0101ff0101ff0101, 0x0101ff0101000000, 0x0101ff010101ffff, 0x0101ff010101ff01,\n    0x0101ff01010101ff, 0x0101ff0101010101, 0x010100ffff000100, 0x010100ffff010000,\n    0x010100ff00ffff00, 0x010100ff00ff00ff, 0x010100ff0000ffff, 0x010100ff000000ff,\n    0x010100ff00000000, 0x010100ff000001ff, 0x010100ff00000101, 0x010100ff0001ff00,\n    0x010100ff00010000, 0x010100ff00010001, 0x010100ff000101ff, 0x010100ff00010100,\n    0x010100ff01ff0000, 0x01010000ffff0001, 0x01010000ffff0100, 0x01010000ff00ffff,\n    0x01010000ff00ff01, 0x01010000ff000000, 0x01010000ff0001ff, 0x01010000ff010001,\n    0x01010000ff010100, 0x0101000000ffff01, 0x0101000000ff0000, 0x010100000000ff00,\n    0x01010000000000ff, 0x0101000000000000, 0x0101000000000001, 0x0101000000000100,\n    0x0101000000010000, 0x0101000000010101, 0x0101000001ffff00, 0x0101000001ff00ff,\n    0x0101000001ff0000, 0x0101000001ff0001, 0x0101000001ff0100, 0x010100000100ff01,\n    0x0101000001000000, 0x01010000010001ff, 0x01010001ffff0000, 0x01010001ff00ff00,\n    0x01010001ff000001, 0x01010001ff000101, 0x01010001ff01ff00, 0x01010001ff010000,\n    0x0101000100ff00ff, 0x0101000100ff0001, 0x0101000100ff0101, 0x010100010000ff01,\n    0x0101000100000000, 0x0101000100000001, 0x01010001000001ff, 0x010100010001ffff,\n    0x010100010001ff01, 0x0101000101ff0001, 0x010100010100ffff, 0x0101000101000000,\n    0x0101000101000001, 0x0101000101000100, 0x010100010101ff00, 0x01010001010100ff,\n    0x0101000101010001, 0x010101ffffffffff, 0x010101ffffffff01, 0x010101ffffff01ff,\n    0x010101ffffff0101, 0x010101ffff01ffff, 0x010101ffff01ff01, 0x010101ffff0101ff,\n    0x010101ffff010101, 0x010101ff0000ff00, 0x010101ff000000ff, 0x010101ff00000001,\n    0x010101ff00000100, 0x010101ff01ffffff, 0x010101ff01ffff01, 0x010101ff01ff01ff,\n    0x010101ff01ff0101, 0x010101ff01000000, 0x010101ff0101ffff, 0x010101ff0101ff01,\n    0x010101ff010101ff, 0x010101ff01010101, 0x01010100ffff0000, 0x01010100ff0000ff,\n    0x01010100ff000100, 0x01010100ff01ff00, 0x01010100ff010000, 0x0101010000ffff00,\n    0x010101000000ffff, 0x0101010000000000, 0x0101010000000101, 0x010101000001ff00,\n    0x0101010000010001, 0x0101010000010100, 0x010101000100ffff, 0x0101010001000001,\n    0x01010101ffffffff, 0x01010101ffffff01, 0x01010101ffff01ff, 0x01010101ffff0101,\n    0x01010101ff01ffff, 0x01010101ff01ff01, 0x01010101ff0101ff, 0x01010101ff010101,\n    0x010101010000ff00, 0x01010101000000ff, 0x0101010100000001, 0x0101010101ffffff,\n    0x0101010101ffff01, 0x0101010101ff01ff, 0x0101010101ff0101, 0x0101010101000000,\n    0x010101010101ffff, 0x010101010101ff01, 0x01010101010101ff, 0x0101010101010101,\nGGML_TABLE_END()\n#else\nGGML_TABLE_BEGIN(uint32_t, iq1s_grid_gpu, NGRID_IQ1S)\n    0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,\n    0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,\n    0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,\n    0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212,\n    0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011,\n    0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111,\n    0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220,\n    0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022,\n    0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,\n    0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101,\n    0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110,\n    0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111,\n    0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010,\n    0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210,\n    0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221,\n    0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021,\n    0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002,\n    0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,\n    0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101,\n    0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211,\n    0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110,\n    0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022,\n    0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121,\n    0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220,\n    0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001,\n    0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101,\n    0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,\n    0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012,\n    0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010,\n    0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111,\n    0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122,\n    0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222,\n    0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001,\n    0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102,\n    0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101,\n    0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,\n    0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101,\n    0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112,\n    0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110,\n    0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211,\n    0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012,\n    0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111,\n    0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120,\n    0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122,\n    0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,\n    0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221,\n    0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001,\n    0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101,\n    0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101,\n    0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011,\n    0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111,\n    0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011,\n    0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122,\n    0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,\n    0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222,\n    0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101,\n    0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000,\n    0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200,\n    0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110,\n    0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112,\n    0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222,\n    0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021,\n    0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,\n    0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201,\n    0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200,\n    0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101,\n    0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011,\n    0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010,\n    0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211,\n    0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121,\n    0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000,\n    0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,\n    0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202,\n    0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211,\n    0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112,\n    0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020,\n    0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121,\n    0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222,\n    0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102,\n    0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100,\n    0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,\n    0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011,\n    0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111,\n    0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110,\n    0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121,\n    0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222,\n    0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201,\n    0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102,\n    0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201,\n    0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,\n    0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010,\n    0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010,\n    0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110,\n    0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011,\n    0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212,\n    0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021,\n    0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021,\n    0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021,\n    0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,\n    0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101,\n    0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100,\n    0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010,\n    0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111,\n    0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010,\n    0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111,\n    0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120,\n    0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120,\n    0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,\n    0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001,\n    0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201,\n    0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210,\n    0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211,\n    0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111,\n    0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112,\n    0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211,\n    0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010,\n    0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,\n    0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122,\n    0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221,\n    0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102,\n    0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100,\n    0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101,\n    0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101,\n    0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101,\n    0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012,\n    0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,\n    0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112,\n    0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210,\n    0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210,\n    0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210,\n    0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010,\n    0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110,\n    0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122,\n    0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020,\n    0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,\n    0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022,\n    0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120,\n    0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222,\n    0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221,\n    0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001,\n    0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102,\n    0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201,\n    0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012,\n    0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,\n    0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012,\n    0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110,\n    0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110,\n    0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121,\n    0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221,\n    0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220,\n    0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222,\n    0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000,\n    0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,\n    0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012,\n    0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011,\n    0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212,\n    0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221,\n    0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121,\n    0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202,\n    0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202,\n    0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002,\n    0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,\n    0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210,\n    0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112,\n    0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011,\n    0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011,\n    0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210,\n    0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020,\n    0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220,\n    0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222,\n    0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,\n    0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001,\n    0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010,\n    0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111,\n    0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010,\n    0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110,\n    0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221,\n    0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122,\n    0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202,\n    0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,\n    0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101,\n    0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112,\n    0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111,\n    0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211,\n    0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222,\n    0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221,\n    0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022,\n    0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101,\n    0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,\n    0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111,\n    0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111,\n    0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010,\n    0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121,\n    0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222,\n    0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000,\n    0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202,\n    0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000,\n    0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,\n    0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110,\n    0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110,\n    0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222,\n    0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120,\n    0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022,\n    0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101,\n    0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202,\n    0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110,\n    0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,\n    0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111,\n    0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111,\n    0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120,\n    0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121,\n    0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001,\n    0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202,\n    0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001,\n    0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200,\n    0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,\n    0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212,\n    0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012,\n    0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110,\n    0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012,\n    0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111,\n    0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020,\n    0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121,\n    0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222,\n    0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,\n    0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102,\n    0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101,\n    0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212,\n    0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210,\n    0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111,\n    0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212,\n    0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221,\n    0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121,\n    0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,\n    0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000,\n    0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202,\n    0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112,\n    0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111,\n    0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020,\n    0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221,\n    0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022,\n    0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100,\n    0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,\n    0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112,\n    0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211,\n    0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012,\n    0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121,\n    0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020,\n    0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120,\n    0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200,\n    0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200,\n    0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,\n    0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011,\n    0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222,\n    0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,\n    0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,\nGGML_TABLE_END()\n#endif\n\n#endif // GGML_COMMON_IMPL\n#endif // GGML_COMMON_IMPL\n"
  },
  {
    "path": "src/ggml-cpu/CMakeLists.txt",
    "content": "function(ggml_add_cpu_backend_features cpu_name arch)\n    # The feature detection code is compiled as a separate target so that\n    # it can be built without the architecture flags\n    # Since multiple variants of the CPU backend may be included in the same\n    # build, using set_source_files_properties() to set the arch flags is not possible\n    set(GGML_CPU_FEATS_NAME ${cpu_name}-feats)\n    add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/arch/${arch}/cpu-feats.cpp)\n    target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . ../include)\n    target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN})\n    target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)\n    set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)\n    # Disable LTO for the feature detection code to prevent cross-module optimization\n    # from inlining architecture-specific instructions into the score function.\n    # Without this, LTO can cause SIGILL when loading backends on older CPUs\n    # (e.g., loading power10 backend on power9 crashes before feature check runs).\n    target_compile_options(${GGML_CPU_FEATS_NAME} PRIVATE -fno-lto)\n    target_link_libraries(${cpu_name} PRIVATE ${GGML_CPU_FEATS_NAME})\nendfunction()\n\nfunction(ggml_add_cpu_backend_variant_impl tag_name)\n    if (tag_name)\n        set(GGML_CPU_NAME ggml-cpu-${tag_name})\n    else()\n        set(GGML_CPU_NAME ggml-cpu)\n    endif()\n\n    ggml_add_backend_library(${GGML_CPU_NAME})\n\n    list (APPEND GGML_CPU_SOURCES\n        ggml-cpu/ggml-cpu.c\n        ggml-cpu/ggml-cpu.cpp\n        ggml-cpu/repack.cpp\n        ggml-cpu/repack.h\n        ggml-cpu/hbm.cpp\n        ggml-cpu/hbm.h\n        ggml-cpu/quants.c\n        ggml-cpu/quants.h\n        ggml-cpu/traits.cpp\n        ggml-cpu/traits.h\n        ggml-cpu/amx/amx.cpp\n        ggml-cpu/amx/amx.h\n        ggml-cpu/amx/mmq.cpp\n        ggml-cpu/amx/mmq.h\n        ggml-cpu/ggml-cpu-impl.h\n        ggml-cpu/common.h\n        ggml-cpu/binary-ops.h\n        ggml-cpu/binary-ops.cpp\n        ggml-cpu/unary-ops.h\n        ggml-cpu/unary-ops.cpp\n        ggml-cpu/simd-mappings.h\n        ggml-cpu/vec.h\n        ggml-cpu/vec.cpp\n        ggml-cpu/ops.h\n        ggml-cpu/ops.cpp\n        )\n\n    target_compile_features(${GGML_CPU_NAME} PRIVATE c_std_11 cxx_std_17)\n    target_include_directories(${GGML_CPU_NAME} PRIVATE . ggml-cpu)\n\n    if (APPLE AND GGML_ACCELERATE)\n        find_library(ACCELERATE_FRAMEWORK Accelerate)\n        if (ACCELERATE_FRAMEWORK)\n            message(STATUS \"Accelerate framework found\")\n\n            target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_ACCELERATE)\n            target_compile_definitions(${GGML_CPU_NAME} PRIVATE ACCELERATE_NEW_LAPACK)\n            target_compile_definitions(${GGML_CPU_NAME} PRIVATE ACCELERATE_LAPACK_ILP64)\n\n            target_link_libraries(${GGML_CPU_NAME} PRIVATE ${ACCELERATE_FRAMEWORK})\n        else()\n            message(WARNING \"Accelerate framework not found\")\n        endif()\n    endif()\n\n    if (GGML_OPENMP)\n        find_package(OpenMP)\n        if (OpenMP_FOUND)\n            set(GGML_OPENMP_ENABLED \"ON\" CACHE INTERNAL \"\")\n            target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP)\n\n            target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX)\n        else()\n            set(GGML_OPENMP_ENABLED \"OFF\" CACHE INTERNAL \"\")\n            message(WARNING \"OpenMP not found\")\n        endif()\n    endif()\n\n    if (GGML_LLAMAFILE)\n        target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_LLAMAFILE)\n\n        list(APPEND GGML_CPU_SOURCES\n                    ggml-cpu/llamafile/sgemm.cpp\n                    ggml-cpu/llamafile/sgemm.h)\n    endif()\n\n    if (GGML_CPU_HBM)\n        find_library(memkind memkind REQUIRED)\n\n        message(STATUS \"Using memkind for CPU HBM\")\n\n        target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_HBM)\n\n        target_link_libraries(${GGML_CPU_NAME} PUBLIC memkind)\n    endif()\n\n    if (GGML_SYSTEM_ARCH STREQUAL \"ARM\")\n        message(STATUS \"ARM detected\")\n        list(APPEND GGML_CPU_SOURCES\n            ggml-cpu/arch/arm/quants.c\n            ggml-cpu/arch/arm/repack.cpp\n            )\n\n        if (MSVC AND NOT CMAKE_C_COMPILER_ID STREQUAL \"Clang\")\n            message(FATAL_ERROR \"MSVC is not supported for ARM, use clang\")\n        else()\n            check_cxx_compiler_flag(-mfp16-format=ieee GGML_COMPILER_SUPPORTS_FP16_FORMAT_I3E)\n            if (NOT \"${GGML_COMPILER_SUPPORTS_FP16_FORMAT_I3E}\" STREQUAL \"\")\n                list(APPEND ARCH_FLAGS -mfp16-format=ieee)\n            endif()\n\n            if (GGML_NATIVE)\n                # -mcpu=native does not always enable all the features in some compilers,\n                # so we check for them manually and enable them if available\n\n                execute_process(\n                    COMMAND ${CMAKE_C_COMPILER} -mcpu=native -E -v -\n                    INPUT_FILE \"/dev/null\"\n                    OUTPUT_QUIET\n                    ERROR_VARIABLE ARM_MCPU\n                    RESULT_VARIABLE ARM_MCPU_RESULT\n                )\n                if (NOT ARM_MCPU_RESULT)\n                    string(REGEX MATCH \"-mcpu=[^ ']+\" ARM_MCPU_FLAG \"${ARM_MCPU}\")\n                    string(REGEX MATCH \"-march=[^ ']+\" ARM_MARCH_FLAG \"${ARM_MCPU}\")\n\n                    # on some old GCC we need to read -march=\n                    if (ARM_MARCH_FLAG AND NOT \"${ARM_MARCH_FLAG}\" STREQUAL \"-march=native\")\n                        set(ARM_NATIVE_FLAG \"${ARM_MARCH_FLAG}\")\n                    elseif(ARM_MCPU_FLAG AND NOT \"${ARM_MCPU_FLAG}\" STREQUAL \"-mcpu=native\")\n                        set(ARM_NATIVE_FLAG \"${ARM_MCPU_FLAG}\")\n                    endif()\n                endif()\n\n                if (\"${ARM_NATIVE_FLAG}\" STREQUAL \"\")\n                    set(ARM_NATIVE_FLAG -mcpu=native)\n                    message(WARNING \"ARM -march/-mcpu not found, -mcpu=native will be used\")\n                else()\n                    message(STATUS \"ARM detected flags: ${ARM_NATIVE_FLAG}\")\n                endif()\n\n                include(CheckCXXSourceRuns)\n\n                macro(check_arm_feature tag feature code)\n                    set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})\n                    set(CMAKE_REQUIRED_FLAGS \"${ARM_NATIVE_FLAG}+${tag}\")\n                    check_cxx_source_runs(\"${code}\" GGML_MACHINE_SUPPORTS_${tag})\n                    if (GGML_MACHINE_SUPPORTS_${tag})\n                        set(ARM_NATIVE_FLAG_FIX \"${ARM_NATIVE_FLAG_FIX}+${tag}\")\n                    else()\n                        set(CMAKE_REQUIRED_FLAGS \"${ARM_NATIVE_FLAG}+no${tag}\")\n                        check_cxx_source_compiles(\"int main() { return 0; }\" GGML_MACHINE_SUPPORTS_no${tag})\n                        if (GGML_MACHINE_SUPPORTS_no${tag})\n                            set(ARM_NATIVE_FLAG_FIX \"${ARM_NATIVE_FLAG_FIX}+no${tag}\")\n                            list(APPEND ARCH_FLAGS -U__ARM_FEATURE_${feature})\n                        endif()\n                    endif()\n                    set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})\n                endmacro()\n\n                check_arm_feature(dotprod DOTPROD     \"#include <arm_neon.h>\\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }\")\n                check_arm_feature(i8mm    MATMUL_INT8 \"#include <arm_neon.h>\\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }\")\n                check_arm_feature(sve     SVE         \"#include <arm_sve.h>\\nint main()  { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }\")\n                check_arm_feature(sme     SME         \"#include <arm_sme.h>\\n__arm_locally_streaming int main() { __asm__ volatile(\\\"smstart; smstop;\\\"); return 0; }\")\n\n                list(APPEND ARCH_FLAGS \"${ARM_NATIVE_FLAG}${ARM_NATIVE_FLAG_FIX}\")\n            else()\n                if (GGML_CPU_ARM_ARCH)\n                    list(APPEND ARCH_FLAGS -march=${GGML_CPU_ARM_ARCH})\n                elseif(GGML_CPU_ALL_VARIANTS)\n                    # Begin with the lowest baseline\n                    set(ARM_MCPU \"armv8-a\")\n                    set(ARCH_TAGS \"\")\n                    set(ARCH_DEFINITIONS \"\")\n\n                    # When a feature is selected, bump the MCPU to the first\n                    # version that supported it\n                    if (GGML_INTERNAL_DOTPROD)\n                        set(ARM_MCPU \"armv8.2-a\")\n                        set(ARCH_TAGS \"${ARCH_TAGS}+dotprod\")\n                        list(APPEND ARCH_DEFINITIONS GGML_USE_DOTPROD)\n                    endif()\n                    if (GGML_INTERNAL_FP16_VECTOR_ARITHMETIC)\n                        set(ARM_MCPU \"armv8.2-a\")\n                        set(ARCH_TAGS \"${ARCH_TAGS}+fp16\")\n                        list(APPEND ARCH_DEFINITIONS GGML_USE_FP16_VECTOR_ARITHMETIC)\n                    endif()\n                    if (GGML_INTERNAL_SVE)\n                        set(ARM_MCPU \"armv8.2-a\")\n                        set(ARCH_TAGS \"${ARCH_TAGS}+sve\")\n                        list(APPEND ARCH_DEFINITIONS GGML_USE_SVE)\n                    endif()\n                    if (GGML_INTERNAL_MATMUL_INT8)\n                        set(ARM_MCPU \"armv8.6-a\")\n                        set(ARCH_TAGS \"${ARCH_TAGS}+i8mm\")\n                        list(APPEND ARCH_DEFINITIONS GGML_USE_MATMUL_INT8)\n                    endif()\n                    if (GGML_INTERNAL_SVE2)\n                        set(ARM_MCPU \"armv8.6-a\")\n                        set(ARCH_TAGS \"${ARCH_TAGS}+sve2\")\n                        list(APPEND ARCH_DEFINITIONS GGML_USE_SVE2)\n                    endif()\n                    if (GGML_INTERNAL_NOSVE)\n                        set(ARCH_TAGS \"${ARCH_TAGS}+nosve\")\n                    endif()\n                    if (GGML_INTERNAL_SME)\n                        set(ARM_MCPU \"armv9.2-a\")\n                        set(ARCH_TAGS \"${ARCH_TAGS}+sme\")\n                        list(APPEND ARCH_DEFINITIONS GGML_USE_SME)\n                    endif()\n                    list(APPEND ARCH_FLAGS \"-march=${ARM_MCPU}${ARCH_TAGS}\")\n                    ggml_add_cpu_backend_features(${GGML_CPU_NAME} arm ${ARCH_DEFINITIONS})\n                endif()\n            endif()\n\n            message(STATUS \"Checking for ARM features using flags:\")\n            foreach(flag IN LISTS ARCH_FLAGS)\n                message(STATUS \"  ${flag}\")\n            endforeach()\n\n            include(CheckCXXSourceCompiles)\n            set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})\n            string(REPLACE \";\" \" \" ARCH_FLAGS_STR \"${ARCH_FLAGS}\")\n            set(CMAKE_REQUIRED_FLAGS \"${ARCH_FLAGS_STR}\")\n            foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)\n                set(ARM_FEATURE \"HAVE_${feature}\")\n                check_cxx_source_compiles(\n                    \"\n                    #if !defined(__ARM_FEATURE_${feature})\n                    #  error \\\"Feature ${feature} is not defined\\\"\n                    #endif\n                    int main() { return 0; }\n                    \"\n                    ${ARM_FEATURE}\n                )\n            endforeach()\n            set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})\n        endif()\n    elseif (GGML_SYSTEM_ARCH STREQUAL \"x86\")\n        message(STATUS \"x86 detected\")\n        list(APPEND GGML_CPU_SOURCES\n            ggml-cpu/arch/x86/quants.c\n            ggml-cpu/arch/x86/repack.cpp\n            )\n\n        if (MSVC)\n            # instruction set detection for MSVC only\n            if (GGML_NATIVE)\n                include(ggml-cpu/cmake/FindSIMD.cmake)\n            endif ()\n            if (GGML_AVX512)\n                list(APPEND ARCH_FLAGS /arch:AVX512)\n                # /arch:AVX512 includes: __AVX512F__, __AVX512CD__, __AVX512BW__, __AVX512DQ__, and __AVX512VL__\n                # MSVC has no compile-time flags enabling specific\n                # AVX512 extensions, neither it defines the\n                # macros corresponding to the extensions.\n                # Do it manually.\n                list(APPEND ARCH_DEFINITIONS GGML_AVX512)\n                if (GGML_AVX512_VBMI)\n                    list(APPEND ARCH_DEFINITIONS __AVX512VBMI__)\n                    if (CMAKE_C_COMPILER_ID STREQUAL \"Clang\")\n                        list(APPEND ARCH_FLAGS -mavx512vbmi)\n                    endif()\n                endif()\n                if (GGML_AVX512_VNNI)\n                    list(APPEND ARCH_DEFINITIONS __AVX512VNNI__ GGML_AVX512_VNNI)\n                    if (CMAKE_C_COMPILER_ID STREQUAL \"Clang\")\n                        list(APPEND ARCH_FLAGS -mavx512vnni)\n                    endif()\n                endif()\n                if (GGML_AVX512_BF16)\n                    list(APPEND ARCH_DEFINITIONS __AVX512BF16__ GGML_AVX512_BF16)\n                    if (CMAKE_C_COMPILER_ID STREQUAL \"Clang\")\n                        list(APPEND ARCH_FLAGS -mavx512bf16)\n                    endif()\n                endif()\n                if (GGML_AMX_TILE)\n                    list(APPEND ARCH_DEFINITIONS __AMX_TILE__ GGML_AMX_TILE)\n                endif()\n                if (GGML_AMX_INT8)\n                    list(APPEND ARCH_DEFINITIONS __AMX_INT8__ GGML_AMX_INT8)\n                endif()\n                if (GGML_AMX_BF16)\n                    list(APPEND ARCH_DEFINITIONS __AMX_BF16__ GGML_AMX_BF16)\n                endif()\n            elseif (GGML_AVX2)\n                list(APPEND ARCH_FLAGS /arch:AVX2)\n                list(APPEND ARCH_DEFINITIONS GGML_AVX2 GGML_FMA GGML_F16C)\n            elseif (GGML_AVX)\n                list(APPEND ARCH_FLAGS /arch:AVX)\n                list(APPEND ARCH_DEFINITIONS GGML_AVX)\n            elseif (GGML_SSE42)\n                list(APPEND ARCH_FLAGS /arch:SSE4.2)\n                list(APPEND ARCH_DEFINITIONS GGML_SSE42)\n            endif()\n            if (GGML_AVX_VNNI)\n                list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI)\n            endif()\n            if (GGML_BMI2)\n                # MSVC does not define macro __BMI2__\n                list(APPEND ARCH_DEFINITIONS __BMI2__ GGML_BMI2)\n            endif()\n        else ()\n            if (GGML_NATIVE)\n                list(APPEND ARCH_FLAGS -march=native)\n            else ()\n                if (GGML_SSE42)\n                    list(APPEND ARCH_FLAGS -msse4.2)\n                    list(APPEND ARCH_DEFINITIONS GGML_SSE42)\n                endif()\n                if (GGML_F16C)\n                    list(APPEND ARCH_FLAGS -mf16c)\n                    list(APPEND ARCH_DEFINITIONS GGML_F16C)\n                endif()\n                if (GGML_FMA)\n                    list(APPEND ARCH_FLAGS -mfma)\n                    list(APPEND ARCH_DEFINITIONS GGML_FMA)\n                endif()\n                if (GGML_BMI2)\n                    list(APPEND ARCH_FLAGS -mbmi2)\n                    list(APPEND ARCH_DEFINITIONS GGML_BMI2)\n                endif()\n                if (GGML_AVX)\n                    list(APPEND ARCH_FLAGS -mavx)\n                    list(APPEND ARCH_DEFINITIONS GGML_AVX)\n                endif()\n                if (GGML_AVX2)\n                    list(APPEND ARCH_FLAGS -mavx2)\n                    list(APPEND ARCH_DEFINITIONS GGML_AVX2)\n                endif()\n                if (GGML_AVX_VNNI)\n                    list(APPEND ARCH_FLAGS -mavxvnni)\n                    list(APPEND ARCH_DEFINITIONS GGML_AVX_VNNI)\n                endif()\n                if (GGML_AVX512)\n                    list(APPEND ARCH_FLAGS -mavx512f)\n                    list(APPEND ARCH_FLAGS -mavx512cd)\n                    list(APPEND ARCH_FLAGS -mavx512vl)\n                    list(APPEND ARCH_FLAGS -mavx512dq)\n                    list(APPEND ARCH_FLAGS -mavx512bw)\n                    list(APPEND ARCH_DEFINITIONS GGML_AVX512)\n                endif()\n                if (GGML_AVX512_VBMI)\n                    list(APPEND ARCH_FLAGS -mavx512vbmi)\n                    list(APPEND ARCH_DEFINITIONS GGML_AVX512_VBMI)\n                endif()\n                if (GGML_AVX512_VNNI)\n                    list(APPEND ARCH_FLAGS -mavx512vnni)\n                    list(APPEND ARCH_DEFINITIONS GGML_AVX512_VNNI)\n                endif()\n                if (GGML_AVX512_BF16)\n                    list(APPEND ARCH_FLAGS -mavx512bf16)\n                    list(APPEND ARCH_DEFINITIONS GGML_AVX512_BF16)\n                endif()\n                if (GGML_AMX_TILE)\n                    list(APPEND ARCH_FLAGS -mamx-tile)\n                    list(APPEND ARCH_DEFINITIONS GGML_AMX_TILE)\n                endif()\n                if (GGML_AMX_INT8)\n                    list(APPEND ARCH_FLAGS -mamx-int8)\n                    list(APPEND ARCH_DEFINITIONS GGML_AMX_INT8)\n                endif()\n                if (GGML_AMX_BF16)\n                    list(APPEND ARCH_FLAGS -mamx-bf16)\n                    list(APPEND ARCH_DEFINITIONS GGML_AMX_BF16)\n                endif()\n            endif()\n        endif()\n\n        if (GGML_BACKEND_DL)\n            if (GGML_NATIVE)\n                # the feature check relies on ARCH_DEFINITIONS, but it is not set with GGML_NATIVE\n                message(FATAL_ERROR \"GGML_NATIVE is not compatible with GGML_BACKEND_DL, consider using GGML_CPU_ALL_VARIANTS\")\n            endif()\n            ggml_add_cpu_backend_features(${GGML_CPU_NAME} x86 ${ARCH_DEFINITIONS})\n        endif()\n    elseif (GGML_SYSTEM_ARCH STREQUAL \"PowerPC\")\n        message(STATUS \"PowerPC detected\")\n        list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/powerpc/quants.c)\n        if (GGML_NATIVE)\n            if (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"ppc64\")\n                file(READ \"/proc/cpuinfo\" POWER10_M)\n            elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"powerpc\")\n                execute_process(COMMAND bash -c \"prtconf |grep 'Implementation' | head -n 1\" OUTPUT_VARIABLE POWER10_M)\n            endif()\n\n            string(TOUPPER \"${POWER10_M}\" POWER10_M_UPPER)\n            string(REGEX MATCHALL \"POWER *([0-9]+)\" MATCHED_STRING \"${POWER10_M_UPPER}\")\n            string(REGEX REPLACE \"POWER *([0-9]+)\" \"\\\\1\" EXTRACTED_NUMBER \"${MATCHED_STRING}\")\n\n            if (EXTRACTED_NUMBER GREATER_EQUAL 10)\n                list(APPEND ARCH_FLAGS -mcpu=power10)\n            elseif (EXTRACTED_NUMBER EQUAL 9)\n                list(APPEND ARCH_FLAGS -mcpu=power9)\n            elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"ppc64le\")\n                list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)\n            else()\n                list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)\n            endif()\n        elseif(GGML_CPU_ALL_VARIANTS)\n            # Begin with the lowest baseline\n            set(ARCH_DEFINITIONS \"\")\n\n            # When a feature is selected, bump the MCPU to the first\n            # version that supported it\n            foreach(PVER RANGE 7 11)\n                if(DEFINED GGML_INTERNAL_POWER${PVER})\n                    set(POWERPC_MCPU \"power${PVER}\")\n                    list(APPEND ARCH_DEFINITIONS GGML_USE_POWER${PVER})\n                endif()\n            endforeach()\n            if (GGML_INTERNAL_VSX)\n                list(APPEND ARCH_DEFINITIONS GGML_USE_VSX)\n                list(APPEND ARCH_FLAGS -mvsx)\n            endif()\n\n            if (DEFINED POWERPC_MCPU)\n                list(APPEND ARCH_FLAGS -mcpu=${POWERPC_MCPU})\n            endif()\n            ggml_add_cpu_backend_features(${GGML_CPU_NAME} powerpc ${ARCH_DEFINITIONS})\n        else()\n            if (GGML_CPU_POWERPC_CPUTYPE)\n                list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE})\n            endif()\n        endif()\n    elseif (GGML_SYSTEM_ARCH STREQUAL \"loongarch64\")\n        message(STATUS \"loongarch64 detected\")\n        list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/loongarch/quants.c)\n\n        list(APPEND ARCH_FLAGS -march=loongarch64)\n        if (GGML_LASX)\n            list(APPEND ARCH_FLAGS -mlasx)\n        endif()\n        if (GGML_LSX)\n            list(APPEND ARCH_FLAGS -mlsx)\n        endif()\n    elseif (GGML_SYSTEM_ARCH STREQUAL \"riscv64\")\n        message(STATUS \"riscv64 detected\")\n        list(APPEND GGML_CPU_SOURCES\n            ggml-cpu/arch/riscv/quants.c\n            ggml-cpu/arch/riscv/repack.cpp\n            )\n        if (GGML_CPU_RISCV64_SPACEMIT)\n            target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC})\n            list(APPEND GGML_CPU_SOURCES\n                ggml-cpu/spacemit/ime.cpp\n                ggml-cpu/spacemit/ime.h\n                ggml-cpu/spacemit/ime1_kernels.cpp\n                ggml-cpu/spacemit/ime_kernels.h\n            )\n        endif()\n        if(NOT GGML_CPU_ALL_VARIANTS)\n            set(MARCH_STR \"rv64gc\")\n            if (GGML_RV_ZFH)\n                string(APPEND MARCH_STR \"_zfh\")\n            endif()\n\n            if (GGML_XTHEADVECTOR)\n                string(APPEND MARCH_STR \"_xtheadvector\")\n            elseif (GGML_RVV)\n                string(APPEND MARCH_STR \"_v\")\n                if (GGML_RV_ZVFH)\n                    string(APPEND MARCH_STR \"_zvfh\")\n                endif()\n                if (GGML_RV_ZVFBFWMA)\n                    string(APPEND MARCH_STR \"_zvfbfwma\")\n                endif()\n            endif()\n            if (GGML_RV_ZICBOP)\n                string(APPEND MARCH_STR \"_zicbop\")\n            endif()\n            if (GGML_RV_ZIHINTPAUSE)\n                string(APPEND MARCH_STR \"_zihintpause\")\n            endif()\n            list(APPEND ARCH_FLAGS \"-march=${MARCH_STR}\" -mabi=lp64d)\n        else()\n            # Begin with the lowest baseline\n            set(ARCH_DEFINITIONS \"\")\n\n            if (GGML_INTERNAL_RVV)\n                message(STATUS \"RVV enabled\")\n                list(APPEND ARCH_DEFINITIONS GGML_USE_RVV)\n                list(APPEND ARCH_FLAGS -march=rv64gc_v -mabi=lp64d)\n            endif()\n\n            ggml_add_cpu_backend_features(${GGML_CPU_NAME} riscv ${ARCH_DEFINITIONS})\n        endif()\n    elseif (GGML_SYSTEM_ARCH STREQUAL \"s390x\")\n        message(STATUS \"s390x detected\")\n        list(APPEND GGML_CPU_SOURCES\n            ggml-cpu/arch/s390/quants.c)\n\n        # for native compilation\n        if (GGML_NATIVE)\n            # check machine level to determine target\n            file(READ \"/proc/cpuinfo\" CPUINFO_CONTENTS)\n            string(REGEX REPLACE \"machine[ \\t\\r\\n]*=[ \\t\\r\\n]*([0-9]+)\" \"\\\\1\" S390X_M ${CPUINFO_CONTENTS})\n\n            # TODO: Separation to determine activation of VX/VXE/VXE2\n            if (${S390X_M} MATCHES \"8561|8562\")\n                message(STATUS \"z15 target\")\n                list(APPEND ARCH_FLAGS -march=z15)\n            elseif (${S390X_M} MATCHES \"3931\")\n                message(STATUS \"z16 target\")\n                list(APPEND ARCH_FLAGS -march=z16)\n            elseif (${S390X_M} MATCHES \"9175|9176\")\n                # NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.\n                #       binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15.\n                message(STATUS \"z17 target\")\n                list(APPEND ARCH_FLAGS -march=arch15)\n            else()\n                message(STATUS \"Unknown target\")\n                message(WARNING \"Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.\")\n                list(APPEND ARCH_FLAGS -march=native -mtune=native)\n            endif()\n        # for cross-compilation\n        elseif(GGML_CPU_ALL_VARIANTS)\n            # range through IBM z15 to z17\n            # NOTE: update when a new hardware level is released\n            foreach (ZHW RANGE 15 17)\n                if(DEFINED GGML_INTERNAL_Z${ZHW})\n                    message(STATUS \"z${ZHW} cross-compile target\")\n                    list(APPEND ARCH_FLAGS -march=z${ZHW})\n                endif()\n            endforeach()\n        endif()\n\n        if (GGML_VXE OR GGML_INTERNAL_VXE2)\n            message(STATUS \"VXE2 enabled\")\n            list(APPEND ARCH_FLAGS -mvx -mzvector)\n            list(APPEND ARCH_DEFINITIONS GGML_USE_VXE2)\n        endif()\n\n        if (GGML_INTERNAL_NNPA)\n            message(STATUS \"NNPA enabled\")\n            list(APPEND ARCH_DEFINITIONS GGML_USE_NNPA)\n        endif()\n\n        ggml_add_cpu_backend_features(${GGML_CPU_NAME} s390 ${ARCH_DEFINITIONS})\n    elseif (CMAKE_SYSTEM_PROCESSOR MATCHES \"wasm\")\n        message(STATUS \"Wasm detected\")\n        list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c)\n    else()\n        message(WARNING \"Unknown CPU architecture. Falling back to generic implementations.\")\n        list(APPEND ARCH_FLAGS -DGGML_CPU_GENERIC)\n    endif()\n\n    if (GGML_CPU_REPACK)\n        target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_REPACK)\n    endif()\n\n    if (GGML_CPU_KLEIDIAI)\n        message(STATUS \"Using KleidiAI optimized kernels if applicable\")\n\n        # Disable the KleidiAI tests\n        set(KLEIDIAI_BUILD_TESTS  OFF)\n\n        # Fetch KleidiAI sources:\n        include(FetchContent)\n        set(KLEIDIAI_COMMIT_TAG \"v1.22.0\")\n        set(KLEIDIAI_DOWNLOAD_URL \"https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz\")\n        set(KLEIDIAI_ARCHIVE_MD5  \"54049037570ab0ee0a0d126b2ba5ece1\")\n\n        if (POLICY CMP0135)\n            cmake_policy(SET CMP0135 NEW)\n        endif()\n\n        # TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+\n        # Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28\n        FetchContent_Declare(KleidiAI_Download\n            URL ${KLEIDIAI_DOWNLOAD_URL}\n            DOWNLOAD_EXTRACT_TIMESTAMP NEW\n            URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})\n\n        FetchContent_GetProperties(KleidiAI_Download\n            SOURCE_DIR  KLEIDIAI_SRC\n            POPULATED   KLEIDIAI_POPULATED)\n\n        if (NOT KLEIDIAI_POPULATED)\n            FetchContent_Populate(KleidiAI_Download)\n            FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)\n        endif()\n\n        add_compile_definitions(GGML_USE_CPU_KLEIDIAI)\n\n        list(APPEND GGML_CPU_SOURCES\n            ggml-cpu/kleidiai/kleidiai.cpp\n            ggml-cpu/kleidiai/kernels.cpp\n            ggml-cpu/kleidiai/kleidiai.h\n            ggml-cpu/kleidiai/kernels.h\n            )\n\n        # KleidiAI\n        include_directories(\n            ${KLEIDIAI_SRC}/\n            ${KLEIDIAI_SRC}/kai/\n            ${KLEIDIAI_SRC}/kai/ukernels/\n            ${KLEIDIAI_SRC}/kai/ukernels/matmul/\n            ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/\n            ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/\n            ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/\n            ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/\n            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)\n\n        set(ARCH_FLAGS_TEMP \"${ARCH_FLAGS}\")\n        if (NOT ARCH_FLAGS_TEMP)\n            string(REGEX MATCH \"-march=[^ ]+\" ARCH_FLAGS_TEMP \"${CMAKE_C_FLAGS}\")\n        endif()\n        string(FIND \"${ARCH_FLAGS_TEMP}\" \"+dotprod\" DOTPROD_ENABLED)\n        string(FIND \"${ARCH_FLAGS_TEMP}\" \"+i8mm\" I8MM_ENABLED)\n        string(FIND \"${ARCH_FLAGS_TEMP}\" \"+sme\" SME_ENABLED)\n        string(FIND \"${ARCH_FLAGS_TEMP}\" \"+sve\" SVE_ENABLED)\n\n        set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP})\n\n        list(APPEND GGML_KLEIDIAI_SOURCES\n            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c\n            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c\n            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c\n            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c\n            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c\n            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c\n            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c)\n\n        if (NOT DOTPROD_ENABLED MATCHES -1)\n            list(APPEND GGML_KLEIDIAI_SOURCES\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c)\n        endif()\n\n        if (NOT I8MM_ENABLED MATCHES -1)\n            list(APPEND GGML_KLEIDIAI_SOURCES\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c)\n        endif()\n\n        if (NOT SME_ENABLED MATCHES -1)\n            list(APPEND GGML_KLEIDIAI_SOURCES\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.c\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa_asm.S\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_f16pmrx2_f32_neon.c\n                ${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)\n            set(PRIVATE_ARCH_FLAGS \"-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2+sme2+fp16\")\n        endif()\n\n        if (NOT SVE_ENABLED MATCHES -1)\n            list(APPEND GGML_KLEIDIAI_SOURCES\n                ${KLEIDIAI_SRC}/kai/kai_common_sve_asm.S\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod_asm.S\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.c\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm_asm.S\n                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.c)\n        endif()\n\n        set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS \"${PRIVATE_ARCH_FLAGS}\")\n        list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES})\n    endif()\n\n    message(STATUS \"Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}\")\n    target_sources(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_SOURCES})\n    target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS})\n    target_compile_definitions(${GGML_CPU_NAME} PRIVATE ${ARCH_DEFINITIONS})\n\n    if (EMSCRIPTEN)\n        set_target_properties(${GGML_CPU_NAME} PROPERTIES COMPILE_FLAGS \"-msimd128\")\n    endif()\n\n    if (CMAKE_CXX_COMPILER_ID STREQUAL \"IntelLLVM\")\n        # The compiler automatically enables \"-ffast-math\" which can cause NaNs in tests due to \"-fassociative-math\"\n        target_compile_options(${GGML_CPU_NAME} PRIVATE \"-fno-associative-math\")\n    endif()\nendfunction()\n"
  },
  {
    "path": "src/ggml-cpu/amx/amx.cpp",
    "content": "#include \"amx.h\"\n#include \"common.h\"\n#include \"mmq.h\"\n#include \"ggml-backend-impl.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-cpu.h\"\n#include \"traits.h\"\n\n#if defined(__linux__)\n#include <sys/syscall.h>\n#include <unistd.h>\n#endif\n\n#include <cstdlib>\n#include <cstring>\n#include <memory>\n\n#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)\n\n// AMX type_trais\nnamespace ggml::cpu::amx {\nclass tensor_traits : public ggml::cpu::tensor_traits {\n    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {\n        size = ggml_backend_amx_desired_wsize(op);\n        return true;\n    }\n\n    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {\n        if (op->op == GGML_OP_MUL_MAT) {\n            ggml_backend_amx_mul_mat(params, op);\n            return true;\n        }\n        return false;\n    }\n};\n\nstatic ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {\n    static tensor_traits traits;\n    return &traits;\n}\n}  // namespace ggml::cpu::amx\n\n// AMX buffer interface\nstatic void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    free(buffer->context);\n}\n\nstatic void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {\n    return (void *) (buffer->context);\n}\n\nstatic enum ggml_status ggml_backend_amx_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {\n    tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits(buffer, tensor);\n\n    GGML_UNUSED(buffer);\n    return GGML_STATUS_SUCCESS;\n}\n\nstatic void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,\n                                                  uint8_t value, size_t offset, size_t size) {\n    memset((char *) tensor->data + offset, value, size);\n\n    GGML_UNUSED(buffer);\n}\n\nstatic void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,\n                                               const void * data, size_t offset, size_t size) {\n    if (qtype_has_amx_kernels(tensor->type)) {\n        GGML_LOG_DEBUG(\"%s: amx repack tensor %s of type %s\\n\", __func__, tensor->name, ggml_type_name(tensor->type));\n        ggml_backend_amx_convert_weight(tensor, data, offset, size);\n    } else {\n        memcpy((char *) tensor->data + offset, data, size);\n    }\n\n    GGML_UNUSED(buffer);\n}\n\n/*\n// need to figure what we need to do with buffer->extra.\nstatic void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));\n    memcpy(data, (const char *)tensor->data + offset, size);\n\n    GGML_UNUSED(buffer);\n}\n\nstatic bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {\n    if (ggml_backend_buffer_is_host(src->buffer)) {\n        if (qtype_has_amx_kernels(src->type)) {\n            ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_nbytes(dst));\n        } else {\n            memcpy(dst->data, src->data, ggml_nbytes(src));\n        }\n        return true;\n    }\n    return false;\n\n    GGML_UNUSED(buffer);\n}\n*/\n\nstatic void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    memset(buffer->context, value, buffer->size);\n}\n\nstatic ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {\n    /* .free_buffer     = */ ggml_backend_amx_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_amx_buffer_get_base,\n    /* .init_tensor     = */ ggml_backend_amx_buffer_init_tensor,\n    /* .memset_tensor   = */ ggml_backend_amx_buffer_memset_tensor,\n    /* .set_tensor      = */ ggml_backend_amx_buffer_set_tensor,\n    /* .get_tensor      = */ nullptr,\n    /* .cpy_tensor      = */ nullptr,\n    /* .clear           = */ ggml_backend_amx_buffer_clear,\n    /* .reset           = */ nullptr,\n};\n\nstatic const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    return \"AMX\";\n\n    GGML_UNUSED(buft);\n}\n\nstatic ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    void * data = ggml_aligned_malloc(size);\n    if (data == NULL) {\n        fprintf(stderr, \"%s: failed to allocate buffer of size %zu\\n\", __func__, size);\n        return NULL;\n    }\n\n    return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);\n}\n\nstatic size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    return TENSOR_ALIGNMENT;\n\n    GGML_UNUSED(buft);\n}\n\nnamespace ggml::cpu::amx {\nclass extra_buffer_type : ggml::cpu::extra_buffer_type {\n    bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {\n        if (op->op != GGML_OP_MUL_MAT) {\n            return false;\n        }\n        auto * src0 = op->src[0];\n        auto * src1 = op->src[1];\n\n        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {\n            return false;\n        }\n        if (!src0->buffer || src0->buffer->buft != ggml_backend_amx_buffer_type()) {\n            return false;\n        }\n        if (src1->buffer && !ggml_backend_buft_is_host(src1->buffer->buft)) {\n            return false;\n        }\n        if (op->ne[0] % (TILE_N * 2)) {\n            return false;\n        }\n        int alignment;\n        switch (src0->type) {\n            case GGML_TYPE_Q4_0:\n            case GGML_TYPE_Q4_1:\n            case GGML_TYPE_Q8_0:\n                alignment = TILE_K;\n                break;\n            case GGML_TYPE_Q4_K:\n            case GGML_TYPE_Q5_K:\n            case GGML_TYPE_Q6_K:\n            case GGML_TYPE_IQ4_XS:\n                alignment = 256; // QK_K\n                break;\n            case GGML_TYPE_F16:\n                alignment = 16;\n                break;\n            default:\n                return false;\n        }\n        if (src0->ne[0] % alignment) {\n            return false;\n        }\n        if (src1->type != GGML_TYPE_F32) {\n            return false;\n        }\n        return true;\n    }\n\n    ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {\n        if (op->op == GGML_OP_MUL_MAT && op->src[0]->buffer &&\n            op->src[0]->buffer->buft == ggml_backend_amx_buffer_type()) {\n            return (ggml::cpu::tensor_traits *) op->src[0]->extra;\n        }\n\n        return nullptr;\n    }\n};\n}  // namespace ggml::cpu::amx\n\nstatic size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {\n    return ggml_backend_amx_get_alloc_size(tensor);\n\n    GGML_UNUSED(buft);\n}\n\n#define ARCH_GET_XCOMP_PERM     0x1022\n#define ARCH_REQ_XCOMP_PERM     0x1023\n#define XFEATURE_XTILECFG       17\n#define XFEATURE_XTILEDATA      18\n\nstatic bool ggml_amx_init() {\n#if defined(__linux__)\n    if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {\n        fprintf(stderr, \"AMX is not ready to be used!\\n\");\n        return false;\n    }\n    return true;\n#elif defined(_WIN32)\n    return true;\n#else\n    return false;\n#endif\n}\n\nggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {\n    static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {\n        /* .iface = */ {\n                        /* .get_name         = */ ggml_backend_amx_buffer_type_get_name,\n                        /* .alloc_buffer     = */ ggml_backend_amx_buffer_type_alloc_buffer,\n                        /* .get_alignment    = */ ggml_backend_amx_buffer_type_get_alignment,\n                        /* .get_max_size     = */ nullptr,  // defaults to SIZE_MAX\n                        /* .get_alloc_size   = */ ggml_backend_amx_buffer_type_get_alloc_size,\n                        /* .is_host          = */ nullptr,\n                        },\n        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),\n        /* .context = */ new ggml::cpu::amx::extra_buffer_type(),\n    };\n\n    if (!ggml_amx_init()) {\n        return nullptr;\n    }\n\n    return &ggml_backend_buffer_type_amx;\n}\n\n#endif  // defined(__AMX_INT8__) && defined(__AVX512VNNI__)\n"
  },
  {
    "path": "src/ggml-cpu/amx/amx.h",
    "content": "#include \"ggml-backend.h\"\n#include \"ggml-cpu-impl.h\"\n\n// GGML internal header\n\n#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)\nggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/amx/common.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n#include \"ggml-cpu-impl.h\"\n\n#include <algorithm>\n#include <memory>\n#include <type_traits>\n\n#if defined(GGML_USE_OPENMP)\n#include <omp.h>\n#else\n#include <thread>\n#endif\n\n#define TILE_M 16\n#define TILE_N 16\n#define TILE_K 32\n#define VNNI_BLK 4\n\n#define AMX_BLK_SIZE 32\n\n#define TMM0 0\n#define TMM1 1\n#define TMM2 2\n#define TMM3 3\n#define TMM4 4\n#define TMM5 5\n#define TMM6 6\n#define TMM7 7\n\n// parallel routines\ntemplate <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>\ninline T div_up(T x, T y) { return (x + y - 1) / y; }\n\ntemplate <typename T>\ninline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {\n#if 0\n    // onednn partition pattern\n    T& n_my = n_end;\n    if (nth <= 1 || n == 0) {\n        n_start = 0;\n        n_my = n;\n    } else {\n        T n1 = div_up(n, nth);\n        T n2 = n1 - 1;\n        T T1 = n - n2 * nth;\n        n_my = ith < T1 ? n1 : n2;\n        n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;\n    }\n    n_end += n_start;\n#else\n    // pytorch aten partition pattern\n    T n_my = div_up(n, nth);\n    n_start = ith * n_my;\n    n_end = std::min(n_start + n_my, n);\n#endif\n}\n\ntemplate <typename func_t>\ninline void parallel_for(int n, const func_t & f) {\n    if (n <= 0) {\n        return;\n    }\n#if defined(GGML_USE_OPENMP)\n    #pragma omp parallel\n    {\n        int nth = omp_get_num_threads();\n        int ith = omp_get_thread_num();\n        int tbegin, tend;\n        balance211(n, nth, ith, tbegin, tend);\n        f(tbegin, tend);\n    }\n#else\n    int nth = std::thread::hardware_concurrency();\n    if (nth <= 1) {\n        f(0, n);\n        return;\n    }\n    if (nth > n) {\n        nth = n;\n    }\n    std::vector<std::thread> threads;\n    threads.reserve(nth);\n    for (int ith = 0; ith < nth; ++ith) {\n        threads.emplace_back([&f, n, ith, nth] {\n            int tbegin, tend;\n            balance211(n, nth, ith, tbegin, tend);\n            f(tbegin, tend);\n        });\n    }\n    for (auto & t : threads) {\n        t.join();\n    }\n#endif\n}\n\ntemplate <typename func_t>\ninline void parallel_for_ggml(const ggml_compute_params * params, int n, const func_t & f) {\n    int tbegin, tend;\n    balance211(n, params->nth, params->ith, tbegin, tend);\n    f(tbegin, tend);\n}\n\n// quantized types that have AMX support\ninline bool qtype_has_amx_kernels(const enum ggml_type type) {\n    // TODO: fix padding for vnni format\n    return (type == GGML_TYPE_Q4_0) ||\n        (type == GGML_TYPE_Q4_1) ||\n        (type == GGML_TYPE_Q8_0) ||\n        (type == GGML_TYPE_Q4_K) ||\n        (type == GGML_TYPE_Q5_K) ||\n        (type == GGML_TYPE_Q6_K) ||\n        (type == GGML_TYPE_IQ4_XS);\n}\n"
  },
  {
    "path": "src/ggml-cpu/amx/mmq.cpp",
    "content": "#if defined(__GNUC__)\n#pragma GCC diagnostic ignored \"-Wpedantic\"\n#pragma GCC diagnostic ignored \"-Wunused-local-typedefs\"\n#endif\n\n#include \"amx.h\"\n#include \"mmq.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-cpu-impl.h\"\n#include \"simd-mappings.h\"\n#include \"quants.h\"\n#include \"ggml-quants.h\"\n#include <algorithm>\n#include <type_traits>\n\n#if defined(__gnu_linux__)\n#include <sys/syscall.h>\n#include <unistd.h>\n#endif\n\n#if (defined(_WIN32) || defined(_WIN64))\n#define RESTRICT __restrict\n#else\n#define RESTRICT __restrict__\n#endif\n\n#if (defined(_WIN32) || defined(_WIN64))\n#define ALWAYS_INLINE __forceinline\n#elif __has_attribute(always_inline) || defined(__GNUC__)\n#define ALWAYS_INLINE __attribute__((__always_inline__)) inline\n#else\n#define ALWAYS_INLINE inline\n#endif\n\n#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)\n\nnamespace {\n\n// Forced unrolling\ntemplate <int n>\nstruct Unroll {\n    template <typename Func, typename... Args>\n    ALWAYS_INLINE void operator()(const Func& f, Args... args) const {\n        Unroll<n - 1>{}(f, args...);\n        f(std::integral_constant<int, n - 1>{}, args...);\n    }\n};\n\ntemplate <>\nstruct Unroll<1> {\n    template <typename Func, typename... Args>\n    ALWAYS_INLINE void operator()(const Func& f, Args... args) const {\n        f(std::integral_constant<int, 0>{}, args...);\n    }\n};\n\n// type traits\ntemplate <typename T> struct PackedTypes {};\ntemplate <> struct PackedTypes<block_q4_0> { using type = int8_t; };\ntemplate <> struct PackedTypes<block_q4_1> { using type = uint8_t; };\ntemplate <> struct PackedTypes<block_q8_0> { using type = int8_t; };\ntemplate <typename T> using packed_B_type = typename PackedTypes<T>::type;\n\ntemplate <typename T>\nstruct do_compensate : std::integral_constant<bool,\n    std::is_same<T, block_q8_0>::value> {};\n\ntemplate <typename T>\nstruct do_unpack : std::integral_constant<bool,\n    std::is_same<T, block_q4_0>::value ||\n    std::is_same<T, block_q4_1>::value> {};\n\ntemplate <typename T>\nstruct is_type_qkk : std::integral_constant<bool,\n    std::is_same<T, block_q4_K>::value ||\n    std::is_same<T, block_q5_K>::value ||\n    std::is_same<T, block_q6_K>::value ||\n    std::is_same<T, block_iq4_xs>::value> {};\n\n#define GGML_DISPATCH_FLOATING_TYPES(TYPE, ...)                                        \\\n    [&] {                                                                              \\\n        switch (TYPE) {                                                                \\\n            case GGML_TYPE_F16: {                                                      \\\n                using type = ggml_fp16_t;                                              \\\n                constexpr int blck_size = 16;                                          \\\n                return __VA_ARGS__();                                                  \\\n            }                                                                          \\\n            case GGML_TYPE_BF16: {                                                     \\\n                using type = ggml_bf16_t;                                              \\\n                constexpr int blck_size = 32;                                          \\\n                return __VA_ARGS__();                                                  \\\n            }                                                                          \\\n            default:                                                                   \\\n                fprintf(stderr, \"Unsupported floating data type\\n\");                   \\\n        }                                                                              \\\n    }()\n\n#define GGML_DISPATCH_QTYPES(QT, ...)                                                  \\\n    [&] {                                                                              \\\n        switch (QT) {                                                                  \\\n            case GGML_TYPE_Q4_0: {                                                     \\\n                using type = block_q4_0;                                               \\\n                using vec_dot_type = block_q8_0;                                       \\\n                constexpr int blck_size = QK4_0;                                       \\\n                return __VA_ARGS__();                                                  \\\n            }                                                                          \\\n            case GGML_TYPE_Q4_1: {                                                     \\\n                using type = block_q4_1;                                               \\\n                using vec_dot_type = block_q8_1;                                       \\\n                constexpr int blck_size = QK4_1;                                       \\\n                return __VA_ARGS__();                                                  \\\n            }                                                                          \\\n            case GGML_TYPE_Q8_0: {                                                     \\\n                using type = block_q8_0;                                               \\\n                using vec_dot_type = block_q8_0;                                       \\\n                constexpr int blck_size = QK8_0;                                       \\\n                return __VA_ARGS__();                                                  \\\n            }                                                                          \\\n            case GGML_TYPE_Q4_K: {                                                     \\\n                using type = block_q4_K;                                               \\\n                using vec_dot_type = block_q8_K;                                       \\\n                constexpr int blck_size = QK_K;                                        \\\n                return __VA_ARGS__();                                                  \\\n            }                                                                          \\\n            case GGML_TYPE_Q5_K: {                                                     \\\n                using type = block_q5_K;                                               \\\n                using vec_dot_type = block_q8_K;                                       \\\n                constexpr int blck_size = QK_K;                                        \\\n                return __VA_ARGS__();                                                  \\\n            }                                                                          \\\n            case GGML_TYPE_Q6_K: {                                                     \\\n                using type = block_q6_K;                                               \\\n                using vec_dot_type = block_q8_K;                                       \\\n                constexpr int blck_size = QK_K;                                        \\\n                return __VA_ARGS__();                                                  \\\n            }                                                                          \\\n            case GGML_TYPE_IQ4_XS: {                                                   \\\n                using type = block_iq4_xs;                                             \\\n                using vec_dot_type = block_q8_K;                                       \\\n                constexpr int blck_size = QK_K;                                        \\\n                return __VA_ARGS__();                                                  \\\n            }                                                                          \\\n            default:                                                                   \\\n                fprintf(stderr, \"Unsupported quantized data type: %d\\n\", int(TYPE));   \\\n        }                                                                              \\\n    }()\n\n#define GGML_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...)                                     \\\n    [&] {                                                                              \\\n        if (BOOL_V) {                                                                  \\\n            constexpr bool BOOL_NAME = true;                                           \\\n            return __VA_ARGS__();                                                      \\\n        } else {                                                                       \\\n            constexpr bool BOOL_NAME = false;                                          \\\n            return __VA_ARGS__();                                                      \\\n        }                                                                              \\\n    }()\n\n// define amx tile config data structure\nstruct tile_config_t{\n    uint8_t palette_id = 0;\n    uint8_t start_row = 0;\n    uint8_t reserved_0[14] = {0};\n    uint16_t colsb[16] = {0};\n    uint8_t rows[16] = {0};\n};\n\n// Notes: amx tile config\n//\n// Typically, TMUL calculates A and B of size 16 x 64 containing INT8 values,\n// and accumulate the result to a 16 x 16 matrix C containing INT32 values,\n//\n// As many GGUF quantized types as `block_size` of 32, so a 16-16-32 config is used\n// instead of the normally used 16-16-64 config.\n//\n//    Block A: {16, 32}, dtype = int8_t\n//    Block B: {16, 32}, dtype = uint8_t/int8_t\n//    Block C: {16, 16}, dtype = int32_t\n//\n// Block B needs to be prepacked to vnni format before feeding into  TMUL:\n//    packed_B: from {n, k} to {k/vnni_blk, n, vnni_blck}, viewed in 2d, we get {8, 64}\n//\n// Therefore, we get tileconfig:\n//             A    B    C\n//    rows    16    8   16\n//    colsb   32   64   16\n//\n// For tile distribution, follow a 2-2-4 pattern, e.g. A used TMM2-TMM3, B used TMM0-TMM1,\n// C used TMM4-TMM7:\n//            B TMM0  B TMM1\n//    A TMM2  C TMM4  C TMM6\n//    A TMM3  C TMM5  C TMM7\n//\n// Each `amx` kernel handles 4 blocks at a time: 2MB * 2NB, when m < 2 * BLOCK_M, unpack A\n// will be needed.\n//\n// Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16;\n// and the single batch gemm (m=1) has a special fast path with `avx512-vnni`.\n//\n// ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/\n//    advanced-matrix-extensions-intrinsics-functions.html\n//\n\ninline void ggml_tile_config_init(void) {\n    static thread_local bool done = false;\n\n    if (done) {\n        return;\n    }\n\n    alignas(64) tile_config_t tc = {};\n    tc.palette_id = 1;\n    tc.start_row = 0;\n    tc.rows[0] = 8;   tc.colsb[0] = 64;\n    tc.rows[1] = 8;   tc.colsb[1] = 64;\n    tc.rows[2] = 16;  tc.colsb[2] = 32;\n    tc.rows[3] = 16;  tc.colsb[3] = 32;\n    tc.rows[4] = 16;  tc.colsb[4] = 64;\n    tc.rows[5] = 16;  tc.colsb[5] = 64;\n    tc.rows[6] = 16;  tc.colsb[6] = 64;\n    tc.rows[7] = 16;  tc.colsb[7] = 64;\n\n    _tile_loadconfig(&tc);\n    done = true;\n}\n\n// we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.\n// See the notes `s8s8 igemm compensation in avx512-vnni` for detail.\ntemplate <typename TB>\nint get_tile_size() {\n    int tile_size = TILE_N * sizeof(TB);\n    if (do_compensate<TB>::value) {\n        tile_size += TILE_N * sizeof(int32_t);\n    }\n    if (std::is_same<TB, block_q4_K>::value ||\n        std::is_same<TB, block_q5_K>::value) {\n        tile_size += TILE_N * 4;\n    }\n    if (std::is_same<TB, block_iq4_xs>::value) {\n        tile_size += TILE_N * 2;\n    }\n    return tile_size;\n}\n\ntemplate <typename TB, int BLOCK_K>\nint get_row_size(int K) {\n    int KB = K / BLOCK_K;\n    int row_size = KB * sizeof(TB);\n    if (do_compensate<TB>::value) {\n        row_size += KB * sizeof(int32_t);\n    }\n    if (std::is_same<TB, block_q4_K>::value ||\n        std::is_same<TB, block_q5_K>::value) {\n        row_size += KB * 4;\n    }\n    if (std::is_same<TB, block_iq4_xs>::value) {\n        row_size += KB * 2;\n    }\n    return row_size;\n}\n\n// transpose utils\n#define SHUFFLE_EPI32(a, b, mask) \\\n    _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))\ninline void transpose_8x8_32bit(__m256i * v, __m256i * v1) {\n    // unpacking and 32-bit elements\n    v1[0] = _mm256_unpacklo_epi32(v[0], v[1]);\n    v1[1] = _mm256_unpackhi_epi32(v[0], v[1]);\n    v1[2] = _mm256_unpacklo_epi32(v[2], v[3]);\n    v1[3] = _mm256_unpackhi_epi32(v[2], v[3]);\n    v1[4] = _mm256_unpacklo_epi32(v[4], v[5]);\n    v1[5] = _mm256_unpackhi_epi32(v[4], v[5]);\n    v1[6] = _mm256_unpacklo_epi32(v[6], v[7]);\n    v1[7] = _mm256_unpackhi_epi32(v[6], v[7]);\n\n    // shuffling the 32-bit elements\n    v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44);\n    v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee);\n    v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44);\n    v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee);\n    v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44);\n    v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee);\n    v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44);\n    v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee);\n\n    // shuffling 128-bit elements\n    v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02);\n    v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02);\n    v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02);\n    v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02);\n    v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13);\n    v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13);\n    v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13);\n    v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13);\n}\n\ninline void transpose_16x4_32bit(__m512i * r, __m512i * d) {\n\n    static const __m512i index1 = _mm512_set_epi32(\n        0x0f, 0x0b, 0x07, 0x03,\n        0x0e, 0x0a, 0x06, 0x02,\n        0x0d, 0x09, 0x05, 0x01,\n        0x0c, 0x08, 0x04, 0x00);\n\n    d[0] = _mm512_permutexvar_epi32(index1, r[0]);\n    d[1] = _mm512_permutexvar_epi32(index1, r[1]);\n    d[2] = _mm512_permutexvar_epi32(index1, r[2]);\n    d[3] = _mm512_permutexvar_epi32(index1, r[3]);\n\n    r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);\n    r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);\n    r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44);\n    r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee);\n\n    d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88);\n    d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd);\n    d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88);\n    d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd);\n}\n\ninline void transpose_16x16_32bit(__m512i * v) {\n    __m512i v1[16];\n    v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);\n    v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);\n    v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);\n    v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);\n    v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);\n    v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);\n    v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);\n    v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);\n    v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);\n    v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);\n    v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);\n    v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);\n    v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);\n    v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);\n    v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);\n    v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);\n\n    v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);\n    v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);\n    v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);\n    v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);\n    v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);\n    v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);\n    v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);\n    v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);\n    v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);\n    v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);\n    v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);\n    v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);\n    v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);\n    v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);\n    v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);\n    v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);\n\n    v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);\n    v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);\n    v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);\n    v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);\n    v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);\n    v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);\n    v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);\n    v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);\n    v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);\n    v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);\n    v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);\n    v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);\n    v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);\n    v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);\n    v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);\n    v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);\n\n    v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);\n    v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);\n    v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);\n    v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);\n    v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);\n    v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);\n    v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);\n    v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);\n    v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);\n    v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);\n    v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);\n    v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);\n    v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);\n    v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);\n    v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);\n    v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);\n}\n\nvoid quantize_row_q8_K_vnni(const float * RESTRICT x, void * RESTRICT vy, int64_t k) {\n    assert(k % QK_K == 0);\n    const int KB = k / QK_K;\n    constexpr int kVecs = QK_K / 16;\n\n    block_q8_K * y = reinterpret_cast<block_q8_K *>(vy);\n\n    // hold 16 float vecs from x\n    __m512  v[kVecs];\n\n    // hold the quants vecs\n    __m512i vq[kVecs / 4];\n\n    // hold the packed quants vecs\n    __m512i vq_packed[kVecs / 4];\n\n    const __m512 signBit = _mm512_set1_ps(-0.f);\n\n    for (int i = 0; i < KB; ++i) {\n        // Compute max(abs(e)) for the block\n        __m512 vamax = _mm512_set1_ps(0.f);\n        for (int j = 0; j < kVecs; ++j) {\n            v[j] = _mm512_loadu_ps(x); x += 16;\n            vamax = _mm512_max_ps(vamax, _mm512_andnot_ps(signBit, v[j]));\n        }\n        const float amax = _mm512_reduce_max_ps(vamax);\n\n        // Quantize these floats\n        const float iscale = 127.f / amax;\n        y[i].d = GGML_CPU_FP32_TO_FP16(1 / iscale);\n        const float id = ( amax != 0.0f ) ? iscale : 0.f;\n        const __m512 vscale = _mm512_set1_ps(id);\n\n        // Apply multiplier and round to nearest integer\n        for (int j = 0; j < kVecs; ++j) {\n            v[j] = _mm512_mul_ps(v[j], vscale);\n            v[j] = _mm512_roundscale_ps(v[j], (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n        }\n\n        // Pack to epi8 vecs\n        for (int j = 0; j < kVecs / 4; ++j) {\n            __m128i q8_0 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 0]));\n            __m128i q8_1 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 1]));\n            __m128i q8_2 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 2]));\n            __m128i q8_3 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 3]));\n\n            __m256i q8_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_0), (q8_1), 1);\n            __m256i q8_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_2), (q8_3), 1);\n\n            vq[j] = _mm512_inserti32x8(_mm512_castsi256_si512(q8_01), q8_23, 1);\n            _mm512_storeu_si512((__m512i *)(y[i].qs + j * 64), vq[j]);\n        }\n\n        // Compute the bsums with vnni\n        transpose_16x4_32bit(vq, vq_packed);\n\n        const __m512i one = _mm512_set1_epi8(1);\n        __m512i sum = _mm512_setzero_si512();\n        for (int k = 0; k < 4; ++k) {\n            sum = _mm512_dpbusd_epi32(sum, one, vq_packed[k]);\n        }\n        _mm256_storeu_si256((__m256i *)(y[i].bsums), _mm512_cvtepi32_epi16(sum));\n    }\n}\n\n// quantize A from float to `vec_dot_type`\ntemplate <typename T>\ninline void from_float(const float * x, char * vy, int64_t k);\n\ntemplate <>\ninline void from_float<block_q8_0>(const float * x, char * vy, int64_t k) {\n    quantize_row_q8_0(x, (block_q8_0 *)vy, k);\n}\n\ntemplate <>\ninline void from_float<block_q8_1>(const float * x, char * vy, int64_t k) {\n    quantize_row_q8_1(x, (block_q8_1 *)vy, k);\n}\n\ntemplate <>\ninline void from_float<block_q8_K>(const float * x, char * vy, int64_t k) {\n#if 1\n    // TODO: this is reference impl!\n    quantize_row_q8_K_ref(x, (block_q8_K *)vy, k);\n#else\n    quantize_row_q8_K_vnni(x, vy, k);\n#endif\n}\n\n// load A from memory to array when nrows can not fill in whole tile\nvoid unpack_A(int8_t * RESTRICT tile, const block_q8_0 * RESTRICT A, int lda, int nr) {\n    assert(nr != TILE_M);\n    for (int m = 0; m < nr; ++m) {\n        const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs));\n        _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);\n    }\n}\n\nvoid unpack_A(int8_t * RESTRICT tile, const block_q8_1 * RESTRICT A, int lda, int nr) {\n    assert(nr != TILE_M);\n    for (int m = 0; m < nr; ++m) {\n        const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs));\n        _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);\n    }\n}\n\ntemplate <typename TB>\nvoid unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) {\n    assert(nr <= TILE_M);\n    for (int m = 0; m < nr; ++m) {\n        const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs + k * 32));\n        _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);\n    }\n}\n\ntemplate <>\nvoid unpack_A<block_q6_K>(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) {\n    assert(nr <= TILE_M);\n    // zero padding k from 16 to 32, so that we don't have to re-config amx\n    const __m128i zero = _mm_setzero_si128();\n    for (int m = 0; m < nr; ++m) {\n        const __m128i v = _mm_loadu_si128((const __m128i *)(A[m * lda].qs + k * 16));\n        const __m256i r = _mm256_insertf128_si256(_mm256_castsi128_si256(v), zero, 1);\n        _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), r);\n    }\n}\n\n#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)\ninline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {\n    const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);\n    const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);\n    const __m256i lowMask = _mm256_set1_epi8(0xF);\n    return _mm256_and_si256(lowMask, bytes);\n}\n\n// used for block_q4_K\ninline __m512i bytes_from_nibbles_64(const uint8_t * rsi) {\n    const __m256i tmp = _mm256_loadu_si256((const __m256i *)rsi);\n    const __m256i lowMask = _mm256_set1_epi8(0xF);\n    const __m256i q4l = _mm256_and_si256(tmp, lowMask);\n    const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(tmp, 4), lowMask);\n    return _mm512_inserti32x8(_mm512_castsi256_si512(q4l), q4h, 1);\n}\n\n// used for block_q5_K\ninline __m512i bytes_from_nibbles_64(const uint8_t * qs, const uint8_t * qh, int k) {\n    const __m256i lowMask = _mm256_set1_epi8(0xF);\n    __m256i hmask = _mm256_set1_epi8(1);\n    hmask = _mm256_slli_epi16(hmask, k);\n\n    const __m256i q5bits = _mm256_loadu_si256((const __m256i *)qs);\n    const __m256i hbits = _mm256_loadu_si256((const __m256i *)qh);\n\n    const __m256i q5l_0 = _mm256_and_si256(q5bits, lowMask);\n    const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 0), 4);\n    const __m256i q5_0  = _mm256_add_epi8(q5l_0, q5h_0);\n    hmask = _mm256_slli_epi16(hmask, 1);\n\n    const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), lowMask);\n    const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 1), 4);\n    const __m256i q5_1  = _mm256_add_epi8(q5l_1, q5h_1);\n\n    return _mm512_inserti32x8(_mm512_castsi256_si512(q5_0), q5_1, 1);\n}\n\n// used for block_q6_K\ninline void bytes_from_nibbles_128(__m512i& r0, __m512i& r1, const uint8_t * qs, const uint8_t * qh) {\n    const __m256i m4 = _mm256_set1_epi8(0xF);\n    const __m256i m2 = _mm256_set1_epi8(0x3);\n\n    const __m256i q6bits1 = _mm256_loadu_si256((const __m256i *)qs);\n    const __m256i q6bits2 = _mm256_loadu_si256((const __m256i *)(qs + 32));\n    const __m256i q6bitsH = _mm256_loadu_si256((const __m256i *)qh);\n\n    const __m256i q6h_0 = _mm256_slli_epi16(_mm256_and_si256(                  q6bitsH,     m2), 4);\n    const __m256i q6h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 2), m2), 4);\n    const __m256i q6h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 4), m2), 4);\n    const __m256i q6h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 6), m2), 4);\n\n    const __m256i q6_0 = _mm256_or_si256(_mm256_and_si256(q6bits1, m4), q6h_0);\n    const __m256i q6_1 = _mm256_or_si256(_mm256_and_si256(q6bits2, m4), q6h_1);\n    const __m256i q6_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits1, 4), m4), q6h_2);\n    const __m256i q6_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits2, 4), m4), q6h_3);\n\n    r0 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_0), q6_1, 1);\n    r1 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_2), q6_3, 1);\n}\n\ninline __m512i packNibbles(__m512i r0, __m512i r1) {\n    return _mm512_or_si512(r0, _mm512_slli_epi16(r1, 4));\n}\n\ntemplate <typename TB>\ninline void pack_qs(void * RESTRICT packed_B, const TB * RESTRICT B, int KB) {\n    int8_t tmp[8 * 64];\n    __m256i v[8], v2[8];\n    for (int n = 0; n < 8; ++n) {\n        v[n] = bytes_from_nibbles_32(B[n * KB].qs);\n    }\n    transpose_8x8_32bit(v, v2);\n    for (int n = 0; n < 8; ++n) {\n        _mm256_storeu_si256((__m256i *)(tmp + n * 64), v2[n]);\n    }\n    for (int n = 0; n < 8; ++n) {\n        v[n] = bytes_from_nibbles_32(B[(n + 8) * KB].qs);\n    }\n    transpose_8x8_32bit(v, v2);\n    for (int n = 0; n < 8; ++n) {\n        _mm256_storeu_si256((__m256i *)(tmp + n * 64 + 32), v2[n]);\n    }\n\n    // pack again with 128 to fully utilize vector length\n    for (int n = 0; n < 8; n += 2) {\n        __m512i r0 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64));\n        __m512i r1 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64 + 64));\n        __m512i r1r0 = packNibbles(r0, r1);\n        _mm512_storeu_si512((__m512i *)((char *)packed_B + n * 32), r1r0);\n    }\n}\n\ntemplate <>\ninline void pack_qs<block_q8_0>(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) {\n    __m256i v[8], v2[8];\n    for (int n = 0; n < 8; ++n) {\n        v[n] = _mm256_loadu_si256((const __m256i *)(B[n * KB].qs));\n    }\n    transpose_8x8_32bit(v, v2);\n    for (int n = 0; n < 8; ++n) {\n        _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64), v2[n]);\n    }\n    for (int n = 0; n < 8; ++n) {\n        v[n] = _mm256_loadu_si256((const __m256i *)(B[(n + 8) * KB].qs));\n    }\n    transpose_8x8_32bit(v, v2);\n    for (int n = 0; n < 8; ++n) {\n        _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64 + 32), v2[n]);\n    }\n}\n\ntemplate <>\ninline void pack_qs<block_q4_K>(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) {\n    __m512i v[16];\n    // QK_K 256 with 8 groups, handle 2 groups at a time\n    char * pb = (char *)packed_B;\n    for (int k = 0; k < QK_K / 64; ++k) {\n        // pack 2 groups { n, g,  k} to {g, k/4, 4n}\n        //          e.g. {16, 2, 32} to {2,   8, 64}\n        for (int n = 0; n < TILE_N; ++n) {\n            v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32);\n        }\n\n        transpose_16x16_32bit(v);\n\n        // pack again with 128 to fully utilize vector length\n        for (int n = 0; n < TILE_N; n += 2) {\n            _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1]));\n            pb += 64;\n        }\n    }\n}\n\ntemplate <>\ninline void pack_qs<block_q5_K>(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) {\n    __m512i v[16];\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n    // QK_K 256 with 8 groups, handle 2 groups at a time\n    char * pb = (char *)packed_B;\n    char * ph = (char *)packed_B + (QK_K / 2) * TILE_N;\n    for (int k = 0; k < QK_K / 64; ++k) {\n        // pack 2 groups { n, g,  k} to {g, k/4, 4n}\n        //          e.g. {16, 2, 32} to {2,   8, 64}\n        for (int n = 0; n < TILE_N; ++n) {\n            v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32, B[n * KB].qh, /* group */2 * k);\n        }\n\n        transpose_16x16_32bit(v);\n\n        // 1. pack lower 4bits with 2 groups\n        for (int n = 0; n < TILE_N; n += 2) {\n            // get lower 4 bits\n            const __m512i r0 = _mm512_and_si512(v[n], lowMask);\n            const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);\n            _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64;\n        }\n\n        // 2. pack higher 1bit with 2 groups\n        const __m512i hmask = _mm512_set1_epi8(0x10);\n        for (int g = 0; g < 2; ++g) {\n            __m512i hbits = _mm512_setzero_si512();\n            hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 0], hmask), 4));\n            hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 1], hmask), 3));\n            hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 2], hmask), 2));\n            hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 3], hmask), 1));\n            hbits = _mm512_add_epi8(hbits,                   _mm512_and_si512(v[g * 8 + 4], hmask)    );\n            hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 5], hmask), 1));\n            hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 6], hmask), 2));\n            hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 7], hmask), 3));\n            _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64;\n        }\n    }\n}\n\ntemplate <>\ninline void pack_qs<block_q6_K>(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) {\n    __m512i v[32];\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n    // QK_K 256 with 8 groups, handle 4 groups at a time\n    char * pb = (char *)packed_B;\n    char * ph = (char *)packed_B + (QK_K / 2) * TILE_N;\n    for (int k = 0; k < QK_K / 128; ++k) {\n        for (int n = 0; n < TILE_N; ++n) {\n            bytes_from_nibbles_128(v[n], v[n + 16], B[n * KB].ql + k * 64, B[n * KB].qh + k * 32);\n        }\n\n        // top half: group 0,1 or 4,5; bottom half: group 2,3 or 6,7\n        transpose_16x16_32bit(v);\n        transpose_16x16_32bit(v + 16);\n\n        // 1. pack lower 4bits with 4 groups\n        for (int n = 0; n < 32; n += 2) {\n            const __m512i r0 = _mm512_and_si512(v[n], lowMask);\n            const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);\n            _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64;\n        }\n\n        // 2. pack higher 2bit with 4 groups\n        const __m512i hmask = _mm512_set1_epi8(0x30);\n        for (int g = 0; g < 8; ++g) {\n            __m512i hbits = _mm512_setzero_si512();\n            hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 0], hmask), 4));\n            hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 1], hmask), 2));\n            hbits = _mm512_add_epi8(hbits,                   _mm512_and_si512(v[g * 4 + 2], hmask)    );\n            hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 4 + 3], hmask), 2));\n            _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64;\n        }\n    }\n}\n\ntemplate <>\ninline void pack_qs<block_iq4_xs>(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) {\n    __m512i v[16];\n    char * pb = (char *)packed_B;\n    for (int k = 0; k < QK_K / 64; ++k) {\n        for (int n = 0; n < TILE_N; ++n) {\n            __m256i r0 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 +  0);\n            __m256i r1 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 16);\n            v[n] = _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1);\n        }\n\n        transpose_16x16_32bit(v);\n\n        // pack again with 128 to fully utilize vector length\n        for (int n = 0; n < TILE_N; n += 2) {\n            _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1]));\n            pb += 64;\n        }\n    }\n}\n\n// pack B to vnni formats in 4bits or 8 bits\nvoid pack_B(void * RESTRICT packed_B, const block_q4_0 * RESTRICT B, int KB) {\n    pack_qs(packed_B, B, KB);\n    ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K / 2);\n    for (int n = 0; n < TILE_N; ++n) {\n        d0[n] = B[n * KB].d;\n    }\n}\n\nvoid pack_B(void * RESTRICT packed_B, const block_q4_1 * RESTRICT B, int KB) {\n    pack_qs(packed_B, B, KB);\n    ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K / 2);\n    ggml_half * m0 = d0 + TILE_N;\n    for (int n = 0; n < TILE_N; ++n) {\n        d0[n] = B[n * KB].d;\n        m0[n] = B[n * KB].m;\n    }\n}\n\ninline void s8s8_compensation(void * RESTRICT packed_B) {\n    // packed_B layout:\n    //   quants {TILE_N, TILEK}  int8_t\n    //   d0     {TILE_N}      ggml_half\n    //   comp   {TILE_N}        int32_t\n    const int offset = TILE_N * TILE_K + TILE_N * sizeof(ggml_half);\n    __m512i vcomp = _mm512_setzero_si512();\n    const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));\n    for (int k = 0; k < 8; ++k) {\n        __m512i vb = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + k * 64));\n        vcomp = _mm512_dpbusd_epi32(vcomp, off, vb);\n    }\n    _mm512_storeu_si512((__m512i *)((char *)(packed_B) + offset), vcomp);\n}\n\nvoid pack_B(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) {\n    pack_qs(packed_B, B, KB);\n    ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K);\n    for (int n = 0; n < TILE_N; ++n) {\n        d0[n] = B[n * KB].d;\n    }\n    s8s8_compensation(packed_B);\n}\n\n// convert 8 * {min, scale} from int6 to int8\ninline void unpack_mins_and_scales(const uint8_t * scales, uint32_t * utmp) {\n    const uint32_t kmask1 = 0x3f3f3f3f;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n    const uint32_t kmask3 = 0x03030303;\n\n    memcpy(utmp, scales, 12);\n    utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n    const uint32_t uaux = utmp[1] & kmask1;\n    utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n    utmp[2] = uaux;\n    utmp[0] &= kmask1;\n}\n\n// packed_B layout:\n//   quants {8, TILE_N, 16}  uint8\n//   scales {8, TILE_N}      uint8\n//   mins   {8, TILE_N}      uint8\n//   d      {TILE_N}     ggml_half\n//   dmin   {TILE_N}     ggml_half\nvoid pack_B(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) {\n    pack_qs(packed_B, B, KB);\n\n    uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N);\n    uint8_t * mins = scales + 8 * TILE_N;\n    ggml_half * d = reinterpret_cast<ggml_half *>(mins + 8 * TILE_N);\n    ggml_half * dmin = d + TILE_N;\n\n    union {\n        uint32_t u32[4];\n        uint8_t  u8[16];\n    } s;\n\n    for (int n = 0; n < TILE_N; ++n) {\n        unpack_mins_and_scales(B[n * KB].scales, s.u32);\n        for (int k = 0; k < 8; ++k) {\n            scales[k * TILE_N + n] = s.u8[k];\n            mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];\n        }\n        d[n] = B[n * KB].d;\n        dmin[n] = B[n * KB].dmin;\n    }\n}\n\n// packed_B layout:\n//   quants {8, TILE_N, 16}  uint8\n//   qh     {8, TILE_N,  4}  uint8\n//   scales {8, TILE_N}      uint8\n//   mins   {8, TILE_N}      uint8\n//   d      {TILE_N}     ggml_half\n//   dmin   {TILE_N}     ggml_half\nvoid pack_B(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) {\n    pack_qs(packed_B, B, KB);\n\n    uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);\n    uint8_t * mins = scales + 8 * TILE_N;\n    ggml_half * d = reinterpret_cast<ggml_half *>(mins + 8 * TILE_N);\n    ggml_half * dmin = d + TILE_N;\n\n    union {\n        uint32_t u32[4];\n        uint8_t  u8[16];\n    } s;\n\n    for (int n = 0; n < TILE_N; ++n) {\n        unpack_mins_and_scales(B[n * KB].scales, s.u32);\n        for (int k = 0; k < 8; ++k) {\n            scales[k * TILE_N + n] = s.u8[k];\n            mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];\n        }\n        d[n] = B[n * KB].d;\n        dmin[n] = B[n * KB].dmin;\n    }\n}\n\n// packed_B layout:\n//   quants {16, TILE_N, 8}  uint8\n//   qh     {16, TILE_N, 4}  uint8\n//   scales {16, TILE_N}      uint8\n//   d      {TILE_N}     ggml_half\nvoid pack_B(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) {\n    pack_qs(packed_B, B, KB);\n\n    uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);\n    ggml_half * d = reinterpret_cast<ggml_half *>(scales + 16 * TILE_N);\n    for (int n = 0; n < TILE_N; ++n) {\n        const int8_t * ps = B[n * KB].scales;\n        for (int k = 0; k < 16; ++k) {\n            scales[k * TILE_N + n] = ps[k];\n        }\n        d[n] = B[n * KB].d;\n    }\n}\n\n// packed_B layout:\n//   quants {8, TILE_N, 16}  uint8\n//   scales {8, TILE_N}       int8\n//   d      {TILE_N}     ggml_half\nvoid pack_B(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) {\n    pack_qs(packed_B, B, KB);\n\n    int8_t * scales = reinterpret_cast<int8_t *>((char *)packed_B + (QK_K / 2) * TILE_N);\n    ggml_half * d = reinterpret_cast<ggml_half *>(scales + 8 * TILE_N);\n\n    // pack the scales\n    for (int n = 0; n < TILE_N; ++n) {\n        uint16_t sh = B[n * KB].scales_h;\n        for (int k = 0; k < 8; k += 2) {\n            const int16_t ls1 = ((B[n * KB].scales_l[k / 2] & 0xf) | ((sh << 4) & 0x30)) - 32;\n            const int16_t ls2 = ((B[n * KB].scales_l[k / 2] >>  4) | ((sh << 2) & 0x30)) - 32;\n            scales[(k + 0) * TILE_N + n] = ls1;\n            scales[(k + 1) * TILE_N + n] = ls2;\n            sh >>= 4;\n        }\n        d[n] = B[n * KB].d;\n    }\n}\n\ntemplate<typename TB, typename packed_B_t = packed_B_type<TB>>\nvoid unpack_B(packed_B_t * RESTRICT tile, const void * RESTRICT packed_B) {\n    GGML_UNUSED(tile);\n    GGML_UNUSED(packed_B);\n}\n\ntemplate <>\nvoid unpack_B<block_q4_0>(int8_t * RESTRICT tile, const void * RESTRICT packed_B) {\n  const __m512i off = _mm512_set1_epi8(8);\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  for (int n = 0; n < 8; n += 2) {\n    __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32));\n    const __m512i r0 = _mm512_sub_epi8(_mm512_and_si512(bytes, lowMask), off);\n    const __m512i r1 = _mm512_sub_epi8(_mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask), off);\n    _mm512_storeu_si512((__m512i *)(tile + n * 64 +  0), r0);\n    _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);\n  }\n}\n\ntemplate <>\nvoid unpack_B<block_q4_1>(uint8_t * RESTRICT tile, const void * RESTRICT packed_B) {\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n    for (int n = 0; n < 8; n += 2) {\n        __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32));\n        const __m512i r0 = _mm512_and_si512(bytes, lowMask);\n        const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n        _mm512_storeu_si512((__m512i *)(tile + n * 64 +  0), r0);\n        _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);\n    }\n}\n\n// packed_B_t for QKK is int8_t\ntemplate <typename TB>\nvoid unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {\n    const int packed_B_group_size = QK_K / 2 * TILE_N / 8;\n    const char * packed_B_group = (const char *)packed_B + k * packed_B_group_size;\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n    for (int n = 0; n < 8; n += 2) {\n        __m512i bytes = _mm512_loadu_si512(packed_B_group + n * 32);\n        const __m512i r0 = _mm512_and_si512(bytes, lowMask);\n        const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n        _mm512_storeu_si512((__m512i *)(tile + n * 64 +  0), r0);\n        _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);\n    }\n}\n\ntemplate <>\nvoid unpack_B<block_q5_K>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {\n    // lower 4bits, stride 256 bytes\n    const int packed_l4_group_size = QK_K / 2 * TILE_N / 8;\n    const char * pb = (const char *)packed_B + k * packed_l4_group_size;\n\n    // higher 1bit, stride 64 bytes\n    const int packed_h1_group_size = QK_K / 8 * TILE_N / 8;\n    const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h1_group_size;\n    const __m512i hbits = _mm512_loadu_si512(ph);\n\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n    __m512i hmask0 = _mm512_set1_epi8(0x1);\n    __m512i hmask1 = _mm512_set1_epi8(0x2);\n\n    for (int n = 0; n < 8; n += 2) {\n        __m512i bytes = _mm512_loadu_si512(pb + n * 32);\n        __m512i r0 = _mm512_and_si512(bytes, lowMask);\n        __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n        __m512i h0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), n), 4);\n        __m512i h1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), n + 1), 4);\n\n        hmask0 = _mm512_slli_epi16(hmask0, 2);\n        hmask1 = _mm512_slli_epi16(hmask1, 2);\n        r0 = _mm512_add_epi8(r0, h0);\n        r1 = _mm512_add_epi8(r1, h1);\n        _mm512_storeu_si512((__m512i *)(tile + n * 64 +  0), r0);\n        _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);\n    }\n}\n\ntemplate <>\nvoid unpack_B<block_q6_K>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {\n    // lower 4bits, stride 128 bytes\n    const int packed_l4_group_size = QK_K / 2 * TILE_N / 16;\n    const char * pb = (const char *)packed_B + k * packed_l4_group_size;\n\n    // higher 2bits, stride 64 bytes\n    const int packed_h2_group_size = QK_K / 4 * TILE_N / 16;\n    const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h2_group_size;\n    const __m512i hbits = _mm512_loadu_si512(ph);\n\n    const __m512i off = _mm512_set1_epi8(32);\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n    __m512i hmask0 = _mm512_set1_epi8(0x3); // 0011\n    __m512i hmask1 = _mm512_set1_epi8(0xC); // 1100\n\n    // notes: skip zero padding from row4 to row7 as we have done so in `unpack_A`\n    __m512i bytes = _mm512_loadu_si512(pb);\n    __m512i r0 = _mm512_and_si512(bytes, lowMask);\n    __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n    __m512i h0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask0), 4);\n    __m512i h1 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask1), 2);\n    _mm512_storeu_si512((__m512i *)(tile +  0), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));\n    _mm512_storeu_si512((__m512i *)(tile + 64), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));\n\n    hmask0 = _mm512_slli_epi16(hmask0, 4);\n    hmask1 = _mm512_slli_epi16(hmask1, 4);\n\n    bytes = _mm512_loadu_si512(pb + 64);\n    r0 = _mm512_and_si512(bytes, lowMask);\n    r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n    h0 =                   _mm512_and_si512(hbits, hmask0);\n    h1 = _mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), 2);\n    _mm512_storeu_si512((__m512i *)(tile + 128), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));\n    _mm512_storeu_si512((__m512i *)(tile + 192), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));\n}\n\ntemplate <>\nvoid unpack_B<block_iq4_xs>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {\n    static const __m512i values128 = _mm512_set_epi8(\n        113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,\n        113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,\n        113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,\n        113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127\n    );\n\n    const int packed_B_group_size = QK_K / 2 * TILE_N / 8;\n    const char * pb = (const char *)packed_B + k * packed_B_group_size;\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n    for (int n = 0; n < 8; n += 2) {\n        __m512i bytes = _mm512_loadu_si512(pb + n * 32);\n        const __m512i r0 = _mm512_shuffle_epi8(values128, _mm512_and_si512(bytes, lowMask));\n        const __m512i r1 = _mm512_shuffle_epi8(values128, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));\n        _mm512_storeu_si512((__m512i *)(tile + n * 64 +  0), r0);\n        _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);\n    }\n}\n\ntemplate <typename TA, typename TB, bool is_acc>\nstruct acc_C {};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_0, block_q4_0, is_acc> {\n    static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) {\n        const int offset = TILE_N * TILE_K / 2;\n        const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));\n\n        for (int m = 0; m < nr; ++m) {\n            const __m512 vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[m * lda].d));\n            const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n            __m512 vsum;\n            if (is_acc) {\n                vsum = _mm512_loadu_ps(C + m * ldc);\n            } else {\n                vsum = _mm512_set1_ps(0.f);\n            }\n            vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);\n            _mm512_storeu_ps(C + m * ldc, vsum);\n        }\n    }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_1, block_q4_1, is_acc> {\n    static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_1 * A, int lda, const void * packed_B, int nr) {\n        const int offset = TILE_N * TILE_K / 2;\n        const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));\n        const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset + TILE_N * sizeof(ggml_half))));\n\n        for (int m = 0; m < nr; ++m) {\n            const __m512 vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[m * lda].d));\n            const __m512 vs1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[m * lda].s));\n            const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n            __m512 vsum;\n            if (is_acc) {\n                vsum = _mm512_loadu_ps(C + m * ldc);\n            } else {\n                vsum = _mm512_set1_ps(0.f);\n            }\n            vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);\n            vsum = _mm512_fmadd_ps(vm0, vs1, vsum);\n            _mm512_storeu_ps(C + m * ldc, vsum);\n        }\n    }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_0, block_q8_0, is_acc> {\n    static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) {\n        const int offset = TILE_N * TILE_K;\n        const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));\n\n        for (int m = 0; m < nr; ++m) {\n            const __m512 vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[m * lda].d));\n            const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n            __m512 vsum;\n            if (is_acc) {\n                vsum = _mm512_loadu_ps(C + m * ldc);\n            } else {\n                vsum = _mm512_set1_ps(0.f);\n            }\n            vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);\n            _mm512_storeu_ps(C + m * ldc, vsum);\n        }\n    }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_K, block_q4_K, is_acc> {\n    static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {\n        const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N);\n        const uint8_t * mins = scales + 8 * TILE_N;\n        const ggml_half * d0 = reinterpret_cast<const ggml_half *>(mins + 8 * TILE_N);\n        const ggml_half * dmin = d0 + TILE_N;\n\n        const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));\n        const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin));\n\n        for (int m = 0; m < nr; ++m) {\n            const float d1 = A[m * lda].d;\n            const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);\n            const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);\n            const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n            __m512 vsum;\n            if (is_acc) {\n                vsum = _mm512_loadu_ps(C + m * ldc);\n            } else {\n                vsum = _mm512_set1_ps(0.f);\n            }\n\n            const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums);\n            const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));\n\n            __m512i acc_m = _mm512_setzero_si512();\n            for (int k = 0; k < 4; ++k) {\n                __m512i vmask = _mm512_set1_epi32(k);\n                __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));\n                __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32)));\n                acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);\n            }\n\n            vsum = _mm512_fmadd_ps(vtile, vd, vsum);\n            vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);\n            _mm512_storeu_ps(C + m * ldc, vsum);\n        }\n    }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_K, block_q5_K, is_acc> {\n    static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {\n        const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);\n        const uint8_t * mins = scales + 8 * TILE_N;\n        const ggml_half * d0 = reinterpret_cast<const ggml_half *>(mins + 8 * TILE_N);\n        const ggml_half * dmin = d0 + TILE_N;\n\n        const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));\n        const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin));\n\n        for (int m = 0; m < nr; ++m) {\n            const float d1 = A[m * lda].d;\n            const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);\n            const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);\n            const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n            __m512 vsum;\n            if (is_acc) {\n                vsum = _mm512_loadu_ps(C + m * ldc);\n            } else {\n                vsum = _mm512_set1_ps(0.f);\n            }\n\n            const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums);\n            const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));\n\n            __m512i acc_m = _mm512_setzero_si512();\n            for (int k = 0; k < 4; ++k) {\n                __m512i vmask = _mm512_set1_epi32(k);\n                __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));\n                __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32)));\n                acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);\n            }\n\n            vsum = _mm512_fmadd_ps(vtile, vd, vsum);\n            vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);\n            _mm512_storeu_ps(C + m * ldc, vsum);\n        }\n    }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_K, block_q6_K, is_acc> {\n    static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {\n        const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);\n        const ggml_half * d0 = reinterpret_cast<const ggml_half *>(scales + 16 * TILE_N);\n\n        const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));\n\n        for (int m = 0; m < nr; ++m) {\n            const float d1 = A[m * lda].d;\n            const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);\n            const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n            __m512 vsum;\n            if (is_acc) {\n                vsum = _mm512_loadu_ps(C + m * ldc);\n            } else {\n                vsum = _mm512_set1_ps(0.f);\n            }\n\n            vsum = _mm512_fmadd_ps(vtile, vd, vsum);\n            _mm512_storeu_ps(C + m * ldc, vsum);\n        }\n    }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_K, block_iq4_xs, is_acc> {\n    static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {\n        const int8_t * scales = reinterpret_cast<const int8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N);\n        const ggml_half * d0 = reinterpret_cast<const ggml_half *>(scales + 8 * TILE_N);\n\n        const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));\n\n        for (int m = 0; m < nr; ++m) {\n            const float d1 = A[m * lda].d;\n            const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);\n            const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n            __m512 vsum;\n            if (is_acc) {\n                vsum = _mm512_loadu_ps(C + m * ldc);\n            } else {\n                vsum = _mm512_set1_ps(0.f);\n            }\n\n            vsum = _mm512_fmadd_ps(vtile, vd, vsum);\n            _mm512_storeu_ps(C + m * ldc, vsum);\n        }\n    }\n};\n\ntemplate <typename TB> constexpr int get_quants_size();\ntemplate <> constexpr int get_quants_size<block_q4_K>() { return (QK_K / 2) * TILE_N; }\ntemplate <> constexpr int get_quants_size<block_q5_K>() { return (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; }\ntemplate <> constexpr int get_quants_size<block_q6_K>() { return (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; }\ntemplate <> constexpr int get_quants_size<block_iq4_xs>() { return (QK_K / 2) * TILE_N; }\n\n// used for QKK format\ntemplate <typename TB, bool is_acc,\n          typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0>\ninline void scale_C(const int32_t * RESTRICT tile, int32_t * RESTRICT sumi, const void * packed_B, int k, int nr) {\n    const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + get_quants_size<TB>());\n    const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(scales + k * TILE_N)));\n\n    for (int m = 0; m < nr; ++m) {\n        __m512i vsumi;\n        if (is_acc) {\n            vsumi = _mm512_loadu_si512(sumi + m * TILE_N);\n        } else {\n            vsumi = _mm512_setzero_si512();\n        }\n        __m512i vtile = _mm512_loadu_si512(tile + m * TILE_N);\n        vsumi = _mm512_add_epi32(vsumi, _mm512_mullo_epi32(vtile, vscale));\n        _mm512_storeu_si512((__m512i *)(sumi + m * TILE_N), vsumi);\n    }\n}\n\ntemplate <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_avx {\n    static void apply(int K, const TA * RESTRICT A, const TB * RESTRICT B, TC * RESTRICT C, int ldc) {\n        GGML_UNUSED(K);\n        GGML_UNUSED(A);\n        GGML_UNUSED(B);\n        GGML_UNUSED(C);\n        GGML_UNUSED(ldc);\n    }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n    static void apply(int K, const float * RESTRICT A, const ggml_fp16_t * RESTRICT B, float * RESTRICT C, int ldc) {\n        constexpr int ROWS = BLOCK_M;\n        constexpr int COLS = BLOCK_N;\n        assert(BLOCK_K == 16);\n\n        __m512 va;\n        __m512 vb[COLS];\n        __m512 vc[ROWS * COLS];\n\n        auto loadc = [&](auto idx) {\n            vc[idx] = _mm512_setzero_ps();\n        };\n        Unroll<ROWS * COLS>{}(loadc);\n\n        auto compute = [&](auto idx, auto k) {\n            constexpr int row = idx / COLS;\n            constexpr int col = idx % COLS;\n\n            if constexpr (col == 0) {\n                va = _mm512_loadu_ps(A + row * K + k);\n            }\n            if constexpr (row == 0) {\n                vb[col] =  _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k)));\n            }\n            vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);\n        };\n\n        for (int k = 0; k < K; k += 16) {\n            Unroll<ROWS * COLS>{}(compute, k);\n        }\n\n        auto storec = [&](auto idx) {\n            constexpr int row = idx / COLS;\n            constexpr int col = idx % COLS;\n            C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]);\n        };\n        Unroll<ROWS * COLS>{}(storec);\n    }\n};\n\n#define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE)                                \\\n    tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply(    \\\n        K, (const float *)src1->data + src1_offset + mb_start * K,                  \\\n        (const type *)src0->data + src0_offset + nb_start * K,                      \\\n        (float *)dst->data + dst_offset + mb_start * ldc + nb_start, ldc)\n\n\n// re-organize in the format {NB, KB, TILE_SIZE}:\n#define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size\n\ntemplate<typename TB, int BLOCK_K>\nvoid convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K) {\n    const int NB = N / TILE_N;\n    const int KB = K / BLOCK_K;\n    const int TILE_SIZE = get_tile_size<TB>();\n\n    // parallel on NB should be enough\n    parallel_for(NB, [&](int begin, int end) {\n        for (int n = begin; n < end; ++n) {\n            for (int k = 0; k < KB; ++k) {\n                int n0 = n * TILE_N;\n                pack_B((char *)packed_B + PACKED_INDEX(n, k, KB, TILE_SIZE), &B[n0 * KB + k], KB);\n            }\n        }\n    });\n}\n\ntemplate <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni {};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_0, block_q4_0, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n    static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {\n\n        constexpr int COLS = BLOCK_N / 16;\n        const int TILE_SIZE = TILE_N * sizeof(block_q4_0);\n\n        const block_q8_0 * RESTRICT A = static_cast<const block_q8_0 *>(_A);\n        const char * RESTRICT B = static_cast<const char *>(_B);\n\n        __m512i va[8];\n        __m512 vc[COLS];\n        __m512 vd1;\n\n        // sum of offsets, shared across COLS\n        //\n        // avx512-vnni does not have `_mm512_dpbssd_epi32`,\n        // need to transform ss to us:\n        //   a * (b - 8) is equivalent to b * a - 8 * a\n        //   s    u   u                   u   s   u   s\n        //\n        __m512i vcomp;\n\n        const __m512i off = _mm512_set1_epi8(8);\n        const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n        auto loadc = [&](auto col) {\n            vc[col] = _mm512_setzero_ps();\n        };\n        Unroll<COLS>{}(loadc);\n\n        auto compute = [&](auto col, auto i) {\n            // load a and compute compensation\n            if constexpr (col == 0) {\n                const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);\n                vcomp = _mm512_setzero_si512();\n                for (int k = 0; k < 8; ++k) {\n                    va[k] = _mm512_set1_epi32(a_ptr[k]);\n                    vcomp = _mm512_dpbusd_epi32(vcomp, off, va[k]);\n                }\n                vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[0 * KB + i].d));\n            }\n\n            // load b\n            __m512i vsum = _mm512_setzero_si512();\n            const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n            for (int k = 0; k < 8; k += 2) {\n                __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32));\n                __m512i vb0 = _mm512_and_si512(bytes, lowMask);\n                vsum = _mm512_dpbusd_epi32(vsum, vb0, va[k + 0]);\n                __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n                vsum = _mm512_dpbusd_epi32(vsum, vb1, va[k + 1]);\n            }\n            const int offset = TILE_N * TILE_K / 2;\n            const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));\n            vsum = _mm512_sub_epi32(vsum, vcomp);\n\n            vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);\n        };\n\n        for (int i = 0; i < KB; ++i) {\n            Unroll<COLS>{}(compute, i);\n        }\n\n        //store to C\n        auto storec = [&](auto col) {\n            _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);\n        };\n        Unroll<COLS>{}(storec);\n    }\n};\n\ntemplate <int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_1, block_q4_1, float, 1, BLOCK_N, BLOCK_K> {\n    static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {\n\n        constexpr int COLS = BLOCK_N / 16;\n        const int TILE_SIZE = TILE_N * sizeof(block_q4_1);\n\n        const block_q8_1 * RESTRICT A = static_cast<const block_q8_1 *>(_A);\n        const char * RESTRICT B = static_cast<const char *>(_B);\n\n        __m512i va[8];\n        __m512i vb[8];\n        __m512 vc[COLS];\n        __m512 vd1, vs1;\n\n        const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n        auto loadc = [&](auto col) {\n            vc[col] = _mm512_setzero_ps();\n        };\n        Unroll<COLS>{}(loadc);\n\n        auto compute = [&](auto col, auto i) {\n            // load a\n            if constexpr (col == 0) {\n                const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);\n                for (int k = 0; k < 8; ++k) {\n                    va[k] = _mm512_set1_epi32(a_ptr[k]);\n                }\n                vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[0 * KB + i].d));\n                vs1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[0 * KB + i].s));\n            }\n\n            // load b\n            const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n            for (int k = 0; k < 8; k += 2) {\n                __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32));\n                vb[k + 0] = _mm512_and_si512(bytes, lowMask);\n                vb[k + 1] = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n            }\n            const int offset = TILE_N * TILE_K / 2;\n            const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));\n            const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset + TILE_N * sizeof(ggml_half))));\n\n            __m512i vsum = _mm512_setzero_si512();\n            for (int k = 0; k < 8; ++k) {\n                vsum = _mm512_dpbusd_epi32(vsum, vb[k], va[k]);\n            }\n\n            vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);\n            vc[col] = _mm512_fmadd_ps(vm0, vs1, vc[col]);\n        };\n\n        for (int i = 0; i < KB; ++i) {\n            Unroll<COLS>{}(compute, i);\n        }\n\n        //store to C\n        auto storec = [&](auto col) {\n            _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);\n        };\n        Unroll<COLS>{}(storec);\n    }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_0, block_q8_0, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n    static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {\n\n        constexpr int COLS = BLOCK_N / 16;\n        const int TILE_SIZE = TILE_N * sizeof(block_q8_0) + TILE_N * sizeof(int32_t);\n\n        const block_q8_0 * RESTRICT A = static_cast<const block_q8_0 *>(_A);\n        const char * RESTRICT B = static_cast<const char *>(_B);\n\n        __m512i va[8];\n        __m512i vb[8];\n        __m512 vc[COLS];\n        __m512 vd1;\n\n        // Notes: s8s8 igemm compensation in avx512-vnni\n        // change s8s8 to u8s8 with compensate\n        //   a * b = (a + 128) * b - 128 * b\n        //   s   s       u       s    u    s\n        //\n        // (128 * b is pre-computed when packing B to vnni formats)\n        //\n        const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));\n\n        auto loadc = [&](auto col) {\n            vc[col] = _mm512_setzero_ps();\n        };\n        Unroll<COLS>{}(loadc);\n\n        auto compute = [&](auto col, auto i) {\n            // load a and add offset 128\n            if constexpr (col == 0) {\n                const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);\n                for (int k = 0; k < 8; ++k) {\n                    va[k] = _mm512_set1_epi32(a_ptr[k]);\n                    va[k] = _mm512_add_epi8(va[k], off);\n                }\n                vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[0 * KB + i].d));\n            }\n\n            // load b\n            const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n            for (int k = 0; k < 8; ++k) {\n                vb[k] = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 64));\n            }\n            const int offset = TILE_N * TILE_K;\n            const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));\n            const int offset2 = TILE_N * TILE_K + TILE_N * sizeof(ggml_half);\n            const __m512i vcomp = _mm512_loadu_si512((const __m512i *)(b_ptr + offset2));\n\n            __m512i vsum = _mm512_setzero_si512();\n            for (int k = 0; k < 8; ++k) {\n                vsum = _mm512_dpbusd_epi32(vsum, va[k], vb[k]);\n            }\n            vsum = _mm512_sub_epi32(vsum, vcomp);\n\n            vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);\n        };\n\n        for (int i = 0; i < KB; ++i) {\n            Unroll<COLS>{}(compute, i);\n        }\n\n        //store to C\n        auto storec = [&](auto col) {\n            _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);\n        };\n        Unroll<COLS>{}(storec);\n    }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n    static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {\n\n        constexpr int COLS = BLOCK_N / 16;\n        const int TILE_SIZE = TILE_N * sizeof(block_q4_K) + TILE_N * 4;\n\n        const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);\n        const char * RESTRICT B = static_cast<const char *>(_B);\n\n        // a.qs:   8 groups, 32 bytes each group (m256i)\n        __m512i va[8];\n        // a.bsum: 8 groups,  2 bytes each group (m128i)\n        __m512i va_bsum;\n        __m512 vc[COLS];\n        __m512 vd1;\n\n        // packed_B:\n        const int offset_scales = (QK_K / 2) * TILE_N;\n        const int offset_mins   = (QK_K / 2) * TILE_N +  8 * TILE_N;\n        const int offset_d0     = (QK_K / 2) * TILE_N + 16 * TILE_N;\n        const int offset_dmin   = (QK_K / 2) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half);\n\n        const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n        auto loadc = [&](auto col) {\n            vc[col] = _mm512_setzero_ps();\n        };\n        Unroll<COLS>{}(loadc);\n\n        // Notes: vnni formats in QK_K\n        //   a) quants vnni format\n        //     int8  {k/4, n, 4}, viewed as 2d {k/4, 4n}, k = 32\n        //     from {16, 32} to {8, 64}\n        //\n        //   b) min vnni format\n        //     int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8\n        //     from {16,  8} to {4, 32}\n        //\n        auto compute = [&](auto col, auto i) {\n            // load a\n            if constexpr (col == 0) {\n                for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n                    va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));\n                }\n                const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);\n                const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));\n                va_bsum = _mm512_castsi128_si512(q8s);\n                vd1 = _mm512_set1_ps(A[0 * KB + i].d);\n            }\n\n            // step 1: accumultate the quants\n            __m512i acc = _mm512_setzero_si512();\n            const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n            const char * b_qs  = b_ptr;\n            for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n                __m512i vsum = _mm512_setzero_si512();\n                for (int k = 0; k < 8; k += 2) {\n                    __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);\n                    __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);\n\n                    __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs);\n                    __m512i vb0 = _mm512_and_si512(bytes, lowMask);\n                    vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n                    __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n                    vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n\n                    b_qs += 64;\n                }\n                // vacc += scale * (q8 @ q4)\n                const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));\n                acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));\n            }\n            const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));\n            vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);\n\n            // step 2: accumulate the mins\n            __m512i acc_m = _mm512_setzero_si512();\n            for (int k = 0; k < 4; ++k) {\n                __m512i vmask = _mm512_set1_epi32(k);\n                __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);\n                __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32)));\n                acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);\n            }\n            const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin)));\n            vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);\n        };\n\n        for (int i = 0; i < KB; ++i) {\n            Unroll<COLS>{}(compute, i);\n        }\n\n        //store to C\n        auto storec = [&](auto col) {\n            _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);\n        };\n        Unroll<COLS>{}(storec);\n    }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_K, block_q5_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n    static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {\n\n        constexpr int COLS = BLOCK_N / 16;\n        const int TILE_SIZE = TILE_N * sizeof(block_q5_K) + TILE_N * 4;\n\n        const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);\n        const char * RESTRICT B = static_cast<const char *>(_B);\n\n        // a.qs:   8 groups, 32 bytes each group (m256i)\n        __m512i va[8];\n        // a.bsum: 8 groups,  2 bytes each group (m128i)\n        __m512i va_bsum;\n        __m512 vc[COLS];\n        __m512 vd1;\n\n        // packed_B:\n        const int offset_qh     = (QK_K / 2) * TILE_N;\n        const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N;\n        const int offset_mins   = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N +  8 * TILE_N;\n        const int offset_d0     = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N;\n        const int offset_dmin   = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half);\n\n        const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n        auto loadc = [&](auto col) {\n            vc[col] = _mm512_setzero_ps();\n        };\n        Unroll<COLS>{}(loadc);\n\n        // Q5_K and Q4_K shares the same vnni formats, refer to notes above.\n        auto compute = [&](auto col, auto i) {\n            // load a\n            if constexpr (col == 0) {\n                for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n                    va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));\n                }\n                const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);\n                const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));\n                va_bsum = _mm512_castsi128_si512(q8s);\n                vd1 = _mm512_set1_ps(A[0 * KB + i].d);\n            }\n\n            // step 1: accumultate the quants\n            __m512i acc = _mm512_setzero_si512();\n            const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n            const char * b_qs  = b_ptr;\n            const char * b_qh  = b_ptr + offset_qh;\n            for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n                __m512i vsum = _mm512_setzero_si512();\n                __m512i hmask0 = _mm512_set1_epi8(0x1);\n                __m512i hmask1 = _mm512_set1_epi8(0x2);\n                __m512i hbits = _mm512_loadu_si512((const __m512i *)(b_qh + k_group * 64));\n                for (int k = 0; k < 8; k += 2) {\n                    __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);\n                    __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);\n\n                    __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs);\n                    __m512i vb0 = _mm512_and_si512(bytes, lowMask);\n                    __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n\n                    __m512i vh0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), k), 4);\n                    __m512i vh1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), k + 1), 4);\n\n                    hmask0 = _mm512_slli_epi16(hmask0, 2);\n                    hmask1 = _mm512_slli_epi16(hmask1, 2);\n                    vb0 = _mm512_add_epi8(vb0, vh0);\n                    vb1 = _mm512_add_epi8(vb1, vh1);\n\n                    vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n                    vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n\n                    b_qs += 64;\n                }\n                // vacc += scale * (q8 @ q5)\n                const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));\n                acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));\n            }\n            const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));\n            vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);\n\n            // step 2: accumulate the mins\n            __m512i acc_m = _mm512_setzero_si512();\n            for (int k = 0; k < 4; ++k) {\n                __m512i vmask = _mm512_set1_epi32(k);\n                __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);\n                __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32)));\n                acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);\n            }\n            const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin)));\n            vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);\n        };\n\n        for (int i = 0; i < KB; ++i) {\n            Unroll<COLS>{}(compute, i);\n        }\n\n        //store to C\n        auto storec = [&](auto col) {\n            _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);\n        };\n        Unroll<COLS>{}(storec);\n    }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_K, block_q6_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n    static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {\n\n        constexpr int COLS = BLOCK_N / 16;\n        const int TILE_SIZE = TILE_N * sizeof(block_q6_K);\n\n        const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);\n        const char * RESTRICT B = static_cast<const char *>(_B);\n\n        // load the 256 bytes from A to 4 avx512 vectors\n        __m512i va[4];\n        __m512 vc[COLS];\n        __m512 vd1;\n\n        // packed_B:\n        const int offset_qh     = (QK_K / 2) * TILE_N;\n        const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N;\n        const int offset_d0     = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N + 16 * TILE_N;\n\n        // compensation\n        __m512i vcomp;\n\n        const __m512i m32s = _mm512_set1_epi32(32);\n        const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n        auto loadc = [&](auto col) {\n            vc[col] = _mm512_setzero_ps();\n        };\n        Unroll<COLS>{}(loadc);\n\n        auto compute = [&](auto col, auto i) {\n            if constexpr (col == 0) {\n                // load a\n                va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs +   0));\n                va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs +  64));\n                va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128));\n                va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192));\n\n                const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);\n                vcomp = _mm512_mullo_epi32(_mm512_cvtepi16_epi32(q8sums), m32s);\n                vd1 = _mm512_set1_ps(A[0 * KB + i].d);\n            }\n\n            // accmulate the quants\n            __m512i acc = _mm512_setzero_si512();\n            const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n            const char * b_qs = b_ptr;\n            const char * b_qh = b_ptr + offset_qh;\n            int mask = 0;\n            for (int k_group = 0; k_group < QK_K / 16; ++k_group) {\n                int r = k_group >> 2;\n                __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n                __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n\n                __m512i vsum = _mm512_setzero_si512();\n                __m512i hmask = _mm512_set1_epi8(0x3);\n\n                __m512i bytes = _mm512_loadu_si512(b_qs);\n                __m512i hbits = _mm512_loadu_si512(b_qh);\n                __m512i vb0 = _mm512_and_si512(bytes, lowMask);\n                __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n                __m512i vh0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask), 4);\n                __m512i vh1 = _mm512_slli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 2)), 2);\n\n                vb0 = _mm512_add_epi8(vb0, vh0);\n                vb1 = _mm512_add_epi8(vb1, vh1);\n                vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n                vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n                b_qs += 64;\n\n                va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n                va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n\n                bytes = _mm512_loadu_si512(b_qs);\n                vb0 = _mm512_and_si512(bytes, lowMask);\n                vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n                vh0 =                   _mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 4));\n                vh1 = _mm512_srli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 6)), 2);\n                vb0 = _mm512_add_epi8(vb0, vh0);\n                vb1 = _mm512_add_epi8(vb1, vh1);\n                vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n                vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n                b_qs += 64;\n                b_qh += 64;\n\n                // B * A - 32 * A\n                __m512i vmask = _mm512_set1_epi32(k_group);\n                vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));\n\n                // vacc += scale * (q8 @ q6)\n                const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));\n                acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));\n            }\n            const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));\n            vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);\n        };\n\n        for (int i = 0; i < KB; ++i) {\n            Unroll<COLS>{}(compute, i);\n        }\n\n        //store to C\n        auto storec = [&](int col) {\n            _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);\n        };\n        Unroll<COLS>{}(storec);\n    }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n    static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {\n\n        constexpr int COLS = BLOCK_N / 16;\n        const int TILE_SIZE = TILE_N * sizeof(block_iq4_xs) + TILE_N * 2;\n\n        const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);\n        const char * RESTRICT B = static_cast<const char *>(_B);\n\n        // load the 256 bytes from A to 4 avx512 vectors\n        __m512i va[4];\n        __m512 vc[COLS];\n        __m512 vd1;\n\n        // packed_B:\n        const int offset_scales = (QK_K / 2) * TILE_N ;\n        const int offset_d0     = (QK_K / 2) * TILE_N + 8 * TILE_N;\n\n        // compensation\n        __m512i vcomp;\n\n        const __m256i m128s = _mm256_set1_epi16(128);\n        const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n        const __m512i values128 = _mm512_set_epi8(\n            113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,\n            113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,\n            113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,\n            113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127\n        );\n        const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));\n        const __m512i values256 = _mm512_add_epi8(values128, off);\n\n        auto loadc = [&](auto col) {\n            vc[col] = _mm512_setzero_ps();\n        };\n        Unroll<COLS>{}(loadc);\n\n        auto compute = [&](auto col, auto i) {\n            if constexpr (col == 0) {\n                // load a\n                va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs +   0));\n                va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs +  64));\n                va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128));\n                va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192));\n\n                // compensation: 128 * A\n                const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);\n                vcomp = _mm512_castsi256_si512(_mm256_madd_epi16(q8sums, m128s));\n                vd1 = _mm512_set1_ps(A[0 * KB + i].d);\n            }\n\n            // accmulate the quants\n            __m512i acc = _mm512_setzero_si512();\n            const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n            const char * b_qs = b_ptr;\n            int mask = 0;\n            for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n                int r = k_group >> 1;\n                __m512i vmask = _mm512_set1_epi32(k_group);\n                __m512i vsum = _mm512_setzero_si512();\n                for (int k = 0; k < 8; k += 2) {\n                    __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n                    __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n\n                    __m512i bytes = _mm512_loadu_si512(b_qs);\n                    __m512i vb0 = _mm512_shuffle_epi8(values256, _mm512_and_si512(bytes, lowMask));\n                    __m512i vb1 = _mm512_shuffle_epi8(values256, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));\n\n                    vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n                    vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n                    b_qs += 64;\n                }\n                // (B + 128) * A - 128 * A\n                vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));\n\n                // vacc += scale * (q8 @ q4)\n                const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));\n                acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));\n            }\n            const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));\n            vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);\n        };\n\n        for (int i = 0; i < KB; ++i) {\n            Unroll<COLS>{}(compute, i);\n        }\n\n        //store to C\n        auto storec = [&](auto col) {\n            _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);\n        };\n        Unroll<COLS>{}(storec);\n    }\n};\n\n#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE)                                                   \\\n    tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply(             \\\n        KB, wdata_batch,                                                                       \\\n        (const char *)src0->data + src0_offset + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \\\n        (float *) dst->data + dst_offset + nb_start, ldc)\n\ntemplate <typename TA, typename TB, typename TC, int BLOCK_K,\n          typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0>\nvoid tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, TC * RESTRICT C, int ldc) {\n    using packed_B_t = packed_B_type<TB>;\n    const int TILE_SIZE = get_tile_size<TB>();\n    const bool need_unpack = do_unpack<TB>::value;\n\n    GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);\n    const TA * RESTRICT A = static_cast<const TA *>(_A);\n    const char * RESTRICT B = static_cast<const char *>(_B);\n\n    const int m0 = std::min(M, TILE_M);\n    const int m1 = std::max(M - TILE_M, 0);\n    const int lda = KB * sizeof(TA);\n    //const int ldb = KB * sizeof(TB);\n\n    static thread_local packed_B_t Tile0[TILE_N * TILE_K];\n    static thread_local packed_B_t Tile1[TILE_N * TILE_K];\n    static thread_local int8_t Tile23[TILE_M * TILE_K];\n\n    static thread_local int32_t TileC0[TILE_M * TILE_N * 4];\n    static thread_local int32_t TileC1[TILE_M * TILE_N * 4];\n\n    // double buffering C to interleave avx512 and amx\n    int32_t * C_cur = TileC0;\n    int32_t * C_pre = TileC1;\n\n    auto Tile4 = [&](int32_t * base) { return base; };\n    auto Tile5 = [&](int32_t * base) { return base + TILE_M * TILE_N; };\n    auto Tile6 = [&](int32_t * base) { return base + 2 * TILE_M * TILE_N; };\n    auto Tile7 = [&](int32_t * base) { return base + 3 * TILE_M * TILE_N; };\n\n    if (M == 2 * TILE_M) {\n        // i = 0\n        const char * B_blk0 = B + PACKED_INDEX(0, 0, KB, TILE_SIZE);\n        const char * B_blk1 = B + PACKED_INDEX(1, 0, KB, TILE_SIZE);\n        if (need_unpack) {\n            unpack_B<TB>(Tile0, B_blk0);\n            _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);\n        } else {\n            _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);\n        }\n\n        _tile_zero(TMM4);\n        _tile_loadd(TMM2, A[0].qs, lda);\n        _tile_dpbssd(TMM4, TMM2, TMM0);\n        _tile_stored(TMM4, Tile4(C_pre), TILE_N * sizeof(int32_t));\n\n        _tile_zero(TMM5);\n        _tile_loadd(TMM3, A[TILE_M * KB + 0].qs, lda);\n        _tile_dpbssd(TMM5, TMM3, TMM0);\n        _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));\n\n        if (need_unpack) {\n            unpack_B<TB>(Tile1, B_blk1);\n            _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);\n        } else {\n            _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);\n        }\n\n        _tile_zero(TMM6);\n        _tile_dpbssd(TMM6, TMM2, TMM1);\n        _tile_stored(TMM6, Tile6(C_pre), TILE_N * sizeof(int32_t));\n\n        _tile_zero(TMM7);\n        _tile_dpbssd(TMM7, TMM3, TMM1);\n        _tile_stored(TMM7, Tile7(C_pre), TILE_N * sizeof(int32_t));\n\n        for (int i = 1; i < KB; ++i) {\n            // index of previous iter\n            const int ii = i - 1;\n            const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);\n            const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);\n            GGML_DISPATCH_BOOL(ii > 0, is_acc, [&] {\n                if (need_unpack) {\n                    unpack_B<TB>(Tile0, B_blk0);\n                    _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);\n                } else {\n                    _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);\n                }\n                _tile_zero(TMM4);\n                _tile_loadd(TMM2, A[i].qs, lda);\n                acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);\n\n                _tile_dpbssd(TMM4, TMM2, TMM0);\n                _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));\n\n                _tile_zero(TMM5);\n                _tile_loadd(TMM3, A[TILE_M * KB + i].qs, lda);\n                acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);\n\n                _tile_dpbssd(TMM5, TMM3, TMM0);\n                _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));\n\n                if (need_unpack) {\n                    unpack_B<TB>(Tile1, B_blk1);\n                    _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);\n                } else {\n                    _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);\n                }\n                _tile_zero(TMM6);\n                acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);\n\n                _tile_dpbssd(TMM6, TMM2, TMM1);\n                _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));\n\n                _tile_zero(TMM7);\n                acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);\n\n                _tile_dpbssd(TMM7, TMM3, TMM1);\n                _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));\n\n                std::swap(C_cur, C_pre);\n            });\n        }\n        // final accumulation\n        {\n            int ii = KB - 1;\n            acc_C<TA, TB, true>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);\n            acc_C<TA, TB, true>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);\n            acc_C<TA, TB, true>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);\n            acc_C<TA, TB, true>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);\n        }\n    } else {\n        for (int i = 0; i < KB; ++i) {\n            _tile_zero(TMM4);\n            _tile_zero(TMM6);\n            if (m1 != 0) {\n                _tile_zero(TMM5);\n                _tile_zero(TMM7);\n            }\n\n            const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);\n            const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);\n            if (need_unpack) {\n                unpack_B<TB>(Tile0, B_blk0);\n                _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);\n            } else {\n                _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);\n            }\n\n            if (need_unpack) {\n                unpack_B<TB>(Tile1, B_blk1);\n                _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);\n            } else {\n                _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);\n            }\n\n            if (m0 == TILE_M) {\n                _tile_loadd(TMM2, A[i].qs, lda);\n            } else {\n                unpack_A(Tile23, &A[i], KB, m0);\n                _tile_loadd(TMM2, Tile23, TILE_K);\n            }\n\n            _tile_dpbssd(TMM4, TMM2, TMM0);\n            _tile_dpbssd(TMM6, TMM2, TMM1);\n\n            _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));\n            _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));\n\n            GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {\n                acc_C<TA, TB, is_acc>::apply(C,          ldc, Tile4(C_cur), &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);\n                acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_cur), &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0);\n            });\n\n            if (m1 != 0) {\n                unpack_A(Tile23, &A[TILE_M * KB + i], KB, m1);\n                _tile_loadd(TMM3, Tile23, TILE_K);\n\n                _tile_dpbssd(TMM5, TMM3, TMM0);\n                _tile_dpbssd(TMM7, TMM3, TMM1);\n                _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));\n                _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));\n                GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {\n                    acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc,          ldc, Tile5(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);\n                    acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);\n                });\n            }\n        }\n    }\n    return;\n}\n\ntemplate <typename TA, typename TB, typename TC, int BLOCK_K,\n          typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0>\nvoid tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {\n    static_assert(std::is_same<TA, block_q8_K>::value);\n    const int TILE_SIZE = get_tile_size<TB>();\n\n    GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);\n    const TA * RESTRICT A = static_cast<const TA *>(_A);\n    const char * RESTRICT B = static_cast<const char *>(_B);\n\n    const int m0 = std::min(M, TILE_M);\n    const int m1 = std::max(M - TILE_M, 0);\n    //const int lda = KB * sizeof(TA);\n\n    static thread_local int8_t Tile0[TILE_N * TILE_K];\n    static thread_local int8_t Tile1[TILE_N * TILE_K];\n    static thread_local int8_t Tile23[TILE_M * TILE_K];\n\n    // mat mul result for each group\n    static thread_local int32_t Tile4[TILE_M * TILE_N];\n    static thread_local int32_t Tile5[TILE_M * TILE_N];\n    static thread_local int32_t Tile6[TILE_M * TILE_N];\n    static thread_local int32_t Tile7[TILE_M * TILE_N];\n\n    // sum of each QK_K block, contains 8 groups, int32\n    static thread_local int32_t Sumi4[TILE_M * TILE_N];\n    static thread_local int32_t Sumi5[TILE_M * TILE_N];\n    static thread_local int32_t Sumi6[TILE_M * TILE_N];\n    static thread_local int32_t Sumi7[TILE_M * TILE_N];\n\n    const int k_group_size = std::is_same<TB, block_q6_K>::value ? 16 : 32;\n    for (int i = 0; i < KB; ++i) {\n        // step 1: accumulate the quants across 8 groups, each group with 32\n        for (int k = 0; k < QK_K / k_group_size; ++k) {\n            GGML_DISPATCH_BOOL(k > 0, is_acc, [&] {\n                _tile_zero(TMM4);\n                _tile_zero(TMM6);\n\n                unpack_B<TB>(Tile0, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k);\n                _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);\n\n                unpack_B<TB>(Tile1, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k);\n                _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);\n\n                unpack_A<TB>(Tile23, &A[i], KB, k, m0);\n                _tile_loadd(TMM2, Tile23, TILE_K);\n\n                _tile_dpbssd(TMM4, TMM2, TMM0);\n                _tile_dpbssd(TMM6, TMM2, TMM1);\n\n                _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));\n                _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));\n\n                scale_C<TB, is_acc>(Tile4, Sumi4, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m0);\n                scale_C<TB, is_acc>(Tile6, Sumi6, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m0);\n\n                if (m1 != 0) {\n                    _tile_zero(TMM5);\n                    _tile_zero(TMM7);\n\n                    unpack_A<TB>(Tile23, &A[TILE_M * KB + i], KB, k, m1);\n                    _tile_loadd(TMM3, Tile23, TILE_K);\n\n                    _tile_dpbssd(TMM5, TMM3, TMM0);\n                    _tile_dpbssd(TMM7, TMM3, TMM1);\n\n                    _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));\n                    _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));\n\n                    scale_C<TB, is_acc>(Tile5, Sumi5, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m1);\n                    scale_C<TB, is_acc>(Tile7, Sumi7, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m1);\n                }\n            });\n        }\n\n        // step 2: accmulate the mins\n        GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {\n            acc_C<TA, TB, is_acc>::apply(C,          ldc, Sumi4, &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);\n            acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Sumi6, &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0);\n            if (m1 != 0) {\n                acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc,          ldc, Sumi5, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);\n                acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Sumi7, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);\n            }\n        });\n    }\n    return;\n}\n\n} // anonymous namespace\n\n// get the packed tensor size for quantized weights\nsize_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor) {\n    const enum ggml_type TYPE = tensor->type;\n\n    const int K = tensor->ne[0]; // ne0: in_features\n    const int N = tensor->ne[1]; // ne1: out_features\n\n    auto get_tensor_size = [&] {\n        size_t row_size_B{0};\n        GGML_DISPATCH_QTYPES(TYPE, [&] {\n            row_size_B = get_row_size<type, blck_size>(K);\n        });\n        return N * row_size_B;\n    };\n\n    if (qtype_has_amx_kernels(TYPE)) {\n        return get_tensor_size();\n    } else {\n        // for f16, bf16 we don't do packing\n        return ggml_nbytes(tensor);\n    }\n}\n\n// pack weight to vnni format\nvoid ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    GGML_ASSERT(offset == 0 && size == ggml_nbytes(tensor)); // only full tensor conversion is supported for now\n\n    const enum ggml_type TYPE = tensor->type;\n\n    const int K = tensor->ne[0]; // ne0: in_features\n    const int N = tensor->ne[1]; // ne1: out_features\n\n    GGML_DISPATCH_QTYPES(TYPE, [&] {\n        convert_B_packed_format<type, blck_size>((void *)((char *)tensor->data + offset), (const type *)data, N, K);\n    });\n}\n\n// ne2 is passed explicitly to help compiler optimize repeated calls\ninline int64_t ggml_batch_offset(const ggml_tensor * t, int64_t batch_idx, int64_t ne2) {\n    const int64_t i2 = batch_idx % ne2;\n    const int64_t i3 = batch_idx / ne2;\n    return i3 * t->nb[3] + i2 * t->nb[2];\n}\n\nsize_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {\n    struct ggml_tensor * src0 = dst->src[0];\n\n    const enum ggml_type TYPE = src0->type;\n\n    const bool is_floating_type = TYPE == GGML_TYPE_F16;\n    if (is_floating_type) {\n        return 0;\n    }\n\n    const int M = dst->ne[1];\n    const int K = src0->ne[0];\n    const int64_t n_batch = dst->ne[2] * dst->ne[3];\n\n    size_t desired_wsize = 0;\n\n    GGML_DISPATCH_QTYPES(TYPE, [&] {\n        const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);\n        desired_wsize = n_batch * M * row_size_A;\n    });\n\n    return desired_wsize;\n}\n\n// NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX)\n//\n// src0: weight in shape of {N, K}, quantized\n// src1: input  in shape of {M, K}, float32\n// dst:  output in shape of {M, N}, float32\n//\n// the function performs: dst = src1 @ src0.T for each batch\n//\nvoid ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) {\n    struct ggml_tensor * src0 = dst->src[0];\n    struct ggml_tensor * src1 = dst->src[1];\n\n    const enum ggml_type TYPE = src0->type;\n\n    // f16 only has avx512 kernels for now,\n    // amx kernels will be added once 6th gen xeon is released.\n    const bool is_floating_type = TYPE == GGML_TYPE_F16;\n\n    const int M = dst->ne[1];\n    const int N = dst->ne[0];\n    const int K = src0->ne[0];\n    const int ldc = dst->nb[1] / dst->nb[0];\n\n    const int64_t ne2 = dst->ne[2];\n    const int64_t n_batch = ne2 * dst->ne[3];\n\n    if (is_floating_type) {\n        constexpr int BLOCK_M = 4;\n        constexpr int BLOCK_N = 6;\n        const int MB = div_up(M, BLOCK_M);\n        const int NB = div_up(N, BLOCK_N);\n\n        parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {\n            GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {\n                for (int i = begin; i < end; ++i) {\n                    int batch_idx = i / (MB * NB);\n                    int remaining = i % (MB * NB);\n                    int mb = remaining / NB;\n                    int nb = remaining % NB;\n\n                    int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);\n                    int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);\n                    int64_t dst_offset  = ggml_batch_offset(dst,  batch_idx, ne2);\n\n                    int mb_start = mb * BLOCK_M;\n                    int mb_size = std::min(BLOCK_M, M - mb_start);\n                    int nb_start = nb * BLOCK_N;\n                    int nb_size = std::min(BLOCK_N, N - nb_start);\n\n                    switch (mb_size << 4 | nb_size) {\n                        case 0x12: LAUNCH_TINYGEMM_KERNEL_AVX(1, 2); break;\n                        case 0x14: LAUNCH_TINYGEMM_KERNEL_AVX(1, 4); break;\n                        case 0x16: LAUNCH_TINYGEMM_KERNEL_AVX(1, 6); break;\n                        case 0x22: LAUNCH_TINYGEMM_KERNEL_AVX(2, 2); break;\n                        case 0x24: LAUNCH_TINYGEMM_KERNEL_AVX(2, 4); break;\n                        case 0x26: LAUNCH_TINYGEMM_KERNEL_AVX(2, 6); break;\n                        case 0x32: LAUNCH_TINYGEMM_KERNEL_AVX(3, 2); break;\n                        case 0x34: LAUNCH_TINYGEMM_KERNEL_AVX(3, 4); break;\n                        case 0x36: LAUNCH_TINYGEMM_KERNEL_AVX(3, 6); break;\n                        case 0x42: LAUNCH_TINYGEMM_KERNEL_AVX(4, 2); break;\n                        case 0x44: LAUNCH_TINYGEMM_KERNEL_AVX(4, 4); break;\n                        case 0x46: LAUNCH_TINYGEMM_KERNEL_AVX(4, 6); break;\n                        default: fprintf(stderr, \"Unexpected block size!\\n\");\n                    }\n                }\n            });\n        });\n        return;\n    }\n\n    // pointer to work space, used convert A from float to quantized type\n    void * wdata = params->wdata;\n\n    //TODO: performance improvement: merge quant A\n // if (params->ith == 0) {\n        GGML_DISPATCH_QTYPES(TYPE, [&] {\n            const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);\n            const size_t desired_wsize = n_batch * M * row_size_A;\n            if (params->wsize < desired_wsize) {\n                GGML_ABORT(\"insufficient work space size\");\n            }\n\n            // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size\n            // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size\n            GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);\n\n            parallel_for_ggml(params, n_batch, [&](int begin, int end) {\n                for (int batch_idx = begin; batch_idx < end; ++batch_idx) {\n                    int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);\n                    const float * A_data = (const float *)((const char *)src1->data + src1_offset);\n                    char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A;\n\n                    for (int m = 0; m < M; ++m) {\n                        from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);\n                    }\n                }\n            });\n        });\n // }\n\n    ggml_barrier(params->threadpool);\n\n    if (M == 1) {\n        // MB = 1 and handle 8 tiles in each block\n        constexpr int kTilesN = 4;\n        constexpr int BLOCK_N = TILE_N * kTilesN;\n        const int NB = div_up(N, BLOCK_N);\n\n        parallel_for_ggml(params, n_batch * NB, [&](int begin, int end) {\n            GGML_DISPATCH_QTYPES(TYPE, [&] {\n                const int KB = K / blck_size;\n                const int TILE_SIZE = get_tile_size<type>();\n                const int row_size_A = KB * sizeof(vec_dot_type);\n                for (int i = begin; i < end; ++i) {\n                    int batch_idx = i / NB;\n                    int nb = i % NB;\n\n                    int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);\n                    int64_t dst_offset  = ggml_batch_offset(dst,  batch_idx, ne2);\n                    const char * wdata_batch = (const char *)wdata + batch_idx * row_size_A;\n\n                    int nb_start = nb * BLOCK_N;\n                    int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96\n\n                    switch (nb_size) {\n                        //case 160: LAUNCH_TINYGEMM_KERNEL_VNNI(160); break;\n                        case 128: LAUNCH_TINYGEMM_KERNEL_VNNI(128); break;\n                        case 96: LAUNCH_TINYGEMM_KERNEL_VNNI(96); break;\n                        case 64: LAUNCH_TINYGEMM_KERNEL_VNNI(64); break;\n                        case 32: LAUNCH_TINYGEMM_KERNEL_VNNI(32); break;\n                        default: fprintf(stderr, \"Unexpected n block size!\\n\");\n                    }\n                }\n            });\n        });\n        return;\n    }\n\n    // handle 4 tiles at a tile\n    constexpr int BLOCK_M = TILE_M * 2;\n    constexpr int BLOCK_N = TILE_N * 2;\n    const int MB = div_up(M, BLOCK_M);\n    const int NB = div_up(N, BLOCK_N);\n\n    parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {\n        // init tile config for each thread\n        ggml_tile_config_init();\n\n        GGML_DISPATCH_QTYPES(TYPE, [&] {\n            const int KB = K / blck_size;\n            const int TILE_SIZE = get_tile_size<type>();\n            const int row_size_A = KB * sizeof(vec_dot_type);\n\n            for (int i = begin; i < end; ++i) {\n                int batch_idx = i / (MB * NB);\n                int remaining = i % (MB * NB);\n                int mb = remaining / NB;\n                int nb = remaining % NB;\n\n                int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);\n                int64_t dst_offset  = ggml_batch_offset(dst,  batch_idx, ne2);\n                const char * wdata_batch = (const char *)wdata + batch_idx * M * row_size_A;\n\n                int mb_start = mb * BLOCK_M;\n                int mb_size = std::min(BLOCK_M, M - mb_start);\n                int nb_start = nb * BLOCK_N;\n                int nb_size = BLOCK_N;\n\n                tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>(\n                    mb_size, nb_size, KB,\n                    wdata_batch + mb_start * row_size_A,\n                    (const char *)src0->data + src0_offset + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),\n                    (float *) dst->data + dst_offset + mb_start * N + nb_start, ldc);\n            }\n        });\n    });\n}\n\n#endif // if defined(__AMX_INT8__) && defined(__AVX512VNNI__)\n"
  },
  {
    "path": "src/ggml-cpu/amx/mmq.h",
    "content": "#pragma once\n#include \"common.h\"\n\nsize_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst);\n\nsize_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor);\n\nvoid ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);\n\nvoid ggml_backend_amx_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cpu/arch/arm/cpu-feats.cpp",
    "content": "#include \"ggml-backend-impl.h\"\n\n#if defined(__aarch64__)\n\n#if defined(__linux__)\n#include <sys/auxv.h>\n#elif defined(__APPLE__)\n#include <sys/sysctl.h>\n#endif\n\n#if !defined(HWCAP2_SVE2)\n#define HWCAP2_SVE2 (1 << 1)\n#endif\n\n#if !defined(HWCAP2_I8MM)\n#define HWCAP2_I8MM (1 << 13)\n#endif\n\n#if !defined(HWCAP2_SME)\n#define HWCAP2_SME (1 << 23)\n#endif\n\nstruct aarch64_features {\n    // has_neon not needed, aarch64 has NEON guaranteed\n    bool has_dotprod     = false;\n    bool has_fp16_va     = false;\n    bool has_sve         = false;\n    bool has_sve2        = false;\n    bool has_i8mm        = false;\n    bool has_sme         = false;\n\n    aarch64_features() {\n#if defined(__linux__)\n        uint32_t hwcap = getauxval(AT_HWCAP);\n        uint32_t hwcap2 = getauxval(AT_HWCAP2);\n\n        has_dotprod = !!(hwcap & HWCAP_ASIMDDP);\n        has_fp16_va = !!(hwcap & HWCAP_FPHP);\n        has_sve     = !!(hwcap & HWCAP_SVE);\n        has_sve2    = !!(hwcap2 & HWCAP2_SVE2);\n        has_i8mm    = !!(hwcap2 & HWCAP2_I8MM);\n        has_sme     = !!(hwcap2 & HWCAP2_SME);\n#elif defined(__APPLE__)\n        int oldp = 0;\n        size_t size = sizeof(oldp);\n\n        if (sysctlbyname(\"hw.optional.arm.FEAT_DotProd\", &oldp, &size, NULL, 0) == 0) {\n            has_dotprod = static_cast<bool>(oldp);\n        }\n\n        if (sysctlbyname(\"hw.optional.arm.FEAT_I8MM\", &oldp, &size, NULL, 0) == 0) {\n            has_i8mm = static_cast<bool>(oldp);\n        }\n\n        if (sysctlbyname(\"hw.optional.arm.FEAT_SME\", &oldp, &size, NULL, 0) == 0) {\n            has_sme = static_cast<bool>(oldp);\n        }\n\n        // Apple apparently does not implement SVE yet\n#endif\n    }\n};\n\nstatic int ggml_backend_cpu_aarch64_score() {\n    int score = 1;\n    aarch64_features af;\n\n#ifdef GGML_USE_DOTPROD\n    if (!af.has_dotprod) { return 0; }\n    score += 1<<1;\n#endif\n#ifdef GGML_USE_FP16_VECTOR_ARITHMETIC\n    if (!af.has_fp16_va) { return 0; }\n    score += 1<<2;\n#endif\n#ifdef GGML_USE_SVE\n    if (!af.has_sve) { return 0; }\n    score += 1<<3;\n#endif\n#ifdef GGML_USE_MATMUL_INT8\n    if (!af.has_i8mm) { return 0; }\n    score += 1<<4;\n#endif\n#ifdef GGML_USE_SVE2\n    if (!af.has_sve2) { return 0; }\n    score += 1<<5;\n#endif\n#ifdef GGML_USE_SME\n    if (!af.has_sme) { return 0; }\n    score += 1<<6;\n#endif\n\n    return score;\n}\n\nGGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_aarch64_score)\n\n# endif // defined(__aarch64__)\n"
  },
  {
    "path": "src/ggml-cpu/arch/arm/quants.c",
    "content": "#define GGML_COMMON_IMPL_C\n#include \"ggml-common.h\"\n#include \"ggml-quants.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-cpu.h\"\n#include \"simd-mappings.h\"\n\n#include \"../../quants.h\"\n#include \"../../ggml-cpu-impl.h\"\n\n#include <math.h>\n#include <string.h>\n#include <assert.h>\n#include <float.h>\n#include <stdlib.h> // for qsort\n#include <stdio.h>  // for GGML_ASSERT\n\n#define GROUP_MAX_EPS 1e-15f\n#define GROUP_MAX_EPS_IQ3_XXS 1e-8f\n#define GROUP_MAX_EPS_IQ2_S 1e-8f\n#define GROUP_MAX_EPS_IQ1_M 1e-7f\n#define GROUP_MAX_EPS_IQ1_S 1e-12f\n\n#define UNUSED GGML_UNUSED\n\n#if defined(__ARM_NEON)\n#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s\n#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)\n#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)\n#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)\n#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)\n#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)\n#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)\n#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)\n\n// precomputed tables for expanding 8bits to 8 bytes:\nstatic const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4\nstatic const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4\n#endif\n\nvoid quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK8_0 == 32);\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n    block_q8_0 * GGML_RESTRICT y = vy;\n\n#if defined(__ARM_NEON)\n    for (int i = 0; i < nb; i++) {\n        float32x4_t srcv [8];\n        float32x4_t asrcv[8];\n        float32x4_t amaxv[8];\n\n        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);\n        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);\n\n        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);\n        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);\n        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);\n\n        const float amax = vmaxvq_f32(amaxv[0]);\n\n        const float d = amax / ((1 << 7) - 1);\n        const float id = d ? 1.0f/d : 0.0f;\n\n        y[i].d = GGML_CPU_FP32_TO_FP16(d);\n\n        for (int j = 0; j < 8; j++) {\n            const float32x4_t v  = vmulq_n_f32(srcv[j], id);\n            const int32x4_t   vi = vcvtnq_s32_f32(v);\n\n            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);\n            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);\n            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);\n            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);\n        }\n    }\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_0_ref(x, y, k);\n#endif\n}\n\nvoid quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(k % QK8_1 == 0);\n    const int nb = k / QK8_1;\n\n    block_q8_1 * GGML_RESTRICT y = vy;\n#if defined(__ARM_NEON)\n    for (int i = 0; i < nb; i++) {\n        float32x4_t srcv [8];\n        float32x4_t asrcv[8];\n        float32x4_t amaxv[8];\n\n        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);\n        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);\n\n        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);\n        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);\n        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);\n\n        const float amax = vmaxvq_f32(amaxv[0]);\n\n        const float d = amax / ((1 << 7) - 1);\n        const float id = d ? 1.0f/d : 0.0f;\n\n        y[i].d = GGML_CPU_FP32_TO_FP16(d);\n\n        int32x4_t accv = vdupq_n_s32(0);\n\n        for (int j = 0; j < 8; j++) {\n            const float32x4_t v  = vmulq_n_f32(srcv[j], id);\n            const int32x4_t   vi = vcvtnq_s32_f32(v);\n\n            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);\n            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);\n            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);\n            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);\n\n            accv = vaddq_s32(accv, vi);\n        }\n\n        y[i].s = GGML_CPU_FP32_TO_FP16(d * vaddvq_s32(accv));\n    }\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_1_ref(x, y, k);\n#endif\n}\n\n// placeholder implementation for Apple targets\nvoid quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n    quantize_row_q8_K_ref(x, y, k);\n}\n\n//===================================== Dot products =================================\n\nvoid ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n#if defined(__ARM_FEATURE_MATMUL_INT8)\n    assert((nrc == 2) || (nrc == 1));\n#else\n    assert(nrc == 1);\n#endif\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n#if defined(__ARM_FEATURE_MATMUL_INT8)\n    if (nrc == 2) {\n        const block_q4_0 * GGML_RESTRICT vx0 = vx;\n        const block_q4_0 * GGML_RESTRICT vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);\n        const block_q8_0 * GGML_RESTRICT vy0 = vy;\n        const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);\n\n        float32x4_t sumv0 = vdupq_n_f32(0.0f);\n\n        for (int i = 0; i < nb; i++) {\n            const block_q4_0 * GGML_RESTRICT b_x0 = &vx0[i];\n            const block_q4_0 * GGML_RESTRICT b_x1 = &vx1[i];\n            const block_q8_0 * GGML_RESTRICT b_y0 = &vy0[i];\n            const block_q8_0 * GGML_RESTRICT b_y1 = &vy1[i];\n\n            const uint8x16_t m4b = vdupq_n_u8(0x0F);\n            const int8x16_t  s8b = vdupq_n_s8(0x8);\n\n            const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);\n            const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);\n\n            // 4-bit -> 8-bit\n            const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));\n            const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));\n            const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));\n            const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));\n\n            // sub 8\n            const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);\n            const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);\n            const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);\n            const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);\n\n            // load y\n            const int8x16_t y0_l = vld1q_s8(b_y0->qs);\n            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);\n            const int8x16_t y1_l = vld1q_s8(b_y1->qs);\n            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);\n\n            float32_t _scale[4] = {\n                GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),\n                GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y1->d),\n                GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),\n                GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y1->d)\n            };\n            float32x4_t scale = vld1q_f32(_scale);\n\n            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));\n            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));\n\n            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));\n            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));\n\n            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));\n            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));\n\n            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));\n            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));\n\n            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),\n                                                l1, r1)), l2, r2)), l3, r3))), scale);\n        }\n\n        float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);\n        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);\n\n        vst1_f32(s,      vget_low_f32 (sumv2));\n        vst1_f32(s + bs, vget_high_f32(sumv2));\n\n        return;\n    }\n#endif\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__ARM_FEATURE_SVE)\n    svfloat32_t sumv0 = svdup_n_f32(0.0f);\n    svfloat32_t sumv1 = svdup_n_f32(0.0f);\n\n    const int vector_length = ggml_cpu_get_sve_cnt()*8;\n\n    // VLA Implementation using switch case\n    switch (vector_length) {\n        case 128:\n            {\n                // predicate for activating higher lanes for 4 float32 elements\n                const svbool_t ph4 = svptrue_pat_b32(SV_VL4);\n\n                for (; ib + 1 < nb; ib += 2) {\n                    const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];\n                    const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];\n                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];\n                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];\n\n                    // load x\n                    const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);\n                    const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);\n\n                    // 4-bit -> 8-bit\n                    const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F));\n                    const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04));\n                    const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F));\n                    const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));\n\n                    // sub 8\n                    const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);\n                    const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);\n                    const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);\n                    const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);\n\n                    // load y\n                    const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);\n                    const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16);\n                    const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);\n                    const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16);\n\n                    // dot product\n                    sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4,\n                                    svdot_s32(svdup_n_s32(0), qx0ls, qy0l),\n                                    svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));\n                    sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4,\n                                    svdot_s32(svdup_n_s32(0), qx1ls, qy1l),\n                                    svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));\n                }\n\n                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));\n            } break;\n        case 256:\n            {\n                // predicate for activating higher lanes for 16 int8 elements\n                const svbool_t ph16 = svptrue_pat_b8(SV_VL16);\n                // predicate for activating lower lanes for  16 int8 elements\n                const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);\n\n                for (; ib + 1 < nb; ib += 2) {\n                    const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];\n                    const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];\n                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];\n                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];\n\n                    // load x\n                    const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);\n                    const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);\n\n                    // 4-bit -> 8-bit\n                    const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));\n                    const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));\n\n                    // sub 8\n                    const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);\n                    const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);\n\n                    // load y\n                    const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);\n                    const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);\n\n                    // dot product\n                    sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),\n                                svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));\n                    sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),\n                                svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));\n                }\n\n                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));\n            } break;\n        case 512:\n            {\n                // predicate for activating higher lanes for 32 int8 elements\n                const svbool_t ph32 = svptrue_pat_b8(SV_VL32);\n\n                // predicate for activating higher lanes for 16 int8 elements\n                const svbool_t ph16 = svptrue_pat_b8(SV_VL16);\n                // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes\n                const svbool_t pl16 = svnot_b_z(ph32, ph16);\n\n                for (; ib + 1 < nb; ib += 2) {\n                    const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];\n                    const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];\n                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];\n                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];\n\n                    // load x\n                    const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);\n                    const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);\n\n                    // 4-bit -> 8-bit\n                    const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));\n                    const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));\n\n                    // sub 8\n                    const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8);\n                    const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8);\n\n                    // load y\n                    const svint8_t qy0 = svld1_s8(ph32, y0->qs);\n                    const svint8_t qy1 = svld1_s8(ph32, y1->qs);\n\n                    // dot product\n                    sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32,\n                                svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));\n                    sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32,\n                                svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));\n                }\n\n                sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));\n            } break;\n        default:\n            assert(false && \"Unsupported vector length\");\n            break;\n    }\n\n#elif defined(__ARM_NEON)\n    float32x4_t sumv0 = vdupq_n_f32(0.0f);\n    float32x4_t sumv1 = vdupq_n_f32(0.0f);\n\n    for (; ib + 1 < nb; ib += 2) {\n        const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0];\n        const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];\n        const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];\n        const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];\n\n        const uint8x16_t m4b = vdupq_n_u8(0x0F);\n        const int8x16_t  s8b = vdupq_n_s8(0x8);\n\n        const uint8x16_t v0_0 = vld1q_u8(x0->qs);\n        const uint8x16_t v0_1 = vld1q_u8(x1->qs);\n\n        // 4-bit -> 8-bit\n        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));\n        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));\n        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));\n        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));\n\n        // sub 8\n        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);\n        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);\n        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);\n        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);\n\n        // load y\n        const int8x16_t v1_0l = vld1q_s8(y0->qs);\n        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);\n        const int8x16_t v1_1l = vld1q_s8(y1->qs);\n        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);\n\n        // dot product into int32x4_t\n        const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);\n        const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);\n\n        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));\n        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));\n    }\n\n    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);\n#endif\n    for (; ib < nb; ++ib) {\n        int sumi0 = 0;\n        int sumi1 = 0;\n\n        for (int j = 0; j < qk/2; ++j) {\n            const int v0 = (x[ib].qs[j] & 0x0F) - 8;\n            const int v1 = (x[ib].qs[j] >>   4) - 8;\n\n            sumi0 += (v0 * y[ib].qs[j]);\n            sumi1 += (v1 * y[ib].qs[j + qk/2]);\n        }\n\n        int sumi = sumi0 + sumi1;\n        sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n#if defined(__ARM_FEATURE_MATMUL_INT8)\n    assert((nrc == 2) || (nrc == 1));\n#else\n    assert(nrc == 1);\n#endif\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n#if defined(__ARM_FEATURE_MATMUL_INT8)\n    if (nrc == 2) {\n        const block_q4_1 * GGML_RESTRICT vx0 = vx;\n        const block_q4_1 * GGML_RESTRICT vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx);\n        const block_q8_1 * GGML_RESTRICT vy0 = vy;\n        const block_q8_1 * GGML_RESTRICT vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by);\n\n        float32x4_t sumv0 = vdupq_n_f32(0.0f);\n        float32x4_t summs0 = vdupq_n_f32(0.0f);\n\n        for (int i = 0; i < nb; i++) {\n            const block_q4_1 * GGML_RESTRICT b_x0 = &vx0[i];\n            const block_q4_1 * GGML_RESTRICT b_x1 = &vx1[i];\n            const block_q8_1 * GGML_RESTRICT b_y0 = &vy0[i];\n            const block_q8_1 * GGML_RESTRICT b_y1 = &vy1[i];\n\n            float32_t summs_t[4] = {\n                GGML_CPU_FP16_TO_FP32(b_x0->m) * GGML_CPU_FP16_TO_FP32(b_y0->s),\n                GGML_CPU_FP16_TO_FP32(b_x1->m) * GGML_CPU_FP16_TO_FP32(b_y0->s),\n                GGML_CPU_FP16_TO_FP32(b_x0->m) * GGML_CPU_FP16_TO_FP32(b_y1->s),\n                GGML_CPU_FP16_TO_FP32(b_x1->m) * GGML_CPU_FP16_TO_FP32(b_y1->s)\n            };\n            summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));\n\n            const uint8x16_t m4b = vdupq_n_u8(0x0F);\n\n            const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);\n            const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);\n\n            // 4-bit -> 8-bit\n            const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));\n            const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));\n            const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));\n            const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));\n\n            // load y\n            const int8x16_t y0_l = vld1q_s8(b_y0->qs);\n            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);\n            const int8x16_t y1_l = vld1q_s8(b_y1->qs);\n            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);\n\n            // mmla into int32x4_t\n            float32_t _scale[4] = {\n                GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),\n                GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y1->d),\n                GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),\n                GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y1->d)\n            };\n            float32x4_t scale = vld1q_f32(_scale);\n\n            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));\n            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));\n\n            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));\n            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));\n\n            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));\n            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));\n\n            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));\n            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));\n            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),\n                                                l1, r1)), l2, r2)), l3, r3))), scale);\n        }\n\n        float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);\n        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);\n\n        sumv2 = vaddq_f32(sumv2, summs0);\n\n        vst1_f32(s,      vget_low_f32 (sumv2));\n        vst1_f32(s + bs, vget_high_f32(sumv2));\n\n        return;\n    }\n#endif\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__ARM_NEON)\n    float32x4_t sumv0 = vdupq_n_f32(0.0f);\n    float32x4_t sumv1 = vdupq_n_f32(0.0f);\n\n    float summs = 0;\n\n    for (; ib + 1 < nb; ib += 2) {\n        const block_q4_1 * GGML_RESTRICT x0 = &x[ib + 0];\n        const block_q4_1 * GGML_RESTRICT x1 = &x[ib + 1];\n        const block_q8_1 * GGML_RESTRICT y0 = &y[ib + 0];\n        const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1];\n\n        summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s) + GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s);\n\n        const uint8x16_t m4b = vdupq_n_u8(0x0F);\n\n        const uint8x16_t v0_0 = vld1q_u8(x0->qs);\n        const uint8x16_t v0_1 = vld1q_u8(x1->qs);\n\n        // 4-bit -> 8-bit\n        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));\n        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));\n        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));\n        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));\n\n        // load y\n        const int8x16_t v1_0l = vld1q_s8(y0->qs);\n        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);\n        const int8x16_t v1_1l = vld1q_s8(y1->qs);\n        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);\n\n        // dot product into int32x4_t\n        const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);\n        const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);\n\n        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));\n        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));\n    }\n\n    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;\n\n#endif\n    for (; ib < nb; ++ib) {\n        int sumi0 = 0;\n        int sumi1 = 0;\n\n        for (int j = 0; j < qk/2; ++j) {\n            const int v0 = (x[ib].qs[j] & 0x0F);\n            const int v1 = (x[ib].qs[j] >>   4);\n\n            sumi0 += (v0 * y[ib].qs[j]);\n            sumi1 += (v1 * y[ib].qs[j + qk/2]);\n        }\n\n        int sumi = sumi0 + sumi1;\n        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_MXFP4 == 0);\n    static_assert(QK_MXFP4 == QK8_0, \"QK_MXFP4 and QK8_0 must be the same\");\n\n    const block_mxfp4 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_MXFP4;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined __ARM_NEON\n    const int8x16_t values = vld1q_s8(kvalues_mxfp4);\n    const uint8x16_t m4b = vdupq_n_u8(0x0f);\n    uint8x16x2_t q4bits;\n    int8x16x4_t q4b;\n    int8x16x4_t q8b;\n    int32x4_t prod_1;\n    int32x4_t prod_2;\n\n    for (; ib + 1 < nb; ib += 2) {\n        q4bits.val[0] = vld1q_u8(x[ib + 0].qs);\n        q4bits.val[1] = vld1q_u8(x[ib + 1].qs);\n        q8b.val[0]    = vld1q_s8(y[ib + 0].qs);\n        q8b.val[1]    = vld1q_s8(y[ib + 0].qs + 16);\n        q8b.val[2]    = vld1q_s8(y[ib + 1].qs);\n        q8b.val[3]    = vld1q_s8(y[ib + 1].qs + 16);\n\n        q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[0], m4b));\n        q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));\n        q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[1], m4b));\n        q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));\n\n        prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);\n        prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);\n\n        sumf +=\n            GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +\n            GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);\n    }\n\n#endif\n    for (; ib < nb; ++ib) {\n        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);\n        int sumi1 = 0;\n        int sumi2 = 0;\n        for (int j = 0; j < QK_MXFP4/2; ++j) {\n            sumi1 += y[ib].qs[j +          0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];\n            sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >>  4];\n        }\n        sumf += d * (sumi1 + sumi2);\n    }\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_NVFP4 == 0);\n\n    const block_nvfp4 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    // Each NVFP4 super-block (64 elements) spans 2 q8_0 blocks\n    const int nb = n / QK_NVFP4;\n\n    float sumf = 0;\n\n#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)\n    const int8x16_t values = vld1q_s8(kvalues_mxfp4);\n    const uint8x16_t m4b = vdupq_n_u8(0x0f);\n    float32x4_t acc = vdupq_n_f32(0.0f);\n\n    for (int ib = 0; ib < nb; ++ib) {\n        const uint8x16_t q4bits_0 = vld1q_u8(x[ib].qs);\n        const uint8x16_t q4bits_1 = vld1q_u8(x[ib].qs + 16);\n\n        const int8x16_t q4_lo_0 = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits_0, m4b));\n        const int8x16_t q4_hi_0 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_0, 4));\n        const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits_1, m4b));\n        const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4));\n\n        const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs);\n        const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16);\n        const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b));\n        const int8x16_t q8_hi_0 = vcombine_s8(vget_high_s8(q8_0a), vget_high_s8(q8_0b));\n\n        const int8x16_t q8_1a = vld1q_s8(y[2*ib+1].qs);\n        const int8x16_t q8_1b = vld1q_s8(y[2*ib+1].qs + 16);\n        const int8x16_t q8_lo_1 = vcombine_s8(vget_low_s8(q8_1a), vget_low_s8(q8_1b));\n        const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b));\n\n        const int32x4_t p0 = vaddq_s32(\n            ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0),\n            ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0));\n        const int32x4_t p1 = vaddq_s32(\n            ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1),\n            ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1));\n\n        const int32x4_t sums = vpaddq_s32(p0, p1);\n\n        // Decode 4 UE4M3 scales to f32 and multiply with q8 scales\n        const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);\n        const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);\n        const float32x4_t nvsc = {\n            ggml_ue4m3_to_fp32(x[ib].d[0]),\n            ggml_ue4m3_to_fp32(x[ib].d[1]),\n            ggml_ue4m3_to_fp32(x[ib].d[2]),\n            ggml_ue4m3_to_fp32(x[ib].d[3])\n        };\n        const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1});\n\n        acc = vfmaq_f32(acc, vcvtq_f32_s32(sums), scales);\n    }\n    sumf = vaddvq_f32(acc);\n#else\n    for (int ib = 0; ib < nb; ++ib) {\n        for (int si = 0; si < 4; ++si) {\n            const float d = ggml_ue4m3_to_fp32(x[ib].d[si]);\n            const int q8b = si / 2;\n            const int q8o = (si % 2) * QK_NVFP4_SUB;\n            const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8b].d);\n\n            int sumi_lo = 0, sumi_hi = 0;\n            for (int j = 0; j < QK_NVFP4_SUB/2; ++j) {\n                const uint8_t qv = x[ib].qs[si*(QK_NVFP4_SUB/2) + j];\n                sumi_lo += y[2*ib + q8b].qs[q8o + j +               0] * kvalues_mxfp4[qv & 0xf];\n                sumi_hi += y[2*ib + q8b].qs[q8o + j + QK_NVFP4_SUB/2] * kvalues_mxfp4[qv >>  4];\n            }\n            sumf += dy * d * (sumi_lo + sumi_hi);\n        }\n    }\n#endif\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    int ib = 0;\n    float sumf = 0;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n#if defined(__ARM_NEON)\n    float32x4_t sumv0 = vdupq_n_f32(0.0f);\n    float32x4_t sumv1 = vdupq_n_f32(0.0f);\n\n    uint32_t qh0;\n    uint32_t qh1;\n\n    uint64_t tmp0[4];\n    uint64_t tmp1[4];\n\n    for (; ib + 1 < nb; ib += 2) {\n        const block_q5_0 * GGML_RESTRICT x0 = &x[ib];\n        const block_q5_0 * GGML_RESTRICT x1 = &x[ib + 1];\n        const block_q8_0 * GGML_RESTRICT y0 = &y[ib];\n        const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];\n\n        const uint8x16_t m4b = vdupq_n_u8(0x0F);\n\n        // extract the 5th bit via lookup table ((!b) << 4)\n        memcpy(&qh0, x0->qh, sizeof(qh0));\n        memcpy(&qh1, x1->qh, sizeof(qh1));\n\n        tmp0[0] = table_b2b_1[(qh0 >>  0) & 0xFF];\n        tmp0[1] = table_b2b_1[(qh0 >>  8) & 0xFF];\n        tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];\n        tmp0[3] = table_b2b_1[(qh0 >> 24)       ];\n\n        tmp1[0] = table_b2b_1[(qh1 >>  0) & 0xFF];\n        tmp1[1] = table_b2b_1[(qh1 >>  8) & 0xFF];\n        tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];\n        tmp1[3] = table_b2b_1[(qh1 >> 24)       ];\n\n        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));\n        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));\n        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));\n        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));\n\n        const uint8x16_t v0_0 = vld1q_u8(x0->qs);\n        const uint8x16_t v0_1 = vld1q_u8(x1->qs);\n\n        // 4-bit -> 8-bit\n        int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));\n        int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));\n        int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));\n        int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));\n\n        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)\n        const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);\n        const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);\n        const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);\n        const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);\n\n        // load y\n        const int8x16_t v1_0l = vld1q_s8(y0->qs);\n        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);\n        const int8x16_t v1_1l = vld1q_s8(y1->qs);\n        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);\n\n        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(\n                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),\n                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));\n        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(\n                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),\n                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));\n    }\n\n    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);\n\n#endif\n    for (; ib < nb; ++ib) {\n        uint32_t qh;\n        memcpy(&qh, x[ib].qh, sizeof(qh));\n\n        int sumi0 = 0;\n        int sumi1 = 0;\n\n        for (int j = 0; j < qk/2; ++j) {\n            const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;\n            const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));\n\n            const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);\n            const int32_t x1 = (int8_t)(((x[ib].qs[j] >>   4) | xh_1) - 16);\n\n            sumi0 += (x0 * y[ib].qs[j]);\n            sumi1 += (x1 * y[ib].qs[j + qk/2]);\n        }\n\n        int sumi = sumi0 + sumi1;\n        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    int ib = 0;\n    float sumf = 0;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_1);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n#if defined(__ARM_NEON)\n    float32x4_t sumv0 = vdupq_n_f32(0.0f);\n    float32x4_t sumv1 = vdupq_n_f32(0.0f);\n\n    float summs0 = 0.0f;\n    float summs1 = 0.0f;\n\n    uint32_t qh0;\n    uint32_t qh1;\n\n    uint64_t tmp0[4];\n    uint64_t tmp1[4];\n\n    for (; ib + 1 < nb; ib += 2) {\n        const block_q5_1 * GGML_RESTRICT x0 = &x[ib];\n        const block_q5_1 * GGML_RESTRICT x1 = &x[ib + 1];\n        const block_q8_1 * GGML_RESTRICT y0 = &y[ib];\n        const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1];\n\n        const uint8x16_t m4b = vdupq_n_u8(0x0F);\n\n        summs0 += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);\n        summs1 += GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s);\n\n        // extract the 5th bit via lookup table ((b) << 4)\n        memcpy(&qh0, x0->qh, sizeof(qh0));\n        memcpy(&qh1, x1->qh, sizeof(qh1));\n\n        tmp0[0] = table_b2b_0[(qh0 >>  0) & 0xFF];\n        tmp0[1] = table_b2b_0[(qh0 >>  8) & 0xFF];\n        tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];\n        tmp0[3] = table_b2b_0[(qh0 >> 24)       ];\n\n        tmp1[0] = table_b2b_0[(qh1 >>  0) & 0xFF];\n        tmp1[1] = table_b2b_0[(qh1 >>  8) & 0xFF];\n        tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];\n        tmp1[3] = table_b2b_0[(qh1 >> 24)       ];\n\n        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));\n        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));\n        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));\n        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));\n\n        const uint8x16_t v0_0 = vld1q_u8(x0->qs);\n        const uint8x16_t v0_1 = vld1q_u8(x1->qs);\n\n        // 4-bit -> 8-bit\n        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));\n        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));\n        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));\n        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));\n\n        // add high bit\n        const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);\n        const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);\n        const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);\n        const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);\n\n        // load y\n        const int8x16_t v1_0l = vld1q_s8(y0->qs);\n        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);\n        const int8x16_t v1_1l = vld1q_s8(y1->qs);\n        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);\n\n        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(\n                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),\n                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));\n        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(\n                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),\n                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));\n    }\n\n    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;\n\n#endif\n    for (; ib < nb; ++ib) {\n        uint32_t qh;\n        memcpy(&qh, x[ib].qh, sizeof(qh));\n\n        int sumi0 = 0;\n        int sumi1 = 0;\n\n        for (int j = 0; j < qk/2; ++j) {\n            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;\n            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;\n\n            const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;\n            const int32_t x1 = (x[ib].qs[j] >>  4) | xh_1;\n\n            sumi0 += (x0 * y[ib].qs[j]);\n            sumi1 += (x1 * y[ib].qs[j + qk/2]);\n        }\n\n        int sumi = sumi0 + sumi1;\n        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n#if defined(__ARM_FEATURE_MATMUL_INT8)\n    assert((nrc == 2) || (nrc == 1));\n#else\n    assert(nrc == 1);\n#endif\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q8_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n#if defined(__ARM_FEATURE_MATMUL_INT8)\n    if (nrc == 2) {\n        const block_q8_0 * GGML_RESTRICT vx0 = vx;\n        const block_q8_0 * GGML_RESTRICT vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx);\n        const block_q8_0 * GGML_RESTRICT vy0 = vy;\n        const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);\n\n        float32x4_t sumv0 = vdupq_n_f32(0.0f);\n\n        for (int i = 0; i < nb; i++) {\n            const block_q8_0 * GGML_RESTRICT b_x0 = &vx0[i];\n            const block_q8_0 * GGML_RESTRICT b_y0 = &vy0[i];\n\n            const block_q8_0 * GGML_RESTRICT b_x1 = &vx1[i];\n            const block_q8_0 * GGML_RESTRICT b_y1 = &vy1[i];\n\n            const int8x16_t x0_l = vld1q_s8(b_x0->qs);\n            const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);\n            const int8x16_t x1_l = vld1q_s8(b_x1->qs);\n            const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);\n\n            // load y\n            const int8x16_t y0_l = vld1q_s8(b_y0->qs);\n            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);\n            const int8x16_t y1_l = vld1q_s8(b_y1->qs);\n            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);\n\n            float32_t _scale[4] = {\n                GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),\n                GGML_CPU_FP16_TO_FP32(b_x0->d)*GGML_CPU_FP16_TO_FP32(b_y1->d),\n                GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y0->d),\n                GGML_CPU_FP16_TO_FP32(b_x1->d)*GGML_CPU_FP16_TO_FP32(b_y1->d)\n            };\n            float32x4_t scale = vld1q_f32(_scale);\n\n            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));\n            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));\n\n            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));\n            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));\n\n            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));\n            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));\n\n            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));\n            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));\n\n            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),\n                                                l1, r1)), l2, r2)), l3, r3))), scale);\n        }\n\n        float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);\n        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);\n\n        vst1_f32(s,      vget_low_f32 (sumv2));\n        vst1_f32(s + bs, vget_high_f32(sumv2));\n\n        return;\n    }\n#endif\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__ARM_FEATURE_SVE)\n    svfloat32_t sumv0 = svdup_n_f32(0.0f);\n    svfloat32_t sumv1 = svdup_n_f32(0.0f);\n\n    const int vector_length = ggml_cpu_get_sve_cnt()*8;\n\n    //VLA Implementation for SVE\n    switch (vector_length) {\n        case 128:\n            {\n                // predicate for activating lanes for 16 Int8 elements\n                const svbool_t ph16 = svptrue_pat_b8 (SV_VL16);\n                const svbool_t pl16 = svptrue_pat_b32(SV_VL4);\n\n                for (; ib + 1 < nb; ib += 2) {\n                    const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];\n                    const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];\n                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];\n                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];\n\n                    // load x\n                    const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);\n                    const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16);\n                    const svint8_t qx1_0 = svld1_s8(ph16, x1->qs);\n                    const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16);\n\n                    // load y\n                    const svint8_t qy0_0 = svld1_s8(ph16, y0->qs);\n                    const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16);\n                    const svint8_t qy1_0 = svld1_s8(ph16, y1->qs);\n                    const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16);\n\n                    sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16,\n                                    svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),\n                                    svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));\n                    sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16,\n                                    svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),\n                                    svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));\n                }\n\n                sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1));\n            } break;\n        case 256:\n            {\n                //printf(\"sve256\");\n                for (; ib + 1 < nb; ib += 2) {\n                    const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];\n                    const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];\n                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];\n                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];\n\n                    // load x\n                    const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);\n                    const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);\n\n                    // load y\n                    const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);\n                    const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);\n\n                    sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),\n                                svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));\n                    sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),\n                                svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));\n                }\n\n                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));\n            } break;\n        case 512:\n            {\n                // predicate for activating high 256 bit\n                const svbool_t ph32 = svptrue_pat_b8(SV_VL32);\n                // predicate for activating low 256 bit\n                const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32);\n\n                // predicate for activating high lanes for 8 float32 elements\n                const svbool_t ph8 = svptrue_pat_b32(SV_VL8);\n                // predicate for activating low lanes for 8 float32 elements\n                const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8);\n\n                svfloat32_t sumv00 = svdup_n_f32(0.0f);\n\n                for (; ib + 1 < nb; ib += 2) {\n                    const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];\n                    const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];\n                    const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];\n                    const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];\n\n                    //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits\n                    // and add them to make one 64 element vector\n                    // load x\n                    const svint8_t qx_32 = svld1_s8(ph32, x0->qs);\n                          svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2);\n\n                    qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);\n\n                    // load y\n                    const svint8_t qy_32 = svld1_s8(ph32, y0->qs);\n                          svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2);\n\n                    qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);\n\n                    // scale creation\n                    const float32_t deq1 = GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d);\n                    const float32_t deq2 = GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d);\n\n                    // duplicate deq1 in first half of vector and deq2 in second half of vector\n                    const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2);\n\n                    const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));\n\n                    sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp);\n                }\n\n                sumf = svaddv_f32(svptrue_b32(), sumv00);\n                break;\n            }\n        default:\n            assert(false && \"Unsupported vector length\");\n            break;\n    }\n#elif defined(__ARM_NEON)\n    float32x4_t sumv0 = vdupq_n_f32(0.0f);\n    float32x4_t sumv1 = vdupq_n_f32(0.0f);\n\n    for (; ib + 1 < nb; ib += 2) {\n        const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0];\n        const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1];\n        const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];\n        const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];\n\n        const int8x16_t x0_0 = vld1q_s8(x0->qs);\n        const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);\n        const int8x16_t x1_0 = vld1q_s8(x1->qs);\n        const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);\n\n        // load y\n        const int8x16_t y0_0 = vld1q_s8(y0->qs);\n        const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);\n        const int8x16_t y1_0 = vld1q_s8(y1->qs);\n        const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);\n\n        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(\n                        ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),\n                        ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_CPU_FP16_TO_FP32(x0->d)*GGML_CPU_FP16_TO_FP32(y0->d));\n\n        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(\n                        ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),\n                        ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_CPU_FP16_TO_FP32(x1->d)*GGML_CPU_FP16_TO_FP32(y1->d));\n    }\n\n    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);\n#endif\n    for (; ib < nb; ++ib) {\n        int sumi = 0;\n\n        for (int j = 0; j < qk; j++) {\n            sumi += x[ib].qs[j]*y[ib].qs[j];\n        }\n\n        sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_tq1_0 * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__ARM_NEON)\n    float sumf = 0.0f;\n\n    uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};\n\n    const uint8x16_t shift = vld1q_u8(k_shift);\n\n    for (int i = 0; i < nb; ++i) {\n#if defined(__ARM_FEATURE_DOTPROD)\n        int32x4_t sumi0 = vdupq_n_s32(0);\n        int32x4_t sumi1 = vdupq_n_s32(0);\n#else\n        int16x8_t sumi0 = vdupq_n_s16(0);\n        int16x8_t sumi1 = vdupq_n_s16(0);\n#endif\n\n        // first 32 bytes of 5 elements\n        {\n            uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);\n            uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);\n            uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));\n            uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));\n            uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));\n            uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));\n            uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));\n            uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));\n            uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));\n            uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));\n\n            // multiply by 3 and keep the 2 bits above 8 bits\n            int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));\n            int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));\n            int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));\n            int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));\n            int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));\n            int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));\n            int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));\n            int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));\n            int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));\n            int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));\n\n            const int8x16_t qy0 = vld1q_s8(y[i].qs +   0);\n            const int8x16_t qy1 = vld1q_s8(y[i].qs +  16);\n            const int8x16_t qy2 = vld1q_s8(y[i].qs +  32);\n            const int8x16_t qy3 = vld1q_s8(y[i].qs +  48);\n            const int8x16_t qy4 = vld1q_s8(y[i].qs +  64);\n            const int8x16_t qy5 = vld1q_s8(y[i].qs +  80);\n            const int8x16_t qy6 = vld1q_s8(y[i].qs +  96);\n            const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);\n            const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);\n            const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);\n\n#if defined(__ARM_FEATURE_DOTPROD)\n            sumi0 = vdotq_s32(sumi0, sqx0, qy0);\n            sumi1 = vdotq_s32(sumi1, sqx1, qy1);\n            sumi0 = vdotq_s32(sumi0, sqx2, qy2);\n            sumi1 = vdotq_s32(sumi1, sqx3, qy3);\n            sumi0 = vdotq_s32(sumi0, sqx4, qy4);\n            sumi1 = vdotq_s32(sumi1, sqx5, qy5);\n            sumi0 = vdotq_s32(sumi0, sqx6, qy6);\n            sumi1 = vdotq_s32(sumi1, sqx7, qy7);\n            sumi0 = vdotq_s32(sumi0, sqx8, qy8);\n            sumi1 = vdotq_s32(sumi1, sqx9, qy9);\n#else\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));\n#endif\n        }\n\n        // last 16 bytes of 5-element, along with the 4 bytes of 4 elements\n        {\n            uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);\n            uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));\n            uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));\n            uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));\n            uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));\n            uint32_t qh;\n            memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned\n            uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));\n            qx5 = vmulq_u8(qx5, shift);\n\n            // multiply by 3 and keep the 2 bits above 8 bits\n            int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));\n            int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));\n            int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));\n            int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));\n            int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));\n            int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));\n\n            const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);\n            const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);\n            const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);\n            const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);\n            const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);\n            const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);\n\n#if defined(__ARM_FEATURE_DOTPROD)\n            sumi0 = vdotq_s32(sumi0, sqx0, qy0);\n            sumi1 = vdotq_s32(sumi1, sqx1, qy1);\n            sumi0 = vdotq_s32(sumi0, sqx2, qy2);\n            sumi1 = vdotq_s32(sumi1, sqx3, qy3);\n            sumi0 = vdotq_s32(sumi0, sqx4, qy4);\n            sumi1 = vdotq_s32(sumi1, sqx5, qy5);\n#else\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));\n#endif\n        }\n\n        const int16x8_t ysum0 = vld1q_s16(y[i].bsums);\n        const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);\n\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n\n#if defined(__ARM_FEATURE_DOTPROD)\n        sumi0 = vaddq_s32(sumi0, sumi1);\n        sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));\n\n        sumf += d * (float) vaddvq_s32(sumi0);\n#else\n        sumi0 = vaddq_s16(sumi0, sumi1);\n        sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));\n\n        sumf += d * (float) vaddlvq_s16(sumi0);\n#endif\n    }\n\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_tq2_0 * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__ARM_NEON)\n    float sumf = 0.0f;\n\n    const uint8x16_t m3 = vdupq_n_u8(3);\n\n    for (int i = 0; i < nb; ++i) {\n#if defined(__ARM_FEATURE_DOTPROD)\n        int32x4_t sumi0 = vdupq_n_s32(0);\n        int32x4_t sumi1 = vdupq_n_s32(0);\n#else\n        int16x8_t sumi0 = vdupq_n_s16(0);\n        int16x8_t sumi1 = vdupq_n_s16(0);\n#endif\n\n        for (size_t j = 0; j < sizeof(x->qs); j += 32) {\n            uint8x16_t qx0 = vld1q_u8(x[i].qs + j);\n            uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);\n            uint8x16_t qx2 = vshrq_n_u8(qx0, 2);\n            uint8x16_t qx3 = vshrq_n_u8(qx1, 2);\n            uint8x16_t qx4 = vshrq_n_u8(qx0, 4);\n            uint8x16_t qx5 = vshrq_n_u8(qx1, 4);\n            uint8x16_t qx6 = vshrq_n_u8(qx0, 6);\n            uint8x16_t qx7 = vshrq_n_u8(qx1, 6);\n\n            int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));\n            int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));\n            int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));\n            int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));\n            int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));\n            int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));\n            int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));\n            int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));\n\n            const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 +   0);\n            const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 +  16);\n            const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 +  32);\n            const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 +  48);\n            const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 +  64);\n            const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 +  80);\n            const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 +  96);\n            const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);\n\n#if defined(__ARM_FEATURE_DOTPROD)\n            sumi0 = vdotq_s32(sumi0, sqx0, qy0);\n            sumi1 = vdotq_s32(sumi1, sqx1, qy1);\n            sumi0 = vdotq_s32(sumi0, sqx2, qy2);\n            sumi1 = vdotq_s32(sumi1, sqx3, qy3);\n            sumi0 = vdotq_s32(sumi0, sqx4, qy4);\n            sumi1 = vdotq_s32(sumi1, sqx5, qy5);\n            sumi0 = vdotq_s32(sumi0, sqx6, qy6);\n            sumi1 = vdotq_s32(sumi1, sqx7, qy7);\n#else\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));\n            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));\n            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));\n#endif\n        }\n\n        const int16x8_t ysum0 = vld1q_s16(y[i].bsums);\n        const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);\n\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n\n#if defined(__ARM_FEATURE_DOTPROD)\n        sumi0 = vaddq_s32(sumi0, sumi1);\n        sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));\n\n        sumf += d * (float) vaddvq_s32(sumi0);\n#else\n        sumi0 = vaddq_s16(sumi0, sumi1);\n        sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));\n\n        sumf += d * (float) vaddlvq_s16(sumi0);\n#endif\n    }\n\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q2_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#ifdef __ARM_FEATURE_SVE\n    const int vector_length = svcntb()*8;\n    const svuint8_t m3s = svdup_n_u8(0x3);\n    const svuint32_t m4s = svdup_n_u32(0xF);\n    const svint32_t vzero_sv = svdup_n_s32(0);\n    svfloat32_t acc_sum = svdup_n_f32(0);\n    svbool_t pred_s32 = svptrue_pat_b32(SV_VL4);\n\n    switch (vector_length) {\n        case 128:\n            for (int i = 0; i < nb; ++i) {\n                const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n                svfloat32_t d_broad = svdup_n_f32((float32_t)d);\n                const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n                svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);\n\n                const uint8_t * GGML_RESTRICT q2 = x[i].qs;\n                const int8_t  * GGML_RESTRICT q8_sv = y[i].qs;\n                const uint8_t * GGML_RESTRICT sc = x[i].scales;\n\n                svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc);\n                const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));\n\n                mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4);\n                const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));\n\n                svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums);\n                svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4);\n\n                const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2));\n\n                mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8);\n                const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));\n\n                mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12);\n                const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));\n\n                q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8);\n                q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12);\n\n                svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2));\n\n                svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1));\n\n                acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad);\n\n                svint32_t sumi1 = svdup_n_s32(0);\n\n                {\n                    const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2);\n                    svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s));\n                    svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n                    const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s));\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0));\n\n                    const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16);\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1));\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2));\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3));\n\n\n                    const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s));\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0));\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1));\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2));\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3));\n\n                    //-------------------------------\n\n                    q2 += 32;\n                    const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s));\n                    const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2);\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0));\n\n                    const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16);\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1));\n\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2));\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3));\n\n\n                    const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s));\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0));\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1));\n\n\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2));\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3));\n                }\n                acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad);\n            }\n            *s = svaddv_f32(svptrue_b32(), acc_sum);\n            break;\n\n        case 256:\n        case 512:\n            for (int i = 0; i < nb; ++i) {\n                const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n                svfloat32_t d_broad = svdup_n_f32((float32_t)d);\n                const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n                svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);\n\n                const uint8_t * GGML_RESTRICT q2 = x[i].qs;\n                const int8_t  * GGML_RESTRICT q8_sv = y[i].qs;\n                const uint8_t * GGML_RESTRICT sc = x[i].scales;\n\n                const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8;\n                const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s));\n                const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4));\n                svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums);\n\n                const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc);\n                const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s));\n                const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4));\n\n                svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8);\n\n                svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2)));\n\n                acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad);\n\n                svint32_t sumi1 = svdup_n_s32(0);\n\n                {\n                    const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);\n                    svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s));\n                    svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;\n\n                    svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1));\n                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;\n\n                    svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3));\n                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2);\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;\n\n                    scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5));\n                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;\n\n                    scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7));\n                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);\n\n                    q2 += 32;\n\n                    const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s));\n                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;\n\n                    scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1));\n                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;\n\n                    scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3));\n                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;\n\n                    scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5));\n                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);\n\n                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s));\n                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;\n\n                    scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7));\n                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);\n                }\n                acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad);\n            }\n            *s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum);\n            break;\n\n        default:\n            assert(false && \"Unsupported vector length\");\n            break;\n    }\n\n#elif __ARM_NEON\n    const uint8x16_t m3 = vdupq_n_u8(0x3);\n    const uint8x16_t m4 = vdupq_n_u8(0xF);\n\n    const int32x4_t vzero = vdupq_n_s32(0);\n\n    ggml_int8x16x2_t q2bytes;\n    uint8_t aux[16];\n\n    float sum = 0;\n\n    for (int i = 0; i < nb; ++i) {\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        const uint8_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n        const uint8_t * GGML_RESTRICT sc = x[i].scales;\n\n        const uint8x16_t mins_and_scales = vld1q_u8(sc);\n        const uint8x16_t scales = vandq_u8(mins_and_scales, m4);\n        vst1q_u8(aux, scales);\n\n        const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);\n        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);\n        const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};\n        const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),\n                                       vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));\n        const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),\n                                       vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));\n        sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));\n\n        int isum = 0;\n        int is = 0;\n\n// We use this macro instead of a function call because for some reason\n// the code runs 2-3% slower, even if the function is declared inline\n#define MULTIPLY_ACCUM_WITH_SCALE(index)\\\n        isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\\\n        isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];\n\n#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\\\n        q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\\\n        q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\\\n        q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\\\n        MULTIPLY_ACCUM_WITH_SCALE((index));\n\n        for (int j = 0; j < QK_K/128; ++j) {\n            const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;\n\n            ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\n            q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));\n            q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));\n\n            MULTIPLY_ACCUM_WITH_SCALE(0);\n\n            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);\n            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);\n            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);\n\n            is += 8;\n        }\n\n        sum += d * isum;\n    }\n\n    *s = sum;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n\n    const block_q3_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__ARM_FEATURE_SVE)\n\n    uint32_t aux[3];\n    uint32_t utmp[4];\n\n    const int8_t m32 = 32;\n    const int vector_length = svcntb()*8;\n    const svuint8_t m3b_sv = svdup_n_u8(0x3);\n    const svint32_t vzero_sv = svdup_n_s32(0);\n\n    const svuint8_t m0_sv = svdup_n_u8(1);\n    const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);\n    const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);\n    const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);\n\n    float sum = 0;\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n\n        const uint8_t * GGML_RESTRICT q3_sv = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh_sv = x[i].hmask;\n        const int8_t  * GGML_RESTRICT q8_sv = y[i].qs;\n\n        // Set up scales\n        memcpy(aux, x[i].scales, 12);\n        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);\n        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);\n        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);\n        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);\n\n        int8_t * scale = (int8_t *)utmp;\n\n        for (int j = 0; j < 16; ++j) scale[j] -= m32;\n\n        switch (vector_length) {\n            case 128:\n                {\n                    svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv);\n                    svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16);\n                    svuint8_t q3h_sv;\n\n                    svint32_t sumi1_1 = svdup_n_s32(0);\n                    svint8_t q3bytes_sv;\n\n                    for (int j = 0; j < QK_K/128; ++j) {\n\n                        const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;\n                        const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;\n                        svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n                        svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2);\n                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));\n\n                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));\n\n                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2);\n                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv));\n\n                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));\n\n                        q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n                        q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1);\n                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));\n\n                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));\n\n                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1);\n                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));\n\n                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));\n\n\n                        scale += 4;\n                        q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n                        q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                        q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1);\n                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));\n\n                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));\n\n                        q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2);\n                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));\n\n                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));\n\n\n                        q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n                        q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;\n\n                        q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1);\n                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));\n\n                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));\n\n                        q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1);\n                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));\n\n                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));\n\n                        if (j == 0) {\n                            qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4);\n                            qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4);\n                        }\n\n                        scale += 4;\n                    }\n\n                    sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));\n                } break;\n            case 256:\n            case 512:\n                {\n                    svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv);\n                    svuint8_t q3h_sv;\n\n                    svint32_t sumi1_1 = svdup_n_s32(0);\n                    svint8_t q3bytes_sv;\n\n                    for (int j = 0; j < QK_K/128; ++j) {\n\n                        const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32;\n                        svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;\n                        svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;\n\n                        q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2);\n                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));\n\n\n                        svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));\n                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);\n\n                        q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1);\n                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));\n\n                        scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));\n                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);\n\n                        scale += 4;\n                        q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;\n                        q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;\n\n                        q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv);\n                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));\n\n                        scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));\n                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);\n\n                        q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1);\n                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));\n\n                        scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));\n                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);\n\n                        if (j == 0) {\n                            qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4);\n                        }\n\n                        scale += 4;\n                    }\n\n                    sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));\n                } break;\n            default:\n                assert(false && \"Unsupported vector length\");\n                break;\n        }\n    }\n    *s = sum;\n\n#elif __ARM_NEON\n\n    uint32_t aux[3];\n    uint32_t utmp[4];\n\n    const uint8x16_t m3b = vdupq_n_u8(0x3);\n    const int32x4_t  vzero = vdupq_n_s32(0);\n\n    const uint8x16_t m0 = vdupq_n_u8(1);\n    const uint8x16_t m1 = vshlq_n_u8(m0, 1);\n    const uint8x16_t m2 = vshlq_n_u8(m0, 2);\n    const uint8x16_t m3 = vshlq_n_u8(m0, 3);\n    const int8_t m32 = 32;\n\n    ggml_int8x16x4_t q3bytes;\n\n    float sum = 0;\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n\n        const uint8_t * GGML_RESTRICT q3 = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].hmask;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);\n\n        ggml_uint8x16x4_t q3h;\n\n        int32_t isum = 0;\n\n        // Set up scales\n        memcpy(aux, x[i].scales, 12);\n        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);\n        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);\n        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);\n        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);\n\n        int8_t * scale = (int8_t *)utmp;\n        for (int j = 0; j < 16; ++j) scale[j] -= m32;\n\n        for (int j = 0; j < QK_K/128; ++j) {\n\n            const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;\n            const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;\n            const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;\n\n            q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);\n            q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);\n            q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);\n            q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);\n\n            q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));\n            q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));\n            q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));\n            q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));\n\n            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];\n            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];\n            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];\n            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];\n\n            scale += 4;\n\n            q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);\n            q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);\n            q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);\n            q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);\n\n            q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));\n            q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));\n            q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));\n            q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));\n\n            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];\n            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];\n            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];\n            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];\n\n            scale += 4;\n\n            if (j == 0) {\n                qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);\n                qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);\n            }\n\n        }\n        sum += d * isum;\n\n    }\n\n    *s = sum;\n\n#else\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n\n}\n\n#ifdef __ARM_FEATURE_SVE\nstatic inline svuint32_t ggml_decode_q4scales_and_mins_for_mmla(const uint32_t * vx_scales) {\n    const svbool_t pg_all   = svptrue_pat_b32(SV_VL4);\n    const svbool_t pg_false = svpfalse_b();            // 0x0000\n    const svbool_t pg_lo_8  = svwhilelt_b8_s32(0,  8); // 0x00ff\n    const svbool_t pg_odd   = svzip1_b32(pg_false, pg_lo_8);\n\n    svuint32_t vutmp_hi, vutmp_lo;\n    svuint32_t vx01 = svld1_u32(pg_lo_8, vx_scales);\n    vutmp_hi = svzip1_u32(vx01, vx01);\n    vutmp_hi = svlsr_n_u32_m(pg_odd, vutmp_hi, 2);\n    vutmp_hi = svreinterpret_u32_u64(svand_n_u64_x(pg_all, svreinterpret_u64_u32(vutmp_hi), UINT64_C(0x303030303f3f3f3f)));\n    const svuint32_t vx2 = svdup_u32(vx_scales[2]);\n    vutmp_lo = svlsr_u32_x(pg_all, vx2, svreinterpret_u32_s32(svindex_s32(-2, 2)));\n    vutmp_lo = svand_n_u32_z(pg_odd, vutmp_lo, UINT32_C(0x0f0f0f0f));\n    svuint32_t vutmp = svorr_u32_z(pg_all, vutmp_hi, vutmp_lo);\n    return vutmp;\n}\n#endif\n\nvoid ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n#ifdef __ARM_FEATURE_MATMUL_INT8\n    assert((nrc == 2) || (nrc == 1));\n#else\n    assert(nrc == 1);\n#endif\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n#ifdef __ARM_FEATURE_SVE\n    const int vector_length = ggml_cpu_get_sve_cnt()*8;\n#endif\n\n#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)\n    if (nrc == 2) {\n        svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);\n\n        const block_q4_K * GGML_RESTRICT vx0 = vx;\n        const block_q8_K * GGML_RESTRICT vy0 = vy;\n        const block_q4_K * GGML_RESTRICT vx1 = (const block_q4_K *) ((const uint8_t*)vx + bx);\n        const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);\n\n        union {\n            uint32_t u32[8];\n            uint64_t u64[4];\n        } new_utmp;\n\n        svfloat32_t sumf1 = svdup_n_f32(0);\n\n        switch (vector_length) {\n            case 128:\n                {\n                    svbool_t pg_false = svpfalse_b();\n                    svbool_t pg_lo_8  = svwhilelt_b8_s32(0,  8);\n                    svbool_t vmins_mask1= svzip1_b32(pg_lo_8, pg_false);\n                    svbool_t vmins_mask2 = svzip1_b32(pg_false, pg_lo_8);\n                    svbool_t pg128_all  = svptrue_pat_b8(SV_VL16);\n                    for (int i = 0; i < nb; ++i) {\n                        svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));\n                        svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));\n                        svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);\n                        svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));\n                        svfloat32_t vy_dmins = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));\n                        svfloat32_t svdmins = svmul_n_f32_x(pg128_all, svmul_f32_x(pg128_all, vy_dmins, vx_dmins), -1);\n                        const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;\n                        const int8_t  * GGML_RESTRICT q8_0 = vy0[i].qs;\n                        const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;\n                        const int8_t  * GGML_RESTRICT q8_1 = vy1[i].qs;\n                        svint16_t lo = svld1_s16(pg128_all, vy0[i].bsums + 0);\n                        svint16_t hi = svld1_s16(pg128_all, vy0[i].bsums + 8);\n                        svint16_t sum_tmp1 = svuzp1_s16(lo, hi);\n                        svint16_t sum_tmp2 = svuzp2_s16(lo, hi);\n                        svint16_t svq8sums_0 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);\n                        lo = svld1_s16(pg128_all, vy1[i].bsums + 0);\n                        hi = svld1_s16(pg128_all, vy1[i].bsums + 8);\n                        sum_tmp1 = svuzp1(lo, hi);\n                        sum_tmp2 = svuzp2(lo, hi);\n                        svint16_t svq8sums_1 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);\n                        svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);\n                        svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);\n                        svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);\n                        svst2_u32(pg128_all, new_utmp.u32, decoded_scales);\n                        svint16_t svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp1_u32(svld1_u32(vmins_mask1, new_utmp.u32+4), svdup_n_u32(0)))));\n                        svint16_t svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp2_u32(svld1_u32(vmins_mask2, new_utmp.u32+4), svdup_n_u32(0)))));\n                        svint32_t svsumfs_tmp1 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_0));\n                        svint32_t svsumfs_tmp2 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_1));\n                        svint32_t svsumfs_tmp3 = svtrn1_s32(svsumfs_tmp1, svsumfs_tmp2);\n                        svint32_t svsumfs_tmp4 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_0));\n                        svint32_t svsumfs_tmp5 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_1));\n                        svint32_t svsumfs_tmp6 = svtrn1_s32(svsumfs_tmp4, svsumfs_tmp5);\n                        svint32_t svsumfs_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));\n                        svint32_t svsumfs_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));\n                        svint32_t svsumfs_tmp = svadd_s32_x(pg128_all, svsumfs_tmp7, svsumfs_tmp8);\n                        svint32_t svscales, sumi1, sumi2;\n                        svint32_t acc_sumif1 = svdup_n_s32(0);\n                        svint32_t acc_sumif2 = svdup_n_s32(0);\n                        svint8_t q4bytes_0_l, q4bytes_0_h, q4bytes_1_l, q4bytes_1_h, l0, l1, l2, l3,\n                                 q8bytes_0_h, q8bytes_0_l, q8bytes_1_h, q8bytes_1_l, r0, r1, r2, r3;\n#pragma GCC unroll 1\n                        for (int j = 0; j < QK_K/64; ++j) {\n                            q4bytes_0_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 0xf));\n                            q4bytes_1_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 0xf));\n                            q4bytes_0_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 0xf));\n                            q4bytes_1_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 0xf));\n                            l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));\n                            l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));\n                            l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));\n                            l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));\n                            q8bytes_0_h = svld1_s8(pg128_all, q8_0);\n                            q8bytes_1_h = svld1_s8(pg128_all, q8_1);\n                            q8bytes_0_l = svld1_s8(pg128_all, q8_0+16);\n                            q8bytes_1_l = svld1_s8(pg128_all, q8_1+16);\n                            r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));\n                            r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));\n                            r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));\n                            r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));\n                            sumi1 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);\n                            svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));\n                            acc_sumif1 = svmla_s32_x(pg128_all, acc_sumif1, svscales, sumi1);\n\n                            q4bytes_0_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 4));\n                            q4bytes_1_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 4));\n                            q4bytes_0_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 4));\n                            q4bytes_1_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 4));\n                            l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));\n                            l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));\n                            l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));\n                            l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));\n                            q8bytes_0_h = svld1_s8(pg128_all, q8_0+32);\n                            q8bytes_1_h = svld1_s8(pg128_all, q8_1+32);\n                            q8bytes_0_l = svld1_s8(pg128_all, q8_0+48);\n                            q8bytes_1_l = svld1_s8(pg128_all, q8_1+48);\n                            r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));\n                            r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));\n                            r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));\n                            r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));\n                            sumi2 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);\n                            svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));\n                            acc_sumif2 = svmla_s32_x(pg128_all, acc_sumif2, svscales, sumi2);\n                            q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;\n                        }\n                        sumf1 = svmla_f32_x(pg128_all,\n                                svmla_f32_x(pg128_all,\n                                    sumf1,\n                                    svcvt_f32_x(pg128_all,\n                                        svadd_s32_x(pg128_all, acc_sumif1, acc_sumif2)),\n                                    svsuper_block_scales),\n                                svdmins,\n                                svcvt_f32_s32_x(pg128_all, svsumfs_tmp));\n                    }  //end of for nb\n                } // end of case 128\n                break;\n            case 256:\n            case 512:\n                {\n                    const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);\n                    const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);\n                    const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);\n                    for (int i = 0; i < nb; ++i) {\n                        const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;\n                        const int8_t  * GGML_RESTRICT q8_0 = vy0[i].qs;\n                        const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;\n                        const int8_t  * GGML_RESTRICT q8_1 = vy1[i].qs;\n                        svint32_t svscales, sumi1, sumi2;\n                        svint32_t acc_sumif1 = svdup_n_s32(0);\n                        svint32_t acc_sumif2 = svdup_n_s32(0);\n                        svint8_t l0, l1, l2, l3, r0, r1, r2, r3;\n                        svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));\n                        svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));\n                        svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));\n                        svfloat32_t svsuper_block_scales = svmul_f32_z(pg32_4, vy_d, vx_d);\n                        svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));\n                        svfloat64_t vy_dmins_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));\n                        svfloat32_t vy_dmins = svreinterpret_f32_f64(svuzp1_f64(vy_dmins_tmp, vy_dmins_tmp));\n                        svfloat32_t svdmins = svmul_n_f32_x(pg32_4, svmul_f32_x(pg32_4, vx_dmins, vy_dmins), -1);\n                        svint16_t rc1 = svuzp1_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));\n                        svint16_t rc2 = svuzp2_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));\n                        svint16_t svq8sums = svadd_s16_x(pg256_all, rc1, rc2);\n                        svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);\n                        svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);\n                        svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);\n                        svst2_u32(pg8_16, new_utmp.u32, decoded_scales);\n                        svint16_t new_svq8sums_0 = svreinterpret_s16_u64(svtrn1_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));\n                        svint16_t new_svq8sums_1 = svreinterpret_s16_u64(svtrn2_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));\n                        svuint64_t new_mins_0 = svdup_u64(new_utmp.u64[2]);\n                        svuint64_t new_mins_1 = svdup_u64(new_utmp.u64[3]);\n                        svint16_t new_svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_0)));\n                        svint16_t new_svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_1)));\n                        svint64_t dot_prod_0 = svdot_s64(svdup_s64(0), new_svmins8_0, new_svq8sums_0);\n                        svint64_t dot_prod_1 = svdot_s64(dot_prod_0, new_svmins8_1, new_svq8sums_1);\n                        svfloat32_t converted_dot_prod_1 = svcvt_f32_s64_x(pg256_all, dot_prod_1);\n                        svfloat32_t svsumfs_tmp = svuzp1_f32(converted_dot_prod_1, converted_dot_prod_1);\n\n#pragma GCC unroll 1\n                        for (int j = 0; j < QK_K/64; ++j) {\n                            svuint8_t q4bytes_0 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 0xf);\n                            svuint8_t q4bytes_1 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 0xf);\n                            svuint8_t q4bytes_2 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 4);\n                            svuint8_t q4bytes_3 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 4);\n                            l0 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));\n                            l1 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));\n                            l2 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));\n                            l3 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));\n                            svint8_t q8bytes_0 = svld1_s8(pg256_all, q8_0);\n                            svint8_t q8bytes_1 = svld1_s8(pg256_all, q8_1);\n                            svint8_t q8bytes_2 = svld1_s8(pg256_all, q8_0+32);\n                            svint8_t q8bytes_3 = svld1_s8(pg256_all, q8_1+32);\n                            r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));\n                            r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));\n                            r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));\n                            r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));\n                            sumi1 = svmmla(svmmla(svdup_n_s32(0), r0, l0), r1, l1);\n                            svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));\n                            acc_sumif1 = svmla_s32_x(pg256_all, acc_sumif1, svscales, sumi1);\n                            sumi2 = svmmla(svmmla(svdup_n_s32(0), r2, l2), r3, l3);\n                            svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));\n                            acc_sumif2 = svmla_s32_x(pg256_all, acc_sumif2, svscales, sumi2);\n                            q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;\n                        }\n                        svint32_t acc_sumif = svadd_s32_x(pg256_all, acc_sumif1, acc_sumif2);\n                        svint32_t swap_acc_sumif = svext_s32(acc_sumif, acc_sumif, 4);\n                        acc_sumif = svadd_s32_x(pg32_4, acc_sumif, swap_acc_sumif);\n                        sumf1 = svmla_f32_x(pg32_4,\n                                svmla_f32_x(pg32_4,\n                                    sumf1,\n                                    svcvt_f32_x(pg32_4, acc_sumif),\n                                    svsuper_block_scales),\n                                svdmins,\n                                svsumfs_tmp);\n                    } // end of for nb\n                } // end of case 256-512\n                break;\n            default:\n                assert(false && \"Unsupported vector length\");\n                break;\n        }\n\n        svst1_f32(pg32_2, s, sumf1);\n        svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sumf1), svdup_n_u8(0), 8)));\n\n        return;\n    }\n#elif defined(__ARM_FEATURE_MATMUL_INT8)\n    if (nrc == 2) {\n        const block_q4_K * GGML_RESTRICT x0 = x;\n        const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);\n        const block_q8_K * GGML_RESTRICT y0 = y;\n        const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);\n\n        const uint8x16_t m4b = vdupq_n_u8(0x0f);\n\n        float32x4_t vfsum = vdupq_n_f32(0.0f);\n\n        for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {\n            const uint8_t * GGML_RESTRICT qx0 = x0->qs;\n            const uint8_t * GGML_RESTRICT qx1 = x1->qs;\n            const  int8_t * GGML_RESTRICT qy0 = y0->qs;\n            const  int8_t * GGML_RESTRICT qy1 = y1->qs;\n\n            // decode scales and mins\n            int8_t x0_scales[8], x1_scales[8];\n            int16x8_t x0_mins, x1_mins;\n            {\n                uint32_t scales_mins[3];\n                memcpy(scales_mins, x0->scales, 12);\n                const uint32_t mins_0_3 = scales_mins[1] & kmask1;\n                const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);\n                const uint32x2_t mins = {mins_0_3, mins_4_7};\n                x0_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));\n                uint32_t scales[2];\n                scales[0] = scales_mins[0] & kmask1; // scales 0~3\n                scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7\n                memcpy(x0_scales, scales, 8);\n            }\n            {\n                uint32_t scales_mins[3];\n                memcpy(scales_mins, x1->scales, 12);\n                const uint32_t mins_0_3 = scales_mins[1] & kmask1;\n                const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);\n                const uint32x2_t mins = {mins_0_3, mins_4_7};\n                x1_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));\n                uint32_t scales[2];\n                scales[0] = scales_mins[0] & kmask1; // scales 0~3\n                scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7\n                memcpy(x1_scales, scales, 8);\n            }\n\n            int32x4_t visum = {0};\n\n            // process 64 data points per iteration, totally 256 data points\n            for (int j = 0; j < QK_K / 64; ++j, qx0 += 32, qx1 += 32, qy0 += 64, qy1 += 64) {\n                const int8x16x4_t vy0 = vld1q_s8_x4(qy0);\n                const int8x16x4_t vy1 = vld1q_s8_x4(qy1);\n\n                int8x16_t vx0[4], vx1[4];\n                {\n                    const uint8x16x2_t vv = vld1q_u8_x2(qx0);\n                    vx0[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));\n                    vx0[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));\n                    vx0[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));\n                    vx0[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));\n                }\n                {\n                    const uint8x16x2_t vv = vld1q_u8_x2(qx1);\n                    vx1[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));\n                    vx1[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));\n                    vx1[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));\n                    vx1[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));\n                }\n\n                // process 32 data points (share same block scale) per iteration\n                for (int k = 0; k < 2; ++k) {\n                    const int blk = j * 2 + k;\n                    const int32x4_t block_scale = {\n                        x0_scales[blk],\n                        x0_scales[blk],\n                        x1_scales[blk],\n                        x1_scales[blk],\n                    };\n\n                    int32x4_t vr = {0};\n                    for (int l = 0; l < 2; ++l) {\n                        const int idx = k * 2 + l;\n                        const int64x2_t vx0_s64 = vreinterpretq_s64_s8(vx0[idx]);\n                        const int64x2_t vx1_s64 = vreinterpretq_s64_s8(vx1[idx]);\n                        const int64x2_t vy0_s64 = vreinterpretq_s64_s8(vy0.val[idx]);\n                        const int64x2_t vy1_s64 = vreinterpretq_s64_s8(vy1.val[idx]);\n                        const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vx0_s64, vx1_s64));\n                        const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vx0_s64, vx1_s64));\n                        const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vy0_s64, vy1_s64));\n                        const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vy0_s64, vy1_s64));\n                        vr = vmmlaq_s32(vr, vx_l, vy_l);\n                        vr = vmmlaq_s32(vr, vx_h, vy_h);\n                    }\n                    // apply block scale, will NOT overflow\n                    // block_scale * sum_256(int4*int8) <= 2^(8+8+4+8) = 28 bits\n                    visum = vmlaq_s32(visum, vr, block_scale);\n                }\n            }\n\n            // adjust bias, apply superblock scale\n            {\n                int32_t bias[4];\n                // no obvious uplift from sve sdot-16, just use neon mul add\n                const int16x8_t y0_sums = vpaddq_s16(vld1q_s16(y0->bsums), vld1q_s16(y0->bsums+8));\n                const int16x8_t y1_sums = vpaddq_s16(vld1q_s16(y1->bsums), vld1q_s16(y1->bsums+8));\n                bias[0] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x0_mins)),\n                                               vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x0_mins))));\n                bias[1] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x0_mins)),\n                                               vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x0_mins))));\n                bias[2] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x1_mins)),\n                                               vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x1_mins))));\n                bias[3] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x1_mins)),\n                                               vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x1_mins))));\n                const float32x4_t dmins = {\n                    GGML_CPU_FP16_TO_FP32(x0->dmin) * y0->d,\n                    GGML_CPU_FP16_TO_FP32(x0->dmin) * y1->d,\n                    GGML_CPU_FP16_TO_FP32(x1->dmin) * y0->d,\n                    GGML_CPU_FP16_TO_FP32(x1->dmin) * y1->d,\n                };\n                vfsum = vmlsq_f32(vfsum, vcvtq_f32_s32(vld1q_s32(bias)), dmins);\n\n                const float32x4_t superblock_scale = {\n                    GGML_CPU_FP16_TO_FP32(x0->d) * y0->d,\n                    GGML_CPU_FP16_TO_FP32(x0->d) * y1->d,\n                    GGML_CPU_FP16_TO_FP32(x1->d) * y0->d,\n                    GGML_CPU_FP16_TO_FP32(x1->d) * y1->d,\n                };\n                vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);\n            }\n        }\n\n        // vfsum = ABCD -> ACBD\n        // AC -> s, BD -> (s+bs)\n        vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));\n        vst1_f32(s,      vget_low_f32 (vfsum));\n        vst1_f32(s + bs, vget_high_f32(vfsum));\n\n        return;\n    }\n#endif\n\n#ifdef __ARM_FEATURE_SVE\n    float sumf = 0;\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));\n\n        memcpy(utmp, x[i].scales, K_SCALE_SIZE);\n\n        uint32x2_t mins8 = { 0 };\n        mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);\n        mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);\n\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[0] &= kmask1;\n\n        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));\n        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),\n                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));\n        sumf -= dmin * vaddvq_s32(prod);\n\n        const uint8_t * scales = (const uint8_t *)utmp;\n\n        const uint8_t * GGML_RESTRICT q4 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        const svuint8_t m4b = svdup_n_u8(0xf);\n        const svint32_t mzero = svdup_n_s32(0);\n        svint32_t sumi1 = svdup_n_s32(0);\n        svint32_t sumi1_1 = svdup_n_s32(0);\n        svint32_t sumi1_2 = svdup_n_s32(0);\n        svint32_t sumi2 = svdup_n_s32(0);\n        svint32_t sumi2_1 = svdup_n_s32(0);\n        svint32_t sumi2_2 = svdup_n_s32(0);\n        switch (vector_length) {\n            case 128:\n                {\n                    for (int j = 0; j < QK_K/64; ++j) {\n                        svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));\n                        svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;\n                        sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);\n                        q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));\n                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;\n                        sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);\n\n                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));\n                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;\n                        sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);\n                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));\n                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;\n                        sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);\n                        q4 += 32;\n                    }\n                    sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);\n                    sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);\n                    sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));\n                } break;\n            case 256:\n            case 512:\n                {\n                    for (int j = 0; j < QK_K/64; ++j) {\n                        const svuint8_t q4bits  = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;\n                        svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));\n                        svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;\n                        sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);\n\n                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));\n                        q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;\n                        sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);\n                    }\n                    sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));\n                } break;\n            default:\n                assert(false && \"Unsupported vector length\");\n                break;\n        }\n    }\n    *s = sumf;\n#elif defined __ARM_NEON\n    const uint8x16_t m4b = vdupq_n_u8(0xf);\n    const int32x4_t mzero = vdupq_n_s32(0);\n\n    ggml_int8x16x2_t q4bytes;\n    ggml_int8x16x2_t q8bytes;\n\n    float sumf = 0;\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));\n\n        memcpy(utmp, x[i].scales, 12);\n\n        uint32x2_t mins8 = { 0 };\n        mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);\n        mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);\n\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[0] &= kmask1;\n\n        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));\n        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),\n                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));\n        sumf -= dmin * vaddvq_s32(prod);\n\n        const uint8_t * scales = (const uint8_t *)utmp;\n\n        const uint8_t * GGML_RESTRICT q4 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        int32_t sumi1 = 0;\n        int32_t sumi2 = 0;\n\n        for (int j = 0; j < QK_K/64; ++j) {\n            const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;\n\n            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\n            q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));\n            q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));\n\n            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);\n            sumi1 += vaddvq_s32(p1) * scales[2*j+0];\n\n            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\n            q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));\n            q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));\n\n            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);\n\n            sumi2 += vaddvq_s32(p2) * scales[2*j+1];\n        }\n\n        sumf += d * (sumi1 + sumi2);\n\n    }\n\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    UNUSED(utmp);\n    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n\n#ifdef __ARM_NEON\n    const uint8x16_t m4b = vdupq_n_u8(0xf);\n    const uint8x16_t mone = vdupq_n_u8(1);\n    const uint8x16_t mtwo = vdupq_n_u8(2);\n    const int32x4_t mzero = vdupq_n_s32(0);\n\n    ggml_int8x16x4_t q5bytes;\n\n    float sumf = 0;\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));\n\n        memcpy(utmp, x[i].scales, 12);\n        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n        const uint32_t uaux = utmp[1] & kmask1;\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[2] = uaux;\n        utmp[0] &= kmask1;\n\n        const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);\n        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));\n        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),\n                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));\n        int32_t sumi_mins = vaddvq_s32(prod);\n\n        const uint8_t * scales = (const uint8_t *)utmp;\n\n        const uint8_t * GGML_RESTRICT q5 = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);\n\n        ggml_uint8x16x4_t q5h;\n\n        int32_t sumi = 0;\n\n        for (int j = 0; j < QK_K/64; ++j) {\n\n            const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;\n            const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;\n\n            q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);\n            q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);\n            q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);\n            q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);\n            qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);\n            qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);\n\n            q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));\n            q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));\n            q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));\n            q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));\n\n            sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;\n            sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;\n        }\n\n        sumf += d * sumi - dmin * sumi_mins;\n    }\n\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    UNUSED(utmp);\n    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n#ifdef __ARM_FEATURE_MATMUL_INT8\n    assert((nrc == 2) || (nrc == 1));\n#else\n    assert(nrc == 1);\n#endif\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q6_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#ifdef __ARM_FEATURE_SVE\n    const int vector_length = ggml_cpu_get_sve_cnt()*8;\n#endif\n#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)\n    if (nrc == 2) {\n        const svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);\n\n        svfloat32_t sum = svdup_n_f32(0);\n\n        const block_q6_K * GGML_RESTRICT vx0 = vx;\n        const block_q8_K * GGML_RESTRICT vy0 = vy;\n        const block_q6_K * GGML_RESTRICT vx1 = (const block_q6_K *) ((const uint8_t*)vx + bx);\n        const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);\n\n        switch (vector_length) {\n            case 128:\n                {\n                    const svbool_t pg128_all = svptrue_pat_b8(SV_ALL);\n                    for (int i = 0; i < nb; ++i) {\n                        const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;\n                        const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;\n                        const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;\n                        const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;\n                        const int8_t  * GGML_RESTRICT q80 = vy0[i].qs;\n                        const int8_t  * GGML_RESTRICT q81 = vy1[i].qs;\n\n                        const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;\n                        const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;\n\n                        svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));\n                        svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));\n                        svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);\n                        // process q8sum summation 128 bit route\n                        const svint16_t q8sums_01 = svld1_s16(pg128_all, vy0[i].bsums);\n                        const svint16_t q8sums_02 = svld1_s16(pg128_all, vy0[i].bsums + 8);\n                        const svint16_t q8sums_11 = svld1_s16(pg128_all, vy1[i].bsums);\n                        const svint16_t q8sums_12 = svld1_s16(pg128_all, vy1[i].bsums + 8);\n                        const svint64x2_t q6scales_0_tmp = svld2_s64(pg128_all, (const int64_t *)scale0);\n                        const svint16_t q6scales_01 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 0)));\n                        const svint16_t q6scales_02 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 1)));\n                        const svint64x2_t q6scales_1_tmp = svld2_s64(pg128_all, (const int64_t *)scale1);\n                        const svint16_t q6scales_11 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 0)));\n                        const svint16_t q6scales_12 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 1)));\n                        const svint64_t prod = svdup_n_s64(0);\n\n                        svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_01), q8sums_02, q6scales_02));\n                        svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_11), q8sums_02, q6scales_12));\n                        svint32_t isum_tmp3 = svtrn1_s32(isum_tmp1, isum_tmp2);\n                        svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_01), q8sums_12, q6scales_02));\n                        svint32_t isum_tmp5 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_11), q8sums_12, q6scales_12));\n                        svint32_t isum_tmp6 = svtrn1_s32(isum_tmp4, isum_tmp5);\n                        svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));\n                        svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));\n                        svint32_t svisum_mins = svadd_s32_x(pg128_all, isum_tmp7, isum_tmp8);\n\n                        // process mmla\n                        svint8_t  l0, l1, r0, r1;\n                        svint32_t isum_tmp = svdup_n_s32(0);\n                        for (int j = 0; j < QK_K/128; ++j) {\n                            for (int k = 0; k < 8; ++k) {\n                                svuint8_t qhbits_0 = svld1_u8(pg128_all, qh0+16*(k%2));\n                                svuint8_t qhbits_1 = svld1_u8(pg128_all, qh1+16*(k%2));\n                                svuint8_t q6bits_0 = svld1_u8(pg128_all, ql0+16*(k%4));\n                                svuint8_t q6bits_1 = svld1_u8(pg128_all, ql1+16*(k%4));\n                                const int ql_pos = (k/4)*4;\n                                svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_0, 4);\n                                svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_1, 4);\n                                const int qh_pos = (k/2)*2;\n                                svuint8_t q6bytes_0_hi = svand_n_u8_x(pg128_all, qhbits_0, 0x3 << qh_pos);\n                                svuint8_t q6bytes_1_hi = svand_n_u8_x(pg128_all, qhbits_1, 0x3 << qh_pos);\n                                svint8_t  q6bytes_0, q6bytes_1;\n                                if (qh_pos <= 4) {\n                                    q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));\n                                    q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));\n                                } else {\n                                    q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_0_lo, svlsr_n_u8_x(pg128_all, q6bytes_0_hi, (qh_pos - 4))));\n                                    q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_1_lo, svlsr_n_u8_x(pg128_all, q6bytes_1_hi, (qh_pos - 4))));\n                                }\n                                svint8_t  q8bytes_0 = svld1_s8(pg128_all, q80+16*(k%8));\n                                svint8_t  q8bytes_1 = svld1_s8(pg128_all, q81+16*(k%8));\n                                l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));\n                                l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));\n                                r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));\n                                r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));\n                                svint32_t svscale = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));\n                                isum_tmp = svmla_s32_x(pg128_all, isum_tmp, svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), svscale);\n                            }\n                            qh0 += 32;  qh1 += 32;\n                            ql0 += 64;  ql1 += 64;\n                            q80 += 128; q81 += 128;\n                            scale0 += 8; scale1 += 8;\n                        }\n                        sum = svmla_f32_x(pg128_all, sum,\n                                svcvt_f32_x(pg128_all, svmla_s32_x(pg128_all, isum_tmp,\n                                        svisum_mins, svdup_n_s32(-32))),\n                                svsuper_block_scales);\n                    }\n                } // end of case 128\n                break;\n            case 256:\n            case 512:\n                {\n                    const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);\n                    const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);\n                    for (int i = 0; i < nb; ++i) {\n                        const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;\n                        const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;\n                        const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;\n                        const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;\n                        const int8_t  * GGML_RESTRICT q80 = vy0[i].qs;\n                        const int8_t  * GGML_RESTRICT q81 = vy1[i].qs;\n\n                        const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;\n                        const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;\n                        svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));\n                        svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));\n                        svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));\n                        svfloat32_t svsuper_block_scales = svmul_f32_x(pg32_4, vy_d, vx_d);\n                        // process q8sum summation 256 bit route\n                        const svint16_t q8sums_0 = svld1_s16(pg256_all, vy0[i].bsums);\n                        const svint16_t q8sums_1 = svld1_s16(pg256_all, vy1[i].bsums);\n                        const svint16_t q6scales_0 = svunpklo_s16(svld1_s8(pg256_all, scale0));\n                        const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(pg256_all, scale1));\n                        const svint64_t prod = svdup_n_s64(0);\n                        svint32_t isum_tmp1  = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_0));\n                        svint32_t isum_tmp2  = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_1));\n                        svint32_t isum_tmp3  = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_0));\n                        svint32_t isum_tmp4  = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_1));\n                        svint32_t isum_tmp5  = svtrn1_s32(isum_tmp1, isum_tmp2);\n                        svint32_t isum_tmp6  = svtrn1_s32(isum_tmp3, isum_tmp4);\n                        svint32_t isum_tmp7  = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));\n                        svint32_t isum_tmp8  = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));\n                        svint32_t isum_tmp9  = svadd_s32_x(pg256_all, isum_tmp7, isum_tmp8);\n                        svint32_t isum_tmp10 = svreinterpret_s32_u8(svext_u8(svreinterpret_u8_s32(isum_tmp9), svreinterpret_u8_s32(isum_tmp9), 16));\n                        svint32_t svisum_mins = svadd_s32_z(pg32_4, isum_tmp9, isum_tmp10);\n\n                        // process mmla\n                        svint8_t l0, l1, r0, r1;\n                        svint32_t isum_tmp = svdup_n_s32(0);\n                        for (int j = 0; j < QK_K/128; ++j) {\n                            for (int k = 0; k < 8; k+=2) { // process 2 block\n                                svuint8_t qhbits_0  = svld1_u8(pg256_all, qh0);\n                                svuint8_t qhbits_1  = svld1_u8(pg256_all, qh1);\n                                svuint8_t q6bits_0  = svld1_u8(pg256_all, ql0+32*((k%4)/2));\n                                svuint8_t q6bits_1  = svld1_u8(pg256_all, ql1+32*((k%4)/2));\n                                const int ql_pos = (k/4)*4;\n                                svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_0, 4);\n                                svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_1, 4);\n                                const int qh_pos = (k/2)*2;\n                                svuint8_t q6bytes_0_hi = svand_n_u8_x(pg256_all, qhbits_0, 0x3 << qh_pos);\n                                svuint8_t q6bytes_1_hi = svand_n_u8_x(pg256_all, qhbits_1, 0x3 << qh_pos);\n                                svint8_t  q6bytes_0, q6bytes_1;\n                                if (qh_pos <= 4) {\n                                    q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));\n                                    q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));\n                                } else {\n                                    q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_0_lo, svlsr_n_u8_x(pg256_all, q6bytes_0_hi, (qh_pos - 4))));\n                                    q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_1_lo, svlsr_n_u8_x(pg256_all, q6bytes_1_hi, (qh_pos - 4))));\n                                }\n                                svint8_t  q8bytes_0 = svld1_s8(pg256_all, q80+32*(k/2));\n                                svint8_t  q8bytes_1 = svld1_s8(pg256_all, q81+32*(k/2));\n                                l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));\n                                l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));\n                                r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));\n                                r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));\n                                svint32_t svscale0 = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));\n                                svint32_t svscale1 = svzip1_s32(svdup_n_s32(scale0[k+1]), svdup_n_s32(scale1[k+1]));\n                                isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r0, l0), svscale0);\n                                isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r1, l1), svscale1);\n                            }\n                            qh0 += 32;  qh1 += 32;\n                            ql0 += 64;  ql1 += 64;\n                            q80 += 128; q81 += 128;\n                            scale0 += 8; scale1 += 8;\n                        } // end of for\n                        svint32_t swap_isum_tmp = svext_s32(isum_tmp, isum_tmp, 4);\n                        isum_tmp = svadd_s32_x(pg32_4, isum_tmp, swap_isum_tmp);\n                        sum = svmla_f32_x(pg32_4, sum,\n                                svcvt_f32_x(pg32_4, svmla_s32_x(pg32_4, isum_tmp,\n                                        svisum_mins, svdup_n_s32(-32))),\n                                svsuper_block_scales);\n                    }\n                } // end of case 256\n                break;\n            default:\n                assert(false && \"Unsupported vector length\");\n                break;\n        } // end of switch\n\n        svst1_f32(pg32_2, s, sum);\n        svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sum), svdup_n_u8(0), 8)));\n\n        return;\n    }\n#elif defined(__ARM_FEATURE_MATMUL_INT8)\n    if (nrc == 2) {\n        const block_q6_K * GGML_RESTRICT x0 = x;\n        const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);\n        const block_q8_K * GGML_RESTRICT y0 = y;\n        const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);\n\n        float32x4_t vfsum = vdupq_n_f32(0.0f);\n\n        for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {\n            const uint8_t * GGML_RESTRICT ql0 = x0->ql;\n            const uint8_t * GGML_RESTRICT ql1 = x1->ql;\n            const uint8_t * GGML_RESTRICT qh0 = x0->qh;\n            const uint8_t * GGML_RESTRICT qh1 = x1->qh;\n            const  int8_t * GGML_RESTRICT qy0 = y0->qs;\n            const  int8_t * GGML_RESTRICT qy1 = y1->qs;\n\n            const uint8x16_t mone = vdupq_n_u8(0x30);\n            const uint8x16_t  m4b = vdupq_n_u8(0x0f);\n\n            int32x4_t visum = vdupq_n_s32(0);\n\n            // process 8 blocks per iteration, totally 16 blocks\n            for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {\n                int8x16_t vx0[8], vx1[8];\n\n                // de-quantize vx0[8]\n                {\n                    const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);\n                    const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);\n\n                    uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));\n                    uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));\n                    uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));\n                    uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));\n\n                    vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));\n                    vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));\n                    vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));\n                    vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));\n\n                    q6h_0 = vandq_u8(mone, qh_bits.val[0]);\n                    q6h_1 = vandq_u8(mone, qh_bits.val[1]);\n                    q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));\n                    q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));\n\n                    vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));\n                    vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));\n                    vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));\n                    vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));\n                }\n\n                // de-quantize vx1[8]\n                {\n                    const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);\n                    const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);\n\n                    uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));\n                    uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));\n                    uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));\n                    uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));\n\n                    vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));\n                    vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));\n                    vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));\n                    vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));\n\n                    q6h_0 = vandq_u8(mone, qh_bits.val[0]);\n                    q6h_1 = vandq_u8(mone, qh_bits.val[1]);\n                    q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));\n                    q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));\n\n                    vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));\n                    vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));\n                    vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));\n                    vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));\n                }\n\n                // process 16 elements (one block with same scale) per iteration\n                // - vx = concat(ql, qh) - 32\n                // - r1,r2,r3,r4 = smmla(vx, vy)\n                for (int k = 0; k < 8; ++k) {\n                    const int blk = j * 8 + k;\n\n                    const int8x16_t vy0 = vld1q_s8(qy0);\n                    const int8x16_t vy1 = vld1q_s8(qy1);\n                    qy0 += 16;\n                    qy1 += 16;\n\n                    const int32x4_t block_scale = {\n                        x0->scales[blk],\n                        x0->scales[blk],\n                        x1->scales[blk],\n                        x1->scales[blk],\n                    };\n\n                    // calculate four results at once with outer product\n                    const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));\n                    const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));\n                    const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));\n                    const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));\n                    int32x4_t vr = vdupq_n_s32(0);\n                    vr = vmmlaq_s32(vr, vx_l, vy_l);\n                    vr = vmmlaq_s32(vr, vx_h, vy_h);\n\n                    // apply block scale, will NOT overflow\n                    // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits\n                    visum = vmlaq_s32(visum, vr, block_scale);\n                }\n            }\n\n            // adjust bias, apply superblock scale\n            {\n                int32_t bias[4];\n                // NEON doesn't support int16 dot product, fallback to separated mul and add\n                const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);\n                const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);\n\n                int8x16_t scales_s8 = vld1q_s8(x0->scales);\n                const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};\n                scales_s8 = vld1q_s8(x1->scales);\n                const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};\n\n                int32x4_t prod;\n                prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),\n                                           vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),\n                                 vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),\n                                           vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));\n                bias[0] = vaddvq_s32(prod);\n                prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),\n                                           vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),\n                                 vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),\n                                           vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));\n                bias[1] = vaddvq_s32(prod);\n                prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),\n                                           vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),\n                                 vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),\n                                           vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));\n                bias[2] = vaddvq_s32(prod);\n                prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),\n                                           vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),\n                                 vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),\n                                           vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));\n                bias[3] = vaddvq_s32(prod);\n\n                const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);\n\n                const float32x4_t superblock_scale = {\n                    GGML_CPU_FP16_TO_FP32(x0->d) * y0->d,\n                    GGML_CPU_FP16_TO_FP32(x0->d) * y1->d,\n                    GGML_CPU_FP16_TO_FP32(x1->d) * y0->d,\n                    GGML_CPU_FP16_TO_FP32(x1->d) * y1->d,\n                };\n\n                visum = vsubq_s32(visum, vibias);\n                vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);\n            }\n        }\n\n        // vfsum = ABCD -> ACBD\n        // AC -> s, BD -> (s+bs)\n        vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));\n        vst1_f32(s,      vget_low_f32 (vfsum));\n        vst1_f32(s + bs, vget_high_f32(vfsum));\n\n        return;\n    }\n#endif\n\n#ifdef __ARM_FEATURE_SVE\n    float sum = 0;\n    svuint8_t m4b = svdup_n_u8(0xf);\n    svint32_t vzero = svdup_n_s32(0);\n    svuint8_t mone = svdup_n_u8(0x30);\n    svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;\n    svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;\n\n    for (int i = 0; i < nb; ++i) {\n        const float d_all = GGML_CPU_FP16_TO_FP32(x[i].d);\n\n        const uint8_t * GGML_RESTRICT q6 = x[i].ql;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        const int8_t * GGML_RESTRICT scale = x[i].scales;\n\n        const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);\n        const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);\n        const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);\n        const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));\n        const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));\n        const svint64_t prod = svdup_n_s64(0);\n        int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),\n                                                                                 svdot_s64(prod, q8sums_2, q6scales_2)));\n        int32_t isum = 0;\n\n        switch (vector_length) {\n            case 128:\n                {\n                    const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);\n                    const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);\n                    svint32_t isum_tmp = svdup_n_s32(0);\n                    for (int j = 0; j < QK_K/128; ++j) {\n                        svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);\n                        svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);\n                        qh += 32;\n                        svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);\n                        svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);\n                        svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);\n                        svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);\n                        q6 += 64;\n                        svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);\n                        svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);\n                        svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);\n                        svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);\n                        q8 += 64;\n\n                        q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));\n                        q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));\n                        q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));\n                        q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));\n                        q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));\n                        q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));\n                        q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));\n                        q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));\n                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);\n                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);\n                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);\n                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);\n\n                        scale += 4;\n                        q8bytes_1 = svld1_s8(pg8_16, q8);\n                        q8bytes_2 = svld1_s8(pg8_16, q8+16);\n                        q8bytes_3 = svld1_s8(pg8_16, q8+32);\n                        q8bytes_4 = svld1_s8(pg8_16, q8+48);\n                        q8 += 64;\n\n                        q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);\n                        q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);\n                        q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));\n                        q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));\n                        q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));\n                        q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));\n                        q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));\n                        q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));\n                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);\n                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);\n                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);\n                        isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);\n                        scale += 4;\n                    }\n                    isum += svaddv_s32(pg32_4, isum_tmp);\n                    sum += d_all * y[i].d * (isum - 32 * isum_mins);\n                }\n                break;\n            case 256:\n            case 512:\n                {\n                    const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);\n                    const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);\n                    const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);\n                    svint32_t isum_tmp = svdup_n_s32(0);\n                    for (int j = 0; j < QK_K/128; j++) {\n                        svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);\n                        qh += 32;\n                        svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);\n                        svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);\n                        q6 += 64;\n                        svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);\n                        svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);\n                        svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);\n                        svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);\n                        q8 += 128;\n                        q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));\n                        q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));\n                        q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);\n                        q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));\n                        q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));\n                        q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));\n                        q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));\n                        q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));\n\n                        svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);\n                        scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);\n                        scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);\n                        svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);\n                        scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);\n                        scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);\n                        svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);\n                        scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);\n                        scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);\n                        svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);\n                        scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);\n                        scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);\n                        svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));\n                        svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));\n                        svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));\n                        svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));\n\n                        isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);\n                        isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);\n                        isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);\n                        isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);\n                        scale += 8;\n                    }\n                    isum += svaddv_s32(pg32_8, isum_tmp);\n                    sum += d_all * y[i].d * (isum - 32 * isum_mins);\n                }\n                break;\n            default:\n                assert(false && \"Unsupported vector length\");\n                break;\n        }\n    }\n\n    *s = sum;\n\n#elif __ARM_NEON\n    float sum = 0;\n\n    const uint8x16_t m4b = vdupq_n_u8(0xF);\n    const int32x4_t  vzero = vdupq_n_s32(0);\n    //const int8x16_t  m32s = vdupq_n_s8(32);\n\n    const uint8x16_t mone = vdupq_n_u8(3);\n\n    ggml_int8x16x4_t q6bytes;\n    ggml_uint8x16x4_t q6h;\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d_all = GGML_CPU_FP16_TO_FP32(x[i].d);\n\n        const uint8_t * GGML_RESTRICT q6 = x[i].ql;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        const int8_t * GGML_RESTRICT scale = x[i].scales;\n\n        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);\n        const int8x16_t scales = vld1q_s8(scale);\n        const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};\n\n        const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),\n                                                   vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),\n                                         vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),\n                                                   vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));\n        int32_t isum_mins = vaddvq_s32(prod);\n\n        int32_t isum = 0;\n\n        for (int j = 0; j < QK_K/128; ++j) {\n\n            ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;\n            ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;\n            ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;\n\n            q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);\n            q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);\n            uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);\n            q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);\n            shifted = vshrq_n_u8(qhbits.val[1], 2);\n            q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);\n\n            //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);\n            //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);\n            //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);\n            //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);\n            q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));\n            q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));\n            q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));\n            q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));\n\n            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +\n                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +\n                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +\n                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];\n\n            scale += 4;\n\n            q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;\n\n            shifted = vshrq_n_u8(qhbits.val[0], 4);\n            q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);\n            shifted = vshrq_n_u8(qhbits.val[1], 4);\n            q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);\n            shifted = vshrq_n_u8(qhbits.val[0], 6);\n            q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);\n            shifted = vshrq_n_u8(qhbits.val[1], 6);\n            q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);\n\n            //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);\n            //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);\n            //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);\n            //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);\n            q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));\n            q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));\n            q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));\n            q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));\n\n            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +\n                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +\n                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +\n                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];\n            scale += 4;\n        }\n        //sum += isum * d_all * y[i].d;\n        sum += d_all * y[i].d * (isum - 32 * isum_mins);\n\n    }\n    *s = sum;\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\n#if defined (__ARM_NEON)\nstatic const int8_t keven_signs_q2xs[1024] = {\n     1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,\n     1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,\n     1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1, -1,\n     1,  1, -1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1,  1,\n     1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1, -1,\n     1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1,  1,\n     1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1,  1,\n     1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1, -1,\n     1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1, -1,\n     1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1,  1,\n     1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1,  1,\n     1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1, -1,\n     1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1,  1,\n     1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,\n     1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1, -1,\n     1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1,\n     1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1, -1,\n     1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1,  1,\n     1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1,\n     1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1,\n     1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1,  1,\n     1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1, -1,\n     1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1, -1,\n     1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1,\n     1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1,\n     1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1, -1,\n     1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1,\n     1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1,\n     1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1,\n     1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,  1,\n     1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1,\n     1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,\n};\n#endif\n\nvoid ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_xxs * GGML_RESTRICT x = vx;\n    const block_q8_K    * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__ARM_NEON)\n\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n\n    uint32_t aux32[4];\n    const uint8_t * aux8 = (const uint8_t *)aux32;\n\n    ggml_int8x16x4_t q2u;\n    ggml_int8x16x4_t q2s;\n    ggml_int8x16x4_t q8b;\n\n    float sumf = 0;\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint16_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n        float sumf1 = 0, sumf2 = 0;\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;\n            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;\n            q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));\n            q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));\n            q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));\n            q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));\n            q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >>  7) & 127))));\n            q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));\n            q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >>  7) & 127))));\n            q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));\n            q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);\n            q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);\n            q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);\n            q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);\n            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);\n            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);\n            sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));\n            sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));\n        }\n        sumf += d*(sumf1 + sumf2);\n    }\n    *s = 0.25f * sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_xs * GGML_RESTRICT x = vx;\n    const block_q8_K   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__ARM_NEON)\n\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n\n    ggml_int8x16x4_t q2u;\n    ggml_int8x16x4_t q2s;\n    ggml_int8x16x4_t q8b;\n\n    int32x4x4_t scales32;\n\n    float sumf = 0;\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint16_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n        const uint8x8_t scales8 = vld1_u8(x[i].scales);\n        const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));\n        const uint8x8_t scales_h = vshr_n_u8(scales8, 4);\n        uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));\n        scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));\n        const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));\n        const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));\n        scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));\n        scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));\n        scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));\n        scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));\n        int32x4_t sumi = vdupq_n_s32(0);\n        for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {\n            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;\n            q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));\n            q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));\n            q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));\n            q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));\n            q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));\n            q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));\n            q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));\n            q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));\n            q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);\n            q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);\n            q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);\n            q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);\n            const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);\n            const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);\n            const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);\n            const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);\n            const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));\n            sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);\n            q2 += 8;\n        }\n        sumf += d*vaddvq_s32(sumi);\n    }\n    *s = 0.125f * sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__ARM_NEON)\n\n   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,\n                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03\n   };\n\n    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};\n\n    const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1);\n    const uint8x16_t        mask2 = vld1q_u8(k_mask2);\n    const uint8x16_t m1 = vdupq_n_u8(1);\n    const int32x4_t vzero = vdupq_n_s32(0);\n\n    uint8x16x2_t vs;\n    ggml_int8x16x4_t q2s;\n    ggml_int8x16x4_t q8b;\n\n    float sumf = 0;\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n\n        const uint8_t * GGML_RESTRICT qs = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        int sumi1 = 0, sumi2 = 0;\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;\n            q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))),\n                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300)))));\n            q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))),\n                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300)))));\n            q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))),\n                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300)))));\n            q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))),\n                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300)))));\n            qs += 8;\n\n            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));\n            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);\n            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);\n            vs.val[0] = vceqq_u8(vs.val[0], mask2);\n            vs.val[1] = vceqq_u8(vs.val[1], mask2);\n\n            q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]);\n            q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);\n\n            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));\n            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);\n            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);\n            vs.val[0] = vceqq_u8(vs.val[0], mask2);\n            vs.val[1] = vceqq_u8(vs.val[1], mask2);\n\n            signs += 4;\n\n            q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]);\n            q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]);\n\n            const int32x4_t p1 = ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]);\n            const int32x4_t p2 = ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]);\n            const int32x4_t p3 = ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]);\n            const int32x4_t p4 = ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]);\n\n            sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf));\n            sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >>  4));\n            sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf));\n            sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >>  4));\n        }\n        sumf += d*(sumi1 + sumi2);\n    }\n\n    *s = 0.125f * sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n\n}\n\nvoid ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq3_xxs * GGML_RESTRICT x = vx;\n    const block_q8_K    * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__ARM_NEON)\n\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n\n    uint32_t aux32[2];\n\n    ggml_int8x16x4_t q3s;\n    ggml_int8x16x4_t q8b;\n\n    float sumf = 0;\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint8_t * GGML_RESTRICT q3 = x[i].qs;\n        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n        float sumf1 = 0, sumf2 = 0;\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;\n            memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);\n            const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]);\n            const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]);\n            const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]);\n            const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]);\n            q3 += 16;\n            q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >>  7) & 127))));\n            q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));\n            q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >>  7) & 127))));\n            q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));\n            q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));\n            q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));\n            q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));\n            q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));\n            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);\n            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);\n            sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));\n            sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));\n        }\n        sumf += d*(sumf1 + sumf2);\n    }\n    *s = 0.5f * sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq3_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__ARM_NEON)\n\n    typedef union {\n        uint16x8_t vec_index;\n        uint16_t   index[8];\n    } vec_index_t;\n\n   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,\n                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03\n   };\n\n    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};\n\n    static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};\n\n    const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1);\n    const uint8x16_t        mask2 = vld1q_u8(k_mask2);\n\n    const int16x8_t  hshift = vld1q_s16(k_shift);\n    const uint16x8_t m256   = vdupq_n_u16(256);\n    const uint8x16_t m1     = vdupq_n_u8(1);\n\n    uint8x16x2_t vs;\n    ggml_int8x16x4_t q3s;\n    ggml_int8x16x4_t q8b;\n    vec_index_t idx;\n\n    uint32_t scales32[2];\n    const uint8_t * scales8 = (const uint8_t *)scales32;\n\n    float sumf = 0;\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint8_t * GGML_RESTRICT qs = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n\n        memcpy(scales32, x[i].scales, 4);\n        scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;\n        scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;\n\n        int sumi1 = 0, sumi2 = 0;\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;\n\n            const uint8x16_t idx_l = vld1q_u8(qs); qs += 16;\n            idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256));\n            const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],\n                                                        iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);\n            const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],\n                                                        iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);\n            idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256));\n            const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],\n                                                        iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);\n            const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],\n                                                        iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);\n\n\n            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));\n            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);\n            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);\n            vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);\n            vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);\n\n            q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));\n            q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));\n\n            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));\n            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);\n            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);\n            vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);\n            vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);\n\n            signs += 4;\n\n            q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2));\n            q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3));\n\n            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);\n            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);\n\n            sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0];\n            sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4];\n        }\n        sumf += d*(sumi1 + sumi2);\n    }\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq1_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __ARM_NEON\n\n    ggml_int8x16x4_t q1b;\n    ggml_int8x16x4_t q8b;\n\n    float sumf = 0;\n    for (int i = 0; i < nb; ++i) {\n\n        const int8_t   * q8 = y[i].qs;\n        const uint8_t  * qs = x[i].qs;\n        const uint16_t * qh = x[i].qh;\n\n        int sumi1 = 0, sumi2 = 0, sumi3 = 0;\n\n        for (int ib = 0; ib < QK_K/32; ib += 2) {\n\n            q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))),\n                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700)))));\n            q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))),\n                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700)))));\n            q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))),\n                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700)))));\n            q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))),\n                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700)))));\n            qs += 8;\n\n            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;\n\n            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]);\n            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]);\n\n            const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;\n            const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;\n            sumi1 += vaddvq_s32(p1) * ls1;\n            sumi2 += vaddvq_s32(p2) * ls2;\n            sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1)\n                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1);\n\n        }\n\n        sumf += y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3);\n    }\n\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq1_m * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    iq1m_scale_t scale;\n\n#if defined __ARM_NEON\n    const int32x4_t mask  = vdupq_n_s32(0x7);\n    const int32x4_t mone  = vdupq_n_s32(1);\n    const int32x4_t mzero = vdupq_n_s32(0);\n\n    ggml_int8x16x4_t deltas;\n    deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1));\n    deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1));\n    deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1));\n    deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1));\n\n    ggml_int8x16x4_t q1b;\n    ggml_int8x16x4_t q8b;\n\n    uint32_t aux32;\n    const uint8_t * aux8 = (const uint8_t *)&aux32;\n\n    float sumf = 0;\n    for (int i = 0; i < nb; ++i) {\n\n        const int8_t   * q8 = y[i].qs;\n        const uint8_t  * qs = x[i].qs;\n        const uint8_t  * qh = x[i].qh;\n        const uint16_t * sc = (const uint16_t *)x[i].scales;\n\n        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);\n\n        int32x4_t sumi1 = mzero;\n        int32x4_t sumi2 = mzero;\n\n        for (int ib = 0; ib < QK_K/32; ib += 2) {\n\n            q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))),\n                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700)))));\n            q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))),\n                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700)))));\n            q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))),\n                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700)))));\n            q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))),\n                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700)))));\n\n            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;\n\n            const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1]));\n            const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3]));\n            const int32x4_t p12 = vpaddq_s32(p1, p2);\n\n            const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that\n            aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202);\n\n            const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1]));\n            const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));\n            const int32x4_t p34 = vpaddq_s32(p3, p4);\n\n            int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);\n\n            scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);\n\n            sumi1 = vmlaq_s32(sumi1, scales_4, p12);\n            sumi2 = vmlaq_s32(sumi2, scales_4, p34);\n\n            qs += 8; qh += 4;\n\n        }\n\n        sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));\n    }\n\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(scale);\n    ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK4_NL == 0);\n    static_assert(QK4_NL == QK8_0, \"QK4_NL and QK8_0 must be the same\");\n\n    const block_iq4_nl * GGML_RESTRICT x = vx;\n    const block_q8_0   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK4_NL;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined __ARM_NEON\n    const int8x16_t values = vld1q_s8(kvalues_iq4nl);\n    const uint8x16_t m4b = vdupq_n_u8(0x0f);\n    uint8x16x2_t q4bits;\n    int8x16x4_t q4b;\n    int8x16x4_t q8b;\n    int32x4_t prod_1, prod_2;\n\n    for (; ib + 1 < nb; ib += 2) {\n\n        q4bits.val[0] = vld1q_u8(x[ib + 0].qs);\n        q4bits.val[1] = vld1q_u8(x[ib + 1].qs);\n        q8b.val[0]    = vld1q_s8(y[ib + 0].qs);\n        q8b.val[1]    = vld1q_s8(y[ib + 0].qs + 16);\n        q8b.val[2]    = vld1q_s8(y[ib + 1].qs);\n        q8b.val[3]    = vld1q_s8(y[ib + 1].qs + 16);\n\n        q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[0], m4b));\n        q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));\n        q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[1], m4b));\n        q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));\n\n        prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);\n        prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);\n\n        sumf +=\n            GGML_CPU_FP16_TO_FP32(x[ib+0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +\n            GGML_CPU_FP16_TO_FP32(x[ib+1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);\n    }\n\n#endif\n    for (; ib < nb; ++ib) {\n        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d);\n        int sumi1 = 0, sumi2 = 0;\n        for (int j = 0; j < QK4_NL/2; ++j) {\n            sumi1 += y[ib].qs[j+       0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];\n            sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >>  4];\n        }\n        sumf += d * (sumi1 + sumi2);\n    }\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_K == 0);\n\n    const block_iq4_xs * GGML_RESTRICT x = vx;\n    const block_q8_K   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __ARM_NEON\n    const int8x16_t values = vld1q_s8(kvalues_iq4nl);\n    const uint8x16_t m4b = vdupq_n_u8(0x0f);\n    ggml_uint8x16x2_t q4bits;\n    ggml_int8x16x4_t q4b;\n    ggml_int8x16x4_t q8b;\n    int32x4_t prod_1, prod_2;\n\n    float sumf = 0;\n\n    for (int ibl = 0; ibl < nb; ++ibl) {\n\n        const int8_t  * q8 = y[ibl].qs;\n        const uint8_t * q4 = x[ibl].qs;\n        uint16_t h = x[ibl].scales_h;\n\n        int sumi1 = 0, sumi2 = 0;\n        for (int ib = 0; ib < QK_K/64; ++ib) {\n\n            q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;\n            q8b    = ggml_vld1q_s8_x4(q8); q8 += 64;\n\n            q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[0], m4b));\n            q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));\n            q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[1], m4b));\n            q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));\n\n            prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);\n            prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);\n\n            int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;\n            int ls2 = ((x[ibl].scales_l[ib] >>  4) | ((h << 2) & 0x30)) - 32;\n            h >>= 4;\n            sumi1 += vaddvq_s32(prod_1) * ls1;\n            sumi2 += vaddvq_s32(prod_2) * ls2;\n\n        }\n\n        sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);\n    }\n\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\n"
  },
  {
    "path": "src/ggml-cpu/arch/arm/repack.cpp",
    "content": "#define GGML_COMMON_IMPL_CPP\n#define GGML_COMMON_DECL_CPP\n#include \"ggml-common.h\"\n#include \"ggml-backend-impl.h\"\n\n#include \"ggml-impl.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-cpu-impl.h\"\n#include \"simd-mappings.h\"\n#include \"traits.h\"\n\n#include <cmath>\n#include <cstring>\n#include <cassert>\n#include <cstdlib> // for qsort\n#include <cstdio>  // for GGML_ASSERT\n\n#define GGML_CPU_CLANG_WORKAROUND\n#include \"../../repack.h\"\n\n#if defined(__GNUC__)\n#pragma GCC diagnostic ignored \"-Woverlength-strings\"\n#endif\n\n#define UNUSED GGML_UNUSED\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))\n// Helper for decoding scales and mins of Q4_K and Q5_K block formats\nstatic inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) {\n    constexpr uint32_t kmask1 = 0x3f3f3f3f;\n    constexpr uint32_t kmask2 = 0x0f0f0f0f;\n    constexpr uint32_t kmask3 = 0x03030303;\n    constexpr uint8_t  scales_size = 12;\n\n    uint32_t sm[3];\n    memcpy(sm, scales_in, scales_size);\n\n    const uint32_t   mins_0_3 = sm[1] & kmask1;\n    const uint32_t   mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);\n    const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };\n\n    *out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));\n\n    uint32_t scales_u32[2];\n    scales_u32[0] = sm[0] & kmask1;\n    scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);\n    memcpy(out_scales, scales_u32, 8);\n}\n#endif\n\nvoid ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK8_0 == 32);\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;\n\n#if defined(__ARM_NEON)\n    float32x4_t srcv[4][8];\n    float id[4];\n\n    for (int i = 0; i < nb; i++) {\n        float32x4_t asrcv[8];\n        float32x4_t amaxv[8];\n\n        for (int row_iter = 0; row_iter < 4; row_iter++) {\n            for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);\n            for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);\n\n            for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);\n            for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);\n            for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);\n\n            const float amax = vmaxvq_f32(amaxv[0]);\n\n            const float d = amax / ((1 << 7) - 1);\n            id[row_iter] = d ? 1.0f / d : 0.0f;\n\n            y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);\n        }\n\n        for (int j = 0; j < 8; j++) {\n            float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]);\n            int32x4_t vi = vcvtnq_s32_f32(v);\n            y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0);\n            y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1);\n            y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2);\n            y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3);\n\n            v = vmulq_n_f32(srcv[1][j], id[1]);\n            vi = vcvtnq_s32_f32(v);\n            y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0);\n            y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1);\n            y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2);\n            y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3);\n\n            v = vmulq_n_f32(srcv[2][j], id[2]);\n            vi = vcvtnq_s32_f32(v);\n            y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0);\n            y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1);\n            y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2);\n            y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3);\n\n            v = vmulq_n_f32(srcv[3][j], id[3]);\n            vi = vcvtnq_s32_f32(v);\n            y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0);\n            y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1);\n            y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2);\n            y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3);\n        }\n    }\n#else\n    UNUSED(nb);\n    UNUSED(y);\n    ggml_quantize_mat_q8_0_4x4_generic(x, vy, k);\n#endif\n}\n\nvoid ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK8_0 == 32);\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;\n\n#if defined(__ARM_NEON)\n    float32x4_t srcv[4][8];\n    float id[4];\n\n    for (int i = 0; i < nb; i++) {\n        float32x4_t asrcv[8];\n        float32x4_t amaxv[8];\n\n        for (int row_iter = 0; row_iter < 4; row_iter++) {\n            for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);\n            for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);\n\n            for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);\n            for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);\n            for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);\n\n            const float amax = vmaxvq_f32(amaxv[0]);\n\n            const float d = amax / ((1 << 7) - 1);\n            id[row_iter] = d ? 1.0f / d : 0.0f;\n\n            y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);\n        }\n\n        for (int j = 0; j < 4; j++) {\n            float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]);\n            int32x4_t vi = vcvtnq_s32_f32(v);\n            y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0);\n            y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1);\n            y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2);\n            y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3);\n            v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]);\n            vi = vcvtnq_s32_f32(v);\n            y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0);\n            y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1);\n            y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2);\n            y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3);\n\n            v = vmulq_n_f32(srcv[1][2 * j], id[1]);\n            vi = vcvtnq_s32_f32(v);\n            y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0);\n            y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1);\n            y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2);\n            y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3);\n            v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]);\n            vi = vcvtnq_s32_f32(v);\n            y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0);\n            y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1);\n            y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2);\n            y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3);\n\n            v = vmulq_n_f32(srcv[2][2 * j], id[2]);\n            vi = vcvtnq_s32_f32(v);\n            y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0);\n            y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1);\n            y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2);\n            y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3);\n            v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]);\n            vi = vcvtnq_s32_f32(v);\n            y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0);\n            y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1);\n            y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2);\n            y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3);\n\n            v = vmulq_n_f32(srcv[3][2 * j], id[3]);\n            vi = vcvtnq_s32_f32(v);\n            y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0);\n            y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1);\n            y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2);\n            y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3);\n            v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]);\n            vi = vcvtnq_s32_f32(v);\n            y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0);\n            y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1);\n            y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2);\n            y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3);\n        }\n    }\n\n#else\n    UNUSED(nb);\n    UNUSED(y);\n    ggml_quantize_mat_q8_0_4x8_generic(x, vy, k);\n#endif\n}\n\nvoid ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 4;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;\n\n    for (int c = 0; c < nc; c += ncols_interleaved) {\n        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n        float32x4_t acc = vdupq_n_f32(0);\n        for (int b = 0; b < nb; b++) {\n            int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);\n            int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);\n            int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);\n            int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);\n            float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);\n\n            int8x16_t a0 = vld1q_s8(a_ptr->qs);\n            int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);\n            float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);\n\n            int32x4_t ret = vdupq_n_s32(0);\n\n            ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0);\n            ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1);\n            ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2);\n            ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3);\n\n            ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0);\n            ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1);\n            ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2);\n            ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3);\n\n            acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),\n                            vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));\n            a_ptr++;\n            b_ptr++;\n        }\n        vst1q_f32(s, acc);\n        s += ncols_interleaved;\n    }\n    return;\n#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    ggml_gemv_q4_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 8;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;\n\n    for (int c = 0; c < nc; c += ncols_interleaved) {\n        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n        float32x4_t acc = vdupq_n_f32(0);\n        for (int b = 0; b < nb; b++) {\n            int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);\n            int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);\n            int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);\n            int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);\n            float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);\n\n            int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs);\n            int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1);\n            int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2);\n            int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3);\n            float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);\n\n            int32x4_t ret0 = vdupq_n_s32(0);\n            int32x4_t ret1 = vdupq_n_s32(0);\n\n            ret0 = vdotq_s32(ret0, b0 << 4, a0);\n            ret1 = vdotq_s32(ret1, b1 << 4, a0);\n            ret0 = vdotq_s32(ret0, b2 << 4, a1);\n            ret1 = vdotq_s32(ret1, b3 << 4, a1);\n\n            ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2);\n            ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2);\n            ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3);\n            ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3);\n\n            int32x4_t ret = vpaddq_s32(ret0, ret1);\n\n            acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),\n                    vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));\n            a_ptr++;\n            b_ptr++;\n        }\n        vst1q_f32(s, acc);\n        s += ncols_interleaved;\n    }\n    return;\n#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    ggml_gemv_q4_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)\n#if defined(__ARM_FEATURE_SVE)\n    if (ggml_cpu_get_sve_cnt() == QK8_0) {\n        const void * b_ptr = vx;\n        const void * a_ptr = vy;\n        float * res_ptr = s;\n\n        __asm__ __volatile__(\n            \"ptrue p0.b\\n\"\n            \"add %x[b_ptr], %x[b_ptr], #0x10\\n\"\n            \"1:\"  // Column loop\n            \"add x22, %x[a_ptr], #0x2\\n\"\n            \"mov z31.b, #0x0\\n\"\n            \"mov x21, %x[nb]\\n\"\n            \"2:\"  // Block loop\n            \"ld1b { z30.b }, p0/Z, [%x[b_ptr]]\\n\"\n            \"ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\\n\"\n            \"mov z28.s, #0x0\\n\"\n            \"mov z27.s, #0x0\\n\"\n            \"ld1rd { z26.d }, p0/Z, [x22]\\n\"\n            \"ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\\n\"\n            \"sub x20, x22, #0x2\\n\"\n            \"sub x21, x21, #0x1\\n\"\n            \"ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\\n\"\n            \"ld1rd { z23.d }, p0/Z, [x22, #8]\\n\"\n            \"lsl z22.b, z30.b, #0x4\\n\"\n            \"lsl z16.b, z29.b, #0x4\\n\"\n            \"and z30.b, z30.b, #0xf0\\n\"\n            \"and z29.b, z29.b, #0xf0\\n\"\n            \"ld1rd { z21.d }, p0/Z, [x22, #16]\\n\"\n            \"ld1rd { z20.d }, p0/Z, [x22, #24]\\n\"\n            \"lsl z19.b, z25.b, #0x4\\n\"\n            \"and z25.b, z25.b, #0xf0\\n\"\n            \"ld1rh { z17.h }, p0/Z, [x20]\\n\"\n            \"ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\\n\"\n            \"sdot z28.s, z22.b, z26.b\\n\"\n            \"sdot z27.s, z16.b, z26.b\\n\"\n            \"lsl z16.b, z24.b, #0x4\\n\"\n            \"add x22, x22, #0x22\\n\"\n            \"and z24.b, z24.b, #0xf0\\n\"\n            \"add %x[b_ptr], %x[b_ptr], #0x90\\n\"\n            \"fcvt z17.s, p0/m, z17.h\\n\"\n            \"fcvt z18.s, p0/m, z18.h\\n\"\n            \"sdot z28.s, z19.b, z23.b\\n\"\n            \"sdot z27.s, z16.b, z23.b\\n\"\n            \"fmul z18.s, z18.s, z17.s\\n\"\n            \"sdot z28.s, z30.b, z21.b\\n\"\n            \"sdot z27.s, z29.b, z21.b\\n\"\n            \"sdot z28.s, z25.b, z20.b\\n\"\n            \"sdot z27.s, z24.b, z20.b\\n\"\n            \"uzp1 z17.s, z28.s, z27.s\\n\"\n            \"uzp2 z16.s, z28.s, z27.s\\n\"\n            \"add z17.s, z17.s, z16.s\\n\"\n            \"asr z17.s, z17.s, #0x4\\n\"\n            \"scvtf z17.s, p0/m, z17.s\\n\"\n            \"fmla z31.s, p0/M, z17.s, z18.s\\n\"\n            \"cbnz x21, 2b\\n\"\n            \"sub %x[nc], %x[nc], #0x8\\n\"\n            \"st1w { z31.s }, p0, [%x[res_ptr]]\\n\"\n            \"add %x[res_ptr], %x[res_ptr], #0x20\\n\"\n            \"cbnz %x[nc], 1b\\n\"\n            : [b_ptr] \"+&r\" (b_ptr), [res_ptr] \"+&r\" (res_ptr), [nc] \"+&r\" (nc)\n            : [a_ptr] \"r\" (a_ptr), [nb] \"r\" (nb)\n            : \"memory\", \"p0\", \"x20\", \"x21\", \"x22\", \"z16\", \"z17\", \"z18\", \"z19\", \"z20\", \"z21\", \"z22\", \"z23\", \"z24\", \"z25\", \"z26\", \"z27\", \"z28\", \"z29\", \"z30\", \"z31\"\n        );\n        return;\n    }\n#endif // #if defined(__ARM_FEATURE_SVE)\n\n#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)\n    ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 4;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    float * res_ptr = s;\n\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);\n\n        float32x4_t sumf = vdupq_n_f32(0);\n        for (int l = 0; l < nb; l++) {\n            uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);\n            uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);\n            uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);\n            uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);\n\n            int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);\n            int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);\n            int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);\n            int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);\n            int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);\n            int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);\n            int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);\n            int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);\n\n            int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);\n            int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);\n\n            int32x4_t sumi = vdupq_n_s32(0);\n            sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);\n            sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);\n            sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);\n            sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);\n            sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);\n            sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);\n            sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);\n            sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);\n\n            float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));\n            float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));\n            float32x4_t d = a_d * b_d;\n\n            sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));\n        }\n\n        vst1q_f32(res_ptr + x * 4, sumf);\n    }\n    return;\n#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)\n    ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 4;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    float * res_ptr = s;\n\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);\n\n        float32x4_t sumf = vdupq_n_f32(0);\n        for (int l = 0; l < nb; l++) {\n            uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);\n            uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);\n            uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);\n            uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);\n\n            int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);\n            int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);\n            int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);\n            int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);\n            int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);\n            int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);\n            int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);\n            int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);\n\n            int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);\n            int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);\n\n            int32x4_t sumi = vdupq_n_s32(0);\n            sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);\n            sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);\n            sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);\n            sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);\n            sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);\n            sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);\n            sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);\n            sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);\n\n            float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));\n            float32x4_t b_d = {\n                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),\n                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),\n                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),\n                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),\n            };\n            float32x4_t d = a_d * b_d;\n\n            sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));\n        }\n\n        vst1q_f32(res_ptr + x * 4, sumf);\n    }\n    return;\n#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)\n    ggml_gemv_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    constexpr int qk = QK_K;\n    const int     nb = n / qk;\n\n    constexpr int ncols_interleaved = 8;\n    constexpr int blocklen          = 8;\n\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    constexpr int    col_groups = ncols_interleaved / 4; // 0123 and 4567\n    const uint8x16_t m4b        = vdupq_n_u8(0x0f);\n\n    // 1x8 tile = 2 x 4\n    float32x4_t acc_f32[col_groups];\n\n    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;\n\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);\n\n        for (int i = 0; i < col_groups; i++) {\n            acc_f32[i] = vdupq_n_f32(0);\n        }\n\n        for (int b = 0; b < nb; b++) {\n            float32x4_t q4_d_0        = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));      // d0 d1 d2 d3\n            float32x4_t q4_d_1        = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));  // d4 d5 d6 d7\n            float32x4_t q8_d          = vdupq_n_f32(q8_ptr[b].d);\n            float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d);\n            float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d);\n            float32x4_t q4_dmin_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));      // dmin 0..3\n            float32x4_t q4_dmin_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));  // dmin 4..7\n            float32x4_t sb_min_0123   = vmulq_f32(q4_dmin_0, q8_d);\n            float32x4_t sb_min_4567   = vmulq_f32(q4_dmin_1, q8_d);\n\n            // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567\n            int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };\n            int32x4_t acc_lo[col_groups];\n            int32x4_t acc_hi[col_groups];\n\n            // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block\n            const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));\n            int16_t         bsums_arr[8];\n            vst1q_s16(bsums_arr, bsums);\n            for (int sb = 0; sb < QK_K / 64; sb++) {\n                for (int i = 0; i < col_groups; i++) {\n                    acc_lo[i] = vdupq_n_s32(0);\n                    acc_hi[i] = vdupq_n_s32(0);\n                }\n                // Need scales for the low and high nibbles\n                // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total\n                int16x8_t q4sb_mins[2];\n                int16x8_t q4sb_scales[2];\n                for (int i = 0; i < 2; i++) {\n                    int8_t    aux_q4sb[8];\n                    const int offset = sb * 24 + i * 12;\n                    decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);\n                    q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));\n                }\n\n                int8x16_t q8_qs[64 / 16];\n                for (int i = 0; i < 64 / 16; i++) {\n                    q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);\n                }\n\n                for (int c = 0; c < col_groups; c++) {\n                    uint8x16_t q4_cols[8];\n                    for (int i = 0; i < 8; i++) {\n                        q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);\n                    }\n\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0);\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1);\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2);\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3);\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0);\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1);\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2);\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3);\n\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0);\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1);\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2);\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3);\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0);\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1);\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2);\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3);\n                }\n\n                // Scales\n                // row c0123 blk0 and blk1\n                const int16x4_t   sc_0123_lo = vget_low_s16(q4sb_scales[0]);\n                const int16x4_t   sc_0123_hi = vget_low_s16(q4sb_scales[1]);\n                const float32x4_t sumf_0123  = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),\n                                                                       vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));\n                acc_f32[0]                   = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);\n                // row c4567 blk0 and blk1\n                const int16x4_t   sc_4567_lo = vget_high_s16(q4sb_scales[0]);\n                const int16x4_t   sc_4567_hi = vget_high_s16(q4sb_scales[1]);\n                const float32x4_t sumf_4567  = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),\n                                                                       vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));\n                acc_f32[1]                   = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);\n\n                // Bias Correction\n                const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);\n                const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);\n\n                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));\n                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));\n                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));\n                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));\n            }  // for sb\n\n            acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);\n            acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);\n        }  // for b\n\n        int base = x * ncols_interleaved;\n        vst1q_f32(s + base, acc_f32[0]);\n        vst1q_f32(s + base + 4, acc_f32[1]);\n    }  // for x\n    return;\n#endif  // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q4_K_8x8_q8_K(int                        n,\n                             float * GGML_RESTRICT      s,\n                             size_t                     bs,\n                             const void * GGML_RESTRICT vx,\n                             const void * GGML_RESTRICT vy,\n                             int                        nr,\n                             int                        nc) {\n    constexpr int qk = QK_K;\n    const int     nb = n / qk;\n\n    constexpr int ncols_interleaved = 8;\n    constexpr int blocklen          = 8;\n\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    constexpr int    col_pairs = ncols_interleaved / 2;\n    const uint8x16_t m4b       = vdupq_n_u8(0x0f);\n\n    // 1x8 tile = 2 x 4\n    float32x4_t acc_f32[ncols_interleaved / 4];\n\n    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;\n\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);\n\n        for (int i = 0; i < ncols_interleaved / 4; i++) {\n            acc_f32[i] = vdupq_n_f32(0);\n        }\n\n        for (int b = 0; b < nb; b++) {\n            float32x4_t q4_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));      // d0 d1 d2 d3\n            float32x4_t q4_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));  // d4 d5 d6 d7\n            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);\n            float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d);\n            float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d);\n            float32x4_t q4_dmin_0  = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));      // dmin 0..3\n            float32x4_t q4_dmin_1  = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));  // dmin 4..7\n            float32x4_t sb_min_0   = vmulq_f32(q4_dmin_0, q8_d);\n            float32x4_t sb_min_1   = vmulq_f32(q4_dmin_1, q8_d);\n\n            // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567\n            int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };\n            // 2 sb each iteration\n            int32x4_t acc_lo[col_pairs];\n            int32x4_t acc_hi[col_pairs];\n\n            // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block\n            const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));\n            int16_t         bsums_arr[8];\n            vst1q_s16(bsums_arr, bsums);\n            for (int sb = 0; sb < QK_K / 64; sb++) {\n                for (int i = 0; i < col_pairs; i++) {\n                    acc_lo[i] = vdupq_n_s32(0);\n                    acc_hi[i] = vdupq_n_s32(0);\n                }\n                // Need scales for the low and high nibbles\n                // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total\n                int16x8_t q4sb_mins[2];  // int16 as its needed for bias_acc later\n                int16x8_t q4sb_scales[2];\n                for (int i = 0; i < 2; i++) {\n                    int8_t    aux_q4sb[8];\n                    const int offset = sb * 24 + i * 12;\n                    decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);\n                    q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));\n                }\n\n                const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;\n\n                // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns\n                // but still need the qs to use the low and hi bits from q4\n                const int8_t * q8_base = q8_ptr[b].qs + sb * 64;\n                int8x16_t      q8_qs[8];\n                for (int i = 0; i < 8; i++) {\n                    q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));\n                }\n\n                // Q4s columns iterated in pairs (01, 23, 45, 67)\n                for (int cp = 0; cp < col_pairs; cp++) {\n                    uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp);\n                    uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64);\n                    uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128);\n                    uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192);\n\n                    acc_lo[cp] =\n                        ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]);  // 0 .. 7\n                    acc_lo[cp] =\n                        ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]);  // 8 ..15\n                    acc_lo[cp] =\n                        ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]);  // 16..23\n                    acc_lo[cp] =\n                        ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]);  // 24..31\n\n                    acc_hi[cp] =\n                        ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]);  // 32..39\n                    acc_hi[cp] =\n                        ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]);  // 40..47\n                    acc_hi[cp] =\n                        ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]);  // 48..55\n                    acc_hi[cp] =\n                        ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]);  // 56..63\n                }\n\n                // Iterates over a pair of column pairs (4 columns) to use a single 128 register\n                // p = 0 -> 0123  p2 -> 4567\n                for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {\n                    int16x4_t   group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]);\n                    int16x4_t   group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]);\n                    float32x4_t sb_scale        = p == 0 ? sb_scale_0 : sb_scale_1;\n\n                    // 0123 or 4567\n                    float32x4_t sumf_0 =\n                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));\n                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);\n\n                    float32x4_t sumf_1 =\n                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));\n                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);\n                }\n\n                // Multiply Acc bsum + mins\n                // Each pair of subblocks share the same bsums\n                // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).\n                int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);\n                int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);\n\n                // cols 0-3 bias\n                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));\n                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));\n\n                // cols 4-7 bias\n                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));\n                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));\n            }  // for sb\n\n            acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0);\n            acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1);\n        }  // for b\n\n        int base = x * ncols_interleaved;\n        vst1q_f32(s + base, acc_f32[0]);\n        vst1q_f32(s + base + 4, acc_f32[1]);\n    }  // for x\n    return;\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q5_K_8x4_q8_K(int                        n,\n                             float * GGML_RESTRICT      s,\n                             size_t                     bs,\n                             const void * GGML_RESTRICT vx,\n                             const void * GGML_RESTRICT vy,\n                             int                        nr,\n                             int                        nc) {\n    constexpr int qk = QK_K;\n    const int     nb = n / qk;\n\n    constexpr int ncols_interleaved = 8;\n    constexpr int blocklen          = 4;\n\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    constexpr int    col_groups = ncols_interleaved / 4;  // 0123 and 4567\n    const uint8x16_t m4b        = vdupq_n_u8(0x0f);\n    const uint8x16_t mone       = vdupq_n_u8(1);\n    const uint8x16_t mtwo       = vdupq_n_u8(2);\n\n    // 1x8 tile = 2 x 4\n    float32x4_t acc_f32[col_groups];\n\n    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;\n\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);\n\n        for (int i = 0; i < col_groups; i++) {\n            acc_f32[i] = vdupq_n_f32(0);\n        }\n\n        for (int b = 0; b < nb; b++) {\n            float32x4_t q5_d_0        = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));      // d0 d1 d2 d3\n            float32x4_t q5_d_1        = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));  // d4 d5 d6 d7\n            float32x4_t q8_d          = vdupq_n_f32(q8_ptr[b].d);\n            float32x4_t sb_scale_0123 = vmulq_f32(q5_d_0, q8_d);\n            float32x4_t sb_scale_4567 = vmulq_f32(q5_d_1, q8_d);\n            float32x4_t q5_dmin_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));      // dmin 0..3\n            float32x4_t q5_dmin_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));  // dmin 4..7\n            float32x4_t sb_min_0123   = vmulq_f32(q5_dmin_0, q8_d);\n            float32x4_t sb_min_4567   = vmulq_f32(q5_dmin_1, q8_d);\n\n            // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567\n            int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };\n            int32x4_t acc_lo[col_groups];\n            int32x4_t acc_hi[col_groups];\n\n            // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block\n            const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));\n            int16_t         bsums_arr[8];\n            vst1q_s16(bsums_arr, bsums);\n\n            uint8x16_t qh[col_groups][8];\n            for (int c = 0; c < col_groups; c++) {\n                for (int i = 0; i < 8; i++) {\n                    qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);\n                }\n            }\n\n            for (int sb = 0; sb < QK_K / 64; sb++) {\n                for (int i = 0; i < col_groups; i++) {\n                    acc_lo[i] = vdupq_n_s32(0);\n                    acc_hi[i] = vdupq_n_s32(0);\n                }\n                // Need scales for the low and high nibbles\n                // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total\n                int16x8_t q5sb_mins[2];\n                int16x8_t q5sb_scales[2];\n                for (int i = 0; i < 2; i++) {\n                    int8_t    aux_q5sb[8];\n                    const int offset = sb * 24 + i * 12;\n                    decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);\n                    q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));\n                }\n\n                int8x16_t q8_qs[4];\n                for (int i = 0; i < 4; i++) {\n                    q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);\n                }\n\n                for (int c = 0; c < col_groups; c++) {\n                    uint8x16_t q5_cols[8];\n                    uint8x16_t hbit_lo[8];\n                    uint8x16_t hbit_hi[8];\n                    int8x16_t  q5_lo[8];\n                    int8x16_t  q5_hi[8];\n\n                    for (int i = 0; i < 8; i++) {\n                        q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);\n                        hbit_lo[i] = vandq_u8(qh[c][i], mone);\n                        hbit_hi[i] = vshlq_n_u8(vandq_u8(qh[c][i], mtwo), 3);\n                        qh[c][i]   = vshrq_n_u8(qh[c][i], 2);\n                        q5_lo[i]   = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_cols[i], m4b), hbit_lo[i], 4));\n                        q5_hi[i]   = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_cols[i], 4), hbit_hi[i]));\n                    }\n\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[0], q8_qs[0], 0);\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[1], q8_qs[0], 1);\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[2], q8_qs[0], 2);\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[3], q8_qs[0], 3);\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[4], q8_qs[1], 0);\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[5], q8_qs[1], 1);\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[6], q8_qs[1], 2);\n                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[7], q8_qs[1], 3);\n\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[0], q8_qs[2], 0);\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[1], q8_qs[2], 1);\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[2], q8_qs[2], 2);\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[3], q8_qs[2], 3);\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[4], q8_qs[3], 0);\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[5], q8_qs[3], 1);\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[6], q8_qs[3], 2);\n                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[7], q8_qs[3], 3);\n                }\n\n                // Scales\n                // row c0123 blk0 and blk1\n                const int16x4_t   sc_0123_lo = vget_low_s16(q5sb_scales[0]);\n                const int16x4_t   sc_0123_hi = vget_low_s16(q5sb_scales[1]);\n                const float32x4_t sumf_0123  = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),\n                                                                       vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));\n                acc_f32[0]                   = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);\n                // row c4567 blk0 and blk1\n                const int16x4_t   sc_4567_lo = vget_high_s16(q5sb_scales[0]);\n                const int16x4_t   sc_4567_hi = vget_high_s16(q5sb_scales[1]);\n                const float32x4_t sumf_4567  = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),\n                                                                       vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));\n                acc_f32[1]                   = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);\n\n                // Bias Correction\n                const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);\n                const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);\n\n                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));\n                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));\n                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));\n                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));\n            }  // for sb\n\n            acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);\n            acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);\n        }  // for b\n\n        int base = x * ncols_interleaved;\n        vst1q_f32(s + base, acc_f32[0]);\n        vst1q_f32(s + base + 4, acc_f32[1]);\n    }  // for x\n    return;\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    ggml_gemv_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q5_K_8x8_q8_K(int                        n,\n                             float * GGML_RESTRICT      s,\n                             size_t                     bs,\n                             const void * GGML_RESTRICT vx,\n                             const void * GGML_RESTRICT vy,\n                             int                        nr,\n                             int                        nc) {\n    constexpr int qk = QK_K;\n    const int     nb = n / qk;\n\n    constexpr int ncols_interleaved = 8;\n    constexpr int blocklen          = 8;\n\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    constexpr int    col_pairs = ncols_interleaved / 2;\n    const uint8x16_t m4b       = vdupq_n_u8(0x0f);\n    const uint8x16_t mone      = vdupq_n_u8(1);\n    const uint8x16_t mtwo      = vdupq_n_u8(2);\n\n    // 1x8 tile = 2 x 4\n    float32x4_t acc_f32[ncols_interleaved / 4];\n\n    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;\n\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);\n\n        for (int i = 0; i < ncols_interleaved / 4; i++) {\n            acc_f32[i] = vdupq_n_f32(0);\n        }\n\n        for (int b = 0; b < nb; b++) {\n            float32x4_t q5_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));      // d0 d1 d2 d3\n            float32x4_t q5_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));  // d4 d5 d6 d7\n            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);\n            float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d);\n            float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d);\n            float32x4_t q5_dmin_0  = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));      // dmin 0..3\n            float32x4_t q5_dmin_1  = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));  // dmin 4..7\n            float32x4_t sb_min_0   = vmulq_f32(q5_dmin_0, q8_d);\n            float32x4_t sb_min_1   = vmulq_f32(q5_dmin_1, q8_d);\n\n            // 2 sb each iteration\n            int32x4_t acc_lo[col_pairs];\n            int32x4_t acc_hi[col_pairs];\n\n            // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block\n            const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));\n            int16_t         bsums_arr[8];\n            vst1q_s16(bsums_arr, bsums);\n\n            // Load qh once per block and shift after each subblock\n            const uint8_t * qh_base = q5_ptr[b].qh;\n            uint8x16_t      qh[col_pairs][4];\n            for (int cp = 0; cp < col_pairs; cp++) {\n                qh[cp][0] = vld1q_u8(qh_base + 16 * cp);\n                qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);\n                qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);\n                qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);\n            }\n\n            for (int sb = 0; sb < QK_K / 64; sb++) {\n                for (int i = 0; i < col_pairs; i++) {\n                    acc_lo[i] = vdupq_n_s32(0);\n                    acc_hi[i] = vdupq_n_s32(0);\n                }\n                // Need scales for the low and high nibbles\n                // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total\n                int16x8_t q5sb_mins[2];  // int16 as its needed for bias_acc later\n                int16x8_t q5sb_scales[2];\n                for (int i = 0; i < 2; i++) {\n                    int8_t    aux_q5sb[8];\n                    const int offset = sb * 24 + i * 12;\n                    decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);\n                    q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));\n                }\n\n                const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K;\n\n                // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns\n                const int8_t * q8_base = q8_ptr[b].qs + sb * 64;\n                int8x16_t      q8_qs[8];\n                for (int i = 0; i < 8; i++) {\n                    q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));\n                }\n\n                // Q5s column pair loop unrolled\n                {\n                    // Cols 01\n                    uint8x16_t qs_0 = vld1q_u8(qs_base);\n                    uint8x16_t qs_1 = vld1q_u8(qs_base + 64);\n                    uint8x16_t qs_2 = vld1q_u8(qs_base + 128);\n                    uint8x16_t qs_3 = vld1q_u8(qs_base + 192);\n\n                    uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone);\n                    uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone);\n                    uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone);\n                    uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone);\n                    uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3);\n                    uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3);\n                    uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3);\n                    uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3);\n\n                    qh[0][0] = vshrq_n_u8(qh[0][0], 2);\n                    qh[0][1] = vshrq_n_u8(qh[0][1], 2);\n                    qh[0][2] = vshrq_n_u8(qh[0][2], 2);\n                    qh[0][3] = vshrq_n_u8(qh[0][3], 2);\n\n                    acc_lo[0] = ggml_vdotq_s32(\n                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);\n                    acc_lo[0] = ggml_vdotq_s32(\n                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);\n                    acc_lo[0] = ggml_vdotq_s32(\n                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);\n                    acc_lo[0] = ggml_vdotq_s32(\n                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);\n                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),\n                                               q8_qs[4]);\n                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),\n                                               q8_qs[5]);\n                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),\n                                               q8_qs[6]);\n                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),\n                                               q8_qs[7]);\n\n                    // Cols 23\n                    qs_0 = vld1q_u8(qs_base + 16);\n                    qs_1 = vld1q_u8(qs_base + 80);\n                    qs_2 = vld1q_u8(qs_base + 144);\n                    qs_3 = vld1q_u8(qs_base + 208);\n\n                    hbit_lo_0 = vandq_u8(qh[1][0], mone);\n                    hbit_lo_1 = vandq_u8(qh[1][1], mone);\n                    hbit_lo_2 = vandq_u8(qh[1][2], mone);\n                    hbit_lo_3 = vandq_u8(qh[1][3], mone);\n                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3);\n                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3);\n                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3);\n                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3);\n\n                    qh[1][0] = vshrq_n_u8(qh[1][0], 2);\n                    qh[1][1] = vshrq_n_u8(qh[1][1], 2);\n                    qh[1][2] = vshrq_n_u8(qh[1][2], 2);\n                    qh[1][3] = vshrq_n_u8(qh[1][3], 2);\n\n                    acc_lo[1] = ggml_vdotq_s32(\n                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);\n                    acc_lo[1] = ggml_vdotq_s32(\n                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);\n                    acc_lo[1] = ggml_vdotq_s32(\n                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);\n                    acc_lo[1] = ggml_vdotq_s32(\n                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);\n                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),\n                                               q8_qs[4]);\n                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),\n                                               q8_qs[5]);\n                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),\n                                               q8_qs[6]);\n                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),\n                                               q8_qs[7]);\n\n                    // Cols 45\n                    qs_0 = vld1q_u8(qs_base + 32);\n                    qs_1 = vld1q_u8(qs_base + 96);\n                    qs_2 = vld1q_u8(qs_base + 160);\n                    qs_3 = vld1q_u8(qs_base + 224);\n\n                    hbit_lo_0 = vandq_u8(qh[2][0], mone);\n                    hbit_lo_1 = vandq_u8(qh[2][1], mone);\n                    hbit_lo_2 = vandq_u8(qh[2][2], mone);\n                    hbit_lo_3 = vandq_u8(qh[2][3], mone);\n                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3);\n                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3);\n                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3);\n                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3);\n\n                    qh[2][0] = vshrq_n_u8(qh[2][0], 2);\n                    qh[2][1] = vshrq_n_u8(qh[2][1], 2);\n                    qh[2][2] = vshrq_n_u8(qh[2][2], 2);\n                    qh[2][3] = vshrq_n_u8(qh[2][3], 2);\n\n                    acc_lo[2] = ggml_vdotq_s32(\n                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);\n                    acc_lo[2] = ggml_vdotq_s32(\n                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);\n                    acc_lo[2] = ggml_vdotq_s32(\n                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);\n                    acc_lo[2] = ggml_vdotq_s32(\n                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);\n                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),\n                                               q8_qs[4]);\n                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),\n                                               q8_qs[5]);\n                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),\n                                               q8_qs[6]);\n                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),\n                                               q8_qs[7]);\n\n                    // Cols 45\n                    qs_0 = vld1q_u8(qs_base + 48);\n                    qs_1 = vld1q_u8(qs_base + 112);\n                    qs_2 = vld1q_u8(qs_base + 176);\n                    qs_3 = vld1q_u8(qs_base + 240);\n\n                    hbit_lo_0 = vandq_u8(qh[3][0], mone);\n                    hbit_lo_1 = vandq_u8(qh[3][1], mone);\n                    hbit_lo_2 = vandq_u8(qh[3][2], mone);\n                    hbit_lo_3 = vandq_u8(qh[3][3], mone);\n                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3);\n                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3);\n                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3);\n                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3);\n\n                    qh[3][0] = vshrq_n_u8(qh[3][0], 2);\n                    qh[3][1] = vshrq_n_u8(qh[3][1], 2);\n                    qh[3][2] = vshrq_n_u8(qh[3][2], 2);\n                    qh[3][3] = vshrq_n_u8(qh[3][3], 2);\n\n                    acc_lo[3] = ggml_vdotq_s32(\n                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);\n                    acc_lo[3] = ggml_vdotq_s32(\n                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);\n                    acc_lo[3] = ggml_vdotq_s32(\n                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);\n                    acc_lo[3] = ggml_vdotq_s32(\n                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);\n                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),\n                                               q8_qs[4]);\n                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),\n                                               q8_qs[5]);\n                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),\n                                               q8_qs[6]);\n                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),\n                                               q8_qs[7]);\n                }\n\n                // Prepare bsum vectors for bias computation\n                // Each pair of subblocks share the same bsums\n                int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);\n                int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);\n\n                // Iterates over a pair of column pairs (4 columns) to use a single 128 register\n                // p = 0 -> 0123  p2 -> 4567\n                for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {\n                    int16x4_t   group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]);\n                    int16x4_t   group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]);\n                    int16x4_t   group_mins_lo   = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]);\n                    int16x4_t   group_mins_hi   = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]);\n                    float32x4_t sb_scale        = p == 0 ? sb_scale_0 : sb_scale_1;\n                    float32x4_t sb_min          = p == 0 ? sb_min_0 : sb_min_1;\n\n                    // 0123 or 4567\n                    float32x4_t sumf_0 =\n                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));\n                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);\n\n                    float32x4_t sumf_1 =\n                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));\n                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);\n\n                    // FUSED BIAS: Compute and subtract bias immediately\n                    // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min\n                    int32x4_t bias       = vmull_s16(bsums_vec_lo, group_mins_lo);\n                    bias                 = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);\n                    float32x4_t bias_f32 = vcvtq_f32_s32(bias);\n                    acc_f32[i]           = vmlsq_f32(acc_f32[i], sb_min, bias_f32);\n                }\n            }  // for sb\n        }  // for b\n\n        int base = x * ncols_interleaved;\n        vst1q_f32(s + base, acc_f32[0]);\n        vst1q_f32(s + base + 4, acc_f32[1]);\n    }  // for x\n    return;\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q6_K_8x4_q8_K(int                        n,\n                             float * GGML_RESTRICT      s,\n                             size_t                     bs,\n                             const void * GGML_RESTRICT vx,\n                             const void * GGML_RESTRICT vy,\n                             int                        nr,\n                             int                        nc) {\n    constexpr int qk = QK_K;\n    const int     nb = n / qk;\n\n    constexpr int ncols_interleaved = 8;\n    constexpr int blocklen          = 4;\n\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    constexpr int    col_groups = ncols_interleaved / 4;\n    const uint8x16_t m4b        = vdupq_n_u8(0x0f);\n    const uint8x16_t mask_lo    = vdupq_n_u8(0x03);\n    const uint8x16_t mask_hi    = vdupq_n_u8(0x30);\n\n    // 1x8 tile = 2 x 4\n    float32x4_t acc_f32[2];\n\n    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;\n\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);\n\n        for (int i = 0; i < col_groups; i++) {\n            acc_f32[i] = vdupq_n_f32(0);\n        }\n\n        for (int b = 0; b < nb; b++) {\n            float32x4_t q6_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));      // d0 d1 d2 d3\n            float32x4_t q6_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));  // d4 d5 d6 d7\n            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);\n            float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);\n            float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);\n\n            int32x4_t acc[col_groups];\n            for (int i = 0; i < col_groups; i++) {\n                acc[i] = vdupq_n_s32(0);\n            }\n\n            // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)\n            // Reused for bias and dequantization later\n            int16_t q6_scales[16 * 8];\n            for (int i = 0; i < 16; i++) {\n                int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));\n                vst1q_s16(q6_scales + i * 8, scales);\n            }\n\n            // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift\n            int32x4_t bias_lo = vdupq_n_s32(0);\n            int32x4_t bias_hi = vdupq_n_s32(0);\n\n            // Load bsums in chunks of 4 to process with vectorized operations\n            for (int i = 0; i < 16; i += 4) {\n                int16x4_t bsums_vec   = vld1_s16(q8_ptr[b].bsums + i);\n                int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);\n                int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);\n                int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);\n                int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);\n                int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);\n                int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);\n                int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);\n                int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);\n\n                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);\n                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);\n                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);\n                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);\n                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);\n                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);\n                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);\n                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);\n            }\n            bias_lo = vshlq_n_s32(bias_lo, 5);\n            bias_hi = vshlq_n_s32(bias_hi, 5);\n\n            // Process two 128-value halves per superblock\n            for (int half = 0; half < 2; half++) {\n                const uint8_t * ql_base = q6_ptr[b].ql + half * 512;\n                const uint8_t * qh_base = q6_ptr[b].qh + half * 256;\n\n                // A subblock (sb) is a set of weights that share the scale\n                // Since q6_K scales are per 16 elements\n                // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)\n                for (int sb = 0; sb < QK_K / 64; sb++) {\n                    const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;\n                    const int8_t * q8_base_h = q8_base_l + 64;\n\n                    // Load and duplicate q8 values (each register covers four interleaved columns of q6)\n                    int8x16_t q8_l[4];\n                    int8x16_t q8_h[4];\n                    for (int i = 0; i < 4; i++) {\n                        q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4));\n                        q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4));\n                    }\n\n                    const int ql_off_base = sb * QK_K / 2;\n                    const int qh_off_base = ql_off_base & 255;  // wraps after 256 bytes\n\n                    // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)\n                    uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);\n                    uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);\n                    uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);\n                    uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);\n\n                    // Adjust qh for subblocks 2 and 3 (shift right by 2)\n                    if (sb > 1) {\n                        q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);\n                        q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);\n                        q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);\n                        q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);\n                        q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);\n                        q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);\n                        q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);\n                        q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);\n                    }\n\n                    const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3],\n                                                  q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] };\n                    const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3],\n                                                  q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] };\n\n                    // Process column groups (0-3, 4-7)\n                    for (int g = 0; g < col_groups; g++) {\n                        int32x4_t sb_acc_l = vdupq_n_s32(0);\n                        int32x4_t sb_acc_h = vdupq_n_s32(0);\n\n                        for (int chunk = 0; chunk < 4; chunk++) {\n                            const int idx = chunk * 2 + g;\n\n                            const uint8x16_t q6_qs_l = q6_ql[idx];\n                            const uint8x16_t q6_qs_h = q6_qh[idx];\n\n                            // Extract high 2 bits for upper nibble reconstruction\n                            const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi);\n\n                            // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)\n                            const int8x16_t q6_l =\n                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4));\n                            const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh));\n\n                            sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]);\n                            sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]);\n                        }\n\n                        const int scale_idx_l = half * 8 + sb;\n                        const int scale_idx_h = half * 8 + sb + 4;\n\n                        const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4));\n                        const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4));\n\n                        acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l);\n                        acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h);\n                    }\n                }\n            }  // for half\n\n            // Bias correction\n            acc[0] = vsubq_s32(acc[0], bias_lo);\n            acc[1] = vsubq_s32(acc[1], bias_hi);\n\n            // Apply superblock scale (no mins for q6_K)\n            // acc[g] has [c0, c1, c2, c3]\n            float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0);\n            float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1);\n\n            acc_f32[0] = vaddq_f32(acc_f32[0], w_0123);\n            acc_f32[1] = vaddq_f32(acc_f32[1], w_4567);\n        }  // for b\n\n        int base = x * ncols_interleaved;\n        vst1q_f32(s + base, acc_f32[0]);\n        vst1q_f32(s + base + 4, acc_f32[1]);\n    }  // for x\n    return;\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q6_K_8x8_q8_K(int                        n,\n                             float * GGML_RESTRICT      s,\n                             size_t                     bs,\n                             const void * GGML_RESTRICT vx,\n                             const void * GGML_RESTRICT vy,\n                             int                        nr,\n                             int                        nc) {\n    constexpr int qk = QK_K;\n    const int     nb = n / qk;\n\n    constexpr int ncols_interleaved = 8;\n    constexpr int blocklen          = 8;\n\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    constexpr int    col_pairs = ncols_interleaved / 2;\n    const uint8x16_t m4b       = vdupq_n_u8(0x0f);\n    const uint8x16_t mask_lo   = vdupq_n_u8(0x03);\n    const uint8x16_t mask_hi   = vdupq_n_u8(0x30);\n\n    // 1x8 tile = 2 x 4\n    float32x4_t acc_f32[2];\n\n    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;\n\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);\n\n        acc_f32[0] = vdupq_n_f32(0);\n        acc_f32[1] = vdupq_n_f32(0);\n\n        for (int b = 0; b < nb; b++) {\n            float32x4_t q6_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));      // d0 d1 d2 d3\n            float32x4_t q6_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));  // d4 d5 d6 d7\n            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);\n            float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);\n            float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);\n\n            int32x2_t acc[col_pairs];\n            for (int i = 0; i < col_pairs; i++) {\n                acc[i] = vdup_n_s32(0);\n            }\n\n            // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)\n            // Reused for bias and dequantization later\n            int16_t q6_scales[16 * 8];\n            for (int i = 0; i < 16; i++) {\n                int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));\n                vst1q_s16(q6_scales + i * 8, scales);\n            }\n\n            // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift\n            int32x4_t bias_lo = vdupq_n_s32(0);\n            int32x4_t bias_hi = vdupq_n_s32(0);\n\n            // Load bsums in chunks of 4 to process with vectorized operations\n            for (int i = 0; i < 16; i += 4) {\n                int16x4_t bsums_vec   = vld1_s16(q8_ptr[b].bsums + i);\n                int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);\n                int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);\n                int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);\n                int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);\n                int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);\n                int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);\n                int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);\n                int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);\n\n                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);\n                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);\n                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);\n                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);\n                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);\n                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);\n                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);\n                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);\n            }\n            bias_lo = vshlq_n_s32(bias_lo, 5);\n            bias_hi = vshlq_n_s32(bias_hi, 5);\n\n            // Process two 128-value halves per superblock\n            for (int half = 0; half < 2; half++) {\n                const uint8_t * ql_base = q6_ptr[b].ql + half * 512;\n                const uint8_t * qh_base = q6_ptr[b].qh + half * 256;\n\n                // A subblock (sb) is a set of weights that share the scale\n                // Since q6_K scales are per 16 elements\n                // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)\n                for (int sb = 0; sb < QK_K / 64; sb++) {\n                    const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;\n                    const int8_t * q8_base_h = q8_base_l + 64;\n\n                    // Load and duplicate q8 values (each register covers two interleaved columns of q6)\n                    int8x16_t q8_l[2];\n                    int8x16_t q8_h[2];\n                    for (int i = 0; i < 2; i++) {\n                        q8_l[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_l + i * 8));\n                        q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8));\n                    }\n\n                    const int ql_off_base = sb * QK_K / 2;\n                    const int qh_off_base = ql_off_base & 255;  // wraps after 256 bytes\n\n                    // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)\n                    uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);\n                    uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);\n                    uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);\n                    uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);\n\n                    // Adjust qh for subblocks 2 and 3 (shift right by 2)\n                    if (sb > 1) {\n                        q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);\n                        q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);\n                        q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);\n                        q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);\n                        q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);\n                        q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);\n                        q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);\n                        q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);\n                    }\n\n                    // Process column pairs (0-1, 2-3, 4-5, 6-7)\n                    for (int cp = 0; cp < col_pairs; cp++) {\n                        const uint8x16_t q6_qs_cp_0_l = q6_ql_0.val[cp];\n                        const uint8x16_t q6_qs_cp_1_l = q6_ql_1.val[cp];\n                        const uint8x16_t q6_qs_cp_0_h = q6_qh_0.val[cp];\n                        const uint8x16_t q6_qs_cp_1_h = q6_qh_1.val[cp];\n\n                        // Extract high 2 bits for upper nibble reconstruction\n                        const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);\n                        const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);\n\n                        // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)\n                        const int8x16_t q6_l0 = vreinterpretq_s8_u8(\n                            vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4));\n                        const int8x16_t q6_l1 = vreinterpretq_s8_u8(\n                            vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4));\n                        const int8x16_t q6_h0 =\n                            vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh));\n                        const int8x16_t q6_h1 =\n                            vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh));\n\n                        int32x4_t sb_acc_l = vdupq_n_s32(0);\n                        sb_acc_l           = vdotq_s32(sb_acc_l, q6_l0, q8_l[0]);\n                        sb_acc_l           = vdotq_s32(sb_acc_l, q6_l1, q8_l[1]);\n\n                        int32x4_t sb_acc_h = vdupq_n_s32(0);\n                        sb_acc_h           = vdotq_s32(sb_acc_h, q6_h0, q8_h[0]);\n                        sb_acc_h           = vdotq_s32(sb_acc_h, q6_h1, q8_h[1]);\n\n                        // Pairwise add to get per-column sums: [col0, col1]\n                        int32x2_t sum_l = vpadd_s32(vget_low_s32(sb_acc_l), vget_high_s32(sb_acc_l));\n                        int32x2_t sum_h = vpadd_s32(vget_low_s32(sb_acc_h), vget_high_s32(sb_acc_h));\n\n                        const int scale_idx_l = half * 8 + sb;\n                        const int scale_idx_h = half * 8 + sb + 4;\n\n                        // Access scales using array indexing (scales are interleaved by column)\n                        const int32x2_t scale_vec_l = { (int32_t) q6_scales[scale_idx_l * 8 + cp * 2],\n                                                        (int32_t) q6_scales[scale_idx_l * 8 + cp * 2 + 1] };\n                        const int32x2_t scale_vec_h = { (int32_t) q6_scales[scale_idx_h * 8 + cp * 2],\n                                                        (int32_t) q6_scales[scale_idx_h * 8 + cp * 2 + 1] };\n\n                        // Accumulate scaled results\n                        acc[cp] = vmla_s32(acc[cp], sum_l, scale_vec_l);\n                        acc[cp] = vmla_s32(acc[cp], sum_h, scale_vec_h);\n                    }\n                }\n            }  // for half\n\n            // Bias correction\n            acc[0] = vsub_s32(acc[0], vget_low_s32(bias_lo));\n            acc[1] = vsub_s32(acc[1], vget_high_s32(bias_lo));\n            acc[2] = vsub_s32(acc[2], vget_low_s32(bias_hi));\n            acc[3] = vsub_s32(acc[3], vget_high_s32(bias_hi));\n\n            // Apply superblock scale (no mins for q6_K)\n            // acc[cp] has [c0, c1]\n            float32x2_t w_01 = vmul_f32(vcvt_f32_s32(acc[0]), vget_low_f32(sb_scale_0));\n            float32x2_t w_23 = vmul_f32(vcvt_f32_s32(acc[1]), vget_high_f32(sb_scale_0));\n            float32x2_t w_45 = vmul_f32(vcvt_f32_s32(acc[2]), vget_low_f32(sb_scale_1));\n            float32x2_t w_67 = vmul_f32(vcvt_f32_s32(acc[3]), vget_high_f32(sb_scale_1));\n\n            acc_f32[0] = vaddq_f32(acc_f32[0], vcombine_f32(w_01, w_23));\n            acc_f32[1] = vaddq_f32(acc_f32[1], vcombine_f32(w_45, w_67));\n        }  // for b\n\n        int base = x * ncols_interleaved;\n        vst1q_f32(s + base, acc_f32[0]);\n        vst1q_f32(s + base + 4, acc_f32[1]);\n    }  // for x\n    return;\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    ggml_gemv_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q8_0_4x4_q8_0(int                        n,\n                             float * GGML_RESTRICT      s,\n                             size_t                     bs,\n                             const void * GGML_RESTRICT vx,\n                             const void * GGML_RESTRICT vy,\n                             int                        nr,\n                             int                        nc) {\n    const int qk                = QK8_0;\n    const int nb                = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen          = 4;\n\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;\n\n    for (int c = 0; c < nc; c += ncols_interleaved) {\n        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n        float32x4_t        acc   = vdupq_n_f32(0);\n        for (int b = 0; b < nb; b++) {\n            int8x16x4_t b_low  = vld1q_s8_x4((const int8_t *) b_ptr->qs);\n            int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);\n            float16x4_t bd     = vld1_f16((const __fp16 *) b_ptr->d);\n\n            int8x16x2_t a  = vld1q_s8_x2(a_ptr->qs);\n            float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);\n\n            int32x4_t ret = vdupq_n_s32(0);\n\n            ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);\n            ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);\n            ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);\n            ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);\n\n            ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);\n            ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);\n            ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);\n            ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);\n\n            acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));\n            a_ptr++;\n            b_ptr++;\n        }\n        vst1q_f32(s, acc);\n        s += ncols_interleaved;\n    }\n    return;\n\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q8_0_4x8_q8_0(int                        n,\n                             float * GGML_RESTRICT      s,\n                             size_t                     bs,\n                             const void * GGML_RESTRICT vx,\n                             const void * GGML_RESTRICT vy,\n                             int                        nr,\n                             int                        nc) {\n    const int qk                = QK8_0;\n    const int nb                = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen          = 8;\n\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;\n\n    for (int c = 0; c < nc; c += ncols_interleaved) {\n        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n        float32x4_t        acc   = vdupq_n_f32(0);\n\n        for (int b = 0; b < nb; b++) {\n            int8x16x4_t b_low  = vld1q_s8_x4((const int8_t *) b_ptr->qs);\n            int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);\n            float16x4_t bd     = vld1_f16((const __fp16 *) b_ptr->d);\n\n            int8x8x4_t  a_chunks = vld1_s8_x4(a_ptr->qs);\n            int8x16_t   a0       = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);\n            int8x16_t   a1       = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);\n            int8x16_t   a2       = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);\n            int8x16_t   a3       = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);\n            float16x4_t ad       = vld1_dup_f16((const __fp16 *) &a_ptr->d);\n\n            int32x4_t ret0 = vdupq_n_s32(0);\n            int32x4_t ret1 = vdupq_n_s32(0);\n\n            // 0..7\n            ret0 = vdotq_s32(ret0, b_low.val[0], a0);\n            ret1 = vdotq_s32(ret1, b_low.val[1], a0);\n            // 8..15\n            ret0 = vdotq_s32(ret0, b_low.val[2], a1);\n            ret1 = vdotq_s32(ret1, b_low.val[3], a1);\n            // 16..23\n            ret0 = vdotq_s32(ret0, b_high.val[0], a2);\n            ret1 = vdotq_s32(ret1, b_high.val[1], a2);\n            // 24..31\n            ret0 = vdotq_s32(ret0, b_high.val[2], a3);\n            ret1 = vdotq_s32(ret1, b_high.val[3], a3);\n\n            int32x4_t ret = vpaddq_s32(ret0, ret1);\n\n            acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));\n            a_ptr++;\n            b_ptr++;\n        }\n        vst1q_f32(s, acc);\n        s += ncols_interleaved;\n    }\n    return;\n\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 4;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    const void * b_ptr = vx;\n    const void * a_ptr = vy;\n    float * res_ptr = s;\n    size_t res_stride = bs * sizeof(float);\n\n    __asm__ __volatile__(\n        \"mov x10, %x[nr]\\n\"\n        \"mov x9, #0x88\\n\"\n        \"cmp x10, #0x10\\n\"\n        \"mul x9, %x[nb], x9\\n\"\n        \"blt 4f\\n\"\n        \"1:\"  // Row loop\n        \"add x28, %x[b_ptr], #0x8\\n\"\n        \"mov x27, %x[nc]\\n\"\n        \"add x26, %x[res_ptr], %x[res_stride], LSL #4\\n\"\n        \"2:\"  // Column loop\n        \"add x25, %x[a_ptr], #0x8\\n\"\n        \"movi v15.16b, #0x0\\n\"\n        \"movi v19.16b, #0x0\\n\"\n        \"mov x24, %x[nb]\\n\"\n        \"add x23, x25, x9\\n\"\n        \"movi v18.16b, #0x0\\n\"\n        \"movi v14.16b, #0x0\\n\"\n        \"add x22, x23, x9\\n\"\n        \"movi v11.16b, #0x0\\n\"\n        \"movi v13.16b, #0x0\\n\"\n        \"add x21, x22, x9\\n\"\n        \"movi v23.16b, #0x0\\n\"\n        \"movi v16.16b, #0x0\\n\"\n        \"movi v25.16b, #0x0\\n\"\n        \"movi v7.16b, #0x0\\n\"\n        \"movi v0.16b, #0x0\\n\"\n        \"movi v4.16b, #0x0\\n\"\n        \"movi v5.16b, #0x0\\n\"\n        \"movi v21.16b, #0x0\\n\"\n        \"movi v8.16b, #0x0\\n\"\n        \"movi v1.16b, #0x0\\n\"\n        \"3:\"  // Block loop\n        \"ldr q3, [x28, #0x0]\\n\"\n        \"ldr q31, [x25, #0x0]\\n\"\n        \"movi v28.16b, #0x4\\n\"\n        \"movi v10.4s, #0x0\\n\"\n        \"ldr q22, [x28, #0x10]\\n\"\n        \"ldr q6, [x25, #0x10]\\n\"\n        \"movi v29.4s, #0x0\\n\"\n        \"movi v9.4s, #0x0\\n\"\n        \"ldr q27, [x28, #0x20]\\n\"\n        \"ldr q30, [x28, #0x30]\\n\"\n        \"movi v20.4s, #0x0\\n\"\n        \"movi v24.16b, #0xf0\\n\"\n        \"ldr d2, [x25, #-0x8]\\n\"\n        \"ldr d26, [x23, #-0x8]\\n\"\n        \"sshl v12.16b, v3.16b, v28.16b\\n\"\n        \"sub x20, x28, #0x8\\n\"\n        \"ldr d17, [x20, #0x0]\\n\"\n        \"and v3.16b, v3.16b, v24.16b\\n\"\n        \"subs x24, x24, #0x1\\n\"\n        \"add x28, x28, #0x48\\n\"\n        \".inst 0x4f9fe18a  // sdot v10.4s, v12.16b, v31.4b[0]\\n\"\n        \".inst 0x4fbfe19d  // sdot v29.4s, v12.16b, v31.4b[1]\\n\"\n        \".inst 0x4f9fe989  // sdot v9.4s, v12.16b, v31.4b[2]\\n\"\n        \".inst 0x4fbfe994  // sdot v20.4s, v12.16b, v31.4b[3]\\n\"\n        \"sshl v31.16b, v22.16b, v28.16b\\n\"\n        \"and v22.16b, v22.16b, v24.16b\\n\"\n        \"fcvtl v17.4s, v17.4h\\n\"\n        \"fcvtl v2.4s, v2.4h\\n\"\n        \"fcvtl v26.4s, v26.4h\\n\"\n        \".inst 0x4f86e3ea  // sdot v10.4s, v31.16b, v6.4b[0]\\n\"\n        \".inst 0x4fa6e3fd  // sdot v29.4s, v31.16b, v6.4b[1]\\n\"\n        \".inst 0x4f86ebe9  // sdot v9.4s, v31.16b, v6.4b[2]\\n\"\n        \".inst 0x4fa6ebf4  // sdot v20.4s, v31.16b, v6.4b[3]\\n\"\n        \"sshl v6.16b, v27.16b, v28.16b\\n\"\n        \"sshl v28.16b, v30.16b, v28.16b\\n\"\n        \"and v27.16b, v27.16b, v24.16b\\n\"\n        \"and v30.16b, v30.16b, v24.16b\\n\"\n        \"ldr q24, [x25, #0x20]\\n\"\n        \".inst 0x4f98e0ca  // sdot v10.4s, v6.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98e8c9  // sdot v9.4s, v6.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8e8d4  // sdot v20.4s, v6.16b, v24.4b[3]\\n\"\n        \"ldr q24, [x25, #0x30]\\n\"\n        \".inst 0x4f98e38a  // sdot v10.4s, v28.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e39d  // sdot v29.4s, v28.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98eb89  // sdot v9.4s, v28.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8eb94  // sdot v20.4s, v28.16b, v24.4b[3]\\n\"\n        \"ldr q24, [x25, #0x40]\\n\"\n        \".inst 0x4f98e06a  // sdot v10.4s, v3.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98e869  // sdot v9.4s, v3.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8e874  // sdot v20.4s, v3.16b, v24.4b[3]\\n\"\n        \"ldr q24, [x25, #0x50]\\n\"\n        \".inst 0x4f98e2ca  // sdot v10.4s, v22.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e2dd  // sdot v29.4s, v22.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98eac9  // sdot v9.4s, v22.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8ead4  // sdot v20.4s, v22.16b, v24.4b[3]\\n\"\n        \"ldr q24, [x25, #0x60]\\n\"\n        \".inst 0x4f98e36a  // sdot v10.4s, v27.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98eb69  // sdot v9.4s, v27.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8eb74  // sdot v20.4s, v27.16b, v24.4b[3]\\n\"\n        \"ldr q24, [x25, #0x70]\\n\"\n        \"add x25, x25, #0x88\\n\"\n        \".inst 0x4f98e3ca  // sdot v10.4s, v30.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e3dd  // sdot v29.4s, v30.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98ebc9  // sdot v9.4s, v30.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8ebd4  // sdot v20.4s, v30.16b, v24.4b[3]\\n\"\n        \"fmul v24.4s, v17.4s, v2.s[0]\\n\"\n        \"scvtf v10.4s, v10.4s, #0x4\\n\"\n        \"scvtf v29.4s, v29.4s, #0x4\\n\"\n        \"scvtf v9.4s, v9.4s, #0x4\\n\"\n        \"scvtf v20.4s, v20.4s, #0x4\\n\"\n        \"fmla v15.4s, v10.4s, v24.4s\\n\"\n        \"ldr q24, [x23, #0x0]\\n\"\n        \"fmul v10.4s, v17.4s, v2.s[1]\\n\"\n        \"fmla v19.4s, v29.4s, v10.4s\\n\"\n        \"ldr q10, [x23, #0x10]\\n\"\n        \"fmul v29.4s, v17.4s, v2.s[2]\\n\"\n        \"fmul v2.4s, v17.4s, v2.s[3]\\n\"\n        \"fmla v18.4s, v9.4s, v29.4s\\n\"\n        \"movi v9.4s, #0x0\\n\"\n        \"movi v29.4s, #0x0\\n\"\n        \".inst 0x4f98e189  // sdot v9.4s, v12.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e19d  // sdot v29.4s, v12.16b, v24.4b[1]\\n\"\n        \"fmla v14.4s, v20.4s, v2.4s\\n\"\n        \"movi v20.4s, #0x0\\n\"\n        \"movi v2.4s, #0x0\\n\"\n        \".inst 0x4f98e994  // sdot v20.4s, v12.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\\n\"\n        \"ldr q24, [x23, #0x20]\\n\"\n        \".inst 0x4f8ae3e9  // sdot v9.4s, v31.16b, v10.4b[0]\\n\"\n        \".inst 0x4faae3fd  // sdot v29.4s, v31.16b, v10.4b[1]\\n\"\n        \".inst 0x4f8aebf4  // sdot v20.4s, v31.16b, v10.4b[2]\\n\"\n        \".inst 0x4faaebe2  // sdot v2.4s, v31.16b, v10.4b[3]\\n\"\n        \"ldr q10, [x23, #0x30]\\n\"\n        \".inst 0x4f98e0c9  // sdot v9.4s, v6.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98e8d4  // sdot v20.4s, v6.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\\n\"\n        \"ldr q24, [x23, #0x40]\\n\"\n        \".inst 0x4f8ae389  // sdot v9.4s, v28.16b, v10.4b[0]\\n\"\n        \".inst 0x4faae39d  // sdot v29.4s, v28.16b, v10.4b[1]\\n\"\n        \".inst 0x4f8aeb94  // sdot v20.4s, v28.16b, v10.4b[2]\\n\"\n        \".inst 0x4faaeb82  // sdot v2.4s, v28.16b, v10.4b[3]\\n\"\n        \"ldr q10, [x23, #0x50]\\n\"\n        \".inst 0x4f98e069  // sdot v9.4s, v3.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98e874  // sdot v20.4s, v3.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\\n\"\n        \"ldr q24, [x23, #0x60]\\n\"\n        \".inst 0x4f8ae2c9  // sdot v9.4s, v22.16b, v10.4b[0]\\n\"\n        \".inst 0x4faae2dd  // sdot v29.4s, v22.16b, v10.4b[1]\\n\"\n        \".inst 0x4f8aead4  // sdot v20.4s, v22.16b, v10.4b[2]\\n\"\n        \".inst 0x4faaeac2  // sdot v2.4s, v22.16b, v10.4b[3]\\n\"\n        \"ldr q10, [x23, #0x70]\\n\"\n        \"add x23, x23, #0x88\\n\"\n        \".inst 0x4f98e369  // sdot v9.4s, v27.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98eb74  // sdot v20.4s, v27.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\\n\"\n        \"ldr q24, [x22, #0x0]\\n\"\n        \".inst 0x4f8ae3c9  // sdot v9.4s, v30.16b, v10.4b[0]\\n\"\n        \".inst 0x4faae3dd  // sdot v29.4s, v30.16b, v10.4b[1]\\n\"\n        \".inst 0x4f8aebd4  // sdot v20.4s, v30.16b, v10.4b[2]\\n\"\n        \".inst 0x4faaebc2  // sdot v2.4s, v30.16b, v10.4b[3]\\n\"\n        \"fmul v10.4s, v17.4s, v26.s[0]\\n\"\n        \"scvtf v9.4s, v9.4s, #0x4\\n\"\n        \"scvtf v29.4s, v29.4s, #0x4\\n\"\n        \"scvtf v20.4s, v20.4s, #0x4\\n\"\n        \"scvtf v2.4s, v2.4s, #0x4\\n\"\n        \"fmla v11.4s, v9.4s, v10.4s\\n\"\n        \"ldr q9, [x22, #0x10]\\n\"\n        \"fmul v10.4s, v17.4s, v26.s[1]\\n\"\n        \"fmla v13.4s, v29.4s, v10.4s\\n\"\n        \"ldr d29, [x22, #-0x8]\\n\"\n        \"fmul v10.4s, v17.4s, v26.s[2]\\n\"\n        \"fmul v26.4s, v17.4s, v26.s[3]\\n\"\n        \"fcvtl v29.4s, v29.4h\\n\"\n        \"fmla v23.4s, v20.4s, v10.4s\\n\"\n        \"movi v20.4s, #0x0\\n\"\n        \"movi v10.4s, #0x0\\n\"\n        \"fmla v16.4s, v2.4s, v26.4s\\n\"\n        \"movi v26.4s, #0x0\\n\"\n        \"movi v2.4s, #0x0\\n\"\n        \".inst 0x4f98e194  // sdot v20.4s, v12.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e18a  // sdot v10.4s, v12.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98e99a  // sdot v26.4s, v12.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\\n\"\n        \"ldr q24, [x22, #0x20]\\n\"\n        \".inst 0x4f89e3f4  // sdot v20.4s, v31.16b, v9.4b[0]\\n\"\n        \".inst 0x4fa9e3ea  // sdot v10.4s, v31.16b, v9.4b[1]\\n\"\n        \".inst 0x4f89ebfa  // sdot v26.4s, v31.16b, v9.4b[2]\\n\"\n        \".inst 0x4fa9ebe2  // sdot v2.4s, v31.16b, v9.4b[3]\\n\"\n        \"ldr q9, [x22, #0x30]\\n\"\n        \".inst 0x4f98e0d4  // sdot v20.4s, v6.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e0ca  // sdot v10.4s, v6.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98e8da  // sdot v26.4s, v6.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\\n\"\n        \"ldr q24, [x22, #0x40]\\n\"\n        \".inst 0x4f89e394  // sdot v20.4s, v28.16b, v9.4b[0]\\n\"\n        \".inst 0x4fa9e38a  // sdot v10.4s, v28.16b, v9.4b[1]\\n\"\n        \".inst 0x4f89eb9a  // sdot v26.4s, v28.16b, v9.4b[2]\\n\"\n        \".inst 0x4fa9eb82  // sdot v2.4s, v28.16b, v9.4b[3]\\n\"\n        \"ldr q9, [x22, #0x50]\\n\"\n        \".inst 0x4f98e074  // sdot v20.4s, v3.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e06a  // sdot v10.4s, v3.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98e87a  // sdot v26.4s, v3.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\\n\"\n        \"ldr q24, [x22, #0x60]\\n\"\n        \".inst 0x4f89e2d4  // sdot v20.4s, v22.16b, v9.4b[0]\\n\"\n        \".inst 0x4fa9e2ca  // sdot v10.4s, v22.16b, v9.4b[1]\\n\"\n        \".inst 0x4f89eada  // sdot v26.4s, v22.16b, v9.4b[2]\\n\"\n        \".inst 0x4fa9eac2  // sdot v2.4s, v22.16b, v9.4b[3]\\n\"\n        \"ldr q9, [x22, #0x70]\\n\"\n        \"add x22, x22, #0x88\\n\"\n        \".inst 0x4f98e374  // sdot v20.4s, v27.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e36a  // sdot v10.4s, v27.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98eb7a  // sdot v26.4s, v27.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\\n\"\n        \"ldr q24, [x21, #0x0]\\n\"\n        \".inst 0x4f89e3d4  // sdot v20.4s, v30.16b, v9.4b[0]\\n\"\n        \".inst 0x4fa9e3ca  // sdot v10.4s, v30.16b, v9.4b[1]\\n\"\n        \".inst 0x4f89ebda  // sdot v26.4s, v30.16b, v9.4b[2]\\n\"\n        \".inst 0x4fa9ebc2  // sdot v2.4s, v30.16b, v9.4b[3]\\n\"\n        \"fmul v9.4s, v17.4s, v29.s[0]\\n\"\n        \"scvtf v20.4s, v20.4s, #0x4\\n\"\n        \"scvtf v10.4s, v10.4s, #0x4\\n\"\n        \"scvtf v26.4s, v26.4s, #0x4\\n\"\n        \"scvtf v2.4s, v2.4s, #0x4\\n\"\n        \"fmla v25.4s, v20.4s, v9.4s\\n\"\n        \"ldr q9, [x21, #0x10]\\n\"\n        \"fmul v20.4s, v17.4s, v29.s[1]\\n\"\n        \"fmla v7.4s, v10.4s, v20.4s\\n\"\n        \"ldr d20, [x21, #-0x8]\\n\"\n        \"fmul v10.4s, v17.4s, v29.s[2]\\n\"\n        \"fmul v29.4s, v17.4s, v29.s[3]\\n\"\n        \"fcvtl v20.4s, v20.4h\\n\"\n        \"fmla v0.4s, v26.4s, v10.4s\\n\"\n        \"movi v26.4s, #0x0\\n\"\n        \"movi v10.4s, #0x0\\n\"\n        \"fmla v4.4s, v2.4s, v29.4s\\n\"\n        \"movi v2.4s, #0x0\\n\"\n        \"movi v29.4s, #0x0\\n\"\n        \".inst 0x4f98e19a  // sdot v26.4s, v12.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e18a  // sdot v10.4s, v12.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98e982  // sdot v2.4s, v12.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8e99d  // sdot v29.4s, v12.16b, v24.4b[3]\\n\"\n        \"ldr q12, [x21, #0x20]\\n\"\n        \"fmul v24.4s, v17.4s, v20.s[0]\\n\"\n        \".inst 0x4f89e3fa  // sdot v26.4s, v31.16b, v9.4b[0]\\n\"\n        \".inst 0x4fa9e3ea  // sdot v10.4s, v31.16b, v9.4b[1]\\n\"\n        \".inst 0x4f89ebe2  // sdot v2.4s, v31.16b, v9.4b[2]\\n\"\n        \".inst 0x4fa9ebfd  // sdot v29.4s, v31.16b, v9.4b[3]\\n\"\n        \"ldr q9, [x21, #0x30]\\n\"\n        \"fmul v31.4s, v17.4s, v20.s[1]\\n\"\n        \".inst 0x4f8ce0da  // sdot v26.4s, v6.16b, v12.4b[0]\\n\"\n        \".inst 0x4face0ca  // sdot v10.4s, v6.16b, v12.4b[1]\\n\"\n        \".inst 0x4f8ce8c2  // sdot v2.4s, v6.16b, v12.4b[2]\\n\"\n        \".inst 0x4face8dd  // sdot v29.4s, v6.16b, v12.4b[3]\\n\"\n        \"ldr q12, [x21, #0x40]\\n\"\n        \"fmul v6.4s, v17.4s, v20.s[2]\\n\"\n        \"fmul v20.4s, v17.4s, v20.s[3]\\n\"\n        \".inst 0x4f89e39a  // sdot v26.4s, v28.16b, v9.4b[0]\\n\"\n        \".inst 0x4fa9e38a  // sdot v10.4s, v28.16b, v9.4b[1]\\n\"\n        \".inst 0x4f89eb82  // sdot v2.4s, v28.16b, v9.4b[2]\\n\"\n        \".inst 0x4fa9eb9d  // sdot v29.4s, v28.16b, v9.4b[3]\\n\"\n        \"ldr q9, [x21, #0x50]\\n\"\n        \".inst 0x4f8ce07a  // sdot v26.4s, v3.16b, v12.4b[0]\\n\"\n        \".inst 0x4face06a  // sdot v10.4s, v3.16b, v12.4b[1]\\n\"\n        \".inst 0x4f8ce862  // sdot v2.4s, v3.16b, v12.4b[2]\\n\"\n        \".inst 0x4face87d  // sdot v29.4s, v3.16b, v12.4b[3]\\n\"\n        \"ldr q12, [x21, #0x60]\\n\"\n        \".inst 0x4f89e2da  // sdot v26.4s, v22.16b, v9.4b[0]\\n\"\n        \".inst 0x4fa9e2ca  // sdot v10.4s, v22.16b, v9.4b[1]\\n\"\n        \".inst 0x4f89eac2  // sdot v2.4s, v22.16b, v9.4b[2]\\n\"\n        \".inst 0x4fa9eadd  // sdot v29.4s, v22.16b, v9.4b[3]\\n\"\n        \"ldr q17, [x21, #0x70]\\n\"\n        \"add x21, x21, #0x88\\n\"\n        \".inst 0x4f8ce37a  // sdot v26.4s, v27.16b, v12.4b[0]\\n\"\n        \".inst 0x4face36a  // sdot v10.4s, v27.16b, v12.4b[1]\\n\"\n        \".inst 0x4f8ceb62  // sdot v2.4s, v27.16b, v12.4b[2]\\n\"\n        \".inst 0x4faceb7d  // sdot v29.4s, v27.16b, v12.4b[3]\\n\"\n        \".inst 0x4f91e3da  // sdot v26.4s, v30.16b, v17.4b[0]\\n\"\n        \".inst 0x4fb1e3ca  // sdot v10.4s, v30.16b, v17.4b[1]\\n\"\n        \".inst 0x4f91ebc2  // sdot v2.4s, v30.16b, v17.4b[2]\\n\"\n        \".inst 0x4fb1ebdd  // sdot v29.4s, v30.16b, v17.4b[3]\\n\"\n        \"scvtf v26.4s, v26.4s, #0x4\\n\"\n        \"scvtf v10.4s, v10.4s, #0x4\\n\"\n        \"fmla v5.4s, v26.4s, v24.4s\\n\"\n        \"scvtf v2.4s, v2.4s, #0x4\\n\"\n        \"scvtf v29.4s, v29.4s, #0x4\\n\"\n        \"fmla v21.4s, v10.4s, v31.4s\\n\"\n        \"fmla v8.4s, v2.4s, v6.4s\\n\"\n        \"fmla v1.4s, v29.4s, v20.4s\\n\"\n        \"bgt 3b\\n\"\n        \"mov x20, %x[res_ptr]\\n\"\n        \"subs x27, x27, #0x4\\n\"\n        \"add %x[res_ptr], %x[res_ptr], #0x10\\n\"\n        \"str q15, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q19, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q18, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q14, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q11, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q13, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q23, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q16, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q25, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q7, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q0, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q4, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q5, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q21, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q8, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q1, [x20, #0x0]\\n\"\n        \"bne 2b\\n\"\n        \"mov x20, #0x4\\n\"\n        \"sub x10, x10, #0x10\\n\"\n        \"cmp x10, #0x10\\n\"\n        \"mov %x[res_ptr], x26\\n\"\n        \"madd %x[a_ptr], x20, x9, %x[a_ptr]\\n\"\n        \"bge 1b\\n\"\n        \"4:\"  // Row loop skip\n        \"cbz x10, 9f\\n\"\n        \"5:\"  // Row tail: Row loop\n        \"add x24, %x[b_ptr], #0x8\\n\"\n        \"mov x23, %x[nc]\\n\"\n        \"add x22, %x[res_ptr], %x[res_stride], LSL #2\\n\"\n        \"6:\"  // Row tail: Column loop\n        \"movi v15.16b, #0x0\\n\"\n        \"movi v19.16b, #0x0\\n\"\n        \"add x25, %x[a_ptr], #0x8\\n\"\n        \"mov x21, %x[nb]\\n\"\n        \"movi v18.16b, #0x0\\n\"\n        \"movi v14.16b, #0x0\\n\"\n        \"7:\"  // Row tail: Block loop\n        \"ldr q7, [x24, #0x0]\\n\"\n        \"ldr q5, [x25, #0x0]\\n\"\n        \"movi v9.16b, #0x4\\n\"\n        \"movi v4.4s, #0x0\\n\"\n        \"ldr q3, [x24, #0x10]\\n\"\n        \"ldr q2, [x25, #0x10]\\n\"\n        \"movi v1.4s, #0x0\\n\"\n        \"movi v0.4s, #0x0\\n\"\n        \"ldr q13, [x24, #0x20]\\n\"\n        \"ldr q31, [x25, #0x20]\\n\"\n        \"movi v30.4s, #0x0\\n\"\n        \"movi v29.16b, #0xf0\\n\"\n        \"ldr q28, [x24, #0x30]\\n\"\n        \"ldr q27, [x25, #0x30]\\n\"\n        \"sshl v20.16b, v7.16b, v9.16b\\n\"\n        \"sub x20, x24, #0x8\\n\"\n        \"ldr q26, [x25, #0x40]\\n\"\n        \"ldr q25, [x25, #0x50]\\n\"\n        \"sshl v17.16b, v3.16b, v9.16b\\n\"\n        \"and v7.16b, v7.16b, v29.16b\\n\"\n        \"ldr q24, [x25, #0x60]\\n\"\n        \"ldr q16, [x25, #0x70]\\n\"\n        \"sshl v22.16b, v13.16b, v9.16b\\n\"\n        \"and v3.16b, v3.16b, v29.16b\\n\"\n        \"ldr d21, [x20, #0x0]\\n\"\n        \"ldr d12, [x25, #-0x8]\\n\"\n        \".inst 0x4f85e284  // sdot v4.4s, v20.16b, v5.4b[0]\\n\"\n        \".inst 0x4fa5e281  // sdot v1.4s, v20.16b, v5.4b[1]\\n\"\n        \".inst 0x4f85ea80  // sdot v0.4s, v20.16b, v5.4b[2]\\n\"\n        \".inst 0x4fa5ea9e  // sdot v30.4s, v20.16b, v5.4b[3]\\n\"\n        \"sshl v9.16b, v28.16b, v9.16b\\n\"\n        \"subs x21, x21, #0x1\\n\"\n        \"and v13.16b, v13.16b, v29.16b\\n\"\n        \"and v28.16b, v28.16b, v29.16b\\n\"\n        \"add x25, x25, #0x88\\n\"\n        \"add x24, x24, #0x48\\n\"\n        \"fcvtl v21.4s, v21.4h\\n\"\n        \"fcvtl v12.4s, v12.4h\\n\"\n        \".inst 0x4f82e224  // sdot v4.4s, v17.16b, v2.4b[0]\\n\"\n        \".inst 0x4fa2e221  // sdot v1.4s, v17.16b, v2.4b[1]\\n\"\n        \".inst 0x4f82ea20  // sdot v0.4s, v17.16b, v2.4b[2]\\n\"\n        \".inst 0x4fa2ea3e  // sdot v30.4s, v17.16b, v2.4b[3]\\n\"\n        \"fmul v11.4s, v21.4s, v12.s[0]\\n\"\n        \"fmul v23.4s, v21.4s, v12.s[1]\\n\"\n        \"fmul v17.4s, v21.4s, v12.s[2]\\n\"\n        \".inst 0x4f9fe2c4  // sdot v4.4s, v22.16b, v31.4b[0]\\n\"\n        \"fmul v6.4s, v21.4s, v12.s[3]\\n\"\n        \".inst 0x4fbfe2c1  // sdot v1.4s, v22.16b, v31.4b[1]\\n\"\n        \".inst 0x4f9feac0  // sdot v0.4s, v22.16b, v31.4b[2]\\n\"\n        \".inst 0x4fbfeade  // sdot v30.4s, v22.16b, v31.4b[3]\\n\"\n        \".inst 0x4f9be124  // sdot v4.4s, v9.16b, v27.4b[0]\\n\"\n        \".inst 0x4fbbe121  // sdot v1.4s, v9.16b, v27.4b[1]\\n\"\n        \".inst 0x4f9be920  // sdot v0.4s, v9.16b, v27.4b[2]\\n\"\n        \".inst 0x4fbbe93e  // sdot v30.4s, v9.16b, v27.4b[3]\\n\"\n        \".inst 0x4f9ae0e4  // sdot v4.4s, v7.16b, v26.4b[0]\\n\"\n        \".inst 0x4fbae0e1  // sdot v1.4s, v7.16b, v26.4b[1]\\n\"\n        \".inst 0x4f9ae8e0  // sdot v0.4s, v7.16b, v26.4b[2]\\n\"\n        \".inst 0x4fbae8fe  // sdot v30.4s, v7.16b, v26.4b[3]\\n\"\n        \".inst 0x4f99e064  // sdot v4.4s, v3.16b, v25.4b[0]\\n\"\n        \".inst 0x4fb9e061  // sdot v1.4s, v3.16b, v25.4b[1]\\n\"\n        \".inst 0x4f99e860  // sdot v0.4s, v3.16b, v25.4b[2]\\n\"\n        \".inst 0x4fb9e87e  // sdot v30.4s, v3.16b, v25.4b[3]\\n\"\n        \".inst 0x4f98e1a4  // sdot v4.4s, v13.16b, v24.4b[0]\\n\"\n        \".inst 0x4fb8e1a1  // sdot v1.4s, v13.16b, v24.4b[1]\\n\"\n        \".inst 0x4f98e9a0  // sdot v0.4s, v13.16b, v24.4b[2]\\n\"\n        \".inst 0x4fb8e9be  // sdot v30.4s, v13.16b, v24.4b[3]\\n\"\n        \".inst 0x4f90e384  // sdot v4.4s, v28.16b, v16.4b[0]\\n\"\n        \".inst 0x4fb0e381  // sdot v1.4s, v28.16b, v16.4b[1]\\n\"\n        \".inst 0x4f90eb80  // sdot v0.4s, v28.16b, v16.4b[2]\\n\"\n        \".inst 0x4fb0eb9e  // sdot v30.4s, v28.16b, v16.4b[3]\\n\"\n        \"scvtf v4.4s, v4.4s, #0x4\\n\"\n        \"scvtf v1.4s, v1.4s, #0x4\\n\"\n        \"scvtf v0.4s, v0.4s, #0x4\\n\"\n        \"fmla v15.4s, v4.4s, v11.4s\\n\"\n        \"scvtf v30.4s, v30.4s, #0x4\\n\"\n        \"fmla v19.4s, v1.4s, v23.4s\\n\"\n        \"fmla v18.4s, v0.4s, v17.4s\\n\"\n        \"fmla v14.4s, v30.4s, v6.4s\\n\"\n        \"bgt 7b\\n\"\n        \"mov x20, %x[res_ptr]\\n\"\n        \"cmp x10, #0x1\\n\"\n        \"str q15, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"ble 8f\\n\"\n        \"cmp x10, #0x2\\n\"\n        \"str q19, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"ble 8f\\n\"\n        \"cmp x10, #0x3\\n\"\n        \"str q18, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"ble 8f\\n\"\n        \"str q14, [x20, #0x0]\\n\"\n        \"8:\"  // Row tail: Accumulator store skip\n        \"subs x23, x23, #0x4\\n\"\n        \"add %x[res_ptr], %x[res_ptr], #0x10\\n\"\n        \"bne 6b\\n\"\n        \"subs x10, x10, #0x4\\n\"\n        \"add %x[a_ptr], %x[a_ptr], x9\\n\"\n        \"mov %x[res_ptr], x22\\n\"\n        \"bgt 5b\\n\"\n        \"9:\"  // Row tail: Row loop skip\n        : [a_ptr] \"+&r\" (a_ptr), [res_ptr] \"+&r\" (res_ptr)\n        : [b_ptr] \"r\" (b_ptr), [nr] \"r\" (nr), [nb] \"r\" (nb), [res_stride] \"r\" (res_stride), [nc] \"r\" (nc)\n        : \"cc\", \"memory\", \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\", \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\", \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\", \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\", \"x9\", \"x10\", \"x20\", \"x21\", \"x22\", \"x23\", \"x24\", \"x25\", \"x26\", \"x27\", \"x28\"\n    );\n    return;\n#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)\n    ggml_gemm_q4_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 8;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)\n    const void * b_ptr = vx;\n    const void * a_ptr = vy;\n    float * res_ptr = s;\n    size_t res_stride = bs * sizeof(float);\n\n    __asm__ __volatile__(\n        \"mov x10, %x[nr]\\n\"\n        \"mov x9, #0x88\\n\"\n        \"cmp x10, #0x10\\n\"\n        \"mul x9, %x[nb], x9\\n\"\n        \"blt 4f\\n\"\n        \"1:\"  // Row loop\n        \"add x28, %x[b_ptr], #0x8\\n\"\n        \"mov x27, %x[nc]\\n\"\n        \"add x26, %x[res_ptr], %x[res_stride], LSL #4\\n\"\n        \"2:\"  // Column loop\n        \"add x25, %x[a_ptr], #0x8\\n\"\n        \"movi v2.16b, #0x0\\n\"\n        \"movi v10.16b, #0x0\\n\"\n        \"mov x24, %x[nb]\\n\"\n        \"add x23, x25, x9\\n\"\n        \"movi v12.16b, #0x0\\n\"\n        \"movi v28.16b, #0x0\\n\"\n        \"add x22, x23, x9\\n\"\n        \"movi v11.16b, #0x0\\n\"\n        \"movi v13.16b, #0x0\\n\"\n        \"add x21, x22, x9\\n\"\n        \"movi v22.16b, #0x0\\n\"\n        \"movi v23.16b, #0x0\\n\"\n        \"movi v25.16b, #0x0\\n\"\n        \"movi v5.16b, #0x0\\n\"\n        \"movi v7.16b, #0x0\\n\"\n        \"movi v4.16b, #0x0\\n\"\n        \"movi v6.16b, #0x0\\n\"\n        \"movi v30.16b, #0x0\\n\"\n        \"movi v24.16b, #0x0\\n\"\n        \"movi v14.16b, #0x0\\n\"\n        \"3:\"  // Block loop\n        \"ldr q21, [x28, #0x0]\\n\"\n        \"ldr q16, [x28, #0x10]\\n\"\n        \"movi v1.16b, #0x4\\n\"\n        \"movi v19.4s, #0x0\\n\"\n        \"ldr q27, [x25, #0x0]\\n\"\n        \"ldr q15, [x25, #0x10]\\n\"\n        \"movi v26.4s, #0x0\\n\"\n        \"movi v18.4s, #0x0\\n\"\n        \"ldr q29, [x28, #0x20]\\n\"\n        \"ldr q3, [x28, #0x30]\\n\"\n        \"movi v17.4s, #0x0\\n\"\n        \"movi v0.16b, #0xf0\\n\"\n        \"ldr d20, [x25, #-0x8]\\n\"\n        \"ldr d9, [x23, #-0x8]\\n\"\n        \"sshl v8.16b, v21.16b, v1.16b\\n\"\n        \"sshl v31.16b, v16.16b, v1.16b\\n\"\n        \"and v21.16b, v21.16b, v0.16b\\n\"\n        \"and v16.16b, v16.16b, v0.16b\\n\"\n        \"sub x20, x28, #0x8\\n\"\n        \"subs x24, x24, #0x1\\n\"\n        \"add x28, x28, #0x48\\n\"\n        \".inst 0x4e88a773  // smmla v19.4s, v27.16b, v8.16b\\n\"\n        \".inst 0x4e9fa77a  // smmla v26.4s, v27.16b, v31.16b\\n\"\n        \"ldr q27, [x25, #0x20]\\n\"\n        \".inst 0x4e88a5f2  // smmla v18.4s, v15.16b, v8.16b\\n\"\n        \".inst 0x4e9fa5f1  // smmla v17.4s, v15.16b, v31.16b\\n\"\n        \"sshl v15.16b, v29.16b, v1.16b\\n\"\n        \"sshl v1.16b, v3.16b, v1.16b\\n\"\n        \"and v29.16b, v29.16b, v0.16b\\n\"\n        \"and v3.16b, v3.16b, v0.16b\\n\"\n        \"ldr q0, [x25, #0x30]\\n\"\n        \"fcvtl v20.4s, v20.4h\\n\"\n        \".inst 0x4e8fa773  // smmla v19.4s, v27.16b, v15.16b\\n\"\n        \"fcvtl v9.4s, v9.4h\\n\"\n        \".inst 0x4e81a77a  // smmla v26.4s, v27.16b, v1.16b\\n\"\n        \"ldr q27, [x25, #0x40]\\n\"\n        \".inst 0x4e8fa412  // smmla v18.4s, v0.16b, v15.16b\\n\"\n        \".inst 0x4e81a411  // smmla v17.4s, v0.16b, v1.16b\\n\"\n        \"ldr q0, [x25, #0x50]\\n\"\n        \".inst 0x4e95a773  // smmla v19.4s, v27.16b, v21.16b\\n\"\n        \".inst 0x4e90a77a  // smmla v26.4s, v27.16b, v16.16b\\n\"\n        \"ldr q27, [x25, #0x60]\\n\"\n        \".inst 0x4e95a412  // smmla v18.4s, v0.16b, v21.16b\\n\"\n        \".inst 0x4e90a411  // smmla v17.4s, v0.16b, v16.16b\\n\"\n        \"ldr q0, [x25, #0x70]\\n\"\n        \"add x25, x25, #0x88\\n\"\n        \".inst 0x4e9da773  // smmla v19.4s, v27.16b, v29.16b\\n\"\n        \".inst 0x4e83a77a  // smmla v26.4s, v27.16b, v3.16b\\n\"\n        \"ldr d27, [x20, #0x0]\\n\"\n        \".inst 0x4e9da412  // smmla v18.4s, v0.16b, v29.16b\\n\"\n        \".inst 0x4e83a411  // smmla v17.4s, v0.16b, v3.16b\\n\"\n        \"fcvtl v27.4s, v27.4h\\n\"\n        \"uzp1 v0.2d, v19.2d, v26.2d\\n\"\n        \"uzp2 v26.2d, v19.2d, v26.2d\\n\"\n        \"fmul v19.4s, v27.4s, v20.s[0]\\n\"\n        \"scvtf v0.4s, v0.4s, #0x4\\n\"\n        \"scvtf v26.4s, v26.4s, #0x4\\n\"\n        \"fmla v2.4s, v0.4s, v19.4s\\n\"\n        \"ldr q19, [x23, #0x0]\\n\"\n        \"uzp1 v0.2d, v18.2d, v17.2d\\n\"\n        \"uzp2 v18.2d, v18.2d, v17.2d\\n\"\n        \"fmul v17.4s, v27.4s, v20.s[1]\\n\"\n        \"scvtf v0.4s, v0.4s, #0x4\\n\"\n        \"scvtf v18.4s, v18.4s, #0x4\\n\"\n        \"fmla v10.4s, v26.4s, v17.4s\\n\"\n        \"ldr q17, [x23, #0x10]\\n\"\n        \"fmul v26.4s, v27.4s, v20.s[2]\\n\"\n        \"fmul v20.4s, v27.4s, v20.s[3]\\n\"\n        \"fmla v12.4s, v0.4s, v26.4s\\n\"\n        \"ldr d0, [x22, #-0x8]\\n\"\n        \"ldr d26, [x21, #-0x8]\\n\"\n        \"fcvtl v0.4s, v0.4h\\n\"\n        \"fmla v28.4s, v18.4s, v20.4s\\n\"\n        \"movi v20.4s, #0x0\\n\"\n        \"movi v18.4s, #0x0\\n\"\n        \".inst 0x4e88a674  // smmla v20.4s, v19.16b, v8.16b\\n\"\n        \".inst 0x4e9fa672  // smmla v18.4s, v19.16b, v31.16b\\n\"\n        \"ldr q19, [x23, #0x20]\\n\"\n        \"fcvtl v26.4s, v26.4h\\n\"\n        \".inst 0x4e8fa674  // smmla v20.4s, v19.16b, v15.16b\\n\"\n        \".inst 0x4e81a672  // smmla v18.4s, v19.16b, v1.16b\\n\"\n        \"ldr q19, [x23, #0x40]\\n\"\n        \".inst 0x4e95a674  // smmla v20.4s, v19.16b, v21.16b\\n\"\n        \".inst 0x4e90a672  // smmla v18.4s, v19.16b, v16.16b\\n\"\n        \"ldr q19, [x23, #0x60]\\n\"\n        \".inst 0x4e9da674  // smmla v20.4s, v19.16b, v29.16b\\n\"\n        \".inst 0x4e83a672  // smmla v18.4s, v19.16b, v3.16b\\n\"\n        \"uzp1 v19.2d, v20.2d, v18.2d\\n\"\n        \"scvtf v19.4s, v19.4s, #0x4\\n\"\n        \"uzp2 v20.2d, v20.2d, v18.2d\\n\"\n        \"fmul v18.4s, v27.4s, v9.s[0]\\n\"\n        \"scvtf v20.4s, v20.4s, #0x4\\n\"\n        \"fmla v11.4s, v19.4s, v18.4s\\n\"\n        \"ldr q18, [x22, #0x0]\\n\"\n        \"fmul v19.4s, v27.4s, v9.s[1]\\n\"\n        \"fmla v13.4s, v20.4s, v19.4s\\n\"\n        \"movi v19.4s, #0x0\\n\"\n        \"movi v20.4s, #0x0\\n\"\n        \".inst 0x4e88a633  // smmla v19.4s, v17.16b, v8.16b\\n\"\n        \".inst 0x4e9fa634  // smmla v20.4s, v17.16b, v31.16b\\n\"\n        \"ldr q17, [x23, #0x30]\\n\"\n        \".inst 0x4e8fa633  // smmla v19.4s, v17.16b, v15.16b\\n\"\n        \".inst 0x4e81a634  // smmla v20.4s, v17.16b, v1.16b\\n\"\n        \"ldr q17, [x23, #0x50]\\n\"\n        \".inst 0x4e95a633  // smmla v19.4s, v17.16b, v21.16b\\n\"\n        \".inst 0x4e90a634  // smmla v20.4s, v17.16b, v16.16b\\n\"\n        \"ldr q17, [x23, #0x70]\\n\"\n        \"add x23, x23, #0x88\\n\"\n        \".inst 0x4e9da633  // smmla v19.4s, v17.16b, v29.16b\\n\"\n        \".inst 0x4e83a634  // smmla v20.4s, v17.16b, v3.16b\\n\"\n        \"uzp1 v17.2d, v19.2d, v20.2d\\n\"\n        \"scvtf v17.4s, v17.4s, #0x4\\n\"\n        \"uzp2 v20.2d, v19.2d, v20.2d\\n\"\n        \"fmul v19.4s, v27.4s, v9.s[2]\\n\"\n        \"fmul v9.4s, v27.4s, v9.s[3]\\n\"\n        \"scvtf v20.4s, v20.4s, #0x4\\n\"\n        \"fmla v22.4s, v17.4s, v19.4s\\n\"\n        \"ldr q17, [x22, #0x10]\\n\"\n        \"movi v19.4s, #0x0\\n\"\n        \".inst 0x4e88a653  // smmla v19.4s, v18.16b, v8.16b\\n\"\n        \"fmla v23.4s, v20.4s, v9.4s\\n\"\n        \"movi v20.4s, #0x0\\n\"\n        \"movi v9.4s, #0x0\\n\"\n        \".inst 0x4e9fa654  // smmla v20.4s, v18.16b, v31.16b\\n\"\n        \"ldr q18, [x22, #0x20]\\n\"\n        \".inst 0x4e88a629  // smmla v9.4s, v17.16b, v8.16b\\n\"\n        \".inst 0x4e8fa653  // smmla v19.4s, v18.16b, v15.16b\\n\"\n        \".inst 0x4e81a654  // smmla v20.4s, v18.16b, v1.16b\\n\"\n        \"ldr q18, [x22, #0x40]\\n\"\n        \".inst 0x4e95a653  // smmla v19.4s, v18.16b, v21.16b\\n\"\n        \".inst 0x4e90a654  // smmla v20.4s, v18.16b, v16.16b\\n\"\n        \"ldr q18, [x22, #0x60]\\n\"\n        \".inst 0x4e9da653  // smmla v19.4s, v18.16b, v29.16b\\n\"\n        \".inst 0x4e83a654  // smmla v20.4s, v18.16b, v3.16b\\n\"\n        \"movi v18.4s, #0x0\\n\"\n        \".inst 0x4e9fa632  // smmla v18.4s, v17.16b, v31.16b\\n\"\n        \"ldr q17, [x22, #0x30]\\n\"\n        \".inst 0x4e8fa629  // smmla v9.4s, v17.16b, v15.16b\\n\"\n        \".inst 0x4e81a632  // smmla v18.4s, v17.16b, v1.16b\\n\"\n        \"ldr q17, [x22, #0x50]\\n\"\n        \".inst 0x4e95a629  // smmla v9.4s, v17.16b, v21.16b\\n\"\n        \".inst 0x4e90a632  // smmla v18.4s, v17.16b, v16.16b\\n\"\n        \"ldr q17, [x22, #0x70]\\n\"\n        \"add x22, x22, #0x88\\n\"\n        \".inst 0x4e9da629  // smmla v9.4s, v17.16b, v29.16b\\n\"\n        \".inst 0x4e83a632  // smmla v18.4s, v17.16b, v3.16b\\n\"\n        \"uzp1 v17.2d, v19.2d, v20.2d\\n\"\n        \"uzp2 v20.2d, v19.2d, v20.2d\\n\"\n        \"fmul v19.4s, v27.4s, v0.s[0]\\n\"\n        \"scvtf v17.4s, v17.4s, #0x4\\n\"\n        \"scvtf v20.4s, v20.4s, #0x4\\n\"\n        \"fmla v25.4s, v17.4s, v19.4s\\n\"\n        \"ldr q19, [x21, #0x0]\\n\"\n        \"fmul v17.4s, v27.4s, v0.s[1]\\n\"\n        \"fmla v5.4s, v20.4s, v17.4s\\n\"\n        \"ldr q17, [x21, #0x10]\\n\"\n        \"uzp1 v20.2d, v9.2d, v18.2d\\n\"\n        \"uzp2 v9.2d, v9.2d, v18.2d\\n\"\n        \"fmul v18.4s, v27.4s, v0.s[2]\\n\"\n        \"fmul v0.4s, v27.4s, v0.s[3]\\n\"\n        \"scvtf v20.4s, v20.4s, #0x4\\n\"\n        \"scvtf v9.4s, v9.4s, #0x4\\n\"\n        \"fmla v7.4s, v20.4s, v18.4s\\n\"\n        \"movi v20.4s, #0x0\\n\"\n        \"movi v18.4s, #0x0\\n\"\n        \".inst 0x4e88a674  // smmla v20.4s, v19.16b, v8.16b\\n\"\n        \".inst 0x4e9fa672  // smmla v18.4s, v19.16b, v31.16b\\n\"\n        \"ldr q19, [x21, #0x20]\\n\"\n        \"fmla v4.4s, v9.4s, v0.4s\\n\"\n        \"movi v9.4s, #0x0\\n\"\n        \"movi v0.4s, #0x0\\n\"\n        \".inst 0x4e88a629  // smmla v9.4s, v17.16b, v8.16b\\n\"\n        \"fmul v8.4s, v27.4s, v26.s[0]\\n\"\n        \".inst 0x4e9fa620  // smmla v0.4s, v17.16b, v31.16b\\n\"\n        \"ldr q17, [x21, #0x30]\\n\"\n        \".inst 0x4e8fa674  // smmla v20.4s, v19.16b, v15.16b\\n\"\n        \"fmul v31.4s, v27.4s, v26.s[1]\\n\"\n        \".inst 0x4e81a672  // smmla v18.4s, v19.16b, v1.16b\\n\"\n        \"ldr q19, [x21, #0x40]\\n\"\n        \".inst 0x4e8fa629  // smmla v9.4s, v17.16b, v15.16b\\n\"\n        \"fmul v15.4s, v27.4s, v26.s[2]\\n\"\n        \"fmul v27.4s, v27.4s, v26.s[3]\\n\"\n        \".inst 0x4e81a620  // smmla v0.4s, v17.16b, v1.16b\\n\"\n        \"ldr q1, [x21, #0x50]\\n\"\n        \".inst 0x4e95a674  // smmla v20.4s, v19.16b, v21.16b\\n\"\n        \".inst 0x4e90a672  // smmla v18.4s, v19.16b, v16.16b\\n\"\n        \"ldr q26, [x21, #0x60]\\n\"\n        \".inst 0x4e95a429  // smmla v9.4s, v1.16b, v21.16b\\n\"\n        \".inst 0x4e90a420  // smmla v0.4s, v1.16b, v16.16b\\n\"\n        \"ldr q21, [x21, #0x70]\\n\"\n        \"add x21, x21, #0x88\\n\"\n        \".inst 0x4e9da754  // smmla v20.4s, v26.16b, v29.16b\\n\"\n        \".inst 0x4e83a752  // smmla v18.4s, v26.16b, v3.16b\\n\"\n        \".inst 0x4e9da6a9  // smmla v9.4s, v21.16b, v29.16b\\n\"\n        \".inst 0x4e83a6a0  // smmla v0.4s, v21.16b, v3.16b\\n\"\n        \"uzp1 v29.2d, v20.2d, v18.2d\\n\"\n        \"uzp2 v21.2d, v20.2d, v18.2d\\n\"\n        \"scvtf v29.4s, v29.4s, #0x4\\n\"\n        \"uzp1 v18.2d, v9.2d, v0.2d\\n\"\n        \"uzp2 v16.2d, v9.2d, v0.2d\\n\"\n        \"scvtf v21.4s, v21.4s, #0x4\\n\"\n        \"fmla v6.4s, v29.4s, v8.4s\\n\"\n        \"scvtf v18.4s, v18.4s, #0x4\\n\"\n        \"scvtf v16.4s, v16.4s, #0x4\\n\"\n        \"fmla v30.4s, v21.4s, v31.4s\\n\"\n        \"fmla v24.4s, v18.4s, v15.4s\\n\"\n        \"fmla v14.4s, v16.4s, v27.4s\\n\"\n        \"bgt 3b\\n\"\n        \"mov x20, %x[res_ptr]\\n\"\n        \"subs x27, x27, #0x4\\n\"\n        \"add %x[res_ptr], %x[res_ptr], #0x10\\n\"\n        \"str q2, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q10, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q12, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q28, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q11, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q13, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q22, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q23, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q25, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q5, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q7, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q4, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q6, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q30, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q24, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"str q14, [x20, #0x0]\\n\"\n        \"bne 2b\\n\"\n        \"mov x20, #0x4\\n\"\n        \"sub x10, x10, #0x10\\n\"\n        \"cmp x10, #0x10\\n\"\n        \"mov %x[res_ptr], x26\\n\"\n        \"madd %x[a_ptr], x20, x9, %x[a_ptr]\\n\"\n        \"bge 1b\\n\"\n        \"4:\"  // Row loop skip\n        \"cbz x10, 9f\\n\"\n        \"5:\"  // Row tail: Row loop\n        \"add x24, %x[b_ptr], #0x8\\n\"\n        \"mov x23, %x[nc]\\n\"\n        \"add x22, %x[res_ptr], %x[res_stride], LSL #2\\n\"\n        \"6:\"  // Row tail: Column loop\n        \"movi v2.16b, #0x0\\n\"\n        \"movi v10.16b, #0x0\\n\"\n        \"add x25, %x[a_ptr], #0x8\\n\"\n        \"mov x21, %x[nb]\\n\"\n        \"movi v12.16b, #0x0\\n\"\n        \"movi v28.16b, #0x0\\n\"\n        \"7:\"  // Row tail: Block loop\n        \"ldr q6, [x24, #0x0]\\n\"\n        \"ldr q5, [x24, #0x10]\\n\"\n        \"movi v17.16b, #0x4\\n\"\n        \"movi v8.4s, #0x0\\n\"\n        \"ldr q4, [x25, #0x0]\\n\"\n        \"ldr q13, [x25, #0x10]\\n\"\n        \"movi v27.4s, #0x0\\n\"\n        \"movi v0.4s, #0x0\\n\"\n        \"ldr q31, [x24, #0x20]\\n\"\n        \"ldr q14, [x24, #0x30]\\n\"\n        \"movi v29.4s, #0x0\\n\"\n        \"movi v22.16b, #0xf0\\n\"\n        \"ldr q11, [x25, #0x20]\\n\"\n        \"ldr q23, [x25, #0x30]\\n\"\n        \"sshl v21.16b, v6.16b, v17.16b\\n\"\n        \"sshl v16.16b, v5.16b, v17.16b\\n\"\n        \"ldr q20, [x25, #0x40]\\n\"\n        \"ldr q26, [x25, #0x50]\\n\"\n        \"and v6.16b, v6.16b, v22.16b\\n\"\n        \"and v5.16b, v5.16b, v22.16b\\n\"\n        \"ldr q25, [x25, #0x60]\\n\"\n        \"ldr q3, [x25, #0x70]\\n\"\n        \"sshl v19.16b, v31.16b, v17.16b\\n\"\n        \"sshl v18.16b, v14.16b, v17.16b\\n\"\n        \"ldr d17, [x25, #-0x8]\\n\"\n        \".inst 0x4e95a488  // smmla v8.4s, v4.16b, v21.16b\\n\"\n        \".inst 0x4e90a49b  // smmla v27.4s, v4.16b, v16.16b\\n\"\n        \"and v31.16b, v31.16b, v22.16b\\n\"\n        \".inst 0x4e95a5a0  // smmla v0.4s, v13.16b, v21.16b\\n\"\n        \".inst 0x4e90a5bd  // smmla v29.4s, v13.16b, v16.16b\\n\"\n        \"and v14.16b, v14.16b, v22.16b\\n\"\n        \"sub x20, x24, #0x8\\n\"\n        \"ldr d16, [x20, #0x0]\\n\"\n        \"subs x21, x21, #0x1\\n\"\n        \"add x25, x25, #0x88\\n\"\n        \"fcvtl v17.4s, v17.4h\\n\"\n        \"add x24, x24, #0x48\\n\"\n        \".inst 0x4e93a568  // smmla v8.4s, v11.16b, v19.16b\\n\"\n        \".inst 0x4e92a57b  // smmla v27.4s, v11.16b, v18.16b\\n\"\n        \".inst 0x4e93a6e0  // smmla v0.4s, v23.16b, v19.16b\\n\"\n        \".inst 0x4e92a6fd  // smmla v29.4s, v23.16b, v18.16b\\n\"\n        \"fcvtl v16.4s, v16.4h\\n\"\n        \".inst 0x4e86a688  // smmla v8.4s, v20.16b, v6.16b\\n\"\n        \".inst 0x4e85a69b  // smmla v27.4s, v20.16b, v5.16b\\n\"\n        \"fmul v23.4s, v16.4s, v17.s[0]\\n\"\n        \"fmul v21.4s, v16.4s, v17.s[1]\\n\"\n        \"fmul v1.4s, v16.4s, v17.s[2]\\n\"\n        \"fmul v20.4s, v16.4s, v17.s[3]\\n\"\n        \".inst 0x4e86a740  // smmla v0.4s, v26.16b, v6.16b\\n\"\n        \".inst 0x4e85a75d  // smmla v29.4s, v26.16b, v5.16b\\n\"\n        \".inst 0x4e9fa728  // smmla v8.4s, v25.16b, v31.16b\\n\"\n        \".inst 0x4e8ea73b  // smmla v27.4s, v25.16b, v14.16b\\n\"\n        \".inst 0x4e9fa460  // smmla v0.4s, v3.16b, v31.16b\\n\"\n        \".inst 0x4e8ea47d  // smmla v29.4s, v3.16b, v14.16b\\n\"\n        \"uzp1 v19.2d, v8.2d, v27.2d\\n\"\n        \"uzp2 v18.2d, v8.2d, v27.2d\\n\"\n        \"scvtf v19.4s, v19.4s, #0x4\\n\"\n        \"uzp1 v17.2d, v0.2d, v29.2d\\n\"\n        \"uzp2 v16.2d, v0.2d, v29.2d\\n\"\n        \"scvtf v18.4s, v18.4s, #0x4\\n\"\n        \"fmla v2.4s, v19.4s, v23.4s\\n\"\n        \"scvtf v17.4s, v17.4s, #0x4\\n\"\n        \"scvtf v16.4s, v16.4s, #0x4\\n\"\n        \"fmla v10.4s, v18.4s, v21.4s\\n\"\n        \"fmla v12.4s, v17.4s, v1.4s\\n\"\n        \"fmla v28.4s, v16.4s, v20.4s\\n\"\n        \"bgt 7b\\n\"\n        \"mov x20, %x[res_ptr]\\n\"\n        \"cmp x10, #0x1\\n\"\n        \"str q2, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"ble 8f\\n\"\n        \"cmp x10, #0x2\\n\"\n        \"str q10, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"ble 8f\\n\"\n        \"cmp x10, #0x3\\n\"\n        \"str q12, [x20, #0x0]\\n\"\n        \"add x20, x20, %x[res_stride]\\n\"\n        \"ble 8f\\n\"\n        \"str q28, [x20, #0x0]\\n\"\n        \"8:\"  // Row tail: Accumulator store skip\n        \"subs x23, x23, #0x4\\n\"\n        \"add %x[res_ptr], %x[res_ptr], #0x10\\n\"\n        \"bne 6b\\n\"\n        \"subs x10, x10, #0x4\\n\"\n        \"add %x[a_ptr], %x[a_ptr], x9\\n\"\n        \"mov %x[res_ptr], x22\\n\"\n        \"bgt 5b\\n\"\n        \"9:\"  // Row tail: Row loop skip\n        : [a_ptr] \"+&r\" (a_ptr), [res_ptr] \"+&r\" (res_ptr)\n        : [b_ptr] \"r\" (b_ptr), [nr] \"r\" (nr), [nb] \"r\" (nb), [res_stride] \"r\" (res_stride), [nc] \"r\" (nc)\n        : \"cc\", \"memory\", \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\", \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\", \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\", \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\", \"x9\", \"x10\", \"x20\", \"x21\", \"x22\", \"x23\", \"x24\", \"x25\", \"x26\", \"x27\", \"x28\"\n    );\n    return;\n#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)\n    ggml_gemm_q4_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)\n#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)\n    if (ggml_cpu_get_sve_cnt() == QK8_0) {\n        const void * b_ptr = vx;\n        const void * a_ptr = vy;\n        float * res_ptr = s;\n        size_t res_stride = bs * sizeof(float);\n\n        __asm__ __volatile__(\n            \"mov x20, #0x4\\n\"\n            \"mov x13, %x[nr]\\n\"\n            \"mov z28.s, #-0x4\\n\"\n            \"mov x12, #0x88\\n\"\n            \"ptrue p1.b\\n\"\n            \"whilelt p0.s, XZR, x20\\n\"\n            \"cmp x13, #0x10\\n\"\n            \"mul x12, %x[nb], x12\\n\"\n            \"blt 4f\\n\"\n            \"1:\"  // Row loop\n            \"add x11, %x[b_ptr], #0x10\\n\"\n            \"mov x10, %x[nc]\\n\"\n            \"add x9, %x[res_ptr], %x[res_stride], LSL #4\\n\"\n            \"2:\"  // Column loop\n            \"add x28, %x[a_ptr], #0x8\\n\"\n            \"mov z24.b, #0x0\\n\"\n            \"mov z15.b, #0x0\\n\"\n            \"mov x27, %x[nb]\\n\"\n            \"add x26, x28, x12\\n\"\n            \"mov z12.b, #0x0\\n\"\n            \"mov z0.b, #0x0\\n\"\n            \"add x25, x26, x12\\n\"\n            \"mov z13.b, #0x0\\n\"\n            \"mov z1.b, #0x0\\n\"\n            \"add x24, x25, x12\\n\"\n            \"mov z20.b, #0x0\\n\"\n            \"mov z25.b, #0x0\\n\"\n            \"mov z11.b, #0x0\\n\"\n            \"mov z16.b, #0x0\\n\"\n            \"mov z19.b, #0x0\\n\"\n            \"mov z26.b, #0x0\\n\"\n            \"mov z8.b, #0x0\\n\"\n            \"mov z29.b, #0x0\\n\"\n            \"mov z27.b, #0x0\\n\"\n            \"mov z10.b, #0x0\\n\"\n            \"3:\"  // Block loop\n            \"ld1b { z30.b }, p1/Z, [x11]\\n\"\n            \"ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\\n\"\n            \"mov z18.s, #0x0\\n\"\n            \"mov z7.s, #0x0\\n\"\n            \"ld1rqb { z3.b }, p1/Z, [x28]\\n\"\n            \"ld1rqb { z5.b }, p1/Z, [x28, #16]\\n\"\n            \"mov z9.s, #0x0\\n\"\n            \"mov z22.s, #0x0\\n\"\n            \"ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\\n\"\n            \"ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\\n\"\n            \"sub x20, x11, #0x10\\n\"\n            \"sub x23, x28, #0x8\\n\"\n            \"lsl z31.b, z30.b, #0x4\\n\"\n            \"lsl z6.b, z21.b, #0x4\\n\"\n            \"ld1h { z23.s }, p1/Z, [x20]\\n\"\n            \"sub x22, x26, #0x8\\n\"\n            \"and z30.b, z30.b, #0xf0\\n\"\n            \"and z21.b, z21.b, #0xf0\\n\"\n            \"sub x21, x25, #0x8\\n\"\n            \"sub x20, x24, #0x8\\n\"\n            \"lsl z14.b, z4.b, #0x4\\n\"\n            \"lsl z2.b, z17.b, #0x4\\n\"\n            \"subs x27, x27, #0x1\\n\"\n            \"add x11, x11, #0x90\\n\"\n            \".inst 0x451f9872  // smmla z18.s, z3.b, z31.b\\n\"\n            \".inst 0x45069867  // smmla z7.s, z3.b, z6.b\\n\"\n            \"ld1rqb { z3.b }, p1/Z, [x28, #32]\\n\"\n            \"and z4.b, z4.b, #0xf0\\n\"\n            \".inst 0x451f98a9  // smmla z9.s, z5.b, z31.b\\n\"\n            \".inst 0x450698b6  // smmla z22.s, z5.b, z6.b\\n\"\n            \"ld1rqb { z5.b }, p1/Z, [x28, #48]\\n\"\n            \"and z17.b, z17.b, #0xf0\\n\"\n            \"fcvt z23.s, p1/m, z23.h\\n\"\n            \".inst 0x450e9872  // smmla z18.s, z3.b, z14.b\\n\"\n            \".inst 0x45029867  // smmla z7.s, z3.b, z2.b\\n\"\n            \"ld1rqb { z3.b }, p1/Z, [x28, #64]\\n\"\n            \".inst 0x450e98a9  // smmla z9.s, z5.b, z14.b\\n\"\n            \".inst 0x450298b6  // smmla z22.s, z5.b, z2.b\\n\"\n            \"ld1rqb { z5.b }, p1/Z, [x28, #80]\\n\"\n            \"fscale z23.s, p1/m, z23.s, z28.s\\n\"\n            \".inst 0x451e9872  // smmla z18.s, z3.b, z30.b\\n\"\n            \".inst 0x45159867  // smmla z7.s, z3.b, z21.b\\n\"\n            \"ld1rqb { z3.b }, p1/Z, [x28, #96]\\n\"\n            \".inst 0x451e98a9  // smmla z9.s, z5.b, z30.b\\n\"\n            \".inst 0x451598b6  // smmla z22.s, z5.b, z21.b\\n\"\n            \"ld1rqb { z5.b }, p1/Z, [x28, #112]\\n\"\n            \"add x28, x28, #0x88\\n\"\n            \".inst 0x45049872  // smmla z18.s, z3.b, z4.b\\n\"\n            \".inst 0x45119867  // smmla z7.s, z3.b, z17.b\\n\"\n            \"ld1h { z3.s }, p0/Z, [x23]\\n\"\n            \".inst 0x450498a9  // smmla z9.s, z5.b, z4.b\\n\"\n            \".inst 0x451198b6  // smmla z22.s, z5.b, z17.b\\n\"\n            \"fcvt z3.s, p1/m, z3.h\\n\"\n            \"uzp1 z5.d, z18.d, z7.d\\n\"\n            \"uzp2 z18.d, z18.d, z7.d\\n\"\n            \"mov z3.q, z3.q[0]\\n\"\n            \"uzp1 z7.d, z9.d, z22.d\\n\"\n            \"uzp2 z22.d, z9.d, z22.d\\n\"\n            \"fmul z9.s, z23.s, z3.s[0]\\n\"\n            \"scvtf z5.s, p1/m, z5.s\\n\"\n            \"scvtf z18.s, p1/m, z18.s\\n\"\n            \"scvtf z7.s, p1/m, z7.s\\n\"\n            \"scvtf z22.s, p1/m, z22.s\\n\"\n            \"fmla z24.s, p1/M, z5.s, z9.s\\n\"\n            \"ld1rqb { z5.b }, p1/Z, [x26]\\n\"\n            \"fmul z9.s, z23.s, z3.s[1]\\n\"\n            \"fmla z15.s, p1/M, z18.s, z9.s\\n\"\n            \"ld1rqb { z18.b }, p1/Z, [x26, #16]\\n\"\n            \"fmul z9.s, z23.s, z3.s[2]\\n\"\n            \"fmul z3.s, z23.s, z3.s[3]\\n\"\n            \"fmla z12.s, p1/M, z7.s, z9.s\\n\"\n            \"mov z9.s, #0x0\\n\"\n            \"ld1h { z7.s }, p0/Z, [x22]\\n\"\n            \".inst 0x451f98a9  // smmla z9.s, z5.b, z31.b\\n\"\n            \"fmla z0.s, p1/M, z22.s, z3.s\\n\"\n            \"mov z22.s, #0x0\\n\"\n            \"ld1h { z3.s }, p0/Z, [x21]\\n\"\n            \".inst 0x450698b6  // smmla z22.s, z5.b, z6.b\\n\"\n            \"ld1rqb { z5.b }, p1/Z, [x26, #32]\\n\"\n            \"fcvt z7.s, p1/m, z7.h\\n\"\n            \"fcvt z3.s, p1/m, z3.h\\n\"\n            \".inst 0x450e98a9  // smmla z9.s, z5.b, z14.b\\n\"\n            \".inst 0x450298b6  // smmla z22.s, z5.b, z2.b\\n\"\n            \"ld1rqb { z5.b }, p1/Z, [x26, #64]\\n\"\n            \"mov z7.q, z7.q[0]\\n\"\n            \"mov z3.q, z3.q[0]\\n\"\n            \".inst 0x451e98a9  // smmla z9.s, z5.b, z30.b\\n\"\n            \".inst 0x451598b6  // smmla z22.s, z5.b, z21.b\\n\"\n            \"ld1rqb { z5.b }, p1/Z, [x26, #96]\\n\"\n            \".inst 0x450498a9  // smmla z9.s, z5.b, z4.b\\n\"\n            \".inst 0x451198b6  // smmla z22.s, z5.b, z17.b\\n\"\n            \"uzp1 z5.d, z9.d, z22.d\\n\"\n            \"scvtf z5.s, p1/m, z5.s\\n\"\n            \"uzp2 z22.d, z9.d, z22.d\\n\"\n            \"fmul z9.s, z23.s, z7.s[0]\\n\"\n            \"scvtf z22.s, p1/m, z22.s\\n\"\n            \"fmla z13.s, p1/M, z5.s, z9.s\\n\"\n            \"ld1rqb { z9.b }, p1/Z, [x25]\\n\"\n            \"fmul z5.s, z23.s, z7.s[1]\\n\"\n            \"fmla z1.s, p1/M, z22.s, z5.s\\n\"\n            \"mov z5.s, #0x0\\n\"\n            \"mov z22.s, #0x0\\n\"\n            \".inst 0x451f9a45  // smmla z5.s, z18.b, z31.b\\n\"\n            \".inst 0x45069a56  // smmla z22.s, z18.b, z6.b\\n\"\n            \"ld1rqb { z18.b }, p1/Z, [x26, #48]\\n\"\n            \".inst 0x450e9a45  // smmla z5.s, z18.b, z14.b\\n\"\n            \".inst 0x45029a56  // smmla z22.s, z18.b, z2.b\\n\"\n            \"ld1rqb { z18.b }, p1/Z, [x26, #80]\\n\"\n            \".inst 0x451e9a45  // smmla z5.s, z18.b, z30.b\\n\"\n            \".inst 0x45159a56  // smmla z22.s, z18.b, z21.b\\n\"\n            \"ld1rqb { z18.b }, p1/Z, [x26, #112]\\n\"\n            \"add x26, x26, #0x88\\n\"\n            \".inst 0x45049a45  // smmla z5.s, z18.b, z4.b\\n\"\n            \".inst 0x45119a56  // smmla z22.s, z18.b, z17.b\\n\"\n            \"uzp1 z18.d, z5.d, z22.d\\n\"\n            \"scvtf z18.s, p1/m, z18.s\\n\"\n            \"uzp2 z22.d, z5.d, z22.d\\n\"\n            \"fmul z5.s, z23.s, z7.s[2]\\n\"\n            \"fmul z7.s, z23.s, z7.s[3]\\n\"\n            \"scvtf z22.s, p1/m, z22.s\\n\"\n            \"fmla z20.s, p1/M, z18.s, z5.s\\n\"\n            \"ld1rqb { z18.b }, p1/Z, [x25, #16]\\n\"\n            \"ld1h { z5.s }, p0/Z, [x20]\\n\"\n            \"fcvt z5.s, p1/m, z5.h\\n\"\n            \"fmla z25.s, p1/M, z22.s, z7.s\\n\"\n            \"mov z22.s, #0x0\\n\"\n            \"mov z7.s, #0x0\\n\"\n            \".inst 0x451f9936  // smmla z22.s, z9.b, z31.b\\n\"\n            \".inst 0x45069927  // smmla z7.s, z9.b, z6.b\\n\"\n            \"ld1rqb { z9.b }, p1/Z, [x25, #32]\\n\"\n            \"mov z5.q, z5.q[0]\\n\"\n            \".inst 0x450e9936  // smmla z22.s, z9.b, z14.b\\n\"\n            \".inst 0x45029927  // smmla z7.s, z9.b, z2.b\\n\"\n            \"ld1rqb { z9.b }, p1/Z, [x25, #64]\\n\"\n            \".inst 0x451e9936  // smmla z22.s, z9.b, z30.b\\n\"\n            \".inst 0x45159927  // smmla z7.s, z9.b, z21.b\\n\"\n            \"ld1rqb { z9.b }, p1/Z, [x25, #96]\\n\"\n            \".inst 0x45049936  // smmla z22.s, z9.b, z4.b\\n\"\n            \".inst 0x45119927  // smmla z7.s, z9.b, z17.b\\n\"\n            \"uzp1 z9.d, z22.d, z7.d\\n\"\n            \"scvtf z9.s, p1/m, z9.s\\n\"\n            \"uzp2 z22.d, z22.d, z7.d\\n\"\n            \"fmul z7.s, z23.s, z3.s[0]\\n\"\n            \"scvtf z22.s, p1/m, z22.s\\n\"\n            \"fmla z11.s, p1/M, z9.s, z7.s\\n\"\n            \"ld1rqb { z9.b }, p1/Z, [x24]\\n\"\n            \"fmul z7.s, z23.s, z3.s[1]\\n\"\n            \"fmla z16.s, p1/M, z22.s, z7.s\\n\"\n            \"mov z22.s, #0x0\\n\"\n            \"mov z7.s, #0x0\\n\"\n            \".inst 0x451f9a56  // smmla z22.s, z18.b, z31.b\\n\"\n            \".inst 0x45069a47  // smmla z7.s, z18.b, z6.b\\n\"\n            \"ld1rqb { z18.b }, p1/Z, [x25, #48]\\n\"\n            \".inst 0x450e9a56  // smmla z22.s, z18.b, z14.b\\n\"\n            \".inst 0x45029a47  // smmla z7.s, z18.b, z2.b\\n\"\n            \"ld1rqb { z18.b }, p1/Z, [x25, #80]\\n\"\n            \".inst 0x451e9a56  // smmla z22.s, z18.b, z30.b\\n\"\n            \".inst 0x45159a47  // smmla z7.s, z18.b, z21.b\\n\"\n            \"ld1rqb { z18.b }, p1/Z, [x25, #112]\\n\"\n            \"add x25, x25, #0x88\\n\"\n            \".inst 0x45049a56  // smmla z22.s, z18.b, z4.b\\n\"\n            \".inst 0x45119a47  // smmla z7.s, z18.b, z17.b\\n\"\n            \"uzp1 z18.d, z22.d, z7.d\\n\"\n            \"scvtf z18.s, p1/m, z18.s\\n\"\n            \"uzp2 z7.d, z22.d, z7.d\\n\"\n            \"fmul z22.s, z23.s, z3.s[2]\\n\"\n            \"fmul z3.s, z23.s, z3.s[3]\\n\"\n            \"scvtf z7.s, p1/m, z7.s\\n\"\n            \"fmla z19.s, p1/M, z18.s, z22.s\\n\"\n            \"ld1rqb { z18.b }, p1/Z, [x24, #16]\\n\"\n            \"fmul z22.s, z23.s, z5.s[0]\\n\"\n            \"fmla z26.s, p1/M, z7.s, z3.s\\n\"\n            \"mov z3.s, #0x0\\n\"\n            \"mov z7.s, #0x0\\n\"\n            \".inst 0x451f9923  // smmla z3.s, z9.b, z31.b\\n\"\n            \".inst 0x45069927  // smmla z7.s, z9.b, z6.b\\n\"\n            \"ld1rqb { z9.b }, p1/Z, [x24, #32]\\n\"\n            \".inst 0x450e9923  // smmla z3.s, z9.b, z14.b\\n\"\n            \".inst 0x45029927  // smmla z7.s, z9.b, z2.b\\n\"\n            \"mov z9.s, #0x0\\n\"\n            \".inst 0x451f9a49  // smmla z9.s, z18.b, z31.b\\n\"\n            \"mov z31.s, #0x0\\n\"\n            \".inst 0x45069a5f  // smmla z31.s, z18.b, z6.b\\n\"\n            \"ld1rqb { z6.b }, p1/Z, [x24, #48]\\n\"\n            \"ld1rqb { z18.b }, p1/Z, [x24, #64]\\n\"\n            \".inst 0x450e98c9  // smmla z9.s, z6.b, z14.b\\n\"\n            \"fmul z14.s, z23.s, z5.s[1]\\n\"\n            \".inst 0x450298df  // smmla z31.s, z6.b, z2.b\\n\"\n            \"ld1rqb { z6.b }, p1/Z, [x24, #80]\\n\"\n            \"fmul z2.s, z23.s, z5.s[2]\\n\"\n            \"fmul z23.s, z23.s, z5.s[3]\\n\"\n            \".inst 0x451e9a43  // smmla z3.s, z18.b, z30.b\\n\"\n            \".inst 0x45159a47  // smmla z7.s, z18.b, z21.b\\n\"\n            \"ld1rqb { z5.b }, p1/Z, [x24, #96]\\n\"\n            \".inst 0x451e98c9  // smmla z9.s, z6.b, z30.b\\n\"\n            \".inst 0x451598df  // smmla z31.s, z6.b, z21.b\\n\"\n            \"ld1rqb { z18.b }, p1/Z, [x24, #112]\\n\"\n            \"add x24, x24, #0x88\\n\"\n            \".inst 0x450498a3  // smmla z3.s, z5.b, z4.b\\n\"\n            \".inst 0x451198a7  // smmla z7.s, z5.b, z17.b\\n\"\n            \".inst 0x45049a49  // smmla z9.s, z18.b, z4.b\\n\"\n            \".inst 0x45119a5f  // smmla z31.s, z18.b, z17.b\\n\"\n            \"uzp1 z18.d, z3.d, z7.d\\n\"\n            \"uzp2 z5.d, z3.d, z7.d\\n\"\n            \"scvtf z18.s, p1/m, z18.s\\n\"\n            \"uzp1 z6.d, z9.d, z31.d\\n\"\n            \"uzp2 z9.d, z9.d, z31.d\\n\"\n            \"scvtf z5.s, p1/m, z5.s\\n\"\n            \"fmla z8.s, p1/M, z18.s, z22.s\\n\"\n            \"scvtf z6.s, p1/m, z6.s\\n\"\n            \"scvtf z9.s, p1/m, z9.s\\n\"\n            \"fmla z29.s, p1/M, z5.s, z14.s\\n\"\n            \"fmla z27.s, p1/M, z6.s, z2.s\\n\"\n            \"fmla z10.s, p1/M, z9.s, z23.s\\n\"\n            \"bgt 3b\\n\"\n            \"mov x20, %x[res_ptr]\\n\"\n            \"subs x10, x10, #0x8\\n\"\n            \"add %x[res_ptr], %x[res_ptr], #0x20\\n\"\n            \"st1w { z24.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z15.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z12.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z0.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z13.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z1.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z20.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z25.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z11.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z16.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z19.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z26.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z8.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z29.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z27.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"st1w { z10.s }, p1, [x20]\\n\"\n            \"bne 2b\\n\"\n            \"mov x20, #0x4\\n\"\n            \"sub x13, x13, #0x10\\n\"\n            \"cmp x13, #0x10\\n\"\n            \"mov %x[res_ptr], x9\\n\"\n            \"madd %x[a_ptr], x20, x12, %x[a_ptr]\\n\"\n            \"bge 1b\\n\"\n            \"4:\"  // Row loop skip\n            \"cbz x13, 9f\\n\"\n            \"5:\"  // Row tail: Row loop\n            \"add x25, %x[b_ptr], #0x10\\n\"\n            \"mov x24, %x[nc]\\n\"\n            \"add x23, %x[res_ptr], %x[res_stride], LSL #2\\n\"\n            \"6:\"  // Row tail: Column loop\n            \"mov z24.b, #0x0\\n\"\n            \"mov z15.b, #0x0\\n\"\n            \"add x28, %x[a_ptr], #0x8\\n\"\n            \"mov x22, %x[nb]\\n\"\n            \"mov z12.b, #0x0\\n\"\n            \"mov z0.b, #0x0\\n\"\n            \"7:\"  // Row tail: Block loop\n            \"ld1b { z3.b }, p1/Z, [x25]\\n\"\n            \"ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\\n\"\n            \"mov z2.s, #0x0\\n\"\n            \"mov z25.s, #0x0\\n\"\n            \"ld1rqb { z26.b }, p1/Z, [x28]\\n\"\n            \"ld1rqb { z21.b }, p1/Z, [x28, #16]\\n\"\n            \"mov z27.s, #0x0\\n\"\n            \"mov z19.s, #0x0\\n\"\n            \"ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\\n\"\n            \"ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\\n\"\n            \"sub x21, x25, #0x10\\n\"\n            \"sub x20, x28, #0x8\\n\"\n            \"lsl z20.b, z3.b, #0x4\\n\"\n            \"lsl z4.b, z6.b, #0x4\\n\"\n            \"ld1rqb { z10.b }, p1/Z, [x28, #32]\\n\"\n            \"ld1rqb { z23.b }, p1/Z, [x28, #48]\\n\"\n            \"and z3.b, z3.b, #0xf0\\n\"\n            \"and z6.b, z6.b, #0xf0\\n\"\n            \"ld1rqb { z11.b }, p1/Z, [x28, #64]\\n\"\n            \"ld1rqb { z7.b }, p1/Z, [x28, #80]\\n\"\n            \"lsl z8.b, z29.b, #0x4\\n\"\n            \"lsl z14.b, z16.b, #0x4\\n\"\n            \"ld1rqb { z18.b }, p1/Z, [x28, #96]\\n\"\n            \"ld1rqb { z30.b }, p1/Z, [x28, #112]\\n\"\n            \".inst 0x45149b42  // smmla z2.s, z26.b, z20.b\\n\"\n            \".inst 0x45049b59  // smmla z25.s, z26.b, z4.b\\n\"\n            \"and z29.b, z29.b, #0xf0\\n\"\n            \"ld1h { z17.s }, p1/Z, [x21]\\n\"\n            \".inst 0x45149abb  // smmla z27.s, z21.b, z20.b\\n\"\n            \".inst 0x45049ab3  // smmla z19.s, z21.b, z4.b\\n\"\n            \"and z16.b, z16.b, #0xf0\\n\"\n            \"ld1h { z4.s }, p0/Z, [x20]\\n\"\n            \"subs x22, x22, #0x1\\n\"\n            \"add x28, x28, #0x88\\n\"\n            \"fcvt z17.s, p1/m, z17.h\\n\"\n            \"add x25, x25, #0x90\\n\"\n            \".inst 0x45089942  // smmla z2.s, z10.b, z8.b\\n\"\n            \".inst 0x450e9959  // smmla z25.s, z10.b, z14.b\\n\"\n            \"fcvt z4.s, p1/m, z4.h\\n\"\n            \".inst 0x45089afb  // smmla z27.s, z23.b, z8.b\\n\"\n            \".inst 0x450e9af3  // smmla z19.s, z23.b, z14.b\\n\"\n            \"fscale z17.s, p1/m, z17.s, z28.s\\n\"\n            \"mov z4.q, z4.q[0]\\n\"\n            \".inst 0x45039962  // smmla z2.s, z11.b, z3.b\\n\"\n            \".inst 0x45069979  // smmla z25.s, z11.b, z6.b\\n\"\n            \"fmul z23.s, z17.s, z4.s[0]\\n\"\n            \"fmul z9.s, z17.s, z4.s[1]\\n\"\n            \"fmul z21.s, z17.s, z4.s[2]\\n\"\n            \"fmul z4.s, z17.s, z4.s[3]\\n\"\n            \".inst 0x450398fb  // smmla z27.s, z7.b, z3.b\\n\"\n            \".inst 0x450698f3  // smmla z19.s, z7.b, z6.b\\n\"\n            \".inst 0x451d9a42  // smmla z2.s, z18.b, z29.b\\n\"\n            \".inst 0x45109a59  // smmla z25.s, z18.b, z16.b\\n\"\n            \".inst 0x451d9bdb  // smmla z27.s, z30.b, z29.b\\n\"\n            \".inst 0x45109bd3  // smmla z19.s, z30.b, z16.b\\n\"\n            \"uzp1 z31.d, z2.d, z25.d\\n\"\n            \"uzp2 z13.d, z2.d, z25.d\\n\"\n            \"scvtf z31.s, p1/m, z31.s\\n\"\n            \"uzp1 z17.d, z27.d, z19.d\\n\"\n            \"uzp2 z18.d, z27.d, z19.d\\n\"\n            \"scvtf z13.s, p1/m, z13.s\\n\"\n            \"fmla z24.s, p1/M, z31.s, z23.s\\n\"\n            \"scvtf z17.s, p1/m, z17.s\\n\"\n            \"scvtf z18.s, p1/m, z18.s\\n\"\n            \"fmla z15.s, p1/M, z13.s, z9.s\\n\"\n            \"fmla z12.s, p1/M, z17.s, z21.s\\n\"\n            \"fmla z0.s, p1/M, z18.s, z4.s\\n\"\n            \"bgt 7b\\n\"\n            \"mov x20, %x[res_ptr]\\n\"\n            \"cmp x13, #0x1\\n\"\n            \"st1w { z24.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"ble 8f\\n\"\n            \"cmp x13, #0x2\\n\"\n            \"st1w { z15.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"ble 8f\\n\"\n            \"cmp x13, #0x3\\n\"\n            \"st1w { z12.s }, p1, [x20]\\n\"\n            \"add x20, x20, %x[res_stride]\\n\"\n            \"ble 8f\\n\"\n            \"st1w { z0.s }, p1, [x20]\\n\"\n            \"8:\"  // Row tail: Accumulator store skip\n            \"subs x24, x24, #0x8\\n\"\n            \"add %x[res_ptr], %x[res_ptr], #0x20\\n\"\n            \"bne 6b\\n\"\n            \"subs x13, x13, #0x4\\n\"\n            \"add %x[a_ptr], %x[a_ptr], x12\\n\"\n            \"mov %x[res_ptr], x23\\n\"\n            \"bgt 5b\\n\"\n            \"9:\"  // Row tail: Row loop skip\n            : [a_ptr] \"+&r\" (a_ptr), [res_ptr] \"+&r\" (res_ptr)\n            : [b_ptr] \"r\" (b_ptr), [nr] \"r\" (nr), [nb] \"r\" (nb), [res_stride] \"r\" (res_stride), [nc] \"r\" (nc)\n            : \"cc\", \"memory\", \"p0\", \"p1\", \"x9\", \"x10\", \"x11\", \"x12\", \"x13\", \"x20\", \"x21\", \"x22\", \"x23\", \"x24\", \"x25\", \"x26\", \"x27\", \"x28\", \"z0\", \"z1\", \"z2\", \"z3\", \"z4\", \"z5\", \"z6\", \"z7\", \"z8\", \"z9\", \"z10\", \"z11\", \"z12\", \"z13\", \"z14\", \"z15\", \"z16\", \"z17\", \"z18\", \"z19\", \"z20\", \"z21\", \"z22\", \"z23\", \"z24\", \"z25\", \"z26\", \"z27\", \"z28\", \"z29\", \"z30\", \"z31\"\n        );\n        return;\n    }\n#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)\n\n#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)\n    ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 4;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);\n\n            float32x4_t sumf[4];\n            for (int m = 0; m < 4; m++) {\n                sumf[m] = vdupq_n_f32(0);\n            }\n\n            for (int l = 0; l < nb; l++) {\n                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));\n                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));\n\n                int32x4_t sumi_0 = vdupq_n_s32(0);\n                int32x4_t sumi_1 = vdupq_n_s32(0);\n                int32x4_t sumi_2 = vdupq_n_s32(0);\n                int32x4_t sumi_3 = vdupq_n_s32(0);\n\n                for (int k = 0; k < 4; k++) {\n                    int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);\n                    int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);\n\n                    uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);\n                    int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);\n                    int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);\n\n                    sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);\n                    sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);\n                    sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);\n                    sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);\n                    sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);\n                    sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);\n                    sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);\n                    sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);\n                }\n\n                sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));\n                sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));\n                sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));\n                sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));\n            }\n\n            for (int m = 0; m < 4; m++) {\n                vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);\n            }\n        }\n    }\n    return;\n#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)\n    ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 4;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);\n\n            float32x4_t sumf[4];\n            for (int m = 0; m < 4; m++) {\n                sumf[m] = vdupq_n_f32(0);\n            }\n\n            for (int l = 0; l < nb; l++) {\n                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));\n                float32x4_t b_d = {\n                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),\n                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),\n                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),\n                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),\n                };\n\n                int32x4_t sumi_0 = vdupq_n_s32(0);\n                int32x4_t sumi_1 = vdupq_n_s32(0);\n                int32x4_t sumi_2 = vdupq_n_s32(0);\n                int32x4_t sumi_3 = vdupq_n_s32(0);\n\n                for (int k = 0; k < 4; k++) {\n                    int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);\n                    int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);\n\n                    uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);\n                    int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);\n                    int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);\n\n                    sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);\n                    sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);\n                    sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);\n                    sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);\n                    sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);\n                    sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);\n                    sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);\n                    sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);\n                }\n\n                sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));\n                sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));\n                sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));\n                sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));\n            }\n\n            for (int m = 0; m < 4; m++) {\n                vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);\n            }\n        }\n    }\n    return;\n#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)\n    ggml_gemm_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    constexpr int qk = QK_K;\n    const int     nb = n / qk;\n\n    constexpr int ncols_interleaved = 8;\n    constexpr int blocklen          = 4;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    constexpr int    q8_k_blocklen = 4;\n    constexpr int    acc_size  = 2 * 4;  // 2 row pairs × 4 col pairs\n    const uint8x16_t m4b       = vdupq_n_u8(0x0f);\n\n    // 8 accumulators: 2 row pairs × 4 col pairs\n    float32x4_t acc_f32[acc_size];\n\n    for (int y = 0; y < nr / q8_k_blocklen; y++) {\n        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);\n\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);\n\n            for (int i = 0; i < acc_size; i++) {\n                acc_f32[i] = vdupq_n_f32(0);\n            }\n\n            for (int b = 0; b < nb; b++) {\n                // d4 0 1 2 3, 4 5 6 7\n                float32x4_t q4_d_0123    = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));\n                float32x4_t q4_d_4567    = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));\n                // d8 0 1 2 3\n                float32x4_t q8_d_0123    = vld1q_f32(q8_ptr[b].d);\n                // mins\n                float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));\n                float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));\n\n                // Precomputation of scales and mins\n                float32x4_t sbd_scale_0123[q8_k_blocklen];\n                float32x4_t sbd_scale_4567[q8_k_blocklen];\n                float32x4_t sbd_min_0123[q8_k_blocklen];\n                float32x4_t sbd_min_4567[q8_k_blocklen];\n\n                sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);\n                sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);\n                sbd_min_0123[0]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);\n                sbd_min_4567[0]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);\n\n                sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);\n                sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);\n                sbd_min_0123[1]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);\n                sbd_min_4567[1]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);\n\n                sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);\n                sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);\n                sbd_min_0123[2]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);\n                sbd_min_4567[2]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);\n\n                sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);\n                sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);\n                sbd_min_0123[3]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);\n                sbd_min_4567[3]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);\n\n                // Precomputation of bsums, each vpaddq calcs all the bsums for each row\n                const int16x8_t bsums[q8_k_blocklen] = {\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),\n                };\n                int16_t bsums_arr[QK_K / 64][8];\n                for (int q8_row = 0; q8_row < 4; q8_row++) {\n                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);\n                }\n\n                // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..\n                int32x4_t bias_acc[acc_size];\n                for (int i = 0; i < acc_size; i++) {\n                    bias_acc[i] = vdupq_n_s32(0);\n                }\n\n                for (int sb = 0; sb < QK_K / 64; sb++) {\n                    // Int accumulators for qs vecdot (4 row x 2 col quartets)\n                    int32x4_t acc_lo[acc_size];\n                    int32x4_t acc_hi[acc_size];\n                    for (int i = 0; i < acc_size; i++) {\n                        acc_lo[i] = vdupq_n_s32(0);\n                        acc_hi[i] = vdupq_n_s32(0);\n                    }\n                    // Need scales for the low and high nibbles\n                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total\n                    int16x8_t q4sb_scales[2];\n                    int16x8_t q4sb_mins[2];\n                    for (int i = 0; i < 2; i++) {\n                        int8_t    aux_q4sb[8];\n                        const int offset = sb * 24 + i * 12;\n                        decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);\n                        q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));\n                    }\n\n                    constexpr int reads_per_sb = 8;  // 8 * 16 bytes each => 32 qs * 4 rows\n                    for (int k = 0; k < reads_per_sb; k++) {\n                        const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);\n                        const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);\n\n                        // 0..3 & 32..35\n                        const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);\n                        const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);\n\n                        const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));\n                        const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));\n\n                        acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0);  //  0..3  r0 c0123\n                        acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1);  //  0..3  r1 c0123\n                        acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2);  //  0..3  r2 c0123\n                        acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3);  //  0..3  r3 c0123\n\n                        acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0);  // 32..35 r0 c0123\n                        acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1);  // 32..35 r1 c0123\n                        acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2);  // 32..35 r2 c0123\n                        acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3);  // 32..35 r3 c0123\n\n                        const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));\n                        const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));\n\n                        acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0);  //  0..3  r0 c4567\n                        acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1);  //  0..3  r1 c4567\n                        acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2);  //  0..3  r2 c4567\n                        acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3);  //  0..3  r3 c4567\n\n                        acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0);  // 32..35 r0 c4567\n                        acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1);  // 32..35 r1 c4567\n                        acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2);  // 32..35 r2 c4567\n                        acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3);  // 32..35 r3 c4567\n                    }\n\n                    // Scale and bias application\n                    // acc is stored interleaved to match output layout\n                    const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);\n                    const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);\n                    const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);\n                    const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);\n                    for (int row = 0; row < q8_k_blocklen; row++) {\n                        // Bias correction\n                        // row c0123 blk0 and blk1\n                        const float32x4_t sumf_0123 =\n                            vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),\n                                                    vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));\n                        acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);\n\n                        // row c4567 blk0 and blk1\n                        const float32x4_t sumf_4567 =\n                            vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),\n                                                    vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));\n                        acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);\n\n                        // Bias\n                        const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);\n                        const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);\n\n                        // row c0123 blk0 and blk1\n                        bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));\n                        bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));\n\n                        // row c4567 blk0 and blk1\n                        bias_acc[2 * row + 1] =\n                            vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));\n                        bias_acc[2 * row + 1] =\n                            vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));\n                    }\n                }  // for sb\n\n                for (int row = 0; row < q8_k_blocklen; row++) {\n                    acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);\n                    acc_f32[2 * row + 1] =\n                        vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);\n                }\n            }  // for b\n\n            for (int i = 0; i < q8_k_blocklen; i++) {\n                int row = y * q8_k_blocklen + i;\n                for (int j = 0; j < 2; j++) {\n                    int col    = x * ncols_interleaved + j * 4;\n                    int offset = row * bs + col;\n                    vst1q_f32(s + offset, acc_f32[2 * i + j]);\n                }\n            }\n        }  // for x\n    }  // for y\n    return;\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q5_K_8x4_q8_K(int                        n,\n                             float * GGML_RESTRICT      s,\n                             size_t                     bs,\n                             const void * GGML_RESTRICT vx,\n                             const void * GGML_RESTRICT vy,\n                             int                        nr,\n                             int                        nc) {\n    constexpr int qk = QK_K;\n    const int     nb = n / qk;\n\n    constexpr int ncols_interleaved = 8;\n    constexpr int blocklen          = 4;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    constexpr int    q8_k_blocklen = 4;\n    constexpr int    acc_size      = 2 * 4;  // 2 row pairs, 4 col pairs\n    constexpr int    col_groups    = ncols_interleaved / 4;\n    const uint8x16_t m4b           = vdupq_n_u8(0x0f);\n    const uint8x16_t mone          = vdupq_n_u8(1);\n    const uint8x16_t mtwo          = vdupq_n_u8(2);\n\n    // 8 accumulators: 2 row pairs, 4 col pairs\n    float32x4_t acc_f32[acc_size];\n\n    for (int y = 0; y < nr / q8_k_blocklen; y++) {\n        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);\n\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);\n\n            for (int i = 0; i < acc_size; i++) {\n                acc_f32[i] = vdupq_n_f32(0);\n            }\n\n            for (int b = 0; b < nb; b++) {\n                // d5 0 1 2 3, 4 5 6 7\n                float32x4_t q5_d_0123    = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));\n                float32x4_t q5_d_4567    = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));\n                // d8 0 1 2 3\n                float32x4_t q8_d_0123    = vld1q_f32(q8_ptr[b].d);\n                // mins\n                float32x4_t q5_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));\n                float32x4_t q5_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));\n\n                // Precomputation of scales and mins\n                float32x4_t sbd_scale_0123[q8_k_blocklen];\n                float32x4_t sbd_scale_4567[q8_k_blocklen];\n                float32x4_t sbd_min_0123[q8_k_blocklen];\n                float32x4_t sbd_min_4567[q8_k_blocklen];\n\n                sbd_scale_0123[0] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 0);\n                sbd_scale_4567[0] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 0);\n                sbd_min_0123[0]   = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 0);\n                sbd_min_4567[0]   = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 0);\n\n                sbd_scale_0123[1] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 1);\n                sbd_scale_4567[1] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 1);\n                sbd_min_0123[1]   = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 1);\n                sbd_min_4567[1]   = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 1);\n\n                sbd_scale_0123[2] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 2);\n                sbd_scale_4567[2] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 2);\n                sbd_min_0123[2]   = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 2);\n                sbd_min_4567[2]   = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 2);\n\n                sbd_scale_0123[3] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 3);\n                sbd_scale_4567[3] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 3);\n                sbd_min_0123[3]   = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 3);\n                sbd_min_4567[3]   = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 3);\n\n                // Precomputation of bsums, each vpaddq calcs all the bsums for each row\n                const int16x8_t bsums[q8_k_blocklen] = {\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),\n                };\n                int16_t bsums_arr[QK_K / 64][8];\n                for (int q8_row = 0; q8_row < 4; q8_row++) {\n                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);\n                }\n\n                // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..\n                int32x4_t bias_acc[acc_size];\n                for (int i = 0; i < acc_size; i++) {\n                    bias_acc[i] = vdupq_n_s32(0);\n                }\n\n                uint8x16_t qh[col_groups][8];\n                for (int c = 0; c < col_groups; c++) {\n                    for (int i = 0; i < 8; i++) {\n                        qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);\n                    }\n                }\n\n                for (int sb = 0; sb < QK_K / 64; sb++) {\n                    // Int accumulators for qs vecdot (4 row * 2 col quartets)\n                    int32x4_t acc_lo[acc_size];\n                    int32x4_t acc_hi[acc_size];\n                    for (int i = 0; i < acc_size; i++) {\n                        acc_lo[i] = vdupq_n_s32(0);\n                        acc_hi[i] = vdupq_n_s32(0);\n                    }\n                    // Need scales for the low and high nibbles\n                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total\n                    int16x8_t q5sb_scales[2];\n                    int16x8_t q5sb_mins[2];\n                    for (int i = 0; i < 2; i++) {\n                        int8_t    aux_q5sb[8];\n                        const int offset = sb * 24 + i * 12;\n                        decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);\n                        q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));\n                    }\n\n                    constexpr int reads_per_sb = 8;  // 8 * 16 bytes each => 32 qs * 4 rows\n                    for (int k = 0; k < reads_per_sb; k++) {\n                        const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);\n                        const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);\n\n                        // 0..3 & 32..35\n                        const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k);\n                        const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16);\n\n                        // NOTE: This is the only difference with q4_K\n                        const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone);\n                        const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3);\n                        qh[0][k]                      = vshrq_n_u8(qh[0][k], 2);\n                        const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone);\n                        const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3);\n                        qh[1][k]                      = vshrq_n_u8(qh[1][k], 2);\n                        // From here, same as q4_K\n\n                        const int8x16_t q5_0123_lo =\n                            vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4));\n                        const int8x16_t q5_0123_hi =\n                            vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123));\n\n                        acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0);  //  0..3  r0 c0123\n                        acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1);  //  0..3  r1 c0123\n                        acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q5_0123_lo, q8_blk0, 2);  //  0..3  r2 c0123\n                        acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q5_0123_lo, q8_blk0, 3);  //  0..3  r3 c0123\n\n                        acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q5_0123_hi, q8_blk1, 0);  // 32..35 r0 c0123\n                        acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q5_0123_hi, q8_blk1, 1);  // 32..35 r1 c0123\n                        acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2);  // 32..35 r2 c0123\n                        acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3);  // 32..35 r3 c0123\n\n                        const int8x16_t q5_4567_lo =\n                            vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4));\n                        const int8x16_t q5_4567_hi =\n                            vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567));\n\n                        acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0);  //  0..3  r0 c4567\n                        acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1);  //  0..3  r1 c4567\n                        acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2);  //  0..3  r2 c4567\n                        acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q5_4567_lo, q8_blk0, 3);  //  0..3  r3 c4567\n\n                        acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q5_4567_hi, q8_blk1, 0);  // 32..35 r0 c4567\n                        acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q5_4567_hi, q8_blk1, 1);  // 32..35 r1 c4567\n                        acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q5_4567_hi, q8_blk1, 2);  // 32..35 r2 c4567\n                        acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q5_4567_hi, q8_blk1, 3);  // 32..35 r3 c4567\n                    }\n\n                    // Scale and bias application\n                    // acc is stored interleaved to match output layout\n                    const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);\n                    const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);\n                    const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);\n                    const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);\n                    for (int row = 0; row < q8_k_blocklen; row++) {\n                        // Bias correction\n                        // row c0123 blk0 and blk1\n                        const float32x4_t sumf_0123 =\n                            vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),\n                                                    vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));\n                        acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);\n\n                        // row c4567 blk0 and blk1\n                        const float32x4_t sumf_4567 =\n                            vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),\n                                                    vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));\n                        acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);\n\n                        // Bias\n                        const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);\n                        const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);\n\n                        // row c0123 blk0 and blk1\n                        bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));\n                        bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));\n\n                        // row c4567 blk0 and blk1\n                        bias_acc[2 * row + 1] =\n                            vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));\n                        bias_acc[2 * row + 1] =\n                            vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));\n                    }\n                }  // for sb\n\n                for (int row = 0; row < q8_k_blocklen; row++) {\n                    acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);\n                    acc_f32[2 * row + 1] =\n                        vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);\n                }\n            }  // for b\n\n            for (int i = 0; i < q8_k_blocklen; i++) {\n                int row = y * q8_k_blocklen + i;\n                for (int j = 0; j < 2; j++) {\n                    int col    = x * ncols_interleaved + j * 4;\n                    int offset = row * bs + col;\n                    vst1q_f32(s + offset, acc_f32[2 * i + j]);\n                }\n            }\n        }  // for x\n    }  // for y\n    return;\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    ggml_gemm_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q4_K_8x8_q8_K(int                        n,\n                             float * GGML_RESTRICT      s,\n                             size_t                     bs,\n                             const void * GGML_RESTRICT vx,\n                             const void * GGML_RESTRICT vy,\n                             int                        nr,\n                             int                        nc) {\n    constexpr int qk = QK_K;\n    const int     nb = n / qk;\n\n    constexpr int ncols_interleaved = 8;\n    constexpr int blocklen          = 8;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)\n    if (svcntb() * 8 == 256) {\n        constexpr int    q8_k_blocklen = 4;\n        const svuint8_t m4b_1          = svdup_n_u8(0x0f);\n        // 8 accumulators: 2 row pairs × 4 col pairs\n        svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;\n        uint32_t idx_arr[8] = { 0, 2, 4, 6,  1, 3, 5, 7 };\n        svbool_t pg = svptrue_pat_b32(SV_VL8);\n        svuint32_t idx = svld1(pg, idx_arr);\n\n        static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};\n        svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);\n\n        for (int y = 0; y < nr / q8_k_blocklen; y++) {\n            const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);\n\n            for (int x = 0; x < nc / ncols_interleaved; x++) {\n                const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);\n\n                acc_f32_01 = svdup_n_f32(0);\n                acc_f32_23 = svdup_n_f32(0);\n                acc_f32_45 = svdup_n_f32(0);\n                acc_f32_67 = svdup_n_f32(0);\n\n                for (int b = 0; b < nb; b++) {\n                    // bsums pairs belongs to the same q8_k subblock\n                    // 64 elements loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum\n                    const int16x8_t bsums[4]{\n                        vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),\n                        vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),\n                        vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),\n                        vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),\n                    };\n\n                    int32_t bsums_arr32[4][8];\n\n                    for (int q8_row = 0; q8_row < 4; q8_row++) {\n                        int16x8_t v16 = bsums[q8_row];\n\n                        // low 4\n                        int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));\n                        vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);\n\n                        // high 4\n                        int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));\n                        vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);\n                    }\n\n                    svint32_t sb_acc_0 = svdup_n_s32(0);\n                    svint32_t sb_acc_2 = svdup_n_s32(0);\n\n                    svint32_t acc_00 = svdup_n_s32(0);\n                    svint32_t acc_11 = svdup_n_s32(0);\n                    svint32_t acc_22 = svdup_n_s32(0);\n                    svint32_t acc_33 = svdup_n_s32(0);\n                    svint32_t acc_44 = svdup_n_s32(0);\n                    svint32_t acc_55 = svdup_n_s32(0);\n                    svint32_t acc_66 = svdup_n_s32(0);\n                    svint32_t acc_77 = svdup_n_s32(0);\n\n                    svint32_t bias_acc_00 = svdup_n_s32(0);\n                    svint32_t bias_acc_22 = svdup_n_s32(0);\n                    svint32_t bias_acc_44 = svdup_n_s32(0);\n                    svint32_t bias_acc_66 = svdup_n_s32(0);\n\n                    for (int sb = 0; sb < QK_K / 64; sb++) {\n                        // Need scales for the low and high nibbles\n                        // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total\n                        svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;\n                        svint32_t q4sb_mins_0, q4sb_mins_1;\n                        {\n                            // 2-superblock I am working on\n                            const int offset = sb * 24 + 0 * 12;\n                            const uint8_t * scales_in = &q4_ptr[b].scales[offset];\n\n                            const int offset1 = sb * 24 + 12;\n                            const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1];\n\n                            constexpr uint32_t kmask1 = 0x3f3f3f3f;\n                            constexpr uint32_t kmask2 = 0x0f0f0f0f;\n                            constexpr uint32_t kmask3 = 0x03030303;\n                            constexpr uint8_t  scales_size = 12;\n\n                            uint32_t sm[3];\n                            memcpy(sm, scales_in, scales_size);\n\n                            uint32_t sm1[3];\n                            memcpy(sm1, scales_in1, scales_size);\n\n                            const uint32_t mins_0_3 = sm[1] & kmask1;\n                            const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);\n\n                            const uint32_t mins_0_3_1 = sm1[1] & kmask1;\n                            const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);\n\n                            svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));\n                            svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));\n\n                            /* reinterpret u32 → u8 */\n                            svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);\n                            svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);\n\n                            /* widen u8 → u16->u32 (lower half only) */\n                            svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));\n                            svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));\n\n                            q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);\n                            q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);\n\n                            uint32_t scales_u32_0 = sm[0] & kmask1;\n                            uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);\n                            uint32_t scales_u32_2 = sm1[0] & kmask1;\n                            uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);\n\n                            svuint32_t S01 = svdup_n_u32(scales_u32_0);\n                            svuint32_t S23 = svdup_n_u32(scales_u32_1);\n                            svuint32_t R01 = svdup_n_u32(scales_u32_2);\n                            svuint32_t R23 = svdup_n_u32(scales_u32_3);\n\n                            svint8_t S01_b = svreinterpret_s8_u32(S01);\n                            svint8_t S23_b = svreinterpret_s8_u32(S23);\n                            svint8_t R01_b = svreinterpret_s8_u32(R01);\n                            svint8_t R23_b = svreinterpret_s8_u32(R23);\n\n                            svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));\n                            svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));\n                            svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));\n                            svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));\n\n                            block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);\n                            block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);\n                            block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);\n                            block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);\n                        }\n\n                        const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256;\n\n                        // Load 32-byte per row pair, 1 subblock each time\n                        // predicate for activating higher lanes for 16 int8 elements\n                        const svbool_t ph16 = svptrue_pat_b8(SV_VL16);\n                        // predicate for activating lower lanes for  16 int8 elements\n                        const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);\n\n                        svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));\n                        svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));\n                        svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));\n                        svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));\n\n                        svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));\n                        svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));\n                        svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));\n                        svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));\n\n                        // Q4s columns iterated in pairs (01, 23, 45, 67)\n                        for (int cp = 0; cp < ncols_interleaved / 2; cp++) {\n\n                            sb_acc_0 = svdup_n_s32(0);\n                            sb_acc_2 = svdup_n_s32(0);\n\n                            svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);\n                            svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);\n                            svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);\n                            svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);\n\n                            svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));\n                            svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));\n                            svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));\n                            svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));\n\n                            sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);\n                            sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);\n\n                            sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);\n                            sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);\n\n                            sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);\n                            sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);\n\n                            sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);\n                            sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);\n\n                            if(cp == 0) {\n                                acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);\n                                acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);\n                            }\n                            if(cp == 1) {\n                                acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);\n                                acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);\n                            }\n                            if(cp == 2) {\n                                acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);\n                                acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);\n                            }\n                            if(cp == 3) {\n                                acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);\n                                acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);\n                            }\n                        }\n\n                        bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);\n                        bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);\n\n                        bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);\n                        bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);\n\n                        bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);\n                        bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);\n\n                        bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);\n                        bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);\n                    }  // for sb\n\n\n                    acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));\n                    acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));\n                    acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));\n                    acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));\n                    acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));\n                    acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));\n                    acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));\n                    acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));\n\n                    svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);\n                    svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);\n\n                    svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);\n                    svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);\n\n                    // Broadcast q8 scalar\n                    svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);\n\n                    svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));\n\n                    svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));\n\n                    svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);\n                    svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);\n\n                    acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);\n                    acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);\n\n                    q8_d = svdup_f32(q8_ptr[b].d[1]);\n\n                    scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);\n                    dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);\n\n                    acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);\n                    acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);\n\n                    q8_d = svdup_f32(q8_ptr[b].d[2]);\n\n\n                    scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);\n                    dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);\n\n                    acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);\n                    acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);\n\n                    q8_d = svdup_f32(q8_ptr[b].d[3]);\n\n                    scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);\n                    dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);\n\n                    acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);\n                    acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);\n\n                }  // for b\n\n                // With the previous reorder, the tile is already in the correct memory layout.\n                // Predicate for exactly 4 lanes\n                svbool_t pg4 = svptrue_pat_b32(SV_VL4);\n                for (int i = 0; i < q8_k_blocklen; i++) {\n                    int row = y * q8_k_blocklen + i;\n                    for (int j = 0; j < 2; j++) {\n                        int col    = x * ncols_interleaved + j * 4;\n                        int offset = row * bs + col;\n\n                        if (i == 0 && j == 0) {\n                            // acc_f32_0 → lower half of acc_f32_01\n                            svst1_f32(pg4, s + offset, acc_f32_01);\n                        } else if (i == 0 && j == 1) {\n                            // acc_f32_1 → upper half of acc_f32_01\n                            svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));\n                        } else if (i == 1 && j == 0) {\n                            // acc_f32_2\n                            svst1_f32(pg4, s + offset, acc_f32_23);\n                        } else if (i == 1 && j == 1) {\n                            // acc_f32_3\n                            svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));\n                        } else if (i == 2 && j == 0) {\n                            // acc_f32_4\n                            svst1_f32(pg4, s + offset, acc_f32_45);\n                        } else if (i == 2 && j == 1) {\n                            // acc_f32_5\n                            svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));\n                        } else if (i == 3 && j == 0) {\n                            // acc_f32_6\n                            svst1_f32(pg4, s + offset, acc_f32_67);\n                        } else if (i == 3 && j == 1) {\n                            // acc_f32_7\n                            svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));\n                        }\n                    }\n                }\n            }  // for x\n        }  // for y\n        return;\n    }\n#endif  // SVE compile-time end\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)\n    constexpr int    q8_k_blocklen = 4;\n    const uint8x16_t m4b           = vdupq_n_u8(0x0f);\n\n    // 8 accumulators: 2 row pairs × 4 col pairs\n    float32x4_t acc_f32[blocklen];\n\n    for (int y = 0; y < nr / q8_k_blocklen; y++) {\n        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);\n\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);\n\n            for (int i = 0; i < blocklen; i++) {\n                acc_f32[i] = vdupq_n_f32(0);\n            }\n\n            for (int b = 0; b < nb; b++) {\n                // bsums pairs belongs to the same q8_k subblock\n                const int16x8_t bsums[4]{\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),\n                };\n                int16_t bsums_arr[4][8];\n                for (int q8_row = 0; q8_row < 4; q8_row++) {\n                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);\n                }\n\n                int32x4_t sb_acc[4];    // Aux accumulators to store subblock (partial) results\n                int32x4_t acc[8];       // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]\n                int32x4_t bias_acc[8];  // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...\n                for (int i = 0; i < 8; i++) {\n                    acc[i]      = vdupq_n_s32(0);\n                    bias_acc[i] = vdupq_n_s32(0);\n                }\n\n                for (int sb = 0; sb < QK_K / 64; sb++) {\n                    // Need scales for the low and high nibbles\n                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total\n                    int8_t    q4sb_scales[2][8];\n                    int16x8_t q4sb_mins[2];  // int16 as its needed for bias_acc later\n                    for (int i = 0; i < 2; i++) {\n                        const int offset = sb * 24 + i * 12;\n                        decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);\n                    }\n\n                    // q8_ptr[b].qs has interleaved Q8 rows (01, 23)\n                    const int8_t * q8_base = q8_ptr[b].qs + sb * 256;\n\n                    int8x16_t q8_qs_01[8];\n                    int8x16_t q8_qs_23[8];\n\n                    // Load 32-byte per row pair, 1 subblock each time\n                    for (int i = 0; i < 8; i++) {\n                        const int offset = i * 32;  // 16 for row 01, 16 for row 23\n                        q8_qs_01[i]      = vld1q_s8(q8_base + offset);\n                        q8_qs_23[i]      = vld1q_s8(q8_base + offset + 16);\n                    }\n\n                    const int8x16_t q8s[2][8] = {\n                        { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],\n                          q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },\n                        { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],\n                          q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },\n                    };\n\n                    // Q4s columns iterated in pairs (01, 23, 45, 67)\n                    for (int cp = 0; cp < ncols_interleaved / 2; cp++) {\n                        for (int i = 0; i < 4; i++) {\n                            sb_acc[i] = vdupq_n_s32(0);\n                        }\n\n                        uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);    // 0 .. 7 & 32..39\n                        uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);   // 8 ..15 & 40..47\n                        uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);  // 16..23 & 48..55\n                        uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);  // 24..31 & 56..63\n                        const int8x16_t q4_nibbles[2][4] = {\n                            {\n                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),\n                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),\n                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),\n                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),\n                            },\n                            {\n                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),\n                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),\n                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),\n                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),\n                            }\n                        };\n\n                        // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8\n                        // for each of the internal 32 qs subblock (blk)\n                        for (int rp = 0; rp < 2; rp++) {\n                            for (int blk = 0; blk < 2; blk++) {\n                                const int8x16_t * q8  = &q8s[rp][4 * blk];\n                                const int8x16_t * q4  = q4_nibbles[blk];\n                                int32x4_t         acc = sb_acc[2 * rp + blk];\n                                // mul add for each qs in the same subblock\n                                for (int qs_offset = 0; qs_offset < 4; qs_offset++) {\n                                    acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);\n                                }\n                                sb_acc[2 * rp + blk] = acc;\n                            }\n                        }\n\n                        // Scales[i] corresponds to column i\n                        const int scale_offset = cp * 2;\n                        const int32_t scale_00 = q4sb_scales[0][scale_offset];\n                        const int32_t scale_01 = q4sb_scales[0][scale_offset + 1];\n                        const int32_t scale_10 = q4sb_scales[1][scale_offset];\n                        const int32_t scale_11 = q4sb_scales[1][scale_offset + 1];\n                        const int32x4_t block_scale_0 = vcombine_s32(vdup_n_s32(scale_00), vdup_n_s32(scale_01));\n                        const int32x4_t block_scale_1 = vcombine_s32(vdup_n_s32(scale_10), vdup_n_s32(scale_11));\n\n                        acc[cp]     = vmlaq_s32(acc[cp], sb_acc[0], block_scale_0);\n                        acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale_0);\n                        acc[cp]     = vmlaq_s32(acc[cp], sb_acc[1], block_scale_1);\n                        acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale_1);\n                    }\n\n                    // Multiply Acc bsum + mins\n                    for (int q8_row = 0; q8_row < 4; q8_row++) {\n                        // Each pair of subblocks share the same bsums\n                        // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).\n                        int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);\n                        int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);\n\n                        bias_acc[2 * q8_row] =\n                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));\n                        bias_acc[2 * q8_row] =\n                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));\n                        bias_acc[2 * q8_row + 1] =\n                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));\n                        bias_acc[2 * q8_row + 1] =\n                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));\n                    }\n                }  // for sb\n\n                // Reorder of i8mm output with bias and output layout\n                for (int i = 0; i < 8; i++) {\n                    int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));\n                    acc[i]          = vcombine_s32(aux.val[0], aux.val[1]);\n                }\n                int32x4_t reorder_acc[8] = {\n                    vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),\n                    vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),\n                    vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),\n                    vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),\n                    vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),\n                    vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),\n                    vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),\n                    vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),\n                };\n\n                for (int i = 0; i < q8_k_blocklen; i++) {\n                    for (int j = 0; j < 2; j++) {\n                        float32x4_t       q8_d    = vdupq_n_f32(q8_ptr[b].d[i]);\n                        float32x4_t       q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));\n                        const float32x4_t dmins   = vmulq_f32(q4_dmin, q8_d);\n\n                        float32x4_t       q4_d  = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));\n                        const float32x4_t scale = vmulq_f32(q4_d, q8_d);\n\n                        acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);\n                        acc_f32[2 * i + j] =\n                            vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);\n                    }\n                }\n            }  // for b\n\n            // With the previous reorder, the tile is already in the correct memory layout.\n            for (int i = 0; i < q8_k_blocklen; i++) {\n                int row = y * q8_k_blocklen + i;\n                for (int j = 0; j < 2; j++) {\n                    int col    = x * ncols_interleaved + j * 4;\n                    int offset = row * bs + col;\n                    vst1q_f32(s + offset, acc_f32[2 * i + j]);\n                }\n            }\n        }  // for x\n    }  // for y\n    return;\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)\n    ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q5_K_8x8_q8_K(int                        n,\n                             float * GGML_RESTRICT      s,\n                             size_t                     bs,\n                             const void * GGML_RESTRICT vx,\n                             const void * GGML_RESTRICT vy,\n                             int                        nr,\n                             int                        nc) {\n    constexpr int qk = QK_K;\n    const int     nb = n / qk;\n\n    constexpr int ncols_interleaved = 8;\n    constexpr int blocklen          = 8;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)\n    constexpr int    q8_k_blocklen = 4;\n    constexpr int    col_pairs     = ncols_interleaved / 2;\n    const uint8x16_t m4b           = vdupq_n_u8(0x0f);\n    const uint8x16_t mone          = vdupq_n_u8(1);\n    const uint8x16_t mtwo          = vdupq_n_u8(2);\n\n    // 8 accumulators: 2 row pairs × 4 col pairs\n    float32x4_t acc_f32[blocklen];\n\n    for (int y = 0; y < nr / q8_k_blocklen; y++) {\n        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);\n\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);\n\n            for (int i = 0; i < blocklen; i++) {\n                acc_f32[i] = vdupq_n_f32(0);\n            }\n\n            for (int b = 0; b < nb; b++) {\n                // bsums pairs belongs to the same q8_k subblock\n                const int16x8_t bsums[4]{\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),\n                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),\n                };\n                int16_t bsums_arr[4][8];\n                for (int q8_row = 0; q8_row < 4; q8_row++) {\n                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);\n                }\n\n                int32x4_t sb_acc[4];    // Aux accumulators to store subblock (partial) results\n                int32x4_t acc[8];       // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]\n                int32x4_t bias_acc[8];  // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...\n                for (int i = 0; i < 8; i++) {\n                    acc[i]      = vdupq_n_s32(0);\n                    bias_acc[i] = vdupq_n_s32(0);\n                }\n\n                // Load qh once per block and shift after each subblock\n                const uint8_t * qh_base = q5_ptr[b].qh;\n                uint8x16_t      qh[col_pairs][4];\n                for (int cp = 0; cp < col_pairs; cp++) {\n                    qh[cp][0] = vld1q_u8(qh_base + 16 * cp);\n                    qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);\n                    qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);\n                    qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);\n                }\n\n                for (int sb = 0; sb < QK_K / 64; sb++) {\n                    // Need scales for the low and high nibbles\n                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total\n                    int8_t    q5sb_scales[2][8];\n                    int16x8_t q5sb_mins[2];  // int16 as its needed for bias_acc later\n                    for (int i = 0; i < 2; i++) {\n                        const int offset = sb * 24 + i * 12;\n                        decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]);\n                    }\n\n                    // q8_ptr[b].qs has interleaved Q8 rows (01, 23)\n                    const int8_t * q8_base = q8_ptr[b].qs + sb * 256;\n\n                    int8x16_t q8_qs_01[8];\n                    int8x16_t q8_qs_23[8];\n\n                    // Load 32-byte per row pair, 1 subblock each time\n                    for (int i = 0; i < 8; i++) {\n                        const int offset = i * 32;  // 16 for row 01, 16 for row 23\n                        q8_qs_01[i]      = vld1q_s8(q8_base + offset);\n                        q8_qs_23[i]      = vld1q_s8(q8_base + offset + 16);\n                    }\n\n                    const int8x16_t q8s[2][8] = {\n                        { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6],\n                         q8_qs_01[7] },\n                        { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6],\n                         q8_qs_23[7] },\n                    };\n\n                    // Q5s columns iterated in pairs (01, 23, 45, 67)\n                    for (int cp = 0; cp < col_pairs; cp++) {\n                        for (int i = 0; i < 4; i++) {\n                            sb_acc[i] = vdupq_n_s32(0);\n                        }\n\n                        uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0);    // 0 .. 7 & 32..39\n                        uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64);   // 8 ..15 & 40..47\n                        uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128);  // 16..23 & 48..55\n                        uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192);  // 24..31 & 56..63\n\n                        // This is the only part of the algorithm that differs with Q4_K\n                        // Extract High bits and pack into 5 bit weights\n                        uint8x16_t hbit_lo_0    = vandq_u8(qh[cp][0], mone);\n                        uint8x16_t hbit_hi_0    = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3);\n                        qh[cp][0]               = vshrq_n_u8(qh[cp][0], 2);\n                        // Same as Q4_K, i8mm to dequantize the weights.\n                        const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));\n                        int32x4_t       acc_0   = sb_acc[0];\n                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);\n                        int32x4_t acc_2         = sb_acc[2];\n                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);\n                        const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));\n                        int32x4_t       acc_1   = sb_acc[1];\n                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);\n                        int32x4_t acc_3         = sb_acc[3];\n                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]);\n\n                        // Repeat for the other 3 columns (8..15, 16..23, 24..31)\n                        uint8x16_t hbit_hi_1    = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3);\n                        uint8x16_t hbit_lo_1    = vandq_u8(qh[cp][1], mone);\n                        qh[cp][1]               = vshrq_n_u8(qh[cp][1], 2);\n                        const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4));\n                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]);\n                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]);\n                        const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1));\n                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]);\n                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]);\n\n                        uint8x16_t hbit_hi_2    = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3);\n                        uint8x16_t hbit_lo_2    = vandq_u8(qh[cp][2], mone);\n                        qh[cp][2]               = vshrq_n_u8(qh[cp][2], 2);\n                        const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4));\n                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]);\n                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]);\n                        const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2));\n                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]);\n                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]);\n\n                        uint8x16_t hbit_lo_3    = vandq_u8(qh[cp][3], mone);\n                        uint8x16_t hbit_hi_3    = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3);\n                        qh[cp][3]               = vshrq_n_u8(qh[cp][3], 2);\n                        const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4));\n                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]);\n                        sb_acc[0]               = acc_0;\n                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]);\n                        sb_acc[2]               = acc_2;\n\n                        // Scales[i] corresponds to column i\n                        const int       scale_offset = cp * 2;\n                        const int32_t   s0           = q5sb_scales[0][scale_offset];\n                        const int32_t   s1           = q5sb_scales[0][scale_offset + 1];\n                        const int32x4_t block_scale  = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1));\n                        acc[cp]                      = vmlaq_s32(acc[cp], sb_acc[0], block_scale);\n                        acc[cp + 4]                  = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale);\n\n                        const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3));\n                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]);\n                        sb_acc[1]               = acc_1;\n                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]);\n                        sb_acc[3]               = acc_3;\n\n                        const int32_t   s2           = q5sb_scales[1][scale_offset];\n                        const int32_t   s3           = q5sb_scales[1][scale_offset + 1];\n                        const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3));\n                        acc[cp]                      = vmlaq_s32(acc[cp], sb_acc[1], block_scale2);\n                        acc[cp + 4]                  = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2);\n                    }\n\n                    // Multiply Acc bsum + mins\n                    for (int q8_row = 0; q8_row < 4; q8_row++) {\n                        // Each pair of subblocks share the same bsums\n                        // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).\n                        int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);\n                        int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);\n\n                        bias_acc[2 * q8_row] =\n                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));\n                        bias_acc[2 * q8_row] =\n                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));\n                        bias_acc[2 * q8_row + 1] =\n                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));\n                        bias_acc[2 * q8_row + 1] =\n                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));\n                    }\n                }  // for sb\n\n                // Reorder of i8mm output with bias and output layout\n                for (int i = 0; i < 8; i++) {\n                    int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));\n                    acc[i]          = vcombine_s32(aux.val[0], aux.val[1]);\n                }\n                int32x4_t reorder_acc[8] = {\n                    vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),\n                    vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),\n                    vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),\n                    vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),\n                    vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),\n                    vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),\n                    vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),\n                    vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),\n                };\n\n                for (int i = 0; i < q8_k_blocklen; i++) {\n                    for (int j = 0; j < 2; j++) {\n                        float32x4_t       q8_d    = vdupq_n_f32(q8_ptr[b].d[i]);\n                        float32x4_t       q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4)));\n                        const float32x4_t dmins   = vmulq_f32(q5_dmin, q8_d);\n\n                        float32x4_t       q5_d  = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4)));\n                        const float32x4_t scale = vmulq_f32(q5_d, q8_d);\n\n                        acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);\n                        acc_f32[2 * i + j] =\n                            vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);\n                    }\n                }\n            }  // for b\n\n            // With the previous reorder, the tile is already in the correct memory layout.\n            for (int i = 0; i < q8_k_blocklen; i++) {\n                int row = y * q8_k_blocklen + i;\n                for (int j = 0; j < 2; j++) {\n                    int col    = x * ncols_interleaved + j * 4;\n                    int offset = row * bs + col;\n                    vst1q_f32(s + offset, acc_f32[2 * i + j]);\n                }\n            }\n        }  // for x\n    }  // for y\n    return;\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)\n    ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q6_K_8x4_q8_K(int                        n,\n                             float * GGML_RESTRICT      s,\n                             size_t                     bs,\n                             const void * GGML_RESTRICT vx,\n                             const void * GGML_RESTRICT vy,\n                             int                        nr,\n                             int                        nc) {\n    constexpr int qk = QK_K;\n    const int     nb = n / qk;\n\n    constexpr int ncols_interleaved = 8;\n    constexpr int blocklen          = 4;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    constexpr int    q8_k_blocklen = 4;\n    constexpr int    col_groups    = ncols_interleaved / 4;\n    constexpr int    acc_size      = q8_k_blocklen * col_groups;  // 4 rows, 2 column groups\n    const uint8x16_t m4b           = vdupq_n_u8(0x0f);\n    const uint8x16_t mask_lo       = vdupq_n_u8(0x03);\n    const uint8x16_t mask_hi       = vdupq_n_u8(0x30);\n    const int8x16_t  m32s          = vdupq_n_s8(32);\n\n    float32x4_t acc_f32[acc_size];\n\n    for (int y = 0; y < nr / q8_k_blocklen; y++) {\n        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);\n\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);\n\n            for (int i = 0; i < acc_size; i++) {\n                acc_f32[i] = vdupq_n_f32(0);\n            }\n\n            for (int b = 0; b < nb; b++) {\n                float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));\n                float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));\n                float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);\n\n                float32x4_t sbd_scale_0123[q8_k_blocklen];\n                float32x4_t sbd_scale_4567[q8_k_blocklen];\n\n                sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0);\n                sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0);\n                sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1);\n                sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1);\n                sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2);\n                sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2);\n                sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3);\n                sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3);\n\n                int32x4_t acc_s32[acc_size];\n                for (int i = 0; i < acc_size; i++) {\n                    acc_s32[i] = vdupq_n_s32(0);\n                }\n\n                int16_t q6_scales[8 * 16];\n                for (int i = 0; i < 16; i++) {\n                    int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));\n                    vst1q_s16(q6_scales + i * 8, scales);\n                }\n\n                for (int half = 0; half < 2; half++) {\n                    const uint8_t * ql_base = q6_ptr[b].ql + half * 512;\n                    const uint8_t * qh_base = q6_ptr[b].qh + half * 256;\n\n                    for (int sb = 0; sb < QK_K / 64; sb++) {\n                        int32x4_t acc_lo[acc_size];\n                        int32x4_t acc_hi[acc_size];\n                        for (int i = 0; i < acc_size; i++) {\n                            acc_lo[i] = vdupq_n_s32(0);\n                            acc_hi[i] = vdupq_n_s32(0);\n                        }\n\n                        const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;\n                        const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;\n\n                        // 4 rows * 16 elements per scale\n                        // 4 reads of 16 bytes each\n                        constexpr int reads_per_sb = 4;\n                        int8x16_t     q8_l[reads_per_sb];\n                        int8x16_t     q8_h[reads_per_sb];\n                        for (int k = 0; k < reads_per_sb; k++) {\n                            q8_l[k] = vld1q_s8(q8_base_l + 16 * k);\n                            q8_h[k] = vld1q_s8(q8_base_h + 16 * k);\n                        }\n\n                        const int ql_off_base = sb * QK_K / 2;\n                        const int qh_off_base = ql_off_base & 255;\n\n                        uint8x16_t q6_ql_0123[reads_per_sb];\n                        uint8x16_t q6_ql_4567[reads_per_sb];\n                        uint8x16_t q6_qh_0123[reads_per_sb];\n                        uint8x16_t q6_qh_4567[reads_per_sb];\n\n                        for (int k = 0; k < reads_per_sb; k++) {\n                            q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32);\n                            q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16);\n                            q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32);\n                            q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16);\n                        }\n\n                        if (sb > 1) {\n                            for (int k = 0; k < reads_per_sb; k++) {\n                                q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2);\n                                q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2);\n                            }\n                        }\n\n                        for (int k = 0; k < reads_per_sb; k++) {\n                            // q = (ql | qh) - 32\n                            const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo);\n                            const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi);\n                            const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo);\n                            const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi);\n\n                            const int8x16_t q6_0123_lo = vsubq_s8(\n                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s);\n                            const int8x16_t q6_0123_hi = vsubq_s8(\n                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s);\n\n                            acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0);  //  0..3  r0 c0123\n                            acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1);  //  0..3  r1 c0123\n                            acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2);  //  0..3  r2 c0123\n                            acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3);  //  0..3  r3 c0123\n\n                            acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0);  // 64..67 r0 c0123\n                            acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1);  // 64..67 r1 c0123\n                            acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2);  // 64..67 r2 c0123\n                            acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3);  // 64..67 r3 c0123\n\n                            const int8x16_t q6_4567_lo = vsubq_s8(\n                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s);\n                            const int8x16_t q6_4567_hi = vsubq_s8(\n                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s);\n\n                            acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0);  //  0..3  r0 c4567\n                            acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1);  //  0..3  r1 c4567\n                            acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2);  //  0..3  r2 c4567\n                            acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3);  //  0..3  r3 c4567\n\n                            acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0);  // 64..67 r0 c4567\n                            acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1);  // 64..67 r1 c4567\n                            acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2);  // 64..67 r2 c4567\n                            acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3);  // 64..67 r3 c4567\n                        }\n\n                        // Scale and bias\n                        const int scale_idx_l = half * 8 + sb;\n                        const int scale_idx_h = half * 8 + sb + 4;\n\n                        for (int g = 0; g < col_groups; g++) {\n                            const int16x4_t scales_l16  = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4);\n                            const int16x4_t scales_h16  = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4);\n                            const int32x4_t scale_vec_l = vmovl_s16(scales_l16);\n                            const int32x4_t scale_vec_h = vmovl_s16(scales_h16);\n                            const int       acc_offset  = g * q8_k_blocklen;\n\n                            for (int row = 0; row < q8_k_blocklen; row++) {\n                                const int idx = row * 2 + g;\n                                acc_s32[idx]  = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l);\n                                acc_s32[idx]  = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h);\n                            }\n                        }\n                    }\n                }\n\n                // Finally we apply the superblock scales\n                for (int row = 0; row < q8_k_blocklen; row++) {\n                    const int       idx0     = 2 * row;\n                    const int       idx1     = 2 * row + 1;\n                    const int32x4_t acc_0123 = acc_s32[idx0];\n                    const int32x4_t acc_4567 = acc_s32[idx1];\n\n                    acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]);\n                    acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]);\n                }\n            }  // for b\n\n            for (int i = 0; i < q8_k_blocklen; i++) {\n                int row = y * q8_k_blocklen + i;\n                for (int j = 0; j < 2; j++) {\n                    int col    = x * ncols_interleaved + j * 4;\n                    int offset = row * bs + col;\n                    vst1q_f32(s + offset, acc_f32[2 * i + j]);\n                }\n            }\n        }  // for x\n    }  // for y\n    return;\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q6_K_8x8_q8_K(int                        n,\n                             float * GGML_RESTRICT      s,\n                             size_t                     bs,\n                             const void * GGML_RESTRICT vx,\n                             const void * GGML_RESTRICT vy,\n                             int                        nr,\n                             int                        nc) {\n    constexpr int qk = QK_K;\n    const int     nb = n / qk;\n\n    constexpr int ncols_interleaved = 8;\n    constexpr int blocklen          = 8;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)\n    constexpr int    q8_k_blocklen = 4;\n    const uint8x16_t m4b           = vdupq_n_u8(0x0f);\n    const uint8x16_t mask_lo       = vdupq_n_u8(0x03);\n    const uint8x16_t mask_hi       = vdupq_n_u8(0x30);\n    const int8x16_t  m32s          = vdupq_n_s8(32);\n\n    // 8 accumulators: 4 q8 rows × 2 col groups (0-3, 4-7)\n    float32x4_t acc_f32[blocklen];\n\n    for (int y = 0; y < nr / q8_k_blocklen; y++) {\n        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);\n\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);\n\n            for (int i = 0; i < blocklen; i++) {\n                acc_f32[i] = vdupq_n_f32(0);\n            }\n\n            for (int b = 0; b < nb; b++) {\n                int32x4_t acc[8];  // rows 01 stored in [0][1][2][3], rows 23 stored in [4][5][6][7]\n                for (int i = 0; i < 8; i++) {\n                    acc[i] = vdupq_n_s32(0);\n                }\n\n                // Q6_K has simple 8-bit scales, 16 per block (one per 16 values)\n                // Reused for bias and dequantization later\n                int16_t q6_scales[16 * 8];\n                for (int i = 0; i < 16; ++i) {\n                    int16x8_t s16 = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));\n                    vst1q_s16(q6_scales + i * 8, s16);\n                }\n\n                // Process two 128-value halves per superblock\n                for (int half = 0; half < 2; half++) {\n\n                    const uint8_t * ql_base = q6_ptr[b].ql + half * 512;\n                    const uint8_t * qh_base = q6_ptr[b].qh + half * 256;\n\n                    // A subblock (sb) is a set of weights that share the scale\n                    // Since q6_K scales are per 16 elements\n                    // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)\n                    for (int sb = 0; sb < QK_K / 64; sb++) {\n                        // Q6_K weight index increasing by 64 instead of 32 requires\n                        // loading various q8 memory regions\n                        const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;\n                        const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;\n\n                        int8x16_t q8_l_01[2];\n                        int8x16_t q8_l_23[2];\n                        for (int i = 0; i < 2; i++) {\n                            const int offset = i * 32;\n                            q8_l_01[i]       = vld1q_s8(q8_base_l + offset);       // 0..7 & 8..15 (r01)\n                            q8_l_23[i]       = vld1q_s8(q8_base_l + offset + 16);  // 0..7 & 8..15 (r23)\n                        }\n\n                        int8x16_t q8_h_01[2];\n                        int8x16_t q8_h_23[2];\n                        for (int i = 0; i < 2; i++) {\n                            const int offset = i * 32;\n                            q8_h_01[i]       = vld1q_s8(q8_base_h + offset);\n                            q8_h_23[i]       = vld1q_s8(q8_base_h + offset + 16);\n                        }\n\n                        const int ql_off_base = sb * QK_K / 2;\n\n                        uint8x16_t q6_ql_0[4];\n                        uint8x16_t q6_ql_1[4];\n                        for (int k = 0; k < 4; k++) {\n                            q6_ql_0[k] = vld1q_u8(ql_base + ql_off_base + 16 * k);\n                            q6_ql_1[k] = vld1q_u8(ql_base + ql_off_base + 64 + 16 * k);\n                        }\n\n                        const int  qh_off_base = (sb * QK_K / 2) & 255;  // wrap after 256 bytes\n                        uint8x16_t q6_qh_0[4];\n                        uint8x16_t q6_qh_1[4];\n                        for (int k = 0; k < 4; k++) {\n                            q6_qh_0[k] = vld1q_u8(qh_base + qh_off_base + 16 * k);\n                            q6_qh_1[k] = vld1q_u8(qh_base + qh_off_base + 64 + 16 * k);\n                        }\n\n                        // Adjust for the proper high bits (Sb 2 and 3)\n                        if (sb > 1) {\n                            for (int k = 0; k < 4; k++) {\n                                q6_qh_0[k] = vshrq_n_u8(q6_qh_0[k], 2);\n                                q6_qh_1[k] = vshrq_n_u8(q6_qh_1[k], 2);\n                            }\n                        }\n\n                        // Process column pairs (0-1, 2-3, 4-5, 6-7)\n                        for (int cp = 0; cp < ncols_interleaved / 2; cp++) {\n                            const uint8x16_t q6_qs_cp_0_l = q6_ql_0[cp];\n                            const uint8x16_t q6_qs_cp_1_l = q6_ql_1[cp];\n                            const uint8x16_t q6_qs_cp_0_h = q6_qh_0[cp];\n                            const uint8x16_t q6_qs_cp_1_h = q6_qh_1[cp];\n\n                            // Extract high 2 bits for upper nibble reconstruction\n                            const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);\n                            const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);\n\n                            // q6 = (low4 | high2<<4) - 32\n                            // Use vsliq_n_u8 to combine shift-left-insert in one instruction (like Q5_K)\n                            const int8x16_t q6_l0 = vsubq_s8(\n                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)),\n                                m32s);\n                            const int8x16_t q6_l1 = vsubq_s8(\n                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)),\n                                m32s);\n                            const int8x16_t q6_h0 = vsubq_s8(\n                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)), m32s);\n                            const int8x16_t q6_h1 = vsubq_s8(\n                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)), m32s);\n\n                            // row pair 0, base_l\n                            int32x4_t sb_acc_0l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_01[0]);\n                            sb_acc_0l           = vmmlaq_s32(sb_acc_0l, q6_l1, q8_l_01[1]);\n                            // row pair 0, base_h\n                            int32x4_t sb_acc_0h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_01[0]);\n                            sb_acc_0h           = vmmlaq_s32(sb_acc_0h, q6_h1, q8_h_01[1]);\n                            // row pair 1, base_l\n                            int32x4_t sb_acc_1l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_23[0]);\n                            sb_acc_1l           = vmmlaq_s32(sb_acc_1l, q6_l1, q8_l_23[1]);\n                            // row pair 1, base_h\n                            int32x4_t sb_acc_1h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_23[0]);\n                            sb_acc_1h           = vmmlaq_s32(sb_acc_1h, q6_h1, q8_h_23[1]);\n\n                            const int scale_idx_l = half * 8 + sb;\n                            const int scale_idx_h = half * 8 + sb + 4;\n\n                            const int32x4_t scale_vec_l = {\n                                q6_scales[scale_idx_l * 8 + cp * 2 + 0],\n                                q6_scales[scale_idx_l * 8 + cp * 2 + 0],\n                                q6_scales[scale_idx_l * 8 + cp * 2 + 1],\n                                q6_scales[scale_idx_l * 8 + cp * 2 + 1],\n                            };\n                            const int32x4_t scale_vec_h = {\n                                q6_scales[scale_idx_h * 8 + cp * 2 + 0],\n                                q6_scales[scale_idx_h * 8 + cp * 2 + 0],\n                                q6_scales[scale_idx_h * 8 + cp * 2 + 1],\n                                q6_scales[scale_idx_h * 8 + cp * 2 + 1],\n                            };\n\n                            acc[cp]     = vmlaq_s32(acc[cp], sb_acc_0l, scale_vec_l);\n                            acc[cp]     = vmlaq_s32(acc[cp], sb_acc_0h, scale_vec_h);\n                            acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1l, scale_vec_l);\n                            acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1h, scale_vec_h);\n                        }\n                    }\n                }  // for half\n\n                // Reorder i8mm output to match memory layout\n                for (int i = 0; i < 8; i++) {\n                    int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));\n                    acc[i]          = vcombine_s32(aux.val[0], aux.val[1]);\n                }\n                int32x4_t reorder_acc[8] = {\n                    vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),\n                    vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),\n                    vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),\n                    vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),\n                    vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),\n                    vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),\n                    vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),\n                    vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),\n                };\n\n                // Apply superblock scale (no mins for q6_K)\n                for (int i = 0; i < q8_k_blocklen; i++) {\n                    for (int j = 0; j < 2; j++) {\n                        float32x4_t       q8_d  = vdupq_n_f32(q8_ptr[b].d[i]);\n                        float32x4_t       q6_d  = vcvt_f32_f16(vld1_f16((const __fp16 *) (q6_ptr[b].d + j * 4)));\n                        const float32x4_t scale = vmulq_f32(q6_d, q8_d);\n\n                        acc_f32[2 * i + j] =\n                            vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);\n                    }\n                }\n            }  // for b\n\n            // Store results\n            for (int i = 0; i < q8_k_blocklen; i++) {\n                int row = y * q8_k_blocklen + i;\n                for (int j = 0; j < 2; j++) {\n                    int col    = x * ncols_interleaved + j * 4;\n                    int offset = row * bs + col;\n                    vst1q_f32(s + offset, acc_f32[2 * i + j]);\n                }\n            }\n        }  // for x\n    }  // for y\n    return;\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)\n    ggml_gemm_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q8_0_4x4_q8_0(int                        n,\n                             float * GGML_RESTRICT      s,\n                             size_t                     bs,\n                             const void * GGML_RESTRICT vx,\n                             const void * GGML_RESTRICT vy,\n                             int                        nr,\n                             int                        nc) {\n    const int qk                = QK8_0;\n    const int nb                = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen          = 4;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);\n\n            float32x4_t sumf[4];\n            for (int m = 0; m < 4; m++) {\n                sumf[m] = vdupq_n_f32(0);\n            }\n\n            for (int l = 0; l < nb; l++) {\n                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *) a_ptr[l].d));\n                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *) b_ptr[l].d));\n\n                int32x4_t sumi_0 = vdupq_n_s32(0);\n                int32x4_t sumi_1 = vdupq_n_s32(0);\n                int32x4_t sumi_2 = vdupq_n_s32(0);\n                int32x4_t sumi_3 = vdupq_n_s32(0);\n\n                for (int k_group = 0; k_group < 8; k_group += 4) {\n                    int8x16x4_t a = vld1q_s8_x4(a_ptr[l].qs + 16 * k_group);\n                    int8x16x4_t b = vld1q_s8_x4(b_ptr[l].qs + 16 * k_group);\n\n                    for (int k = 0; k < 4; k++) {\n                        sumi_0 = vdotq_laneq_s32(sumi_0, b.val[k], a.val[k], 0);\n                        sumi_1 = vdotq_laneq_s32(sumi_1, b.val[k], a.val[k], 1);\n                        sumi_2 = vdotq_laneq_s32(sumi_2, b.val[k], a.val[k], 2);\n                        sumi_3 = vdotq_laneq_s32(sumi_3, b.val[k], a.val[k], 3);\n                    }\n                }\n\n                sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));\n                sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));\n                sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));\n                sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));\n            }\n\n            for (int m = 0; m < 4; m++) {\n                vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);\n            }\n        }\n    }\n    return;\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)\n    ggml_gemm_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q8_0_4x8_q8_0(int                        n,\n                             float * GGML_RESTRICT      s,\n                             size_t                     bs,\n                             const void * GGML_RESTRICT vx,\n                             const void * GGML_RESTRICT vy,\n                             int                        nr,\n                             int                        nc) {\n    const int qk                = QK8_0;\n    const int nb                = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen          = 8;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)\n    const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;\n\n    for (int y = 0; y < nr; y += 4) {\n        const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;\n\n        for (int x = 0; x < nc; x += ncols_interleaved) {\n            const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;\n            const block_q8_0x4 * a_ptr = a_ptr_base;\n\n            float32x4_t acc_f32[4];\n            for (int i = 0; i < 4; i++) {\n                acc_f32[i] = vdupq_n_f32(0);\n            }\n\n            for (int b = 0; b < nb; b++) {\n                int32x4_t acc[4];\n                for (int i = 0; i < 4; i++) {\n                    acc[i] = vdupq_n_s32(0);\n                }\n\n                // Process 4 chunks of 8 positions each\n                for (int chunk = 0; chunk < 4; chunk++) {\n                    int8x16_t a01 = vld1q_s8(a_ptr->qs + chunk * 32);\n                    int8x16_t a23 = vld1q_s8(a_ptr->qs + chunk * 32 + 16);\n                    int8x16_t b01 = vld1q_s8(b_ptr->qs + chunk * 32);\n                    int8x16_t b23 = vld1q_s8(b_ptr->qs + chunk * 32 + 16);\n\n                    acc[0] = vmmlaq_s32(acc[0], a01, b01);\n                    acc[1] = vmmlaq_s32(acc[1], a01, b23);\n                    acc[2] = vmmlaq_s32(acc[2], a23, b01);\n                    acc[3] = vmmlaq_s32(acc[3], a23, b23);\n                }\n\n                // Reorder outputs from 2×2 tiles to row-major\n                // acc[0] = [r0c0, r0c1, r1c0, r1c1]\n                // acc[1] = [r0c2, r0c3, r1c2, r1c3]\n                // acc[2] = [r2c0, r2c1, r3c0, r3c1]\n                // acc[3] = [r2c2, r2c3, r3c2, r3c3]\n                int32x4_t row0 = vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1]));\n                int32x4_t row1 = vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1]));\n                int32x4_t row2 = vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3]));\n                int32x4_t row3 = vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3]));\n\n                // Scales\n                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const __fp16 *) a_ptr->d));\n                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const __fp16 *) b_ptr->d));\n\n                acc_f32[0] = vfmaq_f32(acc_f32[0], vcvtq_f32_s32(row0), vmulq_laneq_f32(b_d, a_d, 0));\n                acc_f32[1] = vfmaq_f32(acc_f32[1], vcvtq_f32_s32(row1), vmulq_laneq_f32(b_d, a_d, 1));\n                acc_f32[2] = vfmaq_f32(acc_f32[2], vcvtq_f32_s32(row2), vmulq_laneq_f32(b_d, a_d, 2));\n                acc_f32[3] = vfmaq_f32(acc_f32[3], vcvtq_f32_s32(row3), vmulq_laneq_f32(b_d, a_d, 3));\n\n                a_ptr++;\n                b_ptr++;\n            }\n\n            for (int row = 0; row < 4; row++) {\n                vst1q_f32(s + (y + row) * bs + x, acc_f32[row]);\n            }\n        }\n    }\n    return;\n#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)\n    ggml_gemm_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n"
  },
  {
    "path": "src/ggml-cpu/arch/loongarch/quants.c",
    "content": "#define GGML_COMMON_IMPL_C\n#include \"ggml-common.h\"\n#include \"ggml-quants.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-cpu.h\"\n#include \"simd-mappings.h\"\n\n#include \"../../quants.h\"\n#include \"../../ggml-cpu-impl.h\"\n\n#include <math.h>\n#include <string.h>\n#include <assert.h>\n#include <float.h>\n#include <stdlib.h> // for qsort\n#include <stdio.h>  // for GGML_ASSERT\n\n#define GROUP_MAX_EPS 1e-15f\n#define GROUP_MAX_EPS_IQ3_XXS 1e-8f\n#define GROUP_MAX_EPS_IQ2_S 1e-8f\n#define GROUP_MAX_EPS_IQ1_M 1e-7f\n#define GROUP_MAX_EPS_IQ1_S 1e-12f\n\n#define UNUSED GGML_UNUSED\n\n#if defined(__loongarch_sx)\n\nstatic __m128i lsx_packs_w(__m128i a, __m128i b) {\n    __m128i tmp, tmp1;\n    tmp = __lsx_vsat_w(a, 15);\n    tmp1 = __lsx_vsat_w(b, 15);\n    return __lsx_vpickev_h(tmp1, tmp);\n}\n\nstatic __m128i lsx_packs_h(__m128i a, __m128i b) {\n    __m128i tmp, tmp1;\n    tmp = __lsx_vsat_h(a, 7);\n    tmp1 = __lsx_vsat_h(b, 7);\n    return __lsx_vpickev_b(tmp1, tmp);\n}\n\nstatic __m128i lsx_packus_h(__m128i a, __m128i b) {\n    __m128i tmp, tmp1;\n    tmp = __lsx_vsat_hu(a, 7);\n    tmp1 = __lsx_vsat_hu(b, 7);\n    return __lsx_vpickev_b(tmp1, tmp);\n}\n\nstatic __m128i lsx_maddubs_h(__m128i a, __m128i b) {\n    __m128i tmp1, tmp2;\n    tmp1 = __lsx_vmulwev_h_b(a, b);\n    tmp2 = __lsx_vmulwod_h_b(a, b);\n    return __lsx_vsadd_h(tmp1, tmp2);\n}\n\nstatic __m128i lsx_madd_h(__m128i a, __m128i b) {\n    __m128i tmp1, tmp2;\n    tmp1 = __lsx_vmulwev_w_h(a, b);\n    tmp2 = __lsx_vmulwod_w_h(a, b);\n    return __lsx_vadd_w(tmp1, tmp2);\n}\n\nstatic __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {\n    v4i32 __ret = {d, c, b, a};\n    return (__m128i)__ret;\n}\n\nstatic __m128i lsx_shuffle_b(__m128i a, __m128i b) {\n    __m128i mask_f, zero, tmp0, tmp2, mask;\n    int f = 0x8f;\n    mask_f = __lsx_vreplgr2vr_b(f);\n    zero = __lsx_vldi(0);\n    tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits\n    tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or  with 0x10 prepare for positive\n    mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask\n    tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones\n    return __lsx_vshuf_b(a, zero, tmp2);\n}\n\nstatic __m128i lsx_hadd_h(__m128i a, __m128i b) {\n    __m128i tmp1 = __lsx_vpickev_h(b, a);\n    __m128i tmp2 = __lsx_vpickod_h(b, a);\n    return __lsx_vadd_h(tmp1, tmp2);\n}\n\nstatic __m128i lsx_hadd_w(__m128i a, __m128i b) {\n    __m128i tmp1 = __lsx_vpickev_w(b, a);\n    __m128i tmp2 = __lsx_vpickod_w(b, a);\n    return __lsx_vadd_w(tmp1, tmp2);\n}\n\nstatic __m128 lsx_hadd_s(__m128 a, __m128 b) {\n    __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);\n    __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);\n\n    return __lsx_vfadd_s(tmp1, tmp2);\n}\n\nstatic inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {\n    __m128 res_0 =lsx_hadd_s(a, b);\n    __m128 res_1 =lsx_hadd_s(c, d);\n    __m128 res =lsx_hadd_s(res_0, res_1);\n    res =lsx_hadd_s(res, res);\n    res =lsx_hadd_s(res, res);\n\n    return ((v4f32)res)[0];\n}\n\n// multiply int8_t, add results pairwise twice\nstatic inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {\n    // Get absolute values of x vectors\n    const __m128i ax = __lsx_vsigncov_b(x, x);\n    // Sign the values of the y vectors\n    const __m128i sy = __lsx_vsigncov_b(x, y);\n    // Perform multiplication and create 16-bit values\n    const __m128i dot = lsx_maddubs_h(ax, sy);\n    const __m128i ones = __lsx_vreplgr2vr_h(1);\n    return lsx_madd_h(ones, dot);\n}\n#endif\n\n#if defined(__loongarch_asx)\n\n#ifdef __clang__\n#define VREGS_PREFIX \"$vr\"\n#define XREGS_PREFIX \"$xr\"\n#else // GCC\n#define VREGS_PREFIX \"$f\"\n#define XREGS_PREFIX \"$f\"\n#endif\n#define __ALL_REGS \"0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31\"\n// Convert __m128i to __m256i\nstatic inline __m256i ____m256i(__m128i in) {\n    __m256i out = __lasx_xvldi(0);\n    __asm__ volatile (\n        \".irp i,\" __ALL_REGS                \"\\n\\t\"\n        \" .ifc %[out], \" XREGS_PREFIX\"\\\\i    \\n\\t\"\n        \"  .irp j,\" __ALL_REGS              \"\\n\\t\"\n        \"   .ifc %[in], \" VREGS_PREFIX \"\\\\j  \\n\\t\"\n        \"    xvpermi.q $xr\\\\i, $xr\\\\j, 0x20  \\n\\t\"\n        \"   .endif                           \\n\\t\"\n        \"  .endr                             \\n\\t\"\n        \" .endif                             \\n\\t\"\n        \".endr                               \\n\\t\"\n        : [out] \"+f\" (out) : [in] \"f\" (in)\n    );\n    return out;\n}\n// Convert two __m128i to __m256i\nstatic inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {\n    __m256i out;\n    __asm__ volatile (\n        \".irp i,\" __ALL_REGS                \"\\n\\t\"\n        \" .ifc %[hi], \" VREGS_PREFIX \"\\\\i    \\n\\t\"\n        \"  .irp j,\" __ALL_REGS              \"\\n\\t\"\n        \"   .ifc %[lo], \" VREGS_PREFIX \"\\\\j  \\n\\t\"\n        \"    xvpermi.q $xr\\\\i, $xr\\\\j, 0x20  \\n\\t\"\n        \"   .endif                           \\n\\t\"\n        \"  .endr                             \\n\\t\"\n        \" .endif                             \\n\\t\"\n        \".endr                               \\n\\t\"\n        \".ifnc %[out], %[hi]                 \\n\\t\"\n        \".irp i,\" __ALL_REGS                \"\\n\\t\"\n        \" .ifc %[out], \" XREGS_PREFIX \"\\\\i   \\n\\t\"\n        \"  .irp j,\" __ALL_REGS              \"\\n\\t\"\n        \"   .ifc %[hi], \" VREGS_PREFIX \"\\\\j  \\n\\t\"\n        \"    xvori.b $xr\\\\i, $xr\\\\j, 0       \\n\\t\"\n        \"   .endif                           \\n\\t\"\n        \"  .endr                             \\n\\t\"\n        \" .endif                             \\n\\t\"\n        \".endr                               \\n\\t\"\n        \".endif                              \\n\\t\"\n        : [out] \"=f\" (out), [hi] \"+f\" (inhi)\n        : [lo] \"f\" (inlo)\n    );\n    return out;\n}\n// Convert __m256i low part to __m128i\nstatic inline __m128i lasx_extracti128_lo(__m256i in) {\n    __m128i out;\n    __asm__ volatile (\n        \".ifnc %[out], %[in]                 \\n\\t\"\n        \".irp i,\" __ALL_REGS                \"\\n\\t\"\n        \" .ifc %[out], \" VREGS_PREFIX \"\\\\i   \\n\\t\"\n        \"  .irp j,\" __ALL_REGS              \"\\n\\t\"\n        \"   .ifc %[in], \" XREGS_PREFIX \"\\\\j  \\n\\t\"\n        \"    vori.b $vr\\\\i, $vr\\\\j, 0        \\n\\t\"\n        \"   .endif                           \\n\\t\"\n        \"  .endr                             \\n\\t\"\n        \" .endif                             \\n\\t\"\n        \".endr                               \\n\\t\"\n        \".endif                              \\n\\t\"\n        : [out] \"=f\" (out) : [in] \"f\" (in)\n    );\n    return out;\n}\n// Convert __m256i high part to __m128i\nstatic inline __m128i lasx_extracti128_hi(__m256i in) {\n    __m128i out;\n    __asm__ volatile (\n        \".irp i,\" __ALL_REGS                \"\\n\\t\"\n        \" .ifc %[out], \" VREGS_PREFIX \"\\\\i   \\n\\t\"\n        \"  .irp j,\" __ALL_REGS              \"\\n\\t\"\n        \"   .ifc %[in], \" XREGS_PREFIX \"\\\\j  \\n\\t\"\n        \"    xvpermi.q $xr\\\\i, $xr\\\\j, 0x11  \\n\\t\"\n        \"   .endif                           \\n\\t\"\n        \"  .endr                             \\n\\t\"\n        \" .endif                             \\n\\t\"\n        \".endr                               \\n\\t\"\n        : [out] \"=f\" (out) : [in] \"f\" (in)\n    );\n    return out;\n}\n\nstatic __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) {\n    v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7};\n    return (__m256i)__ret;\n}\n\nstatic __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {\n    v4i64 __ret = {d, c, b, a};\n    return (__m256i)__ret;\n}\n\nstatic __m256i lasx_insertf128( __m128i x, __m128i y) {\n    return lasx_set_q(x, y);\n}\n\nstatic __m256i lasx_shuffle_b(__m256i a, __m256i b) {\n    __m256i mask_f, zero, tmp0, tmp2, mask;\n    int f = 0x8f;\n    mask_f = __lasx_xvreplgr2vr_b(f);\n    zero = __lasx_xvldi(0);\n    tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits\n    tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or  with 0x10 prepare for positive\n    mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask\n    tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones\n    return __lasx_xvshuf_b(a, zero, tmp2);\n}\n\nstatic __m256i lasx_extu8_16(__m128i a) {\n    return __lasx_vext2xv_hu_bu(____m256i(a));\n}\n\nstatic __m256i lasx_ext8_16(__m128i a) {\n    return __lasx_vext2xv_h_b(____m256i(a));\n}\n\nstatic __m256i lasx_ext16_32(__m128i a) {\n    return __lasx_vext2xv_w_h(____m256i(a));\n}\n\nstatic __m128i lasx_extracti128( __m256i a, int pos) {\n    __m128i ret;\n    if( pos == 0)\n    {\n       ret = lasx_extracti128_lo(a);\n    } else {\n       ret = lasx_extracti128_hi(a);\n    }\n    return ret;\n}\n\nstatic __m128 lasx_extractf128( __m256 a, int pos) {\n    __m128 ret;\n    if( pos == 0)\n    {\n       ret = (__m128)lasx_extracti128_lo((__m256i)a);\n    } else {\n       ret = (__m128)lasx_extracti128_hi((__m256i)a);\n    }\n    return ret;\n}\n\nstatic __m256i lasx_maddubs_h(__m256i a, __m256i b) {\n    __m256i tmp1, tmp2;\n    tmp1 = __lasx_xvmulwev_h_b(a, b);\n    tmp2 = __lasx_xvmulwod_h_b(a, b);\n    return __lasx_xvsadd_h(tmp1, tmp2);\n}\n\nstatic __m256i lasx_madd_h(__m256i a, __m256i b) {\n    __m256i tmp1, tmp2;\n    tmp1 = __lasx_xvmulwev_w_h(a, b);\n    tmp2 = __lasx_xvmulwod_w_h(a, b);\n    return __lasx_xvadd_w(tmp1, tmp2);\n}\n\nstatic __m256i lasx_packs_w(__m256i a, __m256i b) {\n    __m256i tmp, tmp1;\n    tmp = __lasx_xvsat_w(a, 15);\n    tmp1 = __lasx_xvsat_w(b, 15);\n    return __lasx_xvpickev_h(tmp1, tmp);\n}\n\nstatic __m256i lasx_packs_h(__m256i a, __m256i b) {\n    __m256i tmp, tmp1;\n    tmp = __lasx_xvsat_h(a, 7);\n    tmp1 = __lasx_xvsat_h(b, 7);\n    return __lasx_xvpickev_b(tmp1, tmp);\n}\n\nstatic inline __m256i lasx_madd_h_b(__m256i a, __m256i b) {\n    __m256i tmp1, tmp2;\n    tmp1 = __lasx_xvmulwev_h_b(a, b);\n    tmp2 = __lasx_xvmulwod_h_b(a, b);\n    return __lasx_xvadd_h(tmp1, tmp2);\n}\n\nstatic inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) {\n    switch (b) {\n        case 0: return __lasx_xvrepl128vei_h(a, 0);\n        case 1: return __lasx_xvrepl128vei_h(a, 1);\n        case 2: return __lasx_xvrepl128vei_h(a, 2);\n        case 3: return __lasx_xvrepl128vei_h(a, 3);\n        case 4: return __lasx_xvrepl128vei_h(a, 4);\n        case 5: return __lasx_xvrepl128vei_h(a, 5);\n        case 6: return __lasx_xvrepl128vei_h(a, 6);\n        case 7: return __lasx_xvrepl128vei_h(a, 7);\n        default: __builtin_unreachable();\n    }\n}\n\nstatic inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) {\n    switch (b) {\n        case 0: return __lasx_xvandi_b(a, 1 << 0);\n        case 1: return __lasx_xvandi_b(a, 1 << 1);\n        case 2: return __lasx_xvandi_b(a, 1 << 2);\n        case 3: return __lasx_xvandi_b(a, 1 << 3);\n        case 4: return __lasx_xvandi_b(a, 1 << 4);\n        case 5: return __lasx_xvandi_b(a, 1 << 5);\n        case 6: return __lasx_xvandi_b(a, 1 << 6);\n        case 7: return __lasx_xvandi_b(a, 1 << 7);\n        default: __builtin_unreachable();\n    }\n}\n\n// horizontally add 8 floats\nstatic inline float hsum_float_8(const __m256 x) {\n    __m128 res = lasx_extractf128(x, 1);\n    res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));\n    res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));\n    res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));\n    return ((v4f32)res)[0];\n}\n\n// horizontally add 8 int32_t\nstatic inline int hsum_i32_8(const __m256i a) {\n\n    __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);\n    __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);\n\n    __m128i  tmp1_128 = lasx_extracti128_lo(tmp1);\n    __m128i  tmp2_128 = lasx_extracti128_lo(tmp2);\n\n    __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);\n\n    __m128i ev = __lsx_vpickev_w(sum128, sum128);\n    __m128i od = __lsx_vpickod_w(sum128, sum128);\n    __m128i sum64 = __lsx_vadd_w(ev, od);\n\n    int sum64_1, sum64_2;\n    sum64_1 = __lsx_vpickve2gr_w(sum64, 0);\n    sum64_2 = __lsx_vpickve2gr_w(sum64, 1);\n\n    return  sum64_1 + sum64_2;\n}\n\n// horizontally add 4 int32_t\nstatic inline int hsum_i32_4(const __m128i a) {\n    __m128i ev = __lsx_vpickev_w(a, a);\n    __m128i od = __lsx_vpickod_w(a, a);\n    __m128i sum64 = __lsx_vadd_w(ev, od);\n\n    int sum64_1, sum64_2;\n    sum64_1 = __lsx_vpickve2gr_w(sum64, 0);\n    sum64_2 = __lsx_vpickve2gr_w(sum64, 1);\n\n    return  sum64_1 + sum64_2;\n}\n\n// spread 32 bits to 32 bytes { 0x00, 0xFF }\nstatic inline __m256i bytes_from_bits_32(const uint8_t * x) {\n\n    uint32_t x32;\n    memcpy(&x32, x, sizeof(uint32_t));\n    const __m256i shuf_mask = lasx_set_d(\n            0x0303030303030303, 0x0202020202020202,\n            0x0101010101010101, 0x0000000000000000);\n\n    __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask);\n    const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe);\n    bytes = __lasx_xvor_v(bytes, bit_mask);\n    return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1));\n}\n\n// Unpack 32 4-bit fields into 32 bytes\n// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval\nstatic inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {\n    const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);\n    __m128i hi = __lsx_vsrli_h(lo, 4);\n    return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf);\n}\n\n// add int16_t pairwise and return as float vector\nstatic inline __m256 sum_i16_pairs_float(const __m256i x) {\n    __m256i v = __lasx_xvpackod_h(x, x);\n    __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);\n    return __lasx_xvffint_s_w(summed_pairs);\n}\n\nstatic inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {\n    // Perform multiplication and create 16-bit values\n    const __m256i dot = lasx_maddubs_h(ax, sy);\n    return sum_i16_pairs_float(dot);\n}\n\n// multiply int8_t, add results pairwise twice and return as float vector\nstatic inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {\n    const __m256i dot = lasx_madd_h_b(x, y);\n    return sum_i16_pairs_float(dot);\n}\n\nstatic inline __m128i packNibbles( __m256i bytes ) {\n    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh\n    const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF);\n     __m256i high = __lasx_xvandn_v(lowByte, bytes);\n    __m256i low = __lasx_xvand_v(lowByte, bytes);\n    high = __lasx_xvsrli_h(high, 4);\n    bytes = __lasx_xvor_v(low, high);\n    // Compress uint16_t lanes into bytes\n    __m128i *r0 = (__m128i *)&bytes;\n    __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11);\n    __m128i *r1 = (__m128i *)&tmp_h128;\n\n    __m128i zero = __lsx_vldi(0);\n    __m128i tmp, tmp2, tmp3;\n\n    tmp = __lsx_vmax_h(zero, *r0);\n    tmp2 = __lsx_vsat_hu(tmp, 7);\n\n    tmp = __lsx_vmax_h(zero, *r1);\n    tmp3 = __lsx_vsat_hu(tmp, 7);\n    return  __lsx_vpickev_b(tmp3, tmp2);\n}\n#endif  //__loongarch_asx\n\nvoid quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK8_0 == 32);\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n    block_q8_0 * GGML_RESTRICT y = vy;\n\n#if defined(__loongarch_asx)\n    for (int i = 0; i < nb; i++) {\n        __m256 v0 = (__m256)__lasx_xvld( x , 0);\n        __m256 v1 = (__m256)__lasx_xvld( x , 32);\n        __m256 v2 = (__m256)__lasx_xvld( x , 64);\n        __m256 v3 = (__m256)__lasx_xvld( x , 96);\n        x += 32;\n\n        // Compute max(abs(e)) for the block\n        const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );\n        __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );\n        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );\n        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );\n        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );\n\n        __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) );\n        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );\n        __m128 tmp = max4;\n        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));\n        const float max_scalar = ((v4f32)max4)[0];\n\n        // Quantize these floats\n        const float d = max_scalar / 127.f;\n        y[i].d = GGML_CPU_FP32_TO_FP16(d);\n        const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;\n        const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id );\n\n        // Apply the multiplier\n        v0 = __lasx_xvfmul_s( v0, mul );\n        v1 = __lasx_xvfmul_s( v1, mul );\n        v2 = __lasx_xvfmul_s( v2, mul );\n        v3 = __lasx_xvfmul_s( v3, mul );\n\n        // Round to nearest integer\n        __m256i i0 = __lasx_xvftintrne_w_s( v0 );\n        __m256i i1 = __lasx_xvftintrne_w_s( v1 );\n        __m256i i2 = __lasx_xvftintrne_w_s( v2 );\n        __m256i i3 = __lasx_xvftintrne_w_s( v3 );\n\n        __m128i ni0 = lasx_extracti128( i0, 0 );\n        __m128i ni1 = lasx_extracti128( i0, 1);\n        __m128i ni2 = lasx_extracti128( i1, 0);\n        __m128i ni3 = lasx_extracti128( i1, 1);\n        __m128i ni4 = lasx_extracti128( i2, 0);\n        __m128i ni5 = lasx_extracti128( i2, 1);\n        __m128i ni6 = lasx_extracti128( i3, 0);\n        __m128i ni7 = lasx_extracti128( i3, 1);\n\n        // Convert int32 to int16\n        ni0 = lsx_packs_w( ni0, ni1 );\n        ni2 = lsx_packs_w( ni2, ni3 );\n        ni4 = lsx_packs_w( ni4, ni5 );\n        ni6 = lsx_packs_w( ni6, ni7 );\n        // Convert int16 to int8\n        ni0 = lsx_packs_h( ni0, ni2 );\n        ni4 = lsx_packs_h( ni4, ni6 );\n\n        __lsx_vst(ni0, (__m128i *)(y[i].qs +  0), 0);\n        __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);\n\n    }\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_0_ref(x, y, k);\n#endif\n}\n\nvoid quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(k % QK8_1 == 0);\n    const int nb = k / QK8_1;\n\n    block_q8_1 * GGML_RESTRICT y = vy;\n\n#if defined(__loongarch_asx)\n    for (int i = 0; i < nb; i++) {\n        __m256 v0 = (__m256)__lasx_xvld( x , 0 );\n        __m256 v1 = (__m256)__lasx_xvld( x , 32 );\n        __m256 v2 = (__m256)__lasx_xvld( x , 64 );\n        __m256 v3 = (__m256)__lasx_xvld( x , 96 );\n        x += 32;\n\n        // Compute max(abs(e)) for the block\n        const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );\n        __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );\n        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );\n        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );\n        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );\n\n        __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) );\n        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );\n        __m128 tmp = max4;\n        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x1 ));\n        const float max_scalar = ((v4f32)max4)[0];\n\n        // Quantize these floats\n        const float d = max_scalar / 127.f;\n        y[i].d = GGML_CPU_FP32_TO_FP16(d);\n        const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;\n        const __m256 mul = __lasx_xvreplfr2vr_s( id );\n\n        // Apply the multiplier\n        v0 = __lasx_xvfmul_s( v0, mul );\n        v1 = __lasx_xvfmul_s( v1, mul );\n        v2 = __lasx_xvfmul_s( v2, mul );\n        v3 = __lasx_xvfmul_s( v3, mul );\n\n        // Round to nearest integer\n        __m256i i0 = __lasx_xvftintrne_w_s( v0 );\n        __m256i i1 = __lasx_xvftintrne_w_s( v1 );\n        __m256i i2 = __lasx_xvftintrne_w_s( v2 );\n        __m256i i3 = __lasx_xvftintrne_w_s( v3 );\n\n        __m128i ni0 = lasx_extracti128(i0, 0);\n        __m128i ni1 = lasx_extracti128( i0, 1);\n        __m128i ni2 = lasx_extracti128( i1, 0);\n        __m128i ni3 = lasx_extracti128( i1, 1);\n        __m128i ni4 = lasx_extracti128( i2, 0 );\n        __m128i ni5 = lasx_extracti128( i2, 1);\n        __m128i ni6 = lasx_extracti128( i3, 0);\n        __m128i ni7 = lasx_extracti128( i3, 1);\n\n        // Compute the sum of the quants and set y[i].s\n        const __m128i s0 = __lsx_vadd_w(__lsx_vadd_w(ni0, ni1), __lsx_vadd_w(ni2, ni3));\n        const __m128i s1 = __lsx_vadd_w(__lsx_vadd_w(ni4, ni5), __lsx_vadd_w(ni6, ni7));\n        y[i].s = GGML_CPU_FP32_TO_FP16(d * hsum_i32_4(__lsx_vadd_w(s0, s1)));\n\n        // Convert int32 to int16\n        ni0 = lsx_packs_w( ni0, ni1 );\n        ni2 = lsx_packs_w( ni2, ni3 );\n        ni4 = lsx_packs_w( ni4, ni5 );\n        ni6 = lsx_packs_w( ni6, ni7 );\n        // Convert int16 to int8\n        ni0 = lsx_packs_h( ni0, ni2 );\n        ni4 = lsx_packs_h( ni4, ni6 );\n\n        __lsx_vst(ni0, (__m128i *)(y[i].qs +  0), 0);\n        __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);\n    }\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_1_ref(x, y, k);\n#endif\n}\n\n\n//===================================== Dot products =================================\n\n//\n// Helper functions\n//\n\n#if defined(__loongarch_asx)\n// shuffles to pick the required scales in dot products\nstatic inline __m256i get_scale_shuffle_q3k(int i) {\n    static const uint8_t k_shuffle[128] = {\n         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,\n         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,\n         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,\n        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,\n    };\n    return __lasx_xvld((const __m256i*)k_shuffle + i, 0);\n}\nstatic inline __m256i get_scale_shuffle_k4(int i) {\n    static const uint8_t k_shuffle[256] = {\n         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,\n         2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,\n         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,\n         6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,\n         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,\n        10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,\n        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,\n        14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15\n    };\n    return __lasx_xvld((const __m256i*)k_shuffle + i, 0);\n}\nstatic inline __m128i get_scale_shuffle(int i) {\n    static const uint8_t k_shuffle[128] = {\n         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n         2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,\n         4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,\n         6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,\n         8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,\n        10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,\n        12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,\n        14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15\n    };\n    return __lsx_vld((const __m128i*)k_shuffle + i, 0);\n}\n#endif\n\nvoid ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__loongarch_asx)\n    // Initialize accumulator with zeros\n    __m256 acc = (__m256)__lasx_xvldi(0);\n\n    // Main loop\n    for (; ib < nb; ++ib) {\n        /* Compute combined scale for the block */\n        const __m256 d = __lasx_xvreplfr2vr_s( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );\n\n        __m256i qx = bytes_from_nibbles_32(x[ib].qs);\n\n        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.\n        const __m256i off = __lasx_xvreplgr2vr_b( 8 );\n        qx = __lasx_xvsub_b( qx, off );\n\n        __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);\n\n        const __m256 q = mul_sum_i8_pairs_float(qx, qy);\n\n        /* Multiply q with scale and accumulate */\n        acc = __lasx_xvfmadd_s( d, q, acc );\n    }\n\n    sumf = hsum_float_8(acc);\n\n#elif defined(__loongarch_sx)\n    // set constants\n    const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);\n    const __m128i off = __lsx_vreplgr2vr_b(8);\n\n    // Initialize accumulator with zeros\n    __m128 acc_0 = (__m128)__lsx_vldi(0);\n    __m128 acc_1 = (__m128)__lsx_vldi(0);\n    __m128 acc_2 = (__m128)__lsx_vldi(0);\n    __m128 acc_3 = (__m128)__lsx_vldi(0);\n\n    for (; ib + 1 < nb; ib += 2) {\n\n        // Compute combined scale for the block 0 and 1\n        const float ft0 = GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d);\n        const __m128 d_0_1 = (__m128)(v4f32){ft0, ft0, ft0, ft0};\n\n        const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);\n\n        __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1);\n        __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0);\n        bx_0 = __lsx_vsub_b(bx_0, off);\n        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);\n\n        __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4));\n        __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0);\n        bx_1 = __lsx_vsub_b(bx_1, off);\n        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);\n\n        // Compute combined scale for the block 2 and 3\n        const float ft1 = GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d);\n        const __m128 d_2_3 = (__m128)(v4f32){ft1, ft1, ft1, ft1};\n\n        const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);\n\n        __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3);\n        __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0);\n        bx_2 = __lsx_vsub_b(bx_2, off);\n        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);\n\n        __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4));\n        __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0);\n        bx_3 = __lsx_vsub_b(bx_3, off);\n        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);\n\n        // Convert int32_t to float\n        __m128 p0 = __lsx_vffint_s_w(i32_0);\n        __m128 p1 = __lsx_vffint_s_w(i32_1);\n        __m128 p2 = __lsx_vffint_s_w(i32_2);\n        __m128 p3 = __lsx_vffint_s_w(i32_3);\n\n        // Apply the scale\n        __m128 p0_d = __lsx_vfmul_s( d_0_1, p0 );\n        __m128 p1_d = __lsx_vfmul_s( d_0_1, p1 );\n        __m128 p2_d = __lsx_vfmul_s( d_2_3, p2 );\n        __m128 p3_d = __lsx_vfmul_s( d_2_3, p3 );\n\n        // Acummulate\n        acc_0 = __lsx_vfadd_s(p0_d, acc_0);\n        acc_1 = __lsx_vfadd_s(p1_d, acc_1);\n        acc_2 = __lsx_vfadd_s(p2_d, acc_2);\n        acc_3 = __lsx_vfadd_s(p3_d, acc_3);\n    }\n\n    sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);\n\n#endif\n    for (; ib < nb; ++ib) {\n        int sumi0 = 0;\n        int sumi1 = 0;\n\n        for (int j = 0; j < qk/2; ++j) {\n            const int v0 = (x[ib].qs[j] & 0x0F) - 8;\n            const int v1 = (x[ib].qs[j] >>   4) - 8;\n\n            sumi0 += (v0 * y[ib].qs[j]);\n            sumi1 += (v1 * y[ib].qs[j + qk/2]);\n        }\n\n        int sumi = sumi0 + sumi1;\n        sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__loongarch_asx)\n    // Initialize accumulator with zeros\n    __m256 acc = (__m256)__lasx_xvldi(0);\n\n    float summs = 0;\n\n    // Main loop\n    for (; ib < nb; ++ib) {\n        const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);\n        const float d1 = GGML_CPU_FP16_TO_FP32(y[ib].d);\n\n        summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);\n\n        const __m256 d0v = __lasx_xvreplfr2vr_s( d0 );\n        const __m256 d1v = __lasx_xvreplfr2vr_s( d1 );\n\n        // Compute combined scales\n        const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v );\n\n        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes\n        const __m256i qx = bytes_from_nibbles_32(x[ib].qs);\n        const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0);\n\n        const __m256 xy = mul_sum_us8_pairs_float(qx, qy);\n\n        // Accumulate d0*d1*x*y\n        acc = __lasx_xvfmadd_s( d0d1, xy, acc );\n    }\n\n    sumf = hsum_float_8(acc) + summs;\n\n    *s = sumf;\n#else\n    UNUSED(nb);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(ib);\n    UNUSED(sumf);\n    ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    int ib = 0;\n    float sumf = 0;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n#if defined(__loongarch_asx)\n    // Initialize accumulator with zeros\n    __m256 acc = (__m256)__lasx_xvldi(0);\n\n    // Main loop\n    for (; ib < nb; ++ib) {\n        /* Compute combined scale for the block */\n        const __m256 d = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)); //FIXME\n\n        __m256i qx = bytes_from_nibbles_32(x[ib].qs);\n        __m256i bxhi = bytes_from_bits_32(x[ib].qh);\n        bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0));\n        qx = __lasx_xvor_v(qx, bxhi);\n\n        __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);\n\n        const __m256 q = mul_sum_i8_pairs_float(qx, qy);\n\n        /* Multiply q with scale and accumulate */\n        acc = __lasx_xvfmadd_s(d, q, acc);\n    }\n\n    sumf = hsum_float_8(acc);\n\n    *s = sumf;\n#else\n    UNUSED(nb);\n    UNUSED(ib);\n    UNUSED(sumf);\n    UNUSED(x);\n    UNUSED(y);\n    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    int ib = 0;\n    float sumf = 0;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_1);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n#if defined(__loongarch_asx)\n    // Initialize accumulator with zeros\n    __m256 acc = (__m256)__lasx_xvldi(0);\n\n    float summs = 0.0f;\n\n    // Main loop\n    for (; ib < nb; ++ib) {\n        const __m256 dx = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ib].d));\n\n        summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);\n\n        __m256i qx = bytes_from_nibbles_32(x[ib].qs);\n        __m256i bxhi = bytes_from_bits_32(x[ib].qh);\n        bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10));\n        qx = __lasx_xvor_v(qx, bxhi);\n\n        const __m256 dy = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(y[ib].d));\n        const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);\n\n        const __m256 q = mul_sum_us8_pairs_float(qx, qy);\n\n        acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc);\n    }\n\n    sumf = hsum_float_8(acc) + summs;\n\n    *s = sumf;\n#else\n    UNUSED(nb);\n    UNUSED(ib);\n    UNUSED(sumf);\n    UNUSED(x);\n    UNUSED(y);\n    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q8_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__loongarch_asx)\n    // Initialize accumulator with zeros\n    __m256 acc = (__m256)__lasx_xvldi(0);\n\n    // Main loop\n    for (; ib < nb; ++ib) {\n        // Compute combined scale for the block\n        const __m256 d = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));\n        __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0);\n        __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);\n\n        const __m256 q = mul_sum_i8_pairs_float(qx, qy);\n\n        // Multiply q with scale and accumulate\n        acc = __lasx_xvfmadd_s( d, q, acc );\n    }\n\n    sumf = hsum_float_8(acc);\n\n    *s = sumf;\n#else\n    UNUSED(nb);\n    UNUSED(ib);\n    UNUSED(sumf);\n    UNUSED(x);\n    UNUSED(y);\n    ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q2_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __loongarch_asx\n\n    __m256 acc = (__m256)__lasx_xvldi(0);\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        const uint8_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);\n        const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf);\n        const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4));\n        const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));\n\n        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);\n\n        const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};\n        const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));\n\n        __m256i sumi = __lasx_xvldi(0);\n\n        for (int j = 0; j < QK_K/128; ++j) {\n\n            const __m256i q2bits = __lasx_xvld((const __m256i*)q2, 0); q2 += 32;\n\n            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n            const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n            const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n\n            const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3);\n            const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3);\n            const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3);\n            const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6);\n\n            __m256i p0 = lasx_madd_h_b(q2_0, q8_0);\n            __m256i p1 = lasx_madd_h_b(q2_1, q8_1);\n            __m256i p2 = lasx_madd_h_b(q2_2, q8_2);\n            __m256i p3 = lasx_madd_h_b(q2_3, q8_3);\n\n            p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0);\n            p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1);\n            p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2);\n            p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3);\n\n            p0 = __lasx_xvadd_w(p0, p1);\n            p2 = __lasx_xvadd_w(p2, p3);\n\n            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p0, p2));\n        }\n\n        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);\n\n    }\n\n    *s = hsum_float_8(acc);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n\n    const block_q3_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __loongarch_asx\n\n    const __m128i m32 = __lsx_vreplgr2vr_b(32);\n\n    __m256 acc = (__m256)__lasx_xvldi(0);\n\n    uint32_t aux[3];\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const uint8_t * GGML_RESTRICT q3 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n        // Set up scales\n        memcpy(aux, x[i].scales, 12);\n        __m128i scales128 = lsx_set_w(\n                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),\n                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),\n                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),\n                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));\n        scales128 = __lsx_vsub_b(scales128, m32);\n\n        const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};\n        const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));\n\n        // high bit\n        const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);\n\n        // integer accumulator\n        __m256i sumi = __lasx_xvldi(0);\n\n        for (int j = 0; j < QK_K/128; ++j) {\n            // load low 2 bits\n            const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;\n\n            // prepare low and high bits\n            const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3);\n            const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3);\n            const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3);\n            const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6);\n            const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2);\n            const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2);\n            const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2);\n            const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2);\n            const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0);\n            const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1);\n            const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2);\n            const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3);\n\n            // load Q8 quants\n            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n            const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n            const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n\n            __m256i p16_0 = lasx_madd_h_b(q8_0, q3_0);\n            __m256i p16_1 = lasx_madd_h_b(q8_1, q3_1);\n            __m256i p16_2 = lasx_madd_h_b(q8_2, q3_2);\n            __m256i p16_3 = lasx_madd_h_b(q8_3, q3_3);\n\n            // multiply with scales\n            p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);\n            p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);\n            p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);\n            p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);\n\n            // accumulate\n            p16_0 = __lasx_xvadd_w(p16_0, p16_1);\n            p16_2 = __lasx_xvadd_w(p16_2, p16_3);\n            sumi  = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));\n        }\n        // multiply with block scale and accumulate\n        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);\n    }\n\n    *s = hsum_float_8(acc);\n\n#else\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n#if defined __loongarch_asx\n\n    __m256 acc = (__m256)__lasx_xvldi(0);\n    __m128 acc_m = (__m128)__lsx_vldi(0);\n\n   for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        memcpy(utmp, x[i].scales, 12);\n        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n        const uint32_t uaux = utmp[1] & kmask1;\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[2] = uaux;\n        utmp[0] &= kmask1;\n\n        const uint8_t * GGML_RESTRICT q4 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);\n        const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);\n        const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);\n\n        const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);\n        const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));\n        const __m128i prod = lsx_madd_h(mins128, q8s);\n        acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);\n\n        const __m256i scales = lasx_insertf128(scales128, scales128);\n\n        __m256i sumi = __lasx_xvldi(0);\n\n        for (int j = 0; j < QK_K/64; ++j) {\n\n            const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0);\n            const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1);\n\n            const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;\n            const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf);\n            const __m256i q4h = __lasx_xvsrli_b(q4bits, 4);\n\n            const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n            __m256i p16l = lasx_madd_h_b(q4l, q8l);\n            p16l = lasx_madd_h(scale_l, p16l);\n\n            const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n            __m256i p16h = lasx_madd_h_b(q4h, q8h);\n            p16h = lasx_madd_h(scale_h, p16h);\n            const __m256i sumj = __lasx_xvadd_w(p16l, p16h);\n\n            sumi = __lasx_xvadd_w(sumi, sumj);\n        }\n\n        __m256 vd = __lasx_xvreplfr2vr_s(d);\n        acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);\n\n    }\n\n    acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));\n    __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);\n    acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);\n\n\n    *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    UNUSED(utmp);\n    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n#if defined __loongarch_asx\n\n    __m256 acc = (__m256)__lasx_xvldi(0);\n    __m128 acc_m = (__m128)__lsx_vldi(0);\n\n    for (int i = 0; i < nb; ++i) {\n\n        const uint8_t * GGML_RESTRICT q5 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        memcpy(utmp, x[i].scales, 12);\n        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n        const uint32_t uaux = utmp[1] & kmask1;\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[2] = uaux;\n        utmp[0] &= kmask1;\n\n        const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);\n        const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);\n        const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);\n\n        const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);\n        const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));\n        const __m128i prod = lsx_madd_h(mins128, q8s);\n        acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);\n\n        const __m256i scales = lasx_insertf128(scales128, scales128);\n\n        const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);\n\n        __m256i sumi = __lasx_xvldi(0);\n\n        for (int j = 0; j < QK_K/64; ++j) {\n\n            const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0);\n            const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1);\n\n            const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;\n\n            const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf);\n            const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4);\n            const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef);\n            const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef);\n            const __m256i q5_0  = __lasx_xvor_v(q5l_0, q5h_0);\n            const __m256i q5_1  = __lasx_xvor_v(q5l_1, q5h_1);\n\n            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n\n            __m256i p16_0 = lasx_madd_h_b(q5_0, q8_0);\n            __m256i p16_1 = lasx_madd_h_b(q5_1, q8_1);\n\n            p16_0 = lasx_madd_h(scale_0, p16_0);\n            p16_1 = lasx_madd_h(scale_1, p16_1);\n\n            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));\n\n        }\n\n        __m256 vd = __lasx_xvreplfr2vr_s(d);\n        acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);\n\n    }\n\n    acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8));\n    acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));\n\n    *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    UNUSED(utmp);\n    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q6_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __loongarch_asx\n\n    const __m256i m32s = __lasx_xvreplgr2vr_b(32);\n\n    __m256 acc = (__m256)__lasx_xvldi(0);\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n\n        const uint8_t * GGML_RESTRICT q4 = x[i].ql;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);\n        const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};\n        const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));\n\n        __m256i sumi = __lasx_xvldi(0);\n\n        for (int j = 0; j < QK_K/128; ++j) {\n\n            const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;\n            const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;\n            const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;\n\n            const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4);\n            const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2);\n            const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4);\n            const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2);\n\n            const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0);\n            const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1);\n            const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2);\n            const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3);\n\n            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n            const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n            const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n\n            __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0);\n            __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1);\n            __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2);\n            __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3);\n\n            p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);\n            p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);\n            p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);\n            p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);\n\n            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));\n            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));\n        }\n\n        acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);\n    }\n\n    *s = hsum_float_8(acc);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\n#if defined(__loongarch_asx)\nstatic const int8_t keven_signs_q2xs[1024] = {\n     1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,\n     1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,\n     1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1, -1,\n     1,  1, -1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1,  1,\n     1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1, -1,\n     1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1,  1,\n     1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1,  1,\n     1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1, -1,\n     1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1, -1,\n     1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1,  1,\n     1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1,  1,\n     1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1, -1,\n     1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1,  1,\n     1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,\n     1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1, -1,\n     1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1,\n     1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1, -1,\n     1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1,  1,\n     1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1,\n     1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1,\n     1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1,  1,\n     1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1, -1,\n     1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1, -1,\n     1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1,\n     1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1,\n     1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1, -1,\n     1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1,\n     1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1,\n     1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1,\n     1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,  1,\n     1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1,\n     1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,\n};\n#endif\n\nvoid ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_xxs * GGML_RESTRICT x = vx;\n    const block_q8_K    * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__loongarch_asx)\n\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n\n    uint32_t aux32[4];\n    const uint8_t * aux8 = (const uint8_t *)aux32;\n\n    __m256 accumf = (__m256)__lasx_xvldi(0);\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint16_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n        __m256i sumi1 = __lasx_xvldi(0);\n        __m256i sumi2 = __lasx_xvldi(0);\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;\n            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;\n            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;\n\n            const __m256i q2_1 = lasx_set_d(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);\n            const __m256i q2_2 = lasx_set_d(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);\n            const __m256i s2_1 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],\n                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);\n            const __m256i s2_2 = lasx_set_d(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],\n                                                   signs64[(aux32[3] >>  7) & 127], signs64[(aux32[3] >>  0) & 127]);\n            const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1);\n            const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2);\n            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1);\n            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);\n            const uint16_t ls1 = aux32[1] >> 28;\n            const uint16_t ls2 = aux32[3] >> 28;\n            const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));\n            const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));\n            sumi1 = __lasx_xvadd_w(sumi1, p1);\n            sumi2 = __lasx_xvadd_w(sumi2, p2);\n        }\n\n        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);\n    }\n\n    *s = 0.125f * hsum_float_8(accumf);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_xs * GGML_RESTRICT x = vx;\n    const block_q8_K   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__loongarch_asx)\n\n    const __m256i mone = __lasx_xvreplgr2vr_b(1);\n    static const char block_sign_shuffle_mask_1[32] = {\n        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,\n        0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,\n    };\n    static const char block_sign_shuffle_mask_2[32] = {\n        0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,\n        0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,\n    };\n    static const uint8_t bit_selector_mask_bytes[32] = {\n        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n    };\n\n    const __m256i bit_selector_mask = __lasx_xvld((const __m256i*)bit_selector_mask_bytes, 0);\n    const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0);\n    const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0);\n\n    static const uint8_t k_bit_helper[32] = {\n        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,\n        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,\n    };\n    const __m256i bit_helper = __lasx_xvld((const __m256i*)k_bit_helper, 0);\n    const __m256i m511 = __lasx_xvreplgr2vr_h(511);\n    const __m128i m4 = __lsx_vreplgr2vr_b(0xf);\n    const __m128i m1 = __lsx_vreplgr2vr_b(1);\n\n    uint64_t aux64;\n\n    // somewhat hacky, but gives a significant boost in performance\n    __m256i aux_gindex;\n    const uint16_t * gindex = (const uint16_t *)&aux_gindex;\n\n    __m256 accumf = (__m256)__lasx_xvldi(0);\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint16_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n\n        memcpy(&aux64, x[i].scales, 8);\n        __m128i stmp = __lsx_vreplgr2vr_d(aux64);\n        stmp = __lsx_vilvl_b( __lsx_vand_v(__lsx_vsrli_h(stmp, 4), m4), __lsx_vand_v(stmp, m4));\n        const __m128i scales = __lsx_vadd_b(__lsx_vslli_h(stmp, 1), m1);\n\n        __m256i sumi1 = __lasx_xvldi(0);\n        __m256i sumi2 = __lasx_xvldi(0);\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {\n\n            const __m256i q2_data = __lasx_xvld((const __m256i*)q2, 0);  q2 += 16;\n            aux_gindex = __lasx_xvand_v(q2_data, m511);\n\n            const __m256i partial_sign_bits = __lasx_xvsrli_h(q2_data, 9);\n            const __m256i partial_sign_bits_upper = __lasx_xvsrli_h(q2_data, 13);\n            const __m256i partial_sign_bits_for_counting = __lasx_xvxor_v(partial_sign_bits, partial_sign_bits_upper);\n\n            const __m256i odd_bits = lasx_shuffle_b(bit_helper, partial_sign_bits_for_counting);\n            const __m256i full_sign_bits = __lasx_xvor_v(partial_sign_bits, odd_bits);\n\n            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;\n            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;\n            const __m256i q8_3 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;\n            const __m256i q8_4 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;\n\n            const __m256i q2_1 = lasx_set_d(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],\n                                                   iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);\n            const __m256i q2_2 = lasx_set_d(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],\n                                                   iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);\n            const __m256i q2_3 = lasx_set_d(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],\n                                                   iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);\n            const __m256i q2_4 = lasx_set_d(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],\n                                                   iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);\n\n            const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0);\n            const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1);\n            const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l);\n            const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h);\n\n            __m256i signs;\n            signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1);\n            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);\n            const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1);\n\n            signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_2);\n            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);\n            const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2);\n\n            signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_1);\n            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);\n            const __m256i q8s_3 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_3);\n\n            signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_2);\n            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);\n            const __m256i q8s_4 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_4);\n\n            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1);\n            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);\n            const __m256i dot3  = lasx_maddubs_h(q2_3, q8s_3);\n            const __m256i dot4  = lasx_maddubs_h(q2_4, q8s_4);\n\n            const __m256i sc1 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+0)));\n            const __m256i sc2 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+1)));\n            const __m256i sc3 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+2)));\n            const __m256i sc4 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+3)));\n\n            sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot1, sc1));\n            sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot2, sc2));\n            sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot3, sc3));\n            sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot4, sc4));\n        }\n\n        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);\n\n    }\n\n    *s = 0.125f * hsum_float_8(accumf);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__loongarch_asx)\n\n   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,\n                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03\n   };\n\n    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n    };\n\n\n    const __m128i m4 = __lsx_vreplgr2vr_b(0xf);\n    const __m128i m1 = __lsx_vreplgr2vr_b(1);\n\n    const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0);\n    const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0);\n    uint64_t aux64;\n\n    __m256 accumf = (__m256)__lasx_xvldi(0);\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint8_t * GGML_RESTRICT qs = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        __m128i tmp1;\n        memcpy(&aux64, x[i].scales, 8);\n        tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64, 0);\n        tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64 >> 4, 1);\n        const __m128i scales8 = __lsx_vadd_b(__lsx_vslli_h(__lsx_vand_v(tmp1, m4), 1), m1);\n        const __m256i scales16 = lasx_ext8_16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15\n\n        __m256i sumi1 = __lasx_xvldi(0);\n        __m256i sumi2 = __lasx_xvldi(0);\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;\n            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;\n            const __m256i q2_1 = lasx_set_d(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],\n                                                   iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],\n                                                   iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],\n                                                   iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);\n            const __m256i q2_2 = lasx_set_d(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],\n                                                   iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],\n                                                   iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],\n                                                   iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);\n            qs += 8;\n\n            __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | ((uint32_t) signs[1] << 16));\n            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);\n            const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2);\n            const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1);\n\n            aux256 = __lasx_xvreplgr2vr_w(signs[2] | ((uint32_t) signs[3] << 16));\n            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);\n            const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2);\n            const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2);\n\n            signs += 4;\n\n            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1\n            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3\n\n            const __m256i p1 = lasx_madd_h(dot1, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+0)));\n            const __m256i p2 = lasx_madd_h(dot2, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+1)));\n            sumi1 = __lasx_xvadd_w(sumi1, p1);\n            sumi2 = __lasx_xvadd_w(sumi2, p2);\n        }\n\n        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);\n    }\n\n    *s = 0.125f * hsum_float_8(accumf);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq3_xxs * GGML_RESTRICT x = vx;\n    const block_q8_K    * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__loongarch_asx)\n\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n\n    uint32_t aux32[2];\n\n    __m256 accumf = (__m256)__lasx_xvldi(0);\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint8_t * GGML_RESTRICT q3 = x[i].qs;\n        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n        __m256i sumi1 = __lasx_xvldi(0);\n        __m256i sumi2 = __lasx_xvldi(0);\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;\n            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;\n            const __m256i q2_1 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],\n                                                iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);\n            q3 += 8;\n            const __m256i q2_2 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],\n                                                iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);\n            q3 += 8;\n            memcpy(aux32, gas, 8); gas += 8;\n\n            const __m256i s2_1 = lasx_set_d(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],\n                                                   signs64[(aux32[0] >>  7) & 127], signs64[(aux32[0] >>  0) & 127]);\n            const __m256i s2_2 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],\n                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);\n            const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1);\n            const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2);\n            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1);\n            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);\n            const uint16_t ls1 = aux32[0] >> 28;\n            const uint16_t ls2 = aux32[1] >> 28;\n\n            const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));\n            const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));\n            sumi1 = __lasx_xvadd_w(sumi1, p1);\n            sumi2 = __lasx_xvadd_w(sumi2, p2);\n        }\n\n        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);\n    }\n\n    *s = 0.25f * hsum_float_8(accumf);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq3_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__loongarch_asx)\n\n   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,\n                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03\n   };\n\n    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n    };\n\n    const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0);\n    const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0);\n\n    __m256i idx_shift = lasx_set_w(1, 2, 3, 4, 5, 6, 7, 8);\n    const __m256i idx_mask  = __lasx_xvreplgr2vr_w(256);\n\n    typedef union {\n        __m256i  vec[2];\n        uint32_t index[16];\n    } index_t;\n\n    index_t idx;\n\n    __m256 accumf = (__m256)__lasx_xvldi(0);\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint8_t * GGML_RESTRICT qs = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n        __m256i sumi1 = __lasx_xvldi(0);\n        __m256i sumi2 = __lasx_xvldi(0);\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;\n            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;\n            const __m256i idx_l = lasx_extu8_16(__lsx_vld(qs, 0)); qs += 16;\n            idx.vec[0] = __lasx_xvreplgr2vr_w(qh[ib32+0]);\n            idx.vec[1] = __lasx_xvreplgr2vr_w(qh[ib32+1]);\n            idx.vec[0] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[0], idx_shift), idx_mask);\n            idx.vec[1] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[1], idx_shift), idx_mask);\n            idx.vec[0] = __lasx_xvor_v(idx.vec[0], lasx_ext16_32(lasx_extracti128(idx_l, 0)));\n            idx.vec[1] = __lasx_xvor_v(idx.vec[1], lasx_ext16_32(lasx_extracti128(idx_l, 1)));\n\n            // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.\n            //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);\n            //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);\n            const __m256i q2_1 = lasx_set_w(\n                    iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],\n                    iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]\n            );\n            const __m256i q2_2 = lasx_set_w(\n                    iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],\n                    iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]\n            );\n\n            __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | (signs[1] << 16));\n            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);\n            const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2);\n            const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1);\n\n            aux256 = __lasx_xvreplgr2vr_w(signs[2] | (signs[3] << 16));\n            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);\n            const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2);\n            const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2);\n\n            signs += 4;\n\n            const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);\n            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);\n            const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;\n            const uint16_t ls2 = x[i].scales[ib32/2] >>  4;\n            const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));\n            const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));\n            sumi1 = __lasx_xvadd_w(sumi1, p1);\n            sumi2 = __lasx_xvadd_w(sumi2, p2);\n        }\n\n        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);\n    }\n\n    *s = hsum_float_8(accumf);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\n#if defined(__loongarch_asx)\nstatic inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {\n    const __m256i a = __lasx_xvmulwev_h_b(x, y);\n    const __m256i b = __lasx_xvmulwod_h_b(x, y);\n    return __lasx_xvadd_h(a, b);\n}\n#endif\n\nvoid ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq1_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__loongarch_asx)\n\n    __m256 accum = (__m256)__lasx_xvldi(0);\n    float accum1 = 0;\n    for (int i = 0; i < nb; ++i) {\n\n        const int8_t   * q8 = y[i].qs;\n        const uint8_t  * qs = x[i].qs;\n        const uint16_t * qh = x[i].qh;\n\n        __m256i sumi = __lasx_xvldi(0);\n        int sumi1 = 0;\n        for (int ib = 0; ib < QK_K/32; ib += 2) {\n            __m256i q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)], 0);\n            q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], 1);\n            q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2);\n            q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3);\n\n            __m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0);\n            q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1);\n            q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2);\n            q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], 3);\n\n            qs += 8;\n            const __m256i q8b_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n            const __m256i q8b_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;\n\n            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);\n            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);\n            const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;\n            const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;\n\n            __m256i tmp1, tmp5, tmp6;\n            tmp1 = __lasx_xvreplgr2vr_h(ls1);\n            tmp5 = __lasx_xvmulwev_w_h(dot1, tmp1);\n            tmp6 = __lasx_xvmulwod_w_h(dot1, tmp1);\n            const __m256i p1 = __lasx_xvadd_w(tmp5, tmp6);\n\n            tmp1 = __lasx_xvreplgr2vr_h(ls2);\n            tmp5 = __lasx_xvmulwev_w_h(dot2, tmp1);\n            tmp6 = __lasx_xvmulwod_w_h(dot2, tmp1);\n            const __m256i p2 = __lasx_xvadd_w(tmp5, tmp6);\n\n            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p1, p2));\n            sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1\n                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;\n        }\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum);\n        accum1 += d * sumi1;\n    }\n\n    *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK4_NL == 0);\n    static_assert(QK4_NL == QK8_0, \"QK4_NL and QK8_0 must be the same\");\n\n    const block_iq4_nl * GGML_RESTRICT x = vx;\n    const block_q8_0   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK4_NL;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined (__loongarch_asx)\n\n    const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);\n    const __m128i m4b  = __lsx_vreplgr2vr_b(0x0f);\n    const __m256i mone = __lasx_xvreplgr2vr_h(1);\n\n    __m256 accum1 = (__m256)__lasx_xvldi(0);\n    __m256 accum2 = (__m256)__lasx_xvldi(0);\n    for (; ib + 1 < nb; ib += 2) {\n        const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0);\n        const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0);\n        const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0);\n        const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0);\n        const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)),\n                                              lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b)));\n        const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)),\n                                              lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b)));\n        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);\n        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);\n        const __m256i p_1 = lasx_madd_h(p16_1, mone);\n        const __m256i p_2 = lasx_madd_h(p16_2, mone);\n        accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_FP16_TO_FP32(x[ib + 0].d)),\n                __lasx_xvffint_s_w(p_1), accum1);\n        accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_FP16_TO_FP32(x[ib + 1].d)),\n                __lasx_xvffint_s_w(p_2), accum2);\n    }\n\n    sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));\n\n#endif\n    for (; ib < nb; ++ib) {\n        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d);\n        int sumi1 = 0, sumi2 = 0;\n        for (int j = 0; j < QK4_NL/2; ++j) {\n            sumi1 += y[ib].qs[j+       0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];\n            sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >>  4];\n        }\n        sumf += d * (sumi1 + sumi2);\n    }\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_K == 0);\n\n    const block_iq4_xs * GGML_RESTRICT x = vx;\n    const block_q8_K   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__loongarch_asx)\n\n    const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);\n\n    __m256 accum = (__m256)__lasx_xvldi(0);\n\n    for (int ibl = 0; ibl < nb; ++ibl) {\n        const uint8_t * qs = x[ibl].qs;\n        const int8_t  * q8 = y[ibl].qs;\n        uint16_t sh = x[ibl].scales_h;\n        __m256i sumi1 = __lasx_xvldi(0);\n        __m256i sumi2 = __lasx_xvldi(0);\n        for (int ib = 0; ib < QK_K/32; ib += 2) {\n            const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;\n            const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;\n            const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;\n            const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;\n            const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)),\n                                                  __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf)));\n            const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)),\n                                                  __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf)));\n            const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);\n            const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);\n            const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;\n            const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;\n            sh >>= 4;\n            const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1));\n            const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2));\n            sumi1 = __lasx_xvadd_w(p_1, sumi1);\n            sumi2 = __lasx_xvadd_w(p_2, sumi2);\n        }\n        accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ibl].d)*y[ibl].d),\n                __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accum);\n    }\n\n    *s = hsum_float_8(accum);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\n"
  },
  {
    "path": "src/ggml-cpu/arch/powerpc/cpu-feats.cpp",
    "content": "# include \"ggml-backend-impl.h\"\n\n#if defined(__powerpc64__) || defined(__ppc64__) || defined(__PPC64__)\n\n#if defined(__linux__)\n#include <sys/auxv.h>\n#endif\n\n#include <string>\n\nstruct powerpc_features {\n    std::string platform = \"\";\n    int power_version    = -1;\n\n    bool has_vsx         = false;\n\n    powerpc_features() {\n#if defined(__linux__)\n        unsigned long auxval = getauxval(AT_PLATFORM);\n        if (auxval) {\n            platform = std::string(reinterpret_cast<const char*>(auxval));\n            // TBD: Do systems exist that return this in uppercase?\n            if (platform.substr(0, 5) == \"power\") {\n                // Extractt a numeric suffix, if one exists\n                int vpos = -1;\n                for (int i = platform.length() - 1; i >= 0; i--) {\n                    if (std::isdigit(platform[i])) {\n                        vpos = i;\n                    } else {\n                        break;\n                    }\n                }\n                if (vpos > -1) {\n                    power_version = std::stoi(platform.substr(vpos));\n                }\n            }\n        }\n#endif\n        if (power_version >= 9) {\n            has_vsx = true;\n        }\n    }\n};\n\nstatic int ggml_backend_cpu_powerpc_score() {\n    int score = 1;\n    powerpc_features pf;\n\n// Platform scores\n#if defined(GGML_USE_POWER7)\n    if (pf.power_version < 7) { return 0; }\n    score += 1<<1;\n#endif\n#if defined(GGML_USE_POWER8)\n    if (pf.power_version < 8) { return 0; }\n    score += 1<<2;\n#endif\n#if defined(GGML_USE_POWER9)\n    if (pf.power_version < 9) { return 0; }\n    score += 1<<3;\n#endif\n#if defined(GGML_USE_POWER10)\n    if (pf.power_version < 10) { return 0; }\n    score += 1<<4;\n#endif\n#if defined(GGML_USE_POWER11)\n    if (pf.power_version < 11) { return 0; }\n    score += 1<<5;\n#endif\n\n// Feature scores\n#if defined(GGML_USE_VSX)\n    if (!pf.has_vsx) { return 0; }\n    score += 1<<6;\n#endif\n\n    return score;\n}\n\nGGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_powerpc_score)\n\n#endif // defined(__powerpc64__) || defined(__ppc64__) || defined(__PPC64__)\n"
  },
  {
    "path": "src/ggml-cpu/arch/powerpc/quants.c",
    "content": "#define GGML_COMMON_IMPL_C\n#include \"ggml-common.h\"\n#include \"ggml-quants.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-cpu.h\"\n#include \"simd-mappings.h\"\n\n#include \"../../quants.h\"\n#include \"../../ggml-cpu-impl.h\"\n\n#include <math.h>\n#include <string.h>\n#include <assert.h>\n#include <float.h>\n#include <stdlib.h> // for qsort\n#include <stdio.h>  // for GGML_ASSERT\n\n#define GROUP_MAX_EPS 1e-15f\n#define GROUP_MAX_EPS_IQ3_XXS 1e-8f\n#define GROUP_MAX_EPS_IQ2_S 1e-8f\n#define GROUP_MAX_EPS_IQ1_M 1e-7f\n#define GROUP_MAX_EPS_IQ1_S 1e-12f\n\n#define UNUSED GGML_UNUSED\n\n#if defined(__POWER9_VECTOR__)\n#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s\n#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)\n#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)\n#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)\n#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)\n#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)\n#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)\n#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)\n\n// precomputed tables for expanding 8bits to 8 bytes:\nstatic const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4\nstatic const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4\n#endif\n\nvoid quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK8_0 == 32);\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n    block_q8_0 * GGML_RESTRICT y = vy;\n\n#if defined(__POWER9_VECTOR__)\n    for (int i = 0; i < nb; i++) {\n        vector float srcv [8];\n        vector float asrcv[8];\n        vector float amaxv[8];\n        vector signed int vi[8];\n\n        for (int j = 0; j < 8; j++) srcv[j]  = vec_xl(0, x + i*32 + 4*j);\n        for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);\n\n        for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);\n        for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);\n        for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);\n\n        const float amax = MAX(MAX(vec_extract(amaxv[0], 0),\n                                   vec_extract(amaxv[0], 1)),\n                               MAX(vec_extract(amaxv[0], 2),\n                                   vec_extract(amaxv[0], 3)));\n\n        const float d = amax / ((1 << 7) - 1);\n        const float id = d ? 1.0f/d : 0.0f;\n        const vector float vid = vec_splats(id);\n\n        y[i].d = GGML_CPU_FP32_TO_FP16(d);\n\n        for (int j = 0; j < 8; j++) {\n            const vector float v  = vec_round(vec_mul(srcv[j], vid));\n            vi[j] = vec_cts(v, 0);\n        }\n        vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])),  0, &y[i].qs[0]);\n        vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]);\n    }\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_0_ref(x, y, k);\n#endif\n}\n\nvoid quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(k % QK8_1 == 0);\n    const int nb = k / QK8_1;\n\n    block_q8_1 * GGML_RESTRICT y = vy;\n\n#if defined(__POWER9_VECTOR__)\n    for (int i = 0; i < nb; i++) {\n        vector float srcv [8];\n        vector float asrcv[8];\n        vector float amaxv[8];\n        vector signed int vi[8];\n\n        for (int j = 0; j < 8; j++) srcv[j]  = vec_xl(0, x + i*32 + 4*j);\n        for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);\n\n        for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);\n        for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);\n        for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);\n\n        const float amax = MAX(MAX(vec_extract(amaxv[0], 0),\n                                   vec_extract(amaxv[0], 1)),\n                               MAX(vec_extract(amaxv[0], 2),\n                                   vec_extract(amaxv[0], 3)));\n\n        const float d = amax / ((1 << 7) - 1);\n        const float id = d ? 1.0f/d : 0.0f;\n        const vector float vid = vec_splats(id);\n\n        y[i].d = GGML_CPU_FP32_TO_FP16(d);\n\n        vector int accv = vec_splats(0);\n\n        for (int j = 0; j < 8; j++) {\n            const vector float v  = vec_round(vec_mul(srcv[j], vid));\n            vi[j] = vec_cts(v, 0);\n\n            accv = vec_add(accv, vi[j]);\n        }\n        vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])),  0, &y[i].qs[0]);\n        vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]);\n\n        accv = vec_add(accv, vec_sld(accv, accv, 4));\n        accv = vec_add(accv, vec_sld(accv, accv, 8));\n        y[i].s = GGML_CPU_FP32_TO_FP16(d * vec_extract(accv, 0));\n    }\n\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_1_ref(x, y, k);\n#endif\n}\n\n\n//===================================== Dot products =================================\n\nvoid ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__POWER9_VECTOR__)\n    const vector signed char lowMask = vec_splats((signed char)0xF);\n    const vector signed int v0 = vec_splats((int32_t)0);\n    const vector unsigned char v4 = vec_splats((unsigned char)0x4);\n    const vector signed char v8 = vec_splats((signed char)0x8);\n\n    vector float vsumf0 = vec_splats(0.0f);\n\n#pragma GCC unroll 8\n    for (; ib < nb; ++ib) {\n        __builtin_prefetch(x[ib].qs, 0, 1);\n        __builtin_prefetch(y[ib].qs, 0, 1);\n\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d));\n        vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d));\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);\n        vector signed char q8y0 = vec_xl( 0, y[ib].qs);\n        vector signed char q8y1 = vec_xl(16, y[ib].qs);\n\n        vector signed char q4x0 = vec_and(qxs, lowMask);\n        vector signed char q4x1 = vec_sr(qxs, v4);\n\n        q4x0 = vec_sub(q4x0, v8);\n        q4x1 = vec_sub(q4x1, v8);\n\n        vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));\n        vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));\n\n        vector signed int vsumi0 = v0;\n\n        vsumi0 = vec_sum4s(qv0, vsumi0);\n        vsumi0 = vec_sum4s(qv1, vsumi0);\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n    }\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    sumf = vec_extract(vsumf0, 0);\n\n    *s = sumf;\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(ib);\n    UNUSED(sumf);\n    ggml_vec_dot_q4_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__POWER9_VECTOR__)\n    const vector signed char lowMask = vec_splats((signed char)0xF);\n    const vector signed int v0 = vec_splats((int32_t)0);\n    const vector unsigned char v4 = vec_splats((unsigned char)0x4);\n\n    vector float vsumf0 = vec_splats(0.0f);\n\n#pragma GCC unroll 4\n    for (; ib < nb; ++ib) {\n        __builtin_prefetch(x[ib].qs, 0, 1);\n        __builtin_prefetch(y[ib].qs, 0, 1);\n\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d));\n        vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d));\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].m));\n        vector float vys = {GGML_CPU_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f};\n        vsumf0 = vec_madd(vxmin, vys, vsumf0);\n\n        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);\n        vector signed char q8y0 = vec_xl( 0, y[ib].qs);\n        vector signed char q8y1 = vec_xl(16, y[ib].qs);\n\n        vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask);\n        vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4);\n\n        vector signed int vsumi0 = v0;\n\n        vsumi0 = vec_msum(q8y0, q4x0, vsumi0);\n        vsumi0 = vec_msum(q8y1, q4x1, vsumi0);\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n    }\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    sumf = vec_extract(vsumf0, 0);\n\n    *s = sumf;\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(ib);\n    UNUSED(sumf);\n    ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_MXFP4 == 0);\n    static_assert(QK_MXFP4 == QK8_0, \"QK_MXFP4 and QK8_0 must be the same\");\n\n    const block_mxfp4 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_MXFP4;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__POWER9_VECTOR__)\n    const vector signed char lowMask = vec_splats((signed char)0xF);\n    const vector unsigned char vshift4 = vec_splats((unsigned char)4);\n    vector float vsumf0 = vec_splats(0.0f);\n\n    vector signed char kv = vec_xl(0, (const signed char *)kvalues_mxfp4);\n\n#pragma GCC unroll 8\n    for (; ib < nb; ++ib) {\n        __builtin_prefetch(x[ib].qs, 0, 1);\n        __builtin_prefetch(y[ib].qs, 0, 1);\n\n        vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d) *\n                                      GGML_E8M0_TO_FP32_HALF(x[ib].e));\n\n        vector signed char q8y0 = vec_xl( 0, y[ib].qs);\n        vector signed char q8y1 = vec_xl(16, y[ib].qs);\n\n        vector signed char qxs = (vector signed char)vec_xl(0, x[ib].qs);\n\n        vector unsigned char lo_nibbles = (vector unsigned char)vec_and(qxs, lowMask);\n        vector unsigned char hi_nibbles = (vector unsigned char)vec_sr(qxs, vshift4);\n\n        vector signed char q4x0 = vec_perm(kv, kv, lo_nibbles);\n        vector signed char q4x1 = vec_perm(kv, kv, hi_nibbles);\n\n        vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));\n        vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));\n\n        vector signed int vsumi0 = vec_splats((int32_t)0);\n        vsumi0 = vec_sum4s(qv0, vsumi0);\n        vsumi0 = vec_sum4s(qv1, vsumi0);\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vyd, vsumf0);\n    }\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n    sumf = vec_extract(vsumf0, 0);\n    *s = sumf;\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(ib);\n    UNUSED(sumf);\n    ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    int ib = 0;\n    float sumf = 0;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n#if defined(__POWER9_VECTOR__)\n    const vector signed char lowMask = vec_splats((signed char)0xF);\n    const vector unsigned char v4 = vec_splats((unsigned char)4);\n\n    vector float vsumf0 = vec_splats(0.0f);\n\n#pragma GCC unroll 4\n    for (; ib < nb; ++ib) {\n        __builtin_prefetch(x[ib].qs, 0, 1);\n        __builtin_prefetch(y[ib].qs, 0, 1);\n\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d));\n        vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d));\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])};\n        vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])};\n\n        vector signed char qh0 = (vector signed char)aux64x2_0;\n        vector signed char qh1 = (vector signed char)aux64x2_1;\n\n        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);\n\n        vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0);\n        vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1);\n\n        vector signed char q8y0 = vec_xl(  0, y[ib].qs);\n        vector signed char q8y1 = vec_xl( 16, y[ib].qs);\n\n        vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0));\n        vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1));\n\n        qv0 = vec_add(qv0, qv1);\n\n        vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0));\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n    }\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    sumf = vec_extract(vsumf0, 0);\n\n    *s = sumf;\n#else\n    UNUSED(ib);\n    UNUSED(sumf);\n    UNUSED(x);\n    UNUSED(y);\n    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    int ib = 0;\n    float sumf = 0;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_1);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n#if defined(__POWER9_VECTOR__)\n    const vector signed char lowMask = vec_splats((signed char)0xF);\n    const vector signed int v0 = vec_splats((int32_t)0);\n    const vector unsigned char v4 = vec_splats((unsigned char)0x4);\n\n    vector float vsumf0 = vec_splats(0.0f);\n\n#pragma GCC unroll 4\n    for (; ib < nb; ++ib) {\n        __builtin_prefetch(x[ib].qs, 0, 1);\n        __builtin_prefetch(y[ib].qs, 0, 1);\n\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d));\n        vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d));\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].m));\n        vector float vys = {GGML_CPU_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f};\n        vsumf0 = vec_madd(vxmin, vys, vsumf0);\n\n        vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])};\n        vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])};\n\n        vector signed char qh0 = (vector signed char)aux64x2_0;\n        vector signed char qh1 = (vector signed char)aux64x2_1;\n\n        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);\n\n        vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0);\n        vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1);\n\n        vector signed char q8y0 = vec_xl(  0, y[ib].qs);\n        vector signed char q8y1 = vec_xl( 16, y[ib].qs);\n\n        vector signed int vsumi0 = v0;\n\n        vsumi0 = vec_msum(q8y0, q5x0, vsumi0);\n        vsumi0 = vec_msum(q8y1, q5x1, vsumi0);\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n    }\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    sumf = vec_extract(vsumf0, 0);\n\n    *s = sumf;\n#else\n    UNUSED(nb);\n    UNUSED(ib);\n    UNUSED(sumf);\n    UNUSED(x);\n    UNUSED(y);\n    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q8_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__POWER9_VECTOR__)\n    const vector signed int v0 = vec_splats((int32_t)0);\n    vector float vsumf0 = vec_splats(0.0f);\n\n#pragma GCC unroll 8\n    for (; ib < nb; ++ib) {\n        __builtin_prefetch(x[ib].qs, 0, 1);\n        __builtin_prefetch(y[ib].qs, 0, 1);\n\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d));\n        vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d));\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector signed char q8x0 = vec_xl( 0, x[ib].qs);\n        vector signed char q8x1 = vec_xl(16, x[ib].qs);\n        vector signed char q8y0 = vec_xl( 0, y[ib].qs);\n        vector signed char q8y1 = vec_xl(16, y[ib].qs);\n\n        vector signed short qv0 = vec_mule(q8x0, q8y0);\n        vector signed short qv1 = vec_mulo(q8x0, q8y0);\n        vector signed short qv2 = vec_mule(q8x1, q8y1);\n        vector signed short qv3 = vec_mulo(q8x1, q8y1);\n\n        vector signed int vsumi0 = v0;\n        vector signed int vsumi1 = v0;\n\n        vsumi0 = vec_sum4s(qv0, vsumi0);\n        vsumi1 = vec_sum4s(qv1, vsumi1);\n        vsumi0 = vec_sum4s(qv2, vsumi0);\n        vsumi1 = vec_sum4s(qv3, vsumi1);\n\n        vsumi0 = vec_add(vsumi0, vsumi1);\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n    }\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    sumf = vec_extract(vsumf0, 0);\n\n    *s = sumf;\n#else\n    UNUSED(nb);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(ib);\n    UNUSED(sumf);\n    ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q2_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__POWER9_VECTOR__)\n    const vector signed char lowMask = vec_splats((signed char)0x3);\n    const vector signed char lowScaleMask = vec_splats((signed char)0xF);\n    const vector int v0 = vec_splats((int32_t)0);\n    const vector unsigned char v2 = vec_splats((unsigned char)0x2);\n    const vector unsigned char v6 = vec_splats((unsigned char)0x6);\n    const vector unsigned char v4 = vec_splats((unsigned char)0x4);\n\n    vector float vsumf0 = vec_splats(0.0f);\n    vector float vsumf1 = vec_splats(0.0f);\n    vector float vsumf2 = vec_splats(0.0f);\n    vector float vsumf3 = vec_splats(0.0f);\n\n    for (int i = 0; i < nb; ++i) {\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));\n        vector float vyd = vec_splats(y[i].d);\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].dmin));\n        vector float vdmin = vec_mul(vxmin, vyd);\n\n        vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);\n        vector signed short q8ysums1 = vec_xl(16, y[i].bsums);\n\n        vector signed char q2xmins = (vector signed char)vec_xl( 0, x[i].scales);\n        vector signed char vscales = vec_and(q2xmins, lowScaleMask);\n\n        q2xmins = vec_sr(q2xmins, v4);\n        vector signed short q2xmins0 = vec_unpackh(q2xmins);\n        vector signed short q2xmins1 = vec_unpackl(q2xmins);\n\n        vector signed int prod0 = vec_mule(q2xmins0, q8ysums0);\n        vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0);\n        vector signed int prod2 = vec_mule(q2xmins1, q8ysums1);\n        vector signed int prod3 = vec_mulo(q2xmins1, q8ysums1);\n\n        vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);\n        vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);\n        vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);\n        vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);\n\n        vector signed int vsumi0 = v0;\n        vector signed int vsumi1 = v0;\n        vector signed int vsumi2 = v0;\n        vector signed int vsumi3 = v0;\n        vector signed int vsumi4 = v0;\n        vector signed int vsumi5 = v0;\n        vector signed int vsumi6 = v0;\n        vector signed int vsumi7 = v0;\n\n        const uint8_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        for (int j = 0; j < QK_K/128; ++j) {\n            __builtin_prefetch(q2, 0, 1);\n            __builtin_prefetch(q8, 0, 1);\n\n            vector signed char qxs0 = (vector signed char)vec_xl( 0, q2);\n            vector signed char qxs1 = (vector signed char)vec_xl(16, q2);\n            q2 += 32;\n\n            vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask);\n            vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask);\n            vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask);\n            vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask);\n            vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask);\n            vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask);\n            vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask);\n            vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask);\n\n            vector signed char q8y00 = vec_xl(  0, q8);\n            vector signed char q8y10 = vec_xl( 16, q8);\n            vector signed char q8y01 = vec_xl( 32, q8);\n            vector signed char q8y11 = vec_xl( 48, q8);\n            vector signed char q8y02 = vec_xl( 64, q8);\n            vector signed char q8y12 = vec_xl( 80, q8);\n            vector signed char q8y03 = vec_xl( 96, q8);\n            vector signed char q8y13 = vec_xl(112, q8);\n            q8 += 128;\n\n            vector signed int qv0 = vec_msum(q8y00, q2x00, v0);\n            vector signed int qv1 = vec_msum(q8y01, q2x01, v0);\n            vector signed int qv2 = vec_msum(q8y02, q2x02, v0);\n            vector signed int qv3 = vec_msum(q8y03, q2x03, v0);\n            vector signed int qv4 = vec_msum(q8y10, q2x10, v0);\n            vector signed int qv5 = vec_msum(q8y11, q2x11, v0);\n            vector signed int qv6 = vec_msum(q8y12, q2x12, v0);\n            vector signed int qv7 = vec_msum(q8y13, q2x13, v0);\n\n            vector signed short vscales_07 = vec_unpackh(vscales);\n            vector signed int vscales_03 = vec_unpackh(vscales_07);\n            vector signed int vscales_47 = vec_unpackl(vscales_07);\n            vector signed int vs0 = vec_splat(vscales_03, 0);\n            vector signed int vs1 = vec_splat(vscales_03, 1);\n            vector signed int vs2 = vec_splat(vscales_03, 2);\n            vector signed int vs3 = vec_splat(vscales_03, 3);\n            vector signed int vs4 = vec_splat(vscales_47, 0);\n            vector signed int vs5 = vec_splat(vscales_47, 1);\n            vector signed int vs6 = vec_splat(vscales_47, 2);\n            vector signed int vs7 = vec_splat(vscales_47, 3);\n            vscales = vec_sld(vscales, vscales, 8);\n\n            vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0);\n            vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1);\n            vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2);\n            vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3);\n            vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4);\n            vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5);\n            vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6);\n            vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7);\n        }\n\n        vsumi0 = vec_add(vsumi0, vsumi4);\n        vsumi1 = vec_add(vsumi1, vsumi5);\n        vsumi2 = vec_add(vsumi2, vsumi6);\n        vsumi3 = vec_add(vsumi3, vsumi7);\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);\n        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);\n        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);\n    }\n\n    vsumf0 = vec_add(vsumf0, vsumf2);\n    vsumf1 = vec_add(vsumf1, vsumf3);\n\n    vsumf0 = vec_add(vsumf0, vsumf1);\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    *s = vec_extract(vsumf0, 0);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n\n    const block_q3_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__POWER9_VECTOR__)\n    const vector signed char lowMask = vec_splats((signed char)0x3);\n    const vector signed char lowMask1 = vec_splats((int8_t)0xf);\n    const vector signed char lowMask2 = vec_splats((int8_t)0x30);\n    const vector int v0 = vec_splats((int32_t)0);\n    const vector signed char v1 = vec_splats((signed char)0x1);\n    const vector unsigned char v2 = vec_splats((unsigned char)0x2);\n    const vector unsigned char v3 = vec_splats((unsigned char)0x3);\n    const vector unsigned char v4 = vec_splats((unsigned char)0x4);\n    const vector unsigned char v6 = vec_splats((unsigned char)0x6);\n    const vector signed char off = vec_splats((signed char)0x20);\n\n    vector float vsumf0 = vec_splats(0.0f);\n    vector float vsumf1 = vec_splats(0.0f);\n    vector float vsumf2 = vec_splats(0.0f);\n    vector float vsumf3 = vec_splats(0.0f);\n\n    for (int i = 0; i < nb; ++i) {\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));\n        vector float vyd = vec_splats(y[i].d);\n        vector float vd = vec_mul(vxd, vyd);\n\n        UNUSED(kmask1);\n        UNUSED(kmask2);\n\n        vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);\n        vector signed char u1 = vec_and(u0, lowMask1);\n        vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);\n        vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2));\n        vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4);\n        vector signed char u31 = vec_and(u3, lowMask2);\n\n        u1 = vec_or(u1, u30);\n        u2 = vec_or(vec_sr(u0, v4), u31);\n\n        vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2);\n        vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask);\n        vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask);\n\n        vscales = vec_sub(vscales, off);\n\n        vector signed int vsumi0 = v0;\n        vector signed int vsumi1 = v0;\n        vector signed int vsumi2 = v0;\n        vector signed int vsumi3 = v0;\n        vector signed int vsumi4 = v0;\n        vector signed int vsumi5 = v0;\n        vector signed int vsumi6 = v0;\n        vector signed int vsumi7 = v0;\n\n        const uint8_t * GGML_RESTRICT q3 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        for (int j = 0; j < QK_K/128; ++j) {\n            __builtin_prefetch(q3, 0, 1);\n            __builtin_prefetch(q8, 0, 1);\n\n            vector signed char qxs0 = (vector signed char)vec_xl( 0, q3);\n            vector signed char qxs1 = (vector signed char)vec_xl(16, q3);\n            q3 += 32;\n\n            //the low 2 bits\n            vector signed char qxs00 = vec_and(qxs0, lowMask);\n            vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask);\n            vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask);\n            vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask);\n            vector signed char qxs10 = vec_and(qxs1, lowMask);\n            vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask);\n            vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask);\n            vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask);\n\n            //the 3rd bit\n            vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2);\n            vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2);\n            vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2);\n            vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2);\n            vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2);\n            vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2);\n            vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2);\n            vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2);\n            qxhs0 = vec_sr(qxhs0, v4);\n            qxhs1 = vec_sr(qxhs1, v4);\n\n            vector signed char q3x00 = vec_sub(qxs00, qxh00);\n            vector signed char q3x01 = vec_sub(qxs01, qxh01);\n            vector signed char q3x02 = vec_sub(qxs02, qxh02);\n            vector signed char q3x03 = vec_sub(qxs03, qxh03);\n            vector signed char q3x10 = vec_sub(qxs10, qxh10);\n            vector signed char q3x11 = vec_sub(qxs11, qxh11);\n            vector signed char q3x12 = vec_sub(qxs12, qxh12);\n            vector signed char q3x13 = vec_sub(qxs13, qxh13);\n\n            vector signed char q8y00 = vec_xl(  0, q8);\n            vector signed char q8y10 = vec_xl( 16, q8);\n            vector signed char q8y01 = vec_xl( 32, q8);\n            vector signed char q8y11 = vec_xl( 48, q8);\n            vector signed char q8y02 = vec_xl( 64, q8);\n            vector signed char q8y12 = vec_xl( 80, q8);\n            vector signed char q8y03 = vec_xl( 96, q8);\n            vector signed char q8y13 = vec_xl(112, q8);\n            q8 += 128;\n\n            vector signed short vscales_h = vec_unpackh(vscales);\n            vector signed short vs0 = vec_splat(vscales_h, 0);\n            vector signed short vs1 = vec_splat(vscales_h, 1);\n            vector signed short vs2 = vec_splat(vscales_h, 2);\n            vector signed short vs3 = vec_splat(vscales_h, 3);\n            vector signed short vs4 = vec_splat(vscales_h, 4);\n            vector signed short vs5 = vec_splat(vscales_h, 5);\n            vector signed short vs6 = vec_splat(vscales_h, 6);\n            vector signed short vs7 = vec_splat(vscales_h, 7);\n            vscales = vec_sld(vscales, vscales, 8);\n\n            vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00));\n            vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01));\n            vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02));\n            vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03));\n            vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10));\n            vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11));\n            vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12));\n            vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13));\n\n            vsumi0 = vec_msum(qv00, vs0, vsumi0);\n            vsumi1 = vec_msum(qv01, vs2, vsumi1);\n            vsumi2 = vec_msum(qv02, vs4, vsumi2);\n            vsumi3 = vec_msum(qv03, vs6, vsumi3);\n            vsumi4 = vec_msum(qv10, vs1, vsumi4);\n            vsumi5 = vec_msum(qv11, vs3, vsumi5);\n            vsumi6 = vec_msum(qv12, vs5, vsumi6);\n            vsumi7 = vec_msum(qv13, vs7, vsumi7);\n        }\n\n        vsumi0 = vec_add(vsumi0, vsumi4);\n        vsumi1 = vec_add(vsumi1, vsumi5);\n        vsumi2 = vec_add(vsumi2, vsumi6);\n        vsumi3 = vec_add(vsumi3, vsumi7);\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);\n        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);\n        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);\n    }\n\n    vsumf0 = vec_add(vsumf0, vsumf2);\n    vsumf1 = vec_add(vsumf1, vsumf3);\n\n    vsumf0 = vec_add(vsumf0, vsumf1);\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    *s = vec_extract(vsumf0, 0);\n\n#else\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n#if defined(__POWER9_VECTOR__)\n    const vector signed char lowMask = vec_splats((signed char)0xF);\n    const vector signed char lowMask1 = vec_splats((int8_t)0x3f);\n    const vector signed char lowMask2 = vec_splats((int8_t)0x30);\n    const vector int v0 = vec_splats((int32_t)0);\n    const vector unsigned char v2 = vec_splats((uint8_t)2);\n    const vector unsigned char v4 = vec_splats((unsigned char)0x4);\n\n    vector float vsumf0 = vec_splats(0.0f);\n    vector float vsumf1 = vec_splats(0.0f);\n    vector float vsumf2 = vec_splats(0.0f);\n    vector float vsumf3 = vec_splats(0.0f);\n\n    for (int i = 0; i < nb; ++i) {\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));\n        vector float vyd = vec_splats(y[i].d);\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].dmin));\n        vector float vdmin = vec_mul(vxmin, vyd);\n\n        vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);\n        vector signed short q8ysums1 = vec_xl(16, y[i].bsums);\n\n        UNUSED(kmask1);\n        UNUSED(kmask2);\n        UNUSED(kmask3);\n        UNUSED(utmp);\n\n        vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);\n        vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2);\n        vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);\n        vector signed char u3 = vec_sr(u2, v4);\n\n        vector signed char u30 = u1;\n        vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3);\n\n        u1 = vec_and(u0, lowMask1);\n        u2 = vec_or(u30, u31);\n\n        vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2);\n\n        vector signed short vscales = vec_unpackh(utmps);\n        vector signed short q4xmins = vec_unpackl(utmps);\n        vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins);\n        vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins);\n\n        vector signed int prod0 = vec_mule(q4xmins0, q8ysums0);\n        vector signed int prod1 = vec_mule(q4xmins1, q8ysums1);\n        vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0);\n        vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1);\n\n        vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);\n        vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);\n        vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);\n        vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);\n\n        vector signed int vsumi0 = v0;\n        vector signed int vsumi1 = v0;\n        vector signed int vsumi2 = v0;\n        vector signed int vsumi3 = v0;\n\n        const uint8_t * GGML_RESTRICT q4 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        for (int j = 0; j < QK_K/64; j+=2) {\n            __builtin_prefetch(q4, 0, 1);\n            __builtin_prefetch(q8, 0, 1);\n\n            vector signed char qxs0 = (vector signed char)vec_xl( 0, q4);\n            vector signed char qxs1 = (vector signed char)vec_xl(16, q4);\n            vector signed char qxs2 = (vector signed char)vec_xl(32, q4);\n            vector signed char qxs3 = (vector signed char)vec_xl(48, q4);\n            q4 += 64;\n\n            vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask);\n            vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4);\n            vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask);\n            vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4);\n            vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask);\n            vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4);\n            vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask);\n            vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4);\n\n            vector signed char q8y00 = vec_xl(  0, q8);\n            vector signed char q8y10 = vec_xl( 16, q8);\n            vector signed char q8y01 = vec_xl( 32, q8);\n            vector signed char q8y11 = vec_xl( 48, q8);\n            vector signed char q8y20 = vec_xl( 64, q8);\n            vector signed char q8y30 = vec_xl( 80, q8);\n            vector signed char q8y21 = vec_xl( 96, q8);\n            vector signed char q8y31 = vec_xl(112, q8);\n            q8 += 128;\n\n            vector signed int qv00 = vec_msum(q8y00, q4x00, v0);\n            vector signed int qv01 = vec_msum(q8y01, q4x01, v0);\n            vector signed int qv10 = vec_msum(q8y10, q4x10, v0);\n            vector signed int qv11 = vec_msum(q8y11, q4x11, v0);\n            vector signed int qv20 = vec_msum(q8y20, q4x20, v0);\n            vector signed int qv21 = vec_msum(q8y21, q4x21, v0);\n            vector signed int qv30 = vec_msum(q8y30, q4x30, v0);\n            vector signed int qv31 = vec_msum(q8y31, q4x31, v0);\n\n            vector signed int vscales_h = vec_unpackh(vscales);\n            vector signed int vs0 = vec_splat(vscales_h, 0);\n            vector signed int vs1 = vec_splat(vscales_h, 1);\n            vector signed int vs2 = vec_splat(vscales_h, 2);\n            vector signed int vs3 = vec_splat(vscales_h, 3);\n            vscales = vec_sld(vscales, vscales, 8);\n\n            vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0);\n            vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1);\n            vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2);\n            vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3);\n\n            vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0);\n            vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1);\n            vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2);\n            vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3);\n        }\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);\n        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);\n        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);\n    }\n\n    vsumf0 = vec_add(vsumf0, vsumf2);\n    vsumf1 = vec_add(vsumf1, vsumf3);\n\n    vsumf0 = vec_add(vsumf0, vsumf1);\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    *s = vec_extract(vsumf0, 0);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    UNUSED(utmp);\n    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n#if defined(__POWER9_VECTOR__)\n    const vector signed char lowMask = vec_splats((signed char)0xF);\n    const vector signed char lowMask1 = vec_splats((int8_t)0x3f);\n    const vector signed char lowMask2 = vec_splats((int8_t)0x30);\n    const vector int v0 = vec_splats((int32_t)0);\n    const vector unsigned char v1 = vec_splats((unsigned char)0x1);\n    const vector unsigned char v2 = vec_splats((unsigned char)0x2);\n    const vector unsigned char v3 = vec_splats((unsigned char)0x3);\n    const vector unsigned char v4 = vec_splats((unsigned char)0x4);\n\n    vector float vsumf0 = vec_splats(0.0f);\n    vector float vsumf1 = vec_splats(0.0f);\n    vector float vsumf2 = vec_splats(0.0f);\n    vector float vsumf3 = vec_splats(0.0f);\n\n    for (int i = 0; i < nb; ++i) {\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));\n        vector float vyd = vec_splats(y[i].d);\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].dmin));\n        vector float vdmin = vec_mul(vxmin, vyd);\n\n        UNUSED(kmask1);\n        UNUSED(kmask2);\n        UNUSED(kmask3);\n        UNUSED(utmp);\n\n        vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);\n        vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2);\n        vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);\n        vector signed char u3 = vec_sr(u2, v4);\n\n        vector signed char u30 = u1;\n        vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3);\n\n        u1 = vec_and(u0, lowMask1);\n        u2 = vec_or(u30, u31);\n\n        vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2);\n\n        vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);\n        vector signed short q8ysums1 = vec_xl(16, y[i].bsums);\n\n        vector signed short vscales = vec_unpackh(utmps);\n\n        vector signed short q5xmins = vec_unpackl(utmps);\n        vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins);\n        vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins);\n\n        vector signed int prod0 = vec_mule(q5xmins0, q8ysums0);\n        vector signed int prod1 = vec_mule(q5xmins1, q8ysums1);\n        vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0);\n        vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1);\n\n        vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);\n        vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);\n        vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);\n        vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);\n\n        vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh);\n        vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh);\n\n        vector signed int vsumi0 = v0;\n        vector signed int vsumi1 = v0;\n        vector signed int vsumi2 = v0;\n        vector signed int vsumi3 = v0;\n\n        const uint8_t * GGML_RESTRICT q5 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        for (int j = 0; j < QK_K/64; ++j) {\n            __builtin_prefetch(q5, 0, 1);\n            __builtin_prefetch(q8, 0, 1);\n\n            vector signed char qxs0 = (vector signed char)vec_xl( 0, q5);\n            vector signed char qxs1 = (vector signed char)vec_xl(16, q5);\n            q5 += 32;\n\n            vector signed char qxs00 = vec_and(qxs0, lowMask);\n            vector signed char qxs01 = vec_sr(qxs0, v4);\n            vector signed char qxs10 = vec_and(qxs1, lowMask);\n            vector signed char qxs11 = vec_sr(qxs1, v4);\n\n            vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4);\n            vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3);\n            vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4);\n            vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3);\n            qxhs0 = vec_sr(qxhs0, v2);\n            qxhs1 = vec_sr(qxhs1, v2);\n\n            vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00);\n            vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01);\n            vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10);\n            vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11);\n\n            vector signed char q8y00 = vec_xl( 0, q8);\n            vector signed char q8y10 = vec_xl(16, q8);\n            vector signed char q8y01 = vec_xl(32, q8);\n            vector signed char q8y11 = vec_xl(48, q8);\n            q8 += 64;\n\n            vector signed int qv00 = vec_msum(q8y00, q5x00, v0);\n            vector signed int qv01 = vec_msum(q8y01, q5x01, v0);\n            vector signed int qv10 = vec_msum(q8y10, q5x10, v0);\n            vector signed int qv11 = vec_msum(q8y11, q5x11, v0);\n\n            vector signed int vscales_h = vec_unpackh(vscales);\n            vector signed int vs0 = vec_splat(vscales_h, 0);\n            vector signed int vs1 = vec_splat(vscales_h, 1);\n            vscales = vec_sld(vscales, vscales, 12);\n\n            vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0);\n            vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1);\n            vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2);\n            vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3);\n        }\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);\n        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);\n        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);\n    }\n\n    vsumf0 = vec_add(vsumf0, vsumf2);\n    vsumf1 = vec_add(vsumf1, vsumf3);\n\n    vsumf0 = vec_add(vsumf0, vsumf1);\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    *s = vec_extract(vsumf0, 0);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    UNUSED(utmp);\n    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q6_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__POWER9_VECTOR__)\n    const vector signed char lowMask = vec_splats((signed char)0xF);\n    const vector int v0 = vec_splats((int32_t)0);\n    const vector unsigned char v2 = vec_splats((unsigned char)0x2);\n    const vector unsigned char v3 = vec_splats((unsigned char)0x3);\n    const vector unsigned char v4 = vec_splats((unsigned char)0x4);\n    const vector unsigned char v6 = vec_splats((unsigned char)0x6);\n    const vector signed char off = vec_splats((signed char)0x20);\n\n    vector float vsumf0 = vec_splats(0.0f);\n    vector float vsumf1 = vec_splats(0.0f);\n    vector float vsumf2 = vec_splats(0.0f);\n    vector float vsumf3 = vec_splats(0.0f);\n\n    for (int i = 0; i < nb; ++i) {\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));\n        vector float vyd = vec_splats(y[i].d);\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector signed int vsumi0 = v0;\n        vector signed int vsumi1 = v0;\n        vector signed int vsumi2 = v0;\n        vector signed int vsumi3 = v0;\n        vector signed int vsumi4 = v0;\n        vector signed int vsumi5 = v0;\n        vector signed int vsumi6 = v0;\n        vector signed int vsumi7 = v0;\n\n        const uint8_t * GGML_RESTRICT q6 = x[i].ql;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const int8_t  * GGML_RESTRICT qs = x[i].scales;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        for (int j = 0; j < QK_K/128; ++j) {\n            __builtin_prefetch(q6, 0, 0);\n            __builtin_prefetch(qh, 0, 0);\n            __builtin_prefetch(q8, 0, 0);\n\n            vector signed char qxs0 = (vector signed char)vec_xl( 0, q6);\n            vector signed char qxs1 = (vector signed char)vec_xl(16, q6);\n            vector signed char qxs2 = (vector signed char)vec_xl(32, q6);\n            vector signed char qxs3 = (vector signed char)vec_xl(48, q6);\n            q6 += 64;\n\n            vector signed char qxs00 = vec_and(qxs0, lowMask);\n            vector signed char qxs01 = vec_sr(qxs0, v4);\n            vector signed char qxs10 = vec_and(qxs1, lowMask);\n            vector signed char qxs11 = vec_sr(qxs1, v4);\n            vector signed char qxs20 = vec_and(qxs2, lowMask);\n            vector signed char qxs21 = vec_sr(qxs2, v4);\n            vector signed char qxs30 = vec_and(qxs3, lowMask);\n            vector signed char qxs31 = vec_sr(qxs3, v4);\n\n            vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh);\n            vector signed char qxhs1 = (vector signed char)vec_xl(16, qh);\n            qh += 32;\n\n            vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4);\n            vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4);\n            vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4);\n            vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4);\n            vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4);\n            vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4);\n            vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4);\n            vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4);\n\n            vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off);\n            vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off);\n            vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off);\n            vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off);\n            vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off);\n            vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off);\n            vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off);\n            vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off);\n\n            vector signed char q8y00 = vec_xl(  0, q8);\n            vector signed char q8y10 = vec_xl( 16, q8);\n            vector signed char q8y20 = vec_xl( 32, q8);\n            vector signed char q8y30 = vec_xl( 48, q8);\n            vector signed char q8y01 = vec_xl( 64, q8);\n            vector signed char q8y11 = vec_xl( 80, q8);\n            vector signed char q8y21 = vec_xl( 96, q8);\n            vector signed char q8y31 = vec_xl(112, q8);\n            q8 += 128;\n\n            vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00));\n            vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10));\n            vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20));\n            vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30));\n            vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01));\n            vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11));\n            vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21));\n            vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31));\n\n            vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8));\n            qs += 8;\n\n            vector signed short vs0 = vec_splat(vscales, 0);\n            vector signed short vs1 = vec_splat(vscales, 1);\n            vector signed short vs2 = vec_splat(vscales, 2);\n            vector signed short vs3 = vec_splat(vscales, 3);\n            vector signed short vs4 = vec_splat(vscales, 4);\n            vector signed short vs5 = vec_splat(vscales, 5);\n            vector signed short vs6 = vec_splat(vscales, 6);\n            vector signed short vs7 = vec_splat(vscales, 7);\n\n            vsumi0 = vec_msum(qv00, vs0, vsumi0);\n            vsumi1 = vec_msum(qv01, vs4, vsumi1);\n            vsumi2 = vec_msum(qv10, vs1, vsumi2);\n            vsumi3 = vec_msum(qv11, vs5, vsumi3);\n            vsumi4 = vec_msum(qv20, vs2, vsumi4);\n            vsumi5 = vec_msum(qv21, vs6, vsumi5);\n            vsumi6 = vec_msum(qv30, vs3, vsumi6);\n            vsumi7 = vec_msum(qv31, vs7, vsumi7);\n        }\n\n        vsumi0 = vec_add(vsumi0, vsumi4);\n        vsumi1 = vec_add(vsumi1, vsumi5);\n        vsumi2 = vec_add(vsumi2, vsumi6);\n        vsumi3 = vec_add(vsumi3, vsumi7);\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);\n        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);\n        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);\n    }\n\n    vsumf0 = vec_add(vsumf0, vsumf2);\n    vsumf1 = vec_add(vsumf1, vsumf3);\n\n    vsumf0 = vec_add(vsumf0, vsumf1);\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    *s = vec_extract(vsumf0, 0);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\n#if defined (__POWER9_VECTOR__)\nstatic const int8_t keven_signs_q2xs[1024] = {\n     1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,\n     1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,\n     1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1, -1,\n     1,  1, -1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1,  1,\n     1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1, -1,\n     1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1,  1,\n     1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1,  1,\n     1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1, -1,\n     1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1, -1,\n     1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1,  1,\n     1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1,  1,\n     1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1, -1,\n     1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1,  1,\n     1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,\n     1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1, -1,\n     1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1,\n     1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1, -1,\n     1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1,  1,\n     1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1,\n     1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1,\n     1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1,  1,\n     1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1, -1,\n     1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1, -1,\n     1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1,\n     1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1,\n     1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1, -1,\n     1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1,\n     1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1,\n     1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1,\n     1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,  1,\n     1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1,\n     1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,\n};\n#endif\n\nvoid ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_xxs * GGML_RESTRICT x = vx;\n    const block_q8_K    * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__POWER9_VECTOR__)\n    const vector int v0 = vec_splats((int32_t)0);\n    vector float vsumf0 = vec_splats(0.0f);\n    vector float vsumf1 = vec_splats(0.0f);\n    vector float vsumf2 = vec_splats(0.0f);\n    vector float vsumf3 = vec_splats(0.0f);\n\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n\n    for (int i = 0; i < nb; ++i) {\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));\n        vector float vyd = vec_splats(y[i].d);\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector signed int vsumi0 = v0;\n        vector signed int vsumi1 = v0;\n        vector signed int vsumi2 = v0;\n        vector signed int vsumi3 = v0;\n\n        const uint16_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t  *  GGML_RESTRICT q8 = y[i].qs;\n\n        for (int j = 0; j < QK_K/32; j += 2) {\n            __builtin_prefetch(q2, 0, 1);\n            __builtin_prefetch(q8, 0, 1);\n\n            uint32_t aux32[4];\n            const uint8_t * aux8 = (const uint8_t *)aux32;\n\n            memcpy(aux32, q2, 4*sizeof(uint32_t));\n            q2 += 8;\n\n            vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1])};\n            vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3])};\n            vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9])};\n            vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11])};\n\n            vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((aux32[1] >>  0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >>  7) & 127))};\n            vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127))};\n            vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((aux32[3] >>  0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >>  7) & 127))};\n            vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127))};\n\n            vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0);\n            vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1);\n            vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2);\n            vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3);\n\n            vector signed char q8y0 = vec_xl( 0, q8);\n            vector signed char q8y1 = vec_xl(16, q8);\n            vector signed char q8y2 = vec_xl(32, q8);\n            vector signed char q8y3 = vec_xl(48, q8);\n            q8 += 64;\n\n            vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));\n            vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));\n            vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));\n            vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));\n\n            const uint16_t ls0 = aux32[1] >> 28;\n            const uint16_t ls1 = aux32[3] >> 28;\n\n            vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1));\n            vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1));\n\n            vsumi0 = vec_msum(qv0, vscales01, vsumi0);\n            vsumi1 = vec_msum(qv1, vscales01, vsumi1);\n            vsumi2 = vec_msum(qv2, vscales23, vsumi2);\n            vsumi3 = vec_msum(qv3, vscales23, vsumi3);\n        }\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);\n        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);\n        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);\n    }\n\n    vsumf0 = vec_add(vsumf0, vsumf2);\n    vsumf1 = vec_add(vsumf1, vsumf3);\n\n    vsumf0 = vec_add(vsumf0, vsumf1);\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    *s = 0.125f * vec_extract(vsumf0, 0);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_xs * GGML_RESTRICT x = vx;\n    const block_q8_K   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__POWER9_VECTOR__)\n    const vector int v0 = vec_splats((int32_t)0);\n    vector float vsumf0 = vec_splats(0.0f);\n    vector float vsumf1 = vec_splats(0.0f);\n    vector float vsumf2 = vec_splats(0.0f);\n    vector float vsumf3 = vec_splats(0.0f);\n\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n\n    for (int i = 0; i < nb; ++i) {\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));\n        vector float vyd = vec_splats(y[i].d);\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector signed int vsumi0 = v0;\n        vector signed int vsumi1 = v0;\n        vector signed int vsumi2 = v0;\n        vector signed int vsumi3 = v0;\n\n        const uint16_t * GGML_RESTRICT q2 = x[i].qs;\n        const uint8_t  * GGML_RESTRICT sc = x[i].scales;\n        const int8_t  *  GGML_RESTRICT q8 = y[i].qs;\n\n        for (int j = 0; j < QK_K/64; ++j) {\n            __builtin_prefetch(q2, 0, 1);\n            __builtin_prefetch(q8, 0, 1);\n\n            vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xs_grid + (q2[0] & 511)), *(const int64_t *)(iq2xs_grid + (q2[1] & 511))};\n            vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xs_grid + (q2[2] & 511)), *(const int64_t *)(iq2xs_grid + (q2[3] & 511))};\n            vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xs_grid + (q2[4] & 511)), *(const int64_t *)(iq2xs_grid + (q2[5] & 511))};\n            vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xs_grid + (q2[6] & 511)), *(const int64_t *)(iq2xs_grid + (q2[7] & 511))};\n\n            vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((q2[0] >> 9))), *(const int64_t *)(signs64 + ((q2[1] >> 9)))};\n            vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((q2[2] >> 9))), *(const int64_t *)(signs64 + ((q2[3] >> 9)))};\n            vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((q2[4] >> 9))), *(const int64_t *)(signs64 + ((q2[5] >> 9)))};\n            vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((q2[6] >> 9))), *(const int64_t *)(signs64 + ((q2[7] >> 9)))};\n            q2 += 8;\n\n            vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0);\n            vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1);\n            vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2);\n            vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3);\n\n            vector signed char q8y0 = vec_xl( 0, q8);\n            vector signed char q8y1 = vec_xl(16, q8);\n            vector signed char q8y2 = vec_xl(32, q8);\n            vector signed char q8y3 = vec_xl(48, q8);\n            q8 += 64;\n\n            vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));\n            vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));\n            vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));\n            vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));\n\n            const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);\n            const uint16_t ls1 = (uint16_t)(sc[0] >>  4);\n            const uint16_t ls2 = (uint16_t)(sc[1] & 0xf);\n            const uint16_t ls3 = (uint16_t)(sc[1] >>  4);\n            sc += 2;\n\n            vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1));\n            vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1));\n            vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1));\n            vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1));\n\n            vsumi0 = vec_msum(qv0, vscales0, vsumi0);\n            vsumi1 = vec_msum(qv1, vscales1, vsumi1);\n            vsumi2 = vec_msum(qv2, vscales2, vsumi2);\n            vsumi3 = vec_msum(qv3, vscales3, vsumi3);\n        }\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);\n        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);\n        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);\n    }\n\n    vsumf0 = vec_add(vsumf0, vsumf2);\n    vsumf1 = vec_add(vsumf1, vsumf3);\n\n    vsumf0 = vec_add(vsumf0, vsumf1);\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    *s = 0.125f * vec_extract(vsumf0, 0);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__POWER9_VECTOR__)\n    static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,\n                                        0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03\n    };\n\n    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};\n\n    const vector int v0 = vec_splats((int32_t)0);\n\n    vector float vsumf0 = vec_splats(0.0f);\n    vector float vsumf1 = vec_splats(0.0f);\n    vector float vsumf2 = vec_splats(0.0f);\n    vector float vsumf3 = vec_splats(0.0f);\n\n    const vector unsigned char mask0 = vec_xl( 0, k_mask1);\n    const vector unsigned char mask1 = vec_xl(16, k_mask1);\n    const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2);\n\n    for (int i = 0; i < nb; ++i) {\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));\n        vector float vyd = vec_splats(y[i].d);\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector signed int vsumi0 = v0;\n        vector signed int vsumi1 = v0;\n        vector signed int vsumi2 = v0;\n        vector signed int vsumi3 = v0;\n\n        const uint8_t *  GGML_RESTRICT q2 = x[i].qs;\n        const uint8_t *  GGML_RESTRICT qh = x[i].qh;\n        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);\n        const uint8_t *  GGML_RESTRICT sc = x[i].scales;\n        const int8_t  *  GGML_RESTRICT q8 = y[i].qs;\n\n        for (int j = 0; j < QK_K/32; j += 2) {\n            __builtin_prefetch(q2, 0, 1);\n            __builtin_prefetch(q8, 0, 1);\n\n            vector signed long long aux64x2_0 = {*(const int64_t *)(iq2s_grid + (q2[0] | ((qh[0] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[1] | ((qh[0] << 6) & 0x300)))};\n            vector signed long long aux64x2_1 = {*(const int64_t *)(iq2s_grid + (q2[2] | ((qh[0] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[3] | ((qh[0] << 2) & 0x300)))};\n            vector signed long long aux64x2_2 = {*(const int64_t *)(iq2s_grid + (q2[4] | ((qh[1] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[5] | ((qh[1] << 6) & 0x300)))};\n            vector signed long long aux64x2_3 = {*(const int64_t *)(iq2s_grid + (q2[6] | ((qh[1] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[7] | ((qh[1] << 2) & 0x300)))};\n            q2 += 8;\n            qh += 2;\n\n            vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]);\n            vector signed char vsigns23 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]);\n            signs += 4;\n\n            vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0);\n            vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1);\n            vector signed char vsigns2 = vec_perm(vsigns23, vsigns23, mask0);\n            vector signed char vsigns3 = vec_perm(vsigns23, vsigns23, mask1);\n\n            vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2);\n            vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2);\n            vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2);\n            vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2);\n\n            vector signed char q2x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux64x2_0), vsigns0);\n            vector signed char q2x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux64x2_1), vsigns1);\n            vector signed char q2x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux64x2_2), vsigns2);\n            vector signed char q2x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux64x2_3), vsigns3);\n\n            vector signed char q8y0 = vec_xl( 0, q8);\n            vector signed char q8y1 = vec_xl(16, q8);\n            vector signed char q8y2 = vec_xl(32, q8);\n            vector signed char q8y3 = vec_xl(48, q8);\n            q8 += 64;\n\n            vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));\n            vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));\n            vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));\n            vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));\n\n            const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);\n            const uint16_t ls1 = (uint16_t)(sc[0] >>  4);\n            const uint16_t ls2 = (uint16_t)(sc[1] & 0xf);\n            const uint16_t ls3 = (uint16_t)(sc[1] >>  4);\n            sc += 2;\n\n            vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1));\n            vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1));\n            vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1));\n            vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1));\n\n            vsumi0 = vec_msum(qv0, vscales0, vsumi0);\n            vsumi1 = vec_msum(qv1, vscales1, vsumi1);\n            vsumi2 = vec_msum(qv2, vscales2, vsumi2);\n            vsumi3 = vec_msum(qv3, vscales3, vsumi3);\n        }\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);\n        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);\n        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);\n    }\n\n    vsumf0 = vec_add(vsumf0, vsumf2);\n    vsumf1 = vec_add(vsumf1, vsumf3);\n\n    vsumf0 = vec_add(vsumf0, vsumf1);\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    *s = 0.125f * vec_extract(vsumf0, 0);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq3_xxs * GGML_RESTRICT x = vx;\n    const block_q8_K    * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__POWER9_VECTOR__)\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n\n    const vector int v0 = vec_splats((int32_t)0);\n\n    vector float vsumf0 = vec_splats(0.0f);\n    vector float vsumf1 = vec_splats(0.0f);\n    vector float vsumf2 = vec_splats(0.0f);\n    vector float vsumf3 = vec_splats(0.0f);\n\n    for (int i = 0; i < nb; ++i) {\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));\n        vector float vyd = vec_splats(y[i].d);\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector signed int vsumi0 = v0;\n        vector signed int vsumi1 = v0;\n        vector signed int vsumi2 = v0;\n        vector signed int vsumi3 = v0;\n\n        const uint8_t * GGML_RESTRICT q3 = x[i].qs;\n        const uint32_t * GGML_RESTRICT signs = (const uint32_t *)(x[i].qs + QK_K/4);\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n#pragma GCC unroll 1\n        for (int j = 0; j < QK_K/32; j += 2) {\n            __builtin_prefetch(q3, 0, 1);\n            __builtin_prefetch(q8, 0, 1);\n\n            vector unsigned int aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};\n            vector unsigned int aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};\n            vector unsigned int aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};\n            vector unsigned int aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};\n            q3 += 16;\n\n            vector unsigned long long aux64x2_0 = {(uint64_t)(signs64[(signs[0] >>  0) & 127]), (uint64_t)(signs64[(signs[0] >>  7) & 127])};\n            vector unsigned long long aux64x2_1 = {(uint64_t)(signs64[(signs[0] >> 14) & 127]), (uint64_t)(signs64[(signs[0] >> 21) & 127])};\n            vector unsigned long long aux64x2_2 = {(uint64_t)(signs64[(signs[1] >>  0) & 127]), (uint64_t)(signs64[(signs[1] >>  7) & 127])};\n            vector unsigned long long aux64x2_3 = {(uint64_t)(signs64[(signs[1] >> 14) & 127]), (uint64_t)(signs64[(signs[1] >> 21) & 127])};\n\n            vector signed char q3x0 = vec_mul((vector signed char)aux64x2_0, (vector signed char)aux32x4_0);\n            vector signed char q3x1 = vec_mul((vector signed char)aux64x2_1, (vector signed char)aux32x4_1);\n            vector signed char q3x2 = vec_mul((vector signed char)aux64x2_2, (vector signed char)aux32x4_2);\n            vector signed char q3x3 = vec_mul((vector signed char)aux64x2_3, (vector signed char)aux32x4_3);\n\n            vector signed char q8y0 = vec_xl( 0, q8);\n            vector signed char q8y1 = vec_xl(16, q8);\n            vector signed char q8y2 = vec_xl(32, q8);\n            vector signed char q8y3 = vec_xl(48, q8);\n            q8 += 64;\n\n            vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0));\n            vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1));\n            vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2));\n            vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3));\n\n            const uint16_t ls0 = (uint16_t)(signs[0] >> 28);\n            const uint16_t ls1 = (uint16_t)(signs[1] >> 28);\n            signs += 2;\n\n            vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));\n            vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));\n\n            vsumi0 = vec_msum(qv0, vscales01, vsumi0);\n            vsumi1 = vec_msum(qv1, vscales01, vsumi1);\n            vsumi2 = vec_msum(qv2, vscales23, vsumi2);\n            vsumi3 = vec_msum(qv3, vscales23, vsumi3);\n        }\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);\n        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);\n        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);\n    }\n\n    vsumf0 = vec_add(vsumf0, vsumf2);\n    vsumf1 = vec_add(vsumf1, vsumf3);\n\n    vsumf0 = vec_add(vsumf0, vsumf1);\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    *s = 0.25f * vec_extract(vsumf0, 0);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq3_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__POWER9_VECTOR__)\n    static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,\n                                        0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03\n    };\n\n    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};\n\n    const vector int v0 = vec_splats((int32_t)0);\n\n    vector float vsumf0 = vec_splats(0.0f);\n    vector float vsumf1 = vec_splats(0.0f);\n    vector float vsumf2 = vec_splats(0.0f);\n    vector float vsumf3 = vec_splats(0.0f);\n\n    const vector unsigned char mask0 = vec_xl( 0, k_mask1);\n    const vector unsigned char mask1 = vec_xl(16, k_mask1);\n    const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2);\n\n    for (int i = 0; i < nb; ++i) {\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));\n        vector float vyd = vec_splats(y[i].d);\n        vector float vd = vec_mul(vxd, vyd);\n\n        const uint8_t *  GGML_RESTRICT q3 = x[i].qs;\n        const uint8_t *  GGML_RESTRICT qh = x[i].qh;\n        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].signs);\n        const uint8_t *  GGML_RESTRICT sc = x[i].scales;\n        const int8_t  *  GGML_RESTRICT q8 = y[i].qs;\n\n        vector signed int vsumi0 = v0;\n        vector signed int vsumi1 = v0;\n        vector signed int vsumi2 = v0;\n        vector signed int vsumi3 = v0;\n\n        for (int j = 0; j < QK_K/32; j += 2) {\n            __builtin_prefetch(q3, 0, 1);\n            __builtin_prefetch(q8, 0, 1);\n\n            vector unsigned int aux32x4_0 = {iq3s_grid[q3[ 0] | ((qh[0] << 8) & 256)], iq3s_grid[q3[ 1] | ((qh[0] << 7) & 256)],\n                                             iq3s_grid[q3[ 2] | ((qh[0] << 6) & 256)], iq3s_grid[q3[ 3] | ((qh[0] << 5) & 256)]};\n            vector unsigned int aux32x4_1 = {iq3s_grid[q3[ 4] | ((qh[0] << 4) & 256)], iq3s_grid[q3[ 5] | ((qh[0] << 3) & 256)],\n                                             iq3s_grid[q3[ 6] | ((qh[0] << 2) & 256)], iq3s_grid[q3[ 7] | ((qh[0] << 1) & 256)]};\n            vector unsigned int aux32x4_2 = {iq3s_grid[q3[ 8] | ((qh[1] << 8) & 256)], iq3s_grid[q3[ 9] | ((qh[1] << 7) & 256)],\n                                             iq3s_grid[q3[10] | ((qh[1] << 6) & 256)], iq3s_grid[q3[11] | ((qh[1] << 5) & 256)]};\n            vector unsigned int aux32x4_3 = {iq3s_grid[q3[12] | ((qh[1] << 4) & 256)], iq3s_grid[q3[13] | ((qh[1] << 3) & 256)],\n                                             iq3s_grid[q3[14] | ((qh[1] << 2) & 256)], iq3s_grid[q3[15] | ((qh[1] << 1) & 256)]};\n            q3 += 16;\n            qh += 2;\n\n            vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]);\n            vector signed char vsigns02 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]);\n            signs += 4;\n\n            vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0);\n            vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1);\n            vector signed char vsigns2 = vec_perm(vsigns02, vsigns02, mask0);\n            vector signed char vsigns3 = vec_perm(vsigns02, vsigns02, mask1);\n\n            vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2);\n            vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2);\n            vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2);\n            vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2);\n\n            vector signed char q3x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux32x4_0), vsigns0);\n            vector signed char q3x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux32x4_1), vsigns1);\n            vector signed char q3x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux32x4_2), vsigns2);\n            vector signed char q3x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux32x4_3), vsigns3);\n\n            vector signed char q8y0 = vec_xl( 0, q8);\n            vector signed char q8y1 = vec_xl(16, q8);\n            vector signed char q8y2 = vec_xl(32, q8);\n            vector signed char q8y3 = vec_xl(48, q8);\n            q8 += 64;\n\n            vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0));\n            vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1));\n            vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2));\n            vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3));\n\n            const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);\n            const uint16_t ls1 = (uint16_t)(sc[0] >>  4);\n            sc ++;\n\n            vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));\n            vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));\n\n            vsumi0 = vec_msum(qv0, vscales01, vsumi0);\n            vsumi1 = vec_msum(qv1, vscales01, vsumi1);\n            vsumi2 = vec_msum(qv2, vscales23, vsumi2);\n            vsumi3 = vec_msum(qv3, vscales23, vsumi3);\n        }\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);\n        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);\n        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);\n    }\n\n    vsumf0 = vec_add(vsumf0, vsumf2);\n    vsumf1 = vec_add(vsumf1, vsumf3);\n\n    vsumf0 = vec_add(vsumf0, vsumf1);\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    *s = vec_extract(vsumf0, 0);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq1_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__POWER9_VECTOR__)\n    const vector unsigned char v0 = vec_splats((unsigned char)0x0);\n    const vector unsigned short vsign = vec_splats((unsigned short)0x8000);\n\n    vector float vsumf0 = vec_splats(0.0f);\n    vector float vsumf1 = vec_splats(0.0f);\n    vector float vsumf2 = vec_splats(0.0f);\n    vector float vsumf3 = vec_splats(0.0f);\n\n    for (int i = 0; i < nb; ++i) {\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));\n        vector float vyd = vec_splats(y[i].d);\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector signed int vsumi0 = vec_splats((int32_t)0);\n        vector signed int vsumi1 = vec_splats((int32_t)0);\n        vector signed int vsumi2 = vec_splats((int32_t)0);\n        vector signed int vsumi3 = vec_splats((int32_t)0);\n        vector signed int vsumi8 = vec_splats((int32_t)0);\n\n        const uint8_t  * GGML_RESTRICT q1 = x[i].qs;\n        const uint16_t * GGML_RESTRICT qh = x[i].qh;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n        const int16_t  * GGML_RESTRICT qs = y[i].bsums;\n\n        for (int j = 0; j < QK_K/32; j += 2) {\n            __builtin_prefetch(q1, 0, 1);\n            __builtin_prefetch(qh, 0, 1);\n            __builtin_prefetch(q8, 0, 1);\n\n            vector signed long long aux64x2_0 = {*(const int64_t *)(iq1s_grid + (q1[0] | ((qh[0] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[1] | ((qh[0] << 5) & 0x700)))};\n            vector signed long long aux64x2_1 = {*(const int64_t *)(iq1s_grid + (q1[2] | ((qh[0] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[3] | ((qh[0] >> 1) & 0x700)))};\n            vector signed long long aux64x2_2 = {*(const int64_t *)(iq1s_grid + (q1[4] | ((qh[1] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[5] | ((qh[1] << 5) & 0x700)))};\n            vector signed long long aux64x2_3 = {*(const int64_t *)(iq1s_grid + (q1[6] | ((qh[1] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[7] | ((qh[1] >> 1) & 0x700)))};\n            q1 += 8;\n\n            vector signed char q1x0 = (vector signed char)aux64x2_0;\n            vector signed char q1x1 = (vector signed char)aux64x2_1;\n            vector signed char q1x2 = (vector signed char)aux64x2_2;\n            vector signed char q1x3 = (vector signed char)aux64x2_3;\n\n            vector signed char q8y0 = vec_xl( 0, q8);\n            vector signed char q8y1 = vec_xl(16, q8);\n            vector signed char q8y2 = vec_xl(32, q8);\n            vector signed char q8y3 = vec_xl(48, q8);\n            q8 += 64;\n\n            vector signed short qv0 = vec_add(vec_mule(q1x0, q8y0), vec_mulo(q1x0, q8y0));\n            vector signed short qv1 = vec_add(vec_mule(q1x1, q8y1), vec_mulo(q1x1, q8y1));\n            vector signed short qv2 = vec_add(vec_mule(q1x2, q8y2), vec_mulo(q1x2, q8y2));\n            vector signed short qv3 = vec_add(vec_mule(q1x3, q8y3), vec_mulo(q1x3, q8y3));\n\n            const uint16_t ls0 = (uint16_t)((qh[0] >> 12) & 7);\n            const uint16_t ls1 = (uint16_t)((qh[1] >> 12) & 7);\n\n            vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));\n            vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));\n            vector signed short vscales = vec_sld(vscales23, vscales01, 8);\n\n            vsumi0 = vec_msum(qv0, vscales01, vsumi0);\n            vsumi1 = vec_msum(qv1, vscales01, vsumi1);\n            vsumi2 = vec_msum(qv2, vscales23, vsumi2);\n            vsumi3 = vec_msum(qv3, vscales23, vsumi3);\n\n            vector signed short q8ysums = vec_xl_len(qs, 8);\n            qs += 4;\n            q8ysums = vec_mergeh(q8ysums, (vector signed short)v0);\n\n            vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8);\n            qh += 2;\n            vector __bool short vsel = vec_cmpge(qxh, (vector signed short)v0);\n\n            vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel);\n\n            vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8);\n        }\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);\n        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);\n        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);\n\n        vsumf0 = vec_madd(vec_ctf(vsumi8, 0), vec_mul(vd, vec_splats(IQ1S_DELTA)), vsumf0);\n    }\n\n    vsumf0 = vec_add(vsumf0, vsumf2);\n    vsumf1 = vec_add(vsumf1, vsumf3);\n\n    vsumf0 = vec_add(vsumf0, vsumf1);\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    *s = vec_extract(vsumf0, 0);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK4_NL == 0);\n    static_assert(QK4_NL == QK8_0, \"QK4_NL and QK8_0 must be the same\");\n\n    const block_iq4_nl * GGML_RESTRICT x = vx;\n    const block_q8_0   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK4_NL;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__POWER9_VECTOR__)\n    const vector signed char lowMask = vec_splats((signed char)0xF);\n    const vector signed int v0 = vec_splats((int32_t)0);\n    const vector unsigned char v4 = vec_splats((unsigned char)0x4);\n\n    vector float vsumf0 = vec_splats(0.0f);\n    vector float vsumf1 = vec_splats(0.0f);\n\n    const vector signed char values = vec_xl( 0, kvalues_iq4nl);\n\n#pragma GCC unroll 4\n    for (; ib < nb; ++ib) {\n        __builtin_prefetch(x[ib].qs, 0, 1);\n        __builtin_prefetch(y[ib].qs, 0, 1);\n\n\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d));\n        vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d));\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);\n        vector signed char q4x0 = vec_and(qxs, lowMask);\n        vector signed char q4x1 = vec_sr(qxs, v4);\n\n        q4x0 = vec_perm(values, values, (vector unsigned char)q4x0);\n        q4x1 = vec_perm(values, values, (vector unsigned char)q4x1);\n\n        vector signed char q8y0 = vec_xl( 0, y[ib].qs);\n        vector signed char q8y1 = vec_xl(16, y[ib].qs);\n\n        vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));\n        vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));\n\n        vector signed int vsumi0 = v0;\n        vector signed int vsumi1 = v0;\n\n        vsumi0 = vec_sum4s(qv0, vsumi0);\n        vsumi1 = vec_sum4s(qv1, vsumi1);\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);\n    }\n\n    vsumf0 = vec_add(vsumf0, vsumf1);\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    sumf = vec_extract(vsumf0, 0);\n\n    *s = sumf;\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(ib);\n    UNUSED(sumf);\n    ggml_vec_dot_iq4_nl_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_K == 0);\n\n    const block_iq4_xs * GGML_RESTRICT x = vx;\n    const block_q8_K   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__POWER9_VECTOR__)\n    const vector signed char lowMask = vec_splats((signed char)0xF);\n    const vector int v0 = vec_splats((int32_t)0);\n    const vector unsigned char v4 = vec_splats((unsigned char)0x4);\n\n    vector float vsumf0 = vec_splats(0.0f);\n    vector float vsumf1 = vec_splats(0.0f);\n    vector float vsumf2 = vec_splats(0.0f);\n    vector float vsumf3 = vec_splats(0.0f);\n\n    const vector signed char values = vec_xl( 0, kvalues_iq4nl);\n\n    for (int ibl = 0; ibl < nb; ++ibl) {\n\n        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ibl].d));\n        vector float vyd = vec_splats(y[ibl].d);\n        vector float vd = vec_mul(vxd, vyd);\n\n        vector signed int vsumi0 = v0;\n        vector signed int vsumi1 = v0;\n        vector signed int vsumi2 = v0;\n        vector signed int vsumi3 = v0;\n\n        uint16_t h = x[ibl].scales_h;\n\n        const uint8_t * GGML_RESTRICT q4 = x[ibl].qs;\n        const uint8_t * GGML_RESTRICT sc = x[ibl].scales_l;\n        const int8_t  * GGML_RESTRICT q8 = y[ibl].qs;\n\n        for (int ib = 0; ib < QK_K/64; ib ++ ) {\n            __builtin_prefetch(q4, 0, 1);\n            __builtin_prefetch(q8, 0, 1);\n\n            vector signed char qxs0 = (vector signed char)vec_xl( 0, q4);\n            vector signed char qxs1 = (vector signed char)vec_xl(16, q4);\n            q4 += 32;\n\n            vector signed char q4x00 = (vector signed char)vec_and(qxs0, lowMask);\n            vector signed char q4x01 = (vector signed char)vec_sr(qxs0, v4);\n            vector signed char q4x10 = (vector signed char)vec_and(qxs1, lowMask);\n            vector signed char q4x11 = (vector signed char)vec_sr(qxs1, v4);\n\n            q4x00 = vec_perm(values, values, (vector unsigned char)q4x00);\n            q4x01 = vec_perm(values, values, (vector unsigned char)q4x01);\n            q4x10 = vec_perm(values, values, (vector unsigned char)q4x10);\n            q4x11 = vec_perm(values, values, (vector unsigned char)q4x11);\n\n            vector signed char q8y0 = vec_xl( 0, q8);\n            vector signed char q8y1 = vec_xl(16, q8);\n            vector signed char q8y2 = vec_xl(32, q8);\n            vector signed char q8y3 = vec_xl(48, q8);\n            q8 += 64;\n\n            vector signed short qv0 = vec_add(vec_mule(q4x00, q8y0), vec_mulo(q4x00, q8y0));\n            vector signed short qv1 = vec_add(vec_mule(q4x01, q8y1), vec_mulo(q4x01, q8y1));\n            vector signed short qv2 = vec_add(vec_mule(q4x10, q8y2), vec_mulo(q4x10, q8y2));\n            vector signed short qv3 = vec_add(vec_mule(q4x11, q8y3), vec_mulo(q4x11, q8y3));\n\n            const uint16_t ls0 = (uint16_t)(((sc[0] & 0xf) | ((h << 4) & 0x30)) - 32);\n            const uint16_t ls1 = (uint16_t)(((sc[0] >>  4) | ((h << 2) & 0x30)) - 32);\n            h >>= 4;\n            sc ++;\n\n            vector signed short vscales01 = vec_splats((int16_t)ls0);\n            vector signed short vscales23 = vec_splats((int16_t)ls1);\n\n            vsumi0 = vec_msum(qv0, vscales01, vsumi0);\n            vsumi1 = vec_msum(qv1, vscales01, vsumi1);\n            vsumi2 = vec_msum(qv2, vscales23, vsumi2);\n            vsumi3 = vec_msum(qv3, vscales23, vsumi3);\n        }\n\n        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);\n        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);\n        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);\n        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);\n    }\n\n    vsumf0 = vec_add(vsumf0, vsumf2);\n    vsumf1 = vec_add(vsumf1, vsumf3);\n\n    vsumf0 = vec_add(vsumf0, vsumf1);\n\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));\n    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));\n\n    *s = vec_extract(vsumf0, 0);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\n"
  },
  {
    "path": "src/ggml-cpu/arch/riscv/cpu-feats.cpp",
    "content": "#include \"ggml-backend-impl.h\"\n\n#if defined(__riscv) && __riscv_xlen == 64\n#include <asm/hwprobe.h>\n#include <asm/unistd.h>\n#include <unistd.h>\n\nstruct riscv64_features {\n    bool has_rvv = false;\n\n    riscv64_features() {\n        struct riscv_hwprobe probe;\n        probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0;\n        probe.value = 0;\n\n        int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0);\n\n        if (0 == ret) {\n            has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V);\n        }\n    }\n};\n\nstatic int ggml_backend_cpu_riscv64_score() {\n    int score = 1;\n    riscv64_features rf;\n\n#ifdef GGML_USE_RVV\n    if (!rf.has_rvv) { return 0; }\n    score += 1 << 1;\n#endif\n\n    return score;\n}\n\nGGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_riscv64_score)\n\n#endif  // __riscv && __riscv_xlen == 64\n"
  },
  {
    "path": "src/ggml-cpu/arch/riscv/quants.c",
    "content": "#define GGML_COMMON_IMPL_C\n#include \"ggml-common.h\"\n#include \"ggml-quants.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-cpu.h\"\n#include \"simd-mappings.h\"\n\n#include \"../../quants.h\"\n#include \"../../ggml-cpu-impl.h\"\n\n#include <math.h>\n#include <string.h>\n#include <assert.h>\n#include <float.h>\n#include <stdlib.h> // for qsort\n#include <stdio.h>  // for GGML_ASSERT\n\n#define GROUP_MAX_EPS 1e-15f\n#define GROUP_MAX_EPS_IQ3_XXS 1e-8f\n#define GROUP_MAX_EPS_IQ2_S 1e-8f\n#define GROUP_MAX_EPS_IQ1_M 1e-7f\n#define GROUP_MAX_EPS_IQ1_S 1e-12f\n\n#define UNUSED GGML_UNUSED\n\nvoid quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK8_0 == 32);\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n    block_q8_0 * GGML_RESTRICT y = vy;\n\n#if defined(__riscv_v)\n\n    size_t vl = QK8_0;\n\n    for (int i = 0; i < nb; i++) {\n        // load elements\n        vfloat32m8_t v_x   = __riscv_vle32_v_f32m8(x+i*QK8_0, vl);\n\n        vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);\n        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0f, vl);\n        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);\n        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);\n\n        const float d = amax / ((1 << 7) - 1);\n        const float id = d ? 1.0f/d : 0.0f;\n\n        y[i].d = GGML_CPU_FP32_TO_FP16(d);\n\n        vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);\n\n        // convert to integer\n        vint16m4_t   vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);\n        vint8m2_t    vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);\n\n        // store result\n        __riscv_vse8_v_i8m2(y[i].qs , vs, vl);\n    }\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_0_ref(x, y, k);\n#endif\n}\n\nvoid quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(k % QK8_1 == 0);\n    const int nb = k / QK8_1;\n\n    block_q8_1 * GGML_RESTRICT y = vy;\n\n#if defined(__riscv_v)\n\n    size_t vl = QK8_1;\n\n    for (int i = 0; i < nb; i++) {\n        // load elements\n        vfloat32m8_t v_x   = __riscv_vle32_v_f32m8(x+i*QK8_1, vl);\n\n        vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);\n        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0, vl);\n        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);\n        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);\n\n        const float d  = amax / ((1 << 7) - 1);\n        const float id = d ? 1.0f/d : 0.0f;\n\n        y[i].d = GGML_CPU_FP32_TO_FP16(d);\n\n        vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);\n\n        // convert to integer\n        vint16m4_t   vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);\n        vint8m2_t    vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);\n\n        // store result\n        __riscv_vse8_v_i8m2(y[i].qs , vs, vl);\n\n        // compute sum for y[i].s\n        vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);\n        vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl);\n\n        // set y[i].s\n        int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);\n        y[i].s = GGML_CPU_FP32_TO_FP16(sum*d);\n    }\n\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_1_ref(x, y, k);\n#endif\n}\n\nvoid quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    block_q8_K * y_blocks = (block_q8_K *)y;\n    size_t nb = k / QK_K;\n\n#if defined(__riscv_v_intrinsic)\n    const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8();\n\n    for (size_t i = 0; i < nb; i++) {\n        const float* x_block = x + i * QK_K;\n        block_q8_K* y_block = &y_blocks[i];\n\n        // 1. Calculate Min/Max\n        vfloat32m8_t max_v = __riscv_vfmv_v_f_f32m8(-__builtin_inff(), vlmax_f32m8);\n        vfloat32m8_t min_v = __riscv_vfmv_v_f_f32m8(__builtin_inff(), vlmax_f32m8);\n\n        size_t rem = QK_K;\n        size_t offset = 0;\n        while (rem > 0) {\n            size_t vl = __riscv_vsetvl_e32m8(rem);\n            vfloat32m8_t v_curr = __riscv_vle32_v_f32m8(x_block + offset, vl);\n            max_v = __riscv_vfmax_vv_f32m8(max_v, v_curr, vl);\n            min_v = __riscv_vfmin_vv_f32m8(min_v, v_curr, vl);\n            rem -= vl;\n            offset += vl;\n        }\n\n        vfloat32m1_t v_init_max = __riscv_vfmv_s_f_f32m1(-__builtin_inff(), 1);\n        vfloat32m1_t v_init_min = __riscv_vfmv_s_f_f32m1(__builtin_inff(), 1);\n\n        vfloat32m1_t v_scalar_max = __riscv_vfredmax_vs_f32m8_f32m1(max_v, v_init_max, vlmax_f32m8);\n        vfloat32m1_t v_scalar_min = __riscv_vfredmin_vs_f32m8_f32m1(min_v, v_init_min, vlmax_f32m8);\n\n        float max_val = __riscv_vfmv_f_s_f32m1_f32(v_scalar_max);\n        float min_val = __riscv_vfmv_f_s_f32m1_f32(v_scalar_min);\n\n        float amax = fabsf(max_val) > fabsf(min_val) ? fabsf(max_val) : fabsf(min_val);\n\n        if (amax == 0.0f) {\n            y_block->d = 0.0f;\n            memset(y_block->qs, 0, QK_K);\n            memset(y_block->bsums, 0, sizeof(y_block->bsums));\n            continue;\n        }\n\n        const float iscale = -127.f / (fabsf(max_val) > fabsf(min_val) ? max_val : min_val);\n        y_block->d = 1.0f / iscale;\n\n        // 2. Quantize and Calculate Sums\n        offset = 0;\n        rem = QK_K;\n        vint16m1_t v_zero_sum = __riscv_vmv_v_x_i16m1(0, 1);\n\n        while (rem > 0) {\n            size_t vl = __riscv_vsetvl_e32m8(rem);\n            vfloat32m8_t v_f = __riscv_vle32_v_f32m8(x_block + offset, vl);\n\n            v_f = __riscv_vfmul_vf_f32m8(v_f, iscale, vl);\n\n            vint32m8_t v_i32 = __riscv_vfcvt_x_f_v_i32m8_rm(v_f, __RISCV_FRM_RNE, vl);\n            vint16m4_t v_i16 = __riscv_vnclip_wx_i16m4(v_i32, 0, __RISCV_VXRM_RNE, vl);\n            vint8m2_t v_q = __riscv_vnclip_wx_i8m2(v_i16, 0, __RISCV_VXRM_RNE, vl);\n\n            __riscv_vse8_v_i8m2(y_block->qs + offset, v_q, vl);\n\n            // first iteration clear\n\n            int sum_idx;\n            vint8m1_t chunk_m1;\n            vint16m1_t v_sum;\n            sum_idx = offset / 16;\n            chunk_m1 = __riscv_vget_v_i8m2_i8m1(v_q, 0);\n            v_sum = __riscv_vwredsum_vs_i8m1_i16m1(chunk_m1, v_zero_sum, 16);\n            y_block->bsums[sum_idx] = (int16_t)__riscv_vmv_x_s_i16m1_i16(v_sum);\n\n            // remaining iterations\n            vint8m2_t slid_q = v_q;\n            for (size_t k = 16; k < vl; k += 16) {\n                slid_q = __riscv_vslidedown_vx_i8m2(slid_q, 16, vl);\n\n                sum_idx = (offset + k) / 16;\n                chunk_m1 = __riscv_vget_v_i8m2_i8m1(slid_q, 0);\n\n                v_sum = __riscv_vwredsum_vs_i8m1_i16m1(chunk_m1, v_zero_sum, 16);\n                y_block->bsums[sum_idx] =(int16_t)__riscv_vmv_x_s_i16m1_i16(v_sum);\n            }\n\n            rem -= vl;\n            offset += vl;\n        }\n    }\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_K_ref(x, y, k);\n#endif\n}\n\n//===================================== Dot products =================================\n\nvoid ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined(__riscv_v)\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n    size_t vl = qk / 2;\n\n    for (; ib < nb; ++ib) {\n        // load elements\n        vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);\n\n        vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);\n        vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);\n\n        // mask and store lower part of x, and then upper part\n        vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);\n        vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);\n\n        vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);\n        vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);\n\n        // subtract offset\n        vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);\n        vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);\n\n        vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);\n        vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);\n\n        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);\n        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);\n\n        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);\n\n        sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);\n    }\n\n    *s = sumf;\n#else\n    ggml_vec_dot_q4_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined(__riscv_v)\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n    size_t vl = qk / 2;\n\n    for (; ib < nb; ++ib) {\n        // load elements\n        vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);\n\n        vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);\n        vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);\n\n        // mask and store lower part of x, and then upper part\n        vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);\n        vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);\n\n        vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);\n        vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);\n\n        vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);\n        vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);\n\n        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);\n        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);\n\n        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);\n\n        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);\n    }\n\n    *s = sumf;\n#else\n    ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined(__riscv_v)\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    int ib = 0;\n    float sumf = 0;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    size_t vl;\n    size_t vlenb = __riscv_vlenb();\n\n    for (; ib < nb; ++ib) {\n        vl = qk / 2;\n        vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);\n        vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));\n        vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));\n        vint8m2_t v0c;\n        if (vlenb == 16) {\n            v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);\n        } else {\n            v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);\n            v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);\n        }\n\n        vl = qk;\n        vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);\n        qh = __riscv_vmnand_mm_b4(qh, qh, vl);\n        vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);\n        vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);\n        vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);\n        vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);\n        vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);\n        int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);\n\n        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;\n    }\n\n    *s = sumf;\n#else\n    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined(__riscv_v)\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    int ib = 0;\n    float sumf = 0;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_1);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n    size_t vl;\n    size_t vlenb = __riscv_vlenb();\n\n    for (; ib < nb; ++ib) {\n        vl = qk / 2;\n        vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);\n        vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));\n        vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));\n        vint8m2_t v0c;\n        if (vlenb == 16) {\n            v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);\n        } else {\n            v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);\n            v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);\n        }\n\n        vl = qk;\n        vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);\n        vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);\n        vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);\n        vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);\n        vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);\n        vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);\n        int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);\n\n        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);\n    }\n\n    *s = sumf;\n#else\n    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q8_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__riscv_v)\n    size_t vl = qk;\n\n    for (; ib < nb; ++ib) {\n        // load elements\n        vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl);\n        vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl);\n\n        vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl);\n\n        vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);\n        vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl);\n\n        int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);\n\n        sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));\n    }\n\n    *s = sumf;\n#else\n\n    UNUSED(nb);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(ib);\n    UNUSED(sumf);\n\n    ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q2_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __riscv_xtheadvector\n\n    float sumf = 0;\n    uint8_t atmp[16];\n\n    for (int i = 0; i < nb; ++i) {\n        const uint8_t * q2 = x[i].qs;\n        const  int8_t * q8 = y[i].qs;\n        const uint8_t * sc = x[i].scales;\n        const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n        uint8_t *patmp = atmp;\n        int vsums;\n        int tmp;\n        __asm__ __volatile__(\n            \"th.vsetvli zero, %[vl16], e8, m1\\n\\t\"\n            \"th.vmv.v.x v8, zero\\n\\t\"\n            \"th.vlb.v v1, (%[sc])\\n\\t\"\n            \"th.vand.vi v0, v1, 0xF\\n\\t\"\n            \"th.vsrl.vi v1, v1, 4\\n\\t\"\n            \"th.vsb.v v0, (%[scale])\\n\\t\"\n            \"th.vwaddu.vx v16, v1, zero\\n\\t\"\n            \"th.vsetvli zero, %[vl16], e16, m2\\n\\t\"\n            \"th.vlh.v v2, (%[bsums])\\n\\t\"\n            \"th.vwmul.vv v4, v16, v2\\n\\t\"\n            \"th.vsetvli zero, %[vl16], e32, m4\\n\\t\"\n            \"th.vredsum.vs v8, v4, v8\\n\\t\"\n            \"th.vmv.x.s %[vsums], v8\"\n            : [tmp] \"=&r\" (tmp), [vsums] \"=&r\" (vsums)\n            : [sc] \"r\" (sc), [scale] \"r\" (atmp), [bsums] \"r\" (y[i].bsums)\n            , [vl16] \"r\" (16)\n            : \"memory\"\n            , \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\"\n            , \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\"\n            , \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\"\n            , \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\"\n        );\n        sumf += dmin * vsums;\n        int isum = 0;\n\n        for (int j = 0; j < QK_K/128; ++j) {\n            __asm__ __volatile__(\n                \"th.vsetvli zero, %[vl32], e8, m2\\n\\t\"\n                \"th.vlb.v v0, (%[q2])\\n\\t\"\n                \"th.vsrl.vi v2, v0, 2\\n\\t\"\n                \"th.vsrl.vi v4, v0, 4\\n\\t\"\n                \"th.vsrl.vi v6, v0, 6\\n\\t\"\n                \"th.vand.vi v0, v0, 0x3\\n\\t\"\n                \"th.vand.vi v2, v2, 0x3\\n\\t\"\n                \"th.vand.vi v4, v4, 0x3\\n\\t\"\n                \"th.vsetvli zero, %[vl128], e8, m8\\n\\t\"\n                \"th.vlb.v v8, (%[q8])\\n\\t\"\n                \"th.vsetvli zero, %[vl64], e8, m4\\n\\t\"\n                \"th.vwmul.vv v16, v0, v8\\n\\t\"\n                \"th.vwmul.vv v24, v4, v12\\n\\t\"\n                \"th.vsetvli zero, %[vl16], e16, m2\\n\\t\"\n                \"th.vmv.v.x v0, zero\\n\\t\"\n                \"th.vwredsum.vs v10, v16, v0\\n\\t\"\n                \"th.vwredsum.vs v9, v18, v0\\n\\t\"\n                \"th.vwredsum.vs v8, v20, v0\\n\\t\"\n                \"th.vwredsum.vs v7, v22, v0\\n\\t\"\n                \"th.vwredsum.vs v11, v24, v0\\n\\t\"\n                \"th.vwredsum.vs v12, v26, v0\\n\\t\"\n                \"th.vwredsum.vs v13, v28, v0\\n\\t\"\n                \"th.vwredsum.vs v14, v30, v0\\n\\t\"\n                \"li %[tmp], 4\\n\\t\"\n                \"th.vsetvli zero, %[tmp], e32, m1\\n\\t\"\n                \"th.vslideup.vi v10, v9, 1\\n\\t\"\n                \"th.vslideup.vi v8, v7, 1\\n\\t\"\n                \"th.vslideup.vi v11, v12, 1\\n\\t\"\n                \"th.vslideup.vi v13, v14, 1\\n\\t\"\n                \"th.vslideup.vi v10, v8, 2\\n\\t\"\n                \"th.vslideup.vi v11, v13, 2\\n\\t\"\n                \"li %[tmp], 8\\n\\t\"\n                \"th.vsetvli zero, %[tmp], e32, m2\\n\\t\"\n                \"th.vlbu.v v12, (%[scale])\\n\\t\"\n                \"th.vmul.vv v10, v10, v12\\n\\t\"\n                \"th.vredsum.vs v0, v10, v0\\n\\t\"\n                \"th.vmv.x.s %[tmp], v0\\n\\t\"\n                \"add %[isum], %[isum], %[tmp]\"\n                : [tmp] \"=&r\" (tmp), [isum] \"+&r\" (isum)\n                : [q2] \"r\" (q2), [scale] \"r\" (patmp), [q8] \"r\" (q8)\n                , [vl16] \"r\" (16), [vl32] \"r\" (32), [vl64] \"r\" (64), [vl128] \"r\" (128)\n                : \"memory\"\n                , \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\"\n                , \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\"\n                , \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\"\n                , \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\"\n            );\n            q2 += 32; q8 += 128; patmp += 8;\n        }\n\n        sumf += dall * isum;\n    }\n\n    *s = sumf;\n\n#elif defined __riscv_v\n\n    float sumf = 0;\n    uint8_t atmp[16];\n\n    const int vector_length = __riscv_vlenb() * 8;\n    uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n                            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };\n\n    switch (vector_length) {\n    case 256:\n        for (int i = 0; i < nb; ++i) {\n            const uint8_t * q2 = x[i].qs;\n            const int8_t *  q8 = y[i].qs;\n            const uint8_t * sc = x[i].scales;\n\n            const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n            const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n            size_t vl = 16;\n\n            vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);\n            vuint8m1_t aux    = __riscv_vand_vx_u8m1(scales, 0x0F, vl);\n\n            vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);\n\n            vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);\n            vuint8mf2_t mins8    = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);\n            vint16m1_t  mins     = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));\n            vint32m2_t  prod     = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);\n            vint32m1_t  vsums    = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);\n\n            sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);\n\n            vl = 32;\n\n            vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);\n            vuint8m1_t v_b   = __riscv_vle8_v_u8m1(temp_01, vl);\n\n            uint8_t is   = 0;\n            int     isum = 0;\n\n            for (int j = 0; j < QK_K / 128; ++j) {\n                // load Q2\n                vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);\n\n                vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);\n                vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl);\n                vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl);\n                vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl);\n\n                // duplicate scale elements for product\n                vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl);\n                vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl);\n                vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl);\n                vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl);\n\n                vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));\n                vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));\n                vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));\n                vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));\n\n                // load Q8\n                vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);\n                vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl);\n                vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl);\n                vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl);\n\n                vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);\n                vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);\n                vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);\n                vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);\n\n                vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);\n                vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);\n\n                isum += __riscv_vmv_x_s_i32m1_i32(isum1);\n\n                q2 += 32;\n                q8 += 128;\n                is = 8;\n            }\n\n            sumf += dall * isum;\n        }\n        break;\n    case 128:\n        for (int i = 0; i < nb; ++i) {\n            const uint8_t * q2 = x[i].qs;\n            const  int8_t * q8 = y[i].qs;\n            const uint8_t * sc = x[i].scales;\n            const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n            const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n            uint8_t *patmp = atmp;\n            int vsums;\n            int tmp, t1, t2, t3, t4, t5, t6, t7;\n            __asm__ __volatile__(\n                \"vsetivli zero, 16, e8, m1\\n\\t\"\n                \"vmv.v.x v8, zero\\n\\t\"\n                \"lb zero, 15(%[sc])\\n\\t\"\n                \"vle8.v v1, (%[sc])\\n\\t\"\n                \"vle8.v v2, (%[bsums])\\n\\t\"\n                \"addi %[tmp], %[bsums], 16\\n\\t\"\n                \"vand.vi v0, v1, 0xF\\n\\t\"\n                \"vsrl.vi v1, v1, 4\\n\\t\"\n                \"vle8.v v3, (%[tmp])\\n\\t\"\n                \"vse8.v v0, (%[scale])\\n\\t\"\n                \"vsetivli zero, 16, e16, m2\\n\\t\"\n                \"vzext.vf2 v0, v1\\n\\t\"\n                \"vwmul.vv v4, v0, v2\\n\\t\"\n                \"vsetivli zero, 16, e32, m4\\n\\t\"\n                \"vredsum.vs v8, v4, v8\\n\\t\"\n                \"vmv.x.s %[vsums], v8\"\n                : [tmp] \"=&r\" (tmp), [vsums] \"=&r\" (vsums)\n                : [sc] \"r\" (sc), [scale] \"r\" (atmp), [bsums] \"r\" (y[i].bsums)\n                : \"memory\"\n                , \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\"\n                , \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\"\n                , \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\"\n                , \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\"\n            );\n            sumf += dmin * vsums;\n            int isum = 0;\n\n            for (int j = 0; j < QK_K/128; ++j) {\n                __asm__ __volatile__(\n                    \"lb zero, 31(%[q2])\\n\\t\"\n                    \"addi %[tmp], %[q2], 16\\n\\t\"\n                    \"addi %[t1], %[q8], 16\\n\\t\"\n                    \"vsetivli zero, 16, e8, m1\\n\\t\"\n                    \"vle8.v v0, (%[q2])\\n\\t\"\n                    \"vle8.v v1, (%[tmp])\\n\\t\"\n                    \"vsrl.vi v2, v0, 2\\n\\t\"\n                    \"vsrl.vi v3, v1, 2\\n\\t\"\n                    \"vsrl.vi v4, v0, 4\\n\\t\"\n                    \"addi %[tmp], %[q8], 32\\n\\t\"\n                    \"vle8.v v8, (%[q8])\\n\\t\"\n                    \"vle8.v v9, (%[t1])\\n\\t\"\n                    \"addi %[t1], %[t1], 32\\n\\t\"\n                    \"vsrl.vi v5, v1, 4\\n\\t\"\n                    \"vsrl.vi v6, v0, 6\\n\\t\"\n                    \"vsrl.vi v7, v1, 6\\n\\t\"\n                    \"vle8.v v10, (%[tmp])\\n\\t\"\n                    \"vle8.v v11, (%[t1])\\n\\t\"\n                    \"addi %[tmp], %[tmp], 32\\n\\t\"\n                    \"addi %[t1], %[t1], 32\\n\\t\"\n                    \"vand.vi v0, v0, 0x3\\n\\t\"\n                    \"vand.vi v1, v1, 0x3\\n\\t\"\n                    \"vand.vi v2, v2, 0x3\\n\\t\"\n                    \"vle8.v v12, (%[tmp])\\n\\t\"\n                    \"vle8.v v13, (%[t1])\\n\\t\"\n                    \"addi %[tmp], %[tmp], 32\\n\\t\"\n                    \"addi %[t1], %[t1], 32\\n\\t\"\n                    \"vand.vi v3, v3, 0x3\\n\\t\"\n                    \"vand.vi v4, v4, 0x3\\n\\t\"\n                    \"vand.vi v5, v5, 0x3\\n\\t\"\n                    \"vle8.v v14, (%[tmp])\\n\\t\"\n                    \"vle8.v v15, (%[t1])\\n\\t\"\n                    \"vwmul.vv v16, v0, v8\\n\\t\"\n                    \"vwmul.vv v18, v1, v9\\n\\t\"\n                    \"vwmul.vv v20, v2, v10\\n\\t\"\n                    \"vwmul.vv v22, v3, v11\\n\\t\"\n                    \"vwmul.vv v24, v4, v12\\n\\t\"\n                    \"vwmul.vv v26, v5, v13\\n\\t\"\n                    \"vwmul.vv v28, v6, v14\\n\\t\"\n                    \"vwmul.vv v30, v7, v15\\n\\t\"\n                    \"vsetivli zero, 8, e16, m1\\n\\t\"\n                    \"vmv.v.x v0, zero\\n\\t\"\n                    \"lbu %[tmp], 0(%[scale])\\n\\t\"\n                    \"vwredsum.vs v8, v16, v0\\n\\t\"\n                    \"vwredsum.vs v9, v18, v0\\n\\t\"\n                    \"lbu %[t1], 1(%[scale])\\n\\t\"\n                    \"vwredsum.vs v10, v20, v0\\n\\t\"\n                    \"vwredsum.vs v11, v22, v0\\n\\t\"\n                    \"lbu %[t2], 2(%[scale])\\n\\t\"\n                    \"vwredsum.vs v12, v24, v0\\n\\t\"\n                    \"vwredsum.vs v13, v26, v0\\n\\t\"\n                    \"lbu %[t3], 3(%[scale])\\n\\t\"\n                    \"vwredsum.vs v14, v28, v0\\n\\t\"\n                    \"vwredsum.vs v15, v30, v0\\n\\t\"\n                    \"lbu %[t4], 4(%[scale])\\n\\t\"\n                    \"vwredsum.vs v8, v17, v8\\n\\t\"\n                    \"vwredsum.vs v9, v19, v9\\n\\t\"\n                    \"lbu %[t5], 5(%[scale])\\n\\t\"\n                    \"vwredsum.vs v10, v21, v10\\n\\t\"\n                    \"vwredsum.vs v11, v23, v11\\n\\t\"\n                    \"lbu %[t6], 6(%[scale])\\n\\t\"\n                    \"vwredsum.vs v12, v25, v12\\n\\t\"\n                    \"vwredsum.vs v13, v27, v13\\n\\t\"\n                    \"lbu %[t7], 7(%[scale])\\n\\t\"\n                    \"vwredsum.vs v14, v29, v14\\n\\t\"\n                    \"vwredsum.vs v15, v31, v15\\n\\t\"\n                    \"vsetivli zero, 4, e32, m1\\n\\t\"\n                    \"vmul.vx v0, v8, %[tmp]\\n\\t\"\n                    \"vmul.vx v1, v9, %[t1]\\n\\t\"\n                    \"vmacc.vx v0, %[t2], v10\\n\\t\"\n                    \"vmacc.vx v1, %[t3], v11\\n\\t\"\n                    \"vmacc.vx v0, %[t4], v12\\n\\t\"\n                    \"vmacc.vx v1, %[t5], v13\\n\\t\"\n                    \"vmacc.vx v0, %[t6], v14\\n\\t\"\n                    \"vmacc.vx v1, %[t7], v15\\n\\t\"\n                    \"vmv.x.s %[tmp], v0\\n\\t\"\n                    \"vmv.x.s %[t1], v1\\n\\t\"\n                    \"add %[isum], %[isum], %[tmp]\\n\\t\"\n                    \"add %[isum], %[isum], %[t1]\"\n                    : [tmp] \"=&r\" (tmp), [t1] \"=&r\" (t1), [t2] \"=&r\" (t2), [t3] \"=&r\" (t3)\n                    , [t4] \"=&r\" (t4), [t5] \"=&r\" (t5), [t6] \"=&r\" (t6), [t7] \"=&r\" (t7)\n                    , [isum] \"+&r\" (isum)\n                    : [q2] \"r\" (q2), [scale] \"r\" (patmp), [q8] \"r\" (q8)\n                    : \"memory\"\n                    , \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\"\n                    , \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\"\n                    , \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\"\n                    , \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\"\n                );\n                q2 += 32; q8 += 128; patmp += 8;\n            }\n\n            sumf += dall * isum;\n        }\n        break;\n    default:\n        assert(false && \"Unsupported vector length\");\n        break;\n    }\n\n    *s = sumf;\n\n#else\n\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n\n    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n\n    const block_q3_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __riscv_xtheadvector\n\n    uint32_t utmp[4];\n    float sumf = 0;\n\n    for (int i = 0; i < nb; ++i) {\n        const uint8_t * restrict q3 = x[i].qs;\n        const uint8_t * restrict qh = x[i].hmask;\n        const  int8_t * restrict q8 = y[i].qs;\n\n        int8_t * scale = (int8_t *)utmp;\n        int tmp;\n        __asm__ __volatile__(\n            \"li %[tmp], 12\\n\\t\"\n            \"th.vsetvli zero, %[tmp], e8, m1\\n\\t\"\n            \"th.vlb.v v0, (%[s6b])\\n\\t\"\n            \"th.vmv.v.v v2, v0\\n\\t\"\n            \"li %[tmp], 2\\n\\t\"\n            \"th.vsetvli zero, %[tmp], e64, m1\\n\\t\"\n            \"th.vmv.v.x v9, %[sh]\\n\\t\"\\\n            \"th.vslidedown.vi v1, v0, 1\\n\\t\"\n            \"th.vslide1up.vx v8, v9, zero\\n\\t\" // {0, 0, 4, 4}\n            \"th.vslideup.vi v0, v2, 1\\n\\t\" // {aux[0], aux[1], aux[0], aux[1]}\n            \"li %[tmp], 4\\n\\t\"\n            \"th.vsetvli zero, %[tmp], e32, m1\\n\\t\"\n            \"th.vid.v v9\\n\\t\"\n            \"th.vmv.x.s %[tmp], v1\\n\\t\"\n            \"th.vsll.vi v9, v9, 1\\n\\t\" // {0, 2, 4, 6}\n            \"th.vmv.v.x v1, %[tmp]\\n\\t\" // {aux[2], aux[2], aux[2], aux[2]}\n            \"th.vsrl.vv v4, v1, v9\\n\\t\"\n            \"th.vsrl.vv v2, v0, v8\\n\\t\"\n            \"th.vand.vx v5, v4, %[kmask1]\\n\\t\"\n            \"th.vand.vx v3, v2, %[kmask2]\\n\\t\"\n            \"th.vsll.vi v6, v5, 4\\n\\t\"\n            \"th.vor.vv v7, v6, v3\\n\\t\"\n            \"li %[tmp], 16\\n\\t\"\n            \"th.vsetvli zero, %[tmp], e8, m1\\n\\t\"\n            \"th.vsub.vx v0, v7, %[c]\\n\\t\"\n            \"th.vsb.v v0, (%[scale])\"\n            : [tmp] \"=&r\" (tmp)\n            : [sh] \"r\" (0x0000000400000004), [s6b] \"r\" (x[i].scales), [c] \"r\" (32)\n            , [scale] \"r\" (scale), [kmask1] \"r\" (kmask1), [kmask2] \"r\" (kmask2)\n            : \"memory\"\n            , \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\"\n            , \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\"\n            , \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\"\n            , \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\"\n        );\n\n        uint8_t m = 1;\n        int isum = 0;\n        for (int j = 0; j < QK_K; j += 128) {\n            __asm__ __volatile__(\n                // fixme: use v0p7 mask layout directly\n                \"th.vsetvli zero, %[vl32], e8, m2\\n\\t\"\n                \"th.vlb.v v8, (%[q3])\\n\\t\"\n                \"th.vsrl.vi v10, v8, 2\\n\\t\"\n                \"th.vsrl.vi v12, v8, 4\\n\\t\"\n                \"th.vsrl.vi v14, v8, 6\\n\\t\"\n                \"th.vand.vi v8, v8, 3\\n\\t\"\n                \"th.vand.vi v10, v10, 3\\n\\t\"\n                \"th.vand.vi v12, v12, 3\\n\\t\"\n                \"th.vlb.v v2, (%[qh])\\n\\t\"\n                \"th.vand.vx v4, v2, %[m]\\n\\t\"\n                \"slli %[m], %[m], 1\\n\\t\"\n                \"th.vmseq.vx v0, v4, zero\\n\\t\"\n                \"th.vadd.vi v8, v8, -4, v0.t\\n\\t\"\n                \"th.vand.vx v4, v2, %[m]\\n\\t\"\n                \"slli %[m], %[m], 1\\n\\t\"\n                \"th.vmseq.vx v0, v4, zero\\n\\t\"\n                \"th.vadd.vi v10, v10, -4, v0.t\\n\\t\"\n                \"th.vand.vx v4, v2, %[m]\\n\\t\"\n                \"slli %[m], %[m], 1\\n\\t\"\n                \"th.vmseq.vx v0, v4, zero\\n\\t\"\n                \"th.vadd.vi v12, v12, -4, v0.t\\n\\t\"\n                \"th.vand.vx v4, v2, %[m]\\n\\t\"\n                \"slli %[m], %[m], 1\\n\\t\"\n                \"th.vmseq.vx v0, v4, zero\\n\\t\"\n                \"th.vadd.vi v14, v14, -4, v0.t\\n\\t\"\n                \"th.vsetvli zero, %[vl128], e8, m8\\n\\t\"\n                \"th.vlb.v v0, (%[q8])\\n\\t\"\n                \"th.vsetvli zero, %[vl64], e8, m4\\n\\t\"\n                \"th.vwmul.vv v16, v0, v8\\n\\t\"\n                \"th.vwmul.vv v24, v4, v12\\n\\t\"\n                \"li %[tmp], 16\\n\\t\"\n                \"th.vsetvli zero, %[tmp], e16, m2\\n\\t\"\n                \"th.vmv.v.x v0, zero\\n\\t\"\n                \"th.vwredsum.vs v10, v16, v0\\n\\t\"\n                \"th.vwredsum.vs v9, v18, v0\\n\\t\"\n                \"th.vwredsum.vs v8, v20, v0\\n\\t\"\n                \"th.vwredsum.vs v7, v22, v0\\n\\t\"\n                \"th.vwredsum.vs v11, v24, v0\\n\\t\"\n                \"th.vwredsum.vs v12, v26, v0\\n\\t\"\n                \"th.vwredsum.vs v13, v28, v0\\n\\t\"\n                \"th.vwredsum.vs v14, v30, v0\\n\\t\"\n                \"li %[tmp], 4\\n\\t\"\n                \"th.vsetvli zero, %[tmp], e32, m1\\n\\t\"\n                \"th.vslideup.vi v10, v9, 1\\n\\t\"\n                \"th.vslideup.vi v8, v7, 1\\n\\t\"\n                \"th.vslideup.vi v11, v12, 1\\n\\t\"\n                \"th.vslideup.vi v13, v14, 1\\n\\t\"\n                \"th.vslideup.vi v10, v8, 2\\n\\t\"\n                \"th.vslideup.vi v11, v13, 2\\n\\t\"\n                \"li %[tmp], 8\\n\\t\"\n                \"th.vsetvli zero, %[tmp], e32, m2\\n\\t\"\n                \"th.vlb.v v12, (%[scale])\\n\\t\"\n                \"th.vmul.vv v10, v10, v12\\n\\t\"\n                \"th.vredsum.vs v0, v10, v0\\n\\t\"\n                \"th.vmv.x.s %[tmp], v0\\n\\t\"\n                \"add %[isum], %[isum], %[tmp]\"\n                : [tmp] \"=&r\" (tmp), [m] \"+&r\" (m), [isum] \"+&r\" (isum)\n                : [vl128] \"r\" (128), [vl64] \"r\" (64), [vl32] \"r\" (32)\n                , [q3] \"r\" (q3), [qh] \"r\" (qh), [scale] \"r\" (scale), [q8] \"r\" (q8)\n                : \"memory\"\n                , \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\"\n                , \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\"\n                , \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\"\n                , \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\"\n            );\n            q3 += 32;    q8 += 128;   scale += 8;\n        }\n\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        sumf += d * isum;\n    }\n\n    *s = sumf;\n\n#elif defined __riscv_v\n\n    uint32_t utmp[4];\n    float sumf = 0;\n    uint32_t aux[3];\n    const int vector_length = __riscv_vlenb() * 8;\n\n    switch (vector_length) {\n    case 256:\n        for (int i = 0; i < nb; ++i) {\n\n            const uint8_t * GGML_RESTRICT q3 = x[i].qs;\n            const uint8_t * GGML_RESTRICT qh = x[i].hmask;\n            const  int8_t * GGML_RESTRICT q8 = y[i].qs;\n\n            memcpy(aux, x[i].scales, 12);\n            utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);\n            utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);\n            utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);\n            utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);\n\n            int8_t * scale = (int8_t *)utmp;\n            for (int j = 0; j < 16; ++j) scale[j] -= 32;\n\n\n            size_t vl = 32;\n            uint8_t m =  1;\n\n            vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);\n            vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);\n\n            int sum_t = 0;\n\n            for (int j = 0; j < QK_K; j += 128) {\n\n                vl = 32;\n\n                // load Q3\n                vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);\n\n                vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));\n                vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));\n                vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));\n                vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));\n\n                // compute mask for subtraction\n                vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);\n                vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);\n                vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);\n                m <<= 1;\n\n                vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);\n                vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);\n                vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);\n                m <<= 1;\n\n                vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);\n                vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);\n                vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);\n                m <<= 1;\n\n                vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);\n                vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);\n                vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);\n                m <<= 1;\n\n                // load Q8 and take product with Q3\n                vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);\n                vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);\n                vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);\n                vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);\n\n                vl = 16;\n\n                // retrieve lane to multiply with scale\n                vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);\n                vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);\n                vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);\n                vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);\n                vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);\n                vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);\n                vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);\n                vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);\n\n                vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);\n                vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);\n                vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);\n                vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);\n\n                sum_t +=  __riscv_vmv_x_s_i32m1_i32(isum3);\n\n                q3 += 32;    q8 += 128;   scale += 8;\n\n            }\n\n            const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n\n            sumf += d*sum_t;\n\n        }\n        break;\n    case 128:\n        for (int i = 0; i < nb; ++i) {\n            const uint8_t * restrict q3 = x[i].qs;\n            const uint8_t * restrict qh = x[i].hmask;\n            const  int8_t * restrict q8 = y[i].qs;\n\n            int8_t * scale = (int8_t *)utmp;\n            int tmp, t1, t2, t3, t4, t5, t6, t7;\n            __asm__ __volatile__(\n                \"vsetivli zero, 12, e8, m1\\n\\t\"\n                \"vle8.v v0, (%[s6b])\\n\\t\"\n                \"vmv1r.v v2, v0\\n\\t\"\n                \"vsetivli zero, 2, e64, m1\\n\\t\"\n                \"vmv.v.x v9, %[sh]\\n\\t\"\\\n                \"vslidedown.vi v1, v0, 1\\n\\t\"\n                \"vslide1up.vx v8, v9, zero\\n\\t\" // {0, 0, 4, 4}\n                \"vslideup.vi v0, v2, 1\\n\\t\" // {aux[0], aux[1], aux[0], aux[1]}\n                \"vsetivli zero, 4, e32, m1\\n\\t\"\n                \"vid.v v9\\n\\t\"\n                \"vmv.x.s %[tmp], v1\\n\\t\"\n                \"vsll.vi v9, v9, 1\\n\\t\" // {0, 2, 4, 6}\n                \"vmv.v.x v1, %[tmp]\\n\\t\" // {aux[2], aux[2], aux[2], aux[2]}\n                \"vsrl.vv v4, v1, v9\\n\\t\"\n                \"vsrl.vv v2, v0, v8\\n\\t\"\n                \"vand.vx v5, v4, %[kmask1]\\n\\t\"\n                \"vand.vx v3, v2, %[kmask2]\\n\\t\"\n                \"vsll.vi v6, v5, 4\\n\\t\"\n                \"vor.vv v7, v6, v3\\n\\t\"\n                \"vsetivli zero, 16, e8, m1\\n\\t\"\n                \"vsub.vx v0, v7, %[c]\\n\\t\"\n                \"vse8.v v0, (%[scale])\"\n                : [tmp] \"=&r\" (tmp)\n                : [sh] \"r\" (0x0000000400000004), [s6b] \"r\" (x[i].scales), [c] \"r\" (32)\n                , [scale] \"r\" (scale), [kmask1] \"r\" (kmask1), [kmask2] \"r\" (kmask2)\n                : \"memory\"\n                , \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\"\n                , \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\"\n                , \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\"\n                , \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\"\n            );\n\n            uint8_t m = 1;\n            int isum = 0;\n            for (int j = 0; j < QK_K; j += 128) {\n                __asm__ __volatile__(\n                    \"lb zero, 31(%[q3])\\n\\t\"\n                    \"vsetvli zero, %[vl32], e8, m2, ta, mu\\n\\t\"\n                    \"vle8.v v8, (%[q3])\\n\\t\"\n                    \"vsrl.vi v10, v8, 2\\n\\t\"\n                    \"vsrl.vi v12, v8, 4\\n\\t\"\n                    \"vsrl.vi v14, v8, 6\\n\\t\"\n                    \"lb zero, 64(%[q8])\\n\\t\"\n                    \"vand.vi v8, v8, 3\\n\\t\"\n                    \"vand.vi v10, v10, 3\\n\\t\"\n                    \"vand.vi v12, v12, 3\\n\\t\"\n                    \"vle8.v v2, (%[qh])\\n\\t\"\n                    \"lb zero, 127(%[q8])\\n\\t\"\n                    \"vand.vx v4, v2, %[m]\\n\\t\"\n                    \"slli %[m], %[m], 1\\n\\t\"\n                    \"vmseq.vx v0, v4, zero\\n\\t\"\n                    \"vadd.vi v8, v8, -4, v0.t\\n\\t\"\n                    \"lb zero, 0(%[q8])\\n\\t\"\n                    \"vand.vx v4, v2, %[m]\\n\\t\"\n                    \"slli %[m], %[m], 1\\n\\t\"\n                    \"vmseq.vx v0, v4, zero\\n\\t\"\n                    \"vadd.vi v10, v10, -4, v0.t\\n\\t\"\n                    \"vand.vx v4, v2, %[m]\\n\\t\"\n                    \"slli %[m], %[m], 1\\n\\t\"\n                    \"vmseq.vx v0, v4, zero\\n\\t\"\n                    \"vadd.vi v12, v12, -4, v0.t\\n\\t\"\n                    \"vand.vx v4, v2, %[m]\\n\\t\"\n                    \"slli %[m], %[m], 1\\n\\t\"\n                    \"vmseq.vx v0, v4, zero\\n\\t\"\n                    \"vadd.vi v14, v14, -4, v0.t\\n\\t\"\n                    \"vsetvli zero, %[vl128], e8, m8\\n\\t\"\n                    \"vle8.v v0, (%[q8])\\n\\t\"\n                    \"lb %[tmp], 0(%[scale])\\n\\t\"\n                    \"lb %[t1], 1(%[scale])\\n\\t\"\n                    \"lb %[t2], 2(%[scale])\\n\\t\"\n                    \"lb %[t3], 3(%[scale])\\n\\t\"\n                    \"vsetvli zero, %[vl64], e8, m4\\n\\t\"\n                    \"vwmul.vv v16, v0, v8\\n\\t\"\n                    \"vwmul.vv v24, v4, v12\\n\\t\"\n                    \"vsetivli zero, 16, e16, m2\\n\\t\"\n                    \"vmv.v.x v0, zero\\n\\t\"\n                    \"vwredsum.vs v8, v16, v0\\n\\t\"\n                    \"lb %[t4], 4(%[scale])\\n\\t\"\n                    \"lb %[t5], 5(%[scale])\\n\\t\"\n                    \"vwredsum.vs v9, v18, v0\\n\\t\"\n                    \"vwredsum.vs v10, v20, v0\\n\\t\"\n                    \"vwredsum.vs v11, v22, v0\\n\\t\"\n                    \"vwredsum.vs v12, v24, v0\\n\\t\"\n                    \"lb %[t6], 6(%[scale])\\n\\t\"\n                    \"lb %[t7], 7(%[scale])\\n\\t\"\n                    \"vwredsum.vs v13, v26, v0\\n\\t\"\n                    \"vwredsum.vs v14, v28, v0\\n\\t\"\n                    \"vwredsum.vs v15, v30, v0\\n\\t\"\n                    \"vsetivli zero, 4, e32, m1\\n\\t\"\n                    \"vmul.vx v0, v8, %[tmp]\\n\\t\"\n                    \"vmul.vx v1, v9, %[t1]\\n\\t\"\n                    \"vmacc.vx v0, %[t2], v10\\n\\t\"\n                    \"vmacc.vx v1, %[t3], v11\\n\\t\"\n                    \"vmacc.vx v0, %[t4], v12\\n\\t\"\n                    \"vmacc.vx v1, %[t5], v13\\n\\t\"\n                    \"vmacc.vx v0, %[t6], v14\\n\\t\"\n                    \"vmacc.vx v1, %[t7], v15\\n\\t\"\n                    \"vmv.x.s %[tmp], v0\\n\\t\"\n                    \"vmv.x.s %[t1], v1\\n\\t\"\n                    \"add %[isum], %[isum], %[tmp]\\n\\t\"\n                    \"add %[isum], %[isum], %[t1]\"\n                    : [tmp] \"=&r\" (tmp), [t1] \"=&r\" (t1), [t2] \"=&r\" (t2), [t3] \"=&r\" (t3)\n                    , [t4] \"=&r\" (t4), [t5] \"=&r\" (t5), [t6] \"=&r\" (t6), [t7] \"=&r\" (t7)\n                    , [m] \"+&r\" (m), [isum] \"+&r\" (isum)\n                    : [vl128] \"r\" (128), [vl64] \"r\" (64), [vl32] \"r\" (32)\n                    , [q3] \"r\" (q3), [qh] \"r\" (qh), [scale] \"r\" (scale), [q8] \"r\" (q8)\n                    : \"memory\"\n                    , \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\"\n                    , \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\"\n                    , \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\"\n                    , \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\"\n                );\n                q3 += 32;    q8 += 128;   scale += 8;\n            }\n\n            const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n            sumf += d * isum;\n        }\n        break;\n    default:\n        assert(false && \"Unsupported vector length\");\n        break;\n    }\n\n    *s = sumf;\n\n#else\n\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n\n    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n\n}\n\nvoid ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n#if defined __riscv_xtheadvector\n\n    const uint8_t * scales = (const uint8_t*)&utmp[0];\n    const uint8_t * mins   = (const uint8_t*)&utmp[2];\n\n    float sumf = 0;\n\n    for (int i = 0; i < nb; ++i) {\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        int tmp, tmp2, sumi;\n        __asm__ __volatile__(\n            \"li %[t1], 12\\n\\t\"\n            \"th.vsetvli zero, %[t1], e8, m1\\n\\t\"\n            \"th.vlb.v v1, (%[s6b])\\n\\t\" // {aux[0], aux[1], aux[2]}\n            \"li %[t1], 4\\n\\t\"\n            \"th.vsetvli zero, %[t1], e32, m1\\n\\t\"\n            \"th.vslidedown.vi v2, v1, 2\\n\\t\"\n            \"th.vmv.v.v v3, v2\\n\\t\"\n            \"th.vslideup.vi v2, v3, 1\\n\\t\" // {aux[2], aux[2]}\n            \"li %[t1], 2\\n\\t\"\n            \"th.vsetvli zero, %[t1], e32, m1\\n\\t\"\n            \"th.vmv.v.i v4, 4\\n\\t\"\n            \"th.vand.vx v8, v1, %[kmask1]\\n\\t\"\n            \"th.vslide1up.vx v5, v4, zero\\n\\t\" // {0, 4}\n            \"th.vsrl.vi v6, v1, 6\\n\\t\"\n            \"th.vsrl.vv v7, v2, v5\\n\\t\"\n            \"th.vand.vx v0, v6, %[kmask3]\\n\\t\"\n            \"th.vand.vx v2, v7, %[kmask2]\\n\\t\"\n            \"th.vsll.vi v6, v0, 4\\n\\t\"\n            \"li %[t2], 8\\n\\t\"\n            \"addi %[t1], %[utmp], 4\\n\\t\"\n            \"th.vor.vv v1, v6, v2\\n\\t\"\n            \"th.vssw.v v8, (%[utmp]), %[t2]\\n\\t\"\n            \"th.vssw.v v1, (%[t1]), %[t2]\\n\\t\"\n            \"th.vsetvli zero, zero, e32, m2\\n\\t\" // vl == 8\n            \"th.vlw.v v2, (%[bsums])\\n\\t\"\n            \"th.vsetvli zero, %[t2], e16, m1\\n\\t\"\n            \"th.vnsrl.vi v0, v2, 0\\n\\t\"\n            \"th.vnsrl.vi v1, v2, 16\\n\\t\"\n            \"th.vadd.vv v2, v0, v1\\n\\t\"\n            \"th.vlbu.v v4, (%[mins])\\n\\t\"\n            \"th.vwmul.vv v6, v4, v2\\n\\t\"\n            \"th.vmv.v.x v0, zero\\n\\t\"\n            \"th.vsetvli zero, %[t2], e32, m2\\n\\t\"\n            \"th.vredsum.vs v0, v6, v0\\n\\t\"\n            \"th.vmv.x.s %[sumi], v0\"\n            : [t1] \"=&r\" (tmp), [t2] \"=&r\" (tmp2), [sumi] \"=&r\" (sumi)\n            : [bsums] \"r\" (y[i].bsums), [mins] \"r\" (mins), [utmp] \"r\" (utmp)\n            , [s6b] \"r\" (x[i].scales), [kmask1] \"r\" (kmask1)\n            , [kmask2] \"r\" (kmask2), [kmask3] \"r\" (kmask3)\n            : \"memory\"\n            , \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\"\n            , \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\"\n            , \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\"\n            , \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\"\n        );\n        sumf -= dmin * sumi;\n\n        const uint8_t * restrict q4 = x[i].qs;\n        const int8_t  * restrict q8 = y[i].qs;\n\n        sumi = 0;\n        const uint8_t * scale = scales;\n\n        for (int j = 0; j < QK_K/128; ++j) {\n            int vl128 = 128, vl64 = 64, vl32 = 32;\n            __asm__ __volatile__(\n                \"th.vsetvli zero, %[vl128], e8, m8\\n\\t\"\n                \"th.vlb.v v8, (%[q8])\\n\\t\"\n                \"th.vsetvli zero, %[vl64], e8, m4\\n\\t\"\n                \"th.vlb.v v0, (%[q4])\\n\\t\"\n                \"th.vsrl.vi v4, v0, 4\\n\\t\"\n                \"th.vand.vi v0, v0, 0xF\\n\\t\"\n                \"th.vsetvli zero, %[vl32], e8, m2\\n\\t\"\n                \"th.vwmul.vv v28, v6, v14\\n\\t\"\n                \"th.vwmul.vv v20, v4, v10\\n\\t\"\n                \"th.vwmul.vv v24, v2, v12\\n\\t\"\n                \"th.vwmul.vv v16, v0, v8\\n\\t\"\n                \"li %[tmp], 4\\n\\t\"\n                \"th.vsetvli zero, %[tmp], e32, m1\\n\\t\"\n                \"th.vlbu.v v1, (%[scale])\\n\\t\"\n                \"th.vmv.v.x v0, zero\\n\\t\"\n                \"th.vsetvli zero, %[vl32], e16, m4\\n\\t\"\n                \"th.vwredsum.vs v6, v24, v0\\n\\t\"\n                \"th.vwredsum.vs v7, v28, v0\\n\\t\"\n                \"th.vwredsum.vs v4, v16, v0\\n\\t\"\n                \"th.vwredsum.vs v5, v20, v0\\n\\t\"\n                \"th.vsetvli zero, %[tmp], e32, m1\\n\\t\"\n                \"th.vslideup.vi v6, v7, 1\\n\\t\"\n                \"th.vslideup.vi v4, v5, 1\\n\\t\"\n                \"th.vslideup.vi v4, v6, 2\\n\\t\"\n                \"th.vmul.vv v8, v4, v1\\n\\t\"\n                \"th.vredsum.vs v0, v8, v0\\n\\t\"\n                \"th.vmv.x.s %[tmp], v0\\n\\t\"\n                \"add %[sumi], %[sumi], %[tmp]\"\n                : [tmp] \"=&r\" (tmp), [sumi] \"+&r\" (sumi)\n                : [vl128] \"r\" (vl128), [vl64] \"r\" (vl64), [vl32] \"r\" (vl32)\n                , [q4] \"r\" (q4), [q8] \"r\" (q8), [scale] \"r\" (scale)\n                : \"memory\"\n                , \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\"\n                , \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\"\n                , \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\"\n                , \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\"\n            );\n\n            q4 += 64;    q8 += 128;    scale += 4;\n        }\n\n        sumf += d * sumi;\n\n    }\n\n    *s = sumf;\n\n#elif defined __riscv_v\n\n    const uint8_t * scales = (const uint8_t*)&utmp[0];\n    const uint8_t * mins   = (const uint8_t*)&utmp[2];\n\n    float sumf = 0;\n    const int vector_length = __riscv_vlenb() * 8;\n\n    switch (vector_length) {\n    case 256:\n        for (int i = 0; i < nb; ++i) {\n\n            size_t vl = 8;\n\n            const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n            const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n            vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);\n            vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);\n            vint16mf2_t q8sums   = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);\n\n            memcpy(utmp, x[i].scales, 12);\n            utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n            const uint32_t uaux = utmp[1] & kmask1;\n            utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n            utmp[2] = uaux;\n            utmp[0] &= kmask1;\n\n            vuint8mf4_t mins8  = __riscv_vle8_v_u8mf4(mins, vl);\n            vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));\n            vint32m1_t  prod   = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);\n\n            vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);\n            sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);\n\n            const uint8_t * GGML_RESTRICT q4 = x[i].qs;\n            const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n            vl = 32;\n\n            int32_t sum_1 = 0;\n            int32_t sum_2 = 0;\n\n            vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);\n\n            for (int j = 0; j < QK_K/64; ++j) {\n                // load Q4\n                vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);\n\n                // load Q8 and multiply it with lower Q4 nibble\n                vint8m1_t  q8_0 = __riscv_vle8_v_i8m1(q8, vl);\n                vint8m1_t  q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));\n                vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);\n                vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);\n\n                sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];\n\n                // load Q8 and multiply it with upper Q4 nibble\n                vint8m1_t  q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);\n                vint8m1_t  q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));\n                vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);\n                vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);\n\n                sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];\n\n                q4 += 32;    q8 += 64;\n\n            }\n\n            sumf += d*(sum_1 + sum_2);\n\n        }\n        break;\n    case 128:\n        for (int i = 0; i < nb; ++i) {\n            const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n            const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n            float ftmp, ft2;\n            const uint8_t * restrict q40;\n            const uint8_t * restrict q41;\n            const uint8_t * restrict q42;\n            const uint8_t * restrict q43;\n            const int8_t  * restrict q80;\n            const int8_t  * restrict q81;\n            const int8_t  * restrict q82;\n            const int8_t  * restrict q83;\n            int s0, s1, s2, s3;\n\n            __asm__ __volatile__(\n                \"li %[s1], 8\\n\\t\"\n                \"vsetivli zero, 4, e32, m1, ta, ma\\n\\t\"\n                \"vle32.v v1, (%[s6b])\\n\\t\"\n                \"vslide1down.vx v1, v1, zero\\n\\t\"\n                \"vmv.v.x v16, zero\\n\\t\"\n                \"vslidedown.vi v2, v1, 2\\n\\t\"\n                \"vmv1r.v v3, v2\\n\\t\"\n                \"vslideup.vi v2, v3, 1\\n\\t\" // {aux[2], aux[2]}\n                \"vsetivli zero, 2, e32, m1, ta, ma\\n\\t\"\n                \"vmv.v.i v4, 4\\n\\t\"\n                \"vand.vx v8, v1, %[kmask1]\\n\\t\"\n                \"vslide1up.vx v5, v4, zero\\n\\t\" // {0, 4}\n                \"vsrl.vi v6, v1, 6\\n\\t\"\n                \"vsrl.vv v7, v2, v5\\n\\t\"\n                \"vsse32.v v8, (%[utmp]), %[s1]\\n\\t\"\n                \"vand.vx v0, v6, %[kmask3]\\n\\t\"\n                \"vand.vx v2, v7, %[kmask2]\\n\\t\"\n                \"vsll.vi v6, v0, 4\\n\\t\"\n                \"addi %[s0], %[utmp], 4\\n\\t\"\n                \"vor.vv v1, v6, v2\\n\\t\"\n                \"vsse32.v v1, (%[s0]), %[s1]\\n\\t\"\n                \"vsetivli zero, 8, e16, m1, ta, ma\\n\\t\"\n                \"vle32.v v2, (%[bsums])\\n\\t\"\n                \"vnsrl.wi v0, v2, 0\\n\\t\"\n                \"vnsrl.wi v1, v2, 16\\n\\t\"\n                \"vadd.vv v2, v0, v1\\n\\t\"\n                \"vle8.v v3, (%[mins])\\n\\t\"\n                \"vzext.vf2 v4, v3\\n\\t\"\n                \"vwmul.vv v6, v4, v2\\n\\t\"\n                \"vsetivli zero, 4, e32, m1, ta, ma\\n\\t\"\n                \"vredsum.vs v0, v6, v16\\n\\t\"\n                \"vredsum.vs v0, v7, v0\\n\\t\"\n                \"vfcvt.f.x.v v0, v0\\n\\t\"\n                \"vfmv.f.s %[ftmp], v0\\n\\t\"\n                \"vsetivli zero, 16, e8, m1, ta, ma\\n\\t\"\n                \"vle8.v v0, (%[xs])\\n\\t\"\n                \"fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\\n\\t\"\n                \"addi %[q40], %[xs], 64\\n\\t\"\n                \"addi %[q41], %[xs], 16\\n\\t\"\n                \"addi %[q42], %[xs], 32\\n\\t\"\n                \"addi %[q43], %[xs], 48\\n\\t\"\n                \"addi %[q80], %[ys], 64\\n\\t\"\n                \"vle8.v v1, (%[q41])\\n\\t\"\n                \"vle8.v v2, (%[q42])\\n\\t\"\n                \"addi %[q81], %[ys], 16\\n\\t\"\n                \"addi %[q41], %[q41], 64\\n\\t\"\n                \"addi %[q82], %[ys], 32\\n\\t\"\n                \"vle8.v v3, (%[q43])\\n\\t\"\n                \"vle8.v v8, (%[ys])\\n\\t\"\n                \"addi %[q42], %[q42], 64\\n\\t\"\n                \"addi %[q83], %[ys], 48\\n\\t\"\n                \"addi %[q43], %[q43], 64\\n\\t\"\n                \"vsrl.vi v4, v0, 4\\n\\t\"\n                \"vle8.v v9, (%[q81])\\n\\t\"\n                \"vle8.v v10, (%[q82])\\n\\t\"\n                \"vand.vi v0, v0, 0xF\\n\\t\"\n                \"addi %[q81], %[q81], 64\\n\\t\"\n                \"vsrl.vi v5, v1, 4\\n\\t\"\n                \"addi %[q82], %[q82], 64\\n\\t\"\n                \"vle8.v v11, (%[q83])\\n\\t\"\n                \"vle8.v v12, (%[q80])\\n\\t\"\n                \"vand.vi v1, v1, 0xF\\n\\t\"\n                \"addi %[q83], %[q83], 64\\n\\t\"\n                \"vsrl.vi v6, v2, 4\\n\\t\"\n                \"addi %[q80], %[q80], 64\\n\\t\"\n                \"vle8.v v13, (%[q81])\\n\\t\"\n                \"vle8.v v14, (%[q82])\\n\\t\"\n                \"vand.vi v2, v2, 0xF\\n\\t\"\n                \"addi %[q81], %[q81], 64\\n\\t\"\n                \"vsrl.vi v7, v3, 4\\n\\t\"\n                \"addi %[q82], %[q82], 64\\n\\t\"\n                \"vwmul.vv v16, v0, v8\\n\\t\"\n                \"vle8.v v15, (%[q83])\\n\\t\"\n                \"vle8.v v0, (%[q40])\\n\\t\"\n                \"vand.vi v3, v3, 0xF\\n\\t\"\n                \"addi %[q83], %[q83], 64\\n\\t\"\n                \"vwmul.vv v24, v2, v12\\n\\t\"\n                \"vwmul.vv v20, v4, v10\\n\\t\"\n                \"vwmul.vv v28, v6, v14\\n\\t\"\n                \"vwmacc.vv v16, v1, v9\\n\\t\"\n                \"vle8.v v1, (%[q41])\\n\\t\"\n                \"vle8.v v2, (%[q42])\\n\\t\"\n                \"vwmacc.vv v24, v3, v13\\n\\t\"\n                \"vwmacc.vv v20, v5, v11\\n\\t\"\n                \"vwmacc.vv v28, v7, v15\\n\\t\"\n                \"addi %[q40], %[q80], 64\\n\\t\"\n                \"addi %[q41], %[q81], 64\\n\\t\"\n                \"vle8.v v3, (%[q43])\\n\\t\"\n                \"vle8.v v8, (%[q80])\\n\\t\"\n                \"addi %[q42], %[q82], 64\\n\\t\"\n                \"addi %[q43], %[q83], 64\\n\\t\"\n                \"vsrl.vi v4, v0, 4\\n\\t\"\n                \"vle8.v v9, (%[q81])\\n\\t\"\n                \"vle8.v v10, (%[q82])\\n\\t\"\n                \"vand.vi v0, v0, 0xF\\n\\t\"\n                \"vsrl.vi v5, v1, 4\\n\\t\"\n                \"vsrl.vi v7, v3, 4\\n\\t\"\n                \"vand.vi v3, v3, 0xF\\n\\t\"\n                \"vle8.v v11, (%[q83])\\n\\t\"\n                \"vle8.v v12, (%[q40])\\n\\t\"\n                \"vand.vi v1, v1, 0xF\\n\\t\"\n                \"vsrl.vi v6, v2, 4\\n\\t\"\n                \"vand.vi v2, v2, 0xF\\n\\t\"\n                \"vwmul.vv v18, v0, v8\\n\\t\"\n                \"vle8.v v13, (%[q41])\\n\\t\"\n                \"vle8.v v14, (%[q42])\\n\\t\"\n                \"vwmul.vv v26, v2, v12\\n\\t\"\n                \"vwmul.vv v22, v4, v10\\n\\t\"\n                \"vwmul.vv v30, v6, v14\\n\\t\"\n                \"vwmacc.vv v18, v1, v9\\n\\t\"\n                \"vle8.v v15, (%[q43])\\n\\t\"\n                \"vwmacc.vv v26, v3, v13\\n\\t\"\n                \"vwmacc.vv v22, v5, v11\\n\\t\"\n                \"vwmacc.vv v30, v7, v15\\n\\t\"\n                \"vmv.v.x v0, zero\\n\\t\"\n                \"vsetivli zero, 16, e16, m2, ta, ma\\n\\t\"\n                \"vwredsum.vs v4, v16, v0\\n\\t\"\n                \"lbu %[s0], 0(%[scale])\\n\\t\"\n                \"vwredsum.vs v5, v20, v0\\n\\t\"\n                \"lbu %[s1], 1(%[scale])\\n\\t\"\n                \"vwredsum.vs v6, v24, v0\\n\\t\"\n                \"lbu %[s2], 2(%[scale])\\n\\t\"\n                \"vwredsum.vs v7, v28, v0\\n\\t\"\n                \"lbu %[s3], 3(%[scale])\\n\\t\"\n                \"vwredsum.vs v8, v18, v0\\n\\t\"\n                \"lbu %[q40], 4(%[scale])\\n\\t\"\n                \"vwredsum.vs v9, v22, v0\\n\\t\"\n                \"lbu %[q41], 5(%[scale])\\n\\t\"\n                \"vwredsum.vs v10, v26, v0\\n\\t\"\n                \"lbu %[q42], 6(%[scale])\\n\\t\"\n                \"vwredsum.vs v11, v30, v0\\n\\t\"\n                \"lbu %[q43], 7(%[scale])\\n\\t\"\n                \"vsetivli zero, 4, e32, m1, ta, ma\\n\\t\"\n                \"vmul.vx v0, v4, %[s0]\\n\\t\"\n                \"vmul.vx v1, v8, %[q40]\\n\\t\"\n                \"vmacc.vx v0, %[s1], v5\\n\\t\"\n                \"vmacc.vx v1, %[q41], v9\\n\\t\"\n                \"vmacc.vx v0, %[s2], v6\\n\\t\"\n                \"vmacc.vx v1, %[q42], v10\\n\\t\"\n                \"vmacc.vx v0, %[s3], v7\\n\\t\"\n                \"vmacc.vx v1, %[q43], v11\\n\\t\"\n                \"vfcvt.f.x.v v0, v0\\n\\t\"\n                \"vfcvt.f.x.v v1, v1\\n\\t\"\n                \"vfmv.f.s %[ft2], v0\\n\\t\"\n                \"vfmv.f.s %[ftmp], v1\\n\\t\"\n                \"fadd.s %[ft2], %[ft2], %[ftmp]\\n\\t\"\n                \"fmadd.s %[sumf], %[d], %[ft2], %[sumf]\"\n                : [ftmp] \"=&f\" (ftmp), [sumf] \"+&f\" (sumf), [ft2] \"=&f\" (ft2)\n                , [s0] \"=&r\" (s0), [s1] \"=&r\" (s1), [s2] \"=&r\" (s2), [s3] \"=&r\" (s3)\n                , [q40] \"=&r\" (q40), [q41] \"=&r\" (q41), [q42] \"=&r\" (q42), [q43] \"=&r\" (q43)\n                , [q80] \"=&r\" (q80), [q81] \"=&r\" (q81), [q82] \"=&r\" (q82), [q83] \"=&r\" (q83)\n                : [d] \"f\" (d), [ys] \"r\" (y[i].qs), [xs] \"r\" (x[i].qs), [scale] \"r\" (scales)\n                , [bsums] \"r\" (y[i].bsums), [mins] \"r\" (mins), [utmp] \"r\" (utmp)\n                , [s6b] \"r\" (&x[i]), [kmask1] \"r\" (kmask1), [dmin] \"f\" (dmin)\n                , [kmask2] \"r\" (kmask2), [kmask3] \"r\" (kmask3)\n                : \"memory\"\n                , \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\"\n                , \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\"\n                , \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\"\n                , \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\"\n            );\n        }\n        break;\n    default:\n        assert(false && \"Unsupported vector length\");\n        break;\n    }\n\n    *s = sumf;\n\n#else\n\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    UNUSED(nb);\n    UNUSED(utmp);\n\n    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n#if defined __riscv_v\n\n    const uint8_t * scales = (const uint8_t*)&utmp[0];\n    const uint8_t * mins   = (const uint8_t*)&utmp[2];\n\n    float sumf = 0;\n    float sums = 0.0;\n\n    size_t vl;\n\n    for (int i = 0; i < nb; ++i) {\n\n        vl = 8;\n\n        const uint8_t * GGML_RESTRICT q5 = x[i].qs;\n        const uint8_t * GGML_RESTRICT hm = x[i].qh;\n        const  int8_t * GGML_RESTRICT q8 = y[i].qs;\n\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;\n\n        vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl);\n        vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl);\n        vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl);\n\n        memcpy(utmp, x[i].scales, 12);\n        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n        const uint32_t uaux = utmp[1] & kmask1;\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[2] = uaux;\n        utmp[0] &= kmask1;\n\n        vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl);\n        vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));\n        vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl);\n\n        vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);\n        sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);\n\n        vl = 32;\n        int32_t aux32 = 0;\n        int is = 0;\n\n        uint8_t m = 1;\n        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);\n        vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl);\n\n        for (int j = 0; j < QK_K/64; ++j) {\n            // load Q5 and Q8\n            vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl);\n            vint8m2_t  q8_y1 = __riscv_vle8_v_i8m2(q8, vl);\n            vint8m2_t  q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl);\n\n            // compute mask for addition\n            vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl));\n            vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl);\n            vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl);\n            vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl);\n            m <<= 1;\n\n            vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl));\n            vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl);\n            vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl);\n            vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl);\n            m <<= 1;\n\n            vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl);\n            vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl);\n\n            vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl);\n            vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl);\n\n            vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl);\n            vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl);\n\n            aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2);\n            q5 += 32;    q8 += 64;\n\n        }\n\n        sums += aux32 * d;\n\n    }\n\n    *s = sumf+sums;\n\n#else\n\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    UNUSED(nb);\n    UNUSED(utmp);\n\n    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q6_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __riscv_xtheadvector\n\n    float sumf = 0;\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n\n        const uint8_t * restrict q6 = x[i].ql;\n        const uint8_t * restrict qh = x[i].qh;\n        const  int8_t * restrict q8 = y[i].qs;\n\n        const int8_t * restrict scale = x[i].scales;\n\n        int sum_t = 0;\n        int t0;\n\n        for (int j = 0; j < QK_K/128; ++j) {\n            __asm__ __volatile__(\n                \"th.vsetvli zero, %[vl32], e8, m2\\n\\t\" // vl == 32\n                \"th.vlb.v v4, (%[qh])\\n\\t\"\n                \"th.vsll.vi v0, v4, 4\\n\\t\"\n                \"th.vsll.vi v2, v4, 2\\n\\t\"\n                \"th.vsrl.vi v6, v4, 2\\n\\t\"\n                \"th.vsetvli zero, %[vl64], e8, m4\\n\\t\" // vl == 64\n                \"th.vlb.v v8, (%[q6])\\n\\t\"\n                \"th.vsrl.vi v12, v8, 4\\n\\t\"\n                \"th.vand.vi v8, v8, 0xF\\n\\t\"\n                \"th.vsetvli zero, %[vl128], e8, m8\\n\\t\" // vl == 128\n                \"th.vand.vx v0, v0, %[mask]\\n\\t\"\n                \"th.vor.vv v8, v8, v0\\n\\t\"\n                \"th.vlb.v v0, (%[q8])\\n\\t\"\n                \"th.vsub.vx v8, v8, %[vl32]\\n\\t\"\n                \"th.vsetvli zero, %[vl64], e8, m4\\n\\t\" // vl == 64\n                \"th.vwmul.vv v16, v0, v8\\n\\t\"\n                \"th.vwmul.vv v24, v4, v12\\n\\t\"\n                \"li %[t0], 16\\n\\t\"\n                \"th.vsetvli zero, %[t0], e16, m2\\n\\t\" // vl == 16\n                \"th.vmv.v.x v0, zero\\n\\t\"\n                \"th.vwredsum.vs v10, v16, v0\\n\\t\"\n                \"th.vwredsum.vs v9, v18, v0\\n\\t\"\n                \"th.vwredsum.vs v8, v20, v0\\n\\t\"\n                \"th.vwredsum.vs v7, v22, v0\\n\\t\"\n                \"th.vwredsum.vs v11, v24, v0\\n\\t\"\n                \"th.vwredsum.vs v12, v26, v0\\n\\t\"\n                \"th.vwredsum.vs v13, v28, v0\\n\\t\"\n                \"th.vwredsum.vs v14, v30, v0\\n\\t\"\n                \"li %[t0], 4\\n\\t\"\n                \"th.vsetvli zero, %[t0], e32, m1\\n\\t\" // vl == 4\n                \"th.vslideup.vi v10, v9, 1\\n\\t\"\n                \"th.vslideup.vi v8, v7, 1\\n\\t\"\n                \"th.vslideup.vi v11, v12, 1\\n\\t\"\n                \"th.vslideup.vi v13, v14, 1\\n\\t\"\n                \"th.vslideup.vi v10, v8, 2\\n\\t\"\n                \"th.vslideup.vi v11, v13, 2\\n\\t\"\n                \"li %[t0], 8\\n\\t\"\n                \"th.vsetvli zero, %[t0], e32, m2\\n\\t\" // vl == 8\n                \"th.vlb.v v4, (%[scale])\\n\\t\"\n                \"th.vmul.vv v2, v4, v10\\n\\t\"\n                \"th.vredsum.vs v0, v2, v0\\n\\t\"\n                \"th.vmv.x.s %[t0], v0\\n\\t\"\n                \"add %[sumi], %[sumi], %[t0]\"\n                : [sumi] \"+&r\" (sum_t), [t0] \"=&r\" (t0)\n                : [qh] \"r\" (qh), [q6] \"r\" (q6), [q8] \"r\" (q8), [scale] \"r\" (scale)\n                , [vl32] \"r\" (32), [vl64] \"r\" (64), [vl128] \"r\" (128)\n                , [mask] \"r\" (0x30)\n                : \"memory\"\n                , \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\"\n                , \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\"\n                , \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\"\n                , \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\"\n            );\n            q6 += 64;   qh += 32;   q8 += 128;   scale += 8;\n        }\n\n        sumf += d * sum_t;\n\n    }\n\n    *s = sumf;\n\n#elif defined __riscv_v\n\n    float sumf = 0;\n    const int vector_length = __riscv_vlenb() * 8;\n\n    switch (vector_length) {\n    case 256:\n        for (int i = 0; i < nb; ++i) {\n\n            const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n\n            const uint8_t * GGML_RESTRICT q6 = x[i].ql;\n            const uint8_t * GGML_RESTRICT qh = x[i].qh;\n            const  int8_t * GGML_RESTRICT q8 = y[i].qs;\n\n            const int8_t * GGML_RESTRICT scale = x[i].scales;\n\n            size_t vl;\n\n            vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);\n\n            int sum_t = 0;\n            int is = 0;\n\n            for (int j = 0; j < QK_K/128; ++j) {\n\n                vl = 32;\n\n                // load qh\n                vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);\n\n                // load Q6\n                vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);\n                vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);\n\n                vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);\n                vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);\n                vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);\n                vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);\n\n                vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);\n                vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);\n                vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);\n                vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);\n\n                vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);\n                vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);\n                vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);\n                vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);\n\n                vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);\n                vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);\n                vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);\n                vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);\n\n                // load Q8 and take product\n                vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);\n                vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);\n                vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);\n                vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);\n\n                vl = 16;\n\n                vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);\n                vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);\n                vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);\n                vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);\n                vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);\n                vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);\n                vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);\n                vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);\n\n                vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);\n                vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);\n                vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);\n                vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);\n\n                sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);\n\n                q6 += 64;   qh += 32;   q8 += 128;   is=8;\n\n            }\n\n            sumf += d * sum_t;\n\n        }\n        break;\n    case 128:\n        for (int i = 0; i < nb; ++i) {\n\n            __builtin_prefetch(&x[i + 1].d, 0, 1);\n\n            const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n\n            const uint8_t * restrict q6 = x[i].ql;\n            const uint8_t * restrict qh = x[i].qh;\n            const  int8_t * restrict q8 = y[i].qs;\n\n            const int8_t * restrict scale = x[i].scales;\n\n            int q6h;\n            float ftmp;\n\n            for (int j = 0; j < QK_K/128; ++j) {\n                __asm__ __volatile__(\n                    \"addi %[q6h], %[q6], 32\\n\\t\"\n                    \"ld t0, 0(%[scale])\\n\\t\"\n                    \"addi %[scale], %[scale], 8\\n\\t\"\n                    \"slli t6, t0, 1 * 8\\n\\t\"\n                    \"lb zero, 0(%[q6])\\n\\t\"\n                    \"slli t5, t0, 2 * 8\\n\\t\"\n                    \"slli t4, t0, 3 * 8\\n\\t\"\n                    \"lb zero, 0(%[q6h])\\n\\t\"\n                    \"slli t3, t0, 4 * 8\\n\\t\"\n                    \"slli t2, t0, 5 * 8\\n\\t\"\n                    \"lb zero, 0(%[qh])\\n\\t\"\n                    \"lb zero, 31(%[q6h])\\n\\t\"\n                    \"slli t1, t0, 6 * 8\\n\\t\"\n                    \"srai a7, t0, 56\\n\\t\"\n                    \"vsetvli zero, %[vl32], e8, m2\\n\\t\"\n                    \"vle8.v v8, (%[q6])\\n\\t\"\n                    \"srai t6, t6, 56\\n\\t\"\n                    \"srai t5, t5, 56\\n\\t\"\n                    \"srai t4, t4, 56\\n\\t\"\n                    \"srai t3, t3, 56\\n\\t\"\n                    \"vle8.v v10, (%[q6h])\\n\\t\"\n                    \"addi %[q6], %[q6], 64\\n\\t\"\n                    \"slli t0, t0, 7 * 8\\n\\t\"\n                    \"srai t2, t2, 56\\n\\t\"\n                    \"srai t1, t1, 56\\n\\t\"\n                    \"srai t0, t0, 56\\n\\t\"\n                    \"vle8.v v4, (%[qh])\\n\\t\"\n                    \"vsrl.vi v12, v8, 4\\n\\t\"\n                    \"vsrl.vi v14, v10, 4\\n\\t\"\n                    \"lb zero, 0(%[q8])\\n\\t\"\n                    \"vand.vi v8, v8, 0xF\\n\\t\"\n                    \"vand.vi v10, v10, 0xF\\n\\t\"\n                    \"lb zero, 32(%[q8])\\n\\t\"\n                    \"vsll.vi v0, v4, 4\\n\\t\"\n                    \"vsll.vi v2, v4, 2\\n\\t\"\n                    \"lb zero, 64(%[q8])\\n\\t\"\n                    \"vsrl.vi v6, v4, 2\\n\\t\"\n                    \"vand.vx v0, v0, %[mask]\\n\\t\"\n                    \"lb zero, 96(%[q8])\\n\\t\"\n                    \"vand.vx v2, v2, %[mask]\\n\\t\"\n                    \"vand.vx v4, v4, %[mask]\\n\\t\"\n                    \"vand.vx v6, v6, %[mask]\\n\\t\"\n                    \"vor.vv v8, v8, v0\\n\\t\"\n                    \"lb zero, 127(%[q8])\\n\\t\"\n                    \"vor.vv v10, v10, v2\\n\\t\"\n                    \"vor.vv v12, v12, v4\\n\\t\"\n                    \"vor.vv v14, v14, v6\\n\\t\"\n                    \"vsetvli zero, %[vl128], e8, m8\\n\\t\"\n                    \"vle8.v v0, (%[q8])\\n\\t\"\n                    \"vsub.vx v8, v8, %[vl32]\\n\\t\"\n                    \"vsetvli zero, %[vl64], e8, m4\\n\\t\"\n                    \"vwmul.vv v16, v0, v8\\n\\t\"\n                    \"vwmul.vv v24, v4, v12\\n\\t\"\n                    \"vsetivli zero, 16, e16, m2\\n\\t\"\n                    \"vmv.v.x v0, zero\\n\\t\"\n                    \"vwredsum.vs v10, v16, v0\\n\\t\"\n                    \"vwredsum.vs v9, v18, v0\\n\\t\"\n                    \"vwredsum.vs v8, v20, v0\\n\\t\"\n                    \"vwredsum.vs v7, v22, v0\\n\\t\"\n                    \"vwredsum.vs v11, v24, v0\\n\\t\"\n                    \"vwredsum.vs v12, v26, v0\\n\\t\"\n                    \"vwredsum.vs v13, v28, v0\\n\\t\"\n                    \"vwredsum.vs v14, v30, v0\\n\\t\"\n                    \"vsetivli zero, 4, e32, m1\\n\\t\"\n                    \"vmul.vx v0, v10, t0\\n\\t\"\n                    \"vmul.vx v1, v9, t1\\n\\t\"\n                    \"vmacc.vx v0, t2, v8\\n\\t\"\n                    \"vmacc.vx v1, t3, v7\\n\\t\"\n                    \"vmacc.vx v0, t4, v11\\n\\t\"\n                    \"vmacc.vx v1, t5, v12\\n\\t\"\n                    \"vmacc.vx v0, t6, v13\\n\\t\"\n                    \"vmacc.vx v1, a7, v14\\n\\t\"\n                    \"vadd.vv v0, v0, v1\\n\\t\"\n                    \"vfcvt.f.x.v v0, v0\\n\\t\"\n                    \"vfmv.f.s %[ftmp], v0\\n\\t\"\n                    \"fmadd.s %[sumf], %[d], %[ftmp], %[sumf]\"\n                    : [q6] \"+&r\" (q6), [q6h] \"=&r\" (q6h)\n                    , [scale] \"+&r\" (scale)\n                    , [sumf] \"+&f\" (sumf), [ftmp] \"=&f\" (ftmp)\n                    : [qh] \"r\" (qh), [q8] \"r\" (q8)\n                    , [vl32] \"r\" (32), [vl64] \"r\" (64), [vl128] \"r\" (128)\n                    , [mask] \"r\" (0x30), [d] \"f\" (d)\n                    : \"memory\"\n                    , \"v0\", \"v1\", \"v2\", \"v3\", \"v4\", \"v5\", \"v6\", \"v7\"\n                    , \"v8\", \"v9\", \"v10\", \"v11\", \"v12\", \"v13\", \"v14\", \"v15\"\n                    , \"v16\", \"v17\", \"v18\", \"v19\", \"v20\", \"v21\", \"v22\", \"v23\"\n                    , \"v24\", \"v25\", \"v26\", \"v27\", \"v28\", \"v29\", \"v30\", \"v31\"\n                    , \"t0\", \"t1\", \"t2\", \"t3\", \"t4\", \"t5\", \"t6\", \"a7\"\n                    , \"a6\", \"a5\", \"a4\", \"a3\"\n                );\n                qh += 32;   q8 += 128;\n            }\n        }\n        break;\n    default:\n        assert(false && \"Unsupported vector length\");\n        break;\n    }\n\n    *s = sumf;\n\n#else\n\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n\n    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nstatic void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq1_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    float sumf = 0;\n    for (int i = 0; i < nb; ++i) {\n        // Load qh once for the entire superblock.\n        vuint16mf2_t qh = __riscv_vle16_v_u16mf2(x[i].qh, 8);\n\n        // Calculate ls.\n        vuint16mf2_t temp = __riscv_vsrl_vx_u16mf2(qh, 12, 8);\n        temp = __riscv_vand_vx_u16mf2(temp, 7, 8);\n        vint32m1_t ls = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vwmulu_vx_u32m1(temp, 2, 8));\n        ls = __riscv_vadd_vx_i32m1(ls, 1, 8);\n\n        // Calculate delta.\n        vbool32_t mask = __riscv_vmseq_vx_u16mf2_b32(__riscv_vand_vx_u16mf2(qh, 0x8000, 8), 0, 8);\n        vint32m1_t delta_neg = __riscv_vmv_v_x_i32m1(-1, 8);\n        vint32m1_t delta_pos = __riscv_vmv_v_x_i32m1(1, 8);\n        vint32m1_t delta = __riscv_vmerge_vvm_i32m1(delta_neg, delta_pos, mask, 8);\n\n        // Load qs.\n        vuint8m1_t qs = __riscv_vle8_v_u8m1(x[i].qs, 32);\n\n        // Prepare the indices.\n        const uint64_t shift = 0x0009000600030000;\n        vuint16m2_t qh_shift = __riscv_vreinterpret_v_u64m2_u16m2(__riscv_vmv_v_x_u64m2(shift, 8));\n        vuint16m2_t qh_gather_index = __riscv_vreinterpret_v_i16m2_u16m2(\n            __riscv_vdiv_vx_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vid_v_u16m2(32)), 4, 32));\n        vuint16m2_t qh_ext = __riscv_vlmul_ext_v_u16m1_u16m2(__riscv_vlmul_ext_v_u16mf2_u16m1(qh));\n        vuint16m2_t qh_index = __riscv_vrgather_vv_u16m2(qh_ext, qh_gather_index, 32);\n        qh_index = __riscv_vsrl_vv_u16m2(qh_index, qh_shift, 32);\n        qh_index = __riscv_vand_vx_u16m2(qh_index, 7, 32);\n        qh_index = __riscv_vsll_vx_u16m2(qh_index, 8, 32);\n        qh_index = __riscv_vor_vv_u16m2(qh_index, __riscv_vzext_vf2_u16m2(qs, 32), 32);\n        vuint16m2_t index = __riscv_vsll_vx_u16m2(qh_index, 3, 32);\n\n        // Final lsums.\n        int32_t lsums_s[8];\n        vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1);\n\n        // Sub-blocks 1-4\n        {\n            vuint16m1_t grid_index0 = __riscv_vget_v_u16m2_u16m1(index, 0);\n            vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 16));\n            vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 128);\n            vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128);\n            lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 0), one_scalar, 32));\n            lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 1), one_scalar, 32));\n            lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 2), one_scalar, 32));\n            lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 3), one_scalar, 32));\n        }\n        __asm__ __volatile__(\"\" ::: \"memory\");\n        // Sub-blocks 5-8\n        {\n            vuint16m1_t grid_index1 = __riscv_vget_v_u16m2_u16m1(index, 1);\n            vint8m4_t grid1 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index1, 16));\n            vint8m4_t q81 = __riscv_vle8_v_i8m4(&y[i].qs[128], 128);\n            vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(grid1, q81, 128);\n            lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 0), one_scalar, 32));\n            lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 1), one_scalar, 32));\n            lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 2), one_scalar, 32));\n            lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 3), one_scalar, 32));\n        }\n        __asm__ __volatile__(\"\" ::: \"memory\");\n        vint32m1_t lsums = __riscv_vle32_v_i32m1(&lsums_s[0], 8);\n\n        // Calculate the bsums.\n        vint16m1_t bsums_0 = __riscv_vle16_v_i16m1(y[i].bsums, 16);\n        const vuint32m1_t bsums_i32 = __riscv_vreinterpret_v_u16m1_u32m1(__riscv_vreinterpret_v_i16m1_u16m1(bsums_0));\n        const vint16mf2_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 0, 8));\n        const vint16mf2_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 16, 8));\n        const vint32m1_t bsums = __riscv_vwadd_vv_i32m1(bsums_i32_0, bsums_i32_1, 8);\n\n        // Accumulation.\n        vint32m1_t sumi_v = __riscv_vmul_vv_i32m1(ls, lsums, 8);\n        vint32m1_t sumi1_v = __riscv_vmul_vv_i32m1(__riscv_vmul_vv_i32m1(ls, delta, 8), bsums, 8);\n\n        // Update sumf.\n        int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8));\n        int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8));\n        sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined __riscv_v_intrinsic\n    switch (__riscv_vlenb() * 8) {\n        case 256:\n            ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n        default:\n            ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n    }\n#else\n    ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nstatic void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq1_m * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    iq1m_scale_t scale;\n    float sumf = 0.0f;\n    for (int i = 0; i < nb; ++i) {\n        const int8_t   * q8 = y[i].qs;\n        const uint8_t  * qs = x[i].qs;\n        const uint8_t  * qh = x[i].qh;\n        const uint16_t * sc = (const uint16_t *)x[i].scales;\n\n        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);\n\n        // Accumulators.\n        vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 16);\n        vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 16);\n\n        // We process 4 sub-blocks together.\n        for (int ib = 0; ib < QK_K/128; ib++) {\n            // Load qh for 4 sub-blocks.\n            const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 8);\n            const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 8);\n            const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 8);\n            const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1(\n                __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 8)), 16);\n            qh += 8;\n\n            // Prepare grid indices.\n            const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 16), 16);\n            const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 8));\n            vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 16), 0x700, 16), 16);\n            index = __riscv_vsll_vx_u16m1(index, 3, 16);\n            qs += 16;\n\n            // Load the grid.\n            const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4(\n                __riscv_vluxei16_v_u64m4(iq1s_grid, index, 16)));\n\n            // Prepare the deltas.\n            const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16(\n                __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 8)), 16), 0, 16);\n            const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 16);\n            const vint64m4_t delta_neg = __riscv_vmv_v_x_i64m4(0xffffffffffffffff, 16);\n            const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4(\n                __riscv_vmerge_vvm_i64m4(delta_pos, delta_neg, mask, 16));\n\n            // Load q8 for sub-blocks.\n            const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128);\n            q8 += 128;\n\n            // Calculate the lsums.\n            const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 128);\n            const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 128);\n\n            // Prepare the scales.\n            const int16_t ls_0_0 = 2*((sc[0] >> 0) & 0x7) + 1;\n            const int16_t ls_0_1 = 2*((sc[0] >> 3) & 0x7) + 1;\n            const int16_t ls_1_0 = 2*((sc[0] >> 6) & 0x7) + 1;\n            const int16_t ls_1_1 = 2*((sc[0] >> 9) & 0x7) + 1;\n            const int16_t ls_2_0 = 2*((sc[1] >> 0) & 0x7) + 1;\n            const int16_t ls_2_1 = 2*((sc[1] >> 3) & 0x7) + 1;\n            const int16_t ls_3_0 = 2*((sc[1] >> 6) & 0x7) + 1;\n            const int16_t ls_3_1 = 2*((sc[1] >> 9) & 0x7) + 1;\n            sc += 2;\n\n            // Accumulate in acc0 and acc1 for each sub-block.\n            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16);\n            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16);\n            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16);\n            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16);\n            //\n            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16);\n            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16);\n            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16);\n            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16);\n            //\n            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16);\n            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16);\n            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16);\n            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16);\n            //\n            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16);\n            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16);\n            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16);\n            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16);\n        }\n\n        // Reduce and accumulate in `sumf`.\n        vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1);\n        int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 16));\n        int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 16));\n        sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2);\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined __riscv_v_intrinsic\n    switch (__riscv_vlenb() * 8) {\n        case 256:\n            ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n        default:\n            ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n    }\n#else\n    ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nstatic const uint8_t sign_gather_indices_arr[64] = {\n    0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3,\n    4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7\n};\n\nstatic const uint8_t sign_bit_masks_arr[64] = {\n    1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128,\n    1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128\n};\n\n\nstatic void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);\n\n    const block_iq2_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n    const uint64_t * grid64 = (const uint64_t *)iq2s_grid;\n\n    // Pre-load Constants\n    vuint8m2_t v_ids = __riscv_vid_v_u8m2(32);\n    vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 32);\n    vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 32);\n    vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 32);\n    vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 32);\n    uint16_t shift_qh_arr[4] = {11, 9, 7, 5};\n    vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 4);\n\n    float sumf = 0.0f;\n\n    for (int i = 0; i < nb; ++i) {\n        const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n\n        const uint8_t * GGML_RESTRICT qs = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const uint8_t * GGML_RESTRICT scales = x[i].scales;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        const uint8_t * signs_ptr = qs + 32;\n        float sum_block = 0.0f;\n\n        for (int ib = 0; ib < 8; ++ib) {\n\n            // Load Low Bits [4 bytes]\n            vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 4);\n            qs += 4;\n\n            // Load 1 byte. It contains bits for 4 mini-blocks.\n            uint8_t qh_val = *qh++;\n\n            // Combine Low + High bits of 10bit indices\n            vuint8mf4_t v_qh_raw = __riscv_vmv_v_x_u8mf4(qh_val, 4);\n            vuint16mf2_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qh_raw, 4);\n            vuint16mf2_t v_qh_mf2 = __riscv_vsll_vv_u16mf2(v_qh_u16, v_shift_qh, 4);\n            v_qh_mf2 = __riscv_vand_vx_u16mf2(v_qh_mf2, 0x1800, 4);\n            vuint16mf2_t v_qs_u16_mf2 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 4);\n            vuint16mf2_t v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16_mf2, 3, 4);\n            vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_mf2, 4);\n\n            // Lookup Grid\n            vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(__riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 4)));\n\n            vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 4);\n            signs_ptr += 4;\n            vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);\n            vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32);\n\n            // generating sign mask\n            vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32);\n            vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32);\n\n            vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32);\n            q8 += 32;\n\n            // apply signs\n            vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative,v_q8, v_q8, 0, 32);\n            vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 32);\n\n            // Reduction\n            vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);\n\n            // Reduce 0-15 (First Half)\n            int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(\n                __riscv_vget_v_i16m4_i16m2(v_dot, 0), v_zero, 16));\n\n            // Reduce 16-31 (Second Half)\n            int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(\n                __riscv_vget_v_i16m4_i16m2(v_dot, 1), v_zero, 16));\n\n            // Apply sub Scales\n            uint8_t sc = *scales++;\n\n            sum_block += s0 * (2 * (sc & 0xF) + 1);\n            sum_block += s1 * (2 * (sc >> 4)  + 1);\n        }\n        sumf += sum_block * combined_scale;\n    }\n    *s = 0.125f * sumf;\n}\n\nstatic void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);\n\n    const block_iq2_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n    const uint64_t * grid64 = (const uint64_t *)iq2s_grid;\n\n    // --- Pre-load Constants ---\n    uint16_t gather_qh_arr[8] = {0, 0, 0, 0, 1, 1, 1, 1};\n    vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 8);\n    uint16_t shift_qh_arr[8] = {11, 9, 7, 5, 11, 9, 7, 5};\n    vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 8);\n\n    // Constants for sign extraction\n    vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64);\n    vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64);\n\n    float sumf = 0.0f;\n\n    for (int i = 0; i < nb; ++i) {\n        const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n\n        const uint8_t * GGML_RESTRICT qs = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const uint8_t * GGML_RESTRICT scales = x[i].scales;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        const uint8_t * signs_ptr = qs + 32;\n\n        float sum_block = 0.0f;\n\n        for (int ib = 0; ib < 4; ++ib) {\n            // Combine low + high bits\n            vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 8);\n            qs += 8;\n            uint16_t qh_val;\n            memcpy(&qh_val, qh, 2);\n            qh += 2;\n            vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8((const uint8_t*)&qh_val, 2);\n            vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 2);\n            vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16);\n            vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 8);\n            v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 8);\n\n            // Mask: We want bits 11-12. 0x1800 = 0001 1000 0000 0000\n            v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 8);\n            vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 8);\n\n            // Multiply by 8 to get byte offset, instead of element offset\n            v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 8);\n            vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 8);\n\n            // Lookup Grid using Byte Offsets\n            vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 8);\n\n            vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals);\n            vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8);\n\n            // Load signs and generate sign mask\n            vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 8);\n            signs_ptr += 8;\n\n            vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);\n            vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64);\n\n            vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64);\n            vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64);\n\n            vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);\n            q8 += 64;\n\n            vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64);\n            vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 64);\n\n            vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);\n\n            int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(\n                __riscv_vget_v_i16m4_i16m1(v_dot, 0), v_zero, 16));\n            int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(\n                __riscv_vget_v_i16m4_i16m1(v_dot, 1), v_zero, 16));\n            int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(\n                __riscv_vget_v_i16m4_i16m1(v_dot, 2), v_zero, 16));\n            int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(\n                __riscv_vget_v_i16m4_i16m1(v_dot, 3), v_zero, 16));\n\n            uint8_t sc0 = scales[0];\n            uint8_t sc1 = scales[1];\n            scales += 2;\n\n            sum_block += s0 * (2 * (sc0 & 0xF) + 1);\n            sum_block += s1 * (2 * (sc0 >> 4)  + 1);\n            sum_block += s2 * (2 * (sc1 & 0xF) + 1);\n            sum_block += s3 * (2 * (sc1 >> 4)  + 1);\n        }\n        sumf += sum_block * combined_scale;\n    }\n    *s = 0.125f * sumf;\n}\n\nvoid ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined __riscv_v_intrinsic\n    switch (__riscv_vlenb() * 8) {\n        case 128:\n            ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n        case 256:\n            ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n        default:\n            ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n    }\n#else\n    ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\n#if defined(__riscv_v)\nstatic const int8_t keven_signs_q2xs[1024] = {\n     1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,\n     1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,\n     1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1, -1,\n     1,  1, -1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1,  1,\n     1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1, -1,\n     1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1,  1,\n     1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1,  1,\n     1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1, -1,\n     1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1, -1,\n     1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1,  1,\n     1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1,  1,\n     1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1, -1,\n     1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1,  1,\n     1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,\n     1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1, -1,\n     1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1,\n     1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1, -1,\n     1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1,  1,\n     1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1,\n     1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1,\n     1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1,  1,\n     1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1, -1,\n     1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1, -1,\n     1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1,\n     1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1,\n     1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1, -1,\n     1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1,\n     1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1,\n     1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1,\n     1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,  1,\n     1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1,\n     1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,\n};\n#endif\n\nstatic void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_xs * GGML_RESTRICT x = vx;\n    const block_q8_K   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n    const uint64_t * grid64  = (const uint64_t *)iq2xs_grid;\n\n    float sumf = 0.0f;\n\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint16_t * GGML_RESTRICT qs = x[i].qs;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n        const uint8_t  * GGML_RESTRICT scales = x[i].scales;\n\n        int32_t sum_int = 0;\n\n        // Loop over 4 subblocks of 64 elements (QK_K = 256)\n        for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) {\n            // Load 8 uint16 indices (controls 64 values)\n            vuint16mf2_t v_qs = __riscv_vle16_v_u16mf2(qs, 8);\n            qs += 8;\n\n            // Extract indices for grid (low 9 bits) and signs (high 7 bits)\n            // Multiply by 8 (<< 3) for byte offsets into the uint64 tables\n            vuint16mf2_t vidx_grid = __riscv_vsll_vx_u16mf2(__riscv_vand_vx_u16mf2(v_qs, 511, 8), 3, 8);\n            vuint16mf2_t vidx_sign = __riscv_vsll_vx_u16mf2(__riscv_vsrl_vx_u16mf2(v_qs, 9, 8), 3, 8);\n\n            vuint64m2_t vq2_64 = __riscv_vluxei16_v_u64m2(grid64, vidx_grid, 8);\n            vuint64m2_t vs2_64 = __riscv_vluxei16_v_u64m2(signs64, vidx_sign, 8);\n\n            vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64));\n            vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64));\n\n            vint8m2_t q2_final = __riscv_vmul_vv_i8m2(q2u, q2s, 64);\n\n            vint8m2_t q8v = __riscv_vle8_v_i8m2(q8, 64);\n            q8 += 64;\n\n            vint16m4_t prod = __riscv_vwmul_vv_i16m4(q2_final, q8v, 64);\n\n            vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1);\n\n            int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(\n                           __riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 16));\n            int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(\n                           __riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 16));\n            int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(\n                           __riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 16));\n            int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(\n                           __riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 16));\n\n            const uint8_t scale_byte_1 = scales[0];\n            const uint8_t scale_byte_2 = scales[1];\n            scales += 2;\n\n            sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1);\n            sum_int += sum1 * ((scale_byte_1 >> 4)   * 2 + 1);\n            sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1);\n            sum_int += sum3 * ((scale_byte_2 >> 4)   * 2 + 1);\n        }\n\n        sumf += d * sum_int;\n    }\n    *s = 0.125f * sumf;\n}\n\nvoid ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined __riscv_v_intrinsic\n      switch (__riscv_vlenb() * 8) {\n          case 256:\n              ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);\n              break;\n          default:\n              ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n              break;\n      }\n#else\n    ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nstatic void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_xxs * GGML_RESTRICT x = vx;\n    const block_q8_K    * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n    const uint64_t * grid64  = (const uint64_t *)iq2xxs_grid;\n\n    uint32_t shift_constants[4] = {0, 7, 14, 21};\n    vuint32m1_t v_shifts = __riscv_vle32_v_u32m1(shift_constants, 4);\n\n    float sumf = 0.0f;\n    for (int i = 0; i < nb; ++i) {\n        const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n\n        const uint8_t  * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n\n        float sum = 0.0f;\n\n        #pragma GCC unroll 1\n        for (int ib32 = 0; ib32 < QK_K / 32; ib32 += 2) {\n            vint8m2_t q8_1 = __riscv_vle8_v_i8m2(q8, 32); q8 += 32;\n            vint8m2_t q8_2 = __riscv_vle8_v_i8m2(q8, 32); q8 += 32;\n\n            vuint8mf4_t v_raw_q2_1 = __riscv_vle8_v_u8mf4(q2_ptr, 4);\n            vuint8mf4_t v_raw_q2_2 = __riscv_vle8_v_u8mf4(q2_ptr + 8, 4);\n\n            vuint16mf2_t vidx_q2_1 = __riscv_vwcvtu_x_x_v_u16mf2(v_raw_q2_1, 4);\n            vuint16mf2_t vidx_q2_2 = __riscv_vwcvtu_x_x_v_u16mf2(v_raw_q2_2, 4);\n\n            vidx_q2_1 = __riscv_vsll_vx_u16mf2(vidx_q2_1, 3, 4);\n            vidx_q2_2 = __riscv_vsll_vx_u16mf2(vidx_q2_2, 3, 4);\n\n            uint32_t s_packed_1, s_packed_2;\n            memcpy(&s_packed_1, q2_ptr + 4, 4);\n            memcpy(&s_packed_2, q2_ptr + 12, 4);\n\n            vuint32m1_t v_s_1 = __riscv_vmv_v_x_u32m1(s_packed_1, 4);\n            vuint32m1_t v_s_2 = __riscv_vmv_v_x_u32m1(s_packed_2, 4);\n            v_s_1 = __riscv_vsrl_vv_u32m1(v_s_1, v_shifts, 4);\n            v_s_2 = __riscv_vsrl_vv_u32m1(v_s_2, v_shifts, 4);\n\n            v_s_1 = __riscv_vand_vx_u32m1(v_s_1, 127, 4);\n            v_s_2 = __riscv_vand_vx_u32m1(v_s_2, 127, 4);\n\n            vuint16mf2_t vidx_s2_1 = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_1, 4), 3, 4);\n            vuint16mf2_t vidx_s2_2 = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_2, 4), 3, 4);\n\n            vuint64m2_t vq2_64_1 = __riscv_vluxei16_v_u64m2(grid64, vidx_q2_1, 4);\n            vuint64m2_t vq2_64_2 = __riscv_vluxei16_v_u64m2(grid64, vidx_q2_2, 4);\n\n            vint8m2_t q2_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64_1));\n            vint8m2_t q2_2 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64_2));\n\n            vuint64m2_t vs2_64_1 = __riscv_vluxei16_v_u64m2(signs64, vidx_s2_1, 4);\n            vuint64m2_t vs2_64_2 = __riscv_vluxei16_v_u64m2(signs64, vidx_s2_2, 4);\n            vint8m2_t s2_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64_1));\n            vint8m2_t s2_2 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64_2));\n\n            vint8m2_t q8s_1 = __riscv_vmul_vv_i8m2(q8_1, s2_1, 32);\n            vint8m2_t q8s_2 = __riscv_vmul_vv_i8m2(q8_2, s2_2, 32);\n\n            vint16m4_t dot1 = __riscv_vwmul_vv_i16m4(q8s_1, q2_1, 32);\n            vint16m4_t dot2 = __riscv_vwmul_vv_i16m4(q8s_2, q2_2, 32);\n\n            vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1);\n            vint32m1_t sumv1 = __riscv_vwredsum_vs_i16m4_i32m1(dot1, zero_vec, 32);\n            vint32m1_t sumv2 = __riscv_vwredsum_vs_i16m4_i32m1(dot2, zero_vec, 32);\n\n            int32_t scalar_sum1 = __riscv_vmv_x_s_i32m1_i32(sumv1);\n            int32_t scalar_sum2 = __riscv_vmv_x_s_i32m1_i32(sumv2);\n\n            int16_t scale1 = 2 * ((s_packed_1 >> 28) & 0xF) + 1;\n            int16_t scale2 = 2 * ((s_packed_2 >> 28) & 0xF) + 1;\n\n            sum += scalar_sum1 * scale1 + scalar_sum2 * scale2;\n            q2_ptr += 16;\n        }\n        sumf += sum * combined_scale;\n    }\n    *s = 0.125f * sumf;\n}\n\nstatic void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_xxs * GGML_RESTRICT x = vx;\n    const block_q8_K    * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n    const uint64_t * grid64  = (const uint64_t *)iq2xxs_grid;\n\n    uint32_t shift_constants[4] = {0, 7, 14, 21};\n    vuint32mf2_t v_shifts = __riscv_vle32_v_u32mf2(shift_constants, 4);\n\n    float sumf = 0.0f;\n\n    for (int i = 0; i < nb; ++i) {\n        const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n\n        const uint8_t  * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n\n        float sum = 0.0f;\n\n        for (int ib32 = 0; ib32 < QK_K / 32; ib32 += 2) {\n            vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8, 32); q8 += 32;\n            vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8, 32); q8 += 32;\n\n            vuint8mf8_t v_raw_q2_1 = __riscv_vle8_v_u8mf8(q2_ptr, 4);\n            vuint8mf8_t v_raw_q2_2 = __riscv_vle8_v_u8mf8(q2_ptr + 8, 4);\n\n            vuint16mf4_t vidx_q2_1 = __riscv_vwcvtu_x_x_v_u16mf4(v_raw_q2_1, 4);\n            vuint16mf4_t vidx_q2_2 = __riscv_vwcvtu_x_x_v_u16mf4(v_raw_q2_2, 4);\n\n            vidx_q2_1 = __riscv_vsll_vx_u16mf4(vidx_q2_1, 3, 4);\n            vidx_q2_2 = __riscv_vsll_vx_u16mf4(vidx_q2_2, 3, 4);\n\n            uint32_t s_packed_1, s_packed_2;\n            memcpy(&s_packed_1, q2_ptr + 4, 4);\n            memcpy(&s_packed_2, q2_ptr + 12, 4);\n\n            vuint32mf2_t v_s_1 = __riscv_vmv_v_x_u32mf2(s_packed_1, 4);\n            vuint32mf2_t v_s_2 = __riscv_vmv_v_x_u32mf2(s_packed_2, 4);\n\n            v_s_1 = __riscv_vsrl_vv_u32mf2(v_s_1, v_shifts, 4);\n            v_s_2 = __riscv_vsrl_vv_u32mf2(v_s_2, v_shifts, 4);\n\n            v_s_1 = __riscv_vand_vx_u32mf2(v_s_1, 127, 4);\n            v_s_2 = __riscv_vand_vx_u32mf2(v_s_2, 127, 4);\n\n            // Narrow u32 -> u16 (vncvt) and Scale by 8 to get byte offsets\n            vuint16mf4_t vidx_s2_1 = __riscv_vsll_vx_u16mf4(__riscv_vncvt_x_x_w_u16mf4(v_s_1, 4), 3, 4);\n            vuint16mf4_t vidx_s2_2 = __riscv_vsll_vx_u16mf4(__riscv_vncvt_x_x_w_u16mf4(v_s_2, 4), 3, 4);\n\n            // Load q2 values from lookup grid\n            vuint64m1_t vq2_64_1 = __riscv_vluxei16_v_u64m1(grid64, vidx_q2_1, 4);\n            vuint64m1_t vq2_64_2 = __riscv_vluxei16_v_u64m1(grid64, vidx_q2_2, 4);\n            vint8m1_t q2_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vq2_64_1));\n            vint8m1_t q2_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vq2_64_2));\n\n            // Load sign values\n            vuint64m1_t vs2_64_1 = __riscv_vluxei16_v_u64m1(signs64, vidx_s2_1, 4);\n            vuint64m1_t vs2_64_2 = __riscv_vluxei16_v_u64m1(signs64, vidx_s2_2, 4);\n            vint8m1_t s2_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vs2_64_1));\n            vint8m1_t s2_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vs2_64_2));\n\n            // Apply signs to q8\n            vint8m1_t q8s_1 = __riscv_vmul_vv_i8m1(q8_1, s2_1, 32);\n            vint8m1_t q8s_2 = __riscv_vmul_vv_i8m1(q8_2, s2_2, 32);\n\n            // multiplying q2 with q8\n            vint16m2_t dot1 = __riscv_vwmul_vv_i16m2(q8s_1, q2_1, 32);\n            vint16m2_t dot2 = __riscv_vwmul_vv_i16m2(q8s_2, q2_2, 32);\n\n            vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1);\n            vint32m1_t sumv1 = __riscv_vwredsum_vs_i16m2_i32m1(dot1, zero_vec, 32);\n            vint32m1_t sumv2 = __riscv_vwredsum_vs_i16m2_i32m1(dot2, zero_vec, 32);\n            int32_t scalar_sum1 = __riscv_vmv_x_s_i32m1_i32(sumv1);\n            int32_t scalar_sum2 = __riscv_vmv_x_s_i32m1_i32(sumv2);\n            int16_t scale1 = 2 * ((s_packed_1 >> 28) & 0xF) + 1;\n            int16_t scale2 = 2 * ((s_packed_2 >> 28) & 0xF) + 1;\n\n            sum += scalar_sum1 * scale1 + scalar_sum2 * scale2;\n            q2_ptr += 16;\n        }\n        sumf += sum * combined_scale;\n    }\n    *s = 0.125f * sumf;\n}\n\nvoid ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined __riscv_v_intrinsic\n    switch (__riscv_vlenb() * 8) {\n        case 128:\n            ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n        default:\n            ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n    }\n#else\n    ggml_vec_dot_iq2_xxs_q8_K(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nstatic void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq3_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    const uint64_t * grid64 = (const uint64_t *)iq3s_grid;\n\n    // --- Pre-load Constants ---\n    const uint16_t qh_bit_shifts_arr[16] = {\n        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15\n    };\n    vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64);\n    vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64);\n    vuint16m1_t v_qh_shifts = __riscv_vle16_v_u16m1(qh_bit_shifts_arr, 16);\n\n    float sumf = 0.0f;\n\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float combined_scale = d * y[i].d;\n\n        const uint8_t * GGML_RESTRICT qs = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const uint8_t * GGML_RESTRICT scales = x[i].scales;\n        const uint8_t * GGML_RESTRICT signs = x[i].signs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        float sum_block = 0.0f;\n\n        // Loop: Process 64 weights (16 mini-blocks of 4) per iteration\n        for (int ib = 0; ib < 4; ++ib) {\n\n            vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 16);\n            qs += 16;\n\n            uint16_t qh_val;\n            memcpy(&qh_val, qh, 2);\n            qh += 2;\n\n            vuint16m1_t v_qh_val = __riscv_vmv_v_x_u16m1(qh_val, 16);\n            // Extract bits: (qh >> i) & 1\n            v_qh_val = __riscv_vsrl_vv_u16m1(v_qh_val, v_qh_shifts, 16);\n            v_qh_val = __riscv_vand_vx_u16m1(v_qh_val, 1, 16);\n\n            vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 16);\n            v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 16);\n            v_qh_val = __riscv_vsll_vx_u16m1(v_qh_val, 10, 16);\n            vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_val, 16);\n\n            // Grid value is 4xuint8\n            vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2((const uint32_t *)grid64, v_grid_offsets, 16);\n            vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed);\n            vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 8);\n            signs += 8;\n\n            // Generate sign mask\n            vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);\n            vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64);\n            vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64);\n            vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64);\n\n            vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);\n            q8 += 64;\n\n            // Apply Signs\n            vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64);\n            vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 64);\n\n            // Reduction\n            vint16m2_t v_dot_lo = __riscv_vget_v_i16m4_i16m2(v_dot, 0);\n            vint16m2_t v_dot_hi = __riscv_vget_v_i16m4_i16m2(v_dot, 1);\n            vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);\n\n            int32_t s_lo = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_lo, v_zero, 32));\n            int32_t s_hi = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_hi, v_zero, 32));\n\n            // Apply sub-scales\n            uint8_t sc_byte = *scales++;\n            int sc_lo = (sc_byte & 0xF) * 2 + 1;\n            int sc_hi = (sc_byte >> 4)  * 2 + 1;\n\n            sum_block += s_lo * sc_lo + s_hi * sc_hi;\n        }\n        sumf += sum_block * combined_scale;\n    }\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined __riscv_v_intrinsic\n    switch (__riscv_vlenb() * 8) {\n        case 256:\n            ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n        default:\n            ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n    }\n#else\n    ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nstatic void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq3_xxs * GGML_RESTRICT x = vx;\n    const block_q8_K    * GGML_RESTRICT y = vy;\n    const int nb = n / QK_K;\n\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n    const uint32_t * grid32  = (const uint32_t *)iq3xxs_grid;\n\n    // constants for unpacking logic\n    const uint32_t shifts_val[8] = {0, 7, 14, 21, 0, 7, 14, 21};\n    vuint32m1_t v_shifts = __riscv_vle32_v_u32m1(shifts_val, 8);\n\n    const uint32_t gather_idx_val[8] = {0, 0, 0, 0, 1, 1, 1, 1};\n    vuint32m1_t v_gather_idx = __riscv_vle32_v_u32m1(gather_idx_val, 8);\n\n    uint32_t aux32[2];\n    float sumf = 0.0f;\n\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n\n        const uint8_t * GGML_RESTRICT q3_indices = x[i].qs;\n        const uint8_t * GGML_RESTRICT metadata   = x[i].qs + QK_K/4;\n        const int8_t  * GGML_RESTRICT q8         = y[i].qs;\n\n        float block_sum = 0.0f;\n\n        for (int ib = 0; ib < QK_K / 64; ++ib) {\n            // Load q8 (64 bytes)\n            vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);\n            q8 += 64;\n\n            // load of metadata via memcpy\n            memcpy(aux32, metadata, 2 * sizeof(uint32_t));\n            metadata += 2 * sizeof(uint32_t);\n\n            // Load q3 indices and gather magnitudes\n            vuint8mf2_t v_q3_idx_u8 = __riscv_vle8_v_u8mf2(q3_indices, 16);\n            q3_indices += 16;\n\n            vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_u8, 4, 16);\n            vuint32m2_t v_q3_magnitudes_u32 = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 16);\n            vint8m2_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u32m2_u8m2(v_q3_magnitudes_u32));\n\n            // --- Unpacking of Sign Indices ---\n\n            // 1. Load the 2 auxiliary 32-bit integers into a vector\n            vuint32m1_t v_aux = __riscv_vle32_v_u32m1(aux32, 2);\n\n            // 2. Broadcast/Gather: replicate aux[0] to first 4 lanes, aux[1] to next 4 lanes\n            vuint32m1_t v_aux_expanded = __riscv_vrgather_vv_u32m1(v_aux, v_gather_idx, 8);\n\n            // 3. Apply Shifts and Mask: ((val >> shift) & 127)\n            vuint32m1_t v_s_vals_raw = __riscv_vand_vx_u32m1(__riscv_vsrl_vv_u32m1(v_aux_expanded, v_shifts, 8), 127, 8);\n\n            // 4. Narrow to u16 (required for vluxei index) and multiply by 8 (byte offset for u64 table)\n            vuint16mf2_t sign_indices_byte_offset = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_vals_raw, 8), 3, 8);\n\n            // 5. Gather Signs\n            vuint64m2_t v_s_vals_u64 = __riscv_vluxei16_v_u64m2(signs64, sign_indices_byte_offset, 8);\n            vint8m2_t v_s_vals = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(v_s_vals_u64));\n\n            vint8m2_t v_q3_signed = __riscv_vmul_vv_i8m2(v_q3_magnitudes, v_s_vals, 64);\n            vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_q8, v_q3_signed, 64);\n\n            vint16m2_t v_dot_1 = __riscv_vget_v_i16m4_i16m2(v_dot, 0);\n            vint16m2_t v_dot_2 = __riscv_vget_v_i16m4_i16m2(v_dot, 1);\n\n            vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);\n            vint32m1_t v_sum_1 = __riscv_vwredsum_vs_i16m2_i32m1(v_dot_1, v_zero, 32);\n            vint32m1_t v_sum_2 = __riscv_vwredsum_vs_i16m2_i32m1(v_dot_2, v_zero, 32);\n\n            int32_t sum1_i = __riscv_vmv_x_s_i32m1_i32(v_sum_1);\n            int32_t sum2_i = __riscv_vmv_x_s_i32m1_i32(v_sum_2);\n\n            const float scale1_f = (float)(2 * (aux32[0] >> 28) + 1);\n            const float scale2_f = (float)(2 * (aux32[1] >> 28) + 1);\n\n            block_sum += sum1_i * scale1_f + sum2_i * scale2_f;\n        }\n\n        sumf += d * block_sum;\n    }\n    *s = 0.25f * sumf;\n}\n\nvoid ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined __riscv_v_intrinsic\n    switch (__riscv_vlenb() * 8) {\n        case 256:\n            ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n        default:\n            ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n    }\n#else\n    ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nstatic void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK4_NL == 0);\n    static_assert(QK4_NL == QK8_0, \"QK4_NL and QK8_0 must be the same\");\n\n    const block_iq4_nl * GGML_RESTRICT x = vx;\n    const block_q8_0   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK4_NL;\n\n    int ib = 0;\n    float sumf = 0;\n\n    // Load the lookup table once.\n    const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16);\n    int acc1, acc2;\n\n    // We process 2 blocks at once.\n    for (; ib + 1 < nb; ib += 2) {\n        // Weights and activations.\n        vuint8m1_t iq4_packed1 = __riscv_vle8_v_u8m1(x[ib + 0].qs, 16);\n        vint8m2_t q8b1 = __riscv_vle8_v_i8m2(y[ib + 0].qs, 32);\n        vuint8m1_t iq4_packed2 = __riscv_vle8_v_u8m1(x[ib + 1].qs, 16);\n        vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32);\n\n        // Unpack the weight blocks.\n        vuint8m2_t iq4bits1;\n        iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 0, __riscv_vand_vx_u8m1(iq4_packed1, 0xf, 16));\n        iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 1, __riscv_vsrl_vx_u8m1(iq4_packed1, 4, 16));\n        vuint8m2_t iq4bits2;\n        iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 0, __riscv_vand_vx_u8m1(iq4_packed2, 0xf, 16));\n        iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 1, __riscv_vsrl_vx_u8m1(iq4_packed2, 4, 16));\n\n        // Gather values from the lookup table.\n        vint8m2_t iq4b1 = __riscv_vrgather_vv_i8m2(values, iq4bits1, 32);\n        vint8m2_t iq4b2 = __riscv_vrgather_vv_i8m2(values, iq4bits2, 32);\n\n        // Accumulation.\n        vint16m4_t sum1 = __riscv_vwmul_vv_i16m4(q8b1, iq4b1, 32);\n        vint16m4_t sum2 = __riscv_vwmul_vv_i16m4(q8b2, iq4b2, 32);\n        __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m4_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 32), 1);\n        __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m4_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 32), 1);\n        sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1));\n        sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2));\n    }\n\n    *s = sumf;\n}\n\nstatic void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK4_NL == 0);\n    static_assert(QK4_NL == QK8_0, \"QK4_NL and QK8_0 must be the same\");\n\n    const block_iq4_nl * GGML_RESTRICT x = vx;\n    const block_q8_0   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK4_NL;\n\n    int ib = 0;\n    float sumf = 0;\n\n    // Load the lookup table once.\n    const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16);\n    int acc1, acc2;\n\n    // We process 2 blocks at once.\n    for (; ib + 1 < nb; ib += 2) {\n        // Weights and activations.\n        vuint8mf2_t iq4_packed1 = __riscv_vle8_v_u8mf2(x[ib + 0].qs, 16);\n        vint8mf2_t q8b_lo1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs, 16);\n        vint8mf2_t q8b_hi1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs + 16, 16);\n        vuint8mf2_t iq4_packed2 = __riscv_vle8_v_u8mf2(x[ib + 1].qs, 16);\n        vint8mf2_t q8b_lo2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs, 16);\n        vint8mf2_t q8b_hi2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs + 16, 16);\n\n        // Unpack the weight blocks.\n        vuint8mf2_t iq4bits_lo1 = __riscv_vand_vx_u8mf2(iq4_packed1, 0xf, 16);\n        vuint8mf2_t iq4bits_hi1 = __riscv_vsrl_vx_u8mf2(iq4_packed1, 4, 16);\n        vuint8mf2_t iq4bits_lo2 = __riscv_vand_vx_u8mf2(iq4_packed2, 0xf, 16);\n        vuint8mf2_t iq4bits_hi2 = __riscv_vsrl_vx_u8mf2(iq4_packed2, 4, 16);\n\n        // Gather values from the lookup table.\n        vint8mf2_t iq4b_lo1 = __riscv_vrgather_vv_i8mf2(values, iq4bits_lo1, 16);\n        vint8mf2_t iq4b_hi1 = __riscv_vrgather_vv_i8mf2(values, iq4bits_hi1, 16);\n        vint8mf2_t iq4b_lo2 = __riscv_vrgather_vv_i8mf2(values, iq4bits_lo2, 16);\n        vint8mf2_t iq4b_hi2 = __riscv_vrgather_vv_i8mf2(values, iq4bits_hi2, 16);\n\n        // Accumulation.\n        vint16m1_t sum1 = __riscv_vwmul_vv_i16m1(q8b_lo1, iq4b_lo1, 16);\n        sum1 = __riscv_vwmacc_vv_i16m1(sum1, q8b_hi1, iq4b_hi1, 16);\n        vint16m1_t sum2 = __riscv_vwmul_vv_i16m1(q8b_lo2, iq4b_lo2, 16);\n        sum2 = __riscv_vwmacc_vv_i16m1(sum2, q8b_hi2, iq4b_hi2, 16);\n        __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m1_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 16), 1);\n        __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m1_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 16), 1);\n        sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1));\n        sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2));\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined __riscv_v_intrinsic\n    switch (__riscv_vlenb() * 8) {\n        case 128:\n            ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n        default:\n            ggml_vec_dot_iq4_nl_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n    }\n#else\n    ggml_vec_dot_iq4_nl_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nstatic void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_K == 0);\n\n    const block_iq4_xs * GGML_RESTRICT x = vx;\n    const block_q8_K   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __riscv_v_intrinsic\n    const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16);\n    float sumf = 0;\n    int acc[4];\n\n    // Indices for re-ordering IQ4 data.\n    uint64_t index[16] = {\n        0, 1, 8, 9,\n        2, 3, 10, 11,\n        4, 5,12, 13,\n        6, 7, 14, 15,\n    };\n    vuint64m4_t i_vec = __riscv_vle64_v_u64m4(index, 16);\n\n    for (int ibl = 0; ibl < nb; ++ibl) {\n        const int8_t  * q8 = y[ibl].qs;\n        const uint8_t * iq4 = x[ibl].qs;\n        uint16_t h = x[ibl].scales_h;\n\n        int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;\n\n        for (int ib = 0; ib < QK_K / 128; ++ib) {\n            // Weights and activations.\n            vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 64);\n            vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128);\n            iq4 += 64;\n            q8 += 128;\n\n            // Unpack the weight blocks.\n            vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 64);\n            vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 64);\n            vuint8m4_t iq4bits;\n            iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 0, iq4bits_lo);\n            iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 1, iq4bits_hi);\n            vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgather_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 16));\n            vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 128);\n\n            // Multiply with activations.\n            vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 128);\n\n            // Reduce separately.\n            __riscv_vse32_v_i32m1(&acc[0],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32), 1);\n            __riscv_vse32_v_i32m1(&acc[1],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32), 1);\n            __riscv_vse32_v_i32m1(&acc[2],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32), 1);\n            __riscv_vse32_v_i32m1(&acc[3],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32), 1);\n\n            int ls1 = ((x[ibl].scales_l[ib * 2 + 0] & 0xf)  | ((h << 4) & 0x30)) - 32;\n            int ls2 = ((x[ibl].scales_l[ib * 2 + 0] >>  4)  | ((h << 2) & 0x30)) - 32;\n            int ls3 = ((x[ibl].scales_l[ib * 2 + 1] &  0xf) | ((h << 0) & 0x30)) - 32;\n            int ls4 = ((x[ibl].scales_l[ib * 2 + 1] >>  4)  | ((h >> 2) & 0x30)) - 32;\n            h >>= 8;\n\n            sumi1 += acc[0] * ls1;\n            sumi2 += acc[1] * ls2;\n            sumi3 += acc[2] * ls3;\n            sumi4 += acc[3] * ls4;\n        }\n\n        sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2 + sumi3 + sumi4);\n    }\n\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined __riscv_v_intrinsic\n    switch (__riscv_vlenb() * 8) {\n        case 256:\n            ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n        default:\n            ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n    }\n#else\n    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nstatic void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_tq1_0 * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    float sumf = 0.0f;\n    uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};\n\n    for (int i = 0; i < nb; i++) {\n        // First loop.\n        vint32m4_t suml1;\n        {\n            const int vl = 32;\n            vuint8m1_t tq = __riscv_vle8_v_u8m1(x[i].qs, vl);\n\n            vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tq, 3, vl), 8, vl);\n            vuint16m2_t tq1 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 3, vl), 3, vl), 8, vl);\n            vuint16m2_t tq2 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 9, vl), 3, vl), 8, vl);\n            vuint16m2_t tq3 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 27, vl), 3, vl), 8, vl);\n            vuint16m2_t tq4 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 81, vl), 3, vl), 8, vl);\n\n            vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 0, vl), vl);\n            vint16m2_t q81 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 32, vl), vl);\n            vint16m2_t q82 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 64, vl), vl);\n            vint16m2_t q83 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 96, vl), vl);\n            vint16m2_t q84 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 128, vl), vl);\n\n            vint16m2_t sum0 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl);\n            vint16m2_t sum1 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq1, 1, vl)), q81, vl);\n            vint16m2_t sum2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq2, 1, vl)), q82, vl);\n            vint16m2_t sum3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq3, 1, vl)), q83, vl);\n            vint16m2_t sum4 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq4, 1, vl)), q84, vl);\n\n            vint32m4_t sumi0 = __riscv_vwadd_vv_i32m4(sum0, sum1, vl);\n            vint32m4_t sumi1 = __riscv_vwadd_vv_i32m4(sum2, sum3, vl);\n            suml1 = __riscv_vadd_vv_i32m4(__riscv_vwcvt_x_x_v_i32m4(sum4, vl), __riscv_vadd_vv_i32m4(sumi0, sumi1, vl), vl);\n        }\n\n        // Second loop.\n        vint32m2_t suml2;\n        {\n            const int vl = 16;\n            vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs + 32, vl);\n\n            vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3 * 1, vl), 8, vl);\n            vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl);\n            vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl);\n            vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl);\n            vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl);\n\n            vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 160, vl), vl);\n            vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 176, vl), vl);\n            vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 192, vl), vl);\n            vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 208, vl), vl);\n            vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 224, vl), vl);\n\n            vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl);\n            vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl);\n            vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl);\n            vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl);\n            vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl);\n\n            vint32m2_t sumi0 = __riscv_vwadd_vv_i32m2(sum0, sum1, vl);\n            vint32m2_t sumi1 = __riscv_vwadd_vv_i32m2(sum2, sum3, vl);\n            suml2 = __riscv_vadd_vv_i32m2(__riscv_vwcvt_x_x_v_i32m2(sum4, vl), __riscv_vadd_vv_i32m2(sumi0, sumi1, vl), vl);\n        }\n\n        // Third loop.\n        vint32m2_t suml3;\n        {\n            const int vl = 16;\n\n            uint32_t qh;\n            memcpy(&qh, &x[i].qh[0], 4);\n            // Prevent fusion with vmv.\n            __asm__ __volatile__(\"\" : \"+r\"(qh));\n            vuint8mf2_t tq = __riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4));\n\n            vuint8mf2_t p = __riscv_vle8_v_u8mf2(pow, vl);\n\n            vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vv_u8mf2(tq, p, vl), 3, vl), 8, vl);\n\n            vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 240, vl), vl);\n\n            vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl);\n            suml3 = __riscv_vwcvt_x_x_v_i32m2(sum0, vl);\n        }\n\n        vint32m2_t sumb = __riscv_vadd_vv_i32m2(__riscv_vget_v_i32m4_i32m2(suml1, 0), __riscv_vget_v_i32m4_i32m2(suml1, 1), 16);\n        sumb = __riscv_vadd_vv_i32m2(sumb, suml2, 16);\n        sumb = __riscv_vadd_vv_i32m2(sumb, suml3, 16);\n\n        vint32m1_t sum = __riscv_vredsum_vs_i32m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16);\n        sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined __riscv_v_intrinsic\n    switch (__riscv_vlenb() * 8) {\n        case 256:\n            ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n        default:\n            ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n    }\n#else\n    ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nstatic void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_tq2_0 * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    float sumf = 0.0f;\n    for (int i = 0; i < nb; ++i) {\n        int32_t sumi = 0;\n\n        for (size_t j = 0; j < sizeof(x[0].qs); j += 32) {\n            const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32];\n            const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32];\n            const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32];\n            const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32];\n            const uint8_t* px  = &x[i].qs[j];\n\n            size_t vlmax_16m2 = __riscv_vsetvl_e16m2(32);\n            vint16m2_t vacc16 = __riscv_vmv_v_x_i16m2(0, vlmax_16m2);\n\n            size_t vl = __riscv_vsetvl_e8m1(32);\n\n            vuint8m1_t vx_u8 = __riscv_vle8_v_u8m1(px, vl);\n\n            vint8m1_t vy0 = __riscv_vle8_v_i8m1(py0 , vl);\n            vint8m1_t vy1 = __riscv_vle8_v_i8m1(py1, vl);\n            vint8m1_t vy2 = __riscv_vle8_v_i8m1(py2, vl);\n            vint8m1_t vy3 = __riscv_vle8_v_i8m1(py3, vl);\n\n            // l=0 (bits 1:0)\n            vuint8m1_t t0 = __riscv_vand_vx_u8m1(vx_u8, 0x03, vl);\n            vint8m1_t vq0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t0), 1, vl);\n\n            // l=1 (bits 3:2)\n            vuint8m1_t t1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 2, vl), 0x03, vl);\n            vint8m1_t vq1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t1), 1, vl);\n\n            // l=2 (bits 5:4)\n            vuint8m1_t t2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 4, vl), 0x03, vl);\n            vint8m1_t vq2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t2), 1, vl);\n\n            // l=3 (bits 7:6)\n            vuint8m1_t t3 = __riscv_vsrl_vx_u8m1(vx_u8, 6, vl); // No final AND needed as vsrl shifts in zeros\n            vint8m1_t vq3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t3), 1, vl);\n\n            // 4. Multiply and accumulate\n            vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq0, vy0, vl);\n            vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq1, vy1, vl);\n            vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq2, vy2, vl);\n            vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq3, vy3, vl);\n\n            vlmax_16m2 = __riscv_vsetvl_e16m2(32);\n            vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1);\n            vint32m1_t vred32 = __riscv_vwredsum_vs_i16m2_i32m1(vacc16, vzero32, vlmax_16m2);\n\n            sumi += __riscv_vmv_x_s_i32m1_i32(vred32);\n        }\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        sumf += (float)sumi * d;\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined __riscv_v_intrinsic\n    switch (__riscv_vlenb() * 8) {\n        case 256:\n            ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n        default:\n            ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n    }\n#else\n    ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nstatic void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_MXFP4 == 0);\n    static_assert(QK_MXFP4 == QK8_0, \"QK_MXFP4 and QK8_0 must be the same\");\n\n    const block_mxfp4 * GGML_RESTRICT x = vx;\n    const block_q8_0  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_MXFP4;\n\n    int ib = 0;\n    float sumf = 0;\n\n    // Load the lookup table once.\n    const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_mxfp4, 16);\n    int acc1, acc2;\n\n    // We process 2 blocks at once.\n    for (; ib + 1 < nb; ib += 2) {\n        // Weights and activations.\n        vuint8m1_t mx_packed1 = __riscv_vle8_v_u8m1(x[ib + 0].qs, 16);\n        vint8m2_t q8b1 = __riscv_vle8_v_i8m2(y[ib + 0].qs, 32);\n        vuint8m1_t mx_packed2 = __riscv_vle8_v_u8m1(x[ib + 1].qs, 16);\n        vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32);\n\n        // Unpack the weight blocks.\n        vuint8m2_t mxbits1;\n        mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 0, __riscv_vand_vx_u8m1(mx_packed1, 0xf, 16));\n        mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 1, __riscv_vsrl_vx_u8m1(mx_packed1, 4, 16));\n        vuint8m2_t mxbits2;\n        mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 0, __riscv_vand_vx_u8m1(mx_packed2, 0xf, 16));\n        mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 1, __riscv_vsrl_vx_u8m1(mx_packed2, 4, 16));\n\n        // Gather values from the lookup table.\n        vint8m2_t mxb1 = __riscv_vrgather_vv_i8m2(values, mxbits1, 32);\n        vint8m2_t mxb2 = __riscv_vrgather_vv_i8m2(values, mxbits2, 32);\n\n        // Accumulation.\n        vint16m4_t sum1 = __riscv_vwmul_vv_i16m4(q8b1, mxb1, 32);\n        vint16m4_t sum2 = __riscv_vwmul_vv_i16m4(q8b2, mxb2, 32);\n        __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m4_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 32), 1);\n        __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m4_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 32), 1);\n        sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1));\n        sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2));\n    }\n\n    *s = sumf;\n}\n\nstatic void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_MXFP4 == 0);\n    static_assert(QK_MXFP4 == QK8_0, \"QK_MXFP4 and QK8_0 must be the same\");\n\n    const block_mxfp4 * GGML_RESTRICT x = vx;\n    const block_q8_0  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_MXFP4;\n\n    int ib = 0;\n    float sumf = 0;\n\n    // Load the lookup table once.\n    const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_mxfp4, 16);\n    int acc1, acc2;\n\n    // We process 2 blocks at once.\n    for (; ib + 1 < nb; ib+=2) {\n        // Weights and activations.\n        vuint8mf2_t mx_packed1 = __riscv_vle8_v_u8mf2(x[ib + 0].qs, 16);\n        vint8mf2_t q8b_lo1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs, 16);\n        vint8mf2_t q8b_hi1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs + 16, 16);\n        vuint8mf2_t mx_packed2 = __riscv_vle8_v_u8mf2(x[ib + 1].qs, 16);\n        vint8mf2_t q8b_lo2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs, 16);\n        vint8mf2_t q8b_hi2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs + 16, 16);\n\n        // Unpack the weight blocks.\n        vuint8mf2_t mxbits_lo1 = __riscv_vand_vx_u8mf2(mx_packed1, 0xf, 16);\n        vuint8mf2_t mxbits_hi1 = __riscv_vsrl_vx_u8mf2(mx_packed1, 4, 16);\n        vuint8mf2_t mxbits_lo2 = __riscv_vand_vx_u8mf2(mx_packed2, 0xf, 16);\n        vuint8mf2_t mxbits_hi2 = __riscv_vsrl_vx_u8mf2(mx_packed2, 4, 16);\n\n        // Gather values from the lookup table.\n        vint8mf2_t mxb_lo1 = __riscv_vrgather_vv_i8mf2(values, mxbits_lo1, 16);\n        vint8mf2_t mxb_hi1 = __riscv_vrgather_vv_i8mf2(values, mxbits_hi1, 16);\n        vint8mf2_t mxb_lo2 = __riscv_vrgather_vv_i8mf2(values, mxbits_lo2, 16);\n        vint8mf2_t mxb_hi2 = __riscv_vrgather_vv_i8mf2(values, mxbits_hi2, 16);\n\n        // Accumulation.\n        vint16m1_t sum1 = __riscv_vwmul_vv_i16m1(q8b_lo1, mxb_lo1, 16);\n        sum1 = __riscv_vwmacc_vv_i16m1(sum1, q8b_hi1, mxb_hi1, 16);\n        vint16m1_t sum2 = __riscv_vwmul_vv_i16m1(q8b_lo2, mxb_lo2, 16);\n        sum2 = __riscv_vwmacc_vv_i16m1(sum2, q8b_hi2, mxb_hi2, 16);\n        __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m1_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 16), 1);\n        __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m1_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 16), 1);\n        sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1));\n        sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2));\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n#if defined __riscv_v_intrinsic\n    switch (__riscv_vlenb() * 8) {\n        case 128:\n            ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n        default:\n            ggml_vec_dot_mxfp4_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc);\n            break;\n    }\n#else\n    return ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n"
  },
  {
    "path": "src/ggml-cpu/arch/riscv/repack.cpp",
    "content": "#define GGML_COMMON_IMPL_CPP\n#define GGML_COMMON_DECL_CPP\n#include \"ggml-common.h\"\n#include \"ggml-backend-impl.h\"\n\n#include \"ggml-impl.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-cpu-impl.h\"\n#include \"simd-mappings.h\"\n#include \"traits.h\"\n\n#include <cmath>\n#include <cstring>\n#include <cassert>\n#include <cstdlib> // for qsort\n#include <cstdio>  // for GGML_ASSERT\n\n#define GGML_CPU_CLANG_WORKAROUND\n#include \"../../repack.h\"\n\n#if defined(__GNUC__)\n#pragma GCC diagnostic ignored \"-Woverlength-strings\"\n#endif\n\n#define UNUSED GGML_UNUSED\n\nvoid ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK8_0 == 32);\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n#if defined(__riscv_v_intrinsic)\n    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;\n    const size_t vl_calc = __riscv_vsetvl_e32m8(QK8_0);\n    const size_t vl_save = __riscv_vsetvl_e64m2(4);\n    vfloat32m1_t v_scalar_zero = __riscv_vfmv_s_f_f32m1(0.0f, __riscv_vsetvl_e32m1(1));\n\n    for (int i = 0; i < nb; i++) {\n        const float *x_block_base = x + i * QK8_0;\n        vint8m2_t q_r0, q_r1, q_r2, q_r3;\n        {\n            vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 0 * k, vl_calc);\n            vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);\n            vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);\n            float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);\n\n            float d = amax / 127.0f;\n            y[i].d[0] = GGML_CPU_FP32_TO_FP16(d);\n\n            float id = d ? 1.0f / d : 0.0f;\n            vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);\n            vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);\n            q_r0 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);\n        }\n        asm volatile (\"\" ::: \"memory\");\n\n        {\n            vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 1 * k, vl_calc);\n            vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);\n            vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);\n            float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);\n\n            float d = amax / 127.0f;\n            y[i].d[1] = GGML_CPU_FP32_TO_FP16(d);\n            float id = d ? 1.0f / d : 0.0f;\n\n            vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);\n            vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);\n            q_r1 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);\n        }\n        asm volatile (\"\" ::: \"memory\");\n        {\n            vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 2 * k, vl_calc);\n            vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);\n            vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);\n            float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);\n\n            float d = amax / 127.0f;\n            y[i].d[2] = GGML_CPU_FP32_TO_FP16(d);\n            float id = d ? 1.0f / d : 0.0f;\n\n            vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);\n            vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);\n            q_r2 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);\n        }\n        asm volatile (\"\" ::: \"memory\");\n        {\n            vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 3 * k, vl_calc);\n            vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);\n            vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);\n            float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);\n\n            float d = amax / 127.0f;\n            y[i].d[3] = GGML_CPU_FP32_TO_FP16(d);\n            float id = d ? 1.0f / d : 0.0f;\n\n            vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);\n            vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);\n            q_r3 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);\n        }\n        vint64m2_t v_q64_r0 = __riscv_vreinterpret_v_i8m2_i64m2(q_r0);\n        vint64m2_t v_q64_r1 = __riscv_vreinterpret_v_i8m2_i64m2(q_r1);\n        vint64m2_t v_q64_r2 = __riscv_vreinterpret_v_i8m2_i64m2(q_r2);\n        vint64m2_t v_q64_r3 = __riscv_vreinterpret_v_i8m2_i64m2(q_r3);\n        vint64m2x4_t v_quant_tuple = __riscv_vcreate_v_i64m2x4(v_q64_r0, v_q64_r1, v_q64_r2, v_q64_r3);\n        __riscv_vsseg4e64_v_i64m2x4((int64_t*)y[i].qs, v_quant_tuple, vl_save);\n    }\n#else\n    UNUSED(nb);\n    UNUSED(y);\n    ggml_quantize_mat_q8_0_4x4_generic(x, vy, k);\n#endif\n}\n\nvoid ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined __riscv_v\n    if (__riscv_vlenb() >= QK4_0) {\n        const size_t vl = QK4_0;\n\n        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);\n\n            vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);\n            for (int l = 0; l < nb; l++) {\n                const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0];\n                const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8];\n                const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16];\n                const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24];\n                __asm__ __volatile__(\"\" ::: \"memory\"); // prevent gcc from emitting fused vlse64, violating alignment constraints\n                const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4));\n                const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4));\n                const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4));\n                const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4));\n\n                const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);\n                const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);\n                const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);\n                const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);\n                const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);\n                const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);\n                const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);\n\n                const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);\n                const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);\n                const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);\n                const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);\n\n                const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m));\n                const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);\n                const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);\n                const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);\n                const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);\n                const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);\n                const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);\n                const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);\n                const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);\n                const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));\n                const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));\n                const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);\n                const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);\n\n                // vector version needs Zvfhmin extension\n                const float a_scale = GGML_CPU_FP16_TO_FP32(a_ptr[l].d);\n                const float b_scales[8] = {\n                    GGML_CPU_FP16_TO_FP32(b_ptr[l].d[0]),\n                    GGML_CPU_FP16_TO_FP32(b_ptr[l].d[1]),\n                    GGML_CPU_FP16_TO_FP32(b_ptr[l].d[2]),\n                    GGML_CPU_FP16_TO_FP32(b_ptr[l].d[3]),\n                    GGML_CPU_FP16_TO_FP32(b_ptr[l].d[4]),\n                    GGML_CPU_FP16_TO_FP32(b_ptr[l].d[5]),\n                    GGML_CPU_FP16_TO_FP32(b_ptr[l].d[6]),\n                    GGML_CPU_FP16_TO_FP32(b_ptr[l].d[7])\n                };\n                const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);\n                const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4);\n                sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4);\n            }\n            __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4);\n        }\n        return;\n    }\n\n#endif\n    ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen = 1;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined __riscv_v_intrinsic\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb);\n\n        // 1x16 Accumulator\n        vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n\n        for (int l = 0; l < nb; l++) {\n            // 1x16 Integer Accumulator\n            vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n            vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n\n            // Accumulation loop.\n            for (int i = 0; i < QK4_0 / 2; i++) {\n                // Load `b_ptr`.\n                const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);\n                const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16);\n                const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16);\n\n                sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], b_0_lo, 16);\n                sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], b_0_hi, 16);\n            }\n\n            const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16);\n\n            const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);\n            const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);\n\n            sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);\n        }\n\n        __riscv_vse32_v_f32m2(s + x * 16, sumf, 16);\n    }\n    return;\n#endif\n    ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK_K;\n    const int nb = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen = 1;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined __riscv_v_intrinsic\n    const block_q8_K * a_ptr = (const block_q8_K *) vy;\n\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb);\n\n        // 1x16 Accumulator\n        vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n\n        for (int l = 0; l < nb; l++) {\n            vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, 16);\n\n            // Load `dmin`.\n            const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2(\n                __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d, 16);\n\n            // We process 4 sub-blocks at once.\n            for (int j = 0; j < QK_K / 128; j++) {\n                // Extract the scales and the mins.\n                //\n                // Low bits.\n                vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64);\n                vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64);\n                vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64);\n\n                // High bits.\n                vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64);\n                vuint8m2_t scales_hi;\n                vuint8m2_t mins_hi;\n                if (!j) {\n                    scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64);\n                    mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64);\n                } else {\n                    scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64);\n                    mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64);\n                }\n                vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64);\n                vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64));\n\n                // Reduce the mins and multiply with `dmin`.\n                //\n                // Correct in `sumf`.\n                vint32m2_t bsums = __riscv_vmv_v_x_i32m2(0, 16);\n                bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8] + a_ptr[l].bsums[j * 8 + 1], __riscv_vget_v_i16m4_i16m1(mins, 0), 16);\n                bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 2] + a_ptr[l].bsums[j * 8 + 3], __riscv_vget_v_i16m4_i16m1(mins, 1), 16);\n                bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), 16);\n                bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), 16);\n\n                sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, 16), 16), 16);\n\n                // Accumulation for 2 sub-blocks.\n                //\n                // This might overflow, so we accumulate in two steps.\n                //\n                // Recheck.\n                for (int k = 0; k < 2; k++) {\n                    vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                    vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n\n                    for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {\n                        // Load `b_ptr`.\n                        const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16);\n                        const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));\n                        const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));\n\n                        sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + i], b_s_0, 16);\n                        sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 32 + i], b_s_1, 16);\n                    }\n\n                    sumi = __riscv_vwmacc_vv_i32m2(sumi,\n                        __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),\n                        sumi_s_0_16, 16);\n                    sumi = __riscv_vwmacc_vv_i32m2(sumi,\n                        __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),\n                        sumi_s_1_16, 16);\n                }\n                // Accumulation for 2 sub-blocks.\n                //\n                // This might overflow, so we accumulate in two steps.\n                //\n                // Recheck.\n                for (int k = 0; k < 2; k++) {\n                    vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                    vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n\n                    for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {\n                        // Load `b_ptr`.\n                        const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16);\n                        const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));\n                        const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));\n\n                        sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + 64 + i], b_s_0, 16);\n                        sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 96 + i], b_s_1, 16);\n                    }\n\n                    sumi = __riscv_vwmacc_vv_i32m2(sumi,\n                        __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),\n                        sumi_s_0_16, 16);\n                    sumi = __riscv_vwmacc_vv_i32m2(sumi,\n                        __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),\n                        sumi_s_1_16, 16);\n                }\n            }\n\n            const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], 16), 16);\n            const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, 16);\n\n            sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);\n        }\n\n        __riscv_vse32_v_f32m2(s + x * 16, sumf, 16);\n    }\n    return;\n#endif\n    ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen = 1;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined __riscv_v_intrinsic\n    const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16);\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);\n\n        // 1x16 Accumulator1\n        vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n\n        for (int l = 0; l < nb; l++) {\n            // 1x16 integer accumulator\n            vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16);\n\n            // Accumulation loop.\n            for (int i = 0; i < QK4_NL / 2; i++) {\n                // Load `b_ptr`.\n                const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16);\n                const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);\n                const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);\n                // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);\n                // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);\n\n                const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], 16);\n                const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], 16);\n                sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16);\n            }\n\n            const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);\n            const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);\n\n            sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);\n        }\n\n        __riscv_vse32_v_f32m2(s + x * 16, sumf, 16);\n    }\n    return;\n#endif\n    ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen = 1;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n    UNUSED(bs);\n\n#if defined __riscv_v_intrinsic\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb);\n\n        // 1x16 Accumulator\n        vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n\n        for (int l = 0; l < nb; l++) {\n            // 1x16 Integer Accumulator\n            vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16);\n\n            // Accumulation loop.\n            for (int i = 0; i < QK8_0; i++) {\n                // Load `b_ptr`.\n                const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);\n                // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16);\n\n                sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16);\n            }\n\n            const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);\n            const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);\n\n            sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);\n        }\n\n        __riscv_vse32_v_f32m2(s + x * 16, sumf, 16);\n    }\n    return;\n#endif\n    ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    assert(n % QK_K == 0);\n    assert(nr == 1);\n    assert(nc % 16 == 0);\n\n    UNUSED(bs);\n\n    const int N_COLS_TILE = 16;\n    const int num_k_blocks = n / QK_K;\n\n    const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE);\n    for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) {\n\n        const block_q8_K* lhs_base_ptr = (const block_q8_K*)vy;\n        const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks;\n\n        vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, vl);\n\n        for (int k_block = 0; k_block < num_k_blocks; ++k_block) {\n            const block_q8_K* lhs_current = &lhs_base_ptr[k_block];\n            const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block];\n\n            // 1. Prepare Global Min Scales\n            vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl);\n            vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl);\n\n            vfloat32m2_t v_g_min_final = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d, vl);\n\n            vint32m2_t v_isum = __riscv_vmv_v_x_i32m2(0, vl);\n\n            const uint8_t* rhs_qs_ptr = rhs_current->qs;\n            const uint8_t* rhs_sc_ptr = rhs_current->scales;\n            const int8_t*  lhs_qs_ptr = lhs_current->qs;\n\n            // --- Phase Loop (4 phases x 64 elements) ---\n            for (int phase = 0; phase < 4; ++phase) {\n\n                // A. Load Scales/Mins\n                vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3;\n                vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3;\n\n                {\n                    vuint8mf2_t v_raw;\n                    // Sub-block 0\n                    v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl);\n                    v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);\n                    v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);\n\n                    // Sub-block 1\n                    v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl);\n                    v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);\n                    v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);\n\n                    // Sub-block 2\n                    v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl);\n                    v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);\n                    v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);\n\n                    // Sub-block 3\n                    v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl);\n                    v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);\n                    v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);\n\n                    rhs_sc_ptr += 64;\n                }\n\n                int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16);\n                int k_offsets[4] = {0, 32, 64, 96};\n\n                // B. Inner Dot Product Loop\n                for (int l = 0; l < 16; ++l) {\n                    vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl);\n                    rhs_qs_ptr += 16;\n\n                    // Sub-block 0\n                    {\n                        vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl);\n                        vint16m1_t v_w = __riscv_vmul_vv_i16m1(\n                            __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),\n                            __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl);\n\n                        int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[0] + l];\n                        v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);\n                    }\n                    // Sub-block 1\n                    {\n                        vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl);\n                        vint16m1_t v_w = __riscv_vmul_vv_i16m1(\n                            __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),\n                            __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl);\n\n                        int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[1] + l];\n                        v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);\n                    }\n                    // Sub-block 2\n                    {\n                        vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl);\n                        vint16m1_t v_w = __riscv_vmul_vv_i16m1(\n                            __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),\n                            __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl);\n\n                        int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[2] + l];\n                        v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);\n                    }\n                    // Sub-block 3\n                    {\n                        vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl);\n                        vint16m1_t v_w = __riscv_vmul_vv_i16m1(\n                            __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),\n                            __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl);\n\n                        int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[3] + l];\n                        v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);\n                    }\n                }\n\n                // correction\n                int sb_base_abs = base_k_phase / 16;\n\n                // Sub-block 0\n                {\n                    int sb_idx = sb_base_abs + (k_offsets[0] / 16);\n                    int16_t bsum = lhs_current->bsums[sb_idx];\n                    vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0);\n                    vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);\n                    vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);\n                    v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);\n                }\n                // Sub-block 1\n                {\n                    int sb_idx = sb_base_abs + (k_offsets[1] / 16);\n                    int16_t bsum = lhs_current->bsums[sb_idx];\n                    vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1);\n                    vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);\n                    vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);\n                    v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);\n                }\n                // Sub-block 2\n                {\n                    int sb_idx = sb_base_abs + (k_offsets[2] / 16);\n                    int16_t bsum = lhs_current->bsums[sb_idx];\n                    vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2);\n                    vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);\n                    vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);\n                    v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);\n                }\n                // Sub-block 3\n                {\n                    int sb_idx = sb_base_abs + (k_offsets[3] / 16);\n                    int16_t bsum = lhs_current->bsums[sb_idx];\n                    vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3);\n                    vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);\n                    vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);\n                    v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);\n                }\n\n            } // End Phase Loop\n\n            // Apply global Scales\n            vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl);\n            vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl);\n\n            vfloat32m2_t v_g_all_final = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d, vl);\n            vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum, vl);\n            v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all_final, vl);\n            v_sumf = __riscv_vfadd_vv_f32m2(v_sumf, v_sum, vl);\n\n        } // End K-Block\n        __riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl);\n\n    }\n}\n\nvoid ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined __riscv_v\n    if (__riscv_vlenb() >= QK4_0) {\n        const size_t vl = QK4_0;\n\n        for (int y = 0; y < nr / 4; y++) {\n            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n            for (int x = 0; x < nc / ncols_interleaved; x++) {\n                const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);\n                vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);\n                vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);\n                vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);\n                vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);\n                for (int l = 0; l < nb; l++) {\n                    const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);\n                    const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);\n                    const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);\n                    const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);\n                    const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);\n                    const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);\n                    const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);\n\n                    // vector version needs Zvfhmin extension\n                    const float a_scales[4] = {\n                        GGML_CPU_FP16_TO_FP32(a_ptr[l].d[0]),\n                        GGML_CPU_FP16_TO_FP32(a_ptr[l].d[1]),\n                        GGML_CPU_FP16_TO_FP32(a_ptr[l].d[2]),\n                        GGML_CPU_FP16_TO_FP32(a_ptr[l].d[3])\n                    };\n                    const float b_scales[8] = {\n                        GGML_CPU_FP16_TO_FP32(b_ptr[l].d[0]),\n                        GGML_CPU_FP16_TO_FP32(b_ptr[l].d[1]),\n                        GGML_CPU_FP16_TO_FP32(b_ptr[l].d[2]),\n                        GGML_CPU_FP16_TO_FP32(b_ptr[l].d[3]),\n                        GGML_CPU_FP16_TO_FP32(b_ptr[l].d[4]),\n                        GGML_CPU_FP16_TO_FP32(b_ptr[l].d[5]),\n                        GGML_CPU_FP16_TO_FP32(b_ptr[l].d[6]),\n                        GGML_CPU_FP16_TO_FP32(b_ptr[l].d[7])\n                    };\n                    const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);\n\n                    const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0];\n                    const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32];\n                    const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64];\n                    const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96];\n                    __asm__ __volatile__(\"\" ::: \"memory\"); // prevent gcc from emitting fused vlse64, violating alignment\n                    vint16m4_t sumi_l0;\n                    {\n                        const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4));\n                        const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4));\n                        const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4));\n                        const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4));\n                        const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);\n                        const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);\n                        const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);\n                        const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);\n\n                        sumi_l0 = sumi_hi_m;\n                    }\n\n                    {\n                        const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0));\n                        const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);\n                        const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);\n                        const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);\n                        const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);\n                        const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);\n                        const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);\n                        const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);\n                        const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);\n                        const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));\n                        const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));\n                        const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);\n                        const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);\n\n                        const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4);\n                        sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4);\n                    }\n\n                    const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8];\n                    const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40];\n                    const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72];\n                    const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104];\n                    __asm__ __volatile__(\"\" ::: \"memory\"); // prevent gcc from emitting fused vlse64, violating alignment\n                    vint16m4_t sumi_l1;\n                    {\n                        const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4));\n                        const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4));\n                        const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4));\n                        const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4));\n                        const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);\n                        const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);\n                        const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);\n                        const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);\n\n                        sumi_l1 = sumi_hi_m;\n                    }\n\n                    {\n                        const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1));\n                        const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);\n                        const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);\n                        const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);\n                        const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);\n                        const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);\n                        const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);\n                        const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);\n                        const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);\n                        const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));\n                        const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));\n                        const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);\n                        const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);\n\n                        const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4);\n                        sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4);\n                    }\n\n                    const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16];\n                    const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48];\n                    const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80];\n                    const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112];\n                    __asm__ __volatile__(\"\" ::: \"memory\"); // prevent gcc from emitting fused vlse64, violating alignment\n                    vint16m4_t sumi_l2;\n                    {\n                        const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4));\n                        const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4));\n                        const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4));\n                        const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4));\n                        const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);\n                        const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);\n                        const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);\n                        const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);\n\n                        sumi_l2 = sumi_hi_m;\n                    }\n\n                    {\n                        const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2));\n                        const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);\n                        const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);\n                        const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);\n                        const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);\n                        const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);\n                        const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);\n                        const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);\n                        const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);\n                        const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));\n                        const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));\n                        const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);\n                        const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);\n\n                        const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4);\n                        sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4);\n                    }\n\n                    const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24];\n                    const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56];\n                    const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88];\n                    const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120];\n                    __asm__ __volatile__(\"\" ::: \"memory\"); // prevent gcc from emitting fused vlse64, violating alignment\n                    vint16m4_t sumi_l3;\n                    {\n                        const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4));\n                        const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4));\n                        const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4));\n                        const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4));\n                        const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);\n                        const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);\n                        const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);\n                        const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);\n\n                        sumi_l3 = sumi_hi_m;\n                    }\n\n                    {\n                        const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3));\n                        const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);\n                        const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);\n                        const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);\n                        const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);\n                        const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);\n                        const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);\n                        const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);\n                        const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);\n                        const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));\n                        const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));\n                        const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);\n                        const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);\n\n                        const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4);\n                        sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4);\n                    }\n                }\n                __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4);\n                __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4);\n                __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4);\n                __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4);\n            }\n        }\n\n        return;\n    }\n\n#endif\n    ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen = 1;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined __riscv_v_intrinsic\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb);\n\n            // 4x16 Accumulators\n            vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n            vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n            vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n            vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n\n            for (int l = 0; l < nb; l++) {\n                // 4x16 integer accumulators\n                vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n\n                // Accumulation loop.\n                for (int i = 0; i < QK4_0 / 2; i++) {\n                    // Load `b_ptr`.\n                    const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);\n                    const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16);\n                    const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16);\n\n                    sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], b_0_lo, 16);\n                    sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], b_0_lo, 16);\n                    sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], b_0_lo, 16);\n                    sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], b_0_lo, 16);\n\n                    sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], b_0_hi, 16);\n                    sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], b_0_hi, 16);\n                    sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], b_0_hi, 16);\n                    sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], b_0_hi, 16);\n                }\n\n                // Do the final accumulation in i32 to prevent overflow.\n                const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16);\n                const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, 16);\n                const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, 16);\n                const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, 16);\n\n                const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);\n                const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);\n                const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);\n                const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);\n                const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);\n\n                sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);\n                sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);\n                sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);\n                sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);\n            }\n\n            __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);\n            __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);\n            __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);\n            __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);\n        }\n    }\n    return;\n#endif\n    ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK_K;\n    const int nb = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen = 1;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined __riscv_v_intrinsic\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb);\n\n            // 4x16 Accumulators\n            vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n            vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n            vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n            vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n\n            for (int l = 0; l < nb; l++) {\n                vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, 16);\n                vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, 16);\n                vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, 16);\n                vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, 16);\n\n                // Load `dmin`.\n                const vfloat32m2_t dmins = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16);\n\n                // We process 4 sub-blocks at once.\n                for (int j = 0; j < QK_K / 128; j++) {\n                    // Extract the scales and the mins.\n                    //\n                    // Low bits.\n                    vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64);\n                    vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64);\n                    vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64);\n\n                    // High bits.\n                    vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64);\n                    vuint8m2_t scales_hi;\n                    vuint8m2_t mins_hi;\n                    if (!j) {\n                        scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64);\n                        mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64);\n                    } else {\n                        scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64);\n                        mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64);\n                    }\n                    vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64);\n                    vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64));\n\n                    // Reduce the mins and multiply with `dmin`.\n                    //\n                    // Correct in `sumf`.\n                    vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, 16);\n                    vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, 16);\n                    vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, 16);\n                    vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, 16);\n\n                    bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,\n                                a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4],\n                                __riscv_vget_v_i16m4_i16m1(mins, 0), 16);\n                    bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,\n                                a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5],\n                                __riscv_vget_v_i16m4_i16m1(mins, 0), 16);\n                    bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,\n                                a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6],\n                                __riscv_vget_v_i16m4_i16m1(mins, 0), 16);\n                    bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,\n                                a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7],\n                                __riscv_vget_v_i16m4_i16m1(mins, 0), 16);\n                    bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,\n                                a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4],\n                                __riscv_vget_v_i16m4_i16m1(mins, 1), 16);\n                    bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,\n                                a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5],\n                                __riscv_vget_v_i16m4_i16m1(mins, 1), 16);\n                    bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,\n                                a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6],\n                                __riscv_vget_v_i16m4_i16m1(mins, 1), 16);\n                    bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,\n                                a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7],\n                                __riscv_vget_v_i16m4_i16m1(mins, 1), 16);\n                    bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,\n                                a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4],\n                                __riscv_vget_v_i16m4_i16m1(mins, 2), 16);\n                    bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,\n                                a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5],\n                                __riscv_vget_v_i16m4_i16m1(mins, 2), 16);\n                    bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,\n                                a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6],\n                                __riscv_vget_v_i16m4_i16m1(mins, 2), 16);\n                    bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,\n                                a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7],\n                                __riscv_vget_v_i16m4_i16m1(mins, 2), 16);\n                    bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,\n                                a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4],\n                                __riscv_vget_v_i16m4_i16m1(mins, 3), 16);\n                    bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,\n                                a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5],\n                                __riscv_vget_v_i16m4_i16m1(mins, 3), 16);\n                    bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,\n                                a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6],\n                                __riscv_vget_v_i16m4_i16m1(mins, 3), 16);\n                    bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,\n                                a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7],\n                                __riscv_vget_v_i16m4_i16m1(mins, 3), 16);\n\n                    const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[0], 16);\n                    const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[1], 16);\n                    const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[2], 16);\n                    const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[3], 16);\n\n                    sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, 16), 16), 16);\n                    sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, 16), 16), 16);\n                    sumf_2 = __riscv_vfsub_vv_f32m2(sumf_2, __riscv_vfmul_vv_f32m2(dmins_d_2, __riscv_vfcvt_f_x_v_f32m2(bsums_2, 16), 16), 16);\n                    sumf_3 = __riscv_vfsub_vv_f32m2(sumf_3, __riscv_vfmul_vv_f32m2(dmins_d_3, __riscv_vfcvt_f_x_v_f32m2(bsums_3, 16), 16), 16);\n\n\n                    // Accumulation for 2 sub-blocks.\n                    //\n                    // This might overflow, so we accumulate in two steps.\n                    //\n                    // Recheck.\n                    for (int k = 0; k < 2; k++) {\n                        // 4x16 integer accumulators\n                        vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                        vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                        vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                        vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                        vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                        vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                        vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                        vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n\n                        for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {\n                            // Load `b_ptr`.\n                            const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16);\n                            const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));\n                            const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));\n\n                            sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + i * 4], b_s_0, 16);\n                            sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 1], b_s_0, 16);\n                            sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 2], b_s_0, 16);\n                            sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 3], b_s_0, 16);\n\n                            sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4], b_s_1, 16);\n                            sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 1], b_s_1, 16);\n                            sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 2], b_s_1, 16);\n                            sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 3], b_s_1, 16);\n                        }\n\n                        sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),\n                                    sumi_0_s_0_16, 16);\n                        sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),\n                                    sumi_0_s_1_16, 16);\n                        sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),\n                                    sumi_1_s_0_16, 16);\n                        sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),\n                                    sumi_1_s_1_16, 16);\n                        sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),\n                                    sumi_2_s_0_16, 16);\n                        sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),\n                                    sumi_2_s_1_16, 16);\n                        sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),\n                                    sumi_3_s_0_16, 16);\n                        sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),\n                                    sumi_3_s_1_16, 16);\n                    }\n                    // Accumulation for 2 sub-blocks.\n                    //\n                    // This might overflow, so we accumulate in two steps.\n                    //\n                    // Recheck.\n                    for (int k = 0; k < 2; k++) {\n                        // 4x16 integer accumulators\n                        vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                        vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                        vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                        vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                        vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                        vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                        vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n                        vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);\n\n                        for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {\n                            // Load `b_ptr`.\n                            const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16);\n                            const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));\n                            const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));\n\n                            sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4], b_s_0, 16);\n                            sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 1], b_s_0, 16);\n                            sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 2], b_s_0, 16);\n                            sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 3], b_s_0, 16);\n\n                            sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4], b_s_1, 16);\n                            sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 1], b_s_1, 16);\n                            sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 2], b_s_1, 16);\n                            sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 3], b_s_1, 16);\n                        }\n\n                        sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),\n                                    sumi_0_s_0_16, 16);\n                        sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),\n                                    sumi_0_s_1_16, 16);\n                        sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),\n                                    sumi_1_s_0_16, 16);\n                        sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),\n                                    sumi_1_s_1_16, 16);\n                        sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),\n                                    sumi_2_s_0_16, 16);\n                        sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),\n                                    sumi_2_s_1_16, 16);\n                        sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),\n                                    sumi_3_s_0_16, 16);\n                        sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,\n                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),\n                                    sumi_3_s_1_16, 16);\n                    }\n                }\n\n                const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16), 16);\n                const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], 16);\n                const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], 16);\n                const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], 16);\n                const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], 16);\n\n                sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);\n                sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);\n                sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);\n                sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);\n            }\n\n            __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);\n            __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);\n            __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);\n            __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);\n        }\n    }\n    return;\n#endif\n    ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen = 1;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined __riscv_v_intrinsic\n    const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16);\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);\n\n            // 4x16 Accumulators\n            vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n            vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n            vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n            vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n\n            for (int l = 0; l < nb; l++) {\n                // 4x16 integer accumulators\n                vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16);\n                vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16);\n                vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16);\n                vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16);\n\n                // Accumulation loop.\n                for (int i = 0; i < QK4_NL / 2; i++) {\n                    // Load `b_ptr`.\n                    const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16);\n                    const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);\n                    const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);\n                    // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);\n                    // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);\n\n                    const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], 16);\n                    const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], 16);\n                    const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], 16);\n                    const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 3], 16);\n\n                    const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4], 16);\n                    const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 1], 16);\n                    const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 2], 16);\n                    const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 3], 16);\n\n                    sumi_0 = __riscv_vadd_vv_i32m2(sumi_0, __riscv_vwadd_vv_i32m2(sumi_0_lo, sumi_0_hi, 16), 16);\n                    sumi_1 = __riscv_vadd_vv_i32m2(sumi_1, __riscv_vwadd_vv_i32m2(sumi_1_lo, sumi_1_hi, 16), 16);\n                    sumi_2 = __riscv_vadd_vv_i32m2(sumi_2, __riscv_vwadd_vv_i32m2(sumi_2_lo, sumi_2_hi, 16), 16);\n                    sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16);\n                }\n\n                const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);\n                const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);\n                const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);\n                const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);\n                const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);\n\n                sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);\n                sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);\n                sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);\n                sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);\n            }\n\n            __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);\n            __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);\n            __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);\n            __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);\n        }\n    }\n    return;\n#endif\n    ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen = 1;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined __riscv_v_intrinsic\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb);\n\n            // 4x16 Accumulators\n            vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n            vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n            vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n            vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);\n\n            for (int l = 0; l < nb; l++) {\n                // 4x16 Integer Accumulators\n                vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16);\n                vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16);\n                vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16);\n                vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16);\n\n                // Accumulation loop.\n                for (int i = 0; i < QK8_0; i++) {\n                    // Load `b_ptr`.\n                    const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);\n                    // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16);\n\n                    sumi_0 = __riscv_vwadd_wv_i32m2(sumi_0, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 0], 16), 16);\n                    sumi_1 = __riscv_vwadd_wv_i32m2(sumi_1, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 1], 16), 16);\n                    sumi_2 = __riscv_vwadd_wv_i32m2(sumi_2, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 2], 16), 16);\n                    sumi_3 = __riscv_vwadd_wv_i32m2(sumi_3, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 3], 16), 16);\n                }\n\n                const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);\n                const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);\n                const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);\n                const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);\n                const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);\n\n                sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);\n                sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);\n                sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);\n                sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);\n            }\n\n            __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);\n            __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);\n            __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);\n            __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);\n        }\n    }\n    return;\n#endif\n    ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    assert(n % QK_K == 0);\n    const int num_k_blocks = n / QK_K;\n    const int N_ROWS_TILE = 4;\n    const int N_COLS_TILE = 16;\n    assert(nr % N_ROWS_TILE == 0);\n    assert(nc % N_COLS_TILE == 0);\n\n    const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE);\n    // --- Tiling Loops ---\n#pragma GCC unroll 1\n    for (int row_tile = 0; row_tile < nr; row_tile += N_ROWS_TILE) {\n#pragma GCC unroll 1\n        for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) {\n            // Base Pointers\n            const block_q8_Kx4* lhs_base_ptr = (const block_q8_Kx4*)vy + (row_tile / N_ROWS_TILE) * num_k_blocks;\n            const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks;\n\n            // Persistent Float Accumulators\n            vfloat32m2_t v_sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl);\n            vfloat32m2_t v_sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, vl);\n            vfloat32m2_t v_sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, vl);\n            vfloat32m2_t v_sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, vl);\n\n            // --- Super-Block Loop (K=0..255) ---\n#pragma GCC unroll 1\n            for (int k_block = 0; k_block < num_k_blocks; ++k_block) {\n                const block_q8_Kx4* lhs_current = &lhs_base_ptr[k_block];\n                const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block];\n\n                // 1. Load Global Min Scales (Keep as F16/LMUL=1 to save registers)\n                vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl);\n                vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl);\n\n                // 2. Initialize Integer Accumulators\n                vint32m2_t v_isum_0 = __riscv_vmv_v_x_i32m2(0, vl);\n                vint32m2_t v_isum_1 = __riscv_vmv_v_x_i32m2(0, vl);\n                vint32m2_t v_isum_2 = __riscv_vmv_v_x_i32m2(0, vl);\n                vint32m2_t v_isum_3 = __riscv_vmv_v_x_i32m2(0, vl);\n\n                const uint8_t* rhs_qs_ptr = rhs_current->qs;\n                const uint8_t* rhs_sc_ptr = rhs_current->scales;\n                const int8_t*  lhs_qs_ptr = lhs_current->qs;\n\n                // --- Phase Loop (4 phases x 64 elements) ---\n#pragma GCC unroll 1\n                for (int phase = 0; phase < 4; ++phase) {\n\n                    // A. Load Scales/Mins for the 4 interleaved sub-blocks\n                    vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3;\n                    vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3;\n\n                    // Unrolled Load Logic\n                    {\n                        vuint8mf2_t v_raw;\n                        // Sub-block 0\n                        v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl);\n                        v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);\n                        v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);\n\n                        // Sub-block 1\n                        v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl);\n                        v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);\n                        v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);\n\n                        // Sub-block 2\n                        v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl);\n                        v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);\n                        v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);\n\n                        // Sub-block 3\n                        v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl);\n                        v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);\n                        v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);\n\n                        rhs_sc_ptr += 64;\n                    }\n\n                    int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16);\n                    int k_offsets[4] = {0, 32, 64, 96};\n\n                    // B. Inner Dot Product Loop\n#pragma GCC unroll 1\n                    for (int l = 0; l < 16; ++l) {\n                        vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl);\n                        rhs_qs_ptr += 16;\n\n                        // Unroll over 4 sub-blocks (0, 1, 2, 3 relative to phase)\n\n                        // --- Sub-block 0 ---\n                        {\n                            vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl);\n                            vint16m1_t v_w = __riscv_vmul_vv_i16m1(\n                                __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),\n                                __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl);\n\n                            const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[0] + l) * 4];\n                            v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);\n                            v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);\n                            v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);\n                            v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);\n                        }\n                        // --- Sub-block 1 ---\n                        {\n                            vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl);\n                            vint16m1_t v_w = __riscv_vmul_vv_i16m1(\n                                __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),\n                                __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl);\n\n                            const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[1] + l) * 4];\n                            v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);\n                            v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);\n                            v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);\n                            v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);\n                        }\n                        // --- Sub-block 2 ---\n                        {\n                            vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl);\n                            vint16m1_t v_w = __riscv_vmul_vv_i16m1(\n                                __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),\n                                __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl);\n\n                            const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[2] + l) * 4];\n                            v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);\n                            v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);\n                            v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);\n                            v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);\n                        }\n                        // --- Sub-block 3 ---\n                        {\n                            vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl);\n                            vint16m1_t v_w = __riscv_vmul_vv_i16m1(\n                                __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),\n                                __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl);\n\n                            const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[3] + l) * 4];\n                            v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);\n                            v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);\n                            v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);\n                            v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);\n                        }\n                    }\n\n                    // C CORRECTION\n                    int sb_base_abs = base_k_phase / 16;\n\n                    // --- Correction Sub-block 0 ---\n                    {\n                        int sb_abs = sb_base_abs + (k_offsets[0] / 16);\n                        vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0);\n\n                        // Row 0\n                        vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);\n                        vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);\n                        vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);\n\n                        // Row 1\n                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);\n                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);\n                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);\n\n                        // Row 2\n                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);\n                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);\n                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);\n\n                        // Row 3\n                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);\n                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);\n                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);\n                    }\n\n                    // --- Correction Sub-block 1 ---\n                    {\n                        int sb_abs = sb_base_abs + (k_offsets[1] / 16);\n                        vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1);\n\n                        vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);\n                        vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);\n                        vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);\n\n                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);\n                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);\n                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);\n\n                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);\n                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);\n                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);\n\n                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);\n                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);\n                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);\n                    }\n\n                    // --- Correction Sub-block 2 ---\n                    {\n                        int sb_abs = sb_base_abs + (k_offsets[2] / 16);\n                        vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2);\n\n                        vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);\n                        vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);\n                        vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);\n\n                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);\n                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);\n                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);\n\n                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);\n                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);\n                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);\n\n                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);\n                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);\n                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);\n                    }\n\n                    // --- Correction Sub-block 3 ---\n                    {\n                        int sb_abs = sb_base_abs + (k_offsets[3] / 16);\n                        vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3);\n\n                        vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);\n                        vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);\n                        vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);\n\n                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);\n                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);\n                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);\n\n                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);\n                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);\n                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);\n\n                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);\n                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);\n                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);\n                        v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);\n                    }\n\n                } // End Phase Loop\n\n                // --- Apply Main Scales ---\n                vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl);\n                vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl);\n\n                {\n                    vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[0], vl);\n                    vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_0, vl);\n                    v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);\n                    v_sumf_0 = __riscv_vfadd_vv_f32m2(v_sumf_0, v_sum, vl);\n                }\n                // Row 1\n                {\n                    vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[1], vl);\n                    vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_1, vl);\n                    v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);\n                    v_sumf_1 = __riscv_vfadd_vv_f32m2(v_sumf_1, v_sum, vl);\n                }\n                // Row 2\n                {\n                    vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[2], vl);\n                    vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_2, vl);\n                    v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);\n                    v_sumf_2 = __riscv_vfadd_vv_f32m2(v_sumf_2, v_sum, vl);\n                }\n                // Row 3\n                {\n                    vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[3], vl);\n                    vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_3, vl);\n                    v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);\n                    v_sumf_3 = __riscv_vfadd_vv_f32m2(v_sumf_3, v_sum, vl);\n                }\n\n            } // End K-Block\n\n            __riscv_vse32_v_f32m2(s + (row_tile + 0) * bs + col_tile, v_sumf_0, vl);\n            __riscv_vse32_v_f32m2(s + (row_tile + 1) * bs + col_tile, v_sumf_1, vl);\n            __riscv_vse32_v_f32m2(s + (row_tile + 2) * bs + col_tile, v_sumf_2, vl);\n            __riscv_vse32_v_f32m2(s + (row_tile + 3) * bs + col_tile, v_sumf_3, vl);\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-cpu/arch/s390/cpu-feats.cpp",
    "content": "#include \"ggml-backend-impl.h\"\n\n#if defined(__s390x__)\n#include <sys/auxv.h>\n\n// find hwcap bits in asm/elf.h\n#ifndef HWCAP_VXRS_EXT2\n#define HWCAP_VXRS_EXT2 (1 << 15)\n#endif\n\n#ifndef HWCAP_NNPA\n#define HWCAP_NNPA (1 << 20)\n#endif\n\nstruct s390x_features {\n    bool has_vxe2 = false;\n    bool has_nnpa = false;\n\n    s390x_features() {\n        uint32_t hwcap = getauxval(AT_HWCAP);\n        // NOTE: use hwcap2 with DFLT for z17 and later\n        // uint32_t hwcap2 = getauxval(AT_HWCAP2);\n\n        has_vxe2 = !!(hwcap & HWCAP_VXRS_EXT2);\n        has_nnpa = !!(hwcap & HWCAP_NNPA);\n    }\n};\n\nstatic int ggml_backend_cpu_s390x_score() {\n    int score = 1;\n    s390x_features sf;\n\n// IBM z15 / LinuxONE 3\n#ifdef GGML_USE_VXE2\n    if (!sf.has_vxe2) { return 0; }\n    score += 1 << 1;\n#endif\n\n// IBM z16 / LinuxONE 4 and z17 / LinuxONE 5\n#ifdef GGML_USE_NNPA\n    if (!sf.has_nnpa) { return 0; }\n    score += 1 << 2;\n#endif\n\n    return score;\n}\n\nGGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_s390x_score)\n\n#endif  // __s390x__\n"
  },
  {
    "path": "src/ggml-cpu/arch/s390/quants.c",
    "content": "#define GGML_COMMON_IMPL_C\n#include \"ggml-common.h\"\n#include \"ggml-quants.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-cpu.h\"\n#include \"simd-mappings.h\"\n\n#include \"../../quants.h\"\n#include \"../../ggml-cpu-impl.h\"\n\n#include <math.h>\n#include <string.h>\n#include <assert.h>\n#include <float.h>\n#include <stdlib.h> // for qsort\n#include <stdio.h>  // for GGML_ASSERT\n\n#define GROUP_MAX_EPS 1e-15f\n#define GROUP_MAX_EPS_IQ3_XXS 1e-8f\n#define GROUP_MAX_EPS_IQ2_S 1e-8f\n#define GROUP_MAX_EPS_IQ1_M 1e-7f\n#define GROUP_MAX_EPS_IQ1_S 1e-12f\n\n#define UNUSED GGML_UNUSED\n\n#if defined(__VXE__) || defined(__VXE2__)\n#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s\n#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)\n#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)\n#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)\n#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)\n#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)\n#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)\n#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)\n\n// precomputed tables for expanding 8bits to 8 bytes:\nstatic const __attribute__((aligned(16))) uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b ) << 4\nstatic const __attribute__((aligned(16))) uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4\n\n// permute mask for byteswapping\nstatic const uint8x16_t v_kperm = (const uint8x16_t){\n     7,  6,  5,  4,  3,  2, 1, 0,\n    15, 14, 13, 12, 11, 10, 9, 8\n};\n#endif\n\nvoid quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK8_0 == 32);\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n    block_q8_0 * GGML_RESTRICT y = vy;\n\n#if defined(__VXE__) || defined(__VXE2__)\n    for (int i = 0; i < nb; i++) {\n        float32x4_t srcv [8];\n        float32x4_t asrcv[8];\n        float32x4_t amaxv[8];\n\n        for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);\n        for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);\n        for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);\n        for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);\n        for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);\n\n        const float amax = MAX(MAX(vec_extract(amaxv[0], 0),\n                                   vec_extract(amaxv[0], 1)),\n                               MAX(vec_extract(amaxv[0], 2),\n                                   vec_extract(amaxv[0], 3)));\n\n        const float d = amax / ((1 << 7) - 1);\n        const float id = d ? 1.0f / d : 0.0f;\n\n        y[i].d = GGML_CPU_FP32_TO_FP16(d);\n\n        for (int j = 0; j < 8; j++) {\n            const float32x4_t v = vec_mul(srcv[j], vec_splats(id));\n            /* Uses non-default rounding for vec_signed or vec_round */\n            const int32x4_t vi = vec_signed(__builtin_s390_vfisb(v, 4, 1));\n\n            y[i].qs[4*j + 0] = vec_extract(vi, 0);\n            y[i].qs[4*j + 1] = vec_extract(vi, 1);\n            y[i].qs[4*j + 2] = vec_extract(vi, 2);\n            y[i].qs[4*j + 3] = vec_extract(vi, 3);\n        }\n    }\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_0_ref(x, y, k);\n#endif\n}\n\nvoid quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(k % QK8_1 == 0);\n    const int nb = k / QK8_1;\n\n    block_q8_1 * GGML_RESTRICT y = vy;\n\n#if defined(__VXE__) || defined(__VXE2__)\n    for (int i = 0; i < nb; i++) {\n        float32x4_t srcv [8];\n        float32x4_t asrcv[8];\n        float32x4_t amaxv[8];\n\n        for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);\n        for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);\n        for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);\n        for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);\n        for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);\n\n        const float amax = MAX(MAX(vec_extract(amaxv[0], 0),\n                                   vec_extract(amaxv[0], 1)),\n                               MAX(vec_extract(amaxv[0], 2),\n                                   vec_extract(amaxv[0], 3)));\n\n        const float d = amax / ((1 << 7) - 1);\n        const float id = d ? 1.0f / d : 0.0f;\n\n        y[i].d = GGML_CPU_FP32_TO_FP16(d);\n\n        int32x4_t acc = vec_splats(0);\n\n        for (int j = 0; j < 8; j++) {\n            const float32x4_t v = vec_mul(srcv[j], vec_splats(id));\n            /* Uses non-default rounding for vec_signed or vec_round */\n            const int32x4_t vi = vec_signed(__builtin_s390_vfisb(v, 4, 1));\n\n            y[i].qs[4*j + 0] = vec_extract(vi, 0);\n            y[i].qs[4*j + 1] = vec_extract(vi, 1);\n            y[i].qs[4*j + 2] = vec_extract(vi, 2);\n            y[i].qs[4*j + 3] = vec_extract(vi, 3);\n\n            acc = vec_add(acc, vi);\n        }\n\n        y[i].s = GGML_CPU_FP32_TO_FP16(d * (acc[0] + acc[1] + acc[2] + acc[3]));\n    }\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_1_ref(x, y, k);\n#endif\n}\n\n\n//===================================== Dot products =================================\n\nvoid ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__VXE__) || defined(__VXE2__)\n    float32x4_t acc = vec_splats(0.0f);\n\n    const uint8x16_t v_m = vec_splats((const uint8_t)0x0F);\n    const int8x16_t  v_s = vec_splats( (const int8_t)0x08);\n\n    for (; ib < nb; ++ib) {\n        const uint8x16_t v_x = vec_xl(0, x[ib].qs);\n        const int8x16_t v_xl = (const int8x16_t)(v_x & v_m);\n        const int8x16_t v_xh = (const int8x16_t)(v_x >> 4);\n\n        const int8x16_t v_xls = vec_sub(v_xl, v_s);\n        const int8x16_t v_xhs = vec_sub(v_xh, v_s);\n\n        const int8x16_t v_yl = vec_xl(0      , y[ib].qs);\n        const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs);\n\n        const int16x8_t v_xylso = vec_mulo(v_xls, v_yl);\n        const int16x8_t v_xyl = vec_meadd(v_xls, v_yl, v_xylso);\n        const int16x8_t v_xyhso = vec_mulo(v_xhs, v_yh);\n        const int16x8_t v_xyh = vec_meadd(v_xhs, v_yh, v_xyhso);\n\n        int16x8_t v_xy_ = v_xyl + v_xyh; v_xy_ += vec_reve(v_xy_);\n\n        const float32x4_t v_xy = vec_float(vec_unpackh(v_xy_));\n        const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));\n\n        acc = vec_madd(v_xy, v_d, acc);\n    }\n\n    sumf = vec_hsum_f32x4(acc);\n    *s = sumf;\n#else\n    UNUSED(nb);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(ib);\n    UNUSED(sumf);\n    ggml_vec_dot_q4_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__VXE__) || defined(__VXE2__)\n    float summs = 0;\n    float32x4_t acc = vec_splats(0.0f);\n\n    const uint8x16_t v_m = vec_splat_u8(0x0F);\n\n#pragma GCC unroll 4\n    for (; ib < nb; ++ib) {\n        __builtin_prefetch(x[ib].qs, 0, 1);\n        __builtin_prefetch(y[ib].qs, 0, 1);\n\n        summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);\n\n        const uint8x16_t v_x = vec_xl(0, x[ib].qs);\n        const int8x16_t v_xl = (const int8x16_t)(v_x & v_m);\n        const int8x16_t v_xh = (const int8x16_t)(v_x >> 4);\n\n        const int8x16_t v_yl = vec_xl(0      , y[ib].qs);\n        const int8x16_t v_yh = vec_xl(QK8_1/2, y[ib].qs);\n\n        const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);\n        const float32x4_t v_xy = vec_float(v_xy_);\n\n        const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));\n\n        acc = vec_madd(v_xy, v_d, acc);\n    }\n\n    sumf = vec_hsum_f32x4(acc) + summs;\n    *s = sumf;\n#else\n    UNUSED(nb);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(ib);\n    UNUSED(sumf);\n    ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_MXFP4 == 0);\n    static_assert(QK_MXFP4 == QK8_0, \"QK_MXFP4 and QK8_0 must be the same\");\n\n    const int qk = QK_MXFP4;\n    const int nb = n / qk;\n\n    const block_mxfp4 * GGML_RESTRICT x = vx;\n    const block_q8_0  * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0.0f;\n\n#if defined(__VXE__) || defined(__VXE2__)\n    const int8x16_t  v_k = vec_xl(0, kvalues_mxfp4);\n    const uint8x16_t v_m = vec_splats((const uint8_t)0x0F);\n\n    float32x4_t v_acc = vec_splats(0.0f);\n\n    #pragma GCC unroll 8\n    for (; ib + 1 < nb; ib += 2) {\n        const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];\n        const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1];\n        const block_q8_0  * GGML_RESTRICT y0 = &y[ib + 0];\n        const block_q8_0  * GGML_RESTRICT y1 = &y[ib + 1];\n\n        const uint8x16_t v_x0 = vec_xl(0, x0->qs);\n        const uint8x16_t v_x1 = vec_xl(0, x1->qs);\n\n        int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);\n        int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);\n        int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);\n        int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);\n\n        v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l);\n        v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h);\n        v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l);\n        v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h);\n\n        const int8x16_t v_y0l = vec_xl(0,       y0->qs);\n        const int8x16_t v_y0h = vec_xl(QK8_0/2, y0->qs);\n        const int8x16_t v_y1l = vec_xl(0,       y1->qs);\n        const int8x16_t v_y1h = vec_xl(QK8_0/2, y1->qs);\n\n        const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0l), v_x0h, v_y0h);\n        const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y1l), v_x1h, v_y1h);\n\n        const float32x4_t v_xy0f = vec_float(v_xy0);\n        const float32x4_t v_xy1f = vec_float(v_xy1);\n\n        const float32x4_t v_d0 = vec_splats(GGML_E8M0_TO_FP32_HALF(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d));\n        const float32x4_t v_d1 = vec_splats(GGML_E8M0_TO_FP32_HALF(x1->e) * GGML_CPU_FP16_TO_FP32(y1->d));\n\n        v_acc = vec_madd(v_xy0f, v_d0, v_acc);\n        v_acc = vec_madd(v_xy1f, v_d1, v_acc);\n    }\n\n    for (; ib < nb; ++ib) {\n        const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];\n        const block_q8_0  * GGML_RESTRICT y0 = &y[ib + 0];\n\n        const uint8x16_t v_x = vec_xl(0, x0->qs);\n\n        int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);\n        int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);\n\n        v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl);\n        v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh);\n\n        const int8x16_t v_yl = vec_xl(0,       y0->qs);\n        const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs);\n\n        const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);\n        const float32x4_t v_xyf = vec_float(v_xy);\n\n        const float32x4_t v_d = vec_splats(GGML_E8M0_TO_FP32_HALF(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d));\n        v_acc = vec_madd(v_xyf, v_d, v_acc);\n    }\n\n    sumf = vec_hsum_f32x4(v_acc);\n    *s = sumf;\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(ib);\n    UNUSED(sumf);\n    ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0.0f;\n\n#if defined(__VXE__) || defined(__VXE2__)\n    float32x4_t v_sum0 = vec_splats(0.0f);\n    float32x4_t v_sum1 = vec_splats(0.0f);\n\n    uint32_t qh0, qh1;\n    uint64_t tmp0[4], tmp1[4];\n\n    const uint8x16_t v_m = vec_splats((uint8_t)0x0F);\n\n    #pragma GCC unroll 4\n    for (; ib + 1 < nb; ib += 2) {\n        const block_q5_0 * GGML_RESTRICT x0 = &x[ib + 0];\n        const block_q5_0 * GGML_RESTRICT x1 = &x[ib + 1];\n        const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];\n        const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];\n\n        memcpy(&qh0, x0->qh, sizeof(qh0));\n        memcpy(&qh1, x1->qh, sizeof(qh1));\n\n        tmp0[0] = table_b2b_1[(qh0 >>  0) & 0xFF];\n        tmp0[1] = table_b2b_1[(qh0 >>  8) & 0xFF];\n        tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];\n        tmp0[3] = table_b2b_1[(qh0 >> 24)       ];\n\n        tmp1[0] = table_b2b_1[(qh1 >>  0) & 0xFF];\n        tmp1[1] = table_b2b_1[(qh1 >>  8) & 0xFF];\n        tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];\n        tmp1[3] = table_b2b_1[(qh1 >> 24)       ];\n\n        int8x16_t v_qh0l = vec_xl(0, (const int8_t *)(tmp0 + 0));\n        int8x16_t v_qh0h = vec_xl(0, (const int8_t *)(tmp0 + 2));\n        int8x16_t v_qh1l = vec_xl(0, (const int8_t *)(tmp1 + 0));\n        int8x16_t v_qh1h = vec_xl(0, (const int8_t *)(tmp1 + 2));\n\n        // required for fixing the byteorder\n        v_qh0l = vec_perm(v_qh0l, v_qh0l, v_kperm);\n        v_qh0h = vec_perm(v_qh0h, v_qh0h, v_kperm);\n        v_qh1l = vec_perm(v_qh1l, v_qh1l, v_kperm);\n        v_qh1h = vec_perm(v_qh1h, v_qh1h, v_kperm);\n\n        const uint8x16_t v_x0 = vec_xl(0, (const uint8_t *)x0->qs);\n        const uint8x16_t v_x1 = vec_xl(0, (const uint8_t *)x1->qs);\n\n        int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);\n        int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);\n        int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);\n        int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);\n\n        const int8x16_t v_x0lf = vec_sub(v_x0l, v_qh0l);\n        const int8x16_t v_x0hf = vec_sub(v_x0h, v_qh0h);\n        const int8x16_t v_x1lf = vec_sub(v_x1l, v_qh1l);\n        const int8x16_t v_x1hf = vec_sub(v_x1h, v_qh1h);\n\n        const int8x16_t v_y0l = vec_xl(0,       (const int8_t *)y0->qs);\n        const int8x16_t v_y0h = vec_xl(QK8_0/2, (const int8_t *)y0->qs);\n        const int8x16_t v_y1l = vec_xl(0,       (const int8_t *)y1->qs);\n        const int8x16_t v_y1h = vec_xl(QK8_0/2, (const int8_t *)y1->qs);\n\n        const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0lf, v_y0l), v_x0hf, v_y0h);\n        const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1lf, v_y1l), v_x1hf, v_y1h);\n\n        const float32x4_t v_xy0f = vec_float(v_xy0);\n        const float32x4_t v_xy1f = vec_float(v_xy1);\n\n        const float32x4_t v_d0 = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));\n        const float32x4_t v_d1 = vec_splats(GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d));\n\n        v_sum0 = vec_madd(v_xy0f, v_d0, v_sum0);\n        v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1);\n    }\n\n    sumf += vec_hsum_f32x4(v_sum0) + vec_hsum_f32x4(v_sum1);\n\n    #pragma GCC unroll 4\n    for (; ib < nb; ++ib) {\n        const block_q5_0 * GGML_RESTRICT x0 = &x[ib];\n        const block_q8_0 * GGML_RESTRICT y0 = &y[ib];\n\n        uint32_t qh;\n        memcpy(&qh, x0->qh, sizeof(qh));\n\n        uint64_t tmp[4];\n        tmp[0] = table_b2b_1[(qh >>  0) & 0xFF];\n        tmp[1] = table_b2b_1[(qh >>  8) & 0xFF];\n        tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];\n        tmp[3] = table_b2b_1[(qh >> 24)       ];\n\n        int8x16_t v_qhl = vec_xl(0, (const int8_t *)(tmp + 0));\n        int8x16_t v_qhh = vec_xl(0, (const int8_t *)(tmp + 2));\n\n        // required for fixing the byteorder\n        v_qhl = vec_perm(v_qhl, v_qhl, v_kperm);\n        v_qhh = vec_perm(v_qhh, v_qhh, v_kperm);\n\n        const uint8x16_t v_x = vec_xl(0, (const uint8_t *)x0->qs);\n        int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);\n        int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);\n\n        const int8x16_t v_xlf = vec_sub(v_xl, v_qhl);\n        const int8x16_t v_xhf = vec_sub(v_xh, v_qhh);\n\n        const int8x16_t v_yl = vec_xl(0,       (const int8_t *)y0->qs);\n        const int8x16_t v_yh = vec_xl(QK8_0/2, (const int8_t *)y0->qs);\n\n        const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xlf, v_yl), v_xhf, v_yh);\n        const float32x4_t v_xyf = vec_float(v_xy);\n\n        const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));\n        const float32x4_t v_acc = vec_madd(v_xyf, v_d, vec_splats(0.0f));\n\n        sumf += vec_hsum_f32x4(v_acc);\n    }\n\n    *s = sumf;\n#else\n    UNUSED(nb);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(ib);\n    UNUSED(sumf);\n    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_1);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0.0f;\n\n#if defined(__VXE__) || defined(__VXE2__)\n    float32x4_t v_sum0 = vec_splats(0.0f);\n    float32x4_t v_sum1 = vec_splats(0.0f);\n\n    float summs0 = 0.0f;\n    float summs1 = 0.0f;\n\n    uint32_t qh0;\n    uint32_t qh1;\n\n    uint64_t tmp0[4];\n    uint64_t tmp1[4];\n\n    const uint8x16_t v_m = vec_splats((uint8_t)0x0F);\n\n    #pragma GCC unroll 4\n    for (; ib + 1 < nb; ib += 2) {\n        const block_q5_1 * GGML_RESTRICT x0 = &x[ib + 0];\n        const block_q5_1 * GGML_RESTRICT x1 = &x[ib + 1];\n        const block_q8_1 * GGML_RESTRICT y0 = &y[ib + 0];\n        const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1];\n\n        summs0 += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);\n        summs1 += GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s);\n\n        memcpy(&qh0, x0->qh, sizeof(qh0));\n        memcpy(&qh1, x1->qh, sizeof(qh1));\n\n        tmp0[0] = table_b2b_0[(qh0 >>  0) & 0xFF];\n        tmp0[1] = table_b2b_0[(qh0 >>  8) & 0xFF];\n        tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];\n        tmp0[3] = table_b2b_0[(qh0 >> 24)       ];\n\n        tmp1[0] = table_b2b_0[(qh1 >>  0) & 0xFF];\n        tmp1[1] = table_b2b_0[(qh1 >>  8) & 0xFF];\n        tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];\n        tmp1[3] = table_b2b_0[(qh1 >> 24)       ];\n\n        int8x16_t v_qh0l = vec_xl(0, (const int8_t *)(tmp0 + 0));\n        int8x16_t v_qh0h = vec_xl(0, (const int8_t *)(tmp0 + 2));\n        int8x16_t v_qh1l = vec_xl(0, (const int8_t *)(tmp1 + 0));\n        int8x16_t v_qh1h = vec_xl(0, (const int8_t *)(tmp1 + 2));\n\n        // required for fixing the byteorder\n        v_qh0l = vec_perm(v_qh0l, v_qh0l, v_kperm);\n        v_qh0h = vec_perm(v_qh0h, v_qh0h, v_kperm);\n        v_qh1l = vec_perm(v_qh1l, v_qh1l, v_kperm);\n        v_qh1h = vec_perm(v_qh1h, v_qh1h, v_kperm);\n\n        const uint8x16_t v_x0 = vec_xl(0, x0->qs);\n        const uint8x16_t v_x1 = vec_xl(0, x1->qs);\n\n        const int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);\n        const int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);\n        const int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);\n        const int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);\n\n        const int8x16_t v_x0lf = vec_or(v_x0l, v_qh0l);\n        const int8x16_t v_x0hf = vec_or(v_x0h, v_qh0h);\n        const int8x16_t v_x1lf = vec_or(v_x1l, v_qh1l);\n        const int8x16_t v_x1hf = vec_or(v_x1h, v_qh1h);\n\n        const int8x16_t v_y0l = vec_xl(0      , y0->qs);\n        const int8x16_t v_y0h = vec_xl(QK8_1/2, y0->qs);\n        const int8x16_t v_y1l = vec_xl(0      , y1->qs);\n        const int8x16_t v_y1h = vec_xl(QK8_1/2, y1->qs);\n\n        const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0lf, v_y0l), v_x0hf, v_y0h);\n        const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1lf, v_y1l), v_x1hf, v_y1h);\n\n        const float32x4_t v_xy0f = vec_float(v_xy0);\n        const float32x4_t v_xy1f = vec_float(v_xy1);\n\n        const float32x4_t v_d0 = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));\n        const float32x4_t v_d1 = vec_splats(GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d));\n\n        v_sum0 = vec_madd(v_xy0f, v_d0, v_sum0);\n        v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1);\n    }\n\n    sumf += vec_hsum_f32x4(v_sum0) + vec_hsum_f32x4(v_sum1) + summs0 + summs1;\n\n    #pragma GCC unroll 4\n    for (; ib < nb; ++ib) {\n        const block_q5_1 * GGML_RESTRICT x0 = &x[ib];\n        const block_q8_1 * GGML_RESTRICT y0 = &y[ib];\n\n        float summs = GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);\n\n        uint32_t qh;\n        memcpy(&qh, x0->qh, sizeof(qh));\n\n        uint64_t tmp[4];\n        tmp[0] = table_b2b_0[(qh >>  0) & 0xFF];\n        tmp[1] = table_b2b_0[(qh >>  8) & 0xFF];\n        tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];\n        tmp[3] = table_b2b_0[(qh >> 24)       ];\n\n        int8x16_t v_qhl = vec_xl(0, (const int8_t *)(tmp + 0));\n        int8x16_t v_qhh = vec_xl(0, (const int8_t *)(tmp + 2));\n\n        // required for fixing the byteorder\n        v_qhl = vec_perm(v_qhl, v_qhl, v_kperm);\n        v_qhh = vec_perm(v_qhh, v_qhh, v_kperm);\n\n        const uint8x16_t v_x = vec_xl(0, x0->qs);\n        const int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);\n        const int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);\n\n        const int8x16_t v_xlf = vec_or(v_xl, v_qhl);\n        const int8x16_t v_xhf = vec_or(v_xh, v_qhh);\n\n        const int8x16_t v_yl = vec_xl(0      , y0->qs);\n        const int8x16_t v_yh = vec_xl(QK8_1/2, y0->qs);\n\n        const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xlf, v_yl), v_xhf, v_yh);\n        const float32x4_t v_xyf = vec_float(v_xy);\n\n        const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));\n        const float32x4_t v_acc = vec_madd(v_xyf, v_d, v_acc);\n\n        sumf += vec_hsum_f32x4(v_acc) + summs;\n    }\n\n    *s = sumf;\n#else\n    UNUSED(nb);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(ib);\n    UNUSED(sumf);\n    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q8_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__VXE__) || defined(__VXE2__)\n    float32x4_t acc = vec_splats(0.0f);\n\n#pragma GCC unroll 8\n    for (; ib < nb; ++ib) {\n        __builtin_prefetch(x[ib].qs, 0, 1);\n        __builtin_prefetch(y[ib].qs, 0, 1);\n\n        const int8x16_t v_xl = vec_xl(0      , x[ib].qs);\n        const int8x16_t v_xh = vec_xl(QK8_0/2, x[ib].qs);\n        const int8x16_t v_yl = vec_xl(0      , y[ib].qs);\n        const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs);\n\n        const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);\n        const float32x4_t v_xy = vec_float(v_xy_);\n        const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));\n\n        acc = vec_madd(v_xy, v_d, acc);\n    }\n\n    sumf = vec_hsum_f32x4(acc);\n\n    *s = sumf;\n#else\n    UNUSED(nb);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(ib);\n    UNUSED(sumf);\n    ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n\n    const block_q3_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__VXE__) || defined(__VXE2__)\n    uint32_t aux[3];\n    uint32_t utmp[4];\n\n    const int32x4_t v_z = vec_splat_s32(0);\n    const uint8x16_t v_3m = vec_splat_u8(0x03);\n\n    const uint8x16_t v_0c = vec_splat_u8(1);\n    const uint8x16_t v_1c = vec_sl(v_0c, 1);\n    const uint8x16_t v_2c = vec_sl(v_0c, 2);\n    const uint8x16_t v_3c = vec_sl(v_0c, 3);\n\n    uint8x16_t q3h[4];\n    uint8x16_t q3b[2];\n    int8x16_t q3bytes[4];\n    int8x16_t q8bytes[8];\n    uint8x16_t qhbits[2];\n\n    float sum = 0;\n\n    for (int i = 0; i < nb; ++i) {\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n\n        const uint8_t * restrict x0l = x[i].qs;\n        const uint8_t * restrict x0h = x[i].hmask;\n        const int8_t  * restrict y0  = y[i].qs;\n\n        qhbits[0] = vec_xl(0 , x0h);\n        qhbits[1] = vec_xl(16, x0h);\n\n        int32_t isum = 0;\n\n        memcpy(aux, x[i].scales, 12);\n        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);\n        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);\n        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);\n        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);\n\n        int8_t * scale = (int8_t *)utmp;\n        for (int j = 0; j < 16; ++j) scale[j] -= 32;\n\n        for (int j = 0; j < QK_K/128; ++j) {\n            int32x4_t isum0, isum1, isum2, isum3;\n\n            q3b[0] = vec_xl(0 , x0l);\n            q3b[1] = vec_xl(16, x0l);\n            x0l += 32;\n\n            q8bytes[0] = vec_xl(0  , y0);\n            q8bytes[1] = vec_xl(16 , y0);\n            q8bytes[2] = vec_xl(32 , y0);\n            q8bytes[3] = vec_xl(48 , y0);\n            q8bytes[4] = vec_xl(64 , y0);\n            q8bytes[5] = vec_xl(80 , y0);\n            q8bytes[6] = vec_xl(96 , y0);\n            q8bytes[7] = vec_xl(112, y0);\n            y0 += 128;\n\n            q3h[0] = vec_sl(vec_andc(v_0c, qhbits[0]), 2);\n            q3h[1] = vec_sl(vec_andc(v_0c, qhbits[1]), 2);\n            q3h[2] = vec_sl(vec_andc(v_1c, qhbits[0]), 1);\n            q3h[3] = vec_sl(vec_andc(v_1c, qhbits[1]), 1);\n\n            q3bytes[0] = vec_sub((int8x16_t)vec_and(q3b[0], v_3m), (int8x16_t)q3h[0]);\n            q3bytes[1] = vec_sub((int8x16_t)vec_and(q3b[1], v_3m), (int8x16_t)q3h[1]);\n            q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 2), v_3m), (int8x16_t)q3h[2]);\n            q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 2), v_3m), (int8x16_t)q3h[3]);\n\n            isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[0]);\n            isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[1]);\n            isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[2]);\n            isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[3]);\n\n            isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];\n            isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];\n            isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];\n            isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];\n\n            scale += 4;\n\n            q3h[0] = vec_andc(v_2c, qhbits[0]);\n            q3h[1] = vec_andc(v_2c, qhbits[1]);\n            q3h[2] = vec_sr(vec_andc(v_3c, qhbits[0]), 1);\n            q3h[3] = vec_sr(vec_andc(v_3c, qhbits[1]), 1);\n\n            q3bytes[0] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 4), v_3m), (int8x16_t)q3h[0]);\n            q3bytes[1] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 4), v_3m), (int8x16_t)q3h[1]);\n            q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 6), v_3m), (int8x16_t)q3h[2]);\n            q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 6), v_3m), (int8x16_t)q3h[3]);\n\n            isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[4]);\n            isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[5]);\n            isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]);\n            isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]);\n\n            isum += vec_hsum_i32x4(isum0) * scale[0];\n            isum += vec_hsum_i32x4(isum1) * scale[1];\n            isum += vec_hsum_i32x4(isum2) * scale[2];\n            isum += vec_hsum_i32x4(isum3) * scale[3];\n\n            scale += 4;\n\n            if (j == 0) {\n                qhbits[0] = vec_sr(qhbits[0], 4);\n                qhbits[1] = vec_sr(qhbits[1], 4);\n            }\n        }\n\n        sum += d * isum;\n    }\n\n    *s = sum;\n\n#else\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n#if defined(__VXE__) || defined(__VXE2__)\n    const uint8x16_t v_lm = vec_splat_u8(0x0F);\n    const int32x4_t v_z = vec_splat_s32(0);\n\n    uint8x16_t v_x[2];\n    int8x16_t  v_xl[2];\n    int8x16_t  v_y[2];\n\n    float sumf = 0;\n\n    for (int i = 0; i < nb; ++i) {\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);\n        const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);\n        const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);\n\n        memcpy(utmp, x[i].scales, 12);\n\n        uint32x4_t v_mins8 = { 0 };\n        v_mins8 = vec_insert(utmp[1] & kmask1, v_mins8, 0);\n        v_mins8 = vec_insert(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), v_mins8, 1);\n\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[0] &= kmask1;\n\n        const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8);\n\n        const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh);\n        const int32x4_t v_mins = vec_meadd(v_ysums, v_minsh, v_minso);\n        sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]);\n\n        const uint8_t * scales = (const uint8_t *)utmp;\n        const uint8_t * GGML_RESTRICT x0 = x[i].qs;\n        const int8_t  * GGML_RESTRICT y0 = y[i].qs;\n\n        int32_t sumi1 = 0;\n        int32_t sumi2 = 0;\n\n        for (int j = 0; j < QK_K/64; ++j) {\n            v_x[0] = vec_xl(0 , x0);\n            v_x[1] = vec_xl(16, x0);\n            x0 += 32;\n\n            v_y[0] = vec_xl(0 , y0);\n            v_y[1] = vec_xl(16, y0);\n            y0 += 32;\n\n            v_xl[0] = (int8x16_t)vec_and(v_x[0], v_lm);\n            v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm);\n\n            const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);\n            sumi1 += vec_hsum_i32x4(p1) * scales[2*j+0];\n\n            v_y[0] = vec_xl(0 , y0);\n            v_y[1] = vec_xl(16, y0);\n            y0 += 32;\n\n            v_xl[0] = (int8x16_t)vec_sr(v_x[0], 4);\n            v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4);\n\n            const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);\n            sumi2 += vec_hsum_i32x4(p2) * scales[2*j+1];\n        }\n\n        sumf += d * (sumi1 + sumi2);\n    }\n\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    UNUSED(utmp);\n    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n#if defined(__VXE__) || defined(__VXE2__)\n    const uint8x16_t v_lm = vec_splat_u8(0x0F);\n    const uint8x16_t v_1m = vec_splat_u8(0x01);\n    const uint8x16_t v_2m = vec_splat_u8(0x02);\n\n    const int32x4_t v_z = vec_splat_s32(0);\n\n    const uchar8x16_t v_minsm = {\n        0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,\n        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF\n    };\n\n    int8x16_t  q5b[4];\n    uint8x16_t q5h[4];\n\n    uint8x16_t v_xl[2];\n    uint8x16_t v_xh[2];\n    int8x16_t  v_y[4];\n\n    float sumf = 0;\n\n    for (int i = 0; i < nb; ++i) {\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);\n        const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);\n        const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);\n\n        memcpy(utmp, x[i].scales, 12);\n        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n        const uint32_t uaux = utmp[1] & kmask1;\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[2] = uaux;\n        utmp[0] &= kmask1;\n\n        const uint8x16_t v_mins16 = vec_xl(0, (const uint8_t *)utmp);\n        const uint8x16_t v_mins8 = vec_perm(v_mins16, v_mins16, v_minsm);\n        const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8);\n\n        const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh);\n        const int32x4_t v_mins = vec_meadd(v_ysums, v_minsh, v_minsho);\n        const int32_t mins = vec_hsum_i32x4(v_mins);\n\n        const uint8_t * scales = (const uint8_t *)utmp;\n        const uint8_t * GGML_RESTRICT x0l = x[i].qs;\n        const uint8_t * GGML_RESTRICT x0h = x[i].qh;\n        const int8_t  * GGML_RESTRICT y0 = y[i].qs;\n\n        v_xh[0] = vec_xl(0 , x0h);\n        v_xh[1] = vec_xl(16, x0h);\n\n        int32_t sumi = 0;\n        for (int j = 0; j < QK_K/64; ++j) {\n            v_xl[0] = vec_xl(0 , x0l);\n            v_xl[1] = vec_xl(16, x0l);\n            x0l += 32;\n\n            v_y[0] = vec_xl(0 , y0);\n            v_y[1] = vec_xl(16, y0);\n            v_y[2] = vec_xl(32, y0);\n            v_y[3] = vec_xl(48, y0);\n            y0 += 64;\n\n            q5h[0] = vec_sl(vec_and(v_1m, v_xh[0]), 4);\n            q5h[1] = vec_sl(vec_and(v_1m, v_xh[1]), 4);\n            q5h[2] = vec_sl(vec_and(v_2m, v_xh[0]), 3);\n            q5h[3] = vec_sl(vec_and(v_2m, v_xh[1]), 3);\n            v_xh[0] = vec_sr(v_xh[0], 2);\n            v_xh[1] = vec_sr(v_xh[1], 2);\n\n            q5b[0] = (int8x16_t)vec_or(vec_and(v_xl[0], v_lm), q5h[0]);\n            q5b[1] = (int8x16_t)vec_or(vec_and(v_xl[1], v_lm), q5h[1]);\n            q5b[2] = (int8x16_t)vec_or(vec_sr(v_xl[0], 4), q5h[2]);\n            q5b[3] = (int8x16_t)vec_or(vec_sr(v_xl[1], 4), q5h[3]);\n\n            int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]);\n            int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]);\n\n            sumi += vec_hsum_i32x4(sumi0) * *scales++;\n            sumi += vec_hsum_i32x4(sumi1) * *scales++;\n        }\n\n        sumf += d * sumi - dmin * mins;\n    }\n\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    UNUSED(utmp);\n    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q6_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__VXE__) || defined(__VXE2__)\n    float sum = 0;\n\n    // Lower 4-bit and upper 2-bit masks\n    const uint8x16_t v_lm = vec_splat_u8(0x0F);\n    const uint8x16_t v_um = vec_splat_u8(0x03);\n\n    const int32x4_t v_z = vec_splat_s32(0);\n\n    int8x16_t  q6b[4];\n    uint8x16_t q6h[4];\n\n    uint8x16_t v_xl[4];\n    uint8x16_t v_xh[2];\n    int8x16_t  v_y[4];\n\n    for (int i = 0; i < nb; ++i) {\n        const float d_all = GGML_CPU_FP16_TO_FP32(x[i].d);\n\n        const uint8_t * GGML_RESTRICT x0l = x[i].ql;\n        const uint8_t * GGML_RESTRICT x0h = x[i].qh;\n        const int8_t  * GGML_RESTRICT y0 = y[i].qs;\n\n        const int8_t  * GGML_RESTRICT scale = x[i].scales;\n\n        const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);\n        const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);\n\n        const int8x16_t v_scale  = vec_xl(0, scale);\n        const int16x8_t v_scalel = vec_unpackh(v_scale);\n        const int16x8_t v_scaleh = vec_unpackl(v_scale);\n\n        const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel);\n        const int32x4_t v_minsl = vec_meadd(v_ysumsl, v_scalel, v_minslo);\n        const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh);\n        const int32x4_t v_minsh = vec_meadd(v_ysumsh, v_scaleh, v_minsho);\n        const int32x4_t v_mins = vec_add(v_minsl, v_minsh);\n\n        const int32_t mins = vec_hsum_i32x4(v_mins);\n\n        int32_t isum = 0;\n        for (int j = 0; j < QK_K/128; ++j) {\n            // Load model upper 2 bits\n            v_xh[0] = vec_xl(0 , x0h);\n            v_xh[1] = vec_xl(16, x0h);\n            x0h += 32;\n\n            // Load model lower 4 bits\n            v_xl[0] = vec_xl(0 , x0l);\n            v_xl[1] = vec_xl(16, x0l);\n            v_xl[2] = vec_xl(32, x0l);\n            v_xl[3] = vec_xl(48, x0l);\n            x0l += 64;\n\n            // Load activation quants\n            v_y[0] = vec_xl(0 , y0);\n            v_y[1] = vec_xl(16, y0);\n            v_y[2] = vec_xl(32, y0);\n            v_y[3] = vec_xl(48, y0);\n            y0 += 64;\n\n            q6h[0] = vec_sl(vec_and(v_um, v_xh[0]), 4);\n            q6h[1] = vec_sl(vec_and(v_um, v_xh[1]), 4);\n            uint8x16_t shifted = vec_sr(v_xh[0], 2);\n            q6h[2] = vec_sl(vec_and(v_um, shifted), 4);\n            shifted = vec_sr(v_xh[1], 2);\n            q6h[3] = vec_sl(vec_and(v_um, shifted), 4);\n\n            q6b[0] = (int8x16_t)(vec_or(vec_and(v_xl[0], v_lm), q6h[0]));\n            q6b[1] = (int8x16_t)(vec_or(vec_and(v_xl[1], v_lm), q6h[1]));\n            q6b[2] = (int8x16_t)(vec_or(vec_and(v_xl[2], v_lm), q6h[2]));\n            q6b[3] = (int8x16_t)(vec_or(vec_and(v_xl[3], v_lm), q6h[3]));\n\n            int32x4_t summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]);\n            int32x4_t summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]);\n            int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);\n            int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);\n\n            isum += vec_hsum_i32x4(summs0) * scale[0] +\n                    vec_hsum_i32x4(summs1) * scale[1] +\n                    vec_hsum_i32x4(summs2) * scale[2] +\n                    vec_hsum_i32x4(summs3) * scale[3];\n\n            scale += 4;\n\n\n            // Load activation quants\n            v_y[0] = vec_xl(0 , y0);\n            v_y[1] = vec_xl(16, y0);\n            v_y[2] = vec_xl(32, y0);\n            v_y[3] = vec_xl(48, y0);\n            y0 += 64;\n\n            shifted = vec_sr(v_xh[0], 4);\n            q6h[0] = vec_sl(vec_and(v_um, shifted), 4);\n            shifted = vec_sr(v_xh[1], 4);\n            q6h[1] = vec_sl(vec_and(v_um, shifted), 4);\n            shifted = vec_sr(v_xh[0], 6);\n            q6h[2] = vec_sl(vec_and(v_um, shifted), 4);\n            shifted = vec_sr(v_xh[1], 6);\n            q6h[3] = vec_sl(vec_and(v_um, shifted), 4);\n\n            q6b[0] = (int8x16_t)(vec_or(vec_sr(v_xl[0], 4), q6h[0]));\n            q6b[1] = (int8x16_t)(vec_or(vec_sr(v_xl[1], 4), q6h[1]));\n            q6b[2] = (int8x16_t)(vec_or(vec_sr(v_xl[2], 4), q6h[2]));\n            q6b[3] = (int8x16_t)(vec_or(vec_sr(v_xl[3], 4), q6h[3]));\n\n            summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]);\n            summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]);\n            summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);\n            summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);\n\n            isum += vec_hsum_i32x4(summs0) * scale[0] +\n                    vec_hsum_i32x4(summs1) * scale[1] +\n                    vec_hsum_i32x4(summs2) * scale[2] +\n                    vec_hsum_i32x4(summs3) * scale[3];\n\n            scale += 4;\n        }\n\n        sum += d_all * y[i].d * (isum - 32 * mins);\n    }\n\n    *s = sum;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\n// #if defined(__VXE__) || defined(__VXE2__)\n// static const int8_t keven_signs_q2xs[1024] = {\n//      1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,\n//      1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,\n//      1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1, -1,\n//      1,  1, -1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1,  1,\n//      1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1, -1,\n//      1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1,  1,\n//      1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1,  1,\n//      1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1, -1,\n//      1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1, -1,\n//      1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1,  1,\n//      1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1,  1,\n//      1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1, -1,\n//      1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1,  1,\n//      1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,\n//      1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1, -1,\n//      1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1,\n//      1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1, -1,\n//      1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1,  1,\n//      1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1,\n//      1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1,\n//      1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1,  1,\n//      1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1, -1,\n//      1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1, -1,\n//      1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1,\n//      1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1,\n//      1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1, -1,\n//      1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1,\n//      1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1,\n//      1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1,\n//      1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,  1,\n//      1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1,\n//      1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,\n// };\n// #endif\n\n// void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n//     assert(n % QK_K == 0);\n//     assert(nrc == 1);\n//     UNUSED(nrc);\n//     UNUSED(bx);\n//     UNUSED(by);\n//     UNUSED(bs);\n\n//     const block_iq2_xxs * GGML_RESTRICT x = vx;\n//     const block_q8_K    * GGML_RESTRICT y = vy;\n\n//     const int nb = n / QK_K;\n\n// #if defined(__VXE__) || defined(__VXE2__)\n//    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n\n//    uint32_t aux32[4];\n//    const uint8_t * aux8 = (const uint8_t *)aux32;\n\n//    float sumf = 0;\n\n//    for (int i = 0; i < nb; ++i) {\n//        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n//        const uint16_t * GGML_RESTRICT q2 = x[i].qs;\n//        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n\n//        float sumf1 = 0, sumf2 = 0;\n\n//        for (int ib32 = 0; ib32 < QK_K/32; ib += 2) {\n//            int8x16_t q8b0 = vec_xl( 0, q8);\n//            int8x16_t qb81 = vec_xl(16, q8);\n//            int8x16_t q8b2 = vec_xl(32, q8);\n//            int8x16_t q8b3 = vec_xl(48, q8);\n//            q8 += 64;\n\n//            memcpy(aux32, q2, 4 * sizeof(uint32_t));\n//            q2 += 8;\n\n//            int8x16_t q2u0 = { *(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1]) };\n//            int8x16_t q2u1 = { *(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3]) };\n//            int8x16_t q2u2 = { *(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9]) };\n//            int8x16_t q2u3 = { *(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11]) };\n\n//            int8x16_t q2s0 = { *(const int64_t *)(signs64 + ((aux32[1] >>  0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >>  7) & 127)) };\n//            int8x16_t q2s1 = { *(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127)) };\n//            int8x16_t q2s2 = { *(const int64_t *)(signs64 + ((aux32[3] >>  0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >>  7) & 127)) };\n//            int8x16_t q2s3 = { *(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127)) };\n\n//            q2u0 = vec_mul(q2u0, q2s0);\n//            q2u1 = vec_mul(q2u1, q2s1);\n//            q2u2 = vec_mul(q2u2, q2s2);\n//            q2u3 = vec_mul(q2u3, q2s3);\n\n//            const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u0, q8b0), q2u1, q8b1);\n//            const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u2, q8b2), q2u3, q8b3);\n\n//            sumf1 += (p1[0] + p1[1] + p1[2] + p1[3]) * (0.5f + (aux32[1] >> 28));\n//            sumf2 += (p2[0] + p2[1] + p2[2] + p2[3]) * (0.5f + (aux32[3] >> 28));\n//        }\n\n//        sumf += d * (sumf1 + sumf2);\n//    }\n\n//    *s = 0.25f * sumf;\n\n// #else\n\n//     uint32_t aux32[2];\n//     const uint8_t * aux8 = (const uint8_t *)aux32;\n\n//     float sumf = 0.f;\n//     for (int i = 0; i < nb; ++i) {\n//         const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n//         const uint16_t * GGML_RESTRICT q2 = x[i].qs;\n//         const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n//         int32_t bsum = 0;\n//         for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {\n//             memcpy(aux32, q2, 2*sizeof(uint32_t));\n//             q2 += 4;\n//             const uint32_t ls = 2*(aux32[1] >> 28) + 1;\n//             int32_t sumi = 0;\n//             for (int l = 0; l < 4; ++l) {\n//                 const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);\n//                 const uint8_t  signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];\n//                 for (int j = 0; j < 8; ++j) {\n//                     sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);\n//                 }\n//                 q8 += 8;\n//             }\n//             bsum += sumi * ls;\n//         }\n//         sumf += d * bsum;\n//     }\n//     *s = 0.125f * sumf;\n// #endif\n// }\n\nvoid ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK4_NL == 0);\n    static_assert(QK4_NL == QK8_0, \"QK4_NL and QK8_0 must be the same\");\n\n    const block_iq4_nl * GGML_RESTRICT x = vx;\n    const block_q8_0   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK4_NL;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__VXE__) || defined(__VXE2__)\n    const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);\n    const uint8x16_t v_m = vec_splat_u8(0x0F);\n\n    for (; ib < nb; ++ib) {\n        const block_iq4_nl * GGML_RESTRICT x0 = &x[ib];\n        const block_q8_0   * GGML_RESTRICT y0 = &y[ib];\n\n        const uint8x16_t v_x = vec_xl(0, x0->qs);\n        int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);\n        int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);\n\n        v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl);\n        v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh);\n\n        const int8x16_t v_yl = vec_xl(0      , y0->qs);\n        const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs);\n        const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);\n\n        sumf += GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d) * vec_hsum_i32x4(v_xy);\n    }\n\n    *s = sumf;\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(ib);\n    UNUSED(sumf);\n    ggml_vec_dot_iq4_nl_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_K == 0);\n\n    const block_iq4_xs * GGML_RESTRICT x = vx;\n    const block_q8_K   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__VXE__) || defined(__VXE2__)\n    const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);\n    const uint8x16_t v_m = vec_splat_u8(0x0F);\n\n    float sumf = 0;\n\n    for (int ibl = 0; ibl < nb; ++ibl) {\n        const uint8_t * GGML_RESTRICT q4 = x[ibl].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[ibl].qs;\n\n        uint16_t h = x[ibl].scales_h;\n\n        int sumi1 = 0, sumi2 = 0;\n        for (int ib = 0; ib < QK_K/64; ++ib) {\n            const uint8x16_t v_x0 = vec_xl(0       , q4);\n            const uint8x16_t v_x1 = vec_xl(QK4_NL/2, q4);\n            q4 += 32;\n\n            int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);\n            int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);\n            int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);\n            int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);\n\n            v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l);\n            v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h);\n            v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l);\n            v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h);\n\n            const int8x16_t v_y0 = vec_xl( 0, q8);\n            const int8x16_t v_y1 = vec_xl(16, q8);\n            const int8x16_t v_y2 = vec_xl(32, q8);\n            const int8x16_t v_y3 = vec_xl(48, q8);\n            q8 += 64;\n\n            int32x4_t vsumi0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0), v_x0h, v_y1);\n            int32x4_t vsumi1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y2), v_x1h, v_y3);\n\n            int ls1 = ((x[ibl].scales_l[ib] & 0xF) | ((h << 4) & 0x30)) - 32;\n            int ls2 = ((x[ibl].scales_l[ib] >>  4) | ((h << 2) & 0x30)) - 32;\n\n            h >>= 4;\n\n            sumi1 += vec_hsum_i32x4(vsumi0) * ls1;\n            sumi2 += vec_hsum_i32x4(vsumi1) * ls2;\n        }\n\n        sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);\n    }\n\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\n"
  },
  {
    "path": "src/ggml-cpu/arch/wasm/quants.c",
    "content": "#define GGML_COMMON_IMPL_C\n#include \"ggml-common.h\"\n#include \"ggml-quants.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-cpu.h\"\n#include \"simd-mappings.h\"\n\n#include \"../../quants.h\"\n#include \"../../ggml-cpu-impl.h\"\n\n#include <math.h>\n#include <string.h>\n#include <assert.h>\n#include <float.h>\n#include <stdlib.h> // for qsort\n#include <stdio.h>  // for GGML_ASSERT\n\n#define GROUP_MAX_EPS 1e-15f\n#define GROUP_MAX_EPS_IQ3_XXS 1e-8f\n#define GROUP_MAX_EPS_IQ2_S 1e-8f\n#define GROUP_MAX_EPS_IQ1_M 1e-7f\n#define GROUP_MAX_EPS_IQ1_S 1e-12f\n\n#define UNUSED GGML_UNUSED\n\n#if defined(__wasm_simd128__)\n#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s\n#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)\n#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)\n#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)\n#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)\n#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)\n#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)\n#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)\n\n// precomputed tables for expanding 8bits to 8 bytes:\nstatic const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4\nstatic const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4\n#endif\n\nvoid quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK8_0 == 32);\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n    block_q8_0 * GGML_RESTRICT y = vy;\n\n#if defined __wasm_simd128__\n    for (int i = 0; i < nb; i++) {\n        v128_t srcv [8];\n        v128_t asrcv[8];\n        v128_t amaxv[8];\n\n        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);\n        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);\n\n        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);\n        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);\n        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);\n\n        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),\n                                   wasm_f32x4_extract_lane(amaxv[0], 1)),\n                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),\n                                   wasm_f32x4_extract_lane(amaxv[0], 3)));\n\n        const float d = amax / ((1 << 7) - 1);\n        const float id = d ? 1.0f/d : 0.0f;\n\n        y[i].d = GGML_CPU_FP32_TO_FP16(d);\n\n        for (int j = 0; j < 8; j++) {\n            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));\n            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);\n\n            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);\n            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);\n            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);\n            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);\n        }\n    }\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_0_ref(x, y, k);\n#endif\n}\n\nvoid quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(k % QK8_1 == 0);\n    const int nb = k / QK8_1;\n\n    block_q8_1 * GGML_RESTRICT y = vy;\n#if defined __wasm_simd128__\n    for (int i = 0; i < nb; i++) {\n        v128_t srcv [8];\n        v128_t asrcv[8];\n        v128_t amaxv[8];\n\n        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);\n        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);\n\n        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);\n        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);\n        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);\n\n        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),\n                                   wasm_f32x4_extract_lane(amaxv[0], 1)),\n                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),\n                                   wasm_f32x4_extract_lane(amaxv[0], 3)));\n\n        const float d = amax / ((1 << 7) - 1);\n        const float id = d ? 1.0f/d : 0.0f;\n\n        y[i].d = GGML_CPU_FP32_TO_FP16(d);\n\n        v128_t accv = wasm_i32x4_splat(0);\n\n        for (int j = 0; j < 8; j++) {\n            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));\n            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);\n\n            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);\n            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);\n            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);\n            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);\n\n            accv = wasm_i32x4_add(accv, vi);\n        }\n\n        y[i].s = GGML_CPU_FP32_TO_FP16(\n                d * (wasm_i32x4_extract_lane(accv, 0) +\n                     wasm_i32x4_extract_lane(accv, 1) +\n                     wasm_i32x4_extract_lane(accv, 2) +\n                     wasm_i32x4_extract_lane(accv, 3)));\n    }\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_1_ref(x, y, k);\n#endif\n}\n\n//===================================== Q8_K ==============================================\n\nvoid quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n#ifdef __wasm_simd128__\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n    block_q8_K * GGML_RESTRICT yc = y; // Cast to proper type\n\n    for (int i = 0; i < nb; i++) {\n        const float * x_block = x + i * QK_K;\n\n        v128_t min_vec = wasm_v128_load(x_block);\n        v128_t max_vec = min_vec;\n\n        for (int j = 4; j < QK_K; j += 4) {\n            v128_t x_vec = wasm_v128_load(x_block + j);\n            max_vec = wasm_f32x4_pmax(max_vec, x_vec);\n            min_vec = wasm_f32x4_pmin(min_vec, x_vec);\n        }\n        max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1));\n        max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2));\n        min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1));\n        min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2));\n        float max = wasm_f32x4_extract_lane(max_vec, 0);\n        float min = wasm_f32x4_extract_lane(min_vec, 0);\n        float amax = -min > max ? min : max;\n\n        if (amax == 0.0f) {\n            yc[i].d = 0.0f;\n            const v128_t zero = wasm_i8x16_splat(0);\n            for (int j = 0; j < QK_K; j += 16) {\n                wasm_v128_store(yc[i].qs + j, zero);\n            }\n            continue;\n        }\n\n        const float iscale = -127.0f / amax;\n        const v128_t scale_vec = wasm_f32x4_splat(iscale);\n\n        // Process 16 elements per iteration\n        for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {\n            // Load and quantize 16 floats\n            v128_t x0 = wasm_v128_load(x_block + j);\n            v128_t x1 = wasm_v128_load(x_block + j + 4);\n            v128_t x2 = wasm_v128_load(x_block + j + 8);\n            v128_t x3 = wasm_v128_load(x_block + j + 12);\n\n            v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));\n            v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));\n            v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));\n            v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));\n\n            // Convert to i32 with saturation\n            v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);\n            v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);\n            v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);\n            v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);\n\n            // Pack into 16 i8 values\n            v128_t i8 = wasm_i8x16_narrow_i16x8(\n                wasm_i16x8_narrow_i32x4(i0, i1),\n                wasm_i16x8_narrow_i32x4(i2, i3)\n            );\n            wasm_v128_store(yc[i].qs + j, i8);\n\n            // Calculate bsums using SIMD\n            v128_t sum16 = wasm_i16x8_add(\n                wasm_i16x8_extend_low_i8x16(i8),\n                wasm_i16x8_extend_high_i8x16(i8)\n            );\n            v128_t sum32 = wasm_i32x4_add(\n                wasm_i32x4_extend_low_i16x8(sum16),\n                wasm_i32x4_extend_high_i16x8(sum16)\n            );\n            sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));\n            sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));\n            yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);\n        }\n\n        yc[i].d = 1.0f / iscale;\n    }\n#else\n    quantize_row_q8_K_ref(x, y, k);\n#endif\n}\n\n\n//===================================== Dot products =================================\n\nvoid ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined __wasm_simd128__\n    v128_t sumv = wasm_f32x4_splat(0.0f);\n\n    const v128_t m4b = wasm_i8x16_splat(0x0F);\n    const v128_t s8b = wasm_i8x16_splat(0x8);\n\n    for (; ib + 1 < nb; ib += 2) {\n        const block_q4_0 * GGML_RESTRICT x0 = &x[ib];\n        const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];\n        const block_q8_0 * GGML_RESTRICT y0 = &y[ib];\n        const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];\n\n        // Load and process x0\n        v128_t v0_0 = wasm_v128_load(x0->qs);\n        v128_t v0_0l = wasm_v128_and(v0_0, m4b);\n        v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);\n        v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);\n        v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);\n\n        // Load y0 vectors\n        v128_t y0_l = wasm_v128_load(y0->qs);\n        v128_t y0_h = wasm_v128_load(y0->qs + 16);\n\n        // Extend to i16x8 and compute dot products\n        v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);\n        v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);\n        v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);\n        v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);\n\n        v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);\n        v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);\n        v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);\n        v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);\n\n        v128_t dp0 = wasm_i32x4_add(\n            wasm_i32x4_add(\n                wasm_i32x4_dot_i16x8(dx0l, dy0ll),\n                wasm_i32x4_dot_i16x8(dx0h, dy0lh)\n            ),\n            wasm_i32x4_add(\n                wasm_i32x4_dot_i16x8(dx0hl, dy0hl),\n                wasm_i32x4_dot_i16x8(dx0hh, dy0hh)\n            )\n        );\n\n        // Load and process x1\n        v128_t v0_1 = wasm_v128_load(x1->qs);\n        v128_t v0_1l = wasm_v128_and(v0_1, m4b);\n        v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);\n        v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);\n        v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);\n\n        // Load y1 vectors\n        v128_t y1_l = wasm_v128_load(y1->qs);\n        v128_t y1_h = wasm_v128_load(y1->qs + 16);\n\n        // Extend to i16x8 and compute dot products\n        v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);\n        v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);\n        v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);\n        v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);\n\n        v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);\n        v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);\n        v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);\n        v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);\n\n        v128_t dp1 = wasm_i32x4_add(\n            wasm_i32x4_add(\n                wasm_i32x4_dot_i16x8(dx1l, dy1ll),\n                wasm_i32x4_dot_i16x8(dx1h, dy1lh)\n            ),\n            wasm_i32x4_add(\n                wasm_i32x4_dot_i16x8(dx1hl, dy1hl),\n                wasm_i32x4_dot_i16x8(dx1hh, dy1hh)\n            )\n        );\n\n        // Accumulate results with scaling\n        float scale0 = GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d);\n        float scale1 = GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d);\n\n        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));\n        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));\n    }\n\n    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +\n           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);\n\n#endif\n    for (; ib < nb; ++ib) {\n        int sumi0 = 0;\n        int sumi1 = 0;\n\n        for (int j = 0; j < qk/2; ++j) {\n            const int v0 = (x[ib].qs[j] & 0x0F) - 8;\n            const int v1 = (x[ib].qs[j] >>   4) - 8;\n\n            sumi0 += (v0 * y[ib].qs[j]);\n            sumi1 += (v1 * y[ib].qs[j + qk/2]);\n        }\n\n        int sumi = sumi0 + sumi1;\n        sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    int ib = 0;\n    float sumf = 0;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n#if defined __wasm_simd128__\n    v128_t sumv = wasm_f32x4_splat(0.0f);\n\n    uint32_t qh_;\n    uint64_t tmp[4];\n\n    // TODO: check if unrolling this is better\n    for (; ib < nb; ++ib) {\n        const block_q5_0 * GGML_RESTRICT x0 = &x[ib];\n        const block_q8_0 * GGML_RESTRICT y0 = &y[ib];\n\n        const v128_t m4b  = wasm_i8x16_splat(0x0F);\n\n        // extract the 5th bit\n        memcpy(&qh_, x0->qh, sizeof(qh_));\n\n        tmp[0] = table_b2b_1[(qh_ >>  0) & 0xFF];\n        tmp[1] = table_b2b_1[(qh_ >>  8) & 0xFF];\n        tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];\n        tmp[3] = table_b2b_1[(qh_ >> 24)       ];\n\n        const v128_t qhl = wasm_v128_load(tmp + 0);\n        const v128_t qhh = wasm_v128_load(tmp + 2);\n\n        const v128_t v0 = wasm_v128_load(x0->qs);\n\n        // 4-bit -> 8-bit\n        const v128_t v0l = wasm_v128_and (v0, m4b);\n        const v128_t v0h = wasm_u8x16_shr(v0, 4);\n\n        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)\n        const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);\n        const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);\n\n        // load y\n        const v128_t v1l = wasm_v128_load(y0->qs);\n        const v128_t v1h = wasm_v128_load(y0->qs + 16);\n\n        // int8x16 -> int16x8\n        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);\n        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);\n        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);\n        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);\n\n        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);\n        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);\n        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);\n        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);\n\n        // dot product\n        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(\n                        wasm_i32x4_add(\n                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),\n                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),\n                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),\n                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),\n                    wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d))));\n    }\n\n    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +\n           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);\n\n    *s = sumf;\n#else\n    UNUSED(nb);\n    UNUSED(ib);\n    UNUSED(sumf);\n    UNUSED(x);\n    UNUSED(y);\n    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    int ib = 0;\n    float sumf = 0;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_1);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n#if defined __wasm_simd128__\n    v128_t sumv = wasm_f32x4_splat(0.0f);\n\n    float summs = 0.0f;\n\n    uint32_t qh_;\n    uint64_t tmp[4];\n\n    // TODO: check if unrolling this is better\n    for (; ib < nb; ++ib) {\n        const block_q5_1 * GGML_RESTRICT x0 = &x[ib];\n        const block_q8_1 * GGML_RESTRICT y0 = &y[ib];\n\n        summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);\n\n        const v128_t m4b = wasm_i8x16_splat(0x0F);\n\n        // extract the 5th bit\n        memcpy(&qh_, x0->qh, sizeof(qh_));\n\n        tmp[0] = table_b2b_0[(qh_ >>  0) & 0xFF];\n        tmp[1] = table_b2b_0[(qh_ >>  8) & 0xFF];\n        tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];\n        tmp[3] = table_b2b_0[(qh_ >> 24)       ];\n\n        const v128_t qhl = wasm_v128_load(tmp + 0);\n        const v128_t qhh = wasm_v128_load(tmp + 2);\n\n        const v128_t v0 = wasm_v128_load(x0->qs);\n\n        // 4-bit -> 8-bit\n        const v128_t v0l = wasm_v128_and (v0, m4b);\n        const v128_t v0h = wasm_u8x16_shr(v0, 4);\n\n        // add high bit\n        const v128_t v0lf = wasm_v128_or(v0l, qhl);\n        const v128_t v0hf = wasm_v128_or(v0h, qhh);\n\n        // load y\n        const v128_t v1l = wasm_v128_load(y0->qs);\n        const v128_t v1h = wasm_v128_load(y0->qs + 16);\n\n        // int8x16 -> int16x8\n        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);\n        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);\n        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);\n        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);\n\n        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);\n        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);\n        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);\n        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);\n\n        // dot product\n        sumv = wasm_f32x4_add(sumv,\n                wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(\n                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),\n                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),\n                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),\n                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),\n                    wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d))));\n    }\n\n    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +\n           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;\n\n    *s = sumf;\n#else\n    UNUSED(nb);\n    UNUSED(ib);\n    UNUSED(sumf);\n    UNUSED(x);\n    UNUSED(y);\n    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q8_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined __wasm_simd128__\n    v128_t sumv = wasm_f32x4_splat(0.0f);\n\n    for (; ib < nb; ++ib) {\n        const block_q8_0 * GGML_RESTRICT x0 = &x[ib];\n        const block_q8_0 * GGML_RESTRICT y0 = &y[ib];\n\n        const v128_t x0_0 = wasm_v128_load(x0->qs);\n        const v128_t x0_1 = wasm_v128_load(x0->qs + 16);\n        const v128_t y0_0 = wasm_v128_load(y0->qs);\n        const v128_t y0_1 = wasm_v128_load(y0->qs + 16);\n\n        // Extend 8-bit to 16-bit\n        const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);\n        const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);\n        const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);\n        const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);\n\n        const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);\n        const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);\n        const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);\n        const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);\n\n        // Compute dot products\n        const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);\n        const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);\n        const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);\n        const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);\n\n        // Sum all dot products\n        const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));\n\n        // Convert to float and accumulate\n        const float scale = GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d);\n        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));\n    }\n\n    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +\n           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);\n\n    *s = sumf;\n#else\n    UNUSED(nb);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(ib);\n    UNUSED(sumf);\n    ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q2_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __wasm_simd128__\n    float sumf = 0;\n\n    for (int i = 0; i < nb; ++i) {\n        const uint8_t * q2 = x[i].qs;\n        const int8_t * q8 = y[i].qs;\n        const uint8_t * sc = x[i].scales;\n\n        // Vectorized summs calculation\n        v128_t summs_vec = wasm_i32x4_splat(0);\n        {\n            v128_t sc_vec = wasm_v128_load(sc);\n            v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);\n\n            v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);\n            v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);\n\n            v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);\n            v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);\n\n            summs_vec = wasm_i32x4_add(\n                wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),\n                               wasm_i32x4_dot_i16x8(sc_high, bsums2)),\n                summs_vec\n            );\n\n            summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));\n            summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));\n        }\n        int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);\n\n        // Vectorized isum calculation\n        int32_t isum = 0;\n        const uint8_t * sc_ptr = sc;\n        const int k_iters = QK_K/128;\n\n        for (int k = 0; k < k_iters; ++k) {\n            v128_t isum_vec = wasm_i32x4_splat(0);\n            int shift = 0;\n\n            for (int j = 0; j < 4; ++j) {\n                const int d0 = (sc_ptr[0] & 0xF);\n                const int d1 = (sc_ptr[1] & 0xF);\n                sc_ptr += 2;\n\n                // Process first 16 elements\n                v128_t q2_0 = wasm_v128_load(q2);\n                v128_t q8_0 = wasm_v128_load(q8);\n                v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);\n                v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));\n\n                // Process next 16 elements\n                v128_t q2_1 = wasm_v128_load(q2 + 16);\n                v128_t q8_1 = wasm_v128_load(q8 + 16);\n                v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);\n                v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));\n\n                // Calculate dot products\n                v128_t p0 = wasm_i32x4_dot_i16x8(\n                    wasm_i16x8_extend_low_i8x16(q8_0),\n                    wasm_i16x8_extend_low_i8x16(q2_bits_0)\n                );\n                v128_t p1 = wasm_i32x4_dot_i16x8(\n                    wasm_i16x8_extend_high_i8x16(q8_0),\n                    wasm_i16x8_extend_high_i8x16(q2_bits_0)\n                );\n                v128_t p2 = wasm_i32x4_dot_i16x8(\n                    wasm_i16x8_extend_low_i8x16(q8_1),\n                    wasm_i16x8_extend_low_i8x16(q2_bits_1)\n                );\n                v128_t p3 = wasm_i32x4_dot_i16x8(\n                    wasm_i16x8_extend_high_i8x16(q8_1),\n                    wasm_i16x8_extend_high_i8x16(q2_bits_1)\n                );\n\n                // Accumulate scaled results\n                v128_t scaled = wasm_i32x4_add(\n                    wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),\n                    wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))\n                );\n\n                isum_vec = wasm_i32x4_add(isum_vec, scaled);\n                q8 += 32;\n                shift += 2;\n            }\n            q2 += 32;\n\n            // Horizontal sum of isum_vec\n            isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));\n            isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));\n            isum += wasm_i32x4_extract_lane(isum_vec, 0);\n        }\n\n        const float dall = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;\n        sumf += dall * isum - dmin * summs;\n    }\n\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n\n    const block_q3_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __wasm_simd128__\n    int8_t  aux8[QK_K];\n    float   sums[8] = {0};\n    uint32_t auxs[4];\n\n    float sumf = 0;\n    for (int i = 0; i < nb; ++i) {\n        const uint8_t * GGML_RESTRICT q3 = x[i].qs;\n        const uint8_t * GGML_RESTRICT hm = x[i].hmask;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        // Process blocks with SIMD\n        int8_t * a = aux8;\n        uint8_t m = 1;\n        for (int j = 0; j < QK_K; j += 128) {\n            for (int shift = 0; shift <= 6; shift += 2) {\n                v128_t v_m = wasm_i8x16_splat(m);\n                for (int l = 0; l < 32; l += 16) {\n                    v128_t v_q3 = wasm_v128_load(q3 + l);\n                    v128_t v_shift = wasm_i8x16_shr(v_q3, shift);\n                    v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));\n\n                    v128_t v_hm = wasm_v128_load(hm + l);\n                    v128_t v_mask = wasm_v128_and(v_hm, v_m);\n                    v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));\n\n                    v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));\n                    wasm_v128_store(a + l, v_low2);\n                }\n                a += 32;\n                m <<= 1;\n            }\n            q3 += 32;\n        }\n\n        // Extract scales\n        memcpy(auxs, x[i].scales, 12);\n        uint32_t tmp = auxs[2];\n        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);\n        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);\n        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);\n        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);\n        const int8_t * scales = (const int8_t *)auxs;\n\n        // SIMD dot product with register accumulators\n        v128_t v_acc0 = wasm_i32x4_splat(0);\n        v128_t v_acc1 = wasm_i32x4_splat(0);\n        a = aux8;\n        for (int j = 0; j < QK_K/16; ++j) {\n            const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);\n\n            // Process 16 elements per iteration\n            for (int k = 0; k < 2; ++k) {\n                const v128_t v_q8 = wasm_i16x8_load8x8(q8);\n                const v128_t v_a = wasm_i16x8_load8x8(a);\n\n                v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);\n                v_prod = wasm_i16x8_mul(v_prod, v_scale);\n\n                v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));\n                v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));\n\n                q8 += 8;\n                a += 8;\n            }\n        }\n\n        // Accumulate results\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const v128_t v_d = wasm_f32x4_splat(d);\n        v128_t v_sum = wasm_f32x4_add(\n            wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),\n            wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)\n        );\n\n        // Accumulate into sums vector\n        wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));\n    }\n\n    // Horizontal sum\n    v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));\n    sumf = wasm_f32x4_extract_lane(v_sum, 0) +\n           wasm_f32x4_extract_lane(v_sum, 1) +\n           wasm_f32x4_extract_lane(v_sum, 2) +\n           wasm_f32x4_extract_lane(v_sum, 3);\n\n    *s = sumf;\n\n#else\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n\n}\n\nvoid ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n#if defined __wasm_simd128__\n    const uint8_t * scales = (const uint8_t*)&utmp[0];\n    float sumf = 0;\n\n    for (int i = 0; i < nb; ++i) {\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); // Corrected sign\n\n        const uint8_t * GGML_RESTRICT q4 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        // Process scales and mins\n        memcpy(utmp, x[i].scales, 12);\n        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n        const uint32_t uaux = utmp[1] & kmask1;\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[2] = uaux;\n        utmp[0] &= kmask1;\n\n        // Sum mins * q8sums\n        int32_t sumi = 0;\n        const int16_t * GGML_RESTRICT q8sums = y[i].bsums;\n        const uint8_t * m = (const uint8_t *)&utmp[2];\n        for (int j = 0; j < 16; j += 2) {\n            sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];\n        }\n        sumf -= dmin * sumi;\n\n        int32_t sumi1 = 0;\n        int32_t sumi2 = 0;\n\n        for (int j = 0; j < QK_K/64; ++j) {\n            // Load 64 4-bit weights (32 bytes)\n            const v128_t q4x0 = wasm_v128_load(q4);\n            const v128_t q4x1 = wasm_v128_load(q4 + 16);\n            q4 += 32;\n\n            // Split into low/high nibbles\n            const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));\n            const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);\n            const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));\n            const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);\n\n            // Load 64 8-bit values (64 bytes)\n            const v128_t q8x0 = wasm_v128_load(q8);\n            const v128_t q8x1 = wasm_v128_load(q8 + 16);\n            const v128_t q8x2 = wasm_v128_load(q8 + 32);\n            const v128_t q8x3 = wasm_v128_load(q8 + 48);\n            q8 += 64;\n\n            // Low nibble products\n            v128_t vacc1 = wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_low_i8x16(q4l0),\n                wasm_i16x8_extend_low_i8x16(q8x0)\n            );\n            vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_high_i8x16(q4l0),\n                wasm_i16x8_extend_high_i8x16(q8x0)\n            ));\n            vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_low_i8x16(q4l1),\n                wasm_i16x8_extend_low_i8x16(q8x1)\n            ));\n            vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_high_i8x16(q4l1),\n                wasm_i16x8_extend_high_i8x16(q8x1)\n            ));\n\n            // High nibble products\n            v128_t vacc2 = wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_low_i8x16(q4h0),\n                wasm_i16x8_extend_low_i8x16(q8x2)\n            );\n            vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_high_i8x16(q4h0),\n                wasm_i16x8_extend_high_i8x16(q8x2)\n            ));\n            vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_low_i8x16(q4h1),\n                wasm_i16x8_extend_low_i8x16(q8x3)\n            ));\n            vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_high_i8x16(q4h1),\n                wasm_i16x8_extend_high_i8x16(q8x3)\n            ));\n\n            // Accumulate scaled results\n            int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +\n                                wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);\n            sumi1 += vacc1_sum * scales[2*j];\n\n            int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +\n                                wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);\n            sumi2 += vacc2_sum * scales[2*j+1];\n        }\n\n        sumf += d * (sumi1 + sumi2);\n    }\n\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    UNUSED(utmp);\n    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n#if defined __wasm_simd128__\n    //const uint8_t * scales = (const uint8_t*)&utmp[0];\n    float sumf = 0;\n\n    for (int i = 0; i < nb; ++i) {\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); // Fixed sign\n\n        const uint8_t * GGML_RESTRICT q5 = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        // Process scales and mins\n        memcpy(utmp, x[i].scales, 12);\n        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n        const uint32_t uaux = utmp[1] & kmask1;\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[2] = uaux;\n        utmp[0] &= kmask1;\n\n        // Sum mins * q8sums\n        int32_t sumi_mins = 0;\n        const int16_t * GGML_RESTRICT q8sums = y[i].bsums;\n        const uint8_t * m = (const uint8_t *)&utmp[2];\n        for (int j = 0; j < 16; j += 2) {\n            sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];\n        }\n        sumf -= dmin * sumi_mins; // Correct subtraction\n\n        v128_t qh0 = wasm_v128_load(qh);\n        v128_t qh1 = wasm_v128_load(qh + 16);\n        const uint8_t * sc = (const uint8_t *)utmp;\n\n        int32_t sumi = 0;\n\n        for (int j = 0; j < QK_K/64; ++j) {\n            const int shift = j * 2;\n            v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);\n            v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);\n\n            v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);\n            v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);\n            v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);\n            v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);\n\n            v128_t q5_0 = wasm_v128_load(q5);\n            v128_t q5_1 = wasm_v128_load(q5 + 16);\n            q5 += 32;\n\n            v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);\n            v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);\n            v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);\n            v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);\n\n            v128_t q8_0 = wasm_v128_load(q8);\n            v128_t q8_1 = wasm_v128_load(q8 + 16);\n            v128_t q8_2 = wasm_v128_load(q8 + 32);\n            v128_t q8_3 = wasm_v128_load(q8 + 48);\n            q8 += 64;\n\n            // Process low quants\n            v128_t pl0 = wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_low_i8x16(q5l_0),\n                wasm_i16x8_extend_low_i8x16(q8_0)\n            );\n            pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_high_i8x16(q5l_0),\n                wasm_i16x8_extend_high_i8x16(q8_0)\n            ));\n            v128_t pl1 = wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_low_i8x16(q5l_1),\n                wasm_i16x8_extend_low_i8x16(q8_1)\n            );\n            pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_high_i8x16(q5l_1),\n                wasm_i16x8_extend_high_i8x16(q8_1)\n            ));\n            v128_t sum_low = wasm_i32x4_add(pl0, pl1);\n\n            // Process high quants\n            v128_t ph0 = wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_low_i8x16(q5h_0),\n                wasm_i16x8_extend_low_i8x16(q8_2)\n            );\n            ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_high_i8x16(q5h_0),\n                wasm_i16x8_extend_high_i8x16(q8_2)\n            ));\n            v128_t ph1 = wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_low_i8x16(q5h_1),\n                wasm_i16x8_extend_low_i8x16(q8_3)\n            );\n            ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(\n                wasm_i16x8_extend_high_i8x16(q5h_1),\n                wasm_i16x8_extend_high_i8x16(q8_3)\n            ));\n            v128_t sum_high = wasm_i32x4_add(ph0, ph1);\n\n            // Accumulate with scale factors\n            int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +\n                        wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);\n            int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +\n                        wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);\n\n            sumi += sl * sc[2*j] + sh * sc[2*j+1];\n        }\n\n        sumf += d * sumi;\n    }\n\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    UNUSED(utmp);\n    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q6_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __wasm_simd128__\n    int8_t aux8[QK_K] __attribute__((aligned(16)));\n    int32_t aux32[8] __attribute__((aligned(16))) = {0};\n    float sums[8] __attribute__((aligned(16))) = {0};\n\n    for (int i = 0; i < nb; ++i) {\n        // Unpack 6-bit quantized data into aux8 (unchanged)\n        const uint8_t * GGML_RESTRICT q4 = x[i].ql;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        int8_t * a = aux8;\n        for (int j = 0; j < QK_K; j += 128) {\n            for (int l = 0; l < 32; ++l) {\n                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;\n                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;\n                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;\n                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;\n            }\n            a += 128;\n            q4 += 64;\n            qh += 32;\n        }\n\n        const int8_t * GGML_RESTRICT a_ptr = aux8;\n        const int8_t * GGML_RESTRICT q8 = y[i].qs;\n        v128_t acc0 = wasm_i32x4_splat(0);\n        v128_t acc1 = wasm_i32x4_splat(0);\n\n        for (int j = 0; j < QK_K/16; ++j) {\n            const int scale = x[i].scales[j];\n            const v128_t vscale = wasm_i32x4_splat(scale);\n\n            // Load 16 elements from a and q8\n            const v128_t a_vec = wasm_v128_load(a_ptr);\n            const v128_t q8_vec = wasm_v128_load(q8);\n\n            // Process low 8 elements\n            v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);\n            v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);\n            v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);\n            v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);\n            v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);\n\n            // Process high 8 elements\n            v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);\n            v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);\n            v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);\n            v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);\n            v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);\n\n            // Scale and accumulate\n            prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);\n            prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);\n            prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);\n            prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);\n\n            acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));\n            acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));\n\n            a_ptr += 16;\n            q8 += 16;\n        }\n\n        // Store accumulated results\n        wasm_v128_store(&aux32[0], acc0);\n        wasm_v128_store(&aux32[4], acc1);\n\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        for (int l = 0; l < 8; ++l) {\n            sums[l] += d * aux32[l];\n        }\n    }\n\n    // Sum final results\n    float sumf = 0;\n    for (int l = 0; l < 8; ++l) {\n        sumf += sums[l];\n    }\n    *s = sumf;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\n"
  },
  {
    "path": "src/ggml-cpu/arch/x86/cpu-feats.cpp",
    "content": "#include \"ggml-backend-impl.h\"\n\n#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))\n\n#ifdef _MSC_VER\n#include <intrin.h>\n#endif\n\n#include <cstring>\n#include <vector>\n#include <bitset>\n#include <array>\n#include <string>\n\n// ref: https://cdrdv2-public.intel.com/782156/325383-sdm-vol-2abcd.pdf\nstruct cpuid_x86 {\n    bool SSE3(void) { return f_1_ecx[0]; }\n    bool PCLMULQDQ(void) { return f_1_ecx[1]; }\n    bool MONITOR(void) { return f_1_ecx[3]; }\n    bool SSSE3(void) { return f_1_ecx[9]; }\n    bool FMA(void) { return f_1_ecx[12]; }\n    bool CMPXCHG16B(void) { return f_1_ecx[13]; }\n    bool SSE41(void) { return f_1_ecx[19]; }\n    bool SSE42(void) { return f_1_ecx[20]; }\n    bool MOVBE(void) { return f_1_ecx[22]; }\n    bool POPCNT(void) { return f_1_ecx[23]; }\n    bool AES(void) { return f_1_ecx[25]; }\n    bool XSAVE(void) { return f_1_ecx[26]; }\n    bool OSXSAVE(void) { return f_1_ecx[27]; }\n    bool AVX(void) { return f_1_ecx[28]; }\n    bool F16C(void) { return f_1_ecx[29]; }\n    bool RDRAND(void) { return f_1_ecx[30]; }\n\n    bool MSR(void) { return f_1_edx[5]; }\n    bool CX8(void) { return f_1_edx[8]; }\n    bool SEP(void) { return f_1_edx[11]; }\n    bool CMOV(void) { return f_1_edx[15]; }\n    bool CLFSH(void) { return f_1_edx[19]; }\n    bool MMX(void) { return f_1_edx[23]; }\n    bool FXSR(void) { return f_1_edx[24]; }\n    bool SSE(void) { return f_1_edx[25]; }\n    bool SSE2(void) { return f_1_edx[26]; }\n\n    bool FSGSBASE(void) { return f_7_ebx[0]; }\n    bool BMI1(void) { return f_7_ebx[3]; }\n    bool HLE(void) { return is_intel && f_7_ebx[4]; }\n    bool AVX2(void) { return f_7_ebx[5]; }\n    bool BMI2(void) { return f_7_ebx[8]; }\n    bool ERMS(void) { return f_7_ebx[9]; }\n    bool INVPCID(void) { return f_7_ebx[10]; }\n    bool RTM(void) { return is_intel && f_7_ebx[11]; }\n    bool AVX512F(void) { return f_7_ebx[16]; }\n    bool AVX512DQ(void) { return f_7_ebx[17]; }\n    bool RDSEED(void) { return f_7_ebx[18]; }\n    bool ADX(void) { return f_7_ebx[19]; }\n    bool AVX512PF(void) { return f_7_ebx[26]; }\n    bool AVX512ER(void) { return f_7_ebx[27]; }\n    bool AVX512CD(void) { return f_7_ebx[28]; }\n    bool AVX512BW(void) { return f_7_ebx[30]; }\n    bool AVX512VL(void) { return f_7_ebx[31]; }\n\n    bool SHA(void) { return f_7_ebx[29]; }\n\n    bool PREFETCHWT1(void) { return f_7_ecx[0]; }\n\n    bool LAHF(void) { return f_81_ecx[0]; }\n    bool LZCNT(void) { return is_intel && f_81_ecx[5]; }\n    bool ABM(void) { return is_amd && f_81_ecx[5]; }\n    bool SSE4a(void) { return is_amd && f_81_ecx[6]; }\n    bool XOP(void) { return is_amd && f_81_ecx[11]; }\n    bool TBM(void) { return is_amd && f_81_ecx[21]; }\n\n    bool SYSCALL(void) { return is_intel && f_81_edx[11]; }\n    bool MMXEXT(void) { return is_amd && f_81_edx[22]; }\n    bool RDTSCP(void) { return is_intel && f_81_edx[27]; }\n    bool _3DNOWEXT(void) { return is_amd && f_81_edx[30]; }\n    bool _3DNOW(void) { return is_amd && f_81_edx[31]; }\n\n    bool AVX512_VBMI(void) { return f_7_ecx[1]; }\n    bool AVX512_VNNI(void) { return f_7_ecx[11]; }\n    bool AVX512_FP16(void) { return f_7_edx[23]; }\n    bool AVX512_BF16(void) { return f_7_1_eax[5]; }\n    bool AVX_VNNI(void) { return f_7_1_eax[4]; }\n\n    bool AMX_TILE(void) { return f_7_edx[24]; }\n    bool AMX_INT8(void) { return f_7_edx[25]; }\n    bool AMX_FP16(void) { return f_7_1_eax[21]; }\n    bool AMX_BF16(void) { return f_7_edx[22]; }\n\n#ifdef _MSC_VER\n    static void cpuid(int cpu_info[4], int eax) {\n        __cpuid(cpu_info, eax);\n    }\n    static void cpuidex(int cpu_info[4], int eax, int ecx) {\n        __cpuidex(cpu_info, eax, ecx);\n    }\n#else\n    static void cpuid(int cpu_info[4], int eax) {\n        __asm__ __volatile__(\n            \"cpuid\"\n            : \"=a\"(cpu_info[0]), \"=b\"(cpu_info[1]), \"=c\"(cpu_info[2]), \"=d\"(cpu_info[3])\n            : \"a\"(eax), \"c\"(0));\n    }\n    static void cpuidex(int cpu_info[4], int eax, int ecx) {\n        __asm__ __volatile__(\n            \"cpuid\"\n            : \"=a\"(cpu_info[0]), \"=b\"(cpu_info[1]), \"=c\"(cpu_info[2]), \"=d\"(cpu_info[3])\n            : \"a\"(eax), \"c\"(ecx));\n    }\n#endif\n\n    cpuid_x86() {\n        std::array<int, 4> cpui;\n        std::vector<std::array<int, 4>> data;\n\n        // calling __cpuid with 0x0 as the function_id argument\n        // gets the number of the highest valid function ID.\n        cpuid(cpui.data(), 0);\n        int n_ids = cpui[0];\n\n        for (int i = 0; i <= n_ids; ++i) {\n            cpuidex(cpui.data(), i, 0);\n            data.push_back(cpui);\n        }\n\n        // capture vendor string\n        char vendor[0x20] = {};\n        *reinterpret_cast<int *>(vendor)     = data[0][1];\n        *reinterpret_cast<int *>(vendor + 4) = data[0][3];\n        *reinterpret_cast<int *>(vendor + 8) = data[0][2];\n        this->vendor = vendor;\n        if (this->vendor == \"GenuineIntel\") {\n            is_intel = true;\n        } else if (this->vendor == \"AuthenticAMD\") {\n            is_amd = true;\n        }\n\n        // load bitset with flags for function 0x00000001\n        if (n_ids >= 1) {\n            f_1_ecx = data[1][2];\n            f_1_edx = data[1][3];\n        }\n\n        // load bitset with flags for function 0x00000007\n        if (n_ids >= 7) {\n            f_7_ebx = data[7][1];\n            f_7_ecx = data[7][2];\n            f_7_edx = data[7][3];\n            cpuidex(cpui.data(), 7, 1);\n            f_7_1_eax = cpui[0];\n        }\n\n        // calling __cpuid with 0x80000000 as the function_id argument\n        // gets the number of the highest valid extended ID.\n        cpuid(cpui.data(), 0x80000000);\n        unsigned int n_ex_ids = cpui[0];\n\n        std::vector<std::array<int, 4>> ext_data;\n        for (unsigned int i = 0x80000000; i <= n_ex_ids; ++i) {\n            cpuidex(cpui.data(), i, 0);\n            ext_data.push_back(cpui);\n        }\n\n        // load bitset with flags for function 0x80000001\n        if (n_ex_ids >= 0x80000001) {\n            f_81_ecx = ext_data[1][2];\n            f_81_edx = ext_data[1][3];\n        }\n\n        // interpret CPU brand string if reported\n        char brand[0x40] = {};\n        if (n_ex_ids >= 0x80000004) {\n            std::memcpy(brand, ext_data[2].data(), sizeof(cpui));\n            std::memcpy(brand + 16, ext_data[3].data(), sizeof(cpui));\n            std::memcpy(brand + 32, ext_data[4].data(), sizeof(cpui));\n            this->brand = brand;\n        }\n    }\n\n    bool is_intel = false;\n    bool is_amd = false;\n    std::string vendor;\n    std::string brand;\n    std::bitset<32> f_1_ecx;\n    std::bitset<32> f_1_edx;\n    std::bitset<32> f_7_ebx;\n    std::bitset<32> f_7_ecx;\n    std::bitset<32> f_7_edx;\n    std::bitset<32> f_7_1_eax;\n    std::bitset<32> f_81_ecx;\n    std::bitset<32> f_81_edx;\n};\n\n#if 0\nvoid test_x86_is() {\n    cpuid_x86 is;\n    printf(\"CPU Vendor: %s\\n\", is.vendor.c_str());\n    printf(\"Brand: %s\\n\", is.brand.c_str());\n    printf(\"is_intel: %d\\n\", is.is_intel);\n    printf(\"is_amd: %d\\n\", is.is_amd);\n    printf(\"sse3: %d\\n\", is.SSE3());\n    printf(\"pclmulqdq: %d\\n\", is.PCLMULQDQ());\n    printf(\"ssse3: %d\\n\", is.SSSE3());\n    printf(\"fma: %d\\n\", is.FMA());\n    printf(\"cmpxchg16b: %d\\n\", is.CMPXCHG16B());\n    printf(\"sse41: %d\\n\", is.SSE41());\n    printf(\"sse42: %d\\n\", is.SSE42());\n    printf(\"movbe: %d\\n\", is.MOVBE());\n    printf(\"popcnt: %d\\n\", is.POPCNT());\n    printf(\"aes: %d\\n\", is.AES());\n    printf(\"xsave: %d\\n\", is.XSAVE());\n    printf(\"osxsave: %d\\n\", is.OSXSAVE());\n    printf(\"avx: %d\\n\", is.AVX());\n    printf(\"f16c: %d\\n\", is.F16C());\n    printf(\"rdrand: %d\\n\", is.RDRAND());\n    printf(\"msr: %d\\n\", is.MSR());\n    printf(\"cx8: %d\\n\", is.CX8());\n    printf(\"sep: %d\\n\", is.SEP());\n    printf(\"cmov: %d\\n\", is.CMOV());\n    printf(\"clflush: %d\\n\", is.CLFSH());\n    printf(\"mmx: %d\\n\", is.MMX());\n    printf(\"fxsr: %d\\n\", is.FXSR());\n    printf(\"sse: %d\\n\", is.SSE());\n    printf(\"sse2: %d\\n\", is.SSE2());\n    printf(\"fsgsbase: %d\\n\", is.FSGSBASE());\n    printf(\"bmi1: %d\\n\", is.BMI1());\n    printf(\"hle: %d\\n\", is.HLE());\n    printf(\"avx2: %d\\n\", is.AVX2());\n    printf(\"bmi2: %d\\n\", is.BMI2());\n    printf(\"erms: %d\\n\", is.ERMS());\n    printf(\"invpcid: %d\\n\", is.INVPCID());\n    printf(\"rtm: %d\\n\", is.RTM());\n    printf(\"avx512f: %d\\n\", is.AVX512F());\n    printf(\"rdseed: %d\\n\", is.RDSEED());\n    printf(\"adx: %d\\n\", is.ADX());\n    printf(\"avx512pf: %d\\n\", is.AVX512PF());\n    printf(\"avx512er: %d\\n\", is.AVX512ER());\n    printf(\"avx512cd: %d\\n\", is.AVX512CD());\n    printf(\"sha: %d\\n\", is.SHA());\n    printf(\"prefetchwt1: %d\\n\", is.PREFETCHWT1());\n    printf(\"lahf: %d\\n\", is.LAHF());\n    printf(\"lzcnt: %d\\n\", is.LZCNT());\n    printf(\"abm: %d\\n\", is.ABM());\n    printf(\"sse4a: %d\\n\", is.SSE4a());\n    printf(\"xop: %d\\n\", is.XOP());\n    printf(\"tbm: %d\\n\", is.TBM());\n    printf(\"syscall: %d\\n\", is.SYSCALL());\n    printf(\"mmxext: %d\\n\", is.MMXEXT());\n    printf(\"rdtscp: %d\\n\", is.RDTSCP());\n    printf(\"3dnowext: %d\\n\", is._3DNOWEXT());\n    printf(\"3dnow: %d\\n\", is._3DNOW());\n    printf(\"avx512_vbmi: %d\\n\", is.AVX512_VBMI());\n    printf(\"avx512_vnni: %d\\n\", is.AVX512_VNNI());\n    printf(\"avx512_fp16: %d\\n\", is.AVX512_FP16());\n    printf(\"avx512_bf16: %d\\n\", is.AVX512_BF16());\n    printf(\"amx_tile: %d\\n\", is.AMX_TILE());\n    printf(\"amx_int8: %d\\n\", is.AMX_INT8());\n    printf(\"amx_fp16: %d\\n\", is.AMX_FP16());\n    printf(\"amx_bf16: %d\\n\", is.AMX_BF16());\n}\n#endif\n\nstatic int ggml_backend_cpu_x86_score() {\n    // FIXME: this does not check for OS support\n\n    int score = 1;\n    cpuid_x86 is;\n\n#ifdef GGML_FMA\n    if (!is.FMA()) { return 0; }\n    score += 1;\n#endif\n#ifdef GGML_F16C\n    if (!is.F16C()) { return 0; }\n    score += 1<<1;\n#endif\n#ifdef GGML_SSE42\n    if (!is.SSE42()) { return 0; }\n    score += 1<<2;\n#endif\n#ifdef GGML_BMI2\n    if (!is.BMI2()) { return 0; }\n    score += 1<<3;\n#endif\n#ifdef GGML_AVX\n    if (!is.AVX()) { return 0; }\n    score += 1<<4;\n#endif\n#ifdef GGML_AVX2\n    if (!is.AVX2()) { return 0; }\n    score += 1<<5;\n#endif\n#ifdef GGML_AVX_VNNI\n    if (!is.AVX_VNNI()) { return 0; }\n    score += 1<<6;\n#endif\n#ifdef GGML_AVX512\n    if (!is.AVX512F()) { return 0; }\n    if (!is.AVX512CD()) { return 0; }\n    if (!is.AVX512VL()) { return 0; }\n    if (!is.AVX512DQ()) { return 0; }\n    if (!is.AVX512BW()) { return 0; }\n    score += 1<<7;\n#endif\n#ifdef GGML_AVX512_VBMI\n    if (!is.AVX512_VBMI()) { return 0; }\n    score += 1<<8;\n#endif\n#ifdef GGML_AVX512_BF16\n    if (!is.AVX512_BF16()) { return 0; }\n    score += 1<<9;\n#endif\n#ifdef GGML_AVX512_VNNI\n    if (!is.AVX512_VNNI()) { return 0; }\n    score += 1<<10;\n#endif\n#ifdef GGML_AMX_INT8\n    if (!is.AMX_INT8()) { return 0; }\n    score += 1<<11;\n#endif\n\n    return score;\n}\n\nGGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_x86_score)\n\n#endif // defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))\n"
  },
  {
    "path": "src/ggml-cpu/arch/x86/quants.c",
    "content": "#define GGML_COMMON_IMPL_C\n#include \"ggml-common.h\"\n#include \"ggml-quants.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-cpu.h\"\n#include \"simd-mappings.h\"\n\n#include \"../../quants.h\"\n#include \"../../ggml-cpu-impl.h\"\n\n#include <math.h>\n#include <string.h>\n#include <assert.h>\n#include <stdlib.h> // for qsort\n#include <stdio.h>  // for GGML_ASSERT\n\n#define GROUP_MAX_EPS 1e-15f\n#define GROUP_MAX_EPS_IQ3_XXS 1e-8f\n#define GROUP_MAX_EPS_IQ2_S 1e-8f\n#define GROUP_MAX_EPS_IQ1_M 1e-7f\n#define GROUP_MAX_EPS_IQ1_S 1e-12f\n\n#define UNUSED GGML_UNUSED\n\n// some compilers don't provide _mm256_set_m128i, e.g. gcc 7\n#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)\n\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)\n// multiply int8_t, add results pairwise twice\nstatic inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {\n    // Get absolute values of x vectors\n    const __m128i ax = _mm_sign_epi8(x, x);\n    // Sign the values of the y vectors\n    const __m128i sy = _mm_sign_epi8(y, x);\n    // Perform multiplication and create 16-bit values\n    const __m128i dot = _mm_maddubs_epi16(ax, sy);\n    const __m128i ones = _mm_set1_epi16(1);\n    return _mm_madd_epi16(ones, dot);\n}\n\n#if __AVX__ || __AVX2__ || __AVX512F__\n// horizontally add 8 floats\nstatic inline float hsum_float_8(const __m256 x) {\n    __m128 res = _mm256_extractf128_ps(x, 1);\n    res = _mm_add_ps(res, _mm256_castps256_ps128(x));\n    res = _mm_add_ps(res, _mm_movehl_ps(res, res));\n    res = _mm_add_ss(res, _mm_movehdup_ps(res));\n    return _mm_cvtss_f32(res);\n}\n\n// horizontally add 8 int32_t\nstatic inline int hsum_i32_8(const __m256i a) {\n    const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));\n    const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);\n    const __m128i sum64 = _mm_add_epi32(hi64, sum128);\n    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));\n    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));\n}\n\n// horizontally add 4 int32_t\nstatic inline int hsum_i32_4(const __m128i a) {\n    const __m128i hi64 = _mm_unpackhi_epi64(a, a);\n    const __m128i sum64 = _mm_add_epi32(hi64, a);\n    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));\n    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));\n}\n\n#if defined(__AVX2__) || defined(__AVX512F__)\nstatic inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {\n    const __m256i ax = _mm256_sign_epi8(x, x);\n    const __m256i sy = _mm256_sign_epi8(y, x);\n    return _mm256_maddubs_epi16(ax, sy);\n}\n\n// spread 32 bits to 32 bytes { 0x00, 0xFF }\nstatic inline __m256i bytes_from_bits_32(const uint8_t * x) {\n    uint32_t x32;\n    memcpy(&x32, x, sizeof(uint32_t));\n    const __m256i shuf_mask = _mm256_set_epi64x(\n            0x0303030303030303, 0x0202020202020202,\n            0x0101010101010101, 0x0000000000000000);\n    __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);\n    const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);\n    bytes = _mm256_or_si256(bytes, bit_mask);\n    return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));\n}\n\n// Unpack 32 4-bit fields into 32 bytes\n// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval\nstatic inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)\n{\n    const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);\n    const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);\n    const __m256i lowMask = _mm256_set1_epi8( 0xF );\n    return _mm256_and_si256(lowMask, bytes);\n}\n\n// add int16_t pairwise and return as float vector\nstatic inline __m256 sum_i16_pairs_float(const __m256i x) {\n    const __m256i ones = _mm256_set1_epi16(1);\n    const __m256i summed_pairs = _mm256_madd_epi16(ones, x);\n    return _mm256_cvtepi32_ps(summed_pairs);\n}\n\nstatic inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {\n#if defined(__AVX512VNNI__) && defined(__AVX512VL__)\n    const __m256i zero = _mm256_setzero_si256();\n    const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);\n    return _mm256_cvtepi32_ps(summed_pairs);\n#elif defined(__AVXVNNI__)\n    const __m256i zero = _mm256_setzero_si256();\n    const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);\n    return _mm256_cvtepi32_ps(summed_pairs);\n#else\n    // Perform multiplication and create 16-bit values\n    const __m256i dot = _mm256_maddubs_epi16(ax, sy);\n    return sum_i16_pairs_float(dot);\n#endif\n}\n\n// multiply int8_t, add results pairwise twice and return as float vector\nstatic inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {\n#if __AVXVNNIINT8__\n    const __m256i zero = _mm256_setzero_si256();\n    const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);\n    return _mm256_cvtepi32_ps(summed_pairs);\n#else\n    // Get absolute values of x vectors\n    const __m256i ax = _mm256_sign_epi8(x, x);\n    // Sign the values of the y vectors\n    const __m256i sy = _mm256_sign_epi8(y, x);\n    return mul_sum_us8_pairs_float(ax, sy);\n#endif\n}\n\nstatic inline __m128i packNibbles( __m256i bytes )\n{\n    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh\n#if __AVX512F__\n    const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4);   // 0000_0000_abcd_0000\n    bytes = _mm256_or_si256(bytes, bytes_srli_4);               // 0000_abcd_abcd_efgh\n    return _mm256_cvtepi16_epi8(bytes);                         // abcd_efgh\n#else\n    const __m256i lowByte = _mm256_set1_epi16( 0xFF );\n    __m256i high = _mm256_andnot_si256( lowByte, bytes );\n    __m256i low = _mm256_and_si256( lowByte, bytes );\n    high = _mm256_srli_epi16( high, 4 );\n    bytes = _mm256_or_si256( low, high );\n\n    // Compress uint16_t lanes into bytes\n    __m128i r0 = _mm256_castsi256_si128( bytes );\n    __m128i r1 = _mm256_extracti128_si256( bytes, 1 );\n    return _mm_packus_epi16( r0, r1 );\n#endif\n}\n#elif defined(__AVX__)\nstatic inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )\n{\n    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh\n    const __m128i lowByte = _mm_set1_epi16( 0xFF );\n    __m128i high = _mm_andnot_si128( lowByte, bytes1 );\n    __m128i low = _mm_and_si128( lowByte, bytes1 );\n    high = _mm_srli_epi16( high, 4 );\n    bytes1 = _mm_or_si128( low, high );\n    high = _mm_andnot_si128( lowByte, bytes2 );\n    low = _mm_and_si128( lowByte, bytes2 );\n    high = _mm_srli_epi16( high, 4 );\n    bytes2 = _mm_or_si128( low, high );\n\n    return _mm_packus_epi16( bytes1, bytes2);\n}\n\nstatic inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {\n    const __m128i ax = _mm_sign_epi8(x, x);\n    const __m128i sy = _mm_sign_epi8(y, x);\n    return _mm_maddubs_epi16(ax, sy);\n}\n\n// spread 32 bits to 32 bytes { 0x00, 0xFF }\nstatic inline __m256i bytes_from_bits_32(const uint8_t * x) {\n    uint32_t x32;\n    memcpy(&x32, x, sizeof(uint32_t));\n    const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);\n    const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);\n    __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);\n    __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);\n    const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);\n    bytesl = _mm_or_si128(bytesl, bit_mask);\n    bytesh = _mm_or_si128(bytesh, bit_mask);\n    bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));\n    bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));\n    return MM256_SET_M128I(bytesh, bytesl);\n}\n\n// Unpack 32 4-bit fields into 32 bytes\n// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval\nstatic inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)\n{\n    // Load 16 bytes from memory\n    __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);\n    __m128i tmph = _mm_srli_epi16(tmpl, 4);\n    const __m128i lowMask = _mm_set1_epi8(0xF);\n    tmpl = _mm_and_si128(lowMask, tmpl);\n    tmph = _mm_and_si128(lowMask, tmph);\n    return MM256_SET_M128I(tmph, tmpl);\n}\n\n// add int16_t pairwise and return as float vector\nstatic inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {\n    const __m128i ones = _mm_set1_epi16(1);\n    const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);\n    const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);\n    const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);\n    return _mm256_cvtepi32_ps(summed_pairs);\n}\n\nstatic inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {\n    const __m128i axl = _mm256_castsi256_si128(ax);\n    const __m128i axh = _mm256_extractf128_si256(ax, 1);\n    const __m128i syl = _mm256_castsi256_si128(sy);\n    const __m128i syh = _mm256_extractf128_si256(sy, 1);\n    // Perform multiplication and create 16-bit values\n    const __m128i dotl = _mm_maddubs_epi16(axl, syl);\n    const __m128i doth = _mm_maddubs_epi16(axh, syh);\n    return sum_i16_pairs_float(doth, dotl);\n}\n\n// multiply int8_t, add results pairwise twice and return as float vector\nstatic inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {\n    const __m128i xl = _mm256_castsi256_si128(x);\n    const __m128i xh = _mm256_extractf128_si256(x, 1);\n    const __m128i yl = _mm256_castsi256_si128(y);\n    const __m128i yh = _mm256_extractf128_si256(y, 1);\n    // Get absolute values of x vectors\n    const __m128i axl = _mm_sign_epi8(xl, xl);\n    const __m128i axh = _mm_sign_epi8(xh, xh);\n    // Sign the values of the y vectors\n    const __m128i syl = _mm_sign_epi8(yl, xl);\n    const __m128i syh = _mm_sign_epi8(yh, xh);\n    // Perform multiplication and create 16-bit values\n    const __m128i dotl = _mm_maddubs_epi16(axl, syl);\n    const __m128i doth = _mm_maddubs_epi16(axh, syh);\n    return sum_i16_pairs_float(doth, dotl);\n}\n\n// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors\nstatic inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1,\n                                           const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) {\n    const __m128i mone = _mm_set1_epi16(1);\n\n    const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0);\n    const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1);\n    const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0);\n    const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1);\n    const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);\n    const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);\n    const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);\n    const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);\n    const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1);\n    const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1);\n    return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1));\n}\n\n// quad fp16 delta calculation\nstatic inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) {\n    // GGML_CPU_FP16_TO_FP32 is faster than Intel F16C\n    return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),\n                           _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));\n}\n\nstatic inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const uint8_t x1, const float y1) {\n    return _mm256_set_m128(_mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),\n                           _mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));\n}\n#endif\n#elif defined(__SSSE3__)\n// horizontally add 4x4 floats\nstatic inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {\n    __m128 res_0 =_mm_hadd_ps(a, b);\n    __m128 res_1 =_mm_hadd_ps(c, d);\n    __m128 res =_mm_hadd_ps(res_0, res_1);\n    res =_mm_hadd_ps(res, res);\n    res =_mm_hadd_ps(res, res);\n\n    return _mm_cvtss_f32(res);\n}\n#endif // __AVX__ || __AVX2__ || __AVX512F__\n#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)\n\nvoid quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK8_0 == 32);\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n    block_q8_0 * GGML_RESTRICT y = vy;\n\n#if defined(__AVX2__) || defined(__AVX__)\n    for (int i = 0; i < nb; i++) {\n        // Load elements into 4 AVX vectors\n        __m256 v0 = _mm256_loadu_ps( x );\n        __m256 v1 = _mm256_loadu_ps( x + 8 );\n        __m256 v2 = _mm256_loadu_ps( x + 16 );\n        __m256 v3 = _mm256_loadu_ps( x + 24 );\n        x += 32;\n\n        // Compute max(abs(e)) for the block\n        const __m256 signBit = _mm256_set1_ps( -0.0f );\n        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );\n        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );\n        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );\n        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );\n\n        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );\n        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );\n        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );\n        const float maxScalar = _mm_cvtss_f32( max4 );\n\n        // Quantize these floats\n        const float d = maxScalar / 127.f;\n        y[i].d = GGML_CPU_FP32_TO_FP16(d);\n        const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;\n        const __m256 mul = _mm256_set1_ps( id );\n\n        // Apply the multiplier\n        v0 = _mm256_mul_ps( v0, mul );\n        v1 = _mm256_mul_ps( v1, mul );\n        v2 = _mm256_mul_ps( v2, mul );\n        v3 = _mm256_mul_ps( v3, mul );\n\n        // Round to nearest integer\n        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );\n        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );\n        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );\n        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );\n\n        // Convert floats to integers\n        __m256i i0 = _mm256_cvtps_epi32( v0 );\n        __m256i i1 = _mm256_cvtps_epi32( v1 );\n        __m256i i2 = _mm256_cvtps_epi32( v2 );\n        __m256i i3 = _mm256_cvtps_epi32( v3 );\n\n#if defined(__AVX2__)\n        // Convert int32 to int16\n        i0 = _mm256_packs_epi32( i0, i1 );\t// 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15\n        i2 = _mm256_packs_epi32( i2, i3 );\t// 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31\n                                            // Convert int16 to int8\n        i0 = _mm256_packs_epi16( i0, i2 );\t// 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31\n\n        // We got our precious signed bytes, but the order is now wrong\n        // These AVX2 pack instructions process 16-byte pieces independently\n        // The following instruction is fixing the order\n        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );\n        i0 = _mm256_permutevar8x32_epi32( i0, perm );\n\n        _mm256_storeu_si256((__m256i *)y[i].qs, i0);\n#else\n        // Since we don't have in AVX some necessary functions,\n        // we split the registers in half and call AVX2 analogs from SSE\n        __m128i ni0 = _mm256_castsi256_si128( i0 );\n        __m128i ni1 = _mm256_extractf128_si256( i0, 1);\n        __m128i ni2 = _mm256_castsi256_si128( i1 );\n        __m128i ni3 = _mm256_extractf128_si256( i1, 1);\n        __m128i ni4 = _mm256_castsi256_si128( i2 );\n        __m128i ni5 = _mm256_extractf128_si256( i2, 1);\n        __m128i ni6 = _mm256_castsi256_si128( i3 );\n        __m128i ni7 = _mm256_extractf128_si256( i3, 1);\n\n        // Convert int32 to int16\n        ni0 = _mm_packs_epi32( ni0, ni1 );\n        ni2 = _mm_packs_epi32( ni2, ni3 );\n        ni4 = _mm_packs_epi32( ni4, ni5 );\n        ni6 = _mm_packs_epi32( ni6, ni7 );\n        // Convert int16 to int8\n        ni0 = _mm_packs_epi16( ni0, ni2 );\n        ni4 = _mm_packs_epi16( ni4, ni6 );\n\n        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);\n        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);\n#endif\n    }\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_0_ref(x, y, k);\n#endif\n}\n\nvoid quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(k % QK8_1 == 0);\n    const int nb = k / QK8_1;\n\n    block_q8_1 * GGML_RESTRICT y = vy;\n#if defined(__AVX2__) || defined(__AVX__)\n    for (int i = 0; i < nb; i++) {\n        // Load elements into 4 AVX vectors\n        __m256 v0 = _mm256_loadu_ps( x );\n        __m256 v1 = _mm256_loadu_ps( x + 8 );\n        __m256 v2 = _mm256_loadu_ps( x + 16 );\n        __m256 v3 = _mm256_loadu_ps( x + 24 );\n        x += 32;\n\n        // Compute max(abs(e)) for the block\n        const __m256 signBit = _mm256_set1_ps( -0.0f );\n        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );\n        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );\n        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );\n        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );\n\n        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );\n        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );\n        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );\n        const float max_scalar = _mm_cvtss_f32( max4 );\n\n        // Quantize these floats\n        const float d = max_scalar / 127.f;\n        y[i].d = GGML_CPU_FP32_TO_FP16(d);\n        const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;\n        const __m256 mul = _mm256_set1_ps( id );\n\n        // Apply the multiplier\n        v0 = _mm256_mul_ps( v0, mul );\n        v1 = _mm256_mul_ps( v1, mul );\n        v2 = _mm256_mul_ps( v2, mul );\n        v3 = _mm256_mul_ps( v3, mul );\n\n        // Round to nearest integer\n        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );\n        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );\n        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );\n        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );\n\n        // Convert floats to integers\n        __m256i i0 = _mm256_cvtps_epi32( v0 );\n        __m256i i1 = _mm256_cvtps_epi32( v1 );\n        __m256i i2 = _mm256_cvtps_epi32( v2 );\n        __m256i i3 = _mm256_cvtps_epi32( v3 );\n\n#if defined(__AVX2__)\n        // Compute the sum of the quants and set y[i].s\n        y[i].s = GGML_CPU_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));\n\n        // Convert int32 to int16\n        i0 = _mm256_packs_epi32( i0, i1 );\t// 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15\n        i2 = _mm256_packs_epi32( i2, i3 );\t// 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31\n                                            // Convert int16 to int8\n        i0 = _mm256_packs_epi16( i0, i2 );\t// 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31\n\n        // We got our precious signed bytes, but the order is now wrong\n        // These AVX2 pack instructions process 16-byte pieces independently\n        // The following instruction is fixing the order\n        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );\n        i0 = _mm256_permutevar8x32_epi32( i0, perm );\n\n        _mm256_storeu_si256((__m256i *)y[i].qs, i0);\n#else\n        // Since we don't have in AVX some necessary functions,\n        // we split the registers in half and call AVX2 analogs from SSE\n        __m128i ni0 = _mm256_castsi256_si128( i0 );\n        __m128i ni1 = _mm256_extractf128_si256( i0, 1);\n        __m128i ni2 = _mm256_castsi256_si128( i1 );\n        __m128i ni3 = _mm256_extractf128_si256( i1, 1);\n        __m128i ni4 = _mm256_castsi256_si128( i2 );\n        __m128i ni5 = _mm256_extractf128_si256( i2, 1);\n        __m128i ni6 = _mm256_castsi256_si128( i3 );\n        __m128i ni7 = _mm256_extractf128_si256( i3, 1);\n\n        // Compute the sum of the quants and set y[i].s\n        const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));\n        const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));\n        y[i].s = GGML_CPU_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1)));\n\n        // Convert int32 to int16\n        ni0 = _mm_packs_epi32( ni0, ni1 );\n        ni2 = _mm_packs_epi32( ni2, ni3 );\n        ni4 = _mm_packs_epi32( ni4, ni5 );\n        ni6 = _mm_packs_epi32( ni6, ni7 );\n        // Convert int16 to int8\n        ni0 = _mm_packs_epi16( ni0, ni2 );\n        ni4 = _mm_packs_epi16( ni4, ni6 );\n\n        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);\n        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);\n#endif\n    }\n#else\n    GGML_UNUSED(nb);\n    // scalar\n    quantize_row_q8_1_ref(x, y, k);\n#endif\n}\n\n// placeholder implementation for Apple targets\nvoid quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n    quantize_row_q8_K_ref(x, y, k);\n}\n\n//===================================== Dot products =================================\n\n//\n// Helper functions\n//\n\n#if __AVX__ || __AVX2__ || __AVX512F__\n\n// shuffles to pick the required scales in dot products\nstatic inline __m256i get_scale_shuffle_q3k(int i) {\n    static const uint8_t k_shuffle[128] = {\n         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,\n         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,\n         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,\n        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,\n    };\n    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);\n}\nstatic inline __m256i get_scale_shuffle_k4(int i) {\n    static const uint8_t k_shuffle[256] = {\n         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,\n         2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,\n         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,\n         6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,\n         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,\n        10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,\n        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,\n        14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15\n    };\n    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);\n}\nstatic inline __m128i get_scale_shuffle(int i) {\n    static const uint8_t k_shuffle[128] = {\n         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n         2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,\n         4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,\n         6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,\n         8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,\n        10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,\n        12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,\n        14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15\n    };\n    return _mm_loadu_si128((const __m128i*)k_shuffle + i);\n}\n#endif\n\nvoid ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__AVX2__)\n    // Initialize accumulator with zeros\n    __m256 acc = _mm256_setzero_ps();\n\n    // Main loop\n    for (; ib < nb; ++ib) {\n        /* Compute combined scale for the block */\n        const __m256 d = _mm256_set1_ps( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );\n\n        __m256i qx = bytes_from_nibbles_32(x[ib].qs);\n\n        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.\n        const __m256i off = _mm256_set1_epi8( 8 );\n        qx = _mm256_sub_epi8( qx, off );\n\n        __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);\n\n        const __m256 q = mul_sum_i8_pairs_float(qx, qy);\n\n        /* Multiply q with scale and accumulate */\n        acc = _mm256_fmadd_ps( d, q, acc );\n    }\n\n    sumf = hsum_float_8(acc);\n#elif defined(__AVX__)\n    __m256 accum = _mm256_setzero_ps();\n    for (; ib + 1 < nb; ib += 2) {\n        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);\n        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);\n        const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);\n        const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);\n        const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);\n        const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);\n\n        const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8));\n        const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8));\n        const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));\n        const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));\n\n        const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);\n        const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);\n        const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);\n        const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);\n        const __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1);\n        const __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1);\n        const __m256 p =  sum_i16_pairs_float(p_2, p_1);\n\n        const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);\n        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);\n    }\n\n    sumf = hsum_float_8(accum);\n#elif defined(__SSSE3__)\n    // set constants\n    const __m128i lowMask = _mm_set1_epi8(0xF);\n    const __m128i off = _mm_set1_epi8(8);\n\n    // Initialize accumulator with zeros\n    __m128 acc_0 = _mm_setzero_ps();\n    __m128 acc_1 = _mm_setzero_ps();\n    __m128 acc_2 = _mm_setzero_ps();\n    __m128 acc_3 = _mm_setzero_ps();\n\n    for (; ib + 1 < nb; ib += 2) {\n        _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0);\n        _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0);\n\n        // Compute combined scale for the block 0 and 1\n        const __m128 d_0_1 = _mm_set1_ps( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );\n\n        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs);\n\n        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);\n        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);\n        bx_0 = _mm_sub_epi8(bx_0, off);\n        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);\n\n        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));\n        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));\n        bx_1 = _mm_sub_epi8(bx_1, off);\n        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);\n\n        _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);\n        _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);\n\n        // Compute combined scale for the block 2 and 3\n        const __m128 d_2_3 = _mm_set1_ps( GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) );\n\n        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);\n\n        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);\n        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);\n        bx_2 = _mm_sub_epi8(bx_2, off);\n        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);\n\n        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));\n        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16));\n        bx_3 = _mm_sub_epi8(bx_3, off);\n        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);\n\n        // Convert int32_t to float\n        __m128 p0 = _mm_cvtepi32_ps(i32_0);\n        __m128 p1 = _mm_cvtepi32_ps(i32_1);\n        __m128 p2 = _mm_cvtepi32_ps(i32_2);\n        __m128 p3 = _mm_cvtepi32_ps(i32_3);\n\n        // Apply the scale\n        __m128 p0_d = _mm_mul_ps( d_0_1, p0 );\n        __m128 p1_d = _mm_mul_ps( d_0_1, p1 );\n        __m128 p2_d = _mm_mul_ps( d_2_3, p2 );\n        __m128 p3_d = _mm_mul_ps( d_2_3, p3 );\n\n        // Acummulate\n        acc_0 = _mm_add_ps(p0_d, acc_0);\n        acc_1 = _mm_add_ps(p1_d, acc_1);\n        acc_2 = _mm_add_ps(p2_d, acc_2);\n        acc_3 = _mm_add_ps(p3_d, acc_3);\n    }\n\n    sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);\n\n#endif\n    for (; ib < nb; ++ib) {\n        int sumi0 = 0;\n        int sumi1 = 0;\n\n        for (int j = 0; j < qk/2; ++j) {\n            const int v0 = (x[ib].qs[j] & 0x0F) - 8;\n            const int v1 = (x[ib].qs[j] >>   4) - 8;\n\n            sumi0 += (v0 * y[ib].qs[j]);\n            sumi1 += (v1 * y[ib].qs[j + qk/2]);\n        }\n\n        int sumi = sumi0 + sumi1;\n        sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n\n#if defined(__AVX2__) || defined(__AVX__)\n    // Initialize accumulator with zeros\n    __m256 acc = _mm256_setzero_ps();\n\n    float summs = 0;\n\n    // Main loop\n    for (; ib < nb; ++ib) {\n        const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);\n        const float d1 = GGML_CPU_FP16_TO_FP32(y[ib].d);\n\n        summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);\n\n        const __m256 d0v = _mm256_set1_ps( d0 );\n        const __m256 d1v = _mm256_set1_ps( d1 );\n\n        // Compute combined scales\n        const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );\n\n        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes\n        const __m256i qx = bytes_from_nibbles_32(x[ib].qs);\n        const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs );\n\n        const __m256 xy = mul_sum_us8_pairs_float(qx, qy);\n\n        // Accumulate d0*d1*x*y\n#if defined(__AVX2__)\n        acc = _mm256_fmadd_ps( d0d1, xy, acc );\n#else\n        acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );\n#endif\n    }\n\n    *s = hsum_float_8(acc) + summs;\n#else\n    UNUSED(nb);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(ib);\n    ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_MXFP4 == 0);\n    static_assert(QK_MXFP4 == QK8_0, \"QK_MXFP4 and QK8_0 must be the same\");\n\n    const block_mxfp4 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_MXFP4;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined __AVX2__\n\n    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);\n    const __m128i m4b  = _mm_set1_epi8(0x0f);\n    const __m256i mone = _mm256_set1_epi16(1);\n\n    __m256 accum1 = _mm256_setzero_ps();\n    __m256 accum2 = _mm256_setzero_ps();\n\n    for (; ib + 1 < nb; ib += 2) {\n        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);\n        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);\n        const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);\n        const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);\n        const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),\n                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));\n        const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),\n                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));\n        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);\n        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);\n        const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);\n        const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);\n        const __m256 scale0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 0].e));\n        const __m256 scale1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 1].e));\n        accum1 = _mm256_fmadd_ps(scale0, _mm256_cvtepi32_ps(p_1), accum1);\n        accum2 = _mm256_fmadd_ps(scale1, _mm256_cvtepi32_ps(p_2), accum2);\n    }\n\n    sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));\n\n#elif defined __AVX__\n    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);\n    const __m128i m4b  = _mm_set1_epi8(0x0f);\n\n    __m256 accum = _mm256_setzero_ps();\n    for (; ib + 1 < nb; ib += 2) {\n        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);\n        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);\n        const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);\n        const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);\n        const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);\n        const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);\n\n        const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));\n        const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));\n        const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));\n        const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));\n\n        const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);\n        const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);\n        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);\n    }\n\n    sumf = hsum_float_8(accum);\n\n#endif\n    for (; ib < nb; ++ib) {\n        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib].e);\n        int sumi1 = 0;\n        int sumi2 = 0;\n        for (int j = 0; j < QK_MXFP4/2; ++j) {\n            sumi1 += y[ib].qs[j +          0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];\n            sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >>  4];\n        }\n        sumf += d * (sumi1 + sumi2);\n    }\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    int ib = 0;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n#if defined(__AVX2__)\n    // Initialize accumulator with zeros\n    __m256 acc = _mm256_setzero_ps();\n\n    // Main loop\n    for (; ib < nb; ++ib) {\n        /* Compute combined scale for the block */\n        const __m256 d = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));\n\n        __m256i qx = bytes_from_nibbles_32(x[ib].qs);\n        __m256i bxhi = bytes_from_bits_32(x[ib].qh);\n        bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));\n        qx = _mm256_or_si256(qx, bxhi);\n\n        __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);\n\n        const __m256 q = mul_sum_i8_pairs_float(qx, qy);\n\n        /* Multiply q with scale and accumulate */\n        acc = _mm256_fmadd_ps(d, q, acc);\n    }\n\n    *s = hsum_float_8(acc);\n#elif defined(__AVX__)\n    // Initialize accumulator with zeros\n    __m256 acc = _mm256_setzero_ps();\n    __m128i mask = _mm_set1_epi8((char)0xF0);\n\n    // Main loop\n    for (; ib < nb; ++ib) {\n        /* Compute combined scale for the block */\n        const __m256 d = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));\n\n        __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);\n        const __m256i bxhi = bytes_from_bits_32(x[ib].qh);\n        __m128i bxhil = _mm256_castsi256_si128(bxhi);\n        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);\n        bxhil = _mm_andnot_si128(bxhil, mask);\n        bxhih = _mm_andnot_si128(bxhih, mask);\n        __m128i bxl = _mm256_castsi256_si128(bx_0);\n        __m128i bxh = _mm256_extractf128_si256(bx_0, 1);\n        bxl = _mm_or_si128(bxl, bxhil);\n        bxh = _mm_or_si128(bxh, bxhih);\n        bx_0 = MM256_SET_M128I(bxh, bxl);\n\n        const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);\n\n        const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);\n\n        /* Multiply q with scale and accumulate */\n        acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);\n    }\n\n    *s = hsum_float_8(acc);\n#else\n    UNUSED(nb);\n    UNUSED(ib);\n    UNUSED(x);\n    UNUSED(y);\n    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    int ib = 0;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_1);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n#if defined(__AVX2__)\n    // Initialize accumulator with zeros\n    __m256 acc = _mm256_setzero_ps();\n\n    float summs = 0.0f;\n\n    // Main loop\n    for (; ib < nb; ++ib) {\n        const __m256 dx = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d));\n\n        summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);\n\n        __m256i qx = bytes_from_nibbles_32(x[ib].qs);\n        __m256i bxhi = bytes_from_bits_32(x[ib].qh);\n        bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));\n        qx = _mm256_or_si256(qx, bxhi);\n\n        const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib].d));\n        const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);\n\n        const __m256 q = mul_sum_us8_pairs_float(qx, qy);\n\n        acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);\n    }\n\n    *s = hsum_float_8(acc) + summs;\n#elif defined(__AVX__)\n    // Initialize accumulator with zeros\n    __m256 acc = _mm256_setzero_ps();\n    __m128i mask = _mm_set1_epi8(0x10);\n\n    float summs = 0.0f;\n\n    // Main loop\n    for (; ib < nb; ++ib) {\n        const __m256 dx = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d));\n\n        summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);\n\n        __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);\n        const __m256i bxhi = bytes_from_bits_32(x[ib].qh);\n        __m128i bxhil = _mm256_castsi256_si128(bxhi);\n        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);\n        bxhil = _mm_and_si128(bxhil, mask);\n        bxhih = _mm_and_si128(bxhih, mask);\n        __m128i bxl = _mm256_castsi256_si128(bx_0);\n        __m128i bxh = _mm256_extractf128_si256(bx_0, 1);\n        bxl = _mm_or_si128(bxl, bxhil);\n        bxh = _mm_or_si128(bxh, bxhih);\n        bx_0 = MM256_SET_M128I(bxh, bxl);\n\n        const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib].d));\n        const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);\n\n        const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);\n\n        acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);\n    }\n\n    *s = hsum_float_8(acc) + summs;\n#else\n    UNUSED(nb);\n    UNUSED(ib);\n    UNUSED(x);\n    UNUSED(y);\n    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q8_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined(__AVX2__)\n    // Initialize accumulator with zeros\n    __m256 acc = _mm256_setzero_ps();\n\n    // Main loop\n    for (; ib < nb; ++ib) {\n        // Compute combined scale for the block\n        const __m256 d = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));\n        __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs);\n        __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);\n\n        const __m256 q = mul_sum_i8_pairs_float(qx, qy);\n\n        // Multiply q with scale and accumulate\n        acc = _mm256_fmadd_ps( d, q, acc );\n    }\n\n    sumf = hsum_float_8(acc);\n#elif defined(__AVX__)\n    __m256 accum = _mm256_setzero_ps();\n\n    for (; ib + 1 < nb; ib += 2) {\n        const __m128i qx_1_0 = _mm_loadu_si128((const __m128i *)x[ib].qs);\n        const __m128i qx_1_1 = _mm_loadu_si128((const __m128i *)x[ib].qs + 1);\n        const __m128i qx_2_0 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);\n        const __m128i qx_2_1 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs + 1);\n        const __m128i qy_1_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);\n        const __m128i qy_1_1 = _mm_loadu_si128((const __m128i *)y[ib].qs + 1);\n        const __m128i qy_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);\n        const __m128i qy_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);\n\n        const __m256 p = mul_sum_i8_quad_float(qx_1_0, qx_1_1, qx_2_0, qx_2_1, qy_1_0, qy_1_1, qy_2_0, qy_2_1);\n        const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);\n        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);\n    }\n\n    sumf = hsum_float_8(accum);\n#endif\n    for (; ib < nb; ++ib) {\n        int sumi = 0;\n\n        for (int j = 0; j < qk; j++) {\n            sumi += x[ib].qs[j]*y[ib].qs[j];\n        }\n\n        sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_tq1_0 * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__AVX2__)\n    __m256 sumf = _mm256_setzero_ps();\n\n    for (int i = 0; i < nb; ++i) {\n        // 16-bit sums\n        __m256i sumi0 = _mm256_setzero_si256();\n        __m256i sumi1 = _mm256_setzero_si256();\n        __m256i sumi2 = _mm256_setzero_si256();\n\n        // first 32 bytes of 5 elements\n        {\n            __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs));\n            // 8-bit multiplies with shifts, masks and adds\n            __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3\n            __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9\n            __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9\n            __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9\n\n            // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits?\n\n            // Cancel the +1 from avg so that it behaves like a halving add\n            qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1));\n            qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1));\n            qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1));\n            qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1));\n            qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1));\n            // Multiply by 3 and get the top 2 bits\n            qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256()));\n            qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256()));\n            qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256()));\n            qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256()));\n            qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256()));\n            qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3));\n            qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3));\n            qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3));\n            qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3));\n            qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3));\n\n            const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs +   0));\n            const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs +  32));\n            const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs +  64));\n            const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs +  96));\n            const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128));\n\n            qx0 = _mm256_maddubs_epi16(qx0, qy0);\n            qx1 = _mm256_maddubs_epi16(qx1, qy1);\n            qx2 = _mm256_maddubs_epi16(qx2, qy2);\n            qx3 = _mm256_maddubs_epi16(qx3, qy3);\n            qx4 = _mm256_maddubs_epi16(qx4, qy4);\n\n            sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));\n            sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));\n            sumi2 = _mm256_add_epi16(sumi2, qx4);\n        }\n\n        // last 16 bytes of 5-element, along with the 4 bytes of 4 elements\n        {\n            __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32));\n            uint32_t qh;\n            memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned\n            __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh));\n            __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3\n            __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9\n            __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9\n            __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9\n            __m256i qx01 = MM256_SET_M128I(qx1, qx0);\n            __m256i qx23 = MM256_SET_M128I(qx3, qx2);\n\n            // avx2 does not have 8-bit multiplies, so 16-bit it is.\n            qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1));\n            qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF));\n            __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1));\n\n            __m256i qx45 = MM256_SET_M128I(qx5, qx4);\n\n            // Cancel the +1 from avg so that it behaves like a halving add\n            qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1));\n            qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1));\n            qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1));\n            // Multiply by 3 and get the top 2 bits\n            qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256()));\n            qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256()));\n            qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256()));\n            qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3));\n            qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3));\n            qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3));\n\n            const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160));\n            const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192));\n            const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224));\n\n            qx01 = _mm256_maddubs_epi16(qx01, qy01);\n            qx23 = _mm256_maddubs_epi16(qx23, qy23);\n            qx45 = _mm256_maddubs_epi16(qx45, qy45);\n\n            sumi0 = _mm256_add_epi16(sumi0, qx01);\n            sumi1 = _mm256_add_epi16(sumi1, qx23);\n            sumi2 = _mm256_add_epi16(sumi2, qx45);\n        }\n\n        const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);\n        const __m256 d = _mm256_set1_ps(y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d));\n\n        sumi0 = _mm256_sub_epi16(sumi0, ysum);\n        sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2));\n        sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));\n\n        sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);\n    }\n\n    *s = hsum_float_8(sumf);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_tq2_0 * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__AVX2__)\n    __m256 sumf = _mm256_setzero_ps();\n\n    for (int i = 0; i < nb; ++i) {\n        // 16-bit sums, because 256*127 still fits\n        __m256i sumi0 = _mm256_setzero_si256();\n        __m256i sumi1 = _mm256_setzero_si256();\n\n        for (size_t j = 0; j < sizeof(x->qs); j += 32) {\n            __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j));\n            __m256i qx1 = _mm256_srli_epi16(qx0, 2);\n            __m256i qx2 = _mm256_srli_epi16(qx0, 4);\n            __m256i qx3 = _mm256_srli_epi16(qx0, 6);\n\n            // 0, 1, 2 (should not be 3)\n            qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3));\n            qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3));\n            qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3));\n            qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3));\n\n            const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 +  0));\n            const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32));\n            const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64));\n            const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96));\n\n            qx0 = _mm256_maddubs_epi16(qx0, qy0);\n            qx1 = _mm256_maddubs_epi16(qx1, qy1);\n            qx2 = _mm256_maddubs_epi16(qx2, qy2);\n            qx3 = _mm256_maddubs_epi16(qx3, qy3);\n\n            sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));\n            sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));\n        }\n\n        const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);\n        const __m256 d = _mm256_set1_ps(y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d));\n\n        sumi0 = _mm256_add_epi16(sumi0, sumi1);\n        sumi0 = _mm256_sub_epi16(sumi0, ysum);\n        sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));\n\n        sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);\n    }\n\n    *s = hsum_float_8(sumf);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q2_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __AVX2__\n\n    const __m256i m3 = _mm256_set1_epi8(3);\n    const __m128i m4 = _mm_set1_epi8(0xF);\n\n    __m256 acc = _mm256_setzero_ps();\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        const uint8_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);\n        const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);\n        const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);\n        const __m256i mins = _mm256_cvtepi8_epi16(mins8);\n        const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));\n\n        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);\n\n        const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);\n        const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);\n        const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);\n        const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};\n\n        __m256i sumi = _mm256_setzero_si256();\n\n        for (int j = 0; j < QK_K/128; ++j) {\n\n            const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32;\n\n            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n\n            const __m256i q2_0 = _mm256_and_si256(q2bits, m3);\n            const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);\n            const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);\n            const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);\n\n            __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);\n            __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);\n            __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);\n            __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);\n\n            p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);\n            p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);\n            p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);\n            p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);\n\n            p0 = _mm256_add_epi32(p0, p1);\n            p2 = _mm256_add_epi32(p2, p3);\n\n            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));\n        }\n\n        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);\n\n    }\n\n    *s = hsum_float_8(acc);\n\n#elif defined __AVX__\n\n    const __m128i m3 = _mm_set1_epi8(0x3);\n    const __m128i m4 = _mm_set1_epi8(0xF);\n    const __m128i m2 = _mm_set1_epi8(0x2);\n\n    __m256 acc = _mm256_setzero_ps();\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        const uint8_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        // load mins and scales from block_q2_K.scales[QK_K/16]\n        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);\n        const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);\n        const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);\n        const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);\n        const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));\n\n        // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2\n        const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));\n        const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));\n\n        // sumf += -dmin * summs in 32bits*8\n        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);\n\n        const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);\n        const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));\n        const __m128i scales[2] = { scales_0, scales_1 };\n\n        __m128i sumi_0 = _mm_setzero_si128();\n        __m128i sumi_1 = _mm_setzero_si128();\n\n        for (int j = 0; j < QK_K/128; ++j) {\n\n            // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]\n            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n\n            // load 2bits*16*8 from block_q2_K.qs[QK_K/4]\n            __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;\n            const __m128i q2_0 = _mm_and_si128(q2bits, m3);\n            const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);\n            const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);\n            const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);\n            q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;\n            const __m128i q2_1 = _mm_and_si128(q2bits, m3);\n            const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);\n            const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);\n            const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);\n\n            // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8\n            __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);\n            __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);\n            __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);\n            __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);\n            __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);\n            __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);\n            __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);\n            __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);\n\n            // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8\n            __m128i shuffle = _mm_set1_epi16(0x0100);\n            p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);\n\n            p0 = _mm_add_epi32(p0, p1);\n            p2 = _mm_add_epi32(p2, p3);\n            p4 = _mm_add_epi32(p4, p5);\n            p6 = _mm_add_epi32(p6, p7);\n\n            // isum in 32bits*4*2\n            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));\n            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));\n        }\n\n        // sumf += dall * isum - dmin * summs in 32bits\n        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);\n        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);\n    }\n\n    *s = hsum_float_8(acc);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n\n    const block_q3_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __AVX2__\n\n    const __m256i m3 = _mm256_set1_epi8(3);\n    const __m256i mone = _mm256_set1_epi8(1);\n    const __m128i m32 = _mm_set1_epi8(32);\n\n    __m256 acc = _mm256_setzero_ps();\n\n    uint32_t aux[3];\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n\n        const uint8_t * GGML_RESTRICT q3 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        // Set up scales\n        memcpy(aux, x[i].scales, 12);\n        __m128i scales128 = _mm_set_epi32(\n                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),\n                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),\n                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),\n                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));\n        scales128 = _mm_sub_epi8(scales128, m32);\n        const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);\n        const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);\n        const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);\n        const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};\n\n        // high bit\n        const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);\n\n        // integer accumulator\n        __m256i sumi = _mm256_setzero_si256();\n\n        int bit = 0;\n        int is  = 0;\n\n        for (int j = 0; j < QK_K/128; ++j) {\n            // load low 2 bits\n            const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;\n\n            // prepare low and high bits\n            const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);\n            const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);\n            ++bit;\n\n            const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);\n            const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);\n            ++bit;\n\n            const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);\n            const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);\n            ++bit;\n\n            const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);\n            const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);\n            ++bit;\n\n            // load Q8 quants\n            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n\n            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,\n            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,\n            // and 2 if the high bit was set)\n            __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);\n            __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);\n            __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);\n            __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);\n\n            __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);\n            __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);\n            __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);\n            __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);\n\n            p16_0 = _mm256_sub_epi16(p16_0, q8s_0);\n            p16_1 = _mm256_sub_epi16(p16_1, q8s_1);\n            p16_2 = _mm256_sub_epi16(p16_2, q8s_2);\n            p16_3 = _mm256_sub_epi16(p16_3, q8s_3);\n\n            // multiply with scales\n            p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);\n            p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);\n            p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);\n            p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);\n\n            // accumulate\n            p16_0 = _mm256_add_epi32(p16_0, p16_1);\n            p16_2 = _mm256_add_epi32(p16_2, p16_3);\n            sumi  = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));\n\n        }\n\n        // multiply with block scale and accumulate\n        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);\n\n    }\n\n    *s = hsum_float_8(acc);\n\n#elif defined __AVX__\n\n    const __m128i m3 = _mm_set1_epi8(3);\n    const __m128i mone = _mm_set1_epi8(1);\n    const __m128i m32 = _mm_set1_epi8(32);\n    const __m128i m2 = _mm_set1_epi8(2);\n\n    __m256 acc = _mm256_setzero_ps();\n\n    const uint32_t *aux;\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n\n        const uint8_t * GGML_RESTRICT q3 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        // Set up scales\n        aux = (const uint32_t *)x[i].scales;\n        __m128i scales128 = _mm_set_epi32(\n                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),\n                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),\n                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),\n                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));\n        scales128 = _mm_sub_epi8(scales128, m32);\n        const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);\n        const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));\n        const __m128i scales[2] = { scales_0, scales_1 };\n\n        // high bit *128*2 from block_q3_K.hmask[QK_K/8]\n        const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);\n        const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);\n\n        // integer accumulator\n        __m128i sumi_0 = _mm_setzero_si128();\n        __m128i sumi_1 = _mm_setzero_si128();\n\n        for (int j = 0; j < QK_K/128; ++j) {\n            // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]\n            const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;\n            const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;\n\n            // prepare low and high bits\n            const int bit = j << 2;\n\n            const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);\n            const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);\n            const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);\n            const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);\n\n            const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);\n            const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);\n            const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);\n            const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);\n\n            const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);\n            const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);\n            const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);\n            const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);\n\n            const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);\n            const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);\n            const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);\n            const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);\n\n            // load Q8 quants from block_q8_K.qs[QK_K]\n            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n\n            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,\n            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,\n            // and 2 if the high bit was set)\n            __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);\n            __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);\n            __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);\n            __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);\n            __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);\n            __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);\n            __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);\n            __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);\n\n            __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);\n            __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);\n            __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);\n            __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);\n            __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);\n            __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);\n            __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);\n            __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);\n\n            p16_0 = _mm_sub_epi16(p16_0, q8s_0);\n            p16_1 = _mm_sub_epi16(p16_1, q8s_1);\n            p16_2 = _mm_sub_epi16(p16_2, q8s_2);\n            p16_3 = _mm_sub_epi16(p16_3, q8s_3);\n            p16_4 = _mm_sub_epi16(p16_4, q8s_4);\n            p16_5 = _mm_sub_epi16(p16_5, q8s_5);\n            p16_6 = _mm_sub_epi16(p16_6, q8s_6);\n            p16_7 = _mm_sub_epi16(p16_7, q8s_7);\n\n            // multiply with scales\n            __m128i shuffle = _mm_set1_epi16(0x0100);\n            p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);\n\n            // accumulate\n            p16_0 = _mm_add_epi32(p16_0, p16_1);\n            p16_2 = _mm_add_epi32(p16_2, p16_3);\n            p16_4 = _mm_add_epi32(p16_4, p16_5);\n            p16_6 = _mm_add_epi32(p16_6, p16_7);\n            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));\n            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));\n\n        }\n\n        // multiply with block scale and accumulate\n        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);\n        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);\n\n    }\n\n    *s = hsum_float_8(acc);\n\n#else\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n#if defined __AVX2__\n\n    const __m256i m4 = _mm256_set1_epi8(0xF);\n\n    __m256 acc = _mm256_setzero_ps();\n    __m128 acc_m = _mm_setzero_ps();\n\n   for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        memcpy(utmp, x[i].scales, 12);\n        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n        const uint32_t uaux = utmp[1] & kmask1;\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[2] = uaux;\n        utmp[0] &= kmask1;\n\n        const uint8_t * GGML_RESTRICT q4 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));\n\n        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);\n        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));\n        const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);\n        acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);\n\n        const __m128i sc128  = _mm256_extracti128_si256(mins_and_scales, 0);\n        const __m256i scales = MM256_SET_M128I(sc128, sc128);\n\n        __m256i sumi = _mm256_setzero_si256();\n\n        for (int j = 0; j < QK_K/64; ++j) {\n\n            const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));\n            const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));\n\n            const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;\n            const __m256i q4l = _mm256_and_si256(q4bits, m4);\n            const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);\n\n            const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n            __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);\n            p16l = _mm256_madd_epi16(scale_l, p16l);\n\n            const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n            __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);\n            p16h = _mm256_madd_epi16(scale_h, p16h);\n            const __m256i sumj = _mm256_add_epi32(p16l, p16h);\n\n            sumi = _mm256_add_epi32(sumi, sumj);\n        }\n\n        __m256 vd = _mm256_set1_ps(d);\n        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);\n\n    }\n\n    acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));\n    acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));\n\n    *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);\n\n#elif defined __AVX__\n\n    const __m128i m4 = _mm_set1_epi8(0xF);\n    const __m128i m2 = _mm_set1_epi8(0x2);\n\n    __m256 acc = _mm256_setzero_ps();\n    __m128 acc_m = _mm_setzero_ps();\n\n   for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        const uint8_t * GGML_RESTRICT q4 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        memcpy(utmp, x[i].scales, 12);\n        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n        const uint32_t uaux = utmp[1] & kmask1;\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[2] = uaux;\n        utmp[0] &= kmask1;\n\n        const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);\n        const __m128i scales = _mm_cvtepu8_epi16(utmps);\n        const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));\n\n        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);\n        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);\n        const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);\n        const __m128i prod = _mm_madd_epi16(mins, q8s);\n        acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);\n\n        __m128i sumi_0 = _mm_setzero_si128();\n        __m128i sumi_1 = _mm_setzero_si128();\n\n        __m128i shuffle = _mm_set1_epi16(0x0100);\n        for (int j = 0; j < QK_K/64; ++j) {\n\n            const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);\n            shuffle = _mm_add_epi16(shuffle, m2);\n\n            __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;\n            const __m128i q4l_0 = _mm_and_si128(q4bits, m4);\n            const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);\n            q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;\n            const __m128i q4l_1 = _mm_and_si128(q4bits, m4);\n            const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);\n\n            const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);\n            p16l = _mm_madd_epi16(scale_l, p16l);\n            sumi_0 = _mm_add_epi32(sumi_0, p16l);\n            const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            p16l = _mm_maddubs_epi16(q4l_1, q8l_1);\n            p16l = _mm_madd_epi16(scale_l, p16l);\n            sumi_1 = _mm_add_epi32(sumi_1, p16l);\n\n            const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);\n            p16h = _mm_madd_epi16(scale_h, p16h);\n            sumi_0 = _mm_add_epi32(sumi_0, p16h);\n            const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            p16h = _mm_maddubs_epi16(q4h_1, q8h_1);\n            p16h = _mm_madd_epi16(scale_h, p16h);\n            sumi_1 = _mm_add_epi32(sumi_1, p16h);\n\n        }\n\n        __m256 vd = _mm256_set1_ps(d);\n        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);\n        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);\n\n    }\n\n    acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));\n    acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));\n\n    *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    UNUSED(utmp);\n    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n#if defined __AVX2__\n\n    const __m256i m4 = _mm256_set1_epi8(0xF);\n    const __m128i mzero = _mm_setzero_si128();\n    const __m256i mone  = _mm256_set1_epi8(1);\n\n    __m256 acc = _mm256_setzero_ps();\n\n    float summs = 0.f;\n\n    for (int i = 0; i < nb; ++i) {\n        const uint8_t * GGML_RESTRICT q5 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        memcpy(utmp, x[i].scales, 12);\n        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n        const uint32_t uaux = utmp[1] & kmask1;\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[2] = uaux;\n        utmp[0] &= kmask1;\n\n        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));\n\n        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);\n        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));\n        const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);\n        const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);\n        summs += dmin * _mm_extract_epi32(hsum, 0);\n\n        const __m128i sc128  = _mm256_extracti128_si256(mins_and_scales, 0);\n        const __m256i scales = MM256_SET_M128I(sc128, sc128);\n\n        const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);\n        __m256i hmask = mone;\n\n        __m256i sumi = _mm256_setzero_si256();\n\n        int bit = 0;\n\n        for (int j = 0; j < QK_K/64; ++j) {\n\n            const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));\n            const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));\n\n            const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;\n\n            const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);\n            const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);\n            const __m256i q5_0  = _mm256_add_epi8(q5l_0, q5h_0);\n            hmask = _mm256_slli_epi16(hmask, 1);\n\n            const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);\n            const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);\n            const __m256i q5_1  = _mm256_add_epi8(q5l_1, q5h_1);\n            hmask = _mm256_slli_epi16(hmask, 1);\n\n            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n\n            __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);\n            __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);\n\n            p16_0 = _mm256_madd_epi16(scale_0, p16_0);\n            p16_1 = _mm256_madd_epi16(scale_1, p16_1);\n\n            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));\n\n        }\n\n        __m256 vd = _mm256_set1_ps(d);\n        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);\n\n    }\n\n    *s = hsum_float_8(acc) + summs;\n\n#elif defined __AVX__\n\n    const __m128i m4 = _mm_set1_epi8(0xF);\n    const __m128i mzero = _mm_setzero_si128();\n    const __m128i mone  = _mm_set1_epi8(1);\n    const __m128i m2 = _mm_set1_epi8(2);\n\n    __m256 acc = _mm256_setzero_ps();\n\n    float summs = 0.f;\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        const uint8_t * GGML_RESTRICT q5 = x[i].qs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        memcpy(utmp, x[i].scales, 12);\n        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n        const uint32_t uaux = utmp[1] & kmask1;\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[2] = uaux;\n        utmp[0] &= kmask1;\n\n        const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);\n        const __m128i scales = _mm_cvtepu8_epi16(utmps);\n        const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));\n\n        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);\n        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);\n        const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);\n        const __m128i prod = _mm_madd_epi16(mins, q8s);\n        const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);\n        summs += dmin * _mm_extract_epi32(hsum, 0);\n\n        const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);\n        const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);\n        __m128i hmask = mone;\n\n        __m128i sumi_0 = _mm_setzero_si128();\n        __m128i sumi_1 = _mm_setzero_si128();\n\n        int bit = 0;\n\n        __m128i shuffle = _mm_set1_epi16(0x0100);\n        for (int j = 0; j < QK_K/64; ++j) {\n\n            const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);\n            shuffle = _mm_add_epi16(shuffle, m2);\n            const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);\n            shuffle = _mm_add_epi16(shuffle, m2);\n\n            const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;\n            const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;\n\n            __m128i q5l_0 = _mm_and_si128(q5bits_0, m4);\n            __m128i q5l_1 = _mm_and_si128(q5bits_1, m4);\n            __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);\n            __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);\n            __m128i q5_0  = _mm_add_epi8(q5l_0, q5h_0);\n            __m128i q5_1  = _mm_add_epi8(q5l_1, q5h_1);\n            hmask = _mm_slli_epi16(hmask, 1);\n\n            __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);\n            __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);\n            p16_0 = _mm_madd_epi16(scale_0, p16_0);\n            p16_1 = _mm_madd_epi16(scale_0, p16_1);\n\n            q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);\n            q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);\n            q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);\n            q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);\n            q5_0  = _mm_add_epi8(q5l_0, q5h_0);\n            q5_1  = _mm_add_epi8(q5l_1, q5h_1);\n            hmask = _mm_slli_epi16(hmask, 1);\n\n            q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);\n            __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);\n            p16_2 = _mm_madd_epi16(scale_1, p16_2);\n            p16_3 = _mm_madd_epi16(scale_1, p16_3);\n\n            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));\n            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));\n\n        }\n\n        __m256 vd = _mm256_set1_ps(d);\n        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);\n        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);\n\n    }\n\n    *s = hsum_float_8(acc) + summs;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    UNUSED(utmp);\n    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q6_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __AVX2__\n\n    const __m256i m4 = _mm256_set1_epi8(0xF);\n    const __m256i m2 = _mm256_set1_epi8(3);\n    const __m256i m32s = _mm256_set1_epi8(32);\n\n    __m256 acc = _mm256_setzero_ps();\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n\n        const uint8_t * GGML_RESTRICT q4 = x[i].ql;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);\n\n        __m256i sumi = _mm256_setzero_si256();\n\n        int is = 0;\n\n        for (int j = 0; j < QK_K/128; ++j) {\n\n            const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));\n            const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));\n            const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));\n            const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));\n            is += 4;\n\n            const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;\n            const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;\n            const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;\n\n            const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);\n            const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);\n            const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);\n            const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);\n\n            const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);\n            const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);\n            const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);\n            const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);\n\n            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n\n            __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);\n            __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);\n            __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);\n            __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);\n\n            __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);\n            __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);\n            __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);\n            __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);\n\n            p16_0 = _mm256_sub_epi16(p16_0, q8s_0);\n            p16_1 = _mm256_sub_epi16(p16_1, q8s_1);\n            p16_2 = _mm256_sub_epi16(p16_2, q8s_2);\n            p16_3 = _mm256_sub_epi16(p16_3, q8s_3);\n\n            p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);\n            p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);\n            p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);\n            p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);\n\n            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));\n            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));\n\n        }\n\n        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);\n    }\n\n    *s = hsum_float_8(acc);\n\n#elif defined __AVX__\n\n    const __m128i m3 = _mm_set1_epi8(3);\n    const __m128i m15 = _mm_set1_epi8(15);\n\n    __m256 acc = _mm256_setzero_ps();\n\n    for (int i = 0; i < nb; ++i) {\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n\n        const uint8_t * GGML_RESTRICT q4 = x[i].ql;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        // handle the q6_k -32 offset separately using bsums\n        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);\n        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);\n        const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);\n        const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);\n        const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));\n        const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);\n        const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);\n\n        __m128i sumi_0 = _mm_setzero_si128();\n        __m128i sumi_1 = _mm_setzero_si128();\n\n        int is = 0;\n\n        for (int j = 0; j < QK_K/128; ++j) {\n\n            const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;\n            const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;\n\n            const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);\n            const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);\n            const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);\n            const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);\n            const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));\n            const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));\n            const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);\n            const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);\n\n            const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;\n            const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;\n            const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;\n            const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;\n\n            const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);\n            const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);\n            const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);\n            const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);\n            const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);\n            const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);\n            const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);\n            const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);\n\n            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;\n\n            __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);\n            __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);\n            __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);\n            __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);\n            __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);\n            __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);\n            __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);\n            __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);\n\n            const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));\n            const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));\n            const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));\n            const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));\n            is += 4;\n\n            p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);\n            p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);\n            p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);\n            p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);\n            p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);\n            p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);\n            p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);\n            p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);\n\n            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));\n            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));\n            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));\n            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));\n\n        }\n\n        sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);\n        sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);\n        const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);\n        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);\n    }\n\n    *s = hsum_float_8(acc);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\n#if defined (__AVX__) || defined (__AVX2__)\nstatic const int8_t keven_signs_q2xs[1024] = {\n     1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,\n     1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,\n     1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1, -1,\n     1,  1, -1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1,  1,\n     1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1, -1,\n     1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1,  1,\n     1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1,  1,\n     1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1, -1,\n     1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1, -1,\n     1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1,  1,\n     1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1,  1,\n     1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1, -1,\n     1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1,  1,\n     1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,\n     1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1, -1,\n     1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1,\n     1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1, -1,\n     1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1,  1,\n     1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1,\n     1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1,\n     1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1,  1,\n     1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1, -1,\n     1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1, -1,\n     1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1,\n     1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1,\n     1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1, -1,\n     1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1,\n     1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1,\n     1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1,\n     1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,  1,\n     1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1,\n     1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,\n};\n#endif\n\nvoid ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_xxs * GGML_RESTRICT x = vx;\n    const block_q8_K    * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__AVX2__)\n\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n\n    uint32_t aux32[4];\n    const uint8_t * aux8 = (const uint8_t *)aux32;\n\n    __m256 accumf = _mm256_setzero_ps();\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint16_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n        __m256i sumi1 = _mm256_setzero_si256();\n        __m256i sumi2 = _mm256_setzero_si256();\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;\n            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;\n            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;\n            const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);\n            const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);\n            const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],\n                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);\n            const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],\n                                                   signs64[(aux32[3] >>  7) & 127], signs64[(aux32[3] >>  0) & 127]);\n            const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);\n            const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);\n            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);\n            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);\n            const uint16_t ls1 = aux32[1] >> 28;\n            const uint16_t ls2 = aux32[3] >> 28;\n            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));\n            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));\n            sumi1 = _mm256_add_epi32(sumi1, p1);\n            sumi2 = _mm256_add_epi32(sumi2, p2);\n        }\n\n        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);\n\n    }\n\n    *s = 0.125f * hsum_float_8(accumf);\n\n#elif defined(__AVX__)\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n\n    uint32_t aux32[4];\n    const uint8_t * aux8 = (const uint8_t *)aux32;\n\n    __m256 accumf = _mm256_setzero_ps();\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint16_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n        __m128i sumi1_0 = _mm_setzero_si128();\n        __m128i sumi1_1 = _mm_setzero_si128();\n        __m128i sumi2_0 = _mm_setzero_si128();\n        __m128i sumi2_1 = _mm_setzero_si128();\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;\n            const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);\n            const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]);\n            const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);\n            const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]);\n            const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);\n            const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);\n            const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >>  7) & 127], signs64[(aux32[3] >>  0) & 127]);\n            const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]);\n            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);\n            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);\n            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);\n            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);\n            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);\n            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);\n            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);\n            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);\n            const uint16_t ls1 = aux32[1] >> 28;\n            const uint16_t ls2 = aux32[3] >> 28;\n            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));\n            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));\n            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));\n            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));\n            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);\n            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);\n            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);\n            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);\n        }\n\n        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);\n\n    }\n\n    *s = 0.125f * hsum_float_8(accumf);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_xs * GGML_RESTRICT x = vx;\n    const block_q8_K   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__AVX2__)\n\n    const __m256i mone = _mm256_set1_epi8(1);\n    static const char block_sign_shuffle_mask_1[32] = {\n        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,\n        0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,\n    };\n    static const char block_sign_shuffle_mask_2[32] = {\n        0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,\n        0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,\n    };\n    static const uint8_t bit_selector_mask_bytes[32] = {\n        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n    };\n\n    const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);\n    const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);\n    const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);\n\n    static const uint8_t k_bit_helper[32] = {\n        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,\n        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,\n    };\n    const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);\n    const __m256i m511 = _mm256_set1_epi16(511);\n    const __m128i m4 = _mm_set1_epi8(0xf);\n    const __m128i m1 = _mm_set1_epi8(1);\n\n    uint64_t aux64;\n\n    // somewhat hacky, but gives a significant boost in performance\n    __m256i aux_gindex;\n    const uint16_t * gindex = (const uint16_t *)&aux_gindex;\n\n    __m256 accumf = _mm256_setzero_ps();\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint16_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n\n        memcpy(&aux64, x[i].scales, 8);\n        __m128i stmp = _mm_set1_epi64x(aux64);\n        stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));\n        const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);\n\n        __m256i sumi1 = _mm256_setzero_si256();\n        __m256i sumi2 = _mm256_setzero_si256();\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {\n\n            const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2);  q2 += 16;\n            aux_gindex = _mm256_and_si256(q2_data, m511);\n\n            const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9);\n            const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13);\n            const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper);\n\n            const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);\n            const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits);\n\n            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;\n            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;\n            const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;\n            const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;\n\n            const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],\n                                                   iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);\n            const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],\n                                                   iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);\n            const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],\n                                                   iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);\n            const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],\n                                                   iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);\n\n            const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);\n            const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);\n            const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l);\n            const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h);\n\n            __m256i signs;\n            signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);\n            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);\n            const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));\n\n            signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);\n            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);\n            const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));\n\n            signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);\n            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);\n            const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));\n\n            signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);\n            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);\n            const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));\n\n            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);\n            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);\n            const __m256i dot3  = _mm256_maddubs_epi16(q2_3, q8s_3);\n            const __m256i dot4  = _mm256_maddubs_epi16(q2_4, q8s_4);\n\n            const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));\n            const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));\n            const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)));\n            const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)));\n\n            sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));\n            sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));\n            sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3));\n            sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4));\n        }\n\n        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);\n\n    }\n\n    *s = 0.125f * hsum_float_8(accumf);\n\n#elif defined(__AVX__)\n    const __m128i mone = _mm_set1_epi8(1);\n    static const char block_sign_shuffle_mask_1[32] = {\n        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,\n        0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,\n    };\n    static const char block_sign_shuffle_mask_2[32] = {\n        0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,\n        0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,\n    };\n    static const uint8_t bit_selector_mask_bytes[32] = {\n        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n    };\n\n    const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes);\n    const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1);\n    const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1);\n    const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1);\n    const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2);\n    const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1);\n\n    static const uint8_t k_bit_helper[32] = {\n        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,\n        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,\n    };\n    const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper);\n    const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1);\n    const __m128i m511 = _mm_set1_epi16(511);\n    const __m128i m4 = _mm_set1_epi8(0xf);\n    const __m128i m1 = _mm_set1_epi8(1);\n\n    uint64_t aux64;\n\n    // somewhat hacky, but gives a significant boost in performance\n    __m256i aux_gindex;\n    const uint16_t * gindex = (const uint16_t *)&aux_gindex;\n\n    __m256 accumf = _mm256_setzero_ps();\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint16_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n\n        memcpy(&aux64, x[i].scales, 8);\n        __m128i stmp = _mm_set1_epi64x(aux64);\n        stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));\n        const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);\n\n        __m128i sumi1_0 = _mm_setzero_si128();\n        __m128i sumi1_1 = _mm_setzero_si128();\n        __m128i sumi2_0 = _mm_setzero_si128();\n        __m128i sumi2_1 = _mm_setzero_si128();\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {\n\n            const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2);\n            const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1);  q2 += 16;\n            aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511));\n\n            const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9);\n            const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9);\n            const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13);\n            const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13);\n            const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0);\n            const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1);\n\n            const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0);\n            const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1);\n            const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0);\n            const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1);\n\n            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n\n            const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);\n            const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]);\n            const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);\n            const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]);\n            const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]);\n            const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]);\n            const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);\n            const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]);\n\n            // AVX2 full_signs_1 is full_sign_bits_0 here\n            // AVX2 full_signs_2 is full_sign_bits_1 here\n            __m128i signs_0, signs_1;\n            signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0);\n            signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1);\n            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);\n            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);\n            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone));\n            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone));\n\n            signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0);\n            signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1);\n            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);\n            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);\n            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone));\n            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone));\n\n            signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0);\n            signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1);\n            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);\n            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);\n            const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone));\n            const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone));\n\n            signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0);\n            signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1);\n            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);\n            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);\n            const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone));\n            const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone));\n\n            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);\n            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);\n            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);\n            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);\n            const __m128i dot3_0  = _mm_maddubs_epi16(q2_3_0, q8s_3_0);\n            const __m128i dot3_1  = _mm_maddubs_epi16(q2_3_1, q8s_3_1);\n            const __m128i dot4_0  = _mm_maddubs_epi16(q2_4_0, q8s_4_0);\n            const __m128i dot4_1  = _mm_maddubs_epi16(q2_4_1, q8s_4_1);\n\n            __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));\n            const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp);\n            const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));\n            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));\n            const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp);\n            const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));\n            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));\n            const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp);\n            const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));\n            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));\n            const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp);\n            const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));\n\n            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0));\n            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1));\n            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0));\n            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1));\n            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0));\n            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1));\n            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0));\n            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1));\n        }\n\n        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);\n\n    }\n\n    *s = 0.125f * hsum_float_8(accumf);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__AVX2__)\n\n   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,\n                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03\n   };\n\n    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n    };\n\n    const __m128i m4 = _mm_set1_epi8(0xf);\n    const __m128i m1 = _mm_set1_epi8(1);\n\n    const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);\n    const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);\n\n    uint64_t aux64;\n\n    __m256 accumf = _mm256_setzero_ps();\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint8_t * GGML_RESTRICT qs = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        memcpy(&aux64, x[i].scales, 8);\n        const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);\n        const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15\n\n        __m256i sumi1 = _mm256_setzero_si256();\n        __m256i sumi2 = _mm256_setzero_si256();\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;\n            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;\n            const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],\n                                                   iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],\n                                                   iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],\n                                                   iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);\n            const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],\n                                                   iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],\n                                                   iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],\n                                                   iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);\n            qs += 8;\n\n            __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));\n            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);\n            const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);\n            const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);\n\n            aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));\n            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);\n            const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);\n            const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);\n\n            signs += 4;\n\n            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1\n            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3\n\n            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0)));\n            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1)));\n            sumi1 = _mm256_add_epi32(sumi1, p1);\n            sumi2 = _mm256_add_epi32(sumi2, p2);\n        }\n\n        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);\n\n    }\n\n    *s = 0.125f * hsum_float_8(accumf);\n\n#elif defined(__AVX__)\n   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,\n                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03\n   };\n\n    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n    };\n\n    const __m128i m4 = _mm_set1_epi8(0xf);\n    const __m128i m1 = _mm_set1_epi8(1);\n\n    const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);\n    const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);\n    const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);\n    const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);\n\n    uint64_t aux64;\n\n    __m256 accumf = _mm256_setzero_ps();\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint8_t * GGML_RESTRICT qs = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n\n        memcpy(&aux64, x[i].scales, 8);\n        const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);\n        const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8);\n        const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8));\n\n        __m128i sumi1_0 = _mm_setzero_si128();\n        __m128i sumi1_1 = _mm_setzero_si128();\n        __m128i sumi2_0 = _mm_setzero_si128();\n        __m128i sumi2_1 = _mm_setzero_si128();\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],\n                                                  iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);\n            const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],\n                                                  iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]);\n            const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],\n                                                  iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);\n            const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],\n                                                  iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]);\n            qs += 8;\n\n            __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));\n            __m128i aux128_1 = aux128_0;\n            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);\n            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);\n            const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);\n            const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);\n            const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);\n            const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);\n\n            aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));\n            aux128_1 = aux128_0;\n            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);\n            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);\n            const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);\n            const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);\n            const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);\n            const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);\n\n            signs += 4;\n\n            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);\n            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);\n            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);\n            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);\n\n            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0)));\n            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1)));\n            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0)));\n            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1)));\n            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);\n            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);\n            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);\n            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);\n        }\n\n        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);\n\n    }\n\n    *s = 0.125f * hsum_float_8(accumf);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq3_xxs * GGML_RESTRICT x = vx;\n    const block_q8_K    * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__AVX2__)\n\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n\n    uint32_t aux32[2];\n\n    __m256 accumf = _mm256_setzero_ps();\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint8_t * GGML_RESTRICT q3 = x[i].qs;\n        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n        __m256i sumi1 = _mm256_setzero_si256();\n        __m256i sumi2 = _mm256_setzero_si256();\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;\n            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;\n            const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],\n                                                  iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);\n            q3 += 8;\n            const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],\n                                                  iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);\n            q3 += 8;\n            memcpy(aux32, gas, 8); gas += 8;\n            const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],\n                                                   signs64[(aux32[0] >>  7) & 127], signs64[(aux32[0] >>  0) & 127]);\n            const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],\n                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);\n            const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);\n            const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);\n            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);\n            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);\n            const uint16_t ls1 = aux32[0] >> 28;\n            const uint16_t ls2 = aux32[1] >> 28;\n            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));\n            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));\n            sumi1 = _mm256_add_epi32(sumi1, p1);\n            sumi2 = _mm256_add_epi32(sumi2, p2);\n        }\n\n        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);\n\n    }\n\n    *s = 0.25f * hsum_float_8(accumf);\n\n#elif defined(__AVX__)\n    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;\n\n    uint32_t aux32[2];\n\n    __m256 accumf = _mm256_setzero_ps();\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint8_t * GGML_RESTRICT q3 = x[i].qs;\n        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n        __m128i sumi1_0 = _mm_setzero_si128();\n        __m128i sumi1_1 = _mm_setzero_si128();\n        __m128i sumi2_0 = _mm_setzero_si128();\n        __m128i sumi2_1 = _mm_setzero_si128();\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);\n            const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);\n            q3 += 8;\n            const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);\n            const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);\n            q3 += 8;\n            memcpy(aux32, gas, 8); gas += 8;\n            const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >>  7) & 127], signs64[(aux32[0] >>  0) & 127]);\n            const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]);\n            const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);\n            const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);\n            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);\n            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);\n            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);\n            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);\n            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);\n            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);\n            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);\n            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);\n            const uint16_t ls1 = aux32[0] >> 28;\n            const uint16_t ls2 = aux32[1] >> 28;\n            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));\n            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));\n            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));\n            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));\n            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);\n            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);\n            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);\n            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);\n        }\n\n        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);\n\n    }\n\n    *s = 0.25f * hsum_float_8(accumf);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq3_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined(__AVX2__)\n\n   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,\n                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03\n   };\n\n    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n    };\n\n    const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);\n    const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);\n\n    const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);\n    const __m256i idx_mask  = _mm256_set1_epi32(256);\n\n    typedef union {\n        __m256i  vec[2];\n        uint32_t index[16];\n    } index_t;\n\n    index_t idx;\n\n    __m256 accumf = _mm256_setzero_ps();\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint8_t * GGML_RESTRICT qs = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n        __m256i sumi1 = _mm256_setzero_si256();\n        __m256i sumi2 = _mm256_setzero_si256();\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;\n            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;\n            const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16;\n            idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]);\n            idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]);\n            idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask);\n            idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask);\n            idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l)));\n            idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1)));\n\n            // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.\n            //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);\n            //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);\n            const __m256i q2_1 = _mm256_set_epi32(\n                    iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],\n                    iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]\n            );\n            const __m256i q2_2 = _mm256_set_epi32(\n                    iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],\n                    iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]\n            );\n\n            __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));\n            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);\n            const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);\n            const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);\n\n            aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));\n            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);\n            const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);\n            const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);\n\n            signs += 4;\n\n            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);\n            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);\n            const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;\n            const uint16_t ls2 = x[i].scales[ib32/2] >>  4;\n            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));\n            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));\n            sumi1 = _mm256_add_epi32(sumi1, p1);\n            sumi2 = _mm256_add_epi32(sumi2, p2);\n        }\n\n        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);\n\n    }\n\n    *s = hsum_float_8(accumf);\n\n#elif defined(__AVX__)\n   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,\n                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03\n   };\n\n    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,\n    };\n\n    const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);\n    const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);\n    const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);\n    const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);\n\n    const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256);\n    const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16);\n    const __m128i idx_mask  = _mm_set1_epi32(256);\n\n    typedef union {\n        __m128i  vec[4];\n        uint32_t index[16];\n    } index_t;\n\n    index_t idx;\n\n    __m256 accumf = _mm256_setzero_ps();\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint8_t * GGML_RESTRICT qs = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n        __m128i sumi1_0 = _mm_setzero_si128();\n        __m128i sumi1_1 = _mm_setzero_si128();\n        __m128i sumi2_0 = _mm_setzero_si128();\n        __m128i sumi2_1 = _mm_setzero_si128();\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);\n            const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);\n            const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;\n            idx.vec[0] = _mm_set1_epi32(qh[ib32+0]);\n            idx.vec[1] = idx.vec[0];\n            idx.vec[2] = _mm_set1_epi32(qh[ib32+1]);\n            idx.vec[3] = idx.vec[2];\n\n            idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask);\n            idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask);\n            idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask);\n            idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask);\n\n            idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0));\n            idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));\n            idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1));\n            idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));\n\n            const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);\n            const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);\n            const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]);\n            const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);\n\n            __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));\n            __m128i aux128_1 = aux128_0;\n            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);\n            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);\n            const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);\n            const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);\n            const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);\n            const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);\n\n            aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16));\n            aux128_1 = aux128_0;\n            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);\n            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);\n            const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);\n            const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);\n            const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);\n            const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);\n\n            signs += 4;\n\n            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);\n            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);\n            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);\n            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);\n            const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;\n            const uint16_t ls2 = x[i].scales[ib32/2] >>  4;\n            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));\n            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));\n            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));\n            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));\n            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);\n            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);\n            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);\n            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);\n        }\n\n        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);\n\n    }\n\n    *s = hsum_float_8(accumf);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq1_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __AVX2__\n\n    __m256 accum = _mm256_setzero_ps();\n    float accum1 = 0;\n    for (int i = 0; i < nb; ++i) {\n\n        const int8_t   * q8 = y[i].qs;\n        const uint8_t  * qs = x[i].qs;\n        const uint16_t * qh = x[i].qh;\n\n        __m256i sumi = _mm256_setzero_si256();\n        int sumi1 = 0;\n        for (int ib = 0; ib < QK_K/32; ib += 2) {\n#ifdef __BMI2__\n            const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib], 0x700070007000700ULL);\n            const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib + 1], 0x700070007000700ULL);\n            const uint16_t *idx1 = (const uint16_t *)(&packed_idx1);\n            const uint16_t *idx2 = (const uint16_t *)(&packed_idx2);\n            const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]);\n            const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]);\n#else\n            const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],\n                                                    iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);\n            const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)],\n                                                    iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);\n#endif\n            qs += 8;\n            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n\n            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);\n            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);\n            const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;\n            const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;\n            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1));\n            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2));\n\n            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2));\n            sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1\n                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;\n        }\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum);\n        accum1 += d * sumi1;\n\n    }\n\n    *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;\n\n#elif defined __AVX__\n    __m256 accum = _mm256_setzero_ps();\n    float accum1 = 0;\n    for (int i = 0; i < nb; ++i) {\n\n        const int8_t   * q8 = y[i].qs;\n        const uint8_t  * qs = x[i].qs;\n        const uint16_t * qh = x[i].qh;\n\n        __m128i sumi1_0 = _mm_setzero_si128();\n        __m128i sumi1_1 = _mm_setzero_si128();\n        int sumi1 = 0;\n        for (int ib = 0; ib < QK_K/32; ib += 2) {\n            const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);\n            const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]);\n            const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);\n            const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]);\n            qs += 8;\n            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n\n            const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);\n            const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);\n            const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);\n            const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);\n            const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;\n            const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;\n            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1));\n            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1));\n            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2));\n            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2));\n\n            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));\n            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));\n            sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1\n                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;\n        }\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum);\n        accum1 += d * sumi1;\n\n    }\n\n    *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq1_m * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    iq1m_scale_t scale;\n\n#if defined __AVX2__\n\n    const __m256i mask = _mm256_set1_epi16(0x7);\n    const __m256i mone = _mm256_set1_epi16(1);\n    const __m256i mone8 = _mm256_set1_epi8(1);\n    const __m256i mtwo8 = _mm256_set1_epi8(2);\n    // VPSHUFB cannot cross 128-bit lanes so odd shifts go to upper half.\n    const __m256i scales_shift = _mm256_set_epi64x(9, 3, 6, 0);\n\n    __m256 accum1 = _mm256_setzero_ps();\n    __m256 accum2 = _mm256_setzero_ps();\n    for (int i = 0; i < nb; ++i) {\n\n        const int8_t   * q8 = y[i].qs;\n        const uint8_t  * qs = x[i].qs;\n        const uint8_t  * qh = x[i].qh;\n        const uint16_t * sc = (const uint16_t *)x[i].scales;\n\n        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);\n        // Extract 3-bit scales (16 values)\n        __m256i scales = _mm256_set1_epi64x(*(const uint64_t*)sc);\n        scales = _mm256_srlv_epi64(scales, scales_shift);\n        scales = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scales, mask), 1), mone);\n\n        // Indices to repeat each scale 8 times.\n        __m256i scales_idx1 = _mm256_set1_epi16(0x0100);\n        __m256i scales_idx2 = _mm256_add_epi8(scales_idx1, _mm256_set1_epi8(8));\n\n        __m256i sumi1 = _mm256_setzero_si256();\n        __m256i sumi2 = _mm256_setzero_si256();\n        for (int ib = 0; ib < QK_K/32; ib += 2) {\n#ifdef __BMI2__\n            const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL)\n                                       | _pdep_u64(*(const uint16_t*)(qh) & 0x7777, 0xf000f000f000f00ULL);\n            const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL)\n                                       | _pdep_u64(*(const uint16_t*)(qh + 2) & 0x7777, 0xf000f000f000f00ULL);\n            const uint16_t *idx1 = (const uint16_t *)(&packed_idx1);\n            const uint16_t *idx2 = (const uint16_t *)(&packed_idx2);\n            const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]);\n            const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]);\n\n            // Convert signs to bytes 0x81 (negative) or 0x01 (positive)\n            const uint64_t delta_sign = _pdep_u64(*(const uint32_t*)(qh) & 0x88888888, 0xf0f0f0f0f0f0f0f0ULL);\n            const __m256i delta1 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign)));\n            const __m256i delta2 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign >> 32)));\n#else\n            const __m256i q1b_1 = _mm256_set_epi64x(\n                    iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],\n                    iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]\n            );\n            const __m256i q1b_2 = _mm256_set_epi64x(\n                    iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],\n                    iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]\n            );\n\n            const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,\n                                                     qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,\n                                                     qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,\n                                                     qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);\n            const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,\n                                                     qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,\n                                                     qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,\n                                                     qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);\n#endif\n            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;\n\n            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);\n            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);\n            const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1));\n            const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2));\n\n            __m256i scale1 = _mm256_shuffle_epi8(scales, scales_idx1);\n            __m256i scale2 = _mm256_shuffle_epi8(scales, scales_idx2);\n\n            scales_idx1 = _mm256_add_epi8(scales_idx1, mtwo8);\n            scales_idx2 = _mm256_add_epi8(scales_idx2, mtwo8);\n\n            const __m256i p1 = _mm256_madd_epi16(dot1, scale1);\n            const __m256i p2 = _mm256_madd_epi16(dot2, scale2);\n            const __m256i p3 = _mm256_madd_epi16(dot3, scale1);\n            const __m256i p4 = _mm256_madd_epi16(dot4, scale2);\n\n            sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2));\n            sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4));\n\n            qs += 8; qh += 4;\n        }\n\n        const __m256 d = _mm256_set1_ps(y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16));\n\n        accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);\n        accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);\n    }\n\n    *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);\n\n#elif defined __AVX__\n    const __m128i mask = _mm_set1_epi16(0x7);\n    const __m128i mone = _mm_set1_epi16(1);\n\n    __m256 accum1 = _mm256_setzero_ps();\n    __m256 accum2 = _mm256_setzero_ps();\n    for (int i = 0; i < nb; ++i) {\n\n        const int8_t   * q8 = y[i].qs;\n        const uint8_t  * qs = x[i].qs;\n        const uint8_t  * qh = x[i].qh;\n        const uint16_t * sc = (const uint16_t *)x[i].scales;\n\n        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);\n\n        __m128i sumi1_0 = _mm_setzero_si128();\n        __m128i sumi1_1 = _mm_setzero_si128();\n        __m128i sumi2_0 = _mm_setzero_si128();\n        __m128i sumi2_1 = _mm_setzero_si128();\n        for (int ib = 0; ib < QK_K/32; ib += 2) {\n            const __m128i q1b_1_0 = _mm_set_epi64x(\n                    iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]);\n            const __m128i q1b_1_1 = _mm_set_epi64x(\n                    iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]);\n            const __m128i q1b_2_0 = _mm_set_epi64x(\n                    iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]);\n            const __m128i q1b_2_1 = _mm_set_epi64x(\n                    iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]);\n            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n\n            const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);\n            const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);\n            const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);\n            const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);\n\n            const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,\n                                                     qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);\n            const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,\n                                                     qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);\n            const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,\n                                                     qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);\n            const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,\n                                                     qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);\n\n            const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0);\n            const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1);\n            const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0);\n            const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1);\n\n            __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0);\n            __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3);\n            __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6);\n            __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9);\n\n            scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone);\n            scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone);\n            scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone);\n            scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone);\n            const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0);\n            const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1);\n            const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0);\n            const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1);\n            const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0);\n            const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1);\n            const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0);\n            const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1);\n\n            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));\n            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));\n            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0));\n            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1));\n\n            qs += 8; qh += 4;\n        }\n\n        const __m256 d = _mm256_set1_ps(y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16));\n\n        accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1);\n        accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2);\n    }\n\n    *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    UNUSED(scale);\n    ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n\nvoid ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK4_NL == 0);\n    static_assert(QK4_NL == QK8_0, \"QK4_NL and QK8_0 must be the same\");\n\n    const block_iq4_nl * GGML_RESTRICT x = vx;\n    const block_q8_0   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK4_NL;\n\n    int ib = 0;\n    float sumf = 0;\n\n#if defined __AVX2__\n\n    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);\n    const __m128i m4b  = _mm_set1_epi8(0x0f);\n    const __m256i mone = _mm256_set1_epi16(1);\n\n    __m256 accum1 = _mm256_setzero_ps();\n    __m256 accum2 = _mm256_setzero_ps();\n    for (; ib + 1 < nb; ib += 2) {\n        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);\n        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);\n        const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);\n        const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);\n        const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),\n                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));\n        const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),\n                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));\n        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);\n        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);\n        const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);\n        const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);\n        accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_FP16_TO_FP32(x[ib + 0].d)),\n                _mm256_cvtepi32_ps(p_1), accum1);\n        accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_FP16_TO_FP32(x[ib + 1].d)),\n                _mm256_cvtepi32_ps(p_2), accum2);\n    }\n\n    sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));\n\n#elif defined __AVX__\n    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);\n    const __m128i m4b  = _mm_set1_epi8(0x0f);\n\n    __m256 accum = _mm256_setzero_ps();\n    for (; ib + 1 < nb; ib += 2) {\n        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);\n        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);\n        const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);\n        const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);\n        const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);\n        const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);\n\n        const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));\n        const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));\n        const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));\n        const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));\n\n        const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);\n        const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);\n        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);\n    }\n\n    sumf = hsum_float_8(accum);\n\n#endif\n    for (; ib < nb; ++ib) {\n        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d);\n        int sumi1 = 0, sumi2 = 0;\n        for (int j = 0; j < QK4_NL/2; ++j) {\n            sumi1 += y[ib].qs[j+       0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];\n            sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >>  4];\n        }\n        sumf += d * (sumi1 + sumi2);\n    }\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_K == 0);\n\n    const block_iq4_xs * GGML_RESTRICT x = vx;\n    const block_q8_K   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n#if defined __AVX2__\n\n    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);\n    const __m128i m4b  = _mm_set1_epi8(0x0f);\n\n    __m256 accum = _mm256_setzero_ps();\n    for (int ibl = 0; ibl < nb; ++ibl) {\n        const uint8_t * qs = x[ibl].qs;\n        const int8_t  * q8 = y[ibl].qs;\n        uint16_t sh = x[ibl].scales_h;\n        __m256i sumi1 = _mm256_setzero_si256();\n        __m256i sumi2 = _mm256_setzero_si256();\n        for (int ib = 0; ib < QK_K/32; ib += 2) {\n            const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs);  qs += 16;\n            const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs);  qs += 16;\n            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;\n            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;\n            const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),\n                                                  _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));\n            const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),\n                                                  _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));\n            const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);\n            const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);\n            const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;\n            const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;\n            sh >>= 4;\n            const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1));\n            const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2));\n            sumi1 = _mm256_add_epi32(p_1, sumi1);\n            sumi2 = _mm256_add_epi32(p_2, sumi2);\n        }\n        accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ibl].d)*y[ibl].d),\n                _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum);\n    }\n\n    *s = hsum_float_8(accum);\n\n#elif defined __AVX__\n    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);\n    const __m128i m4b  = _mm_set1_epi8(0x0f);\n\n    __m256 accum = _mm256_setzero_ps();\n    for (int ibl = 0; ibl < nb; ++ibl) {\n        const uint8_t * qs = x[ibl].qs;\n        const int8_t  * q8 = y[ibl].qs;\n        uint16_t sh = x[ibl].scales_h;\n        __m128i sumi1_0 = _mm_setzero_si128();\n        __m128i sumi1_1 = _mm_setzero_si128();\n        __m128i sumi2_0 = _mm_setzero_si128();\n        __m128i sumi2_1 = _mm_setzero_si128();\n        for (int ib = 0; ib < QK_K/32; ib += 2) {\n            const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16;\n            const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16;\n            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;\n            const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));\n            const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));\n            const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));\n            const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));\n            const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);\n            const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);\n            const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);\n            const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);\n            const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;\n            const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;\n            sh >>= 4;\n            const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1));\n            const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1));\n            const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2));\n            const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2));\n            sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0);\n            sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1);\n            sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0);\n            sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1);\n        }\n        __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0);\n        __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1);\n        accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ibl].d)*y[ibl].d),\n                _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum);\n    }\n\n    *s = hsum_float_8(accum);\n\n#else\n    UNUSED(x);\n    UNUSED(y);\n    UNUSED(nb);\n    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);\n#endif\n}\n"
  },
  {
    "path": "src/ggml-cpu/arch/x86/repack.cpp",
    "content": "#define GGML_COMMON_IMPL_CPP\n#define GGML_COMMON_DECL_CPP\n#include \"ggml-common.h\"\n#include \"ggml-backend-impl.h\"\n\n#include \"ggml-impl.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-cpu-impl.h\"\n#include \"simd-mappings.h\"\n#include \"traits.h\"\n\n#include <cmath>\n#include <cstring>\n#include <cassert>\n#include <cstdlib> // for qsort\n#include <cstdio>  // for GGML_ASSERT\n\n#define GGML_CPU_CLANG_WORKAROUND\n#include \"../../repack.h\"\n\n#if defined(__GNUC__)\n#pragma GCC diagnostic ignored \"-Woverlength-strings\"\n#endif\n\n#define UNUSED GGML_UNUSED\n\n#if defined(__AVX__)\n#if defined(__F16C__)\n#if defined(__AVX512F__)\n#define GGML_F32Cx8x2_LOAD(x, y)     _mm512_cvtph_ps(_mm256_set_m128i(_mm_loadu_si128((const __m128i *)(y)), _mm_loadu_si128((const __m128i *)(x))))\n#define GGML_F32Cx16_REPEAT_LOAD(x)  _mm512_cvtph_ps(_mm256_set_m128i(x, x))\n#endif\n// the  _mm256_cvt intrinsics require F16C\n#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))\n#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask)     _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68))\n#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask)     _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask))\n#else\n#if defined(__AVX512F__)\nstatic inline __m512 __avx512_f32cx8x2_load(ggml_fp16_t *x, ggml_fp16_t *y) {\n    float tmp[16];\n\n    for (int i = 0; i < 8; i++) {\n        tmp[i] = GGML_CPU_FP16_TO_FP32(x[i]);\n    }\n\n    for (int i = 0; i < 8; i++) {\n        tmp[i + 8] = GGML_CPU_FP16_TO_FP32(y[i]);\n    }\n\n    return _mm512_loadu_ps(tmp);\n}\nstatic inline __m512 __avx512_repeat_f32cx16_load(__m128i x) {\n    float tmp[16];\n    uint16_t tmphalf[8];\n    _mm_storeu_si128((__m128i*)tmphalf, x);\n\n    for (int i = 0; i < 4; i++) {\n        tmp[i] = GGML_CPU_FP16_TO_FP32(tmphalf[i]);\n        tmp[i + 4] = GGML_CPU_FP16_TO_FP32(tmphalf[i]);\n        tmp[i + 8] = GGML_CPU_FP16_TO_FP32(tmphalf[i]);\n        tmp[i + 12] = GGML_CPU_FP16_TO_FP32(tmphalf[i]);\n    }\n\n    return _mm512_loadu_ps(tmp);\n}\n#endif\nstatic inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {\n    float tmp[8];\n\n    for (int i = 0; i < 8; i++) {\n        tmp[i] = GGML_CPU_FP16_TO_FP32(x[i]);\n    }\n\n    return _mm256_loadu_ps(tmp);\n}\nstatic inline __m256 __avx_repeat_f32cx8_load(ggml_fp16_t *x) {\n    float tmp[8];\n\n    for (int i = 0; i < 4; i++) {\n        tmp[i] = GGML_CPU_FP16_TO_FP32(x[i]);\n        tmp[i + 4] = GGML_CPU_FP16_TO_FP32(x[i]);\n    }\n\n    return _mm256_loadu_ps(tmp);\n}\nstatic inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrangeMask) {\n    uint16_t tmphalf[8];\n    float tmp[8];\n\n    _mm_storeu_si128((__m128i*)tmphalf, _mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask));\n    for (int i = 0; i < 8; i++) {\n        tmp[i] = GGML_CPU_FP16_TO_FP32(tmphalf[i]);\n    }\n\n    return _mm256_loadu_ps(tmp);\n}\n\n#define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x)\n#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask)     __avx_repeat_f32cx8_load(x)\n#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask)     __avx_rearranged_f32cx8_load(x, arrangeMask)\n#if defined(__AVX512F__)\n#define GGML_F32Cx8x2_LOAD(x, y)     __avx512_f32cx8x2_load(x, y)\n#define GGML_F32Cx16_REPEAT_LOAD(x)  __avx512_repeat_f32cx16_load(x)\n#endif\n#endif\n#endif\n\nstatic inline int nearest_int(float fval) {\n    assert(fabsf(fval) <= 4194303.f);\n    float val = fval + 12582912.f;\n    int i; memcpy(&i, &val, sizeof(int));\n    return (i & 0x007fffff) - 0x00400000;\n}\n\n#if defined(__AVX2__) || defined(__AVX512F__)\n#if defined(__AVX512F__)\n// add int16_t pairwise and return as 512 bit int vector, then add the accumulator\nstatic inline __m512i sum_i16_pairs_acc_int32x16(const __m512i acc, const __m512i x) {\n    const __m512i ones = _mm512_set1_epi16(1);\n    return _mm512_add_epi32(acc, _mm512_madd_epi16(ones, x));\n}\n\nstatic inline __m512i mul_sum_us8_pairs_acc_int32x16(const __m512i acc, const __m512i ax, const __m512i sy) {\n#if defined(__AVX512VNNI__)\n    return _mm512_dpbusd_epi32(acc, ax, sy);\n#else\n    // Perform multiplication and create 16-bit values\n    const __m512i dot = _mm512_maddubs_epi16(ax, sy);\n    return sum_i16_pairs_acc_int32x16(acc, dot);\n#endif\n}\n\n// multiply int8_t, add results pairwise twice and return as 512 bit int vector，then add the accumulator\nstatic inline __m512i mul_sum_i8_pairs_acc_int32x16(const __m512i acc, const __m512i x, const __m512i y) {\n    const __m512i zero = _mm512_setzero_si512();\n    // Get absolute values of x vectors\n    const __m512i ax = _mm512_abs_epi8(x);\n    // Sign the values of the y vectors\n    __mmask64 blt0 = _mm512_movepi8_mask(x);\n    const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y);\n    return mul_sum_us8_pairs_acc_int32x16(acc, ax, sy);\n}\n#endif\n\n// add int16_t pairwise and return as 256 bit int vector, then add the accumulator\nstatic inline __m256i sum_i16_pairs_acc_int32x8(const __m256i acc, const __m256i x) {\n    const __m256i ones = _mm256_set1_epi16(1);\n    return _mm256_add_epi32(acc, _mm256_madd_epi16(ones, x));\n}\n\nstatic inline __m256i mul_sum_us8_pairs_acc_int32x8(const __m256i acc, const __m256i ax, const __m256i sy) {\n#if defined(__AVX512VNNI__) && defined(__AVX512VL__)\n    return _mm256_dpbusd_epi32(acc, ax, sy);\n#elif defined(__AVXVNNI__)\n    return _mm256_dpbusd_avx_epi32(acc, ax, sy);\n#else\n    // Perform multiplication and create 16-bit values\n    const __m256i dot = _mm256_maddubs_epi16(ax, sy);\n    return sum_i16_pairs_acc_int32x8(acc, dot);\n#endif\n}\n\n// Integer variant of the function defined in ggml-quants.c\n// multiply int8_t, add results pairwise twice and return as 256 bit int vector, then add the accumulator\nstatic inline __m256i mul_sum_i8_pairs_acc_int32x8(const __m256i acc, const __m256i x, const __m256i y) {\n#if defined(__AVXVNNIINT8__)\n    return _mm256_dpbssd_epi32(acc, x, y);\n#else\n    // Get absolute values of x vectors\n    const __m256i ax = _mm256_sign_epi8(x, x);\n    // Sign the values of the y vectors\n    const __m256i sy = _mm256_sign_epi8(y, x);\n    return mul_sum_us8_pairs_acc_int32x8(acc, ax, sy);\n#endif\n}\n#endif\n\nvoid ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK8_0 == 32);\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;\n\n#if defined(__AVX2__) || defined(__AVX__)\n    float id[4];\n    __m256 srcv[4][4];\n    __m256 idvec[4];\n\n    for (int i = 0; i < nb; i++) {\n        for (int row_iter = 0; row_iter < 4; row_iter++) {\n            // Load elements into 4 AVX vectors\n            __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 32 );\n            __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 8 );\n            __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 16 );\n            __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 24 );\n\n            // Compute max(abs(e)) for the block\n            const __m256 signBit = _mm256_set1_ps( -0.0f );\n            __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );\n            maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );\n            maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );\n            maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );\n\n            __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );\n            max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );\n            max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );\n            const float maxScalar = _mm_cvtss_f32( max4 );\n\n            // Divided by 127.f to mirror results in quantize_row_q8_0\n            const float d = maxScalar  / 127.f;\n            id[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; //d ? 1.0f / d : 0.0f;\n\n            // Store the scale for the individual block\n            y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);\n\n            // Store the values in blocks of eight values - Aim is to use these later for block interleaving\n            srcv[row_iter][0] = v0;\n            srcv[row_iter][1] = v1;\n            srcv[row_iter][2] = v2;\n            srcv[row_iter][3] = v3;\n            idvec[row_iter] = _mm256_set1_ps(id[row_iter]);\n        }\n\n        // The loop iterates four times - The aim is to get 4 corresponding chunks of eight bytes from the original weight blocks that are interleaved\n        for (int j = 0; j < 4; j++) {\n            // Apply the multiplier\n            __m256 v0 = _mm256_mul_ps(srcv[0][j], idvec[0]);\n            __m256 v1 = _mm256_mul_ps(srcv[1][j], idvec[1]);\n            __m256 v2 = _mm256_mul_ps(srcv[2][j], idvec[2]);\n            __m256 v3 = _mm256_mul_ps(srcv[3][j], idvec[3]);\n\n            // Round to nearest integer\n            v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );\n            v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );\n            v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );\n            v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );\n\n            // Convert floats to integers\n            __m256i i0 = _mm256_cvtps_epi32( v0 );\n            __m256i i1 = _mm256_cvtps_epi32( v1 );\n            __m256i i2 = _mm256_cvtps_epi32( v2 );\n            __m256i i3 = _mm256_cvtps_epi32( v3 );\n\n#if defined(__AVX2__)\n            // Convert int32 to int16\n            i0 = _mm256_packs_epi32( i0, i1 );\n            i2 = _mm256_packs_epi32( i2, i3 );\n            // Convert int16 to int8\n            i0 = _mm256_packs_epi16( i0, i2 );\n\n            //  Permute and store the quantized weights in the required order after the pack instruction\n            const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );\n            i0 = _mm256_permutevar8x32_epi32( i0, perm );\n\n            _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);\n#else\n            // Since we don't have in AVX some necessary functions,\n            // we split the registers in half and call AVX2 analogs from SSE\n            __m128i ni0 = _mm256_castsi256_si128( i0 );\n            __m128i ni1 = _mm256_extractf128_si256( i0, 1);\n            __m128i ni2 = _mm256_castsi256_si128( i1 );\n            __m128i ni3 = _mm256_extractf128_si256( i1, 1);\n            __m128i ni4 = _mm256_castsi256_si128( i2 );\n            __m128i ni5 = _mm256_extractf128_si256( i2, 1);\n            __m128i ni6 = _mm256_castsi256_si128( i3 );\n            __m128i ni7 = _mm256_extractf128_si256( i3, 1);\n\n            // Convert int32 to int16\n            ni0 = _mm_packs_epi32( ni0, ni1 );\n            ni2 = _mm_packs_epi32( ni2, ni3 );\n            ni4 = _mm_packs_epi32( ni4, ni5 );\n            ni6 = _mm_packs_epi32( ni6, ni7 );\n            // Convert int16 to int8\n            ni0 = _mm_packs_epi16( ni0, ni2 );\n            ni4 = _mm_packs_epi16( ni4, ni6 );\n            _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j), ni0);\n            _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j + 16), ni4);\n#endif\n        }\n    }\n\n#else\n    UNUSED(nb);\n    UNUSED(y);\n    ggml_quantize_mat_q8_0_4x8_generic(x, vy, k);\n#endif\n}\n\nvoid ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK_K == 256);\n    assert(k % QK_K == 0);\n    const int nb = k / QK_K;\n\n    block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;\n\n#if defined(__AVX2__)\n    float iscale[4];\n    __m256 srcv[4][32];\n    __m256 iscale_vec[4];\n\n    for (int i = 0; i < nb; i++) {\n        for (int row_iter = 0; row_iter < 4; row_iter++) {\n            // Load elements into 4 AVX vectors\n            __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 );\n            __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 );\n            __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 );\n            __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 );\n\n            // Compute max(abs(e)) for the block\n            const __m256 signBit = _mm256_set1_ps( -0.0f );\n            __m256 abs0 = _mm256_andnot_ps( signBit, v0 );\n            __m256 abs1 = _mm256_andnot_ps( signBit, v1 );\n            __m256 abs2 = _mm256_andnot_ps( signBit, v2 );\n            __m256 abs3 = _mm256_andnot_ps( signBit, v3 );\n\n            __m256 maxAbs = _mm256_max_ps( abs0, abs1 );\n            maxAbs = _mm256_max_ps( maxAbs, abs2 );\n            maxAbs = _mm256_max_ps( maxAbs, abs3 );\n\n            __m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );\n            __m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );\n            __m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );\n            __m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );\n\n            __m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));\n\n            srcv[row_iter][0] = v0;\n            srcv[row_iter][1] = v1;\n            srcv[row_iter][2] = v2;\n            srcv[row_iter][3] = v3;\n\n            for (int sb = 1; sb < 8; sb++) {\n                // Temporarily stores absolute quant values\n                __m256 tempAbs = maxAbs;\n\n                // Load elements into 4 AVX vectors\n                __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32);\n                __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 );\n                __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 );\n                __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 );\n\n                // Compute max(abs(e)) for the block\n                __m256 abs0 = _mm256_andnot_ps( signBit, v0 );\n                __m256 abs1 = _mm256_andnot_ps( signBit, v1 );\n                __m256 abs2 = _mm256_andnot_ps( signBit, v2 );\n                __m256 abs3 = _mm256_andnot_ps( signBit, v3 );\n\n                maxAbs = _mm256_max_ps( maxAbs, abs0 );\n                maxAbs = _mm256_max_ps( maxAbs, abs1 );\n                maxAbs = _mm256_max_ps( maxAbs, abs2 );\n                maxAbs = _mm256_max_ps( maxAbs, abs3 );\n\n                __m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ );\n                maskAbs = _mm256_and_ps( maskAbs, mask_prev );\n\n                mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );\n                mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );\n                mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );\n                mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );\n\n                __m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));\n                maskAbs =  _mm256_or_ps(maskAbs, mask_curr);\n\n                srcv[row_iter][sb * 4] = v0;\n                srcv[row_iter][sb * 4 + 1] = v1;\n                srcv[row_iter][sb * 4 + 2] = v2;\n                srcv[row_iter][sb * 4 + 3] = v3;\n            }\n\n            __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );\n            max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );\n            max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );\n            const float maxScalar = _mm_cvtss_f32( max4 );\n\n            __m256 maxScalarVec = _mm256_set1_ps(maxScalar);\n\n            __m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ );\n            __m256 finalMask = _mm256_and_ps(maskAbs, mask_next);\n\n            const int mask = _mm256_movemask_ps(finalMask);\n            iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;\n\n            if(mask) {\n                iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f;\n            }\n\n            y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0;\n            iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]);\n        }\n\n        __m256i quants_interleaved[32];\n        for (int j = 0; j < 32; j++) {\n            // Apply the multiplier\n            __m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]);\n            __m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]);\n            __m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]);\n            __m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]);\n\n            // Round to nearest integer\n            v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );\n            v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );\n            v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );\n            v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );\n\n            // Convert floats to integers\n            __m256i i0 = _mm256_cvtps_epi32( v0 );\n            __m256i i1 = _mm256_cvtps_epi32( v1 );\n            __m256i i2 = _mm256_cvtps_epi32( v2 );\n            __m256i i3 = _mm256_cvtps_epi32( v3 );\n\n            // Convert int32 to int16\n            i0 = _mm256_packs_epi32( i0, i1 );\n            i2 = _mm256_packs_epi32( i2, i3 );\n            // Convert int16 to int8\n            i0 = _mm256_packs_epi16( i0, i2 );\n\n            //  Permute and store the quantized weights in the required order after the pack instruction\n            const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );\n            i0 = _mm256_permutevar8x32_epi32( i0, perm );\n\n            _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);\n            quants_interleaved[j] = i0;\n        }\n\n        // Masks to shuffle the quants of corresponding sub blocks for rearranging quants for vectorized bsums computation\n        __m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15));\n        shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0);\n        __m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15));\n        shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0);\n        __m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9));\n        shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0);\n\n        for (int k = 0; k < 4; k++) {\n            // Quants from four different sub blocks are taken\n            __m256i q0 = quants_interleaved[k * 8 + 0];\n            __m256i q1 = quants_interleaved[k * 8 + 1];\n            __m256i q2 = quants_interleaved[k * 8 + 2];\n            __m256i q3 = quants_interleaved[k * 8 + 3];\n            __m256i q4 = quants_interleaved[k * 8 + 4];\n            __m256i q5 = quants_interleaved[k * 8 + 5];\n            __m256i q6 = quants_interleaved[k * 8 + 6];\n            __m256i q7 = quants_interleaved[k * 8 + 7];\n\n\n            // The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time\n            __m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);\n            __m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);\n            __m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);\n            sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);\n            __m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);\n            sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);\n\n            __m256i one = _mm256_set1_epi8(1);\n            __m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved);\n\n            for (int l = 0; l < 3; l++) {\n                // Quants value shifted to process next two values from each sub block\n                q0 = _mm256_srli_epi64(q0, 16);\n                q2 = _mm256_srli_epi64(q2, 16);\n                q4 = _mm256_srli_epi64(q4, 16);\n                q6 = _mm256_srli_epi64(q6, 16);\n\n                sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);\n                sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);\n                sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);\n                sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);\n                sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);\n                sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);\n\n                bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved));\n            }\n\n            // The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time\n            __m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);\n            __m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);\n            __m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);\n            sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);\n            __m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);\n            sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);\n\n            __m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved);\n\n            for (int l = 0; l < 3; l++) {\n                // Quants value shifted to process next two values from each sub block\n                q1 = _mm256_srli_epi64(q1, 16);\n                q3 = _mm256_srli_epi64(q3, 16);\n                q5 = _mm256_srli_epi64(q5, 16);\n                q7 = _mm256_srli_epi64(q7, 16);\n\n                sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);\n                sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);\n                sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);\n                sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);\n                sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);\n                sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);\n\n                bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved));\n            }\n\n            // Overall bsums in interleaved fashion computed by adding results of both halves\n            __m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2);\n            _mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r);\n        }\n    }\n\n#else\n    UNUSED(nb);\n    UNUSED(y);\n    ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);\n#endif\n}\n\n//\n// GEMV/GEMM templates\n//\n\n#if defined(__AVX2__) || defined(__AVX512F__)\n\n// GEMV for 8x blocks of 32 4-bit quants with a single scale factor per block\ntemplate<typename block_tx8>\nstatic void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {\n    static_assert(\n            std::is_same_v<block_tx8, block_q4_0x8> ||\n            std::is_same_v<block_tx8, block_iq4_nlx8> ||\n            std::is_same_v<block_tx8, block_mxfp4x8>,\n            \"Unsupported block type\");\n\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    UNUSED(bs);\n\n    __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);\n    __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);\n\n    // Permute mask used for easier vector processing at later stages\n    const __m256i m4b = _mm256_set1_epi8(0x0F);\n\n    int64_t b_nb = n / 32;\n\n    const block_tx8  * b_ptr_start = (const block_tx8  *)vx;\n    const block_q8_0 * a_ptr_start = (const block_q8_0 *)vy;\n\n    // Process Q8_0 blocks one by one\n    for (int64_t y = 0; y < nr; y++) {\n\n        // Pointers to LHS blocks of block_q8_0 format\n        const block_q8_0 * a_ptr = a_ptr_start + (y * nb);\n\n        // Take group of eight blocks at each pass of the loop and perform dot product operation\n        for (int64_t x = 0; x < nc / 8; x++) {\n\n            // Pointers to RHS blocks\n            const block_tx8 * b_ptr = b_ptr_start + (x * b_nb);\n\n            // Master FP accumulator\n            __m256 acc_row = _mm256_setzero_ps();\n\n            for (int64_t b = 0; b < nb; b++) {\n                // Load 8 blocks of 32 interleaved as 8 bytes (B0 - B7)\n                const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));\n                const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 1);\n                const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 2);\n                const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 3);\n\n                // 4-bit -> 8-bit - Sign is maintained\n                const __m256i rhs_vec_0123_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_0, m4b)); // B0(0-7) B1(0-7) B2(0-7) B3(0-7)\n                const __m256i rhs_vec_4567_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_0, m4b)); // B4(0-7) B5(0-7) B6(0-7) B7(0-7)\n                const __m256i rhs_vec_0123_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15)\n                const __m256i rhs_vec_4567_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15)\n\n                const __m256i rhs_vec_0123_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b)); // B0(16-23) B1(16-23) B2(16-23) B3(16-23)\n                const __m256i rhs_vec_4567_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b)); // B4(16-23) B5(16-23) B6(16-23) B7(16-23)\n                const __m256i rhs_vec_0123_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b)); // B0(24-31) B1(24-31) B2(24-31) B3(24-31)\n                const __m256i rhs_vec_4567_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b)); // B4(24-31) B5(24-31) B6(24-31) B7(24-31)\n\n                // Load the scale values for the 8 blocks interleaved in block_tx8\n                __m256 col_scale_f32;\n                if constexpr (\n                        std::is_same_v<block_tx8, block_q4_0x8> ||\n                        std::is_same_v<block_tx8, block_iq4_nlx8>) {\n                    col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);\n                } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {\n                    // Load 8 E8M0 exponents and convert to float via LUT\n                    // Rearranged to match changemask order: 0,4,1,5,2,6,3,7\n                    col_scale_f32 = _mm256_set_ps(\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));\n                }\n\n                // Load and convert to FP32 scale from block_q8_0\n                const __m256 row_scale_f32 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(a_ptr[b].d));\n\n                // Load the block values in block_q8_0 in batches of 16 bytes and replicate the same across 256 bit vector\n                __m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)a_ptr[b].qs));\n                __m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16)));\n\n                lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0); // A0 (0-15) A0(0-15)\n                lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0); // A0 (16-31) A0(16-31))\n\n                __m256i iacc = _mm256_setzero_si256();\n\n                // Dot product done within 32 bit lanes and accumulated in the same vector\n                // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)\n                // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)\n                // ...........................................................................\n                // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)\n\n                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0));\n                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85));\n\n                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170));\n                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255));\n\n                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0));\n                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85));\n\n                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170));\n                iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255));\n\n                // Accumulated values multiplied with appropriate scales\n                acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);\n            }\n\n            // Accumulated output values permuted so as to be stored in appropriate order post accumulation\n            acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);\n            _mm256_storeu_ps(s + (y * nr + x * 8), acc_row);\n        }\n    }\n}\n\n// GEMM for 8x blocks of 32 4-bit quants with a single scale factor per block\ntemplate<typename block_tx8>\nstatic void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {\n    static_assert(\n            std::is_same_v<block_tx8, block_q4_0x8> ||\n            std::is_same_v<block_tx8, block_iq4_nlx8> ||\n            std::is_same_v<block_tx8, block_mxfp4x8>,\n            \"Unsupported block type\");\n\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    const block_tx8    * b_ptr_start = (const block_tx8    *)vx;\n    const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;\n\n    int64_t b_nb = n / 32;\n    int64_t y = 0;\n    // Mask to mask out nibbles from packed bytes\n    const __m256i m4b = _mm256_set1_epi8(0x0F);\n    const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3);\n    // Permute mask used for easier vector processing at later stages\n    __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);\n    int64_t xstart = 0;\n    int anr = nr - nr%16; // Used to align nr with boundary of 16\n#if defined(__AVX512BW__) && defined(__AVX512DQ__)\n    int anc = nc - nc%16; // Used to align nc with boundary of 16\n                          // Mask to mask out nibbles from packed bytes expanded to 512 bit length\n    const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);\n    // Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length\n    __m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1);\n\n    // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation\n    for (; y < anr / 4; y += 4) {\n\n        const block_q8_0x4 * a_ptrs[4];\n\n        a_ptrs[0] = a_ptr_start + (y * nb);\n        for (int i = 0; i < 3; ++i) {\n            a_ptrs[i + 1] = a_ptrs[i] + nb;\n        }\n\n        // Take group of two block_tx8 structures at each pass of the loop and perform dot product operation\n        for (int64_t x = 0; x < anc / 8; x += 2) {\n\n            const block_tx8 * b_ptr_0 = b_ptr_start + ((x)     * b_nb);\n            const block_tx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);\n\n            // Master FP accumulators\n            __m512 acc_rows[16];\n            for (int i = 0; i < 16; i++) {\n                acc_rows[i] = _mm512_setzero_ps();\n            }\n\n            for (int64_t b = 0; b < nb; b++) {\n                // Load the sixteen blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF\n                const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));\n                const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));\n                const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));\n                const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));\n\n                const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));\n                const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));\n                const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));\n                const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));\n\n                // Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values\n                const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);\n                const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);\n                const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);\n                const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);\n\n                const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);\n                const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);\n                const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);\n                const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);\n\n                const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);\n                const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);\n                const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);\n                const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);\n\n                // 4-bit -> 8-bit - Sign is maintained\n                const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)\n                const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)\n\n                const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)\n                const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)\n\n                const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)\n                const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)\n\n                const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)\n                const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)\n\n                // Shuffle pattern one - right side input\n                const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)\n                const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)\n\n                const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)\n                const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)\n\n                const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)\n                const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)\n\n                const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)\n                const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)\n\n                // Shuffle pattern two - right side input\n\n                const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)\n                const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)\n\n                const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)\n                const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)\n\n                const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)\n                const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)\n\n                const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)\n                const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)\n\n                // Scale values - Load the weight scale values of two block_tx8\n                __m512 col_scale_f32;\n                if constexpr (\n                        std::is_same_v<block_tx8, block_q4_0x8> ||\n                        std::is_same_v<block_tx8, block_iq4_nlx8>) {\n                    col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);\n                } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {\n                    //TODO: simd-ify\n                    col_scale_f32 = _mm512_set_ps(\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0]));\n                }\n\n                // Process LHS in pairs of rows\n                for (int rp = 0; rp < 4; rp++) {\n\n                    // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3\n                    // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector\n                    __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));\n                    __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);\n                    __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);\n                    __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));\n                    __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);\n                    __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);\n                    __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));\n                    __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);\n                    __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);\n                    __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));\n                    __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);\n                    __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);\n\n                    __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);\n                    __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);\n                    __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);\n                    __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);\n                    __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);\n                    __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);\n                    __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);\n                    __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);\n\n                    // Shuffle pattern one - left side input\n\n                    const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)\n                    const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)\n\n                    const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)\n                    const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)\n\n                    const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)\n                    const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)\n\n                    const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)\n                    const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)\n\n                    // Shuffle pattern two - left side input\n\n                    const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)\n                    const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)\n\n                    const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)\n                    const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)\n\n                    const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)\n                    const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)\n\n                    const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)\n                    const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)\n\n                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane\n                    // Resembles MMLAs into 2x2 matrices in ARM Version\n                    const __m512i zero = _mm512_setzero_epi32();\n                    __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);\n                    __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);\n                    __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);\n                    __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);\n                    __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);\n                    __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);\n                    __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);\n                    __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);\n\n                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block\n                    __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);\n                    __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);\n                    __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);\n                    __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);\n\n\n                    // Straighten out to make 4 row vectors\n                    __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));\n                    __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);\n                    __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));\n                    __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);\n\n                    // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes\n                    const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);\n                    const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);\n\n                    // Multiply with appropriate scales and accumulate\n                    acc_rows[rp * 4]     = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)),   acc_rows[rp * 4]);\n                    acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)),  acc_rows[rp * 4 + 1]);\n                    acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);\n                    acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);\n                }\n            }\n\n            // Store the accumulated values\n            for (int i = 0; i < 16; i++) {\n                _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);\n            }\n        }\n    }\n\n    // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation\n    for (; y < nr / 4; y ++) {\n        const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);\n\n        // Take group of two block_tx8 structures at each pass of the loop and perform dot product operation\n        for (int64_t x = 0; x < anc / 8; x += 2) {\n\n            const block_tx8 * b_ptr_0 = b_ptr_start + ((x)     * b_nb);\n            const block_tx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);\n\n            // Master FP accumulators\n            __m512 acc_rows[4];\n            for (int i = 0; i < 4; i++) {\n                acc_rows[i] = _mm512_setzero_ps();\n            }\n\n            for (int64_t b = 0; b < nb; b++) {\n                // Load the sixteen blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF\n                const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));\n                const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));\n                const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));\n                const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));\n\n                const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));\n                const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));\n                const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));\n                const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));\n\n                // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values\n                const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);\n                const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);\n                const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);\n                const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);\n\n                const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);\n                const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);\n                const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);\n                const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);\n\n                const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);\n                const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);\n                const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);\n                const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);\n\n                // 4-bit -> 8-bit - Sign is maintained\n                const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)\n                const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)\n\n                const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)\n                const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)\n\n                const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)\n                const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)\n\n                const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)\n                const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)\n\n                // Shuffle pattern one - right side input\n                const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)\n                const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)\n\n                const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)\n                const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)\n\n                const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)\n                const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)\n\n                const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)\n                const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)\n\n                // Shuffle pattern two - right side input\n\n                const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)\n                const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)\n\n                const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)\n                const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)\n\n                const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)\n                const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)\n\n                const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)\n                const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)\n\n\n                // Scale values - Load the weight scale values of two block_tx8\n                __m512 col_scale_f32;\n                if constexpr (\n                        std::is_same_v<block_tx8, block_q4_0x8> ||\n                        std::is_same_v<block_tx8, block_iq4_nlx8>) {\n                    col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);\n                } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {\n                    //TODO: simd-ify\n                    col_scale_f32 = _mm512_set_ps(\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0]));\n                }\n\n                // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3\n                // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector\n                __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));\n                __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);\n                __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);\n                __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));\n                __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);\n                __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);\n                __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));\n                __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);\n                __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);\n                __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));\n                __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);\n                __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);\n\n                __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);\n                __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);\n                __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);\n                __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);\n                __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);\n                __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);\n                __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);\n                __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);\n\n                // Shuffle pattern one - left side input\n\n                const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)\n                const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)\n\n                const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)\n                const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)\n\n                const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)\n                const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)\n\n                const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)\n                const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)\n\n                // Shuffle pattern two - left side input\n\n                const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)\n                const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)\n\n                const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)\n                const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)\n\n                const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)\n                const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)\n\n                const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)\n                const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)\n\n                // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane\n                // Resembles MMLAs into 2x2 matrices in ARM Version\n                const __m512i zero = _mm512_setzero_epi32();\n                __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);\n                __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);\n                __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);\n                __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);\n                __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);\n                __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);\n                __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);\n                __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);\n\n                // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block\n                __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);\n                __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);\n                __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);\n                __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);\n\n\n                // Straighten out to make 4 row vectors\n                __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));\n                __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);\n                __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));\n                __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);\n\n                // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes\n                const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);\n                const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);\n\n                // Multiply with appropriate scales and accumulate\n                acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)),   acc_rows[0]);\n                acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)),  acc_rows[1]);\n                acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);\n                acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);\n            }\n\n            // Store the accumulated values\n            for (int i = 0; i < 4; i++) {\n                _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);\n            }\n        }\n    }\n    if (anc != nc) {\n        xstart = anc/8;\n        y = 0;\n    }\n#endif // __AVX512BW__ && __AVX512DQ__\n\n    // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation\n\n    for (; y < anr / 4; y += 4) {\n        const block_q8_0x4 * a_ptrs[4];\n\n        a_ptrs[0] = a_ptr_start + (y * nb);\n        for (int i = 0; i < 3; ++i) {\n            a_ptrs[i + 1] = a_ptrs[i] + nb;\n        }\n\n        // Take group of eight block_tx8 structures at each pass of the loop and perform dot product operation\n        for (int64_t x = xstart; x < nc / 8; x++) {\n\n            const block_tx8 * b_ptr = b_ptr_start + (x * b_nb);\n\n            // Master FP accumulators\n            __m256 acc_rows[16];\n            for (int i = 0; i < 16; i++) {\n                acc_rows[i] = _mm256_setzero_ps();\n            }\n\n            for (int64_t b = 0; b < nb; b++) {\n                // Load the eight blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7\n                const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));\n                const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));\n                const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));\n                const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));\n\n                // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values\n                const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);\n                const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);\n                const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);\n                const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);\n\n                // 4-bit -> 8-bit - Sign is maintained\n                const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)\n                const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)\n\n                const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)\n                const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)\n\n                const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)\n                const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)\n\n                const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)\n                const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)\n\n                // Shuffle pattern one - right side input\n                const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136);  //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)\n                const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136);  //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)\n\n                const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136);  //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)\n                const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136);  //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)\n\n                const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136);  //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)\n                const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136);  //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)\n\n                const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136);  //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)\n                const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136);  //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)\n\n                // Shuffle pattern two - right side input\n\n                const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221);  //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)\n                const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221);  //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)\n\n                const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221);  //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)\n                const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221);  //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)\n\n                const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221);  //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)\n                const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221);  //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)\n\n                const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221);  //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)\n                const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221);  //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)\n\n                // Scale values - Load the wight scale values of block_tx8\n                __m256 col_scale_f32;\n                if constexpr (\n                        std::is_same_v<block_tx8, block_q4_0x8> ||\n                        std::is_same_v<block_tx8, block_iq4_nlx8>) {\n                    col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);\n                } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {\n                    col_scale_f32 = _mm256_set_ps(\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));\n                }\n\n                // Process LHS in groups of four\n                for (int rp = 0; rp < 4; rp++) {\n                    // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3\n                    // Loaded as set of 128 bit vectors and repeated into a 256 bit vector\n                    __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));\n                    __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);\n                    __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);\n                    __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));\n                    __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);\n                    __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);\n                    __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));\n                    __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);\n                    __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);\n                    __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));\n                    __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);\n                    __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);\n\n                    // Shuffle pattern one - left side input\n                    const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)\n                    const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)\n\n                    const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)\n                    const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)\n\n                    const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)\n                    const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)\n\n                    const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)\n                    const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)\n\n                    // Shuffle pattern two - left side input\n                    const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)\n                    const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)\n\n                    const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)\n                    const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)\n\n                    const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)\n                    const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)\n\n                    const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)\n                    const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)\n\n                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane\n                    // Resembles MMLAs into 2x2 matrices in ARM Version\n                    const __m256i zero = _mm256_setzero_si256();\n                    __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);\n                    __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);\n                    __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);\n                    __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);\n                    __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);\n                    __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);\n                    __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);\n                    __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);\n\n                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block\n                    __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);\n                    __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);\n                    __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);\n                    __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);\n\n                    // Straighten out to make 4 row vectors\n                    __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);\n                    __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);\n                    __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);\n                    __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);\n\n                    // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes\n                    const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);\n\n                    // Multiply with appropriate scales and accumulate\n                    acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);\n                    acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);\n                    acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);\n                    acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32,  255)), acc_rows[rp * 4 + 3]);\n                }\n            }\n\n            // Store the accumulated values\n            for (int i = 0; i < 16; i++) {\n                _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);\n            }\n        }\n    }\n\n    // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation\n    for (; y < nr / 4; y ++) {\n        const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);\n\n        // Load the eight blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7\n        for (int64_t x = xstart; x < nc / 8; x++) {\n            const block_tx8 * b_ptr = b_ptr_start + (x * b_nb);\n\n            // Master FP accumulators\n            __m256 acc_rows[4];\n            for (int i = 0; i < 4; i++) {\n                acc_rows[i] = _mm256_setzero_ps();\n            }\n\n            for (int64_t b = 0; b < nb; b++) {\n                // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7\n                const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));\n                const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));\n                const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));\n                const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));\n\n                // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values\n                const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);\n                const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);\n                const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);\n                const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);\n\n                // 4-bit -> 8-bit - Sign is maintained\n                const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b));  //B0(0-7) B1(0-7) B4(0-7) B5(0-7)\n                const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b));  //B2(0-7) B3(0-7) B6(0-7) B7(0-7)\n\n                const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b));  //B0(8-15) B1(8-15) B4(8-15) B5(8-15)\n                const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b));  //B2(8-15) B3(8-15) B6(8-15) B7(8-15)\n\n                const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b));  //B0(16-23) B1(16-23) B4(16-23) B5(16-23)\n                const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b));  //B2(16-23) B3(16-23) B6(16-23) B7(16-23)\n\n                const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b));  //B0(24-31) B1(24-31) B4(24-31) B5(24-31)\n                const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b));  //B2(24-31) B3(24-31) B6(24-31) B7(24-31)\n\n                // Shuffle pattern one - right side input\n                const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136);  //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)\n                const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136);  //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)\n\n                const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136);  //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)\n                const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136);  //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)\n\n                const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136);  //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)\n                const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136);  //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)\n\n                const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136);  //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)\n                const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136);  //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)\n\n                // Shuffle pattern two - right side input\n\n                const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221);  //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)\n                const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221);  //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)\n\n                const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221);  //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)\n                const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221);  //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)\n\n                const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221);  //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)\n                const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221);  //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)\n\n                const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221);  //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)\n                const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221);  //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)\n\n                // Scale values - Load the wight scale values of block_tx8\n                __m256 col_scale_f32;\n                if constexpr (\n                        std::is_same_v<block_tx8, block_q4_0x8> ||\n                        std::is_same_v<block_tx8, block_iq4_nlx8>) {\n                    col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);\n                } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {\n                    col_scale_f32 = _mm256_set_ps(\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),\n                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));\n                }\n\n                // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3\n                // Loaded as set of 128 bit vectors and repeated into a 256 bit vector\n                __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));\n                __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);\n                __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);\n                __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));\n                __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);\n                __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);\n                __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));\n                __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);\n                __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);\n                __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));\n                __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);\n                __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);\n\n                // Shuffle pattern one - left side input\n\n                const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)\n                const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)\n\n                const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)\n                const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)\n\n                const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)\n                const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)\n\n                const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)\n                const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)\n\n                // Shuffle pattern two - left side input\n\n                const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)\n                const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)\n\n                const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)\n                const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)\n\n                const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)\n                const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)\n\n                const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)\n                const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)\n\n                // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane\n                // Resembles MMLAs into 2x2 matrices in ARM Version\n                const __m256i zero = _mm256_setzero_si256();\n                __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);\n                __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);\n                __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);\n                __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);\n                __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);\n                __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);\n                __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);\n                __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);\n\n                // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block\n                __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);\n                __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);\n                __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);\n                __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);\n\n\n                // Straighten out to make 4 row vectors\n                __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);\n                __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);\n                __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);\n                __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);\n\n                // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes\n                const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask);\n\n                // Multiply with appropriate scales and accumulate\n                acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);\n                acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);\n                acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);\n                acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);\n            }\n\n            // Store the accumulated values\n            for (int i = 0; i < 4; i++) {\n                _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);\n            }\n        }\n    }\n}\n\n#endif // defined(__AVX2__) || defined(__AVX512F__)\n\nvoid ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n#if defined(__AVX2__) || defined(__AVX512F__)\n    {\n        // Lookup table to convert signed nibbles to signed bytes\n        __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));\n        signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);\n\n        gemv_q4_b32_8x8_q8_0_lut_avx<block_q4_0x8>(n, s, bs, vx, vy, nr, nc, signextendlut);\n\n        return;\n    }\n#endif\n\n    ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK_K;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__AVX2__)\n    // Lookup table to convert signed nibbles to signed bytes\n    __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));\n    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);\n    // Shuffle masks to rearrange delta and scale values to multiply with appropriate scales\n    __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);\n    __m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);\n    // Permute mask used for easier vector processing at later stages\n    __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);\n\n    // Mask to extract nibbles from bytes\n    const __m256i m4b = _mm256_set1_epi8(0x0F);\n\n    int64_t b_nb = n / QK_K;\n\n    const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 *)vx;\n    const block_q8_K * a_ptr_start = (const block_q8_K *)vy;\n\n    // Process Q8_K blocks one by one\n    for (int64_t y = 0; y < nr; y++) {\n\n        // Pointers to LHS blocks of block_q8_K format\n        const block_q8_K * a_ptr = a_ptr_start + (y * nb);\n\n        // Take group of eight interleaved block_q4_K structures at each pass of the loop and perform dot product operation\n        for (int64_t x = 0; x < nc / 8; x++) {\n\n            // Pointers to RHS blocks\n            const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);\n\n            // Master FP accumulators\n            __m256 acc_row = _mm256_setzero_ps();\n            __m256 acc_min_rows = _mm256_setzero_ps();\n\n            for (int64_t b = 0; b < nb; b++) {\n\n                // Load and convert to FP32 scale from block_q8_K\n                const __m256 row_scale_f32 = _mm256_set1_ps((a_ptr[b].d));\n\n                // Load the scale values for the 8 blocks interleaved in block_q4_Kx8\n                // col_scale_f32 rearranged so as to multiply with appropriate quants\n                const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, deltamask);\n                const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);\n\n                __m256i iacc_b = _mm256_setzero_si256();\n                __m256i iacc_min_b = _mm256_setzero_si256();\n\n                const __m256i q8sums = _mm256_loadu_si256((const __m256i * )(a_ptr[b].bsums));\n                __m256i q8s = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)));\n                q8s = _mm256_permute2f128_si256(q8s, q8s, 0);\n\n                // Processes two sub blocks from each Q4_K in each iteration\n                for (int sb = 0; sb < QK_K / 64; sb++) {\n\n                    // Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7\n                    const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));\n                    const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));\n                    const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));\n                    const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));\n                    const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));\n                    const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));\n                    const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));\n                    const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));\n\n                    // 4-bit -> 8-bit\n                    // Values of the first sub block of eight block_q4_K structures for the sb loop\n                    const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);\n                    const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);\n                    const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);\n                    const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);\n                    const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);\n                    const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);\n                    const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);\n                    const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);\n\n                    // Values of the second sub block of eight block_q4_K structures when sb = 1\n                    const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b);\n                    const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);\n                    const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);\n                    const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);\n                    const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);\n                    const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);\n                    const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);\n                    const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);\n\n                    uint32_t utmp_0[4], utmp_1[4];\n\n                    // Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together\n                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop\n                    memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);\n                    utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_0 = utmp_0[1] & kmask1;\n                    utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);\n                    utmp_0[2] = uaux_0;\n                    utmp_0[0] &= kmask1;\n\n                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop\n                    memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);\n                    utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_1 = utmp_1[1] & kmask1;\n                    utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);\n                    utmp_1[2] = uaux_1;\n                    utmp_1[0] &= kmask1;\n\n                    // Scales of first sub block in the sb loop\n                    const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);\n                    __m128i scales_rearrange_0 = _mm_shuffle_epi8(mins_and_scales_0, scalemask);\n                    __m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);\n\n                    // Scales of second sub block in the sb loop\n                    __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);\n                    __m128i scales_rearrange_1 = _mm_shuffle_epi8(mins_and_scales_1, scalemask);\n                    __m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);\n\n                    // Mins of first and second sub block of Q4_K block are arranged side by side\n                    __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));\n\n                    // Load the two sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector\n                    __m256i lhs_vec_00 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 64)));\n                    __m256i lhs_vec_01 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 64)));\n                    __m256i lhs_vec_10 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 64)));\n                    __m256i lhs_vec_11 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 64)));\n\n                    lhs_vec_00 = _mm256_permute2f128_si256(lhs_vec_00, lhs_vec_00, 0);\n                    lhs_vec_01 = _mm256_permute2f128_si256(lhs_vec_01, lhs_vec_01, 0);\n                    lhs_vec_10 = _mm256_permute2f128_si256(lhs_vec_10, lhs_vec_10, 0);\n                    lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0);\n\n                    // Dot product done within 32 bit lanes and accumulated in the same vector\n                    // First done for first sub block and then for second sub block in each sb\n                    // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)\n                    // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)\n                    // ...........................................................................\n                    // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)\n\n\n                    __m256i iacc_0 = _mm256_setzero_si256();\n                    __m256i iacc_1 = _mm256_setzero_si256();\n\n                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 0)));\n                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_00, 85)));\n\n                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 170)));\n                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_00, 255)));\n\n                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 0)));\n                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_shuffle_epi32(lhs_vec_01, 85)));\n\n                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 170)));\n                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_shuffle_epi32(lhs_vec_01, 255)));\n\n                    iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);\n\n                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 0)));\n                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_10, 85)));\n\n                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 170)));\n                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_10, 255)));\n\n                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 0)));\n                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_shuffle_epi32(lhs_vec_11, 85)));\n\n                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 170)));\n                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_shuffle_epi32(lhs_vec_11, 255)));\n\n                    iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);\n\n                    // Accumulate the iacc value for one sb\n                    __m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1);\n\n                    // Broadcast the bsums of the two sub blocks  of the iteration of Q8_K across the vector\n                    // Multiply-Add with corresponding mins of Q4_Kx8 with bsums\n                    __m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);\n                    __m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);\n                    q8s = _mm256_bsrli_epi128(q8s, 4);\n\n                    // Accumulate for the complete block\n                    iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);\n                    iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);\n                }\n\n                // Multiply-Add with scale values for the complete super block\n                acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);\n                acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32), acc_min_rows);\n\n            }\n\n            // Accumulated output values permuted so as to be stored in appropriate order post accumulation\n            acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);\n            _mm256_storeu_ps(s + (y * nr + x * 8), _mm256_sub_ps(acc_row, acc_min_rows));\n        }\n    }\n\n#else\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n#endif\n}\n\nvoid ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n#if defined(__AVX2__)\n    __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_iq4nl));\n    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);\n\n    gemv_q4_b32_8x8_q8_0_lut_avx<block_iq4_nlx8>(n, s, bs, vx, vy, nr, nc, signextendlut);\n\n    return;\n#endif\n\n    ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n#if defined(__AVX2__)\n    __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4));\n    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);\n\n    gemv_q4_b32_8x8_q8_0_lut_avx<block_mxfp4x8>(n, s, bs, vx, vy, nr, nc, signextendlut);\n\n    return;\n#endif\n\n    ggml_gemv_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK_K;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__AVX2__)\n    // Lookup table to convert signed nibbles to signed bytes\n    __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));\n    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);\n    // Shuffle masks to rearrange delta values to multiply with appropriate scales\n    __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);\n    // Permute mask used for easier vector processing at later stages\n    __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);\n\n    const __m256i m3b = _mm256_set1_epi8(3);\n    const __m128i m4b_sse = _mm_set1_epi8(0xF);\n\n    //Mask to get appropriate scales\n    __m128i scalemask1 = _mm_set_epi8(14,14,6,6,12,12,4,4,10,10,2,2,8,8,0,0);\n    __m128i scalemask2 = _mm_set_epi8(15,15,7,7,13,13,5,5,11,11,3,3,9,9,1,1);\n\n    int64_t b_nb = n / QK_K;\n\n    const block_q2_Kx8 * b_ptr_start = (const block_q2_Kx8 *)vx;\n    const block_q8_K * a_ptr_start = (const block_q8_K *)vy;\n\n    // Process Q8_K blocks one by one\n    for (int64_t y = 0; y < nr; y++) {\n\n        // Pointers to LHS blocks of block_q8_K format\n        const block_q8_K * a_ptr = a_ptr_start + (y * nb);\n\n        // Take group of eight interleaved block_q2_K structures at each pass of the loop and perform dot product operation\n        for(int64_t x = 0; x < nc / 8; x++) {\n\n            // Pointers to RHS blocks\n            const block_q2_Kx8 * b_ptr = b_ptr_start + (x * b_nb);\n\n            // Master FP accumulators\n            __m256 acc_row = _mm256_setzero_ps();\n            __m256 acc_min_rows = _mm256_setzero_ps();\n\n            for (int64_t b = 0; b < nb; b++) {\n\n                // Load and convert to FP32 delta from block_q8_K\n                const __m256 row_scale_f32 = _mm256_set1_ps((a_ptr[b].d));\n\n                // Load the delta values for the 8 blocks interleaved in block_q2_Kx8\n                // col_scale_f32 rearranged so as to multiply with appropriate quants\n                const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, deltamask);\n                const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);\n\n                __m256i iacc_b = _mm256_setzero_si256();\n                __m256i iacc_min_b = _mm256_setzero_si256();\n\n                // Processes eight sub blocks from each Q2_K in each iteration\n                for(int sb = 0; sb < QK_K / 128; sb++) {\n\n                    // Load the eight block_q2_K for eight sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7\n                    const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));\n                    const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));\n                    const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));\n                    const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));\n                    const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));\n                    const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));\n                    const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));\n                    const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));\n\n                    // 2-bit -> 8-bit\n                    // Values of the 0th,2nd,4th,6th sub blocks of eight block_q2_K structures for the sb loop\n                    const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m3b); //B00(0-7) B01(0-7) B02(0-7) B03(0-7)\n                    const __m256i rhs_vec_0123_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 2), m3b); //B20(0-7) B21(0-7) B22(0-7) B23(0-7)\n                    const __m256i rhs_vec_0123_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m3b); //B40(0-7) B41(0-7) B42(0-7) B43(0-7)\n                    const __m256i rhs_vec_0123_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 6), m3b); //B60(0-7) B61(0-7) B62(0-7) B63(0-7)\n\n                    const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m3b); //B04(0-7) B05(0-7) B06(0-7) B07(0-7)\n                    const __m256i rhs_vec_4567_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 2), m3b); //B24(0-7) B25(0-7) B26(0-7) B27(0-7)\n                    const __m256i rhs_vec_4567_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m3b); //B44(0-7) B45(0-7) B46(0-7) B47(0-7)\n                    const __m256i rhs_vec_4567_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 6), m3b); //B64(0-7) B65(0-7) B66(0-7) B67(0-7)\n\n                    const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m3b); //B00(8-15) B01(8-15) B02(8-15) B03(8-15)\n                    const __m256i rhs_vec_0123_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 2), m3b); //B20(8-15) B21(8-15) B22(8-15) B23(8-15)\n                    const __m256i rhs_vec_0123_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m3b); //B40(8-15) B41(8-15) B42(8-15) B43(8-15)\n                    const __m256i rhs_vec_0123_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 6), m3b); //B60(8-15) B61(8-15) B62(8-15) B63(8-15)\n\n                    const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m3b); //B04(8-15) B05(8-15) B06(8-15) B07(8-15)\n                    const __m256i rhs_vec_4567_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 2), m3b); //B24(8-15) B25(8-15) B26(8-15) B27(8-15)\n                    const __m256i rhs_vec_4567_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m3b); //B44(8-15) B45(8-15) B46(8-15) B47(8-15)\n                    const __m256i rhs_vec_4567_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 6), m3b); //B64(8-15) B65(8-15) B66(8-15) B67(8-15)\n\n                    // Values of the 1st,3rd,5th,7th sub blocks of eight block_q2_K structures for the sb loop\n                    const __m256i rhs_vec_0123_10 = _mm256_and_si256(rhs_raw_vec_0123_2, m3b); //B10(0-7) B11(0-7) B12(0-7) B13(0-7)\n                    const __m256i rhs_vec_0123_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 2), m3b); //B30(0-7) B31(0-7) B32(0-7) B33(0-7)\n                    const __m256i rhs_vec_0123_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m3b); //B50(0-7) B51(0-7) B52(0-7) B53(0-7)\n                    const __m256i rhs_vec_0123_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 6), m3b); //B70(0-7) B71(0-7) B72(0-7) B73(0-7)\n\n                    const __m256i rhs_vec_4567_10 = _mm256_and_si256(rhs_raw_vec_4567_2, m3b); //B14(0-7) B15(0-7) B16(0-7) B17(0-7)\n                    const __m256i rhs_vec_4567_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 2), m3b); //B34(0-7) B35(0-7) B36(0-7) B37(0-7)\n                    const __m256i rhs_vec_4567_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m3b); //B54(0-7) B55(0-7) B56(0-7) B57(0-7)\n                    const __m256i rhs_vec_4567_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 6), m3b); //B74(0-7) B75(0-7) B76(0-7) B77(0-7)\n\n                    const __m256i rhs_vec_0123_11 = _mm256_and_si256(rhs_raw_vec_0123_3, m3b); //B10(8-15) B11(8-15) B12(8-15) B13(8-15)\n                    const __m256i rhs_vec_0123_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 2), m3b); //B30(8-15) B31(8-15) B32(8-15) B33(8-15)\n                    const __m256i rhs_vec_0123_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m3b); //B50(8-15) B51(8-15) B52(8-15) B53(8-15)\n                    const __m256i rhs_vec_0123_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 6), m3b); //B70(8-15) B71(8-15) B72(8-15) B73(8-15)\n\n                    const __m256i rhs_vec_4567_11 = _mm256_and_si256(rhs_raw_vec_4567_3, m3b); //B14(8-15) B15(8-15) B16(8-15) B17(8-15)\n                    const __m256i rhs_vec_4567_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 2), m3b); //B34(8-15) B35(8-15) B36(8-15) B37(8-15)\n                    const __m256i rhs_vec_4567_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m3b); //B54(8-15) B55(8-15) B56(8-15) B57(8-15)\n                    const __m256i rhs_vec_4567_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 6), m3b); //B74(8-15) B75(8-15) B76(8-15) B77(8-15)\n\n                    //Scales and Mins of corresponding sub blocks from different Q2_K structures are stored together\n                    //s00 m00  s01 m01   s10 m10  s11 m11  s20 m20  s21 m21   s30 m30  s31 m31  s40 m40  s41 m41   s50 m50  s51 m51  s60 m60  s61 m61   s70 m70  s71 m71\n\n                    const __m128i mins_and_scales_01 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + sb * 64));\n                    const __m128i mins_and_scales_23 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 16 + sb * 64));\n                    const __m128i mins_and_scales_45 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 32 + sb * 64));\n                    const __m128i mins_and_scales_67 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 48 + sb * 64));\n\n                    // Extract scales which is lower half from mins_and_scales\n                    const __m128i scales_01 = _mm_and_si128(mins_and_scales_01, m4b_sse);\n                    const __m128i scales_23 = _mm_and_si128(mins_and_scales_23, m4b_sse);\n                    const __m128i scales_45 = _mm_and_si128(mins_and_scales_45, m4b_sse);\n                    const __m128i scales_67 = _mm_and_si128(mins_and_scales_67, m4b_sse);\n\n                    // Extract mins which is upper half from mins_and_scales\n                    const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_01, 4), m4b_sse));\n                    const __m256i mins_23 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_23, 4), m4b_sse));\n                    const __m256i mins_45 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_45, 4), m4b_sse));\n                    const __m256i mins_67 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_67, 4), m4b_sse));\n\n                    // Scales of sub blocks in the sb loop\n                    // Scales of the 0th sub block from each super block\n                    __m128i scales_rearrange_0 = _mm_shuffle_epi8(scales_01, scalemask1);\n                    __m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);\n\n                    // Scales of the 1st sub block from each super block\n                    __m128i scales_rearrange_1 = _mm_shuffle_epi8(scales_01, scalemask2);\n                    __m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);\n\n                    // Scales of the 2nd sub block from each super block\n                    __m128i scales_rearrange_2 = _mm_shuffle_epi8(scales_23, scalemask1);\n                    __m256i scales_2 = _mm256_cvtepu8_epi16(scales_rearrange_2);\n\n                    // Scales of the 3rd sub block from each super block\n                    __m128i scales_rearrange_3 = _mm_shuffle_epi8(scales_23, scalemask2);\n                    __m256i scales_3 = _mm256_cvtepu8_epi16(scales_rearrange_3);\n\n                    // Scales of the 4th sub block from each super block\n                    __m128i scales_rearrange_4 = _mm_shuffle_epi8(scales_45, scalemask1);\n                    __m256i scales_4 = _mm256_cvtepu8_epi16(scales_rearrange_4);\n\n                    // Scales of the 5th sub block from each super block\n                    __m128i scales_rearrange_5 = _mm_shuffle_epi8(scales_45, scalemask2);\n                    __m256i scales_5 = _mm256_cvtepu8_epi16(scales_rearrange_5);\n\n                    // Scales of the 6th sub block from each super block\n                    __m128i scales_rearrange_6 = _mm_shuffle_epi8(scales_67, scalemask1);\n                    __m256i scales_6 = _mm256_cvtepu8_epi16(scales_rearrange_6);\n\n                    // Scales of the 7th sub block from each super block\n                    __m128i scales_rearrange_7 = _mm_shuffle_epi8(scales_67, scalemask2);\n                    __m256i scales_7 = _mm256_cvtepu8_epi16(scales_rearrange_7);\n\n                    // Load the sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector\n                    __m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 128)));\n                    __m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 128)));\n                    __m256i lhs_vec_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 128)));\n                    __m256i lhs_vec_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 128)));\n                    __m256i lhs_vec_4 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 64 + sb * 128)));\n                    __m256i lhs_vec_5 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 80 + sb * 128)));\n                    __m256i lhs_vec_6 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 96 + sb * 128)));\n                    __m256i lhs_vec_7 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 112 + sb * 128)));\n\n                    lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0);\n                    lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0);\n                    lhs_vec_2 = _mm256_permute2f128_si256(lhs_vec_2, lhs_vec_2, 0);\n                    lhs_vec_3 = _mm256_permute2f128_si256(lhs_vec_3, lhs_vec_3, 0);\n                    lhs_vec_4 = _mm256_permute2f128_si256(lhs_vec_4, lhs_vec_4, 0);\n                    lhs_vec_5 = _mm256_permute2f128_si256(lhs_vec_5, lhs_vec_5, 0);\n                    lhs_vec_6 = _mm256_permute2f128_si256(lhs_vec_6, lhs_vec_6, 0);\n                    lhs_vec_7 = _mm256_permute2f128_si256(lhs_vec_7, lhs_vec_7, 0);\n\n                    __m256i iacc_0 = _mm256_setzero_si256();\n                    __m256i iacc_1 = _mm256_setzero_si256();\n                    __m256i iacc_2 = _mm256_setzero_si256();\n                    __m256i iacc_3 = _mm256_setzero_si256();\n                    __m256i iacc_4 = _mm256_setzero_si256();\n                    __m256i iacc_5 = _mm256_setzero_si256();\n                    __m256i iacc_6 = _mm256_setzero_si256();\n                    __m256i iacc_7 = _mm256_setzero_si256();\n\n                    // Dot product done within 32 bit lanes and accumulated in the same vector\n                    // First done for 0th sub block and then for seven (1st - 7th) other sub blocks processed for each sb (sb < QK_K/128 loop)                    // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)\n                    // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)\n                    // B0(8-11) B4(8-11) B1(8-11) B5(8-11) B2(8-11) B6(8-11) B3(8-11) B7(8-11) with A0(8-11)\n                    // B0(12-15) B4(12-15) B1(12-15) B5(12-15) B2(12-15) B6(12-15) B3(12-15) B7(12-15) with A0(12-15)\n\n                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));\n                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));\n\n                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));\n                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));\n\n                    iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);\n\n                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));\n                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));\n\n                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));\n                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));\n\n                    iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);\n\n                    iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_20 ,_mm256_shuffle_epi32(rhs_vec_4567_20, 177), 170), _mm256_shuffle_epi32(lhs_vec_2, 0)));\n                    iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_20, 177) ,rhs_vec_4567_20, 170), _mm256_shuffle_epi32(lhs_vec_2, 85)));\n\n                    iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_21 ,_mm256_shuffle_epi32(rhs_vec_4567_21, 177), 170), _mm256_shuffle_epi32(lhs_vec_2, 170)));\n                    iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_21, 177) ,rhs_vec_4567_21, 170), _mm256_shuffle_epi32(lhs_vec_2, 255)));\n\n                    iacc_2 = _mm256_madd_epi16(iacc_2, scales_2);\n\n                    iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_30 ,_mm256_shuffle_epi32(rhs_vec_4567_30, 177), 170), _mm256_shuffle_epi32(lhs_vec_3, 0)));\n                    iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_30, 177) ,rhs_vec_4567_30, 170), _mm256_shuffle_epi32(lhs_vec_3, 85)));\n\n                    iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_31 ,_mm256_shuffle_epi32(rhs_vec_4567_31, 177), 170), _mm256_shuffle_epi32(lhs_vec_3, 170)));\n                    iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_31, 177) ,rhs_vec_4567_31, 170), _mm256_shuffle_epi32(lhs_vec_3, 255)));\n\n                    iacc_3 = _mm256_madd_epi16(iacc_3, scales_3);\n\n                    iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_40 ,_mm256_shuffle_epi32(rhs_vec_4567_40, 177), 170), _mm256_shuffle_epi32(lhs_vec_4, 0)));\n                    iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_40, 177) ,rhs_vec_4567_40, 170), _mm256_shuffle_epi32(lhs_vec_4, 85)));\n\n                    iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_41 ,_mm256_shuffle_epi32(rhs_vec_4567_41, 177), 170), _mm256_shuffle_epi32(lhs_vec_4, 170)));\n                    iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_41, 177) ,rhs_vec_4567_41, 170), _mm256_shuffle_epi32(lhs_vec_4, 255)));\n\n                    iacc_4 = _mm256_madd_epi16(iacc_4, scales_4);\n\n                    iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_50 ,_mm256_shuffle_epi32(rhs_vec_4567_50, 177), 170), _mm256_shuffle_epi32(lhs_vec_5, 0)));\n                    iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_50, 177) ,rhs_vec_4567_50, 170), _mm256_shuffle_epi32(lhs_vec_5, 85)));\n\n                    iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_51 ,_mm256_shuffle_epi32(rhs_vec_4567_51, 177), 170), _mm256_shuffle_epi32(lhs_vec_5, 170)));\n                    iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_51, 177) ,rhs_vec_4567_51, 170), _mm256_shuffle_epi32(lhs_vec_5, 255)));\n\n                    iacc_5 = _mm256_madd_epi16(iacc_5, scales_5);\n\n                    iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_60 ,_mm256_shuffle_epi32(rhs_vec_4567_60, 177), 170), _mm256_shuffle_epi32(lhs_vec_6, 0)));\n                    iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_60, 177) ,rhs_vec_4567_60, 170), _mm256_shuffle_epi32(lhs_vec_6, 85)));\n\n                    iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_61 ,_mm256_shuffle_epi32(rhs_vec_4567_61, 177), 170), _mm256_shuffle_epi32(lhs_vec_6, 170)));\n                    iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_61, 177) ,rhs_vec_4567_61, 170), _mm256_shuffle_epi32(lhs_vec_6, 255)));\n\n                    iacc_6 = _mm256_madd_epi16(iacc_6, scales_6);\n\n                    iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_70 ,_mm256_shuffle_epi32(rhs_vec_4567_70, 177), 170), _mm256_shuffle_epi32(lhs_vec_7, 0)));\n                    iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_70, 177) ,rhs_vec_4567_70, 170), _mm256_shuffle_epi32(lhs_vec_7, 85)));\n\n                    iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_71 ,_mm256_shuffle_epi32(rhs_vec_4567_71, 177), 170), _mm256_shuffle_epi32(lhs_vec_7, 170)));\n                    iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_71, 177) ,rhs_vec_4567_71, 170), _mm256_shuffle_epi32(lhs_vec_7, 255)));\n\n                    iacc_7 = _mm256_madd_epi16(iacc_7, scales_7);\n\n                    // Accumulate the iacc value for one sb\n                    __m256i iacc_sb = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_0, iacc_1), _mm256_add_epi32(iacc_2, iacc_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_4, iacc_5), _mm256_add_epi32(iacc_6, iacc_7)));\n\n                    __m128i q8sums = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + sb * 8));\n                    __m256i q8s = _mm256_castsi128_si256(q8sums);\n                    q8s= _mm256_permute2f128_si256(q8s, q8s, 0);\n\n                    // Broadcast the bsums of the two corresponding subblocks of q8_k\n                    // Multiply-Add with corresponding mins of Q2_Kx8 with bsums\n                    __m256i iacc_min_sb_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 0), mins_01);\n                    __m256i iacc_min_sb_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 85), mins_23);\n                    __m256i iacc_min_sb_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 170), mins_45);\n                    __m256i iacc_min_sb_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 255), mins_67);\n\n                    __m256i iacc_min_sb = _mm256_add_epi32(_mm256_add_epi32(iacc_min_sb_01, iacc_min_sb_23), _mm256_add_epi32(iacc_min_sb_45,iacc_min_sb_67));\n\n                    // Accumulate for the complete block\n                    iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);\n                    iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);\n                }\n\n                //Multiply-Add with scale values for complete super block\n                acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);\n                acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32), acc_min_rows);\n            }\n            // Accumulated output values permuted so as to be stored in appropriate order post accumulation\n            acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);\n            _mm256_storeu_ps(s + (y * nr + x * 8), _mm256_sub_ps(acc_row, acc_min_rows));\n        }\n    }\n#else\n\n    ggml_gemv_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n\n#endif\n}\n\nvoid ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n#if defined(__AVX2__) || defined(__AVX512F__)\n    {\n        // Lookup table to convert signed nibbles to signed bytes\n        __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));\n        signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);\n\n        gemm_q4_b32_8x8_q8_0_lut_avx<block_q4_0x8>(n, s, bs, vx, vy, nr, nc, signextendlut);\n\n        return;\n    }\n#endif // defined(__AVX2__) || defined(__AVX512F__)\n\n    ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK_K;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__AVX2__) || defined(__AVX512F__)\n    const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 * ) vx;\n    const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy;\n    int64_t b_nb = n / QK_K;\n    int64_t y = 0;\n\n    // Mask to mask out nibbles from packed bytes\n    const __m256i m4b = _mm256_set1_epi8(0x0F);\n    // Permute mask used for easier vector processing at later stages\n    __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);\n    int64_t xstart = 0;\n    int anr = nr - nr % 16;; // Used to align nr with boundary of 16\n#if defined(__AVX512BW__) && defined(__AVX512DQ__)\n    int anc = nc - nc % 16; // Used to align nc with boundary of 16\n    // Mask to mask out nibbles from packed bytes expanded to 512 bit length\n    const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);\n    //Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation\n    for (; y < anr / 4; y += 4) {\n\n        const block_q8_Kx4 * a_ptrs[4];\n\n        a_ptrs[0] = a_ptr_start + (y * nb);\n        for (int i = 0; i < 3; ++i) {\n            a_ptrs[i + 1] = a_ptrs[i] + nb;\n        }\n\n        // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation\n        for (int64_t x = 0; x < anc / 8; x += 2) {\n\n            const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);\n            const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);\n\n            // Master FP accumulators\n            __m512 acc_rows[16];\n            for (int i = 0; i < 16; i++) {\n                acc_rows[i] = _mm512_setzero_ps();\n            }\n\n            __m512 acc_min_rows[16];\n            for (int i = 0; i < 16; i++) {\n                acc_min_rows[i] = _mm512_setzero_ps();\n            }\n\n            // For super block\n            for (int64_t b = 0; b < nb; b++) {\n                // Scale values - Load the sixteen scale values from two block_q4_kx8 structures\n                const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);\n\n                // dmin values - Load the sixteen dmin values from two block_q4_kx8 structures\n                const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin);\n\n                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration\n                for (int sb = 0; sb < QK_K / 64; sb++) {\n\n                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256));\n                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256));\n\n                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256));\n                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256));\n                    const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256));\n                    const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256));\n\n                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);\n                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);\n                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);\n                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);\n\n                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);\n                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);\n                    const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240);\n                    const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240);\n\n                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);\n                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);\n\n                    const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1);\n                    const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1);\n\n                    //4-bit -> 8-bit\n                    const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7)\n                    const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7)\n                    const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15)\n                    const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15)\n\n                    const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23)\n                    const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23)\n                    const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31)\n                    const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31)\n\n                    const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7)\n                    const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7)\n                    const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15)\n                    const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15)\n\n                    const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23)\n                    const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23)\n                    const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31)\n                    const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31)\n\n                    // Shuffle pattern one - right side input\n                    const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3)\n                    const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3)\n                    const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)\n                    const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11)\n                    const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19)\n                    const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19)\n                    const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27)\n                    const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27)\n\n                    const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3)\n                    const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3)\n                    const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11)\n                    const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11)\n                    const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19)\n                    const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19)\n                    const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27)\n                    const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27)\n\n                    // Shuffle pattern two - right side input\n                    const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7)\n                    const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7)\n                    const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15)\n                    const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15)\n                    const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23)\n                    const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23)\n                    const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31)\n                    const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31)\n\n                    const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7)\n                    const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7)\n                    const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15)\n                    const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15)\n                    const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23)\n                    const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23)\n                    const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31)\n                    const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31)\n\n                    uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4];\n\n                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together\n                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop\n                    memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12);\n                    utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_00 = utmp_00[1] & kmask1;\n                    utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4);\n                    utmp_00[2] = uaux_00;\n                    utmp_00[0] &= kmask1;\n\n                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop\n                    memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12);\n                    utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_01 = utmp_01[1] & kmask1;\n                    utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4);\n                    utmp_01[2] = uaux_01;\n                    utmp_01[0] &= kmask1;\n\n                    memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12);\n                    utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_10 = utmp_10[1] & kmask1;\n                    utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4);\n                    utmp_10[2] = uaux_10;\n                    utmp_10[0] &= kmask1;\n\n                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop\n                    memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12);\n                    utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_11 = utmp_11[1] & kmask1;\n                    utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4);\n                    utmp_11[2] = uaux_11;\n                    utmp_11[0] &= kmask1;\n\n                    // Scales of first sub block in the sb loop\n                    const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]);\n                    const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));\n\n                    // Scales of second sub block in the sb loop\n                    const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]);\n                    const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));\n\n                    // Mins of first and second sub block of Q4_K block are arranged side by side\n                    const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78)));\n\n                    const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238);\n\n                    for (int rp = 0; rp < 4; rp++) {\n\n                        // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3\n                        // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector\n                        __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));\n                        __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0);\n                        __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17);\n                        __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));\n                        __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0);\n                        __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17);\n                        __m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));\n                        __m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0);\n                        __m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17);\n                        __m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));\n                        __m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0);\n                        __m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17);\n                        __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));\n                        __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0);\n                        __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17);\n                        __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));\n                        __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0);\n                        __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17);\n                        __m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));\n                        __m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0);\n                        __m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17);\n                        __m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));\n                        __m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0);\n                        __m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17);\n\n                        __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1);\n                        __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1);\n                        __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1);\n                        __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1);\n                        __m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1);\n                        __m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1);\n                        __m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1);\n                        __m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1);\n\n                        __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1);\n                        __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1);\n                        __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1);\n                        __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1);\n                        __m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1);\n                        __m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1);\n                        __m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1);\n                        __m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1);\n\n                        // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks\n                        __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));\n                        __m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));\n                        lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0);\n                        __m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1);\n\n                        // Shuffle pattern one - left side input\n                        const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)\n                        const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3)\n                        const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)\n                        const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11)\n                        const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)\n                        const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19)\n                        const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)\n                        const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27)\n\n                        const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)\n                        const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3)\n                        const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)\n                        const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11)\n                        const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)\n                        const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19)\n                        const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)\n                        const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27)\n\n                        const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)\n                        const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7)\n                        const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)\n                        const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15)\n                        const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)\n                        const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23)\n                        const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)\n                        const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31)\n\n                        const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)\n                        const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7)\n                        const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)\n                        const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15)\n                        const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)\n                        const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23)\n                        const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)\n                        const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31)\n\n                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane\n                        __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1));\n                        __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1));\n                        __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1));\n                        __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1));\n                        __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1));\n                        __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1));\n                        __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1));\n                        __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1));\n\n                        __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2));\n                        __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2));\n                        __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2));\n                        __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2));\n                        __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2));\n                        __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2));\n                        __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2));\n                        __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2));\n\n                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block\n                        __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);\n                        __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);\n                        __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);\n                        __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);\n\n                        __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);\n                        __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);\n                        __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);\n                        __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);\n\n                        iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0);\n                        iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0);\n                        iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0);\n                        iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0);\n\n                        iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1);\n                        iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1);\n                        iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1);\n                        iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1);\n\n                        // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)\n                        __m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78));\n                        __m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0);\n                        __m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78));\n                        __m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0);\n                        __m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78));\n                        __m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1);\n                        __m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78));\n                        __m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1);\n\n                        __m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1);\n                        __m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1);\n                        __m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1);\n                        __m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1);\n\n                        // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes\n                        const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);\n                        const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);\n                        const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1);\n\n                        // Multiply with appropriate scales and accumulate (for both d and dmin) below\n                        acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);\n                        acc_rows[rp * 4  + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);\n                        acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);\n                        acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);\n\n                        __m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01);\n                        __m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01);\n                        __m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01);\n                        __m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01);\n\n                        acc_min_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);\n                        acc_min_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);\n                        acc_min_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);\n                        acc_min_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);\n                    }\n                }\n            }\n            // Store the accumulated values\n            for (int i = 0; i < 16; i++) {\n                _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i]));\n            }\n        }\n    }\n\n    for (; y < nr / 4; y++) {\n\n        const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);\n\n        // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation\n        for (int64_t x = 0; x < anc / 8; x += 2) {\n\n            const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);\n            const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);\n\n            // Master FP accumulators\n            __m512 acc_rows[4];\n            for (int i = 0; i < 4; i++) {\n                acc_rows[i] = _mm512_setzero_ps();\n            }\n\n            __m512 acc_min_rows[4];\n            for (int i = 0; i < 4; i++) {\n                acc_min_rows[i] = _mm512_setzero_ps();\n            }\n\n            // For super block\n            for (int64_t b = 0; b < nb; b++) {\n                // Scale values - Load the sixteen scale values from two block_q4_kx8 structures\n                const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);\n\n                // dmin values - Load the sixteen dmin values from two block_q4_kx8 structures\n                const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin);\n\n                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration\n                for (int sb = 0; sb < QK_K / 64; sb++) {\n\n                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256));\n                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256));\n\n                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256));\n                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256));\n                    const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256));\n                    const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256));\n\n                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);\n                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);\n                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);\n                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);\n\n                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);\n                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);\n                    const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240);\n                    const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240);\n\n                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);\n                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);\n\n                    const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1);\n                    const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1);\n\n                    //4-bit -> 8-bit\n                    const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7)\n                    const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7)\n                    const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15)\n                    const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15)\n\n                    const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23)\n                    const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23)\n                    const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31)\n                    const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31)\n\n                    const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7)\n                    const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7)\n                    const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15)\n                    const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15)\n\n                    const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23)\n                    const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23)\n                    const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31)\n                    const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31)\n\n                    // Shuffle pattern one - right side input\n                    const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3)\n                    const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3)\n                    const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)\n                    const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11)\n                    const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19)\n                    const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19)\n                    const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27)\n                    const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27)\n\n                    const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3)\n                    const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3)\n                    const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11)\n                    const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11)\n                    const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19)\n                    const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19)\n                    const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27)\n                    const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27)\n\n                    // Shuffle pattern two - right side input\n                    const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7)\n                    const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7)\n                    const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15)\n                    const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15)\n                    const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23)\n                    const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23)\n                    const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31)\n                    const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31)\n\n                    const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7)\n                    const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7)\n                    const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15)\n                    const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15)\n                    const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23)\n                    const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23)\n                    const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31)\n                    const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31)\n\n                    uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4];\n\n                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together\n                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop\n                    memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12);\n                    utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_00 = utmp_00[1] & kmask1;\n                    utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4);\n                    utmp_00[2] = uaux_00;\n                    utmp_00[0] &= kmask1;\n\n                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop\n                    memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12);\n                    utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_01 = utmp_01[1] & kmask1;\n                    utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4);\n                    utmp_01[2] = uaux_01;\n                    utmp_01[0] &= kmask1;\n\n                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop\n                    memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12);\n                    utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_10 = utmp_10[1] & kmask1;\n                    utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4);\n                    utmp_10[2] = uaux_10;\n                    utmp_10[0] &= kmask1;\n\n                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop\n                    memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12);\n                    utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_11 = utmp_11[1] & kmask1;\n                    utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4);\n                    utmp_11[2] = uaux_11;\n                    utmp_11[0] &= kmask1;\n\n                    // Scales of first sub block in the sb loop\n                    const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]);\n                    const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));\n\n                    // Scales of second sub block in the sb loop\n                    const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]);\n                    const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));\n\n                    // Mins of first and second sub block of Q4_K block are arranged side by side\n                    const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78)));\n\n                    const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238);\n\n                    // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3\n                    // Loaded as set of 128 bit vectors and repeated into a 256 bit vector\n                    __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb)));\n                    __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0);\n                    __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17);\n                    __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb)));\n                    __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0);\n                    __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17);\n                    __m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb)));\n                    __m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0);\n                    __m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17);\n                    __m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb)));\n                    __m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0);\n                    __m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17);\n                    __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb)));\n                    __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0);\n                    __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17);\n                    __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb)));\n                    __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0);\n                    __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17);\n                    __m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb)));\n                    __m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0);\n                    __m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17);\n                    __m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb)));\n                    __m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0);\n                    __m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17);\n\n                    //Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into a 512 bit vector\n                    __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1);\n                    __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1);\n                    __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1);\n                    __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1);\n                    __m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1);\n                    __m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1);\n                    __m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1);\n                    __m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1);\n\n                    __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1);\n                    __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1);\n                    __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1);\n                    __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1);\n                    __m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1);\n                    __m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1);\n                    __m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1);\n                    __m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1);\n\n                    // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks\n                    __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb)));\n                    __m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));\n                    lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0);\n                    __m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1);\n\n                    // Shuffle pattern one - left side input\n                    const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)\n                    const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3)\n                    const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)\n                    const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11)\n                    const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)\n                    const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19)\n                    const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)\n                    const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27)\n\n                    const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)\n                    const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3)\n                    const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)\n                    const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11)\n                    const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)\n                    const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19)\n                    const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)\n                    const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27)\n\n                    const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)\n                    const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7)\n                    const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)\n                    const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15)\n                    const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)\n                    const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23)\n                    const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)\n                    const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31)\n\n                    const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)\n                    const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7)\n                    const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)\n                    const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15)\n                    const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)\n                    const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23)\n                    const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)\n                    const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31)\n\n                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane\n                    __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1));\n                    __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1));\n                    __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1));\n                    __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1));\n                    __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1));\n                    __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1));\n                    __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1));\n                    __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1));\n\n                    __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2));\n                    __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2));\n                    __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2));\n                    __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2));\n                    __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2));\n                    __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2));\n                    __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2));\n                    __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2));\n\n                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block\n                    __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);\n                    __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);\n                    __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);\n                    __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);\n\n                    __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);\n                    __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);\n                    __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);\n                    __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);\n\n                    iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0);\n                    iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0);\n                    iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0);\n                    iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0);\n\n                    iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1);\n                    iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1);\n                    iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1);\n                    iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1);\n\n                    // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)\n                    __m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78));\n                    __m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0);\n                    __m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78));\n                    __m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0);\n                    __m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78));\n                    __m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1);\n                    __m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78));\n                    __m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1);\n\n                    __m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1);\n                    __m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1);\n                    __m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1);\n                    __m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1);\n\n                    // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes\n                    const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);\n                    const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);\n                    const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1);\n\n                    // Multiply with appropriate scales and accumulate (for both d and dmin) below\n                    acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);\n                    acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);\n                    acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);\n                    acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);\n\n                    __m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01);\n                    __m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01);\n                    __m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01);\n                    __m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01);\n\n                    acc_min_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);\n                    acc_min_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);\n                    acc_min_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]);\n                    acc_min_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);\n                }\n            }\n            // Store accumulated values\n            for (int i = 0; i < 4; i++) {\n                _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i]));\n            }\n        }\n    }\n    if (anc != nc) {\n        xstart = anc/8;\n        y = 0;\n    }\n#endif // __AVX512BW__ && __AVX512DQ__\n\n    // Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation\n    for (; y < anr / 4; y += 4) {\n\n        const block_q8_Kx4 * a_ptrs[4];\n\n        a_ptrs[0] = a_ptr_start + (y * nb);\n        for (int i = 0; i < 3; ++i) {\n            a_ptrs[i + 1] = a_ptrs[i] + nb;\n        }\n\n        // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation\n        for (int64_t x = xstart; x < nc / 8; x++) {\n\n            const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);\n\n            // Master FP accumulators\n            __m256 acc_rows[16];\n            for (int i = 0; i < 16; i++) {\n                acc_rows[i] = _mm256_setzero_ps();\n            }\n\n            __m256 acc_min_rows[16];\n            for (int i = 0; i < 16; i++) {\n                acc_min_rows[i] = _mm256_setzero_ps();\n            }\n\n            // For super block\n            for (int64_t b = 0; b < nb; b++) {\n\n                // Scale values - Load the eight scale values of block_q4_kx8\n                const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);\n\n                // dmin values - Load the eight dmin values of block_q4_kx8\n                const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);\n\n                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration\n                for (int sb = 0; sb < QK_K / 64; sb++) {\n\n                    // Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7\n                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));\n                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));\n\n                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values\n                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);\n                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);\n                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);\n                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);\n\n                    // 4-bit -> 8-bit\n                    // First sub block of the two sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)\n                    const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)\n\n                    const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)\n                    const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)\n\n                    const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)\n                    const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)\n\n                    const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)\n                    const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)\n\n                    // Second sub block of the two sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)\n                    const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)\n\n                    const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)\n                    const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)\n\n                    const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)\n                    const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)\n\n                    const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)\n                    const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)\n\n                    // Shuffle pattern one - right side input\n                    const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)\n                    const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)\n\n                    const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)\n                    const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)\n\n                    const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)\n                    const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)\n\n                    const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)\n                    const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)\n\n                    const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)\n                    const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)\n\n                    const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)\n                    const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)\n\n                    const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)\n                    const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)\n\n                    const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)\n                    const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)\n\n\n                    // Shuffle pattern two - right side input\n                    const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)\n                    const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)\n\n                    const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)\n                    const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)\n\n                    const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)\n                    const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)\n\n                    const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)\n                    const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)\n\n                    const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)\n                    const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)\n\n                    const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)\n                    const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)\n\n                    const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)\n                    const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)\n\n                    const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)\n                    const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)\n\n                    uint32_t utmp_0[4], utmp_1[4];\n\n                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together\n                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop\n                    memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);\n                    utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_0 = utmp_0[1] & kmask1;\n                    utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);\n                    utmp_0[2] = uaux_0;\n                    utmp_0[0] &= kmask1;\n\n                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop\n                    memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);\n                    utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_1 = utmp_1[1] & kmask1;\n                    utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);\n                    utmp_1[2] = uaux_1;\n                    utmp_1[0] &= kmask1;\n\n                    // Scales of first sub block in the sb loop\n                    const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);\n                    const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));\n\n                    // Scales of second sub block in the sb loop\n                    const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);\n                    const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));\n\n                    // Mins of first and second sub block of Q4_K block are arranged side by side\n                    const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));\n\n                    const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);\n                    const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);\n\n                    const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);\n                    const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);\n\n                    for (int rp = 0; rp < 4; rp++) {\n\n                        // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3\n                        // Loaded as set of 128 bit vectors and repeated into a 256 bit vector\n                        __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));\n                        __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);\n                        __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);\n                        __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));\n                        __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);\n                        __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);\n                        __m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));\n                        __m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);\n                        __m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);\n                        __m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));\n                        __m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);\n                        __m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);\n                        __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));\n                        __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);\n                        __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);\n                        __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));\n                        __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);\n                        __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);\n                        __m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));\n                        __m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);\n                        __m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);\n                        __m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));\n                        __m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);\n                        __m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);\n\n                        // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks\n                        __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));\n                        __m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));\n                        lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);\n\n                        // Shuffle pattern one - left side input\n                        const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)\n                        const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)\n\n                        const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160);  //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)\n                        const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160);  //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)\n\n                        const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160);  //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)\n                        const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160);  //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)\n\n                        const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160);  //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)\n                        const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)\n\n                        const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)\n                        const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)\n\n                        const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160);  //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)\n                        const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160);  //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)\n\n                        const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160);  //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)\n                        const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160);  //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)\n\n                        const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)\n                        const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)\n\n                        // Shuffle pattern two- left side input\n                        const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)\n                        const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)\n\n                        const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)\n                        const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)\n\n                        const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)\n                        const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)\n\n                        const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)\n                        const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)\n\n                        const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)\n                        const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)\n\n                        const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)\n                        const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)\n\n                        const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)\n                        const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)\n\n                        const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)\n                        const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)\n\n                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane\n                        __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));\n                        __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));\n                        __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));\n                        __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));\n                        __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));\n                        __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));\n                        __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));\n                        __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));\n\n                        __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));\n                        __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));\n                        __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));\n                        __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));\n                        __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));\n                        __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));\n                        __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));\n                        __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));\n\n                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block\n                        __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);\n                        __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);\n                        __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);\n                        __m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);\n\n                        __m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);\n                        __m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);\n                        __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);\n                        __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);\n\n                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block\n                        iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);\n                        iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);\n                        iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);\n                        iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);\n\n                        iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);\n                        iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);\n                        iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);\n                        iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);\n\n                        // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)\n                        __m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);\n                        __m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);\n                        __m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);\n                        __m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);\n                        __m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);\n                        __m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);\n                        __m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);\n                        __m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);\n\n                        __m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);\n                        __m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);\n                        __m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);\n                        __m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);\n\n                        // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes\n                        const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);\n                        const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);//GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);\n\n                        // Multiply with appropriate scales and accumulate (for both d and dmin) below\n                        acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);\n                        acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);\n                        acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);\n                        acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);\n\n                        __m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);\n                        __m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);\n                        __m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);\n                        __m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);\n\n                        acc_min_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);\n                        acc_min_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);\n                        acc_min_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);\n                        acc_min_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);\n\n                    }\n                }\n            }\n            // Store the accumulated values\n            for (int i = 0; i < 16; i++) {\n                _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));\n            }\n        }\n    }\n    for (; y < nr / 4; y++) {\n\n        const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);\n\n        for (int64_t x = xstart; x < nc / 8; x++) {\n\n            const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);\n\n            // Master FP accumulators\n            __m256 acc_rows[4];\n            for (int i = 0; i < 4; i++) {\n                acc_rows[i] = _mm256_setzero_ps();\n            }\n\n            __m256 acc_min_rows[4];\n            for (int i = 0; i < 4; i++) {\n                acc_min_rows[i] = _mm256_setzero_ps();\n            }\n\n            for (int64_t b = 0; b < nb; b++) {\n\n                // Scale values - Load the eight scale values of block_q4_Kx8\n                const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);\n\n                // dmin values - Load the eight dmin values of block_q4_Kx8\n                const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);\n\n                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration\n                for (int sb = 0; sb < QK_K / 64; sb++) {\n\n                    // Load the eight block_q4_k for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7\n                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));\n                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));\n\n                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values\n                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);\n                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);\n                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);\n                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);\n\n                    // 4-bit -> 8-bit\n                    // First sub block of the two sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)\n                    const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)\n\n                    const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)\n                    const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)\n\n                    const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)\n                    const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)\n\n                    const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)\n                    const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)\n\n                    // Second sub block of the two sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)\n                    const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)\n\n                    const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)\n                    const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)\n\n                    const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)\n                    const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)\n\n                    const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)\n                    const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)\n\n                    // Shuffle pattern one - right side input\n                    const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)\n                    const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)\n\n                    const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)\n                    const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)\n\n                    const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)\n                    const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)\n\n                    const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)\n                    const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)\n\n                    const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)\n                    const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)\n\n                    const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)\n                    const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)\n\n                    const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)\n                    const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)\n\n                    const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)\n                    const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)\n\n                    // Shuffle pattern two - right side input\n                    const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)\n                    const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)\n\n                    const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)\n                    const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)\n\n                    const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)\n                    const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)\n\n                    const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)\n                    const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)\n\n                    const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)\n                    const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)\n\n                    const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)\n                    const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)\n\n                    const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)\n                    const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)\n\n                    const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)\n                    const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)\n\n                    uint32_t utmp_0[4], utmp_1[4];\n\n                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together\n                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop\n                    memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);\n                    utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_0 = utmp_0[1] & kmask1;\n                    utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);\n                    utmp_0[2] = uaux_0;\n                    utmp_0[0] &= kmask1;\n\n                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures when sb = 1\n                    memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);\n                    utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_1 = utmp_1[1] & kmask1;\n                    utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);\n                    utmp_1[2] = uaux_1;\n                    utmp_1[0] &= kmask1;\n\n                    // Scales of first sub block in the sb loop\n                    const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);\n                    const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));\n\n                    // Scales of second sub block in the sb loop\n                    const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);\n                    const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));\n\n                    // Mins of first and second sub block of Q4_K block are arranged side by side\n                    const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));\n\n                    const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);\n                    const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);\n\n                    const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);\n                    const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);\n\n                    // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3\n                    // Loaded as set of 128 bit vectors and repeated into a 256 bit vector\n                    __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb)));\n                    __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);\n                    __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);\n                    __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb)));\n                    __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);\n                    __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);\n                    __m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb)));\n                    __m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);\n                    __m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);\n                    __m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb)));\n                    __m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);\n                    __m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);\n                    __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb)));\n                    __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);\n                    __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);\n                    __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb)));\n                    __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);\n                    __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);\n                    __m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb)));\n                    __m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);\n                    __m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);\n                    __m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb)));\n                    __m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);\n                    __m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);\n\n                    // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks\n                    __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb)));\n                    __m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));\n                    lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);\n\n                    // Shuffle pattern one - left side input\n                    const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)\n                    const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)\n\n                    const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160);  //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)\n                    const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160);  //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)\n\n                    const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160);  //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)\n                    const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160);  //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)\n\n                    const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160);  //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)\n                    const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)\n\n                    const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)\n                    const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)\n\n                    const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160);  //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)\n                    const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160);  //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)\n\n                    const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160);  //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)\n                    const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160);  //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)\n\n                    const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)\n                    const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)\n\n                    // Shuffle pattern two- left side input\n                    const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)\n                    const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)\n\n                    const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)\n                    const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)\n\n                    const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)\n                    const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)\n\n                    const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)\n                    const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)\n\n                    const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)\n                    const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)\n\n                    const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)\n                    const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)\n\n                    const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)\n                    const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)\n\n                    const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)\n                    const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)\n\n                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane\n                    __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));\n                    __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));\n                    __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));\n                    __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));\n                    __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));\n                    __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));\n                    __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));\n                    __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));\n\n                    __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));\n                    __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));\n                    __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));\n                    __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));\n                    __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));\n                    __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));\n                    __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));\n                    __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));\n\n                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block\n                    __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);\n                    __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);\n                    __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);\n                    __m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);\n\n                    __m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);\n                    __m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);\n                    __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);\n                    __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);\n\n                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block\n                    iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);\n                    iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);\n                    iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);\n                    iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);\n\n                    iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);\n                    iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);\n                    iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);\n                    iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);\n\n                    // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)\n                    __m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);\n                    __m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);\n                    __m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);\n                    __m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);\n                    __m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);\n                    __m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);\n                    __m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);\n                    __m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);\n\n                    __m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);\n                    __m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);\n                    __m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);\n                    __m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);\n\n                    // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes\n                    const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);\n                    const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); //GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);\n\n                    // Multiply with appropriate scales and accumulate (for both d and dmin) below\n                    acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);\n                    acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);\n                    acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);\n                    acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);\n\n                    __m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);\n                    __m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);\n                    __m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);\n                    __m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);\n\n                    acc_min_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);\n                    acc_min_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);\n                    acc_min_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]);\n                    acc_min_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);\n                }\n            }\n\n            // Store the accumulated values\n            for (int i = 0; i < 4; i++) {\n                _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));\n            }\n        }\n    }\n\n#else\n    UNUSED(kmask1);\n    UNUSED(kmask2);\n    UNUSED(kmask3);\n    ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n#endif\n}\n\nvoid ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n#if defined(__AVX2__) || defined(__AVX512F__)\n    {\n        __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_iq4nl));\n        signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);\n\n        gemm_q4_b32_8x8_q8_0_lut_avx<block_iq4_nlx8>(n, s, bs, vx, vy, nr, nc, signextendlut);\n\n        return;\n    }\n#endif // defined(__AVX2__) || defined(__AVX512F__)\n\n    ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n#if defined(__AVX2__) || defined(__AVX512F__)\n    {\n        __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4));\n        signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);\n\n        gemm_q4_b32_8x8_q8_0_lut_avx<block_mxfp4x8>(n, s, bs, vx, vy, nr, nc, signextendlut);\n\n        return;\n    }\n#endif // defined(__AVX2__) || defined(__AVX512F__)\n\n    ggml_gemm_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK_K;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n#if defined(__AVX2__) || defined(__AVX512F__)\n    const block_q2_Kx8 * b_ptr_start = (const block_q2_Kx8 * ) vx;\n    const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy;\n    int64_t b_nb = n / QK_K;\n    int64_t y = 0;\n\n    // Permute mask used for easier vector processing at later stages\n    __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);\n    int64_t xstart = 0;\n    int anr = nr - nr % 16; // Used to align nr with boundary of 16\n\n    // Mask to convert 2 bit and 4 bit values into a bytes\n    const __m256i m3b = _mm256_set1_epi8(3);\n    const __m128i m4b_sse = _mm_set1_epi8(0xF);\n\n    //Mask to get appropriate scales\n    __m128i scalesmask1_sse = _mm_set_epi8(14,14,12,12,10,10,8,8,6,6,4,4,2,2,0,0);\n    __m128i scalesmask2_sse = _mm_set_epi8(15,15,13,13,11,11,9,9,7,7,5,5,3,3,1,1);\n\n    __m256i scalesmask1 = _mm256_castsi128_si256(scalesmask1_sse);\n    scalesmask1 = _mm256_permute2f128_si256(scalesmask1, scalesmask1, 0);\n    __m256i scalesmask2 = _mm256_castsi128_si256(scalesmask2_sse);\n    scalesmask2 = _mm256_permute2f128_si256(scalesmask2, scalesmask2, 0);\n\n#if defined(__AVX512BW__) && defined(__AVX512DQ__)\n\n    int anc = nc - nc % 16; // Used to align nc with boundary of 16\n\n    // Mask to mask out nibbles from packed bytes\n    const __m256i m4b = _mm256_set1_epi8(0x0F);\n    // Mask to mask out nibbles from packed bytes expanded to 512 bit length\n    const __m512i m3bexpanded = _mm512_set1_epi8(3);\n    //Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation\n    for (; y < anr / 4; y += 4) {\n\n        const block_q8_Kx4 * a_ptrs[4];\n\n        a_ptrs[0] = a_ptr_start + (y * nb);\n        for (int i = 0; i < 3; ++i) {\n            a_ptrs[i + 1] = a_ptrs[i] + nb;\n        }\n\n        // Take group of eight block_q2_kx8 structures at each pass of the loop and perform dot product operation\n        for (int64_t x = 0; x < anc / 8; x += 2) {\n\n            const block_q2_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);\n            const block_q2_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);\n\n            // Master FP accumulators\n            __m512 acc_rows[16];\n            for (int i = 0; i < 16; i++) {\n                acc_rows[i] = _mm512_setzero_ps();\n            }\n\n            __m512 acc_min_rows[16];\n            for (int i = 0; i < 16; i++) {\n                acc_min_rows[i] = _mm512_setzero_ps();\n            }\n            // For super block\n            for (int64_t b = 0; b < nb; b++) {\n                // Delta values - Load the sixteen scale values from two block_q2_kx8 structures\n                const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);\n\n                // dmin values - Load the sixteen dmin values from two block_q2_kx8 structures\n                const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin);\n\n                // Loop to iterate over the sixteen sub blocks of a super block - eight sub blocks are processed per iteration\n                for (int sb = 0; sb < QK_K / 128; sb++) {\n\n                    // Load the eight block_q2_k for eight sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7\n                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256));\n                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256));\n\n                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256));\n                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256));\n                    const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256));\n                    const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256));\n\n                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);\n                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);\n                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);\n                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);\n\n                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);\n                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);\n                    const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240);\n                    const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240);\n\n                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);\n                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);\n\n                    const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1);\n                    const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1);\n\n                    //2-bit -> 8-bit\n                    const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0,m3bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7)\n                    const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0,m3bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7)\n                    const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1,m3bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15)\n                    const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1,m3bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15)\n                    const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(rhs_raw_mat_014589CD_2,m3bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7)\n                    const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2,m3bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7)\n                    const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(rhs_raw_mat_014589CD_3,m3bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15)\n                    const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3,m3bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15)\n\n                    const __m512i rhs_mat_014589CD_20 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 2), m3bexpanded); //B20(0-7) B21(0-7) B24(0-7) B25(0-7) B28(0-7) B29(0-7) B2C(0-7) B2D(0-7)\n                    const __m512i rhs_mat_2367ABEF_20 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 2), m3bexpanded); //B22(0-7) B23(0-7) B26(0-7) B27(0-7) B2A(0-7) B2B(0-7) B2E(0-7) B2F(0-7)\n\n                    const __m512i rhs_mat_014589CD_21 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 2), m3bexpanded); //B20(8-15) B21(8-15) B24(8-15) B25(8-15) B28(8-15) B29(8-15) B2C(8-15) B2D(8-15)\n                    const __m512i rhs_mat_2367ABEF_21 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 2), m3bexpanded); //B22(8-15) B23(8-15) B26(8-15) B27(8-15) B2A(8-15) B2B(8-15) B2E(8-15) B2F(8-15)\n\n                    const __m512i rhs_mat_014589CD_30 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 2), m3bexpanded); //B30(0-7) B31(0-7) B34(0-7) B35(0-7) B38(0-7) B39(0-7) B3C(0-7) B3D(0-7)\n                    const __m512i rhs_mat_2367ABEF_30 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 2), m3bexpanded); //B32(0-7) B33(0-7) B36(0-7) B37(0-7) B3A(0-7) B3B(0-7) B3E(0-7) B3F(0-7)\n\n                    const __m512i rhs_mat_014589CD_31 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 2), m3bexpanded); //B30(8-15) B31(8-15) B34(8-15) B35(8-15) B38(8-15) B39(8-15) B3C(8-15) B3D(8-15)\n                    const __m512i rhs_mat_2367ABEF_31 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 2), m3bexpanded); //B32(8-15) B33(8-15) B36(8-15) B37(8-15) B3A(8-15) B3B(8-15) B3E(8-15) B3F(8-15)\n\n                    const __m512i rhs_mat_014589CD_40 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m3bexpanded); //B40(0-7) B41(0-7) B44(0-7) B45(0-7) B48(0-7) B49(0-7) B4C(0-7) B4D(0-7)\n                    const __m512i rhs_mat_2367ABEF_40 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m3bexpanded); //B42(0-7) B43(0-7) B46(0-7) B47(0-7) B4A(0-7) B4B(0-7) B4E(0-7) B4F(0-7)\n\n                    const __m512i rhs_mat_014589CD_41 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m3bexpanded); //B40(8-15) B41(8-15) B44(8-15) B45(8-15) B48(8-15) B49(8-15) B4C(8-15) B4D(8-15)\n                    const __m512i rhs_mat_2367ABEF_41 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m3bexpanded); //B42(8-15) B43(8-15) B46(8-15) B47(8-15) B4A(8-15) B4B(8-15) B4E(8-15) B4F(8-15)\n\n                    const __m512i rhs_mat_014589CD_50 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m3bexpanded); //B50(0-7) B51(0-7) B54(0-7) B55(0-7) B58(0-7) B59(0-7) B5C(0-7) B5D(0-7)\n                    const __m512i rhs_mat_2367ABEF_50 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m3bexpanded); //B52(0-7) B53(0-7) B56(0-7) B57(0-7) B5A(0-7) B5B(0-7) B5E(0-7) B5F(0-7)\n\n                    const __m512i rhs_mat_014589CD_51 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m3bexpanded); //B50(8-15) B51(8-15) B54(8-15) B55(8-15) B58(8-15) B59(8-15) B5C(8-15) B5D(8-15)\n                    const __m512i rhs_mat_2367ABEF_51 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m3bexpanded); //B52(8-15) B53(8-15) B56(8-15) B57(8-15) B5A(8-15) B5B(8-15) B5E(8-15) B5F(8-15)\n\n                    const __m512i rhs_mat_014589CD_60 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 6), m3bexpanded); //B60(0-7) B61(0-7) B64(0-7) B65(0-7) B68(0-7) B69(0-7) B6C(0-7) B6D(0-7)\n                    const __m512i rhs_mat_2367ABEF_60 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 6), m3bexpanded); //B62(0-7) B63(0-7) B66(0-7) B67(0-7) B6A(0-7) B6B(0-7) B6E(0-7) B6F(0-7)\n\n                    const __m512i rhs_mat_014589CD_61 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 6), m3bexpanded); //B60(8-15) B61(8-15) B64(8-15) B65(8-15) B68(8-15) B69(8-15) B6C(8-15) B6D(8-15)\n                    const __m512i rhs_mat_2367ABEF_61 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 6), m3bexpanded); //B62(8-15) B63(8-15) B66(8-15) B67(8-15) B6A(8-15) B6B(8-15) B6E(8-15) B6F(8-15)\n\n                    const __m512i rhs_mat_014589CD_70 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 6), m3bexpanded); //B70(0-7) B71(0-7) B74(0-7) B75(0-7) B78(0-7) B79(0-7) B7C(0-7) B7D(0-7)\n                    const __m512i rhs_mat_2367ABEF_70 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 6), m3bexpanded); //B72(0-7) B73(0-7) B76(0-7) B77(0-7) B7A(0-7) B7B(0-7) B7E(0-7) B7F(0-7)\n\n                    const __m512i rhs_mat_014589CD_71 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 6), m3bexpanded); //B70(8-15) B71(8-15) B74(8-15) B75(8-15) B78(8-15) B79(8-15) B7C(8-15) B7D(8-15)\n                    const __m512i rhs_mat_2367ABEF_71 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 6), m3bexpanded); //B72(8-15) B73(8-15) B76(8-15) B77(8-15) B7A(8-15) B7B(8-15) B7E(8-15) B7F(8-15)\n\n                    const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3)\n                    const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3)\n\n                    const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)\n                    const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11)\n\n                    const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3)\n                    const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3)\n\n                    const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11)\n                    const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11)\n\n                    const __m512i rhs_mat_014589CD_20_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_20, (_MM_PERM_ENUM)136); //B20(0-3) B21(0-3) B20(0-3) B21(0-3) B24(0-3) B25(0-3) B24(0-3) B25(0-3) B28(0-3) B29(0-3) B28(0-3) B29(0-3) B2C(0-3) B2D(0-3) B2C(0-3) B2D(0-3)\n                    const __m512i rhs_mat_2367ABEF_20_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_20, (_MM_PERM_ENUM)136); //B22(0-3) B23(0-3) B22(0-3) B23(0-3) B26(0-3) B27(0-3) B26(0-3) B27(0-3) B2A(0-3) B2B(0-3) B2A(0-3) B2B(0-3) B2E(0-3) B2F(0-3) B2E(0-3) B2F(0-3)\n\n                    const __m512i rhs_mat_014589CD_21_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_21, (_MM_PERM_ENUM)136); //B20(8-11) B21(8-11) B20(8-11) B21(8-11) B24(8-11) B25(8-11) B24(8-11) B25(8-11) B28(8-11) B29(8-11) B28(8-11) B29(8-11) B2C(8-11) B2D(8-11) B2C(8-11) B2D(8-11)\n                    const __m512i rhs_mat_2367ABEF_21_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_21, (_MM_PERM_ENUM)136); //B22(8-11) B23(8-11) B22(8-11) B23(8-11) B26(8-11) B27(8-11) B26(8-11) B27(8-11) B2A(8-11) B2B(8-11) B2A(8-11) B2B(8-11) B2E(8-11) B2F(8-11) B2E(8-11) B2F(8-11)\n\n                    const __m512i rhs_mat_014589CD_30_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_30, (_MM_PERM_ENUM)136); ///B30(0-3) B31(0-3) B30(0-3) B31(0-3) B34(0-3) B35(0-3) B34(0-3) B35(0-3) B38(0-3) B39(0-3) B38(0-3) B39(0-3) B3C(0-3) B3D(0-3) B3C(0-3) B3D(0-3)\n                    const __m512i rhs_mat_2367ABEF_30_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_30, (_MM_PERM_ENUM)136); //B32(0-3) B33(0-3) B32(0-3) B33(0-3) B36(0-3) B37(0-3) B36(0-3) B37(0-3) B3A(0-3) B3B(0-3) B3A(0-3) B3B(0-3) B3E(0-3) B3F(0-3) B3E(0-3) B3F(0-3)\n\n                    const __m512i rhs_mat_014589CD_31_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_31, (_MM_PERM_ENUM)136); //B30(8-11) B31(8-11) B30(8-11) B31(8-11) B34(8-11) B35(8-11) B34(8-11) B35(8-11) B38(8-11) B39(8-11) B38(8-11) B39(8-11) B3C(8-11) B3D(8-11) B3C(8-11) B3D(8-11)\n                    const __m512i rhs_mat_2367ABEF_31_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_31, (_MM_PERM_ENUM)136); //B32(8-11) B33(8-11) B32(8-11) B33(8-11) B36(8-11) B37(8-11) B36(8-11) B37(8-11) B3A(8-11) B3B(8-11) B3A(8-11) B3B(8-11) B3E(8-11) B3F(8-11) B3E(8-11) B3F(8-11)\n\n                    const __m512i rhs_mat_014589CD_40_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_40, (_MM_PERM_ENUM)136); //B40(0-3) B41(0-3) B40(0-3) B41(0-3) B44(0-3) B45(0-3) B44(0-3) B45(0-3) B48(0-3) B49(0-3) B48(0-3) B49(0-3) B4C(0-3) B4D(0-3) B4C(0-3) B4D(0-3)\n                    const __m512i rhs_mat_2367ABEF_40_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_40, (_MM_PERM_ENUM)136); //B42(0-3) B43(0-3) B42(0-3) B43(0-3) B46(0-3) B47(0-3) B46(0-3) B47(0-3) B4A(0-3) B4B(0-3) B4A(0-3) B4B(0-3) B4E(0-3) B4F(0-3) B4E(0-3) B4F(0-3)\n\n                    const __m512i rhs_mat_014589CD_41_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_41, (_MM_PERM_ENUM)136); //B40(8-11) B41(8-11) B40(8-11) B41(8-11) B44(8-11) B45(8-11) B44(8-11) B45(8-11) B48(8-11) B49(8-11) B48(8-11) B49(8-11) B4C(8-11) B4D(8-11) B4C(8-11) B4D(8-11)\n                    const __m512i rhs_mat_2367ABEF_41_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_41, (_MM_PERM_ENUM)136); //B42(8-11) B43(8-11) B42(8-11) B43(8-11) B46(8-11) B47(8-11) B46(8-11) B47(8-11) B4A(8-11) B4B(8-11) B4A(8-11) B4B(8-11) B4E(8-11) B4F(8-11) B4E(8-11) B4F(8-11)\n\n                    const __m512i rhs_mat_014589CD_50_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_50, (_MM_PERM_ENUM)136); //B50(0-3) B51(0-3) B50(0-3) B51(0-3) B54(0-3) B55(0-3) B54(0-3) B55(0-3) B58(0-3) B59(0-3) B58(0-3) B59(0-3) B5C(0-3) B5D(0-3) B5C(0-3) B5D(0-3)\n                    const __m512i rhs_mat_2367ABEF_50_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_50, (_MM_PERM_ENUM)136); //B52(0-3) B53(0-3) B52(0-3) B53(0-3) B56(0-3) B57(0-3) B56(0-3) B57(0-3) B5A(0-3) B5B(0-3) B5A(0-3) B5B(0-3) B5E(0-3) B5F(0-3) B5E(0-3) B5F(0-3)\n\n                    const __m512i rhs_mat_014589CD_51_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_51, (_MM_PERM_ENUM)136); //B50(8-11) B51(8-11) B50(8-11) B51(8-11) B54(8-11) B55(8-11) B54(8-11) B55(8-11) B58(8-11) B59(8-11) B58(8-11) B59(8-11) B5C(8-11) B5D(8-11) B5C(8-11) B5D(8-11)\n                    const __m512i rhs_mat_2367ABEF_51_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_51, (_MM_PERM_ENUM)136); //B52(8-11) B53(8-11) B52(8-11) B53(8-11) B56(8-11) B57(8-11) B56(8-11) B57(8-11) B5A(8-11) B5B(8-11) B5A(8-11) B5B(8-11) B5E(8-11) B5F(8-11) B5E(8-11) B5F(8-11)\n\n                    const __m512i rhs_mat_014589CD_60_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_60, (_MM_PERM_ENUM)136); //B60(0-3) B61(0-3) B60(0-3) B61(0-3) B64(0-3) B65(0-3) B64(0-3) B65(0-3) B68(0-3) B69(0-3) B68(0-3) B69(0-3) B6C(0-3) B6D(0-3) B6C(0-3) B6D(0-3)\n                    const __m512i rhs_mat_2367ABEF_60_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_60, (_MM_PERM_ENUM)136); //B62(0-3) B63(0-3) B62(0-3) B63(0-3) B66(0-3) B67(0-3) B66(0-3) B67(0-3) B6A(0-3) B6B(0-3) B6A(0-3) B6B(0-3) B6E(0-3) B6F(0-3) B6E(0-3) B6F(0-3)\n\n                    const __m512i rhs_mat_014589CD_61_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_61, (_MM_PERM_ENUM)136); //B60(8-11) B61(8-11) B60(8-11) B61(8-11) B64(8-11) B65(8-11) B64(8-11) B65(8-11) B68(8-11) B69(8-11) B68(8-11) B69(8-11) B6C(8-11) B6D(8-11) B6C(8-11) B6D(8-11)\n                    const __m512i rhs_mat_2367ABEF_61_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_61, (_MM_PERM_ENUM)136); //B62(8-11) B63(8-11) B62(8-11) B63(8-11) B66(8-11) B67(8-11) B66(8-11) B67(8-11) B6A(8-11) B6B(8-11) B6A(8-11) B6B(8-11) B6E(8-11) B6F(8-11) B6E(8-11) B6F(8-11)\n\n                    const __m512i rhs_mat_014589CD_70_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_70, (_MM_PERM_ENUM)136); //B70(0-3) B71(0-3) B70(0-3) B71(0-3) B74(0-3) B75(0-3) B74(0-3) B75(0-3) B78(0-3) B79(0-3) B78(0-3) B79(0-3) B7C(0-3) B7D(0-3) B7C(0-3) B7D(0-3)\n                    const __m512i rhs_mat_2367ABEF_70_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_70, (_MM_PERM_ENUM)136); //B72(0-3) B73(0-3) B72(0-3) B73(0-3) B76(0-3) B77(0-3) B76(0-3) B77(0-3) B7A(0-3) B7B(0-3) B7A(0-3) B7B(0-3) B7E(0-3) B7F(0-3) B7E(0-3) B7F(0-3)\n\n                    const __m512i rhs_mat_014589CD_71_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_71, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)\n                    const __m512i rhs_mat_2367ABEF_71_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_71, (_MM_PERM_ENUM)136); //B72(8-11) B73(8-11) B72(8-11) B73(8-11) B76(8-11) B77(8-11) B76(8-11) B77(8-11) B7A(8-11) B7B(8-11) B7A(8-11) B7B(8-11) B7E(8-11) B7F(8-11) B7E(8-11) B7F(8-11)\n\n                    const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7)\n                    const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7)\n\n                    const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15)\n                    const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15)\n\n                    const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7)\n                    const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7)\n\n                    const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15)\n                    const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15)\n\n                    const __m512i rhs_mat_014589CD_20_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_20, (_MM_PERM_ENUM)221); //B20(4-7) B21(4-7) B20(4-7) B21(4-7) B24(4-7) B25(4-7) B24(4-7) B25(4-7) B28(4-7) B29(4-7) B28(4-7) B29(4-7) B2C(4-7) B2D(4-7) B2C(4-7) B2D(4-7)\n                    const __m512i rhs_mat_2367ABEF_20_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_20, (_MM_PERM_ENUM)221); //B22(4-7) B23(4-7) B22(4-7) B23(4-7) B26(4-7) B27(4-7) B26(4-7) B27(4-7) B2A(4-7) B2B(4-7) B2A(4-7) B2B(4-7) B2E(4-7) B2F(4-7) B2E(4-7) B2F(4-7)\n\n                    const __m512i rhs_mat_014589CD_21_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_21, (_MM_PERM_ENUM)221); //B20(12-15) B21(12-15) B20(12-15) B21(12-15) B24(12-15) B25(12-15) B24(12-15) B25(12-15) B28(12-15) B29(12-15) B28(12-15) B29(12-15) B2C(12-15) B2D(12-15) B2C(12-15) B2D(12-15)\n                    const __m512i rhs_mat_2367ABEF_21_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_21, (_MM_PERM_ENUM)221); //B22(12-15) B23(12-15) B22(12-15) B23(12-15) B26(12-15) B27(12-15) B26(12-15) B27(12-15) B2A(12-15) B2B(12-15) B2A(12-15) B2B(12-15) B2E(12-15) B2F(12-15) B2E(12-15) B2F(12-15)\n\n                    const __m512i rhs_mat_014589CD_30_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_30, (_MM_PERM_ENUM)221); //B30(4-7) B31(4-7) B30(4-7) B31(4-7) B34(4-7) B35(4-7) B34(4-7) B35(4-7) B38(4-7) B39(4-7) B38(4-7) B39(4-7) B3C(4-7) B3D(4-7) B3C(4-7) B3D(4-7)\n                    const __m512i rhs_mat_2367ABEF_30_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_30, (_MM_PERM_ENUM)221); //B32(4-7) B33(4-7) B32(4-7) B33(4-7) B36(4-7) B37(4-7) B36(4-7) B37(4-7) B3A(4-7) B3B(4-7) B3A(4-7) B3B(4-7) B3E(4-7) B3F(4-7) B3E(4-7) B3F(4-7)\n\n                    const __m512i rhs_mat_014589CD_31_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_31, (_MM_PERM_ENUM)221); //B30(12-15) B31(12-15) B30(12-15) B31(12-15) B34(12-15) B35(12-15) B34(12-15) B35(12-15) B38(12-15) B39(12-15) B38(12-15) B39(12-15) B3C(12-15) B3D(12-15) B3C(12-15) B3D(12-15)\n                    const __m512i rhs_mat_2367ABEF_31_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_31, (_MM_PERM_ENUM)221); //B32(12-15) B33(12-15) B32(12-15) B33(12-15) B36(12-15) B37(12-15) B36(12-15) B37(12-15) B3A(12-15) B3B(12-15) B3A(12-15) B3B(12-15) B3E(12-15) B3F(12-15) B3E(12-15) B3F(12-15)\n\n                    const __m512i rhs_mat_014589CD_40_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_40, (_MM_PERM_ENUM)221); //B40(4-7) B41(4-7) B40(4-7) B41(4-7) B44(4-7) B45(4-7) B44(4-7) B45(4-7) B48(4-7) B49(4-7) B48(4-7) B49(4-7) B4C(4-7) B4D(4-7) B4C(4-7) B4D(4-7)\n                    const __m512i rhs_mat_2367ABEF_40_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_40, (_MM_PERM_ENUM)221); //B42(4-7) B43(4-7) B42(4-7) B43(4-7) B46(4-7) B47(4-7) B46(4-7) B47(4-7) B4A(4-7) B4B(4-7) B4A(4-7) B4B(4-7) B4E(4-7) B4F(4-7) B4E(4-7) B4F(4-7)\n\n                    const __m512i rhs_mat_014589CD_41_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_41, (_MM_PERM_ENUM)221); //B40(12-15) B41(12-15) B40(12-15) B41(12-15) B44(12-15) B45(12-15) B44(12-15) B45(12-15) B48(12-15) B49(12-15) B48(12-15) B49(12-15) B4C(12-15) B4D(12-15) B4C(12-15) B4D(12-15)\n                    const __m512i rhs_mat_2367ABEF_41_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_41, (_MM_PERM_ENUM)221); //B42(12-15) B43(12-15) B42(12-15) B43(12-15) B46(12-15) B47(12-15) B46(12-15) B47(12-15) B4A(12-15) B4B(12-15) B4A(12-15) B4B(12-15) B4E(12-15) B4F(12-15) B4E(12-15) B4F(12-15)\n\n                    const __m512i rhs_mat_014589CD_50_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_50, (_MM_PERM_ENUM)221); //B50(4-7) B51(4-7) B50(4-7) B51(4-7) B54(4-7) B55(4-7) B54(4-7) B55(4-7) B58(4-7) B59(4-7) B58(4-7) B59(4-7) B5C(4-7) B5D(4-7) B5C(4-7) B5D(4-7)\n                    const __m512i rhs_mat_2367ABEF_50_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_50, (_MM_PERM_ENUM)221); //B52(4-7) B53(4-7) B52(4-7) B53(4-7) B56(4-7) B57(4-7) B56(4-7) B57(4-7) B5A(4-7) B5B(4-7) B5A(4-7) B5B(4-7) B5E(4-7) B5F(4-7) B5E(4-7) B5F(4-7)\n\n                    const __m512i rhs_mat_014589CD_51_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_51, (_MM_PERM_ENUM)221); //B50(12-15) B51(12-15) B50(12-15) B51(12-15) B54(12-15) B55(12-15) B54(12-15) B55(12-15) B58(12-15) B59(12-15) B58(12-15) B59(12-15) B5C(12-15) B5D(12-15) B5C(12-15) B5D(12-15)\n                    const __m512i rhs_mat_2367ABEF_51_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_51, (_MM_PERM_ENUM)221); //B52(12-15) B53(12-15) B52(12-15) B53(12-15) B56(12-15) B57(12-15) B56(12-15) B57(12-15) B5A(12-15) B5B(12-15) B5A(12-15) B5B(12-15) B5E(12-15) B5F(12-15) B5E(12-15) B5F(12-15)\n\n                    const __m512i rhs_mat_014589CD_60_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_60, (_MM_PERM_ENUM)221); //B60(4-7) B61(4-7) B60(4-7) B61(4-7) B64(4-7) B65(4-7) B64(4-7) B65(4-7) B68(4-7) B69(4-7) B68(4-7) B69(4-7) B6C(4-7) B6D(4-7) B6C(4-7) B6D(4-7)\n                    const __m512i rhs_mat_2367ABEF_60_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_60, (_MM_PERM_ENUM)221); //B62(4-7) B63(4-7) B62(4-7) B63(4-7) B66(4-7) B67(4-7) B66(4-7) B67(4-7) B6A(4-7) B6B(4-7) B6A(4-7) B6B(4-7) B6E(4-7) B6F(4-7) B6E(4-7) B6F(4-7)\n\n                    const __m512i rhs_mat_014589CD_61_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_61, (_MM_PERM_ENUM)221); //B60(12-15) B61(12-15) B60(12-15) B61(12-15) B64(12-15) B65(12-15) B64(12-15) B65(12-15) B68(12-15) B69(12-15) B68(12-15) B69(12-15) B6C(12-15) B6D(12-15) B6C(12-15) B6D(12-15)\n                    const __m512i rhs_mat_2367ABEF_61_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_61, (_MM_PERM_ENUM)221); //B62(12-15) B63(12-15) B62(12-15) B63(12-15) B66(12-15) B67(12-15) B66(12-15) B67(12-15) B6A(12-15) B6B(12-15) B6A(12-15) B6B(12-15) B6E(12-15) B6F(12-15) B6E(12-15) B6F(12-15)\n\n                    const __m512i rhs_mat_014589CD_70_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_70, (_MM_PERM_ENUM)221); //B70(4-7) B71(4-7) B70(4-7) B71(4-7) B74(4-7) B75(4-7) B74(4-7) B75(4-7) B78(4-7) B79(4-7) B78(4-7) B79(4-7) B7C(4-7) B7D(4-7) B7C(4-7) B7D(4-7)\n                    const __m512i rhs_mat_2367ABEF_70_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_70, (_MM_PERM_ENUM)221); //B72(4-7) B73(4-7) B72(4-7) B73(4-7) B76(4-7) B77(4-7) B76(4-7) B77(4-7) B7A(4-7) B7B(4-7) B7A(4-7) B7B(4-7) B7E(4-7) B7F(4-7) B7E(4-7) B7F(4-7)\n\n                    const __m512i rhs_mat_014589CD_71_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_71, (_MM_PERM_ENUM)221); //B70(12-15) B71(12-15) B70(12-15) B71(12-15) B74(12-15) B75(12-15) B74(12-15) B75(12-15) B78(12-15) B79(12-15) B78(12-15) B79(12-15) B7C(12-15) B7D(12-15) B7C(12-15) B7D(12-15)\n                    const __m512i rhs_mat_2367ABEF_71_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_71, (_MM_PERM_ENUM)221); //B72(12-15) B73(12-15) B72(12-15) B73(12-15) B76(12-15) B77(12-15) B76(12-15) B77(12-15) B7A(12-15) B7B(12-15) B7A(12-15) B7B(12-15) B7E(12-15) B7F(12-15) B7E(12-15) B7F(12-15)\n\n                    //notation:superblock subblock\n                    //s00 m00  s01 m01   s10 m10  s11 m11  s20 m20  s21 m21   s30 m30  s31 m31  s40 m40  s41 m41   s50 m50  s51 m51  s60 m60  s61 m61   s70 m70  s71 m71\n\n                    const __m128i mins_and_scales_01_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + sb * 64));\n                    const __m128i mins_and_scales_23_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 16 + sb * 64));\n                    const __m128i mins_and_scales_45_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 32 + sb * 64));\n                    const __m128i mins_and_scales_67_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 48 + sb * 64));\n\n                    const __m128i mins_and_scales_01_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + sb * 64));\n                    const __m128i mins_and_scales_23_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 16 + sb * 64));\n                    const __m128i mins_and_scales_45_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 32 + sb * 64));\n                    const __m128i mins_and_scales_67_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 48 + sb * 64));\n\n                    // Combine mins and scales for sub-blocks: 0-1, 2-3, 4-5, 6-7 in the sb loop\n                    const __m256i mins_and_scales_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_01_0), mins_and_scales_01_1, 1);\n                    const __m256i mins_and_scales_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_23_0), mins_and_scales_23_1, 1);\n                    const __m256i mins_and_scales_45 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_45_0), mins_and_scales_45_1, 1);\n                    const __m256i mins_and_scales_67 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_67_0), mins_and_scales_67_1, 1);\n\n                    // Extract scales which is lower half from mins_and_scales\n                    const __m256i scales_01 = _mm256_and_si256(mins_and_scales_01, m4b);\n                    const __m256i scales_23 = _mm256_and_si256(mins_and_scales_23, m4b);\n                    const __m256i scales_45 = _mm256_and_si256(mins_and_scales_45, m4b);\n                    const __m256i scales_67 = _mm256_and_si256(mins_and_scales_67, m4b);\n\n                    // Extract mins which is upper half from mins_and_scales\n                    const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_01, 4), m4b));\n                    const __m512i mins_23 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_23, 4), m4b));\n                    const __m512i mins_45 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_45, 4), m4b));\n                    const __m512i mins_67 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_67, 4), m4b));\n\n                    const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_01,scalesmask1));\n                    const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_01,scalesmask2));\n                    const __m512i scales_2 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_23,scalesmask1));\n                    const __m512i scales_3 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_23,scalesmask2));\n                    const __m512i scales_4 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_45,scalesmask1));\n                    const __m512i scales_5 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_45,scalesmask2));\n                    const __m512i scales_6 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_67,scalesmask1));\n                    const __m512i scales_7 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_67,scalesmask2));\n\n                    const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_2 = _mm512_shuffle_epi32(scales_2, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_2 = _mm512_shuffle_epi32(scales_2, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_3 = _mm512_shuffle_epi32(scales_3, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_3 = _mm512_shuffle_epi32(scales_3, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_4 = _mm512_shuffle_epi32(scales_4, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_4 = _mm512_shuffle_epi32(scales_4, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_5 = _mm512_shuffle_epi32(scales_5, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_5 = _mm512_shuffle_epi32(scales_5, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_6 = _mm512_shuffle_epi32(scales_6, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_6 = _mm512_shuffle_epi32(scales_6, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_7 = _mm512_shuffle_epi32(scales_7, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_7 = _mm512_shuffle_epi32(scales_7, (_MM_PERM_ENUM)238);\n\n\n                    for (int rp = 0; rp < 4; rp++) {\n\n                        // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3\n                        // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector\n                        __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0);\n                        __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17);\n                        __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0);\n                        __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17);\n                        __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0);\n                        __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17);\n                        __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0);\n                        __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17);\n                        __m256i lhs_mat_ymm_0123_20 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_20 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_20, lhs_mat_ymm_0123_20, 0);\n                        __m256i lhs_mat_ymm_23_20 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_20, lhs_mat_ymm_0123_20, 17);\n                        __m256i lhs_mat_ymm_0123_21 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_21 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_21, lhs_mat_ymm_0123_21, 0);\n                        __m256i lhs_mat_ymm_23_21 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_21, lhs_mat_ymm_0123_21, 17);\n                        __m256i lhs_mat_ymm_0123_30 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_30 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_30, lhs_mat_ymm_0123_30, 0);\n                        __m256i lhs_mat_ymm_23_30 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_30, lhs_mat_ymm_0123_30, 17);\n                        __m256i lhs_mat_ymm_0123_31 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_31 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_31, lhs_mat_ymm_0123_31, 0);\n                        __m256i lhs_mat_ymm_23_31 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_31, lhs_mat_ymm_0123_31, 17);\n\n                        __m256i lhs_mat_ymm_0123_40 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_40 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_40, lhs_mat_ymm_0123_40, 0);\n                        __m256i lhs_mat_ymm_23_40 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_40, lhs_mat_ymm_0123_40, 17);\n                        __m256i lhs_mat_ymm_0123_41 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 288 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_41 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_41, lhs_mat_ymm_0123_41, 0);\n                        __m256i lhs_mat_ymm_23_41 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_41, lhs_mat_ymm_0123_41, 17);\n                        __m256i lhs_mat_ymm_0123_50 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 320 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_50 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_50, lhs_mat_ymm_0123_50, 0);\n                        __m256i lhs_mat_ymm_23_50 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_50, lhs_mat_ymm_0123_50, 17);\n                        __m256i lhs_mat_ymm_0123_51 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 352 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_51 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_51, lhs_mat_ymm_0123_51, 0);\n                        __m256i lhs_mat_ymm_23_51 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_51, lhs_mat_ymm_0123_51, 17);\n                        __m256i lhs_mat_ymm_0123_60 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 384 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_60 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_60, lhs_mat_ymm_0123_60, 0);\n                        __m256i lhs_mat_ymm_23_60 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_60, lhs_mat_ymm_0123_60, 17);\n                        __m256i lhs_mat_ymm_0123_61 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 416 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_61 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_61, lhs_mat_ymm_0123_61, 0);\n                        __m256i lhs_mat_ymm_23_61 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_61, lhs_mat_ymm_0123_61, 17);\n                        __m256i lhs_mat_ymm_0123_70 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 448 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_70 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_70, lhs_mat_ymm_0123_70, 0);\n                        __m256i lhs_mat_ymm_23_70 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_70, lhs_mat_ymm_0123_70, 17);\n                        __m256i lhs_mat_ymm_0123_71 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 480 + 512 * sb)));\n                        __m256i lhs_mat_ymm_01_71 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_71, lhs_mat_ymm_0123_71, 0);\n                        __m256i lhs_mat_ymm_23_71 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_71, lhs_mat_ymm_0123_71, 17);\n\n\n                        __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1);\n                        __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1);\n                        __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1);\n                        __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1);\n\n                        __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1);\n                        __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1);\n                        __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1);\n                        __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1);\n\n                        __m512i lhs_mat_01_20 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_20), lhs_mat_ymm_01_20, 1);\n                        __m512i lhs_mat_23_20 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_20), lhs_mat_ymm_23_20, 1);\n                        __m512i lhs_mat_01_21 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_21), lhs_mat_ymm_01_21, 1);\n                        __m512i lhs_mat_23_21 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_21), lhs_mat_ymm_23_21, 1);\n\n                        __m512i lhs_mat_01_30 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_30), lhs_mat_ymm_01_30, 1);\n                        __m512i lhs_mat_23_30 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_30), lhs_mat_ymm_23_30, 1);\n                        __m512i lhs_mat_01_31 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_31), lhs_mat_ymm_01_31, 1);\n                        __m512i lhs_mat_23_31 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_31), lhs_mat_ymm_23_31, 1);\n\n                        __m512i lhs_mat_01_40 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_40), lhs_mat_ymm_01_40, 1);\n                        __m512i lhs_mat_23_40 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_40), lhs_mat_ymm_23_40, 1);\n                        __m512i lhs_mat_01_41 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_41), lhs_mat_ymm_01_41, 1);\n                        __m512i lhs_mat_23_41 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_41), lhs_mat_ymm_23_41, 1);\n\n                        __m512i lhs_mat_01_50 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_50), lhs_mat_ymm_01_50, 1);\n                        __m512i lhs_mat_23_50 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_50), lhs_mat_ymm_23_50, 1);\n                        __m512i lhs_mat_01_51 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_51), lhs_mat_ymm_01_51, 1);\n                        __m512i lhs_mat_23_51 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_51), lhs_mat_ymm_23_51, 1);\n\n                        __m512i lhs_mat_01_60 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_60), lhs_mat_ymm_01_60, 1);\n                        __m512i lhs_mat_23_60 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_60), lhs_mat_ymm_23_60, 1);\n                        __m512i lhs_mat_01_61 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_61), lhs_mat_ymm_01_61, 1);\n                        __m512i lhs_mat_23_61 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_61), lhs_mat_ymm_23_61, 1);\n\n                        __m512i lhs_mat_01_70 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_70), lhs_mat_ymm_01_70, 1);\n                        __m512i lhs_mat_23_70 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_70), lhs_mat_ymm_23_70, 1);\n                        __m512i lhs_mat_01_71 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_71), lhs_mat_ymm_01_71, 1);\n                        __m512i lhs_mat_23_71 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_71), lhs_mat_ymm_23_71, 1);\n\n                        // Bsums are loaded for the different Q8_K blocks\n                        __m128i lhs_raw_bsums_01_0123 = _mm_loadu_si128((const __m128i *)((a_ptrs[rp][b].bsums + 32 * sb)));\n                        __m128i lhs_raw_bsums_23_0123 = _mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].bsums + 8 + 32 * sb));\n                        __m128i lhs_raw_bsums_01_4567 = _mm_loadu_si128((const __m128i *)((a_ptrs[rp][b].bsums + 16 + 32 * sb)));\n                        __m128i lhs_raw_bsums_23_4567 = _mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].bsums + 24 + 32 * sb));\n\n                        __m256i lhs_bsums_ymm_01_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_0123), lhs_raw_bsums_01_0123, 1);\n                        __m512i lhs_bsums_01_0123 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_01_0123), lhs_bsums_ymm_01_0123, 1);\n                        __m256i lhs_bsums_ymm_23_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_0123), lhs_raw_bsums_23_0123, 1);\n                        __m512i lhs_bsums_23_0123 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_23_0123), lhs_bsums_ymm_23_0123, 1);                        __m256i lhs_bsums_ymm_01_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_4567), lhs_raw_bsums_01_4567, 1);\n                        __m512i lhs_bsums_01_4567 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_01_4567), lhs_bsums_ymm_01_4567, 1);\n                        __m256i lhs_bsums_ymm_23_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_4567), lhs_raw_bsums_23_4567, 1);\n                        __m512i lhs_bsums_23_4567 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_23_4567), lhs_bsums_ymm_23_4567, 1);\n\n                        // Shuffle pattern one - left side input\n                        const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)\n                        const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3)\n\n                        const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)\n                        const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11)\n\n                        const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)\n                        const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3)\n\n                        const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)\n                        const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11)\n\n                        const __m512i lhs_mat_01_20_sp1 = _mm512_shuffle_epi32(lhs_mat_01_20, (_MM_PERM_ENUM)160); //A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3)\n                        const __m512i lhs_mat_23_20_sp1 = _mm512_shuffle_epi32(lhs_mat_23_20, (_MM_PERM_ENUM)160); //A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3)\n\n                        const __m512i lhs_mat_01_21_sp1 = _mm512_shuffle_epi32(lhs_mat_01_21, (_MM_PERM_ENUM)160); //A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11)\n                        const __m512i lhs_mat_23_21_sp1 = _mm512_shuffle_epi32(lhs_mat_23_21, (_MM_PERM_ENUM)160); //A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11)\n\n                        const __m512i lhs_mat_01_30_sp1 = _mm512_shuffle_epi32(lhs_mat_01_30, (_MM_PERM_ENUM)160); //A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3)\n                        const __m512i lhs_mat_23_30_sp1 = _mm512_shuffle_epi32(lhs_mat_23_30, (_MM_PERM_ENUM)160); //A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3)\n\n                        const __m512i lhs_mat_01_31_sp1 = _mm512_shuffle_epi32(lhs_mat_01_31, (_MM_PERM_ENUM)160); //A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11)\n                        const __m512i lhs_mat_23_31_sp1 = _mm512_shuffle_epi32(lhs_mat_23_31, (_MM_PERM_ENUM)160); //A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11)\n\n                        const __m512i lhs_mat_01_40_sp1 = _mm512_shuffle_epi32(lhs_mat_01_40, (_MM_PERM_ENUM)160); //A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3)\n                        const __m512i lhs_mat_23_40_sp1 = _mm512_shuffle_epi32(lhs_mat_23_40, (_MM_PERM_ENUM)160); //A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3)\n\n                        const __m512i lhs_mat_01_41_sp1 = _mm512_shuffle_epi32(lhs_mat_01_41, (_MM_PERM_ENUM)160); //A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11)\n                        const __m512i lhs_mat_23_41_sp1 = _mm512_shuffle_epi32(lhs_mat_23_41, (_MM_PERM_ENUM)160); //A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11)\n\n                        const __m512i lhs_mat_01_50_sp1 = _mm512_shuffle_epi32(lhs_mat_01_50, (_MM_PERM_ENUM)160); //A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3)\n                        const __m512i lhs_mat_23_50_sp1 = _mm512_shuffle_epi32(lhs_mat_23_50, (_MM_PERM_ENUM)160); //A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3)\n\n                        const __m512i lhs_mat_01_51_sp1 = _mm512_shuffle_epi32(lhs_mat_01_51, (_MM_PERM_ENUM)160); //A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11)\n                        const __m512i lhs_mat_23_51_sp1 = _mm512_shuffle_epi32(lhs_mat_23_51, (_MM_PERM_ENUM)160); //A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11)\n\n                        const __m512i lhs_mat_01_60_sp1 = _mm512_shuffle_epi32(lhs_mat_01_60, (_MM_PERM_ENUM)160); //A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3)\n                        const __m512i lhs_mat_23_60_sp1 = _mm512_shuffle_epi32(lhs_mat_23_60, (_MM_PERM_ENUM)160); //A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3)\n\n                        const __m512i lhs_mat_01_61_sp1 = _mm512_shuffle_epi32(lhs_mat_01_61, (_MM_PERM_ENUM)160); //A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11)\n                        const __m512i lhs_mat_23_61_sp1 = _mm512_shuffle_epi32(lhs_mat_23_61, (_MM_PERM_ENUM)160); //A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11)\n\n                        const __m512i lhs_mat_01_70_sp1 = _mm512_shuffle_epi32(lhs_mat_01_70, (_MM_PERM_ENUM)160); //A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3)\n                        const __m512i lhs_mat_23_70_sp1 = _mm512_shuffle_epi32(lhs_mat_23_70, (_MM_PERM_ENUM)160); //A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3)\n\n                        const __m512i lhs_mat_01_71_sp1 = _mm512_shuffle_epi32(lhs_mat_01_71, (_MM_PERM_ENUM)160); //A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11)\n                        const __m512i lhs_mat_23_71_sp1 = _mm512_shuffle_epi32(lhs_mat_23_71, (_MM_PERM_ENUM)160); //A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11)\n\n                        const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)\n                        const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7)\n\n                        const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)\n                        const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15)\n\n                        const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)\n                        const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7)\n\n                        const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)\n                        const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15)\n\n                        const __m512i lhs_mat_01_20_sp2 = _mm512_shuffle_epi32(lhs_mat_01_20, (_MM_PERM_ENUM)245); //A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7)\n                        const __m512i lhs_mat_23_20_sp2 = _mm512_shuffle_epi32(lhs_mat_23_20, (_MM_PERM_ENUM)245); //A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7)\n\n                        const __m512i lhs_mat_01_21_sp2 = _mm512_shuffle_epi32(lhs_mat_01_21, (_MM_PERM_ENUM)245); //A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15)\n                        const __m512i lhs_mat_23_21_sp2 = _mm512_shuffle_epi32(lhs_mat_23_21, (_MM_PERM_ENUM)245); //A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15)\n\n                        const __m512i lhs_mat_01_30_sp2 = _mm512_shuffle_epi32(lhs_mat_01_30, (_MM_PERM_ENUM)245); //A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7)\n                        const __m512i lhs_mat_23_30_sp2 = _mm512_shuffle_epi32(lhs_mat_23_30, (_MM_PERM_ENUM)245); //A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7)\n\n                        const __m512i lhs_mat_01_31_sp2 = _mm512_shuffle_epi32(lhs_mat_01_31, (_MM_PERM_ENUM)245); //A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15)\n                        const __m512i lhs_mat_23_31_sp2 = _mm512_shuffle_epi32(lhs_mat_23_31, (_MM_PERM_ENUM)245); //A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15)\n\n                        const __m512i lhs_mat_01_40_sp2 = _mm512_shuffle_epi32(lhs_mat_01_40, (_MM_PERM_ENUM)245); //A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7)\n                        const __m512i lhs_mat_23_40_sp2 = _mm512_shuffle_epi32(lhs_mat_23_40, (_MM_PERM_ENUM)245); //A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7)\n\n                        const __m512i lhs_mat_01_41_sp2 = _mm512_shuffle_epi32(lhs_mat_01_41, (_MM_PERM_ENUM)245); //A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15)\n                        const __m512i lhs_mat_23_41_sp2 = _mm512_shuffle_epi32(lhs_mat_23_41, (_MM_PERM_ENUM)245); //A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15)\n\n                        const __m512i lhs_mat_01_50_sp2 = _mm512_shuffle_epi32(lhs_mat_01_50, (_MM_PERM_ENUM)245); //A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7)\n                        const __m512i lhs_mat_23_50_sp2 = _mm512_shuffle_epi32(lhs_mat_23_50, (_MM_PERM_ENUM)245); //A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7)\n\n                        const __m512i lhs_mat_01_51_sp2 = _mm512_shuffle_epi32(lhs_mat_01_51, (_MM_PERM_ENUM)245); //A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15)\n                        const __m512i lhs_mat_23_51_sp2 = _mm512_shuffle_epi32(lhs_mat_23_51, (_MM_PERM_ENUM)245); //A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15)\n\n                        const __m512i lhs_mat_01_60_sp2 = _mm512_shuffle_epi32(lhs_mat_01_60, (_MM_PERM_ENUM)245); //A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7)\n                        const __m512i lhs_mat_23_60_sp2 = _mm512_shuffle_epi32(lhs_mat_23_60, (_MM_PERM_ENUM)245); //A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7)\n\n                        const __m512i lhs_mat_01_61_sp2 = _mm512_shuffle_epi32(lhs_mat_01_61, (_MM_PERM_ENUM)245); //A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15)\n                        const __m512i lhs_mat_23_61_sp2 = _mm512_shuffle_epi32(lhs_mat_23_61, (_MM_PERM_ENUM)245); //A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15)\n\n                        const __m512i lhs_mat_01_70_sp2 = _mm512_shuffle_epi32(lhs_mat_01_70, (_MM_PERM_ENUM)245); //A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7)\n                        const __m512i lhs_mat_23_70_sp2 = _mm512_shuffle_epi32(lhs_mat_23_70, (_MM_PERM_ENUM)245); //A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7)\n\n                        const __m512i lhs_mat_01_71_sp2 = _mm512_shuffle_epi32(lhs_mat_01_71, (_MM_PERM_ENUM)245); //A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15)\n                        const __m512i lhs_mat_23_71_sp2 = _mm512_shuffle_epi32(lhs_mat_23_71, (_MM_PERM_ENUM)245); //A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15)\n\n                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane\n                        __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1));\n                        __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1));\n\n                        __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1));\n                        __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1));\n\n                        __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1));\n                        __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1));\n\n                        __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1));\n                        __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1));\n\n                        __m512i iacc_mat_00_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp1, lhs_mat_01_20_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp1, lhs_mat_01_21_sp1));\n                        __m512i iacc_mat_01_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp1, lhs_mat_01_20_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp1, lhs_mat_01_21_sp1));\n\n                        __m512i iacc_mat_10_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp1, lhs_mat_23_20_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp1, lhs_mat_23_21_sp1));\n                        __m512i iacc_mat_11_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp1, lhs_mat_23_20_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp1, lhs_mat_23_21_sp1));\n\n                        __m512i iacc_mat_00_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp1, lhs_mat_01_30_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp1, lhs_mat_01_31_sp1));\n                        __m512i iacc_mat_01_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp1, lhs_mat_01_30_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp1, lhs_mat_01_31_sp1));\n\n                        __m512i iacc_mat_10_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp1, lhs_mat_23_30_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp1, lhs_mat_23_31_sp1));\n                        __m512i iacc_mat_11_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp1, lhs_mat_23_30_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp1, lhs_mat_23_31_sp1));\n\n                        __m512i iacc_mat_00_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp1, lhs_mat_01_40_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp1, lhs_mat_01_41_sp1));\n                        __m512i iacc_mat_01_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp1, lhs_mat_01_40_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp1, lhs_mat_01_41_sp1));\n\n                        __m512i iacc_mat_10_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp1, lhs_mat_23_40_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp1, lhs_mat_23_41_sp1));\n                        __m512i iacc_mat_11_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp1, lhs_mat_23_40_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp1, lhs_mat_23_41_sp1));\n\n                        __m512i iacc_mat_00_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp1, lhs_mat_01_50_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp1, lhs_mat_01_51_sp1));\n                        __m512i iacc_mat_01_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp1, lhs_mat_01_50_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp1, lhs_mat_01_51_sp1));\n\n                        __m512i iacc_mat_10_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp1, lhs_mat_23_50_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp1, lhs_mat_23_51_sp1));\n                        __m512i iacc_mat_11_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp1, lhs_mat_23_50_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp1, lhs_mat_23_51_sp1));\n\n                        __m512i iacc_mat_00_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp1, lhs_mat_01_60_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp1, lhs_mat_01_61_sp1));\n                        __m512i iacc_mat_01_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp1, lhs_mat_01_60_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp1, lhs_mat_01_61_sp1));\n\n                        __m512i iacc_mat_10_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp1, lhs_mat_23_60_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp1, lhs_mat_23_61_sp1));\n                        __m512i iacc_mat_11_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp1, lhs_mat_23_60_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp1, lhs_mat_23_61_sp1));\n\n                        __m512i iacc_mat_00_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp1, lhs_mat_01_70_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp1, lhs_mat_01_71_sp1));\n                        __m512i iacc_mat_01_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp1, lhs_mat_01_70_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp1, lhs_mat_01_71_sp1));\n\n                        __m512i iacc_mat_10_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp1, lhs_mat_23_70_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp1, lhs_mat_23_71_sp1));\n                        __m512i iacc_mat_11_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp1, lhs_mat_23_70_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp1, lhs_mat_23_71_sp1));\n\n\n                        __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2));\n                        __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2));\n\n                        __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2));\n                        __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2));\n\n                        __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2));\n                        __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2));\n\n                        __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2));\n                        __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2));\n\n                        __m512i iacc_mat_00_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp2, lhs_mat_01_20_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp2, lhs_mat_01_21_sp2));\n                        __m512i iacc_mat_01_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp2, lhs_mat_01_20_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp2, lhs_mat_01_21_sp2));\n\n                        __m512i iacc_mat_10_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp2, lhs_mat_23_20_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp2, lhs_mat_23_21_sp2));\n                        __m512i iacc_mat_11_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp2, lhs_mat_23_20_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp2, lhs_mat_23_21_sp2));\n\n                        __m512i iacc_mat_00_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp2, lhs_mat_01_30_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp2, lhs_mat_01_31_sp2));\n                        __m512i iacc_mat_01_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp2, lhs_mat_01_30_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp2, lhs_mat_01_31_sp2));\n\n                        __m512i iacc_mat_10_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp2, lhs_mat_23_30_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp2, lhs_mat_23_31_sp2));\n                        __m512i iacc_mat_11_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp2, lhs_mat_23_30_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp2, lhs_mat_23_31_sp2));\n\n                        __m512i iacc_mat_00_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp2, lhs_mat_01_40_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp2, lhs_mat_01_41_sp2));\n                        __m512i iacc_mat_01_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp2, lhs_mat_01_40_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp2, lhs_mat_01_41_sp2));\n\n                        __m512i iacc_mat_10_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp2, lhs_mat_23_40_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp2, lhs_mat_23_41_sp2));\n                        __m512i iacc_mat_11_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp2, lhs_mat_23_40_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp2, lhs_mat_23_41_sp2));\n\n                        __m512i iacc_mat_00_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp2, lhs_mat_01_50_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp2, lhs_mat_01_51_sp2));\n                        __m512i iacc_mat_01_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp2, lhs_mat_01_50_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp2, lhs_mat_01_51_sp2));\n\n                        __m512i iacc_mat_10_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp2, lhs_mat_23_50_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp2, lhs_mat_23_51_sp2));\n                        __m512i iacc_mat_11_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp2, lhs_mat_23_50_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp2, lhs_mat_23_51_sp2));\n\n                        __m512i iacc_mat_00_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp2, lhs_mat_01_60_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp2, lhs_mat_01_61_sp2));\n                        __m512i iacc_mat_01_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp2, lhs_mat_01_60_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp2, lhs_mat_01_61_sp2));\n\n                        __m512i iacc_mat_10_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp2, lhs_mat_23_60_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp2, lhs_mat_23_61_sp2));\n                        __m512i iacc_mat_11_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp2, lhs_mat_23_60_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp2, lhs_mat_23_61_sp2));\n\n                        __m512i iacc_mat_00_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp2, lhs_mat_01_70_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp2, lhs_mat_01_71_sp2));\n                        __m512i iacc_mat_01_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp2, lhs_mat_01_70_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp2, lhs_mat_01_71_sp2));\n\n                        __m512i iacc_mat_10_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp2, lhs_mat_23_70_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp2, lhs_mat_23_71_sp2));\n                        __m512i iacc_mat_11_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp2, lhs_mat_23_70_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp2, lhs_mat_23_71_sp2));\n\n                        // Combine results from both shuffle patterns for each output block\n                        __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);\n                        __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);\n                        __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);\n                        __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);\n\n                        __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);\n                        __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);\n                        __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);\n                        __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);\n\n                        __m512i iacc_mat_00_2 = _mm512_add_epi16(iacc_mat_00_2_sp1, iacc_mat_00_2_sp2);\n                        __m512i iacc_mat_01_2 = _mm512_add_epi16(iacc_mat_01_2_sp1, iacc_mat_01_2_sp2);\n                        __m512i iacc_mat_10_2 = _mm512_add_epi16(iacc_mat_10_2_sp1, iacc_mat_10_2_sp2);\n                        __m512i iacc_mat_11_2 = _mm512_add_epi16(iacc_mat_11_2_sp1, iacc_mat_11_2_sp2);\n\n                        __m512i iacc_mat_00_3 = _mm512_add_epi16(iacc_mat_00_3_sp1, iacc_mat_00_3_sp2);\n                        __m512i iacc_mat_01_3 = _mm512_add_epi16(iacc_mat_01_3_sp1, iacc_mat_01_3_sp2);\n                        __m512i iacc_mat_10_3 = _mm512_add_epi16(iacc_mat_10_3_sp1, iacc_mat_10_3_sp2);\n                        __m512i iacc_mat_11_3 = _mm512_add_epi16(iacc_mat_11_3_sp1, iacc_mat_11_3_sp2);\n\n                        __m512i iacc_mat_00_4 = _mm512_add_epi16(iacc_mat_00_4_sp1, iacc_mat_00_4_sp2);\n                        __m512i iacc_mat_01_4 = _mm512_add_epi16(iacc_mat_01_4_sp1, iacc_mat_01_4_sp2);\n                        __m512i iacc_mat_10_4 = _mm512_add_epi16(iacc_mat_10_4_sp1, iacc_mat_10_4_sp2);\n                        __m512i iacc_mat_11_4 = _mm512_add_epi16(iacc_mat_11_4_sp1, iacc_mat_11_4_sp2);\n\n                        __m512i iacc_mat_00_5 = _mm512_add_epi16(iacc_mat_00_5_sp1, iacc_mat_00_5_sp2);\n                        __m512i iacc_mat_01_5 = _mm512_add_epi16(iacc_mat_01_5_sp1, iacc_mat_01_5_sp2);\n                        __m512i iacc_mat_10_5 = _mm512_add_epi16(iacc_mat_10_5_sp1, iacc_mat_10_5_sp2);\n                        __m512i iacc_mat_11_5 = _mm512_add_epi16(iacc_mat_11_5_sp1, iacc_mat_11_5_sp2);\n\n                        __m512i iacc_mat_00_6 = _mm512_add_epi16(iacc_mat_00_6_sp1, iacc_mat_00_6_sp2);\n                        __m512i iacc_mat_01_6 = _mm512_add_epi16(iacc_mat_01_6_sp1, iacc_mat_01_6_sp2);\n                        __m512i iacc_mat_10_6 = _mm512_add_epi16(iacc_mat_10_6_sp1, iacc_mat_10_6_sp2);\n                        __m512i iacc_mat_11_6 = _mm512_add_epi16(iacc_mat_11_6_sp1, iacc_mat_11_6_sp2);\n\n                        __m512i iacc_mat_00_7 = _mm512_add_epi16(iacc_mat_00_7_sp1, iacc_mat_00_7_sp2);\n                        __m512i iacc_mat_01_7 = _mm512_add_epi16(iacc_mat_01_7_sp1, iacc_mat_01_7_sp2);\n                        __m512i iacc_mat_10_7 = _mm512_add_epi16(iacc_mat_10_7_sp1, iacc_mat_10_7_sp2);\n                        __m512i iacc_mat_11_7 = _mm512_add_epi16(iacc_mat_11_7_sp1, iacc_mat_11_7_sp2);\n\n                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block\n                        iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0);\n                        iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0);\n                        iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0);\n                        iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0);\n\n                        iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1);\n                        iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1);\n                        iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1);\n                        iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1);\n\n                        iacc_mat_00_2 = _mm512_madd_epi16(iacc_mat_00_2, scale_014589CD_2);\n                        iacc_mat_01_2 = _mm512_madd_epi16(iacc_mat_01_2, scale_2367ABEF_2);\n                        iacc_mat_10_2 = _mm512_madd_epi16(iacc_mat_10_2, scale_014589CD_2);\n                        iacc_mat_11_2 = _mm512_madd_epi16(iacc_mat_11_2, scale_2367ABEF_2);\n\n                        iacc_mat_00_3 = _mm512_madd_epi16(iacc_mat_00_3, scale_014589CD_3);\n                        iacc_mat_01_3 = _mm512_madd_epi16(iacc_mat_01_3, scale_2367ABEF_3);\n                        iacc_mat_10_3 = _mm512_madd_epi16(iacc_mat_10_3, scale_014589CD_3);\n                        iacc_mat_11_3 = _mm512_madd_epi16(iacc_mat_11_3, scale_2367ABEF_3);\n\n                        iacc_mat_00_4 = _mm512_madd_epi16(iacc_mat_00_4, scale_014589CD_4);\n                        iacc_mat_01_4 = _mm512_madd_epi16(iacc_mat_01_4, scale_2367ABEF_4);\n                        iacc_mat_10_4 = _mm512_madd_epi16(iacc_mat_10_4, scale_014589CD_4);\n                        iacc_mat_11_4 = _mm512_madd_epi16(iacc_mat_11_4, scale_2367ABEF_4);\n\n                        iacc_mat_00_5 = _mm512_madd_epi16(iacc_mat_00_5, scale_014589CD_5);\n                        iacc_mat_01_5 = _mm512_madd_epi16(iacc_mat_01_5, scale_2367ABEF_5);\n                        iacc_mat_10_5 = _mm512_madd_epi16(iacc_mat_10_5, scale_014589CD_5);\n                        iacc_mat_11_5 = _mm512_madd_epi16(iacc_mat_11_5, scale_2367ABEF_5);\n\n                        iacc_mat_00_6 = _mm512_madd_epi16(iacc_mat_00_6, scale_014589CD_6);\n                        iacc_mat_01_6 = _mm512_madd_epi16(iacc_mat_01_6, scale_2367ABEF_6);\n                        iacc_mat_10_6 = _mm512_madd_epi16(iacc_mat_10_6, scale_014589CD_6);\n                        iacc_mat_11_6 = _mm512_madd_epi16(iacc_mat_11_6, scale_2367ABEF_6);\n\n                        iacc_mat_00_7 = _mm512_madd_epi16(iacc_mat_00_7, scale_014589CD_7);\n                        iacc_mat_01_7 = _mm512_madd_epi16(iacc_mat_01_7, scale_2367ABEF_7);\n                        iacc_mat_10_7 = _mm512_madd_epi16(iacc_mat_10_7, scale_014589CD_7);\n                        iacc_mat_11_7 = _mm512_madd_epi16(iacc_mat_11_7, scale_2367ABEF_7);\n\n                        __m512i iacc_mat_00 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_00_0, iacc_mat_00_1), _mm512_add_epi32(iacc_mat_00_2, iacc_mat_00_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_00_4, iacc_mat_00_5), _mm512_add_epi32(iacc_mat_00_6, iacc_mat_00_7)));\n                        __m512i iacc_mat_01 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_01_0, iacc_mat_01_1), _mm512_add_epi32(iacc_mat_01_2, iacc_mat_01_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_01_4, iacc_mat_01_5), _mm512_add_epi32(iacc_mat_01_6, iacc_mat_01_7)));\n                        __m512i iacc_mat_10 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_10_0, iacc_mat_10_1), _mm512_add_epi32(iacc_mat_10_2, iacc_mat_10_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_10_4, iacc_mat_10_5), _mm512_add_epi32(iacc_mat_10_6, iacc_mat_10_7)));\n                        __m512i iacc_mat_11 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_11_0, iacc_mat_11_1), _mm512_add_epi32(iacc_mat_11_2, iacc_mat_11_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_11_4, iacc_mat_11_5), _mm512_add_epi32(iacc_mat_11_6, iacc_mat_11_7)));\n\n                        // Straighten out to make 4 row vectors\n                        __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));\n                        __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);\n                        __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));\n                        __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);\n\n                        // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes\n                        const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);\n                        const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);\n                        const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1);\n\n                        // Multiply with appropriate scales and accumulate (for both d and dmin) below\n                        acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);\n                        acc_rows[rp * 4  + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);\n                        acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);\n                        acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);\n\n                        // Take two bsums from two Q8_Ks at a time and multiply with corresponding mins values from each Q2_K\n                        __m512i iacc_row_min_0_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)0), mins_01);\n                        __m512i iacc_row_min_1_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)170), mins_01);\n                        __m512i iacc_row_min_2_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)0), mins_01);\n                        __m512i iacc_row_min_3_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)170), mins_01);\n\n                        __m512i iacc_row_min_0_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)85), mins_23);\n                        __m512i iacc_row_min_1_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)255), mins_23);\n                        __m512i iacc_row_min_2_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)85), mins_23);\n                        __m512i iacc_row_min_3_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)255), mins_23);\n\n                        __m512i iacc_row_min_0_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)0), mins_45);\n                        __m512i iacc_row_min_1_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)170), mins_45);\n                        __m512i iacc_row_min_2_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)0), mins_45);\n                        __m512i iacc_row_min_3_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)170), mins_45);\n\n                        __m512i iacc_row_min_0_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)85), mins_67);\n                        __m512i iacc_row_min_1_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)255), mins_67);\n                        __m512i iacc_row_min_2_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)85), mins_67);\n                        __m512i iacc_row_min_3_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)255), mins_67);\n\n                        __m512i iacc_row_min_0 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_0_01, iacc_row_min_0_23), _mm512_add_epi32(iacc_row_min_0_45,iacc_row_min_0_67));\n                        __m512i iacc_row_min_1 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_1_01, iacc_row_min_1_23), _mm512_add_epi32(iacc_row_min_1_45,iacc_row_min_1_67));\n                        __m512i iacc_row_min_2 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_2_01, iacc_row_min_2_23), _mm512_add_epi32(iacc_row_min_2_45,iacc_row_min_2_67));\n                        __m512i iacc_row_min_3 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_3_01, iacc_row_min_3_23), _mm512_add_epi32(iacc_row_min_3_45,iacc_row_min_3_67));\n\n                        acc_min_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);\n                        acc_min_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);\n                        acc_min_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);\n                        acc_min_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);\n                    }\n                }\n            }\n            // Store the accumulated values\n            for (int i = 0; i < 16; i++) {\n                _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i]));\n            }\n        }\n    }\n\n    for (; y < nr / 4; y ++) {\n\n        const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);\n\n        // Take group of eight block_q2_kx8 structures at each pass of the loop and perform dot product operation\n        for (int64_t x = 0; x < anc / 8; x += 2) {\n\n            const block_q2_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);\n            const block_q2_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);\n\n            // Master FP accumulators\n            __m512 acc_rows[4];\n            for (int i = 0; i < 4; i++) {\n                acc_rows[i] = _mm512_setzero_ps();\n            }\n\n            __m512 acc_min_rows[4];\n            for (int i = 0; i < 4; i++) {\n                acc_min_rows[i] = _mm512_setzero_ps();\n            }\n            // For super block\n            for (int64_t b = 0; b < nb; b++) {\n                // Delta values - Load the sixteen scale values from two block_q2_kx8 structures\n                const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);\n\n                // dmin values - Load the sixteen dmin values from two block_q2_kx8 structures\n                const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin);\n\n                // Loop to iterate over the sixteen sub blocks of a super block - eight sub blocks are processed per iteration\n                for (int sb = 0; sb < QK_K / 128; sb++) {\n\n                    // Load the eight block_q2_k for eight sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7\n                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256));\n                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256));\n\n                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256));\n                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256));\n                    const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256));\n                    const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256));\n                    const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256));\n\n                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);\n                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);\n                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);\n                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);\n\n                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);\n                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);\n                    const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240);\n                    const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240);\n\n                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);\n                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);\n\n                    const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1);\n                    const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1);\n                    const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1);\n\n                    //2-bit -> 8-bit\n                    const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0,m3bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7)\n                    const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0,m3bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7)\n                    const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1,m3bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15)\n                    const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1,m3bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15)\n                    const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(rhs_raw_mat_014589CD_2,m3bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7)\n                    const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2,m3bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7)\n                    const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(rhs_raw_mat_014589CD_3,m3bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15)\n                    const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3,m3bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15)\n\n                    const __m512i rhs_mat_014589CD_20 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 2), m3bexpanded); //B20(0-7) B21(0-7) B24(0-7) B25(0-7) B28(0-7) B29(0-7) B2C(0-7) B2D(0-7)\n                    const __m512i rhs_mat_2367ABEF_20 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 2), m3bexpanded); //B22(0-7) B23(0-7) B26(0-7) B27(0-7) B2A(0-7) B2B(0-7) B2E(0-7) B2F(0-7)\n\n                    const __m512i rhs_mat_014589CD_21 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 2), m3bexpanded); //B20(8-15) B21(8-15) B24(8-15) B25(8-15) B28(8-15) B29(8-15) B2C(8-15) B2D(8-15)\n                    const __m512i rhs_mat_2367ABEF_21 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 2), m3bexpanded); //B22(8-15) B23(8-15) B26(8-15) B27(8-15) B2A(8-15) B2B(8-15) B2E(8-15) B2F(8-15)\n\n                    const __m512i rhs_mat_014589CD_30 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 2), m3bexpanded); //B30(0-7) B31(0-7) B34(0-7) B35(0-7) B38(0-7) B39(0-7) B3C(0-7) B3D(0-7)\n                    const __m512i rhs_mat_2367ABEF_30 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 2), m3bexpanded); //B32(0-7) B33(0-7) B36(0-7) B37(0-7) B3A(0-7) B3B(0-7) B3E(0-7) B3F(0-7)\n\n                    const __m512i rhs_mat_014589CD_31 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 2), m3bexpanded); //B30(8-15) B31(8-15) B34(8-15) B35(8-15) B38(8-15) B39(8-15) B3C(8-15) B3D(8-15)\n                    const __m512i rhs_mat_2367ABEF_31 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 2), m3bexpanded); //B32(8-15) B33(8-15) B36(8-15) B37(8-15) B3A(8-15) B3B(8-15) B3E(8-15) B3F(8-15)\n\n                    const __m512i rhs_mat_014589CD_40 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m3bexpanded); //B40(0-7) B41(0-7) B44(0-7) B45(0-7) B48(0-7) B49(0-7) B4C(0-7) B4D(0-7)\n                    const __m512i rhs_mat_2367ABEF_40 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m3bexpanded); //B42(0-7) B43(0-7) B46(0-7) B47(0-7) B4A(0-7) B4B(0-7) B4E(0-7) B4F(0-7)\n\n                    const __m512i rhs_mat_014589CD_41 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m3bexpanded); //B40(8-15) B41(8-15) B44(8-15) B45(8-15) B48(8-15) B49(8-15) B4C(8-15) B4D(8-15)\n                    const __m512i rhs_mat_2367ABEF_41 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m3bexpanded); //B42(8-15) B43(8-15) B46(8-15) B47(8-15) B4A(8-15) B4B(8-15) B4E(8-15) B4F(8-15)\n\n                    const __m512i rhs_mat_014589CD_50 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m3bexpanded); //B50(0-7) B51(0-7) B54(0-7) B55(0-7) B58(0-7) B59(0-7) B5C(0-7) B5D(0-7)\n                    const __m512i rhs_mat_2367ABEF_50 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m3bexpanded); //B52(0-7) B53(0-7) B56(0-7) B57(0-7) B5A(0-7) B5B(0-7) B5E(0-7) B5F(0-7)\n\n                    const __m512i rhs_mat_014589CD_51 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m3bexpanded); //B50(8-15) B51(8-15) B54(8-15) B55(8-15) B58(8-15) B59(8-15) B5C(8-15) B5D(8-15)\n                    const __m512i rhs_mat_2367ABEF_51 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m3bexpanded); //B52(8-15) B53(8-15) B56(8-15) B57(8-15) B5A(8-15) B5B(8-15) B5E(8-15) B5F(8-15)\n\n                    const __m512i rhs_mat_014589CD_60 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 6), m3bexpanded); //B60(0-7) B61(0-7) B64(0-7) B65(0-7) B68(0-7) B69(0-7) B6C(0-7) B6D(0-7)\n                    const __m512i rhs_mat_2367ABEF_60 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 6), m3bexpanded); //B62(0-7) B63(0-7) B66(0-7) B67(0-7) B6A(0-7) B6B(0-7) B6E(0-7) B6F(0-7)\n\n                    const __m512i rhs_mat_014589CD_61 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 6), m3bexpanded); //B60(8-15) B61(8-15) B64(8-15) B65(8-15) B68(8-15) B69(8-15) B6C(8-15) B6D(8-15)\n                    const __m512i rhs_mat_2367ABEF_61 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 6), m3bexpanded); //B62(8-15) B63(8-15) B66(8-15) B67(8-15) B6A(8-15) B6B(8-15) B6E(8-15) B6F(8-15)\n\n                    const __m512i rhs_mat_014589CD_70 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 6), m3bexpanded); //B70(0-7) B71(0-7) B74(0-7) B75(0-7) B78(0-7) B79(0-7) B7C(0-7) B7D(0-7)\n                    const __m512i rhs_mat_2367ABEF_70 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 6), m3bexpanded); //B72(0-7) B73(0-7) B76(0-7) B77(0-7) B7A(0-7) B7B(0-7) B7E(0-7) B7F(0-7)\n\n                    const __m512i rhs_mat_014589CD_71 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 6), m3bexpanded); //B70(8-15) B71(8-15) B74(8-15) B75(8-15) B78(8-15) B79(8-15) B7C(8-15) B7D(8-15)\n                    const __m512i rhs_mat_2367ABEF_71 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 6), m3bexpanded); //B72(8-15) B73(8-15) B76(8-15) B77(8-15) B7A(8-15) B7B(8-15) B7E(8-15) B7F(8-15)\n\n                    const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3)\n                    const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3)\n\n                    const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)\n                    const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11)\n\n                    const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3)\n                    const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3)\n\n                    const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11)\n                    const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11)\n\n                    const __m512i rhs_mat_014589CD_20_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_20, (_MM_PERM_ENUM)136); //B20(0-3) B21(0-3) B20(0-3) B21(0-3) B24(0-3) B25(0-3) B24(0-3) B25(0-3) B28(0-3) B29(0-3) B28(0-3) B29(0-3) B2C(0-3) B2D(0-3) B2C(0-3) B2D(0-3)\n                    const __m512i rhs_mat_2367ABEF_20_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_20, (_MM_PERM_ENUM)136); //B22(0-3) B23(0-3) B22(0-3) B23(0-3) B26(0-3) B27(0-3) B26(0-3) B27(0-3) B2A(0-3) B2B(0-3) B2A(0-3) B2B(0-3) B2E(0-3) B2F(0-3) B2E(0-3) B2F(0-3)\n\n                    const __m512i rhs_mat_014589CD_21_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_21, (_MM_PERM_ENUM)136); //B20(8-11) B21(8-11) B20(8-11) B21(8-11) B24(8-11) B25(8-11) B24(8-11) B25(8-11) B28(8-11) B29(8-11) B28(8-11) B29(8-11) B2C(8-11) B2D(8-11) B2C(8-11) B2D(8-11)\n                    const __m512i rhs_mat_2367ABEF_21_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_21, (_MM_PERM_ENUM)136); //B22(8-11) B23(8-11) B22(8-11) B23(8-11) B26(8-11) B27(8-11) B26(8-11) B27(8-11) B2A(8-11) B2B(8-11) B2A(8-11) B2B(8-11) B2E(8-11) B2F(8-11) B2E(8-11) B2F(8-11)\n                    const __m512i rhs_mat_014589CD_30_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_30, (_MM_PERM_ENUM)136); ///B30(0-3) B31(0-3) B30(0-3) B31(0-3) B34(0-3) B35(0-3) B34(0-3) B35(0-3) B38(0-3) B39(0-3) B38(0-3) B39(0-3) B3C(0-3) B3D(0-3) B3C(0-3) B3D(0-3)\n                    const __m512i rhs_mat_2367ABEF_30_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_30, (_MM_PERM_ENUM)136); //B32(0-3) B33(0-3) B32(0-3) B33(0-3) B36(0-3) B37(0-3) B36(0-3) B37(0-3) B3A(0-3) B3B(0-3) B3A(0-3) B3B(0-3) B3E(0-3) B3F(0-3) B3E(0-3) B3F(0-3)\n\n                    const __m512i rhs_mat_014589CD_31_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_31, (_MM_PERM_ENUM)136); //B30(8-11) B31(8-11) B30(8-11) B31(8-11) B34(8-11) B35(8-11) B34(8-11) B35(8-11) B38(8-11) B39(8-11) B38(8-11) B39(8-11) B3C(8-11) B3D(8-11) B3C(8-11) B3D(8-11)\n                    const __m512i rhs_mat_2367ABEF_31_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_31, (_MM_PERM_ENUM)136); //B32(8-11) B33(8-11) B32(8-11) B33(8-11) B36(8-11) B37(8-11) B36(8-11) B37(8-11) B3A(8-11) B3B(8-11) B3A(8-11) B3B(8-11) B3E(8-11) B3F(8-11) B3E(8-11) B3F(8-11)\n\n                    const __m512i rhs_mat_014589CD_40_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_40, (_MM_PERM_ENUM)136); //B40(0-3) B41(0-3) B40(0-3) B41(0-3) B44(0-3) B45(0-3) B44(0-3) B45(0-3) B48(0-3) B49(0-3) B48(0-3) B49(0-3) B4C(0-3) B4D(0-3) B4C(0-3) B4D(0-3)\n                    const __m512i rhs_mat_2367ABEF_40_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_40, (_MM_PERM_ENUM)136); //B42(0-3) B43(0-3) B42(0-3) B43(0-3) B46(0-3) B47(0-3) B46(0-3) B47(0-3) B4A(0-3) B4B(0-3) B4A(0-3) B4B(0-3) B4E(0-3) B4F(0-3) B4E(0-3) B4F(0-3)\n\n                    const __m512i rhs_mat_014589CD_41_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_41, (_MM_PERM_ENUM)136); //B40(8-11) B41(8-11) B40(8-11) B41(8-11) B44(8-11) B45(8-11) B44(8-11) B45(8-11) B48(8-11) B49(8-11) B48(8-11) B49(8-11) B4C(8-11) B4D(8-11) B4C(8-11) B4D(8-11)\n                    const __m512i rhs_mat_2367ABEF_41_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_41, (_MM_PERM_ENUM)136); //B42(8-11) B43(8-11) B42(8-11) B43(8-11) B46(8-11) B47(8-11) B46(8-11) B47(8-11) B4A(8-11) B4B(8-11) B4A(8-11) B4B(8-11) B4E(8-11) B4F(8-11) B4E(8-11) B4F(8-11)\n\n                    const __m512i rhs_mat_014589CD_50_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_50, (_MM_PERM_ENUM)136); //B50(0-3) B51(0-3) B50(0-3) B51(0-3) B54(0-3) B55(0-3) B54(0-3) B55(0-3) B58(0-3) B59(0-3) B58(0-3) B59(0-3) B5C(0-3) B5D(0-3) B5C(0-3) B5D(0-3)\n                    const __m512i rhs_mat_2367ABEF_50_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_50, (_MM_PERM_ENUM)136); //B52(0-3) B53(0-3) B52(0-3) B53(0-3) B56(0-3) B57(0-3) B56(0-3) B57(0-3) B5A(0-3) B5B(0-3) B5A(0-3) B5B(0-3) B5E(0-3) B5F(0-3) B5E(0-3) B5F(0-3)\n\n                    const __m512i rhs_mat_014589CD_51_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_51, (_MM_PERM_ENUM)136); //B50(8-11) B51(8-11) B50(8-11) B51(8-11) B54(8-11) B55(8-11) B54(8-11) B55(8-11) B58(8-11) B59(8-11) B58(8-11) B59(8-11) B5C(8-11) B5D(8-11) B5C(8-11) B5D(8-11)\n                    const __m512i rhs_mat_2367ABEF_51_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_51, (_MM_PERM_ENUM)136); //B52(8-11) B53(8-11) B52(8-11) B53(8-11) B56(8-11) B57(8-11) B56(8-11) B57(8-11) B5A(8-11) B5B(8-11) B5A(8-11) B5B(8-11) B5E(8-11) B5F(8-11) B5E(8-11) B5F(8-11)\n\n                    const __m512i rhs_mat_014589CD_60_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_60, (_MM_PERM_ENUM)136); //B60(0-3) B61(0-3) B60(0-3) B61(0-3) B64(0-3) B65(0-3) B64(0-3) B65(0-3) B68(0-3) B69(0-3) B68(0-3) B69(0-3) B6C(0-3) B6D(0-3) B6C(0-3) B6D(0-3)\n                    const __m512i rhs_mat_2367ABEF_60_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_60, (_MM_PERM_ENUM)136); //B62(0-3) B63(0-3) B62(0-3) B63(0-3) B66(0-3) B67(0-3) B66(0-3) B67(0-3) B6A(0-3) B6B(0-3) B6A(0-3) B6B(0-3) B6E(0-3) B6F(0-3) B6E(0-3) B6F(0-3)\n\n                    const __m512i rhs_mat_014589CD_61_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_61, (_MM_PERM_ENUM)136); //B60(8-11) B61(8-11) B60(8-11) B61(8-11) B64(8-11) B65(8-11) B64(8-11) B65(8-11) B68(8-11) B69(8-11) B68(8-11) B69(8-11) B6C(8-11) B6D(8-11) B6C(8-11) B6D(8-11)\n                    const __m512i rhs_mat_2367ABEF_61_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_61, (_MM_PERM_ENUM)136); //B62(8-11) B63(8-11) B62(8-11) B63(8-11) B66(8-11) B67(8-11) B66(8-11) B67(8-11) B6A(8-11) B6B(8-11) B6A(8-11) B6B(8-11) B6E(8-11) B6F(8-11) B6E(8-11) B6F(8-11)\n\n                    const __m512i rhs_mat_014589CD_70_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_70, (_MM_PERM_ENUM)136); //B70(0-3) B71(0-3) B70(0-3) B71(0-3) B74(0-3) B75(0-3) B74(0-3) B75(0-3) B78(0-3) B79(0-3) B78(0-3) B79(0-3) B7C(0-3) B7D(0-3) B7C(0-3) B7D(0-3)\n                    const __m512i rhs_mat_2367ABEF_70_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_70, (_MM_PERM_ENUM)136); //B72(0-3) B73(0-3) B72(0-3) B73(0-3) B76(0-3) B77(0-3) B76(0-3) B77(0-3) B7A(0-3) B7B(0-3) B7A(0-3) B7B(0-3) B7E(0-3) B7F(0-3) B7E(0-3) B7F(0-3)\n\n                    const __m512i rhs_mat_014589CD_71_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_71, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)\n                    const __m512i rhs_mat_2367ABEF_71_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_71, (_MM_PERM_ENUM)136); //B72(8-11) B73(8-11) B72(8-11) B73(8-11) B76(8-11) B77(8-11) B76(8-11) B77(8-11) B7A(8-11) B7B(8-11) B7A(8-11) B7B(8-11) B7E(8-11) B7F(8-11) B7E(8-11) B7F(8-11)\n\n                    const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7)\n                    const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7)\n\n                    const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15)\n                    const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15)\n\n                    const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7)\n                    const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7)\n\n                    const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15)\n                    const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15)\n\n                    const __m512i rhs_mat_014589CD_20_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_20, (_MM_PERM_ENUM)221); //B20(4-7) B21(4-7) B20(4-7) B21(4-7) B24(4-7) B25(4-7) B24(4-7) B25(4-7) B28(4-7) B29(4-7) B28(4-7) B29(4-7) B2C(4-7) B2D(4-7) B2C(4-7) B2D(4-7)\n                    const __m512i rhs_mat_2367ABEF_20_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_20, (_MM_PERM_ENUM)221); //B22(4-7) B23(4-7) B22(4-7) B23(4-7) B26(4-7) B27(4-7) B26(4-7) B27(4-7) B2A(4-7) B2B(4-7) B2A(4-7) B2B(4-7) B2E(4-7) B2F(4-7) B2E(4-7) B2F(4-7)\n\n                    const __m512i rhs_mat_014589CD_21_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_21, (_MM_PERM_ENUM)221); //B20(12-15) B21(12-15) B20(12-15) B21(12-15) B24(12-15) B25(12-15) B24(12-15) B25(12-15) B28(12-15) B29(12-15) B28(12-15) B29(12-15) B2C(12-15) B2D(12-15) B2C(12-15) B2D(12-15)\n                    const __m512i rhs_mat_2367ABEF_21_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_21, (_MM_PERM_ENUM)221); //B22(12-15) B23(12-15) B22(12-15) B23(12-15) B26(12-15) B27(12-15) B26(12-15) B27(12-15) B2A(12-15) B2B(12-15) B2A(12-15) B2B(12-15) B2E(12-15) B2F(12-15) B2E(12-15) B2F(12-15)\n\n                    const __m512i rhs_mat_014589CD_30_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_30, (_MM_PERM_ENUM)221); //B30(4-7) B31(4-7) B30(4-7) B31(4-7) B34(4-7) B35(4-7) B34(4-7) B35(4-7) B38(4-7) B39(4-7) B38(4-7) B39(4-7) B3C(4-7) B3D(4-7) B3C(4-7) B3D(4-7)\n                    const __m512i rhs_mat_2367ABEF_30_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_30, (_MM_PERM_ENUM)221); //B32(4-7) B33(4-7) B32(4-7) B33(4-7) B36(4-7) B37(4-7) B36(4-7) B37(4-7) B3A(4-7) B3B(4-7) B3A(4-7) B3B(4-7) B3E(4-7) B3F(4-7) B3E(4-7) B3F(4-7)\n\n                    const __m512i rhs_mat_014589CD_31_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_31, (_MM_PERM_ENUM)221); //B30(12-15) B31(12-15) B30(12-15) B31(12-15) B34(12-15) B35(12-15) B34(12-15) B35(12-15) B38(12-15) B39(12-15) B38(12-15) B39(12-15) B3C(12-15) B3D(12-15) B3C(12-15) B3D(12-15)\n                    const __m512i rhs_mat_2367ABEF_31_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_31, (_MM_PERM_ENUM)221); //B32(12-15) B33(12-15) B32(12-15) B33(12-15) B36(12-15) B37(12-15) B36(12-15) B37(12-15) B3A(12-15) B3B(12-15) B3A(12-15) B3B(12-15) B3E(12-15) B3F(12-15) B3E(12-15) B3F(12-15)\n\n                    const __m512i rhs_mat_014589CD_40_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_40, (_MM_PERM_ENUM)221); //B40(4-7) B41(4-7) B40(4-7) B41(4-7) B44(4-7) B45(4-7) B44(4-7) B45(4-7) B48(4-7) B49(4-7) B48(4-7) B49(4-7) B4C(4-7) B4D(4-7) B4C(4-7) B4D(4-7)\n                    const __m512i rhs_mat_2367ABEF_40_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_40, (_MM_PERM_ENUM)221); //B42(4-7) B43(4-7) B42(4-7) B43(4-7) B46(4-7) B47(4-7) B46(4-7) B47(4-7) B4A(4-7) B4B(4-7) B4A(4-7) B4B(4-7) B4E(4-7) B4F(4-7) B4E(4-7) B4F(4-7)\n\n                    const __m512i rhs_mat_014589CD_41_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_41, (_MM_PERM_ENUM)221); //B40(12-15) B41(12-15) B40(12-15) B41(12-15) B44(12-15) B45(12-15) B44(12-15) B45(12-15) B48(12-15) B49(12-15) B48(12-15) B49(12-15) B4C(12-15) B4D(12-15) B4C(12-15) B4D(12-15)\n                    const __m512i rhs_mat_2367ABEF_41_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_41, (_MM_PERM_ENUM)221); //B42(12-15) B43(12-15) B42(12-15) B43(12-15) B46(12-15) B47(12-15) B46(12-15) B47(12-15) B4A(12-15) B4B(12-15) B4A(12-15) B4B(12-15) B4E(12-15) B4F(12-15) B4E(12-15) B4F(12-15)\n\n                    const __m512i rhs_mat_014589CD_50_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_50, (_MM_PERM_ENUM)221); //B50(4-7) B51(4-7) B50(4-7) B51(4-7) B54(4-7) B55(4-7) B54(4-7) B55(4-7) B58(4-7) B59(4-7) B58(4-7) B59(4-7) B5C(4-7) B5D(4-7) B5C(4-7) B5D(4-7)\n                    const __m512i rhs_mat_2367ABEF_50_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_50, (_MM_PERM_ENUM)221); //B52(4-7) B53(4-7) B52(4-7) B53(4-7) B56(4-7) B57(4-7) B56(4-7) B57(4-7) B5A(4-7) B5B(4-7) B5A(4-7) B5B(4-7) B5E(4-7) B5F(4-7) B5E(4-7) B5F(4-7)\n\n                    const __m512i rhs_mat_014589CD_51_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_51, (_MM_PERM_ENUM)221); //B50(12-15) B51(12-15) B50(12-15) B51(12-15) B54(12-15) B55(12-15) B54(12-15) B55(12-15) B58(12-15) B59(12-15) B58(12-15) B59(12-15) B5C(12-15) B5D(12-15) B5C(12-15) B5D(12-15)\n                    const __m512i rhs_mat_2367ABEF_51_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_51, (_MM_PERM_ENUM)221); //B52(12-15) B53(12-15) B52(12-15) B53(12-15) B56(12-15) B57(12-15) B56(12-15) B57(12-15) B5A(12-15) B5B(12-15) B5A(12-15) B5B(12-15) B5E(12-15) B5F(12-15) B5E(12-15) B5F(12-15)\n\n                    const __m512i rhs_mat_014589CD_60_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_60, (_MM_PERM_ENUM)221); //B60(4-7) B61(4-7) B60(4-7) B61(4-7) B64(4-7) B65(4-7) B64(4-7) B65(4-7) B68(4-7) B69(4-7) B68(4-7) B69(4-7) B6C(4-7) B6D(4-7) B6C(4-7) B6D(4-7)\n                    const __m512i rhs_mat_2367ABEF_60_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_60, (_MM_PERM_ENUM)221); //B62(4-7) B63(4-7) B62(4-7) B63(4-7) B66(4-7) B67(4-7) B66(4-7) B67(4-7) B6A(4-7) B6B(4-7) B6A(4-7) B6B(4-7) B6E(4-7) B6F(4-7) B6E(4-7) B6F(4-7)\n\n                    const __m512i rhs_mat_014589CD_61_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_61, (_MM_PERM_ENUM)221); //B60(12-15) B61(12-15) B60(12-15) B61(12-15) B64(12-15) B65(12-15) B64(12-15) B65(12-15) B68(12-15) B69(12-15) B68(12-15) B69(12-15) B6C(12-15) B6D(12-15) B6C(12-15) B6D(12-15)\n                    const __m512i rhs_mat_2367ABEF_61_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_61, (_MM_PERM_ENUM)221); //B62(12-15) B63(12-15) B62(12-15) B63(12-15) B66(12-15) B67(12-15) B66(12-15) B67(12-15) B6A(12-15) B6B(12-15) B6A(12-15) B6B(12-15) B6E(12-15) B6F(12-15) B6E(12-15) B6F(12-15)\n\n                    const __m512i rhs_mat_014589CD_70_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_70, (_MM_PERM_ENUM)221); //B70(4-7) B71(4-7) B70(4-7) B71(4-7) B74(4-7) B75(4-7) B74(4-7) B75(4-7) B78(4-7) B79(4-7) B78(4-7) B79(4-7) B7C(4-7) B7D(4-7) B7C(4-7) B7D(4-7)\n                    const __m512i rhs_mat_2367ABEF_70_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_70, (_MM_PERM_ENUM)221); //B72(4-7) B73(4-7) B72(4-7) B73(4-7) B76(4-7) B77(4-7) B76(4-7) B77(4-7) B7A(4-7) B7B(4-7) B7A(4-7) B7B(4-7) B7E(4-7) B7F(4-7) B7E(4-7) B7F(4-7)\n\n                    const __m512i rhs_mat_014589CD_71_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_71, (_MM_PERM_ENUM)221); //B70(12-15) B71(12-15) B70(12-15) B71(12-15) B74(12-15) B75(12-15) B74(12-15) B75(12-15) B78(12-15) B79(12-15) B78(12-15) B79(12-15) B7C(12-15) B7D(12-15) B7C(12-15) B7D(12-15)\n                    const __m512i rhs_mat_2367ABEF_71_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_71, (_MM_PERM_ENUM)221); //B72(12-15) B73(12-15) B72(12-15) B73(12-15) B76(12-15) B77(12-15) B76(12-15) B77(12-15) B7A(12-15) B7B(12-15) B7A(12-15) B7B(12-15) B7E(12-15) B7F(12-15) B7E(12-15) B7F(12-15)\n\n                    //notation:superblock subblock\n                    //s00 m00  s01 m01   s10 m10  s11 m11  s20 m20  s21 m21   s30 m30  s31 m31  s40 m40  s41 m41   s50 m50  s51 m51  s60 m60  s61 m61   s70 m70  s71 m71\n\n                    const __m128i mins_and_scales_01_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + sb * 64));\n                    const __m128i mins_and_scales_23_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 16 + sb * 64));\n                    const __m128i mins_and_scales_45_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 32 + sb * 64));\n                    const __m128i mins_and_scales_67_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 48 + sb * 64));\n\n                    const __m128i mins_and_scales_01_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + sb * 64));\n                    const __m128i mins_and_scales_23_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 16 + sb * 64));\n                    const __m128i mins_and_scales_45_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 32 + sb * 64));\n                    const __m128i mins_and_scales_67_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 48 + sb * 64));\n\n                    // Combine mins and scales for sub-blocks: 0-1, 2-3, 4-5, 6-7 in the sb loop\n                    const __m256i mins_and_scales_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_01_0), mins_and_scales_01_1, 1);\n                    const __m256i mins_and_scales_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_23_0), mins_and_scales_23_1, 1);\n                    const __m256i mins_and_scales_45 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_45_0), mins_and_scales_45_1, 1);\n                    const __m256i mins_and_scales_67 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_67_0), mins_and_scales_67_1, 1);\n\n                    // Extract scales which is lower half from mins_and_scales\n                    const __m256i scales_01 = _mm256_and_si256(mins_and_scales_01, m4b);\n                    const __m256i scales_23 = _mm256_and_si256(mins_and_scales_23, m4b);\n                    const __m256i scales_45 = _mm256_and_si256(mins_and_scales_45, m4b);\n                    const __m256i scales_67 = _mm256_and_si256(mins_and_scales_67, m4b);\n\n                    // Extract mins which is upper half from mins_and_scales\n                    const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_01, 4), m4b));\n                    const __m512i mins_23 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_23, 4), m4b));\n                    const __m512i mins_45 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_45, 4), m4b));\n                    const __m512i mins_67 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_67, 4), m4b));\n\n                    const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_01, scalesmask1));\n                    const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_01, scalesmask2));\n                    const __m512i scales_2 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_23, scalesmask1));\n                    const __m512i scales_3 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_23, scalesmask2));\n                    const __m512i scales_4 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_45, scalesmask1));\n                    const __m512i scales_5 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_45, scalesmask2));\n                    const __m512i scales_6 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_67, scalesmask1));\n                    const __m512i scales_7 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_67, scalesmask2));\n\n                    const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_2 = _mm512_shuffle_epi32(scales_2, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_2 = _mm512_shuffle_epi32(scales_2, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_3 = _mm512_shuffle_epi32(scales_3, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_3 = _mm512_shuffle_epi32(scales_3, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_4 = _mm512_shuffle_epi32(scales_4, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_4 = _mm512_shuffle_epi32(scales_4, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_5 = _mm512_shuffle_epi32(scales_5, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_5 = _mm512_shuffle_epi32(scales_5, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_6 = _mm512_shuffle_epi32(scales_6, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_6 = _mm512_shuffle_epi32(scales_6, (_MM_PERM_ENUM)238);\n\n                    const __m512i scale_014589CD_7 = _mm512_shuffle_epi32(scales_7, (_MM_PERM_ENUM)68);\n                    const __m512i scale_2367ABEF_7 = _mm512_shuffle_epi32(scales_7, (_MM_PERM_ENUM)238);\n\n                    // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3\n                    // Loaded as set of 128 bit vectors and repeated into a 256 bit vector\n                    __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0);\n                    __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17);\n                    __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0);\n                    __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17);\n                    __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0);\n                    __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17);\n                    __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0);\n                    __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17);\n                    __m256i lhs_mat_ymm_0123_20 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_20 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_20, lhs_mat_ymm_0123_20, 0);\n                    __m256i lhs_mat_ymm_23_20 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_20, lhs_mat_ymm_0123_20, 17);\n                    __m256i lhs_mat_ymm_0123_21 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_21 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_21, lhs_mat_ymm_0123_21, 0);\n                    __m256i lhs_mat_ymm_23_21 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_21, lhs_mat_ymm_0123_21, 17);\n                    __m256i lhs_mat_ymm_0123_30 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_30 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_30, lhs_mat_ymm_0123_30, 0);\n                    __m256i lhs_mat_ymm_23_30 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_30, lhs_mat_ymm_0123_30, 17);\n                    __m256i lhs_mat_ymm_0123_31 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_31 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_31, lhs_mat_ymm_0123_31, 0);\n                    __m256i lhs_mat_ymm_23_31 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_31, lhs_mat_ymm_0123_31, 17);\n\n                    __m256i lhs_mat_ymm_0123_40 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_40 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_40, lhs_mat_ymm_0123_40, 0);\n                    __m256i lhs_mat_ymm_23_40 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_40, lhs_mat_ymm_0123_40, 17);\n                    __m256i lhs_mat_ymm_0123_41 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 288 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_41 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_41, lhs_mat_ymm_0123_41, 0);\n                    __m256i lhs_mat_ymm_23_41 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_41, lhs_mat_ymm_0123_41, 17);\n                    __m256i lhs_mat_ymm_0123_50 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 320 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_50 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_50, lhs_mat_ymm_0123_50, 0);\n                    __m256i lhs_mat_ymm_23_50 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_50, lhs_mat_ymm_0123_50, 17);\n                    __m256i lhs_mat_ymm_0123_51 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 352 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_51 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_51, lhs_mat_ymm_0123_51, 0);\n                    __m256i lhs_mat_ymm_23_51 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_51, lhs_mat_ymm_0123_51, 17);\n                    __m256i lhs_mat_ymm_0123_60 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 384 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_60 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_60, lhs_mat_ymm_0123_60, 0);\n                    __m256i lhs_mat_ymm_23_60 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_60, lhs_mat_ymm_0123_60, 17);\n                    __m256i lhs_mat_ymm_0123_61 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 416 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_61 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_61, lhs_mat_ymm_0123_61, 0);\n                    __m256i lhs_mat_ymm_23_61 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_61, lhs_mat_ymm_0123_61, 17);\n                    __m256i lhs_mat_ymm_0123_70 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 448 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_70 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_70, lhs_mat_ymm_0123_70, 0);\n                    __m256i lhs_mat_ymm_23_70 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_70, lhs_mat_ymm_0123_70, 17);\n                    __m256i lhs_mat_ymm_0123_71 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 480 + 512 * sb)));\n                    __m256i lhs_mat_ymm_01_71 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_71, lhs_mat_ymm_0123_71, 0);\n                    __m256i lhs_mat_ymm_23_71 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_71, lhs_mat_ymm_0123_71, 17);\n\n                    __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1);\n                    __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1);\n                    __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1);\n                    __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1);\n\n                    __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1);\n                    __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1);\n                    __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1);\n                    __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1);\n\n                    __m512i lhs_mat_01_20 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_20), lhs_mat_ymm_01_20, 1);\n                    __m512i lhs_mat_23_20 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_20), lhs_mat_ymm_23_20, 1);\n                    __m512i lhs_mat_01_21 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_21), lhs_mat_ymm_01_21, 1);\n                    __m512i lhs_mat_23_21 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_21), lhs_mat_ymm_23_21, 1);\n\n                    __m512i lhs_mat_01_30 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_30), lhs_mat_ymm_01_30, 1);\n                    __m512i lhs_mat_23_30 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_30), lhs_mat_ymm_23_30, 1);\n                    __m512i lhs_mat_01_31 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_31), lhs_mat_ymm_01_31, 1);\n                    __m512i lhs_mat_23_31 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_31), lhs_mat_ymm_23_31, 1);\n\n                    __m512i lhs_mat_01_40 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_40), lhs_mat_ymm_01_40, 1);\n                    __m512i lhs_mat_23_40 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_40), lhs_mat_ymm_23_40, 1);\n                    __m512i lhs_mat_01_41 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_41), lhs_mat_ymm_01_41, 1);\n                    __m512i lhs_mat_23_41 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_41), lhs_mat_ymm_23_41, 1);\n\n                    __m512i lhs_mat_01_50 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_50), lhs_mat_ymm_01_50, 1);\n                    __m512i lhs_mat_23_50 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_50), lhs_mat_ymm_23_50, 1);\n                    __m512i lhs_mat_01_51 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_51), lhs_mat_ymm_01_51, 1);\n                    __m512i lhs_mat_23_51 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_51), lhs_mat_ymm_23_51, 1);\n\n                    __m512i lhs_mat_01_60 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_60), lhs_mat_ymm_01_60, 1);\n                    __m512i lhs_mat_23_60 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_60), lhs_mat_ymm_23_60, 1);\n                    __m512i lhs_mat_01_61 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_61), lhs_mat_ymm_01_61, 1);\n                    __m512i lhs_mat_23_61 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_61), lhs_mat_ymm_23_61, 1);\n\n                    __m512i lhs_mat_01_70 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_70), lhs_mat_ymm_01_70, 1);\n                    __m512i lhs_mat_23_70 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_70), lhs_mat_ymm_23_70, 1);\n                    __m512i lhs_mat_01_71 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_71), lhs_mat_ymm_01_71, 1);\n                    __m512i lhs_mat_23_71 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_71), lhs_mat_ymm_23_71, 1);\n\n                    // Bsums are loaded for the different Q8_K blocks\n                    __m128i lhs_raw_bsums_01_0123 = _mm_loadu_si128((const __m128i *)((a_ptr[b].bsums + 32 * sb)));\n                    __m128i lhs_raw_bsums_23_0123 = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + 8 + 32 * sb));\n                    __m128i lhs_raw_bsums_01_4567 = _mm_loadu_si128((const __m128i *)((a_ptr[b].bsums + 16 + 32 * sb)));\n                    __m128i lhs_raw_bsums_23_4567 = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + 24 + 32 * sb));\n\n                    __m256i lhs_bsums_ymm_01_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_0123), lhs_raw_bsums_01_0123, 1);\n                    __m512i lhs_bsums_01_0123 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_01_0123), lhs_bsums_ymm_01_0123, 1);\n                    __m256i lhs_bsums_ymm_23_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_0123), lhs_raw_bsums_23_0123, 1);\n                    __m512i lhs_bsums_23_0123 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_23_0123), lhs_bsums_ymm_23_0123, 1);\n                    __m256i lhs_bsums_ymm_01_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_4567), lhs_raw_bsums_01_4567, 1);\n                    __m512i lhs_bsums_01_4567 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_01_4567), lhs_bsums_ymm_01_4567, 1);\n                    __m256i lhs_bsums_ymm_23_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_4567), lhs_raw_bsums_23_4567, 1);\n                    __m512i lhs_bsums_23_4567 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_23_4567), lhs_bsums_ymm_23_4567, 1);\n\n                    // Shuffle pattern one - left side input\n                    const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)\n                    const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3)\n\n                    const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)\n                    const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11)\n\n                    const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)\n                    const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3)\n\n                    const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)\n                    const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11)\n\n                    const __m512i lhs_mat_01_20_sp1 = _mm512_shuffle_epi32(lhs_mat_01_20, (_MM_PERM_ENUM)160); //A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3)\n                    const __m512i lhs_mat_23_20_sp1 = _mm512_shuffle_epi32(lhs_mat_23_20, (_MM_PERM_ENUM)160); //A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3)\n\n                    const __m512i lhs_mat_01_21_sp1 = _mm512_shuffle_epi32(lhs_mat_01_21, (_MM_PERM_ENUM)160); //A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11)\n                    const __m512i lhs_mat_23_21_sp1 = _mm512_shuffle_epi32(lhs_mat_23_21, (_MM_PERM_ENUM)160); //A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11)\n\n                    const __m512i lhs_mat_01_30_sp1 = _mm512_shuffle_epi32(lhs_mat_01_30, (_MM_PERM_ENUM)160); //A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3)\n                    const __m512i lhs_mat_23_30_sp1 = _mm512_shuffle_epi32(lhs_mat_23_30, (_MM_PERM_ENUM)160); //A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3)\n\n                    const __m512i lhs_mat_01_31_sp1 = _mm512_shuffle_epi32(lhs_mat_01_31, (_MM_PERM_ENUM)160); //A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11)\n                    const __m512i lhs_mat_23_31_sp1 = _mm512_shuffle_epi32(lhs_mat_23_31, (_MM_PERM_ENUM)160); //A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11)\n\n                    const __m512i lhs_mat_01_40_sp1 = _mm512_shuffle_epi32(lhs_mat_01_40, (_MM_PERM_ENUM)160); //A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3)\n                    const __m512i lhs_mat_23_40_sp1 = _mm512_shuffle_epi32(lhs_mat_23_40, (_MM_PERM_ENUM)160); //A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3)\n\n                    const __m512i lhs_mat_01_41_sp1 = _mm512_shuffle_epi32(lhs_mat_01_41, (_MM_PERM_ENUM)160); //A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11)\n                    const __m512i lhs_mat_23_41_sp1 = _mm512_shuffle_epi32(lhs_mat_23_41, (_MM_PERM_ENUM)160); //A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11)\n\n                    const __m512i lhs_mat_01_50_sp1 = _mm512_shuffle_epi32(lhs_mat_01_50, (_MM_PERM_ENUM)160); //A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3)\n                    const __m512i lhs_mat_23_50_sp1 = _mm512_shuffle_epi32(lhs_mat_23_50, (_MM_PERM_ENUM)160); //A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3)\n\n                    const __m512i lhs_mat_01_51_sp1 = _mm512_shuffle_epi32(lhs_mat_01_51, (_MM_PERM_ENUM)160); //A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11)\n                    const __m512i lhs_mat_23_51_sp1 = _mm512_shuffle_epi32(lhs_mat_23_51, (_MM_PERM_ENUM)160); //A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11)\n\n                    const __m512i lhs_mat_01_60_sp1 = _mm512_shuffle_epi32(lhs_mat_01_60, (_MM_PERM_ENUM)160); //A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3)\n                    const __m512i lhs_mat_23_60_sp1 = _mm512_shuffle_epi32(lhs_mat_23_60, (_MM_PERM_ENUM)160); //A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3)\n\n                    const __m512i lhs_mat_01_61_sp1 = _mm512_shuffle_epi32(lhs_mat_01_61, (_MM_PERM_ENUM)160); //A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11)\n                    const __m512i lhs_mat_23_61_sp1 = _mm512_shuffle_epi32(lhs_mat_23_61, (_MM_PERM_ENUM)160); //A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11)\n\n                    const __m512i lhs_mat_01_70_sp1 = _mm512_shuffle_epi32(lhs_mat_01_70, (_MM_PERM_ENUM)160); //A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3)\n                    const __m512i lhs_mat_23_70_sp1 = _mm512_shuffle_epi32(lhs_mat_23_70, (_MM_PERM_ENUM)160); //A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3)\n\n                    const __m512i lhs_mat_01_71_sp1 = _mm512_shuffle_epi32(lhs_mat_01_71, (_MM_PERM_ENUM)160); //A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11)\n                    const __m512i lhs_mat_23_71_sp1 = _mm512_shuffle_epi32(lhs_mat_23_71, (_MM_PERM_ENUM)160); //A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11)\n\n                    const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)\n                    const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7)\n\n                    const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)\n                    const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15)\n\n                    const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)\n                    const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7)\n\n                    const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)\n                    const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15)\n\n                    const __m512i lhs_mat_01_20_sp2 = _mm512_shuffle_epi32(lhs_mat_01_20, (_MM_PERM_ENUM)245); //A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7)\n                    const __m512i lhs_mat_23_20_sp2 = _mm512_shuffle_epi32(lhs_mat_23_20, (_MM_PERM_ENUM)245); //A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7)\n\n                    const __m512i lhs_mat_01_21_sp2 = _mm512_shuffle_epi32(lhs_mat_01_21, (_MM_PERM_ENUM)245); //A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15)\n                    const __m512i lhs_mat_23_21_sp2 = _mm512_shuffle_epi32(lhs_mat_23_21, (_MM_PERM_ENUM)245); //A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15)\n\n                    const __m512i lhs_mat_01_30_sp2 = _mm512_shuffle_epi32(lhs_mat_01_30, (_MM_PERM_ENUM)245); //A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7)\n                    const __m512i lhs_mat_23_30_sp2 = _mm512_shuffle_epi32(lhs_mat_23_30, (_MM_PERM_ENUM)245); //A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7)\n\n                    const __m512i lhs_mat_01_31_sp2 = _mm512_shuffle_epi32(lhs_mat_01_31, (_MM_PERM_ENUM)245); //A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15)\n                    const __m512i lhs_mat_23_31_sp2 = _mm512_shuffle_epi32(lhs_mat_23_31, (_MM_PERM_ENUM)245); //A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15)\n\n                    const __m512i lhs_mat_01_40_sp2 = _mm512_shuffle_epi32(lhs_mat_01_40, (_MM_PERM_ENUM)245); //A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7)\n                    const __m512i lhs_mat_23_40_sp2 = _mm512_shuffle_epi32(lhs_mat_23_40, (_MM_PERM_ENUM)245); //A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7)\n\n                    const __m512i lhs_mat_01_41_sp2 = _mm512_shuffle_epi32(lhs_mat_01_41, (_MM_PERM_ENUM)245); //A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15)\n                    const __m512i lhs_mat_23_41_sp2 = _mm512_shuffle_epi32(lhs_mat_23_41, (_MM_PERM_ENUM)245); //A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15)\n\n                    const __m512i lhs_mat_01_50_sp2 = _mm512_shuffle_epi32(lhs_mat_01_50, (_MM_PERM_ENUM)245); //A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7)\n                    const __m512i lhs_mat_23_50_sp2 = _mm512_shuffle_epi32(lhs_mat_23_50, (_MM_PERM_ENUM)245); //A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7)\n\n                    const __m512i lhs_mat_01_51_sp2 = _mm512_shuffle_epi32(lhs_mat_01_51, (_MM_PERM_ENUM)245); //A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15)\n                    const __m512i lhs_mat_23_51_sp2 = _mm512_shuffle_epi32(lhs_mat_23_51, (_MM_PERM_ENUM)245); //A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15)\n\n                    const __m512i lhs_mat_01_60_sp2 = _mm512_shuffle_epi32(lhs_mat_01_60, (_MM_PERM_ENUM)245); //A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7)\n                    const __m512i lhs_mat_23_60_sp2 = _mm512_shuffle_epi32(lhs_mat_23_60, (_MM_PERM_ENUM)245); //A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7)\n\n                    const __m512i lhs_mat_01_61_sp2 = _mm512_shuffle_epi32(lhs_mat_01_61, (_MM_PERM_ENUM)245); //A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15)\n                    const __m512i lhs_mat_23_61_sp2 = _mm512_shuffle_epi32(lhs_mat_23_61, (_MM_PERM_ENUM)245); //A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15)\n\n                    const __m512i lhs_mat_01_70_sp2 = _mm512_shuffle_epi32(lhs_mat_01_70, (_MM_PERM_ENUM)245); //A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7)\n                    const __m512i lhs_mat_23_70_sp2 = _mm512_shuffle_epi32(lhs_mat_23_70, (_MM_PERM_ENUM)245); //A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7)\n\n                    const __m512i lhs_mat_01_71_sp2 = _mm512_shuffle_epi32(lhs_mat_01_71, (_MM_PERM_ENUM)245); //A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15)\n                    const __m512i lhs_mat_23_71_sp2 = _mm512_shuffle_epi32(lhs_mat_23_71, (_MM_PERM_ENUM)245); //A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15)\n\n                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane\n                    __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1));\n                    __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1));\n\n                    __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1));\n                    __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1));\n\n                    __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1));\n                    __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1));\n\n                    __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1));\n                    __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1));\n\n                    __m512i iacc_mat_00_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp1, lhs_mat_01_20_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp1, lhs_mat_01_21_sp1));\n                    __m512i iacc_mat_01_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp1, lhs_mat_01_20_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp1, lhs_mat_01_21_sp1));\n\n                    __m512i iacc_mat_10_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp1, lhs_mat_23_20_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp1, lhs_mat_23_21_sp1));\n                    __m512i iacc_mat_11_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp1, lhs_mat_23_20_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp1, lhs_mat_23_21_sp1));\n\n                    __m512i iacc_mat_00_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp1, lhs_mat_01_30_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp1, lhs_mat_01_31_sp1));\n                    __m512i iacc_mat_01_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp1, lhs_mat_01_30_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp1, lhs_mat_01_31_sp1));\n\n                    __m512i iacc_mat_10_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp1, lhs_mat_23_30_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp1, lhs_mat_23_31_sp1));\n                    __m512i iacc_mat_11_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp1, lhs_mat_23_30_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp1, lhs_mat_23_31_sp1));\n\n                    __m512i iacc_mat_00_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp1, lhs_mat_01_40_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp1, lhs_mat_01_41_sp1));\n                    __m512i iacc_mat_01_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp1, lhs_mat_01_40_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp1, lhs_mat_01_41_sp1));\n\n                    __m512i iacc_mat_10_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp1, lhs_mat_23_40_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp1, lhs_mat_23_41_sp1));\n                    __m512i iacc_mat_11_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp1, lhs_mat_23_40_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp1, lhs_mat_23_41_sp1));\n\n                    __m512i iacc_mat_00_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp1, lhs_mat_01_50_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp1, lhs_mat_01_51_sp1));\n                    __m512i iacc_mat_01_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp1, lhs_mat_01_50_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp1, lhs_mat_01_51_sp1));\n\n                    __m512i iacc_mat_10_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp1, lhs_mat_23_50_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp1, lhs_mat_23_51_sp1));\n                    __m512i iacc_mat_11_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp1, lhs_mat_23_50_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp1, lhs_mat_23_51_sp1));\n\n                    __m512i iacc_mat_00_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp1, lhs_mat_01_60_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp1, lhs_mat_01_61_sp1));\n                    __m512i iacc_mat_01_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp1, lhs_mat_01_60_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp1, lhs_mat_01_61_sp1));\n\n                    __m512i iacc_mat_10_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp1, lhs_mat_23_60_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp1, lhs_mat_23_61_sp1));\n                    __m512i iacc_mat_11_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp1, lhs_mat_23_60_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp1, lhs_mat_23_61_sp1));\n\n                    __m512i iacc_mat_00_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp1, lhs_mat_01_70_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp1, lhs_mat_01_71_sp1));\n                    __m512i iacc_mat_01_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp1, lhs_mat_01_70_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp1, lhs_mat_01_71_sp1));\n\n                    __m512i iacc_mat_10_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp1, lhs_mat_23_70_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp1, lhs_mat_23_71_sp1));\n                    __m512i iacc_mat_11_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp1, lhs_mat_23_70_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp1, lhs_mat_23_71_sp1));\n\n\n                    __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2));\n                    __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2));\n\n                    __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2));\n                    __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2));\n\n                    __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2));\n                    __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2));\n\n                    __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2));\n                    __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2));\n\n                    __m512i iacc_mat_00_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp2, lhs_mat_01_20_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp2, lhs_mat_01_21_sp2));\n                    __m512i iacc_mat_01_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp2, lhs_mat_01_20_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp2, lhs_mat_01_21_sp2));\n\n                    __m512i iacc_mat_10_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp2, lhs_mat_23_20_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp2, lhs_mat_23_21_sp2));\n                    __m512i iacc_mat_11_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp2, lhs_mat_23_20_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp2, lhs_mat_23_21_sp2));\n\n                    __m512i iacc_mat_00_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp2, lhs_mat_01_30_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp2, lhs_mat_01_31_sp2));\n                    __m512i iacc_mat_01_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp2, lhs_mat_01_30_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp2, lhs_mat_01_31_sp2));\n\n                    __m512i iacc_mat_10_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp2, lhs_mat_23_30_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp2, lhs_mat_23_31_sp2));\n                    __m512i iacc_mat_11_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp2, lhs_mat_23_30_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp2, lhs_mat_23_31_sp2));\n\n                    __m512i iacc_mat_00_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp2, lhs_mat_01_40_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp2, lhs_mat_01_41_sp2));\n                    __m512i iacc_mat_01_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp2, lhs_mat_01_40_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp2, lhs_mat_01_41_sp2));\n\n                    __m512i iacc_mat_10_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp2, lhs_mat_23_40_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp2, lhs_mat_23_41_sp2));\n                    __m512i iacc_mat_11_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp2, lhs_mat_23_40_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp2, lhs_mat_23_41_sp2));\n\n                    __m512i iacc_mat_00_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp2, lhs_mat_01_50_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp2, lhs_mat_01_51_sp2));\n                    __m512i iacc_mat_01_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp2, lhs_mat_01_50_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp2, lhs_mat_01_51_sp2));\n\n                    __m512i iacc_mat_10_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp2, lhs_mat_23_50_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp2, lhs_mat_23_51_sp2));\n                    __m512i iacc_mat_11_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp2, lhs_mat_23_50_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp2, lhs_mat_23_51_sp2));\n\n                    __m512i iacc_mat_00_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp2, lhs_mat_01_60_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp2, lhs_mat_01_61_sp2));\n                    __m512i iacc_mat_01_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp2, lhs_mat_01_60_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp2, lhs_mat_01_61_sp2));\n\n                    __m512i iacc_mat_10_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp2, lhs_mat_23_60_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp2, lhs_mat_23_61_sp2));\n                    __m512i iacc_mat_11_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp2, lhs_mat_23_60_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp2, lhs_mat_23_61_sp2));\n\n                    __m512i iacc_mat_00_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp2, lhs_mat_01_70_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp2, lhs_mat_01_71_sp2));\n                    __m512i iacc_mat_01_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp2, lhs_mat_01_70_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp2, lhs_mat_01_71_sp2));\n\n                    __m512i iacc_mat_10_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp2, lhs_mat_23_70_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp2, lhs_mat_23_71_sp2));\n                    __m512i iacc_mat_11_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp2, lhs_mat_23_70_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp2, lhs_mat_23_71_sp2));\n\n                    // Combine results from both shuffle patterns for each output block\n                    __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);\n                    __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);\n                    __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);\n                    __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);\n\n                    __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);\n                    __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);\n                    __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);\n                    __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);\n\n                    __m512i iacc_mat_00_2 = _mm512_add_epi16(iacc_mat_00_2_sp1, iacc_mat_00_2_sp2);\n                    __m512i iacc_mat_01_2 = _mm512_add_epi16(iacc_mat_01_2_sp1, iacc_mat_01_2_sp2);\n                    __m512i iacc_mat_10_2 = _mm512_add_epi16(iacc_mat_10_2_sp1, iacc_mat_10_2_sp2);\n                    __m512i iacc_mat_11_2 = _mm512_add_epi16(iacc_mat_11_2_sp1, iacc_mat_11_2_sp2);\n\n                    __m512i iacc_mat_00_3 = _mm512_add_epi16(iacc_mat_00_3_sp1, iacc_mat_00_3_sp2);\n                    __m512i iacc_mat_01_3 = _mm512_add_epi16(iacc_mat_01_3_sp1, iacc_mat_01_3_sp2);\n                    __m512i iacc_mat_10_3 = _mm512_add_epi16(iacc_mat_10_3_sp1, iacc_mat_10_3_sp2);\n                    __m512i iacc_mat_11_3 = _mm512_add_epi16(iacc_mat_11_3_sp1, iacc_mat_11_3_sp2);\n\n                    __m512i iacc_mat_00_4 = _mm512_add_epi16(iacc_mat_00_4_sp1, iacc_mat_00_4_sp2);\n                    __m512i iacc_mat_01_4 = _mm512_add_epi16(iacc_mat_01_4_sp1, iacc_mat_01_4_sp2);\n                    __m512i iacc_mat_10_4 = _mm512_add_epi16(iacc_mat_10_4_sp1, iacc_mat_10_4_sp2);\n                    __m512i iacc_mat_11_4 = _mm512_add_epi16(iacc_mat_11_4_sp1, iacc_mat_11_4_sp2);\n\n                    __m512i iacc_mat_00_5 = _mm512_add_epi16(iacc_mat_00_5_sp1, iacc_mat_00_5_sp2);\n                    __m512i iacc_mat_01_5 = _mm512_add_epi16(iacc_mat_01_5_sp1, iacc_mat_01_5_sp2);\n                    __m512i iacc_mat_10_5 = _mm512_add_epi16(iacc_mat_10_5_sp1, iacc_mat_10_5_sp2);\n                    __m512i iacc_mat_11_5 = _mm512_add_epi16(iacc_mat_11_5_sp1, iacc_mat_11_5_sp2);\n\n                    __m512i iacc_mat_00_6 = _mm512_add_epi16(iacc_mat_00_6_sp1, iacc_mat_00_6_sp2);\n                    __m512i iacc_mat_01_6 = _mm512_add_epi16(iacc_mat_01_6_sp1, iacc_mat_01_6_sp2);\n                    __m512i iacc_mat_10_6 = _mm512_add_epi16(iacc_mat_10_6_sp1, iacc_mat_10_6_sp2);\n                    __m512i iacc_mat_11_6 = _mm512_add_epi16(iacc_mat_11_6_sp1, iacc_mat_11_6_sp2);\n\n                    __m512i iacc_mat_00_7 = _mm512_add_epi16(iacc_mat_00_7_sp1, iacc_mat_00_7_sp2);\n                    __m512i iacc_mat_01_7 = _mm512_add_epi16(iacc_mat_01_7_sp1, iacc_mat_01_7_sp2);\n                    __m512i iacc_mat_10_7 = _mm512_add_epi16(iacc_mat_10_7_sp1, iacc_mat_10_7_sp2);\n                    __m512i iacc_mat_11_7 = _mm512_add_epi16(iacc_mat_11_7_sp1, iacc_mat_11_7_sp2);\n\n                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block\n                    iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0);\n                    iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0);\n                    iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0);\n                    iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0);\n\n                    iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1);\n                    iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1);\n                    iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1);\n                    iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1);\n\n                    iacc_mat_00_2 = _mm512_madd_epi16(iacc_mat_00_2, scale_014589CD_2);\n                    iacc_mat_01_2 = _mm512_madd_epi16(iacc_mat_01_2, scale_2367ABEF_2);\n                    iacc_mat_10_2 = _mm512_madd_epi16(iacc_mat_10_2, scale_014589CD_2);\n                    iacc_mat_11_2 = _mm512_madd_epi16(iacc_mat_11_2, scale_2367ABEF_2);\n\n                    iacc_mat_00_3 = _mm512_madd_epi16(iacc_mat_00_3, scale_014589CD_3);\n                    iacc_mat_01_3 = _mm512_madd_epi16(iacc_mat_01_3, scale_2367ABEF_3);\n                    iacc_mat_10_3 = _mm512_madd_epi16(iacc_mat_10_3, scale_014589CD_3);\n                    iacc_mat_11_3 = _mm512_madd_epi16(iacc_mat_11_3, scale_2367ABEF_3);\n\n                    iacc_mat_00_4 = _mm512_madd_epi16(iacc_mat_00_4, scale_014589CD_4);\n                    iacc_mat_01_4 = _mm512_madd_epi16(iacc_mat_01_4, scale_2367ABEF_4);\n                    iacc_mat_10_4 = _mm512_madd_epi16(iacc_mat_10_4, scale_014589CD_4);\n                    iacc_mat_11_4 = _mm512_madd_epi16(iacc_mat_11_4, scale_2367ABEF_4);\n\n                    iacc_mat_00_5 = _mm512_madd_epi16(iacc_mat_00_5, scale_014589CD_5);\n                    iacc_mat_01_5 = _mm512_madd_epi16(iacc_mat_01_5, scale_2367ABEF_5);\n                    iacc_mat_10_5 = _mm512_madd_epi16(iacc_mat_10_5, scale_014589CD_5);\n                    iacc_mat_11_5 = _mm512_madd_epi16(iacc_mat_11_5, scale_2367ABEF_5);\n\n                    iacc_mat_00_6 = _mm512_madd_epi16(iacc_mat_00_6, scale_014589CD_6);\n                    iacc_mat_01_6 = _mm512_madd_epi16(iacc_mat_01_6, scale_2367ABEF_6);\n                    iacc_mat_10_6 = _mm512_madd_epi16(iacc_mat_10_6, scale_014589CD_6);\n                    iacc_mat_11_6 = _mm512_madd_epi16(iacc_mat_11_6, scale_2367ABEF_6);\n\n                    iacc_mat_00_7 = _mm512_madd_epi16(iacc_mat_00_7, scale_014589CD_7);\n                    iacc_mat_01_7 = _mm512_madd_epi16(iacc_mat_01_7, scale_2367ABEF_7);\n                    iacc_mat_10_7 = _mm512_madd_epi16(iacc_mat_10_7, scale_014589CD_7);\n                    iacc_mat_11_7 = _mm512_madd_epi16(iacc_mat_11_7, scale_2367ABEF_7);\n\n                    __m512i iacc_mat_00 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_00_0, iacc_mat_00_1), _mm512_add_epi32(iacc_mat_00_2, iacc_mat_00_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_00_4, iacc_mat_00_5), _mm512_add_epi32(iacc_mat_00_6, iacc_mat_00_7)));\n                    __m512i iacc_mat_01 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_01_0, iacc_mat_01_1), _mm512_add_epi32(iacc_mat_01_2, iacc_mat_01_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_01_4, iacc_mat_01_5), _mm512_add_epi32(iacc_mat_01_6, iacc_mat_01_7)));\n                    __m512i iacc_mat_10 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_10_0, iacc_mat_10_1), _mm512_add_epi32(iacc_mat_10_2, iacc_mat_10_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_10_4, iacc_mat_10_5), _mm512_add_epi32(iacc_mat_10_6, iacc_mat_10_7)));\n                    __m512i iacc_mat_11 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_11_0, iacc_mat_11_1), _mm512_add_epi32(iacc_mat_11_2, iacc_mat_11_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_11_4, iacc_mat_11_5), _mm512_add_epi32(iacc_mat_11_6, iacc_mat_11_7)));\n\n                    // Straighten out to make 4 row vectors\n                    __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));\n                    __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);\n                    __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));\n                    __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);\n\n                    // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes\n                    const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);\n                    const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);\n                    const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1);\n\n                    // Multiply with appropiate scales and accumulate (for both d and dmin) below\n                    acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);\n                    acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);\n                    acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);\n                    acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);\n\n                    // Take two bsums from two Q8_Ks at a time and multiply with corresponding mins values from each Q2_K\n                    __m512i iacc_row_min_0_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)0), mins_01);\n                    __m512i iacc_row_min_1_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)170), mins_01);\n                    __m512i iacc_row_min_2_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)0), mins_01);\n                    __m512i iacc_row_min_3_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)170), mins_01);\n\n                    __m512i iacc_row_min_0_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)85), mins_23);\n                    __m512i iacc_row_min_1_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)255), mins_23);\n                    __m512i iacc_row_min_2_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)85), mins_23);\n                    __m512i iacc_row_min_3_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)255), mins_23);\n\n                    __m512i iacc_row_min_0_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)0), mins_45);\n                    __m512i iacc_row_min_1_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)170), mins_45);\n                    __m512i iacc_row_min_2_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)0), mins_45);\n                    __m512i iacc_row_min_3_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)170), mins_45);\n\n                    __m512i iacc_row_min_0_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)85), mins_67);\n                    __m512i iacc_row_min_1_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)255), mins_67);\n                    __m512i iacc_row_min_2_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)85), mins_67);\n                    __m512i iacc_row_min_3_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)255), mins_67);\n\n                    __m512i iacc_row_min_0 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_0_01, iacc_row_min_0_23), _mm512_add_epi32(iacc_row_min_0_45,iacc_row_min_0_67));\n                    __m512i iacc_row_min_1 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_1_01, iacc_row_min_1_23), _mm512_add_epi32(iacc_row_min_1_45,iacc_row_min_1_67));\n                    __m512i iacc_row_min_2 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_2_01, iacc_row_min_2_23), _mm512_add_epi32(iacc_row_min_2_45,iacc_row_min_2_67));\n                    __m512i iacc_row_min_3 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_3_01, iacc_row_min_3_23), _mm512_add_epi32(iacc_row_min_3_45,iacc_row_min_3_67));\n\n                    acc_min_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);\n                    acc_min_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);\n                    acc_min_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]);\n                    acc_min_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);\n                }\n            }\n            // Store accumulated values\n            for (int i = 0; i < 4; i++) {\n                _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i]));\n            }\n        }\n    }\n\n    if (anc != nc) {\n        xstart = anc/8;\n        y = 0;\n    }\n\n#endif // __AVX512BW__ && __AVX512DQ__\n\n    // Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation\n    for (; y < anr / 4; y += 4) {\n\n        const block_q8_Kx4 * a_ptrs[4];\n\n        a_ptrs[0] = a_ptr_start + (y * nb);\n        for (int i = 0; i < 3; ++i) {\n            a_ptrs[i + 1] = a_ptrs[i] + nb;\n        }\n\n        // Take group of eight block_q2_kx8 structures at each pass of the loop and perform dot product operation\n        for (int64_t x = xstart; x < nc / 8; x++) {\n\n            const block_q2_Kx8 * b_ptr = b_ptr_start + (x * b_nb);\n\n            // Master FP accumulators\n            __m256 acc_rows[16];\n            for (int i = 0; i < 16; i++) {\n                acc_rows[i] = _mm256_setzero_ps();\n            }\n\n            __m256 acc_min_rows[16];\n            for (int i = 0; i < 16; i++) {\n                acc_min_rows[i] = _mm256_setzero_ps();\n            }\n\n            // For super block\n            for (int64_t b = 0; b < nb; b++) {\n                // Delta values - Load the eight scale values of block_q2_kx8\n                const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);\n\n                // dmin values - Load the eight dmin values of block_q2_kx8\n                const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);\n\n                // Loop to iterate over the sixteen sub blocks of a super block - eight sub blocks are processed per iteration\n                for (int sb = 0; sb < QK_K / 128; sb++) {\n\n                    // Load the eight block_q2_K for eight sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7\n                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + sb * 256));\n                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 128 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 160 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 192 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 224 + sb * 256));\n\n                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values\n                    //superblock    sub block   which part of sub block\n                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);\n\n                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);\n\n                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);\n\n                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);\n\n                    // 2-bit -> 8-bit\n                    // First sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m3b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)\n                    const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m3b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)\n\n                    const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m3b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)\n                    const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m3b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)\n\n                    // Second sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_10 = _mm256_and_si256(rhs_raw_mat_0145_2, m3b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)\n                    const __m256i rhs_mat_2367_10 = _mm256_and_si256(rhs_raw_mat_2367_2, m3b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)\n\n                    const __m256i rhs_mat_0145_11 = _mm256_and_si256(rhs_raw_mat_0145_3, m3b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)\n                    const __m256i rhs_mat_2367_11 = _mm256_and_si256(rhs_raw_mat_2367_3, m3b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)\n\n                    // Third sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 2), m3b); //B20(0-7) B21(0-7) B24(0-7) B25(0-7)\n                    const __m256i rhs_mat_2367_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 2), m3b); //B22(0-7) B23(0-7) B26(0-7) B27(0-7)\n\n                    const __m256i rhs_mat_0145_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 2), m3b); //B20(8-15) B21(8-15) B24(8-15) B25(8-15)\n                    const __m256i rhs_mat_2367_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 2), m3b); //B22(8-15) B23(8-15) B26(8-15) B27(8-15)\n\n                    // Fourth sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 2), m3b); //B30(0-7) B31(0-7) B34(0-7) B35(0-7)\n                    const __m256i rhs_mat_2367_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 2), m3b); //B32(0-7) B33(0-7) B36(0-7) B37(0-7)\n\n                    const __m256i rhs_mat_0145_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 2), m3b); //B30(8-15) B31(8-15) B34(8-15) B35(8-15)\n                    const __m256i rhs_mat_2367_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 2), m3b); //B32(8-15) B33(8-15) B36(8-15) B37(8-15)\n\n                    // Fifth sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m3b); //B40(0-7) B41(0-7) B44(0-7) B45(0-7)\n                    const __m256i rhs_mat_2367_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m3b); //B42(0-7) B43(0-7) B46(0-7) B47(0-7)\n\n                    const __m256i rhs_mat_0145_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m3b); //B40(8-15) B41(8-15) B44(8-15) B45(8-15)\n                    const __m256i rhs_mat_2367_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m3b); //B42(8-15) B43(8-15) B46(8-15) B47(8-15)\n\n                    // Sixth sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m3b); //B50(0-7) B51(0-7) B54(0-7) B55(0-7)\n                    const __m256i rhs_mat_2367_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m3b); //B52(0-7) B53(0-7) B56(0-7) B57(0-7)\n\n                    const __m256i rhs_mat_0145_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m3b); //B50(8-15) B51(8-15) B54(8-15) B55(8-15)\n                    const __m256i rhs_mat_2367_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m3b); //B52(8-15) B53(8-15) B56(8-15) B57(8-15)\n\n                    // Seventh sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 6), m3b); //B60(0-7) B61(0-7) B64(0-7) B65(0-7)\n                    const __m256i rhs_mat_2367_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 6), m3b); //B62(0-7) B63(0-7) B66(0-7) B67(0-7)\n\n                    const __m256i rhs_mat_0145_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 6), m3b); //B60(8-15) B61(8-15) B64(8-15) B65(8-15)\n                    const __m256i rhs_mat_2367_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 6), m3b); //B62(8-15) B63(8-15) B66(8-15) B67(8-15)\n\n                    // Eighth sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 6), m3b); //B70(0-7) B71(0-7) B74(0-7) B75(0-7)\n                    const __m256i rhs_mat_2367_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 6), m3b); //B72(0-7) B73(0-7) B76(0-7) B77(0-7)\n\n                    const __m256i rhs_mat_0145_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 6), m3b); //B70(8-15) B71(8-15) B74(8-15) B75(8-15)\n                    const __m256i rhs_mat_2367_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 6), m3b); //B72(8-15) B73(8-15) B76(8-15) B77(8-15)\n\n                    // Shuffle pattern one - right side input\n                    const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)\n                    const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)\n\n                    const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)\n                    const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)\n\n                    const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)\n                    const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)\n\n                    const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)\n                    const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)\n\n                    const __m256i rhs_mat_0145_20_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_20, 136); //B20(0-3) B21(0-3) B20(0-3) B21(0-3) B24(0-3) B25(0-3) B24(0-3) B25(0-3)\n                    const __m256i rhs_mat_2367_20_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_20, 136); //B22(0-3) B23(0-3) B22(0-3) B23(0-3) B26(0-3) B27(0-3) B26(0-3) B27(0-3)\n\n                    const __m256i rhs_mat_0145_21_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_21, 136); //B20(8-11) B21(8-11) B20(8-11) B21(8-11) B24(8-11) B25(8-11) B24(8-11) B25(8-11)\n                    const __m256i rhs_mat_2367_21_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_21, 136); //B22(8-11) B23(8-11) B22(8-11) B23(8-11) B26(8-11) B27(8-11) B26(8-11) B27(8-11)\n\n                    const __m256i rhs_mat_0145_30_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_30, 136); //B30(0-3) B31(0-3) B30(0-3) B31(0-3) B34(0-3) B35(0-3) B34(0-3) B35(0-3)\n                    const __m256i rhs_mat_2367_30_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_30, 136); //B32(0-3) B33(0-3) B32(0-3) B33(0-3) B36(0-3) B37(0-3) B36(0-3) B37(0-3)\n\n                    const __m256i rhs_mat_0145_31_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_31, 136); //B30(8-11) B31(8-11) B30(8-11) B31(8-11) B34(8-11) B35(8-11) B34(8-11) B35(8-11\n                    const __m256i rhs_mat_2367_31_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_31, 136); //B32(8-11) B33(8-11) B32(8-11) B33(8-11) B36(8-11) B37(8-11) B36(8-11) B37(8-11)\n\n                    const __m256i rhs_mat_0145_40_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_40, 136); //B40(0-3) B41(0-3) B40(0-3) B41(0-3) B44(0-3) B45(0-3) B44(0-3) B45(0-3)\n                    const __m256i rhs_mat_2367_40_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_40, 136); //B42(0-3) B43(0-3) B42(0-3) B43(0-3) B46(0-3) B47(0-3) B46(0-3) B47(0-3)\n\n                    const __m256i rhs_mat_0145_41_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_41, 136); //B40(8-11) B41(8-11) B40(8-11) B41(8-11) B44(8-11) B45(8-11) B44(8-11) B45(8-11)\n                    const __m256i rhs_mat_2367_41_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_41, 136); //B42(8-11) B43(8-11) B42(8-11) B43(8-11) B46(8-11) B47(8-11) B46(8-11) B47(8-11)\n\n                    const __m256i rhs_mat_0145_50_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_50, 136); //B50(0-3) B51(0-3) B50(0-3) B51(0-3) B54(0-3) B55(0-3) B54(0-3) B55(0-3)\n                    const __m256i rhs_mat_2367_50_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_50, 136); //B52(0-3) B53(0-3) B52(0-3) B53(0-3) B56(0-3) B57(0-3) B56(0-3) B57(0-3)\n\n                    const __m256i rhs_mat_0145_51_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_51, 136); //B50(8-11) B51(8-11) B50(8-11) B51(8-11) B54(8-11) B55(8-11) B54(8-11) B55(8-11)\n                    const __m256i rhs_mat_2367_51_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_51, 136); //B52(8-11) B53(8-11) B52(8-11) B53(8-11) B56(8-11) B57(8-11) B56(8-11) B57(8-11)\n\n                    const __m256i rhs_mat_0145_60_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_60, 136); //B60(0-3) B61(0-3) B60(0-3) B61(0-3) B64(0-3) B65(0-3) B64(0-3) B65(0-3)\n                    const __m256i rhs_mat_2367_60_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_60, 136); //B62(0-3) B63(0-3) B62(0-3) B63(0-3) B66(0-3) B67(0-3) B66(0-3) B67(0-3)\n\n                    const __m256i rhs_mat_0145_61_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_61, 136); //B60(8-11) B61(8-11) B60(8-11) B61(8-11) B64(8-11) B65(8-11) B64(8-11) B65(8-11)\n                    const __m256i rhs_mat_2367_61_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_61, 136); //B62(8-11) B63(8-11) B62(8-11) B63(8-11) B66(8-11) B67(8-11) B66(8-11) B67(8-11)\n\n                    const __m256i rhs_mat_0145_70_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_70, 136); //B70(0-3) B71(0-3) B70(0-3) B71(0-3) B74(0-3) B75(0-3) B74(0-3) B75(0-3)\n                    const __m256i rhs_mat_2367_70_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_70, 136); //B72(0-3) B73(0-3) B72(0-3) B73(0-3) B76(0-3) B77(0-3) B76(0-3) B77(0-3)\n\n                    const __m256i rhs_mat_0145_71_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_71, 136); //B70(8-11) B71(8-11) B70(8-11) B71(8-11) B74(8-11) B75(8-11) B74(8-11) B75(8-11)\n                    const __m256i rhs_mat_2367_71_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_71, 136); //B72(8-11) B73(8-11) B72(8-11) B73(8-11) B76(8-11) B77(8-11) B76(8-11) B77(8-11)\n\n\n                    // Shuffle pattern two - right side input\n                    const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)\n                    const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)\n\n                    const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)\n                    const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)\n\n                    const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)\n                    const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)\n\n                    const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)\n                    const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)\n\n                    const __m256i rhs_mat_0145_20_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_20, 221); //B20(4-7) B21(4-7) B20(4-7) B21(4-7) B24(4-7) B25(4-7) B24(4-7) B25(4-7)\n                    const __m256i rhs_mat_2367_20_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_20, 221); //B22(4-7) B23(4-7) B22(4-7) B23(4-7) B26(4-7) B27(4-7) B26(4-7) B27(4-7)\n\n                    const __m256i rhs_mat_0145_21_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_21, 221); //B20(12-15) B21(12-15) B20(12-15) B21(12-15) B24(12-15) B25(12-15) B24(12-15) B25(12-15)\n                    const __m256i rhs_mat_2367_21_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_21, 221); //B22(12-15) B23(12-15) B22(12-15) B23(12-15) B26(12-15) B27(12-15) B26(12-15) B27(12-15)\n\n                    const __m256i rhs_mat_0145_30_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_30, 221); //B30(4-7) B31(4-7) B30(4-7) B31(4-7) B34(4-7) B35(4-7) B34(4-7) B35(4-7)\n                    const __m256i rhs_mat_2367_30_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_30, 221); //B32(4-7) B33(4-7) B32(4-7) B33(4-7) B36(4-7) B37(4-7) B36(4-7) B37(4-7)\n\n                    const __m256i rhs_mat_0145_31_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_31, 221); //B30(12-15) B31(12-15) B30(12-15) B31(12-15) B34(12-15) B35(12-15) B34(12-15) B35(12-15)\n                    const __m256i rhs_mat_2367_31_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_31, 221); //B32(12-15) B33(12-15) B32(12-15) B33(12-15) B36(12-15) B37(12-15) B36(12-15) B37(12-15)\n\n                    const __m256i rhs_mat_0145_40_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_40, 221); //B40(4-7) B41(4-7) B40(4-7) B41(4-7) B44(4-7) B45(4-7) B44(4-7) B45(4-7)\n                    const __m256i rhs_mat_2367_40_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_40, 221); //B42(4-7) B43(4-7) B42(4-7) B43(4-7) B46(4-7) B47(4-7) B46(4-7) B47(4-7)\n\n                    const __m256i rhs_mat_0145_41_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_41, 221); //B40(12-15) B41(12-15) B40(12-15) B41(12-15) B44(12-15) B45(12-15) B44(12-15) B45(12-15)\n                    const __m256i rhs_mat_2367_41_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_41, 221); //B42(12-15) B43(12-15) B42(12-15) B43(12-15) B46(12-15) B47(12-15) B46(12-15) B47(12-15)\n\n                    const __m256i rhs_mat_0145_50_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_50, 221); //B50(4-7) B51(4-7) B50(4-7) B51(4-7) B54(4-7) B55(4-7) B54(4-7) B55(4-7)\n                    const __m256i rhs_mat_2367_50_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_50, 221); //B52(4-7) B53(4-7) B52(4-7) B53(4-7) B56(4-7) B57(4-7) B56(4-7) B57(4-7)\n\n                    const __m256i rhs_mat_0145_51_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_51, 221); //B50(12-15) B51(12-15) B50(12-15) B51(12-15) B54(12-15) B55(12-15) B54(12-15) B55(12-15)\n                    const __m256i rhs_mat_2367_51_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_51, 221); //B52(12-15) B53(12-15) B52(12-15) B53(12-15) B56(12-15) B57(12-15) B56(12-15) B57(12-15)\n\n                    const __m256i rhs_mat_0145_60_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_60, 221); //B60(4-7) B61(4-7) B60(4-7) B61(4-7) B64(4-7) B65(4-7) B64(4-7) B65(4-7)\n                    const __m256i rhs_mat_2367_60_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_60, 221); //B62(4-7) B63(4-7) B62(4-7) B63(4-7) B66(4-7) B67(4-7) B66(4-7) B67(4-7)\n\n                    const __m256i rhs_mat_0145_61_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_61, 221); //B60(12-15) B61(12-15) B60(12-15) B61(12-15) B64(12-15) B65(12-15) B64(12-15) B65(12-15)\n                    const __m256i rhs_mat_2367_61_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_61, 221); //B62(12-15) B63(12-15) B62(12-15) B63(12-15) B66(12-15) B67(12-15) B66(12-15) B67(12-15)\n\n                    const __m256i rhs_mat_0145_70_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_70, 221); //B70(4-7) B71(4-7) B70(4-7) B71(4-7) B74(4-7) B75(4-7) B74(4-7) B75(4-7)\n                    const __m256i rhs_mat_2367_70_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_70, 221); //B72(4-7) B73(4-7) B72(4-7) B73(4-7) B76(4-7) B77(4-7) B76(4-7) B77(4-7)\n\n                    const __m256i rhs_mat_0145_71_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_71, 221); //B70(12-15) B71(12-15) B70(12-15) B71(12-15) B74(12-15) B75(12-15) B74(12-15) B75(12-15)\n                    const __m256i rhs_mat_2367_71_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_71, 221); //B72(12-15) B73(12-15) B72(12-15) B73(12-15) B76(12-15) B77(12-15) B76(12-15) B77(12-15)\n\n                    //Scales and Mins of corresponding sub blocks from different Q2_K structures are stored together\n                    //s00 m00  s01 m01   s10 m10  s11 m11  s20 m20  s21 m21   s30 m30  s31 m31  s40 m40  s41 m41   s50 m50  s51 m51  s60 m60  s61 m61   s70 m70  s71 m71\n\n                    // Combine mins and scales for sub-blocks: 0-1, 2-3, 4-5, 6-7 in the sb loop\n                    const __m128i mins_and_scales_01 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + sb * 64));\n                    const __m128i mins_and_scales_23 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 16 + sb * 64));\n                    const __m128i mins_and_scales_45 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 32 + sb * 64));\n                    const __m128i mins_and_scales_67 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 48 + sb * 64));\n\n                    // Extract scales which is lower half from mins_and_scales\n                    const __m128i scales_01 = _mm_and_si128(mins_and_scales_01, m4b_sse);\n                    const __m128i scales_23 = _mm_and_si128(mins_and_scales_23, m4b_sse);\n                    const __m128i scales_45 = _mm_and_si128(mins_and_scales_45, m4b_sse);\n                    const __m128i scales_67 = _mm_and_si128(mins_and_scales_67, m4b_sse);\n\n                    // Extract mins which is upper half from mins_and_scales\n                    const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_01, 4), m4b_sse));\n                    const __m256i mins_23 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_23, 4), m4b_sse));\n                    const __m256i mins_45 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_45, 4), m4b_sse));\n                    const __m256i mins_67 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_67, 4), m4b_sse));\n\n                    const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_01, scalesmask1_sse));\n                    const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_01, scalesmask2_sse));\n\n                    const __m256i scales_2 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_23, scalesmask1_sse));\n                    const __m256i scales_3 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_23, scalesmask2_sse));\n\n                    const __m256i scales_4 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_45, scalesmask1_sse));\n                    const __m256i scales_5 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_45, scalesmask2_sse));\n\n                    const __m256i scales_6 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_67, scalesmask1_sse));\n                    const __m256i scales_7 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_67, scalesmask2_sse));\n\n                    const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);\n                    const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);\n\n                    const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);\n                    const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);\n\n                    const __m256i scale_0145_2 = _mm256_shuffle_epi32(scales_2, 68);\n                    const __m256i scale_2367_2 = _mm256_shuffle_epi32(scales_2, 238);\n\n                    const __m256i scale_0145_3 = _mm256_shuffle_epi32(scales_3, 68);\n                    const __m256i scale_2367_3 = _mm256_shuffle_epi32(scales_3, 238);\n\n                    const __m256i scale_0145_4 = _mm256_shuffle_epi32(scales_4, 68);\n                    const __m256i scale_2367_4 = _mm256_shuffle_epi32(scales_4, 238);\n\n                    const __m256i scale_0145_5 = _mm256_shuffle_epi32(scales_5, 68);\n                    const __m256i scale_2367_5 = _mm256_shuffle_epi32(scales_5, 238);\n\n                    const __m256i scale_0145_6 = _mm256_shuffle_epi32(scales_6, 68);\n                    const __m256i scale_2367_6 = _mm256_shuffle_epi32(scales_6, 238);\n\n                    const __m256i scale_0145_7 = _mm256_shuffle_epi32(scales_7, 68);\n                    const __m256i scale_2367_7 = _mm256_shuffle_epi32(scales_7, 238);\n\n\n                    for (int rp = 0; rp < 4; rp++) {\n\n                        // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3\n                        // Loaded as set of 128 bit vectors and repeated into a 256 bit vector\n                        __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 512 * sb)));\n                        __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);\n                        __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);\n                        __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 512 * sb)));\n                        __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);\n                        __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);\n                        __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 512 * sb)));\n                        __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);\n                        __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);\n                        __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 512 * sb)));\n                        __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);\n                        __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);\n                        __m256i lhs_mat_0123_20 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 512 * sb)));\n                        __m256i lhs_mat_01_20 = _mm256_permute2f128_si256(lhs_mat_0123_20, lhs_mat_0123_20, 0);\n                        __m256i lhs_mat_23_20 = _mm256_permute2f128_si256(lhs_mat_0123_20, lhs_mat_0123_20, 17);\n                        __m256i lhs_mat_0123_21 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 512 * sb)));\n                        __m256i lhs_mat_01_21 = _mm256_permute2f128_si256(lhs_mat_0123_21, lhs_mat_0123_21, 0);\n                        __m256i lhs_mat_23_21 = _mm256_permute2f128_si256(lhs_mat_0123_21, lhs_mat_0123_21, 17);\n                        __m256i lhs_mat_0123_30 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 512 * sb)));\n                        __m256i lhs_mat_01_30 = _mm256_permute2f128_si256(lhs_mat_0123_30, lhs_mat_0123_30, 0);\n                        __m256i lhs_mat_23_30 = _mm256_permute2f128_si256(lhs_mat_0123_30, lhs_mat_0123_30, 17);\n                        __m256i lhs_mat_0123_31 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 512 * sb)));\n                        __m256i lhs_mat_01_31 = _mm256_permute2f128_si256(lhs_mat_0123_31, lhs_mat_0123_31, 0);\n                        __m256i lhs_mat_23_31 = _mm256_permute2f128_si256(lhs_mat_0123_31, lhs_mat_0123_31, 17);\n\n                        __m256i lhs_mat_0123_40 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 + 512 * sb)));\n                        __m256i lhs_mat_01_40 = _mm256_permute2f128_si256(lhs_mat_0123_40, lhs_mat_0123_40, 0);\n                        __m256i lhs_mat_23_40 = _mm256_permute2f128_si256(lhs_mat_0123_40, lhs_mat_0123_40, 17);\n                        __m256i lhs_mat_0123_41 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 288 + 512 * sb)));\n                        __m256i lhs_mat_01_41 = _mm256_permute2f128_si256(lhs_mat_0123_41, lhs_mat_0123_41, 0);\n                        __m256i lhs_mat_23_41 = _mm256_permute2f128_si256(lhs_mat_0123_41, lhs_mat_0123_41, 17);\n                        __m256i lhs_mat_0123_50 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 320 + 512 * sb)));\n                        __m256i lhs_mat_01_50 = _mm256_permute2f128_si256(lhs_mat_0123_50, lhs_mat_0123_50, 0);\n                        __m256i lhs_mat_23_50 = _mm256_permute2f128_si256(lhs_mat_0123_50, lhs_mat_0123_50, 17);\n                        __m256i lhs_mat_0123_51 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 352 + 512 * sb)));\n                        __m256i lhs_mat_01_51 = _mm256_permute2f128_si256(lhs_mat_0123_51, lhs_mat_0123_51, 0);\n                        __m256i lhs_mat_23_51 = _mm256_permute2f128_si256(lhs_mat_0123_51, lhs_mat_0123_51, 17);\n                        __m256i lhs_mat_0123_60 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 384 + 512 * sb)));\n                        __m256i lhs_mat_01_60 = _mm256_permute2f128_si256(lhs_mat_0123_60, lhs_mat_0123_60, 0);\n                        __m256i lhs_mat_23_60 = _mm256_permute2f128_si256(lhs_mat_0123_60, lhs_mat_0123_60, 17);\n                        __m256i lhs_mat_0123_61 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 416 + 512 * sb)));\n                        __m256i lhs_mat_01_61 = _mm256_permute2f128_si256(lhs_mat_0123_61, lhs_mat_0123_61, 0);\n                        __m256i lhs_mat_23_61 = _mm256_permute2f128_si256(lhs_mat_0123_61, lhs_mat_0123_61, 17);\n                        __m256i lhs_mat_0123_70 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 448 + 512 * sb)));\n                        __m256i lhs_mat_01_70 = _mm256_permute2f128_si256(lhs_mat_0123_70, lhs_mat_0123_70, 0);\n                        __m256i lhs_mat_23_70 = _mm256_permute2f128_si256(lhs_mat_0123_70, lhs_mat_0123_70, 17);\n                        __m256i lhs_mat_0123_71 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 480 + 512 * sb)));\n                        __m256i lhs_mat_01_71 = _mm256_permute2f128_si256(lhs_mat_0123_71, lhs_mat_0123_71, 0);\n                        __m256i lhs_mat_23_71 = _mm256_permute2f128_si256(lhs_mat_0123_71, lhs_mat_0123_71, 17);\n\n                        // Bsums are loaded for the different Q8_K blocks\n                        __m128i lhs_raw_bsums_01_0123 = _mm_loadu_si128((const __m128i *)((a_ptrs[rp][b].bsums + 32 * sb)));\n                        __m128i lhs_raw_bsums_23_0123 = _mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].bsums + 8 + 32 * sb));\n                        __m128i lhs_raw_bsums_01_4567 = _mm_loadu_si128((const __m128i *)((a_ptrs[rp][b].bsums + 16 + 32 * sb)));\n                        __m128i lhs_raw_bsums_23_4567 = _mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].bsums + 24 + 32 * sb));\n\n                        // Shuffle pattern one - left side input\n                        const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)\n                        const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)\n\n                        const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)\n                        const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160); //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)\n\n                        const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)\n                        const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)\n\n                        const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)\n                        const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160); //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)\n\n                        const __m256i lhs_mat_01_20_sp1 = _mm256_shuffle_epi32(lhs_mat_01_20, 160); //A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3)\n                        const __m256i lhs_mat_23_20_sp1 = _mm256_shuffle_epi32(lhs_mat_23_20, 160); //A22(0-3) A23(0-3) A22(0-3) A23(0-3) A22(0-3) A23(0-3) A22(0-3) A23(0-3)\n\n                        const __m256i lhs_mat_01_21_sp1 = _mm256_shuffle_epi32(lhs_mat_01_21, 160); //A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11)\n                        const __m256i lhs_mat_23_21_sp1 = _mm256_shuffle_epi32(lhs_mat_23_21, 160); //A22(8-11) A23(8-11) A22(8-11) A23(8-11) A22(8-11) A23(8-11) A22(8-11) A23(8-11)\n\n                        const __m256i lhs_mat_01_30_sp1 = _mm256_shuffle_epi32(lhs_mat_01_30, 160); //A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3)\n                        const __m256i lhs_mat_23_30_sp1 = _mm256_shuffle_epi32(lhs_mat_23_30, 160); //A32(0-3) A33(0-3) A32(0-3) A33(0-3) A32(0-3) A33(0-3) A32(0-3) A33(0-3)\n\n                        const __m256i lhs_mat_01_31_sp1 = _mm256_shuffle_epi32(lhs_mat_01_31, 160); //A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11)\n                        const __m256i lhs_mat_23_31_sp1 = _mm256_shuffle_epi32(lhs_mat_23_31, 160); //A32(8-11) A33(8-11) A32(8-11) A33(8-11) A32(8-11) A33(8-11) A32(8-11) A33(8-11)\n\n                        const __m256i lhs_mat_01_40_sp1 = _mm256_shuffle_epi32(lhs_mat_01_40, 160); //A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3)\n                        const __m256i lhs_mat_23_40_sp1 = _mm256_shuffle_epi32(lhs_mat_23_40, 160); //A42(0-3) A43(0-3) A42(0-3) A43(0-3) A42(0-3) A43(0-3) A42(0-3) A43(0-3)\n\n                        const __m256i lhs_mat_01_41_sp1 = _mm256_shuffle_epi32(lhs_mat_01_41, 160); //A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11)\n                        const __m256i lhs_mat_23_41_sp1 = _mm256_shuffle_epi32(lhs_mat_23_41, 160); //A42(8-11) A43(8-11) A42(8-11) A43(8-11) A42(8-11) A43(8-11) A42(8-11) A43(8-11)\n\n                        const __m256i lhs_mat_01_50_sp1 = _mm256_shuffle_epi32(lhs_mat_01_50, 160); //A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3)\n                        const __m256i lhs_mat_23_50_sp1 = _mm256_shuffle_epi32(lhs_mat_23_50, 160); //A52(0-3) A53(0-3) A52(0-3) A53(0-3) A52(0-3) A53(0-3) A52(0-3) A53(0-3)\n\n                        const __m256i lhs_mat_01_51_sp1 = _mm256_shuffle_epi32(lhs_mat_01_51, 160); //A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11)\n                        const __m256i lhs_mat_23_51_sp1 = _mm256_shuffle_epi32(lhs_mat_23_51, 160); //A52(8-11) A53(8-11) A52(8-11) A53(8-11) A52(8-11) A53(8-11) A52(8-11) A53(8-11)\n\n                        const __m256i lhs_mat_01_60_sp1 = _mm256_shuffle_epi32(lhs_mat_01_60, 160); //A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3)\n                        const __m256i lhs_mat_23_60_sp1 = _mm256_shuffle_epi32(lhs_mat_23_60, 160); //A62(0-3) A63(0-3) A62(0-3) A63(0-3) A62(0-3) A63(0-3) A62(0-3) A63(0-3)\n\n                        const __m256i lhs_mat_01_61_sp1 = _mm256_shuffle_epi32(lhs_mat_01_61, 160); //A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11)\n                        const __m256i lhs_mat_23_61_sp1 = _mm256_shuffle_epi32(lhs_mat_23_61, 160); //A62(8-11) A63(8-11) A62(8-11) A63(8-11) A62(8-11) A63(8-11) A62(8-11) A63(8-11)\n\n                        const __m256i lhs_mat_01_70_sp1 = _mm256_shuffle_epi32(lhs_mat_01_70, 160); //A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3)\n                        const __m256i lhs_mat_23_70_sp1 = _mm256_shuffle_epi32(lhs_mat_23_70, 160); //A72(0-3) A73(0-3) A72(0-3) A73(0-3) A72(0-3) A73(0-3) A72(0-3) A73(0-3)\n\n                        const __m256i lhs_mat_01_71_sp1 = _mm256_shuffle_epi32(lhs_mat_01_71, 160); //A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11)\n                        const __m256i lhs_mat_23_71_sp1 = _mm256_shuffle_epi32(lhs_mat_23_71, 160); //A72(8-11) A73(8-11) A72(8-11) A73(8-11) A72(8-11) A73(8-11) A72(8-11) A73(8-11)\n\n                        // Shuffle pattern two- left side input\n                        const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)\n                        const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)\n\n                        const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)\n                        const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)\n\n                        const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)\n                        const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)\n\n                        const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)\n                        const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)\n\n                        const __m256i lhs_mat_01_20_sp2 = _mm256_shuffle_epi32(lhs_mat_01_20, 245); //A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7)\n                        const __m256i lhs_mat_23_20_sp2 = _mm256_shuffle_epi32(lhs_mat_23_20, 245); //A22(4-7) A23(4-7) A22(4-7) A23(4-7) A22(4-7) A23(4-7) A22(4-7) A23(4-7)\n\n                        const __m256i lhs_mat_01_21_sp2 = _mm256_shuffle_epi32(lhs_mat_01_21, 245); //A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15)\n                        const __m256i lhs_mat_23_21_sp2 = _mm256_shuffle_epi32(lhs_mat_23_21, 245); //A22(12-15) A23(12-15) A22(12-15) A23(12-15) A22(12-15) A23(12-15) A22(12-15) A23(12-15)\n\n                        const __m256i lhs_mat_01_30_sp2 = _mm256_shuffle_epi32(lhs_mat_01_30, 245); //A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7)\n                        const __m256i lhs_mat_23_30_sp2 = _mm256_shuffle_epi32(lhs_mat_23_30, 245); //A32(4-7) A33(4-7) A32(4-7) A33(4-7) A32(4-7) A33(4-7) A32(4-7) A33(4-7)\n\n                        const __m256i lhs_mat_01_31_sp2 = _mm256_shuffle_epi32(lhs_mat_01_31, 245); //A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15)\n                        const __m256i lhs_mat_23_31_sp2 = _mm256_shuffle_epi32(lhs_mat_23_31, 245); //A32(12-15) A33(12-15) A32(12-15) A33(12-15) A32(12-15) A33(12-15) A32(12-15) A33(12-15)\n\n                        const __m256i lhs_mat_01_40_sp2 = _mm256_shuffle_epi32(lhs_mat_01_40, 245); //A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7)\n                        const __m256i lhs_mat_23_40_sp2 = _mm256_shuffle_epi32(lhs_mat_23_40, 245); //A42(4-7) A43(4-7) A42(4-7) A43(4-7) A42(4-7) A43(4-7) A42(4-7) A43(4-7)\n\n                        const __m256i lhs_mat_01_41_sp2 = _mm256_shuffle_epi32(lhs_mat_01_41, 245); //A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15)\n                        const __m256i lhs_mat_23_41_sp2 = _mm256_shuffle_epi32(lhs_mat_23_41, 245); //A42(12-15) A43(12-15) A42(12-15) A43(12-15) A42(12-15) A43(12-15) A42(12-15) A43(12-15)\n\n                        const __m256i lhs_mat_01_50_sp2 = _mm256_shuffle_epi32(lhs_mat_01_50, 245); //A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7)\n                        const __m256i lhs_mat_23_50_sp2 = _mm256_shuffle_epi32(lhs_mat_23_50, 245); //A52(4-7) A53(4-7) A52(4-7) A53(4-7) A52(4-7) A53(4-7) A52(4-7) A53(4-7)\n\n                        const __m256i lhs_mat_01_51_sp2 = _mm256_shuffle_epi32(lhs_mat_01_51, 245); //A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15)\n                        const __m256i lhs_mat_23_51_sp2 = _mm256_shuffle_epi32(lhs_mat_23_51, 245); //A52(12-15) A53(12-15) A52(12-15) A53(12-15) A52(12-15) A53(12-15) A52(12-15) A53(12-15)\n\n                        const __m256i lhs_mat_01_60_sp2 = _mm256_shuffle_epi32(lhs_mat_01_60, 245); //A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7)\n                        const __m256i lhs_mat_23_60_sp2 = _mm256_shuffle_epi32(lhs_mat_23_60, 245); //A62(4-7) A63(4-7) A62(4-7) A63(4-7) A62(4-7) A63(4-7) A62(4-7) A63(4-7)\n\n                        const __m256i lhs_mat_01_61_sp2 = _mm256_shuffle_epi32(lhs_mat_01_61, 245); //A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15)\n                        const __m256i lhs_mat_23_61_sp2 = _mm256_shuffle_epi32(lhs_mat_23_61, 245); //A62(12-15) A63(12-15) A62(12-15) A63(12-15) A62(12-15) A63(12-15) A62(12-15) A63(12-15)\n\n                        const __m256i lhs_mat_01_70_sp2 = _mm256_shuffle_epi32(lhs_mat_01_70, 245); //A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7)\n                        const __m256i lhs_mat_23_70_sp2 = _mm256_shuffle_epi32(lhs_mat_23_70, 245); //A72(4-7) A73(4-7) A72(4-7) A73(4-7) A72(4-7) A73(4-7) A72(4-7) A73(4-7)\n\n                        const __m256i lhs_mat_01_71_sp2 = _mm256_shuffle_epi32(lhs_mat_01_71, 245); //A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15)\n                        const __m256i lhs_mat_23_71_sp2 = _mm256_shuffle_epi32(lhs_mat_23_71, 245); //A72(12-15) A73(12-15) A72(12-15) A73(12-15) A72(12-15) A73(12-15) A72(12-15) A73(12-15)\n\n                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane\n                        __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1),_mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1));\n                        __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1),_mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1));\n\n                        __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1),_mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1));\n                        __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1),_mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1));\n\n                        __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1),_mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1));\n                        __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1),_mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1));\n\n                        __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1),_mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1));\n                        __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1),_mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1));\n\n                        __m256i iacc_mat_00_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp1, lhs_mat_01_20_sp1),_mm256_maddubs_epi16(rhs_mat_0145_21_sp1, lhs_mat_01_21_sp1));\n                        __m256i iacc_mat_01_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp1, lhs_mat_01_20_sp1),_mm256_maddubs_epi16(rhs_mat_2367_21_sp1, lhs_mat_01_21_sp1));\n\n                        __m256i iacc_mat_10_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp1, lhs_mat_23_20_sp1),_mm256_maddubs_epi16(rhs_mat_0145_21_sp1, lhs_mat_23_21_sp1));\n                        __m256i iacc_mat_11_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp1, lhs_mat_23_20_sp1),_mm256_maddubs_epi16(rhs_mat_2367_21_sp1, lhs_mat_23_21_sp1));\n\n                        __m256i iacc_mat_00_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp1, lhs_mat_01_30_sp1),_mm256_maddubs_epi16(rhs_mat_0145_31_sp1, lhs_mat_01_31_sp1));\n                        __m256i iacc_mat_01_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp1, lhs_mat_01_30_sp1),_mm256_maddubs_epi16(rhs_mat_2367_31_sp1, lhs_mat_01_31_sp1));\n\n                        __m256i iacc_mat_10_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp1, lhs_mat_23_30_sp1),_mm256_maddubs_epi16(rhs_mat_0145_31_sp1, lhs_mat_23_31_sp1));\n                        __m256i iacc_mat_11_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp1, lhs_mat_23_30_sp1),_mm256_maddubs_epi16(rhs_mat_2367_31_sp1, lhs_mat_23_31_sp1));\n\n                        __m256i iacc_mat_00_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp1, lhs_mat_01_40_sp1),_mm256_maddubs_epi16(rhs_mat_0145_41_sp1, lhs_mat_01_41_sp1));\n                        __m256i iacc_mat_01_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp1, lhs_mat_01_40_sp1),_mm256_maddubs_epi16(rhs_mat_2367_41_sp1, lhs_mat_01_41_sp1));\n\n                        __m256i iacc_mat_10_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp1, lhs_mat_23_40_sp1),_mm256_maddubs_epi16(rhs_mat_0145_41_sp1, lhs_mat_23_41_sp1));\n                        __m256i iacc_mat_11_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp1, lhs_mat_23_40_sp1),_mm256_maddubs_epi16(rhs_mat_2367_41_sp1, lhs_mat_23_41_sp1));\n\n                        __m256i iacc_mat_00_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp1, lhs_mat_01_50_sp1),_mm256_maddubs_epi16(rhs_mat_0145_51_sp1, lhs_mat_01_51_sp1));\n                        __m256i iacc_mat_01_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp1, lhs_mat_01_50_sp1),_mm256_maddubs_epi16(rhs_mat_2367_51_sp1, lhs_mat_01_51_sp1));\n\n                        __m256i iacc_mat_10_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp1, lhs_mat_23_50_sp1),_mm256_maddubs_epi16(rhs_mat_0145_51_sp1, lhs_mat_23_51_sp1));\n                        __m256i iacc_mat_11_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp1, lhs_mat_23_50_sp1),_mm256_maddubs_epi16(rhs_mat_2367_51_sp1, lhs_mat_23_51_sp1));\n\n                        __m256i iacc_mat_00_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp1, lhs_mat_01_60_sp1),_mm256_maddubs_epi16(rhs_mat_0145_61_sp1, lhs_mat_01_61_sp1));\n                        __m256i iacc_mat_01_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp1, lhs_mat_01_60_sp1),_mm256_maddubs_epi16(rhs_mat_2367_61_sp1, lhs_mat_01_61_sp1));\n\n                        __m256i iacc_mat_10_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp1, lhs_mat_23_60_sp1),_mm256_maddubs_epi16(rhs_mat_0145_61_sp1, lhs_mat_23_61_sp1));\n                        __m256i iacc_mat_11_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp1, lhs_mat_23_60_sp1),_mm256_maddubs_epi16(rhs_mat_2367_61_sp1, lhs_mat_23_61_sp1));\n\n                        __m256i iacc_mat_00_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp1, lhs_mat_01_70_sp1),_mm256_maddubs_epi16(rhs_mat_0145_71_sp1, lhs_mat_01_71_sp1));\n                        __m256i iacc_mat_01_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp1, lhs_mat_01_70_sp1),_mm256_maddubs_epi16(rhs_mat_2367_71_sp1, lhs_mat_01_71_sp1));\n\n                        __m256i iacc_mat_10_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp1, lhs_mat_23_70_sp1),_mm256_maddubs_epi16(rhs_mat_0145_71_sp1, lhs_mat_23_71_sp1));\n                        __m256i iacc_mat_11_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp1, lhs_mat_23_70_sp1),_mm256_maddubs_epi16(rhs_mat_2367_71_sp1, lhs_mat_23_71_sp1));\n\n\n                        __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2),_mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2));\n                        __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2),_mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2));\n\n                        __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2),_mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2));\n                        __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2),_mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2));\n\n                        __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2),_mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2));\n                        __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2),_mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2));\n\n                        __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2),_mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2));\n                        __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2),_mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2));\n\n                        __m256i iacc_mat_00_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp2, lhs_mat_01_20_sp2),_mm256_maddubs_epi16(rhs_mat_0145_21_sp2, lhs_mat_01_21_sp2));\n                        __m256i iacc_mat_01_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp2, lhs_mat_01_20_sp2),_mm256_maddubs_epi16(rhs_mat_2367_21_sp2, lhs_mat_01_21_sp2));\n\n                        __m256i iacc_mat_10_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp2, lhs_mat_23_20_sp2),_mm256_maddubs_epi16(rhs_mat_0145_21_sp2, lhs_mat_23_21_sp2));\n                        __m256i iacc_mat_11_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp2, lhs_mat_23_20_sp2),_mm256_maddubs_epi16(rhs_mat_2367_21_sp2, lhs_mat_23_21_sp2));\n\n                        __m256i iacc_mat_00_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp2, lhs_mat_01_30_sp2),_mm256_maddubs_epi16(rhs_mat_0145_31_sp2, lhs_mat_01_31_sp2));\n                        __m256i iacc_mat_01_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp2, lhs_mat_01_30_sp2),_mm256_maddubs_epi16(rhs_mat_2367_31_sp2, lhs_mat_01_31_sp2));\n\n                        __m256i iacc_mat_10_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp2, lhs_mat_23_30_sp2),_mm256_maddubs_epi16(rhs_mat_0145_31_sp2, lhs_mat_23_31_sp2));\n                        __m256i iacc_mat_11_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp2, lhs_mat_23_30_sp2),_mm256_maddubs_epi16(rhs_mat_2367_31_sp2, lhs_mat_23_31_sp2));\n\n                        __m256i iacc_mat_00_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp2, lhs_mat_01_40_sp2),_mm256_maddubs_epi16(rhs_mat_0145_41_sp2, lhs_mat_01_41_sp2));\n                        __m256i iacc_mat_01_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp2, lhs_mat_01_40_sp2),_mm256_maddubs_epi16(rhs_mat_2367_41_sp2, lhs_mat_01_41_sp2));\n\n                        __m256i iacc_mat_10_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp2, lhs_mat_23_40_sp2),_mm256_maddubs_epi16(rhs_mat_0145_41_sp2, lhs_mat_23_41_sp2));\n                        __m256i iacc_mat_11_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp2, lhs_mat_23_40_sp2),_mm256_maddubs_epi16(rhs_mat_2367_41_sp2, lhs_mat_23_41_sp2));\n\n                        __m256i iacc_mat_00_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp2, lhs_mat_01_50_sp2),_mm256_maddubs_epi16(rhs_mat_0145_51_sp2, lhs_mat_01_51_sp2));\n                        __m256i iacc_mat_01_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp2, lhs_mat_01_50_sp2),_mm256_maddubs_epi16(rhs_mat_2367_51_sp2, lhs_mat_01_51_sp2));\n\n                        __m256i iacc_mat_10_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp2, lhs_mat_23_50_sp2),_mm256_maddubs_epi16(rhs_mat_0145_51_sp2, lhs_mat_23_51_sp2));\n                        __m256i iacc_mat_11_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp2, lhs_mat_23_50_sp2),_mm256_maddubs_epi16(rhs_mat_2367_51_sp2, lhs_mat_23_51_sp2));\n\n                        __m256i iacc_mat_00_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp2, lhs_mat_01_60_sp2),_mm256_maddubs_epi16(rhs_mat_0145_61_sp2, lhs_mat_01_61_sp2));\n                        __m256i iacc_mat_01_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp2, lhs_mat_01_60_sp2),_mm256_maddubs_epi16(rhs_mat_2367_61_sp2, lhs_mat_01_61_sp2));\n\n                        __m256i iacc_mat_10_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp2, lhs_mat_23_60_sp2),_mm256_maddubs_epi16(rhs_mat_0145_61_sp2, lhs_mat_23_61_sp2));\n                        __m256i iacc_mat_11_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp2, lhs_mat_23_60_sp2),_mm256_maddubs_epi16(rhs_mat_2367_61_sp2, lhs_mat_23_61_sp2));\n\n                        __m256i iacc_mat_00_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp2, lhs_mat_01_70_sp2),_mm256_maddubs_epi16(rhs_mat_0145_71_sp2, lhs_mat_01_71_sp2));\n                        __m256i iacc_mat_01_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp2, lhs_mat_01_70_sp2),_mm256_maddubs_epi16(rhs_mat_2367_71_sp2, lhs_mat_01_71_sp2));\n\n                        __m256i iacc_mat_10_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp2, lhs_mat_23_70_sp2),_mm256_maddubs_epi16(rhs_mat_0145_71_sp2, lhs_mat_23_71_sp2));\n                        __m256i iacc_mat_11_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp2, lhs_mat_23_70_sp2),_mm256_maddubs_epi16(rhs_mat_2367_71_sp2, lhs_mat_23_71_sp2));\n\n                        // Combine results from both shuffle patterns for each output block\n                        __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);\n                        __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);\n                        __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);\n                        __m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);\n\n                        __m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);\n                        __m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);\n                        __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);\n                        __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);\n\n                        __m256i iacc_mat_00_2 = _mm256_add_epi16(iacc_mat_00_2_sp1, iacc_mat_00_2_sp2);\n                        __m256i iacc_mat_01_2 = _mm256_add_epi16(iacc_mat_01_2_sp1, iacc_mat_01_2_sp2);\n                        __m256i iacc_mat_10_2 = _mm256_add_epi16(iacc_mat_10_2_sp1, iacc_mat_10_2_sp2);\n                        __m256i iacc_mat_11_2 = _mm256_add_epi16(iacc_mat_11_2_sp1, iacc_mat_11_2_sp2);\n\n                        __m256i iacc_mat_00_3 = _mm256_add_epi16(iacc_mat_00_3_sp1, iacc_mat_00_3_sp2);\n                        __m256i iacc_mat_01_3 = _mm256_add_epi16(iacc_mat_01_3_sp1, iacc_mat_01_3_sp2);\n                        __m256i iacc_mat_10_3 = _mm256_add_epi16(iacc_mat_10_3_sp1, iacc_mat_10_3_sp2);\n                        __m256i iacc_mat_11_3 = _mm256_add_epi16(iacc_mat_11_3_sp1, iacc_mat_11_3_sp2);\n\n                        __m256i iacc_mat_00_4 = _mm256_add_epi16(iacc_mat_00_4_sp1, iacc_mat_00_4_sp2);\n                        __m256i iacc_mat_01_4 = _mm256_add_epi16(iacc_mat_01_4_sp1, iacc_mat_01_4_sp2);\n                        __m256i iacc_mat_10_4 = _mm256_add_epi16(iacc_mat_10_4_sp1, iacc_mat_10_4_sp2);\n                        __m256i iacc_mat_11_4 = _mm256_add_epi16(iacc_mat_11_4_sp1, iacc_mat_11_4_sp2);\n\n                        __m256i iacc_mat_00_5 = _mm256_add_epi16(iacc_mat_00_5_sp1, iacc_mat_00_5_sp2);\n                        __m256i iacc_mat_01_5 = _mm256_add_epi16(iacc_mat_01_5_sp1, iacc_mat_01_5_sp2);\n                        __m256i iacc_mat_10_5 = _mm256_add_epi16(iacc_mat_10_5_sp1, iacc_mat_10_5_sp2);\n                        __m256i iacc_mat_11_5 = _mm256_add_epi16(iacc_mat_11_5_sp1, iacc_mat_11_5_sp2);\n\n                        __m256i iacc_mat_00_6 = _mm256_add_epi16(iacc_mat_00_6_sp1, iacc_mat_00_6_sp2);\n                        __m256i iacc_mat_01_6 = _mm256_add_epi16(iacc_mat_01_6_sp1, iacc_mat_01_6_sp2);\n                        __m256i iacc_mat_10_6 = _mm256_add_epi16(iacc_mat_10_6_sp1, iacc_mat_10_6_sp2);\n                        __m256i iacc_mat_11_6 = _mm256_add_epi16(iacc_mat_11_6_sp1, iacc_mat_11_6_sp2);\n\n                        __m256i iacc_mat_00_7 = _mm256_add_epi16(iacc_mat_00_7_sp1, iacc_mat_00_7_sp2);\n                        __m256i iacc_mat_01_7 = _mm256_add_epi16(iacc_mat_01_7_sp1, iacc_mat_01_7_sp2);\n                        __m256i iacc_mat_10_7 = _mm256_add_epi16(iacc_mat_10_7_sp1, iacc_mat_10_7_sp2);\n                        __m256i iacc_mat_11_7 = _mm256_add_epi16(iacc_mat_11_7_sp1, iacc_mat_11_7_sp2);\n\n                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block\n                        iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);\n                        iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);\n                        iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);\n                        iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);\n\n                        iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);\n                        iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);\n                        iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);\n                        iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);\n\n                        iacc_mat_00_2 = _mm256_madd_epi16(iacc_mat_00_2, scale_0145_2);\n                        iacc_mat_01_2 = _mm256_madd_epi16(iacc_mat_01_2, scale_2367_2);\n                        iacc_mat_10_2 = _mm256_madd_epi16(iacc_mat_10_2, scale_0145_2);\n                        iacc_mat_11_2 = _mm256_madd_epi16(iacc_mat_11_2, scale_2367_2);\n\n                        iacc_mat_00_3 = _mm256_madd_epi16(iacc_mat_00_3, scale_0145_3);\n                        iacc_mat_01_3 = _mm256_madd_epi16(iacc_mat_01_3, scale_2367_3);\n                        iacc_mat_10_3 = _mm256_madd_epi16(iacc_mat_10_3, scale_0145_3);\n                        iacc_mat_11_3 = _mm256_madd_epi16(iacc_mat_11_3, scale_2367_3);\n\n                        iacc_mat_00_4 = _mm256_madd_epi16(iacc_mat_00_4, scale_0145_4);\n                        iacc_mat_01_4 = _mm256_madd_epi16(iacc_mat_01_4, scale_2367_4);\n                        iacc_mat_10_4 = _mm256_madd_epi16(iacc_mat_10_4, scale_0145_4);\n                        iacc_mat_11_4 = _mm256_madd_epi16(iacc_mat_11_4, scale_2367_4);\n\n                        iacc_mat_00_5 = _mm256_madd_epi16(iacc_mat_00_5, scale_0145_5);\n                        iacc_mat_01_5 = _mm256_madd_epi16(iacc_mat_01_5, scale_2367_5);\n                        iacc_mat_10_5 = _mm256_madd_epi16(iacc_mat_10_5, scale_0145_5);\n                        iacc_mat_11_5 = _mm256_madd_epi16(iacc_mat_11_5, scale_2367_5);\n\n                        iacc_mat_00_6 = _mm256_madd_epi16(iacc_mat_00_6, scale_0145_6);\n                        iacc_mat_01_6 = _mm256_madd_epi16(iacc_mat_01_6, scale_2367_6);\n                        iacc_mat_10_6 = _mm256_madd_epi16(iacc_mat_10_6, scale_0145_6);\n                        iacc_mat_11_6 = _mm256_madd_epi16(iacc_mat_11_6, scale_2367_6);\n\n                        iacc_mat_00_7 = _mm256_madd_epi16(iacc_mat_00_7, scale_0145_7);\n                        iacc_mat_01_7 = _mm256_madd_epi16(iacc_mat_01_7, scale_2367_7);\n                        iacc_mat_10_7 = _mm256_madd_epi16(iacc_mat_10_7, scale_0145_7);\n                        iacc_mat_11_7 = _mm256_madd_epi16(iacc_mat_11_7, scale_2367_7);\n\n                        __m256i iacc_mat_00 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_00_0, iacc_mat_00_1), _mm256_add_epi32(iacc_mat_00_2, iacc_mat_00_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_00_4, iacc_mat_00_5), _mm256_add_epi32(iacc_mat_00_6, iacc_mat_00_7)));\n                        __m256i iacc_mat_01 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_01_0, iacc_mat_01_1), _mm256_add_epi32(iacc_mat_01_2, iacc_mat_01_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_01_4, iacc_mat_01_5), _mm256_add_epi32(iacc_mat_01_6, iacc_mat_01_7)));\n                        __m256i iacc_mat_10 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_10_0, iacc_mat_10_1), _mm256_add_epi32(iacc_mat_10_2, iacc_mat_10_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_10_4, iacc_mat_10_5), _mm256_add_epi32(iacc_mat_10_6, iacc_mat_10_7)));\n                        __m256i iacc_mat_11 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_11_0, iacc_mat_11_1), _mm256_add_epi32(iacc_mat_11_2, iacc_mat_11_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_11_4, iacc_mat_11_5), _mm256_add_epi32(iacc_mat_11_6, iacc_mat_11_7)));\n\n                        // Straighten out to make 4 row vectors\n                        __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);\n                        __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);\n                        __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);\n                        __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);\n\n                        // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes\n                        const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);\n                        const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);\n\n                        // Multiply with appropriate scales and accumulate (for both d and dmin) below\n                        acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);\n                        acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);\n                        acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);\n                        acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);\n\n                        __m256i lhs_bsums_01_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_0123), lhs_raw_bsums_01_0123, 1);\n                        __m256i lhs_bsums_23_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_0123), lhs_raw_bsums_23_0123, 1);\n                        __m256i lhs_bsums_01_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_4567), lhs_raw_bsums_01_4567, 1);\n                        __m256i lhs_bsums_23_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_4567), lhs_raw_bsums_23_4567, 1);\n\n                       // Take two bsums from two Q8_Ks at a time and multiply with corresponding mins values from each Q2_K\n                        __m256i iacc_row_min_0_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 0), mins_01);\n                        __m256i iacc_row_min_1_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 170), mins_01);\n                        __m256i iacc_row_min_2_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 0), mins_01);\n                        __m256i iacc_row_min_3_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 170), mins_01);\n\n                        __m256i iacc_row_min_0_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 85), mins_23);\n                        __m256i iacc_row_min_1_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 255), mins_23);\n                        __m256i iacc_row_min_2_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 85), mins_23);\n                        __m256i iacc_row_min_3_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 255), mins_23);\n\n                        __m256i iacc_row_min_0_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 0), mins_45);\n                        __m256i iacc_row_min_1_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 170), mins_45);\n                        __m256i iacc_row_min_2_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 0), mins_45);\n                        __m256i iacc_row_min_3_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 170), mins_45);\n\n                        __m256i iacc_row_min_0_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 85), mins_67);\n                        __m256i iacc_row_min_1_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 255), mins_67);\n                        __m256i iacc_row_min_2_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 85), mins_67);\n                        __m256i iacc_row_min_3_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 255), mins_67);\n\n                        __m256i iacc_row_min_0 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_0_01, iacc_row_min_0_23), _mm256_add_epi32(iacc_row_min_0_45,iacc_row_min_0_67));\n                        __m256i iacc_row_min_1 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_1_01, iacc_row_min_1_23), _mm256_add_epi32(iacc_row_min_1_45,iacc_row_min_1_67));\n                        __m256i iacc_row_min_2 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_2_01, iacc_row_min_2_23), _mm256_add_epi32(iacc_row_min_2_45,iacc_row_min_2_67));\n                        __m256i iacc_row_min_3 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_3_01, iacc_row_min_3_23), _mm256_add_epi32(iacc_row_min_3_45,iacc_row_min_3_67));\n\n                        acc_min_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);\n                        acc_min_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);\n                        acc_min_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);\n                        acc_min_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);\n\n                    }\n                }\n            }\n            // Store the accumulated values\n            for (int i = 0; i < 16; i++) {\n                _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));\n\n            }\n        }\n    }\n\n    for (; y < nr / 4; y ++) {\n\n        const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);\n\n        // Take group of eight block_q2_kx8 structures at each pass of the loop and perform dot product operation\n        for (int64_t x = xstart; x < nc / 8; x++) {\n\n            const block_q2_Kx8 * b_ptr = b_ptr_start + (x * b_nb);\n\n            // Master FP accumulators\n            __m256 acc_rows[4];\n            for (int i = 0; i < 4; i++) {\n                acc_rows[i] = _mm256_setzero_ps();\n            }\n\n            __m256 acc_min_rows[4];\n            for (int i = 0; i < 4; i++) {\n                acc_min_rows[i] = _mm256_setzero_ps();\n            }\n\n            for (int64_t b = 0; b < nb; b++) {\n                // Delta values - Load the eight scale values of block_q2_kx8\n                const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);\n\n                // dmin values - Load the eight dmin values of block_q2_kx8\n                const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);\n\n                // Loop to iterate over the sixteen sub blocks of a super block - eight sub blocks are processed per iteration\n                for (int sb = 0; sb < QK_K / 128; sb++) {\n\n                    // Load the eight block_q2_k for eight sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7\n                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + sb * 256));\n                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 128 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 160 + sb * 256));\n                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 192 + sb * 256));\n                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 224 + sb * 256));\n\n                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values\n                    //superblock    sub block   which part of sub block\n                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);\n\n                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);\n\n                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);\n\n                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);\n                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);\n\n                    // 2-bit -> 8-bit\n                    // First sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m3b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)\n                    const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m3b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)\n\n                    const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m3b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)\n                    const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m3b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)\n\n                    // Second sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_10 = _mm256_and_si256(rhs_raw_mat_0145_2, m3b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)\n                    const __m256i rhs_mat_2367_10 = _mm256_and_si256(rhs_raw_mat_2367_2, m3b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)\n\n                    const __m256i rhs_mat_0145_11 = _mm256_and_si256(rhs_raw_mat_0145_3, m3b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)\n                    const __m256i rhs_mat_2367_11 = _mm256_and_si256(rhs_raw_mat_2367_3, m3b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)\n\n                    // Third sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 2), m3b); //B20(0-7) B21(0-7) B24(0-7) B25(0-7)\n                    const __m256i rhs_mat_2367_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 2), m3b); //B22(0-7) B23(0-7) B26(0-7) B27(0-7)\n\n                    const __m256i rhs_mat_0145_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 2), m3b); //B20(8-15) B21(8-15) B24(8-15) B25(8-15)\n                    const __m256i rhs_mat_2367_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 2), m3b); //B22(8-15) B23(8-15) B26(8-15) B27(8-15)\n\n                    // Fourth sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 2), m3b); //B30(0-7) B31(0-7) B34(0-7) B35(0-7)\n                    const __m256i rhs_mat_2367_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 2), m3b); //B32(0-7) B33(0-7) B36(0-7) B37(0-7)\n\n                    const __m256i rhs_mat_0145_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 2), m3b); //B30(8-15) B31(8-15) B34(8-15) B35(8-15)\n                    const __m256i rhs_mat_2367_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 2), m3b); //B32(8-15) B33(8-15) B36(8-15) B37(8-15)\n\n                    // Fifth sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m3b); //B40(0-7) B41(0-7) B44(0-7) B45(0-7)\n                    const __m256i rhs_mat_2367_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m3b); //B42(0-7) B43(0-7) B46(0-7) B47(0-7)\n\n                    const __m256i rhs_mat_0145_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m3b); //B40(8-15) B41(8-15) B44(8-15) B45(8-15)\n                    const __m256i rhs_mat_2367_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m3b); //B42(8-15) B43(8-15) B46(8-15) B47(8-15)\n\n                    // Sixth sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m3b); //B50(0-7) B51(0-7) B54(0-7) B55(0-7)\n                    const __m256i rhs_mat_2367_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m3b); //B52(0-7) B53(0-7) B56(0-7) B57(0-7)\n\n                    const __m256i rhs_mat_0145_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m3b); //B50(8-15) B51(8-15) B54(8-15) B55(8-15)\n                    const __m256i rhs_mat_2367_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m3b); //B52(8-15) B53(8-15) B56(8-15) B57(8-15)\n\n                    // Seventh sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 6), m3b); //B60(0-7) B61(0-7) B64(0-7) B65(0-7)\n                    const __m256i rhs_mat_2367_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 6), m3b); //B62(0-7) B63(0-7) B66(0-7) B67(0-7)\n\n                    const __m256i rhs_mat_0145_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 6), m3b); //B60(8-15) B61(8-15) B64(8-15) B65(8-15)\n                    const __m256i rhs_mat_2367_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 6), m3b); //B62(8-15) B63(8-15) B66(8-15) B67(8-15)\n\n                    // Eighth sub block of the eight sub blocks processed in the iteration\n                    const __m256i rhs_mat_0145_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 6), m3b); //B70(0-7) B71(0-7) B74(0-7) B75(0-7)\n                    const __m256i rhs_mat_2367_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 6), m3b); //B72(0-7) B73(0-7) B76(0-7) B77(0-7)\n\n                    const __m256i rhs_mat_0145_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 6), m3b); //B70(8-15) B71(8-15) B74(8-15) B75(8-15)\n                    const __m256i rhs_mat_2367_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 6), m3b); //B72(8-15) B73(8-15) B76(8-15) B77(8-15)\n\n                    // Shuffle pattern one - right side input\n                    const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)\n                    const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)\n\n                    const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)\n                    const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)\n\n                    const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)\n                    const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)\n\n                    const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)\n                    const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)\n\n                    const __m256i rhs_mat_0145_20_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_20, 136); //B20(0-3) B21(0-3) B20(0-3) B21(0-3) B24(0-3) B25(0-3) B24(0-3) B25(0-3)\n                    const __m256i rhs_mat_2367_20_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_20, 136); //B22(0-3) B23(0-3) B22(0-3) B23(0-3) B26(0-3) B27(0-3) B26(0-3) B27(0-3)\n\n                    const __m256i rhs_mat_0145_21_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_21, 136); //B20(8-11) B21(8-11) B20(8-11) B21(8-11) B24(8-11) B25(8-11) B24(8-11) B25(8-11)\n                    const __m256i rhs_mat_2367_21_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_21, 136); //B22(8-11) B23(8-11) B22(8-11) B23(8-11) B26(8-11) B27(8-11) B26(8-11) B27(8-11)\n\n                    const __m256i rhs_mat_0145_30_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_30, 136); //B30(0-3) B31(0-3) B30(0-3) B31(0-3) B34(0-3) B35(0-3) B34(0-3) B35(0-3)\n                    const __m256i rhs_mat_2367_30_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_30, 136); //B32(0-3) B33(0-3) B32(0-3) B33(0-3) B36(0-3) B37(0-3) B36(0-3) B37(0-3)\n\n                    const __m256i rhs_mat_0145_31_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_31, 136); //B30(8-11) B31(8-11) B30(8-11) B31(8-11) B34(8-11) B35(8-11) B34(8-11) B35(8-11\n                    const __m256i rhs_mat_2367_31_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_31, 136); //B32(8-11) B33(8-11) B32(8-11) B33(8-11) B36(8-11) B37(8-11) B36(8-11) B37(8-11)\n\n                    const __m256i rhs_mat_0145_40_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_40, 136); //B40(0-3) B41(0-3) B40(0-3) B41(0-3) B44(0-3) B45(0-3) B44(0-3) B45(0-3)\n                    const __m256i rhs_mat_2367_40_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_40, 136); //B42(0-3) B43(0-3) B42(0-3) B43(0-3) B46(0-3) B47(0-3) B46(0-3) B47(0-3)\n\n                    const __m256i rhs_mat_0145_41_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_41, 136); //B40(8-11) B41(8-11) B40(8-11) B41(8-11) B44(8-11) B45(8-11) B44(8-11) B45(8-11)\n                    const __m256i rhs_mat_2367_41_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_41, 136); //B42(8-11) B43(8-11) B42(8-11) B43(8-11) B46(8-11) B47(8-11) B46(8-11) B47(8-11)\n\n                    const __m256i rhs_mat_0145_50_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_50, 136); //B50(0-3) B51(0-3) B50(0-3) B51(0-3) B54(0-3) B55(0-3) B54(0-3) B55(0-3)\n                    const __m256i rhs_mat_2367_50_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_50, 136); //B52(0-3) B53(0-3) B52(0-3) B53(0-3) B56(0-3) B57(0-3) B56(0-3) B57(0-3)\n\n                    const __m256i rhs_mat_0145_51_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_51, 136); //B50(8-11) B51(8-11) B50(8-11) B51(8-11) B54(8-11) B55(8-11) B54(8-11) B55(8-11)\n                    const __m256i rhs_mat_2367_51_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_51, 136); //B52(8-11) B53(8-11) B52(8-11) B53(8-11) B56(8-11) B57(8-11) B56(8-11) B57(8-11)\n\n                    const __m256i rhs_mat_0145_60_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_60, 136); //B60(0-3) B61(0-3) B60(0-3) B61(0-3) B64(0-3) B65(0-3) B64(0-3) B65(0-3)\n                    const __m256i rhs_mat_2367_60_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_60, 136); //B62(0-3) B63(0-3) B62(0-3) B63(0-3) B66(0-3) B67(0-3) B66(0-3) B67(0-3)\n\n                    const __m256i rhs_mat_0145_61_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_61, 136); //B60(8-11) B61(8-11) B60(8-11) B61(8-11) B64(8-11) B65(8-11) B64(8-11) B65(8-11)\n                    const __m256i rhs_mat_2367_61_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_61, 136); //B62(8-11) B63(8-11) B62(8-11) B63(8-11) B66(8-11) B67(8-11) B66(8-11) B67(8-11)\n\n                    const __m256i rhs_mat_0145_70_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_70, 136); //B70(0-3) B71(0-3) B70(0-3) B71(0-3) B74(0-3) B75(0-3) B74(0-3) B75(0-3)\n                    const __m256i rhs_mat_2367_70_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_70, 136); //B72(0-3) B73(0-3) B72(0-3) B73(0-3) B76(0-3) B77(0-3) B76(0-3) B77(0-3)\n\n                    const __m256i rhs_mat_0145_71_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_71, 136); //B70(8-11) B71(8-11) B70(8-11) B71(8-11) B74(8-11) B75(8-11) B74(8-11) B75(8-11)\n                    const __m256i rhs_mat_2367_71_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_71, 136); //B72(8-11) B73(8-11) B72(8-11) B73(8-11) B76(8-11) B77(8-11) B76(8-11) B77(8-11)\n\n\n                    // Shuffle pattern two - right side input\n                    const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)\n                    const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)\n\n                    const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)\n                    const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)\n\n                    const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)\n                    const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)\n\n                    const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)\n                    const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)\n\n                    const __m256i rhs_mat_0145_20_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_20, 221); //B20(4-7) B21(4-7) B20(4-7) B21(4-7) B24(4-7) B25(4-7) B24(4-7) B25(4-7)\n                    const __m256i rhs_mat_2367_20_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_20, 221); //B22(4-7) B23(4-7) B22(4-7) B23(4-7) B26(4-7) B27(4-7) B26(4-7) B27(4-7)\n\n                    const __m256i rhs_mat_0145_21_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_21, 221); //B20(12-15) B21(12-15) B20(12-15) B21(12-15) B24(12-15) B25(12-15) B24(12-15) B25(12-15)\n                    const __m256i rhs_mat_2367_21_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_21, 221); //B22(12-15) B23(12-15) B22(12-15) B23(12-15) B26(12-15) B27(12-15) B26(12-15) B27(12-15)\n\n                    const __m256i rhs_mat_0145_30_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_30, 221); //B30(4-7) B31(4-7) B30(4-7) B31(4-7) B34(4-7) B35(4-7) B34(4-7) B35(4-7)\n                    const __m256i rhs_mat_2367_30_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_30, 221); //B32(4-7) B33(4-7) B32(4-7) B33(4-7) B36(4-7) B37(4-7) B36(4-7) B37(4-7)\n\n                    const __m256i rhs_mat_0145_31_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_31, 221); //B30(12-15) B31(12-15) B30(12-15) B31(12-15) B34(12-15) B35(12-15) B34(12-15) B35(12-15)\n                    const __m256i rhs_mat_2367_31_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_31, 221); //B32(12-15) B33(12-15) B32(12-15) B33(12-15) B36(12-15) B37(12-15) B36(12-15) B37(12-15)\n\n                    const __m256i rhs_mat_0145_40_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_40, 221); //B40(4-7) B41(4-7) B40(4-7) B41(4-7) B44(4-7) B45(4-7) B44(4-7) B45(4-7)\n                    const __m256i rhs_mat_2367_40_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_40, 221); //B42(4-7) B43(4-7) B42(4-7) B43(4-7) B46(4-7) B47(4-7) B46(4-7) B47(4-7)\n\n                    const __m256i rhs_mat_0145_41_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_41, 221); //B40(12-15) B41(12-15) B40(12-15) B41(12-15) B44(12-15) B45(12-15) B44(12-15) B45(12-15)\n                    const __m256i rhs_mat_2367_41_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_41, 221); //B42(12-15) B43(12-15) B42(12-15) B43(12-15) B46(12-15) B47(12-15) B46(12-15) B47(12-15)\n\n                    const __m256i rhs_mat_0145_50_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_50, 221); //B50(4-7) B51(4-7) B50(4-7) B51(4-7) B54(4-7) B55(4-7) B54(4-7) B55(4-7)\n                    const __m256i rhs_mat_2367_50_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_50, 221); //B52(4-7) B53(4-7) B52(4-7) B53(4-7) B56(4-7) B57(4-7) B56(4-7) B57(4-7)\n\n                    const __m256i rhs_mat_0145_51_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_51, 221); //B50(12-15) B51(12-15) B50(12-15) B51(12-15) B54(12-15) B55(12-15) B54(12-15) B55(12-15)\n                    const __m256i rhs_mat_2367_51_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_51, 221); //B52(12-15) B53(12-15) B52(12-15) B53(12-15) B56(12-15) B57(12-15) B56(12-15) B57(12-15)\n\n                    const __m256i rhs_mat_0145_60_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_60, 221); //B60(4-7) B61(4-7) B60(4-7) B61(4-7) B64(4-7) B65(4-7) B64(4-7) B65(4-7)\n                    const __m256i rhs_mat_2367_60_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_60, 221); //B62(4-7) B63(4-7) B62(4-7) B63(4-7) B66(4-7) B67(4-7) B66(4-7) B67(4-7)\n\n                    const __m256i rhs_mat_0145_61_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_61, 221); //B60(12-15) B61(12-15) B60(12-15) B61(12-15) B64(12-15) B65(12-15) B64(12-15) B65(12-15)\n                    const __m256i rhs_mat_2367_61_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_61, 221); //B62(12-15) B63(12-15) B62(12-15) B63(12-15) B66(12-15) B67(12-15) B66(12-15) B67(12-15)\n\n                    const __m256i rhs_mat_0145_70_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_70, 221); //B70(4-7) B71(4-7) B70(4-7) B71(4-7) B74(4-7) B75(4-7) B74(4-7) B75(4-7)\n                    const __m256i rhs_mat_2367_70_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_70, 221); //B72(4-7) B73(4-7) B72(4-7) B73(4-7) B76(4-7) B77(4-7) B76(4-7) B77(4-7)\n\n                    const __m256i rhs_mat_0145_71_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_71, 221); //B70(12-15) B71(12-15) B70(12-15) B71(12-15) B74(12-15) B75(12-15) B74(12-15) B75(12-15)\n                    const __m256i rhs_mat_2367_71_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_71, 221); //B72(12-15) B73(12-15) B72(12-15) B73(12-15) B76(12-15) B77(12-15) B76(12-15) B77(12-15)\n\n\n                    //Scales and Mins of corresponding sub blocks from different Q2_K structures are stored together\n                    //s00 m00  s01 m01   s10 m10  s11 m11  s20 m20  s21 m21   s30 m30  s31 m31  s40 m40  s41 m41   s50 m50  s51 m51  s60 m60  s61 m61   s70 m70  s71 m71\n\n                    // Combine mins and scales for sub-blocks: 0-1, 2-3, 4-5, 6-7 in the sb loop\n                    const __m128i mins_and_scales_01 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + sb * 64));\n                    const __m128i mins_and_scales_23 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 16 + sb * 64));\n                    const __m128i mins_and_scales_45 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 32 + sb * 64));\n                    const __m128i mins_and_scales_67 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 48 + sb * 64));\n\n                    // Extract scales which is lower half from mins_and_scales\n                    const __m128i scales_01 = _mm_and_si128(mins_and_scales_01, m4b_sse);\n                    const __m128i scales_23 = _mm_and_si128(mins_and_scales_23, m4b_sse);\n                    const __m128i scales_45 = _mm_and_si128(mins_and_scales_45, m4b_sse);\n                    const __m128i scales_67 = _mm_and_si128(mins_and_scales_67, m4b_sse);\n\n                    // Extract mins which is upper half from mins_and_scales\n                    const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_01, 4), m4b_sse));\n                    const __m256i mins_23 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_23, 4), m4b_sse));\n                    const __m256i mins_45 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_45, 4), m4b_sse));\n                    const __m256i mins_67 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_67, 4), m4b_sse));\n\n                    const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_01, scalesmask1_sse));\n                    const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_01, scalesmask2_sse));\n\n                    const __m256i scales_2 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_23, scalesmask1_sse));\n                    const __m256i scales_3 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_23, scalesmask2_sse));\n\n                    const __m256i scales_4 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_45, scalesmask1_sse));\n                    const __m256i scales_5 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_45, scalesmask2_sse));\n\n                    const __m256i scales_6 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_67, scalesmask1_sse));\n                    const __m256i scales_7 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_67, scalesmask2_sse));\n\n                    const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);\n                    const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);\n\n                    const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);\n                    const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);\n\n                    const __m256i scale_0145_2 = _mm256_shuffle_epi32(scales_2, 68);\n                    const __m256i scale_2367_2 = _mm256_shuffle_epi32(scales_2, 238);\n\n                    const __m256i scale_0145_3 = _mm256_shuffle_epi32(scales_3, 68);\n                    const __m256i scale_2367_3 = _mm256_shuffle_epi32(scales_3, 238);\n\n                    const __m256i scale_0145_4 = _mm256_shuffle_epi32(scales_4, 68);\n                    const __m256i scale_2367_4 = _mm256_shuffle_epi32(scales_4, 238);\n\n                    const __m256i scale_0145_5 = _mm256_shuffle_epi32(scales_5, 68);\n                    const __m256i scale_2367_5 = _mm256_shuffle_epi32(scales_5, 238);\n\n                    const __m256i scale_0145_6 = _mm256_shuffle_epi32(scales_6, 68);\n                    const __m256i scale_2367_6 = _mm256_shuffle_epi32(scales_6, 238);\n\n                    const __m256i scale_0145_7 = _mm256_shuffle_epi32(scales_7, 68);\n                    const __m256i scale_2367_7 = _mm256_shuffle_epi32(scales_7, 238);\n\n                    // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3\n                    // Loaded as set of 128 bit vectors and repeated into a 256 bit vector\n                    __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 512 * sb)));\n                    __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);\n                    __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);\n                    __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 512 * sb)));\n                    __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);\n                    __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);\n                    __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 512 * sb)));\n                    __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);\n                    __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);\n                    __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 512 * sb)));\n                    __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);\n                    __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);\n                    __m256i lhs_mat_0123_20 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 512 * sb)));\n                    __m256i lhs_mat_01_20 = _mm256_permute2f128_si256(lhs_mat_0123_20, lhs_mat_0123_20, 0);\n                    __m256i lhs_mat_23_20 = _mm256_permute2f128_si256(lhs_mat_0123_20, lhs_mat_0123_20, 17);\n                    __m256i lhs_mat_0123_21 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 512 * sb)));\n                    __m256i lhs_mat_01_21 = _mm256_permute2f128_si256(lhs_mat_0123_21, lhs_mat_0123_21, 0);\n                    __m256i lhs_mat_23_21 = _mm256_permute2f128_si256(lhs_mat_0123_21, lhs_mat_0123_21, 17);\n                    __m256i lhs_mat_0123_30 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 512 * sb)));\n                    __m256i lhs_mat_01_30 = _mm256_permute2f128_si256(lhs_mat_0123_30, lhs_mat_0123_30, 0);\n                    __m256i lhs_mat_23_30 = _mm256_permute2f128_si256(lhs_mat_0123_30, lhs_mat_0123_30, 17);\n                    __m256i lhs_mat_0123_31 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 512 * sb)));\n                    __m256i lhs_mat_01_31 = _mm256_permute2f128_si256(lhs_mat_0123_31, lhs_mat_0123_31, 0);\n                    __m256i lhs_mat_23_31 = _mm256_permute2f128_si256(lhs_mat_0123_31, lhs_mat_0123_31, 17);\n\n                    __m256i lhs_mat_0123_40 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 + 512 * sb)));\n                    __m256i lhs_mat_01_40 = _mm256_permute2f128_si256(lhs_mat_0123_40, lhs_mat_0123_40, 0);\n                    __m256i lhs_mat_23_40 = _mm256_permute2f128_si256(lhs_mat_0123_40, lhs_mat_0123_40, 17);\n                    __m256i lhs_mat_0123_41 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 288 + 512 * sb)));\n                    __m256i lhs_mat_01_41 = _mm256_permute2f128_si256(lhs_mat_0123_41, lhs_mat_0123_41, 0);\n                    __m256i lhs_mat_23_41 = _mm256_permute2f128_si256(lhs_mat_0123_41, lhs_mat_0123_41, 17);\n                    __m256i lhs_mat_0123_50 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 320 + 512 * sb)));\n                    __m256i lhs_mat_01_50 = _mm256_permute2f128_si256(lhs_mat_0123_50, lhs_mat_0123_50, 0);\n                    __m256i lhs_mat_23_50 = _mm256_permute2f128_si256(lhs_mat_0123_50, lhs_mat_0123_50, 17);\n                    __m256i lhs_mat_0123_51 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 352 + 512 * sb)));\n                    __m256i lhs_mat_01_51 = _mm256_permute2f128_si256(lhs_mat_0123_51, lhs_mat_0123_51, 0);\n                    __m256i lhs_mat_23_51 = _mm256_permute2f128_si256(lhs_mat_0123_51, lhs_mat_0123_51, 17);\n                    __m256i lhs_mat_0123_60 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 384 + 512 * sb)));\n                    __m256i lhs_mat_01_60 = _mm256_permute2f128_si256(lhs_mat_0123_60, lhs_mat_0123_60, 0);\n                    __m256i lhs_mat_23_60 = _mm256_permute2f128_si256(lhs_mat_0123_60, lhs_mat_0123_60, 17);\n                    __m256i lhs_mat_0123_61 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 416 + 512 * sb)));\n                    __m256i lhs_mat_01_61 = _mm256_permute2f128_si256(lhs_mat_0123_61, lhs_mat_0123_61, 0);\n                    __m256i lhs_mat_23_61 = _mm256_permute2f128_si256(lhs_mat_0123_61, lhs_mat_0123_61, 17);\n                    __m256i lhs_mat_0123_70 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 448 + 512 * sb)));\n                    __m256i lhs_mat_01_70 = _mm256_permute2f128_si256(lhs_mat_0123_70, lhs_mat_0123_70, 0);\n                    __m256i lhs_mat_23_70 = _mm256_permute2f128_si256(lhs_mat_0123_70, lhs_mat_0123_70, 17);\n                    __m256i lhs_mat_0123_71 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 480 + 512 * sb)));\n                    __m256i lhs_mat_01_71 = _mm256_permute2f128_si256(lhs_mat_0123_71, lhs_mat_0123_71, 0);\n                    __m256i lhs_mat_23_71 = _mm256_permute2f128_si256(lhs_mat_0123_71, lhs_mat_0123_71, 17);\n\n                    // Bsums are loaded for the different Q8_K blocks\n                    __m128i lhs_raw_bsums_01_0123 = _mm_loadu_si128((const __m128i *)((a_ptr[b].bsums + 32 * sb)));\n                    __m128i lhs_raw_bsums_23_0123 = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + 8 + 32 * sb));\n                    __m128i lhs_raw_bsums_01_4567 = _mm_loadu_si128((const __m128i *)((a_ptr[b].bsums + 16 + 32 * sb)));\n                    __m128i lhs_raw_bsums_23_4567 = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + 24 + 32 * sb));\n\n                    // Shuffle pattern one - left side input\n                    const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)\n                    const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)\n\n                    const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)\n                    const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160); //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)\n\n                    const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)\n                    const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)\n\n                    const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)\n                    const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160); //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)\n\n                    const __m256i lhs_mat_01_20_sp1 = _mm256_shuffle_epi32(lhs_mat_01_20, 160); //A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3)\n                    const __m256i lhs_mat_23_20_sp1 = _mm256_shuffle_epi32(lhs_mat_23_20, 160); //A22(0-3) A23(0-3) A22(0-3) A23(0-3) A22(0-3) A23(0-3) A22(0-3) A23(0-3)\n\n                    const __m256i lhs_mat_01_21_sp1 = _mm256_shuffle_epi32(lhs_mat_01_21, 160); //A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11)\n                    const __m256i lhs_mat_23_21_sp1 = _mm256_shuffle_epi32(lhs_mat_23_21, 160); //A22(8-11) A23(8-11) A22(8-11) A23(8-11) A22(8-11) A23(8-11) A22(8-11) A23(8-11)\n\n                    const __m256i lhs_mat_01_30_sp1 = _mm256_shuffle_epi32(lhs_mat_01_30, 160); //A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3)\n                    const __m256i lhs_mat_23_30_sp1 = _mm256_shuffle_epi32(lhs_mat_23_30, 160); //A32(0-3) A33(0-3) A32(0-3) A33(0-3) A32(0-3) A33(0-3) A32(0-3) A33(0-3)\n\n                    const __m256i lhs_mat_01_31_sp1 = _mm256_shuffle_epi32(lhs_mat_01_31, 160); //A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11)\n                    const __m256i lhs_mat_23_31_sp1 = _mm256_shuffle_epi32(lhs_mat_23_31, 160); //A32(8-11) A33(8-11) A32(8-11) A33(8-11) A32(8-11) A33(8-11) A32(8-11) A33(8-11)\n\n                    const __m256i lhs_mat_01_40_sp1 = _mm256_shuffle_epi32(lhs_mat_01_40, 160); //A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3)\n                    const __m256i lhs_mat_23_40_sp1 = _mm256_shuffle_epi32(lhs_mat_23_40, 160); //A42(0-3) A43(0-3) A42(0-3) A43(0-3) A42(0-3) A43(0-3) A42(0-3) A43(0-3)\n\n                    const __m256i lhs_mat_01_41_sp1 = _mm256_shuffle_epi32(lhs_mat_01_41, 160); //A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11)\n                    const __m256i lhs_mat_23_41_sp1 = _mm256_shuffle_epi32(lhs_mat_23_41, 160); //A42(8-11) A43(8-11) A42(8-11) A43(8-11) A42(8-11) A43(8-11) A42(8-11) A43(8-11)\n\n                    const __m256i lhs_mat_01_50_sp1 = _mm256_shuffle_epi32(lhs_mat_01_50, 160); //A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3)\n                    const __m256i lhs_mat_23_50_sp1 = _mm256_shuffle_epi32(lhs_mat_23_50, 160); //A52(0-3) A53(0-3) A52(0-3) A53(0-3) A52(0-3) A53(0-3) A52(0-3) A53(0-3)\n\n                    const __m256i lhs_mat_01_51_sp1 = _mm256_shuffle_epi32(lhs_mat_01_51, 160); //A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11)\n                    const __m256i lhs_mat_23_51_sp1 = _mm256_shuffle_epi32(lhs_mat_23_51, 160); //A52(8-11) A53(8-11) A52(8-11) A53(8-11) A52(8-11) A53(8-11) A52(8-11) A53(8-11)\n\n                    const __m256i lhs_mat_01_60_sp1 = _mm256_shuffle_epi32(lhs_mat_01_60, 160); //A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3)\n                    const __m256i lhs_mat_23_60_sp1 = _mm256_shuffle_epi32(lhs_mat_23_60, 160); //A62(0-3) A63(0-3) A62(0-3) A63(0-3) A62(0-3) A63(0-3) A62(0-3) A63(0-3)\n\n                    const __m256i lhs_mat_01_61_sp1 = _mm256_shuffle_epi32(lhs_mat_01_61, 160); //A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11)\n                    const __m256i lhs_mat_23_61_sp1 = _mm256_shuffle_epi32(lhs_mat_23_61, 160); //A62(8-11) A63(8-11) A62(8-11) A63(8-11) A62(8-11) A63(8-11) A62(8-11) A63(8-11)\n\n                    const __m256i lhs_mat_01_70_sp1 = _mm256_shuffle_epi32(lhs_mat_01_70, 160); //A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3)\n                    const __m256i lhs_mat_23_70_sp1 = _mm256_shuffle_epi32(lhs_mat_23_70, 160); //A72(0-3) A73(0-3) A72(0-3) A73(0-3) A72(0-3) A73(0-3) A72(0-3) A73(0-3)\n\n                    const __m256i lhs_mat_01_71_sp1 = _mm256_shuffle_epi32(lhs_mat_01_71, 160); //A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11)\n                    const __m256i lhs_mat_23_71_sp1 = _mm256_shuffle_epi32(lhs_mat_23_71, 160); //A72(8-11) A73(8-11) A72(8-11) A73(8-11) A72(8-11) A73(8-11) A72(8-11) A73(8-11)\n\n                    // Shuffle pattern two- left side input\n                    const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)\n                    const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)\n\n                    const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)\n                    const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)\n\n                    const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)\n                    const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)\n\n                    const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)\n                    const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)\n\n                    const __m256i lhs_mat_01_20_sp2 = _mm256_shuffle_epi32(lhs_mat_01_20, 245); //A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7)\n                    const __m256i lhs_mat_23_20_sp2 = _mm256_shuffle_epi32(lhs_mat_23_20, 245); //A22(4-7) A23(4-7) A22(4-7) A23(4-7) A22(4-7) A23(4-7) A22(4-7) A23(4-7)\n\n                    const __m256i lhs_mat_01_21_sp2 = _mm256_shuffle_epi32(lhs_mat_01_21, 245); //A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15)\n                    const __m256i lhs_mat_23_21_sp2 = _mm256_shuffle_epi32(lhs_mat_23_21, 245); //A22(12-15) A23(12-15) A22(12-15) A23(12-15) A22(12-15) A23(12-15) A22(12-15) A23(12-15)\n\n                    const __m256i lhs_mat_01_30_sp2 = _mm256_shuffle_epi32(lhs_mat_01_30, 245); //A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7)\n                    const __m256i lhs_mat_23_30_sp2 = _mm256_shuffle_epi32(lhs_mat_23_30, 245); //A32(4-7) A33(4-7) A32(4-7) A33(4-7) A32(4-7) A33(4-7) A32(4-7) A33(4-7)\n\n                    const __m256i lhs_mat_01_31_sp2 = _mm256_shuffle_epi32(lhs_mat_01_31, 245); //A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15)\n                    const __m256i lhs_mat_23_31_sp2 = _mm256_shuffle_epi32(lhs_mat_23_31, 245); //A32(12-15) A33(12-15) A32(12-15) A33(12-15) A32(12-15) A33(12-15) A32(12-15) A33(12-15)\n\n                    const __m256i lhs_mat_01_40_sp2 = _mm256_shuffle_epi32(lhs_mat_01_40, 245); //A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7)\n                    const __m256i lhs_mat_23_40_sp2 = _mm256_shuffle_epi32(lhs_mat_23_40, 245); //A42(4-7) A43(4-7) A42(4-7) A43(4-7) A42(4-7) A43(4-7) A42(4-7) A43(4-7)\n\n                    const __m256i lhs_mat_01_41_sp2 = _mm256_shuffle_epi32(lhs_mat_01_41, 245); //A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15)\n                    const __m256i lhs_mat_23_41_sp2 = _mm256_shuffle_epi32(lhs_mat_23_41, 245); //A42(12-15) A43(12-15) A42(12-15) A43(12-15) A42(12-15) A43(12-15) A42(12-15) A43(12-15)\n\n                    const __m256i lhs_mat_01_50_sp2 = _mm256_shuffle_epi32(lhs_mat_01_50, 245); //A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7)\n                    const __m256i lhs_mat_23_50_sp2 = _mm256_shuffle_epi32(lhs_mat_23_50, 245); //A52(4-7) A53(4-7) A52(4-7) A53(4-7) A52(4-7) A53(4-7) A52(4-7) A53(4-7)\n\n                    const __m256i lhs_mat_01_51_sp2 = _mm256_shuffle_epi32(lhs_mat_01_51, 245); //A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15)\n                    const __m256i lhs_mat_23_51_sp2 = _mm256_shuffle_epi32(lhs_mat_23_51, 245); //A52(12-15) A53(12-15) A52(12-15) A53(12-15) A52(12-15) A53(12-15) A52(12-15) A53(12-15)\n\n                    const __m256i lhs_mat_01_60_sp2 = _mm256_shuffle_epi32(lhs_mat_01_60, 245); //A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7)\n                    const __m256i lhs_mat_23_60_sp2 = _mm256_shuffle_epi32(lhs_mat_23_60, 245); //A62(4-7) A63(4-7) A62(4-7) A63(4-7) A62(4-7) A63(4-7) A62(4-7) A63(4-7)\n\n                    const __m256i lhs_mat_01_61_sp2 = _mm256_shuffle_epi32(lhs_mat_01_61, 245); //A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15)\n                    const __m256i lhs_mat_23_61_sp2 = _mm256_shuffle_epi32(lhs_mat_23_61, 245); //A62(12-15) A63(12-15) A62(12-15) A63(12-15) A62(12-15) A63(12-15) A62(12-15) A63(12-15)\n\n                    const __m256i lhs_mat_01_70_sp2 = _mm256_shuffle_epi32(lhs_mat_01_70, 245); //A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7)\n                    const __m256i lhs_mat_23_70_sp2 = _mm256_shuffle_epi32(lhs_mat_23_70, 245); //A72(4-7) A73(4-7) A72(4-7) A73(4-7) A72(4-7) A73(4-7) A72(4-7) A73(4-7)\n\n                    const __m256i lhs_mat_01_71_sp2 = _mm256_shuffle_epi32(lhs_mat_01_71, 245); //A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15)\n                    const __m256i lhs_mat_23_71_sp2 = _mm256_shuffle_epi32(lhs_mat_23_71, 245); //A72(12-15) A73(12-15) A72(12-15) A73(12-15) A72(12-15) A73(12-15) A72(12-15) A73(12-15)\n\n                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane\n                    __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1),_mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1));\n                    __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1),_mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1));\n\n                    __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1),_mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1));\n                    __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1),_mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1));\n\n                    __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1),_mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1));\n                    __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1),_mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1));\n\n                    __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1),_mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1));\n                    __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1),_mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1));\n\n                    __m256i iacc_mat_00_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp1, lhs_mat_01_20_sp1),_mm256_maddubs_epi16(rhs_mat_0145_21_sp1, lhs_mat_01_21_sp1));\n                    __m256i iacc_mat_01_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp1, lhs_mat_01_20_sp1),_mm256_maddubs_epi16(rhs_mat_2367_21_sp1, lhs_mat_01_21_sp1));\n\n                    __m256i iacc_mat_10_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp1, lhs_mat_23_20_sp1),_mm256_maddubs_epi16(rhs_mat_0145_21_sp1, lhs_mat_23_21_sp1));\n                    __m256i iacc_mat_11_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp1, lhs_mat_23_20_sp1),_mm256_maddubs_epi16(rhs_mat_2367_21_sp1, lhs_mat_23_21_sp1));\n\n                    __m256i iacc_mat_00_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp1, lhs_mat_01_30_sp1),_mm256_maddubs_epi16(rhs_mat_0145_31_sp1, lhs_mat_01_31_sp1));\n                    __m256i iacc_mat_01_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp1, lhs_mat_01_30_sp1),_mm256_maddubs_epi16(rhs_mat_2367_31_sp1, lhs_mat_01_31_sp1));\n\n                    __m256i iacc_mat_10_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp1, lhs_mat_23_30_sp1),_mm256_maddubs_epi16(rhs_mat_0145_31_sp1, lhs_mat_23_31_sp1));\n                    __m256i iacc_mat_11_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp1, lhs_mat_23_30_sp1),_mm256_maddubs_epi16(rhs_mat_2367_31_sp1, lhs_mat_23_31_sp1));\n\n                    __m256i iacc_mat_00_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp1, lhs_mat_01_40_sp1),_mm256_maddubs_epi16(rhs_mat_0145_41_sp1, lhs_mat_01_41_sp1));\n                    __m256i iacc_mat_01_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp1, lhs_mat_01_40_sp1),_mm256_maddubs_epi16(rhs_mat_2367_41_sp1, lhs_mat_01_41_sp1));\n\n                    __m256i iacc_mat_10_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp1, lhs_mat_23_40_sp1),_mm256_maddubs_epi16(rhs_mat_0145_41_sp1, lhs_mat_23_41_sp1));\n                    __m256i iacc_mat_11_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp1, lhs_mat_23_40_sp1),_mm256_maddubs_epi16(rhs_mat_2367_41_sp1, lhs_mat_23_41_sp1));\n\n                    __m256i iacc_mat_00_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp1, lhs_mat_01_50_sp1),_mm256_maddubs_epi16(rhs_mat_0145_51_sp1, lhs_mat_01_51_sp1));\n                    __m256i iacc_mat_01_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp1, lhs_mat_01_50_sp1),_mm256_maddubs_epi16(rhs_mat_2367_51_sp1, lhs_mat_01_51_sp1));\n\n                    __m256i iacc_mat_10_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp1, lhs_mat_23_50_sp1),_mm256_maddubs_epi16(rhs_mat_0145_51_sp1, lhs_mat_23_51_sp1));\n                    __m256i iacc_mat_11_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp1, lhs_mat_23_50_sp1),_mm256_maddubs_epi16(rhs_mat_2367_51_sp1, lhs_mat_23_51_sp1));\n\n                    __m256i iacc_mat_00_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp1, lhs_mat_01_60_sp1),_mm256_maddubs_epi16(rhs_mat_0145_61_sp1, lhs_mat_01_61_sp1));\n                    __m256i iacc_mat_01_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp1, lhs_mat_01_60_sp1),_mm256_maddubs_epi16(rhs_mat_2367_61_sp1, lhs_mat_01_61_sp1));\n\n                    __m256i iacc_mat_10_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp1, lhs_mat_23_60_sp1),_mm256_maddubs_epi16(rhs_mat_0145_61_sp1, lhs_mat_23_61_sp1));\n                    __m256i iacc_mat_11_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp1, lhs_mat_23_60_sp1),_mm256_maddubs_epi16(rhs_mat_2367_61_sp1, lhs_mat_23_61_sp1));\n\n                    __m256i iacc_mat_00_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp1, lhs_mat_01_70_sp1),_mm256_maddubs_epi16(rhs_mat_0145_71_sp1, lhs_mat_01_71_sp1));\n                    __m256i iacc_mat_01_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp1, lhs_mat_01_70_sp1),_mm256_maddubs_epi16(rhs_mat_2367_71_sp1, lhs_mat_01_71_sp1));\n\n                    __m256i iacc_mat_10_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp1, lhs_mat_23_70_sp1),_mm256_maddubs_epi16(rhs_mat_0145_71_sp1, lhs_mat_23_71_sp1));\n                    __m256i iacc_mat_11_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp1, lhs_mat_23_70_sp1),_mm256_maddubs_epi16(rhs_mat_2367_71_sp1, lhs_mat_23_71_sp1));\n\n\n                    __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2),_mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2));\n                    __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2),_mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2));\n\n                    __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2),_mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2));\n                    __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2),_mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2));\n\n                    __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2),_mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2));\n                    __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2),_mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2));\n\n                    __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2),_mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2));\n                    __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2),_mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2));\n\n                    __m256i iacc_mat_00_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp2, lhs_mat_01_20_sp2),_mm256_maddubs_epi16(rhs_mat_0145_21_sp2, lhs_mat_01_21_sp2));\n                    __m256i iacc_mat_01_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp2, lhs_mat_01_20_sp2),_mm256_maddubs_epi16(rhs_mat_2367_21_sp2, lhs_mat_01_21_sp2));\n\n                    __m256i iacc_mat_10_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp2, lhs_mat_23_20_sp2),_mm256_maddubs_epi16(rhs_mat_0145_21_sp2, lhs_mat_23_21_sp2));\n                    __m256i iacc_mat_11_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp2, lhs_mat_23_20_sp2),_mm256_maddubs_epi16(rhs_mat_2367_21_sp2, lhs_mat_23_21_sp2));\n\n                    __m256i iacc_mat_00_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp2, lhs_mat_01_30_sp2),_mm256_maddubs_epi16(rhs_mat_0145_31_sp2, lhs_mat_01_31_sp2));\n                    __m256i iacc_mat_01_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp2, lhs_mat_01_30_sp2),_mm256_maddubs_epi16(rhs_mat_2367_31_sp2, lhs_mat_01_31_sp2));\n\n                    __m256i iacc_mat_10_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp2, lhs_mat_23_30_sp2),_mm256_maddubs_epi16(rhs_mat_0145_31_sp2, lhs_mat_23_31_sp2));\n                    __m256i iacc_mat_11_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp2, lhs_mat_23_30_sp2),_mm256_maddubs_epi16(rhs_mat_2367_31_sp2, lhs_mat_23_31_sp2));\n\n                    __m256i iacc_mat_00_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp2, lhs_mat_01_40_sp2),_mm256_maddubs_epi16(rhs_mat_0145_41_sp2, lhs_mat_01_41_sp2));\n                    __m256i iacc_mat_01_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp2, lhs_mat_01_40_sp2),_mm256_maddubs_epi16(rhs_mat_2367_41_sp2, lhs_mat_01_41_sp2));\n\n                    __m256i iacc_mat_10_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp2, lhs_mat_23_40_sp2),_mm256_maddubs_epi16(rhs_mat_0145_41_sp2, lhs_mat_23_41_sp2));\n                    __m256i iacc_mat_11_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp2, lhs_mat_23_40_sp2),_mm256_maddubs_epi16(rhs_mat_2367_41_sp2, lhs_mat_23_41_sp2));\n\n                    __m256i iacc_mat_00_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp2, lhs_mat_01_50_sp2),_mm256_maddubs_epi16(rhs_mat_0145_51_sp2, lhs_mat_01_51_sp2));\n                    __m256i iacc_mat_01_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp2, lhs_mat_01_50_sp2),_mm256_maddubs_epi16(rhs_mat_2367_51_sp2, lhs_mat_01_51_sp2));\n\n                    __m256i iacc_mat_10_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp2, lhs_mat_23_50_sp2),_mm256_maddubs_epi16(rhs_mat_0145_51_sp2, lhs_mat_23_51_sp2));\n                    __m256i iacc_mat_11_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp2, lhs_mat_23_50_sp2),_mm256_maddubs_epi16(rhs_mat_2367_51_sp2, lhs_mat_23_51_sp2));\n\n                    __m256i iacc_mat_00_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp2, lhs_mat_01_60_sp2),_mm256_maddubs_epi16(rhs_mat_0145_61_sp2, lhs_mat_01_61_sp2));\n                    __m256i iacc_mat_01_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp2, lhs_mat_01_60_sp2),_mm256_maddubs_epi16(rhs_mat_2367_61_sp2, lhs_mat_01_61_sp2));\n\n                    __m256i iacc_mat_10_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp2, lhs_mat_23_60_sp2),_mm256_maddubs_epi16(rhs_mat_0145_61_sp2, lhs_mat_23_61_sp2));\n                    __m256i iacc_mat_11_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp2, lhs_mat_23_60_sp2),_mm256_maddubs_epi16(rhs_mat_2367_61_sp2, lhs_mat_23_61_sp2));\n\n                    __m256i iacc_mat_00_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp2, lhs_mat_01_70_sp2),_mm256_maddubs_epi16(rhs_mat_0145_71_sp2, lhs_mat_01_71_sp2));\n                    __m256i iacc_mat_01_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp2, lhs_mat_01_70_sp2),_mm256_maddubs_epi16(rhs_mat_2367_71_sp2, lhs_mat_01_71_sp2));\n\n                    __m256i iacc_mat_10_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp2, lhs_mat_23_70_sp2),_mm256_maddubs_epi16(rhs_mat_0145_71_sp2, lhs_mat_23_71_sp2));\n                    __m256i iacc_mat_11_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp2, lhs_mat_23_70_sp2),_mm256_maddubs_epi16(rhs_mat_2367_71_sp2, lhs_mat_23_71_sp2));\n\n                    // Combine results from both shuffle patterns for each output block.\n                    __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);\n                    __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);\n                    __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);\n                    __m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);\n\n                    __m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);\n                    __m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);\n                    __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);\n                    __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);\n\n                    __m256i iacc_mat_00_2 = _mm256_add_epi16(iacc_mat_00_2_sp1, iacc_mat_00_2_sp2);\n                    __m256i iacc_mat_01_2 = _mm256_add_epi16(iacc_mat_01_2_sp1, iacc_mat_01_2_sp2);\n                    __m256i iacc_mat_10_2 = _mm256_add_epi16(iacc_mat_10_2_sp1, iacc_mat_10_2_sp2);\n                    __m256i iacc_mat_11_2 = _mm256_add_epi16(iacc_mat_11_2_sp1, iacc_mat_11_2_sp2);\n\n                    __m256i iacc_mat_00_3 = _mm256_add_epi16(iacc_mat_00_3_sp1, iacc_mat_00_3_sp2);\n                    __m256i iacc_mat_01_3 = _mm256_add_epi16(iacc_mat_01_3_sp1, iacc_mat_01_3_sp2);\n                    __m256i iacc_mat_10_3 = _mm256_add_epi16(iacc_mat_10_3_sp1, iacc_mat_10_3_sp2);\n                    __m256i iacc_mat_11_3 = _mm256_add_epi16(iacc_mat_11_3_sp1, iacc_mat_11_3_sp2);\n\n                    __m256i iacc_mat_00_4 = _mm256_add_epi16(iacc_mat_00_4_sp1, iacc_mat_00_4_sp2);\n                    __m256i iacc_mat_01_4 = _mm256_add_epi16(iacc_mat_01_4_sp1, iacc_mat_01_4_sp2);\n                    __m256i iacc_mat_10_4 = _mm256_add_epi16(iacc_mat_10_4_sp1, iacc_mat_10_4_sp2);\n                    __m256i iacc_mat_11_4 = _mm256_add_epi16(iacc_mat_11_4_sp1, iacc_mat_11_4_sp2);\n\n                    __m256i iacc_mat_00_5 = _mm256_add_epi16(iacc_mat_00_5_sp1, iacc_mat_00_5_sp2);\n                    __m256i iacc_mat_01_5 = _mm256_add_epi16(iacc_mat_01_5_sp1, iacc_mat_01_5_sp2);\n                    __m256i iacc_mat_10_5 = _mm256_add_epi16(iacc_mat_10_5_sp1, iacc_mat_10_5_sp2);\n                    __m256i iacc_mat_11_5 = _mm256_add_epi16(iacc_mat_11_5_sp1, iacc_mat_11_5_sp2);\n\n                    __m256i iacc_mat_00_6 = _mm256_add_epi16(iacc_mat_00_6_sp1, iacc_mat_00_6_sp2);\n                    __m256i iacc_mat_01_6 = _mm256_add_epi16(iacc_mat_01_6_sp1, iacc_mat_01_6_sp2);\n                    __m256i iacc_mat_10_6 = _mm256_add_epi16(iacc_mat_10_6_sp1, iacc_mat_10_6_sp2);\n                    __m256i iacc_mat_11_6 = _mm256_add_epi16(iacc_mat_11_6_sp1, iacc_mat_11_6_sp2);\n\n                    __m256i iacc_mat_00_7 = _mm256_add_epi16(iacc_mat_00_7_sp1, iacc_mat_00_7_sp2);\n                    __m256i iacc_mat_01_7 = _mm256_add_epi16(iacc_mat_01_7_sp1, iacc_mat_01_7_sp2);\n                    __m256i iacc_mat_10_7 = _mm256_add_epi16(iacc_mat_10_7_sp1, iacc_mat_10_7_sp2);\n                    __m256i iacc_mat_11_7 = _mm256_add_epi16(iacc_mat_11_7_sp1, iacc_mat_11_7_sp2);\n\n                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block\n                    iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);\n                    iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);\n                    iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);\n                    iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);\n\n                    iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);\n                    iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);\n                    iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);\n                    iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);\n\n                    iacc_mat_00_2 = _mm256_madd_epi16(iacc_mat_00_2, scale_0145_2);\n                    iacc_mat_01_2 = _mm256_madd_epi16(iacc_mat_01_2, scale_2367_2);\n                    iacc_mat_10_2 = _mm256_madd_epi16(iacc_mat_10_2, scale_0145_2);\n                    iacc_mat_11_2 = _mm256_madd_epi16(iacc_mat_11_2, scale_2367_2);\n\n                    iacc_mat_00_3 = _mm256_madd_epi16(iacc_mat_00_3, scale_0145_3);\n                    iacc_mat_01_3 = _mm256_madd_epi16(iacc_mat_01_3, scale_2367_3);\n                    iacc_mat_10_3 = _mm256_madd_epi16(iacc_mat_10_3, scale_0145_3);\n                    iacc_mat_11_3 = _mm256_madd_epi16(iacc_mat_11_3, scale_2367_3);\n\n                    iacc_mat_00_4 = _mm256_madd_epi16(iacc_mat_00_4, scale_0145_4);\n                    iacc_mat_01_4 = _mm256_madd_epi16(iacc_mat_01_4, scale_2367_4);\n                    iacc_mat_10_4 = _mm256_madd_epi16(iacc_mat_10_4, scale_0145_4);\n                    iacc_mat_11_4 = _mm256_madd_epi16(iacc_mat_11_4, scale_2367_4);\n\n                    iacc_mat_00_5 = _mm256_madd_epi16(iacc_mat_00_5, scale_0145_5);\n                    iacc_mat_01_5 = _mm256_madd_epi16(iacc_mat_01_5, scale_2367_5);\n                    iacc_mat_10_5 = _mm256_madd_epi16(iacc_mat_10_5, scale_0145_5);\n                    iacc_mat_11_5 = _mm256_madd_epi16(iacc_mat_11_5, scale_2367_5);\n\n                    iacc_mat_00_6 = _mm256_madd_epi16(iacc_mat_00_6, scale_0145_6);\n                    iacc_mat_01_6 = _mm256_madd_epi16(iacc_mat_01_6, scale_2367_6);\n                    iacc_mat_10_6 = _mm256_madd_epi16(iacc_mat_10_6, scale_0145_6);\n                    iacc_mat_11_6 = _mm256_madd_epi16(iacc_mat_11_6, scale_2367_6);\n\n                    iacc_mat_00_7 = _mm256_madd_epi16(iacc_mat_00_7, scale_0145_7);\n                    iacc_mat_01_7 = _mm256_madd_epi16(iacc_mat_01_7, scale_2367_7);\n                    iacc_mat_10_7 = _mm256_madd_epi16(iacc_mat_10_7, scale_0145_7);\n                    iacc_mat_11_7 = _mm256_madd_epi16(iacc_mat_11_7, scale_2367_7);\n\n                    __m256i iacc_mat_00 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_00_0, iacc_mat_00_1), _mm256_add_epi32(iacc_mat_00_2, iacc_mat_00_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_00_4, iacc_mat_00_5), _mm256_add_epi32(iacc_mat_00_6, iacc_mat_00_7)));\n                    __m256i iacc_mat_01 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_01_0, iacc_mat_01_1), _mm256_add_epi32(iacc_mat_01_2, iacc_mat_01_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_01_4, iacc_mat_01_5), _mm256_add_epi32(iacc_mat_01_6, iacc_mat_01_7)));\n                    __m256i iacc_mat_10 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_10_0, iacc_mat_10_1), _mm256_add_epi32(iacc_mat_10_2, iacc_mat_10_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_10_4, iacc_mat_10_5), _mm256_add_epi32(iacc_mat_10_6, iacc_mat_10_7)));\n                    __m256i iacc_mat_11 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_11_0, iacc_mat_11_1), _mm256_add_epi32(iacc_mat_11_2, iacc_mat_11_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_11_4, iacc_mat_11_5), _mm256_add_epi32(iacc_mat_11_6, iacc_mat_11_7)));\n\n                    // Straighten out to make 4 row vectors\n                    __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);\n                    __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);\n                    __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);\n                    __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);\n\n                    // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes\n                    const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);\n                    const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);\n\n                    // Multiply with appropriate scales and accumulate (for both d and dmin) below\n                    acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);\n                    acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);\n                    acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);\n                    acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);\n\n                    __m256i lhs_bsums_01_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_0123), lhs_raw_bsums_01_0123, 1);\n                    __m256i lhs_bsums_23_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_0123), lhs_raw_bsums_23_0123, 1);\n                    __m256i lhs_bsums_01_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_4567), lhs_raw_bsums_01_4567, 1);\n                    __m256i lhs_bsums_23_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_4567), lhs_raw_bsums_23_4567, 1);\n\n                    // Take two bsums from two Q8_Ks at a time and multiply with corresponding mins values from each Q2_K\n                    __m256i iacc_row_min_0_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 0), mins_01);\n                    __m256i iacc_row_min_1_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 170), mins_01);\n                    __m256i iacc_row_min_2_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 0), mins_01);\n                    __m256i iacc_row_min_3_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 170), mins_01);\n\n                    __m256i iacc_row_min_0_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 85), mins_23);\n                    __m256i iacc_row_min_1_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 255), mins_23);\n                    __m256i iacc_row_min_2_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 85), mins_23);\n                    __m256i iacc_row_min_3_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 255), mins_23);\n\n                    __m256i iacc_row_min_0_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 0), mins_45);\n                    __m256i iacc_row_min_1_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 170), mins_45);\n                    __m256i iacc_row_min_2_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 0), mins_45);\n                    __m256i iacc_row_min_3_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 170), mins_45);\n\n                    __m256i iacc_row_min_0_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 85), mins_67);\n                    __m256i iacc_row_min_1_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 255), mins_67);\n                    __m256i iacc_row_min_2_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 85), mins_67);\n                    __m256i iacc_row_min_3_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 255), mins_67);\n\n                    __m256i iacc_row_min_0 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_0_01, iacc_row_min_0_23), _mm256_add_epi32(iacc_row_min_0_45,iacc_row_min_0_67));\n                    __m256i iacc_row_min_1 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_1_01, iacc_row_min_1_23), _mm256_add_epi32(iacc_row_min_1_45,iacc_row_min_1_67));\n                    __m256i iacc_row_min_2 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_2_01, iacc_row_min_2_23), _mm256_add_epi32(iacc_row_min_2_45,iacc_row_min_2_67));\n                    __m256i iacc_row_min_3 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_3_01, iacc_row_min_3_23), _mm256_add_epi32(iacc_row_min_3_45,iacc_row_min_3_67));\n\n                    acc_min_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);\n                    acc_min_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);\n                    acc_min_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]);\n                    acc_min_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);\n                }\n            }\n            // Store the accumulated values\n            for (int i = 0; i < 4; i++) {\n                _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));\n            }\n        }\n    }\n#else\n\n    ggml_gemm_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);\n\n\n#endif\n}\n"
  },
  {
    "path": "src/ggml-cpu/arch-fallback.h",
    "content": "\n#pragma once\n\n// Rename `_generic` functions if no native implementation is available.\n// This effectively selects the generic implementation.\n\n#if defined(GGML_CPU_GENERIC)\n// quants.c\n#define quantize_row_q8_0_generic quantize_row_q8_0\n#define quantize_row_q8_1_generic quantize_row_q8_1\n#define quantize_row_q8_K_generic quantize_row_q8_K\n#define ggml_vec_dot_q4_0_q8_0_generic ggml_vec_dot_q4_0_q8_0\n#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1\n#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0\n#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1\n#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0\n#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0\n#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0\n#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K\n#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K\n#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K\n#define ggml_vec_dot_q3_K_q8_K_generic ggml_vec_dot_q3_K_q8_K\n#define ggml_vec_dot_q4_K_q8_K_generic ggml_vec_dot_q4_K_q8_K\n#define ggml_vec_dot_q5_K_q8_K_generic ggml_vec_dot_q5_K_q8_K\n#define ggml_vec_dot_q6_K_q8_K_generic ggml_vec_dot_q6_K_q8_K\n#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K\n#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K\n#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K\n#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K\n#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K\n#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K\n#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K\n#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0\n#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K\n// repack.cpp\n#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4\n#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8\n#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4\n#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8\n#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0\n#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0\n#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0\n#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K\n#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K\n#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K\n#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K\n#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K\n#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K\n#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K\n#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0\n#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0\n#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0\n#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0\n#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0\n#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0\n#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0\n#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0\n#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0\n#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K\n#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K\n#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K\n#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K\n#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K\n#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K\n#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K\n#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0\n#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0\n#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0\n#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0\n#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0\n#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0\n#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)\n// repack.cpp\n#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4\n#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8\n#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0\n#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0\n#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K\n#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0\n#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0\n#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K\n#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)\n// quants.c\n#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0\n// repack.cpp\n#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4\n#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4\n#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0\n#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0\n#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K\n#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K\n#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K\n#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K\n#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K\n#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0\n#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0\n#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0\n#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0\n#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0\n#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0\n#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K\n#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K\n#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K\n#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K\n#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K\n#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0\n#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0\n#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0\n#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0\n#elif defined(__POWERPC__) || defined(__powerpc__)\n// ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679\n// quants.c\n#define quantize_row_q8_K_generic quantize_row_q8_K\n#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0\n#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K\n#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K\n#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K\n// repack.cpp\n#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4\n#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8\n#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4\n#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8\n#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0\n#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0\n#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0\n#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K\n#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K\n#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K\n#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K\n#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K\n#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K\n#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K\n#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0\n#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0\n#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0\n#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0\n#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0\n#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0\n#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0\n#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0\n#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0\n#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K\n#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K\n#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K\n#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K\n#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K\n#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K\n#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K\n#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0\n#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0\n#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0\n#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0\n#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0\n#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0\n#elif defined(__loongarch64)\n// quants.c\n#define quantize_row_q8_K_generic quantize_row_q8_K\n#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K\n#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K\n#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K\n#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0\n#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0\n// repack.cpp\n#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4\n#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8\n#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4\n#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8\n#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0\n#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0\n#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0\n#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K\n#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K\n#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K\n#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K\n#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K\n#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K\n#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K\n#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0\n#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0\n#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0\n#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0\n#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0\n#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0\n#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0\n#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0\n#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0\n#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K\n#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K\n#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K\n#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K\n#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K\n#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K\n#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K\n#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0\n#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0\n#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0\n#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0\n#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0\n#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0\n#elif defined(__riscv)\n// quants.c\n#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0\n// repack.cpp\n#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1\n#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4\n#define ggml_quantize_mat_q8_K_4x1_generic ggml_quantize_mat_q8_K_4x1\n#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4\n#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8\n#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0\n#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0\n#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K\n#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K\n#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K\n#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K\n#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K\n#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K\n#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K\n#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0\n#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0\n#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0\n#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0\n#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0\n#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0\n#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0\n#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0\n#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K\n#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K\n#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K\n#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K\n#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K\n#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K\n#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K\n#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0\n#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0\n#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0\n#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0\n#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0\n#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0\n#elif defined(__s390x__)\n// quants.c\n#define quantize_row_q8_K_generic quantize_row_q8_K\n#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0\n#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K\n#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K\n#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K\n#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K\n#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K\n#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K\n#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K\n#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K\n#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K\n#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K\n// repack.cpp\n#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4\n#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8\n#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4\n#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8\n#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0\n#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0\n#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0\n#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K\n#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K\n#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K\n#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K\n#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K\n#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K\n#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K\n#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0\n#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0\n#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0\n#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0\n#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0\n#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0\n#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0\n#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0\n#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0\n#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K\n#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K\n#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K\n#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K\n#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K\n#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K\n#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K\n#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0\n#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0\n#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0\n#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0\n#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0\n#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0\n#elif defined(__wasm__)\n// quants.c\n#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1\n#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K\n#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K\n#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K\n#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K\n#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K\n#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K\n#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K\n#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K\n#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K\n#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0\n#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K\n#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0\n#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0\n// repack.cpp\n#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4\n#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8\n#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4\n#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8\n#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0\n#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0\n#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0\n#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K\n#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K\n#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K\n#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K\n#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K\n#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K\n#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K\n#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0\n#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0\n#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0\n#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0\n#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0\n#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0\n#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0\n#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0\n#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0\n#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K\n#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K\n#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K\n#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K\n#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K\n#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K\n#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K\n#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0\n#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0\n#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0\n#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0\n#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0\n#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/binary-ops.cpp",
    "content": "#include \"binary-ops.h\"\n\n#if defined(GGML_USE_ACCELERATE)\n#include <Accelerate/Accelerate.h>\n\nusing vDSP_fn_t = void (*)(const float *, vDSP_Stride, const float *, vDSP_Stride, float *, vDSP_Stride, vDSP_Length);\n#endif\n\nstatic inline float op_add(float a, float b) {\n    return a + b;\n}\n\nstatic inline float op_sub(float a, float b) {\n    return a - b;\n}\n\nstatic inline float op_mul(float a, float b) {\n    return a * b;\n}\n\nstatic inline float op_div(float a, float b) {\n    return a / b;\n}\n\ntemplate <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>\nstatic inline void vec_binary_op_contiguous(const int64_t n, dst_t * z, const src0_t * x, const src1_t * y) {\n    constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;\n    constexpr auto src1_to_f32 = type_conversion_table<src1_t>::to_f32;\n    constexpr auto f32_to_dst  = type_conversion_table<dst_t >::from_f32;\n\n    for (int i = 0; i < n; i++) {\n        z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(y[i])));\n    }\n}\n\ntemplate <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>\nstatic inline void vec_binary_op_non_contiguous(const int64_t n, const int64_t ne10, const int64_t nb10, dst_t * z, const src0_t * x, const src1_t * y) {\n    constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;\n    constexpr auto src1_to_f32 = type_conversion_table<src1_t>::to_f32;\n    constexpr auto f32_to_dst  = type_conversion_table<dst_t >::from_f32;\n\n    for (int i = 0; i < n; i++) {\n        int i10 = i % ne10;\n        const src1_t * y_ptr = (const src1_t *)((const char *)y + i10*nb10);\n        z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(*y_ptr)));\n    }\n}\n\ntemplate <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>\nstatic void apply_binary_op(const ggml_compute_params * params, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    GGML_ASSERT( nb0 == sizeof(dst_t));\n    GGML_ASSERT(nb00 == sizeof(src0_t));\n\n    const auto [ir0, ir1] = get_thread_range(params, src0);\n    const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1);\n\n#ifdef GGML_USE_ACCELERATE\n    vDSP_fn_t vDSP_op = nullptr;\n    // TODO - avoid the f32-only check using type 'trait' lookup tables and row-based src-to-float conversion functions\n    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n        if (op == op_add) {\n            vDSP_op = vDSP_vadd;\n        } else if (op == op_sub) {\n            vDSP_op = vDSP_vsub;\n        } else if (op == op_mul) {\n            vDSP_op = vDSP_vmul;\n        } else if (op == op_div) {\n            vDSP_op = vDSP_vdiv;\n        }\n    }\n#endif\n\n    for (int64_t ir = ir0; ir < ir1; ++ir) {\n        const int64_t i03 = ir/(ne02*ne01);\n        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;\n        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);\n\n        const int64_t i13 = i03 % ne13;\n        const int64_t i12 = i02 % ne12;\n        const int64_t i11 = i01 % ne11;\n\n        dst_t        * dst_ptr  = (dst_t  *)       ((char *)       dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );\n        const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);\n        const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);\n\n        if (is_src1_contiguous_rows) {\n            // src1 is broadcastable across src0 and dst in i1, i2, i3\n            const int64_t nr0 = ne00 / ne10;\n\n            for (int64_t r = 0; r < nr0; ++r) {\n#ifdef GGML_USE_ACCELERATE\n                if constexpr (std::is_same_v<src0_t, float> && std::is_same_v<src1_t, float> && std::is_same_v<dst_t, float>) {\n                    if (vDSP_op != nullptr) {\n                        vDSP_op(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);\n                        continue;\n                    }\n                }\n#endif\n                vec_binary_op_contiguous<op>(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);\n            }\n        } else {\n            vec_binary_op_non_contiguous<op>(ne0, ne10, nb10, dst_ptr, src0_ptr, src1_ptr);\n        }\n    }\n}\n\n// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates\ntemplate <float (*op)(float, float)>\nstatic void binary_op(const ggml_compute_params * params, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    /*  */ if (src0->type == GGML_TYPE_F32  && src1->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) { // all f32\n        apply_binary_op<op, float, float, float>(params, dst);\n    } else if (src0->type == GGML_TYPE_F16  && src1->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F16) { // all f16\n        apply_binary_op<op, ggml_fp16_t, ggml_fp16_t, ggml_fp16_t>(params, dst);\n    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16\n        apply_binary_op<op, ggml_bf16_t, ggml_bf16_t, ggml_bf16_t>(params, dst);\n    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_BF16) {\n        apply_binary_op<op, ggml_bf16_t, float, ggml_bf16_t>(params, dst);\n    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) {\n        apply_binary_op<op, ggml_bf16_t, float, float>(params, dst);\n    } else if (src0->type == GGML_TYPE_F16  && src1->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F16) {\n        apply_binary_op<op, ggml_fp16_t, float, ggml_fp16_t>(params, dst);\n    } else if (src0->type == GGML_TYPE_F16  && src1->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) {\n        apply_binary_op<op, ggml_fp16_t, float, float>(params, dst);\n    } else {\n        GGML_ABORT(\"%s: unsupported types: dst: %s, src0: %s, src1: %s\\n\", __func__,\n            ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));\n    }\n}\n\nvoid ggml_compute_forward_add_non_quantized(const ggml_compute_params * params, ggml_tensor * dst) {\n    binary_op<op_add>(params, dst);\n}\n\nvoid ggml_compute_forward_sub(const ggml_compute_params * params, ggml_tensor * dst) {\n    binary_op<op_sub>(params, dst);\n}\n\nvoid ggml_compute_forward_mul(const ggml_compute_params * params, ggml_tensor * dst) {\n    binary_op<op_mul>(params, dst);\n}\n\nvoid ggml_compute_forward_div(const ggml_compute_params * params, ggml_tensor * dst) {\n    binary_op<op_div>(params, dst);\n}\n"
  },
  {
    "path": "src/ggml-cpu/binary-ops.h",
    "content": "#pragma once\n\n#include \"common.h\"\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nvoid ggml_compute_forward_add_non_quantized(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_sub(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_mul(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_div(const struct ggml_compute_params * params, struct ggml_tensor * dst);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/cmake/FindSIMD.cmake",
    "content": "include(CheckCSourceRuns)\n\nset(AVX_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m256 a;\n        a = _mm256_set1_ps(0);\n        return 0;\n    }\n\")\n\nset(AVX512_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0);\n        __m512i b = a;\n        __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);\n        return 0;\n    }\n\")\n\nset(AVX2_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m256i a = {0};\n        a = _mm256_abs_epi16(a);\n        __m256i x;\n        _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code\n        return 0;\n    }\n\")\n\nset(FMA_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m256 acc = _mm256_setzero_ps();\n        const __m256 d = _mm256_setzero_ps();\n        const __m256 p = _mm256_setzero_ps();\n        acc = _mm256_fmadd_ps( d, p, acc );\n        return 0;\n    }\n\")\n\nmacro(check_sse type flags)\n    set(__FLAG_I 1)\n    set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})\n    foreach (__FLAG ${flags})\n        if (NOT ${type}_FOUND)\n            set(CMAKE_REQUIRED_FLAGS ${__FLAG})\n            check_c_source_runs(\"${${type}_CODE}\" HAS_${type}_${__FLAG_I})\n            if (HAS_${type}_${__FLAG_I})\n                set(${type}_FOUND TRUE CACHE BOOL \"${type} support\")\n                set(${type}_FLAGS \"${__FLAG}\" CACHE STRING \"${type} flags\")\n            endif()\n            math(EXPR __FLAG_I \"${__FLAG_I}+1\")\n        endif()\n    endforeach()\n    set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})\n\n    if (NOT ${type}_FOUND)\n        set(${type}_FOUND FALSE CACHE BOOL \"${type} support\")\n        set(${type}_FLAGS \"\" CACHE STRING \"${type} flags\")\n    endif()\n\n    mark_as_advanced(${type}_FOUND ${type}_FLAGS)\nendmacro()\n\n# flags are for MSVC only!\ncheck_sse(\"AVX\" \" ;/arch:AVX\")\nif (NOT ${AVX_FOUND})\n    set(GGML_AVX OFF)\nelse()\n    set(GGML_AVX ON)\nendif()\n\ncheck_sse(\"AVX2\" \" ;/arch:AVX2\")\ncheck_sse(\"FMA\" \" ;/arch:AVX2\")\nif ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))\n    set(GGML_AVX2 OFF)\nelse()\n    set(GGML_AVX2 ON)\nendif()\n\ncheck_sse(\"AVX512\" \" ;/arch:AVX512\")\nif (NOT ${AVX512_FOUND})\n    set(GGML_AVX512 OFF)\nelse()\n    set(GGML_AVX512 ON)\nendif()\n"
  },
  {
    "path": "src/ggml-cpu/common.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n#include \"traits.h\"\n#include \"ggml-cpu-impl.h\"\n#include \"ggml-impl.h\"\n#include \"simd-mappings.h\"\n\n#define GGML_FA_TILE_Q  64\n#define GGML_FA_TILE_KV 64\n\n#ifdef __cplusplus\n\n#include <utility>\n\n// convenience functions/macros for use in template calls\n// note: these won't be required after the 'traits' lookup table is used.\nstatic inline ggml_fp16_t f32_to_f16(float x) {\n    return GGML_CPU_FP32_TO_FP16(x);\n}\n\nstatic inline float f16_to_f32(ggml_fp16_t x) {\n    return GGML_CPU_FP16_TO_FP32(x);\n}\n\nstatic inline ggml_bf16_t f32_to_bf16(float x) {\n    return GGML_FP32_TO_BF16(x);\n}\n\nstatic inline float bf16_to_f32(ggml_bf16_t x) {\n    return GGML_BF16_TO_FP32(x);\n}\n\nstatic inline float i32_to_f32(int32_t x) {\n    return x;\n}\n\nstatic inline int32_t f32_to_i32(float x) {\n    return x;\n}\n\nstatic inline float f32_to_f32(float x) {\n    return x;\n}\n\n// TODO - merge this into the traits table, after using row-based conversions\ntemplate <class T>\nstruct type_conversion_table;\n\ntemplate <>\nstruct type_conversion_table<ggml_fp16_t> {\n    static constexpr float (*to_f32)(ggml_fp16_t) = f16_to_f32;\n    static constexpr ggml_fp16_t (*from_f32)(float) = f32_to_f16;\n};\n\ntemplate <>\nstruct type_conversion_table<float> {\n    static constexpr float (*to_f32)(float) = f32_to_f32;\n    static constexpr float (*from_f32)(float) = f32_to_f32;\n};\n\ntemplate <>\nstruct type_conversion_table<ggml_bf16_t> {\n    static constexpr float (*to_f32)(ggml_bf16_t) = bf16_to_f32;\n    static constexpr ggml_bf16_t (*from_f32)(float) = f32_to_bf16;\n};\n\ntemplate <>\nstruct type_conversion_table<int32_t> {\n    static constexpr float (*to_f32)(int32_t) = i32_to_f32;\n    static constexpr int32_t (*from_f32)(float) = f32_to_i32;\n};\n\nstatic std::pair<int64_t, int64_t> get_thread_range(const struct ggml_compute_params * params, const struct ggml_tensor * src0) {\n    const int64_t ith = params->ith;\n    const int64_t nth = params->nth;\n\n    const int64_t nr  = ggml_nrows(src0);\n\n    // rows per thread\n    const int64_t dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int64_t ir0 = dr*ith;\n    const int64_t ir1 = MIN(ir0 + dr, nr);\n\n    return {ir0, ir1};\n}\n\nstruct ggml_fa_tile_config {\n    static constexpr size_t Q  = GGML_FA_TILE_Q;\n    static constexpr size_t KV = GGML_FA_TILE_KV;\n};\n\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/ggml-cpu-impl.h",
    "content": "#pragma once\n\n// GGML CPU internal header\n\n#include \"ggml.h\"\n#include \"ggml-impl.h\"\n\n#include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/\n//#include <stddef.h>\n#include <stdbool.h>\n#include <string.h> // memcpy\n#include <math.h>   // fabsf\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nstruct ggml_compute_params {\n    // ith = thread index, nth = number of threads\n    int ith, nth;\n\n    // work buffer for all threads\n    size_t wsize;\n    void * wdata;\n\n    struct ggml_threadpool * threadpool;\n\n    // use reference implementation\n    bool use_ref;\n};\n\n\n#if defined(_MSC_VER)\n\n#define m512bh(p) p\n#define m512i(p) p\n\n#else\n\n#define m512bh(p) (__m512bh)(p)\n#define m512i(p) (__m512i)(p)\n\n#endif\n\n// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512\n#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))\n#ifndef __FMA__\n#define __FMA__\n#endif\n#ifndef __F16C__\n#define __F16C__\n#endif\n#endif\n\n// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available\n#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))\n#ifndef __SSE3__\n#define __SSE3__\n#endif\n#ifndef __SSSE3__\n#define __SSSE3__\n#endif\n#endif\n\n#if defined(__s390x__) && defined(__VEC__)\n#ifndef __VXE__\n#define __VXE__\n#endif  // __VXE__\n#ifndef __VXE2__\n#define __VXE2__\n#endif  // __VXE2__\n#endif  // __s390x__ && __VEC__\n\n#if defined(__ARM_FEATURE_SVE) && defined(__linux__)\n#include <sys/prctl.h>\n#endif\n\n#if defined(__ARM_NEON)\n\n// ref: https://github.com/ggml-org/llama.cpp/pull/5404\n#ifdef _MSC_VER\n#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }\n#else\n#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }\n#endif // _MSC_VER\n\n#if !defined(__aarch64__)\n\n// 32-bit ARM compatibility\n\n// vaddlvq_s16\n// vpaddq_s16\n// vpaddq_s32\n// vaddvq_s32\n// vaddvq_f32\n// vmaxvq_f32\n// vcvtnq_s32_f32\n// vzip1_u8\n// vzip2_u8\n\ninline static int32_t vaddlvq_s16(int16x8_t v) {\n    int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));\n    return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);\n}\n\ninline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {\n    int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));\n    int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));\n    return vcombine_s16(a0, b0);\n}\n\ninline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {\n    int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));\n    int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));\n    return vcombine_s32(a0, b0);\n}\n\ninline static int32_t vaddvq_s32(int32x4_t v) {\n    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);\n}\n\ninline static float vaddvq_f32(float32x4_t v) {\n    return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);\n}\n\ninline static float vmaxvq_f32(float32x4_t v) {\n    return\n        MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),\n            MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));\n}\n\ninline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {\n    int32x4_t res;\n\n    res[0] = roundf(vgetq_lane_f32(v, 0));\n    res[1] = roundf(vgetq_lane_f32(v, 1));\n    res[2] = roundf(vgetq_lane_f32(v, 2));\n    res[3] = roundf(vgetq_lane_f32(v, 3));\n\n    return res;\n}\n\ninline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {\n    uint8x8_t res;\n\n    res[0] = a[0]; res[1] = b[0];\n    res[2] = a[1]; res[3] = b[1];\n    res[4] = a[2]; res[5] = b[2];\n    res[6] = a[3]; res[7] = b[3];\n\n    return res;\n}\n\ninline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {\n    uint8x8_t res;\n\n    res[0] = a[4]; res[1] = b[4];\n    res[2] = a[5]; res[3] = b[5];\n    res[4] = a[6]; res[5] = b[6];\n    res[6] = a[7]; res[7] = b[7];\n\n    return res;\n}\n\n// vld1q_s16_x2\n// vld1q_u8_x2\n// vld1q_u8_x4\n// vld1q_s8_x2\n// vld1q_s8_x4\n// TODO: double-check these work correctly\n\ntypedef struct ggml_int16x8x2_t {\n    int16x8_t val[2];\n} ggml_int16x8x2_t;\n\ninline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {\n    ggml_int16x8x2_t res;\n\n    res.val[0] = vld1q_s16(ptr + 0);\n    res.val[1] = vld1q_s16(ptr + 8);\n\n    return res;\n}\n\ntypedef struct ggml_uint8x16x2_t {\n    uint8x16_t val[2];\n} ggml_uint8x16x2_t;\n\ninline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {\n    ggml_uint8x16x2_t res;\n\n    res.val[0] = vld1q_u8(ptr + 0);\n    res.val[1] = vld1q_u8(ptr + 16);\n\n    return res;\n}\n\ntypedef struct ggml_uint8x16x4_t {\n    uint8x16_t val[4];\n} ggml_uint8x16x4_t;\n\ninline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {\n    ggml_uint8x16x4_t res;\n\n    res.val[0] = vld1q_u8(ptr + 0);\n    res.val[1] = vld1q_u8(ptr + 16);\n    res.val[2] = vld1q_u8(ptr + 32);\n    res.val[3] = vld1q_u8(ptr + 48);\n\n    return res;\n}\n\ntypedef struct ggml_int8x16x2_t {\n    int8x16_t val[2];\n} ggml_int8x16x2_t;\n\ninline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {\n    ggml_int8x16x2_t res;\n\n    res.val[0] = vld1q_s8(ptr + 0);\n    res.val[1] = vld1q_s8(ptr + 16);\n\n    return res;\n}\n\ntypedef struct ggml_int8x16x4_t {\n    int8x16_t val[4];\n} ggml_int8x16x4_t;\n\ninline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {\n    ggml_int8x16x4_t res;\n\n    res.val[0] = vld1q_s8(ptr + 0);\n    res.val[1] = vld1q_s8(ptr + 16);\n    res.val[2] = vld1q_s8(ptr + 32);\n    res.val[3] = vld1q_s8(ptr + 48);\n\n    return res;\n}\n\n// NOTE: not tested\ninline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {\n    int8x16_t res;\n\n    res[ 0] = a[b[ 0]];\n    res[ 1] = a[b[ 1]];\n    res[ 2] = a[b[ 2]];\n    res[ 3] = a[b[ 3]];\n    res[ 4] = a[b[ 4]];\n    res[ 5] = a[b[ 5]];\n    res[ 6] = a[b[ 6]];\n    res[ 7] = a[b[ 7]];\n    res[ 8] = a[b[ 8]];\n    res[ 9] = a[b[ 9]];\n    res[10] = a[b[10]];\n    res[11] = a[b[11]];\n    res[12] = a[b[12]];\n    res[13] = a[b[13]];\n    res[14] = a[b[14]];\n    res[15] = a[b[15]];\n\n    return res;\n}\n\n// NOTE: not tested\ninline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {\n    uint8x16_t res;\n\n    res[ 0] = a[b[ 0]];\n    res[ 1] = a[b[ 1]];\n    res[ 2] = a[b[ 2]];\n    res[ 3] = a[b[ 3]];\n    res[ 4] = a[b[ 4]];\n    res[ 5] = a[b[ 5]];\n    res[ 6] = a[b[ 6]];\n    res[ 7] = a[b[ 7]];\n    res[ 8] = a[b[ 8]];\n    res[ 9] = a[b[ 9]];\n    res[10] = a[b[10]];\n    res[11] = a[b[11]];\n    res[12] = a[b[12]];\n    res[13] = a[b[13]];\n    res[14] = a[b[14]];\n    res[15] = a[b[15]];\n\n    return res;\n}\n\n#else\n\n#define ggml_int16x8x2_t  int16x8x2_t\n#define ggml_uint8x16x2_t uint8x16x2_t\n#define ggml_uint8x16x4_t uint8x16x4_t\n#define ggml_int8x16x2_t  int8x16x2_t\n#define ggml_int8x16x4_t  int8x16x4_t\n\n#define ggml_vld1q_s16_x2 vld1q_s16_x2\n#define ggml_vld1q_u8_x2  vld1q_u8_x2\n#define ggml_vld1q_u8_x4  vld1q_u8_x4\n#define ggml_vld1q_s8_x2  vld1q_s8_x2\n#define ggml_vld1q_s8_x4  vld1q_s8_x4\n#define ggml_vqtbl1q_s8   vqtbl1q_s8\n#define ggml_vqtbl1q_u8   vqtbl1q_u8\n\n#endif // !defined(__aarch64__)\n\n#if !defined(__ARM_FEATURE_DOTPROD)\n\ninline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {\n    const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));\n    const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));\n\n    return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));\n}\n\n#else\n\n#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)\n\n#endif // !defined(__ARM_FEATURE_DOTPROD)\n\n#endif // defined(__ARM_NEON)\n\n#ifdef __wasm_simd128__\n#include <wasm_simd128.h>\n#endif\n\n#ifdef __POWER9_VECTOR__\n#include <altivec.h>\n#endif\n\n#if defined(_MSC_VER) || defined(__MINGW32__)\n#include <intrin.h>\n#elif defined(__SSE__) || defined(__SSE3__) || defined(__SSSE3__) || defined(__AVX__) || defined(__F16C__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX512BF16__)\n#include <immintrin.h>\n#endif\n\n#ifdef __riscv_v_intrinsic\n#include <riscv_vector.h>\n#endif\n\n#if defined(__loongarch64)\n#if defined(__loongarch_asx)\n#include <lasxintrin.h>\n#endif\n#if defined(__loongarch_sx)\n#include <lsxintrin.h>\n#endif\n#endif\n\n#if defined(__VXE__) || defined(__VXE2__)\n#include <vecintrin.h>\n\n#define vec_neg(a)    (-(a))                // Vector Negate\n#define vec_add(a, b) ((a) + (b))           // Vector Add\n#define vec_sub(a, b) ((a) - (b))           // Vector Subtract\n#define vec_mul(a, b) ((a) * (b))           // Vector Multiply\n#define vec_div(a, b) ((a) / (b))           // Vector Divide\n#define vec_sl(a, b)  ((a) << (b))          // Vector Shift Left\n#define vec_sra(a, b) ((a) >> (b))          // Vector Shift Right\n#define vec_sr(a, b)  ((a) >> (b))          // Vector Shift Right Algebraic\n#define vec_slo(a, b) vec_slb(a, (b) << 64) // Vector Shift Left by Octet\n#define vec_sro(a, b) vec_srb(a, (b) << 64) // Vector Shift Right by Octet\n\n#ifndef vec_and\n#define vec_and(a, b) ((a) & (b)) // Vector AND\n#endif\n\n#ifndef vec_or\n#define vec_or(a, b)  ((a) | (b)) // Vector OR\n#endif\n\n#ifndef vec_xor\n#define vec_xor(a, b) ((a) ^ (b)) // Vector XOR\n#endif\n\ntypedef signed   char char8x16_t  __attribute__((vector_size(16)));\ntypedef unsigned char uchar8x16_t __attribute__((vector_size(16)));\n\ntypedef int8_t  int8x16_t __attribute__((vector_size(16)));\ntypedef int16_t int16x8_t __attribute__((vector_size(16)));\ntypedef int32_t int32x4_t __attribute__((vector_size(16)));\n\ntypedef uint8_t  uint8x16_t __attribute__((vector_size(16)));\ntypedef uint16_t uint16x8_t __attribute__((vector_size(16)));\ntypedef uint32_t uint32x4_t __attribute__((vector_size(16)));\n\ntypedef float  float32x4_t  __attribute__((vector_size(16)));\ntypedef double double64x2_t __attribute__((vector_size(16)));\n\ntypedef signed   long long long64x2_t  __attribute__((vector_size(16)));\ntypedef unsigned long long ulong64x2_t __attribute__((vector_size(16)));\n\ntypedef struct ggml_uint8x16x2_t {\n    uint8x16_t val[2];\n} ggml_uint8x16x2_t;\n\ninline static ggml_uint8x16x2_t ggml_vec_xl_u8x2(const uint8_t * ptr) {\n    ggml_uint8x16x2_t res;\n\n    res.val[0] = vec_xl( 0, ptr);\n    res.val[1] = vec_xl(16, ptr);\n\n    return res;\n}\n\ntypedef struct ggml_uint8x16x4_t {\n    uint8x16_t val[4];\n} ggml_uint8x16x4_t;\n\ninline static ggml_uint8x16x4_t ggml_vec_xl_u8x4(const uint8_t * ptr) {\n    ggml_uint8x16x4_t res;\n\n    res.val[0] = vec_xl( 0, ptr);\n    res.val[1] = vec_xl(16, ptr);\n    res.val[2] = vec_xl(32, ptr);\n    res.val[3] = vec_xl(48, ptr);\n\n    return res;\n}\n\ntypedef struct ggml_int8x16x4_t {\n    int8x16_t val[4];\n} ggml_int8x16x4_t;\n\ninline static ggml_int8x16x4_t ggml_vec_xl_s8x4(const int8_t * ptr) {\n    ggml_int8x16x4_t res;\n\n    res.val[0] = vec_xl( 0, ptr);\n    res.val[1] = vec_xl(16, ptr);\n    res.val[2] = vec_xl(32, ptr);\n    res.val[3] = vec_xl(48, ptr);\n\n    return res;\n}\n\ntypedef struct ggml_int16x8x2_t {\n    int16x8_t val[2];\n} ggml_int16x8x2_t;\n\ninline static ggml_int16x8x2_t ggml_vec_xl_s16x2(const int16_t * ptr) {\n    ggml_int16x8x2_t res;\n\n    res.val[0] = vec_xl( 0, ptr);\n    res.val[1] = vec_xl(16, ptr);\n\n    return res;\n}\n\n/*\n    ! WARNING: Very slow. Use vec_perm if possible. Refer to iq4_xs\n    !          or iq4_nl for example implementation.\n*/\ninline static int8x16_t ggml_vec_tbl(int8x16_t a, uint8x16_t b) {\n    int8x16_t res;\n\n    res[ 0] = a[b[ 0]];\n    res[ 1] = a[b[ 1]];\n    res[ 2] = a[b[ 2]];\n    res[ 3] = a[b[ 3]];\n    res[ 4] = a[b[ 4]];\n    res[ 5] = a[b[ 5]];\n    res[ 6] = a[b[ 6]];\n    res[ 7] = a[b[ 7]];\n    res[ 8] = a[b[ 8]];\n    res[ 9] = a[b[ 9]];\n    res[10] = a[b[10]];\n    res[11] = a[b[11]];\n    res[12] = a[b[12]];\n    res[13] = a[b[13]];\n    res[14] = a[b[14]];\n    res[15] = a[b[15]];\n\n    return res;\n}\n\ninline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {\n    const uchar8x16_t v_maske = {  0,  1,  4,  5,  8,  9, 12, 13,\n                                  16, 17, 20, 21, 24, 25, 28, 29 };\n\n    const int16x8_t v_abo = vec_pack((int32x4_t)a, (int32x4_t)b);\n    const int16x8_t v_abe = vec_perm(a, b, v_maske);\n    return v_abo + v_abe;\n}\n\n/**\n * @see https://github.com/ggml-org/llama.cpp/pull/14037\n */\ninline static float vec_hsum_f32x4(float32x4_t v) {\n    float32x4_t v_temp = v + vec_reve(v);\n    return v_temp[0] + v_temp[1];\n}\n\ninline static int32_t vec_hsum_i32x4(int32x4_t v) {\n    int32x4_t v_temp = v + vec_reve(v);\n    return v_temp[0] + v_temp[1];\n}\n\ninline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {\n    const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);\n    return acc + (vec_unpackh(p) + vec_unpackl(p));\n}\n\n#endif\n\n#if defined(__loongarch_sx)\n/* float type data load instructions */\nstatic __m128 __lsx_vreplfr2vr_s(const float val) {\n    v4f32 res = {val, val, val, val};\n    return (__m128)res;\n}\n#endif\n\n#if defined(__loongarch_asx)\nstatic __m256 __lasx_xvreplfr2vr_s(const float val) {\n    v8f32 res = {val, val, val, val, val, val, val, val};\n    return (__m256)res;\n}\n#endif\n\n// TODO: move to ggml-threading\nvoid ggml_barrier(struct ggml_threadpool * tp);\n\nvoid ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value);\nint  ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/ggml-cpu.c",
    "content": "#define _CRT_SECURE_NO_DEPRECATE // Disables \"unsafe\" warnings on Windows\n#define _USE_MATH_DEFINES // For M_PI on MSVC\n\n#include \"ggml-backend-impl.h\"\n#include \"ggml-backend.h\"\n#include \"traits.h\"\n#include \"ggml-cpu-impl.h\"\n#include \"ggml-impl.h\"\n#include \"quants.h\"\n#include \"ggml-threading.h\"\n#include \"unary-ops.h\"\n#include \"binary-ops.h\"\n#include \"vec.h\"\n#include \"ops.h\"\n#include \"ggml.h\"\n#include \"common.h\"\n\n#if defined(_MSC_VER) || defined(__MINGW32__)\n#include <malloc.h> // using malloc.h with MSC/MINGW\n#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)\n#include <alloca.h>\n#endif\n\n#include <assert.h>\n#include <errno.h>\n#include <time.h>\n#include <math.h>\n#include <stdlib.h>\n#include <string.h>\n#include <stdint.h>\n#include <inttypes.h>\n#include <stdio.h>\n#include <float.h>\n#include <limits.h>\n#include <stdarg.h>\n#include <signal.h>\n#if defined(__gnu_linux__)\n#include <syscall.h>\n#endif\n\n#ifdef GGML_USE_OPENMP\n#include <omp.h>\n#endif\n\n#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)\n#undef GGML_USE_LLAMAFILE\n#endif\n\n#ifdef GGML_USE_LLAMAFILE\n#include \"llamafile/sgemm.h\"\n#endif\n\n// Note: once we move threading into a separate C++ file\n// will use std::hardware_destructive_interference_size instead of hardcoding it here\n// and we'll use C++ attribute syntax.\n#define GGML_CACHE_LINE  64\n\n#if defined(__clang__) || defined(__GNUC__)\n#define GGML_CACHE_ALIGN __attribute__((aligned(GGML_CACHE_LINE)))\n#endif\n\n#if defined(__has_feature)\n#if __has_feature(thread_sanitizer)\n#define GGML_TSAN_ENABLED 1\n#endif\n#else  // __has_feature\n#if defined(__SANITIZE_THREAD__)\n#define GGML_TSAN_ENABLED 1\n#endif\n#endif // __has_feature\n\n#define UNUSED GGML_UNUSED\n#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0)\n\n// precomputed f32 table for f16 (256 KB) (simd-mappings.h)\nfloat ggml_table_f32_f16[1 << 16];\n\n// precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h)\nfloat ggml_table_f32_e8m0_half[1 << 8];\n\n#if defined(__ARM_ARCH)\nstruct ggml_arm_arch_features_type {\n    int sve_cnt;\n} ggml_arm_arch_features = { 0 };\n#endif\n\n#if defined(__riscv)\nstruct ggml_riscv_arch_features_type {\n    int rvv_vlen;\n} ggml_riscv_arch_features = { 0 };\n#endif\n\n#if defined(_WIN32)\n\n#define WIN32_LEAN_AND_MEAN\n#ifndef NOMINMAX\n    #define NOMINMAX\n#endif\n#include <windows.h>\n\n#if defined(_MSC_VER) && !defined(__clang__)\n#define GGML_CACHE_ALIGN __declspec(align(GGML_CACHE_LINE))\n\ntypedef volatile LONG atomic_int;\ntypedef atomic_int atomic_bool;\ntypedef atomic_int atomic_flag;\n\n#define ATOMIC_FLAG_INIT 0\n\ntypedef enum {\n    memory_order_relaxed,\n    memory_order_consume,\n    memory_order_acquire,\n    memory_order_release,\n    memory_order_acq_rel,\n    memory_order_seq_cst\n} memory_order;\n\nstatic void atomic_store(atomic_int * ptr, LONG val) {\n    InterlockedExchange(ptr, val);\n}\nstatic void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) {\n    // TODO: add support for explicit memory order\n    InterlockedExchange(ptr, val);\n}\nstatic LONG atomic_load(atomic_int * ptr) {\n    return InterlockedCompareExchange(ptr, 0, 0);\n}\nstatic LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) {\n    // TODO: add support for explicit memory order\n    return InterlockedCompareExchange(ptr, 0, 0);\n}\nstatic LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {\n    return InterlockedExchangeAdd(ptr, inc);\n}\nstatic LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) {\n    // TODO: add support for explicit memory order\n    return InterlockedExchangeAdd(ptr, inc);\n}\nstatic atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {\n    return InterlockedExchange(ptr, 1);\n}\nstatic void atomic_flag_clear(atomic_flag * ptr) {\n    InterlockedExchange(ptr, 0);\n}\nstatic void atomic_thread_fence(memory_order mo) {\n    MemoryBarrier();\n}\n#else // clang\n#include <stdatomic.h>\n#endif\n\ntypedef HANDLE pthread_t;\n\ntypedef DWORD thread_ret_t;\nstatic int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {\n    (void) unused;\n    HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);\n    if (handle == NULL)\n    {\n        return EAGAIN;\n    }\n\n    *out = handle;\n    return 0;\n}\n\nstatic int pthread_join(pthread_t thread, void * unused) {\n    (void) unused;\n    int ret = (int) WaitForSingleObject(thread, INFINITE);\n    CloseHandle(thread);\n    return ret;\n}\n\nstatic int sched_yield (void) {\n    Sleep (0);\n    return 0;\n}\n#else\n\n#include <pthread.h>\n#include <stdatomic.h>\n#include <sched.h>\n#if defined(__FreeBSD__)\n#include <pthread_np.h>\n#endif\n\ntypedef void * thread_ret_t;\n\n#include <sys/types.h>\n#include <sys/stat.h>\n#include <unistd.h>\n\n#endif\n\ntypedef pthread_t ggml_thread_t;\n\n#define GGML_THREADPOOL_N_THREADS_MASK (0xffffU)\n#define GGML_THREADPOOL_N_THREADS_BITS (16)\n\n#if defined(__APPLE__)\n#include <unistd.h>\n#include <mach/mach.h>\n#include <TargetConditionals.h>\n#endif\n\nstatic const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {\n    [GGML_TYPE_F32] = {\n        .from_float               = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,\n        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,\n        .vec_dot_type             = GGML_TYPE_F32,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_F16] = {\n        .from_float               = (ggml_from_float_t) ggml_cpu_fp32_to_fp16,\n        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f16,\n        .vec_dot_type             = GGML_TYPE_F16,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_Q4_0] = {\n        .from_float               = quantize_row_q4_0,\n        .vec_dot                  = ggml_vec_dot_q4_0_q8_0,\n        .vec_dot_type             = GGML_TYPE_Q8_0,\n#if defined (__ARM_FEATURE_MATMUL_INT8)\n        .nrows                    = 2,\n#else\n        .nrows                    = 1,\n#endif\n    },\n    [GGML_TYPE_Q4_1] = {\n        .from_float               = quantize_row_q4_1,\n        .vec_dot                  = ggml_vec_dot_q4_1_q8_1,\n        .vec_dot_type             = GGML_TYPE_Q8_1,\n#if defined (__ARM_FEATURE_MATMUL_INT8)\n        .nrows                    = 2,\n#else\n        .nrows                    = 1,\n#endif\n    },\n    [GGML_TYPE_Q5_0] = {\n        .from_float               = quantize_row_q5_0,\n        .vec_dot                  = ggml_vec_dot_q5_0_q8_0,\n        .vec_dot_type             = GGML_TYPE_Q8_0,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_Q5_1] = {\n        .from_float               = quantize_row_q5_1,\n        .vec_dot                  = ggml_vec_dot_q5_1_q8_1,\n        .vec_dot_type             = GGML_TYPE_Q8_1,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_Q8_0] = {\n        .from_float               = quantize_row_q8_0,\n        .vec_dot                  = ggml_vec_dot_q8_0_q8_0,\n        .vec_dot_type             = GGML_TYPE_Q8_0,\n#if defined (__ARM_FEATURE_MATMUL_INT8)\n        .nrows                    = 2,\n#else\n        .nrows                    = 1,\n#endif\n    },\n    [GGML_TYPE_Q8_1] = {\n        .from_float               = quantize_row_q8_1,\n        .vec_dot_type             = GGML_TYPE_Q8_1,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_MXFP4] = {\n        .from_float               = quantize_row_mxfp4,\n        .vec_dot                  = ggml_vec_dot_mxfp4_q8_0,\n        .vec_dot_type             = GGML_TYPE_Q8_0,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_NVFP4] = {\n        .from_float               = quantize_row_nvfp4,\n        .vec_dot                  = ggml_vec_dot_nvfp4_q8_0,\n        .vec_dot_type             = GGML_TYPE_Q8_0,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_Q2_K] = {\n        .from_float               = quantize_row_q2_K,\n        .vec_dot                  = ggml_vec_dot_q2_K_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_Q3_K] = {\n        .from_float               = quantize_row_q3_K,\n        .vec_dot                  = ggml_vec_dot_q3_K_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_Q4_K] = {\n        .from_float               = quantize_row_q4_K,\n        .vec_dot                  = ggml_vec_dot_q4_K_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n#if defined (__ARM_FEATURE_MATMUL_INT8)\n        .nrows                    = 2,\n#else\n        .nrows                    = 1,\n#endif\n    },\n    [GGML_TYPE_Q5_K] = {\n        .from_float               = quantize_row_q5_K,\n        .vec_dot                  = ggml_vec_dot_q5_K_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_Q6_K] = {\n        .from_float               = quantize_row_q6_K,\n        .vec_dot                  = ggml_vec_dot_q6_K_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n#if defined (__ARM_FEATURE_MATMUL_INT8)\n        .nrows                    = 2,\n#else\n        .nrows                    = 1,\n#endif\n    },\n    [GGML_TYPE_IQ2_XXS] = {\n        .from_float               = NULL,\n        .vec_dot                  = ggml_vec_dot_iq2_xxs_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_IQ2_XS] = {\n        .from_float               = NULL,\n        .vec_dot                  = ggml_vec_dot_iq2_xs_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_IQ3_XXS] = {\n        // NOTE: from_float for iq3 and iq2_s was removed because these quants require initialization in ggml_quantize_init\n        //.from_float               = quantize_row_iq3_xxs,\n        .vec_dot                  = ggml_vec_dot_iq3_xxs_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_IQ3_S] = {\n        //.from_float               = quantize_row_iq3_s,\n        .vec_dot                  = ggml_vec_dot_iq3_s_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_IQ2_S] = {\n        //.from_float               = quantize_row_iq2_s,\n        .vec_dot                  = ggml_vec_dot_iq2_s_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_IQ1_S] = {\n        .from_float               = NULL,\n        .vec_dot                  = ggml_vec_dot_iq1_s_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_IQ1_M] = {\n        .from_float               = NULL,\n        .vec_dot                  = ggml_vec_dot_iq1_m_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_IQ4_NL] = {\n        .from_float               = quantize_row_iq4_nl,\n        .vec_dot                  = ggml_vec_dot_iq4_nl_q8_0,\n        .vec_dot_type             = GGML_TYPE_Q8_0,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_IQ4_XS] = {\n        .from_float               = quantize_row_iq4_xs,\n        .vec_dot                  = ggml_vec_dot_iq4_xs_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_Q8_K] = {\n        .from_float               = quantize_row_q8_K,\n    },\n    [GGML_TYPE_BF16] = {\n        .from_float               = (ggml_from_float_t) ggml_cpu_fp32_to_bf16,\n        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_bf16,\n        .vec_dot_type             = GGML_TYPE_BF16,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_TQ1_0] = {\n        .from_float               = quantize_row_tq1_0,\n        .vec_dot                  = ggml_vec_dot_tq1_0_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_TQ2_0] = {\n        .from_float               = quantize_row_tq2_0,\n        .vec_dot                  = ggml_vec_dot_tq2_0_q8_K,\n        .vec_dot_type             = GGML_TYPE_Q8_K,\n        .nrows                    = 1,\n    },\n    [GGML_TYPE_I32] = {\n        .from_float               = (ggml_from_float_t) ggml_cpu_fp32_to_i32,\n    },\n};\n\nconst struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {\n    return &type_traits_cpu[type];\n}\n\n//\n// Threading defs\n//\n\ntypedef pthread_t          ggml_thread_t;\n\n#if defined(_WIN32)\n\ntypedef CONDITION_VARIABLE ggml_cond_t;\ntypedef SRWLOCK            ggml_mutex_t;\n\n#define ggml_mutex_init(m)   InitializeSRWLock(m)\n#define ggml_mutex_destroy(m)\n#define ggml_mutex_lock(m)   AcquireSRWLockExclusive(m)\n#define ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m)\n#define ggml_mutex_lock_shared(m)   AcquireSRWLockShared(m)\n#define ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m)\n\n#define ggml_cond_init(c)    InitializeConditionVariable(c)\n#define ggml_cond_destroy(c)\n#define ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED)\n#define ggml_cond_broadcast(c) WakeAllConditionVariable(c)\n\n#define ggml_thread_create pthread_create\n#define ggml_thread_join   pthread_join\n\n#else\n\ntypedef pthread_cond_t     ggml_cond_t;\ntypedef pthread_mutex_t    ggml_mutex_t;\n\n#define ggml_mutex_init(m)          pthread_mutex_init(m, NULL)\n#define ggml_mutex_destroy(m)       pthread_mutex_destroy(m)\n#define ggml_mutex_lock(m)          pthread_mutex_lock(m)\n#define ggml_mutex_unlock(m)        pthread_mutex_unlock(m)\n#define ggml_mutex_lock_shared(m)   pthread_mutex_lock(m)\n#define ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m)\n\n#define ggml_lock_init(x)    UNUSED(x)\n#define ggml_lock_destroy(x) UNUSED(x)\n#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))\n#define ggml_lock_lock(x)    _mm_pause()\n#else\n#define ggml_lock_lock(x)    UNUSED(x)\n#endif\n#define ggml_lock_unlock(x)  UNUSED(x)\n\n#define GGML_LOCK_INITIALIZER 0\n#define ggml_cond_init(c)      pthread_cond_init(c, NULL)\n#define ggml_cond_destroy(c)   pthread_cond_destroy(c)\n#define ggml_cond_wait(c, m)   pthread_cond_wait(c, m)\n#define ggml_cond_broadcast(c) pthread_cond_broadcast(c)\n\n#define ggml_thread_create pthread_create\n#define ggml_thread_join   pthread_join\n\n#endif\n\n// Threadpool def\nstruct ggml_threadpool {\n    ggml_mutex_t mutex;       // mutex for cond.var\n    ggml_cond_t  cond;        // cond.var for waiting for new work\n\n    struct ggml_cgraph * cgraph;\n    struct ggml_cplan  * cplan;\n\n    // synchronization primitives\n    atomic_int n_graph;       // updated when there is work to be done (i.e each graph) holds graph and active thread counts.\n    atomic_int GGML_CACHE_ALIGN n_barrier;\n    atomic_int GGML_CACHE_ALIGN n_barrier_passed;\n    atomic_int GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.\n\n    // these are atomic as an annotation for thread-sanitizer\n    atomic_bool stop;         // Used for stopping the threadpool altogether\n    atomic_bool pause;        // Used for pausing the threadpool or individual threads\n    atomic_int  abort;        // Used for aborting processing of a graph\n\n    struct ggml_compute_state * workers;   // per thread state\n    int          n_threads;   // Number of threads in the pool\n    int32_t      prio;        // Scheduling priority\n    uint32_t     poll;        // Polling level (0 - no polling)\n\n    enum ggml_status ec;\n};\n\n// Per-thread state\nstruct ggml_compute_state {\n#ifndef GGML_USE_OPENMP\n    ggml_thread_t thrd;\n    int  last_graph;\n    bool pending;\n#endif\n    bool cpumask[GGML_MAX_N_THREADS];\n    struct ggml_threadpool * threadpool;\n    int ith;\n};\n\n// Helpers for polling loops\n#if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) )\nstatic inline void ggml_thread_cpu_relax(void) {\n    __asm__ volatile(\"yield\" ::: \"memory\");\n}\n#elif defined(__x86_64__)\nstatic inline void ggml_thread_cpu_relax(void) {\n    _mm_pause();\n}\n#elif defined(__riscv)\nstatic inline void ggml_thread_cpu_relax(void) {\n    #ifdef __riscv_zihintpause\n        __asm__ __volatile__ (\"pause\");\n    #else\n        /* Encoding of the pause instruction */\n        __asm__ __volatile__ (\".4byte 0x100000F\");\n    #endif\n}\n#else\nstatic inline void ggml_thread_cpu_relax(void) {;}\n#endif\n\n//\n// NUMA support\n//\n\n#define GGML_NUMA_MAX_NODES 8\n#define GGML_NUMA_MAX_CPUS 512\n\nstruct ggml_numa_node {\n    uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node\n    uint32_t n_cpus;\n};\n\nstruct ggml_numa_nodes {\n    enum ggml_numa_strategy numa_strategy;\n    struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES];\n    uint32_t n_nodes;\n    uint32_t total_cpus; // hardware threads on system\n    uint32_t current_node; // node on which main process is execting\n#if defined(__gnu_linux__)\n    cpu_set_t cpuset; // cpuset from numactl\n#else\n    uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype\n#endif\n};\n\n//\n// ggml state\n//\n\nstruct ggml_state {\n    struct ggml_numa_nodes numa;\n};\n\nstatic struct ggml_state g_state = {0};\n\nvoid ggml_barrier(struct ggml_threadpool * tp) {\n    int n_threads = atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK;\n    if (n_threads == 1) {\n        return;\n    }\n\n#ifdef GGML_USE_OPENMP\n    #pragma omp barrier\n#else\n    int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed);\n\n    // enter barrier (full seq-cst fence)\n    int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst);\n\n    if (n_barrier == (n_threads - 1)) {\n        // last thread\n        atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed);\n\n        // exit barrier (full seq-cst fence)\n        atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst);\n        return;\n    }\n\n    // wait for other threads\n    while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {\n        ggml_thread_cpu_relax();\n    }\n\n    // exit barrier (full seq-cst fence)\n    // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead\n    #ifdef GGML_TSAN_ENABLED\n    atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst);\n    #else\n    atomic_thread_fence(memory_order_seq_cst);\n    #endif\n#endif\n}\n\nvoid ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value) {\n    atomic_store_explicit(&tp->current_chunk, value, memory_order_relaxed);\n}\n\nint ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value) {\n    return atomic_fetch_add_explicit(&tp->current_chunk, value, memory_order_relaxed);\n}\n\n#if defined(__gnu_linux__)\nstatic cpu_set_t ggml_get_numa_affinity(void) {\n    cpu_set_t cpuset;\n    pthread_t thread;\n    thread = pthread_self();\n    CPU_ZERO(&cpuset);\n    pthread_getaffinity_np(thread, sizeof(cpu_set_t), &cpuset);\n    return cpuset;\n}\n#else\nstatic uint32_t ggml_get_numa_affinity(void) {\n    return 0; // no NUMA support\n}\n#endif\n\nvoid ggml_numa_init(enum ggml_numa_strategy numa_flag) {\n    if (g_state.numa.n_nodes > 0) {\n        fprintf(stderr, \"ggml_numa_init: NUMA already initialized\\n\");\n\n        return;\n    }\n\n#if defined(__gnu_linux__)\n    struct stat st;\n    char path[256];\n    int rv;\n\n    // set numa scheme\n    g_state.numa.numa_strategy = numa_flag;\n\n    GGML_PRINT_DEBUG(\"numa strategy %u\\n\",g_state.numa.numa_strategy);\n\n    g_state.numa.cpuset = ggml_get_numa_affinity();\n\n    // enumerate nodes\n    while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) {\n        rv = snprintf(path, sizeof(path), \"/sys/devices/system/node/node%u\", g_state.numa.n_nodes);\n        GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));\n        if (stat(path, &st) != 0) { break; }\n        ++g_state.numa.n_nodes;\n    }\n\n    // enumerate CPUs\n    while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) {\n        rv = snprintf(path, sizeof(path), \"/sys/devices/system/cpu/cpu%u\", g_state.numa.total_cpus);\n        GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));\n        if (stat(path, &st) != 0) { break; }\n        ++g_state.numa.total_cpus;\n    }\n\n    GGML_PRINT_DEBUG(\"found %u numa nodes, %u CPUs\\n\", g_state.numa.n_nodes, g_state.numa.total_cpus);\n\n    // figure out which node we're on\n    uint current_cpu;\n    int getcpu_ret = 0;\n#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 33) || defined(__COSMOPOLITAN__)\n    getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);\n#else\n    // old glibc doesn't have a wrapper for this call. Fall back on direct syscall\n#   if !defined(SYS_getcpu) && defined(SYS_get_cpu)\n#       define SYS_getcpu SYS_get_cpu // some older glibc versions use this name\n#   endif\n    getcpu_ret = syscall(SYS_getcpu, &current_cpu, &g_state.numa.current_node);\n#endif\n\n    if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) {\n        g_state.numa.n_nodes = 0;\n        return;\n    }\n\n    GGML_PRINT_DEBUG(\"found our process on numa node %u, CPU %u\\n\", g_state.numa.current_node, current_cpu);\n\n    for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) {\n        struct ggml_numa_node * node = &g_state.numa.nodes[n];\n        GGML_PRINT_DEBUG(\"CPUs on node %u:\", n);\n        node->n_cpus = 0;\n        for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) {\n            rv = snprintf(path, sizeof(path), \"/sys/devices/system/node/node%u/cpu%u\", n, c);\n            GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));\n            if (stat(path, &st) == 0) {\n                node->cpus[node->n_cpus++] = c;\n                GGML_PRINT_DEBUG(\" %u\", c);\n            }\n        }\n        GGML_PRINT_DEBUG(\"\\n\");\n    }\n\n    if (ggml_is_numa()) {\n        FILE *fptr = fopen(\"/proc/sys/kernel/numa_balancing\", \"r\");\n        if (fptr != NULL) {\n            char buf[42];\n            if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, \"0\\n\", sizeof(buf)) != 0) {\n                GGML_LOG_WARN(\"/proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\\n\");\n            }\n            fclose(fptr);\n        }\n    }\n#else\n    UNUSED(numa_flag);\n    // TODO\n#endif\n}\n\nbool ggml_is_numa(void) {\n    return g_state.numa.n_nodes > 1;\n}\n\n#if defined(__ARM_ARCH)\n#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)\n#include <arm_sve.h>\nstatic void ggml_init_arm_arch_features(void) {\n    ggml_arm_arch_features.sve_cnt = svcntb();\n}\n#else\nstatic void ggml_init_arm_arch_features(void) {}\n#endif\n#endif // __ARM_ARCH\n\n#if defined(__riscv) && defined(__riscv_v_intrinsic)\n#include <riscv_vector.h>\nstatic void ggml_init_riscv_arch_features(void) {\n    ggml_riscv_arch_features.rvv_vlen = __riscv_vlenb();\n}\n#else\nstatic void ggml_init_riscv_arch_features(void) {}\n#endif\n\nstruct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {\n    GGML_ASSERT(!ggml_get_no_alloc(ctx));\n\n    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);\n\n    ggml_set_i32(result, value);\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {\n    GGML_ASSERT(!ggml_get_no_alloc(ctx));\n\n    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);\n\n    ggml_set_f32(result, value);\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {\n    const int n     = ggml_nrows(tensor);\n    const int nc    = tensor->ne[0];\n    const size_t n1 = tensor->nb[1];\n\n    char * const data = tensor->data;\n\n    switch (tensor->type) {\n        case GGML_TYPE_I8:\n            {\n                assert(tensor->nb[0] == sizeof(int8_t));\n                for (int i = 0; i < n; i++) {\n                    ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);\n                }\n            } break;\n        case GGML_TYPE_I16:\n            {\n                assert(tensor->nb[0] == sizeof(int16_t));\n                for (int i = 0; i < n; i++) {\n                    ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);\n                }\n            } break;\n        case GGML_TYPE_I32:\n            {\n                assert(tensor->nb[0] == sizeof(int32_t));\n                for (int i = 0; i < n; i++) {\n                    ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);\n                }\n            } break;\n        case GGML_TYPE_F16:\n            {\n                assert(tensor->nb[0] == sizeof(ggml_fp16_t));\n                for (int i = 0; i < n; i++) {\n                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_CPU_FP32_TO_FP16(value));\n                }\n            } break;\n        case GGML_TYPE_BF16:\n            {\n                assert(tensor->nb[0] == sizeof(ggml_fp16_t));\n                for (int i = 0; i < n; i++) {\n                    ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));\n                }\n            } break;\n        case GGML_TYPE_F32:\n            {\n                assert(tensor->nb[0] == sizeof(float));\n                for (int i = 0; i < n; i++) {\n                    ggml_vec_set_f32(nc, (float *)(data + i*n1), value);\n                }\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n\n    return tensor;\n}\n\nstruct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {\n    const int n     = ggml_nrows(tensor);\n    const int nc    = tensor->ne[0];\n    const size_t n1 = tensor->nb[1];\n\n    char * const data = tensor->data;\n\n    switch (tensor->type) {\n        case GGML_TYPE_I8:\n            {\n                assert(tensor->nb[0] == sizeof(int8_t));\n                for (int i = 0; i < n; i++) {\n                    ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);\n                }\n            } break;\n        case GGML_TYPE_I16:\n            {\n                assert(tensor->nb[0] == sizeof(int16_t));\n                for (int i = 0; i < n; i++) {\n                    ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);\n                }\n            } break;\n        case GGML_TYPE_I32:\n            {\n                assert(tensor->nb[0] == sizeof(int32_t));\n                for (int i = 0; i < n; i++) {\n                    ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);\n                }\n            } break;\n        case GGML_TYPE_F16:\n            {\n                assert(tensor->nb[0] == sizeof(ggml_fp16_t));\n                for (int i = 0; i < n; i++) {\n                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_CPU_FP32_TO_FP16(value));\n                }\n            } break;\n        case GGML_TYPE_BF16:\n            {\n                assert(tensor->nb[0] == sizeof(ggml_bf16_t));\n                for (int i = 0; i < n; i++) {\n                    ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));\n                }\n            } break;\n        case GGML_TYPE_F32:\n            {\n                assert(tensor->nb[0] == sizeof(float));\n                for (int i = 0; i < n; i++) {\n                    ggml_vec_set_f32(nc, (float *)(data + i*n1), value);\n                }\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n\n    return tensor;\n}\n\nint32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {\n    if (!ggml_is_contiguous(tensor)) {\n        int64_t id[4] = { 0, 0, 0, 0 };\n        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);\n        return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);\n    }\n    switch (tensor->type) {\n        case GGML_TYPE_I8:\n            {\n                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));\n                return ((int8_t *)(tensor->data))[i];\n            }\n        case GGML_TYPE_I16:\n            {\n                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));\n                return ((int16_t *)(tensor->data))[i];\n            }\n        case GGML_TYPE_I32:\n            {\n                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));\n                return ((int32_t *)(tensor->data))[i];\n            }\n        case GGML_TYPE_F16:\n            {\n                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));\n                return GGML_CPU_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);\n            }\n        case GGML_TYPE_BF16:\n            {\n                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));\n                return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);\n            }\n        case GGML_TYPE_F32:\n            {\n                GGML_ASSERT(tensor->nb[0] == sizeof(float));\n                return ((float *)(tensor->data))[i];\n            }\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\nvoid ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {\n    if (!ggml_is_contiguous(tensor)) {\n        int64_t id[4] = { 0, 0, 0, 0 };\n        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);\n        ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);\n        return;\n    }\n    switch (tensor->type) {\n        case GGML_TYPE_I8:\n            {\n                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));\n                ((int8_t *)(tensor->data))[i] = value;\n            } break;\n        case GGML_TYPE_I16:\n            {\n                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));\n                ((int16_t *)(tensor->data))[i] = value;\n            } break;\n        case GGML_TYPE_I32:\n            {\n                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));\n                ((int32_t *)(tensor->data))[i] = value;\n            } break;\n        case GGML_TYPE_F16:\n            {\n                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));\n                ((ggml_fp16_t *)(tensor->data))[i] = GGML_CPU_FP32_TO_FP16(value);\n            } break;\n        case GGML_TYPE_BF16:\n            {\n                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));\n                ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);\n            } break;\n        case GGML_TYPE_F32:\n            {\n                GGML_ASSERT(tensor->nb[0] == sizeof(float));\n                ((float *)(tensor->data))[i] = value;\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\nint32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {\n    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];\n    switch (tensor->type) {\n        case GGML_TYPE_I8:\n            return ((int8_t *) data)[0];\n        case GGML_TYPE_I16:\n            return ((int16_t *) data)[0];\n        case GGML_TYPE_I32:\n            return ((int32_t *) data)[0];\n        case GGML_TYPE_F16:\n            return GGML_CPU_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);\n        case GGML_TYPE_BF16:\n            return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);\n        case GGML_TYPE_F32:\n            return ((float *) data)[0];\n        default:\n            GGML_ABORT(\"fatal error\");\n    }\n}\n\nvoid ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {\n    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];\n    switch (tensor->type) {\n        case GGML_TYPE_I8:\n            {\n                ((int8_t *)(data))[0] = value;\n            } break;\n        case GGML_TYPE_I16:\n            {\n                ((int16_t *)(data))[0] = value;\n            } break;\n        case GGML_TYPE_I32:\n            {\n                ((int32_t *)(data))[0] = value;\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ((ggml_fp16_t *)(data))[0] = GGML_CPU_FP32_TO_FP16(value);\n            } break;\n        case GGML_TYPE_BF16:\n            {\n                ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);\n            } break;\n        case GGML_TYPE_F32:\n            {\n                ((float *)(data))[0] = value;\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\nfloat ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {\n    if (!ggml_is_contiguous(tensor)) {\n        int64_t id[4] = { 0, 0, 0, 0 };\n        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);\n        return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);\n    }\n    switch (tensor->type) {\n        case GGML_TYPE_I8:\n            {\n                return ((int8_t *)(tensor->data))[i];\n            }\n        case GGML_TYPE_I16:\n            {\n                return ((int16_t *)(tensor->data))[i];\n            }\n        case GGML_TYPE_I32:\n            {\n                return ((int32_t *)(tensor->data))[i];\n            }\n        case GGML_TYPE_F16:\n            {\n                return GGML_CPU_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);\n            }\n        case GGML_TYPE_BF16:\n            {\n                return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);\n            }\n        case GGML_TYPE_F32:\n            {\n                return ((float *)(tensor->data))[i];\n            }\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\nvoid ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {\n    if (!ggml_is_contiguous(tensor)) {\n        int64_t id[4] = { 0, 0, 0, 0 };\n        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);\n        ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);\n        return;\n    }\n    switch (tensor->type) {\n        case GGML_TYPE_I8:\n            {\n                ((int8_t *)(tensor->data))[i] = value;\n            } break;\n        case GGML_TYPE_I16:\n            {\n                ((int16_t *)(tensor->data))[i] = value;\n            } break;\n        case GGML_TYPE_I32:\n            {\n                ((int32_t *)(tensor->data))[i] = value;\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ((ggml_fp16_t *)(tensor->data))[i] = GGML_CPU_FP32_TO_FP16(value);\n            } break;\n        case GGML_TYPE_BF16:\n            {\n                ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);\n            } break;\n        case GGML_TYPE_F32:\n            {\n                ((float *)(tensor->data))[i] = value;\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\nfloat ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {\n    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];\n    switch (tensor->type) {\n        case GGML_TYPE_I8:\n            return ((int8_t *) data)[0];\n        case GGML_TYPE_I16:\n            return ((int16_t *) data)[0];\n        case GGML_TYPE_I32:\n            return ((int32_t *) data)[0];\n        case GGML_TYPE_F16:\n            return GGML_CPU_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);\n        case GGML_TYPE_BF16:\n            return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);\n        case GGML_TYPE_F32:\n            return ((float *) data)[0];\n        default:\n            GGML_ABORT(\"fatal error\");\n    }\n}\n\nvoid ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {\n    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];\n    switch (tensor->type) {\n        case GGML_TYPE_I8:\n            {\n                ((int8_t *)(data))[0] = value;\n            } break;\n        case GGML_TYPE_I16:\n            {\n                ((int16_t *)(data))[0] = value;\n            } break;\n        case GGML_TYPE_I32:\n            {\n                ((int32_t *)(data))[0] = value;\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ((ggml_fp16_t *)(data))[0] = GGML_CPU_FP32_TO_FP16(value);\n            } break;\n        case GGML_TYPE_BF16:\n            {\n                ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);\n            } break;\n        case GGML_TYPE_F32:\n            {\n                ((float *)(data))[0] = value;\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////\n\n// ggml_compute_forward_mul_mat\n\nstatic void ggml_compute_forward_mul_mat_one_chunk(\n    const struct ggml_compute_params * params,\n    struct ggml_tensor * dst,\n    const enum ggml_type type,\n    const int64_t num_rows_per_vec_dot,\n    const int64_t ir0_start,\n    const int64_t ir0_end,\n    const int64_t ir1_start,\n    const int64_t ir1_end) {\n\n    const struct ggml_tensor * src0 = dst->src[0];\n    const struct ggml_tensor * src1 = dst->src[1];\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const bool src1_cont = ggml_is_contiguous(src1);\n\n    ggml_vec_dot_t const vec_dot      = type_traits_cpu[type].vec_dot;\n    enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;\n\n    // broadcast factors\n    const int64_t r2 = ne12 / ne02;\n    const int64_t r3 = ne13 / ne03;\n\n    //printf(\"ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\\n\", ir0_start, ir0_end, ir1_start, ir1_end);\n\n    // threads with no work simply yield (not sure if it helps)\n    if (ir0_start >= ir0_end || ir1_start >= ir1_end) {\n        return;\n    }\n\n    const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;\n    const size_t row_size = ggml_row_size(vec_dot_type, ne10);\n\n    assert(ne12 % ne02 == 0);\n    assert(ne13 % ne03 == 0);\n\n    // block-tiling attempt\n    const int64_t blck_0 = 16;\n    const int64_t blck_1 = 16;\n\n    const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;\n\n    // attempt to reduce false-sharing (does not seem to make a difference)\n    // 16 * 2, accounting for mmla kernels\n    float tmp[32];\n\n    for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {\n        for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {\n            for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {\n                const int64_t i13 = (ir1 / (ne12 * ne1));\n                const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;\n                const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);\n\n                // broadcast src0 into src1\n                const int64_t i03 = i13 / r3;\n                const int64_t i02 = i12 / r2;\n\n                const int64_t i1 = i11;\n                const int64_t i2 = i12;\n                const int64_t i3 = i13;\n\n                const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);\n\n                // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides\n                //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using\n                //       the original src1 data pointer, so we should index using the indices directly\n                // TODO: this is a bit of a hack, we should probably have a better way to handle this\n                const char * src1_col = (const char*)wdata +\n                    (src1_cont || src1->type != vec_dot_type\n                        ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size\n                        : (i11 * nb11 + i12 * nb12 + i13 * nb13));\n                float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));\n\n                //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {\n                //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);\n                //}\n\n                for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {\n                    vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);\n                }\n\n                for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {\n                    memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_mul_mat(\n        const struct ggml_compute_params * params,\n              struct ggml_tensor * dst) {\n\n    const struct ggml_tensor * src0 = dst->src[0];\n    const struct ggml_tensor * src1 = dst->src[1];\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    enum ggml_type           const vec_dot_type         = type_traits_cpu[src0->type].vec_dot_type;\n    ggml_from_float_t        const from_float           = type_traits_cpu[vec_dot_type].from_float;\n    int64_t                  const vec_dot_num_rows     = type_traits_cpu[src0->type].nrows;\n\n    GGML_ASSERT(ne0 == ne01);\n    GGML_ASSERT(ne1 == ne11);\n    GGML_ASSERT(ne2 == ne12);\n    GGML_ASSERT(ne3 == ne13);\n\n    // we don't support permuted src0 or src1\n    GGML_ASSERT(nb00 == ggml_type_size(src0->type));\n    GGML_ASSERT(nb10 == ggml_type_size(src1->type));\n\n    // dst cannot be transposed or permuted\n    GGML_ASSERT(nb0 == sizeof(float));\n    GGML_ASSERT(nb0 <= nb1);\n    GGML_ASSERT(nb1 <= nb2);\n    GGML_ASSERT(nb2 <= nb3);\n\n    // nb01 >= nb00 - src0 is not transposed\n    //   compute by src0 rows\n\n    // TODO: extract to \"extra_op\"\n#if GGML_USE_LLAMAFILE\n    // broadcast factors\n    const int64_t r2 = ne12 / ne02;\n    const int64_t r3 = ne13 / ne03;\n\n    const bool src1_cont = ggml_is_contiguous(src1);\n\n    if (src1_cont) {\n        for (int64_t i13 = 0; i13 < ne13; i13++)\n            for (int64_t i12 = 0; i12 < ne12; i12++)\n                if (!llamafile_sgemm(params,\n                                     ne01, ne11, ne00/ggml_blck_size(src0->type),\n                                     (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,\n                                     nb01/ggml_type_size(src0->type),\n                                     (const char *)src1->data + i12*nb12 + i13*nb13,\n                                     nb11/ggml_type_size(src1->type),\n                                     (char *)dst->data + i12*nb2 + i13*nb3,\n                                     nb1/ggml_type_size(dst->type),\n                                     src0->type,\n                                     src1->type,\n                                     dst->type))\n                    goto UseGgmlGemm1;\n        return;\n    }\nUseGgmlGemm1:;\n#endif\n\n    if (src1->type != vec_dot_type) {\n        char * wdata = params->wdata;\n\n        const size_t nbw0 = ggml_type_size(vec_dot_type);\n        const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);\n        const size_t nbw2 = nbw1*ne11;\n        const size_t nbw3 = nbw2*ne12;\n\n        assert(params->wsize >= ne13*nbw3);\n        GGML_ASSERT(src1->type == GGML_TYPE_F32);\n\n    #if 0\n        for (int64_t i13 = 0; i13 < ne13; ++i13) {\n            for (int64_t i12 = 0; i12 < ne12; ++i12) {\n                for (int64_t i11 = ith; i11 < ne11; i11 += nth) {\n                    from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),\n                               (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),\n                                ne10);\n                }\n            }\n        }\n    #else\n        for (int64_t i13 = 0; i13 < ne13; ++i13) {\n            for (int64_t i12 = 0; i12 < ne12; ++i12) {\n                for (int64_t i11 = 0; i11 < ne11; ++i11) {\n                    size_t bs = ggml_blck_size(vec_dot_type);\n                    int64_t ne10_block_start = (ith * ne10/bs) / nth;\n                    int64_t ne10_block_end   = ((ith + 1) * ne10/bs) / nth;\n                    from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),\n                               (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),\n                               (ne10_block_end - ne10_block_start) * bs);\n                }\n            }\n        }\n    #endif\n    }\n\n    if (ith == 0) {\n        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.\n        atomic_store_explicit(&params->threadpool->current_chunk, nth, memory_order_relaxed);\n    }\n\n    ggml_barrier(params->threadpool);\n\n#if GGML_USE_LLAMAFILE\n    if (src1->type != vec_dot_type) {\n        const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;\n        const size_t row_size = ggml_row_size(vec_dot_type, ne10);\n\n        for (int64_t i13 = 0; i13 < ne13; i13++)\n            for (int64_t i12 = 0; i12 < ne12; i12++)\n                if (!llamafile_sgemm(params,\n                                     ne01, ne11, ne00/ggml_blck_size(src0->type),\n                                     (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,\n                                     nb01/ggml_type_size(src0->type),\n                                     (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,\n                                     row_size/ggml_type_size(vec_dot_type),\n                                     (char *)dst->data + i12*nb2 + i13*nb3,\n                                     nb1/ggml_type_size(dst->type),\n                                     src0->type,\n                                     vec_dot_type,\n                                     dst->type))\n                    goto UseGgmlGemm2;\n        return;\n    }\nUseGgmlGemm2:;\n#endif\n\n    // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)\n    const int64_t nr0 = ne0;\n\n    // This is the size of the rest of the dimensions of the result\n    const int64_t nr1 = ne1 * ne2 * ne3;\n\n    // Now select a reasonable chunk size.\n    int chunk_size = 16;\n\n    // We need to step up the size if it's small\n    if (nr0 == 1 || nr1 == 1) {\n        chunk_size = 64;\n    }\n\n    // distribute the work across the inner or outer loop based on which one is larger\n    // The number of chunks in the 0/1 dim.\n    // CEIL(nr0/chunk_size)\n    int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;\n    int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;\n\n    // If the chunking is poor for the number of threads on this setup, scrap the whole plan.  Re-chunk it by thread.\n    //   Also, chunking by thread was measured to have perform better on NUMA systems.  See https://github.com/ggml-org/llama.cpp/pull/6915\n    //   In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.\n    if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {\n        // distribute the thread work across the inner or outer loop based on which one is larger\n        nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows\n        nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows\n    }\n\n    // The number of elements in each chunk\n    const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;\n    const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;\n\n    // The first chunk comes from our thread_id, the rest will get auto-assigned.\n    int current_chunk = ith;\n\n    while (current_chunk < nchunk0 * nchunk1) {\n        const int64_t ith0 = current_chunk % nchunk0;\n        const int64_t ith1 = current_chunk / nchunk0;\n\n        const int64_t ir0_start = dr0 * ith0;\n        const int64_t ir0_end = MIN(ir0_start + dr0, nr0);\n\n        const int64_t ir1_start = dr1 * ith1;\n        const int64_t ir1_end = MIN(ir1_start + dr1, nr1);\n\n        // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols\n        int64_t num_rows_per_vec_dot = vec_dot_num_rows;\n\n        // these checks are needed to avoid crossing dim1 boundaries\n        // can be optimized, but the logic would become more complicated, so keeping it like this for simplicity\n        if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {\n            num_rows_per_vec_dot = 1;\n        }\n        ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);\n\n        if (nth >= nchunk0 * nchunk1) {\n            break;\n        }\n\n        current_chunk = atomic_fetch_add_explicit(&params->threadpool->current_chunk, 1, memory_order_relaxed);\n    }\n}\n\n// ggml_compute_forward_mul_mat_id\n\n#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]\n\nstruct mmid_row_mapping {\n    int32_t i1;\n    int32_t i2;\n};\n\nstatic void ggml_compute_forward_mul_mat_id_one_chunk(\n    struct ggml_tensor * dst,\n    const struct ggml_tensor * src0,\n    const struct ggml_tensor * src1,\n    const struct ggml_tensor * ids,\n    const int64_t cur_a,\n    const int64_t ir0_start,\n    const int64_t ir0_end,\n    const int64_t ir1_start,\n    const int64_t ir1_end,\n    const char * src0_cur,\n    const struct mmid_row_mapping * matrix_rows,\n    const size_t row_size,\n    const bool src1_cont,\n    const void * wdata) {\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const enum ggml_type type = src0->type;\n\n    ggml_vec_dot_t    const vec_dot      = type_traits_cpu[type].vec_dot;\n    enum ggml_type    const vec_dot_type = type_traits_cpu[type].vec_dot_type;\n\n    const int64_t blck_0 = 16;\n    const int64_t blck_1 = 16;\n\n    float tmp[16];\n\n    for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {\n        for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {\n            for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) {\n                const int64_t _i12 = ir1; // logical row index for this expert\n\n                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);\n                const int id       = row_mapping.i1; // selected expert index\n\n                const int64_t  i11 = id % ne11;\n                const int64_t  i12 = row_mapping.i2; // row index in src1\n\n                const int64_t  i1 = id;  // selected expert index\n                const int64_t  i2 = i12; // row\n\n                // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides\n                //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using\n                //       the original src1 data pointer, so we should index using the indices directly\n                // TODO: this is a bit of a hack, we should probably have a better way to handle this\n                const char * src1_col = (const char *) wdata +\n                    (src1_cont || src1->type != vec_dot_type\n                    ? (i11      + i12*ne11)*row_size\n                    : (i11*nb11 + i12*nb12));\n\n                float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));\n\n                for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {\n                    vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);\n                }\n\n                memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));\n            }\n        }\n    }\n}\n\nstatic void * incr_ptr_aligned(void ** p, size_t size, size_t align) {\n\n    void * ptr = *p;\n    ptr = (void *) GGML_PAD((uintptr_t) ptr, align);\n    *p = (void *) ((char *) ptr + size);\n    return ptr;\n}\n\nstatic void ggml_compute_forward_mul_mat_id(\n        const struct ggml_compute_params * params,\n              struct ggml_tensor * dst) {\n\n    const struct ggml_tensor * src0 = dst->src[0];\n    const struct ggml_tensor * src1 = dst->src[1];\n    const struct ggml_tensor * ids = dst->src[2];\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const enum ggml_type type = src0->type;\n\n    const bool src1_cont = ggml_is_contiguous(src1);\n\n    enum ggml_type    const vec_dot_type    = type_traits_cpu[type].vec_dot_type;\n    ggml_from_float_t const from_float      = type_traits_cpu[vec_dot_type].from_float;\n\n    // we don't support permuted src0 or src1\n    GGML_ASSERT(nb00 == ggml_type_size(type));\n    GGML_ASSERT(nb10 == ggml_type_size(src1->type));\n\n    // dst cannot be transposed or permuted\n    GGML_ASSERT(nb0 == sizeof(float));\n    GGML_ASSERT(nb0 <= nb1);\n    GGML_ASSERT(nb1 <= nb2);\n    GGML_ASSERT(nb2 <= nb3);\n\n    // row groups\n    const int n_ids = ids->ne[0]; // n_expert_used\n    const int n_as  = ne02;       // n_expert\n\n    void * wdata_cur = params->wdata;\n\n    if (src1->type != vec_dot_type) {\n        incr_ptr_aligned(&wdata_cur, ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));\n    }\n\n    int64_t * matrix_row_counts = // [n_as]\n        incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t));\n\n    struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]]\n        incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t));\n\n    char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as]\n        incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE);\n\n    GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata));\n\n    if (src1->type != vec_dot_type) {\n        char * wdata = params->wdata;\n\n        const size_t nbw0 = ggml_type_size(vec_dot_type);\n        const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);\n        const size_t nbw2 = nbw1*ne11;\n        const size_t nbw3 = nbw2*ne12;\n\n        assert(params->wsize >= ne13*nbw3);\n        GGML_ASSERT(src1->type == GGML_TYPE_F32);\n\n#if 0\n        for (int64_t i13 = 0; i13 < ne13; ++i13) {\n            for (int64_t i12 = ith; i12 < ne12; i12 += nth) {\n                for (int64_t i11 = 0; i11 < ne11; ++i11) {\n                    from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),\n                               (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),\n                               ne10);\n                }\n            }\n        }\n#else\n        for (int64_t i13 = 0; i13 < ne13; ++i13) {\n            for (int64_t i12 = 0; i12 < ne12; ++i12) {\n                for (int64_t i11 = 0; i11 < ne11; ++i11) {\n                    size_t bs = ggml_blck_size(vec_dot_type);\n                    int64_t ne10_block_start = (ith * ne10/bs) / nth;\n                    int64_t ne10_block_end   = ((ith + 1) * ne10/bs) / nth;\n                    from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),\n                               (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),\n                               (ne10_block_end - ne10_block_start) * bs);\n                }\n            }\n        }\n#endif\n    }\n\n    if (ith == 0) {\n        // initialize matrix_row_counts\n        memset(matrix_row_counts, 0, n_as*sizeof(int64_t));\n\n        // group rows by src0 matrix\n        for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {\n            for (int id = 0; id < n_ids; ++id) {\n                const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);\n\n                assert(i02 >= 0 && i02 < n_as);\n\n                MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};\n                matrix_row_counts[i02] += 1;\n            }\n        }\n    }\n\n    // reset current_chunk\n    for (int cur_a = ith; cur_a < n_as; cur_a += nth) {\n        atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);\n        *current_chunk_ctr = nth;\n    }\n\n    ggml_barrier(params->threadpool);\n\n    for (int cur_a = 0; cur_a < n_as; ++cur_a) {\n        const int64_t cne1 = matrix_row_counts[cur_a];\n\n        if (cne1 == 0) {\n            continue;\n        }\n\n        const char * src0_cur = (const char *) src0->data + cur_a * nb02;\n        const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;\n        const size_t row_size = ggml_row_size(vec_dot_type, ne10);\n\n        const int64_t nr0 = ne01;\n        const int64_t nr1 = cne1;\n\n        int chunk_size = 16;\n        if (nr0 == 1 || nr1 == 1) {\n            chunk_size = 64;\n        }\n\n        // disable for NUMA\n        const bool disable_chunking = ggml_is_numa();\n\n        int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;\n        int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;\n\n        if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {\n            nchunk0 = nr0 > nr1 ? nth : 1;\n            nchunk1 = nr0 > nr1 ? 1 : nth;\n        }\n\n        const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;\n        const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;\n\n        int current_chunk = ith;\n\n        atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);\n\n        while (current_chunk < nchunk0 * nchunk1) {\n            const int64_t ith0 = current_chunk % nchunk0;\n            const int64_t ith1 = current_chunk / nchunk0;\n\n            const int64_t ir0_start = dr0 * ith0;\n            const int64_t ir0_end = MIN(ir0_start + dr0, nr0);\n\n            const int64_t ir1_start = dr1 * ith1;\n            const int64_t ir1_end = MIN(ir1_start + dr1, nr1);\n\n            ggml_compute_forward_mul_mat_id_one_chunk(\n                dst, src0, src1, ids, cur_a,\n                ir0_start, ir0_end, ir1_start, ir1_end,\n                src0_cur, matrix_rows, row_size, src1_cont, wdata\n            );\n\n            if (nth >= nchunk0 * nchunk1) {\n                break;\n            }\n\n            current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed);\n        }\n    }\n}\n\n/////////////////////////////////\n\nstatic void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {\n    GGML_ASSERT(params);\n\n    if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {\n        return;\n    }\n\n    // extra_buffer op?\n    if (ggml_cpu_extra_compute_forward(params, tensor)) {\n        return;\n    }\n\n    switch (tensor->op) {\n        case GGML_OP_DUP:\n            {\n                ggml_compute_forward_dup(params, tensor);\n            } break;\n        case GGML_OP_ADD:\n            {\n                ggml_compute_forward_add(params, tensor);\n            } break;\n        case GGML_OP_ADD_ID:\n            {\n                ggml_compute_forward_add_id(params, tensor);\n            } break;\n        case GGML_OP_ADD1:\n            {\n                ggml_compute_forward_add1(params, tensor);\n            } break;\n        case GGML_OP_ACC:\n            {\n                ggml_compute_forward_acc(params, tensor);\n            } break;\n        case GGML_OP_SUB:\n            {\n                ggml_compute_forward_sub(params, tensor);\n            } break;\n        case GGML_OP_MUL:\n            {\n                ggml_compute_forward_mul(params, tensor);\n            } break;\n        case GGML_OP_DIV:\n            {\n                ggml_compute_forward_div(params, tensor);\n            } break;\n        case GGML_OP_SQR:\n            {\n                ggml_compute_forward_sqr(params, tensor);\n            } break;\n        case GGML_OP_SQRT:\n            {\n                ggml_compute_forward_sqrt(params, tensor);\n            } break;\n        case GGML_OP_LOG:\n            {\n                ggml_compute_forward_log(params, tensor);\n            } break;\n        case GGML_OP_SIN:\n            {\n                ggml_compute_forward_sin(params, tensor);\n            } break;\n        case GGML_OP_COS:\n            {\n                ggml_compute_forward_cos(params, tensor);\n            } break;\n        case GGML_OP_SUM:\n            {\n                ggml_compute_forward_sum(params, tensor);\n            } break;\n        case GGML_OP_SUM_ROWS:\n            {\n                ggml_compute_forward_sum_rows(params, tensor);\n            } break;\n        case GGML_OP_CUMSUM:\n            {\n                ggml_compute_forward_cumsum(params, tensor);\n            } break;\n        case GGML_OP_MEAN:\n            {\n                ggml_compute_forward_mean(params, tensor);\n            } break;\n        case GGML_OP_ARGMAX:\n            {\n                ggml_compute_forward_argmax(params, tensor);\n            } break;\n        case GGML_OP_COUNT_EQUAL:\n            {\n                ggml_compute_forward_count_equal(params, tensor);\n            } break;\n        case GGML_OP_REPEAT:\n            {\n                ggml_compute_forward_repeat(params, tensor);\n            } break;\n        case GGML_OP_REPEAT_BACK:\n            {\n                ggml_compute_forward_repeat_back(params, tensor);\n            } break;\n        case GGML_OP_CONCAT:\n            {\n                ggml_compute_forward_concat(params, tensor);\n            } break;\n        case GGML_OP_SILU_BACK:\n            {\n                ggml_compute_forward_silu_back(params, tensor);\n            } break;\n        case GGML_OP_NORM:\n            {\n                ggml_compute_forward_norm(params, tensor);\n            } break;\n        case GGML_OP_RMS_NORM:\n            {\n                ggml_compute_forward_rms_norm(params, tensor);\n            } break;\n        case GGML_OP_RMS_NORM_BACK:\n            {\n                ggml_compute_forward_rms_norm_back(params, tensor);\n            } break;\n        case GGML_OP_GROUP_NORM:\n            {\n                ggml_compute_forward_group_norm(params, tensor);\n            } break;\n        case GGML_OP_L2_NORM:\n            {\n                ggml_compute_forward_l2_norm(params, tensor);\n            } break;\n        case GGML_OP_MUL_MAT:\n            {\n                ggml_compute_forward_mul_mat(params, tensor);\n            } break;\n        case GGML_OP_MUL_MAT_ID:\n            {\n                ggml_compute_forward_mul_mat_id(params, tensor);\n            } break;\n        case GGML_OP_OUT_PROD:\n            {\n                ggml_compute_forward_out_prod(params, tensor);\n            } break;\n        case GGML_OP_SCALE:\n            {\n                ggml_compute_forward_scale(params, tensor);\n            } break;\n        case GGML_OP_SET:\n            {\n                ggml_compute_forward_set(params, tensor);\n            } break;\n        case GGML_OP_CPY:\n            {\n                ggml_compute_forward_cpy(params, tensor);\n            } break;\n        case GGML_OP_CONT:\n            {\n                ggml_compute_forward_cont(params, tensor);\n            } break;\n        case GGML_OP_GET_ROWS:\n            {\n                ggml_compute_forward_get_rows(params, tensor);\n            } break;\n        case GGML_OP_GET_ROWS_BACK:\n            {\n                ggml_compute_forward_get_rows_back(params, tensor);\n            } break;\n        case GGML_OP_SET_ROWS:\n            {\n                ggml_compute_forward_set_rows(params, tensor);\n            } break;\n        case GGML_OP_DIAG:\n            {\n                ggml_compute_forward_diag(params, tensor);\n            } break;\n        case GGML_OP_DIAG_MASK_INF:\n            {\n                ggml_compute_forward_diag_mask_inf(params, tensor);\n            } break;\n        case GGML_OP_DIAG_MASK_ZERO:\n            {\n                ggml_compute_forward_diag_mask_zero(params, tensor);\n            } break;\n        case GGML_OP_SOFT_MAX:\n            {\n                ggml_compute_forward_soft_max(params, tensor);\n            } break;\n        case GGML_OP_SOFT_MAX_BACK:\n            {\n                ggml_compute_forward_soft_max_ext_back(params, tensor);\n            } break;\n        case GGML_OP_ROPE:\n            {\n                ggml_compute_forward_rope(params, tensor);\n            } break;\n        case GGML_OP_ROPE_BACK:\n            {\n                ggml_compute_forward_rope_back(params, tensor);\n            } break;\n        case GGML_OP_CLAMP:\n            {\n                ggml_compute_forward_clamp(params, tensor);\n            } break;\n        case GGML_OP_CONV_TRANSPOSE_1D:\n            {\n                ggml_compute_forward_conv_transpose_1d(params, tensor);\n            } break;\n        case GGML_OP_IM2COL:\n            {\n                ggml_compute_forward_im2col(params, tensor);\n            } break;\n        case GGML_OP_IM2COL_BACK:\n            {\n                ggml_compute_forward_im2col_back_f32(params, tensor);\n            } break;\n        case GGML_OP_IM2COL_3D:\n            {\n                ggml_compute_forward_im2col_3d(params, tensor);\n            } break;\n        case GGML_OP_CONV_2D:\n            {\n                ggml_compute_forward_conv_2d(params, tensor);\n            } break;\n        case GGML_OP_CONV_3D:\n            {\n                ggml_compute_forward_conv_3d(params, tensor);\n            } break;\n        case GGML_OP_CONV_2D_DW:\n            {\n                ggml_compute_forward_conv_2d_dw(params, tensor);\n            } break;\n        case GGML_OP_CONV_TRANSPOSE_2D:\n            {\n                ggml_compute_forward_conv_transpose_2d(params, tensor);\n            } break;\n        case GGML_OP_POOL_1D:\n            {\n                ggml_compute_forward_pool_1d(params, tensor);\n            } break;\n        case GGML_OP_POOL_2D:\n            {\n                ggml_compute_forward_pool_2d(params, tensor);\n            } break;\n        case GGML_OP_POOL_2D_BACK:\n            {\n                ggml_compute_forward_pool_2d_back(params, tensor);\n            } break;\n        case GGML_OP_UPSCALE:\n            {\n                ggml_compute_forward_upscale(params, tensor);\n            } break;\n        case GGML_OP_PAD:\n            {\n                ggml_compute_forward_pad(params, tensor);\n            } break;\n        case GGML_OP_PAD_REFLECT_1D:\n            {\n                ggml_compute_forward_pad_reflect_1d(params, tensor);\n            } break;\n        case GGML_OP_ROLL:\n            {\n                ggml_compute_forward_roll(params, tensor);\n            } break;\n        case GGML_OP_ARANGE:\n            {\n                ggml_compute_forward_arange(params, tensor);\n            } break;\n        case GGML_OP_TIMESTEP_EMBEDDING:\n            {\n                ggml_compute_forward_timestep_embedding(params, tensor);\n            } break;\n        case GGML_OP_ARGSORT:\n            {\n                ggml_compute_forward_argsort(params, tensor);\n            } break;\n        case GGML_OP_TOP_K:\n            {\n                ggml_compute_forward_top_k(params, tensor);\n            } break;\n        case GGML_OP_LEAKY_RELU:\n            {\n                ggml_compute_forward_leaky_relu(params, tensor);\n            } break;\n        case GGML_OP_TRI:\n            {\n                ggml_compute_forward_tri(params, tensor);\n            } break;\n        case GGML_OP_FILL:\n            {\n                ggml_compute_forward_fill(params, tensor);\n            } break;\n        case GGML_OP_FLASH_ATTN_EXT:\n            {\n                ggml_compute_forward_flash_attn_ext(params, tensor);\n            } break;\n        case GGML_OP_FLASH_ATTN_BACK:\n            {\n                int32_t t = ggml_get_op_params_i32(tensor, 0);\n                GGML_ASSERT(t == 0 || t == 1);\n                bool masked = t != 0;\n                ggml_compute_forward_flash_attn_back(params, masked, tensor);\n            } break;\n        case GGML_OP_SSM_CONV:\n            {\n                ggml_compute_forward_ssm_conv(params, tensor);\n            } break;\n        case GGML_OP_SSM_SCAN:\n            {\n                ggml_compute_forward_ssm_scan(params, tensor);\n            } break;\n        case GGML_OP_WIN_PART:\n            {\n                ggml_compute_forward_win_part(params, tensor);\n            } break;\n        case GGML_OP_WIN_UNPART:\n            {\n                ggml_compute_forward_win_unpart(params, tensor);\n            } break;\n        case GGML_OP_UNARY:\n            {\n                ggml_compute_forward_unary(params, tensor);\n            } break;\n        case GGML_OP_GLU:\n            {\n                ggml_compute_forward_glu(params, tensor);\n            } break;\n        case GGML_OP_GET_REL_POS:\n            {\n                ggml_compute_forward_get_rel_pos(params, tensor);\n            } break;\n        case GGML_OP_ADD_REL_POS:\n            {\n                ggml_compute_forward_add_rel_pos(params, tensor);\n            } break;\n        case GGML_OP_RWKV_WKV6:\n            {\n                ggml_compute_forward_rwkv_wkv6(params, tensor);\n            } break;\n        case GGML_OP_GATED_LINEAR_ATTN:\n            {\n                ggml_compute_forward_gla(params, tensor);\n            } break;\n        case GGML_OP_RWKV_WKV7:\n            {\n                ggml_compute_forward_rwkv_wkv7(params, tensor);\n            } break;\n        case GGML_OP_SOLVE_TRI:\n            {\n                ggml_compute_forward_solve_tri(params, tensor);\n            } break;\n        case GGML_OP_GATED_DELTA_NET:\n            {\n                ggml_compute_forward_gated_delta_net(params, tensor);\n            } break;\n        case GGML_OP_MAP_CUSTOM1:\n            {\n                ggml_compute_forward_map_custom1(params, tensor);\n            }\n            break;\n        case GGML_OP_MAP_CUSTOM2:\n            {\n                ggml_compute_forward_map_custom2(params, tensor);\n            }\n            break;\n        case GGML_OP_MAP_CUSTOM3:\n            {\n                ggml_compute_forward_map_custom3(params, tensor);\n            }\n            break;\n        case GGML_OP_CUSTOM:\n            {\n                ggml_compute_forward_custom(params, tensor);\n            }\n            break;\n        case GGML_OP_CROSS_ENTROPY_LOSS:\n            {\n                ggml_compute_forward_cross_entropy_loss(params, tensor);\n            }\n            break;\n        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:\n            {\n                ggml_compute_forward_cross_entropy_loss_back(params, tensor);\n            }\n            break;\n        case GGML_OP_OPT_STEP_ADAMW:\n            {\n                ggml_compute_forward_opt_step_adamw(params, tensor);\n            }\n            break;\n        case GGML_OP_OPT_STEP_SGD:\n            {\n                ggml_compute_forward_opt_step_sgd(params, tensor);\n            }\n            break;\n        case GGML_OP_NONE:\n            {\n                // nop\n            } break;\n        case GGML_OP_RESHAPE:\n            {\n                // nop\n            } break;\n        case GGML_OP_PERMUTE:\n            {\n                // nop\n            } break;\n        case GGML_OP_VIEW:\n            {\n                // nop\n            } break;\n        case GGML_OP_TRANSPOSE:\n            {\n                // nop\n            } break;\n        case GGML_OP_COUNT:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// Android's libc implementation \"bionic\" does not support setting affinity\n#if defined(__gnu_linux__)\nstatic void set_numa_thread_affinity(int thread_n) {\n    if (!ggml_is_numa()) {\n        return;\n    }\n\n    int node_num;\n    int rv;\n    size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);\n\n    switch(g_state.numa.numa_strategy) {\n        case GGML_NUMA_STRATEGY_DISTRIBUTE:\n            // run thread on node_num thread_n / (threads per node)\n            node_num = thread_n % g_state.numa.n_nodes;\n            break;\n        case GGML_NUMA_STRATEGY_ISOLATE:\n            // run thread on current_node\n            node_num = g_state.numa.current_node;\n            break;\n        case GGML_NUMA_STRATEGY_NUMACTL:\n            // use the cpuset that numactl gave us\n            rv = pthread_setaffinity_np(pthread_self(), setsize, &g_state.numa.cpuset);\n            if (rv) {\n                fprintf(stderr, \"warning: pthread_setaffinity_np() failed: %s\\n\",strerror(rv));\n            }\n            return;\n        default:\n            return;\n    }\n\n    struct ggml_numa_node * node = &g_state.numa.nodes[node_num];\n\n    cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);\n    CPU_ZERO_S(setsize, cpus);\n    for (size_t i = 0; i < node->n_cpus; ++i) {\n        CPU_SET_S(node->cpus[i], setsize, cpus);\n    }\n\n    rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);\n    if (rv) {\n            fprintf(stderr, \"warning: pthread_setaffinity_np() failed: %s\\n\", strerror(rv));\n    }\n\n    CPU_FREE(cpus);\n}\n\nstatic void clear_numa_thread_affinity(void) {\n    if (!ggml_is_numa()) {\n        return;\n    }\n\n    size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);\n\n    cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);\n    CPU_ZERO_S(setsize, cpus);\n    for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) {\n        CPU_SET_S(i, setsize, cpus);\n    }\n\n    int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);\n    if (rv) {\n        fprintf(stderr, \"warning: pthread_setaffinity_np() failed: %s\\n\", strerror(rv));\n    }\n\n    CPU_FREE(cpus);\n}\n#else\n// TODO: Windows etc.\n// (the linux implementation may also work on BSD, someone should test)\nstatic void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n);  }\nstatic void clear_numa_thread_affinity(void) {}\n#endif\n\nstatic int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {\n    int n_tasks = 0;\n\n    if (ggml_is_empty(node)) {\n        // no need to multi-thread a no-op\n        n_tasks = 1;\n        return n_tasks;\n    }\n\n    switch (node->op) {\n        case GGML_OP_CPY:\n        case GGML_OP_DUP:\n        case GGML_OP_CONT:\n        case GGML_OP_ADD:\n        case GGML_OP_ADD_ID:\n        case GGML_OP_ADD1:\n        case GGML_OP_ACC:\n        case GGML_OP_CUMSUM:\n        case GGML_OP_TRI:\n        case GGML_OP_FILL:\n            {\n                n_tasks = n_threads;\n            } break;\n        case GGML_OP_SUB:\n        case GGML_OP_SQR:\n        case GGML_OP_SQRT:\n        case GGML_OP_LOG:\n        case GGML_OP_SIN:\n        case GGML_OP_COS:\n        case GGML_OP_SUM:\n        case GGML_OP_SUM_ROWS:\n        case GGML_OP_MEAN:\n        case GGML_OP_ARGMAX:\n            {\n                n_tasks = 1;\n            } break;\n        case GGML_OP_COUNT_EQUAL:\n        case GGML_OP_SOLVE_TRI:\n        case GGML_OP_GATED_DELTA_NET:\n            {\n                n_tasks = n_threads;\n            } break;\n        case GGML_OP_REPEAT:\n        case GGML_OP_REPEAT_BACK:\n        case GGML_OP_LEAKY_RELU:\n            {\n                n_tasks = 1;\n            } break;\n        case GGML_OP_UNARY:\n            switch (ggml_get_unary_op(node)) {\n                case GGML_UNARY_OP_ABS:\n                case GGML_UNARY_OP_SGN:\n                case GGML_UNARY_OP_NEG:\n                case GGML_UNARY_OP_STEP:\n                case GGML_UNARY_OP_TANH:\n                case GGML_UNARY_OP_ELU:\n                case GGML_UNARY_OP_RELU:\n                case GGML_UNARY_OP_SIGMOID:\n                case GGML_UNARY_OP_HARDSWISH:\n                case GGML_UNARY_OP_HARDSIGMOID:\n                case GGML_UNARY_OP_EXP:\n                case GGML_UNARY_OP_SOFTPLUS:\n                case GGML_UNARY_OP_EXPM1:\n                case GGML_UNARY_OP_FLOOR:\n                case GGML_UNARY_OP_CEIL:\n                case GGML_UNARY_OP_ROUND:\n                case GGML_UNARY_OP_TRUNC:\n                    {\n                        n_tasks = 1;\n                    } break;\n\n                case GGML_UNARY_OP_GELU:\n                case GGML_UNARY_OP_GELU_ERF:\n                case GGML_UNARY_OP_GELU_QUICK:\n                case GGML_UNARY_OP_SILU:\n                case GGML_UNARY_OP_XIELU:\n                    {\n                        n_tasks = n_threads;\n                    } break;\n                default:\n                    GGML_ABORT(\"fatal error\");\n            }\n            break;\n        case GGML_OP_GLU:\n            switch (ggml_get_glu_op(node)) {\n                case GGML_GLU_OP_REGLU:\n                case GGML_GLU_OP_GEGLU:\n                case GGML_GLU_OP_SWIGLU:\n                case GGML_GLU_OP_SWIGLU_OAI:\n                case GGML_GLU_OP_GEGLU_ERF:\n                case GGML_GLU_OP_GEGLU_QUICK:\n                    {\n                        n_tasks = n_threads;\n                    } break;\n                default:\n                    GGML_ABORT(\"fatal error\");\n            }\n            break;\n        case GGML_OP_SILU_BACK:\n        case GGML_OP_MUL:\n        case GGML_OP_DIV:\n        case GGML_OP_NORM:\n        case GGML_OP_RMS_NORM:\n        case GGML_OP_RMS_NORM_BACK:\n        case GGML_OP_L2_NORM:\n        case GGML_OP_GROUP_NORM:\n        case GGML_OP_CONCAT:\n        case GGML_OP_MUL_MAT:\n        case GGML_OP_MUL_MAT_ID:\n        case GGML_OP_OUT_PROD:\n            {\n                n_tasks = n_threads;\n            } break;\n        case GGML_OP_GET_ROWS:\n        case GGML_OP_SET_ROWS:\n            {\n                // FIXME: get_rows can use additional threads, but the cost of launching additional threads\n                // decreases performance with GPU offloading\n                //n_tasks = n_threads;\n                n_tasks = 1;\n            } break;\n        case GGML_OP_SCALE:\n        case GGML_OP_SET:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n        case GGML_OP_GET_ROWS_BACK:\n        case GGML_OP_DIAG:\n            {\n                n_tasks = 1;\n            } break;\n        case GGML_OP_DIAG_MASK_ZERO:\n        case GGML_OP_DIAG_MASK_INF:\n        case GGML_OP_SOFT_MAX_BACK:\n        case GGML_OP_ROPE:\n        case GGML_OP_ROPE_BACK:\n        case GGML_OP_ADD_REL_POS:\n            {\n                n_tasks = n_threads;\n            } break;\n        case GGML_OP_CLAMP:\n            {\n                n_tasks = 1; //TODO\n            } break;\n        case GGML_OP_SOFT_MAX:\n            {\n                n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));\n            } break;\n        case GGML_OP_IM2COL:\n        case GGML_OP_IM2COL_BACK:\n        case GGML_OP_IM2COL_3D:\n        case GGML_OP_CONV_2D:\n        case GGML_OP_CONV_3D:\n        case GGML_OP_CONV_2D_DW:\n        case GGML_OP_CONV_TRANSPOSE_1D:\n        case GGML_OP_CONV_TRANSPOSE_2D:\n            {\n                n_tasks = n_threads;\n            } break;\n        case GGML_OP_POOL_1D:\n        case GGML_OP_POOL_2D:\n        case GGML_OP_POOL_2D_BACK:\n            {\n                n_tasks = 1;\n            } break;\n        case GGML_OP_UPSCALE:\n        case GGML_OP_PAD:\n        case GGML_OP_PAD_REFLECT_1D:\n        case GGML_OP_ROLL:\n        case GGML_OP_ARANGE:\n        case GGML_OP_TIMESTEP_EMBEDDING:\n        case GGML_OP_ARGSORT:\n        case GGML_OP_TOP_K:\n        case GGML_OP_FLASH_ATTN_EXT:\n        case GGML_OP_FLASH_ATTN_BACK:\n        case GGML_OP_SSM_CONV:\n        case GGML_OP_SSM_SCAN:\n        case GGML_OP_RWKV_WKV6:\n        case GGML_OP_GATED_LINEAR_ATTN:\n        case GGML_OP_RWKV_WKV7:\n            {\n                n_tasks = n_threads;\n            } break;\n        case GGML_OP_WIN_PART:\n        case GGML_OP_WIN_UNPART:\n        case GGML_OP_GET_REL_POS:\n            {\n                n_tasks = 1;\n            } break;\n        case GGML_OP_MAP_CUSTOM1:\n            {\n                struct ggml_map_custom1_op_params p;\n                memcpy(&p, node->op_params, sizeof(p));\n                if (p.n_tasks == GGML_N_TASKS_MAX) {\n                    n_tasks = n_threads;\n                } else {\n                    n_tasks = MIN(p.n_tasks, n_threads);\n                }\n            } break;\n        case GGML_OP_MAP_CUSTOM2:\n            {\n                struct ggml_map_custom2_op_params p;\n                memcpy(&p, node->op_params, sizeof(p));\n                if (p.n_tasks == GGML_N_TASKS_MAX) {\n                    n_tasks = n_threads;\n                } else {\n                    n_tasks = MIN(p.n_tasks, n_threads);\n                }\n            } break;\n        case GGML_OP_MAP_CUSTOM3:\n            {\n                struct ggml_map_custom3_op_params p;\n                memcpy(&p, node->op_params, sizeof(p));\n                if (p.n_tasks == GGML_N_TASKS_MAX) {\n                    n_tasks = n_threads;\n                } else {\n                    n_tasks = MIN(p.n_tasks, n_threads);\n                }\n            } break;\n        case GGML_OP_CUSTOM:\n            {\n                struct ggml_custom_op_params p;\n                memcpy(&p, node->op_params, sizeof(p));\n                if (p.n_tasks == GGML_N_TASKS_MAX) {\n                    n_tasks = n_threads;\n                } else {\n                    n_tasks = MIN(p.n_tasks, n_threads);\n                }\n            } break;\n        case GGML_OP_CROSS_ENTROPY_LOSS:\n        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:\n        case GGML_OP_OPT_STEP_ADAMW:\n        case GGML_OP_OPT_STEP_SGD:\n            {\n                n_tasks = n_threads;\n            } break;\n        case GGML_OP_NONE:\n            {\n                n_tasks = 1;\n            } break;\n        case GGML_OP_COUNT:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n        default:\n            {\n                fprintf(stderr, \"%s: op not implemented: \", __func__);\n                if (node->op < GGML_OP_COUNT) {\n                    fprintf(stderr, \"%s\\n\", ggml_op_name(node->op));\n                } else {\n                    fprintf(stderr, \"%d\\n\", node->op);\n                }\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n\n    assert(n_tasks > 0);\n\n    return n_tasks;\n}\n\nstatic thread_ret_t ggml_graph_compute_secondary_thread(void* data);\n\n#if defined(_WIN32)\n#include \"windows.h\"\n\n// TODO: support > 64 CPUs\nstatic bool ggml_thread_apply_affinity(bool * mask) {\n    HANDLE    h = GetCurrentThread();\n    uint64_t  bitmask = 0ULL;\n\n    assert(GGML_MAX_N_THREADS >= 64);\n\n    for (int32_t i = 0; i < 8; i++) {\n        int32_t idx = i * 8;\n        uint8_t val = 0;\n        val |= mask[idx + 0] << 0;\n        val |= mask[idx + 1] << 1;\n        val |= mask[idx + 2] << 2;\n        val |= mask[idx + 3] << 3;\n        val |= mask[idx + 4] << 4;\n        val |= mask[idx + 5] << 5;\n        val |= mask[idx + 6] << 6;\n        val |= mask[idx + 7] << 7;\n        bitmask |= (uint64_t)val << idx;\n    }\n\n    for (int32_t i = 64; i < GGML_MAX_N_THREADS; i++) {\n        if (mask[i]) {\n            fprintf(stderr, \"warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\\n\");\n            break;\n        }\n    }\n\n    DWORD_PTR m = (DWORD_PTR)bitmask;\n\n    m = SetThreadAffinityMask(h, m);\n\n    return m != 0;\n}\n\nstatic bool ggml_thread_apply_priority(int32_t prio) {\n    // Note that on Windows the Process Priority Class must be updated in order to set Thread priority.\n    // This is up to the applications.\n    DWORD p = THREAD_PRIORITY_NORMAL;\n    switch (prio) {\n        case GGML_SCHED_PRIO_LOW:      p = THREAD_PRIORITY_BELOW_NORMAL;  break;\n        case GGML_SCHED_PRIO_NORMAL:   p = THREAD_PRIORITY_NORMAL;        break;\n        case GGML_SCHED_PRIO_MEDIUM:   p = THREAD_PRIORITY_ABOVE_NORMAL;  break;\n        case GGML_SCHED_PRIO_HIGH:     p = THREAD_PRIORITY_HIGHEST;       break;\n        case GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break;\n    }\n\n    if (prio != GGML_SCHED_PRIO_LOW) {\n        // Tell Windows that this thread should not be throttled (needs its own CPU core).\n        // Newer Windows 11 versions aggressively park (offline) CPU cores and often place\n        // all our threads onto the first 4 cores which results in terrible performance with\n        // n_threads > 4\n        #if _WIN32_WINNT >= 0x0602\n        THREAD_POWER_THROTTLING_STATE t;\n        ZeroMemory(&t, sizeof(t));\n        t.Version     = THREAD_POWER_THROTTLING_CURRENT_VERSION;\n        t.ControlMask = THREAD_POWER_THROTTLING_EXECUTION_SPEED;\n        t.StateMask   = 0;\n\n        if (!SetThreadInformation(GetCurrentThread(), ThreadPowerThrottling, &t, sizeof(t))) {\n            GGML_LOG_DEBUG(\"failed to disable thread power throttling %d : (%d)\\n\", prio, (int) GetLastError());\n            return false;\n        }\n        #endif\n    }\n\n    if (prio == GGML_SCHED_PRIO_NORMAL) {\n        // Keep inherited policy/priority\n        return true;\n    }\n\n    if (!SetThreadPriority(GetCurrentThread(), p)) {\n        fprintf(stderr, \"warn: failed to set thread priority %d : (%d)\\n\", prio, (int) GetLastError());\n        return false;\n    }\n\n    return true;\n}\n\n#elif defined(__APPLE__)\n#include <sys/types.h>\n#include <sys/resource.h>\n\nstatic bool ggml_thread_apply_affinity(const bool * mask) {\n    // Not supported on Apple platforms\n    UNUSED(mask);\n    return true;\n}\n\nstatic bool ggml_thread_apply_priority(int32_t prio) {\n    struct sched_param p;\n    int32_t policy = SCHED_OTHER;\n    switch (prio) {\n        // TODO: there seems to be no way to set lower prio on Apple platforms\n        case GGML_SCHED_PRIO_LOW:      policy = SCHED_OTHER; p.sched_priority = 0;  break;\n        case GGML_SCHED_PRIO_NORMAL:   policy = SCHED_OTHER; p.sched_priority = 0;  break;\n        case GGML_SCHED_PRIO_MEDIUM:   policy = SCHED_FIFO;  p.sched_priority = 40; break;\n        case GGML_SCHED_PRIO_HIGH:     policy = SCHED_FIFO;  p.sched_priority = 80; break;\n        case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO;  p.sched_priority = 90; break;\n    }\n\n    if (prio == GGML_SCHED_PRIO_NORMAL) {\n        // Keep inherited policy/priority\n        return true;\n    }\n\n    int32_t err = pthread_setschedparam(pthread_self(), policy, &p);\n    if (err != 0) {\n        fprintf(stderr, \"warn: failed to set thread priority %d : %s (%d)\\n\", prio, strerror(err), err);\n        return false;\n    }\n\n    return true;\n}\n\n#elif defined(__gnu_linux__)\n// TODO: this may not work on BSD, to be verified\n\nstatic bool ggml_thread_apply_affinity(const bool * mask) {\n    cpu_set_t cpuset;\n    int err;\n\n    CPU_ZERO(&cpuset);\n\n    for (uint32_t i = 0; i < GGML_MAX_N_THREADS; i++) {\n        if (mask[i]) {\n            GGML_PRINT_DEBUG(\"Thread %lx: adding %d to cpuset\\n\", pthread_self(), i);\n            CPU_SET(i, &cpuset);\n        }\n    }\n\n#ifdef __ANDROID__\n    err = sched_setaffinity(0, sizeof(cpuset), &cpuset);\n    if (err < 0) {\n        err = errno;\n    }\n#else\n    err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset);\n#endif\n    if (err != 0) {\n        fprintf(stderr, \"warn: failed to set affinity mask 0x%llx : %s (%d)\\n\", (unsigned long long)mask, strerror(err), err);\n        return false;\n    }\n\n    return true;\n}\n\nstatic bool ggml_thread_apply_priority(int32_t prio) {\n    struct sched_param p;\n    int32_t policy = SCHED_OTHER;\n    switch (prio) {\n        case GGML_SCHED_PRIO_LOW:      policy = SCHED_BATCH; p.sched_priority = 0;  break;\n        case GGML_SCHED_PRIO_NORMAL:   policy = SCHED_OTHER; p.sched_priority = 0;  break;\n        case GGML_SCHED_PRIO_MEDIUM:   policy = SCHED_FIFO;  p.sched_priority = 40; break;\n        case GGML_SCHED_PRIO_HIGH:     policy = SCHED_FIFO;  p.sched_priority = 80; break;\n        case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO;  p.sched_priority = 90; break;\n    }\n\n    if (prio == GGML_SCHED_PRIO_NORMAL) {\n        // Keep inherited policy/priority\n        return true;\n    }\n\n    int32_t err = pthread_setschedparam(pthread_self(), policy, &p);\n    if (err != 0) {\n        fprintf(stderr, \"warn: failed to set thread priority %d : %s (%d)\\n\", prio, strerror(err), err);\n        return false;\n    }\n\n    return true;\n}\n\n#else // unsupported platforms\n\nstatic bool ggml_thread_apply_affinity(const bool * mask) {\n    UNUSED(mask);\n    return true;\n}\n\nstatic bool ggml_thread_apply_priority(int32_t prio) {\n    UNUSED(prio);\n    return true;\n}\n\n#endif\n\nstatic bool ggml_thread_cpumask_is_valid(const bool * mask) {\n    for (int i = 0; i < GGML_MAX_N_THREADS; i++) {\n        if (mask[i]) { return true; }\n    }\n    return false;\n}\n\nstatic void ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) {\n    if (!strict) {\n        memcpy(local_mask, global_mask, GGML_MAX_N_THREADS);\n        return;\n    } else {\n        memset(local_mask, 0, GGML_MAX_N_THREADS);\n        int32_t base_idx = *iter;\n        for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) {\n            int32_t idx = base_idx + i;\n            if (idx >= GGML_MAX_N_THREADS) {\n                // Just a cheaper modulo\n                idx -= GGML_MAX_N_THREADS;\n            }\n            if (global_mask[idx]) {\n                local_mask[idx] = 1;\n                *iter = idx + 1;\n                return;\n            }\n        }\n    }\n}\n\nvoid ggml_threadpool_free(struct ggml_threadpool* threadpool) {\n    if (!threadpool) return;\n\n    const int n_threads = threadpool->n_threads;\n\n#ifndef GGML_USE_OPENMP\n    struct ggml_compute_state* workers = threadpool->workers;\n\n    ggml_mutex_lock(&threadpool->mutex);\n\n    threadpool->stop = true;\n    threadpool->pause = false;\n\n    ggml_cond_broadcast(&threadpool->cond);\n    ggml_mutex_unlock(&threadpool->mutex);\n\n    for (int j = 1; j < n_threads; j++) {\n        int32_t rc = ggml_thread_join(workers[j].thrd, NULL);\n        GGML_ASSERT(rc == GGML_EXIT_SUCCESS || rc == GGML_EXIT_ABORTED);\n        UNUSED(rc);\n    }\n\n    ggml_mutex_destroy(&threadpool->mutex);\n    ggml_cond_destroy(&threadpool->cond);\n#endif // GGML_USE_OPENMP\n\n    const size_t workers_size = sizeof(struct ggml_compute_state) * n_threads;\n    ggml_aligned_free(threadpool->workers, workers_size);\n    ggml_aligned_free(threadpool, sizeof(struct ggml_threadpool));\n}\n\n#ifndef GGML_USE_OPENMP\n// pause/resume must be called under mutex\nstatic void ggml_threadpool_pause_locked(struct ggml_threadpool * threadpool) {\n    GGML_PRINT_DEBUG(\"Pausing threadpool\\n\");\n    threadpool->pause = true;\n    ggml_cond_broadcast(&threadpool->cond);\n}\n\nstatic void ggml_threadpool_resume_locked(struct ggml_threadpool * threadpool) {\n    GGML_PRINT_DEBUG(\"Resuming threadpool\\n\");\n    threadpool->pause = false;\n    ggml_cond_broadcast(&threadpool->cond);\n}\n#endif\n\nvoid ggml_threadpool_pause(struct ggml_threadpool * threadpool) {\n#ifndef GGML_USE_OPENMP\n    ggml_mutex_lock(&threadpool->mutex);\n    if (!threadpool->pause) {\n       ggml_threadpool_pause_locked(threadpool);\n    }\n    ggml_mutex_unlock(&threadpool->mutex);\n#else\n    UNUSED(threadpool);\n#endif\n}\n\nvoid ggml_threadpool_resume(struct ggml_threadpool * threadpool) {\n#ifndef GGML_USE_OPENMP\n    ggml_mutex_lock(&threadpool->mutex);\n    if (threadpool->pause) {\n       ggml_threadpool_resume_locked(threadpool);\n    }\n    ggml_mutex_unlock(&threadpool->mutex);\n#else\n    UNUSED(threadpool);\n#endif\n}\n\nstruct ggml_cplan ggml_graph_plan(\n          const struct ggml_cgraph * cgraph,\n                               int   n_threads,\n            struct ggml_threadpool * threadpool) {\n\n    if (threadpool == NULL) {\n        //GGML_PRINT_DEBUG(\"Threadpool is not specified. Will create a disposable threadpool : n_threads %d\\n\", n_threads);\n    }\n    if (n_threads <= 0) {\n        n_threads = threadpool ? threadpool->n_threads : GGML_DEFAULT_N_THREADS;\n    }\n\n#if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__)\n    // Emscripten without pthreads support can only use a single thread\n    n_threads = 1;\n#endif\n\n    size_t work_size = 0;\n\n    struct ggml_cplan cplan;\n    memset(&cplan, 0, sizeof(struct ggml_cplan));\n\n    int max_tasks = 1;\n\n    // thread scheduling for the different operations + work buffer size estimation\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        struct ggml_tensor * node = cgraph->nodes[i];\n\n        const int n_tasks = ggml_get_n_tasks(node, n_threads);\n\n        max_tasks = MAX(max_tasks, n_tasks);\n\n        size_t cur = 0;\n\n        if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {\n            switch (node->op) {\n                case GGML_OP_CPY:\n                case GGML_OP_DUP:\n                    {\n                        if (ggml_is_quantized(node->type) ||\n                            // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32\n                            (node->src[0]->type == GGML_TYPE_F16  && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||\n                            (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16) ||\n                            // conversion between F32 and I32\n                            (node->src[0]->type == GGML_TYPE_F32 && node->src[1] && node->src[1]->type == GGML_TYPE_I32) ||\n                            (node->src[0]->type == GGML_TYPE_I32 && node->src[1] && node->src[1]->type == GGML_TYPE_F32)) {\n                            cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;\n                        }\n                    } break;\n                case GGML_OP_ADD:\n                case GGML_OP_ADD_ID:\n                case GGML_OP_ADD1:\n                    {\n                        if (ggml_is_quantized(node->src[0]->type)) {\n                            cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;\n                        }\n                    } break;\n                case GGML_OP_ACC:\n                    {\n                        if (ggml_is_quantized(node->src[0]->type)) {\n                            cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;\n                        }\n                    } break;\n                case GGML_OP_COUNT_EQUAL:\n                    {\n                        cur = ggml_type_size(node->type)*n_tasks;\n                    } break;\n                case GGML_OP_MUL_MAT:\n                    {\n                        const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;\n\n                        if (node->src[1]->type != vec_dot_type) {\n                            cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));\n                        }\n                    } break;\n                case GGML_OP_MUL_MAT_ID:\n                    {\n                        cur = 0;\n                        const struct ggml_tensor * src0 = node->src[0];\n                        const struct ggml_tensor * src1 = node->src[1];\n                        const struct ggml_tensor * ids = node->src[2];\n                        const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;\n                        const int n_as = src0->ne[2];\n                        // src1\n                        if (src1->type != vec_dot_type) {\n                            cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)) + sizeof(int64_t);\n                        }\n                        // matrix_row_counts\n                        cur += n_as * sizeof(int64_t) + sizeof(int64_t);\n                        // matrix_rows\n                        cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t);\n                        // atomic_current_chunk\n                        cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE;\n                    } break;\n                case GGML_OP_OUT_PROD:\n                    {\n                        if (ggml_is_quantized(node->src[0]->type)) {\n                            cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;\n                        }\n                    } break;\n                case GGML_OP_SOFT_MAX:\n                case GGML_OP_ROPE:\n                case GGML_OP_ROPE_BACK:\n                    {\n                        cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;\n                    } break;\n                case GGML_OP_CONV_TRANSPOSE_1D:\n                    {\n                        GGML_ASSERT(node->src[0]->ne[3] == 1);\n                        GGML_ASSERT(node->src[1]->ne[2] == 1);\n                        GGML_ASSERT(node->src[1]->ne[3] == 1);\n\n                        const int64_t ne00 = node->src[0]->ne[0];  // K\n                        const int64_t ne01 = node->src[0]->ne[1];  // Cout\n                        const int64_t ne02 = node->src[0]->ne[2];  // Cin\n                        const int64_t ne10 = node->src[1]->ne[0];  // L\n                        const int64_t ne11 = node->src[1]->ne[1];  // Cin\n\n                        if ((node->src[0]->type == GGML_TYPE_F16 ||\n                             node->src[0]->type == GGML_TYPE_BF16) &&\n                            node->src[1]->type == GGML_TYPE_F32) {\n                            cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;\n                            cur += sizeof(ggml_fp16_t)*ne10*ne11;\n                        } else if (node->src[0]->type == GGML_TYPE_F32 &&\n                                   node->src[1]->type == GGML_TYPE_F32) {\n                            cur += sizeof(float)*ne00*ne01*ne02;\n                            cur += sizeof(float)*ne10*ne11;\n                        } else {\n                            GGML_ABORT(\"fatal error\");\n                        }\n                    } break;\n                case GGML_OP_CONV_2D:\n                case GGML_OP_CONV_3D:\n                    {\n                        cur = GGML_IM2COL_WORK_SIZE;\n                    } break;\n                case GGML_OP_CONV_TRANSPOSE_2D:\n                    {\n                        const int64_t ne00 = node->src[0]->ne[0]; // W\n                        const int64_t ne01 = node->src[0]->ne[1]; // H\n                        const int64_t ne02 = node->src[0]->ne[2]; // Channels Out\n                        const int64_t ne03 = node->src[0]->ne[3]; // Channels In\n\n                        const int64_t ne10 = node->src[1]->ne[0]; // W\n                        const int64_t ne11 = node->src[1]->ne[1]; // H\n                        const int64_t ne12 = node->src[1]->ne[2]; // Channels In\n\n                        cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;\n                        cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;\n                    } break;\n                case GGML_OP_TOP_K:\n                    {\n                        cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;\n                    } break;\n                case GGML_OP_FLASH_ATTN_EXT:\n                    {\n                        const int64_t neq2 = node->src[0]->ne[2]; // number of query heads\n                        const int64_t DK = node->src[1]->ne[0];\n                        const int64_t DV = node->src[2]->ne[0];\n\n                        // Tiled flash attention scratch (tile sizes defined in common.h)\n                        // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding\n                        size_t prefill  = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks;\n\n                        // Decode path: n_kv_chunks = n_tasks (one chunk per thread)\n                        // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ\n                        size_t n_chunks = n_tasks;\n                        size_t decode   = sizeof(float)*(neq2*n_chunks*(2+DV) + n_tasks*(DK + 2*DV));\n\n                        cur += MAX(prefill, decode);\n                    } break;\n                case GGML_OP_FLASH_ATTN_BACK:\n                    {\n                        const int64_t    D = node->src[0]->ne[0];\n                        const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);\n                        const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back\n                        if (node->src[1]->type == GGML_TYPE_F32) {\n                            cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)\n                            cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2\n                        } else if (node->src[1]->type == GGML_TYPE_F16) {\n                            cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)\n                            cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2\n                        } else if (node->src[1]->type == GGML_TYPE_BF16) {\n                            cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)\n                            cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2\n                        }\n                    } break;\n\n                case GGML_OP_CROSS_ENTROPY_LOSS:\n                    {\n                        cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);\n                    } break;\n                case GGML_OP_GATED_DELTA_NET:\n                    {\n                        const int64_t S_v = node->src[2]->ne[0];\n                        cur = S_v * sizeof(float) * n_tasks;\n                    } break;\n                case GGML_OP_COUNT:\n                    {\n                        GGML_ABORT(\"fatal error\");\n                    }\n                default:\n                    break;\n            }\n        }\n\n        work_size = MAX(work_size, cur);\n    }\n\n    if (work_size > 0) {\n        work_size += CACHE_LINE_SIZE*(n_threads);\n    }\n\n    cplan.threadpool = threadpool;\n    cplan.n_threads  = MIN(max_tasks, n_threads);\n    cplan.work_size  = work_size;\n    cplan.work_data  = NULL;\n\n    return cplan;\n}\n\nstatic thread_ret_t ggml_graph_compute_thread(void * data) {\n    struct ggml_compute_state * state = (struct ggml_compute_state *) data;\n    struct ggml_threadpool    * tp    = state->threadpool;\n\n    const struct ggml_cgraph * cgraph = tp->cgraph;\n    const struct ggml_cplan  * cplan  = tp->cplan;\n\n    set_numa_thread_affinity(state->ith);\n\n    struct ggml_compute_params params = {\n        /*.ith        =*/ state->ith,\n        /*.nth        =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK,\n        /*.wsize      =*/ cplan->work_size,\n        /*.wdata      =*/ cplan->work_data,\n        /*.threadpool =*/ tp,\n        /*.use_ref    =*/ cplan->use_ref,\n    };\n\n#ifdef GGML_USE_OPENMP\n    GGML_PRINT_DEBUG(\"thread #%d compute-start cplan %p\\n\", state->ith, (const void *)cplan);\n#else\n    GGML_PRINT_DEBUG(\"thread #%d compute-start cplan %p last-graph %d\\n\", state->ith, (const void *)cplan, state->last_graph);\n#endif\n\n    for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {\n        struct ggml_tensor * node = cgraph->nodes[node_n];\n\n        if (ggml_op_is_empty(node->op)) {\n            // skip NOPs\n            continue;\n        }\n\n        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {\n            continue;\n        }\n\n        ggml_compute_forward(&params, node);\n\n        if (state->ith == 0 && cplan->abort_callback &&\n                cplan->abort_callback(cplan->abort_callback_data)) {\n            atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);\n            tp->ec    = GGML_STATUS_ABORTED;\n        }\n\n        if (node_n + 1 < cgraph->n_nodes) {\n            ggml_barrier(state->threadpool);\n        }\n    }\n\n#ifdef GGML_USE_OPENMP\n    GGML_PRINT_DEBUG(\"thread #%d compute-done cplan %p\\n\", state->ith, (const void *)cplan);\n#else\n    GGML_PRINT_DEBUG(\"thread #%d compute-done cplan %p last-graph %d\\n\", state->ith, (const void *)cplan, state->last_graph);\n#endif\n\n    ggml_barrier(state->threadpool);\n\n    return 0;\n}\n\n#ifndef GGML_USE_OPENMP\n\n// check if thread is ready to proceed (exit from polling or sleeping)\n// returns true if loops should exit, sets state->pending to indicate new work\nstatic inline bool ggml_graph_compute_thread_ready(struct ggml_compute_state * state) {\n    struct ggml_threadpool * threadpool = state->threadpool;\n\n    if (state->pending || threadpool->stop || threadpool->pause) { return true; }\n\n    // check for new graph/work\n    int n_graph   = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed);\n    int n_threads = n_graph & GGML_THREADPOOL_N_THREADS_MASK;\n    if (n_graph != state->last_graph) {\n        state->pending    = (state->ith < n_threads);\n        state->last_graph = n_graph;\n        return true;\n    }\n\n    return false;\n}\n\n// sync thread state after polling\nstatic inline void ggml_graph_compute_thread_sync(struct ggml_compute_state * state) {\n    // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead\n    #ifdef GGML_TSAN_ENABLED\n    atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst);\n    #else\n    atomic_thread_fence(memory_order_seq_cst);\n    #endif\n    UNUSED(state);\n}\n\nstatic inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) {\n    struct ggml_threadpool * threadpool = state->threadpool;\n\n    // This seems to make 0 ... 100 a decent range for polling level across modern processors.\n    // Perhaps, we can adjust it dynamically based on load and things.\n    const uint64_t n_rounds = 1024UL * 128 * threadpool->poll;\n\n    for (uint64_t i=0; !ggml_graph_compute_thread_ready(state) && i < n_rounds; i++) {\n        // No new work. Keep polling.\n        ggml_thread_cpu_relax();\n    }\n\n    return state->pending;\n}\n\nstatic inline bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state) {\n    struct ggml_threadpool * threadpool = state->threadpool;\n\n    if (ggml_graph_compute_poll_for_work(state)) {\n        ggml_graph_compute_thread_sync(state);\n        return state->pending;\n    }\n\n    ggml_mutex_lock_shared(&threadpool->mutex);\n    while (!ggml_graph_compute_thread_ready(state)) {\n        // No new work. Wait for the signal.\n        GGML_PRINT_DEBUG(\"thread #%d waiting for work (sleeping)\\n\", state->ith);\n        ggml_cond_wait(&threadpool->cond, &threadpool->mutex);\n    }\n    ggml_mutex_unlock_shared(&threadpool->mutex);\n\n    return state->pending;\n}\n\nstatic thread_ret_t ggml_graph_compute_secondary_thread(void* data) {\n    struct ggml_compute_state * state = (struct ggml_compute_state *) data;\n    struct ggml_threadpool * threadpool = state->threadpool;\n\n    ggml_thread_apply_priority(threadpool->prio);\n    if (ggml_thread_cpumask_is_valid(state->cpumask)) {\n        ggml_thread_apply_affinity(state->cpumask);\n    }\n\n    while (true) {\n        // Check if we need to sleep\n        while (threadpool->pause) {\n            GGML_PRINT_DEBUG(\"thread #%d inside pause loop\\n\", state->ith);\n            ggml_mutex_lock_shared(&threadpool->mutex);\n            if (threadpool->pause) {\n                ggml_cond_wait(&threadpool->cond, &threadpool->mutex);\n            }\n            GGML_PRINT_DEBUG(\"thread #%d resuming after wait\\n\", state->ith);\n            ggml_mutex_unlock_shared(&threadpool->mutex);\n        }\n\n        // This needs to be checked for after the cond_wait\n        if (threadpool->stop) break;\n\n        // Check if there is new work\n        // The main thread is the only one that can dispatch new work\n\n        ggml_graph_compute_check_for_work(state);\n        if (state->pending) {\n            state->pending = false;\n            ggml_graph_compute_thread(state);\n        }\n    }\n\n    return (thread_ret_t) 0;\n}\n\n// Start processing new graph\nstatic void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool, int n_threads)\n{\n    // Always take the mutex here because the worker threads are doing hybrid poll/wait\n\n    ggml_mutex_lock(&threadpool->mutex);\n\n    // Update the number of active threads and the graph count\n    int n_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed) >> GGML_THREADPOOL_N_THREADS_BITS;\n    n_graph = ((n_graph + 1) << GGML_THREADPOOL_N_THREADS_BITS) | (n_threads & GGML_THREADPOOL_N_THREADS_MASK);\n\n    GGML_PRINT_DEBUG(\"compute-kickoff: n_threads %d n_graph %d\\n\", n_threads, n_graph);\n\n    // Indicate the graph is ready to be processed\n    // We need the full seq-cst fence here because of the polling threads (used in thread_sync)\n    atomic_store_explicit(&threadpool->n_graph, n_graph, memory_order_seq_cst);\n\n    if (threadpool->pause) {\n       // Update main thread prio and affinity to match the threadpool settings\n       ggml_thread_apply_priority(threadpool->prio);\n       if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {\n           ggml_thread_apply_affinity(threadpool->workers[0].cpumask);\n       }\n\n       // resume does cond broadcast\n       ggml_threadpool_resume_locked(threadpool);\n    } else {\n       ggml_cond_broadcast(&threadpool->cond);\n    }\n\n    ggml_mutex_unlock(&threadpool->mutex);\n}\n\n#endif // GGML_USE_OPENMP\n\nstatic struct ggml_threadpool * ggml_threadpool_new_impl(\n    struct ggml_threadpool_params * tpp,\n               struct ggml_cgraph * cgraph,\n                struct ggml_cplan * cplan) {\n\n    struct ggml_threadpool * threadpool =\n        ggml_aligned_malloc(sizeof(struct ggml_threadpool));\n    {\n        threadpool->cgraph           = cgraph;\n        threadpool->cplan            = cplan;\n        threadpool->n_graph          = 0;\n        threadpool->n_barrier        = 0;\n        threadpool->n_barrier_passed = 0;\n        threadpool->current_chunk    = 0;\n        threadpool->stop             = false;\n        threadpool->pause            = tpp->paused;\n        threadpool->abort            = -1;\n        threadpool->workers          = NULL;\n        threadpool->n_threads        = tpp->n_threads;\n        threadpool->poll             = tpp->poll;\n        threadpool->prio             = tpp->prio;\n        threadpool->ec               = GGML_STATUS_SUCCESS;\n    }\n\n    // Allocate and init workers state\n    const size_t workers_size = sizeof(struct ggml_compute_state) * tpp->n_threads;\n    struct ggml_compute_state * workers = ggml_aligned_malloc(workers_size);\n\n    memset(workers, 0, workers_size);\n    for (int j = 0; j < tpp->n_threads; j++) {\n        workers[j].threadpool = threadpool;\n        workers[j].ith        = j;\n    }\n\n    threadpool->workers = workers;\n\n#ifdef GGML_USE_OPENMP\n    int32_t cpumask_iter = 0;\n\n    // Compute CPU masks for each thread\n    for (int j = 0; j < tpp->n_threads; j++) {\n        ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter);\n    }\n#else // GGML_USE_OPENMP\n    ggml_mutex_init(&threadpool->mutex);\n    ggml_cond_init(&threadpool->cond);\n\n    // Spin the threads for all workers, and update CPU placements.\n    // Place the main thread last (towards the higher numbered CPU cores).\n\n    int32_t cpumask_iter = 0;\n\n    for (int j = 1; j < tpp->n_threads; j++) {\n        ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter);\n\n        int32_t rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_secondary_thread, &workers[j]);\n        GGML_ASSERT(rc == 0);\n    }\n\n    ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter);\n\n    if (!threadpool->pause) {\n        // Update main thread prio and affinity at the start, otherwise we'll do it in resume\n        ggml_thread_apply_priority(threadpool->prio);\n        if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {\n            ggml_thread_apply_affinity(threadpool->workers[0].cpumask);\n        }\n    }\n#endif // GGML_USE_OPENMP\n\n    return threadpool;\n}\n\nstruct ggml_threadpool * ggml_threadpool_new(struct ggml_threadpool_params * tpp) {\n    return ggml_threadpool_new_impl(tpp, NULL, NULL);\n}\n\nenum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {\n    ggml_cpu_init();\n\n    GGML_ASSERT(cplan);\n    GGML_ASSERT(cplan->n_threads > 0);\n    GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);\n\n    int n_threads                               = cplan->n_threads;\n    struct ggml_threadpool * threadpool = cplan->threadpool;\n\n    bool disposable_threadpool = false;\n\n    if (threadpool == NULL) {\n        //GGML_PRINT_DEBUG(\"Threadpool is not specified. Will create a disposable threadpool : n_threads %d\\n\", n_threads);\n        disposable_threadpool = true;\n\n        struct ggml_threadpool_params ttp = ggml_threadpool_params_default(n_threads);\n        threadpool = ggml_threadpool_new_impl(&ttp, cgraph, cplan);\n    } else {\n        // Reset some of the parameters that need resetting\n        // No worker threads should be accessing the parameters below at this stage\n        threadpool->cgraph           = cgraph;\n        threadpool->cplan            = cplan;\n        threadpool->current_chunk    = 0;\n        threadpool->abort            = -1;\n        threadpool->ec               = GGML_STATUS_SUCCESS;\n    }\n\n#ifdef GGML_USE_OPENMP\n    if (n_threads > 1) {\n        #pragma omp parallel num_threads(n_threads)\n        {\n            #pragma omp single\n            {\n                // update the number of threads from the actual number of threads that we got from OpenMP\n                n_threads = omp_get_num_threads();\n                atomic_store_explicit(&threadpool->n_graph, n_threads, memory_order_relaxed);\n            }\n\n            // Apply thread CPU mask and priority\n            int ith = omp_get_thread_num();\n\n            ggml_thread_apply_priority(threadpool->prio);\n            if (ggml_thread_cpumask_is_valid(threadpool->workers[ith].cpumask)) {\n                ggml_thread_apply_affinity(threadpool->workers[ith].cpumask);\n            }\n            ggml_graph_compute_thread(&threadpool->workers[ith]);\n        }\n    } else {\n        atomic_store_explicit(&threadpool->n_graph, 1, memory_order_relaxed);\n        ggml_graph_compute_thread(&threadpool->workers[0]);\n    }\n#else\n    if (n_threads > threadpool->n_threads) {\n        GGML_LOG_WARN(\"cplan requested more threads (%d) than available (%d)\\n\", n_threads, threadpool->n_threads);\n        n_threads = threadpool->n_threads;\n    }\n\n    // Kick all threads to start the new graph\n    ggml_graph_compute_kickoff(threadpool, n_threads);\n\n    // This is a work thread too\n    ggml_graph_compute_thread(&threadpool->workers[0]);\n#endif\n\n    // don't leave affinity set on the main thread\n    clear_numa_thread_affinity();\n\n    enum ggml_status ret = threadpool->ec;\n\n    if (disposable_threadpool) {\n        ggml_threadpool_free(threadpool);\n    }\n\n    return ret;\n}\n\nenum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {\n    struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads, NULL);\n\n    cplan.work_data = (uint8_t *)ggml_new_buffer(ctx, cplan.work_size);\n\n    return ggml_graph_compute(cgraph, &cplan);\n}\n\nvoid ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {\n    memcpy(y, x, n * sizeof(float));\n}\n\nvoid ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {\n    int64_t i = 0;\n#if defined(__F16C__)\n#if defined(__AVX512F__)\n    for (; i + 15 < n; i += 16) {\n        __m512 x_vec = _mm512_loadu_ps(x + i);\n        __m256i y_vec = _mm512_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);\n        _mm256_storeu_si256((__m256i *)(y + i), y_vec);\n    }\n#endif\n    for (; i + 7 < n; i += 8) {\n        __m256 x_vec = _mm256_loadu_ps(x + i);\n        __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);\n        _mm_storeu_si128((__m128i *)(y + i), y_vec);\n    }\n    for (; i + 3 < n; i += 4) {\n        __m128 x_vec = _mm_loadu_ps(x + i);\n        __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);\n        _mm_storel_epi64((__m128i *)(y + i), y_vec);\n    }\n#elif defined(__riscv_zvfh)\n    for (int vl; i < n; i += vl) {\n        vl = __riscv_vsetvl_e32m2(n - i);\n        vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);\n        vfloat16m1_t vy = __riscv_vfncvt_f_f_w_f16m1(vx, vl);\n        __riscv_vse16_v_f16m1((_Float16 *)&y[i], vy, vl);\n    }\n#endif\n    for (; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16(x[i]);\n    }\n}\n\nvoid ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) {\n    int64_t i = 0;\n#if defined(__F16C__)\n#if defined(__AVX512F__)\n    for (; i + 15 < n; i += 16) {\n        __m256i x_vec = _mm256_loadu_si256((const __m256i *)(x + i));\n        __m512 y_vec = _mm512_cvtph_ps(x_vec);\n        _mm512_storeu_ps(y + i, y_vec);\n    }\n#endif\n    for (; i + 7 < n; i += 8) {\n        __m128i x_vec = _mm_loadu_si128((const __m128i *)(x + i));\n        __m256 y_vec = _mm256_cvtph_ps(x_vec);\n        _mm256_storeu_ps(y + i, y_vec);\n    }\n    for (; i + 3 < n; i += 4) {\n        __m128i x_vec = _mm_loadl_epi64((const __m128i *)(x + i));\n        __m128 y_vec = _mm_cvtph_ps(x_vec);\n        _mm_storeu_ps(y + i, y_vec);\n    }\n\n#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfhmin)\n    // calculate step size\n    const int epr = __riscv_vsetvlmax_e16m2();\n    const int step = epr * 2;\n    const int np = (n & ~(step - 1));\n\n    // unroll by 2\n    for (; i < np; i += step) {\n        vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16*)x + i, epr);\n        vfloat32m4_t ay0 = __riscv_vfwcvt_f_f_v_f32m4(ax0, epr);\n        __riscv_vse32_v_f32m4(y + i, ay0, epr);\n\n        vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16*)x + i + epr, epr);\n        vfloat32m4_t ay1 = __riscv_vfwcvt_f_f_v_f32m4(ax1, epr);\n        __riscv_vse32_v_f32m4(y + i + epr, ay1, epr);\n    }\n\n    // leftovers\n    int vl;\n    for (i = np; i < n; i += vl) {\n        vl = __riscv_vsetvl_e16m2(n - i);\n        vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16*)x + i, vl);\n        vfloat32m4_t ay0 = __riscv_vfwcvt_f_f_v_f32m4(ax0, vl);\n        __riscv_vse32_v_f32m4(y + i, ay0, vl);\n    }\n\n#endif\n\n    for (; i < n; ++i) {\n        y[i] = GGML_CPU_FP16_TO_FP32(x[i]);\n    }\n}\n\nvoid ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) {\n    int64_t i = 0;\n    for (; i < n; ++i) {\n        y[i] = GGML_FP32_TO_BF16(x[i]);\n    }\n}\n\nvoid ggml_cpu_fp32_to_i32(const float * x, int32_t * y, int64_t n) {\n    int64_t i = 0;\n    for (; i < n; ++i) {\n        y[i] = x[i];\n    }\n}\n\nvoid ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) {\n    int64_t i = 0;\n#if defined(__AVX2__)\n#if defined(__AVX512F__)\n    for (; i + 15 < n; i += 16) {\n        _mm512_storeu_ps(y + i,\n                        _mm512_castsi512_ps(\n                            _mm512_slli_epi32(\n                                _mm512_cvtepu16_epi32(\n                                    _mm256_loadu_si256(\n                                        (const __m256i *)(x + i))),\n                                16)));\n    }\n#endif\n    for (; i + 7 < n; i += 8) {\n        _mm256_storeu_ps(y + i,\n                        _mm256_castsi256_ps(\n                            _mm256_slli_epi32(\n                                _mm256_cvtepu16_epi32(\n                                    _mm_loadu_si128(\n                                        (const __m128i *)(x + i))),\n                                16)));\n    }\n#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfmin)\n    // calculate step size\n    const int epr = __riscv_vsetvlmax_e16m2();\n    const int step = epr * 2;\n    const int np = (n & ~(step - 1));\n\n    // unroll by 2\n    for (; i < np; i += step) {\n        vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16*)x + i, epr);\n        vfloat32m4_t ay0 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax0, epr);\n        __riscv_vse32_v_f32m4(y + i, ay0, epr);\n\n        vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16*)x + i + epr, epr);\n        vfloat32m4_t ay1 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax1, epr);\n        __riscv_vse32_v_f32m4(y + i + epr, ay1, epr);\n    }\n\n    // leftovers\n    int vl;\n    for (i = np; i < n; i += vl) {\n        vl = __riscv_vsetvl_e16m2(n - i);\n        vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16*)x + i, vl);\n        vfloat32m4_t ay0 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax0, vl);\n        __riscv_vse32_v_f32m4(y + i, ay0, vl);\n    }\n#endif\n    for (; i < n; i++) {\n        y[i] = GGML_BF16_TO_FP32(x[i]);\n    }\n}\n\nint ggml_cpu_has_avx(void) {\n#if defined(__AVX__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_avx_vnni(void) {\n#if defined(__AVXVNNI__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_avx2(void) {\n#if defined(__AVX2__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_avx512(void) {\n#if defined(__AVX512F__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_avx512_vbmi(void) {\n#if defined(__AVX512VBMI__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_avx512_vnni(void) {\n#if defined(__AVX512VNNI__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_avx512_bf16(void) {\n#if defined(__AVX512BF16__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_amx_int8(void) {\n#if defined(__AMX_INT8__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_bmi2(void) {\n#if defined(__BMI2__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_fma(void) {\n#if defined(__FMA__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_arm_fma(void) {\n#if defined(__ARM_FEATURE_FMA)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_riscv_v(void) {\n#if defined(__riscv_v_intrinsic)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_get_rvv_vlen(void) {\n#if defined(__riscv) && defined(__riscv_v_intrinsic)\n    return ggml_riscv_arch_features.rvv_vlen;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_f16c(void) {\n#if defined(__F16C__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_fp16_va(void) {\n#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_wasm_simd(void) {\n#if defined(__wasm_simd128__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_llamafile(void) {\n#if defined(GGML_USE_LLAMAFILE)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_sse3(void) {\n#if defined(__SSE3__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_ssse3(void) {\n#if defined(__SSSE3__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_vsx(void) {\n#if defined(__POWER9_VECTOR__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_vxe(void) {\n#if defined(__VXE__) || defined(__VXE2__)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_neon(void) {\n#if defined(__ARM_ARCH) && defined(__ARM_NEON)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_dotprod(void) {\n#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_DOTPROD)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_sve(void) {\n#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_matmul_int8(void) {\n#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_MATMUL_INT8)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_get_sve_cnt(void) {\n#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)\n    return ggml_arm_arch_features.sve_cnt;\n#else\n    return 0;\n#endif\n}\n\nint ggml_cpu_has_sme(void) {\n#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)\n    return 1;\n#else\n    return 0;\n#endif\n}\n\nvoid ggml_cpu_init(void) {\n    // needed to initialize ggml_time\n    {\n        struct ggml_init_params params = { 0, NULL, false };\n        struct ggml_context * ctx = ggml_init(params);\n        ggml_free(ctx);\n    }\n\n    ggml_critical_section_start();\n\n    static bool is_first_call = true;\n\n    if (is_first_call) {\n        // initialize GELU, Quick GELU, SILU and EXP F32 tables\n        {\n            const uint64_t t_start = ggml_time_us(); UNUSED(t_start);\n\n            for (int i = 0; i < (1 << 16); ++i) {\n                union {\n                    uint16_t u16;\n                    ggml_fp16_t fp16;\n                } u = {i};\n                float f = GGML_COMPUTE_FP16_TO_FP32(u.fp16);\n                ggml_table_f32_f16[i] = f;\n                ggml_table_gelu_f16[i] = GGML_CPU_FP32_TO_FP16(ggml_gelu_f32(f));\n                ggml_table_gelu_quick_f16[i] = GGML_CPU_FP32_TO_FP16(ggml_gelu_quick_f32(f));\n            }\n\n            // initialize E8M0 half table (256 entries)\n            for (int i = 0; i < (1 << 8); ++i) {\n                ggml_table_f32_e8m0_half[i] = GGML_E8M0_TO_FP32_HALF(i);\n            }\n\n            const uint64_t t_end = ggml_time_us(); UNUSED(t_end);\n\n            GGML_PRINT_DEBUG(\"%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\\n\", __func__, (t_end - t_start)/1000.0);\n\n#ifdef GGML_USE_OPENMP\n            //if (!getenv(\"OMP_WAIT_POLICY\")) {\n            //    // set the wait policy to active, so that OpenMP threads don't sleep\n            //    setenv(\"OMP_WAIT_POLICY\", \"active\", 0)\n            //}\n\n            if (!getenv(\"KMP_BLOCKTIME\")) {\n                // set the time to wait before sleeping a thread\n                // this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases\n#ifdef _WIN32\n                _putenv_s(\"KMP_BLOCKTIME\", \"200\"); // 200ms\n#else\n                setenv(\"KMP_BLOCKTIME\", \"200\", 0); // 200ms\n#endif\n            }\n#endif\n        }\n\n#if defined(__ARM_ARCH)\n        ggml_init_arm_arch_features();\n#endif\n\n#if defined(__riscv)\n        ggml_init_riscv_arch_features();\n#endif\n\n        is_first_call = false;\n    }\n\n    ggml_critical_section_end();\n}\n"
  },
  {
    "path": "src/ggml-cpu/ggml-cpu.cpp",
    "content": "#include \"ggml-backend.h\"\n#include \"ggml-backend-impl.h\"\n#include \"ggml-cpu.h\"\n#include \"repack.h\"\n#include \"traits.h\"\n#include \"ggml-impl.h\"\n#include \"amx/amx.h\"\n\n#include <cctype>\n#include <string>\n#include <vector>\n\n#ifdef GGML_USE_CPU_HBM\n#    include \"hbm.h\"\n#endif\n\n#ifdef GGML_USE_CPU_KLEIDIAI\n#    include \"kleidiai/kleidiai.h\"\n#endif\n\n#ifdef GGML_USE_CPU_RISCV64_SPACEMIT\n#    include \"spacemit/ime.h\"\n#endif\n\n#if defined(_WIN32)\n#    define WIN32_LEAN_AND_MEAN\n#    ifndef NOMINMAX\n#        define NOMINMAX\n#    endif\n#    include <windows.h>\n#else\n#    include <unistd.h>\n#endif\n\n#if defined(__APPLE__)\n#    include <sys/sysctl.h>\n#    include <sys/types.h>\n#endif\n\n// ggml-backend interface\n\nstd::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_types() {\n    static std::vector<ggml_backend_buffer_type_t> bufts = []() {\n        std::vector<ggml_backend_buffer_type_t> bufts;\n\n#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)\n        if (ggml_backend_amx_buffer_type()) {\n            bufts.push_back(ggml_backend_amx_buffer_type());\n        }\n#endif\n\n#ifdef GGML_USE_CPU_RISCV64_SPACEMIT\n        if (ggml_backend_cpu_riscv64_spacemit_buffer_type()) {\n            bufts.push_back(ggml_backend_cpu_riscv64_spacemit_buffer_type());\n        }\n#endif\n\n#ifdef GGML_USE_CPU_KLEIDIAI\n        if (ggml_backend_cpu_kleidiai_buffer_type()) {\n            bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());\n        }\n#endif\n\n#ifdef GGML_USE_CPU_REPACK\n        if (ggml_backend_cpu_repack_buffer_type()) {\n            bufts.push_back(ggml_backend_cpu_repack_buffer_type());\n        }\n#endif\n\n        return bufts;\n    }();\n\n    return bufts;\n}\n\nstatic ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) {\n    static std::vector<ggml_backend_buffer_type_t> extra_bufts = [] {\n        std::vector<ggml_backend_buffer_type_t> bufts = ggml_backend_cpu_get_extra_buffer_types();\n        bufts.push_back(nullptr);\n        return bufts;\n    }();\n\n    return extra_bufts.data();\n\n    GGML_UNUSED(device);\n}\n\nstatic bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) {\n    for (auto * extra : ggml_backend_cpu_get_extra_buffer_types()) {\n        if (extra == buft) {\n            return true;\n        }\n    }\n    return false;\n}\n\n// CPU backend - backend (stream)\n\nstruct ggml_backend_cpu_context {\n    int                 n_threads;\n    ggml_threadpool_t   threadpool;\n\n    uint8_t *           work_data;\n    size_t              work_size;\n\n    ggml_abort_callback abort_callback;\n    void *              abort_callback_data;\n\n    bool                use_ref;  // use reference implementation\n};\n\nstatic const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {\n    return \"CPU\";\n\n    GGML_UNUSED(backend);\n}\n\nstatic void ggml_backend_cpu_free(ggml_backend_t backend) {\n    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;\n    delete[] cpu_ctx->work_data;\n    delete cpu_ctx;\n    delete backend;\n}\n\nstruct ggml_backend_plan_cpu {\n    struct ggml_cplan cplan;\n    struct ggml_cgraph cgraph;\n};\n\nstatic ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {\n    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;\n\n    struct ggml_backend_plan_cpu * cpu_plan = new ggml_backend_plan_cpu;\n\n    cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);\n    cpu_plan->cgraph = *cgraph; // FIXME: deep copy\n\n    if (cpu_plan->cplan.work_size > 0) {\n        cpu_plan->cplan.work_data = new uint8_t[cpu_plan->cplan.work_size];\n        if (cpu_plan->cplan.work_data == NULL) {\n            delete cpu_plan;\n            return NULL;\n        }\n    }\n\n    cpu_plan->cplan.abort_callback      = cpu_ctx->abort_callback;\n    cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;\n    cpu_plan->cplan.use_ref             = cpu_ctx->use_ref;\n\n    return cpu_plan;\n}\n\nstatic void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {\n    struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;\n\n    delete[] cpu_plan->cplan.work_data;\n    delete cpu_plan;\n\n    GGML_UNUSED(backend);\n}\n\nstatic enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {\n    struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;\n\n    return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);\n\n    GGML_UNUSED(backend);\n}\n\nstatic enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {\n    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;\n\n    struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);\n\n    if (cpu_ctx->work_size < cplan.work_size) {\n        delete[] cpu_ctx->work_data;\n        cpu_ctx->work_data = new uint8_t[cplan.work_size];\n        if (cpu_ctx->work_data == NULL) {\n            cpu_ctx->work_size = 0;\n            return GGML_STATUS_ALLOC_FAILED;\n        }\n        cpu_ctx->work_size = cplan.work_size;\n    }\n    cplan.work_data = (uint8_t *)cpu_ctx->work_data;\n\n    cplan.abort_callback      = cpu_ctx->abort_callback;\n    cplan.abort_callback_data = cpu_ctx->abort_callback_data;\n    cplan.use_ref             = cpu_ctx->use_ref;\n\n    return ggml_graph_compute(cgraph, &cplan);\n}\n\nstatic const struct ggml_backend_i ggml_backend_cpu_i = {\n    /* .get_name                = */ ggml_backend_cpu_get_name,\n    /* .free                    = */ ggml_backend_cpu_free,\n    /* .set_tensor_async        = */ NULL,\n    /* .get_tensor_async        = */ NULL,\n    /* .cpy_tensor_async        = */ NULL,\n    /* .synchronize             = */ NULL,\n    /* .graph_plan_create       = */ ggml_backend_cpu_graph_plan_create,\n    /* .graph_plan_free         = */ ggml_backend_cpu_graph_plan_free,\n    /* .graph_plan_update       = */ NULL,\n    /* .graph_plan_compute      = */ ggml_backend_cpu_graph_plan_compute,\n    /* .graph_compute           = */ ggml_backend_cpu_graph_compute,\n    /* .event_record            = */ NULL,\n    /* .event_wait              = */ NULL,\n    /* .graph_optimize          = */ NULL,\n};\n\nstatic ggml_guid_t ggml_backend_cpu_guid(void) {\n    static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 };\n    return &guid;\n}\n\nggml_backend_t ggml_backend_cpu_init(void) {\n    // initialize CPU backend now to avoid slowing the first graph computation\n    ggml_cpu_init();\n\n    struct ggml_backend_cpu_context * ctx = new ggml_backend_cpu_context;\n    if (ctx == NULL) {\n        return NULL;\n    }\n\n    ctx->n_threads           = GGML_DEFAULT_N_THREADS;\n    ctx->threadpool          = NULL;\n    ctx->work_data           = NULL;\n    ctx->work_size           = 0;\n    ctx->abort_callback      = NULL;\n    ctx->abort_callback_data = NULL;\n    ctx->use_ref             = false;\n\n    ggml_backend_t cpu_backend = new ggml_backend {\n        /* .guid    = */ ggml_backend_cpu_guid(),\n        /* .iface   = */ ggml_backend_cpu_i,\n        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),\n        /* .context = */ ctx,\n    };\n\n    if (cpu_backend == NULL) {\n        delete ctx;\n        return NULL;\n    }\n\n    return cpu_backend;\n}\n\nbool ggml_backend_is_cpu(ggml_backend_t backend) {\n    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid());\n}\n\nvoid ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {\n    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));\n\n    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;\n    ctx->n_threads = n_threads;\n}\n\nvoid ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) {\n    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));\n\n    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;\n\n    if (ctx->threadpool && ctx->threadpool != threadpool) {\n        // already had a different threadpool, pause/suspend it before switching\n        ggml_threadpool_pause(ctx->threadpool);\n    }\n    ctx->threadpool = threadpool;\n}\n\nvoid ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {\n    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));\n\n    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;\n    ctx->abort_callback = abort_callback;\n    ctx->abort_callback_data = abort_callback_data;\n}\n\nvoid ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref) {\n    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));\n\n    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;\n    ctx->use_ref = use_ref;\n}\n\n// CPU backend - device\n\nstruct ggml_backend_cpu_device_context {\n    std::string description = \"CPU\";\n\n    ggml_backend_cpu_device_context() {\n#ifdef __APPLE__\n        size_t len = 0;\n        if (!sysctlbyname(\"machdep.cpu.brand_string\", NULL, &len, NULL, 0)) {\n            description.resize(len);\n            sysctlbyname(\"machdep.cpu.brand_string\", &description[0], &len, NULL, 0); // NOLINT\n        }\n#elif defined(__linux__)\n        FILE * f = fopen(\"/proc/cpuinfo\", \"r\");\n        if (f) {\n            char buf[1024];\n            while (fgets(buf, sizeof(buf), f)) {\n                if (strncmp(buf, \"model name\", 10) == 0) {\n                    char * p = strchr(buf, ':');\n                    if (p) {\n                        p++;\n                        while (std::isspace(*p)) {\n                            p++;\n                        }\n                        while (std::isspace(p[strlen(p) - 1])) {\n                            p[strlen(p) - 1] = '\\0';\n                        }\n                        description = p;\n                        break;\n                    }\n                }\n            }\n            fclose(f);\n        }\n#elif defined(_WIN32)\n        HKEY hKey;\n        if (RegOpenKeyEx(HKEY_LOCAL_MACHINE,\n                        TEXT(\"HARDWARE\\\\DESCRIPTION\\\\System\\\\CentralProcessor\\\\0\"),\n                        0,\n                        KEY_READ,\n                        &hKey) == ERROR_SUCCESS) {\n            DWORD cpu_brand_size = 0;\n            if (RegQueryValueExA(hKey,\n                                \"ProcessorNameString\",\n                                NULL,\n                                NULL,\n                                NULL,\n                                &cpu_brand_size) == ERROR_SUCCESS) {\n                description.resize(cpu_brand_size);\n                if (RegQueryValueExA(hKey,\n                                    \"ProcessorNameString\",\n                                    NULL,\n                                    NULL,\n                                    (LPBYTE)&description[0], // NOLINT\n                                    &cpu_brand_size) == ERROR_SUCCESS) {\n                    if (description.find('\\0') != std::string::npos) {\n                        description.resize(description.find('\\0'));\n                    }\n                }\n            }\n            RegCloseKey(hKey);\n        }\n#endif\n    }\n};\n\nstatic const char * ggml_backend_cpu_device_get_name(ggml_backend_dev_t dev) {\n    return \"CPU\";\n\n    GGML_UNUSED(dev);\n}\n\nstatic const char * ggml_backend_cpu_device_get_description(ggml_backend_dev_t dev) {\n    struct ggml_backend_cpu_device_context * ctx = (struct ggml_backend_cpu_device_context *)dev->context;\n\n    return ctx->description.c_str();\n}\n\nstatic void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {\n#ifdef _WIN32\n    MEMORYSTATUSEX status;\n    status.dwLength = sizeof(status);\n    GlobalMemoryStatusEx(&status);\n    *total = status.ullTotalPhys;\n    *free = status.ullAvailPhys;\n#else\n    long pages = sysconf(_SC_PHYS_PAGES);\n    long page_size = sysconf(_SC_PAGE_SIZE);\n    *total = pages * page_size;\n\n    // \"free\" system memory is ill-defined, for practical purposes assume that all of it is free:\n    *free = *total;\n#endif // _WIN32\n\n    GGML_UNUSED(dev);\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_cpu_device_get_type(ggml_backend_dev_t dev) {\n    return GGML_BACKEND_DEVICE_TYPE_CPU;\n\n    GGML_UNUSED(dev);\n}\n\nstatic void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {\n    props->name        = ggml_backend_cpu_device_get_name(dev);\n    props->description = ggml_backend_cpu_device_get_description(dev);\n    props->type        = ggml_backend_cpu_device_get_type(dev);\n    ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total);\n    props->caps = {\n        /* .async                 = */ false,\n        /* .host_buffer           = */ false,\n        /* .buffer_from_host_ptr  = */ true,\n        /* .events                = */ false,\n    };\n}\n\nstatic ggml_backend_t ggml_backend_cpu_device_init_backend(ggml_backend_dev_t dev, const char * params) {\n    return ggml_backend_cpu_init();\n\n    GGML_UNUSED(dev);\n    GGML_UNUSED(params);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_cpu_device_get_buffer_type(ggml_backend_dev_t dev) {\n    return ggml_backend_cpu_buffer_type();\n\n    GGML_UNUSED(dev);\n}\n\nstatic ggml_backend_buffer_t ggml_backend_cpu_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {\n    return ggml_backend_cpu_buffer_from_ptr(ptr, size);\n\n    GGML_UNUSED(dev);\n    GGML_UNUSED(max_tensor_size);\n}\n\nstatic bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {\n    const struct ggml_tensor * src0 = op->src[0];\n    const struct ggml_tensor * src1 = op->src[1];\n\n    if (op->op == GGML_OP_NONE || op->op == GGML_OP_RESHAPE || op->op == GGML_OP_VIEW || op->op == GGML_OP_PERMUTE || op->op == GGML_OP_TRANSPOSE) {\n        return true;\n    }\n\n    // check extra buffer types\n    // note: only the first sources are checked for extra buffer types to reduce overhead, increase if necessary\n    for (int i = 0; i < 4; i++) {\n        if (op->src[i] && op->src[i]->buffer &&\n            ggml_backend_cpu_is_extra_buffer_type(op->src[i]->buffer->buft)) {\n            auto * buf_extra = (ggml::cpu::extra_buffer_type *) op->src[i]->buffer->buft->context;\n            return buf_extra->supports_op(dev, op);\n        }\n    }\n\n    switch (op->op) {\n        case GGML_OP_CPY:\n        case GGML_OP_SET_ROWS:\n            return\n                op->type != GGML_TYPE_IQ3_XXS &&\n                op->type != GGML_TYPE_IQ3_S   &&\n                op->type != GGML_TYPE_IQ2_XXS &&\n                op->type != GGML_TYPE_IQ2_XS  &&\n                op->type != GGML_TYPE_IQ2_S   &&\n                op->type != GGML_TYPE_IQ1_S   &&\n                op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float\n        case GGML_OP_MUL_MAT:\n            return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;\n        case GGML_OP_SOFT_MAX_BACK: {\n            if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) {\n                return false;\n            }\n            float max_bias = 0.0f;\n\n            memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));\n\n            return max_bias == 0.0f;\n        }\n        case GGML_OP_IM2COL_BACK:\n            return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;\n        case GGML_OP_GET_ROWS_BACK:\n            return src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16;\n        case GGML_OP_OUT_PROD:\n            return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&\n                src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;\n        default:\n            return true;\n    }\n}\n\nstatic bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    return ggml_backend_buft_is_host(buft) || ggml_backend_cpu_is_extra_buffer_type(buft);\n    GGML_UNUSED(dev);\n}\n\nstatic const struct ggml_backend_device_i ggml_backend_cpu_device_i = {\n    /* .get_name             = */ ggml_backend_cpu_device_get_name,\n    /* .get_description      = */ ggml_backend_cpu_device_get_description,\n    /* .get_memory           = */ ggml_backend_cpu_device_get_memory,\n    /* .get_type             = */ ggml_backend_cpu_device_get_type,\n    /* .get_props            = */ ggml_backend_cpu_device_get_props,\n    /* .init_backend         = */ ggml_backend_cpu_device_init_backend,\n    /* .get_buffer_type      = */ ggml_backend_cpu_device_get_buffer_type,\n    /* .get_host_buffer_type = */ NULL,\n    /* .buffer_from_host_ptr = */ ggml_backend_cpu_device_buffer_from_host_ptr,\n    /* .supports_op          = */ ggml_backend_cpu_device_supports_op,\n    /* .supports_buft        = */ ggml_backend_cpu_device_supports_buft,\n    /* .offload_op           = */ NULL,\n    /* .event_new            = */ NULL,\n    /* .event_free           = */ NULL,\n    /* .event_synchronize    = */ NULL,\n};\n\n// CPU backend - backend (reg)\n\nstatic const char * ggml_backend_cpu_reg_get_name(ggml_backend_reg_t reg) {\n    return \"CPU\";\n\n    GGML_UNUSED(reg);\n}\n\nstatic size_t ggml_backend_cpu_reg_get_device_count(ggml_backend_reg_t reg) {\n    return 1;\n\n    GGML_UNUSED(reg);\n}\n\nstatic ggml_backend_dev_t ggml_backend_cpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {\n    GGML_ASSERT(index == 0);\n\n    static ggml_backend_cpu_device_context ctx;\n    static ggml_backend_device ggml_backend_cpu_device = {\n        /* .iface   = */ ggml_backend_cpu_device_i,\n        /* .reg     = */ reg,\n        /* .context = */ &ctx,\n    };\n\n    return &ggml_backend_cpu_device;\n}\n\n// This is intended to replace the the ggml_cpu_has_* functions when loading the CPU backend dynamically,\n// and additionally to allow other backends to expose their own list of features that applications can query using the same API\nstatic ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t reg) {\n    static std::vector<ggml_backend_feature> features = []() {\n        ggml_cpu_init();\n\n        std::vector<ggml_backend_feature> features;\n        if (ggml_cpu_has_sse3()) {\n            features.push_back({ \"SSE3\", \"1\" });\n        }\n        if (ggml_cpu_has_ssse3()) {\n            features.push_back({ \"SSSE3\", \"1\" });\n        }\n        if (ggml_cpu_has_avx()) {\n            features.push_back({ \"AVX\", \"1\" });\n        }\n        if (ggml_cpu_has_avx_vnni()) {\n            features.push_back({ \"AVX_VNNI\", \"1\" });\n        }\n        if (ggml_cpu_has_avx2()) {\n            features.push_back({ \"AVX2\", \"1\" });\n        }\n        if (ggml_cpu_has_f16c()) {\n            features.push_back({ \"F16C\", \"1\" });\n        }\n        if (ggml_cpu_has_fma()) {\n            features.push_back({ \"FMA\", \"1\" });\n        }\n        if (ggml_cpu_has_bmi2()) {\n            features.push_back({ \"BMI2\", \"1\" });\n        }\n        if (ggml_cpu_has_avx512()) {\n            features.push_back({ \"AVX512\", \"1\" });\n        }\n        if (ggml_cpu_has_avx512_vbmi()) {\n            features.push_back({ \"AVX512_VBMI\", \"1\" });\n        }\n        if (ggml_cpu_has_avx512_vnni()) {\n            features.push_back({ \"AVX512_VNNI\", \"1\" });\n        }\n        if (ggml_cpu_has_avx512_bf16()) {\n            features.push_back({ \"AVX512_BF16\", \"1\" });\n        }\n        if (ggml_cpu_has_amx_int8()) {\n            features.push_back({ \"AMX_INT8\", \"1\" });\n        }\n        if (ggml_cpu_has_neon()) {\n            features.push_back({ \"NEON\", \"1\" });\n        }\n        if (ggml_cpu_has_arm_fma()) {\n            features.push_back({ \"ARM_FMA\", \"1\" });\n        }\n        if (ggml_cpu_has_fp16_va()) {\n            features.push_back({ \"FP16_VA\", \"1\" });\n        }\n        if (ggml_cpu_has_matmul_int8()) {\n            features.push_back({ \"MATMUL_INT8\", \"1\" });\n        }\n        if (ggml_cpu_has_sve()) {\n            features.push_back({ \"SVE\", \"1\" });\n        }\n        if (ggml_cpu_has_dotprod()) {\n            features.push_back({ \"DOTPROD\", \"1\" });\n        }\n        if (ggml_cpu_get_sve_cnt() > 0) {\n            static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());\n            features.push_back({ \"SVE_CNT\", sve_cnt.c_str() });\n        }\n        if (ggml_cpu_has_sme()) {\n            features.push_back({ \"SME\", \"1\" });\n        }\n        if (ggml_cpu_has_riscv_v()) {\n            features.push_back({ \"RISCV_V\", \"1\" });\n        }\n        if (ggml_cpu_get_rvv_vlen() > 0) {\n            static std::string rvv_vlen = std::to_string(ggml_cpu_get_rvv_vlen());\n            features.push_back({ \"RVV_VLEN\", rvv_vlen.c_str() });\n        }\n        if (ggml_cpu_has_vsx()) {\n            features.push_back({ \"VSX\", \"1\" });\n        }\n        if (ggml_cpu_has_vxe()) {\n            features.push_back({ \"VXE\", \"1\" });\n        }\n        if (ggml_cpu_has_wasm_simd()) {\n            features.push_back({ \"WASM_SIMD\", \"1\" });\n        }\n        if (ggml_cpu_has_llamafile()) {\n            features.push_back({ \"LLAMAFILE\", \"1\" });\n        }\n    #ifdef GGML_USE_ACCELERATE\n        features.push_back({ \"ACCELERATE\", \"1\" });\n    #endif\n    #ifdef GGML_USE_CPU_HBM\n        features.push_back({ \"CPU_HBM\", \"1\" });\n    #endif\n    #ifdef GGML_USE_OPENMP\n        features.push_back({ \"OPENMP\", \"1\" });\n    #endif\n    #ifdef GGML_USE_CPU_KLEIDIAI\n        features.push_back({ \"KLEIDIAI\", \"1\" });\n    #endif\n    #ifdef GGML_USE_CPU_REPACK\n        features.push_back({ \"REPACK\", \"1\" });\n    #endif\n\n        features.push_back({ nullptr, nullptr });\n\n        return features;\n    }();\n\n    return features.data();\n\n    GGML_UNUSED(reg);\n}\n\nstatic void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) {\n    if (strcmp(name, \"ggml_backend_set_n_threads\") == 0) {\n        ggml_backend_set_n_threads_t fct = ggml_backend_cpu_set_n_threads;\n        return (void *)fct;\n    }\n    if (strcmp(name, \"ggml_backend_dev_get_extra_bufts\") == 0) {\n        ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_cpu_device_get_extra_buffers_type;\n        return (void *)fct;\n    }\n    if (strcmp(name, \"ggml_backend_get_features\") == 0) {\n        return (void *)ggml_backend_cpu_get_features;\n    }\n    if (strcmp(name, \"ggml_backend_set_abort_callback\") == 0) {\n        return (void *)ggml_backend_cpu_set_abort_callback;\n    }\n    if (strcmp(name, \"ggml_backend_cpu_numa_init\") == 0) {\n        return (void *)ggml_numa_init;\n    }\n    if (strcmp(name, \"ggml_backend_cpu_is_numa\") == 0) {\n        return (void *)ggml_is_numa;\n    }\n    if (strcmp(name, \"ggml_backend_cpu_set_use_ref\") == 0) {\n        return (void *)ggml_backend_cpu_set_use_ref;\n    }\n\n    // threadpool - TODO:  move to ggml-base\n    if (strcmp(name, \"ggml_threadpool_new\") == 0) {\n        return (void *)ggml_threadpool_new;\n    }\n    if (strcmp(name, \"ggml_threadpool_free\") == 0) {\n        return (void *)ggml_threadpool_free;\n    }\n    if (strcmp(name, \"ggml_backend_cpu_set_threadpool\") == 0) {\n        return (void *)ggml_backend_cpu_set_threadpool;\n    }\n\n    return NULL;\n\n    GGML_UNUSED(reg);\n}\n\nstatic const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = {\n    /* .get_name         = */ ggml_backend_cpu_reg_get_name,\n    /* .get_device_count = */ ggml_backend_cpu_reg_get_device_count,\n    /* .get_device       = */ ggml_backend_cpu_reg_get_device,\n    /* .get_proc_address = */ ggml_backend_cpu_get_proc_address,\n};\n\nggml_backend_reg_t ggml_backend_cpu_reg(void) {\n    // init CPU feature detection\n    ggml_cpu_init();\n\n    static struct ggml_backend_reg ggml_backend_cpu_reg = {\n        /* .api_version = */ GGML_BACKEND_API_VERSION,\n        /* .iface       = */ ggml_backend_cpu_reg_i,\n        /* .context     = */ NULL,\n    };\n\n    return &ggml_backend_cpu_reg;\n}\n\nGGML_BACKEND_DL_IMPL(ggml_backend_cpu_reg)\n"
  },
  {
    "path": "src/ggml-cpu/hbm.cpp",
    "content": "#ifdef GGML_USE_CPU_HBM\n\n#include \"ggml-backend.h\"\n#include \"ggml-backend-impl.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-impl.h\"\n\n#include \"hbm.h\"\n\n// buffer type HBM\n\n#include <hbwmalloc.h>\n\nstatic const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    return \"CPU_HBM\";\n\n    GGML_UNUSED(buft);\n}\n\nstatic void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    hbw_free(buffer->context);\n}\n\nstatic ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,\n                                                                           size_t                     size) {\n    void * ptr;\n    int    result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);\n    if (result != 0) {\n        GGML_LOG_ERROR(\"failed to allocate HBM buffer of size %zu\\n\", size);\n        return NULL;\n    }\n\n    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);\n    buffer->buft                 = buft;\n    buffer->iface.free_buffer    = ggml_backend_cpu_hbm_buffer_free_buffer;\n\n    return buffer;\n}\n\nggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {\n    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {\n        /* .iface    = */ {\n                           /* .get_name         = */ ggml_backend_cpu_hbm_buffer_type_get_name,\n                           /* .alloc_buffer     = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,\n                           /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,\n                           /* .get_max_size     = */ nullptr,  // defaults to SIZE_MAX\n                           /* .get_alloc_size   = */ nullptr,  // defaults to ggml_nbytes\n                           /* .is_host          = */ ggml_backend_cpu_buffer_type_is_host,\n                           },\n        /* .context  = */ nullptr,\n    };\n\n    return &ggml_backend_cpu_buffer_type_hbm;\n}\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/hbm.h",
    "content": "#pragma once\n\n#include \"ggml-backend.h\"\n#include \"ggml.h\"\n\n// GGML CPU internal header\n\nggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);\n"
  },
  {
    "path": "src/ggml-cpu/kleidiai/kernels.cpp",
    "content": "// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>\n// SPDX-License-Identifier: MIT\n//\n\n// KleidiAI micro-kernels\n#include \"kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h\"\n#include \"kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h\"\n#include \"kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h\"\n#include \"kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h\"\n#include \"kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h\"\n#include \"kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h\"\n#include \"kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h\"\n#include \"kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h\"\n#include \"kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h\"\n#include \"kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h\"\n#include \"kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h\"\n#include \"kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h\"\n#include \"kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h\"\n#include \"kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h\"\n#include \"kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h\"\n#include \"kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h\"\n#include \"kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.h\"\n\n#include \"kai_lhs_pack_bf16p2vlx2_f32_sme.h\"\n#include \"kai_lhs_quant_pack_qsi8d32p_f32.h\"\n#include \"kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h\"\n#include \"kai_lhs_quant_pack_qsi8d32p_f32_neon.h\"\n#include \"kai_lhs_quant_pack_qai8dxp_f32.h\"\n\n#include \"kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h\"\n#include \"kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h\"\n#include \"kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h\"\n#include \"kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h\"\n#include \"kai_lhs_pack_f16pmrx2_f32_neon.h\"\n\n#include \"kai_common.h\"\n\n#include \"simd-mappings.h\"\n\n#define GGML_COMMON_DECL_CPP\n#include \"ggml-common.h\"\n\n#include \"kernels.h\"\n\n#define NELEMS(x) (sizeof(x) / sizeof(*x))\n\ntemplate<size_t(*Fn)(size_t,size_t,size_t)>\nstatic inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {\n    return Fn(a, b, c);\n}\n\ntemplate<size_t(*Fn)(size_t,size_t)>\nstatic inline size_t kernel_offs_fn2(size_t a, size_t b, size_t) {\n    return Fn(a, b);\n}\n\ntemplate<void(*Fn)(size_t,size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>\nstatic inline void kernel_run_fn11(size_t m, size_t n, size_t k, size_t bl,\n                                     const void* lhs, const void* rhs, void* dst,\n                                     size_t dst_stride_row, size_t dst_stride_col,\n                                     float clamp_min, float clamp_max) {\n    Fn(m, n, k, bl, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);\n}\n\ntemplate<void(*Fn)(size_t,size_t,size_t,const void*,const void*,void*,size_t,size_t,float,float)>\nstatic inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,\n                                   const void* lhs, const void* rhs, void* dst,\n                                   size_t dst_stride_row, size_t dst_stride_col,\n                                   float clamp_min, float clamp_max) {\n    Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);\n}\n\ntemplate<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>\nstatic inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,\n                                         const void* lhs, const void* rhs, void* dst,\n                                         size_t dst_stride_row, size_t dst_stride_col,\n                                         float clamp_min, float clamp_max) {\n    Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);\n}\n\ntemplate<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>\nstatic inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {\n    return Fn(m, k, bl, mr, kr, sr);\n}\n\ntemplate<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>\nstatic inline size_t lhs_ps_fn5(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) {\n    return Fn(m, k, mr, kr, sr);\n}\n\ntemplate<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>\nstatic inline size_t lhs_offs_fn6(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {\n    return Fn(m_idx, k, bl, mr, kr, sr);\n}\n\ntemplate<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>\nstatic inline size_t lhs_offs_fn5(size_t m_idx, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) {\n    return Fn(m_idx, k, mr, kr, sr);\n}\n\ntemplate<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>\nstatic inline void lhs_pack_float_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,\n                                            size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {\n    Fn(m, k, bl, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);\n}\n\ntemplate<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)>\nstatic inline void lhs_pack_void_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,\n                                           size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {\n    Fn(m, k, bl, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);\n}\n\ntemplate<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)>\nstatic inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,\n                                             size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {\n    Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);\n}\n\ntemplate<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>\nstatic inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,\n                                            size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) {\n    Fn(m, k, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);\n}\n\ntemplate<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>\nstatic inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {\n    return Fn(n, k, nr, kr, bl);\n}\n\ntemplate<size_t(*Fn)(size_t,size_t)>\nstatic inline size_t rhs_ps_fn2(size_t n, size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) {\n    return Fn(n, k);\n}\n\ntemplate<size_t(*Fn)(size_t,size_t,size_t,size_t)>\nstatic inline size_t rhs_stride_fn4(size_t k, size_t nr, size_t kr, size_t bl) {\n    return Fn(k, nr, kr, bl);\n}\n\ntemplate<size_t(*Fn)(size_t)>\nstatic inline size_t rhs_stride_fn1(size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) {\n    return Fn(k);\n}\n\ntemplate<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const uint8_t*,const float*,void*,size_t,const struct kai_rhs_pack_qs4cxs1s0_param*)>\nstatic inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,\n                                      size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* /*scale*/,\n                                      void* rhs_packed, size_t extra_bytes, const void* params) {\n    Fn(num_groups, n, k, nr, kr, sr, bl,\n       static_cast<const uint8_t*>(rhs),\n       static_cast<const float*>(bias),\n       rhs_packed, extra_bytes,\n       static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));\n}\n\ntemplate<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>\nstatic inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,\n                                       size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,\n                                       void* rhs_packed, size_t extra_bytes, const void* params) {\n    Fn(num_groups, n, k, nr, kr, sr,\n       static_cast<const int8_t*>(rhs),\n       static_cast<const float*>(bias),\n       static_cast<const float*>(scale),\n       rhs_packed, extra_bytes,\n       static_cast<const kai_rhs_pack_qsi8cx_params*>(params));\n}\n\ntemplate<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>\nstatic inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,\n                                               size_t rhs_stride, const void* rhs, const void* bias, const void* scale,\n                                               void* rhs_packed, size_t extra_bytes, const void* params) {\n    Fn(num_groups, n, k, nr, kr, sr, rhs_stride, rhs, bias, scale, rhs_packed, extra_bytes, params);\n}\n\nstatic const size_t INT4_PER_BYTE = 2;\nstatic const size_t INT4_BITS     = 4;\nstatic const int Q4_0_ZERO_POINT  = 8;\nconst size_t INT4_PER_UINT16      = 4;\n\nstatic void dequantize_row_qsi4c32pscalef16(\n    const void *packed_data,\n    int32_t row_idx,\n    int64_t nc,\n    float *out,\n    size_t nr_pack,\n    size_t packed_row_stride,\n    size_t kr,\n    size_t bl,\n    size_t num_bytes_multiplier\n) {\n    size_t group_idx = row_idx / nr_pack;\n    size_t row_in_group = row_idx % nr_pack;\n    const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;\n    size_t num_blocks = nc / bl;\n    const uint8_t *block_ptr = packed_group;\n\n    for (size_t b = 0; b < num_blocks; ++b) {\n        uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));\n        float scale = GGML_CPU_FP16_TO_FP32(scale_f16);\n\n        const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;\n        size_t num_segments = bl / kr;\n        size_t num_bytes_per_segment = kr / INT4_PER_BYTE;\n\n        for (size_t s = 0; s < num_segments; ++s) {\n            const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;\n            const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;\n            for (size_t k = 0; k < num_bytes_per_segment; ++k) {\n                uint8_t byte = qbytes[k] ^ 0x88;\n                int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT;\n                int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;\n                out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;\n                out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale;\n            }\n        }\n        block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;\n    }\n}\n\nstatic void dequantize_row_qsi4c32ps1s0scalef16(\n    const void *packed_data,\n    int32_t row_idx,\n    int64_t k,\n    float *out,\n    size_t nr,\n    size_t packed_row_stride,\n    size_t kr,\n    size_t bl,\n    size_t num_bytes_multiplier\n) {\n    const size_t num_blocks = k / bl;\n    const size_t bl4 = bl / INT4_PER_UINT16;\n\n    size_t group_idx = row_idx / nr;\n    size_t row_in_group = row_idx % nr;\n\n    const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;\n    const uint16_t *qdata = (const uint16_t *)packed_group;\n    const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));\n\n    for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {\n        uint16_t scale_f16 = scales[row_in_group + block_idx * nr];\n        float scale = GGML_CPU_FP16_TO_FP32(scale_f16);\n\n        for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) {\n            uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];\n\n            for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) {\n                int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT;\n                out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;\n            }\n        }\n    }\n    GGML_UNUSED(kr);\n}\n\nstatic void dequantize_row_qsi8cxp(\n    const void *packed_data,\n    int32_t row_idx,\n    int64_t k,\n    float *out,\n    size_t nr,\n    size_t packed_row_stride,\n    size_t kr,\n    size_t bl,\n    size_t num_bytes_multiplier\n) {\n    GGML_UNUSED(bl);\n    GGML_UNUSED(num_bytes_multiplier);\n\n    const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0;\n    const size_t group_idx = row_idx / nr;\n    const size_t row_in_group = row_idx % nr;\n\n    const uint8_t * group_ptr = static_cast<const uint8_t *>(packed_data) + group_idx * packed_row_stride;\n    const int8_t  * data_base = reinterpret_cast<const int8_t *>(group_ptr);\n\n    const size_t num_blocks = k_internal / kr;\n\n    for (size_t block = 0; block < num_blocks; ++block) {\n        const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr;\n        for (size_t i = 0; i < kr; ++i) {\n            const size_t k_idx = block * kr + i;\n            if (k_idx < (size_t) k) {\n                out[k_idx] = static_cast<float>(block_ptr[i]);\n            }\n        }\n    }\n\n    const uint8_t * sums_ptr = group_ptr + nr * k_internal;\n    GGML_UNUSED(sums_ptr);\n\n    const float * scale_ptr = reinterpret_cast<const float *>(sums_ptr + nr * sizeof(int32_t));\n    const float scale = scale_ptr[row_in_group];\n\n    if (scale == 0.0f) {\n        for (size_t i = 0; i < (size_t) k; ++i) {\n            out[i] = 0.0f;\n        }\n        return;\n    }\n\n    for (size_t i = 0; i < (size_t) k; ++i) {\n        out[i] *= scale;\n    }\n}\n\nstatic ggml_kleidiai_kernels gemm_gemv_kernels[] = {\n#if defined(__ARM_FEATURE_SME)\n    {\n        /* SME GEMM */\n        /* .kern_info = */ {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>,\n            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>,\n        },\n\n        /* .gemm_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_pack_f16pmrx2_f32_neon,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_pack_f16pmrx2_f32_neon>,\n            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_pack_f16pmrx2_f32_neon>,\n            /* .pack_func_ex          = */ &lhs_pack_void_fn10<kai_run_lhs_pack_f16pmrx2_f32_neon>,\n        },\n        /* SME GEMV */\n        /* .kern_info = */ {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,\n            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,\n        },\n        /* .gemv_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,\n            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,\n        },\n        /* .rhs_info = */ {\n            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,\n            /* .to_float              = */ dequantize_row_qsi4c32ps1s0scalef16,\n            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,\n            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,\n            /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,\n        },\n        /* .required_cpu       = */ CPU_FEATURE_SME,\n        /* .lhs_type           = */ GGML_TYPE_F32,\n        /* .rhs_type           = */ GGML_TYPE_Q4_0,\n        /* .op_type            = */ GGML_TYPE_F32,\n    },\n    {\n        /* SME GEMM */\n        /* .kern_info = */ {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,\n            /* .run_kernel_ex         = */ &kernel_run_fn10<kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,\n        },\n        /* .gemm_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>,\n            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>,\n            /* .pack_func_ex          = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>,\n        },\n        /* SME GEMV */\n        /* .kern_info = */ {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,\n            /* .get_lhs_offset_ex     = */ nullptr,\n            /* .get_rhs_packed_offset_ex = */ nullptr,\n            /* .run_kernel_ex         = */ nullptr,\n        },\n        /* .gemv_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>,\n            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>,\n            /* .pack_func_ex          = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>,\n        },\n        /* .rhs_info = */ {\n            /* .packed_stride         = */ nullptr,\n            /* .to_float              = */ nullptr,\n            /* .packed_size_ex        = */ &rhs_ps_fn2<kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,\n            /* .packed_stride_ex      = */ &rhs_stride_fn1<kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,\n            /* .pack_func_ex          = */ &rhs_pack_fn13<kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,\n        },\n        /* .required_cpu       = */ CPU_FEATURE_SME,\n        /* .lhs_type           = */ GGML_TYPE_F32,\n        /* .rhs_type           = */ GGML_TYPE_F16,\n        /* .op_type            = */ GGML_TYPE_F32,\n    },\n#endif\n#if defined(__APPLE__)\n#if defined(__ARM_FEATURE_DOTPROD)\n    {\n        /* DOTPROD GEMM */\n        /* .kern_info = */ {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,\n            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,\n        },\n        /* .gemm_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,\n            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,\n        },\n        /* DOTPROD GEMV */\n        /* .kern_info = */ {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,\n            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,\n        },\n        /* .gemv_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,\n            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,\n        },\n        /* .rhs_info = */ {\n            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,\n            /* .to_float              = */ dequantize_row_qsi4c32pscalef16,\n            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n            /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n        },\n        /* .required_cpu       = */ CPU_FEATURE_DOTPROD,\n        /* .lhs_type           = */ GGML_TYPE_F32,\n        /* .rhs_type           = */ GGML_TYPE_Q4_0,\n        /* .op_type            = */ GGML_TYPE_F32,\n    },\n#endif\n#if defined(__ARM_FEATURE_MATMUL_INT8)\n    {\n        /* i8mm GEMM */\n        /* .kern_info = */ {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,\n            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,\n        },\n        /* .gemm_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,\n            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,\n        },\n        /* i8mm GEMV */\n        /* .kern_info = */ {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,\n            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,\n        },\n        /* .gemv_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,\n            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,\n        },\n        /* .rhs_info = */ {\n            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,\n            /* .to_float              = */ dequantize_row_qsi4c32pscalef16,\n            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n            /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n        },\n        /* .required_cpu       = */ CPU_FEATURE_I8MM,\n        /* .lhs_type           = */ GGML_TYPE_F32,\n        /* .rhs_type           = */ GGML_TYPE_Q4_0,\n        /* .op_type            = */ GGML_TYPE_F32,\n    },\n#endif\n#else\n#if defined(__ARM_FEATURE_SVE)\n    {\n        /* SVE i8mm GEMM */\n        /* .kern_info = */ {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,\n            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,\n        },\n        /* .gemm_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,\n            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,\n        },\n        /* SVE dotprod GEMV */\n        /* .kern_info = */ {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,\n            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,\n        },\n        /* .gemv_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,\n            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,\n        },\n        /* .rhs_info = */ {\n            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,\n            /* .to_float              = */ dequantize_row_qsi4c32pscalef16,\n            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n            /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n        },\n        /* .required_cpu       = */ CPU_FEATURE_SVE | CPU_FEATURE_I8MM | CPU_FEATURE_DOTPROD,\n        /* .lhs_type           = */ GGML_TYPE_F32,\n        /* .rhs_type           = */ GGML_TYPE_Q4_0,\n        /* .op_type            = */ GGML_TYPE_F32,\n    },\n#endif\n#if defined(__ARM_FEATURE_MATMUL_INT8)\n    {\n        /* i8mm GEMM */\n        /* .kern_info = */ {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,\n            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,\n        },\n        /* .gemm_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,\n            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,\n        },\n        /* i8mm GEMV */\n        /* .kern_info = */ {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,\n            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,\n        },\n        /* .gemv_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,\n            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,\n        },\n        /* .rhs_info = */ {\n            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,\n            /* .to_float              = */ dequantize_row_qsi4c32pscalef16,\n            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n            /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n        },\n        /* .required_cpu       = */ CPU_FEATURE_I8MM,\n        /* .lhs_type           = */ GGML_TYPE_F32,\n        /* .rhs_type           = */ GGML_TYPE_Q4_0,\n        /* .op_type            = */ GGML_TYPE_F32,\n    },\n#endif // __ARM_FEATURE_MATMUL_INT8\n#if defined(__ARM_FEATURE_DOTPROD)\n    {\n        /* DOTPROD GEMM */\n        /* .kern_info = */ {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,\n            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,\n        },\n        /* .gemm_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,\n            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,\n        },\n        /* DOTPROD GEMV */\n        /* .kern_info = */ {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,\n            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,\n        },\n        /* .gemv_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,\n            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,\n        },\n        /* .rhs_info = */ {\n            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,\n            /* .to_float              = */ dequantize_row_qsi4c32pscalef16,\n            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n            /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,\n        },\n        /* .required_cpu       = */ CPU_FEATURE_DOTPROD,\n        /* .lhs_type           = */ GGML_TYPE_F32,\n        /* .rhs_type           = */ GGML_TYPE_Q4_0,\n        /* .op_type            = */ GGML_TYPE_F32,\n    },\n#endif\n#endif\n    { /* Sentinel */ }\n};\n\nstatic ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {\n#if defined(__ARM_FEATURE_SME)\n    {\n        /* SME GEMM */\n        {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,\n            /* .run_kernel_ex         = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,\n        },\n        /* .gemm_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,\n            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,\n        },\n        /* SME GEMV */\n        {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,\n            /* .run_kernel_ex         = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,\n        },\n        /* .gemv_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,\n            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,\n        },\n        /* .rhs_info = */ {\n            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,\n            /* .to_float              = */ dequantize_row_qsi8cxp,\n            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,\n            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,\n            /* .pack_func_ex          = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,\n        },\n        /* .required_cpu       = */ CPU_FEATURE_SME,\n        /* .lhs_type           = */ GGML_TYPE_F32,\n        /* .rhs_type           = */ GGML_TYPE_Q8_0,\n        /* .op_type            = */ GGML_TYPE_F32,\n    },\n#endif\n#if defined(__ARM_FEATURE_MATMUL_INT8)\n    {\n        /* I8MM GEMM */\n        {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,\n            /* .run_kernel_ex         = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,\n        },\n        /* .gemm_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,\n            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,\n        },\n        /* I8MM GEMV (dotprod fallback) */\n        {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,\n            /* .run_kernel_ex         = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,\n        },\n        /* .gemv_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,\n            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,\n        },\n        /* .rhs_info = */ {\n            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,\n            /* .to_float              = */ dequantize_row_qsi8cxp,\n            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,\n            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,\n            /* .pack_func_ex          = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,\n        },\n        /* .required_cpu       = */ CPU_FEATURE_I8MM,\n        /* .lhs_type           = */ GGML_TYPE_F32,\n        /* .rhs_type           = */ GGML_TYPE_Q8_0,\n        /* .op_type            = */ GGML_TYPE_F32,\n    },\n#endif\n#if defined(__ARM_FEATURE_DOTPROD)\n    {\n        /* DOTPROD GEMM */\n        {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,\n            /* .run_kernel_ex         = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,\n        },\n        /* .gemm_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,\n            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,\n        },\n        /* DOTPROD GEMV */\n        {\n            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,\n            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,\n            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,\n            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,\n            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,\n            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,\n            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,\n            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,\n            /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,\n            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,\n            /* .run_kernel_ex         = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,\n        },\n        /* .gemv_lhs_info = */ {\n            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,\n            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,\n            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,\n            /* .pack_func_ex          = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,\n        },\n        /* .rhs_info = */ {\n            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,\n            /* .to_float              = */ dequantize_row_qsi8cxp,\n            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,\n            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,\n            /* .pack_func_ex          = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,\n        },\n        /* .required_cpu       = */ CPU_FEATURE_DOTPROD,\n        /* .lhs_type           = */ GGML_TYPE_F32,\n        /* .rhs_type           = */ GGML_TYPE_Q8_0,\n        /* .op_type            = */ GGML_TYPE_F32,\n    },\n#endif\n    { /* Sentinel */ }\n};\n\nggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {\n    ggml_kleidiai_kernels * kernel = nullptr;\n\n    if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {\n#if defined(__ARM_FEATURE_SME)          ||  \\\n    defined(__ARM_FEATURE_DOTPROD)      ||  \\\n    defined(__ARM_FEATURE_MATMUL_INT8)  ||  \\\n    defined(__ARM_FEATURE_SVE)\n        auto try_table = [&](auto & table) {\n            for (size_t i = 0; i < NELEMS(table) - 1; ++i) {\n                if ((cpu_features & table[i].required_cpu) == table[i].required_cpu &&\n                    table[i].lhs_type == tensor->src[1]->type &&\n                    table[i].rhs_type == tensor->src[0]->type &&\n                    table[i].op_type  == tensor->type) {\n                    kernel = &table[i];\n                    return true;\n                }\n            }\n            return false;\n        };\n\n        if (tensor->src[0]->type == GGML_TYPE_Q8_0) {\n            try_table(gemm_gemv_kernels_q8);\n        } else {\n            try_table(gemm_gemv_kernels);\n        }\n#else\n    GGML_UNUSED(gemm_gemv_kernels);\n    GGML_UNUSED(gemm_gemv_kernels_q8);\n    GGML_UNUSED(cpu_features);\n#endif\n    }\n\n    return kernel;\n}\n\nggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {\n    ggml_kleidiai_kernels * kernels = nullptr;\n\n#if defined(__ARM_FEATURE_SME)          ||  \\\n    defined(__ARM_FEATURE_DOTPROD)      ||  \\\n    defined(__ARM_FEATURE_MATMUL_INT8)  ||  \\\n    defined(__ARM_FEATURE_SVE)\n    for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {\n        if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {\n            kernels = &gemm_gemv_kernels[i];\n            break;\n        }\n    }\n#else\n    GGML_UNUSED(features);\n#endif\n\n    return kernels;\n}\n\nggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) {\n    ggml_kleidiai_kernels * kernels = nullptr;\n\n#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)\n    for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {\n        if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {\n            kernels = &gemm_gemv_kernels_q8[i];\n            break;\n        }\n    }\n#else\n    GGML_UNUSED(features);\n#endif\n\n    return kernels;\n}\n"
  },
  {
    "path": "src/ggml-cpu/kleidiai/kernels.h",
    "content": "// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>\n// SPDX-License-Identifier: MIT\n//\n\n#pragma once\n\n#include \"ggml.h\"\n\nenum cpu_feature {\n    CPU_FEATURE_NONE    = 0,\n    CPU_FEATURE_DOTPROD = 1,\n    CPU_FEATURE_I8MM    = 2,\n    CPU_FEATURE_SVE     = 4,\n    CPU_FEATURE_SME     = 8\n};\n\ninline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {\n    lhs = static_cast<cpu_feature>(lhs | rhs);\n    return lhs;\n}\ninline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) {\n    return static_cast<cpu_feature>(static_cast<int>(lhs) | static_cast<int>(rhs));\n}\n\nstruct kernel_info {\n    size_t (*get_m_step)(void);\n    size_t (*get_n_step)(void);\n    size_t (*get_mr)(void);\n    size_t (*get_nr)(void);\n    size_t (*get_kr)(void);\n    size_t (*get_sr)(void);\n\n    size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);\n    size_t (*get_dst_size)(size_t m, size_t n);\n\n    size_t (*get_lhs_offset_ex)(size_t m_idx, size_t k, size_t bl);\n\n    size_t (*get_rhs_packed_offset_ex)(size_t n_idx, size_t k, size_t bl);\n\n    void (*run_kernel_ex)(\n        size_t m, size_t n, size_t k, size_t bl,\n        const void* lhs_packed, const void* rhs_packed,\n        void* dst, size_t dst_stride_row, size_t dst_stride_col,\n        float clamp_min, float clamp_max);\n};\n\nstruct lhs_packing_info {\n    size_t (*get_offset)(size_t m_idx, size_t lhs_stride);\n\n    size_t (*get_packed_offset_ex)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);\n\n    size_t (*packed_size_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);\n\n    void (*pack_func_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,\n        size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed);\n};\n\nstruct rhs_packing_info {\n    size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);\n\n    void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out,\n                     size_t nr_pack, size_t packed_row_stride, size_t kr, size_t bl,\n                     size_t num_bytes_multiplier);\n\n    size_t (*packed_size_ex)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);\n\n    size_t (*packed_stride_ex)(size_t k, size_t nr, size_t kr, size_t bl);\n\n    void (*pack_func_ex)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,\n        size_t rhs_stride, const void * rhs, const void * bias, const void * scale, void * rhs_packed, size_t extra_bytes, const void * params);\n};\n\nstruct ggml_kleidiai_kernels {\n    kernel_info      gemm;\n    lhs_packing_info gemm_lhs_info;\n\n    kernel_info      gemv;\n    lhs_packing_info gemv_lhs_info;\n\n    rhs_packing_info rhs_info;\n\n    cpu_feature required_cpu;\n    ggml_type lhs_type;\n    ggml_type rhs_type;\n    ggml_type op_type;\n};\n\nggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);\nggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);\nggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features);\n"
  },
  {
    "path": "src/ggml-cpu/kleidiai/kleidiai.cpp",
    "content": "// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>\n// SPDX-License-Identifier: MIT\n//\n#include <arm_neon.h>\n#include <assert.h>\n#include <stdio.h>\n#include <atomic>\n#include <cfloat>\n#include <algorithm>\n#include <cmath>\n#include <stdexcept>\n#include <stdint.h>\n#include <string.h>\n#include <string>\n#include <vector>\n#include <array>\n#include <cstddef>\n#include <cstdint>\n#include <fstream>\n#include <set>\n#include <iostream>\n#include <climits>\n#if defined(__linux__)\n#include <asm/hwcap.h>\n#include <sys/auxv.h>\n#include <sys/types.h>\n#include <sys/stat.h>\n#include <unistd.h>\n#elif defined(__APPLE__)\n#include <string_view>\n#include <sys/sysctl.h>\n#include <sys/types.h>\n#elif defined(_WIN32)\n#include <windows.h>\n#include <excpt.h>\n#endif\n\n#include \"kleidiai.h\"\n\n#include \"ggml-cpu.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-backend-impl.h\"\n#include \"ggml-threading.h\"\n#include \"traits.h\"\n\n#include \"kernels.h\"\n\n#include \"kai_common.h\"\n\n#define GGML_COMMON_DECL_CPP\n#include \"ggml-common.h\"\n\nstatic constexpr int      GGML_KLEIDIAI_MAX_KERNEL_SLOTS = 2;\nstatic constexpr uint32_t GGML_KLEIDIAI_PACK_MAGIC       = 0x4b4c4149; // \"KLAI\"\nstatic constexpr uint16_t GGML_KLEIDIAI_PACK_VERSION     = 1;\nstatic constexpr size_t   GGML_KLEIDIAI_PACK_ALIGN       = 64;\n\nstruct ggml_kleidiai_context {\n    cpu_feature features;\n    ggml_kleidiai_kernels * kernels_q4;\n    ggml_kleidiai_kernels * kernels_q8;\n    int sme_thread_cap; // <= 0 means “SME disabled/unknown”;\n    int thread_hint;    // <= 0 means “no hint”\n} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 };\n\nstatic const char* cpu_feature_to_string(cpu_feature f) {\n    if (f == CPU_FEATURE_NONE) {\n        return \"NONE\";\n    } else if ((f & CPU_FEATURE_SME) == CPU_FEATURE_SME) {\n        return \"SME\";\n    } else if ((f & CPU_FEATURE_SVE) == CPU_FEATURE_SVE) {\n        return \"SVE\";\n    }\n    else if ((f & CPU_FEATURE_I8MM) == CPU_FEATURE_I8MM) {\n        return \"I8MM\";\n    } else if ((f & CPU_FEATURE_DOTPROD) == CPU_FEATURE_DOTPROD) {\n        return \"DOTPROD\";\n    }\n    else {\n        return \"UNKNOWN\";\n    }\n}\n\nstatic size_t detect_num_smcus() {\n    if (!ggml_cpu_has_sme()) {\n        return 0;\n    }\n\n#if defined(__linux__) && defined(__aarch64__)\n    // Linux/aarch64: Best-effort count of Streaming Mode Compute Units (SMCUs) via SMIDR_EL1 sysfs.\n    size_t num_private = 0;\n    std::set<uint32_t> shared_ids;\n\n    for (size_t cpu = 0;; ++cpu) {\n        const std::string path =\n            \"/sys/devices/system/cpu/cpu\" + std::to_string(cpu) +\n            \"/regs/identification/smidr_el1\";\n\n        std::ifstream file(path);\n        if (!file.is_open()) {\n            break;\n        }\n\n        uint64_t smidr = 0;\n        if (!(file >> std::hex >> smidr)) {\n            continue;\n        }\n\n        // Arm ARM: SMIDR_EL1\n        const uint32_t sh = (uint32_t)((smidr >> 13) & 0x3);\n        // Build an \"affinity-like\" identifier for shared SMCUs.\n        // Keep the original packing logic, but isolate it here.\n        const uint32_t id = (uint32_t)((smidr & 0xFFFu) | ((smidr >> 20) & 0xFFFFF000u));\n\n        switch (sh) {\n            case 0b10: // private SMCU\n                ++num_private;\n                break;\n            case 0b11: // shared SMCU\n                shared_ids.emplace(id);\n                break;\n            case 0b00:\n                // Ambiguous / implementation-defined. Be conservative:\n                // treat id==0 as private, otherwise as shared.\n                if (id == 0) ++num_private;\n                else shared_ids.emplace(id);\n                break;\n            default:\n                break;\n        }\n    }\n\n    return num_private + shared_ids.size();\n\n#elif defined(__APPLE__) && defined(__aarch64__)\n    // table for known M4 variants. Users can override via GGML_KLEIDIAI_SME=<n>.\n    char chip_name[256] = {};\n    size_t size = sizeof(chip_name);\n\n    if (sysctlbyname(\"machdep.cpu.brand_string\", chip_name, &size, nullptr, 0) == 0) {\n        const std::string brand(chip_name);\n\n        struct ModelSMCU { const char *match; size_t smcus; };\n        static const ModelSMCU table[] = {\n            { \"M4 Ultra\", 2 },\n            { \"M4 Max\",   2 },\n            { \"M4 Pro\",   2 },\n            { \"M4\",       1 },\n        };\n\n        for (const auto &e : table) {\n            if (brand.find(e.match) != std::string::npos) {\n                return e.smcus;\n            }\n        }\n    }\n    return 1;\n\n#else\n    return 1;\n#endif\n}\n\nstatic int parse_uint_env(const char *s, const char *name, bool *ok) {\n    if (!s) { *ok = false; return 0; }\n    char *end = nullptr;\n    long v = strtol(s, &end, 10);\n    if (end == s || *end != '\\0') {\n        GGML_LOG_WARN(\"kleidiai: invalid %s='%s' (expected integer)\\n\", name, s);\n        *ok = false;\n        return 0;\n    }\n    if (v < 0 || v > INT_MAX) {\n        GGML_LOG_WARN(\"kleidiai: out-of-range %s='%s'\\n\", name, s);\n        *ok = false;\n        return 0;\n    }\n    *ok = true;\n    return (int)v;\n}\n\nstatic void init_kleidiai_context(void) {\n    ggml_critical_section_start();\n    static bool initialized = false;\n\n    if (!initialized) {\n        initialized = true;\n\n        const char *env_sme     = getenv(\"GGML_KLEIDIAI_SME\");\n        const char *env_threads = getenv(\"GGML_TOTAL_THREADS\");\n\n        const bool cpu_has_sme = ggml_cpu_has_sme();\n        size_t detected_smcus = 0;\n\n        ctx.features  = (ggml_cpu_has_dotprod()     ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |\n                        (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM    : CPU_FEATURE_NONE) |\n                        ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);\n\n        if (env_threads) {\n            bool ok = false;\n            int hint = parse_uint_env(env_threads, \"GGML_TOTAL_THREADS\", &ok);\n            if (ok && hint > 0) {\n                ctx.thread_hint = hint;\n            }\n        }\n\n        // SME policy:\n        // - If CPU doesn't support SME: SME always off.\n        // - Else:\n        //   - env unset => auto-detect cores; enable if detected > 0.\n        //   - env=0     => force off.\n        //   - env>0     => force N cores (skip detection).\n        int sme_cores = 0;\n        bool sme_env_ok = false;\n        bool sme_env_set = (env_sme != nullptr);\n\n        if (!cpu_has_sme) {\n            if (sme_env_set) {\n                bool ok = false;\n                int req = parse_uint_env(env_sme, \"GGML_KLEIDIAI_SME\", &ok);\n                if (ok && req > 0) {\n                    GGML_LOG_WARN(\"kleidiai: GGML_KLEIDIAI_SME=%d but SME is not supported on this CPU; disabling SME\\n\", req);\n                }\n            }\n            sme_cores = 0;\n        } else {\n            if (sme_env_set) {\n                bool ok = false;\n                int v = parse_uint_env(env_sme, \"GGML_KLEIDIAI_SME\", &ok);\n                sme_env_ok = ok;\n\n                if (!ok) {\n                    GGML_LOG_WARN(\"kleidiai: GGML_KLEIDIAI_SME set but parsing failed; falling back to runtime SME-core detection\\n\");\n                    detected_smcus = detect_num_smcus();\n                    sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;\n                } else if (v == 0) {\n                    sme_cores = 0;\n                } else {\n                    sme_cores = v;\n                }\n            } else {\n                detected_smcus = detect_num_smcus();\n                sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;\n            }\n\n            if (!sme_env_set && sme_cores == 0) {\n                GGML_LOG_WARN(\"kleidiai: SME supported but runtime SME-core detection returned 0; falling back to NEON\\n\");\n            }\n\n            if (sme_cores > 0) {\n                ctx.features |= CPU_FEATURE_SME;\n            }\n        }\n\n        // Kernel selection\n        ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);\n        ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);\n\n        if (!ctx.kernels_q4) {\n            GGML_LOG_INFO(\"kleidiai: no compatible q4 kernels found for CPU features mask %d\\n\", (int)ctx.features);\n        } else {\n            GGML_LOG_INFO(\"kleidiai: primary q4 kernel feature %s\\n\", cpu_feature_to_string(ctx.kernels_q4->required_cpu));\n        }\n\n        if (!ctx.kernels_q8) {\n            GGML_LOG_INFO(\"kleidiai: no compatible q8 kernels found for CPU features mask %d\\n\", (int)ctx.features);\n        } else {\n            GGML_LOG_INFO(\"kleidiai: primary q8 kernel feature %s\\n\", cpu_feature_to_string(ctx.kernels_q8->required_cpu));\n        }\n\n        ctx.sme_thread_cap = (ctx.features & CPU_FEATURE_SME) ? sme_cores : 0;\n\n        if (ctx.features & CPU_FEATURE_SME) {\n            if (sme_env_set && sme_env_ok && sme_cores > 0) {\n                GGML_LOG_INFO(\"kleidiai: SME enabled (GGML_KLEIDIAI_SME=%d override)\\n\", sme_cores);\n            } else {\n                GGML_LOG_INFO(\"kleidiai: SME enabled (runtime-detected SME cores=%d)\\n\", sme_cores);\n            }\n        } else {\n            GGML_LOG_INFO(\"kleidiai: SME disabled\\n\");\n        }\n    }\n\n    ggml_critical_section_end();\n}\n\nstatic inline int kleidiai_sme_thread_cap() {\n    return ctx.sme_thread_cap;\n}\n\nstatic inline size_t align_up(size_t value, size_t alignment) {\n    if (alignment == 0) {\n        return value;\n    }\n    const size_t remainder = value % alignment;\n    return remainder == 0 ? value : value + (alignment - remainder);\n}\n\nstatic inline bool kleidiai_pack_fallback_allowed() {\n    if (ctx.sme_thread_cap <= 0) {\n        return false;\n    }\n    if (ctx.thread_hint <= 0) {\n        return true;\n    }\n    return ctx.thread_hint > ctx.sme_thread_cap;\n}\n\nstruct kleidiai_weight_header {\n    uint32_t magic;\n    uint16_t version;\n    uint16_t slot_count;\n    uint64_t offsets[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];\n    uint64_t sizes[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];\n};\n\nstatic inline kleidiai_weight_header * kleidiai_weight_header_from_ptr(void * data) {\n    return reinterpret_cast<kleidiai_weight_header *>(data);\n}\n\nstatic inline const kleidiai_weight_header * kleidiai_weight_header_from_ptr(const void * data) {\n    return reinterpret_cast<const kleidiai_weight_header *>(data);\n}\n\nstatic inline bool kleidiai_is_weight_header_valid(const kleidiai_weight_header * header) {\n    if (!header) {\n        return false;\n    }\n    if (header->magic != GGML_KLEIDIAI_PACK_MAGIC || header->version != GGML_KLEIDIAI_PACK_VERSION) {\n        return false;\n    }\n    if (header->slot_count == 0 || header->slot_count > GGML_KLEIDIAI_MAX_KERNEL_SLOTS) {\n        return false;\n    }\n    return true;\n}\n\nstatic inline uint8_t * kleidiai_weight_slot_ptr(kleidiai_weight_header * header, int slot) {\n    if (!kleidiai_is_weight_header_valid(header)) {\n        return nullptr;\n    }\n    if (slot < 0 || slot >= header->slot_count) {\n        return nullptr;\n    }\n    return reinterpret_cast<uint8_t *>(header) + header->offsets[slot];\n}\n\nstatic inline const uint8_t * kleidiai_weight_slot_ptr(const kleidiai_weight_header * header, int slot) {\n    if (!kleidiai_is_weight_header_valid(header)) {\n        return nullptr;\n    }\n    if (slot < 0 || slot >= header->slot_count) {\n        return nullptr;\n    }\n    return reinterpret_cast<const uint8_t *>(header) + header->offsets[slot];\n}\n\nstatic inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q4() {\n    return ctx.kernels_q4;\n}\n\nstatic inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q8() {\n    return ctx.kernels_q8;\n}\n\ntemplate <typename SelectFallback>\nstatic int kleidiai_collect_kernel_chain_common(\n        ggml_kleidiai_kernels * primary,\n        cpu_feature features,\n        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out,\n        SelectFallback select_fallback) {\n    int count = 0;\n    if (!primary) {\n        return 0;\n    }\n    out[count++] = primary;\n\n    if ((primary->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {\n        const cpu_feature fallback_mask = static_cast<cpu_feature>(features & ~CPU_FEATURE_SME);\n        if (fallback_mask != CPU_FEATURE_NONE) {\n            ggml_kleidiai_kernels * fallback = select_fallback(fallback_mask);\n            if (fallback && fallback != primary &&\n                fallback->lhs_type == primary->lhs_type &&\n                fallback->rhs_type == primary->rhs_type &&\n                fallback->op_type  == primary->op_type) {\n                out[count++] = fallback;\n            }\n        }\n    }\n\n    return count;\n}\n\nstatic int kleidiai_collect_kernel_chain(const struct ggml_tensor * op,\n        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {\n    ggml_kleidiai_kernels * primary = ggml_kleidiai_select_kernels(ctx.features, op);\n    return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,\n        [&](cpu_feature mask) { return ggml_kleidiai_select_kernels(mask, op); });\n}\n\nstatic int kleidiai_collect_q4_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {\n    ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q4();\n    return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,\n        [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q4_0(mask); });\n}\n\nstatic int kleidiai_collect_q8_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {\n    ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q8();\n    return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,\n        [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q8_0(mask); });\n}\n\nstatic inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {\n    GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);\n    return tensor->ne[dim];\n}\n\nnamespace ggml::cpu::kleidiai {\n\nstatic size_t round_down(size_t x, size_t y) {\n    return y == 0 ? x : x - (x % y);\n}\n\nstatic void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {\n    size_t src_stride = rhs_stride / sizeof(uint16_t);\n    size_t dst_stride = n;\n\n    for (size_t k_idx = 0; k_idx < k; ++k_idx) {\n        for (size_t n_idx = 0; n_idx < n; ++n_idx) {\n            uint16_t v = *(src + k_idx + n_idx * src_stride);\n            *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);\n        }\n    }\n}\n\nclass tensor_traits : public ggml::cpu::tensor_traits {\n    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {\n        if (op->op != GGML_OP_MUL_MAT) {\n            return false;\n        }\n\n        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;\n        const int slot_count = kleidiai_collect_kernel_chain(op, kernel_chain);\n        if (slot_count == 0) {\n            return false;\n        }\n\n        const bool is_gemv = op->src[1]->ne[1] == 1;\n        const size_t k = op->src[0]->ne[0];\n        const size_t n = op->src[0]->ne[1];\n        const size_t m = op->src[1]->ne[1];\n\n        if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) {\n            const size_t qk = (op->src[0]->type == GGML_TYPE_Q4_0) ? QK4_0 : QK8_0;\n\n            size_t cursor = 0;\n            bool any_slot = false;\n\n            for (int slot = 0; slot < slot_count; ++slot) {\n                ggml_kleidiai_kernels * kernels = kernel_chain[slot];\n                lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;\n                kernel_info * kernel        = is_gemv ? &kernels->gemv : &kernels->gemm;\n\n                if (!lhs_info || !lhs_info->packed_size_ex || !kernel) {\n                    return false;\n                }\n\n                const size_t mr = kernel->get_mr();\n                const size_t kr = kernel->get_kr();\n                const size_t sr = kernel->get_sr();\n\n                const size_t packed = lhs_info->packed_size_ex(m, k, qk, mr, kr, sr);\n\n                cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);\n                cursor += packed;\n                any_slot = true;\n            }\n\n            if (!any_slot) {\n                return false;\n            }\n\n            size = cursor;\n            return true;\n        }\n\n        if (op->src[0]->type == GGML_TYPE_F16) {\n            const int64_t lhs_batch_size0 = op->src[1]->ne[2];\n            const int64_t rhs_batch_size0 = op->src[0]->ne[2];\n            GGML_ASSERT(rhs_batch_size0 > 0);\n            const int64_t r = lhs_batch_size0 / rhs_batch_size0;\n\n            size_t cursor = 0;\n            bool any_slot = false;\n\n            for (int slot = 0; slot < slot_count; ++slot) {\n                ggml_kleidiai_kernels * kernels = kernel_chain[slot];\n                lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;\n                kernel_info * kernel        = is_gemv ? &kernels->gemv : &kernels->gemm;\n                if (!lhs_info || !lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex || !kernel) {\n                    return false;\n                }\n\n                const size_t mr = kernel->get_mr();\n                const size_t kr = kernel->get_kr();\n                const size_t sr = kernel->get_sr();\n\n                cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);\n                cursor += lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr);\n                any_slot = true;\n            }\n\n            for (int slot = 0; slot < slot_count; ++slot) {\n                ggml_kleidiai_kernels * kernels = kernel_chain[slot];\n                kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;\n                if (!kernel || !kernels->rhs_info.packed_size_ex) {\n                    return false;\n                }\n                cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);\n                cursor += kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0);\n            }\n\n            cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);\n            cursor += k * n * sizeof(float);\n            cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);\n            cursor += n * sizeof(float);\n\n            if (!any_slot) {\n                return false;\n            }\n\n            size = cursor;\n            return true;\n        }\n\n        return false;\n    }\n\n    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {\n        if (dst->op == GGML_OP_MUL_MAT) {\n            if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {\n                return compute_forward_qx(params, dst);\n            } else if (dst->src[0]->type == GGML_TYPE_F16) {\n                return compute_forward_fp16(params, dst);\n            }\n        } else if (dst->op == GGML_OP_GET_ROWS) {\n            if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {\n                return compute_forward_get_rows(params, dst);\n            }\n        }\n        return false;\n    }\n\n    bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {\n        const ggml_tensor * src0 = dst->src[0];\n        const ggml_tensor * src1 = dst->src[1];\n\n        GGML_TENSOR_BINARY_OP_LOCALS\n\n        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);\n        if (!kernels) {\n            return false;\n        }\n\n        const bool is_gemv = src1->ne[1] == 1;\n        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;\n        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;\n        GGML_ASSERT(kernel);\n        if (!kernels->rhs_info.pack_func_ex ||\n            !kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) {\n            return false;\n        }\n\n        const int nth = params->nth;\n        const int ith = params->ith;\n\n        const int64_t lhs_batch_size0 = ne12;\n        const int64_t rhs_batch_size0 = ne02;\n        const int64_t batch_size      = lhs_batch_size0;\n\n        GGML_ASSERT(rhs_batch_size0 > 0);\n        GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0);\n        const int64_t r = lhs_batch_size0 / rhs_batch_size0;\n\n        const int64_t m_group = ne11;\n        const int64_t m       = m_group;\n        const int64_t n       = ne01;\n        const int64_t k       = ne00;\n\n        const size_t lhs_stride = src1->nb[1];\n        const size_t rhs_stride = src0->nb[1];\n        const size_t dst_stride = dst->nb[1];\n\n        const int64_t mr = (int64_t) kernel->get_mr();\n        const int64_t nr = (int64_t) kernel->get_nr();\n        const int64_t kr = (int64_t) kernel->get_kr();\n        const int64_t sr = (int64_t) kernel->get_sr();\n\n        const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr);\n        const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0);\n        const size_t kxn_size        = k * n * sizeof(float);\n        const size_t bias_size       = n * sizeof(float);\n\n        const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;\n        GGML_ASSERT(wsize_required <= params->wsize);\n\n        uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);\n        uint8_t * rhs_packed = lhs_packed + lhs_packed_size;\n        uint8_t * rhs_kxn    = rhs_packed + rhs_packed_size;\n        uint8_t * bias       = rhs_kxn + kxn_size;\n\n        for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {\n            const int64_t rhs_batch_idx = batch_idx / r;\n            const uint8_t * rhs_batch_base = static_cast<const uint8_t *>(src0->data) + rhs_batch_idx * src0->nb[2];\n            uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];\n\n            // LHS packing (threaded over m, honoring mr alignment and KV groups)\n            {\n                const int64_t m_roundup_mr = kai_roundup(m, mr);\n                const int64_t num_threads  = KAI_MIN(m_roundup_mr / mr, nth);\n\n                if (ith < num_threads) {\n                    const int64_t num_m_per_thread0   = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr);\n                    const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;\n\n                    const int64_t m_start = ith * num_m_per_thread0;\n                    const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;\n\n                    // Base packed offset (aligned) and per-row stride in bytes\n                    const size_t base_packed_off  = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);\n                    const size_t next_block_off   = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr);\n                    const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;\n\n                    int64_t remaining = m_count;\n                    int64_t cur       = m_start;\n\n                    while (remaining > 0) {\n                        const int64_t row_in_group = cur;\n                        const int64_t avail        = m_group - row_in_group;\n                        const int64_t take         = std::min(avail, remaining);\n\n                        const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];\n                        const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride;\n                        const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;\n                        void * dst_ptr       = lhs_packed + dst_off;\n\n                        lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);\n\n                        cur       += take;\n                        remaining -= take;\n                    }\n                }\n            }\n\n            // RHS packing (single thread), then synchronize\n            if (ith == 0) {\n                memset(bias, 0, (size_t)n * sizeof(float));\n                transpose_f32kxn_f16nxk((size_t)n, (size_t)k,\n                                        reinterpret_cast<float *>(rhs_kxn),\n                                        reinterpret_cast<const uint16_t *>(rhs_batch_base),\n                                        rhs_stride);\n\n                kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float),\n                             rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);\n            }\n\n            ggml_barrier(params->threadpool);\n\n            // Matmul (threaded over n)\n            {\n                const int64_t n_step  = (int64_t) kernel->get_n_step();\n                int64_t num_threads_n = KAI_MIN(n / n_step, nth);\n                if (num_threads_n <= 0) {\n                    num_threads_n = 1;\n                }\n\n                if (ith < num_threads_n) {\n                    const int64_t num_n_per_thread0   = round_down((size_t)(n / num_threads_n), (size_t)n_step);\n                    const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0;\n\n                    const int64_t n_start      = ith * num_n_per_thread0;\n                    const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;\n\n                    // LHS packed base at row 0 (consistent with packing above)\n                    const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);\n                    const size_t rhs_packed_offset  = kernel->get_rhs_packed_offset_ex(n_start, k, 0);\n                    const size_t dst_offset         = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);\n\n                    const void * lhs_ptr = lhs_packed + lhs_packed_offset0;\n                    const void * rhs_ptr = rhs_packed + rhs_packed_offset;\n                    float * dst_ptr      = reinterpret_cast<float *>(dst_batch_base + dst_offset);\n\n                    kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);\n                }\n            }\n\n            if (batch_idx != batch_size - 1) {\n                ggml_barrier(params->threadpool);\n            }\n        }\n\n        return true;\n    }\n\n    bool compute_forward_qx(struct ggml_compute_params * params, struct ggml_tensor * dst) {\n        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);\n\n        const ggml_tensor * src0 = dst->src[0];\n        const ggml_tensor * src1 = dst->src[1];\n\n        GGML_TENSOR_BINARY_OP_LOCALS\n\n        const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);\n        const bool has_header = kleidiai_is_weight_header_valid(header);\n        const bool is_gemv = src1->ne[1] == 1;\n        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;\n        const int slot_total = kleidiai_collect_kernel_chain(dst, kernel_chain);\n\n        auto weight_for_slot = [&](int slot_index, size_t & size_out) -> const uint8_t * {\n            if (slot_index < 0 || slot_index >= slot_total) {\n                return nullptr;\n            }\n            if (has_header) {\n                if (slot_index < header->slot_count) {\n                    size_out = static_cast<size_t>(header->sizes[slot_index]);\n                    return kleidiai_weight_slot_ptr(header, slot_index);\n                }\n                return nullptr;\n            }\n            if (slot_index == 0) {\n                size_out = ggml_nbytes(src0);\n                return static_cast<const uint8_t *>(src0->data);\n            }\n            return nullptr;\n        };\n\n        struct runtime_slot {\n            int slot_index;\n            ggml_kleidiai_kernels * kernels;\n            kernel_info * kernel;\n            lhs_packing_info * lhs_info;\n            size_t mr;\n            size_t nr;\n            size_t kr;\n            size_t sr;\n            size_t n_step;\n            size_t lhs_packed_size;\n            size_t lhs_offset;\n            size_t n_offset;\n            size_t n_cols;\n            int assigned_threads;\n            int thread_begin;\n            int thread_end;\n            const uint8_t * rhs_base;\n        };\n\n        std::array<runtime_slot, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> runtime{};\n        int runtime_count = 0;\n\n        for (int slot = 0; slot < slot_total && runtime_count < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {\n            ggml_kleidiai_kernels * kernels = kernel_chain[slot];\n            kernel_info * kinfo      = is_gemv ? &kernels->gemv : &kernels->gemm;\n            lhs_packing_info * linfo = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;\n            if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || !linfo->get_offset ||\n                !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset) {\n                continue;\n            }\n\n            size_t rhs_size = 0;\n            const uint8_t * rhs_ptr = weight_for_slot(slot, rhs_size);\n            if (!rhs_ptr || rhs_size == 0) {\n                continue;\n            }\n\n            runtime[runtime_count] = {\n                slot,\n                kernels,\n                kinfo,\n                linfo,\n                kinfo->get_mr(),\n                kinfo->get_nr(),\n                kinfo->get_kr(),\n                kinfo->get_sr(),\n                kinfo->get_n_step(),\n                0,\n                0,\n                0,\n                0,\n                0,\n                0,\n                0,\n                rhs_ptr\n            };\n            ++runtime_count;\n        }\n\n        if (runtime_count == 0) {\n            ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst);\n            if (!fallback) {\n                return false;\n            }\n            kernel_info * kinfo      = is_gemv ? &fallback->gemv : &fallback->gemm;\n            lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info;\n            rhs_packing_info * rinfo = &fallback->rhs_info;\n            if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex ||\n                !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset ||\n                !rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) {\n                return false;\n            }\n            kernel_chain[0] = fallback;\n            runtime[0] = {\n                0,\n                fallback,\n                kinfo,\n                linfo,\n                kinfo->get_mr(),\n                kinfo->get_nr(),\n                kinfo->get_kr(),\n                kinfo->get_sr(),\n                kinfo->get_n_step(),\n                0,\n                0,\n                0,\n                0,\n                0,\n                0,\n                0,\n                nullptr\n            };\n            size_t rhs_size_fallback = 0;\n            const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback);\n            if (!rhs_base) {\n                rhs_base = static_cast<const uint8_t *>(src0->data);\n            }\n            runtime[0].rhs_base = rhs_base;\n            runtime_count = 1;\n        }\n\n        const int nth_total = params->nth > 0 ? params->nth : 1;\n        const int ith_total = params->ith;\n\n        int sme_slot = -1;\n        for (int i = 0; i < runtime_count; ++i) {\n            if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {\n                sme_slot = i;\n                break;\n            }\n        }\n\n        const int sme_cap_limit = ctx.sme_thread_cap;\n        const bool use_hybrid = sme_cap_limit > 0 &&\n                                 runtime_count > 1 &&\n                                 nth_total > sme_cap_limit;\n        // Heuristic: disable hybrid for very small workloads where per-slot overhead dominates.\n        // If rows are small or average columns per thread are small, keep single-slot.\n        size_t min_cols_per_thread = 0;\n        if (runtime_count > 0 && nth_total > 0) {\n            min_cols_per_thread = (size_t) std::max<int64_t>(1, (int64_t)ne01 / (int64_t)nth_total);\n        }\n        const bool too_small_for_hybrid = (min_cols_per_thread < 2) || (ne11 < 128);\n\n        const bool hybrid_enabled = use_hybrid && !too_small_for_hybrid;\n\n        if (!hybrid_enabled) {\n            int chosen_slot = 0;\n            if (too_small_for_hybrid && sme_slot != -1) {\n                chosen_slot = sme_slot;\n            } else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) {\n                chosen_slot = 1;\n            }\n            if (chosen_slot != 0 && chosen_slot < runtime_count) {\n                runtime[0] = runtime[chosen_slot];\n            }\n            runtime_count = runtime_count > 0 ? 1 : 0;\n\n            // Recompute SME slot based on the collapsed runtime[0]\n            sme_slot = -1;\n            if (runtime_count > 0 &&\n                (runtime[0].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {\n                sme_slot = 0;\n            }\n        }\n\n        int sme_cap = kleidiai_sme_thread_cap();\n        if (sme_cap < 0) {\n            sme_cap = nth_total;\n        }\n        sme_cap = std::min(sme_cap, nth_total);\n\n        int threads_remaining = nth_total;\n        if (sme_slot != -1) {\n            int sme_threads = std::min(std::max(sme_cap, 0), threads_remaining);\n            runtime[sme_slot].assigned_threads = sme_threads;\n            threads_remaining -= sme_threads;\n        }\n\n        int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];\n        int fallback_count = 0;\n        for (int i = 0; i < runtime_count; ++i) {\n            if (i == sme_slot) {\n                continue;\n            }\n            fallback_indices[fallback_count++] = i;\n        }\n\n        for (int fi = 0; fi < fallback_count; ++fi) {\n            if (threads_remaining <= 0) {\n                break;\n            }\n            const int slot_index = fallback_indices[fi];\n            const int slots_left = fallback_count - fi;\n            int share = (threads_remaining + slots_left - 1) / slots_left;\n            share     = std::min(share, threads_remaining);\n            runtime[slot_index].assigned_threads = share;\n            threads_remaining -= share;\n        }\n\n        if (threads_remaining > 0) {\n            const int fallback_slot = (sme_slot != -1) ? sme_slot : 0;\n            runtime[fallback_slot].assigned_threads += threads_remaining;\n            threads_remaining = 0;\n        }\n\n        int thread_cursor = 0;\n        for (int i = 0; i < runtime_count; ++i) {\n            runtime[i].thread_begin = thread_cursor;\n            thread_cursor += runtime[i].assigned_threads;\n            runtime[i].thread_end = thread_cursor;\n        }\n\n        if (thread_cursor < nth_total && runtime_count > 0) {\n            runtime[runtime_count - 1].assigned_threads += nth_total - thread_cursor;\n            runtime[runtime_count - 1].thread_end = nth_total;\n        }\n\n        int local_slot = -1;\n        int local_ith  = 0;\n        for (int i = 0; i < runtime_count; ++i) {\n            if (ith_total >= runtime[i].thread_begin && ith_total < runtime[i].thread_end) {\n                local_slot = i;\n                local_ith  = ith_total - runtime[i].thread_begin;\n                break;\n            }\n        }\n        if (local_slot == -1) {\n            return false;\n        }\n\n        const size_t k = ne00;\n        const size_t m = ne11;\n        const size_t n = ne01;\n\n        size_t cursor = 0;\n        for (int i = 0; i < runtime_count; ++i) {\n            const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type;\n            const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :\n                                              slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;\n            runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr);\n            cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);\n            runtime[i].lhs_offset = cursor;\n            cursor += runtime[i].lhs_packed_size;\n        }\n\n        GGML_ASSERT(cursor <= params->wsize);\n        uint8_t * scratch = static_cast<uint8_t *>(params->wdata);\n\n        size_t assigned_cols = 0;\n        uint64_t weighted_total = 0;\n        if (runtime_count > 1 && sme_slot != -1) {\n            for (int i = 0; i < runtime_count; ++i) {\n                const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;\n                weighted_total += (uint64_t)runtime[i].assigned_threads * weight;\n            }\n        }\n        for (int i = 0; i < runtime_count; ++i) {\n            runtime[i].n_offset = assigned_cols;\n            if (runtime[i].assigned_threads == 0) {\n                runtime[i].n_cols = 0;\n                continue;\n            }\n            const size_t remaining_cols = n - assigned_cols;\n            if (remaining_cols == 0) {\n                runtime[i].n_cols = 0;\n                continue;\n            }\n            const size_t step = runtime[i].n_step ? runtime[i].n_step : 1;\n            size_t target      = 0;\n            if (weighted_total > 0) {\n                const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;\n                target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total);\n            } else {\n                target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total);\n            }\n            target             = std::min(target, remaining_cols);\n            size_t aligned     = round_down(target, step);\n            if (aligned == 0 && remaining_cols >= step) {\n                aligned = step;\n            }\n            runtime[i].n_cols = aligned;\n            assigned_cols += aligned;\n        }\n\n        if (assigned_cols < n) {\n            for (int i = runtime_count - 1; i >= 0; --i) {\n                if (runtime[i].assigned_threads > 0) {\n                    runtime[i].n_cols += n - assigned_cols;\n                    break;\n                }\n            }\n        }\n        const size_t dst_stride = dst->nb[1];\n\n        for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) {\n            const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];\n            uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];\n\n            if (runtime[local_slot].assigned_threads > 0) {\n                runtime_slot & slot = runtime[local_slot];\n                const ggml_type slot_rhs_type = slot.kernels->rhs_type;\n                const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :\n                                                 slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;\n                const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr);\n                int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads;\n                max_threads = std::max<int64_t>(1, max_threads);\n                const int64_t use_threads = std::min<int64_t>(slot.assigned_threads, max_threads);\n\n                if (local_ith < use_threads) {\n                    const int64_t num_m_per_thread0   = round_down((size_t)(m_roundup_mr / use_threads), slot.mr);\n                    const int64_t num_m_per_threadN_1 = (int64_t)m - (use_threads - 1) * num_m_per_thread0;\n\n                    const int64_t m_start = (int64_t)local_ith * num_m_per_thread0;\n                    const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;\n\n                    const size_t base_packed_off  = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);\n                    const size_t next_block_off   = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);\n                    const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0;\n\n                    int64_t remaining = m_count;\n                    int64_t cur       = m_start;\n\n                    uint8_t * lhs_packed = scratch + slot.lhs_offset;\n                    while (remaining > 0) {\n                        const int64_t row_in_group = cur;\n                        const int64_t avail        = (int64_t)m - row_in_group;\n                        const int64_t take         = std::min(avail, remaining);\n\n                        const size_t src_off = slot.lhs_info->get_offset(row_in_group, src1->nb[1]);\n                        const void * src_ptr = lhs_batch_base + src_off;\n                        const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;\n                        void * dst_ptr       = lhs_packed + dst_off;\n\n                        slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr);\n\n                        cur       += take;\n                        remaining -= take;\n                    }\n                }\n            }\n\n            ggml_barrier(params->threadpool);\n\n            runtime_slot & slot = runtime[local_slot];\n            if (slot.n_cols > 0 && slot.assigned_threads > 0) {\n                int64_t active_threads = slot.assigned_threads;\n                const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads;\n                if (max_threads > 0) {\n                    active_threads = std::min<int64_t>(active_threads, std::max<int64_t>(1, max_threads));\n                }\n                active_threads = std::max<int64_t>(1, active_threads);\n\n                if (local_ith < active_threads) {\n                    const size_t step = slot.n_step ? slot.n_step : 1;\n                    const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step);\n                    const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0;\n                    const size_t local_start = (size_t)local_ith * chunk0;\n                    const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0;\n\n                    if (cols > 0) {\n                        const ggml_type slot_rhs_type = slot.kernels->rhs_type;\n                        const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :\n                                                         slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;\n                        const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :\n                                                          slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;\n                        const size_t global_start = slot.n_offset + local_start;\n                        const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);\n                        const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg);\n                        const size_t dst_offset        = slot.kernel->get_dst_offset(0, global_start, dst_stride);\n\n                        const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset;\n                        const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset;\n                        float * dst_ptr         = reinterpret_cast<float *>(dst_batch_base + dst_offset);\n\n                        slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg,\n                                                   lhs_ptr,\n                                                   rhs_ptr,\n                                                   dst_ptr,\n                                                   dst_stride,\n                                                   sizeof(float),\n                                                   -FLT_MAX,\n                                                   FLT_MAX);\n                    }\n                }\n            }\n\n            if (batch_idx != ne12 - 1) {\n                ggml_barrier(params->threadpool);\n            }\n        }\n\n        return true;\n    }\n\n    bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {\n        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);\n        const ggml_tensor * src0 = dst->src[0];\n        const ggml_tensor * src1 = dst->src[1];\n\n        GGML_TENSOR_BINARY_OP_LOCALS\n\n        const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);\n        const bool has_header = kleidiai_is_weight_header_valid(header);\n\n        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;\n        const bool want_q8 = src0->type == GGML_TYPE_Q8_0;\n        const int chain_count = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)\n                                        : kleidiai_collect_q4_chain(kernel_chain);\n\n        ggml_kleidiai_kernels * kernels = nullptr;\n        const uint8_t * packed_base = static_cast<const uint8_t *>(src0->data);\n\n        if (has_header && chain_count > 0) {\n            int select_slot = 0;\n            if (select_slot >= header->slot_count) {\n                select_slot = header->slot_count - 1;\n            }\n            if (select_slot >= 0 && select_slot < chain_count) {\n                kernels = kernel_chain[select_slot];\n                const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, select_slot);\n                if (slot_ptr) {\n                    packed_base = slot_ptr;\n                }\n            }\n        }\n\n        if (!kernels && chain_count > 0) {\n            kernels = kernel_chain[0];\n            if (has_header) {\n                const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, 0);\n                if (slot_ptr) {\n                    packed_base = slot_ptr;\n                }\n            }\n        }\n\n        if (!kernels) {\n            return false;\n        }\n\n        rhs_packing_info * rhs_info = &kernels->rhs_info;\n        kernel_info * kernel        = &kernels->gemm;\n        if (!rhs_info->to_float || !kernel->get_nr) {\n            return false;\n        }\n\n        const int64_t nc     = ne00;\n        const int64_t nr     = ggml_nelements(src1);\n\n        const ggml_type rhs_type = kernels->rhs_type;\n        size_t block_len = 0;\n        size_t num_bytes_multiplier = 0;\n        if (rhs_type == GGML_TYPE_Q4_0) {\n            block_len = QK4_0;\n            num_bytes_multiplier = sizeof(uint16_t);\n        } else if (rhs_type == GGML_TYPE_Q8_0) {\n            block_len = QK8_0;\n            num_bytes_multiplier = sizeof(float);\n        } else {\n            return false;\n        }\n\n        const size_t block_rows = kernel->get_nr();\n        const size_t kr         = kernel->get_kr();\n\n        const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);\n\n        const int ith = params->ith;\n        const int nth = params->nth;\n\n        const int dr = (nr + nth - 1) / nth;\n        const int ir0 = dr * ith;\n        const int ir1 = MIN(ir0 + dr, nr);\n\n        for (int64_t i = ir0; i < ir1; ++i) {\n            GGML_ASSERT(src1->type == GGML_TYPE_I32);\n            int64_t row_idx = ((const int32_t *)src1->data)[i];\n            GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);\n\n            float *out = (float *)((char *)dst->data + i * nb1);\n            rhs_info->to_float(packed_base, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);\n        }\n\n        return true;\n    }\n\npublic:\n    int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {\n        GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0);\n        const size_t n = tensor->ne[1];\n        const size_t k = tensor->ne[0];\n\n        kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(tensor->data);\n        if (!header) {\n            return -1;\n        }\n\n        header->magic      = GGML_KLEIDIAI_PACK_MAGIC;\n        header->version    = GGML_KLEIDIAI_PACK_VERSION;\n        header->slot_count = 0;\n\n        uint8_t * base_ptr = static_cast<uint8_t *>(tensor->data);\n        size_t cursor = sizeof(kleidiai_weight_header);\n        cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);\n\n        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;\n        const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;\n        const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)\n                                       : kleidiai_collect_q4_chain(kernel_chain);\n        const bool allow_fallback = kleidiai_pack_fallback_allowed();\n\n        std::vector<int8_t> qdata;\n        std::vector<float>  scales;\n\n        if (want_q8 && slot_total > 0) {\n            qdata.resize(n * k, 0);\n            scales.resize(n, 0.0f);\n\n            const size_t row_stride = tensor->nb[1];\n            const size_t k_blocks   = (k + QK8_0 - 1) / QK8_0;\n\n            for (size_t row = 0; row < n; ++row) {\n                const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(\n                    static_cast<const uint8_t *>(data) + row * row_stride);\n\n                float max_abs = 0.0f;\n                for (size_t block = 0; block < k_blocks; ++block) {\n                    const block_q8_0 & blk = row_blocks[block];\n                    const float d = GGML_FP16_TO_FP32(blk.d);\n                    for (size_t l = 0; l < QK8_0; ++l) {\n                        const size_t linear_idx = block * QK8_0 + l;\n                        if (linear_idx >= k) {\n                            break;\n                        }\n                        const float value = d * static_cast<float>(blk.qs[l]);\n                        max_abs = std::max(max_abs, std::fabs(value));\n                    }\n                }\n\n                float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;\n                scales[row] = scale;\n                const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;\n\n                for (size_t block = 0; block < k_blocks; ++block) {\n                    const block_q8_0 & blk = row_blocks[block];\n                    const float d = GGML_FP16_TO_FP32(blk.d);\n                    for (size_t l = 0; l < QK8_0; ++l) {\n                        const size_t linear_idx = block * QK8_0 + l;\n                        if (linear_idx >= k) {\n                            break;\n                        }\n                        const float value = d * static_cast<float>(blk.qs[l]);\n                        int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;\n                        q = std::clamp(q, -127, 127);\n                        qdata[row * k + linear_idx] = static_cast<int8_t>(q);\n                    }\n                }\n            }\n        }\n\n        for (int slot = 0; slot < slot_total && slot < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {\n            if (!allow_fallback && slot > 0) {\n                break;\n            }\n            ggml_kleidiai_kernels * kernels = kernel_chain[slot];\n            kernel_info * kernel = &kernels->gemm;\n            rhs_packing_info * rhs_info = &kernels->rhs_info;\n            if (!rhs_info || !rhs_info->pack_func_ex || !rhs_info->packed_size_ex || !kernel) {\n                continue;\n            }\n\n            const size_t nr = kernel->get_nr();\n            const size_t kr = kernel->get_kr();\n            const size_t sr = kernel->get_sr();\n            const ggml_type rhs_type = kernels->rhs_type;\n            const size_t block_len = rhs_type == GGML_TYPE_Q8_0 ? QK8_0 :\n                                     rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : 0;\n            if (block_len == 0) {\n                continue;\n            }\n\n            const size_t packed_size = rhs_info->packed_size_ex(n, k, nr, kr, block_len);\n            const size_t aligned_cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);\n\n            uint8_t * dst_ptr = base_ptr + aligned_cursor;\n\n            if (rhs_type == GGML_TYPE_Q4_0) {\n                struct kai_rhs_pack_qs4cxs1s0_param params;\n                params.lhs_zero_point = 1;\n                params.rhs_zero_point = 8;\n                rhs_info->pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,\n                                       static_cast<const uint8_t *>(data), nullptr, nullptr,\n                                       dst_ptr, 0, &params);\n            } else if (rhs_type == GGML_TYPE_Q8_0) {\n                struct kai_rhs_pack_qsi8cx_params params;\n                params.lhs_zero_point = 1;\n                params.scale_multiplier = 1.0f;\n                rhs_info->pack_func_ex(1, n, k, nr, kr, sr, 0, 0,\n                                       qdata.data(), nullptr, scales.data(),\n                                       dst_ptr, 0, &params);\n            } else {\n                continue;\n            }\n\n            header->offsets[header->slot_count] = aligned_cursor;\n            header->sizes[header->slot_count]   = packed_size;\n            ++header->slot_count;\n\n            cursor = aligned_cursor + packed_size;\n        }\n\n        if (header->slot_count == 0) {\n            header->magic   = 0;\n            header->version = 0;\n            memcpy(tensor->data, data, data_size);\n        }\n\n        return 0;\n    }\n};\n\nstatic ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {\n    static tensor_traits traits;\n    return &traits;\n}\n}  // namespace ggml::cpu::kleidiai\n\nstatic enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {\n    tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);\n\n    return GGML_STATUS_SUCCESS;\n    GGML_UNUSED(buffer);\n}\n\nstatic void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,\n                                                       const void * data, size_t offset, size_t size) {\n    GGML_ASSERT(offset == 0);\n    GGML_ASSERT(size == ggml_nbytes(tensor));\n\n    auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;\n    auto OK            = tensor_traits->repack(tensor, data, size);\n\n    GGML_ASSERT(OK == 0);\n    GGML_UNUSED(buffer);\n}\n\nstatic const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    GGML_UNUSED(buft);\n    return \"CPU_KLEIDIAI\";\n}\n\nstatic ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);\n\n    if (buffer == nullptr) {\n        return nullptr;\n    }\n\n    buffer->buft              = buft;\n    buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;\n    buffer->iface.set_tensor  = ggml_backend_cpu_kleidiai_buffer_set_tensor;\n    buffer->iface.get_tensor  = nullptr;\n    buffer->iface.cpy_tensor  = nullptr;\n    return buffer;\n}\n\nstatic size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    GGML_UNUSED(buft);\n    return TENSOR_ALIGNMENT;\n}\n\nstatic size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {\n    GGML_UNUSED(buft);\n\n    if (tensor->type != GGML_TYPE_Q4_0 && tensor->type != GGML_TYPE_Q8_0) {\n        return ggml_nbytes(tensor);\n    }\n\n    const size_t n = tensor->ne[1];\n    const size_t k = tensor->ne[0];\n\n    size_t cursor = sizeof(kleidiai_weight_header);\n    cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);\n\n    std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;\n    const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;\n    const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)\n                                   : kleidiai_collect_q4_chain(kernel_chain);\n    const bool allow_fallback = kleidiai_pack_fallback_allowed();\n\n    size_t slot_count = 0;\n    for (int slot = 0; slot < slot_total; ++slot) {\n        if (!allow_fallback && slot > 0) {\n            break;\n        }\n        ggml_kleidiai_kernels * kernels = kernel_chain[slot];\n        if (!kernels) {\n            continue;\n        }\n        kernel_info * kernel = &kernels->gemm;\n        rhs_packing_info * rhs_info = &kernels->rhs_info;\n        if (!kernel || !rhs_info || !rhs_info->packed_size_ex) {\n            continue;\n        }\n\n        const ggml_type rhs_type = kernels->rhs_type;\n        const size_t block_len = rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :\n                                 rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;\n        if (block_len == 0) {\n            continue;\n        }\n\n        cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);\n        cursor += rhs_info->packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), block_len);\n        ++slot_count;\n    }\n\n    if (slot_count == 0) {\n        return ggml_nbytes(tensor);\n    }\n\n    return std::max(cursor, ggml_nbytes(tensor));\n}\n\nnamespace ggml::cpu::kleidiai {\nclass extra_buffer_type : ggml::cpu::extra_buffer_type {\n    bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {\n        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;\n        const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);\n        if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&\n            (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&\n            op->src[0]->buffer &&\n            (ggml_n_dims(op->src[0]) == 2) &&\n            op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() &&\n            slot_total > 0) {\n            if (op->src[0]->type == GGML_TYPE_Q4_0 && ctx.kernels_q4 == nullptr) {\n                return false;\n            }\n            if (op->src[0]->type == GGML_TYPE_Q8_0 && ctx.kernels_q8 == nullptr) {\n                return false;\n            }\n            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {\n                return false;\n            }\n            if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&\n                ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {\n                return true;\n            }\n        }\n        return false;\n    }\n\n    ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {\n        if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {\n            if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {\n                return (ggml::cpu::tensor_traits *) op->src[0]->extra;\n            } else {\n                std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;\n                const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);\n                const bool has_kernel = slot_total > 0;\n                if (has_kernel && op->src[1]->ne[1] > 1) {\n                    if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||\n                        (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {\n                        return nullptr;\n                    }\n                    return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);\n                }\n            }\n        }\n        return nullptr;\n    }\n};\n}  // namespace ggml::cpu::kleidiai\n\nggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {\n    static ggml::cpu::kleidiai::extra_buffer_type ctx;\n    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {\n        /* .iface    = */ {\n                           /* .get_name         = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,\n                           /* .alloc_buffer     = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,\n                           /* .get_alignment    = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,\n                           /* .get_max_size     = */ nullptr,  // defaults to SIZE_MAX\n                           /* .get_alloc_size   = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,\n                           /* .is_host          = */ nullptr,\n                           },\n        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),\n        /* .context = */ &ctx,\n    };\n\n    init_kleidiai_context();\n\n    return &ggml_backend_cpu_buffer_type_kleidiai;\n}\n"
  },
  {
    "path": "src/ggml-cpu/kleidiai/kleidiai.h",
    "content": "// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>\n// SPDX-License-Identifier: MIT\n//\n\n#pragma once\n\n#include \"ggml-alloc.h\"\n\n#ifdef  __cplusplus\nextern \"C\" {\n#endif\n\nggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void);\n\n#ifdef  __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/llamafile/sgemm.cpp",
    "content": "// Copyright 2024 Mozilla Foundation\n//\n// Permission is hereby granted, free of charge, to any person obtaining\n// a copy of this software and associated documentation files (the\n// \"Software\"), to deal in the Software without restriction, including\n// without limitation the rights to use, copy, modify, merge, publish,\n// distribute, sublicense, and/or sell copies of the Software, and to\n// permit persons to whom the Software is furnished to do so, subject to\n// the following conditions:\n//\n// The above copyright notice and this permission notice shall be\n// included in all copies or substantial portions of the Software.\n//\n// THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\n// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF\n// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND\n// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS\n// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN\n// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN\n// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n// SOFTWARE.\n\n//\n//                   _   _          ___ _      _   ___\n//                  | |_(_)_ _ _  _| _ ) |    /_\\ / __|\n//                  |  _| | ' \\ || | _ \\ |__ / _ \\\\__ \\.\n//                   \\__|_|_||_\\_, |___/____/_/ \\_\\___/\n//                             |__/\n//\n//                    BASIC LINEAR ALGEBRA SUBPROGRAMS\n//\n//\n// This file implements multithreaded CPU matrix multiplication for the\n// common contiguous use case C = Aᵀ * B. These kernels are designed to\n// have excellent performance[1] for matrices that fit in the CPU cache\n// without imposing any overhead such as cache filling or malloc calls.\n//\n// This implementation does not guarantee any upper bound with rounding\n// errors, which grow along with k. Our goal's to maximally exploit the\n// hardware for performance, and then use whatever resources remain for\n// improving numerical accuracy.\n//\n// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].\n//     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].\n\n#if defined(__GNUC__)\n#pragma GCC diagnostic ignored \"-Wpedantic\"\n#pragma GCC diagnostic ignored \"-Wignored-attributes\"\n#endif\n\n#include \"sgemm.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-cpu-impl.h\"\n#include \"ggml-quants.h\"\n#include \"simd-mappings.h\"\n\n#include <array>\n#include <type_traits>\n\n#ifdef _MSC_VER\n#define NOINLINE __declspec(noinline)\n#else\n#define NOINLINE __attribute__((__noinline__))\n#endif\n\n#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)\n#define VECTOR_REGISTERS 32\n#else\n#define VECTOR_REGISTERS 16\n#endif\n\n#if defined(__riscv_v_intrinsic)\n#define LMUL 4\n#endif\n\n#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)\n\nnamespace {\n\ninline float unhalf(ggml_fp16_t d) {\n    return GGML_CPU_FP16_TO_FP32(d);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// VECTORIZED ARITHMETIC OPERATIONS\n\n#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ninline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }\ninline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }\ninline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }\n#endif  // __SSE__\n\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ninline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }\ninline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }\ninline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }\n#endif // __AVX__\n\n#if defined(__AVX512F__)\ninline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }\ninline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }\ninline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }\n#endif // __AVX512F__\n\n#if defined(__ARM_NEON)\ninline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }\ninline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }\ninline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }\n#endif // __ARM_NEON\n\n#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)\ninline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }\ninline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }\ninline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }\n#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC\n\n#if defined(__VXE__) || defined(__VXE2__)\ninline float32x4_t add(float32x4_t x, float32x4_t y) { return vec_add(x, y); }\ninline float32x4_t sub(float32x4_t x, float32x4_t y) { return vec_sub(x, y); }\ninline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }\n#endif\n\n#if defined(__MMA__)\ntypedef vector unsigned char vec_t;\ntypedef __vector_quad acc_t;\n#endif\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// VECTORIZED FUSED MULTIPLY ADD\n\n/**\n * Computes a * b + c.\n */\ntemplate <typename T, typename U>\ninline U madd(T a, T b, U c) {\n    return add(mul(a, b), c);\n}\n\n#if defined(__FMA__)\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ntemplate <>\ninline __m256 madd(__m256 a, __m256 b, __m256 c) {\n    return _mm256_fmadd_ps(a, b, c);\n}\n#endif\n#if defined(__AVX512F__)\ntemplate <>\ninline __m512 madd(__m512 a, __m512 b, __m512 c) {\n    return _mm512_fmadd_ps(a, b, c);\n}\n#endif\n#if defined(__AVX512BF16__)\ntemplate <>\ninline __m512 madd(__m512bh a, __m512bh b, __m512 c) {\n    return _mm512_dpbf16_ps(c, a, b);\n}\ntemplate <>\ninline __m256 madd(__m256bh a, __m256bh b, __m256 c) {\n    return _mm256_dpbf16_ps(c, a, b);\n}\n#endif\n#endif\n\n#if defined(__ARM_FEATURE_FMA)\ntemplate <>\ninline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {\n    return vfmaq_f32(c, b, a);\n}\n#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)\ntemplate <>\ninline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {\n    return vfmaq_f16(c, b, a);\n}\n#endif\n#endif\n\n#if defined(__VXE__) || defined(__VXE2__)\ntemplate <>\ninline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {\n    return vec_madd(a, b, c);\n}\n#endif\n\n#if defined(__riscv_zvfh)\ntemplate <>\ninline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {\n    return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());\n}\ninline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {\n    return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());\n}\ninline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {\n    return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());\n}\ninline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {\n    return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());\n}\ninline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {\n    return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());\n}\ninline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {\n    return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());\n}\ninline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {\n    return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());\n}\ninline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {\n    return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());\n}\n#endif\n\n#if defined(__riscv_zvfbfwma)\ninline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {\n    return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());\n}\ninline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {\n    return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());\n}\ninline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {\n    return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());\n}\n#endif\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// VECTORIZED HORIZONTAL SUM\n\n#if defined(__ARM_NEON)\ninline float hsum(float32x4_t x) {\n    return vaddvq_f32(x);\n}\n#endif // __ARM_NEON\n\n#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)\ninline float hsum(float16x8_t x) {\n    return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),\n                                vcvt_f32_f16(vget_high_f16(x))));\n}\n#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC\n\n#if defined(__VXE__) || defined(__VXE2__)\ninline float hsum(float32x4_t x) {\n    float32x4_t tmp = x + vec_reve(x);\n    return tmp[0] + tmp[1];\n}\n#endif\n\n#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ninline float hsum(__m128 x) {\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\n    x = _mm_add_ps(x, _mm_movehl_ps(x, x));\n    x = _mm_add_ss(x, _mm_movehdup_ps(x));\n#else\n    __m128 t;\n    t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));\n    x = _mm_add_ps(x, t);\n    t = _mm_movehl_ps(t, x);\n    x = _mm_add_ss(x, t);\n#endif\n    return _mm_cvtss_f32(x);\n}\n#endif\n\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ninline float hsum(__m256 x) {\n    return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),\n                           _mm256_castps256_ps128(x)));\n}\n#endif // __AVX__\n\n#if defined(__AVX512F__)\ninline float hsum(__m512 x) {\n    return _mm512_reduce_add_ps(x);\n}\n#endif // __AVX512F__\n\n#if defined(__riscv_zvfh)\ninline float hsum(vfloat32m1_t x) {\n    return __riscv_vfmv_f_s_f32m1_f32(\n        __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));\n}\ninline float hsum(vfloat32m2_t x) {\n    return __riscv_vfmv_f_s_f32m1_f32(\n        __riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2()));\n}\ninline float hsum(vfloat32m4_t x) {\n    return __riscv_vfmv_f_s_f32m1_f32(\n        __riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4()));\n}\ninline float hsum(vfloat32m8_t x) {\n    return __riscv_vfmv_f_s_f32m1_f32(\n        __riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8()));\n}\n#endif\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// VECTORIZED MEMORY LOADING\n\ntemplate <typename T, typename U> T load(const U *);\n\n#if defined(__ARM_NEON)\ntemplate <> inline float32x4_t load(const float *p) {\n    return vld1q_f32(p);\n}\n#if !defined(_MSC_VER)\n// FIXME: this should check for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC\ntemplate <> inline float16x8_t load(const ggml_fp16_t *p) {\n    return vld1q_f16((const float16_t *)p);\n}\ntemplate <> inline float32x4_t load(const ggml_fp16_t *p) {\n    return vcvt_f32_f16(vld1_f16((const float16_t *)p));\n}\n#endif // _MSC_VER\n#endif // __ARM_NEON\n\n#if defined(__VXE__) || defined(__VXE2__)\ntemplate <> inline float32x4_t load(const ggml_fp16_t * p) {\n    float tmp[4];\n\n    for (int i = 0; i < 4; i++) {\n        tmp[i] = GGML_CPU_FP16_TO_FP32(p[i]);\n    }\n\n    return vec_xl(0, (const float *)(tmp));\n}\ntemplate <> inline float32x4_t load(const float * p) {\n    return vec_xl(0, p);\n}\n#endif\n\n#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ntemplate <> inline __m128 load(const float *p) {\n    return _mm_loadu_ps(p);\n}\n#endif  // __SSE__\n\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ntemplate <> inline __m256 load(const float *p) {\n    return _mm256_loadu_ps(p);\n}\n#endif // __AVX__\n\n#if defined(__AVX2__) || defined(__AVX512F__)\ntemplate <> inline __m256 load(const ggml_bf16_t *p) {\n    return _mm256_castsi256_ps(\n        _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16));\n}\n#endif // __AVX2__\n\n#if defined(__F16C__)\ntemplate <> inline __m256 load(const ggml_fp16_t *p) {\n    return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));\n}\n#endif // __F16C__\n\n#if defined(__AVX512F__)\ntemplate <> inline __m512 load(const float *p) {\n    return _mm512_loadu_ps(p);\n}\ntemplate <> inline __m512 load(const ggml_fp16_t *p) {\n    return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));\n}\ntemplate <> inline __m512 load(const ggml_bf16_t *p) {\n    return _mm512_castsi512_ps(\n        _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16));\n}\n#endif // __AVX512F__\n\n#if defined(__AVX512BF16__)\ntemplate <> inline __m512bh load(const ggml_bf16_t *p) {\n    return (__m512bh)_mm512_loadu_ps((const float *)p);\n}\ntemplate <> inline __m256bh load(const ggml_bf16_t *p) {\n    return (__m256bh)_mm256_loadu_ps((const float *)p);\n}\ntemplate <> inline __m512bh load(const float *p) {\n    return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));\n}\ntemplate <> inline __m256bh load(const float *p) {\n    return _mm512_cvtneps_pbh(_mm512_loadu_ps(p));\n}\n#endif\n\n#if defined(__riscv_zvfh)\ntemplate <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {\n    return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());\n}\ntemplate <> inline vfloat16m1_t load(const ggml_fp16_t *p) {\n    return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());\n}\ntemplate <> inline vfloat16m2_t load(const ggml_fp16_t *p) {\n    return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());\n}\ntemplate <> inline vfloat16m4_t load(const ggml_fp16_t *p) {\n    return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());\n}\ntemplate <> inline vfloat32m1_t load(const float *p) {\n    return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());\n}\ntemplate <> inline vfloat32m2_t load(const float *p) {\n    return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());\n}\ntemplate <> inline vfloat32m4_t load(const float *p) {\n    return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());\n}\ntemplate <> inline vfloat32m8_t load(const float *p) {\n    return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());\n}\n#endif\n\n#if defined(__riscv_zvfbfwma)\ntemplate <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {\n    return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());\n}\ntemplate <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {\n    return __riscv_vle16_v_bf16m1(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m1());\n}\ntemplate <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {\n    return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());\n}\n#endif\n\n#if defined(__riscv_zvfh)\ntemplate <typename T> T set_zero();\n\ntemplate <> inline vfloat16mf2_t set_zero() {\n    return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());\n}\ntemplate <> inline vfloat16m1_t set_zero() {\n    return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());\n}\ntemplate <> inline vfloat16m2_t set_zero() {\n    return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());\n}\ntemplate <> inline vfloat16m4_t set_zero() {\n    return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());\n}\ntemplate <> inline vfloat32m1_t set_zero() {\n    return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());\n}\ntemplate <> inline vfloat32m2_t set_zero() {\n    return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2());\n}\ntemplate <> inline vfloat32m4_t set_zero() {\n    return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4());\n}\ntemplate <> inline vfloat32m8_t set_zero() {\n    return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8());\n}\n#endif\n\n#if defined(__riscv_v_intrinsic)\ntemplate <typename T> size_t vlmax() {\n    if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return  __riscv_vsetvlmax_e16mf2(); }\n    else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return  __riscv_vsetvlmax_e16m1(); }\n    else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return  __riscv_vsetvlmax_e16m2(); }\n    else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return  __riscv_vsetvlmax_e16m4(); }\n    else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return  __riscv_vsetvlmax_e32m1(); }\n    else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return  __riscv_vsetvlmax_e32m2(); }\n    else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return  __riscv_vsetvlmax_e32m4(); }\n    else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return  __riscv_vsetvlmax_e32m8(); }\n    return 0;\n}\n#endif\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// FLOATING POINT MATRIX MULTIPLICATION\n\ntemplate <int M>\nstatic inline int64_t BLOCK_SIZE(size_t m) {\n    const int64_t NB_BLOC_M = (m + M - 1) / M;\n    return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;\n}\n\nstatic constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {\n    return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);\n}\n\ntemplate <int KN, typename D, typename V, typename TA, typename TB, typename TC>\nclass tinyBLAS {\n  public:\n    tinyBLAS(const ggml_compute_params * params, int64_t k,\n             const TA *A, int64_t lda,\n             const TB *B, int64_t ldb,\n             TC *C, int64_t ldc)\n        : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {\n    }\n\n    bool matmul(int64_t m, int64_t n) {\n        if (k % KN != 0)\n            return false;\n        // compute RM for only need tile with size RM&RM-1\n#if VECTOR_REGISTERS == 32\n        if (m % 16 == 0 && (m/16 >= params->nth)) {\n            const int64_t SIZE_N = BLOCK_SIZE<6>(n);\n            mnpack<4, 6, 4>(m, n, SIZE_N, 12);\n            return true;\n        }\n        if (m % 8 == 0 ) {\n            const int64_t SIZE_N = BLOCK_SIZE<6>(n);\n            mnpack<4, 6, 2>(m, n, SIZE_N, 12);\n            return true;\n        }\n        if (m % 4 == 0) {\n            const int64_t SIZE_N = BLOCK_SIZE<6>(n);\n            mnpack<4, 6, 1>(m, n, SIZE_N, 12);\n            return true;\n        }\n#else  // VECTOR_REGISTERS == 16\n        if (m % 16 == 0 && (m/16 >= params->nth)) {\n            const int64_t SIZE_N = BLOCK_SIZE<3>(n);\n            mnpack<4, 3, 4>(m, n, SIZE_N, 24);\n            return true;\n        }\n        if (m % 8 == 0 ) {\n            const int64_t SIZE_N = BLOCK_SIZE<3>(n);\n            mnpack<4, 3, 2>(m, n, SIZE_N, 24);\n            return true;\n        }\n        if (m % 4 == 0) {\n            const int64_t SIZE_N = BLOCK_SIZE<3>(n);\n            mnpack<4, 3, 1>(m, n, SIZE_N, 24);\n            return true;\n        }\n#endif\n        return false;\n    }\n\n  private:\n    template <int RM, int RN, int BM>\n    inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {\n        if (SIZE_N == RN) {\n            return gemm<RM, RN, BM>(m, n, BN);\n        }\n        if constexpr (RN > 1) {\n            return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);\n        } else {\n            GGML_LOG_ERROR(\"mnpack<%d, %d> block size not supported\\n\", RM, (int)SIZE_N);\n            GGML_ASSERT(false); // we have miss something.\n        }\n    }\n\n    template <int RM, int RN>\n    inline void gemm_bloc(int64_t ii, int64_t jj) {\n        D Cv[RN][RM] = {};\n        for (int64_t l = 0; l < k; l += KN) {\n            // help compiler for op order.\n            if constexpr (RM <= RN) {\n                V Av[RM];\n                for (int64_t i = 0; i < RM; ++i) {\n                    Av[i] = load<V>(A + lda * (ii + i) + l);\n                }\n                for (int64_t j = 0; j < RN; ++j) {\n                    V Bv = load<V>(B + ldb * (jj + j) + l);\n                    for (int64_t i = 0; i < RM; ++i) {\n                        Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);\n                    }\n                }\n            } else {\n                V Bv[RN];\n                for (int64_t j = 0; j < RN; ++j) {\n                    Bv[j] = load<V>(B + ldb * (jj + j) + l);\n                }\n                for (int64_t i = 0; i < RM; ++i) {\n                    V Av = load<V>(A + lda * (ii + i) + l);\n                    for (int64_t j = 0; j < RN; ++j) {\n                        Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);\n                    }\n                }\n            }\n        }\n        for (int64_t j = 0; j < RN; ++j)\n            for (int64_t i = 0; i < RM; ++i)\n                C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);\n    }\n\n    template <int RM, int RN, int BM>\n    NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {\n        GGML_ASSERT(m % (RM * BM) == 0);\n        const int64_t ytiles = m / (RM * BM);\n        const int64_t xtiles = (n + RN -1) / RN;\n        const int64_t jj_RN = (xtiles - (xtiles * RN - n));\n\n        // \"round\" bloc_size to \"nearest\" BN\n        const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;\n        const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;\n        const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));\n        const int64_t nb_job = ytiles * NB_BN;\n\n        if (params->ith == 0) {\n            GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);\n            // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.\n            ggml_threadpool_chunk_set(params->threadpool, params->nth);\n        }\n\n        ggml_barrier(params->threadpool);\n\n        int64_t job = params->ith;\n        while (job < nb_job) {\n            const int64_t ii = (job % ytiles) * RM * BM;\n            const int64_t jb =  job / ytiles;\n            const int64_t jr0 = BLOC_POS(jb  , jj_BN, SIZE_BN);\n            const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);\n\n            const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);\n            const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);\n            const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;\n\n            for (int64_t bi = 0; bi < BM * RM; bi += RM) {\n                int64_t jj = jj0;\n                for (; jj < jj1; jj += RN) {\n                    gemm_bloc<RM, RN>(ii + bi, jj);\n                }\n                if constexpr (RN > 1) {\n                    for (; jj < jj2; jj += RN - 1) {\n                        gemm_bloc<RM, RN-1>(ii + bi, jj);\n                    }\n                }\n                GGML_ASSERT(jj == jj2);\n            }\n\n            job = ggml_threadpool_chunk_add(params->threadpool, 1);\n        }\n\n        ggml_barrier(params->threadpool);\n        return;\n    }\n\n    const ggml_compute_params * params;\n    const TA *const A;\n    const TB *const B;\n    TC *const C;\n    const int64_t k;\n    const int64_t lda;\n    const int64_t ldb;\n    const int64_t ldc;\n};\n\n#if defined(__riscv_v_intrinsic)\ntemplate <typename D, typename V, typename TA, typename TB, typename TC>\nclass tinyBLAS_RVV {\n  public:\n    tinyBLAS_RVV(const ggml_compute_params * params, int64_t k,\n             const TA *A, int64_t lda,\n             const TB *B, int64_t ldb,\n             TC *C, int64_t ldc)\n        : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {\n    }\n\n    bool matmul(int64_t m, int64_t n) {\n        if (k % vlmax<V>() != 0) {\n            return false;\n        }\n\n#if LMUL == 1\n        if (m % 16 == 0 && (m/16 >= params->nth)) {\n            const int64_t SIZE_N = BLOCK_SIZE<6>(n);\n            mnpack<4, 6, 4>(m, n, SIZE_N, 12);\n            return true;\n        }\n        if (m % 8 == 0 ) {\n            const int64_t SIZE_N = BLOCK_SIZE<6>(n);\n            mnpack<4, 6, 2>(m, n, SIZE_N, 12);\n            return true;\n        }\n        if (m % 4 == 0) {\n            const int64_t SIZE_N = BLOCK_SIZE<6>(n);\n            mnpack<4, 6, 1>(m, n, SIZE_N, 12);\n            return true;\n        }\n#elif LMUL == 2\n        if (m % 16 == 0 && (m/16 >= params->nth)) {\n            const int64_t SIZE_N = BLOCK_SIZE<3>(n);\n            mnpack<4, 3, 4>(m, n, SIZE_N, 24);\n            return true;\n        }\n        if (m % 8 == 0 ) {\n            const int64_t SIZE_N = BLOCK_SIZE<3>(n);\n            mnpack<4, 3, 2>(m, n, SIZE_N, 24);\n            return true;\n        }\n        if (m % 4 == 0) {\n            const int64_t SIZE_N = BLOCK_SIZE<3>(n);\n            mnpack<4, 3, 1>(m, n, SIZE_N, 24);\n            return true;\n        }\n#else // LMUL = 4\n        if (m % 16 == 0 && (m/16 >= params->nth)) {\n            const int64_t SIZE_N = BLOCK_SIZE<2>(n);\n            mnpack<2, 2, 8>(m, n, SIZE_N, 36);\n            return true;\n        }\n        if (m % 8 == 0 ) {\n            const int64_t SIZE_N = BLOCK_SIZE<2>(n);\n            mnpack<2, 2, 4>(m, n, SIZE_N, 36);\n            return true;\n        }\n        if (m % 4 == 0) {\n            const int64_t SIZE_N = BLOCK_SIZE<2>(n);\n            mnpack<2, 2, 2>(m, n, SIZE_N, 36);\n            return true;\n        }\n#endif\n        return false;\n    }\n\n  private:\n    template<int RM, int RN, int BM>\n    inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {\n        if (SIZE_N == RN) {\n            return gemm<RM, RN, BM>(m, n, BN);\n        }\n        if constexpr (RN > 1) {\n            return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);\n        } else {\n            GGML_LOG_ERROR(\"mnpack<%d, %d> block size not supported\\n\", RM, (int)SIZE_N);\n            GGML_ASSERT(false); // we have miss something.\n        }\n    }\n\n    inline void gemm_bloc_4x6(int64_t ii, int64_t jj) {\n        size_t vl = vlmax<V>();\n        D Cv00 = set_zero<D>();\n        D Cv01 = set_zero<D>();\n        D Cv02 = set_zero<D>();\n        D Cv03 = set_zero<D>();\n        D Cv10 = set_zero<D>();\n        D Cv11 = set_zero<D>();\n        D Cv12 = set_zero<D>();\n        D Cv13 = set_zero<D>();\n        D Cv20 = set_zero<D>();\n        D Cv21 = set_zero<D>();\n        D Cv22 = set_zero<D>();\n        D Cv23 = set_zero<D>();\n        D Cv30 = set_zero<D>();\n        D Cv31 = set_zero<D>();\n        D Cv32 = set_zero<D>();\n        D Cv33 = set_zero<D>();\n        D Cv40 = set_zero<D>();\n        D Cv41 = set_zero<D>();\n        D Cv42 = set_zero<D>();\n        D Cv43 = set_zero<D>();\n        D Cv50 = set_zero<D>();\n        D Cv51 = set_zero<D>();\n        D Cv52 = set_zero<D>();\n        D Cv53 = set_zero<D>();\n\n        for (int64_t l = 0; l < k; l += vl) {\n            V Bv0 = load<V>(B + ldb * (jj + 0) + l);\n            V Bv1 = load<V>(B + ldb * (jj + 1) + l);\n            V Bv2 = load<V>(B + ldb * (jj + 2) + l);\n            V Bv3 = load<V>(B + ldb * (jj + 3) + l);\n            V Bv4 = load<V>(B + ldb * (jj + 4) + l);\n            V Bv5 = load<V>(B + ldb * (jj + 5) + l);\n\n            V Av0 = load<V>(A + lda * (ii + 0) + l);\n            Cv00 = madd(Av0, Bv0, Cv00);\n            Cv10 = madd(Av0, Bv1, Cv10);\n            Cv20 = madd(Av0, Bv2, Cv20);\n            Cv30 = madd(Av0, Bv3, Cv30);\n            Cv40 = madd(Av0, Bv4, Cv40);\n            Cv50 = madd(Av0, Bv5, Cv50);\n\n            V Av1 = load<V>(A + lda * (ii + 1) + l);\n            Cv01 = madd(Av1, Bv0, Cv01);\n            Cv11 = madd(Av1, Bv1, Cv11);\n            Cv21 = madd(Av1, Bv2, Cv21);\n            Cv31 = madd(Av1, Bv3, Cv31);\n            Cv41 = madd(Av1, Bv4, Cv41);\n            Cv51 = madd(Av1, Bv5, Cv51);\n\n            V Av2 = load<V>(A + lda * (ii + 2) + l);\n            Cv02 = madd(Av2, Bv0, Cv02);\n            Cv12 = madd(Av2, Bv1, Cv12);\n            Cv22 = madd(Av2, Bv2, Cv22);\n            Cv32 = madd(Av2, Bv3, Cv32);\n            Cv42 = madd(Av2, Bv4, Cv42);\n            Cv52 = madd(Av2, Bv5, Cv52);\n\n            V Av3 = load<V>(A + lda * (ii + 3) + l);\n            Cv03 = madd(Av3, Bv0, Cv03);\n            Cv13 = madd(Av3, Bv1, Cv13);\n            Cv23 = madd(Av3, Bv2, Cv23);\n            Cv33 = madd(Av3, Bv3, Cv33);\n            Cv43 = madd(Av3, Bv4, Cv43);\n            Cv53 = madd(Av3, Bv5, Cv53);\n        }\n\n        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);\n        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);\n        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);\n        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);\n        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);\n        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);\n        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);\n        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);\n        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);\n        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);\n        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);\n        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);\n        C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);\n        C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);\n        C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);\n        C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);\n        C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);\n        C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);\n        C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);\n        C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);\n        C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50);\n        C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51);\n        C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52);\n        C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53);\n    }\n\n    inline void gemm_bloc_4x5(int64_t ii, int64_t jj) {\n        size_t vl = vlmax<V>();\n        D Cv00 = set_zero<D>();\n        D Cv01 = set_zero<D>();\n        D Cv02 = set_zero<D>();\n        D Cv03 = set_zero<D>();\n        D Cv10 = set_zero<D>();\n        D Cv11 = set_zero<D>();\n        D Cv12 = set_zero<D>();\n        D Cv13 = set_zero<D>();\n        D Cv20 = set_zero<D>();\n        D Cv21 = set_zero<D>();\n        D Cv22 = set_zero<D>();\n        D Cv23 = set_zero<D>();\n        D Cv30 = set_zero<D>();\n        D Cv31 = set_zero<D>();\n        D Cv32 = set_zero<D>();\n        D Cv33 = set_zero<D>();\n        D Cv40 = set_zero<D>();\n        D Cv41 = set_zero<D>();\n        D Cv42 = set_zero<D>();\n        D Cv43 = set_zero<D>();\n\n        for (int64_t l = 0; l < k; l += vl) {\n            V Bv0 = load<V>(B + ldb * (jj + 0) + l);\n            V Bv1 = load<V>(B + ldb * (jj + 1) + l);\n            V Bv2 = load<V>(B + ldb * (jj + 2) + l);\n            V Bv3 = load<V>(B + ldb * (jj + 3) + l);\n            V Bv4 = load<V>(B + ldb * (jj + 4) + l);\n\n            V Av0 = load<V>(A + lda * (ii + 0) + l);\n            Cv00 = madd(Av0, Bv0, Cv00);\n            Cv10 = madd(Av0, Bv1, Cv10);\n            Cv20 = madd(Av0, Bv2, Cv20);\n            Cv30 = madd(Av0, Bv3, Cv30);\n            Cv40 = madd(Av0, Bv4, Cv40);\n\n            V Av1 = load<V>(A + lda * (ii + 1) + l);\n            Cv01 = madd(Av1, Bv0, Cv01);\n            Cv11 = madd(Av1, Bv1, Cv11);\n            Cv21 = madd(Av1, Bv2, Cv21);\n            Cv31 = madd(Av1, Bv3, Cv31);\n            Cv41 = madd(Av1, Bv4, Cv41);\n\n            V Av2 = load<V>(A + lda * (ii + 2) + l);\n            Cv02 = madd(Av2, Bv0, Cv02);\n            Cv12 = madd(Av2, Bv1, Cv12);\n            Cv22 = madd(Av2, Bv2, Cv22);\n            Cv32 = madd(Av2, Bv3, Cv32);\n            Cv42 = madd(Av2, Bv4, Cv42);\n\n            V Av3 = load<V>(A + lda * (ii + 3) + l);\n            Cv03 = madd(Av3, Bv0, Cv03);\n            Cv13 = madd(Av3, Bv1, Cv13);\n            Cv23 = madd(Av3, Bv2, Cv23);\n            Cv33 = madd(Av3, Bv3, Cv33);\n            Cv43 = madd(Av3, Bv4, Cv43);\n        }\n\n        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);\n        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);\n        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);\n        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);\n        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);\n        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);\n        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);\n        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);\n        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);\n        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);\n        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);\n        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);\n        C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);\n        C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);\n        C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);\n        C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);\n        C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);\n        C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);\n        C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);\n        C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);\n    }\n\n    inline void gemm_bloc_4x4(int64_t ii, int64_t jj) {\n        size_t vl = vlmax<V>();\n        D Cv00 = set_zero<D>();\n        D Cv01 = set_zero<D>();\n        D Cv02 = set_zero<D>();\n        D Cv03 = set_zero<D>();\n        D Cv10 = set_zero<D>();\n        D Cv11 = set_zero<D>();\n        D Cv12 = set_zero<D>();\n        D Cv13 = set_zero<D>();\n        D Cv20 = set_zero<D>();\n        D Cv21 = set_zero<D>();\n        D Cv22 = set_zero<D>();\n        D Cv23 = set_zero<D>();\n        D Cv30 = set_zero<D>();\n        D Cv31 = set_zero<D>();\n        D Cv32 = set_zero<D>();\n        D Cv33 = set_zero<D>();\n\n        for (int64_t l = 0; l < k; l += vl) {\n            V Av0 = load<V>(A + lda * (ii + 0) + l);\n            V Av1 = load<V>(A + lda * (ii + 1) + l);\n            V Av2 = load<V>(A + lda * (ii + 2) + l);\n            V Av3 = load<V>(A + lda * (ii + 3) + l);\n\n            V Bv0 = load<V>(B + ldb * (jj + 0) + l);\n            Cv00 = madd(Av0, Bv0, Cv00);\n            Cv01 = madd(Av1, Bv0, Cv01);\n            Cv02 = madd(Av2, Bv0, Cv02);\n            Cv03 = madd(Av3, Bv0, Cv03);\n\n            V Bv1 = load<V>(B + ldb * (jj + 1) + l);\n            Cv10 = madd(Av0, Bv1, Cv10);\n            Cv11 = madd(Av1, Bv1, Cv11);\n            Cv12 = madd(Av2, Bv1, Cv12);\n            Cv13 = madd(Av3, Bv1, Cv13);\n\n            V Bv2 = load<V>(B + ldb * (jj + 2) + l);\n            Cv20 = madd(Av0, Bv2, Cv20);\n            Cv21 = madd(Av1, Bv2, Cv21);\n            Cv22 = madd(Av2, Bv2, Cv22);\n            Cv23 = madd(Av3, Bv2, Cv23);\n\n            V Bv3 = load<V>(B + ldb * (jj + 3) + l);\n            Cv30 = madd(Av0, Bv3, Cv30);\n            Cv31 = madd(Av1, Bv3, Cv31);\n            Cv32 = madd(Av2, Bv3, Cv32);\n            Cv33 = madd(Av3, Bv3, Cv33);\n        }\n\n        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);\n        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);\n        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);\n        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);\n        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);\n        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);\n        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);\n        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);\n        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);\n        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);\n        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);\n        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);\n        C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);\n        C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);\n        C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);\n        C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);\n    }\n\n    inline void gemm_bloc_4x3(int64_t ii, int64_t jj) {\n        size_t vl = vlmax<V>();\n        D Cv00 = set_zero<D>();\n        D Cv01 = set_zero<D>();\n        D Cv02 = set_zero<D>();\n        D Cv03 = set_zero<D>();\n        D Cv10 = set_zero<D>();\n        D Cv11 = set_zero<D>();\n        D Cv12 = set_zero<D>();\n        D Cv13 = set_zero<D>();\n        D Cv20 = set_zero<D>();\n        D Cv21 = set_zero<D>();\n        D Cv22 = set_zero<D>();\n        D Cv23 = set_zero<D>();\n\n        for (int64_t l = 0; l < k; l += vl) {\n            V Av0 = load<V>(A + lda * (ii + 0) + l);\n            V Av1 = load<V>(A + lda * (ii + 1) + l);\n            V Av2 = load<V>(A + lda * (ii + 2) + l);\n            V Av3 = load<V>(A + lda * (ii + 3) + l);\n\n            V Bv0 = load<V>(B + ldb * (jj + 0) + l);\n            Cv00 = madd(Av0, Bv0, Cv00);\n            Cv01 = madd(Av1, Bv0, Cv01);\n            Cv02 = madd(Av2, Bv0, Cv02);\n            Cv03 = madd(Av3, Bv0, Cv03);\n\n            V Bv1 = load<V>(B + ldb * (jj + 1) + l);\n            Cv10 = madd(Av0, Bv1, Cv10);\n            Cv11 = madd(Av1, Bv1, Cv11);\n            Cv12 = madd(Av2, Bv1, Cv12);\n            Cv13 = madd(Av3, Bv1, Cv13);\n\n            V Bv2 = load<V>(B + ldb * (jj + 2) + l);\n            Cv20 = madd(Av0, Bv2, Cv20);\n            Cv21 = madd(Av1, Bv2, Cv21);\n            Cv22 = madd(Av2, Bv2, Cv22);\n            Cv23 = madd(Av3, Bv2, Cv23);\n        }\n\n        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);\n        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);\n        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);\n        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);\n        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);\n        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);\n        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);\n        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);\n        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);\n        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);\n        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);\n        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);\n    }\n\n    inline void gemm_bloc_4x2(int64_t ii, int64_t jj) {\n        size_t vl = vlmax<V>();\n        D Cv00 = set_zero<D>();\n        D Cv01 = set_zero<D>();\n        D Cv02 = set_zero<D>();\n        D Cv03 = set_zero<D>();\n        D Cv10 = set_zero<D>();\n        D Cv11 = set_zero<D>();\n        D Cv12 = set_zero<D>();\n        D Cv13 = set_zero<D>();\n\n        for (int64_t l = 0; l < k; l += vl) {\n            V Av0 = load<V>(A + lda * (ii + 0) + l);\n            V Av1 = load<V>(A + lda * (ii + 1) + l);\n            V Av2 = load<V>(A + lda * (ii + 2) + l);\n            V Av3 = load<V>(A + lda * (ii + 3) + l);\n\n            V Bv0 = load<V>(B + ldb * (jj + 0) + l);\n            Cv00 = madd(Av0, Bv0, Cv00);\n            Cv01 = madd(Av1, Bv0, Cv01);\n            Cv02 = madd(Av2, Bv0, Cv02);\n            Cv03 = madd(Av3, Bv0, Cv03);\n\n            V Bv1 = load<V>(B + ldb * (jj + 1) + l);\n            Cv10 = madd(Av0, Bv1, Cv10);\n            Cv11 = madd(Av1, Bv1, Cv11);\n            Cv12 = madd(Av2, Bv1, Cv12);\n            Cv13 = madd(Av3, Bv1, Cv13);\n        }\n\n        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);\n        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);\n        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);\n        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);\n        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);\n        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);\n        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);\n        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);\n    }\n\n    inline void gemm_bloc_4x1(int64_t ii, int64_t jj) {\n        size_t vl = vlmax<V>();\n        D Cv00 = set_zero<D>();\n        D Cv01 = set_zero<D>();\n        D Cv02 = set_zero<D>();\n        D Cv03 = set_zero<D>();\n\n        for (int64_t l = 0; l < k; l += vl) {\n            V Av0 = load<V>(A + lda * (ii + 0) + l);\n            V Av1 = load<V>(A + lda * (ii + 1) + l);\n            V Av2 = load<V>(A + lda * (ii + 2) + l);\n            V Av3 = load<V>(A + lda * (ii + 3) + l);\n\n            V Bv0 = load<V>(B + ldb * (jj + 0) + l);\n            Cv00 = madd(Av0, Bv0, Cv00);\n            Cv01 = madd(Av1, Bv0, Cv01);\n            Cv02 = madd(Av2, Bv0, Cv02);\n            Cv03 = madd(Av3, Bv0, Cv03);\n        }\n\n        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);\n        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);\n        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);\n        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);\n    }\n\n    inline void gemm_bloc_2x2(int64_t ii, int64_t jj) {\n        size_t vl = vlmax<V>();\n        D Cv00 = set_zero<D>();\n        D Cv01 = set_zero<D>();\n        D Cv10 = set_zero<D>();\n        D Cv11 = set_zero<D>();\n\n        for (int64_t l = 0; l < k; l += vl) {\n            V Av0 = load<V>(A + lda * (ii + 0) + l);\n            V Av1 = load<V>(A + lda * (ii + 1) + l);\n\n            V Bv0 = load<V>(B + ldb * (jj + 0) + l);\n            Cv00 = madd(Av0, Bv0, Cv00);\n            Cv01 = madd(Av1, Bv0, Cv01);\n\n            V Bv1 = load<V>(B + ldb * (jj + 1) + l);\n            Cv10 = madd(Av0, Bv1, Cv10);\n            Cv11 = madd(Av1, Bv1, Cv11);\n        }\n\n        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);\n        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);\n        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);\n        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);\n    }\n\n    inline void gemm_bloc_2x1(int64_t ii, int64_t jj) {\n        size_t vl = vlmax<V>();\n        D Cv00 = set_zero<D>();\n        D Cv01 = set_zero<D>();\n\n        for (int64_t l = 0; l < k; l += vl) {\n            V Av0 = load<V>(A + lda * (ii + 0) + l);\n            V Av1 = load<V>(A + lda * (ii + 1) + l);\n\n            V Bv0 = load<V>(B + ldb * (jj + 0) + l);\n            Cv00 = madd(Av0, Bv0, Cv00);\n            Cv01 = madd(Av1, Bv0, Cv01);\n        }\n\n        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);\n        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);\n    }\n\n    template <int RM, int RN>\n    inline void gemm_bloc(int64_t ii, int64_t jj) {\n        if constexpr (RM == 4) {\n            if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); }\n            if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); }\n            if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); }\n            if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); }\n            if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); }\n            if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); }\n        } else if constexpr (RM == 2) {\n            if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); }\n            if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); }\n        }\n    }\n\n    template <int RM, int RN, int BM>\n    NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {\n        GGML_ASSERT(m % (RM * BM) == 0);\n        const int64_t ytiles = m / (RM * BM);\n        const int64_t xtiles = (n + RN -1) / RN;\n        const int64_t jj_RN = (xtiles - (xtiles * RN - n));\n\n        // \"round\" bloc_size to \"nearest\" BN\n        const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;\n        const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;\n        const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));\n        const int64_t nb_job = ytiles * NB_BN;\n\n        if (params->ith == 0) {\n            GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);\n            // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.\n            ggml_threadpool_chunk_set(params->threadpool, params->nth);\n        }\n\n        ggml_barrier(params->threadpool);\n\n        int64_t job = params->ith;\n        while (job < nb_job) {\n            const int64_t ii = (job % ytiles) * RM * BM;\n            const int64_t jb =  job / ytiles;\n            const int64_t jr0 = BLOC_POS(jb  , jj_BN, SIZE_BN);\n            const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);\n\n            const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);\n            const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);\n            const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;\n\n            for (int64_t bi = 0; bi < BM * RM; bi += RM) {\n                int64_t jj = jj0;\n                for (; jj < jj1; jj += RN) {\n                    gemm_bloc<RM, RN>(ii + bi, jj);\n                }\n                if constexpr (RN > 1) {\n                    for (; jj < jj2; jj += RN - 1) {\n                        gemm_bloc<RM, RN-1>(ii + bi, jj);\n                    }\n                }\n                GGML_ASSERT(jj == jj2);\n            }\n\n            job = ggml_threadpool_chunk_add(params->threadpool, 1);\n        }\n\n        ggml_barrier(params->threadpool);\n        return;\n    }\n\n    const ggml_compute_params * params;\n    const TA *const A;\n    const TB *const B;\n    TC *const C;\n    const int64_t k;\n    const int64_t lda;\n    const int64_t ldb;\n    const int64_t ldc;\n};\n#endif\n\n//////////////////////////////////////////////////////////////////////////////////////////\n// QUANT ZERO MATRIX MULTIPLICATION\n\n#if defined(__ARM_FEATURE_DOTPROD)\ntemplate <typename TA>\nclass tinyBLAS_Q0_ARM {\n  public:\n    tinyBLAS_Q0_ARM(int64_t k,\n                    const TA *A, int64_t lda,\n                    const block_q8_0 *B, int64_t ldb,\n                    float *C, int64_t ldc,\n                    int ith, int nth)\n        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {\n    }\n\n    void matmul(int64_t m, int64_t n) {\n        mnpack(0, m, 0, n);\n    }\n\n  private:\n    NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {\n        int64_t mc, nc, mp, np;\n        switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {\n        case 0x33:\n            mc = 3;\n            nc = 3;\n            gemm<3, 3>(m0, m, n0, n);\n            break;\n        case 0x32:\n            mc = 3;\n            nc = 2;\n            gemm<3, 2>(m0, m, n0, n);\n            break;\n        case 0x23:\n            mc = 2;\n            nc = 3;\n            gemm<2, 3>(m0, m, n0, n);\n            break;\n        case 0x22:\n            mc = 2;\n            nc = 2;\n            gemm<2, 2>(m0, m, n0, n);\n            break;\n        case 0x31:\n            mc = 3;\n            nc = 1;\n            gemm<3, 1>(m0, m, n0, n);\n            break;\n        case 0x13:\n            mc = 1;\n            nc = 3;\n            gemm<1, 3>(m0, m, n0, n);\n            break;\n        case 0x21:\n            mc = 2;\n            nc = 1;\n            gemm<2, 1>(m0, m, n0, n);\n            break;\n        case 0x12:\n            mc = 1;\n            nc = 2;\n            gemm<1, 2>(m0, m, n0, n);\n            break;\n        case 0x11:\n            mc = 1;\n            nc = 1;\n            gemm<1, 1>(m0, m, n0, n);\n            break;\n        default:\n            return;\n        }\n        mp = m0 + (m - m0) / mc * mc;\n        np = n0 + (n - n0) / nc * nc;\n        mnpack(mp, m, n0, np);\n        mnpack(m0, m, np, n);\n    }\n\n    template <int RM, int RN>\n    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {\n        int64_t ytiles = (m - m0) / RM;\n        int64_t xtiles = (n - n0) / RN;\n        int64_t tiles = xtiles * ytiles;\n        int64_t duty = (tiles + nth - 1) / nth;\n        int64_t start = duty * ith;\n        int64_t end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (int64_t job = start; job < end; ++job) {\n            int64_t ii = m0 + job / xtiles * RM;\n            int64_t jj = n0 + job % xtiles * RN;\n            float32x4_t Cv[RN][RM] = {};\n            for (int64_t l = 0; l < k; ++l)\n                for (int64_t j = 0; j < RN; ++j)\n                    for (int64_t i = 0; i < RM; ++i)\n                        Cv[j][i] = vmlaq_n_f32(Cv[j][i],\n                                               vcvtq_f32_s32(vdotq_s32(\n                                                   vdotq_s32(vdupq_n_s32(0),\n                                                             load_lo(A + lda * (ii + i) + l),\n                                                             load_lo(B + ldb * (jj + j) + l)),\n                                                   load_hi(A + lda * (ii + i) + l),\n                                                   load_hi(B + ldb * (jj + j) + l))),\n                                               unhalf(A[lda * (ii + i) + l].d) *\n                                               unhalf(B[ldb * (jj + j) + l].d));\n            for (int64_t j = 0; j < RN; ++j)\n                for (int64_t i = 0; i < RM; ++i)\n                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);\n        }\n    }\n\n    inline int8x16_t load_lo(const block_q8_0 *b) {\n        return vld1q_s8(b->qs);\n    }\n\n    inline int8x16_t load_hi(const block_q8_0 *b) {\n        return vld1q_s8(b->qs + 16);\n    }\n\n    inline int8x16_t load_lo(const block_q4_0 *b) {\n        return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),\n                                                     vdupq_n_u8(0x0f))),\n                        vdupq_n_s8(0x8));\n    }\n\n    inline int8x16_t load_hi(const block_q4_0 *b) {\n        return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),\n                        vdupq_n_s8(0x8));\n    }\n\n    const TA *const A;\n    const block_q8_0 *const B;\n    float *const C;\n    const int64_t k;\n    const int64_t lda;\n    const int64_t ldb;\n    const int64_t ldc;\n    const int ith;\n    const int nth;\n};\n#endif // __ARM_FEATURE_DOTPROD\n\n#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)\ntemplate <typename TA, typename TB, typename TC>\nclass tinyBLAS_Q0_AVX {\n  public:\n    tinyBLAS_Q0_AVX(int64_t k,\n                    const TA *A, int64_t lda,\n                    const TB *B, int64_t ldb,\n                    TC *C, int64_t ldc,\n                    int ith, int nth)\n        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {\n        const int8_t kvalues_iq4nl[16] = {\n            -127, -104, -83, -65,\n            -49,  -35,  -22, -10,\n              1,   13,   25,  38,\n             53,   69,   89, 113\n        };\n\n        iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);\n    }\n\n    void matmul(int64_t m, int64_t n) {\n        mnpack(0, m, 0, n);\n    }\n\n  private:\n    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {\n        int64_t mc, nc, mp, np;\n        switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {\n#if VECTOR_REGISTERS == 32\n        case 0x44:\n            mc = 4;\n            nc = 4;\n#if defined(__AVX2__) && defined(__F16C__)\n            gemm4xN<4>(m0, m, n0, n);\n#else\n            gemm<4, 4>(m0, m, n0, n);\n#endif\n            break;\n        case 0x43:\n            mc = 4;\n            nc = 3;\n#if defined(__AVX2__) && defined(__F16C__)\n            gemm4xN<3>(m0, m, n0, n);\n#else\n            gemm<4, 3>(m0, m, n0, n);\n#endif\n            break;\n        case 0x34:\n            mc = 3;\n            nc = 4;\n#if defined(__AVX2__) && defined(__F16C__)\n            gemmMx4<3>(m0, m, n0, n);\n#else\n            gemm<3, 4>(m0, m, n0, n);\n#endif\n            break;\n        case 0x33:\n            mc = 3;\n            nc = 3;\n            gemm<3, 3>(m0, m, n0, n);\n            break;\n        case 0x42:\n            mc = 4;\n            nc = 2;\n#if defined(__AVX2__) && defined(__F16C__)\n            gemm4xN<2>(m0, m, n0, n);\n#else\n            gemm<4, 2>(m0, m, n0, n);\n#endif\n            break;\n        case 0x24:\n            mc = 2;\n            nc = 4;\n#if defined(__AVX2__) && defined(__F16C__)\n            gemmMx4<2>(m0, m, n0, n);\n#else\n            gemm<2, 4>(m0, m, n0, n);\n#endif\n            break;\n#else\n        case 0x44:\n        case 0x43:\n        case 0x42:\n            mc = 4;\n            nc = 2;\n#if defined(__AVX2__) && defined(__F16C__)\n            gemm4xN<2>(m0, m, n0, n);\n#else\n            gemm<4, 2>(m0, m, n0, n);\n#endif\n            break;\n        case 0x34:\n        case 0x24:\n            mc = 2;\n            nc = 4;\n#if defined(__AVX2__) && defined(__F16C__)\n            gemmMx4<2>(m0, m, n0, n);\n#else\n            gemm<2, 4>(m0, m, n0, n);\n#endif\n            break;\n        case 0x33:\n#endif\n        case 0x32:\n            mc = 3;\n            nc = 2;\n            gemm<3, 2>(m0, m, n0, n);\n            break;\n        case 0x23:\n            mc = 2;\n            nc = 3;\n            gemm<2, 3>(m0, m, n0, n);\n            break;\n        case 0x41:\n            mc = 4;\n            nc = 1;\n#if defined(__AVX2__) && defined(__F16C__)\n            gemm4xN<1>(m0, m, n0, n);\n#else\n            gemm<4, 1>(m0, m, n0, n);\n#endif\n            break;\n        case 0x22:\n            mc = 2;\n            nc = 2;\n            gemm<2, 2>(m0, m, n0, n);\n            break;\n        case 0x14:\n            mc = 1;\n            nc = 4;\n#if defined(__AVX2__) && defined(__F16C__)\n            gemmMx4<1>(m0, m, n0, n);\n#else\n            gemm<1, 4>(m0, m, n0, n);\n#endif\n            break;\n        case 0x31:\n            mc = 3;\n            nc = 1;\n            gemm<3, 1>(m0, m, n0, n);\n            break;\n        case 0x13:\n            mc = 1;\n            nc = 3;\n            gemm<1, 3>(m0, m, n0, n);\n            break;\n        case 0x21:\n            mc = 2;\n            nc = 1;\n            gemm<2, 1>(m0, m, n0, n);\n            break;\n        case 0x12:\n            mc = 1;\n            nc = 2;\n            gemm<1, 2>(m0, m, n0, n);\n            break;\n        case 0x11:\n            mc = 1;\n            nc = 1;\n            gemm<1, 1>(m0, m, n0, n);\n            break;\n        default:\n            return;\n        }\n        mp = m0 + (m - m0) / mc * mc;\n        np = n0 + (n - n0) / nc * nc;\n        mnpack(mp, m, n0, np);\n        mnpack(m0, m, np, n);\n    }\n\n#if defined(__AVX2__) && defined(__F16C__)\n// Templated functions for gemm of dimensions 4xN\n    template <int RN>\n    NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {\n        int64_t ytiles = (m - m0) / 4;\n        int64_t xtiles = (n - n0) / RN;\n        int64_t tiles = xtiles * ytiles;\n        int64_t duty = (tiles + nth - 1) / nth;\n        int64_t start = duty * ith;\n        int64_t end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (int64_t job = start; job < end; ++job) {\n            int64_t ii = m0 + job / xtiles * 4;\n            int64_t jj = n0 + job % xtiles * RN;\n            __m256 Cv[RN][4] = {};\n            for (int64_t l = 0; l < k; ++l) {\n                uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);\n                // Convert delta values for four blocks to float values\n                __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));\n                __m256i avec0 = load(A + lda * (ii + 0) + l);\n                __m256i avec1 = load(A + lda * (ii + 1) + l);\n                __m256i avec2 = load(A + lda * (ii + 2) + l);\n                __m256i avec3 = load(A + lda * (ii + 3) + l);\n                for (int64_t j = 0; j < RN; ++j) {\n                        __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));\n                        // Computation of product of delta values for four blocks and replicate it across 256 bit lane\n                        __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db));\n                        dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);\n                        // Computation of dot product and multiplication with appropriate delta value products\n                        Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),\n                                    updot(_mm256_sign_epi8(avec0, avec0),\n                                          _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),\n                                    Cv[j][0]);\n                        Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),\n                                    updot(_mm256_sign_epi8(avec1, avec1),\n                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),\n                                    Cv[j][1]);\n                        Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),\n                                    updot(_mm256_sign_epi8(avec2, avec2),\n                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),\n                                    Cv[j][2]);\n                        Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),\n                                    updot(_mm256_sign_epi8(avec3, avec3),\n                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),\n                                    Cv[j][3]);\n                }\n            }\n\n            for (int64_t j = 0; j < RN; ++j)\n                for (int64_t i = 0; i < 4; ++i)\n                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);\n        }\n    }\n\n    // Templated functions for gemm of dimensions Mx4\n    template <int RM>\n    NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {\n        int64_t ytiles = (m - m0) / RM;\n        int64_t xtiles = (n - n0) / 4;\n        int64_t tiles = xtiles * ytiles;\n        int64_t duty = (tiles + nth - 1) / nth;\n        int64_t start = duty * ith;\n        int64_t end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (int64_t job = start; job < end; ++job) {\n            int64_t ii = m0 + job / xtiles * RM;\n            int64_t jj = n0 + job % xtiles * 4;\n            __m256 Cv[4][RM] = {};\n            for (int64_t l = 0; l < k; ++l) {\n                uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);\n                // Convert delta values for four blocks to float values\n                __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));\n                __m256i bvec0 = load(B + ldb * (jj + 0) + l);\n                __m256i bvec1 = load(B + ldb * (jj + 1) + l);\n                __m256i bvec2 = load(B + ldb * (jj + 2) + l);\n                __m256i bvec3 = load(B + ldb * (jj + 3) + l);\n                for (int64_t i = 0; i < RM; ++i) {\n                    __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));\n                    // Computation of product of delta values for four blocks and replicate it across 256 bit lane\n                    __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db));\n                    dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);\n                    // Computation of dot product and multiplication with appropriate delta value products\n                    Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),\n                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),\n                                                            load(A + lda * (ii + i) + l)),\n                                            _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),\n                                    Cv[0][i]);\n                    Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),\n                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),\n                                                            load(A + lda * (ii + i) + l)),\n                                            _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),\n                                    Cv[1][i]);\n                    Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),\n                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),\n                                                            load(A + lda * (ii + i) + l)),\n                                            _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),\n                                    Cv[2][i]);\n                    Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),\n                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),\n                                                            load(A + lda * (ii + i) + l)),\n                                            _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),\n                                    Cv[3][i]);\n                }\n            }\n            for (int64_t j = 0; j < 4; ++j)\n                for (int64_t i = 0; i < RM; ++i)\n                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);\n        }\n    }\n#endif\n\n    template <int RM, int RN>\n    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {\n        int64_t ytiles = (m - m0) / RM;\n        int64_t xtiles = (n - n0) / RN;\n        int64_t tiles = xtiles * ytiles;\n        int64_t duty = (tiles + nth - 1) / nth;\n        int64_t start = duty * ith;\n        int64_t end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (int64_t job = start; job < end; ++job) {\n            int64_t ii = m0 + job / xtiles * RM;\n            int64_t jj = n0 + job % xtiles * RN;\n            __m256 Cv[RN][RM] = {};\n            for (int64_t l = 0; l < k; ++l)\n                for (int64_t j = 0; j < RN; ++j)\n                    for (int64_t i = 0; i < RM; ++i) {\n#if defined(__AVX2__)\n                        __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),\n                                                              load(A + lda * (ii + i) + l)),\n                                             _mm256_sign_epi8(load(B + ldb * (jj + j) + l),\n                                                              load(A + lda * (ii + i) + l)));\n#else\n                        __m128i ali0 = load0(A + lda * (ii + i) + l);\n                        __m128i ali1 = load1(A + lda * (ii + i) + l);\n                        __m128i blj0 = load0(B + ldb * (jj + j) + l);\n                        __m128i blj1 = load1(B + ldb * (jj + j) + l);\n\n                        __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);\n                        __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);\n                        __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);\n                        __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);\n\n                        // updot\n                        const __m128i oneFill = _mm_set1_epi16(1);\n                        __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);\n                        __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);\n                        __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));\n#endif\n                        Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *\n                                                       unhalf(B[ldb * (jj + j) + l].d)),\n                                                       udTmp,\n                                                       Cv[j][i]);\n                    }\n            for (int64_t j = 0; j < RN; ++j)\n                for (int64_t i = 0; i < RM; ++i)\n                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);\n        }\n    }\n\n    inline __m256i load(const block_q8_0 *b) {\n        return _mm256_loadu_si256((const __m256i *)b->qs);\n    }\n\n    inline __m128i load0(const block_q8_0 *b) {\n        return _mm_loadu_si128((const __m128i *)b->qs);\n    }\n\n    inline __m128i load1(const block_q8_0 *b) {\n        return _mm_loadu_si128(((const __m128i *)b->qs) + 1);\n    }\n\n    inline __m256i load(const block_q4_0 *b) {\n        return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));\n    }\n\n    inline __m128i load0(const block_q4_0 *b) {\n        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));\n        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));\n    }\n\n    inline __m128i load1(const block_q4_0 *b) {\n        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));\n        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));\n    }\n\n    inline __m256i load(const block_q5_0 *b) {\n        return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));\n    }\n\n    inline __m128i load0(const block_q5_0* b) {\n        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));\n        uint32_t x32;\n        memcpy(&x32, b->qh, sizeof(uint32_t));\n        __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);\n        __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),\n                                        _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),\n                                                     _mm_shuffle_epi8(_mm_set1_epi32(x32),\n                                                                      _mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));\n        bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));\n        return _mm_or_si128(qxl, bytesl);\n    }\n\n    inline __m128i load1(const block_q5_0* b) {\n        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));\n        uint32_t x32;\n        memcpy(&x32, b->qh, sizeof(uint32_t));\n        __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));\n        __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),\n                                        _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),\n                                                     _mm_shuffle_epi8(_mm_set1_epi32(x32),\n                                                                      _mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));\n        bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));\n        return _mm_or_si128(qxh, bytesh);\n    }\n\n    inline __m256i load(const block_iq4_nl *b) {\n        return MM256_SET_M128I(load1(b), load0(b));\n    }\n\n    inline __m128i load0(const block_iq4_nl *b) {\n        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));\n        return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));\n    }\n\n    inline __m128i load1(const block_iq4_nl *b) {\n        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));\n        return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));\n    }\n\n    inline __m256 updot(__m256i u, __m256i s) {\n        __m256i res;\n#if defined(__AVX512VNNI__) && defined(__AVX512VL__)\n        res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);\n#elif defined(__AVXVNNI__)\n        res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);\n#else\n        res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));\n#endif\n        return _mm256_cvtepi32_ps(res);\n    }\n\n    static inline __m256i denibble(const uint8_t *p) {\n        __m128i x = _mm_loadu_si128((const __m128i *)p);\n        return _mm256_and_si256(_mm256_set1_epi8(15),\n                                _mm256_insertf128_si256(_mm256_castsi128_si256(x),\n                                                        _mm_srli_epi16(x, 4), 1));\n    }\n\n    static inline __m256i bittobyte(const uint8_t *p) {\n        uint32_t x32;\n        memcpy(&x32, p, sizeof(uint32_t));\n        __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),\n                                          _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),\n                                                          _mm256_shuffle_epi8(_mm256_set1_epi32(x32),\n                                                                              _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,\n                                                                                                0x0101010101010101, 0x0000000000000000))));\n        return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));\n    }\n\n    const TA *const A;\n    const TB *const B;\n    TC *const C;\n    const int64_t k;\n    const int64_t lda;\n    const int64_t ldb;\n    const int64_t ldc;\n    const int ith;\n    const int nth;\n    __m128i iq4nlt;\n};\n#endif // __AVX__\n\n//PPC Implementation\n#if defined(__MMA__)\n\n#define SAVE_ACC(ACC, ii, jj) \\\n   __builtin_mma_disassemble_acc(vec_C, ACC); \\\n   for (int I = 0; I < 4; I++) { \\\n      for (int J = 0; J < 4; J++) { \\\n         *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \\\n      } \\\n   } \\\n\ntemplate<typename T>\nstruct mma_instr;\n\ntemplate<>\nstruct mma_instr<ggml_bf16_t> {\n    static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {\n        __builtin_mma_xvbf16ger2pp(acc, a, b);\n    }\n};\n\ntemplate<>\nstruct mma_instr<ggml_fp16_t> {\n    static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {\n        __builtin_mma_xvf16ger2pp(acc, a, b);\n    }\n};\n\ntemplate <typename TA, typename TB, typename TC>\nclass tinyBLAS_HP16_PPC {\n  public:\n    tinyBLAS_HP16_PPC(int64_t k,\n                const TA *A, int64_t lda,\n                const TB *B, int64_t ldb,\n                TC *C, int64_t ldc,\n                int ith, int nth)\n        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {\n    }\n\n    void matmul(int64_t m, int64_t n) {\n        mnpack(0, m, 0, n);\n    }\n\n  private:\n    void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {\n        vec_t t[8], s[8];\n        vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};\n        vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};\n        vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};\n        vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};\n\n        if (numVec == 2) {\n            t[0] = vec_perm(c[0], c[1], swiz1);\n            t[1] = vec_perm(c[2], c[3], swiz1);\n            s[0] = vec_perm(t[0], t[1], swiz3);\n            s[1] = vec_perm(t[0], t[1], swiz4);\n            vec_xst(s[0], 0, (vec_t*)vecOffset);\n            vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));\n        } else if (numVec == 4) {\n            t[0] = vec_perm(c[0], c[1], swiz1);\n            t[1] = vec_perm(c[0], c[1], swiz2);\n            t[2] = vec_perm(c[2], c[3], swiz1);\n            t[3] = vec_perm(c[2], c[3], swiz2);\n            s[0] = vec_perm(t[0], t[2], swiz3);\n            s[1] = vec_perm(t[0], t[2], swiz4);\n            s[2] = vec_perm(t[1], t[3], swiz3);\n            s[3] = vec_perm(t[1], t[3], swiz4);\n            for (int i = 0; i < 4; ++i)\n                vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));\n        } else if (numVec == 8) {\n            for (int i = 0; i < 4; i += 2) {\n                t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);\n                t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);\n            }\n            for (int i = 4; i < 8; i += 2) {\n                t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);\n                t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);\n            }\n            s[0] = vec_perm(t[0], t[2], swiz3);\n            s[1] = vec_perm(t[0], t[2], swiz4);\n            s[2] = vec_perm(t[1], t[3], swiz3);\n            s[3] = vec_perm(t[1], t[3], swiz4);\n            s[4] = vec_perm(t[4], t[6], swiz3);\n            s[5] = vec_perm(t[4], t[6], swiz4);\n            s[6] = vec_perm(t[5], t[7], swiz3);\n            s[7] = vec_perm(t[5], t[7], swiz4);\n            for (int i = 0; i < 8; ++i)\n                vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));\n        }\n    }\n\n    void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {\n        int64_t i, j;\n        TA *aoffset = NULL;\n        unsigned char *vecOffset = NULL;\n        TA * aoffsets[8];\n        vector unsigned char c_arr[8];\n        aoffset = const_cast<TA*>(a);\n        vecOffset = vec;\n        j = (rows >> 3);\n        if (j > 0) {\n            do {\n                if (cols == 4) {\n                    aoffsets[0] = aoffset;\n                    for (int it = 1; it < 4; ++it)\n                        aoffsets[it] = aoffsets[it-1] + lda;\n                    aoffset += 4 * lda;\n                    for (int i = 0; i < 4; ++i)\n                        c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);\n                    vector_permute_store(c_arr, 4, vecOffset);\n                    for (int i = 0; i<4; i++)\n                        aoffsets[i] = aoffsets[i]+lda;\n                    vecOffset +=64;\n                }\n                i = (cols >> 3);\n                if (i > 0) {\n                    aoffsets[0] = aoffset;\n                    for (int it = 1; it < 8; ++it) {\n                        aoffsets[it] = aoffsets[it-1] + lda;\n                    }\n                    aoffset += 8 * lda;\n                    do {\n                        for (int it = 0; it < 8; ++it)\n                            c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);\n                        vector_permute_store(c_arr, 8, vecOffset);\n                        for (int it = 0; it < 8; ++it)\n                            aoffsets[it] = aoffsets[it] + 8*lda;\n                        vecOffset += 128;\n                        i--;\n                    } while(i > 0);\n                }\n                j--;\n            } while(j > 0);\n        }\n        if (rows & 4) {\n            aoffsets[0] = aoffset;\n            for (int it = 1; it < 4; ++it)\n                aoffsets[it] = aoffsets[it-1] + lda;\n            aoffset += 4 * lda;\n            if (cols == 4) {\n                for (int it = 0; it < 4; ++it)\n                    c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);\n                vector_permute_store(c_arr, 2, vecOffset);\n                for (int it = 0; it< 4; it++)\n                    aoffsets[it] = aoffsets[it] + lda;\n                vecOffset += 32;\n            }\n            i = (cols >> 3);\n            if (i > 0) {\n                do {\n                    for (int it = 0; it < 4; ++it)\n                        c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);\n                    vector_permute_store(c_arr, 4, vecOffset);\n                    for (int it = 0; it< 4; it++)\n                        aoffsets[it] = aoffsets[it] + 8*lda;\n                    vecOffset += 64;\n                    i--;\n                } while(i > 0);\n            }\n        }\n        if (rows & 3) {\n            aoffsets[0] = aoffset;\n            for (int it = 1; it < 4; ++it)\n                aoffsets[it] = aoffsets[it-1] + lda;\n            if (cols == 4) {\n                switch(rows) {\n                    case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);\n                    case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);\n                    case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);\n                        break;\n                }\n                vector_permute_store(c_arr, 2, vecOffset);\n                for (int it = 0; it< 4; it++)\n                     aoffsets[it] = aoffsets[it] + lda;\n                vecOffset += 32;\n            }\n            i = (cols >> 3);\n            if (i > 0) {\n                do {\n                    switch(rows) {\n                        case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);\n                        case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);\n                        case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);\n                            break;\n                    }\n                    vector_permute_store(c_arr, 4, vecOffset);\n                    for (int it = 0; it <4; it++)\n                         aoffsets[it] = aoffsets[it] + 8* lda;\n                    vecOffset += 64;\n                    i--;\n                } while(i > 0);\n            }\n        }\n    }\n\n    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {\n        int64_t mc, nc, mp, np;\n        int m_rem = MIN(m - m0, 8);\n        int n_rem = MIN(n - n0, 8);\n\n        if (m_rem >= 8 && n_rem >= 8) {\n            mc = 8;\n            nc = 8;\n            gemm<8,8>(m0, m, n0, n);\n        } else if (m_rem >= 4 && n_rem >= 8) {\n            mc = 4;\n            nc = 8;\n            gemm<4,8>(m0, m, n0, n);\n        } else if (m_rem >=8 && n_rem >=4){\n                mc = 8;\n                nc = 4;\n                gemm<8,4>(m0, m, n0, n);\n        } else if ((m_rem < 4) && (n_rem >= 8)) {\n            nc = 8;\n            switch(m_rem) {\n                case 1:\n                    mc = 1;\n                    gemm_Mx8<1>(m0, m, n0, n);\n                    break;\n                case 2:\n                    mc = 2;\n                    gemm_Mx8<2>(m0, m, n0, n);\n                    break;\n                case 3:\n                    mc = 3;\n                    gemm_Mx8<3>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        } else if (m_rem >= 4 && n_rem >= 4) {\n            mc = 4;\n            nc = 4;\n            gemm_small<4, 4>(m0, m, n0, n);\n        } else if ((m_rem > 4) && (n_rem < 4)) {\n            mc = 4;\n            switch(n_rem) {\n                case 1:\n                    nc = 1;\n                    gemm_small<4, 1>(m0, m, n0, n);\n                    break;\n                case 2:\n                    nc = 2;\n                    gemm_small<4, 2>(m0, m, n0, n);\n                    break;\n                case 3:\n                    nc = 3;\n                    gemm_small<4, 3>(m0, m, n0, n);\n                    break;\n\n                default:\n                    return;\n            }\n        } else {\n            switch((m_rem << 4) | n_rem) {\n                case 0x43:\n                    mc = 4;\n                    nc = 3;\n                    gemm_small<4, 3>(m0, m, n0, n);\n                    break;\n                case 0x42:\n                    mc = 4;\n                    nc = 2;\n                    gemm_small<4, 2>(m0, m, n0, n);\n                    break;\n                case 0x41:\n                    mc = 4;\n                    nc = 1;\n                    gemm_small<4, 1>(m0, m, n0, n);\n                    break;\n                case 0x34:\n                    mc = 3;\n                    nc = 4;\n                    gemm_small<3, 4>(m0, m, n0, n);\n                    break;\n                case 0x33:\n                    mc = 3;\n                    nc = 3;\n                    gemm_small<3, 3>(m0, m, n0, n);\n                    break;\n                case 0x32:\n                    mc = 3;\n                    nc = 2;\n                    gemm_small<3, 2>(m0, m, n0, n);\n                    break;\n                case 0x31:\n                    mc = 3;\n                    nc = 1;\n                    gemm_small<3, 1>(m0, m, n0, n);\n                    break;\n                case 0x24:\n                    mc = 2;\n                    nc = 4;\n                    gemm_small<2,4>(m0, m, n0, n);\n                    break;\n                case 0x23:\n                    mc = 2;\n                    nc = 3;\n                    gemm_small<2, 3>(m0, m, n0, n);\n                    break;\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm_small<2, 2>(m0, m, n0, n);\n                    break;\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm_small<2, 1>(m0, m, n0, n);\n                    break;\n                case 0x14:\n                    mc = 1;\n                    nc = 4;\n                    gemm_small<1, 4>(m0, m, n0, n);\n                    break;\n                case 0x13:\n                    mc = 1;\n                    nc = 3;\n                    gemm_small<1, 3>(m0, m, n0, n);\n                    break;\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm_small<1, 2>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm_small<1, 1>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        }\n        mp = m0 + (m - m0) / mc * mc;\n        np = n0 + (n - n0) / nc * nc;\n        mnpack(mp, m, n0, np);\n        mnpack(m0, m, np, n);\n    }\n\n    void KERNEL_4x8(int64_t ii, int64_t jj) {\n        vec_t vec_A[4], vec_B[8] , vec_C[4];\n        acc_t acc_0, acc_1;\n        __builtin_mma_xxsetaccz(&acc_0);\n        __builtin_mma_xxsetaccz(&acc_1);\n        for (int l = 0; l < k; l+=8) {\n            packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);\n            packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);\n            for (int x = 0; x < 4; x++) {\n                mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);\n                mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);\n            }\n        }\n        SAVE_ACC(&acc_0, ii, jj);\n        SAVE_ACC(&acc_1, ii, jj+4);\n    }\n\n    void KERNEL_8x4(int64_t ii, int64_t jj) {\n        vec_t vec_A[8], vec_B[4] , vec_C[4];\n        acc_t acc_0, acc_1;\n        __builtin_mma_xxsetaccz(&acc_0);\n        __builtin_mma_xxsetaccz(&acc_1);\n        for (int l = 0; l < k; l+=8) {\n            packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);\n            packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);\n            for (int x = 0; x < 4; x++) {\n                mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);\n                mma_instr<TA>::outer_product(&acc_1, vec_A[x+4], vec_B[x]);\n            }\n        }\n        SAVE_ACC(&acc_0, ii, jj);\n        SAVE_ACC(&acc_1, ii+4, jj);\n    }\n\n\n    void KERNEL_8x8(int64_t ii, int64_t jj) {\n        vec_t vec_A[8], vec_B[8], vec_C[4];\n        acc_t acc_0, acc_1, acc_2, acc_3;\n        __builtin_mma_xxsetaccz(&acc_0);\n        __builtin_mma_xxsetaccz(&acc_1);\n        __builtin_mma_xxsetaccz(&acc_2);\n        __builtin_mma_xxsetaccz(&acc_3);\n        for (int l = 0; l < k; l+=8) {\n            packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);\n            packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);\n            for (int x = 0; x < 4; x++) {\n                mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);\n                mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);\n                mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]);\n                mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);\n            }\n        }\n\n        SAVE_ACC(&acc_0, ii, jj);\n        SAVE_ACC(&acc_1, ii, jj+4);\n        SAVE_ACC(&acc_2, ii+4, jj);\n        SAVE_ACC(&acc_3, ii+4, jj+4);\n    }\n\n    template<int RM, int RN>\n    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {\n        int64_t ytiles = (m - m0) / RM;\n        int64_t xtiles = (n - n0) / RN;\n        int64_t tiles = xtiles * ytiles;\n        int64_t duty = (tiles + nth - 1) / nth;\n        int64_t start = duty * ith;\n        int64_t end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (int64_t job = start; job < end; ++job) {\n            int64_t ii = m0 + job / xtiles * RM;\n            int64_t jj = n0 + job % xtiles * RN;\n            vec_t vec_C[4];\n            acc_t acc_0;\n            __builtin_mma_xxsetaccz(&acc_0);\n            vec_t vec_A[2], vec_B[2];\n            for (int l=0; l<k; l+=4) {\n                packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);\n                packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);\n                for (int x = 0; x<2; x++) {\n                    mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);\n                }\n            }\n            __builtin_mma_disassemble_acc(vec_C, &acc_0);\n            for (int I = 0; I < RM; I++) {\n                for (int J = 0; J < RN; J++) {\n                    *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);\n                }\n            }\n        }\n    }\n\n    template<int RM>\n    void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {\n        int RN = 8;\n        int64_t ytiles = (m - m0) / RM;\n        int64_t xtiles = (n - n0) / RN;\n        int64_t tiles = xtiles * ytiles;\n        int64_t duty = (tiles + nth - 1) / nth;\n        int64_t start = duty * ith;\n        int64_t end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (int64_t job = start; job < end; ++job) {\n            int64_t ii = m0 + job / xtiles * RM;\n            int64_t jj = n0 + job % xtiles * RN;\n            vec_t vec_C[4];\n            acc_t acc_0, acc_1;\n            __builtin_mma_xxsetaccz(&acc_0);\n            __builtin_mma_xxsetaccz(&acc_1);\n            vec_t vec_A[4], vec_B[8];\n            for (int l=0; l<k; l+=8) {\n                packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);\n                packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);\n                for (int x = 0; x<4; x++) {\n                    mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);\n                    mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);\n                }\n            }\n            __builtin_mma_disassemble_acc(vec_C, &acc_0);\n            for (int I = 0; I < RM; I++) {\n                for (int J = 0; J < 4; J++) {\n                    *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);\n                }\n            }\n            __builtin_mma_disassemble_acc(vec_C, &acc_1);\n            for (int I = 0; I < RM; I++) {\n                for (int J = 0; J < 4; J++) {\n                    *((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);\n                }\n            }\n        }\n    }\n\n    template<int RM, int RN>\n    inline void kernel(int64_t ii, int64_t jj) {\n       if constexpr(RM == 4 && RN == 8) {\n          KERNEL_4x8(ii,jj);\n       } else if constexpr(RM == 8 && RN == 8) {\n          KERNEL_8x8(ii,jj);\n       } else if constexpr(RM == 8 && RN == 4) {\n          KERNEL_8x4(ii,jj);\n       } else {\n          assert(false && \"RN/RM values not supported\");\n       }\n    }\n\n    template <int RM, int RN>\n    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {\n        int64_t ytiles = (m - m0) / RM;\n        int64_t xtiles = (n - n0) / RN;\n        int64_t tiles = xtiles * ytiles;\n        int64_t duty = (tiles + nth - 1) / nth;\n        int64_t start = duty * ith;\n        int64_t end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (int64_t job = start; job < end; ++job) {\n            int64_t ii = m0 + job / xtiles * RM;\n            int64_t jj = n0 + job % xtiles * RN;\n            kernel<RM, RN>(ii, jj);\n        }\n    }\n\n    const TA *const A;\n    const TB *const B;\n    TC *C;\n    const int64_t k;\n    const int64_t lda;\n    const int64_t ldb;\n    const int64_t ldc;\n    const int ith;\n    const int nth;\n};\n\ntemplate <typename TA>\nclass tinyBLAS_Q0_PPC {\n  public:\n    tinyBLAS_Q0_PPC(int64_t k,\n             const TA * A, int64_t lda,\n             const block_q8_0 * B, int64_t ldb,\n             float * C, int64_t ldc,\n             int ith, int nth)\n        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {\n    }\n\n    void matmul(int64_t m, int64_t n) {\n        const int64_t mc = 64;\n        const int64_t kc = 64;\n        int64_t nc = 64;\n        int64_t n_aligned = 0;\n        if (n % 64 == 0) {\n            n_aligned = n;\n        } else if (n == 4) {\n            n_aligned = 4;\n        } else if (n < 64) {\n            n_aligned = (n / 8) * 8;\n        } else {\n            n_aligned = (n / 64) * 64;\n        }\n\n        if (n_aligned > 0) {\n            if (n_aligned % 64 == 0)      nc = 64;\n            else if (n_aligned == n)      nc = n;\n            else if (n_aligned % 32 == 0) nc = 32;\n            else if (n_aligned % 24 == 0) nc = 24;\n            else if (n_aligned % 16 == 0) nc = 16;\n            else                          nc = 8;\n        }\n        bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);\n        if (can_use_tiled) {\n            matmul_tiled(m, n_aligned, mc, nc, kc);\n            if (n > n_aligned) {\n                mnpack(0, m, n_aligned, n);\n            }\n        } else {\n            mnpack(0, m, 0, n);\n        }\n    }\n\n  private:\n    inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) {\n        for (int I = 0; I < RM; I++) {\n            for (int J = 0; J < RN; J++) {\n                *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J);\n            }\n        }\n    }\n\n    inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {\n        vec_t vec_C[4];\n        __builtin_mma_disassemble_acc(vec_C, ACC);\n        for (int I = 0; I < 4; I++) {\n            for (int J = 0; J < 4; J++) {\n                *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J);\n            }\n        }\n    }\n\n    inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {\n        vec_t vec_C[4];\n        __builtin_mma_disassemble_acc(vec_C, ACC);\n        for (int I = 0; I < 4; I++) {\n            for (int J = 0; J < 4; J++) {\n                float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I);\n                *c_ptr += *((float *)&vec_C[I] + J);\n            }\n        }\n    }\n\n    template<typename ArrayType>\n    inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) {\n        vector signed int vec_C[4];\n        vector float CA[4] = {0};\n        vector float res[4] = {0};\n        __builtin_mma_disassemble_acc(vec_C, ACC);\n        for (int i = 0; i < 4; i++) {\n            CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0));\n            res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);\n            fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]);\n        }\n    }\n\n    inline void process_q4_elements(vector signed char (&c)[2], int * ca) {\n        const vector signed char lowMask = vec_splats((signed char)0xF);\n        const vector unsigned char v4 = vec_splats((unsigned char)0x4);\n        const vector signed char v8 = vec_splats((signed char)0x8);\n        vector signed int vsum = {0};\n        vector signed int vsum2 = {0};\n        c[0] = vec_and(c[1], lowMask);\n        c[1] = vec_sr(c[1], v4);\n        c[0] = vec_sub(c[0], v8);\n        c[1] = vec_sub(c[1], v8);\n        vsum = vec_sum4s(c[0], vsum);\n        vsum2 = vec_sum4s(c[1], vsum2);\n        vsum = vec_add(vsum, vsum2);\n        *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];\n    }\n\n    template <typename V1, typename V2>\n    inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) {\n        vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};\n        vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};\n        vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};\n        vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};\n        V2 t1, t2, t3, t4, t5, t6, t7, t8;\n        vector unsigned char xor_vector;\n        uint8_t flip_vec = 0x80;\n        xor_vector = vec_splats(flip_vec);\n        t1 = vec_perm(s1, s2, swiz1);\n        t2 = vec_perm(s1, s2, swiz2);\n        t3 = vec_perm(s3, s4, swiz1);\n        t4 = vec_perm(s3, s4, swiz2);\n        t5 = vec_perm(t1, t3, swiz3);\n        t6 = vec_perm(t1, t3, swiz4);\n        t7 = vec_perm(t2, t4, swiz3);\n        t8 = vec_perm(t2, t4, swiz4);\n        if (flip == true) {\n            t5 = vec_xor(t5, xor_vector);\n            t6 = vec_xor(t6, xor_vector);\n            t7 = vec_xor(t7, xor_vector);\n            t8 = vec_xor(t8, xor_vector);\n        }\n        vec_xst(t5, 0, vecOffset);\n        vec_xst(t6, 0, vecOffset + 16);\n        vec_xst(t7, 0, vecOffset + 32);\n        vec_xst(t8, 0, vecOffset + 48);\n    }\n\n    inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) {\n        const vector signed char lowMask = vec_splats((signed char)0x0F);\n        const vector signed char v8      = vec_splats((signed char)0x08);\n        const vector unsigned char v4    = vec_splats((unsigned char)4);\n        lo = vec_and(packed, lowMask);\n        hi = vec_sr(packed, v4);\n        lo = vec_sub(lo, v8);\n        hi = vec_sub(hi, v8);\n    }\n\n    inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) {\n        vec_t t[8], s[8];\n        vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};\n        vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};\n        vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};\n        vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};\n        for (int i = 0; i < 4; i += 2) {\n            t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);\n            t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);\n        }\n        for (int i = 4; i < 8; i += 2) {\n            t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);\n            t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);\n        }\n        s[0] = vec_perm(t[0], t[2], swiz3);\n        s[1] = vec_perm(t[0], t[2], swiz4);\n        s[2] = vec_perm(t[1], t[3], swiz3);\n        s[3] = vec_perm(t[1], t[3], swiz4);\n        s[4] = vec_perm(t[4], t[6], swiz3);\n        s[5] = vec_perm(t[4], t[6], swiz4);\n        s[6] = vec_perm(t[5], t[7], swiz3);\n        s[7] = vec_perm(t[5], t[7], swiz4);\n        for (int i = 0; i < 8; ++i) {\n            vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16));\n        }\n    }\n\n    static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) {\n        vector signed short i16_hi = vec_unpackh(raw);\n        vector signed short i16_lo = vec_unpackl(raw);\n\n        vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0);\n        vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0);\n        vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0);\n        vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0);\n        out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale));\n        out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale));\n    }\n\n    void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {\n        unsigned char * vecOffset = vec;\n        for (int i = 0; i < rows; i += 8) {\n            const block_q4_0 * rows_base[8];\n            for (int r = 0; r < 8; r++) {\n                rows_base[r] = a + (i + r) * lda;\n            }\n            for (int blk = 0; blk < blocks; blk++) {\n                vector unsigned short hp_res[8][4];\n                for (int r = 0; r < 8; r++) {\n                    const block_q4_0 * current_blk = rows_base[r] + blk;\n                    vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));\n                    vector signed char v_qs = vec_xl(0, (const vector signed char *)current_blk->qs);\n                    vector signed char c1, c2;\n                    unpack_q4_to_q8(v_qs, c1, c2);\n                    convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);\n                    convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]);\n                }\n                for (int c = 0; c < 4; c++) {\n                    vector unsigned char c_arr[8];\n                    for (int r = 0; r < 8; r++) {\n                        c_arr[r] = (vector unsigned char)hp_res[r][c];\n                    }\n                    vector_permute_store_fp16((vec_t *)c_arr, vecOffset);\n                    vecOffset += 128;\n                }\n            }\n        }\n    }\n\n    template <int chunk_size>\n    static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {\n        unsigned char * vecOffset = vec;\n        const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};\n        const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};\n        const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};\n        const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};\n\n        for (int i = 0; i < rows; i += chunk_size) {\n            const block_q8_0 * rows_base[chunk_size];\n            for (int r = 0; r < chunk_size; r++) {\n                rows_base[r] = a + (i + r) * lda;\n            }\n            for (int blk = 0; blk < blocks; blk++) {\n                vector unsigned short hp_res[chunk_size][4];\n                for (int r = 0; r < chunk_size; r++) {\n                    const block_q8_0 * b = rows_base[r] + blk;\n                    vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d));\n                    vector signed char c[2];\n                    __vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs);\n                    __builtin_vsx_disassemble_pair(c, & pair);\n                    convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]);\n                    convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]);\n                }\n                for (int col = 0; col < 4; col++) {\n                    if constexpr (chunk_size == 8) {\n                        vec_t t[8];\n                        t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);\n                        t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);\n                        t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);\n                        t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);\n                        t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1);\n                        t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2);\n                        t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1);\n                        t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2);\n\n                        vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0));\n                        vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16));\n                        vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32));\n                        vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48));\n                        vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64));\n                        vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80));\n                        vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96));\n                        vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112));\n                        vecOffset += 128;\n                    } else {\n                        vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);\n                        vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);\n                        vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);\n                        vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);\n\n                        vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0));\n                        vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16));\n                        vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32));\n                        vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48));\n                        vecOffset += 64;\n                    }\n                }\n            }\n        }\n    }\n\n    void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {\n        if (rows == 4) {\n            pack_q8_block<4>(a, lda, rows, blocks, vec);\n        } else {\n            pack_q8_block<8>(a, lda, rows, blocks, vec);\n        }\n    }\n\n    template<int size>\n    void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array<int, size> & comparray) {\n        int64_t i, j;\n        TA * aoffset = NULL;\n        int8_t * vecOffset = NULL;\n        TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL;\n        TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL;\n        vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};\n        vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};\n        aoffset = const_cast<TA *>(a);\n        vecOffset = vec;\n        j = (rows >> 3);\n        if (j > 0) {\n            do {\n                aoffset1 = aoffset;\n                aoffset2 = aoffset1 + lda;\n                aoffset3 = aoffset2 + lda;\n                aoffset4 = aoffset3 + lda;\n                aoffset5 = aoffset4 + lda;\n                aoffset6 = aoffset5 + lda;\n                aoffset7 = aoffset6 + lda;\n                aoffset8 = aoffset7 + lda;\n                aoffset += 8 * lda;\n                i = (cols >> 2);\n                if (i > 0) {\n                    do {\n                        c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);\n                        c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);\n                        c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);\n                        c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);\n                        c5[1] = vec_xl(0, (const vector signed char *)aoffset5->qs);\n                        c6[1] = vec_xl(0, (const vector signed char *)aoffset6->qs);\n                        c7[1] = vec_xl(0, (const vector signed char *)aoffset7->qs);\n                        c8[1] = vec_xl(0, (const vector signed char *)aoffset8->qs);\n\n                        process_q4_elements(c1, & comparray[0]);\n                        process_q4_elements(c2, & comparray[1]);\n                        process_q4_elements(c3, & comparray[2]);\n                        process_q4_elements(c4, & comparray[3]);\n                        process_q4_elements(c5, & comparray[4]);\n                        process_q4_elements(c6, & comparray[5]);\n                        process_q4_elements(c7, & comparray[6]);\n                        process_q4_elements(c8, & comparray[7]);\n                        vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);\n                        vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);\n                        vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false);\n                        vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false);\n                        aoffset1 += lda;\n                        aoffset2 += lda;\n                        aoffset3 += lda;\n                        aoffset4 += lda;\n                        aoffset5 += lda;\n                        aoffset6 += lda;\n                        aoffset7 += lda;\n                        aoffset8 += lda;\n                        vecOffset += 256;\n                        i--;\n                    } while (i > 0);\n                }\n                j--;\n            } while (j > 0);\n        }\n\n        if (rows & 4) {\n            aoffset1 = aoffset;\n            aoffset2 = aoffset1 + lda;\n            aoffset3 = aoffset2 + lda;\n            aoffset4 = aoffset3 + lda;\n            aoffset += 4 * lda;\n            i = (cols >> 2);\n            if (i > 0) {\n                do {\n                    c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);\n                    c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);\n                    c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);\n                    c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);\n\n                    process_q4_elements(c1, & comparray[0]);\n                    process_q4_elements(c2, & comparray[1]);\n                    process_q4_elements(c3, & comparray[2]);\n                    process_q4_elements(c4, & comparray[3]);\n                    vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);\n                    vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);\n                    aoffset1 += lda;\n                    aoffset2 += lda;\n                    aoffset3 += lda;\n                    aoffset4 += lda;\n                    vecOffset += 128;\n                    i--;\n                } while (i > 0);\n            }\n        }\n\n        if (rows & 3) {\n            aoffset1 = aoffset;\n            aoffset2 = aoffset1 + lda;\n            aoffset3 = aoffset2 + lda;\n            i = (cols >> 2);\n            if (i > 0) {\n                do {\n                    switch(rows) {\n                        case 3: c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);\n                        case 2: c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);\n                        case 1: c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);\n                            break;\n                    }\n                    process_q4_elements(c1, & comparray[0]);\n                    process_q4_elements(c2, & comparray[1]);\n                    process_q4_elements(c3, & comparray[2]);\n                    process_q4_elements(c4, & comparray[3]);\n                    vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);\n                    vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);\n                    aoffset1 += lda;\n                    aoffset2 += lda;\n                    aoffset3 += lda;\n                    vecOffset += 128;\n                    i--;\n                } while(i > 0);\n            }\n        }\n    }\n\n    template<typename VA, typename VB>\n    void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) {\n        int64_t i, j;\n        block_q8_0 * aoffset = NULL;\n        VA * vecOffset = NULL;\n        block_q8_0 * aoffsets[8];\n        __vector_pair arr[8];\n        VB c[8][2] = {0};\n        VB c1[8] = {0}; VB c2[8] = {0};\n        aoffset = const_cast<block_q8_0 *>(a);\n        vecOffset = vec;\n        j = (rows >> 3);\n        if (j > 0) {\n            do {\n                aoffsets[0] = aoffset;\n                for (int it = 1; it < 8; it++)\n                    aoffsets[it] = aoffsets[it - 1] + lda;\n                aoffset += 8 * lda;\n\n                i = (cols >> 3);\n                if (i > 0) {\n                do {\n                    for (int it = 0; it < 8; it++) {\n                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);\n                        __builtin_vsx_disassemble_pair(c[it], & arr[it]);\n                        c1[it] = c[it][0];\n                        c2[it] = c[it][1];\n                    }\n                    vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);\n                    vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);\n                    vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip);\n                    vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip);\n                    for (int it = 0; it < 8; it++)\n                        aoffsets[it] += lda;\n                    vecOffset += 256;\n                    i--;\n               } while(i > 0);\n            }\n            j--;\n        } while(j > 0);\n    }\n    if (rows & 4) {\n            aoffsets[0]  = aoffset;\n            for (int it = 1; it < 4; it++ )\n                aoffsets[it] = aoffsets[it-1] + lda;\n            aoffset += 4 * lda;\n        i = (cols >> 3);\n            if (i > 0) {\n               do {\n                    for (int it = 0; it < 4; it++) {\n                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);\n                        __builtin_vsx_disassemble_pair(c[it], & arr[it]);\n                        c1[it] = c[it][0];\n                        c2[it] = c[it][1];\n                    }\n                    vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);\n                    vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);\n                    for (int it = 0; it < 4; it++) {\n                        aoffsets[it] += lda;\n                    }\n                    vecOffset += 128;\n                    i--;\n               } while(i > 0);\n            }\n        }\n\n        if (rows & 3) {\n            aoffsets[0]  = aoffset;\n            for (int it = 1; it < 3; it++ )\n                aoffsets[it] = aoffsets[it - 1] + lda;\n            i = (cols >> 3);\n            if (i > 0) {\n                do {\n                    switch(rows) {\n                        case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs);\n                                __builtin_vsx_disassemble_pair(c[2], & arr[2]);\n                                c1[2] = c[2][0]; c2[2] = c[2][1];\n                        case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs);\n                                __builtin_vsx_disassemble_pair(c[1], & arr[1]);\n                                c1[1] = c[1][0]; c2[1] = c[1][1];\n                        case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs);\n                                __builtin_vsx_disassemble_pair(c[0], & arr[0]);\n                                c1[0] = c[0][0]; c2[0] = c[0][1];\n                                break;\n                    }\n                    vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);\n                    vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);\n                    for (int it = 0; it < 3; it++)\n                         aoffsets[it] += lda;\n                    vecOffset += 128;\n                    i--;\n               } while(i > 0);\n            }\n        }\n    }\n\n    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {\n        int m_rem = MIN(m - m0, 16);\n        int n_rem = MIN(n - n0, 16);\n\n        int mc = 0, nc = 0;\n\n        if (m_rem >= 8 && n_rem >= 8) {\n           mc = 8;\n           nc = 8;\n           gemm<8, 8>(m0, m, n0, n);\n        } else if (m_rem >= 4 && n_rem >= 8) {\n            mc = 4;\n            nc = 8;\n            gemm<4, 8>(m0, m, n0, n);\n        } else if (m_rem >= 8 && n_rem >= 4) {\n            mc = 8;\n            nc = 4;\n            gemm<8, 4>(m0, m, n0, n);\n        } else if (m_rem >= 4 && n_rem >= 4) {\n            mc = 4;\n            nc = 4;\n            gemm_small(m0, m, n0, n, mc, nc);\n        } else {\n            mc = (m_rem >= 4) ? 4 : m_rem;\n            nc = (n_rem >= 4) ? 4 : n_rem;\n            if (mc == 0 || nc == 0)\n               return;\n            gemm_small(m0, m, n0, n, mc, nc);\n        }\n\n        int64_t mp = m0 + ((m - m0) / mc) * mc;\n        int64_t np = n0 + ((n - n0) / nc) * nc;\n        mnpack(mp, m, n0, np);\n        mnpack(m0, m, np, n);\n    }\n\n\n    void KERNEL_4x8(int64_t ii, int64_t jj) {\n        vec_t vec_A[8], vec_B[16] = {0};\n        acc_t acc_0, acc_1;\n        std::array<int, 4> comparray {};\n        vector float fin_res[8] = {0};\n        vector float vs[8] = {0};\n        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;\n        for (int l = 0; l < k; l++) {\n            __builtin_mma_xxsetaccz(& acc_0);\n            __builtin_mma_xxsetaccz(& acc_1);\n            if (std::is_same_v<TA, block_q4_0>) {\n               packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray);\n            } else {\n               packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false);\n            }\n            packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);\n            for(int x = 0; x < 8; x++) {\n                __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);\n                __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]);\n            }\n            for (int I = 0; I<4; I++) {\n                for (int J = 0; J<4; J++) {\n                    *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));\n                    *((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));\n                }\n            }\n            if (!isAblock_q4) {\n                auto aoffset = A + (ii * lda) + l;\n                for (int i = 0; i < 4; i++) {\n                    comparray[i] = 0;\n                    int ca = 0;\n                    auto *at = aoffset->qs;\n                    for (int j = 0; j < 32; j++)\n                        ca += (int)*at++;\n                    comparray[i] = ca;\n                    aoffset += lda;\n                }\n            }\n            compute(& acc_0, 0, 0, comparray, vs, fin_res);\n            compute(& acc_1, 0, 4, comparray, vs, fin_res);\n        }\n        save_res(ii, jj, 0, fin_res);\n        save_res(ii, jj + 4, 4, fin_res);\n    }\n\n    void KERNEL_8x4(int64_t ii, int64_t jj) {\n        vec_t vec_A[16], vec_B[8] = {0};\n        acc_t acc_0, acc_1;\n        std::array<int, 8> comparray {};\n        vector float fin_res[8] = {0};\n        vector float vs[8] = {0};\n        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;\n        for (int l = 0; l < k; l++) {\n            __builtin_mma_xxsetaccz(& acc_0);\n            __builtin_mma_xxsetaccz(& acc_1);\n            if (std::is_same_v<TA, block_q4_0>) {\n               packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);\n            } else {\n               packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);\n            }\n            packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true);\n            for(int x = 0; x < 8; x++) {\n                __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);\n                __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);\n            }\n            for (int I = 0; I < 8; I++) {\n                for (int J = 0; J < 4; J++) {\n                    *((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));\n                }\n            }\n            if (!isAblock_q4) {\n                auto aoffset = A + (ii * lda) + l;\n                for (int i = 0; i < 8; i++) {\n                    comparray[i] = 0;\n                    int ca = 0;\n                    auto *at = aoffset->qs;\n                    for (int j = 0; j < 32; j++)\n                        ca += (int)*at++;\n                    comparray[i] = ca;\n                    aoffset += lda;\n                }\n            }\n            compute(& acc_0, 0, 0, comparray, vs, fin_res);\n            compute(& acc_1, 4, 4, comparray, vs, fin_res);\n        }\n        save_res(ii, jj, 0, fin_res);\n        save_res(ii + 4, jj, 4, fin_res);\n    }\n\n    void KERNEL_8x8(int64_t ii, int64_t jj) {\n        vec_t vec_A[16], vec_B[16] = {0};\n        acc_t acc_0, acc_1, acc_2, acc_3;\n        acc_t acc_4, acc_5, acc_6, acc_7;\n        std::array<int, 8> comparray {};\n        vector float fin_res[16] = {0};\n        vector float vs[16] = {0};\n        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;\n        for (int l = 0; l < k; l++) {\n            __builtin_mma_xxsetaccz(& acc_0);\n            __builtin_mma_xxsetaccz(& acc_1);\n            __builtin_mma_xxsetaccz(& acc_2);\n            __builtin_mma_xxsetaccz(& acc_3);\n            if (std::is_same_v<TA, block_q4_0>) {\n               packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);\n            } else {\n               packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);\n            }\n            packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);\n            for(int x = 0; x < 8; x++) {\n                __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);\n                __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);\n                __builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]);\n                __builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]);\n            }\n            for (int I = 0; I < 8 ; I++) {\n                for (int J = 0; J < 4; J++) {\n                    *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));\n                    *((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));\n                }\n            }\n            if (!isAblock_q4) {\n                auto aoffset = A + (ii * lda) + l;\n                for (int i = 0; i < 8; i++) {\n                    comparray[i] = 0;\n                    int ca = 0;\n                    auto *at = aoffset->qs;\n                    for (int j = 0; j < 32; j++)\n                        ca += (int)*at++;\n                    comparray[i] = ca;\n                    aoffset += lda;\n                }\n            }\n            compute(& acc_0, 0, 0, comparray, vs, fin_res);\n            compute(& acc_1, 4, 4, comparray, vs, fin_res);\n            compute(& acc_2, 0, 8, comparray, vs, fin_res);\n            compute(& acc_3, 4, 12, comparray, vs, fin_res);\n        }\n        save_res(ii, jj, 0, fin_res);\n        save_res(ii + 4, jj, 4, fin_res);\n        save_res(ii, jj + 4, 8, fin_res);\n        save_res(ii + 4, jj + 4, 12, fin_res);\n    }\n\n    void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) {\n        acc_t acc[8];\n        for (int i = 0; i < mc ; i += 16) {\n            for (int j = 0; j < nc; j += 8) {\n                int A0_base = (i / 16) * (2 * 32 * kc);\n                int B0_base = (j / 8) * (32 * kc);\n                for (int x = 0; x < 8; x++) {\n                     __builtin_mma_xxsetaccz(&acc[x]);\n                }\n                for (int64_t kk = 0; kk < kc; kk++) {\n                    int A0_block_idx = A0_base + kk * 32;\n                    int B0_block_idx = B0_base + kk * 32;\n                    int A1_block_idx = A0_block_idx + 32 * kc;\n                    int B1_block_idx = B0_block_idx + 32 * kc;\n                    vec_t * A0_block = & vec_A[A0_block_idx];\n                    vec_t * B0_block = & vec_B[B0_block_idx];\n                    vec_t * A1_block = & vec_A[A1_block_idx];\n                    for (int it = 0; it < 4; it++) {\n                        for (int x = 0; x < 4; x++) {\n                            __builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]);\n                            __builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]);\n                            __builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]);\n                            __builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]);\n                            __builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]);\n                            __builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]);\n                            __builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]);\n                            __builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]);\n                        }\n                    }\n                }\n                if (l == 0) {\n                    save_acc(& acc[0], ii + i, jj + j);\n                    save_acc(& acc[1], ii + i, jj + j + 4);\n                    save_acc(& acc[2], ii + i + 4, jj + j);\n                    save_acc(& acc[3], ii + i + 4, jj + j + 4);\n                    save_acc(& acc[4], ii + i + 8, jj + j);\n                    save_acc(& acc[5], ii + i + 8, jj + j + 4);\n                    save_acc(& acc[6], ii + i + 12, jj + j);\n                    save_acc(& acc[7], ii + i + 12, jj + j + 4);\n                } else {\n                    add_save_acc(& acc[0], ii + i, jj + j);\n                    add_save_acc(& acc[1], ii + i, jj + j + 4);\n                    add_save_acc(& acc[2], ii + i + 4, jj + j);\n                    add_save_acc(& acc[3], ii + i + 4, jj + j + 4);\n                    add_save_acc(& acc[4], ii + i + 8, jj + j);\n                    add_save_acc(& acc[5], ii + i + 8, jj + j + 4);\n                    add_save_acc(& acc[6], ii + i + 12, jj + j);\n                    add_save_acc(& acc[7], ii + i + 12, jj + j + 4);\n                }\n            }\n        }\n    }\n\n    void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {\n        vec_t A_pack[mc * kc * 4];\n        vec_t B_pack[nc * kc * 4];\n        constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;\n        int64_t ytiles = m / mc;\n        int64_t xtiles = n / nc;\n        int64_t tiles  = xtiles * ytiles;\n        int64_t duty = (tiles + nth - 1) / nth;\n        int64_t start = duty * ith;\n        int64_t end = start + duty;\n        if (end > tiles) {\n            end = tiles;\n        }\n        for (int64_t job = start; job < end; ++job) {\n            int64_t ii = (job / xtiles) * mc;\n            int64_t jj = (job % xtiles) * nc;\n            for (int64_t kk = 0; kk < k; kk += kc) {\n                if constexpr(is_Ablock_q4) {\n                    packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);\n                } else {\n                    packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);\n                }\n                packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);\n                KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);\n            }\n        }\n    }\n\n    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {\n        int64_t ytiles = (m - m0) / RM;\n        int64_t xtiles = (n - n0) / RN;\n        int64_t tiles = xtiles * ytiles;\n        int64_t duty = (tiles + nth - 1) / nth;\n        int64_t start = duty * ith;\n        int64_t end = start + duty;\n        vec_t vec_A[8] = {0}, vec_B[8] = {0};\n        vector signed int vec_C[4];\n        acc_t acc_0;\n        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;\n\n        if (end > tiles)\n            end = tiles;\n        for (int64_t job = start; job < end; ++job) {\n            int64_t ii = m0 + job / xtiles * RM;\n            int64_t jj = n0 + job % xtiles * RN;\n            std::array<int, 4> comparray{};\n            vector float res[4] = {0};\n            vector float fin_res[4] = {0};\n            vector float vs[4] = {0};\n            vector float CA[4] = {0};\n            __builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value\n            __builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value\n            for (int l = 0; l < k; l++) {\n                __builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead\n                __builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead\n                __builtin_mma_xxsetaccz(& acc_0);\n                if (isAblock_q4) {\n                    packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray);\n                } else {\n                    packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false);\n                }\n                packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true);\n                for (int x = 0; x < 8; x += 4) {\n                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);\n                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]);\n                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]);\n                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]);\n                }\n                for (int I = 0; I < RM; I++) {\n                    for (int J = 0; J < RN; J++) {\n                        *((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));\n                    }\n                }\n                __builtin_mma_disassemble_acc(vec_C, & acc_0);\n                if (!isAblock_q4) {\n                    auto aoffset = A + (ii * lda) + l;\n                    for (int i = 0; i < RM; i++) {\n                        comparray[i] = 0;\n                        int ca = 0;\n                        auto *at = aoffset->qs;\n                        for (int j = 0; j < 32; j++)\n                            ca += (int)*at++;\n                        comparray[i] = ca;\n                        aoffset += lda;\n                    }\n                }\n                for (int i = 0; i < RM; i++) {\n                    CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));\n                    res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);\n                    fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);\n                }\n            }\n            save_res(ii, jj, 0, fin_res, RM, RN);\n        }\n    }\n\n    template<int RM, int RN>\n    inline void kernel(int64_t ii, int64_t jj) {\n        if constexpr(RM == 4 && RN == 8) {\n            KERNEL_4x8(ii,jj);\n        } else if constexpr(RM == 8 && RN == 4) {\n            KERNEL_8x4(ii,jj);\n        } else if constexpr(RM == 8 && RN == 8) {\n            KERNEL_8x8(ii,jj);\n        } else {\n            assert(false && \"RN/RM values not supported\");\n        }\n    }\n\n    template <int RM, int RN>\n    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {\n        int64_t ytiles = (m - m0) / RM;\n        int64_t xtiles = (n - n0) / RN;\n        int64_t tiles = xtiles * ytiles;\n        int64_t duty = (tiles + nth - 1) / nth;\n        int64_t start = duty * ith;\n        int64_t end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (int64_t job = start; job < end; ++job) {\n            int64_t ii = m0 + job / xtiles * RM;\n            int64_t jj = n0 + job % xtiles * RN;\n            kernel<RM, RN>(ii, jj);\n        }\n    }\n    const TA * const A;\n    const block_q8_0 * const B;\n    float * C;\n    const int64_t k;\n    int64_t kc;\n    const int64_t lda;\n    const int64_t ldb;\n    const int64_t ldc;\n    const int ith;\n    const int nth;\n};\n\nclass tinyBLAS_PPC {\n  public:\n    tinyBLAS_PPC(int64_t k,\n                const float * A, int64_t lda,\n                const float * B, int64_t ldb,\n                float * C, int64_t ldc,\n                int ith, int nth)\n        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {\n    }\n\n    void matmul(int64_t m, int64_t n) {\n        int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;\n        if (m % mc == 0 && n % nc == 0 && k % kc == 0) {\n            matmul_tiled(m, n, mc, nc, kc);\n        } else {\n            mnpack(0, m, 0, n);\n        }\n    }\n\n  private:\n\n    inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {\n        vec_t vec_C[4];\n        __builtin_mma_disassemble_acc(vec_C, ACC);\n        for (int I = 0; I < 4; I++) {\n            for (int J = 0; J < 4; J++) {\n                *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);\n            }\n        }\n    }\n\n    inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {\n        vec_t vec_C[4];\n        __builtin_mma_disassemble_acc(vec_C, ACC);\n        for (int I = 0; I < 4; I++) {\n            for (int J = 0; J < 4; J++) {\n                float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);\n                *c_ptr += *((float *)&vec_C[I]+J);\n            }\n        }\n    }\n\n    inline void vector_permute_store_4(vector float * src, float * vecOffset) {\n        vector float t1, t2, t3, t4, t5, t6, t7, t8;\n        t1 = vec_mergeh(src[0], src[1]);\n        t2 = vec_mergeh(src[2], src[3]);\n        t3 = vec_mergel(src[0], src[1]);\n        t4 = vec_mergel(src[2], src[3]);\n\n        t5 = vec_xxpermdi(t1, t2, 0);\n        t6 = vec_xxpermdi(t1, t2, 3);\n        t7 = vec_xxpermdi(t3, t4, 0);\n        t8 = vec_xxpermdi(t3, t4, 3);\n\n        vec_xst(t5, 0, vecOffset);\n        vec_xst(t6, 0, vecOffset + 4);\n        vec_xst(t7, 0, vecOffset + 8);\n        vec_xst(t8, 0, vecOffset + 12);\n    }\n\n    inline void vector_permute_store_8(vector float * src, float * vecOffset) {\n        vector float t1, t2, t3, t4, t5, t6, t7, t8;\n        t1 = vec_mergeh(src[0], src[1]);\n        t2 = vec_mergeh(src[2], src[3]);\n        t3 = vec_mergeh(src[4], src[5]);\n        t4 = vec_mergeh(src[6], src[7]);\n\n        t5 = vec_xxpermdi(t1, t2, 0);\n        t6 = vec_xxpermdi(t3, t4, 0);\n        t7 = vec_xxpermdi(t1, t2, 3);\n        t8 = vec_xxpermdi(t3, t4, 3);\n\n        vec_xst(t5, 0, vecOffset);\n        vec_xst(t6, 0, vecOffset + 4);\n        vec_xst(t7, 0, vecOffset + 8);\n        vec_xst(t8, 0, vecOffset + 12);\n\n        t1 = vec_mergel(src[0], src[1]);\n        t2 = vec_mergel(src[2], src[3]);\n        t3 = vec_mergel(src[4], src[5]);\n        t4 = vec_mergel(src[6], src[7]);\n\n        t5 = vec_xxpermdi(t1, t2, 0);\n        t6 = vec_xxpermdi(t3, t4, 0);\n        t7 = vec_xxpermdi(t1, t2, 3);\n        t8 = vec_xxpermdi(t3, t4, 3);\n\n        vec_xst(t5, 0, vecOffset + 16);\n        vec_xst(t6, 0, vecOffset + 20);\n        vec_xst(t7, 0, vecOffset + 24);\n        vec_xst(t8, 0, vecOffset + 28);\n    }\n\n    void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {\n        int64_t i, j;\n        float * aoffsets[8];\n        float * aoffset = NULL, * boffset = NULL;\n        __vector_pair arr[8];\n        vector float c[8][2] = {0};\n        vector float c1[8] = {0};\n        vector float c2[8] = {0};\n        aoffset = const_cast<float *>(a);\n        boffset = vec;\n        j = (rows >> 3);\n        if (j > 0) {\n            do {\n                aoffsets[0] = aoffset;\n                for (int it = 1; it < 8; it++)\n                    aoffsets[it] = aoffsets[it-1] + lda;\n                aoffset += 8 * lda;\n                i = (cols >> 3);\n                if (i > 0) {\n                    do {\n                        for (int it = 0; it < 8; it++) {\n                            arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);\n                            __builtin_vsx_disassemble_pair(c[it], &arr[it]);\n                            c1[it] = c[it][0];\n                            c2[it] = c[it][1];\n                        }\n\n                        vector_permute_store_8(c1, boffset);\n                        vector_permute_store_8(c2, boffset + 32);\n                        boffset += 64;\n                        i--;\n                        if (i > 0) {\n                           for (int it = 0; it < 8; it++) {\n                               aoffsets[it] = aoffsets[it] + 8;\n                           }\n                        }\n                    } while(i > 0);\n                }\n                if (cols & 4) {\n                    for (int it = 0; it < 8 ; it++)\n                        c1[it] = vec_xl(0, aoffsets[it]);\n                    vector_permute_store_8(c1, boffset);\n                }\n            j--;\n            } while(j > 0);\n        }\n\n        if (rows & 4) {\n            aoffsets[0] = aoffset;\n            for (int it = 1; it < 4; it++)\n                aoffsets[it] = aoffsets[it-1] + lda;\n            aoffset += 4 * lda;\n            i = (cols >> 3);\n            if (i > 0) {\n                do {\n                    for (int it = 0; it < 4; it++) {\n                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);\n                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);\n                        c1[it] = c[it][0];\n                        c2[it] = c[it][1];\n                    }\n                    vector_permute_store_4(c1, boffset);\n                    vector_permute_store_4(c2, boffset + 16);\n                    for (int it = 0; it < 4; it++)\n                        aoffsets[it] += 8 * lda;\n                    boffset += 32;\n                    i--;\n                } while(i > 0);\n            }\n\n            if (cols & 4) {\n               for (int it = 0; it < 4; it++)\n                   c1[it] = vec_xl(0, aoffsets[it]);\n                vector_permute_store_4(c1, boffset);\n            }\n        }\n        if (rows & 3) {\n            aoffsets[0] = aoffset;\n            for (int it = 1; it < 3; it++)\n                aoffsets[it] = aoffsets[it-1] + lda;\n            if (cols & 4) {\n                for (int it = 0; it < 3; it++)\n                    c1[it] = vec_xl(0, aoffsets[it]);\n                vector_permute_store_4(c1, boffset);\n            }\n        }\n    }\n\n    void KERNEL_4x4(int64_t ii, int64_t jj) {\n        vec_t vec_A[4], vec_B[4], vec_C[4];\n        acc_t acc_0;\n        __builtin_mma_xxsetaccz(&acc_0);\n        for (int l = 0; l < k; l += 4) {\n            packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);\n            packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);\n            __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);\n            __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);\n            __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);\n            __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);\n        }\n        save_acc(&acc_0, ii, jj);\n    }\n\n    void KERNEL_4x8(int64_t ii, int64_t jj) {\n        vec_t vec_A[4], vec_B[8], vec_C[4];\n        acc_t acc_0, acc_1;\n        __builtin_mma_xxsetaccz(&acc_0);\n        __builtin_mma_xxsetaccz(&acc_1);\n        for (int64_t l = 0; l < k; l += 4) {\n            packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);\n            packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);\n            __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);\n            __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);\n            __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);\n            __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);\n            __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);\n            __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);\n            __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);\n            __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);\n        }\n        save_acc(&acc_0, ii, jj);\n        save_acc(&acc_1, ii, jj + 4);\n    }\n\n    void KERNEL_8x4(int64_t ii, int64_t jj) {\n        vec_t vec_A[8], vec_B[4], vec_C[4];\n        acc_t acc_0, acc_1;\n        __builtin_mma_xxsetaccz(&acc_0);\n        __builtin_mma_xxsetaccz(&acc_1);\n        for (int64_t l = 0; l < k; l += 4) {\n            packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);\n            packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);\n            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);\n            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);\n            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);\n            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);\n            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);\n            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);\n            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);\n            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);\n        }\n        save_acc(&acc_0, ii, jj);\n        save_acc(&acc_1, ii + 4, jj);\n    }\n\n    void KERNEL_8x8(int64_t ii, int64_t jj) {\n        vec_t vec_A[16], vec_B[16], vec_C[4];\n        acc_t acc_0, acc_1, acc_2, acc_3;\n        __builtin_mma_xxsetaccz(&acc_0);\n        __builtin_mma_xxsetaccz(&acc_1);\n        __builtin_mma_xxsetaccz(&acc_2);\n        __builtin_mma_xxsetaccz(&acc_3);\n        for (int l = 0; l < k; l+=8) {\n            packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);\n            packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);\n            for(int x = 0; x < 16; x+=2) {\n                __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);\n                __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);\n                __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);\n                __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);\n            }\n        }\n        save_acc(&acc_0, ii, jj);\n        save_acc(&acc_1, ii, jj + 4);\n        save_acc(&acc_2, ii + 4, jj);\n        save_acc(&acc_3, ii + 4, jj + 4);\n    }\n\n    inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {\n        for (int x = 0; x < 16; x += 2) {\n            __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);\n            __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);\n            __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);\n            __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);\n            __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);\n            __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);\n            __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);\n            __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);\n        }\n    }\n\n    void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {\n        for (int64_t i = 0; i < mc; i += 16) {\n            int A_base_addr = (mc / 8) * (i / 8) * 16;\n            for (int64_t j = 0; j < nc; j += 8) {\n                 int B_base_addr = (nc / 8) * (j / 8) * 16;\n                 acc_t acc[8];\n                 vec_t A0_block[16]; vec_t A1_block[16];\n                 for (int x = 0; x < 8; x++)\n                     __builtin_mma_xxsetaccz(&acc[x]);\n                 for (int64_t l = 0; l < kc; l += 8) {\n                     int A0_block_idx = A_base_addr + (l / 8) * 16;\n                     int A1_block_idx = A0_block_idx + (mc / 8) * 16;\n                     int B_block_idx = B_base_addr + (l / 8) * 16;\n                     vec_t* A0_block = &vec_A[A0_block_idx];\n                     vec_t* A1_block = &vec_A[A1_block_idx];\n                     vec_t* B_block = &vec_B[B_block_idx];\n                     MMA_16x8(A0_block, A1_block, B_block, acc);\n                 }\n                 if (kk == 0) {\n                     save_acc(&acc[0], ii + i, jj + j);\n                     save_acc(&acc[1], ii + i, jj + j + 4);\n                     save_acc(&acc[2], ii + i + 4, jj + j);\n                     save_acc(&acc[3], ii + i + 4, jj + j + 4);\n                     save_acc(&acc[4], ii + i + 8, jj + j);\n                     save_acc(&acc[5], ii + i + 8, jj + j + 4);\n                     save_acc(&acc[6], ii + i + 12, jj + j);\n                     save_acc(&acc[7], ii + i + 12, jj + j + 4);\n                 } else {\n                     add_save_acc(&acc[0], ii + i, jj + j);\n                     add_save_acc(&acc[1], ii + i, jj + j + 4);\n                     add_save_acc(&acc[2], ii + i + 4, jj + j);\n                     add_save_acc(&acc[3], ii + i + 4, jj + j + 4);\n                     add_save_acc(&acc[4], ii + i + 8, jj + j);\n                     add_save_acc(&acc[5], ii + i + 8, jj + j + 4);\n                     add_save_acc(&acc[6], ii + i + 12, jj + j);\n                     add_save_acc(&acc[7], ii + i + 12, jj + j + 4);\n                 }\n            }\n        }\n    }\n\n    void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {\n        int64_t ytiles = m / mc;\n        int64_t xtiles = n / nc;\n        int64_t tiles = xtiles * ytiles;\n        int64_t duty = (tiles + nth - 1) / nth;\n        int64_t start = duty * ith;\n        int64_t end = start + duty;\n        if (end > tiles) {\n            end = tiles;\n        }\n        for (int64_t job = start; job < end; ++job) {\n            int64_t ii = (job / xtiles) * mc;\n            int64_t jj = (job % xtiles) * nc;\n            for (int64_t kk = 0; kk < k; kk += kc) {\n                 vec_t A_pack[kc * mc / 4];\n                 vec_t B_pack[kc * nc / 4];\n                 packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);\n                 packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);\n                 KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);\n            }\n        }\n    }\n\n    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {\n        int m_rem = MIN(m - m0, 8);\n        int n_rem = MIN(n - n0, 8);\n        int mc = 0, nc = 0;\n        if (m_rem >= 8 && n_rem >= 8) {\n            mc = 8;\n            nc = 8;\n            gemm<8, 8>(m0, m, n0, n);\n        } else if (m_rem >= 4 && n_rem >= 8) {\n            mc = 4;\n            nc = 8;\n            gemm<4, 8>(m0, m, n0, n);\n        } else if (m_rem >= 8 && n_rem >= 4) {\n            mc = 8;\n            nc = 4;\n            gemm<8, 4>(m0, m, n0, n);\n        } else if (m_rem >= 4 && n_rem >= 4) {\n            mc = 4;\n            nc = 4;\n            gemm<4, 4>(m0, m, n0, n);\n        } else {\n            mc = (m_rem >= 4) ? 4 : m_rem;\n            nc = (n_rem >= 4) ? 4 : n_rem;\n            if (mc == 0 || nc == 0)\n                return;\n            gemm_small(m0, m, n0, n, mc, nc);\n        }\n        int64_t mp = m0 + ((m - m0) / mc) * mc;\n        int64_t np = n0 + ((n - n0) / nc) * nc;\n        mnpack(mp, m, n0, np);\n        mnpack(m0, m, np, n);\n    }\n\n    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {\n        int64_t ytiles = (m - m0) / RM;\n        int64_t xtiles = (n - n0) / RN;\n        int64_t tiles = xtiles * ytiles;\n        int64_t duty = (tiles + nth - 1) / nth;\n        int64_t start = duty * ith;\n        int64_t end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (int64_t job = start; job < end; ++job) {\n            int64_t ii = m0 + job / xtiles * RM;\n            int64_t jj = n0 + job % xtiles * RN;\n            vec_t vec_C[4];\n            acc_t acc_0;\n            __builtin_mma_xxsetaccz(&acc_0);\n            vec_t vec_A[4] = {0}, vec_B[4] = {0};\n            for (int l = 0; l < k; l += 4) {\n                /* 'GEMV Forwarding' concept is used in first two conditional loops.\n                 * when one of the matrix has a single row/column, the elements are\n                 * broadcasted, instead of using packing routine to prepack the\n                 * matrix elements.\n                 */\n                if (RM == 1) {\n                    float * a = const_cast<float *>(A + (ii) * lda + l);\n                    packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);\n                    vec_A[0] = (vec_t)vec_xl(0,a);\n                    vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));\n                    vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));\n                    vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));\n                } else if (RN == 1) {\n                    packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);\n                    float * b = const_cast<float *>(B + (jj) * ldb + l);\n                    vec_B[0] = (vec_t)vec_xl(0,b);\n                    vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));\n                    vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));\n                    vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));\n                } else {\n                    packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);\n                    packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);\n                }\n                __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);\n                __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);\n                __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);\n                __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);\n            }\n            __builtin_mma_disassemble_acc(vec_C, &acc_0);\n            for (int I = 0; I < RM; I++) {\n                for (int J = 0; J < RN; J++) {\n                    *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);\n                }\n            }\n       }\n    }\n\n    template<int RM, int RN>\n    inline void kernel(int64_t ii, int64_t jj) {\n        if constexpr(RM == 4 && RN == 4) {\n            KERNEL_4x4(ii, jj);\n        } else if constexpr(RM == 4 && RN == 8) {\n            KERNEL_4x8(ii, jj);\n        } else if constexpr(RM == 8 && RN == 4) {\n            KERNEL_8x4(ii, jj);\n        } else if constexpr(RM == 8 && RN == 8) {\n            KERNEL_8x8(ii, jj);\n        } else {\n            static_assert(false, \"RN/RM values not supported\");\n        }\n    }\n\n    template <int RM, int RN>\n    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {\n        int64_t ytiles = (m - m0) / RM;\n        int64_t xtiles = (n - n0) / RN;\n        int64_t tiles = xtiles * ytiles;\n        int64_t duty = (tiles + nth - 1) / nth;\n        int64_t start = duty * ith;\n        int64_t end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (int64_t job = start; job < end; ++job) {\n            int64_t ii = m0 + job / xtiles * RM;\n            int64_t jj = n0 + job % xtiles * RN;\n            kernel<RM, RN>(ii, jj);\n        }\n    }\n\n    const float * const A;\n    const float * const B;\n    float * C;\n    const int64_t k;\n    const int64_t lda;\n    const int64_t ldb;\n    const int64_t ldc;\n    const int ith;\n    const int nth;\n};\n#endif\n} // namespace\n\n/**\n * Performs optimized matrix multiplication on CPU.\n *\n * This subroutine may compute C = Aᵀ * B with column major ordering.\n * Despite its name, this isn't a generalized implementation. Work is\n * only performed when a handwritten kernel is written and available.\n * Otherwise the caller should fall back to a general matmul routine.\n *\n * For example, for single-threaded single-precision GEMM you can say\n *\n *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,\n *                     0, 1,\n *                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);\n *\n * @param m is rows in `A` and `C`\n * @param n is cols in `B` and `C`\n * @param k is cols in `A` and rows in `B`\n * @param A is first input matrix (always transposed)\n * @param lda is row stride of `A`\n * @param B is second input matrix (never transposed)\n * @param ldb is row stride of `B`\n * @param C is input/output array of output matrices\n * @param ldc is row stride of `C`\n * @param ith is thread id (must be less than `nth`)\n * @param nth is number of threads (must be greater than zero)\n * @param Atype is GGML data type of `A`\n * @param Btype is GGML data type of `B`\n * @param Ctype is GGML data type of `C`\n * @return true if this function was able to service the matmul request\n */\nbool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,\n                     const void *A, int64_t lda, const void *B, int64_t ldb, void *C,\n                     int64_t ldc, int Atype, int Btype, int Ctype) {\n\n    assert(m >= 0);\n    assert(n >= 0);\n    assert(k >= 0);\n    assert(lda >= k);\n    assert(ldb >= k);\n    assert(ldc >= m);\n    assert(params->nth > 0);\n    assert(params->ith < params->nth);\n\n    // only enable sgemm for prompt processing\n#if !defined(__MMA__)\n    if (n < 2)\n        return false;\n#endif\n\n    if (Ctype != GGML_TYPE_F32)\n        return false;\n\n    switch (Atype) {\n\n    case GGML_TYPE_F32: {\n        if (Btype != GGML_TYPE_F32)\n            return false;\n#if defined(__AVX512F__)\n        tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,\n            k, (const float *)A, lda,\n            (const float *)B, ldb,\n            (float *)C, ldc};\n        return tb.matmul(m, n);\n#elif defined(__AVX__) || defined(__AVX2__)\n        tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,\n            k, (const float *)A, lda,\n            (const float *)B, ldb,\n            (float *)C, ldc};\n        return tb.matmul(m, n);\n#elif defined(__ARM_NEON)\n        if (n < 4)\n            return false;\n        tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,\n            k, (const float *)A, lda,\n            (const float *)B, ldb,\n            (float *)C, ldc};\n        return tb.matmul(m, n);\n#elif defined(__VXE__) || defined(__VXE2__)\n        if (n < 4)\n            return false;\n        tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,\n            k, (const float *)A, lda,\n            (const float *)B, ldb,\n            (float *)C, ldc};\n        return tb.matmul(m, n);\n#elif defined(__MMA__)\n        if (k % 8)\n            return false;\n        tinyBLAS_PPC tb{\n            k, (const float *)A, lda,\n            (const float *)B, ldb,\n            (float *)C, ldc,\n            params->ith, params->nth};\n        tb.matmul(m, n);\n        return true;\n#elif defined(__riscv_zvfh)\n    #if LMUL == 1\n        tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,\n            k, (const float *)A, lda,\n            (const float *)B, ldb,\n            (float *)C, ldc};\n    #elif LMUL == 2\n        tinyBLAS_RVV<vfloat32m2_t, vfloat32m2_t, float, float, float> tb{ params,\n            k, (const float *)A, lda,\n            (const float *)B, ldb,\n            (float *)C, ldc};\n    #else // LMUL = 4\n        tinyBLAS_RVV<vfloat32m4_t, vfloat32m4_t, float, float, float> tb{ params,\n            k, (const float *)A, lda,\n            (const float *)B, ldb,\n            (float *)C, ldc};\n    #endif\n        return tb.matmul(m, n);\n#else\n        return false;\n#endif\n    }\n\n    case GGML_TYPE_BF16: {\n#if defined(__AVX512BF16__)\n        if (Btype == GGML_TYPE_BF16) {\n            tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,\n                (const ggml_bf16_t *)A, lda,\n                (const ggml_bf16_t *)B, ldb,\n                (float *)C, ldc};\n            return tb.matmul(m, n);\n        }\n#elif defined(__AVX512F__)\n        if (Btype == GGML_TYPE_BF16) {\n            tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,\n                (const ggml_bf16_t *)A, lda,\n                (const ggml_bf16_t *)B, ldb,\n                (float *)C, ldc};\n            return tb.matmul(m, n);\n        }\n#elif defined(__AVX2__)\n        if (Btype == GGML_TYPE_BF16) {\n            tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,\n                (const ggml_bf16_t *)A, lda,\n                (const ggml_bf16_t *)B, ldb,\n                (float *)C, ldc};\n            return tb.matmul(m, n);\n        }\n#elif defined(__MMA__)\n        if (k % 8) {\n            return false;\n        }\n\n        if (Btype == GGML_TYPE_BF16) {\n            tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,\n                (const ggml_bf16_t *)A, lda,\n                (const ggml_bf16_t *)B, ldb,\n                (float *)C, ldc,\n                params->ith, params->nth };\n\n            tb.matmul(m, n);\n            return true;\n        }\n#elif defined(__riscv_zvfbfwma)\n        #if LMUL == 1\n            tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,\n                k, (const ggml_bf16_t *)A, lda,\n                (const ggml_bf16_t *)B, ldb,\n                (float *)C, ldc};\n        #elif LMUL == 2\n            tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,\n                k, (const ggml_bf16_t *)A, lda,\n                (const ggml_bf16_t *)B, ldb,\n                (float *)C, ldc};\n        #else // LMUL = 4\n            tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,\n                k, (const ggml_bf16_t *)A, lda,\n                (const ggml_bf16_t *)B, ldb,\n                (float *)C, ldc};\n        #endif\n            return tb.matmul(m, n);\n#endif\n        return false;\n    }\n\n    case GGML_TYPE_F16: {\n#if defined(__AVX512F__)\n        if (Btype == GGML_TYPE_F16) {\n            tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,\n                (const ggml_fp16_t *)A, lda,\n                (const ggml_fp16_t *)B, ldb,\n                (float *)C, ldc};\n            return tb.matmul(m, n);\n        }\n#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)\n        if (Btype == GGML_TYPE_F16) {\n            tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,\n                (const ggml_fp16_t *)A, lda,\n                (const ggml_fp16_t *)B, ldb,\n                (float *)C, ldc};\n            return tb.matmul(m, n);\n        }\n#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)\n        if (n < 8)\n            return false;\n        if (Btype == GGML_TYPE_F16) {\n            tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,\n                k, (const ggml_fp16_t *)A, lda,\n                (const ggml_fp16_t *)B, ldb,\n                (float *)C, ldc};\n            return tb.matmul(m, n);\n        }\n#elif defined(__ARM_NEON) && !defined(_MSC_VER)\n        if (Btype == GGML_TYPE_F32) {\n            tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params,\n                k, (const ggml_fp16_t *)A, lda,\n                (const float *)B, ldb,\n                (float *)C, ldc};\n            return tb.matmul(m, n);\n        }\n#elif defined(__VXE__) || defined(__VXE2__)\n        if (n < 4)\n            return false;\n        if (Btype == GGML_TYPE_F16) {\n            tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,\n                k, (const ggml_fp16_t *)A, lda,\n                (const ggml_fp16_t *)B, ldb,\n                (float *)C, ldc};\n            return tb.matmul(m, n);\n        }\n#elif defined(__riscv_zvfh)\n        if (Btype == GGML_TYPE_F16) {\n        #if LMUL == 1\n            tinyBLAS_RVV<vfloat32m1_t, vfloat16mf2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,\n                k, (const ggml_fp16_t *)A, lda,\n                (const ggml_fp16_t *)B, ldb,\n                (float *)C, ldc};\n        #elif LMUL == 2\n            tinyBLAS_RVV<vfloat32m2_t, vfloat16m1_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,\n                k, (const ggml_fp16_t *)A, lda,\n                (const ggml_fp16_t *)B, ldb,\n                (float *)C, ldc};\n        #else // LMUL = 4\n            tinyBLAS_RVV<vfloat32m4_t, vfloat16m2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,\n                k, (const ggml_fp16_t *)A, lda,\n                (const ggml_fp16_t *)B, ldb,\n                (float *)C, ldc};\n        #endif\n            return tb.matmul(m, n);\n        }\n#elif defined(__MMA__)\n        if (k % 8) {\n            return false;\n        }\n\n        if (Btype == GGML_TYPE_F16) {\n            tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k,\n                (const ggml_fp16_t *)A, lda,\n                (const ggml_fp16_t *)B, ldb,\n                (float *)C, ldc,\n                params->ith, params->nth };\n\n            tb.matmul(m, n);\n            return true;\n        }\n#endif\n        return false;\n    }\n\n    case GGML_TYPE_Q8_0: {\n        if (Btype != GGML_TYPE_Q8_0)\n           return false;\n#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)\n        tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{\n            k, (const block_q8_0 *)A, lda,\n            (const block_q8_0 *)B, ldb,\n            (float *)C, ldc,\n            params->ith, params->nth};\n        tb.matmul(m, n);\n        return true;\n#elif defined(__ARM_FEATURE_DOTPROD)\n        tinyBLAS_Q0_ARM<block_q8_0> tb{\n            k, (const block_q8_0 *)A, lda,\n            (const block_q8_0 *)B, ldb,\n            (float *)C, ldc,\n            params->ith, params->nth};\n        tb.matmul(m, n);\n        return true;\n#elif defined(__MMA__)\n    //TO-DO: Remove this condition once gemv forwarding is enabled.\n        if (n < 8 && n != 4)\n           return false;\n        if (m < 8 && m != 4)\n           return false;\n        tinyBLAS_Q0_PPC<block_q8_0> tb{\n            k, (const block_q8_0 *)A, lda,\n            (const block_q8_0 *)B, ldb,\n            (float *)C, ldc,\n            params->ith, params->nth};\n        tb.matmul(m, n);\n        return true;\n#else\n        return false;\n#endif\n    }\n\n    case GGML_TYPE_Q4_0: {\n        if (Btype != GGML_TYPE_Q8_0)\n            return false;\n#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)\n        tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{\n            k, (const block_q4_0 *)A, lda,\n            (const block_q8_0 *)B, ldb,\n            (float *)C, ldc,\n            params->ith, params->nth};\n        tb.matmul(m, n);\n        return true;\n#elif defined(__ARM_FEATURE_DOTPROD)\n        tinyBLAS_Q0_ARM<block_q4_0> tb{\n            k, (const block_q4_0 *)A, lda,\n            (const block_q8_0 *)B, ldb,\n            (float *)C, ldc,\n            params->ith, params->nth};\n        tb.matmul(m, n);\n        return true;\n#elif defined(__MMA__)\n    //TO-DO: Remove this condition once gemv forwarding is enabled.\n        if (n < 8 && n != 4)\n           return false;\n        if (m < 8 && m != 4)\n           return false;\n        tinyBLAS_Q0_PPC<block_q4_0> tb{\n            k, (const block_q4_0 *)A, lda,\n            (const block_q8_0 *)B, ldb,\n            (float *)C, ldc,\n            params->ith, params->nth};\n        tb.matmul(m, n);\n        return true;\n#else\n        return false;\n#endif\n    }\n\n    case GGML_TYPE_Q5_0: {\n        if (Btype != GGML_TYPE_Q8_0)\n            return false;\n#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)\n        tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{\n            k, (const block_q5_0 *)A, lda,\n            (const block_q8_0 *)B, ldb,\n            (float *)C, ldc,\n            params->ith, params->nth};\n        tb.matmul(m, n);\n        return true;\n#else\n        return false;\n#endif\n    }\n\n    case GGML_TYPE_IQ4_NL: {\n        if (Btype != GGML_TYPE_Q8_0)\n            return false;\n#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)\n        tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{\n            k, (const block_iq4_nl *)A, lda,\n            (const block_q8_0 *)B, ldb,\n            (float *)C, ldc,\n            params->ith, params->nth};\n        tb.matmul(m, n);\n        return true;\n#else\n        return false;\n#endif\n    }\n\n    default:\n        return false;\n    }\n\n    (void)params;\n    (void)m;\n    (void)n;\n    (void)k;\n    (void)A;\n    (void)lda;\n    (void)B;\n    (void)ldb;\n    (void)C;\n    (void)ldc;\n    (void)Atype;\n    (void)Btype;\n    (void)Ctype;\n}\n"
  },
  {
    "path": "src/ggml-cpu/llamafile/sgemm.h",
    "content": "#pragma once\n#include <stdint.h>\n#include <stdbool.h>\n\n#if defined(__VXE__) || defined(__VXE2__)\n#include <vecintrin.h>\n#endif\n\n#ifdef _MSC_VER\n#define NOINLINE __declspec(noinline)\n#else\n#define NOINLINE __attribute__((__noinline__))\n#endif\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nbool llamafile_sgemm(const struct ggml_compute_params * params, int64_t, int64_t, int64_t,\n                     const void *, int64_t, const void *, int64_t, void *, int64_t,\n                     int, int, int);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/ops.cpp",
    "content": "#include \"ops.h\"\n\n#include \"ggml-cpu.h\"\n#include \"ggml-impl.h\"\n#include \"binary-ops.h\"\n#include \"simd-gemm.h\"\n#include \"ggml.h\"\n#include \"unary-ops.h\"\n#include \"vec.h\"\n\n#include <algorithm>\n#include <cfloat>\n#include <cmath>\n\n// ggml_compute_forward_dup\n\nstatic void ggml_compute_forward_dup_same_cont(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));\n    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));\n    GGML_ASSERT(src0->type == dst->type);\n\n    const size_t nb0 = ggml_type_size(src0->type);\n\n    const int ith = params->ith; // thread index\n    const int nth = params->nth; // number of threads\n\n    // parallelize by blocks\n    const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);\n    const int dr = (nk + nth - 1) / nth;\n    const int k0 = dr * ith;\n    const int k1 = MIN(k0 + dr, nk);\n\n    if (k0 < k1) {\n        memcpy(\n            ((char *)  dst->data + k0*nb0),\n            ((char *) src0->data + k0*nb0),\n            (k1 - k0) * nb0);\n    }\n}\n\ntemplate<typename src_t, typename dst_t>\nstatic void ggml_compute_forward_dup_flt(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));\n    GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    const int ith = params->ith; // thread index\n    const int nth = params->nth; // number of threads\n\n    // parallelize by rows\n    const int nr = ne01;\n    // number of rows per thread\n    const int dr = (nr + nth - 1) / nth;\n    // row range for this thread\n    const int ir0 = dr * ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    // case: type & row size equal\n    if (src0->type == dst->type &&\n        ne00 == ne0 &&\n        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {\n        // copy by rows\n        const size_t rs = ne00*nb00;\n        for (int64_t i03 = 0; i03 < ne03; i03++) {\n            for (int64_t i02 = 0; i02 < ne02; i02++) {\n                for (int64_t i01 = ir0; i01 < ir1; i01++) {\n                    memcpy(\n                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),\n                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),\n                        rs);\n                }\n            }\n        }\n        return;\n    }\n\n    // case: dst tensor is contiguous\n    if (ggml_is_contiguous(dst)) {\n        if (nb00 == sizeof(src_t)) {\n            if constexpr (std::is_same_v<dst_t, src_t>) {\n                // same type\n                size_t id = 0;\n                const size_t rs = ne00 * nb00;\n                char * dst_ptr = (char *) dst->data;\n\n                for (int i03 = 0; i03 < ne03; i03++) {\n                    for (int i02 = 0; i02 < ne02; i02++) {\n                        id += rs * ir0;\n                        for (int i01 = ir0; i01 < ir1; i01++) {\n                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;\n                            memcpy(dst_ptr + id, src0_ptr, rs);\n                            id += rs;\n                        }\n                        id += rs * (ne01 - ir1);\n                    }\n                }\n            } else {\n                // casting between non-quantized types\n                size_t id = 0;\n                dst_t * dst_ptr = (dst_t *) dst->data;\n\n                for (int i03 = 0; i03 < ne03; i03++) {\n                    for (int i02 = 0; i02 < ne02; i02++) {\n                        id += ne00 * ir0;\n                        for (int i01 = ir0; i01 < ir1; i01++) {\n                            const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);\n                            for (int i00 = 0; i00 < ne00; i00++) {\n                                float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);\n                                dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);\n                                id++;\n                            }\n                        }\n                        id += ne00 * (ne01 - ir1);\n                    }\n                }\n            }\n        } else {\n            //printf(\"%s: this is not optimal - fix me\\n\", __func__);\n\n            size_t id = 0;\n            dst_t * dst_ptr = (dst_t *) dst->data;\n\n            for (int i03 = 0; i03 < ne03; i03++) {\n                for (int i02 = 0; i02 < ne02; i02++) {\n                    id += ne00 * ir0;\n                    for (int i01 = ir0; i01 < ir1; i01++) {\n                        for (int i00 = 0; i00 < ne00; i00++) {\n                            const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);\n\n                            float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);\n                            dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);\n                            id++;\n                        }\n                    }\n                    id += ne00 * (ne01 - ir1);\n                }\n            }\n        }\n        return;\n    }\n\n    // dst counters\n    int64_t i10 = 0;\n    int64_t i11 = 0;\n    int64_t i12 = 0;\n    int64_t i13 = 0;\n\n    if constexpr (std::is_same_v<dst_t, src_t>) {\n        for (int64_t i03 = 0; i03 < ne03; i03++) {\n            for (int64_t i02 = 0; i02 < ne02; i02++) {\n                i10 += ne00 * ir0;\n                while (i10 >= ne0) {\n                    i10 -= ne0;\n                    if (++i11 == ne1) {\n                        i11 = 0;\n                        if (++i12 == ne2) {\n                            i12 = 0;\n                            if (++i13 == ne3) {\n                                i13 = 0;\n                            }\n                        }\n                    }\n                }\n                for (int64_t i01 = ir0; i01 < ir1; i01++) {\n                    for (int64_t i00 = 0; i00 < ne00; i00++) {\n                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);\n                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);\n\n                        memcpy(dst_ptr, src0_ptr, sizeof(dst_t));\n\n                        if (++i10 == ne00) {\n                            i10 = 0;\n                            if (++i11 == ne01) {\n                                i11 = 0;\n                                if (++i12 == ne02) {\n                                    i12 = 0;\n                                    if (++i13 == ne03) {\n                                        i13 = 0;\n                                    }\n                                }\n                            }\n                        }\n                    }\n                }\n                i10 += ne00 * (ne01 - ir1);\n                while (i10 >= ne0) {\n                    i10 -= ne0;\n                    if (++i11 == ne1) {\n                        i11 = 0;\n                        if (++i12 == ne2) {\n                            i12 = 0;\n                            if (++i13 == ne3) {\n                                i13 = 0;\n                            }\n                        }\n                    }\n                }\n            }\n        }\n\n    } else {\n        for (int64_t i03 = 0; i03 < ne03; i03++) {\n            for (int64_t i02 = 0; i02 < ne02; i02++) {\n                i10 += ne00 * ir0;\n                while (i10 >= ne0) {\n                    i10 -= ne0;\n                    if (++i11 == ne1) {\n                        i11 = 0;\n                        if (++i12 == ne2) {\n                            i12 = 0;\n                            if (++i13 == ne3) {\n                                i13 = 0;\n                            }\n                        }\n                    }\n                }\n                for (int64_t i01 = ir0; i01 < ir1; i01++) {\n                    for (int64_t i00 = 0; i00 < ne00; i00++) {\n                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);\n                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);\n\n                        float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);\n                        *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);\n\n                        if (++i10 == ne0) {\n                            i10 = 0;\n                            if (++i11 == ne1) {\n                                i11 = 0;\n                                if (++i12 == ne2) {\n                                    i12 = 0;\n                                    if (++i13 == ne3) {\n                                        i13 = 0;\n                                    }\n                                }\n                            }\n                        }\n                    }\n                }\n                i10 += ne00 * (ne01 - ir1);\n                while (i10 >= ne0) {\n                    i10 -= ne0;\n                    if (++i11 == ne1) {\n                        i11 = 0;\n                        if (++i12 == ne2) {\n                            i12 = 0;\n                            if (++i13 == ne3) {\n                                i13 = 0;\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n}\n\n\ntemplate<typename src_t>\nstatic void ggml_compute_forward_dup_to_q(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));\n    GGML_ASSERT(!ggml_is_quantized(src0->type));\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    const int ith = params->ith; // thread index\n    const int nth = params->nth; // number of threads\n\n    // parallelize by rows\n    const int nr = ne01;\n    // number of rows per thread\n    const int dr = (nr + nth - 1) / nth;\n    // row range for this thread\n    const int ir0 = dr * ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    if (ggml_is_contiguous(dst) &&\n            nb00 == sizeof(src_t) &&\n            ggml_get_type_traits_cpu(dst->type)->from_float) {\n        // casting non-quantized types --> intermediate f32 --> quantized\n        ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;\n        float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;\n\n        size_t id = 0;\n        size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));\n        char * dst_ptr = (char *) dst->data;\n\n        for (int i03 = 0; i03 < ne03; i03++) {\n            for (int i02 = 0; i02 < ne02; i02++) {\n                id += rs * ir0;\n                for (int i01 = ir0; i01 < ir1; i01++) {\n                    const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);\n\n                    for (int i00 = 0; i00 < ne00; i00++) {\n                        src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);\n                    }\n\n                    quantize_row_q(src0_f32, dst_ptr + id, ne00);\n                    id += rs;\n                }\n                id += rs * (ne01 - ir1);\n            }\n        }\n    } else {\n        // printf(\"%s %s\\n\", ggml_type_name(src0->type), ggml_type_name(dst->type));\n        GGML_ABORT(\"not implemented\");\n    }\n}\n\n// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.\nstatic void ggml_compute_forward_dup_bytes(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));\n    GGML_ASSERT(src0->type == dst->type);\n\n    GGML_TENSOR_UNARY_OP_LOCALS;\n\n    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {\n        ggml_compute_forward_dup_same_cont(params, dst);\n        return;\n    }\n\n    const size_t type_size = ggml_type_size(src0->type);\n\n    const int ith = params->ith; // thread index\n    const int nth = params->nth; // number of threads\n\n    // parallelize by rows\n    const int nr = ne01;\n    // number of rows per thread\n    const int dr = (nr + nth - 1) / nth;\n    // row range for this thread\n    const int ir0 = dr * ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    if (src0->type == dst->type &&\n        ggml_are_same_shape(src0, dst) &&\n        nb00 == type_size && nb0 == type_size) {\n        // copy by rows\n        const size_t rs = ggml_row_size(src0->type, ne00);\n        for (int64_t i03 = 0; i03 < ne03; i03++) {\n            for (int64_t i02 = 0; i02 < ne02; i02++) {\n                for (int64_t i01 = ir0; i01 < ir1; i01++) {\n                    memcpy(\n                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),\n                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),\n                        rs);\n                }\n            }\n        }\n        return;\n    }\n\n    if (ggml_is_contiguous(dst)) {\n        size_t id = 0;\n        char * dst_ptr = (char *) dst->data;\n        const size_t rs = ne00 * type_size;\n\n        if (nb00 == type_size) {\n            // src0 is contiguous on first dimension, copy by rows\n            for (int64_t i03 = 0; i03 < ne03; i03++) {\n                for (int64_t i02 = 0; i02 < ne02; i02++) {\n                    id += rs * ir0;\n                    for (int64_t i01 = ir0; i01 < ir1; i01++) {\n                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;\n                        memcpy(dst_ptr + id, src0_ptr, rs);\n                        id += rs;\n                    }\n                    id += rs * (ne01 - ir1);\n                }\n            }\n        } else {\n            //printf(\"%s: this is not optimal - fix me\\n\", __func__);\n\n            for (int64_t i03 = 0; i03 < ne03; i03++) {\n                for (int64_t i02 = 0; i02 < ne02; i02++) {\n                    id += rs * ir0;\n                    for (int64_t i01 = ir0; i01 < ir1; i01++) {\n                        for (int64_t i00 = 0; i00 < ne00; i00++) {\n                            const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;\n                            memcpy(dst_ptr + id, src0_ptr, type_size);\n\n                            id += type_size;\n                        }\n                    }\n                    id += rs * (ne01 - ir1);\n                }\n            }\n        }\n\n        return;\n    }\n\n    // dst counters\n    int64_t k10 = 0;\n    int64_t i11 = 0;\n    int64_t i12 = 0;\n    int64_t i13 = 0;\n\n    // number of blocks in a row\n    const int64_t nk00 = ne00 / ggml_blck_size(src0->type);\n    const int64_t nk0  = ne0  / ggml_blck_size(dst->type);\n\n    for (int64_t i03 = 0; i03 < ne03; i03++) {\n        for (int64_t i02 = 0; i02 < ne02; i02++) {\n            k10 += nk00 * ir0;\n            while (k10 >= nk0) {\n                k10 -= nk0;\n                if (++i11 == ne1) {\n                    i11 = 0;\n                    if (++i12 == ne2) {\n                        i12 = 0;\n                        if (++i13 == ne3) {\n                            i13 = 0;\n                        }\n                    }\n                }\n            }\n            for (int64_t i01 = ir0; i01 < ir1; i01++) {\n                for (int64_t k00 = 0; k00 < nk00; k00++) {\n                    const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);\n                          char * dst_ptr  = ((char *)  dst->data + k10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);\n\n                    memcpy(dst_ptr, src0_ptr, type_size);\n\n                    if (++k10 == nk0) {\n                        k10 = 0;\n                        if (++i11 == ne1) {\n                            i11 = 0;\n                            if (++i12 == ne2) {\n                                i12 = 0;\n                                if (++i13 == ne3) {\n                                    i13 = 0;\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n            k10 += nk00 * (ne01 - ir1);\n            while (k10 >= nk0) {\n                k10 -= nk0;\n                if (++i11 == ne1) {\n                    i11 = 0;\n                    if (++i12 == ne2) {\n                        i12 = 0;\n                        if (++i13 == ne3) {\n                            i13 = 0;\n                        }\n                    }\n                }\n            }\n        }\n    }\n}\n\nstatic void ggml_compute_forward_dup_from_q(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const ggml_type type = src0->type;\n    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;\n\n    size_t qk = ggml_blck_size(type);\n    const int64_t nr = ggml_nelements(src1) / qk;\n\n    // destination must be contiguous in the first dimension\n    GGML_ASSERT(nb10 == ggml_type_size(dst->type));\n    // must either have first dimension large enough to hold a row, or fully contiguous\n    GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int64_t ir = ir0; ir < ir1; ++ir) {\n\n        uint32_t i = ir * qk;\n\n        const int64_t i03 = i/(ne00 * ne01 * ne02);\n        const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);\n        const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;\n        const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;\n        const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;\n\n        const int64_t i13 = i/(ne10 * ne11 * ne12);\n        const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);\n        const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;\n        const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;\n        const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;\n\n        dequantize_row_q(\n                (const void *) ((char *) src0->data + x_offset),\n                     (float *) ((char *)  dst->data + dst_offset), qk);\n    }\n}\n\nvoid ggml_compute_forward_dup(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    if (src0->type == dst->type) {\n        ggml_compute_forward_dup_bytes(params, dst);\n        return;\n    }\n\n    switch (src0->type) {\n        case GGML_TYPE_F16:\n            {\n                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);\n                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);\n                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<ggml_fp16_t, float      >(params, dst);\n                else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);\n            } break;\n        case GGML_TYPE_BF16:\n            {\n                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);\n                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);\n                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<ggml_bf16_t, float      >(params, dst);\n                else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);\n            } break;\n        case GGML_TYPE_F32:\n            {\n                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);\n                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);\n                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<float, float      >(params, dst);\n                else if (dst->type == GGML_TYPE_I32)  ggml_compute_forward_dup_flt<float, int32_t    >(params, dst);\n                else ggml_compute_forward_dup_to_q<float>(params, dst);\n            } break;\n        case GGML_TYPE_I32:\n            {\n                if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);\n                else GGML_ABORT(\"not implemented\");\n            } break;\n        default:\n            {\n                if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {\n                    ggml_compute_forward_dup_from_q(params, dst);\n                    break;\n                }\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_add\n\nstatic void ggml_compute_forward_add_q_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));\n\n    const int nr  = ggml_nrows(src0);\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const ggml_type type = src0->type;\n    const ggml_type dtype = dst->type;\n    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;\n    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dtype)->from_float;\n\n    // we don't support permuted src0 or src1\n    GGML_ASSERT(nb00 == ggml_type_size(type));\n    GGML_ASSERT(nb10 == sizeof(float));\n\n    // dst cannot be transposed or permuted\n    GGML_ASSERT(nb0 <= nb1);\n    GGML_ASSERT(nb1 <= nb2);\n    GGML_ASSERT(nb2 <= nb3);\n\n    GGML_ASSERT(ggml_is_quantized(src0->type));\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        // src0 indices\n        const int i03 = ir/(ne02*ne01);\n        const int i02 = (ir - i03*ne02*ne01)/ne01;\n        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);\n\n        // src1 and dst are same shape as src0 => same indices\n        const int i13 = i03;\n        const int i12 = i02;\n        const int i11 = i01;\n\n        const int i3 = i03;\n        const int i2 = i02;\n        const int i1 = i01;\n\n        void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));\n        float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));\n        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb3));\n\n        assert(ne00 % 32 == 0);\n\n        // unquantize row from src0 to temp buffer\n        dequantize_row_q(src0_row, wdata, ne00);\n        // add src1\n        ggml_vec_acc_f32(ne00, wdata, src1_row);\n        // quantize row to dst\n        if (quantize_row_q != NULL) {\n            quantize_row_q(wdata, dst_row, ne00);\n        } else {\n            memcpy(dst_row, wdata, ne0*nb0);\n        }\n    }\n}\n\nvoid ggml_compute_forward_add(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n        case GGML_TYPE_F16:\n        case GGML_TYPE_BF16:\n            {\n                ggml_compute_forward_add_non_quantized(params, dst);\n            } break;\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_MXFP4:\n        case GGML_TYPE_NVFP4:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_TQ1_0:\n        case GGML_TYPE_TQ2_0:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:\n        case GGML_TYPE_IQ4_NL:\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ3_S:\n        case GGML_TYPE_IQ2_S:\n            {\n                ggml_compute_forward_add_q_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_add_id\n\nstatic void ggml_compute_forward_add_id_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    const ggml_tensor * src2 = dst->src[2];\n\n    GGML_ASSERT(dst->type  == GGML_TYPE_F32);\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(src2->type == GGML_TYPE_I32);\n\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n    GGML_ASSERT(src1->nb[0] == sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nr  = ggml_nrows(src0);\n\n    GGML_TENSOR_TERNARY_OP_LOCALS\n\n    GGML_ASSERT( nb0 == sizeof(float));\n    GGML_ASSERT(nb10 == sizeof(float));\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        // src0 indices\n        const int i3 = ir/(ne2*ne1);\n        const int i2 = (ir - i3*ne2*ne1)/ne1;\n        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);\n\n        // src1 indices\n        const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);\n\n        GGML_ASSERT(i11 >= 0 && i11 < ne11);\n\n        ggml_vec_add_f32(ne0,\n                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),\n                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),\n                (float *) ((char *) src1->data + i11*nb11));\n    }\n}\n\nvoid ggml_compute_forward_add_id(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_add_id_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"unsupported type for ggml_compute_forward_add_id: %s\", ggml_type_name(src0->type));\n            }\n    }\n}\n\n// ggml_compute_forward_add1\n\nstatic void ggml_compute_forward_add1_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n    GGML_ASSERT(ggml_is_scalar(src1));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nr  = ggml_nrows(src0);\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    GGML_ASSERT( nb0 == sizeof(float));\n    GGML_ASSERT(nb00 == sizeof(float));\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        // src0 and dst are same shape => same indices\n        const int i3 = ir/(ne2*ne1);\n        const int i2 = (ir - i3*ne2*ne1)/ne1;\n        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);\n\n#ifdef GGML_USE_ACCELERATE\n        GGML_UNUSED(ggml_vec_add1_f32);\n\n        vDSP_vadd(\n                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,\n                (float *) ((char *) src1->data), 0,\n                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,\n                ne0);\n#else\n        ggml_vec_add1_f32(ne0,\n                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),\n                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),\n               *(float *) src1->data);\n#endif\n    }\n}\n\nstatic void ggml_compute_forward_add1_f16_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n    GGML_ASSERT(ggml_is_scalar(src1));\n\n    // scalar to add\n    const float v = *(float *) src1->data;\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nr  = ggml_nrows(src0);\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F16);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type  == GGML_TYPE_F16);\n\n    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));\n    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        // src0 and dst are same shape => same indices\n        const int i3 = ir/(ne2*ne1);\n        const int i2 = (ir - i3*ne2*ne1)/ne1;\n        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);\n\n        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );\n        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);\n        for (int i = 0; i < ne0; i++) {\n            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);\n        }\n    }\n}\n\nstatic void ggml_compute_forward_add1_f16_f16(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n    GGML_ASSERT(ggml_is_scalar(src1));\n\n    // scalar to add\n    const float v = GGML_CPU_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nr  = ggml_nrows(src0);\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F16);\n    GGML_ASSERT(src1->type == GGML_TYPE_F16);\n    GGML_ASSERT(dst->type  == GGML_TYPE_F16);\n\n    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));\n    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        // src0 and dst are same shape => same indices\n        const int i3 = ir/(ne2*ne1);\n        const int i2 = (ir - i3*ne2*ne1)/ne1;\n        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);\n\n        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );\n        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);\n        for (int i = 0; i < ne0; i++) {\n            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);\n        }\n    }\n}\n\nstatic void ggml_compute_forward_add1_q_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n    GGML_ASSERT(ggml_is_scalar(src1));\n\n    // scalar to add\n    const float v = *(float *) src1->data;\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nr  = ggml_nrows(src0);\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    const ggml_type type = src0->type;\n    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;\n    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float;\n\n    // we don't support permuted src0\n    GGML_ASSERT(nb00 == ggml_type_size(type));\n\n    // dst cannot be transposed or permuted\n    GGML_ASSERT(nb0 <= nb1);\n    GGML_ASSERT(nb1 <= nb2);\n    GGML_ASSERT(nb2 <= nb3);\n\n    GGML_ASSERT(ggml_is_quantized(src0->type));\n    GGML_ASSERT(dst->type == src0->type);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        // src0 and dst are same shape => same indices\n        const int i3 = ir/(ne2*ne1);\n        const int i2 = (ir - i3*ne2*ne1)/ne1;\n        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);\n\n        void  * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));\n        void  * dst_row  = (void *) ((char *)  dst->data + (i1*nb1  + i2*nb2  + i3*nb0 ));\n\n        assert(ne0 % 32 == 0);\n\n        // unquantize row from src0 to temp buffer\n        dequantize_row_q(src0_row, wdata, ne0);\n        // add src1\n        ggml_vec_acc1_f32(ne0, wdata, v);\n        // quantize row to dst\n        quantize_row_q(wdata, dst_row, ne0);\n    }\n}\n\nstatic void ggml_compute_forward_add1_bf16_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n    GGML_ASSERT(ggml_is_scalar(src1));\n\n    // scalar to add\n    const float v = *(float *) src1->data;\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nr  = ggml_nrows(src0);\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    GGML_ASSERT(src0->type == GGML_TYPE_BF16);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);\n\n    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));\n    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        // src0 and dst are same shape => same indices\n        const int i3 = ir/(ne2*ne1);\n        const int i2 = (ir - i3*ne2*ne1)/ne1;\n        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);\n\n        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );\n        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);\n        for (int i = 0; i < ne0; i++) {\n            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);\n        }\n    }\n}\n\nstatic void ggml_compute_forward_add1_bf16_bf16(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n    GGML_ASSERT(ggml_is_scalar(src1));\n\n    // scalar to add\n    const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nr  = ggml_nrows(src0);\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    GGML_ASSERT(src0->type == GGML_TYPE_BF16);\n    GGML_ASSERT(src1->type == GGML_TYPE_BF16);\n    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);\n\n    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));\n    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        // src0 and dst are same shape => same indices\n        const int i3 = ir/(ne2*ne1);\n        const int i2 = (ir - i3*ne2*ne1)/ne1;\n        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);\n\n        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );\n        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);\n        for (int i = 0; i < ne0; i++) {\n            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);\n        }\n    }\n}\n\nvoid ggml_compute_forward_add1(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_add1_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                if (src1->type == GGML_TYPE_F16) {\n                    ggml_compute_forward_add1_f16_f16(params, dst);\n                }\n                else if (src1->type == GGML_TYPE_F32) {\n                    ggml_compute_forward_add1_f16_f32(params, dst);\n                }\n                else {\n                    GGML_ABORT(\"fatal error\");\n                }\n            } break;\n        case GGML_TYPE_BF16:\n            {\n                if (src1->type == GGML_TYPE_BF16) {\n                    ggml_compute_forward_add1_bf16_bf16(params, dst);\n                }\n                else if (src1->type == GGML_TYPE_F32) {\n                    ggml_compute_forward_add1_bf16_f32(params, dst);\n                }\n                else {\n                    GGML_ABORT(\"fatal error\");\n                }\n            } break;\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_Q8_1:\n        case GGML_TYPE_MXFP4:\n        case GGML_TYPE_NVFP4:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_TQ1_0:\n        case GGML_TYPE_TQ2_0:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:\n        case GGML_TYPE_IQ4_NL:\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ3_S:\n        case GGML_TYPE_IQ2_S:\n            {\n                ggml_compute_forward_add1_q_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_acc\n\nstatic void ggml_compute_forward_acc_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));\n\n    // view src0 and dst with these strides and data offset inbytes during acc\n    // nb0 is implicitly element_size because src0 and dst are contiguous\n    size_t nb1     = ((int32_t *) dst->op_params)[0];\n    size_t nb2     = ((int32_t *) dst->op_params)[1];\n    size_t nb3     = ((int32_t *) dst->op_params)[2];\n    size_t offset  = ((int32_t *) dst->op_params)[3];\n    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];\n\n    if (!inplace) {\n        if (params->ith == 0) {\n            // memcpy needs to be synchronized across threads to avoid race conditions.\n            // => do it in INIT phase\n            memcpy(\n                ((char *)  dst->data),\n                ((char *) src0->data),\n                ggml_nbytes(dst));\n        }\n        ggml_barrier(params->threadpool);\n    }\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nr = ggml_nrows(src1);\n    const int nc = src1->ne[0];\n\n    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)\n\n    // src0 and dst as viewed during acc\n    const size_t nb0 = ggml_element_size(src0);\n\n    const size_t nb00 = nb0;\n    const size_t nb01 = nb1;\n    const size_t nb02 = nb2;\n    const size_t nb03 = nb3;\n\n    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0  + (ne11 == 0 ? 0 : ne11-1)*nb1  + (ne12 == 0 ? 0 : ne12-1)*nb2  + (ne13 == 0 ? 0 : ne13-1)*nb3  < ggml_nbytes(dst));\n    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));\n\n    GGML_ASSERT(nb10 == sizeof(float));\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        // src0 and dst are viewed with shape of src1 and offset\n        // => same indices\n        const int i3 = ir/(ne12*ne11);\n        const int i2 = (ir - i3*ne12*ne11)/ne11;\n        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);\n\n#ifdef GGML_USE_ACCELERATE\n        vDSP_vadd(\n                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,\n                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,\n                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + offset), 1, nc);\n#else\n        ggml_vec_add_f32(nc,\n                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),\n                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),\n                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));\n#endif\n    }\n}\n\nvoid ggml_compute_forward_acc(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_acc_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n        case GGML_TYPE_BF16:\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_Q8_1:\n        case GGML_TYPE_MXFP4:\n        case GGML_TYPE_NVFP4:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_TQ1_0:\n        case GGML_TYPE_TQ2_0:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:\n        case GGML_TYPE_IQ4_NL:\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ3_S:\n        case GGML_TYPE_IQ2_S:\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_sum\n\nstatic void ggml_compute_forward_sum_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    assert(ggml_is_scalar(dst));\n    assert(src0->nb[0] == sizeof(float));\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)\n\n    ggml_float sum     = 0;\n    ggml_float row_sum = 0;\n\n    for (int64_t i03 = 0; i03 < ne03; i03++) {\n        for (int64_t i02 = 0; i02 < ne02; i02++) {\n            for (int64_t i01 = 0; i01 < ne01; i01++) {\n                ggml_vec_sum_f32_ggf(ne00,\n                        &row_sum,\n                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));\n                sum += row_sum;\n            }\n        }\n    }\n    ((float *) dst->data)[0] = sum;\n}\n\nstatic void ggml_compute_forward_sum_f16(\n    const ggml_compute_params * params,\n          ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    assert(ggml_is_scalar(dst));\n\n    assert(src0->nb[0] == sizeof(ggml_fp16_t));\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)\n\n    float sum = 0;\n    float row_sum = 0;\n\n    for (int64_t i03 = 0; i03 < ne03; i03++) {\n        for (int64_t i02 = 0; i02 < ne02; i02++) {\n            for (int64_t i01 = 0; i01 < ne01; i01++) {\n                ggml_vec_sum_f16_ggf(ne00,\n                    &row_sum,\n                    (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));\n                sum += row_sum;\n            }\n        }\n    }\n    ((ggml_fp16_t *) dst->data)[0] = GGML_CPU_FP32_TO_FP16(sum);\n}\n\nstatic void ggml_compute_forward_sum_bf16(\n    const ggml_compute_params * params,\n          ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    assert(ggml_is_scalar(dst));\n\n    assert(src0->nb[0] == sizeof(ggml_bf16_t));\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)\n\n    float sum = 0;\n    float row_sum = 0;\n\n    for (int64_t i03 = 0; i03 < ne03; i03++) {\n        for (int64_t i02 = 0; i02 < ne02; i02++) {\n            for (int64_t i01 = 0; i01 < ne01; i01++) {\n                ggml_vec_sum_bf16_ggf(ne00,\n                    &row_sum,\n                    (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));\n                sum += row_sum;\n            }\n        }\n    }\n    ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);\n}\n\nvoid ggml_compute_forward_sum(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_sum_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_sum_f16(params, dst);\n            } break;\n        case GGML_TYPE_BF16:\n            {\n                ggml_compute_forward_sum_bf16(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_cumsum\n\nstatic void ggml_compute_forward_cumsum_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n    GGML_ASSERT(dst->nb[0] == sizeof(float));\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    GGML_ASSERT(ne0 == ne00);\n    GGML_ASSERT(ne1 == ne01);\n    GGML_ASSERT(ne2 == ne02);\n    GGML_ASSERT(ne3 == ne03);\n\n    const auto [ir0, ir1] = get_thread_range(params, src0);\n\n    for (int64_t ir = ir0; ir < ir1; ++ir) {\n        const int64_t i03 = ir/(ne02*ne01);\n        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;\n        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);\n\n        float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);\n        float * dst_row = (float *) ((char *) dst->data  + i01*nb1  + i02*nb2  + i03*nb3);\n\n        ggml_vec_cumsum_f32(ne00, dst_row, src_row);\n    }\n}\n\nvoid ggml_compute_forward_cumsum(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_cumsum_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_sum_rows\n\nstatic void ggml_compute_forward_sum_rows_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n    GGML_ASSERT(dst->nb[0] == sizeof(float));\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    GGML_ASSERT(ne0 == 1);\n    GGML_ASSERT(ne1 == ne01);\n    GGML_ASSERT(ne2 == ne02);\n    GGML_ASSERT(ne3 == ne03);\n\n    for (int64_t i3 = 0; i3 < ne03; i3++) {\n        for (int64_t i2 = 0; i2 < ne02; i2++) {\n            for (int64_t i1 = 0; i1 < ne01; i1++) {\n                float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);\n                float * dst_row = (float *) ((char *) dst->data  + i1*nb1  + i2*nb2  + i3*nb3);\n                float row_sum = 0;\n                ggml_vec_sum_f32(ne00, &row_sum, src_row);\n                dst_row[0] = row_sum;\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_sum_rows(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_sum_rows_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_mean\n\nstatic void ggml_compute_forward_mean_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    assert(src0->nb[0] == sizeof(float));\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    assert(ne0 == 1);\n    assert(ne1 == ne01);\n    assert(ne2 == ne02);\n    assert(ne3 == ne03);\n\n    GGML_UNUSED(ne0);\n    GGML_UNUSED(ne1);\n    GGML_UNUSED(ne2);\n    GGML_UNUSED(ne3);\n\n    for (int64_t i03 = 0; i03 < ne03; i03++) {\n        for (int64_t i02 = 0; i02 < ne02; i02++) {\n            for (int64_t i01 = 0; i01 < ne01; i01++) {\n                ggml_vec_sum_f32(ne00,\n                        (float *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),\n                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));\n\n                *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_mean(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_mean_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_argmax\n\nstatic void ggml_compute_forward_argmax_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    assert(src0->nb[0] == sizeof(float));\n    assert(dst->nb[0] == sizeof(float));\n\n    const int64_t ne00 = src0->ne[0];\n    const int64_t ne01 = src0->ne[1];\n\n    const size_t nb01 = src0->nb[1];\n    const size_t nb0 = dst->nb[0];\n\n    for (int64_t i1 = 0; i1 < ne01; i1++) {\n        float * src = (float *) ((char *) src0->data + i1*nb01);\n        int32_t * dst_ = (int32_t *) ((char *)  dst->data + i1*nb0);\n        int v = 0;\n        ggml_vec_argmax_f32(ne00, &v, src);\n        dst_[0] = v;\n    }\n}\n\nvoid ggml_compute_forward_argmax(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_argmax_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_count_equal\n\nstatic void ggml_compute_forward_count_equal_i32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_TENSOR_BINARY_OP_LOCALS;\n\n    GGML_ASSERT(src0->type == GGML_TYPE_I32);\n    GGML_ASSERT(src1->type == GGML_TYPE_I32);\n    GGML_ASSERT(ggml_are_same_shape(src0, src1));\n    GGML_ASSERT(ggml_is_scalar(dst));\n    GGML_ASSERT(dst->type == GGML_TYPE_I64);\n\n    const int64_t nr = ggml_nrows(src0);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    int64_t * sums = (int64_t *) params->wdata;\n    int64_t sum_thread = 0;\n\n    // rows per thread\n    const int64_t dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int64_t ir0 = dr*ith;\n    const int64_t ir1 = MIN(ir0 + dr, nr);\n\n    for (int64_t ir = ir0; ir < ir1; ++ir) {\n        const int64_t i03 =  ir                        / (ne02*ne01);\n        const int64_t i02 = (ir - i03*ne03)            /       ne01;\n        const int64_t i01 =  ir - i03*ne03 - i02*ne02;\n\n        const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;\n        const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;\n\n        for (int64_t i00 = 0; i00 < ne00; ++i00) {\n            const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));\n            const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));\n\n            sum_thread += val0 == val1;\n        }\n    }\n    if (ith != 0) {\n        sums[ith] = sum_thread;\n    }\n    ggml_barrier(params->threadpool);\n\n    if (ith != 0) {\n        return;\n    }\n\n    for (int ith_other = 1; ith_other < nth; ++ith_other) {\n        sum_thread += sums[ith_other];\n    }\n    *((int64_t *) dst->data) = sum_thread;\n}\n\nvoid ggml_compute_forward_count_equal(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_I32:\n            {\n                ggml_compute_forward_count_equal_i32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_repeat\n\nstatic void ggml_compute_forward_repeat_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    GGML_ASSERT(ggml_can_repeat(src0, dst));\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    // guaranteed to be an integer due to the check in ggml_can_repeat\n    const int nr0 = (int)(ne0/ne00);\n    const int nr1 = (int)(ne1/ne01);\n    const int nr2 = (int)(ne2/ne02);\n    const int nr3 = (int)(ne3/ne03);\n\n    // TODO: support for transposed / permuted tensors\n    GGML_ASSERT(nb0  == sizeof(float));\n    GGML_ASSERT(nb00 == sizeof(float));\n\n    // TODO: maybe this is not optimal?\n    for                         (int i3 = 0; i3 < nr3;  i3++) {\n        for                     (int k3 = 0; k3 < ne03; k3++) {\n            for                 (int i2 = 0; i2 < nr2;  i2++) {\n                for             (int k2 = 0; k2 < ne02; k2++) {\n                    for         (int i1 = 0; i1 < nr1;  i1++) {\n                        for     (int k1 = 0; k1 < ne01; k1++) {\n                            for (int i0 = 0; i0 < nr0;  i0++) {\n                                ggml_vec_cpy_f32(ne00,\n                                        (float *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0),\n                                        (float *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01));\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n}\n\nstatic void ggml_compute_forward_repeat_f16(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    GGML_ASSERT(ggml_can_repeat(src0, dst));\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    // guaranteed to be an integer due to the check in ggml_can_repeat\n    const int nr0 = (int)(ne0/ne00);\n    const int nr1 = (int)(ne1/ne01);\n    const int nr2 = (int)(ne2/ne02);\n    const int nr3 = (int)(ne3/ne03);\n\n    // TODO: support for transposed / permuted tensors\n    GGML_ASSERT(nb0  == sizeof(ggml_fp16_t));\n    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));\n\n    // TODO: maybe this is not optimal?\n    for                         (int i3 = 0; i3 < nr3;  i3++) {\n        for                     (int k3 = 0; k3 < ne03; k3++) {\n            for                 (int i2 = 0; i2 < nr2;  i2++) {\n                for             (int k2 = 0; k2 < ne02; k2++) {\n                    for         (int i1 = 0; i1 < nr1;  i1++) {\n                        for     (int k1 = 0; k1 < ne01; k1++) {\n                            for (int i0 = 0; i0 < nr0;  i0++) {\n                                ggml_fp16_t * y = (ggml_fp16_t *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0);\n                                ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01);\n                                // ggml_vec_cpy_f16(ne00, y, x)\n                                for (int i = 0; i < ne00; ++i) {\n                                    y[i]  = x[i];\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_repeat(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F16:\n        case GGML_TYPE_BF16:\n        case GGML_TYPE_I16:\n            {\n                ggml_compute_forward_repeat_f16(params, dst);\n            } break;\n        case GGML_TYPE_F32:\n        case GGML_TYPE_I32:\n            {\n                ggml_compute_forward_repeat_f32(params, dst);\n            } break;\n        // TODO: templateify the implementation and support for I64\n        //       ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225\n        //case GGML_TYPE_I64:\n        //    {\n        //        ggml_compute_forward_repeat_i64(params, dst);\n        //    } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_repeat_back\n\nstatic void ggml_compute_forward_repeat_back_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    GGML_ASSERT(ggml_can_repeat(dst, src0));\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    // guaranteed to be an integer due to the check in ggml_can_repeat\n    const int nr0 = (int)(ne00/ne0);\n    const int nr1 = (int)(ne01/ne1);\n    const int nr2 = (int)(ne02/ne2);\n    const int nr3 = (int)(ne03/ne3);\n\n    // TODO: support for transposed / permuted tensors\n    GGML_ASSERT(nb0  == sizeof(float));\n    GGML_ASSERT(nb00 == sizeof(float));\n\n    if (ggml_is_contiguous(dst)) {\n        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);\n    } else {\n        for         (int k3 = 0; k3 < ne3; k3++) {\n            for     (int k2 = 0; k2 < ne2; k2++) {\n                for (int k1 = 0; k1 < ne1; k1++) {\n                    ggml_vec_set_f32(ne0,\n                        (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),\n                        0);\n                }\n            }\n        }\n    }\n\n    // TODO: maybe this is not optimal?\n    for                         (int i3 = 0; i3 < nr3; i3++) {\n        for                     (int k3 = 0; k3 < ne3; k3++) {\n            for                 (int i2 = 0; i2 < nr2; i2++) {\n                for             (int k2 = 0; k2 < ne2; k2++) {\n                    for         (int i1 = 0; i1 < nr1; i1++) {\n                        for     (int k1 = 0; k1 < ne1; k1++) {\n                            for (int i0 = 0; i0 < nr0; i0++) {\n                                ggml_vec_acc_f32(ne0,\n                                        (float *) ((char *)  dst->data + (         k3)*nb3  + (         k2)*nb2  + (         k1)*nb1),\n                                        (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_repeat_back(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_repeat_back_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_concat\n\nstatic void ggml_compute_forward_concat_any(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    const size_t len = ggml_type_size(src0->type);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int32_t dim = ggml_get_op_params_i32(dst, 0);\n\n    GGML_ASSERT(dim >= 0 && dim < 4);\n\n    int64_t o[4] = {0, 0, 0, 0};\n    o[dim] = src0->ne[dim];\n\n    const char * x;\n\n    // TODO: smarter multi-theading\n    for (int i3 = 0; i3 < ne3; i3++) {\n        for (int i2 = ith; i2 < ne2; i2 += nth) {\n            for (int i1 = 0; i1 < ne1; i1++) {\n                for (int i0 = 0; i0 < ne0; i0++) {\n                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {\n                        x = (const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03;\n                    } else {\n                        x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;\n                    }\n\n                    char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;\n\n                    memcpy(y, x, len);\n                }\n            }\n        }\n    }\n}\n\nstatic void ggml_compute_forward_concat_i8(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int32_t dim = ggml_get_op_params_i32(dst, 0);\n\n    GGML_ASSERT(dim >= 0 && dim < 4);\n\n    int64_t o[4] = {0, 0, 0, 0};\n    o[dim] = src0->ne[dim];\n\n    const int8_t * x;\n\n    // TODO: smarter multi-theading\n    for (int i3 = 0; i3 < ne3; i3++) {\n        for (int i2 = ith; i2 < ne2; i2 += nth) {\n            for (int i1 = 0; i1 < ne1; i1++) {\n                for (int i0 = 0; i0 < ne0; i0++) {\n                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {\n                        x = (const int8_t *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);\n                    } else {\n                        x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);\n                    }\n\n                    int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);\n\n                    *y = *x;\n                }\n            }\n        }\n    }\n}\n\nstatic void ggml_compute_forward_concat_f16(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int32_t dim = ggml_get_op_params_i32(dst, 0);\n\n    GGML_ASSERT(dim >= 0 && dim < 4);\n\n    int64_t o[4] = {0, 0, 0, 0};\n    o[dim] = src0->ne[dim];\n\n    const ggml_fp16_t * x;\n\n    // TODO: smarter multi-theading\n    for (int i3 = 0; i3 < ne3; i3++) {\n        for (int i2 = ith; i2 < ne2; i2 += nth) {\n            for (int i1 = 0; i1 < ne1; i1++) {\n                for (int i0 = 0; i0 < ne0; i0++) {\n                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {\n                        x = (const ggml_fp16_t *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);\n                    } else {\n                        x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);\n                    }\n\n                    ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);\n\n                    *y = *x;\n                }\n            }\n        }\n    }\n}\n\nstatic void ggml_compute_forward_concat_f32(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int32_t dim = ggml_get_op_params_i32(dst, 0);\n\n    GGML_ASSERT(dim >= 0 && dim < 4);\n\n    int64_t o[4] = {0, 0, 0, 0};\n    o[dim] = src0->ne[dim];\n\n    const float * x;\n\n    // TODO: smarter multi-theading\n    for (int i3 = 0; i3 < ne3; i3++) {\n        for (int i2 = ith; i2 < ne2; i2 += nth) {\n            for (int i1 = 0; i1 < ne1; i1++) {\n                for (int i0 = 0; i0 < ne0; i0++) {\n                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {\n                        x = (const float *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);\n                    } else {\n                        x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);\n                    }\n\n                    float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);\n\n                    *y = *x;\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_concat(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F16:\n        case GGML_TYPE_BF16:\n        case GGML_TYPE_I16:\n            {\n                ggml_compute_forward_concat_f16(params, dst);\n            } break;\n        case GGML_TYPE_I8:\n            {\n                ggml_compute_forward_concat_i8(params, dst);\n            } break;\n        case GGML_TYPE_F32:\n        case GGML_TYPE_I32:\n            {\n                ggml_compute_forward_concat_f32(params, dst);\n            } break;\n        default:\n            {\n                ggml_compute_forward_concat_any(params, dst);\n            }\n    }\n}\n\n// ggml_compute_forward_gelu\n\nstatic void ggml_compute_forward_gelu_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    assert(ggml_is_contiguous_rows(src0));\n    assert(ggml_are_same_shape(src0, dst));\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src0->ne[0];\n    const int nr = ggml_nrows(src0);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        const int i3 = ir/(ne02*ne01);\n        const int i2 = (ir - i3*ne02*ne01)/ne01;\n        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);\n\n        ggml_vec_gelu_f32(nc,\n                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),\n                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];\n            GGML_UNUSED(x);\n            assert(!isnan(x));\n            assert(!isinf(x));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_gelu_f16(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    assert(ggml_is_contiguous_rows(src0));\n    assert(ggml_are_same_shape(src0, dst));\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src0->ne[0];\n    const int nr = ggml_nrows(src0);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        const int i3 = ir/(ne02*ne01);\n        const int i2 = (ir - i3*ne02*ne01)/ne01;\n        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);\n\n        ggml_vec_gelu_f16(nc,\n                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),\n                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];\n            const float v = GGML_CPU_FP16_TO_FP32(x);\n            GGML_UNUSED(v);\n            assert(!isnan(v));\n            assert(!isinf(v));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_gelu(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_gelu_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_gelu_f16(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_fill\n\nstatic void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {\n    const float c = ggml_get_op_params_f32(dst, 0);\n\n    GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);\n    GGML_TENSOR_LOCALS(size_t,  nb, dst, nb);\n\n    const auto [ir0, ir1] = get_thread_range(params, dst);\n\n    for (int64_t ir = ir0; ir < ir1; ++ir) {\n        const int64_t i03 = ir/(ne2*ne1);\n        const int64_t i02 = (ir - i03*ne2*ne1)/ne1;\n        const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);\n\n        float * dst_ptr  = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);\n\n        ggml_vec_set_f32(ne0, dst_ptr, c);\n    }\n}\n\nvoid ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {\n    ggml_compute_forward_fill_f32(params, dst);\n}\n\n// ggml_compute_tri\n\nstatic void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);\n\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    const auto [ir0, ir1] = get_thread_range(params, src0);\n\n    bool (*bipred)(int, int);\n\n    switch (ttype) {\n        case GGML_TRI_TYPE_LOWER:      bipred = [](int i, int r) { return i <  r; }; break;\n        case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;\n        case GGML_TRI_TYPE_UPPER:      bipred = [](int i, int r) { return i >  r; }; break;\n        case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;\n        default: GGML_ABORT(\"invalid tri type\");\n    }\n\n    for (int64_t ir = ir0; ir < ir1; ++ir) {\n        const int64_t i03 = ir/(ne02*ne01);\n        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;\n        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);\n\n        const float * src_ptr = (const float  *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);\n              float * dst_ptr = (      float  *) ((      char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1);\n\n        for (int i0 = 0; i0 < ne0; ++i0) {\n            dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;\n        }\n    }\n}\n\nvoid ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_tri_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_gelu_erf\n\nstatic void ggml_compute_forward_gelu_erf_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    assert(ggml_is_contiguous_rows(src0));\n    assert(ggml_are_same_shape(src0, dst));\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src0->ne[0];\n    const int nr = ggml_nrows(src0);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        const int i3 = ir/(ne02*ne01);\n        const int i2 = (ir - i3*ne02*ne01)/ne01;\n        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);\n\n        ggml_vec_gelu_erf_f32(nc,\n                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),\n                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];\n            GGML_UNUSED(x);\n            assert(!isnan(x));\n            assert(!isinf(x));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_gelu_erf_f16(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    assert(ggml_is_contiguous_rows(src0));\n    assert(ggml_are_same_shape(src0, dst));\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src0->ne[0];\n    const int nr = ggml_nrows(src0);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        const int i3 = ir/(ne02*ne01);\n        const int i2 = (ir - i3*ne02*ne01)/ne01;\n        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);\n\n        ggml_vec_gelu_erf_f16(nc,\n                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),\n                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];\n            const float v = GGML_CPU_FP16_TO_FP32(x);\n            GGML_UNUSED(v);\n            assert(!isnan(v));\n            assert(!isinf(v));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_gelu_erf(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_gelu_erf_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_gelu_erf_f16(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_gelu_quick\n\nstatic void ggml_compute_forward_gelu_quick_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    assert(ggml_is_contiguous_rows(src0));\n    assert(ggml_are_same_shape(src0, dst));\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src0->ne[0];\n    const int nr = ggml_nrows(src0);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        const int i3 = ir/(ne02*ne01);\n        const int i2 = (ir - i3*ne02*ne01)/ne01;\n        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);\n\n        ggml_vec_gelu_quick_f32(nc,\n                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),\n                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];\n            GGML_UNUSED(x);\n            assert(!isnan(x));\n            assert(!isinf(x));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_gelu_quick_f16(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    assert(ggml_is_contiguous_rows(src0));\n    assert(ggml_are_same_shape(src0, dst));\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src0->ne[0];\n    const int nr = ggml_nrows(src0);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        const int i3 = ir/(ne02*ne01);\n        const int i2 = (ir - i3*ne02*ne01)/ne01;\n        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);\n\n        ggml_vec_gelu_quick_f16(nc,\n                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),\n                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];\n            const float v = GGML_CPU_FP16_TO_FP32(x);\n            GGML_UNUSED(v);\n            assert(!isnan(v));\n            assert(!isinf(v));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_gelu_quick(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_gelu_quick_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_gelu_quick_f16(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_silu\n\nstatic void ggml_compute_forward_silu_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    assert(ggml_is_contiguous_rows(src0));\n    assert(ggml_are_same_shape(src0, dst));\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src0->ne[0];\n    const int nr = ggml_nrows(src0);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        const int i3 = ir/(ne02*ne01);\n        const int i2 = (ir - i3*ne02*ne01)/ne01;\n        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);\n\n        ggml_vec_silu_f32(nc,\n                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),\n                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];\n            GGML_UNUSED(x);\n            assert(!isnan(x));\n            assert(!isinf(x));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_silu_f16(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    assert(ggml_is_contiguous_rows(src0));\n    assert(ggml_are_same_shape(src0, dst));\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src0->ne[0];\n    const int nr = ggml_nrows(src0);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        const int i3 = ir/(ne02*ne01);\n        const int i2 = (ir - i3*ne02*ne01)/ne01;\n        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);\n\n        ggml_vec_silu_f16(nc,\n                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),\n                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];\n            const float v = GGML_CPU_FP16_TO_FP32(x);\n            GGML_UNUSED(v);\n            assert(!isnan(v));\n            assert(!isinf(v));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_silu(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_silu_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_silu_f16(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n// ggml_compute_forward_leaky_relu\n\nstatic void ggml_compute_forward_leaky_relu_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    assert(ggml_is_contiguous_1(src0));\n    assert(ggml_is_contiguous_1(dst));\n    assert(ggml_are_same_shape(src0, dst));\n\n    const int n  = ggml_nrows(src0);\n    const int nc = src0->ne[0];\n\n    float negative_slope;\n    memcpy(&negative_slope, dst->op_params, sizeof(float));\n\n    assert(dst->nb[0]  == sizeof(float));\n    assert(src0->nb[0] == sizeof(float));\n\n    for (int i = 0; i < n; i++) {\n        ggml_vec_leaky_relu_f32(nc,\n                (float *) ((char *) dst->data  + i*( dst->nb[1])),\n                (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);\n    }\n}\n\nstatic void ggml_compute_forward_leaky_relu_f16(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    assert(ggml_is_contiguous_1(src0));\n    assert(ggml_is_contiguous_1(dst));\n    assert(ggml_are_same_shape(src0, dst));\n\n    const int n  = ggml_nrows(src0);\n    const int nc = src0->ne[0];\n\n    float negative_slope;\n    memcpy(&negative_slope, dst->op_params, sizeof(float));\n\n    assert(dst->nb[0]  == sizeof(ggml_fp16_t));\n    assert(src0->nb[0] == sizeof(ggml_fp16_t));\n\n    for (int i = 0; i < n; i++) {\n        ggml_vec_leaky_relu_f16(nc,\n                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),\n                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);\n    }\n}\n\nvoid ggml_compute_forward_leaky_relu(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_leaky_relu_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_leaky_relu_f16(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_silu_back\n\nstatic void ggml_compute_forward_silu_back_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * grad = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    assert(ggml_is_contiguous_1(grad));\n    assert(ggml_is_contiguous_1(src1));\n    assert(ggml_is_contiguous_1(dst));\n    assert(ggml_are_same_shape(src1, dst));\n    assert(ggml_are_same_shape(src1, grad));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src1->ne[0];\n    const int nr = ggml_nrows(src1);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        ggml_vec_silu_backward_f32(nc,\n                (float *) ((char *) dst->data  + i1*( dst->nb[1])),\n                (float *) ((char *) src1->data + i1*(src1->nb[1])),\n                (float *) ((char *) grad->data + i1*(grad->nb[1])));\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];\n            GGML_UNUSED(x);\n            assert(!isnan(x));\n            assert(!isinf(x));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_silu_back_f16(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * grad = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    assert(ggml_is_contiguous_1(grad));\n    assert(ggml_is_contiguous_1(src1));\n    assert(ggml_is_contiguous_1(dst));\n    assert(ggml_are_same_shape(src1, dst));\n    assert(ggml_are_same_shape(src1, grad));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src1->ne[0];\n    const int nr = ggml_nrows(src1);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        ggml_vec_silu_backward_f16(nc,\n                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),\n                (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),\n                (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];\n            const float v = GGML_CPU_FP16_TO_FP32(x);\n            GGML_UNUSED(v);\n            assert(!isnan(v));\n            assert(!isinf(v));\n        }\n#endif // NDEBUG\n    }\n}\n\nvoid ggml_compute_forward_silu_back(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_silu_back_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_silu_back_f16(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_reglu\n\nstatic void ggml_compute_forward_reglu_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    char * src0_d = (char *) src0->data;\n    char * src1_d = (char *) (src1 ? src1->data : src0->data);\n    const size_t src0_o = src0->nb[1];\n    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(ggml_is_contiguous_1(dst));\n\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;\n    const int nr = ggml_nrows(src0);\n\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_nrows(dst) == nr);\n\n    const int32_t swapped = ggml_get_op_params_i32(dst, 1);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        float * src0_p = (float *) (src0_d + i1*src0_o);\n        float * src1_p = (float *) (src1_d + i1*src1_o);\n\n        if (!src1) {\n            src0_p += swapped ? nc : 0;\n            src1_p += swapped ? 0 : nc;\n        }\n\n        ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];\n            GGML_UNUSED(x);\n            assert(!isnan(x));\n            assert(!isinf(x));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_reglu_f16(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    char * src0_d = (char *) src0->data;\n    char * src1_d = (char *) (src1 ? src1->data : src0->data);\n    const size_t src0_o = src0->nb[1];\n    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(ggml_is_contiguous_1(dst));\n\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;\n    const int nr = ggml_nrows(src0);\n\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_nrows(dst) == nr);\n\n    const int32_t swapped = ggml_get_op_params_i32(dst, 1);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);\n        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);\n\n        if (!src1) {\n            src0_p += swapped ? nc : 0;\n            src1_p += swapped ? 0 : nc;\n        }\n\n        ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];\n            const float v = GGML_FP16_TO_FP32(x);\n            GGML_UNUSED(v);\n            assert(!isnan(v));\n            assert(!isinf(v));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_reglu(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_reglu_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_reglu_f16(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_geglu\n\nstatic void ggml_compute_forward_geglu_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    char * src0_d = (char *) src0->data;\n    char * src1_d = (char *) (src1 ? src1->data : src0->data);\n    const size_t src0_o = src0->nb[1];\n    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(ggml_is_contiguous_1(dst));\n\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;\n    const int nr = ggml_nrows(src0);\n\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_nrows(dst) == nr);\n\n    const int32_t swapped = ggml_get_op_params_i32(dst, 1);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        float * src0_p = (float *) (src0_d + i1*src0_o);\n        float * src1_p = (float *) (src1_d + i1*src1_o);\n\n        if (!src1) {\n            src0_p += swapped ? nc : 0;\n            src1_p += swapped ? 0 : nc;\n        }\n\n        ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];\n            GGML_UNUSED(x);\n            assert(!isnan(x));\n            assert(!isinf(x));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_geglu_f16(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    char * src0_d = (char *) src0->data;\n    char * src1_d = (char *) (src1 ? src1->data : src0->data);\n    const size_t src0_o = src0->nb[1];\n    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(ggml_is_contiguous_1(dst));\n\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;\n    const int nr = ggml_nrows(src0);\n\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_nrows(dst) == nr);\n\n    const int32_t swapped = ggml_get_op_params_i32(dst, 1);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);\n        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);\n\n        if (!src1) {\n            src0_p += swapped ? nc : 0;\n            src1_p += swapped ? 0 : nc;\n        }\n\n        ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];\n            const float v = GGML_FP16_TO_FP32(x);\n            GGML_UNUSED(v);\n            assert(!isnan(v));\n            assert(!isinf(v));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_geglu(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_geglu_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_geglu_f16(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_swiglu\n\nstatic void ggml_compute_forward_swiglu_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    char * src0_d = (char *) src0->data;\n    char * src1_d = (char *) (src1 ? src1->data : src0->data);\n    const size_t src0_o = src0->nb[1];\n    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(ggml_is_contiguous_1(dst));\n\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;\n    const int nr = ggml_nrows(src0);\n\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_nrows(dst) == nr);\n\n    const int32_t swapped = ggml_get_op_params_i32(dst, 1);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        float * src0_p = (float *) (src0_d + i1*src0_o);\n        float * src1_p = (float *) (src1_d + i1*src1_o);\n\n        if (!src1) {\n            src0_p += swapped ? nc : 0;\n            src1_p += swapped ? 0 : nc;\n        }\n\n        ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];\n            GGML_UNUSED(x);\n            assert(!isnan(x));\n            assert(!isinf(x));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_swiglu_f16(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    char * src0_d = (char *) src0->data;\n    char * src1_d = (char *) (src1 ? src1->data : src0->data);\n    const size_t src0_o = src0->nb[1];\n    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(ggml_is_contiguous_1(dst));\n\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;\n    const int nr = ggml_nrows(src0);\n\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_nrows(dst) == nr);\n\n    const int32_t swapped = ggml_get_op_params_i32(dst, 1);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);\n        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);\n\n        if (!src1) {\n            src0_p += swapped ? nc : 0;\n            src1_p += swapped ? 0 : nc;\n        }\n\n        ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];\n            const float v = GGML_FP16_TO_FP32(x);\n            GGML_UNUSED(v);\n            assert(!isnan(v));\n            assert(!isinf(v));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_swiglu(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_swiglu_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_swiglu_f16(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_swiglu_oai\n\nstatic void ggml_compute_forward_swiglu_oai_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    char * src0_d = (char *) src0->data;\n    char * src1_d = (char *) (src1 ? src1->data : src0->data);\n    const size_t src0_o = src0->nb[1];\n    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(ggml_is_contiguous_1(dst));\n\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;\n    const int nr = ggml_nrows(src0);\n\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_nrows(dst) == nr);\n\n    const int32_t swapped = ggml_get_op_params_i32(dst, 1);\n    const float alpha = ggml_get_op_params_f32(dst, 2);\n    const float limit = ggml_get_op_params_f32(dst, 3);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        float * src0_p = (float *) (src0_d + i1*src0_o);\n        float * src1_p = (float *) (src1_d + i1*src1_o);\n        float * dst_p  = (float *) ((char *) dst->data + i1*(dst->nb[1]));\n\n        if (!src1) {\n            src0_p += swapped ? nc : 0;\n            src1_p += swapped ? 0 : nc;\n        }\n\n        for (int k = 0; k < nc; k++) {\n            const float x = std::min(src0_p[k], limit);\n            const float y = std::clamp(src1_p[k], -limit, limit);\n            const float out_glu = x / (1.f + expf(alpha * (-x)));\n            dst_p[k] = out_glu * (y + 1.f);\n        }\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const float x = dst_p[k];\n            GGML_UNUSED(x);\n            assert(!isnan(x));\n            assert(!isinf(x));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_swiglu_oai(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_swiglu_oai_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_geglu_erf\n\nstatic void ggml_compute_forward_geglu_erf_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    char * src0_d = (char *) src0->data;\n    char * src1_d = (char *) (src1 ? src1->data : src0->data);\n    const size_t src0_o = src0->nb[1];\n    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(ggml_is_contiguous_1(dst));\n\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;\n    const int nr = ggml_nrows(src0);\n\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_nrows(dst) == nr);\n\n    const int32_t swapped = ggml_get_op_params_i32(dst, 1);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        float * src0_p = (float *) (src0_d + i1*src0_o);\n        float * src1_p = (float *) (src1_d + i1*src1_o);\n\n        if (!src1) {\n            src0_p += swapped ? nc : 0;\n            src1_p += swapped ? 0 : nc;\n        }\n\n        ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];\n            GGML_UNUSED(x);\n            assert(!isnan(x));\n            assert(!isinf(x));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_geglu_erf_f16(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    char * src0_d = (char *) src0->data;\n    char * src1_d = (char *) (src1 ? src1->data : src0->data);\n    const size_t src0_o = src0->nb[1];\n    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(ggml_is_contiguous_1(dst));\n\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;\n    const int nr = ggml_nrows(src0);\n\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_nrows(dst) == nr);\n\n    const int32_t swapped = ggml_get_op_params_i32(dst, 1);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);\n        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);\n\n        if (!src1) {\n            src0_p += swapped ? nc : 0;\n            src1_p += swapped ? 0 : nc;\n        }\n\n        ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];\n            const float v = GGML_FP16_TO_FP32(x);\n            GGML_UNUSED(v);\n            assert(!isnan(v));\n            assert(!isinf(v));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_geglu_erf(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_geglu_erf_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_geglu_erf_f16(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_geglu_quick\n\nstatic void ggml_compute_forward_geglu_quick_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    char * src0_d = (char *) src0->data;\n    char * src1_d = (char *) (src1 ? src1->data : src0->data);\n    const size_t src0_o = src0->nb[1];\n    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(ggml_is_contiguous_1(dst));\n\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;\n    const int nr = ggml_nrows(src0);\n\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_nrows(dst) == nr);\n\n    const int32_t swapped = ggml_get_op_params_i32(dst, 1);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        float * src0_p = (float *) (src0_d + i1*src0_o);\n        float * src1_p = (float *) (src1_d + i1*src1_o);\n\n        if (!src1) {\n            src0_p += swapped ? nc : 0;\n            src1_p += swapped ? 0 : nc;\n        }\n\n        ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];\n            GGML_UNUSED(x);\n            assert(!isnan(x));\n            assert(!isinf(x));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_geglu_quick_f16(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    char * src0_d = (char *) src0->data;\n    char * src1_d = (char *) (src1 ? src1->data : src0->data);\n    const size_t src0_o = src0->nb[1];\n    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(ggml_is_contiguous_1(dst));\n\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;\n    const int nr = ggml_nrows(src0);\n\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_nrows(dst) == nr);\n\n    const int32_t swapped = ggml_get_op_params_i32(dst, 1);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);\n        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);\n\n        if (!src1) {\n            src0_p += swapped ? nc : 0;\n            src1_p += swapped ? 0 : nc;\n        }\n\n        ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);\n\n#ifndef NDEBUG\n        for (int k = 0; k < nc; k++) {\n            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];\n            const float v = GGML_FP16_TO_FP32(x);\n            GGML_UNUSED(v);\n            assert(!isnan(v));\n            assert(!isinf(v));\n        }\n#endif // NDEBUG\n    }\n}\n\nstatic void ggml_compute_forward_geglu_quick(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_geglu_quick_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_geglu_quick_f16(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_norm\n\nstatic void ggml_compute_forward_norm_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n\n    GGML_ASSERT(eps >= 0.0f);\n\n    for (int64_t i03 = 0; i03 < ne03; i03++) {\n        for (int64_t i02 = 0; i02 < ne02; i02++) {\n            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {\n                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);\n\n                float sum = 0.0;\n                ggml_vec_sum_f32(ne00, &sum, x);\n                float mean = sum/ne00;\n\n                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);\n                float variance = 0;\n\n#ifdef GGML_USE_ACCELERATE\n                mean = -mean;\n                vDSP_vsadd(x, 1, &mean, y, 1, ne00);\n                vDSP_measqv(y, 1, &variance, ne00);\n#else\n                variance = ggml_vec_cvar_f32(ne00, y, x, mean);\n#endif //GGML_USE_ACCELERATE\n\n                const float scale = 1.0f/sqrtf(variance + eps);\n                ggml_vec_scale_f32(ne00, y, scale);\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_norm(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_norm_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_group_rms_norm\n\nstatic void ggml_compute_forward_rms_norm_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n\n    GGML_ASSERT(eps >= 0.0f);\n\n    // TODO: optimize\n    for (int64_t i03 = 0; i03 < ne03; i03++) {\n        for (int64_t i02 = 0; i02 < ne02; i02++) {\n            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {\n                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);\n\n                ggml_float sum = 0.0;\n                for (int64_t i00 = 0; i00 < ne00; i00++) {\n                    sum += (ggml_float)(x[i00] * x[i00]);\n                }\n\n                const float mean = sum/ne00;\n\n                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);\n\n                memcpy(y, x, ne00 * sizeof(float));\n                // for (int i00 = 0; i00 < ne00; i00++) {\n                //     y[i00] = x[i00];\n                // }\n\n                const float scale = 1.0f/sqrtf(mean + eps);\n\n                // if you hit this, likely you got an inf somewhere earlier\n                assert(scale > 0.0f);\n\n                ggml_vec_scale_f32(ne00, y, scale);\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_rms_norm(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_rms_norm_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\nstatic void ggml_compute_forward_rms_norm_back_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output\n    const ggml_tensor * src1 = dst->src[1]; // src1 from forward pass\n\n    GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));\n\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n    GGML_ASSERT(src1->nb[0] == sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n\n    // TODO: optimize\n    for (int64_t i03 = 0; i03 < ne03; i03++) {\n        for (int64_t i02 = 0; i02 < ne02; i02++) {\n            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {\n                // src1 is same shape as src0 => same indices\n                const int64_t i11 = i01;\n                const int64_t i12 = i02;\n                const int64_t i13 = i03;\n\n                const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);\n                const float * x  = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);\n\n                ggml_float sum_xx  = 0.0;\n                ggml_float sum_xdz = 0.0;\n\n                for (int64_t i00 = 0; i00 < ne00; i00++) {\n                    sum_xx  += (ggml_float)(x[i00] * x[i00]);\n                    sum_xdz += (ggml_float)(x[i00] * dz[i00]);\n                }\n\n                //const float mean     = (float)(sum_xx)/ne00;\n                const float mean_eps = (float)(sum_xx)/ne00 + eps;\n                const float sum_eps  = (float)(sum_xx) + eps*ne00;\n                //const float mean_xdz = (float)(sum_xdz)/ne00;\n                // we could cache rms from forward pass to improve performance.\n                // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.\n                //const float rms      = sqrtf(mean_eps);\n                const float rrms     = 1.0f / sqrtf(mean_eps);\n                //const float scale    = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)\n\n                {\n                    // z = rms_norm(x)\n                    //\n                    // rms_norm(src1) =\n                    //     scale(\n                    //         src1,\n                    //         div(\n                    //             1,\n                    //             sqrt(\n                    //                 add(\n                    //                     scale(\n                    //                         sum(\n                    //                             sqr(\n                    //                                 src1)),\n                    //                         (1.0/N)),\n                    //                     eps))));\n\n                    // postorder:\n                    // ## op    args         grad\n                    // 00 param src1         grad[#00]\n                    // 01 const 1\n                    // 02 sqr   (#00)        grad[#02]\n                    // 03 sum   (#02)        grad[#03]\n                    // 04 const 1/N\n                    // 05 scale (#03, #04)   grad[#05]\n                    // 06 const eps\n                    // 07 add   (#05, #06)   grad[#07]\n                    // 08 sqrt  (#07)        grad[#08]\n                    // 09 div   (#01,#08)    grad[#09]\n                    // 10 scale (#00,#09)    grad[#10]\n                    //\n                    // backward pass, given grad[#10]\n                    // #10: scale\n                    // grad[#00] += scale(grad[#10],#09)\n                    // grad[#09] += sum(mul(grad[#10],#00))\n                    // #09: div\n                    // grad[#08] += neg(mul(grad[#09], div(#09,#08)))\n                    // #08: sqrt\n                    // grad[#07] += mul(grad[#08], div(0.5, #08))\n                    // #07: add\n                    // grad[#05] += grad[#07]\n                    // #05: scale\n                    // grad[#03] += scale(grad[#05],#04)\n                    // #03: sum\n                    // grad[#02] += repeat(grad[#03], #02)\n                    // #02:\n                    // grad[#00] += scale(mul(#00, grad[#02]), 2.0)\n                    //\n                    // substitute and simplify:\n                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)\n                    // grad[#02] = repeat(grad[#03], #02)\n                    // grad[#02] = repeat(scale(grad[#05],#04), #02)\n                    // grad[#02] = repeat(scale(grad[#07],#04), #02)\n                    // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)\n                    // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)\n                    // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)\n                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)\n                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)\n                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)\n                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)\n                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)\n                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)\n                    // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)\n                    // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))\n                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))\n                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))\n                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))\n                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))\n                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))\n                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))\n                    // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))\n                    // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))\n                    // a = b*c + d*e\n                    // a = b*c*f/f + d*e*f/f\n                    // a = (b*c*f + d*e*f)*(1/f)\n                    // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))\n                    // a = (b + d*e/c)*c\n                    // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)\n                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms\n                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms\n                    // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms\n                    // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms\n                    // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms\n                    // a = (dz + x*div(-mean_xdz,mean_eps))*rrms\n                    // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)\n                    // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)\n                    // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)\n                }\n                // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)\n                // post-order:\n                // dx := x\n                // dx := scale(dx,-mean_xdz/mean_eps)\n                // dx := add(dx, dz)\n                // dx := scale(dx, rrms)\n                float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);\n\n                // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)\n                ggml_vec_cpy_f32  (ne00, dx, x);\n                // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);\n                ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);\n                ggml_vec_acc_f32  (ne00, dx, dz);\n                ggml_vec_scale_f32(ne00, dx, rrms);\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_rms_norm_back(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_rms_norm_back_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_group_norm\n\nstatic void ggml_compute_forward_group_norm_f32(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    // TODO: optimize\n\n    float eps;\n    memcpy(&eps, dst->op_params + 1, sizeof(float));\n\n    int n_channels = src0->ne[2];\n    int n_groups = dst->op_params[0];\n    int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;\n    for (int i = ith; i < n_groups; i += nth) {\n        int start = i * n_channels_per_group;\n        int end = start + n_channels_per_group;\n        if (end > n_channels) {\n            end = n_channels;\n        }\n        int step = end - start;\n\n        for (int64_t i03 = 0; i03 < ne03; i03++) {\n            ggml_float sum = 0.0;\n            for (int64_t i02 = start; i02 < end; i02++) {\n                for (int64_t i01 = 0; i01 < ne01; i01++) {\n                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);\n\n                    ggml_float sumr = 0.0;\n                    for (int64_t i00 = 0; i00 < ne00; i00++) {\n                        sumr += (ggml_float)x[i00];\n                    }\n                    sum += sumr;\n                }\n            }\n            const float mean = sum / (ne00 * ne01 * step);\n\n            ggml_float sum2 = 0.0;\n            for (int64_t i02 = start; i02 < end; i02++) {\n                for (int64_t i01 = 0; i01 < ne01; i01++) {\n                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);\n\n                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);\n\n                    ggml_float sumr = 0.0;\n                    for (int64_t i00 = 0; i00 < ne00; i00++) {\n                        float v = x[i00] - mean;\n                        y[i00] = v;\n                        sumr += (ggml_float)(v * v);\n                    }\n                    sum2 += sumr;\n                }\n            }\n            const float variance = sum2 / (ne00 * ne01 * step);\n            const float scale = 1.0f / sqrtf(variance + eps);\n\n            for (int64_t i02 = start; i02 < end; i02++) {\n                for (int64_t i01 = 0; i01 < ne01; i01++) {\n                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);\n                    ggml_vec_scale_f32(ne00, y, scale);\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_group_norm(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_group_norm_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_l2_norm\n\nstatic void ggml_compute_forward_l2_norm_f32(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n\n    GGML_ASSERT(eps >= 0.0f);\n\n    // TODO: optimize\n    for (int64_t i03 = 0; i03 < ne03; i03++) {\n        for (int64_t i02 = 0; i02 < ne02; i02++) {\n            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {\n                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);\n\n                ggml_float sum = 0.0;\n                for (int64_t i00 = 0; i00 < ne00; i00++) {\n                    sum += (ggml_float)(x[i00] * x[i00]);\n                }\n\n                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);\n\n                memcpy(y, x, ne00 * sizeof(float));\n\n                const float scale = 1.0f/fmaxf(sqrtf(sum), eps);\n\n                ggml_vec_scale_f32(ne00, y, scale);\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_l2_norm(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_l2_norm_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_out_prod\n\nstatic void ggml_compute_forward_out_prod_f32(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    GGML_ASSERT(ne0 == ne00);\n    GGML_ASSERT(ne1 == ne10);\n    GGML_ASSERT(ne2 == ne12);\n    GGML_ASSERT(ne3 == ne13);\n\n    GGML_ASSERT(ne2 % ne02 == 0);\n    GGML_ASSERT(ne3 % ne03 == 0);\n\n    // we don't support permuted src0 or src1\n    GGML_ASSERT(nb00 == sizeof(float));\n\n    // dst cannot be transposed or permuted\n    GGML_ASSERT(nb0 == sizeof(float));\n    // GGML_ASSERT(nb0 <= nb1);\n    // GGML_ASSERT(nb1 <= nb2);\n    // GGML_ASSERT(nb2 <= nb3);\n\n    // nb01 >= nb00 - src0 is not transposed\n    //   compute by src0 rows\n\n    if (ith == 0) {\n        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);\n    }\n    ggml_barrier(params->threadpool);\n\n    // dst[:,:,:,:] = 0\n    // for i2,i3:\n    //   for i1:\n    //     for i01:\n    //       for i0:\n    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]\n\n    // parallelize by last three dimensions\n\n    // total rows in dst\n    const int64_t nr = ne1*ne2*ne3;\n\n    // rows per thread\n    const int64_t dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int64_t ir0 = dr*ith;\n    const int64_t ir1 = MIN(ir0 + dr, nr);\n\n    // block-tiling attempt\n    const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);\n    const int64_t blck_1 = 16;\n\n    // dps == dst per src0, used for group query attention\n    const int64_t dps2 = ne2 / ne02;\n    const int64_t dps3 = ne3 / ne03;\n\n    for (int64_t bir = ir0; bir < ir1; bir += blck_1) {\n        const int64_t bir1 = MIN(bir + blck_1, ir1);\n        for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {\n            const int64_t bne01 = MIN(bi01 + blck_0, ne01);\n            for (int64_t ir = bir; ir < bir1; ++ir) {\n                // dst indices\n                const int64_t i3 = ir/(ne2*ne1);\n                const int64_t i2 = (ir - i3*ne2*ne1)/ne1;\n                const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);\n\n                const int64_t i02 = i2 / dps2;\n                const int64_t i03 = i3 / dps3;\n\n                //const int64_t i10 = i1;\n                const int64_t i12 = i2;\n                const int64_t i13 = i3;\n\n#if GGML_VEC_MAD_UNROLL > 2\n                const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);\n                for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {\n                    const int64_t i11 = i01;\n\n                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));\n                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));\n                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));\n\n                    ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);\n                }\n                for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {\n                    const int64_t i11 = i01;\n\n                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));\n                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));\n                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));\n\n                    ggml_vec_mad_f32(ne0, d, s0, *s1);\n                }\n#else\n                for (int64_t i01 = bi01; i01 < bne01; ++i01) {\n                    const int64_t i11 = i01;\n\n                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));\n                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));\n                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));\n\n                    ggml_vec_mad_f32(ne0, d, s0, *s1);\n                }\n#endif\n            }\n        }\n    }\n}\n\nstatic void ggml_compute_forward_out_prod_q_f32(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_TENSOR_BINARY_OP_LOCALS;\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const ggml_type type = src0->type;\n    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;\n\n    GGML_ASSERT(ne02 == ne12);\n    GGML_ASSERT(ne03 == ne13);\n    GGML_ASSERT(ne2  == ne12);\n    GGML_ASSERT(ne3  == ne13);\n\n    // we don't support permuted src0 dim0\n    GGML_ASSERT(nb00 == ggml_type_size(type));\n\n    // dst dim0 cannot be transposed or permuted\n    GGML_ASSERT(nb0 == sizeof(float));\n    // GGML_ASSERT(nb0 <= nb1);\n    // GGML_ASSERT(nb1 <= nb2);\n    // GGML_ASSERT(nb2 <= nb3);\n\n    GGML_ASSERT(ne0 == ne00);\n    GGML_ASSERT(ne1 == ne10);\n    GGML_ASSERT(ne2 == ne02);\n    GGML_ASSERT(ne3 == ne03);\n\n    // nb01 >= nb00 - src0 is not transposed\n    //   compute by src0 rows\n\n    if (ith == 0) {\n        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);\n    }\n    ggml_barrier(params->threadpool);\n\n    // parallelize by last three dimensions\n\n    // total rows in dst\n    const int64_t nr = ne1*ne2*ne3;\n\n    // rows per thread\n    const int64_t dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int64_t ir0 = dr*ith;\n    const int64_t ir1 = MIN(ir0 + dr, nr);\n\n    // dst[:,:,:,:] = 0\n    // for i2,i3:\n    //   for i1:\n    //     for i01:\n    //       for i0:\n    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]\n\n    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;\n\n    for (int64_t ir = ir0; ir < ir1; ++ir) {\n        // dst indices\n        const int64_t i3 = ir/(ne2*ne1);\n        const int64_t i2 = (ir - i3*ne2*ne1)/ne1;\n        const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);\n\n        const int64_t i02 = i2;\n        const int64_t i03 = i3;\n\n        //const int64_t i10 = i1;\n        const int64_t i12 = i2;\n        const int64_t i13 = i3;\n\n        for (int64_t i01 = 0; i01 < ne01; ++i01) {\n            const int64_t i11 = i01;\n\n            float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));\n            float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));\n            float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));\n\n            dequantize_row_q(s0, wdata, ne0);\n            ggml_vec_mad_f32(ne0, d, wdata, *s1);\n        }\n    }\n}\n\nvoid ggml_compute_forward_out_prod(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_MXFP4:\n        case GGML_TYPE_NVFP4:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_TQ1_0:\n        case GGML_TYPE_TQ2_0:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:\n        case GGML_TYPE_IQ4_NL:\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ3_S:\n        case GGML_TYPE_IQ2_S:\n            {\n                ggml_compute_forward_out_prod_q_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                GGML_ABORT(\"fatal error\"); // todo\n                // ggml_compute_forward_out_prod_f16_f32(params, dst);\n            }\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_out_prod_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_scale\n\nstatic void ggml_compute_forward_scale_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(ggml_is_contiguous(src0));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n\n    float s; // scale factor\n    float b; // bias\n\n    memcpy(&s, (float *) dst->op_params + 0, sizeof(float));\n    memcpy(&b, (float *) dst->op_params + 1, sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src0->ne[0];\n    const int nr = ggml_nrows(src0);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    const size_t nb01 = src0->nb[1];\n\n    const size_t nb1 = dst->nb[1];\n\n    if (b == 0.0f) {\n        for (int i1 = ir0; i1 < ir1; i1++) {\n            if (dst->data != src0->data) {\n                // src0 is same shape as dst => same indices\n                // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy\n                memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));\n            }\n            ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);\n        }\n    } else {\n        for (int i1 = ir0; i1 < ir1; i1++) {\n            ggml_vec_mad1_f32(nc,\n                (float *) ((char *) dst->data  + i1*nb1),\n                (float *) ((char *) src0->data + i1*nb1),\n                s, b);\n        }\n    }\n}\n\nvoid ggml_compute_forward_scale(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_scale_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_set\n\nstatic void ggml_compute_forward_set_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));\n\n    // view src0 and dst with these strides and data offset inbytes during set\n    // nb0 is implicitly element_size because src0 and dst are contiguous\n    size_t nb1     = ((int32_t *) dst->op_params)[0];\n    size_t nb2     = ((int32_t *) dst->op_params)[1];\n    size_t nb3     = ((int32_t *) dst->op_params)[2];\n    size_t offset  = ((int32_t *) dst->op_params)[3];\n    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];\n\n    if (!inplace) {\n        if (params->ith == 0) {\n            // memcpy needs to be synchronized across threads to avoid race conditions.\n            // => do it in INIT phase\n            memcpy(\n                ((char *)  dst->data),\n                ((char *) src0->data),\n                ggml_nbytes(dst));\n        }\n        ggml_barrier(params->threadpool);\n    }\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nr = ggml_nrows(src1);\n    const int nc = src1->ne[0];\n\n    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)\n\n    // src0 and dst as viewed during set\n    const size_t nb0 = ggml_element_size(src0);\n\n    const int im0 = (ne10 == 0 ? 0 : ne10-1);\n    const int im1 = (ne11 == 0 ? 0 : ne11-1);\n    const int im2 = (ne12 == 0 ? 0 : ne12-1);\n    const int im3 = (ne13 == 0 ? 0 : ne13-1);\n\n    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));\n\n    GGML_ASSERT(nb10 == sizeof(float));\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        // src0 and dst are viewed with shape of src1 and offset\n        // => same indices\n        const int i3 = ir/(ne12*ne11);\n        const int i2 = (ir - i3*ne12*ne11)/ne11;\n        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);\n\n        ggml_vec_cpy_f32(nc,\n                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),\n                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));\n    }\n}\n\nstatic void ggml_compute_forward_set_i32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));\n\n    // view src0 and dst with these strides and data offset inbytes during set\n    // nb0 is implicitly element_size because src0 and dst are contiguous\n    size_t nb1     = ((int32_t *) dst->op_params)[0];\n    size_t nb2     = ((int32_t *) dst->op_params)[1];\n    size_t nb3     = ((int32_t *) dst->op_params)[2];\n    size_t offset  = ((int32_t *) dst->op_params)[3];\n    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];\n\n    if (!inplace) {\n        if (params->ith == 0) {\n            // memcpy needs to be synchronized across threads to avoid race conditions.\n            // => do it in INIT phase\n            memcpy(\n                ((char *)  dst->data),\n                ((char *) src0->data),\n                ggml_nbytes(dst));\n        }\n        ggml_barrier(params->threadpool);\n    }\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nr = ggml_nrows(src1);\n    const int nc = src1->ne[0];\n\n    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)\n\n    // src0 and dst as viewed during set\n    const size_t nb0 = ggml_element_size(src0);\n\n    const int im0 = (ne10 == 0 ? 0 : ne10-1);\n    const int im1 = (ne11 == 0 ? 0 : ne11-1);\n    const int im2 = (ne12 == 0 ? 0 : ne12-1);\n    const int im3 = (ne13 == 0 ? 0 : ne13-1);\n\n    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));\n\n    GGML_ASSERT(nb10 == sizeof(int32_t));\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        // src0 and dst are viewed with shape of src1 and offset\n        // => same indices\n        const int i3 = ir/(ne12*ne11);\n        const int i2 = (ir - i3*ne12*ne11)/ne11;\n        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);\n\n        ggml_vec_cpy_i32(nc,\n                (int32_t *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),\n                (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));\n    }\n}\n\nvoid ggml_compute_forward_set(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_set_f32(params, dst);\n            } break;\n        case GGML_TYPE_I32:\n            {\n                ggml_compute_forward_set_i32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n        case GGML_TYPE_BF16:\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_Q8_1:\n        case GGML_TYPE_MXFP4:\n        case GGML_TYPE_NVFP4:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_TQ1_0:\n        case GGML_TYPE_TQ2_0:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:\n        case GGML_TYPE_IQ4_NL:\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ3_S:\n        case GGML_TYPE_IQ2_S:\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_cpy\n\nvoid ggml_compute_forward_cpy(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    ggml_compute_forward_dup(params, dst);\n}\n\n// ggml_compute_forward_cont\n\nvoid ggml_compute_forward_cont(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    ggml_compute_forward_dup(params, dst);\n}\n\n// ggml_compute_forward_get_rows\n\nstatic void ggml_compute_forward_get_rows_q(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int64_t nc = ne00;\n    const int64_t nr = ggml_nelements(src1);\n\n    const ggml_type type = src0->type;\n    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;\n\n    assert(ne0  == nc);\n    assert(ne02 == ne11);\n    assert(nb00 == ggml_type_size(type));\n    assert(ggml_nrows(dst) == nr);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int64_t i = ir0; i < ir1; ++i) {\n        const int64_t i12 = i/(ne11*ne10);\n        const int64_t i11 = (i - i12*ne11*ne10)/ne10;\n        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);\n        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);\n\n        GGML_ASSERT(i01 >= 0 && i01 < ne01);\n\n        dequantize_row_q(\n                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),\n                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);\n    }\n}\n\nstatic void ggml_compute_forward_get_rows_f16(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int64_t nc = ne00;\n    const int64_t nr = ggml_nelements(src1);\n\n    assert(ne0  == nc);\n    assert(ne02 == ne11);\n    assert(nb00 == sizeof(ggml_fp16_t));\n    assert(ggml_nrows(dst) == nr);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int64_t i = ir0; i < ir1; ++i) {\n        const int64_t i12 = i/(ne11*ne10);\n        const int64_t i11 = (i - i12*ne11*ne10)/ne10;\n        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);\n        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);\n\n        GGML_ASSERT(i01 >= 0 && i01 < ne01);\n\n        ggml_cpu_fp16_to_fp32(\n            (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),\n                       (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);\n    }\n}\n\nstatic void ggml_compute_forward_get_rows_bf16(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int64_t nc = ne00;\n    const int64_t nr = ggml_nelements(src1);\n\n    assert(ne0  == nc);\n    assert(ne02 == ne11);\n    assert(nb00 == sizeof(ggml_bf16_t));\n    assert(ggml_nrows(dst) == nr);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int64_t i = ir0; i < ir1; ++i) {\n        const int64_t i12 = i/(ne11*ne10);\n        const int64_t i11 = (i - i12*ne11*ne10)/ne10;\n        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);\n        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);\n\n        GGML_ASSERT(i01 >= 0 && i01 < ne01);\n\n        ggml_cpu_bf16_to_fp32(\n            (const ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),\n                        (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);\n    }\n}\n\nstatic void ggml_compute_forward_get_rows_f32(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int64_t nc = ne00;\n    const int64_t nr = ggml_nelements(src1);\n\n    assert(ne0  == nc);\n    assert(ne02 == ne11);\n    assert(nb00 == sizeof(float));\n    assert(ggml_nrows(dst) == nr);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int64_t i = ir0; i < ir1; ++i) {\n        const int64_t i12 = i/(ne11*ne10);\n        const int64_t i11 = (i - i12*ne11*ne10)/ne10;\n        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);\n        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);\n\n        GGML_ASSERT(i01 >= 0 && i01 < ne01);\n\n        ggml_vec_cpy_f32(nc,\n                (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),\n                (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));\n    }\n}\n\nvoid ggml_compute_forward_get_rows(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_Q8_1:\n        case GGML_TYPE_MXFP4:\n        case GGML_TYPE_NVFP4:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_TQ1_0:\n        case GGML_TYPE_TQ2_0:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:\n        case GGML_TYPE_IQ4_NL:\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ3_S:\n        case GGML_TYPE_IQ2_S:\n            {\n                ggml_compute_forward_get_rows_q(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_get_rows_f16(params, dst);\n            } break;\n        case GGML_TYPE_BF16:\n            {\n                ggml_compute_forward_get_rows_bf16(params, dst);\n            } break;\n        case GGML_TYPE_F32:\n        case GGML_TYPE_I32:\n            {\n                ggml_compute_forward_get_rows_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n\n    //static bool first = true;\n    //printf(\"ne0 = %d, ne1 = %d, ne2 = %d\\n\", dst->ne[0], dst->ne[1], dst->ne[2]);\n    //if (first) {\n    //    first = false;\n    //} else {\n    //    for (int k = 0; k < dst->ne[1]; ++k) {\n    //        for (int j = 0; j < dst->ne[0]/16; ++j) {\n    //            for (int i = 0; i < 16; ++i) {\n    //                printf(\"%8.4f \", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);\n    //            }\n    //            printf(\"\\n\");\n    //        }\n    //        printf(\"\\n\");\n    //    }\n    //    printf(\"\\n\");\n    //    exit(0);\n    //}\n}\n\ntemplate<typename idx_t>\nstatic void ggml_compute_forward_set_rows_f32(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int64_t nc = ne00;\n    const int64_t nr = ne01;\n\n    assert(ne0  == nc);\n    assert(ne2  == ne02);\n    assert(ne3  == ne03);\n    assert(src0->type == GGML_TYPE_F32);\n    assert(ne02 % ne11 == 0);\n    assert(ne03 % ne12 == 0);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    // rows per thread\n    const int64_t dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int64_t ir0 = dr*ith;\n    const int64_t ir1 = std::min(ir0 + dr, nr);\n\n    ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;\n\n    for (int64_t i03 = 0; i03 < ne03; ++i03) {\n        for (int64_t i02 = 0; i02 < ne02; ++i02) {\n            for (int64_t i = ir0; i < ir1; ++i) {\n                const int64_t i12 = i03%ne12;\n                const int64_t i11 = i02%ne11;\n                const int64_t i10 = i;\n\n                const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);\n\n                GGML_ASSERT(i1 >= 0 && i1 < ne1);\n\n                from_float(\n                        (const float *) ((char *) src0->data +  i*nb01 + i02*nb02 + i03*nb03),\n                                        ((char *)  dst->data + i1*nb1  + i02*nb2  + i03*nb3), nc);\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_set_rows(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                if (src1->type == GGML_TYPE_I64) {\n                    ggml_compute_forward_set_rows_f32<int64_t>(params, dst);\n                } else if (src1->type == GGML_TYPE_I32) {\n                    ggml_compute_forward_set_rows_f32<int32_t>(params, dst);\n                } else {\n                    GGML_ABORT(\"src1->type = %d (%s) not supported\", src1->type, ggml_type_name(src1->type));\n                }\n            } break;\n        default:\n            {\n                GGML_ABORT(\"src0->type = %d (%s) not supported\", src0->type, ggml_type_name(src0->type));\n            }\n    }\n}\n\n// ggml_compute_forward_get_rows_back\n\nstatic void ggml_compute_forward_get_rows_back_f32_f16(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    GGML_ASSERT(ggml_is_contiguous(dst));\n\n    // ggml_compute_forward_dup_same_cont(params, opt0, dst);\n\n    memset(dst->data, 0, ggml_nbytes(dst));\n\n    const int nc = src0->ne[0];\n    const int nr = ggml_nelements(src1);\n\n    GGML_ASSERT( dst->ne[0] == nc);\n    GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));\n\n    for (int i = 0; i < nr; ++i) {\n        const int r = ((int32_t *) src1->data)[i];\n\n        for (int j = 0; j < nc; ++j) {\n            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];\n            ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_CPU_FP16_TO_FP32(v);\n        }\n    }\n}\n\nstatic void ggml_compute_forward_get_rows_back_f32(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    GGML_ASSERT(ggml_is_contiguous(dst));\n\n    // ggml_compute_forward_dup_same_cont(params, opt0, dst);\n\n    memset(dst->data, 0, ggml_nbytes(dst));\n\n    const int nc = src0->ne[0];\n    const int nr = ggml_nelements(src1);\n\n    GGML_ASSERT( dst->ne[0] == nc);\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n\n    for (int i = 0; i < nr; ++i) {\n        const int r = ((int32_t *) src1->data)[i];\n\n        ggml_vec_add_f32(nc,\n                (float *) ((char *)  dst->data + r*dst->nb[1]),\n                (float *) ((char *)  dst->data + r*dst->nb[1]),\n                (float *) ((char *) src0->data + i*src0->nb[1]));\n    }\n}\n\nvoid ggml_compute_forward_get_rows_back(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_get_rows_back_f32_f16(params, dst);\n            } break;\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_get_rows_back_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n\n    //static bool first = true;\n    //printf(\"ne0 = %d, ne1 = %d, ne2 = %d\\n\", dst->ne[0], dst->ne[1], dst->ne[2]);\n    //if (first) {\n    //    first = false;\n    //} else {\n    //    for (int k = 0; k < dst->ne[1]; ++k) {\n    //        for (int j = 0; j < dst->ne[0]/16; ++j) {\n    //            for (int i = 0; i < 16; ++i) {\n    //                printf(\"%8.4f \", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);\n    //            }\n    //            printf(\"\\n\");\n    //        }\n    //        printf(\"\\n\");\n    //    }\n    //    printf(\"\\n\");\n    //    exit(0);\n    //}\n}\n\n// ggml_compute_forward_diag\n\nstatic void ggml_compute_forward_diag_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    // TODO: handle transposed/permuted matrices\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    GGML_ASSERT(ne00 == ne0);\n    GGML_ASSERT(ne00 == ne1);\n    GGML_ASSERT(ne01 == 1);\n    GGML_ASSERT(ne02 == ne2);\n    GGML_ASSERT(ne03 == ne3);\n\n    GGML_ASSERT(nb00 == sizeof(float));\n    GGML_ASSERT(nb0  == sizeof(float));\n\n    for (int i3 = 0; i3 < ne3; i3++) {\n        for (int i2 = 0; i2 < ne2; i2++) {\n            for (int i1 = 0; i1 < ne1; i1++) {\n                float * d = (float *)((char *)  dst->data + i3*nb3  + i2*nb2 + i1*nb1);\n                float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);\n                for (int i0 = 0; i0 < i1; i0++) {\n                    d[i0] = 0;\n                }\n                d[i1] = s[i1];\n                for (int i0 = i1+1; i0 < ne0; i0++) {\n                    d[i0] = 0;\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_diag(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_diag_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_diag_mask_inf\n\nstatic void ggml_compute_forward_diag_mask_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst,\n        const float value) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int  n_past  = ((int32_t *) dst->op_params)[0];\n    const bool inplace = src0->data == dst->data;\n\n    GGML_ASSERT(n_past >= 0);\n\n    if (!inplace) {\n        if (ith == 0) {\n            // memcpy needs to be synchronized across threads to avoid race conditions.\n            // => do it in INIT phase\n            GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));\n            GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));\n            memcpy(\n                ((char *)  dst->data),\n                ((char *) src0->data),\n                ggml_nbytes(dst));\n        }\n        ggml_barrier(params->threadpool);\n    }\n\n    // TODO: handle transposed/permuted matrices\n\n    const int n  = ggml_nrows(src0);\n    const int nc = src0->ne[0];\n    const int nr = src0->ne[1];\n    const int nz = n/nr;\n\n    GGML_ASSERT( dst->nb[0] == sizeof(float));\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n\n    for (int k = 0; k < nz; k++) {\n        for (int j = ith; j < nr; j += nth) {\n            for (int i = n_past; i < nc; i++) {\n                if (i > n_past + j) {\n                    *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_diag_mask_inf(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\nvoid ggml_compute_forward_diag_mask_zero(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_diag_mask_f32(params, dst, 0);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_soft_max\n\nstatic void ggml_compute_forward_soft_max_f32(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    const ggml_tensor * src2 = dst->src[2];\n\n    assert(ggml_is_contiguous(dst));\n    assert(ggml_are_same_shape(src0, dst));\n\n    float scale    = 1.0f;\n    float max_bias = 0.0f;\n\n    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));\n    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    const int64_t nb11 = src1 ? src1->nb[1] : 1;\n    const int64_t nb12 = src1 ? src1->nb[2] : 1;\n    const int64_t nb13 = src1 ? src1->nb[3] : 1;\n\n    const int64_t ne12 = src1 ? src1->ne[2] : 1;\n    const int64_t ne13 = src1 ? src1->ne[3] : 1;\n\n    // TODO: is this supposed to be ceil instead of floor?\n    //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370\n    const uint32_t n_head      = ne02;\n    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));\n\n    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);\n    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n    float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;\n\n    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);\n\n    // sinks\n    const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;\n\n    for (int64_t i03 = 0; i03 < ne03; i03++) {\n        for (int64_t i02 = 0; i02 < ne02; i02++) {\n            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {\n                const int64_t i11 = i01;\n                const int64_t i12 = i02%ne12;\n                const int64_t i13 = i03%ne13;\n\n                // ALiBi\n                const uint32_t h = i02; // head\n                const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;\n\n                float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);\n                float * dp = (float *)((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3);\n\n                // broadcast the mask across rows\n                ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;\n                float       * mp_f32 = src1 ? (float       *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;\n\n                ggml_vec_cpy_f32  (ne00, wp, sp);\n                ggml_vec_scale_f32(ne00, wp, scale);\n                if (mp_f32) {\n                    if (use_f16) {\n                        for (int i = 0; i < ne00; ++i) {\n                            wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);\n                        }\n                    } else {\n                        for (int i = 0; i < ne00; ++i) {\n                            wp[i] += slope*mp_f32[i];\n                        }\n                    }\n                }\n\n#ifndef NDEBUG\n                for (int i = 0; i < ne00; ++i) {\n                    //printf(\"p[%d] = %f\\n\", i, p[i]);\n                    assert(!isnan(wp[i]));\n                }\n#endif // NDEBUG\n\n                float max = -INFINITY;\n                ggml_vec_max_f32(ne00, &max, wp);\n\n                // if we have sinks, make a correction as if they were included in the softmax\n                if (sk) {\n                    max = MAX(max, sk[i02]);\n                }\n\n                ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);\n                assert(sum > 0.0);\n\n                if (sk) {\n                    sum += (ggml_float) expf(sk[i02] - max);\n                }\n\n                sum = 1.0/sum;\n                ggml_vec_scale_f32(ne00, dp, sum);\n\n#ifndef NDEBUG\n                for (int i = 0; i < ne00; ++i) {\n                    assert(!isnan(dp[i]));\n                    assert(!isinf(dp[i]));\n                }\n#endif // NDEBUG\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_soft_max(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_soft_max_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n\n// ggml_compute_forward_soft_max_ext_back\n\nstatic void ggml_compute_forward_soft_max_ext_back_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(ggml_is_contiguous(src0));\n    GGML_ASSERT(ggml_is_contiguous(src1));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n    GGML_ASSERT(ggml_are_same_shape(src1, dst));\n\n    float scale    = 1.0f;\n    float max_bias = 0.0f;\n\n    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));\n    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));\n\n    GGML_ASSERT(max_bias == 0.0f);\n\n    // TODO: handle transposed/permuted matrices\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc = src0->ne[0];\n    const int nr = ggml_nrows(src0);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);\n        float *y  = (float *)((char *) src1->data + i1*src1->nb[1]);\n        float *dx = (float *)((char *) dst->data  + i1*dst->nb[1]);\n\n#ifndef NDEBUG\n        for (int i = 0; i < nc; ++i) {\n            //printf(\"p[%d] = %f\\n\", i, p[i]);\n            assert(!isnan(dy[i]));\n            assert(!isnan(y[i]));\n        }\n#endif // NDEBUG\n        // Jii = yi - yi*yi\n        // Jij = -yi*yj\n        // J = diag(y)-y.T*y\n        // dx = J * dy\n        // dxk = sum_i(Jki * dyi)\n        // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk\n        // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk\n        // dxk = sum_i(-yk*yi * dyi) + yk*dyk\n        // dxk = -yk * sum_i(yi * dyi) + yk*dyk\n        // dxk = -yk * dot(y, dy) + yk*dyk\n        // dxk = yk * (- dot(y, dy) + dyk)\n        // dxk = yk * (dyk - dot(y, dy))\n        //\n        // post-order:\n        // dot_y_dy := dot(y, dy)\n        // dx := dy\n        // dx := dx - dot_y_dy\n        // dx := dx * y\n\n        // linear runtime, no additional memory\n        float dot_y_dy = 0;\n        ggml_vec_dot_f32  (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);\n        ggml_vec_cpy_f32  (nc, dx, dy);\n        ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);\n        ggml_vec_mul_f32  (nc, dx, dx, y);\n        ggml_vec_scale_f32(nc, dx, scale);\n\n#ifndef NDEBUG\n        for (int i = 0; i < nc; ++i) {\n            assert(!isnan(dx[i]));\n            assert(!isinf(dx[i]));\n        }\n#endif // NDEBUG\n    }\n}\n\nvoid ggml_compute_forward_soft_max_ext_back(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_soft_max_ext_back_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_clamp\n\nstatic void ggml_compute_forward_clamp_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    float min;\n    float max;\n    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));\n    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int n  = ggml_nrows(src0);\n    const int nc = src0->ne[0];\n\n    const size_t nb00 = src0->nb[0];\n    const size_t nb01 = src0->nb[1];\n\n    const size_t nb0 = dst->nb[0];\n    const size_t nb1 = dst->nb[1];\n\n    GGML_ASSERT( nb0 == sizeof(float));\n    GGML_ASSERT(nb00 == sizeof(float));\n\n    for (int j = ith; j < n; j += nth) {\n        float * dst_ptr  = (float *) ((char *)  dst->data + j*nb1);\n        float * src0_ptr = (float *) ((char *) src0->data + j*nb01);\n\n        for (int i = 0; i < nc; i++) {\n            dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);\n        }\n    }\n}\n\nstatic void ggml_compute_forward_clamp_f16(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    float min;\n    float max;\n    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));\n    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int n  = ggml_nrows(src0);\n    const int nc = src0->ne[0];\n\n    const size_t nb00 = src0->nb[0];\n    const size_t nb01 = src0->nb[1];\n\n    const size_t nb0 = dst->nb[0];\n    const size_t nb1 = dst->nb[1];\n\n    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));\n    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));\n\n    for (int j = ith; j < n; j += nth) {\n        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *)  dst->data + j*nb1);\n        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);\n\n        for (int i = 0; i < nc; i++) {\n            float v = GGML_CPU_FP16_TO_FP32(src0_ptr[i]);\n            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));\n        }\n    }\n}\n\nvoid ggml_compute_forward_clamp(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_clamp_f32(params, dst);\n            } break;\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_clamp_f16(params, dst);\n            } break;\n        case GGML_TYPE_BF16:\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_Q8_1:\n        case GGML_TYPE_MXFP4:\n        case GGML_TYPE_NVFP4:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_TQ1_0:\n        case GGML_TYPE_TQ2_0:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:\n        case GGML_TYPE_IQ4_NL:\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ3_S:\n        case GGML_TYPE_IQ2_S:\n        case GGML_TYPE_Q8_K:\n        case GGML_TYPE_I8:\n        case GGML_TYPE_I16:\n        case GGML_TYPE_I32:\n        case GGML_TYPE_I64:\n        case GGML_TYPE_F64:\n        case GGML_TYPE_COUNT:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_rope\n\nstatic float rope_yarn_ramp(const float low, const float high, const int i0) {\n    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);\n    return 1 - MIN(1, MAX(0, y));\n}\n\n// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn\n// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.\nstatic void rope_yarn(\n    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,\n    float * cos_theta, float * sin_theta) {\n    // Get n-d rotational scaling corrected for extrapolation\n    float theta_interp = freq_scale * theta_extrap;\n    float theta = theta_interp;\n    if (ext_factor != 0.0f) {\n        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;\n        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;\n\n        // Get n-d magnitude scaling corrected for interpolation\n        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);\n    }\n    *cos_theta = cosf(theta) * mscale;\n    *sin_theta = sinf(theta) * mscale;\n}\n\nstatic void ggml_rope_cache_init(\n     float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,\n     float * cache, float sin_sign, float theta_scale) {\n    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py\n    float theta = theta_base;\n    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {\n        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;\n        rope_yarn(\n            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]\n        );\n        cache[i0 + 1] *= sin_sign;\n\n        theta *= theta_scale;\n    }\n}\n\nstatic void ggml_mrope_cache_init(\n     float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,\n     float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,\n     float * cache, float sin_sign, float theta_scale) {\n    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py\n    float theta_t = theta_base_t;\n    float theta_h = theta_base_h;\n    float theta_w = theta_base_w;\n    float theta_e = theta_base_e;  // extra position id for vision encoder\n    int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];\n    int sec_w = sections[1] + sections[0];\n    int sec_e = sections[2] + sec_w;\n    GGML_ASSERT(sect_dims <= ne0);\n\n    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {\n        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;\n\n        int sector = (i0 / 2) % sect_dims;\n        if (indep_sects) {\n            // compute theta independently for each dim sections\n            // (i.e. reset corresponding theta when `i0` go from one section to another)\n            if (sector == 0) {\n                theta_t = theta_base_t;\n            }\n            else if (sector == sections[0]) {\n                theta_h = theta_base_h;;\n            }\n            else if (sector == sec_w) {\n                theta_w = theta_base_w;\n            }\n            else if (sector == sec_e) {\n                theta_e = theta_base_e;\n            }\n        }\n\n        float theta = theta_t;\n        if (is_imrope) { // qwen3vl apply interleaved mrope\n            if (sector % 3 == 1 && sector < 3 * sections[1]) {\n                theta = theta_h;\n            } else if (sector % 3 == 2 && sector < 3 * sections[2]) {\n                theta = theta_w;\n            } else if (sector % 3 == 0 && sector < 3 * sections[0]) {\n                theta = theta_t;\n            } else {\n                theta = theta_e;\n            }\n        } else {\n            if (sector >= sections[0] && sector < sec_w) {\n                theta = theta_h;\n            }\n            else if (sector >= sec_w && sector < sec_w + sections[2]) {\n                theta = theta_w;\n            }\n            else if (sector >= sec_w + sections[2]) {\n                theta = theta_e;\n            }\n        }\n\n        rope_yarn(\n            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]\n        );\n        cache[i0 + 1] *= sin_sign;\n\n        theta_t *= theta_scale;\n        theta_w *= theta_scale;\n        theta_h *= theta_scale;\n        theta_e *= theta_scale;\n    }\n}\n\n\ntemplate<typename T>\nstatic void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {\n  for (int64_t i0 = 0; i0 < n; i0 += 2) {\n    const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2\n\n    const float cos_theta = cache[i0 + 0];\n    const float sin_theta = cache[i0 + 1];\n\n    const T * const src = src_data + ic;\n    T * dst             = dst_data + ic;\n\n    const float x0 = type_conversion_table<T>::to_f32(src[0]);\n    const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);\n\n    dst[0]        = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);\n    dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);\n  }\n}\n\ntemplate<typename T> //float or ggml_fp16_t\nstatic void ggml_compute_forward_rope_flt(\n        const ggml_compute_params * params,\n        ggml_tensor * dst,\n        const bool forward) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    const ggml_tensor * src2 = dst->src[2];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);\n    GGML_ASSERT(src1->type == GGML_TYPE_I32);\n\n    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;\n    int sections[4];\n\n    //const int n_past     = ((int32_t *) dst->op_params)[0];\n    const int n_dims     = ((int32_t *) dst->op_params)[1];\n    const int mode       = ((int32_t *) dst->op_params)[2];\n    //const int n_ctx      = ((int32_t *) dst->op_params)[3];\n    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];\n\n    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));\n    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));\n    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));\n    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));\n    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));\n    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));\n    memcpy(&sections,    (int32_t *) dst->op_params + 11, sizeof(int)*4);\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    //printf(\"ne0: %d, ne1: %d, ne2: %d, ne3: %d\\n\", ne0, ne1, ne2, ne3);\n    //printf(\"n_past = %d, ne2 = %d\\n\", n_past, ne2);\n\n    GGML_ASSERT(nb0 == nb00);\n    GGML_ASSERT(nb0 == sizeof(T));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nr = ggml_nrows(dst);\n\n    GGML_ASSERT(n_dims <= ne0);\n    GGML_ASSERT(n_dims % 2 == 0);\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    // row index used to determine which thread to use\n    int ir = 0;\n\n    const float theta_scale = powf(freq_base, -2.0f/n_dims);\n\n    float corr_dims[2];\n    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);\n\n    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope\n    const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;  // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope\n    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;\n\n    if (mrope_used) {\n        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);\n    }\n\n    if (is_vision) {\n        GGML_ASSERT(n_dims == ne0/2);\n    }\n\n    const float * freq_factors = NULL;\n    if (src2 != NULL) {\n        GGML_ASSERT(src2->type == GGML_TYPE_F32);\n        GGML_ASSERT(src2->ne[0] >= n_dims / 2);\n        freq_factors = (const float *) src2->data;\n    }\n\n    // backward process uses inverse rotation by cos and sin.\n    // cos and sin build a rotation matrix, where the inverse is the transpose.\n    // this essentially just switches the sign of sin.\n    const float sin_sign = forward ? 1.0f : -1.0f;\n\n    const int32_t * pos = (const int32_t *) src1->data;\n\n    int64_t last_i2 = -1;\n\n    for (int64_t i3 = 0; i3 < ne3; i3++) { // batch\n        for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len\n            for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads\n                if (ir++ < ir0) continue; // skip rows mapped to other threads\n                if (ir   > ir1) break;\n\n                float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;\n                if (last_i2 != i2) {\n                    if (!mrope_used) {\n                        const int64_t p = pos[i2];\n                        ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);\n                    }\n                    else {\n                        const int64_t p_t = pos[i2];\n                        const int64_t p_h = pos[i2 + ne2];\n                        const int64_t p_w = pos[i2 + ne2 * 2];\n                        const int64_t p_e = pos[i2 + ne2 * 3];\n                        ggml_mrope_cache_init(\n                            p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,\n                            freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);\n                    }\n\n                    last_i2 = i2;\n                }\n\n                T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);\n                T * dst_data  = (T *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1);\n\n                switch (mode) {\n                    case GGML_ROPE_TYPE_NORMAL:\n                        rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);\n                        break;\n                    case GGML_ROPE_TYPE_NEOX:\n                    case GGML_ROPE_TYPE_MROPE:\n                    case GGML_ROPE_TYPE_IMROPE:\n                        rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);\n                        break;\n                    case GGML_ROPE_TYPE_VISION:\n                        rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);\n                        break;\n                    default:\n                        GGML_ABORT(\"rope type not supported\");\n                }\n\n                if (!is_vision) {\n                    // fill the remain channels with data from src tensor\n                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {\n                        const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n                        T * dst_data  = (T *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n                        dst_data[0] = src[0];\n                        dst_data[1] = src[1];\n                    }\n                }\n            } //attn-heads\n        }\n    }\n}\n\nvoid ggml_compute_forward_rope(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);\n            } break;\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_rope_flt<float>(params, dst, true);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_rope_back\n\nvoid ggml_compute_forward_rope_back(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);\n            } break;\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_rope_flt<float>(params, dst, false);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_conv_transpose_1d\n\nstatic void ggml_compute_forward_conv_transpose_1d_f16_f32(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F16);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nk = ne00*ne01*ne02;\n\n    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));\n    GGML_ASSERT(nb10 == sizeof(float));\n\n    if (ith == 0) {\n        memset(params->wdata, 0, params->wsize);\n\n        // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)\n        {\n            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;\n\n            for (int64_t i02 = 0; i02 < ne02; i02++) {\n                for (int64_t i01 = 0; i01 < ne01; i01++) {\n                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);\n                    ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;\n                    for (int64_t i00 = 0; i00 < ne00; i00++) {\n                        dst_data[i00*ne02 + i02] = src[i00];\n                    }\n                }\n            }\n        }\n\n        // permute source data (src1) from (L x Cin) to (Cin x L)\n        {\n            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;\n            ggml_fp16_t * dst_data = wdata;\n\n            for (int64_t i11 = 0; i11 < ne11; i11++) {\n                const float * const src = (float *)((char *) src1->data + i11*nb11);\n                for (int64_t i10 = 0; i10 < ne10; i10++) {\n                    dst_data[i10*ne11 + i11] = GGML_CPU_FP32_TO_FP16(src[i10]);\n                }\n            }\n        }\n\n        // need to zero dst since we are accumulating into it\n        memset(dst->data, 0, ggml_nbytes(dst));\n    }\n    ggml_barrier(params->threadpool);\n\n    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];\n\n    // total rows in dst\n    const int nr = ne1;\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    ggml_fp16_t * const wdata     = (ggml_fp16_t *) params->wdata + 0;\n    ggml_fp16_t * const wdata_src = wdata + nk;\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        float * dst_data = (float *)((char *) dst->data + i1*nb1);\n        ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;\n        for (int i10 = 0; i10 < ne10; i10++) {\n            const int i1n = i10*ne11;\n            for (int i00 = 0; i00 < ne00; i00++) {\n                float v = 0;\n                ggml_vec_dot_f16(ne02, &v, 0,\n                        (ggml_fp16_t *)    wdata_src + i1n, 0,\n                        (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);\n                dst_data[i10*s0 + i00] += v;\n            }\n        }\n    }\n}\n\nstatic void ggml_compute_forward_conv_transpose_1d_f32(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nk = ne00*ne01*ne02;\n\n    GGML_ASSERT(nb00 == sizeof(float));\n    GGML_ASSERT(nb10 == sizeof(float));\n\n    if (ith == 0) {\n        memset(params->wdata, 0, params->wsize);\n\n        // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)\n        {\n            float * const wdata = (float *) params->wdata + 0;\n\n            for (int64_t i02 = 0; i02 < ne02; i02++) {\n                for (int64_t i01 = 0; i01 < ne01; i01++) {\n                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);\n                    float * dst_data = wdata + i01*ne00*ne02;\n                    for (int64_t i00 = 0; i00 < ne00; i00++) {\n                        dst_data[i00*ne02 + i02] = src[i00];\n                    }\n                }\n            }\n        }\n\n        // prepare source data (src1)\n        {\n            float * const wdata = (float *) params->wdata + nk;\n            float * dst_data = wdata;\n\n            for (int64_t i11 = 0; i11 < ne11; i11++) {\n                const float * const src = (float *)((char *) src1->data + i11*nb11);\n                for (int64_t i10 = 0; i10 < ne10; i10++) {\n                    dst_data[i10*ne11 + i11] = src[i10];\n                }\n            }\n        }\n\n        // need to zero dst since we are accumulating into it\n        memset(dst->data, 0, ggml_nbytes(dst));\n    }\n    ggml_barrier(params->threadpool);\n\n    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];\n\n    // total rows in dst\n    const int nr = ne1;\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    float * const wdata     = (float *) params->wdata + 0;\n    float * const wdata_src = wdata + nk;\n\n    for (int i1 = ir0; i1 < ir1; i1++) {\n        float * dst_data = (float *)((char *) dst->data + i1*nb1);\n        float * wdata_kernel = wdata + i1*ne02*ne00;\n        for (int i10 = 0; i10 < ne10; i10++) {\n            const int i1n = i10*ne11;\n            for (int i00 = 0; i00 < ne00; i00++) {\n                float v = 0;\n                ggml_vec_dot_f32(ne02, &v, 0,\n                        wdata_src + i1n, 0,\n                        wdata_kernel + i00*ne02, 0, 1);\n                dst_data[i10*s0 + i00] += v;\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_conv_transpose_1d(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);\n            } break;\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_conv_transpose_1d_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_im2col_f32\n// src0: kernel [OC, IC, KH, KW]\n// src1: image [N, IC, IH, IW]\n// dst:  result [N, OH, OW, IC*KH*KW]\nstatic void ggml_compute_forward_im2col_f32(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_BINARY_OP_LOCALS;\n\n    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];\n    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];\n    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];\n    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];\n    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];\n    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];\n    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int64_t N  = is_2D ? ne13 : ne12;\n    const int64_t IC = is_2D ? ne12 : ne11;\n    const int64_t IH = is_2D ? ne11 : 1;\n    const int64_t IW = ne10;\n\n    const int64_t KH = is_2D ? ne01 : 1;\n    const int64_t KW = ne00;\n\n    const int64_t OH = is_2D ? ne2 : 1;\n    const int64_t OW = ne1;\n\n    int ofs0 = is_2D ? nb13 : nb12;\n    int ofs1 = is_2D ? nb12 : nb11;\n\n    GGML_ASSERT(nb10 == sizeof(float));\n\n    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]\n    {\n        float * const wdata = (float *) dst->data;\n\n        for (int64_t in = 0; in < N; in++) {\n            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1\n                for (int64_t iow = 0; iow < OW; iow++) {\n                    for (int64_t iic = ith; iic < IC; iic += nth) {\n\n                        // micro kernel\n                        float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]\n                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]\n\n                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1\n                            for (int64_t ikw = 0; ikw < KW; ikw++) {\n                                const int64_t iiw = iow*s0 + ikw*d0 - p0;\n                                const int64_t iih = ioh*s1 + ikh*d1 - p1;\n\n                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {\n                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;\n                                } else {\n                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n}\n\n\n// ggml_compute_forward_im2col_f16\n// src0: kernel [OC, IC, KH, KW]\n// src1: image [N, IC, IH, IW]\n// dst:  result [N, OH, OW, IC*KH*KW]\nstatic void ggml_compute_forward_im2col_f16(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F16);\n    GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F16);\n\n    GGML_TENSOR_BINARY_OP_LOCALS;\n\n    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];\n    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];\n    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];\n    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];\n    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];\n    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];\n    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int64_t N  = is_2D ? ne13 : ne12;\n    const int64_t IC = is_2D ? ne12 : ne11;\n    const int64_t IH = is_2D ? ne11 : 1;\n    const int64_t IW = ne10;\n\n    const int64_t KH = is_2D ? ne01 : 1;\n    const int64_t KW = ne00;\n\n    const int64_t OH = is_2D ? ne2 : 1;\n    const int64_t OW = ne1;\n\n    int ofs0 = is_2D ? nb13 : nb12;\n    int ofs1 = is_2D ? nb12 : nb11;\n\n    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));\n    GGML_ASSERT(nb10 == ggml_type_size(src1->type));\n\n    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]\n    {\n        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;\n\n        for (int64_t in = 0; in < N; in++) {\n            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1\n                for (int64_t iow = 0; iow < OW; iow++) {\n                    for (int64_t iic = ith; iic < IC; iic += nth) {\n\n                        // micro kernel\n                        ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]\n                        const float * const src_data_f32 = src1->type == GGML_TYPE_F32\n                            ? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1)\n                            : nullptr; // [IH, IW]\n                        const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16\n                            ? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1)\n                            : nullptr; // [IH, IW]\n\n                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1\n                            for (int64_t ikw = 0; ikw < KW; ikw++) {\n                                const int64_t iiw = iow*s0 + ikw*d0 - p0;\n                                const int64_t iih = ioh*s1 + ikh*d1 - p1;\n\n                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {\n                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;\n                                } else {\n                                    if (src_data_f32 != nullptr) {\n                                        dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data_f32[iih*IW + iiw]);\n                                    } else {\n                                        dst_data[iic*(KH*KW) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw];\n                                    }\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_im2col(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n    switch (dst->type) {\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_im2col_f16(params, dst);\n            } break;\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_im2col_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_im2col_back_f32\n\nvoid ggml_compute_forward_im2col_back_f32(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output\n    const ggml_tensor * src1 = dst->src[1]; // convolution kernel\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_BINARY_OP_LOCALS;\n\n    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];\n    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];\n    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];\n    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];\n    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];\n    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];\n    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int64_t N  = is_2D ? ne3 : ne2;\n    const int64_t IC = is_2D ? ne2 : ne1;\n    const int64_t IH = is_2D ? ne1 : 1;\n    const int64_t IW = ne0;\n\n    const int64_t KH = is_2D ? ne11 : 1;\n    const int64_t KW = ne10;\n\n    const int64_t OH = is_2D ? ne02 : 1;\n    const int64_t OW = ne01;\n\n    int ofs0 = is_2D ? nb3 : nb2;\n    int ofs1 = is_2D ? nb2 : nb1;\n\n    GGML_ASSERT(nb0  == sizeof(float));\n\n    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]\n    {\n        float * const wdata = (float *) dst->data;\n\n        for (int64_t in = 0; in < N; in++) {\n            for (int64_t iic = ith; iic < IC; iic += nth) {\n                for (int64_t iih = 0; iih < IH; iih++) {\n                    for (int64_t iiw = 0; iiw < IW; iiw++) {\n\n                        // micro kernel\n                        float grad = 0.0f;\n                        for (int64_t ikh = 0; ikh < KH; ikh++) {\n                            for (int64_t ikw = 0; ikw < KW; ikw++) {\n                                // For s0 > 1 some values were skipped over in the forward pass.\n                                // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.\n                                const int64_t tmpw = (iiw + p0 - ikw*d0);\n                                if (tmpw % s0 != 0) {\n                                    continue;\n                                }\n                                const int64_t iow = tmpw / s0;\n\n                                // Equivalent logic as above except for s1.\n                                int64_t ioh;\n                                if (is_2D) {\n                                    const int64_t tmph = iih + p1 - ikh*d1;\n\n                                    if (tmph % s1 != 0) {\n                                        continue;\n                                    }\n\n                                    ioh = tmph / s1;\n                                } else {\n                                    ioh = 0;\n                                }\n\n                                if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {\n                                    continue;\n                                }\n\n                                const float * const grad_in = (const float *) src0->data\n                                    + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]\n                                grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];\n                            }\n                        }\n                        float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]\n                        dst_data[iih*IW + iiw] = grad;\n                    }\n                }\n            }\n        }\n    }\n}\n\n\n// ggml_compute_forward_im2col_3d_f16\n// src0: kernel [OC*IC, KD, KH, KW]\n// src1: image [N*IC, ID, IH, IW]\n// dst:  result [N*OD, OH, OW, IC * KD * KH * KW]\nstatic void ggml_compute_forward_im2col_3d_f16(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F16);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F16);\n\n    GGML_TENSOR_BINARY_OP_LOCALS;\n\n    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];\n    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];\n    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];\n    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];\n    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];\n    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];\n    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];\n    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];\n    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];\n    const int32_t IC = ((const int32_t *)(dst->op_params))[9];\n\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int64_t N  = ne13 / IC;\n    const int64_t ID = ne12;\n    const int64_t IH = ne11;\n    const int64_t IW = ne10;\n\n    const int64_t OC = ne03 / IC;\n    GGML_UNUSED(OC);\n    const int64_t KD = ne02;\n    const int64_t KH = ne01;\n    const int64_t KW = ne00;\n\n    const int64_t OD = ne3 / N;\n    const int64_t OH = ne2;\n    const int64_t OW = ne1;\n    const int64_t OH_OW = OH*OW;\n    const int64_t KD_KH_KW = KD*KH*KW;\n    const int64_t KH_KW = KH*KW;\n    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;\n\n    GGML_ASSERT(nb10 == sizeof(float));\n\n    // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]\n    {\n        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;\n\n        for (int64_t in = 0; in < N; in++) {\n            for (int64_t iod = 0; iod < OD; iod++) {\n                for (int64_t ioh = 0; ioh < OH; ioh++) {\n                    for (int64_t iow = 0; iow < OW; iow++) {\n                        for (int64_t iic = ith; iic < IC; iic += nth) {\n\n                            // micro kernel\n                            ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]\n                            const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]\n\n                            for (int64_t ikd = 0; ikd < KD; ikd++) {\n                                for (int64_t ikh = 0; ikh < KH; ikh++) {\n                                    for (int64_t ikw = 0; ikw < KW; ikw++) {\n                                        const int64_t iiw = iow*s0 + ikw*d0 - p0;\n                                        const int64_t iih = ioh*s1 + ikh*d1 - p1;\n                                        const int64_t iid = iod*s2 + ikd*d2 - p2;\n\n                                        if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {\n                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;\n                                        } else {\n                                            const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]\n                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);\n                                        }\n                                    }\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n}\n\n// ggml_compute_forward_im2col_3d_f32\n// src0: kernel [OC*IC, KD, KH, KW]\n// src1: image [N*IC, ID, IH, IW]\n// dst:  result [N*OD, OH, OW, IC * KD * KH * KW]\nstatic void ggml_compute_forward_im2col_3d_f32(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_BINARY_OP_LOCALS;\n\n    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];\n    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];\n    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];\n    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];\n    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];\n    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];\n    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];\n    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];\n    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];\n    const int32_t IC = ((const int32_t *)(dst->op_params))[9];\n\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int64_t N  = ne13 / IC;\n    const int64_t ID = ne12;\n    const int64_t IH = ne11;\n    const int64_t IW = ne10;\n\n    const int64_t OC = ne03 / IC;\n    GGML_UNUSED(OC);\n    const int64_t KD = ne02;\n    const int64_t KH = ne01;\n    const int64_t KW = ne00;\n\n    const int64_t OD = ne3 / N;\n    const int64_t OH = ne2;\n    const int64_t OW = ne1;\n\n    const int64_t OH_OW = OH*OW;\n    const int64_t KD_KH_KW = KD*KH*KW;\n    const int64_t KH_KW = KH*KW;\n    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;\n\n    GGML_ASSERT(nb10 == sizeof(float));\n\n    // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]\n    {\n        float * const wdata = (float *) dst->data;\n\n        for (int64_t in = 0; in < N; in++) {\n            for (int64_t iod = 0; iod < OD; iod++) {\n                for (int64_t ioh = 0; ioh < OH; ioh++) {\n                    for (int64_t iow = 0; iow < OW; iow++) {\n                        for (int64_t iic = ith; iic < IC; iic += nth) {\n\n                            // micro kernel\n                            float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]\n                            const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]\n\n                            for (int64_t ikd = 0; ikd < KD; ikd++) {\n                                for (int64_t ikh = 0; ikh < KH; ikh++) {\n                                    for (int64_t ikw = 0; ikw < KW; ikw++) {\n                                        const int64_t iiw = iow*s0 + ikw*d0 - p0;\n                                        const int64_t iih = ioh*s1 + ikh*d1 - p1;\n                                        const int64_t iid = iod*s2 + ikd*d2 - p2;\n\n                                        if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {\n                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;\n                                        } else {\n                                            const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]\n                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;\n                                        }\n                                    }\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n}\n\n\nvoid ggml_compute_forward_im2col_3d(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n    switch (dst->type) {\n        case GGML_TYPE_F16:\n            {\n                ggml_compute_forward_im2col_3d_f16(params, dst);\n            } break;\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_im2col_3d_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\nstatic void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,\n                              void * a, void * b, float * c) {\n    const ggml_type_traits * traits = ggml_get_type_traits(type);\n    struct ggml_tensor src1 = {};\n    src1.type  = type;\n    src1.ne[0] = k;\n    src1.ne[1] = m;\n    src1.ne[2] = 1;\n    src1.ne[3] = 1;\n    src1.nb[0] = traits->type_size;\n    src1.nb[1] = k * traits->type_size;\n    src1.nb[2] = src1.nb[1];\n    src1.nb[3] = src1.nb[2];\n    src1.data  = a;\n\n    struct ggml_tensor src0 = {};\n    src0.type  = type;\n    src0.ne[0] = k;\n    src0.ne[1] = n;\n    src0.ne[2] = 1;\n    src0.ne[3] = 1;\n    src0.nb[0] = traits->type_size;\n    src0.nb[1] = k * traits->type_size;\n    src0.nb[2] = src0.nb[1];\n    src0.nb[3] = src0.nb[2];\n    src0.data  = b;\n\n    struct ggml_tensor dst = {};\n    dst.ne[0] = n;\n    dst.ne[1] = m;\n    dst.ne[2] = 1;\n    dst.ne[3] = 1;\n    dst.nb[0] = sizeof(float);\n    dst.nb[1] = n * sizeof(float);\n    dst.nb[2] = dst.nb[1];\n    dst.nb[3] = dst.nb[2];\n    dst.data  = c;\n    dst.src[0] = &src0;\n    dst.src[1] = &src1;\n\n    ggml_compute_forward_mul_mat(params, &dst);\n}\n\nstatic inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {\n    return (coord  + size) % size; // adding size avoids negative number weirdness\n}\n\n// ggml_compute_forward_conv_2d\n\n\nstatic void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,\n                                              const ggml_tensor *         kernel,  // [KW, KH, IC, OC]\n                                              const ggml_tensor *         src,     // [W, H, C, N]\n                                              ggml_tensor *               dst,     // [OW, OH, OC, N]\n                                              ggml_type                   kernel_type) {\n\n    GGML_ASSERT(ggml_is_contiguous(kernel));\n    GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);\n    GGML_ASSERT(kernel->type == kernel_type);\n\n    const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);\n\n    const int32_t stride_x   = dst->op_params[0];\n    const int32_t stride_y   = dst->op_params[1];\n    const int32_t pad_x      = dst->op_params[2];\n    const int32_t pad_y      = dst->op_params[3];\n    const int32_t dilation_x = dst->op_params[4];\n    const int32_t dilation_y = dst->op_params[5];\n\n    const int64_t c_in  = src->ne[2];\n    const int64_t c_out = kernel->ne[3];\n    GGML_ASSERT(c_in == kernel->ne[2]);\n\n    const int64_t src_w = src->ne[0];\n    const int64_t src_h = src->ne[1];\n    const int64_t knl_w = kernel->ne[0];\n    const int64_t knl_h = kernel->ne[1];\n    const int64_t dst_w = dst->ne[0];\n    const int64_t dst_h = dst->ne[1];\n\n    const float * src_data = (float *) src->data;\n    void  * knl_data       = kernel->data;\n    float * dst_data       = (float *) dst->data;\n\n    const int64_t knl_n           = knl_w * knl_h * c_in;\n    const int64_t patch_total     = dst->ne[3] * dst_w * dst_h;\n\n    const int64_t space_per_patch   = knl_n * traits->type_size + c_out * sizeof(float);\n    const int64_t batch_size        = params->wsize / space_per_patch;\n    const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;\n    const int64_t batch_n           = (patch_total + patches_per_batch - 1) / patches_per_batch;\n\n    GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);\n\n    void * tmp = params->wdata;\n\n    for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {\n\n        const int64_t patch_start_batch = batch_i * patches_per_batch;\n        const int64_t patch_end_batch   = std::min(patch_start_batch + patches_per_batch,\n                                              patch_total);\n        const int64_t patch_n           = patch_end_batch - patch_start_batch;\n\n        const int64_t patch_per_thread  = (patch_n + params->nth - 1) / params->nth;\n        const int64_t patch_start       = patch_start_batch + params->ith * patch_per_thread;\n        const int64_t patch_end         = std::min(patch_start + patch_per_thread, patch_end_batch);\n\n        //im2col for a patch\n        for (int64_t p = patch_start; p < patch_end; ++p) {\n            const int64_t  batch_n     =  p / (dst_w * dst_h);\n            const int64_t  src_x       = (p / dst_w) % dst_h;\n            const int64_t  src_y       =  p % dst_w;\n\n            const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);\n            char *        dst_row  = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;\n\n            for (int64_t ic = 0; ic < c_in; ++ic) {\n                for (int64_t ky = 0; ky < knl_h; ++ky) {\n                    for (int64_t kx = 0; kx < knl_w; ++kx) {\n                        const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;\n                        const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;\n\n                        int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;\n\n                        float src_val;\n                        if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {\n                            src_val = 0.0f;\n                        } else {\n                            const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);\n                            src_val               = *src_ptr;\n                        }\n\n                        char * element_ptr = dst_row + dst_idx * traits->type_size;\n                        if (kernel_type == GGML_TYPE_F32) {\n                            *(float *) element_ptr = src_val;\n                        } else if (kernel_type == GGML_TYPE_F16) {\n                            *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);\n                        }\n                    }\n                }\n            }\n        }   // patches handled by this thread\n\n        ggml_barrier(params->threadpool);\n\n        float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);\n\n        GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);\n\n        // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]\n        ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);\n\n        ggml_barrier(params->threadpool);\n\n\n        //permute back [OC, N, OH, OW] to [N, OC, OH, OW]\n        const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;\n        const int64_t permute_start = params->ith * permute_per_thread;\n        const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);\n\n        for (int64_t i = permute_start; i < permute_end; ++i) {\n            const int64_t p       = patch_start_batch + i;\n            const int64_t batch_n = p / (dst_w * dst_h);\n            const int64_t dst_y   = (p / dst_w) % dst_h;\n            const int64_t dst_x   = p % dst_w;\n\n            for (int64_t oc = 0; oc < c_out; ++oc) {\n                const float value = gemm_output[i * c_out + oc];\n                float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);\n                *dst_ptr = value;\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_conv_2d(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);\n}\n\n// ggml_compute_forward_conv_3d\n\nstatic void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,\n                                              const ggml_tensor *         kernel,\n                                              const ggml_tensor *         src,\n                                              ggml_tensor *               dst,\n                                              ggml_type                   kernel_type) {\n\n    GGML_ASSERT(ggml_is_contiguous(kernel));\n    GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);\n    GGML_ASSERT(kernel->type == kernel_type);\n\n    const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);\n\n    const int32_t s0 = dst->op_params[0];\n    const int32_t s1 = dst->op_params[1];\n    const int32_t s2 = dst->op_params[2];\n    const int32_t p0 = dst->op_params[3];\n    const int32_t p1 = dst->op_params[4];\n    const int32_t p2 = dst->op_params[5];\n    const int32_t d0 = dst->op_params[6];\n    const int32_t d1 = dst->op_params[7];\n    const int32_t d2 = dst->op_params[8];\n    const int32_t c  = dst->op_params[9];\n    const int32_t n  = dst->op_params[10];\n    const int32_t oc = dst->op_params[11];\n\n    const int64_t src_w = src->ne[0];\n    const int64_t src_h = src->ne[1];\n    const int64_t src_d = src->ne[2];\n    const int64_t knl_w = kernel->ne[0];\n    const int64_t knl_h = kernel->ne[1];\n    const int64_t knl_d = kernel->ne[2];\n    const int64_t dst_w = dst->ne[0];\n    const int64_t dst_h = dst->ne[1];\n    const int64_t dst_d = dst->ne[2];\n\n    const float * src_data = (float *) src->data;\n    void  * knl_data       = kernel->data;\n    float * dst_data       = (float *) dst->data;\n\n    const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;\n    const int64_t knl_n_total       = knl_n_per_channel * c;\n    const int64_t patch_total       = n * dst_w * dst_h * dst_d;\n\n    const int64_t space_per_patch   = knl_n_total * traits->type_size + oc * sizeof(float);\n    const int64_t batch_size        = params->wsize / space_per_patch;\n    const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;\n    const int64_t batch_n           = (patch_total + patches_per_batch - 1) / patches_per_batch;\n\n    GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);\n\n    void * tmp = params->wdata;\n\n    for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {\n        const int64_t patch_start_batch = batch_i * patches_per_batch;\n        const int64_t patch_end_batch   = std::min(patch_start_batch + patches_per_batch, patch_total);\n        const int64_t patch_n_in_batch  = patch_end_batch - patch_start_batch;\n\n        const int64_t patch_per_thread  = (patch_n_in_batch + params->nth - 1) / params->nth;\n        const int64_t patch_start       = patch_start_batch + params->ith * patch_per_thread;\n        const int64_t patch_end         = std::min(patch_start + patch_per_thread, patch_end_batch);\n\n        for (int64_t p = patch_start; p < patch_end; ++p) {\n            const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);\n            const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);\n            const int64_t batch_idx  = p / (dst_w * dst_h * dst_d);\n            const int64_t dst_z      = p_in_batch / (dst_w * dst_h);\n            const int64_t dst_y      = p_in_depth / dst_w;\n            const int64_t dst_x      = p_in_depth % dst_w;\n\n            char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;\n\n            for (int64_t ic = 0; ic < c; ++ic) {\n                for (int64_t kz = 0; kz < knl_d; ++kz) {\n                    for (int64_t ky = 0; ky < knl_h; ++ky) {\n                        for (int64_t kx = 0; kx < knl_w; ++kx) {\n                            const int64_t sz = dst_z * s2 + kz * d2 - p2;\n                            const int64_t sy = dst_y * s1 + ky * d1 - p1;\n                            const int64_t sx = dst_x * s0 + kx * d0 - p0;\n\n                            int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;\n\n                            float src_val;\n                            if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {\n                                src_val = 0.0f;\n                            } else {\n                                const int64_t cn_idx = batch_idx * c + ic;\n                                const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);\n                                src_val = *src_ptr;\n                            }\n\n                            char * element_ptr = dst_row + dst_idx * traits->type_size;\n                            if (kernel_type == GGML_TYPE_F32) {\n                                *(float *)element_ptr = src_val;\n                            } else if (kernel_type == GGML_TYPE_F16) {\n                                *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);\n                            }\n                        }\n                    }\n                }\n            }\n        }\n\n        ggml_barrier(params->threadpool);\n\n        float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);\n        ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);\n\n        ggml_barrier(params->threadpool);\n\n        const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;\n        const int64_t permute_start = params->ith * permute_per_thread;\n        const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);\n\n        for (int64_t i = permute_start; i < permute_end; ++i) {\n            const int64_t p = patch_start_batch + i;\n            const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);\n            const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);\n            const int64_t batch_idx  = p / (dst_w * dst_h * dst_d);\n            const int64_t dst_z      = p_in_batch / (dst_w * dst_h);\n            const int64_t dst_y      = p_in_depth / dst_w;\n            const int64_t dst_x      = p_in_depth % dst_w;\n\n            for (int64_t ioc = 0; ioc < oc; ++ioc) {\n                const float value = gemm_output[i * oc + ioc];\n                const int64_t ocn_idx = batch_idx * oc + ioc;\n                float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);\n                *dst_ptr = value;\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_conv_3d(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);\n}\n\n// ggml_compute_forward_conv_transpose_2d\n\nvoid ggml_compute_forward_conv_transpose_2d(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F16);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nk = ne00*ne01*ne02*ne03;\n\n    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));\n    GGML_ASSERT(nb10 == sizeof(float));\n\n    if (ith == 0) {\n        memset(params->wdata, 0, params->wsize);\n\n        // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)\n        {\n            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;\n\n            for (int64_t i03 = 0; i03 < ne03; i03++) {\n                for (int64_t i02 = 0; i02 < ne02; i02++) {\n                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);\n                    ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;\n                    for (int64_t i01 = 0; i01 < ne01; i01++) {\n                        for (int64_t i00 = 0; i00 < ne00; i00++) {\n                            dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];\n                        }\n                    }\n                }\n            }\n        }\n\n        // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)\n        {\n            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;\n            for (int i12 = 0; i12 < ne12; i12++) {\n                for (int i11 = 0; i11 < ne11; i11++) {\n                    const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);\n                    ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;\n                    for (int i10 = 0; i10 < ne10; i10++) {\n                        dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);\n                    }\n                }\n            }\n        }\n\n        memset(dst->data, 0, ggml_nbytes(dst));\n    }\n    ggml_barrier(params->threadpool);\n\n    const int32_t stride = ggml_get_op_params_i32(dst, 0);\n\n    // total patches in dst\n    const int np = ne2;\n\n    // patches per thread\n    const int dp = (np + nth - 1)/nth;\n\n    // patch range for this thread\n    const int ip0 = dp*ith;\n    const int ip1 = MIN(ip0 + dp, np);\n\n    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;\n    ggml_fp16_t * const wdata_src = wdata + nk;\n\n    for (int i2 = ip0; i2 < ip1; i2++) { // Cout\n        float * dst_data = (float *)((char *) dst->data + i2*nb2);\n        ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;\n        for (int i11 = 0; i11 < ne11; i11++) {\n            for (int i10 = 0; i10 < ne10; i10++) {\n                const int i1n = i11*ne10*ne12 + i10*ne12;\n                for (int i01 = 0; i01 < ne01; i01++) {\n                    for (int i00 = 0; i00 < ne00; i00++) {\n                        float v = 0;\n                        ggml_vec_dot_f16(ne03, &v, 0,\n                                wdata_src + i1n, 0,\n                                wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);\n                        dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;\n                    }\n                }\n            }\n        }\n    }\n}\n\n// ggml_compute_forward_conv_2d_dw\n\nstruct ggml_conv_2d_dw_params {\n    int64_t channels;\n    int64_t batch;\n    int64_t src_w;\n    int64_t src_h;\n    int64_t dst_w;\n    int64_t dst_h;\n    int64_t knl_w;\n    int64_t knl_h;\n    int stride_x;\n    int stride_y;\n    int pad_x;\n    int pad_y;\n    int dilation_x;\n    int dilation_y;\n};\n\nstatic void ggml_compute_forward_conv_2d_dw_cwhn(\n        const ggml_compute_params * params,\n        const ggml_tensor * src,\n        const ggml_tensor * kernel,\n        ggml_tensor * dst,\n        const ggml_conv_2d_dw_params & p) {\n\n    const int64_t c = p.channels;\n    const float * knl_data = (const float *)kernel->data;\n\n    const int64_t rows_total = p.dst_h * p.batch;\n    const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;\n    const int64_t row_start = params->ith * rows_per_thread;\n    const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);\n\n#ifdef GGML_SIMD\n    #if defined(__ARM_FEATURE_SVE)\n        const int64_t pkg_size = svcntw();\n    #else\n        const int64_t pkg_size = GGML_F32_EPR;\n    #endif\n    const int64_t pkg_count = c / pkg_size;\n    const int64_t c_pkg_end = pkg_count * pkg_size;\n#else\n    const int64_t c_pkg_end = 0;\n#endif\n\n    for (int64_t row = row_start; row < row_end; ++row) {\n        const int64_t dst_y = row % p.dst_h;\n        const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c;\n        for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {\n            float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;\n            const int64_t src_y_base = dst_y * p.stride_y - p.pad_y;\n            const int64_t src_x_base = dst_x * p.stride_x - p.pad_x;\n\n#ifdef GGML_SIMD\n            // Vectorized loop\n            for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) {\n                GGML_F32_VEC sum = GGML_F32_VEC_ZERO;\n                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {\n                    const int64_t src_y = src_y_base + knl_y * p.dilation_y;\n                    if (src_y < 0 || src_y >= p.src_h) {\n                        continue;\n                    }\n                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {\n                        const int64_t src_x = src_x_base + knl_x * p.dilation_x;\n                        if (src_x < 0 || src_x >= p.src_w) {\n                            continue;\n                        }\n                        GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);\n                        GGML_F32_VEC s = GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i);\n                        sum = GGML_F32_VEC_FMA(sum, k, s);\n                    }\n                }\n                GGML_F32_VEC_STORE(dst_data + c_i, sum);\n            }\n#endif\n            // Scalar loop\n            for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {\n                float sum = 0.0f;\n                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {\n                    const int64_t src_y = src_y_base + knl_y * p.dilation_y;\n                    if (src_y < 0 || src_y >= p.src_h) {\n                        continue;\n                    }\n                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {\n                        const int64_t src_x = src_x_base + knl_x * p.dilation_x;\n                        if (src_x < 0 || src_x >= p.src_w) {\n                            continue;\n                        }\n                        sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]\n                             * src_data[(src_y * p.src_w + src_x) * c + c_i];\n                    }\n                }\n                dst_data[c_i] = sum;\n            }\n        }\n    }\n}\n\nstatic void ggml_compute_forward_conv_2d_dw_whcn(\n        const ggml_compute_params * params,\n        const ggml_tensor * src,\n        const ggml_tensor * kernel,\n        ggml_tensor * dst,\n        const ggml_conv_2d_dw_params & p) {\n\n    const int64_t n = p.channels * p.batch;\n    const int64_t per_thread = (n + params->nth - 1) / params->nth;\n    const int64_t start = params->ith * per_thread;\n    const int64_t end = MIN(start + per_thread, n);\n\n    for (int64_t i = start; i < end; ++i) {\n        const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h;\n        const float * src_data = (const float *)src->data + i * p.src_w * p.src_h;\n        float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h;\n\n        for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) {\n            for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {\n\n                float sum = 0.0f;\n                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {\n                    const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;\n                    if (src_y < 0 || src_y >= p.src_h) {\n                        continue;\n                    }\n                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {\n                        const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;\n                        if (src_x < 0 || src_x >= p.src_w) {\n                            continue;\n                        }\n                        sum += knl_data[knl_y * p.knl_w + knl_x]\n                             * src_data[src_y * p.src_w + src_x];\n                    }\n                }\n                dst_data[dst_y * p.dst_w + dst_x] = sum;\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_conv_2d_dw(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * kernel = dst->src[0];\n    const ggml_tensor * src = dst->src[1];\n    ggml_conv_2d_dw_params p;\n    p.channels = src->ne[2];\n    p.batch = src->ne[3];\n    p.src_w = src->ne[0];\n    p.src_h = src->ne[1];\n    p.dst_w = dst->ne[0];\n    p.dst_h = dst->ne[1];\n    p.knl_w = kernel->ne[0];\n    p.knl_h = kernel->ne[1];\n    p.stride_x = dst->op_params[0];\n    p.stride_y = dst->op_params[1];\n    p.pad_x = dst->op_params[2];\n    p.pad_y = dst->op_params[3];\n    p.dilation_x = dst->op_params[4];\n    p.dilation_y = dst->op_params[5];\n\n    GGML_ASSERT(kernel->ne[3] == p.channels);\n    GGML_ASSERT(dst->ne[3] == p.batch);\n\n    if (ggml_is_contiguous(src)) {\n        ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p);\n    } else if (ggml_is_contiguous_channels(src)) {\n        // kernel should also have channels most contiguous in memory\n        GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]);\n        ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p);\n    } else {\n        GGML_ABORT(\"non-contiguous memory layout not supported\");\n    }\n}\n\n// ggml_compute_forward_pool_1d_ksp\nstatic void ggml_compute_forward_pool_1d_ksp(\n        const ggml_compute_params * params,\n        const ggml_op_pool op,\n        const int k,\n        const int s,\n        const int p,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src = dst->src[0];\n\n    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    const int64_t IW = src->ne[0];\n    const int64_t OW = dst->ne[0];\n\n    const int64_t nr = ggml_nrows(src);\n\n    for (int64_t ir = 0; ir < nr; ++ir) {\n        const char * srow_bytes =            (const char *) src->data + ir * src->nb[1];\n        float      * drow       = (float *) ((      char *) dst->data + ir * dst->nb[1]);\n\n        for (int64_t ow = 0; ow < OW; ++ow) {\n            float res = 0;\n            switch (op) {\n                case GGML_OP_POOL_AVG: res = 0.0f;     break;\n                case GGML_OP_POOL_MAX: res = -FLT_MAX; break;\n                case GGML_OP_POOL_COUNT: GGML_ABORT(\"fatal error\");\n            }\n\n            int count = 0;\n            const int base = (int) ow * s - p;\n\n            for (int ki = 0; ki < k; ++ki) {\n                const int j = base + ki;\n                if (j < 0 || j >= (int) IW) {\n                    continue;\n                }\n\n                float v;\n                if (src->type == GGML_TYPE_F32) {\n                    v = ((const float *) srow_bytes)[j];\n                } else {\n                    v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);\n                }\n\n                switch (op) {\n                    case GGML_OP_POOL_AVG: res += v;                break;\n                    case GGML_OP_POOL_MAX: res =  std::max(v, res); break;\n                    case GGML_OP_POOL_COUNT: GGML_ABORT(\"fatal error\");\n                }\n\n                ++count;\n            }\n\n            switch (op) {\n                case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;\n                case GGML_OP_POOL_MAX:                                           break;\n                case GGML_OP_POOL_COUNT: GGML_ABORT(\"fatal error\");\n            }\n\n            drow[ow] = res;\n        }\n    }\n}\n\n// ggml_compute_forward_pool_1d\n\nvoid ggml_compute_forward_pool_1d(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const int32_t * opts = (const int32_t *)dst->op_params;\n    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);\n    const int k0 = opts[1];\n    const int s0 = opts[2];\n    const int p0 = opts[3];\n\n    ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);\n}\n\n// ggml_compute_forward_pool_2d\n\nvoid ggml_compute_forward_pool_2d(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src = dst->src[0];\n\n    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    const int32_t * opts = (const int32_t *)dst->op_params;\n\n    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);\n    const int k0 = opts[1];\n    const int k1 = opts[2];\n    const int s0 = opts[3];\n    const int s1 = opts[4];\n    const int p0 = opts[5];\n    const int p1 = opts[6];\n    const char * cdata = (const char*)src->data;\n    const char * const data_end = cdata + ggml_nbytes(src);\n\n    const int64_t px = dst->ne[0];\n    const int64_t py = dst->ne[1];\n    const int64_t pa = px * py;\n\n    float * dplane = (float *)dst->data;\n\n    const int ka = k0 * k1;\n    const int offset0 = -p0;\n    const int offset1 = -p1;\n\n    while (cdata < data_end) {\n        for (int oy = 0; oy < py; ++oy) {\n            float * const drow = dplane + oy * px;\n            float * const out  = drow;\n\n            for (int ox = 0; ox < px; ++ox) {\n                float res = 0;\n                switch (op) {\n                    case GGML_OP_POOL_AVG: res = 0;        break;\n                    case GGML_OP_POOL_MAX: res = -FLT_MAX; break;\n                    case GGML_OP_POOL_COUNT: GGML_ABORT(\"fatal error\");\n                }\n\n                const int ix = offset0 + ox * s0;\n                const int iy = offset1 + oy * s1;\n\n                for (int ky = 0; ky < k1; ++ky) {\n                    if (iy + ky < 0 || iy + ky >= src->ne[1]) {\n                        continue;\n                    }\n\n                    const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));\n                    for (int kx = 0; kx < k0; ++kx) {\n                        int j = ix + kx;\n                        if (j < 0 || j >= src->ne[0]) {\n                            continue;\n                        }\n\n                        const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);\n                        switch (op) {\n                            case GGML_OP_POOL_AVG: res += srow_j;                break;\n                            case GGML_OP_POOL_MAX: res =  std::max(srow_j, res); break;\n                            case GGML_OP_POOL_COUNT:               GGML_ABORT(\"fatal error\");\n                        }\n                    }\n                }\n                switch (op) {\n                    case GGML_OP_POOL_AVG:           res /= ka; break;\n                    case GGML_OP_POOL_MAX:                      break;\n                    case GGML_OP_POOL_COUNT: GGML_ABORT(\"fatal error\");\n                }\n\n                out[ox] = res;\n            }\n        }\n\n        cdata  += src->nb[2];\n        dplane += pa;\n    }\n}\n\n// ggml_compute_forward_pool_2d_back\n\nvoid ggml_compute_forward_pool_2d_back(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src  = dst->src[0];\n    const ggml_tensor * dstf = dst->src[1]; // forward tensor of dst\n\n    assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);\n\n    if (params->ith != 0) {\n        return;\n    }\n\n    const int32_t * opts = (const int32_t *)dst->op_params;\n    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);\n    const int k0 = opts[1];\n    const int k1 = opts[2];\n    const int s0 = opts[3];\n    const int s1 = opts[4];\n    const int p0 = opts[5];\n    const int p1 = opts[6];\n\n    char       * cdata  = (char       *) dst->data;\n    const char * cdataf = (const char *) dstf->data;\n    const char * const data_end = cdata + ggml_nbytes(dst);\n\n    GGML_ASSERT(params->ith == 0);\n    memset(cdata, 0, ggml_nbytes(dst));\n\n    const int64_t px = src->ne[0];\n    const int64_t py = src->ne[1];\n    const int64_t pa = px * py;\n\n    const float * splane = (const float *) src->data;\n\n    const int ka = k0 * k1;\n    const int offset0 = -p0;\n    const int offset1 = -p1;\n\n    while (cdata < data_end) {\n        for (int oy = 0; oy < py; ++oy) {\n            const float * const srow = splane + oy * px;\n            for (int ox = 0; ox < px; ++ox) {\n                const float grad0 = srow[ox];\n\n                const int ix = offset0 + ox * s0;\n                const int iy = offset1 + oy * s1;\n\n                if (op == GGML_OP_POOL_MAX) {\n                    float maxval = -FLT_MAX;\n                    int kxmax = -1;\n                    int kymax = -1;\n\n                    for (int ky = 0; ky < k1; ++ky) {\n                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {\n                            continue;\n                        }\n                        const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));\n                        for (int kx = 0; kx < k0; ++kx) {\n                            int j = ix + kx;\n                            if (j < 0 || j >= dst->ne[0]) {\n                                continue;\n                            }\n\n                            const float val = dst->type == GGML_TYPE_F32 ?\n                                ((const float *) drowf)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);\n                            if (val <= maxval) {\n                                continue;\n                            }\n\n                            maxval = val;\n                            kxmax = kx;\n                            kymax = ky;\n                        }\n                    }\n\n                    if (kxmax == -1 || kymax == -1) {\n                        continue;\n                    }\n\n                    void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));\n                    const int j = ix + kxmax;\n                    if (dst->type == GGML_TYPE_F32) {\n                        ((float *) drow)[j] += grad0;\n                    } else {\n                        ((ggml_fp16_t *) drow)[j] = GGML_CPU_FP32_TO_FP16(grad0 + GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));\n                    }\n                } else if (op == GGML_OP_POOL_AVG) {\n                    const float grad = grad0 / ka;\n\n                    for (int ky = 0; ky < k1; ++ky) {\n                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {\n                            continue;\n                        }\n                        void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));\n                        for (int kx = 0; kx < k0; ++kx) {\n                            int j = ix + kx;\n                            if (j < 0 || j >= dst->ne[0]) {\n                                continue;\n                            }\n\n                            if (dst->type == GGML_TYPE_F32) {\n                                ((float *) drow)[j] += grad;\n                            } else {\n                                ((ggml_fp16_t *) drow)[j] += GGML_CPU_FP32_TO_FP16(grad);\n                            }\n                        }\n                    }\n                } else {\n                    GGML_ASSERT(false);\n                }\n            }\n        }\n\n        cdata  += dst->nb[2];\n        cdataf += dst->nb[2];\n        splane += pa;\n    }\n}\n\n// ggml_compute_forward_upscale\n\nstatic void ggml_compute_forward_upscale_f32(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    float sf0 = (float)ne0/src0->ne[0];\n    float sf1 = (float)ne1/src0->ne[1];\n    float sf2 = (float)ne2/src0->ne[2];\n    float sf3 = (float)ne3/src0->ne[3];\n    float pixel_offset = 0.5f;\n\n    const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);\n    const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);\n\n    if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {\n        pixel_offset = 0.0f;\n        sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;\n        sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;\n    }\n\n    if (mode == GGML_SCALE_MODE_NEAREST) {\n        for (int64_t i3 = 0; i3 < ne3; i3++) {\n            const int64_t i03 = i3 / sf3;\n            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {\n                const int64_t i02 = i2 / sf2;\n                for (int64_t i1 = 0; i1 < ne1; i1++) {\n                    const int64_t i01 = i1 / sf1;\n                    for (int64_t i0 = 0; i0 < ne0; i0++) {\n                        const int64_t i00 = i0 / sf0;\n\n                        const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);\n                              float * y = (float *)((char *)  dst->data +  i0*nb0  +  i1*nb1  +  i2*nb2  +  i3*nb3);\n\n                        *y = *x;\n                    }\n                }\n            }\n        }\n    } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {\n        // Similar to F.interpolate(..., mode=\"bilinear\", align_corners=False, antialias=True)\n        // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp\n        auto triangle_filter = [](float x) -> float {\n            return std::max(1.0f - fabsf(x), 0.0f);\n        };\n\n        // support and invscale, minimum 1 pixel for bilinear\n        const float support1  = std::max(1.0f, 1.0f / sf1);\n        const float invscale1 = 1.0f / support1;\n        const float support0  = std::max(1.0f, 1.0f / sf0);\n        const float invscale0 = 1.0f / support0;\n\n        for (int64_t i3 = 0; i3 < ne3; i3++) {\n            const int64_t i03 = i3 / sf3;\n            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {\n                const int64_t i02 = i2 / sf2;\n                for (int64_t i1 = 0; i1 < ne1; i1++) {\n                    const float y = ((float) i1 + pixel_offset) / sf1;\n                    for (int64_t i0 = 0; i0 < ne0; i0++) {\n                        const float x = ((float) i0 + pixel_offset) / sf0;\n\n                        // the range of source pixels that contribute\n                        const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);\n                        const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);\n                        const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);\n                        const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);\n\n                        // bilinear filter with antialiasing\n                        float val = 0.0f;\n                        float total_weight = 0.0f;\n\n                        for (int64_t sy = y_min; sy < y_max; sy++) {\n                            const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);\n\n                            for (int64_t sx = x_min; sx < x_max; sx++) {\n                                const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);\n                                const float weight = weight_x * weight_y;\n\n                                if (weight <= 0.0f) {\n                                    continue;\n                                }\n\n                                const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);\n                                val += pixel * weight;\n                                total_weight += weight;\n                            }\n                        }\n\n                        if (total_weight > 0.0f) {\n                            val /= total_weight;\n                        }\n\n                        float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);\n                        *dst_ptr = val;\n                    }\n                }\n            }\n        }\n    } else if (mode == GGML_SCALE_MODE_BILINEAR) {\n        for (int64_t i3 = 0; i3 < ne3; i3++) {\n            const int64_t i03 = i3 / sf3;\n            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {\n                const int64_t i02 = i2 / sf2;\n                for (int64_t i1 = 0; i1 < ne1; i1++) {\n                    const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;\n                    int64_t y0 = (int64_t)floorf(y);\n                    int64_t y1 = y0 + 1;\n\n                    y0 = std::max(int64_t(0), std::min(y0, ne01 - 1));\n                    y1 = std::max(int64_t(0), std::min(y1, ne01 - 1));\n\n                    float dy = y - (float)y0;\n                    dy = std::max(0.0f, std::min(dy, 1.0f));\n\n                    for (int64_t i0 = 0; i0 < ne0; i0++) {\n                        const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;\n                        int64_t x0 = (int64_t)floorf(x);\n                        int64_t x1 = x0 + 1;\n\n                        x0 = std::max(int64_t(0), std::min(x0, ne00 - 1));\n                        x1 = std::max(int64_t(0), std::min(x1, ne00 - 1));\n\n                        float dx = x - (float)x0;\n                        dx = std::max(0.0f, std::min(dx, 1.0f));\n\n                        // fetch the four surrounding pixel values and interpolate\n                        const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03);\n                        const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03);\n                        const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03);\n                        const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03);\n\n                        const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;\n\n                        float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);\n                        *y_dst = val;\n                    }\n                }\n            }\n        }\n    } else if (mode == GGML_SCALE_MODE_BICUBIC) {\n        // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm\n        const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)\n        auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };\n        auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };\n        auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {\n            const float w0 = weight2(x + 1);\n            const float w1 = weight1(x + 0);\n            const float w2 = weight1(1 - x);\n            const float w3 = weight2(2 - x);\n            return p0*w0 + p1*w1 + p2*w2 + p3*w3;\n        };\n\n        for (int64_t i3 = 0; i3 < ne3; i3++) {\n            const int64_t i03 = i3 / sf3;\n            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {\n                const int64_t i02 = i2 / sf2;\n                for (int64_t i1 = 0; i1 < ne1; i1++) {\n                    const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;\n                    const int64_t y0 = (int64_t)floorf(y);\n                    const float dy = y - (float)y0;\n\n                    for (int64_t i0 = 0; i0 < ne0; i0++) {\n                        const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;\n                        const int64_t x0 = (int64_t)floorf(x);\n                        const float dx = x - (float)x0;\n\n                        auto p = [=](int64_t x_off, int64_t y_off) -> float {\n                            int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));\n                            int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));\n                            return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);\n                        };\n\n                        const float val = bicubic(\n                            bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),\n                            bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),\n                            bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),\n                            bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);\n\n                        float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);\n                        *y_dst = val;\n                    }\n                }\n            }\n        }\n    } else {\n        GGML_ABORT(\"unsupported upscale mode\");\n    }\n}\n\nvoid ggml_compute_forward_upscale(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_upscale_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n\n// ggml_compute_forward_pad\n\ntemplate<bool circular_t>\nstatic void ggml_compute_forward_pad_f32(\n    const ggml_compute_params * params,\n          ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    assert(dst->nb[0] == sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    float * dst_ptr = (float *) dst->data;\n    const int32_t lp0 = ggml_get_op_params_i32(dst, 0);\n    const int32_t rp0 = ggml_get_op_params_i32(dst, 1);\n    const int32_t lp1 = ggml_get_op_params_i32(dst, 2);\n    const int32_t rp1 = ggml_get_op_params_i32(dst, 3);\n    const int32_t lp2 = ggml_get_op_params_i32(dst, 4);\n    const int32_t rp2 = ggml_get_op_params_i32(dst, 5);\n    const int32_t lp3 = ggml_get_op_params_i32(dst, 6);\n    const int32_t rp3 = ggml_get_op_params_i32(dst, 7);\n\n    // TODO: optimize\n\n    for (int64_t i2 = 0; i2 < ne2; ++i2) {\n        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {\n            for (int64_t i0 = 0; i0 < ne0; ++i0) {\n                for (int64_t i3 = 0; i3 < ne3; ++i3) {\n                    // circular means wrap around on a torus, so x and y loop around\n                    if constexpr (circular_t) {\n                        const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;\n                        const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);\n                        const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);\n                        const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);\n                        const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);\n\n                        const int64_t src_idx =\n                            src_i3*nb03 +\n                            src_i2*nb02 +\n                            src_i1*nb01 +\n                            src_i0*nb00;\n\n                        const float * src_ptr = (const float *)((char *) src0->data + src_idx);\n                        dst_ptr[dst_idx] = *src_ptr;\n                    } else {\n                        const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;\n                        if ((i0 >= lp0 && i0 < ne0 - rp0) \\\n                            && (i1 >= lp1 && i1 < ne1 - rp1) \\\n                            && (i2 >= lp2 && i2 < ne2 - rp2) \\\n                            && (i3 >= lp3 && i3 < ne3 - rp3)) {\n                            const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;\n                            const float * src_ptr = (const float *)((char *) src0->data + src_idx);\n                            dst_ptr[dst_idx] = *src_ptr;\n                        } else {\n                            dst_ptr[dst_idx] = 0;\n                        }\n                    }\n                }\n            }\n        }\n    }\n}\n\n\nvoid ggml_compute_forward_pad(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const bool circular = (bool) ggml_get_op_params_i32(dst, 8);\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                if (circular) {\n                    ggml_compute_forward_pad_f32<true>(params, dst);\n                } else {\n                    ggml_compute_forward_pad_f32<false>(params, dst);\n                }\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_pad_reflect_1d\n\nvoid ggml_compute_forward_pad_reflect_1d(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int32_t * opts = (const int32_t *) dst->op_params;\n    const int p0 = opts[0];\n    const int p1 = opts[1];\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    for (int64_t i3 = 0; i3 < ne3; i3++) {\n        for (int64_t i2 = 0; i2 < ne2; i2++) {\n            for (int64_t i1 = ith; i1 < ne1; i1 += nth) {\n                float * left  = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 +         p0*nb0);\n                float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);\n\n                ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));\n\n                for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0];   }\n                for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }\n            }\n        }\n    }\n}\n\n// ggml_compute_forward_roll\n\nstatic int64_t ggml_wrap_index(int64_t i, int64_t ne) {\n    if (i < 0) {\n        return i + ne;\n    } else if (i >= ne) {\n        return i - ne;\n    }\n    return i;\n}\n\nstatic void ggml_compute_forward_roll_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src_data = (const float *) src0->data;\n    float * dst_data = (float *) dst->data;\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    const int s0 = ggml_get_op_params_i32(dst, 0);\n    const int s1 = ggml_get_op_params_i32(dst, 1);\n    const int s2 = ggml_get_op_params_i32(dst, 2);\n    const int s3 = ggml_get_op_params_i32(dst, 3);\n\n    const int64_t total = ne1 * ne2 * ne3;\n    const int64_t per_thread = (total + params->nth) / params->nth;\n    const int64_t start = params->ith * per_thread;\n    const int64_t end   = std::min(start + per_thread, total);\n\n    for (int64_t i = start; i < end; ++i) {\n        const int64_t i1 = i % ne1;\n        const int64_t i2 = (i / ne1) % ne2;\n        const int64_t i3 = i / (ne2 * ne1);\n        float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);\n\n        const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);\n        const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);\n        const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);\n        const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);\n\n        const int64_t s = ggml_wrap_index(-s0, ne00);\n        const int64_t n = ne00 - s;\n        ggml_vec_cpy_f32(n, dst_row,     src_row + s);\n        ggml_vec_cpy_f32(s, dst_row + n, src_row);\n    }\n}\n\nvoid ggml_compute_forward_roll(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_roll_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_arange\n\nstatic void ggml_compute_forward_arange_f32(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    GGML_ASSERT(dst->nb[0] == sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const float start = ggml_get_op_params_f32(dst, 0);\n    const float stop  = ggml_get_op_params_f32(dst, 1);\n    const float step  = ggml_get_op_params_f32(dst, 2);\n\n    const int64_t steps = (int64_t) ceilf((stop - start) / step);\n\n    GGML_ASSERT(ggml_nelements(dst) == steps);\n\n    for (int64_t i = ith; i < steps; i+= nth) {\n        float value = start + step * i;\n        ((float *)dst->data)[i] = value;\n    }\n}\n\nvoid ggml_compute_forward_arange(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n    switch (dst->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_arange_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\nstatic void ggml_compute_forward_timestep_embedding_f32(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    const int dim = ggml_get_op_params_i32(dst, 0);\n    const int max_period = ggml_get_op_params_i32(dst, 1);\n\n    int half = dim / 2;\n\n    for (int64_t i = 0; i < ne00; i++) {\n        float * embed_data = (float *)((char *)  dst->data +  i*nb1);\n        for (int64_t j = ith; j < half; j += nth) {\n            float timestep = ((float *)src0->data)[i];\n            float freq = (float)expf(-logf(max_period) * j / half);\n            float arg = timestep * freq;\n            embed_data[j] = cosf(arg);\n            embed_data[j + half] = sinf(arg);\n        }\n        if (dim % 2 != 0 && ith == 0) {\n            embed_data[2 * half] = 0.f;\n        }\n    }\n}\n\nvoid ggml_compute_forward_timestep_embedding(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_timestep_embedding_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_argsort\n\ntemplate<enum ggml_sort_order order>\nstruct cmp_argsort {\n    const float * data;\n    bool operator()(int32_t a, int32_t b) const {\n        if constexpr (order == GGML_SORT_ORDER_ASC) {\n            return data[a] < data[b];\n        } else {\n            return data[a] > data[b];\n        }\n    }\n};\n\nstatic void ggml_compute_forward_argsort_f32(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    GGML_ASSERT(nb0 == sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int64_t nr = ggml_nrows(src0);\n\n    ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);\n\n    for (int64_t i = ith; i < nr; i += nth) {\n        const float * src_data = (float *)((char *) src0->data + i*nb01);\n\n        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);\n\n        for (int64_t j = 0; j < ne0; j++) {\n            dst_data[j] = j;\n        }\n\n        switch (order) {\n            case GGML_SORT_ORDER_ASC:\n                std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});\n                break;\n\n            case GGML_SORT_ORDER_DESC:\n                std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});\n                break;\n\n            default:\n                GGML_ABORT(\"invalid sort order\");\n        }\n    }\n}\n\nvoid ggml_compute_forward_argsort(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_argsort_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_top_k\n\nstruct cmp_top_k {\n    const float * data;\n    bool operator()(int32_t a, int32_t b) const {\n        return data[a] > data[b];\n    }\n};\n\nstatic void ggml_compute_forward_top_k_f32(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    GGML_ASSERT(nb0 == sizeof(float));\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int64_t nr = ggml_nrows(src0);\n\n    const int top_k = ne0;\n\n    int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;\n\n    for (int64_t i = ith; i < nr; i += nth) {\n        const float * src_data = (float *)((char *) src0->data + i*nb01);\n\n        for (int64_t j = 0; j < ne00; j++) {\n            tmp[j] = j;\n        }\n\n        std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});\n\n        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);\n\n        std::copy(tmp, tmp + top_k, dst_data);\n\n        // emphasize that the order is not important\n        if (top_k > 1) {\n            std::swap(dst_data[0], dst_data[1]);\n        }\n    }\n}\n\nvoid ggml_compute_forward_top_k(\n    const ggml_compute_params * params,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_top_k_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\nstatic void ggml_compute_forward_flash_attn_ext_f16_one_chunk(\n        const ggml_compute_params * params,\n        ggml_tensor * dst,\n        int ir0, int ir1,\n        int64_t ic_start, int64_t ic_end,\n        float * partials, int64_t partial_stride) {\n\n    const bool write_partials = (partials != nullptr);\n    const ggml_tensor * q     = dst->src[0];\n    const ggml_tensor * k     = dst->src[1];\n    const ggml_tensor * v     = dst->src[2];\n    const ggml_tensor * mask  = dst->src[3];\n    const ggml_tensor * sinks = dst->src[4];\n\n    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)\n    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)\n    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)\n\n    const int64_t DK = nek0;\n    const int64_t DV = nev0;\n    const int64_t N  = neq1;\n\n    GGML_ASSERT(ne0 == DV);\n    GGML_ASSERT(ne2 == N);\n\n    // input tensor rows must be contiguous\n    GGML_ASSERT(nbq0 == ggml_type_size(q->type));\n    GGML_ASSERT(nbk0 == ggml_type_size(k->type));\n    GGML_ASSERT(nbv0 == ggml_type_size(v->type));\n\n    GGML_ASSERT(neq0 == DK);\n    GGML_ASSERT(nek0 == DK);\n    GGML_ASSERT(nev0 == DV);\n\n    GGML_ASSERT(neq1 == N);\n\n    // dst cannot be transposed or permuted\n    GGML_ASSERT(nb0 == sizeof(float));\n    GGML_ASSERT(nb0 <= nb1);\n    GGML_ASSERT(nb1 <= nb2);\n    GGML_ASSERT(nb2 <= nb3);\n\n    // broadcast factors\n    const int64_t rk2 = neq2/nek2;\n    const int64_t rk3 = neq3/nek3;\n\n    const int64_t rv2 = neq2/nev2;\n    const int64_t rv3 = neq3/nev3;\n\n    // parallelize by q rows using ggml_vec_dot_f32\n\n    float scale         = 1.0f;\n    float max_bias      = 0.0f;\n    float logit_softcap = 0.0f;\n\n    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));\n    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));\n    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));\n\n    if (logit_softcap != 0) {\n        scale /= logit_softcap;\n    }\n\n    const uint32_t n_head      = neq2;\n    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));\n\n    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);\n    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n    ggml_type         const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;\n    ggml_from_float_t const q_to_vec_dot   = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;\n    ggml_vec_dot_t    const kq_vec_dot     = ggml_get_type_traits_cpu(k->type)->vec_dot;\n    ggml_to_float_t   const v_to_float     = ggml_get_type_traits(v->type)->to_float;\n\n    GGML_ASSERT((                            q_to_vec_dot) && \"fattn: unsupported K-type\");\n    GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float  ) && \"fattn: unsupported V-type\");\n\n    int ith = params->ith;\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        // q indices\n        const int iq3 = ir/(neq2*neq1);\n        const int iq2 = (ir - iq3*neq2*neq1)/neq1;\n        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);\n\n        const uint32_t h = iq2; // head index\n        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;\n\n        float S = 0.0f;      // sum\n        float M = -INFINITY; // maximum KQ value\n\n        float       * VKQ32 = (float       *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator\n        float       * V32   =                 (VKQ32 + 1*DV); // (temporary) FP32 V buffer\n        ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator\n        ggml_fp16_t * Q_q   = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16\n\n        if (v->type == GGML_TYPE_F16) {\n            memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));\n        } else {\n            memset(VKQ32, 0, DV*sizeof(float));\n        }\n\n        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;\n\n        // k indices\n        const int ik3 = iq3 / rk3;\n        const int ik2 = iq2 / rk2;\n\n        // v indices\n        const int iv3 = iq3 / rv3;\n        const int iv2 = iq2 / rv2;\n\n        const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));\n        q_to_vec_dot(pq, Q_q, DK);\n\n        // online softmax / attention\n        // loop over n_kv and n_head_kv\n        // ref: https://arxiv.org/pdf/2112.05682.pdf\n\n        for (int64_t ic = ic_start; ic < ic_end; ++ic) {\n            const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;\n            if (mv == -INFINITY) {\n                continue;\n            }\n\n            float s; // KQ value\n\n            const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);\n            kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);\n\n            s = s*scale; // scale KQ value\n\n            if (logit_softcap != 0.0f) {\n                s = logit_softcap*tanhf(s);\n            }\n\n            s += mv; // apply mask\n\n            const float Mold = M;\n\n            float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value\n            float vs = 1.0f; // post-softmax KQ value, expf(s - M)\n\n            const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));\n\n            if (v->type == GGML_TYPE_F16) {\n                if (s > M) {\n                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f\n                    M = s;\n                    ms = expf(Mold - M);\n\n                    // V = V*expf(Mold - M)\n                    ggml_vec_scale_f16(DV, VKQ16, ms);\n                } else {\n                    // no new maximum, ms == 1.0f, vs != 1.0f\n                    vs = expf(s - M);\n                }\n\n                // V += v*expf(s - M)\n                ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);\n            } else {\n                if (s > M) {\n                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f\n                    M = s;\n                    ms = expf(Mold - M);\n\n                    // V = V*expf(Mold - M)\n                    ggml_vec_scale_f32(DV, VKQ32, ms);\n                } else {\n                    // no new maximum, ms == 1.0f, vs != 1.0f\n                    vs = expf(s - M);\n                }\n\n                // V += v*expf(s - M)\n                if (v_to_float) {\n                    v_to_float(v_data, V32, DV);\n                    ggml_vec_mad_f32(DV, VKQ32, V32, vs);\n                } else {\n                    // V is F32\n                    ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);\n                }\n            }\n\n            S = S*ms + vs; // scale and increment sum with partial sum\n        }\n\n        if (v->type == GGML_TYPE_F16) {\n            for (int64_t d = 0; d < DV; ++d) {\n                VKQ32[d] = GGML_CPU_FP16_TO_FP32(VKQ16[d]);\n            }\n        }\n\n        // sinks - apply only on the first kv-chunk\n        if (sinks && ic_start == 0) {\n            const float s = ((float *)((char *) sinks->data))[h];\n\n            float ms = 1.0f;\n            float vs = 1.0f;\n\n            if (s > M) {\n                ms = expf(M - s);\n                M = s;\n                ggml_vec_scale_f32(DV, VKQ32, ms);\n            } else {\n                vs = expf(s - M);\n            }\n\n            S = S*ms + vs;\n        }\n\n        if (write_partials) {\n            // Write M, S, VKQ to partials for later reduction\n            // partials layout: [M, S, VKQ[DV]] per query head\n            float * partial = partials + ir * partial_stride;\n            partial[0] = M;\n            partial[1] = S;\n            memcpy(partial + 2, VKQ32, DV * sizeof(float));\n        } else {\n            // V /= S\n            const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;\n            ggml_vec_scale_f32(DV, VKQ32, S_inv);\n\n            // dst indices\n            const int i1 = iq1;\n            const int i2 = iq2;\n            const int i3 = iq3;\n\n            // permute(0, 2, 1, 3)\n            memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);\n        }\n    }\n}\n\nstatic void ggml_compute_forward_flash_attn_ext_tiled(\n        const ggml_compute_params * params,\n        ggml_tensor * dst,\n        int ir0, int ir1) {\n    const ggml_tensor * q     = dst->src[0];\n    const ggml_tensor * k     = dst->src[1];\n    const ggml_tensor * v     = dst->src[2];\n    const ggml_tensor * mask  = dst->src[3];\n    const ggml_tensor * sinks = dst->src[4];\n\n    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)\n    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)\n    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)\n\n    const int64_t DK = nek0;\n    const int64_t DV = nev0;\n    const int64_t N  = neq1;\n\n    GGML_ASSERT(ne0 == DV);\n    GGML_ASSERT(ne2 == N);\n\n    // input tensor rows must be contiguous\n    GGML_ASSERT(nbq0 == ggml_type_size(q->type));\n    GGML_ASSERT(nbk0 == ggml_type_size(k->type));\n    GGML_ASSERT(nbv0 == ggml_type_size(v->type));\n\n    GGML_ASSERT(neq0 == DK);\n    GGML_ASSERT(nek0 == DK);\n    GGML_ASSERT(nev0 == DV);\n\n    GGML_ASSERT(neq1 == N);\n\n    // dst cannot be transposed or permuted\n    GGML_ASSERT(nb0 == sizeof(float));\n    GGML_ASSERT(nb0 <= nb1);\n    GGML_ASSERT(nb1 <= nb2);\n    GGML_ASSERT(nb2 <= nb3);\n\n    GGML_ASSERT(k->type == v->type);\n    const ggml_type kv_type = k->type;\n\n\n    // broadcast factors\n    const int64_t rk2 = neq2/nek2;\n    const int64_t rk3 = neq3/nek3;\n\n    const int64_t rv2 = neq2/nev2;\n    const int64_t rv3 = neq3/nev3;\n\n    float scale         = 1.0f;\n    float max_bias      = 0.0f;\n    float logit_softcap = 0.0f;\n\n    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));\n    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));\n    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));\n\n    if (logit_softcap != 0) {\n        scale /= logit_softcap;\n    }\n\n    const uint32_t n_head      = neq2;\n    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));\n\n    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);\n    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n    int ith = params->ith;\n\n    static constexpr int Q_TILE_SZ  = ggml_fa_tile_config::Q;\n    static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;\n\n    int ir = ir0;\n    while (ir < ir1) {\n        // q indices for the start of this tile\n        const int iq3 = ir/(neq2*neq1);\n        const int iq2 = (ir - iq3*neq2*neq1)/neq1;\n        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);\n\n        // Number of valid rows in this tile:\n        // - limited by tile size (Q_TILE_SZ)\n        // - limited by chunk boundary (ir1 - ir)\n        // - limited by head boundary (neq1 - iq1) to avoid crossing into next head\n        const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));\n        GGML_ASSERT(tile_rows > 0);\n\n        const uint32_t h = iq2; // head index\n        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;\n\n        float S[Q_TILE_SZ];\n        float M[Q_TILE_SZ];\n\n        for (int i = 0 ; i < Q_TILE_SZ; ++i) {\n            S[i] = 0.;\n            M[i] = -INFINITY;\n        }\n\n        // Per-thread scratch layout:\n        // Q_q:    Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)\n        // KQ:     Q_TILE_SZ * KV_TILE_SZ (attention scores in float)\n        // mask:   Q_TILE_SZ * KV_TILE_SZ (mask in float)\n        // VKQ32:  Q_TILE_SZ * DV (FP32 output accumulator)\n        // V32:    KV_TILE_SZ * DV (F32 buffer for V tile)\n        // K_f32:  KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)\n        float * base  = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);\n\n        void  * Q_q    = base;\n        float * KQ     = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));\n        float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;\n        float * VKQ32  = mask32 + Q_TILE_SZ * KV_TILE_SZ;\n        float * V32    = VKQ32 + Q_TILE_SZ * DV;\n        float * K_f32  = V32 + KV_TILE_SZ * DV;\n\n        memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));\n        memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));\n\n        // k indices\n        const int ik3 = iq3 / rk3;\n        const int ik2 = iq2 / rk2;\n\n        // v indices\n        const int iv3 = iq3 / rv3;\n        const int iv2 = iq2 / rv2;\n\n        {\n            float * Q_f32 = (float *)Q_q;\n            for (int tq = 0; tq < tile_rows; tq++) {\n                const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));\n                memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));\n            }\n            for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {\n                memset(Q_f32 + tq * DK, 0, DK * sizeof(float));\n            }\n        }\n\n        memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));\n        memset(V32,   0, KV_TILE_SZ * DV * sizeof(float));\n\n        for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {\n            const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);\n\n            // skip the tile entirely if all the masks are -inf\n            if (mask) {\n                bool can_skip = true;\n                for (int tq = 0; tq < tile_rows; tq++) {\n                    const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);\n                    for (int tk = 0; tk < kv_tile; tk++) {\n                        mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);\n                        if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {\n                            can_skip = false;\n                        }\n                    }\n                    // Pad remaining mask entries with -inf\n                    for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {\n                        mask32[tq * KV_TILE_SZ + tk] = -INFINITY;\n                    }\n                }\n\n                if (can_skip) {\n                    continue;\n                }\n            }\n\n            // Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)\n            // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns\n            for (int tk = 0; tk < kv_tile; tk++) {\n                const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;\n                if (kv_type == GGML_TYPE_F16) {\n                    const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;\n                    for (int64_t dk = 0; dk < DK; dk++) {\n                        K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);\n                    }\n                } else {\n                    const float * k_f32_src = (const float *)k_data;\n                    for (int64_t dk = 0; dk < DK; dk++) {\n                        K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];\n                    }\n                }\n            }\n            memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));\n            simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);\n            ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);\n\n            // Set padded KQ entries to -inf so softmax gives them zero weight\n            if (kv_tile < KV_TILE_SZ) {\n                for (int tq = 0; tq < Q_TILE_SZ; tq++) {\n                    for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {\n                        KQ[tq * KV_TILE_SZ + tk] = -INFINITY;\n                    }\n                }\n            }\n\n            if (logit_softcap != 0.0f) {\n                ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);\n                ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);\n            }\n\n            if (mask) {\n                ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);\n            }\n\n            bool skip[Q_TILE_SZ] = {};\n\n            for (int tq = 0; tq < Q_TILE_SZ; tq++) {\n                float * kq_row = KQ + tq * KV_TILE_SZ;\n\n                float tile_max;\n                ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);\n\n                if (tile_max == -INFINITY) {\n                    skip[tq] = true;\n                    continue;\n                }\n\n                const float Mold = M[tq];\n                const float Mnew = fmaxf(Mold, tile_max);\n\n                if (Mnew > Mold) {\n                    const float ms = expf(Mold - Mnew);\n                    ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);\n                    S[tq] *= ms;\n                }\n                M[tq] = Mnew;\n\n\n                S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);\n            }\n\n            // V accumulation: VKQ32 += softmax(KQ) * V\n            // Pack V tile to contiguous F32, zero-padded\n            for (int tk = 0; tk < kv_tile; tk++) {\n                const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;\n                if (kv_type == GGML_TYPE_F16) {\n                    ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);\n                } else {\n                    memcpy(V32 + tk * DV, v_data, DV * sizeof(float));\n                }\n            }\n            for (int tq = 0; tq < Q_TILE_SZ; tq++) {\n                if (skip[tq]) {\n                    memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));\n                }\n            }\n            simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);\n        }\n\n        // sinks (apply only to valid rows in the tile)\n        if (sinks) {\n            const float s = ((float *)((char *) sinks->data))[h];\n\n            for (int tq = 0; tq < tile_rows; tq++) {\n                float ms = 1.0f;\n                float vs = 1.0f;\n\n                if (s > M[tq]) {\n                    ms = expf(M[tq] - s);\n                    ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);\n                } else {\n                    vs = expf(s - M[tq]);\n                }\n\n                S[tq] = S[tq] * ms + vs;\n            }\n        }\n\n        for (int tq = 0; tq < tile_rows; tq++) {\n            // V /= S\n            const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];\n            ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);\n\n            // dst indices\n            const int i1 = iq1 + tq;\n            const int i2 = iq2;\n            const int i3 = iq3;\n\n            // permute(0, 2, 1, 3)\n            memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);\n        }\n\n        ir += tile_rows;\n    }\n}\n\n// Reduction function: combines partial results across KV chunks\n// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]\nstatic void ggml_flash_attn_ext_reduce_partials(\n        const ggml_compute_params * params,\n        ggml_tensor * dst,\n        const int64_t n_chunks,\n        const int64_t chunk_size) {\n\n    const ggml_tensor * q = dst->src[0];\n    const ggml_tensor * k = dst->src[1];\n    const ggml_tensor * v = dst->src[2];\n\n    const int64_t DK        = k->ne[0];\n    const int64_t DV        = v->ne[0];\n    const int64_t nek1      = k->ne[1];\n    const int64_t n_q_heads = q->ne[2];\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;\n    float *       thread_wdata     = (float *) params->wdata + ith * wdata_per_thread;\n\n    const int64_t partials_offset  = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);\n    const int64_t partial_size     = 2 + DV;\n    const float * partials_base    = (const float *) params->wdata + partials_offset;\n\n    // Output layout\n    const int64_t ne1 = dst->ne[1];\n    const int64_t ne2 = dst->ne[2];\n    const size_t  nb1 = dst->nb[1];\n\n    // Each thread reduces a subset of query heads\n    for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {\n        float   M_final   = -INFINITY;\n        float   S_final   = 0.0f;\n        float * VKQ_final = thread_wdata;\n        memset(VKQ_final, 0, DV * sizeof(float));\n\n        // Combine partials from all chunks\n        for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {\n            const int64_t ic_start = chunk_idx * chunk_size;\n            if (ic_start >= nek1) continue;\n\n            const float * partial   = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;\n            const float   M_chunk   = partial[0];\n            const float   S_chunk   = partial[1];\n            const float * VKQ_chunk = partial + 2;\n\n            if (S_chunk == 0.0f) continue;\n\n            const float M_new     = fmaxf(M_final, M_chunk);\n            const float scale_old = expf(M_final - M_new);\n            const float scale_new = expf(M_chunk - M_new);\n\n            for (int64_t d = 0; d < DV; ++d) {\n                VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;\n            }\n            S_final = S_final * scale_old + S_chunk * scale_new;\n            M_final = M_new;\n        }\n\n        // Normalize and write to output\n        if (S_final != 0.0f) {\n            const float S_inv = 1.0f / S_final;\n            ggml_vec_scale_f32(DV, VKQ_final, S_inv);\n        }\n        // iq1=0, iq3=0 for decode\n        memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);\n    }\n}\n\nstatic void ggml_compute_forward_flash_attn_ext_f16(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * q     = dst->src[0];\n    const ggml_tensor * k     = dst->src[1];\n    const ggml_tensor * v     = dst->src[2];\n\n    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)\n    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)\n    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)\n\n    const int64_t DK = nek0;\n    const int64_t DV = nev0;\n    const int64_t N  = neq1;\n\n\n    GGML_ASSERT(ne0 == DV);\n    GGML_ASSERT(ne2 == N);\n\n    // input tensor rows must be contiguous\n    GGML_ASSERT(nbq0 == ggml_type_size(q->type));\n    GGML_ASSERT(nbk0 == ggml_type_size(k->type));\n    GGML_ASSERT(nbv0 == ggml_type_size(v->type));\n\n    GGML_ASSERT(neq0 == DK);\n    GGML_ASSERT(nek0 == DK);\n    GGML_ASSERT(nev0 == DV);\n\n    GGML_ASSERT(neq1 == N);\n\n    // dst cannot be transposed or permuted\n    GGML_ASSERT(nb0 == sizeof(float));\n    GGML_ASSERT(nb0 <= nb1);\n    GGML_ASSERT(nb1 <= nb2);\n    GGML_ASSERT(nb2 <= nb3);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)\n    const bool use_ref = params->use_ref;\n\n    const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);\n    const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;\n\n    if (use_split_kv_path) {\n        const int64_t chunk_size = (nek1 + nth - 1) / nth;\n\n        // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]\n        const int64_t partial_size  = 2 + DV;\n        float *       partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);\n\n        const int64_t ic_start = ith * chunk_size;\n        const int64_t ic_end   = std::min(ic_start + chunk_size, nek1);\n\n        const int64_t partial_stride = nth * partial_size;\n        float *       chunk_partials = partials_base + ith * partial_size;\n\n        if (ic_start < nek1) {\n            for (int64_t q_head = 0; q_head < neq2; q_head++) {\n                ggml_compute_forward_flash_attn_ext_f16_one_chunk(\n                    params, dst, q_head, q_head + 1, ic_start, ic_end,\n                    chunk_partials, partial_stride);\n            }\n        } else {\n            for (int64_t q_head = 0; q_head < neq2; q_head++) {\n                float * q_partials = chunk_partials + q_head * partial_stride;\n                q_partials[0] = -INFINITY;  // M\n                q_partials[1] = 0.0f;       // S\n            }\n        }\n\n        ggml_barrier(params->threadpool);\n        ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);\n    } else {\n\n        // total rows in q\n        const int64_t nr = neq1*neq2*neq3;\n\n        // disable for NUMA\n        const bool disable_chunking = ggml_is_numa();\n\n        // 4x chunks per thread\n        int nth_scaled = nth * 4;\n        int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;\n        int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;\n\n        if (nth == 1 || nchunk < nth || disable_chunking) {\n            nchunk = nth;\n        }\n\n        if (ith == 0) {\n            ggml_threadpool_chunk_set(params->threadpool, nth);\n        }\n\n        ggml_barrier(params->threadpool);\n\n        const int64_t dr = (nr + nchunk - 1) / nchunk;\n\n        static constexpr int64_t Q_TILE_SZ  = ggml_fa_tile_config::Q;\n        bool use_tiled = !use_ref &&\n                               (q->type == GGML_TYPE_F32 &&\n                                kv_is_f32_or_f16 &&\n                                k->type == v->type &&\n                                neq1 >= Q_TILE_SZ);\n#ifdef GGML_SIMD\n        use_tiled &= (DV % GGML_F32_EPR == 0);\n#endif\n        int current_chunk = ith;\n\n        while (current_chunk < nchunk) {\n            const int64_t ir0 = dr * current_chunk;\n            const int64_t ir1 = MIN(ir0 + dr, nr);\n\n            if (use_tiled) {\n                ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);\n            } else {\n                ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);\n            }\n\n            current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);\n        }\n    }\n}\n\nvoid ggml_compute_forward_flash_attn_ext(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    switch (dst->op_params[3]) {\n        case GGML_PREC_DEFAULT:\n        case GGML_PREC_F32:\n            {\n                // uses F32 accumulators\n                ggml_compute_forward_flash_attn_ext_f16(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_flash_attn_back\n\nstatic void ggml_compute_forward_flash_attn_back_f32(\n        const ggml_compute_params * params,\n        const bool masked,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * q = dst->src[0];\n    const ggml_tensor * k = dst->src[1];\n    const ggml_tensor * v = dst->src[2];\n    const ggml_tensor * d = dst->src[3];\n\n    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)\n    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)\n    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)\n    GGML_TENSOR_LOCALS(int64_t, ned, d,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbd, d,   nb)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int64_t D = neq0;\n    const int64_t N = neq1;\n    const int64_t P = nek1 - N;\n    const int64_t M = P + N;\n\n    const int Mup  = ggml_up(M, GGML_SOFT_MAX_UNROLL);\n    const int mxDM = MAX(D, Mup);\n\n    // GGML_ASSERT(ne0 == D);\n    // GGML_ASSERT(ne1 == N);\n    GGML_ASSERT(P >= 0);\n\n    GGML_ASSERT(nbq0 == sizeof(float));\n    GGML_ASSERT(nbk0 == sizeof(float));\n    GGML_ASSERT(nbv0 == sizeof(float));\n\n    GGML_ASSERT(neq0 == D);\n    GGML_ASSERT(nek0 == D);\n    GGML_ASSERT(nev1 == D);\n    GGML_ASSERT(ned0 == D);\n\n    GGML_ASSERT(neq1 == N);\n    GGML_ASSERT(nek1 == N + P);\n    GGML_ASSERT(nev1 == D);\n    GGML_ASSERT(ned1 == N);\n\n    // dst cannot be transposed or permuted\n    GGML_ASSERT(nb0 == sizeof(float));\n    GGML_ASSERT(nb0 <= nb1);\n    GGML_ASSERT(nb1 <= nb2);\n    GGML_ASSERT(nb2 <= nb3);\n\n    if (ith == 0) {\n        memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);\n    }\n    ggml_barrier(params->threadpool);\n\n    const int64_t elem_q = ggml_nelements(q);\n    const int64_t elem_k = ggml_nelements(k);\n\n    ggml_type result_type = dst->type;\n    GGML_ASSERT(ggml_blck_size(result_type) == 1);\n    const size_t tsize = ggml_type_size(result_type);\n\n    const size_t offs_q = 0;\n    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);\n    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);\n\n    void * grad_q = (char *) dst->data;\n    void * grad_k = (char *) dst->data + offs_k;\n    void * grad_v = (char *) dst->data + offs_v;\n\n    const size_t nbgq1 = nb0*neq0;\n    const size_t nbgq2 = nb0*neq0*neq1;\n    const size_t nbgq3 = nb0*neq0*neq1*neq2;\n\n    const size_t nbgk1 = nb0*nek0;\n    const size_t nbgk2 = nb0*nek0*nek1;\n    const size_t nbgk3 = nb0*nek0*nek1*neq2;\n\n    const size_t nbgv1 = nb0*nev0;\n    const size_t nbgv2 = nb0*nev0*nev1;\n    const size_t nbgv3 = nb0*nev0*nev1*neq2;\n\n    // parallelize by k rows using ggml_vec_dot_f32\n\n    // total rows in k\n    const int nr = nek2*nek3;\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    const float scale = 1.0f/sqrtf(D);\n\n    //printf(\"P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\\n\", P, N, D, ir0, ir1, scale);\n\n    // how often k2 (and v2) is repeated in q2\n    int nrep = neq2/nek2;\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        // q indices\n        const int ik3 = ir/(nek2);\n        const int ik2 = ir - ik3*nek2;\n\n        const int iq3 = ik3;\n        const int id3 = ik3;\n        const int iv3 = ik3;\n        const int iv2 = ik2;\n\n        for (int irep = 0; irep < nrep; ++irep) {\n            const int iq2 = ik2 + irep*nek2;\n            const int id2 = iq2;\n\n            // (ik2 + irep*nek2) % nek2 == ik2\n            for (int iq1 = 0; iq1 < neq1; ++iq1) {\n                const int id1 = iq1;\n\n                // not sure about CACHE_LINE_SIZE_F32..\n                // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?\n                float * S  = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);\n                float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);\n\n                for (int i = M; i < Mup; ++i) {\n                    S[i] = -INFINITY;\n                }\n\n                const int64_t masked_begin = masked ? (P + iq1 + 1) : M;\n                for (int64_t ic = 0; ic < masked_begin; ++ic) {\n                    // k indices\n                    const int ik1 = ic;\n\n                    // S indices\n                    const int i1 = ik1;\n\n                    ggml_vec_dot_f32(neq0,\n                            S + i1, 0,\n                            (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,\n                            (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);\n                }\n\n                // scale\n                ggml_vec_scale_f32(masked_begin, S, scale);\n\n                for (int64_t i = masked_begin; i < M; i++) {\n                    S[i] = -INFINITY;\n                }\n\n                // softmax\n                // exclude known -INF S[..] values from max and loop\n                // dont forget to set their SM values to zero\n                {\n                    float max = -INFINITY;\n                    ggml_vec_max_f32(masked_begin, &max, S);\n\n                    ggml_float sum = 0.0;\n                    {\n#ifdef GGML_SOFT_MAX_ACCELERATE\n                        max = -max;\n                        vDSP_vsadd(SM, 1, &max, SM, 1, Mup);\n                        vvexpf(SM, SM, &Mup);\n                        ggml_vec_sum_f32(Mup, &sum, SM);\n#else\n                        sum = ggml_vec_soft_max_f32(Mup, SM, S, max);\n#endif\n                    }\n\n                    assert(sum > 0.0);\n\n                    sum = 1.0/sum;\n                    ggml_vec_scale_f32(masked_begin, SM, sum);\n\n                }\n\n                // step-by-step explanation\n                {\n                    // forward-process                    shape      grads from backward process\n                    // parallel_for ik2,ik3:\n                    //  for irep:\n                    //   iq2 = ik2 + irep*nek2\n                    //   k[:D,:M,:,:]                     [D,M,:,:]  grad[k][:D,:M,ik2,ik3]  += grad[kcur]\n                    //   q[:D,:N,:,:]                     [D,N,:,:]  grad[q][:D,iq1,iq2,iq3] += grad[qcur]\n                    //   v[:M,:D,:,:]                     [M,D,:,:]  grad[v][:M,:D,iv2,iv3]  += grad[vcur]\n                    //   for iq1:\n                    //    kcur   = k[:D,:M,ik2,ik3]       [D,M,1,1]  grad[kcur] = grad[S1].T @ qcur\n                    //    qcur   = q[:D,iq1,iq2,iq3]      [D,1,1,1]  grad[qcur] = grad[S1]   @ kcur\n                    //    vcur   = v[:M,:D,iv2,iv3]       [M,D,1,1]  grad[vcur] = grad[S5].T @ S4\n                    //    S0     = -Inf                   [D,1,1,1]\n                    //   ~S1[i]  = dot(kcur[:D,i], qcur)\n                    //    S1     = qcur @ kcur.T          [M,1,1,1]  grad[S1]   = grad[S2] * scale\n                    //    S2     = S1 * scale             [M,1,1,1]  grad[S2]   = diag_mask_zero(grad[S3], P)\n                    //    S3     = diag_mask_inf(S2, P)   [M,1,1,1]  grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))\n                    //    S4     = softmax(S3)            [M,1,1,1]  grad[S4]   = grad[S5] @ vcur\n                    //   ~S5[i]  = dot(vcur[:,i], S4)\n                    //    S5     = S4 @ vcur.T            [D,1,1,1]  grad[S5]   = d[:D,id1,id2,id3]\n                    //   ~dst[i,iq1,iq2,iq3]  = S5[i]              ^\n                    //    dst[:D,iq1,iq2,iq3] = S5                 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]\n                    // dst                               backward-/ grad[dst]                 = d\n                    //\n                    // output gradients with their dependencies:\n                    //\n                    // grad[kcur] = grad[S1].T @ qcur\n                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale\n                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))\n                    // grad[S4]   = grad[S5] @ vcur\n                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur\n                    // grad[qcur] = grad[S1]   @ kcur\n                    // grad[vcur] = grad[S5].T @ S4\n                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4\n                    //\n                    // in post-order:\n                    //\n                    // S1         = qcur @ kcur.T\n                    // S2         = S1 * scale\n                    // S3         = diag_mask_inf(S2, P)\n                    // S4         = softmax(S3)\n                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur\n                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))\n                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale\n                    // grad[qcur] = grad[S1]   @ kcur\n                    // grad[kcur] = grad[S1].T @ qcur\n                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4\n                    //\n                    // using less variables (SM=S4):\n                    //\n                    // S             = diag_mask_inf(qcur @ kcur.T * scale, P)\n                    // SM            = softmax(S)\n                    // S             = d[:D,iq1,iq2,iq3] @ vcur\n                    // dot_SM_gradSM = dot(SM, S)\n                    // S             = SM * (S - dot(SM, S))\n                    // S             = diag_mask_zero(S, P) * scale\n                    //\n                    // grad[q][:D,iq1,iq2,iq3] += S   @ kcur\n                    // grad[k][:D,:M,ik2,ik3]  += S.T @ qcur\n                    // grad[v][:M,:D,iv2,iv3]  += d[:D,id1,id2,id3].T @ SM\n                }\n\n                // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]\n                // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]\n                // for ic:\n                //   S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]\n                // exclude known future zero S[..] values from operation\n                ggml_vec_set_f32(masked_begin, S, 0);\n                for (int64_t ic = 0; ic < D; ++ic) {\n                    ggml_vec_mad_f32(masked_begin,\n                            S,\n                             (float *) ((char *) v->data + (          ic*nbv1  + iv2*nbv2 + iv3*nbv3)),\n                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));\n                }\n\n                // S = SM * (S - dot(SM, S))\n                float dot_SM_gradSM = 0;\n                ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);\n                ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);\n                ggml_vec_mul_f32 (masked_begin, S, S, SM);\n\n                // S = diag_mask_zero(S, P) * scale\n                // already done by above ggml_vec_set_f32\n\n                // exclude known zero S[..] values from operation\n                ggml_vec_scale_f32(masked_begin, S, scale);\n\n                // S    shape [M,1]\n                // SM   shape [M,1]\n                // kcur shape [D,M]\n                // qcur shape [D,1]\n                // vcur shape [M,D]\n\n                // grad[q][:D,iq1,iq2,iq3] += S @ kcur\n                // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]\n                // for ic:\n                //  grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]\n                // exclude known zero S[..] values from loop\n                for (int64_t ic = 0; ic < masked_begin; ++ic) {\n                    ggml_vec_mad_f32(D,\n                            (float *) ((char *) grad_q  + (iq1*nbgq1 + iq2*nbgq2  + iq3*nbgq3)),\n                            (float *) ((char *) k->data + (ic*nbk1   + ik2*nbk2   + ik3*nbk3)),\n                            S[ic]);\n                }\n\n                // grad[k][:D,:M,iq2,iq3] += S.T @ qcur\n                // for ic:\n                //  grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]\n                //  grad[k][:D,ic,iq2,iq3] += S[ic]     * qcur[:D,0]\n                // exclude known zero S[..] values from loop\n                for (int64_t ic = 0; ic < masked_begin; ++ic) {\n                    ggml_vec_mad_f32(D,\n                            (float *) ((char *) grad_k  + (ic*nbgk1  + ik2*nbgk2  + ik3*nbgk3)),\n                            (float *) ((char *) q->data + (iq1*nbq1  + iq2*nbq2   + iq3*nbq3)),\n                            S[ic]);\n                }\n\n                // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T       @ SM\n                // for ic:\n                //  grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]\n                //  grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3]         * SM[:M]\n                // exclude known zero SM[..] values from mad\n                for (int64_t ic = 0; ic < D; ++ic) {\n                    ggml_vec_mad_f32(masked_begin,\n                            (float *) ((char *) grad_v   + (          ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),\n                            SM,\n                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2  + id3*nbd3)));\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_flash_attn_back(\n        const ggml_compute_params * params,\n        const bool masked,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * q = dst->src[0];\n\n    switch (q->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_flash_attn_back_f32(params, masked, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_ssm_conv\n\nstatic void ggml_compute_forward_ssm_conv_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0]; // conv_x\n    const ggml_tensor * src1 = dst->src[1]; // conv1d.weight\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nc  = src1->ne[0]; // d_conv\n    const int ncs = src0->ne[0]; // d_conv - 1 + n_t\n    const int nr  = src0->ne[1]; // d_inner\n    const int n_t =  dst->ne[1]; // tokens per sequence\n    const int n_s =  dst->ne[2]; // number of sequences in the batch\n\n    GGML_ASSERT( dst->ne[0] == nr);\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n    GGML_ASSERT(src1->nb[0] == sizeof(float));\n    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n    const int ir  = ir1 - ir0;\n\n    for (int i3 = 0; i3 < n_s; ++i3) {\n        for (int i2 = 0; i2 < n_t; ++i2) {\n            // {d_conv - 1 + n_t, d_inner, n_seqs}\n            // sliding window\n            const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}\n            const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}\n            float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}\n\n            // TODO: transpose the output for smaller strides for big batches?\n            // d_inner\n            for (int i1 = 0; i1 < ir; ++i1) {\n                // rowwise dot product\n                // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision\n                float sumf = 0.0f;\n\n                // d_conv\n                for (int i0 = 0; i0 < nc; ++i0) {\n                    sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];\n                }\n                x[i1] = sumf;\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_ssm_conv(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    switch (dst->src[0]->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_ssm_conv_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_ssm_scan\n\nstatic void ggml_compute_forward_ssm_scan_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0]; // s  {d_state, dim, n_head, n_seqs+}\n    const ggml_tensor * src1 = dst->src[1]; // x  {dim, n_head, n_seq_tokens, n_seqs}\n    const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}\n    const ggml_tensor * src3 = dst->src[3]; // A  {d_state, n_head} or {1, n_head}\n    const ggml_tensor * src4 = dst->src[4]; // B  {d_state, n_group, n_seq_tokens, n_seqs}\n    const ggml_tensor * src5 = dst->src[5]; // C  {d_state, n_group, n_seq_tokens, n_seqs}\n    const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int64_t nc = src0->ne[0]; // d_state\n    const int64_t nr = src0->ne[1]; // dim\n    const int64_t nh = src1->ne[1]; // n_head\n    const int64_t ng = src4->ne[1];\n    const int64_t nt = src1->ne[2]; // number of tokens per sequence\n    const int64_t ns = src1->ne[3]; // number of sequences in the batch\n\n    // can't use ggml_nbytes because src1 is not necessarily contiguous\n    const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);\n\n    GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n    GGML_ASSERT(src1->nb[0] == sizeof(float));\n    GGML_ASSERT(src2->nb[0] == sizeof(float));\n    GGML_ASSERT(src3->nb[0] == sizeof(float));\n    GGML_ASSERT(src4->nb[0] == sizeof(float));\n    GGML_ASSERT(src5->nb[0] == sizeof(float));\n    GGML_ASSERT(src6->nb[0] == sizeof(int32_t));\n    GGML_ASSERT(nh % ng == 0);\n\n    // heads per thread\n    const int dh = (nh + nth - 1)/nth;\n\n    // head range for this thread\n    const int ih0 = dh*ith;\n    const int ih1 = MIN(ih0 + dh, nh);\n\n    const int32_t * ids = (const int32_t *) src6->data;\n\n    for (int i3 = 0; i3 < ns; ++i3) {\n        const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}\n              float * s  = (      float *) ((      char *) dst->data  + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}\n\n        for (int i2 = 0; i2 < nt; ++i2) {\n            const float * x  = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}\n            const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}\n            const float * A  = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}\n            const float * B  = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}\n            const float * C  = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}\n                  float * y  = (      float *) ((      char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}\n\n            if (src3->ne[0] == 1) {\n                // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop\n\n                // n_head\n                for (int h = ih0; h < ih1; ++h) {\n                    // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16\n                    const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);\n                    const float dA = expf(dt_soft_plus * A[h]);\n                    const int g = h / (nh / ng); // repeat_interleave\n\n                    // dim\n                    for (int i1 = 0; i1 < nr; ++i1) {\n                        const int ii = i1 + h*nr;\n                        const float x_dt = x[ii] * dt_soft_plus;\n                        float sumf = 0.0f;\n#if defined(GGML_SIMD)\n    #if defined(__ARM_FEATURE_SVE)\n                        const int ggml_f32_epr = svcntw();\n                        const int ggml_f32_step = 1 * ggml_f32_epr;\n\n                        const int np = (nc & ~(ggml_f32_step - 1));\n\n                        GGML_F32_VEC sum = GGML_F32_VEC_ZERO;\n\n                        GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);\n                        GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);\n\n                        for (int i = 0; i < np; i += ggml_f32_step) {\n                            // TODO: maybe unroll more?\n                            for (int j = 0; j < 1; j++) {\n                                GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);\n                                GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);\n                                GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);\n\n                                t0 = GGML_F32_VEC_MUL(t0, adA);\n                                t1 = GGML_F32_VEC_MUL(t1, axdt);\n\n                                t0 = GGML_F32_VEC_ADD(t0, t1);\n\n                                sum = GGML_F32_VEC_FMA(sum, t0, t2);\n\n                                GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);\n                            }\n                        }\n\n                        sumf = GGML_F32xt_REDUCE_ONE(sum);\n    #elif defined(__riscv_v_intrinsic)\n                        // todo: RVV implementation\n                        const int np = 0;\n    #else\n                        const int np = (nc & ~(GGML_F32_STEP - 1));\n\n                        GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };\n\n                        GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);\n                        GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);\n\n                        GGML_F32_VEC ax[GGML_F32_ARR];\n                        GGML_F32_VEC ay[GGML_F32_ARR];\n                        GGML_F32_VEC az[GGML_F32_ARR];\n\n                        for (int i = 0; i < np; i += GGML_F32_STEP) {\n                            for (int j = 0; j < GGML_F32_ARR; j++) {\n                                ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);\n                                ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);\n                                az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);\n\n                                ax[j] = GGML_F32_VEC_MUL(ax[j], adA);\n                                ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);\n\n                                ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);\n\n                                sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);\n\n                                GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);\n                            }\n                        }\n\n                        // reduce sum0..sum3 to sum0\n                        GGML_F32_VEC_REDUCE(sumf, sum);\n    #endif\n#else\n                        const int np = 0;\n#endif\n                        // d_state\n                        for (int i0 = np; i0 < nc; ++i0) {\n                            const int i = i0 + ii*nc;\n                            const int ig = i0 + g*nc;\n                            // state = prev_state * dA + dB * x\n                            const float state = (s0[i] * dA) + (B[ig] * x_dt);\n                            // y = rowwise_dotprod(state, C)\n                            sumf += state * C[ig];\n                            s[i] = state;\n                        }\n                        y[ii] = sumf;\n                    }\n                }\n            } else {\n                // Mamba-1 has an element-wise decay factor for the states\n\n                // n_head\n                for (int h = ih0; h < ih1; ++h) {\n                    // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16\n                    const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);\n                    const int g = h / (nh / ng); // repeat_interleave\n\n                    // dim\n                    for (int i1 = 0; i1 < nr; ++i1) {\n                        const int ii = i1 + h*nr;\n                        const float x_dt = x[ii] * dt_soft_plus;\n#if defined(__ARM_FEATURE_SVE)\n                        svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);\n                        svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);\n                        svfloat32_t r1_vector = GGML_F32_VEC_ZERO;\n\n                        // d_state\n                        // TODO: what happens when (d_state % svcntw()) != 0?\n                        for (int64_t k = 0; k < nc; k += svcntw()) {\n                            svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);\n                            svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);\n                            svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);\n                            svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);\n\n                            svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);\n                            t1 = exp_ps_sve(svptrue_b32(), t1);\n                            svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);\n\n                            vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);\n                            r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);\n\n                            GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);\n                        }\n                        y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);\n#else\n                        float sumf = 0.0f;\n                        // NOTE: can't really use GGML_SIMD here because d_state is usually 16\n                        //       and also because expf is used within the loop.\n                        // d_state\n                        for (int i0 = 0; i0 < nc; ++i0) {\n                            const int i = i0 + ii*nc;\n                            const int ig = i0 + g*nc;\n                            // state = prev_state * dA + dB * x\n                            const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);\n                            // y = rowwise_dotprod(state, C)\n                            sumf += state * C[ig];\n                            s[i] = state;\n                        }\n                        y[ii] = sumf;\n#endif\n                    }\n                }\n            }\n            // use the output as the source when it's not the first token-wise iteration\n            s0 = s;\n        }\n    }\n}\n\nvoid ggml_compute_forward_ssm_scan(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    switch (dst->src[0]->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_ssm_scan_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_win_part\n\nstatic void ggml_compute_forward_win_part_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    GGML_UNUSED(params);\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)\n\n    const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];\n    const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];\n    const int32_t w    = ((const int32_t *)(dst->op_params))[2];\n\n    assert(ne00 == ne0);\n    assert(ne3  == nep0*nep1);\n\n    // TODO: optimize / multi-thread\n    for (int py = 0; py < nep1; ++py) {\n        for (int px = 0; px < nep0; ++px) {\n            const int64_t i3 = py*nep0 + px;\n            for (int64_t i2 = 0; i2 < ne2; ++i2) {\n                for (int64_t i1 = 0; i1 < ne1; ++i1) {\n                    for (int64_t i0 = 0; i0 < ne0; ++i0) {\n                        const int64_t i02 = py*w + i2;\n                        const int64_t i01 = px*w + i1;\n                        const int64_t i00 = i0;\n\n                        const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0    + i1*ne0   + i0;\n                        const int64_t j =                  i02*ne01*ne00 + i01*ne00 + i00;\n\n                        if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {\n                            ((float *) dst->data)[i] = 0.0f;\n                        } else {\n                            ((float *) dst->data)[i] = ((float *) src0->data)[j];\n                        }\n                    }\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_win_part(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_win_part_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_win_unpart\n\nstatic void ggml_compute_forward_win_unpart_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    GGML_UNUSED(params);\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)\n\n    const int32_t w = ((const int32_t *)(dst->op_params))[0];\n\n    // padding\n    const int px = (w - ne1%w)%w;\n    //const int py = (w - ne2%w)%w;\n\n    const int npx = (px + ne1)/w;\n    //const int npy = (py + ne2)/w;\n\n    assert(ne0 == ne00);\n\n    // TODO: optimize / multi-thread\n    for (int64_t i2 = 0; i2 < ne2; ++i2) {\n        for (int64_t i1 = 0; i1 < ne1; ++i1) {\n            for (int64_t i0 = 0; i0 < ne0; ++i0) {\n                const int ip2 = i2/w;\n                const int ip1 = i1/w;\n\n                const int64_t i02 = i2%w;\n                const int64_t i01 = i1%w;\n                const int64_t i00 = i0;\n\n                const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;\n                const int64_t j =                                  i2*ne1*ne0    + i1*ne0   + i0;\n\n                ((float *) dst->data)[j] = ((float *) src0->data)[i];\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_win_unpart(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_win_unpart_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n//ggml_compute_forward_unary\n\nvoid ggml_compute_forward_unary(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_unary_op op = ggml_get_unary_op(dst);\n\n    switch (op) {\n        case GGML_UNARY_OP_ABS:\n            {\n                ggml_compute_forward_abs(params, dst);\n            } break;\n        case GGML_UNARY_OP_SGN:\n            {\n                ggml_compute_forward_sgn(params, dst);\n            } break;\n        case GGML_UNARY_OP_NEG:\n            {\n                ggml_compute_forward_neg(params, dst);\n            } break;\n        case GGML_UNARY_OP_STEP:\n            {\n                ggml_compute_forward_step(params, dst);\n            } break;\n        case GGML_UNARY_OP_TANH:\n            {\n                ggml_compute_forward_tanh(params, dst);\n            } break;\n        case GGML_UNARY_OP_ELU:\n            {\n                ggml_compute_forward_elu(params, dst);\n            } break;\n        case GGML_UNARY_OP_RELU:\n            {\n                ggml_compute_forward_relu(params, dst);\n            } break;\n        case GGML_UNARY_OP_SIGMOID:\n            {\n                ggml_compute_forward_sigmoid(params, dst);\n            } break;\n        case GGML_UNARY_OP_GELU:\n            {\n                ggml_compute_forward_gelu(params, dst);\n            } break;\n        case GGML_UNARY_OP_GELU_ERF:\n            {\n                ggml_compute_forward_gelu_erf(params, dst);\n            } break;\n        case GGML_UNARY_OP_GELU_QUICK:\n            {\n                ggml_compute_forward_gelu_quick(params, dst);\n            } break;\n        case GGML_UNARY_OP_SILU:\n            {\n                ggml_compute_forward_silu(params, dst);\n            } break;\n        case GGML_UNARY_OP_HARDSWISH:\n            {\n                ggml_compute_forward_hardswish(params, dst);\n            } break;\n        case GGML_UNARY_OP_HARDSIGMOID:\n            {\n                ggml_compute_forward_hardsigmoid(params, dst);\n            } break;\n        case GGML_UNARY_OP_EXP:\n            {\n                ggml_compute_forward_exp(params, dst);\n            } break;\n        case GGML_UNARY_OP_FLOOR:\n            {\n                ggml_compute_forward_floor(params, dst);\n            } break;\n        case GGML_UNARY_OP_CEIL:\n            {\n                ggml_compute_forward_ceil(params, dst);\n            } break;\n        case GGML_UNARY_OP_ROUND:\n            {\n                ggml_compute_forward_round(params, dst);\n            } break;\n        case GGML_UNARY_OP_TRUNC:\n            {\n                ggml_compute_forward_trunc(params, dst);\n            } break;\n        case GGML_UNARY_OP_XIELU:\n            {\n                ggml_compute_forward_xielu(params, dst);\n            } break;\n        case GGML_UNARY_OP_EXPM1:\n            {\n                ggml_compute_forward_expm1(params, dst);\n            } break;\n        case GGML_UNARY_OP_SOFTPLUS:\n            {\n                ggml_compute_forward_softplus(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n//ggml_compute_forward_glu\n\nvoid ggml_compute_forward_glu(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_glu_op op = ggml_get_glu_op(dst);\n\n    switch (op) {\n        case GGML_GLU_OP_REGLU:\n            {\n                ggml_compute_forward_reglu(params, dst);\n            } break;\n        case GGML_GLU_OP_GEGLU:\n            {\n                ggml_compute_forward_geglu(params, dst);\n            } break;\n        case GGML_GLU_OP_SWIGLU:\n            {\n                ggml_compute_forward_swiglu(params, dst);\n            } break;\n        case GGML_GLU_OP_SWIGLU_OAI:\n            {\n                ggml_compute_forward_swiglu_oai(params, dst);\n            } break;\n        case GGML_GLU_OP_GEGLU_ERF:\n            {\n                ggml_compute_forward_geglu_erf(params, dst);\n            } break;\n        case GGML_GLU_OP_GEGLU_QUICK:\n            {\n                ggml_compute_forward_geglu_quick(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_get_rel_pos\n\nstatic void ggml_compute_forward_get_rel_pos_f16(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    GGML_UNUSED(params);\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    const int64_t w = ne1;\n\n    ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;\n    ggml_fp16_t * dst_data  = (ggml_fp16_t *) dst->data;\n\n    for (int64_t i2 = 0; i2 < ne2; ++i2) {\n        for (int64_t i1 = 0; i1 < ne1; ++i1) {\n            const int64_t pos = (w - i1 - 1) + i2;\n            for (int64_t i0 = 0; i0 < ne0; ++i0) {\n                dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_get_rel_pos(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F16:\n        case GGML_TYPE_BF16:\n            {\n                ggml_compute_forward_get_rel_pos_f16(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_add_rel_pos\n\nstatic void ggml_compute_forward_add_rel_pos_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    const ggml_tensor * src2 = dst->src[2];\n\n    const bool inplace = (bool) ((int32_t *) dst->op_params)[0];\n    if (!inplace) {\n        if (params->ith == 0) {\n            memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));\n        }\n        ggml_barrier(params->threadpool);\n    }\n    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359\n\n    float * src1_data = (float *) src1->data;\n    float * src2_data = (float *) src2->data;\n    float * dst_data  = (float *) dst->data;\n\n    const int64_t ne10 = src1->ne[0];\n    const int64_t ne11 = src1->ne[1];\n    const int64_t ne12 = src1->ne[2];\n    const int64_t ne13 = src1->ne[3];\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    // total patches in dst\n    const int np = ne13;\n\n    // patches per thread\n    const int dp = (np + nth - 1)/nth;\n\n    // patch range for this thread\n    const int ip0 = dp*ith;\n    const int ip1 = MIN(ip0 + dp, np);\n\n    for (int64_t i13 = ip0; i13 < ip1; ++i13) {\n        for (int64_t i12 = 0; i12 < ne12; ++i12) {\n            for (int64_t i11 = 0; i11 < ne11; ++i11) {\n                const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;\n                for (int64_t i10 = 0; i10 < ne10; ++i10) {\n                    const int64_t jp0  = jp1 + i10;\n                    const float src1_e = src1_data[jp0];\n                    const float src2_e = src2_data[jp0];\n\n                    const int64_t jdh = jp0 * ne10;\n                    const int64_t jdw = jdh - (ne10 - 1) * i10;\n\n                    for (int64_t j = 0; j < ne10; ++j) {\n                        dst_data[jdh + j     ] += src2_e;\n                        dst_data[jdw + j*ne10] += src1_e;\n                    }\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_compute_forward_add_rel_pos(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_add_rel_pos_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_rwkv_wkv6\n\nstatic void ggml_compute_forward_rwkv_wkv6_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    const int64_t T = dst->src[1]->ne[2];\n    const int64_t C = dst->ne[0];\n    const int64_t HEADS = dst->src[1]->ne[1];\n    const int64_t n_seqs = dst->src[5]->ne[1];\n    const int64_t head_size = C / HEADS;\n\n    float * dst_data = (float *) dst->data;\n    float * state = ((float *) dst->data) + C * T;\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    if (ith >= HEADS) {\n        return;\n    }\n\n    const int h_start = (HEADS * ith) / nth;\n    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?\n                (HEADS * (ith + 1)) / nth : HEADS;\n\n    float * k =          (float *) dst->src[0]->data;\n    float * v =          (float *) dst->src[1]->data;\n    float * r =          (float *) dst->src[2]->data;\n    float * time_faaaa = (float *) dst->src[3]->data;\n    float * time_decay = (float *) dst->src[4]->data;\n\n    size_t t_stride = HEADS * head_size; // Same to C\n\n    size_t h_stride = C / HEADS;\n    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS\n    size_t h_stride_2d = head_size * head_size;\n\n    if (ith == 0) {\n        memset(dst_data, 0, T * C * sizeof(float));\n    }\n    ggml_barrier(params->threadpool);\n\n\n    #if defined(__AVX__) && !defined(__AVX512F__)\n        #define GGML_F32X GGML_F32x8\n        #define GGML_F32X_SET1 GGML_F32x8_SET1\n        #define GGML_F32X_LOAD GGML_F32x8_LOAD\n        #define GGML_F32X_STORE GGML_F32x8_STORE\n        #define GGML_F32X_MUL GGML_F32x8_MUL\n        #define GGML_F32X_FMA GGML_F32x8_FMA\n        #define WKV_VECTOR_SIZE 8\n    #elif defined(__AVX512F__)\n        #define GGML_F32X GGML_F32x16\n        #define GGML_F32X_SET1 GGML_F32x16_SET1\n        #define GGML_F32X_LOAD GGML_F32x16_LOAD\n        #define GGML_F32X_STORE GGML_F32x16_STORE\n        #define GGML_F32X_MUL GGML_F32x16_MUL\n        #define GGML_F32X_FMA GGML_F32x16_FMA\n        #define WKV_VECTOR_SIZE 16\n    #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)\n        #define GGML_F32X GGML_F32xt\n        #define GGML_F32X_SET1 GGML_F32xt_SET1\n        #define GGML_F32X_LOAD GGML_F32xt_LOAD\n        #define GGML_F32X_STORE GGML_F32xt_STORE\n        #define GGML_F32X_MUL GGML_F32xt_MUL\n        #define GGML_F32X_FMA GGML_F32xt_FMA\n        #define WKV_VECTOR_SIZE 8\n    #elif defined(__ARM_NEON) && defined(__aarch64__)\n        #define GGML_F32X GGML_F32x4\n        #define GGML_F32X_SET1 GGML_F32x4_SET1\n        #define GGML_F32X_LOAD GGML_F32x4_LOAD\n        #define GGML_F32X_STORE GGML_F32x4_STORE\n        #define GGML_F32X_MUL GGML_F32x4_MUL\n        #define GGML_F32X_FMA GGML_F32x4_FMA\n        #define WKV_VECTOR_SIZE 4\n    #endif\n\n    #ifdef WKV_VECTOR_SIZE\n        int wkv_vector_size;\n        #if defined(__ARM_FEATURE_SVE)\n            wkv_vector_size = svcntw();\n        #else\n            wkv_vector_size = WKV_VECTOR_SIZE;\n        #endif\n        const int64_t vec_count = head_size / wkv_vector_size;\n\n        for (int64_t t = 0; t < T; t++) {\n            size_t t_offset = t * t_stride;\n            size_t state_offset = head_size * C * (t / (T / n_seqs));\n            float * state_cur = state + state_offset;\n            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;\n\n            for (int64_t h = h_start; h < h_end; h++) {\n                size_t h_offset = h * h_stride;\n                size_t t_h_offset = t_offset + h_offset;\n                size_t h_2d_offset = h * h_stride_2d;\n\n                for (int64_t i = 0; i < head_size; i++) {\n                    size_t t_h_i_offset = t_h_offset + i;\n                    size_t h_i_offset = h_offset + i;\n                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;\n\n                    float k_val = k[t_h_i_offset];\n                    float r_val = r[t_h_i_offset];\n                    float time_faaaa_val = time_faaaa[h_i_offset];\n                    float time_decay_val = time_decay[t_h_i_offset];\n\n                    // Broadcast scalar values to vectors\n                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);\n                    GGML_F32X r_vec = GGML_F32X_SET1(r_val);\n                    GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);\n                    GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);\n\n                    for (int64_t j = 0; j < vec_count; j++) {\n                        size_t base_j = j * wkv_vector_size;\n                        size_t t_h_j_offset = t_h_offset + base_j;\n                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;\n\n                        // Load x elements at once\n                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);\n                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);\n                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);\n\n                        // Compute kv = v * k\n                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);\n\n                        // Compute temp = kv * time_faaaa + prev_state\n                        GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);\n\n                        // Update dst: dst += temp * r\n                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);\n                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);\n\n                        // Update state: state = prev_state * time_decay + kv\n                        GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);\n                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);\n                    }\n\n                    // Handle remaining elements, this will not be used.\n                    for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {\n                        size_t t_h_j_offset = t_h_offset + j;\n                        size_t h_2d_i_j_offset = h_2d_i_offset + j;\n                        float v_val = v[t_h_j_offset];\n                        float kv_val = v_val * k_val;\n                        float prev_state_val = state_prev[h_2d_i_j_offset];\n                        float temp_val = kv_val * time_faaaa_val + prev_state_val;\n                        dst_data[t_h_j_offset] += temp_val * r_val;\n                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;\n                    }\n                }\n            }\n        }\n\n    #else\n        // basically fused operations:\n        // dst = r @ (time_faaaa * (k @ v) + state),\n        // state = time_decay * state + (k @ v),\n        // recursive through each token\n        for (int64_t t = 0; t < T; t++) {\n            size_t t_offset = t * t_stride;\n            size_t state_offset = head_size * C * (t / (T / n_seqs));\n            float * state_cur = state + state_offset;\n            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;\n\n            for (int64_t h = h_start; h < h_end; h++) {\n                size_t h_offset = h * h_stride;\n                size_t t_h_offset = t_offset + h_offset;\n                size_t h_2d_offset = h * h_stride_2d;\n\n                for (int64_t i = 0; i < head_size; i++) {\n                    size_t t_h_i_offset = t_h_offset + i;\n                    size_t h_i_offset = h_offset + i;\n                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;\n\n                    float k_val = k[t_h_i_offset];\n                    float r_val = r[t_h_i_offset];\n                    float time_faaaa_val = time_faaaa[h_i_offset];\n                    // RWKV v6: different time_decay for each token.\n                    float time_decay_val = time_decay[t_h_i_offset];\n\n                    for (int64_t j = 0; j < head_size; j++) {\n                        size_t t_h_j_offset = t_h_offset + j;\n                        size_t h_2d_i_j_offset = h_2d_i_offset + j;\n\n                        float v_val = v[t_h_j_offset];\n                        float kv_val = v_val * k_val;\n                        float prev_state_val = state_prev[h_2d_i_j_offset];\n                        float temp_val = kv_val * time_faaaa_val + prev_state_val;\n                        dst_data[t_h_j_offset] += temp_val * r_val;\n                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;\n                    }\n                }\n            }\n        }\n    #endif\n}\n\n\nvoid ggml_compute_forward_rwkv_wkv6(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_rwkv_wkv6_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_gla\n\nstatic void ggml_compute_forward_gla_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    const int64_t T = dst->src[1]->ne[2];\n    const int64_t C = dst->ne[0];\n    const int64_t HEADS = dst->src[1]->ne[1];\n    const int64_t n_seqs = dst->src[4]->ne[1];\n    const int64_t head_size = C / HEADS;\n    const float scale = ggml_get_op_params_f32(dst, 0);\n\n    float * dst_data = (float *) dst->data;\n    float * state = ((float *) dst->data) + C * T;\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    if (ith >= HEADS) {\n        return;\n    }\n\n    const int h_start = (HEADS * ith) / nth;\n    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?\n                (HEADS * (ith + 1)) / nth : HEADS;\n\n    float * k = (float *) dst->src[0]->data;\n    float * v = (float *) dst->src[1]->data;\n    float * q = (float *) dst->src[2]->data;\n    float * g = (float *) dst->src[3]->data;\n\n    size_t t_stride = HEADS * head_size; // Same to C\n\n    size_t h_stride = C / HEADS;\n    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS\n    size_t h_stride_2d = head_size * head_size;\n\n    if (ith == 0) {\n        memset(dst_data, 0, T * C * sizeof(float));\n    }\n    ggml_barrier(params->threadpool);\n\n\n    #if defined(__AVX__) && !defined(__AVX512F__)\n        #define GGML_F32X GGML_F32x8\n        #define GGML_F32X_SET1 GGML_F32x8_SET1\n        #define GGML_F32X_LOAD GGML_F32x8_LOAD\n        #define GGML_F32X_STORE GGML_F32x8_STORE\n        #define GGML_F32X_MUL GGML_F32x8_MUL\n        #define GGML_F32X_FMA GGML_F32x8_FMA\n        #define GLA_VECTOR_SIZE 8\n    #elif defined(__AVX512F__)\n        #define GGML_F32X GGML_F32x16\n        #define GGML_F32X_SET1 GGML_F32x16_SET1\n        #define GGML_F32X_LOAD GGML_F32x16_LOAD\n        #define GGML_F32X_STORE GGML_F32x16_STORE\n        #define GGML_F32X_MUL GGML_F32x16_MUL\n        #define GGML_F32X_FMA GGML_F32x16_FMA\n        #define GLA_VECTOR_SIZE 16\n    #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)\n        #define GGML_F32X GGML_F32xt\n        #define GGML_F32X_SET1 GGML_F32xt_SET1\n        #define GGML_F32X_LOAD GGML_F32xt_LOAD\n        #define GGML_F32X_STORE GGML_F32xt_STORE\n        #define GGML_F32X_MUL GGML_F32xt_MUL\n        #define GGML_F32X_FMA GGML_F32xt_FMA\n        #define GLA_VECTOR_SIZE 8\n    #elif defined(__ARM_NEON) && defined(__aarch64__)\n        #define GGML_F32X GGML_F32x4\n        #define GGML_F32X_SET1 GGML_F32x4_SET1\n        #define GGML_F32X_LOAD GGML_F32x4_LOAD\n        #define GGML_F32X_STORE GGML_F32x4_STORE\n        #define GGML_F32X_MUL GGML_F32x4_MUL\n        #define GGML_F32X_FMA GGML_F32x4_FMA\n        #define GLA_VECTOR_SIZE 4\n    #endif\n\n    #ifdef GLA_VECTOR_SIZE\n        int gla_vector_size;\n        #if defined(__ARM_FEATURE_SVE)\n            gla_vector_size = svcntw();\n        #else\n            gla_vector_size = GLA_VECTOR_SIZE;\n        #endif\n        const int64_t vec_count = head_size / gla_vector_size;\n\n        for (int64_t t = 0; t < T; t++) {\n            size_t t_offset = t * t_stride;\n            size_t state_offset = head_size * C * (t / (T / n_seqs));\n            float * state_cur = state + state_offset;\n            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;\n\n            for (int64_t h = h_start; h < h_end; h++) {\n                size_t h_offset = h * h_stride;\n                size_t t_h_offset = t_offset + h_offset;\n                size_t h_2d_offset = h * h_stride_2d;\n\n                for (int64_t i = 0; i < head_size; i++) {\n                    size_t t_h_i_offset = t_h_offset + i;\n                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;\n\n                    float k_val = k[t_h_i_offset];\n                    float q_val = q[t_h_i_offset] * scale;\n                    float g_val = g[t_h_i_offset];\n\n                    // Broadcast scalar values to vectors\n                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);\n                    GGML_F32X q_vec = GGML_F32X_SET1(q_val);\n                    GGML_F32X g_vec = GGML_F32X_SET1(g_val);\n\n                    for (int64_t j = 0; j < vec_count; j++) {\n                        size_t base_j = j * gla_vector_size;\n                        size_t t_h_j_offset = t_h_offset + base_j;\n                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;\n\n                        // Load x elements at once\n                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);\n                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);\n                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);\n\n                        // Compute kv = v * k\n                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);\n\n                        // Compute temp = prev_state * g + kv\n                        GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);\n\n                        // Update dst: dst += temp * q\n                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);\n                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);\n\n                        // Update state\n                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);\n                    }\n\n                    // Handle remaining elements, this will not be used.\n                    for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {\n                        size_t t_h_j_offset = t_h_offset + j;\n                        size_t h_2d_i_j_offset = h_2d_i_offset + j;\n                        float v_val = v[t_h_j_offset];\n                        float kv_val = v_val * k_val;\n                        float prev_state_val = state_prev[h_2d_i_j_offset];\n                        float temp_val = kv_val + prev_state_val * g_val;\n                        dst_data[t_h_j_offset] += temp_val * q_val;\n                        state_cur[h_2d_i_j_offset] = temp_val;\n                    }\n                }\n            }\n        }\n\n    #else\n        for (int64_t t = 0; t < T; t++) {\n            size_t t_offset = t * t_stride;\n            size_t state_offset = head_size * C * (t / (T / n_seqs));\n            float * state_cur = state + state_offset;\n            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;\n\n            for (int64_t h = h_start; h < h_end; h++) {\n                size_t h_offset = h * h_stride;\n                size_t t_h_offset = t_offset + h_offset;\n                size_t h_2d_offset = h * h_stride_2d;\n\n                for (int64_t i = 0; i < head_size; i++) {\n                    size_t t_h_i_offset = t_h_offset + i;\n                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;\n\n                    float k_val = k[t_h_i_offset];\n                    float q_val = q[t_h_i_offset] * scale;\n                    float g_val = g[t_h_i_offset];\n\n                    for (int64_t j = 0; j < head_size; j++) {\n                        size_t t_h_j_offset = t_h_offset + j;\n                        size_t h_2d_i_j_offset = h_2d_i_offset + j;\n\n                        float v_val = v[t_h_j_offset];\n                        float kv_val = v_val * k_val;\n                        float prev_state_val = state_prev[h_2d_i_j_offset];\n                        float temp_val = prev_state_val * g_val + kv_val;\n                        dst_data[t_h_j_offset] += temp_val * q_val;\n                        state_cur[h_2d_i_j_offset] = temp_val;\n                    }\n                }\n            }\n        }\n    #endif\n}\n\n\nvoid ggml_compute_forward_gla(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_gla_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\nstatic void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {\n    const struct ggml_tensor * src0 = dst->src[0];  // A (lower triangular)\n    const struct ggml_tensor * src1 = dst->src[1];  // B (RHS)\n\n    GGML_TENSOR_BINARY_OP_LOCALS;\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type  == GGML_TYPE_F32);\n\n    GGML_ASSERT(ne00 == ne01); // A must be square\n    GGML_ASSERT(ne0  == ne10); // solution cols == B cols\n    GGML_ASSERT(ne1  == ne11); // solution rows == B rows\n\n    GGML_ASSERT(ne02 == ne12 && ne12 == ne2);\n    GGML_ASSERT(ne03 == ne13 && ne13 == ne3);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int64_t k = ne10;   // number of RHS columns\n    const int64_t n = ne11;   // A is n×n\n    const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit\n\n    // chunks per thread\n    const int64_t dr = (nr + nth - 1)/nth;\n\n    // chunk range for this thread\n    const int64_t ir0 = dr*ith;\n    const int64_t ir1 = MIN(ir0 + dr, nr);\n\n    const float * A = (const float *) src0->data;  // [n, n, B1, B2]\n    const float * B = (const float *) src1->data;  // [n, k, B1, B2]\n          float * X = (      float *) dst->data;   // [n, k, B1, B2]\n\n    for (int64_t ir = ir0; ir < ir1; ++ir) {\n        const int64_t i03 = ir/(ne02*k);\n        const int64_t i02 = (ir - i03*ne02*k)/k;\n        const int64_t i01 = (ir - i03*ne02*k - i02*k);\n\n        const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);\n        const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);\n\n        float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);\n\n        for (int64_t i00 = 0; i00 < n; ++i00) {\n            float sum = 0.0f;\n            for (int64_t t = 0; t < i00; ++t) {\n                sum += A_batch[i00 * n + t] * X_batch[t * k + i01];\n            }\n\n            const float diag = A_batch[i00 * n + i00];\n            assert(diag != 0.0f && \"Zero diagonal in triangular matrix\");\n\n            X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;\n        }\n    }\n}\n\nvoid ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {\n        ggml_compute_forward_solve_tri_f32(params, dst);\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n}\n\n// ggml_compute_forward_gated_delta_net\nstatic void ggml_compute_forward_gated_delta_net_one_chunk(\n    const ggml_compute_params * params,\n    ggml_tensor * dst,\n    int64_t ir0,\n    int64_t ir1) {\n\n    ggml_tensor * src_q     = dst->src[0];\n    ggml_tensor * src_k     = dst->src[1];\n    ggml_tensor * src_v     = dst->src[2];\n    ggml_tensor * src_g     = dst->src[3];\n    ggml_tensor * src_beta  = dst->src[4];\n    ggml_tensor * src_state = dst->src[5];\n\n    const int64_t S_v      = src_v->ne[0];\n    const int64_t H        = src_v->ne[1];\n    const int64_t n_tokens = src_v->ne[2];\n    const int64_t n_seqs   = src_v->ne[3];\n\n    GGML_ASSERT(ggml_is_contiguous_rows(src_q));\n    GGML_ASSERT(ggml_is_contiguous_rows(src_k));\n    GGML_ASSERT(ggml_is_contiguous_rows(src_v));\n    GGML_ASSERT(ggml_is_contiguous(src_g));\n    GGML_ASSERT(ggml_is_contiguous(src_beta));\n    GGML_ASSERT(ggml_is_contiguous(src_state));\n\n    GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v);\n    GGML_ASSERT(src_beta->ne[0] == 1);\n\n    GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);\n    GGML_TENSOR_LOCALS(size_t,  nbq, src_q, nb);\n    GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);\n    GGML_TENSOR_LOCALS(size_t,  nbk, src_k, nb);\n    GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);\n    GGML_TENSOR_LOCALS(size_t,  nbv, src_v, nb);\n    GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne);\n    GGML_TENSOR_LOCALS(size_t,  nbg, src_g, nb);\n    GGML_TENSOR_LOCALS(size_t,  nbb, src_beta, nb);\n\n    const bool kda = (neg0 == S_v);\n\n    // scratch layout per thread: [delta(S_v)]\n    const int64_t scratch_per_thread = S_v;\n    const int ith = params->ith;\n\n    float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;\n\n    // output layout: [attn_scores | new_states]\n    // attn_scores: S_v * H * n_tokens * n_seqs floats\n    // new_states:  S_v * S_v * H * n_seqs floats\n    const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;\n    float * attn_out_base  = (float *)dst->data;\n    float * state_out_base = (float *)dst->data + attn_score_elems;\n\n    const float * state_in_base = (const float *)src_state->data;\n\n  //const int64_t rq1 = nev1 / neq1;\n  //const int64_t rk1 = nev1 / nek1;\n    const int64_t rq3 = nev3 / neq3;\n    const int64_t rk3 = nev3 / nek3;\n\n    const float scale = 1.0f / sqrtf((float) S_v);\n\n    for (int64_t ir = ir0; ir < ir1; ++ir) {\n        const int64_t iv1 = ir % H; // head_index\n        const int64_t iv3 = ir / H; // sequence\n\n        const int64_t iq1 = iv1 % neq1;\n        const int64_t ik1 = iv1 % nek1;\n\n        const int64_t iq3 = iv3 / rq3;\n        const int64_t ik3 = iv3 / rk3;\n\n        float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v;\n\n        // copy input state into output buffer and operate in-place\n        const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;\n        memcpy(s_out, s_in, S_v * S_v * sizeof(float));\n\n        // attn output pointer for first token of this (head, seq)\n        float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v;\n\n        for (int64_t t = 0; t < n_tokens; t++) {\n            const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1);\n            const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1);\n            const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);\n\n            const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);\n            const float * g_d    =  (const float *)((const char *)src_g->data    + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);\n\n            // state is stored transposed: s_out[j*S_v + i] = S[i][j]\n            // so row j of s_out = column j of S (contiguous access)\n\n            if (kda) {\n                // precompute exp(g) into delta scratch (reused below)\n                for (int64_t i = 0; i < S_v; ++i) {\n                    delta[i] = expf(g_d[i]);\n                }\n                // S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])\n                for (int64_t j = 0; j < S_v; ++j) {\n                    ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);\n                }\n            } else {\n                ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));\n            }\n\n            // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)\n            for (int64_t j = 0; j < S_v; ++j) {\n                float sum = 0.0f;\n                ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);\n                delta[j] = (v_d[j] - sum) * beta_val;\n            }\n\n            // outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]\n            for (int64_t j = 0; j < S_v; ++j) {\n                ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);\n            }\n\n            // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)\n            for (int64_t j = 0; j < S_v; ++j) {\n                float sum = 0.0f;\n                ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);\n                attn_data[j] = sum * scale;\n            }\n\n            attn_data += S_v * H; // advance to next token\n        }\n    }\n}\n\n\nstatic void ggml_compute_forward_gated_delta_net_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    ggml_tensor * V = dst->src[2];\n    int64_t nr = V->ne[1] * V->ne[3];\n\n    // disable for NUMA\n    const bool disable_chunking = ggml_is_numa();\n\n    int nth = params->nth;\n    int ith = params->ith;\n\n    // 4x chunks per thread\n    int nth_scaled = nth * 4;\n    int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;\n    int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;\n\n    if (nth == 1 || nchunk < nth || disable_chunking) {\n      nchunk = nth;\n    }\n\n    if (ith == 0) {\n      ggml_threadpool_chunk_set(params->threadpool, nth);\n    }\n\n    ggml_barrier(params->threadpool);\n\n    const int64_t dr = (nr + nchunk - 1) / nchunk;\n\n    int current_chunk = ith;\n\n    while (current_chunk < nchunk) {\n        const int64_t ir0 = dr * current_chunk;\n        const int64_t ir1 = MIN(ir0 + dr, nr);\n\n        ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1);\n        current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);\n    }\n}\n\nvoid ggml_compute_forward_gated_delta_net(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_gated_delta_net_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_rwkv_wkv7\n\nstatic void ggml_compute_forward_rwkv_wkv7_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n    const int64_t T = dst->src[1]->ne[2];\n    const int64_t C = dst->ne[0];\n    const int64_t HEADS = dst->src[1]->ne[1];\n    const int64_t n_seqs = dst->src[6]->ne[1];\n    const int64_t head_size = C / HEADS;\n\n    float * dst_data = (float *) dst->data;\n    float * state = ((float *) dst->data) + C * T;\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    if (ith >= HEADS) {\n        return;\n    }\n\n    const int h_start = (HEADS * ith) / nth;\n    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?\n                (HEADS * (ith + 1)) / nth : HEADS;\n\n    float * r = (float *) dst->src[0]->data;\n    float * w = (float *) dst->src[1]->data;\n    float * k = (float *) dst->src[2]->data;\n    float * v = (float *) dst->src[3]->data;\n    float * a = (float *) dst->src[4]->data;\n    float * b = (float *) dst->src[5]->data;\n\n    int64_t t_stride = HEADS * head_size; // Same to C\n\n    int64_t h_stride = C / HEADS;\n    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS\n    int64_t h_stride_2d = head_size * head_size;\n\n    #if defined(GGML_SIMD)\n        #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)\n            // scalar Route to scalar implementation       //TODO: Write SVE code and RVV code\n            for (int64_t t = 0; t < T; t++) {\n                int64_t t_offset = t * t_stride;\n                int64_t state_offset = head_size * C * (t / (T / n_seqs));\n                float * state_cur = state + state_offset;\n                float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;\n\n                for (int64_t h = h_start; h < h_end; h++) {\n                    int64_t h_offset = h * h_stride;\n                    int64_t t_h_offset = t_offset + h_offset;\n                    int64_t h_2d_offset = h * h_stride_2d;\n\n                    for (int64_t i = 0; i < head_size; i++) {\n                        int64_t t_h_i_offset = t_h_offset + i;\n                        int64_t h_2d_i_offset = h_2d_offset + i * h_stride;\n\n                        float v_val = v[t_h_i_offset];\n\n                        float sa = 0, result = 0;\n                        for (int64_t j = 0; j < head_size; j++) {\n                            sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];\n                        }\n\n                        for (int64_t j = 0; j < head_size; j++) {\n                            int64_t t_h_j_offset = t_h_offset + j;\n                            int64_t h_2d_i_j_offset = h_2d_i_offset + j;\n\n                            float r_val = r[t_h_j_offset];\n                            float w_val = w[t_h_j_offset];\n                            float k_val = k[t_h_j_offset];\n                            float b_val = b[t_h_j_offset];\n                            float kv_val = v_val * k_val;\n                            float prev_state_val = state_prev[h_2d_i_j_offset];\n                            state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;\n                            result += state_cur[h_2d_i_j_offset] * r_val;\n                        }\n                        dst_data[t_h_i_offset] = result;\n                    }\n                }\n            }\n        #else\n            for (int64_t t = 0; t < T; t++) {\n                int64_t t_offset = t * t_stride;\n                int64_t state_offset = head_size * C * (t / (T / n_seqs));\n                float * state_cur = state + state_offset;\n                float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;\n\n                for (int64_t h = h_start; h < h_end; h++) {\n                    int64_t h_offset = h * h_stride;\n                    int64_t t_h_offset = t_offset + h_offset;\n                    int64_t h_2d_offset = h * h_stride_2d;\n\n                    for (int64_t ii = 0; ii < head_size; ii++) {\n                        int64_t t_h_i_offset = t_h_offset + ii;\n                        int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;\n\n                        GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);\n\n                        float sa = 0;\n                        {\n                            GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };\n                            GGML_F32_VEC ax[GGML_F32_ARR];\n                            GGML_F32_VEC ay[GGML_F32_ARR];\n                            for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {\n                                for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {\n                                    ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);\n                                    ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);\n                                    sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);\n                                }\n                            }\n                            GGML_F32_VEC_REDUCE(sa, sum);\n                        }\n\n                        GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);\n\n                        int64_t j = 0;\n                        GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };\n                        for (; j < head_size; j += GGML_F32_STEP) {\n                            for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {\n                                int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;\n                                int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;\n\n                                GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);\n                                GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);\n                                GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);\n                                GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);\n\n                                k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);\n\n                                GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);\n                                // kv + s * decay + sa * b\n                                state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);\n                                state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);\n                                GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);\n\n                                result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);\n                            }\n                        }\n                        GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);\n\n                        // There shouldn't be left-overs though.\n                        for (; j < head_size; j++) {\n                            int64_t t_h_j_offset = t_h_offset + j;\n                            int64_t h_2d_i_j_offset = h_2d_i_offset + j;\n\n                            float r_val = r[t_h_j_offset];\n                            float w_val = w[t_h_j_offset];\n                            float k_val = k[t_h_j_offset];\n                            float b_val = b[t_h_j_offset];\n                            float kv_val = v[t_h_i_offset] * k_val;\n\n                            float prev_state_val = state_prev[h_2d_i_j_offset];\n                            state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;\n                            dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;\n                        }\n                    }\n                }\n            }\n        #endif\n    #else\n        for (int64_t t = 0; t < T; t++) {\n            int64_t t_offset = t * t_stride;\n            int64_t state_offset = head_size * C * (t / (T / n_seqs));\n            float * state_cur = state + state_offset;\n            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;\n\n            for (int64_t h = h_start; h < h_end; h++) {\n                int64_t h_offset = h * h_stride;\n                int64_t t_h_offset = t_offset + h_offset;\n                int64_t h_2d_offset = h * h_stride_2d;\n\n                for (int64_t i = 0; i < head_size; i++) {\n                    int64_t t_h_i_offset = t_h_offset + i;\n                    int64_t h_2d_i_offset = h_2d_offset + i * h_stride;\n\n                    float v_val = v[t_h_i_offset];\n\n                    float sa = 0, result = 0;\n                    for (int64_t j = 0; j < head_size; j++) {\n                        sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];\n                    }\n\n                    for (int64_t j = 0; j < head_size; j++) {\n                        int64_t t_h_j_offset = t_h_offset + j;\n                        int64_t h_2d_i_j_offset = h_2d_i_offset + j;\n\n                        float r_val = r[t_h_j_offset];\n                        float w_val = w[t_h_j_offset];\n                        float k_val = k[t_h_j_offset];\n                        float b_val = b[t_h_j_offset];\n                        float kv_val = v_val * k_val;\n                        float prev_state_val = state_prev[h_2d_i_j_offset];\n                        state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;\n                        result += state_cur[h_2d_i_j_offset] * r_val;\n                    }\n                    dst_data[t_h_i_offset] = result;\n                }\n            }\n        }\n    #endif\n}\n\n\nvoid ggml_compute_forward_rwkv_wkv7(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_rwkv_wkv7_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_map_custom1\n\nvoid ggml_compute_forward_map_custom1(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * a = dst->src[0];\n\n    struct ggml_map_custom1_op_params p;\n    memcpy(&p, dst->op_params, sizeof(p));\n\n    p.fun(dst, a, params->ith, params->nth, p.userdata);\n}\n\n// ggml_compute_forward_map_custom2\n\nvoid ggml_compute_forward_map_custom2(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * a = dst->src[0];\n    const ggml_tensor * b = dst->src[1];\n\n    struct ggml_map_custom2_op_params p;\n    memcpy(&p, dst->op_params, sizeof(p));\n\n    p.fun(dst, a, b, params->ith, params->nth, p.userdata);\n}\n\n// ggml_compute_forward_map_custom3\n\nvoid ggml_compute_forward_map_custom3(\n        const ggml_compute_params * params,\n              ggml_tensor * dst) {\n\n    const ggml_tensor * a = dst->src[0];\n    const ggml_tensor * b = dst->src[1];\n    const ggml_tensor * c = dst->src[2];\n\n    struct ggml_map_custom3_op_params p;\n    memcpy(&p, dst->op_params, sizeof(p));\n\n    p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);\n}\n\n// ggml_compute_forward_custom\n\nvoid ggml_compute_forward_custom(\n    const struct ggml_compute_params * params,\n          struct ggml_tensor * dst) {\n\n    struct ggml_custom_op_params p;\n    memcpy(&p, dst->op_params, sizeof(p));\n\n    p.fun(dst, params->ith, params->nth, p.userdata);\n}\n\n// ggml_compute_forward_cross_entropy_loss\n\nstatic void ggml_compute_forward_cross_entropy_loss_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));\n    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));\n    GGML_ASSERT(ggml_are_same_shape(src0, src1));\n    GGML_ASSERT(ggml_is_scalar(dst));\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    // TODO: handle transposed/permuted matrices\n    const int64_t nc = src0->ne[0];\n    const int64_t nr = ggml_nrows(src0);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    float * sums =  (float *) params->wdata;\n    float * st   = ((float *) params->wdata) + nth + ith*nc;\n    float sum_thread = 0.0f;\n\n    GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));\n\n    // rows per thread\n    const int64_t dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int64_t ir0 = dr*ith;\n    const int64_t ir1 = MIN(ir0 + dr, nr);\n\n    for (int64_t i1 = ir0; i1 < ir1; ++i1) {\n        const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);\n        const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);\n\n#ifndef NDEBUG\n        for (int64_t i = 0; i < nc; ++i) {\n            //printf(\"p[%d] = %f\\n\", i, p[i]);\n            assert(!isnan(s0[i]));\n            assert(!isnan(s1[i]));\n        }\n#endif // NDEBUG\n\n        float max = -INFINITY;\n        ggml_vec_max_f32(nc, &max, s0);\n        const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);\n        assert(sum_softmax >= 0.0);\n\n        ggml_vec_add1_f32(nc, st, st, -sum_softmax);\n        ggml_vec_mul_f32(nc, st, st, s1);\n\n        float sum_st = 0.0f;\n        ggml_vec_sum_f32(nc, &sum_st, st);\n        sum_thread += sum_st;\n\n#ifndef NDEBUG\n        for (int64_t i = 0; i < nc; ++i) {\n            assert(!isnan(st[i]));\n            assert(!isinf(st[i]));\n        }\n#endif // NDEBUG\n    }\n    sums[ith] = sum_thread;\n    ggml_barrier(params->threadpool);\n\n    if (ith == 0) {\n        float * dp = (float *) dst->data;\n        ggml_vec_sum_f32(nth, dp, sums);\n        dp[0] *= -1.0f / (float) nr;\n    }\n}\n\nvoid ggml_compute_forward_cross_entropy_loss(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_cross_entropy_loss_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\n// ggml_compute_forward_cross_entropy_loss_back\n\nstatic void ggml_compute_forward_cross_entropy_loss_back_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * grad  = dst->src[0]; // gradient of forward pass output\n    const ggml_tensor * src0f = dst->src[1]; // src0 of forward pass\n    const ggml_tensor * src1f = dst->src[2]; // src1 of forward pass\n\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    GGML_ASSERT(ggml_is_contiguous(src0f));\n    GGML_ASSERT(ggml_is_contiguous(src1f));\n    GGML_ASSERT(ggml_is_contiguous(grad));\n    GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));\n\n    const int64_t ith = params->ith;\n    const int64_t nth = params->nth;\n\n    // TODO: handle transposed/permuted matrices\n    const int64_t nc = src0f->ne[0];\n    const int64_t nr = ggml_nrows(src0f);\n\n    // rows per thread\n    const int64_t dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int64_t ir0 = dr*ith;\n    const int64_t ir1 = MIN(ir0 + dr, nr);\n\n    const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;\n\n    for (int64_t i1 = ir0; i1 < ir1; i1++) {\n        float       * ds0 = (float       *)((char       *) dst->data   + i1*dst->nb[1]);\n        const float * s0  = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);\n        const float * s1  = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);\n\n#ifndef NDEBUG\n        for (int64_t i = 0; i < nc; ++i) {\n            //printf(\"p[%d] = %f\\n\", i, p[i]);\n            assert(!isnan(s0[i]));\n            assert(!isnan(s1[i]));\n        }\n#endif // NDEBUG\n\n        // soft_max\n        float max = -INFINITY;\n        ggml_vec_max_f32(nc, &max, s0);\n        const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);\n        assert(sum > 0.0);\n        ggml_vec_scale_f32(nc, ds0, 1.0/sum);\n\n        // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr\n        ggml_vec_sub_f32(nc, ds0, ds0, s1);\n        ggml_vec_scale_f32(nc, ds0, d_by_nr);\n\n#ifndef NDEBUG\n        for (int64_t i = 0; i < nc; ++i) {\n            assert(!isnan(ds0[i]));\n            assert(!isinf(ds0[i]));\n        }\n#endif // NDEBUG\n    }\n}\n\nvoid ggml_compute_forward_cross_entropy_loss_back(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\nstatic void ggml_compute_forward_opt_step_adamw_f32(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0         = dst->src[0];\n    const ggml_tensor * src0_grad    = dst->src[1];\n    const ggml_tensor * src0_grad_m  = dst->src[2];\n    const ggml_tensor * src0_grad_v  = dst->src[3];\n    const ggml_tensor * adamw_params = dst->src[4];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));\n    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));\n    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));\n    GGML_ASSERT(ggml_nelements(adamw_params) == 7);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nr  = ggml_nrows(src0);\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n    GGML_ASSERT(nb00 == sizeof(float));\n\n    // rows per thread\n    const int dr = (nr + nth - 1)/nth;\n\n    // row range for this thread\n    const int ir0 = dr*ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);\n\n    const float alpha  = adamw_params_ptr[0];\n    const float beta1  = adamw_params_ptr[1];\n    const float beta2  = adamw_params_ptr[2];\n    const float eps    = adamw_params_ptr[3];\n    const float wd     = adamw_params_ptr[4];\n    const float beta1h = adamw_params_ptr[5];\n    const float beta2h = adamw_params_ptr[6];\n    const float keep   = 1.f - alpha * wd;\n    for (int ir = ir0; ir < ir1; ++ir) {\n        const int64_t i03 = ir/(ne02*ne01);\n        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;\n        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);\n\n        const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;\n\n        float       * w = (float       *) ((char       *) src0->data        + offset); // weight\n        const float * g = (const float *) ((const char *) src0_grad->data   + offset); // grad\n        float       * m = (float       *) ((char       *) src0_grad_m->data + offset);\n        float       * v = (float       *) ((char       *) src0_grad_v->data + offset);\n\n        for (int i00 = 0; i00 < ne00; ++i00) {\n            m[i00] = m[i00]*beta1 +        g[i00]*(1.0f - beta1);\n            v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);\n\n            const float mh =       m[i00]*beta1h;\n            const float vh = sqrtf(v[i00]*beta2h) + eps;\n\n            // The weight decay is applied independently of the Adam momenta m and v.\n            // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.\n            // See: https://arxiv.org/pdf/1711.05101v3.pdf\n            w[i00] = w[i00] * keep - alpha * mh / vh;\n        }\n    }\n}\n\nvoid ggml_compute_forward_opt_step_adamw(\n        const ggml_compute_params * params,\n        ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_opt_step_adamw_f32(params, dst);\n            } break;\n        default:\n            {\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n}\n\nstatic void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {\n    const ggml_tensor * src0       = dst->src[0];\n    const ggml_tensor * src0_grad  = dst->src[1];\n    const ggml_tensor * sgd_params = dst->src[2];\n\n    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));\n    GGML_ASSERT(ggml_nelements(sgd_params) == 2);\n\n    const int ith = params->ith;\n    const int nth = params->nth;\n\n    const int nr = ggml_nrows(src0);\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n    GGML_ASSERT(nb00 == sizeof(float));\n\n    // rows per thread\n    const int dr = (nr + nth - 1) / nth;\n\n    // row range for this thread\n    const int ir0 = dr * ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    // using adamw param subset we care about - alpha, wd - could have a separate struct\n    const float * sgd_params_ptr   = ggml_get_data_f32(sgd_params);\n    const float   alpha            = sgd_params_ptr[0];\n    const float   keep             = 1.f - alpha * sgd_params_ptr[1];\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        const int64_t i03 = ir / (ne02 * ne01);\n        const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;\n        const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);\n\n        const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;\n\n        float *       w = (float *) ((char *) src0->data + offset);                   // weight\n        const float * g = (const float *) ((const char *) src0_grad->data + offset);  // grad\n\n        for (int i00 = 0; i00 < ne00; ++i00) {\n            w[i00] = w[i00] * keep - alpha * g[i00];\n        }\n    }\n}\n\nvoid ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            {\n                ggml_compute_forward_opt_step_sgd_f32(params, dst);\n            }\n            break;\n        default:\n            {\n                GGML_ABORT(\"fatal error - sgd is F32 only\");\n            }\n    }\n}\n"
  },
  {
    "path": "src/ggml-cpu/ops.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n\n//\n// cache line\n//\n\n#if defined(__cpp_lib_hardware_interference_size)\n#define CACHE_LINE_SIZE std::hardware_destructive_interference_size\n#else\n#if defined(__POWER9_VECTOR__)\n#define CACHE_LINE_SIZE 128\n#elif defined(__VXE__) || defined(__VXE2__)\n#define CACHE_LINE_SIZE 256\n#else\n#define CACHE_LINE_SIZE 64\n#endif\n#endif\n\nstatic const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);\n\n// Work buffer size for im2col operations in CONV2D\n#define GGML_IM2COL_WORK_SIZE (16 * 1024 * 1024)\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nvoid ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_add_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_cumsum(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_repeat(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_repeat_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_concat(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_silu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_rms_norm_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_group_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_l2_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_out_prod(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_scale(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_set(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_cpy(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_cont(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_soft_max(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_soft_max_ext_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_rope(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_rope_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_clamp(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_pool_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_upscale(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_pad(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_pad_reflect_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_roll(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_flash_attn_back(\n        const struct ggml_compute_params * params,\n        const bool masked,\n        struct ggml_tensor * dst);\nvoid ggml_compute_forward_ssm_conv(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_custom(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/quants.c",
    "content": "#define GGML_COMMON_IMPL_C\n#include \"ggml-common.h\"\n\n#include \"ggml-cpu-impl.h\"\n#include \"simd-mappings.h\"\n#include \"ggml-quants.h\"\n#include \"quants.h\"\n\n#include \"arch-fallback.h\"\n\n#include <string.h>\n#include <assert.h>\n#include <float.h>\n#include <stdlib.h> // for qsort\n#include <stdio.h>  // for GGML_ASSERT\n\n#define GROUP_MAX_EPS 1e-15f\n#define GROUP_MAX_EPS_IQ3_XXS 1e-8f\n#define GROUP_MAX_EPS_IQ2_S 1e-8f\n#define GROUP_MAX_EPS_IQ1_M 1e-7f\n#define GROUP_MAX_EPS_IQ1_S 1e-12f\n\n#define UNUSED GGML_UNUSED\n\nvoid quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n    quantize_row_q4_0_ref(x, y, k);\n}\n\nvoid quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n    quantize_row_q4_1_ref(x, y, k);\n}\n\nvoid quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n    quantize_row_q5_0_ref(x, y, k);\n}\n\nvoid quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n    quantize_row_q5_1_ref(x, y, k);\n}\n\nvoid quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n    quantize_row_q8_0_ref(x, y, k);\n}\n\nvoid quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n    quantize_row_q8_1_ref(x, y, k);\n}\n\nvoid quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n    quantize_row_mxfp4_ref(x, y, k);\n}\n\nvoid quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n    quantize_row_nvfp4_ref(x, y, k);\n}\n\n//\n// 2-6 bit quantization in super-blocks\n//\n\n//========================- 2-bit (de)-quantization\n\nvoid quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    quantize_row_q2_K_ref(x, vy, k);\n}\n\n//========================= 3-bit (de)-quantization\n\nvoid quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    quantize_row_q3_K_ref(x, vy, k);\n}\n\n// ====================== 4-bit (de)-quantization\n\nvoid quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(k % QK_K == 0);\n    block_q4_K * GGML_RESTRICT y = vy;\n    quantize_row_q4_K_ref(x, y, k);\n}\n\n// ====================== 5-bit (de)-quantization\n\nvoid quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(k % QK_K == 0);\n    block_q5_K * GGML_RESTRICT y = vy;\n    quantize_row_q5_K_ref(x, y, k);\n}\n\n// ====================== 6-bit (de)-quantization\n\nvoid quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(k % QK_K == 0);\n    block_q6_K * GGML_RESTRICT y = vy;\n    quantize_row_q6_K_ref(x, y, k);\n}\n\n// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)\n\nvoid quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(k % QK_K == 0);\n    block_tq1_0 * GGML_RESTRICT y = vy;\n    quantize_row_tq1_0_ref(x, y, k);\n}\n\nvoid quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(k % QK_K == 0);\n    block_tq2_0 * GGML_RESTRICT y = vy;\n    quantize_row_tq2_0_ref(x, y, k);\n}\n\n//===================================== Q8_K ==============================================\n\nvoid quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n    quantize_row_q8_K_ref(x, y, k);\n}\n\n//===================================== Dot products =================================\n\nvoid ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n    for (; ib < nb; ++ib) {\n        int sumi0 = 0;\n        int sumi1 = 0;\n\n        for (int j = 0; j < qk/2; ++j) {\n            const int v0 = (x[ib].qs[j] & 0x0F) - 8;\n            const int v1 = (x[ib].qs[j] >>   4) - 8;\n\n            sumi0 += (v0 * y[ib].qs[j]);\n            sumi1 += (v1 * y[ib].qs[j + qk/2]);\n        }\n\n        int sumi = sumi0 + sumi1;\n        sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);\n    }\n\n    *s = sumf;\n}\n\n// TODO: add WASM SIMD\nvoid ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n    for (; ib < nb; ++ib) {\n        int sumi0 = 0;\n        int sumi1 = 0;\n\n        for (int j = 0; j < qk/2; ++j) {\n            const int v0 = (x[ib].qs[j] & 0x0F);\n            const int v1 = (x[ib].qs[j] >>   4);\n\n            sumi0 += (v0 * y[ib].qs[j]);\n            sumi1 += (v1 * y[ib].qs[j + qk/2]);\n        }\n\n        int sumi = sumi0 + sumi1;\n        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_MXFP4 == 0);\n    static_assert(QK_MXFP4 == QK8_0, \"QK_MXFP4 and QK8_0 must be the same\");\n\n    const block_mxfp4 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_MXFP4;\n\n    int ib = 0;\n    float sumf = 0;\n\n    for (; ib < nb; ++ib) {\n        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);\n\n        int sumi1 = 0;\n        int sumi2 = 0;\n        for (int j = 0; j < QK_MXFP4/2; ++j) {\n            sumi1 += y[ib].qs[j +          0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];\n            sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >>  4];\n        }\n        sumf += d * (sumi1 + sumi2);\n    }\n    *s = sumf;\n}\n\n// NVFP4: super-block of 64 elements = 4 sub-blocks of 16 = 2 q8_0 blocks\nvoid ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_NVFP4 == 0);\n\n    const block_nvfp4 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_NVFP4;\n\n    float sumf = 0;\n\n    for (int ib = 0; ib < nb; ++ib) {\n        for (int s_idx = 0; s_idx < 4; ++s_idx) {\n            const float d = ggml_ue4m3_to_fp32(x[ib].d[s_idx]);\n            const int q8_block = s_idx / 2;\n            const int q8_off   = (s_idx % 2) * QK_NVFP4_SUB;\n            const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8_block].d);\n\n            int sumi_lo = 0, sumi_hi = 0;\n            for (int j = 0; j < QK_NVFP4_SUB/2; ++j) {\n                const uint8_t qv = x[ib].qs[s_idx*(QK_NVFP4_SUB/2) + j];\n                sumi_lo += y[2*ib + q8_block].qs[q8_off + j +               0] * kvalues_mxfp4[qv & 0xf];\n                sumi_hi += y[2*ib + q8_block].qs[q8_off + j + QK_NVFP4_SUB/2] * kvalues_mxfp4[qv >>  4];\n            }\n\n            sumf += dy * d * (sumi_lo + sumi_hi);\n        }\n    }\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    int ib = 0;\n    float sumf = 0;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    for (; ib < nb; ++ib) {\n        uint32_t qh;\n        memcpy(&qh, x[ib].qh, sizeof(qh));\n\n        int sumi0 = 0;\n        int sumi1 = 0;\n\n        for (int j = 0; j < qk/2; ++j) {\n            const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;\n            const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));\n\n            const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);\n            const int32_t x1 = (int8_t)(((x[ib].qs[j] >>   4) | xh_1) - 16);\n\n            sumi0 += (x0 * y[ib].qs[j]);\n            sumi1 += (x1 * y[ib].qs[j + qk/2]);\n        }\n\n        int sumi = sumi0 + sumi1;\n        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_1;\n    const int nb = n / qk;\n\n    int ib = 0;\n    float sumf = 0;\n\n    assert(n % qk == 0);\n    assert(qk == QK5_1);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_1 * GGML_RESTRICT x = vx;\n    const block_q8_1 * GGML_RESTRICT y = vy;\n\n    for (; ib < nb; ++ib) {\n        uint32_t qh;\n        memcpy(&qh, x[ib].qh, sizeof(qh));\n\n        int sumi0 = 0;\n        int sumi1 = 0;\n\n        for (int j = 0; j < qk/2; ++j) {\n            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;\n            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;\n\n            const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;\n            const int32_t x1 = (x[ib].qs[j] >>  4) | xh_1;\n\n            sumi0 += (x0 * y[ib].qs[j]);\n            sumi1 += (x1 * y[ib].qs[j + qk/2]);\n        }\n\n        int sumi = sumi0 + sumi1;\n        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n\n    assert(n % qk == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q8_0 * GGML_RESTRICT x = vx;\n    const block_q8_0 * GGML_RESTRICT y = vy;\n\n    int ib = 0;\n    float sumf = 0;\n\n    for (; ib < nb; ++ib) {\n        int sumi = 0;\n\n        for (int j = 0; j < qk; j++) {\n            sumi += x[ib].qs[j]*y[ib].qs[j];\n        }\n\n        sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_tq1_0 * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};\n\n    float sumf = 0.0f;\n\n    for (int i = 0; i < nb; ++i) {\n        int sum = 0;\n\n        for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {\n            for (size_t l = 0; l < 5; ++l) {\n                for (size_t m = 0; m < 32; ++m) {\n                    uint8_t q = x[i].qs[j + m] * pow3[l];\n                    uint16_t xi = ((uint16_t) q * 3) >> 8;\n                    sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];\n                }\n            }\n        }\n        for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {\n            for (size_t l = 0; l < 5; ++l) {\n                for (size_t m = 0; m < 16; ++m) {\n                    uint8_t q = x[i].qs[j + m] * pow3[l];\n                    uint16_t xi = ((uint16_t) q * 3) >> 8;\n                    sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];\n                }\n            }\n        }\n\n        for (size_t l = 0; l < 4; ++l) {\n            for (size_t j = 0; j < sizeof(x->qh); ++j) {\n                uint8_t q = x[i].qh[j] * pow3[l];\n                uint16_t xi = ((uint16_t) q * 3) >> 8;\n                sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];\n            }\n        }\n\n        sumf += (float) sum * (GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d);\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_tq2_0 * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n    float sumf = 0.0f;\n\n    for (int i = 0; i < nb; ++i) {\n        int32_t sumi = 0;\n\n        for (size_t j = 0; j < sizeof(x->qs); j += 32) {\n            for (size_t l = 0; l < 4; ++l) {\n                for (size_t k = 0; k < 32; ++k) {\n                    sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);\n                }\n            }\n        }\n\n        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n\n        sumf += (float) sumi * d;\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q2_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    float sumf = 0;\n\n    for (int i = 0; i < nb; ++i) {\n\n        const uint8_t * q2 = x[i].qs;\n        const  int8_t * q8 = y[i].qs;\n        const uint8_t * sc = x[i].scales;\n\n        int summs = 0;\n        for (int j = 0; j < 16; ++j) {\n            summs += y[i].bsums[j] * (sc[j] >> 4);\n        }\n\n        const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);\n        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);\n\n        int isum = 0;\n        int is = 0;\n        int d;\n        for (int k = 0; k < QK_K/128; ++k) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n                d = sc[is++] & 0xF;\n                int isuml = 0;\n                for (int l =  0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);\n                isum += d * isuml;\n                d = sc[is++] & 0xF;\n                isuml = 0;\n                for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);\n                isum += d * isuml;\n                shift += 2;\n                q8 += 32;\n            }\n            q2 += 32;\n        }\n        sumf += dall * isum - dmin * summs;\n    }\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n\n    const block_q3_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    // scalar version\n    // This function is written like this so the compiler can manage to vectorize most of it\n    // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the\n    // manually vectorized version above. Every other version I tried would run at least 4 times slower.\n    // The ideal situation would be if we could just write the code once, and the compiler would\n    // automatically produce the best possible set of machine instructions, instead of us having to manually\n    // write vectorized versions for AVX, ARM_NEON, etc.\n\n    int8_t  aux8[QK_K];\n    int16_t aux16[8];\n    float   sums [8];\n    int32_t aux32[8];\n    memset(sums, 0, 8*sizeof(float));\n\n    uint32_t auxs[4];\n    const int8_t * scales = (const int8_t*)auxs;\n\n    float sumf = 0;\n    for (int i = 0; i < nb; ++i) {\n        const uint8_t * GGML_RESTRICT q3 = x[i].qs;\n        const uint8_t * GGML_RESTRICT hm = x[i].hmask;\n        const  int8_t * GGML_RESTRICT q8 = y[i].qs;\n        memset(aux32, 0, 8*sizeof(int32_t));\n        int8_t * GGML_RESTRICT a = aux8;\n        uint8_t m = 1;\n        for (int j = 0; j < QK_K; j += 128) {\n            for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;\n            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);\n            a += 32; m <<= 1;\n            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;\n            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);\n            a += 32; m <<= 1;\n            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;\n            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);\n            a += 32; m <<= 1;\n            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;\n            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);\n            a += 32; m <<= 1;\n            q3 += 32;\n        }\n        a = aux8;\n\n        memcpy(auxs, x[i].scales, 12);\n        uint32_t tmp = auxs[2];\n        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);\n        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);\n        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);\n        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);\n        for (int j = 0; j < QK_K/16; ++j) {\n            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];\n            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];\n            q8 += 8; a += 8;\n            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];\n            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];\n            q8 += 8; a += 8;\n        }\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];\n    }\n    for (int l = 0; l < 8; ++l) sumf += sums[l];\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q4_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n    const uint8_t * scales = (const uint8_t*)&utmp[0];\n    const uint8_t * mins   = (const uint8_t*)&utmp[2];\n\n    int8_t  aux8[QK_K];\n    int16_t aux16[8];\n    float   sums [8];\n    int32_t aux32[8];\n    memset(sums, 0, 8*sizeof(float));\n\n    float sumf = 0;\n    for (int i = 0; i < nb; ++i) {\n        const uint8_t * GGML_RESTRICT q4 = x[i].qs;\n        const  int8_t * GGML_RESTRICT q8 = y[i].qs;\n        memset(aux32, 0, 8*sizeof(int32_t));\n        int8_t * GGML_RESTRICT a = aux8;\n        for (int j = 0; j < QK_K/64; ++j) {\n            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);\n            a += 32;\n            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);\n            a += 32; q4 += 32;\n        }\n        memcpy(utmp, x[i].scales, 12);\n        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n        const uint32_t uaux = utmp[1] & kmask1;\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[2] = uaux;\n        utmp[0] &= kmask1;\n\n        int sumi = 0;\n        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];\n        a = aux8;\n        int is = 0;\n        for (int j = 0; j < QK_K/32; ++j) {\n            int32_t scale = scales[is++];\n            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];\n            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];\n            q8 += 8; a += 8;\n            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];\n            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];\n            q8 += 8; a += 8;\n            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];\n            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];\n            q8 += 8; a += 8;\n            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];\n            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];\n            q8 += 8; a += 8;\n        }\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];\n        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;\n        sumf -= dmin * sumi;\n    }\n    for (int l = 0; l < 8; ++l) sumf += sums[l];\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q5_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q5_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    uint32_t utmp[4];\n\n    const uint8_t * scales = (const uint8_t*)&utmp[0];\n    const uint8_t * mins   = (const uint8_t*)&utmp[2];\n\n    int8_t  aux8[QK_K];\n    int16_t aux16[8];\n    float   sums [8];\n    int32_t aux32[8];\n    memset(sums, 0, 8*sizeof(float));\n\n    float sumf = 0;\n    for (int i = 0; i < nb; ++i) {\n        const uint8_t * GGML_RESTRICT q4 = x[i].qs;\n        const uint8_t * GGML_RESTRICT hm = x[i].qh;\n        const  int8_t * GGML_RESTRICT q8 = y[i].qs;\n        memset(aux32, 0, 8*sizeof(int32_t));\n        int8_t * GGML_RESTRICT a = aux8;\n        uint8_t m = 1;\n        for (int j = 0; j < QK_K/64; ++j) {\n            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);\n            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);\n            a += 32; m <<= 1;\n            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);\n            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);\n            a += 32; m <<= 1;\n            q4 += 32;\n        }\n        memcpy(utmp, x[i].scales, 12);\n        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n        const uint32_t uaux = utmp[1] & kmask1;\n        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n        utmp[2] = uaux;\n        utmp[0] &= kmask1;\n\n        int sumi = 0;\n        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];\n        a = aux8;\n        int is = 0;\n        for (int j = 0; j < QK_K/32; ++j) {\n            int32_t scale = scales[is++];\n            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];\n            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];\n            q8 += 8; a += 8;\n            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];\n            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];\n            q8 += 8; a += 8;\n            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];\n            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];\n            q8 += 8; a += 8;\n            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];\n            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];\n            q8 += 8; a += 8;\n        }\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];\n        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;\n        sumf -= dmin * sumi;\n    }\n    for (int l = 0; l < 8; ++l) sumf += sums[l];\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_q6_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_q6_K * GGML_RESTRICT x = vx;\n    const block_q8_K * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    int8_t  aux8[QK_K];\n    int16_t aux16[8];\n    float   sums [8];\n    int32_t aux32[8];\n    memset(sums, 0, 8*sizeof(float));\n\n    float sumf = 0;\n    for (int i = 0; i < nb; ++i) {\n        const uint8_t * GGML_RESTRICT q4 = x[i].ql;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const  int8_t * GGML_RESTRICT q8 = y[i].qs;\n        memset(aux32, 0, 8*sizeof(int32_t));\n        int8_t * GGML_RESTRICT a = aux8;\n        for (int j = 0; j < QK_K; j += 128) {\n            for (int l = 0; l < 32; ++l) {\n                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;\n                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;\n                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;\n                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;\n            }\n            a  += 128;\n            q4 += 64;\n            qh += 32;\n        }\n        a = aux8;\n        int is = 0;\n        for (int j = 0; j < QK_K/16; ++j) {\n            int scale = x[i].scales[is++];\n            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];\n            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];\n            q8 += 8; a += 8;\n            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];\n            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];\n            q8 += 8; a += 8;\n        }\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];\n    }\n    for (int l = 0; l < 8; ++l) sumf += sums[l];\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_iq2_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_xxs * GGML_RESTRICT x = vx;\n    const block_q8_K    * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    uint32_t aux32[2];\n    const uint8_t * aux8 = (const uint8_t *)aux32;\n\n    float sumf = 0.f;\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint16_t * GGML_RESTRICT q2 = x[i].qs;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n        int32_t bsum = 0;\n        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {\n            memcpy(aux32, q2, 2*sizeof(uint32_t));\n            q2 += 4;\n            const uint32_t ls = 2*(aux32[1] >> 28) + 1;\n            int32_t sumi = 0;\n            for (int l = 0; l < 4; ++l) {\n                const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);\n                const uint8_t  signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];\n                for (int j = 0; j < 8; ++j) {\n                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);\n                }\n                q8 += 8;\n            }\n            bsum += sumi * ls;\n        }\n        sumf += d * bsum;\n    }\n    *s = 0.125f * sumf;\n}\n\nvoid ggml_vec_dot_iq2_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_xs * GGML_RESTRICT x = vx;\n    const block_q8_K   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    float sumf = 0.f;\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint16_t * GGML_RESTRICT q2 = x[i].qs;\n        const uint8_t  * GGML_RESTRICT sc = x[i].scales;\n        const int8_t   * GGML_RESTRICT q8 = y[i].qs;\n        int32_t bsum = 0;\n        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {\n            const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;\n            const uint16_t ls2 = 2*(sc[ib32] >>  4) + 1;\n            int32_t sumi = 0;\n            for (int l = 0; l < 2; ++l) {\n                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));\n                const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];\n                for (int j = 0; j < 8; ++j) {\n                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);\n                }\n                q8 += 8;\n            }\n            bsum += sumi * ls1;\n            sumi = 0;\n            for (int l = 2; l < 4; ++l) {\n                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));\n                const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];\n                for (int j = 0; j < 8; ++j) {\n                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);\n                }\n                q8 += 8;\n            }\n            bsum += sumi * ls2;\n            q2 += 4;\n        }\n        sumf += d * bsum;\n    }\n    *s = 0.125f * sumf;\n}\n\nvoid ggml_vec_dot_iq2_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq2_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    float sumf = 0;\n    for (int i = 0; i < nb; i++) {\n\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const int8_t  * q8 = y[i].qs;\n        const uint8_t * qs = x[i].qs;\n        const uint8_t * qh = x[i].qh;\n        const uint8_t * signs = qs + QK_K/8;\n\n        int bsum = 0;\n        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {\n            int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf);\n            int ls2 = 1 + 2*(x[i].scales[ib32] >>  4);\n            int sumi1 = 0, sumi2 = 0;\n            for (int l = 0; l < 2; ++l) {\n                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));\n                for (int j = 0; j < 8; ++j) {\n                    sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);\n                }\n                q8 += 8;\n            }\n            for (int l = 2; l < 4; ++l) {\n                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));\n                for (int j = 0; j < 8; ++j) {\n                    sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);\n                }\n                q8 += 8;\n            }\n            bsum += ls1 * sumi1 + ls2 * sumi2;\n            qs += 4;\n            signs += 4;\n        }\n\n        sumf += d * bsum;\n    }\n\n    *s = 0.125f * sumf;\n}\n\nvoid ggml_vec_dot_iq3_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq3_xxs * GGML_RESTRICT x = vx;\n    const block_q8_K    * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    uint32_t aux32;\n\n    float sumf = 0.f;\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint8_t * GGML_RESTRICT q3 = x[i].qs;\n        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n        int32_t bsum = 0;\n        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {\n            memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);\n            const uint32_t ls = 2*(aux32 >> 28) + 1;\n            int32_t sumi = 0;\n            for (int l = 0; l < 4; ++l) {\n                const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);\n                const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);\n                const uint8_t  signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];\n                for (int j = 0; j < 4; ++j) {\n                    sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);\n                    sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);\n                }\n                q8 += 8;\n            }\n            q3 += 8;\n            bsum += sumi * ls;\n        }\n        sumf += d * bsum;\n    }\n    *s = 0.25f * sumf;\n}\n\nvoid ggml_vec_dot_iq3_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq3_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    float sumf = 0.f;\n    for (int i = 0; i < nb; ++i) {\n        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;\n        const uint8_t * GGML_RESTRICT qs = x[i].qs;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const uint8_t * GGML_RESTRICT signs = x[i].signs;\n        const int8_t  * GGML_RESTRICT q8 = y[i].qs;\n        int32_t bsum = 0;\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;\n            const uint32_t ls2 = 2*(x[i].scales[ib32/2] >>  4) + 1;\n            int32_t sumi = 0;\n            for (int l = 0; l < 4; ++l) {\n                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));\n                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));\n                for (int j = 0; j < 4; ++j) {\n                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);\n                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);\n                }\n                q8 += 8;\n            }\n            qs += 8;\n            signs += 4;\n            bsum += sumi * ls1;\n            sumi = 0;\n            for (int l = 0; l < 4; ++l) {\n                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));\n                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));\n                for (int j = 0; j < 4; ++j) {\n                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);\n                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);\n                }\n                q8 += 8;\n            }\n            qs += 8;\n            signs += 4;\n            bsum += sumi * ls2;\n        }\n        sumf += d * bsum;\n    }\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_iq1_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq1_s * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    float sumf = 0;\n    for (int i = 0; i < nb; i++) {\n\n        const int8_t   * q8 = y[i].qs;\n        const uint8_t  * qs = x[i].qs;\n        const uint16_t * qh = x[i].qh;\n\n        int sumi = 0, sumi1 = 0;\n        for (int ib = 0; ib < QK_K/32; ++ib) {\n            const int ls = 2*((qh[ib] >> 12) & 7) + 1;\n            const int delta = qh[ib] & 0x8000 ? -1 : 1;\n            int lsum = 0;\n            for (int l = 0; l < 4; ++l) {\n                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));\n                for (int j = 0; j < 8; ++j) {\n                    lsum += q8[j] * grid[j];\n                }\n                q8 += 8;\n            }\n            sumi  += ls * lsum;\n            sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]);\n            qs += 4;\n        }\n\n        sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_iq1_m_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(n % QK_K == 0);\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n\n    const block_iq1_m * GGML_RESTRICT x = vx;\n    const block_q8_K  * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    iq1m_scale_t scale;\n\n    int sum1[2], sum2[2], delta[4];\n\n    float sumf = 0;\n    for (int i = 0; i < nb; i++) {\n\n        const int8_t   * q8 = y[i].qs;\n        const uint8_t  * qs = x[i].qs;\n        const uint8_t  * qh = x[i].qh;\n        const uint16_t * sc = (const uint16_t *)x[i].scales;\n\n        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);\n\n        int sumi1 = 0, sumi2 = 0;\n        for (int ib = 0; ib < QK_K/32; ++ib) {\n            delta[0] = qh[0] & 0x08 ? -1 : 1;\n            delta[1] = qh[0] & 0x80 ? -1 : 1;\n            delta[2] = qh[1] & 0x08 ? -1 : 1;\n            delta[3] = qh[1] & 0x80 ? -1 : 1;\n            sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0;\n            for (int l = 0; l < 4; ++l) {\n                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700)));\n                int lsum1 = 0, lsum2 = 0;\n                for (int j = 0; j < 8; ++j) {\n                    lsum1 += q8[j] * grid[j];\n                    lsum2 += q8[j];\n                }\n                q8 += 8;\n                sum1[l/2] += lsum1;\n                sum2[l/2] += lsum2*delta[l];\n            }\n\n            const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;\n            const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;\n\n            sumi1 += sum1[0] * ls1 + sum1[1] * ls2;\n            sumi2 += sum2[0] * ls1 + sum2[1] * ls2;\n            qs += 4;\n            qh += 2;\n        }\n\n        sumf += GGML_CPU_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);\n    }\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_iq4_nl_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK4_NL == 0);\n    static_assert(QK4_NL == QK8_0, \"QK4_NL and QK8_0 must be the same\");\n\n    const block_iq4_nl * GGML_RESTRICT x = vx;\n    const block_q8_0   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK4_NL;\n\n    int ib = 0;\n    float sumf = 0;\n\n    for (; ib < nb; ++ib) {\n        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d);\n        int sumi1 = 0, sumi2 = 0;\n        for (int j = 0; j < QK4_NL/2; ++j) {\n            sumi1 += y[ib].qs[j+       0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];\n            sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >>  4];\n        }\n        sumf += d * (sumi1 + sumi2);\n    }\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_iq4_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {\n    assert(nrc == 1);\n    UNUSED(nrc);\n    UNUSED(bx);\n    UNUSED(by);\n    UNUSED(bs);\n    assert(n % QK_K == 0);\n\n    const block_iq4_xs * GGML_RESTRICT x = vx;\n    const block_q8_K   * GGML_RESTRICT y = vy;\n\n    const int nb = n / QK_K;\n\n    float sumf = 0;\n    for (int ibl = 0; ibl < nb; ++ibl) {\n        const float d4d8 = GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d;\n        uint16_t h = x[ibl].scales_h;\n        const uint8_t * qs = x[ibl].qs;\n        const int8_t  * q8 = y[ibl].qs;\n        for (int ib = 0; ib < QK_K/32; ib += 2) {\n            const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);\n            const uint8_t ls2 = (x[ibl].scales_l[ib/2] >>  4) | ((h << 2) & 0x30);\n            h >>= 4;\n            const float d1 = d4d8*(ls1 - 32);\n            const float d2 = d4d8*(ls2 - 32);\n            int sumi1 = 0, sumi2 = 0;\n            for (int j = 0; j < 16; ++j) {\n                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];\n                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];\n            }\n            sumf += d1 * (sumi1 + sumi2);\n            qs += 16;\n            q8 += 32;\n            sumi1 = sumi2 = 0;\n            for (int j = 0; j < 16; ++j) {\n                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];\n                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];\n            }\n            sumf += d2 * (sumi1 + sumi2);\n            qs += 16;\n            q8 += 32;\n        }\n    }\n    *s = sumf;\n}\n\n// ============================ 4-bit non-linear quants\n\nvoid quantize_row_iq4_nl(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK4_NL == 0);\n    quantize_row_iq4_nl_ref(x, y, k);\n}\n\nvoid quantize_row_iq4_xs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    quantize_iq4_xs(x, y, 1, k, NULL);\n}\n"
  },
  {
    "path": "src/ggml-cpu/quants.h",
    "content": "#pragma once\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n\n#include \"ggml.h\"\n\n// GGML CPU internal header\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n// Quantization\nvoid quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\nvoid quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\nvoid quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\nvoid quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\nvoid quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\nvoid quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\n\nvoid quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\nvoid quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\n\nvoid quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\nvoid quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\nvoid quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\nvoid quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\nvoid quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\nvoid quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\n\nvoid quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\nvoid quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\n\nvoid quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\nvoid quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\n\n// Dot product\nvoid ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\n\nvoid ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\n\nvoid ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\n\nvoid ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\n\nvoid ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq2_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq1_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq1_m_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq3_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\n\n// Generic implementation\nvoid quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);\nvoid quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);\nvoid quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);\nvoid ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\n\nvoid ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\n\nvoid ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\n\nvoid ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_q5_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc);\nvoid ggml_vec_dot_q6_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq2_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq2_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq2_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq3_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq3_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq1_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq1_m_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq4_nl_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\nvoid ggml_vec_dot_iq4_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/repack.cpp",
    "content": "#define GGML_COMMON_IMPL_CPP\n#define GGML_COMMON_DECL_CPP\n#include \"ggml-common.h\"\n#include \"ggml-backend-impl.h\"\n\n#include \"ggml-impl.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-cpu-impl.h\"\n#include \"simd-mappings.h\"\n#include \"traits.h\"\n\n#include \"arch-fallback.h\"\n\n#include <cmath>\n#include <cstring>\n#include <cassert>\n#include <cstdio>  // for GGML_ASSERT\n\n#include \"repack.h\"\n\n#if defined(__GNUC__)\n#pragma GCC diagnostic ignored \"-Woverlength-strings\"\n#endif\n\n#define UNUSED GGML_UNUSED\n\nstatic inline int nearest_int(float fval) {\n    assert(fabsf(fval) <= 4194303.f);\n    float val = fval + 12582912.f;\n    int i; memcpy(&i, &val, sizeof(int));\n    return (i & 0x007fffff) - 0x00400000;\n}\n\n// Functions to create the interleaved data layout formats\n\n// interleave 4 block_q4_0s in blocks of blck_size_interleave\n// returns an interleaved block_q4_0x4\n// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks\n// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave\n//\n// - in                  : an array of block_q4_0 pointers\n// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of\n//                         blck_size_interleave bytes\n// - xor_mask            : the mask to convert the nibbles in block_q4_0 quants bytes\n//                         from bias offset form to pure sign form (this saves subtract\n//                         operations durin unpacking)\n//\n\nextern \"C\" {\n\n#if defined __riscv_zvfh\nvoid ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK8_0 == 32);\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;\n\n    // scalar\n    const int blck_size_interleave = 1;\n    float srcv[4][QK8_0];\n    float id[4];\n\n    for (int i = 0; i < nb; i++) {\n        for (int row_iter = 0; row_iter < 4; row_iter++) {\n            float amax = 0.0f; // absolute max\n\n            for (int j = 0; j < QK8_0; j++) {\n                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];\n                amax = MAX(amax, fabsf(srcv[row_iter][j]));\n            }\n\n            const float d = amax / ((1 << 7) - 1);\n            id[row_iter] = d ? 1.0f / d : 0.0f;\n\n            y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);\n        }\n\n        for (int j = 0; j < QK8_0 * 4; j++) {\n            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;\n            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;\n            src_offset += (j % blck_size_interleave);\n\n            float x0 = srcv[src_id][src_offset] * id[src_id];\n            y[i].qs[j] = roundf(x0);\n        }\n    }\n}\n\nvoid ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK_K == 256);\n    assert(k % QK_K == 0);\n    const int nb = k / QK_K;\n\n    block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;\n\n    const int blck_size_interleave = 1;\n    float srcv[4][QK_K];\n    float iscale[4];\n\n    for (int i = 0; i < nb; i++) {\n        for (int row_iter = 0; row_iter < 4; row_iter++) {\n            float amax = 0.0f; // absolute max\n            float max = 0;\n\n            for (int j = 0; j < QK_K; j++) {\n                srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];\n                // Update the maximum value of the corresponding super block\n                if(amax < fabsf(srcv[row_iter][j])) {\n                    amax = fabsf(srcv[row_iter][j]);\n                    max = srcv[row_iter][j];\n                }\n            }\n\n            iscale[row_iter] = amax ? -127.f/max : 0;\n            y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;\n        }\n\n        for (int j = 0; j < QK_K / 4; j++) {\n            y[i].bsums[j] = 0;\n        }\n        for (int j = 0; j < QK_K * 4; j++) {\n            int src_id = j % 4;\n            int src_offset = j / 4;\n            int index = ((j >> 6) << 2) + (j & 3);\n\n            float x0 = srcv[src_id][src_offset] * iscale[src_id];\n            y[i].qs[j] = nearest_int(x0);\n            y[i].bsums[index] += y[i].qs[j];\n        }\n    }\n}\n#endif\n\nvoid ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK8_0 == 32);\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;\n\n    // scalar\n    const int blck_size_interleave = 4;\n    float srcv[4][QK8_0];\n    float id[4];\n\n    for (int i = 0; i < nb; i++) {\n        for (int row_iter = 0; row_iter < 4; row_iter++) {\n            float amax = 0.0f; // absolute max\n\n            for (int j = 0; j < QK8_0; j++) {\n                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];\n                amax = MAX(amax, fabsf(srcv[row_iter][j]));\n            }\n\n            const float d = amax / ((1 << 7) - 1);\n            id[row_iter] = d ? 1.0f / d : 0.0f;\n\n            y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);\n        }\n\n        for (int j = 0; j < QK8_0 * 4; j++) {\n            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;\n            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;\n            src_offset += (j % blck_size_interleave);\n\n            float x0 = srcv[src_id][src_offset] * id[src_id];\n            y[i].qs[j] = roundf(x0);\n        }\n    }\n}\n\nvoid ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK8_0 == 32);\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;\n\n    // scalar\n    const int blck_size_interleave = 8;\n    float srcv[4][QK8_0];\n    float id[4];\n\n    for (int i = 0; i < nb; i++) {\n        for (int row_iter = 0; row_iter < 4; row_iter++) {\n            float amax = 0.0f; // absolute max\n\n            for (int j = 0; j < QK8_0; j++) {\n                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];\n                amax = MAX(amax, fabsf(srcv[row_iter][j]));\n            }\n\n            const float d = amax / ((1 << 7) - 1);\n            id[row_iter] = d ? 1.0f / d : 0.0f;\n\n            y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);\n        }\n\n        for (int j = 0; j < QK8_0 * 4; j++) {\n            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;\n            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;\n            src_offset += (j % blck_size_interleave);\n\n            float x0 = srcv[src_id][src_offset] * id[src_id];\n            y[i].qs[j] = roundf(x0);\n        }\n    }\n}\n\nvoid ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK_K == 256);\n    assert(k % QK_K == 0);\n    const int nb = k / QK_K;\n\n    block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;\n\n    // scalar\n    const int blck_size_interleave = 4;\n    float srcv[4][QK_K];\n    float iscale[4];\n\n    for (int i = 0; i < nb; i++) {\n        for (int row_iter = 0; row_iter < 4; row_iter++) {\n            float amax = 0.0f; // absolute max\n            float max = 0;\n\n            for (int j = 0; j < QK_K; j++) {\n                srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];\n                // Update the maximum value of the corresponding super block\n                if(amax < fabsf(srcv[row_iter][j])) {\n                    amax = fabsf(srcv[row_iter][j]);\n                    max = srcv[row_iter][j];\n                }\n            }\n\n            iscale[row_iter] = amax ? -127.f/max : 0;\n\n            y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;\n        }\n\n        for (int j = 0; j < QK_K / 4; j++) {\n            y[i].bsums[j] = 0;\n        }\n\n        // Quants values are interleaved in sequence of four bytes from corresponding super blocks\n        // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving\n        // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on\n        for (int j = 0; j < QK_K * 4; j++) {\n            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;\n            int src_id     = (j % (4 * blck_size_interleave)) / blck_size_interleave;\n            src_offset += (j % blck_size_interleave);\n            int index = (((j & 15) >> 2) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);\n\n            float x0 = srcv[src_id][src_offset] * iscale[src_id];\n            y[i].qs[j] = nearest_int(x0);\n            y[i].bsums[index] += y[i].qs[j];\n        }\n    }\n}\n\nvoid ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\n    assert(QK_K == 256);\n    assert(k % QK_K == 0);\n    const int nb = k / QK_K;\n\n    block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;\n\n    // scalar\n    const int blck_size_interleave = 8;\n    float srcv[4][QK_K];\n    float iscale[4];\n\n    for (int i = 0; i < nb; i++) {\n        for (int row_iter = 0; row_iter < 4; row_iter++) {\n            float amax = 0.0f; // absolute max\n            float max = 0;\n\n            for (int j = 0; j < QK_K; j++) {\n                srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];\n                // Update the maximum value of the corresponding super block\n                if(amax < fabsf(srcv[row_iter][j])) {\n                    amax = fabsf(srcv[row_iter][j]);\n                    max = srcv[row_iter][j];\n                }\n            }\n\n            iscale[row_iter] = amax ? -127.f/max : 0;\n\n            y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;\n        }\n\n        for (int j = 0; j < QK_K / 4; j++) {\n            y[i].bsums[j] = 0;\n        }\n\n        // Quants values are interleaved in sequence of eight bytes from corresponding super blocks\n        // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving\n        // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on\n        for (int j = 0; j < QK_K * 4; j++) {\n            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;\n            int src_id     = (j % (4 * blck_size_interleave)) / blck_size_interleave;\n            src_offset += (j % blck_size_interleave);\n            int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);\n\n            float x0 = srcv[src_id][src_offset] * iscale[src_id];\n            y[i].qs[j] = nearest_int(x0);\n            y[i].bsums[index] += y[i].qs[j];\n        }\n    }\n}\n\n} // extern \"C\"\n\ntemplate <int64_t INTER_SIZE, ggml_type PARAM_TYPE>\nvoid ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);\n\ntemplate <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {\n    assert(nrow == 4);\n    UNUSED(nrow);\n    ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);\n}\n\ntemplate <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {\n    assert(nrow == 4);\n    UNUSED(nrow);\n    ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);\n}\n\ntemplate <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {\n    assert(nrow == 4);\n    UNUSED(nrow);\n    ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);\n}\n\ntemplate <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {\n    assert(nrow == 4);\n    UNUSED(nrow);\n    ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);\n}\n\n#if defined __riscv_zvfh\ntemplate <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {\n    assert(nrow == 4);\n    UNUSED(nrow);\n    ggml_quantize_mat_q8_0_4x1(x, vy, n_per_row);\n}\n\ntemplate <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {\n    assert(nrow == 4);\n    UNUSED(nrow);\n    ggml_quantize_mat_q8_K_4x1(x, vy, n_per_row);\n}\n#endif\n\ntemplate <int M, int N>\nstatic void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int                        n,\n                                                 float * GGML_RESTRICT      s,\n                                                 size_t                     bs,\n                                                 const void * GGML_RESTRICT vx,\n                                                 const void * GGML_RESTRICT vy,\n                                                 int                        nr,\n                                                 int                        nc) {\n    constexpr int blocklen          = M;\n    constexpr int ncols_interleaved = N;\n    const int     qk                = QK_K;\n    const int     nb                = n / qk;\n    const int     blocks_per_half   = 64 / blocklen;\n\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(bs);\n    UNUSED(nr);\n\n    float sumf[8];\n\n    const block_q8_K * a_ptr = (const block_q8_K *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) {\n            sumf[j] = 0.0f;\n        }\n\n        for (int l = 0; l < nb; l++) {\n            for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;\n                const int base_h = base_l + 64;\n\n                const int scale_idx_l = base_l / 16;\n                const int scale_idx_h = base_h / 16;\n\n                const int qh_shift_l = ((base_l % 128) / 32) * 2;\n                const int qh_shift_h = ((base_h % 128) / 32) * 2;\n\n                const int qh_half_l = (base_l / 128) * 32;\n                const int qh_half_h = (base_h / 128) * 32;\n\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];\n                    const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];\n\n                    int sumi_l = 0;\n                    int sumi_h = 0;\n\n                    for (int i = 0; i < blocklen; i++) {\n                        const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;\n                        const int l_4    = b_ptr[l].ql[ql_pos] & 0xF;\n                        const int hi_4   = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;\n\n                        const int qh_idx_l    = qh_half_l + ((base_l + i) % 32);\n                        const int qh_chunk_l  = qh_idx_l / blocklen;\n                        const int qh_pos_l    = qh_idx_l % blocklen;\n                        const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;\n                        const int hi_2_l      = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;\n\n                        const int qh_idx_h    = qh_half_h + ((base_h + i) % 32);\n                        const int qh_chunk_h  = qh_idx_h / blocklen;\n                        const int qh_pos_h    = qh_idx_h % blocklen;\n                        const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;\n                        const int hi_2_h      = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;\n\n                        const int q_l = ((hi_2_l << 4) | l_4) - 32;\n                        const int q_h = ((hi_2_h << 4) | hi_4) - 32;\n\n                        const int8_t a_l = a_ptr[l].qs[base_l + i];\n                        const int8_t a_h = a_ptr[l].qs[base_h + i];\n\n                        sumi_l += q_l * a_l;\n                        sumi_h += q_h * a_h;\n                    }\n\n                    sumf[j] +=\n                        (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;\n                }\n            }\n        }\n\n        for (int j = 0; j < ncols_interleaved; j++) {\n            s[x * ncols_interleaved + j] = sumf[j];\n        }\n    }\n}\n\ntemplate <int M, int N>\nstatic void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int                        n,\n                                                 float * GGML_RESTRICT      s,\n                                                 size_t                     bs,\n                                                 const void * GGML_RESTRICT vx,\n                                                 const void * GGML_RESTRICT vy,\n                                                 int                        nr,\n                                                 int                        nc) {\n    constexpr int blocklen          = M;\n    constexpr int ncols_interleaved = N;\n    const int     qk                = QK_K;\n    const int     nb                = n / qk;\n    const int     blocks_per_half   = 64 / blocklen;\n    const int     q8_half_stride    = 512;\n    const int     q8_low_high_step  = 256;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(bs);\n\n    float sumf[4][8];\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);\n\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumf[m][j] = 0.0f;\n                }\n            }\n\n            for (int l = 0; l < nb; l++) {\n                for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                    const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;\n                    const int base_h = base_l + 64;\n\n                    const int scale_idx_l = base_l / 16;\n                    const int scale_idx_h = base_h / 16;\n\n                    const int qh_shift_l = ((base_l % 128) / 32) * 2;\n                    const int qh_shift_h = ((base_h % 128) / 32) * 2;\n\n                    const int qh_half_l = (base_l / 128) * 32;\n                    const int qh_half_h = (base_h / 128) * 32;\n\n                    const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4);\n\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];\n                            const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];\n\n                            int sumi_l = 0;\n                            int sumi_h = 0;\n\n                            for (int i = 0; i < blocklen; i++) {\n                                const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;\n                                const int l_4    = b_ptr[l].ql[ql_pos] & 0xF;\n                                const int hi_4   = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;\n\n                                const int qh_idx_l   = qh_half_l + ((base_l + i) % 32);\n                                const int qh_chunk_l = qh_idx_l / blocklen;\n                                const int qh_pos_l   = qh_idx_l % blocklen;\n                                const int qh_offset_l =\n                                    qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;\n                                const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;\n\n                                const int qh_idx_h   = qh_half_h + ((base_h + i) % 32);\n                                const int qh_chunk_h = qh_idx_h / blocklen;\n                                const int qh_pos_h   = qh_idx_h % blocklen;\n                                const int qh_offset_h =\n                                    qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;\n                                const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;\n\n                                const int q_l = ((hi_2_l << 4) | l_4) - 32;\n                                const int q_h = ((hi_2_h << 4) | hi_4) - 32;\n\n                                const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i];\n                                const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step];\n\n                                sumi_l += q_l * q8_l;\n                                sumi_h += q_h * q8_h;\n                            }\n\n                            sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *\n                                          a_ptr[l].d[m];\n                        }\n                    }\n                }\n            }\n\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];\n                }\n            }\n        }\n    }\n}\n\ntemplate <int M, int N>\nstatic void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int                        n,\n                                                 float * GGML_RESTRICT      s,\n                                                 size_t                     bs,\n                                                 const void * GGML_RESTRICT vx,\n                                                 const void * GGML_RESTRICT vy,\n                                                 int                        nr,\n                                                 int                        nc) {\n    constexpr int         blocklen          = M;\n    constexpr int         ncols_interleaved = N;\n    const int             qk                = QK_K;\n    const int             nb                = n / qk;\n    static const uint32_t kmask1            = 0x3f3f3f3f;\n    static const uint32_t kmask2            = 0x0f0f0f0f;\n    static const uint32_t kmask3            = 0x03030303;\n\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(bs);\n    UNUSED(nr);\n\n    float    sumf[ncols_interleaved];\n    float    sum_minf[ncols_interleaved];\n    uint32_t utmp[32];\n    int      sumi1;\n    int      sumi2;\n    int      sumi;\n\n    const block_q8_K * a_ptr = (const block_q8_K *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) {\n            sumf[j]     = 0.0;\n            sum_minf[j] = 0.0;\n        }\n        for (int l = 0; l < nb; l++) {\n            for (int sb = 0; sb < 8; sb++) {\n                memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);\n                utmp[sb * 4 + 3]      = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);\n                const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;\n                utmp[sb * 4 + 1]      = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);\n                utmp[sb * 4 + 2]      = uaux_0;\n                utmp[sb * 4 + 0] &= kmask1;\n            }\n            for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                constexpr int scale_stride = 32;\n                uint8_t *     scales_0     = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;\n                uint8_t *     scales_1     = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;\n\n                const int qh_shift = (k / (32 / blocklen)) * 2;\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi1 = 0;\n                    sumi2 = 0;\n                    sumi  = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;\n\n                        const int qh_idx      = (k * blocklen + i) % 32;\n                        const int qh_chunk    = qh_idx / blocklen;\n                        const int qh_pos      = qh_idx % blocklen;\n                        const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;\n\n                        const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];\n                        const uint8_t h0     = (qh_val >> qh_shift) & 1;\n                        const uint8_t h1     = (qh_val >> (qh_shift + 1)) & 1;\n\n                        const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));\n                        const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));\n\n                        const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i;\n\n                        sumi1 = (v0 * a_ptr[l].qs[q8_offset]);\n                        sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);\n                        sumi1 = sumi1 * scales_0[j];\n                        sumi2 = sumi2 * scales_1[j];\n                        sumi += sumi1 + sumi2;\n                    }\n                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;\n                }\n            }\n            for (int sb = 0; sb < 8; sb++) {\n                uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *\n                                   GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) {\n            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];\n        }\n    }\n}\n\ntemplate <int M, int N>\nstatic void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int                        n,\n                                                 float * GGML_RESTRICT      s,\n                                                 size_t                     bs,\n                                                 const void * GGML_RESTRICT vx,\n                                                 const void * GGML_RESTRICT vy,\n                                                 int                        nr,\n                                                 int                        nc) {\n    constexpr int         blocklen          = M;\n    constexpr int         ncols_interleaved = N;\n    const int             qk                = QK_K;\n    const int             nb                = n / qk;\n    static const uint32_t kmask1            = 0x3f3f3f3f;\n    static const uint32_t kmask2            = 0x0f0f0f0f;\n    static const uint32_t kmask3            = 0x03030303;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    float    sumf[4][ncols_interleaved];\n    float    sum_minf[4][ncols_interleaved];\n    uint32_t utmp[32];\n    int      sumi1;\n    int      sumi2;\n    int      sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumf[m][j]     = 0.0;\n                    sum_minf[m][j] = 0.0;\n                }\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int sb = 0; sb < 8; sb++) {\n                    memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);\n                    utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;\n                    utmp[sb * 4 + 1]      = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);\n                    utmp[sb * 4 + 2]      = uaux_0;\n                    utmp[sb * 4 + 0] &= kmask1;\n                }\n                for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                    constexpr int scale_stride = 32;\n                    uint8_t *     scales_0     = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;\n                    uint8_t *     scales_1     = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;\n\n                    const int qh_shift = (k / (32 / blocklen)) * 2;\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sumi1 = 0;\n                            sumi2 = 0;\n                            sumi  = 0;\n                            for (int i = 0; i < blocklen; ++i) {\n                                const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;\n\n                                const int qh_idx   = (k * blocklen + i) % 32;\n                                const int qh_chunk = qh_idx / blocklen;\n                                const int qh_pos   = qh_idx % blocklen;\n                                const int b_qh_offset =\n                                    qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;\n\n                                const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];\n                                const uint8_t h0     = (qh_val >> qh_shift) & 1;\n                                const uint8_t h1     = (qh_val >> (qh_shift + 1)) & 1;\n\n                                const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));\n                                const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));\n\n                                const int q8_offset = (k / (32 / blocklen)) * 256 +\n                                                      (k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i;\n\n                                sumi1 = (v0 * a_ptr[l].qs[q8_offset]);\n                                sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);\n                                sumi1 = sumi1 * scales_0[j];\n                                sumi2 = sumi2 * scales_1[j];\n                                sumi += sumi1 + sumi2;\n                            }\n                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];\n                        }\n                    }\n                }\n                for (int sb = 0; sb < 8; sb++) {\n                    uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;\n                    for (int m = 0; m < 4; m++) {\n                        const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *\n                                              GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];\n                        }\n                    }\n                }\n            }\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];\n                }\n            }\n        }\n    }\n}\n\nextern \"C\" {\n\nvoid ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 4;\n\n    assert(nr == 1);\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n    float sumf[4];\n    int sumi;\n\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;\n        for (int l = 0; l < nb; l++) {\n            for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);\n                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);\n                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;\n                    }\n                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];\n    }\n}\n\nvoid ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 8;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n    float sumf[4];\n    int sumi;\n\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;\n        for (int l = 0; l < nb; l++) {\n            for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);\n                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);\n                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;\n                    }\n                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];\n    }\n}\n\nvoid ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n    float sumf[8];\n    int sumi;\n\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;\n        for (int l = 0; l < nb; l++) {\n            for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);\n                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);\n                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;\n                    }\n                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];\n    }\n}\n\nvoid ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK_K;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 4;\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(bs);\n    UNUSED(nr);\n\n    float sumf[8];\n    float sum_minf[8];\n    uint32_t utmp[32];\n    int sumi1;\n    int sumi2;\n    int sumi;\n\n    const block_q8_K * a_ptr = (const block_q8_K *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) {\n            sumf[j] = 0.0;\n            sum_minf[j] = 0.0;\n        }\n        for (int l = 0; l < nb; l++) {\n            for (int sb = 0; sb < 8; sb++) {\n                memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);\n                utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);\n                const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;\n                utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);\n                utmp[sb * 4 + 2] = uaux_0;\n                utmp[sb * 4 + 0] &= kmask1;\n            }\n            for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;\n                uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi1 = 0;\n                    sumi2 = 0;\n                    sumi = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);\n                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);\n                        sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]);\n                        sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]);\n                        sumi1 = sumi1 * scales_0[j];\n                        sumi2 = sumi2 * scales_1[j];\n                        sumi += sumi1 + sumi2;\n                    }\n                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;\n                }\n            }\n            for (int sb = 0; sb < 8; sb++) {\n                uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) {\n            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];\n        }\n    }\n}\n\nvoid ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK_K;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(bs);\n    UNUSED(nr);\n\n    float sumf[8];\n    float sum_minf[8];\n    uint32_t utmp[32];\n    int sumi1;\n    int sumi2;\n    int sumi;\n\n    const block_q8_K * a_ptr = (const block_q8_K *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) {\n            sumf[j] = 0.0;\n            sum_minf[j] = 0.0;\n        }\n        for (int l = 0; l < nb; l++) {\n            for (int sb = 0; sb < 8; sb++) {\n                memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);\n                utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);\n                const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;\n                utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);\n                utmp[sb * 4 + 2] = uaux_0;\n                utmp[sb * 4 + 0] &= kmask1;\n            }\n            for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;\n                uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi1 = 0;\n                    sumi2 = 0;\n                    sumi = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);\n                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);\n                        sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i]);\n                        sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i + 32]);\n                        sumi1 = sumi1 * scales_0[j];\n                        sumi2 = sumi2 * scales_1[j];\n                        sumi += sumi1 + sumi2;\n                    }\n                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;\n                }\n            }\n            for (int sb = 0; sb < 8; sb++) {\n                uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) {\n            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];\n        }\n    }\n}\n\nvoid ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK_K;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n    float sumf[8];\n    float sum_minf[8];\n    int sumi1,sumi2,sumi3,sumi4;\n    int sumi;\n\n    const block_q8_K * a_ptr = (const block_q8_K *)vy;\n    for(int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);\n        for (int j = 0; j < ncols_interleaved; j++) {\n            sumf[j] = 0.0;\n            sum_minf[j] = 0.0;\n        }\n        for (int l = 0; l < nb; l++) {\n            for (int k = 0; k < (qk / (4 * blocklen)); k++) {\n                const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;\n                const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;\n                const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;\n                const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi1 = 0;\n                    sumi2 = 0;\n                    sumi3 = 0;\n                    sumi4 = 0;\n                    sumi = 0;\n                    int offset = ((k / 2) % 2) + j * 2;\n                    for (int i = 0; i < blocklen; ++i){\n                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);\n                        const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);\n                        const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);\n                        const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);\n                        sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]);\n                        sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]);\n                        sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]);\n                        sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]);\n\n                        sumi1 = sumi1 * (scales_0[offset] & 0xF);\n                        sumi2 = sumi2 * (scales_1[offset] & 0xF);\n                        sumi3 = sumi3 * (scales_2[offset] & 0xF);\n                        sumi4 = sumi4 * (scales_3[offset] & 0xF);\n                        sumi += sumi1 + sumi2 + sumi3 + sumi4;\n                    }\n                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;\n                }\n            }\n            for(int sb = 0; sb < 8; sb++) {\n                const uint8_t *mins = b_ptr[l].scales + sb * 16;\n                for(int j = 0; j < ncols_interleaved; j++){\n                    sum_minf[j] += ((mins[j * 2] >> 4) * a_ptr[l].bsums[sb * 2] + (mins[(j * 2)+ 1] >> 4) * a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) {\n            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];\n        }\n    }\n}\n\nvoid ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    ggml_gemv_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);\n}\n\n\nvoid ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 4;\n\n    assert(nr == 1);\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(bs);\n    UNUSED(nr);\n\n    float sumf[4];\n    int sumi;\n\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;\n        for (int l = 0; l < nb; l++) {\n            for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];\n                        const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];\n                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));\n                    }\n                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];\n    }\n}\n\nvoid ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n\n    assert(nr == 1);\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(bs);\n    UNUSED(nr);\n\n    float sumf[8];\n    int sumi;\n\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;\n        for (int l = 0; l < nb; l++) {\n            for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];\n                        const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];\n                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));\n                    }\n                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];\n    }\n}\n\nvoid ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 4;\n\n    assert(nr == 1);\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(bs);\n    UNUSED(nr);\n\n    float sumf[4];\n    int sumi;\n\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;\n        for (int l = 0; l < nb; l++) {\n            for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];\n                        const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];\n                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));\n                    }\n                    sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];\n    }\n}\n\nvoid ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n\n    assert(nr == 1);\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(bs);\n    UNUSED(nr);\n\n    float sumf[8];\n    int sumi;\n\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;\n        for (int l = 0; l < nb; l++) {\n            for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];\n                        const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];\n                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));\n                    }\n                    sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];\n    }\n}\n\nvoid ggml_gemv_q8_0_4x4_q8_0_generic(int                        n,\n                                     float * GGML_RESTRICT      s,\n                                     size_t                     bs,\n                                     const void * GGML_RESTRICT vx,\n                                     const void * GGML_RESTRICT vy,\n                                     int                        nr,\n                                     int                        nc) {\n    const int qk                = QK8_0;\n    const int nb                = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen          = 4;\n\n    assert(nr == 1);\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(bs);\n    UNUSED(nr);\n\n    float sumf[4];\n    int   sumi;\n\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) {\n            sumf[j] = 0.0;\n        }\n        for (int l = 0; l < nb; l++) {\n            for (int k = 0; k < (qk / blocklen); k++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];\n                        sumi += v0 * a_ptr[l].qs[k * blocklen + i];\n                    }\n                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) {\n            s[x * ncols_interleaved + j] = sumf[j];\n        }\n    }\n}\n\nvoid ggml_gemv_q8_0_4x8_q8_0_generic(int                        n,\n                                     float * GGML_RESTRICT      s,\n                                     size_t                     bs,\n                                     const void * GGML_RESTRICT vx,\n                                     const void * GGML_RESTRICT vy,\n                                     int                        nr,\n                                     int                        nc) {\n    const int qk                = QK8_0;\n    const int nb                = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen          = 8;\n\n    assert(nr == 1);\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(bs);\n    UNUSED(nr);\n\n    float sumf[4];\n    int   sumi;\n\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) {\n            sumf[j] = 0.0;\n        }\n        for (int l = 0; l < nb; l++) {\n            for (int k = 0; k < (qk / blocklen); k++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];\n                        sumi += v0 * a_ptr[l].qs[k * blocklen + i];\n                    }\n                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) {\n            s[x * ncols_interleaved + j] = sumf[j];\n        }\n    }\n}\n\n#if defined __riscv_zvfh\nvoid ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen = 1;\n\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n    float sumf[16];\n    int sumi;\n\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;\n        for (int l = 0; l < nb; l++) {\n            for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);\n                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);\n                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;\n                    }\n                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];\n    }\n}\n\nvoid ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK_K;\n    const int nb = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen = 1;\n    assert (n % qk == 0);\n    assert (nc % ncols_interleaved == 0);\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n    float sumf[16];\n    float sum_minf[16];\n    uint8_t scales[128];\n    uint8_t mins[128];\n    int sumi1;\n    int sumi2;\n    int sumi;\n    const block_q8_K * a_ptr = (const block_q8_K *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb);\n        for (int j = 0; j < ncols_interleaved; j++) {\n            sumf[j] = 0.0f;\n            sum_minf[j] = 0.0f;\n        }\n        for (int l = 0; l < nb; l++) {\n            for (int i = 0; i < 128; i++) {\n                scales[i] = b_ptr[l].scales[i] & 0x0F;\n                mins[i] = b_ptr[l].scales[i] >> 4;\n            }\n            for (int i = 0; i < 64; i++) {\n                scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4;\n                mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2;\n                scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30);\n                mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2;\n            }\n            for (int sb = 0; sb < 8; sb++) {\n                uint8_t *min = &mins[sb * 16];\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sum_minf[j] += min[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;\n                }\n            }\n            for (int sb = 0; sb < 8; sb += 2) {\n                uint8_t *scales_0 = &scales[sb * 16];\n                uint8_t *scales_1 = &scales[(sb + 1) * 16];\n                for (int i = 0; i < QK4_0; i++) {\n                    for (int j = 0; j < ncols_interleaved; j++) {\n                        sumi1 = 0;\n                        sumi2 = 0;\n                        sumi = 0;\n                        const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF);\n                        const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4);\n                        sumi1 = (v0 * a_ptr[l].qs[sb * 32 + i]);\n                        sumi2 = (v1 * a_ptr[l].qs[sb * 32 + 32 + i]);\n                        sumi1 = sumi1 * scales_0[j];\n                        sumi2 = sumi2 * scales_1[j];\n                        sumi += sumi1 + sumi2;\n                        sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;\n                    }\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) {\n            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];\n        }\n    }\n}\n\nvoid ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen = 1;\n\n    assert(nr == 1);\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(bs);\n    UNUSED(nr);\n\n    float sumf[16];\n    int sumi;\n\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;\n        for (int l = 0; l < nb; l++) {\n            for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];\n                        const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];\n                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));\n                    }\n                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];\n    }\n}\n\nvoid ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk                = QK8_0;\n    const int nb                = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen          = 1;\n\n    assert(nr == 1);\n    assert(n % qk == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    UNUSED(bs);\n    UNUSED(nr);\n\n    float sumf[16];\n    int   sumi;\n\n    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;\n    for (int x = 0; x < nc / ncols_interleaved; x++) {\n        const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb);\n\n        for (int j = 0; j < ncols_interleaved; j++) {\n            sumf[j] = 0.0;\n        }\n        for (int l = 0; l < nb; l++) {\n            for (int k = 0; k < (qk / blocklen); k++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumi = 0;\n                    for (int i = 0; i < blocklen; ++i) {\n                        const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];\n                        sumi += v0 * a_ptr[l].qs[k * blocklen + i];\n                    }\n                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);\n                }\n            }\n        }\n        for (int j = 0; j < ncols_interleaved; j++) {\n            s[x * ncols_interleaved + j] = sumf[j];\n        }\n    }\n}\n\nvoid ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    assert(n % QK_K == 0);\n    assert(nr == 1);\n    assert(nc % 16 == 0);\n\n    UNUSED(bs);\n\n    const int nb = n / QK_K;\n    const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx;\n    const block_q8_K    * y = (const block_q8_K *)vy;\n\n    // Layout: Even-Low(0,2,4,6), Odd-Low(1,3,5,7), Even-High(8...), Odd-High(9...)\n    const int sb_perm[16] = {\n        0, 4, 1, 5, 2, 6, 3, 7,  // 0-7\n        8, 12, 9, 13, 10, 14, 11, 15 // 8-15\n    };\n\n    for (int col_tile = 0; col_tile < nc; col_tile += 16) {\n        const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb;\n        const block_q8_K    * y_ptr = y;\n\n        float sumf[16] = {0};\n\n        // Loop over K-blocks\n        for (int k_block = 0; k_block < nb; ++k_block) {\n            int32_t isum[16]  = {0};\n            int32_t summs[16] = {0};\n\n            const uint8_t * qs_rhs = x_ptr[k_block].qs;\n            const uint8_t * sc_rhs = x_ptr[k_block].scales;\n            const int8_t  * qs_lhs = y_ptr[k_block].qs;\n            const int16_t * bs_lhs = y_ptr[k_block].bsums;\n\n            // Iterate over sub-blocks 0..15\n            for (int sb = 0; sb < 16; ++sb) {\n                // Correction Term\n                int16_t bsum = bs_lhs[sb];\n                int scale_offset = sb_perm[sb] * 16;\n\n                for (int col = 0; col < 16; ++col) {\n                    uint8_t sc_val = sc_rhs[scale_offset + col];\n                    summs[col] += bsum * (sc_val >> 4); // Min is high 4 bits\n                }\n\n                // Main Dot Product\n                // Calculate base offsets for Q2 unpacking based on SB\n                int byte_base;\n                if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16;\n                else        byte_base = (sb % 2 == 0) ? 32 : 48;\n\n                int shift = ((sb / 2) % 4) * 2;\n\n                for (int col = 0; col < 16; ++col) {\n                    uint8_t sc_val = sc_rhs[scale_offset + col];\n                    int32_t d_sb = sc_val & 0xF; // Scale is low 4 bits\n\n                    // Process 16 elements (l=0..15)\n                    for (int l = 0; l < 16; ++l) {\n                        // Q2: Interleaved by column. Byte `l` contains 4 k-values.\n                        int qs_idx = (byte_base + l) * 16 + col;\n                        uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3;\n\n                        // Q8: Linear access\n                        int k = sb * 16 + l;\n                        int8_t q8_val = qs_lhs[k];\n\n                        isum[col] += q8_val * q2_val * d_sb;\n                    }\n                }\n            }\n\n            // Finalize K-Block\n            for (int col = 0; col < 16; ++col) {\n                float d_lhs = y_ptr[k_block].d;\n                float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]);\n                float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]);\n\n                float d_all = d_lhs * d_rhs;\n                float d_min = d_lhs * dm_rhs;\n\n                sumf[col] += (isum[col] * d_all) - (summs[col] * d_min);\n            }\n        }\n\n        for (int col = 0; col < 16; ++col) {\n            s[col_tile + col] = sumf[col];\n        }\n    }\n}\n#endif\n\nvoid ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 4;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n    {\n        float sumf[4][4];\n        int sumi;\n\n        for (int y = 0; y < nr / 4; y++) {\n            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n            for (int x = 0; x < nc / ncols_interleaved; x++) {\n                const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);\n                for (int m = 0; m < 4; m++) {\n                    for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;\n                }\n                for (int l = 0; l < nb; l++) {\n                    for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                        for (int m = 0; m < 4; m++) {\n                            for (int j = 0; j < ncols_interleaved; j++) {\n                                sumi = 0;\n                                for (int i = 0; i < blocklen; ++i) {\n                                    const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);\n                                    const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);\n                                    sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +\n                                            (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;\n                                }\n                                sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);\n                            }\n                        }\n                    }\n                }\n                for (int m = 0; m < 4; m++) {\n                    for (int j = 0; j < ncols_interleaved; j++)\n                        s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 8;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n    float sumf[4][4];\n    int sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sumi = 0;\n                            for (int i = 0; i < blocklen; ++i) {\n                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);\n                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);\n                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +\n                                        (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;\n                            }\n                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);\n                        }\n                    }\n                }\n            }\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++)\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];\n            }\n        }\n    }\n}\n\nvoid ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n    float sumf[4][8];\n    int sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sumi = 0;\n                            for (int i = 0; i < blocklen; ++i) {\n                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);\n                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);\n                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +\n                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;\n                            }\n                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);\n                        }\n                    }\n                }\n            }\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++)\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];\n            }\n        }\n    }\n}\n\nvoid ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK_K;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 4;\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n    float sumf[4][8];\n    float sum_minf[4][8];\n    uint32_t utmp[32];\n    int sumi1;\n    int sumi2;\n    int sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumf[m][j] = 0.0;\n                    sum_minf[m][j] = 0.0;\n                }\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int sb = 0; sb < 8; sb++) {\n                    memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);\n                    utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;\n                    utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);\n                    utmp[sb * 4 + 2] = uaux_0;\n                    utmp[sb * 4 + 0] &= kmask1;\n                }\n                for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                    uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;\n                    uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sumi1 = 0;\n                            sumi2 = 0;\n                            sumi = 0;\n                            for (int i = 0; i < blocklen; ++i) {\n                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);\n                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);\n                                sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]);\n                                sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]);\n                                sumi1 = sumi1 * scales_0[j];\n                                sumi2 = sumi2 * scales_1[j];\n                                sumi += sumi1 + sumi2;\n                            }\n                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];\n                        }\n                    }\n                }\n                for (int sb = 0; sb < 8; sb++) {\n                    uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;\n                    for(int m = 0; m < 4; m++) {\n                        const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);\n                        for(int j = 0; j < ncols_interleaved; j++) {\n                            sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];\n                        }\n                    }\n                }\n            }\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK_K;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n    static const uint32_t kmask1 = 0x3f3f3f3f;\n    static const uint32_t kmask2 = 0x0f0f0f0f;\n    static const uint32_t kmask3 = 0x03030303;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(bs);\n\n    float sumf[4][8];\n    float sum_minf[4][8];\n    uint32_t utmp[32];\n    int sumi1;\n    int sumi2;\n    int sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumf[m][j] = 0.0;\n                    sum_minf[m][j] = 0.0;\n                }\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int sb = 0; sb < 8; sb++) {\n                    memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);\n                    utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);\n                    const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;\n                    utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);\n                    utmp[sb * 4 + 2] = uaux_0;\n                    utmp[sb * 4 + 0] &= kmask1;\n                }\n                for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                    uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;\n                    uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sumi1 = 0;\n                            sumi2 = 0;\n                            sumi = 0;\n                            for (int i = 0; i < blocklen; ++i) {\n                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);\n                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);\n                                sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i]);\n                                sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);\n                                sumi1 = sumi1 * scales_0[j];\n                                sumi2 = sumi2 * scales_1[j];\n                                sumi += sumi1 + sumi2;\n                            }\n                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];\n                        }\n                    }\n                }\n                for (int sb = 0; sb < 8; sb++) {\n                    uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;\n                    for(int m = 0; m < 4; m++) {\n                        const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);\n                        for(int j = 0; j < ncols_interleaved; j++) {\n                            sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];\n                        }\n                    }\n                }\n            }\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK_K;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n    float sumf[4][8];\n    float sum_minf[4][8];\n    int sumi1, sumi2, sumi3, sumi4;\n    int sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumf[m][j] = 0.0;\n                    sum_minf[m][j] = 0.0;\n                }\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int k = 0; k < (qk / (4 * blocklen)); k++) {\n\n                    const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;\n                    const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;\n                    const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;\n                    const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sumi1 = 0;\n                            sumi2 = 0;\n                            sumi3 = 0;\n                            sumi4 = 0;\n                            sumi = 0;\n                            int offset = ((k / 2) % 2) + j * 2;\n                            for (int i = 0; i < blocklen; ++i){\n                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);\n                                const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);\n                                const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);\n                                const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);\n                                sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]);\n                                sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512  + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);\n                                sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512  + (k % 4) * 4 * blocklen + m * blocklen + i + 256]);\n                                sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512  + (k % 4) * 4 * blocklen + m * blocklen + i + 384]);\n                                sumi1 = sumi1 * (scales_0[offset] & 0xF);\n                                sumi2 = sumi2 * (scales_1[offset] & 0xF);\n                                sumi3 = sumi3 * (scales_2[offset] & 0xF);\n                                sumi4 = sumi4 * (scales_3[offset] & 0xF);\n                                sumi += sumi1 + sumi2 + sumi3 + sumi4;\n                            }\n                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];\n                        }\n                    }\n                }\n                for(int sb = 0; sb < 8; sb++) {\n                    const uint8_t *mins = b_ptr[l].scales + sb * 16;\n                    for(int m = 0; m < 4; m++) {\n                        const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) *  6);\n                        for(int j = 0; j < ncols_interleaved; j++) {\n                            int mins_prod = ((mins[j * 2] >> 4) * bsums[0] + (mins[(j * 2)+ 1] >> 4) * bsums[1]);\n                            sum_minf[m][j] += (mins_prod) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];\n                        }\n                    }\n                }\n            }\n\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    ggml_gemm_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    ggml_gemm_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    ggml_gemm_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n   ggml_gemm_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);\n}\n\nvoid ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 4;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n    {\n        float sumf[4][4];\n        int sumi;\n\n        for (int y = 0; y < nr / 4; y++) {\n            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n            for (int x = 0; x < nc / ncols_interleaved; x++) {\n                const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);\n                for (int m = 0; m < 4; m++) {\n                    for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;\n                }\n                for (int l = 0; l < nb; l++) {\n                    for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                        for (int m = 0; m < 4; m++) {\n                            for (int j = 0; j < ncols_interleaved; j++) {\n                                sumi = 0;\n                                for (int i = 0; i < blocklen; ++i) {\n                                    const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];\n                                    const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];\n                                    sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +\n                                            (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));\n                                }\n                                sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);\n                            }\n                        }\n                    }\n                }\n                for (int m = 0; m < 4; m++) {\n                    for (int j = 0; j < ncols_interleaved; j++)\n                        s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    float sumf[4][8];\n    int sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sumi = 0;\n                            for (int i = 0; i < blocklen; ++i) {\n                                const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];\n                                const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];\n                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +\n                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));\n                            }\n                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);\n                        }\n                    }\n                }\n            }\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++)\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];\n            }\n        }\n    }\n}\n\nvoid ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen = 4;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    float sumf[4][4];\n    int sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sumi = 0;\n                            for (int i = 0; i < blocklen; ++i) {\n                                const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];\n                                const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];\n                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +\n                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));\n                            }\n                            sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);\n                        }\n                    }\n                }\n            }\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++)\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];\n            }\n        }\n    }\n}\n\nvoid ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 8;\n    const int blocklen = 8;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    float sumf[4][8];\n    int sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sumi = 0;\n                            for (int i = 0; i < blocklen; ++i) {\n                                const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];\n                                const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];\n                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +\n                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));\n                            }\n                            sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);\n                        }\n                    }\n                }\n            }\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++)\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];\n            }\n        }\n    }\n}\n\nvoid ggml_gemm_q8_0_4x4_q8_0_generic(int                        n,\n                                     float * GGML_RESTRICT      s,\n                                     size_t                     bs,\n                                     const void * GGML_RESTRICT vx,\n                                     const void * GGML_RESTRICT vy,\n                                     int                        nr,\n                                     int                        nc) {\n    const int qk                = QK8_0;\n    const int nb                = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen          = 4;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    float sumf[4][4];\n    int   sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumf[m][j] = 0.0;\n                }\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int k = 0; k < (qk / blocklen); k++) {\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sumi = 0;\n                            for (int i = 0; i < blocklen; ++i) {\n                                const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];\n                                sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];\n                            }\n                            sumf[m][j] +=\n                                sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);\n                        }\n                    }\n                }\n            }\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];\n                }\n            }\n        }\n    }\n}\n\n\n\nvoid ggml_gemm_q8_0_4x8_q8_0_generic(int                        n,\n                                     float * GGML_RESTRICT      s,\n                                     size_t                     bs,\n                                     const void * GGML_RESTRICT vx,\n                                     const void * GGML_RESTRICT vy,\n                                     int                        nr,\n                                     int                        nc) {\n    const int qk                = QK8_0;\n    const int nb                = n / qk;\n    const int ncols_interleaved = 4;\n    const int blocklen          = 8;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    float sumf[4][4];\n    int   sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumf[m][j] = 0.0;\n                }\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int k = 0; k < (qk / blocklen); k++) {\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sumi = 0;\n                            for (int i = 0; i < blocklen; ++i) {\n                                const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];\n                                sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];\n                            }\n                            sumf[m][j] +=\n                                sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);\n                        }\n                    }\n                }\n            }\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];\n                }\n            }\n        }\n    }\n}\n\n#if defined __riscv_zvfh\nvoid ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen = 1;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n    float sumf[4][16];\n    int sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sumi = 0;\n                            for (int i = 0; i < blocklen; ++i) {\n                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);\n                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);\n                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +\n                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;\n                            }\n                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);\n                        }\n                    }\n                }\n            }\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++)\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];\n            }\n        }\n    }\n}\n\nvoid ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK_K;\n    const int nb = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen = 1;\n\n    assert (n % qk == 0);\n    assert (nr % 4 == 0);\n    assert (nc % ncols_interleaved == 0);\n\n    UNUSED(s);\n    UNUSED(bs);\n    UNUSED(vx);\n    UNUSED(vy);\n    UNUSED(nr);\n    UNUSED(nc);\n    UNUSED(nb);\n    UNUSED(ncols_interleaved);\n    UNUSED(blocklen);\n\n    float sumf[4][16];\n    float sum_minf[4][16];\n    uint8_t scales[128];\n    uint8_t mins[128];\n    int sumi1;\n    int sumi2;\n    int sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumf[m][j] = 0.0;\n                    sum_minf[m][j] = 0.0;\n                }\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int i = 0; i < 128; i++) {\n                    scales[i] = b_ptr[l].scales[i] & 0x0F;\n                    mins[i] = b_ptr[l].scales[i] >> 4;\n                }\n                for (int i = 0; i < 64; i++) {\n                    scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4;\n                    mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2;\n                    scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30);\n                    mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2;\n                }\n\n                for (int sb = 0; sb < 8; sb++) {\n                    uint8_t *min = &mins[sb * 16];\n                    for(int m = 0; m < 4; m++) {\n                        const int16_t bsums = a_ptr[l].bsums[sb * 8 + m] + a_ptr[l].bsums[sb * 8 + m + 4];\n                        for(int j = 0; j < ncols_interleaved; j++) {\n                            sum_minf[m][j] += min[j] * bsums * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];\n                        }\n                    }\n                }\n\n                for (int sb = 0; sb < 8; sb += 2) {\n                    uint8_t *scales_0 = &scales[sb * 16];\n                    uint8_t *scales_1 = &scales[(sb + 1) * 16];\n\n                    for (int i = 0; i < QK4_0; i++) {\n                        for (int m = 0; m < 4; m++) {\n                            for (int j = 0; j < ncols_interleaved; j++) {\n                                sumi1 = 0;\n                                sumi2 = 0;\n                                sumi = 0;\n\n                                const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF);\n                                const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4);\n                                sumi1 = (v0 * a_ptr[l].qs[sb * 4 * 32 + i * 4 + m]);\n                                sumi2 = (v1 * a_ptr[l].qs[sb * 4 * 32 + 32 * 4 + i * 4 + m]);\n                                sumi1 = sumi1 * scales_0[j];\n                                sumi2 = sumi2 * scales_1[j];\n                                sumi += sumi1 + sumi2;\n\n                                sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];\n                            }\n                        }\n                    }\n                }\n            }\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];\n                }\n            }\n        }\n    }\n}\n\nvoid ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk = QK8_0;\n    const int nb = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen = 1;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    float sumf[4][16];\n    int sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int k = 0; k < (qk / (2 * blocklen)); k++) {\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sumi = 0;\n                            for (int i = 0; i < blocklen; ++i) {\n                                const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];\n                                const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];\n                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +\n                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4]));\n                            }\n                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);\n                        }\n                    }\n                }\n            }\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++)\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];\n            }\n        }\n    }\n}\n\nvoid ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    const int qk                = QK8_0;\n    const int nb                = n / qk;\n    const int ncols_interleaved = 16;\n    const int blocklen          = 1;\n\n    assert(n % qk == 0);\n    assert(nr % 4 == 0);\n    assert(nc % ncols_interleaved == 0);\n\n    float sumf[4][16];\n    int   sumi;\n\n    for (int y = 0; y < nr / 4; y++) {\n        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);\n        for (int x = 0; x < nc / ncols_interleaved; x++) {\n            const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb);\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    sumf[m][j] = 0.0;\n                }\n            }\n            for (int l = 0; l < nb; l++) {\n                for (int k = 0; k < (qk / blocklen); k++) {\n                    for (int m = 0; m < 4; m++) {\n                        for (int j = 0; j < ncols_interleaved; j++) {\n                            sumi = 0;\n                            for (int i = 0; i < blocklen; ++i) {\n                                const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];\n                                sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];\n                            }\n                            sumf[m][j] +=\n                                sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);\n                        }\n                    }\n                }\n            }\n            for (int m = 0; m < 4; m++) {\n                for (int j = 0; j < ncols_interleaved; j++) {\n                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];\n                }\n            }\n        }\n    }\n}\n\n\nvoid ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {\n    assert(n % QK_K == 0);\n    assert(nr % 4 == 0);\n    assert(nc % 16 == 0);\n    const int nb = n / QK_K;\n    const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx;\n    const block_q8_Kx4  * y = (const block_q8_Kx4 *)vy;\n\n    const int sb_perm[16] = {\n        0, 4, 1, 5, 2, 6, 3, 7,\n        8, 12, 9, 13, 10, 14, 11, 15\n    };\n\n    // Iterate Rows in tiles of 4\n    for (int row_tile = 0; row_tile < nr; row_tile += 4) {\n        // Iterate Columns in tiles of 16\n        for (int col_tile = 0; col_tile < nc; col_tile += 16) {\n\n            const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb;\n            const block_q8_Kx4  * y_ptr = y + (row_tile / 4) * nb;\n\n            float sumf[4][16];\n            memset(sumf, 0, sizeof(sumf));\n\n            for (int k_block = 0; k_block < nb; ++k_block) {\n                int32_t isum[4][16];\n                int32_t summs[4][16];\n                memset(isum, 0, sizeof(isum));\n                memset(summs, 0, sizeof(summs));\n\n                const uint8_t * qs_rhs = x_ptr[k_block].qs;\n                const uint8_t * sc_rhs = x_ptr[k_block].scales;\n                const int8_t  * qs_lhs = y_ptr[k_block].qs;\n                const int16_t * bs_lhs = y_ptr[k_block].bsums;\n\n                for (int sb = 0; sb < 16; ++sb) {\n                    int scale_offset = sb_perm[sb] * 16;\n\n                    int byte_base;\n                    if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16;\n                    else        byte_base = (sb % 2 == 0) ? 32 : 48;\n                    int shift = ((sb / 2) % 4) * 2;\n\n                    for (int col = 0; col < 16; ++col) {\n                        uint8_t sc_val = sc_rhs[scale_offset + col];\n                        int32_t d_sb = sc_val & 0xF;\n                        int32_t m_sb = sc_val >> 4;\n\n                        // Correction Term\n                        for (int r = 0; r < 4; ++r) {\n                            int bsum_idx = (sb / 4) * 16 + r * 4 + (sb % 4);\n                            summs[r][col] += bs_lhs[bsum_idx] * m_sb;\n                        }\n\n                        // Main Dot Product\n                        for (int l = 0; l < 16; ++l) {\n                            int qs_idx = (byte_base + l) * 16 + col;\n                            uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3;\n\n                            // Calculate Q8 index for this specific k and row\n                            int k = sb * 16 + l;\n                            int q8_idx = (k / 4) * 16 + (k % 4);\n\n                            for (int r = 0; r < 4; ++r) {\n                                // Add r*4 to jump to the correct row within the 4x4 chunk\n                                int8_t q8_val = qs_lhs[q8_idx + r * 4];\n                                isum[r][col] += q8_val * q2_val * d_sb;\n                            }\n                        }\n                    }\n                }\n\n                // Finalize K-Block\n                for (int col = 0; col < 16; ++col) {\n                    float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]);\n                    float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]);\n\n                    for (int r = 0; r < 4; ++r) {\n                        float d_lhs = y_ptr[k_block].d[r];\n                        float d_all = d_lhs * d_rhs;\n                        float d_min = d_lhs * dm_rhs;\n                        sumf[r][col] += (isum[r][col] * d_all) - (summs[r][col] * d_min);\n                    }\n                }\n            }\n\n            for (int r = 0; r < 4; ++r) {\n                for (int col = 0; col < 16; ++col) {\n                    s[(row_tile + r) * bs + (col_tile + col)] = sumf[r][col];\n                }\n            }\n        }\n    }\n}\n#endif\n\n} // extern \"C\"\n\nstatic block_q8_0x4 make_block_q8_0x4(block_q8_0 * in, unsigned int blck_size_interleave) {\n    block_q8_0x4 out;\n\n    for (int i = 0; i < 4; i++) {\n        out.d[i] = in[i].d;\n    }\n\n    const int end = QK8_0 * 4 / blck_size_interleave;\n    for (int i = 0; i < end; ++i) {\n        int src_id     = i % 4;\n        int src_offset = (i / 4) * blck_size_interleave;\n        int dst_offset = i * blck_size_interleave;\n        memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);\n    }\n    return out;\n}\n\nstatic block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {\n    block_q4_0x4 out;\n\n    for (int i = 0; i < 4; i++) {\n        out.d[i] = in[i].d;\n    }\n\n    const int end = QK4_0 * 2 / blck_size_interleave;\n\n    if (blck_size_interleave == 8) {\n        const uint64_t xor_mask = 0x8888888888888888ULL;\n        for (int i = 0; i < end; ++i) {\n            int src_id = i % 4;\n            int src_offset = (i / 4) * blck_size_interleave;\n            int dst_offset = i * blck_size_interleave;\n\n            uint64_t elems;\n            // Using memcpy to avoid unaligned memory accesses\n            memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));\n            elems ^= xor_mask;\n            memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));\n        }\n    } else if (blck_size_interleave == 4) {\n        const uint32_t xor_mask = 0x88888888;\n        for (int i = 0; i < end; ++i) {\n            int src_id = i % 4;\n            int src_offset = (i / 4) * blck_size_interleave;\n            int dst_offset = i * blck_size_interleave;\n\n            uint32_t elems;\n            memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t));\n            elems ^= xor_mask;\n            memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t));\n        }\n    } else {\n        GGML_ASSERT(false);\n    }\n\n    return out;\n}\n\n// interleave 8 block_q4_0s in blocks of blck_size_interleave\n// returns an interleaved block_q4_0x8\n// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks\n// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave\nstatic block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) {\n    block_q4_0x8 out;\n\n    for (int i = 0; i < 8; i++) {\n        out.d[i] = in[i].d;\n    }\n\n    const int end = QK4_0 * 4 / blck_size_interleave;\n    const uint64_t xor_mask = 0x8888888888888888ULL;\n\n    for (int i = 0; i < end; ++i) {\n        int src_id = i % 8;\n        int src_offset = (i / 8) * blck_size_interleave;\n        int dst_offset = i * blck_size_interleave;\n\n        uint64_t elems;\n        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));\n        elems ^= xor_mask;\n        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));\n    }\n\n    return out;\n}\n\nstatic block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) {\n    block_q4_0x16 out;\n\n    for (int i = 0; i < 16; i++) {\n        out.d[i] = in[i].d;\n    }\n\n    const int end = QK4_0 * 8 / blck_size_interleave;\n\n    if (blck_size_interleave == 1) {\n        const uint8_t xor_mask = 0x88;\n        for (int i = 0; i < end; ++i) {\n            int src_id = i % 16;\n            int src_offset = i / 16;\n            int dst_offset = i;\n\n            out.qs[dst_offset] = in[src_id].qs[src_offset] ^ xor_mask;\n        }\n    } else {\n        GGML_ASSERT(false);\n    }\n\n    return out;\n}\n\nstatic block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) {\n    block_q4_Kx8 out;\n    //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure\n    for (int i = 0; i < 8; i++) {\n        out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;\n    }\n\n    for (int i = 0; i < 8; i++) {\n        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;\n    }\n\n    const int end = QK_K * 4 / blck_size_interleave;\n\n    // Interleave Q4_K quants by taking 8 bytes at a time\n    for (int i = 0; i < end; ++i) {\n        int src_id = i % 8;\n        int src_offset = (i / 8) * blck_size_interleave;\n        int dst_offset = i * blck_size_interleave;\n\n        // buffer large enough for the max interleave block size (8 bytes)\n        uint64_t elems;\n        memcpy(&elems, &in[src_id].qs[src_offset], blck_size_interleave);\n        memcpy(&out.qs[dst_offset], &elems, blck_size_interleave);\n    }\n\n    // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K\n    // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)\n    // The output Q4_Kx8 structure has 96 bytes\n    // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure\n    // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures\n    uint8_t s[8], m[8];\n\n    for (int i = 0; i < 4; i++) {\n        for (int j = 0; j < 8; j++) {\n            s[j] = in[j].scales[i] & 63;\n            m[j] = in[j].scales[i + 4] & 63;\n        }\n\n        out.scales[i * 12]      = (s[0] & 63) + ((s[4] & 48) << 2);\n        out.scales[i * 12 + 1]  = (s[1] & 63) + ((s[5] & 48) << 2);\n        out.scales[i * 12 + 2]  = (s[2] & 63) + ((s[6] & 48) << 2);\n        out.scales[i * 12 + 3]  = (s[3] & 63) + ((s[7] & 48) << 2);\n        out.scales[i * 12 + 4]  = (m[0] & 63) + ((m[4] & 48) << 2);\n        out.scales[i * 12 + 5]  = (m[1] & 63) + ((m[5] & 48) << 2);\n        out.scales[i * 12 + 6]  = (m[2] & 63) + ((m[6] & 48) << 2);\n        out.scales[i * 12 + 7]  = (m[3] & 63) + ((m[7] & 48) << 2);\n        out.scales[i * 12 + 8]  = (s[4] & 15) + ((m[4] & 15) << 4);\n        out.scales[i * 12 + 9]  = (s[5] & 15) + ((m[5] & 15) << 4);\n        out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);\n        out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);\n\n    }\n\n    for (int i = 0; i < 4; i++) {\n        for (int j = 0; j < 8; j++) {\n            s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15);\n            m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4);\n        }\n\n        out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);\n        out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);\n        out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);\n        out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);\n        out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);\n        out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);\n        out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);\n        out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);\n        out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);\n        out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);\n        out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);\n        out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);\n\n    }\n\n    return out;\n}\n\nstatic block_q4_Kx16 make_block_q4_Kx16(block_q4_K * in, unsigned int blck_size_interleave) {\n    block_q4_Kx16 out;\n    //Delta(scale) and dmin values of the 16 Q4_K structures are copied onto the output interleaved structure\n    for (int i = 0; i < 16; i++) {\n        out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;\n    }\n\n    for (int i = 0; i < 16; i++) {\n        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;\n    }\n\n    const int end = QK_K * 8 / blck_size_interleave;\n\n    if (blck_size_interleave == 1) {\n        for (int i = 0; i < end; ++i) {\n            int src_id = i % 16;\n            int src_offset = i / 16;\n            int dst_offset = i;\n\n            out.qs[dst_offset] = in[src_id].qs[src_offset];\n        }\n\n        // RVV repacking.\n        //\n        // Extract sums and mins for all 8 sub-blocks for each block of Q4_K.\n        uint8_t s[128], m[128];\n        for (int i = 0; i < 4; i++) {\n            for (int j = 0; j < 16; j++) {\n                s[i * 16 + j] = in[j].scales[i] & 63;\n                m[i * 16 + j] = in[j].scales[i + 4] & 63;\n            }\n        }\n        for (int i = 0; i < 4; i++) {\n            for (int j = 0; j < 16; j++) {\n                s[64 + i * 16 + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15);\n                m[64 + i * 16 + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4);\n            }\n        }\n\n        for (int i = 0; i < 128; i++) {\n            out.scales[i] = (s[i] & 15) | ((m[i] & 15) << 4);\n        }\n        for (int i = 0; i < 64; i++) {\n            out.scales[128 + i] = ((s[i] & 48) >> 4) | ((m[i] & 48) >> 2) | (s[64 + i] & 48) | ((m[64 + i] & 48) << 2);\n        }\n    } else {\n        GGML_ASSERT(false);\n    }\n\n    return out;\n}\n\nstatic block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_interleave) {\n    block_q2_Kx8 out;\n\n    // Delta(scale) and dmin values of the eight Q2_K structures are copied onto the output interleaved structure\n    for (int i = 0; i < 8; i++) {\n        out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;\n    }\n\n    for (int i = 0; i < 8; i++) {\n        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;\n    }\n\n    const int end = QK_K * 2 / blck_size_interleave;\n\n    // Interleave Q2_K quants by taking 8 bytes at a time\n    for (int i = 0; i < end; ++i) {\n        int src_id = i % 8;\n        int src_offset = (i / 8) * blck_size_interleave;\n        int dst_offset = i * blck_size_interleave;\n\n        uint64_t elems;\n        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));\n        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));\n    }\n\n    // The below logic is designed so as to unpack and rearrange scales and mins values in Q2_K\n    // Currently the Q2_K structure has 16 scales and 16 mins packed in 16 bytes ( 4 bits for each value)\n    // The output Q2_Kx8 structure has 128 bytes for storing scales and mins\n    // Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure\n    // For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures\n\n    for (int i = 0; i < 128; i++) {\n        // Index for selecting which q2k super block\n        int src1 = (i % 16) / 2;\n        // Index for selecting scale\n        int src2 = ((i / 16) * 2) + (i % 2);\n\n        out.scales[i] = in[src1].scales[src2];\n    }\n    return out;\n}\n\nstatic block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) {\n    block_q5_Kx8 out;\n    //Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure\n    for (int i = 0; i < 8; i++) {\n        out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;\n    }\n\n    for (int i = 0; i < 8; i++) {\n        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;\n    }\n\n    const int end = QK_K * 4 / blck_size_interleave;\n\n    // Interleave Q5_K quants by taking blck_size_interleave bytes at a time\n    for (int i = 0; i < end; ++i) {\n        int src_id     = i % 8;\n        int src_offset = (i / 8) * blck_size_interleave;\n        int dst_offset = i * blck_size_interleave;\n\n        memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);\n    }\n\n    // Repeat for high bits with the same chunk size, since\n    // the high bits are interleaved in Q5_K and the index is\n    // qh_idx = (qs_idx % 32);\n    // qh_val = qh[qh_idx] >> (qs_idx / 32);\n    for (int i = 0; i < end / 4; ++i) {\n        int src_id     = i % 8;\n        int src_offset = (i / 8) * blck_size_interleave;\n        int dst_offset = i * blck_size_interleave;\n\n        memcpy(&out.qh[dst_offset], &in[src_id].qh[src_offset], blck_size_interleave);\n    }\n\n    // The below logic is copied over from Q4_K\n    // The point is to unpack all the scales and mins for each sub block every time we load 12 bytes.\n    // Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)\n    // The output Q5_Kx8 structure has 96 bytes\n    // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure\n    // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures\n    uint8_t s[8], m[8];\n\n    for (int i = 0; i < 4; i++) {\n        for (int j = 0; j < 8; j++) {\n            s[j] = in[j].scales[i] & 63;\n            m[j] = in[j].scales[i + 4] & 63;\n        }\n\n        out.scales[i * 12]      = (s[0] & 63) + ((s[4] & 48) << 2);\n        out.scales[i * 12 + 1]  = (s[1] & 63) + ((s[5] & 48) << 2);\n        out.scales[i * 12 + 2]  = (s[2] & 63) + ((s[6] & 48) << 2);\n        out.scales[i * 12 + 3]  = (s[3] & 63) + ((s[7] & 48) << 2);\n        out.scales[i * 12 + 4]  = (m[0] & 63) + ((m[4] & 48) << 2);\n        out.scales[i * 12 + 5]  = (m[1] & 63) + ((m[5] & 48) << 2);\n        out.scales[i * 12 + 6]  = (m[2] & 63) + ((m[6] & 48) << 2);\n        out.scales[i * 12 + 7]  = (m[3] & 63) + ((m[7] & 48) << 2);\n        out.scales[i * 12 + 8]  = (s[4] & 15) + ((m[4] & 15) << 4);\n        out.scales[i * 12 + 9]  = (s[5] & 15) + ((m[5] & 15) << 4);\n        out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);\n        out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);\n    }\n\n    for (int i = 0; i < 4; i++) {\n        for (int j = 0; j < 8; j++) {\n            s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15);\n            m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4);\n        }\n\n        out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);\n        out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);\n        out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);\n        out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);\n        out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);\n        out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);\n        out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);\n        out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);\n        out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);\n        out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);\n        out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);\n        out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);\n    }\n\n    return out;\n}\n\nstatic block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_interleave) {\n    block_q6_Kx8  out;\n    constexpr int n_blocks = 8;  // Kx8\n    for (int i = 0; i < n_blocks; i++) {\n        out.d[i] = in[i].d;\n    }\n\n    const int end_ls = QK_K * 4 / blck_size_interleave;\n    // Interleave Q6_K quants by taking blck_size_interleave bytes at a time\n    for (int i = 0; i < end_ls; ++i) {\n        int src_id     = i % n_blocks;\n        int src_offset = (i / n_blocks) * blck_size_interleave;\n        int dst_offset = i * blck_size_interleave;\n\n        uint64_t elem_ls;\n        memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave);\n        memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave);\n    }\n\n    // Interleave high bits using same chunk size as low bits\n    const int end_hs = end_ls / 2;\n    for (int i = 0; i < end_hs; ++i) {\n        int src_id     = i % n_blocks;\n        int src_offset = (i / n_blocks) * blck_size_interleave;\n        int dst_offset = i * blck_size_interleave;\n\n        uint64_t elem_hs;\n        memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave);\n        memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave);\n    }\n\n    // The below logic is designed so as to unpack and rearrange scales in Q6_K\n    // The output Q6_Kx8 structure interleaves the 8 bit scales in the same fashion as the quants\n    // Q6_K structure has an 8-bit scale per 16 elements -> 16 scales\n    // scales: [0 bl0 0 bl1 ... 0 bl7][1 bl0 ... 1 bl7] ... [15 bl0 ... 15 bl7]  (bl = block)\n    constexpr int n_scales = QK_K / 16;\n\n    for (int i = 0; i < n_blocks; i++) {\n        for (int j = 0; j < n_scales; j++) {\n            out.scales[j * n_blocks + i] = in[i].scales[j];\n        }\n    }\n\n    return out;\n}\n\nstatic block_q2_Kx16 make_block_q2_Kx16(const block_q2_K * in, unsigned int blck_size_interleave) {\n    block_q2_Kx16 out;\n    constexpr int N_COLS = 16;\n\n    // 1. Copy Super-Scales (d) and Super-Mins (dmin)\n    for (int i = 0; i < N_COLS; i++) {\n        out.d[i]    = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;\n        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;\n    }\n\n    // 2. Interleave Q2_K Data\n    const int bytes_per_col = 64;\n    const int total_bytes = N_COLS * bytes_per_col;\n    const int end = total_bytes / blck_size_interleave;\n\n    for (int i = 0; i < end; ++i) {\n        int src_col_id = i % N_COLS;\n        int src_offset = (i / N_COLS) * blck_size_interleave;\n        int dst_offset = i * blck_size_interleave;\n        memcpy(&out.qs[dst_offset], &in[src_col_id].qs[src_offset], blck_size_interleave);\n    }\n\n    // 3. Repack Scales into the Optimized \"Sequential-Parallel\" Layout\n    int out_idx = 0;\n\n    // Arrays define the sub-block order for each group\n    const int even_low_sbs[]  = {0, 2, 4, 6};\n    const int odd_low_sbs[]   = {1, 3, 5, 7};\n    const int even_high_sbs[] = {8, 10, 12, 14};\n    const int odd_high_sbs[]  = {9, 11, 13, 15};\n\n    // Pack Group 1: Even-Low\n    for (int sb : even_low_sbs) {\n        for (int col = 0; col < N_COLS; col++) {\n            out.scales[out_idx++] = in[col].scales[sb];\n        }\n    }\n\n    // Pack Group 2: Odd-Low\n    for (int sb : odd_low_sbs) {\n        for (int col = 0; col < N_COLS; col++) {\n            out.scales[out_idx++] = in[col].scales[sb];\n        }\n    }\n\n    // Pack Group 3: Even-High\n    for (int sb : even_high_sbs) {\n        for (int col = 0; col < N_COLS; col++) {\n            out.scales[out_idx++] = in[col].scales[sb];\n        }\n    }\n\n    // Pack Group 4: Odd-High\n    for (int sb : odd_high_sbs) {\n        for (int col = 0; col < N_COLS; col++) {\n            out.scales[out_idx++] = in[col].scales[sb];\n        }\n    }\n\n    return out;\n}\n\nstatic int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);\n    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);\n    constexpr int nrows_interleaved = 4;\n\n    block_q4_0x4 * dst = (block_q4_0x4 *)t->data;\n    const block_q4_0 * src = (const block_q4_0 *)data;\n    block_q4_0 dst_tmp[4];\n    int nrow = ggml_nrows(t);\n    int nblocks = t->ne[0] / QK4_0;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i = 0; i < nrows_interleaved; i++) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_q4_0x4(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nstatic int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_Q4_K);\n    GGML_ASSERT(interleave_block == 8 || interleave_block == 4);\n    constexpr int nrows_interleaved = 8;\n\n    block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;\n    const block_q4_K * src = (const block_q4_K*) data;\n    block_q4_K dst_tmp[8];\n    int nrow = ggml_nrows(t);\n    int nblocks = t->ne[0] / QK_K;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i  = 0; i < nrows_interleaved; i++ ) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_q4_Kx8(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nstatic int repack_q4_K_to_q4_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_Q4_K);\n    constexpr int nrows_interleaved = 16;\n\n    block_q4_Kx16 * dst = (block_q4_Kx16*)t->data;\n    const block_q4_K * src = (const block_q4_K*) data;\n    block_q4_K dst_tmp[16];\n    int nrow = ggml_nrows(t);\n    int nblocks = t->ne[0] / QK_K;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i  = 0; i < nrows_interleaved; i++ ) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_q4_Kx16(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nstatic int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_Q2_K);\n    GGML_ASSERT(interleave_block == 8);\n    constexpr int nrows_interleaved = 8;\n\n    block_q2_Kx8 * dst = (block_q2_Kx8*)t->data;\n    const block_q2_K * src = (const block_q2_K*) data;\n    block_q2_K dst_tmp[8];\n    int nrow = ggml_nrows(t);\n    int nblocks = t->ne[0] / QK_K;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i = 0; i < nrows_interleaved; i++) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_q2_Kx8(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nstatic int repack_q2_K_to_q2_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_Q2_K);\n    constexpr int nrows_interleaved = 16;\n\n    block_q2_Kx16 * dst = (block_q2_Kx16*)t->data;\n    const block_q2_K * src = (const block_q2_K*) data;\n\n    block_q2_K dst_tmp[nrows_interleaved];\n\n    int nrow = ggml_nrows(t);\n    int nblocks = t->ne[0] / QK_K;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            // This loop gathers 16 separate blocks (one from each column)\n            // that correspond to the same K-dimension chunk.\n            for (int i  = 0; i < nrows_interleaved; i++ ) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n\n            *dst++ = make_block_q2_Kx16(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nstatic int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);\n    constexpr int nrows_interleaved = 16;\n\n    block_q4_0x16 * dst = (block_q4_0x16*)t->data;\n    const block_q4_0 * src = (const block_q4_0*) data;\n    block_q4_0 dst_tmp[16];\n    int nrow = ggml_nrows(t);\n    int nblocks = t->ne[0] / QK4_0;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i  = 0; i < nrows_interleaved; i++ ) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_q4_0x16(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nstatic int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor *       t,\n                                    int                        interleave_block,\n                                    const void * GGML_RESTRICT data,\n                                    size_t                     data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_Q5_K);\n    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);\n    constexpr int nrows_interleaved = 8;\n\n    block_q5_Kx8 *     dst = (block_q5_Kx8 *) t->data;\n    const block_q5_K * src = (const block_q5_K *) data;\n    block_q5_K         dst_tmp[8];\n    int                nrow    = ggml_nrows(t);\n    int                nblocks = t->ne[0] / QK_K;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i = 0; i < nrows_interleaved; i++) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_q5_Kx8(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n}\n\nstatic int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_Q6_K);\n    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);\n    constexpr int nrows_interleaved = 8;\n\n    block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data;\n    const block_q6_K * src = (const block_q6_K *) data;\n    block_q6_K dst_tmp[8];\n    int nrow = ggml_nrows(t);\n    int nblocks = t->ne[0] / QK_K;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q6_K));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i = 0; i < nrows_interleaved; i++) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_q6_Kx8(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n}\n\nstatic int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);\n    GGML_ASSERT(interleave_block == 8);\n    constexpr int nrows_interleaved = 8;\n\n    block_q4_0x8 * dst = (block_q4_0x8*)t->data;\n    const block_q4_0 * src = (const block_q4_0*) data;\n    block_q4_0 dst_tmp[8];\n    int nrow = ggml_nrows(t);\n    int nblocks = t->ne[0] / QK4_0;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i  = 0; i < nrows_interleaved; i++ ) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_q4_0x8(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nstatic int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor *       t,\n                                    int                        interleave_block,\n                                    const void * GGML_RESTRICT data,\n                                    size_t                     data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_Q8_0);\n    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);\n    constexpr int nrows_interleaved = 4;\n\n    block_q8_0x4 *     dst = (block_q8_0x4 *) t->data;\n    const block_q8_0 * src = (const block_q8_0 *) data;\n    block_q8_0         dst_tmp[4];\n    int                nrow    = ggml_nrows(t);\n    int                nblocks = t->ne[0] / QK8_0;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i = 0; i < nrows_interleaved; i++) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_q8_0x4(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n}\n\nstatic block_q8_0x16 make_block_q8_0x16(block_q8_0 * in, unsigned int blck_size_interleave) {\n    block_q8_0x16 out;\n\n    for (int i = 0; i < 16; i++) {\n        out.d[i] = in[i].d;\n    }\n\n    const int end = QK8_0 * 16 / blck_size_interleave;\n\n    if (blck_size_interleave == 1) {\n        for (int i = 0; i < end; ++i) {\n            int src_id     = i % 16;\n            int src_offset = i / 16;\n            int dst_offset = i;\n            out.qs[dst_offset] = in[src_id].qs[src_offset];\n        }\n    } else {\n        GGML_ASSERT(false);\n    }\n\n    return out;\n}\n\nstatic int repack_q8_0_to_q8_0_16_bl(struct ggml_tensor *       t,\n                                    int                        interleave_block,\n                                    const void * GGML_RESTRICT data,\n                                    size_t                     data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_Q8_0);\n    constexpr int nrows_interleaved = 16;\n\n    block_q8_0x16 *     dst = (block_q8_0x16 *) t->data;\n    const block_q8_0 * src = (const block_q8_0 *) data;\n    block_q8_0         dst_tmp[16];\n    int                nrow    = ggml_nrows(t);\n    int                nblocks = t->ne[0] / QK8_0;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i = 0; i < nrows_interleaved; i++) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_q8_0x16(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n}\n\nstatic block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {\n    block_iq4_nlx4 out;\n\n    for (int i = 0; i < 4; i++) {\n        out.d[i] = in[i].d;\n    }\n\n    const int end = QK4_NL * 2 / blck_size_interleave;\n\n    // TODO: this branch seems wrong\n    //if (blck_size_interleave == 8) {\n    //    for (int i = 0; i < end; ++i) {\n    //        int src_id = i % 4;\n    //        int src_offset = (i / 4) * blck_size_interleave;\n    //        int dst_offset = i * blck_size_interleave;\n\n    //        // Using memcpy to avoid unaligned memory accesses\n    //        memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));\n    //    }\n    //} else\n    if (blck_size_interleave == 4) {\n        for (int i = 0; i < end; ++i) {\n            int src_id = i % 4;\n            int src_offset = (i / 4) * blck_size_interleave;\n            int dst_offset = i * blck_size_interleave;\n\n            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));\n        }\n    } else {\n        GGML_ASSERT(false);\n    }\n\n    return out;\n}\n\nstatic int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);\n    GGML_ASSERT(interleave_block == 4);\n\n    const block_iq4_nl   * src = (const block_iq4_nl   *)data;\n          block_iq4_nlx4 * dst = (      block_iq4_nlx4 *)t->data;\n\n    block_iq4_nl dst_tmp[4];\n\n    int nrow = ggml_nrows(t);\n    int nrows_interleaved = 4;\n    int nblocks = t->ne[0] / QK4_NL;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i = 0; i < nrows_interleaved; i++) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nstatic block_iq4_nlx8 make_block_iq4_nlx8(block_iq4_nl * in, unsigned int blck_size_interleave) {\n    block_iq4_nlx8 out;\n\n    for (int i = 0; i < 8; i++) {\n        out.d[i] = in[i].d;\n    }\n\n    const int end = QK4_NL * 4 / blck_size_interleave;\n\n    if (blck_size_interleave == 8) {\n        for (int i = 0; i < end; ++i) {\n            int src_id = i % 8;\n            int src_offset = (i / 8) * blck_size_interleave;\n            int dst_offset = i * blck_size_interleave;\n\n            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));\n        }\n    } else {\n        GGML_ASSERT(false);\n    }\n\n    return out;\n}\n\nstatic int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);\n    GGML_ASSERT(interleave_block == 8);\n\n    const block_iq4_nl   * src = (const block_iq4_nl   *)data;\n          block_iq4_nlx8 * dst = (      block_iq4_nlx8 *)t->data;\n\n    block_iq4_nl dst_tmp[8];\n\n    int nrow = ggml_nrows(t);\n    int nrows_interleaved = 8;\n    int nblocks = t->ne[0] / QK4_NL;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));\n\n    if (t->ne[1] % nrows_interleaved != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i = 0; i < nrows_interleaved; i++) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_iq4_nlx8(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nstatic block_iq4_nlx16 make_block_iq4_nlx16(block_iq4_nl * in, unsigned int blck_size_interleave) {\n    block_iq4_nlx16 out;\n\n    for (int i = 0; i < 16; i++) {\n        out.d[i] = in[i].d;\n    }\n\n    const int end = QK4_NL * 8 / blck_size_interleave;\n\n    if (blck_size_interleave == 1) {\n        for (int i = 0; i < end; ++i) {\n            int src_id = i % 16;\n            int src_offset = i / 16;\n            int dst_offset = i;\n\n            out.qs[dst_offset] = in[src_id].qs[src_offset];\n        }\n    } else {\n        GGML_ASSERT(false);\n    }\n\n    return out;\n}\n\nstatic int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);\n    GGML_ASSERT(interleave_block == 1);\n\n    const block_iq4_nl    * src = (const block_iq4_nl   *)data;\n          block_iq4_nlx16 * dst = (      block_iq4_nlx16 *)t->data;\n\n    block_iq4_nl dst_tmp[16];\n\n    int nrow = ggml_nrows(t);\n    int nrows_interleaved = 16;\n    int nblocks = t->ne[0] / QK4_NL;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));\n\n    if (t->ne[1] % nrows_interleaved != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i = 0; i < nrows_interleaved; i++) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_iq4_nlx16(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nstatic block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) {\n    block_mxfp4x4 out;\n\n    for (int i = 0; i < 4; i++) {\n        out.e[i] = in[i].e;\n    }\n\n    const int end = QK_MXFP4 * 2 / blck_size_interleave;\n\n    if (blck_size_interleave == 4) {\n        for (int i = 0; i < end; ++i) {\n            int src_id = i % 4;\n            int src_offset = (i / 4) * blck_size_interleave;\n            int dst_offset = i * blck_size_interleave;\n\n            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));\n        }\n    } else {\n        GGML_ASSERT(false);\n    }\n\n    return out;\n}\n\nstatic int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_MXFP4);\n    GGML_ASSERT(interleave_block == 4);\n\n    const block_mxfp4   * src = (const block_mxfp4   *)data;\n          block_mxfp4x4 * dst = (      block_mxfp4x4 *)t->data;\n\n    block_mxfp4 dst_tmp[4];\n\n    int nrow = ggml_nrows(t);\n    int nrows_interleaved = 4;\n    int nblocks = t->ne[0] / QK_MXFP4;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i = 0; i < nrows_interleaved; i++) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_mxfp4x4(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nstatic block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size_interleave) {\n    block_mxfp4x8 out;\n\n    for (int i = 0; i < 8; i++) {\n        out.e[i] = in[i].e;\n    }\n\n    const int end = QK_MXFP4 * 4 / blck_size_interleave;\n\n    if (blck_size_interleave == 8) {\n        for (int i = 0; i < end; ++i) {\n            int src_id = i % 8;\n            int src_offset = (i / 8) * blck_size_interleave;\n            int dst_offset = i * blck_size_interleave;\n\n            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));\n        }\n    } else {\n        GGML_ASSERT(false);\n    }\n\n    return out;\n}\n\nstatic int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_MXFP4);\n    GGML_ASSERT(interleave_block == 8);\n\n    const block_mxfp4   * src = (const block_mxfp4   *)data;\n          block_mxfp4x8 * dst = (      block_mxfp4x8 *)t->data;\n\n    block_mxfp4 dst_tmp[8];\n\n    int nrow = ggml_nrows(t);\n    int nrows_interleaved = 8;\n    int nblocks = t->ne[0] / QK_MXFP4;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4));\n\n    if (t->ne[1] % nrows_interleaved != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i = 0; i < nrows_interleaved; i++) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_mxfp4x8(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nnamespace ggml::cpu::repack {\n// repack\ntemplate <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>\nint repack(struct ggml_tensor *, const void *, size_t);\n\n// TODO: generalise.\ntemplate <> int repack<block_q4_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);\n}\n\ntemplate <> int repack<block_q4_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);\n}\n\ntemplate <> int repack<block_q4_0, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);\n}\n\ntemplate <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);\n}\n\ntemplate <> int repack<block_q4_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q4_K_to_q4_K_8_bl(t, 4, data, data_size);\n}\n\ntemplate <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);\n}\n\ntemplate <> int repack<block_q5_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q5_K_to_q5_K_8_bl(t, 4, data, data_size);\n}\n\ntemplate <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);\n}\n\ntemplate <> int repack<block_q6_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q6_K_to_q6_K_8_bl(t, 4, data, data_size);\n}\n\ntemplate <> int repack<block_q6_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size);\n}\n\ntemplate <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);\n}\n\n// TODO: needs to be revisited\n//template <> int repack<block_iq4_nl, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {\n//    return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);\n//}\n\ntemplate <> int repack<block_iq4_nl, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);\n}\n\ntemplate <> int repack<block_mxfp4, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_mxfp4_to_mxfp4_4_bl(t, 4, data, data_size);\n}\n\ntemplate <> int repack<block_mxfp4, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_mxfp4_to_mxfp4_8_bl(t, 8, data, data_size);\n}\n\ntemplate <> int repack<block_q8_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size);\n}\n\ntemplate <> int repack<block_q8_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size);\n}\n\n#if defined __riscv_zvfh\ntemplate <> int repack<block_q4_0, 1, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q4_0_to_q4_0_16_bl(t, 1, data, data_size);\n}\n\ntemplate <> int repack<block_q4_K, 1, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q4_K_to_q4_K_16_bl(t, 1, data, data_size);\n}\n\ntemplate <> int repack<block_iq4_nl, 1, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_iq4_nl_to_iq4_nl_16_bl(t, 1, data, data_size);\n}\n\ntemplate <> int repack<block_q8_0, 1, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q8_0_to_q8_0_16_bl(t, 1, data, data_size);\n}\n\ntemplate <> int repack<block_q2_K, 1, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q2_K_to_q2_K_16_bl(t, 1, data, data_size);\n}\n#endif\n\n// gemv\ntemplate <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>\nvoid gemv(int, float *, size_t, const void *, const void *, int, int);\n\ntemplate <> void gemv<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <>\nvoid gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int          n,\n                                            float *      s,\n                                            size_t       bs,\n                                            const void * vx,\n                                            const void * vy,\n                                            int          nr,\n                                            int          nc) {\n    ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_q5_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_mxfp4, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_mxfp4, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_q8_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\n#if defined __riscv_zvfh\ntemplate <> void gemv<block_q4_0, 1, 16, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_q4_K, 1, 16, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_iq4_nl, 1, 16, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_q8_0, 1, 16, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemv<block_q2_K, 1, 16, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemv_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n#endif\n\n// gemm\ntemplate <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>\nvoid gemm(int, float *, size_t, const void *, const void *, int, int);\n\ntemplate <> void gemm<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <>\nvoid gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int          n,\n                                            float *      s,\n                                            size_t       bs,\n                                            const void * vx,\n                                            const void * vy,\n                                            int          nr,\n                                            int          nc) {\n    ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_q5_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_mxfp4, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_mxfp4, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_q8_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\n#if defined __riscv_zvfh\ntemplate <> void gemm<block_q4_0, 1, 16, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_q4_K, 1, 16, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_iq4_nl, 1, 16, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_q8_0, 1, 16, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc);\n}\n\ntemplate <> void gemm<block_q2_K, 1, 16, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {\n    ggml_gemm_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc);\n}\n#endif\n\nclass tensor_traits_base : public ggml::cpu::tensor_traits {\n  public:\n    virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;\n};\n\ntemplate <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE> class tensor_traits : public tensor_traits_base {\n\n    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {\n        // not realy a GGML_TYPE_Q8_0 but same size.\n        switch (op->op) {\n            case GGML_OP_MUL_MAT:\n                {\n                    size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));\n                    return true;\n                }\n            case GGML_OP_MUL_MAT_ID:\n                {\n                    size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));\n                    size = GGML_PAD(size, sizeof(int64_t)); // + padding for next block.\n\n                    const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert\n                    const int64_t ne12 = op->src[1]->ne[2]; // n_tokens\n\n                    const size_t sizeof_mmid_row_mapping = sizeof(int64_t);\n\n                    size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);\n\n                    return true;\n                }\n            default:\n                // GGML_ABORT(\"fatal error\");\n                break;\n        }\n        return false;\n    }\n\n    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {\n        switch (op->op) {\n            case GGML_OP_MUL_MAT:\n                forward_mul_mat(params, op);\n                return true;\n            case GGML_OP_MUL_MAT_ID:\n                forward_mul_mat_id(params, op);\n                return true;\n            default:\n                // GGML_ABORT(\"fatal error\");\n                break;\n        }\n        return false;\n    }\n\n    void forward_mul_mat_one_chunk(ggml_compute_params * params,\n                                   ggml_tensor *         op,\n                                   int64_t               src0_start,\n                                   int64_t               src0_end,\n                                   int64_t               src1_start,\n                                   int64_t               src1_end) {\n        const ggml_tensor * src0 = op->src[0];\n        const ggml_tensor * src1 = op->src[1];\n        ggml_tensor *       dst  = op;\n\n        GGML_TENSOR_BINARY_OP_LOCALS\n\n        const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);\n\n        GGML_ASSERT(ne03 == 1 && ne13 == 1);\n        GGML_ASSERT(ne12 % ne02 == 0);\n        const int64_t r2 = ne12 / ne02;\n\n        const int64_t i12 = src1_start / ne1;\n        const int64_t i11 = src1_start - i12 * ne1;\n\n        // Determine batch index\n        const int64_t i02 = i12 / r2;\n\n        const int64_t i1 = i11;\n        const int64_t i2 = i12;\n\n        const char * src0_ptr = (const char *) src0->data + i02 * nb02;\n        const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;\n        char *       dst_ptr  = ((char *) dst->data + (i1 * nb1 + i2 * nb2));\n\n        const int64_t nrows = src1_end - src1_start;\n        const int64_t ncols = src0_end - src0_start;\n\n        GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);\n\n        // If there are more than three rows in src1, use gemm; otherwise, use gemv.\n        if (nrows > 3) {\n            gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,\n                                                             src0_ptr + src0_start * nb01, src1_ptr,\n                                                             nrows - (nrows % 4), ncols);\n        }\n        for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {\n            gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,\n                                                             ne01, src0_ptr + src0_start * nb01,\n                                                             src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);\n        }\n    }\n\n    void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {\n        const ggml_tensor * src0 = op->src[0];\n        const ggml_tensor * src1 = op->src[1];\n        ggml_tensor *       dst  = op;\n\n        GGML_TENSOR_BINARY_OP_LOCALS\n\n        const int ith = params->ith;\n        const int nth = params->nth;\n\n        GGML_ASSERT(ne0 == ne01);\n        GGML_ASSERT(ne1 == ne11);\n        GGML_ASSERT(ne2 == ne12);\n        GGML_ASSERT(ne3 == ne13);\n\n        // dst cannot be transposed or permuted\n        GGML_ASSERT(nb0 == sizeof(float));\n        GGML_ASSERT(nb0 <= nb1);\n        GGML_ASSERT(nb1 <= nb2);\n        GGML_ASSERT(nb2 <= nb3);\n\n        // TODO: General batched mul mat for 4D tensors\n        // Currently only supports 3D tensors\n        GGML_ASSERT(ne03 == 1);\n        GGML_ASSERT(ne13 == 1);\n        GGML_ASSERT(ne3 == 1);\n\n        GGML_ASSERT(src1->type == GGML_TYPE_F32);\n\n        GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);\n        // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);\n\n        char *       wdata = static_cast<char *>(params->wdata);\n        const size_t nbw1  = ggml_row_size(PARAM_TYPE, ne10);\n        const size_t nbw2  = nbw1 * ne11;\n\n        assert(params->wsize >= nbw2 * ne12);\n\n        const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;\n\n        // INFO: Quantization is done in planes to avoid extra complexity in chunking.\n        // Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how\n        // the planes are broadcast.\n        for (int64_t i12 = 0; i12 < ne12; i12++) {\n            char * data_ptr  = (char *) src1->data + i12 * nb12;\n            char * wdata_ptr = wdata + i12 * nbw2;\n\n            for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {\n                ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),\n                                                            (void *) (wdata_ptr + i11 * nbw1), 4, ne10);\n            }\n\n            const int64_t i11_processed = ne11 - ne11 % 4;\n            for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {\n                from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);\n            }\n        }\n\n        // disable for NUMA\n        const bool disable_chunking = ggml_is_numa();\n\n        // 4x chunks per thread\n        const int64_t nr0 = ggml_nrows(op->src[0]);\n\n        int     nth_scaled  = nth * 4;\n        int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;\n        int64_t nchunk0     = (nr0 + chunk_size0 - 1) / chunk_size0;\n\n        // src1 is chunked only by full planes.\n        // When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE\n        // to route them thorugh GEMV.\n        // nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors\n        // to avoid affecting their performance\n        int64_t nchunk1 = ne12;\n\n        // Ensure minimum chunk size to avoid alignment issues with high thread counts\n        // Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment\n        const int64_t min_chunk_size = NB_COLS;\n        if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {\n            nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;\n        }\n\n        int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;\n        // Only increase nchunk0 to nth if it won't make chunks too small\n        if (nth == 1 || ((nchunk0 < nth || disable_chunking) && (nr0 + nth - 1) / nth >= min_chunk_size)) {\n            nchunk0 = nth;\n            dr0 = (nr0 + nchunk0 - 1) / nchunk0;\n        }\n\n        // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size\n        // This prevents creating too many tiny chunks that could overlap after alignment\n        const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;\n        nchunk0                  = MIN(nchunk0, max_nchunk);\n\n        if (ith == 0) {\n            // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.\n            ggml_threadpool_chunk_set(params->threadpool, nth);\n        }\n\n        ggml_barrier(params->threadpool);\n\n        // The first chunk comes from our thread_id, the rest will get auto-assigned.\n        int current_chunk = ith;\n\n        while (current_chunk < nchunk0 * nchunk1) {\n            const int64_t ith0 = current_chunk % nchunk0;\n            const int64_t ith1 = current_chunk / nchunk0;\n\n            int64_t src0_start = dr0 * ith0;\n            int64_t src0_end   = MIN(src0_start + dr0, nr0);\n\n            // full-plane range for src1\n            int64_t src1_start = ith1 * ne11;\n            int64_t src1_end = (ith1 + 1) * ne11;\n\n            // Align boundaries to NB_COLS - round up to ensure all data is included\n            // The chunk size limiting above ensures chunks are large enough to prevent overlaps\n            src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;\n            src0_end   = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;\n            src0_end   = MIN(src0_end, ne01);\n\n            // Make sure current plane is the last one before exiting\n            if (src0_start >= src0_end) {\n                current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);\n                continue;\n            }\n\n            forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);\n\n            current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);\n        }\n    }\n\n    void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {\n        const ggml_tensor * src0 = op->src[0];\n        const ggml_tensor * src1 = op->src[1];\n        const ggml_tensor * ids  = op->src[2];\n        ggml_tensor *       dst  = op;\n\n        GGML_TENSOR_BINARY_OP_LOCALS\n\n        const int ith = params->ith;\n        const int nth = params->nth;\n\n        const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;\n\n        // we don't support permuted src0 or src1\n        GGML_ASSERT(nb00 == ggml_type_size(src0->type));\n        GGML_ASSERT(nb10 == ggml_type_size(src1->type));\n\n        // dst cannot be transposed or permuted\n        GGML_ASSERT(nb0 == sizeof(float));\n        GGML_ASSERT(nb0 <= nb1);\n        GGML_ASSERT(nb1 <= nb2);\n        GGML_ASSERT(nb2 <= nb3);\n\n        GGML_ASSERT(ne03 == 1);\n        GGML_ASSERT(ne13 == 1);\n        GGML_ASSERT(ne3  == 1);\n\n        GGML_ASSERT(src1->type == GGML_TYPE_F32);\n\n        // row groups\n        const int n_ids = ids->ne[0]; // n_expert_used\n        const int n_as  = ne02;       // n_expert\n\n        const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);\n        const size_t nbw2 = nbw1*ne11;\n        const size_t nbw3 = nbw2*ne12;\n\n        struct mmid_row_mapping {\n            int32_t i1;\n            int32_t i2;\n        };\n\n        GGML_ASSERT(params->wsize >=\n                (GGML_PAD(nbw3, sizeof(int64_t)) +\n                 n_as*(ne12 + 1)*sizeof(mmid_row_mapping))\n                );\n\n        auto * wdata          = (char *)params->wdata;\n        auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t));\n\n        // total of [n_as][ne12 + 1] elements of type mmid_row_mapping (2*int32_t = int64_t)\n        auto * matrix_row_counts = (int64_t *) (wdata_src1_end);                                        // [n_as]\n        struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]\n\n        // src1: float32 => param type\n        for (int64_t i12 = 0; i12 < ne12; ++i12) {\n            for (int64_t i11 = ith; i11 < ne11; i11 += nth) {\n                from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),\n                           (void *)               (wdata + i12 * nbw2 + i11 * nbw1),\n                           ne10);\n            }\n        }\n\n#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]\n\n        if (ith == 0) {\n            // initialize matrix_row_counts\n            memset(matrix_row_counts, 0, n_as * sizeof(int64_t));\n\n            // group rows by src0 matrix\n            for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {\n                for (int32_t id = 0; id < n_ids; ++id) {\n                    const int32_t i02 =\n                        *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);\n\n                    GGML_ASSERT(i02 >= 0 && i02 < n_as);\n\n                    MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };\n                    matrix_row_counts[i02] += 1;\n                }\n            }\n        }\n\n        ggml_barrier(params->threadpool);\n\n        // compute each matrix multiplication in sequence\n        for (int cur_a = 0; cur_a < n_as; ++cur_a) {\n            const int64_t cne1 = matrix_row_counts[cur_a];\n\n            if (cne1 == 0) {\n                continue;\n            }\n\n            const auto * src0_cur = (const char *) src0->data + cur_a*nb02;\n\n            //const int64_t nr0 = ne01; // src0 rows\n            const int64_t nr1 = cne1; // src1 rows\n\n            int64_t src0_cur_start = (ith * ne01) / nth;\n            int64_t src0_cur_end   = ((ith + 1) * ne01) / nth;\n\n            // Align boundaries to NB_COLS - round up to ensure all data is included\n            src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;\n            src0_cur_end   = (src0_cur_end   % NB_COLS) ? src0_cur_end   + NB_COLS - (src0_cur_end   % NB_COLS) : src0_cur_end;\n            if (src0_cur_end > ne01) {\n                src0_cur_end = ne01;\n            }\n\n            if (src0_cur_start >= src0_cur_end) {\n                return;\n            }\n\n            for (int ir1 = 0; ir1 < nr1; ir1++) {\n                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);\n\n                const int id = row_mapping.i1;  // selected expert index\n\n                const int64_t i11 = id % ne11;\n                const int64_t i12 = row_mapping.i2;  // row index in src1\n\n                const int64_t i1 = id;               // selected expert index\n                const int64_t i2 = i12;              // row\n\n                const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);\n\n                gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(\n                    ne00, (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,\n                    src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);\n            }\n        }\n#undef MMID_MATRIX_ROW\n    }\n\n    int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {\n        GGML_LOG_DEBUG(\"%s: repack tensor %s with %s_%dx%d\\n\", __func__, t->name, ggml_type_name(t->type),\n                       (int) NB_COLS, (int) INTER_SIZE);\n        return ggml::cpu::repack::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);\n    }\n};\n\n}  // namespace ggml::cpu::repack\n\nstatic const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(const struct ggml_tensor * cur) {\n    // instance for Q4\n    static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;\n    static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;\n    static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;\n\n    // instance for Q4_K\n    static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;\n    static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;\n\n    // instance for Q5_K\n    static const ggml::cpu::repack::tensor_traits<block_q5_K, 4, 8, GGML_TYPE_Q8_K> q5_K_8x4_q8_K;\n    static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;\n\n    // instance for Q6_K\n    static const ggml::cpu::repack::tensor_traits<block_q6_K, 4, 8, GGML_TYPE_Q8_K> q6_K_8x4_q8_K;\n    static const ggml::cpu::repack::tensor_traits<block_q6_K, 8, 8, GGML_TYPE_Q8_K> q6_K_8x8_q8_K;\n\n    // instance for Q2\n    static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K;\n\n    // instance for IQ4\n    static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;\n    static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;\n\n    // instance for MXFP4\n    static const ggml::cpu::repack::tensor_traits<block_mxfp4, 4, 4, GGML_TYPE_Q8_0> mxfp4_4x4_q8_0;\n    static const ggml::cpu::repack::tensor_traits<block_mxfp4, 8, 8, GGML_TYPE_Q8_0> mxfp4_8x8_q8_0;\n\n    // instance for Q8_0\n    static const ggml::cpu::repack::tensor_traits<block_q8_0, 4, 4, GGML_TYPE_Q8_0> q8_0_4x4_q8_0;\n    static const ggml::cpu::repack::tensor_traits<block_q8_0, 8, 4, GGML_TYPE_Q8_0> q8_0_4x8_q8_0;\n\n    // instances for RISC-V\n    //\n    // These implement outer-product style matrix multiplication kernels with\n    // an interleave of 1.\n#if defined __riscv_zvfh\n    static const ggml::cpu::repack::tensor_traits<block_q4_0, 1, 16, GGML_TYPE_Q8_0> q4_0_16x1_q8_0;\n    static const ggml::cpu::repack::tensor_traits<block_q4_K, 1, 16, GGML_TYPE_Q8_K> q4_K_16x1_q8_K;\n    static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 1, 16, GGML_TYPE_Q8_0> iq4_nl_16x1_q8_0;\n    static const ggml::cpu::repack::tensor_traits<block_q8_0, 1, 16, GGML_TYPE_Q8_0> q8_0_16x1_q8_0;\n    static const ggml::cpu::repack::tensor_traits<block_q2_K, 1, 16, GGML_TYPE_Q8_K> q2_K_16x1_q8_K;\n#endif\n\n    if (cur->type == GGML_TYPE_Q4_0) {\n        if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {\n            if (cur->ne[1] % 8 == 0) {\n                return &q4_0_8x8_q8_0;\n            }\n        }\n        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {\n            if (cur->ne[1] % 4 == 0) {\n                return &q4_0_4x8_q8_0;\n            }\n        }\n        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {\n            if (cur->ne[1] % 4 == 0) {\n                return &q4_0_4x4_q8_0;\n            }\n        }\n        if (ggml_cpu_has_riscv_v()) {\n            #if defined __riscv_zvfh\n            switch (__riscv_vlenb() * 8) {\n                case 128:  { break; } // TODO\n                case 256:  { if (cur->ne[1] % 16 == 0) { return &q4_0_16x1_q8_0; } break; }\n                case 512:  { break; } // TODO\n                case 1024: { break; } // TODO\n                default:   { return nullptr; }\n            }\n            #endif\n        }\n    } else if (cur->type == GGML_TYPE_Q4_K) {\n        if (ggml_cpu_has_avx2()) {\n            if (cur->ne[1] % 8 == 0) {\n                return &q4_K_8x8_q8_K;\n            }\n        }\n        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {\n            if (cur->ne[1] % 8 == 0) {\n                return &q4_K_8x8_q8_K;\n            }\n        }\n        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {\n            if (cur->ne[1] % 8 == 0) {\n                return &q4_K_8x4_q8_K;\n            }\n        }\n        if (ggml_cpu_has_riscv_v()) {\n            #if defined __riscv_zvfh\n            switch (__riscv_vlenb() * 8) {\n                case 128:  { break; } // TODO\n                case 256:  { if (cur->ne[1] % 16 == 0) { return &q4_K_16x1_q8_K; } break; }\n                case 512:  { break; } // TODO\n                case 1024: { break; } // TODO\n                default:   { return nullptr; }\n            }\n            #endif\n        }\n    } else if (cur->type == GGML_TYPE_Q2_K) {\n        if (ggml_cpu_has_avx512()) {\n            if (cur->ne[1] % 8 == 0) {\n                return &q2_K_8x8_q8_K;\n            }\n        }\n        if (ggml_cpu_has_riscv_v()) {\n            #if defined __riscv_zvfh\n            switch (__riscv_vlenb() * 8) {\n                case 128:  { break; } // TODO\n                case 256:  { if (cur->ne[1] % 16 == 0) { return &q2_K_16x1_q8_K; } break; }\n                case 512:  { break; } // TODO\n                case 1024: { break; } // TODO\n                default:   { return nullptr; }\n            }\n            #endif\n        }\n    } else if (cur->type == GGML_TYPE_Q5_K) {\n        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {\n            if (cur->ne[1] % 8 == 0) {\n                return &q5_K_8x8_q8_K;\n            }\n        }\n        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {\n            if (cur->ne[1] % 8 == 0) {\n                return &q5_K_8x4_q8_K;\n            }\n        }\n    } else if (cur->type == GGML_TYPE_Q6_K) {\n        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {\n            if (cur->ne[1] % 8 == 0) {\n                return &q6_K_8x8_q8_K;\n            }\n        }\n        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {\n            if (cur->ne[1] % 8 == 0) {\n                return &q6_K_8x4_q8_K;\n            }\n        }\n    } else if (cur->type == GGML_TYPE_IQ4_NL) {\n        if (ggml_cpu_has_avx2()) {\n            if (cur->ne[1] % 8 == 0) {\n                return &iq4_nl_8x8_q8_0;\n            }\n        }\n        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {\n            if (cur->ne[1] % 4 == 0) {\n                return &iq4_nl_4x4_q8_0;\n            }\n        }\n        if (ggml_cpu_has_riscv_v()) {\n            #if defined __riscv_zvfh\n            switch (__riscv_vlenb() * 8) {\n                case 128:  { break; } // TODO\n                case 256:  { if (cur->ne[1] % 16 == 0) { return &iq4_nl_16x1_q8_0; } break; }\n                case 512:  { break; } // TODO\n                case 1024: { break; } // TODO\n                default:   { return nullptr; }\n            }\n            #endif\n        }\n    } else if (cur->type == GGML_TYPE_MXFP4) {\n        if (ggml_cpu_has_avx2()) {\n            if (cur->ne[1] % 8 == 0) {\n                return &mxfp4_8x8_q8_0;\n            }\n        }\n        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {\n            if (cur->ne[1] % 4 == 0) {\n                return &mxfp4_4x4_q8_0;\n            }\n        }\n    } else if (cur->type == GGML_TYPE_Q8_0) {\n        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {\n            if (cur->ne[1] % 4 == 0) {\n                return &q8_0_4x8_q8_0;\n            }\n        }\n        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {\n            if (cur->ne[1] % 4 == 0) {\n                return &q8_0_4x4_q8_0;\n            }\n        }\n        if (ggml_cpu_has_riscv_v()) {\n            #if defined __riscv_zvfh\n            switch (__riscv_vlenb() * 8) {\n                case 128:  { break; } // TODO\n                case 256:  { if (cur->ne[1] % 16 == 0) { return &q8_0_16x1_q8_0; } break; }\n                case 512:  { break; } // TODO\n                case 1024: { break; } // TODO\n                default:   { return nullptr; }\n            }\n            #endif\n        }\n    }\n\n    return nullptr;\n}\n\nstatic enum ggml_status ggml_backend_cpu_repack_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {\n    tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_repack_get_optimal_repack_type(tensor));\n\n    GGML_UNUSED(buffer);\n    return GGML_STATUS_SUCCESS;\n}\n\nstatic void ggml_backend_cpu_repack_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,\n                                                       const void * data, size_t offset, size_t size) {\n    GGML_ASSERT(offset == 0);\n    GGML_ASSERT(size == ggml_nbytes(tensor));\n\n    auto tensor_traits = (ggml::cpu::repack::tensor_traits_base *) tensor->extra;\n    auto OK            = tensor_traits->repack(tensor, data, size);\n\n    GGML_ASSERT(OK == 0);\n    GGML_UNUSED(buffer);\n}\n\nstatic const char * ggml_backend_cpu_repack_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    return \"CPU_REPACK\";\n\n    GGML_UNUSED(buft);\n}\n\nstatic ggml_backend_buffer_t ggml_backend_cpu_repack_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);\n\n    if (buffer == nullptr) {\n        return nullptr;\n    }\n\n    buffer->buft              = buft;\n    buffer->iface.init_tensor = ggml_backend_cpu_repack_buffer_init_tensor;\n    buffer->iface.set_tensor  = ggml_backend_cpu_repack_buffer_set_tensor;\n    buffer->iface.get_tensor  = nullptr;\n    buffer->iface.cpy_tensor  = nullptr;\n    return buffer;\n}\n\nstatic size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    return TENSOR_ALIGNMENT;\n\n    GGML_UNUSED(buft);\n}\n\nnamespace ggml::cpu::repack {\nclass extra_buffer_type : ggml::cpu::extra_buffer_type {\n    bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {\n        if (    op->op == GGML_OP_MUL_MAT &&\n                op->src[0]->buffer &&\n                (ggml_n_dims(op->src[0]) == 2) &&\n                op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type() &&\n                ggml_repack_get_optimal_repack_type(op->src[0])\n                ) {\n            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {\n                return false;\n            }\n            if (op->src[1]->type == GGML_TYPE_F32) {\n                return true;\n            }\n            //if (op->src[1]->type == GGML_TYPE_Q8_0) {\n            //    return true;\n            //}\n            // may be possible if Q8_0 packed...\n        } else if (op->op == GGML_OP_MUL_MAT_ID\n                && op->src[0]->buffer\n                && (ggml_n_dims(op->src[0]) == 3)\n                && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()\n                && ggml_repack_get_optimal_repack_type(op->src[0])\n                ) {\n            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {\n                return false;\n            }\n            if (op->src[1]->type == GGML_TYPE_F32) {\n                return true;\n            }\n            //if (op->src[1]->type == GGML_TYPE_Q8_0) {\n            //    return true;\n            //}\n        }\n        return false;\n    }\n\n    ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {\n        if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {\n            if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()) {\n                return (ggml::cpu::tensor_traits *) op->src[0]->extra;\n            }\n        }\n        return nullptr;\n    }\n};\n}  // namespace ggml::cpu::repack\n\nggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void) {\n    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_repack = {\n        /* .iface    = */ {\n                           /* .get_name         = */ ggml_backend_cpu_repack_buffer_type_get_name,\n                           /* .alloc_buffer     = */ ggml_backend_cpu_repack_buffer_type_alloc_buffer,\n                           /* .get_alignment    = */ ggml_backend_cpu_repack_buffer_type_get_alignment,\n                           /* .get_max_size     = */ nullptr,  // defaults to SIZE_MAX\n                           /* .get_alloc_size   = */ nullptr,  // defaults to ggml_nbytes\n                           /* .is_host          = */ nullptr,\n                           },\n        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),\n        /* .context = */ new ggml::cpu::repack::extra_buffer_type(),\n    };\n\n    return &ggml_backend_cpu_buffer_type_repack;\n}\n"
  },
  {
    "path": "src/ggml-cpu/repack.h",
    "content": "#pragma once\n\n#define GGML_COMMON_DECL_CPP\n#include \"ggml-common.h\"\n\n#include \"traits.h\"\n#include \"ggml.h\"\n\n// GGML internal header\n\nggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void);\n\ntemplate <int K> constexpr int QK_0() {\n    if constexpr (K == 4) {\n        return QK4_0;\n    }\n    if constexpr (K == 8) {\n        return QK8_0;\n    }\n    return -1;\n}\n\ntemplate <int K, int N> struct block {\n    ggml_half d[N];                         // deltas for N qK_0 blocks\n    int8_t    qs[(QK_0<K>() * N * K) / 8];  // quants for N qK_0 blocks\n};\n\n// control size\nstatic_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, \"wrong block<4,4> size/padding\");\nstatic_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, \"wrong block<4,8> size/padding\");\nstatic_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, \"wrong block<4,16> size/padding\");\nstatic_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, \"wrong block<8,4> size/padding\");\nstatic_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, \"wrong block<8,8> size/padding\");\nstatic_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, \"wrong block<8,16> size/padding\");\n\nusing block_q4_0x4 = block<4, 4>;\nusing block_q4_0x8 = block<4, 8>;\nusing block_q4_0x16 = block<4, 16>;\nusing block_q8_0x4 = block<8, 4>;\nusing block_q8_0x8 = block<8, 8>;\nusing block_q8_0x16 = block<8, 16>;\n\nstruct block_q4_Kx8 {\n    ggml_half d[8];      // super-block scale for quantized scales\n    ggml_half dmin[8];   // super-block scale for quantized mins\n    uint8_t scales[96];  // scales and mins, quantized with 6 bits\n    uint8_t qs[1024];    // 4--bit quants\n};\n\nstatic_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, \"wrong q4_K block size/padding\");\nstruct block_q4_Kx16 {\n    ggml_half d[16];      // super-block scale for quantized scales\n    ggml_half dmin[16];   // super-block scale for quantized mins\n    uint8_t scales[192];  // scales and mins, quantized with 6 bits\n    uint8_t qs[2048];    // 4--bit quants\n};\n\nstatic_assert(sizeof(block_q4_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 8, \"wrong q4_K block size/padding\");\nstruct block_q2_Kx8 {\n    ggml_half d[8];      // super-block scale for quantized scales\n    ggml_half dmin[8];   // super-block scale for quantized mins\n    uint8_t scales[128];  // scales and mins, quantized with 4 bits\n    uint8_t qs[512];    // 2--bit quants\n};\n\nstatic_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, \"wrong q2_K block size/padding\");\nstruct block_q2_Kx16 {\n    ggml_half d[16];       // Super-block scale for quantized scales\n    ggml_half dmin[16];    // Super-block scale for quantized mins\n    uint8_t   scales[256]; // Sub-block scales (16 cols * 16 sub-blocks)\n    uint8_t   qs[1024];    // Data (16 cols * 64 bytes per block)\n};\nstatic_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, \"wrong q2_K block size/padding\");\n\nstruct block_q5_Kx8 {\n    ggml_half d[8];              // super-block scale for quantized scales\n    ggml_half dmin[8];           // super-block scale for quantized mins\n    uint8_t   scales[96];        // scales and mins, quantized with 6 bits\n    uint8_t   qh[QK_K * 8 / 8];  // high bits of 5-bit quants\n    uint8_t   qs[QK_K * 8 / 2];  // low bits of 5-bit quants (in groups of 4)\n};\n\nstatic_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5,\n              \"wrong q5_K block size/padding\");\n\nstruct block_q6_Kx8 {\n    ggml_half d[8];\n    int8_t    scales[QK_K / 16 * 8];\n    uint8_t   ql[QK_K / 2 * 8];  // low bits of 6-bit quants (groups of 2)\n    uint8_t   qh[QK_K / 4 * 8];  // high bits of 6-bit quants (groups of 4)\n};\n\nstatic_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8,\n              \"wrong q6_K block size/padding\");\n\nstruct block_q8_Kx4 {\n    float d[4];              // delta\n    int8_t qs[QK_K * 4];     // quants\n    int16_t bsums[QK_K / 4]; // sum of quants in groups of 16\n};\n\nstatic_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), \"wrong q8_K block size/padding\");\n\nstruct block_iq4_nlx4 {\n    ggml_half d[4];            // deltas for 4 iq4_nl blocks\n    uint8_t   qs[QK4_NL * 2];  // nibbles / quants for 4 iq4_nl blocks\n};\n\nstatic_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, \"wrong iq4_nlx4 block size/padding\");\n\nstruct block_iq4_nlx8 {\n    ggml_half d[8];            // deltas for 8 iq4_nl blocks\n    uint8_t   qs[QK4_NL * 4];  // nibbles / quants for 8 iq4_nl blocks\n};\n\nstatic_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, \"wrong iq4_nlx8 block size/padding\");\n\nstruct block_iq4_nlx16 {\n    ggml_half d[16];            // deltas for 16 iq4_nl blocks\n    uint8_t   qs[QK4_NL * 8];  // nibbles / quants for 16 iq4_nl blocks\n};\n\nstatic_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, \"wrong iq4_nlx16 block size/padding\");\nstruct block_mxfp4x4 {\n    uint8_t e[4];\n    uint8_t qs[QK_MXFP4 * 2];\n};\nstatic_assert(sizeof(block_mxfp4x4) == 4 + QK_MXFP4 * 2, \"wrong mxfp4x4 block size/padding\");\n\nstruct block_mxfp4x8 {\n    uint8_t e[8];\n    uint8_t qs[QK_MXFP4 * 4];\n};\nstatic_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, \"wrong mxfp4x8 block size/padding\");\n\n#if defined(__cplusplus)\nextern \"C\" {\n#endif\n\nvoid ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);\nvoid ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);\nvoid ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);\nvoid ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);\nvoid ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\n#if defined __riscv_zvfh\nvoid ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);\nvoid ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);\nvoid ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\n#endif\n\n// Native implementations\nvoid ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);\nvoid ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);\nvoid ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);\nvoid ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);\nvoid ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\n#if defined __riscv_zvfh\nvoid ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);\nvoid ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);\nvoid ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\nvoid ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);\n#endif\n\n#if defined(__cplusplus)\n} // extern \"C\"\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/simd-gemm.h",
    "content": "#pragma once\n\n// Computes C[M x N] += A[M x K] * B[K x N]\n\n#include \"simd-mappings.h\"\n\n// TODO: add support for sizeless vector types\n#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)\n\n// TODO: untested on avx512\n// These are in units of GGML_F32_EPR\n#if defined(__AVX512F__) || defined (__ARM_NEON__)\n    static constexpr int GEMM_RM = 4;\n    static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32\n#elif defined(__AVX2__) || defined(__AVX__)\n    static constexpr int GEMM_RM = 6;\n    static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16\n#else\n    static constexpr int GEMM_RM = 2;\n    static constexpr int GEMM_RN = 2;\n#endif\n\ntemplate <int RM, int RN>\nstatic inline void simd_gemm_ukernel(\n    float       * GGML_RESTRICT C,\n    const float * GGML_RESTRICT A,\n    const float * GGML_RESTRICT B,\n    int K, int N)\n{\n    static constexpr int KN = GGML_F32_EPR;\n\n    GGML_F32_VEC acc[RM][RN];\n    for (int64_t i = 0; i < RM; i++) {\n        for (int r = 0; r < RN; r++) {\n            acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN);\n        }\n    }\n\n    for (int64_t kk = 0; kk < K; kk++) {\n        GGML_F32_VEC Bv[RN];\n        for (int r = 0; r < RN; r++) {\n            Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN);\n        }\n        for (int64_t i = 0; i < RM; i++) {\n            GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]);\n            for (int r = 0; r < RN; r++) {\n                acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);\n            }\n        }\n    }\n\n    for (int64_t i = 0; i < RM; i++) {\n        for (int r = 0; r < RN; r++) {\n            GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]);\n        }\n    }\n}\n\n// C[M x N] += A[M x K] * B[K x N]\nstatic void simd_gemm(\n    float       * GGML_RESTRICT C,\n    const float * GGML_RESTRICT A,\n    const float * GGML_RESTRICT B,\n    int M, int K, int N)\n{\n    static constexpr int KN = GGML_F32_EPR;\n\n    int64_t ii = 0;\n    for (; ii + GEMM_RM <= M; ii += GEMM_RM) {\n        int64_t jj = 0;\n        for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {\n            simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C + jj, A, B + jj, K, N);\n        }\n        for (; jj + KN <= N; jj += KN) {\n            simd_gemm_ukernel<GEMM_RM, 1>(C + jj, A, B + jj, K, N);\n        }\n        for (; jj < N; jj++) {\n            for (int64_t i = 0; i < GEMM_RM; i++) {\n                float a = C[i * N + jj];\n                for (int64_t kk = 0; kk < K; kk++) {\n                    a += A[i + kk] * B[kk * N + jj];\n                }\n                C[i * N + jj] = a;\n            }\n        }\n\n        A += GEMM_RM * K;\n        C += GEMM_RM * N;\n    }\n\n    // Tail rows: one at a time\n    for (; ii < M; ii++) {\n        int64_t jj = 0;\n        for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {\n            simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N);\n        }\n        for (; jj + KN <= N; jj += KN) {\n            simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N);\n        }\n        for (; jj < N; jj++) {\n            float a = C[jj];\n            for (int64_t kk = 0; kk < K; kk++) {\n                a += A[kk] * B[kk * N + jj];\n            }\n            C[jj] = a;\n        }\n\n        A += K;\n        C += N;\n    }\n}\n\n#if defined(__GNUC__) && !defined(__clang__)\n#pragma GCC diagnostic pop\n#endif\n\n#else // scalar path\n\nstatic void simd_gemm(\n    float       * GGML_RESTRICT C,\n    const float * GGML_RESTRICT A,\n    const float * GGML_RESTRICT B,\n    int M, int K, int N)\n{\n    for (int64_t i = 0; i < M; i++) {\n        for (int64_t j = 0; j < N; j++) {\n            float sum = C[i * N + j];\n            for (int64_t kk = 0; kk < K; kk++) {\n                sum += A[i * K + kk] * B[kk * N + j];\n            }\n            C[i * N + j] = sum;\n        }\n    }\n}\n\n#endif // GGML_SIMD\n"
  },
  {
    "path": "src/ggml-cpu/simd-mappings.h",
    "content": "#pragma once\n\n#include \"ggml-cpu-impl.h\"\n\n#ifdef __ARM_FEATURE_SVE\n#include <arm_sve.h>\n#endif // __ARM_FEATURE_SVE\n\n#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)\n// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:\n//\n//   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/\n//\n#include <arm_neon.h>\n#endif\n\n#if defined(__riscv_v_intrinsic)\n#include <riscv_vector.h>\n#endif\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n//\n// simd mappings\n//\n\n// FP16 to FP32 conversion\n\n// 16-bit float\n// on Arm, we use __fp16\n// on x86, we use uint16_t\n//\n// for old CUDA compilers (<= 11), we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/10616\n// for     MUSA compilers        , we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/11843\n//\n#if defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__)\n    #define GGML_CPU_COMPUTE_FP16_TO_FP32(x) neon_compute_fp16_to_fp32(x)\n    #define GGML_CPU_COMPUTE_FP32_TO_FP16(x) neon_compute_fp32_to_fp16(x)\n\n    #define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x)\n\n    static inline float neon_compute_fp16_to_fp32(ggml_fp16_t h) {\n        __fp16 tmp;\n        memcpy(&tmp, &h, sizeof(ggml_fp16_t));\n        return (float)tmp;\n    }\n\n    static inline ggml_fp16_t neon_compute_fp32_to_fp16(float f) {\n        ggml_fp16_t res;\n        __fp16 tmp = f;\n        memcpy(&res, &tmp, sizeof(ggml_fp16_t));\n        return res;\n    }\n#elif defined(__F16C__)\n    #ifdef _MSC_VER\n        #define GGML_CPU_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))\n        #define GGML_CPU_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)\n    #else\n        #define GGML_CPU_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)\n        #define GGML_CPU_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)\n    #endif\n#elif defined(__POWER9_VECTOR__)\n    #define GGML_CPU_COMPUTE_FP16_TO_FP32(x) power_compute_fp16_to_fp32(x)\n    #define GGML_CPU_COMPUTE_FP32_TO_FP16(x) power_compute_fp32_to_fp16(x)\n    /* the inline asm below is about 12% faster than the lookup method */\n    #define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x)\n    #define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x)\n\n    static inline float power_compute_fp16_to_fp32(ggml_fp16_t h) {\n        float f;\n        double d;\n        __asm__(\n            \"mtfprd %0,%2\\n\"\n            \"xscvhpdp %0,%0\\n\"\n            \"frsp %1,%0\\n\" :\n            /* temp */ \"=d\"(d),\n            /* out */  \"=f\"(f):\n            /* in */   \"r\"(h));\n        return f;\n    }\n\n    static inline ggml_fp16_t power_compute_fp32_to_fp16(float f) {\n        double d;\n        ggml_fp16_t r;\n        __asm__( /* xscvdphp can work on double or single precision */\n            \"xscvdphp %0,%2\\n\"\n            \"mffprd %1,%0\\n\" :\n            /* temp */ \"=d\"(d),\n            /* out */  \"=r\"(r):\n            /* in */   \"f\"(f));\n        return r;\n    }\n#elif defined(__riscv) && defined(__riscv_zfhmin)\n    static inline float riscv_compute_fp16_to_fp32(ggml_fp16_t h) {\n        _Float16 hf;\n        memcpy(&hf, &h, sizeof(ggml_fp16_t));\n        return hf;\n    }\n\n    static inline ggml_fp16_t riscv_compute_fp32_to_fp16(float f) {\n        ggml_fp16_t res;\n        _Float16 hf = (_Float16)f;\n        memcpy(&res, &hf, sizeof(ggml_fp16_t));\n        return res;\n    }\n\n    #define GGML_CPU_COMPUTE_FP16_TO_FP32(x) riscv_compute_fp16_to_fp32(x)\n    #define GGML_CPU_COMPUTE_FP32_TO_FP16(x) riscv_compute_fp32_to_fp16(x)\n    #define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x)\n    #define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x)\n#endif\n\n// precomputed f32 table for f16 (256 KB)\n// defined in ggml-cpu.c, initialized in ggml_cpu_init()\nextern float ggml_table_f32_f16[1 << 16];\n\n// precomputed f32 table for e8m0 half (1 KB)\n// defined in ggml-cpu.c, initialized in ggml_cpu_init()\nextern float ggml_table_f32_e8m0_half[1 << 8];\n\n// Use lookup table for E8M0 on x86 (faster than bit manipulation)\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\n#define GGML_CPU_E8M0_TO_FP32_HALF(x) ggml_table_f32_e8m0_half[(uint8_t)(x)]\n#else\n#define GGML_CPU_E8M0_TO_FP32_HALF(x) GGML_E8M0_TO_FP32_HALF(x)\n#endif\n\n// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,\n// so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON.\n// This is also true for POWER9.\n#if !defined(GGML_CPU_FP16_TO_FP32)\ninline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {\n    uint16_t s;\n    memcpy(&s, &f, sizeof(uint16_t));\n    return ggml_table_f32_f16[s];\n}\n\n#define GGML_CPU_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)\n#endif\n\n#if !defined(GGML_CPU_FP32_TO_FP16)\n#define GGML_CPU_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)\n#endif\n\n\n// we define a common set of C macros which map to specific intrinsics based on the current architecture\n// we then implement the fundamental computation operations below using only these macros\n// adding support for new architectures requires to define the corresponding SIMD macros\n//\n// GGML_F32_STEP / GGML_F16_STEP\n//   number of elements to process in a single step\n//\n// GGML_F32_EPR / GGML_F16_EPR\n//   number of elements to fit in a single register\n//\n\n#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_FMA)\n\n#define GGML_SIMD\n\n// F32 SVE\n#define GGML_F32_EPR 8\n#define DEFAULT_PG svptrue_b32()\n\n#define GGML_F32xt                        svfloat32_t\n#define GGML_F32xt_ZERO                   svdup_n_f32(0.0f)\n#define GGML_F32xt_SET1(x)                svdup_n_f32(x)\n#define GGML_F32xt_LOAD_IMPL(pg, a)       svld1_f32(pg, a)\n#define GGML_F32xt_LOAD(a)                GGML_F32xt_LOAD_IMPL(DEFAULT_PG, a)\n#define GGML_F32xt_STORE_IMPL(pg, a, b)   svst1_f32(pg, a, b)\n#define GGML_F32xt_STORE(a, b)            GGML_F32xt_STORE_IMPL(DEFAULT_PG, a, b)\n#define GGML_F32xt_FMA_IMPL(pg, a, b, c)  svmad_f32_m(pg, b, c, a)\n#define GGML_F32xt_FMA(a, b, c)           GGML_F32xt_FMA_IMPL(DEFAULT_PG, a, b, c)\n#define GGML_F32xt_ADD_IMPL(pg, a, b)     svadd_f32_m(pg, a, b)\n#define GGML_F32xt_ADD(a, b)              GGML_F32xt_ADD_IMPL(DEFAULT_PG, a, b)\n#define GGML_F32xt_MUL_IMPL(pg, a, b)     svmul_f32_m(pg, a, b)\n#define GGML_F32xt_MUL(a, b)              GGML_F32xt_MUL_IMPL(DEFAULT_PG, a, b)\n#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a)\n#define GGML_F32xt_REDUCE_ONE(a)          GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, a)\n#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8)  \\\n{                                                      \\\n    sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2);        \\\n    sum3 = svadd_f32_m(DEFAULT_PG, sum3, sum4);        \\\n    sum5 = svadd_f32_m(DEFAULT_PG, sum5, sum6);        \\\n    sum7 = svadd_f32_m(DEFAULT_PG, sum7, sum8);        \\\n    sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum3);        \\\n    sum5 = svadd_f32_m(DEFAULT_PG, sum5, sum7);        \\\n    sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5);        \\\n    (res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1);  \\\n}\n#define GGML_F32xt_REDUCE(res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8)  \\\n        GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8)\n\n#define GGML_F32_VEC        GGML_F32xt\n#define GGML_F32_VEC_ZERO   GGML_F32xt_ZERO\n#define GGML_F32_VEC_SET1   GGML_F32xt_SET1\n#define GGML_F32_VEC_LOAD   GGML_F32xt_LOAD\n#define GGML_F32_VEC_STORE  GGML_F32xt_STORE\n#define GGML_F32_VEC_FMA    GGML_F32xt_FMA\n#define GGML_F32_VEC_ADD    GGML_F32xt_ADD\n#define GGML_F32_VEC_MUL    GGML_F32xt_MUL\n#define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE\n\n// F16 SVE\n#define DEFAULT_PG32    svptrue_b32()\n#define DEFAULT_PG16    svptrue_b16()\n\n#define GGML_F32Cxt                         svfloat16_t\n#define GGML_F32Cxt_ZERO                    svdup_n_f16(0.0f)\n#define GGML_F32Cxt_SET1(x)                 svdup_n_f16(x)\n#define GGML_F32Cxt_LOAD(p)                 svld1_f16(DEFAULT_PG16, (const __fp16 *)(p))\n#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))\n\n#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c)   svmad_f16_x(pg, b, c, a)\n#define GGML_F32Cxt_FMA(a, b, c)            GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, a, b, c)\n#define GGML_F32Cxt_ADD_IMPL(pg, a, b)      svadd_f16_x(pg, a, b)\n#define GGML_F32Cxt_ADD(a, b)               GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, a, b)\n#define GGML_F32Cxt_MUL_IMPL(pg, a, b)      svmul_f16_x(pg, a, b)\n#define GGML_F32Cxt_MUL(a, b)               GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, a, b)\n#define GGML_F32Cxt_REDUCE                  GGML_F16xt_REDUCE_MIXED\n\n#define GGML_F16x_VEC                GGML_F32Cxt\n#define GGML_F16x_VEC_ZERO           GGML_F32Cxt_ZERO\n#define GGML_F16x_VEC_SET1           GGML_F32Cxt_SET1\n#define GGML_F16x_VEC_LOAD(p, i)     GGML_F32Cxt_LOAD(p)\n#define GGML_F16x_VEC_STORE(p, r, i) GGML_F32Cxt_STORE((__fp16 *)(p), r)\n#define GGML_F16x_VEC_FMA            GGML_F32Cxt_FMA\n#define GGML_F16x_VEC_ADD            GGML_F32Cxt_ADD\n#define GGML_F16x_VEC_MUL            GGML_F32Cxt_MUL\n#define GGML_F16x_VEC_REDUCE         GGML_F32Cxt_REDUCE\n\n#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)\n#define GGML_F16xt_REDUCE_ONE(a)          GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, a)\n\n#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4)  \\\n{                                                      \\\n    sum1 = svadd_f16_x(pg16, sum1, sum2);              \\\n    sum3 = svadd_f16_x(pg16, sum3, sum4);              \\\n    sum1 = svadd_f16_x(pg16, sum1, sum3);              \\\n    __fp16 sum_f16 = svaddv_f16(pg16, sum1);           \\\n    (res) = (ggml_float) sum_f16;                      \\\n}\n#define GGML_F16xt_REDUCE_MIXED(res, sum1, sum2, sum3, sum4)  \\\n        GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, res, sum1, sum2, sum3, sum4)\n\n// F16 NEON\n\n#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)\n    #define GGML_F16_STEP 32\n    #define GGML_F16_EPR  8\n\n    #define GGML_F16x8              float16x8_t\n    #define GGML_F16x8_ZERO         vdupq_n_f16(0.0f)\n    #define GGML_F16x8_SET1(x)      vdupq_n_f16(x)\n    #define GGML_F16x8_LOAD(x)      vld1q_f16((const __fp16 *)(x))\n    #define GGML_F16x8_STORE        vst1q_f16\n    #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)\n    #define GGML_F16x8_ADD          vaddq_f16\n    #define GGML_F16x8_MUL          vmulq_f16\n    #define GGML_F16x8_REDUCE(res, x)                               \\\n    do {                                                            \\\n        int offset = GGML_F16_ARR >> 1;                             \\\n        for (int i = 0; i < offset; ++i) {                          \\\n            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \\\n        }                                                           \\\n        offset >>= 1;                                               \\\n        for (int i = 0; i < offset; ++i) {                          \\\n            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \\\n        }                                                           \\\n        offset >>= 1;                                               \\\n        for (int i = 0; i < offset; ++i) {                          \\\n            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \\\n        }                                                           \\\n        const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \\\n        const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \\\n        (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \\\n    } while (0)\n\n    #define GGML_F16_VEC                GGML_F16x8\n    #define GGML_F16_VEC_ZERO           GGML_F16x8_ZERO\n    #define GGML_F16_VEC_SET1           GGML_F16x8_SET1\n    #define GGML_F16_VEC_LOAD(p, i)     GGML_F16x8_LOAD(p)\n    #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), (r)[i])\n    #define GGML_F16_VEC_FMA            GGML_F16x8_FMA\n    #define GGML_F16_VEC_ADD            GGML_F16x8_ADD\n    #define GGML_F16_VEC_MUL            GGML_F16x8_MUL\n    #define GGML_F16_VEC_REDUCE         GGML_F16x8_REDUCE\n#else\n    // if FP16 vector arithmetic is not supported, we use FP32 instead\n    // and take advantage of the vcvt_ functions to convert to/from FP16\n\n    #define GGML_F16_STEP 16\n    #define GGML_F16_EPR  4\n\n    #define GGML_F32Cx4              float32x4_t\n    #define GGML_F32Cx4_ZERO         vdupq_n_f32(0.0f)\n    #define GGML_F32Cx4_SET1(x)      vdupq_n_f32(x)\n    #define GGML_F32Cx4_LOAD(x)      vcvt_f32_f16(vld1_f16((const __fp16 *)(x)))\n    #define GGML_F32Cx4_STORE(x, y)  vst1_f16(x, vcvt_f16_f32(y))\n    #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)\n    #define GGML_F32Cx4_ADD          vaddq_f32\n    #define GGML_F32Cx4_MUL          vmulq_f32\n    #define GGML_F32Cx4_REDUCE       GGML_F32x4_REDUCE\n\n    #define GGML_F16_VEC                GGML_F32Cx4\n    #define GGML_F16_VEC_ZERO           GGML_F32Cx4_ZERO\n    #define GGML_F16_VEC_SET1           GGML_F32Cx4_SET1\n    #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx4_LOAD(p)\n    #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i])\n    #define GGML_F16_VEC_FMA            GGML_F32Cx4_FMA\n    #define GGML_F16_VEC_ADD            GGML_F32Cx4_ADD\n    #define GGML_F16_VEC_MUL            GGML_F32Cx4_MUL\n    #define GGML_F16_VEC_REDUCE         GGML_F32Cx4_REDUCE\n#endif\n\n#elif defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)\n\n#define GGML_SIMD\n\n// F32 NEON\n\n#define GGML_F32_STEP 16\n#define GGML_F32_EPR  4\n\n#define GGML_F32x4              float32x4_t\n#define GGML_F32x4_ZERO         vdupq_n_f32(0.0f)\n#define GGML_F32x4_SET1(x)      vdupq_n_f32(x)\n#define GGML_F32x4_LOAD         vld1q_f32\n#define GGML_F32x4_STORE        vst1q_f32\n#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)\n#define GGML_F32x4_ADD          vaddq_f32\n#define GGML_F32x4_MUL          vmulq_f32\n#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)\n#define GGML_F32x4_REDUCE(res, x)                       \\\n{                                                       \\\n    int offset = GGML_F32_ARR >> 1;                     \\\n    for (int i = 0; i < offset; ++i) {                  \\\n        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \\\n    }                                                   \\\n    offset >>= 1;                                       \\\n    for (int i = 0; i < offset; ++i) {                  \\\n        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \\\n    }                                                   \\\n    offset >>= 1;                                       \\\n    for (int i = 0; i < offset; ++i) {                  \\\n        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \\\n    }                                                   \\\n    (res) = (ggml_float) GGML_F32x4_REDUCE_ONE((x)[0]); \\\n}\n\n#define GGML_F32_VEC        GGML_F32x4\n#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO\n#define GGML_F32_VEC_SET1   GGML_F32x4_SET1\n#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD\n#define GGML_F32_VEC_STORE  GGML_F32x4_STORE\n#define GGML_F32_VEC_FMA    GGML_F32x4_FMA\n#define GGML_F32_VEC_ADD    GGML_F32x4_ADD\n#define GGML_F32_VEC_MUL    GGML_F32x4_MUL\n#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE\n\n// F16 NEON\n\n#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)\n    #define GGML_F16_STEP 32\n    #define GGML_F16_EPR  8\n\n    #define GGML_F16x8              float16x8_t\n    #define GGML_F16x8_ZERO         vdupq_n_f16(0.0f)\n    #define GGML_F16x8_SET1(x)      vdupq_n_f16(x)\n    #define GGML_F16x8_LOAD(x)      vld1q_f16((const __fp16 *)(x))\n    #define GGML_F16x8_STORE        vst1q_f16\n    #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)\n    #define GGML_F16x8_ADD          vaddq_f16\n    #define GGML_F16x8_MUL          vmulq_f16\n    #define GGML_F16x8_REDUCE(res, x)                               \\\n    do {                                                            \\\n        int offset = GGML_F16_ARR >> 1;                             \\\n        for (int i = 0; i < offset; ++i) {                          \\\n            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \\\n        }                                                           \\\n        offset >>= 1;                                               \\\n        for (int i = 0; i < offset; ++i) {                          \\\n            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \\\n        }                                                           \\\n        offset >>= 1;                                               \\\n        for (int i = 0; i < offset; ++i) {                          \\\n            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \\\n        }                                                           \\\n        const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \\\n        const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \\\n        (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \\\n    } while (0)\n\n    #define GGML_F16_VEC                GGML_F16x8\n    #define GGML_F16_VEC_ZERO           GGML_F16x8_ZERO\n    #define GGML_F16_VEC_SET1           GGML_F16x8_SET1\n    #define GGML_F16_VEC_LOAD(p, i)     GGML_F16x8_LOAD(p)\n    #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), (r)[i])\n    #define GGML_F16_VEC_FMA            GGML_F16x8_FMA\n    #define GGML_F16_VEC_ADD            GGML_F16x8_ADD\n    #define GGML_F16_VEC_MUL            GGML_F16x8_MUL\n    #define GGML_F16_VEC_REDUCE         GGML_F16x8_REDUCE\n#else\n    // if FP16 vector arithmetic is not supported, we use FP32 instead\n    // and take advantage of the vcvt_ functions to convert to/from FP16\n\n    #define GGML_F16_STEP 16\n    #define GGML_F16_EPR  4\n\n    #define GGML_F32Cx4              float32x4_t\n    #define GGML_F32Cx4_ZERO         vdupq_n_f32(0.0f)\n    #define GGML_F32Cx4_SET1(x)      vdupq_n_f32(x)\n    #define GGML_F32Cx4_LOAD(x)      vcvt_f32_f16(vld1_f16((const __fp16 *)(x)))\n    #define GGML_F32Cx4_STORE(x, y)  vst1_f16(x, vcvt_f16_f32(y))\n    #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)\n    #define GGML_F32Cx4_ADD          vaddq_f32\n    #define GGML_F32Cx4_MUL          vmulq_f32\n    #define GGML_F32Cx4_REDUCE       GGML_F32x4_REDUCE\n\n    #define GGML_F16_VEC                GGML_F32Cx4\n    #define GGML_F16_VEC_ZERO           GGML_F32Cx4_ZERO\n    #define GGML_F16_VEC_SET1           GGML_F32Cx4_SET1\n    #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx4_LOAD(p)\n    #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i])\n    #define GGML_F16_VEC_FMA            GGML_F32Cx4_FMA\n    #define GGML_F16_VEC_ADD            GGML_F32Cx4_ADD\n    #define GGML_F16_VEC_MUL            GGML_F32Cx4_MUL\n    #define GGML_F16_VEC_REDUCE         GGML_F32Cx4_REDUCE\n#endif\n\n#elif defined(__AVX512F__)\n\n#define GGML_SIMD\n\n// F32 AVX512\n\n#define GGML_F32_STEP 64\n#define GGML_F32_EPR  16\n\n#define GGML_F32x16         __m512\n#define GGML_F32x16_ZERO    _mm512_setzero_ps()\n#define GGML_F32x16_SET1(x) _mm512_set1_ps(x)\n#define GGML_F32x16_LOAD    _mm512_loadu_ps\n#define GGML_F32x16_STORE   _mm512_storeu_ps\n// _mm512_fmadd_ps is defined in AVX512F so no guard is required\n#define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)\n#define GGML_F32x16_ADD     _mm512_add_ps\n#define GGML_F32x16_MUL     _mm512_mul_ps\n#define GGML_F32x16_REDUCE(res, x)                                    \\\ndo {                                                                  \\\n    int offset = GGML_F32_ARR >> 1;                                   \\\n    for (int i = 0; i < offset; ++i) {                                \\\n        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \\\n    }                                                                 \\\n    offset >>= 1;                                                     \\\n    for (int i = 0; i < offset; ++i) {                                \\\n        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \\\n    }                                                                 \\\n    offset >>= 1;                                                     \\\n    for (int i = 0; i < offset; ++i) {                                \\\n        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \\\n    }                                                                 \\\n    res = (ggml_float) _mm512_reduce_add_ps(x[0]);                    \\\n} while (0)\n\n// TODO: is this optimal ?\n\n#define GGML_F32_VEC        GGML_F32x16\n#define GGML_F32_VEC_ZERO   GGML_F32x16_ZERO\n#define GGML_F32_VEC_SET1   GGML_F32x16_SET1\n#define GGML_F32_VEC_LOAD   GGML_F32x16_LOAD\n#define GGML_F32_VEC_STORE  GGML_F32x16_STORE\n#define GGML_F32_VEC_FMA    GGML_F32x16_FMA\n#define GGML_F32_VEC_ADD    GGML_F32x16_ADD\n#define GGML_F32_VEC_MUL    GGML_F32x16_MUL\n#define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE\n\n// F16 AVX512\n\n#if defined(__AVX512FP16__)\n\n#define GGML_F16_STEP 128\n#define GGML_F16_EPR  32\n\n#define GGML_F16x32              __m512h\n#define GGML_F16x32_ZERO         _mm512_setzero_ph()\n#define GGML_F16x32_SET1(x)      _mm512_set1_ph(__extension__(_Float16)(x))\n#define GGML_F16x32_LOAD(x)      _mm512_loadu_ph(x)\n#define GGML_F16x32_STORE(x, y)  _mm512_storeu_ph(x, y)\n#define GGML_F16x32_FMA(a, b, c) _mm512_fmadd_ph(b, c, a)\n#define GGML_F16x32_ADD          _mm512_add_ph\n#define GGML_F16x32_MUL          _mm512_mul_ph\n#define GGML_F16x32_REDUCE(res, x)                                     \\\ndo {                                                                   \\\n    int offset = GGML_F16_ARR >> 1;                                    \\\n    for (int i = 0; i < offset; ++i) {                                 \\\n        x[i] = _mm512_add_ph(x[i], x[offset+i]);                       \\\n    }                                                                  \\\n    offset >>= 1;                                                      \\\n    for (int i = 0; i < offset; ++i) {                                 \\\n        x[i] = _mm512_add_ph(x[i], x[offset+i]);                       \\\n    }                                                                  \\\n    offset >>= 1;                                                      \\\n    for (int i = 0; i < offset; ++i) {                                 \\\n        x[i] = _mm512_add_ph(x[i], x[offset+i]);                       \\\n    }                                                                  \\\n    res = (ggml_float) _mm512_reduce_add_ph(x[0]);                     \\\n} while (0)\n\n#define GGML_F16_VEC                GGML_F16x32\n#define GGML_F16_VEC_ZERO           GGML_F16x32_ZERO\n#define GGML_F16_VEC_SET1           GGML_F16x32_SET1\n#define GGML_F16_VEC_LOAD(p, i)     GGML_F16x32_LOAD(p)\n#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x32_STORE(p, r[i])\n#define GGML_F16_VEC_FMA            GGML_F16x32_FMA\n#define GGML_F16_VEC_ADD            GGML_F16x32_ADD\n#define GGML_F16_VEC_MUL            GGML_F16x32_MUL\n#define GGML_F16_VEC_REDUCE         GGML_F16x32_REDUCE\n\n#else // Fallback FP16 <-> FP32\n\n#define GGML_F16_STEP 64\n#define GGML_F16_EPR  16\n\n#define GGML_F32Cx16             __m512\n#define GGML_F32Cx16_ZERO        _mm512_setzero_ps()\n#define GGML_F32Cx16_SET1(x)     _mm512_set1_ps(x)\n\n// unlike  _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F\n// so F16C guard isn't required\n#define GGML_F32Cx16_LOAD(x)     _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))\n#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))\n\n#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)\n#define GGML_F32Cx16_ADD         _mm512_add_ps\n#define GGML_F32Cx16_MUL         _mm512_mul_ps\n#define GGML_F32Cx16_REDUCE(res, x)                               \\\ndo {                                                              \\\n    int offset = GGML_F32_ARR >> 1;                               \\\n    for (int i = 0; i < offset; ++i) {                            \\\n        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \\\n    }                                                             \\\n    offset >>= 1;                                                 \\\n    for (int i = 0; i < offset; ++i) {                            \\\n        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \\\n    }                                                             \\\n    offset >>= 1;                                                 \\\n    for (int i = 0; i < offset; ++i) {                            \\\n        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \\\n    }                                                             \\\n    res = (ggml_float) _mm512_reduce_add_ps(x[0]);                \\\n} while (0)\n\n#define GGML_F16_VEC                GGML_F32Cx16\n#define GGML_F16_VEC_ZERO           GGML_F32Cx16_ZERO\n#define GGML_F16_VEC_SET1           GGML_F32Cx16_SET1\n#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx16_LOAD(p)\n#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i])\n#define GGML_F16_VEC_FMA            GGML_F32Cx16_FMA\n#define GGML_F16_VEC_ADD            GGML_F32Cx16_ADD\n#define GGML_F16_VEC_MUL            GGML_F32Cx16_MUL\n\n#define GGML_F16_VEC_REDUCE         GGML_F32Cx16_REDUCE\n\n#endif // __AVX512FP16__\n#elif defined(__AVX__)\n\n#define GGML_SIMD\n\n// F32 AVX\n\n#define GGML_F32_STEP 32\n#define GGML_F32_EPR  8\n\n#define GGML_F32x8         __m256\n#define GGML_F32x8_ZERO    _mm256_setzero_ps()\n#define GGML_F32x8_SET1(x) _mm256_set1_ps(x)\n#define GGML_F32x8_LOAD    _mm256_loadu_ps\n#define GGML_F32x8_STORE   _mm256_storeu_ps\n#if defined(__FMA__)\n    #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)\n#else\n    #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)\n#endif\n#define GGML_F32x8_ADD     _mm256_add_ps\n#define GGML_F32x8_MUL     _mm256_mul_ps\n#define GGML_F32x8_REDUCE(res, x)                                 \\\ndo {                                                              \\\n    int offset = GGML_F32_ARR >> 1;                               \\\n    for (int i = 0; i < offset; ++i) {                            \\\n        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \\\n    }                                                             \\\n    offset >>= 1;                                                 \\\n    for (int i = 0; i < offset; ++i) {                            \\\n        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \\\n    }                                                             \\\n    offset >>= 1;                                                 \\\n    for (int i = 0; i < offset; ++i) {                            \\\n        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \\\n    }                                                             \\\n    const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]),    \\\n                                 _mm256_extractf128_ps(x[0], 1)); \\\n    const __m128 t1 = _mm_hadd_ps(t0, t0);                        \\\n    res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t1, t1));        \\\n} while (0)\n// TODO: is this optimal ?\n\n#define GGML_F32_VEC        GGML_F32x8\n#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO\n#define GGML_F32_VEC_SET1   GGML_F32x8_SET1\n#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD\n#define GGML_F32_VEC_STORE  GGML_F32x8_STORE\n#define GGML_F32_VEC_FMA    GGML_F32x8_FMA\n#define GGML_F32_VEC_ADD    GGML_F32x8_ADD\n#define GGML_F32_VEC_MUL    GGML_F32x8_MUL\n#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE\n\n// F16 AVX\n\n#define GGML_F16_STEP 32\n#define GGML_F16_EPR  8\n\n// F16 arithmetic is not supported by AVX, so we use F32 instead\n\n#define GGML_F32Cx8             __m256\n#define GGML_F32Cx8_ZERO        _mm256_setzero_ps()\n#define GGML_F32Cx8_SET1(x)     _mm256_set1_ps(x)\n\n#if defined(__F16C__)\n// the  _mm256_cvt intrinsics require F16C\n#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))\n#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))\n#else\nstatic inline __m256 __avx_f32cx8_load(const ggml_fp16_t * x) {\n    float tmp[8];\n\n    for (int i = 0; i < 8; i++) {\n        tmp[i] = GGML_CPU_FP16_TO_FP32(x[i]);\n    }\n\n    return _mm256_loadu_ps(tmp);\n}\nstatic inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {\n    float arr[8];\n\n    _mm256_storeu_ps(arr, y);\n\n    for (int i = 0; i < 8; i++)\n        x[i] = GGML_CPU_FP32_TO_FP16(arr[i]);\n}\n#define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x)\n#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)\n#endif\n\n#define GGML_F32Cx8_FMA         GGML_F32x8_FMA\n#define GGML_F32Cx8_ADD         _mm256_add_ps\n#define GGML_F32Cx8_MUL         _mm256_mul_ps\n#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE\n\n#define GGML_F16_VEC                GGML_F32Cx8\n#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO\n#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1\n#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)\n#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])\n#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA\n#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD\n#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL\n#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE\n\n#elif defined(__POWER9_VECTOR__)\n\n#define GGML_SIMD\n\n// F32 POWER9\n\n#define GGML_F32_STEP 32\n#define GGML_F32_EPR  4\n\n#define GGML_F32x4              vector float\n#define GGML_F32x4_ZERO         {0.0f}\n#define GGML_F32x4_SET1         vec_splats\n#define GGML_F32x4_LOAD(p)      vec_xl(0, p)\n#define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p)\n#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)\n#define GGML_F32x4_ADD          vec_add\n#define GGML_F32x4_MUL          vec_mul\n#define GGML_F32x4_REDUCE(res, x)              \\\n{                                              \\\n    int offset = GGML_F32_ARR >> 1;            \\\n    for (int i = 0; i < offset; ++i) {         \\\n        x[i] = vec_add(x[i], x[offset+i]);     \\\n    }                                          \\\n    offset >>= 1;                              \\\n    for (int i = 0; i < offset; ++i) {         \\\n        x[i] = vec_add(x[i], x[offset+i]);     \\\n    }                                          \\\n    offset >>= 1;                              \\\n    for (int i = 0; i < offset; ++i) {         \\\n        x[i] = vec_add(x[i], x[offset+i]);     \\\n    }                                          \\\n    res = vec_extract(x[0], 0) +               \\\n          vec_extract(x[0], 1) +               \\\n          vec_extract(x[0], 2) +               \\\n          vec_extract(x[0], 3);                \\\n}\n#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3)        \\\n{                                                       \\\n    vector float v = vec_add(vec_add(s0, s1),           \\\n                             vec_add(s2, s3));          \\\n    v = vec_add(v, vec_sld(v, v, 8));                   \\\n    v = vec_add(v, vec_sld(v, v, 4));                   \\\n    res += (ggml_float) vec_extract(v, 0);              \\\n}\n\n#define GGML_F32_VEC        GGML_F32x4\n#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO\n#define GGML_F32_VEC_SET1   GGML_F32x4_SET1\n#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD\n#define GGML_F32_VEC_STORE  GGML_F32x4_STORE\n#define GGML_F32_VEC_FMA    GGML_F32x4_FMA\n#define GGML_F32_VEC_ADD    GGML_F32x4_ADD\n#define GGML_F32_VEC_MUL    GGML_F32x4_MUL\n#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE\n\n// F16 POWER9\n#define GGML_F16_STEP       GGML_F32_STEP\n#define GGML_F16_EPR        GGML_F32_EPR\n#define GGML_F16_VEC        GGML_F32x4\n#define GGML_F16_VEC_ZERO   GGML_F32x4_ZERO\n#define GGML_F16_VEC_SET1   GGML_F32x4_SET1\n#define GGML_F16_VEC_FMA    GGML_F32x4_FMA\n#define GGML_F16_VEC_ADD    GGML_F32x4_ADD\n#define GGML_F16_VEC_MUL    GGML_F32x4_MUL\n#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE\n// Use vec_xl, not vec_ld, in case the load address is not aligned.\n#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ?                   \\\n  vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \\\n  vec_extract_fp32_from_shortl(vec_xl(0, p))\nstatic inline unsigned char ggml_endian_byte(int i) {\n       uint16_t tmp_val = 1;\n       return ((unsigned char *)&tmp_val)[i];\n}\n#define GGML_ENDIAN_BYTE(i) ggml_endian_byte(i)\n#define GGML_F16_VEC_STORE(p, r, i)                             \\\n  if (i & 0x1)                                                  \\\n    vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)],  \\\n                                   r[i - GGML_ENDIAN_BYTE(0)]), \\\n            0, p - GGML_F16_EPR)\n\n//BF16 POWER9\n#define GGML_BF16_STEP 16\n#define GGML_BF16_EPR  8\n\n#define GGML_BF16x8         vector unsigned short\n#define GGML_BF16x8_ZERO    vec_splats((unsigned short)0)\n#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p))\n\n#define GGML_BF16_VEC          GGML_BF16x8\n#define GGML_BF16_VEC_ZERO     GGML_BF16x8_ZERO\n#define GGML_BF16_VEC_LOAD     GGML_BF16x8_LOAD\n#if defined(__LITTLE_ENDIAN__)\n#define GGML_BF16_TO_F32_LO(v) ((vector float) vec_mergel(GGML_BF16_VEC_ZERO, (v)))\n#define GGML_BF16_TO_F32_HI(v) ((vector float) vec_mergeh(GGML_BF16_VEC_ZERO, (v)))\n#else\n#define GGML_BF16_TO_F32_LO(v) ((vector float) vec_mergel((v), GGML_BF16_VEC_ZERO))\n#define GGML_BF16_TO_F32_HI(v) ((vector float) vec_mergeh((v), GGML_BF16_VEC_ZERO))\n#endif\n#define GGML_BF16_FMA_LO(acc, x, y) \\\n    (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y))\n#define GGML_BF16_FMA_HI(acc, x, y) \\\n    (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y))\n\n#elif defined(__wasm_simd128__)\n\n#define GGML_SIMD\n\n// F32 WASM\n\n#define GGML_F32_STEP 16\n#define GGML_F32_EPR  4\n\n#define GGML_F32x4              v128_t\n#define GGML_F32x4_ZERO         wasm_f32x4_splat(0.0f)\n#define GGML_F32x4_SET1(x)      wasm_f32x4_splat(x)\n#define GGML_F32x4_LOAD         wasm_v128_load\n#define GGML_F32x4_STORE        wasm_v128_store\n#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)\n#define GGML_F32x4_ADD          wasm_f32x4_add\n#define GGML_F32x4_MUL          wasm_f32x4_mul\n#define GGML_F32x4_REDUCE(res, x)                  \\\n{                                                  \\\n    int offset = GGML_F32_ARR >> 1;                \\\n    for (int i = 0; i < offset; ++i) {             \\\n        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \\\n    }                                              \\\n    offset >>= 1;                                  \\\n    for (int i = 0; i < offset; ++i) {             \\\n        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \\\n    }                                              \\\n    offset >>= 1;                                  \\\n    for (int i = 0; i < offset; ++i) {             \\\n        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \\\n    }                                              \\\n    res = wasm_f32x4_extract_lane(x[0], 0) +       \\\n          wasm_f32x4_extract_lane(x[0], 1) +       \\\n          wasm_f32x4_extract_lane(x[0], 2) +       \\\n          wasm_f32x4_extract_lane(x[0], 3);        \\\n}\n\n#define GGML_F32_VEC        GGML_F32x4\n#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO\n#define GGML_F32_VEC_SET1   GGML_F32x4_SET1\n#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD\n#define GGML_F32_VEC_STORE  GGML_F32x4_STORE\n#define GGML_F32_VEC_FMA    GGML_F32x4_FMA\n#define GGML_F32_VEC_ADD    GGML_F32x4_ADD\n#define GGML_F32_VEC_MUL    GGML_F32x4_MUL\n#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE\n\n// F16 WASM\n\n#define GGML_F16_STEP 16\n#define GGML_F16_EPR  4\n\ninline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {\n    float tmp[4];\n\n    tmp[0] = GGML_CPU_FP16_TO_FP32(p[0]);\n    tmp[1] = GGML_CPU_FP16_TO_FP32(p[1]);\n    tmp[2] = GGML_CPU_FP16_TO_FP32(p[2]);\n    tmp[3] = GGML_CPU_FP16_TO_FP32(p[3]);\n\n    return wasm_v128_load(tmp);\n}\n\ninline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {\n    float tmp[4];\n\n    wasm_v128_store(tmp, x);\n\n    p[0] = GGML_CPU_FP32_TO_FP16(tmp[0]);\n    p[1] = GGML_CPU_FP32_TO_FP16(tmp[1]);\n    p[2] = GGML_CPU_FP32_TO_FP16(tmp[2]);\n    p[3] = GGML_CPU_FP32_TO_FP16(tmp[3]);\n}\n\n#define GGML_F16x4             v128_t\n#define GGML_F16x4_ZERO        wasm_f32x4_splat(0.0f)\n#define GGML_F16x4_SET1(x)     wasm_f32x4_splat(x)\n#define GGML_F16x4_LOAD(x)     __wasm_f16x4_load(x)\n#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)\n#define GGML_F16x4_FMA         GGML_F32x4_FMA\n#define GGML_F16x4_ADD         wasm_f32x4_add\n#define GGML_F16x4_MUL         wasm_f32x4_mul\n#define GGML_F16x4_REDUCE(res, x)                           \\\n{                                                           \\\n    int offset = GGML_F16_ARR >> 1;                         \\\n    for (int i = 0; i < offset; ++i) {                      \\\n        x[i] = wasm_f32x4_add(x[i], x[offset+i]);           \\\n    }                                                       \\\n    offset >>= 1;                                           \\\n    for (int i = 0; i < offset; ++i) {                      \\\n        x[i] = wasm_f32x4_add(x[i], x[offset+i]);           \\\n    }                                                       \\\n    offset >>= 1;                                           \\\n    for (int i = 0; i < offset; ++i) {                      \\\n        x[i] = wasm_f32x4_add(x[i], x[offset+i]);           \\\n    }                                                       \\\n    res = (ggml_float) (wasm_f32x4_extract_lane(x[0], 0) +  \\\n          wasm_f32x4_extract_lane(x[0], 1) +                \\\n          wasm_f32x4_extract_lane(x[0], 2) +                \\\n          wasm_f32x4_extract_lane(x[0], 3));                \\\n}\n\n#define GGML_F16_VEC                GGML_F16x4\n#define GGML_F16_VEC_ZERO           GGML_F16x4_ZERO\n#define GGML_F16_VEC_SET1           GGML_F16x4_SET1\n#define GGML_F16_VEC_LOAD(p, i)     GGML_F16x4_LOAD(p)\n#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])\n#define GGML_F16_VEC_FMA            GGML_F16x4_FMA\n#define GGML_F16_VEC_ADD            GGML_F16x4_ADD\n#define GGML_F16_VEC_MUL            GGML_F16x4_MUL\n#define GGML_F16_VEC_REDUCE         GGML_F16x4_REDUCE\n\n#elif defined(__SSE3__)\n\n#define GGML_SIMD\n\n// F32 SSE\n\n#define GGML_F32_STEP 32\n#define GGML_F32_EPR  4\n\n#define GGML_F32x4         __m128\n#define GGML_F32x4_ZERO    _mm_setzero_ps()\n#define GGML_F32x4_SET1(x) _mm_set1_ps(x)\n#define GGML_F32x4_LOAD    _mm_loadu_ps\n#define GGML_F32x4_STORE   _mm_storeu_ps\n#if defined(__FMA__)\n    // TODO: Does this work?\n    #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)\n#else\n    #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)\n#endif\n#define GGML_F32x4_ADD     _mm_add_ps\n#define GGML_F32x4_MUL     _mm_mul_ps\n#define GGML_F32x4_REDUCE(res, x)                                 \\\n{                                                                 \\\n    int offset = GGML_F32_ARR >> 1;                               \\\n    for (int i = 0; i < offset; ++i) {                            \\\n        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \\\n    }                                                             \\\n    offset >>= 1;                                                 \\\n    for (int i = 0; i < offset; ++i) {                            \\\n        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \\\n    }                                                             \\\n    offset >>= 1;                                                 \\\n    for (int i = 0; i < offset; ++i) {                            \\\n        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \\\n    }                                                             \\\n    const __m128 t0 = _mm_hadd_ps(x[0], x[0]);                    \\\n    res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t0, t0));        \\\n}\n// TODO: is this optimal ?\n\n#define GGML_F32_VEC        GGML_F32x4\n#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO\n#define GGML_F32_VEC_SET1   GGML_F32x4_SET1\n#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD\n#define GGML_F32_VEC_STORE  GGML_F32x4_STORE\n#define GGML_F32_VEC_FMA    GGML_F32x4_FMA\n#define GGML_F32_VEC_ADD    GGML_F32x4_ADD\n#define GGML_F32_VEC_MUL    GGML_F32x4_MUL\n#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE\n\n// F16 SSE\n\n#define GGML_F16_STEP 32\n#define GGML_F16_EPR  4\n\nstatic inline __m128 __sse_f16x4_load(const ggml_fp16_t * x) {\n    float tmp[4];\n\n    tmp[0] = GGML_CPU_FP16_TO_FP32(x[0]);\n    tmp[1] = GGML_CPU_FP16_TO_FP32(x[1]);\n    tmp[2] = GGML_CPU_FP16_TO_FP32(x[2]);\n    tmp[3] = GGML_CPU_FP16_TO_FP32(x[3]);\n\n    return _mm_loadu_ps(tmp);\n}\n\nstatic inline void __sse_f16x4_store(ggml_fp16_t * x, __m128 y) {\n    float arr[4];\n\n    _mm_storeu_ps(arr, y);\n\n    x[0] = GGML_CPU_FP32_TO_FP16(arr[0]);\n    x[1] = GGML_CPU_FP32_TO_FP16(arr[1]);\n    x[2] = GGML_CPU_FP32_TO_FP16(arr[2]);\n    x[3] = GGML_CPU_FP32_TO_FP16(arr[3]);\n}\n\n#define GGML_F32Cx4             __m128\n#define GGML_F32Cx4_ZERO        _mm_setzero_ps()\n#define GGML_F32Cx4_SET1(x)     _mm_set1_ps(x)\n#define GGML_F32Cx4_LOAD(x)     __sse_f16x4_load(x)\n#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)\n#define GGML_F32Cx4_FMA         GGML_F32x4_FMA\n#define GGML_F32Cx4_ADD         _mm_add_ps\n#define GGML_F32Cx4_MUL         _mm_mul_ps\n#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE\n\n#define GGML_F16_VEC                 GGML_F32Cx4\n#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO\n#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1\n#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)\n#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])\n#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA\n#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD\n#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL\n#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE\n\n#elif defined(__loongarch_asx)\n\n#define GGML_SIMD\n\n// F32 LASX\n#define GGML_F32_STEP 32\n#define GGML_F32_EPR  8\n\n#define GGML_F32x8         __m256\n#define GGML_F32x8_ZERO    (__m256)__lasx_xvldi(0)\n#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))\n#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)\n#define GGML_F32x8_STORE(x,y)   __lasx_xvst((y), (x), 0)\n#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)\n#define GGML_F32x8_ADD     __lasx_xvfadd_s\n#define GGML_F32x8_MUL     __lasx_xvfmul_s\n#define GGML_F32x8_REDUCE(res, x)                                 \\\ndo {                                                              \\\n    int offset = GGML_F32_ARR >> 1;                               \\\n    for (int i = 0; i < offset; ++i) {                            \\\n        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \\\n    }                                                             \\\n    offset >>= 1;                                                 \\\n    for (int i = 0; i < offset; ++i) {                            \\\n        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \\\n    }                                                             \\\n    offset >>= 1;                                                 \\\n    for (int i = 0; i < offset; ++i) {                            \\\n        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \\\n    }                                                             \\\n    float *tmp_p = (float *)&x[0]; \\\n    res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7];  \\\n} while (0)\n// TODO: is this optimal ?\n\n#define GGML_F32_VEC        GGML_F32x8\n#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO\n#define GGML_F32_VEC_SET1   GGML_F32x8_SET1\n#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD\n#define GGML_F32_VEC_STORE  GGML_F32x8_STORE\n#define GGML_F32_VEC_FMA    GGML_F32x8_FMA\n#define GGML_F32_VEC_ADD    GGML_F32x8_ADD\n#define GGML_F32_VEC_MUL    GGML_F32x8_MUL\n#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE\n\n// F16 LASX\n\n#define GGML_F16_STEP 32\n#define GGML_F16_EPR  8\n\n// F16 arithmetic is not supported by LASX, so we use F32 instead\n\n#define GGML_F32Cx8          __m256\n#define GGML_F32Cx8_ZERO    (__m256)__lasx_xvldi(0)\n#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))\n\nstatic inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {\n    __m256i a;\n    memcpy(&a, x, sizeof(ggml_fp16_t) * 8);\n    a = __lasx_xvpermi_d(a, 0 | (1 << 4));\n    return __lasx_xvfcvtl_s_h(a);\n}\n\nstatic inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {\n    __m256i a = __lasx_xvfcvt_h_s(y, y);\n    a = __lasx_xvpermi_d(a, 0 | (2 << 2));\n    memcpy(x, &a, sizeof(ggml_fp16_t) * 8);\n}\n#define GGML_F32Cx8_LOAD(x)     __lasx_f32cx8_load(x)\n#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)\n\n#define GGML_F32Cx8_FMA         GGML_F32x8_FMA\n#define GGML_F32Cx8_ADD         __lasx_xvfadd_s\n#define GGML_F32Cx8_MUL         __lasx_xvfmul_s\n#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE\n\n#define GGML_F16_VEC                GGML_F32Cx8\n#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO\n#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1\n#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)\n#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])\n#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA\n#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD\n#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL\n#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE\n\n#elif defined(__loongarch_sx)\n\n#define GGML_SIMD\n\n// F32 LSX\n\n#define GGML_F32_STEP 32\n#define GGML_F32_EPR  4\n\n#define GGML_F32x4         __m128\n#define GGML_F32x4_ZERO    (__m128)__lsx_vldi(0)\n#define GGML_F32x4_SET1(x) (__m128)__lsx_vreplfr2vr_s((x))\n#define GGML_F32x4_LOAD(x) (__m128)__lsx_vld((x), 0)\n#define GGML_F32x4_STORE(x, y)   __lsx_vst(y, x, 0)\n#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)\n#define GGML_F32x4_ADD     __lsx_vfadd_s\n#define GGML_F32x4_MUL     __lsx_vfmul_s\n\n#define GGML_F32x4_REDUCE(res, x)                               \\\n{                                                               \\\n    int offset = GGML_F32_ARR >> 1;                             \\\n    for (int i = 0; i < offset; ++i) {                          \\\n        x[i] = __lsx_vfadd_s(x[i], x[offset+i]);                \\\n    }                                                           \\\n    offset >>= 1;                                               \\\n    for (int i = 0; i < offset; ++i) {                          \\\n        x[i] = __lsx_vfadd_s(x[i], x[offset+i]);                \\\n    }                                                           \\\n    offset >>= 1;                                               \\\n    for (int i = 0; i < offset; ++i) {                          \\\n        x[i] = __lsx_vfadd_s(x[i], x[offset+i]);                \\\n    }                                                           \\\n    __m128i t0 = __lsx_vpickev_w((__m128i)x[0], (__m128i)x[0]); \\\n    __m128i t1 = __lsx_vpickod_w((__m128i)x[0], (__m128i)x[0]); \\\n    __m128 t2 = __lsx_vfadd_s((__m128)t0, (__m128)t1);          \\\n    __m128i t3 = __lsx_vpickev_w((__m128i)t2, (__m128i)t2);     \\\n    __m128i t4 = __lsx_vpickod_w((__m128i)t2, (__m128i)t2);     \\\n    __m128 t5 = __lsx_vfadd_s((__m128)t3, (__m128)t4);          \\\n    res = (ggml_float) ((v4f32)t5)[0];                          \\\n}\n\n#define GGML_F32_VEC        GGML_F32x4\n#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO\n#define GGML_F32_VEC_SET1   GGML_F32x4_SET1\n#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD\n#define GGML_F32_VEC_STORE  GGML_F32x4_STORE\n#define GGML_F32_VEC_FMA    GGML_F32x4_FMA\n#define GGML_F32_VEC_ADD    GGML_F32x4_ADD\n#define GGML_F32_VEC_MUL    GGML_F32x4_MUL\n#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE\n\n// F16 LSX\n\n#define GGML_F16_STEP 32\n#define GGML_F16_EPR  4\n\nstatic inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {\n    float tmp[4];\n\n    tmp[0] = GGML_CPU_FP16_TO_FP32(x[0]);\n    tmp[1] = GGML_CPU_FP16_TO_FP32(x[1]);\n    tmp[2] = GGML_CPU_FP16_TO_FP32(x[2]);\n    tmp[3] = GGML_CPU_FP16_TO_FP32(x[3]);\n\n    return (__m128)__lsx_vld(tmp, 0);\n}\n\nstatic inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {\n    float arr[4];\n\n    __lsx_vst(y, arr, 0);\n\n    x[0] = GGML_CPU_FP32_TO_FP16(arr[0]);\n    x[1] = GGML_CPU_FP32_TO_FP16(arr[1]);\n    x[2] = GGML_CPU_FP32_TO_FP16(arr[2]);\n    x[3] = GGML_CPU_FP32_TO_FP16(arr[3]);\n}\n\n#define GGML_F32Cx4             __m128\n#define GGML_F32Cx4_ZERO        (__m128)__lsx_vldi(0)\n#define GGML_F32Cx4_SET1(x)     (__m128)__lsx_vreplfr2vr_s((x))\n#define GGML_F32Cx4_LOAD(x)     (__m128)__lsx_f16x4_load(x)\n#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)\n#define GGML_F32Cx4_FMA         GGML_F32x4_FMA\n#define GGML_F32Cx4_ADD         __lsx_vfadd_s\n#define GGML_F32Cx4_MUL         __lsx_vfmul_s\n#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE\n\n#define GGML_F16_VEC                 GGML_F32Cx4\n#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO\n#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1\n#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)\n#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])\n#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA\n#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD\n#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL\n#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE\n\n#elif defined(__VXE__) || defined(__VXE2__)\n\n#define GGML_SIMD\n\n// F32 s390x\n\n#define GGML_F32_STEP 32\n#define GGML_F32_EPR  4\n\n#define GGML_F32x4              float32x4_t\n#define GGML_F32x4_ZERO         vec_splats(0.0f)\n#define GGML_F32x4_SET1         vec_splats\n#define GGML_F32x4_LOAD(p)      vec_xl(0, p)\n#define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p)\n#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)\n#define GGML_F32x4_ADD          vec_add\n#define GGML_F32x4_MUL          vec_mul\n#define GGML_F32x4_REDUCE(res, x)                   \\\n{                                                   \\\n    int offset = GGML_F32_ARR >> 1;                 \\\n    for (int i = 0; i < offset; ++i) {              \\\n        x[i] = vec_add(x[i], x[offset + i]);        \\\n    }                                               \\\n    offset >>= 1;                                   \\\n    for (int i = 0; i < offset; ++i) {              \\\n        x[i] = vec_add(x[i], x[offset + i]);        \\\n    }                                               \\\n    offset >>= 1;                                   \\\n    for (int i = 0; i < offset; ++i) {              \\\n        x[i] = vec_add(x[i], x[offset + i]);        \\\n    }                                               \\\n    float32x4_t tmp = x[0] + vec_reve(x[0]);        \\\n    res = tmp[0] + tmp[1];                          \\\n}\n#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \\\n{                                                \\\n    float32x4_t v = vec_add(vec_add(s0, s1),     \\\n                            vec_add(s2, s3));    \\\n    v = vec_add(v, vec_sld(v, v, 8));            \\\n    v = vec_add(v, vec_sld(v, v, 4));            \\\n    res += (ggml_float)vec_extract(v, 0);        \\\n}\n\n#define GGML_F32_VEC        GGML_F32x4\n#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO\n#define GGML_F32_VEC_SET1   GGML_F32x4_SET1\n#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD\n#define GGML_F32_VEC_STORE  GGML_F32x4_STORE\n#define GGML_F32_VEC_FMA    GGML_F32x4_FMA\n#define GGML_F32_VEC_ADD    GGML_F32x4_ADD\n#define GGML_F32_VEC_MUL    GGML_F32x4_MUL\n#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE\n\n// F16 s390x\n#define GGML_F16_STEP GGML_F32_STEP\n#define GGML_F16_EPR  GGML_F32_EPR\n\nstatic inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) {\n    float tmp[4];\n\n    for (int i = 0; i < 4; i++) {\n        tmp[i] = GGML_CPU_FP16_TO_FP32(x[i]);\n    }\n\n    // note: keep type-cast here to prevent compiler bugs\n    // see: https://github.com/ggml-org/llama.cpp/issues/12846\n    return vec_xl(0, (const float *)(tmp));\n}\n\nstatic inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {\n    float arr[4];\n\n    // note: keep type-cast here to prevent compiler bugs\n    // see: https://github.com/ggml-org/llama.cpp/issues/12846\n    vec_xst(v_y, 0, (float *)(arr));\n\n    for (int i = 0; i < 4; i++) {\n        x[i] = GGML_CPU_FP32_TO_FP16(arr[i]);\n    }\n}\n\n#define GGML_F16_VEC                GGML_F32x4\n#define GGML_F16_VEC_ZERO           GGML_F32x4_ZERO\n#define GGML_F16_VEC_SET1           GGML_F32x4_SET1\n#define GGML_F16_VEC_LOAD(p, i)     __lzs_f16cx4_load(p)\n#define GGML_F16_VEC_STORE(p, r, i) __lzs_f16cx4_store(p, r[i])\n#define GGML_F16_VEC_FMA            GGML_F32x4_FMA\n#define GGML_F16_VEC_ADD            GGML_F32x4_ADD\n#define GGML_F16_VEC_MUL            GGML_F32x4_MUL\n#define GGML_F16_VEC_REDUCE         GGML_F32x4_REDUCE\n\n// BF16 s390x\n#define GGML_BF16_STEP 16\n#define GGML_BF16_EPR  8\n\n#define GGML_BF16x8         __vector unsigned short\n#define GGML_BF16x8_ZERO    vec_splats((unsigned short)0)\n#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p))\n\n#define GGML_BF16_VEC      GGML_BF16x8\n#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO\n#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD\n#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO))\n#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO))\n#define GGML_BF16_FMA_LO(acc, x, y) \\\n    (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y))\n#define GGML_BF16_FMA_HI(acc, x, y) \\\n    (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y))\n\n#elif defined(__riscv_v_intrinsic)\n\n// compatible with vlen >= 128\n\n#define GGML_SIMD\n\n// F32\n\n#define GGML_F32_STEP 16\n#define GGML_F32_EPR  4\n\n#define GGML_F32x4              vfloat32m1_t\n#define GGML_F32x4_ZERO         __riscv_vfmv_v_f_f32m1(0.0f, GGML_F32_EPR)\n#define GGML_F32x4_SET1(x)      __riscv_vfmv_v_f_f32m1(x, GGML_F32_EPR)\n#define GGML_F32x4_LOAD(x)      __riscv_vle32_v_f32m1(x, GGML_F32_EPR)\n#define GGML_F32x4_STORE(b, v)  __riscv_vse32_v_f32m1(b, v, GGML_F32_EPR)\n#define GGML_F32x4_FMA(a, b, c) __riscv_vfmacc_vv_f32m1(a, b, c, GGML_F32_EPR)\n#define GGML_F32x4_ADD(a, b)    __riscv_vfadd_vv_f32m1(a, b, GGML_F32_EPR)\n#define GGML_F32x4_MUL(a, b)    __riscv_vfmul_vv_f32m1(a, b, GGML_F32_EPR)\n\n#define GGML_F32_VEC        GGML_F32x4\n#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO\n#define GGML_F32_VEC_SET1   GGML_F32x4_SET1\n#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD\n#define GGML_F32_VEC_STORE  GGML_F32x4_STORE\n#define GGML_F32_VEC_FMA    GGML_F32x4_FMA\n#define GGML_F32_VEC_ADD    GGML_F32x4_ADD\n#define GGML_F32_VEC_MUL    GGML_F32x4_MUL\n#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE\n\n#endif\n\n// GGML_F32_ARR / GGML_F16_ARR\n//   number of registers to use per step\n#ifdef GGML_SIMD\n#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)\n#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)\n#endif\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/spacemit/ime.cpp",
    "content": "#define GGML_COMMON_IMPL_CPP\n#define GGML_COMMON_DECL_CPP\n\n#include \"ime.h\"\n\n#include \"ggml-backend-impl.h\"\n#include \"ggml-common.h\"\n#include \"ggml-cpu.h\"\n#include \"ime_kernels.h\"\n#include \"traits.h\"\n\n#include <algorithm>\n#include <cassert>\n#include <cmath>\n#include <cstdio>  // for GGML_ASSERT\n#include <stdexcept>\n#include <thread>\n\n// clang-format off\n#if defined(__riscv)\n\n#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic)\n#error \"riscv v extension or v_intrinsic not enabled\"\n#else\n#include <riscv_vector.h>\n#endif\n\n#if !defined(__riscv_zfh)\n#error \"riscv zfh extension not enabled\"\n#endif\n\n#if defined(RISCV64_SPACEMIT_IME1)\n#else\n#error \"RISCV64_SPACEMIT_IME1 not defined\"\n#endif\n\n#else\n\n#error \"riscv not enabled in this build\"\n\n#endif\n\n#if defined(__GNUC__)\n#pragma GCC diagnostic ignored \"-Woverlength-strings\"\n#pragma GCC diagnostic ignored \"-Wcast-qual\"\n#pragma GCC diagnostic ignored \"-Wunused-parameter\"\n#endif\n\n#if defined(RISCV64_SPACEMIT_IME1)\n#define QGEMM_STRIDEN_THREAD_ALIGN 16\n#else\n#define QGEMM_STRIDEN_THREAD_ALIGN 32\n#endif\n\n// clang-format on\n\nstruct qnbitgemm_spacemit_ime_args {\n    const float *     a_ptr               = nullptr;\n    size_t            lda                 = 0;\n    const std::byte * packed_quant_b_data = nullptr;\n    const float *     quant_b_scale       = nullptr;\n    const void *      quant_b_zp          = nullptr;\n    const float *     quant_b_blksum      = nullptr;\n    const float *     bias                = nullptr;\n    float *           c_ptr               = nullptr;\n    size_t            ldc                 = 0;\n};\n\nconstexpr size_t div_round_up(size_t up, size_t down) {\n    return (up + down - 1) / down;\n}\n\nconstexpr size_t q8_blk_size(size_t blk_len) {\n    const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t);\n    // Currently, the strictest alignment requirement of a block is for a float.\n    // Ensure contiguous blocks are suitably aligned.\n    assert(blk_size % alignof(float) == 0);\n    return blk_size;\n}\n\nnamespace ggml::cpu::riscv64_spacemit {\n\nconst int num_ai_cores = std::thread::hardware_concurrency() / 2;\n\n}  // namespace ggml::cpu::riscv64_spacemit\n\nstatic void sqnbitgemm_spacemit_ime_i8i4(const size_t                        blk_len,\n                                         const size_t                        gemm_k,\n                                         const qnbitgemm_spacemit_ime_args * gemm_args,\n                                         void * const                        per_gemm_ws,\n                                         const size_t                        m_start,\n                                         const size_t                        m_count,\n                                         const size_t                        n_start,\n                                         const size_t                        n_count) {\n    constexpr size_t scale_stride = sizeof(uint16_t);\n    constexpr size_t blk_bitwidth = 4;\n\n    const size_t k_blks = div_round_up(gemm_k, blk_len);\n\n    const size_t      lda         = k_blks * q8_blk_size(blk_len);\n    const size_t      ldc         = gemm_args->ldc;\n    const size_t      ldb         = k_blks * (blk_len * blk_bitwidth / 8);\n    const std::byte * quant_a_ptr = static_cast<const std::byte *>(per_gemm_ws) + m_start * lda;\n\n    const size_t      zero_point_stride   = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0;\n    const size_t      packed_b_stride     = ldb + k_blks * (scale_stride + zero_point_stride);\n    const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride;\n\n    float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start;\n\n    size_t       count_n               = 0;\n    const size_t compute_block_count_n = m_count == 1 ? n_count : 16;\n    for (size_t n = 0; n < n_count; n += count_n) {\n        count_n = std::min(n_count - n, compute_block_count_n);\n\n        const std::byte * a_row    = quant_a_ptr;\n        const std::byte * b_col    = packed_quant_b_data + n * packed_b_stride;\n        const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr;\n        float *           c_blk    = c_ptr + n;\n\n        int32_t rows_remaining = m_count;\n\n        while (rows_remaining > 0) {\n            const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4(\n                blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr,\n                scale_stride);\n\n            c_blk += rows_handled * ldc;\n            a_row += rows_handled * lda;\n\n            rows_remaining -= rows_handled;\n        }\n    }\n}\n\ntemplate <int K> constexpr int QK_0() {\n    if constexpr (K == 4) {\n        return QK4_0;\n    }\n    if constexpr (K == 8) {\n        return QK8_0;\n    }\n    return -1;\n}\n\ntemplate <int K, int N> struct block {\n    ggml_half d[N];                         // deltas for N qK_0 blocks\n    uint8_t   qs[(QK_0<K>() * N * K) / 8];  // quants for N qK_0 blocks\n};\n\ntemplate <int K, int N> struct block_with_zp {\n    ggml_half d[N];                         // deltas for N qK_1 blocks\n    uint8_t   zp[N];                        // zero points for N qK_1 blocks\n    uint8_t   qs[(QK_0<K>() * N * K) / 8];  // quants for N qK_1 blocks\n};\n\n// control size\nstatic_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, \"wrong block<4,16> size/padding\");\nstatic_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t),\n              \"wrong block_with_zp<4,16> size/padding\");\nstatic_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, \"wrong block<8,16> size/padding\");\n\nusing block_q4_0x16 = block<4, 16>;\nusing block_q4_1x16 = block_with_zp<4, 16>;\nusing block_q8_0x16 = block<8, 16>;\n\nstatic block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) {\n    block_q4_0x16 out;\n    GGML_ASSERT(QK4_0 / blck_size_interleave == 2);\n\n    for (int i = 0; i < 16; i++) {\n        out.d[i] = in[i].d;\n    }\n\n    for (int i = 0; i < 16; i++) {\n        // [0, 15], in.d & 0x0F\n        for (int j = 0; j < QK4_0 / 4; j++) {\n            //src [b0 b16] ......... [b8 b24] ......... [b15 b31]\n            //dst [b0 b8] ......... [b7 b15]\n            out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4);\n        }\n    }\n\n    for (int i = 0; i < 16; i++) {\n        // [16, 31], in.d & 0xF0\n        for (int j = 0; j < QK4_0 / 4; j++) {\n            //src [b0 b16] ......... [b8 b24] ......... [b15 b31]\n            //dst [b16 b24] ......... [b23 b31]\n            out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0);\n        }\n    }\n\n    return out;\n}\n\nstatic block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) {\n    block_q4_1x16 out;\n    GGML_ASSERT(QK4_1 / blck_size_interleave == 2);\n\n    for (int i = 0; i < 16; i++) {\n        float d   = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);\n        float m   = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m);\n        float mid = -std::nearbyintf(m / d);\n        mid       = std::min(15.0f, std::max(0.0f, mid));\n        out.d[i]  = GGML_FP32_TO_FP16(d);\n        out.zp[i] = static_cast<uint8_t>(mid);\n    }\n\n    for (int i = 0; i < 16; i++) {\n        // [0, 15], in.d & 0x0F\n        for (int j = 0; j < QK4_1 / 4; j++) {\n            //src [b0 b16] ......... [b8 b24] ......... [b15 b31]\n            //dst [b0 b8] ......... [b7 b15]\n            out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4);\n        }\n    }\n\n    for (int i = 0; i < 16; i++) {\n        // [16, 31], in.d & 0xF0\n        for (int j = 0; j < QK4_1 / 4; j++) {\n            //src [b0 b16] ......... [b8 b24] ......... [b15 b31]\n            //dst [b16 b24] ......... [b23 b31]\n            out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0);\n        }\n    }\n\n    return out;\n}\n\nstatic int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor *       t,\n                                     int                        interleave_block,\n                                     const void * GGML_RESTRICT data,\n                                     size_t                     data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);\n    GGML_ASSERT(interleave_block == 16);\n\n    constexpr int nrows_interleaved = 16;\n\n    block_q4_0x16 *    dst = (block_q4_0x16 *) t->data;\n    const block_q4_0 * src = (const block_q4_0 *) data;\n    block_q4_0         dst_tmp[16];\n    int                nrow    = ggml_nrows(t);\n    int                nblocks = t->ne[0] / QK4_0;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i = 0; i < nrows_interleaved; i++) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_q4_0x16(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nstatic int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor *       t,\n                                     int                        interleave_block,\n                                     const void * GGML_RESTRICT data,\n                                     size_t                     data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_Q4_1);\n    GGML_ASSERT(interleave_block == 16);\n\n    constexpr int nrows_interleaved = 16;\n\n    block_q4_1x16 *    dst = (block_q4_1x16 *) t->data;\n    const block_q4_1 * src = (const block_q4_1 *) data;\n    block_q4_1         dst_tmp[16];\n    int                nrow    = ggml_nrows(t);\n    int                nblocks = t->ne[0] / QK4_1;\n\n    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1));\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int i = 0; i < nrows_interleaved; i++) {\n                dst_tmp[i] = src[x + i * nblocks];\n            }\n            *dst++ = make_block_q4_1x16(dst_tmp, interleave_block);\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nstatic inline void get_scale_min_k4(int                           j,\n                                    const uint8_t * GGML_RESTRICT q,\n                                    uint8_t * GGML_RESTRICT       d,\n                                    uint8_t * GGML_RESTRICT       m) {\n    if (j < 4) {\n        *d = q[j] & 63;\n        *m = q[j + 4] & 63;\n    } else {\n        *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);\n        *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);\n    }\n}\n\nstatic int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor *       t,\n                                     int                        interleave_block,\n                                     const void * GGML_RESTRICT data,\n                                     size_t                     data_size) {\n    GGML_ASSERT(t->type == GGML_TYPE_Q4_K);\n    GGML_ASSERT(interleave_block == 16);\n    GGML_ASSERT(QK_K / QK4_1 == 8);\n\n    constexpr int nrows_interleaved = 16;\n\n    block_q4_1x16 *    dst = (block_q4_1x16 *) t->data;\n    const block_q4_K * src = (const block_q4_K *) data;\n    block_q4_1         dst_tmp[16];\n    int                nrow    = ggml_nrows(t);\n    int                nblocks = t->ne[0] / QK_K;\n\n    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) {\n        return -1;\n    }\n\n    for (int b = 0; b < nrow; b += nrows_interleaved) {\n        for (int64_t x = 0; x < nblocks; x++) {\n            for (int j = 0; j < 8; j++) {\n                for (int i = 0; i < nrows_interleaved; i++) {\n                    uint8_t     sc, m;\n                    const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);\n                    const float min =\n                        GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin);\n                    get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m);\n                    const float d1 = d * sc;\n                    const float m1 = min * m;\n\n                    dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1);\n                    dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1);\n                    // src -> [b0, b32] [b1, b33] ... [b31, b63]\n                    // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63]\n                    const uint8_t * q                                  = src[x + i * nblocks].qs + (j / 2) * QK4_1;\n                    if (j % 2 == 0) {\n                        for (int ii = 0; ii < 16; ii++) {\n                            dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4);\n                        }\n                    } else {\n                        for (int ii = 0; ii < 16; ii++) {\n                            dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0);\n                        }\n                    }\n                }\n                *dst++ = make_block_q4_1x16(dst_tmp, interleave_block);\n            }\n        }\n        src += nrows_interleaved * nblocks;\n    }\n    return 0;\n\n    GGML_UNUSED(data_size);\n}\n\nnamespace ggml::cpu::riscv64_spacemit {\n\ntemplate <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>\nint repack(struct ggml_tensor *, const void *, size_t);\n\ntemplate <> int repack<block_q4_0, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size);\n}\n\ntemplate <> int repack<block_q4_1, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size);\n}\n\ntemplate <> int repack<block_q4_K, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {\n    return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size);\n}\n\nclass tensor_traits_base : public ggml::cpu::tensor_traits {\n  public:\n    virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;\n};\n\ntemplate <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {\n    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {\n        switch (op->op) {\n            case GGML_OP_MUL_MAT:\n                size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4;\n                size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float));\n                return true;\n            default:\n                // GGML_ABORT(\"fatal error\");\n                break;\n        }\n        return false;\n    }\n\n    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {\n        switch (op->op) {\n            case GGML_OP_MUL_MAT:\n                if (op->src[0]->type == GGML_TYPE_Q4_0 ||  //\n                    op->src[0]->type == GGML_TYPE_Q4_1 ||  //\n                    op->src[0]->type == GGML_TYPE_Q4_K) {\n                    forward_mul_mat_q4(params, op);\n                    return true;\n                }\n            default:\n                // GGML_ABORT(\"fatal error\");\n                break;\n        }\n        return false;\n    }\n\n    void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) {\n        const ggml_tensor * src0 = op->src[0];\n        const ggml_tensor * src1 = op->src[1];\n        ggml_tensor *       dst  = op;\n\n        GGML_TENSOR_BINARY_OP_LOCALS\n\n        int ith = params->ith;\n        int nth = params->nth;\n\n        [[maybe_unused]] const enum ggml_type type = src0->type;\n\n        void *        w_data  = (void *) src0->data;\n        const float * feature = (const float *) src1->data;\n        float *       output  = (float *) dst->data;\n\n        const size_t                  batch_feature = ne12 * ne13;\n        [[maybe_unused]] const size_t batch_weight  = ne02 * ne03;\n        const size_t                  gemm_m        = ne11;\n        const size_t                  gemm_k        = ne10;\n        const size_t                  gemm_n        = ne01;\n\n        GGML_ASSERT(batch_weight == 1);\n\n        const size_t block_count_k           = div_round_up(gemm_k, QK4_0);\n        const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0);\n        const size_t per_gemm_workspace_stride =\n            div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t);\n        const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride;\n        const size_t desired_wsize       = gemm_workspace_size + alignof(uint64_t) - 1;\n\n        if (ith == 0 && params->wsize < desired_wsize) {\n            throw std::runtime_error(\"wsize less than desired_wsize\");\n        }\n\n        std::vector<qnbitgemm_spacemit_ime_args> qnbitgemm_args(batch_feature);\n\n        for (size_t i = 0; i < batch_feature; i++) {\n            qnbitgemm_args[i].a_ptr               = feature + gemm_m * gemm_k * i;\n            qnbitgemm_args[i].lda                 = gemm_k;\n            qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data;\n            qnbitgemm_args[i].quant_b_scale       = nullptr;\n\n            if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0>) {\n                qnbitgemm_args[i].quant_b_zp = nullptr;\n            } else {\n                qnbitgemm_args[i].quant_b_zp = w_data;\n            }\n\n            qnbitgemm_args[i].bias  = nullptr;\n            qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i;\n            qnbitgemm_args[i].ldc   = gemm_n;\n        }\n\n        const uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata);\n        void *          ws = reinterpret_cast<void *>((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1)));\n        const size_t    quant_a_stride = block_count_k * q8_blk_size(QK4_0);\n\n        {\n            constexpr size_t block_size_m           = 4;\n            size_t           per_gemm_block_count_m = div_round_up(gemm_m, block_size_m);\n            int32_t          task_count             = batch_feature * per_gemm_block_count_m;\n            int32_t          task_per_thread        = (task_count + nth - 1) / nth;\n            int32_t          start                  = ith * task_per_thread;\n            int32_t          end                    = std::min((ith + 1) * task_per_thread, task_count);\n            for (int32_t compute_idx = start; compute_idx < end; compute_idx++) {\n                int32_t                             gemm_idx = compute_idx / per_gemm_block_count_m;\n                int32_t                             block_idx_in_gemm = compute_idx % per_gemm_block_count_m;\n                int32_t                             m_idx    = block_idx_in_gemm * block_size_m;\n                const qnbitgemm_spacemit_ime_args & data     = qnbitgemm_args[gemm_idx];\n                int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx);\n\n                if (rows_tobe_handled == block_size_m) {\n                    const float * a_row_ptr = data.a_ptr + m_idx * data.lda;\n                    std::byte *   quant_a_row_ptr =\n                        static_cast<std::byte *>(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride;\n                    sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr);\n                } else {\n                    while (rows_tobe_handled) {\n                        const float * a_row_ptr       = data.a_ptr + m_idx * data.lda;\n                        std::byte *   quant_a_row_ptr = static_cast<std::byte *>(ws) +\n                                                      gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride;\n                        sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr);\n                        rows_tobe_handled -= 1;\n                        m_idx += 1;\n                    }\n                }\n            }\n        }\n\n        ggml_barrier(params->threadpool);\n\n        if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) {\n            return;\n        }\n        nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores });\n\n        size_t           threads_per_gemm = nth / batch_feature;\n        constexpr size_t gemm_m_stride    = 128;\n        size_t           nc               = gemm_n;\n        const size_t     gemm_m_blocked   = div_round_up(gemm_m, gemm_m_stride);\n        const size_t     max_nc           = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm);\n        if (max_nc < nc) {\n            nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN);\n        }\n        const size_t gemm_n_stride  = nc;\n        const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride);\n        const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride);\n        threads_per_gemm            = thread_count_m * thread_count_n;\n\n        {\n            int task_count      = batch_feature * threads_per_gemm;\n            int task_per_thread = (task_count + nth - 1) / nth;\n            int start           = ith * task_per_thread;\n            int end             = std::min((ith + 1) * task_per_thread, task_count);\n            for (int compute_idx = start; compute_idx < end; compute_idx++) {\n                const auto   gemm_i = compute_idx / threads_per_gemm;\n                const auto   blk_i  = compute_idx % threads_per_gemm;\n                const auto * data   = &qnbitgemm_args[gemm_i];\n\n                const auto tid_n = blk_i / thread_count_m;\n                const auto tid_m = blk_i % thread_count_m;\n\n                const size_t m_start = tid_m * gemm_m_stride;\n                const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride);\n\n                const size_t n_start = tid_n * gemm_n_stride;\n                const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride);\n\n                void * per_gemm_ws = reinterpret_cast<std::byte *>(ws) + gemm_i * per_gemm_workspace_stride;\n\n                sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count);\n            }\n        }\n    }\n\n    int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {\n        GGML_LOG_DEBUG(\"%s: repack tensor %s with %s_%dx%d\\n\", __func__, t->name, ggml_type_name(t->type),\n                       (int) NB_COLS, (int) INTER_SIZE);\n        return ggml::cpu::riscv64_spacemit::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);\n    }\n};\n\nclass tensor_traits_common : public tensor_traits_base {\n    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {\n        switch (op->op) {\n            case GGML_OP_NORM:\n            case GGML_OP_RMS_NORM:\n                size = 0;\n                return true;\n            default:\n                // GGML_ABORT(\"fatal error\");\n                break;\n        }\n        return false;\n    }\n\n    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {\n        switch (op->op) {\n            case GGML_OP_NORM:\n                forward_norm_f32(params, op);\n                return true;\n            case GGML_OP_RMS_NORM:\n                forward_rms_norm_f32(params, op);\n                return true;\n            default:\n                // GGML_ABORT(\"fatal error\");\n                break;\n        }\n        return false;\n    }\n\n    void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) {\n        const ggml_tensor * src0 = op->src[0];\n        ggml_tensor *       dst  = op;\n        GGML_ASSERT(ggml_are_same_shape(src0, dst));\n        GGML_ASSERT(src0->nb[0] == sizeof(float));\n\n        const int ith = params->ith;\n        const int nth = params->nth;\n\n        GGML_TENSOR_UNARY_OP_LOCALS\n\n        float epsilon;\n        memcpy(&epsilon, dst->op_params, sizeof(float));\n\n        GGML_ASSERT(epsilon > 0.0f);\n\n        auto * input  = (float *) src0->data;\n        auto * output = (float *) dst->data;\n\n        const auto hidden_size     = ne00;\n        const auto task_count      = ne01 * ne02 * ne03;\n        const auto task_per_thread = (task_count + nth - 1) / nth;\n\n        const auto task_begin = ith * task_per_thread;\n        const auto task_end   = std::min((ith + 1) * task_per_thread, task_count);\n\n        for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {\n            auto   offset  = task_idx * hidden_size;\n            auto * p_input = const_cast<float *>(input + offset);\n\n            auto *       p_output      = output + offset;\n            auto *       p_temp_output = p_output;\n            auto *       p_gamma_data  = (const float *) nullptr;\n            auto *       p_beta_data   = (const float *) nullptr;\n            size_t       gvl           = __riscv_vsetvlmax_e32m4();\n            vfloat32m4_t sum           = __riscv_vfmv_v_f_f32m4(0.f, gvl);\n            vfloat32m4_t sum_sq        = __riscv_vfmv_v_f_f32m4(0.f, gvl);\n            int64_t      length        = hidden_size;\n            while (length > 0) {\n                gvl                   = __riscv_vsetvl_e32m4(length);\n                // load data\n                vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);\n\n                sum    = __riscv_vfadd_vv_f32m4(sum, src_data, gvl);\n                sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl);\n\n                __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl);\n\n                p_input += gvl;\n                p_temp_output += gvl;\n                length -= gvl;\n            }\n\n            gvl = __riscv_vsetvlmax_e32m1();\n\n            float        mean   = 0.f;\n            vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);\n            vfloat32m1_t mean_v =\n                __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl);\n            mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl);\n            mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl);\n            mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl);\n            mean   = __riscv_vfmv_f_s_f32m1_f32(mean_v);\n            mean /= hidden_size;\n\n            vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),\n                                                                __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);\n            mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);\n            mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);\n            mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);\n\n            float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);\n            mean_square /= hidden_size;\n            mean_square = sqrt(mean_square - mean * mean + epsilon);\n\n            mean_square   = 1.0f / mean_square;\n            length        = hidden_size;\n            p_temp_output = p_output;\n\n            if (p_gamma_data == nullptr && p_beta_data == nullptr) {\n                while (length > 0) {\n                    gvl                   = __riscv_vsetvl_e32m4(length);\n                    vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);\n                    src_data              = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);\n                    src_data              = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);\n                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);\n                    p_temp_output += gvl;\n                    p_output += gvl;\n                    length -= gvl;\n                }\n            } else if (p_beta_data == nullptr) {\n                while (length > 0) {\n                    gvl                       = __riscv_vsetvl_e32m4(length);\n                    vfloat32m4_t src_data     = __riscv_vle32_v_f32m4(p_temp_output, gvl);\n                    vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);\n                    src_data                  = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);\n                    src_data                  = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);\n                    src_data                  = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);\n                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);\n                    p_temp_output += gvl;\n                    p_output += gvl;\n                    p_gamma_data += gvl;\n                    length -= gvl;\n                }\n            } else if (p_gamma_data != nullptr) {\n                while (length > 0) {\n                    gvl                       = __riscv_vsetvl_e32m4(length);\n                    vfloat32m4_t src_data     = __riscv_vle32_v_f32m4(p_temp_output, gvl);\n                    vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);\n                    src_data                  = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);\n                    src_data                  = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);\n                    src_data                  = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);\n                    vfloat32m4_t beta_data_v  = __riscv_vle32_v_f32m4(p_beta_data, gvl);\n                    src_data                  = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);\n                    p_beta_data += gvl;\n                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);\n                    p_temp_output += gvl;\n                    p_output += gvl;\n                    p_gamma_data += gvl;\n                    length -= gvl;\n                }\n            }\n        }\n    }\n\n    void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) {\n        const ggml_tensor * src0 = op->src[0];\n        ggml_tensor *       dst  = op;\n        GGML_ASSERT(ggml_are_same_shape(src0, dst));\n        GGML_ASSERT(src0->nb[0] == sizeof(float));\n\n        const int ith = params->ith;\n        const int nth = params->nth;\n\n        GGML_TENSOR_UNARY_OP_LOCALS\n\n        float epsilon;\n        memcpy(&epsilon, dst->op_params, sizeof(float));\n\n        GGML_ASSERT(epsilon > 0.0f);\n\n        auto * input  = (float *) src0->data;\n        auto * output = (float *) dst->data;\n\n        const auto hidden_size     = ne00;\n        const auto task_count      = ne01 * ne02 * ne03;\n        const auto task_per_thread = (task_count + nth - 1) / nth;\n\n        const auto task_begin = ith * task_per_thread;\n        const auto task_end   = std::min((ith + 1) * task_per_thread, task_count);\n\n        for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {\n            auto   offset        = task_idx * hidden_size;\n            auto * p_input       = const_cast<float *>(input + offset);\n            auto * p_output      = output + offset;\n            auto * p_temp_output = p_output;\n            auto * p_gamma_data  = (const float *) nullptr;\n            auto * p_beta_data   = (const float *) nullptr;\n\n            size_t       gvl    = __riscv_vsetvlmax_e32m4();\n            // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl);\n            vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl);\n            int64_t      length = hidden_size;\n            while (length > 0) {\n                gvl                   = __riscv_vsetvl_e32m4(length);\n                // load data\n                vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);\n\n                sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl);\n\n                __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl);\n\n                p_input += gvl;\n                p_temp_output += gvl;\n                length -= gvl;\n            }\n\n            gvl = __riscv_vsetvlmax_e32m1();\n\n            // float mean = 0.f;\n            vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);\n\n            vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),\n                                                                __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);\n            mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);\n            mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);\n            mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);\n\n            float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);\n            mean_square /= hidden_size;\n\n            mean_square = sqrt(mean_square + epsilon);\n\n            mean_square   = 1.0f / mean_square;\n            length        = hidden_size;\n            p_temp_output = p_output;\n\n            if (p_gamma_data == nullptr && p_beta_data == nullptr) {\n                while (length > 0) {\n                    gvl                   = __riscv_vsetvl_e32m4(length);\n                    vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);\n                    src_data              = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);\n                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);\n                    p_temp_output += gvl;\n                    p_output += gvl;\n                    length -= gvl;\n                }\n            } else if (p_beta_data == nullptr) {\n                while (length > 0) {\n                    gvl                       = __riscv_vsetvl_e32m4(length);\n                    vfloat32m4_t src_data     = __riscv_vle32_v_f32m4(p_temp_output, gvl);\n                    vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);\n                    src_data                  = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);\n                    src_data                  = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);\n                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);\n                    p_temp_output += gvl;\n                    p_output += gvl;\n                    p_gamma_data += gvl;\n                    length -= gvl;\n                }\n            } else if (p_gamma_data != nullptr) {\n                while (length > 0) {\n                    gvl                       = __riscv_vsetvl_e32m4(length);\n                    vfloat32m4_t src_data     = __riscv_vle32_v_f32m4(p_temp_output, gvl);\n                    vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);\n                    src_data                  = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);\n                    src_data                  = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);\n                    vfloat32m4_t beta_data_v  = __riscv_vle32_v_f32m4(p_beta_data, gvl);\n                    src_data                  = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);\n                    p_beta_data += gvl;\n                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);\n                    p_temp_output += gvl;\n                    p_output += gvl;\n                    p_gamma_data += gvl;\n                    length -= gvl;\n                }\n            }\n        }\n    }\n\n    int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {\n        memcpy(t->data, data, data_size);\n        return 0;\n    }\n};\n\nstatic const tensor_traits<block_q4_0, 8, 16> q4_0_16x8_q8_0;\nstatic const tensor_traits<block_q4_1, 8, 16> q4_1_16x8_q8_0;\nstatic const tensor_traits<block_q4_K, 8, 16> q4_k_16x8_q8_0;\nstatic const tensor_traits_common             rvv_impl;\n\n}  // namespace ggml::cpu::riscv64_spacemit\n\nstatic const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) {\n    if (cur->type == GGML_TYPE_Q4_0) {\n        if (cur->ne[1] % 16 == 0) {\n            return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0;\n        }\n    } else if (cur->type == GGML_TYPE_Q4_1) {\n        if (cur->ne[1] % 16 == 0) {\n            return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0;\n        }\n    } else if (cur->type == GGML_TYPE_Q4_K) {\n        if (cur->ne[1] % 16 == 0) {\n            return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0;\n        }\n    } else if (cur->type == GGML_TYPE_F32) {\n        return &ggml::cpu::riscv64_spacemit::rvv_impl;\n    }\n\n    return nullptr;\n}\n\nstatic enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer,\n                                                                         struct ggml_tensor *  tensor) {\n    tensor->extra =\n        (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_riscv64_spacemit_get_optimal_repack_type(tensor));\n\n    GGML_UNUSED(buffer);\n\n    return GGML_STATUS_SUCCESS;\n}\n\nstatic void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer,\n                                                            struct ggml_tensor *  tensor,\n                                                            const void *          data,\n                                                            size_t                offset,\n                                                            size_t                size) {\n    GGML_ASSERT(offset == 0);\n    GGML_ASSERT(size == ggml_nbytes(tensor));\n\n    auto tensor_traits = (ggml::cpu::riscv64_spacemit::tensor_traits_base *) tensor->extra;\n    if (tensor_traits) {\n        auto OK = tensor_traits->repack(tensor, data, size);\n        GGML_ASSERT(OK == 0);\n    }\n\n    GGML_UNUSED(buffer);\n}\n\nstatic const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    return \"CPU_RISCV64_SPACEMIT\";\n\n    GGML_UNUSED(buft);\n}\n\nstatic ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,\n                                                                                        size_t size) {\n    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);\n\n    if (buffer == nullptr) {\n        return nullptr;\n    }\n\n    buffer->buft              = buft;\n    buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor;\n    buffer->iface.set_tensor  = ggml_backend_riscv64_spacemit_buffer_set_tensor;\n    buffer->iface.get_tensor  = nullptr;\n    buffer->iface.cpy_tensor  = nullptr;\n    return buffer;\n}\n\nstatic size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    return 64;\n\n    GGML_UNUSED(buft);\n}\n\nstatic size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft,\n                                                       const struct ggml_tensor * tensor) {\n    for (int i = 0; i < GGML_MAX_DIMS; ++i) {\n        if (tensor->ne[i] <= 0) {\n            return 0;\n        }\n    }\n\n    size_t       nbytes;\n    const size_t blck_size = ggml_blck_size(tensor->type);\n    if (blck_size == 1) {\n        nbytes = ggml_type_size(tensor->type);\n        for (int i = 0; i < GGML_MAX_DIMS; ++i) {\n            nbytes += (tensor->ne[i] - 1) * tensor->nb[i];\n        }\n    } else {\n        nbytes = tensor->ne[0] * tensor->nb[0] / blck_size;\n        if (tensor->type == GGML_TYPE_Q4_K) {\n            GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0);\n            nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8;\n            for (int i = 1; i < GGML_MAX_DIMS; ++i) {\n                nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8;\n            }\n        } else {\n            for (int i = 1; i < GGML_MAX_DIMS; ++i) {\n                nbytes += (tensor->ne[i] - 1) * tensor->nb[i];\n            }\n        }\n    }\n\n    GGML_UNUSED(buft);\n    return nbytes;\n}\n\nnamespace ggml::cpu::riscv64_spacemit {\n\nclass extra_buffer_type : ggml::cpu::extra_buffer_type {\n    bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {\n        switch (op->op) {\n            case GGML_OP_MUL_MAT:\n                if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) &&\n                    op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() &&\n                    ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) {\n                    if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {\n                        return false;\n                    }\n                    if (op->src[1]->type == GGML_TYPE_F32) {\n                        return true;\n                    }\n                }\n                break;\n            case GGML_OP_NORM:\n            case GGML_OP_RMS_NORM:\n                if (op->src[0]->type == GGML_TYPE_F32) {\n                    return true;\n                }\n                break;\n            default:\n                // GGML_ABORT(\"fatal error\");\n                break;\n        }\n        return false;\n    }\n\n    ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {\n        switch (op->op) {\n            case GGML_OP_MUL_MAT:\n                if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) {\n                    return (ggml::cpu::tensor_traits *) op->src[0]->extra;\n                }\n                break;\n            case GGML_OP_NORM:\n            case GGML_OP_RMS_NORM:\n                return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl);\n            default:\n                // GGML_ABORT(\"fatal error\");\n                break;\n        }\n\n        return nullptr;\n    }\n};\n\n}  // namespace ggml::cpu::riscv64_spacemit\n\nggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) {\n    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = {\n  /* .iface    = */\n        {\n         /* .get_name         = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name,\n         /* .alloc_buffer     = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer,\n         /* .get_alignment    = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment,\n         /* .get_max_size     = */ nullptr,\n         /* .get_alloc_size   = */ ggml_backend_cpu_riscv64_spacemit_nbytes,\n         /* .is_host          = */ nullptr,\n         },\n /* .device  = */\n        ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),\n /* .context = */\n        new ggml::cpu::riscv64_spacemit::extra_buffer_type(),\n    };\n\n    return &ggml_backend_cpu_buffer_type_riscv64_spacemit;\n}\n"
  },
  {
    "path": "src/ggml-cpu/spacemit/ime.h",
    "content": "#pragma once\n\n#include \"ggml-alloc.h\"\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/spacemit/ime1_kernels.cpp",
    "content": "#include \"ggml.h\"\n#include \"ime_kernels.h\"\n\n#include <algorithm>\n#include <cmath>\n\n// clang-format off\n#if defined(__GNUC__)\n#pragma GCC diagnostic ignored \"-Woverlength-strings\"\n#pragma GCC diagnostic ignored \"-Wcast-qual\"\n#pragma GCC diagnostic ignored \"-Wunused-parameter\"\n#endif\n// clang-format on\nnamespace sqnbitgemm_spacemit_ime {\n\n#define QUANTIZEM4ROW_KERNEL                           \\\n    \"vmv.s.x            v16, zero                \\n\\t\" \\\n    \"vfabs.v            v8, v0                   \\n\\t\" \\\n    \"vfredmax.vs        v16, v8, v16             \\n\\t\" \\\n    \"vfmv.f.s           f10, v16                 \\n\\t\" \\\n    \"fmul.s             f10, f10, %[RMAXREC]     \\n\\t\" \\\n    \"fsw                f10, (a1)                \\n\\t\" \\\n    \"fdiv.s             f11, %[FONE], f10        \\n\\t\" \\\n    \"vfmul.vf           v16, v0, f11             \\n\\t\" \\\n    \"vfcvt.x.f.v        v16, v16                 \\n\\t\" \\\n    \"vsetvli            t0, zero, e16, mf2       \\n\\t\" \\\n    \"vnclip.wx          v16, v16, zero           \\n\\t\" \\\n    \"vnclip.wx          v17, v17, zero           \\n\\t\" \\\n    \"vnclip.wx          v18, v18, zero           \\n\\t\" \\\n    \"vnclip.wx          v19, v19, zero           \\n\\t\" \\\n    \"vnclip.wx          v20, v20, zero           \\n\\t\" \\\n    \"vnclip.wx          v21, v21, zero           \\n\\t\" \\\n    \"vnclip.wx          v22, v22, zero           \\n\\t\" \\\n    \"vnclip.wx          v23, v23, zero           \\n\\t\" \\\n    \"vsetvli            t0, zero, e8, mf4        \\n\\t\" \\\n    \"vnclip.wx          v24, v16, zero           \\n\\t\" \\\n    \"vnclip.wx          v25, v17, zero           \\n\\t\" \\\n    \"vnclip.wx          v26, v18, zero           \\n\\t\" \\\n    \"vnclip.wx          v27, v19, zero           \\n\\t\" \\\n    \"vnclip.wx          v28, v20, zero           \\n\\t\" \\\n    \"vnclip.wx          v29, v21, zero           \\n\\t\" \\\n    \"vnclip.wx          v30, v22, zero           \\n\\t\" \\\n    \"vnclip.wx          v31, v23, zero           \\n\\t\"\n\n#define QUANTIZEM4ROW_STORE                            \\\n    \"addi               t1, %[BlkLen], 0         \\n\\t\" \\\n    \"vsetvli            t0, t1, e8, mf4          \\n\\t\" \\\n    \"vse8.v             v24, (s1)                \\n\\t\" \\\n    \"addi               s1, s1, 32               \\n\\t\" \\\n    \"sub                t1, t1, t0               \\n\\t\" \\\n    \"vsetvli            t0, t1, e8, mf4          \\n\\t\" \\\n    \"vse8.v             v25, (s1)                \\n\\t\" \\\n    \"addi               s1, s1, 32               \\n\\t\" \\\n    \"sub                t1, t1, t0               \\n\\t\" \\\n    \"vsetvli            t0, t1, e8, mf4          \\n\\t\" \\\n    \"vse8.v             v26, (s1)                \\n\\t\" \\\n    \"addi               s1, s1, 32               \\n\\t\" \\\n    \"sub                t1, t1, t0               \\n\\t\" \\\n    \"vsetvli            t0, t1, e8, mf4          \\n\\t\" \\\n    \"vse8.v             v27, (s1)                \\n\\t\" \\\n    \"addi               s1, s1, 32               \\n\\t\" \\\n    \"sub                t1, t1, t0               \\n\\t\" \\\n    \"vsetvli            t0, t1, e8, mf4          \\n\\t\" \\\n    \"vse8.v             v28, (s1)                \\n\\t\" \\\n    \"addi               s1, s1, 32               \\n\\t\" \\\n    \"sub                t1, t1, t0               \\n\\t\" \\\n    \"vsetvli            t0, t1, e8, mf4          \\n\\t\" \\\n    \"vse8.v             v29, (s1)                \\n\\t\" \\\n    \"addi               s1, s1, 32               \\n\\t\" \\\n    \"sub                t1, t1, t0               \\n\\t\" \\\n    \"vsetvli            t0, t1, e8, mf4          \\n\\t\" \\\n    \"vse8.v             v30, (s1)                \\n\\t\" \\\n    \"addi               s1, s1, 32               \\n\\t\" \\\n    \"sub                t1, t1, t0               \\n\\t\" \\\n    \"vsetvli            t0, t1, e8, mf4          \\n\\t\" \\\n    \"vse8.v             v31, (s1)                \\n\\t\"\n\nnamespace ime1 {\nvoid quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {\n    constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);\n    const float     fone                 = 1.0f;\n\n    if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) {\n        for (size_t row_index = 0; row_index < 4; ++row_index) {\n            const float * SRC = A + row_index * CountK;\n            std::byte *   DST = QuantA + row_index * sizeof(float);\n\n            const size_t offset = (4 - row_index) * 4 + row_index * 8;\n            const size_t stride = 4 * (sizeof(float) + BlkLen);\n            __asm__ volatile(\n                \"vsetvli            t0, zero, e32, m8        \\n\\t\"\n                \"addi               t2, %[CountK], 0         \\n\\t\"\n                \"addi               a1, %[DST], 0            \\n\\t\"\n                \"blt                t2, %[BlkLen], TAIL%=    \\n\\t\"\n\n                \"LOOP%=:                                     \\n\\t\"\n                \"vsetvli            t0, %[BlkLen], e32, m8   \\n\\t\"\n                \"vle32.v            v0, (%[SRC])             \\n\\t\"\n                \"sub                t2, t2, t0               \\n\\t\"\n                \"slli               t1, t0, 2                \\n\\t\"\n                \"add                %[SRC], %[SRC], t1       \\n\\t\"\n                \"add                s1, a1, %[OFFSET]        \\n\\t\"\n\n                QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE\n\n                \"add                a1, a1, %[STRIDE]        \\n\\t\"\n                \"bge                t2, %[BlkLen], LOOP%=    \\n\\t\"\n\n                \"TAIL%=:                                     \\n\\t\"\n                \"blez               t2, QUIT%=               \\n\\t\"\n                \"vsetvli            t0, zero, e32, m8        \\n\\t\"\n                \"vxor.vv            v16, v16, v16            \\n\\t\"\n                \"vxor.vv            v24, v24, v24            \\n\\t\"\n                \"vsetvli            t0, t2, e32, m8          \\n\\t\"\n                \"vle32.v            v0, (%[SRC])             \\n\\t\"\n                \"add                s1, a1, %[OFFSET]        \\n\\t\"\n\n                QUANTIZEM4ROW_KERNEL\n\n                \"addi               t3, %[BlkLen], 0         \\n\\t\"\n                \"addi               s2, s1, 0                \\n\\t\"\n                \"vsetvli            t0, zero, e8, mf4        \\n\\t\"\n                \"vxor.vv            v8, v8, v8               \\n\\t\"\n                \"SET_ZERO%=:                                 \\n\\t\"\n                \"vse8.v             v8, (s2)                 \\n\\t\"\n                \"addi               s2, s2, 32               \\n\\t\"\n                \"addi               t3, t3, -8               \\n\\t\"\n                \"bnez               t3, SET_ZERO%=           \\n\\t\"\n\n                QUANTIZEM4ROW_STORE\n\n                \"QUIT%=:                                     \\n\\t\"\n                : [SRC] \"+r\"(SRC)\n                : [DST] \"r\"(DST), [BlkLen] \"r\"(BlkLen), [OFFSET] \"r\"(offset), [STRIDE] \"r\"(stride),\n                  [CountK] \"r\"(CountK), [FONE] \"f\"(fone), [RMAXREC] \"f\"(range_max_reciprocal)\n                : \"cc\", \"t0\", \"t1\", \"t2\", \"t3\", \"a1\", \"s1\", \"s2\", \"f10\", \"f11\");\n        }\n    } else if (BlkLen == 128) {\n        for (size_t row_index = 0; row_index < 4; ++row_index) {\n            const float * SRC = A + row_index * CountK;\n            std::byte *   DST = QuantA + row_index * sizeof(float);\n\n            const size_t offset = (4 - row_index) * 4 + row_index * 8;\n            const size_t stride = 4 * (sizeof(float) + BlkLen);\n            __asm__ volatile(\n                \"vsetvli            t0, zero, e32, m8        \\n\\t\"\n                \"li                 t6, 32                   \\n\\t\"\n                \"addi               t2, %[CountK], 0         \\n\\t\"\n                \"addi               a1, %[DST], 0            \\n\\t\"\n                \"add                s1, a1, %[OFFSET]        \\n\\t\"\n                \"blt                t2, %[BlkLen], TAIL%=    \\n\\t\"\n\n                \"LOOP%=:                                     \\n\\t\"\n                \"vsetvli            t0, zero, e32, m8        \\n\\t\"\n                \"vle32.v            v0, (%[SRC])             \\n\\t\"\n                \"addi               %[SRC], %[SRC], 256      \\n\\t\"\n                \"vle32.v            v8, (%[SRC])             \\n\\t\"\n                \"addi               %[SRC], %[SRC], 256      \\n\\t\"\n                \"addi               t2, t2, -128             \\n\\t\"\n\n                \"QUANTIZE%=:                                 \\n\\t\"\n                \"add                s1, a1, %[OFFSET]        \\n\\t\"\n                \"vfabs.v            v16, v0                  \\n\\t\"\n                \"vfabs.v            v24, v8                  \\n\\t\"\n                \"vfmax.vv           v16, v24, v16            \\n\\t\"\n                \"vfredmax.vs        v24, v16, v24            \\n\\t\"\n                \"vfmv.f.s           f10, v24                 \\n\\t\"\n                \"fmul.s             f10, f10, %[RMAXREC]     \\n\\t\"\n                \"fsw                f10, (a1)                \\n\\t\"\n                \"fdiv.s             f11, %[FONE], f10        \\n\\t\"\n                \"vfmul.vf           v16, v0, f11             \\n\\t\"\n                \"vfmul.vf           v24, v8, f11             \\n\\t\"\n                \"vfcvt.x.f.v        v16, v16                 \\n\\t\"\n                \"vfcvt.x.f.v        v24, v24                 \\n\\t\"\n                \"vsetvli            t0, zero, e16, m4        \\n\\t\"\n                \"vnclip.wx          v16, v16, zero           \\n\\t\"\n                \"vnclip.wx          v20, v24, zero           \\n\\t\"\n                \"vsetvli            t0, zero, e8, m4         \\n\\t\"\n                \"vnclip.wx          v16, v16, zero           \\n\\t\"\n                \"vsetvli            t0, zero, e64, m4        \\n\\t\"\n                \"vsse64.v           v16, (s1), t6            \\n\\t\"\n                \"add                a1, a1, %[STRIDE]        \\n\\t\"\n                \"bge                t2, %[BlkLen], LOOP%=    \\n\\t\"\n\n                \"TAIL%=:                                     \\n\\t\"\n                \"blez               t2, QUIT%=               \\n\\t\"\n                \"vsetvli            t0, zero, e32, m8        \\n\\t\"\n                \"vxor.vv             v0, v0, v0              \\n\\t\"\n                \"vxor.vv             v8, v8, v8              \\n\\t\"\n                \"vxor.vv             v16, v16, v16           \\n\\t\"\n                \"vxor.vv             v24, v24, v24           \\n\\t\"\n                \"vsetvli            t0, t2, e32, m8          \\n\\t\"\n                \"sub                t2, t2, t0               \\n\\t\"\n                \"vle32.v            v0, (%[SRC])             \\n\\t\"\n                \"addi               %[SRC], %[SRC], 256      \\n\\t\"\n                \"vsetvli            t0, t2, e32, m8          \\n\\t\"\n                \"vle32.v            v8, (%[SRC])             \\n\\t\"\n                \"sub                t2, t2, t2               \\n\\t\"\n                \"vsetvli            t0, zero, e32, m8        \\n\\t\"\n                \"jal                x0, QUANTIZE%=           \\n\\t\"\n\n                \"QUIT%=:                                     \\n\\t\"\n                : [SRC] \"+r\"(SRC)\n                : [DST] \"r\"(DST), [BlkLen] \"r\"(BlkLen), [OFFSET] \"r\"(offset), [STRIDE] \"r\"(stride),\n                  [CountK] \"r\"(CountK), [FONE] \"f\"(fone), [RMAXREC] \"f\"(range_max_reciprocal)\n                : \"cc\", \"t0\", \"t1\", \"t2\", \"t6\", \"a1\", \"s1\", \"s2\", \"f10\", \"f11\");\n        }\n    } else if (BlkLen == 256) {\n        for (size_t row_index = 0; row_index < 4; ++row_index) {\n            const float * SRC    = A + row_index * CountK;\n            std::byte *   DST    = QuantA + row_index * sizeof(float);\n            const size_t  offset = (4 - row_index) * 4 + row_index * 8;\n            const size_t  stride = 4 * (sizeof(float) + BlkLen);\n            __asm__ volatile(\n                \"vsetvli            t0, zero, e32, m8        \\n\\t\"\n                \"li                 t6, 32                   \\n\\t\"\n                \"addi               t2, %[CountK], 0         \\n\\t\"\n                \"addi               a1, %[DST], 0            \\n\\t\"\n                \"add                s1, a1, %[OFFSET]        \\n\\t\"\n                \"blt                t2, %[BlkLen], TAIL%=    \\n\\t\"\n\n                \"LOOP%=:                                     \\n\\t\"\n                \"vsetvli            t0, zero, e32, m8        \\n\\t\"\n                \"vle32.v            v0, (%[SRC])             \\n\\t\"\n                \"addi               %[SRC], %[SRC], 256      \\n\\t\"\n                \"vle32.v            v8, (%[SRC])             \\n\\t\"\n                \"addi               %[SRC], %[SRC], 256      \\n\\t\"\n                \"vle32.v            v16, (%[SRC])            \\n\\t\"\n                \"addi               %[SRC], %[SRC], 256      \\n\\t\"\n                \"vle32.v            v24, (%[SRC])            \\n\\t\"\n                \"addi               %[SRC], %[SRC], -768     \\n\\t\"\n                \"addi               t2, t2, -256             \\n\\t\"\n                \"vfabs.v            v0, v0                   \\n\\t\"\n                \"vfabs.v            v8, v8                   \\n\\t\"\n                \"vfabs.v            v16, v16                 \\n\\t\"\n                \"vfabs.v            v24, v24                 \\n\\t\"\n                \"vfmax.vv           v8, v0, v8               \\n\\t\"\n                \"vfmax.vv           v24, v24, v16            \\n\\t\"\n                \"vfmax.vv           v8, v8, v24              \\n\\t\"\n                \"vfredmax.vs        v24, v8, v24             \\n\\t\"\n                \"vfmv.f.s           f10, v24                 \\n\\t\"\n                \"vle32.v            v0, (%[SRC])             \\n\\t\"\n                \"addi               %[SRC], %[SRC], 256      \\n\\t\"\n                \"vle32.v            v8, (%[SRC])             \\n\\t\"\n                \"addi               %[SRC], %[SRC], 256      \\n\\t\"\n                \"vle32.v            v16, (%[SRC])            \\n\\t\"\n                \"addi               %[SRC], %[SRC], 256      \\n\\t\"\n                \"vle32.v            v24, (%[SRC])            \\n\\t\"\n                \"addi               %[SRC], %[SRC], 256      \\n\\t\"\n\n                \"QUANTIZE%=:                                 \\n\\t\"\n                \"add                s1, a1, %[OFFSET]        \\n\\t\"\n                \"fmul.s             f10, f10, %[RMAXREC]     \\n\\t\"\n                \"fsw                f10, (a1)                \\n\\t\"\n                \"fdiv.s             f11, %[FONE], f10        \\n\\t\"\n                \"vfmul.vf           v0, v0, f11              \\n\\t\"\n                \"vfmul.vf           v8, v8, f11              \\n\\t\"\n                \"vfmul.vf           v16, v16, f11            \\n\\t\"\n                \"vfmul.vf           v24, v24, f11            \\n\\t\"\n                \"vfcvt.x.f.v        v0, v0                   \\n\\t\"\n                \"vfcvt.x.f.v        v8, v8                   \\n\\t\"\n                \"vfcvt.x.f.v        v16, v16                 \\n\\t\"\n                \"vfcvt.x.f.v        v24, v24                 \\n\\t\"\n                \"vsetvli            t0, zero, e16, m4        \\n\\t\"\n                \"vnclip.wx          v0, v0, zero             \\n\\t\"\n                \"vnclip.wx          v4, v8, zero             \\n\\t\"\n                \"vnclip.wx          v8, v16, zero            \\n\\t\"\n                \"vnclip.wx          v12, v24, zero           \\n\\t\"\n                \"vsetvli            t0, zero, e8, m4         \\n\\t\"\n                \"vnclip.wx          v0, v0, zero             \\n\\t\"\n                \"vnclip.wx          v4, v8, zero             \\n\\t\"\n                \"vsetvli            t0, zero, e64, m8        \\n\\t\"\n                \"vsse64.v           v0, (s1), t6             \\n\\t\"\n                \"add                a1, a1, %[STRIDE]        \\n\\t\"\n                \"bge                t2, %[BlkLen], LOOP%=    \\n\\t\"\n\n                \"TAIL%=:                                     \\n\\t\"\n                \"blez               t2, QUIT%=               \\n\\t\"\n                \"vsetvli            t0, zero, e32, m8        \\n\\t\"\n                \"vxor.vv            v0, v0, v0               \\n\\t\"\n                \"vxor.vv            v8, v8, v8               \\n\\t\"\n                \"vxor.vv            v16, v16, v16            \\n\\t\"\n                \"vxor.vv            v24, v24, v24            \\n\\t\"\n                \"addi               t1, t2, 0                \\n\\t\"\n                \"vsetvli            t0, t1, e32, m8          \\n\\t\"\n                \"sub                t1, t1, t0               \\n\\t\"\n                \"vle32.v            v0, (%[SRC])             \\n\\t\"\n                \"addi               %[SRC], %[SRC], 256      \\n\\t\"\n                \"vsetvli            t0, t1, e32, m8          \\n\\t\"\n                \"sub                t1, t1, t0               \\n\\t\"\n                \"vle32.v            v8, (%[SRC])             \\n\\t\"\n                \"addi               %[SRC], %[SRC], 256      \\n\\t\"\n                \"vsetvli            t0, t1, e32, m8          \\n\\t\"\n                \"sub                t1, t1, t0               \\n\\t\"\n                \"vle32.v            v16, (%[SRC])            \\n\\t\"\n                \"addi               %[SRC], %[SRC], 256      \\n\\t\"\n                \"vsetvli            t0, t1, e32, m8          \\n\\t\"\n                \"vle32.v            v24, (%[SRC])            \\n\\t\"\n                \"addi               %[SRC], %[SRC], -768     \\n\\t\"\n                \"vsetvli            t0, zero, e32, m8        \\n\\t\"\n                \"vfabs.v            v0, v0                   \\n\\t\"\n                \"vfabs.v            v8, v8                   \\n\\t\"\n                \"vfabs.v            v16, v16                 \\n\\t\"\n                \"vfabs.v            v24, v24                 \\n\\t\"\n                \"vfmax.vv           v8, v0, v8               \\n\\t\"\n                \"vfmax.vv           v24, v16, v24            \\n\\t\"\n                \"vfmax.vv           v8, v8, v24              \\n\\t\"\n                \"vfredmax.vs        v24, v8, v24             \\n\\t\"\n                \"vfmv.f.s           f10, v24                 \\n\\t\"\n                \"add                s1, a1, %[OFFSET]        \\n\\t\"\n                \"fmul.s             f10, f10, %[RMAXREC]     \\n\\t\"\n                \"fsw                f10, (a1)                \\n\\t\"\n                \"fdiv.s             f11, %[FONE], f10        \\n\\t\"\n                \"vsetvli            t0, zero, e64, m8        \\n\\t\"\n                \"vxor.vv            v0, v0, v0               \\n\\t\"\n                \"vsse64.v           v0, (s1), t6             \\n\\t\"\n\n                \"TAIL_LOOP%=:                                \\n\\t\"\n                \"vsetvli            t0, zero, e32, m4        \\n\\t\"\n                \"vxor.vv            v0, v0, v0               \\n\\t\"\n                \"vsetvli            t0, t2, e32, m1          \\n\\t\"\n                \"sub                t2, t2, t0               \\n\\t\"\n                \"vle32.v            v0, (%[SRC])             \\n\\t\"\n                \"addi               %[SRC], %[SRC], 32       \\n\\t\"\n                \"vfmul.vf           v1, v0, f11              \\n\\t\"\n                \"vfcvt.x.f.v        v2, v1                   \\n\\t\"\n                \"vsetvli            t0, zero, e16, mf2       \\n\\t\"\n                \"vnclip.wx          v3, v2, zero             \\n\\t\"\n                \"vsetvli            t0, zero, e8, mf4        \\n\\t\"\n                \"vnclip.wx          v3, v3, zero             \\n\\t\"\n                \"vse8.v             v3, (s1)                 \\n\\t\"\n                \"addi               s1, s1, 32               \\n\\t\"\n                \"bnez               t2, TAIL_LOOP%=          \\n\\t\"\n\n                \"QUIT%=:                                     \\n\\t\"\n                : [SRC] \"+r\"(SRC)\n                : [DST] \"r\"(DST), [BlkLen] \"r\"(BlkLen), [OFFSET] \"r\"(offset), [STRIDE] \"r\"(stride),\n                  [CountK] \"r\"(CountK), [FONE] \"f\"(fone), [RMAXREC] \"f\"(range_max_reciprocal)\n                : \"cc\", \"t0\", \"t1\", \"t2\", \"t6\", \"a1\", \"s1\", \"s2\", \"f10\", \"f11\");\n        }\n    }\n}\n\nvoid quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {\n    const float *   SRC                  = A;\n    std::byte *     DST                  = QuantA;\n    constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);\n    const float     fone                 = 1.0f;\n    std::byte *     QuantA_offset        = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen);\n    size_t          offset               = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK;\n\n    if (CountK <= BlkLen) {\n        float max_abs_A = 0.0f;\n        for (size_t k = 0; k < CountK; k++) {\n            max_abs_A = std::max(max_abs_A, fabsf(A[k]));\n        }\n        float scale_A = max_abs_A * range_max_reciprocal;\n\n        ((float *) QuantA)[0] = scale_A;\n\n        auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float));\n\n        for (size_t k = 0; k < CountK; k++) {\n            QuantAData_offset[k] =\n                (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits<int8_t>::lowest(),\n                                    (float) std::numeric_limits<int8_t>::max());\n        }\n        for (size_t k = CountK; k < BlkLen; k++) {\n            QuantAData_offset[k] = 0;\n        }\n\n        return;\n    }\n\n    if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) {\n        __asm__ volatile(\n            \"vsetvli      t0, zero, e8, m8        \\n\\t\"\n            \"vxor.vv      v24, v24, v24           \\n\\t\"\n            \"LOOP%=:                              \\n\\t\"\n            \"vsetvli      t0, %[CNT], e8, m8      \\n\\t\"\n            \"vse8.v       v24, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 128     \\n\\t\"\n            \"sub          %[CNT], %[CNT], t0      \\n\\t\"\n            \"bnez         %[CNT], LOOP%=          \\n\\t\"\n            : [DST] \"+r\"(QuantA_offset), [CNT] \"+r\"(offset)\n            :\n            : \"cc\", \"t0\");\n    }\n    if (BlkLen == 16) {\n        float buffer[64] = { 0.0f };\n        __asm__ volatile(\n            \"addi         t3, zero, 16*8          \\n\\t\"\n            \"addi         t2, zero, 16            \\n\\t\"\n            \"blt          %[K], t3, LOOP_K%=      \\n\\t\"\n            \"blt          %[K], t2, TAIL%=        \\n\\t\"\n            \"LOOP_MAIN%=:                         \\n\\t\"\n            \"vsetvli      t1, zero, e32, m2       \\n\\t\"\n            \"addi         %[K], %[K], -128        \\n\\t\"\n            \"vle32.v      v0, (%[SRC])            \\n\\t\"\n            \"addi         %[SRC], %[SRC], 64      \\n\\t\"\n            \"vle32.v      v2, (%[SRC])            \\n\\t\"\n            \"addi         %[SRC], %[SRC], 64      \\n\\t\"\n            \"vle32.v      v4, (%[SRC])            \\n\\t\"\n            \"addi         %[SRC], %[SRC], 64      \\n\\t\"\n            \"vle32.v      v6, (%[SRC])            \\n\\t\"\n            \"addi         %[SRC], %[SRC], 64      \\n\\t\"\n            \"vle32.v      v8, (%[SRC])            \\n\\t\"\n            \"addi         %[SRC], %[SRC], 64      \\n\\t\"\n            \"vle32.v      v10, (%[SRC])           \\n\\t\"\n            \"addi         %[SRC], %[SRC], 64      \\n\\t\"\n            \"vle32.v      v12, (%[SRC])           \\n\\t\"\n            \"addi         %[SRC], %[SRC], 64      \\n\\t\"\n            \"vle32.v      v14, (%[SRC])           \\n\\t\"\n            \"addi         %[SRC], %[SRC], 64      \\n\\t\"\n            \"addi         a1, %[BUFFER], 0        \\n\\t\"\n            \"vfabs.v      v16, v0                 \\n\\t\"\n            \"vfabs.v      v18, v2                 \\n\\t\"\n            \"vfabs.v      v20, v4                 \\n\\t\"\n            \"vfabs.v      v22, v6                 \\n\\t\"\n            \"vfabs.v      v24, v8                 \\n\\t\"\n            \"vfabs.v      v26, v10                \\n\\t\"\n            \"vfabs.v      v28, v12                \\n\\t\"\n            \"vfabs.v      v30, v14                \\n\\t\"\n            \"vsetvli      t0, zero, e32, m1       \\n\\t\"\n            \"vfmax.vv     v16, v16, v17           \\n\\t\"\n            \"vfmax.vv     v18, v18, v19           \\n\\t\"\n            \"vfmax.vv     v20, v20, v21           \\n\\t\"\n            \"vfmax.vv     v22, v22, v23           \\n\\t\"\n            \"vfmax.vv     v24, v24, v25           \\n\\t\"\n            \"vfmax.vv     v26, v26, v27           \\n\\t\"\n            \"vfmax.vv     v28, v28, v29           \\n\\t\"\n            \"vfmax.vv     v30, v30, v31           \\n\\t\"\n            \"vse32.v      v16, (a1)               \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"vse32.v      v18, (a1)               \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"vse32.v      v20, (a1)               \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"vse32.v      v22, (a1)               \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"vse32.v      v24, (a1)               \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"vse32.v      v26, (a1)               \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"vse32.v      v28, (a1)               \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"vse32.v      v30, (a1)               \\n\\t\"\n            \"addi         a1, %[BUFFER], 0        \\n\\t\"\n            \"flw          f0, (a1)                \\n\\t\"\n            \"flw          f1, 4(a1)               \\n\\t\"\n            \"flw          f2, 8(a1)               \\n\\t\"\n            \"flw          f3, 12(a1)              \\n\\t\"\n            \"flw          f4, 16(a1)              \\n\\t\"\n            \"flw          f5, 20(a1)              \\n\\t\"\n            \"flw          f6, 24(a1)              \\n\\t\"\n            \"flw          f7, 28(a1)              \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"fmax.s       f1, f0, f1              \\n\\t\"\n            \"fmax.s       f3, f2, f3              \\n\\t\"\n            \"fmax.s       f5, f4, f5              \\n\\t\"\n            \"fmax.s       f7, f6, f7              \\n\\t\"\n            \"fmax.s       f3, f1, f3              \\n\\t\"\n            \"fmax.s       f7, f5, f7              \\n\\t\"\n            \"fmax.s       f10, f3, f7             \\n\\t\"\n            \"fmul.s       f10, f10, %[RMAXREC]    \\n\\t\"\n            \"fsw          f10, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 20      \\n\\t\"\n            \"fdiv.s       f10, %[FONE], f10       \\n\\t\"\n            \"flw          f0, (a1)                \\n\\t\"\n            \"flw          f1, 4(a1)               \\n\\t\"\n            \"flw          f2, 8(a1)               \\n\\t\"\n            \"flw          f3, 12(a1)              \\n\\t\"\n            \"flw          f4, 16(a1)              \\n\\t\"\n            \"flw          f5, 20(a1)              \\n\\t\"\n            \"flw          f6, 24(a1)              \\n\\t\"\n            \"flw          f7, 28(a1)              \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"fmax.s       f1, f0, f1              \\n\\t\"\n            \"fmax.s       f3, f2, f3              \\n\\t\"\n            \"fmax.s       f5, f4, f5              \\n\\t\"\n            \"fmax.s       f7, f6, f7              \\n\\t\"\n            \"fmax.s       f3, f1, f3              \\n\\t\"\n            \"fmax.s       f7, f5, f7              \\n\\t\"\n            \"fmax.s       f11, f3, f7             \\n\\t\"\n            \"fmul.s       f11, f11, %[RMAXREC]    \\n\\t\"\n            \"fsw          f11, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 20      \\n\\t\"\n            \"fdiv.s       f11, %[FONE], f11       \\n\\t\"\n            \"flw          f0, (a1)                \\n\\t\"\n            \"flw          f1, 4(a1)               \\n\\t\"\n            \"flw          f2, 8(a1)               \\n\\t\"\n            \"flw          f3, 12(a1)              \\n\\t\"\n            \"flw          f4, 16(a1)              \\n\\t\"\n            \"flw          f5, 20(a1)              \\n\\t\"\n            \"flw          f6, 24(a1)              \\n\\t\"\n            \"flw          f7, 28(a1)              \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"fmax.s       f1, f0, f1              \\n\\t\"\n            \"fmax.s       f3, f2, f3              \\n\\t\"\n            \"fmax.s       f5, f4, f5              \\n\\t\"\n            \"fmax.s       f7, f6, f7              \\n\\t\"\n            \"fmax.s       f3, f1, f3              \\n\\t\"\n            \"fmax.s       f7, f5, f7              \\n\\t\"\n            \"fmax.s       f12, f3, f7             \\n\\t\"\n            \"fmul.s       f12, f12, %[RMAXREC]    \\n\\t\"\n            \"fsw          f12, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 20      \\n\\t\"\n            \"fdiv.s       f12, %[FONE], f12       \\n\\t\"\n            \"flw          f0, (a1)                \\n\\t\"\n            \"flw          f1, 4(a1)               \\n\\t\"\n            \"flw          f2, 8(a1)               \\n\\t\"\n            \"flw          f3, 12(a1)              \\n\\t\"\n            \"flw          f4, 16(a1)              \\n\\t\"\n            \"flw          f5, 20(a1)              \\n\\t\"\n            \"flw          f6, 24(a1)              \\n\\t\"\n            \"flw          f7, 28(a1)              \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"fmax.s       f1, f0, f1              \\n\\t\"\n            \"fmax.s       f3, f2, f3              \\n\\t\"\n            \"fmax.s       f5, f4, f5              \\n\\t\"\n            \"fmax.s       f7, f6, f7              \\n\\t\"\n            \"fmax.s       f3, f1, f3              \\n\\t\"\n            \"fmax.s       f7, f5, f7              \\n\\t\"\n            \"fmax.s       f13, f3, f7             \\n\\t\"\n            \"fmul.s       f13, f13, %[RMAXREC]    \\n\\t\"\n            \"fsw          f13, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 20      \\n\\t\"\n            \"fdiv.s       f13, %[FONE], f13       \\n\\t\"\n            \"flw          f0, (a1)                \\n\\t\"\n            \"flw          f1, 4(a1)               \\n\\t\"\n            \"flw          f2, 8(a1)               \\n\\t\"\n            \"flw          f3, 12(a1)              \\n\\t\"\n            \"flw          f4, 16(a1)              \\n\\t\"\n            \"flw          f5, 20(a1)              \\n\\t\"\n            \"flw          f6, 24(a1)              \\n\\t\"\n            \"flw          f7, 28(a1)              \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"fmax.s       f1, f0, f1              \\n\\t\"\n            \"fmax.s       f3, f2, f3              \\n\\t\"\n            \"fmax.s       f5, f4, f5              \\n\\t\"\n            \"fmax.s       f7, f6, f7              \\n\\t\"\n            \"fmax.s       f3, f1, f3              \\n\\t\"\n            \"fmax.s       f7, f5, f7              \\n\\t\"\n            \"fmax.s       f14, f3, f7             \\n\\t\"\n            \"fmul.s       f14, f14, %[RMAXREC]    \\n\\t\"\n            \"fsw          f14, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 20      \\n\\t\"\n            \"fdiv.s       f14, %[FONE], f14       \\n\\t\"\n            \"flw          f0, (a1)                \\n\\t\"\n            \"flw          f1, 4(a1)               \\n\\t\"\n            \"flw          f2, 8(a1)               \\n\\t\"\n            \"flw          f3, 12(a1)              \\n\\t\"\n            \"flw          f4, 16(a1)              \\n\\t\"\n            \"flw          f5, 20(a1)              \\n\\t\"\n            \"flw          f6, 24(a1)              \\n\\t\"\n            \"flw          f7, 28(a1)              \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"fmax.s       f1, f0, f1              \\n\\t\"\n            \"fmax.s       f3, f2, f3              \\n\\t\"\n            \"fmax.s       f5, f4, f5              \\n\\t\"\n            \"fmax.s       f7, f6, f7              \\n\\t\"\n            \"fmax.s       f3, f1, f3              \\n\\t\"\n            \"fmax.s       f7, f5, f7              \\n\\t\"\n            \"fmax.s       f15, f3, f7             \\n\\t\"\n            \"fmul.s       f15, f15, %[RMAXREC]    \\n\\t\"\n            \"fsw          f15, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 20      \\n\\t\"\n            \"fdiv.s       f15, %[FONE], f15       \\n\\t\"\n            \"flw          f0, (a1)                \\n\\t\"\n            \"flw          f1, 4(a1)               \\n\\t\"\n            \"flw          f2, 8(a1)               \\n\\t\"\n            \"flw          f3, 12(a1)              \\n\\t\"\n            \"flw          f4, 16(a1)              \\n\\t\"\n            \"flw          f5, 20(a1)              \\n\\t\"\n            \"flw          f6, 24(a1)              \\n\\t\"\n            \"flw          f7, 28(a1)              \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"fmax.s       f1, f0, f1              \\n\\t\"\n            \"fmax.s       f3, f2, f3              \\n\\t\"\n            \"fmax.s       f5, f4, f5              \\n\\t\"\n            \"fmax.s       f7, f6, f7              \\n\\t\"\n            \"fmax.s       f3, f1, f3              \\n\\t\"\n            \"fmax.s       f7, f5, f7              \\n\\t\"\n            \"fmax.s       f16, f3, f7             \\n\\t\"\n            \"fmul.s       f16, f16, %[RMAXREC]    \\n\\t\"\n            \"fsw          f16, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 20      \\n\\t\"\n            \"fdiv.s       f16, %[FONE], f16       \\n\\t\"\n            \"flw          f0, (a1)                \\n\\t\"\n            \"flw          f1, 4(a1)               \\n\\t\"\n            \"flw          f2, 8(a1)               \\n\\t\"\n            \"flw          f3, 12(a1)              \\n\\t\"\n            \"flw          f4, 16(a1)              \\n\\t\"\n            \"flw          f5, 20(a1)              \\n\\t\"\n            \"flw          f6, 24(a1)              \\n\\t\"\n            \"flw          f7, 28(a1)              \\n\\t\"\n            \"addi         a1, a1, 32              \\n\\t\"\n            \"fmax.s       f1, f0, f1              \\n\\t\"\n            \"fmax.s       f3, f2, f3              \\n\\t\"\n            \"fmax.s       f5, f4, f5              \\n\\t\"\n            \"fmax.s       f7, f6, f7              \\n\\t\"\n            \"fmax.s       f3, f1, f3              \\n\\t\"\n            \"fmax.s       f7, f5, f7              \\n\\t\"\n            \"fmax.s       f17, f3, f7             \\n\\t\"\n            \"fmul.s       f17, f17, %[RMAXREC]    \\n\\t\"\n            \"fsw          f17, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], -136    \\n\\t\"\n            \"fdiv.s       f17, %[FONE], f17       \\n\\t\"\n            \"vsetvli      t0, zero, e32, m2       \\n\\t\"\n            \"vfmul.vf     v16, v0, f10            \\n\\t\"\n            \"vfmul.vf     v18, v2, f11            \\n\\t\"\n            \"vfmul.vf     v20, v4, f12            \\n\\t\"\n            \"vfmul.vf     v22, v6, f13            \\n\\t\"\n            \"vfmul.vf     v24, v8, f14            \\n\\t\"\n            \"vfmul.vf     v26, v10, f15           \\n\\t\"\n            \"vfmul.vf     v28, v12, f16           \\n\\t\"\n            \"vfmul.vf     v30, v14, f17           \\n\\t\"\n            \"vfcvt.x.f.v  v16, v16                \\n\\t\"\n            \"vfcvt.x.f.v  v18, v18                \\n\\t\"\n            \"vfcvt.x.f.v  v20, v20                \\n\\t\"\n            \"vfcvt.x.f.v  v22, v22                \\n\\t\"\n            \"vfcvt.x.f.v  v24, v24                \\n\\t\"\n            \"vfcvt.x.f.v  v26, v26                \\n\\t\"\n            \"vfcvt.x.f.v  v28, v28                \\n\\t\"\n            \"vfcvt.x.f.v  v30, v30                \\n\\t\"\n            \"vsetvli      t0, zero, e16, m1       \\n\\t\"\n            \"vnclip.wx    v16, v16, zero          \\n\\t\"\n            \"vnclip.wx    v18, v18, zero          \\n\\t\"\n            \"vnclip.wx    v20, v20, zero          \\n\\t\"\n            \"vnclip.wx    v22, v22, zero          \\n\\t\"\n            \"vnclip.wx    v24, v24, zero          \\n\\t\"\n            \"vnclip.wx    v26, v26, zero          \\n\\t\"\n            \"vnclip.wx    v28, v28, zero          \\n\\t\"\n            \"vnclip.wx    v30, v30, zero          \\n\\t\"\n            \"vsetvli      t0, t1, e8, mf2         \\n\\t\"\n            \"vnclip.wx    v16, v16, zero          \\n\\t\"\n            \"vnclip.wx    v18, v18, zero          \\n\\t\"\n            \"vnclip.wx    v20, v20, zero          \\n\\t\"\n            \"vnclip.wx    v22, v22, zero          \\n\\t\"\n            \"vnclip.wx    v24, v24, zero          \\n\\t\"\n            \"vnclip.wx    v26, v26, zero          \\n\\t\"\n            \"vnclip.wx    v28, v28, zero          \\n\\t\"\n            \"vnclip.wx    v30, v30, zero          \\n\\t\"\n            \"vse8.v       v16, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 20      \\n\\t\"\n            \"vse8.v       v18, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 20      \\n\\t\"\n            \"vse8.v       v20, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 20      \\n\\t\"\n            \"vse8.v       v22, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 20      \\n\\t\"\n            \"vse8.v       v24, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 20      \\n\\t\"\n            \"vse8.v       v26, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 20      \\n\\t\"\n            \"vse8.v       v28, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 20      \\n\\t\"\n            \"vse8.v       v30, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 16      \\n\\t\"\n            \"bge          %[K], t3, LOOP_MAIN%=   \\n\\t\"\n            \"blt          %[K], t2, TAIL%=        \\n\\t\"\n            \"LOOP_K%=:                            \\n\\t\"\n            \"vsetvli      t1, %[K], e32, m2       \\n\\t\"\n            \"vle32.v      v0, (%[SRC])            \\n\\t\"\n            \"addi         %[SRC], %[SRC], 64      \\n\\t\"\n            \"sub          %[K], %[K], t1          \\n\\t\"\n            \"vfabs.v      v16, v0                 \\n\\t\"\n            \"vsetvli      t0, zero, e32, m1       \\n\\t\"\n            \"vfmax.vv     v16, v16, v17           \\n\\t\"\n            \"vse32.v      v16, (%[BUFFER])        \\n\\t\"\n            \"flw          f0, (%[BUFFER])         \\n\\t\"\n            \"flw          f1, 4(%[BUFFER])        \\n\\t\"\n            \"flw          f2, 8(%[BUFFER])        \\n\\t\"\n            \"flw          f3, 12(%[BUFFER])       \\n\\t\"\n            \"flw          f4, 16(%[BUFFER])       \\n\\t\"\n            \"flw          f5, 20(%[BUFFER])       \\n\\t\"\n            \"flw          f6, 24(%[BUFFER])       \\n\\t\"\n            \"flw          f7, 28(%[BUFFER])       \\n\\t\"\n            \"fmax.s       f1, f0, f1              \\n\\t\"\n            \"fmax.s       f3, f2, f3              \\n\\t\"\n            \"fmax.s       f5, f4, f5              \\n\\t\"\n            \"fmax.s       f7, f6, f7              \\n\\t\"\n            \"fmax.s       f3, f1, f3              \\n\\t\"\n            \"fmax.s       f7, f5, f7              \\n\\t\"\n            \"fmax.s       f10, f3, f7             \\n\\t\"\n            \"fmul.s       f10, f10, %[RMAXREC]    \\n\\t\"\n            \"fsw          f10, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 4       \\n\\t\"\n            \"fdiv.s       f11, %[FONE], f10       \\n\\t\"\n            \"vsetvli      t0, zero, e32, m2       \\n\\t\"\n            \"vfmul.vf     v16, v0, f11            \\n\\t\"\n            \"vfcvt.x.f.v  v16, v16                \\n\\t\"\n            \"vsetvli      t0, zero, e16, m1       \\n\\t\"\n            \"vnclip.wx    v16, v16, zero          \\n\\t\"\n            \"vsetvli      t0, t1, e8, mf2         \\n\\t\"\n            \"vnclip.wx    v16, v16, zero          \\n\\t\"\n            \"vse8.v       v16, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 16      \\n\\t\"\n            \"bge          %[K], t2, LOOP_K%=      \\n\\t\"\n            \"TAIL%=:                              \\n\\t\"\n            \"blez         %[K], END%=             \\n\\t\"\n            \"vsetvli      t0, t3, e32, m2         \\n\\t\"\n            \"vxor.vv      v16, v16, v16           \\n\\t\"\n            \"jal          x0, LOOP_K%=            \\n\\t\"\n            \"END%=:                               \\n\\t\"\n            : [SRC] \"+r\"(SRC), [DST] \"+r\"(DST), [K] \"+r\"(CountK)\n            : [FONE] \"f\"(fone), [RMAXREC] \"f\"(range_max_reciprocal), [BUFFER] \"r\"(buffer)\n            : \"cc\", \"t3\", \"t2\", \"t1\", \"t0\", \"a1\", \"f0\", \"f1\", \"f2\", \"f3\", \"f4\", \"f5\", \"f6\", \"f7\", \"f10\", \"f11\", \"f12\",\n              \"f13\", \"f14\", \"f15\", \"f16\", \"f17\");\n    } else if (BlkLen == 32) {\n        __asm__ volatile(\n            \"addi         t3, zero, 32*4          \\n\\t\"\n            \"addi         t2, zero, 32            \\n\\t\"\n\n            \"addi         a1, %[SRC], 0           \\n\\t\"\n            \"addi         a2, %[SRC], 128         \\n\\t\"\n            \"addi         a3, %[SRC], 256         \\n\\t\"\n            \"addi         a4, %[SRC], 384         \\n\\t\"\n\n            \"addi         s1, %[DST], 0           \\n\\t\"\n            \"addi         s2, %[DST], 36          \\n\\t\"\n            \"addi         s3, %[DST], 72          \\n\\t\"\n            \"addi         s4, %[DST], 108         \\n\\t\"\n            \"blt          %[K], t3, LOOP_K%=      \\n\\t\"\n            \"blt          %[K], t2, TAIL%=        \\n\\t\"\n\n            \"LOOP_MAIN%=:                         \\n\\t\"\n            \"vsetvli      t1, zero, e32, m4       \\n\\t\"\n            \"addi         %[K], %[K], -128        \\n\\t\"\n            \"vle32.v      v0, (a1)                \\n\\t\"\n            \"addi         a1, a1, 512             \\n\\t\"\n            \"vle32.v      v4, (a2)                \\n\\t\"\n            \"addi         a2, a2, 512             \\n\\t\"\n            \"vle32.v      v8, (a3)                \\n\\t\"\n            \"addi         a3, a3, 512             \\n\\t\"\n            \"vle32.v      v12, (a4)               \\n\\t\"\n            \"addi         a4, a4, 512             \\n\\t\"\n            \"vfabs.v      v16, v0                 \\n\\t\"\n            \"vfabs.v      v20, v4                 \\n\\t\"\n            \"vfabs.v      v24, v8                 \\n\\t\"\n            \"vfabs.v      v28, v12                \\n\\t\"\n            \"vsetvli      t0, zero, e32, m2       \\n\\t\"\n            \"vfmax.vv     v16, v16, v18           \\n\\t\"\n            \"vfmax.vv     v20, v20, v22           \\n\\t\"\n            \"vfmax.vv     v24, v24, v26           \\n\\t\"\n            \"vfmax.vv     v28, v28, v30           \\n\\t\"\n            \"vsetvli      t0, zero, e32, m1       \\n\\t\"\n            \"vfmax.vv     v16, v16, v17           \\n\\t\"\n            \"vfmax.vv     v20, v20, v21           \\n\\t\"\n            \"vfmax.vv     v24, v24, v25           \\n\\t\"\n            \"vfmax.vv     v28, v28, v29           \\n\\t\"\n\n            \"vfredmax.vs  v17, v16, v17           \\n\\t\"\n            \"vfredmax.vs  v21, v20, v21           \\n\\t\"\n            \"vfredmax.vs  v25, v24, v25           \\n\\t\"\n            \"vfredmax.vs  v29, v28, v29           \\n\\t\"\n            \"vfmv.f.s     f10,  v17               \\n\\t\"\n            \"vfmv.f.s     f11,  v21               \\n\\t\"\n            \"vfmv.f.s     f12,  v25               \\n\\t\"\n            \"vfmv.f.s     f13,  v29               \\n\\t\"\n\n            \"fmul.s       f10, f10, %[RMAXREC]    \\n\\t\"\n            \"fmul.s       f11, f11, %[RMAXREC]    \\n\\t\"\n            \"fmul.s       f12, f12, %[RMAXREC]    \\n\\t\"\n            \"fmul.s       f13, f13, %[RMAXREC]    \\n\\t\"\n            \"fsw          f10, (s1)               \\n\\t\"\n            \"addi         s1, s1, 4               \\n\\t\"\n\n            \"fsw          f11, (s2)               \\n\\t\"\n            \"addi         s2, s2, 4               \\n\\t\"\n            \"fsw          f12, (s3)               \\n\\t\"\n            \"addi         s3, s3, 4               \\n\\t\"\n            \"fsw          f13, (s4)               \\n\\t\"\n            \"addi         s4, s4, 4               \\n\\t\"\n            \"fdiv.s       f10, %[FONE], f10       \\n\\t\"\n            \"fdiv.s       f11, %[FONE], f11       \\n\\t\"\n            \"fdiv.s       f12, %[FONE], f12       \\n\\t\"\n            \"fdiv.s       f13, %[FONE], f13       \\n\\t\"\n            \"vsetvli      t0, zero, e32, m4       \\n\\t\"\n            \"vfmul.vf     v16, v0, f10            \\n\\t\"\n            \"vfmul.vf     v20, v4, f11            \\n\\t\"\n            \"vfmul.vf     v24, v8, f12            \\n\\t\"\n            \"vfmul.vf     v28, v12, f13           \\n\\t\"\n            \"vfcvt.x.f.v  v16, v16                \\n\\t\"\n            \"vfcvt.x.f.v  v20, v20                \\n\\t\"\n            \"vfcvt.x.f.v  v24, v24                \\n\\t\"\n            \"vfcvt.x.f.v  v28, v28                \\n\\t\"\n            \"vsetvli      t0, zero, e16, m2       \\n\\t\"\n            \"vnclip.wx    v16, v16, zero          \\n\\t\"\n            \"vnclip.wx    v20, v20, zero          \\n\\t\"\n            \"vnclip.wx    v24, v24, zero          \\n\\t\"\n            \"vnclip.wx    v28, v28, zero          \\n\\t\"\n            \"vsetvli      t0, t1, e8, m1          \\n\\t\"\n            \"vnclip.wx    v16, v16, zero          \\n\\t\"\n            \"vnclip.wx    v20, v20, zero          \\n\\t\"\n            \"vnclip.wx    v24, v24, zero          \\n\\t\"\n            \"vnclip.wx    v28, v28, zero          \\n\\t\"\n            \"vse8.v       v16, (s1)               \\n\\t\"\n            \"addi         s1, s1, 140             \\n\\t\"\n            \"vse8.v       v20, (s2)               \\n\\t\"\n            \"addi         s2, s2, 140             \\n\\t\"\n            \"vse8.v       v24, (s3)               \\n\\t\"\n            \"addi         s3, s3, 140             \\n\\t\"\n            \"vse8.v       v28, (s4)               \\n\\t\"\n            \"addi         s4, s4, 140             \\n\\t\"\n            \"bge          %[K], t3, LOOP_MAIN%=   \\n\\t\"\n            \"blt          %[K], t2, TAIL%=        \\n\\t\"\n            \"LOOP_K%=:                            \\n\\t\"\n            \"vsetvli      t1, %[K], e32, m4       \\n\\t\"\n            \"vle32.v      v0, (a1)                \\n\\t\"\n            \"addi         a1, a1, 128             \\n\\t\"\n            \"sub          %[K], %[K], t1          \\n\\t\"\n            \"vfabs.v      v16, v0                 \\n\\t\"\n            \"vsetvli      t0, zero, e32, m2       \\n\\t\"\n            \"vfmax.vv     v16, v16, v18           \\n\\t\"\n            \"vsetvli      t0, zero, e32, m1       \\n\\t\"\n            \"vfmax.vv     v16, v16, v17           \\n\\t\"\n            \"vfredmax.vs  v17, v16, v17           \\n\\t\"\n            \"vfmv.f.s     f10,  v17               \\n\\t\"\n\n            \"fmul.s       f10, f10, %[RMAXREC]    \\n\\t\"\n            \"fsw          f10, (s1)               \\n\\t\"\n            \"addi         s1, s1, 4               \\n\\t\"\n            \"fdiv.s       f11, %[FONE], f10       \\n\\t\"\n            \"vsetvli      t0, zero, e32, m4       \\n\\t\"\n            \"vfmul.vf     v16, v0, f11            \\n\\t\"\n            \"vfcvt.x.f.v  v16, v16                \\n\\t\"\n            \"vsetvli      t0, zero, e16, m2       \\n\\t\"\n            \"vnclip.wx    v16, v16, zero          \\n\\t\"\n            \"vsetvli      t0, zero, e8, m1        \\n\\t\"\n            \"vnclip.wx    v16, v16, zero          \\n\\t\"\n            \"vse8.v       v16, (s1)               \\n\\t\"\n            \"addi         s1, s1, 32              \\n\\t\"\n            \"bge          %[K], t2, LOOP_K%=      \\n\\t\"\n            \"TAIL%=:                              \\n\\t\"\n            \"blez         %[K], END%=             \\n\\t\"\n            \"vsetvli      t0, t3, e32, m4         \\n\\t\"\n            \"vxor.vv      v0, v0, v0              \\n\\t\"\n            \"vxor.vv      v16, v16, v16           \\n\\t\"\n            \"jal          x0, LOOP_K%=            \\n\\t\"\n            \"END%=:                               \\n\\t\"\n            : [K] \"+r\"(CountK)\n            : [FONE] \"f\"(fone), [RMAXREC] \"f\"(range_max_reciprocal), [SRC] \"r\"(SRC), [DST] \"r\"(DST)\n            : \"cc\", \"t3\", \"t2\", \"t1\", \"t0\", \"a1\", \"a2\", \"a3\", \"a4\", \"s1\", \"s2\", \"s3\", \"s4\", \"f10\", \"f11\", \"f12\", \"f13\");\n    } else if (BlkLen == 64) {\n        __asm__ volatile(\n            \"addi         t3, zero, 64*2          \\n\\t\"\n            \"addi         t2, zero, 64            \\n\\t\"\n            \"addi         a1, %[SRC], 0           \\n\\t\"\n            \"addi         a2, %[SRC], 256         \\n\\t\"\n            \"addi         s1, %[DST], 0           \\n\\t\"\n            \"addi         s2, %[DST], 68          \\n\\t\"\n            \"blt          %[K], t3, LOOP_K%=      \\n\\t\"\n            \"blt          %[K], t2, TAIL%=        \\n\\t\"\n            \"LOOP_MAIN%=:                         \\n\\t\"\n            \"vsetvli      t1, zero, e32, m8       \\n\\t\"\n            \"addi         %[K], %[K], -128        \\n\\t\"\n            \"vle32.v      v0, (a1)                \\n\\t\"\n            \"addi         a1, a1, 512             \\n\\t\"\n            \"vle32.v      v8, (a2)                \\n\\t\"\n            \"addi         a2, a2, 512             \\n\\t\"\n            \"vfabs.v      v16, v0                 \\n\\t\"\n            \"vfabs.v      v24, v8                 \\n\\t\"\n            \"vsetvli      t0, zero, e32, m4       \\n\\t\"\n            \"vfmax.vv     v16, v16, v20           \\n\\t\"\n            \"vfmax.vv     v24, v24, v28           \\n\\t\"\n            \"vsetvli      t0, zero, e32, m2       \\n\\t\"\n            \"vfmax.vv     v16, v16, v18           \\n\\t\"\n            \"vfmax.vv     v24, v24, v26           \\n\\t\"\n            \"vsetvli      t0, zero, e32, m1       \\n\\t\"\n            \"vfmax.vv     v16, v16, v17           \\n\\t\"\n            \"vfmax.vv     v24, v24, v25           \\n\\t\"\n            \"vfredmax.vs  v17, v16, v17           \\n\\t\"\n            \"vfredmax.vs  v25, v24, v25           \\n\\t\"\n            \"vfmv.f.s     f10,  v17               \\n\\t\"\n            \"vfmv.f.s     f11,  v25               \\n\\t\"\n            \"fmul.s       f10, f10, %[RMAXREC]    \\n\\t\"\n            \"fmul.s       f11, f11, %[RMAXREC]    \\n\\t\"\n            \"fsw          f10, (s1)               \\n\\t\"\n            \"addi         s1, s1, 4               \\n\\t\"\n            \"fsw          f11, (s2)               \\n\\t\"\n            \"addi         s2, s2, 4               \\n\\t\"\n            \"fdiv.s       f10, %[FONE], f10       \\n\\t\"\n            \"fdiv.s       f11, %[FONE], f11       \\n\\t\"\n            \"vsetvli      t0, zero, e32, m8       \\n\\t\"\n            \"vfmul.vf     v16, v0, f10            \\n\\t\"\n            \"vfmul.vf     v24, v8, f11            \\n\\t\"\n            \"vfcvt.x.f.v  v16, v16                \\n\\t\"\n            \"vfcvt.x.f.v  v24, v24                \\n\\t\"\n            \"vsetvli      t0, zero, e16, m4       \\n\\t\"\n            \"vnclip.wx    v16, v16, zero          \\n\\t\"\n            \"vnclip.wx    v24, v24, zero          \\n\\t\"\n            \"vsetvli      t0, t1, e8, m2          \\n\\t\"\n            \"vnclip.wx    v16, v16, zero          \\n\\t\"\n            \"vnclip.wx    v24, v24, zero          \\n\\t\"\n            \"vse8.v       v16, (s1)               \\n\\t\"\n            \"addi         s1, s1, 132             \\n\\t\"\n            \"vse8.v       v24, (s2)               \\n\\t\"\n            \"addi         s2, s2, 132             \\n\\t\"\n            \"bge          %[K], t3, LOOP_MAIN%=   \\n\\t\"\n            \"blt          %[K], t2, TAIL%=        \\n\\t\"\n            \"LOOP_K%=:                            \\n\\t\"\n            \"vsetvli      t1, %[K], e32, m8       \\n\\t\"\n            \"vle32.v      v0, (a1)                \\n\\t\"\n            \"addi         a1, a1, 256             \\n\\t\"\n            \"sub          %[K], %[K], t1          \\n\\t\"\n            \"vfabs.v      v16, v0                 \\n\\t\"\n            \"vsetvli      t0, zero, e32, m4       \\n\\t\"\n            \"vfmax.vv     v16, v16, v20           \\n\\t\"\n            \"vsetvli      t0, zero, e32, m2       \\n\\t\"\n            \"vfmax.vv     v16, v16, v18           \\n\\t\"\n            \"vsetvli      t0, zero, e32, m1       \\n\\t\"\n            \"vfmax.vv     v16, v16, v17           \\n\\t\"\n            \"vfredmax.vs  v17, v16, v17           \\n\\t\"\n            \"vfmv.f.s     f10,  v17               \\n\\t\"\n            \"fmul.s       f10, f10, %[RMAXREC]    \\n\\t\"\n            \"fsw          f10, (s1)               \\n\\t\"\n            \"addi         s1, s1, 4               \\n\\t\"\n            \"fdiv.s       f11, %[FONE], f10       \\n\\t\"\n            \"vsetvli      t0, zero, e32, m8       \\n\\t\"\n            \"vfmul.vf     v16, v0, f11            \\n\\t\"\n            \"vfcvt.x.f.v  v16, v16                \\n\\t\"\n            \"vsetvli      t0, zero, e16, m4       \\n\\t\"\n            \"vnclip.wx    v16, v16, zero          \\n\\t\"\n            \"vsetvli      t0, zero, e8, m2        \\n\\t\"\n            \"vnclip.wx    v16, v16, zero          \\n\\t\"\n            \"vse8.v       v16, (s1)               \\n\\t\"\n            \"addi         s1, s1, 64              \\n\\t\"\n            \"bge          %[K], t2, LOOP_K%=      \\n\\t\"\n            \"TAIL%=:                              \\n\\t\"\n            \"blez         %[K], END%=             \\n\\t\"\n            \"vsetvli      t0, t3, e32, m8         \\n\\t\"\n            \"vxor.vv      v0, v0, v0              \\n\\t\"\n            \"vxor.vv      v16, v16, v16           \\n\\t\"\n            \"jal          x0, LOOP_K%=            \\n\\t\"\n            \"END%=:                               \\n\\t\"\n            : [K] \"+r\"(CountK)\n            : [SRC] \"r\"(SRC), [DST] \"r\"(DST), [FONE] \"f\"(fone), [RMAXREC] \"f\"(range_max_reciprocal)\n            : \"cc\", \"t3\", \"t2\", \"t1\", \"t0\", \"a1\", \"a2\", \"s1\", \"s2\", \"f10\", \"f11\");\n    } else if (BlkLen == 128) {\n        __asm__ volatile(\n            \"addi         t2, zero, 128           \\n\\t\"\n            \"addi         a1, %[SRC], 0           \\n\\t\"\n            \"addi         a2, %[SRC], 256         \\n\\t\"\n            \"blt          %[K], t2, TAIL%=        \\n\\t\"\n            \"LOOP_K%=:                            \\n\\t\"\n            \"vsetvli      t1, zero, e32, m8       \\n\\t\"\n            \"vle32.v      v0, (a1)                \\n\\t\"\n            \"addi         a1, a1, 512             \\n\\t\"\n            \"vle32.v      v8, (a2)                \\n\\t\"\n            \"addi         a2, a2, 512             \\n\\t\"\n            \"sub          %[K], %[K], t2          \\n\\t\"\n            \"QUANT%=:                             \\n\\t\"\n            \"vfabs.v      v16, v0                 \\n\\t\"\n            \"vfabs.v      v24, v8                 \\n\\t\"\n            \"vfmax.vv     v24, v16, v24           \\n\\t\"\n            \"vsetvli      t1, zero, e32, m4       \\n\\t\"\n            \"vfmax.vv     v28, v24, v28           \\n\\t\"\n            \"vsetvli      t0, zero, e32, m2       \\n\\t\"\n            \"vfmax.vv     v30, v28, v30           \\n\\t\"\n            \"vsetvli      t0, zero, e32, m1       \\n\\t\"\n            \"vfmax.vv     v30, v30, v31           \\n\\t\"\n            \"vfredmax.vs  v31, v30, v31           \\n\\t\"\n            \"vfmv.f.s     f10, v31                \\n\\t\"\n            \"fmul.s       f10, f10, %[RMAXREC]    \\n\\t\"\n            \"fsw          f10, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 4       \\n\\t\"\n            \"fdiv.s       f11, %[FONE], f10       \\n\\t\"\n            \"vsetvli      t0, zero, e32, m8       \\n\\t\"\n            \"vfmul.vf     v16, v0, f11            \\n\\t\"\n            \"vfmul.vf     v24, v8, f11            \\n\\t\"\n            \"vfcvt.x.f.v  v16, v16                \\n\\t\"\n            \"vfcvt.x.f.v  v24, v24                \\n\\t\"\n            \"vsetvli      t0, zero, e16, m4       \\n\\t\"\n            \"vnclip.wx    v16, v16, zero          \\n\\t\"\n            \"vnclip.wx    v20, v24, zero          \\n\\t\"\n            \"vsetvli      t0, zero, e8, m4        \\n\\t\"\n            \"vnclip.wx    v16, v16, zero          \\n\\t\"\n            \"vse8.v       v16, (%[DST])           \\n\\t\"\n            \"addi         %[DST], %[DST], 128     \\n\\t\"\n            \"bge          %[K], t2, LOOP_K%=      \\n\\t\"\n            \"TAIL%=:                              \\n\\t\"\n            \"blez         %[K], END%=             \\n\\t\"\n            \"vsetvli      t1, zero, e32, m8       \\n\\t\"\n            \"vxor.vv      v0, v0, v0              \\n\\t\"\n            \"vxor.vv      v8, v8, v8              \\n\\t\"\n            \"vsetvli      t0, %[K], e32, m8       \\n\\t\"\n            \"vle32.v      v0, (a1)                \\n\\t\"\n            \"sub          %[K], %[K], t0          \\n\\t\"\n            \"vsetvli      t0, %[K], e32, m8       \\n\\t\"\n            \"vle32.v      v8, (a2)                \\n\\t\"\n            \"sub          %[K], %[K], t0          \\n\\t\"\n            \"vsetvli      t1, zero, e32, m8       \\n\\t\"\n            \"jal          x0, QUANT%=             \\n\\t\"\n            \"END%=:                               \\n\\t\"\n\n            : [DST] \"+r\"(DST), [K] \"+r\"(CountK)\n            : [FONE] \"f\"(fone), [RMAXREC] \"f\"(range_max_reciprocal), [SRC] \"r\"(SRC)\n            : \"cc\", \"t2\", \"t1\", \"t0\", \"a1\", \"a2\", \"f10\", \"f11\");\n    } else {\n        float  buffer[8] = { 0.0f };\n        size_t cnt       = BlkLen / 256;\n\n        __asm__ volatile(\n            \"slli         t3, %[BLK], 2           \\n\\t\"\n            \"blt       %[K], %[BLK], LOOP_TAIL%=  \\n\\t\"\n            \"LOOP_MAIN%=:                         \\n\\t\"\n            \"vsetvli      t0, zero, e32, m1       \\n\\t\"\n            \"vxor.vv      v31, v31, v31           \\n\\t\"\n            \"vse32.v      v31, (%[BUFFER])        \\n\\t\"\n            \"addi         t6, %[CNT], 0           \\n\\t\"\n            \"LOOP_CMP%=:                          \\n\\t\"\n            \"addi         t6, t6, -1              \\n\\t\"\n            \"vsetvli      t0, zero, e32, m8       \\n\\t\"\n            \"vle32.v      v0, (%[SRC])            \\n\\t\"\n            \"addi         %[SRC], %[SRC], 256     \\n\\t\"\n            \"vle32.v      v8, (%[SRC])            \\n\\t\"\n            \"addi         %[SRC], %[SRC], 256     \\n\\t\"\n            \"vle32.v      v16, (%[SRC])           \\n\\t\"\n            \"addi         %[SRC], %[SRC], 256     \\n\\t\"\n            \"vle32.v      v24, (%[SRC])           \\n\\t\"\n            \"addi         %[SRC], %[SRC], 256     \\n\\t\"\n            \"vfabs.v      v0, v0                  \\n\\t\"\n            \"vfabs.v      v8, v8                  \\n\\t\"\n            \"vfabs.v      v16, v16                \\n\\t\"\n            \"vfabs.v      v24, v24                \\n\\t\"\n            \"vfmax.vv     v8, v0, v8              \\n\\t\"\n            \"vfmax.vv     v16, v16, v24           \\n\\t\"\n            \"vfmax.vv     v0, v0, v16             \\n\\t\"\n            \"vsetvli      t0, zero, e32, m4       \\n\\t\"\n            \"vfmax.vv     v0, v0, v4              \\n\\t\"\n            \"vsetvli      t0, zero, e32, m2       \\n\\t\"\n            \"vfmax.vv     v0, v0, v2              \\n\\t\"\n            \"vsetvli      t0, zero, e32, m1       \\n\\t\"\n            \"vfmax.vv     v0, v0, v1              \\n\\t\"\n            \"vle32.v      v30, (%[BUFFER])        \\n\\t\"\n            \"vfmax.vv     v31, v30,  v0           \\n\\t\"\n            \"vse32.v      v31, (%[BUFFER])        \\n\\t\"\n            \"bnez         t6, LOOP_CMP%=          \\n\\t\"\n            \"sub          %[SRC], %[SRC], t3      \\n\\t\"\n            \"addi         t6, %[CNT], 0           \\n\\t\"\n            \"flw          f0, (%[BUFFER])         \\n\\t\"\n            \"flw          f1, 4(%[BUFFER])        \\n\\t\"\n            \"flw          f2, 8(%[BUFFER])        \\n\\t\"\n            \"flw          f3, 12(%[BUFFER])       \\n\\t\"\n            \"flw          f4, 16(%[BUFFER])       \\n\\t\"\n            \"flw          f5, 20(%[BUFFER])       \\n\\t\"\n            \"flw          f6, 24(%[BUFFER])       \\n\\t\"\n            \"flw          f7, 28(%[BUFFER])       \\n\\t\"\n            \"fmax.s       f1, f0, f1              \\n\\t\"\n            \"fmax.s       f3, f2, f3              \\n\\t\"\n            \"fmax.s       f5, f4, f5              \\n\\t\"\n            \"fmax.s       f7, f6, f7              \\n\\t\"\n            \"fmax.s       f3, f1, f3              \\n\\t\"\n            \"fmax.s       f7, f5, f7              \\n\\t\"\n            \"fmax.s       f10, f3, f7             \\n\\t\"\n            \"fmul.s       f10, f10, %[RMAXREC]    \\n\\t\"\n            \"fsw          f10,  (%[DST])          \\n\\t\"\n            \"addi         %[DST], %[DST], 4       \\n\\t\"\n            \"fdiv.s       f11, %[FONE], f10       \\n\\t\"\n            \"addi         t6,  %[CNT], 0          \\n\\t\"\n            \"LOOP_QUANT%=:                        \\n\\t\"\n            \"addi         t6, t6, -1              \\n\\t\"\n            \"vsetvli      t0, zero, e32, m8       \\n\\t\"\n            \"vle32.v      v0, (%[SRC])            \\n\\t\"\n            \"addi         %[SRC], %[SRC], 256     \\n\\t\"\n            \"vle32.v      v8, (%[SRC])            \\n\\t\"\n            \"addi         %[SRC], %[SRC], 256     \\n\\t\"\n            \"vle32.v      v16, (%[SRC])           \\n\\t\"\n            \"addi         %[SRC], %[SRC], 256     \\n\\t\"\n            \"vle32.v      v24, (%[SRC])           \\n\\t\"\n            \"addi         %[SRC], %[SRC], 256     \\n\\t\"\n            \"vsetvli      t0, zero, e32, m8       \\n\\t\"\n            \"vfmul.vf     v0, v0, f11             \\n\\t\"\n            \"vfmul.vf     v8, v8, f11             \\n\\t\"\n            \"vfmul.vf     v16, v16, f11           \\n\\t\"\n            \"vfmul.vf     v24, v24, f11           \\n\\t\"\n            \"vfcvt.x.f.v  v0, v0                  \\n\\t\"\n            \"vfcvt.x.f.v  v8, v8                  \\n\\t\"\n            \"vfcvt.x.f.v  v16, v16                \\n\\t\"\n            \"vfcvt.x.f.v  v24, v24                \\n\\t\"\n            \"vsetvli      t0, zero, e16, m4       \\n\\t\"\n            \"vnclip.wx    v0, v0, zero            \\n\\t\"\n            \"vnclip.wx    v4, v8, zero            \\n\\t\"\n            \"vnclip.wx    v8, v16, zero           \\n\\t\"\n            \"vnclip.wx    v12, v24, zero          \\n\\t\"\n            \"vsetvli      t0, zero, e8, m4        \\n\\t\"\n            \"vnclip.wx    v0, v0, zero            \\n\\t\"\n            \"vnclip.wx    v4, v8, zero            \\n\\t\"\n            \"vse8.v       v0, (%[DST])            \\n\\t\"\n            \"addi         %[DST], %[DST], 128     \\n\\t\"\n            \"vse8.v       v4, (%[DST])            \\n\\t\"\n            \"addi         %[DST], %[DST], 128     \\n\\t\"\n            \"bnez         t6, LOOP_QUANT%=        \\n\\t\"\n            \"sub           %[K], %[K], %[BLK]     \\n\\t\"\n            \"bge        %[K], %[BLK], LOOP_MAIN%= \\n\\t\"\n            \"blez         %[K], END%=             \\n\\t\"\n            \"LOOP_TAIL%=:                         \\n\\t\"\n            \"vsetvli      t0, zero, e32, m1       \\n\\t\"\n            \"vxor.vv      v31, v31, v31           \\n\\t\"\n            \"vse32.v      v31, (%[BUFFER])        \\n\\t\"\n            \"addi         t6, %[K], 0             \\n\\t\"\n            \"addi         s1, %[SRC], 0           \\n\\t\"\n            \"TAIL_CMP%=:                          \\n\\t\"\n            \"vsetvli      t0, zero, e32, m8       \\n\\t\"\n            \"vxor.vv       v0, v0, v0             \\n\\t\"\n            \"vsetvli      t0, t6, e32, m8         \\n\\t\"\n            \"vle32.v      v0, (%[SRC])            \\n\\t\"\n            \"addi         %[SRC], %[SRC], 256     \\n\\t\"\n            \"sub          t6, t6, t0              \\n\\t\"\n            \"vfabs.v      v0, v0                  \\n\\t\"\n            \"vsetvli      t0, zero, e32, m4       \\n\\t\"\n            \"vfmax.vv     v0, v0, v4              \\n\\t\"\n            \"vsetvli      t0, zero, e32, m2       \\n\\t\"\n            \"vfmax.vv     v0, v0, v2              \\n\\t\"\n            \"vsetvli      t0, zero, e32, m1       \\n\\t\"\n            \"vfmax.vv     v0, v0, v1              \\n\\t\"\n            \"vle32.v      v30, (%[BUFFER])        \\n\\t\"\n            \"vfmax.vv     v31, v30,  v0           \\n\\t\"\n            \"vse32.v      v31, (%[BUFFER])        \\n\\t\"\n            \"bnez         t6, TAIL_CMP%=          \\n\\t\"\n            \"addi         t6, %[K], 0             \\n\\t\"\n            \"flw          f0, (%[BUFFER])         \\n\\t\"\n            \"flw          f1, 4(%[BUFFER])        \\n\\t\"\n            \"flw          f2, 8(%[BUFFER])        \\n\\t\"\n            \"flw          f3, 12(%[BUFFER])       \\n\\t\"\n            \"flw          f4, 16(%[BUFFER])       \\n\\t\"\n            \"flw          f5, 20(%[BUFFER])       \\n\\t\"\n            \"flw          f6, 24(%[BUFFER])       \\n\\t\"\n            \"flw          f7, 28(%[BUFFER])       \\n\\t\"\n            \"fmax.s       f1, f0, f1              \\n\\t\"\n            \"fmax.s       f3, f2, f3              \\n\\t\"\n            \"fmax.s       f5, f4, f5              \\n\\t\"\n            \"fmax.s       f7, f6, f7              \\n\\t\"\n            \"fmax.s       f3, f1, f3              \\n\\t\"\n            \"fmax.s       f7, f5, f7              \\n\\t\"\n            \"fmax.s       f10, f3, f7             \\n\\t\"\n            \"fmul.s       f10, f10, %[RMAXREC]    \\n\\t\"\n            \"fsw          f10,  (%[DST])          \\n\\t\"\n            \"addi         %[DST], %[DST], 4       \\n\\t\"\n            \"fdiv.s       f11, %[FONE], f10       \\n\\t\"\n            \"addi         t6,  %[K], 0            \\n\\t\"\n            \"TAIL_QUANT%=:                        \\n\\t\"\n            \"vsetvli      t0, zero, e32, m8       \\n\\t\"\n            \"vxor.vv       v0, v0, v0             \\n\\t\"\n            \"vsetvli      t1, t6, e32, m8         \\n\\t\"\n            \"vle32.v      v0, (s1)                \\n\\t\"\n            \"addi         s1, s1, 256             \\n\\t\"\n            \"sub          t6, t6, t1              \\n\\t\"\n            \"vsetvli      t0, zero, e32, m8       \\n\\t\"\n            \"vfmul.vf     v0, v0, f11             \\n\\t\"\n            \"vfcvt.x.f.v  v0, v0                  \\n\\t\"\n            \"vsetvli      t0, zero, e16, m4       \\n\\t\"\n            \"vnclip.wx    v0, v0, zero            \\n\\t\"\n            \"vsetvli      t0, t1, e8, m2          \\n\\t\"\n            \"vnclip.wx    v0, v0, zero            \\n\\t\"\n            \"vse8.v       v0, (%[DST])            \\n\\t\"\n            \"addi         %[DST], %[DST], 64      \\n\\t\"\n            \"bnez         t6, TAIL_QUANT%=        \\n\\t\"\n            \"END%=:                               \\n\\t\"\n            : [SRC] \"+r\"(SRC), [DST] \"+r\"(DST), [K] \"+r\"(CountK)\n            : [FONE] \"f\"(fone), [RMAXREC] \"f\"(range_max_reciprocal), [BLK] \"r\"(BlkLen), [BUFFER] \"r\"(buffer),\n              [CNT] \"r\"(cnt)\n            : \"cc\", \"t1\", \"t0\", \"t6\", \"s1\", \"f0\", \"f1\", \"f2\", \"f3\", \"f4\", \"f5\", \"f6\");\n    }\n}\n\n}  // namespace ime1\n\nnamespace {\n#define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4          \\\n    \"vmadot       v16, v14, v0            \\n\\t\" \\\n    \"vmadot       v18, v14, v1            \\n\\t\" \\\n    \"vmadot       v20, v14, v2            \\n\\t\" \\\n    \"vmadot       v22, v14, v3            \\n\\t\" \\\n    \"vmadot       v16, v15, v4            \\n\\t\" \\\n    \"vmadot       v18, v15, v5            \\n\\t\" \\\n    \"vmadot       v20, v15, v6            \\n\\t\" \\\n    \"vmadot       v22, v15, v7            \\n\\t\"\n\n#define SQ4BIT_KERNEL_ACC_1X4X4                 \\\n    \"vfcvt.f.x.v  v16,  v16               \\n\\t\" \\\n    \"vfcvt.f.x.v  v18,  v18               \\n\\t\" \\\n    \"vfcvt.f.x.v  v20,  v20               \\n\\t\" \\\n    \"vfcvt.f.x.v  v22,  v22               \\n\\t\" \\\n    \"addi         s2, s1, 16              \\n\\t\" \\\n    \"addi         s3, s1, 32              \\n\\t\" \\\n    \"addi         s4, s1, 48              \\n\\t\" \\\n    \"addi         s6, s5, 12              \\n\\t\" \\\n    \"vfmacc.vv    v28, v16, v24           \\n\\t\" \\\n    \"vfmacc.vv    v29, v18, v25           \\n\\t\" \\\n    \"vfmacc.vv    v30, v20, v26           \\n\\t\" \\\n    \"vfmacc.vv    v31, v22, v27           \\n\\t\"\n\n#define SQ4BIT_KERNEL_ACC_F16_1X4X4             \\\n    \"vfcvt.f.x.v  v16,  v16               \\n\\t\" \\\n    \"vfcvt.f.x.v  v18,  v18               \\n\\t\" \\\n    \"vfcvt.f.x.v  v20,  v20               \\n\\t\" \\\n    \"vfcvt.f.x.v  v22,  v22               \\n\\t\" \\\n    \"addi         s2, s1, 8               \\n\\t\" \\\n    \"addi         s3, s1, 16              \\n\\t\" \\\n    \"addi         s4, s1, 24              \\n\\t\" \\\n    \"addi         s6, s5, 12              \\n\\t\" \\\n    \"vfmacc.vv    v28, v16, v24           \\n\\t\" \\\n    \"vfmacc.vv    v29, v18, v25           \\n\\t\" \\\n    \"vfmacc.vv    v30, v20, v26           \\n\\t\" \\\n    \"vfmacc.vv    v31, v22, v27           \\n\\t\"\n\n#define SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4          \\\n    \"vle8.v       v4, (s1)                \\n\\t\" \\\n    \"addi         s1, s1, 128             \\n\\t\" \\\n    \"vle8.v       v5, (s2)                \\n\\t\" \\\n    \"addi         s2, s2, 128             \\n\\t\" \\\n    \"vle8.v       v6, (s3)                \\n\\t\" \\\n    \"addi         s3, s3, 128             \\n\\t\" \\\n    \"vle8.v       v7, (s4)                \\n\\t\" \\\n    \"addi         s4, s4, 128             \\n\\t\" \\\n    \"vsetvli      t0, zero, e8, mf4       \\n\\t\" \\\n    \"vle8.v       v14, (s5)               \\n\\t\" \\\n    \"addi         s5, s5, 16              \\n\\t\" \\\n    \"vle8.v       v15, (s6)               \\n\\t\" \\\n    \"addi         s6, s6, 16              \\n\\t\" \\\n    \"addi         t5, t5, -1              \\n\\t\" \\\n    \"vsetvli      t0, zero, e8, m1        \\n\\t\" \\\n    \"vand.vi      v0, v4, 15              \\n\\t\" \\\n    \"vand.vi      v1, v5, 15              \\n\\t\" \\\n    \"vand.vi      v2, v6, 15              \\n\\t\" \\\n    \"vand.vi      v3, v7, 15              \\n\\t\" \\\n    \"vsrl.vi      v4, v4, 4               \\n\\t\" \\\n    \"vsrl.vi      v5, v5, 4               \\n\\t\" \\\n    \"vsrl.vi      v6, v6, 4               \\n\\t\" \\\n    \"vsrl.vi      v7, v7, 4               \\n\\t\"\n\n#define SQ4BIT_KERNEL_LOAD_ZP_16X1              \\\n    \"vsetvli      t0, zero, e8, mf2       \\n\\t\" \\\n    \"vle8.v       v1, (s7)                \\n\\t\" \\\n    \"vsetvli      t0, zero, e8, m1        \\n\\t\" \\\n    \"vrgather.vv  v8, v1, v13             \\n\\t\" \\\n    \"vadd.vi      v13, v13, 4             \\n\\t\" \\\n    \"vrgather.vv  v9, v1, v13             \\n\\t\" \\\n    \"vadd.vi      v13, v13, 4             \\n\\t\" \\\n    \"vrgather.vv  v10, v1, v13            \\n\\t\" \\\n    \"vadd.vi      v13, v13, 4             \\n\\t\" \\\n    \"vrgather.vv  v11, v1, v13            \\n\\t\" \\\n    \"vadd.vi      v13, v13, -12           \\n\\t\"\n\n// using for M4Kernel\n#define LOAD_B_16x8x2                           \\\n    \"vsetvli      t0, zero, e8, m1        \\n\\t\" \\\n    \"vle8.v       v6, (s1)                \\n\\t\" \\\n    \"addi         s1, s1, 32*4            \\n\\t\" \\\n    \"vle8.v       v7, (s2)                \\n\\t\" \\\n    \"addi         s2, s2, 32*4            \\n\\t\" \\\n    \"vle8.v       v8, (s3)                \\n\\t\" \\\n    \"addi         s3, s3, 32*4            \\n\\t\" \\\n    \"vle8.v       v9, (s4)                \\n\\t\" \\\n    \"addi         s4, s4, 32*4            \\n\\t\" \\\n                                                \\\n    \"vand.vi      v2, v6, 15              \\n\\t\" \\\n    \"vand.vi      v3, v7, 15              \\n\\t\" \\\n    \"vand.vi      v4, v8, 15              \\n\\t\" \\\n    \"vand.vi      v5, v9, 15              \\n\\t\" \\\n                                                \\\n    \"vsrl.vi      v6, v6, 4               \\n\\t\" \\\n    \"vsrl.vi      v7, v7, 4               \\n\\t\" \\\n    \"vsrl.vi      v8, v8, 4               \\n\\t\" \\\n    \"vsrl.vi      v9, v9, 4               \\n\\t\"\n\n// [s2|s5, s3, s4, s6]\n#define LOAD_SCALE_4x16_FP16                    \\\n    \"addi         s2, s5, -8              \\n\\t\" \\\n    \"addi         s3, s5, 8               \\n\\t\" \\\n    \"addi         s4, s5, 16              \\n\\t\" \\\n    \"addi         s6, s5, 24              \\n\\t\" \\\n    \"li           t1, 0xf0                \\n\\t\" \\\n    \"vmv.s.x      v0, t1                  \\n\\t\" \\\n    \"vsetvli      t0, zero, e16, mf4      \\n\\t\" \\\n    \"vle16.v      v9, (s5)                \\n\\t\" \\\n    \"vle16.v      v11, (s3)               \\n\\t\" \\\n    \"vle16.v      v13, (s4)               \\n\\t\" \\\n    \"vle16.v      v15, (s6)               \\n\\t\" \\\n    \"vsetvli      t0, zero, e16, mf2      \\n\\t\" \\\n    \"vle16.v      v9, (s2), v0.t          \\n\\t\" \\\n    \"vle16.v      v11, (s5), v0.t         \\n\\t\" \\\n    \"vle16.v      v13, (s3), v0.t         \\n\\t\" \\\n    \"vle16.v      v15, (s4), v0.t         \\n\\t\" \\\n    \"vfwcvt.f.f.v v8, v9                  \\n\\t\" \\\n    \"vfwcvt.f.f.v v10, v11                \\n\\t\" \\\n    \"vfwcvt.f.f.v v12, v13                \\n\\t\" \\\n    \"vfwcvt.f.f.v v14, v15                \\n\\t\" \\\n    \"vsetvli      t0, zero, e32, m1       \\n\\t\" \\\n    \"vmv.v.v      v9, v8                  \\n\\t\" \\\n    \"vmv.v.v      v11, v10                \\n\\t\" \\\n    \"vmv.v.v      v13, v12                \\n\\t\" \\\n    \"vmv.v.v      v15, v14                \\n\\t\" \\\n    \"li           t1, 0xf0                \\n\\t\" \\\n    \"vmv.s.x      v0, t1                  \\n\\t\" \\\n    \"vsetvli      t0, zero, e32, mf2      \\n\\t\" \\\n    \"vfmul.vf     v8, v8, f1              \\n\\t\" \\\n    \"vfmul.vf     v10, v10, f1            \\n\\t\" \\\n    \"vfmul.vf     v12, v12, f1            \\n\\t\" \\\n    \"vfmul.vf     v14, v14, f1            \\n\\t\" \\\n    \"vfmul.vf     v9, v9, f3              \\n\\t\" \\\n    \"vfmul.vf     v11, v11, f3            \\n\\t\" \\\n    \"vfmul.vf     v13, v13, f3            \\n\\t\" \\\n    \"vfmul.vf     v15, v15, f3            \\n\\t\" \\\n    \"vsetvli      t0, zero, e32, m1       \\n\\t\" \\\n    \"vfmul.vf     v8, v8, f2, v0.t        \\n\\t\" \\\n    \"vfmul.vf     v10, v10, f2, v0.t      \\n\\t\" \\\n    \"vfmul.vf     v12, v12, f2, v0.t      \\n\\t\" \\\n    \"vfmul.vf     v14, v14, f2, v0.t      \\n\\t\" \\\n    \"vfmul.vf     v9, v9, f4, v0.t        \\n\\t\" \\\n    \"vfmul.vf     v11, v11, f4, v0.t      \\n\\t\" \\\n    \"vfmul.vf     v13, v13, f4, v0.t      \\n\\t\" \\\n    \"vfmul.vf     v15, v15, f4, v0.t      \\n\\t\"\n\n// [s2|s5, s3, s4, s6]\n#define LOAD_SCALE_4x16                         \\\n    \"addi         s2, s5, -16             \\n\\t\" \\\n    \"addi         s3, s5, 16              \\n\\t\" \\\n    \"addi         s4, s5, 32              \\n\\t\" \\\n    \"addi         s6, s5, 48              \\n\\t\" \\\n    \"li           t1, 0xf0                \\n\\t\" \\\n    \"vmv.s.x      v0, t1                  \\n\\t\" \\\n    \"vsetvli      t0, zero, e32, mf2      \\n\\t\" \\\n    \"vle32.v      v8, (s5)                \\n\\t\" \\\n    \"vle32.v      v10, (s3)               \\n\\t\" \\\n    \"vle32.v      v12, (s4)               \\n\\t\" \\\n    \"vle32.v      v14, (s6)               \\n\\t\" \\\n    \"vsetvli      t0, zero, e32, m1       \\n\\t\" \\\n    \"vle32.v      v8, (s2), v0.t          \\n\\t\" \\\n    \"vle32.v      v10, (s5), v0.t         \\n\\t\" \\\n    \"vle32.v      v12, (s3), v0.t         \\n\\t\" \\\n    \"vle32.v      v14, (s4), v0.t         \\n\\t\" \\\n    \"vmv.v.v      v9, v8                  \\n\\t\" \\\n    \"vmv.v.v      v11, v10                \\n\\t\" \\\n    \"vmv.v.v      v13, v12                \\n\\t\" \\\n    \"vmv.v.v      v15, v14                \\n\\t\" \\\n    \"vsetvli      t0, zero, e32, mf2      \\n\\t\" \\\n    \"vfmul.vf     v8, v8, f1              \\n\\t\" \\\n    \"vfmul.vf     v10, v10, f1            \\n\\t\" \\\n    \"vfmul.vf     v12, v12, f1            \\n\\t\" \\\n    \"vfmul.vf     v14, v14, f1            \\n\\t\" \\\n    \"vfmul.vf     v9, v9, f3              \\n\\t\" \\\n    \"vfmul.vf     v11, v11, f3            \\n\\t\" \\\n    \"vfmul.vf     v13, v13, f3            \\n\\t\" \\\n    \"vfmul.vf     v15, v15, f3            \\n\\t\" \\\n    \"vsetvli      t0, zero, e32, m1       \\n\\t\" \\\n    \"vfmul.vf     v8, v8, f2, v0.t        \\n\\t\" \\\n    \"vfmul.vf     v10, v10, f2, v0.t      \\n\\t\" \\\n    \"vfmul.vf     v12, v12, f2, v0.t      \\n\\t\" \\\n    \"vfmul.vf     v14, v14, f2, v0.t      \\n\\t\" \\\n    \"vfmul.vf     v9, v9, f4, v0.t        \\n\\t\" \\\n    \"vfmul.vf     v11, v11, f4, v0.t      \\n\\t\" \\\n    \"vfmul.vf     v13, v13, f4, v0.t      \\n\\t\" \\\n    \"vfmul.vf     v15, v15, f4, v0.t      \\n\\t\"\n\n//[s1| BIAS, s2, s3, s4]\n#define LOAD_BIAS                               \\\n    \"vsetvli      t0, zero, e32, mf2      \\n\\t\" \\\n    \"li           t1, 0xf0                \\n\\t\" \\\n    \"vmv.s.x      v0, t1                  \\n\\t\" \\\n    \"addi         s1, %[BIAS], -16        \\n\\t\" \\\n    \"addi         s2, %[BIAS], 16         \\n\\t\" \\\n    \"addi         s3, %[BIAS], 32         \\n\\t\" \\\n    \"addi         s4, %[BIAS], 48         \\n\\t\" \\\n                                                \\\n    \"vle32.v      v24, (%[BIAS])          \\n\\t\" \\\n    \"vle32.v      v26, (s2)               \\n\\t\" \\\n    \"vle32.v      v28, (s3)               \\n\\t\" \\\n    \"vle32.v      v30, (s4)               \\n\\t\" \\\n    \"vsetvli      t0, zero, e32, m1       \\n\\t\" \\\n    \"vle32.v      v24, (s1), v0.t         \\n\\t\" \\\n    \"vle32.v      v26, (%[BIAS]), v0.t    \\n\\t\" \\\n    \"vle32.v      v28, (s2), v0.t         \\n\\t\" \\\n    \"vle32.v      v30, (s3), v0.t         \\n\\t\" \\\n    \"vmv.v.v      v25, v24                \\n\\t\" \\\n    \"vmv.v.v      v27, v26                \\n\\t\" \\\n    \"vmv.v.v      v29, v28                \\n\\t\" \\\n    \"vmv.v.v      v31, v30                \\n\\t\"\n\n#define SQ4BIT_KERNEL_COMP_4x16x16              \\\n    \"vmadot       v16, v10, v2            \\n\\t\" \\\n    \"vmadot       v18, v10, v3            \\n\\t\" \\\n    \"vmadot       v20, v10, v4            \\n\\t\" \\\n    \"vmadot       v22, v10, v5            \\n\\t\" \\\n    \"vmadot       v16, v11, v6            \\n\\t\" \\\n    \"vmadot       v18, v11, v7            \\n\\t\" \\\n    \"vmadot       v20, v11, v8            \\n\\t\" \\\n    \"vmadot       v22, v11, v9            \\n\\t\"\n\n#define SAVE_RESULT_4x16                        \\\n    \"addi         a1, %[C], 0             \\n\\t\" \\\n    \"add          a2, %[C], %[LDC]        \\n\\t\" \\\n    \"add          a3, a2, %[LDC]          \\n\\t\" \\\n    \"add          a4, a3, %[LDC]          \\n\\t\" \\\n    \"addi         a2, a2, -16             \\n\\t\" \\\n    \"addi         a4, a4, -16             \\n\\t\" \\\n    \"li           t1, 0xf0                \\n\\t\" \\\n    \"vmv.s.x      v0, t1                  \\n\\t\" \\\n    \"vsetvli      t0, zero, e32, mf2      \\n\\t\" \\\n                                                \\\n    \"vse32.v      v24, (a1)               \\n\\t\" \\\n    \"addi         a1, a1, 16              \\n\\t\" \\\n    \"vse32.v      v25, (a3)               \\n\\t\" \\\n    \"addi         a3, a3, 16              \\n\\t\" \\\n                                                \\\n    \"vse32.v      v26, (a1)               \\n\\t\" \\\n    \"addi         a1, a1, 16              \\n\\t\" \\\n    \"vse32.v      v27, (a3)               \\n\\t\" \\\n    \"addi         a3, a3, 16              \\n\\t\" \\\n                                                \\\n    \"vse32.v      v28, (a1)               \\n\\t\" \\\n    \"addi         a1, a1, 16              \\n\\t\" \\\n    \"vse32.v      v29, (a3)               \\n\\t\" \\\n    \"addi         a3, a3, 16              \\n\\t\" \\\n                                                \\\n    \"vse32.v      v30, (a1)               \\n\\t\" \\\n    \"vse32.v      v31, (a3)               \\n\\t\" \\\n    \"vsetvli      t0, zero, e32, m1       \\n\\t\" \\\n                                                \\\n    \"vse32.v      v24, (a2), v0.t         \\n\\t\" \\\n    \"addi         a2, a2, 16              \\n\\t\" \\\n    \"vse32.v      v25, (a4), v0.t         \\n\\t\" \\\n    \"addi         a4, a4, 16              \\n\\t\" \\\n                                                \\\n    \"vse32.v      v26, (a2), v0.t         \\n\\t\" \\\n    \"addi         a2, a2, 16              \\n\\t\" \\\n    \"vse32.v      v27, (a4), v0.t         \\n\\t\" \\\n    \"addi         a4, a4, 16              \\n\\t\" \\\n                                                \\\n    \"vse32.v      v28, (a2), v0.t         \\n\\t\" \\\n    \"addi         a2, a2, 16              \\n\\t\" \\\n    \"vse32.v      v29, (a4), v0.t         \\n\\t\" \\\n    \"addi         a4, a4, 16              \\n\\t\" \\\n                                                \\\n    \"vse32.v      v30, (a2), v0.t         \\n\\t\" \\\n    \"vse32.v      v31, (a4), v0.t         \\n\\t\"\n\n#define SQ4BIT_KERNEL_LOAD_ZP_16X1_v2           \\\n    \"vsetvli      t0, zero, e8, mf2       \\n\\t\" \\\n    \"vle8.v       v11, (s6)               \\n\\t\" \\\n    \"vsetvli      t0, zero, e8, m1        \\n\\t\" \\\n    \"vrgather.vv  v12, v11, v1            \\n\\t\" \\\n    \"vadd.vi      v1, v1, 4               \\n\\t\" \\\n    \"vrgather.vv  v13, v11, v1            \\n\\t\" \\\n    \"vadd.vi      v1, v1, 4               \\n\\t\" \\\n    \"vrgather.vv  v14, v11, v1            \\n\\t\" \\\n    \"vadd.vi      v1, v1, 4               \\n\\t\" \\\n    \"vrgather.vv  v15, v11, v1            \\n\\t\" \\\n    \"vadd.vi      v1, v1, -12             \\n\\t\"\n\ntemplate <bool HasZeroPoint>\nvoid SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t            BlkLen,\n                                                const std::byte * QuantA,\n                                                const std::byte * QuantBData,\n                                                const float *     QuantBScale,\n                                                const std::byte * QuantBZeroPoint,\n                                                float *           C,\n                                                size_t            CountN,\n                                                size_t            BlockCountK,\n                                                const float *     Bias,\n                                                const size_t      ldc) {\n    GGML_UNUSED(QuantBScale);\n    GGML_UNUSED(QuantBZeroPoint);\n    size_t       LDC   = ldc * sizeof(float);\n    const size_t INNER = BlkLen / 16;\n    float        tmp[4 * 16];\n\n    if constexpr (HasZeroPoint) {\n        for (size_t n = 0; n < CountN; n += 16) {\n            size_t      NBLKS         = (CountN - n) > 16 ? 16 : CountN - n;\n            std::byte * QuantBDataPtr = (std::byte *) QuantBData +           //\n                                        n * BlockCountK * BlkLen / 2 +       // b data\n                                        n * BlockCountK * sizeof(uint8_t) +  // zp\n                                        n * BlockCountK * sizeof(_Float16);    // scale\n            float * CPtr = C + n;\n            if (NBLKS < 16) {\n                CPtr = tmp;\n                LDC  = 16 * sizeof(float);\n            }\n            if (Bias != nullptr) {\n                const float * bias = Bias + n;\n                if (NBLKS < 16) {\n                    __asm__ volatile(\n                        \"vsetvli        t0, %[N], e32, m2     \\n\\t\"\n                        \"vle32.v        v0, (%[SRC])          \\n\\t\"\n                        \"vse32.v        v0, (%[DST])          \\n\\t\"\n                        :\n                        : [SRC] \"r\"(bias), [DST] \"r\"(tmp), [N] \"r\"(NBLKS)\n                        : \"cc\", \"t0\");\n                    bias = tmp;\n                }\n                __asm__ volatile(LOAD_BIAS\n\n                                 \"addi               t3, %[BlockCountK], 0       \\n\\t\"\n\n                                 \"vsetvli            t0, zero, e8, m1            \\n\\t\"\n                                 \"li                 s1, 24                      \\n\\t\"\n                                 \"vmv.v.i            v1, 3                       \\n\\t\"\n                                 \"vsetvli            t0, s1, e8, m1              \\n\\t\"\n                                 \"vmv.v.i            v1, 2                       \\n\\t\"\n                                 \"vsetvli            t0, zero, e8, mf2           \\n\\t\"\n                                 \"vmv.v.i            v1, 1                       \\n\\t\"\n                                 \"vsetvli            t0, zero, e8, mf4           \\n\\t\"\n                                 \"vmv.v.i            v1, 0                       \\n\\t\"\n\n                                 \"addi               a1, %[A], 0                 \\n\\t\"\n                                 \"addi               s1, %[B], 0                 \\n\\t\"\n\n                                 \"BLOCK_COUNTK_LOOP%=:                           \\n\\t\"\n                                 // scale offset\n                                 \"addi               s5, s1, 0                   \\n\\t\"\n                                 // zp offset\n                                 \"addi               s6, s1, 32                  \\n\\t\"\n                                 \"addi               s1, s6, 16                  \\n\\t\"\n                                 \"addi               s2, s1, 32                  \\n\\t\"\n                                 \"addi               s3, s1, 32*2                \\n\\t\"\n                                 \"addi               s4, s1, 32*3                \\n\\t\"\n\n                                 \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                                 \"vxor.vv            v16, v16, v16               \\n\\t\"\n                                 // load a scale\n                                 \"flw                f1, (a1)                    \\n\\t\"\n                                 \"flw                f2, 4(a1)                   \\n\\t\"\n                                 \"flw                f3, 8(a1)                   \\n\\t\"\n                                 \"flw                f4, 12(a1)                  \\n\\t\"\n                                 \"addi               a1, a1, 16                  \\n\\t\"\n                                 \"addi               t2, %[INNER], 0             \\n\\t\"\n\n                                 SQ4BIT_KERNEL_LOAD_ZP_16X1_v2\n\n                                 \"BLOCK_INNER_LOOP%=:                            \\n\\t\"\n\n                                 LOAD_B_16x8x2\n\n                                 \"vle8.v             v10, (a1)                   \\n\\t\"\n                                 \"addi               a1, a1, 32                  \\n\\t\"\n                                 \"vle8.v             v11, (a1)                   \\n\\t\"\n                                 \"addi               a1, a1, 32                  \\n\\t\"\n                                 \"vsub.vv            v2, v2, v12                 \\n\\t\"\n                                 \"vsub.vv            v6, v6, v12                 \\n\\t\"\n                                 \"vsub.vv            v3, v3, v13                 \\n\\t\"\n                                 \"vsub.vv            v7, v7, v13                 \\n\\t\"\n                                 \"vsub.vv            v4, v4, v14                 \\n\\t\"\n                                 \"vsub.vv            v8, v8, v14                 \\n\\t\"\n                                 \"vsub.vv            v5, v5, v15                 \\n\\t\"\n                                 \"vsub.vv            v9, v9, v15                 \\n\\t\"\n\n                                 SQ4BIT_KERNEL_COMP_4x16x16\n\n                                 \"addi               t2, t2, -1                  \\n\\t\"\n                                 \"bnez               t2, BLOCK_INNER_LOOP%=      \\n\\t\"\n\n                                 LOAD_SCALE_4x16_FP16\n\n                                 \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                                 \"vfcvt.f.x.v        v16, v16                    \\n\\t\"\n                                 \"vfmacc.vv          v24, v16, v8                \\n\\t\"\n                                 \"addi               t3, t3, -1                  \\n\\t\"\n                                 \"bnez               t3, BLOCK_COUNTK_LOOP%=     \\n\\t\"\n\n                                 \"RESULT_SAVE%=:                                 \\n\\t\"\n\n                                 SAVE_RESULT_4x16\n\n                                 :\n                                 : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [LDC] \"r\"(LDC),\n                                   [BlockCountK] \"r\"(BlockCountK), [C] \"r\"(CPtr), [BIAS] \"r\"(bias)\n                                 : \"cc\", \"t0\", \"t1\", \"t2\", \"t3\", \"a1\", \"a2\", \"a3\", \"a4\", \"f1\", \"f2\", \"f3\", \"f4\", \"s1\",\n                                   \"s2\", \"s3\", \"s4\", \"s5\", \"s6\");\n\n            } else {\n                __asm__ volatile(\n                    \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                    \"vxor.vv            v24, v24, v24               \\n\\t\"\n                    \"addi               t3, %[BlockCountK], 0       \\n\\t\"\n                    \"vsetvli            t0, zero, e8, m1            \\n\\t\"\n                    \"li                 s1, 24                      \\n\\t\"\n                    \"vmv.v.i            v1, 3                       \\n\\t\"\n                    \"vsetvli            t0, s1, e8, m1              \\n\\t\"\n                    \"vmv.v.i            v1, 2                       \\n\\t\"\n                    \"vsetvli            t0, zero, e8, mf2           \\n\\t\"\n                    \"vmv.v.i            v1, 1                       \\n\\t\"\n                    \"vsetvli            t0, zero, e8, mf4           \\n\\t\"\n                    \"vmv.v.i            v1, 0                       \\n\\t\"\n                    \"addi               a1, %[A], 0                 \\n\\t\"\n                    \"addi               s1, %[B], 0                 \\n\\t\"\n                    \"BLOCK_COUNTK_LOOP%=:                           \\n\\t\"\n                    // scale offset\n                    \"addi               s5, s1, 0                   \\n\\t\"\n                    // zp offset\n                    \"addi               s6, s1, 32                  \\n\\t\"\n                    \"addi               s1, s6, 16                  \\n\\t\"\n                    \"addi               s2, s1, 32                  \\n\\t\"\n                    \"addi               s3, s1, 32*2                \\n\\t\"\n                    \"addi               s4, s1, 32*3                \\n\\t\"\n\n                    \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                    \"vxor.vv            v16, v16, v16               \\n\\t\"\n                    // load a scale\n                    \"flw                f1, (a1)                    \\n\\t\"\n                    \"flw                f2, 4(a1)                   \\n\\t\"\n                    \"flw                f3, 8(a1)                   \\n\\t\"\n                    \"flw                f4, 12(a1)                  \\n\\t\"\n                    \"addi               a1, a1, 16                  \\n\\t\"\n                    \"addi               t2, %[INNER], 0             \\n\\t\"\n\n                    SQ4BIT_KERNEL_LOAD_ZP_16X1_v2\n\n                    \"BLOCK_INNER_LOOP%=:                            \\n\\t\"\n\n                    LOAD_B_16x8x2\n\n                    \"vle8.v             v10, (a1)                   \\n\\t\"\n                    \"addi               a1, a1, 32                  \\n\\t\"\n                    \"vle8.v             v11, (a1)                   \\n\\t\"\n                    \"addi               a1, a1, 32                  \\n\\t\"\n                    \"vsub.vv            v2, v2, v12                 \\n\\t\"\n                    \"vsub.vv            v6, v6, v12                 \\n\\t\"\n                    \"vsub.vv            v3, v3, v13                 \\n\\t\"\n                    \"vsub.vv            v7, v7, v13                 \\n\\t\"\n                    \"vsub.vv            v4, v4, v14                 \\n\\t\"\n                    \"vsub.vv            v8, v8, v14                 \\n\\t\"\n                    \"vsub.vv            v5, v5, v15                 \\n\\t\"\n                    \"vsub.vv            v9, v9, v15                 \\n\\t\"\n\n                    SQ4BIT_KERNEL_COMP_4x16x16\n\n                    \"addi               t2, t2, -1                  \\n\\t\"\n                    \"bnez               t2, BLOCK_INNER_LOOP%=      \\n\\t\"\n\n                    LOAD_SCALE_4x16_FP16\n\n                    \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                    \"vfcvt.f.x.v        v16, v16                    \\n\\t\"\n                    \"vfmacc.vv          v24, v16, v8                \\n\\t\"\n                    \"addi               t3, t3, -1                  \\n\\t\"\n                    \"bnez               t3, BLOCK_COUNTK_LOOP%=     \\n\\t\"\n\n                    \"RESULT_SAVE%=:                                 \\n\\t\"\n\n                    SAVE_RESULT_4x16\n\n                    :\n                    : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [LDC] \"r\"(LDC),\n                      [BlockCountK] \"r\"(BlockCountK), [C] \"r\"(CPtr)\n                    : \"cc\", \"t0\", \"t1\", \"t2\", \"t3\", \"a1\", \"a2\", \"a3\", \"a4\", \"f1\", \"f2\", \"f3\", \"f4\", \"s1\", \"s2\", \"s3\",\n                      \"s4\", \"s5\", \"s6\");\n            }\n        }\n    } else {\n        for (size_t n = 0; n < CountN; n += 16) {\n            size_t      NBLKS         = (CountN - n) > 16 ? 16 : CountN - n;\n            std::byte * QuantBDataPtr = (std::byte *) QuantBData +         //\n                                        n * BlockCountK * BlkLen / 2 +     // b data\n                                        n * BlockCountK * sizeof(_Float16);  // scale\n            float * CPtr = C + n;\n            if (NBLKS < 16) {\n                CPtr = tmp;\n                LDC  = 16 * sizeof(float);\n            }\n            if (Bias != nullptr) {\n                const float * bias = Bias + n;\n                if (NBLKS < 16) {\n                    __asm__ volatile(\n                        \"vsetvli        t0, %[N], e32, m2     \\n\\t\"\n                        \"vle32.v        v0, (%[SRC])          \\n\\t\"\n                        \"vse32.v        v0, (%[DST])          \\n\\t\"\n                        :\n                        : [SRC] \"r\"(bias), [DST] \"r\"(tmp), [N] \"r\"(NBLKS)\n                        : \"cc\", \"t0\");\n                    bias = tmp;\n                }\n                __asm__ volatile(LOAD_BIAS\n\n                                 \"addi               t3, %[BlockCountK], 0       \\n\\t\"\n                                 \"addi               a1, %[A], 0                 \\n\\t\"\n                                 \"addi               s1, %[B], 0                 \\n\\t\"\n                                 \"BLOCK_COUNTK_LOOP%=:                           \\n\\t\"\n                                 \"addi               s5, s1, 0                   \\n\\t\"\n                                 \"addi               s1, s5, 32                  \\n\\t\"\n                                 \"addi               s2, s1, 32                  \\n\\t\"\n                                 \"addi               s3, s1, 32*2                \\n\\t\"\n                                 \"addi               s4, s1, 32*3                \\n\\t\"\n                                 \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                                 \"vxor.vv            v16, v16, v16               \\n\\t\"\n                                 // load a scale\n                                 \"flw                f1, (a1)                    \\n\\t\"\n                                 \"flw                f2, 4(a1)                   \\n\\t\"\n                                 \"flw                f3, 8(a1)                   \\n\\t\"\n                                 \"flw                f4, 12(a1)                  \\n\\t\"\n                                 \"addi               a1, a1, 16                  \\n\\t\"\n                                 \"addi               t2, %[INNER], 0             \\n\\t\"\n                                 \"BLOCK_INNER_LOOP%=:                            \\n\\t\"\n\n                                 LOAD_B_16x8x2\n\n                                 \"vsetvli            t0, zero, e8, m1            \\n\\t\"\n                                 \"vle8.v             v10, (a1)                   \\n\\t\"\n                                 \"addi               a1, a1, 32                  \\n\\t\"\n                                 \"vle8.v             v11, (a1)                   \\n\\t\"\n                                 \"addi               a1, a1, 32                  \\n\\t\"\n                                 \"vadd.vi            v2, v2, -8                  \\n\\t\"\n                                 \"vadd.vi            v3, v3, -8                  \\n\\t\"\n                                 \"vadd.vi            v4, v4, -8                  \\n\\t\"\n                                 \"vadd.vi            v5, v5, -8                  \\n\\t\"\n                                 \"vadd.vi            v6, v6, -8                  \\n\\t\"\n                                 \"vadd.vi            v7, v7, -8                  \\n\\t\"\n                                 \"vadd.vi            v8, v8, -8                  \\n\\t\"\n                                 \"vadd.vi            v9, v9, -8                  \\n\\t\"\n\n                                 SQ4BIT_KERNEL_COMP_4x16x16\n\n                                 \"addi               t2, t2, -1                  \\n\\t\"\n                                 \"bnez               t2, BLOCK_INNER_LOOP%=      \\n\\t\"\n\n                                 LOAD_SCALE_4x16_FP16\n\n                                 \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                                 \"vfcvt.f.x.v        v16, v16                    \\n\\t\"\n                                 \"vfmacc.vv          v24, v16, v8                \\n\\t\"\n                                 \"addi               t3, t3, -1                  \\n\\t\"\n                                 \"bnez               t3, BLOCK_COUNTK_LOOP%=     \\n\\t\"\n                                 \"RESULT_SAVE%=:                                 \\n\\t\"\n\n                                 SAVE_RESULT_4x16\n\n                                 :\n                                 : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [LDC] \"r\"(LDC),\n                                   [BlockCountK] \"r\"(BlockCountK), [C] \"r\"(CPtr), [BIAS] \"r\"(bias)\n                                 : \"cc\", \"t0\", \"t1\", \"t2\", \"t3\", \"a1\", \"a2\", \"a3\", \"a4\", \"f1\", \"f2\", \"f3\", \"f4\", \"s1\",\n                                   \"s2\", \"s3\", \"s4\", \"s5\", \"s6\");\n\n            } else {\n                __asm__ volatile(\n                    \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                    \"vxor.vv            v24, v24, v24               \\n\\t\"\n                    \"addi               t3, %[BlockCountK], 0       \\n\\t\"\n                    \"addi               a1, %[A], 0                 \\n\\t\"\n                    \"addi               s1, %[B], 0                 \\n\\t\"\n                    \"BLOCK_COUNTK_LOOP%=:                           \\n\\t\"\n                    \"addi               s5, s1, 0                   \\n\\t\"\n                    \"addi               s1, s5, 32                  \\n\\t\"\n                    \"addi               s2, s1, 32                  \\n\\t\"\n                    \"addi               s3, s1, 32*2                \\n\\t\"\n                    \"addi               s4, s1, 32*3                \\n\\t\"\n                    \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                    \"vxor.vv            v16, v16, v16               \\n\\t\"\n                    // load a scale\n                    \"flw                f1, (a1)                    \\n\\t\"\n                    \"flw                f2, 4(a1)                   \\n\\t\"\n                    \"flw                f3, 8(a1)                   \\n\\t\"\n                    \"flw                f4, 12(a1)                  \\n\\t\"\n                    \"addi               a1, a1, 16                  \\n\\t\"\n                    \"addi               t2, %[INNER], 0             \\n\\t\"\n                    \"BLOCK_INNER_LOOP%=:                            \\n\\t\"\n\n                    LOAD_B_16x8x2\n\n                    \"vsetvli            t0, zero, e8, m1            \\n\\t\"\n                    \"vle8.v             v10, (a1)                   \\n\\t\"\n                    \"addi               a1, a1, 32                  \\n\\t\"\n                    \"vle8.v             v11, (a1)                   \\n\\t\"\n                    \"addi               a1, a1, 32                  \\n\\t\"\n                    \"vadd.vi            v2, v2, -8                  \\n\\t\"\n                    \"vadd.vi            v3, v3, -8                  \\n\\t\"\n                    \"vadd.vi            v4, v4, -8                  \\n\\t\"\n                    \"vadd.vi            v5, v5, -8                  \\n\\t\"\n                    \"vadd.vi            v6, v6, -8                  \\n\\t\"\n                    \"vadd.vi            v7, v7, -8                  \\n\\t\"\n                    \"vadd.vi            v8, v8, -8                  \\n\\t\"\n                    \"vadd.vi            v9, v9, -8                  \\n\\t\"\n\n                    SQ4BIT_KERNEL_COMP_4x16x16\n\n                    \"addi               t2, t2, -1                  \\n\\t\"\n                    \"bnez               t2, BLOCK_INNER_LOOP%=      \\n\\t\"\n\n                    LOAD_SCALE_4x16_FP16\n\n                    \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                    \"vfcvt.f.x.v        v16, v16                    \\n\\t\"\n                    \"vfmacc.vv          v24, v16, v8                \\n\\t\"\n                    \"addi               t3, t3, -1                  \\n\\t\"\n                    \"bnez               t3, BLOCK_COUNTK_LOOP%=     \\n\\t\"\n                    \"RESULT_SAVE%=:                                 \\n\\t\"\n\n                    SAVE_RESULT_4x16\n\n                    :\n                    : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [LDC] \"r\"(LDC),\n                      [BlockCountK] \"r\"(BlockCountK), [C] \"r\"(CPtr)\n                    : \"cc\", \"t0\", \"t1\", \"t2\", \"t3\", \"a1\", \"a2\", \"a3\", \"a4\", \"f1\", \"f2\", \"f3\", \"f4\", \"s1\", \"s2\", \"s3\",\n                      \"s4\", \"s5\", \"s6\");\n            }\n        }\n    }\n    if (CountN % 16 != 0) {\n        // stroe output from tmp to C when NBLKS less than 16.\n        float *      CPtr = C + CountN / 16 * 16;\n        const size_t N    = CountN % 16;\n        LDC               = ldc * sizeof(float);\n        __asm__ volatile(\n            \"vsetvli            t0, %[N], e32, m2       \\n\\t\"\n            \"vle32.v            v0, (%[SRC])            \\n\\t\"\n            \"addi               s2, %[SRC], 64          \\n\\t\"\n            \"addi               s3, %[SRC], 64*2        \\n\\t\"\n            \"addi               s4, %[SRC], 64*3        \\n\\t\"\n            \"vle32.v            v2, (s2)                \\n\\t\"\n            \"vle32.v            v4, (s3)                \\n\\t\"\n            \"vle32.v            v6, (s4)                \\n\\t\"\n            \"add                t2, %[DST], %[LDC]      \\n\\t\"\n            \"add                t3, t2, %[LDC]          \\n\\t\"\n            \"add                t4, t3, %[LDC]          \\n\\t\"\n            \"vse32.v            v0, (%[DST])            \\n\\t\"\n            \"vse32.v            v2, (t2)                \\n\\t\"\n            \"vse32.v            v4, (t3)                \\n\\t\"\n            \"vse32.v            v6, (t4)                \\n\\t\"\n            :\n            : [N] \"r\"(N), [SRC] \"r\"(tmp), [DST] \"r\"(CPtr), [LDC] \"r\"(LDC)\n            : \"cc\", \"t0\", \"t2\", \"t3\", \"t4\", \"s2\", \"s3\", \"s4\");\n    }\n}\n\ntemplate <bool HasZeroPoint>\nvoid SQ4BitGemmM4Kernel_CompInt8_Impl(size_t            BlkLen,\n                                      const std::byte * QuantA,\n                                      const std::byte * QuantBData,\n                                      const float *     QuantBScale,\n                                      const std::byte * QuantBZeroPoint,\n                                      float *           C,\n                                      size_t            CountN,\n                                      size_t            BlockCountK,\n                                      const float *     Bias,\n                                      const size_t      ldc) {\n    GGML_UNUSED(QuantBScale);\n    GGML_UNUSED(QuantBZeroPoint);\n    size_t       LDC   = ldc * sizeof(float);\n    const size_t INNER = BlkLen / 16;\n    float        tmp[4 * 16];\n\n    if constexpr (HasZeroPoint) {\n        for (size_t n = 0; n < CountN; n += 16) {\n            size_t      NBLKS         = (CountN - n) > 16 ? 16 : CountN - n;\n            std::byte * QuantBDataPtr = (std::byte *) QuantBData +           //\n                                        n * BlockCountK * BlkLen / 2 +       // b data\n                                        n * BlockCountK * sizeof(uint8_t) +  // zp\n                                        n * BlockCountK * sizeof(float);     // scale\n            float * CPtr = C + n;\n            if (NBLKS < 16) {\n                CPtr = tmp;\n                LDC  = 16 * sizeof(float);\n            }\n            if (Bias != nullptr) {\n                const float * bias = Bias + n;\n                if (NBLKS < 16) {\n                    __asm__ volatile(\n                        \"vsetvli        t0, %[N], e32, m2     \\n\\t\"\n                        \"vle32.v        v0, (%[SRC])          \\n\\t\"\n                        \"vse32.v        v0, (%[DST])          \\n\\t\"\n                        :\n                        : [SRC] \"r\"(bias), [DST] \"r\"(tmp), [N] \"r\"(NBLKS)\n                        : \"cc\", \"t0\");\n                    bias = tmp;\n                }\n\n                __asm__ volatile(LOAD_BIAS\n                                 \"addi               t3, %[BlockCountK], 0       \\n\\t\"\n                                 \"vsetvli            t0, zero, e8, m1            \\n\\t\"\n                                 \"li                 s1, 24                      \\n\\t\"\n                                 \"vmv.v.i            v1, 3                       \\n\\t\"\n                                 \"vsetvli            t0, s1, e8, m1              \\n\\t\"\n                                 \"vmv.v.i            v1, 2                       \\n\\t\"\n                                 \"vsetvli            t0, zero, e8, mf2           \\n\\t\"\n                                 \"vmv.v.i            v1, 1                       \\n\\t\"\n                                 \"vsetvli            t0, zero, e8, mf4           \\n\\t\"\n                                 \"vmv.v.i            v1, 0                       \\n\\t\"\n                                 \"addi               a1, %[A], 0                 \\n\\t\"\n                                 \"addi               s1, %[B], 0                 \\n\\t\"\n                                 \"BLOCK_COUNTK_LOOP%=:                           \\n\\t\"\n                                 // scale offset\n                                 \"addi               s5, s1, 0                   \\n\\t\"\n                                 // zp offset\n                                 \"addi               s6, s1, 64                  \\n\\t\"\n                                 \"addi               s1, s6, 16                  \\n\\t\"\n                                 \"addi               s2, s1, 32                  \\n\\t\"\n                                 \"addi               s3, s1, 32*2                \\n\\t\"\n                                 \"addi               s4, s1, 32*3                \\n\\t\"\n                                 \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                                 \"vxor.vv            v16, v16, v16               \\n\\t\"\n                                 // load a scale\n                                 \"flw                f1, (a1)                    \\n\\t\"\n                                 \"flw                f2, 4(a1)                   \\n\\t\"\n                                 \"flw                f3, 8(a1)                   \\n\\t\"\n                                 \"flw                f4, 12(a1)                  \\n\\t\"\n                                 \"addi               a1, a1, 16                  \\n\\t\"\n                                 \"addi               t2, %[INNER], 0             \\n\\t\"\n\n                                 SQ4BIT_KERNEL_LOAD_ZP_16X1_v2\n\n                                 \"BLOCK_INNER_LOOP%=:                            \\n\\t\"\n\n                                 LOAD_B_16x8x2\n\n                                 \"vle8.v             v10, (a1)                   \\n\\t\"\n                                 \"addi               a1, a1, 32                  \\n\\t\"\n                                 \"vle8.v             v11, (a1)                   \\n\\t\"\n                                 \"addi               a1, a1, 32                  \\n\\t\"\n                                 \"vsub.vv            v2, v2, v12                 \\n\\t\"\n                                 \"vsub.vv            v6, v6, v12                 \\n\\t\"\n                                 \"vsub.vv            v3, v3, v13                 \\n\\t\"\n                                 \"vsub.vv            v7, v7, v13                 \\n\\t\"\n                                 \"vsub.vv            v4, v4, v14                 \\n\\t\"\n                                 \"vsub.vv            v8, v8, v14                 \\n\\t\"\n                                 \"vsub.vv            v5, v5, v15                 \\n\\t\"\n                                 \"vsub.vv            v9, v9, v15                 \\n\\t\"\n\n                                 SQ4BIT_KERNEL_COMP_4x16x16\n\n                                 \"addi               t2, t2, -1                  \\n\\t\"\n                                 \"bnez               t2, BLOCK_INNER_LOOP%=      \\n\\t\"\n\n                                 LOAD_SCALE_4x16\n\n                                 \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                                 \"vfcvt.f.x.v        v16, v16                    \\n\\t\"\n                                 \"vfmacc.vv          v24, v16, v8                \\n\\t\"\n                                 \"addi               t3, t3, -1                  \\n\\t\"\n                                 \"bnez               t3, BLOCK_COUNTK_LOOP%=     \\n\\t\"\n\n                                 \"RESULT_SAVE%=:                                 \\n\\t\"\n\n                                 SAVE_RESULT_4x16\n\n                                 :\n                                 : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [LDC] \"r\"(LDC),\n                                   [BlockCountK] \"r\"(BlockCountK), [C] \"r\"(CPtr), [BIAS] \"r\"(bias)\n                                 : \"cc\", \"t0\", \"t1\", \"t2\", \"t3\", \"a1\", \"a2\", \"a3\", \"a4\", \"f1\", \"f2\", \"f3\", \"f4\", \"s1\",\n                                   \"s2\", \"s3\", \"s4\", \"s5\", \"s6\");\n\n            } else {\n                __asm__ volatile(\n                    \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                    \"vxor.vv            v24, v24, v24               \\n\\t\"\n                    \"addi               t3, %[BlockCountK], 0       \\n\\t\"\n                    \"vsetvli            t0, zero, e8, m1            \\n\\t\"\n                    \"li                 s1, 24                      \\n\\t\"\n                    \"vmv.v.i            v1, 3                       \\n\\t\"\n                    \"vsetvli            t0, s1, e8, m1              \\n\\t\"\n                    \"vmv.v.i            v1, 2                       \\n\\t\"\n                    \"vsetvli            t0, zero, e8, mf2           \\n\\t\"\n                    \"vmv.v.i            v1, 1                       \\n\\t\"\n                    \"vsetvli            t0, zero, e8, mf4           \\n\\t\"\n                    \"vmv.v.i            v1, 0                       \\n\\t\"\n                    \"addi               a1, %[A], 0                 \\n\\t\"\n                    \"addi               s1, %[B], 0                 \\n\\t\"\n                    \"BLOCK_COUNTK_LOOP%=:                           \\n\\t\"\n                    // scale offset\n                    \"addi               s5, s1, 0                   \\n\\t\"\n                    // zp offset\n                    \"addi               s6, s1, 64                  \\n\\t\"\n                    \"addi               s1, s6, 16                  \\n\\t\"\n                    \"addi               s2, s1, 32                  \\n\\t\"\n                    \"addi               s3, s1, 32*2                \\n\\t\"\n                    \"addi               s4, s1, 32*3                \\n\\t\"\n                    \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                    \"vxor.vv            v16, v16, v16               \\n\\t\"\n                    // load a scale\n                    // load a scale\n                    \"flw                f1, (a1)                    \\n\\t\"\n                    \"flw                f2, 4(a1)                   \\n\\t\"\n                    \"flw                f3, 8(a1)                   \\n\\t\"\n                    \"flw                f4, 12(a1)                  \\n\\t\"\n                    \"addi               a1, a1, 16                  \\n\\t\"\n                    \"addi               t2, %[INNER], 0             \\n\\t\"\n\n                    SQ4BIT_KERNEL_LOAD_ZP_16X1_v2\n\n                    \"BLOCK_INNER_LOOP%=:                            \\n\\t\"\n\n                    LOAD_B_16x8x2\n\n                    \"vle8.v             v10, (a1)                   \\n\\t\"\n                    \"addi               a1, a1, 32                  \\n\\t\"\n                    \"vle8.v             v11, (a1)                   \\n\\t\"\n                    \"addi               a1, a1, 32                  \\n\\t\"\n                    \"vsub.vv            v2, v2, v12                 \\n\\t\"\n                    \"vsub.vv            v6, v6, v12                 \\n\\t\"\n                    \"vsub.vv            v3, v3, v13                 \\n\\t\"\n                    \"vsub.vv            v7, v7, v13                 \\n\\t\"\n                    \"vsub.vv            v4, v4, v14                 \\n\\t\"\n                    \"vsub.vv            v8, v8, v14                 \\n\\t\"\n                    \"vsub.vv            v5, v5, v15                 \\n\\t\"\n                    \"vsub.vv            v9, v9, v15                 \\n\\t\"\n\n                    SQ4BIT_KERNEL_COMP_4x16x16\n\n                    \"addi               t2, t2, -1                  \\n\\t\"\n                    \"bnez               t2, BLOCK_INNER_LOOP%=      \\n\\t\"\n\n                    LOAD_SCALE_4x16\n\n                    \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                    \"vfcvt.f.x.v        v16, v16                    \\n\\t\"\n                    \"vfmacc.vv          v24, v16, v8                \\n\\t\"\n                    \"addi               t3, t3, -1                  \\n\\t\"\n                    \"bnez               t3, BLOCK_COUNTK_LOOP%=     \\n\\t\"\n\n                    \"RESULT_SAVE%=:                                 \\n\\t\"\n\n                    SAVE_RESULT_4x16\n\n                    :\n                    : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [LDC] \"r\"(LDC),\n                      [BlockCountK] \"r\"(BlockCountK), [C] \"r\"(CPtr)\n                    : \"cc\", \"t0\", \"t1\", \"t2\", \"t3\", \"a1\", \"a2\", \"a3\", \"a4\", \"f1\", \"f2\", \"f3\", \"f4\", \"s1\", \"s2\", \"s3\",\n                      \"s4\", \"s5\", \"s6\");\n            }\n        }\n    } else {\n        for (size_t n = 0; n < CountN; n += 16) {\n            size_t      NBLKS         = (CountN - n) > 16 ? 16 : CountN - n;\n            std::byte * QuantBDataPtr = (std::byte *) QuantBData +        //\n                                        n * BlockCountK * BlkLen / 2 +    // b data\n                                        n * BlockCountK * sizeof(float);  // scale\n            float * CPtr = C + n;\n            if (NBLKS < 16) {\n                CPtr = tmp;\n                LDC  = 16 * sizeof(float);\n            }\n            if (Bias != nullptr) {\n                const float * bias = Bias + n;\n                if (NBLKS < 16) {\n                    __asm__ volatile(\n                        \"vsetvli        t0, %[N], e32, m2     \\n\\t\"\n                        \"vle32.v        v0, (%[SRC])          \\n\\t\"\n                        \"vse32.v        v0, (%[DST])          \\n\\t\"\n                        :\n                        : [SRC] \"r\"(bias), [DST] \"r\"(tmp), [N] \"r\"(NBLKS)\n                        : \"cc\", \"t0\");\n                    bias = tmp;\n                }\n                __asm__ volatile(LOAD_BIAS\n                                 \"addi               t3, %[BlockCountK], 0       \\n\\t\"\n                                 \"addi               a1, %[A], 0                 \\n\\t\"\n                                 \"addi               s1, %[B], 0                 \\n\\t\"\n                                 \"BLOCK_COUNTK_LOOP%=:                           \\n\\t\"\n                                 \"addi               s5, s1, 0                   \\n\\t\"\n                                 \"addi               s1, s5, 64                  \\n\\t\"\n                                 \"addi               s2, s1, 32                  \\n\\t\"\n                                 \"addi               s3, s1, 32*2                \\n\\t\"\n                                 \"addi               s4, s1, 32*3                \\n\\t\"\n                                 \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                                 \"vxor.vv            v16, v16, v16               \\n\\t\"\n                                 // load a scale\n                                 \"flw                f1, (a1)                    \\n\\t\"\n                                 \"flw                f2, 4(a1)                   \\n\\t\"\n                                 \"flw                f3, 8(a1)                   \\n\\t\"\n                                 \"flw                f4, 12(a1)                  \\n\\t\"\n                                 \"addi               a1, a1, 16                  \\n\\t\"\n                                 \"addi               t2, %[INNER], 0             \\n\\t\"\n                                 \"BLOCK_INNER_LOOP%=:                            \\n\\t\"\n\n                                 LOAD_B_16x8x2\n\n                                 \"vsetvli            t0, zero, e8, m1            \\n\\t\"\n                                 \"vle8.v             v10, (a1)                   \\n\\t\"\n                                 \"addi               a1, a1, 32                  \\n\\t\"\n                                 \"vle8.v             v11, (a1)                   \\n\\t\"\n                                 \"addi               a1, a1, 32                  \\n\\t\"\n                                 \"vadd.vi            v2, v2, -8                  \\n\\t\"\n                                 \"vadd.vi            v3, v3, -8                  \\n\\t\"\n                                 \"vadd.vi            v4, v4, -8                  \\n\\t\"\n                                 \"vadd.vi            v5, v5, -8                  \\n\\t\"\n                                 \"vadd.vi            v6, v6, -8                  \\n\\t\"\n                                 \"vadd.vi            v7, v7, -8                  \\n\\t\"\n                                 \"vadd.vi            v8, v8, -8                  \\n\\t\"\n                                 \"vadd.vi            v9, v9, -8                  \\n\\t\"\n\n                                 SQ4BIT_KERNEL_COMP_4x16x16\n\n                                 \"addi               t2, t2, -1                  \\n\\t\"\n                                 \"bnez               t2, BLOCK_INNER_LOOP%=      \\n\\t\"\n\n                                 LOAD_SCALE_4x16\n\n                                 \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                                 \"vfcvt.f.x.v        v16, v16                    \\n\\t\"\n                                 \"vfmacc.vv          v24, v16, v8                \\n\\t\"\n                                 \"addi               t3, t3, -1                  \\n\\t\"\n                                 \"bnez               t3, BLOCK_COUNTK_LOOP%=     \\n\\t\"\n\n                                 \"RESULT_SAVE%=:                                 \\n\\t\"\n\n                                 SAVE_RESULT_4x16\n\n                                 :\n                                 : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [LDC] \"r\"(LDC),\n                                   [BlockCountK] \"r\"(BlockCountK), [C] \"r\"(CPtr), [BIAS] \"r\"(bias)\n                                 : \"cc\", \"t0\", \"t1\", \"t2\", \"t3\", \"a1\", \"a2\", \"a3\", \"a4\", \"f1\", \"f2\", \"f3\", \"f4\", \"s1\",\n                                   \"s2\", \"s3\", \"s4\", \"s5\", \"s6\");\n\n            } else {\n                __asm__ volatile(\n                    \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                    \"vxor.vv            v24, v24, v24               \\n\\t\"\n                    \"addi               t3, %[BlockCountK], 0       \\n\\t\"\n                    \"addi               a1, %[A], 0                 \\n\\t\"\n                    \"addi               s1, %[B], 0                 \\n\\t\"\n                    \"BLOCK_COUNTK_LOOP%=:                           \\n\\t\"\n                    \"addi               s5, s1, 0                   \\n\\t\"\n                    \"addi               s1, s5, 64                  \\n\\t\"\n                    \"addi               s2, s1, 32                  \\n\\t\"\n                    \"addi               s3, s1, 32*2                \\n\\t\"\n                    \"addi               s4, s1, 32*3                \\n\\t\"\n                    \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                    \"vxor.vv            v16, v16, v16               \\n\\t\"\n                    // load a scale\n                    \"flw                f1, (a1)                    \\n\\t\"\n                    \"flw                f2, 4(a1)                   \\n\\t\"\n                    \"flw                f3, 8(a1)                   \\n\\t\"\n                    \"flw                f4, 12(a1)                  \\n\\t\"\n                    \"addi               a1, a1, 16                  \\n\\t\"\n                    \"addi               t2, %[INNER], 0             \\n\\t\"\n                    \"BLOCK_INNER_LOOP%=:                            \\n\\t\"\n\n                    LOAD_B_16x8x2\n\n                    \"vsetvli            t0, zero, e8, m1            \\n\\t\"\n                    \"vle8.v             v10, (a1)                   \\n\\t\"\n\n                    \"addi               a1, a1, 32                  \\n\\t\"\n                    \"vle8.v             v11, (a1)                   \\n\\t\"\n                    \"addi               a1, a1, 32                  \\n\\t\"\n                    \"vadd.vi            v2, v2, -8                  \\n\\t\"\n                    \"vadd.vi            v3, v3, -8                  \\n\\t\"\n                    \"vadd.vi            v4, v4, -8                  \\n\\t\"\n                    \"vadd.vi            v5, v5, -8                  \\n\\t\"\n                    \"vadd.vi            v6, v6, -8                  \\n\\t\"\n                    \"vadd.vi            v7, v7, -8                  \\n\\t\"\n                    \"vadd.vi            v8, v8, -8                  \\n\\t\"\n                    \"vadd.vi            v9, v9, -8                  \\n\\t\"\n\n                    SQ4BIT_KERNEL_COMP_4x16x16\n\n                    \"addi               t2, t2, -1                  \\n\\t\"\n                    \"bnez               t2, BLOCK_INNER_LOOP%=      \\n\\t\"\n\n                    LOAD_SCALE_4x16\n\n                    \"vsetvli            t0, zero, e32, m8           \\n\\t\"\n                    \"vfcvt.f.x.v        v16, v16                    \\n\\t\"\n                    \"vfmacc.vv          v24, v16, v8                \\n\\t\"\n                    \"addi               t3, t3, -1                  \\n\\t\"\n                    \"bnez               t3, BLOCK_COUNTK_LOOP%=     \\n\\t\"\n\n                    \"RESULT_SAVE%=:                                 \\n\\t\"\n\n                    SAVE_RESULT_4x16\n\n                    :\n                    : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [LDC] \"r\"(LDC),\n                      [BlockCountK] \"r\"(BlockCountK), [C] \"r\"(CPtr)\n                    : \"cc\", \"t0\", \"t1\", \"t2\", \"t3\", \"a1\", \"a2\", \"a3\", \"a4\", \"f1\", \"f2\", \"f3\", \"f4\", \"s1\", \"s2\", \"s3\",\n                      \"s4\", \"s5\", \"s6\");\n            }\n        }\n    }\n    if (CountN % 16 != 0) {\n        // stroe output from tmp to C when NBLKS less than 16.\n        float *      CPtr = C + CountN / 16 * 16;\n        const size_t N    = CountN % 16;\n        LDC               = ldc * sizeof(float);\n        __asm__ volatile(\n            \"vsetvli            t0, %[N], e32, m2       \\n\\t\"\n            \"vle32.v            v0, (%[SRC])            \\n\\t\"\n            \"addi               s2, %[SRC], 64          \\n\\t\"\n            \"addi               s3, %[SRC], 64*2        \\n\\t\"\n            \"addi               s4, %[SRC], 64*3        \\n\\t\"\n            \"vle32.v            v2, (s2)                \\n\\t\"\n            \"vle32.v            v4, (s3)                \\n\\t\"\n            \"vle32.v            v6, (s4)                \\n\\t\"\n            \"add                t2, %[DST], %[LDC]      \\n\\t\"\n            \"add                t3, t2, %[LDC]          \\n\\t\"\n            \"add                t4, t3, %[LDC]          \\n\\t\"\n            \"vse32.v            v0, (%[DST])            \\n\\t\"\n            \"vse32.v            v2, (t2)                \\n\\t\"\n            \"vse32.v            v4, (t3)                \\n\\t\"\n            \"vse32.v            v6, (t4)                \\n\\t\"\n            :\n            : [N] \"r\"(N), [SRC] \"r\"(tmp), [DST] \"r\"(CPtr), [LDC] \"r\"(LDC)\n            : \"cc\", \"t0\", \"t2\", \"t3\", \"t4\", \"s2\", \"s3\", \"s4\");\n    }\n}\n\ntemplate <bool HasZeroPoint>\nvoid SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t            BlkLen,\n                                                const std::byte * QuantA,\n                                                const std::byte * QuantBData,\n                                                const float *     QuantBScale,\n                                                const std::byte * QuantBZeroPoint,\n                                                float *           C,\n                                                size_t            CountN,\n                                                size_t            BlockCountK,\n                                                const float *     Bias) {\n    GGML_UNUSED(QuantBScale);\n    GGML_UNUSED(QuantBZeroPoint);\n    size_t INNER = BlkLen / 16;\n\n    if constexpr (HasZeroPoint) {\n        for (size_t n = 0; n < CountN; n += 16) {\n            size_t      nblks         = (CountN - n) > 16 ? 16 : CountN - n;\n            std::byte * QuantBDataPtr = (std::byte *) QuantBData +           //\n                                        n * BlockCountK * BlkLen / 2 +       // b data\n                                        n * BlockCountK * sizeof(uint8_t) +  // zp\n                                        n * BlockCountK * sizeof(_Float16);    // scale\n            float * CPtr = C + n;\n            size_t  cnt  = BlockCountK;\n            if (Bias != nullptr) {\n                const float * bias = Bias + n;\n                __asm__ volatile(\n                    \"addi         t3, %[NBLKS], 0         \\n\\t\"\n                    \"vsetvli      t0, zero, e8, m1        \\n\\t\"\n\n                    \"vmv.v.i      v13, 3                  \\n\\t\"\n                    \"li           s1, 24                  \\n\\t\"\n                    \"vsetvli      t0, s1, e8, m1          \\n\\t\"\n                    \"vmv.v.i      v13, 2                  \\n\\t\"\n                    \"vsetvli      t0, zero, e8, mf2       \\n\\t\"\n                    \"vmv.v.i      v13, 1                  \\n\\t\"\n                    \"vsetvli      t0, zero, e8, mf4       \\n\\t\"\n                    \"vmv.v.i      v13, 0                  \\n\\t\"\n                    \"addi         s1, %[B], 0             \\n\\t\"\n                    \"addi         s2, %[B], 8             \\n\\t\"\n                    \"addi         s3, %[B], 16            \\n\\t\"\n                    \"addi         s4, %[B], 24            \\n\\t\"\n                    // zp offset\n                    \"addi         s7, %[B], 32            \\n\\t\"\n                    // a offset\n                    \"addi         s5, %[A], 0             \\n\\t\"\n                    \"addi         s6, %[A], 12            \\n\\t\"\n\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v28, (%[BIAS])          \\n\\t\"\n                    \"sub          t3, t3, t0              \\n\\t\"\n                    \"addi         %[BIAS], %[BIAS], 16    \\n\\t\"\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v29, (%[BIAS])          \\n\\t\"\n                    \"sub          t3, t3, t0              \\n\\t\"\n                    \"addi         %[BIAS], %[BIAS], 16    \\n\\t\"\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v30, (%[BIAS])          \\n\\t\"\n                    \"sub          t3, t3, t0              \\n\\t\"\n                    \"addi         %[BIAS], %[BIAS], 16    \\n\\t\"\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v31, (%[BIAS])          \\n\\t\"\n\n                    \"LOOP_K%=:                            \\n\\t\"\n                    \"vsetvli      t0, zero, e16, mf4      \\n\\t\"\n\n                    \"vle16.v      v4, (s1)                \\n\\t\"\n                    \"addi         s1, s1, 48              \\n\\t\"\n                    \"vle16.v      v5, (s2)                \\n\\t\"\n                    \"addi         s2, s2, 72              \\n\\t\"\n                    \"vle16.v      v6, (s3)                \\n\\t\"\n                    \"addi         s3, s3, 96              \\n\\t\"\n                    \"vle16.v      v7, (s4)                \\n\\t\"\n                    \"addi         s4, s4, 120             \\n\\t\"\n                    \"flw          f1, (s5)                \\n\\t\"\n                    \"addi         s5, s5, 4               \\n\\t\"\n                    \"vfwcvt.f.f.v v8, v4                  \\n\\t\"\n                    \"vfwcvt.f.f.v v9, v5                  \\n\\t\"\n                    \"vfwcvt.f.f.v v10, v6                 \\n\\t\"\n                    \"vfwcvt.f.f.v v11, v7                 \\n\\t\"\n\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n                    \"addi         t5, %[INNER], 0         \\n\\t\"\n                    \"vxor.vv      v16, v16, v16           \\n\\t\"\n                    \"vxor.vv      v18, v18, v18           \\n\\t\"\n                    \"vxor.vv      v20, v20, v20           \\n\\t\"\n                    \"vxor.vv      v22, v22, v22           \\n\\t\"\n                    \"vfmul.vf     v24, v8, f1             \\n\\t\"\n                    \"vfmul.vf     v25, v9, f1             \\n\\t\"\n                    \"vfmul.vf     v26, v10, f1            \\n\\t\"\n                    \"vfmul.vf     v27, v11, f1            \\n\\t\"\n                    \"addi         %[CNT], %[CNT], -1      \\n\\t\"\n\n                    SQ4BIT_KERNEL_LOAD_ZP_16X1\n\n                    \"LOOP_INNER%=:                        \\n\\t\"\n\n                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4\n\n                    \"vsub.vv      v0, v0, v8              \\n\\t\"\n                    \"vsub.vv      v4, v4, v8              \\n\\t\"\n                    \"vsub.vv      v1, v1, v9              \\n\\t\"\n                    \"vsub.vv      v5, v5, v9              \\n\\t\"\n                    \"vsub.vv      v2, v2, v10             \\n\\t\"\n                    \"vsub.vv      v6, v6, v10             \\n\\t\"\n                    \"vsub.vv      v3, v3, v11             \\n\\t\"\n                    \"vsub.vv      v7, v7, v11             \\n\\t\"\n\n                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4\n\n                    \"bnez         t5, LOOP_INNER%=        \\n\\t\"\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n\n                    SQ4BIT_KERNEL_ACC_F16_1X4X4\n                    \"addi         s7, s1, 32              \\n\\t\"\n\n                    \"bnez         %[CNT], LOOP_K%=        \\n\\t\"\n                    \"addi         t3, zero, 16            \\n\\t\"\n                    \"addi         s1, %[C], 16            \\n\\t\"\n                    \"addi         s2, %[C], 32            \\n\\t\"\n                    \"addi         s3, %[C], 48            \\n\\t\"\n                    \"blt          %[NBLKS], t3, ST_TAIL%= \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"jal          x0, END%=               \\n\\t\"\n\n                    \"ST_TAIL%=:                           \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"END%=:                               \\n\\t\"\n\n                    : [CNT] \"+r\"(cnt), [NBLKS] \"+r\"(nblks), [BIAS] \"+r\"(bias)\n                    : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [C] \"r\"(CPtr)\n                    : \"cc\", \"t0\", \"t5\", \"t3\", \"f1\", \"s1\", \"s2\", \"s3\", \"s4\", \"s5\", \"s6\", \"s7\");\n            } else {\n                __asm__ volatile(\n                    \"vsetvli      t0, zero, e32, m4       \\n\\t\"\n                    \"vxor.vv      v28, v28, v28           \\n\\t\"\n\n                    \"vsetvli      t0, zero, e8, m1        \\n\\t\"\n                    \"vmv.v.i      v13, 3                  \\n\\t\"\n                    \"li           s1, 24                  \\n\\t\"\n                    \"vsetvli      t0, s1, e8, m1          \\n\\t\"\n                    \"vmv.v.i      v13, 2                  \\n\\t\"\n                    \"vsetvli      t0, zero, e8, mf2       \\n\\t\"\n                    \"vmv.v.i      v13, 1                  \\n\\t\"\n                    \"vsetvli      t0, zero, e8, mf4       \\n\\t\"\n                    \"vmv.v.i      v13, 0                  \\n\\t\"\n\n                    \"addi         s1, %[B], 0             \\n\\t\"\n                    \"addi         s2, %[B], 8             \\n\\t\"\n                    \"addi         s3, %[B], 16            \\n\\t\"\n                    \"addi         s4, %[B], 24            \\n\\t\"\n\n                    \"addi         s7, %[B], 32            \\n\\t\"\n\n                    \"addi         s5, %[A], 0             \\n\\t\"\n                    \"addi         s6, %[A], 12            \\n\\t\"\n                    \"LOOP_K%=:                            \\n\\t\"\n                    \"vsetvli      t0, zero, e16, mf4      \\n\\t\"\n                    \"vle16.v      v4, (s1)                \\n\\t\"\n                    \"addi         s1, s1, 48              \\n\\t\"\n                    \"vle16.v      v5, (s2)                \\n\\t\"\n                    \"addi         s2, s2, 72              \\n\\t\"\n                    \"vle16.v      v6, (s3)                \\n\\t\"\n                    \"addi         s3, s3, 96              \\n\\t\"\n                    \"vle16.v      v7, (s4)                \\n\\t\"\n                    \"addi         s4, s4, 120             \\n\\t\"\n                    \"flw          f1, (s5)                \\n\\t\"\n                    \"addi         s5, s5, 4               \\n\\t\"\n\n                    \"vfwcvt.f.f.v v8, v4                  \\n\\t\"\n                    \"vfwcvt.f.f.v v9, v5                  \\n\\t\"\n                    \"vfwcvt.f.f.v v10, v6                 \\n\\t\"\n                    \"vfwcvt.f.f.v v11, v7                 \\n\\t\"\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n\n                    \"addi         t5, %[INNER], 0         \\n\\t\"\n                    \"vxor.vv      v16, v16, v16           \\n\\t\"\n                    \"vxor.vv      v18, v18, v18           \\n\\t\"\n                    \"vxor.vv      v20, v20, v20           \\n\\t\"\n                    \"vxor.vv      v22, v22, v22           \\n\\t\"\n                    \"vfmul.vf     v24, v8, f1             \\n\\t\"\n                    \"vfmul.vf     v25, v9, f1             \\n\\t\"\n                    \"vfmul.vf     v26, v10, f1            \\n\\t\"\n                    \"vfmul.vf     v27, v11, f1            \\n\\t\"\n                    \"addi         %[CNT], %[CNT], -1      \\n\\t\"\n\n                    SQ4BIT_KERNEL_LOAD_ZP_16X1\n\n                    \"LOOP_INNER%=:                        \\n\\t\"\n\n                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4\n\n                    \"vsub.vv      v0, v0, v8              \\n\\t\"\n                    \"vsub.vv      v4, v4, v8              \\n\\t\"\n                    \"vsub.vv      v1, v1, v9              \\n\\t\"\n                    \"vsub.vv      v5, v5, v9              \\n\\t\"\n                    \"vsub.vv      v2, v2, v10             \\n\\t\"\n                    \"vsub.vv      v6, v6, v10             \\n\\t\"\n                    \"vsub.vv      v3, v3, v11             \\n\\t\"\n                    \"vsub.vv      v7, v7, v11             \\n\\t\"\n\n                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4\n\n                    \"bnez         t5, LOOP_INNER%=        \\n\\t\"\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n\n                    SQ4BIT_KERNEL_ACC_F16_1X4X4\n                    \"addi         s7, s1, 32              \\n\\t\"\n\n                    \"bnez         %[CNT], LOOP_K%=        \\n\\t\"\n                    \"addi         t3, zero, 16            \\n\\t\"\n                    \"addi         s1, %[C], 16            \\n\\t\"\n                    \"addi         s2, %[C], 32            \\n\\t\"\n                    \"addi         s3, %[C], 48            \\n\\t\"\n                    \"blt          %[NBLKS], t3, ST_TAIL%= \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"jal          x0, END%=               \\n\\t\"\n\n                    \"ST_TAIL%=:                           \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"END%=:                               \\n\\t\"\n\n                    : [CNT] \"+r\"(cnt), [NBLKS] \"+r\"(nblks)\n                    : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [C] \"r\"(CPtr)\n                    : \"cc\", \"t0\", \"t5\", \"t3\", \"f1\", \"s1\", \"s2\", \"s3\", \"s4\", \"s5\", \"s6\", \"s7\");\n            }\n        }\n    } else {\n        for (size_t n = 0; n < CountN; n += 16) {\n            size_t      nblks         = (CountN - n) > 16 ? 16 : CountN - n;\n            std::byte * QuantBDataPtr = (std::byte *) QuantBData +         //\n                                        n * BlockCountK * BlkLen / 2 +     // b data\n                                        n * BlockCountK * sizeof(_Float16);  // scale\n            float * CPtr = C + n;\n            size_t  cnt  = BlockCountK;\n            if (Bias != nullptr) {\n                const float * bias = Bias + n;\n                __asm__ volatile(\n                    \"addi         t3, %[NBLKS], 0         \\n\\t\"\n                    \"addi         s1, %[B], 0             \\n\\t\"\n                    \"addi         s2, %[B], 8             \\n\\t\"\n                    \"addi         s3, %[B], 16            \\n\\t\"\n                    \"addi         s4, %[B], 24            \\n\\t\"\n                    \"addi         s5, %[A], 0             \\n\\t\"\n                    \"addi         s6, %[A], 12            \\n\\t\"\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v28, (%[BIAS])          \\n\\t\"\n                    \"sub          t3, t3, t0              \\n\\t\"\n                    \"addi         %[BIAS], %[BIAS], 16    \\n\\t\"\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v29, (%[BIAS])          \\n\\t\"\n                    \"sub          t3, t3, t0              \\n\\t\"\n                    \"addi         %[BIAS], %[BIAS], 16    \\n\\t\"\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v30, (%[BIAS])          \\n\\t\"\n                    \"sub          t3, t3, t0              \\n\\t\"\n                    \"addi         %[BIAS], %[BIAS], 16    \\n\\t\"\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v31, (%[BIAS])          \\n\\t\"\n\n                    \"LOOP_K%=:                            \\n\\t\"\n                    \"vsetvli      t0, zero, e16, mf4      \\n\\t\"\n\n                    \"vle16.v      v4, (s1)                \\n\\t\"\n                    \"addi         s1, s1, 32              \\n\\t\"\n                    \"vle16.v      v5, (s2)                \\n\\t\"\n                    \"addi         s2, s2, 56              \\n\\t\"\n                    \"vle16.v      v6, (s3)                \\n\\t\"\n                    \"addi         s3, s3, 80              \\n\\t\"\n                    \"vle16.v      v7, (s4)                \\n\\t\"\n                    \"addi         s4, s4, 104             \\n\\t\"\n                    \"flw          f1, (s5)                \\n\\t\"\n                    \"addi         s5, s5, 4               \\n\\t\"\n                    \"vfwcvt.f.f.v v8, v4                  \\n\\t\"\n                    \"vfwcvt.f.f.v v9, v5                  \\n\\t\"\n                    \"vfwcvt.f.f.v v10, v6                 \\n\\t\"\n                    \"vfwcvt.f.f.v v11, v7                 \\n\\t\"\n\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n                    \"addi         t5, %[INNER], 0         \\n\\t\"\n                    \"vxor.vv      v16, v16, v16           \\n\\t\"\n                    \"vxor.vv      v18, v18, v18           \\n\\t\"\n                    \"vxor.vv      v20, v20, v20           \\n\\t\"\n                    \"vxor.vv      v22, v22, v22           \\n\\t\"\n                    \"vfmul.vf     v24, v8, f1             \\n\\t\"\n                    \"vfmul.vf     v25, v9, f1             \\n\\t\"\n                    \"vfmul.vf     v26, v10, f1            \\n\\t\"\n                    \"vfmul.vf     v27, v11, f1            \\n\\t\"\n                    \"addi         %[CNT], %[CNT], -1      \\n\\t\"\n                    \"vsetvli      t0, zero, e8, m1        \\n\\t\"\n                    \"LOOP_INNER%=:                        \\n\\t\"\n\n                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4\n\n                    \"vadd.vi      v0, v0, -8              \\n\\t\"\n                    \"vadd.vi      v1, v1, -8              \\n\\t\"\n                    \"vadd.vi      v2, v2, -8              \\n\\t\"\n                    \"vadd.vi      v3, v3, -8              \\n\\t\"\n                    \"vadd.vi      v4, v4, -8              \\n\\t\"\n                    \"vadd.vi      v5, v5, -8              \\n\\t\"\n                    \"vadd.vi      v6, v6, -8              \\n\\t\"\n                    \"vadd.vi      v7, v7, -8              \\n\\t\"\n\n                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4\n\n                    \"bnez         t5, LOOP_INNER%=        \\n\\t\"\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n\n                    SQ4BIT_KERNEL_ACC_F16_1X4X4\n\n                    \"bnez         %[CNT], LOOP_K%=        \\n\\t\"\n                    \"addi         t3, zero, 16            \\n\\t\"\n                    \"addi         s1, %[C], 16            \\n\\t\"\n                    \"addi         s2, %[C], 32            \\n\\t\"\n                    \"addi         s3, %[C], 48            \\n\\t\"\n                    \"blt          %[NBLKS], t3, ST_TAIL%= \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"jal          x0, END%=               \\n\\t\"\n\n                    \"ST_TAIL%=:                           \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"END%=:                               \\n\\t\"\n\n                    : [CNT] \"+r\"(cnt), [NBLKS] \"+r\"(nblks), [BIAS] \"+r\"(bias)\n                    : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [C] \"r\"(CPtr)\n                    : \"cc\", \"t0\", \"t5\", \"t3\", \"f1\", \"s1\", \"s2\", \"s3\", \"s4\", \"s5\", \"s6\");\n            } else {\n                __asm__ volatile(\n                    \"vsetvli      t0, zero, e32, m4       \\n\\t\"\n                    \"vxor.vv      v28, v28, v28           \\n\\t\"\n                    \"addi         s1, %[B], 0             \\n\\t\"\n                    \"addi         s2, %[B], 8             \\n\\t\"\n                    \"addi         s3, %[B], 16            \\n\\t\"\n                    \"addi         s4, %[B], 24            \\n\\t\"\n\n                    \"addi         s5, %[A], 0             \\n\\t\"\n                    \"addi         s6, %[A], 12            \\n\\t\"\n                    \"LOOP_K%=:                            \\n\\t\"\n                    \"vsetvli      t0, zero, e16, mf4      \\n\\t\"\n                    \"vle16.v      v4, (s1)                \\n\\t\"\n                    \"addi         s1, s1, 32              \\n\\t\"\n                    \"vle16.v      v5, (s2)                \\n\\t\"\n                    \"addi         s2, s2, 56              \\n\\t\"\n                    \"vle16.v      v6, (s3)                \\n\\t\"\n                    \"addi         s3, s3, 80              \\n\\t\"\n                    \"vle16.v      v7, (s4)                \\n\\t\"\n                    \"addi         s4, s4, 104             \\n\\t\"\n                    \"flw          f1, (s5)                \\n\\t\"\n                    \"addi         s5, s5, 4               \\n\\t\"\n\n                    \"vfwcvt.f.f.v v8, v4                  \\n\\t\"\n                    \"vfwcvt.f.f.v v9, v5                  \\n\\t\"\n                    \"vfwcvt.f.f.v v10, v6                 \\n\\t\"\n                    \"vfwcvt.f.f.v v11, v7                 \\n\\t\"\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n\n                    \"addi         t5, %[INNER], 0         \\n\\t\"\n                    \"vxor.vv      v16, v16, v16           \\n\\t\"\n                    \"vxor.vv      v18, v18, v18           \\n\\t\"\n                    \"vxor.vv      v20, v20, v20           \\n\\t\"\n                    \"vxor.vv      v22, v22, v22           \\n\\t\"\n                    \"vfmul.vf     v24, v8, f1             \\n\\t\"\n                    \"vfmul.vf     v25, v9, f1             \\n\\t\"\n                    \"vfmul.vf     v26, v10, f1            \\n\\t\"\n                    \"vfmul.vf     v27, v11, f1            \\n\\t\"\n                    \"addi         %[CNT], %[CNT], -1      \\n\\t\"\n                    \"vsetvli      t0, zero, e8, m1        \\n\\t\"\n                    \"LOOP_INNER%=:                        \\n\\t\"\n\n                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4\n\n                    \"vadd.vi      v0, v0, -8              \\n\\t\"\n                    \"vadd.vi      v1, v1, -8              \\n\\t\"\n                    \"vadd.vi      v2, v2, -8              \\n\\t\"\n                    \"vadd.vi      v3, v3, -8              \\n\\t\"\n                    \"vadd.vi      v4, v4, -8              \\n\\t\"\n                    \"vadd.vi      v5, v5, -8              \\n\\t\"\n                    \"vadd.vi      v6, v6, -8              \\n\\t\"\n                    \"vadd.vi      v7, v7, -8              \\n\\t\"\n\n                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4\n\n                    \"bnez         t5, LOOP_INNER%=        \\n\\t\"\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n\n                    SQ4BIT_KERNEL_ACC_F16_1X4X4\n\n                    \"bnez         %[CNT], LOOP_K%=        \\n\\t\"\n                    \"addi         t3, zero, 16            \\n\\t\"\n                    \"addi         s1, %[C], 16            \\n\\t\"\n                    \"addi         s2, %[C], 32            \\n\\t\"\n                    \"addi         s3, %[C], 48            \\n\\t\"\n                    \"blt          %[NBLKS], t3, ST_TAIL%= \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"jal          x0, END%=               \\n\\t\"\n\n                    \"ST_TAIL%=:                           \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"END%=:                               \\n\\t\"\n\n                    : [CNT] \"+r\"(cnt), [NBLKS] \"+r\"(nblks)\n                    : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [C] \"r\"(CPtr)\n                    : \"cc\", \"t0\", \"t5\", \"t3\", \"f1\", \"s1\", \"s2\", \"s3\", \"s4\", \"s5\", \"s6\");\n            }\n        }\n    }\n}\n\ntemplate <bool HasZeroPoint>\nvoid SQ4BitGemmM1Kernel_CompInt8_Impl(size_t            BlkLen,\n                                      const std::byte * QuantA,\n                                      const std::byte * QuantBData,\n                                      const float *     QuantBScale,\n                                      const std::byte * QuantBZeroPoint,\n                                      float *           C,\n                                      size_t            CountN,\n                                      size_t            BlockCountK,\n                                      const float *     Bias) {\n    GGML_UNUSED(QuantBScale);\n    GGML_UNUSED(QuantBZeroPoint);\n    const size_t INNER = BlkLen / 16;\n    if constexpr (HasZeroPoint) {\n        for (size_t n = 0; n < CountN; n += 16) {\n            size_t      nblks         = (CountN - n) > 16 ? 16 : CountN - n;\n            std::byte * QuantBDataPtr = (std::byte *) QuantBData +           //\n                                        n * BlockCountK * BlkLen / 2 +       // b data\n                                        n * BlockCountK * sizeof(uint8_t) +  // zp\n                                        n * BlockCountK * sizeof(float);     // scale\n            float * CPtr = C + n;\n            size_t  cnt  = BlockCountK;\n            if (Bias != nullptr) {\n                const float * bias = Bias + n;\n                __asm__ volatile(\n                    \"addi         t3, %[NBLKS], 0         \\n\\t\"\n                    \"vsetvli      t0, zero, e8, m1        \\n\\t\"\n                    \"vmv.v.i      v13, 3                  \\n\\t\"\n                    \"li           s1, 24                  \\n\\t\"\n                    \"vsetvli      t0, s1, e8, m1          \\n\\t\"\n                    \"vmv.v.i      v13, 2                  \\n\\t\"\n                    \"vsetvli      t0, zero, e8, mf2       \\n\\t\"\n                    \"vmv.v.i      v13, 1                  \\n\\t\"\n                    \"vsetvli      t0, zero, e8, mf4       \\n\\t\"\n                    \"vmv.v.i      v13, 0                  \\n\\t\"\n                    \"vsetvli      t0, zero, e32, m4       \\n\\t\"\n                    \"vxor.vv      v28, v28, v28           \\n\\t\"\n\n                    // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0\n                    \"addi         s1, %[B], 0             \\n\\t\"\n                    \"addi         s2, %[B], 16            \\n\\t\"\n                    \"addi         s3, %[B], 32            \\n\\t\"\n                    \"addi         s4, %[B], 48            \\n\\t\"\n                    // zp offset\n                    \"addi         s7, %[B], 64            \\n\\t\"\n                    // a offset\n                    \"addi         s5, %[A], 0             \\n\\t\"\n                    \"addi         s6, %[A], 12            \\n\\t\"\n\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v28, (%[BIAS])          \\n\\t\"\n                    \"sub          t3, t3, t0              \\n\\t\"\n                    \"addi         %[BIAS], %[BIAS], 16    \\n\\t\"\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v29, (%[BIAS])          \\n\\t\"\n                    \"sub          t3, t3, t0              \\n\\t\"\n                    \"addi         %[BIAS], %[BIAS], 16    \\n\\t\"\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v30, (%[BIAS])          \\n\\t\"\n                    \"sub          t3, t3, t0              \\n\\t\"\n                    \"addi         %[BIAS], %[BIAS], 16    \\n\\t\"\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v31, (%[BIAS])          \\n\\t\"\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n                    \"LOOP_K%=:                            \\n\\t\"\n\n                    // load scale\n                    \"vle32.v      v8, (s1)                \\n\\t\"\n                    \"addi         s1, s1, 80              \\n\\t\"\n                    \"vle32.v      v9, (s2)                \\n\\t\"\n                    \"addi         s2, s2, 96              \\n\\t\"\n                    \"vle32.v      v10, (s3)               \\n\\t\"\n                    \"addi         s3, s3, 112             \\n\\t\"\n                    \"vle32.v      v11, (s4)               \\n\\t\"\n                    \"addi         s4, s4, 128             \\n\\t\"\n\n                    // load a scale\n                    \"flw          f1, (s5)                \\n\\t\"\n                    \"addi         s5, s5, 4               \\n\\t\"\n\n                    \"addi         t5, %[INNER], 0         \\n\\t\"\n                    \"vxor.vv      v16, v16, v16           \\n\\t\"\n                    \"vxor.vv      v18, v18, v18           \\n\\t\"\n                    \"vxor.vv      v20, v20, v20           \\n\\t\"\n                    \"vxor.vv      v22, v22, v22           \\n\\t\"\n\n                    // a scale * b scale\n                    \"vfmul.vf     v24, v8, f1             \\n\\t\"\n                    \"vfmul.vf     v25, v9, f1             \\n\\t\"\n                    \"vfmul.vf     v26, v10, f1            \\n\\t\"\n                    \"vfmul.vf     v27, v11, f1            \\n\\t\"\n                    \"addi         %[CNT], %[CNT], -1      \\n\\t\"\n\n                    SQ4BIT_KERNEL_LOAD_ZP_16X1\n\n                    \"LOOP_INNER%=:                        \\n\\t\"\n\n                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4\n\n                    \"vsub.vv      v0, v0, v8              \\n\\t\"\n                    \"vsub.vv      v4, v4, v8              \\n\\t\"\n                    \"vsub.vv      v1, v1, v9              \\n\\t\"\n                    \"vsub.vv      v5, v5, v9              \\n\\t\"\n                    \"vsub.vv      v2, v2, v10             \\n\\t\"\n                    \"vsub.vv      v6, v6, v10             \\n\\t\"\n                    \"vsub.vv      v3, v3, v11             \\n\\t\"\n                    \"vsub.vv      v7, v7, v11             \\n\\t\"\n\n                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4\n\n                    \"bnez         t5, LOOP_INNER%=        \\n\\t\"\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n\n                    SQ4BIT_KERNEL_ACC_1X4X4\n                    \"addi         s7, s1, 64              \\n\\t\"\n\n                    \"bnez         %[CNT], LOOP_K%=        \\n\\t\"\n\n                    \"addi         t3, zero, 16            \\n\\t\"\n                    \"addi         s1, %[C], 16            \\n\\t\"\n                    \"addi         s2, %[C], 32            \\n\\t\"\n                    \"addi         s3, %[C], 48            \\n\\t\"\n                    \"blt          %[NBLKS], t3, ST_TAIL%= \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"jal          x0, END%=               \\n\\t\"\n\n                    \"ST_TAIL%=:                           \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"END%=:                               \\n\\t\"\n\n                    : [CNT] \"+r\"(cnt), [NBLKS] \"+r\"(nblks), [BIAS] \"+r\"(bias)\n                    : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [C] \"r\"(CPtr)\n                    : \"cc\", \"t0\", \"t5\", \"t3\", \"f1\", \"s1\", \"s2\", \"s3\", \"s4\", \"s5\", \"s6\", \"s7\");\n            } else {\n                __asm__ volatile(\n                    \"vsetvli      t0, zero, e32, m4       \\n\\t\"\n                    \"vxor.vv      v28, v28, v28           \\n\\t\"\n\n                    \"vsetvli      t0, zero, e8, m1        \\n\\t\"\n                    \"vmv.v.i      v13, 3                  \\n\\t\"\n                    \"li           s1, 24                  \\n\\t\"\n                    \"vsetvli      t0, s1, e8, m1          \\n\\t\"\n                    \"vmv.v.i      v13, 2                  \\n\\t\"\n                    \"vsetvli      t0, zero, e8, mf2       \\n\\t\"\n                    \"vmv.v.i      v13, 1                  \\n\\t\"\n                    \"vsetvli      t0, zero, e8, mf4       \\n\\t\"\n                    \"vmv.v.i      v13, 0                  \\n\\t\"\n                    \"addi         s1, %[B], 0             \\n\\t\"\n                    \"addi         s2, %[B], 16            \\n\\t\"\n                    \"addi         s3, %[B], 32            \\n\\t\"\n                    \"addi         s4, %[B], 48            \\n\\t\"\n\n                    \"addi         s7, %[B], 64            \\n\\t\"\n\n                    \"addi         s5, %[A], 0             \\n\\t\"\n                    \"addi         s6, %[A], 12            \\n\\t\"\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n\n                    \"LOOP_K%=:                            \\n\\t\"\n                    \"vle32.v      v8, (s1)                \\n\\t\"\n                    \"addi         s1, s1, 80              \\n\\t\"\n                    \"vle32.v      v9, (s2)                \\n\\t\"\n                    \"addi         s2, s2, 96              \\n\\t\"\n                    \"vle32.v      v10, (s3)               \\n\\t\"\n                    \"addi         s3, s3, 112             \\n\\t\"\n                    \"vle32.v      v11, (s4)               \\n\\t\"\n                    \"addi         s4, s4, 128             \\n\\t\"\n\n                    \"flw          f1, (s5)                \\n\\t\"\n                    \"addi         s5, s5, 4               \\n\\t\"\n\n                    \"addi         t5, %[INNER], 0         \\n\\t\"\n                    \"vxor.vv      v16, v16, v16           \\n\\t\"\n                    \"vxor.vv      v18, v18, v18           \\n\\t\"\n                    \"vxor.vv      v20, v20, v20           \\n\\t\"\n                    \"vxor.vv      v22, v22, v22           \\n\\t\"\n\n                    \"vfmul.vf     v24, v8, f1             \\n\\t\"\n                    \"vfmul.vf     v25, v9, f1             \\n\\t\"\n                    \"vfmul.vf     v26, v10, f1            \\n\\t\"\n                    \"vfmul.vf     v27, v11, f1            \\n\\t\"\n                    \"addi         %[CNT], %[CNT], -1      \\n\\t\"\n\n                    SQ4BIT_KERNEL_LOAD_ZP_16X1\n\n                    \"LOOP_INNER%=:                        \\n\\t\"\n\n                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4\n\n                    \"vsub.vv      v0, v0, v8              \\n\\t\"\n                    \"vsub.vv      v4, v4, v8              \\n\\t\"\n                    \"vsub.vv      v1, v1, v9              \\n\\t\"\n                    \"vsub.vv      v5, v5, v9              \\n\\t\"\n                    \"vsub.vv      v2, v2, v10             \\n\\t\"\n                    \"vsub.vv      v6, v6, v10             \\n\\t\"\n                    \"vsub.vv      v3, v3, v11             \\n\\t\"\n                    \"vsub.vv      v7, v7, v11             \\n\\t\"\n\n                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4\n\n                    \"bnez         t5, LOOP_INNER%=        \\n\\t\"\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n\n                    SQ4BIT_KERNEL_ACC_1X4X4\n                    \"addi         s7, s1, 64              \\n\\t\"\n\n                    \"bnez         %[CNT], LOOP_K%=        \\n\\t\"\n\n                    \"addi         t3, zero, 16            \\n\\t\"\n                    \"addi         s1, %[C], 16            \\n\\t\"\n                    \"addi         s2, %[C], 32            \\n\\t\"\n                    \"addi         s3, %[C], 48            \\n\\t\"\n                    \"blt          %[NBLKS], t3, ST_TAIL%= \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"jal          x0, END%=               \\n\\t\"\n\n                    \"ST_TAIL%=:                           \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"END%=:                               \\n\\t\"\n\n                    : [CNT] \"+r\"(cnt), [NBLKS] \"+r\"(nblks)\n                    : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [C] \"r\"(CPtr)\n                    : \"cc\", \"t0\", \"t5\", \"t3\", \"f1\", \"s1\", \"s2\", \"s3\", \"s4\", \"s5\", \"s6\", \"s7\");\n            }\n        }\n    } else {\n        for (size_t n = 0; n < CountN; n += 16) {\n            size_t      nblks         = (CountN - n) > 16 ? 16 : CountN - n;\n            std::byte * QuantBDataPtr = (std::byte *) QuantBData +        //\n                                        n * BlockCountK * BlkLen / 2 +    // b data\n                                        n * BlockCountK * sizeof(float);  // scale\n            float * CPtr = C + n;\n            size_t  cnt  = BlockCountK;\n            if (Bias != nullptr) {\n                const float * bias = Bias + n;\n                __asm__ volatile(\n                    \"addi         t3, %[NBLKS], 0         \\n\\t\"\n                    \"addi         s1, %[B], 0             \\n\\t\"\n                    \"addi         s2, %[B], 16            \\n\\t\"\n                    \"addi         s3, %[B], 32            \\n\\t\"\n                    \"addi         s4, %[B], 48            \\n\\t\"\n                    \"addi         s5, %[A], 0             \\n\\t\"\n                    \"addi         s6, %[A], 12            \\n\\t\"\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v28, (%[BIAS])          \\n\\t\"\n                    \"sub          t3, t3, t0              \\n\\t\"\n                    \"addi         %[BIAS], %[BIAS], 16    \\n\\t\"\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v29, (%[BIAS])          \\n\\t\"\n                    \"sub          t3, t3, t0              \\n\\t\"\n                    \"addi         %[BIAS], %[BIAS], 16    \\n\\t\"\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v30, (%[BIAS])          \\n\\t\"\n                    \"sub          t3, t3, t0              \\n\\t\"\n                    \"addi         %[BIAS], %[BIAS], 16    \\n\\t\"\n                    \"vsetvli      t0, t3, e32, mf2        \\n\\t\"\n                    \"vle32.v      v31, (%[BIAS])          \\n\\t\"\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n                    \"LOOP_K%=:                            \\n\\t\"\n                    \"vle32.v      v8, (s1)                \\n\\t\"\n                    \"addi         s1, s1, 64              \\n\\t\"\n                    \"vle32.v      v9, (s2)                \\n\\t\"\n                    \"addi         s2, s2, 80              \\n\\t\"\n                    \"vle32.v      v10, (s3)               \\n\\t\"\n                    \"addi         s3, s3, 96              \\n\\t\"\n                    \"vle32.v      v11, (s4)               \\n\\t\"\n                    \"addi         s4, s4, 112             \\n\\t\"\n                    \"flw          f1, (s5)                \\n\\t\"\n                    \"addi         s5, s5, 4               \\n\\t\"\n\n                    \"addi         t5, %[INNER], 0         \\n\\t\"\n                    \"vxor.vv      v16, v16, v16           \\n\\t\"\n                    \"vxor.vv      v18, v18, v18           \\n\\t\"\n                    \"vxor.vv      v20, v20, v20           \\n\\t\"\n                    \"vxor.vv      v22, v22, v22           \\n\\t\"\n                    \"vfmul.vf     v24, v8, f1             \\n\\t\"\n                    \"vfmul.vf     v25, v9, f1             \\n\\t\"\n                    \"vfmul.vf     v26, v10, f1            \\n\\t\"\n                    \"vfmul.vf     v27, v11, f1            \\n\\t\"\n                    \"addi         %[CNT], %[CNT], -1      \\n\\t\"\n                    \"vsetvli      t0, zero, e8, m1        \\n\\t\"\n                    \"LOOP_INNER%=:                        \\n\\t\"\n\n                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4\n\n                    \"vadd.vi      v0, v0, -8              \\n\\t\"\n                    \"vadd.vi      v1, v1, -8              \\n\\t\"\n                    \"vadd.vi      v2, v2, -8              \\n\\t\"\n                    \"vadd.vi      v3, v3, -8              \\n\\t\"\n                    \"vadd.vi      v4, v4, -8              \\n\\t\"\n                    \"vadd.vi      v5, v5, -8              \\n\\t\"\n                    \"vadd.vi      v6, v6, -8              \\n\\t\"\n                    \"vadd.vi      v7, v7, -8              \\n\\t\"\n\n                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4\n\n                    \"bnez         t5, LOOP_INNER%=        \\n\\t\"\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n\n                    SQ4BIT_KERNEL_ACC_1X4X4\n\n                    \"bnez         %[CNT], LOOP_K%=        \\n\\t\"\n                    \"addi         t3, zero, 16            \\n\\t\"\n                    \"addi         s1, %[C], 16            \\n\\t\"\n                    \"addi         s2, %[C], 32            \\n\\t\"\n                    \"addi         s3, %[C], 48            \\n\\t\"\n                    \"blt          %[NBLKS], t3, ST_TAIL%= \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"jal          x0, END%=               \\n\\t\"\n\n                    \"ST_TAIL%=:                           \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"END%=:                               \\n\\t\"\n\n                    : [CNT] \"+r\"(cnt), [NBLKS] \"+r\"(nblks), [BIAS] \"+r\"(bias)\n                    : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [C] \"r\"(CPtr)\n                    : \"cc\", \"t0\", \"t5\", \"t3\", \"f1\", \"s1\", \"s2\", \"s3\", \"s4\", \"s5\", \"s6\");\n            } else {\n                __asm__ volatile(\n                    \"vsetvli      t0, zero, e32, m4       \\n\\t\"\n                    \"vxor.vv      v28, v28, v28           \\n\\t\"\n                    \"addi         s1, %[B], 0             \\n\\t\"\n                    \"addi         s2, %[B], 16            \\n\\t\"\n                    \"addi         s3, %[B], 32            \\n\\t\"\n                    \"addi         s4, %[B], 48            \\n\\t\"\n\n                    \"addi         s5, %[A], 0             \\n\\t\"\n                    \"addi         s6, %[A], 12            \\n\\t\"\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n                    \"LOOP_K%=:                            \\n\\t\"\n                    \"vle32.v      v8, (s1)                \\n\\t\"\n                    \"addi         s1, s1, 64              \\n\\t\"\n                    \"vle32.v      v9, (s2)                \\n\\t\"\n                    \"addi         s2, s2, 80              \\n\\t\"\n                    \"vle32.v      v10, (s3)               \\n\\t\"\n                    \"addi         s3, s3, 96              \\n\\t\"\n                    \"vle32.v      v11, (s4)               \\n\\t\"\n                    \"addi         s4, s4, 112             \\n\\t\"\n                    \"flw          f1, (s5)                \\n\\t\"\n                    \"addi         s5, s5, 4               \\n\\t\"\n\n                    \"addi         t5, %[INNER], 0         \\n\\t\"\n                    \"vxor.vv      v16, v16, v16           \\n\\t\"\n                    \"vxor.vv      v18, v18, v18           \\n\\t\"\n                    \"vxor.vv      v20, v20, v20           \\n\\t\"\n                    \"vxor.vv      v22, v22, v22           \\n\\t\"\n                    \"vfmul.vf     v24, v8, f1             \\n\\t\"\n                    \"vfmul.vf     v25, v9, f1             \\n\\t\"\n                    \"vfmul.vf     v26, v10, f1            \\n\\t\"\n                    \"vfmul.vf     v27, v11, f1            \\n\\t\"\n                    \"addi         %[CNT], %[CNT], -1      \\n\\t\"\n                    \"vsetvli      t0, zero, e8, m1        \\n\\t\"\n                    \"LOOP_INNER%=:                        \\n\\t\"\n\n                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4\n\n                    \"vadd.vi      v0, v0, -8              \\n\\t\"\n                    \"vadd.vi      v1, v1, -8              \\n\\t\"\n                    \"vadd.vi      v2, v2, -8              \\n\\t\"\n                    \"vadd.vi      v3, v3, -8              \\n\\t\"\n                    \"vadd.vi      v4, v4, -8              \\n\\t\"\n                    \"vadd.vi      v5, v5, -8              \\n\\t\"\n                    \"vadd.vi      v6, v6, -8              \\n\\t\"\n                    \"vadd.vi      v7, v7, -8              \\n\\t\"\n\n                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4\n\n                    \"bnez         t5, LOOP_INNER%=        \\n\\t\"\n                    \"vsetvli      t0, zero, e32, mf2      \\n\\t\"\n\n                    SQ4BIT_KERNEL_ACC_1X4X4\n\n                    \"bnez         %[CNT], LOOP_K%=        \\n\\t\"\n                    \"addi         t3, zero, 16            \\n\\t\"\n                    \"addi         s1, %[C], 16            \\n\\t\"\n                    \"addi         s2, %[C], 32            \\n\\t\"\n                    \"addi         s3, %[C], 48            \\n\\t\"\n                    \"blt          %[NBLKS], t3, ST_TAIL%= \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"jal          x0, END%=               \\n\\t\"\n\n                    \"ST_TAIL%=:                           \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v28, (%[C])             \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v29, (s1)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v30, (s2)               \\n\\t\"\n                    \"vsetvli      t0, %[NBLKS], e32, mf2  \\n\\t\"\n                    \"sub          %[NBLKS], %[NBLKS], t0  \\n\\t\"\n                    \"vse32.v      v31, (s3)               \\n\\t\"\n                    \"END%=:                               \\n\\t\"\n\n                    : [CNT] \"+r\"(cnt), [NBLKS] \"+r\"(nblks)\n                    : [INNER] \"r\"(INNER), [A] \"r\"(QuantA), [B] \"r\"(QuantBDataPtr), [C] \"r\"(CPtr)\n                    : \"cc\", \"t0\", \"t5\", \"t3\", \"f1\", \"s1\", \"s2\", \"s3\", \"s4\", \"s5\", \"s6\");\n            }\n        }\n    }\n}\n\ntemplate <bool HasZeroPoint>\ninline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t            BlkLen,\n                                                         const std::byte * QuantA,\n                                                         const std::byte * QuantBData,\n                                                         const float *     QuantBScale,\n                                                         const std::byte * QuantBZeroPoint,\n                                                         float *           C,\n                                                         size_t            CountM,\n                                                         size_t            CountN,\n                                                         size_t            BlockStrideQuantB,\n                                                         const float *     Bias,\n                                                         const size_t      ldc,\n                                                         const size_t      scalestride) {\n    if (scalestride == 4) {\n        SQ4BitGemmM4Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,\n                                                       CountN, BlockStrideQuantB, Bias, ldc);\n\n    } else if (scalestride == 2) {\n        SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(\n            BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc);\n    }\n}\n\ntemplate <bool HasZeroPoint>\ninline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t            BlkLen,\n                                                         const std::byte * QuantA,\n                                                         const std::byte * QuantBData,\n                                                         const float *     QuantBScale,\n                                                         const std::byte * QuantBZeroPoint,\n                                                         float *           C,\n                                                         size_t            CountM,\n                                                         size_t            CountN,\n                                                         size_t            BlockStrideQuantB,\n                                                         const float *     Bias,\n                                                         const size_t      ldc,\n                                                         const size_t      scalestride) {\n    if (scalestride == 4) {\n        SQ4BitGemmM1Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,\n                                                       CountN, BlockStrideQuantB, Bias);\n    } else if (scalestride == 2) {\n        SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale,\n                                                                 QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias);\n    }\n}\n\n}  // namespace\n\nnamespace ime1 {\nsize_t gemm_kernel_i8i4(size_t            BlkLen,\n                        const std::byte * QuantA,\n                        const std::byte * QuantBData,\n                        const float *     QuantBScale,\n                        const std::byte * QuantBZeroPoint,\n                        float *           C,\n                        size_t            CountM,\n                        size_t            CountN,\n                        size_t            CountK,\n                        size_t            BlockCountK,\n                        size_t            ldc,\n                        const float *     Bias,\n                        const size_t      ScaleStride) {\n    GGML_UNUSED(CountM);\n    GGML_UNUSED(CountK);\n    GGML_UNUSED(ldc);\n    if (CountM >= 4) {\n        if (QuantBZeroPoint != nullptr) {\n            SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,\n                                                               C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);\n        } else {\n            SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,\n                                                                QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,\n                                                                ldc, ScaleStride);\n        }\n        return 4;\n    } else {\n        if (QuantBZeroPoint != nullptr) {\n            SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,\n                                                               C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);\n        } else {\n            SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,\n                                                                QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,\n                                                                ldc, ScaleStride);\n        }\n        return 1;\n    }\n}\n}  // namespace ime1\n}  // namespace sqnbitgemm_spacemit_ime\n"
  },
  {
    "path": "src/ggml-cpu/spacemit/ime_kernels.h",
    "content": "#pragma once\n\n#include <cstddef>\n\nnamespace sqnbitgemm_spacemit_ime {\nnamespace ime1 {\nsize_t gemm_kernel_i8i4(size_t            blk_len,\n                        const std::byte * quant_a_ptr,\n                        const std::byte * quant_b_data,\n                        const float *     quant_b_scale,\n                        const std::byte * quant_b_zp,\n                        float *           c_ptr,\n                        size_t            count_m,\n                        size_t            count_n,\n                        size_t            count_k,\n                        size_t            block_count_k,\n                        size_t            ldc,\n                        const float *     bias,\n                        const size_t      scale_stride);\n\nvoid quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr);\n\nvoid quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr);\n\n}  // namespace ime1\n}  // namespace sqnbitgemm_spacemit_ime\n"
  },
  {
    "path": "src/ggml-cpu/traits.cpp",
    "content": "#include \"traits.h\"\n\n#include \"ggml-backend-impl.h\"\n#include \"ggml-backend.h\"\n\nnamespace ggml::cpu {\ntensor_traits::~tensor_traits() {}\n\nextra_buffer_type::~extra_buffer_type() {}\n}  // namespace ggml::cpu\n\nbool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) {\n    for (auto extra : ggml_backend_cpu_get_extra_buffer_types()) {\n        if (extra && extra->context) {\n            auto buf_extra     = (ggml::cpu::extra_buffer_type *) extra->context;\n            auto tensor_traits = buf_extra->get_tensor_traits(op);\n            if (tensor_traits && tensor_traits->compute_forward(params, op)) {\n                return true;\n            }\n        }\n    }\n    return false;\n}\n\nbool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size) {\n    for (auto extra : ggml_backend_cpu_get_extra_buffer_types()) {\n        if (extra && extra->context) {\n            auto buf_extra     = (ggml::cpu::extra_buffer_type *) extra->context;\n            auto tensor_traits = buf_extra->get_tensor_traits(op);\n            if (tensor_traits && tensor_traits->work_size(n_threads, op, *size)) {\n                return true;\n            }\n        }\n    }\n    return false;\n}\n"
  },
  {
    "path": "src/ggml-cpu/traits.h",
    "content": "#pragma once\n#include \"ggml-backend-impl.h\"\n#include \"ggml-cpu-impl.h\"\n#include \"ggml.h\"\n\n#ifdef __cplusplus\n#    include <vector>\nextern \"C\" {\n#endif\n\n// return true if op part of extra \"accelerator\"\nbool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op);\nbool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size);\n\n#ifdef __cplusplus\n}\n\nnamespace ggml::cpu {\n// register in tensor->extra\nclass tensor_traits {\n  public:\n    virtual ~tensor_traits();\n    virtual bool work_size(int n_threads, const struct ggml_tensor * op, size_t & size)        = 0;\n    virtual bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) = 0;\n};\n\nclass extra_buffer_type {\n  public:\n    virtual ~extra_buffer_type();\n    virtual bool            supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) = 0;\n    virtual tensor_traits * get_tensor_traits(const struct ggml_tensor * op)                   = 0;\n};\n}  // namespace ggml::cpu\n\n// implemented in ggml-cpu.cpp.\nstd::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_types();\n\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/unary-ops.cpp",
    "content": "#include \"unary-ops.h\"\n\nstatic inline float op_abs(float x) {\n    return fabsf(x);\n}\n\nstatic inline float op_sgn(float x) {\n    return (x > 0.f) ? 1.f : ((x < 0.f) ? -1.f : 0.f);\n}\n\nstatic inline float op_neg(float x) {\n    return -x;\n}\n\nstatic inline float op_step(float x) {\n    return (x > 0.f) ? 1.f : 0.f;\n}\n\nstatic inline float op_tanh(float x) {\n    return tanhf(x);\n}\n\nstatic inline float op_elu(float x) {\n    return (x > 0.f) ? x : expm1f(x);\n}\n\nstatic inline float op_relu(float x) {\n    return (x > 0.f) ? x : 0.f;\n}\n\nstatic inline float op_sigmoid(float x) {\n    return 1.f / (1.f + expf(-x));\n}\n\nstatic inline float op_hardsigmoid(float x) {\n    return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));\n}\n\nstatic inline float op_exp(float x) {\n    return expf(x);\n}\n\nstatic inline float op_hardswish(float x) {\n    return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));\n}\n\nstatic inline float op_sqr(float x) {\n    return x * x;\n}\n\nstatic inline float op_sqrt(float x) {\n    return sqrtf(x);\n}\n\nstatic inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) {\n    if (x > 0.0f) {\n        return alpha_p * x * x + beta * x;\n    } else {\n        const float min_x_eps = fminf(x, eps);\n        return (expm1f(min_x_eps) - x) * alpha_n + beta * x;\n    }\n}\n\nstatic inline float op_sin(float x) {\n    return sinf(x);\n}\n\nstatic inline float op_cos(float x) {\n    return cosf(x);\n}\n\nstatic inline float op_log(float x) {\n    return logf(x);\n}\n\nstatic inline float op_expm1(float x) {\n    return expf(x) - 1.0f;\n}\n\nstatic inline float op_softplus(float x) {\n    return (x > 20.0f) ? x : logf(1.0f + expf(x));\n}\n\nstatic inline float op_floor(float x) {\n    return floorf(x);\n}\n\nstatic inline float op_ceil(float x) {\n    return ceilf(x);\n}\n\nstatic inline float op_round(float x) {\n    return roundf(x);\n}\n\nstatic inline float op_trunc(float x) {\n    return truncf(x);\n}\n\ntemplate <float (*op)(float), typename src0_t, typename dst_t>\nstatic inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {\n    constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;\n    constexpr auto f32_to_dst  = type_conversion_table<dst_t >::from_f32;\n\n    for (int i = 0; i < n; i++) {\n        y[i] = f32_to_dst(op(src0_to_f32(x[i])));\n    }\n}\n\ntemplate <float (*op)(float), typename src0_t, typename dst_t>\nstatic void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst));\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    GGML_ASSERT( nb0 == sizeof(dst_t));\n    GGML_ASSERT(nb00 == sizeof(src0_t));\n\n    const auto [ir0, ir1] = get_thread_range(params, src0);\n\n    for (int64_t ir = ir0; ir < ir1; ++ir) {\n        const int64_t i03 = ir/(ne02*ne01);\n        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;\n        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);\n\n        dst_t        * dst_ptr  = (dst_t  *)       ((char *)       dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );\n        const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);\n\n        vec_unary_op<op>(ne0, dst_ptr, src0_ptr);\n    }\n}\n\n// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates\ntemplate <float (*op)(float)>\nstatic void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    /*  */ if (src0->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) { // all f32\n        apply_unary_op<op, float, float>(params, dst);\n    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F16) { // all f16\n        apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);\n    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16\n        apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);\n    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {\n        apply_unary_op<op, ggml_bf16_t, float>(params, dst);\n    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F32) {\n        apply_unary_op<op, ggml_fp16_t, float>(params, dst);\n    } else {\n        fprintf(stderr, \"%s: unsupported types: dst: %s, src0: %s\\n\", __func__,\n            ggml_type_name(dst->type), ggml_type_name(src0->type));\n        GGML_ABORT(\"fatal error\");\n    }\n}\n\ntemplate <float (*op)(float, ggml_tensor *)>\nstatic void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    /*  */ if (src0->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) { // all f32\n        apply_unary_op<op, float, float>(params, dst);\n    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F16) { // all f16\n        apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);\n    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16\n        apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);\n    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {\n        apply_unary_op<op, ggml_bf16_t, float>(params, dst);\n    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F32) {\n        apply_unary_op<op, ggml_fp16_t, float>(params, dst);\n    } else {\n        fprintf(stderr, \"%s: unsupported types: dst: %s, src0: %s\\n\", __func__,\n            ggml_type_name(dst->type), ggml_type_name(src0->type));\n        GGML_ABORT(\"fatal error\");\n    }\n}\n\n// Extend vec_unary_op to support functors\ntemplate <typename Op, typename src0_t, typename dst_t>\nstatic inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) {\n    constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;\n    constexpr auto f32_to_dst  = type_conversion_table<dst_t >::from_f32;\n\n    for (int i = 0; i < n; i++) {\n        y[i] = f32_to_dst(op(src0_to_f32(x[i])));\n    }\n}\n\n// Extend apply_unary_op to support functors\ntemplate <typename Op, typename src0_t, typename dst_t>\nstatic void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    GGML_ASSERT( nb0 == sizeof(dst_t));\n    GGML_ASSERT(nb00 == sizeof(src0_t));\n\n    const auto [ir0, ir1] = get_thread_range(params, src0);\n\n    for (int64_t ir = ir0; ir < ir1; ++ir) {\n        const int64_t i03 = ir/(ne02*ne01);\n        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;\n        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);\n\n        dst_t        * dst_ptr  = (dst_t  *)       ((char *)       dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );\n        const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);\n\n        vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op);\n    }\n}\n\n// Generic dispatcher for functors\ntemplate <typename Op>\nstatic void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    /*  */ if (src0->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) { // all f32\n        apply_unary_op_functor<Op, float, float>(params, dst, op);\n    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F16) { // all f16\n        apply_unary_op_functor<Op, ggml_fp16_t, ggml_fp16_t>(params, dst, op);\n    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16\n        apply_unary_op_functor<Op, ggml_bf16_t, ggml_bf16_t>(params, dst, op);\n    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {\n        apply_unary_op_functor<Op, ggml_bf16_t, float>(params, dst, op);\n    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F32) {\n        apply_unary_op_functor<Op, ggml_fp16_t, float>(params, dst, op);\n    } else {\n        fprintf(stderr, \"%s: unsupported types: dst: %s, src0: %s\\n\", __func__,\n            ggml_type_name(dst->type), ggml_type_name(src0->type));\n        GGML_ABORT(\"fatal error\");\n    }\n}\n\nvoid ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_abs>(params, dst);\n}\n\nvoid ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_sgn>(params, dst);\n}\n\nvoid ggml_compute_forward_neg(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_neg>(params, dst);\n}\n\nvoid ggml_compute_forward_step(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_step>(params, dst);\n}\n\nvoid ggml_compute_forward_tanh(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_tanh>(params, dst);\n}\n\nvoid ggml_compute_forward_elu(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_elu>(params, dst);\n}\n\nvoid ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_relu>(params, dst);\n}\n\nvoid ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_sigmoid>(params, dst);\n}\n\nvoid ggml_compute_forward_hardsigmoid(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_hardsigmoid>(params, dst);\n}\n\nvoid ggml_compute_forward_exp(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_exp>(params, dst);\n}\n\nvoid ggml_compute_forward_hardswish(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_hardswish>(params, dst);\n}\n\nvoid ggml_compute_forward_sqr(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_sqr>(params, dst);\n}\n\nvoid ggml_compute_forward_sqrt(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_sqrt>(params, dst);\n}\n\nvoid ggml_compute_forward_sin(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_sin>(params, dst);\n}\n\nvoid ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_cos>(params, dst);\n}\n\nvoid ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_log>(params, dst);\n}\n\nvoid ggml_compute_forward_expm1(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_expm1>(params, dst);\n}\n\nvoid ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_softplus>(params, dst);\n}\n\nvoid ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_floor>(params, dst);\n}\n\nvoid ggml_compute_forward_ceil(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_ceil>(params, dst);\n}\n\nvoid ggml_compute_forward_round(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_round>(params, dst);\n}\n\nvoid ggml_compute_forward_trunc(const ggml_compute_params * params, ggml_tensor * dst) {\n    unary_op<op_trunc>(params, dst);\n}\n\nvoid ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {\n    const float alpha_n = ggml_get_op_params_f32(dst, 1);\n    const float alpha_p = ggml_get_op_params_f32(dst, 2);\n    const float beta = ggml_get_op_params_f32(dst, 3);\n    const float eps = ggml_get_op_params_f32(dst, 4);\n\n    const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {\n        return op_xielu(f, alpha_n, alpha_p, beta, eps);\n    };\n\n    unary_op_functor(params, dst, xielu_op_params);\n}\n\n"
  },
  {
    "path": "src/ggml-cpu/unary-ops.h",
    "content": "#pragma once\n\n#include \"common.h\"\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nvoid ggml_compute_forward_abs(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_sgn(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_neg(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_step(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_tanh(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_elu(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_sigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_hardsigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_exp(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_hardswish(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_sqr(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_expm1(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_softplus(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_trunc(const struct ggml_compute_params * params, struct ggml_tensor * dst);\nvoid ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-cpu/vec.cpp",
    "content": "#include \"vec.h\"\n\n#include <cassert>\n\n// precomputed gelu table for f16 (128 KB)\nggml_fp16_t ggml_table_gelu_f16[1 << 16];\n\n// precomputed quick gelu table for f16 (128 KB)\nggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];\n\nvoid ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) {\n   assert(nrc == 1);\n   GGML_UNUSED(nrc);\n   GGML_UNUSED(bx);\n   GGML_UNUSED(by);\n   GGML_UNUSED(bs);\n\n#if defined(GGML_SIMD)\n    float sumf = 0.0f;\n\n    #if defined(__ARM_FEATURE_SVE)\n        const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;\n        const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16\n        const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers\n\n        const int np = (n & ~(ggml_f32_step - 1));\n        svfloat32_t sum1 = svdup_n_f32(0.0f);\n        svfloat32_t sum2 = svdup_n_f32(0.0f);\n        svfloat32_t sum3 = svdup_n_f32(0.0f);\n        svfloat32_t sum4 = svdup_n_f32(0.0f);\n        svfloat32_t sum5 = svdup_n_f32(0.0f);\n        svfloat32_t sum6 = svdup_n_f32(0.0f);\n        svfloat32_t sum7 = svdup_n_f32(0.0f);\n        svfloat32_t sum8 = svdup_n_f32(0.0f);\n        svfloat32_t ax1,ax2,ax3,ax4,ax5,ax6,ax7,ax8;\n        svfloat32_t ay1,ay2,ay3,ay4,ay5,ay6,ay7,ay8;\n        for (int i = 0; i < np; i += ggml_f32_step) {\n            ax1 = GGML_F32_VEC_LOAD(x + i);\n            ay1 = GGML_F32_VEC_LOAD(y + i);\n            sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);\n\n            ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);\n            ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);\n            sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);\n\n            ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);\n            ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);\n            sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);\n\n            ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);\n            ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);\n            sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);\n\n            ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);\n            ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);\n            sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);\n\n            ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);\n            ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);\n            sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);\n\n            ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);\n            ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);\n            sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);\n\n            ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);\n            ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);\n            sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);\n        }\n        // leftovers\n        // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop\n        const int np2 = (n & ~(ggml_f32_epr - 1));\n        for (int i = np; i < np2; i += ggml_f32_epr) {\n            ax1 = GGML_F32_VEC_LOAD(x + i);\n            ay1 = GGML_F32_VEC_LOAD(y + i);\n            sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);\n        }\n        // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only\n        if (np2 < n) {\n            svbool_t pg = svwhilelt_b32(np2, n);\n            ax1 = svld1_f32(pg, x + np2);\n            ay1 = svld1_f32(pg, y + np2);\n            sum1 = svmad_f32_m(pg, ax1, ay1, sum1);\n        }\n        // reduce sum1,sum2 to sum1\n        GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);\n    #elif defined(__riscv_v_intrinsic)\n        int vl = __riscv_vsetvlmax_e32m8();\n        vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);\n        vfloat32m8_t vsum;\n        vfloat32m8_t ax;\n        vfloat32m8_t ay;\n        vsum = __riscv_vfmv_v_f_f32m8_tu(vsum, 0.0f, vl);\n        for (int i = 0; i < n; i += vl) {\n            vl = __riscv_vsetvl_e32m8(n - i);\n            ax = __riscv_vle32_v_f32m8_tu(ax, &x[i], vl);\n            ay = __riscv_vle32_v_f32m8_tu(ay, &y[i], vl);\n            vsum = __riscv_vfmacc_vv_f32m8_tu(vsum, ax, ay, vl);\n        }\n        vl = __riscv_vsetvlmax_e32m8();\n        vs = __riscv_vfredusum_vs_f32m8_f32m1(vsum, vs, vl);\n        sumf += __riscv_vfmv_f_s_f32m1_f32(vs);\n    #else\n        const int np = (n & ~(GGML_F32_STEP - 1));\n\n        GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };\n\n        GGML_F32_VEC ax[GGML_F32_ARR];\n        GGML_F32_VEC ay[GGML_F32_ARR];\n\n        for (int i = 0; i < np; i += GGML_F32_STEP) {\n            for (int j = 0; j < GGML_F32_ARR; j++) {\n                ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);\n                ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);\n\n                sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);\n            }\n        }\n\n        // reduce sum0..sum3 to sum0\n        GGML_F32_VEC_REDUCE(sumf, sum);\n\n        // leftovers\n        for (int i = np; i < n; ++i) {\n            sumf += x[i]*y[i];\n        }\n    #endif\n#else\n    // scalar\n    ggml_float sumf = 0.0;\n    for (int i = 0; i < n; ++i) {\n        sumf += (ggml_float)(x[i]*y[i]);\n    }\n#endif\n\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc) {\n    assert(nrc == 1);\n    GGML_UNUSED(nrc);\n    GGML_UNUSED(bx);\n    GGML_UNUSED(by);\n    GGML_UNUSED(bs);\n    int i = 0;\n    ggml_float sumf = 0;\n\n#if defined(__AVX512BF16__)\n    __m512 c1 = _mm512_setzero_ps();\n    __m512 c2 = _mm512_setzero_ps();\n    for (; i + 64 <= n; i += 64) {\n        c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),\n                             m512bh(_mm512_loadu_si512((y + i))));\n        c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),\n                             m512bh(_mm512_loadu_si512((y + i + 32))));\n    }\n    sumf += (ggml_float)_mm512_reduce_add_ps(c1);\n    sumf += (ggml_float)_mm512_reduce_add_ps(c2);\n\n#elif defined(__AVX512F__)\n#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))\n    __m512 c1 = _mm512_setzero_ps();\n    __m512 c2 = _mm512_setzero_ps();\n    for (; i + 32 <= n; i += 32) {\n        c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);\n        c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);\n    }\n    sumf += (ggml_float)_mm512_reduce_add_ps(c1);\n    sumf += (ggml_float)_mm512_reduce_add_ps(c2);\n\n#undef LOAD\n#elif defined(__AVX2__) || defined(__AVX__)\n#if defined(__AVX2__)\n#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))\n#else\n#define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1))\n#endif\n    __m256 c1 = _mm256_setzero_ps();\n    __m256 c2 = _mm256_setzero_ps();\n    __m256 c3 = _mm256_setzero_ps();\n    __m256 c4 = _mm256_setzero_ps();\n    for (; i + 32 <= n; i += 32) {\n        c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);\n        c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);\n        c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);\n        c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);\n    }\n    __m128 g;\n    c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),\n                       _mm256_add_ps(c2, c4));\n    g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),\n                   _mm256_castps256_ps128(c1));\n    g = _mm_add_ps(g, _mm_movehl_ps(g, g));\n    g = _mm_add_ss(g, _mm_movehdup_ps(g));\n    sumf += (ggml_float)_mm_cvtss_f32(g);\n\n#undef LOAD\n#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfwma)\n    size_t vl = __riscv_vsetvlmax_e32m4();\n\n    // initialize accumulators to all zeroes\n    vfloat32m4_t vsum0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);\n    vfloat32m4_t vsum1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);\n\n    // calculate step size\n    const size_t epr = __riscv_vsetvlmax_e16m2();\n    const size_t step = epr * 2;\n    const int np = (n & ~(step - 1));\n\n    // unroll by 2\n    for (; i < np; i += step) {\n        vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], epr);\n        vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], epr);\n        vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, epr);\n        __asm__ __volatile__ (\"\" ::: \"memory\");\n\n        vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i + epr], epr);\n        vbfloat16m2_t ay1 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i + epr], epr);\n        vsum1 = __riscv_vfwmaccbf16_vv_f32m4(vsum1, ax1, ay1, epr);\n        __asm__ __volatile__ (\"\" ::: \"memory\");\n    }\n\n    // accumulate in 1 register\n    vsum0 = __riscv_vfadd_vv_f32m4(vsum0, vsum1, vl);\n\n    // leftovers\n    for (i = np; i < n; i += vl) {\n        vl = __riscv_vsetvl_e16m2(n - i);\n        vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], vl);\n        vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], vl);\n        vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, vl);\n    }\n\n    // reduce\n    vl = __riscv_vsetvlmax_e32m4();\n    vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);\n    sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);\n\n#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__)\n    const int np = (n & ~(GGML_BF16_STEP - 1));\n    if (np > 0) {\n        GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO};\n        for (; i < np; i += GGML_BF16_STEP) {\n            GGML_BF16_VEC vx0 = GGML_BF16_VEC_LOAD(x + i);\n            GGML_BF16_VEC vx1 = GGML_BF16_VEC_LOAD(x + i + 8);\n            GGML_BF16_VEC vy0 = GGML_BF16_VEC_LOAD(y + i);\n            GGML_BF16_VEC vy1 = GGML_BF16_VEC_LOAD(y + i + 8);\n            GGML_BF16_FMA_LO(sum[0], vx0, vy0);\n            GGML_BF16_FMA_HI(sum[1], vx0, vy0);\n            GGML_BF16_FMA_LO(sum[2], vx1, vy1);\n            GGML_BF16_FMA_HI(sum[3], vx1, vy1);\n        }\n        GGML_F32x4_REDUCE_4(sumf, sum[0], sum[1], sum[2], sum[3]);\n    }\n#endif\n\n    for (; i < n; ++i) {\n        sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *\n                             GGML_BF16_TO_FP32(y[i]));\n    }\n    *s = sumf;\n}\n\nvoid ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc) {\n    assert(nrc == 1);\n    GGML_UNUSED(nrc);\n    GGML_UNUSED(bx);\n    GGML_UNUSED(by);\n    GGML_UNUSED(bs);\n\n    ggml_float sumf = 0.0;\n\n\n#if defined(GGML_SIMD)\n    #if defined(__ARM_FEATURE_SVE)\n        const int sve_register_length = svcntb() * 8; //get vector length\n        const int ggml_f16_epr = sve_register_length / 16; // running when 16\n        const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers\n\n        const int np= (n & ~(ggml_f16_step - 1));\n        svfloat16_t sum1 = svdup_n_f16(0.0f);\n        svfloat16_t sum2 = svdup_n_f16(0.0f);\n        svfloat16_t sum3 = svdup_n_f16(0.0f);\n        svfloat16_t sum4 = svdup_n_f16(0.0f);\n\n        svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;\n        svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;\n        for (int i = 0; i < np; i += ggml_f16_step) {\n            ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);\n            ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);\n            sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1);\n\n            ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);\n            ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);\n            sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2);\n\n            ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);\n            ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);\n            sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3);\n\n            ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);\n            ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);\n            sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4);\n\n            ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);\n            ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);\n            sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5);\n\n            ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);\n            ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);\n            sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6);\n\n            ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);\n            ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);\n            sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7);\n\n            ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);\n            ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);\n            sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8);\n        }\n\n        const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8\n        for (int k = np; k < np2; k += ggml_f16_epr) {\n            svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);\n            svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);\n            sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry);\n        }\n\n        if (np2 < n) {\n            svbool_t pg = svwhilelt_b16(np2, n);\n            svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));\n            svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));\n\n            sum1 = svmad_f16_x(pg, hx, hy, sum1);\n        }\n        GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4);\n    #elif defined(__riscv_v_intrinsic)\n        #if defined(__riscv_zvfh)\n            int vl = __riscv_vsetvlmax_e32m2();\n            vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);\n            vfloat32m2_t vsum;\n            vfloat16m1_t ax;\n            vfloat16m1_t ay;\n            vsum = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vmv_v_x_u32m2(0, vl));\n            for (int i = 0; i < n; i += vl) {\n                vl = __riscv_vsetvl_e16m1(n - i);\n                ax = __riscv_vle16_v_f16m1_tu(ax, (const _Float16 *)&x[i], vl);\n                ay = __riscv_vle16_v_f16m1_tu(ay, (const _Float16 *)&y[i], vl);\n                vsum = __riscv_vfwmacc_vv_f32m2_tu(vsum, ax, ay, vl);\n            }\n            vl = __riscv_vsetvlmax_e32m1();\n            vfloat32m1_t ac0 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(vsum, 0), __riscv_vget_v_f32m2_f32m1(vsum, 1), vl);\n            vs = __riscv_vfredusum_vs_f32m1_f32m1(ac0, vs, vl);\n            sumf += __riscv_vfmv_f_s_f32m1_f32(vs);\n        #else\n            for (int i = 0; i < n; ++i) {\n                sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));\n            }\n        #endif // __riscv_zvfh\n    #else\n        const int np = (n & ~(GGML_F16_STEP - 1));\n\n        GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };\n\n        GGML_F16_VEC ax[GGML_F16_ARR];\n        GGML_F16_VEC ay[GGML_F16_ARR];\n\n        for (int i = 0; i < np; i += GGML_F16_STEP) {\n            for (int j = 0; j < GGML_F16_ARR; j++) {\n                ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);\n                ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);\n\n                sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);\n            }\n        }\n\n        // reduce sum0..sum3 to sum0\n        GGML_F16_VEC_REDUCE(sumf, sum);\n\n        // leftovers\n        for (int i = np; i < n; ++i) {\n            sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));\n        }\n        // if you hit this, you are likely running outside the FP range\n        assert(!isnan(sumf) && !isinf(sumf));\n    #endif\n#else\n    for (int i = 0; i < n; ++i) {\n        sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));\n    }\n#endif // GGML_SIMD\n\n    *s = sumf;\n}\n\nvoid ggml_vec_silu_f32(const int n, float * y, const float * x) {\n    int i = 0;\n#if defined(__AVX512F__) && defined(__AVX512DQ__)\n    for (; i + 15 < n; i += 16) {\n        _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));\n    }\n#elif defined(__AVX2__) && defined(__FMA__)\n    for (; i + 7 < n; i += 8) {\n        _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));\n    }\n#elif defined(__SSE2__)\n    for (; i + 3 < n; i += 4) {\n        _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));\n    }\n#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)\n    const int vlen = svcntw();\n    for (; i < n; i += vlen) {\n        const svbool_t pg = svwhilelt_b32_s32(i, n);\n        svst1_f32(pg, y + i, ggml_v_silu(pg, svld1_f32(pg, x + i)));\n    }\n#elif defined(__ARM_NEON) && defined(__aarch64__)\n    for (; i + 3 < n; i += 4) {\n        vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));\n    }\n#elif defined(__riscv_v_intrinsic)\n    for (int vl; i < n; i += vl) {\n        vl = __riscv_vsetvl_e32m2(n - i);\n        vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);\n        vfloat32m2_t vy = ggml_v_silu_m2(vx, vl);\n        __riscv_vse32_v_f32m2(&y[i], vy, vl);\n    }\n#endif\n    for (; i < n; ++i) {\n        y[i] = ggml_silu_f32(x[i]);\n    }\n}\n\nvoid ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {\n    int i = 0;\n#if defined(__AVX512F__) && defined(__AVX512DQ__)\n    for (; i + 15 < n; i += 16) {\n        _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));\n    }\n#elif defined(__AVX2__) && defined(__FMA__)\n    for (; i + 7 < n; i += 8) {\n        _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i)));\n    }\n#elif defined(__SSE2__)\n    for (; i + 3 < n; i += 4) {\n        _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));\n    }\n#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)\n    const int vlen = svcntw();\n    for (; i < n; i += vlen) {\n        const svbool_t pg = svwhilelt_b32_s32(i, n);\n        svst1_f32(pg, y + i, svmul_f32_x(pg, ggml_v_silu(pg, svld1_f32(pg, x + i)), svld1_f32(pg, g + i)));\n    }\n#elif defined(__ARM_NEON) && defined(__aarch64__)\n    for (; i + 3 < n; i += 4) {\n        vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));\n    }\n#elif defined(__riscv_v_intrinsic)\n    for (int vl; i < n; i += vl) {\n        vl = __riscv_vsetvl_e32m2(n - i);\n        vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);\n        vfloat32m2_t vg = __riscv_vle32_v_f32m2(&g[i], vl);\n        vfloat32m2_t vy = __riscv_vfmul_vv_f32m2(ggml_v_silu_m2(vx, vl), vg, vl);\n        __riscv_vse32_v_f32m2(&y[i], vy, vl);\n    }\n#endif\n    for (; i < n; ++i) {\n        y[i] = ggml_silu_f32(x[i]) * g[i];\n    }\n}\n\nggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) {\n    int i = 0;\n    ggml_float sum = 0;\n// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE\n// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344\n#if defined(__AVX512F__) && defined(__AVX512DQ__)\n    for (; i + 15 < n; i += 16) {\n        __m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i),\n                                   _mm512_set1_ps(mean));\n        _mm512_storeu_ps(y + i, val);\n        sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val));\n    }\n#elif defined(__AVX2__) && defined(__FMA__)\n    for (; i + 7 < n; i += 8) {\n        __m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i),\n                                   _mm256_set1_ps(mean));\n        _mm256_storeu_ps(y + i, val);\n        val = _mm256_mul_ps(val,val);\n        __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),\n                                 _mm256_castps256_ps128(val));\n        val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));\n        val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));\n        sum += (ggml_float)_mm_cvtss_f32(val2);\n    }\n#elif defined(__SSE2__)\n    for (; i + 3 < n; i += 4) {\n        __m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),\n                                _mm_set1_ps(mean));\n        _mm_storeu_ps(y + i, val);\n        val = _mm_mul_ps(val, val);\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\n        val = _mm_add_ps(val, _mm_movehl_ps(val, val));\n        val = _mm_add_ss(val, _mm_movehdup_ps(val));\n#else\n        __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));\n        val = _mm_add_ps(val, tmp);\n        tmp = _mm_movehl_ps(tmp, val);\n        val = _mm_add_ss(val, tmp);\n#endif  // __AVX__ || __AVX2__ || __AVX512F__\n        sum += (ggml_float)_mm_cvtss_f32(val);\n    }\n#elif defined(__ARM_NEON) && defined(__aarch64__)\n    for (; i + 3 < n; i += 4) {\n        float32x4_t val = vsubq_f32(vld1q_f32(x + i),\n                                    vdupq_n_f32(mean));\n        vst1q_f32(y + i, val);\n        val = vmulq_f32(val, val);\n        sum += (ggml_float)vaddvq_f32(val);\n    }\n#elif defined(__VXE__) || defined(__VXE2__)\n    for (; i + 3 < n; i += 4) {\n        float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean));\n        vec_xst(val, 0, y + i);\n        val = vec_mul(val, val);\n        sum += (ggml_float)vec_hsum_f32x4(val);\n    }\n#elif defined(__riscv_v_intrinsic)\n    vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);\n    for (int vl; i < n; i += vl) {\n        vl = __riscv_vsetvl_e32m2(n - i);\n        vfloat32m2_t val = __riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], vl), mean, vl);\n        __riscv_vse32_v_f32m2(&y[i], val, vl);\n        val = __riscv_vfmul_vv_f32m2(val, val, vl);\n        vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, vl);\n    }\n    sum = (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);\n#endif\n    for (; i < n; ++i) {\n        float val = x[i] - mean;\n        y[i] = val;\n        val *= val;\n        sum += (ggml_float)val;\n    }\n    return sum/n;\n}\n\nggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {\n    int i = 0;\n    ggml_float sum = 0;\n#if defined(__AVX512F__) && defined(__AVX512DQ__)\n    for (; i + 15 < n; i += 16) {\n        __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),\n                                               _mm512_set1_ps(max)));\n        _mm512_storeu_ps(y + i, val);\n        sum += (ggml_float)_mm512_reduce_add_ps(val);\n    }\n#elif defined(__AVX2__) && defined(__FMA__)\n    for (; i + 7 < n; i += 8) {\n        __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),\n                                               _mm256_set1_ps(max)));\n        _mm256_storeu_ps(y + i, val);\n        __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),\n                                 _mm256_castps256_ps128(val));\n        val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));\n        val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));\n        sum += (ggml_float)_mm_cvtss_f32(val2);\n    }\n#elif defined(__SSE2__)\n    for (; i + 3 < n; i += 4) {\n        __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),\n                                            _mm_set1_ps(max)));\n        _mm_storeu_ps(y + i, val);\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\n        val = _mm_add_ps(val, _mm_movehl_ps(val, val));\n        val = _mm_add_ss(val, _mm_movehdup_ps(val));\n#else\n        __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));\n        val = _mm_add_ps(val, tmp);\n        tmp = _mm_movehl_ps(tmp, val);\n        val = _mm_add_ss(val, tmp);\n#endif\n        sum += (ggml_float)_mm_cvtss_f32(val);\n    }\n#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)\n    const int vlen = svcntw();\n    for (; i < n; i += vlen) {\n        const svbool_t pg = svwhilelt_b32_s32(i, n);\n        svfloat32_t val = ggml_v_expf(pg, svsub_f32_x(pg, svld1_f32(pg, x + i),\n                                                svdup_n_f32_x(pg, max)));\n        svst1_f32(pg, y + i, val);\n        sum += (ggml_float)svaddv_f32(pg, val);\n    }\n#elif defined(__ARM_NEON) && defined(__aarch64__)\n    for (; i + 3 < n; i += 4) {\n        float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),\n                                                vdupq_n_f32(max)));\n        vst1q_f32(y + i, val);\n        sum += (ggml_float)vaddvq_f32(val);\n    }\n#elif defined(__riscv_v_intrinsic)\n    vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);\n    for (int avl; i < n; i += avl) {\n        avl = __riscv_vsetvl_e32m2(n - i);\n        vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl);\n        __riscv_vse32_v_f32m2(&y[i], val, avl);\n        vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl);\n    }\n    return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);\n#endif\n    for (; i < n; ++i) {\n        float val = expf(x[i] - max);\n        sum += (ggml_float)val;\n        y[i] = val;\n    }\n    return sum;\n}\n\nggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {\n    // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)\n\n    int i = 0;\n    ggml_float sum = 0;\n    for (; i < n; ++i) {\n        float val = x[i] - max;\n        y[i] = val;\n        sum += (ggml_float)expf(val);\n    }\n    return sum = (ggml_float)logf(sum);\n}\n"
  },
  {
    "path": "src/ggml-cpu/vec.h",
    "content": "// Vectorized functions for fundamental operations\n\n#pragma once\n\n#include \"ggml-impl.h\"\n#include \"simd-mappings.h\"\n#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n\n#if defined(GGML_USE_ACCELERATE)\n#include <Accelerate/Accelerate.h>\n#endif\n\n// floating point type used to accumulate sums\ntypedef double ggml_float;\n\n#define GGML_GELU_FP16\n#define GGML_GELU_QUICK_FP16\n\n#define GGML_SOFT_MAX_UNROLL 4\n#define GGML_VEC_DOT_UNROLL  2\n#define GGML_VEC_MAD_UNROLL  32\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n//\n// global data\n//\n\n// precomputed gelu table for f16 (128 KB)\nextern ggml_fp16_t ggml_table_gelu_f16[1 << 16];\n\n// precomputed quick gelu table for f16 (128 KB)\nextern ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];\n\n//\n// fundamental operations\n//\n\nvoid ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);\nvoid ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);\nvoid ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);\n\nvoid ggml_vec_silu_f32(const int n, float * y, const float * x);\nggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean); //it will also center y ( y = y - mean )\nggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);\nggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);\n\ninline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }\ninline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }\n\ninline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t   v) { for (int i = 0; i < n; ++i) x[i] = v;    }\ninline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }\n\ninline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }\ninline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }\n\ninline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {\n    int i = 0;\n#if defined(__AVX2__)\n    for (; i + 7 < n; i += 8) {\n        __m256 vx = _mm256_loadu_ps(x + i);\n        __m256 vy = _mm256_loadu_ps(y + i);\n        __m256 vz = _mm256_add_ps(vx, vy);\n        _mm256_storeu_ps(z + i, vz);\n    }\n#endif\n    for (; i < n; ++i) {\n        z[i] = x[i] + y[i];\n    }\n}\n\ninline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {\n    for (int i = 0; i < n; ++i) {\n        z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));\n    }\n}\ninline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float   v) { for (int i = 0; i < n; ++i) z[i]  = x[i] + v;    }\ninline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }\ninline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }\ninline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }\ninline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {\n    for (int i = 0; i < n; ++i) {\n        z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) - GGML_CPU_FP16_TO_FP32(y[i]));\n    }\n}\ninline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }\ninline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }\ninline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }\ninline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16(-GGML_CPU_FP16_TO_FP32(x[i]));\n    }\n}\n\ninline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }\ninline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {\n    for (int i = 0; i < n; ++i) {\n        z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) * GGML_CPU_FP16_TO_FP32(y[i]));\n    }\n}\ninline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }\ninline static void ggml_vec_div_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {\n    for (int i = 0; i < n; ++i) {\n        z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) / GGML_CPU_FP16_TO_FP32(y[i]));\n    }\n}\n\n// compute GGML_VEC_DOT_UNROLL dot products at once\n// xs - x row stride in bytes\ninline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GGML_RESTRICT s, void * GGML_RESTRICT xv, ggml_fp16_t * GGML_RESTRICT y) {\n    ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };\n\n    ggml_fp16_t * GGML_RESTRICT x[GGML_VEC_DOT_UNROLL];\n\n    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {\n        x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);\n    }\n\n#if defined(GGML_SIMD)\n    #if defined(__ARM_FEATURE_SVE)\n\n        const int sve_register_length = svcntb() * 8;\n        const int ggml_f16_epr = sve_register_length / 16; // running when 16\n        const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers\n\n        const int np = (n & ~(ggml_f16_step - 1));\n\n        svfloat16_t sum_00 = svdup_n_f16(0.0f);\n        svfloat16_t sum_01 = svdup_n_f16(0.0f);\n        svfloat16_t sum_02 = svdup_n_f16(0.0f);\n        svfloat16_t sum_03 = svdup_n_f16(0.0f);\n\n        svfloat16_t sum_10 = svdup_n_f16(0.0f);\n        svfloat16_t sum_11 = svdup_n_f16(0.0f);\n        svfloat16_t sum_12 = svdup_n_f16(0.0f);\n        svfloat16_t sum_13 = svdup_n_f16(0.0f);\n\n        svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;\n        svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;\n\n        for (int i = 0; i < np; i += ggml_f16_step) {\n            ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements\n\n            ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements\n            sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1);     // sum_00 = sum_00+ax1*ay1\n            ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements\n            sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);\n\n            ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements\n\n            ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements\n            sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);\n            ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);\n            sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);\n\n            ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);\n\n            ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);\n            sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);\n            ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);\n            sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);\n\n            ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);\n\n            ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3);\n            sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4);\n            ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3);\n            sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4);\n\n            ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);\n\n            ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4);\n\n            sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5);\n            ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4);\n            sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5);\n\n            ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);\n\n            ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5);\n\n            sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6);\n            ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5);\n            sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6);\n\n            ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);\n\n            ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6);\n\n            sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7);\n            ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6);\n            sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7);\n\n            ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);\n\n            ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7);\n\n            sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8);\n            ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7);\n            sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8);\n        }\n\n        const int np2 = (n & ~(ggml_f16_epr - 1));\n        for (int k = np; k < np2; k += ggml_f16_epr) {\n            svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);\n\n            svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0);\n            sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry);\n            rx = GGML_F16x_VEC_LOAD(x[1] + k, 0);\n            sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry);\n        }\n\n        if (np2 < n) {\n            svbool_t pg = svwhilelt_b16(np2, n);\n            svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2));\n            svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2));\n            svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));\n\n            sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00);\n            sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10);\n        }\n        GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);\n        GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);\n\n    #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)\n        size_t vl = __riscv_vsetvlmax_e32m4();\n\n        // initialize accumulators to all zeroes\n        vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);\n        vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);\n        vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);\n        vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);\n\n        // calculate step size\n        const size_t epr = __riscv_vsetvlmax_e16m2();\n        const size_t step = epr * 2;\n        const int np = (n & ~(step - 1));\n\n        // unroll by 2 along the row dimension\n        for (int i = 0; i < np; i += step) {\n            vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);\n            vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);\n            vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);\n            vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);\n            vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);\n\n            vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);\n            vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);\n            vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);\n            vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);\n            vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);\n        }\n\n        vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);\n        vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);\n\n        // leftovers\n        for (int i = np; i < n; i += vl) {\n            vl = __riscv_vsetvl_e16m2(n - i);\n            vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);\n            vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);\n            vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);\n\n            vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);\n            vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);\n        }\n\n        // reduce\n        vl = __riscv_vsetvlmax_e32m2();\n        vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),\n                                    __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);\n        vl = __riscv_vsetvlmax_e32m1();\n        vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),\n        __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);\n        vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(\n                                    acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);\n\n        vl = __riscv_vsetvlmax_e32m2();\n        vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),\n                                    __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);\n        vl = __riscv_vsetvlmax_e32m1();\n        vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),\n                                    __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);\n        vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(\n                                    acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);\n        sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);\n        sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);\n\n    #else\n        const int np = (n & ~(GGML_F16_STEP - 1));\n\n        GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };\n\n        GGML_F16_VEC ax[GGML_F16_ARR];\n        GGML_F16_VEC ay[GGML_F16_ARR];\n\n        for (int i = 0; i < np; i += GGML_F16_STEP) {\n            for (int j = 0; j < GGML_F16_ARR; j++) {\n                ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);\n\n                for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {\n                    ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);\n\n                    sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);\n                }\n            }\n        }\n\n        // reduce sum0..sum3 to sum0\n        for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {\n            GGML_F16_VEC_REDUCE(sumf[k], sum[k]);\n        }\n\n        // leftovers\n        for (int i = np; i < n; ++i) {\n            for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {\n                sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));\n            }\n        }\n    #endif\n#else\n    for (int i = 0; i < n; ++i) {\n        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {\n            sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));\n        }\n    }\n#endif\n\n    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {\n        s[i] = (float)sumf[i];\n    }\n}\n\ninline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const float * GGML_RESTRICT x, const float v) {\n#if defined(GGML_SIMD)\n    #if defined(__ARM_FEATURE_SVE)\n\n        const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;\n        const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16\n        const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers\n        GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);\n\n        const int np = (n & ~(ggml_f32_step - 1));\n        svfloat32_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;\n        svfloat32_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;\n        for (int i = 0; i < np; i += ggml_f32_step) {\n\n            ax1 = GGML_F32_VEC_LOAD(x + i);\n            ay1 = GGML_F32_VEC_LOAD(y + i);\n            ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);\n\n            GGML_F32_VEC_STORE(y + i, ay1);\n\n            ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);\n            ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);\n            ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);\n\n            GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);\n\n            ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);\n            ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);\n            ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);\n\n            GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);\n\n            ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);\n            ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);\n            ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);\n\n            GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);\n\n            ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);\n            ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);\n            ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);\n\n            GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);\n\n            ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);\n            ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);\n            ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);\n\n            GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);\n\n            ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);\n            ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);\n            ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);\n\n            GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);\n\n            ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);\n            ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);\n            ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);\n\n            GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);\n        }\n        // leftovers\n        // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop\n        const int np2 = (n & ~(ggml_f32_epr - 1));\n        for (int i = np; i < np2; i += ggml_f32_epr) {\n            ax1 = GGML_F32_VEC_LOAD(x + i);\n            ay1 = GGML_F32_VEC_LOAD(y + i);\n            ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);\n\n            GGML_F32_VEC_STORE(y + i, ay1);\n        }\n        // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only\n        if (np2 < n) {\n            svbool_t pg =svwhilelt_b32(np2, n);\n            ax1 = svld1_f32(pg, x + np2);\n            ay1 = svld1_f32(pg, y + np2);\n            ay1 = svmad_f32_m(pg, ax1, vx, ay1);\n\n            svst1_f32(pg, y + np2, ay1);\n        }\n    #elif defined(__riscv_v_intrinsic)\n        for (int i = 0, avl; i < n; i += avl) {\n            avl = __riscv_vsetvl_e32m8(n - i);\n            vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);\n            vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);\n            vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, v, ay, avl);\n            __riscv_vse32_v_f32m8(&y[i], ny, avl);\n        }\n    #else\n        const int np = (n & ~(GGML_F32_STEP - 1));\n\n        GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);\n\n        GGML_F32_VEC ax[GGML_F32_ARR];\n        GGML_F32_VEC ay[GGML_F32_ARR];\n\n        for (int i = 0; i < np; i += GGML_F32_STEP) {\n            for (int j = 0; j < GGML_F32_ARR; j++) {\n                ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);\n                ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);\n                ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);\n\n                GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);\n            }\n        }\n\n        // leftovers\n        for (int i = np; i < n; ++i) {\n            y[i] += x[i]*v;\n        }\n    #endif\n#else\n    // scalar\n    for (int i = 0; i < n; ++i) {\n        y[i] += x[i]*v;\n    }\n#endif\n}\n\ninline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {\n#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)\n    const int sve_register_length = svcntb() * 8;\n    const int ggml_f16_epr = sve_register_length / 16;\n    const int ggml_f16_step = 8 * ggml_f16_epr;\n\n    GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);\n\n    int np = (n & ~(ggml_f16_step - 1));\n\n    svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;\n    svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;\n    for (int i = 0; i < np; i += ggml_f16_step) {\n        ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);\n        ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);\n        ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);\n\n        GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);\n\n        ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);\n        ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);\n        ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);\n\n        GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);\n\n        ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);\n        ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);\n        ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);\n\n        GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);\n\n        ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);\n        ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);\n        ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);\n\n        GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);\n\n        ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);\n        ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);\n        ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);\n\n        GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);\n\n        ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);\n        ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);\n        ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);\n\n        GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);\n\n        ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);\n        ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);\n        ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);\n\n        GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);\n\n        ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);\n        ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);\n        ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);\n\n        GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);\n    }\n    const int np2 = (n & ~(ggml_f16_epr - 1));\n    for (int k = np; k < np2; k += ggml_f16_epr) {\n        svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);\n        svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);\n        ry = GGML_F16x_VEC_FMA(ry, rx, vx);\n\n        GGML_F16x_VEC_STORE(y + k, ry, 0);\n    }\n\n    if (np2 < n) {\n        svbool_t pg = svwhilelt_b16(np2, n);\n        svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));\n        svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));\n        hy = svmad_f16_x(pg, hx, vx, hy);\n        svst1_f16(pg, (__fp16 *)(y + np2), hy);\n    }\n    np = n;\n#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic\n    const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);\n    const _Float16 scale = *(const _Float16*)(&s);\n\n    // calculate step size\n    const int epr = __riscv_vsetvlmax_e16m4();\n    const int step = epr * 2;\n    int np = (n & ~(step - 1));\n\n    // unroll by 2\n    for (int i = 0; i < np; i += step) {\n        vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);\n        vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);\n        ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);\n        __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);\n        __asm__ __volatile__ (\"\" ::: \"memory\");\n\n        vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);\n        vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);\n        ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);\n        __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);\n        __asm__ __volatile__ (\"\" ::: \"memory\");\n    }\n\n    // leftovers\n    int vl;\n    for (int i = np; i < n; i += vl) {\n        vl = __riscv_vsetvl_e16m4(n - i);\n        vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);\n        vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);\n        ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);\n        __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);\n    }\n    np = n;\n#elif defined(GGML_SIMD)\n    const int np = (n & ~(GGML_F16_STEP - 1));\n\n    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);\n\n    GGML_F16_VEC ax[GGML_F16_ARR];\n    GGML_F16_VEC ay[GGML_F16_ARR];\n\n    for (int i = 0; i < np; i += GGML_F16_STEP) {\n        for (int j = 0; j < GGML_F16_ARR; j++) {\n            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);\n            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);\n            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);\n\n            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);\n        }\n    }\n#else\n    const int np = 0;\n#endif\n\n    // leftovers\n    for (int i = np; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);\n    }\n}\n\n// xs and vs are byte strides of x and v\ninline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * GGML_RESTRICT y, const float * GGML_RESTRICT xv, const float * GGML_RESTRICT vv) {\n\n    const float * GGML_RESTRICT x[GGML_VEC_MAD_UNROLL];\n    const float * GGML_RESTRICT v[GGML_VEC_MAD_UNROLL];\n\n    for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {\n        x[i] = (const float *) ((const char *) xv + i*xs);\n        v[i] = (const float *) ((const char *) vv + i*vs);\n    }\n\n#if defined(GGML_SIMD)\n    #if defined(__ARM_FEATURE_SVE)\n        // scalar Route to scalar implementation       //TODO: Write SVE code\n        for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {\n            for (int i = 0; i < n; ++i) {\n                y[i] += x[k][i]*v[k][0];\n            }\n        }\n    #elif defined(__riscv_v_intrinsic)\n        for (int i = 0, avl; i < n; i += avl) {\n            avl = __riscv_vsetvl_e32m8(n - i);\n            vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);\n            for (int k = 0; k < GGML_VEC_MAD_UNROLL; k++) {\n                vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[k][i], avl);\n                ay = __riscv_vfmadd_vf_f32m8(ax, v[k][0], ay, avl);\n            }\n            __riscv_vse32_v_f32m8(&y[i], ay, avl);\n        }\n    #else\n        const int np = (n & ~(GGML_F32_STEP - 1));\n\n        GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];\n\n        for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {\n            vx[k] = GGML_F32_VEC_SET1(v[k][0]);\n        }\n\n        GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];\n        GGML_F32_VEC ay[GGML_F32_ARR];\n\n        for (int i = 0; i < np; i += GGML_F32_STEP) {\n            for (int j = 0; j < GGML_F32_ARR; j++) {\n                ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);\n\n                for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {\n                    ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);\n                    ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);\n                }\n\n                GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);\n            }\n        }\n\n        // leftovers\n        for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {\n            for (int i = np; i < n; ++i) {\n                y[i] += x[k][i]*v[k][0];\n            }\n        }\n    #endif\n#else\n    // scalar\n    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {\n        for (int i = 0; i < n; ++i) {\n            y[i] += x[k][i]*v[k][0];\n        }\n    }\n#endif\n}\n\ninline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {\n#if defined(GGML_USE_ACCELERATE)\n    vDSP_vsmsa(x, 1, &s, &b, y, 1, n);\n#elif defined(GGML_SIMD)\n    #if defined(__ARM_FEATURE_SVE)\n        // scalar ; TODO: Write SVE code\n        for (int i = 0; i < n; ++i) {\n            y[i] = x[i]*s + b;\n        }\n    #elif defined(__riscv_v_intrinsic)\n        for (int i = 0, avl; i < n; i += avl) {\n            avl = __riscv_vsetvl_e32m8(n - i);\n            vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);\n            vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl);\n            vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl);\n            __riscv_vse32_v_f32m8(&y[i], ny, avl);\n        }\n    #else\n        const int np = (n & ~(GGML_F32_STEP - 1));\n\n        GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);\n        GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);\n\n        GGML_F32_VEC ay[GGML_F32_ARR];\n\n        for (int i = 0; i < np; i += GGML_F32_STEP) {\n            for (int j = 0; j < GGML_F32_ARR; j++) {\n                ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);\n                ay[j] = GGML_F32_VEC_FMA(vb, ay[j], vs);\n\n                GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);\n            }\n        }\n\n        // leftovers\n        for (int i = np; i < n; ++i) {\n            y[i] = x[i]*s + b;\n        }\n    #endif\n#else\n    // scalar\n    for (int i = 0; i < n; ++i) {\n        y[i] = x[i]*s + b;\n    }\n#endif\n}\n\n//inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }\ninline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {\n#if defined(GGML_USE_ACCELERATE)\n    vDSP_vsmul(y, 1, &v, y, 1, n);\n#elif defined(GGML_SIMD)\n    #if defined(__ARM_FEATURE_SVE)\n        const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;\n        const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16\n        const int ggml_f32_step = 2 * ggml_f32_epr;\n\n        GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);\n        const int np = (n & ~(ggml_f32_step - 1));\n        svfloat32_t ay1;\n        svfloat32_t ay2;\n        for (int i = 0; i < np; i += ggml_f32_step) {\n            ay1 = GGML_F32_VEC_LOAD(y + i);\n            ay1 = GGML_F32_VEC_MUL(ay1, vx);\n            GGML_F32_VEC_STORE(y + i, ay1);\n\n            ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);\n            ay2 = GGML_F32_VEC_MUL(ay2, vx);\n            GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);\n        }\n        // leftovers\n        // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only\n        for (int i = np; i < n; i += ggml_f32_epr) {\n            svbool_t pg = svwhilelt_b32(i, n);\n            ay1 = svld1_f32(pg, y + i);\n            ay1 = svmul_f32_m(pg, ay1, vx);\n            svst1_f32(pg, y + i, ay1);\n        }\n    #elif defined(__riscv_v_intrinsic)\n        for (int i = 0, avl; i < n; i += avl) {\n            avl = __riscv_vsetvl_e32m8(n - i);\n            vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);\n            vfloat32m8_t ny = __riscv_vfmul_vf_f32m8(ay, v, avl);\n            __riscv_vse32_v_f32m8(&y[i], ny, avl);\n        }\n    #else\n        const int np = (n & ~(GGML_F32_STEP - 1));\n\n        GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);\n\n        GGML_F32_VEC ay[GGML_F32_ARR];\n\n        for (int i = 0; i < np; i += GGML_F32_STEP) {\n            for (int j = 0; j < GGML_F32_ARR; j++) {\n                ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);\n                ay[j] = GGML_F32_VEC_MUL(ay[j], vx);\n\n                GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);\n            }\n        }\n\n        // leftovers\n        for (int i = np; i < n; ++i) {\n            y[i] *= v;\n        }\n    #endif\n#else\n    // scalar\n    for (int i = 0; i < n; ++i) {\n        y[i] *= v;\n    }\n#endif\n}\n\ninline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {\n#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)\n    const int sve_register_length = svcntb() * 8;\n    const int ggml_f16_epr = sve_register_length / 16;\n    const int ggml_f16_step = 2 * ggml_f16_epr;\n\n    GGML_F16x_VEC vx =  GGML_F16x_VEC_SET1(v);\n    const int np = (n & ~(ggml_f16_step - 1));\n    svfloat16_t ay1, ay2;\n\n    for (int i = 0; i < np; i += ggml_f16_step) {\n        ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);\n        ay1 = GGML_F16x_VEC_MUL(ay1, vx);\n        GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);\n\n        ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);\n        ay2 = GGML_F16x_VEC_MUL(ay2, vx);\n        GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);\n    }\n    // leftovers\n    // maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only\n    if (np < n) {\n        svbool_t pg = svwhilelt_b16(np, n);\n        svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));\n        svfloat16_t out = svmul_f16_m(pg, hy, vx);\n        svst1_f16(pg, (__fp16 *)(y + np), out);\n    }\n#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)\n    const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);\n    const _Float16 scale = *(const _Float16*)(&s);\n\n    // calculate step size\n    const int epr = __riscv_vsetvlmax_e16m4();\n    const int step = epr * 2;\n    const int np = (n & ~(step - 1));\n\n    // unroll by 2\n    for (int i = 0; i < np; i += step) {\n        vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);\n        ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);\n        __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);\n        __asm__ __volatile__ (\"\" ::: \"memory\");\n\n        vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);\n        ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);\n        __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);\n        __asm__ __volatile__ (\"\" ::: \"memory\");\n    }\n\n    // leftovers\n    int vl;\n    for (int i = np; i < n; i += vl) {\n        vl = __riscv_vsetvl_e16m4(n - i);\n        vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);\n        ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);\n        __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);\n    }\n#elif defined(GGML_SIMD)\n    const int np = (n & ~(GGML_F16_STEP - 1));\n\n    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);\n\n    GGML_F16_VEC ay[GGML_F16_ARR];\n\n    for (int i = 0; i < np; i += GGML_F16_STEP) {\n        for (int j = 0; j < GGML_F16_ARR; j++) {\n            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);\n            ay[j] = GGML_F16_VEC_MUL(ay[j], vx);\n\n            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);\n        }\n    }\n\n    // leftovers\n    for (int i = np; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);\n    }\n#else\n    // scalar\n    for (int i = 0; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);\n    }\n#endif\n}\n\ninline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s);   }\ninline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }\ninline static void ggml_vec_sqr_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        float v = GGML_CPU_FP16_TO_FP32(x[i]);\n        y[i] = GGML_CPU_FP32_TO_FP16(v*v);\n    }\n}\ninline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }\ninline static void ggml_vec_sqrt_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16(sqrtf(GGML_CPU_FP16_TO_FP32(x[i])));\n    }\n}\ninline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);  }\ninline static void ggml_vec_log_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16(logf(GGML_CPU_FP16_TO_FP32(x[i])));\n    }\n}\ninline static void ggml_vec_sin_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]);  }\ninline static void ggml_vec_sin_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16(sinf(GGML_CPU_FP16_TO_FP32(x[i])));\n    }\n}\ninline static void ggml_vec_cos_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]);  }\ninline static void ggml_vec_cos_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16(cosf(GGML_CPU_FP16_TO_FP32(x[i])));\n    }\n}\ninline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }\ninline static void ggml_vec_abs_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16(fabsf(GGML_CPU_FP16_TO_FP32(x[i])));\n    }\n}\ninline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }\ninline static void ggml_vec_sgn_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        float v = GGML_CPU_FP16_TO_FP32(x[i]);\n        y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? 1.f : ((v < 0.f) ? -1.f : 0.f));\n    }\n}\ninline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }\ninline static void ggml_vec_step_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16((GGML_CPU_FP16_TO_FP32(x[i]) > 0.f) ? 1.f : 0.f);\n    }\n}\ninline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]);  }\ninline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16(tanhf(GGML_CPU_FP16_TO_FP32(x[i])));\n    }\n}\ninline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }\ninline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        const float v = GGML_CPU_FP16_TO_FP32(x[i]);\n        y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : expm1f(v));\n    }\n}\ninline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }\ninline static void ggml_vec_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        float v = GGML_CPU_FP16_TO_FP32(x[i]);\n        y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : 0.f);\n    }\n}\ninline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }\ninline static void ggml_vec_leaky_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const float ns) {\n    for (int i = 0; i < n; ++i) {\n        float v = GGML_CPU_FP16_TO_FP32(x[i]);\n        y[i] = GGML_CPU_FP32_TO_FP16(((v > 0.f) ? v : 0.f) + ns * ((v < 0.0f) ? v : 0.f));\n    }\n}\ninline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }\ninline static void ggml_vec_sigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16(1.f / (1.f + expf(-GGML_CPU_FP16_TO_FP32(x[i]))));\n    }\n}\n// TODO: optimize performance\ninline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }\ninline static void ggml_vec_hardswish_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        float v = GGML_CPU_FP16_TO_FP32(x[i]);\n        y[i] = GGML_CPU_FP32_TO_FP16(v * fminf(1.0f, fmaxf(0.0f, (v + 3.0f) / 6.0f)));\n    }\n}\ninline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }\ninline static void ggml_vec_hardsigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16(fminf(1.0f, fmaxf(0.0f, (GGML_CPU_FP16_TO_FP32(x[i]) + 3.0f) / 6.0f)));\n    }\n}\ninline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }\ninline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = GGML_CPU_FP32_TO_FP16(expf(GGML_CPU_FP16_TO_FP32(x[i])));\n    }\n}\n\nstatic const float GELU_COEF_A     = 0.044715f;\nstatic const float GELU_QUICK_COEF = -1.702f;\nstatic const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;\nstatic const float SQRT_2_INV      = 0.70710678118654752440084436210484f;\n\ninline static float ggml_gelu_f32(float x) {\n    return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));\n}\n\ninline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    const uint16_t * i16 = (const uint16_t *) x;\n    for (int i = 0; i < n; ++i) {\n        y[i] = ggml_table_gelu_f16[i16[i]];\n    }\n}\n\ninline static void ggml_vec_gelu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        float xi = GGML_CPU_FP16_TO_FP32(x[i]);\n        float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));\n        y[i] = GGML_CPU_FP32_TO_FP16(res);\n    }\n}\n\n#ifdef GGML_GELU_FP16\ninline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {\n    uint16_t t;\n    for (int i = 0; i < n; ++i) {\n        if (x[i] <= -10.0f) {\n            y[i] = 0.0f;\n        } else if (x[i] >= 10.0f) {\n            y[i] = x[i];\n        } else {\n            ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);\n            memcpy(&t, &fp16, sizeof(uint16_t));\n            y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]);\n        }\n    }\n}\n#else\ninline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = ggml_gelu_f32(x[i]);\n    }\n}\n#endif\n\ninline static void ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) {\n    for (int i = 0; i < n; ++i) {\n        float xi = x[i];\n        y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));\n    }\n}\n\ninline static float ggml_gelu_quick_f32(float x) {\n    return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));\n}\n\n//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n//    const uint16_t * i16 = (const uint16_t *) x;\n//    for (int i = 0; i < n; ++i) {\n//        y[i] = ggml_table_gelu_quick_f16[i16[i]];\n//    }\n//}\n\n#ifdef GGML_GELU_QUICK_FP16\ninline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {\n    uint16_t t;\n    for (int i = 0; i < n; ++i) {\n        ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);\n        memcpy(&t, &fp16, sizeof(uint16_t));\n        y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);\n    }\n}\n#else\ninline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = ggml_gelu_quick_f32(x[i]);\n    }\n}\n#endif\n\ninline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        float v = GGML_CPU_FP16_TO_FP32(x[i]);\n        y[i] = GGML_CPU_FP32_TO_FP16(v*(1.0f/(1.0f+expf(GELU_QUICK_COEF*v))));\n    }\n}\n\n// Sigmoid Linear Unit (SiLU) function\ninline static float ggml_silu_f32(float x) {\n    return x/(1.0f + expf(-x));\n}\ninline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {\n    float v = GGML_CPU_FP16_TO_FP32(x);\n    return GGML_CPU_FP32_TO_FP16(v/(1.0f + expf(-v)));\n}\n\n#if __FINITE_MATH_ONLY__\n#error \"some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix\"\n#error \"ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461\"\n#endif\n\n/* Below function was borrowed from the GitHub repository:\nhttps://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */\n#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)\n    inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) {\n        // Constants\n        const svfloat32_t log2_e = svdup_n_f32(1.4426950409f);\n        const svfloat32_t ln2 = svdup_n_f32(0.6931473921f);\n        const svfloat32_t half_ln2_sq = svdup_n_f32(0.2413862043f);\n        const svuint32_t not_mask17 = svdup_n_u32(~((1u << 17) - 1));\n        const svfloat32_t one = svdup_n_f32(1.0f);\n        const svfloat32_t inactive1 = svdup_n_f32(0.0f);\n        const svint32_t inactive2 = svdup_n_s32(0);\n\n        // Algorithm starts here\n        svfloat32_t t0 = svmul_f32_m(pg, src, log2_e);  // y = x * log2(e)\n        svfloat32_t t1 = svrintm_f32_m(inactive1, pg, t0);         // rount to int (float)\n        svint32_t t2 = svcvt_s32_f32_m(inactive2, pg, t1);         // n\n\n        t1 = svsub_f32_m(pg, t0, t1);   // a = y - floor(y)\n        t1 = svadd_f32_m(pg, t1, one);  // b = a + 1\n\n        svuint32_t t3 = svlsr_n_u32_m(pg, svreinterpret_u32_f32(t1), 17);  // v = b >> 17 (u32)\n        svfloat32_t t4 = svexpa_f32(t3);                                   // c = fexpa(v)\n        t4 = svscale_f32_m(pg, t4, t2);                                    // fexpa(v) * 2^(n)\n\n        // and_(t2.d, t1.d, not_mask17.d)\n        svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_m(pg, svreinterpret_u32_f32(t1), not_mask17));\n        t5 = svsub_f32_m(pg, t1, t5);                // z\n        t0 = svmla_f32_m(pg, ln2, t5, half_ln2_sq);  // ln2 + half_ln2_sq * z\n        t0 = svmla_f32_m(pg, one, t5, t0);           // 1 + (ln2 * z) + (half_ln2_sq * z * z)\n        t0 = svmul_f32_m(pg, t0, t4);                // Final result\n\n        return t0;\n    }\n#endif\n\n#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)\n\ninline static svfloat32_t ggml_v_expf(svbool_t pg, svfloat32_t x) {\n    const svfloat32_t r = svdup_n_f32_x(pg, 0x1.8p23f);\n    const svfloat32_t z = svmla_n_f32_x(pg, r, x, 0x1.715476p+0f);\n    const svfloat32_t n = svsub_f32_x(pg, z, r);\n    const svfloat32_t b = svmls_n_f32_x(pg, svmls_n_f32_x(pg, x, n, 0x1.62e4p-1f), n, 0x1.7f7d1cp-20f);\n    const svuint32_t e = svlsl_n_u32_x(pg, svreinterpret_u32_f32(z), 23);\n    const svfloat32_t k = svreinterpret_f32_u32(svadd_u32_x(pg, e, svreinterpret_u32_f32(svdup_n_f32_x(pg, 1))));\n    const svbool_t c = svacgt_n_f32(pg, n, 126);\n    const svfloat32_t u = svmul_f32_x(pg, b, b);\n    const svfloat32_t j = svmla_f32_x(pg,\n        svmul_n_f32_x(pg, b, 0x1.ffffecp-1f),\n        svmla_f32_x(pg, svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.fffdb6p-2f), svdup_n_f32_x(pg, 0x1.555e66p-3f), b),\n                        svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.573e2ep-5f), svdup_n_f32_x(pg, 0x1.0e4020p-7f), b), u), u);\n    const svuint32_t d = svdup_n_u32_z(svcmple_n_f32(pg, n, 0.0), 0x82000000);\n    const svfloat32_t s1 = svreinterpret_f32_u32(svadd_n_u32_x(pg, d, 0x7f000000));\n    const svfloat32_t s2 = svreinterpret_f32_u32(svsub_u32_x(pg, e, d));\n    return svsel_f32(svacgt_f32(pg, n, svdup_n_f32_x(pg, 192)), svmul_f32_x(pg, s1, s1),\n                     svsel_f32(c, svmul_f32_x(pg, svmla_f32_x(pg, s2, s2, j), s1), svmla_f32_x(pg, k, k, j)));\n}\n\n// computes silu x/(1+exp(-x)) in single precision vector\ninline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) {\n    const svfloat32_t one = svdup_n_f32_x(pg, 1.0f);\n    const svfloat32_t zero = svdup_n_f32_x(pg, 0.0f);\n    const svfloat32_t neg_x = svsub_f32_x(pg, zero, x);\n    const svfloat32_t exp_neg_x = ggml_v_expf(pg, neg_x);\n    const svfloat32_t one_plus_exp_neg_x = svadd_f32_x(pg, one, exp_neg_x);\n    return svdiv_f32_x(pg, x, one_plus_exp_neg_x);\n}\n\n#elif defined(__ARM_NEON) && defined(__aarch64__)\n\n// adapted from arm limited optimized routine\n// the maximum error is 1.45358 plus 0.5 ulps\n// numbers above 88.38 will flush to infinity\n// numbers beneath -103.97 will flush to zero\ninline static float32x4_t ggml_v_expf(float32x4_t x) {\n    const float32x4_t r = vdupq_n_f32(0x1.8p23f);\n    const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));\n    const float32x4_t n = vsubq_f32(z, r);\n    const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,\n                                    vdupq_n_f32(0x1.7f7d1cp-20f));\n    const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);\n    const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));\n    const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));\n    const float32x4_t u = vmulq_f32(b, b);\n    const float32x4_t j = vfmaq_f32(\n        vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),\n        vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),\n                  vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);\n    if (!vpaddd_u64(vreinterpretq_u64_u32(c)))\n        return vfmaq_f32(k, j, k);\n    const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));\n    const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));\n    const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));\n    return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),\n                     vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));\n}\n\n// computes silu x/(1+exp(-x)) in single precision vector\ninline static float32x4_t ggml_v_silu(float32x4_t x) {\n    const float32x4_t one = vdupq_n_f32(1.0f);\n    const float32x4_t zero = vdupq_n_f32(0.0f);\n    const float32x4_t neg_x = vsubq_f32(zero, x);\n    const float32x4_t exp_neg_x = ggml_v_expf(neg_x);\n    const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);\n    return vdivq_f32(x, one_plus_exp_neg_x);\n}\n\n#elif defined(__AVX512F__) && defined(__AVX512DQ__)\n\n// adapted from arm limited optimized routine\n// the maximum error is 1.45358 plus 0.5 ulps\n// numbers above 88.38 will flush to infinity\n// numbers beneath -103.97 will flush to zero\ninline static __m512 ggml_v_expf(__m512 x) {\n  const __m512 r = _mm512_set1_ps(0x1.8p23f);\n  const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);\n  const __m512 n = _mm512_sub_ps(z, r);\n  const __m512 b =\n      _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),\n                       _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));\n  const __mmask16 d =\n      _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);\n  const __m512 u = _mm512_mul_ps(b, b);\n  const __m512 j = _mm512_fmadd_ps(\n      _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,\n                                      _mm512_set1_ps(0x1.573e2ep-5f)),\n                      u,\n                      _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,\n                                      _mm512_set1_ps(0x1.fffdb6p-2f))),\n      u,\n      _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));\n  const __m512 res = _mm512_scalef_ps(j, n);\n  if (_mm512_kortestz(d, d))\n    return res;\n  const __m512 zero = _mm512_setzero_ps();\n  const __m512 alt = _mm512_mask_blend_ps(\n      _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);\n  return _mm512_mask_blend_ps(d, res, alt);\n}\n\n// computes silu x/(1+exp(-x)) in single precision vector\ninline static __m512 ggml_v_silu(__m512 x) {\n    const __m512 one = _mm512_set1_ps(1);\n    const __m512 zero = _mm512_setzero_ps();\n    const __m512 neg_x = _mm512_sub_ps(zero, x);\n    const __m512 exp_neg_x = ggml_v_expf(neg_x);\n    const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);\n    return _mm512_div_ps(x, one_plus_exp_neg_x);\n}\n\n#elif defined(__AVX2__) && defined(__FMA__)\n\n// adapted from arm limited optimized routine\n// the maximum error is 1.45358 plus 0.5 ulps\n// numbers above 88.38 will flush to infinity\n// numbers beneath -103.97 will flush to zero\ninline static __m256 ggml_v_expf(__m256 x) {\n  const __m256 r = _mm256_set1_ps(0x1.8p23f);\n  const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);\n  const __m256 n = _mm256_sub_ps(z, r);\n  const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),\n                                    _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));\n  const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);\n  const __m256 k = _mm256_castsi256_ps(\n      _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));\n  const __m256i c = _mm256_castps_si256(\n      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),\n                    _mm256_set1_ps(126), _CMP_GT_OQ));\n  const __m256 u = _mm256_mul_ps(b, b);\n  const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,\n                                                                   _mm256_set1_ps(0x1.573e2ep-5f)), u,\n                                                   _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,\n                                                                   _mm256_set1_ps(0x1.fffdb6p-2f))),\n                                   u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));\n  if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))\n    return _mm256_fmadd_ps(j, k, k);\n  const __m256i g = _mm256_and_si256(\n      _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),\n      _mm256_set1_epi32(0x82000000u));\n  const __m256 s1 =\n      _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));\n  const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));\n  const __m256i d = _mm256_castps_si256(\n      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),\n                    _mm256_set1_ps(192), _CMP_GT_OQ));\n  return _mm256_or_ps(\n      _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),\n      _mm256_andnot_ps(\n          _mm256_castsi256_ps(d),\n          _mm256_or_ps(\n              _mm256_and_ps(_mm256_castsi256_ps(c),\n                            _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),\n              _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));\n}\n\n// computes silu x/(1+exp(-x)) in single precision vector\ninline static __m256 ggml_v_silu(__m256 x) {\n    const __m256 one = _mm256_set1_ps(1);\n    const __m256 zero = _mm256_setzero_ps();\n    const __m256 neg_x = _mm256_sub_ps(zero, x);\n    const __m256 exp_neg_x = ggml_v_expf(neg_x);\n    const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);\n    return _mm256_div_ps(x, one_plus_exp_neg_x);\n}\n\n#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON\n\n#if defined(__FMA__)\n#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)\n#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)\n#else\n#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)\n#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))\n#endif\n\n// adapted from arm limited optimized routine\n// the maximum error is 1.45358 plus 0.5 ulps\n// numbers above 88.38 will flush to infinity\n// numbers beneath -103.97 will flush to zero\ninline static __m128 ggml_v_expf(__m128 x) {\n    const __m128 r = _mm_set1_ps(0x1.8p23f);\n    const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);\n    const __m128 n = _mm_sub_ps(z, r);\n    const __m128 b =\n        NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));\n    const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);\n    const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));\n    const __m128i c =\n        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));\n    const __m128 u = _mm_mul_ps(b, b);\n    const __m128 j =\n        MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,\n                        MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),\n                u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));\n    if (!_mm_movemask_epi8(c))\n        return MADD128(j, k, k);\n    const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),\n                                    _mm_set1_epi32(0x82000000u));\n    const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));\n    const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));\n    const __m128i d =\n        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));\n    return _mm_or_ps(\n        _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),\n        _mm_andnot_ps(_mm_castsi128_ps(d),\n                      _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),\n                                _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));\n}\n\n// computes silu x/(1+exp(-x)) in single precision vector\ninline static __m128 ggml_v_silu(__m128 x) {\n    const __m128 one = _mm_set1_ps(1);\n    const __m128 zero = _mm_setzero_ps();\n    const __m128 neg_x = _mm_sub_ps(zero, x);\n    const __m128 exp_neg_x = ggml_v_expf(neg_x);\n    const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);\n    return _mm_div_ps(x, one_plus_exp_neg_x);\n}\n\n#elif defined(__riscv_v_intrinsic)\n\n// adapted from arm limited optimized routine\n// the maximum error is 1.45358 plus 0.5 ulps\n// numbers above 88.38 will flush to infinity\n// numbers beneath -103.97 will flush to zero\ninline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {\n    const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl);\n#ifdef __riscv_xtheadvector\n    // workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4')\n    vfloat32m2_t z = __riscv_vfadd_vf_f32m2(r, 0.0f, vl);\n    z = __riscv_vfmacc_vf_f32m2(z, 0x1.715476p+0f, x, vl);\n#else\n    const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl);\n#endif\n    const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl);\n    const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl),\n                                                    0x1.7f7d1cp-20f, n, vl);\n    const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl);\n    const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); // 1.0f\n    const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl);\n    const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl);\n    const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2(\n        __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl),\n        __riscv_vfmacc_vv_f32m2(\n            __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl),\n            __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl),\n            u, vl), u, vl);\n    if (!__riscv_vcpop_m_b16(c, vl))\n        return __riscv_vfmacc_vv_f32m2(k, j, k, vl);\n    const vbool16_t  dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl);\n    const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl);\n    const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl));\n    const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl));\n    const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2(\n        __riscv_vfmacc_vv_f32m2(k, k, j, vl),\n        __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl),\n        c, vl);\n    return __riscv_vmerge_vvm_f32m2(\n        r1, __riscv_vfmul_vv_f32m2(s1, s1, vl),\n        __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl),\n        vl);\n}\n\n// computes silu x/(1+exp(-x)) in single precision vector\ninline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) {\n    const vfloat32m2_t neg_x = __riscv_vfneg_v_f32m2(x, vl);\n    const vfloat32m2_t exp_neg_x = ggml_v_expf_m2(neg_x, vl);\n    const vfloat32m2_t one_plus_exp_neg_x = __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl);\n    return __riscv_vfdiv_vv_f32m2(x, one_plus_exp_neg_x, vl);\n}\n\n#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic\n\ninline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = ggml_silu_f16(x[i]);\n    }\n}\n\ninline static float ggml_silu_backward_f32(float x, float dy) {\n    const float s = 1.0f/(1.0f + expf(-x));\n    return dy*s*(1.0f + x*(1.0f - s));\n}\n\ninline static ggml_fp16_t ggml_silu_backward_f16(ggml_fp16_t x, ggml_fp16_t dy) {\n    const float v = GGML_CPU_FP16_TO_FP32(x);\n    const float s = 1.0f/(1.0f + expf(-v));\n    return GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(dy)*s*(1.0f + v*(1.0f - s)));\n}\n\ninline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {\n    for (int i = 0; i < n; ++i) {\n        dx[i] = ggml_silu_backward_f32(x[i], dy[i]);\n    }\n}\n\ninline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, const ggml_fp16_t * x, const ggml_fp16_t * dy) {\n    for (int i = 0; i < n; ++i) {\n        dx[i] = ggml_silu_backward_f16(x[i], dy[i]);\n    }\n}\n\ninline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;\n    }\n}\n\ninline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {\n    for (int i = 0; i < n; ++i) {\n        float v = GGML_CPU_FP16_TO_FP32(x[i]);\n        y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v * GGML_CPU_FP16_TO_FP32(g[i]) : 0.f);\n    }\n}\n\n#ifdef GGML_GELU_FP16\ninline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {\n    uint16_t t;\n    for (int i = 0; i < n; ++i) {\n        if (x[i] <= -10.0f) {\n            y[i] = 0.0f;\n        } else if (x[i] >= 10.0f) {\n            y[i] = x[i] * g[i];\n        } else {\n            ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);\n            memcpy(&t, &fp16, sizeof(uint16_t));\n            y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];\n        }\n    }\n}\n#else\ninline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = ggml_gelu_f32(x[i]) * g[i];\n    }\n}\n#endif\n\ninline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {\n    const uint16_t * i16 = (const uint16_t *) x;\n    for (int i = 0; i < n; ++i) {\n        float v = GGML_CPU_FP16_TO_FP32(g[i]);\n        y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);\n    }\n}\n\nvoid ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);\n\ninline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {\n    for (int i = 0; i < n; ++i) {\n        float xi = GGML_CPU_FP16_TO_FP32(x[i]);\n        float gi = GGML_CPU_FP16_TO_FP32(g[i]);\n        y[i] = GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);\n    }\n}\n\ninline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {\n    for (int i = 0; i < n; ++i) {\n        float xi = x[i];\n        y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];\n    }\n}\n\ninline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {\n    for (int i = 0; i < n; ++i) {\n        float xi = GGML_CPU_FP16_TO_FP32(x[i]);\n        float gi = GGML_CPU_FP16_TO_FP32(g[i]);\n        y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);\n    }\n}\n\n#ifdef GGML_GELU_QUICK_FP16\ninline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {\n    uint16_t t;\n    for (int i = 0; i < n; ++i) {\n        ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);\n        memcpy(&t, &fp16, sizeof(uint16_t));\n        y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];\n    }\n}\n#else\ninline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {\n    for (int i = 0; i < n; ++i) {\n        y[i] = ggml_gelu_quick_f32(x[i]) * g[i];\n    }\n}\n#endif\n\ninline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {\n    const uint16_t * i16 = (const uint16_t *) x;\n    for (int i = 0; i < n; ++i) {\n        float v = GGML_CPU_FP16_TO_FP32(g[i]);\n        y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);\n    }\n}\n\ninline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {\n#ifndef GGML_USE_ACCELERATE\n    ggml_float sum = 0.0;\n    for (int i = 0; i < n; ++i) {\n        sum += (ggml_float)x[i];\n    }\n    *s = (float)sum;\n#else\n    vDSP_sve(x, 1, s, n);\n#endif\n}\n\ninline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) {\n    for (int i = 0; i < n; ++i) {\n        if (i == 0) {\n            y[i] = x[i];\n        } else {\n            y[i] = y[i - 1] + x[i];\n        }\n    }\n}\n\ninline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {\n    ggml_float sum = 0.0;\n    for (int i = 0; i < n; ++i) {\n        sum += (ggml_float)x[i];\n    }\n    *s = sum;\n}\n\ninline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) {\n    float sum = 0.0f;\n    for (int i = 0; i < n; ++i) {\n        sum += GGML_CPU_FP16_TO_FP32(x[i]);\n    }\n    *s = sum;\n}\n\ninline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {\n    float sum = 0.0f;\n    for (int i = 0; i < n; ++i) {\n        sum += GGML_BF16_TO_FP32(x[i]);\n    }\n    *s = sum;\n}\n\ninline static void ggml_vec_max_f32(const int n, float * s, const float * x) {\n#ifndef GGML_USE_ACCELERATE\n    float max = -INFINITY;\n    for (int i = 0; i < n; ++i) {\n        max = MAX(max, x[i]);\n    }\n    *s = max;\n#else\n    vDSP_maxv(x, 1, s, n);\n#endif\n}\n\ninline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {\n    ggml_vec_norm_f32(n, s, x);\n    *s = 1.f/(*s);\n}\n\ninline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {\n    float max = -INFINITY;\n    int idx = 0;\n    for (int i = 0; i < n; ++i) {\n        max = MAX(max, x[i]);\n        if (max == x[i]) { idx = i; }\n    }\n    *s = idx;\n}\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-cuda/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.18)  # for CMAKE_CUDA_ARCHITECTURES\n\nfind_package(CUDAToolkit)\n\nif (CUDAToolkit_FOUND)\n    message(STATUS \"CUDA Toolkit found\")\n\n    if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)\n        # native == GPUs available at build time\n        # 50     == Maxwell, lowest CUDA 12 standard\n        # 60     == P100, FP16 CUDA intrinsics\n        # 61     == Pascal, __dp4a instruction (per-byte integer dot product)\n        # 70     == V100, FP16 tensor cores\n        # 75     == Turing, int8 tensor cores\n        # 80     == Ampere, asynchronous data loading, faster tensor core instructions\n        # 86     == RTX 3000, needs CUDA v11.1\n        # 89     == RTX 4000, needs CUDA v11.8\n        # 120    == Blackwell, needs CUDA v12.8, FP4 tensor cores\n        #\n        # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run\n        # XX-real    == compile CUDA code as device code for this specific architecture\n        # no suffix  == compile as both PTX and device code\n        #\n        # The default behavior for a non-native is to build virtual architectures as needed to cover all features needed\n        #     for best performance and to also build real architectures for the most commonly used GPUs.\n        if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL \"11.6\" AND CMAKE_VERSION VERSION_GREATER_EQUAL \"3.24\")\n            set(CMAKE_CUDA_ARCHITECTURES \"native\")\n        else()\n            if (CUDAToolkit_VERSION VERSION_LESS \"13\")\n                list(APPEND CMAKE_CUDA_ARCHITECTURES 50-virtual 61-virtual 70-virtual)\n            endif ()\n\n            list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real)\n\n            if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL \"11.8\")\n                list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)\n            endif()\n\n            if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL \"12.8\")\n                # The CUDA architecture 120f-virtual would in principle work for Blackwell support\n                #     but the newly added \"f\" suffix conflicted with a preexising regex for validating CUDA architectures in CMake.\n                # So either a recent CMake version or one with the backported fix is needed.\n                # The following versions should work:\n                #   - CMake >= v3.31.8 && CMake < v4.0.0\n                #   - CMake >= v4.0.2\n                # This is NOT documented in the CMake release notes,\n                #     check Modules/Internal/CMakeCUDAArchitecturesValidate.cmake in the CMake git repository instead.\n                # However, the architectures 120a-real and 121a-real should work with basically any CMake version and\n                #     until the release of e.g. Rubin there is no benefit to shipping virtual architectures for Blackwell.\n                list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real)\n            endif()\n            if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL \"12.9\")\n                list(APPEND CMAKE_CUDA_ARCHITECTURES 121a-real)\n            endif()\n        endif()\n    endif()\n\n    enable_language(CUDA)\n\n    # TODO: Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit\n    if (GGML_CUDA_CUB_3DOT2)\n        include(FetchContent)\n\n        FetchContent_Declare(\n            CCCL\n            GIT_REPOSITORY https://github.com/nvidia/cccl.git\n            GIT_TAG        v3.2.0\n            GIT_SHALLOW    TRUE\n        )\n\n        FetchContent_MakeAvailable(CCCL)\n    endif()\n\n    # Replace any plain 12X CUDA architectures with their \"architecture-specific\" equivalents 12Xa.\n    # 12X is forwards-compatible, 12Xa is not.\n    # Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa.\n    # But while 12X vs. 12Xa can be checked in device code there is (to my knowledge) no easy way to do the same check in host code.\n    # So for now just replace all instances of 12X with 12Xa, this should be fine until Rubin is released.\n    foreach(ARCHS IN ITEMS CMAKE_CUDA_ARCHITECTURES CMAKE_CUDA_ARCHITECTURES_NATIVE)\n        set(FIXED_ARCHS \"\")\n        foreach(ARCH IN LISTS ${ARCHS})\n            if (ARCH MATCHES \"^12[0-9](-real|-virtual)?$\")\n                string(REGEX REPLACE \"^(12[0-9])((-real|-virtual)?)$\" \"\\\\1a\\\\2\" FIXED_ARCH ${ARCH})\n                message(STATUS \"Replacing ${ARCH} in ${ARCHS} with ${FIXED_ARCH}\")\n                list(APPEND FIXED_ARCHS \"${FIXED_ARCH}\")\n            else()\n                list(APPEND FIXED_ARCHS \"${ARCH}\")\n            endif()\n        endforeach()\n        set(${ARCHS} ${FIXED_ARCHS})\n    endforeach()\n\n    # If we try to compile a \"native\" build it will use the 12X architectures and fail.\n    # So we should instead use the native architectures as determined by CMake after replacing 12X with 12Xa.\n    # But if at the time of the build no GPUs are connected at all CMAKE_CUDA_ARCHITECTURES will contain garbage that we should not use.\n    if (CMAKE_CUDA_ARCHITECTURES STREQUAL \"native\" AND CMAKE_CUDA_ARCHITECTURES_NATIVE MATCHES \"^[0-9]+(a|f)?(-real|-virtual)?(;[0-9]+(a|f)?(-real|-virtual)?|;)*$\")\n        set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_NATIVE})\n    endif()\n    message(STATUS \"Using CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} CMAKE_CUDA_ARCHITECTURES_NATIVE=${CMAKE_CUDA_ARCHITECTURES_NATIVE}\")\n\n    file(GLOB   GGML_HEADERS_CUDA \"*.cuh\")\n    list(APPEND GGML_HEADERS_CUDA \"../../include/ggml-cuda.h\")\n\n    file(GLOB   GGML_SOURCES_CUDA \"*.cu\")\n    file(GLOB   SRCS \"template-instances/fattn-tile*.cu\")\n    list(APPEND GGML_SOURCES_CUDA ${SRCS})\n    file(GLOB   SRCS \"template-instances/fattn-mma*.cu\")\n    list(APPEND GGML_SOURCES_CUDA ${SRCS})\n    file(GLOB   SRCS \"template-instances/mmq*.cu\")\n    list(APPEND GGML_SOURCES_CUDA ${SRCS})\n    file(GLOB   SRCS \"template-instances/mmf*.cu\")\n    list(APPEND GGML_SOURCES_CUDA ${SRCS})\n\n    if (GGML_CUDA_FA_ALL_QUANTS)\n        file(GLOB   SRCS \"template-instances/fattn-vec*.cu\")\n        list(APPEND GGML_SOURCES_CUDA ${SRCS})\n        add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)\n    else()\n        file(GLOB   SRCS \"template-instances/fattn-vec*q4_0-q4_0.cu\")\n        list(APPEND GGML_SOURCES_CUDA ${SRCS})\n        file(GLOB   SRCS \"template-instances/fattn-vec*q8_0-q8_0.cu\")\n        list(APPEND GGML_SOURCES_CUDA ${SRCS})\n        file(GLOB   SRCS \"template-instances/fattn-vec*f16-f16.cu\")\n        list(APPEND GGML_SOURCES_CUDA ${SRCS})\n    endif()\n\n    ggml_add_backend_library(ggml-cuda\n                             ${GGML_HEADERS_CUDA}\n                             ${GGML_SOURCES_CUDA}\n                            )\n\n    add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})\n\n    if (GGML_CUDA_GRAPHS)\n        add_compile_definitions(GGML_CUDA_USE_GRAPHS)\n    endif()\n\n    if (GGML_CUDA_FORCE_MMQ)\n        add_compile_definitions(GGML_CUDA_FORCE_MMQ)\n    endif()\n\n    if (GGML_CUDA_FORCE_CUBLAS)\n        add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)\n    endif()\n\n    if (GGML_CUDA_NO_VMM)\n        add_compile_definitions(GGML_CUDA_NO_VMM)\n    endif()\n\n    if (NOT GGML_CUDA_FA)\n        add_compile_definitions(GGML_CUDA_NO_FA)\n    endif()\n\n    if (GGML_CUDA_NO_PEER_COPY)\n        add_compile_definitions(GGML_CUDA_NO_PEER_COPY)\n    endif()\n\n    if (GGML_STATIC)\n        if (WIN32)\n            # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library\n            target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)\n        else ()\n            if (GGML_CUDA_CUB_3DOT2)\n                target_link_libraries(ggml-cuda PRIVATE  CCCL::CCCL)\n            endif()\n            if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL \"10.1\")\n                target_link_libraries(ggml-cuda PRIVATE  CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)\n            else()\n                target_link_libraries(ggml-cuda PRIVATE  CUDA::cudart_static CUDA::cublas_static)\n            endif()\n        endif()\n    else()\n        if (GGML_CUDA_CUB_3DOT2)\n            target_link_libraries(ggml-cuda PRIVATE  CCCL::CCCL)\n        endif()\n        target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)\n    endif()\n\n    if (GGML_CUDA_NO_VMM)\n        # No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)\n    else()\n        target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver)\n    endif()\n\n    set(CUDA_CXX_FLAGS \"\")\n\n    set(CUDA_FLAGS -use_fast_math -extended-lambda)\n\n    if (GGML_CUDA_DEBUG)\n        list(APPEND CUDA_FLAGS -lineinfo)\n        add_compile_definitions(GGML_CUDA_DEBUG)\n    endif()\n\n    if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL \"12.8\")\n        # Options are:\n        # - none (not recommended)\n        # - speed (nvcc's default)\n        # - balance\n        # - size\n        list(APPEND CUDA_FLAGS -compress-mode=${GGML_CUDA_COMPRESSION_MODE})\n    endif()\n\n    if (GGML_FATAL_WARNINGS)\n        list(APPEND CUDA_FLAGS -Werror all-warnings)\n    endif()\n\n    if (GGML_ALL_WARNINGS AND NOT MSVC)\n        set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c)\n        if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL \"\")\n            list(APPEND NVCC_CMD -ccbin ${CMAKE_CUDA_HOST_COMPILER})\n        endif()\n\n        execute_process(\n            COMMAND ${NVCC_CMD} -Xcompiler --version\n            OUTPUT_VARIABLE CUDA_CCFULLVER\n            ERROR_QUIET\n        )\n\n        if (NOT CUDA_CCFULLVER MATCHES clang)\n            set(CUDA_CCID \"GNU\")\n            execute_process(\n                COMMAND ${NVCC_CMD} -Xcompiler \"-dumpfullversion -dumpversion\"\n                OUTPUT_VARIABLE CUDA_CCVER\n                ERROR_QUIET\n                OUTPUT_STRIP_TRAILING_WHITESPACE\n            )\n        else()\n            if (CUDA_CCFULLVER MATCHES Apple)\n                set(CUDA_CCID \"AppleClang\")\n            else()\n                set(CUDA_CCID \"Clang\")\n            endif()\n            string(REGEX REPLACE \"^.* version ([0-9.]*).*$\" \"\\\\1\" CUDA_CCVER ${CUDA_CCFULLVER})\n        endif()\n\n        message(STATUS \"CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}\")\n\n        ggml_get_flags(${CUDA_CCID} ${CUDA_CCVER})\n        list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS})  # This is passed to -Xcompiler later\n    endif()\n\n    if (NOT MSVC)\n        list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)\n    else()\n        # CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC\n        # https://github.com/NVIDIA/cccl/pull/6827\n        list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor)\n    endif()\n\n    list(JOIN   CUDA_CXX_FLAGS \" \" CUDA_CXX_FLAGS_JOINED)  # pass host compiler flags as a single argument\n\n    if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL \"\")\n        list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})\n    endif()\n\n    target_compile_options(ggml-cuda PRIVATE \"$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>\")\nelse()\n    message(FATAL_ERROR \"CUDA Toolkit not found\")\nendif()\n"
  },
  {
    "path": "src/ggml-cuda/acc.cu",
    "content": "#include \"acc.cuh\"\n\nstatic __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,\n        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,\n        const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {\n    const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;\n\n    if (i >= ne) {\n        return;\n    }\n\n    int64_t src1_idx = i - offset;\n\n    int64_t tmp = src1_idx;\n    const int64_t i13 = tmp / s13;\n    tmp -= i13 * s13;\n    const int64_t i12 = tmp / s12;\n    tmp -= i12 * s12;\n    const int64_t i11 = tmp / s11;\n    tmp -= i11 * s11;\n    const int64_t i10 = tmp;\n\n    float val = x[i];\n    if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {\n        val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];\n    }\n    dst[i] = val;\n}\n\nstatic void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements,\n        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,\n        const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) {\n    const int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;\n    acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);\n}\n\nvoid ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    const float * src0_d = (const float *) src0->data;\n    const float * src1_d = (const float *) src1->data;\n    float       * dst_d  = (float       *)  dst->data;\n\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_ASSERT(ggml_is_contiguous(src1));\n    GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));\n    GGML_ASSERT(ggml_is_contiguously_allocated(dst));\n\n    const int64_t s1     = dst->op_params[0] / sizeof(float);\n    const int64_t s2     = dst->op_params[1] / sizeof(float);\n    const int64_t s3     = dst->op_params[2] / sizeof(float);\n    const int64_t offset = dst->op_params[3] / sizeof(float);\n\n    acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], s1, s2, s3, offset, stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/acc.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_ACC_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/add-id.cu",
    "content": "#include \"add-id.cuh\"\n\nstatic __global__ void add_id_kernel(\n        const float * src0, const float * src1, const int32_t * src2, float * dst,\n        int64_t ne0, int64_t ne1,\n        size_t nb01, size_t nb02,\n        size_t nb11,\n        size_t nb21\n    ) {\n\n    const int64_t i1 = blockIdx.x;\n    const int64_t i2 = blockIdx.y;\n\n    const int i11 = *(const int32_t *) ((const char *) src2 + i1*sizeof(int32_t) + i2*nb21);\n\n    const size_t nb1 = ne0 * sizeof(float);\n    const size_t nb2 = ne1 * nb1;\n\n    float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);\n    const float * src0_row = (const float *)((const char *)src0 +  i1*nb01 + i2*nb02);\n    const float * src1_row = (const float *)((const char *)src1 + i11*nb11);\n\n    for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {\n        dst_row[i0] = src0_row[i0] + src1_row[i0];\n    }\n}\n\nvoid ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    const ggml_tensor * src2 = dst->src[2];\n\n    GGML_TENSOR_TERNARY_OP_LOCALS\n\n    GGML_ASSERT(dst->type  == GGML_TYPE_F32);\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(src2->type == GGML_TYPE_I32);\n\n    GGML_ASSERT(nb00 == sizeof(float));\n    GGML_ASSERT(nb10 == sizeof(float));\n    GGML_ASSERT(nb20 == sizeof(int32_t));\n\n    const float * src0_d = (const float *)src0->data;\n    const float * src1_d = (const float *)src1->data;\n    const int32_t * src2_d = (const int32_t *)src2->data;\n    float * dst_d = (float *)dst->data;\n\n    int threads = std::min((int)ne00, 768); // cols\n    dim3 blocks(ne01, ne02); // n_experts_used, n_tokens\n    add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(\n        src0_d, src1_d, src2_d, dst_d,\n        ne0, ne1,\n        nb01, nb02,\n        nb11,\n        nb21\n    );\n}\n"
  },
  {
    "path": "src/ggml-cuda/add-id.cuh",
    "content": "#include \"common.cuh\"\n\nvoid ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/arange.cu",
    "content": "#include \"arange.cuh\"\n\nstatic __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {\n    // blockIDx.x: idx of ne0 / BLOCK_SIZE\n    int nidx = threadIdx.x + blockIdx.x * blockDim.x;\n    if (nidx >= ne0) {\n        return;\n    }\n    dst[nidx] = start + step * nidx;\n}\n\nstatic void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {\n    int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;\n    arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start,  step);\n}\n\nvoid ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    float * dst_d = (float *)dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    float start;\n    float stop;\n    float step;\n    memcpy(&start, (float *)dst->op_params + 0, sizeof(float));\n    memcpy(&stop,  (float *)dst->op_params + 1, sizeof(float));\n    memcpy(&step,  (float *)dst->op_params + 2, sizeof(float));\n\n    int64_t steps = (int64_t)ceil((stop - start) / step);\n    GGML_ASSERT(ggml_nelements(dst) == steps);\n\n    arange_f32_cuda(dst_d, dst->ne[0], start, step, stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/arange.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_ARANGE_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/argmax.cu",
    "content": "#include <algorithm>\n#include <cstdint>\n\n#include \"argmax.cuh\"\n#include \"common.cuh\"\n#include \"sum.cuh\"\n\nstatic __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {\n    const int64_t row = blockIdx.x;\n\n    float maxval = -FLT_MAX;\n    int   argmax = -1;\n    const float * rowx = x + row * ncols;\n\n    for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {\n        const float val = rowx[col];\n        if (val > maxval) {\n            maxval = val;\n            argmax = col;\n        }\n    }\n\n#pragma unroll\n    for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {\n        const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);\n        const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);\n        if (val > maxval) {\n            maxval = val;\n            argmax = col;\n        }\n    }\n\n    const int n_warps = blockDim.x / WARP_SIZE;\n    const int lane_id = threadIdx.x % WARP_SIZE;\n    const int warp_id = threadIdx.x / WARP_SIZE;\n    if (n_warps > 1) {\n        constexpr int    max_warps = 1024 / WARP_SIZE;\n        __shared__ float shared_maxval[max_warps];\n        __shared__ int   shared_argmax[max_warps];\n        if (lane_id == 0) {\n            shared_maxval[warp_id] = maxval;\n            shared_argmax[warp_id] = argmax;\n        }\n\n        __syncthreads();\n\n        if (warp_id == 0) {\n            if (lane_id < n_warps) {\n                maxval = shared_maxval[lane_id];\n                argmax = shared_argmax[lane_id];\n            }\n#pragma unroll\n            for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {\n                const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);\n                const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);\n                if (val > maxval) {\n                    maxval = val;\n                    argmax = col;\n                }\n            }\n        }\n    }\n\n    if (warp_id == 0 && lane_id == 0) {\n        dst[row] = argmax;\n    }\n}\n\nvoid ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_I32);\n\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    const int64_t ne00  = src0->ne[0];\n    const int64_t nrows = ggml_nrows(src0);\n\n    const float * src0_d = (const float *) src0->data;\n    int32_t     * dst_d  = (int32_t     *) dst->data;\n\n    cudaStream_t stream = ctx.stream();\n\n    const int64_t num_blocks = nrows;\n    const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);\n    const dim3 blocks_dim(num_threads, 1, 1);\n    const dim3 blocks_num(num_blocks, 1, 1);\n\n    argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00);\n}\n"
  },
  {
    "path": "src/ggml-cuda/argmax.cuh",
    "content": "#include \"common.cuh\"\n\nvoid ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/argsort.cu",
    "content": "#include \"argsort.cuh\"\n\n#ifdef GGML_CUDA_USE_CUB\n#    include <cub/cub.cuh>\n#    if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1)\n#        define STRIDED_ITERATOR_AVAILABLE\n#    endif\nusing namespace cub;\n#endif  // GGML_CUDA_USE_CUB\n\nstatic __global__ void init_indices(int * indices, const int ncols, const int nrows) {\n    const int col = blockIdx.x * blockDim.x + threadIdx.x;\n    const int row = blockIdx.y;\n\n    if (col < ncols && row < nrows) {\n        indices[row * ncols + col] = col;\n    }\n}\n\n#ifndef STRIDED_ITERATOR_AVAILABLE\nstatic __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {\n    const int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx <= nrows) {\n        offsets[idx] = idx * ncols;\n    }\n}\n#endif  // STRIDED_ITERATOR_AVAILABLE\n\n#ifdef GGML_CUDA_USE_CUB\nvoid argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,\n                              const float *    x,\n                              int *            dst,\n                              const int        ncols,\n                              const int        nrows,\n                              ggml_sort_order  order,\n                              cudaStream_t     stream) {\n    ggml_cuda_pool_alloc<int>   temp_indices_alloc(pool, ncols * nrows);\n    ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);\n\n    int *   temp_indices = temp_indices_alloc.get();\n    float * temp_keys    = temp_keys_alloc.get();\n\n    static const int block_size = 256;\n    const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);\n    init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);\n\n#ifdef STRIDED_ITERATOR_AVAILABLE\n    auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);\n#else\n    ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);\n    int *                     offset_iterator = offsets_alloc.get();\n    const dim3                offset_grid((nrows + block_size - 1) / block_size);\n    init_offsets<<<offset_grid, block_size, 0, stream>>>(offset_iterator, ncols, nrows);\n#endif\n    CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));\n\n    size_t temp_storage_bytes = 0;\n\n    if (order == GGML_SORT_ORDER_ASC) {\n        if (nrows == 1) {\n            DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)\n                                       temp_indices, dst,                                  // values (indices)\n                                       ncols, 0, sizeof(float) * 8, stream);\n        } else {\n            DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)\n                                           temp_indices, dst,                                  // values (indices)\n                                           ncols * nrows, nrows,  // num items, num segments\n                                           offset_iterator, offset_iterator + 1, stream);\n        }\n    } else {\n        if (nrows == 1) {\n            DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)\n                                                 temp_indices, dst,                                  // values (indices)\n                                                 ncols, 0, sizeof(float) * 8, stream);\n        } else {\n            DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,\n                                                     dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,\n                                                     stream);\n        }\n    }\n\n    ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);\n    void *                        d_temp_storage = temp_storage_alloc.get();\n\n    if (order == GGML_SORT_ORDER_ASC) {\n        if (nrows == 1) {\n            DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)\n                                       temp_indices, dst,  // values (indices)\n                                       ncols, 0, sizeof(float) * 8, stream);\n        } else {\n            DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,\n                                           ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);\n        }\n    } else {\n        if (nrows == 1) {\n            DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)\n                                                 temp_indices, dst,                                  // values (indices)\n                                                 ncols, 0, sizeof(float) * 8, stream);\n        } else {\n            DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,\n                                                     temp_indices, dst, ncols * nrows, nrows, offset_iterator,\n                                                     offset_iterator + 1, stream);\n        }\n    }\n}\n#endif  // GGML_CUDA_USE_CUB\n\n// Bitonic sort implementation\ntemplate<typename T>\nstatic inline __device__ void ggml_cuda_swap(T & a, T & b) {\n    T tmp = a;\n    a = b;\n    b = tmp;\n}\n\ntemplate<ggml_sort_order order>\nstatic __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {\n    // bitonic sort\n    int col = threadIdx.x;\n    int row = blockIdx.x;\n\n    if (col >= ncols_pad) {\n        return;\n    }\n\n    const float * x_row = x + row * ncols;\n    extern __shared__ int dst_row[];\n\n    // initialize indices\n    dst_row[col] = col;\n\n    __syncthreads();\n\n    for (int k = 2; k <= ncols_pad; k *= 2) {\n        for (int j = k / 2; j > 0; j /= 2) {\n            int ixj = col ^ j;\n            if (ixj > col) {\n                if ((col & k) == 0) {\n                    if (dst_row[col] >= ncols ||\n                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?\n                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :\n                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))\n                    ) {\n                        ggml_cuda_swap(dst_row[col], dst_row[ixj]);\n                    }\n                } else {\n                    if (dst_row[ixj] >= ncols ||\n                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?\n                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :\n                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))\n                    ) {\n                        ggml_cuda_swap(dst_row[col], dst_row[ixj]);\n                    }\n                }\n            }\n            __syncthreads();\n        }\n    }\n\n    // copy the result to dst without the padding\n    if (col < ncols) {\n        dst[row * ncols + col] = dst_row[col];\n    }\n}\n\nstatic int next_power_of_2(int x) {\n    int n = 1;\n    while (n < x) {\n        n *= 2;\n    }\n    return n;\n}\n\nvoid argsort_f32_i32_cuda_bitonic(const float *   x,\n                                  int *           dst,\n                                  const int       ncols,\n                                  const int       nrows,\n                                  ggml_sort_order order,\n                                  cudaStream_t    stream) {\n    // bitonic sort requires ncols to be power of 2\n    const int ncols_pad = next_power_of_2(ncols);\n\n    const dim3 block_dims(ncols_pad, 1, 1);\n    const dim3 block_nums(nrows, 1, 1);\n    const size_t shared_mem = ncols_pad * sizeof(int);\n\n    // FIXME: this limit could be raised by ~2-4x on Ampere or newer\n    GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);\n\n    if (order == GGML_SORT_ORDER_ASC) {\n        k_argsort_f32_i32<GGML_SORT_ORDER_ASC>\n            <<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);\n    } else if (order == GGML_SORT_ORDER_DESC) {\n        k_argsort_f32_i32<GGML_SORT_ORDER_DESC>\n            <<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n}\n\nvoid ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src0_d = (const float *)src0->data;\n    float * dst_d = (float *)dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_I32);\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    const int64_t ncols = src0->ne[0];\n    const int64_t nrows = ggml_nrows(src0);\n\n    enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];\n\n#ifdef GGML_CUDA_USE_CUB\n    const int    ncols_pad      = next_power_of_2(ncols);\n    const size_t shared_mem     = ncols_pad * sizeof(int);\n    const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;\n\n    if (shared_mem > max_shared_mem || ncols > 1024) {\n        ggml_cuda_pool & pool = ctx.pool();\n        argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);\n    } else {\n        argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);\n    }\n#else\n    argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);\n#endif\n}\n"
  },
  {
    "path": "src/ggml-cuda/argsort.cuh",
    "content": "#include \"common.cuh\"\n\nvoid ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\n#ifdef GGML_CUDA_USE_CUB\nvoid argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,\n                              const float *    x,\n                              int *            dst,\n                              const int        ncols,\n                              const int        nrows,\n                              ggml_sort_order  order,\n                              cudaStream_t     stream);\n#endif  // GGML_CUDA_USE_CUB\nvoid argsort_f32_i32_cuda_bitonic(const float *   x,\n                                  int *           dst,\n                                  const int       ncols,\n                                  const int       nrows,\n                                  ggml_sort_order order,\n                                  cudaStream_t    stream);\n"
  },
  {
    "path": "src/ggml-cuda/binbcast.cu",
    "content": "#include \"binbcast.cuh\"\n#include <cstdint>\n#include <utility>\n\nstatic __device__ __forceinline__ float op_repeat(const float a, const float b) {\n    return b;\n    GGML_UNUSED(a);\n}\n\nstatic __device__ __forceinline__ float op_add(const float a, const float b) {\n    return a + b;\n}\n\nstatic __device__ __forceinline__ float op_sub(const float a, const float b) {\n    return a - b;\n}\n\nstatic __device__ __forceinline__ float op_mul(const float a, const float b) {\n    return a * b;\n}\n\nstatic __device__ __forceinline__ float op_div(const float a, const float b) {\n    return a / b;\n}\n\ntemplate <float (*bin_op)(const float, const float),\n          typename src0_t,\n          typename src1_t,\n          typename dst_t,\n          typename... src1_ptrs>\nstatic __global__ void k_bin_bcast(const src0_t *         src0,\n                                   const src1_t *         src1,\n                                   dst_t *                dst,\n                                   const int              ne0,\n                                   const int              ne1,\n                                   const int              ne2,\n                                   const uint3            ne3,\n                                   const uint3            ne10,\n                                   const uint3            ne11,\n                                   const uint3            ne12,\n                                   const uint3            ne13,\n                                 /*const int              s0,*/\n                                   const int              s1,\n                                   const int              s2,\n                                   const int              s3,\n                                   const int              s00,\n                                   const int              s01,\n                                   const int              s02,\n                                   const int              s03,\n                                   const int              s10,\n                                   const int              s11,\n                                   const int              s12,\n                                   const int              s13,\n                                   src1_ptrs... src1s) {\n    const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;\n    const uint32_t i1  = (blockDim.y * blockIdx.y + threadIdx.y);\n    const uint32_t i2  = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);\n    const uint32_t i3  = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);\n\n    if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) {\n        return;\n    }\n\n    const uint32_t i11 = fastmodulo(i1, ne11);\n    const uint32_t i12 = fastmodulo(i2, ne12);\n    const uint32_t i13 = fastmodulo(i3, ne13);\n\n    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;\n    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;\n    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;\n\n    const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;\n    dst_t * dst_row = dst + i_dst;\n\n    for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {\n        const uint32_t i10 = fastmodulo(i0, ne10);\n\n        float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;\n        if constexpr (sizeof...(src1_ptrs) > 0) {\n            result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));\n        } else {\n            result = bin_op(result, (float)src1[i_src1 + i10*s10]);\n        }\n\n        dst_row[i0] = (dst_t) result;\n    }\n}\n\ntemplate <float (*bin_op)(const float, const float),\n          typename src0_t,\n          typename src1_t,\n          typename dst_t,\n          typename... src1_ptrs>\nstatic __global__ void k_bin_bcast_unravel(const src0_t *         src0,\n                                           const src1_t *         src1,\n                                           dst_t *                dst,\n                                           const uint3            ne0,\n                                           const uint3            ne1,\n                                           const uint3            ne2,\n                                           const uint32_t         ne3,\n                                           const uint3            prod_012,\n                                           const uint3            prod_01,\n                                           const uint3            ne10,\n                                           const uint3            ne11,\n                                           const uint3            ne12,\n                                           const uint3            ne13,\n                                         /*const int              s0,*/\n                                           const int              s1,\n                                           const int              s2,\n                                           const int              s3,\n                                           const int              s00,\n                                           const int              s01,\n                                           const int              s02,\n                                           const int              s03,\n                                           const int              s10,\n                                           const int              s11,\n                                           const int              s12,\n                                           const int              s13,\n                                           src1_ptrs... src1s) {\n    const int i = blockDim.x*blockIdx.x + threadIdx.x;\n\n    const uint32_t i3 = fastdiv(i, prod_012);\n    const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);\n    const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);\n    const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;\n\n    if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {\n        return;\n    }\n\n    const int i11 = fastmodulo(i1, ne11);\n    const int i12 = fastmodulo(i2, ne12);\n    const int i13 = fastmodulo(i3, ne13);\n\n    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;\n    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;\n    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;\n\n    const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;\n    dst_t * dst_row = dst + i_dst;\n\n    const int i10 = fastmodulo(i0, ne10);\n\n    float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;\n    if constexpr (sizeof...(src1_ptrs) > 0) {\n        result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));\n    } else {\n        result = bin_op(result, (float)src1[i_src1 + i10*s10]);\n    }\n\n    dst_row[i0] = (dst_t) result;\n}\n\ntemplate <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, size_t... I>\nstatic void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,\n                                  const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,\n                                  cudaStream_t stream, std::index_sequence<I...>) {\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    int nr0 = ne10 / ne0;\n    int nr1 = ne11 / ne1;\n    int nr2 = ne12 / ne2;\n    int nr3 = ne13 / ne3;\n\n    int nr[4] = { nr0, nr1, nr2, nr3 };\n\n    int64_t cne[]  = { ne0, ne1, ne2, ne3 };\n    int64_t cne0[] = { ne00, ne01, ne02, ne03 };\n    int64_t cne1[] = { ne10, ne11, ne12, ne13 };\n\n    size_t cnb[]  = { nb0, nb1, nb2, nb3 };\n    size_t cnb0[] = { nb00, nb01, nb02, nb03 };\n    size_t cnb1[] = { nb10, nb11, nb12, nb13 };\n\n    auto collapse = [](int64_t cne[]) {\n        cne[0] *= cne[1];\n        cne[1] = cne[2];\n        cne[2] = cne[3];\n        cne[3] = 1;\n    };\n\n    auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {\n        cnb[1] *= cne[1];\n        cnb[2] *= cne[2];\n        cnb[3] *= cne[3];\n    };\n\n    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {\n        for (int i = 0; i < 4; i++) {\n            if (nr[i] != 1) {\n                break;\n            }\n            if (i > 0) {\n                collapse_nb(cnb, cne);\n                collapse_nb(cnb0, cne0);\n                collapse_nb(cnb1, cne1);\n                collapse(cne);\n                collapse(cne0);\n                collapse(cne1);\n            }\n        }\n    }\n\n    {\n        int64_t ne0 = cne[0];\n        int64_t ne1 = cne[1];\n        int64_t ne2 = cne[2];\n        int64_t ne3 = cne[3];\n\n        //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);\n        //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);\n        //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);\n        //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);\n\n        size_t nb0 = cnb[0];\n        size_t nb1 = cnb[1];\n        size_t nb2 = cnb[2];\n        size_t nb3 = cnb[3];\n\n        size_t nb00 = cnb0[0];\n        size_t nb01 = cnb0[1];\n        size_t nb02 = cnb0[2];\n        size_t nb03 = cnb0[3];\n\n        size_t nb10 = cnb1[0];\n        size_t nb11 = cnb1[1];\n        size_t nb12 = cnb1[2];\n        size_t nb13 = cnb1[3];\n\n      //size_t s0 = nb0 / sizeof(dst_t);\n        size_t s1 = nb1 / sizeof(dst_t);\n        size_t s2 = nb2 / sizeof(dst_t);\n        size_t s3 = nb3 / sizeof(dst_t);\n\n        size_t s10 = nb10 / sizeof(src1_t);\n        size_t s11 = nb11 / sizeof(src1_t);\n        size_t s12 = nb12 / sizeof(src1_t);\n        size_t s13 = nb13 / sizeof(src1_t);\n\n        size_t s00 = nb00 / sizeof(src0_t);\n        size_t s01 = nb01 / sizeof(src0_t);\n        size_t s02 = nb02 / sizeof(src0_t);\n        size_t s03 = nb03 / sizeof(src0_t);\n\n        GGML_ASSERT(nb0 % sizeof(dst_t) == 0);\n        GGML_ASSERT(nb1 % sizeof(dst_t) == 0);\n        GGML_ASSERT(nb2 % sizeof(dst_t) == 0);\n        GGML_ASSERT(nb3 % sizeof(dst_t) == 0);\n\n        GGML_ASSERT(nb00 % sizeof(src0_t) == 0);\n        GGML_ASSERT(nb01 % sizeof(src0_t) == 0);\n        GGML_ASSERT(nb02 % sizeof(src0_t) == 0);\n        GGML_ASSERT(nb03 % sizeof(src0_t) == 0);\n\n        GGML_ASSERT(nb10 % sizeof(src1_t) == 0);\n        GGML_ASSERT(nb11 % sizeof(src1_t) == 0);\n        GGML_ASSERT(nb12 % sizeof(src1_t) == 0);\n        GGML_ASSERT(nb13 % sizeof(src1_t) == 0);\n\n        const int block_size = 128;\n\n        int64_t hne0 = std::max(ne0 / 2LL, 1LL);\n\n        dim3 block_dims;\n        block_dims.x = std::min<unsigned int>(hne0, block_size);\n        block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);\n        block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);\n\n        dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,\n                        (ne2 * ne3 + block_dims.z - 1) / block_dims.z);\n\n        const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]);\n        const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]);\n        const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);\n        const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);\n\n        if (block_nums.z > 65535 || block_nums.y > 65535) {\n            int         block_num  = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;\n            const uint3 prod_012    = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));\n            const uint3 prod_01     = init_fastdiv_values((uint32_t) (ne0 * ne1));\n            const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);\n            const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);\n            const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);\n\n            if constexpr (sizeof...(I) > 0) {\n                k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(\n                    src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,\n                    ne12, ne13,\n                  /*s0,*/ s1,  s2,  s3,\n                    s00, s01, s02, s03,\n                    s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);\n            } else {\n                k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>\n                    <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,\n                                                           ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,\n                                                         /*s0,*/ s1,  s2,  s3,\n                                                           s00, s01, s02, s03,\n                                                           s10, s11, s12, s13);\n            }\n        } else {\n            const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);\n            if constexpr (sizeof...(I) > 0) {\n                k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(\n                    src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,\n                  /*s0,*/ s1, s2,  s3,\n                    s00 ,s01, s02, s03,\n                    s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);\n            } else {\n                k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(\n                    src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,\n                  /*s0,*/ s1,  s2,  s3,\n                    s00, s01, s02, s03,\n                    s10, s11, s12, s13);\n            }\n        }\n    }\n}\n\ntemplate <typename T>\nstatic __global__ void k_repeat_back(\n    const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,\n    const size_t s00, const size_t s01, const size_t s02, const size_t s03,\n    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {\n\n    const int64_t tid0  = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;\n    const int64_t tid1  = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;\n    const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;\n    const int64_t tid2  = tid23 % ne2;\n    const int64_t tid3  = tid23 / ne2;\n\n    if (tid0 >= ne0) {\n        return;\n    }\n\n    T sum = 0;\n    for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {\n        for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {\n            for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {\n                for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {\n                    sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];\n                }\n            }\n        }\n    }\n    dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;\n}\n\ntemplate <float (*bin_op)(const float, const float), int n_fuse = 1>\nstruct bin_bcast_cuda {\n    template<typename src0_t, typename src1_t, typename dst_t>\n    void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,\n            const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,\n            cudaStream_t stream) {\n        launch_bin_bcast_pack<bin_op, src0_t, src1_t, dst_t>(\n            src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{});\n    }\n};\n\ntemplate <typename T>\nstatic void repeat_back_cuda(\n    const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,\n    const size_t s00, const size_t s01, const size_t s02, const size_t s03,\n    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {\n\n    const dim3 block_dims(WARP_SIZE, 1, 1);\n    const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);\n    k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>\n        (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);\n}\n\ntemplate<class op>\nstatic void ggml_cuda_op_bin_bcast(\n    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,\n    const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {\n\n    GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);\n\n    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n        op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);\n    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {\n        op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);\n    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {\n        op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);\n    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {\n        op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);\n    } else {\n        fprintf(stderr, \"%s: unsupported types: dst: %s, src0: %s, src1: %s\\n\", __func__,\n            ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));\n        GGML_ABORT(\"fatal error\");\n    }\n}\n\nvoid ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat, 0>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());\n}\n\nvoid ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());\n}\n\nvoid ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());\n}\n\nvoid ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());\n}\n\nvoid ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());\n}\n\ntemplate <float (*op)(const float, const float), int n_fuse>\nstatic void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    cudaStream_t stream = ctx.stream();\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n        launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst,\n            (const float *) src0->data, (const float *) src1->data, (float *) dst->data,\n            stream, std::make_index_sequence<n_fuse>{});\n    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {\n        launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst,\n            (const half *) src0->data, (const half *) src1->data, (half *) dst->data,\n            stream, std::make_index_sequence<n_fuse>{});\n    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {\n        launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst,\n            (const half *) src0->data, (const float *) src1->data, (half *) dst->data,\n            stream, std::make_index_sequence<n_fuse>{});\n    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {\n        launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst,\n            (const half *) src0->data, (const float *) src1->data, (float *) dst->data,\n            stream, std::make_index_sequence<n_fuse>{});\n    } else {\n        fprintf(stderr,\n                \"%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\\n\",\n                __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));\n        GGML_ABORT(\"fatal error\");\n    }\n}\n\n\nvoid ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {\n    GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);\n\n    switch (n_fuse) {\n        case 2:\n            ggml_cuda_op_fused_binbcast_impl<op_add, 2>(ctx, dst);\n            break;\n        case 3:\n            ggml_cuda_op_fused_binbcast_impl<op_add, 3>(ctx, dst);\n            break;\n        case 4:\n            ggml_cuda_op_fused_binbcast_impl<op_add, 4>(ctx, dst);\n            break;\n        case 5:\n            ggml_cuda_op_fused_binbcast_impl<op_add, 5>(ctx, dst);\n            break;\n        case 6:\n            ggml_cuda_op_fused_binbcast_impl<op_add, 6>(ctx, dst);\n            break;\n        case 7:\n            ggml_cuda_op_fused_binbcast_impl<op_add, 7>(ctx, dst);\n            break;\n        case 8:\n            ggml_cuda_op_fused_binbcast_impl<op_add, 8>(ctx, dst);\n            break;\n        default:\n            GGML_ASSERT(false && \"Unsupported n_fuse value\");\n    }\n}\n\nvoid ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(src0->type == dst->type);\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    GGML_ASSERT(ggml_can_repeat(dst, src0));\n\n    cudaStream_t stream = ctx.stream();\n\n    GGML_TENSOR_UNARY_OP_LOCALS;\n\n    GGML_ASSERT(ne2*ne3 <= (1 << 15));\n\n    const size_t ts = ggml_type_size(src0->type);\n    const size_t s00 = nb00 / ts;\n    const size_t s01 = nb01 / ts;\n    const size_t s02 = nb02 / ts;\n    const size_t s03 = nb03 / ts;\n\n    switch (dst->type) {\n        case GGML_TYPE_F32: {\n            const float * src0_d = (const float *) src0->data;\n            float       * dst_d  = (float       *) dst->data;\n            repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);\n        } break;\n        default: {\n            GGML_ASSERT(false);\n        } break;\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/binbcast.cuh",
    "content": "#include \"common.cuh\"\n\nvoid ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\nvoid ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\nvoid ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\nvoid ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\nvoid ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);\n"
  },
  {
    "path": "src/ggml-cuda/clamp.cu",
    "content": "#include \"clamp.cuh\"\n\nstatic __device__ __forceinline__ float op_clamp(float x, float min, float max) {\n    return fminf(fmaxf(x, min), max);\n}\n\ntemplate <class T>\nstatic __global__ void op_clamp_kernel(const T * x, T * dst, const T min, const T max, const int k) {\n    const int i = blockDim.x*blockIdx.x + threadIdx.x;\n\n    if (i >= k) {\n        return;\n    }\n\n    dst[i] = (T)op_clamp((float)x[i], (float)min, (float)max);\n}\n\ntemplate <class T>\nstatic void clamp_cuda(const T * x, T * dst, const T min, const T max, const int k, cudaStream_t stream) {\n    const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;\n    op_clamp_kernel<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);\n}\n\n\nvoid ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const void * src0_d = src0->data;\n    void * dst_d = dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);\n    GGML_ASSERT(src0->type == dst->type);\n\n    float min;\n    float max;\n    memcpy(&min, dst->op_params, sizeof(float));\n    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));\n\n    if (src0->type == GGML_TYPE_F16) {\n        clamp_cuda((const half *)src0_d, (half *)dst_d, (half)min, (half)max, ggml_nelements(src0), stream);\n    } else {\n        clamp_cuda((const float *)src0_d, (float *)dst_d, (float)min, (float)max, ggml_nelements(src0), stream);\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/clamp.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_CLAMP_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/common.cuh",
    "content": "#pragma once\n\n#include \"ggml.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-cuda.h\"\n\n#include <cstdint>\n#include <memory>\n\n#if defined(GGML_USE_HIP)\n#define GGML_COMMON_DECL_HIP\n#define GGML_COMMON_IMPL_HIP\n#else\n#define GGML_COMMON_DECL_CUDA\n#define GGML_COMMON_IMPL_CUDA\n#if defined(GGML_USE_MUSA)\n#define GGML_COMMON_DECL_MUSA\n#define GGML_COMMON_IMPL_MUSA\n#endif\n#endif\n#include \"ggml-common.h\"\n\n#include <array>\n#include <algorithm>\n#include <cassert>\n#include <cfloat>\n#include <cstdio>\n#include <string>\n#include <unordered_map>\n#include <vector>\n\n#if defined(GGML_USE_HIP)\n#include \"vendors/hip.h\"\n#elif defined(GGML_USE_MUSA)\n#include \"vendors/musa.h\"\n#else\n#include \"vendors/cuda.h\"\n#endif // defined(GGML_USE_HIP)\n\n#define STRINGIZE_IMPL(...) #__VA_ARGS__\n#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)\n\n#define WARP_SIZE 32\n#define CUDART_HMAX   11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)\n#define CUDART_HMASK  12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons\n\n#define GGML_CUDA_CC_PASCAL          600\n#define GGML_CUDA_CC_DP4A            610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products\n#define GGML_CUDA_CC_VOLTA           700\n#define GGML_CUDA_CC_TURING          750\n#define GGML_CUDA_CC_AMPERE          800\n#define GGML_CUDA_CC_ADA_LOVELACE    890\n// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see\n// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms\n#define GGML_CUDA_CC_BLACKWELL       1200\n#define GGML_CUDA_CC_DGX_SPARK       1210\n#define GGML_CUDA_CC_RUBIN           1300\n#define GGML_CUDA_CC_OFFSET_AMD      0x1000000\n#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000\n#define GGML_CUDA_CC_IS_NVIDIA(cc)   (cc < GGML_CUDA_CC_OFFSET_MTHREADS)\n\n// AMD\n// GCN/CDNA, wave size is 64\n#define GGML_CUDA_CC_GCN4       (GGML_CUDA_CC_OFFSET_AMD + 0x803)  // Tonga, Fiji, Polaris, minimum for fast fp16\n#define GGML_CUDA_CC_VEGA       (GGML_CUDA_CC_OFFSET_AMD + 0x900)  // Vega56/64, minimum for fp16 dual issue\n#define GGML_CUDA_CC_VEGA20     (GGML_CUDA_CC_OFFSET_AMD + 0x906)  // MI50/Radeon VII, minimum for dp4a\n#define GGML_CUDA_CC_CDNA1      (GGML_CUDA_CC_OFFSET_AMD + 0x908)  // MI100, minimum for MFMA, acc registers\n#define GGML_CUDA_CC_CDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x910)  // MI210, minimum acc register renameing\n#define GGML_CUDA_CC_CDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x942)  // MI300\n\n// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32\n#define GGML_CUDA_CC_RDNA1      (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000\n#define GGML_CUDA_CC_RDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a\n#define GGML_CUDA_CC_RDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA\n#define GGML_CUDA_CC_RDNA3_5    (GGML_CUDA_CC_OFFSET_AMD + 0x1150) // AI 370, AI Max 395 laptops.\n#define GGML_CUDA_CC_RDNA4      (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000\n\n#define GGML_CUDA_CC_IS_AMD(cc)     (cc >= GGML_CUDA_CC_OFFSET_AMD)\n#define GGML_CUDA_CC_IS_RDNA(cc)    (cc >= GGML_CUDA_CC_RDNA1)\n#define GGML_CUDA_CC_IS_RDNA1(cc)   (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)\n#define GGML_CUDA_CC_IS_RDNA2(cc)   (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)\n#define GGML_CUDA_CC_IS_RDNA3_0(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA3_5)\n#define GGML_CUDA_CC_IS_RDNA3_5(cc) (cc >= GGML_CUDA_CC_RDNA3_5 && cc < GGML_CUDA_CC_RDNA4)\n#define GGML_CUDA_CC_IS_RDNA3(cc)   (GGML_CUDA_CC_IS_RDNA3_0(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc))\n#define GGML_CUDA_CC_IS_RDNA4(cc)   (cc >= GGML_CUDA_CC_RDNA4)\n#define GGML_CUDA_CC_IS_GCN(cc)     (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)\n#define GGML_CUDA_CC_IS_CDNA(cc)    (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)\n#define GGML_CUDA_CC_IS_CDNA1(cc)   (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)\n#define GGML_CUDA_CC_IS_CDNA2(cc)   (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)\n#define GGML_CUDA_CC_IS_CDNA3(cc)   (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)\n\n// Moore Threads\n#define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons\n\n#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000\n#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000\n#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000\n\n#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)\n#define GGML_CUDA_CC_IS_QY1(cc)      (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)\n#define GGML_CUDA_CC_IS_QY2(cc)      (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1)\n#define GGML_CUDA_CC_IS_PH1(cc)      (cc >= GGML_CUDA_CC_PH1)\n\n#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070\n#    define GGML_CUDA_USE_CUB\n#endif  // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070\n\n#ifdef __CUDA_ARCH_LIST__\nconstexpr bool ggml_cuda_has_arch_impl(int) {\n    return false;\n}\n\ntemplate<class ... Archs>\nconstexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {\n    return arch == first || ggml_cuda_has_arch_impl(arch, rest...);\n}\n\nconstexpr bool ggml_cuda_has_arch(const int arch) {\n    return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);\n}\n\nconstexpr int ggml_cuda_highest_compiled_arch_impl(const int /*arch*/, const int cur) {\n    if (cur == 0) {\n        return -1;\n    }\n    return cur;\n}\n\ntemplate<class ... Archs>\nconstexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {\n    if (first <= arch && first > cur) {\n        return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);\n    } else {\n        return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);\n    }\n}\n\nconstexpr int ggml_cuda_highest_compiled_arch(const int arch) {\n    return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);\n}\n#else\nstatic int ggml_cuda_highest_compiled_arch(const int arch) {\n    return arch;\n}\n#endif // __CUDA_ARCH_LIST__\n\n// ---------------------------------------------------------------------------------------------------------\n\n#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses\n\n#define GGML_CUDA_MAX_STREAMS 8\n\n[[noreturn]]\nvoid ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);\n\n#define CUDA_CHECK_GEN(err, success, error_fn)                                      \\\n     do {                                                                           \\\n        auto err_ = (err);                                                          \\\n        if (err_ != (success)) {                                                    \\\n            ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_));    \\\n        }                                                                           \\\n    } while (0)\n\n#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)\n\n#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)\n    static const char * cublas_get_error_str(const cublasStatus_t err) {\n        return cublasGetStatusString(err);\n    }\n#else\n    static const char * cublas_get_error_str(const cublasStatus_t err) {\n        switch (err) {\n            case CUBLAS_STATUS_SUCCESS: return \"CUBLAS_STATUS_SUCCESS\";\n            case CUBLAS_STATUS_NOT_INITIALIZED: return \"CUBLAS_STATUS_NOT_INITIALIZED\";\n            case CUBLAS_STATUS_ALLOC_FAILED: return \"CUBLAS_STATUS_ALLOC_FAILED\";\n            case CUBLAS_STATUS_INVALID_VALUE: return \"CUBLAS_STATUS_INVALID_VALUE\";\n            case CUBLAS_STATUS_ARCH_MISMATCH: return \"CUBLAS_STATUS_ARCH_MISMATCH\";\n            case CUBLAS_STATUS_MAPPING_ERROR: return \"CUBLAS_STATUS_MAPPING_ERROR\";\n            case CUBLAS_STATUS_EXECUTION_FAILED: return \"CUBLAS_STATUS_EXECUTION_FAILED\";\n            case CUBLAS_STATUS_INTERNAL_ERROR: return \"CUBLAS_STATUS_INTERNAL_ERROR\";\n            case CUBLAS_STATUS_NOT_SUPPORTED: return \"CUBLAS_STATUS_NOT_SUPPORTED\";\n            default: return \"unknown error\";\n        }\n    }\n#endif // CUDART_VERSION >= 12000\n\n#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)\n\n#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)\nstatic const char * cu_get_error_str(CUresult err) {\n    const char * err_str;\n    cuGetErrorString(err, &err_str);\n    return err_str;\n}\n#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)\n#endif\n\n#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)\n#    define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes)                                                       \\\n        do {                                                                                                   \\\n            static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false };                         \\\n            const int   id                                                = ggml_cuda_get_device();            \\\n            if (!shared_memory_limit_raised[id]) {                                                             \\\n                CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \\\n                shared_memory_limit_raised[id] = true;                                                         \\\n            }                                                                                                  \\\n        } while (0)\n#else\n#    define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \\\n        do {                                             \\\n            GGML_UNUSED(nbytes);                         \\\n        } while (0)\n#endif // !(defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)\n\n#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)\n#define GGML_CUDA_ASSUME(x) __builtin_assume(x)\n#else\n#define GGML_CUDA_ASSUME(x)\n#endif // CUDART_VERSION >= 11010\n\n#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))\n#define GGML_USE_VMM\n#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))\n\n#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL\n#define FP16_AVAILABLE\n#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL\n\n#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610\n#define FAST_FP16_AVAILABLE\n#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610\n\n#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)\n#define AMD_MFMA_AVAILABLE\n#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)\n\n#if defined(GGML_USE_HIP) && (defined(RDNA4) || defined(RDNA3))\n#define AMD_WMMA_AVAILABLE\n#endif // defined(GGML_USE_HIP) && defined(RDNA4)\n\n// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:\n#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n#define VOLTA_MMA_AVAILABLE\n#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n\n#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING\n#define TURING_MMA_AVAILABLE\n#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING\n\n#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n#define AMPERE_MMA_AVAILABLE\n#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n\n#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN\n#    define BLACKWELL_MMA_AVAILABLE\n#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL\n\n#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n#define CP_ASYNC_AVAILABLE\n#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n\n#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)\n#define FLASH_ATTN_AVAILABLE\n#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)\n\n#if defined(TURING_MMA_AVAILABLE)\n#define LDMATRIX_TRANS_AVAILABLE\n#endif // defined(TURING_MMA_AVAILABLE)\n\nstatic bool fp16_available(const int cc) {\n    return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||\n        (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);\n}\n\nstatic bool fast_fp16_available(const int cc) {\n    return GGML_CUDA_CC_IS_AMD(cc) ||\n        (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) ||\n        (GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc));\n}\n\n// To be used for feature selection of external libraries, e.g. cuBLAS.\nstatic bool fast_fp16_hardware_available(const int cc) {\n    return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc) ||\n        (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);\n}\n\n// To be used for feature selection of external libraries, e.g. cuBLAS.\nstatic bool fp16_mma_hardware_available(const int cc) {\n    return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||\n        GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||\n        (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);\n}\n\nstatic bool bf16_mma_hardware_available(const int cc) {\n    return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||\n        GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 ||\n        (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);\n}\n\nstatic bool fp32_mma_hardware_available(const int cc) {\n    return GGML_CUDA_CC_IS_CDNA(cc);\n}\n\nstatic bool amd_mfma_available(const int cc) {\n#if !defined(GGML_HIP_NO_MMQ_MFMA)\n    return GGML_CUDA_CC_IS_CDNA(cc);\n#else\n    return false;\n#endif //!defined(GGML_HIP_NO_MMQ_MFMA)\n}\n\nstatic bool amd_wmma_available(const int cc) {\n    return (GGML_CUDA_CC_IS_RDNA4(cc) || GGML_CUDA_CC_IS_RDNA3(cc));\n}\n\nstatic bool volta_mma_available(const int cc) {\n    return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;\n}\n\nstatic bool turing_mma_available(const int cc) {\n    return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;\n}\n\nstatic bool ampere_mma_available(const int cc) {\n    return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;\n}\n\nstatic bool cp_async_available(const int cc) {\n    return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;\n}\n\nstatic bool blackwell_mma_available(const int cc) {\n    return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL &&\n           ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN;\n}\n\nstatic constexpr __device__ int ggml_cuda_get_physical_warp_size() {\n#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))\n    return 64;\n#else\n    return 32;\n#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))\n}\n\n// Maximum number of bytes that can be copied in a single instruction.\nstatic constexpr __device__ int ggml_cuda_get_max_cpy_bytes() {\n#ifdef GGML_USE_HIP\n    return 16;\n#else\n#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA\n    return 16;\n#else\n    return 8;\n#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA\n#endif // GGML_USE_HIP\n}\n\n\n[[noreturn]]\nstatic __device__ void no_device_code(\n    const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {\n\n#if defined(GGML_USE_HIP)\n    printf(\"%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\\n\",\n           file_name, line, function_name, arch);\n    GGML_UNUSED(arch_list);\n#else\n    printf(\"%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\\n\",\n           file_name, line, function_name, arch, arch_list);\n#endif // defined(GGML_USE_HIP)\n    __trap();\n\n    GGML_UNUSED(no_device_code); // suppress unused function warning\n\n#if defined(GGML_USE_MUSA)\n    __builtin_unreachable();\n#endif // defined(GGML_USE_MUSA)\n}\n\n#ifdef __CUDA_ARCH__\n#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))\n#else\n#define NO_DEVICE_CODE //GGML_ABORT(\"NO_DEVICE_CODE not valid in host code.\")\n#endif // __CUDA_ARCH__\n\n// The compiler is always able to unroll loops if they contain continue expressions.\n// In such cases loop unrolling can still be achieved via recursion:\ntemplate <int n>\nstruct ggml_cuda_unroll {\n    template <typename Func, typename... Args>\n    __device__ void operator()(const Func & f, Args... args) const {\n        f(n - 1, args...);\n        ggml_cuda_unroll<n - 1>{}(f, args...);\n    }\n};\n\ntemplate <>\nstruct ggml_cuda_unroll<1> {\n    template <typename Func, typename... Args>\n    __device__ void operator()(const Func & f, Args... args) const {\n        f(0, args...);\n    }\n};\n\ntemplate<int width = WARP_SIZE>\nstatic __device__ __forceinline__ int warp_reduce_sum(int x) {\n#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n    return __reduce_add_sync(0xffffffff, x);\n#else\n#pragma unroll\n    for (int offset = width/2; offset > 0; offset >>= 1) {\n        x += __shfl_xor_sync(0xffffffff, x, offset, width);\n    }\n    return x;\n#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n}\n\ntemplate<int width = WARP_SIZE>\nstatic __device__ __forceinline__ float warp_reduce_sum(float x) {\n#pragma unroll\n    for (int offset = width/2; offset > 0; offset >>= 1) {\n        x += __shfl_xor_sync(0xffffffff, x, offset, width);\n    }\n    return x;\n}\n\ntemplate<int width = WARP_SIZE>\nstatic __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {\n#pragma unroll\n    for (int offset = width/2; offset > 0; offset >>= 1) {\n        a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);\n        a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);\n    }\n    return a;\n}\n\ntemplate<int width = WARP_SIZE>\nstatic __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {\n#ifdef FP16_AVAILABLE\n#pragma unroll\n    for (int offset = width/2; offset > 0; offset >>= 1) {\n        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));\n    }\n    return a;\n\n#else\n    NO_DEVICE_CODE;\n    return a;\n#endif // FP16_AVAILABLE\n}\n\ntemplate<int width = WARP_SIZE>\nstatic __device__ __forceinline__ int warp_reduce_all(int x) {\n    if (width == ggml_cuda_get_physical_warp_size()) {\n        return __all_sync(0xffffffff, x);\n    } else {\n#pragma unroll\n        for (int offset = width/2; offset > 0; offset >>= 1) {\n            x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;\n        }\n        return x;\n    }\n}\n\ntemplate<int width = WARP_SIZE>\nstatic __device__ __forceinline__ int warp_reduce_any(int x) {\n    if (width == ggml_cuda_get_physical_warp_size()) {\n        return __any_sync(0xffffffff, x);\n    } else {\n#pragma unroll\n        for (int offset = width/2; offset > 0; offset >>= 1) {\n            x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;\n        }\n        return x;\n    }\n}\n\ntemplate<int width = WARP_SIZE>\nstatic __device__ __forceinline__ float warp_reduce_max(float x) {\n#pragma unroll\n    for (int offset = width/2; offset > 0; offset >>= 1) {\n        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));\n    }\n    return x;\n}\n\ntemplate<typename T, int width = WARP_SIZE>\nstatic __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) {\n    const int lane_id = threadIdx.x % width;\n#pragma unroll\n    for (int offset = 1; offset < width; offset <<= 1) {\n        const T t = __shfl_up_sync(0xffffffff, x, offset, width);\n        if (lane_id >= offset) {\n            x += t;\n        }\n    }\n    return x;\n}\n\ntemplate<int width = WARP_SIZE>\nstatic __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) {\n    const int lane_id = threadIdx.x % width;\n#pragma unroll\n    for (int offset = 1; offset < width; offset <<= 1) {\n        const float t_x = __shfl_up_sync(0xffffffff, a.x, offset, width);\n        const float t_y = __shfl_up_sync(0xffffffff, a.y, offset, width);\n        if (lane_id >= offset) {\n            a.x += t_x;\n            a.y += t_y;\n        }\n    }\n    return a;\n}\n\ntemplate<int width = WARP_SIZE>\nstatic __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {\n#ifdef FP16_AVAILABLE\n    const int lane_id = threadIdx.x % width;\n#pragma unroll\n    for (int offset = 1; offset < width; offset <<= 1) {\n        const half2 t = __shfl_up_sync(0xffffffff, a, offset, width);\n        if (lane_id >= offset) {\n            a = __hadd2(a, t);\n        }\n    }\n    return a;\n\n#else\n    NO_DEVICE_CODE;\n    return a;\n#endif // FP16_AVAILABLE\n}\n\nenum class block_reduce_method {\n    MAX,\n    SUM,\n};\n\ntemplate<block_reduce_method method_t, typename T>\nstruct block_reduce_policy;\n\ntemplate <typename T, typename... Ts>\ninline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);\n\ntemplate<typename...>\ninline constexpr bool ggml_cuda_dependent_false_v = false;\n\ntemplate <typename T> struct block_reduce_policy<block_reduce_method::SUM, T> {\n    static __device__ T reduce(T val) {\n        if constexpr(is_any<T, float, float2, half2, int>) {\n            return warp_reduce_sum(val);\n        } else {\n            static_assert(ggml_cuda_dependent_false_v<T>, \"Unsupported type for block reduce sum\");\n        }\n    }\n\n    static __device__ T sentinel() {\n        if constexpr (std::is_same_v<T, float>) {\n            return 0.0f;\n        } else if constexpr (std::is_same_v<T, float2>) {\n            return make_float2(0.0f, 0.0f);\n        } else if constexpr (std::is_same_v<T, half2>) {\n            return make_half2(0.0f, 0.0f);\n        } else if constexpr (std::is_same_v<T, int>) {\n            return 0;\n        } else {\n            static_assert(ggml_cuda_dependent_false_v<T>, \"Unsupported type for block reduce sum\");\n        }\n    }\n};\n\ntemplate <typename T> struct block_reduce_policy<block_reduce_method::MAX, T> {\n    static __device__ T reduce(T val) {\n        if constexpr (is_any<T, float, half2>) {\n            return warp_reduce_max(val);\n        } else {\n            static_assert(ggml_cuda_dependent_false_v<T>, \"Unsupported type for block reduce max\");\n        }\n    }\n\n    static __device__ T sentinel() {\n        if constexpr (std::is_same_v<T, float>) {\n            return -INFINITY;\n        } else if constexpr (std::is_same_v<T, half2>) {\n            return make_half2(-INFINITY, -INFINITY);\n        } else {\n            static_assert(ggml_cuda_dependent_false_v<T>, \"Unsupported type for block reduce max\");\n        }\n    }\n};\n\ntemplate <block_reduce_method reduce_method_t, const unsigned int block_size_template = 0, typename T>\nstatic __device__ T block_reduce(T val, T * shared_vals) {\n    val                           = block_reduce_policy<reduce_method_t, T>::reduce(val);\n    const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template;\n    if (block_size > WARP_SIZE) {\n        assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0);\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        const int lane_id = threadIdx.x % WARP_SIZE;\n        if (lane_id == 0) {\n            shared_vals[warp_id] = val;\n        }\n        __syncthreads();\n        val = block_reduce_policy<reduce_method_t, T>::sentinel();\n        if (lane_id < (static_cast<int>(block_size) / WARP_SIZE)) {\n            val = shared_vals[lane_id];\n        }\n        return block_reduce_policy<reduce_method_t, T>::reduce(val);\n    }\n\n    return val;\n}\n\nstatic __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {\n#ifdef FP16_AVAILABLE\n\n#if !defined(GGML_USE_HIP) && CUDART_VERSION < CUDART_HMAX\n    return __float2half(fmaxf(__half2float(a), __half2float(b)));\n#else\n    return __hmax(a, b);\n#endif // !defined(GGML_USE_HIP) && CUDART_VERSION < CUDART_HMAX\n\n#else\n   NO_DEVICE_CODE;\n   GGML_UNUSED(b);\n   return a;\n#endif // FP16_AVAILABLE\n}\n\nstatic __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {\n#if defined(GGML_USE_HIP)\n    return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));\n#elif CUDART_VERSION >= CUDART_HMAX\n    return __hmax2(a, b);\n#else\n    half2 ret;\n    reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a),  __low2float(b)));\n    reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));\n    return ret;\n#endif\n}\n\ntemplate<int width = WARP_SIZE>\nstatic __device__ __forceinline__ half2 warp_reduce_max(half2 x) {\n#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)\n#pragma unroll\n   for (int offset = width/2; offset > 0; offset >>= 1) {\n       x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));\n   }\n   return x;\n#else\n   GGML_UNUSED(x);\n   NO_DEVICE_CODE;\n#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)\n}\n\n#if (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || \\\n    (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK)\nstatic __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {\n    const uint32_t mask_low  = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));\n    const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));\n    return mask_low | mask_high;\n}\n#endif // (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK)\n\nstatic __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {\n#if defined(GGML_USE_HIP)\n#if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)\n    c = __builtin_amdgcn_sdot4(a, b, c, false);\n#elif defined(RDNA3) || defined(RDNA4)\n    c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);\n#elif defined(RDNA1) || defined(__gfx900__)\n    int tmp1;\n    int tmp2;\n    asm(\"\\n \\\n        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \\n \\\n        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \\n \\\n        v_add3_u32 %0, %1, %2, %0 \\n \\\n        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \\n \\\n        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \\n \\\n        v_add3_u32 %0, %1, %2, %0 \\n \\\n        \"\n        : \"+v\"(c), \"=&v\"(tmp1), \"=&v\"(tmp2)\n        : \"v\"(a), \"v\"(b)\n    );\n#else\n    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);\n    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);\n    c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];\n#endif\n    return c;\n\n#else // defined(GGML_USE_HIP)\n\n#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)\n    return __dp4a(a, b, c);\n#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)\n    const int8_t * a8 = (const int8_t *) &a;\n    const int8_t * b8 = (const int8_t *) &b;\n    return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];\n#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)\n\n#endif // defined(GGML_USE_HIP)\n}\n\nstatic __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) {\n    acc += v*u;\n}\n\nstatic __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) {\n    acc += v.x*u.x;\n    acc += v.y*u.y;\n}\n\n#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))\n#define V_DOT2_F32_F16_AVAILABLE\n#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))\n\nstatic __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {\n#ifdef V_DOT2_F32_F16_AVAILABLE\n    asm volatile(\"v_dot2_f32_f16 %0, %1, %2, %0\" : \"+v\"(acc) : \"v\"(v), \"v\"(u));\n#else\n#ifdef FAST_FP16_AVAILABLE\n    const float2 tmp = __half22float2(v*u);\n    acc += tmp.x + tmp.y;\n#else\n    const float2 tmpv = __half22float2(v);\n    const float2 tmpu = __half22float2(u);\n    acc += tmpv.x * tmpu.x;\n    acc += tmpv.y * tmpu.y;\n#endif // FAST_FP16_AVAILABLE\n#endif // V_DOT2_F32_F16_AVAILABLE\n}\n\nstatic __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {\n#ifdef FAST_FP16_AVAILABLE\n    acc += v*u;\n#else\n    const float2 tmpv = __half22float2(v);\n    const float2 tmpu = __half22float2(u);\n    float2 tmpacc = __half22float2(acc);\n    tmpacc.x += tmpv.x * tmpu.x;\n    tmpacc.y += tmpv.y * tmpu.y;\n    acc = make_half2(tmpacc.x, tmpacc.y);\n#endif // FAST_FP16_AVAILABLE\n}\n\n// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.\n// Important: do not use this function if dst and src both point at registers.\n//     Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.\n//     The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.\n//     If dst and src point at different address spaces then they are guaranteed to not be aliased.\ntemplate <int nbytes, int alignment = 0>\nstatic __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {\n    static_assert(\n        nbytes <= ggml_cuda_get_max_cpy_bytes() || alignment == 0,\n        \"You are misusing the alignment parameter for ggml_cuda_memcpy_1. \"\n        \"The intent is for the parameter is only as a workaround if either one of the pointers is not properly aligned. \"\n        \"If you use it to do more bytes per copy than ggml_cuda_max_cpy_bytes() the reads and writes may not be coalesced. \"\n        \"Call ggml_cuda_memcpy_1 in a loop instead.\");\n    if constexpr (alignment != 0) {\n        static_assert(nbytes % alignment == 0, \"bad alignment\");\n    }\n    constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;\n\n#pragma unroll\n    for (int i = 0; i < nbytes/nb_per_cpy; ++i) {\n        if constexpr (nb_per_cpy == 1) {\n            ((char *) dst)[i] = ((const char *) src)[i];\n        } else if constexpr (nb_per_cpy == 2) {\n            ((short *) dst)[i] = ((const short *) src)[i];\n        } else if constexpr (nb_per_cpy == 4) {\n            ((int *) dst)[i] = ((const int *) src)[i];\n        } else if constexpr (nb_per_cpy == 8) {\n            ((int2 *) dst)[i] = ((const int2 *) src)[i];\n        } else if constexpr (nb_per_cpy == 16) {\n            ((int4 *) dst)[i] = ((const int4 *) src)[i];\n        } else {\n            static_assert(nbytes == 0 && nbytes == -1, \"bad nbytes\");\n        }\n    }\n}\n\nstatic __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {\n#if CUDART_VERSION >= 12080\n    const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);\n    return (float) e;\n#else\n    uint32_t bits;\n    if (x == 0) {\n        bits = 0x00400000;\n    } else {\n        bits = (uint32_t) x << 23;\n    }\n\n    float result;\n    memcpy(&result, &bits, sizeof(float));\n    return result;\n#endif // CUDART_VERSION >= 12050\n}\n\n__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {\n    const uint8_t sign_bit = (x < 0.0f) << 3;\n    float         ax       = fabsf(x) * e;\n\n    // Positive LUT\n    static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };\n\n    int   best_i   = 0;\n    float best_err = fabsf(ax - pos_lut[0]);\n\n#pragma unroll\n    for (int i = 1; i < 8; ++i) {\n        const float err = fabsf(ax - pos_lut[i]);\n        if (err < best_err) {\n            best_err = err;\n            best_i   = i;\n        }\n    }\n\n    return static_cast<uint8_t>(best_i | sign_bit);\n}\n\n// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.\n// Precompute mp (m' in the paper) and L such that division\n// can be computed using a multiply (high 32b of 64b result)\n// and a shift:\n//\n// n/d = (mulhi(n, mp) + n) >> L;\nstatic const uint3 init_fastdiv_values(uint64_t d_64) {\n    GGML_ASSERT(d_64 != 0);\n    GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());\n\n    uint32_t d = (uint32_t)d_64;\n\n    // compute L = ceil(log2(d));\n    uint32_t L = 0;\n    while (L < 32 && (uint32_t{ 1 } << L) < d) {\n        L++;\n    }\n\n    uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);\n    // pack divisor as well to reduce error surface\n    return make_uint3(mp, L, d);\n}\n\nstatic __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) {\n    // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z>\n    // fastdiv_values.z is unused and optimized away by the compiler.\n    // Compute high 32 bits of n * mp\n    const uint32_t hi = __umulhi(n, fastdiv_values.x);\n    // add n, apply bit shift\n    return (hi + n) >> fastdiv_values.y;\n}\n\nstatic __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) {\n    // expects  fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)\n    return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;\n}\n\n// Calculate both division and modulo at once, returns <n/divisor, n%divisor>\nstatic __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) {\n    // expects  fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)\n    const uint32_t div_val = fastdiv(n, fastdiv_values);\n    const uint32_t mod_val = n - div_val * fastdiv_values.z;\n    return make_uint2(div_val, mod_val);\n}\n\ntypedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);\n\nstatic __device__ __forceinline__ float get_alibi_slope(\n    const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1\n) {\n    if (max_bias <= 0.0f) {\n        return 1.0f;\n    }\n    const float base = h < n_head_log2 ? m0 : m1;\n    const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;\n\n    return powf(base, exph);\n}\n\ntemplate <ggml_type type>\nstruct ggml_cuda_type_traits;\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_F16> {\n    static constexpr int qk = 1;\n    static constexpr int qr = 1;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {\n    static constexpr int qk = QK4_0;\n    static constexpr int qr = QR4_0;\n    static constexpr int qi = QI4_0;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {\n    static constexpr int qk = QK4_1;\n    static constexpr int qr = QR4_1;\n    static constexpr int qi = QI4_1;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {\n    static constexpr int qk = QK5_0;\n    static constexpr int qr = QR5_0;\n    static constexpr int qi = QI5_0;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {\n    static constexpr int qk = QK5_1;\n    static constexpr int qr = QR5_1;\n    static constexpr int qi = QI5_1;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {\n    static constexpr int qk = QK8_0;\n    static constexpr int qr = QR8_0;\n    static constexpr int qi = QI8_0;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {\n    static constexpr int qk = QK_MXFP4;\n    static constexpr int qr = QR_MXFP4;\n    static constexpr int qi = QI_MXFP4;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {\n    static constexpr int qk = QK_K;\n    static constexpr int qr = QR2_K;\n    static constexpr int qi = QI2_K;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {\n    static constexpr int qk = QK_K;\n    static constexpr int qr = QR3_K;\n    static constexpr int qi = QI3_K;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {\n    static constexpr int qk = QK_K;\n    static constexpr int qr = QR4_K;\n    static constexpr int qi = QI4_K;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {\n    static constexpr int qk = QK_K;\n    static constexpr int qr = QR5_K;\n    static constexpr int qi = QI5_K;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {\n    static constexpr int qk = QK_K;\n    static constexpr int qr = QR6_K;\n    static constexpr int qi = QI6_K;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {\n    static constexpr int qk = QK_K;\n    static constexpr int qr = QR2_XXS;\n    static constexpr int qi = QI2_XXS;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {\n    static constexpr int qk = QK_K;\n    static constexpr int qr = QR2_XS;\n    static constexpr int qi = QI2_XS;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {\n    static constexpr int qk = QK_K;\n    static constexpr int qr = QR2_S;\n    static constexpr int qi = QI2_S;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {\n    static constexpr int qk = QK_K;\n    static constexpr int qr = QR3_XXS;\n    static constexpr int qi = QI3_XXS;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {\n    static constexpr int qk = QK_K;\n    static constexpr int qr = QR1_S;\n    static constexpr int qi = QI1_S;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {\n    static constexpr int qk = QK_K;\n    static constexpr int qr = QR1_M;\n    static constexpr int qi = QI1_M;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {\n    static constexpr int qk = QK4_NL;\n    static constexpr int qr = QR4_NL;\n    static constexpr int qi = QI4_NL;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {\n    static constexpr int qk = QK_K;\n    static constexpr int qr = QR4_XS;\n    static constexpr int qi = QI4_XS;\n};\n\ntemplate<>\nstruct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {\n    static constexpr int qk = QK_K;\n    static constexpr int qr = QR3_S;\n    static constexpr int qi = QI3_S;\n};\n\n//////////////////////\n\nstruct ggml_cuda_device_info {\n    int device_count;\n\n    struct cuda_device_info {\n        int     cc;                             // compute capability\n        int     nsm;                            // number of streaming multiprocessors\n        size_t  smpb;                           // max. shared memory per block\n        size_t  smpbo;                          // max. shared memory per block (with opt-in)\n        bool    integrated;                     // Device is integrated as opposed to discrete\n        bool    vmm;                            // virtual memory support\n        size_t  vmm_granularity;                // granularity of virtual memory\n        size_t  total_vram;\n        int     warp_size;                      // Number of threads in a dispatch\n        bool    supports_cooperative_launch;    // whether cooperative launch is supported\n    };\n\n    cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};\n\n    std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};\n};\n\nconst ggml_cuda_device_info & ggml_cuda_info();\n\nvoid ggml_cuda_set_device(int device);\nint ggml_cuda_get_device();\n\nstruct ggml_cuda_pool {\n    virtual ~ggml_cuda_pool() = default;\n\n    virtual void * alloc(size_t size, size_t * actual_size) = 0;\n    virtual void free(void * ptr, size_t size) = 0;\n};\n\ntemplate<typename T>\nstruct ggml_cuda_pool_alloc {\n    ggml_cuda_pool * pool = nullptr;\n    T * ptr = nullptr;\n    size_t actual_size = 0;\n\n    ggml_cuda_pool_alloc() = default;\n\n    explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {\n    }\n\n    ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {\n        alloc(size);\n    }\n\n    ~ggml_cuda_pool_alloc() {\n        if (ptr != nullptr) {\n            pool->free(ptr, actual_size);\n        }\n    }\n\n    // size is in number of elements\n    T * alloc(size_t size) {\n        GGML_ASSERT(pool != nullptr);\n        GGML_ASSERT(ptr == nullptr);\n        ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);\n        return ptr;\n    }\n\n    T * alloc(ggml_cuda_pool & pool, size_t size) {\n        this->pool = &pool;\n        return alloc(size);\n    }\n\n    T * get() {\n        return ptr;\n    }\n\n    ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;\n    ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;\n    ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;\n    ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;\n};\n\n\n// backend interface\n\nstruct ggml_tensor_extra_gpu {\n    void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors\n    cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs\n};\n\n\n#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) || defined(GGML_MUSA_GRAPHS)\n#define USE_CUDA_GRAPH\n#endif\n\nstruct ggml_cuda_graph_node_properties {\n    void * node_data;\n    ggml_op node_op;\n    enum ggml_type node_type;\n    int32_t flags;\n    int64_t ne[GGML_MAX_DIMS];\n    size_t nb[GGML_MAX_DIMS];\n    void * src_data[GGML_MAX_SRC];\n    int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];\n};\n\nstatic_assert(std::is_trivial<ggml_cuda_graph_node_properties>::value, \"ggml_cuda_graph_node_properties must be trivial\");\n\nstruct ggml_cuda_graph {\n#ifdef USE_CUDA_GRAPH\n    ~ggml_cuda_graph() {\n        if (instance != nullptr) {\n            CUDA_CHECK(cudaGraphExecDestroy(instance));\n        }\n        if (graph != nullptr) {\n            CUDA_CHECK(cudaGraphDestroy(graph));\n        }\n    }\n    cudaGraph_t graph = nullptr;\n    cudaGraphExec_t instance = nullptr;\n    size_t num_nodes = 0;\n    std::vector<cudaGraphNode_t> nodes;\n    bool disable_due_to_gpu_arch = false;\n    bool warmup_complete = false;\n    std::vector<ggml_cuda_graph_node_properties> props;\n\n    // these are extra tensors (inputs) that participate in the ggml graph but are not nodes\n    // they properties also have to match in order to be able to safely reuse a CUDA graph\n    // ref: https://github.com/ggml-org/llama.cpp/pull/18583\n    // ref: https://github.com/ggml-org/llama.cpp/pull/19165\n    std::vector<ggml_cuda_graph_node_properties> extra;\n\n    bool is_enabled() const {\n        static const bool disable_cuda_graphs_due_to_env = (getenv(\"GGML_CUDA_DISABLE_GRAPHS\") != nullptr);\n        return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env);\n    }\n#endif\n};\n\nstruct ggml_cuda_concurrent_event {\n    std::vector<cudaEvent_t> join_events;\n    cudaEvent_t              fork_event = nullptr;\n\n    int                                          n_streams = 0;\n    std::unordered_map<const ggml_tensor *, int> stream_mapping;\n\n    // Original order of nodes in this concurrent region (before interleaving)\n    // Used to restore grouping for fusion within streams\n    std::vector<const ggml_tensor *> original_order;\n\n    const ggml_tensor * join_node;\n\n    ggml_cuda_concurrent_event() = default;\n\n    ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;\n    ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;\n\n    explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {\n        join_events.resize(n_streams);\n\n        for (size_t i = 0; i < join_events.size(); ++i) {\n            CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));\n        }\n\n        CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));\n    }\n\n    ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept\n    : join_events(std::move(other.join_events))\n    , fork_event(other.fork_event)\n    , n_streams(other.n_streams)\n    , stream_mapping(std::move(other.stream_mapping))\n    , original_order(std::move(other.original_order))\n    , join_node(other.join_node) {\n        other.fork_event = nullptr;\n    }\n\n    // 1. check if any branches write to overlapping memory ranges (except the join node)\n    // 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event\n    // we assume all nodes have the same buffer\n    bool is_valid() const {\n        std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;\n        write_ranges.resize(n_streams);\n\n        // get join_node's memory range to exclude from overlap checking.\n        // multiple nodes can use join_node's buffer; we synchronize on the join node.\n        const ggml_tensor * join_t     = join_node->view_src ? join_node->view_src : join_node;\n        const int64_t       join_start = (int64_t) join_t->data;\n        const int64_t       join_end   = join_start + ggml_nbytes(join_t);\n\n        for (const auto & [tensor, stream] : stream_mapping) {\n            const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;\n            const int64_t       t_start = (int64_t) t->data;\n            const int64_t       t_end   = t_start + ggml_nbytes(t);\n\n            // skip tensors that overlap with join_node's buffer.\n            if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {\n                continue;\n            }\n\n            // concurrent streams begin from 1\n            write_ranges[stream - 1].emplace_back(t_start, t_end);\n        }\n\n        for (int i = 0; i < n_streams; ++i) {\n            // sorts first by start then by end of write range\n            std::sort(write_ranges[i].begin(), write_ranges[i].end());\n        }\n\n        bool writes_overlap = false;\n        bool dependent_srcs = false;\n        for (const auto & [tensor, stream] : stream_mapping) {\n            const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;\n            const int64_t       t_start = (int64_t) t->data;\n            const int64_t       t_end   = t_start + ggml_nbytes(t);\n\n            // skip tensors that overlap with join_node's buffer\n            if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {\n                continue;\n            }\n\n            // check if this buffer's write data overlaps with another stream's\n            std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);\n            for (int i = 0; i < n_streams; ++i) {\n                if (i == stream - 1) {\n                    continue;\n                }\n                auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);\n\n                if (it != write_ranges[i].end()) {\n                    const std::pair<int64_t, int64_t> & other = *it;\n\n                    // std::lower_bound returns the first element where other >= data_range (lexicographically).\n                    // This guarantees other.first >= data_range.first.\n                    // Therefore, overlap occurs iff other.first < data_range.second\n                    // (i.e., the other range starts before this range ends).\n                    if (other.first < data_range.second) {\n                        GGML_LOG_DEBUG(\"Writes overlap for %s\", tensor->name);\n                        writes_overlap = true;\n                        break;\n                    }\n                }\n            }\n\n            //check if all srcs are either in branch or don't have a branch\n            for (int i = 0; i < GGML_MAX_SRC; ++i) {\n                if (!tensor->src[i]) {\n                    continue;\n                }\n\n                auto it = stream_mapping.find(tensor->src[i]);\n\n                if (it == stream_mapping.end()) {\n                    continue;\n                }\n\n                if (it->second != stream) {\n                    dependent_srcs = true;\n                    break;\n                }\n            }\n\n            if (dependent_srcs || writes_overlap) {\n                break;\n            }\n        }\n\n        return !writes_overlap && !dependent_srcs;\n    }\n\n    ~ggml_cuda_concurrent_event() {\n        if (fork_event != nullptr) {\n            CUDA_CHECK(cudaEventDestroy(fork_event));\n        }\n        for (cudaEvent_t e : join_events) {\n            if (e != nullptr) {\n                CUDA_CHECK(cudaEventDestroy(e));\n            }\n        }\n    }\n};\n\nstruct ggml_cuda_stream_context {\n    std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;\n\n    void reset() {\n        concurrent_events.clear();\n    }\n};\n\nstruct ggml_backend_cuda_context {\n    int device;\n    std::string name;\n    cudaEvent_t copy_event = nullptr;\n\n    cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };\n    cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};\n\n    int curr_stream_no = 0;\n\n#ifdef USE_CUDA_GRAPH\n    // Map from first_node_ptr to cuda_graph - allows multiple graphs per context\n    // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe)\n    std::unordered_map<const void *, std::unique_ptr<ggml_cuda_graph>> cuda_graphs;\n\n    ggml_cuda_graph * cuda_graph(const void * first_node_ptr) {\n        auto it = cuda_graphs.find(first_node_ptr);\n        if (it == cuda_graphs.end()) {\n            cuda_graphs[first_node_ptr] = std::make_unique<ggml_cuda_graph>();\n            return cuda_graphs[first_node_ptr].get();\n        }\n        return it->second.get();\n    }\n\n    // Check if any CUDA graph is enabled for this context (used by kernels that need to know\n    // if graphs are in use without having access to the specific graph key)\n    bool any_cuda_graph_enabled() const {\n        for (const auto & [key, graph] : cuda_graphs) {\n            if (graph && graph->is_enabled()) {\n                return true;\n            }\n        }\n        return false;\n    }\n\n    // Check if any CUDA graph has an instance for this context\n    bool any_cuda_graph_has_instance() const {\n        for (const auto & [key, graph] : cuda_graphs) {\n            if (graph && graph->instance != nullptr) {\n                return true;\n            }\n        }\n        return false;\n    }\n#endif // USE_CUDA_GRAPH\n\n    explicit ggml_backend_cuda_context(int device) :\n        device(device),\n        name(GGML_CUDA_NAME + std::to_string(device)) {\n    }\n\n    ggml_cuda_stream_context concurrent_stream_context;\n\n    ~ggml_backend_cuda_context();\n\n    cudaStream_t stream(int device, int stream) {\n        if (streams[device][stream] == nullptr) {\n            ggml_cuda_set_device(device);\n            CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));\n        }\n        return streams[device][stream];\n    }\n\n    cudaStream_t stream() { return stream(device, curr_stream_no); }\n\n    ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }\n\n    cublasHandle_t cublas_handle(int device) {\n        if (cublas_handles[device] == nullptr) {\n            ggml_cuda_set_device(device);\n            CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));\n            CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));\n        }\n        return cublas_handles[device];\n    }\n\n    cublasHandle_t cublas_handle() {\n        return cublas_handle(device);\n    }\n\n    // pool\n    std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];\n\n    static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);\n\n    ggml_cuda_pool & pool(int device) {\n        if (pools[device][curr_stream_no] == nullptr) {\n            pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);\n        }\n        return *pools[device][curr_stream_no];\n    }\n\n    ggml_cuda_pool & pool() {\n        return pool(device);\n    }\n};\n\nstruct ggml_cuda_mm_fusion_args_host {\n    const ggml_tensor * x_bias = nullptr;\n    const ggml_tensor * gate = nullptr;\n    const ggml_tensor * gate_bias = nullptr;\n    ggml_glu_op glu_op;\n};\nstruct ggml_cuda_mm_fusion_args_device {\n    const void * x_bias = nullptr;\n    const void * gate = nullptr;\n    const void * gate_bias = nullptr;\n    ggml_glu_op glu_op;\n};\n"
  },
  {
    "path": "src/ggml-cuda/concat.cu",
    "content": "#include \"concat.cuh\"\n\n// contiguous kernels\nstatic __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {\n    int nidx = threadIdx.x + blockIdx.x * blockDim.x;\n    if (nidx >= ne0) {\n        return;\n    }\n\n    int offset_dst =\n        nidx +\n        blockIdx.y * ne0 +\n        blockIdx.z * ne0 * gridDim.y;\n\n    if (nidx < ne00) { // src0\n        int offset_src =\n            nidx +\n            blockIdx.y * ne00 +\n            blockIdx.z * ne00 * gridDim.y;\n        dst[offset_dst] = x[offset_src];\n    } else {\n        int offset_src =\n            (nidx - ne00) +\n            blockIdx.y * (ne0 - ne00) +\n            blockIdx.z * (ne0 - ne00) * gridDim.y;\n        dst[offset_dst] = y[offset_src];\n    }\n}\n\nstatic __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {\n    int nidx = threadIdx.x + blockIdx.x * blockDim.x;\n    if (nidx >= ne0) {\n        return;\n    }\n\n    int offset_dst =\n        nidx +\n        blockIdx.y * ne0 +\n        blockIdx.z * ne0 * gridDim.y;\n\n    if (blockIdx.y < (unsigned)ne01) { // src0\n        int offset_src =\n            nidx +\n            blockIdx.y * ne0 +\n            blockIdx.z * ne0 * ne01;\n        dst[offset_dst] = x[offset_src];\n    } else {\n        int offset_src =\n            nidx +\n            (blockIdx.y - ne01) * ne0 +\n            blockIdx.z * ne0 * (gridDim.y - ne01);\n        dst[offset_dst] = y[offset_src];\n    }\n}\n\nstatic __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {\n    int nidx = threadIdx.x + blockIdx.x * blockDim.x;\n    if (nidx >= ne0) {\n        return;\n    }\n\n    int offset_dst =\n        nidx +\n        blockIdx.y * ne0 +\n        blockIdx.z * ne0 * gridDim.y;\n\n    if (blockIdx.z < (unsigned)ne02) { // src0\n        int offset_src =\n            nidx +\n            blockIdx.y * ne0 +\n            blockIdx.z * ne0 * gridDim.y;\n        dst[offset_dst] = x[offset_src];\n    } else {\n        int offset_src =\n            nidx +\n            blockIdx.y * ne0 +\n            (blockIdx.z - ne02) * ne0 *  gridDim.y;\n        dst[offset_dst] = y[offset_src];\n    }\n}\n\nstatic void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {\n    int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;\n    dim3 gridDim(num_blocks, ne1, ne2);\n    if (dim == 0) {\n        concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);\n        return;\n    }\n    if (dim == 1) {\n        concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01);\n        return;\n    }\n    concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);\n}\n\n// non-contiguous kernel (slow)\ntemplate <int dim>\nstatic __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)\n    concat_f32_non_cont(\n        const char * src0,\n        const char * src1,\n              char * dst,\n           int64_t   ne00,\n           int64_t   ne01,\n           int64_t   ne02,\n           int64_t   ne03,\n          uint64_t   nb00,\n          uint64_t   nb01,\n          uint64_t   nb02,\n          uint64_t   nb03,\n           int64_t /*ne10*/,\n           int64_t /*ne11*/,\n           int64_t /*ne12*/,\n           int64_t /*ne13*/,\n          uint64_t   nb10,\n          uint64_t   nb11,\n          uint64_t   nb12,\n          uint64_t   nb13,\n           int64_t   ne0,\n           int64_t /*ne1*/,\n           int64_t /*ne2*/,\n           int64_t /*ne3*/,\n          uint64_t   nb0,\n          uint64_t   nb1,\n          uint64_t   nb2,\n          uint64_t   nb3){\n    static_assert(dim >= 0 && dim <= 3, \"dim must be in [0, 3]\");\n\n    const int64_t i3 = blockIdx.z;\n    const int64_t i2 = blockIdx.y;\n    const int64_t i1 = blockIdx.x;\n\n    const float * x;\n\n    for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {\n        if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {\n            x = (const float *)(src0 + (i3       )*nb03 + (i2       )*nb02 + (i1       )*nb01 + (i0       )*nb00);\n        } else {\n            if constexpr (dim == 0) {\n                x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10);\n            } else if constexpr (dim == 1) {\n                x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10);\n            } else if constexpr (dim == 2) {\n                x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10);\n            } else if constexpr (dim == 3) {\n                x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10);\n            }\n        }\n\n        float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);\n\n        *y = *x;\n    }\n}\n\n\nvoid ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    cudaStream_t stream = ctx.stream();\n\n    const int32_t dim = ((int32_t *) dst->op_params)[0];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type  == GGML_TYPE_F32);\n\n    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {\n        const float * src0_d = (const float *)src0->data;\n        const float * src1_d = (const float *)src1->data;\n\n        float * dst_d = (float *)dst->data;\n\n        if (dim != 3) {\n            for (int i3 = 0; i3 < dst->ne[3]; i3++) {\n                concat_f32_cuda(\n                        src0_d + i3 * (src0->nb[3] / 4),\n                        src1_d + i3 * (src1->nb[3] / 4),\n                        dst_d + i3 * ( dst->nb[3] / 4),\n                        src0->ne[0], src0->ne[1], src0->ne[2],\n                        dst->ne[0],  dst->ne[1],  dst->ne[2], dim, stream);\n            }\n        } else {\n            const size_t size0 = ggml_nbytes(src0);\n            const size_t size1 = ggml_nbytes(src1);\n\n            CUDA_CHECK(cudaMemcpyAsync(dst_d,           src0_d, size0, cudaMemcpyDeviceToDevice, stream));\n            CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));\n        }\n    } else {\n        dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);\n        auto launch_kernel = [&](auto dim) {\n            concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(\n                (const char *) src0->data, (const char *) src1->data, (char *) dst->data,\n                src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],\n                src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],\n                src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],\n                src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],\n                dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],\n                dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);\n        };\n        switch (dim) {\n            case 0:\n                launch_kernel(std::integral_constant<int, 0>{});\n                break;\n            case 1:\n                launch_kernel(std::integral_constant<int, 1>{});\n                break;\n            case 2:\n                launch_kernel(std::integral_constant<int, 2>{});\n                break;\n            case 3:\n                launch_kernel(std::integral_constant<int, 3>{});\n                break;\n            default:\n                GGML_ABORT(\"Invalid dim: %d\", dim);\n                break;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/concat.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_CONCAT_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/conv-transpose-1d.cu",
    "content": "#include \"conv-transpose-1d.cuh\"\n\nstatic  __global__ void conv_transpose_1d_kernel(\n        const int s0, const int p0, const int d0, const int output_size,\n        const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,\n        const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,\n        const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,\n        const float * src0, const float * src1,  float * dst) {\n    int global_index = threadIdx.x + blockIdx.x * blockDim.x;\n    if (global_index >= output_size) {\n        return;\n    }\n\n    int out_index = global_index / dst_ne0;\n\n    float accumulator = 0;\n\n    for (int c = 0; c < src0_ne2; c++) {\n        int idx = global_index % dst_ne0;\n\n        int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);\n        int input_offset = src1_ne0 * c;\n\n        for (int i = 0; i < src1_ne0; i++) {\n            if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {\n                continue;\n            }\n            int weight_idx = idx - i*s0;\n\n            float kernel_weight = src0[kernel_offset + weight_idx];\n            float input_value =  src1[input_offset+i];\n\n            accumulator += kernel_weight * input_value;\n        }\n    }\n    dst[global_index] = accumulator;\n    GGML_UNUSED_VARS(p0, d0, src0_ne3, src1_ne3, dst_ne3, src1_ne1, dst_ne1, src1_ne2, dst_ne2);\n}\n\nstatic void conv_transpose_1d_f32_f32_cuda(\n        const int s0, const int p0, const int d0, const int output_size,\n        const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,\n        const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,\n        const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,\n        const float * src0, const float * src1,  float * dst,\n        cudaStream_t stream) {\n\n    const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE;\n    conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>(\n        s0,p0,d0,output_size,\n        src0_ne0, src0_ne1,  src0_ne2, src0_ne3,\n        src1_ne0, src1_ne1,  src1_ne2, src1_ne3,\n        dst_ne0,  dst_ne1,   dst_ne2,  dst_ne3,\n        src0,src1, dst);\n}\n\nvoid ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src0_d = (const float *)src0->data;\n\n    const ggml_tensor * src1 = dst->src[1];\n    const float * src1_d = (const float *)src1->data;\n\n    float * dst_d = (float *)dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_ASSERT(ggml_is_contiguous(src0));\n    GGML_ASSERT(ggml_is_contiguous(src1));\n\n    const int32_t * opts = (const int32_t *)dst->op_params;\n\n    const int s0 = opts[0];\n    const int p0 = 0;//opts[3];\n    const int d0 = 1;//opts[4];\n\n    const int64_t output_size = ggml_nelements(dst);\n\n    conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,\n        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],\n        src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],\n        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],\n        src0_d, src1_d, dst_d, stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/conv-transpose-1d.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/conv2d-dw.cu",
    "content": "#include \"conv2d-dw.cuh\"\n\nstruct conv_params {\n    int in_w, in_h;\n    int out_w, out_h;\n    int kernel_w, kernel_h;\n    int stride_x, stride_y;\n    int padding_x, padding_y;\n    int dilation_x, dilation_y;\n    int channels, batches;\n};\n\nstruct kernel_bounds {\n    int y_min, y_max;\n    int x_min, x_max;\n};\n\n__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {\n    kernel_bounds bounds;\n    bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);\n    bounds.y_max =\n        min(params.kernel_h,\n            (params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);\n    bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);\n    bounds.x_max =\n        min(params.kernel_w,\n            (params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);\n    return bounds;\n}\n\n__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {\n    return out_coord * stride + kern_coord * dilation - padding;\n}\n\nstruct whcn_layout {\n    __device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {\n        return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;\n    }\n\n    __device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {\n        return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx;\n    }\n\n    __device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {\n        return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h +\n               y * params.out_w + x;\n    }\n\n    __device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,\n                                          int & out_x) {\n        out_x = global_idx % params.out_w;\n        out_y = (global_idx / params.out_w) % params.out_h;\n        c     = (global_idx / (params.out_w * params.out_h)) % params.channels;\n        n     = global_idx / (params.out_w * params.out_h * params.channels);\n    }\n};\n\nstruct cwhn_layout {\n    __device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {\n        return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c;\n    }\n\n    __device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {\n        return (ky * params.kernel_w + kx) * params.channels + c;\n    }\n\n    __device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {\n        return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) +\n               x * params.channels + c;\n    }\n\n    __device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,\n                                          int & out_x) {\n        c     = global_idx % params.channels;\n        out_x = (global_idx / params.channels) % params.out_w;\n        out_y = (global_idx / (params.channels * params.out_w)) % params.out_h;\n        n     = global_idx / (params.channels * params.out_w * params.out_h);\n    }\n};\n\ntemplate <typename T, typename Layout>\n__global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,\n                                 const int in_w, const int in_h, const int out_w, const int out_h,\n                                 const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,\n                                 const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,\n                                 const int channels, const int batches) {\n    const int global_idx     = blockIdx.x * blockDim.x + threadIdx.x;\n    const int total_elements = batches * channels * out_h * out_w;\n\n    if (global_idx >= total_elements) {\n        return;\n    }\n\n    conv_params params = { in_w,     in_h,      out_w,     out_h,      kernel_w,   kernel_h, stride_x,\n                           stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };\n\n    int batch_idx, channel_idx, out_y_idx, out_x_idx;\n    Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);\n\n    T accumulator = 0;\n    kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);\n\n    for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {\n        int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);\n\n        for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {\n            int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);\n\n            const T input_val  = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];\n            const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];\n\n            accumulator += input_val * kernel_val;\n        }\n    }\n\n    output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;\n}\n\nvoid ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * kernel = dst->src[0];\n    const ggml_tensor * input  = dst->src[1];\n\n    GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);\n    const float * w_d = (const float *) kernel->data;\n    const float * x_d = (const float *) input->data;\n    float *       y_d = (float *) dst->data;\n\n    const int32_t * p          = (const int32_t *) dst->op_params;\n    const int       stride_x   = p[0];\n    const int       stride_y   = p[1];\n    const int       padding_x  = p[2];\n    const int       padding_y  = p[3];\n    const int       dilation_x = p[4];\n    const int       dilation_y = p[5];\n\n    const int in_w     = input->ne[0];\n    const int in_h     = input->ne[1];\n    const int kernel_w = kernel->ne[0];\n    const int kernel_h = kernel->ne[1];\n    const int out_w    = dst->ne[0];\n    const int out_h    = dst->ne[1];\n    const int channels = dst->ne[2];\n    const int batches  = dst->ne[3];\n\n    cudaStream_t st = ctx.stream();\n\n    const int total  = batches * channels * out_h * out_w;\n    const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE;\n\n    if (ggml_is_contiguous(input)) {\n        conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(\n            x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,\n            dilation_x, dilation_y, channels, batches);\n    } else if (ggml_is_contiguous_channels(input)) {\n        conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(\n            x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,\n            dilation_x, dilation_y, channels, batches);\n    } else {\n        GGML_ABORT(\"Unsupported memory layout for conv_2d_dw\");\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/conv2d-dw.cuh",
    "content": "#pragma once\n#include \"common.cuh\"\n\n#define CUDA_CONV2D_DW_BLOCK_SIZE 256\nvoid ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/conv2d-transpose.cu",
    "content": "#include <algorithm>\n\n#include \"conv2d-transpose.cuh\"\n#include \"ggml.h\"\n\n__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,\n                                        float * __restrict__ output, const int in_w, const int in_h, const int out_w,\n                                        const int out_h, const int kernel_w, const int kernel_h, const int stride,\n                                        const int c_in, const int c_out, const int batches) {\n    const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n    const int total_elements = out_w * out_h * c_out * batches;\n\n    if (global_idx >= total_elements) {\n        return;\n    }\n\n    const int out_x_idx = global_idx % out_w;\n    const int out_y_idx = (global_idx / out_w) % out_h;\n    const int c_idx     = (global_idx / (out_w * out_h)) % c_out;\n    const int n_idx     = global_idx / (out_w * out_h * c_out);\n\n    float accumulator = 0;\n    // For each output idx, find the inputs that contribute to it by checking stride alignment and bounds\n\n    for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {\n        for (int kh = 0; kh < kernel_h; ++kh) {\n            int in_y = out_y_idx - kh;\n            if (in_y < 0 || in_y % stride) continue;\n            in_y /= stride;\n            if (in_y >= in_h) continue;\n\n            for (int kw = 0; kw < kernel_w; ++kw) {\n                int in_x = out_x_idx - kw;\n                if (in_x < 0 || in_x % stride) continue;\n                in_x /= stride;\n                if (in_x >= in_w) continue;\n\n                const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;\n                const int kernel_idx =\n                    (kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;\n\n                float input_val = input[input_idx];\n                half  kern_val  = kernel[kernel_idx];\n\n                accumulator += input_val * (float) kern_val;\n            }\n        }\n    }\n\n    output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator;\n}\n\n//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in)\nvoid ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * kernel = dst->src[0];\n    const ggml_tensor * input  = dst->src[1];\n\n    GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);\n\n    const float * input_data  = (const float *) input->data;\n    float *       output_data = (float *) dst->data;\n    const half * kernel_data = (const half *) kernel->data;\n\n    const int input_w      = input->ne[0];\n    const int input_h      = input->ne[1];\n    const int output_w     = dst->ne[0];\n    const int output_h     = dst->ne[1];\n    const int channels_in  = input->ne[2];\n    const int channels_out = kernel->ne[2];\n    const int kernel_w     = kernel->ne[0];\n    const int kernel_h     = kernel->ne[1];\n    const int stride       = dst->op_params[0];\n    const int batches      = input->ne[3];\n\n    GGML_ASSERT(channels_in == kernel->ne[3]);\n    GGML_ASSERT(stride > 0);\n\n    cudaStream_t st = ctx.stream();\n\n    GGML_ASSERT(ggml_is_contiguous(input));\n    GGML_ASSERT(ggml_is_contiguous(kernel));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n\n    const int total  = (output_w * output_h * channels_out * batches);\n    const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;\n\n    conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(\n        input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,\n        channels_in, channels_out, batches);\n}\n"
  },
  {
    "path": "src/ggml-cuda/conv2d-transpose.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256\nvoid ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/conv2d.cu",
    "content": "#include \"conv2d.cuh\"\n#include \"convert.cuh\"\n\nstruct conv_params {\n    const int64_t IW, IH;\n    const int64_t OW, OH;\n    const int64_t KW, KH;\n    const int64_t ST_X, ST_Y;\n    const int64_t PD_X, PD_Y;\n    const int64_t DL_X, DL_Y;\n    const int64_t IC, OC;\n    const int64_t B;\n    const int64_t TOTAL;\n};\n\nstruct kernel_bounds {\n    int64_t y_min, y_max;\n    int64_t x_min, x_max;\n};\n\n__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {\n    return (a > b) ? a : b;\n}\n\n__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {\n    return (a < b) ? a : b;\n}\n\n__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {\n    kernel_bounds bounds;\n    bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);\n    bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);\n    bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);\n    bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);\n    return bounds;\n}\n\n__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,\n                                                     int64_t kern_coord,\n                                                     int64_t stride,\n                                                     int64_t dilation,\n                                                     int64_t padding) {\n    return out_coord * stride + kern_coord * dilation - padding;\n}\n\nstruct whcn_layout {\n    __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {\n        return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;\n    }\n\n    __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {\n        return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;\n    }\n\n    __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {\n        return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;\n    }\n\n    __device__ static void unpack_indices(int64_t             global_idx,\n                                          const conv_params & P,\n                                          int64_t &           n,\n                                          int64_t &           c,\n                                          int64_t &           out_y,\n                                          int64_t &           out_x) {\n        out_x = global_idx % P.OW;\n        out_y = (global_idx / P.OW) % P.OH;\n        c     = (global_idx / (P.OW * P.OH)) % P.OC;\n        n     = global_idx / (P.OW * P.OH * P.OC);\n    }\n};\n\ntemplate <typename T, typename Layout>\nstatic __global__ void conv2d_kernel(const float * __restrict__ input,\n                                     const T * __restrict__ kernel,\n                                     float * __restrict__ output,\n                                     const conv_params P) {\n    const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n    if (global_idx >= P.TOTAL) {\n        return;\n    }\n\n    int64_t n, c_out, out_y, out_x;\n    Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);\n\n    float acc = 0.0f;\n\n    for (int64_t c_in = 0; c_in < P.IC; ++c_in) {\n        kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);\n\n        for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {\n            const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);\n\n            for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {\n                const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);\n\n                const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];\n                const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];\n                acc += (input_val * ggml_cuda_cast<float>(kernel_val));\n            }\n        }\n    }\n\n    // [N, OC, OH, OW]\n    output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc;\n}\n\ntemplate <typename T>\nstatic void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {\n    const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;\n    conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);\n}\n\nstatic void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {\n    conv2d_cuda<half>(X_D, K_D, Y_D, P, st);\n}\n\nstatic void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {\n    conv2d_cuda<float>(X_D, K_D, Y_D, P, st);\n}\n\nvoid ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * kernel = dst->src[0];\n    const ggml_tensor * input  = dst->src[1];\n    float *             K_D    = (float *) kernel->data;\n    const float *       X_D    = (const float *) input->data;\n    float *             Y_D    = (float *) dst->data;\n\n    GGML_ASSERT(ggml_is_contiguous(kernel));\n    GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);\n\n    // same number of input channels\n    GGML_ASSERT(input->ne[2] == kernel->ne[2]);\n\n    cudaStream_t st = ctx.stream();\n\n    const int32_t * p    = (const int32_t *) dst->op_params;\n    const int       ST_X = p[0];  // stride_x\n    const int       ST_Y = p[1];  // stride_y\n    const int       PD_X = p[2];  // padding_x\n    const int       PD_Y = p[3];  // padding_y\n    const int       DL_X = p[4];  // dilation_x\n    const int       DL_Y = p[5];  // dilation_y\n\n    // No cwhn\n    GGML_ASSERT(p[6] == false);\n\n    const int IW = input->ne[0];   // input_w\n    const int IH = input->ne[1];   // input_h\n    const int OW = dst->ne[0];     // output_w\n    const int OH = dst->ne[1];     // output_h\n    const int KW = kernel->ne[0];  // kernel_w\n    const int KH = kernel->ne[1];  // kernel_h\n    const int IC = input->ne[2];   // input_channels\n    const int OC = kernel->ne[3];  // ouptut_chanles\n    const int B  = input->ne[3];   // n_batches\n\n    const int64_t total  = B * OC * OH * OW;\n    conv_params   params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };\n\n    if (kernel->type == GGML_TYPE_F16) {\n        conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);\n    } else {\n        conv2d_cuda_f32(X_D, K_D, Y_D, params, st);\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/conv2d.cuh",
    "content": "#pragma once\n#include \"common.cuh\"\n\n#define CUDA_CONV2D_BLOCK_SIZE 256\nvoid ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/convert.cu",
    "content": "#include \"convert.cuh\"\n#include \"dequantize.cuh\"\n\n#include <cstdint>\n\n#define CUDA_Q8_0_NE_ALIGN 2048\n\ntemplate <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>\nstatic __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,\n        const int64_t ne00, const int64_t ne01,\n        const int64_t ne0203, const uint3 ne02,\n        const int64_t s01, const int64_t s02, const int64_t s03) {\n    const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);\n\n    if (i00 >= ne00) {\n        return;\n    }\n\n    for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) {\n        for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {\n            const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);\n            const int64_t i02 = dm.y;\n            const int64_t i03 = dm.x;\n\n            const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;\n\n            const int64_t ib = ibx0 + i00/qk; // block index\n            const int64_t iqs = (i00%qk)/qr; // quant index\n            const int64_t iybs = i00 - i00%qk; // y block start index\n            const int64_t y_offset = qr == 1 ? 1 : qk/2;\n\n            // dequantize\n            float2 v;\n            dequantize_kernel(vx, ib, iqs, v);\n\n            const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;\n            y[iy0 + 0]        = ggml_cuda_cast<dst_t>(v.x);\n            y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);\n        }\n    }\n}\n\ntemplate <bool need_check>\nstatic __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {\n#if __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL\n    constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;\n\n    const int64_t   i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;\n    const int * x0 = ((int *) vx) + blockIdx.x * nint;\n    half2 * y2 = (half2 *) (y + i0);\n\n    __shared__ int vals[nint];\n\n#pragma unroll\n    for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {\n        if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {\n            break;\n        }\n\n        const int ix = ix0 + threadIdx.x;\n        vals[ix] = x0[ix];\n    }\n\n    __syncthreads();\n\n#pragma unroll\n    for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {\n        if (need_check && i0 + iy + 2*threadIdx.x >= k) {\n            return;\n        }\n\n        const half * b0 = ((const half  *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);\n        const half    d = *b0;\n        const char2  qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];\n\n        y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));\n    }\n#else\n    GGML_UNUSED_VARS(vx, y, k);\n    NO_DEVICE_CODE;\n#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {\n\n    const int64_t i = blockIdx.x;\n\n    // assume 32 threads\n    const int64_t tid = threadIdx.x;\n    const int64_t il  = tid/8;\n    const int64_t ir  = tid%8;\n    const int64_t ib = 8*i + ir;\n    if (ib >= nb32) {\n        return;\n    }\n\n    dst_t * y = yy + 256*i + 32*ir + 4*il;\n\n    const block_q4_0 * x = (const block_q4_0 *)vx + ib;\n    const float d = __half2float(x->d);\n    const float dm = -8*d;\n\n    const uint8_t * q = x->qs + 4*il;\n\n    for (int l = 0; l < 4; ++l) {\n        y[l+ 0] = d * (q[l] & 0xF) + dm;\n        y[l+16] = d * (q[l] >>  4) + dm;\n    }\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {\n\n    const int64_t i = blockIdx.x;\n\n    // assume 32 threads\n    const int64_t tid = threadIdx.x;\n    const int64_t il  = tid/8;\n    const int64_t ir  = tid%8;\n    const int64_t ib = 8*i + ir;\n    if (ib >= nb32) {\n        return;\n    }\n\n    dst_t * y = yy + 256*i + 32*ir + 4*il;\n\n    const block_q4_1 * x = (const block_q4_1 *)vx + ib;\n    const float2 d = __half22float2(x->dm);\n\n    const uint8_t * q = x->qs + 4*il;\n\n    for (int l = 0; l < 4; ++l) {\n        y[l+ 0] = d.x * (q[l] & 0xF) + d.y;\n        y[l+16] = d.x * (q[l] >>  4) + d.y;\n    }\n}\n\n//================================== k-quants\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n\n    const int64_t i   = blockIdx.x;\n    const block_q2_K * x = (const block_q2_K *) vx;\n\n    const int64_t tid = threadIdx.x;\n    const int64_t n   = tid/32;\n    const int64_t l   = tid - 32*n;\n    const int64_t is  = 8*n + l/16;\n\n    const uint8_t q = x[i].qs[32*n + l];\n    dst_t * y = yy + i*QK_K + 128*n;\n\n    float dall = __low2half(x[i].dm);\n    float dmin = __high2half(x[i].dm);\n    y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);\n    y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);\n    y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);\n    y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n\n    const int64_t i = blockIdx.x;\n    const block_q3_K * x = (const block_q3_K *) vx;\n\n    const int64_t r = threadIdx.x/4;\n    const int64_t tid = r/2;\n    const int64_t is0 = r%2;\n    const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);\n    const int64_t n = tid / 4;\n    const int64_t j = tid - 4*n;\n\n    uint8_t m = 1 << (4*n + j);\n    int64_t is = 8*n + 2*j + is0;\n    int shift = 2*j;\n\n    int8_t us = is <  4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :\n                is <  8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :\n                is < 12 ? (x[i].scales[is-8] >>  4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :\n                          (x[i].scales[is-8] >>  4) | (((x[i].scales[is-4] >> 6) & 3) << 4);\n    float d_all = x[i].d;\n    float dl = d_all * (us - 32);\n\n    dst_t * y = yy + i*QK_K + 128*n + 32*j;\n    const uint8_t * q = x[i].qs + 32*n;\n    const uint8_t * hm = x[i].hmask;\n\n    for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));\n}\n\nstatic inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {\n    if (j < 4) {\n        d = q[j] & 63; m = q[j + 4] & 63;\n    } else {\n        d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);\n        m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);\n    }\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n    const block_q4_K * x = (const block_q4_K *) vx;\n\n    const int64_t i = blockIdx.x;\n\n    // assume 32 threads\n    const int64_t tid = threadIdx.x;\n    const int64_t il  = tid/8;\n    const int64_t ir  = tid%8;\n    const int64_t is  = 2*il;\n    const int64_t n   = 4;\n\n    dst_t * y = yy + i*QK_K + 64*il + n*ir;\n\n    const float dall = __low2half(x[i].dm);\n    const float dmin = __high2half(x[i].dm);\n\n    const uint8_t * q = x[i].qs + 32*il + n*ir;\n\n    uint8_t sc, m;\n    get_scale_min_k4(is + 0, x[i].scales, sc, m);\n    const float d1 = dall * sc; const float m1 = dmin * m;\n    get_scale_min_k4(is + 1, x[i].scales, sc, m);\n    const float d2 = dall * sc; const float m2 = dmin * m;\n    for (int l = 0; l < n; ++l) {\n        y[l + 0] = d1 * (q[l] & 0xF) - m1;\n        y[l +32] = d2 * (q[l] >>  4) - m2;\n    }\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n    const block_q5_K * x = (const block_q5_K *) vx;\n\n    const int64_t i = blockIdx.x;\n\n    // assume 64 threads - this is very slightly better than the one below\n    const int64_t tid = threadIdx.x;\n    const int64_t il  = tid/16;   // il is in 0...3\n    const int64_t ir  = tid%16;   // ir is in 0...15\n    const int64_t is  = 2*il;     // is is in 0...6\n\n    dst_t * y = yy + i*QK_K + 64*il + 2*ir;\n\n    const float dall = __low2half(x[i].dm);\n    const float dmin = __high2half(x[i].dm);\n\n    const uint8_t * ql = x[i].qs + 32*il + 2*ir;\n    const uint8_t * qh = x[i].qh + 2*ir;\n\n    uint8_t sc, m;\n    get_scale_min_k4(is + 0, x[i].scales, sc, m);\n    const float d1 = dall * sc; const float m1 = dmin * m;\n    get_scale_min_k4(is + 1, x[i].scales, sc, m);\n    const float d2 = dall * sc; const float m2 = dmin * m;\n\n    uint8_t   hm  = 1 << (2*il);\n    y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;\n    y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;\n    hm <<= 1;\n    y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;\n    y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n    const block_q6_K * x = (const block_q6_K *) vx;\n\n    const int64_t i = blockIdx.x;\n\n    // assume 64 threads - this is very slightly better than the one below\n    const int64_t tid = threadIdx.x;\n    const int64_t ip  = tid/32;   // ip is 0 or 1\n    const int64_t il  = tid - 32*ip; // 0...32\n    const int64_t is  = 8*ip + il/16;\n\n    dst_t * y = yy + i*QK_K + 128*ip + il;\n\n    const float d = x[i].d;\n\n    const uint8_t * ql = x[i].ql + 64*ip + il;\n    const uint8_t   qh = x[i].qh[32*ip + il];\n    const int8_t  * sc = x[i].scales + is;\n\n    y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);\n    y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);\n    y[64] = d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh >> 4) & 3) << 4)) - 32);\n    y[96] = d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32);\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n\n    const int64_t i   = blockIdx.x;\n    const block_iq2_xxs * x = (const block_iq2_xxs  *) vx;\n\n    const int64_t tid = threadIdx.x;\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 8*il;\n    const uint16_t * q2 = x[i].qs + 4*ib;\n    const uint8_t  * aux8 = (const uint8_t *)q2;\n    const uint8_t  * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);\n    const uint32_t aux32 = q2[2] | (q2[3] << 16);\n    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;\n    const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];\n    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n\n    const int64_t i   = blockIdx.x;\n    const block_iq2_xs * x = (const block_iq2_xs *) vx;\n\n    const int64_t tid = threadIdx.x;\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 8*il;\n    const uint16_t * q2 = x[i].qs + 4*ib;\n    const uint8_t  * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));\n    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;\n    const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];\n    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n\n    const int64_t i   = blockIdx.x;\n    const block_iq2_s * x = (const block_iq2_s *) vx;\n\n    const int64_t tid = threadIdx.x;\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 8*il;\n    const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));\n    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;\n    const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];\n    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n\n    const int64_t i   = blockIdx.x;\n    const block_iq3_xxs * x = (const block_iq3_xxs  *) vx;\n\n    const int64_t tid = threadIdx.x;\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 8*il;\n    const uint8_t  * q3 = x[i].qs + 8*ib;\n    const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;\n    const uint8_t  * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);\n    const uint8_t  * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);\n    const uint32_t aux32 = gas[0] | (gas[1] << 16);\n    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;\n    const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];\n    for (int j = 0; j < 4; ++j) {\n        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);\n        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);\n    }\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n\n    const int64_t i   = blockIdx.x;\n    const block_iq3_s * x = (const block_iq3_s *) vx;\n\n    const int64_t tid = threadIdx.x;\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 8*il;\n    const uint8_t * qs = x[i].qs + 8*ib;\n    const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));\n    const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));\n    const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));\n    const uint8_t signs = x[i].signs[4*ib + il];\n    for (int j = 0; j < 4; ++j) {\n        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);\n        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);\n    }\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n\n    const int64_t i   = blockIdx.x;\n    const block_iq1_s * x = (const block_iq1_s  *) vx;\n\n    const int64_t tid = threadIdx.x;\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 8*il;\n    const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;\n    const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);\n    uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;\n    grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];\n    grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;\n    grid32[0] &= 0x0f0f0f0f;\n    for (int j = 0; j < 8; ++j) {\n        y[j] = d * (q[j] + delta);\n    }\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n\n    const int64_t i   = blockIdx.x;\n    const block_iq1_m * x = (const block_iq1_m  *) vx;\n\n    const int64_t tid = threadIdx.x;\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 8*il;\n    const uint16_t * sc = (const uint16_t *)x[i].scales;\n    iq1m_scale_t scale;\n    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);\n    const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);\n    const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);\n    const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;\n    uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;\n    grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];\n    grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;\n    grid32[0] &= 0x0f0f0f0f;\n    for (int j = 0; j < 8; ++j) {\n        y[j] = d * (q[j] + delta);\n    }\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n\n    const int64_t i   = blockIdx.x;\n    const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);\n\n    const int64_t tid = threadIdx.x;\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 4*il;\n    const uint8_t  * q4 = x[ib].qs + 4*il;\n    const float d = (float)x[ib].d;\n    for (int j = 0; j < 4; ++j) {\n        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];\n        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];\n    }\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n    const int64_t i   = blockIdx.x;\n    const block_iq4_xs * x = (const block_iq4_xs *)vx;\n\n    const int64_t tid = threadIdx.x;\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 4*il;\n    const uint8_t  * q4 = x[i].qs + 16*ib + 4*il;\n    const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);\n    for (int j = 0; j < 4; ++j) {\n        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];\n        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];\n    }\n}\n\ntemplate<typename dst_t>\nstatic __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {\n\n    const int64_t i   = blockIdx.x;\n    const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);\n\n    const int64_t tid = threadIdx.x;\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 4*il;\n    const uint8_t  * q4 = x[ib].qs + 4*il;\n    const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);\n    for (int j = 0; j < 4; ++j) {\n        y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;\n        y[j+16] = d * kvalues_mxfp4[q4[j] >>  4]*0.5f;\n    }\n}\n\ntemplate <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>\nstatic void dequantize_block_cuda(const void * vx, dst_t * y,\n        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,\n        const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {\n    const int64_t ne0203 = ne02*ne03;\n    const uint3 ne02_fdv = init_fastdiv_values(ne02);\n    const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535));\n    dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>\n        (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);\n}\n\ntemplate <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>\nstatic void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {\n    dequantize_block_cuda<qk, qr, dequantize_kernel, dst_t>(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream);\n}\n\nstatic void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {\n    const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;\n    if (k % CUDA_Q8_0_NE_ALIGN == 0) {\n        const bool need_check = false;\n        dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);\n    } else {\n        const bool need_check = true;\n        dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);\n    }\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = k / QK_K;\n    dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = k / QK_K;\n    dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb32 = k / 32;\n    const int nb = (k + 255) / 256;\n    dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb32 = k / 32;\n    const int nb = (k + 255) / 256;\n    dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = k / QK_K;\n    dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = k / QK_K;\n    dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = k / QK_K;\n    dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = k / QK_K;\n    dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = k / QK_K;\n    dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = k / QK_K;\n    dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = k / QK_K;\n    dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = k / QK_K;\n    dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = k / QK_K;\n    dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = (k + QK_K - 1) / QK_K;\n    dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = k / QK_K;\n    dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = (k + QK_K - 1) / QK_K;\n    dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    const int nb = (k + QK_K - 1) / QK_K;\n    dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);\n}\n\ntemplate <typename src_t, typename dst_t>\nstatic __global__ void convert_unary(\n        const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,\n        const int64_t ne0203, const uint3 ne02,\n        const int64_t s01, const int64_t s02, const int64_t s03) {\n    const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;\n\n    if (i00 >= ne00) {\n        return;\n    }\n\n    const src_t * x = (const src_t *) vx;\n\n    for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) {\n        for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {\n            const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);\n            const int64_t i02 = dm.y;\n            const int64_t i03 = dm.x;\n\n            const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;\n            const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;\n            y[iy] = ggml_cuda_cast<dst_t>(x[ix]);\n        }\n    }\n}\n\ntemplate <typename src_t, typename dst_t>\nstatic void convert_unary_cuda(const void * vx, dst_t * y,\n        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,\n        const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {\n    const int64_t ne0203 = ne02*ne03;\n    const uint3 ne02_fdv = init_fastdiv_values(ne02);\n    const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535));\n    convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>\n        (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);\n}\n\ntemplate <typename src_t, typename dst_t>\nstatic void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {\n    convert_unary_cuda<src_t>(vx, y, k, 1, 1, 1, k, k, k, stream);\n}\n\nto_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_F32:\n            return convert_unary_cont_cuda<float>;\n        case GGML_TYPE_F16:\n            return convert_unary_cont_cuda<half>;\n        default:\n            return nullptr;\n    }\n}\n\nto_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_Q4_0:\n            return dequantize_row_q4_0_cuda;\n        case GGML_TYPE_Q4_1:\n            return dequantize_row_q4_1_cuda;\n        case GGML_TYPE_Q5_0:\n            return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;\n        case GGML_TYPE_Q5_1:\n            return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;\n        case GGML_TYPE_Q8_0:\n            if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {\n                return dequantize_block_q8_0_f16_cuda;\n            }\n            return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;\n        case GGML_TYPE_Q2_K:\n            return dequantize_row_q2_K_cuda;\n        case GGML_TYPE_Q3_K:\n            return dequantize_row_q3_K_cuda;\n        case GGML_TYPE_Q4_K:\n            return dequantize_row_q4_K_cuda;\n        case GGML_TYPE_Q5_K:\n            return dequantize_row_q5_K_cuda;\n        case GGML_TYPE_Q6_K:\n            return dequantize_row_q6_K_cuda;\n        case GGML_TYPE_IQ2_XXS:\n            return dequantize_row_iq2_xxs_cuda;\n        case GGML_TYPE_IQ2_XS:\n            return dequantize_row_iq2_xs_cuda;\n        case GGML_TYPE_IQ2_S:\n            return dequantize_row_iq2_s_cuda;\n        case GGML_TYPE_IQ3_XXS:\n            return dequantize_row_iq3_xxs_cuda;\n        case GGML_TYPE_IQ1_S:\n            return dequantize_row_iq1_s_cuda;\n        case GGML_TYPE_IQ1_M:\n            return dequantize_row_iq1_m_cuda;\n        case GGML_TYPE_IQ4_NL:\n            return dequantize_row_iq4_nl_cuda;\n        case GGML_TYPE_IQ4_XS:\n            return dequantize_row_iq4_xs_cuda;\n        case GGML_TYPE_IQ3_S:\n            return dequantize_row_iq3_s_cuda;\n        case GGML_TYPE_MXFP4:\n            return dequantize_row_mxfp4_cuda;\n        case GGML_TYPE_F32:\n            return convert_unary_cont_cuda<float>;\n        case GGML_TYPE_BF16:\n            return convert_unary_cont_cuda<nv_bfloat16>;\n        default:\n            return nullptr;\n    }\n}\n\nto_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_Q4_0:\n            return dequantize_row_q4_0_cuda;\n        case GGML_TYPE_Q4_1:\n            return dequantize_row_q4_1_cuda;\n        case GGML_TYPE_Q5_0:\n            return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;\n        case GGML_TYPE_Q5_1:\n            return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;\n        case GGML_TYPE_Q8_0:\n            return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;\n        case GGML_TYPE_Q2_K:\n            return dequantize_row_q2_K_cuda;\n        case GGML_TYPE_Q3_K:\n            return dequantize_row_q3_K_cuda;\n        case GGML_TYPE_Q4_K:\n            return dequantize_row_q4_K_cuda;\n        case GGML_TYPE_Q5_K:\n            return dequantize_row_q5_K_cuda;\n        case GGML_TYPE_Q6_K:\n            return dequantize_row_q6_K_cuda;\n        case GGML_TYPE_IQ2_XXS:\n            return dequantize_row_iq2_xxs_cuda;\n        case GGML_TYPE_IQ2_XS:\n            return dequantize_row_iq2_xs_cuda;\n        case GGML_TYPE_IQ2_S:\n            return dequantize_row_iq2_s_cuda;\n        case GGML_TYPE_IQ3_XXS:\n            return dequantize_row_iq3_xxs_cuda;\n        case GGML_TYPE_IQ1_S:\n            return dequantize_row_iq1_s_cuda;\n        case GGML_TYPE_IQ1_M:\n            return dequantize_row_iq1_m_cuda;\n        case GGML_TYPE_IQ4_NL:\n            return dequantize_row_iq4_nl_cuda;\n        case GGML_TYPE_IQ4_XS:\n            return dequantize_row_iq4_xs_cuda;\n        case GGML_TYPE_IQ3_S:\n            return dequantize_row_iq3_s_cuda;\n        case GGML_TYPE_MXFP4:\n            return dequantize_row_mxfp4_cuda;\n        case GGML_TYPE_F16:\n            return convert_unary_cont_cuda<half>;\n        case GGML_TYPE_BF16:\n            return convert_unary_cont_cuda<nv_bfloat16>;\n        default:\n            return nullptr;\n    }\n}\n\nto_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_F32:\n            return convert_unary_cuda<float>;\n        case GGML_TYPE_Q4_0:\n            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;\n        case GGML_TYPE_Q4_1:\n            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;\n        case GGML_TYPE_Q5_0:\n            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;\n        case GGML_TYPE_Q5_1:\n            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;\n        case GGML_TYPE_Q8_0:\n            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;\n        case GGML_TYPE_BF16:\n            return convert_unary_cuda<nv_bfloat16>;\n        default:\n            return nullptr;\n    }\n}\n\nto_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_F32:\n            return convert_unary_cuda<float, nv_bfloat16>;\n        case GGML_TYPE_Q4_0:\n            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;\n        case GGML_TYPE_Q4_1:\n            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;\n        case GGML_TYPE_Q5_0:\n            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;\n        case GGML_TYPE_Q5_1:\n            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;\n        case GGML_TYPE_Q8_0:\n            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;\n        case GGML_TYPE_F16:\n            return convert_unary_cuda<half, nv_bfloat16>;\n        default:\n            return nullptr;\n    }\n}\n\nto_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_F16:\n            return convert_unary_cuda<half, float>;\n        case GGML_TYPE_Q4_0:\n            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;\n        case GGML_TYPE_Q4_1:\n            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;\n        case GGML_TYPE_Q5_0:\n            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;\n        case GGML_TYPE_Q5_1:\n            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;\n        case GGML_TYPE_Q8_0:\n            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;\n        case GGML_TYPE_BF16:\n            return convert_unary_cuda<nv_bfloat16, float>;\n        default:\n            return nullptr;\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/convert.cuh",
    "content": "#pragma once\n#include \"common.cuh\"\n\n#define CUDA_DEQUANTIZE_BLOCK_SIZE 256\n\ntemplate<typename T>\nusing to_t_cuda_t = void (*)(const void * x, T * y, int64_t k, cudaStream_t stream);\n\ntypedef to_t_cuda_t<float> to_fp32_cuda_t;\ntypedef to_t_cuda_t<half> to_fp16_cuda_t;\ntypedef to_t_cuda_t<nv_bfloat16> to_bf16_cuda_t;\n\nto_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);\n\nto_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type);\n\nto_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);\n\n// TODO more general support for non-contiguous inputs\n\ntemplate<typename T>\nusing to_t_nc_cuda_t = void (*)(const void * x, T * y,\n    int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,\n    int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);\n\ntypedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;\ntypedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;\ntypedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;\n\nto_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);\nto_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);\nto_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);\n\ntemplate<typename dst_t, typename src_t>\n __host__ __device__ inline dst_t ggml_cuda_cast(src_t x) {\n    if constexpr (std::is_same_v<dst_t, src_t>) {\n        return x;\n    } else if constexpr(std::is_same_v<dst_t, nv_bfloat16>) {\n        return __float2bfloat16(float(x));\n    } else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {\n        return __bfloat162float(x);\n    } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {\n        return __float22half2_rn(x);\n    } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {\n        // bypass compile error on cuda 12.0.1\n#ifdef GGML_USE_HIP\n        return __float22bfloat162_rn(x);\n#else\n        return {x.x, x.y};\n#endif // GGML_USE_HIP\n    } else if constexpr(std::is_same_v<dst_t, int32_t>) {\n        return int32_t(x);\n    } else {\n        return float(x);\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/count-equal.cu",
    "content": "#include \"common.cuh\"\n#include \"count-equal.cuh\"\n\n#include <cstdint>\n\ntemplate <typename T>\nstatic __global__ void count_equal(const T * __restrict__ x, const T * __restrict__ y, int64_t * __restrict__ dst, const int64_t dk, const int64_t k) {\n    const int64_t i0 = (int64_t) blockIdx.x*dk;\n    const int64_t i1 = min(i0 + dk, k);\n\n    int nequal = 0;\n\n    for (int64_t i = i0 + threadIdx.x; i < i1; i += WARP_SIZE) {\n        const T xi = x[i];\n        const T yi = y[i];\n        nequal += xi == yi;\n    }\n\n    nequal = warp_reduce_sum(nequal);\n\n    if (threadIdx.x != 0) {\n        return;\n    }\n\n    atomicAdd((int *) dst, nequal);\n}\n\nvoid ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(src0->type == src1->type);\n    GGML_ASSERT( dst->type == GGML_TYPE_I64);\n\n    GGML_ASSERT(ggml_are_same_shape(src0, src1));\n    GGML_ASSERT(ggml_is_contiguous(src0));\n    GGML_ASSERT(ggml_is_contiguous(src1));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n\n    int64_t * dst_d  = (int64_t *) dst->data;\n\n    cudaStream_t stream = ctx.stream();\n    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;\n\n    const int64_t ne = ggml_nelements(src0);\n    GGML_ASSERT(ne < (1 << 30) && \"atomicAdd implementation only supports int\");\n    const int64_t dne = GGML_PAD((ne + 4*nsm - 1) / (4*nsm), CUDA_COUNT_EQUAL_CHUNK_SIZE);\n\n    CUDA_CHECK(cudaMemsetAsync(dst_d, 0, ggml_nbytes(dst), stream));\n\n    const dim3 blocks_dim(WARP_SIZE, 1, 1);\n    const dim3 blocks_num(std::min((int64_t)4*nsm, (ne + CUDA_COUNT_EQUAL_CHUNK_SIZE - 1)/CUDA_COUNT_EQUAL_CHUNK_SIZE), 1, 1);\n\n    switch (src0->type) {\n        case GGML_TYPE_I32: {\n            const int * src0_d = (const int *) src0->data;\n            const int * src1_d = (const int *) src1->data;\n            count_equal<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_d, dne, ne);\n        } break;\n        default:\n            GGML_ASSERT(false);\n            break;\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/count-equal.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_COUNT_EQUAL_CHUNK_SIZE 128\n\nvoid ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/cp-async.cuh",
    "content": "// Simplified API for asynchronous data loading.\n\n#include \"common.cuh\"\n\n\nstatic __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {\n#ifdef CP_ASYNC_AVAILABLE\n    return __cvta_generic_to_shared(generic_ptr);\n#else\n    GGML_UNUSED(generic_ptr);\n    NO_DEVICE_CODE;\n    return 0;\n#endif // CP_ASYNC_AVAILABLE\n}\n\n// Copies data from global to shared memory, cg == cache global.\n// Both the src and dst pointers must be aligned to 16 bit.\n// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.\n// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.\n// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.\ntemplate <int preload>\nstatic __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {\n    static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, \"bad preload\");\n#ifdef CP_ASYNC_AVAILABLE\n#if CUDART_VERSION >= 11040\n    if (preload == 256) {\n        asm volatile(\"cp.async.cg.shared.global.L2::256B [%0], [%1], 16;\"\n            : : \"r\"(dst), \"l\"(src));\n    } else if (preload == 128) {\n        asm volatile(\"cp.async.cg.shared.global.L2::128B [%0], [%1], 16;\"\n            : : \"r\"(dst), \"l\"(src));\n    } else if (preload == 64) {\n        asm volatile(\"cp.async.cg.shared.global.L2::64B [%0], [%1], 16;\"\n            : : \"r\"(dst), \"l\"(src));\n    } else\n#endif // CUDART_VERSION >= 11040\n    {\n        asm volatile(\"cp.async.cg.shared.global [%0], [%1], 16;\"\n            : : \"r\"(dst), \"l\"(src));\n    }\n#else\n    GGML_UNUSED(dst);\n    GGML_UNUSED(src);\n    NO_DEVICE_CODE;\n#endif // CP_ASYNC_AVAILABLE\n}\n\n// Makes each thread wait until its asynchronous data copies are done.\n// This does NOT provide any additional synchronization.\n// In particular, when copying data with multiple warps a call to __syncthreads will be needed.\nstatic __device__ __forceinline__ void cp_async_wait_all() {\n#ifdef CP_ASYNC_AVAILABLE\n    asm volatile(\"cp.async.wait_all;\");\n#else\n    NO_DEVICE_CODE;\n#endif // CP_ASYNC_AVAILABLE\n}\n"
  },
  {
    "path": "src/ggml-cuda/cpy-utils.cuh",
    "content": "#pragma once\n\n#include \"ggml-common.h\"\n#include \"convert.cuh\"\n\nstatic __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {\n    if (x <= val[0]) return 0;\n    if (x >= val[n-1]) return n-1;\n    int ml = 0, mu = n-1;\n    while (mu-ml > 1) {\n        int mav = (ml+mu)/2;\n        if (x < val[mav]) mu = mav; else ml = mav;\n    }\n    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;\n}\n\nstatic __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {\n    float amax = 0.0f;\n    float vmax = 0.0f;\n\n    for (int j = 0; j < QK4_0; ++j) {\n        const float v = x[j];\n        if (amax < fabsf(v)) {\n            amax = fabsf(v);\n            vmax = v;\n        }\n    }\n\n    const float d  = vmax / -8;\n    const float id = d ? 1.0f/d : 0.0f;\n\n    y->d = d;\n\n    for (int j = 0; j < QK4_0/2; ++j) {\n        const float x0 = x[0       + j]*id;\n        const float x1 = x[QK4_0/2 + j]*id;\n\n        const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));\n        const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));\n\n        y->qs[j]  = xi0;\n        y->qs[j] |= xi1 << 4;\n    }\n}\n\nstatic __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {\n    float vmin = FLT_MAX;\n    float vmax = -FLT_MAX;\n\n    for (int j = 0; j < QK4_1; ++j) {\n        const float v = x[j];\n        if (v < vmin) vmin = v;\n        if (v > vmax) vmax = v;\n    }\n\n    const float d  = (vmax - vmin) / ((1 << 4) - 1);\n    const float id = d ? 1.0f/d : 0.0f;\n\n    y->dm.x = d;\n    y->dm.y = vmin;\n\n    for (int j = 0; j < QK4_1/2; ++j) {\n        const float x0 = (x[0       + j] - vmin)*id;\n        const float x1 = (x[QK4_1/2 + j] - vmin)*id;\n\n        const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));\n        const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));\n\n        y->qs[j]  = xi0;\n        y->qs[j] |= xi1 << 4;\n    }\n}\n\nstatic __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {\n    float amax = 0.0f;\n    float vmax = 0.0f;\n\n    for (int j = 0; j < QK5_0; ++j) {\n        const float v = x[j];\n        if (amax < fabsf(v)) {\n            amax = fabsf(v);\n            vmax = v;\n        }\n    }\n\n    const float d  = vmax / -16;\n    const float id = d ? 1.0f/d : 0.0f;\n\n    y->d = d;\n\n    uint32_t qh = 0;\n    for (int j = 0; j < QK5_0/2; ++j) {\n        const float x0 = x[0       + j]*id;\n        const float x1 = x[QK5_0/2 + j]*id;\n\n        const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));\n        const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));\n\n        y->qs[j]  = (xi0 & 0xf) | ((xi1 & 0xf) << 4);\n        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);\n        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);\n    }\n    memcpy(y->qh, &qh, sizeof(qh));\n}\n\nstatic __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {\n    float min = x[0];\n    float max = x[0];\n\n    for (int j = 1; j < QK5_1; ++j) {\n        const float v = x[j];\n        min = v < min ? v : min;\n        max = v > max ? v : max;\n    }\n\n    const float d  = (max - min) / 31;\n    const float id = d ? 1.0f/d : 0.0f;\n\n    y->dm.x = d;\n    y->dm.y = min;\n\n    uint32_t qh = 0;\n    for (int j = 0; j < QK5_1/2; ++j) {\n        const float x0 = (x[0       + j] - min)*id;\n        const float x1 = (x[QK5_1/2 + j] - min)*id;\n\n        const uint8_t xi0 = (uint8_t)(x0 + 0.5f);\n        const uint8_t xi1 = (uint8_t)(x1 + 0.5f);\n\n        y->qs[j]  = (xi0 & 0xf) | ((xi1 & 0xf) << 4);\n        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);\n        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);\n    }\n    memcpy(y->qh, &qh, sizeof(qh));\n}\n\nstatic __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {\n    float amax = 0.0f; // absolute max\n\n    for (int j = 0; j < QK8_0; j++) {\n        const float v = x[j];\n        amax = fmaxf(amax, fabsf(v));\n    }\n\n    const float d = amax / ((1 << 7) - 1);\n    const float id = d ? 1.0f/d : 0.0f;\n\n    y->d = d;\n\n    for (int j = 0; j < QK8_0; ++j) {\n        const float x0 = x[j]*id;\n        y->qs[j] = roundf(x0);\n    }\n}\n\nstatic __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {\n    float amax = 0.0f;\n    float vmax = 0.0f;\n\n    for (int j = 0; j < QK4_NL; ++j) {\n        const float v = x[j];\n        if (amax < fabsf(v)) {\n            amax = fabsf(v);\n            vmax = v;\n        }\n    }\n\n    float d = vmax / kvalues_iq4nl[0];\n    const float id = d ? 1.0f/d : 0.0f;\n\n    float sumqx = 0, sumq2 = 0;\n    for (int j = 0; j < QK4_NL/2; ++j) {\n        const float x0 = x[0        + j]*id;\n        const float x1 = x[QK4_NL/2 + j]*id;\n        const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);\n        const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);\n        y->qs[j] = xi0 | (xi1 << 4);\n        const float v0 = kvalues_iq4nl[xi0];\n        const float v1 = kvalues_iq4nl[xi1];\n        const float w0 = x[0        + j]*x[0        + j];\n        const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];\n        sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];\n        sumq2 += w0*v0*v0 + w1*v1*v1;\n    }\n\n    y->d = sumq2 > 0 ? sumqx/sumq2 : d;\n}\n\n// Wrapper functions for cpy.cu compatibility\nstatic __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {\n    quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);\n}\n\nstatic __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {\n    quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);\n}\n\nstatic __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {\n    quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);\n}\n\nstatic __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {\n    quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);\n}\n\nstatic __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {\n    quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);\n}\n\nstatic __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {\n    quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);\n}\n\ntemplate<typename src_t, typename dst_t>\nstatic __device__ void cpy_1_scalar(const char * cxi, char * cdsti) {\n    *(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);\n}\n"
  },
  {
    "path": "src/ggml-cuda/cpy.cu",
    "content": "#include \"cpy.cuh\"\n#include \"dequantize.cuh\"\n#include \"cpy-utils.cuh\"\n#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)\n#include \"ggml-musa/mudnn.cuh\"\n#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY\n\ntypedef void (*cpy_kernel_t)(const char * cx, char * cdst);\n\nconst int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks\nconst int CUDA_CPY_BLOCK_NM = 8;     // block size of 3rd dimension if available\nconst int CUDA_CPY_BLOCK_ROWS = 8;   // block dimension for marching through rows\n\ntemplate <cpy_kernel_t cpy_1>\nstatic __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne,\n                                  const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,\n                                  const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,\n                                  const int64_t nb12, const int64_t nb13) {\n    const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;\n\n    if (i >= ne) {\n        return;\n    }\n\n    // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor\n    // then combine those indices with the corresponding byte offsets to get the total offsets\n    const int64_t i03 = i/(ne00 * ne01 * ne02);\n    const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);\n    const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;\n    const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;\n    const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;\n\n    const int64_t i13 = i/(ne10 * ne11 * ne12);\n    const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);\n    const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;\n    const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;\n    const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;\n\n    cpy_1(cx + x_offset, cdst + dst_offset);\n}\n\ntemplate <typename T>\nstatic __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne,\n                               const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,\n                               const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,\n                               const int64_t nb12, const int64_t nb13) {\n\n    const T* src = reinterpret_cast<const T*>(cx);\n    T* dst = reinterpret_cast<T*>(cdst);\n\n    const int64_t nmat = ne / (ne00 * ne01);\n    const int64_t n = ne00 * ne01;\n\n    const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;\n    const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;\n    const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x;  // transpose block offset\n    const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;\n\n    __shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];\n    int cur_tile_buf = 0;\n\n#pragma unroll\n    for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {\n\n        const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;\n        if (imat >= nmat)\n            break;\n\n#pragma unroll\n        for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {\n            if(x < ne01 && y + j < ne00){\n                const int row = threadIdx.y+j;\n                const int col = threadIdx.x * sizeof(float)/sizeof(T);\n                T *tile2 = reinterpret_cast<T*>(tile[cur_tile_buf][row]);\n                tile2[col] = src[imat*n + (y+j)*ne01 + x];\n            }\n        }\n\n        __syncthreads();\n\n#pragma unroll\n        for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {\n            if (ty + j < ne01 && tx < ne00) {\n                const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T);\n                const T *tile2 = reinterpret_cast<const T*>(tile[cur_tile_buf][threadIdx.x]);\n                dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];\n            }\n        }\n\n        cur_tile_buf = (cur_tile_buf + 1) % 2;\n    }\n\n    GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,\n        nb12, nb13);\n}\n\nstatic __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {\n    float * cdstf = (float *)(cdsti);\n\n#pragma unroll\n    for (int j = 0; j < QK8_0; j += 2) {\n        float2 dq;\n        dequantize_q8_0(cxi, 0, j, dq);\n        *(cdstf + j) = dq.x;\n        *(cdstf + j + 1) = dq.y;\n    }\n}\n\ntemplate<dequantize_kernel_t dequant, int qk>\nstatic __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {\n    float * cdstf = (float *)(cdsti);\n\n#pragma unroll\n    for (int j = 0; j < qk/2; j++) {\n        float2 dq;\n        dequant(cxi, 0, j, dq);\n        *(cdstf + j) = dq.x;\n        *(cdstf + j + qk/2) = dq.y;\n    }\n}\n\ntemplate <cpy_kernel_t cpy_blck, int qk>\nstatic __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne,\n                                 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,\n                                 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,\n                                 const int64_t nb12, const int64_t nb13) {\n    const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;\n\n    if (i >= ne) {\n        return;\n    }\n\n    const int64_t i03 = i/(ne00 * ne01 * ne02);\n    const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);\n    const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;\n    const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;\n    const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;\n\n    const int64_t i13 = i/(ne10 * ne11 * ne12);\n    const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);\n    const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;\n    const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;\n    const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;\n\n    cpy_blck(cx + x_offset, cdst + dst_offset);\n}\n\ntemplate <cpy_kernel_t cpy_blck, int qk>\nstatic __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne,\n                                 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,\n                                 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,\n                                 const int64_t nb12, const int64_t nb13) {\n    const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;\n\n    if (i >= ne) {\n        return;\n    }\n\n    const int64_t i03 = i/(ne00 * ne01 * ne02);\n    const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);\n    const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;\n    const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;\n    const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;\n\n    const int64_t i13 = i/(ne10 * ne11 * ne12);\n    const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);\n    const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;\n    const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;\n    const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;\n\n    cpy_blck(cx + x_offset, cdst + dst_offset);\n}\n\ntemplate<typename src_t, typename dst_t>\nstatic __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {\n    const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;\n\n    if (i >= ne) {\n        return;\n    }\n\n    const src_t * x = (const src_t *) cx;\n    dst_t *     dst = (dst_t *) cdst;\n\n    dst[i] = ggml_cuda_cast<dst_t>(x[i]);\n}\n\ntemplate<typename src_t, typename dst_t>\nstatic void ggml_cpy_scalar_contiguous_cuda(\n    const char * cx, char * cdst, const int64_t ne,\ncudaStream_t stream) {\n\n    const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;\n    GGML_ASSERT(num_blocks < UINT_MAX);\n    cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>\n        (cx, cdst, ne);\n}\n\ntemplate<typename src_t, typename dst_t, bool transposed = false>\nstatic void ggml_cpy_scalar_cuda(\n    const char * cx, char * cdst, const int64_t ne,\n    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,\n    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {\n\n    if (transposed) {\n        GGML_ASSERT(ne == ne00*ne01*ne02);  // ne[3] is 1 assumed\n        int64_t ne00n, ne01n, ne02n;\n        if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here\n            ne00n = ne00;\n            ne01n = ne01;\n            ne02n = ne02;\n        } else {\n            ne00n = ne00;\n            ne01n = ne01*ne02;\n            ne02n = 1;\n        }\n\n        int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;\n        int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;\n        int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM;\n        GGML_ASSERT(grid_x < UINT_MAX);\n        GGML_ASSERT(grid_y < USHRT_MAX);\n        GGML_ASSERT(grid_z < USHRT_MAX);\n        dim3 dimGrid(grid_x, grid_y, grid_z);\n        dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);\n        cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>\n            (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);\n    } else {\n        const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;\n        GGML_ASSERT(num_blocks < UINT_MAX);\n        cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>\n            (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);\n    }\n}\n\nstatic void ggml_cpy_f32_q8_0_cuda(\n    const char * cx, char * cdst, const int64_t ne,\n    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,\n    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {\n\n    GGML_ASSERT(ne % QK8_0 == 0);\n    const int64_t num_blocks = ne / QK8_0;\n    GGML_ASSERT(num_blocks < UINT_MAX);\n    cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>\n        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);\n}\n\nstatic void ggml_cpy_q8_0_f32_cuda(\n    const char * cx, char * cdst, const int64_t ne,\n    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,\n    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {\n\n    const int64_t num_blocks = ne;\n    GGML_ASSERT(num_blocks < UINT_MAX);\n    cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>\n        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);\n}\n\nstatic void ggml_cpy_f32_q4_0_cuda(\n    const char * cx, char * cdst, const int64_t ne,\n    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,\n    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {\n\n    GGML_ASSERT(ne % QK4_0 == 0);\n    const int64_t num_blocks = ne / QK4_0;\n    GGML_ASSERT(num_blocks < UINT_MAX);\n    cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>\n        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);\n}\n\nstatic void ggml_cpy_q4_0_f32_cuda(\n    const char * cx, char * cdst, const int64_t ne,\n    const int64_t ne00, const int64_t ne01, const int64_t ne02,\n    const int64_t nb00, const int64_t nb01, const int64_t nb02,\n    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,\n    const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,\n    cudaStream_t stream) {\n    const int64_t num_blocks = ne;\n    GGML_ASSERT(num_blocks < UINT_MAX);\n    cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(\n        cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,\n         ne10, ne11, ne12, nb10, nb11, nb12, nb13);\n}\n\nstatic void ggml_cpy_f32_q4_1_cuda(\n    const char * cx, char * cdst, const int64_t ne,\n    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,\n    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {\n\n    GGML_ASSERT(ne % QK4_1 == 0);\n    const int64_t num_blocks = ne / QK4_1;\n    GGML_ASSERT(num_blocks < UINT_MAX);\n    cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>\n        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);\n}\n\nstatic void ggml_cpy_q4_1_f32_cuda(\n    const char * cx, char * cdst, const int64_t ne,\n    const int64_t ne00, const int64_t ne01, const int64_t ne02,\n    const int64_t nb00, const int64_t nb01, const int64_t nb02,\n    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,\n    const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,\n    cudaStream_t stream) {\n    const int64_t num_blocks = ne;\n    GGML_ASSERT(num_blocks < UINT_MAX);\n    cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(\n        cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,\n         ne10, ne11, ne12, nb10, nb11, nb12, nb13);\n}\n\nstatic void ggml_cpy_f32_q5_0_cuda(\n    const char * cx, char * cdst, const int64_t ne,\n    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,\n    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {\n\n    GGML_ASSERT(ne % QK5_0 == 0);\n    const int64_t num_blocks = ne / QK5_0;\n    GGML_ASSERT(num_blocks < UINT_MAX);\n    cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>\n        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);\n}\n\nstatic void ggml_cpy_q5_0_f32_cuda(\n    const char * cx, char * cdst, const int64_t ne,\n    const int64_t ne00, const int64_t ne01, const int64_t ne02,\n    const int64_t nb00, const int64_t nb01, const int64_t nb02,\n    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,\n    const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,\n    cudaStream_t stream) {\n    const int64_t num_blocks = ne;\n    GGML_ASSERT(num_blocks < UINT_MAX);\n    cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(\n        cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,\n        ne10, ne11, ne12, nb10, nb11, nb12, nb13);\n}\n\nstatic void ggml_cpy_f32_q5_1_cuda(\n    const char * cx, char * cdst, const int64_t ne,\n    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,\n    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {\n\n    GGML_ASSERT(ne % QK5_1 == 0);\n    const int64_t num_blocks = ne / QK5_1;\n    GGML_ASSERT(num_blocks < UINT_MAX);\n    cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>\n        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);\n}\n\nstatic void ggml_cpy_q5_1_f32_cuda(\n    const char * cx, char * cdst, const int64_t ne,\n    const int64_t ne00, const int64_t ne01, const int64_t ne02,\n    const int64_t nb00, const int64_t nb01, const int64_t nb02,\n    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,\n    const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,\n    cudaStream_t stream) {\n    const int64_t num_blocks = ne;\n    GGML_ASSERT(num_blocks < UINT_MAX);\n    cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(\n        cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,\n        ne10, ne11, ne12, nb10, nb11, nb12, nb13);\n}\n\nstatic void ggml_cpy_f32_iq4_nl_cuda(\n    const char * cx, char * cdst, const int64_t ne,\n    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,\n    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {\n\n    GGML_ASSERT(ne % QK4_NL == 0);\n    const int64_t num_blocks = ne / QK4_NL;\n    GGML_ASSERT(num_blocks < UINT_MAX);\n    cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>\n        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);\n}\n\nvoid ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {\n    const int64_t ne = ggml_nelements(src0);\n    GGML_ASSERT(ne == ggml_nelements(src1));\n\n    const int64_t ne00 = src0->ne[0];\n    const int64_t ne01 = src0->ne[1];\n    const int64_t ne02 = src0->ne[2];\n\n    //GGML_ASSERT(src0->ne[3] == 1);\n\n    const int64_t nb00 = src0->nb[0];\n    const int64_t nb01 = src0->nb[1];\n    const int64_t nb02 = src0->nb[2];\n    const int64_t nb03 = src0->nb[3];\n\n    const int64_t ne10 = src1->ne[0];\n    const int64_t ne11 = src1->ne[1];\n    const int64_t ne12 = src1->ne[2];\n\n    //GGML_ASSERT(src1->ne[3] == 1);\n\n    const int64_t nb10 = src1->nb[0];\n    const int64_t nb11 = src1->nb[1];\n    const int64_t nb12 = src1->nb[2];\n    const int64_t nb13 = src1->nb[3];\n\n    cudaStream_t main_stream = ctx.stream();\n\n    char * src0_ddc = (char *) src0->data;\n    char * src1_ddc = (char *) src1->data;\n\n    const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);\n    const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&\n        src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);\n\n    if (src0->type == src1->type && contiguous_srcs) {\n        GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));\n#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)\n        if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {\n            CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));\n        } else\n#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY\n        {\n            CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));\n        }\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {\n        if (can_be_transposed) {\n            ggml_cpy_scalar_cuda<float, float, true>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        } else {\n            ggml_cpy_scalar_cuda<float, float>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        }\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {\n        if (contiguous_srcs) {\n            ggml_cpy_scalar_contiguous_cuda<float, nv_bfloat16>\n                (src0_ddc, src1_ddc, ne, main_stream);\n        } else {\n            ggml_cpy_scalar_cuda<float, nv_bfloat16>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        }\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {\n        if (contiguous_srcs) {\n            ggml_cpy_scalar_contiguous_cuda<float, half>\n                (src0_ddc, src1_ddc, ne, main_stream);\n        } else {\n            ggml_cpy_scalar_cuda<float, half>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        }\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {\n        ggml_cpy_f32_q8_0_cuda\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {\n        ggml_cpy_q8_0_f32_cuda\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {\n        ggml_cpy_f32_q4_0_cuda\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {\n        ggml_cpy_q4_0_f32_cuda\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {\n        ggml_cpy_f32_q4_1_cuda\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {\n        ggml_cpy_q4_1_f32_cuda\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {\n        ggml_cpy_f32_q5_0_cuda\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {\n        ggml_cpy_q5_0_f32_cuda\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {\n        ggml_cpy_f32_iq4_nl_cuda\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {\n        ggml_cpy_f32_q5_1_cuda\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {\n        ggml_cpy_q5_1_f32_cuda\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {\n        if (can_be_transposed) {\n            ggml_cpy_scalar_cuda<half, half, true>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        } else {\n            ggml_cpy_scalar_cuda<half, half>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        }\n    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {\n        if (contiguous_srcs) {\n            ggml_cpy_scalar_contiguous_cuda<half, nv_bfloat16>\n                (src0_ddc, src1_ddc, ne, main_stream);\n        } else {\n            ggml_cpy_scalar_cuda<half, nv_bfloat16>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        }\n    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {\n        if (contiguous_srcs) {\n            ggml_cpy_scalar_contiguous_cuda<half, float>\n                (src0_ddc, src1_ddc, ne, main_stream);\n        } else {\n            ggml_cpy_scalar_cuda<half, float>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        }\n    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {\n        if (can_be_transposed) {\n            ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        } else {\n            ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        }\n    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {\n        if (contiguous_srcs) {\n            ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, half>\n                (src0_ddc, src1_ddc, ne, main_stream);\n        } else {\n            ggml_cpy_scalar_cuda<nv_bfloat16, half>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        }\n    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {\n        if (contiguous_srcs) {\n            ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, float>\n                (src0_ddc, src1_ddc, ne, main_stream);\n        } else {\n            ggml_cpy_scalar_cuda<nv_bfloat16, float>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        }\n    } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {\n        if (can_be_transposed) {\n            ggml_cpy_scalar_cuda<int32_t, int32_t, true>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        } else {\n            ggml_cpy_scalar_cuda<int32_t, int32_t>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        }\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {\n        if (contiguous_srcs) {\n            ggml_cpy_scalar_contiguous_cuda<float, int32_t>\n                (src0_ddc, src1_ddc, ne, main_stream);\n        } else {\n            ggml_cpy_scalar_cuda<float, int32_t>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        }\n    } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {\n        if (contiguous_srcs) {\n            ggml_cpy_scalar_contiguous_cuda<int32_t, float>\n                (src0_ddc, src1_ddc, ne, main_stream);\n        } else {\n            ggml_cpy_scalar_cuda<int32_t, float>\n                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n        }\n    } else {\n        GGML_ABORT(\"%s: unsupported type combination (%s to %s)\\n\", __func__,\n                ggml_type_name(src0->type), ggml_type_name(src1->type));\n    }\n}\n\nvoid ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    ggml_cuda_cpy(ctx, src0, dst);\n}\n"
  },
  {
    "path": "src/ggml-cuda/cpy.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_CPY_BLOCK_SIZE 64\n\nvoid ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);\n\nvoid ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/cross-entropy-loss.cu",
    "content": "#include \"common.cuh\"\n#include \"cross-entropy-loss.cuh\"\n#include \"sum.cuh\"\n\n#include <cmath>\n#include <cstdint>\n\ntemplate <bool use_shared>\nstatic __global__ void cross_entropy_loss_f32(\n        const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {\n    extern __shared__ float tmp[];\n\n    logits += int64_t(blockIdx.x)*nclasses;\n    labels += int64_t(blockIdx.x)*nclasses;\n\n    // Find maximum for softmax:\n    float max_logit = -INFINITY;\n    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {\n        const float val = logits[i];\n        max_logit = fmaxf(max_logit, val);\n\n        if (use_shared) {\n            tmp[i] = val;\n        }\n    }\n    max_logit = warp_reduce_max(max_logit);\n\n    // Calculate log(softmax(logits)) which is just logits - max:\n    float sum = 0.0f;\n    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {\n        const float logit_i = use_shared ? tmp[i] : logits[i];\n        sum += expf(logit_i - max_logit);\n    }\n    sum = warp_reduce_sum(sum);\n    sum = logf(sum);\n\n    // log(exp(logits - max) / sum) = (logits - max) - log(sum)\n    float loss = 0.0f;\n    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {\n        const float logit_i = use_shared ? tmp[i] : logits[i];\n        loss += (logit_i - max_logit - sum) * labels[i];\n    }\n    loss = -warp_reduce_sum(loss) / (float)k;\n\n    if (threadIdx.x != 0) {\n        return;\n    }\n\n    dst[blockIdx.x] = loss;\n}\n\ntemplate <bool use_shared>\nstatic __global__ void cross_entropy_loss_back_f32(\n        const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels,\n        float * __restrict__ dst, const int nclasses) {\n    extern __shared__ float tmp[];\n\n    logits += int64_t(blockIdx.x)*nclasses;\n    labels += int64_t(blockIdx.x)*nclasses;\n    dst    += int64_t(blockIdx.x)*nclasses;\n\n    float maxval = -INFINITY;\n    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {\n        const float val = logits[i];\n        maxval = fmaxf(maxval, val);\n\n        if (use_shared) {\n            tmp[i] = val;\n        }\n    }\n    maxval = warp_reduce_max(maxval);\n\n    float sum = 0.0f;\n    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {\n        const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval);\n        sum += val;\n\n        if (use_shared) {\n            tmp[i] = val;\n        } else {\n            dst[i] = val;\n        }\n    }\n    sum = warp_reduce_sum(sum);\n    const float sm_scale = 1.0f/sum;\n\n    const float d_by_nrows = *grad/gridDim.x;\n    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {\n        const float val = use_shared ? tmp[i] : dst[i];\n        dst[i] = (val*sm_scale - labels[i])*d_by_nrows;\n    }\n}\n\nvoid ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_ASSERT(ggml_is_contiguous(src0));\n    GGML_ASSERT(ggml_is_contiguous(src1));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n\n    const int64_t ne00  = src0->ne[0];\n    const int64_t nrows = ggml_nrows(src0);\n\n    const float * src0_d = (const float *) src0->data;\n    const float * src1_d = (const float *) src1->data;\n    float       * dst_d  = (float       *) dst->data;\n\n    ggml_cuda_pool & pool = ctx.pool();\n    cudaStream_t stream = ctx.stream();\n\n    const dim3 blocks_dim(WARP_SIZE, 1, 1);\n    const dim3 blocks_num(nrows, 1, 1);\n    const size_t nbytes_shared = ne00*sizeof(float);\n\n    const int    id    = ggml_cuda_get_device();\n    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;\n\n    ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);\n\n    if (nbytes_shared <= smpbo) {\n        CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);\n        cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);\n    } else {\n        cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);\n    }\n    CUDA_CHECK(cudaGetLastError());\n\n    // Combine results from individual blocks:\n    sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);\n}\n\nvoid ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * grad  = dst->src[0];\n    const ggml_tensor * src0f = dst->src[1];\n    const ggml_tensor * src1f = dst->src[2];\n\n    GGML_ASSERT(src0f->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1f->type == GGML_TYPE_F32);\n    GGML_ASSERT( grad->type == GGML_TYPE_F32);\n    GGML_ASSERT(  dst->type == GGML_TYPE_F32);\n\n    GGML_ASSERT(ggml_is_scalar(grad));\n    GGML_ASSERT(ggml_is_contiguous(src0f));\n    GGML_ASSERT(ggml_is_contiguous(src1f));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    GGML_ASSERT(ggml_are_same_shape(src0f, src1f));\n    GGML_ASSERT(ggml_are_same_shape(src0f, dst));\n\n    const int64_t ne00  = src0f->ne[0];\n    const int64_t nrows = ggml_nrows(src0f);\n\n    const float * grad_d  = (const float *) grad->data;\n    const float * src0f_d = (const float *) src0f->data;\n    const float * src1f_d = (const float *) src1f->data;\n    float       * dst_d   = (float       *) dst->data;\n\n    cudaStream_t stream = ctx.stream();\n\n    const dim3 blocks_dim(WARP_SIZE, 1, 1);\n    const dim3 blocks_num(nrows, 1, 1);\n    const size_t nbytes_shared = ne00*sizeof(float);\n\n    const int    id    = ggml_cuda_get_device();\n    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;\n\n    if (nbytes_shared <= smpbo) {\n        CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);\n        cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);\n    } else {\n        cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/cross-entropy-loss.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256\n\nvoid ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/cumsum.cu",
    "content": "#include <algorithm>\n#include \"cumsum.cuh\"\n#include \"convert.cuh\"\n#include \"ggml-cuda/common.cuh\"\n#include \"ggml.h\"\n\n#ifdef GGML_CUDA_USE_CUB\n#   include <cub/cub.cuh>\n#endif // GGML_CUDA_USE_CUB\n\ntemplate<typename T, int BLOCK_SIZE>\nstatic __global__ void cumsum_cub_kernel(\n        const T * __restrict__ src,\n        T * __restrict__ dst,\n        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,\n        const int64_t  s01, const int64_t  s02, const int64_t  s03,\n        const int64_t   s1,  const int64_t   s2,  const int64_t   s3) {\n#ifdef GGML_CUDA_USE_CUB\n    using BlockScanT = cub::BlockScan<T, BLOCK_SIZE>;\n\n    __shared__ typename BlockScanT::TempStorage temp_storage;\n    __shared__ T block_carry;\n\n    const int tid = threadIdx.x;\n    constexpr int UNROLL_FACTOR = 4;\n    constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;\n\n    const int64_t i1 = blockIdx.x;\n    const int64_t i2 = blockIdx.y;\n    const int64_t i3 = blockIdx.z;\n\n    if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {\n        return;\n    }\n\n    const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;\n    T *       dst_row = dst + i1 * s1  + i2 * s2  + i3 * s3;\n\n    if (tid == 0) {\n        block_carry = 0;\n    }\n    __syncthreads();\n\n    for (int64_t start = 0; start < ne00; start += TILE_SIZE) {\n        T items[UNROLL_FACTOR];\n        T thread_sum = T(0);\n\n#pragma unroll\n        for (int i = 0; i < UNROLL_FACTOR; i++) {\n            int64_t idx = start + tid * UNROLL_FACTOR + i;\n            T val = (idx < ne00) ? src_row[idx] : T(0);\n            thread_sum += val;\n            items[i] = thread_sum;\n        }\n\n        // Block-wide scan on thread sums\n        T thread_prefix;\n        T block_total;\n        BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);\n        __syncthreads();\n\n        // Add offset to each item and store\n        T thread_offset = thread_prefix - thread_sum + block_carry;\n#pragma unroll\n        for (int i = 0; i < UNROLL_FACTOR; i++) {\n            int64_t idx = start + tid * UNROLL_FACTOR + i;\n            if (idx < ne00) {\n                dst_row[idx] = items[i] + thread_offset;\n            }\n        }\n\n        __syncthreads();\n\n        // Update carry for next tile\n        if (tid == 0) {\n            block_carry += block_total;\n        }\n    }\n#else\n    NO_DEVICE_CODE;\n#endif // GGML_CUDA_USE_CUB\n}\n\n// Fallback kernel implementation\ntemplate<typename T>\nstatic __global__ void cumsum_kernel(\n        const T * src, T * dst,\n        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,\n        const int64_t  s00, const int64_t  s01, const int64_t  s02, const int64_t  s03,\n        const int64_t   s0, const int64_t   s1, const int64_t   s2, const int64_t   s3) {\n\n    GGML_UNUSED_VARS(s00, s0);\n\n    const int tid = threadIdx.x;\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n    const int lane = tid % warp_size;\n    const int warp = tid / warp_size;\n    const int warps_per_block = blockDim.x / warp_size;\n\n    extern __shared__ float smem[];\n    float *                 s_vals        = smem;\n    float *                 s_warp_sums   = smem + blockDim.x;\n    float *                 s_carry       = smem + blockDim.x + warps_per_block;\n    float *                 s_chunk_total = s_carry + 1;\n\n    // Initialize carry\n    if (tid == 0) {\n        *s_carry = 0.0f;\n    }\n    __syncthreads();\n\n    const int64_t i3 = blockIdx.z;\n    const int64_t i2 = blockIdx.y;\n    const int64_t i1 = blockIdx.x;\n    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {\n        return;\n    }\n\n    const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;\n    T       * dst_row = dst + i1 * s1  + i2 * s2  + i3 * s3;\n\n    // register blocking: process 4 elements per thread to hide latency\n    // and reduce synchronization overhead\n    constexpr int num_unroll = 4;\n    T             temp[num_unroll];\n\n    for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {\n        int64_t idx = i + tid * num_unroll;\n\n        // thread local sequential scan\n        temp[0] = (idx < ne00 ? src_row[idx] : T(0));\n#pragma unroll\n        for (int64_t j = 1; j < num_unroll; j++) {\n            temp[j] = temp[j - 1];\n            if (idx + j < ne00) {\n                temp[j] += src_row[idx + j];\n            } else {\n                temp[j] += 0;\n            }\n        }\n\n        // last emenent is sum of all values assigned to thread\n        float val = (idx < ne00) ? ggml_cuda_cast<float, T>(temp[num_unroll - 1]) : 0.0f;\n\n        // Warp inclusive scan\n        val = warp_prefix_inclusive_sum<T, warp_size>(val);\n        s_vals[tid] = val;\n\n        if (lane == warp_size - 1) {\n            s_warp_sums[warp] = val;\n        }\n        __syncthreads();\n\n        // Exclusive scan of warp sums (warp 0 only)\n        if (warp == 0) {\n            float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;\n            float inc = warp_prefix_inclusive_sum<T, warp_size>(w);\n            if (tid < warps_per_block) {\n                s_warp_sums[tid] = inc - w;   // exclusive sum\n            }\n            if (tid == warps_per_block - 1) {\n                *s_chunk_total = inc;          // total sum of this chunk\n            }\n        }\n        __syncthreads();\n\n        // write back results\n        float carry = *s_carry;\n        // calculate sum offset for this thread\n        float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];\n\n#pragma unroll\n        for (int32_t j = 0; j < num_unroll; j++) {\n            if (idx + j < ne00) {\n                dst_row[idx + j] = temp[j] + ggml_cuda_cast<T, float>(final_val_offset);\n            }\n        }\n\n        __syncthreads();\n\n        // Update carry for next chunk\n        if (tid == 0) {\n            *s_carry += *s_chunk_total;\n        }\n    }\n}\n\n#ifdef GGML_CUDA_USE_CUB\ntemplate <typename T>\nstatic void cumsum_cub(ggml_cuda_pool & pool,\n                       const T *        src,\n                       T *              dst,\n                       int64_t          ne,\n                       cudaStream_t     stream) {\n    size_t tmp_size = 0;\n\n    // Query how much temp storage CUDA UnBound (CUB) needs\n    cub::DeviceScan::InclusiveSum(nullptr,   // d_temp_storage (null = just query size)\n                                  tmp_size,  // reference to size (will be set by CUB)\n                                  src,       // input pointer\n                                  dst,       // output pointer\n                                  ne,        // number of elements\n                                  stream     // CUDA stream to use\n    );\n\n    ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);\n\n    // Perform the inclusive scan\n    cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);\n}\n#endif // GGML_CUDA_USE_CUB\n\ntemplate<typename T>\nstatic void cumsum_cuda(\n        [[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst,\n        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,\n        const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,\n        const int64_t  nb0,  const int64_t nb1, const int64_t  nb2, const int64_t  nb3,\n        cudaStream_t stream) {\n\n    const size_t type_size = sizeof(T);\n    bool use_cub = false;\n#ifdef GGML_CUDA_USE_CUB\n    // Check if we can use CUB (data must be contiguous along innermost dimension)\n    const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size);\n\n    if (is_contiguous) {\n        use_cub = true;\n        const int64_t nrows = ne01 * ne02 * ne03;\n        // TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released\n        // Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004\n        if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {\n            for (int i=0; i<nrows; i++) {\n                cumsum_cub(ctx.pool(), src + i * ne00, dst + i * ne00, ne00, stream);\n            }\n            return;\n        }\n    }\n#endif // GGML_CUDA_USE_CUB\n    dim3 grid_dims(ne01, ne02, ne03);\n    const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()];\n    const int warp_size = info.warp_size;\n    const int num_warps = (ne00 + warp_size - 1) / warp_size;\n    int block_size = num_warps * warp_size;\n    block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);\n    dim3 block_dims(block_size, 1, 1);\n    const int warps_per_block = block_size / warp_size;\n    const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);\n\n    if (use_cub && ne00 >= 1024) {\n        cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(\n            src, dst,\n            ne00, ne01, ne02, ne03,\n            nb01 / type_size, nb02 / type_size, nb03 / type_size,\n            nb1 / type_size,  nb2 / type_size,  nb3 / type_size\n        );\n    } else {\n        cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(\n            src, dst,\n            ne00, ne01, ne02, ne03,\n            nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,\n            nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size\n        );\n    }\n}\n\nvoid ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == dst->type);\n    switch(src0->type) {\n        case GGML_TYPE_F32:\n            {\n                cumsum_cuda(\n                    ctx, (const float *)src0->data, (float *)dst->data,\n                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],\n                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],\n                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],\n                    stream\n                );\n            } break;\n        // We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms\n        /*case GGML_TYPE_F16:\n            {\n                cumsum_cuda(\n                    (const half *)src0->data, (half *)dst->data,\n                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],\n                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],\n                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],\n                    stream\n                );\n            } break;\n        case GGML_TYPE_BF16:\n            {\n                cumsum_cuda(\n                    (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,\n                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],\n                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],\n                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],\n                    stream\n                );\n            } break;*/\n        default:\n            GGML_ABORT(\"fatal error\");\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/cumsum.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_CUMSUM_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/dequantize.cuh",
    "content": "#include \"common.cuh\"\n\nstatic __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){\n    const block_q4_0 * x = (const block_q4_0 *) vx;\n\n    const float d = x[ib].d;\n\n    const int vui = x[ib].qs[iqs];\n\n    v.x = vui & 0xF;\n    v.y = vui >> 4;\n\n    v.x = (v.x - 8.0f) * d;\n    v.y = (v.y - 8.0f) * d;\n}\n\nstatic __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, float2 & v){\n    const block_q4_1 * x = (const block_q4_1 *) vx;\n\n    const float2 dm = __half22float2(x[ib].dm);\n\n    const int vui = x[ib].qs[iqs];\n\n    v.x = vui & 0xF;\n    v.y = vui >> 4;\n\n    v.x = (v.x * dm.x) + dm.y;\n    v.y = (v.y * dm.x) + dm.y;\n}\n\nstatic __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, float2 & v){\n    const block_q5_0 * x = (const block_q5_0 *) vx;\n\n    const float d = x[ib].d;\n\n    uint32_t qh;\n    memcpy(&qh, x[ib].qh, sizeof(qh));\n\n    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;\n    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;\n\n    v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);\n    v.y = ((x[ib].qs[iqs] >>  4) | xh_1);\n\n    v.x = (v.x - 16.0f) * d;\n    v.y = (v.y - 16.0f) * d;\n}\n\nstatic __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, float2 & v){\n    const block_q5_1 * x = (const block_q5_1 *) vx;\n\n    const float2 dm = __half22float2(x[ib].dm);\n\n    uint32_t qh;\n    memcpy(&qh, x[ib].qh, sizeof(qh));\n\n    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;\n    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;\n\n    v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);\n    v.y = ((x[ib].qs[iqs] >>  4) | xh_1);\n\n    v.x = (v.x * dm.x) + dm.y;\n    v.y = (v.y * dm.x) + dm.y;\n}\n\nstatic __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, float2 & v){\n    const block_q8_0 * x = (const block_q8_0 *) vx;\n\n    const float d = x[ib].d;\n\n    v.x = x[ib].qs[iqs + 0];\n    v.y = x[ib].qs[iqs + 1];\n\n    v.x *= d;\n    v.y *= d;\n}\n"
  },
  {
    "path": "src/ggml-cuda/diag.cu",
    "content": "#include \"convert.cuh\"\n#include \"diag.cuh\"\n#include \"ggml.h\"\n\ntemplate <typename T>\nstatic __global__ void diag_kernel(T * __restrict__ dst,\n                                   const T * __restrict__ src,\n                                   const int64_t ne0,\n                                   const int64_t ne1,\n                                   const int64_t ne2,\n                                   const int64_t ne3,\n                                   const int64_t total_elements) {\n    const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n    if (global_idx >= total_elements) {\n        return;\n    }\n\n    const int64_t i0 = global_idx % ne0;\n    const int64_t i1 = (global_idx / ne0) % ne1;\n    const int64_t i2 = (global_idx / (ne0 * ne1)) % ne2;\n    const int64_t i3 = global_idx / (ne0 * ne1 * ne2);\n\n    const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0;\n\n    if (i0 == i1) {\n        const int64_t batch_idx = i3 * ne2 + i2;\n        const int64_t src_idx   = batch_idx * ne0 + i0;\n        dst[dst_idx]            = src[src_idx];\n    } else {\n        dst[dst_idx] = ggml_cuda_cast<T>(0);\n    }\n    GGML_UNUSED_VARS(ne3);\n}\n\nvoid ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    void *       dst_d  = dst->data;\n    const void * src0_d = src0->data;\n\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    const int64_t ne00 = src0->ne[0];\n    const int64_t ne01 = src0->ne[1];\n    const int64_t ne02 = src0->ne[2];\n    const int64_t ne03 = src0->ne[3];\n\n    const int64_t ne0 = dst->ne[0];\n    const int64_t ne1 = dst->ne[1];\n    const int64_t ne2 = dst->ne[2];\n    const int64_t ne3 = dst->ne[3];\n\n    GGML_ASSERT(ne00 == ne0);\n    GGML_ASSERT(ne01 == 1);\n    GGML_ASSERT(ne02 == ne2);\n    GGML_ASSERT(ne03 == ne3);\n\n    const int64_t n_elems    = ggml_nelements(dst);\n    const int64_t num_blocks = (n_elems + CUDA_DIAG_BLOCK_SIZE - 1) / CUDA_DIAG_BLOCK_SIZE;\n\n    switch (dst->type) {\n        case GGML_TYPE_F32:\n            diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((float *) dst_d, (const float *) src0_d, ne0,\n                                                                         ne1, ne2, ne3, n_elems);\n            break;\n        case GGML_TYPE_F16:\n            diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((half *) dst_d, (const half *) src0_d, ne0,\n                                                                         ne1, ne2, ne3, n_elems);\n            break;\n        default:\n            GGML_ABORT(\"unsupported type\");\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/diag.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_DIAG_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/diagmask.cu",
    "content": "#include \"diagmask.cuh\"\n\nstatic __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {\n    const int col = blockDim.y*blockIdx.y + threadIdx.y;\n    const int row = blockDim.x*blockIdx.x + threadIdx.x;\n\n    if (col >= ncols) {\n        return;\n    }\n\n    const int i = row*ncols + col;\n    //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];\n    //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU\n    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;\n}\n\nstatic void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {\n    const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);\n    const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;\n    const dim3 block_nums(nrows_x, block_num_x, 1);\n    diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);\n}\n\nvoid ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src0_d = (const float *)src0->data;\n    float * dst_d = (float *)dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    const int64_t ne00 = src0->ne[0];\n    const int64_t ne01 = src0->ne[1];\n    const int nrows0 = ggml_nrows(src0);\n\n    const int n_past = ((int32_t *) dst->op_params)[0];\n\n    diag_mask_inf_f32_cuda(src0_d, dst_d, ne00, nrows0, ne01, n_past, stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/diagmask.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32\n\nvoid ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/fattn-common.cuh",
    "content": "#pragma once\n\n#include \"common.cuh\"\n#include \"convert.cuh\"\n#include \"vecdotq.cuh\"\n\n#include <cstdint>\n\n#define FATTN_KQ_STRIDE       256\n#define HALF_MAX_HALF         __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.\n#define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.\n\n// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable\n//     by the VKQ accumulators is effectively being shifted up by a factor of 2.\n// This reduces issues with numerical overflow but also causes larger values to be flushed to zero.\n// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.\n// Still, the value range should be shifted as much as necessary but as little as possible.\n// The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 .\n#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)\n\ntypedef void (* fattn_kernel_t)(\n        const char * __restrict__ Q,\n        const char * __restrict__ K,\n        const char * __restrict__ V,\n        const char * __restrict__ mask,\n        const char * __restrict__ sinks,\n        const int  * __restrict__ KV_max,\n        float      * __restrict__ dst,\n        float2     * __restrict__ dst_meta,\n        const float scale,\n        const float max_bias,\n        const float m0,\n        const float m1,\n        const uint32_t n_head_log2,\n        const float logit_softcap,\n        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,\n                            const int32_t nb01, const int32_t nb02, const int32_t nb03,\n        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,\n                            const int32_t nb11, const int32_t nb12, const int64_t nb13,\n                            const int32_t nb21, const int32_t nb22, const int64_t nb23,\n                            const int32_t ne31, const int32_t ne32, const int32_t ne33,\n                            const int32_t nb31, const int32_t nb32, const int64_t nb33);\n\ntypedef float (*vec_dot_KQ_t)(\n    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);\n\ntemplate <int D, int nthreads>\nstatic __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(\n    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {\n\n    const half2 * K_h2 = (const half2 *) K_c;\n    GGML_UNUSED(Q_q8);\n    GGML_UNUSED(Q_ds_v);\n\n    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();\n    constexpr int cpy_ne = cpy_nb / 4;\n\n    float sum = 0.0f;\n\n#pragma unroll\n    for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {\n        __align__(16) half2 tmp[cpy_ne];\n        ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);\n#pragma unroll\n        for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {\n#ifdef V_DOT2_F32_F16_AVAILABLE\n            ggml_cuda_mad(sum,                tmp[k_KQ_1] , ((const half2  *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);\n#else\n            ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);\n#endif // V_DOT2_F32_F16_AVAILABLE\n        }\n    }\n\n    return sum;\n}\n\ntemplate<int D, int nthreads>\nstatic __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(\n    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {\n\n    const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;\n    GGML_UNUSED(Q_v);\n\n    float sum = 0.0f;\n\n#pragma unroll\n    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {\n        const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);\n\n        const int ib    = k_KQ /  QI8_1;\n        const int iqs4  = k_KQ %  QI4_0;\n        const int shift = k_KQ & (QI8_1/2);\n\n        int v;\n        ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);\n        v = (v >> shift) & 0x0F0F0F0F;\n        const int u = Q_q8[k_KQ_0/nthreads];\n\n        const int sumi = ggml_cuda_dp4a(v, u, 0);\n\n        const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];\n        sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x - (8/QI8_1)*Q_ds.y);\n    }\n\n    return sum;\n}\n\ntemplate<int D, int nthreads>\nstatic __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_1(\n    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {\n\n    const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;\n    GGML_UNUSED(Q_v);\n\n    float sum = 0.0f;\n\n#pragma unroll\n    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {\n        const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);\n\n        const int ib    = k_KQ /  QI8_1;\n        const int iqs4  = k_KQ %  QI4_1;\n        const int shift = k_KQ & (QI8_1/2);\n\n        int v;\n        ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);\n        v = (v >> shift) & 0x0F0F0F0F;\n        const int u = Q_q8[k_KQ_0/nthreads];\n\n        const int sumi = ggml_cuda_dp4a(v, u, 0);\n\n        const float2 K_dm = __half22float2(K_q4_1[ib].dm);\n        const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];\n\n        sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;\n    }\n\n    return sum;\n}\n\ntemplate<int D, int nthreads>\nstatic __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_0(\n    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {\n\n    const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;\n    GGML_UNUSED(Q_v);\n\n    float sum = 0.0f;\n\n#pragma unroll\n    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {\n        const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);\n\n        const int ib    = k_KQ /  QI8_1;\n        const int iqs4  = k_KQ %  QI5_0;\n        const int iqs8  = k_KQ %  QI8_1;\n        const int shift = k_KQ & (QI8_1/2);\n\n        int v;\n        ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);\n        v = (v >> shift) & 0x0F0F0F0F;\n\n        {\n            int vh;\n            ggml_cuda_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);\n            vh >>= iqs8 * QI5_0;\n\n            v |= (vh <<  4) & 0x00000010; // 0 ->  4\n            v |= (vh << 11) & 0x00001000; // 1 -> 12\n            v |= (vh << 18) & 0x00100000; // 2 -> 20\n            v |= (vh << 25) & 0x10000000; // 3 -> 28\n        }\n\n        const int u = Q_q8[k_KQ_0/nthreads];\n\n        const int sumi = ggml_cuda_dp4a(v, u, 0);\n\n        const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];\n\n        sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x - (16/QI8_1)*Q_ds.y);\n    }\n\n    return sum;\n}\n\ntemplate<int D, int nthreads>\nstatic __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_1(\n    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {\n\n    const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;\n    GGML_UNUSED(Q_v);\n\n    float sum = 0.0f;\n\n#pragma unroll\n    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {\n        const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);\n\n        const int ib    = k_KQ /  QI8_1;\n        const int iqs4  = k_KQ %  QI5_1;\n        const int iqs8  = k_KQ %  QI8_1;\n        const int shift = k_KQ & (QI8_1/2);\n\n        int v;\n        ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);\n        v = (v >> shift) & 0x0F0F0F0F;\n\n        {\n            int vh;\n            ggml_cuda_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);\n            vh >>= iqs8 * QI5_0;\n\n            v |= (vh <<  4) & 0x00000010; // 0 ->  4\n            v |= (vh << 11) & 0x00001000; // 1 -> 12\n            v |= (vh << 18) & 0x00100000; // 2 -> 20\n            v |= (vh << 25) & 0x10000000; // 3 -> 28\n        }\n\n        const int u = Q_q8[k_KQ_0/nthreads];\n\n        const int sumi = ggml_cuda_dp4a(v, u, 0);\n\n        const float2 K_dm = __half22float2(K_q5_1[ib].dm);\n        const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];\n\n        sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;\n    }\n\n    return sum;\n}\n\ntemplate <int D, int nthreads>\nstatic __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0(\n    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {\n\n    const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;\n    GGML_UNUSED(Q_v);\n\n    float sum = 0.0f;\n\n#pragma unroll\n    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {\n        const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);\n\n        const int ib  = k_KQ / QI8_0;\n        const int iqs = k_KQ % QI8_0;\n\n        int v;\n        ggml_cuda_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);\n\n        const float2 * Q_ds = (const float2 *) Q_ds_v;\n        const float Q_d = Q_ds[k_KQ_0/nthreads].x;\n\n        sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);\n    }\n\n    return sum;\n}\n\ntemplate <typename Tds, int ni>\nstatic __device__ __forceinline__ void quantize_q8_1_to_shared(\n    const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {\n\n    float vals[sizeof(int)] = {0.0f};\n#pragma unroll\n    for (int l = 0; l < int(sizeof(int)); ++l) {\n        vals[l] = (ni == WARP_SIZE || threadIdx.x < ni) ? scale * x[4*threadIdx.x + l] : 0.0f;\n    }\n\n    float amax = fabsf(vals[0]);\n    float sum  = vals[0];\n#pragma unroll\n    for (int l = 1; l < int(sizeof(int)); ++l) {\n        amax = fmaxf(amax, fabsf(vals[l]));\n        sum += vals[l];\n    }\n#pragma unroll\n    for (int mask = QI8_1/2; mask > 0; mask >>= 1) {\n        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));\n        sum +=             __shfl_xor_sync(0xFFFFFFFF, sum,  mask, 32);\n    }\n\n    const float d = amax / 127;\n    int q32 = 0;\n    int8_t * q8 = (int8_t *) &q32;\n\n    if (d != 0.0f) {\n#pragma unroll\n        for (int l = 0; l < int(sizeof(int)); ++l) {\n            q8[l] = roundf(vals[l] / d);\n        }\n    }\n\n    yq32[threadIdx.x] = q32;\n    if (threadIdx.x % QI8_1 == 0 && (ni == WARP_SIZE || threadIdx.x < ni)) {\n        if (std::is_same<Tds, half2>::value) {\n            ((half2  *) yds)[threadIdx.x/QI8_1] =  make_half2(d, sum);\n        } else {\n            ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum);\n        }\n    }\n}\n\ntypedef void (*dequantize_V_t)(const void *, void *, const int64_t);\n\ntemplate <typename T, int ne>\nstatic __device__ __forceinline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {\n    if constexpr (std::is_same_v<T, half>) {\n        ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0);\n    } else if constexpr (std::is_same_v<T, float>) {\n        static_assert(ne % 2 == 0, \"bad ne\");\n        __align__(16) half2 tmp[ne/2];\n        ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0);\n        float2 * dst_f2 = (float2 *) dst;\n#pragma unroll\n        for (int l = 0; l < ne/2; ++l) {\n            dst_f2[l] = __half22float2(tmp[l]);\n        }\n    } else {\n        static_assert(std::is_same_v<T, void>, \"unsupported type\");\n    }\n}\n\ntemplate <typename T, int ne>\nstatic __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {\n    const block_q4_0 * x = (const block_q4_0 *) vx;\n\n    const int64_t ib    =  i0          /  QK4_0;\n    const int     iqs   =  i0          % (QK4_0/2);\n    const int     shift = (i0 % QK4_0) / (QK4_0/2);\n\n    int q;\n    static_assert(ne == 2 || ne == 4, \"bad ne\");\n    ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);\n    q >>= 4*shift;\n    q &= 0x0F0F0F0F;\n    q = __vsubss4(q, 0x08080808);\n\n    const int8_t * q8 = (const int8_t *) &q;\n\n#ifdef FP16_AVAILABLE\n    if constexpr (std::is_same_v<T, half>) {\n        const half2 d = __half2half2(x[ib].d);\n\n#pragma unroll\n        for (int l0 = 0; l0 < ne; l0 += 2) {\n            ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);\n        }\n    } else\n#endif // FP16_AVAILABLE\n    if constexpr (std::is_same_v<T, float>) {\n        const float d = x[ib].d;\n\n#pragma unroll\n        for (int l = 0; l < ne; ++l) {\n            ((float *) dst)[l] = d * q8[l];\n        }\n    } else {\n        static_assert(std::is_same_v<T, void>, \"bad type\");\n    }\n}\n\ntemplate <typename T, int ne>\nstatic __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {\n    const block_q4_1 * x = (const block_q4_1 *) vx;\n\n    const int64_t ib    =  i0          /  QK4_1;\n    const int     iqs   =  i0          % (QK4_1/2);\n    const int     shift = (i0 % QK4_1) / (QK4_1/2);\n\n    int q;\n    static_assert(ne == 2 || ne == 4, \"bad ne\");\n    ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);\n    q >>= 4*shift;\n    q &= 0x0F0F0F0F;\n\n    const int8_t * q8 = (const int8_t *) &q;\n\n#ifdef FP16_AVAILABLE\n    if constexpr (std::is_same_v<T, half>) {\n        const half2 dm = x[ib].dm;\n        const half2 d  = __half2half2( __low2half(dm));\n        const half2 m  = __half2half2(__high2half(dm));\n\n#pragma unroll\n        for (int l0 = 0; l0 < ne; l0 += 2) {\n            ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;\n        }\n    } else\n#endif // FP16_AVAILABLE\n    if constexpr (std::is_same_v<T, float>) {\n        const float2 dm = __half22float2(x[ib].dm);\n\n#pragma unroll\n        for (int l = 0; l < ne; ++l) {\n            ((float *) dst)[l] = dm.x * q8[l] + dm.y;\n        }\n    } else {\n        static_assert(std::is_same_v<T, void>, \"bad type\");\n    }\n}\n\ntemplate <typename T, int ne>\nstatic __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {\n    const block_q5_0 * x = (const block_q5_0 *) vx;\n\n    const int64_t ib    =  i0          /  QK5_0;\n    const int     idq   =  i0          %  QK5_0;\n    const int     iqs   =  i0          % (QK5_0/2);\n    const int     shift = (i0 % QK5_0) / (QK5_0/2);\n\n    int q;\n    static_assert(ne == 2 || ne == 4, \"bad ne\");\n    ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);\n    q >>= 4*shift;\n    q &= 0x0F0F0F0F;\n\n    {\n        int qh;\n        ggml_cuda_memcpy_1<ne, 2>(&qh, x[ib].qh);\n#pragma unroll\n        for (int l = 0; l < ne; ++l) {\n            q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);\n        }\n    }\n\n    q = __vsubss4(q, 0x10101010);\n\n    const int8_t * q8 = (const int8_t *) &q;\n\n#ifdef FP16_AVAILABLE\n    if constexpr (std::is_same_v<T, half>) {\n        const half2 d = __half2half2(x[ib].d);\n\n#pragma unroll\n        for (int l0 = 0; l0 < ne; l0 += 2) {\n            ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);\n        }\n    } else\n#endif // FP16_AVAILABLE\n    if constexpr (std::is_same_v<T, float>) {\n        const float d = x[ib].d;\n\n#pragma unroll\n        for (int l = 0; l < ne; ++l) {\n            ((float *) dst)[l] = d * q8[l];\n        }\n    } else {\n        static_assert(std::is_same_v<T, void>, \"bad type\");\n    }\n}\n\ntemplate <typename T, int ne>\nstatic __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {\n    const block_q5_1 * x = (const block_q5_1 *) vx;\n\n    const int64_t ib    =  i0          /  QK5_1;\n    const int     idq   =  i0          %  QK5_1;\n    const int     iqs   =  i0          % (QK5_1/2);\n    const int     shift = (i0 % QK5_1) / (QK5_1/2);\n\n    int q;\n    static_assert(ne == 2 || ne == 4, \"bad ne\");\n    ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);\n    q >>= 4*shift;\n    q &= 0x0F0F0F0F;\n\n    {\n        int qh;\n        ggml_cuda_memcpy_1<ne>(&qh, x[ib].qh);\n#pragma unroll\n        for (int l = 0; l < ne; ++l) {\n            q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);\n        }\n    }\n\n    const int8_t * q8 = (const int8_t *) &q;\n\n#ifdef FP16_AVAILABLE\n    if constexpr (std::is_same_v<T, half>) {\n        const half2 dm = x[ib].dm;\n        const half2 d  = __half2half2( __low2half(dm));\n        const half2 m  = __half2half2(__high2half(dm));\n\n#pragma unroll\n        for (int l0 = 0; l0 < ne; l0 += 2) {\n            ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;\n        }\n    } else\n#endif // FP16_AVAILABLE\n    if constexpr (std::is_same_v<T, float>) {\n        const float2 dm = __half22float2(x[ib].dm);\n\n#pragma unroll\n        for (int l = 0; l < ne; ++l) {\n            ((float *) dst)[l] = dm.x * q8[l] + dm.y;\n        }\n    } else {\n        static_assert(std::is_same_v<T, void>, \"bad type\");\n    }\n}\n\ntemplate <typename T, int ne>\nstatic __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {\n    const block_q8_0 * x = (const block_q8_0 *) vx;\n\n    const int64_t ib  = i0 / QK8_0;\n    const int     iqs = i0 % QK8_0;\n\n    static_assert(ne % 2 == 0, \"bad ne\");\n    int8_t qs[ne];\n    ggml_cuda_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);\n\n#ifdef FP16_AVAILABLE\n    if constexpr (std::is_same<T, half>::value) {\n        const half2 d = __half2half2(x[ib].d);\n\n#pragma unroll\n        for (int l0 = 0; l0 < ne; l0 += 2) {\n            ((half2 *) dst)[l0/2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);\n        }\n    } else\n#endif // FP16_AVAILABLE\n    if constexpr (std::is_same<T, float>::value) {\n        const float d = x[ib].d;\n\n#pragma unroll\n        for (int l = 0; l < ne; ++l) {\n            ((float *) dst)[l] = d * qs[l];\n        }\n    } else {\n        static_assert(std::is_same_v<T, void>, \"unsupported type\");\n    }\n}\n\ntemplate <ggml_type type_K, int D, int nthreads>\nconstexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {\n    if constexpr (type_K == GGML_TYPE_F16) {\n        return vec_dot_fattn_vec_KQ_f16<D, nthreads>;\n    } else if constexpr (type_K == GGML_TYPE_Q4_0) {\n        return vec_dot_fattn_vec_KQ_q4_0<D, nthreads>;\n    } else if constexpr (type_K == GGML_TYPE_Q4_1) {\n        return vec_dot_fattn_vec_KQ_q4_1<D, nthreads>;\n    } else if constexpr (type_K == GGML_TYPE_Q5_0) {\n        return vec_dot_fattn_vec_KQ_q5_0<D, nthreads>;\n    } else if constexpr (type_K == GGML_TYPE_Q5_1) {\n        return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;\n    } else if constexpr (type_K == GGML_TYPE_Q8_0) {\n        return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;\n    } else {\n        static_assert(type_K == -1, \"bad type\");\n        return nullptr;\n    }\n}\n\ntemplate <ggml_type type_V, typename T, int ne>\nconstexpr __device__ dequantize_V_t get_dequantize_V() {\n    if constexpr (type_V == GGML_TYPE_F16) {\n        return dequantize_V_f16<T, ne>;\n    } else if constexpr (type_V == GGML_TYPE_Q4_0) {\n        return dequantize_V_q4_0<T, ne>;\n    } else if constexpr (type_V == GGML_TYPE_Q4_1) {\n        return dequantize_V_q4_1<T, ne>;\n    } else if constexpr (type_V == GGML_TYPE_Q5_0) {\n        return dequantize_V_q5_0<T, ne>;\n    } else if constexpr (type_V == GGML_TYPE_Q5_1) {\n        return dequantize_V_q5_1<T, ne>;\n    } else if constexpr (type_V == GGML_TYPE_Q8_0) {\n        return dequantize_V_q8_0<T, ne>;\n    } else {\n        static_assert(type_V == -1, \"bad type\");\n        return nullptr;\n    }\n}\n\ntemplate <int ncols1>\n__launch_bounds__(FATTN_KQ_STRIDE/2, 1)\nstatic __global__ void flash_attn_mask_to_KV_max(\n        const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {\n    const int ne31     = gridDim.x;\n    const int tid      = threadIdx.x;\n    const int sequence = blockIdx.y;\n    const int jt       = blockIdx.x;\n\n    mask += sequence*s33 + jt*ncols1*s31;\n\n    __shared__ int buf_iw[WARP_SIZE];\n    if (tid < WARP_SIZE) {\n        buf_iw[tid] = 1;\n    }\n    __syncthreads();\n\n    int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;\n    for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {\n        int all_inf = 1;\n\n#pragma unroll\n        for (int j = 0; j < ncols1; ++j) {\n            const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);\n            all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));\n        }\n\n        all_inf = warp_reduce_all(all_inf);\n        if (tid % WARP_SIZE == 0) {\n            buf_iw[tid / WARP_SIZE] = all_inf;\n        }\n        __syncthreads();\n        all_inf = buf_iw[tid % WARP_SIZE];\n        __syncthreads();\n        all_inf = warp_reduce_all(all_inf);\n\n        if (!all_inf) {\n            break;\n        }\n    }\n\n    // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.\n    // If the break was triggered it's the lower edge of the tile with the first non-masked values.\n    // In either case, walk back the decrementation by FATTN_KQ_STRIDE.\n    KV_max_sj += FATTN_KQ_STRIDE;\n\n    if (threadIdx.x != 0) {\n        return;\n    }\n\n    KV_max[sequence*ne31 + jt] = KV_max_sj;\n}\n\ntemplate<int D, int ncols1, int ncols2> // D == head size\n__launch_bounds__(D, 1)\nstatic __global__ void flash_attn_stream_k_fixup(\n        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,\n        const int ne11, const int ne12, const int nbatch_fa) {\n    constexpr int ncols = ncols1*ncols2;\n\n    const int bidx0 = blockIdx.x;\n    const int j     = blockIdx.y;\n    const int c     = blockIdx.z;\n    const int jc    = j*ncols2 + c;\n    const int tid   = threadIdx.x;\n\n    const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);\n\n    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.\n\n    const int iter_k     = (ne11      + (nbatch_fa - 1)) / nbatch_fa;\n    const int iter_j     = (ne01      + (ncols1    - 1)) / ncols1;\n    const int iter_z_gqa = (gqa_ratio + (ncols2    - 1)) / ncols2;\n\n    const int kbc0      = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;\n    const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;\n\n    const bool did_not_have_any_data   = kbc0 == kbc0_stop;\n    const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;\n    const bool did_not_write_last      = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;\n    if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {\n        return;\n    }\n\n    // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index\n    const int sequence =  kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);\n    const int z_KV     = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);\n    const int zt_gqa   = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);\n    const int jt       = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;\n\n    const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.\n\n    if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {\n        return;\n    }\n\n    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;\n\n    // Load the partial result that needs a fixup:\n    float dst_val = 0.0f;\n    float max_val = 0.0f;\n    float rowsum  = 0.0f;\n    {\n        dst_val = *dst;\n\n        const float2 tmp = dst_fixup[bidx0*ncols + jc];\n        max_val = tmp.x;\n        rowsum  = tmp.y;\n    }\n\n    // Iterate over previous blocks and compute the combined results.\n    // All CUDA blocks that get here must have a previous block that needs a fixup.\n    int bidx = bidx0 - 1;\n    int kbc_stop = kbc0;\n    while(true) {\n        const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;\n        if (kbc == kbc_stop) { // Did not have any data.\n            bidx--;\n            kbc_stop = kbc;\n            continue;\n        }\n\n        const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];\n\n        const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];\n\n        // Scale the current and new value accumulators depending on the max. values.\n        const float max_val_new = fmaxf(max_val, tmp.x);\n\n        const float diff_val = max_val - max_val_new;\n        const float diff_add = tmp.x   - max_val_new;\n\n        const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;\n        const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;\n\n        dst_val = scale_val*dst_val + scale_add*dst_add;\n        rowsum  = scale_val*rowsum  + scale_add*tmp.y;\n\n        max_val = max_val_new;\n\n        // If this block started in a previous tile we are done and don't need to combine additional partial results.\n        if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {\n            break;\n        }\n        bidx--;\n        kbc_stop = kbc;\n    }\n\n    // Write back final result:\n    *dst = dst_val / rowsum;\n}\n\ntemplate<int D> // D == head size\n__launch_bounds__(D, 1)\nstatic __global__ void flash_attn_combine_results(\n        const float  * __restrict__ VKQ_parts,\n        const float2 * __restrict__ VKQ_meta,\n        float * __restrict__ dst,\n        const int parallel_blocks) {\n    // Dimension 0: threadIdx.x\n    // Dimension 1: blockIdx.x\n    // Dimension 2: blockIdx.y\n    // Dimension 3: blockIdx.z\n    // Memory layout is permuted with [0, 2, 1, 3]\n\n    const int ne01 = gridDim.x;\n    const int ne02 = gridDim.y;\n\n    const int col      = blockIdx.x;\n    const int head     = blockIdx.y;\n    const int sequence = blockIdx.z;\n\n    const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;\n\n    VKQ_parts += j_dst_unrolled * parallel_blocks*D;\n    VKQ_meta  += j_dst_unrolled * parallel_blocks;\n    dst       += j_dst_unrolled *                 D;\n\n    const int tid = threadIdx.x;\n    __builtin_assume(tid < D);\n\n    extern __shared__ float2 meta[];\n    for (int i = tid; i < 2*parallel_blocks; i += D) {\n        ((float *) meta)[i] = ((const float *)VKQ_meta) [i];\n    }\n\n    __syncthreads();\n\n    float kqmax = meta[0].x;\n    for (int l = 1; l < parallel_blocks; ++l) {\n        kqmax = max(kqmax, meta[l].x);\n    }\n\n    float VKQ_numerator   = 0.0f;\n    float VKQ_denominator = 0.0f;\n    for (int l = 0; l < parallel_blocks; ++l) {\n        const float KQ_max_scale = expf(meta[l].x - kqmax);\n\n        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*D + tid];\n        VKQ_denominator += KQ_max_scale * meta[l].y;\n    }\n\n    dst[tid] = VKQ_numerator / VKQ_denominator;\n}\n\ntemplate <int DV, int ncols1, int ncols2>\nvoid launch_fattn(\n    ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,\n    const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE\n) {\n    constexpr int ncols = ncols1 * ncols2;\n\n    const ggml_tensor * Q = dst->src[0];\n    const ggml_tensor * K = dst->src[1];\n    const ggml_tensor * V = dst->src[2];\n\n    const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));\n\n    const ggml_tensor * mask  = dst->src[3];\n    const ggml_tensor * sinks = dst->src[4];\n\n    ggml_tensor * KQV = dst;\n\n    GGML_ASSERT(Q->type == GGML_TYPE_F32);\n    GGML_ASSERT(KQV->type == GGML_TYPE_F32);\n\n    GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));\n    GGML_ASSERT(K->nb[0] == ggml_element_size(K));\n    GGML_ASSERT(V->nb[0] == ggml_element_size(V));\n\n    GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);\n\n    ggml_cuda_pool & pool = ctx.pool();\n    cudaStream_t main_stream = ctx.stream();\n    const int id  = ggml_cuda_get_device();\n    const int cc  = ggml_cuda_info().devices[id].cc;\n    const int nsm = ggml_cuda_info().devices[id].nsm;\n\n    ggml_cuda_pool_alloc<half>   K_f16(pool);\n    ggml_cuda_pool_alloc<half>   V_f16(pool);\n    ggml_cuda_pool_alloc<int>    KV_max(pool);\n    ggml_cuda_pool_alloc<float>  dst_tmp(pool);\n    ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);\n\n    const char * K_data = (const char *) K->data;\n    size_t nb11 = K->nb[1];\n    size_t nb12 = K->nb[2];\n    size_t nb13 = K->nb[3];\n\n    const char * V_data = (const char *) V->data;\n    size_t nb21 = V->nb[1];\n    size_t nb22 = V->nb[2];\n    size_t nb23 = V->nb[3];\n\n    if (need_f16_K && K->type != GGML_TYPE_F16) {\n        const size_t bs = ggml_blck_size(K->type);\n        const size_t ts = ggml_type_size(K->type);\n\n        K_f16.alloc(ggml_nelements(K));\n        if (ggml_is_contiguously_allocated(K)) {\n            to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);\n            to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);\n\n            nb11 = nb11*bs*sizeof(half)/ts;\n            nb12 = nb12*bs*sizeof(half)/ts;\n            nb13 = nb13*bs*sizeof(half)/ts;\n        } else {\n            GGML_ASSERT(K->nb[0] == ts);\n            to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);\n            const int64_t s01 = nb11 / ts;\n            const int64_t s02 = nb12 / ts;\n            const int64_t s03 = nb13 / ts;\n            to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);\n\n            nb11 = K->ne[0] * sizeof(half);\n            nb12 = K->ne[1] * nb11;\n            nb13 = K->ne[2] * nb12;\n        }\n        K_data = (char *) K_f16.ptr;\n    }\n\n    if (need_f16_V && V->type != GGML_TYPE_F16) {\n        if (V_is_K_view) {\n            V_data = K_data;\n            nb21   = nb11;\n            nb22   = nb12;\n            nb23   = nb13;\n        } else {\n            const size_t bs = ggml_blck_size(V->type);\n            const size_t ts = ggml_type_size(V->type);\n\n            V_f16.alloc(ggml_nelements(V));\n            if (ggml_is_contiguously_allocated(V)) {\n                to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);\n                to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);\n                V_data = (char *) V_f16.ptr;\n\n                nb21 = nb21*bs*sizeof(half)/ts;\n                nb22 = nb22*bs*sizeof(half)/ts;\n                nb23 = nb23*bs*sizeof(half)/ts;\n            } else {\n                GGML_ASSERT(V->nb[0] == ts);\n                to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);\n                const int64_t s01 = nb21 / ts;\n                const int64_t s02 = nb22 / ts;\n                const int64_t s03 = nb23 / ts;\n                to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);\n\n                nb21 = V->ne[0] * sizeof(half);\n                nb22 = V->ne[1] * nb21;\n                nb23 = V->ne[2] * nb22;\n            }\n            V_data = (char *) V_f16.ptr;\n        }\n    }\n\n    const int ntiles_x     = ((Q->ne[1] + ncols1 - 1) / ncols1);\n    const int gqa_ratio    = Q->ne[2] / K->ne[2];\n    const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);\n    const int ntiles_dst   = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];\n\n    // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.\n    // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or\n    //     multiple sequences of possibly different lengths.\n    if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {\n        const int s31 = mask->nb[1] / sizeof(half2);\n        const int s33 = mask->nb[3] / sizeof(half2);\n\n        const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);\n        const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);\n\n        const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;\n        const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;\n\n        KV_max.alloc(ne_KV_max);\n        flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>\n            ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);\n        CUDA_CHECK(cudaGetLastError());\n    }\n\n    const dim3 block_dim(warp_size, nwarps, 1);\n    int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.\n    CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));\n    GGML_ASSERT(max_blocks_per_sm > 0);\n    int parallel_blocks = max_blocks_per_sm;\n\n    const int ntiles_KV = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by KV cache length.\n\n    dim3 blocks_num;\n    if (stream_k) {\n        // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.\n        const int max_blocks = max_blocks_per_sm*nsm;\n        const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks;\n        const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);\n\n        const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst);\n\n        const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;\n\n        blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst;\n        blocks_num.y = 1;\n        blocks_num.z = 1;\n\n        if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.\n            dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));\n        }\n    } else {\n        // parallel_blocks must not be larger than what the tensor size allows:\n        parallel_blocks = std::min(parallel_blocks, ntiles_KV);\n\n        // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.\n        // Test whether parallel_blocks can be set to a higher value for better efficiency.\n        const int blocks_per_wave = nsm * max_blocks_per_sm;\n        int nwaves_best = 0;\n        int efficiency_percent_best = 0;\n        for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KV; ++parallel_blocks_test) {\n            const int nblocks_total = ntiles_dst * parallel_blocks_test;\n            const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;\n            const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);\n\n            // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.\n            if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {\n                break;\n            }\n\n            if (efficiency_percent > efficiency_percent_best) {\n                nwaves_best = nwaves;\n                efficiency_percent_best = efficiency_percent;\n                parallel_blocks = parallel_blocks_test;\n            }\n        }\n\n        blocks_num.x = ntiles_x;\n        blocks_num.y = parallel_blocks;\n        blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];\n\n        if (parallel_blocks > 1) {\n            dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));\n            dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));\n        }\n    }\n\n    float scale         = 1.0f;\n    float max_bias      = 0.0f;\n    float logit_softcap = 0.0f;\n\n    memcpy(&scale,         (const float *) KQV->op_params + 0, sizeof(float));\n    memcpy(&max_bias,      (const float *) KQV->op_params + 1, sizeof(float));\n    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));\n\n    if (logit_softcap != 0.0f) {\n        scale /= logit_softcap;\n    }\n\n    const uint32_t n_head      = Q->ne[2];\n    const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));\n\n    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);\n    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n    // TODO other tensor dimensions after removal of WMMA kernel:\n    const uint3 ne01 = init_fastdiv_values(Q->ne[1]);\n\n    GGML_ASSERT(block_dim.x % warp_size == 0);\n    fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(\n        (const char *) Q->data,\n        K_data,\n        V_data,\n        mask ? ((const char *) mask->data) : nullptr,\n        sinks ? ((const char *) sinks->data) : nullptr,\n        KV_max.ptr,\n        !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,\n        scale, max_bias, m0, m1, n_head_log2, logit_softcap,\n        Q->ne[0], ne01,     Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],\n        K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,\n        nb21, nb22, nb23,\n        mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,\n        mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0\n    );\n    CUDA_CHECK(cudaGetLastError());\n\n    if (stream_k) {\n        if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.\n            const dim3 block_dim_combine(DV, 1, 1);\n            const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};\n\n            flash_attn_stream_k_fixup<DV, ncols1, ncols2>\n                <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>\n                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);\n        }\n    } else if (parallel_blocks > 1) {\n        const dim3 block_dim_combine(DV, 1, 1);\n        const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);\n        const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);\n\n        flash_attn_combine_results<DV>\n            <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>\n            (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);\n    }\n    CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "src/ggml-cuda/fattn-mma-f16.cuh",
    "content": "#include \"common.cuh\"\n#include \"cp-async.cuh\"\n#include \"mma.cuh\"\n#include \"fattn-common.cuh\"\n\nusing namespace ggml_cuda_mma;\n\n// Config options for the MMA kernel.\n// Should not affect results, only speed/register pressure/shared memory use.\nstruct fattn_mma_config {\n    int  nthreads;       // Number of threads per CUDA block.\n    int  occupancy;      // Targeted occupancy for the MMA kernel.\n    int  nbatch_fa;      // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.\n    int  nbatch_K2;      // Number of K half2 values in direction of DKQ to load in parallel.\n    int  nbatch_V2;      // Number of V half2 values in direction of DV to load in parallel.\n    int  nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel.\n    int  nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support.\n    bool Q_in_reg;       // Whether the Q values should be kept permanently in registers.\n\n    constexpr __host__ __device__ fattn_mma_config(\n            int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) :\n        nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine),\n        nstages_target(nstages_target), Q_in_reg(Q_in_reg) {}\n};\n\n#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \\\n    if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) {                                                                                                       \\\n        static_assert((nthreads_)       % 32 == 0 && (nthreads_)       <= 512, \"bad nthreads\");                                                                    \\\n        static_assert(                               (occupancy_)      <=   8, \"bad occupancy\");                                                                   \\\n        static_assert((nbatch_fa_)      % 32 == 0 && (nbatch_fa_)      <= 256, \"bad nbatch_fa\");                                                                   \\\n        static_assert((nbatch_K2_)      %  4 == 0 && (nbatch_K2_)      <= 512, \"bad nbatch_K2\");                                                                   \\\n        static_assert((nbatch_V2_)      %  4 == 0 && (nbatch_V2_)      <= 256, \"bad nbatch_V2\");                                                                   \\\n        static_assert((nbatch_combine_) %  4 == 0 && (nbatch_combine_) <= 128, \"bad nbatch_combine\");                                                              \\\n        static_assert((nstages_target_)      >= 1 && (nstages_target_) <=   2, \"bad nstages_target\");                                                              \\\n        return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)};           \\\n    }                                                                                                                                                              \\\n\nstatic constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) {\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64,  8, 128, 2, 128,  32,  32,  32, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 16, 128, 2,  64,  32,  32,  32, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 32, 128, 2,  64,  32,  32,  32, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 64, 128, 2,  64,  32,  32,  32, 2, true);\n\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80,  8, 128, 2, 128,  40,  40,  40, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 16, 128, 2,  64,  40,  40,  40, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 32, 128, 2,  64,  40,  40,  40, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 64, 128, 2,  64,  40,  40,  40, 2, true);\n\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96,  8, 128, 2, 128,  48,  48,  48, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 16, 128, 2,  64,  48,  48,  48, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 32, 128, 2,  64,  48,  48,  48, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 64, 128, 2,  64,  48,  48,  48, 2, true);\n\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112,  8, 128, 2, 128,  56,  56,  56, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2,  64,  56,  56,  56, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2,  64,  56,  56,  56, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2,  64,  56,  56,  56, 2, true);\n\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128,  8, 128, 2, 128,  64,  64,  64, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2,  64,  64,  64,  64, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2,  64,  64,  64,  64, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2,  64,  64,  64,  64, 2, true);\n\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256,  8,  64, 4,  64, 128, 128, 128, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16,  64, 4,  32, 128, 128, 128, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  32, 128, 128, 128, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  32, 128, 128, 128, 2, true);\n\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256, 128, 1, false);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32, 288, 256, 128, 1, false);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);\n\n    return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);\n}\n\nstatic constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) {\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256,  8, 128, 2,  64, 128, 128, 128, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2,  64, 128, 128, 128, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  64, 128, 128,  64, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  64, 128, 128,  64, 2, true);\n\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32,  96,  64, 128, 1, false);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32,  96,  64, 128, 1, false);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);\n\n    return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);\n}\n\nstatic constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256,  64, 1, false);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32, 288, 256,  64, 1, false);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128,  64, 1, false);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128,  64, 1, false);\n\n    // TODO tune specifically for Volta\n    return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);\n}\n\nstatic constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2,  64, 128, 128, 128, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  64, 128, 128,  64, 2, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  64, 128, 128,  64, 2, true);\n\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32,  96,  64, 128, 1, false);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);\n\n    // TODO tune specifically for RDNA\n    return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);\n}\n\nstatic constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) {\n    // Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async).\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64,  8, 128, 2, 128,  32,  32,  32, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 16, 128, 2,  64,  32,  32,  32, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 32, 128, 2,  64,  32,  32,  32, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 64, 256, 2,  64,  32,  32,  32, 1, true);\n\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80,  8, 128, 2, 128,  40,  40,  40, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 16, 128, 2,  64,  40,  40,  40, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 32, 128, 2,  64,  40,  40,  40, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 64, 256, 2,  64,  40,  40,  40, 1, true);\n\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96,  8, 128, 2, 128,  48,  48,  48, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 16, 128, 2,  64,  48,  48,  48, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 32, 128, 2,  64,  48,  48,  48, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 64, 256, 2,  64,  48,  48,  48, 1, true);\n\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112,  8, 128, 2, 128,  56,  56,  56, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2,  64,  56,  56,  56, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2,  64,  56,  56,  56, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2,  64,  56,  56,  56, 1, true);\n\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128,  8, 128, 2, 128,  64,  64,  64, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2,  64,  64,  64,  64, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2,  64,  64,  64,  64, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2,  64,  64,  64,  64, 1, true);\n\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256,  8,  64, 4,  64, 128, 128, 128, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16,  64, 4,  32, 128, 128, 128, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  32, 128, 128, 128, 1, true);\n    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2,  32, 128, 128, 128, 1, true);\n\n    // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy\n    // compile-time static_asserts even though the kernel guard prevents runtime execution.\n    // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility.\n    return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false);\n}\n\nstatic __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {\n    if (ampere_mma_available(cc)) {\n        return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);\n    }\n    if (turing_mma_available(cc)) {\n        return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);\n    }\n    if (amd_mfma_available(cc)) {\n        return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);\n    }\n    if (amd_wmma_available(cc)) {\n        return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);\n    }\n    GGML_ASSERT(volta_mma_available(cc));\n    return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);\n}\n\nstatic constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) {\n#if defined(AMPERE_MMA_AVAILABLE)\n    return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);\n#elif defined(TURING_MMA_AVAILABLE)\n    return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);\n#elif defined(AMD_MFMA_AVAILABLE)\n    return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);\n#elif defined(VOLTA_MMA_AVAILABLE)\n    return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);\n#elif defined(AMD_WMMA_AVAILABLE)\n    return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);\n#else\n    GGML_UNUSED_VARS(DKQ, DV, ncols);\n    return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);\n#endif // defined(AMPERE_MMA_AVAILABLE)\n}\n\nstatic __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads;\n}\n\nstatic constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads;\n}\n\nstatic __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy;\n}\n\nstatic constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy;\n}\n\nstatic __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa;\n}\n\nstatic constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa;\n}\n\nstatic __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2;\n}\n\nstatic constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2;\n}\n\nstatic __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2;\n}\n\nstatic constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2;\n}\n\nstatic __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine;\n}\n\nstatic constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine;\n}\n\nstatic __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target;\n}\n\nstatic constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target;\n}\n\nstatic __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg;\n}\n\nstatic constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) {\n    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;\n}\n\nstatic constexpr __device__ int get_cols_per_thread() {\n#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n    return 1; // AMD has a single column per thread.\n#else\n    return 2; // This is specifically KQ columns, Volta only has a single VKQ column.\n#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n}\n\nstatic __host__ int get_cols_per_warp(const int cc) {\n    if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) {\n        return 16;\n    } else {\n        // Volta\n        return 32;\n    }\n}\n\n// ------------------------------------------------------------------------------------------------------------------\n\nstatic __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {\n    return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0;\n}\n\nstatic constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) {\n#ifdef CP_ASYNC_AVAILABLE\n    return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0;\n#else\n    GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2);\n    return 0;\n#endif // CP_ASYNC_AVAILABLE\n}\n\n// ------------------------------------------------------------------------------------------------------------------\n\ntemplate<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>\nstatic __device__ __forceinline__ void flash_attn_ext_f16_load_tile(\n        const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n    // K/V data is loaded with decreasing granularity for D for better memory bandwidth.\n    // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.\n    if constexpr (use_cp_async) {\n        static_assert(!oob_check, \"OOB check not compatible with cp_async\");\n        constexpr int preload = 64;\n        constexpr int h2_per_chunk = 16/sizeof(half2);\n        const int chunks_per_row = D2 / h2_per_chunk;\n\n        const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);\n\n        auto load = [&] __device__ (auto n) {\n            const int stride_k = warp_size >> n;\n            const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);\n            const int k0_stop  =                             chunks_per_row - chunks_per_row % (1*stride_k);\n            const int stride_i = warp_size / stride_k;\n\n            if (k0_start == k0_stop) {\n                return;\n            }\n\n#pragma unroll\n            for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {\n                const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);\n\n                if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {\n                    break;\n                }\n\n#pragma unroll\n                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {\n                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);\n\n                    cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);\n                }\n            }\n        };\n        // 1: max 32*16=512 bytes, 256 half\n        // 2: max 16*16=256 bytes, 128 half\n        // 3: max  8*16=128 bytes,  64 half\n        // 4: max  4*16= 64 bytes,  32 half\n        // 5: max  2*16= 32 bytes,  16 half\n        // 6: max  1*16= 16 bytes,   8 half\n        ggml_cuda_unroll<6>{}(load);\n    } else {\n        // TODO use ggml_cuda_memcpy_1\n        auto load = [&] __device__ (const int n) {\n            const int stride_k = warp_size >> n;\n            const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k);\n            const int k0_stop  =                             D2 - D2 % (1*stride_k);\n            const int stride_i = warp_size / stride_k;\n\n            if (k0_start == k0_stop) {\n                return;\n            }\n\n#pragma unroll\n            for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {\n                const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);\n\n                if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {\n                    break;\n                }\n\n#pragma unroll\n                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {\n                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);\n\n                    tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);\n                }\n            }\n        };\n        // 1: max 32* 4=128 bytes,  64 half\n        // 2: max 16* 4= 64 bytes,  32 half\n        // 3: max  8* 4= 32 bytes,  16 half\n        // 4: max  4* 4= 16 bytes,   8 half\n        ggml_cuda_unroll<4>{}(load);\n    }\n}\n\ntemplate<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>\nstatic __device__ __forceinline__ void flash_attn_ext_f16_load_mask(\n        const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,\n        const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n    if constexpr (use_cp_async) {\n        static_assert(nbatch_fa <= 8*warp_size && nbatch_fa % 8 == 0, \"bad nbatch_fa\");\n        static_assert(!oob_check, \"OOB check incompatible with cp_async\");\n        constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;\n        constexpr int cols_per_warp = 8*warp_size/nbatch_fa;\n        constexpr int stride_j = nwarps * cols_per_warp;\n\n        const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);\n\n#pragma unroll\n        for (int j1 = 0; j1 < ncols1; j1 += stride_j) {\n            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);\n            const int j_vram = fastmodulo(j0 + j_sram, ne01);\n\n            if (j1 + stride_j > ncols1 && j_sram >= ncols1) {\n                break;\n            }\n\n            const int i = 8 * (threadIdx.x % (nbatch_fa/8));\n\n            cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);\n        }\n    } else if constexpr (oob_check) {\n#pragma unroll\n        for (int j1 = 0; j1 < ncols1; j1 += nwarps) {\n            const int j_sram = j1 + threadIdx.y;\n            const int j_vram = fastmodulo(j0 + j_sram, ne01);\n\n            if (j1 + nwarps > ncols1 && j_sram >= ncols1) {\n                break;\n            }\n\n#pragma unroll\n            for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n\n                tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);\n            }\n        }\n    } else if constexpr (nbatch_fa < 2*warp_size) {\n        constexpr int cols_per_warp = 2*warp_size/nbatch_fa;\n        constexpr int stride_j = nwarps * cols_per_warp;\n#pragma unroll\n        for (int j1 = 0; j1 < ncols1; j1 += stride_j) {\n            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);\n            const int j_vram = fastmodulo(j0 + j_sram, ne01);\n\n            if (j1 + stride_j > ncols1 && j_sram >= ncols1) {\n                break;\n            }\n\n            const int i = threadIdx.x % (warp_size/cols_per_warp);\n\n            ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);\n        }\n    } else {\n#pragma unroll\n        for (int j1 = 0; j1 < ncols1; j1 += nwarps) {\n            const int j_sram = j1 + threadIdx.y;\n            const int j_vram = fastmodulo(j0 + j_sram, ne01);\n\n            if (j1 + nwarps > ncols1 && j_sram >= ncols1) {\n                break;\n            }\n\n#pragma unroll\n            for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {\n                const int i = i0 + 2*threadIdx.x;\n\n                ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);\n            }\n        }\n    }\n}\n\ntemplate<int DKQ, int DV, int ncols1, int ncols2, int nwarps,\n    bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,\n    typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>\nstatic __device__ __forceinline__ void flash_attn_ext_f16_iter(\n        const float2 * const __restrict__ Q_f2,\n        const half2  * const __restrict__ K_h2,\n        const half2  * const __restrict__ V_h2,\n        const half   * const __restrict__ mask_h,\n        float2       * const __restrict__ dstk,\n        float2       * const __restrict__ dstk_fixup,\n        const float scale,\n        const float slope,\n        const float logit_softcap,\n        const uint3 ne01,\n        const int ne02,\n        const int stride_K,\n        const int stride_V,\n        const int stride_mask,\n        half2        * const __restrict__ tile_Q,\n        half2        * const __restrict__ tile_K,\n        half2        * const __restrict__ tile_V,\n        half         * const __restrict__ tile_mask,\n        T_B_KQ       * const __restrict__ Q_B,\n        T_C_VKQ      * const __restrict__ VKQ_C,\n        float        * const __restrict__ KQ_max,\n        float        * const __restrict__ KQ_rowsum,\n        const int jt,\n        const int kb0,\n        const int k_VKQ_sup) {\n#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)\n    constexpr int  warp_size       = ggml_cuda_get_physical_warp_size();\n    constexpr int  ncols           = ncols1 * ncols2;\n    constexpr int  cols_per_warp   = T_B_KQ::I;\n    constexpr int  cols_per_thread = get_cols_per_thread();\n    constexpr int  np              = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.\n    constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);\n    constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);\n    constexpr int  nbatch_V2       = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);\n    constexpr bool Q_in_reg        = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);\n    constexpr int  nstages         = ggml_cuda_fattn_mma_get_nstages  (DKQ, DV, ncols1, ncols2);\n\n    constexpr int stride_tile_Q = DKQ/2     + 4;\n    constexpr int stride_tile_K = nbatch_K2 + 4;\n\n    constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;\n\n    const int k_VKQ_0 = kb0 * nbatch_fa;\n#if defined(TURING_MMA_AVAILABLE)\n    T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];\n#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n    T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];\n#else // Volta\n    T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];\n#endif // defined(TURING_MMA_AVAILABLE)\n\n    if constexpr (nstages > 1) {\n        static_assert(!oob_check, \"OOB check incompatible with multi-stage pipeline\");\n        static_assert(!V_is_K_view, \"K data reuse not implemented multi-stage loading\");\n        static_assert(nbatch_K2 == DKQ/2, \"batching not implemented for multi stage loading\");\n        constexpr bool use_cp_async = true;\n        cp_async_wait_all();\n        __syncthreads();\n        flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>\n            (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup);\n    } else {\n        constexpr bool use_cp_async = nstages == 1;\n        if (ncols2 > 1 || mask_h) {\n            flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>\n                (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);\n        }\n    }\n\n    // For MLA K and V have the same data.\n    // Therefore, iterate over K in reverse and later re-use the data if possible.\n#pragma unroll\n    for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {\n        const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;\n        const int k0_diff = k0_stop - k0_start;\n\n        if constexpr (nstages <= 1) {\n            constexpr bool use_cp_async = nstages == 1;\n            flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>\n                (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup);\n            if (use_cp_async) {\n                cp_async_wait_all();\n            }\n            __syncthreads();\n        }\n\n        // Calculate tile of KQ:\n        if constexpr (Q_in_reg) {\n#pragma unroll\n            for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {\n                const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;\n#pragma unroll\n                for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {\n                    T_A_KQ K_A;\n                    load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);\n                    if constexpr (cols_per_warp == 8) {\n                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);\n                    } else {\n                        // Wide version of KQ_C is column-major\n#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n                        // AMD matrix C is column-major.\n                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);\n#else\n                        // swap A and B for CUDA.\n                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);\n#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n                    }\n                }\n            }\n        } else {\n#pragma unroll\n            for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {\n                load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);\n\n#pragma unroll\n                for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {\n                    const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;\n\n                    T_A_KQ K_A;\n                    load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);\n\n                    if constexpr (cols_per_warp == 8) {\n                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);\n                    } else {\n                        // Wide version of KQ_C is column-major\n#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n                        // AMD matrix C is column-major.\n                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);\n#else\n                        // swap A and B for CUDA.\n                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);\n#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n                    }\n                }\n            }\n        }\n\n        if constexpr (nstages <= 1) {\n            __syncthreads(); // Only needed if tile_K == tile_V.\n        }\n    }\n\n    if (use_logit_softcap) {\n        constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J;\n        static_assert(nbatch_fa % stride == 0, \"bad loop size\");\n#pragma unroll\n        for (int i = 0; i < nbatch_fa/stride; ++i) {\n#pragma unroll\n            for (int l = 0; l < T_C_KQ::ne; ++l) {\n                KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);\n            }\n        }\n    }\n\n    float KQ_max_new[cols_per_thread];\n#pragma unroll\n    for (int col = 0; col < cols_per_thread; ++col) {\n        KQ_max_new[col] = KQ_max[col];\n    }\n    float KQ_rowsum_add[cols_per_thread] = {0.0f};\n\n    if constexpr (cols_per_warp == 8) {\n        if (ncols2 > 1 || mask_h) {\n#pragma unroll\n            for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) {\n                const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I;\n#pragma unroll\n                for (int l = 0; l < T_C_KQ::ne; ++l) {\n                    const int i = i0 + T_C_KQ::get_i(l);\n                    const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2;\n\n                    KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]);\n                }\n            }\n        }\n\n        // Calculate softmax for each KQ column using the current max. value.\n        // The divisor is stored in KQ_rowsum and will be applied at the end.\n        static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, \"bad loop size\");\n#pragma unroll\n        for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {\n#pragma unroll\n            for (int l = 0; l < T_C_KQ::ne; ++l) {\n                if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {\n#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n                    constexpr int KQ_idx = 0;\n#else\n                    // Turing + Volta:\n                    const int KQ_idx = l % 2;\n#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n                    KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);\n                }\n            }\n        }\n\n        // Values per KQ column are spread across 8 threads:\n#pragma unroll\n        for (int col = 0; col < cols_per_thread; ++col) {\n#pragma unroll\n            for (int offset = 16; offset >= 4; offset >>= 1) {\n                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));\n            }\n        }\n\n        static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, \"bad loop size\");\n#pragma unroll\n        for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {\n#pragma unroll\n            for (int l = 0; l < T_C_KQ::ne; ++l) {\n                if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {\n#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n                    constexpr int KQ_idx = 0;\n#else\n                    // Turing + Volta:\n                    const int KQ_idx = l % 2;\n#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n                    KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);\n                    KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];\n                } else {\n                    KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;\n                }\n            }\n        }\n    } else { // not Turing mma or T_B_KQ::I > 8\n        if (ncols2 > 1 || mask_h) {\n#pragma unroll\n            for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) {\n                const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J;\n#pragma unroll\n                for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) {\n                    const int i = (i0 + T_C_KQ::get_j(l0)) / 2;\n                    const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2;\n\n                    const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]);\n                    KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x;\n                    KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y;\n                }\n            }\n        }\n\n        // Calculate softmax for each KQ column using the current max. value.\n        // The divisor is stored in KQ_rowsum and will be applied at the end.\n        static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, \"bad loop size\");\n#pragma unroll\n        for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {\n#pragma unroll\n            for (int l = 0; l < T_C_KQ::ne; ++l) {\n                if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {\n#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n                    constexpr int KQ_idx = 0;\n#else\n                    // Turing + Volta:\n                    const int KQ_idx = (l/2) % 2;\n#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n                    KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);\n                }\n            }\n        }\n\n#pragma unroll\n        for (int col = 0; col < cols_per_thread; ++col) {\n#if defined(TURING_MMA_AVAILABLE)\n            // Values per KQ column are spread across 4 threads:\n            constexpr int offset_first = 2;\n            constexpr int offset_last  = 1;\n#elif defined(AMD_MFMA_AVAILABLE)\n            // MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16).\n            constexpr int offset_first = 32;\n            constexpr int offset_last  = 16;\n#elif defined(AMD_WMMA_AVAILABLE)\n            // Values per KQ column are spread across 2 threads:\n            constexpr int offset_first = 16;\n            constexpr int offset_last  = 16;\n#else // Volta\n            // Values per KQ column are spread across 2 threads:\n            constexpr int offset_first = 2;\n            constexpr int offset_last  = 2;\n#endif // defined(TURING_MMA_AVAILABLE)\n#pragma unroll\n            for (int offset = offset_first; offset >= offset_last; offset >>= 1) {\n                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));\n            }\n        }\n\n        static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, \"bad loop size\");\n#pragma unroll\n        for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {\n#pragma unroll\n            for (int l = 0; l < T_C_KQ::ne; ++l) {\n                if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {\n#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n                    constexpr int KQ_idx = 0;\n#else\n                    // Turing + Volta:\n                    const int KQ_idx = (l/2) % 2;\n#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n                    KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);\n                    KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];\n                } else {\n                    KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;\n                }\n            }\n        }\n    }\n\n    {\n        float KQ_max_scale[cols_per_thread];\n#pragma unroll\n        for (int col = 0; col < cols_per_thread; ++col) {\n            const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];\n            KQ_max_scale[col] = expf(KQ_max_diff);\n            KQ_max[col] = KQ_max_new[col];\n\n            *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;\n\n            // Scale previous KQ_rowsum to account for a potential increase in KQ_max:\n            KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];\n        }\n\n#if defined(TURING_MMA_AVAILABLE)\n        if constexpr (cols_per_warp == 8) {\n            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);\n#pragma unroll\n            for (int i = 0; i < DV/T_C_VKQ::I; ++i) {\n#pragma unroll\n                for (int l = 0; l < T_C_VKQ::ne; ++l) {\n                    VKQ_C[i].x[l] *= KQ_max_scale_h2;\n                }\n            }\n        } else {\n#pragma unroll\n            for (int col = 0; col < cols_per_thread; ++col) {\n                const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);\n#pragma unroll\n                for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {\n#pragma unroll\n                    for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {\n                        VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;\n                    }\n                }\n            }\n        }\n#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n        const half2 KQ_max_scale_h2 = make_half2(\n            KQ_max_scale[0], KQ_max_scale[0]);\n#pragma unroll\n        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {\n#pragma unroll\n            for (int l = 0; l < T_C_VKQ::ne; ++l) {\n                VKQ_C[i].x[l] *= KQ_max_scale_h2;\n            }\n        }\n#else // Volta\n        const half2 KQ_max_scale_h2 = make_half2(\n            KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);\n#pragma unroll\n        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {\n#pragma unroll\n            for (int l = 0; l < T_C_VKQ::ne; ++l) {\n                VKQ_C[i].x[l] *= KQ_max_scale_h2;\n            }\n        }\n#endif // defined(TURING_MMA_AVAILABLE)\n    }\n\n    // Convert KQ C tiles into B tiles for VKQ calculation:\n    T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)];\n    static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, \"bad loop size\");\n    if constexpr (cols_per_warp == 8) {\n#pragma unroll\n        for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {\n            B[k] = get_transposed(get_half2(KQ_C[k]));\n        }\n    } else {\n        for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {\n            B[k] = get_half2(KQ_C[k]);\n        }\n    }\n\n    if constexpr (nstages > 1) {\n        static_assert(!V_is_K_view, \"K data reuse not implemented multi-stage loading\");\n        // Preload K tile for next iteration:\n        constexpr bool use_cp_async = true;\n        cp_async_wait_all();\n        __syncthreads();\n        if (!last_iter) {\n            if (ncols2 > 1 || mask_h) {\n                flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>\n                    (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);\n            }\n            flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>\n                (K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);\n        }\n    }\n\n\n#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)\n    T_A_VKQ A_identity;\n    make_identity_mat(A_identity);\n#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)\n\n    // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:\n#pragma unroll\n    for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {\n        static_assert(DV % (2*nbatch_V2) == 0, \"bad loop size\");\n        const int i0_stop = i0_start + 2*nbatch_V2;\n        const int i0_diff = i0_stop - i0_start;\n\n        if constexpr (nstages <= 1) {\n            if (!V_is_K_view || i0_stop > 2*nbatch_K2) {\n                constexpr bool use_cp_async = nstages == 1;\n                flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>\n                    (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);\n                if (use_cp_async) {\n                    cp_async_wait_all();\n                }\n                __syncthreads();\n            }\n        }\n        const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;\n\n#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n        constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;\n#pragma unroll\n        for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {\n            static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, \"bad loop size\");\n#pragma unroll\n            for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) {\n                const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;\n\n                T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.\n#if defined(LDMATRIX_TRANS_AVAILABLE)\n                load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);\n#elif defined(AMD_MFMA_AVAILABLE)\n                // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].\n                // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.\n                // Load with transposed addressing: 4 strided half loads.\n                {\n                    const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2;\n                    const half * xs0_h = (const half *) xs0;\n                    const int stride_h = stride_tile_V * 2; // stride in half units\n                    half * A_h = (half *) A.x;\n#pragma unroll\n                    for (int l = 0; l < 4; ++l) {\n                        A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16];\n                    }\n                }\n#else\n                // TODO: Try to transpose tile_V when loading gmem to smem.\n                // Use mma to transpose T_A_VKQ for RDNA.\n                T_A_VKQ A_trans;\n                load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);\n                mma(A, A_trans, A_identity);\n#endif // defined(LDMATRIX_TRANS_AVAILABLE)\n                if constexpr (T_B_KQ::I == 8) {\n                    mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);\n                } else {\n                    // Wide version of VKQ_C is column-major.\n#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n                    // AMD matrix C is column-major.\n                    mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);\n#else\n                    // swap A and B for CUDA.\n                    mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);\n#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n                }\n            }\n        }\n#else // Volta\n        constexpr int i0_stride = 2*T_C_VKQ::J;\n#pragma unroll\n        for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {\n            static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, \"bad loop size\");\n            static_assert(2*T_B_VKQ::J == T_A_VKQ::I, \"bad tile sizes\");\n#pragma unroll\n            for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) {\n                const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I;\n\n                T_A_VKQ A; // Transposed in both SRAM and registers, load normally.\n                load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);\n                mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);\n            }\n        }\n#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n\n        if constexpr (nstages <= 1) {\n            __syncthreads(); // Only needed if tile_K == tile_V.\n        }\n    }\n#else\n    GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup,\n        scale, slope, logit_softcap, ne01, ne02,\n        stride_K, stride_V, stride_mask,\n        tile_Q, tile_K, tile_V, tile_mask,\n        Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);\n    NO_DEVICE_CODE;\n#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)\n}\n\n#if defined(TURING_MMA_AVAILABLE)\ntemplate<int ncols> struct mma_tile_sizes {\n    using T_A_KQ  = tile<16,  8, half2>; // row-major\n    using T_B_KQ  = tile<16,  8, half2>; // column-major\n    using T_C_KQ  = tile<16, 16, float>; // column-major\n    using T_A_VKQ = tile<16,  8, half2>; // row-major\n    using T_B_VKQ = tile<16,  8, half2>; // column-major\n    using T_C_VKQ = tile<16,  8, half2>; // column-major\n};\ntemplate<> struct mma_tile_sizes<8> {\n    using T_A_KQ  = tile<16,  8, half2>; // row-major\n    using T_B_KQ  = tile< 8,  8, half2>; // column-major\n    using T_C_KQ  = tile<16,  8, float>; // row-major\n    using T_A_VKQ = tile<16,  8, half2>; // row-major\n    using T_B_VKQ = tile< 8,  8, half2>; // column-major\n    using T_C_VKQ = tile<16,  4, half2>; // row-major\n};\n#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\ntemplate<int ncols> struct mma_tile_sizes {\n    using T_A_KQ  = tile<16,  8, half2>; // row-major\n    using T_B_KQ  = tile<16,  8, half2>; // column-major\n    using T_C_KQ  = tile<16, 16, float>; // column-major\n    using T_A_VKQ = tile<16,  8, half2>; // row-major\n    using T_B_VKQ = tile<16,  8, half2>; // column-major\n    using T_C_VKQ = tile<16,  8, half2>; // column-major\n};\n#else // Volta\ntemplate<int ncols> struct mma_tile_sizes {\n    using T_A_KQ  = tile< 8,  4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major\n    using T_B_KQ  = tile<32,  4, half2, DATA_LAYOUT_I_MAJOR>;          // column-major\n    using T_C_KQ  = tile<32,  8, float, DATA_LAYOUT_I_MAJOR>;          // column-major\n    using T_A_VKQ = tile< 8,  4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major\n    using T_B_VKQ = tile<32,  4, half2, DATA_LAYOUT_I_MAJOR>;          // column-major\n    using T_C_VKQ = tile<32,  4, half2, DATA_LAYOUT_I_MAJOR>;          // column-major\n};\n#endif // defined(TURING_MMA_AVAILABLE)\n\ntemplate<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup>\nstatic __device__ __forceinline__ void flash_attn_ext_f16_process_tile(\n        const float2 * const __restrict__ Q_f2,\n        const half2  * const __restrict__ K_h2,\n        const half2  * const __restrict__ V_h2,\n        const half   * const __restrict__ mask_h,\n        const float  * const __restrict__ sinks_f,\n        float2       * const __restrict__ dstk,\n        float2       * const __restrict__ dstk_fixup,\n        const float scale,\n        const float slope,\n        const float logit_softcap,\n        const uint3 ne01,\n        const int ne02,\n        const int gqa_ratio,\n        const int ne11,\n        const int stride_Q1,\n        const int stride_Q2,\n        const int stride_K,\n        const int stride_V,\n        const int stride_mask,\n        const int jt,\n        const int zt_gqa,\n        const int kb0_start,\n        const int kb0_stop) {\n#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)\n    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.\n\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n    constexpr int ncols = ncols1 * ncols2;\n    using     T_A_KQ    = typename mma_tile_sizes<ncols>::T_A_KQ;\n    using     T_B_KQ    = typename mma_tile_sizes<ncols>::T_B_KQ;\n    using     T_C_KQ    = typename mma_tile_sizes<ncols>::T_C_KQ;\n    using     T_A_VKQ   = typename mma_tile_sizes<ncols>::T_A_VKQ;\n    using     T_B_VKQ   = typename mma_tile_sizes<ncols>::T_B_VKQ;\n    using     T_C_VKQ   = typename mma_tile_sizes<ncols>::T_C_VKQ;\n\n    constexpr int  cols_per_warp   = T_B_KQ::I;\n    constexpr int  cols_per_thread = get_cols_per_thread();\n    constexpr int  np              = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.\n    constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa     (DKQ, DV, ncols);\n    constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2     (DKQ, DV, ncols);\n    constexpr int  nbatch_V2       = ggml_cuda_fattn_mma_get_nbatch_V2     (DKQ, DV, ncols);\n    constexpr int  nbatch_combine  = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols);\n    constexpr bool Q_in_reg        = ggml_cuda_fattn_mma_get_Q_in_reg      (DKQ, DV, ncols);\n    constexpr int  nstages         = ggml_cuda_fattn_mma_get_nstages       (DKQ, DV, ncols1, ncols2);\n\n    if (cols_per_warp > ncols) {\n        NO_DEVICE_CODE;\n        return;\n    }\n\n    static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, \"bad nwarps\");\n\n    constexpr int stride_tile_Q = DKQ/2     + 4;\n    constexpr int stride_tile_K = nbatch_K2 + 4;\n\n    constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;\n    constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;\n\n    extern __shared__ half2 tile_Q[];\n    half2 * tile_K    = Q_in_reg              ? tile_Q                             : tile_Q + ncols     * stride_tile_Q;\n    half2 * tile_V    =           nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K;\n    half  * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max);\n\n    T_B_KQ    Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];\n#if defined(TURING_MMA_AVAILABLE)\n    T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];\n#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n    T_C_VKQ VKQ_C[                                     DV/(2*T_C_VKQ::J)];\n#else // Volta\n    T_C_VKQ VKQ_C[                                     DV/(2*T_C_VKQ::J)];\n#endif // defined(TURING_MMA_AVAILABLE)\n\n    float KQ_rowsum[cols_per_thread] = {0.0f};\n    float KQ_max[cols_per_thread];\n#pragma unroll\n    for (int col = 0; col < cols_per_thread; ++col) {\n        KQ_max[col] = -FLT_MAX/2.0f;\n    }\n\n    // Load Q data into tile_Q, either temporarily or permanently.\n    // Q in registers is faster, but register pressure is the biggest bottleneck.\n    // The loading is done with decreasing granularity for D for better memory bandwidth.\n    const half2 scale_h2 = make_half2(scale, scale);\n#pragma unroll\n    for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {\n        const int k0_start  = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);\n        const int k0_stop   =                             DKQ/2 - (DKQ/2) % (1*stride_k);\n        const int stride_jc = warp_size / stride_k;\n\n        if (k0_start == k0_stop) {\n            continue;\n        }\n\n#pragma unroll\n        for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {\n            const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);\n\n            if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {\n                break;\n            }\n\n            const int j = jc / ncols2;\n            const int c = jc % ncols2;\n\n            if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {\n#pragma unroll\n                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {\n                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);\n\n                    const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];\n                    tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);\n                }\n            } else {\n#pragma unroll\n                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {\n                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);\n\n                    tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);\n                }\n            }\n        }\n    }\n\n    __syncthreads();\n\n    if (Q_in_reg) {\n        const int j0 = (threadIdx.y / np) * cols_per_warp;\n\n#pragma unroll\n        for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) {\n            load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);\n        }\n    }\n\n    __syncthreads();\n\n    int kb0 = kb0_start;\n\n    // Preload mask and K data for first iteration when using cp_async with multiple stages:\n    if constexpr (nstages > 1) {\n        static_assert(nbatch_K2 == DKQ/2, \"batching not implemented for multi-stage pipeline\");\n        constexpr bool use_cp_async = true;\n        constexpr bool oob_check    = false;\n        constexpr int  k_VKQ_sup    = nbatch_fa;\n        if (ncols2 > 1 || mask_h) {\n            flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>\n                (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);\n        }\n        flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>\n            (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);\n    }\n\n    // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.\n    if constexpr (ncols2 == 1) {\n        constexpr bool oob_check = true;\n        for (; kb0 < kb0_stop-1; ++kb0) {\n            constexpr bool last_iter = false;\n            constexpr int  k_VKQ_sup = nbatch_fa;\n            flash_attn_ext_f16_iter\n                <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,\n                 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>\n                (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,\n                 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,\n                 KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);\n        }\n        constexpr bool last_iter = true;\n        const     int  k_VKQ_sup = ne11 - kb0*nbatch_fa;\n        flash_attn_ext_f16_iter\n            <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,\n              T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>\n            (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,\n             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,\n             KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);\n    } else {\n        constexpr bool oob_check = false;\n        for (; kb0 < kb0_stop-1; ++kb0) {\n            constexpr bool last_iter = false;\n            constexpr int  k_VKQ_sup = nbatch_fa;\n            flash_attn_ext_f16_iter\n                <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,\n                 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>\n                (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,\n                 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,\n                 KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);\n        }\n        constexpr bool last_iter = true;\n        constexpr int  k_VKQ_sup = nbatch_fa;\n        flash_attn_ext_f16_iter\n            <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,\n             T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>\n            (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,\n             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,\n             KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);\n    }\n\n    // With multi-stage loading there is no __syncthreads at the end of the iter,\n    //     there can be a race condition on shared memory access for combining/writing back results.\n    if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) {\n        __syncthreads();\n    }\n\n    // Finally, sum up partial KQ rowsums.\n    {\n#if defined(TURING_MMA_AVAILABLE)\n        // The partial sums are spread across 8/4 threads.\n        constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;\n        constexpr int offset_last  = cols_per_warp == 8 ?  4 : 1;\n#elif defined(AMD_MFMA_AVAILABLE)\n        // The partial sums are spread across 4 threads (wavefront64, 16 cols).\n        constexpr int offset_first = 32;\n        constexpr int offset_last  = 16;\n#elif defined(AMD_WMMA_AVAILABLE)\n        // The partial sums are spread across 2 threads.\n        constexpr int offset_first = 16;\n        constexpr int offset_last  = 16;\n#else // Volta\n        // The partial sums are spread across 2 threads.\n        constexpr int offset_first = 2;\n        constexpr int offset_last  = 2;\n#endif // defined(TURING_MMA_AVAILABLE)\n#pragma unroll\n        for (int col = 0; col < cols_per_thread; ++col) {\n#pragma unroll\n            for (int offset = offset_first; offset >= offset_last; offset >>= 1) {\n                KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size);\n            }\n        }\n    }\n\n    // If attention sinks are used, potentially re-scale if KQ_max is small.\n    // Also add the sink as a value to KQ_rowsum, this is done after synchronization of KQ_rowsum\n    //     so it's being done unconditionally for every thread.\n    if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {\n        float KQ_max_scale[cols_per_thread];\n#pragma unroll\n        for (int col = 0; col < cols_per_thread; ++col) {\n            const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);\n            const float sink = sinks_f[jc % ncols2];\n\n            const float KQ_max_new = fmaxf(KQ_max[col], sink);\n            const float KQ_max_diff = KQ_max[col] - KQ_max_new;\n            KQ_max_scale[col] = expf(KQ_max_diff);\n            KQ_max[col] = KQ_max_new;\n\n            *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;\n\n            const float KQ_max_add = expf(sink - KQ_max_new);\n            KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;\n        }\n\n#if defined(TURING_MMA_AVAILABLE)\n        if constexpr (cols_per_warp == 8) {\n            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);\n#pragma unroll\n            for (int i = 0; i < DV/T_C_VKQ::I; ++i) {\n#pragma unroll\n                for (int l = 0; l < T_C_VKQ::ne; ++l) {\n                    VKQ_C[i].x[l] *= KQ_max_scale_h2;\n                }\n            }\n        } else {\n#pragma unroll\n            for (int col = 0; col < cols_per_thread; ++col) {\n                const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);\n#pragma unroll\n                for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {\n#pragma unroll\n                    for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {\n                        VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;\n                    }\n                }\n            }\n        }\n#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n        const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);\n#pragma unroll\n        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {\n#pragma unroll\n            for (int l = 0; l < T_C_VKQ::ne; ++l) {\n                VKQ_C[i].x[l] *= KQ_max_scale_h2;\n            }\n        }\n#else // Volta\n        const int col = (threadIdx.x / 2) % 2;\n        const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);\n#pragma unroll\n        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {\n#pragma unroll\n            for (int l = 0; l < T_C_VKQ::ne; ++l) {\n                VKQ_C[i].x[l] *= KQ_max_scale_h2;\n            }\n        }\n#endif // defined(TURING_MMA_AVAILABLE)\n    }\n\n    // Combine VKQ accumulator values if np > 1.\n    // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.\n    // So also write VKQ accumulators to shared memory in column-major format if np == 1.\n\n    constexpr int tile_stride = nbatch_combine + 4;\n    static_assert((DV/2) % nbatch_combine == 0, \"bad nbatch_combine\");\n\n    if constexpr (cols_per_warp == 8) {\n        const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset\n        const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta\n        const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum\n\n        if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) {\n            // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.\n            ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;\n        }\n\n        __syncthreads();\n\n        if (np == 1) {\n            // No combination is needed, the meta data can be directly written from registers to VRAM.\n            if (needs_fixup && threadIdx.x < T_B_KQ::I) {\n                float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;\n                dstk_fixup_meta[jc_cwm] = KQ_cmr;\n            }\n            if (is_fixup && threadIdx.x < T_B_KQ::I) {\n                float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;\n                dstk_fixup_meta[jc_cwm] = KQ_cmr;\n            }\n        }\n    } else {\n        // jc_cwm = jc combine write meta\n        // KQ_cmr = KQ combine max rowsum\n        // Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale.\n#if defined(TURING_MMA_AVAILABLE)\n        const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);\n        const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);\n        const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;\n#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n        const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);\n        const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);\n        const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;\n#else // Volta\n        const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);\n        const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);\n        const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8;\n#endif // defined(TURING_MMA_AVAILABLE)\n\n        if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) {\n            ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;\n        }\n\n        __syncthreads();\n\n        if (np == 1) {\n            // No combination is needed, the meta data can be directly written from registers to VRAM.\n            if (needs_fixup && thread_should_write) {\n                float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;\n                dstk_fixup_meta[jc_cwm] = KQ_cmr;\n            }\n            if (is_fixup && thread_should_write) {\n                float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;\n                dstk_fixup_meta[jc_cwm] = KQ_cmr;\n            }\n        }\n    }\n\n    if (np > 1 && threadIdx.y % np == 0) {\n        // Combine the meta data for parallel warps via shared memory.\n        // Warps with threadIdx.y % np != 0 must NOT return early.\n        // All threads must return simultaneously to avoid race conditions with work on the next tile.\n\n        constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1;\n\n        const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);\n        float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;\n        float2 meta[nmeta];\n#pragma unroll\n        for (int imeta = 0; imeta < nmeta; ++imeta) {\n            meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];\n        }\n\n        float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.\n#pragma unroll\n        for (int imeta = 1; imeta < nmeta; ++imeta) {\n            KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x);\n        }\n#pragma unroll\n        for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {\n            if (offset < warp_size) {\n                KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));\n            }\n        }\n\n        float KQ_cms[nmeta]; // KQ combine max scale per warp.\n#pragma unroll\n        for (int imeta = 0; imeta < nmeta; ++imeta) {\n            KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn);\n        }\n\n        float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps.\n#pragma unroll\n        for (int imeta = 1; imeta < nmeta; ++imeta) {\n            KQ_crs += KQ_cms[imeta]*meta[imeta].y;\n        }\n#pragma unroll\n        for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {\n            if (offset < warp_size) {\n                KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);\n            }\n        }\n\n        __syncthreads();\n\n        // Write back combined meta data:\n#pragma unroll\n        for (int imeta = 0; imeta < nmeta; ++imeta) {\n            if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {\n                // Combined KQ max scale + rowsum.\n                meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);\n            }\n        }\n\n        // Combined KQ max + rowsum.\n        static_assert(cols_per_warp <= warp_size);\n        if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {\n            float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;\n            dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);\n        }\n        if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {\n            float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;\n            dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);\n        }\n    } else if (np > 1) {\n        // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.\n        // Therefore, all other warps also need to execute a __syncthreads().\n        // Otherwise the points at which warps synchronize with each other would become misaligned.\n        __syncthreads();\n    }\n\n#pragma unroll\n    for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {\n        if constexpr (cols_per_warp == 8) {\n            const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data\n#pragma unroll\n            for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) {\n                const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format.\n\n#pragma unroll\n                for (int l = 0; l < T_B_KQ::ne; ++l) {\n                    const int k = k1 + T_B_KQ::get_j(l);\n\n                    tile_Q[jc_cwd*tile_stride + k] = B.x[l];\n                }\n            }\n        } else {\n            const int j0 = threadIdx.y*cols_per_warp;\n#pragma unroll\n            for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {\n#pragma unroll\n                for (int l = 0; l < T_C_VKQ::ne; ++l) {\n                    const int j = j0 + T_C_VKQ::get_i(l);\n                    const int k = k1 + T_C_VKQ::get_j(l);\n\n                    tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l];\n                }\n            }\n        }\n\n        __syncthreads();\n\n        if (np == 1 || threadIdx.y % np == 0) {\n            // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.\n            // The values after that are for the partial results of the individual blocks.\n            float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));\n\n#pragma unroll\n            for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {\n                const int k0_start  = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);\n                const int k0_stop   =                             nbatch_combine - nbatch_combine % (1*stride_k);\n                const int stride_jc = warp_size / stride_k;\n\n                if (k0_start == k0_stop) {\n                    continue;\n                }\n\n#pragma unroll\n                for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {\n                    const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);\n\n                    if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {\n                        break;\n                    }\n\n                    const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;\n\n                    const int j_dst = jc_dst / ncols2;\n                    const int c_dst = jc_dst % ncols2;\n\n                    if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {\n                        continue;\n                    }\n\n                    const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;\n#pragma unroll\n                    for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {\n                        const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);\n\n                        float2 dstk_val = make_float2(0.0f, 0.0f);\n#pragma unroll\n                        for (int ip = 0; ip < np; ++ip) {\n                            const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0];\n                            const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]);\n                            dstk_val.x += dstk_val_add.x*KQ_crs;\n                            dstk_val.y += dstk_val_add.y*KQ_crs;\n                        }\n\n                        if (!needs_fixup && !is_fixup) {\n                            const float KQ_rowsum_j = meta_j[1];\n                            dstk_val.x /= KQ_rowsum_j;\n                            dstk_val.y /= KQ_rowsum_j;\n                        }\n\n                        if (is_fixup) {\n                            dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val;\n                        } else {\n                            dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val;\n                        }\n                    }\n                }\n            }\n        }\n        if (np > 1) {\n            __syncthreads();\n        }\n    }\n#else\n    GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,\n        scale, slope, logit_softcap, ne01, ne02, gqa_ratio,\n        stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,\n        jt, kb0_start, kb0_stop);\n    NO_DEVICE_CODE;\n#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)\n}\n\ntemplate<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>\n__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))\nstatic __global__ void flash_attn_ext_f16(\n        const char * __restrict__ Q,\n        const char * __restrict__ K,\n        const char * __restrict__ V,\n        const char * __restrict__ mask,\n        const char * __restrict__ sinks,\n        const int  * __restrict__ KV_max,\n        float      * __restrict__ dst,\n        float2     * __restrict__ dst_meta,\n        const float scale,\n        const float max_bias,\n        const float m0,\n        const float m1,\n        const uint32_t n_head_log2,\n        const float logit_softcap,\n        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,\n                            const int32_t nb01, const int32_t nb02, const int32_t nb03,\n        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,\n                            const int32_t nb11, const int32_t nb12, const int64_t nb13,\n                            const int32_t nb21, const int32_t nb22, const int64_t nb23,\n                            const int32_t ne31, const int32_t ne32, const int32_t ne33,\n                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {\n#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))\n\n    // Skip unused kernel variants for faster compilation:\n    if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {\n        NO_DEVICE_CODE;\n        return;\n    }\n#ifdef VOLTA_MMA_AVAILABLE\n    if (ncols1*ncols2 < 32) {\n        NO_DEVICE_CODE;\n        return;\n    }\n#endif // VOLTA_MMA_AVAILABLE\n\n#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING\n    if (ncols1*ncols2 > 32) {\n        NO_DEVICE_CODE;\n        return;\n    }\n#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING\n\n#if defined(AMD_WMMA_AVAILABLE)\n    if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {\n        NO_DEVICE_CODE;\n        return;\n    }\n#endif // defined(AMD_WMMA_AVAILABLE)\n\n#if defined(AMD_MFMA_AVAILABLE)\n    if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) {\n        NO_DEVICE_CODE;\n        return;\n    }\n#endif // defined(AMD_MFMA_AVAILABLE)\n\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n    constexpr int ncols     = ncols1 * ncols2;\n    constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);\n    constexpr int nthreads  = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);\n    constexpr int nwarps    = nthreads / warp_size;\n\n    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.\n\n    const int stride_Q1   = nb01 / sizeof(float2);\n    const int stride_Q2   = nb02 / sizeof(float2);\n    const int stride_K    = nb11 / sizeof(half2);\n    const int stride_mask = nb31 / sizeof(half);\n\n    const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);\n\n    const int iter_k     = (ne11      + (nbatch_fa - 1)) / nbatch_fa;\n    const int iter_j     = (ne01.z    + (ncols1    - 1)) / ncols1;\n    const int iter_z_gqa = (gqa_ratio + (ncols2    - 1)) / ncols2;\n\n    // kbc == k block continuous, current index in continuous ijk space.\n    int       kbc      = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;\n    const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;\n\n    // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.\n    // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).\n    // In the most general case >2 seams can fall into the same tile.\n\n    // kb0 == k start index when in the output tile.\n    int kb0_start = kbc % iter_k;\n    int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc);\n\n    while (kbc < kbc_stop && kb0_stop == iter_k) {\n        // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index\n        const int sequence =  kbc /(iter_k*iter_j*iter_z_gqa*ne12);\n        const int z_KV     = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);\n        const int zt_gqa   = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);\n        const int jt       = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;\n\n        const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.\n\n        const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);\n        const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*z_KV);\n        const half   * mask_h = ncols2 == 1 && !mask ? nullptr :\n            (const half *) (mask + nb33*(sequence % ne33));\n        float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);\n\n        const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);\n        const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;\n\n        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;\n\n        if (KV_max) {\n            kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);\n        }\n        constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.\n        if (kb0_start == 0) {\n            constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.\n            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>\n                (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,\n                 ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);\n        } else {\n            constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.\n            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>\n                (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,\n                 ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);\n        }\n\n        kbc += iter_k;\n        kbc -= kbc % iter_k;\n\n        kb0_start = 0;\n        kb0_stop  = min(iter_k, kbc_stop - kbc);\n    }\n\n    if (kbc >= kbc_stop) {\n        return;\n    }\n\n    // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.\n    const int sequence =  kbc /(iter_k*iter_j*iter_z_gqa*ne12);\n    const int z_KV     = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);\n    const int zt_gqa   = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);\n    const int jt       = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;\n\n    const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.\n\n    const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);\n    const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*z_KV);\n    const half   * mask_h = ncols2 == 1 && !mask ? nullptr :\n        (const half *) (mask + nb33*(sequence % ne33));\n    float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);\n\n    const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);\n    const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;\n\n    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;\n\n    if (KV_max) {\n        kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);\n    }\n\n    constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.\n    constexpr bool needs_fixup = false;\n    flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>\n        (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,\n         ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);\n#else\n    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,\n        max_bias, m0, m1, n_head_log2, logit_softcap,\n        ne00, ne01, ne02, ne03,\n              nb01, nb02, nb03,\n        ne10, ne11, ne12, ne13,\n              nb11, nb12, nb13,\n              nb21, nb22, nb23,\n              ne31, ne32, ne33,\n              nb31, nb32, nb33);\n    NO_DEVICE_CODE;\n#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))\n}\n\ntemplate <int DKQ, int DV, int ncols1, int ncols2>\nvoid ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * KQV = dst;\n    const int id = ggml_cuda_get_device();\n    const int cc = ggml_cuda_info().devices[id].cc;\n\n    constexpr int ncols = ncols1 * ncols2;\n\n    const int  nthreads       = ggml_cuda_fattn_mma_get_nthreads      (DKQ, DV, ncols, cc);\n    const int  nbatch_fa      = ggml_cuda_fattn_mma_get_nbatch_fa     (DKQ, DV, ncols, cc);\n    const int  nbatch_K2      = ggml_cuda_fattn_mma_get_nbatch_K2     (DKQ, DV, ncols, cc);\n    const int  nbatch_V2      = ggml_cuda_fattn_mma_get_nbatch_V2     (DKQ, DV, ncols, cc);\n    const int  nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc);\n    const bool Q_in_reg       = ggml_cuda_fattn_mma_get_Q_in_reg      (DKQ, DV, ncols, cc);\n    const int  nstages        = ggml_cuda_fattn_mma_get_nstages       (DKQ, DV, ncols1, ncols2, cc);\n\n    const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));\n    const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size;\n    const int nwarps         = nthreads / warp_size_host;\n\n    constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu\n\n    const size_t nbytes_shared_KV_1stage = nbatch_fa            * std::max(nbatch_K2 + 4,  nbatch_V2 + 4) * sizeof(half2);\n    const size_t nbytes_shared_KV_2stage = nbatch_fa            *         (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);\n    const size_t nbytes_shared_Q         = ncols                * (DKQ/2 + 4)                             * sizeof(half2);\n    const size_t nbytes_shared_mask      = ncols1               * (nbatch_fa/2 + 4)                       * sizeof(half2);\n    const size_t nbytes_shared_combine   = nwarps*cols_per_warp * (nbatch_combine + 4)                    * sizeof(half2);\n\n    const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;\n\n    const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ?\n        std::max(nbytes_shared_Q,  nbytes_shared_KV + nbytes_shared_mask) :\n                 nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);\n\n    float logit_softcap;\n    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));\n\n#if defined(GGML_USE_HIP)\n    using fattn_kernel_ptr_t = const void*;\n#else\n    using fattn_kernel_ptr_t = fattn_kernel_t;\n#endif // defined(GGML_USE_HIP)\n    fattn_kernel_t fattn_kernel;\n    if (logit_softcap == 0.0f) {\n        constexpr bool use_logit_softcap = false;\n        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;\n\n#if !defined(GGML_USE_MUSA)\n        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};\n        if (!shared_memory_limit_raised[id]) {\n            CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));\n            shared_memory_limit_raised[id] = true;\n        }\n#endif // !defined(GGML_USE_MUSA)\n    } else {\n        constexpr bool use_logit_softcap = true;\n        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;\n\n#if !defined(GGML_USE_MUSA)\n        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};\n        if (!shared_memory_limit_raised[id]) {\n            CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));\n            shared_memory_limit_raised[id] = true;\n        }\n#endif // !defined(GGML_USE_MUSA)\n    }\n\n    launch_fattn<DV, ncols1, ncols2>\n        (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host);\n}\n\n\n#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2)                          \\\n    template void ggml_cuda_flash_attn_ext_mma_f16_case                           \\\n    <DKQ, DV, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \\\n\n#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols)   \\\n    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1,  1); \\\n    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2,  2); \\\n    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4,  4); \\\n    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8,  8); \\\n    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \\\n\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64,   8)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  80,   8)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  96,   8)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112,   8)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128,   8)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,   8)\n\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64,  16)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  80,  16)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  96,  16)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112,  16)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128,  16)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,  16)\n\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64,  32)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  80,  32)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  96,  32)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112,  32)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128,  32)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,  32)\n\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64,  64)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  80,  64)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  96,  64)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112,  64)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128,  64)\nDECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,  64)\n\n// The number of viable configurations for Deepseek is very limited:\nextern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);\nextern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);\nextern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);\n\n// For GLM 4.7 Flash\nextern DECL_FATTN_MMA_F16_CASE(576, 512,  4,  4);\nextern DECL_FATTN_MMA_F16_CASE(576, 512,  8,  4);\nextern DECL_FATTN_MMA_F16_CASE(576, 512, 16,  4);\nextern DECL_FATTN_MMA_F16_CASE(576, 512,  1, 32);\nextern DECL_FATTN_MMA_F16_CASE(576, 512,  2, 32);\n"
  },
  {
    "path": "src/ggml-cuda/fattn-tile.cu",
    "content": "#include \"common.cuh\"\n#include \"fattn-tile.cuh\"\n#include \"fattn-wmma-f16.cuh\"\n\nvoid ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * K = dst->src[1];\n    const ggml_tensor * V = dst->src[2];\n    switch (K->ne[0]) {\n        case  40: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_cuda_flash_attn_ext_tile_case< 40,  40>(ctx, dst);\n        } break;\n        case  64: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_cuda_flash_attn_ext_tile_case< 64,  64>(ctx, dst);\n        } break;\n        case  72: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_cuda_flash_attn_ext_tile_case< 72,  72>(ctx, dst);\n        } break;\n        case  80: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_cuda_flash_attn_ext_tile_case< 80,  80>(ctx, dst);\n        } break;\n        case  96: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_cuda_flash_attn_ext_tile_case< 96,  96>(ctx, dst);\n        } break;\n        case 112: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst);\n        } break;\n        case 128: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);\n        } break;\n        case 256: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);\n        } break;\n        case 576: {\n            GGML_ASSERT(V->ne[0] == 512);\n            ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);\n        } break;\n        default: {\n            GGML_ABORT(\"Unsupported head size\");\n        } break;\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/fattn-tile.cuh",
    "content": "#include \"common.cuh\"\n#include \"fattn-common.cuh\"\n#include \"fattn-wmma-f16.cuh\"\n\n// nbatch_fa == number of KQ rows to process per iteration\n// nbatch_K == number of K columns to load in parallel for KQ calculation\n\n// TODO optimize kernel parameters for FP16 NVIDIA (P100)\n// TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112\n\n// The ROCm compiler cannot handle templating in __launch_bounds__.\n// As a workaround, define a macro to package the kernel parameters as uint32_t:\n#define GGML_CUDA_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \\\n    if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) {                                          \\\n        static_assert((nthreads)          <= 512, \"bad nthreads\");                                    \\\n        static_assert((occupancy)         <=   8, \"bad occupancy\");                                   \\\n        static_assert((nbatch_fa)         <= 256, \"bad nbatch_fa\");                                   \\\n        static_assert((nbatch_K)          <= 256, \"bad nbatch_K\");                                    \\\n        return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23);    \\\n    }                                                                                                 \\\n\nstatic constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp16(const int DKQ, const int DV, const int ncols) {\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  64,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  64,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  64,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  64,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  64,  40)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 256, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 256, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  64,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  64,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  64,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  64,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  64,  72)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  64,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  64,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  64,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  64,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  64,  40)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  64,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  64,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  64,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  64,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  64,  48)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  64,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  64,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  64,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  64,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  64,  56)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  2,  64, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  8, 256, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  2,  64, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  64,  64)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)\n\n    return 0;\n}\n\nstatic constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp32(const int DKQ, const int DV, const int ncols) {\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  2, 128, 3,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 3,  32,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 3,  32,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 128, 3,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  2, 128, 3,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 3,  32, 128)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  8, 128, 3,  64, 128)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3,  32, 128)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  2, 128, 3,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 3,  32,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  32, 256)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32,  64)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  32,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  32,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  32,  64)\n\n    return 0;\n}\n\nstatic constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 64, 256, 2,  32,  40)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 3,  32,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 3,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 2,  32,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 256, 2, 128,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 64, 256, 2,  64,  64)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 64, 256, 2,  32,  72)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 64, 256, 2,  32,  40)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 64, 256, 2,  32,  48)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2,  32,  56)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  2, 256, 2, 128,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 2,  64, 128)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  8, 256, 2,  64, 128)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2,  64, 128)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2,  64,  32)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  2, 256, 2, 128,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  4, 256, 2,  64, 128)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  64, 128)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32, 128)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128,  64)\n\n    return 0;\n}\n\nstatic constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40,  40, 64, 256, 2,  32,  40)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 8,  32,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  4,  64, 8,  32,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 5, 128,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 128, 5, 128,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 128, 4,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 64, 128, 5,  64,  64)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 64, 256, 2,  32,  72)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80, 64, 256, 2,  32,  40)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96,  96, 64, 256, 2,  32,  48)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2,  32,  56)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  2,  64, 8,  32,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 8,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128,  8, 128, 8,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3,  64,  64)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  2,  64, 8,  32,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 6,  32, 256)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256,  8, 128, 6,  32, 256)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5,  32, 256)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3,  64, 128)\n\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4,  64,  64)\n    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128,  64)\n\n    return 0;\n}\n\nstatic __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {\n    if (GGML_CUDA_CC_IS_AMD(cc)) {\n        if (GGML_CUDA_CC_IS_RDNA(cc)) {\n            return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);\n        }\n        return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);\n    }\n    if (fast_fp16_available(cc)) {\n        return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);\n    }\n    return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols);\n}\n\nstatic constexpr __device__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) {\n#ifdef GGML_USE_HIP\n#ifdef RDNA\n    return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);\n#else\n    return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);\n#endif // RDNA\n#else\n#ifdef FAST_FP16_AVAILABLE\n    return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);\n#else\n    return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols);\n#endif // FAST_FP16_AVAILABLE\n#endif // GGML_USE_HIP\n}\n\nstatic __host__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {\n    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1);\n}\n\nstatic constexpr __device__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) {\n    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1);\n}\n\nstatic __host__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {\n    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1);\n}\n\nstatic constexpr __device__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) {\n    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1);\n}\n\nstatic __host__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {\n    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1);\n}\n\nstatic constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {\n    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1);\n}\n\nstatic __host__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) {\n    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1);\n}\n\nstatic constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) {\n    return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1);\n}\n\n// TODO: deduplicate with mma-f16\ntemplate<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>\nstatic __device__ __forceinline__ void flash_attn_tile_load_tile(\n        const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {\n    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();\n    constexpr int cpy_ne = cpy_nb / 4;\n\n    auto load = [&] __device__ (const int n) {\n        const int stride_j = warp_size >> n;\n\n        if (stride_j == 0) {\n            return;\n        }\n\n        const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j);\n        const int j0_stop  =                             ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j);\n        const int stride_i = warp_size / stride_j;\n\n        if (j0_start == j0_stop) {\n            return;\n        }\n\n#pragma unroll\n        for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {\n            const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j);\n\n            if (i0 + nwarps*stride_i <= I || i < I) {\n#pragma unroll\n                for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {\n                    const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;\n\n                    const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}};\n                    ggml_cuda_memcpy_1<cpy_nb>(\n                        tile_KV + i*(J/2 + J_padding) + j,\n                        !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);\n                }\n            }\n        }\n    };\n    // 1: max 64*16=512 bytes, 512 half\n    // 2: max 32*16=512 bytes, 256 half\n    // 3: max 16*16=256 bytes, 128 half\n    // 4: max  8*16=128 bytes,  64 half\n    // 5: max  4*16= 64 bytes,  32 half\n    // 6: max  2*16= 32 bytes,  16 half\n    // 7: max  1*16= 16 bytes,   8 half\n    static_assert(J % 8 == 0, \"bad J\");\n    static_assert((J/2) % cpy_ne == 0, \"bad J\");\n    ggml_cuda_unroll<7>{}(load);\n}\n\ntemplate<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>\nstatic __device__ __forceinline__ void flash_attn_tile_load_tile(\n        const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {\n    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();\n    constexpr int cpy_ne = cpy_nb / 4;\n\n    auto load = [&] __device__ (const int n) {\n        const int stride_j = warp_size >> n;\n\n        if (stride_j == 0) {\n            return;\n        }\n\n        const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j);\n        const int j0_stop  =                             (J/cpy_ne) - (J/cpy_ne) % (1*stride_j);\n        const int stride_i = warp_size / stride_j;\n\n        if (j0_start == j0_stop) {\n            return;\n        }\n\n#pragma unroll\n        for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {\n            const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j);\n\n            if (i0 + nwarps*stride_i <= I || i < I) {\n#pragma unroll\n                for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {\n                    const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);\n\n                    const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};\n                    __align__(16) half2 tmp_h2[cpy_ne/2];\n                    ggml_cuda_memcpy_1<sizeof(tmp_h2)>(\n                        tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);\n\n                    __align__(16) float2 tmp_f2[cpy_ne/2];\n#pragma unroll\n                    for (int l = 0; l < cpy_ne/2; ++l) {\n                        tmp_f2[l] = __half22float2(tmp_h2[l]);\n                    }\n                    ggml_cuda_memcpy_1<sizeof(tmp_f2)>(tile_KV + i*(J + J_padding) + 2*j, tmp_f2);\n                }\n            }\n        }\n    };\n    // 1: max 32*16=512 bytes, 128 float\n    // 2: max 16*16=256 bytes,  64 float\n    // 3: max  8*16=128 bytes,  32 float\n    // 4: max  4*16= 64 bytes,  16 float\n    // 5: max  2*16= 32 bytes,   8 float\n    static_assert(J % 8 == 0, \"bad J\");\n    static_assert(J % cpy_ne == 0, \"bad J\");\n    ggml_cuda_unroll<5>{}(load);\n}\n\n// Function that performs a single iteration in for the KQ matrix multiplication:\ntemplate <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int nbatch_fa, int nbatch_K,\n    bool use_logit_softcap, bool oob_check, typename T_vec_dot>\nstatic __device__ __forceinline__ void flash_attn_tile_iter_KQ(\n        T_vec_dot   * const Q_tmp,\n        const half2 * const __restrict__ K_h2,\n        T_vec_dot   * const KV_tmp,\n        const int stride_K2,\n        const int k_VKQ_0,\n        const int k_VKQ_sup,\n        const int k_KQ_0,\n        float * KQ_acc) {\n    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();\n    constexpr int cpy_ne = cpy_nb / 4;\n\n    constexpr int ncols = ncols1*ncols2;\n    constexpr int cpw   = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp\n    constexpr int np    = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column\n\n    flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>\n        (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);\n    __syncthreads();\n\n#ifdef FAST_FP16_AVAILABLE\n    static_assert((nbatch_K/2) % cpy_ne == 0, \"bad nbatch_K\");\n#pragma unroll\n    for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {\n        __align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];\n        __align__(16) half2 Q_k[cpw][cpy_ne];\n#else\n    static_assert(nbatch_K % cpy_ne == 0, \"bad nbatch_K\");\n#pragma unroll\n    for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {\n        __align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne];\n        __align__(16) float Q_k[cpw][cpy_ne];\n#endif // FAST_FP16_AVAILABLE\n\n#pragma unroll\n        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {\n            const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;\n\n#ifdef FAST_FP16_AVAILABLE\n            ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]);\n#else\n            ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K   + cpy_ne) + k_KQ_1]);\n#endif // FAST_FP16_AVAILABLE\n        }\n#pragma unroll\n        for (int jc0 = 0; jc0 < cpw; ++jc0) {\n            const int jc = jc0 + (threadIdx.y / np)*cpw;\n\n#ifdef FAST_FP16_AVAILABLE\n            ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]);\n#else\n            ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc* DKQ    + k_KQ_0   + k_KQ_1]);\n#endif // FAST_FP16_AVAILABLE\n        }\n\n#pragma unroll\n        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {\n#pragma unroll\n            for (int jc0 = 0; jc0 < cpw; ++jc0) {\n#pragma unroll\n                for (int k = 0; k < cpy_ne; ++k) {\n                    ggml_cuda_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]);\n                }\n            }\n        }\n    }\n\n    if (k_KQ_0 + nbatch_K < DKQ) {\n        __syncthreads(); // Sync not needed on last iteration.\n    }\n}\n\n// Function that performs a single iteration of the main loop over up to nbatch_fa tokens.\ntemplate <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int DV, int nbatch_fa, int nbatch_K,\n    bool use_logit_softcap, bool oob_check, typename T_vec_dot, typename T_KQ, typename T_acc>\nstatic __device__ __forceinline__ void flash_attn_tile_iter(\n        T_vec_dot * const Q_tmp,\n        const half2 * const __restrict__ K_h2,\n        const half2 * const __restrict__ V_h2,\n        const half  * const __restrict__ mask,\n        const uint3 ne01,\n        const float logit_softcap,\n        const float slope,\n        T_KQ      * const KQ,\n        T_vec_dot * const KV_tmp,\n        const int stride_K2,\n        const int stride_V2,\n        const int stride_mask,\n        float * const KQ_max,\n        float * const KQ_sum,\n        T_acc * const VKQ,\n        const int k_VKQ_0,\n        const int k_VKQ_max,\n        const int col_Q_0) {\n    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();\n    constexpr int cpy_ne = cpy_nb / 4;\n\n    constexpr int ncols = ncols1*ncols2;\n    constexpr int cpw   = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp\n    constexpr int np    = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column\n\n    constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.\n\n    // KQ_cs == KQ chunk size, number of KQ values in j direction to store as one contiguous chunk in memory.\n    // KQ is originally 2D but uses a Z-shaped 3D memory pattern like KQ[ncols/KQ_cs][DVp][KQ_cs].\n#ifdef FAST_FP16_AVAILABLE\n    constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;\n#else\n    constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;\n#endif // FAST_FP16_AVAILABLE\n    static_assert(cpw % KQ_cs == 0, \"bad KQ_cs\");\n    const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data\n\n    float KQ_max_new[cpw];\n#pragma unroll\n    for (int jc0 = 0; jc0 < cpw; ++jc0) {\n        KQ_max_new[jc0] = KQ_max[jc0];\n    }\n\n    float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication.\n\n    // KQ = K @ Q matrix multiplication:\n    constexpr int nbatch_K_last = DKQ % nbatch_K;\n#pragma unroll\n    for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) {\n        flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>(\n            Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);\n    }\n    if (nbatch_K_last > 0) {\n        constexpr int k_KQ_0 = DKQ - nbatch_K_last;\n        flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K_last, use_logit_softcap, oob_check>(\n            Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);\n    }\n\n    // Apply logit softcap + mask, update KQ_max:\n#pragma unroll\n    for (int jc0 = 0; jc0 < cpw; ++jc0) {\n        const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01);\n\n#pragma unroll\n        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {\n            const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;\n\n#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)\n            // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.\n            // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.\n            KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f;\n#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)\n\n            if (use_logit_softcap) {\n                KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);\n            }\n\n            if (!oob_check || i_KQ < k_VKQ_sup) {\n                KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?\n                    slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;\n\n                KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] + FATTN_KQ_MAX_OFFSET);\n            }\n        }\n\n        KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);\n    }\n\n    if constexpr (np == 1) {\n        __syncthreads();\n    } else {\n        static_assert(cpw == 1, \"bad cpw\");\n        __shared__ float KQ_max_new_shared[nwarps];\n        if (threadIdx.x == 0) {\n            KQ_max_new_shared[threadIdx.y] = KQ_max_new[0];\n        }\n        __syncthreads();\n        KQ_max_new[0] = KQ_max_new_shared[(threadIdx.y & ~(np-1)) + threadIdx.x % np];\n        KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);\n    }\n\n    // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:\n#pragma unroll\n    for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {\n#ifdef FAST_FP16_AVAILABLE\n        __align__(16) half  tmp[nbatch_fa/(np*warp_size)][KQ_cs];\n#else\n        __align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs];\n#endif // FAST_FP16_AVAILABLE\n\n#pragma unroll\n        for (int jc1 = 0; jc1 < KQ_cs; ++jc1) {\n            const int jc = jc0 + jc1;\n\n            const float KQ_max_scale = expf(KQ_max[jc] - KQ_max_new[jc]);\n            KQ_max[jc] = KQ_max_new[jc];\n\n            float KQ_sum_add = 0.0f;\n#pragma unroll\n            for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {\n                const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast<uint32_t>(k_VKQ_sup) ?\n                    expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;\n                KQ_sum_add += val;\n                tmp[i0/(np*warp_size)][jc1] = val;\n            }\n            KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;\n\n#ifdef FAST_FP16_AVAILABLE\n            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {\n                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;\n            }\n#else\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {\n                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale;\n                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale;\n            }\n#endif // FAST_FP16_AVAILABLE\n        }\n\n#pragma unroll\n        for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {\n            const int i = i0 + (threadIdx.y % np)*warp_size + threadIdx.x;\n\n            ggml_cuda_memcpy_1<sizeof(tmp[0])>(\n                KQ + (jc0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs))*(nbatch_fa*KQ_cs) + i*KQ_cs,\n                tmp[i0/(np*warp_size)]);\n        }\n    }\n\n    // VKQ = V @ KQ matrix multiplication:\n    static_assert(DV <= DKQ, \"bad DV\");\n    static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), \"bad nbatch_K\");\n    constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K.\n    static_assert(nbatch_fa % nbatch_V == 0, \"bad nbatch_V\");\n    static_assert(nbatch_V % np == 0, \"bad nbatch_V\");\n#pragma unroll\n    for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {\n        flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>\n            (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);\n        __syncthreads();\n\n#ifdef FAST_FP16_AVAILABLE\n#pragma unroll\n        for (int k1 = 0; k1 < nbatch_V; k1 += np) {\n            __align__(16) half2 V_k[(DVp/2)/warp_size];\n            __align__(16) half2 KQ_k[cpw];\n\n            constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {\n                ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[(k1 + threadIdx.y % np)*(DV/2) + i0 + threadIdx.x*cpy_ne_D]);\n            }\n#pragma unroll\n            for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {\n                const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);\n\n                __align__(16) half tmp[KQ_cs];\n                ggml_cuda_memcpy_1<KQ_cs*sizeof(half)>(\n                    &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);\n#pragma unroll\n                for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) {\n                    KQ_k[jc_VKQ_0+jc_VKQ_1] = __half2half2(tmp[jc_VKQ_1]);\n                }\n            }\n\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {\n#pragma unroll\n                for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {\n                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size] += V_k[i0/warp_size]*KQ_k[jc_VKQ_0];\n                }\n            }\n        }\n#else\n#pragma unroll\n        for (int k1 = 0; k1 < nbatch_V; k1 += np) {\n            __align__(16) float2 V_k[(DVp/2)/warp_size];\n            __align__(16) float  KQ_k[cpw];\n\n            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;\n#pragma unroll\n            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {\n                ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + threadIdx.y % np)*DV + i0 + threadIdx.x*cpy_ne_D]);\n            }\n#pragma unroll\n            for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {\n                const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);\n\n                ggml_cuda_memcpy_1<KQ_cs*sizeof(float)>(\n                    &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);\n            }\n\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {\n#pragma unroll\n                for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {\n                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[jc_VKQ_0];\n                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[jc_VKQ_0];\n                }\n            }\n        }\n#endif // FAST_FP16_AVAILABLE\n\n        __syncthreads();\n    }\n}\n\ntemplate<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size\n__launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2))\nstatic __global__ void flash_attn_tile(\n        const char * __restrict__ Q,\n        const char * __restrict__ K,\n        const char * __restrict__ V,\n        const char * __restrict__ mask,\n        const char * __restrict__ sinks,\n        const int  * __restrict__ KV_max,\n        float      * __restrict__ dst,\n        float2     * __restrict__ dst_meta,\n        const float scale,\n        const float max_bias,\n        const float m0,\n        const float m1,\n        const uint32_t n_head_log2,\n        const float logit_softcap,\n        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,\n                            const int32_t nb01, const int32_t nb02, const int32_t nb03,\n        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,\n                            const int32_t nb11, const int32_t nb12, const int64_t nb13,\n                            const int32_t nb21, const int32_t nb22, const int64_t nb23,\n                            const int32_t ne31, const int32_t ne32, const int32_t ne33,\n                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {\n#ifdef FLASH_ATTN_AVAILABLE\n\n    // Skip unused kernel variants for faster compilation:\n\n    if (\n#ifdef GGML_USE_WMMA_FATTN\n            (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||\n#endif // GGML_USE_WMMA_FATTN\n            (use_logit_softcap && !(DV == 128 || DV == 256))\n    ) {\n        GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,\n            max_bias, m0, m1, n_head_log2, logit_softcap,\n            ne00, ne01, ne02, ne03,\n                  nb01, nb02, nb03,\n            ne10, ne11, ne12, ne13,\n                  nb11, nb12, nb13,\n                  nb21, nb22, nb23,\n                  ne31, ne32, ne33,\n                  nb31, nb32, nb33);\n        NO_DEVICE_CODE;\n        return;\n    }\n\n    static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, \"kernel config not defined\");\n\n    constexpr int ncols     = ncols1*ncols2;\n    constexpr int warp_size = 32;\n    constexpr int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size;\n    constexpr int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2);\n    constexpr int nbatch_K  = ggml_cuda_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2);\n\n    // In this kernel Q, K, V are matrices while i, j, k are matrix indices.\n\n    const int col_Q_0 = blockIdx.x * ncols1; // Index of the first Q column for this CUDA block to work on.\n\n    const int sequence = blockIdx.z / (ne02/ncols2);\n    const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2)\n    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.\n    const float * Q_f  = (const float *) (Q + nb03*sequence + nb02* head0);\n    const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));\n    const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape\n\n    const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr;\n\n    const int stride_K2   = nb11 / sizeof(half2);\n    const int stride_V2   = nb21 / sizeof(half2);\n    const int stride_mask = nb31 / sizeof(half);\n\n    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;\n\n    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();\n    constexpr int cpy_ne = cpy_nb / 4;\n\n    constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp.\n    constexpr int np  = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column.\n    static_assert(cpw == 1 || np == 1, \"bad cpw / np\");\n    static_assert(nbatch_fa % (np*warp_size) == 0, \"nbatch_fa % (np*warp_size) != 0\");\n\n    constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size.\n    constexpr int DVp  = (DV  + 2*warp_size - 1) & ~(2*warp_size - 1); // DV  padded to multiple of 2*warp_size.\n\n    // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel.\n    // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11.\n    //     KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV).\n    // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications.\n    // VKQ == Accumulators in registers for the final VKQ result.\n#ifdef FAST_FP16_AVAILABLE\n    __shared__ half2 Q_tmp[ncols * DKQ/2];\n    __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV];\n    __shared__ half  KQ[ncols * nbatch_fa];\n    __align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};\n#else\n    __shared__ float Q_tmp[ncols * DKQ];\n    __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV];\n    __shared__ float KQ[ncols * nbatch_fa];\n    __align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};\n#endif // FAST_FP16_AVAILABLE\n\n    float KQ_max[cpw];\n#pragma unroll\n    for (int j0 = 0; j0 < ncols; j0 += nwarps) {\n        KQ_max[j0/nwarps] = -FLT_MAX/2.0f;\n    }\n    float KQ_sum[cpw] = {0.0f};\n\n    // Load Q data, convert to FP16 if fast:\n#pragma unroll\n    for (int jc0 = 0; jc0 < cpw; ++jc0) {\n        const int jc = jc0 + (threadIdx.y / np)*cpw;\n\n        const int j = jc / ncols2;\n        const int c = jc % ncols2;\n\n        constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size;\n\n#pragma unroll\n        for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {\n            if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {\n                __align__(16) float tmp_f[cpy_ne_D] = {0.0f};\n                ggml_cuda_memcpy_1<sizeof(tmp_f)>\n                    (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))\n                                 + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);\n\n#pragma unroll\n                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {\n                    tmp_f[i1] *= scale;\n                }\n\n#ifdef FAST_FP16_AVAILABLE\n                __align__(16) half2 tmp_h2[cpy_ne_D/2];\n#pragma unroll\n                for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {\n                    tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);\n#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)\n                    // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.\n                    // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.\n                    tmp_h2[i1/2] *= make_half2(0.25f, 0.25f);\n#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)\n                }\n                ggml_cuda_memcpy_1<sizeof(tmp_h2)>(\n                    &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)],\n                    tmp_h2);\n#else\n                ggml_cuda_memcpy_1<sizeof(tmp_f)>(\n                    &Q_tmp[jc* DKQ    + i0   + (threadIdx.y % np)*(warp_size*cpy_ne_D)   + threadIdx.x* cpy_ne_D],\n                    tmp_f);\n#endif // FAST_FP16_AVAILABLE\n            }\n        }\n    }\n\n    __syncthreads();\n\n    // Main loop over KV cache:\n    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;\n    if (ncols2 == 1) {\n        // Branch with out-of-bounds checks.\n        int k_VKQ_0 = blockIdx.y*nbatch_fa;\n        while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {\n            constexpr bool oob_check = false;\n            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>\n                (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,\n                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);\n            k_VKQ_0 += gridDim.y*nbatch_fa;\n        }\n        if (k_VKQ_0 < k_VKQ_max) {\n            constexpr bool oob_check = true;\n            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>\n                (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,\n                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);\n        }\n    } else {\n        // Branch without out-of-bounds checks.\n        for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) {\n            constexpr bool oob_check = false;\n            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>\n                (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,\n                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);\n        }\n    }\n\n#pragma unroll\n    for (int jc0 = 0; jc0 < cpw; ++jc0) {\n        KQ_sum[jc0] = warp_reduce_sum<warp_size>(KQ_sum[jc0]);\n    }\n\n    if constexpr (np > 1) {\n        static_assert(cpw == 1, \"bad cpw\");\n        static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, \"KV_tmp too small\");\n\n#ifdef FAST_FP16_AVAILABLE\n        half2 * VKQ_combine    = (half2 *) KV_tmp;\n#else\n        float * VKQ_combine    = (float *) KV_tmp;\n#endif // FAST_FP16_AVAILABLE\n        float * KQ_sum_combine = (float *) Q_tmp;\n\n        if (threadIdx.y % np != 0) {\n#ifdef FAST_FP16_AVAILABLE\n            constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {\n                ggml_cuda_memcpy_1<cpy_ne_D*4>(&VKQ_combine[threadIdx.y*(DVp/2) + i0 + threadIdx.x*cpy_ne_D], &VKQ[i0/warp_size]);\n            }\n#else\n            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;\n#pragma unroll\n            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {\n                ggml_cuda_memcpy_1<cpy_ne_D*4>(\n                    &VKQ_combine[threadIdx.y*DVp + i0 + threadIdx.x*cpy_ne_D], ((const float *) VKQ) + i0/warp_size);\n            }\n#endif // FAST_FP16_AVAILABLE\n\n            if (threadIdx.x == 0) {\n                KQ_sum_combine[threadIdx.y] = KQ_sum[0];\n            }\n\n            return;\n        }\n\n        __syncthreads();\n\n#pragma unroll\n        for (int ip = 1; ip < np; ++ip) {\n#ifdef FAST_FP16_AVAILABLE\n            constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {\n                __align__(16) half2 tmp[cpy_ne_D];\n                ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]);\n#pragma unroll\n                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {\n                    VKQ[i0/warp_size + i1] += tmp[i1];\n                }\n            }\n#else\n            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;\n#pragma unroll\n            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {\n                __align__(16) float tmp[cpy_ne_D];\n                ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]);\n#pragma unroll\n                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {\n                    ((float *)VKQ)[i0/warp_size + i1] += tmp[i1];\n                }\n            }\n#endif // FAST_FP16_AVAILABLE\n\n            KQ_sum[0] += KQ_sum_combine[threadIdx.y + ip];\n        }\n    }\n\n    // Attention sink: adjust KQ max and sum only for the first of all parallel blocks:\n    if (sinks && blockIdx.y == 0) {\n#pragma unroll\n        for (int jc0 = 0; jc0 < cpw; ++jc0) {\n            const int jc = jc0 + (threadIdx.y/np)*cpw;\n            const float sink = ((const float *) sinks)[head0 + jc % ncols2];\n\n            float KQ_max_new_j = fmaxf(KQ_max[jc0], sink);\n            const float KQ_max_scale = expf(KQ_max[jc0] - KQ_max_new_j);\n            KQ_max[jc0] = KQ_max_new_j;\n\n            const float val = expf(sink - KQ_max[jc0]);\n            KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val;\n\n#ifdef FAST_FP16_AVAILABLE\n            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {\n                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;\n            }\n#else\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {\n                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale;\n                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale;\n            }\n#endif // FAST_FP16_AVAILABLE\n        }\n    }\n\n    // Write back results:\n#pragma unroll\n    for (int jc0 = 0; jc0 < cpw; ++jc0) {\n        const int jc = jc0 + (threadIdx.y/np)*cpw;\n\n        const int j = jc / ncols2;\n        const int c = jc % ncols2;\n\n        if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) {\n            return;\n        }\n\n        const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;\n\n        const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;\n\n#ifdef FAST_FP16_AVAILABLE\n        constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;\n#pragma unroll\n        for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {\n            __align__(16) float2 tmp[cpy_ne_D];\n#pragma unroll\n            for (int i1 = 0; i1 < cpy_ne_D; ++i1) {\n                tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);\n                tmp[i1].x *= scale;\n                tmp[i1].y *= scale;\n            }\n            if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {\n                ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);\n            }\n        }\n#else\n        constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;\n#pragma unroll\n        for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {\n            if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {\n#pragma unroll\n                for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {\n                    VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;\n                    VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;\n                }\n                ggml_cuda_memcpy_1<cpy_ne_D*4>(\n                    &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],\n                    &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);\n            }\n        }\n#endif // FAST_FP16_AVAILABLE\n\n        if (gridDim.y != 1 && threadIdx.x == 0) {\n            dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]);\n        }\n    }\n#else\n    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,\n        max_bias, m0, m1, n_head_log2, logit_softcap,\n        ne00, ne01, ne02, ne03,\n              nb01, nb02, nb03,\n        ne10, ne11, ne12, ne13,\n              nb11, nb12, nb13,\n              nb21, nb22, nb23,\n              ne31, ne32, ne33,\n              nb31, nb32, nb33);\n    NO_DEVICE_CODE;\n#endif // FLASH_ATTN_AVAILABLE\n}\n\ntemplate <int DKQ, int DV, int ncols2, bool use_logit_softcap>\nstatic void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * Q = dst->src[0];\n\n    const int id        = ggml_cuda_get_device();\n    const int cc        = ggml_cuda_info().devices[id].cc;\n    const int warp_size = 32;\n\n    constexpr size_t nbytes_shared = 0;\n\n#ifdef GGML_USE_HIP\n    if constexpr (DV <= 128) {\n        if (Q->ne[1] > 32/ncols2) {\n            constexpr int cols_per_block = 64;\n            const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;\n            const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);\n            fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;\n            launch_fattn<DV, cols_per_block/ncols2, ncols2>\n                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);\n            return;\n        }\n    }\n#endif // GGML_USE_HIP\n\n#ifndef GGML_USE_HIP\n    if constexpr (DV <= 256)\n#endif // GGML_USE_HIP\n    {\n        if (Q->ne[1] > 16/ncols2) {\n            constexpr int cols_per_block = 32;\n            const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;\n            const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);\n            fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;\n            launch_fattn<DV, cols_per_block/ncols2, ncols2>\n                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);\n            return;\n        }\n    }\n\n    if (Q->ne[1] > 8/ncols2) {\n        constexpr int cols_per_block = 16;\n        const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;\n        const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);\n        fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;\n        launch_fattn<DV, cols_per_block/ncols2, ncols2>\n            (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);\n        return;\n    }\n\n    if constexpr (ncols2 <= 8) {\n        if (Q->ne[1] > 4/ncols2) {\n            constexpr int cols_per_block = 8;\n            const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;\n            const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);\n            fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;\n            launch_fattn<DV, cols_per_block/ncols2, ncols2>\n                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);\n            return;\n        }\n    }\n\n    if constexpr (ncols2 <= 4) {\n        if (Q->ne[1] > 2/ncols2) {\n            constexpr int cols_per_block = 4;\n            const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;\n            const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);\n            fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;\n            launch_fattn<DV, cols_per_block/ncols2, ncols2>\n                (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);\n            return;\n        }\n    }\n\n    if constexpr (ncols2 <= 2) {\n        constexpr int cols_per_block = 2;\n        const int nwarps    = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;\n        const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);\n        fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;\n        launch_fattn<DV, cols_per_block/ncols2, ncols2>\n            (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);\n        return;\n    }\n\n    GGML_ABORT(\"fatal error\");\n}\n\ntemplate <int DKQ, int DV, bool use_logit_softcap>\nstatic void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * KQV  = dst;\n    const ggml_tensor * Q    = dst->src[0];\n    const ggml_tensor * K    = dst->src[1];\n    const ggml_tensor * mask = dst->src[3];\n\n    float max_bias = 0.0f;\n    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));\n\n    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);\n    const int gqa_ratio = Q->ne[2] / K->ne[2];\n\n    // On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases.\n    // However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented.\n    const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc);\n    const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;\n    const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;\n\n    if constexpr (DV == 512) {\n        if (use_gqa_opt && gqa_ratio % 16 == 0) {\n            launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);\n            return;\n        }\n        if (use_gqa_opt && gqa_ratio % 4 == 0) {\n            launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);\n            return;\n        }\n    }\n\n    if constexpr (DV <= 256) {\n        if (use_gqa_opt && gqa_ratio % 8 == 0) {\n            launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);\n            return;\n        }\n\n        if (use_gqa_opt && gqa_ratio % 4 == 0) {\n            launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);\n            return;\n        }\n\n        if (use_gqa_opt && gqa_ratio % 2 == 0) {\n            launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);\n            return;\n        }\n\n        launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);\n        return;\n    }\n    GGML_ABORT(\"fatal error\");\n}\n\ntemplate <int DKQ, int DV>\nvoid ggml_cuda_flash_attn_ext_tile_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * KQV = dst;\n\n    float logit_softcap;\n    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));\n\n    if (logit_softcap == 0.0f) {\n        constexpr bool use_logit_softcap = false;\n        launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);\n    } else {\n        constexpr bool use_logit_softcap = true;\n        launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);\n    }\n}\n\nvoid ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\n#define DECL_FATTN_TILE_CASE(DKQ, DV)                             \\\n    template void ggml_cuda_flash_attn_ext_tile_case              \\\n    <DKQ, DV>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \\\n\nextern DECL_FATTN_TILE_CASE( 40,  40);\nextern DECL_FATTN_TILE_CASE( 64,  64);\nextern DECL_FATTN_TILE_CASE( 72,  72);\nextern DECL_FATTN_TILE_CASE( 80,  80);\nextern DECL_FATTN_TILE_CASE( 96,  96);\nextern DECL_FATTN_TILE_CASE(112, 112);\nextern DECL_FATTN_TILE_CASE(128, 128);\nextern DECL_FATTN_TILE_CASE(256, 256);\nextern DECL_FATTN_TILE_CASE(576, 512);\n"
  },
  {
    "path": "src/ggml-cuda/fattn-vec.cuh",
    "content": "#include \"common.cuh\"\n#include \"fattn-common.cuh\"\n\nstatic int ggml_cuda_fattn_vec_get_nthreads_host(const int cc) {\n    return 128;\n    GGML_UNUSED(cc);\n}\n\nstatic constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() {\n    return 128;\n}\n\n// Currently llvm with the amdgcn target does not support unrolling loops\n// that contain a break that can not be resolved at compile time.\n#ifdef __clang__\n#pragma clang diagnostic push\n#pragma clang diagnostic ignored \"-Wpass-failed\"\n#endif // __clang__\ntemplate<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size\n__launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1)\nstatic __global__ void flash_attn_ext_vec(\n        const char * __restrict__ Q,\n        const char * __restrict__ K,\n        const char * __restrict__ V,\n        const char * __restrict__ mask,\n        const char * __restrict__ sinks,\n        const int  * __restrict__ KV_max,\n        float      * __restrict__ dst,\n        float2     * __restrict__ dst_meta,\n        const float scale,\n        const float max_bias,\n        const float m0,\n        const float m1,\n        const uint32_t n_head_log2,\n        const float logit_softcap,\n        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,\n                            const int32_t nb01, const int32_t nb02, const int32_t nb03,\n        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,\n                            const int32_t nb11, const int32_t nb12, const int64_t nb13,\n                            const int32_t nb21, const int32_t nb22, const int64_t nb23,\n                            const int32_t ne31, const int32_t ne32, const int32_t ne33,\n                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {\n#ifdef FLASH_ATTN_AVAILABLE\n\n    // Skip unused kernel variants for faster compilation:\n    if (use_logit_softcap && !(D == 128 || D == 256)) {\n        GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,\n            max_bias, m0, m1, n_head_log2, logit_softcap,\n            ne00, ne01, ne02, ne03,\n                  nb01, nb02, nb03,\n            ne10, ne11, ne12, ne13,\n                  nb11, nb12, nb13,\n                  nb21, nb22, nb23,\n                  ne31, ne32, ne33,\n                  nb31, nb32, nb33);\n        NO_DEVICE_CODE;\n        return;\n    }\n\n    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.\n\n    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();\n    constexpr int cpy_ne = cpy_nb / 4;\n\n#ifdef GGML_USE_HIP\n#ifdef RDNA\n    constexpr int nthreads_KQ_q = 2;\n#else\n    constexpr int nthreads_KQ_q = 4;\n#endif // RDNA\n    constexpr int nthreads_V_q  = (D/4 < 32 ? D/4 : 32);\n#else\n    constexpr int nthreads_KQ_q = (D/4 < 32 ? D/4 : 32);\n    constexpr int nthreads_V_q  = (D/4 < 32 ? D/4 : 32);\n#endif // GGML_USE_HIP\n\n    constexpr int nthreads    = ggml_cuda_fattn_vec_get_nthreads_device();\n    constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;\n    constexpr int nthreads_V  = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;\n\n    static_assert(WARP_SIZE % nthreads_KQ == 0, \"bad nthreads_K\");\n    static_assert(WARP_SIZE % nthreads_V  == 0, \"bad nthreads_V\");\n\n    constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;\n    constexpr int V_cols_per_iter   = WARP_SIZE / nthreads_V;\n\n    constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();\n    constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;\n#ifdef V_DOT2_F32_F16_AVAILABLE\n    constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half,  V_rows_per_thread>();\n#else\n    constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();\n#endif // V_DOT2_F32_F16_AVAILABLE\n\n    const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.\n\n    const int sequence = blockIdx.z / ne02;\n    const int head = blockIdx.z - sequence*ne02;\n    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.\n    Q += nb03*sequence + nb02* head              + nb01*ic0;\n    K += nb13*sequence + nb12*(head / gqa_ratio);\n    V += nb23*sequence + nb22*(head / gqa_ratio);\n\n    const half * maskh  = (const half  *) (mask + nb33*(sequence % ne33) + nb31*ic0);\n\n    const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);\n\n    static_assert(D % (2*WARP_SIZE) == 0, \"D not divisible by 2*WARP_SIZE == 64.\");\n    constexpr int nwarps = nthreads / WARP_SIZE;\n    const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;\n    __builtin_assume(tid < nthreads);\n\n    constexpr int ne_KQ      = ncols*D;\n    constexpr int ne_combine = nwarps*V_cols_per_iter*D;\n#ifdef V_DOT2_F32_F16_AVAILABLE\n    half2            VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};\n    __shared__ half   KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];\n#else\n    float2           VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};\n    __shared__ float  KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];\n#endif // V_DOT2_F32_F16_AVAILABLE\n\n    float KQ_max[ncols];\n    float KQ_sum[ncols];\n#pragma unroll\n    for (int j = 0; j < ncols; ++j) {\n        KQ_max[j] = -FLT_MAX/2.0f;\n        KQ_sum[j] = 0.0f;\n    }\n\n    // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:\n#ifdef V_DOT2_F32_F16_AVAILABLE\n    half2  Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.\n#else\n    __align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.\n#endif // V_DOT2_F32_F16_AVAILABLE\n    int    Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];\n    float2  Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];\n    if constexpr (Q_q8_1) {\n#pragma unroll\n        for (int j0 = 0; j0 < ncols; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n            if (j0 + nwarps > ncols && j >= ncols) {\n                break;\n            }\n\n            // Reuse KQ as temporary storage for converting Q to q8_1:\n            int    * tmp_q_i32 = (int    *) &KQ[j*D];\n            float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));\n\n            // Set memory to zero if out of bounds:\n            if (ncols > 1 && ic0 + j >= int(ne01.z)) {\n#pragma unroll\n                for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {\n                    const int i = i0 + threadIdx.x;\n\n                    if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {\n                        tmp_q_i32[i] = 0;\n                    }\n                }\n                if (threadIdx.x < D/QK8_1) {\n                    tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);\n                }\n            } else {\n                const float * Q_f = (const float *) (Q + j*nb01);\n                constexpr int nthreads_quantize = D/sizeof(int) < WARP_SIZE ? D/sizeof(int) : WARP_SIZE;\n#pragma unroll\n                for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) {\n                    quantize_q8_1_to_shared<float2, nthreads_quantize>\n                        (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1);\n                }\n            }\n        }\n\n        __syncthreads();\n\n#pragma unroll\n        for (int j = 0; j < ncols; ++j) {\n            int    * tmp_q_i32 = (int    *) &KQ[j*D];\n            float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));\n\n#pragma unroll\n            for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) {\n                const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ);\n\n                Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i];\n                Q_ds[j][i0/nthreads_KQ]  = tmp_q_ds[i/QI8_1];\n            }\n        }\n\n        __syncthreads();\n    } else {\n#ifdef V_DOT2_F32_F16_AVAILABLE\n        const half2 scale_h2 = make_half2(scale, scale);\n#pragma unroll\n        for (int j = 0; j < ncols; ++j) {\n            const float2 * Q_j = (const float2 *) (Q + j*nb01);\n#pragma unroll\n            for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {\n                const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;\n\n                __align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}};\n                if (ncols == 1 || ic0 + j < int(ne01.z)) {\n                    ggml_cuda_memcpy_1<cpy_nb>(tmp,            &Q_j[i]);\n                    ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);\n                }\n#pragma unroll\n                for (int i1 = 0; i1 < cpy_ne; ++i1) {\n                    Q_reg[j][i0/nthreads_KQ + i1] = make_half2(tmp[i1].x, tmp[i1].y);\n                }\n            }\n#pragma unroll\n            for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {\n                Q_reg[j][k] *= scale_h2;\n            }\n        }\n#else\n#pragma unroll\n        for (int j = 0; j < ncols; ++j) {\n            const float2 * Q_j = (const float2 *) (Q + j*nb01);\n#pragma unroll\n            for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {\n                const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;\n                if (ncols == 1 || ic0 + j < int(ne01.z)) {\n                    ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ],            &Q_j[i]);\n                    ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);\n                }\n            }\n#pragma unroll\n            for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {\n                Q_reg[j][k].x *= scale;\n                Q_reg[j][k].y *= scale;\n            }\n        }\n#endif // V_DOT2_F32_F16_AVAILABLE\n    }\n\n    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;\n    K     += blockIdx.y*nthreads * nb11;\n    V     += blockIdx.y*nthreads * nb21;\n    maskh += blockIdx.y*nthreads;\n    for (int k_VKQ_0 = blockIdx.y*nthreads; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nthreads,\n             // Increment pointers after each loop:\n             K += gridDim.y*nthreads*nb11, V += gridDim.y*nthreads*nb21, maskh += gridDim.y*nthreads) {\n\n        // Calculate KQ tile and keep track of new maximum KQ values:\n        float KQ_reg[ncols]; // KQ in registers.\n\n        float KQ_max_new[ncols];\n#pragma unroll\n        for (int j = 0; j < ncols; ++j) {\n            KQ_max_new[j] = KQ_max[j];\n        }\n\n#pragma unroll\n        for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {\n            const int i_KQ = threadIdx.y*WARP_SIZE + (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1))) + i_KQ_0;\n\n#pragma unroll\n            for (int j = 0; j < ncols; ++j) {\n                float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);\n                sum = warp_reduce_sum<nthreads_KQ>(sum);\n\n                if (use_logit_softcap) {\n                    sum = logit_softcap*tanhf(sum);\n                }\n\n                if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) {\n                    sum += slope*__half2float(maskh[j*ne11 + i_KQ]);\n                }\n\n                KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET);\n\n                if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {\n                    KQ_reg[j] = sum;\n                }\n            }\n        }\n\n#pragma unroll\n        for (int j = 0; j < ncols; ++j) {\n#pragma unroll\n            for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) {\n                KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE));\n            }\n            const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]);\n            KQ_max[j] = KQ_max_new[j];\n\n            KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]);\n            KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];\n            KQ[j*nthreads + tid] = KQ_reg[j];\n\n#ifdef V_DOT2_F32_F16_AVAILABLE\n            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);\n#pragma unroll\n            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {\n                VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;\n            }\n#else\n#pragma unroll\n            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {\n                VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;\n                VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;\n            }\n#endif // V_DOT2_F32_F16_AVAILABLE\n        }\n\n#ifndef GGML_USE_HIP\n        __syncwarp();\n#endif // GGML_USE_HIP\n\n#pragma unroll\n        for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {\n            const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);\n\n#ifdef V_DOT2_F32_F16_AVAILABLE\n            half2 KQ_k[ncols];\n#pragma unroll\n            for (int j = 0; j < ncols; ++j) {\n                KQ_k[j] = __half2half2(KQ[j*nthreads + k]);\n            }\n#pragma unroll\n            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {\n                half2 tmp[V_rows_per_thread/2];\n                dequantize_V(V + k*nb21, tmp,\n                    2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);\n#pragma unroll\n                for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {\n#pragma unroll\n                    for (int j = 0; j < ncols; ++j) {\n                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j];\n                    }\n                }\n            }\n#else\n            float KQ_k[ncols];\n#pragma unroll\n            for (int j = 0; j < ncols; ++j) {\n                KQ_k[j] = KQ[j*nthreads + k];\n            }\n#pragma unroll\n            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {\n                float2 tmp[V_rows_per_thread/2];\n                dequantize_V(V + k*nb21, tmp,\n                    2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);\n#pragma unroll\n                for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {\n#pragma unroll\n                    for (int j = 0; j < ncols; ++j) {\n                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j];\n                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j];\n                    }\n                }\n            }\n#endif // V_DOT2_F32_F16_AVAILABLE\n        }\n    }\n\n    if (sinks && blockIdx.y == 0) {\n        const float sink = ((const float *) sinks)[head];\n\n#pragma unroll\n        for (int j0 = 0; j0 < ncols; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n            if (j0 + nwarps > ncols && j >= ncols) {\n                break;\n            }\n\n            const float kqmax_new_j = fmaxf(sink, KQ_max[j]);\n            const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j);\n            KQ_max[j] = kqmax_new_j;\n\n            KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);\n\n#ifdef V_DOT2_F32_F16_AVAILABLE\n            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);\n#pragma unroll\n            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {\n                VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;\n            }\n#else\n#pragma unroll\n            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {\n                VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;\n                VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;\n            }\n#endif // V_DOT2_F32_F16_AVAILABLE\n        }\n    }\n\n    __shared__ float KQ_max_shared[ncols][WARP_SIZE];\n    __shared__ float KQ_sum_shared[ncols][WARP_SIZE];\n#pragma unroll\n    for (int j = 0; j < ncols; ++j) {\n        if (threadIdx.y == 0) {\n            KQ_max_shared[j][threadIdx.x] = -FLT_MAX/2.0f;\n            KQ_sum_shared[j][threadIdx.x] = 0.0f;\n        }\n    }\n\n    __syncthreads();\n\n#pragma unroll\n    for (int j = 0; j < ncols; ++j) {\n        if (threadIdx.x == 0) {\n            KQ_max_shared[j][threadIdx.y] = KQ_max[j];\n        }\n    }\n    __syncthreads();\n\n#pragma unroll\n    for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {\n        if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) {\n            break;\n        }\n\n        float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x];\n        kqmax_new = warp_reduce_max(kqmax_new);\n        const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);\n        KQ_max[j_VKQ] = kqmax_new;\n\n#ifdef V_DOT2_F32_F16_AVAILABLE\n        half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)\n            + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);\n\n        const half2 kqmax_scale_h2 = make_half2(kqmax_scale, kqmax_scale);\n#pragma unroll\n        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {\n            VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2;\n        }\n#pragma unroll\n        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {\n            const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);\n\n            ggml_cuda_memcpy_1<V_rows_per_thread*sizeof(half)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);\n        }\n#else\n        float2 * VKQ_tmp = (float2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)\n            + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);\n\n#pragma unroll\n        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {\n            VKQ[j_VKQ][i_VKQ_0/nthreads_V].x *= kqmax_scale;\n            VKQ[j_VKQ][i_VKQ_0/nthreads_V].y *= kqmax_scale;\n        }\n#pragma unroll\n        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {\n            const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);\n\n            ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ,                       &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);\n            ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);\n        }\n#endif // V_DOT2_F32_F16_AVAILABLE\n\n        KQ_sum[j_VKQ] *= kqmax_scale;\n        KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);\n        if (threadIdx.x == 0) {\n            KQ_sum_shared[j_VKQ][threadIdx.y] = KQ_sum[j_VKQ];\n        }\n\n        __syncthreads();\n\n        if (nthreads <= D || tid < D) {\n            KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ][threadIdx.x];\n            KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);\n\n#pragma unroll\n            for (int i0 = 0; i0 < D; i0 += nthreads) {\n                float dst_val = 0;\n#pragma unroll\n                for (int w = 0; w < nwarps; ++w) {\n#pragma unroll\n                    for (int v = 0; v < V_cols_per_iter; ++v) {\n                        dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);\n                    }\n                }\n                if (gridDim.y == 1) {\n                    dst_val /= KQ_sum[j_VKQ];\n                }\n                dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;\n            }\n        }\n\n        if (j_VKQ < ncols-1) {\n            __syncthreads();\n        }\n\n    }\n\n    if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) {\n        dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);\n    }\n#else\n    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,\n        max_bias, m0, m1, n_head_log2, logit_softcap,\n        ne00, ne01, ne02, ne03,\n              nb01, nb02, nb03,\n        ne10, ne11, ne12, ne13,\n              nb11, nb12, nb13,\n              nb21, nb22, nb23,\n              ne31, ne32, ne33,\n              nb31, nb32, nb33);\n    NO_DEVICE_CODE;\n#endif // FLASH_ATTN_AVAILABLE\n}\n#ifdef __clang__\n#pragma clang diagnostic pop\n#endif // __clang__\n\ntemplate <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>\nvoid ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;\n\n    const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);\n    const int nwarps   = nthreads / WARP_SIZE;\n    fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;\n    const bool need_f16_K = type_K == GGML_TYPE_F16;\n    const bool need_f16_V = type_V == GGML_TYPE_F16;\n    constexpr size_t nbytes_shared = 0;\n    launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);\n}\n\ntemplate <int D, ggml_type type_K, ggml_type type_V>\nvoid ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * KQV = dst;\n    const ggml_tensor * Q   = dst->src[0];\n\n    float logit_softcap;\n    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));\n\n    if (Q->ne[1] == 1) {\n        constexpr int cols_per_block = 1;\n        if (logit_softcap == 0.0f) {\n            constexpr bool use_logit_softcap = false;\n            ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);\n        } else {\n            constexpr bool use_logit_softcap = true;\n            ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);\n        }\n        return;\n    }\n\n    constexpr int cols_per_block = 2;\n    if (logit_softcap == 0.0f) {\n        constexpr bool use_logit_softcap = false;\n        ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);\n    } else {\n        constexpr bool use_logit_softcap = true;\n        ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);\n    }\n}\n\n#define DECL_FATTN_VEC_CASE(D, type_K, type_V)                              \\\n    template void ggml_cuda_flash_attn_ext_vec_case                         \\\n    <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \\\n\n#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K)             \\\n    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16);  \\\n    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \\\n    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \\\n    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \\\n    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \\\n    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \\\n\nEXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)\nEXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)\nEXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)\nEXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)\nEXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)\nEXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)\n\nEXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)\nEXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)\nEXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)\nEXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)\nEXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)\nEXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)\n\nEXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)\nEXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)\nEXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)\nEXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)\nEXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)\nEXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)\n"
  },
  {
    "path": "src/ggml-cuda/fattn-wmma-f16.cu",
    "content": "// Old and deprecated WMMA FlashAttention implementation.\n// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing.\n// Long-term the WMMA code should be replaced with a dedicated Volta implementation.\n\n#include \"common.cuh\"\n#include \"fattn-common.cuh\"\n#include \"fattn-wmma-f16.cuh\"\n\n#ifdef GGML_USE_WMMA_FATTN\n#if !defined(GGML_USE_HIP)\n#include <mma.h>\n#if defined(GGML_USE_MUSA)\nnamespace wmma = mtmusa::wmma;\n#else // GGML_USE_MUSA\nnamespace wmma = nvcuda::wmma;\n#endif // GGML_USE_MUSA\n#elif defined(GGML_USE_HIP)\n#include <rocwmma/rocwmma.hpp>\nnamespace wmma = rocwmma;\n#endif // !defined(GGML_USE_HIP)\n#endif // GGML_USE_WMMA_FATTN\n\n// D == head size, VKQ_stride == num VKQ rows calculated in parallel:\ntemplate<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>\n__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)\nstatic __global__ void flash_attn_ext_f16(\n        const char * __restrict__ Q,\n        const char * __restrict__ K,\n        const char * __restrict__ V,\n        const char * __restrict__ mask,\n        const char * __restrict__ sinks,\n        const int  * __restrict__ KV_max,\n        float      * __restrict__ dst,\n        float2     * __restrict__ dst_meta,\n        const float scale,\n        const float max_bias,\n        const float m0,\n        const float m1,\n        const uint32_t n_head_log2,\n        const float logit_softcap,\n        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,\n                            const int32_t nb01, const int32_t nb02, const int32_t nb03,\n        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,\n                            const int32_t nb11, const int32_t nb12, const int64_t nb13,\n                            const int32_t nb21, const int32_t nb22, const int64_t nb23,\n                            const int32_t ne31, const int32_t ne32, const int32_t ne33,\n                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {\n#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))\n    // Skip unused kernel variants for faster compilation:\n    if (use_logit_softcap && !(D == 128 || D == 256)) {\n        NO_DEVICE_CODE;\n        return;\n    }\n\n    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.\n\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.\n\n    static_assert(D <= FATTN_KQ_STRIDE, \"D must be <= FATTN_KQ_STRIDE.\");\n    static_assert(ncols == 8 || ncols % 16 == 0, \"ncols must be 8 or a multiple of 16.\");\n    constexpr int frag_m = ncols == 8 ? 32 : 16;\n    constexpr int frag_n = ncols == 8 ?  8 : 16;\n    static_assert(D % frag_m == 0, \"If ncols == 8 then D % frag_m must be 0.\");\n#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000\n    typedef wmma::fragment<wmma::matrix_a,    frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K;\n    typedef wmma::fragment<wmma::matrix_a,    frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V;\n    typedef wmma::fragment<wmma::matrix_b,    frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b;\n    typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ;\n    typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, _Float16>                          frag_c_VKQ;\n#else\n    typedef wmma::fragment<wmma::matrix_a,    frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;\n    typedef wmma::fragment<wmma::matrix_a,    frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;\n    typedef wmma::fragment<wmma::matrix_b,    frag_m, frag_n, 16, half, wmma::col_major> frag_b;\n    typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ;\n    typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half>                          frag_c_VKQ;\n#endif\n\n    constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.\n    constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.\n    static_assert(VKQ_ratio <= nwarps, \"VKQ_ratio must be <= nwarps.\");\n\n    // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:\n    constexpr int D_padded = D + 8;\n    constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;\n    constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);\n\n    const int sequence = blockIdx.z / ne02;\n    const int head = blockIdx.z - sequence*ne02;\n    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.\n    const float * Q_f    = (const float *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);\n    const half  * K_h    = (const half  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));\n    const half  * V_h    = (const half  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape\n    const half  * maskh  = (const half  *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);\n    const half2 * mask2  = (const half2 *)  maskh;\n    const float * sinksf = (const float *) sinks;\n\n    const int stride_Q  = nb01 / sizeof(float);\n    const int stride_KV = nb11 / sizeof(half);\n\n    const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);\n    const half  slopeh = __float2half(slopef);\n    const half2 slope2 = make_half2(slopef, slopef);\n\n    const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);\n\n    frag_b Q_b[D/16][ncols/frag_n];\n\n    // A single buffer for temporarily holding tiles of KQ and VKQ parts:\n    constexpr int mem_KQ = ncols*kqs_padded*kqar;\n    constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;\n    __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];\n    float * KQ_f = (float *) KQ;\n    half2 * KQ2 = (half2 *) KQ;\n\n    float    KQ_rowsum_f[ncols/nwarps] = {0.0f};\n    float       KQ_max_f[ncols/nwarps];\n    float KQ_max_scale_f[ncols/nwarps] = {0.0f};\n\n#pragma unroll\n    for (int j = 0; j < ncols/nwarps; ++j) {\n        KQ_max_f[j] = -FLT_MAX/2.0f;\n    }\n\n    half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};\n    half2       KQ_max_h2[ncols/nwarps];\n    half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};\n\n#pragma unroll\n    for (int j = 0; j < ncols/nwarps; ++j) {\n        KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);\n    }\n\n    __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.\n    half2 * VKQ2 = (half2 *) VKQ;\n\n#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000\n    const _Float16 * K_h_f16  = reinterpret_cast<const _Float16 *>(K_h);\n    const _Float16 * V_h_f16  = reinterpret_cast<const _Float16 *>(V_h);\n    _Float16       * KQ_f16   = reinterpret_cast<_Float16 *>(KQ);\n    _Float16       * VKQ_f16  = reinterpret_cast<_Float16 *>(VKQ);\n#else\n    const half * K_h_f16  = K_h;\n    const half * V_h_f16  = V_h;\n    half       * KQ_f16   = KQ;\n    half       * VKQ_f16  = VKQ;\n#endif\n\n#pragma unroll\n    for (int j0 = 0; j0 < ncols; j0 += nwarps) {\n        const int j = j0 + threadIdx.y;\n#pragma unroll\n        for (int i0 = 0; i0 < D/2; i0 += warp_size) {\n            const int i = i0 + threadIdx.x;\n            if (i0 + warp_size > D/2 && i >= D/2) {\n                break;\n            }\n            VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);\n        }\n    }\n\n    // Convert Q to half and apply scale, temporarily store in KQ:\n#pragma unroll\n    for (int j0 = 0; j0 < ncols; j0 += nwarps) {\n        const int j = j0 + threadIdx.y;\n#pragma unroll\n        for (int i0 = 0; i0 < D; i0 += warp_size) {\n            const int i = i0 + threadIdx.x;\n            if (i0 + warp_size > D && i >= D) {\n                break;\n            }\n            KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f;\n        }\n    }\n\n    __syncthreads();\n\n    // Load Q into tensor core fragments/registers since it will be used frequently:\n#pragma unroll\n    for (int i0 = 0; i0 < D; i0 += 16) {\n#pragma unroll\n        for (int j0 = 0; j0 < ncols; j0 += frag_n) {\n            wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);\n        }\n    }\n\n    __syncthreads();\n\n    // Iterate over ne11 == previous tokens:\n    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;\n    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {\n        // Calculate tile of KQ:\n#pragma unroll\n        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {\n            frag_c_KQ KQ_c[ncols/frag_n];\n#pragma unroll\n            for (int j = 0; j < ncols/frag_n; ++j) {\n                wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));\n            }\n#pragma unroll\n            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {\n                frag_a_K K_a;\n                wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);\n#pragma unroll\n                for (int j = 0; j < ncols/frag_n; ++j) {\n                    wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);\n                }\n            }\n#pragma unroll\n            for (int j0 = 0; j0 < ncols; j0 += frag_n) {\n                wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);\n            }\n        }\n\n        __syncthreads();\n\n        // Calculate softmax for each KQ column using the current max. value.\n        // The divisor is stored in KQ_rowsum and will be applied at the end.\n#pragma unroll\n        for (int j0 = 0; j0 < ncols; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n            if (std::is_same<KQ_acc_t, float>::value) {\n                float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];\n#pragma unroll\n                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {\n                    const int k = k0 + threadIdx.x;\n\n                    KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];\n\n                    if (use_logit_softcap) {\n                        KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]);\n                    }\n                }\n\n                float KQ_max_new = KQ_max_f[j0/nwarps];\n#pragma unroll\n                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {\n                    const int k = k0 + threadIdx.x;\n\n                    KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?\n                        __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;\n                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size] + FATTN_KQ_MAX_OFFSET);\n                }\n                KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);\n\n                const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;\n                KQ_max_scale_f[j0/nwarps] = expf(diff);\n                if (diff <= SOFTMAX_FTZ_THRESHOLD) {\n                    KQ_max_scale_f[j0/nwarps] = 0.0f;\n                }\n                KQ_max_f[j0/nwarps] = KQ_max_new;\n\n                float KQ_rowsum_add = 0.0f;\n#pragma unroll\n                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {\n                    const int k = k0 + threadIdx.x;\n\n                    const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];\n                    KQ_f_tmp[k0/warp_size] = expf(diff);\n                    if (diff <= SOFTMAX_FTZ_THRESHOLD) {\n                        KQ_f_tmp[k0/warp_size] = 0.0f;\n                    }\n                    KQ_rowsum_add += KQ_f_tmp[k0/warp_size];\n                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size];\n                }\n                KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);\n\n                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:\n                KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;\n            } else {\n                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];\n#pragma unroll\n                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {\n                    const int k = k0 + threadIdx.x;\n\n                    KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];\n\n                    if (use_logit_softcap) {\n                        // There is no dedicated tangens hyperbolicus function for half2.\n                        KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f));\n                        KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f))\n                                               /(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f));\n\n                        KQ2_tmp[k0/warp_size] *= logit_softcap_2;\n                    }\n                }\n\n                half2 KQ_max_new = KQ_max_h2[j0/nwarps];\n#pragma unroll\n                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {\n                    const int k = k0 + threadIdx.x;\n\n                    KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);\n                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);\n                }\n                KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));\n                const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;\n                KQ_max_scale_h2[j0/nwarps] = h2exp(diff);\n                const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));\n                *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;\n                KQ_max_h2[j0/nwarps] = KQ_max_new;\n\n                half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);\n#pragma unroll\n                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {\n                    const int k = k0 + threadIdx.x;\n\n                    const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];\n                    KQ2_tmp[k0/warp_size] = h2exp(diff);\n                    const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));\n                    *((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask;\n                    KQ_rowsum_add += KQ2_tmp[k0/warp_size];\n                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size];\n                }\n                KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);\n\n                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:\n                KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;\n            }\n        }\n\n        __syncthreads();\n\n        frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];\n#pragma unroll\n        for (int j0 = 0; j0 < ncols; j0 += frag_n) {\n#pragma unroll\n            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {\n                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;\n                wmma::load_matrix_sync(\n                    KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],\n                    KQ_f16 + j0*(kqar*kqs_padded) + k,\n                    kqar*kqs_padded);\n            }\n        }\n\n        frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];\n#pragma unroll\n        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {\n#pragma unroll\n            for (int j = 0; j < ncols/frag_n; ++j) {\n                wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));\n            }\n\n#pragma unroll\n            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {\n                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;\n\n                frag_a_V v_a;\n                wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);\n#pragma unroll\n                for (int j = 0; j < ncols/frag_n; ++j) {\n                    wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);\n                }\n            }\n        }\n\n        __syncthreads();\n\n        const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);\n#pragma unroll\n        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {\n#pragma unroll\n            for (int j0 = 0; j0 < ncols; j0 += frag_n) {\n                wmma::store_matrix_sync(\n                    KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),\n                    VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],\n                    D_padded, wmma::mem_col_major);\n            }\n        }\n\n        __syncthreads();\n\n#pragma unroll\n        for (int j0 = 0; j0 < ncols; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n            half2 VKQ_scale;\n            if (std::is_same<KQ_acc_t, float>::value) {\n                VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);\n            } else {\n                VKQ_scale = KQ_max_scale_h2[j0/nwarps];\n            }\n\n#pragma unroll\n            for (int i0 = 0; i0 < D/2; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n                if (i0 + warp_size > D/2 && i >= D/2) {\n                    break;\n                }\n\n                half2 VKQ_add = make_half2(0.0f, 0.0f);\n#pragma unroll\n                for (int l = 0; l < VKQ_ratio; ++l) {\n                    VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];\n                }\n                VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;\n            }\n        }\n\n        __syncthreads();\n    }\n\n    // Apply attention sinks\n    if (sinksf && blockIdx.y == 0) {\n        const float sinkf = sinksf[head];\n        const half  sinkh = __float2half(sinkf);\n\n#pragma unroll\n        for (int j0 = 0; j0 < ncols; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n            if (std::is_same<KQ_acc_t, float>::value) {\n                float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);\n\n                const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);\n                KQ_max_f[j0/nwarps] = kqmax_new;\n\n                KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);\n\n                const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);\n#pragma unroll\n                for (int i0 = 0; i0 < D/2; i0 += warp_size) {\n                    const int i = i0 + threadIdx.x;\n                    if (i0 + warp_size > D/2 && i >= D/2) break;\n                    VKQ2[j*(D_padded/2) + i] *= scale_h2;\n                }\n            } else {\n                half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);\n                half kqmax_new = fmaxf(kqmax_old, sinkh);\n                KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);\n\n                const half  KQ_max_scale_h = hexp(kqmax_old - kqmax_new);\n                const half2 KQ_max_scale   = __half2half2(KQ_max_scale_h);\n\n                KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;\n                const half val = hexp(sinkh - kqmax_new);\n                KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);\n\n#pragma unroll\n                for (int i0 = 0; i0 < D/2; i0 += warp_size) {\n                    const int i = i0 + threadIdx.x;\n                    if (i0 + warp_size > D/2 && i >= D/2) break;\n                    VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;\n                }\n            }\n        }\n\n        __syncthreads();\n    }\n#pragma unroll\n    for (int j0 = 0; j0 < ncols; j0 += nwarps) {\n        const int j_VKQ = j0 + threadIdx.y;\n        if (ic0 + j_VKQ >= int(ne01.z)) {\n            return;\n        }\n\n        float KQ_rowsum_j;\n        if (std::is_same<KQ_acc_t, float>::value) {\n            KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];\n        } else {\n            KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);\n        }\n\n        const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;\n\n#pragma unroll\n        for (int i0 = 0; i0 < D; i0 += warp_size) {\n            const int i = i0 + threadIdx.x;\n            if (i0 + warp_size > D && i >= D) {\n                break;\n            }\n            float dst_val = VKQ[j_VKQ*D_padded + i];\n            if (gridDim.y == 1) {\n                dst_val /= KQ_rowsum_j;\n            }\n            dst[j_dst_unrolled*D + i] = dst_val;\n        }\n\n        if (gridDim.y == 1 || threadIdx.x != 0) {\n            continue;\n        }\n\n        float2 dst_meta_val;\n        if (std::is_same<KQ_acc_t, float>::value) {\n            dst_meta_val.x = KQ_max_f[j0/nwarps];\n        } else {\n            dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);\n        }\n        dst_meta_val.y = KQ_rowsum_j;\n        dst_meta[j_dst_unrolled] = dst_meta_val;\n    }\n#else\n    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,\n        max_bias, m0, m1, n_head_log2, logit_softcap,\n        ne00, ne01, ne02, ne03,\n              nb01, nb02, nb03,\n        ne10, ne11, ne12, ne13,\n              nb11, nb12, nb13,\n              nb21, nb22, nb23,\n              ne31, ne32, ne33,\n              nb31, nb32, nb33);\n    NO_DEVICE_CODE;\n#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))\n}\n\nconstexpr int get_max_power_of_2(int x) {\n    return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;\n}\n\nstatic_assert(get_max_power_of_2(1) == 1, \"Test failed.\");\nstatic_assert(get_max_power_of_2(2) == 2, \"Test failed.\");\nstatic_assert(get_max_power_of_2(4) == 4, \"Test failed.\");\nstatic_assert(get_max_power_of_2(6) == 2, \"Test failed.\");\n\n// Number of VKQ rows calculated in parallel:\nconstexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {\n    return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;\n}\n\nstatic_assert(get_VKQ_stride(128, 1, 32) ==  32, \"Test failed.\");\nstatic_assert(get_VKQ_stride(128, 2, 32) ==  64, \"Test failed.\");\nstatic_assert(get_VKQ_stride(128, 4, 32) == 128, \"Test failed.\");\nstatic_assert(get_VKQ_stride( 64, 1, 32) ==  32, \"Test failed.\");\nstatic_assert(get_VKQ_stride( 64, 2, 32) ==  64, \"Test failed.\");\nstatic_assert(get_VKQ_stride( 64, 4, 32) ==  64, \"Test failed.\");\nstatic_assert(get_VKQ_stride( 80, 1, 16) ==  16, \"Test failed.\");\nstatic_assert(get_VKQ_stride( 80, 2, 16) ==  16, \"Test failed.\");\nstatic_assert(get_VKQ_stride( 80, 4, 16) ==  16, \"Test failed.\");\n\ntemplate <int D, int cols_per_block, typename KQ_acc_t>\nvoid ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * KQV = dst;\n\n    constexpr int nwarps = 4;\n\n    constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;\n    const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;\n\n    float logit_softcap;\n    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));\n\n    fattn_kernel_t fattn_kernel;\n    if (logit_softcap == 0.0f) {\n        constexpr bool use_logit_softcap = false;\n        fattn_kernel = flash_attn_ext_f16<\n            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;\n    } else {\n        constexpr bool use_logit_softcap = true;\n        fattn_kernel = flash_attn_ext_f16<\n            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;\n    }\n    launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);\n}\n\nvoid ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * KQV = dst;\n    const ggml_tensor * Q   = dst->src[0];\n\n    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);\n    const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;\n\n    if (prec != GGML_PREC_DEFAULT) {\n        if (Q->ne[1] <= 32 || Q->ne[0] > 128) {\n            constexpr int cols_per_block = 16;\n            switch (Q->ne[0]) {\n                case 64:\n                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);\n                    break;\n                case 80:\n                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);\n                    break;\n                case 96:\n                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);\n                    break;\n                case 112:\n                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);\n                    break;\n                case 128:\n                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);\n                    break;\n                case 256:\n                    ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);\n                    break;\n                default:\n                    GGML_ABORT(\"fatal error\");\n                    break;\n            }\n        } else {\n            constexpr int cols_per_block = 32;\n            switch (Q->ne[0]) {\n                case 64:\n                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);\n                    break;\n                case 80:\n                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);\n                    break;\n                case 96:\n                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);\n                    break;\n                case 112:\n                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);\n                    break;\n                case 128:\n                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);\n                    break;\n                // case 256:\n                //     ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);\n                //     break;\n                default:\n                    GGML_ABORT(\"fatal error\");\n                    break;\n            }\n        }\n        return;\n    }\n\n#if !defined(GGML_USE_HIP)\n    if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {\n        constexpr int cols_per_block = 8;\n        switch (Q->ne[0]) {\n            case 64:\n                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);\n                break;\n            case 96:\n                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);\n                break;\n            case 128:\n                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);\n                break;\n            case 256:\n                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);\n                break;\n            default:\n                GGML_ABORT(\"fatal error\");\n                break;\n        }\n        return;\n    }\n#endif // !defined(GGML_USE_HIP)\n\n    if (Q->ne[1] <= 32) {\n        constexpr int cols_per_block = 16;\n        switch (Q->ne[0]) {\n            case 64:\n                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);\n                break;\n            case 80:\n                ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);\n                break;\n            case 96:\n                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);\n                break;\n            case 112:\n                ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);\n                break;\n            case 128:\n                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);\n                break;\n            case 256:\n                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);\n                break;\n            default:\n                GGML_ABORT(\"fatal error\");\n                break;\n        }\n        return;\n    }\n\n    constexpr int cols_per_block = 32;\n    switch (Q->ne[0]) {\n        case 64:\n            ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);\n            break;\n        case 80:\n            ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);\n            break;\n        case 96:\n            ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);\n            break;\n        case 112:\n            ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);\n            break;\n        case 128:\n            ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);\n            break;\n        case 256:\n            ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);\n            break;\n        default:\n            GGML_ABORT(\"fatal error\");\n            break;\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/fattn-wmma-f16.cuh",
    "content": "#pragma once\n\n#include \"common.cuh\"\n\n#if defined(GGML_USE_MUSA)\n#define GGML_USE_WMMA_FATTN\n#endif // defined(GGML_USE_MUSA)\n\n#if defined(GGML_HIP_ROCWMMA_FATTN)\n#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)\n#define GGML_USE_WMMA_FATTN\n#elif defined(CDNA)\n#warning \"rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance\"\n#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)\n#if defined(RDNA3)\n#define GGML_USE_WMMA_FATTN\n#endif // defined(RDNA3)\n#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1\n#define GGML_USE_WMMA_FATTN\n#elif defined(RDNA4)\n#warning \"rocwmma fattn is not supported on RDNA4 on rocwmma < v2.0.0, expect degraded performance\"\n#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1\n#endif // defined(GGML_HIP_ROCWMMA_FATTN)\n\n// WMMA flash attention requires FP16 matrix instructions to be available for ggml code.\nstatic bool ggml_cuda_should_use_wmma_fattn(const int cc) {\n#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)\n    return false;\n#else\n    if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) ||\n        GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) {\n        return true;\n    } else if (GGML_CUDA_CC_IS_CDNA(cc)){\n#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)\n        return true;\n#else\n        return false;\n#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)\n    } else if (GGML_CUDA_CC_IS_RDNA4(cc)) {\n#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1\n        return true;\n#else\n        return false;\n#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1\n    } else {\n        return false;\n    }\n#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)\n}\n\nvoid ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/fattn.cu",
    "content": "#include \"common.cuh\"\n#include \"fattn-common.cuh\"\n#include \"fattn-mma-f16.cuh\"\n#include \"fattn-tile.cuh\"\n#include \"fattn-vec.cuh\"\n#include \"fattn-wmma-f16.cuh\"\n#include \"fattn.cuh\"\n\ntemplate <int DKQ, int DV, int ncols2>\nstatic void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;\n    const ggml_tensor * Q = dst->src[0];\n\n    if constexpr (ncols2 <= 8) {\n        if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) {\n            ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);\n            return;\n        }\n    }\n\n    if constexpr (ncols2 <= 16) {\n        if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {\n            ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);\n            return;\n        }\n    }\n\n    if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {\n        ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);\n        return;\n    }\n\n    ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);\n}\n\ntemplate <int DKQ, int DV>\nstatic void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;\n    const ggml_tensor * KQV  = dst;\n    const ggml_tensor * Q    = dst->src[0];\n    const ggml_tensor * K    = dst->src[1];\n    const ggml_tensor * V    = dst->src[2];\n    const ggml_tensor * mask = dst->src[3];\n\n    float max_bias = 0.0f;\n    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));\n\n    // Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers\n    //     are put into the template specialization without GQA optimizations.\n    bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;\n    for (const ggml_tensor * t : {Q, K, V, mask}) {\n        if (t == nullptr || ggml_is_quantized(t->type)) {\n            continue;\n        }\n        for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {\n            if (t->nb[i] % 16 != 0) {\n                use_gqa_opt = false;\n                break;\n            }\n        }\n    }\n\n    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);\n    const int gqa_ratio = Q->ne[2] / K->ne[2];\n\n    // On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute:\n    if (cc == GGML_CUDA_CC_VOLTA) {\n        if (use_gqa_opt && gqa_ratio % 8 == 0) {\n            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);\n            return;\n        }\n\n        if (use_gqa_opt && gqa_ratio % 4 == 0) {\n            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);\n            return;\n        }\n\n        if (use_gqa_opt && gqa_ratio % 2 == 0) {\n            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);\n            return;\n        }\n\n        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);\n        return;\n    }\n\n    if (use_gqa_opt && gqa_ratio > 4) {\n        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);\n        return;\n    }\n\n    if (use_gqa_opt && gqa_ratio > 2) {\n        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);\n        return;\n    }\n\n    if (use_gqa_opt && gqa_ratio > 1) {\n        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);\n        return;\n    }\n\n    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);\n}\n\nstatic void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;\n    const ggml_tensor * KQV  = dst;\n    const ggml_tensor * Q    = dst->src[0];\n    const ggml_tensor * K    = dst->src[1];\n    const ggml_tensor * V    = dst->src[2];\n    const ggml_tensor * mask = dst->src[3];\n\n    switch (Q->ne[0]) {\n        case 64:\n            GGML_ASSERT(V->ne[0] == 64);\n            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64,  64>(ctx, dst);\n            break;\n        case 80:\n            GGML_ASSERT(V->ne[0] == 80);\n            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80,  80>(ctx, dst);\n            break;\n        case 96:\n            GGML_ASSERT(V->ne[0] == 96);\n            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96,  96>(ctx, dst);\n            break;\n        case 112:\n            GGML_ASSERT(V->ne[0] == 112);\n            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);\n            break;\n        case 128:\n            GGML_ASSERT(V->ne[0] == 128);\n            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);\n            break;\n        case 256:\n            GGML_ASSERT(V->ne[0] == 256);\n            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);\n            break;\n        case 576: {\n            // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.\n            GGML_ASSERT(V->ne[0] == 512);\n            float max_bias = 0.0f;\n            memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));\n\n            const bool use_gqa_opt = mask && max_bias == 0.0f;\n            GGML_ASSERT(use_gqa_opt);\n\n            GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);\n            const int gqa_ratio = Q->ne[2] / K->ne[2];\n            if (gqa_ratio == 20) { // GLM 4.7 Flash\n                if (cc >= GGML_CUDA_CC_DGX_SPARK) {\n                    if (Q->ne[1] <= 8) {\n                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);\n                        break;\n                    }\n                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);\n                    break;\n                }\n                if (cc >= GGML_CUDA_CC_BLACKWELL) {\n                    if (Q->ne[1] <= 4 && K->ne[1] >= 65536) {\n                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);\n                        break;\n                    }\n                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);\n                    break;\n                }\n                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {\n                    if (Q->ne[1] <= 4) {\n                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);\n                        break;\n                    }\n                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);\n                    break;\n                }\n                if (cc >= GGML_CUDA_CC_TURING) {\n                    if (Q->ne[1] <= 4) {\n                        if (K->ne[1] <= 16384) {\n                            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);\n                            break;\n                        }\n                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst);\n                        break;\n                    }\n                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);\n                    break;\n                }\n                // Volta:\n                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);\n            } else if (gqa_ratio % 16 == 0) {\n                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);\n            } else {\n                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512,  4>(ctx, dst);\n            }\n        } break;\n        default:\n            GGML_ABORT(\"fatal error\");\n            break;\n    }\n}\n\n#define FATTN_VEC_CASE(D, type_K, type_V)                                                                        \\\n    {                                                                                                            \\\n        const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \\\n        const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \\\n        if (Q->ne[0] == (D) && type_K_okay && type_V_okay) {                                                     \\\n            ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst);                                      \\\n            return;                                                                                              \\\n        }                                                                                                        \\\n    }                                                                                                            \\\n\n#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \\\n    FATTN_VEC_CASE( 64, type_K, type_V)       \\\n    FATTN_VEC_CASE(128, type_K, type_V)       \\\n    FATTN_VEC_CASE(256, type_K, type_V)       \\\n\nstatic void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * Q = dst->src[0];\n    ggml_tensor * K = dst->src[1];\n    ggml_tensor * V = dst->src[2];\n\n#ifdef GGML_CUDA_FA_ALL_QUANTS\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_F16)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)\n\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q4_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)\n\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q4_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)\n\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q5_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)\n\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q5_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)\n\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q8_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)\n#else\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_F16)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)\n#endif // GGML_CUDA_FA_ALL_QUANTS\n\n    GGML_ABORT(\"fatal error\");\n}\n\n// Best FlashAttention kernel for a specific GPU:\nenum best_fattn_kernel {\n    BEST_FATTN_KERNEL_NONE     =   0,\n    BEST_FATTN_KERNEL_TILE     = 200,\n    BEST_FATTN_KERNEL_VEC      = 100,\n    BEST_FATTN_KERNEL_WMMA_F16 = 300,\n    BEST_FATTN_KERNEL_MMA_F16  = 400,\n};\n\nstatic best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {\n#ifndef FLASH_ATTN_AVAILABLE\n    GGML_UNUSED(device); GGML_UNUSED(dst);\n    return BEST_FATTN_KERNEL_NONE;\n#endif// FLASH_ATTN_AVAILABLE\n\n    const ggml_tensor * KQV   = dst;\n    const ggml_tensor * Q     = dst->src[0];\n    const ggml_tensor * K     = dst->src[1];\n    const ggml_tensor * V     = dst->src[2];\n    const ggml_tensor * mask  = dst->src[3];\n\n    const int gqa_ratio = Q->ne[2] / K->ne[2];\n    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);\n\n    float max_bias = 0.0f;\n    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));\n\n    // The effective batch size for the kernel can be increased by gqa_ratio.\n    // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,\n    bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;\n    for (const ggml_tensor * t : {Q, K, V, mask}) {\n        if (t == nullptr || ggml_is_quantized(t->type)) {\n            continue;\n        }\n        for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {\n            if (t->nb[i] % 16 != 0) {\n                gqa_opt_applies = false;\n                break;\n            }\n        }\n    }\n\n    const int cc = ggml_cuda_info().devices[device].cc;\n\n    switch (K->ne[0]) {\n        case  40:\n        case  64:\n        case  72:\n        case  80:\n        case  96:\n        case 128:\n        case 112:\n        case 256:\n            if (V->ne[0] != K->ne[0]) {\n                return BEST_FATTN_KERNEL_NONE;\n            }\n            break;\n        case 576:\n            if (V->ne[0] != 512) {\n                return BEST_FATTN_KERNEL_NONE;\n            }\n            if (!gqa_opt_applies) {\n                return BEST_FATTN_KERNEL_NONE;\n            }\n            break;\n        default:\n            return BEST_FATTN_KERNEL_NONE;\n    }\n\n#ifndef GGML_CUDA_FA_ALL_QUANTS\n    if (K->type != V->type) {\n        return BEST_FATTN_KERNEL_NONE;\n    }\n#endif // GGML_CUDA_FA_ALL_QUANTS\n\n    switch (K->type) {\n        case GGML_TYPE_F32:\n        case GGML_TYPE_F16:\n            break;\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n#ifndef GGML_CUDA_FA_ALL_QUANTS\n            return BEST_FATTN_KERNEL_NONE;\n#endif // GGML_CUDA_FA_ALL_QUANTS\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q8_0:\n            break;\n        default:\n            return BEST_FATTN_KERNEL_NONE;\n    }\n\n    if (mask && mask->ne[2] != 1) {\n        return BEST_FATTN_KERNEL_NONE;\n    }\n\n    // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:\n    const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;\n\n    // If Turing tensor cores are available, use them:\n    if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {\n        if (can_use_vector_kernel) {\n            if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {\n                if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {\n                    return BEST_FATTN_KERNEL_VEC;\n                }\n            } else {\n                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {\n                    if (Q->ne[1] <= 2) {\n                        return BEST_FATTN_KERNEL_VEC;\n                    }\n                } else {\n                    if (Q->ne[1] == 1) {\n                        return BEST_FATTN_KERNEL_VEC;\n                    }\n                }\n            }\n            if (!gqa_opt_applies && Q->ne[1] == 1) {\n                return BEST_FATTN_KERNEL_VEC;\n            }\n        }\n        return BEST_FATTN_KERNEL_MMA_F16;\n    }\n\n    if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {\n        int gqa_ratio_eff = 1;\n        const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;\n        while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {\n            gqa_ratio_eff *= 2;\n        }\n        if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) {\n            return BEST_FATTN_KERNEL_VEC;\n        }\n        if (Q->ne[1] * gqa_ratio_eff <= 16) {\n            return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices.\n        }\n        return BEST_FATTN_KERNEL_MMA_F16;\n    }\n\n    // Use the WMMA kernel if possible:\n    if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {\n        if (can_use_vector_kernel && Q->ne[1] <= 2) {\n            return BEST_FATTN_KERNEL_VEC;\n        }\n        return BEST_FATTN_KERNEL_WMMA_F16;\n    }\n\n    if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {\n        if (can_use_vector_kernel) {\n            if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {\n                if (Q->ne[1] == 1) {\n                    if (!gqa_opt_applies) {\n                        return BEST_FATTN_KERNEL_VEC;\n                    }\n                }\n            } else {\n                if (Q->ne[1] <= 2) {\n                    return BEST_FATTN_KERNEL_VEC;\n                }\n            }\n        }\n        int gqa_ratio_eff = 1;\n        const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;\n        while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {\n            gqa_ratio_eff *= 2;\n        }\n        if (Q->ne[1] * gqa_ratio_eff <= 8) {\n            return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized.\n        }\n        return BEST_FATTN_KERNEL_MMA_F16;\n    }\n\n    // Use MFMA flash attention for CDNA (MI100+):\n    if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) {\n        const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);\n        // MMA vs tile crossover benchmarked on MI300X @ d32768:\n        //   hsk=64  (gqa=4): MMA wins at eff >= 128 (+11%)\n        //   hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%)\n        if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) {\n            return BEST_FATTN_KERNEL_MMA_F16;\n        }\n        // Fall through to tile kernel for small effective batch sizes.\n    }\n\n    // If there are no tensor cores available, use the generic tile kernel:\n    if (can_use_vector_kernel) {\n        if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {\n            if (Q->ne[1] == 1) {\n                if (!gqa_opt_applies) {\n                    return BEST_FATTN_KERNEL_VEC;\n                }\n            }\n        } else {\n            if (Q->ne[1] <= 2) {\n                return BEST_FATTN_KERNEL_VEC;\n            }\n        }\n    }\n    return BEST_FATTN_KERNEL_TILE;\n}\n\nvoid ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_set_device(ctx.device);\n    switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {\n        case BEST_FATTN_KERNEL_NONE:\n            GGML_ABORT(\"fatal error\");\n        case BEST_FATTN_KERNEL_TILE:\n            ggml_cuda_flash_attn_ext_tile(ctx, dst);\n            break;\n        case BEST_FATTN_KERNEL_VEC:\n            ggml_cuda_flash_attn_ext_vec(ctx, dst);\n            break;\n        case BEST_FATTN_KERNEL_WMMA_F16:\n            ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);\n            break;\n        case BEST_FATTN_KERNEL_MMA_F16:\n            ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);\n            break;\n    }\n}\n\nbool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) {\n    return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;\n}\n"
  },
  {
    "path": "src/ggml-cuda/fattn.cuh",
    "content": "#include \"common.cuh\"\n\nvoid ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nbool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/fill.cu",
    "content": "#include \"fill.cuh\"\n#include \"convert.cuh\"\n\n#define CUDA_FILL_BLOCK_SIZE 256\n\ntemplate <typename T>\nstatic __global__ void fill_kernel(T * dst, const int64_t k, const T value) {\n    const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x;\n    if (i >= k) {\n        return;\n    }\n    dst[i] = value;\n}\n\nvoid ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    void * dst_d = dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(ggml_is_contiguous(dst));\n\n    float value;\n    memcpy(&value, dst->op_params, sizeof(float));\n\n    const int64_t k = ggml_nelements(dst);\n    const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE;\n\n    switch (dst->type) {\n        case GGML_TYPE_F32:\n            fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((float *)dst_d, k, value);\n            break;\n        case GGML_TYPE_F16:\n            fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((half *)dst_d, k, ggml_cuda_cast<half>(value));\n            break;\n        default:\n            GGML_ABORT(\"unsupported type\");\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/fill.cuh",
    "content": "#include \"common.cuh\"\n\nvoid ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/gated_delta_net.cu",
    "content": "#include \"gated_delta_net.cuh\"\n\ntemplate <int S_v, bool KDA>\n__global__ void gated_delta_net_cuda(const float * q,\n                                     const float * k,\n                                     const float * v,\n                                     const float * g,\n                                     const float * beta,\n                                     const float * curr_state,\n                                     float *       dst,\n                                     int64_t       H,\n                                     int64_t       n_tokens,\n                                     int64_t       n_seqs,\n                                     int64_t       sq1,\n                                     int64_t       sq2,\n                                     int64_t       sq3,\n                                     int64_t       sv1,\n                                     int64_t       sv2,\n                                     int64_t       sv3,\n                                     int64_t       sb1,\n                                     int64_t       sb2,\n                                     int64_t       sb3,\n                                     const uint3   neqk1_magic,\n                                     const uint3   rq3_magic,\n                                     float         scale) {\n    const uint32_t h_idx    = blockIdx.x;\n    const uint32_t sequence = blockIdx.y;\n    // each warp owns one column, using warp-level primitives to reduce across rows\n    const int      lane     = threadIdx.x;\n    const int      col      = blockIdx.z * blockDim.y + threadIdx.y;\n\n    const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);\n    const uint32_t iq3 = fastdiv(sequence, rq3_magic);\n\n    const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;\n    float *       attn_data        = dst;\n    float *       state            = dst + attn_score_elems;\n\n    const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;\n    state += state_offset;\n    curr_state += state_offset;\n    attn_data += (sequence * n_tokens * H + h_idx) * S_v;\n\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;\n    static_assert(S_v % warp_size == 0, \"S_v must be a multiple of warp_size\");\n    constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;\n    float         s_shard[rows_per_lane];\n    // state is stored transposed: M[col][i] = S[i][col], row col is contiguous\n#pragma unroll\n    for (int r = 0; r < rows_per_lane; r++) {\n        const int i = r * warp_size + lane;\n        s_shard[r]  = curr_state[col * S_v + i];\n    }\n\n    for (int t = 0; t < n_tokens; t++) {\n        const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;\n        const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;\n        const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1;\n\n        const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1;\n        const float * beta_t = beta + gb_offset;\n        const float * g_t    = g    + gb_offset * (KDA ? S_v : 1);\n\n        const float beta_val = *beta_t;\n\n        if constexpr (!KDA) {\n            const float g_val = expf(*g_t);\n\n            // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]\n            float kv_shard = 0.0f;\n#pragma unroll\n            for (int r = 0; r < rows_per_lane; r++) {\n                const int i = r * warp_size + lane;\n                kv_shard += s_shard[r] * k_t[i];\n            }\n            float kv_col = warp_reduce_sum<warp_size>(kv_shard);\n\n            // delta[col] = (v[col] - g * kv[col]) * beta\n            float delta_col = (v_t[col] - g_val * kv_col) * beta_val;\n\n            // fused: S[i][col] = g * S[i][col] + k[i] * delta[col]\n            // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]\n            float attn_partial = 0.0f;\n#pragma unroll\n            for (int r = 0; r < rows_per_lane; r++) {\n                const int i = r * warp_size + lane;\n                s_shard[r]  = g_val * s_shard[r] + k_t[i] * delta_col;\n                attn_partial += s_shard[r] * q_t[i];\n            }\n\n            float attn_col = warp_reduce_sum<warp_size>(attn_partial);\n\n            if (lane == 0) {\n                attn_data[col] = attn_col * scale;\n            }\n        } else {\n            // kv[col] = sum_i g[i] * S[i][col] * k[i]\n            float kv_shard = 0.0f;\n#pragma unroll\n            for (int r = 0; r < rows_per_lane; r++) {\n                const int i = r * warp_size + lane;\n                kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];\n            }\n\n            float kv_col = warp_reduce_sum<warp_size>(kv_shard);\n\n            // delta[col] = (v[col] - kv[col]) * beta\n            float delta_col = (v_t[col] - kv_col) * beta_val;\n\n            // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]\n            // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]\n            float attn_partial = 0.0f;\n#pragma unroll\n            for (int r = 0; r < rows_per_lane; r++) {\n                const int i = r * warp_size + lane;\n                s_shard[r]  = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;\n                attn_partial += s_shard[r] * q_t[i];\n            }\n\n            float attn_col = warp_reduce_sum<warp_size>(attn_partial);\n\n            if (lane == 0) {\n                attn_data[col] = attn_col * scale;\n            }\n        }\n\n        attn_data += S_v * H;\n    }\n\n    // Write state back to global memory (transposed layout)\n#pragma unroll\n    for (int r = 0; r < rows_per_lane; r++) {\n        const int i          = r * warp_size + lane;\n        state[col * S_v + i] = s_shard[r];\n    }\n}\n\ntemplate <bool KDA>\nstatic void launch_gated_delta_net(\n        const float * q_d, const float * k_d, const float * v_d,\n        const float * g_d, const float * b_d, const float * s_d,\n        float * dst_d,\n        int64_t S_v,   int64_t H, int64_t n_tokens, int64_t n_seqs,\n        int64_t sq1,   int64_t sq2, int64_t sq3,\n        int64_t sv1,   int64_t sv2, int64_t sv3,\n        int64_t sb1,   int64_t sb2, int64_t sb3,\n        int64_t neqk1, int64_t rq3,\n        float scale, cudaStream_t stream) {\n    //TODO: Add chunked kernel for even faster pre-fill\n    const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;\n    const int num_warps = 4;\n    dim3      grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);\n    dim3      block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);\n\n    const uint3 neqk1_magic = init_fastdiv_values(neqk1);\n    const uint3 rq3_magic   = init_fastdiv_values(rq3);\n\n    int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;\n\n    switch (S_v) {\n        case 16:\n            gated_delta_net_cuda<16, KDA><<<grid_dims, block_dims, 0, stream>>>(\n                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,\n                n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,\n                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);\n            break;\n        case 32:\n            gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(\n                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,\n                n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,\n                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);\n            break;\n        case 64: {\n            gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(\n                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,\n                n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,\n                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);\n            break;\n        }\n        case 128: {\n            gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(\n                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,\n                n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,\n                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);\n            break;\n        }\n        default:\n            GGML_ABORT(\"fatal error\");\n            break;\n    }\n}\n\nvoid ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src_q     = dst->src[0];\n    ggml_tensor * src_k     = dst->src[1];\n    ggml_tensor * src_v     = dst->src[2];\n    ggml_tensor * src_g     = dst->src[3];\n    ggml_tensor * src_beta  = dst->src[4];\n    ggml_tensor * src_state = dst->src[5];\n\n    GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);\n    GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);\n    GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);\n    GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);\n    GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);\n    GGML_TENSOR_LOCALS(size_t,  nbv, src_v, nb);\n    GGML_TENSOR_LOCALS(size_t,  nbb, src_beta, nb);\n\n    const int64_t S_v      = nev0;\n    const int64_t H        = nev1;\n    const int64_t n_tokens = nev2;\n    const int64_t n_seqs   = nev3;\n\n    const bool kda = (src_g->ne[0] == S_v);\n\n    GGML_ASSERT(neq1 == nek1);\n    const int64_t neqk1 = neq1;\n\n    const int64_t rq3 = nev3 / neq3;\n\n    const float * q_d = (const float *) src_q->data;\n    const float * k_d = (const float *) src_k->data;\n    const float * v_d = (const float *) src_v->data;\n    const float * g_d = (const float *) src_g->data;\n    const float * b_d = (const float *) src_beta->data;\n\n    const float * s_d   = (const float *) src_state->data;\n    float *       dst_d = (float *) dst->data;\n\n    GGML_ASSERT(ggml_is_contiguous_rows(src_q));\n    GGML_ASSERT(ggml_is_contiguous_rows(src_k));\n    GGML_ASSERT(ggml_is_contiguous_rows(src_v));\n    GGML_ASSERT(ggml_are_same_stride(src_q, src_k));\n    GGML_ASSERT(src_g->ne[0] == 1 || kda);\n    GGML_ASSERT(ggml_is_contiguous(src_g));\n    GGML_ASSERT(ggml_is_contiguous(src_beta));\n    GGML_ASSERT(ggml_is_contiguous(src_state));\n\n    // strides in floats (beta strides used for both g and beta offset computation)\n    const int64_t sq1 = nbq1 / sizeof(float);\n    const int64_t sq2 = nbq2 / sizeof(float);\n    const int64_t sq3 = nbq3 / sizeof(float);\n    const int64_t sv1 = nbv1 / sizeof(float);\n    const int64_t sv2 = nbv2 / sizeof(float);\n    const int64_t sv3 = nbv3 / sizeof(float);\n    const int64_t sb1 = nbb1 / sizeof(float);\n    const int64_t sb2 = nbb2 / sizeof(float);\n    const int64_t sb3 = nbb3 / sizeof(float);\n\n    const float scale = 1.0f / sqrtf((float) S_v);\n\n    cudaStream_t stream = ctx.stream();\n\n    if (kda) {\n        launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,\n            S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,\n            sb1, sb2, sb3, neqk1, rq3, scale, stream);\n    } else {\n        launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,\n            S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,\n            sb1, sb2, sb3, neqk1, rq3, scale, stream);\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/gated_delta_net.cuh",
    "content": "#include \"common.cuh\"\n#include \"ggml.h\"\n\nvoid ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/getrows.cu",
    "content": "#include \"getrows.cuh\"\n#include \"dequantize.cuh\"\n#include \"convert.cuh\"\n\ntemplate<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>\nstatic __global__ void k_get_rows(\n        const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,\n        const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/\n        /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/\n        /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,\n        /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,\n        const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {\n\n    for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {\n        for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {\n            // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.\n            const int i10 =  blockIdx.x;\n            const int i11 =  z / ne12; // TODO fastdiv\n            const int i12 =  z % ne12;\n\n            const int i01 = src1[i10*s10 + i11*s11 + i12*s12];\n\n            dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;\n            const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;\n\n            const int ib   =  i00/qk;      // block index\n            const int iqs  = (i00%qk)/qr;  // quant index\n            const int iybs = i00 - i00%qk; // dst block start index\n            const int y_offset = qr == 1 ? 1 : qk/2;\n\n            // dequantize\n            float2 v;\n            dequantize_kernel(src0_row, ib, iqs, v);\n\n            dst_row[iybs + iqs + 0]        = ggml_cuda_cast<dst_t>(v.x);\n            dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);\n        }\n    }\n}\n\ntemplate<typename src0_t, typename dst_t>\nstatic __global__ void k_get_rows_float(\n        const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,\n        const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/\n        /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/\n        /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,\n        /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,\n        const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {\n\n    for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {\n        for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {\n            // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.\n            const int i10 = blockIdx.x;\n            const int i11 = z / ne12; // TODO fastdiv\n            const int i12 = z % ne12;\n\n            if (i00 >= ne00) {\n                return;\n            }\n\n            const int i01 = src1[i10*s10 + i11*s11 + i12*s12];\n\n            dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;\n            const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);\n\n            dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);\n        }\n    }\n}\n\ntemplate<typename grad_t, typename dst_t>\nstatic __global__ void k_get_rows_back_float(\n        const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {\n    const int col = blockIdx.x*blockDim.x + threadIdx.x;\n\n    if (col >= ncols) {\n        return;\n    }\n\n    const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;\n\n    float sum = 0.0f;\n\n    for (int64_t i = 0; i < nrows_grad; ++i) {\n        if (rows[i] != dst_row) {\n            continue;\n        }\n        sum += grad[i*ncols + col];\n    }\n\n    dst[dst_row*ncols + col] = sum;\n}\n\ntemplate<int qk, int qr, dequantize_kernel_t dq, typename dst_t>\nstatic void get_rows_cuda_q(\n        const void * src0_d, const int32_t * src1_d, dst_t * dst_d,\n        const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,\n        const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,\n        const size_t nb1, const size_t nb2, const size_t nb3,\n        cudaStream_t stream) {\n    const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);\n    const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);\n    const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));\n\n    // strides in elements\n    // const size_t s0 = nb0 / sizeof(dst_t);\n    const size_t s1 = nb1 / sizeof(dst_t);\n    const size_t s2 = nb2 / sizeof(dst_t);\n    const size_t s3 = nb3 / sizeof(dst_t);\n\n    const size_t s10 = nb10 / sizeof(int32_t);\n    const size_t s11 = nb11 / sizeof(int32_t);\n    const size_t s12 = nb12 / sizeof(int32_t);\n    // const size_t s13 = nb13 / sizeof(int32_t);\n\n    GGML_ASSERT(ne00 % 2 == 0);\n\n    k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(\n        src0_d, src1_d, dst_d,\n        ne00, /*ne01, ne02, ne03,*/\n        /*ne10,*/ ne11, ne12, /*ne13,*/\n        /* s0,*/ s1, s2, s3,\n        /* nb00,*/ nb01, nb02, nb03,\n        s10, s11, s12/*, s13*/);\n}\n\ntemplate<typename src0_t, typename dst_t>\nstatic void get_rows_cuda_float(\n        const src0_t * src0_d, const int32_t * src1_d, dst_t * dst_d,\n        const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,\n        const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,\n        const size_t nb1, const size_t nb2, const size_t nb3,\n        cudaStream_t stream) {\n    const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);\n    const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;\n    const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));\n\n    // strides in elements\n    // const size_t s0 = nb0 / sizeof(dst_t);\n    const size_t s1 = nb1 / sizeof(dst_t);\n    const size_t s2 = nb2 / sizeof(dst_t);\n    const size_t s3 = nb3 / sizeof(dst_t);\n\n    const size_t s10 = nb10 / sizeof(int32_t);\n    const size_t s11 = nb11 / sizeof(int32_t);\n    const size_t s12 = nb12 / sizeof(int32_t);\n    // const size_t s13 = nb13 / sizeof(int32_t);\n\n    k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(\n        src0_d, src1_d, dst_d,\n        ne00, /*ne01, ne02, ne03,*/\n        /*ne10,*/ ne11, ne12, /*ne13,*/\n        /* s0,*/ s1, s2, s3,\n        /* nb00,*/ nb01, nb02, nb03,\n        s10, s11, s12/*, s13*/);\n}\n\ntemplate <typename dst_t>\nstatic void ggml_cuda_get_rows_switch_src0_type(\n        const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,\n        const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,\n        const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,\n        const size_t nb1, const size_t nb2, const size_t nb3,\n        cudaStream_t stream) {\n    switch (src0_type) {\n        case GGML_TYPE_F16:\n            get_rows_cuda_float((const half *) src0_d, src1_d, dst_d,\n                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_F32:\n            get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,\n                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_I32:\n            get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,\n                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_BF16:\n            get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,\n                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_Q4_0:\n            get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,\n                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_Q4_1:\n            get_rows_cuda_q<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_d, dst_d,\n                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_Q5_0:\n            get_rows_cuda_q<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_d, dst_d,\n                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_Q5_1:\n            get_rows_cuda_q<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_d, dst_d,\n                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_Q8_0:\n            get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,\n                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);\n            break;\n        default:\n            // TODO: k-quants\n            GGML_ABORT(\"%s: unsupported src0 type: %s\\n\", __func__, ggml_type_name(src0_type));\n            break;\n    }\n}\n\nvoid get_rows_cuda(\n        const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,\n        int64_t ne00, size_t nb01, size_t nb02, size_t nb03,\n        int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,\n        size_t nb1, size_t nb2, size_t nb3,\n        cudaStream_t stream) {\n    switch (dst_type) {\n        case GGML_TYPE_F32:\n            ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,\n                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_I32:\n            ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,\n                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_F16:\n            ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,\n                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_BF16:\n            ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (nv_bfloat16 *) dst_d,\n                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);\n            break;\n        default:\n            GGML_ABORT(\"%s: unsupported dst type: %s\\n\", __func__, ggml_type_name(dst_type));\n            break;\n    }\n}\n\nvoid ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    cudaStream_t stream = ctx.stream();\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    GGML_ASSERT(src1->type == GGML_TYPE_I32);\n    GGML_ASSERT(ne13 == 1);\n\n    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));\n    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));\n    GGML_ASSERT(dst->nb[0]  == ggml_type_size(dst->type));\n\n    get_rows_cuda(src0->data, src0->type, (const int32_t *) src1->data, dst->data, dst->type,\n        ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);\n}\n\nvoid ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output\n    const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const float   * src0_d = (const float   *) src0->data;\n    const int32_t * src1_d = (const int32_t *) src1->data;\n    float         * dst_d  = (float         *) dst->data;\n\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_I32);\n    GGML_ASSERT(dst->type  == GGML_TYPE_F32);\n\n    GGML_ASSERT(ggml_is_contiguous(src0));\n    GGML_ASSERT(ggml_is_contiguous(src1));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n\n    GGML_ASSERT(ne02*ne03 == 1);\n    GGML_ASSERT(ne12*ne13 == 1);\n    GGML_ASSERT(ne2*ne3 == 1);\n\n    const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);\n    const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;\n    const dim3 block_nums(block_num_x, ne1, 1);\n\n    k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10);\n}\n"
  },
  {
    "path": "src/ggml-cuda/getrows.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_GET_ROWS_BLOCK_SIZE 256\n#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256\n\nvoid get_rows_cuda(\n        const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,\n        int64_t ne00, size_t nb01, size_t nb02, size_t nb03,\n        int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,\n        size_t nb1, size_t nb2, size_t nb3,\n        cudaStream_t stream);\n\nvoid ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/ggml-cuda.cu",
    "content": "#include \"ggml-cuda.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-backend-impl.h\"\n\n#include \"ggml-cuda/common.cuh\"\n#include \"ggml-cuda/acc.cuh\"\n#include \"ggml-cuda/add-id.cuh\"\n#include \"ggml-cuda/arange.cuh\"\n#include \"ggml-cuda/argmax.cuh\"\n#include \"ggml-cuda/argsort.cuh\"\n#include \"ggml-cuda/binbcast.cuh\"\n#include \"ggml-cuda/clamp.cuh\"\n#include \"ggml-cuda/concat.cuh\"\n#include \"ggml-cuda/conv-transpose-1d.cuh\"\n#include \"ggml-cuda/conv2d.cuh\"\n#include \"ggml-cuda/conv2d-dw.cuh\"\n#include \"ggml-cuda/conv2d-transpose.cuh\"\n#include \"ggml-cuda/convert.cuh\"\n#include \"ggml-cuda/count-equal.cuh\"\n#include \"ggml-cuda/cpy.cuh\"\n#include \"ggml-cuda/cross-entropy-loss.cuh\"\n#include \"ggml-cuda/cumsum.cuh\"\n#include \"ggml-cuda/diagmask.cuh\"\n#include \"ggml-cuda/diag.cuh\"\n#include \"ggml-cuda/fattn.cuh\"\n#include \"ggml-cuda/getrows.cuh\"\n#include \"ggml-cuda/im2col.cuh\"\n#include \"ggml-cuda/mmf.cuh\"\n#include \"ggml-cuda/mmq.cuh\"\n#include \"ggml-cuda/mmvf.cuh\"\n#include \"ggml-cuda/mmvq.cuh\"\n#include \"ggml-cuda/norm.cuh\"\n#include \"ggml-cuda/opt-step-adamw.cuh\"\n#include \"ggml-cuda/opt-step-sgd.cuh\"\n#include \"ggml-cuda/out-prod.cuh\"\n#include \"ggml-cuda/pad.cuh\"\n#include \"ggml-cuda/pool2d.cuh\"\n#include \"ggml-cuda/quantize.cuh\"\n#include \"ggml-cuda/rope.cuh\"\n#include \"ggml-cuda/roll.cuh\"\n#include \"ggml-cuda/scale.cuh\"\n#include \"ggml-cuda/softcap.cuh\"\n#include \"ggml-cuda/softmax.cuh\"\n#include \"ggml-cuda/ssm-conv.cuh\"\n#include \"ggml-cuda/ssm-scan.cuh\"\n#include \"ggml-cuda/sum.cuh\"\n#include \"ggml-cuda/sumrows.cuh\"\n#include \"ggml-cuda/top-k.cuh\"\n#include \"ggml-cuda/mean.cuh\"\n#include \"ggml-cuda/tsembd.cuh\"\n#include \"ggml-cuda/topk-moe.cuh\"\n#include \"ggml-cuda/unary.cuh\"\n#include \"ggml-cuda/upscale.cuh\"\n#include \"ggml-cuda/wkv.cuh\"\n#include \"ggml-cuda/gla.cuh\"\n#include \"ggml-cuda/gated_delta_net.cuh\"\n#include \"ggml-cuda/set.cuh\"\n#include \"ggml-cuda/set-rows.cuh\"\n#include \"ggml-cuda/pad_reflect_1d.cuh\"\n#include \"ggml-cuda/solve_tri.cuh\"\n#include \"ggml-cuda/tri.cuh\"\n#include \"ggml-cuda/cumsum.cuh\"\n#include \"ggml-cuda/fill.cuh\"\n#include \"ggml.h\"\n\n#include <algorithm>\n#include <array>\n#include <atomic>\n#include <charconv>\n#include <cinttypes>\n#include <condition_variable>\n#include <cstddef>\n#include <cstdint>\n#include <cfloat>\n#include <initializer_list>\n#include <limits>\n#include <map>\n#include <memory>\n#include <mutex>\n#include <cstdarg>\n#include <cstdio>\n#include <cstdlib>\n#include <string>\n#include <vector>\n#include <unordered_set>\n\nstatic_assert(sizeof(half) == sizeof(ggml_fp16_t), \"wrong fp16 size\");\n\n[[noreturn]]\nvoid ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {\n    int id = -1; // in case cudaGetDevice fails\n    (void)cudaGetDevice(&id);\n\n    GGML_LOG_ERROR(GGML_CUDA_NAME \" error: %s\\n\", msg);\n    GGML_LOG_ERROR(\"  current device: %d, in function %s at %s:%d\\n\", id, func, file, line);\n    GGML_LOG_ERROR(\"  %s\\n\", stmt);\n    // abort with GGML_ABORT to get a stack trace\n    GGML_ABORT(GGML_CUDA_NAME \" error\");\n}\n\n// this is faster on Windows\n// probably because the Windows CUDA libraries forget to make this check before invoking the drivers\nvoid ggml_cuda_set_device(int device) {\n    int current_device;\n    CUDA_CHECK(cudaGetDevice(&current_device));\n\n    if (device == current_device) {\n        return;\n    }\n\n    CUDA_CHECK(cudaSetDevice(device));\n}\n\nint ggml_cuda_get_device() {\n    int id;\n    CUDA_CHECK(cudaGetDevice(&id));\n    return id;\n}\n\nstatic cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {\n    ggml_cuda_set_device(device);\n    cudaError_t err;\n    if (getenv(\"GGML_CUDA_ENABLE_UNIFIED_MEMORY\") != nullptr) {\n        err = cudaMallocManaged(ptr, size);\n#if defined(GGML_USE_HIP)\n        if (err == hipSuccess) {\n            // hipMemAdviseSetCoarseGrain is an optional performance hint;\n            // ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs).\n            cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device);\n            (void)hipGetLastError(); // clear any error\n        }\n\n        // fall back to cudaMalloc if not supported (e.g. on Windows)\n        if (err == hipErrorNotSupported) {\n            static bool warned_unsupported = false;\n            if (!warned_unsupported) {\n                GGML_LOG_WARN(\"hipMallocManaged unsupported, falling back to hipMalloc.\\n\");\n                warned_unsupported = true;\n            }\n\n            err = cudaMalloc(ptr, size);\n        }\n#endif // defined(GGML_USE_HIP)\n    } else {\n        err = cudaMalloc(ptr, size);\n    }\n    return err;\n}\n\n#if defined(GGML_USE_HIP)\nstatic int ggml_cuda_parse_id(char devName[]) {\n    // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp\n    // these values are not stable so this is susceptible to breakage\n    // https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp\n    int archMajor = 0x0;\n    int archMinor = 0x0;\n    int archNum = GGML_CUDA_CC_OFFSET_AMD;\n    int archLen = strlen(devName);\n    char archName[archLen + 1];\n\n    // strip leading 'gfx' while copying into our buffer\n    if (archLen > 3) {\n        strcpy(archName, &devName[3]);\n        archLen -= 3;\n    }\n\n    // trim trailing :xnack- or :sramecc- statuses\n    archLen = strcspn(archName, \":\");\n    archName[archLen] = '\\0';\n\n    // tease out the version information\n    if (archLen > 8) {\n        // versions labeled generic use '-' as delimiter\n        // strip the trailing \"-generic\" then iterate through what remains\n        if ((strstr(archName, \"-generic\"))) {\n            archName[archLen - 8] = '\\0';\n            char * pch;\n            if ((pch = strtok(archName, \"-\"))) {\n                archMajor = (int)strtoul(pch, 0, 16);\n                if ((pch = strtok(NULL, \"-\"))) {\n                    archMinor = 0x10 * (int)strtoul(pch, 0, 16);\n                }\n            }\n        }\n    } else if (archLen >= 3) {\n        // last two digits should be the minor * 0x10 + stepping\n        archMinor = (int)strtoul(&archName[archLen - 2], 0, 16);\n        archName[archLen - 2] = '\\0';\n\n        // only the major version remains\n        archMajor = (int)strtoul(archName, 0, 16);\n    }\n    archNum += archMajor * 0x100;\n    archNum += archMinor;\n    return archNum;\n}\n#endif // defined(GGML_USE_HIP)\n\nstatic ggml_cuda_device_info ggml_cuda_init() {\n    ggml_cuda_device_info info = {};\n\n    cudaError_t err = cudaGetDeviceCount(&info.device_count);\n    if (err != cudaSuccess) {\n        GGML_LOG_ERROR(\"%s: failed to initialize \" GGML_CUDA_NAME \": %s\\n\", __func__, cudaGetErrorString(err));\n        return info;\n    }\n\n    GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);\n\n    int64_t total_vram = 0;\n    for (int id = 0; id < info.device_count; ++id) {\n        cudaDeviceProp prop;\n        CUDA_CHECK(cudaGetDeviceProperties(&prop, id));\n        total_vram += prop.totalGlobalMem;\n    }\n    GGML_LOG_INFO(\"%s: found %d \" GGML_CUDA_NAME \" devices (Total VRAM: %zu MiB):\\n\",\n                  __func__, info.device_count, (size_t)(total_vram / (1024 * 1024)));\n    total_vram = 0;\n\n    std::vector<std::pair<int, std::string>> turing_devices_without_mma;\n    for (int id = 0; id < info.device_count; ++id) {\n        int device_vmm = 0;\n\n#if defined(GGML_USE_VMM)\n        CUdevice device;\n        CU_CHECK(cuDeviceGet(&device, id));\n        CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));\n\n        if (device_vmm) {\n            CUmemAllocationProp alloc_prop = {};\n            alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;\n            alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;\n            alloc_prop.location.id = id;\n            CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));\n        }\n#endif // defined(GGML_USE_VMM)\n        info.devices[id].vmm = !!device_vmm;\n\n        cudaDeviceProp prop;\n        CUDA_CHECK(cudaGetDeviceProperties(&prop, id));\n\n        info.default_tensor_split[id] = total_vram;\n        total_vram += prop.totalGlobalMem;\n        info.devices[id].integrated = false; // Temporarily disabled due to issues with corrupted output (e.g. #15034)\n        info.devices[id].nsm        = prop.multiProcessorCount;\n        info.devices[id].smpb       = prop.sharedMemPerBlock;\n        info.devices[id].warp_size  = prop.warpSize;\n\n#ifndef GGML_USE_MUSA\n        int supports_coop_launch = 0;\n        CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id));\n        info.devices[id].supports_cooperative_launch = !!supports_coop_launch;\n#else\n        info.devices[id].supports_cooperative_launch = false;\n#endif // !(GGML_USE_MUSA)\n\n#if defined(GGML_USE_HIP)\n        info.devices[id].smpbo = prop.sharedMemPerBlock;\n\n        info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);\n        if ((info.devices[id].cc & 0xff00) == 0x0) {\n            GGML_LOG_WARN(\"invalid architecture ID received for device %d %s: %s  cc %d.%d\\n\",\n                            id, prop.name, prop.gcnArchName, prop.major, prop.minor);\n\n            // Fallback to prop.major and prop.minor\n            if (prop.major > 0) {\n                info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100;\n                info.devices[id].cc += prop.minor * 0x10;\n            }\n        }\n        GGML_LOG_INFO(\"  Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d, VRAM: %zu MiB\\n\",\n                      id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,\n                      device_vmm ? \"yes\" : \"no\", prop.warpSize,\n                      (size_t)(prop.totalGlobalMem / (1024 * 1024)));\n#elif defined(GGML_USE_MUSA)\n        // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.\n        info.devices[id].warp_size = 32;\n        info.devices[id].smpbo = prop.sharedMemPerBlockOptin;\n        info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;\n        info.devices[id].cc += prop.minor * 0x10;\n        GGML_LOG_INFO(\"  Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\\n\",\n                      id, prop.name, prop.major, prop.minor, device_vmm ? \"yes\" : \"no\",\n                      (size_t)(prop.totalGlobalMem / (1024 * 1024)));\n#else\n        info.devices[id].smpbo = prop.sharedMemPerBlockOptin;\n        info.devices[id].cc = 100*prop.major + 10*prop.minor;\n        GGML_LOG_INFO(\"  Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\\n\",\n                      id, prop.name, prop.major, prop.minor, device_vmm ? \"yes\" : \"no\",\n                      (size_t)(prop.totalGlobalMem / (1024 * 1024)));\n        std::string device_name(prop.name);\n        if (device_name == \"NVIDIA GeForce MX450\") {\n            turing_devices_without_mma.push_back({ id, device_name });\n        } else if (device_name == \"NVIDIA GeForce MX550\") {\n            turing_devices_without_mma.push_back({ id, device_name });\n        } else if (device_name.substr(0, 21) == \"NVIDIA GeForce GTX 16\") {\n            turing_devices_without_mma.push_back({ id, device_name });\n        }\n\n        // Temporary performance fix:\n        // Setting device scheduling strategy for iGPUs with cc121 to \"spinning\" to avoid delays in cuda synchronize calls.\n        // TODO: Check for future drivers the default scheduling strategy and\n        // remove this call again when cudaDeviceScheduleSpin is default.\n        if (prop.major == 12 && prop.minor == 1) {\n            CUDA_CHECK(cudaSetDevice(id));\n            CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));\n        }\n\n#endif  // defined(GGML_USE_HIP)\n    }\n\n    if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) {\n        GGML_LOG_INFO(\"The following devices will have suboptimal performance due to a lack of tensor cores:\\n\");\n        for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) {\n            GGML_LOG_INFO(\n                \"  Device %d: %s\\n\", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str());\n        }\n        GGML_LOG_INFO(\n            \"Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\\n\");\n    }\n\n    for (int id = 0; id < info.device_count; ++id) {\n        info.default_tensor_split[id] /= total_vram;\n    }\n\n    // configure logging to stdout\n    // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));\n\n    return info;\n}\n\nconst ggml_cuda_device_info & ggml_cuda_info() {\n    static ggml_cuda_device_info info = ggml_cuda_init();\n    return info;\n}\n\n// #define DEBUG_CUDA_MALLOC\n\n// buffer pool for cuda (legacy)\nstruct ggml_cuda_pool_leg : public ggml_cuda_pool {\n    static const int MAX_BUFFERS = 256;\n\n    int device;\n    struct ggml_cuda_buffer {\n        void * ptr = nullptr;\n        size_t size = 0;\n    };\n\n    ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};\n    size_t pool_size = 0;\n\n    explicit ggml_cuda_pool_leg(int device) :\n        device(device) {\n    }\n\n    ~ggml_cuda_pool_leg() {\n        ggml_cuda_set_device(device);\n        for (int i = 0; i < MAX_BUFFERS; ++i) {\n            ggml_cuda_buffer & b = buffer_pool[i];\n            if (b.ptr != nullptr) {\n                CUDA_CHECK(cudaFree(b.ptr));\n                pool_size -= b.size;\n            }\n        }\n        GGML_ASSERT(pool_size == 0);\n    }\n\n    void * alloc(size_t size, size_t * actual_size) override {\n#ifdef DEBUG_CUDA_MALLOC\n        int nnz = 0;\n        size_t max_size = 0;\n#endif\n        size_t best_diff = 1ull << 36;\n        int ibest = -1;\n        for (int i = 0; i < MAX_BUFFERS; ++i) {\n            ggml_cuda_buffer& b = buffer_pool[i];\n            if (b.ptr != nullptr) {\n#ifdef DEBUG_CUDA_MALLOC\n                ++nnz;\n                if (b.size > max_size) max_size = b.size;\n#endif\n                if (b.size >= size) {\n                    size_t diff = b.size - size;\n                    if (diff < best_diff) {\n                        best_diff = diff;\n                        ibest = i;\n                        if (!best_diff) {\n                            void * ptr = b.ptr;\n                            *actual_size = b.size;\n                            b.ptr = nullptr;\n                            b.size = 0;\n                            return ptr;\n                        }\n                    }\n                }\n            }\n        }\n        if (ibest >= 0) {\n            ggml_cuda_buffer& b = buffer_pool[ibest];\n            void * ptr = b.ptr;\n            *actual_size = b.size;\n            b.ptr = nullptr;\n            b.size = 0;\n            return ptr;\n        }\n        void * ptr;\n        size_t look_ahead_size = (size_t) (1.05 * size);\n        look_ahead_size = 256 * ((look_ahead_size + 255)/256);\n        ggml_cuda_set_device(device);\n        CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));\n        *actual_size = look_ahead_size;\n        pool_size += look_ahead_size;\n#ifdef DEBUG_CUDA_MALLOC\n        GGML_LOG_INFO(\"%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\\n\", __func__, device, nnz,\n                           (uint32_t)(max_size / 1024 / 1024), (uint32_t)(pool_size / 1024 / 1024), (uint32_t)(size / 1024 / 1024));\n#endif\n        return ptr;\n    }\n\n    void free(void * ptr, size_t size) override {\n        for (int i = 0; i < MAX_BUFFERS; ++i) {\n            ggml_cuda_buffer& b = buffer_pool[i];\n            if (b.ptr == nullptr) {\n                b.ptr = ptr;\n                b.size = size;\n                return;\n            }\n        }\n        GGML_LOG_DEBUG(GGML_CUDA_NAME \" buffer pool full, increase MAX_CUDA_BUFFERS\\n\");\n        ggml_cuda_set_device(device);\n        CUDA_CHECK(cudaFree(ptr));\n        pool_size -= size;\n    }\n};\n\n// pool with virtual memory\n#if defined(GGML_USE_VMM)\nstruct ggml_cuda_pool_vmm : public ggml_cuda_pool {\n    static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB\n\n    int device;\n    CUdeviceptr pool_addr = 0;\n    size_t pool_used = 0;\n    size_t pool_size = 0;\n    size_t granularity;\n#if defined(GGML_USE_HIP)\n    std::vector<std::pair<CUdeviceptr, size_t>> mappings;\n#endif\n\n    explicit ggml_cuda_pool_vmm(int device) :\n        device(device),\n        granularity(ggml_cuda_info().devices[device].vmm_granularity) {\n    }\n\n    ~ggml_cuda_pool_vmm() {\n        if (pool_addr != 0) {\n#if defined(GGML_USE_HIP)\n            // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285\n            for (std::pair<CUdeviceptr, size_t> & mapping : mappings) {\n                CU_CHECK(cuMemUnmap(mapping.first, mapping.second));\n            }\n#else\n            CU_CHECK(cuMemUnmap(pool_addr, pool_size));\n#endif\n            CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE));\n        }\n    }\n\n    void * alloc(size_t size, size_t * actual_size) override {\n        // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types\n        const size_t alignment = 128;\n        size = alignment * ((size + alignment - 1) / alignment);\n\n        size_t avail = pool_size - pool_used;\n\n        if (size > avail) {\n            // round up to the next multiple of the granularity\n            size_t reserve_size = size - avail;\n            reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);\n\n            GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);\n\n            // allocate more physical memory\n            CUmemAllocationProp prop = {};\n            prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;\n            prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;\n            prop.location.id = device;\n            CUmemGenericAllocationHandle handle;\n            CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));\n\n            // reserve virtual address space (if not already reserved)\n            if (pool_addr == 0) {\n                CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));\n            }\n\n            // map at the end of the pool\n            CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size);\n            CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0));\n#if defined(GGML_USE_HIP)\n            mappings.push_back({start_ptr, reserve_size});\n#endif\n\n            // the memory allocation handle is no longer needed after mapping\n            CU_CHECK(cuMemRelease(handle));\n\n            // set access\n            CUmemAccessDesc access = {};\n            access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;\n            access.location.id = device;\n            access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;\n            CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1));\n\n            // add to the pool\n            pool_size += reserve_size;\n\n            //printf(\"cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\\n\",\n            //       device, (unsigned long long) (pool_size/1024/1024),\n            //       (unsigned long long) (reserve_size/1024/1024));\n        }\n\n        GGML_ASSERT(pool_addr != 0);\n\n        void * ptr = (void *) ((CUdeviceptr)((char *)(pool_addr) + pool_used));\n        *actual_size = size;\n        pool_used += size;\n\n#ifdef DEBUG_CUDA_MALLOC\n        printf(\"cuda pool[%d]: allocated %llu bytes at %llx\\n\", device, (unsigned long long) size, ptr);\n#endif\n\n        return ptr;\n    }\n\n    void free(void * ptr, size_t size) override {\n#ifdef DEBUG_CUDA_MALLOC\n        printf(\"cuda pool[%d]: freed %llu bytes at %llx\\n\", device, (unsigned long long) size, ptr);\n#endif\n\n        pool_used -= size;\n\n        // all deallocations must be in reverse order of the allocations\n        GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));\n    }\n};\n#endif // defined(GGML_USE_VMM)\n\nstd::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int                  device,\n                                                                               [[maybe_unused]] int stream_no) {\n#if defined(GGML_USE_VMM)\n    if (ggml_cuda_info().devices[device].vmm) {\n        return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));\n    }\n#endif // defined(GGML_USE_VMM)\n    return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));\n}\n\n// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error\n// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured\n\nstatic std::mutex ggml_cuda_lock;\nstatic std::condition_variable ggml_cuda_lock_cv;\nstatic std::atomic<int> ggml_cuda_lock_counter;\n\nggml_backend_cuda_context::~ggml_backend_cuda_context() {\n    std::unique_lock<std::mutex> lock(ggml_cuda_lock);\n    ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });\n\n    if (copy_event != nullptr) {\n        CUDA_CHECK(cudaEventDestroy(copy_event));\n    }\n    for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {\n        for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {\n            if (streams[i][j] != nullptr) {\n                CUDA_CHECK(cudaStreamDestroy(streams[i][j]));\n            }\n        }\n        if (cublas_handles[i] != nullptr) {\n            CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));\n        }\n    }\n}\n\n\n// cuda buffer\n\nstruct ggml_backend_cuda_buffer_context {\n    int device;\n    void * dev_ptr = nullptr;\n    std::string name;\n\n    ggml_backend_cuda_buffer_context(int device, void * dev_ptr) :\n        device(device), dev_ptr(dev_ptr),\n        name(GGML_CUDA_NAME + std::to_string(device)) {\n    }\n\n    ~ggml_backend_cuda_buffer_context() {\n        CUDA_CHECK(cudaFree(dev_ptr));\n    }\n};\n\nstatic void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;\n    delete ctx;\n}\n\nstatic bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {\n    return buffer->iface.free_buffer == ggml_backend_cuda_buffer_free_buffer;\n}\n\nstatic void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {\n    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;\n    return ctx->dev_ptr;\n}\n\nstatic enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {\n    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;\n\n    if (tensor->view_src != NULL) {\n        assert(tensor->view_src->buffer->buft == buffer->buft);\n        return GGML_STATUS_SUCCESS;\n    }\n\n    if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) {\n        // initialize padding to 0 to avoid possible NaN values\n        const size_t original_size = ggml_nbytes(tensor);\n        const size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);\n\n        if (padded_size > original_size) {\n            ggml_cuda_set_device(ctx->device);\n            CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size));\n        }\n    }\n    return GGML_STATUS_SUCCESS;\n}\n\nstatic void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {\n    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;\n\n    ggml_cuda_set_device(ctx->device);\n    CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));\n    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));\n}\n\nstatic void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;\n\n    ggml_cuda_set_device(ctx->device);\n    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));\n    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));\n}\n\nstatic void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;\n\n    ggml_cuda_set_device(ctx->device);\n    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));\n    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));\n}\n\nstatic bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {\n    if (ggml_backend_buffer_is_cuda(src->buffer)) {\n        ggml_backend_cuda_buffer_context * src_ctx = (ggml_backend_cuda_buffer_context *)src->buffer->context;\n        ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *)dst->buffer->context;\n        if (src_ctx->device == dst_ctx->device) {\n            CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice, cudaStreamPerThread));\n        } else {\n#ifdef GGML_CUDA_NO_PEER_COPY\n            return false;\n#else\n            CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread));\n#endif\n        }\n        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));\n        return true;\n    }\n    return false;\n\n    GGML_UNUSED(buffer);\n}\n\nstatic void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;\n\n    ggml_cuda_set_device(ctx->device);\n    CUDA_CHECK(cudaMemsetAsync(ctx->dev_ptr, value, buffer->size, cudaStreamPerThread));\n    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));\n}\n\nstatic const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {\n    /* .free_buffer     = */ ggml_backend_cuda_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_cuda_buffer_get_base,\n    /* .init_tensor     = */ ggml_backend_cuda_buffer_init_tensor,\n    /* .memset_tensor   = */ ggml_backend_cuda_buffer_memset_tensor,\n    /* .set_tensor      = */ ggml_backend_cuda_buffer_set_tensor,\n    /* .get_tensor      = */ ggml_backend_cuda_buffer_get_tensor,\n    /* .cpy_tensor      = */ ggml_backend_cuda_buffer_cpy_tensor,\n    /* .clear           = */ ggml_backend_cuda_buffer_clear,\n    /* .reset           = */ NULL,\n};\n\n// cuda buffer type\nstruct ggml_backend_cuda_buffer_type_context {\n    int device;\n    std::string name;\n};\n\nstatic const char * ggml_backend_cuda_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    ggml_backend_cuda_buffer_type_context * ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;\n\n    return ctx->name.c_str();\n}\n\nstatic bool ggml_backend_buft_is_cuda(ggml_backend_buffer_type_t buft) {\n    return buft->iface.get_name == ggml_backend_cuda_buffer_type_get_name;\n}\n\nstatic ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;\n\n    ggml_cuda_set_device(buft_ctx->device);\n\n    void * dev_ptr;\n    cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);\n    if (err != cudaSuccess) {\n        // clear the error\n        (void)cudaGetLastError();\n        GGML_LOG_ERROR(\"%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\\n\", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));\n        return nullptr;\n    }\n\n    ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);\n\n    return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);\n}\n\nstatic size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    return 128;\n\n    GGML_UNUSED(buft);\n}\n\nstatic size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {\n    size_t size = ggml_nbytes(tensor);\n    int64_t ne0 = tensor->ne[0];\n\n    if (ggml_is_quantized(tensor->type)) {\n        if (ne0 % MATRIX_ROW_PADDING != 0) {\n            GGML_ASSERT(tensor->nb[0] == ggml_element_size(tensor));\n            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);\n        }\n    }\n\n    return size;\n\n    GGML_UNUSED(buft);\n}\n\nstatic const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {\n    /* .get_name         = */ ggml_backend_cuda_buffer_type_get_name,\n    /* .alloc_buffer     = */ ggml_backend_cuda_buffer_type_alloc_buffer,\n    /* .get_alignment    = */ ggml_backend_cuda_buffer_type_get_alignment,\n    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX\n    /* .get_alloc_size   = */ ggml_backend_cuda_buffer_type_get_alloc_size,\n    /* .is_host          = */ NULL,\n};\n\nggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {\n    static std::mutex mutex;\n    std::lock_guard<std::mutex> lock(mutex);\n\n    if (device >= ggml_backend_cuda_get_device_count()) {\n        return nullptr;\n    }\n\n    static ggml_backend_buffer_type ggml_backend_cuda_buffer_types[GGML_CUDA_MAX_DEVICES];\n\n    static bool ggml_backend_cuda_buffer_type_initialized = false;\n\n    if (!ggml_backend_cuda_buffer_type_initialized) {\n        for (int i = 0; i < ggml_backend_cuda_get_device_count(); i++) {\n            ggml_backend_cuda_buffer_types[i] = {\n                /* .iface    = */ ggml_backend_cuda_buffer_type_interface,\n                /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), i),\n                /* .context  = */ new ggml_backend_cuda_buffer_type_context{i, GGML_CUDA_NAME + std::to_string(i)},\n            };\n        }\n        ggml_backend_cuda_buffer_type_initialized = true;\n    }\n\n    return &ggml_backend_cuda_buffer_types[device];\n}\n\n// cuda split buffer\n\nstatic int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {\n    int64_t row_rounding = 0;\n    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {\n        if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {\n            continue;\n        }\n\n        const int cc = ggml_cuda_info().devices[id].cc;\n        row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));\n    }\n    return row_rounding;\n}\n\nstatic void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {\n    const int64_t nrows = ggml_nrows(tensor);\n    const int64_t rounding = get_row_rounding(tensor_split);\n\n    *row_low = id == 0 ? 0 : nrows*tensor_split[id];\n    *row_low -= *row_low % rounding;\n\n    if (id == ggml_backend_cuda_get_device_count() - 1) {\n        *row_high = nrows;\n    } else {\n        *row_high = nrows*tensor_split[id + 1];\n        *row_high -= *row_high % rounding;\n    }\n}\n\nstatic size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);\n}\n\nstruct ggml_backend_cuda_split_buffer_type_context {\n    int main_device;\n    std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;\n    std::string name;\n};\n\nstruct ggml_backend_cuda_split_buffer_context {\n    ~ggml_backend_cuda_split_buffer_context() {\n        for (ggml_tensor_extra_gpu * extra : tensor_extras) {\n            for (int id = 0; id < GGML_CUDA_MAX_DEVICES; ++id) {\n                for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {\n                    if (extra->events[id][is] != nullptr) {\n                        CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));\n                    }\n                }\n                if (extra->data_device[id] != nullptr) {\n                    CUDA_CHECK(cudaFree(extra->data_device[id]));\n                }\n            }\n            delete extra;\n        }\n    }\n\n    std::vector<ggml_tensor_extra_gpu *> tensor_extras;\n};\n\n\nstatic void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;\n    delete ctx;\n}\n\nstatic void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {\n    // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced\n    return (void *)0x1000;\n\n    GGML_UNUSED(buffer);\n}\n\nstatic enum ggml_status ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {\n    GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported\n    GGML_ASSERT(ggml_is_contiguous(tensor) && \"split buffers only supported for contiguous tensors\");\n\n    ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;\n    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;\n\n    const int64_t ne0 = tensor->ne[0];\n\n    ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};\n    ctx->tensor_extras.push_back(extra);\n\n    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {\n        int64_t row_low, row_high;\n        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);\n\n        int64_t nrows_split = row_high - row_low;\n        if (nrows_split == 0) {\n            continue;\n        }\n\n        size_t size = ggml_nbytes_split(tensor, nrows_split);\n        const size_t original_size = size;\n\n        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses\n        if (ne0 % MATRIX_ROW_PADDING != 0) {\n            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);\n        }\n\n        // FIXME: do not crash if cudaMalloc fails\n        // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first\n        ggml_cuda_set_device(id);\n        char * buf;\n        CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));\n\n        // set padding to 0 to avoid possible NaN values\n        if (size > original_size) {\n            CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));\n        }\n\n        extra->data_device[id] = buf;\n\n        for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {\n            CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming));\n        }\n    }\n    tensor->extra = extra;\n    return GGML_STATUS_SUCCESS;\n}\n\nstatic void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    // split tensors must always be set in their entirety at once\n    GGML_ASSERT(offset == 0);\n    GGML_ASSERT(size == ggml_nbytes(tensor));\n    GGML_ASSERT(ggml_is_contiguous(tensor) && \"split buffers only supported for contiguous tensors\");\n\n    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;\n\n    const int64_t ne0 = tensor->ne[0];\n    const size_t nb1 = tensor->nb[1];\n    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;\n\n    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {\n        int64_t row_low, row_high;\n        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);\n\n        int64_t nrows_split = row_high - row_low;\n        if (nrows_split == 0) {\n            continue;\n        }\n\n        const size_t offset_split = row_low*nb1;\n        size_t size = ggml_nbytes_split(tensor, nrows_split);\n        const size_t original_size = size;\n\n        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses\n        if (ne0 % MATRIX_ROW_PADDING != 0) {\n            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);\n        }\n\n        const char * buf_host = (const char *)data + offset_split;\n        CUDA_CHECK(cudaMemcpyAsync(extra->data_device[id], buf_host, original_size, cudaMemcpyHostToDevice, cudaStreamPerThread));\n    }\n\n    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {\n        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));\n    }\n}\n\nstatic void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    // split tensors must always be set in their entirety at once\n    GGML_ASSERT(offset == 0);\n    GGML_ASSERT(size == ggml_nbytes(tensor));\n    GGML_ASSERT(ggml_is_contiguous(tensor) && \"split buffers only supported for contiguous tensors\");\n\n    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;\n\n    const int64_t ne0 = tensor->ne[0];\n    const size_t nb1 = tensor->nb[1];\n    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;\n\n    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {\n        int64_t row_low, row_high;\n        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);\n\n        int64_t nrows_split = row_high - row_low;\n        if (nrows_split == 0) {\n            continue;\n        }\n\n        const size_t offset_split = row_low*nb1;\n        size_t size = ggml_nbytes_split(tensor, nrows_split);\n        const size_t original_size = size;\n\n        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses\n        if (ne0 % MATRIX_ROW_PADDING != 0) {\n            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);\n        }\n\n        char * buf_host = (char *)data + offset_split;\n        CUDA_CHECK(cudaMemcpyAsync(buf_host, extra->data_device[id], original_size, cudaMemcpyDeviceToHost, cudaStreamPerThread));\n    }\n\n    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {\n        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));\n    }\n}\n\nstatic void ggml_backend_cuda_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    GGML_UNUSED(buffer);\n    GGML_UNUSED(value);\n}\n\nstatic const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {\n    /* .free_buffer     = */ ggml_backend_cuda_split_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_cuda_split_buffer_get_base,\n    /* .init_tensor     = */ ggml_backend_cuda_split_buffer_init_tensor,\n    /* .memset_tensor   = */ NULL,\n    /* .set_tensor      = */ ggml_backend_cuda_split_buffer_set_tensor,\n    /* .get_tensor      = */ ggml_backend_cuda_split_buffer_get_tensor,\n    /* .cpy_tensor      = */ NULL,\n    /* .clear           = */ ggml_backend_cuda_split_buffer_clear,\n    /* .reset           = */ NULL,\n};\n\n// cuda split buffer type\n\nstatic const char * ggml_backend_cuda_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;\n\n    return ctx->name.c_str();\n}\n\nstatic bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) {\n    return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_get_name;\n}\n\nstatic ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point\n    // instead, we allocate them for each tensor separately in init_tensor\n    // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,\n    // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.\n    ggml_backend_cuda_split_buffer_context * ctx = new ggml_backend_cuda_split_buffer_context();\n\n    return ggml_backend_buffer_init(buft, ggml_backend_cuda_split_buffer_interface, ctx, size);\n}\n\nstatic size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    return 128;\n\n    GGML_UNUSED(buft);\n}\n\nstatic size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {\n    ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;\n    GGML_ASSERT(ggml_is_contiguous(tensor) && \"split buffers only supported for contiguous tensors\");\n\n    size_t total_size = 0;\n\n    const int64_t ne0 = tensor->ne[0];\n\n    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {\n        int64_t row_low, row_high;\n        get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, id);\n\n        int64_t nrows_split = row_high - row_low;\n        if (nrows_split == 0) {\n            continue;\n        }\n\n        total_size += ggml_nbytes_split(tensor, nrows_split);\n\n        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses\n        if (ne0 % MATRIX_ROW_PADDING != 0) {\n            total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);\n        }\n    }\n\n    return total_size;\n}\n\nstatic bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {\n    return false;\n\n    GGML_UNUSED(buft);\n}\n\nstatic const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface = {\n    /* .get_name         = */ ggml_backend_cuda_split_buffer_type_get_name,\n    /* .alloc_buffer     = */ ggml_backend_cuda_split_buffer_type_alloc_buffer,\n    /* .get_alignment    = */ ggml_backend_cuda_split_buffer_type_get_alignment,\n    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX\n    /* .get_alloc_size   = */ ggml_backend_cuda_split_buffer_type_get_alloc_size,\n    /* .is_host          = */ ggml_backend_cuda_split_buffer_type_is_host,\n};\n\nggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) {\n    static std::mutex mutex;\n    std::lock_guard<std::mutex> lock(mutex);\n\n    static std::map<std::pair<int, std::array<float, GGML_CUDA_MAX_DEVICES>>, struct ggml_backend_buffer_type> buft_map;\n\n    std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split_arr = {};\n\n    bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_CUDA_MAX_DEVICES, [](float x) { return x == 0.0f; });\n    if (all_zero) {\n        tensor_split_arr = ggml_cuda_info().default_tensor_split;\n    } else {\n        float split_sum = 0.0f;\n        for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {\n            tensor_split_arr[i] = split_sum;\n            split_sum += tensor_split[i];\n        }\n        for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {\n            tensor_split_arr[i] /= split_sum;\n        }\n    }\n\n    auto it = buft_map.find({main_device, tensor_split_arr});\n    if (it != buft_map.end()) {\n        return &it->second;\n    }\n    auto * ctx = new ggml_backend_cuda_split_buffer_type_context{\n        main_device,\n        tensor_split_arr,\n        GGML_CUDA_NAME + std::to_string(main_device) + \"_Split\",\n    };\n\n    struct ggml_backend_buffer_type buft {\n        /* .iface   = */ ggml_backend_cuda_split_buffer_type_interface,\n        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), main_device),\n        /* .context = */ ctx,\n    };\n\n    auto result = buft_map.emplace(std::make_pair(main_device, tensor_split_arr), buft);\n    return &result.first->second;\n}\n\n// host buffer type\n\nstatic const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_type_t buft) {\n    return GGML_CUDA_NAME \"_Host\";\n\n    GGML_UNUSED(buft);\n}\n\nstatic bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {\n    return buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;\n}\n\nstatic void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    CUDA_CHECK(cudaFreeHost(buffer->context));\n}\n\nstatic void * ggml_cuda_host_malloc(size_t size) {\n    if (getenv(\"GGML_CUDA_NO_PINNED\") != nullptr) {\n        return nullptr;\n    }\n\n    void * ptr = nullptr;\n    cudaError_t err = cudaMallocHost((void **) &ptr, size);\n    if (err != cudaSuccess) {\n        // clear the error\n        (void)cudaGetLastError();\n        GGML_LOG_DEBUG(\"%s: failed to allocate %.2f MiB of pinned memory: %s\\n\", __func__,\n                           size / 1024.0 / 1024.0, cudaGetErrorString(err));\n        return nullptr;\n    }\n\n    return ptr;\n}\n\nstatic ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    void * ptr = ggml_cuda_host_malloc(size);\n\n    if (ptr == nullptr) {\n        // fallback to cpu buffer\n        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);\n    }\n\n    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);\n    buffer->buft = buft;\n    buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;\n\n    return buffer;\n}\n\nggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {\n    static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_type_host = {\n        /* .iface    = */ {\n            /* .get_name         = */ ggml_backend_cuda_host_buffer_type_name,\n            /* .alloc_buffer     = */ ggml_backend_cuda_host_buffer_type_alloc_buffer,\n            /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,\n            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX\n            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,\n            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,\n        },\n        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), 0),\n        /* .context  = */ nullptr,\n    };\n\n    return &ggml_backend_cuda_buffer_type_host;\n}\n\n//static bool ggml_backend_buffer_is_cuda_host(ggml_backend_buffer_t buffer) {\n//    return buffer->buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;\n//}\n\n/// kernels\n\ntypedef void (*ggml_cuda_op_mul_mat_t)(\n    ggml_backend_cuda_context & ctx,\n    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,\n    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,\n    const int64_t src1_padded_row_size, cudaStream_t stream);\n\n#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE\n#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128\n#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE\n\n#define MUL_MAT_SRC1_COL_STRIDE 128\n\nstatic cudaError_t ggml_cuda_cpy_tensor_2d(\n    void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {\n\n    const char * src_ptr = (const char *) src->data;\n    char       * dst_ptr = (char       *) dst;\n\n    const int64_t ne0 = src->ne[0];\n    const int64_t nb0 = src->nb[0];\n    const int64_t nb1 = src->nb[1];\n    const int64_t nb2 = src->nb[2];\n    const int64_t nb3 = src->nb[3];\n    const enum ggml_type type = src->type;\n    const int64_t ts = ggml_type_size(type);\n    const int64_t bs = ggml_blck_size(type);\n    const int64_t i1_diff = i1_high - i1_low;\n\n    const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;\n    if (nb0 == ts && nb1 == ts*ne0/bs) {\n        return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyDeviceToDevice, stream);\n    } else if (nb0 == ts) {\n        return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyDeviceToDevice, stream);\n    } else {\n        for (int64_t i1 = 0; i1 < i1_diff; i1++) {\n            const void * rx = (const void *) ((const char *) x + i1*nb1);\n            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);\n            // pretend the row is a matrix with cols=1\n            cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyDeviceToDevice, stream);\n            if (r != cudaSuccess) {\n                return r;\n            }\n        }\n        return cudaSuccess;\n    }\n}\n\nstruct cublas_force_compute_type {\n    bool fp32 = false;\n    bool fp16 = false;\n};\n\nstatic const cublas_force_compute_type & ggml_cuda_cublas_get_force_compute_type() {\n    static const cublas_force_compute_type compute_type = [] {\n        cublas_force_compute_type result;\n\n        const bool ggml_cuda_force_cublas_compute_32f_env = getenv(\"GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F\") != nullptr;\n        const bool ggml_cuda_force_cublas_compute_16f_env = getenv(\"GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F\") != nullptr;\n\n        GGML_ASSERT(ggml_cuda_force_cublas_compute_16f_env == false || ggml_cuda_force_cublas_compute_32f_env == false);\n\n        if (ggml_cuda_force_cublas_compute_32f_env) {\n            GGML_LOG_INFO(\"Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F\\n\");\n            result.fp32 = true;\n        } else if (ggml_cuda_force_cublas_compute_16f_env) {\n            GGML_LOG_INFO(\"Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F\\n\");\n            result.fp16 = true;\n        }\n\n        return result;\n    }();\n\n    return compute_type;\n}\n\nstatic void ggml_cuda_op_mul_mat_cublas(\n    ggml_backend_cuda_context & ctx,\n    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,\n    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,\n    const int64_t src1_padded_row_size, cudaStream_t stream) {\n\n    GGML_ASSERT(src0_dd_i  != nullptr);\n    GGML_ASSERT(src1_ddf_i != nullptr);\n    GGML_ASSERT(dst_dd_i   != nullptr);\n\n    const int64_t ne00 = src0->ne[0];\n    const int64_t ne10 = src1->ne[0];\n\n    const int64_t ne0 = dst->ne[0];\n\n    const int64_t row_diff = row_high - row_low;\n\n    int id = ggml_cuda_get_device();\n\n    // the main device has a larger memory buffer to hold the results from all GPUs\n    // ldc == nrows of the matrix that cuBLAS writes into\n    int64_t ldc = id == ctx.device ? ne0 : row_diff;\n\n    const int cc = ggml_cuda_info().devices[id].cc;\n\n    const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||\n        (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);\n\n    const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;\n\n    if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {\n        ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));\n        if (src1->type != GGML_TYPE_BF16) {\n            const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);\n            GGML_ASSERT(to_bf16_cuda != nullptr);\n            size_t ne = src1_ncols*ne10;\n            src1_as_bf16.alloc(ne);\n            to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), ne, stream);\n        }\n        const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();\n        const nv_bfloat16 * src0_ptr = (const nv_bfloat16 *)src0_dd_i;\n        ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols);\n\n        const float alpha_f32 = 1.0f;\n        const float beta_f32  = 0.0f;\n\n        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));\n        CUBLAS_CHECK(\n            cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,\n                    row_diff, src1_ncols, ne10,\n                    &alpha_f32,  src0_ptr,       CUDA_R_16BF, ne00,\n                                 src1_ptr,       CUDA_R_16BF, ne10,\n                    &beta_f32,   dst_bf16.get(), CUDA_R_16BF, ldc,\n                    CUBLAS_COMPUTE_32F,\n                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);\n        to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);\n    } else if (fast_fp16_hardware_available(cc) && use_fp16) {\n        // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32\n        ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));\n        if (src0->type != GGML_TYPE_F16) {\n            const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);\n            GGML_ASSERT(to_fp16_cuda != nullptr);\n            size_t ne = row_diff*ne00;\n            src0_as_f16.alloc(ne);\n            to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);\n        }\n        const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();\n\n        ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));\n        if (src1->type != GGML_TYPE_F16) {\n            const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);\n            GGML_ASSERT(to_fp16_cuda != nullptr);\n            size_t ne = src1_ncols*ne10;\n            src1_as_f16.alloc(ne);\n            to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);\n        }\n        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();\n\n        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));\n\n        const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();\n\n        if (!force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)\n                                        || GGML_CUDA_CC_IS_RDNA4(cc)\n                                        || cc == GGML_CUDA_CC_VOLTA\n                                        || force_compute_type.fp32))\n        {\n            const float alpha = 1.0f;\n            const float beta = 0.0f;\n            CUBLAS_CHECK(\n                cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,\n                        row_diff, src1_ncols, ne10,\n                        &alpha, src0_ptr,  CUDA_R_16F, ne00,\n                                src1_ptr,  CUDA_R_16F, ne10,\n                        &beta,   dst_dd_i, CUDA_R_32F, ldc,\n                        CUBLAS_COMPUTE_32F,\n                        CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n        } else {\n            ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);\n\n            const half alpha_f16 = 1.0f;\n            const half beta_f16 = 0.0f;\n\n            CUBLAS_CHECK(\n                cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,\n                        row_diff, src1_ncols, ne10,\n                        &alpha_f16, src0_ptr,      CUDA_R_16F, ne00,\n                                    src1_ptr,      CUDA_R_16F, ne10,\n                        &beta_f16,  dst_f16.get(), CUDA_R_16F, ldc,\n                        CUBLAS_COMPUTE_16F,\n                        CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);\n            to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);\n        }\n    } else {\n        ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));\n        ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id));\n\n        if (src0->type != GGML_TYPE_F32) {\n            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);\n            GGML_ASSERT(to_fp32_cuda != nullptr);\n            src0_ddq_as_f32.alloc(row_diff*ne00);\n            to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);\n        }\n        if (src1->type != GGML_TYPE_F32) {\n            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);\n            GGML_ASSERT(to_fp32_cuda != nullptr);\n            src1_ddq_as_f32.alloc(src1_ncols*ne10);\n            to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);\n        }\n\n        const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();\n        const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();\n\n        const float alpha = 1.0f;\n        const float beta = 0.0f;\n\n        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));\n        CUBLAS_CHECK(\n            cublasSgemm(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,\n                    row_diff, src1_ncols, ne10,\n                    &alpha, src0_ddf_i,  ne00,\n                            src1_ddf1_i, ne10,\n                    &beta,  dst_dd_i,    ldc));\n    }\n\n    GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size);\n}\n\nstatic void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {\n    static bool peer_access_enabled = false;\n\n    const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;\n\n    if (peer_access_enabled == enable_peer_access) {\n        return;\n    }\n\n#ifdef NDEBUG\n    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {\n        ggml_cuda_set_device(id);\n        CUDA_CHECK(cudaDeviceSynchronize());\n    }\n\n    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {\n        ggml_cuda_set_device(id);\n\n        for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) {\n            if (id == id_other) {\n                continue;\n            }\n            if (id != main_device && id_other != main_device) {\n                continue;\n            }\n\n            int can_access_peer;\n            CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));\n            if (can_access_peer) {\n                if (enable_peer_access) {\n                    cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0);\n                    if (err != cudaErrorPeerAccessAlreadyEnabled) {\n                        CUDA_CHECK(err);\n                    } else {\n                        // reset the error\n                        (void)cudaGetLastError();\n                    }\n                } else {\n                    cudaError_t err = cudaDeviceDisablePeerAccess(id_other);\n                    if (err != cudaErrorPeerAccessNotEnabled) {\n                        CUDA_CHECK(err);\n                    } else {\n                        // reset the error\n                        (void)cudaGetLastError();\n                    }\n                }\n            }\n        }\n    }\n\n    ggml_cuda_set_device(main_device);\n#endif // NDEBUG\n\n    peer_access_enabled = enable_peer_access;\n\n    GGML_UNUSED(main_device);\n}\n\nstatic cudaError_t ggml_cuda_Memcpy2DPeerAsync(\n    void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {\n\n#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)\n    // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices\n    cudaMemcpy3DPeerParms p = {};\n    p.dstDevice = dstDevice;\n    p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height);\n    p.srcDevice = srcDevice;\n    p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height);\n    p.extent = make_cudaExtent(width, height, 1);\n    return cudaMemcpy3DPeerAsync(&p, stream);\n#else\n    // HIP does not support cudaMemcpy3DPeerAsync or vmm pools\n    GGML_UNUSED(dstDevice);\n    GGML_UNUSED(srcDevice);\n    return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);\n#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)\n}\n\nstatic void ggml_cuda_op_mul_mat(\n    ggml_backend_cuda_context & ctx,\n    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,\n    quantize_cuda_t quantize_src1) {\n\n    const int64_t ne00 = src0->ne[0];\n    const int64_t ne01 = src0->ne[1];\n    const int64_t ne02 = src0->ne[2];\n    const int64_t ne03 = src0->ne[3];\n\n    const int64_t ne10 = src1->ne[0];\n    const int64_t ne11 = src1->ne[1];\n    const int64_t ne12 = src1->ne[2];\n    const int64_t ne13 = src1->ne[3];\n    const int64_t nrows1 = ggml_nrows(src1);\n\n    const int64_t ne0 = dst->ne[0];\n    const int64_t ne1 = dst->ne[1];\n\n    // const int64_t nb10 = src1->nb[0];\n    const int64_t nb11 = src1->nb[1];\n    const int64_t nb12 = src1->nb[2];\n    const int64_t nb13 = src1->nb[3];\n\n    const int64_t nb2 = dst->nb[2];\n    const int64_t nb3 = dst->nb[3];\n\n    ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;\n    ggml_backend_cuda_buffer_context * dst_ctx  = (ggml_backend_cuda_buffer_context *) dst->buffer->context;\n\n    GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));\n\n    GGML_ASSERT(ne12 % ne02 == 0);\n    GGML_ASSERT(ne13 % ne03 == 0);\n\n    const int64_t i02_divisor = ne12 / ne02;\n    const int64_t i03_divisor = ne13 / ne03;\n\n    const size_t src0_ts = ggml_type_size(src0->type);\n    const size_t src0_bs = ggml_blck_size(src0->type);\n    const size_t q8_1_ts = sizeof(block_q8_1);\n    const size_t q8_1_bs = QK8_1;\n\n    const bool src0_is_contiguous = ggml_is_contiguous(src0);\n    const bool src1_is_contiguous = ggml_is_contiguous(src1);\n\n    const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);\n\n    const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);\n    GGML_ASSERT(!(split && ne02 > 1));\n    GGML_ASSERT(!(split && ne03 > 1));\n    GGML_ASSERT(!(split && ne02 < ne12));\n    GGML_ASSERT(!(split && ne03 < ne13));\n\n    ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;\n\n\n    std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;\n    if (split) {\n        ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;\n        tensor_split = buft_ctx->tensor_split;\n    }\n\n    struct dev_data {\n        int cc;\n\n        ggml_cuda_pool_alloc<char>   src0_dd_alloc;\n        ggml_cuda_pool_alloc<float> src1_ddf_alloc;\n        ggml_cuda_pool_alloc<char>  src1_ddq_alloc;\n        ggml_cuda_pool_alloc<float>   dst_dd_alloc;\n\n        char  *  src0_dd = nullptr;\n        float * src1_ddf = nullptr; // float\n        char  * src1_ddq = nullptr; // q8_1\n        float *   dst_dd = nullptr;\n\n        int64_t  row_low;\n        int64_t row_high;\n    };\n\n    dev_data dev[GGML_CUDA_MAX_DEVICES];\n\n    int used_devices = 0;\n\n    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {\n        dev[id].cc = ggml_cuda_info().devices[id].cc;\n\n        // by default, use all rows\n        dev[id].row_low  = 0;\n        dev[id].row_high = ne01;\n\n        // for multi GPU, get the row boundaries from tensor split\n        // and round to mul_mat_q tile sizes\n        if (split) {\n            const int64_t rounding = get_row_rounding(tensor_split);\n\n            if (id != 0) {\n                dev[id].row_low  = ne01*tensor_split[id];\n                if (dev[id].row_low < ne01) {\n                    dev[id].row_low -= dev[id].row_low % rounding;\n                }\n            }\n\n            if (id != ggml_backend_cuda_get_device_count() - 1) {\n                dev[id].row_high  = ne01*tensor_split[id + 1];\n                if (dev[id].row_high < ne01) {\n                    dev[id].row_high -= dev[id].row_high % rounding;\n                }\n            }\n        }\n    }\n\n    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {\n        if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {\n            continue;\n        }\n\n        used_devices++;\n\n        const bool src1_on_device = id == src1_ctx->device;\n        const bool  dst_on_device = id == dst_ctx->device;\n\n        ggml_cuda_set_device(id);\n        cudaStream_t stream = ctx.stream(id, 0);\n\n        if (src0_is_contiguous) {\n            dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;\n        } else {\n            // If src0 is not contiguous it will be copied to a temporary buffer.\n            // This buffer needs to be cleared entirely because multiple regions will function as padding.\n            const size_t nbytes_data    = ggml_nbytes(src0);\n            const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);\n            dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);\n            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));\n        }\n\n        // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:\n        if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {\n            GGML_ASSERT(ggml_is_contiguously_allocated(src0));\n            GGML_ASSERT(!src0->view_src);\n            const size_t nbytes_data    = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);\n            const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);\n            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));\n        }\n\n        if (src1_on_device && src1_is_contiguous) {\n            dev[id].src1_ddf = (float *) src1->data;\n        } else {\n            dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));\n        }\n\n        if (quantize_src1) {\n            size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;\n            if (quantize_src1 == quantize_mmq_q8_1_cuda) {\n                src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq);\n            }\n            dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);\n\n            if (src1_on_device && src1_is_contiguous) {\n                quantize_src1(\n                    dev[id].src1_ddf, nullptr, dev[id].src1_ddq, src0->type, ne10,\n                    nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),\n                    src1_padded_col_size, ne11, ne12, ne13, stream);\n                CUDA_CHECK(cudaGetLastError());\n            }\n        }\n\n        if (dst_on_device) {\n            dev[id].dst_dd = (float *) dst->data;\n        } else {\n            const size_t size_dst_ddf = split ? (dev[id].row_high - dev[id].row_low)*ne1 : ggml_nelements(dst);\n            dev[id].dst_dd = dev[id].dst_dd_alloc.alloc(ctx.pool(id), size_dst_ddf);\n        }\n    }\n\n    // if multiple devices are used they need to wait for the main device\n    // here an event is recorded that signals that the main device has finished calculating the input data\n    if (split && used_devices > 1) {\n        ggml_cuda_set_device(ctx.device);\n        CUDA_CHECK(cudaEventRecord(src0_extra->events[ctx.device][0], ctx.stream()));\n    }\n\n    const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;\n    for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {\n        const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_CUDA_MAX_STREAMS : 0;\n        const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;\n\n        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {\n            if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {\n                continue;\n            }\n\n            const bool src1_on_device = id == src1_ctx->device;\n            const bool  dst_on_device = id == dst_ctx->device;\n            const int64_t row_diff = dev[id].row_high - dev[id].row_low;\n\n            ggml_cuda_set_device(id);\n            cudaStream_t stream = ctx.stream(id, is);\n\n            // wait for main GPU data if necessary\n            if (split && (id != ctx.device || is != 0)) {\n                CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[ctx.device][0], 0));\n            }\n\n            for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {\n                const int64_t i03 = i0 / ne12;\n                const int64_t i02 = i0 % ne12;\n\n                size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs;\n                if (quantize_src1 == quantize_mmq_q8_1_cuda) {\n                    src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq);\n                } else {\n                    src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs;\n                }\n\n                // for split tensors the data begins at i0 == i0_offset_low\n                const size_t nbytes_src0_matrix = ne01*ne00*src0_ts / src0_bs;\n                char  *  src0_dd_i =  dev[id].src0_dd + ((i03/i03_divisor)*ne02 + (i02/i02_divisor)) * nbytes_src0_matrix;\n                float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;\n                char  * src1_ddq_i = dev[id].src1_ddq +  src1_ddq_i_offset;\n                float *   dst_dd_i =   dev[id].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);\n\n                // the main device memory buffer can be on VRAM scratch, with space for all partial results\n                // in that case an offset on dst_ddf_i is needed\n                if (id == ctx.device) {\n                    dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split\n                }\n\n                // copy src0, src1 to device if necessary\n                if (src1_is_contiguous) {\n                    if (id != ctx.device) {\n                        if (quantize_src1) {\n                            char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;\n                            if (quantize_src1 == quantize_mmq_q8_1_cuda) {\n                                const size_t pitch = ne11*sizeof(block_q8_1_mmq);\n                                const size_t width = src1_ncols*sizeof(block_q8_1_mmq);\n                                const size_t height = src1_padded_col_size/(4*QK8_1);\n                                CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream));\n                            } else {\n                                CUDA_CHECK(cudaMemcpyPeerAsync(\n                                    src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));\n                            }\n                        } else {\n                            float * src1_ddf_i_source = (float *) src1->data;\n                            src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;\n                            CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddf_i, id, src1_ddf_i_source, ctx.device,\n                                                            src1_ncols*ne10*sizeof(float), stream));\n                        }\n                    }\n                } else if (src1_on_device && !src1_is_contiguous) {\n                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(\n                                src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));\n                } else {\n                    GGML_ABORT(\"fatal error\");\n                }\n\n                if (quantize_src1 && !src1_is_contiguous) {\n                    quantize_src1(\n                        src1_ddf_i, nullptr, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,\n                        src1_padded_col_size, src1_ncols, 1, 1, stream);\n                    CUDA_CHECK(cudaGetLastError());\n                }\n\n                if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) {\n                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(\n                        src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));\n                }\n\n                // do the computation\n                op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,\n                    dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream);\n                CUDA_CHECK(cudaGetLastError());\n\n                // copy dst to host or other device if necessary\n                if (!dst_on_device) {\n                    void * dst_off_device = dst->data;\n                    if (split) {\n                        // src0 = weight matrix is saved as a transposed matrix for better memory layout.\n                        // dst is NOT transposed.\n                        // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.\n                        // Instead they need to be copied to the correct slice in ne0 = dst row index.\n                        // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.\n                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);\n                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));\n                        dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;\n                        CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(\n                            dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream));\n                    } else {\n                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);\n                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));\n                        dhf_dst_i += src1_col_0*ne0;\n                        CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), cudaMemcpyDeviceToDevice, stream));\n                    }\n                }\n\n                // add event for the main device to wait on until other device is done\n                if (split && (id != ctx.device || is != 0)) {\n                    CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream));\n                }\n            }\n        }\n    }\n\n    // main device waits for all other devices to be finished\n    if (split && ggml_backend_cuda_get_device_count() > 1) {\n        int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;\n        is_max = is_max <= GGML_CUDA_MAX_STREAMS ? is_max : GGML_CUDA_MAX_STREAMS;\n\n        ggml_cuda_set_device(ctx.device);\n        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {\n            if (dev[id].row_low == dev[id].row_high) {\n                continue;\n            }\n            for (int64_t is = 0; is < is_max; ++is) {\n                CUDA_CHECK(cudaStreamWaitEvent(ctx.stream(), src0_extra->events[id][is], 0));\n            }\n        }\n    }\n}\n\nstatic __global__ void k_compute_batched_ptrs(\n        const void * src0_as_f16, const void * src1_as_f16, char * dst,\n        const void ** ptrs_src, void ** ptrs_dst,\n        int64_t ne12, int64_t ne13,\n        int64_t ne23,\n        size_t  nb02, size_t  nb03,\n        size_t  nb12, size_t  nb13,\n        size_t  nbd2, size_t  nbd3,\n        int64_t r2,   int64_t r3) {\n    const int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;\n    const int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;\n\n    if (i13 >= ne13 || i12 >= ne12) {\n        return;\n    }\n\n    const int64_t i03 = i13 / r3;\n    const int64_t i02 = i12 / r2;\n\n    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;\n    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;\n    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3;\n}\n\n// Type traits for mapping ggml types to CUDA/cuBLAS types\ntemplate<ggml_type T>\nstruct batched_mul_mat_traits;\n\ntemplate<>\nstruct batched_mul_mat_traits<GGML_TYPE_F32> {\n    using cuda_type = float;\n    static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;\n    static inline const cudaDataType_t data_type = CUDA_R_32F;\n    static inline const ggml_type ggml_type_val = GGML_TYPE_F32;\n    static inline const float alpha = 1.0f;\n    static inline const float beta = 0.0f;\n    static inline const void* get_alpha() { static const float val = alpha; return &val; }\n    static inline const void* get_beta() { static const float val = beta; return &val; }\n    static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }\n};\n\ntemplate<>\nstruct batched_mul_mat_traits<GGML_TYPE_BF16> {\n    using cuda_type = nv_bfloat16;\n    static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;\n    static inline const cudaDataType_t data_type = CUDA_R_16BF;\n    static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;\n    static inline const float alpha = 1.0f;\n    static inline const float beta = 0.0f;\n    static inline const void* get_alpha() { static const float val = alpha; return &val; }\n    static inline const void* get_beta() { static const float val = beta; return &val; }\n    static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }\n};\n\ntemplate<>\nstruct batched_mul_mat_traits<GGML_TYPE_F16> {\n    using cuda_type = half;\n    static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;\n    static inline const cudaDataType_t data_type = CUDA_R_16F;\n    static inline const ggml_type ggml_type_val = GGML_TYPE_F16;\n    static inline const half alpha = 1.0;\n    static inline const half beta = 0.0;\n    static inline const void* get_alpha() { static const half val = alpha; return &val; }\n    static inline const void* get_beta() { static const half val = beta; return &val; }\n    static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }\n};\n\ntemplate<ggml_type src0_type>\nstatic void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    using traits = batched_mul_mat_traits<src0_type>;\n    using cuda_t = typename traits::cuda_type;\n\n    GGML_ASSERT(!ggml_is_transposed(src0));\n    GGML_ASSERT(!ggml_is_transposed(src1));\n    GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));\n    GGML_ASSERT(src0->type == src0_type);\n    GGML_ASSERT(ggml_is_contiguous(dst));\n\n    // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.\n    // As long as dst is contiguous this does not matter though.\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int64_t ne_dst = ggml_nelements(dst);\n    cudaStream_t main_stream = ctx.stream();\n    CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));\n\n    float * dst_ddf = (float *) dst->data;\n    const size_t ts_src1 = ggml_type_size(src1->type);\n    GGML_ASSERT(nb10 == ts_src1);\n    int64_t s11 = nb11 / ts_src1;\n    int64_t s12 = nb12 / ts_src1;\n    int64_t s13 = nb13 / ts_src1;\n\n    const cuda_t * src0_ptr = nullptr;\n    const cuda_t * src1_ptr = nullptr;\n\n    ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());\n    ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());\n\n    bool is_src0_cont_2 = ggml_is_contiguous_2(src0);\n    bool is_src1_cont_2 = ggml_is_contiguous_2(src1);\n\n    // Handle src0\n    src0_ptr = (const cuda_t *) src0->data;\n\n    // Handle src1 - convert if necessary\n    if (src1->type == src0_type) {\n        src1_ptr = (const cuda_t *) src1->data;\n    } else {\n        // Convert src1 to target type using traits conversion functions\n        const int64_t ne_src1 = ggml_nelements(src1);\n        src1_alloc.alloc(ne_src1);\n\n        const auto convert_func = traits::get_nc_converter(src1->type);\n        GGML_ASSERT(convert_func != nullptr);\n        convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);\n        src1_ptr = src1_alloc.get();\n        s11 = ne10;\n        s12 = ne11*s11;\n        s13 = ne12*s12;\n\n        is_src1_cont_2 = true;\n    }\n\n    // Setup destination buffer\n    ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());\n    char * dst_t;\n    size_t nbd2 = dst->nb[2];\n    size_t nbd3 = dst->nb[3];\n\n    cublasComputeType_t cu_compute_type = traits::compute_type;\n    cudaDataType_t cu_data_type = traits::data_type;\n    cudaDataType_t cu_data_type_a = traits::data_type;\n    cudaDataType_t cu_data_type_b = traits::data_type;\n    const void * alpha = traits::get_alpha();\n    const void * beta = traits::get_beta();\n\n    const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();\n\n    int id = ggml_cuda_get_device();\n    const int cc = ggml_cuda_info().devices[id].cc;\n    static constexpr bool is_src0_type_f16 = src0_type == GGML_TYPE_F16;\n\n    // bf16 and fp32 are already being computed in fp32 (ensure it using static_assert),\n    // so checking necessity of forced fp32 only for fp16 src0_type\n    static_assert(is_src0_type_f16 || traits::compute_type == CUBLAS_COMPUTE_32F);\n\n    const bool need_compute_32f = is_src0_type_f16 && !force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)\n                                                                                  || GGML_CUDA_CC_IS_RDNA4(cc)\n                                                                                  || cc == GGML_CUDA_CC_VOLTA\n                                                                                  || force_compute_type.fp32);\n\n    if (dst->op_params[0] == GGML_PREC_DEFAULT && !need_compute_32f) {\n        if constexpr (src0_type == GGML_TYPE_F32) {\n            dst_t = (char *) dst_ddf;  // Direct F32 output\n        } else {\n            dst_t = (char *) dst_temp.alloc(ne_dst);\n            nbd2 /= sizeof(float) / sizeof(cuda_t);\n            nbd3 /= sizeof(float) / sizeof(cuda_t);\n        }\n    } else {\n        dst_t = (char *) dst_ddf;\n        cu_compute_type = batched_mul_mat_traits<GGML_TYPE_F32>::compute_type;\n        cu_data_type = batched_mul_mat_traits<GGML_TYPE_F32>::data_type;\n        alpha = batched_mul_mat_traits<GGML_TYPE_F32>::get_alpha();\n        beta = batched_mul_mat_traits<GGML_TYPE_F32>::get_beta();\n    }\n\n    GGML_ASSERT(ne12 % ne02 == 0);\n    GGML_ASSERT(ne13 % ne03 == 0);\n\n    // broadcast factors\n    const int64_t r2 = ne12/ne02;\n    const int64_t r3 = ne13/ne03;\n\n    if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {\n        // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:\n        const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;\n        const int64_t smb = ne12 == 1 ? s13       : s12;\n\n        // there is no broadcast and src0, src1 are contiguous across dims 2, 3\n        // use cublasGemmStridedBatchedEx\n        CUBLAS_CHECK(\n        cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,\n                ne01, ne11, ne10,\n                alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma,     // strideA\n                       src1_ptr, cu_data_type_b, s11,       smb,     // strideB\n                beta,     dst_t, cu_data_type,   ne0,       ne1*ne0, // strideC\n                ne12*ne13,\n                cu_compute_type,\n                CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n    } else {\n        // use cublasGemmBatchedEx\n        const int64_t ne23 = ne12*ne13;\n\n        ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);\n        ggml_cuda_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23);\n\n        size_t src1_stride_size = sizeof(cuda_t);\n\n        const int threads_x = 16;\n        const int threads_y = 16;\n        dim3 block_dims(threads_x, threads_y);\n\n        dim3 grid_dims(\n            (ne13 + threads_x - 1) / threads_x,\n            (ne12 + threads_y - 1) / threads_y\n        );\n        k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(\n                src0_ptr, src1_ptr, dst_t,\n                ptrs_src.get(), ptrs_dst.get(),\n                ne12, ne13,\n                ne23,\n                nb02, nb03,\n                (src1->type == src0_type) ? nb12 : s12*src1_stride_size,\n                (src1->type == src0_type) ? nb13 : s13*src1_stride_size,\n                nbd2, nbd3,\n                r2, r3);\n\n        CUDA_CHECK(cudaGetLastError());\n\n        CUBLAS_CHECK(\n        cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,\n                ne01, ne11, ne10,\n                alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,\n                       (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,\n                beta,  (      void **) (ptrs_dst.get() + 0*ne23), cu_data_type,   ne0,\n                ne23,\n                cu_compute_type,\n                CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n    }\n\n    // Convert output back to F32 if needed\n    if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {\n        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);\n        to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);\n    }\n}\n\nstatic void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);\n            break;\n        case GGML_TYPE_BF16:\n            ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);\n            break;\n        case GGML_TYPE_F16:\n            ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);\n            break;\n        default:\n            GGML_ABORT(\"Unsupported type\");\n    }\n}\n\nstatic bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,\n                                          const ggml_tensor * ffn_gate,\n                                          const ggml_tensor * glu,\n                                          const ggml_tensor * ffn_up_bias = nullptr,\n                                          const ggml_tensor * ffn_gate_bias = nullptr) {\n    const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr;\n\n    if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) {\n        return false;\n    }\n\n    const bool is_mul_mat     = ffn_up->op == GGML_OP_MUL_MAT     && ffn_gate->op == GGML_OP_MUL_MAT     && glu->op == GGML_OP_GLU;\n    const bool is_mul_mat_id  = ffn_up->op == GGML_OP_MUL_MAT_ID  && ffn_gate->op == GGML_OP_MUL_MAT_ID  && glu->op == GGML_OP_GLU;\n\n    GGML_ASSERT(ffn_up && ffn_gate && glu);\n\n    if (!is_mul_mat && !is_mul_mat_id) {\n        return false;\n    }\n\n    const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID;\n\n    if (has_bias) {\n        if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) {\n            return false;\n        }\n\n        if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) {\n            return false;\n        }\n\n        if (expected_bias_op == GGML_OP_ADD) {\n            const bool up_has_mul   = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up;\n            const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate;\n            if (!up_has_mul || !gate_has_mul) {\n                return false;\n            }\n        } else { // GGML_OP_ADD_ID\n            if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) {\n                return false;\n            }\n            if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) {\n                return false;\n            }\n        }\n    } else {\n        if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {\n            return false;\n        }\n    }\n\n    if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||\n        !ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {\n        return false;\n    }\n\n    if (ffn_up->src[1] != ffn_gate->src[1]) {\n        return false;\n    }\n\n    if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {\n        return false;\n    }\n\n    static constexpr std::array<ggml_glu_op, 3> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI };\n\n    if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {\n        return false;\n    }\n\n    if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {\n        return false;\n    }\n\n    const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||\n                       ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);\n\n    //TODO: add support for fusion for split buffers\n    if (split) {\n        return false;\n    }\n\n    return true;\n}\n\nstatic bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {\n    ggml_tensor *       src0 = tensor->src[0];\n    ggml_tensor *       src1 = tensor->src[1];\n    const ggml_tensor * dst  = tensor;\n\n    const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;\n\n    bool use_mul_mat_vec_f =\n        (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&\n        src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;\n\n    const int cc      = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;\n    use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);\n\n    const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||\n                       ggml_backend_buft_is_cuda_split(src1->buffer->buft);\n\n    //TODO: add support for fusion for split buffers\n    if (split) {\n        return false;\n    }\n\n    //we only support fusion for ncols_dst = 1\n    if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {\n        return false;\n    }\n\n    if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {\n        return false;\n    }\n\n\n    return use_mul_mat_vec_f;\n}\n\nstatic bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {\n    ggml_tensor *       src0 = tensor->src[0];\n    ggml_tensor *       src1 = tensor->src[1];\n    const ggml_tensor * dst  = tensor;\n\n    const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&\n                                   ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&\n                                   src0->view_src;\n\n    bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&\n                             dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;\n\n    // fusion is not universally faster on Pascal\n    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;\n    if (cc <= GGML_CUDA_CC_PASCAL) {\n        return false;\n    }\n    //we only support fusion for ncols_dst = 1\n    if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {\n        return false;\n    }\n\n    if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {\n        return false;\n    }\n\n\n    const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||\n                       ggml_backend_buft_is_cuda_split(src1->buffer->buft);\n\n    //TODO: add support for fusion for split buffers\n    if (split) {\n        return false;\n    }\n\n    return use_mul_mat_vec_q;\n}\n\nstatic void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);\n\n    // If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q.\n    // But if src0 is also a view of another tensor then this cannot be done safely because it may overwrite valid tensor data.\n    // Therefore, in such cases use cuBLAS.\n    const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE\n        && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;\n\n    bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)\n        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;\n    bool use_mul_mat_f     = !ggml_is_quantized(src0->type)\n        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;\n    bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear\n        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32\n        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;\n    bool use_mul_mat_q     = ggml_is_quantized(src0->type) && !bad_padding_clear\n        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;\n\n    bool any_gpus_with_slow_fp16 = false;\n\n    if (split) {\n        ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;\n        auto & tensor_split = buft_ctx->tensor_split;\n        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {\n            // skip devices that are not going to do any work:\n            if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {\n                continue;\n            }\n\n            const int cc            = ggml_cuda_info().devices[id].cc;\n            const int warp_size     = ggml_cuda_info().devices[id].warp_size;\n            use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);\n            use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);\n            use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);\n            any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);\n        }\n    } else {\n        const int cc            = ggml_cuda_info().devices[ctx.device].cc;\n        const int warp_size     = ggml_cuda_info().devices[ctx.device].warp_size;\n        use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);\n        use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);\n        use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);\n        any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);\n    }\n\n    // debug helpers\n    //printf(\"src0: %8d %8d %8d %8d\\n\", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);\n    //printf(\"      %8d %8d %8d %8d\\n\", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);\n    //printf(\"src1: %8d %8d %8d %8d\\n\", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);\n    //printf(\"      %8d %8d %8d %8d\\n\", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);\n    //printf(\"src0 is contiguous %d, transposed %d, type = %s, name = %s\\n\", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);\n    //printf(\"src1 is contiguous %d, transposed %d, type = %s, name = %s\\n\", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);\n\n    //TODO update for generic tensor parallelism\n    const int cc                 = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;\n    bool use_batched_cublas_f16  = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);\n    bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);\n    bool use_batched_cublas_f32  = src0->type == GGML_TYPE_F32;\n\n    if (!split && use_mul_mat_vec_f) {\n        // the custom F16 vector kernel can be used over batched cuBLAS GEMM\n        // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)\n        ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);\n    } else if (!split && use_mul_mat_f) {\n        ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);\n    } else if (!split && use_mul_mat_vec_q) {\n        ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);\n    } else if (!split && use_mul_mat_q) {\n        ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);\n    } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)\n        && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {\n        // general KQ + KQV multi-batch without FlashAttention\n        ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);\n    } else if (use_mul_mat_vec_f) {\n        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);\n    } else if (use_mul_mat_vec_q) {\n        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);\n    } else if (use_mul_mat_q) {\n        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);\n    } else {\n        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);\n    }\n}\n\nstatic void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    const ggml_tensor * ids  = dst->src[2];\n\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type  == GGML_TYPE_F32);\n    GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && \"mul_mat_id does not support split buffers\");\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;\n\n    // [TAG_MUL_MAT_ID_CUDA_GRAPHS]\n    if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n        static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);\n        if (ne2 <= MMVQ_MAX_BATCH_SIZE) {\n            if (ggml_is_quantized(src0->type)) {\n                if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {\n                    ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);\n                    return;\n                }\n            } else {\n                if (GGML_CUDA_CC_IS_AMD(cc)) {\n                    ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);\n                    return;\n                }\n            }\n        }\n\n        if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {\n            ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);\n            return;\n        }\n\n        if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src0->nb, src1->ne[2], /*mul_mat_id=*/true)) {\n            ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);\n            return;\n        }\n    }\n\n    // note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization\n    // TODO: add asserts to verify this. should work with CUDA, HIP, etc.\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(nb12 % nb11 == 0);\n    GGML_ASSERT(nb2  % nb1  == 0);\n\n    const ggml_type type_src1_sorted = (src0->type == GGML_TYPE_F16 && !fast_fp16_hardware_available(cc))\n        || ggml_is_quantized(src0->type) ? GGML_TYPE_F32 : src0->type;\n    const ggml_type type_dst_sorted  = GGML_TYPE_F32;\n    const size_t ts_src1_sorted = ggml_type_size(type_src1_sorted);\n    const size_t ts_dst_sorted  = ggml_type_size(type_dst_sorted);\n\n    const int64_t n_expert_used = ids->ne[0];\n    const int64_t ne_get_rows = ne12 * n_expert_used;\n\n    std::vector<int32_t> ids_to_sorted_host;\n    ids_to_sorted_host.reserve(2*ne_get_rows);\n    std::vector<int32_t> ids_from_sorted_host(ne_get_rows);\n\n    ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool(), 2*ne_get_rows);\n\n    std::vector<int32_t> tokens_per_expert(ne02);\n\n    ggml_cuda_pool_alloc<char> src1_sorted(ctx.pool(), ne12*n_expert_used*ne10*ts_src1_sorted);\n    ggml_cuda_pool_alloc<char>  dst_sorted(ctx.pool(), ne2 *n_expert_used* ne0*ts_dst_sorted);\n\n    std::vector<char> ids_host(ggml_nbytes(ids));\n    CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));\n    CUDA_CHECK(cudaStreamSynchronize(stream));\n\n    for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices\n        for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens\n            for (int64_t iex = 0; iex < n_expert_used; ++iex) {\n                const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);\n                assert(expert_to_use >= 0 && expert_to_use < ne02);\n                if (expert_to_use == i02) {\n                    ids_from_sorted_host[i12*n_expert_used + iex] = ids_to_sorted_host.size();\n                    ids_to_sorted_host.push_back(i12*ne11 + iex % ne11);\n                    tokens_per_expert[i02]++;\n                    break;\n                }\n            }\n        }\n    }\n    GGML_ASSERT(ids_to_sorted_host.size() == size_t(ne_get_rows));\n\n    ids_to_sorted_host.insert(ids_to_sorted_host.end(), ids_from_sorted_host.begin(), ids_from_sorted_host.end());\n\n    CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_to_sorted_host.data(), 2*ne_get_rows*sizeof(int32_t), cudaMemcpyHostToDevice, stream));\n    CUDA_CHECK(cudaStreamSynchronize(stream));\n\n    const int32_t * ids_to_sorted   = ids_buf_dev.ptr + 0*ne_get_rows;\n    const int32_t * ids_from_sorted = ids_buf_dev.ptr + 1*ne_get_rows;\n\n    get_rows_cuda(src1->data, src1->type, ids_to_sorted, src1_sorted.ptr, type_src1_sorted,\n        ne10, nb11, nb12, nb13,\n        ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t),\n        ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, stream);\n    CUDA_CHECK(cudaGetLastError());\n\n    char * src1_data_cur = (char *) src1_sorted.ptr;\n    char *  dst_data_cur = (char *)  dst_sorted.ptr;\n    for (int64_t i02 = 0; i02 < ne02; ++i02) {\n        if (tokens_per_expert[i02] == 0) {\n            continue;\n        }\n\n        ggml_tensor src0_slice = *src0;\n        src0_slice.ne[2]    = 1;\n        src0_slice.nb[3]    = src0_slice.nb[2];\n        src0_slice.op       = GGML_OP_VIEW;\n        src0_slice.view_src = dst->src[0]; // non-const pointer to src0\n        src0_slice.data     = (char *) src0->data + i02*nb02;\n\n        ggml_tensor src1_slice;\n        memset(&src1_slice, 0, sizeof(src1_slice));\n        src1_slice.buffer = src1->buffer;\n        src1_slice.type   = type_src1_sorted;\n        src1_slice.ne[0]  = ne10;\n        src1_slice.ne[1]  = tokens_per_expert[i02];\n        src1_slice.ne[2]  = 1;\n        src1_slice.ne[3]  = 1;\n        src1_slice.nb[0]  = ts_src1_sorted;\n        src1_slice.nb[1]  = src1_slice.ne[0] * src1_slice.nb[0];\n        src1_slice.nb[2]  = src1_slice.ne[1] * src1_slice.nb[1];\n        src1_slice.nb[3]  = src1_slice.ne[2] * src1_slice.nb[2];\n        src1_slice.data   = src1_data_cur;\n\n        ggml_tensor dst_slice;\n        memset(&dst_slice, 0, sizeof(dst_slice));\n        dst_slice.buffer = dst->buffer;\n        dst_slice.type   = type_dst_sorted;\n        dst_slice.ne[0]  = ne0;\n        dst_slice.ne[1]  = tokens_per_expert[i02];\n        dst_slice.ne[2]  = 1;\n        dst_slice.ne[3]  = 1;\n        dst_slice.nb[0]  = ts_dst_sorted;\n        dst_slice.nb[1]  = dst_slice.ne[0] * dst_slice.nb[0];\n        dst_slice.nb[2]  = dst_slice.ne[1] * dst_slice.nb[1];\n        dst_slice.nb[3]  = dst_slice.ne[2] * dst_slice.nb[2];\n        dst_slice.data   = dst_data_cur;\n\n        ggml_cuda_mul_mat(ctx, &src0_slice, &src1_slice, &dst_slice);\n        CUDA_CHECK(cudaGetLastError());\n\n        src1_data_cur += src1_slice.nb[2];\n        dst_data_cur  +=  dst_slice.nb[2];\n    }\n\n    get_rows_cuda(dst_sorted.ptr, type_dst_sorted, ids_from_sorted, dst->data, dst->type,\n        ne0, ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted,\n        ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t),\n        nb1, nb2, nb3, stream);\n}\n\nstatic bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {\n    // why is this here instead of mul_mat?\n    if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) {\n        ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);\n    }\n\n    switch (dst->op) {\n        case GGML_OP_ARGMAX:\n            ggml_cuda_argmax(ctx, dst);\n            break;\n        case GGML_OP_COUNT_EQUAL:\n            ggml_cuda_count_equal(ctx, dst);\n            break;\n        case GGML_OP_REPEAT:\n            ggml_cuda_op_repeat(ctx, dst);\n            break;\n        case GGML_OP_REPEAT_BACK:\n            ggml_cuda_op_repeat_back(ctx, dst);\n            break;\n        case GGML_OP_GET_ROWS:\n            ggml_cuda_op_get_rows(ctx, dst);\n            break;\n        case GGML_OP_GET_ROWS_BACK:\n            ggml_cuda_op_get_rows_back(ctx, dst);\n            break;\n        case GGML_OP_SET_ROWS:\n            ggml_cuda_op_set_rows(ctx, dst);\n            break;\n        case GGML_OP_SET:\n            ggml_cuda_op_set(ctx, dst);\n            break;\n        case GGML_OP_DUP:\n            ggml_cuda_dup(ctx, dst);\n            break;\n        case GGML_OP_CPY:\n            ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);\n            break;\n        case GGML_OP_CONT:\n            ggml_cuda_dup(ctx, dst);\n            break;\n        case GGML_OP_ADD:\n        case GGML_OP_ADD1: // TODO: more efficient implementation\n            ggml_cuda_op_add(ctx, dst);\n            break;\n        case GGML_OP_ADD_ID:\n            ggml_cuda_op_add_id(ctx, dst);\n            break;\n        case GGML_OP_SUB:\n            ggml_cuda_op_sub(ctx, dst);\n            break;\n        case GGML_OP_ACC:\n            ggml_cuda_op_acc(ctx, dst);\n            break;\n        case GGML_OP_MUL:\n            ggml_cuda_op_mul(ctx, dst);\n            break;\n        case GGML_OP_DIV:\n            ggml_cuda_op_div(ctx, dst);\n            break;\n        case GGML_OP_UNARY:\n            switch (ggml_get_unary_op(dst)) {\n                case GGML_UNARY_OP_ABS:\n                    ggml_cuda_op_abs(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_SGN:\n                    ggml_cuda_op_sgn(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_NEG:\n                    ggml_cuda_op_neg(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_STEP:\n                    ggml_cuda_op_step(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_GELU:\n                    ggml_cuda_op_gelu(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_SILU:\n                    ggml_cuda_op_silu(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_GELU_ERF:\n                    ggml_cuda_op_gelu_erf(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_GELU_QUICK:\n                    ggml_cuda_op_gelu_quick(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_TANH:\n                    ggml_cuda_op_tanh(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_RELU:\n                    ggml_cuda_op_relu(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_SIGMOID:\n                    ggml_cuda_op_sigmoid(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_HARDSIGMOID:\n                    ggml_cuda_op_hardsigmoid(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_HARDSWISH:\n                    ggml_cuda_op_hardswish(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_EXP:\n                    ggml_cuda_op_exp(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_ELU:\n                    ggml_cuda_op_elu(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_XIELU:\n                    ggml_cuda_op_xielu(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_FLOOR:\n                    ggml_cuda_op_floor(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_CEIL:\n                    ggml_cuda_op_ceil(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_ROUND:\n                    ggml_cuda_op_round(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_TRUNC:\n                    ggml_cuda_op_trunc(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_EXPM1:\n                    ggml_cuda_op_expm1(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_SOFTPLUS:\n                    ggml_cuda_op_softplus(ctx, dst);\n                    break;\n                default:\n                    return false;\n            }\n            break;\n        case GGML_OP_GLU:\n            switch (ggml_get_glu_op(dst)) {\n                case GGML_GLU_OP_REGLU:\n                    ggml_cuda_op_reglu(ctx, dst);\n                    break;\n                case GGML_GLU_OP_GEGLU:\n                    ggml_cuda_op_geglu(ctx, dst);\n                    break;\n                case GGML_GLU_OP_SWIGLU:\n                    ggml_cuda_op_swiglu(ctx, dst);\n                    break;\n                case GGML_GLU_OP_SWIGLU_OAI:\n                    ggml_cuda_op_swiglu_oai(ctx, dst);\n                    break;\n                case GGML_GLU_OP_GEGLU_ERF:\n                    ggml_cuda_op_geglu_erf(ctx, dst);\n                    break;\n                case GGML_GLU_OP_GEGLU_QUICK:\n                    ggml_cuda_op_geglu_quick(ctx, dst);\n                    break;\n                default:\n                    return false;\n            }\n            break;\n        case GGML_OP_NORM:\n            ggml_cuda_op_norm(ctx, dst);\n            break;\n        case GGML_OP_GROUP_NORM:\n            ggml_cuda_op_group_norm(ctx, dst);\n            break;\n        case GGML_OP_L2_NORM:\n            ggml_cuda_op_l2_norm(ctx, dst);\n            break;\n        case GGML_OP_CONCAT:\n            ggml_cuda_op_concat(ctx, dst);\n            break;\n        case GGML_OP_UPSCALE:\n            ggml_cuda_op_upscale(ctx, dst);\n            break;\n        case GGML_OP_PAD:\n            ggml_cuda_op_pad(ctx, dst);\n            break;\n        case GGML_OP_PAD_REFLECT_1D:\n            ggml_cuda_op_pad_reflect_1d(ctx, dst);\n            break;\n        case GGML_OP_ARANGE:\n            ggml_cuda_op_arange(ctx, dst);\n            break;\n        case GGML_OP_TIMESTEP_EMBEDDING:\n            ggml_cuda_op_timestep_embedding(ctx, dst);\n            break;\n        case GGML_OP_LEAKY_RELU:\n            ggml_cuda_op_leaky_relu(ctx, dst);\n            break;\n        case GGML_OP_SILU_BACK:\n            ggml_cuda_op_silu_back(ctx, dst);\n            break;\n        case GGML_OP_RMS_NORM:\n            ggml_cuda_op_rms_norm(ctx, dst);\n            break;\n        case GGML_OP_RMS_NORM_BACK:\n            ggml_cuda_op_rms_norm_back(ctx, dst);\n            break;\n        case GGML_OP_MUL_MAT:\n            ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);\n            break;\n        case GGML_OP_MUL_MAT_ID:\n            ggml_cuda_mul_mat_id(ctx, dst);\n            break;\n        case GGML_OP_OUT_PROD:\n            ggml_cuda_out_prod(ctx, dst);\n            break;\n        case GGML_OP_SCALE:\n            ggml_cuda_op_scale(ctx, dst);\n            break;\n        case GGML_OP_SQR:\n            ggml_cuda_op_sqr(ctx, dst);\n            break;\n        case GGML_OP_SQRT:\n            ggml_cuda_op_sqrt(ctx, dst);\n            break;\n        case GGML_OP_SIN:\n            ggml_cuda_op_sin(ctx, dst);\n            break;\n        case GGML_OP_COS:\n            ggml_cuda_op_cos(ctx, dst);\n            break;\n        case GGML_OP_CLAMP:\n            ggml_cuda_op_clamp(ctx, dst);\n            break;\n        case GGML_OP_LOG:\n            ggml_cuda_op_log(ctx, dst);\n            break;\n        case GGML_OP_NONE:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n                break;\n        case GGML_OP_DIAG:\n            ggml_cuda_op_diag(ctx, dst);\n            break;\n        case GGML_OP_DIAG_MASK_INF:\n            ggml_cuda_op_diag_mask_inf(ctx, dst);\n            break;\n        case GGML_OP_SOFT_MAX:\n            ggml_cuda_op_soft_max(ctx, dst);\n            break;\n        case GGML_OP_SOFT_MAX_BACK:\n            ggml_cuda_op_soft_max_back(ctx, dst);\n            break;\n        case GGML_OP_ROPE:\n            ggml_cuda_op_rope(ctx, dst);\n            break;\n        case GGML_OP_ROPE_BACK:\n            ggml_cuda_op_rope_back(ctx, dst);\n            break;\n        case GGML_OP_ROLL:\n            ggml_cuda_op_roll(ctx, dst);\n            break;\n        case GGML_OP_IM2COL:\n            ggml_cuda_op_im2col(ctx, dst);\n            break;\n        case GGML_OP_IM2COL_3D:\n            ggml_cuda_op_im2col_3d(ctx, dst);\n            break;\n        case GGML_OP_CONV_2D:\n            ggml_cuda_op_conv2d(ctx, dst);\n            break;\n        case GGML_OP_CONV_2D_DW:\n            ggml_cuda_op_conv2d_dw(ctx, dst);\n            break;\n        case GGML_OP_CONV_TRANSPOSE_2D:\n            ggml_cuda_conv_2d_transpose_p0(ctx, dst);\n            break;\n        case GGML_OP_CONV_TRANSPOSE_1D:\n            ggml_cuda_op_conv_transpose_1d(ctx,dst);\n            break;\n        case GGML_OP_POOL_2D:\n            ggml_cuda_op_pool2d(ctx, dst);\n            break;\n        case GGML_OP_SUM:\n            ggml_cuda_op_sum(ctx, dst);\n            break;\n        case GGML_OP_CUMSUM:\n            ggml_cuda_op_cumsum(ctx, dst);\n            break;\n        case GGML_OP_SUM_ROWS:\n            ggml_cuda_op_sum_rows(ctx, dst);\n            break;\n        case GGML_OP_MEAN:\n            ggml_cuda_op_mean(ctx, dst);\n            break;\n        case GGML_OP_SSM_CONV:\n            ggml_cuda_op_ssm_conv(ctx, dst);\n            break;\n        case GGML_OP_SSM_SCAN:\n            ggml_cuda_op_ssm_scan(ctx, dst);\n            break;\n        case GGML_OP_TOP_K:\n            ggml_cuda_op_top_k(ctx, dst);\n            break;\n        case GGML_OP_ARGSORT:\n            ggml_cuda_op_argsort(ctx, dst);\n            break;\n        case GGML_OP_FLASH_ATTN_EXT:\n            ggml_cuda_flash_attn_ext(ctx, dst);\n            break;\n        case GGML_OP_CROSS_ENTROPY_LOSS:\n            ggml_cuda_cross_entropy_loss(ctx, dst);\n            break;\n        case GGML_OP_TRI:\n            ggml_cuda_op_tri(ctx, dst);\n            break;\n        case GGML_OP_RWKV_WKV6:\n            ggml_cuda_op_rwkv_wkv6(ctx, dst);\n            break;\n        case GGML_OP_GATED_LINEAR_ATTN:\n            ggml_cuda_op_gated_linear_attn(ctx, dst);\n            break;\n        case GGML_OP_GATED_DELTA_NET:\n            ggml_cuda_op_gated_delta_net(ctx, dst);\n            break;\n        case GGML_OP_RWKV_WKV7:\n            ggml_cuda_op_rwkv_wkv7(ctx, dst);\n            break;\n        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:\n            ggml_cuda_cross_entropy_loss_back(ctx, dst);\n            break;\n        case GGML_OP_OPT_STEP_ADAMW:\n            ggml_cuda_opt_step_adamw(ctx, dst);\n            break;\n        case GGML_OP_OPT_STEP_SGD:\n            ggml_cuda_opt_step_sgd(ctx, dst);\n            break;\n        case GGML_OP_SOLVE_TRI:\n            ggml_cuda_op_solve_tri(ctx, dst);\n            break;\n        case GGML_OP_FILL:\n            ggml_cuda_op_fill(ctx, dst);\n            break;\n        default:\n            return false;\n    }\n\n    cudaError_t err = cudaGetLastError();\n    if (err != cudaSuccess) {\n        GGML_LOG_ERROR(\"%s: %s failed\\n\", __func__, ggml_op_desc(dst));\n        CUDA_CHECK(err);\n    }\n\n    return true;\n}\n\n////////////////////////////////////////////////////////////////////////////////\n\n// backend\n\nstatic const char * ggml_backend_cuda_get_name(ggml_backend_t backend) {\n    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;\n\n    return cuda_ctx->name.c_str();\n}\n\nstatic void ggml_backend_cuda_free(ggml_backend_t backend) {\n    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;\n\n    delete cuda_ctx;\n    delete backend;\n}\n\nstatic void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;\n    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;\n\n    GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && \"unsupported buffer type\");\n\n    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));\n}\n\nstatic void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;\n    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;\n\n    GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && \"unsupported buffer type\");\n\n    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream()));\n}\n\nstatic bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {\n    ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;\n    ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;\n\n    if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {\n        return false;\n    }\n\n    if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) {\n        return false;\n    }\n\n    // device -> device copy\n    ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;\n    ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;\n\n    ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;\n    ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;\n\n    if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {\n#ifndef NDEBUG\n        GGML_LOG_DEBUG(\"%s: backend and buffer devices do not match\\n\", __func__);\n#endif\n        return false;\n    }\n\n    if (backend_src != backend_dst) {\n        // copy on src stream\n        if (cuda_ctx_src->device == cuda_ctx_dst->device) {\n            CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));\n        } else {\n#ifdef GGML_CUDA_NO_PEER_COPY\n            return false;\n#else\n            CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));\n#endif\n        }\n\n        // record event on src stream after the copy\n        if (!cuda_ctx_src->copy_event) {\n            ggml_cuda_set_device(cuda_ctx_src->device);\n            CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));\n        }\n\n        CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream()));\n\n        // wait on dst stream for the copy to complete\n        CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0));\n    } else {\n        // src and dst are on the same backend\n        CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));\n    }\n    return true;\n}\n\nstatic void ggml_backend_cuda_synchronize(ggml_backend_t backend) {\n    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;\n\n    CUDA_CHECK(cudaStreamSynchronize(cuda_ctx->stream()));\n\n    GGML_UNUSED(backend);\n}\n\n#ifdef USE_CUDA_GRAPH\nstatic bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {\n\n    bool use_cuda_graph = true;\n    // Loop over nodes in GGML graph to obtain info needed for CUDA graph\n\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        ggml_tensor * node = cgraph->nodes[i];\n\n        if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {\n            continue;\n        }\n\n        if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {\n            use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture\n#ifndef NDEBUG\n            GGML_LOG_DEBUG(\"%s: disabling CUDA graphs due to split buffer\\n\", __func__);\n#endif\n        }\n\n        // [TAG_MUL_MAT_ID_CUDA_GRAPHS]\n        if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {\n            // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs\n            // TODO: figure out a way to enable for larger batch sizes, without hurting performance\n            // ref: https://github.com/ggml-org/llama.cpp/pull/18958\n            use_cuda_graph = false;\n#ifndef NDEBUG\n            GGML_LOG_DEBUG(\"%s: disabling CUDA graphs due to unsupported node type\\n\", __func__);\n#endif\n        }\n\n        if (!use_cuda_graph) {\n            break;\n        }\n    }\n\n    return use_cuda_graph;\n}\n\nstatic void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {\n    memset(props, 0, sizeof(ggml_cuda_graph_node_properties));\n    props->node_data = node->data;\n    props->node_op = node->op;\n    props->node_type = node->type;\n    props->flags = node->flags;\n    for (int i = 0; i < GGML_MAX_DIMS; i++) {\n        props->ne[i] = node->ne[i];\n        props->nb[i] = node->nb[i];\n    }\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        if (!node->src[i]) {\n            continue;\n        }\n\n        props->src_data[i] = node->src[i]->data;\n    }\n    memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);\n}\n\nstatic bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {\n    if (node->data != props->node_data && node->op != GGML_OP_VIEW) {\n        return false;\n    }\n\n    if (node->op != props->node_op) {\n        return false;\n    }\n\n    if (node->type != props->node_type) {\n        return false;\n    }\n\n    for (int i = 0; i < GGML_MAX_DIMS; i++) {\n        if (node->ne[i] != props->ne[i]) {\n            return false;\n        }\n        if (node->nb[i] != props->nb[i]) {\n            return false;\n        }\n    }\n\n    if (node->op != GGML_OP_VIEW) {\n        for (int i = 0; i < GGML_MAX_SRC; i++) {\n            if (!node->src[i]) {\n                if (props->src_data[i] != nullptr) {\n                    return false;\n                }\n                continue;\n            }\n\n            if (node->src[i]->data != props->src_data[i]) {\n                return false;\n            }\n        }\n    }\n\n    if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {\n        return false;\n    }\n\n    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) {\n        return false;\n    }\n\n    return true;\n}\n\nstatic const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {\n    return cgraph->nodes[0];\n}\n\nstatic bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {\n    bool res = false;\n\n    const void * graph_key = ggml_cuda_graph_get_key(cgraph);\n    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);\n\n    // Check if the graph size has changed\n    if (graph->props.size() != (size_t)cgraph->n_nodes) {\n        res = true;\n        graph->props.resize(cgraph->n_nodes);\n    }\n\n    // Loop over nodes in GGML graph to determine if CUDA graph update is required\n    // and store properties to allow this comparison for the next token\n    std::unordered_set<ggml_tensor *> seen_node;\n    std::vector<ggml_tensor *> srcs_extra;\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        bool props_match = true;\n\n        seen_node.insert(cgraph->nodes[i]);\n\n        if (!res) {\n            props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);\n        }\n        if (!props_match) {\n            res = true;\n        }\n        ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);\n\n        for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {\n            ggml_tensor * src = cgraph->nodes[i]->src[src_idx];\n            if (src && seen_node.find(src) == seen_node.end()) {\n                srcs_extra.push_back(src);\n            }\n        }\n    }\n\n    if (graph->extra.size() != (size_t) srcs_extra.size()) {\n        res = true;\n        graph->extra.resize(srcs_extra.size());\n    }\n\n    for (size_t i = 0; i < srcs_extra.size(); ++i) {\n        bool props_match = true;\n\n        if (!res) {\n            props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);\n        }\n\n        if (!props_match) {\n            res = true;\n        }\n        ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);\n    }\n\n    return res;\n}\n\nstatic void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {\n    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);\n\n#if CUDART_VERSION >= 12000\n    cudaGraphExecUpdateResultInfo result_info;\n    cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info);\n#else\n    cudaGraphNode_t errorNode;\n    cudaGraphExecUpdateResult result_info;\n    cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info);\n#endif // CUDART_VERSION >= 12000\n\n    if (stat == cudaErrorGraphExecUpdateFailure) {\n#ifndef NDEBUG\n        GGML_LOG_DEBUG(\"%s: CUDA graph update failed\\n\", __func__);\n#endif\n\n        // The pre-existing graph exec cannot be updated due to violated constraints\n        // so instead clear error and re-instantiate\n        (void)cudaGetLastError();\n        CUDA_CHECK(cudaGraphExecDestroy(graph->instance));\n        graph->instance = nullptr;\n        CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));\n    } else {\n        GGML_ASSERT(stat == cudaSuccess);\n    }\n}\n#endif // USE_CUDA_GRAPH\n\nstatic bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,\n                                                const ggml_tensor * view,\n                                                const ggml_tensor * set_rows) {\n\n    if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) {\n        return false;\n    }\n    // ne3 not tested\n    if (rope->src[0]->ne[3] != 1) {\n        return false;\n    }\n\n    if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {\n        return false;\n    }\n\n    if (set_rows->src[1]->type != GGML_TYPE_I64) {\n        return false;\n    }\n\n    // The view should flatten two dims of rope into one dim\n    if (!ggml_is_contiguous(view) || view->ne[0] != rope->ne[0] * rope->ne[1]) {\n        return false;\n    }\n\n    // Only norm/neox shaders have the fusion code\n    const int mode = ((const int32_t *) rope->op_params)[2];\n    if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {\n        return false;\n    }\n\n    return true;\n}\n\nstatic bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {\n    args.sigmoid         = false;\n    args.softmax         = false;\n    args.delayed_softmax = false;\n    args.prob_bias       = false;\n    args.norm            = false;\n\n    const int      n_nodes = cgraph->n_nodes;\n    ggml_tensor ** nodes   = cgraph->nodes;\n\n    if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {\n        args.softmax = true;\n    }\n\n    if (nodes[node_idx]->op == GGML_OP_UNARY) {\n        if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {\n            return false;\n        }\n        args.sigmoid = true;\n    }\n\n    if (nodes[node_idx]->op == GGML_OP_ARGSORT) {\n        args.delayed_softmax = true;\n    }\n\n    node_idx++;\n\n    if (args.sigmoid || args.softmax) {\n        // SOFTMAX -> RESHAPE\n        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||\n                nodes[node_idx]->src[0] != nodes[node_idx - 1]) {\n            return false;\n        }\n        ggml_tensor * probs_reshaped = nodes[node_idx];\n        node_idx++;\n\n        if (node_idx >= n_nodes) {\n            return false;\n        }\n\n        // src of bias add is the unreshaped probs (-2 instead of -1)\n        if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {\n            args.prob_bias = true;\n            node_idx++;\n        }\n        // RESHAPE/ADD -> ARGSORT\n        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {\n            return false;\n        }\n\n        if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {\n            return false;\n        } else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {\n            return false;\n        }\n\n        node_idx++;\n\n        // ARGSORT-> VIEW\n        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||\n                nodes[node_idx]->src[0] != nodes[node_idx - 1]) {\n            return false;\n        }\n        node_idx++;\n\n        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {\n            return false;\n        }\n\n        // GET_ROWS\n        if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {\n            return false;\n        }\n        node_idx++;\n    } else if (args.delayed_softmax) {\n        if (node_idx - 2 < 0) {\n            return false;\n        }\n        ggml_tensor * probs_reshaped = nodes[node_idx - 2];\n\n        // VIEW->ARGSORT\n        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||\n            nodes[node_idx]->src[0] != nodes[node_idx - 1]) {\n            return false;\n        }\n        node_idx++;\n\n        // GET_ROWS\n        if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||\n                nodes[node_idx]->src[0] != probs_reshaped) {\n            return false;\n        }\n        node_idx++;\n\n        static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };\n\n        for (const ggml_op op : remaining_ops) {\n            if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {\n                return false;\n            }\n            node_idx++;\n        }\n    }\n\n    // At this point we can check for norm + scale. Everything is now at least valid till the norm\n    if (node_idx >= n_nodes) {\n        return true;\n    }\n\n    if (nodes[node_idx]->op == GGML_OP_RESHAPE) {\n        //check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE\n        static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };\n\n        args.norm = true;\n        for (const ggml_op op : norm_ops) {\n            if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {\n                node_idx++;\n            } else {\n                args.norm = false;\n                return true;\n            }\n        }\n\n        // DIV <- CLAMP, RESHAPE\n        if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||\n            nodes[node_idx]->src[0] != nodes[node_idx - 3]) {\n            args.norm = false;\n            return true;\n        }\n        node_idx++;\n\n        if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {\n            args.norm = false;\n            return true;\n        }\n\n        node_idx++;\n    }\n\n    if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {\n        args.scale = true;\n    }\n\n    return true;\n}\n\nstatic bool ggml_cuda_can_fuse(const struct ggml_cgraph *                cgraph,\n                               int                                       node_idx,\n                               std::initializer_list<enum ggml_op>       ops,\n                               std::initializer_list<enum ggml_unary_op> unary_ops) {\n#ifndef NDEBUG\n    const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);\n    GGML_ASSERT(unary_ops.size() == num_unary);\n#endif\n\n    const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,\n                             const std::initializer_list<enum ggml_op> & list2) {\n        return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());\n    };\n\n    std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops    = { GGML_OP_MUL_MAT,    GGML_OP_ADD,    GGML_OP_MUL_MAT,    GGML_OP_ADD,    GGML_OP_GLU };\n    std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };\n\n    std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };\n    std::initializer_list<enum ggml_op> mul_mat_glu_ops    = { GGML_OP_MUL_MAT,    GGML_OP_MUL_MAT,    GGML_OP_GLU };\n\n    if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) &&\n        ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) {\n        const ggml_tensor * ffn_gate      = cgraph->nodes[node_idx];\n        const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];\n        const ggml_tensor * ffn_up        = cgraph->nodes[node_idx + 2];\n        const ggml_tensor * ffn_up_bias   = cgraph->nodes[node_idx + 3];\n        const ggml_tensor * glu           = cgraph->nodes[node_idx + 4];\n\n        if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {\n            return true;\n        }\n    }\n\n    if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) &&\n        ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {\n        const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];\n        const ggml_tensor * ffn_up   = cgraph->nodes[node_idx + 1];\n        const ggml_tensor * glu      = cgraph->nodes[node_idx + 2];\n\n        if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {\n            return true;\n        }\n    }\n\n    std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };\n\n    if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {\n        const ggml_tensor * rope     = cgraph->nodes[node_idx];\n        const ggml_tensor * view     = cgraph->nodes[node_idx + 1];\n        const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];\n\n        if (ggml_cuda_should_fuse_rope_set_rows(rope, view, set_rows)) {\n            return true;\n        }\n    }\n\n    if (!ggml_can_fuse(cgraph, node_idx, ops)) {\n        return false;\n    }\n\n    if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {\n        const ggml_tensor *rms_norm = cgraph->nodes[node_idx];\n        const ggml_tensor *mul      = cgraph->nodes[node_idx+1];\n        const ggml_tensor *add      = nullptr;\n\n        if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {\n            add = cgraph->nodes[node_idx+2];\n        }\n\n        GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);\n        GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);\n\n        //rms norm only supports F32\n        if (mul->src[0]->type != GGML_TYPE_F32 ||\n            mul->src[1]->type != GGML_TYPE_F32 ||\n            mul->type != GGML_TYPE_F32) {\n            return false;\n        }\n\n        if (add && (add->src[0]->type != GGML_TYPE_F32 ||\n            add->src[1]->type != GGML_TYPE_F32 ||\n            add->type != GGML_TYPE_F32) ) {\n            return false;\n        }\n\n        //if rms norm is the B operand, then we don't handle broadcast\n        if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {\n            return false;\n        }\n\n        //rms_norm kernel assumes contiguous rows\n        if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {\n            return false;\n        }\n\n        if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {\n            return false;\n        }\n\n        return true;\n    }\n\n    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY\n     && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {\n        const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];\n        const ggml_tensor * silu     = cgraph->nodes[node_idx+1];\n\n        if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {\n            return false;\n        }\n\n        return true;\n    }\n\n    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL\n     && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {\n        const ggml_tensor * unary = cgraph->nodes[node_idx];\n        const ggml_tensor * mul   = cgraph->nodes[node_idx+1];\n\n        if (ggml_get_unary_op(unary) != unary_ops.begin()[0]) {\n            return false;\n        }\n\n        if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {\n            return false;\n        }\n\n        if (unary->type != mul->type) {\n            return false;\n        }\n\n        const ggml_tensor * other = (mul->src[0] == unary) ? mul->src[1] : mul->src[0];\n        if (other->type != unary->type) {\n            return false;\n        }\n        if (!ggml_is_contiguous_1(other) || !ggml_is_contiguous_1(unary->src[0]) || !ggml_are_same_shape(other, unary)) {\n            return false;\n        }\n\n        return true;\n    }\n\n    if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE\n     && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {\n        const ggml_tensor *scale  = cgraph->nodes[node_idx];\n        const ggml_tensor *tanh   = cgraph->nodes[node_idx+1];\n        const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];\n\n        GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);\n        GGML_ASSERT(scale->type == GGML_TYPE_F32);\n\n        if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) {\n            return false;\n        }\n\n        // Check for bias\n        if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {\n            return false;\n        }\n\n        return true;\n    }\n\n    return false;\n}\n\n// returns whether the write (out) nodes overwrite the read nodes in operation\nstatic bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph,\n                                                 int           node_idx,\n                                                 int           node_count,\n                                                 int *         out_nodes,\n                                                 int           out_count) {\n    auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {\n        const int64_t a_start = (int64_t) a->data;\n        const int64_t a_end   = a_start + ggml_nbytes(a);\n\n        const int64_t b_start = (int64_t) b->data;\n        const int64_t b_end   = b_start + ggml_nbytes(b);\n\n        if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {\n            return true;\n        }\n\n        return false;\n    };\n\n    bool is_ok = true;\n    // for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok\n    if (ggml_nrows(cgraph->nodes[node_idx]) == 1) {\n        return true;\n    }\n\n    for (int i = 0; i < out_count; ++i) {\n        const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];\n\n        for (int j = node_idx; j < node_idx + node_count; ++j) {\n            // Loop over all srcs of all nodes in the fusion. If the src overlaps\n            // the destination and the src is not an intermediate node that's being\n            // elided, then disable fusion.\n\n            for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {\n                const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];\n\n                if (!src || src->op == GGML_OP_NONE) {\n                    continue;\n                }\n\n                if (nodes_overlap(dst, src)) {\n                    bool found = false;\n\n                    for (int k = node_idx; k < j; ++k) {\n                        if (cgraph->nodes[k] == src) {\n                            found = true;\n                            break;\n                        }\n                    }\n\n                    if (!found) {\n                        is_ok = false;\n                        break;\n                    }\n                }\n            }\n        }\n    }\n\n    return is_ok;\n}\n\nstatic void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {\n    bool graph_evaluated_or_captured = false;\n\n    // flag used to determine whether it is an integrated_gpu\n    const bool integrated            = ggml_cuda_info().devices[cuda_ctx->device].integrated;\n\n    ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();\n    bool                         is_concurrent_event_active = false;\n    ggml_cuda_concurrent_event * concurrent_event           = nullptr;\n    bool                         should_launch_concurrent_events = false;\n\n    const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {\n        if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {\n            concurrent_event = &stream_ctx.concurrent_events[node];\n\n            is_concurrent_event_active = true;\n\n            GGML_LOG_DEBUG(\"Launching %d streams at %s\\n\", concurrent_event->n_streams, node->name);\n\n            cudaStream_t main_stream = cuda_ctx->stream();  // this should be stream 0\n            GGML_ASSERT(cuda_ctx->curr_stream_no == 0);\n            CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));\n\n            for (int i = 1; i <= concurrent_event->n_streams; ++i) {\n                cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);\n                CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));\n            }\n        }\n    };\n\n    while (!graph_evaluated_or_captured) {\n        // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.\n        // With the use of CUDA graphs, the execution will be performed by the graph launch.\n        if (!use_cuda_graph || cuda_graph_update_required) {\n            [[maybe_unused]] int prev_i = 0;\n\n            if (stream_ctx.concurrent_events.size() > 0) {\n                should_launch_concurrent_events = true;\n                for (const auto & [tensor, event] : stream_ctx.concurrent_events) {\n                    should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();\n                }\n            }\n\n            if (should_launch_concurrent_events) {\n                // Restore original node order within each concurrent region to enable fusion within streams\n\n                std::unordered_map<const ggml_tensor *, int> node_to_idx;\n                node_to_idx.reserve(cgraph->n_nodes);\n                for (int i = 0; i < cgraph->n_nodes; ++i) {\n                    node_to_idx[cgraph->nodes[i]] = i;\n                }\n\n                for (auto & [fork_node, event] : stream_ctx.concurrent_events) {\n                    // Find positions of all nodes from this event in the current graph\n                    std::vector<int> positions;\n                    positions.reserve(event.original_order.size());\n\n                    bool all_found = true;\n                    for (const ggml_tensor * orig_node : event.original_order) {\n                        auto it = node_to_idx.find(orig_node);\n                        if (it != node_to_idx.end()) {\n                            positions.push_back(it->second);\n                        } else {\n                            all_found = false;\n                            break;\n                        }\n                    }\n\n                    if (!all_found || positions.size() != event.original_order.size()) {\n                        continue;\n                    }\n\n                    // Sort positions to get contiguous range\n                    std::vector<int> sorted_positions = positions;\n                    std::sort(sorted_positions.begin(), sorted_positions.end());\n\n                    bool is_contiguous = true;\n                    for (size_t i = 1; i < sorted_positions.size(); ++i) {\n                        if (sorted_positions[i] != sorted_positions[i-1] + 1) {\n                            is_contiguous = false;\n                            break;\n                        }\n                    }\n\n                    if (!is_contiguous) {\n                        continue;\n                    }\n\n                    // Restore original order at the sorted positions\n                    int start_pos = sorted_positions[0];\n                    for (size_t i = 0; i < event.original_order.size(); ++i) {\n                        cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);\n                    }\n                }\n            } else {\n                stream_ctx.concurrent_events.clear();\n            }\n\n            for (int i = 0; i < cgraph->n_nodes; i++) {\n                ggml_tensor * node = cgraph->nodes[i];\n                if (is_concurrent_event_active) {\n                    GGML_ASSERT(concurrent_event);\n\n                    if (node == concurrent_event->join_node) {\n                        cuda_ctx->curr_stream_no = 0;\n                        for (int i = 1; i <= concurrent_event->n_streams; ++i) {\n                            // Wait on join events of forked streams in the main stream\n                            CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1],\n                                                       cuda_ctx->stream(cuda_ctx->device, i)));\n                            CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1]));\n                        }\n\n                        is_concurrent_event_active = false;\n                        concurrent_event           = nullptr;\n                    } else {\n                        GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end());\n                        cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];\n                        GGML_LOG_DEBUG(\"Setting stream no to %d for node %s\\n\", cuda_ctx->curr_stream_no, node->name);\n                    }\n                } else if (i - prev_i > 1) {\n                    //the previous node was fused\n                    const ggml_tensor * prev_node = cgraph->nodes[i - 1];\n                    try_launch_concurrent_event(prev_node);\n\n                    if (is_concurrent_event_active) {\n                        cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];\n                        GGML_LOG_DEBUG(\"Setting stream no to %d for node %s\\n\", cuda_ctx->curr_stream_no, node->name);\n                    }\n                }\n\n#ifdef GGML_CUDA_DEBUG\n                const int nodes_fused = i - prev_i - 1;\n                if (nodes_fused > 0) {\n                    GGML_LOG_INFO(\"nodes_fused: %d\\n\", nodes_fused);\n                }\n#endif\n                prev_i = i;\n\n                if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {\n                    continue;\n                }\n\n                if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {\n                    continue;\n                }\n\n                // start of fusion operations\n                static bool disable_fusion = (getenv(\"GGML_CUDA_DISABLE_FUSION\") != nullptr);\n                if (!disable_fusion) {\n                    ggml_cuda_topk_moe_args args;\n\n                    if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||\n                        cgraph->nodes[i]->op == GGML_OP_ARGSORT) {\n                        const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);\n\n                        std::vector<ggml_op> ops;\n\n                        if (can_fuse) {\n                            const ggml_tensor * logits  = node->src[0];\n                            ggml_tensor *       weights = nullptr;\n                            ggml_tensor *       ids     = nullptr;\n                            const ggml_tensor * bias    = nullptr;\n                            const ggml_tensor * clamp   = nullptr;\n                            const ggml_tensor * scale   = nullptr;\n\n                            if (!args.delayed_softmax) {\n                                ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;\n                                int     out_nodes[2];  // nodes which can't be elided\n\n                                if (args.prob_bias) {\n                                    bias = cgraph->nodes[i + 2]->src[1];\n                                    ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,\n                                                            GGML_OP_VIEW, GGML_OP_GET_ROWS });\n                                    out_nodes[0] = i + 4;\n                                    ids          = cgraph->nodes[i + 4];\n                                } else {\n                                    ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,\n                                                            GGML_OP_GET_ROWS });\n                                    out_nodes[0] = i + 3;\n                                    ids          = cgraph->nodes[i + 3];\n                                }\n\n                                if (args.norm) {\n                                    ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,\n                                                            GGML_OP_DIV, GGML_OP_RESHAPE });\n                                    clamp = cgraph->nodes[i + ops.size() - 3];\n                                }\n                                if (args.scale) {\n                                    ops.insert(ops.end(), { GGML_OP_SCALE });\n                                    scale = cgraph->nodes[i + ops.size() - 1];\n                                }\n\n                                weights      = cgraph->nodes[i + ops.size() - 1];\n                                out_nodes[1] = i + ops.size() - 1;\n\n                                if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&\n                                    ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&\n                                    ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {\n                                    ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);\n                                    i += ops.size() - 1;\n                                    continue;\n                                }\n                            } else if (!args.norm && !args.prob_bias) {\n                                //special case gpt-oss, no norm, no bias.\n                                ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,\n                                                        GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });\n                                weights                     = cgraph->nodes[i + 5];\n                                ids                         = cgraph->nodes[i + 1];\n                                const ggml_tensor * softmax = cgraph->nodes[i + 4];\n\n                                int out_nodes[2] = { i + 1, i + 5 };\n                                if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&\n                                    ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&\n                                    ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {\n                                    ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);\n                                    i += ops.size() - 1;\n                                    continue;\n                                }\n                            }\n                        }\n                    }\n\n                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {\n                        ggml_tensor * rope = cgraph->nodes[i];\n                        ggml_tensor * set_rows = cgraph->nodes[i + 2];\n\n                        ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);\n                        i += 2;\n                        continue;\n                    }\n\n                    if (node->op == GGML_OP_ADD) {\n                        int n_fuse = 0;\n                        ggml_op ops[8];\n                        std::fill(ops, ops + 8, GGML_OP_ADD);\n\n                        for (; n_fuse <= 6; ++n_fuse){\n                            if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {\n                                break;\n                            }\n                            if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {\n                                break;\n                            }\n                            if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {\n                                break;\n                            }\n                        }\n\n                        n_fuse++;\n\n                        if (n_fuse > 1) {\n                            ggml_tensor fused_add_node;\n                            memcpy(&fused_add_node, node, sizeof(ggml_tensor));\n                            for (int j = 0; j < n_fuse - 1; ++j) {\n                                fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];\n                            }\n                            fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;\n                            ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);\n                            i += n_fuse - 1;\n\n                            continue;\n                        }\n                    }\n\n                    bool fused_mul_mat_vec = false;\n                    int fused_node_count = 0;\n\n                    for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {\n                        const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;\n\n                        if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {\n                            ggml_tensor * glu         = cgraph->nodes[i + 4];\n                            ggml_tensor * gate_bias_n = glu->src[0];\n                            ggml_tensor * up_bias_n   = glu->src[1];\n\n                            //we don't assume the order for {gate, up}. Instead infer it from the bias tensor\n                            ggml_tensor * gate_n      = nullptr;\n                            ggml_tensor * up_n        = nullptr;\n\n                            if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {\n                                gate_n = cgraph->nodes[i];\n                                up_n   = cgraph->nodes[i + 2];\n                            } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {\n                                gate_n = cgraph->nodes[i + 2];\n                                up_n   = cgraph->nodes[i];\n                            } else {\n                                continue;\n                            }\n\n                            auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {\n                                if (op_bias == GGML_OP_ADD) {\n                                    if (bias_node->src[0] == mul_node) {\n                                        return bias_node->src[1];\n                                    }\n                                    if (bias_node->src[1] == mul_node) {\n                                        return bias_node->src[0];\n                                    }\n                                    return (ggml_tensor *) nullptr;\n                                }\n                                GGML_ASSERT(op_bias == GGML_OP_ADD_ID);\n                                GGML_ASSERT(bias_node->src[0] == mul_node);\n                                return bias_node->src[1];\n                            };\n\n                            ggml_tensor * up_bias_tensor   = get_bias_tensor(up_bias_n, up_n, bias_op);\n                            ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);\n\n                            if (!up_bias_tensor || !gate_bias_tensor) {\n                                continue;\n                            }\n\n                            // we don't support repeating adds\n                            if (bias_op == GGML_OP_ADD &&\n                                (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||\n                                 !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {\n                                continue;\n                            }\n\n                            const ggml_tensor * src0 = up_n->src[0];\n                            const ggml_tensor * src1 = up_n->src[1];\n                            const ggml_tensor * ids  = up_n->src[2];\n\n                            if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {\n                                ggml_cuda_mm_fusion_args_host fusion_data{};\n                                fusion_data.gate      = gate_n->src[0];\n                                fusion_data.x_bias    = up_bias_tensor;\n                                fusion_data.gate_bias = gate_bias_tensor;\n                                fusion_data.glu_op    = ggml_get_glu_op(glu);\n\n                                ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);\n                                fused_mul_mat_vec = true;\n                                fused_node_count = 5;\n                                break;\n                            }\n\n                            if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {\n                                ggml_cuda_mm_fusion_args_host fusion_data{};\n                                fusion_data.gate      = gate_n->src[0];\n                                fusion_data.x_bias    = up_bias_tensor;\n                                fusion_data.gate_bias = gate_bias_tensor;\n                                fusion_data.glu_op    = ggml_get_glu_op(glu);\n\n                                ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);\n                                fused_mul_mat_vec = true;\n                                fused_node_count = 5;\n                                break;\n                            }\n                        } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {\n                            ggml_tensor * glu  = cgraph->nodes[i + 2];\n                            ggml_tensor * gate = glu->src[0];\n                            ggml_tensor * up   = glu->src[1];\n\n                            bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])\n                                || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);\n\n                            if (!ok) continue;\n\n                            const ggml_tensor * src0 = up->src[0];\n                            const ggml_tensor * src1 = up->src[1];\n                            const ggml_tensor * ids  = up->src[2];\n\n                            if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {\n                                ggml_cuda_mm_fusion_args_host fusion_data{};\n                                fusion_data.gate   = gate->src[0];\n                                fusion_data.glu_op = ggml_get_glu_op(glu);\n\n                                ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);\n                                fused_mul_mat_vec = true;\n                                fused_node_count = 3;\n                                break;\n                            }\n\n                            if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {\n                                ggml_cuda_mm_fusion_args_host fusion_data{};\n                                fusion_data.gate   = gate->src[0];\n                                fusion_data.glu_op = ggml_get_glu_op(glu);\n\n                                ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);\n                                fused_mul_mat_vec = true;\n                                fused_node_count = 3;\n                                break;\n                            }\n                        }\n                    }\n\n                    if (fused_mul_mat_vec) {\n                        i += fused_node_count - 1;\n                        continue;\n                    }\n\n                    fused_mul_mat_vec = false;\n                    fused_node_count = 0;\n\n                    for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {\n                        const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;\n\n                        if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {\n                            continue;\n                        }\n\n                        ggml_tensor * mm_node   = cgraph->nodes[i];\n                        ggml_tensor * bias_node = cgraph->nodes[i + 1];\n\n                        ggml_tensor * bias_tensor = nullptr;\n                        if (bias_op == GGML_OP_ADD) {\n                            if (bias_node->src[0] == mm_node) {\n                                bias_tensor = bias_node->src[1];\n                            } else if (bias_node->src[1] == mm_node) {\n                                bias_tensor = bias_node->src[0];\n                            } else {\n                                continue;\n                            }\n                        } else {\n                            if (bias_node->src[0] != mm_node) {\n                                continue;\n                            }\n                            bias_tensor = bias_node->src[1];\n                        }\n\n                        const ggml_tensor * src0 = mm_node->src[0];\n                        const ggml_tensor * src1 = mm_node->src[1];\n                        const ggml_tensor * ids  = mm_node->src[2];\n\n                        if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {\n                            continue;\n                        }\n\n                        if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {\n                            continue;\n                        }\n\n                        ggml_cuda_mm_fusion_args_host fusion_data{};\n                        fusion_data.x_bias = bias_tensor;\n\n                        if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {\n                            ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);\n                            fused_mul_mat_vec = true;\n                            fused_node_count = 2;\n                            break;\n                        }\n\n                        if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {\n                            ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);\n                            fused_mul_mat_vec = true;\n                            fused_node_count = 2;\n                            break;\n                        }\n                    }\n\n                    if (fused_mul_mat_vec) {\n                        i += fused_node_count - 1;\n                        continue;\n                    }\n\n                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {\n                        ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);\n                        i += 2;\n                        continue;\n                    }\n\n                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {\n                        ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);\n                        i++;\n                        continue;\n                    }\n\n                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {\n                        ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]);\n                        i++;\n                        continue;\n                    }\n\n                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||\n                        ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||\n                        ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {\n                        ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]);\n                        i++;\n                        continue;\n                    }\n\n                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {\n                        i += 2;\n                        ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);\n                        continue;\n                    }\n                }\n#ifndef NDEBUG\n                assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));\n                for (int j = 0; j < GGML_MAX_SRC; j++) {\n                    if (node->src[j] != nullptr) {\n                        assert(node->src[j]->buffer);\n                        assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||\n                               ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));\n                    }\n                }\n#else\n                GGML_UNUSED(integrated);\n#endif  // NDEBUG\n\n                bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);\n                if (!ok) {\n                    GGML_LOG_ERROR(\"%s: op not supported %s (%s)\\n\", __func__, node->name, ggml_op_name(node->op));\n                }\n                GGML_ASSERT(ok);\n\n                if (!is_concurrent_event_active) {\n                    try_launch_concurrent_event(node);\n               }\n            }\n        }\n\n#ifdef USE_CUDA_GRAPH\n        ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);\n        if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture\n            if (graph->graph != nullptr) {\n                CUDA_CHECK(cudaGraphDestroy(graph->graph));\n                graph->graph = nullptr;\n            }\n\n            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph));\n            graph_evaluated_or_captured = true; // CUDA graph has been captured\n\n            std::lock_guard<std::mutex> lock(ggml_cuda_lock);\n            if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {\n                ggml_cuda_lock_cv.notify_all();\n            }\n        } else {\n            graph_evaluated_or_captured = true; // ggml graph has been directly evaluated\n        }\n    }\n\n    if (use_cuda_graph) {\n        ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);\n        if (graph->instance == nullptr) { // Create executable graph from captured graph.\n            CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));\n        }\n        if (cuda_graph_update_required) { // Update graph executable\n            ggml_cuda_graph_update_executable(cuda_ctx, graph_key);\n        }\n        // Launch graph\n        CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));\n#else\n        GGML_UNUSED(graph_key);\n        graph_evaluated_or_captured = true;\n#endif  // USE_CUDA_GRAPH\n    }\n}\n\n#ifdef USE_CUDA_GRAPH\nstatic bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {\n    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);\n\n    if (graph->graph == nullptr) {\n        if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {\n            if (!graph->disable_due_to_gpu_arch) {\n                GGML_LOG_DEBUG(\"%s: disabling CUDA graphs due to GPU architecture\\n\", __func__);\n            }\n            graph->disable_due_to_gpu_arch = true;\n        }\n    }\n\n    return graph->is_enabled();\n}\n#endif // USE_CUDA_GRAPH\n\nstatic enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {\n    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;\n\n    ggml_cuda_set_device(cuda_ctx->device);\n\n    bool use_cuda_graph             = false;\n    bool cuda_graph_update_required = false;\n    const void * graph_key = nullptr;\n\n#ifdef USE_CUDA_GRAPH\n    graph_key = ggml_cuda_graph_get_key(cgraph);\n\n    ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);\n\n    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);\n    if (graph->is_enabled()) {\n        const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph);\n        if (graph_compatible) {\n            const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph);\n\n            if (!graph->warmup_complete) {\n                // Warmup: need at least 2 calls with no property change on the 2nd call\n                if (!properties_changed) {\n                    graph->warmup_complete = true;\n                    GGML_LOG_DEBUG(\"%s: CUDA graph warmup complete\\n\", __func__);\n                    use_cuda_graph = true;\n                    cuda_graph_update_required = true;\n                }\n                // else: properties changed or first call - execute directly (use_cuda_graph stays false)\n            } else {\n                // Post-warmup: normal CUDA graph operation\n                if (properties_changed) {\n                    // Properties changed - reset warmup, execute directly until stable again\n                    graph->warmup_complete = false;\n                    GGML_LOG_DEBUG(\"%s: CUDA graph warmup reset\\n\", __func__);\n                } else {\n                    use_cuda_graph = true;\n                    cuda_graph_update_required = graph->instance == nullptr;\n                }\n            }\n        }\n    }\n#endif // USE_CUDA_GRAPH\n\n    if (use_cuda_graph && cuda_graph_update_required) {\n        // Start CUDA graph capture\n        {\n            std::lock_guard<std::mutex> lock(ggml_cuda_lock);\n            ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);\n        }\n\n        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));\n    }\n\n    ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);\n\n    return GGML_STATUS_SUCCESS;\n}\n\nstatic void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) {\n    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;\n\n    CUDA_CHECK(cudaEventRecord((cudaEvent_t)event->context, cuda_ctx->stream()));\n}\n\nstatic void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {\n    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;\n\n    if (ggml_backend_is_cuda(backend)) {\n        CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), (cudaEvent_t)event->context, 0));\n    } else {\n#if 0\n        // untested\n        auto wait_fn = [](void * user_data) {\n            ggml_backend_event_t event = (ggml_backend_event_t)user_data;\n            ggml_backend_event_synchronize(event);\n        };\n\n        CUDA_CHECK(cudaLaunchHostFunc(cuda_ctx->stream(), wait_fn, event));\n#endif\n        GGML_ABORT(\"fatal error\");\n    }\n}\n\nstatic void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {\n    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;\n\n#ifdef USE_CUDA_GRAPH\n    const void * graph_key = ggml_cuda_graph_get_key(cgraph);\n    const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);\n#else\n    const bool use_cuda_graph = false;\n    GGML_UNUSED(cuda_ctx);\n    GGML_UNUSED(cgraph);\n#endif\n\n    static bool enable_graph_optimization = [] {\n        const char * env     = getenv(\"GGML_CUDA_GRAPH_OPT\");\n        return env != nullptr && atoi(env) == 1;\n    }();\n\n    if (!enable_graph_optimization) {\n        return;\n    }\n\n    ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();\n    stream_context.reset();\n\n    if (!use_cuda_graph || ggml_backend_cuda_get_device_count() != 1) {\n        return;\n    }\n\n    // number of out-degrees for a particular node\n    std::unordered_map<const ggml_tensor *, int> fan_out;\n    // reverse mapping of node to index in the cgraph\n    std::unordered_map<const ggml_tensor *, int> node_indices;\n\n    const auto & is_noop = [](const ggml_tensor * node) -> bool {\n        return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE ||\n               node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;\n    };\n\n    const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool {\n        for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {\n            if (dst->src[s] == src) {\n                return true;\n            }\n        }\n        // implicit dependency if they view the same tensor\n        const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst;\n        const ggml_tensor * src2 = src->view_src ? src->view_src : src;\n        if (dst2 == src2) {\n            return true;\n        }\n        return false;\n    };\n\n    for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {\n        const ggml_tensor * node = cgraph->nodes[node_idx];\n        node_indices[node]       = node_idx;\n\n        if (is_noop(node)) {\n            continue;\n        }\n        for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {\n            const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx];\n            //TODO: check why nrows > 1 fails\n            if (node && !is_noop(node) && ggml_nrows(node) <= 1) {\n                fan_out[src] += 1;\n            }\n        }\n    }\n\n    // Target Q, K, V for concurrency\n    // this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else):\n    // 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be \"attn-norm\")\n    // 2. find the join node, where 2 or more of the outputs are required (in QKV, this would \"KQ\" or \"flash-attn\")\n    // 3. account for all branches from the fork to the join\n    // 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details)\n    // 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams\n    // See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030\n\n    const int min_fan_out = 3;\n    const int max_fan_out = 3;\n\n    // store {fork_idx, join_idx}\n    std::vector<std::pair<int, int>> concurrent_node_ranges;\n\n    for (const auto & [root_node, count] : fan_out) {\n        if (count >= min_fan_out && count <= max_fan_out) {\n            const int root_node_idx = node_indices[root_node];\n\n            // only optimize for attn_norm\n            // TODO: make this more generic\n            if (!strstr(root_node->name, \"attn_norm\")) {\n                continue;\n            }\n\n            bool is_part_of_event = false;\n            for (const auto & [start, end] : concurrent_node_ranges) {\n                if (root_node_idx >= start && root_node_idx <= end) {\n                    is_part_of_event = true;\n                }\n            }\n\n            if (is_part_of_event) {\n                continue;\n            }\n\n            std::vector<std::vector<const ggml_tensor *>> nodes_per_branch;\n            for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {\n                const ggml_tensor * node = cgraph->nodes[i];\n                if (!is_noop(node) && depends_on(node, root_node)) {\n                    nodes_per_branch.push_back({ node });\n                }\n            }\n\n            GGML_ASSERT(nodes_per_branch.size() == (size_t) count);\n\n            //find the join point\n            const ggml_tensor * join_node = nullptr;\n\n            const auto & belongs_to_branch = [&](const ggml_tensor *                      node,\n                                                 const std::vector<const ggml_tensor *> & branch) -> bool {\n                for (const ggml_tensor * n : branch) {\n                    if (depends_on(node, n)) {\n                        return true;\n                    }\n                }\n                return false;\n            };\n\n            for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {\n                const ggml_tensor * curr_node = cgraph->nodes[i];\n\n                int num_joins = 0;\n                for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {\n                    if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) {\n                        num_joins++;\n                    }\n                }\n\n                if (num_joins >= 2) {\n                    join_node = curr_node;\n                    break;\n                }\n\n                bool found_branch = false;\n                for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {\n                    std::vector<const ggml_tensor *> & branch_vec = nodes_per_branch[branch_idx];\n                    if (belongs_to_branch(curr_node, branch_vec)) {\n                        //continue accumulating\n                        if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) {\n                            branch_vec.push_back(curr_node);\n                        }\n                        found_branch = true;\n                    }\n                }\n\n                if (!found_branch && is_noop(curr_node)) {\n                    // we can put it in any branch because it will be ignored\n                    nodes_per_branch[0].push_back({ curr_node });\n                }\n            }\n\n            if (join_node) {\n                //Create ggml_cuda_concurrent_event\n                ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size());\n                concurrent_event.join_node = join_node;\n\n                for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {\n                    for (const ggml_tensor * n : nodes_per_branch[branch_idx]) {\n                        concurrent_event.stream_mapping[n] = branch_idx + 1;\n                    }\n                }\n\n                int fork_node_idx = node_indices[root_node];\n                int join_node_idx = node_indices[join_node];\n\n                int       current_branch_idx = 0;\n                int       current_node_idx   = fork_node_idx + 1;\n                const int n_branches         = nodes_per_branch.size();\n\n                int total_branch_nodes = 0;\n                for (std::vector<const ggml_tensor *> branch_nodes : nodes_per_branch) {\n                    total_branch_nodes += branch_nodes.size();\n                }\n\n                // there are other nodes in the middle which are unaccounted for\n                // usually (cpy) nodes, then ignore this fork\n                if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) {\n                    GGML_LOG_DEBUG(\n                        \"Skipping %s because the number of nodes in the middle is not equal to the total number of \"\n                        \"branch nodes %d != %d\\n\",\n                        root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes);\n                    continue;\n                }\n\n                // Save the original order of nodes in this region before interleaving\n                // This is used later to restore grouping for fusion within streams\n                concurrent_event.original_order.reserve(total_branch_nodes);\n                for (int i = fork_node_idx + 1; i < join_node_idx; ++i) {\n                    concurrent_event.original_order.push_back(cgraph->nodes[i]);\n                }\n\n                std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;\n                GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());\n                concurrent_events.emplace(root_node, std::move(concurrent_event));\n                GGML_LOG_DEBUG(\"Adding stream at node %s %p\\n\", root_node->name, root_node);\n                concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx);\n\n                // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them\n                // example transformation:\n                // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] ->\n                // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn]\n                while (current_node_idx < join_node_idx) {\n                    std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx];\n\n                    bool has_node = false;\n                    for (std::vector<const ggml_tensor *> branch_node : nodes_per_branch) {\n                        has_node |= branch_node.size() > 0;\n                    }\n\n                    GGML_ASSERT(has_node);\n\n                    if (branch_nodes.empty()) {\n                        current_branch_idx = (current_branch_idx + 1) % n_branches;\n                        continue;\n                    }\n\n                    cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());\n                    current_node_idx++;\n                    branch_nodes.erase(branch_nodes.begin());\n\n                    // append all empty nodes\n                    while (!branch_nodes.empty() && is_noop(branch_nodes.front())) {\n                        cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());\n                        current_node_idx++;\n                        branch_nodes.erase(branch_nodes.begin());\n                    }\n\n                    current_branch_idx = (current_branch_idx + 1) % n_branches;\n                }\n            }\n        }\n    }\n}\n\nstatic const ggml_backend_i ggml_backend_cuda_interface = {\n    /* .get_name                = */ ggml_backend_cuda_get_name,\n    /* .free                    = */ ggml_backend_cuda_free,\n    /* .set_tensor_async        = */ ggml_backend_cuda_set_tensor_async,\n    /* .get_tensor_async        = */ ggml_backend_cuda_get_tensor_async,\n    /* .cpy_tensor_async        = */ ggml_backend_cuda_cpy_tensor_async,\n    /* .synchronize             = */ ggml_backend_cuda_synchronize,\n    /* .graph_plan_create       = */ NULL,\n    /* .graph_plan_free         = */ NULL,\n    /* .graph_plan_update       = */ NULL,\n    /* .graph_plan_compute      = */ NULL,\n    /* .graph_compute           = */ ggml_backend_cuda_graph_compute,\n    /* .event_record            = */ ggml_backend_cuda_event_record,\n    /* .event_wait              = */ ggml_backend_cuda_event_wait,\n    /* .graph_optimize          = */ ggml_backend_cuda_graph_optimize,\n};\n\nstatic ggml_guid_t ggml_backend_cuda_guid() {\n    static ggml_guid guid = { 0x2c, 0xdd, 0xe8, 0x1c, 0x65, 0xb3, 0x65, 0x73, 0x6a, 0x12, 0x88, 0x61, 0x1c, 0xc9, 0xdc, 0x25 };\n    return &guid;\n}\n\nbool ggml_backend_is_cuda(ggml_backend_t backend) {\n    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cuda_guid());\n}\n\nint ggml_backend_cuda_get_device_count() {\n    return ggml_cuda_info().device_count;\n}\n\nvoid ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size) {\n    cudaDeviceProp prop;\n    CUDA_CHECK(cudaGetDeviceProperties(&prop, device));\n    snprintf(description, description_size, \"%s\", prop.name);\n}\n\nvoid ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total) {\n    ggml_cuda_set_device(device);\n\n    CUDA_CHECK(cudaMemGetInfo(free, total));\n}\n\nbool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {\n    if (getenv(\"GGML_CUDA_REGISTER_HOST\") == nullptr) {\n        return false;\n    }\n\n#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) || defined(GGML_USE_HIP)\n    cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);\n    if (err != cudaSuccess) {\n        // clear the error\n        (void)cudaGetLastError();\n\n        GGML_LOG_DEBUG(\"%s: failed to register %.2f MiB of pinned memory: %s\\n\", __func__,\n                           size / 1024.0 / 1024.0, cudaGetErrorString(err));\n        return false;\n    }\n    return true;\n#else\n    GGML_UNUSED(buffer);\n    GGML_UNUSED(size);\n    return false;\n#endif // CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)\n}\n\nvoid ggml_backend_cuda_unregister_host_buffer(void * buffer) {\n    if (getenv(\"GGML_CUDA_REGISTER_HOST\") == nullptr) {\n        return;\n    }\n\n    cudaError_t err = cudaHostUnregister(buffer);\n    if (err != cudaSuccess) {\n        // clear the error\n        (void)cudaGetLastError();\n    }\n}\n\n\n// backend device\n\nstruct ggml_backend_cuda_device_context {\n    int device;\n    std::string name;\n    std::string description;\n    std::string pci_bus_id;\n    int op_offload_min_batch_size;\n};\n\nstatic const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {\n    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;\n    return ctx->name.c_str();\n}\n\nstatic const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t dev) {\n    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;\n    return ctx->description.c_str();\n}\n\n#if defined(__linux__)\n// Helper function to get available memory from /proc/meminfo for UMA systems\nstatic bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_kb, long * free_swap_kb) {\n    FILE * meminfo_file = nullptr;\n    // 2KB buffer for reading /proc/meminfo since it does not report size info, should be enough\n    const size_t BUFFER_SIZE = 2048;\n    auto file_buffer = std::make_unique<char[]>(BUFFER_SIZE);\n    size_t bytes_read = 0;\n    long huge_tlb_total_pages = -1;\n    long huge_tlb_free_pages = -1;\n    long huge_tlb_page_size = -1;\n\n    if (available_memory_kb == nullptr || free_swap_kb == nullptr) {\n        return false;\n    }\n\n    meminfo_file = fopen(\"/proc/meminfo\", \"r\");\n    if (meminfo_file == nullptr) {\n        GGML_LOG_ERROR(\"%s: failed to open /proc/meminfo\\n\", __func__);\n        return false;\n    }\n\n    // Read file into buffer\n    bytes_read = fread(file_buffer.get(), 1, BUFFER_SIZE - 1, meminfo_file);\n    fclose(meminfo_file);\n\n    if (bytes_read == 0) {\n        GGML_LOG_ERROR(\"%s: failed to read from /proc/meminfo\\n\", __func__);\n        return false;\n    }\n    file_buffer[bytes_read] = '\\0';\n\n    *available_memory_kb = -1;\n    *free_swap_kb = -1;\n\n    // Parse the file buffer line by line\n    char * line = file_buffer.get();\n    char * line_next;\n    while (line < file_buffer.get() + bytes_read) {\n        // Find the end of the current line\n        line_next = strchr(line, '\\n');\n        if (line_next != nullptr) {\n            *line_next = '\\0';\n            line_next++;\n        } else {\n            line_next = file_buffer.get() + bytes_read;\n        }\n\n        long value;\n        if (sscanf(line, \"MemAvailable: %ld kB\", &value) == 1) {\n            *available_memory_kb = value;\n        } else if (sscanf(line, \"SwapFree: %ld kB\", &value) == 1) {\n            *free_swap_kb = value;\n        } else if (sscanf(line, \"HugePages_Total: %ld\", &value) == 1) {\n            huge_tlb_total_pages = value;\n        } else if (sscanf(line, \"HugePages_Free: %ld\", &value) == 1) {\n            huge_tlb_free_pages = value;\n        } else if (sscanf(line, \"Hugepagesize: %ld kB\", &value) == 1) {\n            huge_tlb_page_size = value;\n        }\n\n        line = line_next;\n    }\n\n    if (huge_tlb_total_pages != 0 && huge_tlb_total_pages != -1) {\n        *available_memory_kb = huge_tlb_free_pages * huge_tlb_page_size;\n\n        // Hugetlbfs pages are not swappable.\n        *free_swap_kb = 0;\n    }\n\n    GGML_LOG_DEBUG(\"%s: final available_memory_kb: %ld\\n\", __func__, *available_memory_kb);\n    return true;\n}\n#endif // defined(__linux__)\n\nstatic void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {\n    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;\n    ggml_cuda_set_device(ctx->device);\n    CUDA_CHECK(cudaMemGetInfo(free, total));\n\n// ref: https://github.com/ggml-org/llama.cpp/pull/17368\n#if defined(__linux__)\n    // Check if this is a UMA (Unified Memory Architecture) system\n    cudaDeviceProp prop;\n    CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device));\n\n    // Check if UMA is explicitly enabled via environment variable\n    bool uma_env = getenv(\"GGML_CUDA_ENABLE_UNIFIED_MEMORY\") != nullptr;\n    bool is_uma = prop.integrated > 0 || uma_env;\n\n    if (is_uma) {\n        // For UMA systems (like DGX Spark), use system memory info\n        long available_memory_kb = 0;\n        long free_swap_kb = 0;\n\n        if (ggml_backend_cuda_get_available_uma_memory(&available_memory_kb, &free_swap_kb) && available_memory_kb > 0) {\n            *free = (size_t)available_memory_kb * 1024;\n        } else {\n            GGML_LOG_ERROR(\"%s: /proc/meminfo reading failed, using cudaMemGetInfo\\n\", __func__);\n        }\n    }\n#endif // defined(__linux__)\n\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {\n    GGML_UNUSED(dev);\n    return GGML_BACKEND_DEVICE_TYPE_GPU;\n}\n\nstatic void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {\n    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;\n\n    props->name        = ggml_backend_cuda_device_get_name(dev);\n    props->description = ggml_backend_cuda_device_get_description(dev);\n    props->type        = ggml_backend_cuda_device_get_type(dev);\n    props->device_id   = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();\n    ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);\n\n    bool host_buffer = getenv(\"GGML_CUDA_NO_PINNED\") == nullptr;\n#ifdef GGML_CUDA_NO_PEER_COPY\n    bool events = false;\n#else\n    bool events = true;\n#endif\n\n    props->caps = {\n        /* .async                 = */ true,\n        /* .host_buffer           = */ host_buffer,\n        /* .buffer_from_host_ptr  = */ false,\n        /* .events                = */ events,\n    };\n}\n\nstatic ggml_backend_t ggml_backend_cuda_device_init_backend(ggml_backend_dev_t dev, const char * params) {\n    GGML_UNUSED(params);\n    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;\n    return ggml_backend_cuda_init(ctx->device);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_cuda_device_get_buffer_type(ggml_backend_dev_t dev) {\n    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;\n    return ggml_backend_cuda_buffer_type(ctx->device);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_cuda_device_get_host_buffer_type(ggml_backend_dev_t dev) {\n    GGML_UNUSED(dev);\n    return ggml_backend_cuda_host_buffer_type();\n}\n\n// TODO: move these functions here\nstatic bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n    ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;\n\n    // split buffers can only be used with GGML_OP_MUL_MAT\n    if (op->op != GGML_OP_MUL_MAT) {\n        for (int i = 0; i < GGML_MAX_SRC; i++) {\n            if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda_split(op->src[i]->buffer->buft)) {\n                return false;\n            }\n        }\n    }\n\n    // check if all the sources are allocated on this device\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda(op->src[i]->buffer->buft)) {\n            ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)op->src[i]->buffer->buft->context;\n            if (buft_ctx->device != dev_ctx->device) {\n                return false;\n            }\n        }\n    }\n\n    switch (op->op) {\n        case GGML_OP_UNARY:\n            switch (ggml_get_unary_op(op)) {\n                case GGML_UNARY_OP_ABS:\n                case GGML_UNARY_OP_SGN:\n                case GGML_UNARY_OP_NEG:\n                case GGML_UNARY_OP_STEP:\n                case GGML_UNARY_OP_GELU:\n                case GGML_UNARY_OP_SILU:\n                case GGML_UNARY_OP_RELU:\n                case GGML_UNARY_OP_SIGMOID:\n                case GGML_UNARY_OP_HARDSIGMOID:\n                case GGML_UNARY_OP_HARDSWISH:\n                case GGML_UNARY_OP_GELU_ERF:\n                case GGML_UNARY_OP_GELU_QUICK:\n                case GGML_UNARY_OP_TANH:\n                case GGML_UNARY_OP_EXP:\n                case GGML_UNARY_OP_EXPM1:\n                case GGML_UNARY_OP_SOFTPLUS:\n                case GGML_UNARY_OP_ELU:\n                case GGML_UNARY_OP_XIELU:\n                case GGML_UNARY_OP_FLOOR:\n                case GGML_UNARY_OP_CEIL:\n                case GGML_UNARY_OP_ROUND:\n                case GGML_UNARY_OP_TRUNC:\n                    // TODO: should become:\n                    //return ggml_is_contiguous_rows(op->src[0]);\n                    return ggml_is_contiguous(op->src[0]);\n                default:\n                    return false;\n            }\n            break;\n        case GGML_OP_GLU:\n            switch (ggml_get_glu_op(op)) {\n                case GGML_GLU_OP_REGLU:\n                case GGML_GLU_OP_GEGLU:\n                case GGML_GLU_OP_SWIGLU:\n                case GGML_GLU_OP_SWIGLU_OAI:\n                case GGML_GLU_OP_GEGLU_ERF:\n                case GGML_GLU_OP_GEGLU_QUICK:\n                    return ggml_is_contiguous_1(op->src[0]);\n                default:\n                    return false;\n            }\n            break;\n        case GGML_OP_MUL_MAT:\n        case GGML_OP_MUL_MAT_ID:\n            {\n                struct ggml_tensor * a = op->src[0];\n                struct ggml_tensor * b = op->src[1];\n                if (a->buffer && ggml_backend_buft_is_cuda_split(a->buffer->buft)) {\n                    if (a->ne[2] > 1 || a->ne[3] > 1) {\n                        return false;\n                    }\n                    // for small weight matrices the active device can end up without any rows, don't use row split in those cases\n                    // this avoids some edge cases (and the performance would not be good anyways)\n                    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) a->buffer->buft->context;\n                    int64_t row_low;\n                    int64_t row_high;\n                    get_row_split(&row_low, &row_high, a, buft_ctx->tensor_split, dev_ctx->device);\n                    if (row_low == row_high) {\n                        return false;\n                    }\n                }\n                if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {\n                    return false;\n                }\n#ifdef GGML_USE_MUSA\n                const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;\n                if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {\n                    if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT &&\n                            a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {\n                        return false;\n                    }\n                    if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID &&\n                            a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {\n                        return false;\n                    }\n                }\n#endif // GGML_USE_MUSA\n                switch (a->type) {\n                    case GGML_TYPE_F32:\n                    case GGML_TYPE_F16:\n                    case GGML_TYPE_Q4_0:\n                    case GGML_TYPE_Q4_1:\n                    case GGML_TYPE_Q5_0:\n                    case GGML_TYPE_Q5_1:\n                    case GGML_TYPE_Q8_0:\n                    case GGML_TYPE_MXFP4:\n                    case GGML_TYPE_Q2_K:\n                    case GGML_TYPE_Q3_K:\n                    case GGML_TYPE_Q4_K:\n                    case GGML_TYPE_Q5_K:\n                    case GGML_TYPE_Q6_K:\n                    case GGML_TYPE_Q8_K:\n                    case GGML_TYPE_IQ1_M:\n                    case GGML_TYPE_IQ1_S:\n                    case GGML_TYPE_IQ2_S:\n                    case GGML_TYPE_IQ2_XS:\n                    case GGML_TYPE_IQ2_XXS:\n                    case GGML_TYPE_IQ3_S:\n                    case GGML_TYPE_IQ3_XXS:\n                    case GGML_TYPE_IQ4_NL:\n                    case GGML_TYPE_IQ4_XS:\n                    case GGML_TYPE_BF16:\n                        return true;\n                    default:\n                        return false;\n                }\n            } break;\n        case GGML_OP_OUT_PROD:\n            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;\n        case GGML_OP_GET_ROWS:\n            {\n                switch (op->src[0]->type) {\n                    case GGML_TYPE_F16:\n                    case GGML_TYPE_F32:\n                    case GGML_TYPE_BF16:\n                    case GGML_TYPE_I32:\n                    case GGML_TYPE_Q4_0:\n                    case GGML_TYPE_Q4_1:\n                    case GGML_TYPE_Q5_0:\n                    case GGML_TYPE_Q5_1:\n                    case GGML_TYPE_Q8_0:\n                        return true;\n                    default:\n                        return false;\n                }\n            } break;\n        case GGML_OP_GET_ROWS_BACK:\n            {\n                return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;\n            } break;\n        case GGML_OP_SET_ROWS:\n            {\n                return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||\n                       op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||\n                       op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&\n                       op->src[0]->type == GGML_TYPE_F32 &&\n                       (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);\n            } break;\n        case GGML_OP_SET:\n            {\n                const ggml_type t = op->type;\n                return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&\n                    t == op->src[0]->type &&\n                    t == op->src[1]->type;\n            } break;\n        case GGML_OP_CPY:\n            {\n                ggml_type src0_type = op->src[0]->type;\n                ggml_type src1_type = op->src[1]->type;\n                if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) &&\n                    (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16)\n                ) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) {\n                    return true;\n                }\n                if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {\n                    return true;\n                }\n                return false;\n            } break;\n        case GGML_OP_DUP:\n            {\n                ggml_type src0_type = op->src[0]->type;\n                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;\n            } break;\n        case GGML_OP_ARGMAX:\n        case GGML_OP_COUNT_EQUAL:\n            {\n                return true;\n            } break;\n        case GGML_OP_REPEAT:\n            {\n                ggml_type src0_type = op->src[0]->type;\n                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;\n            } break;\n        case GGML_OP_REPEAT_BACK:\n                return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15);\n        case GGML_OP_CONCAT:\n            {\n                ggml_type src0_type = op->src[0]->type;\n                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;\n            } break;\n        case GGML_OP_CONV_TRANSPOSE_1D:\n            {\n                ggml_type src0_type = op->src[0]->type;\n                ggml_type src1_type = op->src[1]->type;\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                return false;\n            } break;\n        case GGML_OP_SILU_BACK:\n            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;\n            break;\n        case GGML_OP_NORM:\n        case GGML_OP_RMS_NORM:\n        case GGML_OP_L2_NORM:\n            return true;\n        case GGML_OP_RMS_NORM_BACK:\n            return ggml_is_contiguous(op->src[0]);\n            break;\n        case GGML_OP_NONE:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n        case GGML_OP_ADD:\n        case GGML_OP_ADD_ID:\n        case GGML_OP_ADD1:\n        case GGML_OP_SUB:\n        case GGML_OP_MUL:\n        case GGML_OP_DIV:\n        case GGML_OP_SCALE:\n        case GGML_OP_SQR:\n        case GGML_OP_SQRT:\n        case GGML_OP_SIN:\n        case GGML_OP_COS:\n        case GGML_OP_CLAMP:\n        case GGML_OP_LOG:\n            return true;\n        case GGML_OP_SSM_SCAN: {\n            if (op->src[3]->ne[0] == 1) {\n                // Mamba2\n                // (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)\n                return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;\n            } else {\n                // Mamba\n                // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)\n                return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;\n            }\n        }\n        case GGML_OP_SSM_CONV: {\n            // assumes d_inner % threads == 0\n            return op->src[0]->ne[1] % 128 == 0;\n        }\n        case GGML_OP_CONT:\n            return true;\n        case GGML_OP_DIAG_MASK_INF:\n            return true;\n        case GGML_OP_SOFT_MAX:\n            return true;\n        case GGML_OP_SOFT_MAX_BACK: {\n            float max_bias = 0.0f;\n            memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));\n            return max_bias == 0.0f;\n        }\n        case GGML_OP_ROLL:\n            if(op->src[0]->type == GGML_TYPE_F32) {\n                return true;\n            }\n            return false;\n        case GGML_OP_ROPE:\n        case GGML_OP_ROPE_BACK: {\n            return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);\n        }\n        case GGML_OP_IM2COL:\n        case GGML_OP_IM2COL_3D:\n        case GGML_OP_CONV_2D:\n        case GGML_OP_CONV_2D_DW:\n        case GGML_OP_CONV_TRANSPOSE_2D:\n        case GGML_OP_POOL_2D:\n            return true;\n        case GGML_OP_ACC:\n            // TODO: extend support like so:\n            //return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]);\n            return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);\n        case GGML_OP_SUM:\n            return ggml_is_contiguous_rows(op->src[0]);\n        case GGML_OP_TOP_K:\n        case GGML_OP_ARGSORT:\n#ifndef GGML_CUDA_USE_CUB\n            return op->src[0]->ne[0] <= 1024;\n#else\n            return true;\n#endif\n        case GGML_OP_SUM_ROWS:\n        case GGML_OP_MEAN:\n        case GGML_OP_GROUP_NORM:\n            return ggml_is_contiguous(op->src[0]);\n        case GGML_OP_PAD:\n            return true;\n        case GGML_OP_UPSCALE:\n        case GGML_OP_PAD_REFLECT_1D:\n        case GGML_OP_ARANGE:\n        case GGML_OP_TIMESTEP_EMBEDDING:\n        case GGML_OP_LEAKY_RELU:\n        case GGML_OP_RWKV_WKV6:\n        case GGML_OP_GATED_LINEAR_ATTN:\n        case GGML_OP_RWKV_WKV7:\n            return true;\n        case GGML_OP_GATED_DELTA_NET:\n            //TODO: enable once MUSA compiler is solved https://github.com/ggml-org/llama.cpp/pull/19504#issuecomment-4018634327\n#ifdef GGML_USE_MUSA\n            return false;\n#else\n            return true;\n#endif // GGML_USE_MUSA\n        case GGML_OP_FLASH_ATTN_EXT:\n            return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);\n        case GGML_OP_CROSS_ENTROPY_LOSS:\n        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:\n        case GGML_OP_OPT_STEP_ADAMW:\n        case GGML_OP_OPT_STEP_SGD:\n        case GGML_OP_FILL:\n        case GGML_OP_CUMSUM:\n        case GGML_OP_TRI:\n        case GGML_OP_DIAG:\n        case GGML_OP_SOLVE_TRI:\n            return true;\n\n        default:\n            return false;\n    }\n}\n\nstatic bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;\n    const bool integrated = ggml_cuda_info().devices[dev_ctx->device].integrated;\n    return (((ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev) || (integrated && ggml_backend_buft_is_cuda_host(buft)));\n}\n\nstatic int64_t get_op_batch_size(const ggml_tensor * op) {\n    switch (op->op) {\n        case GGML_OP_GET_ROWS:\n            return 0;\n        case GGML_OP_MUL_MAT:\n            return op->ne[1];\n        case GGML_OP_MUL_MAT_ID:\n        case GGML_OP_ROPE:\n        case GGML_OP_ROPE_BACK:\n            return op->ne[2];\n        default:\n            return ggml_nrows(op);\n    }\n}\n\nstatic bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n    ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;\n\n    return get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;\n}\n\nstatic ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) {\n#ifdef GGML_CUDA_NO_PEER_COPY\n    return nullptr;\n#else\n    ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *)dev->context;\n\n    ggml_cuda_set_device(dev_ctx->device);\n\n    cudaEvent_t event;\n    CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));\n\n    return new ggml_backend_event {\n        /* .device  = */ dev,\n        /* .context = */ event,\n    };\n#endif\n}\n\nstatic void ggml_backend_cuda_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {\n    GGML_UNUSED(dev);\n\n    CUDA_CHECK(cudaEventDestroy((cudaEvent_t)event->context));\n    delete event;\n}\n\nstatic void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {\n    GGML_UNUSED(dev);\n    CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));\n}\n\nstatic const ggml_backend_device_i ggml_backend_cuda_device_interface = {\n    /* .get_name                = */ ggml_backend_cuda_device_get_name,\n    /* .get_description         = */ ggml_backend_cuda_device_get_description,\n    /* .get_memory              = */ ggml_backend_cuda_device_get_memory,\n    /* .get_type                = */ ggml_backend_cuda_device_get_type,\n    /* .get_props               = */ ggml_backend_cuda_device_get_props,\n    /* .init_backend            = */ ggml_backend_cuda_device_init_backend,\n    /* .get_buffer_type         = */ ggml_backend_cuda_device_get_buffer_type,\n    /* .get_host_buffer_type    = */ ggml_backend_cuda_device_get_host_buffer_type,\n    /* .buffer_from_host_ptr    = */ NULL,\n    /* .supports_op             = */ ggml_backend_cuda_device_supports_op,\n    /* .supports_buft           = */ ggml_backend_cuda_device_supports_buft,\n    /* .offload_op              = */ ggml_backend_cuda_device_offload_op,\n    /* .event_new               = */ ggml_backend_cuda_device_event_new,\n    /* .event_free              = */ ggml_backend_cuda_device_event_free,\n    /* .event_synchronize       = */ ggml_backend_cuda_device_event_synchronize,\n};\n\n// backend reg\n\nstruct ggml_backend_cuda_reg_context {\n    std::vector<ggml_backend_dev_t> devices;\n};\n\nstatic const char * ggml_backend_cuda_reg_get_name(ggml_backend_reg_t reg) {\n    GGML_UNUSED(reg);\n    return GGML_CUDA_NAME;\n}\n\nstatic size_t ggml_backend_cuda_reg_get_device_count(ggml_backend_reg_t reg) {\n    ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context;\n    return ctx->devices.size();\n}\n\nstatic ggml_backend_dev_t ggml_backend_cuda_reg_get_device(ggml_backend_reg_t reg, size_t index) {\n    ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context;\n    GGML_ASSERT(index < ctx->devices.size());\n    return ctx->devices[index];\n}\n\nstatic ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t reg) {\n    static std::vector<ggml_backend_feature> features = []() {\n        std::vector<ggml_backend_feature> features;\n    #define _STRINGIFY(...) #__VA_ARGS__\n    #define STRINGIFY(...) _STRINGIFY(__VA_ARGS__)\n\n    #ifdef __CUDA_ARCH_LIST__\n        features.push_back({ \"ARCHS\", STRINGIFY(__CUDA_ARCH_LIST__) });\n    #endif\n\n    #ifdef GGML_CUDA_FORCE_MMQ\n        features.push_back({ \"FORCE_MMQ\", \"1\" });\n    #endif\n\n    #ifdef GGML_CUDA_FORCE_CUBLAS\n        features.push_back({ \"FORCE_CUBLAS\", \"1\" });\n    #endif\n\n    #ifndef GGML_USE_VMM\n        features.push_back({ \"NO_VMM\", \"1\" });\n    #endif\n\n    #ifdef GGML_CUDA_NO_PEER_COPY\n        features.push_back({ \"NO_PEER_COPY\", \"1\" });\n    #endif\n\n    #ifdef GGML_CUDA_USE_GRAPHS\n        features.push_back({ \"USE_GRAPHS\", \"1\" });\n    #endif\n\n    #ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE\n        features.push_back({ \"PEER_MAX_BATCH_SIZE\", STRINGIFY(GGML_CUDA_PEER_MAX_BATCH_SIZE) });\n    #endif\n\n    #ifdef GGML_CUDA_FA_ALL_QUANTS\n        features.push_back({ \"FA_ALL_QUANTS\", \"1\" });\n    #endif\n\n    {\n        const auto & info = ggml_cuda_info();\n        for (int id = 0; id < info.device_count; ++id) {\n            if (blackwell_mma_available(info.devices[id].cc)) {\n                features.push_back({ \"BLACKWELL_NATIVE_FP4\", \"1\"});\n                break;\n            }\n        }\n    }\n\n    #undef _STRINGIFY\n    #undef STRINGIFY\n\n        features.push_back({ nullptr, nullptr });\n\n        return features;\n    }();\n\n    return features.data();\n\n    GGML_UNUSED(reg);\n}\n\nstatic void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {\n    GGML_UNUSED(reg);\n    if (strcmp(name, \"ggml_backend_split_buffer_type\") == 0) {\n        return (void *)ggml_backend_cuda_split_buffer_type;\n    }\n    if (strcmp(name, \"ggml_backend_register_host_buffer\") == 0) {\n        return (void *)ggml_backend_cuda_register_host_buffer;\n    }\n    if (strcmp(name, \"ggml_backend_unregister_host_buffer\") == 0) {\n        return (void *)ggml_backend_cuda_unregister_host_buffer;\n    }\n    if (strcmp(name, \"ggml_backend_get_features\") == 0) {\n        return (void *)ggml_backend_cuda_get_features;\n    }\n    return nullptr;\n}\n\nstatic const ggml_backend_reg_i ggml_backend_cuda_reg_interface = {\n    /* .get_name          = */ ggml_backend_cuda_reg_get_name,\n    /* .get_device_count  = */ ggml_backend_cuda_reg_get_device_count,\n    /* .get_device        = */ ggml_backend_cuda_reg_get_device,\n    /* .get_proc_address  = */ ggml_backend_cuda_reg_get_proc_address,\n};\n\n// backend registry\nggml_backend_reg_t ggml_backend_cuda_reg() {\n    static ggml_backend_reg reg;\n    static bool initialized = false;\n\n    {\n        static std::mutex mutex;\n        std::lock_guard<std::mutex> lock(mutex);\n        if (!initialized) {\n            ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;\n            const int min_batch_size = getenv(\"GGML_OP_OFFLOAD_MIN_BATCH\") ? atoi(getenv(\"GGML_OP_OFFLOAD_MIN_BATCH\")) : 32;\n\n            for (int i = 0; i < ggml_cuda_info().device_count; i++) {\n                ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;\n                dev_ctx->device = i;\n                dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);\n\n                cudaDeviceProp prop;\n                CUDA_CHECK(cudaGetDeviceProperties(&prop, i));\n                dev_ctx->description = prop.name;\n\n                char pci_bus_id[16] = {};\n                snprintf(pci_bus_id, sizeof(pci_bus_id), \"%04x:%02x:%02x.0\", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);\n                dev_ctx->pci_bus_id = pci_bus_id;\n                dev_ctx->op_offload_min_batch_size = min_batch_size;\n\n                ggml_backend_dev_t dev = new ggml_backend_device {\n                    /* .iface   = */ ggml_backend_cuda_device_interface,\n                    /* .reg     = */ &reg,\n                    /* .context = */ dev_ctx\n                };\n                ctx->devices.push_back(dev);\n            }\n\n            reg = ggml_backend_reg {\n                /* .api_version = */ GGML_BACKEND_API_VERSION,\n                /* .iface       = */ ggml_backend_cuda_reg_interface,\n                /* .context     = */ ctx\n            };\n        }\n\n        initialized = true;\n    }\n\n    return &reg;\n}\n\nggml_backend_t ggml_backend_cuda_init(int device) {\n    if (device < 0 || device >= ggml_backend_cuda_get_device_count()) {\n        GGML_LOG_ERROR(\"%s: invalid device %d\\n\", __func__, device);\n        return nullptr;\n    }\n\n    ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device);\n    if (ctx == nullptr) {\n        GGML_LOG_ERROR(\"%s: failed to allocate context\\n\", __func__);\n        return nullptr;\n    }\n\n    ggml_backend_t cuda_backend = new ggml_backend {\n        /* .guid    = */ ggml_backend_cuda_guid(),\n        /* .iface   = */ ggml_backend_cuda_interface,\n        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),\n        /* .context = */ ctx,\n    };\n\n    return cuda_backend;\n}\n\nGGML_BACKEND_DL_IMPL(ggml_backend_cuda_reg)\n"
  },
  {
    "path": "src/ggml-cuda/gla.cu",
    "content": "#include \"common.cuh\"\n#include \"gla.cuh\"\n\ntemplate<int HEAD_SIZE>\nstatic __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale,\n     const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) {\n    const int tid = threadIdx.x;\n    const int bid = blockIdx.x;\n\n    const int head_size = HEAD_SIZE;\n    const int batch_i = bid / H;\n    const int head_i = bid % H;\n    const int state_size = C * head_size;\n    const int n_seq_tokens = T / B;\n\n    float state[head_size];\n    __shared__ float _k[head_size], _r[head_size], _td[head_size];\n\n    #pragma unroll\n    for (int i = 0; i < head_size; i++) {\n        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];\n    }\n\n    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {\n        __syncthreads();\n        _k[tid] = k[t];\n        _r[tid] = r[t];\n        _td[tid] = td[t];\n        __syncthreads();\n\n        const float _v = v[t];\n        float y = 0;\n        for (int j = 0; j < head_size; j += 4) {\n            const float4 & k = (float4 &)(_k[j]);\n            const float4 & r = (float4 &)(_r[j]);\n            const float4 & td = (float4 &)(_td[j]);\n            float4 & s = (float4 &)(state[j]);\n            float4 kv;\n\n            kv.x = k.x * _v;\n            kv.y = k.y * _v;\n            kv.z = k.z * _v;\n            kv.w = k.w * _v;\n\n            s.x = s.x * td.x + kv.x;\n            s.y = s.y * td.y + kv.y;\n            s.z = s.z * td.z + kv.z;\n            s.w = s.w * td.w + kv.w;\n\n            y += r.x * s.x;\n            y += r.y * s.y;\n            y += r.z * s.z;\n            y += r.w * s.w;\n        }\n        dst[t] = y * scale;\n    }\n\n    #pragma unroll\n    for (int i = 0; i < head_size; i++) {\n        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];\n    }\n}\n\nvoid ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const float * k_d  = (const float *)dst->src[0]->data;\n    const float * v_d  = (const float *)dst->src[1]->data;\n    const float * r_d  = (const float *)dst->src[2]->data;\n    const float * td_d = (const float *)dst->src[3]->data;\n    const float * s_d  = (const float *)dst->src[4]->data;\n\n    const int64_t B = dst->src[4]->ne[1];\n    const int64_t T = dst->src[0]->ne[2];\n    const int64_t C = dst->ne[0];\n    const int64_t H = dst->src[0]->ne[1];\n\n    float scale;\n    memcpy(&scale, (float*)dst->op_params, sizeof(float));\n\n    float * dst_d = (float *)dst->data;\n\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);\n    GGML_ASSERT(C % H == 0);\n    GGML_ASSERT(C / H == 64 || C / H == 128);\n\n\n    if (C / H == 64) {\n        gated_linear_attn_f32<64><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);\n    } else {\n        gated_linear_attn_f32<128><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/gla.cuh",
    "content": "#include \"common.cuh\"\n\nvoid ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/im2col.cu",
    "content": "#include \"im2col.cuh\"\n\n#define MAX_GRIDDIM_Z 65535\n\ntemplate <typename T>\nstatic  __global__ void im2col_kernel(\n        const float * x, T * dst,\n        int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,\n        int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW,\n        int s0, int s1, int p0, int p1, int d0, int d1) {\n    const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;\n    if (i >= IC_KH_KW) {\n        return;\n    }\n\n    const int64_t iic = i / (KH_KW);\n    const int64_t rem = i - iic * KH_KW;\n    const int64_t ikh = rem / KW;\n    const int64_t ikw = rem - ikh * KW;\n\n    const int64_t  iow = blockIdx.y;\n    for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) {\n        const int64_t  in = iz / OH;\n        const int64_t  ioh = iz - in * OH;\n\n        const int64_t iiw = iow * s0 + ikw * d0 - p0;\n        const int64_t iih = ioh * s1 + ikh * d1 - p1;\n\n        const int64_t offset_dst =\n            ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;\n\n        if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {\n            dst[offset_dst] = 0.0f;\n        } else {\n            const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;\n            dst[offset_dst] = x[offset_src + iih * IW + iiw];\n        }\n    }\n\n    GGML_UNUSED(IC);\n    GGML_UNUSED(KH);\n}\n\n// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]\ntemplate <typename T>\nstatic void im2col_cuda(const float * x, T* dst,\n    int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,\n    int64_t N, int64_t IC_IH_IW, int64_t IH_IW,\n    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {\n    const int64_t IC_KH_KW = IC * KH * KW;\n    const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;\n    const int64_t N_OH = N * OH;\n    const int64_t KH_KW = KW*KH;\n    dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z));\n    im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH,\n                                                                                     IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW,\n                                                                                     s0, s1, p0, p1, d0, d1);\n}\n\nstatic void im2col_cuda_f16(const float * x, half * dst,\n    int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,\n    int64_t N, int64_t IC_IH_IW, int64_t IH_IW,\n    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {\n\n    im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);\n}\n\nstatic void im2col_cuda_f32(const float * x, float * dst,\n    int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,\n    int64_t N, int64_t IC_IH_IW, int64_t IH_IW,\n    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {\n\n    im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);\n}\n\nvoid ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    const float * src1_d = (const float *)src1->data;\n    float * dst_d = (float *)dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);\n\n    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];\n    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];\n    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];\n    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];\n    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];\n    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];\n\n    const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;\n\n    const int64_t IC = src1->ne[is_2D ? 2 : 1];\n    const int64_t IH = is_2D ? src1->ne[1] : 1;\n    const int64_t IW =         src1->ne[0];\n\n    const int64_t KH = is_2D ? src0->ne[1] : 1;\n    const int64_t KW =         src0->ne[0];\n\n    const int64_t OH = is_2D ? dst->ne[2] : 1;\n    const int64_t OW =         dst->ne[1];\n\n    const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32\n    const int64_t N        = src1->ne[is_2D ? 3 : 2];\n    const int64_t IH_IW    = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32\n\n    if(dst->type == GGML_TYPE_F16) {\n        im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);\n    } else {\n        im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);\n    }\n}\n\n// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]\ntemplate <typename T>\nstatic  __global__ void im2col_3d_kernel(\n        const float * src, T * dst,\n        int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,\n        int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,\n        int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,\n        int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,\n        int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,\n        int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,\n        int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {\n    const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;\n    if (i >= IC_KD_KH_KW) {\n        return;\n    }\n    GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH);\n    GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW);\n\n    const int64_t iic = i / KD_KH_KW;\n    const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;\n    const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;\n    const int64_t ikw = i % KW;\n\n    const int64_t  iow = blockIdx.y;\n    for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {\n        const int64_t in  = iz / OD_OH;\n        const int64_t iod = (iz - in*OD_OH) / OH;\n        const int64_t ioh = iz % OH;\n\n        const int64_t iiw = iow * s0 + ikw * d0 - p0;\n        const int64_t iih = ioh * s1 + ikh * d1 - p1;\n        const int64_t iid = iod * s2 + ikd * d2 - p2;\n\n        const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;\n\n        if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {\n            dst[offset_dst] = 0.0f;\n        } else {\n            const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);\n            dst[offset_dst] = src[offset_src];\n        }\n    }\n}\n\n// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]\ntemplate <typename T>\nstatic void im2col_3d_cuda(const float * src, T* dst,\n    int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,\n    int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,\n    int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,\n    int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {\n    const int64_t OH_OW = OH*OW;\n    const int64_t KD_KH_KW = KD*KH*KW;\n    const int64_t ID_IH_IW = ID*IH*IW;\n    const int64_t KH_KW = KH*KW;\n    const int64_t IH_IW = IH*IW;\n    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;\n    const int64_t OW_KD_KH_KW = OW*KD*KH*KW;\n    const int64_t N_OD_OH = N*OD*OH;\n    const int64_t OD_OH = OD*OH;\n    const int64_t IC_ID_IH_IW = IC*ID*IH*IW;\n    const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;\n    const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;\n    const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;\n    const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;\n    dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));\n    im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,\n                                                                                           OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,\n                                                                                           IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,\n                                                                                           OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,\n                                                                                           stride_q, stride_z, stride_y, stride_x,\n                                                                                           s0, s1, s2, p0, p1, p2, d0, d1, d2);\n}\n\nstatic void im2col_3d_cuda_f16(const float * src, half * dst,\n    int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,\n    int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,\n    int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,\n    int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {\n\n    im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,\n                         stride_q, stride_z, stride_y, stride_x,\n                         s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);\n}\n\nstatic void im2col_3d_cuda_f32(const float * src, float * dst,\n    int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,\n    int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,\n    int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,\n    int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {\n\n    im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,\n                          stride_q, stride_z, stride_y, stride_x,\n                          s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);\n}\n\nvoid ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    const float * src1_d = (const float *)src1->data;\n    float * dst_d = (float *)dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];\n    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];\n    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];\n    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];\n    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];\n    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];\n    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];\n    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];\n    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];\n    const int32_t IC = ((const int32_t *)(dst->op_params))[9];\n\n    const int64_t N  = ne13 / IC;\n    const int64_t ID = ne12;\n    const int64_t IH = ne11;\n    const int64_t IW = ne10;\n\n    const int64_t OC = ne03 / IC;\n    const int64_t KD = ne02;\n    const int64_t KH = ne01;\n    const int64_t KW = ne00;\n\n    const int64_t OD = ne3 / N;\n    const int64_t OH = ne2;\n    const int64_t OW = ne1;\n\n    const size_t  es       = ggml_element_size(src1);\n    const int64_t stride_x = src1->nb[0] / es;\n    const int64_t stride_y = src1->nb[1] / es;\n    const int64_t stride_z = src1->nb[2] / es;\n    const int64_t stride_q = src1->nb[3] / es;\n\n    if(dst->type == GGML_TYPE_F16) {\n        im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,\n                           stride_q, stride_z, stride_y, stride_x,\n                           s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);\n    } else {\n        im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,\n                           stride_q, stride_z, stride_y, stride_x,\n                           s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/im2col.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_IM2COL_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\nvoid ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/mean.cu",
    "content": "#include \"mean.cuh\"\n#include \"reduce_rows.cuh\"\n\n#ifdef GGML_CUDA_USE_CUB\n#include <cub/cub.cuh>\nusing namespace cub;\n#endif  // GGML_CUDA_USE_CUB\n\ntemplate <typename T> __global__ void divide_by_count(T * result, size_t count) {\n    *result /= static_cast<T>(count);\n}\n\nvoid ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0   = dst->src[0];\n    const float *       src0_d = (const float *) src0->data;\n    float *             dst_d  = (float *) dst->data;\n    cudaStream_t        stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    const int64_t ncols = src0->ne[0];\n    const int64_t nrows = ggml_nrows(src0);\n\n// Special case for reducing vectors\n#ifdef GGML_CUDA_USE_CUB\n#ifdef USE_CUDA_GRAPH\n    cudaStreamCaptureStatus iscapturing;\n    CUDA_CHECK(cudaStreamIsCapturing(stream, &iscapturing));\n#endif // USE_CUDA_GRAPH\n    if ((nrows == 1) &&\n#ifdef USE_CUDA_GRAPH\n            // Determine if CUDA graphs are effectively disabled for this context\n            // (no graph instance exists and we're not capturing, OR graphs are explicitly enabled)\n            (((ncols > 65536) &&\n              (((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||\n               ctx.any_cuda_graph_enabled())) ||\n            // CUDA graphs are enabled - use lower threshold\n             ((ncols > 32768) &&\n              !(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||\n                ctx.any_cuda_graph_enabled())))) {\n#else\n        (ncols > 65536)) {\n#endif // USE_CUDA_GRAPH\n        // Single row - use device-wide reduction\n        size_t           tmp_size = 0;\n        ggml_cuda_pool & pool     = ctx.pool();\n\n        DeviceReduce::Sum(nullptr, tmp_size, src0_d, dst_d, ncols, stream);\n\n        ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);\n        DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, src0_d, dst_d, ncols, stream);\n\n        // Divide by ncols\n        divide_by_count<float><<<1, 1, 0, stream>>>(dst_d, ncols);\n        return;\n    }\n#endif // GGML_CUDA_USE_CUB\n\n    const dim3 block_nums(nrows, 1, 1);\n\n    const int id  = ggml_cuda_get_device();\n    const int nsm = ggml_cuda_info().devices[id].nsm;\n\n    // Heuristic for block size selection to optimize occupancy.\n    // See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132\n    if ((nrows / nsm) < 2) {\n        const dim3 block_dims(512, 1, 1);\n        reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);\n    } else {\n        const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);\n        reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/mean.cuh",
    "content": "#include \"common.cuh\"\n\nvoid ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/mma.cuh",
    "content": "#pragma once\n// This file contains primitives that expose the tensor core PTX instructions for CUDA code.\n// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.\n// The documentation for the PTX instructions can be found under:\n//   https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction\n//\n// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.\n// A is a row-major matrix with shape M x K.\n// B is a column-major matrix with shape K x N.\n// C is a column-major matrix with shape M x N.\n// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.\n// Note that J is measured in physical 32 bit elements instead of logical elements.\n// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.\n// All matrix tiles have ne physical 32 bit elements per warp.\n//\n// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.\n// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.\n\n#include \"common.cuh\"\n\n// On Volta each warp is doing 4 8x8 mma operations in parallel.\n// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.\n// However, the i indices in this file are by default permuted to simplify the index calculations.\n// #define GGML_CUDA_MMA_NO_VOLTA_PERM\n\n#if CUDART_VERSION >= 11080\n\nstatic __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {\n    int ret = 0;\n\n#ifdef TURING_MMA_AVAILABLE\n    asm(\"movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\"\n        : \"=r\"(ret) : \"r\"(x));\n#else\n    GGML_UNUSED(x);\n    NO_DEVICE_CODE;\n#endif // defined(TURING_MMA_AVAILABLE)\n    return ret;\n}\n\n#else\n\nstatic __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {\n    // Imagine transposing row-major matrix to column-major matrix.\n    const int src_i_low  = 2 * (threadIdx.x % 4);\n    const int src_i_high = src_i_low + 1;\n    const int src_j      = threadIdx.x / 4;\n\n    const int src_laneid_low  = src_i_low  * 4 + src_j / 2;\n    const int src_laneid_high = src_i_high * 4 + src_j / 2;\n\n    const int shift_low  = ((src_j + 0) % 2) * 16;\n    const int shift_high = ((src_j + 1) % 2) * 16;\n\n    const int ret_low  = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low,  WARP_SIZE) >> shift_low)  & 0x0000FFFF;\n    const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;\n\n    return ret_low | ret_high;\n}\n\n#endif // CUDART_VERSION >= 11080\n\nstatic __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {\n    half2 ret;\n    *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));\n    return ret;\n}\n\nnamespace ggml_cuda_mma {\n\n    // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,\n    //     effectively the warp is being split into subgroups of threads that each perform a single mma instruction.\n    // In those cases the data can be split in different ways across the warp.\n    enum data_layout {\n        // By default the data uses the I direction as its major dimension and the J direction as its minor dimension.\n        // For the A/C matrices this means I major == row major, J major == column major.\n        // For the B matrix this means I major == column major, J major == row major.\n        // MIRRORED == Each data value is held exactly once per thread subgroup.\n        DATA_LAYOUT_I_MAJOR           =  0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.\n        DATA_LAYOUT_J_MAJOR           = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.\n        DATA_LAYOUT_I_MAJOR_MIRRORED  = 20, // Volta, matrix A&B for RDNA3.\n        DATA_LAYOUT_J_MAJOR_MIRRORED  = 30,\n    };\n    // Implemented mma combinations are:\n    //   - (I_MAJOR, I_MAJOR)          -> I_MAJOR\n    //   - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR\n    //   - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR\n\n    static constexpr bool is_i_major(const data_layout dl) {\n        return dl == DATA_LAYOUT_I_MAJOR ||\n               dl == DATA_LAYOUT_I_MAJOR_MIRRORED;\n    }\n\n    static constexpr __device__ data_layout get_input_data_layout() {\n#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n        return DATA_LAYOUT_I_MAJOR_MIRRORED;\n#else\n        return DATA_LAYOUT_I_MAJOR;\n#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n    }\n\n    template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>\n    struct tile {};\n\n    template <int I_, int J_, typename T>\n    struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {\n        static constexpr int         I  = I_;\n        static constexpr int         J  = J_;\n        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;\n\n#if defined(AMD_MFMA_AVAILABLE)\n        static constexpr int ne = I * J / 64;\n        T x[ne] = {0};\n\n        static constexpr __device__ bool supported() {\n            if (I == 64 && J ==  2) return true;\n            if (I == 16 && J ==  8) return true;\n            if (I == 32 && J ==  4) return true;\n            if (I == 16 && J == 16) return true;\n            if (I == 32 && J == 32) return true;\n            return false;\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>\n                return threadIdx.x % 16;\n            } else if constexpr (I == 16 && J == 8) {\n                return threadIdx.x % 16;\n            } else if constexpr (I == 32 && J == 4) {\n                return threadIdx.x % 32;\n            } else if constexpr (I == 16 && J == 16) {\n                return threadIdx.x % 16;\n            } else if constexpr (I == 32 && J == 32) {\n                return threadIdx.x % 32;\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>\n                return (2 * ((threadIdx.x / 16) % 2) + l);\n            } else if constexpr (I == 16 && J == 8) {\n                return 2 * (threadIdx.x / 16) + l;\n            } else if constexpr (I == 32 && J == 4) {\n                return 2 * (threadIdx.x / 32) + l;\n            } else if constexpr (I == 16 && J == 16) {\n                return 4 * (threadIdx.x / 16) + l;\n            } else if constexpr (I == 32 && J == 32) {\n                return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n        static constexpr int ne = I * J / 32;\n        T x[ne] = {0};\n\n        static constexpr __device__ bool supported() {\n            if (I == 32 && J ==  8) return true;\n            return false;\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            if constexpr (I == 32 && J == 8) {\n#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM\n                return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);\n#else\n                return (l & 2) + (threadIdx.x & ~2);\n#endif // GGML_CUDA_MMA_NO_VOLTA_PERM\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            if constexpr (I == 32 && J == 8) {\n                return (threadIdx.x & 2) + (l & (4 + 1));\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n#elif defined(AMD_WMMA_AVAILABLE)\n        static constexpr int ne = I * J / 32;\n        T x[ne] = {0};\n\n        static constexpr __device__ bool supported() {\n            if (I == 16 && J == 16) return true;\n            if (I == 16 && J == 8) return true;\n            if (I == 16 && J == 4) return true;\n            return false;\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            if constexpr (supported()) {\n                return threadIdx.x % 16;\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            if constexpr (I == 16 && J == 16) {\n#if defined(RDNA3)\n                if constexpr (std::is_same_v<T, float> || std::is_same_v<T, int>) {\n                    // matrix C\n                    return 2 * l + (threadIdx.x / 16);\n                } else {\n                    // matrix A&B\n                    return l;\n                }\n#else\n                // matrix C is the transposed matrix A&B on RDNA4\n                return ne * (threadIdx.x / 16) + l;\n#endif // defined(RDNA3)\n            } else if constexpr (I == 16 && J == 8) {\n                // mmq input for RDNA4\n                return ne * (threadIdx.x / 16) + l;\n            } else if constexpr (I == 16 && J == 4) {\n                return ne * (threadIdx.x / 16) + l;\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n#else\n        static constexpr int ne = I * J / 32;\n        T x[ne] = {0};\n\n        static constexpr __device__ bool supported() {\n            if (I ==  8 && J ==  4) return true;\n            if (I ==  8 && J ==  8) return true;\n            if (I == 16 && J ==  8) return true;\n            if (I == 16 && J == 16) return true;\n            if (I == 32 && J ==  8) return true;\n            return false;\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            if constexpr (I == 8 && J == 4) {\n                return threadIdx.x / 4;\n            } else if constexpr (I == 8 && J == 8) {\n                return threadIdx.x / 4;\n            } else if constexpr (I == 16 && J == 8) {\n                return ((l / 2) * 8) + (threadIdx.x / 4);\n            } else if constexpr (I == 16 && J == 16) {\n                return (((l / 2) % 2) * 8) + (threadIdx.x / 4);\n            } else if constexpr (I == 32 && J == 8) {\n                return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            if constexpr (I == 8 && J == 4) {\n                return threadIdx.x % 4;\n            } else if constexpr (I == 8 && J == 8) {\n                return (l * 4) + (threadIdx.x % 4);\n            } else if constexpr (I == 16 && J == 8) {\n                return ((threadIdx.x % 4) * 2) + (l % 2);\n            } else if constexpr (I == 16 && J == 16) {\n                return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);\n            } else if constexpr (I == 32 && J == 8) {\n                return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n#endif // defined(GGML_USE_HIP)\n    };\n\n    template <int I_, int J_>\n    struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {\n        static constexpr int         I  = I_;\n        static constexpr int         J  = J_;\n        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;\n\n#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n        static constexpr int ne = I * J / WARP_SIZE;\n        half2 x[ne] = {{0.0f, 0.0f}};\n\n        static constexpr __device__ bool supported() {\n            if (I == 32 && J ==  4) return true;\n            return false;\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            if constexpr (I == 32 && J == 4) {\n#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM\n                return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);\n#else\n                return threadIdx.x;\n#endif // GGML_CUDA_MMA_NO_VOLTA_PERM\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            if constexpr (I == 32 && J == 4) {\n                return l;\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n#elif defined(AMD_WMMA_AVAILABLE)\n        static constexpr int ne = I * J / 32;\n        half2 x[ne] = {{0.0f, 0.0f}};\n\n        static constexpr __device__ bool supported() {\n            if (I == 16 && J == 8) return true;\n            return false;\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            if constexpr (I == 16 && J == 8) {\n                return threadIdx.x % 16;\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            if constexpr (I == 16 && J == 8) {\n                return ne * (threadIdx.x / 16) + l;\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n#elif defined(AMD_MFMA_AVAILABLE)\n        static constexpr int ne = I * J / 64;\n        half2 x[ne] = {{0.0f, 0.0f}};\n\n        static constexpr __device__ bool supported() {\n            if (I == 16 && J == 8) return true;\n            return false;\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            if constexpr (I == 16 && J == 8) {\n                return threadIdx.x % 16;\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            if constexpr (I == 16 && J == 8) {\n                return ne * (threadIdx.x / 16) + l;\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n#else\n        static constexpr int ne = I * J / WARP_SIZE;\n        half2 x[ne] = {{0.0f, 0.0f}};\n\n        static constexpr __device__ bool supported() {\n            if (I ==  8 && J ==  4) return true;\n            if (I ==  8 && J ==  8) return true;\n            if (I == 16 && J ==  8) return true;\n            if (I == 16 && J == 16) return true;\n            if (I == 32 && J ==  8) return true;\n            return false;\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            if constexpr (I == 8 && J == 8) {\n                return threadIdx.x / 4;\n            } else if constexpr (I == 16 && J == 4) {\n                return (l * 8) + (threadIdx.x / 4);\n            } else if constexpr (I == 16 && J == 8) {\n                return ((l % 2) * 8) + (threadIdx.x / 4);\n            } else if constexpr (I == 32 && J == 8) {\n                return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            if constexpr (I == 8 && J == 8) {\n                return (l * 4) + (threadIdx.x % 4);\n            } else if constexpr (I == 16 && J == 4) {\n                return threadIdx.x % 4;\n            } else if constexpr (I == 16 && J == 8) {\n                return ((l / 2) * 4) + (threadIdx.x % 4);\n            } else if constexpr (I == 32 && J == 8) {\n                return ((l & 2) * 2) + (threadIdx.x % 4);\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n    };\n\n    template <int I_, int J_>\n    struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {\n        static constexpr int         I  = I_;\n        static constexpr int         J  = J_;\n        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;\n\n#if defined(AMD_WMMA_AVAILABLE)\n        static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;\n        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};\n\n        static constexpr __device__ bool supported() {\n            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);\n        }\n#elif defined(AMD_MFMA_AVAILABLE)\n        static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;\n        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};\n\n        static constexpr __device__ bool supported() {\n            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);\n        }\n#else\n        static constexpr int ne = I * J / WARP_SIZE;\n        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};\n\n        static constexpr __device__ bool supported() {\n            if (I ==  8 && J ==  8) return true;\n            if (I == 16 && J ==  4) return true;\n            if (I == 16 && J ==  8) return true;\n            return false;\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            if constexpr (I == 8 && J == 8) {\n                return threadIdx.x / 4;\n            } else if constexpr (I == 16 && J == 4) {\n                return (l * 8) + (threadIdx.x / 4);\n            } else if constexpr (I == 16 && J == 8) {\n                return ((l % 2) * 8) + (threadIdx.x / 4);\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            if constexpr (I == 8 && J == 8) {\n                return (l * 4) + (threadIdx.x % 4);\n            } else if constexpr (I == 16 && J == 4) {\n                return threadIdx.x % 4;\n            } else if constexpr (I == 16 && J == 8) {\n                return ((l / 2) * 4) + (threadIdx.x % 4);\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n#endif  // defined(AMD_WMMA_AVAILABLE)\n    };\n\n    template <int I_, int J_, typename T>\n    struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {\n        static constexpr int         I  = I_;\n        static constexpr int         J  = J_;\n        static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;\n\n        static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;\n        T x[ne] = {0};\n\n        static constexpr __device__ bool supported() {\n            return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);\n        }\n    };\n\n    template <int I_, int J_, typename T>\n    struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {\n        static constexpr int         I  = I_;\n        static constexpr int         J  = J_;\n        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;\n\n        // RDNA3\n        static constexpr int         ne = I * J / 32 * 2;\n\n        T x[ne] = {0};\n\n        static constexpr __device__ bool supported() {\n            if (I == 16 && J == 16) return true;\n            if (I == 16 && J == 8)  return true;\n            if (I == 16 && J == 4)  return true;\n            return false;\n        }\n\n        static __device__ __forceinline__ int get_i(const int /*l*/) {\n            if constexpr (supported()) {\n                return threadIdx.x % 16;\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            if constexpr (supported()) {\n                return l;\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n    };\n\n    template <int I_, int J_>\n    struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {\n        static constexpr int         I  = I_;\n        static constexpr int         J  = J_;\n        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;\n#if defined(RDNA3)\n        static constexpr int         ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;\n\n        half2 x[ne] = {{0.0f, 0.0f}};\n\n        static constexpr __device__ bool supported() {\n            return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);\n        }\n#else // Volta\n        static constexpr int         ne = I * J / (WARP_SIZE/4);\n\n        half2 x[ne] = {{0.0f, 0.0f}};\n\n        static constexpr __device__ bool supported() {\n            if (I ==  8 && J ==  4) return true;\n            return false;\n        }\n\n        static __device__ __forceinline__ int get_i(const int /*l*/) {\n            if constexpr (I == 8 && J == 4) {\n                return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            if constexpr (I == 8 && J == 4) {\n                return l;\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n#endif // defined(RDNA3)\n    };\n\n    template <int I_, int J_>\n    struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {\n        static constexpr int         I  = I_;\n        static constexpr int         J  = J_;\n        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;\n        static constexpr int         ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;\n\n        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};\n\n        static constexpr __device__ bool supported() {\n            return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);\n        }\n    };\n\n    template <int I_, int J_>\n    struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {\n        static constexpr int         I  = I_;\n        static constexpr int         J  = J_;\n        static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;\n        static constexpr int         ne = I * J / (WARP_SIZE/4);\n\n        half2 x[ne] = {{0.0f, 0.0f}};\n\n        static constexpr __device__ bool supported() {\n            if (I ==  8 && J ==  4) return true;\n            return false;\n        }\n\n        static __device__ __forceinline__ int get_i(const int l) {\n            if constexpr (I == 8 && J == 4) {\n                return ((l / 2) * 4) + (threadIdx.x % 4);\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n\n        static __device__ __forceinline__ int get_j(const int l) {\n            if constexpr (I == 8 && J == 4) {\n                return ((threadIdx.x / 16) * 2) + (l % 2);\n            } else {\n                NO_DEVICE_CODE;\n                return -1;\n            }\n        }\n    };\n\n#if defined(TURING_MMA_AVAILABLE)\n    template <int I, int J>\n    static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {\n        tile<I, J/2, half2> ret;\n#pragma unroll\n        for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {\n            ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);\n        }\n        return ret;\n    }\n\n    static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {\n        tile<8, 8, half2> ret;\n        ret.x[0] = ggml_cuda_movmatrix(t.x[0]);\n        ret.x[1] = ggml_cuda_movmatrix(t.x[1]);\n\n        return ret;\n    }\n#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n    template <int I, int J>\n    static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {\n        tile<I, J/2, half2> ret;\n#pragma unroll\n        for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {\n            ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);\n        }\n        return ret;\n    }\n\n    static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {\n        NO_DEVICE_CODE;\n        return tile<8, 8, half2>{};\n    }\n#else // Volta\n    template <int I, int J>\n    static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {\n        tile<I, J/2, half2> ret;\n#pragma unroll\n        for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {\n            ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);\n            ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);\n\n            // On Volta FP16 and FP32 tiles have a different memory layout,\n            //     for the conversion threads with an offset of 2 need to exchange half their values:\n            ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(\n                0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);\n        }\n        return ret;\n    }\n#endif // defined(TURING_MMA_AVAILABLE)\n\n    static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) {\n#if defined(RDNA4)\n        const int row = t.get_i(0);\n        const int left_right = t.get_j(0) / 4;\n        const int up_down = row / 8;\n        const int idx = row % 8;\n        reinterpret_cast<half*>(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f;\n#else\n        GGML_UNUSED_VARS(t);\n        NO_DEVICE_CODE;\n#endif // defined(RDNA4)\n    }\n\n    template <int I, int J, typename T, data_layout dl>\n    static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {\n#if defined(AMD_MFMA_AVAILABLE)\n        if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>\n#pragma unroll\n            for (int l = 0; l < t.ne; ++l) {\n                t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];\n            }\n        } else {\n            ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));\n        }\n#elif defined(AMD_WMMA_AVAILABLE)\n        // All wmma layout has contiguous data when i-major.\n        if constexpr (is_i_major(dl)) {\n            // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()\n            constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();\n            if constexpr (sizeof(t.x) > aligned_copy_bytes) {\n                static_assert(sizeof(t.x) % aligned_copy_bytes == 0, \"bad type size\");\n                constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;\n#pragma unroll\n                for (int i = 0; i < aligned_copy_count; ++i) {\n                    ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));\n                }\n            } else {\n                ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));\n            }\n        } else {\n#pragma unroll\n            for (int l = 0; l < t.ne; ++l) {\n                t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];\n            }\n        }\n#else\n#pragma unroll\n        for (int l = 0; l < t.ne; ++l) {\n            t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];\n        }\n#endif // defined(AMD_MFMA_AVAILABLE)\n    }\n\n    template <typename T>\n    static __device__ __forceinline__ void load_ldmatrix(\n            tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {\n#ifdef TURING_MMA_AVAILABLE\n        int * xi = (int *) t.x;\n        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;\n        asm volatile(\"ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];\"\n            : \"=r\"(xi[0]), \"=r\"(xi[1])\n            : \"l\"(xs));\n#else\n        load_generic(t, xs0, stride);\n#endif // TURING_MMA_AVAILABLE\n    }\n\n    template <typename T>\n    static __device__ __forceinline__ void load_ldmatrix(\n            tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {\n#ifdef TURING_MMA_AVAILABLE\n        int * xi = (int *) t.x;\n        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;\n        asm volatile(\"ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];\"\n            : \"=r\"(xi[0]), \"=r\"(xi[1])\n            : \"l\"(xs));\n#else\n#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n        GGML_UNUSED_VARS(t, xs0, stride);\n        NO_DEVICE_CODE;\n#else\n        load_generic(t, xs0, stride);\n#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n#endif // TURING_MMA_AVAILABLE\n    }\n\n    template <typename T, data_layout dl>\n    static __device__ __forceinline__ void load_ldmatrix(\n            tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {\n#if defined(TURING_MMA_AVAILABLE)\n        int * xi = (int * ) t.x;\n        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);\n        asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];\"\n            : \"=r\"(xi[0]), \"=r\"(xi[1]), \"=r\"(xi[2]), \"=r\"(xi[3])\n            : \"l\"(xs));\n#else\n#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n#if 1\n        // TODO: more generic handling\n        static_assert(sizeof(T) == 4, \"bad type size\");\n        ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);\n        ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);\n#else\n        load_generic(t, xs0, stride);\n#endif // 1\n#else\n        load_generic(t, xs0, stride);\n#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n#endif // TURING_MMA_AVAILABLE\n    }\n\n    static __device__ __forceinline__ void load_ldmatrix(\n            tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {\n        ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);\n    }\n\n    static __device__ __forceinline__ void load_ldmatrix(\n            tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {\n#pragma unroll\n        for (int l0 = 0; l0 < t.ne; l0 += 2) {\n            ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));\n        }\n    }\n\n    static __device__ __forceinline__ void load_ldmatrix(\n            tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {\n#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n        ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);\n#else\n        GGML_UNUSED_VARS(t, xs0, stride);\n        NO_DEVICE_CODE;\n#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n    }\n\n    template <typename T>\n    static __device__ __forceinline__ void load_ldmatrix_trans(\n            tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {\n#ifdef TURING_MMA_AVAILABLE\n        int * xi = (int * ) t.x;\n        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);\n        asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];\"\n            : \"=r\"(xi[0]), \"=r\"(xi[2]), \"=r\"(xi[1]), \"=r\"(xi[3])\n            : \"l\"(xs));\n#else\n        GGML_UNUSED_VARS(t, xs0, stride);\n        NO_DEVICE_CODE;\n#endif // TURING_MMA_AVAILABLE\n    }\n\n    static __device__ __forceinline__ void mma(\n            tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {\n#ifdef TURING_MMA_AVAILABLE\n#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n        asm(\"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};\"\n            : \"+r\"(D.x[0]), \"+r\"(D.x[1]), \"+r\"(D.x[2]), \"+r\"(D.x[3])\n            : \"r\"(A.x[0]), \"r\"(A.x[1]), \"r\"(B.x[0]));\n#else\n        // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:\n        asm(\"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};\"\n            : \"+r\"(D.x[0]), \"+r\"(D.x[1])\n            : \"r\"(A.x[0]), \"r\"(B.x[0]));\n        asm(\"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};\"\n            : \"+r\"(D.x[2]), \"+r\"(D.x[3])\n            : \"r\"(A.x[1]), \"r\"(B.x[0]));\n#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // TURING_MMA_AVAILABLE\n    }\n\n    static __device__ __forceinline__ void mma(\n            tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {\n#ifdef TURING_MMA_AVAILABLE\n#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n        asm(\"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\"\n            : \"+r\"(D.x[0]), \"+r\"(D.x[1]), \"+r\"(D.x[2]), \"+r\"(D.x[3])\n            : \"r\"(A.x[0]), \"r\"(A.x[1]), \"r\"(A.x[2]), \"r\"(A.x[3]), \"r\"(B.x[0]), \"r\"(B.x[1]));\n#else\n        // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:\n        asm(\"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};\"\n            : \"+r\"(D.x[0]), \"+r\"(D.x[1])\n            : \"r\"(A.x[0]), \"r\"(B.x[0]));\n        asm(\"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};\"\n            : \"+r\"(D.x[2]), \"+r\"(D.x[3])\n            : \"r\"(A.x[1]), \"r\"(B.x[0]));\n        asm(\"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};\"\n            : \"+r\"(D.x[0]), \"+r\"(D.x[1])\n            : \"r\"(A.x[2]), \"r\"(B.x[1]));\n        asm(\"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};\"\n            : \"+r\"(D.x[2]), \"+r\"(D.x[3])\n            : \"r\"(A.x[3]), \"r\"(B.x[1]));\n#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // TURING_MMA_AVAILABLE\n    }\n\n    static __device__ __forceinline__ void mma(\n            tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {\n#ifdef TURING_MMA_AVAILABLE\n        const int * Axi = (const int *) A.x;\n        const int * Bxi = (const int *) B.x;\n        int       * Dxi = (int       *) D.x;\n#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n        asm(\"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[0]), \"r\"(Bxi[1]));\n#else\n        // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:\n        asm(\"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Bxi[0]));\n        asm(\"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1])\n            : \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[1]));\n#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // TURING_MMA_AVAILABLE\n    }\n\n    static __device__ __forceinline__ void mma(\n            tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {\n#ifdef TURING_MMA_AVAILABLE\n        const int * Axi = (const int *) A.x;\n        const int * Bxi = (const int *) B.x;\n        int       * Dxi = (int       *) D.x;\n#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n        asm(\"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[0]), \"r\"(Bxi[2]));\n        asm(\"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};\"\n            : \"+r\"(Dxi[2]), \"+r\"(Dxi[3])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[1]), \"r\"(Bxi[3]));\n#else\n        // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:\n        asm(\"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Bxi[0]));\n        asm(\"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1])\n            : \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[2]));\n        asm(\"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};\"\n            : \"+r\"(Dxi[2]), \"+r\"(Dxi[3])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Bxi[1]));\n        asm(\"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};\"\n            : \"+r\"(Dxi[2]), \"+r\"(Dxi[3])\n            : \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[3]));\n#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n#elif defined(AMD_WMMA_AVAILABLE)\n#if defined(RDNA4)\n        using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;\n        halfx8_t& acc_frag = reinterpret_cast<halfx8_t&>(D.x[0]);\n        const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);\n        const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);\n        acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // defined(RDNA4)\n#elif defined(AMD_MFMA_AVAILABLE)\n        // MFMA: FP16 input, FP32 accumulate, convert back to half2.\n        using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;\n        using floatx4_t = __attribute__((ext_vector_type(4))) float;\n\n        // Convert existing half2 accumulator to float for MFMA:\n        floatx4_t acc_f32;\n        {\n            const halfx4_t acc_h = reinterpret_cast<const halfx4_t&>(D.x[0]);\n#pragma unroll\n            for (int i = 0; i < 4; ++i) {\n                acc_f32[i] = (float)acc_h[i];\n            }\n        }\n\n        const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);\n        const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);\n        acc_f32 = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_f32, 0, 0, 0);\n\n        // Convert back to half2:\n        {\n            halfx4_t result_h;\n#pragma unroll\n            for (int i = 0; i < 4; ++i) {\n                result_h[i] = (_Float16)acc_f32[i];\n            }\n            reinterpret_cast<halfx4_t&>(D.x[0]) = result_h;\n        }\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // TURING_MMA_AVAILABLE\n    }\n\n    template <data_layout dl_ab, data_layout dl_d>\n    static __device__ __forceinline__ void mma(\n            tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {\n#ifdef AMPERE_MMA_AVAILABLE\n        const int * Axi = (const int *) A.x;\n        const int * Bxi = (const int *) B.x;\n        int       * Dxi = (int       *) D.x;\n        asm(\"mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1]), \"+r\"(Dxi[2]), \"+r\"(Dxi[3])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[0]), \"r\"(Bxi[1]));\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // AMPERE_MMA_AVAILABLE\n    }\n\n    template <data_layout dl_ab, data_layout dl_d>\n    static __device__ __forceinline__ void mma(\n            tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {\n#ifdef AMD_MFMA_AVAILABLE\n        using floatx4_t = __attribute__((ext_vector_type(4))) float;\n        floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);\n#if defined(CDNA3)\n        using floatx2_t = __attribute__((ext_vector_type(2))) float;\n        const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);\n        const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);\n        acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);\n#elif defined(CDNA2) || defined(CDNA1)\n#pragma unroll\n        for (int i = 0; i < 2; ++i) {\n            acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);\n        }\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // defined(CDNA3)\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // AMD_MFMA_AVAILABLE\n    }\n\n    static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> &     D,\n                                                            const tile<16, 8, int> & A,\n                                                            const tile<8, 8, int> &  B,\n                                                            uint32_t                 a_scale,\n                                                            uint32_t                 b_scale) {\n#ifdef BLACKWELL_MMA_AVAILABLE\n        const int * Axi = (const int *) A.x;\n        const int * Bxi = (const int *) B.x;\n        float *     Dxi = (float *) D.x;\n\n        asm volatile(\n            \"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 \"\n            \"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, \"\n            \"%10, {0, 0}, %11, {0, 0};\"\n            : \"+f\"(Dxi[0]), \"+f\"(Dxi[1]), \"+f\"(Dxi[2]), \"+f\"(Dxi[3])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[0]), \"r\"(Bxi[1]), \"r\"(a_scale), \"r\"(b_scale));\n#else\n        GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);\n#endif  // BLACKWELL_MMA_AVAILABLE\n    }\n\n    static __device__ __forceinline__ void mma(\n            tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {\n#ifdef TURING_MMA_AVAILABLE\n        const int * Axi = (const int *) A.x;\n        const int * Bxi = (const int *) B.x;\n        int       * Dxi = (int       *) D.x;\n#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n        asm(\"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1]), \"+r\"(Dxi[2]), \"+r\"(Dxi[3])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[0]), \"r\"(Bxi[1]));\n#else\n        // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:\n        asm(\"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1]), \"+r\"(Dxi[2]), \"+r\"(Dxi[3])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Bxi[0]));\n        asm(\"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1]), \"+r\"(Dxi[2]), \"+r\"(Dxi[3])\n            : \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[1]));\n#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // TURING_MMA_AVAILABLE\n    }\n\n    static __device__ __forceinline__ void mma(\n            tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {\n#ifdef AMPERE_MMA_AVAILABLE\n        const int * Axi = (const int *) A.x;\n        const int * Bxi = (const int *) B.x;\n        int       * Dxi = (int       *) D.x;\n        asm(\"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1]), \"+r\"(Dxi[2]), \"+r\"(Dxi[3])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[0]), \"r\"(Bxi[1]));\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // AMPERE_MMA_AVAILABLE\n    }\n\n    template <data_layout dl_ab, data_layout dl_d>\n    static __device__ __forceinline__ void mma(\n            tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {\n#ifdef TURING_MMA_AVAILABLE\n        const int * Axi = (const int *) A.x;\n        const int * Bxi = (const int *) B.x;\n        int       * Dxi = (int       *) D.x;\n#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n        asm(\"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1]), \"+r\"(Dxi[2]), \"+r\"(Dxi[3])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[0]), \"r\"(Bxi[2]));\n        asm(\"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\"\n            : \"+r\"(Dxi[4]), \"+r\"(Dxi[5]), \"+r\"(Dxi[6]), \"+r\"(Dxi[7])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[1]), \"r\"(Bxi[3]));\n#else\n        // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:\n        asm(\"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1]), \"+r\"(Dxi[2]), \"+r\"(Dxi[3])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Bxi[0]));\n        asm(\"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1]), \"+r\"(Dxi[2]), \"+r\"(Dxi[3])\n            : \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[2]));\n        asm(\"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};\"\n            : \"+r\"(Dxi[4]), \"+r\"(Dxi[5]), \"+r\"(Dxi[6]), \"+r\"(Dxi[7])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Bxi[1]));\n        asm(\"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};\"\n            : \"+r\"(Dxi[4]), \"+r\"(Dxi[5]), \"+r\"(Dxi[6]), \"+r\"(Dxi[7])\n            : \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[3]));\n#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE\n#elif defined(AMD_WMMA_AVAILABLE)\n#if defined(RDNA4)\n        using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;\n        using floatx8_t = __attribute__((ext_vector_type(8))) float;\n        floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);\n        const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);\n        const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);\n        acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);\n#elif defined(RDNA3)\n        using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;\n        using floatx8_t = __attribute__((ext_vector_type(8))) float;\n        floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);\n        const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]);\n        const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]);\n        acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag);\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // RDNA4\n#elif defined(AMD_MFMA_AVAILABLE)\n        using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;\n        using floatx4_t = __attribute__((ext_vector_type(4))) float;\n        floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);\n        const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);\n        const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);\n        acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // TURING_MMA_AVAILABLE\n    }\n\n    template <data_layout dl_ab, data_layout dl_d>\n    static __device__ __forceinline__ void mma(\n            tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {\n#if defined(AMD_WMMA_AVAILABLE)\n#if defined(RDNA4)\n        using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;\n        using floatx8_t = __attribute__((ext_vector_type(8))) float;\n        floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);\n        const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);\n        const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);\n        acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);\n#elif defined(RDNA3)\n        using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16;\n        using floatx8_t = __attribute__((ext_vector_type(8))) float;\n        floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);\n        const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]);\n        const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]);\n        acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag);\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // defined(RDNA4)\n#elif defined(AMD_MFMA_AVAILABLE)\n        using floatx4_t = __attribute__((ext_vector_type(4))) float;\n        floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);\n#if defined(CDNA3) || defined(CDNA2)\n        using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;\n        const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);\n        const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);\n        acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);\n#elif defined(CDNA1)\n#pragma unroll\n        for (int i = 0; i < 2; ++i) {\n            using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;\n            const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]);\n            const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]);\n            acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);\n        }\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // defined(CDNA3) || defined(CDNA2)\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // defined(AMD_WMMA_AVAILABLE)\n    }\n\n    template <data_layout dl_d, data_layout dl_ab>\n    static __device__ __forceinline__ void mma(\n            tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {\n#if defined(AMD_MFMA_AVAILABLE)\n        using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;\n        int32x4_t * acc = (int32x4_t *) D.x;\n#if defined(CDNA3)\n        acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],\n                                                       ((int64_t *) B.x)[0],\n                                                       acc[0],\n                                                       0, 0, 0);\n#elif defined(CDNA2) || defined(CDNA)\n        acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],\n                                                      B.x[0],\n                                                      acc[0],\n                                                      0, 0, 0);\n        acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],\n                                                      B.x[1],\n                                                      acc[0],\n                                                      0, 0, 0);\n#endif // defined(CDNA3)\n\n#elif defined(AMD_WMMA_AVAILABLE)\n\n        using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;\n        int32x8_t * acc = (int32x8_t *) D.x;\n\n#if defined(RDNA4)\n        using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;\n        int32x2_t * a_vec = (int32x2_t *) A.x;\n        int32x2_t * b_vec = (int32x2_t *) B.x;\n\n        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(\n            true,\n            a_vec[0],\n            true,\n            b_vec[0],\n            acc[0],\n            true\n        );\n\n        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(\n            true,\n            a_vec[1],\n            true,\n            b_vec[1],\n            acc[0],\n            true\n        );\n\n#elif defined(RDNA3)\n        using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;\n        int32x4_t * a_vec = (int32x4_t *) A.x;\n        int32x4_t * b_vec = (int32x4_t *) B.x;\n\n        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(\n            true,\n            a_vec[0],\n            true,\n            b_vec[0],\n            acc[0],\n            true\n        );\n\n        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(\n            true,\n            a_vec[1],\n            true,\n            b_vec[1],\n            acc[0],\n            true\n        );\n#endif // RDNA4\n\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // AMD_MFMA_AVAILABLE\n    }\n\n    static __device__ __forceinline__ void mma(\n            tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {\n#if defined(AMD_MFMA_AVAILABLE)\n        using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;\n        int32x16_t * acc = (int32x16_t *) D.x;\n#if defined(CDNA3)\n        acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],\n                                                       ((int64_t *) B.x)[0],\n                                                       acc[0],\n                                                       0, 0, 0);\n#elif defined(CDNA2) || defined(CDNA)\n        acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],\n                                                     B.x[0],\n                                                     acc[0],\n                                                     0, 0, 0);\n        acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],\n                                                     B.x[1],\n                                                     acc[0],\n                                                     0, 0, 0);\n#endif // defined(CDNA3)\n\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // AMD_MFMA_AVAILABLE\n    }\n\n    template <typename T1, typename T2, int J, int K>\n    static __device__ __forceinline__ void mma(\n            tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {\n        tile      <16, J, T1> * D16 = reinterpret_cast<      tile<16, J, T1> *>(&D);\n        const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);\n        mma(D16[0], A16[0], B);\n        mma(D16[1], A16[1], B);\n    }\n\n    static __device__ __forceinline__ void mma(\n            tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {\n#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n        const int * Axi = (const int *) A.x;\n        const int * Bxi = (const int *) B.x;\n        int       * Dxi = (int       *) D.x;\n        asm(\"mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 \"\n            \"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1]), \"+r\"(Dxi[2]), \"+r\"(Dxi[3]), \"+r\"(Dxi[4]), \"+r\"(Dxi[5]), \"+r\"(Dxi[6]), \"+r\"(Dxi[7])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Bxi[0]), \"r\"(Bxi[1]));\n        asm(\"mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 \"\n            \"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1]), \"+r\"(Dxi[2]), \"+r\"(Dxi[3]), \"+r\"(Dxi[4]), \"+r\"(Dxi[5]), \"+r\"(Dxi[6]), \"+r\"(Dxi[7])\n            : \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[2]), \"r\"(Bxi[3]));\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA\n    }\n\n    static __device__ __forceinline__ void mma(\n            tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {\n#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA\n        const int * Axi = (const int *) A.x;\n        const int * Bxi = (const int *) B.x;\n        int       * Dxi = (int       *) D.x;\n        asm(\"mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 \"\n            \"{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1]), \"+r\"(Dxi[2]), \"+r\"(Dxi[3])\n            : \"r\"(Axi[0]), \"r\"(Axi[1]), \"r\"(Bxi[0]), \"r\"(Bxi[1]));\n        asm(\"mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 \"\n            \"{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};\"\n            : \"+r\"(Dxi[0]), \"+r\"(Dxi[1]), \"+r\"(Dxi[2]), \"+r\"(Dxi[3])\n            : \"r\"(Axi[2]), \"r\"(Axi[3]), \"r\"(Bxi[2]), \"r\"(Bxi[3]));\n#else\n        GGML_UNUSED_VARS(D, A, B);\n        NO_DEVICE_CODE;\n#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA\n    }\n\n    template <data_layout dl_d, data_layout dl_ab>\n    static __device__ __forceinline__ void mma(\n            tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {\n#if defined(AMD_WMMA_AVAILABLE)\n        using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;\n        int32x8_t * acc = (int32x8_t *) D.x;\n#if defined(RDNA4)\n        using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;\n        int32x2_t * a_vec = (int32x2_t *) A.x;\n        int32x2_t * b_vec = (int32x2_t *) B.x;\n\n        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(\n            true,\n            a_vec[0],\n            true,\n            b_vec[0],\n            acc[0],\n            false\n        );\n#elif defined(RDNA3)\n        using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;\n        int32x4_t * a_vec = (int32x4_t *) A.x;\n        int32x4_t * b_vec = (int32x4_t *) B.x;\n\n        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(\n            true,\n            a_vec[0],\n            true,\n            b_vec[0],\n            acc[0],\n            false\n        );\n#endif // RDNA4\n#else\n        GGML_UNUSED(D);\n        GGML_UNUSED(A);\n        GGML_UNUSED(B);\n        NO_DEVICE_CODE;\n#endif // AMD_WMMA_AVAILABLE\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/mmf.cu",
    "content": "#include \"ggml.h\"\n#include \"mmf.cuh\"\n#include \"mmid.cuh\"\n\nstatic __forceinline__ int mmf_get_rows_per_block(const int cc) {\n    if (GGML_CUDA_CC_IS_CDNA(cc)) {\n        return MMF_ROWS_PER_BLOCK_CDNA;\n    } else {\n        return MMF_ROWS_PER_BLOCK;\n    }\n}\n\nvoid ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {\n    GGML_ASSERT(        src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);\n    GGML_ASSERT(         dst->type == GGML_TYPE_F32);\n\n\n    GGML_TENSOR_BINARY_OP_LOCALS;\n\n    const size_t ts_src0 = ggml_type_size(src0->type);\n    const size_t ts_src1 = ggml_type_size(src1->type);\n    const size_t ts_dst  = ggml_type_size(dst->type);\n\n    GGML_ASSERT(ne13 == ne3);\n\n    GGML_ASSERT(        nb00       == ts_src0);\n    GGML_ASSERT(        nb10       == ts_src1);\n    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));\n    GGML_ASSERT(        nb0        == ts_dst);\n\n    const float   * src1_d =       (const float   *) src1->data;\n    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;\n    float         *  dst_d =       (float         *)  dst->data;\n\n    const int64_t s01 = src0->nb[1] / ts_src0;\n    const int64_t s11 = src1->nb[1] / ts_src1;\n    const int64_t s1  =  dst->nb[1] / ts_dst;\n    const int64_t s02 = src0->nb[2] / ts_src0;\n    const int64_t s12 = src1->nb[2] / ts_src1;\n    const int64_t s2  =  dst->nb[2] / ts_dst;\n    const int64_t s03 = src0->nb[3] / ts_src0;\n    const int64_t s13 = src1->nb[3] / ts_src1;\n    const int64_t s3  =  dst->nb[3] / ts_dst;\n\n    const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;\n    const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;\n\n    mmf_ids_data ids_info{};\n    mmf_ids_data * ids_info_ptr = nullptr;\n    ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;\n    ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;\n    ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;\n\n    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:\n    const int64_t ncols_dst          = ids ? ne2  : ne1;\n    const int64_t nchannels_dst      = ids ? ne1 : ne2;\n\n    const int64_t stride_col_dst     = ids ? s2   : s1;\n    const int64_t stride_col_y       = ids ? s12  : s11;\n    const int64_t stride_channel_dst = ids ? s1 : s2;\n\n    int64_t stride_channel_y         = ids ? s11  : s12;\n    int64_t nchannels_y              = ids ? ne11 : ne12;\n\n    //mul_mat_id: handle broadcast\n    if (ids && nchannels_y == 1) {\n        stride_channel_y = 0;\n        nchannels_y      = ids->ne[0];\n    }\n\n    if (ids && ncols_dst > 16) {\n        const int64_t n_expert_used = ids->ne[0];\n        const int64_t n_experts     = ne02;\n        const int64_t n_tokens      = ne12;\n        const int64_t ne_get_rows   = n_tokens * n_expert_used;\n\n        ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);\n        ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);\n        expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);\n\n        const int si1  = static_cast<int>(ids_s1);\n        const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);\n\n        GGML_ASSERT(sis1 > 0);\n\n        ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),\n            static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());\n        CUDA_CHECK(cudaGetLastError());\n\n        ids_info.ids_src_compact   = ids_src_compact_dev.get();\n        ids_info.ids_dst_compact   = ids_dst_compact_dev.get();\n        ids_info.expert_bounds_dev = expert_bounds_dev.get();\n        ids_info.n_experts         = static_cast<int>(n_experts);\n        ids_info.sis1              = sis1;\n        ids_info_ptr = &ids_info;\n    }\n\n    const int device    = ggml_cuda_get_device();\n    const int cc        = ggml_cuda_info().devices[device].cc;\n    const int rows_per_block = mmf_get_rows_per_block(cc);\n\n    switch (src0->type) {\n        case GGML_TYPE_F32: {\n            const float * src0_d = (const float *) src0->data;\n            constexpr int vals_per_T = 1;\n            mul_mat_f_switch_rows_per_block<float>(\n                rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,\n                ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,\n                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);\n        } break;\n        case GGML_TYPE_F16: {\n            const half2 * src0_d = (const half2 *) src0->data;\n            constexpr int vals_per_T = 2;\n            mul_mat_f_switch_rows_per_block<half2>(\n                rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,\n                ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,\n                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);\n        } break;\n        case GGML_TYPE_BF16: {\n            const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;\n            constexpr int vals_per_T = 2;\n            mul_mat_f_switch_rows_per_block<nv_bfloat162>(\n                rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,\n                ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,\n                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);\n        } break;\n        default:\n            GGML_ABORT(\"unsupported type: %s\", ggml_type_name(src0->type));\n    }\n}\n\nbool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne,\n        const size_t * src0_nb, const int src1_ncols, bool mul_mat_id) {\n    if (ggml_is_quantized(type)) {\n        return false;\n    }\n\n    const size_t ts = ggml_type_size(type);\n    if (src0_ne[0] % (warp_size * (4/ts)) != 0) {\n        return false;\n    }\n\n    if (src0_nb[0] != ts) {\n        return false;\n    }\n\n    // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:\n    for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {\n        if (src0_nb[i] % (2*ts) != 0) {\n            return false;\n        }\n    }\n    if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) {\n        return false;\n    }\n\n    if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) {\n        return false;\n    }\n\n    if (mul_mat_id) {\n        if (src0_ne[1] <= 1024 && src1_ncols > 512) {\n            return false;\n        } else if(src0_ne[1] > 1024 && src1_ncols > 128) {\n            return false;\n        }\n    } else {\n        if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {\n            return false;\n        } else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {\n            //TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available.\n            return false;\n        } else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {\n            return false;\n        } else if (src1_ncols > 16) {\n            return false;\n        }\n    }\n\n    switch (type) {\n        case GGML_TYPE_F32:\n            return ampere_mma_available(cc) || amd_mfma_available(cc);\n        case GGML_TYPE_F16:\n            return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);\n        case GGML_TYPE_BF16:\n            return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);\n        default:\n            return false;\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/mmf.cuh",
    "content": "#pragma once\n\n#include \"mma.cuh\"\n#include \"common.cuh\"\n#include \"convert.cuh\"\n\nusing namespace ggml_cuda_mma;\n\n#define MMF_ROWS_PER_BLOCK 32\n#define MMF_ROWS_PER_BLOCK_CDNA 64\n\nstatic __forceinline__ int64_t mmf_get_max_block_size(int cc) {\n    if (GGML_CUDA_CC_IS_CDNA(cc)) {\n        return 512;\n    } else {\n        return 256;\n    }\n}\n\nstatic __forceinline__ int mmf_get_padding(int cc) {\n    if (GGML_CUDA_CC_IS_CDNA(cc)) {\n        return 2;\n    } else {\n        return 4;\n    }\n}\n\nstatic constexpr __device__ int mmf_get_padding() {\n#if defined(AMD_MFMA_AVAILABLE)\n    return 2;\n#else\n    return 4;\n#endif // defined(AMD_MFMA_AVAILABLE)\n}\n\nstruct mmf_ids_data {\n    const int32_t * ids_src_compact = nullptr;\n    const int32_t * ids_dst_compact = nullptr;\n    const int32_t * expert_bounds_dev = nullptr;\n    int n_experts = 0;\n    int sis1 = 0;\n};\n\nvoid ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);\n\nbool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const size_t * src0_nb, const int src1_ncols, bool mul_mat_id);\n\ntemplate <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>\n__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)\nstatic __global__ void mul_mat_f(\n        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,\n        const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,\n        const int stride_col_id, const int stride_row_id,\n        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,\n        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {\n// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added\n#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n#if defined(AMD_WMMA_AVAILABLE)\n    if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {\n    typedef tile<16, 8,  T,     get_input_data_layout()> tile_A;\n    typedef tile<16, 8,  T,     get_input_data_layout()> tile_B;\n    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR>     tile_C;\n#elif defined(AMD_MFMA_AVAILABLE)\n    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {\n    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_A;\n    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_B;\n    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;\n#else\n#ifdef VOLTA_MMA_AVAILABLE\n    if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {\n    typedef tile<32, 4, T,     DATA_LAYOUT_I_MAJOR>          tile_A;\n    typedef tile< 8, 4, T,     DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;\n    typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR>          tile_C;\n#else\n    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {\n    typedef tile<16, 8, T>     tile_A;\n    typedef tile<8,  8, T>     tile_B;\n    typedef tile<16, 8, float> tile_C;\n#endif // VOLTA_MMA_AVAILABLE\n#endif // defined(AMD_WMMA_AVAILABLE)\n    if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {\n        NO_DEVICE_CODE;\n        return;\n    }\n\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n    constexpr int tile_k_padded = warp_size + mmf_get_padding();\n    constexpr int ntA = rows_per_block / tile_A::I;\n    constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;\n\n    const int row0        = blockIdx.x * rows_per_block;\n\n    int expert_idx = 0;\n    int col_base = 0;\n\n    const int channel_dst = has_ids ? 0 : blockIdx.y;\n\n    if constexpr (has_ids) {\n        // experts + tiles of ncols_dst are packed in the y dimension\n        int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block;\n        const int nchannels_x = gridDim.y / col_tiles;\n        const int tile_idx = blockIdx.y / nchannels_x;\n        expert_idx = blockIdx.y - tile_idx * nchannels_x;\n        col_base = tile_idx * cols_per_block;\n    }\n\n    const int channel_x   = has_ids ? expert_idx : (channel_dst / channel_ratio);\n    const int channel_y   = channel_dst;\n    const int sample_dst  = blockIdx.z;\n    const int sample_x    = sample_dst / sample_ratio;\n    const int sample_y    = sample_dst;\n\n    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x  + row0*stride_row ;\n    y   += int64_t(sample_y)  *stride_sample_y   + (has_ids ? 0 : channel_y  *stride_channel_y);\n    dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);\n\n    if constexpr (has_ids) {\n        constexpr int y_stride_scale = std::is_same_v<T, float> ? 1 : 2;\n        const int64_t col_offset = col_base;\n        y   += col_offset * stride_col_y * y_stride_scale;\n        dst += col_offset * stride_col_dst;\n        ids += col_offset * stride_row_id;\n    }\n\n    const float2 * y2 = (const float2 *) y;\n\n    extern __shared__ char data_mmv[];\n\n    char * shmem_base = data_mmv;\n    int  * slot_map   = (int *) shmem_base;\n    char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base;\n\n    tile_C C[ntA][ntB];\n\n    T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);\n\n    if constexpr (has_ids) {\n        int found = 0;\n\n        for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n            if (threadIdx.x == 0) {\n                slot_map[j] = -1;\n            }\n\n            if (col_base + j >= ncols_dst_total) {\n                continue;\n            }\n\n            const int32_t * __restrict__ id_row = ids + j*stride_row_id;\n\n            for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {\n                int match = id_row[k*stride_col_id] == expert_idx;\n\n                if (match) {\n                    slot_map[j] = k;\n                    found = 1;\n                    break;\n                }\n            }\n        }\n\n        if (!__syncthreads_or(found)) {\n            return;\n        }\n    }\n\n\n    for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {\n        tile_A A[ntA][warp_size / tile_A::J];\n#pragma unroll\n        for (int itA = 0; itA < ntA; ++itA) {\n#pragma unroll\n            for (int i = 0; i < tile_A::I; ++i) {\n                tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row  + col];\n            }\n#pragma unroll\n            for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {\n                load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);\n            }\n        }\n\n#pragma unroll\n        for (int itB = 0; itB < ntB; ++itB) {\n            if constexpr (std::is_same_v<T, float>) {\n#pragma unroll\n                for (int j0 = 0; j0 < tile_B::I; ++j0) {\n                    const int j = j0 + itB*tile_B::I;\n\n                    if constexpr (!has_ids) {\n                        tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;\n                    } else {\n                        const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;\n                        tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;\n                    }\n                }\n            } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {\n#pragma unroll\n                for (int j0 = 0; j0 < tile_B::I; ++j0) {\n                    const int j = j0 + itB*tile_B::I;\n\n                    if constexpr (!has_ids) {\n                        const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);\n                        tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);\n                    } else {\n                        const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;\n                        float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);\n                        tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);\n                    }\n                }\n            } else {\n                static_assert(std::is_same_v<T, void>, \"unsupported type\");\n            }\n#pragma unroll\n            for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {\n                tile_B B;\n                load_ldmatrix(B, tile_xy + k0, tile_k_padded);\n#pragma unroll\n                for (int itA = 0; itA < ntA; ++itA) {\n                    mma(C[itA][itB], A[itA][k0/tile_B::J], B);\n                }\n            }\n        }\n    }\n\n    float * buf_iw = (float *) compute_base;\n    constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();\n\n    if (nwarps > 1) {\n        __syncthreads();\n    }\n#pragma unroll\n    for (int itB = 0; itB < ntB; ++itB) {\n#pragma unroll\n        for (int itA = 0; itA < ntA; ++itA) {\n#pragma unroll\n            for (int l = 0; l < tile_C::ne; ++l) {\n                const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);\n                const int j = itB*tile_C::J + tile_C::get_j(l);\n                buf_iw[j*kiw + i] = C[itA][itB].x[l];\n            }\n        }\n    }\n\n    if (nwarps > 1) {\n        __syncthreads();\n    }\n\n#pragma unroll\n    for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {\n        const int j = j0 + threadIdx.y;\n\n        if (j0 + nwarps > cols_per_block && j >= cols_per_block) {\n            return;\n        }\n\n        float sum[rows_per_block/warp_size] = {0.0f};\n        static_assert((rows_per_block % warp_size) == 0, \"rows_per_block must be a multiple of warp_size.\");\n#pragma unroll\n        for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {\n#pragma unroll\n            for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {\n                const int i = i0 + i1*warp_size + threadIdx.x;\n\n                sum[i1] += buf_iw[j*kiw + i];\n            }\n        }\n\n        if constexpr (!has_ids) {\n#pragma unroll\n            for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {\n                dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];\n            }\n        } else {\n            const int slot = (j < cols_per_block) ? slot_map[j] : -1;\n            if (slot >= 0 && (col_base + j) < ncols_dst_total) {\n#pragma unroll\n                for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {\n                    dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];\n                }\n            }\n        }\n    }\n    }\n#else\n    GGML_UNUSED_VARS(x, y, ids, dst,\n        ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,\n        stride_col_id, stride_row_id,\n        channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);\n    NO_DEVICE_CODE;\n#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n}\n\n//This kernel is for larger batch sizes of mul_mat_id\ntemplate <typename T, int rows_per_block, int cols_per_block, int nwarps>\n__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)\nstatic __global__ void mul_mat_f_ids(\n        const T * __restrict__ x, const float * __restrict__ y,\n        const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,\n        const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,\n        const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,\n        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,\n        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,\n        const uint3 sis1_fd, const uint3 nch_fd) {\n// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added\n#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n#if defined(AMD_WMMA_AVAILABLE)\n    if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {\n    typedef tile<16, 8,  T,     get_input_data_layout()> tile_A;\n    typedef tile<16, 8,  T,     get_input_data_layout()> tile_B;\n    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR>     tile_C;\n#elif defined(AMD_MFMA_AVAILABLE)\n    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {\n    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_A;\n    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_B;\n    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;\n#else\n#ifdef VOLTA_MMA_AVAILABLE\n    if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {\n    typedef tile<32, 4, T,     DATA_LAYOUT_I_MAJOR>          tile_A;\n    typedef tile< 8, 4, T,     DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;\n    typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR>          tile_C;\n#else\n    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {\n    typedef tile<16, 8, T>     tile_A;\n    typedef tile<8,  8, T>     tile_B;\n    typedef tile<16, 8, float> tile_C;\n#endif // VOLTA_MMA_AVAILABLE\n#endif // defined(AMD_WMMA_AVAILABLE)\n    if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {\n        NO_DEVICE_CODE;\n        return;\n    }\n\n\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n    constexpr int tile_k_padded = warp_size + mmf_get_padding();\n    constexpr int ntA = rows_per_block / tile_A::I;\n    constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;\n\n    const int row0        = blockIdx.x * rows_per_block;\n\n    const int expert_idx = blockIdx.y;\n    const int expert_start = expert_bounds[expert_idx];\n    const int expert_end   = expert_bounds[expert_idx + 1];\n    const int ncols_expert = expert_end - expert_start;\n\n    const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;\n    const int tile_idx = blockIdx.z;\n    if (tile_idx >= tiles_for_expert) {\n        return;\n    }\n\n    const int col_base = tile_idx * cols_per_block;\n\n    GGML_UNUSED(channel_ratio);\n\n    const int channel_x   = expert_idx;\n    const int sample_dst  = 0;\n    const int sample_x    = sample_dst / sample_ratio;\n    const int sample_y    = sample_dst;\n\n    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x  + row0*stride_row;\n    y   += int64_t(sample_y)  *stride_sample_y;\n    dst += int64_t(sample_dst)*stride_sample_dst;\n\n    const int32_t * ids_src_expert = ids_src_compact + expert_start;\n    const int32_t * ids_dst_expert = ids_dst_compact + expert_start;\n\n    extern __shared__ char data_mmv[];\n    char * compute_base = data_mmv;\n\n    //const float2 * y2 = (const float2 *) y;\n\n    tile_C C[ntA][ntB];\n\n    T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);\n\n    for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {\n        tile_A A[ntA][warp_size / tile_A::J];\n#pragma unroll\n        for (int itA = 0; itA < ntA; ++itA) {\n#pragma unroll\n            for (int i = 0; i < tile_A::I; ++i) {\n                tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row  + col];\n            }\n#pragma unroll\n            for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {\n                load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);\n            }\n        }\n\n        if constexpr (std::is_same_v<T, float>) {\n            float vals_buf[2][tile_B::I];\n            auto gather_tile = [&](int tile_idx_local, float *vals) {\n#pragma unroll\n                for (int j0 = 0; j0 < tile_B::I; ++j0) {\n                    const int j = j0 + tile_idx_local*tile_B::I;\n                    const int global_j = col_base + j;\n                    float val = 0.0f;\n                    if (j < cols_per_block && global_j < ncols_expert) {\n                        const int src_entry = ids_src_expert[global_j];\n                        const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);\n                        const int token   = (int) qrm.x;\n                        const int channel = (int) qrm.y;\n                        if (token < ncols_dst_total) {\n                            val = y[channel*stride_channel_y + token*stride_col_y + col];\n                        }\n                    }\n                    vals[j0] = val;\n                }\n            };\n\n            gather_tile(0, vals_buf[0]);\n\n            int curr_buf = 0;\n            int next_buf = 1;\n#pragma unroll\n            for (int itB = 0; itB < ntB; ++itB) {\n#pragma unroll\n                for (int j0 = 0; j0 < tile_B::I; ++j0) {\n                    tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];\n                }\n\n                if (itB + 1 < ntB) {\n                    gather_tile(itB + 1, vals_buf[next_buf]);\n                }\n\n#pragma unroll\n                for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {\n                    tile_B B;\n                    load_ldmatrix(B, tile_xy + k0, tile_k_padded);\n#pragma unroll\n                    for (int itA = 0; itA < ntA; ++itA) {\n                        mma(C[itA][itB], A[itA][k0/tile_B::J], B);\n                    }\n                }\n\n                if (itB + 1 < ntB) {\n                    curr_buf ^= 1;\n                    next_buf ^= 1;\n                }\n            }\n        } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {\n            float2 vals_buf[2][tile_B::I];\n            auto gather_tile = [&](int tile_idx_local, float2 *vals) {\n#pragma unroll\n                for (int j0 = 0; j0 < tile_B::I; ++j0) {\n                    const int j = j0 + tile_idx_local*tile_B::I;\n                    const int global_j = col_base + j;\n                    float2 tmp = make_float2(0.0f, 0.0f);\n                    if (j < cols_per_block && global_j < ncols_expert) {\n                        const int src_entry = ids_src_expert[global_j];\n                        const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);\n                        const int token   = (int) qrm.x;\n                        const int channel = (int) qrm.y;\n                        if (token < ncols_dst_total) {\n                            tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];\n                        }\n                    }\n                    vals[j0] = tmp;\n                }\n            };\n\n            if (ntB > 0) {\n                gather_tile(0, vals_buf[0]);\n            }\n\n            int curr_buf = 0;\n            int next_buf = 1;\n#pragma unroll\n            for (int itB = 0; itB < ntB; ++itB) {\n#pragma unroll\n                for (int j0 = 0; j0 < tile_B::I; ++j0) {\n                    const float2 tmp = vals_buf[curr_buf][j0];\n                    tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);\n                }\n\n                if (itB + 1 < ntB) {\n                    gather_tile(itB + 1, vals_buf[next_buf]);\n                }\n\n#pragma unroll\n                for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {\n                    tile_B B;\n                    load_ldmatrix(B, tile_xy + k0, tile_k_padded);\n#pragma unroll\n                    for (int itA = 0; itA < ntA; ++itA) {\n                        mma(C[itA][itB], A[itA][k0/tile_B::J], B);\n                    }\n                }\n\n                if (itB + 1 < ntB) {\n                    curr_buf ^= 1;\n                    next_buf ^= 1;\n                }\n            }\n        } else {\n            static_assert(std::is_same_v<T, void>, \"unsupported type\");\n        }\n    }\n\n    float * buf_iw = (float *) compute_base;\n    constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();\n\n    if (nwarps > 1) {\n        __syncthreads();\n    }\n#pragma unroll\n    for (int itB = 0; itB < ntB; ++itB) {\n#pragma unroll\n        for (int itA = 0; itA < ntA; ++itA) {\n#pragma unroll\n            for (int l = 0; l < tile_C::ne; ++l) {\n                const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);\n                const int j = itB*tile_C::J + tile_C::get_j(l);\n                buf_iw[j*kiw + i] = C[itA][itB].x[l];\n            }\n        }\n    }\n\n    if (nwarps > 1) {\n        __syncthreads();\n    }\n\n#pragma unroll\n    for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {\n        const int j = j0 + threadIdx.y;\n\n        if (j0 + nwarps > cols_per_block && j >= cols_per_block) {\n            return;\n        }\n\n        float sum[rows_per_block/warp_size] = {0.0f};\n        static_assert((rows_per_block % warp_size) == 0, \"rows_per_block must be a multiple of warp_size.\");\n#pragma unroll\n        for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {\n#pragma unroll\n            for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {\n                const int i = i0 + i1*warp_size + threadIdx.x;\n\n                sum[i1] += buf_iw[j * kiw + i];\n            }\n        }\n\n        const int global_j = col_base + j;\n        if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {\n            const int dst_entry = ids_dst_expert[global_j];\n            const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd);\n            const int token = (int) qrm.x;\n            if (token < ncols_dst_total) {\n                const int slot = (int) qrm.y;\n#pragma unroll\n                for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {\n                    dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];\n                }\n            }\n        }\n    }\n    }\n#else\n    GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,\n        ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,\n        channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);\n    NO_DEVICE_CODE;\n#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)\n}\n\ntemplate<typename T, int rows_per_block, int cols_per_block, int nwarps>\nstatic inline void mul_mat_f_switch_ids(\n        const T * x, const float * y, const int32_t * ids, float * dst,\n        const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,\n        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,\n        const int64_t stride_col_id, const int64_t stride_row_id,\n        const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,\n        const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,\n        const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,\n        const mmf_ids_data * ids_data) {\n    const bool has_ids_data = ids_data && ids_data->ids_src_compact;\n\n    // Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)\n    // we prefer the normal mul_mat_f path with has_ids=true.\n    if (has_ids_data && ncols_dst > 16) {\n        const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);\n        if (max_tiles == 0) {\n            return;\n        }\n        dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);\n\n        const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);\n        const uint3 nch_fd  = init_fastdiv_values((uint32_t) nchannels_dst);\n\n        mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>\n            (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,\n            ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,\n            channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n            sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,\n            sis1_fd, nch_fd);\n    } else if (ids) {\n        const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;\n        dim3 block_nums_ids = block_nums;\n        block_nums_ids.y *= col_tiles;\n\n        mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>\n            (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,\n             stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n             sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);\n    } else {\n        mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>\n            (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,\n             stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n             sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);\n    }\n}\n\ntemplate <typename T, int rows_per_block, int cols_per_block>\nvoid mul_mat_f_cuda(\n        const T * x, const float * y, const int32_t * ids, float * dst,\n        const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,\n        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,\n        const int64_t stride_col_id, const int64_t stride_row_id,\n        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,\n        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\n        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,\n        cudaStream_t stream, const mmf_ids_data * ids_data) {\n    typedef tile<16, 8, T>     tile_A_16;\n    typedef tile<32, 8, T>     tile_A_32;\n    typedef tile<16, 8, T>     tile_B_16;\n    typedef tile< 8, 8, T>     tile_B_8;\n\n    GGML_ASSERT(ncols_x      % 2 == 0);\n    GGML_ASSERT(stride_row   % 2 == 0);\n    GGML_ASSERT(stride_col_y % 2 == 0);\n    GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);\n    GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);\n    const int64_t channel_ratio = nchannels_dst / nchannels_x;\n    const int64_t sample_ratio  = nsamples_dst  / nsamples_x;\n\n    const int device    = ggml_cuda_get_device();\n    const int cc        = ggml_cuda_info().devices[device].cc;\n    const int warp_size = ggml_cuda_info().devices[device].warp_size;\n\n    int64_t nwarps_best     = 1;\n    int64_t niter_best      = (ncols_x + warp_size*2 - 1) / (warp_size*2);\n    int64_t max_block_size  = mmf_get_max_block_size(cc);\n    for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {\n        const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);\n        if (niter < niter_best) {\n            niter_best  = niter;\n            nwarps_best = nwarps;\n        }\n    }\n\n    const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4;\n    const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I;\n    const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4;\n    const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);\n    const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;\n    const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;\n    const int64_t grid_y = ids ? nchannels_x : nchannels_dst;\n\n    const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);\n    const dim3 block_dims(warp_size, nwarps_best, 1);\n\n    switch (nwarps_best) {\n        case 1: {\n            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>(\n                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,\n                ids_data);\n        } break;\n        case 2: {\n            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>(\n                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,\n                ids_data);\n        } break;\n        case 3: {\n            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>(\n                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,\n                ids_data);\n        } break;\n        case 4: {\n            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>(\n                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,\n                ids_data);\n        } break;\n        case 5: {\n            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>(\n                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,\n                ids_data);\n        } break;\n        case 6: {\n            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>(\n                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,\n                ids_data);\n        } break;\n        case 7: {\n            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>(\n                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,\n                ids_data);\n        } break;\n        case 8: {\n            mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>(\n                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,\n                ids_data);\n        } break;\n        default: {\n            GGML_ABORT(\"fatal error\");\n        } break;\n    }\n\n    GGML_UNUSED_VARS(nchannels_y);\n}\n\ntemplate <typename T, int rows_per_block>\nstatic void mul_mat_f_switch_cols_per_block(\n        const T * x, const float * y, const int32_t * ids, float * dst,\n        const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,\n        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,\n        const int64_t stride_col_id, const int stride_row_id,\n        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,\n        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\n        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,\n        cudaStream_t stream, const mmf_ids_data * ids_data) {\n\n    const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;\n\n    GGML_ASSERT(ids || ncols_dst <= 16);\n\n    switch (ncols_case) {\n        case  1: {\n            mul_mat_f_cuda<T,  rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case  2: {\n            mul_mat_f_cuda<T,  rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case  3: {\n            mul_mat_f_cuda<T,  rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case  4: {\n            mul_mat_f_cuda<T,  rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case  5: {\n            mul_mat_f_cuda<T,  rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y,  stride_sample_dst, stream, ids_data);\n        } break;\n        case  6: {\n            mul_mat_f_cuda<T,  rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case  7: {\n            mul_mat_f_cuda<T,  rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case  8: {\n            mul_mat_f_cuda<T,  rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case  9: {\n            mul_mat_f_cuda<T,  rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case 10: {\n            mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case 11: {\n            mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case 12: {\n            mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case 13: {\n            mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case 14: {\n            mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case 15: {\n            mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case 16: {\n            mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        default: {\n            GGML_ABORT(\"fatal error\");\n        } break;\n    }\n}\n\ntemplate <typename T>\nstatic void mul_mat_f_switch_rows_per_block(\n        const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst,\n        const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,\n        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,\n        const int64_t stride_col_id, const int stride_row_id,\n        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,\n        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\n        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,\n        cudaStream_t stream, const mmf_ids_data * ids_data) {\n    switch (rows_per_block) {\n        case MMF_ROWS_PER_BLOCK: {\n            mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK>(\n                x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        case MMF_ROWS_PER_BLOCK_CDNA: {\n            mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK_CDNA>(\n                x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);\n        } break;\n        default:\n            GGML_ABORT(\"unsupported rows_per_block: %i\", rows_per_block);\n    }\n}\n\n#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \\\n    template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \\\n        const T * x, const float * y, const int32_t * ids, float * dst, \\\n        const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \\\n        const int64_t stride_col_id, const int64_t stride_row_id, \\\n        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \\\n        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\\\n        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \\\n        cudaStream_t stream, const mmf_ids_data * ids_data);\n\n#if !defined(GGML_USE_MUSA)\n#define DECL_MMF_CASE_EXTERN(ncols_dst) \\\n    extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \\\n    extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \\\n    extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \\\n    extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \\\n    extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \\\n    extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)\n\n#define DECL_MMF_CASE(ncols_dst) \\\n    DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \\\n    DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \\\n    DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \\\n    DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \\\n    DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \\\n    DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)\n\nDECL_MMF_CASE_EXTERN(1);\nDECL_MMF_CASE_EXTERN(2);\nDECL_MMF_CASE_EXTERN(3);\nDECL_MMF_CASE_EXTERN(4);\nDECL_MMF_CASE_EXTERN(5);\nDECL_MMF_CASE_EXTERN(6);\nDECL_MMF_CASE_EXTERN(7);\nDECL_MMF_CASE_EXTERN(8);\nDECL_MMF_CASE_EXTERN(9);\nDECL_MMF_CASE_EXTERN(10);\nDECL_MMF_CASE_EXTERN(11);\nDECL_MMF_CASE_EXTERN(12);\nDECL_MMF_CASE_EXTERN(13);\nDECL_MMF_CASE_EXTERN(14);\nDECL_MMF_CASE_EXTERN(15);\nDECL_MMF_CASE_EXTERN(16);\n#else\n#define DECL_MMF_CASE(ncols_dst)\n#endif\n"
  },
  {
    "path": "src/ggml-cuda/mmid.cu",
    "content": "#include \"common.cuh\"\n#include \"mmid.cuh\"\n\n// To reduce shared memory use, store \"it\" and \"iex_used\" with 22/10 bits each.\nstruct mm_ids_helper_store {\n    uint32_t data;\n\n    __device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) {\n        data = (it & 0x003FFFFF) | (iex_used << 22);\n    }\n\n    __device__ uint32_t it() const {\n        return data & 0x003FFFFF;\n    }\n\n    __device__ uint32_t iex_used() const {\n        return data >> 22;\n    }\n};\nstatic_assert(sizeof(mm_ids_helper_store) == 4, \"unexpected size for mm_ids_helper_store\");\n\n// Helper function for mul_mat_id, converts ids to a more convenient format.\n// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.\n// ids_dst describes the same mapping but for the dst tensor.\n// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].\ntemplate <int n_expert_used_template>\n__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)\nstatic __global__ void mm_ids_helper(\n        const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,\n        const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n    const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;\n    const int expert = blockIdx.x;\n\n    extern __shared__ char data_mm_ids_helper[];\n    mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper;\n\n    int nex_prev   = 0; // Number of columns for experts with a lower index.\n    int it_compact = 0; // Running index for the compact slice of this expert.\n\n    if constexpr (n_expert_used_template == 0) {\n        // Generic implementation:\n        for (int it = 0; it < n_tokens; ++it) {\n            int iex_used = -1; // The index at which the expert is used, if any.\n            for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {\n                const int expert_used = ids[it*si1 + iex];\n                nex_prev += expert_used < expert;\n                if (expert_used == expert) {\n                    iex_used = iex;\n                }\n            }\n\n            if (iex_used != -1) {\n                store[it_compact] = mm_ids_helper_store(it, iex_used);\n            }\n\n            if (warp_reduce_any<warp_size>(iex_used != -1)) {\n                it_compact++;\n            }\n        }\n    } else {\n        // Implementation optimized for specific numbers of experts used:\n        static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, \"bad n_expert_used\");\n        const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.\n        for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {\n            const int it = it0 + threadIdx.x / neu_padded;\n\n            const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.\n            const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?\n                ids[it*si1 + iex] : INT_MAX;\n            const int iex_used = expert_used == expert ? iex : -1;\n            nex_prev += expert_used < expert;\n\n            // Whether the threads at this token position have used the expert:\n            const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);\n\n            // Do a scan over threads at lower token positions in warp to get the correct index for writing data:\n            int it_compact_add_lower = 0;\n#pragma unroll\n            for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {\n                const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);\n                if (threadIdx.x >= static_cast<unsigned int>(offset)) {\n                    it_compact_add_lower += tmp;\n                }\n            }\n\n            if (iex_used != -1) {\n                store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used);\n            }\n\n            // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:\n            it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);\n        }\n    }\n    nex_prev = warp_reduce_sum<warp_size>(nex_prev);\n\n    for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {\n        const mm_ids_helper_store store_it = store[itc];\n        const int it       = store_it.it();\n        const int iex_used = store_it.iex_used();\n        ids_src1[nex_prev + itc] = it*sis1          + iex_used % nchannels_y;\n        ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;\n    }\n\n    if (threadIdx.x != 0) {\n        return;\n    }\n\n    expert_bounds[expert] = nex_prev;\n\n    if (expert < static_cast<int>(gridDim.x) - 1) {\n        return;\n    }\n\n    expert_bounds[gridDim.x] = nex_prev + it_compact;\n}\n\ntemplate <int n_expert_used_template>\nstatic void launch_mm_ids_helper(\n        const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,\n        const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {\n    GGML_ASSERT(n_tokens          < (1 << 22) && \"too few bits in mm_ids_helper_store\");\n    GGML_ASSERT(n_expert_used_var < (1 << 10) && \"too few bits in mm_ids_helper_store\");\n\n    const int id = ggml_cuda_get_device();\n    const int warp_size = ggml_cuda_info().devices[id].warp_size;\n    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;\n    CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo);\n\n    const dim3 num_blocks(n_experts, 1, 1);\n    const dim3 block_size(warp_size, 1, 1);\n    const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);\n    GGML_ASSERT(nbytes_shared <= smpbo);\n    mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>\n        (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);\n}\n\nvoid ggml_cuda_launch_mm_ids_helper(\n        const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,\n        const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {\n    switch (n_expert_used) {\n        case  2:\n            launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);\n            break;\n        case  4:\n            launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);\n            break;\n        case  6:\n            launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);\n            break;\n        case  8:\n            launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);\n            break;\n        case 16:\n            launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);\n            break;\n        case 32:\n            launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);\n            break;\n        default:\n            launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);\n            break;\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/mmid.cuh",
    "content": "#pragma once\n\nvoid ggml_cuda_launch_mm_ids_helper(\n        const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,\n        int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);\n"
  },
  {
    "path": "src/ggml-cuda/mmq.cu",
    "content": "#include \"common.cuh\"\n#include \"mmq.cuh\"\n#include \"quantize.cuh\"\n#include \"mmid.cuh\"\n\nstatic void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {\n    switch (args.type_x) {\n        case GGML_TYPE_Q4_0:\n            mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);\n            break;\n        case GGML_TYPE_Q4_1:\n            mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);\n            break;\n        case GGML_TYPE_Q5_0:\n            mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);\n            break;\n        case GGML_TYPE_Q5_1:\n            mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);\n            break;\n        case GGML_TYPE_Q8_0:\n            mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);\n            break;\n        case GGML_TYPE_MXFP4:\n            mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);\n            break;\n        case GGML_TYPE_Q2_K:\n            mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);\n            break;\n        case GGML_TYPE_Q3_K:\n            mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);\n            break;\n        case GGML_TYPE_Q4_K:\n            mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);\n            break;\n        case GGML_TYPE_Q5_K:\n            mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);\n            break;\n        case GGML_TYPE_Q6_K:\n            mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);\n            break;\n        case GGML_TYPE_IQ2_XXS:\n            mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);\n            break;\n        case GGML_TYPE_IQ2_XS:\n            mul_mat_q_case<GGML_TYPE_IQ2_XS>(ctx, args, stream);\n            break;\n        case GGML_TYPE_IQ2_S:\n            mul_mat_q_case<GGML_TYPE_IQ2_S>(ctx, args, stream);\n            break;\n        case GGML_TYPE_IQ3_XXS:\n            mul_mat_q_case<GGML_TYPE_IQ3_XXS>(ctx, args, stream);\n            break;\n        case GGML_TYPE_IQ3_S:\n            mul_mat_q_case<GGML_TYPE_IQ3_S>(ctx, args, stream);\n            break;\n        case GGML_TYPE_IQ1_S:\n            mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream);\n            break;\n        case GGML_TYPE_IQ4_XS:\n            mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);\n            break;\n        case GGML_TYPE_IQ4_NL:\n            mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);\n            break;\n        default:\n            GGML_ABORT(\"fatal error\");\n            break;\n    }\n}\n\nvoid ggml_cuda_mul_mat_q(\n        ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {\n    GGML_ASSERT(        src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(        dst->type  == GGML_TYPE_F32);\n    GGML_ASSERT(!ids || ids->type  == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.\n\n    GGML_TENSOR_BINARY_OP_LOCALS;\n\n    cudaStream_t stream = ctx.stream();\n    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;\n\n    const size_t ts_src0 = ggml_type_size(src0->type);\n    const size_t ts_src1 = ggml_type_size(src1->type);\n    const size_t ts_dst  = ggml_type_size(dst->type);\n\n    GGML_ASSERT(        nb00       == ts_src0);\n    GGML_ASSERT(        nb10       == ts_src1);\n    GGML_ASSERT(        nb0        == ts_dst);\n    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));\n\n    const char  * src0_d = (const char  *) src0->data;\n    const float * src1_d = (const float *) src1->data;\n    float       *  dst_d = (float       *)  dst->data;\n\n    // If src0 is a temporary compute buffer, clear any potential padding.\n    if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {\n        const size_t size_data  = ggml_nbytes(src0);\n        const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);\n        if (size_alloc > size_data) {\n            GGML_ASSERT(ggml_is_contiguously_allocated(src0));\n            GGML_ASSERT(!src0->view_src);\n            CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));\n        }\n    }\n\n    const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);\n\n    const int64_t s01 = src0->nb[1] / ts_src0;\n    const int64_t s1  =  dst->nb[1] / ts_dst;\n    const int64_t s02 = src0->nb[2] / ts_src0;\n    const int64_t s2  =  dst->nb[2] / ts_dst;\n    const int64_t s03 = src0->nb[3] / ts_src0;\n    const int64_t s3  =  dst->nb[3] / ts_dst;\n\n    const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)\n                            || GGML_CUDA_CC_IS_CDNA(cc);\n\n    // TODO: tighter pool buffer size vs q8 path\n    const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;\n\n    if (!ids) {\n        const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +\n            get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);\n        ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);\n\n        {\n            const int64_t s11 = src1->nb[1] / ts_src1;\n            const int64_t s12 = src1->nb[2] / ts_src1;\n            const int64_t s13 = src1->nb[3] / ts_src1;\n            if (use_native_mxfp4) {\n                static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));\n                quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,\n                                        ne11, ne12, ne13, stream);\n\n            } else {\n                quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,\n                                       ne11, ne12, ne13, stream);\n            }\n            CUDA_CHECK(cudaGetLastError());\n        }\n\n        // Stride depends on quantization format\n        const int64_t s12 = use_native_mxfp4 ?\n                                ne11 * ne10_padded * sizeof(block_fp4_mmq) /\n                                    (8 * QK_MXFP4 * sizeof(int))  // block_fp4_mmq holds 256 values (8 blocks of 32)\n                                :\n                                ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));\n        const int64_t s13 = ne12*s12;\n\n        const mmq_args args = {\n            src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,\n            ne00, ne01, ne1, s01, ne11, s1,\n            ne02, ne12, s02, s12, s2,\n            ne03, ne13, s03, s13, s3,\n            use_stream_k, ne1};\n        ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);\n        return;\n    }\n\n    GGML_ASSERT(ne13 == 1);\n    GGML_ASSERT(nb12 % nb11 == 0);\n    GGML_ASSERT(nb2  % nb1  == 0);\n\n    const int64_t n_expert_used = ids->ne[0];\n    const int64_t ne_get_rows = ne12 * n_expert_used;\n    GGML_ASSERT(ne1 == n_expert_used);\n\n    ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);\n    ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);\n    ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);\n\n    {\n        GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));\n        const int si1  = ids->nb[1] / ggml_element_size(ids);\n        const int sis1 = nb12 / nb11;\n\n        ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),\n            ne02, ne12, n_expert_used, ne11, si1, sis1, stream);\n        CUDA_CHECK(cudaGetLastError());\n    }\n\n    const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +\n        get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);\n    ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);\n\n    const int64_t ne11_flat = ne12*n_expert_used;\n    const int64_t ne12_flat = 1;\n    const int64_t ne13_flat = 1;\n\n    {\n        const int64_t s11 = src1->nb[1] / ts_src1;\n        const int64_t s12 = src1->nb[2] / ts_src1;\n        const int64_t s13 = src1->nb[3] / ts_src1;\n\n        if (use_native_mxfp4) {\n            quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,\n                                    ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);\n        } else {\n            quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,\n                                   ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);\n        }\n        CUDA_CHECK(cudaGetLastError());\n    }\n\n    const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :\n                                           ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));\n    const int64_t s13 = ne12*s12;\n\n    // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.\n    const mmq_args args = {\n        src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,\n        ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,\n        ne02, ne02, s02, s12, s2,\n        ne03, ne13, s03, s13, s3,\n        use_stream_k, ne12};\n\n    ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);\n}\n\nvoid ggml_cuda_op_mul_mat_q(\n    ggml_backend_cuda_context & ctx,\n    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,\n    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,\n    const int64_t src1_padded_row_size, cudaStream_t stream) {\n\n    const int64_t ne00 = src0->ne[0];\n\n    const int64_t ne10 = src1->ne[0];\n    const int64_t ne11 = src1->ne[1];\n    GGML_ASSERT(ne10 % QK8_1 == 0);\n\n    const int64_t ne0 = dst->ne[0];\n\n    const int64_t row_diff = row_high - row_low;\n    const int64_t stride01 = ne00 / ggml_blck_size(src0->type);\n\n    const int id = ggml_cuda_get_device();\n    const int cc = ggml_cuda_info().devices[id].cc;\n\n    // the main device has a larger memory buffer to hold the results from all GPUs\n    // nrows_dst == nrows of the matrix that the kernel writes into\n    const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;\n\n    // The stream-k decomposition is only faster for recent NVIDIA GPUs.\n    // Also its fixup needs to allocate a temporary buffer in the memory pool.\n    // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.\n    const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)\n                            || GGML_CUDA_CC_IS_CDNA(cc))\n                            && src1_ncols == ne11;\n    const mmq_args args = {\n        src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,\n        ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,\n        1, 1, 0, 0, 0,\n        1, 1, 0, 0, 0,\n        use_stream_k, src1_ncols};\n\n    ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);\n\n    GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_padded_row_size);\n}\n\nbool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts) {\n#ifdef GGML_CUDA_FORCE_CUBLAS\n    return false;\n#endif // GGML_CUDA_FORCE_CUBLAS\n\n    bool mmq_supported;\n\n    switch (type) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_MXFP4:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ2_S:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ3_S:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ4_NL:\n            mmq_supported = true;\n            break;\n        default:\n            mmq_supported = false;\n            break;\n    }\n\n    if (!mmq_supported) {\n        return false;\n    }\n\n    if (turing_mma_available(cc)) {\n        return true;\n    }\n\n    if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {\n        return false;\n    }\n\n#ifdef GGML_CUDA_FORCE_MMQ\n    return true;\n#endif //GGML_CUDA_FORCE_MMQ\n\n    if (GGML_CUDA_CC_IS_NVIDIA(cc)) {\n        return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;\n    }\n\n    if (amd_mfma_available(cc)) {\n        // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT)\n        // performs better but is currently suffering from a crash on this architecture.\n        // TODO: Revisit when hipblaslt is fixed on CDNA3\n        if (GGML_CUDA_CC_IS_CDNA3(cc)) {\n            return true;\n        }\n        if (n_experts > 64 || ne11 <= 128) {\n            return true;\n        }\n        if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {\n            return true;\n        }\n        if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {\n            return true;\n        }\n        return false;\n    }\n\n    if (amd_wmma_available(cc)) {\n        if (GGML_CUDA_CC_IS_RDNA3(cc)) {\n            // High expert counts are almost always better on MMQ due to\n            //     the synchronization overhead in the cuBLAS/hipBLAS path:\n            // https://github.com/ggml-org/llama.cpp/pull/18202\n            if (n_experts >= 64) {\n                return true;\n            }\n\n            // For some quantization types MMQ can have lower peak TOPS than hipBLAS\n            //     so it's only faster for sufficiently small batch sizes:\n            switch (type) {\n                case GGML_TYPE_Q2_K:\n                    return ne11 <= 128;\n                case GGML_TYPE_Q6_K:\n                    return ne11 <= (GGML_CUDA_CC_IS_RDNA3_0(cc) ? 128 : 256);\n                case GGML_TYPE_IQ2_XS:\n                case GGML_TYPE_IQ2_S:\n                    return GGML_CUDA_CC_IS_RDNA3_5(cc) || ne11 <= 128;\n                default:\n                    return true;\n            }\n        }\n\n        // For RDNA4 MMQ is consistently faster than dequantization + hipBLAS:\n        // https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301\n        return true;\n    }\n\n    return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;\n\n}\n"
  },
  {
    "path": "src/ggml-cuda/mmq.cuh",
    "content": "#pragma once\n\n#include \"common.cuh\"\n#include \"vecdotq.cuh\"\n#include \"mma.cuh\"\n\n#include <climits>\n#include <cstdint>\n\nusing namespace ggml_cuda_mma;\n\n#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.\n#define MMQ_ITER_K 256\n#define MMQ_ITER_K_MXFP4_FP4    512\n#define MMQ_NWARPS 8\n\ntypedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);\ntypedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);\ntypedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted,\n    float * __restrict__ dst, const int stride, const int i_max, const int j_max);\n\nenum mmq_q8_1_ds_layout {\n    MMQ_Q8_1_DS_LAYOUT_D4,\n    MMQ_Q8_1_DS_LAYOUT_DS4,\n    MMQ_Q8_1_DS_LAYOUT_D2S6,\n};\n\nstruct block_q8_1_mmq {\n    // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block.\n    // The y float data is first grouped as blocks of 128 values.\n    // These blocks are then treated as individual data values and transposed.\n    //\n    // To avoid shared memory bank conflicts each block is padded with 16 bytes.\n    // This padding is also used to store block scales/partial sums.\n    // The scales multiplied with the quantized data are equal to the unquantized values.\n    // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization)\n    //     and are only needed for performance reasons.\n    //\n    // The exact data stored depends on the x data type.\n    union {\n        float d4[4];    // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3\n        half2 ds4[4];   // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3\n        half  d2s6[8];  // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values,\n                        //     stored as d0,d1,s1,s2,s3,s4,s5\n    };\n    int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each\n};\n\nstruct block_fp4_mmq {\n    uint32_t d4[4];       // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.\n    int8_t   qs[4 * 32];  // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values\n};\n\nstatic_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), \"Unexpected block_q8_1_mmq size\");\nstatic_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1),      \"Unexpected block_q8_1_mmq size\");\nstatic_assert(sizeof(block_fp4_mmq)  == sizeof(block_q8_1_mmq),    \"Unexpected block_fp4_mmq size\");\n\nstatic mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {\n    switch (type_x) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n            return MMQ_Q8_1_DS_LAYOUT_DS4;\n        case GGML_TYPE_Q5_0:\n            return MMQ_Q8_1_DS_LAYOUT_D4;\n        case GGML_TYPE_Q5_1:\n            return MMQ_Q8_1_DS_LAYOUT_DS4;\n        case GGML_TYPE_Q8_0:\n            return MMQ_Q8_1_DS_LAYOUT_D4;\n        case GGML_TYPE_MXFP4:\n            return MMQ_Q8_1_DS_LAYOUT_D4;\n        case GGML_TYPE_Q2_K:\n            return MMQ_Q8_1_DS_LAYOUT_D2S6;\n        case GGML_TYPE_Q3_K:\n            return MMQ_Q8_1_DS_LAYOUT_D4;\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n            return MMQ_Q8_1_DS_LAYOUT_DS4;\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ2_S:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ3_S:\n            return MMQ_Q8_1_DS_LAYOUT_D4;\n        case GGML_TYPE_IQ1_S:\n            return MMQ_Q8_1_DS_LAYOUT_DS4;\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ4_NL:\n            return MMQ_Q8_1_DS_LAYOUT_D4;\n        default:\n            GGML_ABORT(\"fatal error\");\n            break;\n    }\n}\n\nstruct tile_x_sizes {\n    int qs;\n    int dm;\n    int sc;\n};\n\nstatic int get_mmq_x_max_host(const int cc) {\n    return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :\n        GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?\n#ifdef GGML_CUDA_FORCE_MMQ\n            128                     : 64;\n#else\n            MMQ_DP4A_MAX_BATCH_SIZE : 64;\n#endif // GGML_CUDA_FORCE_MMQ\n}\n\nstatic constexpr __device__ int get_mmq_x_max_device() {\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    return 128;\n#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)\n\n#if defined(GGML_USE_HIP)\n    return 64;\n#else // defined(GGML_USE_HIP)\n\n#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA\n#ifdef GGML_CUDA_FORCE_MMQ\n    return 128;\n#else // GGML_CUDA_FORCE_MMQ\n    return MMQ_DP4A_MAX_BATCH_SIZE;\n#endif // GGML_CUDA_FORCE_MMQ\n#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA\n    return 64;\n#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA\n\n#endif // defined(GGML_USE_HIP)\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n}\n\nstatic int get_mmq_y_host(const int cc) {\n    return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :\n        ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);\n}\n\nstatic constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {\n#if defined(BLACKWELL_MMA_AVAILABLE)\n    return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;\n#else\n    return MMQ_ITER_K;\n#endif // defined(BLACKWELL_MMA_AVAILABLE)\n}\n\nstatic constexpr __device__ int get_mmq_y_device() {\n#if defined(GGML_USE_HIP)\n#if defined(RDNA1)\n    return 64;\n#else\n    return 128;\n#endif // defined RDNA1\n#else\n#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA\n    return 128;\n#else\n    return 64;\n#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA\n#endif // defined(GGML_USE_HIP)\n}\n\n// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.\n// The K dimension of the tiles has either,\n// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),\n// 32 bit elements for the quantized data (does not include scales).\n// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.\n// The final tile size in K direction is padded to avoid shared memory bank conflicts,\n// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.\n#define MMQ_TILE_NE_K 32\n\n#define MMQ_DP4A_TXS_Q4_0    tile_x_sizes{mmq_y*MMQ_TILE_NE_K   + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0   + mmq_y/QI4_0,     0}\n#define MMQ_DP4A_TXS_Q4_1    tile_x_sizes{mmq_y*MMQ_TILE_NE_K   + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1   + mmq_y/QI4_1,     0}\n#define MMQ_DP4A_TXS_Q8_0    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0}\n#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0}\n#define MMQ_DP4A_TXS_Q8_1    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0}\n#define MMQ_DP4A_TXS_Q2_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K         + mmq_y,           0}\n#define MMQ_DP4A_TXS_Q3_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y,                                         mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}\n#define MMQ_DP4A_TXS_Q4_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K   + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K,                     mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}\n#define MMQ_DP4A_TXS_Q5_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K   + mmq_y/QI5_K,     mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}\n#define MMQ_DP4A_TXS_Q6_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K   + mmq_y/QI6_K,     mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}\n\nstatic constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {\n    switch (type) {\n        case GGML_TYPE_Q4_0:    return MMQ_DP4A_TXS_Q4_0;\n        case GGML_TYPE_Q4_1:    return MMQ_DP4A_TXS_Q4_1;\n        case GGML_TYPE_Q5_0:    return MMQ_DP4A_TXS_Q8_0;\n        case GGML_TYPE_Q5_1:    return MMQ_DP4A_TXS_Q8_1;\n        case GGML_TYPE_Q8_0:    return MMQ_DP4A_TXS_Q8_0;\n        case GGML_TYPE_MXFP4:   return MMQ_DP4A_TXS_Q8_1;\n        case GGML_TYPE_Q2_K:    return MMQ_DP4A_TXS_Q2_K;\n        case GGML_TYPE_Q3_K:    return MMQ_DP4A_TXS_Q3_K;\n        case GGML_TYPE_Q4_K:    return MMQ_DP4A_TXS_Q4_K;\n        case GGML_TYPE_Q5_K:    return MMQ_DP4A_TXS_Q5_K;\n        case GGML_TYPE_Q6_K:    return MMQ_DP4A_TXS_Q6_K;\n        case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;\n        case GGML_TYPE_IQ2_XS:  return MMQ_DP4A_TXS_Q8_0_16;\n        case GGML_TYPE_IQ2_S:   return MMQ_DP4A_TXS_Q8_0_16;\n        case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;\n        case GGML_TYPE_IQ3_S:   return MMQ_DP4A_TXS_Q8_0;\n        case GGML_TYPE_IQ1_S:   return MMQ_DP4A_TXS_Q8_0;\n        case GGML_TYPE_IQ4_XS:  return MMQ_DP4A_TXS_Q8_0;\n        case GGML_TYPE_IQ4_NL:  return MMQ_DP4A_TXS_Q8_0;\n        default:                return tile_x_sizes{0, 0, 0};\n    }\n}\n\n#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0                   + 4)\n#define MMQ_MMA_TILE_X_K_FP4  (2*MMQ_TILE_NE_K + 8                                       + 4)\n#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0                   + 4)\n#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K                           + 4)\n#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2                         + 4)\n#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K   + MMQ_TILE_NE_K/8 + 7)\n\nstatic_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, \"Wrong padding.\");\nstatic_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, \"Wrong padding.\");\nstatic_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, \"Wrong padding.\");\nstatic_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, \"Wrong padding.\");\nstatic_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, \"Wrong padding.\");\nstatic_assert(MMQ_MMA_TILE_X_K_FP4  % 8 == 4, \"Wrong padding.\");\nstatic_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, \"Wrong tile size for MXFP4\");\n\nstatic constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_Q4_0:    return MMQ_MMA_TILE_X_K_Q8_0;\n        case GGML_TYPE_Q4_1:    return MMQ_MMA_TILE_X_K_Q8_1;\n        case GGML_TYPE_Q5_0:    return MMQ_MMA_TILE_X_K_Q8_0;\n        case GGML_TYPE_Q5_1:    return MMQ_MMA_TILE_X_K_Q8_1;\n        case GGML_TYPE_Q8_0:    return MMQ_MMA_TILE_X_K_Q8_0;\n        // tile sizes are the same for Q8_1 and FP4 for blackwell\n        case GGML_TYPE_MXFP4:   return MMQ_MMA_TILE_X_K_Q8_1;\n        case GGML_TYPE_Q2_K:    return MMQ_MMA_TILE_X_K_Q2_K;\n        case GGML_TYPE_Q3_K:    return MMQ_MMA_TILE_X_K_Q3_K;\n        case GGML_TYPE_Q4_K:    return MMQ_MMA_TILE_X_K_Q8_1;\n        case GGML_TYPE_Q5_K:    return MMQ_MMA_TILE_X_K_Q8_1;\n        case GGML_TYPE_Q6_K:    return MMQ_MMA_TILE_X_K_Q6_K;\n        case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;\n        case GGML_TYPE_IQ2_XS:  return MMQ_MMA_TILE_X_K_Q3_K;\n        case GGML_TYPE_IQ2_S:   return MMQ_MMA_TILE_X_K_Q3_K;\n        case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;\n        case GGML_TYPE_IQ3_S:   return MMQ_MMA_TILE_X_K_Q8_0;\n        case GGML_TYPE_IQ1_S:   return MMQ_MMA_TILE_X_K_Q8_0;\n        case GGML_TYPE_IQ4_XS:  return MMQ_MMA_TILE_X_K_Q8_0;\n        case GGML_TYPE_IQ4_NL:  return MMQ_MMA_TILE_X_K_Q8_0;\n        default:                return 0;\n    }\n}\n\n// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)\n#define MMQ_TILE_Y_K     (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)\n#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K\n\nstatic int mmq_get_granularity_host(const int mmq_x, const int cc) {\n    if (amd_mfma_available(cc) || amd_wmma_available(cc)) {\n        return mmq_x >= 128 ? 32 : 16;\n    } else if (turing_mma_available(cc) && mmq_x >= 48) {\n        return 16;\n    } else {\n        return 8;\n    }\n}\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\nstatic constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {\n    return mmq_x >= 128 ? 32 : 16;\n}\n#elif defined(TURING_MMA_AVAILABLE)\nstatic constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {\n    return mmq_x >= 48 ? 16 : 8;\n}\n#else\nstatic constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) {\n    return 8;\n}\n#endif // AMD_MFMA_AVAILABLE\n\n#if defined(GGML_USE_HIP)\nstatic int mmq_get_nwarps_host(const int cc, const int warp_size) {\n    return amd_mfma_available(cc) ? 8 : 256/warp_size;\n}\n#else\nstatic int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {\n    return 256/warp_size;\n}\n#endif // (GGML_USE_HIP)\n\nstatic constexpr __device__ int mmq_get_nwarps_device() {\n#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    return 8;\n#else\n    return 256/ggml_cuda_get_physical_warp_size();\n#endif // AMD_MFMA_AVAILABLE\n}\n\n// ------------------------------------------------------------\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);\n    constexpr int nrows = warp_size / threads_per_row;\n    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;\n    const int kbx  = txi / QI4_0;\n    const int kqsx = txi % QI4_0;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {\n        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;\n        const int qs0 = get_int_b2(bxi->qs, kqsx);\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0]     = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);\n#else\n        x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)\n    }\n\n    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;\n    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;\n    const int kbxd = threadIdx.x % blocks_per_tile_x_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {\n        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_df[i*MMQ_MMA_TILE_X_K_Q8_0           + kbxd] = bxi->d;\n#else\n        x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);\n    const int   * x_qs = (const int   *) x;\n    const float * x_df = (const float *) x_qs + txs.qs;\n    const int   * y_qs = (const int   *) y + 4;\n    const half2 * y_ds = (const half2 *) y;\n\n// #pragma unroll\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {\n        const int k0 = k00 + k01;\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n#pragma unroll\n            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n\n                const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);\n\n                int u[2*VDR_Q4_0_Q8_1_MMQ];\n\n#pragma unroll\n                for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {\n                    u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs +  l];\n                    u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];\n                }\n\n                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>\n                    (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,\n                     x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);\n            }\n        }\n    }\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    half2 * x_dm = (half2 *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);\n    constexpr int nrows = warp_size / threads_per_row;\n    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;\n    const int kbx  = txi / QI4_1;\n    const int kqsx = txi % QI4_1;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {\n        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;\n        const int qs0 = get_int_b4(bxi->qs, kqsx);\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0]     = (qs0 >> 0) & 0x0F0F0F0F;\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;\n#else\n        x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n\n    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;\n    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;\n    const int kbxd = threadIdx.x % blocks_per_tile_x_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {\n        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_dm[i*MMQ_MMA_TILE_X_K_Q8_1           + kbxd] = bxi->dm;\n#else\n        x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);\n    const int   * x_qs = (const int   *) x;\n    const half2 * x_dm = (const half2 *) x_qs + txs.qs;\n    const int   * y_qs = (const int   *) y + 4;\n    const half2 * y_ds = (const half2 *) y;\n\n// #pragma unroll\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {\n        const int k0 = k00 + k01;\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n#pragma unroll\n            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n\n                const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);\n\n                int u[2*VDR_Q4_1_Q8_1_MMQ];\n\n#pragma unroll\n                for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {\n                    u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs +  l];\n                    u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];\n                }\n\n                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>\n                    (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,\n                     x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);\n            }\n        }\n    }\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);\n    constexpr int nrows = warp_size / threads_per_row;\n    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;\n    const int kbx  = txi / QI5_0;\n    const int kqsx = txi % QI5_0;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {\n        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;\n\n        const int ql = get_int_b2(bxi->qs, kqsx);\n        const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx);\n\n        int qs0 = (ql >>  0)   & 0x0F0F0F0F;\n        qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4\n        qs0    |= (qh << 11)   & 0x00001000;  // 1 -> 12\n        qs0    |= (qh << 18)   & 0x00100000;  // 2 -> 20\n        qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28\n        qs0     = __vsubss4(qs0, 0x10101010); // subtract 16\n\n        int qs1 = (ql >>  4)   & 0x0F0F0F0F;\n        qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4\n        qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12\n        qs1    |= (qh <<  2)   & 0x00100000;  // 18 -> 20\n        qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28\n        qs1     = __vsubss4(qs1, 0x10101010); // subtract 16\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0;\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;\n#else\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0]     = qs0;\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n\n    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;\n    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;\n    const int kbxd = threadIdx.x % blocks_per_tile_x_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {\n        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_df[i*MMQ_MMA_TILE_X_K_Q8_0           + kbxd] = bxi->d;\n#else\n        x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    half2 * x_dm = (half2 *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);\n    constexpr int nrows = warp_size / threads_per_row;\n    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;\n    const int kbx  = txi / QI5_1;\n    const int kqsx = txi % QI5_1;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {\n        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;\n\n        const int ql = get_int_b4(bxi->qs, kqsx);\n        const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx);\n\n        int qs0 = (ql >>  0) & 0x0F0F0F0F;\n        qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4\n        qs0    |= (qh << 11) & 0x00001000; // 1 -> 12\n        qs0    |= (qh << 18) & 0x00100000; // 2 -> 20\n        qs0    |= (qh << 25) & 0x10000000; // 3 -> 28\n\n        int qs1 = (ql >>  4) & 0x0F0F0F0F;\n        qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4\n        qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12\n        qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20\n        qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0;\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;\n#else\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0]     = qs0;\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n\n    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;\n    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;\n    const int kbxd = threadIdx.x % blocks_per_tile_x_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {\n        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_dm[i*MMQ_MMA_TILE_X_K_Q8_1           + kbxd] = bxi->dm;\n#else\n        x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp\n    constexpr int threads_per_row = 32;\n    constexpr int nrows = warp_size / threads_per_row;\n    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;\n    const int kbx  = txi / QI8_0;\n    const int kqsx = txi % QI8_0;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {\n        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0             + txi] = get_int_b2(bxi[0].qs,                   kqsx);\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);\n#else\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0             + txi] = get_int_b2(bxi[0].qs,                   kqsx);\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n\n    constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;\n    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;\n    const int kbxd = threadIdx.x % blocks_per_tile_x_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {\n        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_df[i*MMQ_MMA_TILE_X_K_Q8_0                 + kbxd] = bxi->d;\n#else\n        x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);\n    constexpr int nrows = warp_size / threads_per_row;\n    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;\n    const int kbx  = txi / QI_MXFP4;\n    const int kqsx = txi % QI_MXFP4;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {\n        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;\n\n        const int aux_q4 = get_int_b1(bxi->qs, kqsx);\n        const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);\n        const int k0 = kbx * (2 * QI_MXFP4) + kqsx;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0]        = v.x;\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;\n#else\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0]        = v.x;\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)\n    }\n\n    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;\n    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;\n    const int kbxd = threadIdx.x % blocks_per_tile_x_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {\n        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_df[i*MMQ_MMA_TILE_X_K_Q8_1                 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;\n#else\n        x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_y, bool need_check>\nstatic __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,\n                                                            int * __restrict__ x_tile,\n                                                            const int kbx0,\n                                                            const int i_max,\n                                                            const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    int *      x_qs = (int *) x_tile;\n    uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);\n\n    const int txi = threadIdx.x;\n\n    constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);\n\n    constexpr int threads_per_row = iter_k / QK_MXFP4;  // each thread processes 1 block\n    constexpr int rows_per_warp   = warp_size / threads_per_row;\n    const int     kbx             = txi % threads_per_row;\n    const int     row_in_warp     = txi / threads_per_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {\n        int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;\n\n        if constexpr (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;\n\n        // quantize_mxfp4_mmq permutes nibbles to match the quantized format\n        const int k0 = kbx * 4;\n        memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);\n\n        // Load E8M0 scales: pack 2 consecutive scales into one uint32\n        if (kbx % 2 == 0) {\n            uint32_t e = bxi->e;\n            e |= ((bxi + 1)->e << 8);\n            x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;\n        }\n    }\n}\n\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);\n    const int   * x_qs = (const int   *) x;\n    const float * x_df = (const float *) x_qs + txs.qs;\n    const int   * y_qs = (const int   *) y + 4;\n    const float * y_df = (const float *) y;\n\n// #pragma unroll\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {\n        const int k0 = k00 + k01;\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n#pragma unroll\n            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n\n                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>\n                    (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K],\n                     x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]);\n            }\n        }\n    }\n}\n\ntemplate <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>\nstatic __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    constexpr data_layout input_layout = get_input_data_layout();\n    typedef tile<16,  8, int, input_layout>        tile_A;\n    typedef tile<16,  8, int, input_layout>        tile_B;\n    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;\n\n    constexpr int granularity = mmq_get_granularity_device(mmq_x);\n    constexpr int rows_per_warp = granularity;\n    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.\n\n    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);\n\n    const int   * x_qs = (const int   *) x;\n    const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;\n    const int   * y_qs = (const int   *) y + 4;\n    const float * y_df = (const float *) y;\n    const half2 * y_ds = (const half2 *) y;\n\n    const int i0 = (threadIdx.y / ntx) * rows_per_warp;\n\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {\n        const int k0 = k00 + k01;\n\n        tile_A A[ntx];\n#pragma unroll\n        for (int n = 0; n < ntx; ++n) {\n            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);\n        }\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {\n            tile_B B;\n            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);\n\n            float dB;\n            const int j = j0 + tile_C::get_j(0);\n            if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {\n                dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];\n            } else {\n                dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);\n            }\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n                tile_C C;\n                mma(C, A[n], B);\n\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    const int i = i0 + n*tile_A::I + tile_C::get_i(l);\n                    const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB;\n                }\n            }\n        }\n    }\n#else\n    typedef tile<16, 8, int> tile_A;\n    typedef tile< 8, 8, int> tile_B;\n    typedef tile<16, 8, int> tile_C;\n\n    constexpr int granularity = mmq_get_granularity_device(mmq_x);\n    constexpr int rows_per_warp = 2 * granularity;\n    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.\n\n    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);\n\n    const int   * x_qs = (const int   *) x;\n    const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;\n    const int   * y_qs = (const int   *) y + 4;\n    const float * y_df = (const float *) y;\n    const half2 * y_ds = (const half2 *) y;\n\n    tile_A A[ntx][MMQ_TILE_NE_K/QI8_0];\n    float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0];\n\n    const int i0 = (threadIdx.y/ntx)*rows_per_warp;\n\n#pragma unroll\n    for (int n = 0; n < ntx; ++n) {\n#pragma unroll\n        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {\n            const int k0 = k00 + k01;\n\n            load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);\n        }\n\n#pragma unroll\n        for (int l = 0; l < tile_C::ne/2; ++l) {\n            const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);\n\n#pragma unroll\n            for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {\n                const int k0 = k00 + k01;\n\n                dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];\n            }\n        }\n    }\n\n#pragma unroll\n    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {\n#pragma unroll\n        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {\n            tile_B B;\n            float dB[tile_C::ne/2];\n\n            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix\n\n#pragma unroll\n            for (int l = 0; l < tile_C::ne/2; ++l) {\n                const int j = j0 + tile_C::get_j(l);\n\n                if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {\n                    dB[l] =             y_df[j*MMQ_TILE_Y_K + k01/QI8_1];\n                } else {\n                    dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);\n                }\n            }\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n                tile_C C;\n                mma(C, A[n][k01/QI8_0], B);\n\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];\n                }\n            }\n        }\n    }\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n}\n\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,\n                                                               const int * __restrict__ y,\n                                                               float * __restrict__ sum,\n                                                               const int k00) {\n    typedef tile<16, 8, int>   tile_A;\n    typedef tile<8, 8, int>    tile_B;\n    typedef tile<16, 8, float> tile_C;  // Output is float for native scaled MMA\n\n    constexpr int granularity   = mmq_get_granularity_device(mmq_x);\n    constexpr int rows_per_warp = 2 * granularity;\n    constexpr int ntx           = rows_per_warp / tile_C::I;  // Number of x minitiles per warp.\n\n    y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);\n\n    // Match layout from load_tiles_mxfp4_fp4\n    const int *      x_qs = (const int *) x;\n    const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);\n    const int *      y_qs = (const int *) y + 4;\n    const uint32_t * y_sc = (const uint32_t *) y;\n\n    // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4\n    tile_A   A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];\n    uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];\n\n    // Block scale\n    // Each thread has to point to a 4 byte scale value\n    // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling\n\n    const int i0 = (threadIdx.y / ntx) * rows_per_warp;\n\n#pragma unroll\n    for (int n = 0; n < ntx; ++n) {\n#pragma unroll\n        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {\n            const int k0 = k00 + k01;\n\n            load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,\n                          MMQ_MMA_TILE_X_K_FP4);\n\n            // based on block-scaling document, 2 threads in each quad need to supply to the scale value\n            const int tidx         = threadIdx.x / 4 + (threadIdx.x % 2) * 8;\n            scaleA[n][k01 / (2 * QI_MXFP4)] =\n                *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));\n        }\n    }\n\n#pragma unroll\n    for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {\n#pragma unroll\n        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {\n            tile_B   B;\n            uint32_t scaleB;  // 2xN scales\n\n            load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);\n\n            scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n                tile_C C;\n\n                mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];\n                }\n            }\n        }\n    }\n}\n\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);\n    const int   * x_qs = (const int   *) x;\n    const half2 * x_dm = (const half2 *) x_qs + txs.qs;\n    const int   * y_qs = (const int   *) y + 4;\n    const half2 * y_ds = (const half2 *) y;\n\n// #pragma unroll\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {\n        const int k0 = k00 + k01;\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n#pragma unroll\n            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n\n                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>\n                    (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],\n                    x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);\n            }\n        }\n    }\n}\n\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    constexpr data_layout input_layout = get_input_data_layout();\n    typedef tile<16,  8, int, input_layout>        tile_A;\n    typedef tile<16,  8, int, input_layout>        tile_B;\n    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;\n\n    constexpr int granularity = mmq_get_granularity_device(mmq_x);\n    constexpr int rows_per_warp = granularity;\n    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.\n\n    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);\n\n    const int   * x_qs = (const int   *) x;\n    const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;\n    const int   * y_qs = (const int   *) y + 4;\n    const half2 * y_dm = (const half2 *) y;\n\n    const int i0 = (threadIdx.y / ntx) * rows_per_warp;\n\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {\n        const int k0 = k00 + k01;\n\n        tile_A A[ntx];\n#pragma unroll\n        for (int n = 0; n < ntx; ++n) {\n            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);\n        }\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {\n            tile_B B;\n            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);\n\n            const int j = j0 + tile_C::get_j(0);\n            const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n                tile_C C;\n                mma(C, A[n], B);\n\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    const int i = i0 + n*tile_A::I + tile_C::get_i(l);\n                    float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y;\n                }\n            }\n        }\n    }\n#else\n    typedef tile<16,  8, int> tile_A;\n    typedef tile< 8,  8, int> tile_B;\n    typedef tile<16,  8, int> tile_C;\n\n    constexpr int granularity = mmq_get_granularity_device(mmq_x);\n    constexpr int rows_per_warp = 2 * granularity;\n    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.\n\n    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);\n\n    const int   * x_qs = (const int   *) x;\n    const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;\n    const int   * y_qs = (const int   *) y + 4;\n    const half2 * y_dm = (const half2 *) y;\n\n    tile_A   A[ntx][MMQ_TILE_NE_K/QI8_1];\n    float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];\n\n    const int i0 = (threadIdx.y/ntx)*rows_per_warp;\n\n#pragma unroll\n    for (int n = 0; n < ntx; ++n) {\n#pragma unroll\n        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {\n            const int k0 = k00 + k01;\n\n            load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);\n        }\n\n#pragma unroll\n        for (int l = 0; l < tile_C::ne/2; ++l) {\n            const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);\n\n#pragma unroll\n            for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {\n                const int k0 = k00 + k01;\n\n                dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);\n            }\n        }\n    }\n\n#pragma unroll\n    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {\n#pragma unroll\n        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {\n            tile_B   B;\n            float2 dsB[tile_C::ne/2];\n\n            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix\n\n#pragma unroll\n            for (int l = 0; l < tile_C::ne/2; ++l) {\n                const int j = j0 + tile_C::get_j(l);\n\n                dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);\n            }\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n                tile_C C;\n                mma(C, A[n][k01/QI8_1], B);\n\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;\n                }\n            }\n        }\n    }\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n}\n\n// Used for Q3_K, IQ2_S, and IQ2_XS\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;\n    const int   * x_qs = (const int   *) x;\n    const float * x_df = (const float *) x_qs + txs.qs;\n    const int   * y_qs = (const int   *) y + 4;\n    const float * y_df = (const float *) y;\n\n// #pragma unroll\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {\n        const int k0 = k00 + k01;\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n#pragma unroll\n            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n\n                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(\n                    &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0],\n                    &y_qs[j*MMQ_TILE_Y_K + k01],\n                    &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],\n                    y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);\n            }\n        }\n    }\n}\n\n// Used for Q3_K, IQ2_S, and IQ2_XS:\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n#if defined(AMD_MFMA_AVAILABLE)\n    constexpr data_layout input_layout = get_input_data_layout();\n    typedef tile<16,  8, int, input_layout>        tile_A;\n    typedef tile<16,  8, int, input_layout>        tile_B;\n    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;\n    typedef tile<64,  2, int, input_layout>        tile_load;\n\n    constexpr int granularity = mmq_get_granularity_device(mmq_x);\n    constexpr int rows_per_warp = granularity;\n    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.\n\n    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);\n\n    const int   * x_qs = (const int   *) x;\n    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;\n    const int   * y_qs = (const int   *) y + 4;\n    const float * y_df = (const float *) y;\n\n    const int i0 = (threadIdx.y / ntx) * rows_per_warp;\n\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {\n        const int k0 = k00 + k01;\n\n        tile_A A[ntx];\n#pragma unroll\n        for (int n = 0; n < ntx; ++n) {\n            load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);\n        }\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {\n            tile_B B[1];\n            load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);\n\n            const int j = j0 + tile_C::get_j(0);\n            const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n                tile_C C;\n                mma(C, A[n], B[0]);\n\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;\n                }\n            }\n        }\n    }\n#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles\n    constexpr data_layout input_layout = get_input_data_layout();\n    typedef tile<16,  4, int, input_layout>        tile_A;\n    typedef tile<16,  4, int, input_layout>        tile_B;\n    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;\n\n    constexpr int granularity = mmq_get_granularity_device(mmq_x);\n    constexpr int rows_per_warp = granularity;\n    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.\n\n    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);\n\n    const int   * x_qs = (const int   *) x;\n    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;\n    const int   * y_qs = (const int   *) y + 4;\n    const float * y_df = (const float *) y;\n\n    const int i0 = (threadIdx.y / ntx) * rows_per_warp;\n\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {\n        const int k0 = k00 + k01;\n\n        tile_A A[ntx];\n#pragma unroll\n        for (int n = 0; n < ntx; ++n) {\n            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);\n        }\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {\n            tile_B B;\n            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);\n\n            const int j = j0 + tile_C::get_j(0);\n            const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n                tile_C C;\n                mma(C, A[n], B);\n\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;\n                }\n            }\n        }\n    }\n#elif defined(TURING_MMA_AVAILABLE)\n\n    typedef tile<16, 4, int> tile_A;\n    typedef tile<16, 8, int> tile_A_8;\n    typedef tile< 8, 4, int> tile_B;\n    typedef tile<16, 8, int> tile_C;\n\n    constexpr int granularity = mmq_get_granularity_device(mmq_x);\n    constexpr int rows_per_warp = 2 * granularity;\n    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.\n\n    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);\n\n    const int   * x_qs = (const int   *) x;\n    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;\n    const int   * y_qs = (const int   *) y + 4;\n    const float * y_df = (const float *) y;\n\n    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);\n\n    tile_A  A[ntx][8];\n    float  dA[ntx][tile_C::ne/2][8];\n\n#pragma unroll\n    for (int n = 0; n < ntx; ++n) {\n#pragma unroll\n        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {\n            const int k0 = k00 + k01;\n\n            load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);\n        }\n\n#pragma unroll\n        for (int l = 0; l < tile_C::ne/2; ++l) {\n            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);\n\n#pragma unroll\n            for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {\n                const int k0 = k00 + k01;\n\n                dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];\n            }\n        }\n    }\n\n#pragma unroll\n    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {\n#pragma unroll\n        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {\n            tile_B B[2];\n            float dB[tile_C::ne/2];\n\n            // Here load_generic is faster than load_ldmatrix.\n            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),         MMQ_TILE_Y_K);\n            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);\n\n#pragma unroll\n            for (int l = 0; l < tile_C::ne/2; ++l) {\n                const int j = j0 + tile_C::get_j(l);\n\n                dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];\n            }\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n                tile_C C[2];\n                mma(C[0], A[n][k01/4 + 0], B[0]);\n                mma(C[1], A[n][k01/4 + 1], B[1]);\n\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);\n                }\n            }\n        }\n    }\n#else\n    GGML_UNUSED_VARS(x, y, sum, k00);\n    NO_DEVICE_CODE;\n#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    half2 * x_dm = (half2 *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);\n    constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;\n    const int kqsx = threadIdx.x % threads_per_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {\n        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride;\n\n        const int x_ql_0 = get_int_b2(bxi->qs, kqsx);\n\n#pragma unroll\n        for (int l = 0; l < QR2_K; ++l) {\n            const int k = (kqsx/8)*32 + l*8 + kqsx % 8;\n\n            const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n            x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;\n#else\n            x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        }\n\n        const int sc_m = bxi->scales[kqsx];\n#ifdef FAST_FP16_AVAILABLE\n        const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));\n#else\n        const float2 bxi_dmf = __half22float2(bxi->dm);\n        const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));\n#endif // FAST_FP16_AVAILABLE\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;\n#else\n        x_dm[i*(MMQ_TILE_NE_K + 1)   + kqsx] = x_dm_ik;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);\n    const int   * x_qs = (const int   *) x;\n    const half2 * x_dm = (const half2 *) x_qs + txs.qs;\n    const int   * y_qs = (const int   *) y + 4;\n    const half2 * y_ds = (const half2 *) y;\n\n    float2 y_df[mmq_x/nwarps];\n#pragma unroll\n    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n        const int j = j0 + threadIdx.y;\n\n        y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);\n    }\n\n#pragma unroll\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {\n        const int k0 = k00 + k01;\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n#pragma unroll\n            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n\n                constexpr int ns = 2;\n                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(\n                    &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],\n                    &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,\n                    &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);\n            }\n        }\n    }\n\n    // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.\n    // As a workaround 2 separate loops are used instead.\n#pragma unroll\n    for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {\n        const int k0 = k00 + k01;\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n#pragma unroll\n            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n\n                constexpr int ns = 1;\n                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(\n                    &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],\n                    &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,\n                    &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);\n            }\n        }\n    }\n}\n\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n#if defined(AMD_MFMA_AVAILABLE)\n    constexpr data_layout input_layout = get_input_data_layout();\n    typedef tile<16,  8, int, input_layout>        tile_A;\n    typedef tile<16,  8, int, input_layout>        tile_B;\n    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;\n    typedef tile<64,  2, int, input_layout>        tile_load;\n\n    constexpr int granularity = mmq_get_granularity_device(mmq_x);\n    constexpr int rows_per_warp = granularity;\n    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.\n\n    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);\n\n    const int   * x_qs = (const int   *) x;\n    const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;\n    const int   * y_qs = (const int   *) y + 4;\n    const half2 * y_ds = (const half2 *) y;\n\n    const int i0 = (threadIdx.y / ntx) * rows_per_warp;\n\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {\n        const int k0 = k00 + k01;\n\n        tile_A A[ntx];\n#pragma unroll\n        for (int n = 0; n < ntx; ++n) {\n            load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);\n        }\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {\n            tile_B B[1];\n            load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);\n\n            const int j = j0 + tile_C::get_j(0);\n            const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;\n            const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0\n                                              : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y\n                                                             : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);\n\n            tile_C Cm;\n            if (k01 >= MMQ_TILE_NE_K * 3/4) {\n                tile_A A1;\n                A1.x[0] = 0x01010101;\n                A1.x[1] = 0x01010101;\n                mma(Cm, A1, B[0]);\n            }\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n                tile_C Cd;\n                mma(Cd, A[n], B[0]);\n\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);\n                    const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);\n                    float tmp = Cd.x[l]*dm.x;\n                    if (k01 >= MMQ_TILE_NE_K * 3/4) {\n                        tmp -= Cm.x[l]*dm.y;\n                    }\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;\n                }\n            }\n        }\n    }\n#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles\n    constexpr data_layout input_layout = get_input_data_layout();\n    typedef tile<16,  4, int, input_layout>        tile_A;\n    typedef tile<16,  4, int, input_layout>        tile_B;\n    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;\n\n    constexpr int granularity = mmq_get_granularity_device(mmq_x);\n    constexpr int rows_per_warp = granularity;\n    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.\n\n    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);\n\n    const int   * x_qs = (const int   *) x;\n    const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;\n    const int   * y_qs = (const int   *) y + 4;\n    const half2 * y_ds = (const half2 *) y;\n\n    const int i0 = (threadIdx.y / ntx) * rows_per_warp;\n\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {\n        const int k0 = k00 + k01;\n\n        tile_A A[ntx];\n#pragma unroll\n        for (int n = 0; n < ntx; ++n) {\n            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);\n        }\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {\n            tile_B B;\n            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);\n\n            const int j = j0 + tile_C::get_j(0);\n            const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;\n            const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0\n                                              : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y\n                                                             : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);\n\n            tile_C Cm;\n            if (k01 >= MMQ_TILE_NE_K * 3/4) {\n                tile_A A1;\n#pragma unroll\n                for (int l = 0; l < tile_A::ne; ++l) {\n                    A1.x[l] = 0x01010101;\n                }\n                mma(Cm, A1, B);\n            }\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n                tile_C Cd;\n                mma(Cd, A[n], B);\n\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);\n                    const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);\n                    float tmp = Cd.x[l]*dm.x;\n                    if (k01 >= MMQ_TILE_NE_K * 3/4) {\n                        tmp -= Cm.x[l]*dm.y;\n                    }\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;\n                }\n            }\n        }\n    }\n#elif defined(TURING_MMA_AVAILABLE)\n\n    typedef tile<16, 4, int> tile_A;\n    typedef tile<16, 8, int> tile_A_8;\n    typedef tile< 8, 4, int> tile_B;\n    typedef tile<16, 8, int> tile_C;\n\n    constexpr int granularity = mmq_get_granularity_device(mmq_x);\n    constexpr int rows_per_warp = 2 * granularity;\n    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.\n\n    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);\n\n    const int   * x_qs = (const int   *) x;\n    const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;\n    const int   * y_qs = (const int   *) y + 4;\n    const half2 * y_ds = (const half2 *) y;\n\n    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);\n\n    tile_A  A[ntx][8];\n    float  dA[ntx][tile_C::ne/2][8];\n    float  mA[ntx][tile_C::ne/2][8];\n\n#pragma unroll\n    for (int n = 0; n < ntx; ++n) {\n#pragma unroll\n        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {\n            const int k0 = k00 + k01;\n\n            load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);\n        }\n    }\n\n#pragma unroll\n    for (int n = 0; n < ntx; ++n) {\n#pragma unroll\n        for (int l = 0; l < tile_C::ne/2; ++l) {\n            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);\n\n#pragma unroll\n            for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) {\n                const int k0 = k00 + k01;\n\n                const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);\n\n                dA[n][l][k01/(QI8_1/2)] = dm.x;\n                mA[n][l][k01/(QI8_1/2)] = dm.y;\n            }\n        }\n    }\n\n#pragma unroll\n    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {\n        float2 dB[tile_C::ne/2];\n\n#pragma unroll\n        for (int l = 0; l < tile_C::ne/2; ++l) {\n            const int j = j0 + tile_C::get_j(l);\n\n            dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);\n        }\n\n#pragma unroll\n        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {\n            tile_B B[2];\n\n            // Here load_generic is faster than load_ldmatrix.\n            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),         MMQ_TILE_Y_K);\n            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);\n\n            tile_C Cm[2];\n            if (k01 >= MMQ_TILE_NE_K * 3/4) {\n                tile_A A1;\n                A1.x[0] = 0x01010101;\n                A1.x[1] = 0x01010101;\n                mma(Cm[0], A1, B[0]);\n                mma(Cm[1], A1, B[1]);\n            }\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n                tile_C Cd[2];\n\n                mma(Cd[0], A[n][k01/4 + 0], B[0]);\n                mma(Cd[1], A[n][k01/4 + 1], B[1]);\n\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];\n                    if (k01 >= MMQ_TILE_NE_K * 3/4) {\n                        tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];\n                    }\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y);\n                }\n            }\n        }\n\n#pragma unroll\n        for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) {\n            float2 sB[tile_C::ne/2];\n\n#pragma unroll\n            for (int l = 0; l < tile_C::ne/2; ++l) {\n                const int j = j0 + tile_C::get_j(l);\n\n                sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);\n            }\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;\n                }\n            }\n        }\n    }\n#else\n    GGML_UNUSED_VARS(x, y, sum, k00);\n    NO_DEVICE_CODE;\n#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + txs.qs);\n    int   * x_sc = (int   *) (x_df + txs.dm);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)\n\n    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);\n    constexpr int nrows = warp_size / threads_per_row;\n    const int kqsx = threadIdx.x % threads_per_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {\n        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;\n\n        const int x_ql_0 = get_int_b2(bxi->qs,    kqsx);\n        const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));\n\n#pragma unroll\n        for (int l = 0; l < QR3_K; ++l) {\n            const int k = (kqsx/8)*32 + l*8 + kqsx % 8;\n\n            const int x_ql_k =  (x_ql_0 >> (2*l))       & 0x03030303;\n            const int x_qh_k = ((x_qh_0 >>    l)  << 2) & 0x04040404;\n\n            const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;\n#else\n            x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        }\n    }\n\n    constexpr int rows_per_warp = warp_size / 4;\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {\n        int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;\n\n        const int ksc = threadIdx.x % 4;\n\n        const int ksc_low = ksc % (QI3_K/8);\n        const int shift_low = 4 * (ksc / (QI3_K/8));\n        const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;\n\n        const int ksc_high = QI3_K/8;\n        const int shift_high = 2 * ksc;\n        const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;\n\n        const int sc = __vsubss4(sc_low | sc_high, 0x20202020);\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        const int8_t * sc8 = (const int8_t *) &sc;\n        const float d = bxi->d;\n\n#pragma unroll\n        for (int l = 0; l < int(sizeof(int)); ++l) {\n            x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];\n        }\n#else\n        x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n\n#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE))\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {\n        int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;\n\n        x_df[i] = bxi->d;\n    }\n#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE)\n}\n\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);\n    const int   * x_qs = (const int   *) x;\n    const float * x_df = (const float *) x_qs + txs.qs;\n    const int   * x_sc = (const int   *) x_df + txs.dm;\n    const int   * y_qs = (const int   *) y + 4;\n    const float * y_df = (const float *) y;\n\n// #pragma unroll\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {\n        const int k0 = k00 + k01;\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n#pragma unroll\n            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n\n                const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4;\n\n                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq(\n                    &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,\n                    x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);\n            }\n        }\n    }\n}\n\nstatic __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) {\n    // scale arrangement after the following two lines:\n    //   - ksc == 0: sc0, sc1, sc2, sc3\n    //   - ksc == 1: sc4, sc5, sc6, sc7\n    //   - ksc == 2:  m0,  m1,  m2,  m3\n    //   - ksc == 3:  m4,  m5,  m6,  m7\n    return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits\n           ((scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030);  // upper 2 bits\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    half2 * x_dm = (half2 *) (x_qs + txs.qs);\n    int   * x_sc = (int   *) (x_dm + txs.dm);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);\n    constexpr int nrows = warp_size / threads_per_row;\n    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {\n        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;\n        const int qs0 = get_int_b4(bxi->qs, txi);\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;\n#else\n        x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    constexpr int rows_per_warp = warp_size / 2;\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {\n#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        // Need if on AMD instead of % because warp_size == 64\n        // This causes double work and throughput loss (MI300X)\n        // H100 loses about 100 t/s with 'if' condition over '%'\n        int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;\n        if (i < mmq_y) {\n#else\n        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;\n        {\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n            if (need_check) {\n                i = min(i, i_max);\n            }\n\n            const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;\n\n            const int * scales = (const int *) bxi->scales;\n            const int ksc = threadIdx.x % 2;\n\n            const int sc32 = unpack_scales_q45_K(scales, ksc + 0);\n            const int  m32 = unpack_scales_q45_K(scales, ksc + 2);\n\n            const uint8_t * sc8 = (const uint8_t *) &sc32;\n            const uint8_t *  m8 = (const uint8_t *)  &m32;\n\n            const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);\n\n    #pragma unroll\n            for (int l = 0; l < sizeof(int); ++l) {\n                x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);\n            }\n        }\n    }\n#else\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {\n        int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;\n\n        x_dm[i] = bxi->dm;\n    }\n    constexpr int rows_per_warp = warp_size / 4;\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {\n        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8);\n\n        const int * scales = (const int *) bxi->scales;\n\n        const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);\n        const int scales8 = unpack_scales_q45_K(scales, ksc);\n\n        x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;\n    }\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n}\n\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);\n    const int   * x_qs = (const int   *) x;\n    const half2 * x_dm = (const half2 *) x_qs + txs.qs;\n    const int   * x_sc = (const int   *) x_dm + txs.dm;\n    const int   * y_qs = (const int   *) y + 4;\n    const half2 * y_ds = (const half2 *) y;\n\n// #pragma unroll\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {\n        const int k0 = k00 + k01;\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n#pragma unroll\n            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n\n                const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16);\n\n                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq(\n                    &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,\n                    x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);\n            }\n        }\n    }\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    half2 * x_dm = (half2 *) (x_qs + txs.qs);\n    int   * x_sc = (int   *) (x_dm + txs.dm);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)\n\n    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);\n    constexpr int nrows = warp_size / threads_per_row;\n    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {\n        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;\n        const int ky = QR5_K*txi;\n\n        const int ql = get_int_b4(bxi->qs, txi);\n        const int ql0 = (ql >> 0) & 0x0F0F0F0F;\n        const int ql1 = (ql >> 4) & 0x0F0F0F0F;\n\n        const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4));\n        const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010;\n        const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010;\n\n        const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;\n        const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;\n#else\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    constexpr int rows_per_warp = warp_size / 2;\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {\n#if defined(AMD_MFMA_AVAILABLE)\n        // Need if on AMD instead of % because warp_size == 64\n        // This causes double work and throughput loss (MI300X)\n        // H100 loses about 100 t/s with 'if' condition over '%'\n        int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;\n        if (i < mmq_y) {\n#else\n        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;\n        {\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n            if (need_check) {\n                i = min(i, i_max);\n            }\n\n            const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;\n\n            const int * scales = (const int *) bxi->scales;\n            const int ksc = threadIdx.x % 2;\n\n            const int sc32 = unpack_scales_q45_K(scales, ksc + 0);\n            const int  m32 = unpack_scales_q45_K(scales, ksc + 2);\n\n            const uint8_t * sc8 = (const uint8_t *) &sc32;\n            const uint8_t *  m8 = (const uint8_t *)  &m32;\n\n            const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);\n\n#pragma unroll\n            for (int l = 0; l < int(sizeof(int)); ++l) {\n                x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);\n            }\n        }\n    }\n#else\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {\n        int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;\n\n        x_dm[i] = bxi->dm;\n    }\n\n    constexpr int rows_per_warp = warp_size / 4;\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {\n        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;\n\n        const int * scales = (const int *) bxi->scales;\n\n        const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);\n        const int scales8 = unpack_scales_q45_K(scales, ksc);\n\n        x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;\n    }\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n}\n\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);\n    const int   * x_qs = (const int   *) x;\n    const half2 * x_dm = (const half2 *) x_qs + txs.qs;\n    const int   * x_sc = (const int   *) x_dm + txs.dm;\n    const int   * y_qs = (const int   *) y + 4;\n    const half2 * y_ds = (const half2 *) y;\n\n// #pragma unroll\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {\n        const int k0 = k00 + k01;\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n#pragma unroll\n            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n\n                const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16);\n\n                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq(\n                    &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,\n                    x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);\n            }\n        }\n    }\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);\n    int   * x_sc = (int   *) (x_df + MMQ_TILE_NE_K/QI6_K);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + txs.qs);\n    int   * x_sc = (int   *) (x_df + txs.dm);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);\n    constexpr int nrows = warp_size / threads_per_row;\n    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {\n        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;\n\n        const int ql = get_int_b2(bxi->ql, txi);\n        const int ql0 = (ql >> 0) & 0x0F0F0F0F;\n        const int ql1 = (ql >> 4) & 0x0F0F0F0F;\n\n        const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4));\n        const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030;\n        const int qh1 =  (qh >> ((txi & 0x08) >> 2))       & 0x30303030;\n\n        const int kq0 = 2*txi - txi % (QI6_K/2) + 0;\n        const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);\n        x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);\n#else\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {\n        int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_df[i*MMQ_MMA_TILE_X_K_Q6_K]           = bxi->d;\n#else\n        x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n\n    constexpr int rows_per_warp = warp_size / 4;\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {\n        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));\n#else\n        x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);\n    const int   * x_qs = (const int   *) x;\n    const float * x_df = (const float *) x_qs + txs.qs;\n    const int   * x_sc = (const int   *) x_df + txs.dm;\n    const int   * y_qs = (const int   *) y + 4;\n    const float * y_df = (const float *) y;\n\n// #pragma unroll\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {\n        const int k0 = k00 + k01;\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n#pragma unroll\n            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n\n                const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]);\n\n                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq(\n                    &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,\n                    x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);\n            }\n        }\n    }\n}\n\ntemplate <int mmq_x, int mmq_y>\nstatic __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(\n    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {\n#if defined(AMD_MFMA_AVAILABLE)\n    constexpr data_layout input_layout = get_input_data_layout();\n    typedef tile<16,  8, int, input_layout>        tile_A;\n    typedef tile<16,  8, int, input_layout>        tile_B;\n    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;\n    typedef tile<64,  2, int, input_layout>        tile_load;\n\n    constexpr int granularity = mmq_get_granularity_device(mmq_x);\n    constexpr int rows_per_warp = granularity;\n    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.\n\n    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);\n\n    const int   * x_qs = (const int   *) x;\n    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;\n    const int   * x_sc = (const int   *) x_df + MMQ_TILE_NE_K/QI6_K;\n    const int   * y_qs = (const int   *) y + 4;\n    const float * y_df = (const float *) y;\n\n    const int i0 = (threadIdx.y / ntx) * rows_per_warp;\n\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {\n        const int k0 = k00 + k01;\n\n        tile_A A[ntx];\n#pragma unroll\n        for (int n = 0; n < ntx; ++n) {\n            load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);\n        }\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {\n            tile_B B[1];\n            load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);\n\n            const int j = j0 + tile_C::get_j(0);\n            const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n                tile_C C;\n                mma(C, A[n], B[0]);\n\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);\n                    const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;\n                }\n            }\n        }\n    }\n#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles\n    constexpr data_layout input_layout = get_input_data_layout();\n    typedef tile<16,  4, int, input_layout>        tile_A;\n    typedef tile<16,  4, int, input_layout>        tile_B;\n    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;\n\n    constexpr int granularity = mmq_get_granularity_device(mmq_x);\n    constexpr int rows_per_warp = granularity;\n    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.\n\n    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);\n\n    const int   * x_qs = (const int   *) x;\n    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;\n    const int   * x_sc = (const int   *) x_df + MMQ_TILE_NE_K/QI6_K;\n    const int   * y_qs = (const int   *) y + 4;\n    const float * y_df = (const float *) y;\n\n    const int i0 = (threadIdx.y / ntx) * rows_per_warp;\n\n    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {\n        const int k0 = k00 + k01;\n\n        tile_A A[ntx];\n#pragma unroll\n        for (int n = 0; n < ntx; ++n) {\n            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);\n        }\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {\n            tile_B B;\n            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);\n\n            const int j = j0 + tile_C::get_j(0);\n            const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n                tile_C C;\n                mma(C, A[n], B);\n\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);\n                    const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);\n                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;\n                }\n            }\n        }\n    }\n#elif defined(TURING_MMA_AVAILABLE)\n\n    typedef tile<16, 4, int> tile_A;\n    typedef tile< 8, 4, int> tile_B;\n    typedef tile<16, 8, int> tile_C;\n\n    constexpr int granularity = mmq_get_granularity_device(mmq_x);\n    constexpr int rows_per_warp = 2 * granularity;\n    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.\n\n    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);\n\n    const int   * x_qs = (const int   *) x;\n    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;\n    const int   * x_sc = (const int   *) x_df + MMQ_TILE_NE_K/QI6_K;\n    const int   * y_qs = (const int   *) y + 4;\n    const float * y_df = (const float *) y;\n\n    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);\n\n    tile_A   A[ntx][8];\n    int    scA[ntx][tile_C::ne/2][8];\n    float   dA[ntx][tile_C::ne/2];\n\n#pragma unroll\n    for (int n = 0; n < ntx; ++n) {\n#pragma unroll\n        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {\n            const int k0 = k00 + k01;\n\n            load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0),         MMQ_MMA_TILE_X_K_Q6_K);\n            load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);\n        }\n\n#pragma unroll\n        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) {\n            const int k0 = k00 + k01;\n\n#pragma unroll\n            for (int l = 0; l < tile_C::ne/2; ++l) {\n                const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);\n\n                const int      sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];\n                const int8_t * sc        = (const int8_t *) &sc_packed;\n\n#pragma unroll\n                for (int ksc = 0; ksc < sizeof(int); ++ksc) {\n                    scA[n][l][k01/4 + ksc] = sc[ksc];\n                }\n            }\n        }\n\n#pragma unroll\n        for (int l = 0; l < tile_C::ne/2; ++l) {\n            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);\n\n            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];\n        }\n    }\n\n#pragma unroll\n    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {\n        float tmp[ntx][tile_C::ne] = {{0.0f}};\n\n#pragma unroll\n        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {\n            tile_B B[2];\n            float dB[tile_C::ne/2];\n\n            // Here load_generic is faster than load_ldmatrix.\n            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0         + k01, MMQ_TILE_Y_K);\n            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);\n\n#pragma unroll\n            for (int l = 0; l < tile_C::ne/2; ++l) {\n                const int j = j0 + tile_C::get_j(l);\n\n                dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];\n            }\n\n#pragma unroll\n            for (int n = 0; n < ntx; ++n) {\n                tile_C C[2];\n                mma(C[0], A[n][k01/4 + 0], B[0]);\n                mma(C[1], A[n][k01/4 + 1], B[1]);\n\n#pragma unroll\n                for (int l = 0; l < tile_C::ne; ++l) {\n                    tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];\n                }\n            }\n        }\n\n#pragma unroll\n        for (int n = 0; n < ntx; ++n) {\n#pragma unroll\n            for (int l = 0; l < tile_C::ne; ++l) {\n                sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];\n            }\n        }\n    }\n#else\n    GGML_UNUSED_VARS(x, y, sum, k00);\n    NO_DEVICE_CODE;\n#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);\n    constexpr int nrows = warp_size / threads_per_row;\n    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;\n    const int kbx  = txi / QI4_NL;\n    const int kqsx = txi % QI4_NL;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {\n        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;\n\n        const int aux_q4 = get_int_b2(bxi->qs, kqsx);\n        const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);\n        const int k0 = kbx * (2 * QI4_NL) + kqsx;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0]      = v.x;\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;\n#else\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0]      = v.x;\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n\n    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;\n    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;\n    const int kbxd = threadIdx.x % blocks_per_tile_x_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {\n        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_df[i*MMQ_MMA_TILE_X_K_Q8_0             + kbxd] = __half2float(bxi->d);\n#else\n        x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;\n    constexpr int nrows = warp_size / threads_per_row;\n    const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {\n        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride;\n\n        const int q2 = get_int_b2(bxi->qs, 2*kqsx+0);\n        const uint8_t * aux8 = (const uint8_t *) &q2;\n        const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1);\n\n#pragma unroll\n        for (int l = 0; l < QR2_XXS; ++l) {\n            const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]];\n            const uint32_t signs = unpack_ksigns(aux32 >> (7 * l));\n\n            const int signs0 = __vcmpne4(signs & 0x08040201, 0);\n            const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);\n\n            const int signs1 = __vcmpne4(signs & 0x80402010, 0);\n            const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;\n            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;\n#else\n            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;\n            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        }\n\n        const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)\n        const float d = bxi->d;\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4\n#else\n        x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);\n#else\n    constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;\n    constexpr int nrows = warp_size / threads_per_row;\n    const int kqsx = threadIdx.x % threads_per_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {\n        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride;\n\n        const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));\n        const uint16_t * q2 = (const uint16_t *) &q2_packed;\n\n    #pragma unroll\n        for (int l = 0; l < QR2_XS; ++l) {\n            const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF];\n            const uint32_t signs = unpack_ksigns(q2[l] >> 9);\n\n            const int signs0 = __vcmpne4(signs & 0x08040201, 0);\n            const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);\n\n            const int signs1 = __vcmpne4(signs & 0x80402010, 0);\n            const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;\n            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;\n#else\n            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;\n            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        }\n\n        const int ls = bxi->scales[kqsx];\n        const float d = bxi->d;\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;\n        x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;\n#else\n        x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;\n        x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;\n    constexpr int nrows = warp_size / threads_per_row;\n    const int kqsx = threadIdx.x % threads_per_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {\n        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride;\n\n        const int       qs_packed = get_int_b2(bxi->qs, kqsx);\n        const uint8_t * qs        = (const uint8_t *) &qs_packed;\n\n        const int qh = bxi->qh[kqsx];\n\n        const int       signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx);\n        const uint8_t * signs_packed_8  = (const uint8_t *) &signs_packed_32;\n\n#pragma unroll\n        for (int l = 0; l < QR2_S; ++l) {\n            const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300)));\n\n            const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);\n            const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);\n\n            const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);\n            const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;\n            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;\n#else\n            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;\n            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        }\n\n        const int ls = bxi->scales[kqsx];\n        const float d = bxi->d;\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;\n        x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;\n#else\n        x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;\n        x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;\n    constexpr int nrows = warp_size / threads_per_row;\n    const int kqsx = threadIdx.x % threads_per_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {\n        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride;\n\n        const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));\n        const uint8_t * q3 = (const uint8_t *) &q3_packed;\n        const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx);\n\n#pragma unroll\n        for (int l = 0; l < QR3_XXS; ++l) {\n            const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);\n            const uint32_t signs = unpack_ksigns(aux32 >> (7*l));\n\n            const int signs0 = __vcmpne4(signs & 0x08040201, 0);\n            const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);\n\n            const int signs1 = __vcmpne4(signs & 0x80402010, 0);\n            const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;\n            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;\n#else\n            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;\n            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        }\n\n        const int ls = aux32 >> 28;\n        const float d = bxi->d;\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_df[i*MMQ_MMA_TILE_X_K_Q8_0     + kqsx] = (ls*d + d/2)/2;\n#else\n        x_df[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = (ls*d + d/2)/2;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;\n    constexpr int nrows = warp_size / threads_per_row;\n    const int kqsx = threadIdx.x % threads_per_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {\n        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride;\n\n        const int2      qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));\n        const uint8_t * qs        = (const uint8_t *) &qs_packed;\n\n        const int qh = bxi->qh[kqsx];\n\n        const int       signs_packed_32 = get_int_b2(bxi->signs, kqsx);\n        const uint8_t * signs_packed_8  = (const uint8_t *) &signs_packed_32;\n\n#pragma unroll\n        for (int l = 0; l < QR3_S; ++l) {\n            const int2 grid_pos = make_int2(\n                iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)],\n                iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]);\n\n            const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);\n            const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);\n\n            const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);\n            const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;\n            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;\n#else\n            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;\n            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        }\n\n        const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);\n        const float d = bxi->d;\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_df[i*MMQ_MMA_TILE_X_K_Q8_0     + kqsx] = ls*d;\n#else\n        x_df[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = ls*d;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    half2 * x_ds = (half2 *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);\n    constexpr int nrows = warp_size / threads_per_row;\n    const int kqsx = threadIdx.x % threads_per_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {\n        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride;\n\n        const int       qs_packed = get_int_b2(bxi->qs, kqsx);\n        const uint8_t * qs        = (const uint8_t *) &qs_packed;\n\n        const int qh = bxi->qh[kqsx];\n\n    #pragma unroll\n        for (int l = 0; l < QR1_S/2; ++l) {\n            const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)];\n\n            const int grid0 = (grid >> 0) & 0x0F0F0F0F;\n            const int grid1 = (grid >> 4) & 0x0F0F0F0F;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n            x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;\n            x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;\n#else\n            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;\n            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        }\n\n        const float  d1q   = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);\n        const float  delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_ds[i*MMQ_MMA_TILE_X_K_Q8_1     + kqsx] = make_half2(d1q, d1q*delta);\n#else\n        x_ds[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = make_half2(d1q, d1q*delta);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(\n    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);\n#else\n    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);\n    int   * x_qs = (int   *)  x_tile;\n    float * x_df = (float *) (x_qs + txs.qs);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);\n    constexpr int nrows = warp_size / threads_per_row;\n    const int kqsx = threadIdx.x % threads_per_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {\n        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;\n\n        const int aux_q4 = get_int_b4(bxi->qs, kqsx);\n        const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);\n        const int k0 = 8 * (kqsx / 4) + kqsx % 4;\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;\n        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;\n#else\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;\n        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n\n    constexpr int rows_per_warp = warp_size / 8;\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {\n        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4);\n\n        if (need_check) {\n            i = min(i, i_max);\n        }\n\n        const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;\n\n        const float d = __half2float(bxi->d);\n\n        const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)\n            | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n        x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + threadIdx.x % 8] = d * (ls - 32);\n#else\n        x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    }\n}\n\ntemplate<int mmq_x, int mmq_y, bool need_check>\nstatic __device__ __forceinline__ void mmq_write_back_dp4a(\n        const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,\n        const int stride, const int i_max, const int j_max) {\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n#pragma unroll\n    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n        const int j = j0 + threadIdx.y;\n\n        if (j > j_max) {\n            return;\n        }\n\n#pragma unroll\n        for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n            const int i = i0 + threadIdx.x;\n\n            if (need_check && i > i_max) {\n                continue;\n            }\n\n            dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];\n        }\n    }\n}\n\ntemplate<ggml_type type, int mmq_x, int mmq_y, bool need_check>\nstatic __device__ __forceinline__ void mmq_write_back_mma(\n        const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,\n        const int stride, const int i_max, const int j_max) {\n\n    constexpr int granularity = mmq_get_granularity_device(mmq_x);\n    constexpr int nwarps = mmq_get_nwarps_device();\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    constexpr int tileC_IJ = mmq_get_granularity_device(0);\n    typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C;\n    constexpr int rows_per_warp = granularity;\n#else\n    typedef tile<16, 8, int> tile_C;\n    constexpr int rows_per_warp = 2 * granularity;\n#endif // defined(AMD_MFMA_AVAILABLE)\n    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.\n\n    const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);\n#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    static_assert(nwarps*tile_C::I == mmq_y, \"nwarps*tile_C::I != mmq_y\");\n#else\n    GGML_UNUSED(nwarps);\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n#pragma unroll\n    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {\n#pragma unroll\n        for (int n = 0; n < ntx; ++n) {\n#pragma unroll\n            for (int l = 0; l < tile_C::ne; ++l) {\n                const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);\n\n                if (j > j_max) {\n                    continue;\n                }\n\n                const int i = i0 + n*tile_C::I + tile_C::get_i(l);\n\n                if (need_check && i > i_max) {\n                    continue;\n                }\n\n                dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];\n            }\n        }\n    }\n}\n\n// -------------------------------------------------------------------------------------------------------------------------------------\n\ntemplate <int mmq_x, int mmq_y, bool need_check, ggml_type type>\nstruct mmq_type_traits;\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {\n    static constexpr int              vdr          = VDR_Q4_0_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_0<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {\n    static constexpr int              vdr          = VDR_Q4_1_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_1<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {\n    static constexpr int              vdr          = VDR_Q5_0_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_0<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {\n    static constexpr int              vdr          = VDR_Q5_1_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_1<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {\n    static constexpr int              vdr          = VDR_Q8_0_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q8_0<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {\n    static constexpr int              vdr          = VDR_MXFP4_Q8_1_MMQ;\n#ifdef BLACKWELL_MMA_AVAILABLE\n    static constexpr load_tiles_mmq_t load_tiles  = load_tiles_mxfp4_fp4<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;\n#else\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_mxfp4<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;\n#endif // BLACKWELL_MMA_AVAILABLE\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {\n    static constexpr int              vdr          = VDR_Q2_K_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q2_K<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {\n    static constexpr int              vdr          = VDR_Q3_K_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q3_K<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {\n    static constexpr int              vdr          = VDR_Q4_K_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_K<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {\n    static constexpr int              vdr          = VDR_Q5_K_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_K<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {\n    static constexpr int              vdr          = VDR_Q6_K_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q6_K<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {\n    static constexpr int              vdr          = VDR_IQ2_XXS_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_xxs<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {\n    static constexpr int              vdr          = VDR_IQ2_XS_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_xs<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {\n    static constexpr int              vdr          = VDR_IQ2_S_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_s<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {\n    static constexpr int              vdr          = VDR_IQ3_XXS_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq3_xxs<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {\n    static constexpr int              vdr          = VDR_IQ3_S_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq3_s<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {\n    static constexpr int              vdr          = VDR_IQ1_S_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq1_s<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {\n    static constexpr int              vdr          = VDR_IQ4_NL_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_nl<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <int mmq_x, int mmq_y, bool need_check>\nstruct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {\n    static constexpr int              vdr          = VDR_IQ4_XS_Q8_1_MMQ;\n    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_xs<mmq_y, need_check>;\n    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;\n    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;\n};\n\ntemplate <ggml_type type, int mmq_x, bool need_check, bool fixup>\nstatic __device__ __forceinline__ void mul_mat_q_process_tile(\n        const char * __restrict__ x, const int offset_x, const int * __restrict__ y,\n        const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,\n        const int stride_row_x, const int ncols_y, const int stride_col_dst,\n        const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {\n\n    constexpr int              warp_size  = ggml_cuda_get_physical_warp_size();\n    constexpr int              nwarps     = mmq_get_nwarps_device();\n    constexpr int              qk         = ggml_cuda_type_traits<type>::qk;\n    constexpr int              mmq_y      = get_mmq_y_device();\n    constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, need_check, type>::load_tiles;\n\n    extern __shared__ int data_mul_mat_q[];\n    int * tile_y = data_mul_mat_q + mmq_x;\n    int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);\n\n#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;\n    constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;\n#else\n    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;\n    constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;\n#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)\n\n#if defined(BLACKWELL_MMA_AVAILABLE)\n    // FP4 tile stores 8 blocks\n    constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;\n#else\n    constexpr int ne_block = 4 * QK8_1;\n#endif  // defined(BLACKWELL_MMA_AVAILABLE)\n\n    constexpr int ITER_K          = get_iter_k(type);\n    constexpr int blocks_per_iter = ITER_K / qk;\n\n    float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};\n\n    constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);\n\n    for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {\n        load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);\n        {\n            const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;\n#pragma unroll\n            for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {\n                int l = l0 + threadIdx.y*warp_size + threadIdx.x;\n\n                tile_y[l] = by0[l];\n            }\n        }\n\n        __syncthreads();\n\n        vec_dot(tile_x, tile_y, sum, 0);\n\n        __syncthreads();\n\n        {\n            const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);\n#pragma unroll\n            for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {\n                int l = l0 + threadIdx.y*warp_size + threadIdx.x;\n\n                tile_y[l] = by0[l];\n            }\n        }\n\n        __syncthreads();\n\n        vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K);\n\n        __syncthreads();\n    }\n\n    if (fixup) {\n        write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);\n    } else {\n        write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j);\n    }\n}\n\n\n// The mul_mat_q kernel implements \"stream-k\" work partitioning as described in https://arxiv.org/abs/2301.03598\n\ntemplate <ggml_type type, int mmq_x, bool need_check>\n#if defined(GGML_USE_HIP)\n#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)\n    __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)\n#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)\n#else\n#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA\n    __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1)\n#else\n    __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)\n#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA\n#endif // defined(GGML_USE_HIP)\nstatic __global__ void mul_mat_q(\n        const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,\n        const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,\n        const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,\n        const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,\n        const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,\n        const int ncols_max) {\n\n    // Skip unused template specializations for faster compilation:\n    if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {\n        NO_DEVICE_CODE;\n        return;\n    }\n\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    constexpr int qk    = ggml_cuda_type_traits<type>::qk;\n    constexpr int mmq_y = get_mmq_y_device();\n\n    const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x\n    const int nty = (nrows_x   + mmq_y - 1) / mmq_y; // Number of tiles y\n\n    // Initialize the ids for writing back data with just the index.\n    // For regular matrix multiplications this is never changed.\n    // For MoE the correct indices are loaded from ids_dst.\n    extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.\n#pragma unroll\n    for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {\n        const int j = j0 + threadIdx.y*warp_size + threadIdx.x;\n\n        if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {\n            break;\n        }\n\n        ids_dst_shared[j] = j;\n    }\n    __syncthreads();\n\n    // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:\n#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA\n    {\n        const int wt = blockIdx.z / nchannels_y;\n        const int zt = blockIdx.z - wt*nchannels_y;\n        const int jt = blockIdx.y;\n        const int it = blockIdx.x;\n\n        // Defaults for regular matrix multiplication:\n        int col_low    = 0;\n        int col_high   = ncols_dst;\n        int col_diff   = ncols_dst;\n        int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;\n        int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;\n\n        if (ids_dst) {\n            col_low  = expert_bounds[zt + 0];\n            col_high = expert_bounds[zt + 1];\n            col_diff = col_high - col_low;\n\n            offset_y   = 0;\n            offset_dst = 0;\n\n            if (jt*mmq_x >= col_diff) {\n                return;\n            }\n\n            // __syncthreads(); // There is no previous tile that could cause a race condition.\n#pragma unroll\n            for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {\n                const int j = j0 + threadIdx.y*warp_size + threadIdx.x;\n\n                if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {\n                    break;\n                }\n\n                ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];\n            }\n            __syncthreads();\n        }\n\n        offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));\n        offset_dst += it*mmq_y;\n\n        const int tile_x_max_i = nrows_x  - it*mmq_y - 1;\n        const int tile_y_max_j = col_diff - jt*mmq_x - 1;\n\n        const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;\n\n        constexpr bool fixup = false;\n        mul_mat_q_process_tile<type, mmq_x, need_check, fixup>\n            (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,\n             tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);\n        return;\n    }\n#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA\n\n    constexpr int ITER_K = get_iter_k(type);\n\n    const     int64_t blocks_per_ne00 = ncols_x / qk;\n    constexpr int     blocks_per_iter = ITER_K / qk;\n\n    // kbc == k block continuous, current index in continuous ijk space.\n    int64_t kbc      = (int64_t) blockIdx.x     *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;\n    int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;\n\n    kbc      -= (kbc      % blocks_per_ne00) % blocks_per_iter;\n    kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;\n\n    // kb0 == k index when doing the matrix multiplication for an output tile.\n    int kb0_start = kbc % blocks_per_ne00;\n    int kb0_stop  = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);\n    while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {\n        int tmp = kbc;\n        const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);\n        tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);\n        const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);\n        tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);\n        const int zt = tmp / (ntx*blocks_per_ne00);\n        tmp -= zt * (ntx*blocks_per_ne00);\n        const int jt = tmp / blocks_per_ne00;\n\n        // Defaults for regular matrix multiplication:\n        int col_low    = 0;\n        int col_high   = ncols_dst;\n        int col_diff   = ncols_dst;\n        int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;\n        int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;\n\n        if (ids_dst) {\n            col_low  = expert_bounds[zt + 0];\n            col_high = expert_bounds[zt + 1];\n            col_diff = col_high - col_low;\n\n            offset_y   = 0;\n            offset_dst = 0;\n\n            if (jt*mmq_x >= col_diff) {\n                kbc += blocks_per_ne00;\n                kbc -= kbc % blocks_per_ne00;\n\n                kb0_start = 0;\n                kb0_stop  = min(blocks_per_ne00, kbc_stop - kbc);\n\n                continue;\n            }\n\n            __syncthreads();\n#pragma unroll\n            for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {\n                const int j = j0 + threadIdx.y*warp_size + threadIdx.x;\n\n                if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {\n                    break;\n                }\n\n                ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];\n            }\n            __syncthreads();\n        }\n\n        offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));\n        offset_dst += it*mmq_y;\n\n        const int tile_x_max_i = nrows_x  - it*mmq_y - 1;\n        const int tile_y_max_j = col_diff - jt*mmq_x - 1;\n\n        const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;\n\n        constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.\n        mul_mat_q_process_tile<type, mmq_x, need_check, fixup>\n            (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,\n             tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);\n\n        kbc += blocks_per_ne00;\n        kbc -= kbc % blocks_per_ne00;\n\n        kb0_start = 0;\n        kb0_stop  = min(blocks_per_ne00, kbc_stop - kbc);\n    }\n\n    if (kbc >= kbc_stop) {\n        return;\n    }\n\n    int tmp = kbc;\n    const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);\n    tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);\n    const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);\n    tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);\n    const int zt = tmp / (ntx*blocks_per_ne00);\n    tmp -= zt * (ntx*blocks_per_ne00);\n    const int jt = tmp / blocks_per_ne00;\n\n    // Defaults for regular matrix multiplication:\n    int col_low    = 0;\n    int col_high   = ncols_dst;\n    int col_diff   = ncols_dst;\n    int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;\n    int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;\n\n    if (ids_dst) {\n        col_low  = expert_bounds[zt + 0];\n        col_high = expert_bounds[zt + 1];\n        col_diff = col_high - col_low;\n\n        offset_y   = 0;\n        offset_dst = 0;\n\n        if (jt*mmq_x >= col_diff) {\n            return;\n        }\n\n        // The memory layout for the fixup buffer is always contiguous, therefore reset ids:\n        __syncthreads();\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {\n            const int j = j0 + threadIdx.y*warp_size + threadIdx.x;\n\n            if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {\n                break;\n            }\n\n            ids_dst_shared[j] = j;\n        }\n        __syncthreads();\n    }\n\n    offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));\n    offset_dst += it*mmq_y;\n\n    const int tile_x_max_i = nrows_x  - it*mmq_y - 1;\n    const int tile_y_max_j = col_diff - jt*mmq_x - 1;\n\n    const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;\n\n    constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.\n    mul_mat_q_process_tile<type, mmq_x, need_check, fixup>\n        (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,\n         tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);\n}\n\ntemplate <ggml_type type, int mmq_x, bool need_check>\nstatic __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,\n                                                const int32_t * expert_bounds,\n                                                float * __restrict__ dst,\n                                                const float * __restrict__ tmp_last_tile,\n                                                const int    ncols_x,\n                                                const int    nrows_x,\n                                                const int    ncols_dst,\n                                                const size_t stride_col_dst,\n                                                const int    nchannels_y,\n                                                const size_t stride_channel_dst,\n                                                const int    nsamples_y,\n                                                const size_t stride_sample_dst,\n                                                const int    ncols_max) {\n    constexpr int     mmq_y           = get_mmq_y_device();\n    constexpr int     qk              = ggml_cuda_type_traits<type>::qk;\n    constexpr int     ITER_K          = get_iter_k(type);\n\n    constexpr int     blocks_per_iter = ITER_K / qk;\n    const     int64_t blocks_per_ne00 = ncols_x / qk;\n\n    constexpr int nwarps = mmq_get_nwarps_device();\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};\n\n    const int ntx  = (ncols_max + mmq_x - 1) / mmq_x;\n    const int nty  = (nrows_x   + mmq_y - 1) / mmq_y;\n\n    const int bidx0 = blockIdx.x;\n\n    // kbc == k block continuous, current index in continuous ijk space.\n    int64_t kbc0      = (int64_t) bidx0     *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;\n    int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;\n\n    kbc0      -= (kbc0      % blocks_per_ne00) % blocks_per_iter;\n    kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter;\n\n    const bool did_not_have_any_data   = kbc0 == kbc0_stop;\n    const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0;\n    const bool did_not_write_last      = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0;\n    if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {\n        return;\n    }\n\n    bool any_fixup = false;\n\n    // Iterate over previous blocks and sum up partial sums written to fixup buffer.\n    // All CUDA blocks that get here must have a previous block that needs a fixup.\n    int64_t bidx = bidx0 - 1;\n    int64_t kbc_stop = kbc0;\n    while(true) {\n        int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;\n        kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;\n\n        if (kbc == kbc_stop) { // Did not have any data.\n            bidx--;\n            kbc_stop = kbc;\n            continue;\n        }\n\n        any_fixup = true;\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n#pragma unroll\n            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n\n                sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];\n            }\n        }\n\n        // If this block started in a previous tile we are done and don't need to combine additional partial results.\n        if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) {\n            break;\n        }\n        bidx--;\n        kbc_stop = kbc;\n    }\n\n    if (!any_fixup) {\n        return;\n    }\n\n    int tmp = kbc0;\n    const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);\n    tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);\n    const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);\n    tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);\n    const int zt = tmp / (ntx*blocks_per_ne00);\n    tmp -= zt * (ntx*blocks_per_ne00);\n    const int jt = tmp / blocks_per_ne00;\n\n    if (!ids_dst) {\n        const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;\n        dst += offset_dst;\n\n        const int i_max = nrows_x   - it*mmq_y - 1;\n        const int j_max = ncols_dst - jt*mmq_x - 1;\n\n#pragma unroll\n        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n            const int j = j0 + threadIdx.y;\n\n            if (j > j_max) {\n                return;\n            }\n\n#pragma unroll\n            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n                const int i = i0 + threadIdx.x;\n\n                if (need_check && i > i_max) {\n                    continue;\n                }\n\n                dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];\n            }\n        }\n        return;\n    }\n\n    __shared__ int ids_dst_shared[mmq_x];\n    const int col_low  = expert_bounds[zt + 0];\n    const int col_high = expert_bounds[zt + 1];\n    const int col_diff = col_high - col_low;\n\n    for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {\n        ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];\n    }\n    __syncthreads();\n\n    const int offset_dst = it*mmq_y;\n    dst += offset_dst;\n\n    const int i_max = nrows_x  - it*mmq_y - 1;\n    const int j_max = col_diff - jt*mmq_x - 1;\n\n#pragma unroll\n    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {\n        const int j = j0 + threadIdx.y;\n\n        if (j > j_max) {\n            return;\n        }\n\n#pragma unroll\n        for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {\n            const int i = i0 + threadIdx.x;\n\n            if (need_check && i > i_max) {\n                continue;\n            }\n\n            dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];\n        }\n    }\n}\n\nstruct mmq_args {\n    const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;\n    int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;\n    int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;\n    int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;\n    bool use_stream_k; int64_t ncols_max;\n};\n\ntemplate<ggml_type type>\nstatic size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {\n    const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);\n    const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);\n    const size_t nbs_ids = mmq_x*sizeof(int);\n    const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);\n    const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));\n    return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));\n}\n\ntemplate <ggml_type type, int mmq_x>\nstatic void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {\n    const int id = ggml_cuda_get_device();\n    const int cc = ggml_cuda_info().devices[id].cc;\n    const int nsm = ggml_cuda_info().devices[id].nsm;\n    const int warp_size = ggml_cuda_info().devices[id].warp_size;\n    const int nwarps = mmq_get_nwarps_host(cc, warp_size);\n    const int mmq_y = get_mmq_y_host(cc);\n\n    const dim3 block_dims(warp_size, nwarps, 1);\n\n    const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps);\n\n    CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, false>), nbytes_shared);\n    CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x,  true>), nbytes_shared);\n\n    const int nty  = (args.nrows_x   + mmq_y - 1) / mmq_y;\n    const int ntx  = (args.ncols_max + mmq_x - 1) / mmq_x;\n    const int ntzw = args.nchannels_y * args.nsamples_y;\n    const dim3 block_nums_xy_tiling(nty, ntx, ntzw);\n\n    GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0);\n    GGML_ASSERT(args.nsamples_y  % args.nsamples_x  == 0);\n    const int channel_ratio = args.nchannels_y / args.nchannels_x;\n    const int sample_ratio  = args.nsamples_y  / args.nsamples_x;\n\n    if (!args.use_stream_k) {\n        if (args.nrows_x % mmq_y == 0) {\n            constexpr bool need_check = false;\n            mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>\n                (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,\n                 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,\n                 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,\n                 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,\n                 args.ncols_max);\n        } else {\n            constexpr bool need_check = true;\n            mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>\n                (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,\n                 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,\n                 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,\n                 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,\n                 args.ncols_max);\n        }\n        return;\n    }\n\n    const dim3 block_nums_stream_k(nsm, 1, 1);\n    const bool fixup_needed = ntx*nty*ntzw % nsm != 0;\n\n    ggml_cuda_pool & pool = ctx.pool(id);\n    ggml_cuda_pool_alloc<float> tmp_fixup(pool);\n    if (fixup_needed) {\n        tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y);\n    }\n\n    if (args.nrows_x % mmq_y == 0) {\n        constexpr bool need_check = false;\n        mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>\n            (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,\n             args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,\n             channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,\n             sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,\n             args.ncols_max);\n\n        if (!fixup_needed) {\n            return;\n        }\n\n        mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>\n            (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,\n             args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,\n             args.ncols_max);\n    } else {\n        constexpr bool need_check = true;\n        mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>\n            (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,\n             args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,\n             channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,\n             sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,\n             args.ncols_max);\n\n        if (!fixup_needed) {\n            return;\n        }\n\n        mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>\n            (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,\n             args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,\n             args.ncols_max);\n    }\n}\n\ntemplate <ggml_type type>\nvoid mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {\n    const int    id     = ggml_cuda_get_device();\n    const int    cc     = ggml_cuda_info().devices[id].cc;\n    const size_t smpbo  = ggml_cuda_info().devices[id].smpbo;\n    const int warp_size = ggml_cuda_info().devices[id].warp_size;\n    const int nwarps    = mmq_get_nwarps_host(cc, warp_size);\n\n    const int mmq_x_max = get_mmq_x_max_host(cc);\n    const int mmq_y = get_mmq_y_host(cc);\n\n    int mmq_x_best  = 0;\n    int ntiles_x_best = INT_MAX;\n\n    for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {\n        const int granularity = mmq_get_granularity_host(mmq_x, cc);\n\n        if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) {\n            continue;\n        }\n\n        const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;\n\n        if (ntiles_x < ntiles_x_best) {\n            mmq_x_best = mmq_x;\n            ntiles_x_best = ntiles_x;\n        }\n    }\n\n    switch (mmq_x_best) {\n        case   8:\n            launch_mul_mat_q<type,   8>(ctx, args, stream);\n            break;\n        case  16:\n            launch_mul_mat_q<type,  16>(ctx, args, stream);\n            break;\n        case  24:\n            launch_mul_mat_q<type,  24>(ctx, args, stream);\n            break;\n        case  32:\n            launch_mul_mat_q<type,  32>(ctx, args, stream);\n            break;\n        case  40:\n            launch_mul_mat_q<type,  40>(ctx, args, stream);\n            break;\n        case  48:\n            launch_mul_mat_q<type,  48>(ctx, args, stream);\n            break;\n        case  56:\n            launch_mul_mat_q<type,  56>(ctx, args, stream);\n            break;\n        case  64:\n            launch_mul_mat_q<type,  64>(ctx, args, stream);\n            break;\n        case  72:\n            launch_mul_mat_q<type,  72>(ctx, args, stream);\n            break;\n        case  80:\n            launch_mul_mat_q<type,  80>(ctx, args, stream);\n            break;\n        case  88:\n            launch_mul_mat_q<type,  88>(ctx, args, stream);\n            break;\n        case  96:\n            launch_mul_mat_q<type,  96>(ctx, args, stream);\n            break;\n        case 104:\n            launch_mul_mat_q<type, 104>(ctx, args, stream);\n            break;\n        case 112:\n            launch_mul_mat_q<type, 112>(ctx, args, stream);\n            break;\n        case 120:\n            launch_mul_mat_q<type, 120>(ctx, args, stream);\n            break;\n        case 128:\n            launch_mul_mat_q<type, 128>(ctx, args, stream);\n            break;\n        default:\n            fprintf(stderr, \"mmq_x_best=%d\\n\", mmq_x_best);\n            GGML_ABORT(\"fatal error\");\n            break;\n    }\n}\n\n#define DECL_MMQ_CASE(type)                                                        \\\n    template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \\\n\nextern DECL_MMQ_CASE(GGML_TYPE_Q4_0);\nextern DECL_MMQ_CASE(GGML_TYPE_Q4_1);\nextern DECL_MMQ_CASE(GGML_TYPE_Q5_0);\nextern DECL_MMQ_CASE(GGML_TYPE_Q5_1);\nextern DECL_MMQ_CASE(GGML_TYPE_Q8_0);\nextern DECL_MMQ_CASE(GGML_TYPE_MXFP4);\nextern DECL_MMQ_CASE(GGML_TYPE_Q2_K);\nextern DECL_MMQ_CASE(GGML_TYPE_Q3_K);\nextern DECL_MMQ_CASE(GGML_TYPE_Q4_K);\nextern DECL_MMQ_CASE(GGML_TYPE_Q5_K);\nextern DECL_MMQ_CASE(GGML_TYPE_Q6_K);\nextern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);\nextern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);\nextern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);\nextern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);\nextern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);\nextern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);\nextern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);\nextern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);\n\n// -------------------------------------------------------------------------------------------------------------------------\n\nvoid ggml_cuda_mul_mat_q(\n        ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);\n\nvoid ggml_cuda_op_mul_mat_q(\n    ggml_backend_cuda_context & ctx,\n    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,\n    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,\n    const int64_t src1_padded_row_size, cudaStream_t stream);\n\nbool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);\n"
  },
  {
    "path": "src/ggml-cuda/mmvf.cu",
    "content": "#include \"ggml.h\"\n#include \"common.cuh\"\n#include \"unary.cuh\"\n#include \"mmvf.cuh\"\n#include \"convert.cuh\"\n\ntemplate <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false>\nstatic __global__ void mul_mat_vec_f(\n        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,\n        const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,\n        const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,\n        const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,\n        const int ids_stride) {\n    const int row         = blockIdx.x;\n    // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)\n    const int channel_dst = blockIdx.y;\n    const int tid         = threadIdx.x;\n\n    int token_idx;\n    int channel_x;\n    int channel_y;\n    int sample_dst;\n\n    if constexpr (is_multi_token_id) {\n        // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case\n        token_idx  = blockIdx.z;\n        channel_x  = ids[channel_dst + token_idx * ids_stride];\n        channel_y  = fastmodulo(channel_dst, nchannels_y);\n        sample_dst = 0;\n    } else {\n        token_idx  = ids ? blockIdx.z                                          : 0;\n        channel_x  = ids ? ids[blockIdx.y + token_idx * ids_stride]            : fastdiv((uint32_t) channel_dst, channel_ratio);\n        channel_y  = ids ? fastmodulo(blockIdx.y, nchannels_y)                 : channel_dst;\n        sample_dst = ids ? 0                                                   : blockIdx.z;\n    }\n\n    const int sample_x    = fastdiv((uint32_t) sample_dst, sample_ratio);\n    const int sample_y    = sample_dst;\n\n    constexpr int warp_size   = ggml_cuda_get_physical_warp_size();\n\n    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;\n    y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y;\n    dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;\n    if constexpr (is_multi_token_id) {\n        y   += token_idx*stride_col_y2*2;\n        dst += token_idx*stride_col_dst;\n    }\n\n    bool use_gate = false;\n    bool use_bias = false;\n    bool use_gate_bias = false;\n    ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;\n    const T * gate_x = nullptr;\n    const float * x_bias = nullptr;\n    const float * gate_bias = nullptr;\n\n    if constexpr (has_fusion) {\n        use_gate = fusion.gate != nullptr;\n        use_bias = fusion.x_bias != nullptr;\n        use_gate_bias = fusion.gate_bias != nullptr;\n        glu_op = fusion.glu_op;\n\n        if (use_gate) {\n            gate_x = static_cast<const T *>(fusion.gate);\n        }\n        if (use_bias) {\n            x_bias = static_cast<const float *>(fusion.x_bias);\n        }\n        if (use_gate_bias) {\n            gate_bias = static_cast<const float *>(fusion.gate_bias);\n            use_gate_bias = use_gate;\n        } else {\n            use_gate_bias = false;\n        }\n    }\n\n    if (use_gate) {\n        gate_x += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;\n    }\n\n    const int channel_bias = ids ? channel_x : channel_dst;\n\n    if constexpr (has_fusion) {\n        if (use_bias) {\n            x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;\n        }\n        if (use_gate_bias) {\n            gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;\n        }\n    }\n\n    const float2 * y2 = (const float2 *) y;\n\n    extern __shared__ char data_mmv[];\n    float * buf_iw = (float *) data_mmv;\n    float * buf_iw_gate = nullptr;\n    if constexpr (has_fusion) {\n        buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));\n    }\n\n    if (block_size > warp_size) {\n        if (tid < warp_size) {\n            buf_iw[tid] = 0.0f;\n            if constexpr (has_fusion) {\n                if (use_gate) {\n                    buf_iw_gate[tid] = 0.0f;\n                }\n            }\n        }\n        __syncthreads();\n    }\n\n    float sumf[ncols_dst] = {0.0f};\n    float sumf_gate[ncols_dst];\n    if constexpr (has_fusion) {\n#pragma unroll\n        for (int j = 0; j < ncols_dst; ++j) {\n            sumf_gate[j] = 0.0f;\n        }\n    }\n\n    if constexpr (std::is_same_v<T, float>) {\n        const float2 * x2 = (const float2 *) x;\n        const float2 * gate_x2 = nullptr;\n        if constexpr (has_fusion) {\n            if (use_gate) {\n                gate_x2 = (const float2 *) gate_x;\n            }\n        }\n\n        for (int col2 = tid; col2 < ncols2; col2 += block_size) {\n            const float2 tmpx = x2[col2];\n            float2 tmpx_gate = make_float2(0.0f, 0.0f);\n            if constexpr (has_fusion) {\n                if (use_gate) {\n                    tmpx_gate = gate_x2[col2];\n                }\n            }\n\n#pragma unroll\n            for (int j = 0; j < ncols_dst; ++j) {\n                const float2 tmpy = y2[j*stride_col_y2 + col2];\n                ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);\n                ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);\n\n                if constexpr (has_fusion) {\n                    if (use_gate) {\n                        ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);\n                        ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);\n                    }\n                }\n            }\n        }\n    } else if constexpr (std::is_same_v<T, half>) {\n        const half2 * x2 = (const half2 *) x;\n        const half2 * gate_x2 = nullptr;\n        if constexpr (has_fusion) {\n            if (use_gate) {\n                gate_x2 = (const half2 *) gate_x;\n            }\n        }\n\n        if (std::is_same_v<type_acc, float>) {\n            for (int col2 = tid; col2 < ncols2; col2 += block_size) {\n                const float2 tmpx = __half22float2(x2[col2]);\n                float2 tmpx_gate = make_float2(0.0f, 0.0f);\n                if constexpr (has_fusion) {\n                    if (use_gate) {\n                        tmpx_gate = __half22float2(gate_x2[col2]);\n                    }\n                }\n#pragma unroll\n                for (int j = 0; j < ncols_dst; ++j) {\n                    const float2 tmpy = y2[j*stride_col_y2 + col2];\n                    ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);\n                    ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);\n\n                    if constexpr (has_fusion) {\n                        if (use_gate) {\n                            ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);\n                            ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);\n                        }\n                    }\n                }\n            }\n        } else {\n#ifdef FP16_AVAILABLE\n            half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};\n            half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};\n\n            for (int col2 = tid; col2 < ncols2; col2 += block_size) {\n                const half2 tmpx = x2[col2];\n                half2 tmpx_gate = make_half2(0.0f, 0.0f);\n                if constexpr (has_fusion) {\n                    if (use_gate) {\n                        tmpx_gate = gate_x2[col2];\n                    }\n                }\n#pragma unroll\n                for (int j = 0; j < ncols_dst; ++j) {\n                    const float2 tmpy = y2[j*stride_col_y2 + col2];\n                    sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);\n\n                    if constexpr (has_fusion) {\n                        if (use_gate) {\n                            sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);\n                        }\n                    }\n                }\n            }\n\n#pragma unroll\n            for (int j = 0; j < ncols_dst; ++j) {\n                sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);\n            }\n\n            if constexpr (has_fusion) {\n                if (use_gate) {\n#pragma unroll\n                    for (int j = 0; j < ncols_dst; ++j) {\n                        sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);\n                    }\n                }\n            }\n#else\n            NO_DEVICE_CODE;\n#endif // FP16_AVAILABLE\n        }\n    } else if constexpr (std::is_same_v<T, nv_bfloat16>) {\n//TODO: add support for ggml_cuda_mad for hip_bfloat162\n#if defined(GGML_USE_HIP)\n        const int * x2 = (const int *) x;\n        const int * gate_x2 = nullptr;\n        if constexpr (has_fusion) {\n            if (use_gate) {\n                gate_x2 = (const int *) gate_x;\n            }\n        }\n        for (int col2 = tid; col2 < ncols2; col2 += block_size) {\n            const int tmpx = x2[col2];\n            int tmpx_gate = 0;\n            if constexpr (has_fusion) {\n                if (use_gate) {\n                    tmpx_gate = gate_x2[col2];\n                }\n            }\n#pragma unroll\n            for (int j = 0; j < ncols_dst; ++j) {\n                const float2 tmpy = y2[j*stride_col_y2 + col2];\n                const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);\n                const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);\n                ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);\n                ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);\n\n                if constexpr (has_fusion) {\n                    if (use_gate) {\n                        const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);\n                        const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);\n                        ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);\n                        ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);\n                    }\n                }\n            }\n        }\n#else\n        const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;\n        const nv_bfloat162 * gate_x2 = nullptr;\n        if constexpr (has_fusion) {\n            if (use_gate) {\n                gate_x2 = (const nv_bfloat162 *) gate_x;\n            }\n        }\n        for (int col2 = tid; col2 < ncols2; col2 += block_size) {\n            const nv_bfloat162 tmpx = x2[col2];\n            nv_bfloat162 tmpx_gate;\n            if constexpr (has_fusion) {\n                if (use_gate) {\n                    tmpx_gate = gate_x2[col2];\n                }\n            }\n#pragma unroll\n            for (int j = 0; j < ncols_dst; ++j) {\n                const float2 tmpy = y2[j*stride_col_y2 + col2];\n                ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);\n                ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);\n\n                if constexpr (has_fusion) {\n                    if (use_gate) {\n                        ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);\n                        ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);\n                    }\n                }\n            }\n        }\n#endif\n    } else {\n        static_assert(std::is_same_v<T, void>, \"unsupported type\");\n    }\n\n#pragma unroll\n    for (int j = 0; j < ncols_dst; ++j) {\n        sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);\n\n        if constexpr (has_fusion) {\n            if (use_gate) {\n                sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);\n            }\n        }\n\n        if (block_size > warp_size) {\n            buf_iw[tid/warp_size] = sumf[j];\n            if constexpr (has_fusion) {\n                if (use_gate) {\n                    buf_iw_gate[tid/warp_size] = sumf_gate[j];\n                }\n            }\n            __syncthreads();\n            if (tid < warp_size) {\n                sumf[j] = buf_iw[tid];\n                sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);\n                if constexpr (has_fusion) {\n                    if (use_gate) {\n                        sumf_gate[j] = buf_iw_gate[tid];\n                        sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);\n                    }\n                }\n            }\n\n            if (j < ncols_dst) {\n                __syncthreads();\n            }\n        }\n    }\n\n    if (tid >= ncols_dst) {\n        return;\n    }\n\n    float value = sumf[tid];\n\n    if constexpr (has_fusion) {\n        if (use_bias) {\n            value += x_bias[tid*stride_col_dst + row];\n        }\n\n        if (use_gate) {\n            float gate_value = sumf_gate[tid];\n            if (use_gate_bias) {\n                gate_value += gate_bias[tid*stride_col_dst + row];\n            }\n            switch (glu_op) {\n                case GGML_GLU_OP_SWIGLU:\n                    value *= ggml_cuda_op_silu_single(gate_value);\n                    break;\n                case GGML_GLU_OP_GEGLU:\n                    value *= ggml_cuda_op_gelu_single(gate_value);\n                    break;\n                case GGML_GLU_OP_SWIGLU_OAI: {\n                    value = ggml_cuda_op_swiglu_oai_single(gate_value, value);\n                    break;\n                }\n                default:\n                    break;\n            }\n        }\n    }\n\n    dst[tid*stride_col_dst + row] = value;\n\n    if constexpr (!has_fusion) {\n        GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate);\n    }\n}\n\ntemplate<typename T, typename type_acc, int ncols_dst, int block_size, bool is_multi_token_id = false>\nstatic void mul_mat_vec_f_switch_fusion(\n        const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,\n        const int64_t ncols, const uint3 nchannels_y,\n        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,\n        const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,\n        const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,\n        const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) {\n\n    const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;\n    if constexpr (ncols_dst == 1) {\n        if (has_fusion) {\n            mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>\n                (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,\n                channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);\n            return;\n       }\n    }\n\n    GGML_ASSERT(!has_fusion && \"fusion only supported for ncols_dst=1\");\n\n    mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>\n        (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,\n        channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);\n\n}\n\ntemplate <typename T, typename type_acc, int ncols_dst, bool is_multi_token_id = false>\nvoid launch_mul_mat_vec_f_cuda(\n        const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,\n        const int64_t ncols, const int64_t nrows,\n        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,\n        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,\n        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\n        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,\n        const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) {\n    GGML_ASSERT(ncols        % 2 == 0);\n    GGML_ASSERT(stride_row   % 2 == 0);\n    GGML_ASSERT(stride_col_y % 2 == 0);\n    GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);\n    GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);\n    const uint3 nchannels_y_fd   = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);\n    const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);\n    const uint3 sample_ratio_fd  = init_fastdiv_values(nsamples_dst  / nsamples_x);\n\n    const int device = ggml_cuda_get_device();\n    const int warp_size = ggml_cuda_info().devices[device].warp_size;\n\n    int64_t block_size_best = warp_size;\n    int64_t niter_best      = (ncols + 2*warp_size - 1) / (2*warp_size);\n    int64_t max_block_size  = 256;\n    if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {\n        max_block_size = 128;\n    }\n    for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {\n        const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);\n        if (niter < niter_best) {\n            niter_best      = niter;\n            block_size_best = block_size;\n        }\n    }\n\n    const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;\n\n    const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);\n    const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens);\n    const dim3 block_dims(block_size_best, 1, 1);\n    switch (block_size_best) {\n        case   32: {\n            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32, is_multi_token_id>\n                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);\n        } break;\n        case   64: {\n            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64, is_multi_token_id>\n                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);\n        } break;\n        case   96: {\n            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96, is_multi_token_id>\n                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);\n        } break;\n        case  128: {\n            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128, is_multi_token_id>\n                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);\n        } break;\n        case  160: {\n            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160, is_multi_token_id>\n                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);\n        } break;\n        case  192: {\n            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192, is_multi_token_id>\n                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);\n        } break;\n        case  224: {\n            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224, is_multi_token_id>\n                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);\n        } break;\n        case  256: {\n            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256, is_multi_token_id>\n                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);\n        } break;\n        default: {\n            GGML_ABORT(\"fatal error\");\n        } break;\n    }\n}\n\ntemplate <typename T, typename type_acc>\nstatic void mul_mat_vec_f_cuda_switch_ncols_dst(\n        const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,\n        const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,\n        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,\n        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,\n        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\n        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,\n        const int64_t ids_stride, cudaStream_t stream) {\n\n    const bool has_ids = ids != nullptr;\n\n    if (has_ids && ncols_dst > 1) {\n        // Multi-token MUL_MAT_ID path only - single-token goes through regular path below\n        constexpr int c_ncols_dst = 1;\n        launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst, true>\n            (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,\n             nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,\n             stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,\n             ncols_dst, ids_stride, stream);\n        return;\n    }\n\n    if (has_ids) {\n        // Single-token MUL_MAT_ID path\n        constexpr int c_ncols_dst = 1;\n        launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst>\n            (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,\n             nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,\n             stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,\n             ncols_dst, ids_stride, stream);\n        return;\n    }\n\n    switch (ncols_dst) {\n        case 1:\n            launch_mul_mat_vec_f_cuda<T, type_acc, 1>\n                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,\n                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 nsamples_dst, ids_stride, stream);\n            break;\n        case 2:\n            launch_mul_mat_vec_f_cuda<T, type_acc, 2>\n                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,\n                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 nsamples_dst, ids_stride, stream);\n            break;\n        case 3:\n            launch_mul_mat_vec_f_cuda<T, type_acc, 3>\n                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,\n                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 nsamples_dst, ids_stride, stream);\n            break;\n        case 4:\n            launch_mul_mat_vec_f_cuda<T, type_acc, 4>\n                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,\n                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 nsamples_dst, ids_stride, stream);\n            break;\n        case 5:\n            launch_mul_mat_vec_f_cuda<T, type_acc, 5>\n                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,\n                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 nsamples_dst, ids_stride, stream);\n            break;\n        case 6:\n            launch_mul_mat_vec_f_cuda<T, type_acc, 6>\n                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,\n                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 nsamples_dst, ids_stride, stream);\n            break;\n        case 7:\n            launch_mul_mat_vec_f_cuda<T, type_acc, 7>\n                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,\n                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 nsamples_dst, ids_stride, stream);\n            break;\n        case 8:\n            launch_mul_mat_vec_f_cuda<T, type_acc, 8>\n                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,\n                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 nsamples_dst, ids_stride, stream);\n            break;\n        default:\n            GGML_ABORT(\"fatal error\");\n            break;\n    }\n}\n\ntemplate<typename T>\nstatic void mul_mat_vec_f_cuda(\n        const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,\n        const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,\n        const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,\n        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,\n        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\n        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,\n        const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) {\n\n    if constexpr(std::is_same_v<T, half>) {\n        if (prec == GGML_PREC_DEFAULT) {\n            mul_mat_vec_f_cuda_switch_ncols_dst<T, half>\n                (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,\n                stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            return;\n        }\n    }\n    mul_mat_vec_f_cuda_switch_ncols_dst<T, float>\n        (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,\n        nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,\n        stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n}\n\nvoid ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,\n    const ggml_cuda_mm_fusion_args_host * fusion) {\n    GGML_ASSERT(        src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);\n    GGML_ASSERT(         dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_BINARY_OP_LOCALS;\n\n    const size_t ts_src0 = ggml_type_size(src0->type);\n    const size_t ts_src1 = ggml_type_size(src1->type);\n    const size_t ts_dst  = ggml_type_size(dst->type);\n\n    GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE);\n    GGML_ASSERT(ne13 == ne3);\n\n    GGML_ASSERT(        nb00       == ts_src0);\n    GGML_ASSERT(        nb10       == ts_src1);\n    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));\n    GGML_ASSERT(        nb0        == ts_dst);\n\n    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;\n    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;\n\n    const float   * src1_d =       (const float   *) src1->data;\n    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;\n    float         *  dst_d =       (float         *)  dst->data;\n\n    ggml_cuda_mm_fusion_args_device fusion_local{};\n\n    if (fusion) {\n        GGML_ASSERT( !ids || dst->ne[2] == 1);\n        GGML_ASSERT(  ids || dst->ne[1] == 1);\n        if (fusion->x_bias) {\n            GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);\n            GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);\n            GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);\n            fusion_local.x_bias = fusion->x_bias->data;\n        }\n        if (fusion->gate) {\n            GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));\n            fusion_local.gate = fusion->gate->data;\n        }\n        if (fusion->gate_bias) {\n            GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);\n            GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);\n            GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);\n            fusion_local.gate_bias = fusion->gate_bias->data;\n        }\n        fusion_local.glu_op = fusion->glu_op;\n    }\n\n    const int64_t s01 = src0->nb[1] / ts_src0;\n    const int64_t s11 = src1->nb[1] / ts_src1;\n    const int64_t s1  =  dst->nb[1] / ts_dst;\n    const int64_t s02 = src0->nb[2] / ts_src0;\n    const int64_t s12 = src1->nb[2] / ts_src1;\n    const int64_t s2  =  dst->nb[2] / ts_dst;\n    const int64_t s03 = src0->nb[3] / ts_src0;\n    const int64_t s13 = src1->nb[3] / ts_src1;\n    const int64_t s3  =  dst->nb[3] / ts_dst;\n\n    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:\n    const int64_t ncols_dst          = ids ? ne2  : ne1;\n    const int64_t nchannels_y        = ids ? ne11 : ne12;\n    const int64_t nchannels_dst      = ids ? ne1  : ne2;\n    const int64_t stride_col_dst     = ids ? s2   : s1;\n    const int64_t stride_col_y       = ids ? s12  : s11;\n    const int64_t stride_channel_dst = ids ? s1   : s2;\n    const int64_t stride_channel_y   = ids ? s11  : s12;\n\n    const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;\n\n    switch (src0->type) {\n        case GGML_TYPE_F32: {\n            const float * src0_d = (const float *) src0->data;\n            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,\n                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,\n                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());\n        } break;\n        case GGML_TYPE_F16: {\n            const half * src0_d = (const half *) src0->data;\n            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,\n                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,\n                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());\n        } break;\n        case GGML_TYPE_BF16: {\n            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;\n            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,\n                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,\n                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());\n        } break;\n        default:\n            GGML_ABORT(\"unsupported type: %s\", ggml_type_name(src0->type));\n    }\n}\n\nvoid ggml_cuda_op_mul_mat_vec_f(\n    ggml_backend_cuda_context & ctx,\n    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,\n    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,\n    const int64_t src1_padded_row_size, cudaStream_t stream) {\n\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type  == GGML_TYPE_F32);\n\n    const int64_t ne00 = src0->ne[0];\n    const int64_t ne10 = src1->ne[0];\n    const int64_t ne0  =  dst->ne[0];\n    const int64_t row_diff = row_high - row_low;\n\n    const int id = ggml_cuda_get_device();\n    const int cc = ggml_cuda_info().devices[id].cc;\n    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;\n\n    // ggml_cuda_op provides single, contiguous matrices\n    const int64_t stride_row         = ne00;\n    const int64_t stride_col_y       = ne10;\n    const int64_t stride_col_dst     = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer\n    const int64_t nchannels_x        = 1;\n    const int64_t nchannels_y        = 1;\n    const int64_t nchannels_dst      = 1;\n    const int64_t stride_channel_x   = 0;\n    const int64_t stride_channel_y   = 0;\n    const int64_t stride_channel_dst = 0;\n    const int64_t nsamples_x         = 1;\n    const int64_t nsamples_dst       = 1;\n    const int64_t stride_sample_x    = 0;\n    const int64_t stride_sample_y    = 0;\n    const int64_t stride_sample_dst  = 0;\n\n    ggml_cuda_mm_fusion_args_device empty{};\n    switch (src0->type) {\n        case GGML_TYPE_F32: {\n            const float * src0_d = (const float *) src0_dd_i;\n            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,\n                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);\n        } break;\n        case GGML_TYPE_F16: {\n            const half * src0_d = (const half *) src0_dd_i;\n            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,\n                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);\n        } break;\n        case GGML_TYPE_BF16: {\n            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;\n            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,\n                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);\n        } break;\n        default:\n            GGML_ABORT(\"unsupported type: %s\", ggml_type_name(src0->type));\n    }\n\n    GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size);\n}\n\nbool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11) {\n    if (src0_ne[0] % 2 != 0) {\n        return false;\n    }\n\n    const size_t ts = ggml_type_size(type);\n    if (src0_nb[0] != ts) {\n        return false;\n    }\n\n    // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:\n    for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {\n        if (src0_nb[i] % (2*ts) != 0) {\n            return false;\n        }\n    }\n\n    switch (type) {\n        case GGML_TYPE_F32:\n            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {\n                if (ampere_mma_available(cc)) {\n                    return ne11 <= 3;\n                }\n                if (cc >= GGML_CUDA_CC_TURING) {\n                    return ne11 <= 4;\n                }\n                return ne11 <= 3;\n            } else if (GGML_CUDA_CC_IS_AMD(cc)) {\n                if (fp32_mma_hardware_available(cc)) {\n                    return ne11 <= 3;\n                }\n                return ne11 <= 8;\n            }\n            return ne11 <= 8;\n        case GGML_TYPE_F16:\n            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {\n                const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);\n                if (ampere_mma_available(cc)) {\n                    return src0_small && ne11 == 1;\n                }\n                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {\n                    return src0_small && ne11 <= 4;\n                }\n                if (fp16_mma_hardware_available(cc)) {\n                    return src0_small && ne11 <= 3;\n                }\n                return ne11 <= 8;\n            } else if (GGML_CUDA_CC_IS_AMD(cc)) {\n                if (fp16_mma_hardware_available(cc)) {\n                    if (GGML_CUDA_CC_IS_RDNA3(cc)) {\n                        return ne11 <= 3;\n                    }\n                    if (GGML_CUDA_CC_IS_RDNA4(cc)) {\n                        return ne11 <= 5;\n                    }\n                    return ne11 <= 2;\n                }\n                return ne11 <= 8;\n            }\n            return ne11 <= 8;\n        case GGML_TYPE_BF16:\n            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {\n                const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);\n                if (ampere_mma_available(cc)) {\n                    return src0_small && ne11 == 1;\n                }\n                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {\n                    return src0_small && ne11 <= 4;\n                }\n                if (bf16_mma_hardware_available(cc)) {\n                    return src0_small && ne11 <= 3;\n                }\n                return ne11 <= 8;\n            } else if (GGML_CUDA_CC_IS_AMD(cc)) {\n                if (bf16_mma_hardware_available(cc)) {\n                    return ne11 <= 3;\n                }\n                return ne11 <= 8;\n            }\n            return ne11 <= 8;\n        default:\n            return false;\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/mmvf.cuh",
    "content": "#include \"common.cuh\"\n\n#define MMVF_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVF kernels.\n\nvoid ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,\n    const ggml_cuda_mm_fusion_args_host * fusion = nullptr);\n\nvoid ggml_cuda_op_mul_mat_vec_f(\n    ggml_backend_cuda_context & ctx,\n    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,\n    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,\n    const int64_t src1_padded_row_size, cudaStream_t stream);\n\nbool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11);\n"
  },
  {
    "path": "src/ggml-cuda/mmvq.cu",
    "content": "#include \"mmvq.cuh\"\n#include \"quantize.cuh\"\n#include \"unary.cuh\"\n#include \"vecdotq.cuh\"\n\n#include <cstdint>\n\ntypedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);\n\nstatic constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_Q4_0:    return vec_dot_q4_0_q8_1;\n        case GGML_TYPE_Q4_1:    return vec_dot_q4_1_q8_1;\n        case GGML_TYPE_Q5_0:    return vec_dot_q5_0_q8_1;\n        case GGML_TYPE_Q5_1:    return vec_dot_q5_1_q8_1;\n        case GGML_TYPE_Q8_0:    return vec_dot_q8_0_q8_1;\n        case GGML_TYPE_MXFP4:   return vec_dot_mxfp4_q8_1;\n        case GGML_TYPE_Q2_K:    return vec_dot_q2_K_q8_1;\n        case GGML_TYPE_Q3_K:    return vec_dot_q3_K_q8_1;\n        case GGML_TYPE_Q4_K:    return vec_dot_q4_K_q8_1;\n        case GGML_TYPE_Q5_K:    return vec_dot_q5_K_q8_1;\n        case GGML_TYPE_Q6_K:    return vec_dot_q6_K_q8_1;\n        case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1;\n        case GGML_TYPE_IQ2_XS:  return vec_dot_iq2_xs_q8_1;\n        case GGML_TYPE_IQ2_S:   return vec_dot_iq2_s_q8_1;\n        case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1;\n        case GGML_TYPE_IQ1_S:   return vec_dot_iq1_s_q8_1;\n        case GGML_TYPE_IQ1_M:   return vec_dot_iq1_m_q8_1;\n        case GGML_TYPE_IQ4_NL:  return vec_dot_iq4_nl_q8_1;\n        case GGML_TYPE_IQ4_XS:  return vec_dot_iq4_xs_q8_1;\n        case GGML_TYPE_IQ3_S:   return vec_dot_iq3_s_q8_1;\n        default:                return nullptr;\n    }\n}\n\nstatic constexpr __device__ int get_vdr_mmvq(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_Q4_0:    return VDR_Q4_0_Q8_1_MMVQ;\n        case GGML_TYPE_Q4_1:    return VDR_Q4_1_Q8_1_MMVQ;\n        case GGML_TYPE_Q5_0:    return VDR_Q5_0_Q8_1_MMVQ;\n        case GGML_TYPE_Q5_1:    return VDR_Q5_1_Q8_1_MMVQ;\n        case GGML_TYPE_Q8_0:    return VDR_Q8_0_Q8_1_MMVQ;\n        case GGML_TYPE_MXFP4:   return VDR_MXFP4_Q8_1_MMVQ;\n        case GGML_TYPE_Q2_K:    return VDR_Q2_K_Q8_1_MMVQ;\n        case GGML_TYPE_Q3_K:    return VDR_Q3_K_Q8_1_MMVQ;\n        case GGML_TYPE_Q4_K:    return VDR_Q4_K_Q8_1_MMVQ;\n        case GGML_TYPE_Q5_K:    return VDR_Q5_K_Q8_1_MMVQ;\n        case GGML_TYPE_Q6_K:    return VDR_Q6_K_Q8_1_MMVQ;\n        case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ;\n        case GGML_TYPE_IQ2_XS:  return VDR_IQ2_XS_Q8_1_MMVQ;\n        case GGML_TYPE_IQ2_S:   return VDR_IQ2_S_Q8_1_MMVQ;\n        case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ;\n        case GGML_TYPE_IQ3_S:   return VDR_IQ3_S_Q8_1_MMVQ;\n        case GGML_TYPE_IQ4_NL:  return VDR_IQ4_NL_Q8_1_MMVQ;\n        case GGML_TYPE_IQ4_XS:  return VDR_IQ4_XS_Q8_1_MMVQ;\n        default:                return 1;\n    }\n}\n\nenum mmvq_parameter_table_id {\n    MMVQ_PARAMETERS_GENERIC = 0,\n    MMVQ_PARAMETERS_GCN,\n    MMVQ_PARAMETERS_RDNA2,\n    MMVQ_PARAMETERS_RDNA3_0,\n    MMVQ_PARAMETERS_RDNA4\n};\n\nstatic constexpr __device__ mmvq_parameter_table_id get_device_table_id() {\n#if defined(RDNA4)\n    return MMVQ_PARAMETERS_RDNA4;\n#elif defined(RDNA3_0)\n    return MMVQ_PARAMETERS_RDNA3_0;\n#elif defined(RDNA2) || defined(RDNA3_5)\n    return MMVQ_PARAMETERS_RDNA2;\n#elif defined(GCN) || defined(CDNA)\n    return MMVQ_PARAMETERS_GCN;\n#else\n    return MMVQ_PARAMETERS_GENERIC;\n#endif\n}\n\nstatic __host__ mmvq_parameter_table_id get_device_table_id(int cc) {\n    if (GGML_CUDA_CC_IS_RDNA4(cc)) {\n        return MMVQ_PARAMETERS_RDNA4;\n    }\n    if (GGML_CUDA_CC_IS_RDNA3_0(cc)) {\n        return MMVQ_PARAMETERS_RDNA3_0;\n    }\n    if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc)) {\n        return MMVQ_PARAMETERS_RDNA2;\n    }\n    if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {\n        return MMVQ_PARAMETERS_GCN;\n    }\n    return MMVQ_PARAMETERS_GENERIC;\n}\n\nstatic constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {\n    if (table_id == MMVQ_PARAMETERS_GENERIC) {\n        switch (ncols_dst) {\n            case 1:\n            case 2:\n            case 3:\n            case 4:\n                return 4;\n            case 5:\n            case 6:\n            case 7:\n            case 8:\n                return 2;\n            default:\n                return 1;\n        }\n    } else if (table_id == MMVQ_PARAMETERS_GCN) {\n        switch (ncols_dst) {\n            case 1:\n            case 2:\n            case 3:\n            case 4:\n                return 2;\n            case 5:\n            case 6:\n            case 7:\n            case 8:\n            default:\n                return 1;\n        }\n    }\n    if (table_id == MMVQ_PARAMETERS_RDNA4) {\n        // nwarps=8 benefits types with simple vec_dot on RDNA4 (ncols_dst=1).\n        // Types with complex vec_dot (Q3_K, IQ2_*, IQ3_*) regress due to register\n        // pressure and lookup table contention at higher thread counts.\n        if (ncols_dst == 1) {\n            switch (type) {\n                case GGML_TYPE_Q4_0:\n                case GGML_TYPE_Q4_1:\n                case GGML_TYPE_Q5_0:\n                case GGML_TYPE_Q5_1:\n                case GGML_TYPE_Q8_0:\n                case GGML_TYPE_Q2_K:\n                case GGML_TYPE_Q4_K:\n                case GGML_TYPE_Q5_K:\n                case GGML_TYPE_Q6_K:\n                case GGML_TYPE_IQ4_NL:\n                case GGML_TYPE_IQ4_XS:\n                    return 8;\n                default:\n                    return 1;\n            }\n        }\n        return 1;\n    }\n    if (table_id == MMVQ_PARAMETERS_RDNA3_0) {\n        // RDNA3 (W7900): stricter whitelist than RDNA4.\n        // Q2_K / Q5_K / IQ4_XS regress in full quant sweeps.\n        if (ncols_dst == 1) {\n            switch (type) {\n                case GGML_TYPE_Q4_0:\n                case GGML_TYPE_Q4_1:\n                case GGML_TYPE_Q5_0:\n                case GGML_TYPE_Q5_1:\n                case GGML_TYPE_Q8_0:\n                case GGML_TYPE_Q4_K:\n                case GGML_TYPE_Q6_K:\n                case GGML_TYPE_IQ4_NL:\n                    return 8;\n                default:\n                    return 1;\n            }\n        }\n        return 1;\n    }\n    return 1;\n}\n\nstatic constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {\n    if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {\n        switch (ncols_dst) {\n            case 1:\n                return 1;\n            case 2:\n            case 3:\n            case 4:\n            case 5:\n            case 6:\n            case 7:\n            case 8:\n                return 2;\n            default:\n                return 1;\n        }\n    }\n    return 1;\n}\n\ntemplate <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>\n__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)\nstatic __global__ void mul_mat_vec_q(\n        const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,\n        const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,\n        const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,\n        const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,\n        const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,\n        const uint32_t ids_stride) {\n\n    constexpr int qk  = ggml_cuda_type_traits<type>::qk;\n    constexpr int qi  = ggml_cuda_type_traits<type>::qi;\n    constexpr int vdr = get_vdr_mmvq(type);\n    constexpr mmvq_parameter_table_id table_id = get_device_table_id();\n    constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);\n    constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);\n    constexpr int warp_size = ggml_cuda_get_physical_warp_size();\n\n    constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);\n\n    const     int tid = warp_size*threadIdx.y + threadIdx.x;\n    const     int row0 = rows_per_cuda_block*blockIdx.x;\n    const     int blocks_per_row_x = ncols_x / qk;\n    constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;\n\n    const uint32_t channel_dst = blockIdx.y;\n\n    uint32_t token_idx = 0;\n    uint32_t channel_x;\n    uint32_t channel_y;\n    uint32_t sample_dst;\n\n    if constexpr (is_multi_token_id) {\n        // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case\n        token_idx  = blockIdx.z;\n        channel_x  = ids[channel_dst + token_idx * ids_stride];\n        channel_y  = fastmodulo(channel_dst, nchannels_y);\n        sample_dst = 0;\n    } else {\n        channel_x  = ncols_dst == 1 && ids ? ids[channel_dst]                     : fastdiv(channel_dst, channel_ratio);\n        channel_y  = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;\n        sample_dst = blockIdx.z;\n    }\n\n    const uint32_t sample_x    = fastdiv(sample_dst, sample_ratio);\n    const uint32_t sample_y    = sample_dst;\n\n    bool use_gate = false;\n    bool use_bias = false;\n    bool use_gate_bias = false;\n    const void * vgate = nullptr;\n    const float * x_bias = nullptr;\n    const float * gate_bias = nullptr;\n    ggml_glu_op active_glu;\n\n    if constexpr (has_fusion) {\n        use_gate      = fusion.gate      != nullptr;\n        use_bias      = fusion.x_bias    != nullptr;\n        use_gate_bias = fusion.gate_bias != nullptr && use_gate;\n        vgate         = fusion.gate;\n        x_bias        = (const float *) fusion.x_bias;\n        gate_bias     = (const float *) fusion.gate_bias;\n        active_glu    = fusion.glu_op;\n    }\n\n\n    float x_biases[ncols_dst]    = { 0.0f };\n    float gate_biases[ncols_dst] = { 0.0f };\n    if constexpr (has_fusion) {\n        const uint32_t channel_bias = ids ? channel_x : channel_dst;\n        if (use_bias) {\n            x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;\n            // 1. Hide latency by prefetching bias and gate here\n            // 2. load only on threads that won't die after partial sum calculation\n            if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&\n                (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {\n#pragma unroll\n                for (int j = 0; j < ncols_dst; ++j) {\n                    x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x];\n                }\n            }\n        }\n        if (use_gate_bias) {\n            gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;\n            if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&\n                (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {\n#pragma unroll\n                for (int j = 0; j < ncols_dst; ++j) {\n                    gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x];\n                }\n            }\n        }\n    }\n\n    // partial sum for each thread\n    float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};\n    float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};\n\n    const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;\n    if constexpr (is_multi_token_id) {\n        y += token_idx*stride_col_y;\n    }\n    const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;\n\n    for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {\n        const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx\n\n        // x block quant index when casting the quants to int\n        const int kqs = vdr * (tid % (qi/vdr));\n\n#pragma unroll\n        for (int j = 0; j < ncols_dst; ++j) {\n#pragma unroll\n            for (int i = 0; i < rows_per_cuda_block; ++i) {\n                tmp[j][i] += vec_dot_q_cuda(\n                    vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);\n                if constexpr (has_fusion) {\n                    if (use_gate) {\n                        tmp_gate[j][i] += vec_dot_q_cuda(\n                            vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);\n                    }\n                }\n            }\n        }\n    }\n\n    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];\n    __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];\n    if constexpr (!has_fusion) {\n        (void) tmp_shared_gate;\n    } else if (!use_gate) {\n        (void) tmp_shared_gate;\n    }\n\n    if (threadIdx.y > 0) {\n#pragma unroll\n        for (int j = 0; j < ncols_dst; ++j) {\n#pragma unroll\n            for (int i = 0; i < rows_per_cuda_block; ++i) {\n                tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];\n                if constexpr (has_fusion) {\n                    if (use_gate) {\n                        tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i];\n                    }\n                }\n            }\n        }\n    }\n    __syncthreads();\n    if (threadIdx.y > 0) {\n        return;\n    }\n\n    dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;\n\n    if constexpr (is_multi_token_id) {\n        dst += token_idx*stride_col_dst;\n    }\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int j = 0; j < ncols_dst; ++j) {\n#pragma unroll\n        for (int i = 0; i < rows_per_cuda_block; ++i) {\n#pragma unroll\n            for (int l = 0; l < nwarps-1; ++l) {\n                tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];\n                if constexpr (has_fusion) {\n                    if (use_gate) {\n                        tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x];\n                    }\n                }\n            }\n            tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);\n            if constexpr (has_fusion) {\n                if (use_gate) {\n                    tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]);\n                }\n            }\n        }\n\n        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {\n            float result = tmp[j][threadIdx.x];\n            if constexpr (has_fusion) {\n                if (use_bias) {\n                    result += x_biases[j];\n                }\n                if (use_gate) {\n                    float gate_value = tmp_gate[j][threadIdx.x];\n                    if (use_gate_bias) {\n                        gate_value += gate_biases[j];\n                    }\n                    switch (active_glu) {\n                        case GGML_GLU_OP_SWIGLU:\n                            result *= ggml_cuda_op_silu_single(gate_value);\n                            break;\n                        case GGML_GLU_OP_GEGLU:\n                            result *= ggml_cuda_op_gelu_single(gate_value);\n                            break;\n                        case GGML_GLU_OP_SWIGLU_OAI: {\n                            result = ggml_cuda_op_swiglu_oai_single(gate_value, result);\n                            break;\n                        }\n                        default:\n                            result = result * gate_value;\n                            break;\n                    }\n                }\n            }\n            dst[j*stride_col_dst + threadIdx.x] = result;\n        }\n    }\n\n    if constexpr (!has_fusion) {\n        GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate);\n    }\n}\n\ntemplate<ggml_type type>\nstatic std::pair<dim3, dim3> calc_launch_params(\n        const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,\n        const int warp_size, const mmvq_parameter_table_id table_id) {\n    const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);\n    const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);\n    const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1);\n    return {block_nums, block_dims};\n}\n\ntemplate<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>\nstatic void mul_mat_vec_q_switch_fusion(\n        const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,\n        const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,\n        const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,\n        const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,\n        const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,\n        const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared,\n        const uint32_t ids_stride, cudaStream_t stream) {\n\n    const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;\n    if constexpr (c_ncols_dst == 1) {\n        if (has_fusion) {\n            mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>\n                (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,\n                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);\n            return;\n        }\n    }\n\n    GGML_ASSERT(!has_fusion && \"fusion only supported for ncols_dst=1\");\n\n    mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>\n        (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,\n        channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,\n        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);\n}\n\ntemplate <ggml_type type>\nstatic void mul_mat_vec_q_switch_ncols_dst(\n        const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,\n        const int ncols_x, const int nrows_x, const int ncols_dst,\n        const int stride_row_x, const int stride_col_y, const int stride_col_dst,\n        const int nchannels_x, const int nchannels_y, const int nchannels_dst,\n        const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,\n        const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,\n        const int ids_stride, cudaStream_t stream) {\n\n    GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);\n    GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);\n\n    const uint3 nchannels_y_fd   = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);\n    const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0)              : init_fastdiv_values(nchannels_dst / nchannels_x);\n    const uint3 sample_ratio_fd  = init_fastdiv_values(nsamples_dst  / nsamples_x);\n\n    const int device = ggml_cuda_get_device();\n    const int warp_size = ggml_cuda_info().devices[device].warp_size;\n    const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);\n\n    const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;\n    const bool has_ids = ids != nullptr;\n\n    if (has_ids && ncols_dst > 1) {\n        // Multi-token MUL_MAT_ID path only - single-token goes through regular path below\n        constexpr int c_ncols_dst = 1;\n        std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);\n        mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,\n             channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n             sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,\n             dims.first, dims.second, 0, ids_stride, stream);\n        return;\n    }\n\n    switch (ncols_dst) {\n        case 1: {\n            constexpr int c_ncols_dst = 1;\n            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);\n            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 dims.first, dims.second, 0, ids_stride, stream);\n        } break;\n        case 2: {\n            constexpr int c_ncols_dst = 2;\n            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);\n            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 dims.first, dims.second, 0, ids_stride, stream);\n        } break;\n        case 3: {\n            constexpr int c_ncols_dst = 3;\n            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);\n            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 dims.first, dims.second, 0, ids_stride, stream);\n        } break;\n        case 4: {\n            constexpr int c_ncols_dst = 4;\n            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);\n            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 dims.first, dims.second, 0, ids_stride, stream);\n        } break;\n        case 5: {\n            constexpr int c_ncols_dst = 5;\n            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);\n            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 dims.first, dims.second, 0, ids_stride, stream);\n        } break;\n        case 6: {\n            constexpr int c_ncols_dst = 6;\n            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);\n            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 dims.first, dims.second, 0, ids_stride, stream);\n        } break;\n        case 7: {\n            constexpr int c_ncols_dst = 7;\n            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);\n            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 dims.first, dims.second, 0, ids_stride, stream);\n        } break;\n        case 8: {\n            constexpr int c_ncols_dst = 8;\n            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);\n            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,\n                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,\n                 dims.first, dims.second, 0, ids_stride, stream);\n        } break;\n        default:\n            GGML_ABORT(\"fatal error\");\n            break;\n    }\n\n    GGML_UNUSED(has_fusion);\n}\nstatic void mul_mat_vec_q_switch_type(\n        const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,\n        const int ncols_x, const int nrows_x, const int ncols_dst,\n        const int stride_row_x, const int stride_col_y, const int stride_col_dst,\n        const int nchannels_x, const int nchannels_y, const int nchannels_dst,\n        const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,\n        const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,\n        const int ids_stride, cudaStream_t stream) {\n    switch (type_x) {\n        case GGML_TYPE_Q4_0:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_Q4_1:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_Q5_0:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_Q5_1:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_Q8_0:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_MXFP4:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_Q2_K:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_Q3_K:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_Q4_K:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_Q5_K:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_Q6_K:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_IQ2_XXS:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_IQ2_XS:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_IQ2_S:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_IQ3_XXS:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_IQ1_S:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_IQ1_M:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_IQ4_NL:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_IQ4_XS:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        case GGML_TYPE_IQ3_S:\n            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>\n                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,\n                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,\n                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);\n            break;\n        default:\n            GGML_ABORT(\"fatal error\");\n            break;\n    }\n}\n\nvoid ggml_cuda_mul_mat_vec_q(\n        ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,\n        const ggml_cuda_mm_fusion_args_host * fusion) {\n    GGML_ASSERT(        src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(        dst->type  == GGML_TYPE_F32);\n    GGML_ASSERT(!ids || ids->type  == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.\n\n    GGML_TENSOR_BINARY_OP_LOCALS;\n\n    cudaStream_t stream = ctx.stream();\n\n    const size_t ts_src0 = ggml_type_size(src0->type);\n    const size_t ts_src1 = ggml_type_size(src1->type);\n    const size_t ts_dst  = ggml_type_size(dst->type);\n\n    GGML_ASSERT(        nb00       == ts_src0);\n    GGML_ASSERT(        nb10       == ts_src1);\n    GGML_ASSERT(        nb0        == ts_dst);\n    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));\n\n    GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE);\n\n    const float   * src1_d =       (const float   *) src1->data;\n    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;\n    float         *  dst_d =       (float         *)  dst->data;\n\n    ggml_cuda_mm_fusion_args_device fusion_local{};\n\n    if (fusion) {\n        GGML_ASSERT( !ids || dst->ne[2] == 1);\n        GGML_ASSERT(  ids || dst->ne[1] == 1);\n\n        if (fusion->x_bias) {\n            GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);\n            GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);\n            GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);\n            fusion_local.x_bias = fusion->x_bias->data;\n        }\n        if (fusion->gate) {\n            GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));\n            fusion_local.gate = fusion->gate->data;\n        }\n        if (fusion->gate_bias) {\n            GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);\n            GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);\n            GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);\n            fusion_local.gate_bias = fusion->gate_bias->data;\n        }\n        fusion_local.glu_op = fusion->glu_op;\n    }\n\n    // If src0 is a temporary compute buffer, clear any potential padding.\n    if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {\n        const size_t size_data  = ggml_nbytes(src0);\n        const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);\n        if (size_alloc > size_data) {\n            GGML_ASSERT(ggml_is_contiguously_allocated(src0));\n            GGML_ASSERT(!src0->view_src);\n            CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));\n        }\n    }\n\n    const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);\n    ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);\n    {\n        const int64_t s11 = src1->nb[1] / ts_src1;\n        const int64_t s12 = src1->nb[2] / ts_src1;\n        const int64_t s13 = src1->nb[3] / ts_src1;\n        quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);\n    }\n\n    const int64_t s01 = src0->nb[1] / ts_src0;\n    const int64_t s11 = ne10_padded / QK8_1;\n    const int64_t s1  =  dst->nb[1] / ts_dst;\n    const int64_t s02 = src0->nb[2] / ts_src0;\n    const int64_t s2  =  dst->nb[2] / ts_dst;\n    const int64_t s03 = src0->nb[3] / ts_src0;\n    const int64_t s3  =  dst->nb[3] / ts_dst;\n\n    const int64_t s12 = ne11*s11;\n    const int64_t s13 = ne12*s12;\n\n    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:\n    const int64_t ncols_dst          = ids ? ne2  : ne1;\n    const int64_t nchannels_y        = ids ? ne11 : ne12;\n    const int64_t nchannels_dst      = ids ? ne1  : ne2;\n    const int64_t stride_col_dst     = ids ? s2   : s1;\n    const int64_t stride_col_y       = ids ? s12  : s11;\n    const int64_t stride_channel_dst = ids ? s1   : s2;\n    const int64_t stride_channel_y   = ids ? s11  : s12;\n\n    const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;\n\n    mul_mat_vec_q_switch_type(\n        src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,\n        ne01,              ncols_dst,     s01, stride_col_y,     stride_col_dst,\n        ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,\n        ne03,              ne3,           s03, s13,              s3,               ids_stride, stream);\n}\n\nvoid ggml_cuda_op_mul_mat_vec_q(\n    ggml_backend_cuda_context & ctx,\n    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,\n    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,\n    const int64_t src1_padded_row_size, cudaStream_t stream) {\n\n    const int64_t ne00 = src0->ne[0];\n    const int64_t row_diff = row_high - row_low;\n\n    const int64_t ne10 = src1->ne[0];\n    GGML_ASSERT(ne10 % QK8_1 == 0);\n\n    const int64_t ne0 = dst->ne[0];\n\n    int id = ggml_cuda_get_device();\n\n    // the main device has a larger memory buffer to hold the results from all GPUs\n    // nrows_dst == nrows of the matrix that the kernel writes into\n    const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;\n\n    const int stride_row_x = ne00 / ggml_blck_size(src0->type);\n    const int stride_col_y = src1_padded_row_size / QK8_1;\n\n    ggml_cuda_mm_fusion_args_device fusion_local{};\n    mul_mat_vec_q_switch_type(\n        src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,\n        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream);\n\n    GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);\n}\n"
  },
  {
    "path": "src/ggml-cuda/mmvq.cuh",
    "content": "#include \"common.cuh\"\n\n#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.\n#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID\n\nvoid ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,\n    const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);\n\nvoid ggml_cuda_op_mul_mat_vec_q(\n    ggml_backend_cuda_context & ctx,\n    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,\n    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,\n    const int64_t src1_padded_row_size, cudaStream_t stream);\n"
  },
  {
    "path": "src/ggml-cuda/norm.cu",
    "content": "#include \"norm.cuh\"\n#include <cstdint>\n\ntemplate <int block_size>\nstatic __global__ void norm_f32(\n        const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,\n        const int64_t stride_sample, const float eps) {\n    const int nrows     = gridDim.x;\n    const int nchannels = gridDim.y;\n\n    const int row       = blockIdx.x;\n    const int channel   = blockIdx.y;\n    const int sample    = blockIdx.z;\n    const int tid       = threadIdx.x;\n\n    x   += sample*stride_sample + channel*stride_channel + row*stride_row;\n    dst += ((sample*nchannels + channel)*nrows + row)*ncols;\n\n    float2 mean_var = make_float2(0.0f, 0.0f);\n\n    for (int col = tid; col < ncols; col += block_size) {\n        const float xi = x[col];\n        mean_var.x += xi;\n        mean_var.y += xi * xi;\n    }\n\n    // sum up partial sums\n    extern __shared__ float2 s_sum2[];\n    mean_var = block_reduce<block_reduce_method::SUM, block_size>(mean_var, s_sum2);\n\n    const float mean = mean_var.x / ncols;\n    const float var = mean_var.y / ncols - mean * mean;\n    const float inv_std = rsqrtf(var + eps);\n\n    for (int col = tid; col < ncols; col += block_size) {\n        dst[col] = (x[col] - mean) * inv_std;\n    }\n}\n\ntemplate <int block_size>\nstatic __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {\n    // blockIdx.x: num_groups idx\n    // threadIdx.x: block_size idx\n    const int start =     blockIdx.x*group_size + threadIdx.x;\n    const int end   = min(blockIdx.x*group_size + group_size,  ne_elements);\n\n    float tmp = 0.0f; // partial sum for thread in warp\n\n    for (int j = start; j < end; j += block_size) {\n        tmp += x[j];\n    }\n\n    extern __shared__ float s_sum[];\n    tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);\n\n    const float mean = tmp / group_size;\n    tmp = 0.0f;\n\n    for (int j = start; j < end; j += block_size) {\n        const float xi = x[j] - mean;\n        dst[j] = xi;\n        tmp += xi * xi;\n    }\n\n    tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);\n\n    const float variance = tmp / group_size;\n    const float scale = rsqrtf(variance + eps);\n    for (int j = start; j < end; j += block_size) {\n        dst[j] *= scale;\n    }\n}\n\ntemplate <int block_size, bool do_multiply = false, bool do_add = false>\nstatic __global__ void rms_norm_f32(const float * x,\n                                    float *       dst,\n                                    const int     ncols,\n                                    const int64_t stride_row,\n                                    const int64_t stride_channel,\n                                    const int64_t stride_sample,\n                                    const float   eps,\n                                    const float * mul                  = nullptr,\n                                    const int64_t mul_stride_row       = 0,\n                                    const int64_t mul_stride_channel   = 0,\n                                    const int64_t mul_stride_sample    = 0,\n                                    const uint3   mul_ncols_packed     = make_uint3(0, 0, 0),\n                                    const uint3   mul_nrows_packed     = make_uint3(0, 0, 0),\n                                    const uint3   mul_nchannels_packed = make_uint3(0, 0, 0),\n                                    const uint3   mul_nsamples_packed  = make_uint3(0, 0, 0),\n                                    const float * add                  = nullptr,\n                                    const int64_t add_stride_row       = 0,\n                                    const int64_t add_stride_channel   = 0,\n                                    const int64_t add_stride_sample    = 0,\n                                    const uint3   add_ncols_packed     = make_uint3(0, 0, 0),\n                                    const uint3   add_nrows_packed     = make_uint3(0, 0, 0),\n                                    const uint3   add_nchannels_packed = make_uint3(0, 0, 0),\n                                    const uint3   add_nsamples_packed  = make_uint3(0, 0, 0)) {\n    const int nrows     = gridDim.x;\n    const int nchannels = gridDim.y;\n\n    const int row       = blockIdx.x;\n    const int channel   = blockIdx.y;\n    const int sample    = blockIdx.z;\n    const int tid       = threadIdx.x;\n\n    static_assert(!do_add || do_multiply, \"fusing add is not supported without multiplying\");\n\n    x   += sample*stride_sample + channel*stride_channel + row*stride_row;\n    dst += ((sample*nchannels + channel)*nrows + row)*ncols;\n\n    if constexpr (do_multiply) {\n        const uint32_t mul_row     = fastmodulo(row, mul_nrows_packed);\n        const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);\n        const uint32_t mul_sample  = fastmodulo(sample, mul_nsamples_packed);\n        mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;\n    }\n\n    if constexpr (do_add) {\n        const int add_row     = fastmodulo(row, add_nrows_packed);\n        const int add_channel = fastmodulo(channel, add_nchannels_packed);\n        const int add_sample  = fastmodulo(sample, add_nsamples_packed);\n        add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;\n    }\n\n    float tmp = 0.0f; // partial sum for thread in warp\n\n    for (int col = tid; col < ncols; col += block_size) {\n        const float xi = x[col];\n        tmp += xi * xi;\n    }\n\n    // sum up partial sums\n    extern __shared__ float s_sum[];\n    tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);\n\n    const float mean = tmp / ncols;\n    const float scale = rsqrtf(mean + eps);\n\n    for (int col = tid; col < ncols; col += block_size) {\n        if constexpr (do_multiply && do_add) {\n            const int mul_col = fastmodulo(col, mul_ncols_packed);\n            const int add_col = fastmodulo(col, add_ncols_packed);\n            dst[col]          = scale * x[col] * mul[mul_col] + add[add_col];\n        } else if constexpr (do_multiply) {\n            const int mul_col = fastmodulo(col, mul_ncols_packed);\n            dst[col]          = scale * x[col] * mul[mul_col];\n        } else {\n            dst[col] = scale * x[col];\n        }\n    }\n}\n\ntemplate <int block_size>\nstatic __global__ void rms_norm_back_f32(\n        const float * grad, const float * xf, float * dst, const int ncols, const float eps) {\n    const int row = blockIdx.x*blockDim.y + threadIdx.y;\n    const int tid = threadIdx.x;\n\n    grad += int64_t(row)*ncols;\n    xf   += int64_t(row)*ncols;\n    dst  += int64_t(row)*ncols;\n\n    float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass\n    float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs\n\n    for (int col = tid; col < ncols; col += block_size) {\n        const float xfi = xf[col];\n        sum_xx += xfi * xfi;\n        sum_xg += xfi * grad[col];\n    }\n\n    // sum up partial sums\n    sum_xx = warp_reduce_sum(sum_xx);\n    sum_xg = warp_reduce_sum(sum_xg);\n    if constexpr (block_size > WARP_SIZE) {\n        static_assert(block_size == 1024, \"unexpected block_size\");\n        __shared__ float s_sum_xx[32];\n        __shared__ float s_sum_xg[32];\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        const int lane_id = threadIdx.x % WARP_SIZE;\n        if (lane_id == 0) {\n            s_sum_xx[warp_id] = sum_xx;\n            s_sum_xg[warp_id] = sum_xg;\n        }\n        __syncthreads();\n\n        sum_xx = s_sum_xx[lane_id];\n        sum_xx = warp_reduce_sum(sum_xx);\n\n        sum_xg = s_sum_xg[lane_id];\n        sum_xg = warp_reduce_sum(sum_xg);\n    }\n\n    const float mean_eps = sum_xx / ncols + eps;\n    const float sum_eps  = sum_xx + ncols*eps;\n\n    const float scale_grad = rsqrtf(mean_eps);\n    const float scale_x    = -scale_grad * sum_xg/sum_eps;\n\n    for (int col = tid; col < ncols; col += block_size) {\n        dst[col] = scale_grad*grad[col] + scale_x*xf[col];\n    }\n}\n\n// template <int block_size>\n// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {\n//     const int row = blockIdx.x*blockDim.y + threadIdx.y;\n//     const int tid = threadIdx.x;\n\n//     float tmp = 0.0f; // partial sum for thread in warp\n\n//     for (int col = tid; col < ncols; col += block_size) {\n//         const float xi = x[row*ncols + col];\n//         tmp += xi * xi;\n//     }\n\n//     // sum up partial sums\n//     tmp = warp_reduce_sum(tmp);\n//     if (block_size > WARP_SIZE) {\n//         __shared__ float s_sum[32];\n//         int warp_id = threadIdx.x / WARP_SIZE;\n//         int lane_id = threadIdx.x % WARP_SIZE;\n//         if (lane_id == 0) {\n//             s_sum[warp_id] = tmp;\n//         }\n//         __syncthreads();\n//         tmp = s_sum[lane_id];\n//         tmp = warp_reduce_sum(tmp);\n//     }\n\n//     // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html\n//     const float scale = rsqrtf(fmaxf(tmp, eps * eps));\n\n//     for (int col = tid; col < ncols; col += block_size) {\n//         dst[row*ncols + col] = scale * x[row*ncols + col];\n//     }\n// }\n\ntemplate <int block_size>\nstatic __global__ void l2_norm_f32(\n        const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,\n        const int64_t stride_sample, const float eps) {\n    const int nrows     = gridDim.x;\n    const int nchannels = gridDim.y;\n\n    const int row       = blockIdx.x;\n    const int channel   = blockIdx.y;\n    const int sample    = blockIdx.z;\n    const int tid       = threadIdx.x;\n\n    x   += sample*stride_sample + channel*stride_channel + row*stride_row;\n    dst += ((sample*nchannels + channel)*nrows + row)*ncols;\n\n    float tmp = 0.0f; // partial sum for thread in warp\n\n    for (int col = tid; col < ncols; col += block_size) {\n        const float xi = x[col];\n        tmp += xi * xi;\n    }\n\n    // sum up partial sums\n    extern __shared__ float s_sum[];\n    tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);\n\n    // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html\n    const float scale = rsqrtf(fmaxf(tmp, eps * eps));\n\n    for (int col = tid; col < ncols; col += block_size) {\n        dst[col] = scale * x[col];\n    }\n}\n\nstatic void norm_f32_cuda(\n        const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,\n        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {\n    const dim3 blocks_num(nrows, nchannels, nsamples);\n    if (ncols < 1024) {\n        const dim3 block_dims(WARP_SIZE, 1, 1);\n        norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);\n    } else {\n        const dim3 block_dims(1024, 1, 1);\n        norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);\n    }\n}\n\nstatic void group_norm_f32_cuda(\n        const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {\n    if (group_size < 1024) {\n        const dim3 block_dims(WARP_SIZE, 1, 1);\n        group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);\n    } else {\n        const dim3 block_dims(1024, 1, 1);\n        group_norm_f32<1024><<<num_groups, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps);\n    }\n}\n\nstatic void rms_norm_f32_cuda(\n        const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,\n        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {\n    const dim3 blocks_num(nrows, nchannels, nsamples);\n    if (ncols < 1024) {\n        const dim3 block_dims(256, 1, 1);\n        rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);\n    } else {\n        const dim3 block_dims(1024, 1, 1);\n        rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);\n    }\n}\n\nstatic void rms_norm_mul_f32_cuda(const float *  x,\n                                  const float *  mul,\n                                  const float *  add,\n                                  float *        dst,\n                                  const int      ncols,\n                                  const int      nrows,\n                                  const int      nchannels,\n                                  const int      nsamples,\n                                  const int64_t  stride_row,\n                                  const int64_t  stride_channel,\n                                  const int64_t  stride_sample,\n                                  const int64_t  mul_stride_row,\n                                  const int64_t  mul_stride_channel,\n                                  const int64_t  mul_stride_sample,\n                                  const uint32_t mul_ncols,\n                                  const uint32_t mul_nrows,\n                                  const uint32_t mul_nchannels,\n                                  const uint32_t mul_nsamples,\n                                  const int64_t  add_stride_row,\n                                  const int64_t  add_stride_channel,\n                                  const int64_t  add_stride_sample,\n                                  const uint32_t add_ncols,\n                                  const uint32_t add_nrows,\n                                  const uint32_t add_nchannels,\n                                  const uint32_t add_nsamples,\n                                  const float    eps,\n                                  cudaStream_t   stream) {\n    const dim3 blocks_num(nrows, nchannels, nsamples);\n    if (mul == nullptr) {\n        rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);\n        return;\n    }\n    if (add == nullptr) {\n        const uint3 mul_ncols_packed     = init_fastdiv_values(mul_ncols);\n        const uint3 mul_nrows_packed     = init_fastdiv_values(mul_nrows);\n        const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);\n        const uint3 mul_nsamples_packed  = init_fastdiv_values(mul_nsamples);\n        if (ncols < 1024) {\n            const dim3 block_dims(256, 1, 1);\n            rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(\n                x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,\n                mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);\n        } else {\n            const dim3 block_dims(1024, 1, 1);\n            rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(\n                x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,\n                mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);\n        }\n    } else {\n        const uint3 mul_ncols_packed     = init_fastdiv_values(mul_ncols);\n        const uint3 mul_nrows_packed     = init_fastdiv_values(mul_nrows);\n        const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);\n        const uint3 mul_nsamples_packed  = init_fastdiv_values(mul_nsamples);\n\n        const uint3 add_ncols_packed     = init_fastdiv_values(add_ncols);\n        const uint3 add_nrows_packed     = init_fastdiv_values(add_nrows);\n        const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels);\n        const uint3 add_nsamples_packed  = init_fastdiv_values(add_nsamples);\n        if (ncols < 1024) {\n            const dim3 block_dims(256, 1, 1);\n            rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(\n                x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,\n                mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,\n                add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,\n                add_nchannels_packed, add_nsamples_packed);\n        } else {\n            const dim3 block_dims(1024, 1, 1);\n            rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(\n                x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,\n                mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,\n                add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,\n                add_nchannels_packed, add_nsamples_packed);\n        }\n    }\n}\n\nstatic void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {\n    if (ncols < 1024) {\n        const dim3 block_dims(WARP_SIZE, 1, 1);\n        rms_norm_back_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);\n    } else {\n        const dim3 block_dims(1024, 1, 1);\n        rms_norm_back_f32<1024><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);\n    }\n}\n\nstatic void l2_norm_f32_cuda(\n        const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,\n        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {\n    const dim3 blocks_num(nrows, nchannels, nsamples);\n    if (ncols < 1024) {\n        const dim3 block_dims(WARP_SIZE, 1, 1);\n        l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);\n    } else {\n        const dim3 block_dims(1024, 1, 1);\n        l2_norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);\n    }\n}\n\nvoid ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src0_d = (const float *) src0->data;\n    float * dst_d = (float *) dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_UNARY_OP_LOCALS;\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n    GGML_ASSERT(eps >= 0.0f);\n\n    const size_t ts0 = ggml_type_size(src0->type);\n    GGML_ASSERT(nb00 == ts0);\n    const int64_t s01 = nb01 / ts0;\n    const int64_t s02 = nb02 / ts0;\n    const int64_t s03 = nb03 / ts0;\n\n    norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);\n}\n\nvoid ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src0_d = (const float *)src0->data;\n    float * dst_d = (float *)dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    int num_groups = dst->op_params[0];\n\n    float eps;\n    memcpy(&eps, dst->op_params + 1, sizeof(float));\n    GGML_ASSERT(eps >= 0.0f);\n\n    int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);\n    group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);\n}\n\nvoid ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src0_d = (const float *) src0->data;\n    float * dst_d = (float *) dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_UNARY_OP_LOCALS;\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n    GGML_ASSERT(eps >= 0.0f);\n\n    const size_t ts0 = ggml_type_size(src0->type);\n    GGML_ASSERT(nb00 == ts0);\n    const int64_t s01 = nb01 / ts0;\n    const int64_t s02 = nb02 / ts0;\n    const int64_t s03 = nb03 / ts0;\n\n    rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);\n}\n\nvoid ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {\n    const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];\n    float eps = 0.0f;\n\n    memcpy(&eps, dst->op_params, sizeof(float));\n\n    const float * src0_d = (const float *) rms_norm_src->data;\n    const float * mul_d = nullptr;\n    const ggml_tensor * mul_src = nullptr;\n\n    if (mul_tensor->src[0] == dst) {\n        mul_d = (float *) mul_tensor->src[1]->data;\n        mul_src = mul_tensor->src[1];\n    } else if(mul_tensor->src[1] == dst) {\n        mul_d = (float *) mul_tensor->src[0]->data;\n        mul_src = mul_tensor->src[0];\n    } else {\n        GGML_ASSERT(false);\n    }\n\n    float * dst_d = (float *) mul_tensor->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n    GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);\n    GGML_ASSERT(eps >= 0.0f);\n\n    const int64_t ne00 = rms_norm_src->ne[0];\n    const int64_t ne01 = rms_norm_src->ne[1];\n    const int64_t ne02 = rms_norm_src->ne[2];\n    const int64_t ne03 = rms_norm_src->ne[3];\n\n    const size_t ts0 = ggml_type_size(rms_norm_src->type);\n    GGML_ASSERT(rms_norm_src->nb[0] == ts0);\n    const int64_t s01 = rms_norm_src->nb[1] / ts0;\n    const int64_t s02 = rms_norm_src->nb[2] / ts0;\n    const int64_t s03 = rms_norm_src->nb[3] / ts0;\n\n    const size_t ts_mul = ggml_type_size(mul_src->type);\n    GGML_ASSERT(mul_src->nb[0] == ts_mul);\n    const int64_t mul_s01 = mul_src->nb[1] / ts_mul;\n    const int64_t mul_s02 = mul_src->nb[2] / ts_mul;\n    const int64_t mul_s03 = mul_src->nb[3] / ts_mul;\n\n    const int mul_ncols     = mul_src->ne[0];\n    const int mul_nrows     = mul_src->ne[1];\n    const int mul_nchannels = mul_src->ne[2];\n    const int mul_nsamples  = mul_src->ne[3];\n\n    rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,\n                          ne00, ne01, ne02, ne03,\n                          /*s00*/ s01, s02, s03,\n                          /*mul_s00*/ mul_s01, mul_s02, mul_s03,\n                          mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,\n                          /*add_s00*/ 0, 0, 0,\n                          0, 0, 0, 0,\n                          eps, stream);\n}\n\nvoid ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,\n                                     ggml_tensor *               dst,\n                                     ggml_tensor *               mul_tensor,\n                                     ggml_tensor *               add_tensor) {\n    const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];\n    float               eps          = 0.0f;\n\n    memcpy(&eps, dst->op_params, sizeof(float));\n\n    const float *       src0_d  = (const float *) rms_norm_src->data;\n    const float *       mul_d   = nullptr;\n    const ggml_tensor * mul_src = nullptr;\n\n    if (mul_tensor->src[0] == dst) {\n        mul_d   = (float *) mul_tensor->src[1]->data;\n        mul_src = mul_tensor->src[1];\n    } else if (mul_tensor->src[1] == dst) {\n        mul_d   = (float *) mul_tensor->src[0]->data;\n        mul_src = mul_tensor->src[0];\n    } else {\n        GGML_ASSERT(false);\n    }\n\n    const float *       add_d   = nullptr;\n    const ggml_tensor * add_src = nullptr;\n\n    if (add_tensor->src[0] == mul_tensor) {\n        add_d   = (float *) add_tensor->src[1]->data;\n        add_src = add_tensor->src[1];\n    } else if (add_tensor->src[1] == mul_tensor) {\n        add_d   = (float *) add_tensor->src[0]->data;\n        add_src = add_tensor->src[0];\n    } else {\n        GGML_ASSERT(false);\n    }\n\n    float *      dst_d  = (float *) add_tensor->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n    GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);\n    GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);\n    GGML_ASSERT(eps >= 0.0f);\n\n    const int64_t ne00 = rms_norm_src->ne[0];\n    const int64_t ne01 = rms_norm_src->ne[1];\n    const int64_t ne02 = rms_norm_src->ne[2];\n    const int64_t ne03 = rms_norm_src->ne[3];\n\n    const size_t ts0 = ggml_type_size(rms_norm_src->type);\n    GGML_ASSERT(rms_norm_src->nb[0] == ts0);\n    const int64_t s01 = rms_norm_src->nb[1] / ts0;\n    const int64_t s02 = rms_norm_src->nb[2] / ts0;\n    const int64_t s03 = rms_norm_src->nb[3] / ts0;\n\n    const size_t ts_mul = ggml_type_size(mul_src->type);\n    GGML_ASSERT(mul_src->nb[0] == ts_mul);\n    const int64_t mul_s01 = mul_src->nb[1] / ts_mul;\n    const int64_t mul_s02 = mul_src->nb[2] / ts_mul;\n    const int64_t mul_s03 = mul_src->nb[3] / ts_mul;\n\n    const int mul_ncols     = mul_src->ne[0];\n    const int mul_nrows     = mul_src->ne[1];\n    const int mul_nchannels = mul_src->ne[2];\n    const int mul_nsamples  = mul_src->ne[3];\n\n    const size_t ts_add = ggml_type_size(add_src->type);\n    GGML_ASSERT(add_src->nb[0] == ts_add);\n    const int64_t add_s01 = add_src->nb[1] / ts_add;\n    const int64_t add_s02 = add_src->nb[2] / ts_add;\n    const int64_t add_s03 = add_src->nb[3] / ts_add;\n\n    const int add_ncols     = add_src->ne[0];\n    const int add_nrows     = add_src->ne[1];\n    const int add_nchannels = add_src->ne[2];\n    const int add_nsamples  = add_src->ne[3];\n\n    rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,\n                          ne00,ne01, ne02, ne03,\n                          /*s00*/ s01, s02, s03,\n                          /*mul_s00*/ mul_s01, mul_s02, mul_s03,\n                          mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,\n                          /*add_s00*/ add_s01, add_s02, add_s03,\n                          add_ncols, add_nrows, add_nchannels, add_nsamples,\n                          eps, stream);\n}\n\nvoid ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * grad  = dst->src[0]; // gradients\n    const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass\n\n    const float * grad_d  = (const float *) grad->data;\n    const float * src0f_d = (const float *) src0f->data;\n    float       * dst_d   = (float       *) dst->data;\n\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(ggml_is_contiguous(grad));\n\n    GGML_ASSERT( grad->type == GGML_TYPE_F32);\n    GGML_ASSERT(src0f->type == GGML_TYPE_F32);\n    GGML_ASSERT(  dst->type == GGML_TYPE_F32);\n\n    const int64_t ne00 = src0f->ne[0];\n    const int64_t nrows = ggml_nrows(src0f);\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n    GGML_ASSERT(eps >= 0.0f);\n\n    rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);\n}\n\nvoid ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src0_d = (const float *) src0->data;\n    float * dst_d = (float *) dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_UNARY_OP_LOCALS;\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n    GGML_ASSERT(eps >= 0.0f);\n\n    const size_t ts0 = ggml_type_size(src0->type);\n    GGML_ASSERT(nb00 == ts0);\n    const int64_t s01 = nb01 / ts0;\n    const int64_t s02 = nb02 / ts0;\n    const int64_t s03 = nb03 / ts0;\n\n    l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/norm.cuh",
    "content": "#include \"common.cuh\"\n\nvoid ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);\n\nvoid ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,\n                                     ggml_tensor *               dst,\n                                     ggml_tensor *               mul_tensor,\n                                     ggml_tensor *               add_tensor);\n\nvoid ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/opt-step-adamw.cu",
    "content": "#include \"ggml-impl.h\"\n#include \"opt-step-adamw.cuh\"\n\n#include <cstdint>\n\nstatic __global__ void opt_step_adamw_f32(\n    float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v,\n    const float * __restrict__ pars, const int64_t k) {\n\n    const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;\n\n    if (i >= k) {\n        return;\n    }\n\n    const float alpha  = pars[0];\n    const float beta1  = pars[1];\n    const float beta2  = pars[2];\n    const float eps    = pars[3];\n    const float wd     = pars[4];\n    const float beta1h = pars[5];\n    const float beta2h = pars[6];\n\n    const float gi = g[i];\n    const float gmi = g_m[i]*beta1 +    gi*(1.0f - beta1);\n    const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2);\n\n    g_m[i] = gmi;\n    g_v[i] = gvi;\n\n    const float mh =       gmi*beta1h;\n    const float vh = sqrtf(gvi*beta2h) + eps;\n\n    x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;\n}\n\nstatic void opt_step_adamw_f32_cuda(\n    float * x, const float * g, float * g_m, float * g_v, const float * pars, const int64_t k, cudaStream_t stream) {\n\n    const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);\n    const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);\n    opt_step_adamw_f32<<<block_nums, block_dims, 0, stream>>>(x, g, g_m, g_v, pars, k);\n}\n\nvoid ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0         = dst->src[0];\n    const ggml_tensor * src0_grad    = dst->src[1];\n    const ggml_tensor * src0_grad_m  = dst->src[2];\n    const ggml_tensor * src0_grad_v  = dst->src[3];\n    const ggml_tensor * adamw_params = dst->src[4];\n\n    GGML_ASSERT(src0->type         == GGML_TYPE_F32);\n    GGML_ASSERT(src0_grad->type    == GGML_TYPE_F32);\n    GGML_ASSERT(src0_grad_m->type  == GGML_TYPE_F32);\n    GGML_ASSERT(src0_grad_v->type  == GGML_TYPE_F32);\n    GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);\n    GGML_ASSERT(ggml_is_contiguous(src0));\n    GGML_ASSERT(ggml_is_contiguous(src0_grad));\n    GGML_ASSERT(ggml_is_contiguous(src0_grad_m));\n    GGML_ASSERT(ggml_is_contiguous(src0_grad_v));\n    GGML_ASSERT(ggml_is_contiguous(adamw_params));\n    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));\n    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));\n    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));\n    GGML_ASSERT(ggml_nelements(adamw_params) == 7);\n\n    float       * src0_d         = (float       *) src0->data;\n    const float * src0_grad_d    = (const float *) src0_grad->data;\n    float       * src0_grad_m_d  = (float       *) src0_grad_m->data;\n    float       * src0_grad_v_d  = (float       *) src0_grad_v->data;\n    const float * adamw_params_d = (const float *) adamw_params->data;\n\n    cudaStream_t stream = ctx.stream();\n\n    const int64_t ne = ggml_nelements(src0);\n\n    opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, adamw_params_d, ne, stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/opt-step-adamw.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_OPT_STEP_ADAMW_BLOCK_SIZE 256\n\nvoid ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/opt-step-sgd.cu",
    "content": "#include \"ggml-impl.h\"\n#include \"opt-step-sgd.cuh\"\n\n#include <cstdint>\n\nstatic __global__ void opt_step_sgd_f32(\n    float * __restrict__ x, const float * __restrict__ g,\n    const float * __restrict__ pars, const int64_t k) {\n\n    const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;\n\n    if (i >= k) {\n        return;\n    }\n    x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i];\n}\n\nstatic void opt_step_sgd_f32_cuda(\n    float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {\n\n    const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);\n    const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);\n    opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k);\n}\n\nvoid ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0      = dst->src[0];\n    const ggml_tensor * src0_grad = dst->src[1];\n    const ggml_tensor * params    = dst->src[2];\n\n    GGML_ASSERT(src0->type      == GGML_TYPE_F32);\n    GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);\n    GGML_ASSERT(params->type    == GGML_TYPE_F32);\n    GGML_ASSERT(ggml_is_contiguous(src0));\n    GGML_ASSERT(ggml_is_contiguous(src0_grad));\n    GGML_ASSERT(ggml_is_contiguous(params));\n    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));\n    GGML_ASSERT(ggml_nelements(params) == 2);\n\n    float       * src0_d      = (float       *) src0->data;\n    const float * src0_grad_d = (const float *) src0_grad->data;\n    const float * params_d    = (const float *) params->data;\n\n    cudaStream_t stream = ctx.stream();\n\n    const int64_t ne = ggml_nelements(src0);\n\n    opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/opt-step-sgd.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256\n\nvoid ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/out-prod.cu",
    "content": "#include \"out-prod.cuh\"\n\n#include <cstdint>\n\nvoid ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type  == GGML_TYPE_F32);\n\n    GGML_ASSERT(ne01 == ne11);\n    GGML_ASSERT(ne0 == ne00);\n    GGML_ASSERT(ne1 == ne10);\n\n    GGML_ASSERT(ne2 % src0->ne[2] == 0);\n    GGML_ASSERT(ne3 % src0->ne[3] == 0);\n\n    GGML_ASSERT(ne2 == src1->ne[2]);\n    GGML_ASSERT(ne3 == src1->ne[3]);\n\n    const float * src0_d = (const float *) src0->data;\n    const float * src1_d = (const float *) src1->data;\n    float       *  dst_d = (float       *)  dst->data;\n\n    cudaStream_t   stream = ctx.stream();\n    cublasHandle_t handle = ctx.cublas_handle();\n\n    const float alpha = 1.0f;\n    const float beta = 0.0f;\n\n    CUBLAS_CHECK(cublasSetStream(handle, stream));\n\n    const int64_t lda = nb01 / sizeof(float);\n    const int64_t ldc = nb1  / sizeof(float);\n\n    const bool src1_T = ggml_is_transposed(src1);\n    const cublasOperation_t src1_cublas_op =  src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;\n    const int64_t           ldb            = (src1_T ?        nb10 :        nb11) /  sizeof(float);\n    GGML_ASSERT(                             (src1_T ?        nb11 :        nb10) == sizeof(float));\n\n    // data strides in dimensions 2/3\n    const size_t s02 = nb02 / sizeof(float);\n    const size_t s03 = nb03 / sizeof(float);\n    const size_t s12 = nb12 / sizeof(float);\n    const size_t s13 = nb13 / sizeof(float);\n    const size_t s2  = nb2  / sizeof(float);\n    const size_t s3  = nb3  / sizeof(float);\n\n    // dps == dst per src0, used for group query attention\n    const int64_t dps2 = ne2 / ne02;\n    const int64_t dps3 = ne3 / ne03;\n\n    // TODO batched matrix multiplication\n    for (int64_t i3 = 0; i3 < ne3; ++i3) {\n        for (int64_t i2 = 0; i2 < ne2; ++i2) {\n            CUBLAS_CHECK(\n                cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,\n                        ne0, ne1, ne01,\n                        &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,\n                                src1_d +  i3      *s13 +  i2      *s12, ldb,\n                        &beta,  dst_d  +  i3      *s3  +  i2      *s2,  ldc));\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/out-prod.cuh",
    "content": "#include \"common.cuh\"\n\nvoid ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/pad.cu",
    "content": "#include \"pad.cuh\"\n\n#include <stdint.h>\n\n__device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {\n    // + size ensures negatives are handled properly\n    return (coord + size) % size;\n}\n\nstatic __global__ void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,\n                               const int lp0, const int rp0, const int lp1, const int rp1,\n                               const int lp2, const int rp2, const int lp3, const int rp3,\n                               const int ne0, const int ne1, const int ne2, const int ne3,\n                               const bool circular) {\n    // blockIdx.z: i3*ne2+i2\n    // blockIdx.y: i1\n    // blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE\n    // gridDim.y:  ne1\n    int i0 = threadIdx.x + blockIdx.x * blockDim.x;\n    int i1 = blockIdx.y;\n    int i2 = blockIdx.z % ne2;\n    int i3 = blockIdx.z / ne2;\n\n    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {\n        return;\n    }\n\n    const int64_t dst_idx = i3 * (ne0 * ne1 * ne2) + i2 * (ne0 * ne1) + i1 * ne0 + i0;\n\n    if (!circular) {\n        if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && (i2 >= lp2 && i2 < ne2 - rp2) &&\n            (i3 >= lp3 && i3 < ne3 - rp3)) {\n            const int64_t i00  = i0 - lp0;\n            const int64_t i01  = i1 - lp1;\n            const int64_t i02  = i2 - lp2;\n            const int64_t i03  = i3 - lp3;\n\n            const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;\n\n            dst[dst_idx] = src[src_idx];\n        } else {\n            dst[dst_idx] = 0.0f;\n        }\n    }\n    // circular means on a torus, so x and y wrap around\n    else {\n        const int64_t ne00 = ne0 - lp0 - rp0;\n        const int64_t ne01 = ne1 - lp1 - rp1;\n        const int64_t ne02 = ne2 - lp2 - rp2;\n        const int64_t ne03 = ne3 - lp3 - rp3;\n\n        const int64_t i00 = wrap_around(i0 - lp0, ne00);\n        const int64_t i01 = wrap_around(i1 - lp1, ne01);\n        const int64_t i02 = wrap_around(i2 - lp2, ne02);\n        const int64_t i03 = wrap_around(i3 - lp3, ne03);\n\n        const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;\n\n        dst[dst_idx] = src[src_idx];\n    }\n}\n\n\nstatic void pad_f32_cuda(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,\n    const int lp0, const int rp0, const int lp1, const int rp1,\n    const int lp2, const int rp2, const int lp3, const int rp3,\n    const int ne0, const int ne1, const int ne2, const int ne3,\n    const bool circular, cudaStream_t stream) {\n    int  num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;\n    dim3 gridDim(num_blocks, ne1, ne2 * ne3);\n    pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, s00, s01, s02, s03, dst,\n                                                         lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,\n                                                         ne0, ne1, ne2, ne3, circular);\n}\n\nvoid ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0   = dst->src[0];\n    const float *       src0_d = (const float *) src0->data;\n    float *             dst_d  = (float *) dst->data;\n    cudaStream_t        stream = ctx.stream();\n\n    GGML_TENSOR_UNARY_OP_LOCALS;\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    const int32_t lp0      = ((const int32_t *) (dst->op_params))[0];\n    const int32_t rp0      = ((const int32_t *) (dst->op_params))[1];\n    const int32_t lp1      = ((const int32_t *) (dst->op_params))[2];\n    const int32_t rp1      = ((const int32_t *) (dst->op_params))[3];\n    const int32_t lp2      = ((const int32_t *) (dst->op_params))[4];\n    const int32_t rp2      = ((const int32_t *) (dst->op_params))[5];\n    const int32_t lp3      = ((const int32_t *) (dst->op_params))[6];\n    const int32_t rp3      = ((const int32_t *) (dst->op_params))[7];\n    const int32_t circular = ((const int32_t *) (dst->op_params))[8];\n\n    const size_t s00 = nb00 / ggml_type_size(src0->type);\n    const size_t s01 = nb01 / ggml_type_size(src0->type);\n    const size_t s02 = nb02 / ggml_type_size(src0->type);\n    const size_t s03 = nb03 / ggml_type_size(src0->type);\n\n    pad_f32_cuda(src0_d, s00, s01, s02, s03, dst_d,\n                 lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,\n                 dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],\n                 (bool) circular, stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/pad.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_PAD_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/pad_reflect_1d.cu",
    "content": "#include \"pad_reflect_1d.cuh\"\n\nstatic __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void\n    pad_reflect_1d_kernel_f32(\n        const void * __restrict__ src0,\n        void * __restrict__       dst,\n        const int64_t             ne0,\n        const int64_t             ne00,\n        const uint3               ne01,\n        const int64_t             ne02,\n        const int64_t             ne03,\n        const int64_t             nb00,\n        const int64_t             nb01,\n        const int64_t             nb02,\n        const int64_t             nb03,\n        const int64_t             nb0,\n        const int64_t             nb1,\n        const int64_t             nb2,\n        const int64_t             nb3,\n        const int                 p0,\n        const int                 p1) {\n    const int64_t i3 = blockIdx.z;\n    const int64_t i2 = blockIdx.y;\n\n    const uint2   div_mod_packed = fast_div_modulo(blockIdx.x, ne01);\n    const int64_t tile1          = div_mod_packed.y;  // i1\n    const int64_t tile0          = div_mod_packed.x;  // nth i0 tile\n    const int64_t i1             = tile1;\n    const int64_t i0             = threadIdx.x + tile0 * blockDim.x;\n\n    // ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh)\n    if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) {\n        return;\n    }\n\n    const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;\n    char *       dst_ptr  = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1;\n\n    const int64_t rel_i0 = i0 - p0;  // relative i0 in src0\n    int64_t src_idx;\n\n    if (rel_i0 < 0) {\n        // Left padding - reflect\n        src_idx = -rel_i0;\n    } else if (rel_i0 < ne00) {\n        // Middle - copy\n        src_idx = rel_i0;\n    } else {\n        // Right padding - reflect\n        src_idx = 2 * ne00 - 2 - rel_i0;\n    }\n    const float value               = *(const float *) (src0_ptr + src_idx * nb00);\n    *(float *) (dst_ptr + i0 * nb0) = value;\n\n    GGML_UNUSED(p1);\n}\n\nvoid ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0   = dst->src[0];\n    cudaStream_t        stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    const int32_t * opts = (const int32_t *) dst->op_params;\n    const int       p0   = opts[0];\n    const int       p1   = opts[1];\n\n    const int64_t ne00        = src0->ne[0];\n    const int64_t ne01        = src0->ne[1];\n    const uint3   ne01_packed = init_fastdiv_values(ne01);\n    const int64_t ne02        = src0->ne[2];\n    const int64_t ne03        = src0->ne[3];\n\n    const int64_t ne0 = dst->ne[0];\n\n    // sanity: padded length matches\n    GGML_ASSERT(ne0 == ne00 + p0 + p1);\n\n    constexpr int64_t bx     = CUDA_PAD_REFLECT_1D_BLOCK_SIZE;  // threads per block (x)\n    const int64_t     tiles0 = (ne0 + bx - 1) / bx;             // number of tiles along i0\n    // grid.x covers i1 and all tiles of i0: [ne01 * tiles0]\n    // grid.y covers i2: [ne02]\n    // grid.z covers i3: [ne03]\n    const dim3        grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03);\n    const dim3        block_dims((unsigned) bx, 1, 1);\n\n    pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(\n        src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],\n        dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1);\n}\n"
  },
  {
    "path": "src/ggml-cuda/pad_reflect_1d.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/pool2d.cu",
    "content": "#include \"pool2d.cuh\"\n\ntemplate <typename Ti, typename To>\nstatic  __global__ void pool2d_nchw_kernel(\n        const int ih, const int iw, const int oh, const int ow,\n        const int kh, const int kw, const int sh, const int sw,\n        const int ph, const int pw, const int parallel_elements,\n        const Ti* src, To* dst, const enum ggml_op_pool op) {\n    int idx = threadIdx.x + blockIdx.x * blockDim.x;\n    if (idx >= parallel_elements) {\n        return;\n    }\n\n    const int I_HW = ih * iw;\n    const int O_HW = oh * ow;\n    const int nc = idx / O_HW;\n    const int cur_oh = idx % O_HW / ow;\n    const int cur_ow = idx % O_HW % ow;\n    const Ti* i_ptr = src + nc * I_HW;\n    To* o_ptr = dst + nc * O_HW;\n    const int start_h = cur_oh * sh - ph;\n    const int bh = max(0, start_h);\n    const int eh = min(ih, start_h + kh);\n    const int start_w = cur_ow * sw - pw;\n    const int bw = max(0, start_w);\n    const int ew = min(iw, start_w + kw);\n    const To scale = 1. / (kh * kw);\n    To res = 0;\n\n    switch (op) {\n        case GGML_OP_POOL_AVG: res = 0; break;\n        case GGML_OP_POOL_MAX: res = -FLT_MAX; break;\n        default: assert(false);\n    }\n\n    for (int i = bh; i < eh; i += 1) {\n        for (int j = bw; j < ew; j += 1) {\n#if __CUDA_ARCH__ >= 350\n            Ti cur = __ldg(i_ptr + i * iw + j);\n#else\n            Ti cur = i_ptr[i * iw + j];\n#endif\n            switch (op) {\n                case GGML_OP_POOL_AVG: res += cur * scale; break;\n                case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;\n                default: assert(false);\n            }\n        }\n    }\n    o_ptr[cur_oh * ow + cur_ow] = res;\n}\n\nstatic void pool2d_nchw_kernel_f32_f32_cuda(\n        const int ih, const int iw, const int oh, const int ow,\n        const int kh, const int kw, const int sh, const int sw,\n        const int ph, const int pw, const int parallel_elements,\n        const float * src, float * dst, const enum ggml_op_pool op,\n        cudaStream_t stream) {\n\n    const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;\n    dim3 block_nums(num_blocks);\n    pool2d_nchw_kernel<<<block_nums, CUDA_POOL2D_BLOCK_SIZE, 0, stream>>>(ih, iw, oh, ow, kh, kw, sh, sw, ph, pw, parallel_elements, src, dst, op);\n}\n\nvoid ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src0_d = (const float *)src0->data;\n    float * dst_d = (float *)dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    const int32_t * opts = (const int32_t *)dst->op_params;\n    enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);\n    const int k0 = opts[1];\n    const int k1 = opts[2];\n    const int s0 = opts[3];\n    const int s1 = opts[4];\n    const int p0 = opts[5];\n    const int p1 = opts[6];\n\n    const int64_t IH = src0->ne[1];\n    const int64_t IW = src0->ne[0];\n\n    const int64_t N = dst->ne[3];\n    const int64_t OC = dst->ne[2];\n    const int64_t OH = dst->ne[1];\n    const int64_t OW = dst->ne[0];\n\n    const int parallel_elements = N * OC * OH * OW;\n\n    pool2d_nchw_kernel_f32_f32_cuda(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_d, dst_d, op, stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/pool2d.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_POOL2D_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/quantize.cu",
    "content": "#include \"quantize.cuh\"\n#include <cstdint>\n\n__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1)\nstatic __global__ void quantize_q8_1(\n        const float * __restrict__ x, void * __restrict__ vy,\n        const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,\n        const int64_t ne0, const uint32_t ne1, const uint3 ne2) {\n    const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;\n\n    if (i0 >= ne0) {\n        return;\n    }\n\n    const int64_t i3 = fastdiv(blockIdx.z, ne2);\n    const int64_t i2 = blockIdx.z - i3*ne2.z;\n    const int64_t i1 = blockIdx.y;\n\n    const int64_t & i00 = i0;\n    const int64_t & i01 = i1;\n    const int64_t & i02 = i2;\n    const int64_t & i03 = i3;\n\n    const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0;\n\n    block_q8_1 * y = (block_q8_1 *) vy;\n\n    const int64_t ib  = i_cont / QK8_1; // block index\n    const int64_t iqs = i_cont % QK8_1; // quant index\n\n    const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f;\n    float amax = fabsf(xi);\n    float sum = xi;\n\n    amax = warp_reduce_max<QK8_1>(amax);\n    sum  = warp_reduce_sum<QK8_1>(sum);\n\n    const float  d = amax / 127.0f;\n    const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);\n\n    y[ib].qs[iqs] = q;\n\n    if (iqs > 0) {\n        return;\n    }\n\n    y[ib].ds = make_half2(d, sum);\n}\n\n__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {\n    if (!(amax > 0.0f)) {\n        return 0;\n    }\n\n    // FP4 E2M1: max exponent (unbiased) is 2.\n    constexpr int FP4_E2M1_EMAX = 2;\n\n    const float e = log2f(amax);\n\n    // \"even\" -> round-to-nearest integer, ties-to-even\n    const int e_int = __float2int_rn(e);\n\n    const int shared_exp = e_int - FP4_E2M1_EMAX;\n\n    int biased = shared_exp + 127;\n\n    biased = max(biased, 0);\n    biased = min(biased, 254);\n\n    return static_cast<uint8_t>(biased);\n}\n\n// quantize values in the format mxfp4 is stored which is interleaved nibbles\n// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31\nstatic __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,\n                                          const int32_t * __restrict__ ids,\n                                          void * __restrict__ vy,\n                                          const int64_t ne00,\n                                          const int64_t s01,\n                                          const int64_t s02,\n                                          const int64_t s03,\n                                          const int64_t ne0,\n                                          const int     ne1,\n                                          const int     ne2) {\n    constexpr int vals_per_scale = 32;\n    constexpr int vals_per_warp  = 2 * vals_per_scale;  // Each warp processes 2 blocks of 32 = 64 values\n\n    const int warp_id = threadIdx.y;\n    const int lane_id_32 = threadIdx.x;\n\n    const int nwarps = blockDim.y;\n\n    const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;\n\n    if (warp_start_offset >= ne0) {\n        return;\n    }\n\n    const int64_t i1 = blockIdx.x;\n    const int64_t i2 = blockIdx.z % ne2;\n    const int64_t i3 = blockIdx.z / ne2;\n\n    const int64_t i01 = ids ? ids[i1] : i1;\n    const int64_t i02 = i2;\n    const int64_t i03 = i3;\n\n    block_fp4_mmq * y = (block_fp4_mmq *) vy;\n\n    const int64_t block_fp4_mmq_size = 8 * QK_MXFP4;  // 256 values\n    const int64_t ib0                = blockIdx.z * ((int64_t) ne1 * (ne0 / block_fp4_mmq_size));\n    const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;\n    const int64_t quad_idx_in_block  = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;\n\n    const int group_id = lane_id_32 / 4;\n    const int lane_in_group = lane_id_32 % 4;\n    const int base = group_id * 2;\n    char2 * yqs2 = (char2 *) y[ib].qs;\n\n    const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;\n\n    uint8_t scales[2];\n\n#pragma unroll\n    for (int b = 0; b < 2; ++b) {\n        const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32;\n        const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f;\n\n        float amax = fabsf(xi);\n#pragma unroll\n        for (int mask = 16; mask > 0; mask >>= 1) {\n            amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));\n        }\n\n        const uint8_t e = compute_e8m0_scale(amax);\n        scales[b] = e;\n        const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));\n\n#if CUDART_VERSION >= 12080\n        const float scaled_val = xi * inv_s;\n\n        const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE);\n        const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE);\n        const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE);\n        const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE);\n\n        if (lane_in_group == 0) {\n            __nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3));\n\n            yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = *(char2 *) &fp4_packed;\n        }\n#else\n        // Fallback: manual FP4 conversion using LUT\n        const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);\n\n        const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base,      WARP_SIZE);\n        const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1,  WARP_SIZE);\n        const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE);\n        const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE);\n\n        if (lane_in_group == 0) {\n            char2 q;\n            q.x = (q_hi_0 << 4) | q_lo_0;\n            q.y = (q_hi_1 << 4) | q_lo_1;\n            yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = q;\n        }\n#endif // CUDART_VERSION >= 12080\n    }\n\n    if (lane_id_32 == 0) {\n        // Store 2 scales packed into 1 uint32\n        y[ib].d4[quad_idx_in_block] = (scales[1] << 8) | scales[0];\n    }\n}\n\ntemplate <mmq_q8_1_ds_layout ds_layout>\nstatic __global__ void quantize_mmq_q8_1(\n        const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,\n        const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,\n        const int64_t ne0, const int ne1, const int ne2) {\n\n    constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;\n    constexpr int vals_per_sum   = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;\n\n    const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;\n\n    if (i0 >= ne0) {\n        return;\n    }\n\n    const int64_t i1 = blockIdx.x;\n    const int64_t i2 = blockIdx.z % ne2;\n    const int64_t i3 = blockIdx.z / ne2;\n\n    const int64_t i00 = i0;\n    const int64_t i01 = ids ? ids[i1] : i1;\n    const int64_t i02 = i2;\n    const int64_t i03 = i3;\n\n    const float4 * x4 = (const float4 *) x;\n\n    block_q8_1_mmq * y = (block_q8_1_mmq *) vy;\n\n    const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel\n    const int64_t ib  = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x;                    // block index in channel\n    const int64_t iqs = i0 % (4*QK8_1);                                             // quant index in block\n\n    // Load 4 floats per thread and calculate max. abs. value between them:\n    const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);\n    float amax = fabsf(xi.x);\n    amax = fmaxf(amax, fabsf(xi.y));\n    amax = fmaxf(amax, fabsf(xi.z));\n    amax = fmaxf(amax, fabsf(xi.w));\n\n    // Exchange max. abs. value between vals_per_scale/4 threads.\n#pragma unroll\n    for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {\n        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));\n    }\n\n    float sum;\n    if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {\n        sum = xi.x + xi.y + xi.z + xi.w;\n\n        // Calculate sums across vals_per_sum/4 threads.\n#pragma unroll\n        for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {\n            sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);\n        }\n    }\n\n    const float d_inv = 127.0f / amax;\n    char4 q;\n    q.x = roundf(xi.x*d_inv);\n    q.y = roundf(xi.y*d_inv);\n    q.z = roundf(xi.z*d_inv);\n    q.w = roundf(xi.w*d_inv);\n\n    // Write back 4 int8 values as a single 32 bit value for better memory bandwidth:\n    char4 * yqs4 = (char4 *) y[ib].qs;\n    yqs4[iqs/4] = q;\n\n    if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {\n        if (iqs % 16 != 0 || iqs >= 96) {\n            return;\n        }\n\n        y[ib].d2s6[2 + iqs/16] = sum;\n\n        if (iqs % 64 != 0) {\n            return;\n        }\n\n        const float d = 1.0f / d_inv;\n\n        y[ib].d2s6[iqs/64] = d;\n\n        return;\n    }\n\n    if (iqs % 32 != 0) {\n        return;\n    }\n\n    const float d = 1.0f / d_inv;\n\n    if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {\n        y[ib].ds4[iqs/32] = make_half2(d, sum);\n    } else {\n        y[ib].d4[iqs/32]  = d;\n    }\n}\n\nvoid quantize_row_q8_1_cuda(\n        const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,\n        const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,\n        const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {\n    GGML_ASSERT(!ids);\n    GGML_ASSERT(ne0 % QK8_1 == 0);\n\n    const uint3 ne2_fastdiv = init_fastdiv_values(ne2);\n\n    const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;\n    const dim3 num_blocks(block_num_x, ne1, ne2*ne3);\n    const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);\n    quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);\n    GGML_UNUSED(type_src0);\n}\n\nvoid quantize_mmq_q8_1_cuda(\n        const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,\n        const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,\n        const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {\n    GGML_ASSERT(ne00 % 4 == 0);\n    GGML_ASSERT(ne0 % (4*QK8_1) == 0);\n\n    // ne1 tends to assume the highest values, therefore use it as the \"x\" dimension of the CUDA grid:\n    const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);\n    const dim3 num_blocks(ne1, block_num_y, ne2*ne3);\n    const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);\n    switch (mmq_get_q8_1_ds_layout(type_src0)) {\n        case MMQ_Q8_1_DS_LAYOUT_D4:\n            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>\n                <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);\n            break;\n        case MMQ_Q8_1_DS_LAYOUT_DS4:\n            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>\n                <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);\n            break;\n        case MMQ_Q8_1_DS_LAYOUT_D2S6:\n            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>\n                <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);\n            break;\n        default:\n            GGML_ABORT(\"fatal error\");\n            break;\n    }\n}\n\nvoid quantize_mmq_mxfp4_cuda(const float *                    x,\n                             const int32_t *                  ids,\n                             void *                           vy,\n                             [[maybe_unused]] const ggml_type type_src0,\n                             const int64_t                    ne00,\n                             const int64_t                    s01,\n                             const int64_t                    s02,\n                             const int64_t                    s03,\n                             const int64_t                    ne0,\n                             const int64_t                    ne1,\n                             const int64_t                    ne2,\n                             const int64_t                    ne3,\n                             cudaStream_t                     stream) {\n    GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);\n\n    constexpr int nwarps = 8;\n    constexpr int vals_per_warp  = 2 * QK_MXFP4;\n    constexpr int vals_per_block = nwarps * vals_per_warp;\n\n    const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;\n    const dim3    num_blocks(ne1, block_num_y, ne2 * ne3);\n    const dim3    block_size(WARP_SIZE, nwarps, 1);\n\n    quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);\n}\n"
  },
  {
    "path": "src/ggml-cuda/quantize.cuh",
    "content": "#pragma once\n\n#include \"common.cuh\"\n#include \"mmq.cuh\"\n\n#include <cstdint>\n\n#define CUDA_QUANTIZE_BLOCK_SIZE     256\n#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128\n\nstatic_assert(MATRIX_ROW_PADDING %    CUDA_QUANTIZE_BLOCK_SIZE      == 0, \"Risk of out-of-bounds access.\");\nstatic_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, \"Risk of out-of-bounds access.\");\n\ntypedef void (*quantize_cuda_t)(\n        const float * x, const int32_t * ids, void * vy,\n        ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,\n        int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);\n\nvoid quantize_row_q8_1_cuda(\n        const float * x, const int32_t * ids, void * vy,\n        ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,\n        int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);\n\nvoid quantize_mmq_q8_1_cuda(\n        const float * x, const int32_t * ids, void * vy,\n        ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,\n        int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);\n\nvoid quantize_mmq_mxfp4_cuda(const float *   x,\n                             const int32_t * ids,\n                             void *          vy,\n                             ggml_type       type_src0,\n                             int64_t         ne00,\n                             int64_t         s01,\n                             int64_t         s02,\n                             int64_t         s03,\n                             int64_t         ne0,\n                             int64_t         ne1,\n                             int64_t         ne2,\n                             int64_t         ne3,\n                             cudaStream_t    stream);\n"
  },
  {
    "path": "src/ggml-cuda/reduce_rows.cuh",
    "content": "#include \"common.cuh\"\n\n// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)\ntemplate <bool norm>\nstatic __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {\n    const int row = blockIdx.x;\n    const int col = threadIdx.x;\n\n    float     sum        = 0.0f;\n    const int num_unroll = 8;\n    float     temp[num_unroll];\n    float     sum_temp[num_unroll] = { 0.0f };\n    for (int i = col; i < ncols;) {\n        for (int j = 0; j < num_unroll; ++j) {\n            if (i < ncols) {\n                temp[j] = x[row * ncols + i];\n            } else {\n                temp[j] = 0;\n            }\n            i += blockDim.x;\n        }\n        for (int j = 0; j < num_unroll; ++j) {\n            sum_temp[j] += temp[j];\n        }\n    }\n    for (int j = 0; j < num_unroll; ++j) {\n        sum += sum_temp[j];\n    }\n\n    // sum up partial sums\n    __shared__ float shared_vals[32];\n    sum = block_reduce<block_reduce_method::SUM>(sum, shared_vals);\n\n    if (col != 0) {\n        return;\n    }\n\n    dst[row] = norm ? sum / ncols : sum;\n}\n"
  },
  {
    "path": "src/ggml-cuda/roll.cu",
    "content": "#include \"ggml-cuda/common.cuh\"\n#include \"roll.cuh\"\n\nstatic __forceinline__ __device__ int64_t wrap_index(const int64_t idx, const int64_t ne) {\n    if (idx < 0) {\n        return idx + ne;\n    }\n    if (idx >= ne) {\n        return idx - ne;\n    }\n    return idx;\n}\n\nstatic __global__ void roll_f32_cuda(const float * __restrict__ src,\n                                     float * __restrict__ dst,\n                                     const int64_t ne00,\n                                     const int64_t ne01,\n                                     const int64_t ne02,\n                                     const int64_t ne03,\n                                     const int     s0,\n                                     const int     s1,\n                                     const int     s2,\n                                     const int     s3) {\n    const int64_t idx        = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;\n    const int64_t n_elements = ne00 * ne01 * ne02 * ne03;\n\n    if (idx >= n_elements) {\n        return;\n    }\n\n    const int64_t i0 = idx % ne00;\n    const int64_t i1 = (idx / ne00) % ne01;\n    const int64_t i2 = (idx / (ne00 * ne01)) % ne02;\n    const int64_t i3 = (idx / (ne00 * ne01 * ne02)) % ne03;\n\n    const int64_t d0 = wrap_index(i0 - s0, ne00);\n    const int64_t d1 = wrap_index(i1 - s1, ne01);\n    const int64_t d2 = wrap_index(i2 - s2, ne02);\n    const int64_t d3 = wrap_index(i3 - s3, ne03);\n\n    dst[i3 * (ne00 * ne01 * ne02) + i2 * (ne01 * ne00) + i1 * ne00 + i0] =\n        src[d3 * (ne00 * ne01 * ne02) + d2 * (ne01 * ne00) + d1 * ne00 + d0];\n}\n\nvoid ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    int s0 = dst->op_params[0];\n    int s1 = dst->op_params[1];\n    int s2 = dst->op_params[2];\n    int s3 = dst->op_params[3];\n\n    const ggml_tensor * src0   = dst->src[0];\n    const float *       src0_d = (const float *) dst->src[0]->data;\n    float *             dst_d  = (float *) dst->data;\n\n    GGML_TENSOR_UNARY_OP_LOCALS;\n\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst));\n\n    cudaStream_t stream = ctx.stream();\n\n    int64_t sz         = (ne00 * ne01 * ne02 * ne03);\n    int64_t num_blocks = (sz + CUDA_ROLL_BLOCK_SIZE - 1) / CUDA_ROLL_BLOCK_SIZE;\n\n    roll_f32_cuda<<<num_blocks, CUDA_ROLL_BLOCK_SIZE, 0, stream>>>(\n        src0_d, dst_d, ne00, ne01, ne02, ne03, s0, s1, s2, s3);\n}\n"
  },
  {
    "path": "src/ggml-cuda/roll.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_ROLL_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/rope.cu",
    "content": "#include \"convert.cuh\"\n#include \"ggml-cuda/common.cuh\"\n#include \"ggml.h\"\n#include \"rope.cuh\"\n\nstruct rope_corr_dims {\n    float v[2];\n};\n\n\nstruct mrope_sections {\n    int v[4];\n};\n\nstatic __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {\n    const float y = (i0 / 2 - low) / max(0.001f, high - low);\n    return 1.0f - min(1.0f, max(0.0f, y));\n}\n\n// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn\n// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.\ntemplate<bool forward>\nstatic __device__ void rope_yarn(\n        const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor,\n        float mscale, float & cos_theta, float & sin_theta) {\n    // Get n-d rotational scaling corrected for extrapolation\n    float theta_interp = freq_scale * theta_extrap;\n    float theta = theta_interp;\n    if (ext_factor != 0.0f) {\n        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;\n        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;\n\n        // Get n-d magnitude scaling corrected for interpolation\n        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);\n    }\n    cos_theta = cosf(theta) * mscale;\n    sin_theta = sinf(theta) * mscale;\n    if (!forward) {\n        sin_theta *= -1.0f;\n    }\n}\n\ntemplate <bool forward, bool has_ff, typename T, typename D>\nstatic __global__ void rope_norm(const T *            x,\n                                 D *                  dst,\n                                 const int            ne00,\n                                 const int            ne01,\n                                 const int            ne02,\n                                 const int            s01,\n                                 const int            s02,\n                                 const int            s03,\n                                 const int            s1,\n                                 const int            s2,\n                                 const int            s3,\n                                 const int            n_dims,\n                                 const int32_t *      pos,\n                                 const float          freq_scale,\n                                 const float          ext_factor,\n                                 const float          attn_factor,\n                                 const rope_corr_dims corr_dims,\n                                 const float          theta_scale,\n                                 const float *        freq_factors,\n                                 const int64_t *      row_indices,\n                                 const int            set_rows_stride) {\n    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);\n\n    if (i0 >= ne00) {\n        return;\n    }\n\n    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;\n\n    const uint32_t i3 = row_dst / (ne01 * ne02);\n    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;\n    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;\n\n    int       idst = i0 + i1 * s1  + i2 * s2  + i3 * s3;\n    const int ix   = i0 + i1 * s01 + i2 * s02 + i3 * s03;\n    // Fusion optimization: ROPE + VIEW + SET_ROWS.\n    // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.\n    if (set_rows_stride != 0) {\n        idst = i1 * s1 + i0;\n        idst += row_indices[i2] * set_rows_stride;\n    }\n\n    const auto & store_coaelsced = [&](float x0, float x1) {\n        if constexpr (std::is_same_v<float, D>) {\n            float2 v = make_float2(x0, x1);\n            ggml_cuda_memcpy_1<8>(dst + idst, &v);\n        } else if constexpr (std::is_same_v<half, D>) {\n            half2 v = make_half2(x0, x1);\n            ggml_cuda_memcpy_1<4>(dst + idst, &v);\n        }\n    };\n    if (i0 >= n_dims) {\n        store_coaelsced(x[ix + 0], x[ix + 1]);\n        return;\n    }\n\n    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);\n\n    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;\n\n    float cos_theta;\n    float sin_theta;\n\n    rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);\n\n    const float x0 = x[ix + 0];\n    const float x1 = x[ix + 1];\n\n    store_coaelsced(x0 * cos_theta - x1 * sin_theta, x0 * sin_theta + x1 * cos_theta);\n}\n\ntemplate <bool forward, bool has_ff, typename T, typename D>\nstatic __global__ void rope_neox(const T *            x,\n                                 D *                  dst,\n                                 const int            ne00,\n                                 const int            ne01,\n                                 const int            ne02,\n                                 const int            s01,\n                                 const int            s02,\n                                 const int            s03,\n                                 const int            s1,\n                                 const int            s2,\n                                 const int            s3,\n                                 const int            n_dims,\n                                 const int32_t *      pos,\n                                 const float          freq_scale,\n                                 const float          ext_factor,\n                                 const float          attn_factor,\n                                 const rope_corr_dims corr_dims,\n                                 const float          theta_scale,\n                                 const float *        freq_factors,\n                                 const int64_t *      row_indices,\n                                 const int            set_rows_stride) {\n    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);\n\n    if (i0 >= ne00) {\n        return;\n    }\n\n    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;\n\n    const uint32_t i3 = row_dst / (ne01 * ne02);\n    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;\n    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;\n\n    int       idst = i0 / 2 + i1 * s1  + i2 * s2  + i3 * s3;\n    const int ix   = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;\n\n    // Fusion optimization: ROPE + VIEW + SET_ROWS.\n    // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.\n    if (set_rows_stride != 0) {\n        idst = i1 * s1 + i0 / 2;\n        idst += row_indices[i2] * set_rows_stride;\n    }\n\n    if (i0 >= n_dims) {\n        dst[idst + i0 / 2 + 0] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 0]);\n        dst[idst + i0 / 2 + 1] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 1]);\n\n        return;\n    }\n\n    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);\n\n    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;\n\n    float cos_theta;\n    float sin_theta;\n\n    rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);\n\n    const float x0 = x[ix + 0];\n    const float x1 = x[ix + n_dims/2];\n\n    dst[idst + 0]          = ggml_cuda_cast<D>(x0 * cos_theta - x1 * sin_theta);\n    dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);\n}\n\ntemplate <bool forward, bool has_ff, typename T>\nstatic __global__ void rope_multi(const T *            x,\n                                  T *                  dst,\n                                  const int            ne00,\n                                  const int            ne01,\n                                  const int            ne02,\n                                  const int            s01,\n                                  const int            s02,\n                                  const int            s03,\n                                  const int            s1,\n                                  const int            s2,\n                                  const int            s3,\n                                  const int            n_dims,\n                                  const int32_t *      pos,\n                                  const float          freq_scale,\n                                  const float          ext_factor,\n                                  const float          attn_factor,\n                                  const rope_corr_dims corr_dims,\n                                  const float          theta_scale,\n                                  const float *        freq_factors,\n                                  const mrope_sections sections,\n                                  const bool           is_imrope) {\n    const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y);\n\n    if (i0 >= ne00) {\n        return;\n    }\n\n    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;\n\n    const uint32_t i3 = row_dst / (ne01 * ne02);\n    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;\n    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;\n\n    int       idst = i0 / 2 + i1 * s1  + i2 * s2  + i3 * s3;\n    const int ix   = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;\n\n    if (i0 >= n_dims) {\n        dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];\n        dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];\n\n        return;\n    }\n\n    const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];\n    const int sec_w = sections.v[1] + sections.v[0];\n    const int sector = (i0 / 2) % sect_dims;\n\n    float theta_base = 0.0;\n    if (is_imrope) {\n        if (sector % 3 == 1 && sector < 3 * sections.v[1]) {         // h\n            theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);\n        } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {  // w\n            theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);\n        } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {  // t\n            theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);\n        } else {\n            theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);\n        }\n    } else {\n        if (sector < sections.v[0]) {\n            theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);\n        } else if (sector >= sections.v[0] && sector < sec_w) {\n            theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);\n        } else if (sector >= sec_w && sector < sec_w + sections.v[2]) {\n            theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);\n        } else if (sector >= sec_w + sections.v[2]) {\n            theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);\n        }\n    }\n\n    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;\n\n    float cos_theta;\n    float sin_theta;\n\n    rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);\n\n    const float x0 = x[ix + 0];\n    const float x1 = x[ix + n_dims/2];\n\n    dst[idst + 0]        = x0*cos_theta - x1*sin_theta;\n    dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;\n}\n\ntemplate <bool forward, bool has_ff, typename T>\nstatic __global__ void rope_vision(const T *            x,\n                                   T *                  dst,\n                                   const int            ne00,\n                                   const int            ne01,\n                                   const int            ne02,\n                                   const int            s01,\n                                   const int            s02,\n                                   const int            s03,\n                                   const int            s1,\n                                   const int            s2,\n                                   const int            s3,\n                                   const int            n_dims,\n                                   const int32_t *      pos,\n                                   const float          freq_scale,\n                                   const float          ext_factor,\n                                   const float          attn_factor,\n                                   const rope_corr_dims corr_dims,\n                                   const float          theta_scale,\n                                   const float *        freq_factors,\n                                   const mrope_sections sections) {\n    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);\n\n    if (i0 >= ne00) {\n        return;\n    }\n\n    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;\n\n    const uint32_t i3 = row_dst / (ne01 * ne02);\n    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;\n    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;\n\n    int       idst = i0 / 2 + i1 * s1  + i2 * s2  + i3 * s3;\n    const int ix   = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;\n\n    const int sect_dims = sections.v[0] + sections.v[1];\n    const int sec_w     = sections.v[1] + sections.v[0];\n    const int sector    = (i0 / 2) % sect_dims;\n\n    float theta_base = 0.0;\n    if (sector < sections.v[0]) {\n        const int p = sector;\n        theta_base  = pos[i2] * powf(theta_scale, p);\n    } else if (sector >= sections.v[0] && sector < sec_w) {\n        const int p = sector - sections.v[0];\n        theta_base  = pos[i2 + ne02] * powf(theta_scale, p);\n    }\n\n    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;\n\n    float cos_theta;\n    float sin_theta;\n\n    rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);\n\n    const float x0 = x[ix + 0];\n    const float x1 = x[ix + n_dims];\n\n    dst[idst + 0]      = x0*cos_theta - x1*sin_theta;\n    dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;\n}\n\ntemplate <bool forward, typename T, typename D>\nstatic void rope_norm_cuda(const T *            x,\n                           D *                  dst,\n                           const int            ne00,\n                           const int            ne01,\n                           const int            ne02,\n                           const int            s01,\n                           const int            s02,\n                           const int            s03,\n                           const int            s1,\n                           const int            s2,\n                           const int            s3,\n                           const int            n_dims,\n                           const int            nr,\n                           const int32_t *      pos,\n                           const float          freq_scale,\n                           const float          freq_base,\n                           const float          ext_factor,\n                           const float          attn_factor,\n                           const rope_corr_dims corr_dims,\n                           const float *        freq_factors,\n                           const int64_t *      row_indices,\n                           const int            set_rows_stride,\n                           cudaStream_t         stream) {\n    GGML_ASSERT(ne00 % 2 == 0);\n    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);\n    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);\n    const dim3 block_nums(nr, n_blocks_x, 1);\n\n    const float theta_scale = powf(freq_base, -2.0f / n_dims);\n\n    if (freq_factors == nullptr) {\n        rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(\n            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,\n            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);\n    } else {\n        rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(\n            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,\n            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);\n    }\n}\n\ntemplate <bool forward, typename T, typename D>\nstatic void rope_neox_cuda(const T *            x,\n                           D *                  dst,\n                           const int            ne00,\n                           const int            ne01,\n                           const int            ne02,\n                           const int            s01,\n                           const int            s02,\n                           const int            s03,\n                           const int            s1,\n                           const int            s2,\n                           const int            s3,\n                           const int            n_dims,\n                           const int            nr,\n                           const int32_t *      pos,\n                           const float          freq_scale,\n                           const float          freq_base,\n                           const float          ext_factor,\n                           const float          attn_factor,\n                           const rope_corr_dims corr_dims,\n                           const float *        freq_factors,\n                           const int64_t *      row_indices,\n                           const int            set_rows_stride,\n                           cudaStream_t         stream) {\n    GGML_ASSERT(ne00 % 2 == 0);\n    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);\n    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);\n    const dim3 block_nums(nr, n_blocks_x, 1);\n\n    const float theta_scale = powf(freq_base, -2.0f / n_dims);\n\n    if (freq_factors == nullptr) {\n        rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(\n            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,\n            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);\n    } else {\n        rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(\n            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,\n            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);\n    }\n}\n\ntemplate <bool forward, typename T>\nstatic void rope_multi_cuda(const T *            x,\n                            T *                  dst,\n                            const int            ne00,\n                            const int            ne01,\n                            const int            ne02,\n                            const int            s01,\n                            const int            s02,\n                            const int            s03,\n                            const int            s1,\n                            const int            s2,\n                            const int            s3,\n                            const int            n_dims,\n                            const int            nr,\n                            const int32_t *      pos,\n                            const float          freq_scale,\n                            const float          freq_base,\n                            const float          ext_factor,\n                            const float          attn_factor,\n                            const rope_corr_dims corr_dims,\n                            const float *        freq_factors,\n                            const mrope_sections sections,\n                            const bool           is_imrope,\n                            cudaStream_t         stream) {\n    GGML_ASSERT(ne00 % 2 == 0);\n    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);\n    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);\n    const dim3 block_nums(nr, n_blocks_x, 1);\n\n    const float theta_scale = powf(freq_base, -2.0f / n_dims);\n\n    if (freq_factors == nullptr) {\n        rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(\n            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,\n            attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);\n    } else {\n        rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(\n            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,\n            attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);\n    }\n}\n\ntemplate <bool forward, typename T>\nstatic void rope_vision_cuda(const T *            x,\n                             T *                  dst,\n                             const int            ne00,\n                             const int            ne01,\n                             const int            ne02,\n                             const int            s01,\n                             const int            s02,\n                             const int            s03,\n                             const int            s1,\n                             const int            s2,\n                             const int            s3,\n                             const int            n_dims,\n                             const int            nr,\n                             const int32_t *      pos,\n                             const float          freq_scale,\n                             const float          freq_base,\n                             const float          ext_factor,\n                             const float          attn_factor,\n                             const rope_corr_dims corr_dims,\n                             const float *        freq_factors,\n                             const mrope_sections sections,\n                             cudaStream_t         stream) {\n    GGML_ASSERT(ne00 % 2 == 0);\n    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);\n    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);\n    const dim3 block_nums(nr, n_blocks_x, 1);\n    // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)\n    // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);\n\n    const float theta_scale = powf(freq_base, -2.0f/n_dims);\n\n    if (freq_factors == nullptr) {\n        rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(\n            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,\n            attn_factor, corr_dims, theta_scale, freq_factors, sections);\n    } else {\n        rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(\n            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,\n            attn_factor, corr_dims, theta_scale, freq_factors, sections);\n    }\n}\n\ntemplate <bool forward>\nvoid ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,\n                            ggml_tensor *               dst,\n                            const ggml_tensor *         set_rows = nullptr) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    const ggml_tensor * src2 = dst->src[2];\n\n    const float * src0_d = (const float *)src0->data;\n    const float * src1_d = (const float *)src1->data;\n\n    void *          dst_d           = dst->data;\n    const int64_t * row_indices     = nullptr;\n    ggml_type       dst_type        = dst->type;\n    int             set_rows_stride = 0;\n\n    if (set_rows != nullptr) {\n        GGML_ASSERT(forward);\n        dst_d           = set_rows->data;\n        row_indices     = (const int64_t *) set_rows->src[1]->data;\n        dst_type        = set_rows->type;\n        set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);\n    }\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);\n    // When not fused, src0 and dst types must match\n    // When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16\n    GGML_ASSERT(src0->type == dst->type || (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));\n\n    const int64_t ne00 = src0->ne[0]; // head dims\n    const int64_t ne01 = src0->ne[1]; // num heads\n    const int64_t ne02 = src0->ne[2]; // num heads\n    const int64_t nr = ggml_nrows(src0);\n\n    const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);\n    const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);\n    const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);\n\n    const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);\n    const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);\n    const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);\n\n    //const int n_past     = ((int32_t *) dst->op_params)[0];\n    const int n_dims     = ((int32_t *) dst->op_params)[1];\n    const int mode       = ((int32_t *) dst->op_params)[2];\n    //const int n_ctx      = ((int32_t *) dst->op_params)[3];\n    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];\n    mrope_sections sections;\n\n    // RoPE alteration for extended context\n    float freq_base;\n    float freq_scale;\n    float ext_factor;\n    float attn_factor;\n    float beta_fast;\n    float beta_slow;\n\n    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));\n    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));\n    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));\n    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));\n    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));\n    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));\n    memcpy(&sections.v,  (int32_t *) dst->op_params + 11, sizeof(int)*4);\n\n    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;\n    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;\n    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;\n    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;\n\n    if (is_mrope) {\n        GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);\n    }\n\n    if (is_vision) {\n        GGML_ASSERT(n_dims == ne00/2);\n    }\n\n    const int32_t * pos = (const int32_t *) src1_d;\n\n    const float * freq_factors = nullptr;\n    if (src2 != nullptr) {\n        freq_factors = (const float *) src2->data;\n    }\n\n    rope_corr_dims corr_dims;\n    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);\n\n    // compute\n    if (is_neox) {\n        if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {\n            rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,\n                                                  s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,\n                                                  ext_factor, attn_factor, corr_dims, freq_factors, row_indices,\n                                                  set_rows_stride, stream);\n        } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {\n            rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,\n                                                 s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,\n                                                 ext_factor, attn_factor, corr_dims, freq_factors, row_indices,\n                                                 set_rows_stride, stream);\n        } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {\n            rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,\n                                                s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,\n                                                ext_factor, attn_factor, corr_dims, freq_factors, row_indices,\n                                                set_rows_stride, stream);\n        } else {\n            GGML_ABORT(\"fatal error\");\n        }\n    } else if (is_mrope && !is_vision) {\n        if (src0->type == GGML_TYPE_F32) {\n            rope_multi_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,\n                                     s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,\n                                     corr_dims, freq_factors, sections, is_imrope, stream);\n        } else if (src0->type == GGML_TYPE_F16) {\n            rope_multi_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,\n                                     s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,\n                                     corr_dims, freq_factors, sections, is_imrope, stream);\n        } else {\n            GGML_ABORT(\"fatal error\");\n        }\n    } else if (is_vision) {\n        if (src0->type == GGML_TYPE_F32) {\n            rope_vision_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,\n                                      s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,\n                                      corr_dims, freq_factors, sections, stream);\n        } else if (src0->type == GGML_TYPE_F16) {\n            rope_vision_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,\n                                      s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,\n                                      corr_dims, freq_factors, sections, stream);\n        } else {\n            GGML_ABORT(\"fatal error\");\n        }\n    } else {\n        if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {\n            rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,\n                                                  s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,\n                                                  ext_factor, attn_factor, corr_dims, freq_factors, row_indices,\n                                                  set_rows_stride, stream);\n        } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {\n            rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,\n                                                 s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,\n                                                 ext_factor, attn_factor, corr_dims, freq_factors, row_indices,\n                                                 set_rows_stride, stream);\n        } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {\n            rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,\n                                                s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,\n                                                ext_factor, attn_factor, corr_dims, freq_factors, row_indices,\n                                                set_rows_stride, stream);\n        } else {\n            GGML_ABORT(\"fatal error\");\n        }\n    }\n}\n\nvoid ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_rope_impl<true>(ctx, dst);\n}\n\nvoid ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_rope_impl<false>(ctx, dst);\n}\n\nvoid ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) {\n    ggml_cuda_op_rope_impl<true>(ctx, rope, set_rows);\n}\n"
  },
  {
    "path": "src/ggml-cuda/rope.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_ROPE_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);\n"
  },
  {
    "path": "src/ggml-cuda/scale.cu",
    "content": "#include \"scale.cuh\"\n\n#define MAX_GRIDDIM_X 0x7FFFFFFF\n\nstatic __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {\n    int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;\n    int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;\n\n    for (int64_t i = tid; i < nelements; i += stride) {\n        dst[i] = scale * x[i] + bias;\n    }\n}\n\nstatic void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {\n    const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;\n    scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements);\n}\n\nvoid ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src0_d = (const float *)src0->data;\n    float * dst_d = (float *)dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    float scale;\n    float bias;\n    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));\n    memcpy(&bias,  (float *) dst->op_params + 1, sizeof(float));\n\n    scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/scale.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_SCALE_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/set-rows.cu",
    "content": "#include \"set-rows.cuh\"\n#include \"cpy-utils.cuh\"\n\ntypedef void (*set_rows_kernel_t)(const char * src, char * dst);\n\n// Generic quantized set_rows kernel template\ntemplate <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>\nstatic __global__ void k_set_rows_quant(const float * __restrict__ src0,\n                                        const idx_t * __restrict__ src1,\n                                        block_type * __restrict__ dst,\n                                        const int64_t ne_total,\n                                        const int64_t ne10,\n                                        const int64_t ne11,\n                                        const int64_t ne12,\n                                        const int64_t ne13,\n                                        const int64_t s01,\n                                        const int64_t s02,\n                                        const int64_t s03,\n                                        const int64_t s10,\n                                        const int64_t s11,\n                                        const int64_t s12,\n                                        const int64_t s1,\n                                        const int64_t s2,\n                                        const int64_t s3,\n                                        const uint3   ne00,\n                                        const uint3   ne01,\n                                        const uint3   ne02,\n                                        const uint3   ne11_fd,\n                                        const uint3   ne12_fd) {\n    const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;\n\n    if (i >= ne_total) {\n        return;\n    }\n\n    const int64_t i_base = i * qk;\n    uint32_t      tmp    = (uint32_t) i_base;\n    uint2         div_mod;\n\n    div_mod           = fast_div_modulo(tmp, ne00);\n    const int64_t i00 = div_mod.y;\n    tmp               = div_mod.x;\n\n    div_mod           = fast_div_modulo(tmp, ne01);\n    const int64_t i01 = div_mod.y;\n    tmp               = div_mod.x;\n\n    div_mod           = fast_div_modulo(tmp, ne02);\n    const int64_t i02 = div_mod.y;\n    const int64_t i03 = div_mod.x;\n\n    const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);\n    const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);\n    const int64_t i10 = i01;\n\n    const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);\n\n    const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;\n    block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type);\n\n    const float * src_block = src0_row + i00;\n    block_type * dst_block = dst_row_ptr + i00 / qk;\n\n    quantize_func(src_block, dst_block);\n\n    GGML_UNUSED(ne10);\n    GGML_UNUSED(ne11);\n    GGML_UNUSED(ne12);\n    GGML_UNUSED(ne13);\n}\n\n// Template dispatch function for quantized set_rows\ntemplate<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>\nstatic void set_rows_cuda_quant(\n        const float * src0_d, const idx_t * src1_d, block_type * dst_d,\n        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,\n        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,\n        const size_t nb01, const size_t nb02, const size_t nb03,\n        const size_t nb10, const size_t nb11, const size_t nb12,\n        const size_t nb1, const size_t nb2, const size_t nb3,\n        cudaStream_t stream) {\n\n    GGML_ASSERT(ne00 % qk == 0);\n    const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;\n    const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;\n    const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);\n    const dim3 grid_size(num_blocks);\n\n    const int64_t s01 = nb01/sizeof(float);\n    const int64_t s02 = nb02/sizeof(float);\n    const int64_t s03 = nb03/sizeof(float);\n    const int64_t s10 = nb10/sizeof(idx_t);\n    const int64_t s11 = nb11/sizeof(idx_t);\n    const int64_t s12 = nb12/sizeof(idx_t);\n    const int64_t s1  = nb1;\n    const int64_t s2  = nb2;\n    const int64_t s3  = nb3;\n\n    if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {\n        const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);\n        const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);\n        const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);\n        const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);\n        const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);\n\n        k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(\n            src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,\n            ne01_fd, ne02_fd, ne11_fd, ne12_fd);\n    }\n}\n\ntemplate <typename src_t, typename idx_t, typename dst_t>\nstatic __global__ void k_set_rows(const src_t * __restrict__ src0,\n                                  const idx_t * __restrict__ src1,\n                                  dst_t * __restrict__ dst,\n                                  const int64_t ne_total,\n                                  const int64_t ne10,\n                                  const int64_t ne11,\n                                  const int64_t ne12,\n                                  const int64_t ne13,\n                                  const int64_t s01,\n                                  const int64_t s02,\n                                  const int64_t s03,\n                                  const int64_t s10,\n                                  const int64_t s11,\n                                  const int64_t s12,\n                                  const int64_t s1,\n                                  const int64_t s2,\n                                  const int64_t s3,\n                                  const uint3   ne00,\n                                  const uint3   ne01,\n                                  const uint3   ne02,\n                                  const uint3   ne11_fd,\n                                  const uint3   ne12_fd) {\n    const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;\n\n    if (i >= ne_total) {\n        return;\n    }\n\n    uint32_t tmp = (uint32_t) i;\n    uint2    div_mod;\n\n    div_mod           = fast_div_modulo(tmp, ne00);\n    const int64_t i00 = div_mod.y;\n    tmp               = div_mod.x;\n\n    div_mod           = fast_div_modulo(tmp, ne01);\n    const int64_t i01 = div_mod.y;\n    tmp               = div_mod.x;\n\n    div_mod           = fast_div_modulo(tmp, ne02);\n    const int64_t i02 = div_mod.y;\n    const int64_t i03 = div_mod.x;\n\n    const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);\n    const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);\n    const int64_t i10 = i01;\n\n    const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);\n\n    const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;\n    dst_t * dst_row_ptr    = dst + dst_row*s1 + i02*s2 + i03*s3;\n\n    dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);\n\n    GGML_UNUSED(ne10);\n    GGML_UNUSED(ne11);\n    GGML_UNUSED(ne12);\n    GGML_UNUSED(ne13);\n}\n\ntemplate<typename src_t, typename idx_t, typename dst_t>\nstatic void set_rows_cuda(\n        const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d,\n        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,\n        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,\n        const size_t nb01, const size_t nb02, const size_t nb03,\n        const size_t nb10, const size_t nb11, const size_t nb12,\n        const size_t nb1, const size_t nb2, const size_t nb3,\n        cudaStream_t stream) {\n\n    const int64_t ne_total = ne00 * ne01 * ne02 * ne03;\n    const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;\n    const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);\n    const dim3 grid_size(num_blocks);\n\n\n    const int64_t s01 = nb01/sizeof(src_t);\n    const int64_t s02 = nb02/sizeof(src_t);\n    const int64_t s03 = nb03/sizeof(src_t);\n    const int64_t s10 = nb10/sizeof(idx_t);\n    const int64_t s11 = nb11/sizeof(idx_t);\n    const int64_t s12 = nb12/sizeof(idx_t);\n    const int64_t s1  = nb1/sizeof(dst_t);\n    const int64_t s2  = nb2/sizeof(dst_t);\n    const int64_t s3  = nb3/sizeof(dst_t);\n\n    if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {\n        const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);\n        const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);\n        const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);\n        const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);\n        const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);\n\n        k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,\n                                                         s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,\n                                                         ne11_fd, ne12_fd);\n    }\n}\n\ntemplate<typename src_t, typename idx_t>\nstatic void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    const src_t * src0_d = (const src_t *)src0->data;\n    const idx_t * src1_d = (const idx_t *)src1->data;\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    cudaStream_t stream = ctx.stream();\n\n\n    if (dst->type == GGML_TYPE_F32) {\n        set_rows_cuda(\n            src0_d, src1_d, (float*)dst->data,\n            ne00, ne01, ne02, ne03,\n            ne10, ne11, ne12, ne13,\n            nb01, nb02, nb03,\n            nb10, nb11, nb12,\n            nb1, nb2, nb3,\n            stream\n        );\n    } else if (dst->type == GGML_TYPE_F16) {\n        set_rows_cuda(\n            src0_d, src1_d, (half*)dst->data,\n            ne00, ne01, ne02, ne03,\n            ne10, ne11, ne12, ne13,\n            nb01, nb02, nb03,\n            nb10, nb11, nb12,\n            nb1, nb2, nb3,\n            stream\n        );\n    } else if (dst->type == GGML_TYPE_BF16) {\n        set_rows_cuda(\n            src0_d, src1_d, (nv_bfloat16*)dst->data,\n            ne00, ne01, ne02, ne03,\n            ne10, ne11, ne12, ne13,\n            nb01, nb02, nb03,\n            nb10, nb11, nb12,\n            nb1, nb2, nb3,\n            stream\n        );\n    } else if (dst->type == GGML_TYPE_Q4_0) {\n        set_rows_cuda_quant<idx_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>(\n            src0_d, src1_d, (block_q4_0*)dst->data,\n            ne00, ne01, ne02, ne03,\n            ne10, ne11, ne12, ne13,\n            nb01, nb02, nb03,\n            nb10, nb11, nb12,\n            nb1, nb2, nb3,\n            stream\n        );\n    } else if (dst->type == GGML_TYPE_Q4_1) {\n        set_rows_cuda_quant<idx_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>(\n            src0_d, src1_d, (block_q4_1*)dst->data,\n            ne00, ne01, ne02, ne03,\n            ne10, ne11, ne12, ne13,\n            nb01, nb02, nb03,\n            nb10, nb11, nb12,\n            nb1, nb2, nb3,\n            stream\n        );\n    } else if (dst->type == GGML_TYPE_Q5_0) {\n        set_rows_cuda_quant<idx_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>(\n            src0_d, src1_d, (block_q5_0*)dst->data,\n            ne00, ne01, ne02, ne03,\n            ne10, ne11, ne12, ne13,\n            nb01, nb02, nb03,\n            nb10, nb11, nb12,\n            nb1, nb2, nb3,\n            stream\n        );\n    } else if (dst->type == GGML_TYPE_Q5_1) {\n        set_rows_cuda_quant<idx_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>(\n            src0_d, src1_d, (block_q5_1*)dst->data,\n            ne00, ne01, ne02, ne03,\n            ne10, ne11, ne12, ne13,\n            nb01, nb02, nb03,\n            nb10, nb11, nb12,\n            nb1, nb2, nb3,\n            stream\n        );\n    } else if (dst->type == GGML_TYPE_Q8_0) {\n        set_rows_cuda_quant<idx_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>(\n            src0_d, src1_d, (block_q8_0*)dst->data,\n            ne00, ne01, ne02, ne03,\n            ne10, ne11, ne12, ne13,\n            nb01, nb02, nb03,\n            nb10, nb11, nb12,\n            nb1, nb2, nb3,\n            stream\n        );\n    } else if (dst->type == GGML_TYPE_IQ4_NL) {\n        set_rows_cuda_quant<idx_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(\n            src0_d, src1_d, (block_iq4_nl*)dst->data,\n            ne00, ne01, ne02, ne03,\n            ne10, ne11, ne12, ne13,\n            nb01, nb02, nb03,\n            nb10, nb11, nb12,\n            nb1, nb2, nb3,\n            stream\n        );\n    } else {\n        GGML_ABORT(\"unsupported type %s\", ggml_type_name(dst->type));\n    }\n}\n\n\nvoid ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);\n\n    if (src1->type == GGML_TYPE_I64) {\n        set_rows_cuda<float, int64_t>(ctx, src0, src1, dst);\n    } else {\n        set_rows_cuda<float, int32_t>(ctx, src0, src1, dst);\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/set-rows.cuh",
    "content": "#pragma once\n\n#include \"common.cuh\"\n\n#define CUDA_SET_ROWS_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/set.cu",
    "content": "#include \"set.cuh\"\n#include \"cpy.cuh\"\n\nvoid ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));\n    GGML_ASSERT(src1->type == src0->type);\n    GGML_ASSERT(dst ->type == src0->type);\n\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    GGML_ASSERT(ggml_is_contiguous(src0));\n    GGML_ASSERT(ggml_is_contiguous(src1));\n\n    const size_t nb1    = ((int32_t *) dst->op_params)[0];\n    const size_t nb2    = ((int32_t *) dst->op_params)[1];\n    const size_t nb3    = ((int32_t *) dst->op_params)[2];\n    const size_t offset = ((int32_t *) dst->op_params)[3];\n    const bool   inplace= (bool)     ((int32_t *) dst->op_params)[4];\n\n    if (!inplace) {\n        ggml_cuda_cpy(ctx, src0, dst);\n    }\n\n    ggml_tensor dst_view = *dst;\n    dst_view.data  = (void *)((char *)dst->data + offset);\n    dst_view.ne[0] = src1->ne[0];\n    dst_view.ne[1] = src1->ne[1];\n    dst_view.ne[2] = src1->ne[2];\n    dst_view.ne[3] = src1->ne[3];\n\n    dst_view.nb[0] = ggml_element_size(dst);\n    dst_view.nb[1] = nb1;\n    dst_view.nb[2] = nb2;\n    dst_view.nb[3] = nb3;\n\n    ggml_cuda_cpy(ctx, src1, &dst_view);\n}\n"
  },
  {
    "path": "src/ggml-cuda/set.cuh",
    "content": "#pragma once\n\n#include \"common.cuh\"\n\n#define CUDA_SET_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/softcap.cu",
    "content": "#include \"softcap.cuh\"\n\nstatic __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) {\n    const int i = blockDim.x*blockIdx.x + threadIdx.x;\n\n    if (i >= k) {\n        return;\n    }\n\n    dst[i] = tanhf(scale * x[i]) * softcap;\n}\n\nstatic void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) {\n    const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE;\n    softcap_f32<<<num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream>>>(x, dst, scale, softcap, k);\n}\n\n// fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE\nvoid ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src) {\n    const ggml_tensor * src0 = src->src[0];\n    const float * src0_d = (const float *)src0->data;\n    float * dst_d = (float *)dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    float scale;\n    float softcap;\n    memcpy(&scale,   (float *) src->op_params + 0, sizeof(float));\n    memcpy(&softcap, (float *) dst->op_params + 0, sizeof(float));\n\n    softcap_f32_cuda(src0_d, dst_d, scale, softcap, ggml_nelements(src0), stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/softcap.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_SOFTCAP_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src);\n"
  },
  {
    "path": "src/ggml-cuda/softmax.cu",
    "content": "#include \"common.cuh\"\n#include \"ggml.h\"\n#include \"softmax.cuh\"\n\n#ifdef GGML_USE_HIP\n#include <hip/hip_cooperative_groups.h>\n#else\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#endif // GGML_USE_HIP\n\n#include <cstdint>\n#include <utility>\n\ntemplate <typename T>\nstatic __device__ __forceinline__ float t2f32(T val) {\n    return (float) val;\n}\n\ntemplate <>\n__device__ float __forceinline__ t2f32<half>(half val) {\n    return __half2float(val);\n}\n\nstruct soft_max_params {\n\n    int64_t nheads;\n    uint32_t n_head_log2;\n    int64_t ncols;\n    int64_t nrows_x;\n    int64_t nrows_y;\n    int64_t ne00;\n    int64_t ne01;\n    int64_t ne02;\n    int64_t ne03;\n    int64_t nb11;\n    int64_t nb12;\n    int64_t nb13;\n\n    int64_t ne12;\n    int64_t ne13;\n    float scale;\n    float max_bias;\n    float m0;\n    float m1;\n};\n\n// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.\n// As we want to keep pragma unroll for all other cases we suppress the clang transformation warning here.\n#ifdef __clang__\n#pragma clang diagnostic push\n#pragma clang diagnostic ignored \"-Wpass-failed\"\n#endif // __clang__\ntemplate <bool use_shared, int ncols_template, int block_size_template, typename T>\nstatic __global__ void soft_max_f32(\n        const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {\n    const int ncols = ncols_template == 0 ? p.ncols : ncols_template;\n\n    const int tid  = threadIdx.x;\n\n    const int64_t i03 = blockIdx.z;\n    const int64_t i02 = blockIdx.y;\n    const int64_t i01 = blockIdx.x;\n\n    //TODO: noncontigous inputs/outputs\n    const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;\n\n    const int64_t i11 = i01;\n    const int64_t i12 = i02 % p.ne12;\n    const int64_t i13 = i03 % p.ne13;\n\n    x    += int64_t(rowx)*ncols;\n    mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);\n    dst  += int64_t(rowx)*ncols;\n\n    const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;\n\n    const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);\n\n    extern __shared__ float data_soft_max_f32[];\n    float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication\n    // shared memory buffer to cache values between iterations:\n    float * vals = use_shared ? buf_iw + WARP_SIZE : dst;\n\n    float max_val = sinks ? sinks[i02] : -INFINITY;\n\n#pragma unroll\n    for (int col0 = 0; col0 < ncols; col0 += block_size) {\n        const int col = col0 + tid;\n\n        if (ncols_template == 0 && col >= ncols) {\n            break;\n        }\n\n        const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);\n\n        vals[col] = val;\n        max_val = max(max_val, val);\n    }\n\n    // find the max value in the block\n    max_val = block_reduce<block_reduce_method::MAX, block_size_template>(max_val, buf_iw);\n\n    float tmp = 0.0f; // partial sum\n\n#pragma unroll\n    for (int col0 = 0; col0 < ncols; col0 += block_size) {\n        const int col = col0 + tid;\n\n        if (ncols_template == 0 && col >= ncols) {\n            break;\n        }\n\n        const float val = expf(vals[col] - max_val);\n        tmp += val;\n        vals[col] = val;\n    }\n\n    // find the sum of exps in the block\n    tmp = block_reduce<block_reduce_method::SUM, block_size_template>(tmp, buf_iw);\n\n    if (sinks) {\n        tmp += expf(sinks[i02] - max_val);\n    }\n\n    const float inv_sum = 1.0f / tmp;\n\n#pragma unroll\n    for (int col0 = 0; col0 < ncols; col0 += block_size) {\n        const int col = col0 + tid;\n\n        if (ncols_template == 0 && col >= ncols) {\n            return;\n        }\n\n        dst[col] = vals[col] * inv_sum;\n    }\n}\n\n// TODO: Template to allow keeping ncols in registers if they fit\nstatic __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,\n                                                                float * __restrict__ dst,\n                                                                float * __restrict__ tmp_maxs,\n                                                                float * __restrict__ tmp_sums,\n                                                                const soft_max_params p) {\n    namespace cg = cooperative_groups;\n\n    const cg::grid_group g = cg::this_grid();\n\n    const int tid               = threadIdx.x;\n    const int col_start         = blockIdx.x * blockDim.x + tid;\n    const int n_elem_per_thread = 4;\n\n    float     local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };\n    float     local_max                     = -INFINITY;\n    const int step_size                     = gridDim.x * blockDim.x;\n    __shared__ float shared_vals[32];\n\n    // Compute thread-local max\n    for (int col = col_start; col < p.ncols;) {\n#pragma unroll\n        for (int i = 0; i < n_elem_per_thread; i++) {\n            const int idx = col + i * step_size;\n            local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;\n        }\n#pragma unroll\n        for (int i = 0; i < n_elem_per_thread; i++) {\n            local_max = fmaxf(local_max, local_vals[i]);\n        }\n        col += step_size * n_elem_per_thread;\n    }\n\n    // Compute CTA-level max\n    local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);\n\n    // Store CTA-level max to GMEM\n    if (tid == 0) {\n        tmp_maxs[blockIdx.x] = local_max;\n    }\n    g.sync();\n\n    // Compute compute global max from CTA-level maxs\n    assert(gridDim.x < blockDim.x);  // currently we only support this case\n    if (tid < gridDim.x) {\n        local_max = tmp_maxs[tid];\n    } else {\n        local_max = -INFINITY;\n    }\n    local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);\n\n    // Compute softmax dividends, accumulate divisor\n    float tmp_expf = 0.0f;\n    for (int col = col_start; col < p.ncols;) {\n#pragma unroll\n        for (int i = 0; i < n_elem_per_thread; i++) {\n            const int idx = col + i * step_size;\n            local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;\n        }\n#pragma unroll\n        for (int i = 0; i < n_elem_per_thread; i++) {\n            const int idx = col + i * step_size;\n            if (idx < p.ncols) {\n                const float tmp = expf(local_vals[i] - local_max);\n                tmp_expf += tmp;\n                dst[idx] = tmp;\n            }\n        }\n        col += step_size * n_elem_per_thread;\n    }\n\n    // Reduce divisor within CTA\n    tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);\n\n    // Store CTA-level sum to GMEM\n    if (tid == 0) {\n        tmp_sums[blockIdx.x] = tmp_expf;\n    }\n    g.sync();\n\n    // Compute global sum from CTA-level sums\n    if (tid < gridDim.x) {\n        tmp_expf = tmp_sums[tid];\n    } else {\n        tmp_expf = 0.0f;\n    }\n    tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);\n\n    // Divide dividend by global sum + store data\n    for (int col = col_start; col < p.ncols;) {\n#pragma unroll\n        for (int i = 0; i < n_elem_per_thread; i++) {\n            const int idx = col + i * step_size;\n            local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY;\n        }\n#pragma unroll\n        for (int i = 0; i < n_elem_per_thread; i++) {\n            const int idx = col + i * step_size;\n            if (idx < p.ncols) {\n                dst[idx] = local_vals[i] / tmp_expf;\n            }\n        }\n        col += step_size * n_elem_per_thread;\n    }\n}\n\n#ifdef __clang__\n#pragma clang diagnostic pop\n#endif // __clang__\n\nstatic __global__ void soft_max_back_f32(\n        const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {\n    const int tid  = threadIdx.x;\n    const int rowx = blockIdx.x;\n\n    grad += int64_t(rowx)*ncols;\n    dstf += int64_t(rowx)*ncols;\n    dst  += int64_t(rowx)*ncols;\n\n    float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients\n\n    for (int col = tid; col < ncols; col += WARP_SIZE) {\n        dgf_dot += dstf[col]*grad[col];\n    }\n\n    dgf_dot = warp_reduce_sum(dgf_dot);\n\n    for (int col = tid; col < ncols; col += WARP_SIZE) {\n        dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];\n    }\n}\n\ntemplate<int... Ns, typename T>\nstatic void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,\n                             const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)\n{\n    const int id       = ggml_cuda_get_device();\n    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;\n\n    auto launch_kernel = [=](auto I) -> bool {\n        constexpr int ncols = decltype(I)::value;\n        constexpr int block = (ncols > 1024 ? 1024 : ncols);\n\n        if (p.ncols == ncols) {\n            CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);\n            soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>\n                (x, mask, sinks, dst, p);\n            return true;\n        }\n        return false;\n    };\n\n    // unary fold over launch_kernel\n    if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {\n        return;\n    }\n\n    //default case\n    CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);\n    soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);\n}\n\n__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,\n                                                     float * __restrict__ dst,\n                                                     float * __restrict__ tmp_maxs,\n                                                     float * __restrict__ tmp_sums,\n                                                     const soft_max_params p)\n// We loop over all instead of parallelizing across gridDim.y as cooperative groups\n// currently only support synchronizing the complete grid if not launched as a cluster group\n// (which requires CC > 9.0)\n// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization\n// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group\n{\n    for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) {\n        soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs,\n                                                 tmp_sums, p);\n    }\n}\n\ntemplate <typename T>\nstatic void soft_max_f32_cuda(const float *                                x,\n                              const T *                                    mask,\n                              const float *                                sinks,\n                              float *                                      dst,\n                              const soft_max_params &                      params,\n                              cudaStream_t                                 stream,\n                              [[maybe_unused]] ggml_backend_cuda_context & ctx) {\n    int nth = WARP_SIZE;\n    const int64_t ncols_x = params.ncols;\n\n    while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;\n    const dim3 block_dims(nth,     1, 1);\n    const dim3 block_nums(params.ne01, params.ne02, params.ne03);\n    const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);\n    static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, \"These values need to be adjusted.\");\n\n\n    const int id       = ggml_cuda_get_device();\n    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;\n\n\n    if (nbytes_shared <= smpbo) {\n        launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);\n    } else {\n        // Parallelize across SMs for top-p/dist-sampling\n        // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and\n        // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.\n        if (ggml_cuda_info().devices[id].supports_cooperative_launch &&\n            ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr &&\n            params.scale == 1.0f && params.max_bias == 0.0f) {\n            ggml_cuda_pool_alloc<float> tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));\n            ggml_cuda_pool_alloc<float> tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));\n\n            void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr,\n                                     (void *) &tmp_sums_alloc.ptr, (void *) const_cast<soft_max_params *>(&params) };\n            CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols,\n                                                   dim3(ggml_cuda_info().devices[id].nsm, 1, 1),\n                                                   dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream));\n        } else {\n            const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);\n            soft_max_f32<false, 0, 0>\n                <<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);\n        }\n    }\n}\n\nstatic void soft_max_back_f32_cuda(\n        const float * grad, const float * dstf, float * dst,\n        const int ncols, const int nrows, const float scale, cudaStream_t stream) {\n    const dim3 block_dims(WARP_SIZE, 1, 1);\n    const dim3 block_nums(nrows,     1, 1);\n\n    soft_max_back_f32<<<block_nums, block_dims, 0, stream>>>(grad, dstf, dst, ncols, scale);\n}\n\nvoid ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    const ggml_tensor * src2 = dst->src[2];\n\n    const float * src0_d = (const float *) src0->data;\n    const void  * src1_d = src1 ? (const void *) src1->data : nullptr;\n    const void  * src2_d = src2 ? (const void *) src2->data : nullptr;\n    float       *  dst_d = (float *) dst->data;\n\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional\n\n    const int64_t nrows_x = ggml_nrows(src0);\n    const int64_t nrows_y = src0->ne[1];\n\n    const int64_t ne00 = src0->ne[0];\n\n    float scale    = 1.0f;\n    float max_bias = 0.0f;\n\n    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));\n    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));\n\n    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);\n\n    const int64_t nb11 = src1 ? src1->nb[1] : 1;\n    const int64_t nb12 = src1 ? src1->nb[2] : 1;\n    const int64_t nb13 = src1 ? src1->nb[3] : 1;\n\n    const int64_t ne12 = src1 ? src1->ne[2] : 1;\n    const int64_t ne13 = src1 ? src1->ne[3] : 1;\n\n    const uint32_t n_head      = src0->ne[2];\n    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));\n\n    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);\n    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n\n    soft_max_params params = {};\n    params.nheads = src0->ne[2];\n    params.n_head_log2 = n_head_log2;\n    params.ncols = ne00;\n    params.nrows_x = nrows_x;\n    params.nrows_y = nrows_y;\n    params.ne00 = src0->ne[0];\n    params.ne01 = src0->ne[1];\n    params.ne02 = src0->ne[2];\n    params.ne03 = src0->ne[3];\n    params.nb11 = nb11;\n    params.nb12 = nb12;\n    params.nb13 = nb13;\n    params.ne12 = ne12;\n    params.ne13 = ne13;\n    params.scale = scale;\n    params.max_bias = max_bias;\n    params.m0 = m0;\n    params.m1 = m1;\n\n    if (use_f16) {\n        soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);\n    } else {\n        soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);\n    }\n}\n\nvoid ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0]; // grad\n    const ggml_tensor * src1 = dst->src[1]; // forward pass output\n\n    const float * src0_d = (const float *) src0->data;\n    const float * src1_d = (const float *) src1->data;\n    float       * dst_d  = (float       *) dst->data;\n\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    const int64_t ncols = src0->ne[0];\n    const int64_t nrows = ggml_nrows(src0);\n\n    float scale    = 1.0f;\n    float max_bias = 0.0f;\n\n    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));\n    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));\n\n    GGML_ASSERT(max_bias == 0.0f);\n\n    soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/softmax.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_SOFT_MAX_BLOCK_SIZE 1024\n\nvoid ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/solve_tri.cu",
    "content": "#include \"common.cuh\"\n#include \"ggml.h\"\n#include \"solve_tri.cuh\"\n\n#define MAX_N_FAST 64\n#define MAX_K_FAST 32\n\nstatic __global__ void get_batch_pointers(const float *  A,\n                                          float *        X,\n                                          const float ** A_ptrs,\n                                          float **       X_ptrs,\n                                          int64_t        ne02,\n                                          int64_t        total_batches,\n                                          size_t         s02,\n                                          size_t         s03,\n                                          size_t         s2,\n                                          size_t         s3) {\n    const int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx >= total_batches) {\n        return;\n    }\n\n    const int64_t i3 = idx / ne02;\n    const int64_t i2 = idx % ne02;\n\n    A_ptrs[idx] = A + i3 * s03 + i2 * s02;\n    X_ptrs[idx] = X + i3 * s3 + i2 * s2;\n}\n\nstatic void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,\n                                 const float *               A,\n                                 const float *               B,\n                                 float *                     X,\n                                 int                         n,\n                                 int                         k,\n                                 int64_t                     ne02,\n                                 int64_t                     ne03,\n                                 size_t                      s02,\n                                 size_t                      s03,\n                                 size_t                      s12,\n                                 size_t                      s13,\n                                 size_t                      s2,\n                                 size_t                      s3,\n                                 cudaStream_t                stream) {\n    const float   alpha         = 1.0f;\n    const int64_t total_batches = ne02 * ne03;\n    if (total_batches == 0) {\n        return;\n    }\n\n    // Bulk copy B -> X (contiguous tensors)\n    if (X != B) {\n        const int64_t total_elements_BX = n * k * total_batches;\n        CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream));\n    }\n\n    const int id = ggml_cuda_get_device();\n\n    ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);\n    ggml_cuda_pool_alloc<float *>       X_ptrs_alloc(ctx.pool(id), total_batches);\n\n    const float ** A_ptrs_dev = A_ptrs_alloc.get();\n    float **       X_ptrs_dev = X_ptrs_alloc.get();\n\n    get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02,\n                                                                        total_batches, s02, s03, s2, s3);\n\n    CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));\n\n    // Yes, this is necessary, without this we get RMSE errors\n    CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH));\n    CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N,\n                                    CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches));\n\n    // revert to standard mode from common.cuh\n    CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH));\n\n    GGML_UNUSED_VARS(s12, s13);\n}\n\n// ======================\n// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction\n// ======================\n// When ncols_template == 0 the bounds for the loops in this function are not\n// known and can't be unrolled. As we want to keep pragma unroll for all other\n// cases we suppress the clang transformation warning here.\n#ifdef __clang__\n#    pragma clang diagnostic push\n#    pragma clang diagnostic ignored \"-Wpass-failed\"\n#endif  // __clang__\ntemplate <int n_template, int k_template>\nstatic __global__ void solve_tri_f32_fast(const float * __restrict__ A,\n                                          const float * __restrict__ B,\n                                          float * __restrict__ X,\n                                          const uint3  ne02,\n                                          const size_t nb02,\n                                          const size_t nb03,\n                                          const size_t nb12,\n                                          const size_t nb13,\n                                          const size_t nb2,\n                                          const size_t nb3,\n                                          const int    n_arg,\n                                          const int    k_arg) {\n    const int n = n_template == 0 ? n_arg : n_template;\n    const int k = k_template == 0 ? k_arg : k_template;\n\n    const int batch_idx = blockIdx.x;\n    const int lane      = threadIdx.x;\n    const int col_idx   = threadIdx.y;\n\n    if (col_idx >= k) {\n        return;\n    }\n\n    const uint2   i02_i03 = fast_div_modulo(batch_idx, ne02);\n    const int64_t i02     = i02_i03.y;\n    const int64_t i03     = i02_i03.x;\n\n    const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);\n    const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);\n    float *             X_batch = (float *) (X + i02 * nb2 + i03 * nb3);\n\n    __shared__ float sA[MAX_N_FAST * MAX_N_FAST];\n\n    const int offset = threadIdx.x + threadIdx.y * blockDim.x;\n\n#pragma unroll\n    for (int i = 0; i < n * n; i += k * WARP_SIZE) {\n        const int i0 = i + offset;\n        if (i0 < n * n) {\n            sA[i0] = A_batch[i0];\n        }\n    }\n\n    __syncthreads();\n\n    float x_low  = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;\n    float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;\n\n    const int half      = WARP_SIZE;\n    const int nrows_low = (n < half) ? n : half;\n\n#pragma unroll\n    for (int row = 0; row < nrows_low; ++row) {\n        float sum = 0.0f;\n        if (lane < row) {\n            sum += sA[row * n + lane] * x_low;\n        }\n        sum = warp_reduce_sum(sum);\n\n        if (lane == row) {\n            x_low = (x_low - sum) / sA[row * n + row];\n        }\n    }\n\n#pragma unroll\n    for (int row = half; row < n; ++row) {\n        float     sum = sA[row * n + lane] * x_low;\n        const int j   = half + lane;\n        if (j < row) {\n            sum += sA[row * n + j] * x_high;\n        }\n        sum = warp_reduce_sum(sum);\n\n        if (lane == row - half) {\n            x_high = (x_high - sum) / sA[row * n + row];\n        }\n    }\n\n#pragma unroll\n    for (int rr = 0; rr < 2; ++rr) {\n        const int row = rr * WARP_SIZE + lane;\n        if (row < n) {\n            const float val            = (row < half) ? x_low : x_high;\n            X_batch[row * k + col_idx] = val;\n        }\n    }\n}\n#ifdef __clang__\n#    pragma clang diagnostic pop\n#endif  // __clang__\n\nstatic void solve_tri_f32_cuda(const float * A,\n                               const float * B,\n                               float *       X,\n                               int           n,\n                               int           k,\n                               int64_t       ne02,\n                               int64_t       ne03,\n                               size_t        nb02,\n                               size_t        nb03,\n                               size_t        nb12,\n                               size_t        nb13,\n                               size_t        nb2,\n                               size_t        nb3,\n                               cudaStream_t  stream) {\n    const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);\n    dim3        threads(WARP_SIZE, k);\n    dim3        grid(ne02 * ne03);\n    if (n == 64) {\n        switch (k) {\n            case 32:\n                solve_tri_f32_fast<64, 32>\n                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);\n                break;\n            case 16:\n                solve_tri_f32_fast<64, 16>\n                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);\n                break;\n            case 14:\n                solve_tri_f32_fast<64, 14>\n                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);\n                break;\n            case 12:\n                solve_tri_f32_fast<64, 12>\n                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);\n                break;\n            case 10:\n                solve_tri_f32_fast<64, 10>\n                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);\n                break;\n            case 8:\n                solve_tri_f32_fast<64, 8>\n                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);\n                break;\n            case 6:\n                solve_tri_f32_fast<64, 6>\n                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);\n                break;\n            case 4:\n                solve_tri_f32_fast<64, 4>\n                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);\n                break;\n            case 2:\n                solve_tri_f32_fast<64, 2>\n                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);\n                break;\n            case 1:\n                solve_tri_f32_fast<64, 1>\n                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);\n                break;\n            default:\n                solve_tri_f32_fast<0, 0>\n                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);\n        }\n    } else {  // run general case\n        solve_tri_f32_fast<0, 0>\n            <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);\n    }\n}\n\nvoid ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];  // A (n×n, lower triangular)\n    const ggml_tensor * src1 = dst->src[1];  // B (n×k)\n\n    ggml_is_contiguous(src0);\n    ggml_is_contiguous(src1);\n\n    const int64_t n    = src0->ne[0];\n    const int64_t k    = src1->ne[0];\n    const int64_t ne02 = src0->ne[2];\n    const int64_t ne03 = src0->ne[3];\n\n    if (n <= MAX_N_FAST && k <= MAX_K_FAST) {\n        solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,\n                           src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),\n                           src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),\n                           dst->nb[3] / sizeof(float), ctx.stream());\n    } else {\n        solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,\n                             ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),\n                             src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),\n                             dst->nb[3] / sizeof(float), ctx.stream());\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/solve_tri.cuh",
    "content": "#include \"common.cuh\"\n\nvoid ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/ssm-conv.cu",
    "content": "#include \"ssm-conv.cuh\"\n#include \"unary.cuh\"\n\ntemplate <bool apply_silu, size_t split_d_inner, size_t d_conv>\nstatic __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,\n                                    const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,\n                                    float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,\n                                    const int64_t n_t) {\n    GGML_UNUSED(src0_nb0);\n    const int tid  = threadIdx.x;\n    const int bidx = blockIdx.x;\n    const int bidy = blockIdx.y;\n\n    const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);\n    const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);\n    float *       y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);\n\n    const int stride_x = src0_nb1 / sizeof(float);\n    const int stride_w = src1_nb1 / sizeof(float);\n    const int stride_y = dst_nb1 / sizeof(float);\n\n    float x[d_conv] = { 0.0f };\n    float w[d_conv] = { 0.0f };\n\n#pragma unroll\n    for (size_t j = 0; j < d_conv; j++) {\n        w[j] = w_block[tid * stride_w + j];\n    }\n\n    for (int64_t i = 0; i < n_t; i++) {\n        float sumf = 0.0f;\n\n        if (i == 0) {\n            for (size_t j = 0; j < d_conv; j++) {\n                x[j] = x_block[tid * stride_x + j];\n            }\n        } else {\n            x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];\n        }\n\n#pragma unroll\n        for (size_t j = 0; j < d_conv; j++) {\n            sumf += x[(i + j) % d_conv] * w[j];\n        }\n        y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;\n    }\n}\n\ntemplate <bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t>\nstatic __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,\n                                               const int src0_nb0, const int src0_nb1, const int src0_nb2,\n                                               const int src1_nb1, float * __restrict__ dst, const int dst_nb0,\n                                               const int dst_nb1, const int dst_nb2, const int64_t n_t) {\n    const int tid  = threadIdx.x;\n    const int bidx = blockIdx.x;\n    const int bidy = blockIdx.y;\n    const int bidz = blockIdx.z;\n\n    const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +\n                                             bidz * split_n_t * src0_nb0);\n    const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);\n    float *       y_block =\n        (float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);\n\n    const int stride_x = src0_nb1 / sizeof(float);\n    const int stride_w = src1_nb1 / sizeof(float);\n    const int stride_y = dst_nb1 / sizeof(float);\n\n    const int64_t local_n_t = min(split_n_t, n_t - bidz * split_n_t);\n    const int     n_cols    = d_conv - 1 + split_n_t;\n\n    extern __shared__ float smem[];\n\n    constexpr int load_cols   = d_conv - 1 + split_n_t;\n    constexpr int total_elems = split_d_inner * load_cols;\n    int row = tid / load_cols;\n    int col = tid % load_cols;\n#pragma unroll\n    for (int idx = 0; idx < total_elems; idx += split_d_inner) {\n        if (row < (int)split_d_inner) {\n            smem[row * n_cols + col] = x_block[row * stride_x + col];\n        }\n\n        col += split_d_inner;\n        row += col / load_cols;\n        col  = col % load_cols;\n        if (idx >= total_elems - tid - split_d_inner) {\n            break;\n        }\n    }\n    __syncthreads();\n\n    // Load weights into registers (done once, small)\n    float w[d_conv] = { 0.0f };\n#pragma unroll\n    for (size_t j = 0; j < d_conv; j++) {\n        w[j] = w_block[tid * stride_w + j];\n    }\n\n    // Compute from shared memory\n    for (int64_t i = 0; i < local_n_t; i++) {\n        float sumf = 0.0f;\n#pragma unroll\n        for (size_t j = 0; j < d_conv; j++) {\n            sumf += smem[tid * n_cols + i + j] * w[j];\n        }\n        y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;\n    }\n}\n\ntemplate <bool apply_silu>\nstatic void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,\n                              const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,\n                              const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,\n                              const int64_t n_s, cudaStream_t stream) {\n    const int threads = 128;\n    GGML_ASSERT(nr % threads == 0);\n\n    auto launch_kernel = [&](auto NC) {\n        constexpr int kNC = decltype(NC)::value;\n        if (n_t <= 32) {\n            const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);\n            ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,\n                                                                       dst, dst_nb0, dst_nb1, dst_nb2, n_t);\n        } else {\n            const int64_t split_n_t = 32;\n            dim3          blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);\n            const size_t  smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float);\n            ssm_conv_long_token_f32<apply_silu, threads, kNC, split_n_t><<<blocks, threads, smem_size, stream>>>(\n                src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);\n        }\n    };\n\n    switch (nc) {\n        case 3: launch_kernel(std::integral_constant<int, 3>{}); break;\n        case 4: launch_kernel(std::integral_constant<int, 4>{}); break;\n        case 9: launch_kernel(std::integral_constant<int, 9>{}); break;\n        default: GGML_ABORT(\"Only support kernel sizes 3, 4, 9 right now.\");\n    }\n}\n\nvoid ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) {\n    const struct ggml_tensor * src0 = dst->src[0];  // conv_x\n    const struct ggml_tensor * src1 = dst->src[1];  // conv1d.weight\n    const bool fuse_silu = silu_dst != nullptr;\n\n    // When fusing, write to silu_dst (the node downstream references).\n    const struct ggml_tensor * out = fuse_silu ? silu_dst : dst;\n\n    const int64_t nc  = src1->ne[0];                // d_conv\n    const int64_t nr  = src0->ne[1];                // d_inner\n    const int64_t n_t = out->ne[1];                 // tokens per sequence\n    const int64_t n_s = out->ne[2];                 // number of sequences in the batch\n\n    GGML_ASSERT(out->ne[0] == nr);\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n    GGML_ASSERT(src1->nb[0] == sizeof(float));\n    GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));\n\n    const float * src0_d = (const float *) src0->data;\n    const float * src1_d = (const float *) src1->data;\n    float *       dst_d  = (float *) out->data;\n    cudaStream_t  stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(out->type == GGML_TYPE_F32);\n    if (fuse_silu) {\n        ssm_conv_f32_cuda<true>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],\n                          out->nb[2], nc, nr, n_t, n_s, stream);\n    } else {\n        ssm_conv_f32_cuda<false>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],\n                          out->nb[2], nc, nr, n_t, n_s, stream);\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/ssm-conv.cuh",
    "content": "#include \"common.cuh\"\n\nvoid ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr);\n"
  },
  {
    "path": "src/ggml-cuda/ssm-scan.cu",
    "content": "#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070\n#define USE_CUB\n#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070\n\n#ifdef USE_CUB\n#include <cub/cub.cuh>\nusing namespace cub;\n#endif // USE_CUB\n\n#include \"ssm-scan.cuh\"\n\n// We would like to keep pragma unroll for cases where L_template is not 0,\n// so we suppress the clang transformation warning.\n#ifdef __clang__\n#pragma clang diagnostic push\n#pragma clang diagnostic ignored \"-Wpass-failed\"\n#endif // __clang__\ntemplate <size_t splitD, size_t N, size_t L_template>\n__global__ void __launch_bounds__(splitD, 1)\n    ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,\n                 const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,\n                 const int32_t * __restrict__ src6, float * __restrict__ dst,\n                 const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,\n                 const int src2_nb1, const int src2_nb2, const int src3_nb1,\n                 const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,\n                 const int64_t s_off, const int64_t d_inner, const int64_t L_param)\n{\n    const size_t L = L_template == 0 ? L_param : L_template;\n    const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);\n    const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));\n    const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));\n    const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);\n    const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));\n    const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));\n    float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));\n    float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);\n\n    const int stride_x = src1_nb2 / sizeof(float);\n    const int stride_dt = src2_nb1 / sizeof(float);\n    const int stride_B = src4_nb2 / sizeof(float);\n    const int stride_C = src5_nb2 / sizeof(float);\n    const int stride_y = d_inner;\n\n    float regA[N];\n    float regs0[N];\n\n    __shared__ float smemB[N];\n    __shared__ float smemC[N];\n\n#ifdef USE_CUB\n    using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;\n    using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;\n\n    union CubTempStorage {\n        typename BlockLoad::TempStorage load_temp;\n        typename BlockStore::TempStorage store_temp;\n    };\n    __shared__ CubTempStorage cub_temp_storage;\n\n    BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);\n    BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);\n#else\n    const int stride_s0 = src0_nb2 / sizeof(float);\n    const int stride_A = src3_nb1 / sizeof(float);\n#pragma unroll\n    for (size_t n = 0; n < N; ++n)\n    {\n        regA[n] = A_block[threadIdx.x * stride_A + n];\n        regs0[n] = s0_block[threadIdx.x * stride_s0 + n];\n    }\n#endif\n\n#pragma unroll\n    for (size_t i = 0; i < L; i++)\n    {\n        if (threadIdx.x < N)\n        {\n            smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];\n            smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];\n        }\n        __syncthreads();\n\n        float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];\n        if (dt_soft_plus <= 20.0f)\n        {\n            dt_soft_plus = log1pf(expf(dt_soft_plus));\n        }\n        float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;\n\n        float sumf = 0.0f;\n#pragma unroll\n        for (size_t n = 0; n < N; n++)\n        {\n            float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;\n            sumf += state * smemC[n];\n            regs0[n] = state;\n        }\n        y_block[i * stride_y + threadIdx.x] = sumf;\n    }\n\n#ifdef USE_CUB\n    BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);\n#else\n    const int stride_s = stride_s0;\n#pragma unroll\n    for (size_t n = 0; n < N; ++n)\n    {\n        s_block[threadIdx.x * stride_s + n] = regs0[n];\n    }\n#endif\n}\n#ifdef __clang__\n#pragma clang diagnostic pop\n#endif // __clang__\n\n// assumes as many threads as d_state\ntemplate <int c_factor, int d_state>\n__global__ void __launch_bounds__(d_state, 1)\n    ssm_scan_f32_group(\n        const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,\n        const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,\n        const int32_t * __restrict__ src6, float * __restrict__ dst,\n        const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,\n        const int src2_nb1, const int src2_nb2, const int src3_nb1,\n        const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,\n        const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {\n\n    const int warp     = threadIdx.x / WARP_SIZE;\n    const int lane     = threadIdx.x % WARP_SIZE;\n    const int warp_idx = blockIdx.x  * c_factor + warp;\n\n    const int head_idx =  warp_idx / d_head;\n    const int head_off = (warp_idx % d_head) * sizeof(float);\n    const int seq_idx  = blockIdx.y;\n\n    const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);\n\n    // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase\n    const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);\n    const float * x_warp  = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));\n    const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));\n    const float * A_warp  = (const float *) ((const char *) src3 + head_idx * src3_nb1);\n    const float * B_warp  = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));\n    const float * C_warp  = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));\n    float *       y_warp  = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;\n    float *       s_warp  = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);\n\n    // strides across n_seq_tokens\n    const int stride_x  = src1_nb2 / sizeof(float);\n    const int stride_dt = src2_nb1 / sizeof(float);\n    const int stride_B  = src4_nb2 / sizeof(float);\n    const int stride_C  = src5_nb2 / sizeof(float);\n    const int stride_y  = n_head * d_head;\n\n    float state[c_factor];\n    float state_sum = 0.0f;\n\n#pragma unroll\n    for (int j = 0; j < c_factor; j++) {\n        state[j] = s0_warp[WARP_SIZE * j + lane];\n    }\n\n    for (int64_t i = 0; i < n_tok; i++) {\n        // NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here.\n        // Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead.\n        const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]);\n\n        state_sum = 0.0f;\n        const float dA   = expf(dt_soft_plus * A_warp[0]);\n        const float x_dt = x_warp[i * stride_x] * dt_soft_plus;\n#pragma unroll\n        for (int j = 0; j < c_factor; j++) {\n            const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];\n            const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];\n            state[j] = (state[j] * dA) + (B_val * x_dt);\n            state_sum += state[j] * C_val;\n        }\n\n        // parallel accumulation for output\n        state_sum = warp_reduce_sum(state_sum);\n\n        if (lane == 0) {\n            y_warp[i * stride_y] = state_sum;\n        }\n    }\n\n    // write back the state\n#pragma unroll\n    for (int j = 0; j < c_factor; j++) {\n        s_warp[WARP_SIZE * j + lane] = state[j];\n    }\n}\n\nstatic void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,\n                              const float * src4, const float * src5, const int32_t * src6, float * dst,\n                              const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,\n                              const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,\n                              const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,\n                              const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,\n                              cudaStream_t stream) {\n    // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!\n    if (src3_nb1 == sizeof(float)) {\n        // Mamba-2\n        if (d_state == 128) {\n            constexpr int threads   = 128;\n            constexpr int num_warps = threads/WARP_SIZE;\n\n            const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);\n            ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>(\n                    src0, src1, src2, src3, src4, src5, src6, dst,\n                    src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,\n                    src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);\n        } else if (d_state == 256) { // Falcon-H1\n            constexpr int threads   = 256;\n            constexpr int num_warps = threads/WARP_SIZE;\n\n            const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);\n            ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>(\n                    src0, src1, src2, src3, src4, src5, src6, dst,\n                    src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,\n                    src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);\n        } else {\n            GGML_ABORT(\"doesn't support d_state!=(128 or 256).\");\n        }\n    } else {\n        // Mamba-1\n        constexpr int threads = 128;\n        GGML_ASSERT(n_head % threads == 0);\n        GGML_ASSERT(head_dim == 1);\n        GGML_ASSERT(n_group == 1);\n        const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);\n        const int  smem_size = (threads * (d_state + 1) * 2) * sizeof(float);\n        if (d_state == 16) {\n            switch (n_tok)\n            {\n            case 1:\n                ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(\n                    src0, src1, src2, src3, src4, src5, src6, dst,\n                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,\n                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);\n                break;\n            case 2:\n                ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(\n                    src0, src1, src2, src3, src4, src5, src6, dst,\n                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,\n                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);\n                break;\n            case 3:\n                ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(\n                    src0, src1, src2, src3, src4, src5, src6, dst,\n                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,\n                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);\n                break;\n            case 4:\n                ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(\n                    src0, src1, src2, src3, src4, src5, src6, dst,\n                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,\n                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);\n                break;\n            case 5:\n                ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(\n                    src0, src1, src2, src3, src4, src5, src6, dst,\n                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,\n                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);\n                break;\n            case 6:\n                ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(\n                    src0, src1, src2, src3, src4, src5, src6, dst,\n                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,\n                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);\n                break;\n            case 7:\n                ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(\n                    src0, src1, src2, src3, src4, src5, src6, dst,\n                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,\n                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);\n                break;\n            case 8:\n                ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(\n                    src0, src1, src2, src3, src4, src5, src6, dst,\n                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,\n                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);\n                break;\n            default:\n                ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(\n                    src0, src1, src2, src3, src4, src5, src6, dst,\n                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,\n                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);\n                break;\n            }\n        } else {\n            GGML_ABORT(\"doesn't support d_state!=16.\");\n        }\n    }\n}\n\nvoid ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const struct ggml_tensor * src0 = dst->src[0];  // s\n    const struct ggml_tensor * src1 = dst->src[1];  // x\n    const struct ggml_tensor * src2 = dst->src[2];  // dt\n    const struct ggml_tensor * src3 = dst->src[3];  // A\n    const struct ggml_tensor * src4 = dst->src[4];  // B\n    const struct ggml_tensor * src5 = dst->src[5];  // C\n    const struct ggml_tensor * src6 = dst->src[6];  // ids\n\n    const int64_t nc  = src0->ne[0];  // d_state\n    const int64_t nr  = src0->ne[1];  // head_dim or 1\n    const int64_t nh  = src1->ne[1];  // n_head\n    const int64_t ng  = src4->ne[1];  // n_group\n    const int64_t n_t = src1->ne[2];  // number of tokens per sequence\n    const int64_t n_s = src1->ne[3];  // number of sequences in the batch\n\n    const int64_t s_off = ggml_nelements(src1) * sizeof(float);\n\n    GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n    GGML_ASSERT(src1->nb[0] == sizeof(float));\n    GGML_ASSERT(src2->nb[0] == sizeof(float));\n    GGML_ASSERT(src3->nb[0] == sizeof(float));\n    GGML_ASSERT(src4->nb[0] == sizeof(float));\n    GGML_ASSERT(src5->nb[0] == sizeof(float));\n    GGML_ASSERT(src6->nb[0] == sizeof(int32_t));\n\n    const float * src0_d = (const float *) src0->data;\n    const float * src1_d = (const float *) src1->data;\n    const float * src2_d = (const float *) src2->data;\n    const float * src3_d = (const float *) src3->data;\n    const float * src4_d = (const float *) src4->data;\n    const float * src5_d = (const float *) src5->data;\n    const int32_t * src6_d = (const int32_t *) src6->data;\n    float *       dst_d  = (float *) dst->data;\n    cudaStream_t  stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src6->type == GGML_TYPE_I32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,\n                      src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],\n                      src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],\n                      s_off, nc, nr, nh, ng, n_t, n_s, stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/ssm-scan.cuh",
    "content": "#include \"common.cuh\"\n\nvoid ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/sum.cu",
    "content": "#include \"sum.cuh\"\n#include \"sumrows.cuh\"\n\n#ifdef GGML_CUDA_USE_CUB\n#include <cub/cub.cuh>\nusing namespace cub;\n#endif  // GGML_CUDA_USE_CUB\n\n#include <cstdint>\n\nvoid sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {\n#ifdef GGML_CUDA_USE_CUB\n    size_t tmp_size = 0;\n    DeviceReduce::Sum(nullptr,       tmp_size, x, dst, ne, stream);\n    ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);\n    DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream);\n#else\n    // Use (inefficient) sum_rows implementation as a fallback.\n    // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.\n    sum_rows_f32_cuda(x, dst, ne, 1, stream);\n    GGML_UNUSED(pool);\n#endif // GGML_CUDA_USE_CUB\n}\n\nvoid ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n    GGML_ASSERT(ggml_is_contiguously_allocated(src0));\n\n    const float * src0_d = (const float *) src0->data;\n    float * dst_d = (float *) dst->data;\n\n    const int64_t ne = ggml_nelements(src0);\n\n    ggml_cuda_pool & pool = ctx.pool();\n    cudaStream_t stream = ctx.stream();\n\n    sum_f32_cuda(pool, src0_d, dst_d, ne, stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/sum.cuh",
    "content": "#include \"common.cuh\"\n\nvoid sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream);\n\nvoid ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/sumrows.cu",
    "content": "#include \"reduce_rows.cuh\"\n#include \"sumrows.cuh\"\n\nvoid sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {\n    const int  id  = ggml_cuda_get_device();\n    const int  nsm = ggml_cuda_info().devices[id].nsm;\n    const dim3 block_nums(nrows, 1, 1);\n    if ((nrows / nsm) < 2) {\n        const dim3 block_dims(512, 1, 1);\n        reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);\n    } else {\n        const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);\n        reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);\n    }\n}\n\nvoid ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src0_d = (const float *)src0->data;\n    float * dst_d = (float *)dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    const int64_t ncols = src0->ne[0];\n    const int64_t nrows = ggml_nrows(src0);\n\n    const dim3 block_nums(nrows, 1, 1);\n\n    const int id  = ggml_cuda_get_device();\n    const int nsm = ggml_cuda_info().devices[id].nsm;\n    if ((nrows / nsm) < 2) {\n        // Increase num threads to 512 for small nrows to better hide the latency\n        const dim3 block_dims(512, 1, 1);\n        reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);\n    } else {\n        // Enough active SMs to hide latency, use smaller blocks to allow better scheduling\n        const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);\n        reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/sumrows.cuh",
    "content": "#include \"common.cuh\"\n\nvoid sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);\nvoid ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 1, 8);\nDECL_FATTN_MMA_F16_CASE(80, 80, 1, 8);\nDECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);\nDECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);\nDECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);\nDECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 16, 1);\nDECL_FATTN_MMA_F16_CASE(80, 80, 16, 1);\nDECL_FATTN_MMA_F16_CASE(96, 96, 16, 1);\nDECL_FATTN_MMA_F16_CASE(112, 112, 16, 1);\nDECL_FATTN_MMA_F16_CASE(128, 128, 16, 1);\nDECL_FATTN_MMA_F16_CASE(256, 256, 16, 1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 16, 2);\nDECL_FATTN_MMA_F16_CASE(80, 80, 16, 2);\nDECL_FATTN_MMA_F16_CASE(96, 96, 16, 2);\nDECL_FATTN_MMA_F16_CASE(112, 112, 16, 2);\nDECL_FATTN_MMA_F16_CASE(128, 128, 16, 2);\nDECL_FATTN_MMA_F16_CASE(256, 256, 16, 2);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 16, 4);\nDECL_FATTN_MMA_F16_CASE(80, 80, 16, 4);\nDECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);\nDECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);\nDECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);\nDECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);\nDECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 2, 4);\nDECL_FATTN_MMA_F16_CASE(80, 80, 2, 4);\nDECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);\nDECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);\nDECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);\nDECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);\nDECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 2, 8);\nDECL_FATTN_MMA_F16_CASE(80, 80, 2, 8);\nDECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);\nDECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);\nDECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);\nDECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 32, 1);\nDECL_FATTN_MMA_F16_CASE(80, 80, 32, 1);\nDECL_FATTN_MMA_F16_CASE(96, 96, 32, 1);\nDECL_FATTN_MMA_F16_CASE(112, 112, 32, 1);\nDECL_FATTN_MMA_F16_CASE(128, 128, 32, 1);\nDECL_FATTN_MMA_F16_CASE(256, 256, 32, 1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 32, 2);\nDECL_FATTN_MMA_F16_CASE(80, 80, 32, 2);\nDECL_FATTN_MMA_F16_CASE(96, 96, 32, 2);\nDECL_FATTN_MMA_F16_CASE(112, 112, 32, 2);\nDECL_FATTN_MMA_F16_CASE(128, 128, 32, 2);\nDECL_FATTN_MMA_F16_CASE(256, 256, 32, 2);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 4, 2);\nDECL_FATTN_MMA_F16_CASE(80, 80, 4, 2);\nDECL_FATTN_MMA_F16_CASE(96, 96, 4, 2);\nDECL_FATTN_MMA_F16_CASE(112, 112, 4, 2);\nDECL_FATTN_MMA_F16_CASE(128, 128, 4, 2);\nDECL_FATTN_MMA_F16_CASE(256, 256, 4, 2);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 4, 4);\nDECL_FATTN_MMA_F16_CASE(80, 80, 4, 4);\nDECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);\nDECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);\nDECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);\nDECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);\nDECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 4, 8);\nDECL_FATTN_MMA_F16_CASE(80, 80, 4, 8);\nDECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);\nDECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);\nDECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);\nDECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 64, 1);\nDECL_FATTN_MMA_F16_CASE(80, 80, 64, 1);\nDECL_FATTN_MMA_F16_CASE(96, 96, 64, 1);\nDECL_FATTN_MMA_F16_CASE(112, 112, 64, 1);\nDECL_FATTN_MMA_F16_CASE(128, 128, 64, 1);\nDECL_FATTN_MMA_F16_CASE(256, 256, 64, 1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 8, 1);\nDECL_FATTN_MMA_F16_CASE(80, 80, 8, 1);\nDECL_FATTN_MMA_F16_CASE(96, 96, 8, 1);\nDECL_FATTN_MMA_F16_CASE(112, 112, 8, 1);\nDECL_FATTN_MMA_F16_CASE(128, 128, 8, 1);\nDECL_FATTN_MMA_F16_CASE(256, 256, 8, 1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 8, 2);\nDECL_FATTN_MMA_F16_CASE(80, 80, 8, 2);\nDECL_FATTN_MMA_F16_CASE(96, 96, 8, 2);\nDECL_FATTN_MMA_F16_CASE(112, 112, 8, 2);\nDECL_FATTN_MMA_F16_CASE(128, 128, 8, 2);\nDECL_FATTN_MMA_F16_CASE(256, 256, 8, 2);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 8, 4);\nDECL_FATTN_MMA_F16_CASE(80, 80, 8, 4);\nDECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);\nDECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);\nDECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);\nDECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);\nDECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\nDECL_FATTN_MMA_F16_CASE(64, 64, 8, 8);\nDECL_FATTN_MMA_F16_CASE(80, 80, 8, 8);\nDECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);\nDECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);\nDECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);\nDECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.cuh\"\n\nDECL_FATTN_TILE_CASE(112, 112);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.cuh\"\n\nDECL_FATTN_TILE_CASE(128, 128);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.cuh\"\n\nDECL_FATTN_TILE_CASE(256, 256);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.cuh\"\n\nDECL_FATTN_TILE_CASE(40, 40);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.cuh\"\n\nDECL_FATTN_TILE_CASE(576, 512);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.cuh\"\n\nDECL_FATTN_TILE_CASE(64, 64);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.cuh\"\n\nDECL_FATTN_TILE_CASE(72, 72);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.cuh\"\n\nDECL_FATTN_TILE_CASE(80, 80);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.cuh\"\n\nDECL_FATTN_TILE_CASE(96, 96);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/generate_cu_files.py",
    "content": "#!/usr/bin/env python3\n\nfrom glob import glob\nimport os\n\nHEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]\n\nTYPES_KV = [\"GGML_TYPE_F16\", \"GGML_TYPE_Q4_0\", \"GGML_TYPE_Q4_1\", \"GGML_TYPE_Q5_0\", \"GGML_TYPE_Q5_1\", \"GGML_TYPE_Q8_0\"]\n\nSOURCE_FATTN_TILE = \"\"\"// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.cuh\"\n\nDECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v});\n\"\"\"\n\nSOURCE_FATTN_VEC = \"\"\"// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.cuh\"\n\nDECL_FATTN_VEC_CASE( 64, {type_k}, {type_v});\nDECL_FATTN_VEC_CASE(128, {type_k}, {type_v});\nDECL_FATTN_VEC_CASE(256, {type_k}, {type_v});\n\"\"\"\n\nSOURCE_FATTN_MMA_START = \"\"\"// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-mma-f16.cuh\"\n\n\"\"\"\n\nSOURCE_FATTN_MMA_CASE = \"DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\\n\"\n\nTYPES_MMQ = [\n    \"GGML_TYPE_Q4_0\", \"GGML_TYPE_Q4_1\", \"GGML_TYPE_Q5_0\", \"GGML_TYPE_Q5_1\", \"GGML_TYPE_Q8_0\",\n    \"GGML_TYPE_Q2_K\", \"GGML_TYPE_Q3_K\", \"GGML_TYPE_Q4_K\", \"GGML_TYPE_Q5_K\", \"GGML_TYPE_Q6_K\",\n    \"GGML_TYPE_IQ2_XXS\", \"GGML_TYPE_IQ2_XS\", \"GGML_TYPE_IQ2_S\", \"GGML_TYPE_IQ3_XXS\", \"GGML_TYPE_IQ3_S\",\n    \"GGML_TYPE_IQ1_S\", \"GGML_TYPE_IQ4_NL\", \"GGML_TYPE_IQ4_XS\", \"GGML_TYPE_MXFP4\"\n]\n\nSOURCE_MMQ = \"\"\"// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE({type});\n\"\"\"\n\nSOURCE_MMF = \"\"\"// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE({type});\n\"\"\"\n\n\ndef get_short_name(long_quant_name):\n    return long_quant_name.replace(\"GGML_TYPE_\", \"\").lower()\n\n\nfor filename in glob(\"*.cu\"):\n    os.remove(filename)\n\nfor head_size_kq in HEAD_SIZES_KQ:\n    head_size_v = head_size_kq if head_size_kq != 576 else 512\n    with open(f\"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu\", \"w\") as f:\n        f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))\n\nfor type_k in TYPES_KV:\n    for type_v in TYPES_KV:\n        with open(f\"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu\", \"w\") as f:\n            f.write(SOURCE_FATTN_VEC.format(type_k=type_k, type_v=type_v))\n\nfor ncols in [8, 16, 32, 64]:\n    for ncols2 in [1, 2, 4, 8, 16, 32]:\n        if ncols2 > ncols:\n            continue\n        ncols1 = ncols // ncols2\n        with open(f\"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu\", \"w\") as f:\n            f.write(SOURCE_FATTN_MMA_START)\n\n            for head_size_kq in HEAD_SIZES_KQ:\n                if head_size_kq == 40:\n                    continue\n                if head_size_kq == 72:\n                    continue\n                if head_size_kq != 576 and ncols2 in (16, 32):\n                    continue\n                if head_size_kq == 576 and ncols2 not in (4, 16, 32):\n                    continue\n                head_size_v = head_size_kq if head_size_kq != 576 else 512\n                f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))\n\nfor type in TYPES_MMQ:\n    with open(f\"mmq-instance-{get_short_name(type)}.cu\", \"w\") as f:\n        f.write(SOURCE_MMQ.format(type=type))\n\nfor type in range(1, 17):\n    with open(f\"mmf-instance-ncols_{type}.cu\", \"w\") as f:\n        f.write(SOURCE_MMF.format(type=type))\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(10);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(11);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(12);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(13);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(14);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(15);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(16);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(2);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(3);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(4);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(5);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(6);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(7);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(8);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmf.cuh\"\n\nDECL_MMF_CASE(9);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_IQ1_S);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_IQ2_S);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_IQ2_XS);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_IQ3_S);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_IQ4_NL);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_IQ4_XS);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_MXFP4);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-q2_k.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_Q2_K);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-q3_k.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_Q3_K);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-q4_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_Q4_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-q4_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_Q4_1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-q4_k.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_Q4_K);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-q5_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_Q5_0);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-q5_1.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_Q5_1);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-q5_k.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_Q5_K);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-q6_k.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_Q6_K);\n"
  },
  {
    "path": "src/ggml-cuda/template-instances/mmq-instance-q8_0.cu",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../mmq.cuh\"\n\nDECL_MMQ_CASE(GGML_TYPE_Q8_0);\n"
  },
  {
    "path": "src/ggml-cuda/top-k.cu",
    "content": "#include \"argsort.cuh\"\n#include \"top-k.cuh\"\n\n#ifdef GGML_CUDA_USE_CUB\n#    include <cub/cub.cuh>\n#    if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)\n#        define CUB_TOP_K_AVAILABLE\nusing namespace cub;\n#    endif  // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2\n#endif      // GGML_CUDA_USE_CUB\n\n#ifdef CUB_TOP_K_AVAILABLE\n\nstatic void top_k_cub(ggml_cuda_pool & pool,\n                      const float *    src,\n                      int *            dst,\n                      const int        ncols,\n                      const int        k,\n                      cudaStream_t     stream) {\n    auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,\n                                                 cuda::execution::output_ordering::unsorted);\n    auto stream_env   = cuda::stream_ref{ stream };\n    auto env          = cuda::std::execution::env{ stream_env, requirements };\n\n    auto indexes_in = cuda::make_counting_iterator(0);\n\n    size_t temp_storage_bytes = 0;\n    DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,\n                         env);\n\n    ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);\n    void *                        d_temp_storage = temp_storage_alloc.get();\n\n    DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,\n                         ncols, k, env);\n}\n\n#elif defined(GGML_CUDA_USE_CUB)  // CUB_TOP_K_AVAILABLE\n\nstatic int next_power_of_2(int x) {\n    int n = 1;\n    while (n < x) {\n        n *= 2;\n    }\n    return n;\n}\n\n#endif                            // CUB_TOP_K_AVAILABLE\n\nvoid ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0   = dst->src[0];\n    const float *       src0_d = (const float *) src0->data;\n    int *               dst_d  = (int *) dst->data;\n    cudaStream_t        stream = ctx.stream();\n\n    // are these asserts truly necessary?\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_I32);\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    const int64_t    ncols = src0->ne[0];\n    const int64_t    nrows = ggml_nrows(src0);\n    const int64_t    k     = dst->ne[0];\n    ggml_cuda_pool & pool  = ctx.pool();\n#ifdef CUB_TOP_K_AVAILABLE\n    // TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented\n    // https://github.com/NVIDIA/cccl/issues/6391\n    // TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k\n    for (int i = 0; i < nrows; i++) {\n        top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream);\n    }\n#elif defined(GGML_CUDA_USE_CUB)  // CUB_TOP_K_AVAILABLE\n    // Fall back to argsort + copy\n    const int    ncols_pad      = next_power_of_2(ncols);\n    const size_t shared_mem     = ncols_pad * sizeof(int);\n    const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;\n\n    ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);\n    int *                     tmp_dst = temp_dst_alloc.get();\n\n    if (shared_mem > max_shared_mem || ncols > 1024) {\n        argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);\n    } else {\n        argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);\n    }\n    CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,\n                                 cudaMemcpyDeviceToDevice, stream));\n#else                             // GGML_CUDA_USE_CUB\n    ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);\n    int *                     tmp_dst = temp_dst_alloc.get();\n    argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);\n    CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,\n                                 cudaMemcpyDeviceToDevice, stream));\n#endif\n}\n"
  },
  {
    "path": "src/ggml-cuda/top-k.cuh",
    "content": "#include \"common.cuh\"\n\nvoid ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/topk-moe.cu",
    "content": "#include \"ggml-cuda/common.cuh\"\n#include \"ggml.h\"\n#include \"topk-moe.cuh\"\n\n#include <cmath>\n#include <initializer_list>\n\n// Kernel config struct - passed by value to CUDA kernel\nstruct topk_moe_config {\n    bool use_sigmoid;\n    bool with_norm;\n    bool delayed_softmax;\n};\n\n// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.\ntemplate <int experts_per_thread, bool use_limit>\n__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {\n    float max_val = -INFINITY;\n\n#pragma unroll\n    for (int i = 0; i < experts_per_thread; i++) {\n        const int  idx    = lane + i * WARP_SIZE;\n        const bool active = !use_limit || (idx < limit);\n        if (active) {\n            max_val = max(max_val, vals[i]);\n        }\n    }\n\n    max_val = warp_reduce_max(max_val);\n\n    float sum = 0.f;\n\n#pragma unroll\n    for (int i = 0; i < experts_per_thread; i++) {\n        const int  idx    = lane + i * WARP_SIZE;\n        const bool active = !use_limit || (idx < limit);\n        if (active) {\n            const float val = expf(vals[i] - max_val);\n            vals[i]         = val;\n            sum += val;\n        } else {\n            vals[i] = 0.f;\n        }\n    }\n\n    sum = warp_reduce_sum(sum);\n\n    const float inv_sum = 1.0f / sum;\n\n#pragma unroll\n    for (int i = 0; i < experts_per_thread; i++) {\n        const int  idx    = lane + i * WARP_SIZE;\n        const bool active = !use_limit || (idx < limit);\n        if (active) {\n            vals[i] *= inv_sum;\n        }\n    }\n}\n\ntemplate <int experts_per_thread, bool use_limit>\n__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {\n#pragma unroll\n    for (int i = 0; i < experts_per_thread; i++) {\n        const int  idx    = lane + i * WARP_SIZE;\n        const bool active = !use_limit || (idx < limit);\n        vals[i]           = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;\n    }\n}\n\n/*\n    This kernel does the following:\n    1. optionally softmax over the logits per token [n_experts, n_tokens]\n    2. argmax reduce over the top-k (n_experts_used) logits\n    3. write weights + ids to global memory\n    4. optionally normalize the weights or apply softmax over the selected logits\n\n    It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models\n*/\ntemplate <int n_experts, bool has_bias>\n__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *         logits,\n                                                                  float *               weights,\n                                                                  int32_t *             ids,\n                                                                  float *               bias,\n                                                                  const int             n_rows,\n                                                                  const int             n_expert_used,\n                                                                  const float           clamp_val,\n                                                                  const float           scale_val,\n                                                                  const topk_moe_config config) {\n    const int row = blockIdx.x * blockDim.y + threadIdx.y;\n    if (row >= n_rows) {\n        return;\n    }\n\n    logits += n_experts * row;\n    weights += n_expert_used * row;\n    ids += n_experts * row;\n\n    constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;\n\n    float wt[experts_per_thread];\n\n    // Initialize all slots to -INFINITY\n#pragma unroll\n    for (int i = 0; i < experts_per_thread; i++) {\n        wt[i] = -INFINITY;\n    }\n\n#pragma unroll\n    for (int i = 0; i < n_experts; i += WARP_SIZE) {\n        const int expert  = i + threadIdx.x;\n        wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;\n    }\n\n    if (!config.delayed_softmax) {\n        if (config.use_sigmoid) {\n           sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);\n        } else {\n           softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);\n        }\n    }\n\n    // Sanitize NaN to -FLT_MAX so the iterative argmax produces unique expert IDs.\n    // NaN comparisons always return false, which would cause the same expert to be\n    // selected repeatedly. -FLT_MAX compares normally and is still excluded by the\n    // -INFINITY sentinel used after each selection round.\n    // More relevant for the cuBLAS path. See https://github.com/ggml-org/llama.cpp/issues/19659\n#pragma unroll\n    for (int i = 0; i < experts_per_thread; i++) {\n        if (__isnanf(wt[i])) {\n            wt[i] = -FLT_MAX;\n        }\n    }\n\n    // selection_wt is only needed when bias is present (selection uses wt + bias)\n    // when no bias, we use wt directly for both selection and weight values\n    float selection_wt[has_bias ? experts_per_thread : 1];\n\n    if constexpr (has_bias) {\n#pragma unroll\n        for (int i = 0; i < experts_per_thread; i++) {\n            selection_wt[i] = -INFINITY;\n        }\n#pragma unroll\n        for (int i = 0; i < n_experts; i += WARP_SIZE) {\n            const int expert = i + threadIdx.x;\n            selection_wt[i / WARP_SIZE] =\n                (n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;\n        }\n    }\n\n    //at this point, each thread holds either a portion of the softmax distribution\n    //or the raw logits. We do the argmax reduce over n_expert_used, each time marking\n    //the expert weight as -inf to exclude from the next iteration\n\n    float wt_sum = 0.f;\n\n    float output_weights[experts_per_thread];\n\n#pragma unroll\n    for (int i = 0; i < experts_per_thread; i++) {\n        output_weights[i] = 0.f;\n    }\n\n    for (int k = 0; k < n_expert_used; k++) {\n        float max_val    = wt[0];\n        int   max_expert = threadIdx.x;\n\n        if constexpr (has_bias) {\n            float max_val_s = selection_wt[0];\n\n#pragma unroll\n            for (int i = 1; i < experts_per_thread; i++) {\n                const int expert = threadIdx.x + i * WARP_SIZE;\n                if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {\n                    max_val    = wt[i];\n                    max_val_s  = selection_wt[i];\n                    max_expert = expert;\n                }\n            }\n\n#pragma unroll\n            for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {\n                const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);\n                const float val_s  = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);\n                const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);\n                if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {\n                    max_val    = val;\n                    max_val_s  = val_s;\n                    max_expert = expert;\n                }\n            }\n\n            if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {\n                selection_wt[max_expert / WARP_SIZE] = -INFINITY;\n            }\n        } else {\n#pragma unroll\n            for (int i = 1; i < experts_per_thread; i++) {\n                const int expert = threadIdx.x + i * WARP_SIZE;\n                if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {\n                    max_val    = wt[i];\n                    max_expert = expert;\n                }\n            }\n\n#pragma unroll\n            for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {\n                const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);\n                const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);\n                if (val > max_val || (val == max_val && expert < max_expert)) {\n                    max_val    = val;\n                    max_expert = expert;\n                }\n            }\n\n            if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {\n                wt[max_expert / WARP_SIZE] = -INFINITY;\n            }\n        }\n\n        if ((k & (WARP_SIZE - 1)) == threadIdx.x) {\n            output_weights[k / WARP_SIZE] = max_val;\n        }\n\n        if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {\n            ids[k] = max_expert;\n            if (config.with_norm) {\n                wt_sum += max_val;\n            }\n        }\n    }\n\n    if (config.with_norm) {\n        wt_sum              = warp_reduce_sum(wt_sum);\n        wt_sum              = max(wt_sum, clamp_val);\n        const float inv_sum = 1.0f / wt_sum;\n\n        for (int i = 0; i < experts_per_thread; i++) {\n            output_weights[i] *= inv_sum;\n        }\n    }\n\n    if (config.delayed_softmax) {\n        softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);\n    }\n\n#pragma unroll\n    for (int i = 0; i < experts_per_thread; i++) {\n        const int idx = i * WARP_SIZE + threadIdx.x;\n        if (idx < n_expert_used) {\n            weights[idx] = output_weights[i] * scale_val;\n        }\n    }\n}\n\ntemplate<bool has_bias>\nstatic void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,\n                                 const float *               logits,\n                                 float *                     weights,\n                                 int32_t *                   ids,\n                                 float *                     bias,\n                                 const int                   n_rows,\n                                 const int                   n_expert,\n                                 const int                   n_expert_used,\n                                 const float                 clamp_val,\n                                 const float                 scale_val,\n                                 const topk_moe_config       config) {\n    GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&\n                \"delayed softmax is not supported with weight normalization\");\n    const int    rows_per_block = 4;\n    dim3         grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);\n    dim3         block_dims(WARP_SIZE, rows_per_block, 1);\n    cudaStream_t stream = ctx.stream();\n\n    switch (n_expert) {\n        case 1:\n            topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,\n                                                                   clamp_val, scale_val, config);\n            break;\n        case 2:\n            topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,\n                                                                   clamp_val, scale_val, config);\n            break;\n        case 4:\n            topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,\n                                                                   clamp_val, scale_val, config);\n            break;\n        case 8:\n            topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,\n                                                                   clamp_val, scale_val, config);\n            break;\n        case 16:\n            topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,\n                                                                    clamp_val, scale_val, config);\n            break;\n        case 32:\n            topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,\n                                                                    clamp_val, scale_val, config);\n            break;\n        case 64:\n            topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,\n                                                                    clamp_val, scale_val, config);\n            break;\n        case 128:\n            topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,\n                                                                     clamp_val, scale_val, config);\n            break;\n        case 256:\n            topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,\n                                                                     clamp_val, scale_val, config);\n            break;\n        case 512:\n            topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,\n                                                                     clamp_val, scale_val, config);\n            break;\n        case 576:\n            topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,\n                                                                     clamp_val, scale_val, config);\n            break;\n        default:\n            GGML_ASSERT(false && \"fatal error\");\n            break;\n    }\n}\n\nvoid ggml_cuda_op_topk_moe(ggml_backend_cuda_context &     ctx,\n                           const ggml_tensor *             logits,\n                           ggml_tensor *                   weights,\n                           ggml_tensor *                   ids,\n                           const ggml_tensor *             clamp,\n                           const ggml_tensor *             scale,\n                           const ggml_tensor *             bias,\n                           const ggml_cuda_topk_moe_args & args) {\n    GGML_ASSERT(logits->type == GGML_TYPE_F32);\n    GGML_ASSERT(weights->type == GGML_TYPE_F32);\n    GGML_ASSERT(ids->type == GGML_TYPE_I32);\n\n    const int n_experts = logits->ne[0];\n    const int n_rows    = logits->ne[1];\n\n    const float * logits_d  = (const float *) logits->data;\n    float *       weights_d = (float *) weights->data;\n    int32_t *     ids_d     = (int32_t *) ids->data;\n    float *       bias_d    = bias ? (float *) bias->data : nullptr;\n\n    float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;\n\n    GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);\n\n    const int n_expert_used = weights->ne[1];\n\n    const bool with_norm = clamp != nullptr;\n\n    float clamp_val = -INFINITY;\n    if (clamp) {\n        clamp_val = ggml_get_op_params_f32(clamp, 0);\n    }\n\n    topk_moe_config config;\n    config.use_sigmoid     = args.sigmoid;\n    config.with_norm       = with_norm;\n    config.delayed_softmax = args.delayed_softmax;\n\n    if (bias) {\n        launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,\n                             scale_val, config);\n    } else {\n        launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,\n                             scale_val, config);\n    }\n}\n\nbool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,\n                                   const ggml_tensor * weights,\n                                   const ggml_tensor * logits,\n                                   const ggml_tensor * ids) {\n    const int n_expert = ids->nb[1] / ids->nb[0];\n    if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {\n        return false;\n    }\n\n    if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {\n        return false;\n    }\n\n    if (gating_op->op == GGML_OP_SOFT_MAX) {\n        const ggml_tensor * softmax  = gating_op;\n        float               scale    = 1.0f;\n        float               max_bias = 0.0f;\n\n        memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));\n        memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));\n\n        if (!ggml_is_contiguous(softmax->src[0])) {\n            return false;\n        }\n\n        if (scale != 1.0f || max_bias != 0.0f) {\n            return false;\n        }\n\n        // don't fuse when masks or sinks are present\n        if (softmax->src[1] || softmax->src[2]) {\n            return false;\n        }\n    } else if (gating_op->op == GGML_OP_UNARY) {\n        ggml_unary_op op = ggml_get_unary_op(gating_op);\n\n        if (op != GGML_UNARY_OP_SIGMOID) {\n            return false;\n        }\n    }\n\n    return true;\n}\n"
  },
  {
    "path": "src/ggml-cuda/topk-moe.cuh",
    "content": "#include \"common.cuh\"\n#include \"ggml.h\"\n\n#include <initializer_list>\n\nstruct ggml_cuda_topk_moe_args {\n    bool sigmoid{};\n    bool softmax{};\n    bool delayed_softmax{};\n    bool prob_bias{};\n    bool norm{};\n    bool scale{};\n};\n\nvoid ggml_cuda_op_topk_moe(ggml_backend_cuda_context &     ctx,\n                           const ggml_tensor *             logits,\n                           ggml_tensor *                   weights,\n                           ggml_tensor *                   ids,\n                           const ggml_tensor *             clamp,\n                           const ggml_tensor *             scale,\n                           const ggml_tensor *             bias,\n                           const ggml_cuda_topk_moe_args & args);\n\nbool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,\n                                   const ggml_tensor * weights,\n                                   const ggml_tensor * logits,\n                                   const ggml_tensor * ids);\n"
  },
  {
    "path": "src/ggml-cuda/tri.cu",
    "content": "#include \"common.cuh\"\n#include \"convert.cuh\"\n#include \"tri.cuh\"\n#include \"ggml.h\"\n\ntemplate<typename T, bool prefix_keep, int add_to_split>\nstatic __global__ void tri_kernel(\n        const T * src, T * dst,\n        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,\n        const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,\n        const int64_t nb0,  const int64_t nb1,  const int64_t nb2,  const int64_t nb3) {\n    const int64_t i3 = blockIdx.z;\n    const int64_t i2 = blockIdx.y;\n    const int64_t i1 = blockIdx.x;\n    const int64_t split_point = i1 + add_to_split;\n\n    GGML_UNUSED_VARS(nb00, nb0);\n\n    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {\n        return;\n    }\n\n    const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03;\n    T       * dst_row = dst + i1*nb1  + i2*nb2  + i3*nb3;\n\n    if constexpr (prefix_keep) {\n        for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {\n            dst_row[i0] = src_row[i0];\n        }\n        for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {\n            dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);\n        }\n    } else {\n        for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {\n            dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);\n        }\n        for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {\n            dst_row[i0] = src_row[i0];\n        }\n    }\n}\n\ntemplate<typename T>\nstatic void tri_cuda(\n        const T * src, T * dst,\n        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,\n        const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,\n        const int64_t nb0,  const int64_t nb1,  const int64_t nb2,  const int64_t nb3,\n        const ggml_tri_type ttype,\n        cudaStream_t stream) {\n\n    dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1);\n    dim3 grid_dims(ne01, ne02, ne03);\n    const size_t type_size = sizeof(T);\n\n    const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0;\n    const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG);\n\n    if (prefix_keep) {\n        if (add_to_split == 0) {\n            tri_kernel<T, true, 0><<<grid_dims, block_dims, 0, stream>>>(\n                src, dst,\n                ne00, ne01, ne02, ne03,\n                nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,\n                nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size\n            );\n        } else { // only 0 and 1 supported\n            tri_kernel<T, true, 1><<<grid_dims, block_dims, 0, stream>>>(\n                src, dst,\n                ne00, ne01, ne02, ne03,\n                nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,\n                nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size\n            );\n        }\n    } else {\n        if (add_to_split == 0) {\n            tri_kernel<T, false, 0><<<grid_dims, block_dims, 0, stream>>>(\n                src, dst,\n                ne00, ne01, ne02, ne03,\n                nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,\n                nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size\n            );\n        } else {\n            tri_kernel<T, false, 1><<<grid_dims, block_dims, 0, stream>>>(\n                src, dst,\n                ne00, ne01, ne02, ne03,\n                nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,\n                nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size\n            );\n        }\n    }\n}\n\nvoid ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    cudaStream_t stream = ctx.stream();\n\n    const ggml_tri_type ttype = static_cast<ggml_tri_type>(ggml_get_op_params_i32(dst, 0));\n\n    GGML_ASSERT(src0->type == dst->type);\n\n    switch(src0->type) {\n        case GGML_TYPE_F32:\n            {\n                tri_cuda(\n                    (const float *)src0->data, (float *)dst->data,\n                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],\n                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],\n                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],\n                    ttype, stream\n                );\n            } break;\n        case GGML_TYPE_F16:\n            {\n                tri_cuda(\n                    (const half *)src0->data, (half *)dst->data,\n                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],\n                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],\n                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],\n                    ttype, stream\n                );\n            } break;\n        case GGML_TYPE_BF16:\n            {\n                tri_cuda(\n                    (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,\n                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],\n                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],\n                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],\n                    ttype, stream\n                );\n            } break;\n        default:\n            GGML_ABORT(\"fatal error\");\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/tri.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_TRI_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/tsembd.cu",
    "content": "#include \"tsembd.cuh\"\n\nstatic __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {\n    // blockIDx.y: idx of timesteps->ne[0]\n    // blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE\n    int i = blockIdx.y;\n    int j = threadIdx.x + blockIdx.x * blockDim.x;\n    float * embed_data = (float *)((char *)dst +  i*nb1);\n\n    int half = dim / 2;\n    if (dim % 2 != 0 && j == half) {\n        embed_data[2 * half] = 0.f;\n    }\n\n    if (j >= half) {\n        return;\n    }\n\n    float timestep = timesteps[i];\n    float freq = (float)expf(-logf(max_period) * j / half);\n    float arg = timestep * freq;\n    embed_data[j] = cosf(arg);\n    embed_data[j + half] = sinf(arg);\n}\n\nstatic void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1,\n                                        const int dim, const int max_period, cudaStream_t stream) {\n    int half_ceil = (dim + 1) / 2;\n    int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;\n    dim3 gridDim(num_blocks, ne00, 1);\n    timestep_embedding_f32<<<gridDim, CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0, stream>>>(x, dst, nb1, dim, max_period);\n}\n\nvoid ggml_cuda_op_timestep_embedding(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src0_d = (const float *)src0->data;\n    float * dst_d = (float *)dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    const int dim = dst->op_params[0];\n    const int max_period = dst->op_params[1];\n\n    timestep_embedding_f32_cuda(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);\n}\n"
  },
  {
    "path": "src/ggml-cuda/tsembd.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_timestep_embedding(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/unary.cu",
    "content": "#include \"unary.cuh\"\n#include \"convert.cuh\"\n\nstatic __device__ __forceinline__ float op_abs(float x) {\n    return fabsf(x);\n}\n\nstatic __device__ __forceinline__ float op_sgn(float x) {\n    return (x > 0.f ? 1.f : ((x < 0.f ? -1.f : 0.f)));\n}\n\nstatic __device__ __forceinline__ float op_neg(float x) {\n    return -x;\n}\n\nstatic __device__ __forceinline__ float op_step(float x) {\n    return x > 0.0f;\n}\n\nstatic __device__ __forceinline__ float op_gelu(float x) {\n    return ggml_cuda_op_gelu_single(x);\n}\n\nstatic __device__ __forceinline__ float op_gelu_erf(float x) {\n    const float SQRT_2_INV = 0.70710678118654752440084436210484f;\n\n    return 0.5f*x*(1.0f + erff(x*SQRT_2_INV));\n}\n\nstatic __device__ __forceinline__ float op_gelu_quick(float x) {\n    const float GELU_QUICK_COEF = -1.702f;\n\n    return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x)));\n}\n\nstatic __device__ __forceinline__ float op_silu(float x) {\n    return ggml_cuda_op_silu_single(x);\n}\n\nstatic __device__ __forceinline__ float op_tanh(float x) {\n    return tanhf(x);\n}\n\nstatic __device__ __forceinline__ float op_relu(float x) {\n    return fmaxf(x, 0);\n}\n\nstatic __device__ __forceinline__ float op_sigmoid(float x) {\n    return 1.0f / (1.0f + expf(-x));\n}\n\nstatic __device__ __forceinline__ float op_hardsigmoid(float x) {\n    return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));\n}\n\nstatic __device__ __forceinline__ float op_hardswish(float x) {\n    return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));\n}\n\nstatic __device__ __forceinline__ float op_exp(float x) {\n    return expf(x);\n}\n\nstatic __device__ __forceinline__ float op_sqr(float x) {\n    return x * x;\n}\n\nstatic __device__ __forceinline__ float op_sqrt(float x) {\n    return sqrtf(x);\n}\n\nstatic __device__ __forceinline__ float op_sin(float x) {\n    return sinf(x);\n}\n\nstatic __device__ __forceinline__ float op_cos(float x) {\n    return cosf(x);\n}\n\nstatic __device__ __forceinline__ float op_log(float x) {\n    return logf(x);\n}\n\nstatic __device__ __forceinline__ float op_expm1(float x) {\n    return expm1f(x);\n}\n\nstatic __device__ __forceinline__ float op_softplus(float x) {\n    return (x > 20.0f) ? x : logf(1.0f + expf(x));\n}\n\nstatic __device__ __forceinline__ float op_elu(float x) {\n    return (x > 0.f) ? x : expm1f(x);\n}\n\nstatic __device__ __forceinline__ float op_floor(float x) {\n    return floorf(x);\n}\n\nstatic __device__ __forceinline__ float op_ceil(float x) {\n    return ceilf(x);\n}\n\nstatic __device__ __forceinline__ float op_round(float x) {\n    return round(x);\n}\n\nstatic __device__ __forceinline__ float op_trunc(float x) {\n    return trunc(x);\n}\n\ntemplate <float (*op)(float), typename T>\nstatic __global__ void unary_op_kernel(const T * x, T * dst, const int k) {\n    const int i = blockDim.x*blockIdx.x + threadIdx.x;\n\n    if (i >= k) {\n        return;\n    }\n\n    dst[i] = (T)op((float)x[i]);\n}\n\ntemplate <float (*op)(float), typename T>\nstatic void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {\n    const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;\n    unary_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);\n}\n\ntemplate <float (*op)(float)>\nvoid ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const void * src0_d = src0->data;\n    void * dst_d = dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);\n    GGML_ASSERT(src0->type == dst->type);\n\n    if (src0->type == GGML_TYPE_F16) {\n        unary_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);\n    } else {\n        unary_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);\n    }\n}\n\nvoid ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_abs>(ctx, dst);\n}\n\nvoid ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_sgn>(ctx, dst);\n}\n\nvoid ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_neg>(ctx, dst);\n}\n\nvoid ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_step>(ctx, dst);\n}\n\nvoid ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_gelu>(ctx, dst);\n}\n\nvoid ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_gelu_erf>(ctx, dst);\n}\n\nvoid ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_gelu_quick>(ctx, dst);\n}\n\nvoid ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_silu>(ctx, dst);\n}\n\nvoid ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_tanh>(ctx, dst);\n}\n\nvoid ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_relu>(ctx, dst);\n}\n\nvoid ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_sigmoid>(ctx, dst);\n}\n\nvoid ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_hardsigmoid>(ctx, dst);\n}\n\nvoid ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_hardswish>(ctx, dst);\n}\n\nvoid ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_exp>(ctx, dst);\n}\n\nvoid ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_sqr>(ctx, dst);\n}\n\nvoid ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_sqrt>(ctx, dst);\n}\n\nvoid ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_sin>(ctx, dst);\n}\n\nvoid ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_cos>(ctx, dst);\n}\n\nvoid ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_log>(ctx, dst);\n}\n\nvoid ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_elu>(ctx, dst);\n}\n\nvoid ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_floor>(ctx, dst);\n}\n\nvoid ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_ceil>(ctx, dst);\n}\n\nvoid ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_round>(ctx, dst);\n}\n\nvoid ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_trunc>(ctx, dst);\n}\n\nvoid ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_expm1>(ctx, dst);\n}\n\nvoid ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary<op_softplus>(ctx, dst);\n}\n/* gated ops */\n\ntemplate <float (*op)(float), typename T>\nstatic __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) {\n    const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;\n\n    if (i >= k) {\n        return;\n    }\n\n    // perform base op and multiply with gate (either offset in same tensor or a separate one)\n    const int64_t j0 = (i / n) * o0 + (i % n);\n    const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);\n\n    dst[i] = (T)(op((float)x[j0]) * (float)g[j1]);\n}\n\ntemplate <float (*op)(float), typename T>\nstatic void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) {\n    const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;\n    unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1);\n}\n\ntemplate <float (*op)(float)>\nvoid ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    void * src0_d = src0->data;\n    void * src1_d = src1 ? src1->data : src0->data;\n    const int64_t src0_o = src0->nb[1];\n    const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n    void * dst_d = dst->data;\n    const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);\n    GGML_ASSERT(src0->type == dst->type);\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));\n\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));\n        GGML_ASSERT(src1->ne[0] == nc);\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    const int32_t swapped = ((const int32_t *) dst->op_params)[1];\n\n    if (src0->type == GGML_TYPE_F16) {\n        half * src0_p = (half *) src0_d;\n        half * src1_p = (half *) src1_d;\n\n        if (!src1) {\n            src0_p += swapped ? nc : 0;\n            src1_p += swapped ? 0 : nc;\n        }\n\n        unary_gated_cuda<op>(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream);\n    } else {\n        float * src0_p = (float *) src0_d;\n        float * src1_p = (float *) src1_d;\n\n        if (!src1) {\n            src0_p += swapped ? nc : 0;\n            src1_p += swapped ? 0 : nc;\n        }\n\n        unary_gated_cuda<op>(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream);\n    }\n}\n\nvoid ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary_gated<op_relu>(ctx, dst);\n}\n\nvoid ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary_gated<op_gelu>(ctx, dst);\n}\n\nvoid ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary_gated<op_silu>(ctx, dst);\n}\n\nvoid ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary_gated<op_gelu_erf>(ctx, dst);\n}\n\nvoid ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);\n}\n\n// swiglu_oai\n\ntemplate <typename T>\nstatic __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {\n    const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;\n\n    if (i >= k) {\n        return;\n    }\n\n    // perform base op and multiply with gate (either offset in same tensor or a separate one)\n    const int64_t j0 = (i / n) * o0 + (i % n);\n    const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);\n\n    float xi = x[j0];\n    float gi = g[j1];\n\n    dst[i] = ggml_cuda_op_swiglu_oai_single(xi, gi, alpha, limit);\n}\n\ntemplate <typename T>\nstatic void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {\n    const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;\n    swiglu_oai_kernel<<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);\n}\n\nvoid ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    void * src0_d = src0->data;\n    void * src1_d = src1 ? src1->data : src0->data;\n    const int64_t src0_o = src0->nb[1];\n    const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n    void * dst_d = dst->data;\n    const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n    GGML_ASSERT(src0->type == dst->type);\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));\n\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));\n        GGML_ASSERT(src1->ne[0] == nc);\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    //const int32_t swapped = ((const int32_t *) dst->op_params)[1];\n    const int32_t swapped = ggml_get_op_params_i32(dst, 1);\n    const float alpha = ggml_get_op_params_f32(dst, 2);\n    const float limit = ggml_get_op_params_f32(dst, 3);\n\n    float * src0_p = (float *) src0_d;\n    float * src1_p = (float *) src1_d;\n\n    if (!src1) {\n        src0_p += swapped ? nc : 0;\n        src1_p += swapped ? 0 : nc;\n    }\n\n    swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);\n}\n\n/* CUDA kernel + launcher for xIELU */\n\ntemplate <typename T>\nstatic __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) {\n    const int i = blockDim.x*blockIdx.x + threadIdx.x;\n\n    if (i >= k) {\n        return;\n    }\n\n    const float xi = ggml_cuda_cast<float>(x[i]);\n\n    const float gate_pos = (xi > 0.0f);\n    const float y_pos = alpha_p * xi * xi + beta * xi;\n    const float min_v_eps = fminf(xi, eps);\n    const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi;\n    const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg;\n\n    dst[i] = ggml_cuda_cast<T>(out);\n}\n\ntemplate <typename T>\nstatic void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) {\n    const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE;\n    xielu_kernel<<<num_blocks, CUDA_XIELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, alpha_n, alpha_p, beta, eps);\n}\n\nvoid ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const void * src0_d = src0->data;\n    void * dst_d = dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);\n    GGML_ASSERT(src0->type == dst->type);\n\n    const float alpha_n = ggml_get_op_params_f32(dst, 1);\n    const float alpha_p = ggml_get_op_params_f32(dst, 2);\n    const float beta    = ggml_get_op_params_f32(dst, 3);\n    const float eps     = ggml_get_op_params_f32(dst, 4);\n\n    if (src0->type == GGML_TYPE_F16) {\n        xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);\n    } else {\n        xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);\n    }\n}\n\n\n\n/* silu_back */\n\nstatic __device__ __forceinline__ float op_silu_back(float grad, float x) {\n    const float s = 1.0f / (1.0f + expf(-x));\n    return grad * s * (1.0f + x * (1.0f - s));\n}\n\ntemplate <class T>\nstatic __global__ void silu_back_kernel(const T * grad, const T * xf, T * dst, const int k) {\n    const int i = blockDim.x*blockIdx.x + threadIdx.x;\n\n    if (i >= k) {\n        return;\n    }\n\n    dst[i] = (T)op_silu_back((float)grad[i], (float)xf[i]);\n}\n\ntemplate <class T>\nstatic void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) {\n    const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;\n    silu_back_kernel<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);\n}\n\nvoid ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0]; // input from forward pass\n    const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output\n\n    const float * src0_d = (const float *) src0->data;\n    const float * src1_d = (const float *) src1->data;\n    float       * dst_d  = (float       *) dst->data;\n\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);\n    GGML_ASSERT(src0->type == dst->type);\n\n    if (src0->type == GGML_TYPE_F16) {\n        silu_back_cuda((const half *)src0_d, (const half *)src1_d, (half *)dst_d, ggml_nelements(src0), stream);\n    } else {\n        silu_back_cuda((const float*)src0_d, (const float*)src1_d, (float *)dst_d, ggml_nelements(src0), stream);\n    }\n}\n\n/* leaky relu */\n\nstatic __device__ __forceinline__ float op_leaky_relu(float x, const float negative_slope) {\n    return fmaxf(x, 0) + fminf(x, 0.0f) * negative_slope;\n}\n\ntemplate <class T>\nstatic __global__ void leaky_relu_kernel(const T * x, T * dst, const int k, const float negative_slope) {\n    const int i  = blockDim.x*blockIdx.x + threadIdx.x;\n\n    if (i >= k) {\n        return;\n    }\n\n    dst[i] = (T)op_leaky_relu((float)x[i], negative_slope);\n}\n\ntemplate <class T>\nstatic void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) {\n    const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;\n    leaky_relu_kernel<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);\n}\n\nvoid ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const void * src0_d = src0->data;\n    void * dst_d = dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);\n    GGML_ASSERT(src0->type == dst->type);\n\n    float negative_slope;\n    memcpy(&negative_slope, dst->op_params, sizeof(float));\n\n    if (src0->type == GGML_TYPE_F16) {\n        leaky_relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), negative_slope, stream);\n    } else {\n        leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream);\n    }\n}\n\n/* fused unary + mul */\n\ntemplate <float (*op)(float)>\nstatic void ggml_cuda_op_unary_mul_impl(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) {\n    // unary_node: UNARY op applied to unary_node->src[0]\n    // mul_node:   MUL(a, b) where one of a/b is unary_node\n    // Output goes to mul_node->data\n\n    const ggml_tensor * unary_src = unary_node->src[0];  // input to the unary op\n    const ggml_tensor * other_src = (mul_node->src[0] == unary_node) ? mul_node->src[1] : mul_node->src[0];\n\n    GGML_ASSERT(ggml_is_contiguous_1(unary_src));\n    GGML_ASSERT(unary_src->nb[0] == ggml_element_size(unary_src));\n    GGML_ASSERT(ggml_is_contiguous_1(other_src));\n    GGML_ASSERT(other_src->nb[0] == ggml_element_size(other_src));\n    GGML_ASSERT(ggml_are_same_shape(unary_src, other_src));\n\n    GGML_ASSERT(unary_src->type == GGML_TYPE_F32 || unary_src->type == GGML_TYPE_F16);\n    GGML_ASSERT(unary_src->type == other_src->type);\n    GGML_ASSERT(unary_src->type == mul_node->type);\n\n    cudaStream_t stream = ctx.stream();\n\n    const int64_t k  = ggml_nelements(mul_node);\n    const int64_t nc = unary_src->ne[0];\n    const int64_t unary_stride = unary_src->nb[1];\n    const int64_t other_stride = other_src->nb[1];\n\n    if (unary_src->type == GGML_TYPE_F16) {\n        unary_gated_cuda<op>((const half *) unary_src->data, (const half *) other_src->data,\n                             (half *) mul_node->data, k, nc,\n                             unary_stride / sizeof(half), other_stride / sizeof(half), stream);\n    } else {\n        unary_gated_cuda<op>((const float *) unary_src->data, (const float *) other_src->data,\n                             (float *) mul_node->data, k, nc,\n                             unary_stride / sizeof(float), other_stride / sizeof(float), stream);\n    }\n}\n\nvoid ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) {\n    switch (ggml_get_unary_op(unary_node)) {\n        case GGML_UNARY_OP_SILU:\n            ggml_cuda_op_unary_mul_impl<op_silu>(ctx, unary_node, mul_node);\n            break;\n        case GGML_UNARY_OP_SIGMOID:\n            ggml_cuda_op_unary_mul_impl<op_sigmoid>(ctx, unary_node, mul_node);\n            break;\n        case GGML_UNARY_OP_SOFTPLUS:\n            ggml_cuda_op_unary_mul_impl<op_softplus>(ctx, unary_node, mul_node);\n            break;\n        default:\n            GGML_ABORT(\"Unsupported unary op for fused unary+mul\");\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/unary.cuh",
    "content": "#pragma once\n#include \"common.cuh\"\n\n#define CUDA_NEG_BLOCK_SIZE 256\n#define CUDA_STEP_BLOCK_SIZE 256\n#define CUDA_GELU_BLOCK_SIZE 256\n#define CUDA_SILU_BLOCK_SIZE 256\n#define CUDA_SILU_BACK_BLOCK_SIZE 256\n#define CUDA_TANH_BLOCK_SIZE 256\n#define CUDA_RELU_BLOCK_SIZE 256\n#define CUDA_SIGMOID_BLOCK_SIZE 256\n#define CUDA_HARDSIGMOID_BLOCK_SIZE 256\n#define CUDA_EXP_BLOCK_SIZE 256\n#define CUDA_HARDSWISH_BLOCK_SIZE 256\n#define CUDA_SQR_BLOCK_SIZE 256\n#define CUDA_SQRT_BLOCK_SIZE 256\n#define CUDA_SIN_BLOCK_SIZE 256\n#define CUDA_COS_BLOCK_SIZE 256\n#define CUDA_GLU_BLOCK_SIZE 256\n#define CUDA_XIELU_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node);\n\n__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {\n    return x / (1.0f + expf(-x));\n}\n\n__device__ __forceinline__ float ggml_cuda_op_gelu_single(float x) {\n    const float GELU_COEF_A    = 0.044715f;\n    const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;\n\n    return 0.5f * x * (1.0f + tanhf(SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x)));\n}\n\n__device__ __forceinline__ float ggml_cuda_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {\n    x = fminf(x, limit);\n    g = fmaxf(fminf(g, limit), -limit);\n\n    float out_glu = x / (1.0f + expf(-x * alpha));\n    out_glu = out_glu * (1.0f + g);\n    return out_glu;\n}\n"
  },
  {
    "path": "src/ggml-cuda/upscale.cu",
    "content": "#include \"upscale.cuh\"\n\nstatic __global__ void upscale_f32(const float * x, float * dst,\n        const int nb00, const int nb01, const int nb02, const int nb03,\n        const int ne10, const int ne11, const int ne12, const int ne13,\n        const float sf0, const float sf1, const float sf2, const float sf3) {\n    int index = threadIdx.x + blockIdx.x * blockDim.x;\n    if (index >= ne10 * ne11 * ne12 * ne13) {\n        return;\n    }\n\n    int i10 = index % ne10;\n    int i11 = (index / ne10) % ne11;\n    int i12 = (index / (ne10 * ne11)) % ne12;\n    int i13 = (index / (ne10 * ne11 * ne12)) % ne13;\n\n    int i00 = i10 / sf0;\n    int i01 = i11 / sf1;\n    int i02 = i12 / sf2;\n    int i03 = i13 / sf3;\n\n    dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );\n}\n\nstatic __global__ void upscale_f32_bilinear(const float * x, float * dst,\n        const int nb00, const int nb01, const int nb02, const int nb03,\n        const int ne00_src, const int ne01_src,\n        const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,\n        const float sf0, const float sf1, const float sf2, const float sf3,\n        const float pixel_offset) {\n    const int64_t index              = threadIdx.x + blockIdx.x * blockDim.x;\n    const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;\n\n    if (index >= dst_total_elements) {\n        return;\n    }\n\n    const int i10_dst = index % ne10_dst;\n    const int i11_dst = (index / ne10_dst) % ne11_dst;\n    const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;\n    const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);\n\n    const int i02_src = (int)(i12_dst / sf2);\n    const int i03_src = (int)(i13_dst / sf3);\n\n    const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;\n    int y0_src    = (int)floorf(y_src_f);\n    int y1_src    = y0_src + 1;\n\n    y0_src = max(0, min(y0_src, ne01_src - 1));\n    y1_src = max(0, min(y1_src, ne01_src - 1));\n\n    float dy = y_src_f - (float)y0_src;\n    dy       = max(0.0f, min(dy, 1.0f));\n\n    float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;\n    int x0_src    = (int)floorf(x_src_f);\n    int x1_src    = x0_src + 1;\n\n    x0_src = max(0, min(x0_src, ne00_src - 1));\n    x1_src = max(0, min(x1_src, ne00_src - 1));\n\n    float dx = x_src_f - (float)x0_src;\n    dx = max(0.0f, min(dx, 1.0f));\n\n    const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);\n    const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);\n    const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);\n    const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);\n\n    const float val_a = *p_a;\n    const float val_b = *p_b;\n    const float val_c = *p_c;\n    const float val_d = *p_d;\n\n    float result = val_a * (1.0f - dx) * (1.0f - dy) +\n                   val_b * dx * (1.0f - dy) +\n                   val_c * (1.0f - dx) * dy +\n                   val_d * dx * dy;\n\n    dst[index] = result;\n}\n\n// Similar to F.interpolate(..., mode=\"bilinear\", align_corners=False, antialias=True)\n// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp\nstatic __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst,\n        const int nb00, const int nb01, const int nb02, const int nb03,\n        const int ne00_src, const int ne01_src,\n        const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,\n        const float sf0, const float sf1, const float sf2, const float sf3,\n        const float pixel_offset) {\n    const int64_t index              = threadIdx.x + blockIdx.x * blockDim.x;\n    const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;\n\n    if (index >= dst_total_elements) {\n        return;\n    }\n\n    const int i10_dst = index % ne10_dst;\n    const int i11_dst = (index / ne10_dst) % ne11_dst;\n    const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;\n    const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);\n\n    const int i02_src = (int)(i12_dst / sf2);\n    const int i03_src = (int)(i13_dst / sf3);\n\n    const float y = ((float)i11_dst + pixel_offset) / sf1;\n    const float x = ((float)i10_dst + pixel_offset) / sf0;\n\n    // support and invscale, minimum 1 pixel for bilinear\n    const float support1  = max(1.0f / sf1, 1.0f);\n    const float invscale1 = 1.0f / support1;\n    const float support0  = max(1.0f / sf0, 1.0f);\n    const float invscale0 = 1.0f / support0;\n\n    // the range of source pixels that contribute\n    const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset));\n    const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset));\n    const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset));\n    const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset));\n\n    // bilinear filter with antialiasing\n    float val = 0.0f;\n    float total_weight = 0.0f;\n\n    auto triangle_filter = [](float x) -> float {\n        return max(1.0f - fabsf(x), 0.0f);\n    };\n\n    for (int64_t sy = y_min; sy < y_max; sy++) {\n        const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);\n\n        for (int64_t sx = x_min; sx < x_max; sx++) {\n            const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);\n            const float weight = weight_x * weight_y;\n\n            if (weight <= 0.0f) {\n                continue;\n            }\n\n            const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03);\n            val += pixel * weight;\n            total_weight += weight;\n        }\n    }\n\n    if (total_weight > 0.0f) {\n        val /= total_weight;\n    }\n\n    dst[index] = val;\n}\n\nnamespace bicubic_interpolation {\n// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm\n__device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)\n\nstatic __device__ float weight1(float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };\nstatic __device__ float weight2(float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };\n\nstatic __device__ float bicubic(float p0, float p1, float p2, float p3, float x) {\n    const float w0 = weight2(x + 1);\n    const float w1 = weight1(x + 0);\n    const float w2 = weight1(1 - x);\n    const float w3 = weight2(2 - x);\n    return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3;\n};\n} // namespace bicubic_interpolation\n\nstatic __global__ void upscale_f32_bicubic(const float * x, float * dst,\n        const int nb00, const int nb01, const int nb02, const int nb03,\n        const int ne00_src, const int ne01_src,\n        const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,\n        const float sf0, const float sf1, const float sf2, const float sf3,\n        const float pixel_offset) {\n    using bicubic_interpolation::bicubic;\n\n    const int64_t index              = threadIdx.x + blockIdx.x * blockDim.x;\n    const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;\n\n    if (index >= dst_total_elements) {\n        return;\n    }\n\n    const int i10_dst = index % ne10_dst;\n    const int i11_dst = (index / ne10_dst) % ne11_dst;\n    const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;\n    const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);\n\n    const int i02_src = (int)(i12_dst / sf2);\n    const int i03_src = (int)(i13_dst / sf3);\n\n    const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;\n    const int y0_src    = (int)floorf(y_src_f);\n    const float dy      = y_src_f - (float)y0_src;\n\n    const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;\n    const int x0_src    = (int)floorf(x_src_f);\n    const float dx      = x_src_f - (float)x0_src;\n\n    const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03;\n\n    auto load = [=](int x_off, int y_off) -> float {\n        int i00_src = max(0, min(x0_src + x_off, ne00_src - 1));\n        int i01_src = max(0, min(y0_src + y_off, ne01_src - 1));\n        return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01);\n    };\n\n    const float result = bicubic(\n        bicubic(load(-1,-1), load(0,-1), load(1,-1), load(2,-1), dx),\n        bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx),\n        bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx),\n        bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx), dy);\n\n    dst[index] = result;\n}\n\nstatic void upscale_f32_cuda(const float * x, float * dst,\n        const int nb00, const int nb01, const int nb02, const int nb03,\n        const int ne10, const int ne11, const int ne12, const int ne13,\n        const float sf0, const float sf1, const float sf2, const float sf3,\n        cudaStream_t stream) {\n    const int64_t dst_size   = ne10 * ne11 * ne12 * ne13;\n    const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;\n\n    upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);\n}\n\nstatic void upscale_f32_bilinear_cuda(const float * x, float * dst,\n        const int nb00, const int nb01, const int nb02, const int nb03,\n        const int ne00_src, const int ne01_src,\n        const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,\n        const float sf0, const float sf1, const float sf2, const float sf3,\n        const float pixel_offset, bool antialias, cudaStream_t stream) {\n    const int64_t dst_size   = ne10_dst * ne11_dst * ne12_dst * ne13_dst;\n    const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;\n\n    if (antialias) {\n        upscale_f32_bilinear_antialias<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);\n    } else {\n        upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);\n    }\n}\n\nstatic void upscale_f32_bicubic_cuda(const float * x, float * dst,\n        const int nb00, const int nb01, const int nb02, const int nb03,\n        const int ne00_src, const int ne01_src,\n        const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,\n        const float sf0, const float sf1, const float sf2, const float sf3,\n        const float pixel_offset, cudaStream_t stream) {\n    const int64_t dst_size   = ne10_dst * ne11_dst * ne12_dst * ne13_dst;\n    const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;\n\n    upscale_f32_bicubic<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);\n}\n\nvoid ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src0_d = (const float *)src0->data;\n    float * dst_d = (float *)dst->data;\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    const int mode_flags = dst->op_params[0];\n    const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF);\n\n    float sf0 = (float)dst->ne[0]/src0->ne[0];\n    float sf1 = (float)dst->ne[1]/src0->ne[1];\n    float sf2 = (float)dst->ne[2]/src0->ne[2];\n    const float sf3 = (float)dst->ne[3]/src0->ne[3];\n\n    float pixel_offset = 0.5f;\n    if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {\n        sf0          = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;\n        sf1          = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;\n        pixel_offset = 0.0f;\n    }\n\n    if (mode == GGML_SCALE_MODE_NEAREST) {\n        upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);\n    } else if (mode == GGML_SCALE_MODE_BILINEAR) {\n        const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);\n        upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],\n                                 src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],\n                                 sf0, sf1, sf2, sf3, pixel_offset, antialias, stream);\n    } else if (mode == GGML_SCALE_MODE_BICUBIC) {\n        upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],\n                                 src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],\n                                 sf0, sf1, sf2, sf3, pixel_offset, stream);\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/upscale.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_UPSCALE_BLOCK_SIZE 256\n\nvoid ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-cuda/vecdotq.cuh",
    "content": "#pragma once\n\n#include \"common.cuh\"\n\n#include <cstdint>\n\nstatic __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {\n    const uint8_t * x8 = (const uint8_t *) x;\n\n    int x32  = x8[4*i32 + 0] <<  0;\n    x32     |= x8[4*i32 + 1] <<  8;\n    x32     |= x8[4*i32 + 2] << 16;\n    x32     |= x8[4*i32 + 3] << 24;\n\n    return x32;\n}\n\nstatic __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {\n    const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment\n\n    int x32  = x16[2*i32 + 0] <<  0;\n    x32     |= x16[2*i32 + 1] << 16;\n\n    return x32;\n}\n\nstatic __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) {\n    return ((const int *) x)[i32]; // assume at least 4 byte alignment\n}\n\n// q4 contains 8 indices with 4 bit each.\n// This function selects those bytes from table that are at those indices and returns them as int2.\n// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.\nstatic __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {\n#if defined(GGML_USE_HIP)\n    // Load the 16-byte table into four 32-bit unsigned integers.\n    const uint32_t *values = (const uint32_t *)table;\n\n    const uint32_t q_even = q4;\n    const uint32_t q_odd  = (q4 >> 4);\n\n    // Perform lookups in the lower half of the table (indices 0-7).\n    uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);\n    uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);\n\n    // Perform lookups in the upper half of the table (indices 8-15).\n    uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);\n    uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);\n\n    // Select between the low and high results based on the MSB of each index nibble.\n    uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);\n    uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);\n    uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);\n    uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);\n\n    return make_int2(res_x, res_y);\n#elif !defined(GGML_USE_MUSA)\n    // CUDA does not have an instruction for selecting bytes with 4 bit indices.\n    // However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.\n    const uint32_t * table32 = (const uint32_t *) table;\n\n    // __byte_perm selects bytes based on the lower 16 bits in its third argument.\n    // Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.\n    // To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.\n    // Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.\n    uint32_t tmp[2];\n    const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));\n#pragma unroll\n    for (uint32_t i = 0; i < 2; ++i) {\n        const uint32_t shift = 16 * i;\n\n        const uint32_t low  = __byte_perm(table32[0], table32[1], q4 >> shift);\n        const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);\n        tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);\n    }\n\n    // tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.\n    // However, for the result we need ints with all even/odd 4 bit indices in q4.\n    // Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.\n    return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));\n#else\n    // Generic implementation.\n    const int      q0_32  = (q4 >> 0) & 0x0F0F0F0F;\n    const int8_t * q0_8   = (const int8_t *) &q0_32;\n    const char4    val0_8 = make_char4(\n        table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);\n\n    const int      q1_32  = (q4 >> 4) & 0x0F0F0F0F;\n    const int8_t * q1_8   = (const int8_t *) &q1_32;\n    const char4    val1_8 = make_char4(\n        table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);\n\n    return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));\n#endif\n}\n\nstatic __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) {\n    // v is a 7 bit int, with the 8th sign being encodable as popcnt\n    // with xor we can \"correct\" the bit instead of having to mask\n    const uint32_t p = __popc(v) & 1;\n    const uint32_t s = v ^ p << 7;\n    // broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors\n    return s * 0x01010101;\n}\n\n// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called\n// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q\n\n#define VDR_Q4_0_Q8_1_MMVQ 2\n#define VDR_Q4_0_Q8_1_MMQ  4\n\ntemplate <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(\n    const int * v, const int * u, const float & d4, const half2 & ds8) {\n\n    int sumi = 0;\n\n#pragma unroll\n    for (int i = 0; i < vdr; ++i) {\n        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;\n        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;\n\n        // SIMD dot product of quantized values\n        sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);\n        sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);\n    }\n\n    const float2 ds8f = __half22float2(ds8);\n\n    // second part effectively subtracts 8 from each quant value\n    return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);\n}\n\n#define VDR_Q4_1_Q8_1_MMVQ 2\n#define VDR_Q4_1_Q8_1_MMQ  4\n\ntemplate <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(\n    const int * v, const int * u, const half2 & dm4, const half2 & ds8) {\n\n    int sumi = 0;\n\n#pragma unroll\n    for (int i = 0; i < vdr; ++i) {\n        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;\n        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;\n\n        // SIMD dot product of quantized values\n        sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);\n        sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);\n    }\n\n#ifdef FAST_FP16_AVAILABLE\n    const float2 tmp = __half22float2(__hmul2(dm4, ds8));\n    const float d4d8 = tmp.x;\n    const float m4s8 = tmp.y;\n#else\n    const float2 dm4f = __half22float2(dm4);\n    const float2 ds8f = __half22float2(ds8);\n    const float d4d8 = dm4f.x * ds8f.x;\n    const float m4s8 = dm4f.y * ds8f.y;\n#endif // FAST_FP16_AVAILABLE\n\n    // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it\n    return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));\n}\n\n#define VDR_Q5_0_Q8_1_MMVQ 2\n#define VDR_Q5_0_Q8_1_MMQ  4\n\ntemplate <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(\n    const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) {\n\n    int sumi = 0;\n\n#pragma unroll\n    for (int i = 0; i < vdr; ++i) {\n        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits\n        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4\n        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12\n        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20\n        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28\n        sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values\n\n        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits\n        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4\n        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12\n        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20\n        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28\n        sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values\n    }\n\n    const float2 ds8f = __half22float2(ds8);\n\n    // second part effectively subtracts 16 from each quant value\n    return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);\n}\n\n#define VDR_Q5_1_Q8_1_MMVQ 2\n#define VDR_Q5_1_Q8_1_MMQ  4\n\ntemplate <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(\n    const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) {\n\n    int sumi = 0;\n\n#pragma unroll\n    for (int i = 0; i < vdr; ++i) {\n        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits\n        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4\n        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12\n        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20\n        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28\n        sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values\n\n        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits\n        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4\n        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12\n        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20\n        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28\n        sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values\n    }\n\n#ifdef FAST_FP16_AVAILABLE\n    const float2 tmp = __half22float2(__hmul2(dm5, ds8));\n    const float d5d8 = tmp.x;\n    const float m5s8 = tmp.y;\n#else\n    const float2 dm5f = __half22float2(dm5);\n    const float2 ds8f = __half22float2(ds8);\n    const float d5d8 = dm5f.x * ds8f.x;\n    const float m5s8 = dm5f.y * ds8f.y;\n#endif // FAST_FP16_AVAILABLE\n\n    // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it\n    return sumi*d5d8 + m5s8 / (QI5_1 / vdr);\n}\n\n#define VDR_Q8_0_Q8_1_MMVQ 2\n#define VDR_Q8_0_Q8_1_MMQ 8\n\ntemplate <typename T, int vdr> static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl(\n    const int * v, const int * u, const T & d8_0, const T & d8_1) {\n\n    int sumi = 0;\n\n#pragma unroll\n    for (int i = 0; i < vdr; ++i) {\n        // SIMD dot product of quantized values\n        sumi = ggml_cuda_dp4a(v[i], u[i], sumi);\n    }\n\n    return d8_0*d8_1 * ((T) sumi);\n}\n\ntemplate <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl(\n    const int * v, const int * u, const half2 & dm8, const half2 & ds8) {\n\n    int sumi = 0;\n\n#pragma unroll\n    for (int i = 0; i < vdr; ++i) {\n        // SIMD dot product of quantized values\n        sumi = ggml_cuda_dp4a(v[i], u[i], sumi);\n    }\n\n#ifdef FAST_FP16_AVAILABLE\n    const float2 tmp = __half22float2(__hmul2(dm8, ds8));\n    const float d8d8 = tmp.x;\n    const float m8s8 = tmp.y;\n#else\n    const float2 dm8f = __half22float2(dm8);\n    const float2 ds8f = __half22float2(ds8);\n    const float d8d8 = dm8f.x * ds8f.x;\n    const float m8s8 = dm8f.y * ds8f.y;\n#endif // FAST_FP16_AVAILABLE\n\n    // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it\n    return sumi*d8d8 + m8s8 / (QI8_1 / vdr);\n}\n\ntemplate <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_impl(\n    const int * v, const int * u, const float * d8_0, const float & d8_1) {\n\n    float sumf = 0.0f;\n\n#pragma unroll\n    for (int i0 = 0; i0 < vdr; i0 += QI8_0/2) {\n        int sumi = 0;\n\n#pragma unroll\n        for (int i = i0; i < i0 + QI8_0/2; ++i) {\n            // SIMD dot product of quantized values\n            sumi = ggml_cuda_dp4a(v[i], u[i], sumi);\n        }\n\n        sumf += d8_0[i0/(QI8_0/2)]*sumi;\n    }\n\n    return d8_1*sumf;\n}\n\n#define VDR_MXFP4_Q8_1_MMVQ 2\n#define VDR_MXFP4_Q8_1_MMQ  4\n\nstatic __device__ __forceinline__ float vec_dot_mxfp4_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;\n\n    const int * q8 = (const int *) bq8_1->qs + iqs;\n\n    int sumi = 0;\n#pragma unroll\n    for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {\n        const int aux_q4 = get_int_b1(bq4->qs, iqs + l);\n        const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);\n\n        sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);\n        sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);\n    }\n\n    const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);\n    return d * sumi;\n}\n\n#define VDR_Q2_K_Q8_1_MMVQ 1\n#define VDR_Q2_K_Q8_1_MMQ  4\n\n// contiguous v/x values\nstatic __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(\n    const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,\n    const half2 & dm2, const float * __restrict__ d8) {\n\n    float sumf_d = 0.0f;\n    float sumf_m = 0.0f;\n\n#pragma unroll\n    for (int i = 0; i < QR2_K; ++i) {\n        const int sc = scales[2*i];\n\n        const int vi = (v >> (2*i)) & 0x03030303;\n\n        sumf_d += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product\n\n        // fill int with 4x m\n        int m = sc >> 4;\n        m |= m <<  8;\n        m |= m << 16;\n        sumf_m += d8[i] * ggml_cuda_dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values\n    }\n\n    const float2 dm2f = __half22float2(dm2);\n\n    return dm2f.x*sumf_d - dm2f.y*sumf_m;\n}\n\n// contiguous v/x + u/y values\ntemplate <int ns8>\nstatic __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(\n    const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {\n\n    float sumf    = 0.0f;\n    float sumf_d8 = 0.0f;\n\n#pragma unroll\n    for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {\n        const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);\n        int sumi_d0 = 0;\n\n        const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);\n        int sumi_d1 = 0;\n\n#pragma unroll\n        for (int i = i0; i < i0 + QI8_1/2; ++i) {\n            sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);\n        }\n        sumf_d8 += dm2f0.x * sumi_d0;\n\n#pragma unroll\n        for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {\n            sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);\n        }\n        sumf_d8 += dm2f1.x * sumi_d1;\n\n        if (i0/QI8_1 < ns8) {\n            const float2 s8f = __half22float2(s8[i0/QI8_1]);\n            sumf -= dm2f0.y*s8f.x;\n            sumf -= dm2f1.y*s8f.y;\n        } else {\n            int sumi_m0 = 0;\n#pragma unroll\n            for (int i = i0; i < i0 + QI8_1/2; ++i) {\n                sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);\n            }\n            sumf_d8 -= dm2f0.y * sumi_m0;\n\n            int sumi_m1 = 0;\n#pragma unroll\n            for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {\n                sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);\n            }\n            sumf_d8 -= dm2f1.y * sumi_m1;\n        }\n    }\n\n    return sumf + d8*sumf_d8;\n}\n\n#define VDR_Q3_K_Q8_1_MMVQ 1\n#define VDR_Q3_K_Q8_1_MMQ  2\n\n// contiguous v/x values\nstatic __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(\n    const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,\n    const int & scale_offset, const float & d3, const float * __restrict__ d8) {\n\n    float sumf = 0.0f;\n\n#pragma unroll\n    for (int i = 0; i < QR3_K; ++i) {\n        const int isc = scale_offset + 2*i;\n\n        const int isc_low = isc % (QK_K/32);\n        const int sc_shift_low = 4 * (isc / (QK_K/32));\n        const int sc_low  = (scales[isc_low] >> sc_shift_low) & 0xF;\n\n        const int isc_high = isc % (QK_K/64);\n        const int sc_shift_high = 2 * (isc / (QK_K/64));\n        const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;\n\n        const int sc = (sc_low | sc_high) - 32;\n\n        const int vil = (vl >> (2*i)) & 0x03030303;\n\n        const int vih = ((vh >> i) << 2) & 0x04040404;\n\n        const int vi = __vsubss4(vil, vih);\n\n        sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product\n    }\n\n    return d3 * sumf;\n}\n\n// contiguous v/x + u/y values\nstatic __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(\n    const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,\n    const float & d3, const float & d8) {\n\n    int sumi = 0;\n\n#pragma unroll\n    for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {\n        int sumi_sc = 0;\n\n#pragma unroll\n        for (int i = i0; i < i0 + QI8_1/2; ++i) {\n            sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product\n        }\n\n        sumi += sumi_sc * scales[i0 / (QI8_1/2)];\n    }\n\n    return d3*d8 * sumi;\n}\n\n#define VDR_Q4_K_Q8_1_MMVQ 2\n#define VDR_Q4_K_Q8_1_MMQ  8\n\n// contiguous v/x values\nstatic __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(\n    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,\n    const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {\n\n    float sumf_d = 0.0f;\n    float sumf_m = 0.0f;\n\n#pragma unroll\n    for (int i = 0; i < QR4_K; ++i) {\n        const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;\n        const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;\n\n        const int dot1 = ggml_cuda_dp4a(v1i, u[2*i+1], ggml_cuda_dp4a(v0i, u[2*i+0], 0)); // SIMD dot product\n        const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+1], ggml_cuda_dp4a(0x01010101, u[2*i+0], 0)); // sum of u\n\n        sumf_d += d8[i] * (dot1 * sc[i]);\n        sumf_m += d8[i] * (dot2 * m[i]);  // multiply constant part of q4_K with sum of q8_1 values\n    }\n\n    const float2 dm4f = __half22float2(dm4);\n\n    return dm4f.x*sumf_d - dm4f.y*sumf_m;\n}\n\n// contiguous v/x + u/y values\nstatic __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(\n    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,\n    const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {\n\n    float sumf_d = 0.0f;\n    float sumf_m = 0.0f;\n\n#pragma unroll\n    for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {\n        int sumi_d = 0;\n\n#pragma unroll\n        for (int j = 0; j < QI8_1; ++j) {\n            sumi_d = ggml_cuda_dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product\n        }\n\n        const float2 ds8f = __half22float2(ds8[i]);\n\n        sumf_d += ds8f.x * (sc[i] * sumi_d);\n        sumf_m += ds8f.y *   m[i]; // sum of q8_1 block * q4_K min val\n    }\n\n    const float2 dm4f = __half22float2(dm4);\n\n    return dm4f.x*sumf_d - dm4f.y*sumf_m;\n}\n\n#define VDR_Q5_K_Q8_1_MMVQ 2\n#define VDR_Q5_K_Q8_1_MMQ  8\n\n// contiguous v/x values\nstatic __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(\n    const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,\n    const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {\n\n    float sumf_d = 0.0f;\n    float sumf_m = 0.0f;\n\n#pragma unroll\n    for (int i = 0; i < QR5_K; ++i) {\n        const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;\n        const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;\n\n        const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;\n        const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;\n\n        const int v0i = vl0i | vh0i;\n        const int v1i = vl1i | vh1i;\n\n        const int dot1 = ggml_cuda_dp4a(v0i, u[2*i+0], ggml_cuda_dp4a(v1i, u[2*i+1], 0)); // SIMD dot product\n        const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+0], ggml_cuda_dp4a(0x01010101, u[2*i+1], 0)); // sum of u\n\n        sumf_d += d8[i] * (dot1 * sc[i]);\n        sumf_m += d8[i] * (dot2 * m[i]);\n\n    }\n\n    const float2 dm5f = __half22float2(dm5);\n\n    return dm5f.x*sumf_d - dm5f.y*sumf_m;\n}\n\n// contiguous v/x + u/y values\nstatic __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(\n    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,\n    const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {\n\n    float sumf_d = 0.0f;\n    float sumf_m = 0.0f;\n\n#pragma unroll\n    for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {\n        int sumi_d = 0;\n\n#pragma unroll\n        for (int j = 0; j < QI8_1; ++j) {\n            sumi_d = ggml_cuda_dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product\n        }\n\n        const float2 ds8f = __half22float2(ds8[i]);\n\n        sumf_d += ds8f.x * (sc[i] * sumi_d);\n        sumf_m += ds8f.y *   m[i]; // sum of q8_1 block * q4_K min val\n    }\n\n    const float2 dm4f = __half22float2(dm4);\n\n    return dm4f.x*sumf_d - dm4f.y*sumf_m;\n}\n\n#define VDR_Q6_K_Q8_1_MMVQ 1\n#define VDR_Q6_K_Q8_1_MMQ  8\n\n// contiguous v/x values\nstatic __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(\n    const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,\n    const float & d, const float * __restrict__ d8) {\n\n    float sumf = 0.0f;\n\n#pragma unroll\n    for (int i = 0; i < QR6_K; ++i) {\n        const int sc = scales[4*i];\n\n        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;\n\n        const int vih = ((vh >> (4*i)) << 4) & 0x30303030;\n\n        const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32\n\n        sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product\n    }\n\n    return d*sumf;\n}\n\n// contiguous v/x + u/y values\nstatic __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(\n    const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,\n    const float & d6, const float * __restrict__ d8) {\n\n    float sumf_d = 0.0f;\n\n    const int      sc_packed = get_int_b4(sc, 0);\n    const int8_t * sc_reg    = (const int8_t *) &sc_packed;\n\n#pragma unroll\n    for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {\n        int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale\n\n#pragma unroll\n        for (int i = i0; i < i0 + 2; ++i) {\n            sumi_d.x = ggml_cuda_dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product\n            sumi_d.x = ggml_cuda_dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product\n\n            sumi_d.y = ggml_cuda_dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product\n            sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product\n        }\n\n        sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);\n    }\n\n    return d6 * sumf_d;\n}\n\nstatic __device__ __forceinline__ float vec_dot_q4_0_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq + kbx;\n\n    int v[VDR_Q4_0_Q8_1_MMVQ];\n    int u[2*VDR_Q4_0_Q8_1_MMVQ];\n\n#pragma unroll\n    for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {\n        v[i]     = get_int_b2(bq4_0->qs, iqs + i);\n        u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);\n        u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_0);\n    }\n\n    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);\n}\n\n\nstatic __device__ __forceinline__ float vec_dot_q4_1_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq + kbx;\n\n    int v[VDR_Q4_1_Q8_1_MMVQ];\n    int u[2*VDR_Q4_1_Q8_1_MMVQ];\n\n#pragma unroll\n    for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {\n        v[i]     = get_int_b4(bq4_1->qs, iqs + i);\n        u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);\n        u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_1);\n    }\n\n    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);\n}\n\nstatic __device__ __forceinline__ float vec_dot_q5_0_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq + kbx;\n\n    int vl[VDR_Q5_0_Q8_1_MMVQ];\n    int vh[VDR_Q5_0_Q8_1_MMVQ];\n    int  u[2*VDR_Q5_0_Q8_1_MMVQ];\n\n#pragma unroll\n    for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {\n        vl[i]    = get_int_b2(bq5_0->qs, iqs + i);\n        vh[i]    = get_int_b2(bq5_0->qh, 0) >> (4 * (iqs + i));\n        u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);\n        u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_0);\n    }\n\n    return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);\n}\n\nstatic __device__ __forceinline__ float vec_dot_q5_1_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq + kbx;\n\n    int vl[VDR_Q5_1_Q8_1_MMVQ];\n    int vh[VDR_Q5_1_Q8_1_MMVQ];\n    int  u[2*VDR_Q5_1_Q8_1_MMVQ];\n\n#pragma unroll\n    for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {\n        vl[i]    = get_int_b4(bq5_1->qs, iqs + i);\n        vh[i]    = get_int_b4(bq5_1->qh, 0) >> (4 * (iqs + i));\n        u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);\n        u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_1);\n    }\n\n    return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);\n}\n\nstatic __device__ __forceinline__ float vec_dot_q8_0_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq + kbx;\n\n    int v[VDR_Q8_0_Q8_1_MMVQ];\n    int u[VDR_Q8_0_Q8_1_MMVQ];\n\n#pragma unroll\n    for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {\n        v[i] = get_int_b2(bq8_0->qs, iqs + i);\n        u[i] = get_int_b4(bq8_1->qs, iqs + i);\n    }\n\n    return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));\n}\n\nstatic __device__ __forceinline__ float vec_dot_q2_K_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_q2_K * bq2_K = (const block_q2_K *) vbq + kbx;\n\n    const int bq8_offset = QR2_K * (iqs / QI8_1);\n    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);\n\n    const uint8_t * scales = bq2_K->scales + scale_offset;\n\n    const int v = get_int_b4(bq2_K->qs, iqs);\n    int    u[QR2_K];\n    float d8[QR2_K];\n\n#pragma unroll\n    for (int i = 0; i < QR2_K; ++ i) {\n        u[i]  = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1);\n        d8[i] = __low2float(bq8_1[bq8_offset + i].ds);\n    }\n\n    return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);\n}\n\nstatic __device__ __forceinline__ float vec_dot_q3_K_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_q3_K * bq3_K = (const block_q3_K *) vbq + kbx;\n\n    const int bq8_offset = QR3_K * (iqs / (QI3_K/2));\n    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);\n\n    const float d = bq3_K->d;\n\n    const int vl = get_int_b2(bq3_K->qs, iqs);\n\n    // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted\n    const int vh = ~get_int_b2(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;\n\n    int    u[QR3_K];\n    float d8[QR3_K];\n\n#pragma unroll\n    for (int i = 0; i < QR3_K; ++i) {\n        u[i]  = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1);\n        d8[i] = __low2float(bq8_1[bq8_offset + i].ds);\n    }\n\n    return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);\n}\n\nstatic __device__ __forceinline__ float vec_dot_q4_K_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx;\n\n    int    v[2];\n    int    u[2*QR4_K];\n    float d8[QR4_K];\n\n    // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6\n    const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));\n\n    // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12\n    // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44\n    // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76\n    // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108\n\n    const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));\n    v[0] = q4[0];\n    v[1] = q4[4];\n\n    const uint16_t * scales = (const uint16_t *)bq4_K->scales;\n    uint16_t aux[2];\n    const int j = bq8_offset/2;\n    if (j < 2) {\n        aux[0] = scales[j+0] & 0x3f3f;\n        aux[1] = scales[j+2] & 0x3f3f;\n    } else {\n        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);\n        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);\n    }\n    const uint8_t * sc = (const uint8_t *)aux;\n    const uint8_t * m  = sc + 2;\n\n    for (int i = 0; i < QR4_K; ++i) {\n        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;\n        d8[i] = __low2float(bq8i->ds);\n\n        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);\n        u[2*i+0] = q8[0];\n        u[2*i+1] = q8[4];\n    }\n\n    return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);\n}\n\nstatic __device__ __forceinline__ float vec_dot_q5_K_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_q5_K * bq5_K = (const block_q5_K *) vbq + kbx;\n\n    int   vl[2];\n    int   vh[2];\n    int    u[2*QR5_K];\n    float d8[QR5_K];\n\n    const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));\n    const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));\n    const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));\n\n    vl[0] = ql[0];\n    vl[1] = ql[4];\n\n    vh[0] = qh[0] >> bq8_offset;\n    vh[1] = qh[4] >> bq8_offset;\n\n    const uint16_t * scales = (const uint16_t *)bq5_K->scales;\n    uint16_t aux[2];\n    const int j = bq8_offset/2;\n    if (j < 2) {\n        aux[0] = scales[j+0] & 0x3f3f;\n        aux[1] = scales[j+2] & 0x3f3f;\n    } else {\n        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);\n        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);\n    }\n    const uint8_t * sc = (const uint8_t *)aux;\n    const uint8_t * m  = sc + 2;\n\n#pragma unroll\n    for (int i = 0; i < QR5_K; ++i) {\n        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;\n        d8[i] = __low2float(bq8i->ds);\n\n        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);\n        u[2*i+0] = q8[0];\n        u[2*i+1] = q8[4];\n    }\n\n    return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);\n}\n\nstatic __device__ __forceinline__ float vec_dot_q6_K_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_q6_K * bq6_K = (const block_q6_K *) vbq + kbx;\n\n    const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);\n    const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);\n    const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));\n\n    const int vl = get_int_b2(bq6_K->ql, iqs);\n    const int vh = get_int_b2(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;\n\n    const int8_t * scales = bq6_K->scales + scale_offset;\n\n    int    u[QR6_K];\n    float d8[QR6_K];\n\n#pragma unroll\n    for (int i = 0; i < QR6_K; ++i) {\n        u[i]  = get_int_b4(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);\n        d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds);\n    }\n\n    return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);\n}\n\n#define VDR_IQ2_XXS_Q8_1_MMVQ 2\n#define VDR_IQ2_XXS_Q8_1_MMQ  2\n\nstatic __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq + kbx;\n\n    const int q2 = get_int_b2(bq2->qs, iqs);\n    const uint8_t * aux8 = (const uint8_t *) &q2;\n    const uint32_t aux32 = get_int_b2(bq2->qs, iqs + 1);\n\n    int sumi = 0;\n#pragma unroll\n    for (int k0 = 0; k0 < 8; k0 += 2) {\n        const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]];\n        const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2));\n\n        const int signs0 = __vcmpne4(signs & 0x08040201, 0);\n        const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);\n        const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);\n        sumi = ggml_cuda_dp4a(grid0, u0, sumi);\n\n        const int signs1 = __vcmpne4(signs & 0x80402010, 0);\n        const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);\n        const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);\n        sumi = ggml_cuda_dp4a(grid1, u1, sumi);\n    }\n\n    const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)\n    sumi = sumi * ls / 8;           // (sumi * scale + sumi / 2) / 4\n    const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);\n    return d * sumi;\n}\n\n#define VDR_IQ2_XS_Q8_1_MMVQ 2\n#define VDR_IQ2_XS_Q8_1_MMQ  2\n\nstatic __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq + kbx;\n\n    const int2 q2_packed = make_int2(get_int_b2(bq2->qs, iqs + 0), get_int_b2(bq2->qs, iqs + 1));\n    const uint16_t * q2 = (const uint16_t *) &q2_packed;\n    const int ls0 = bq2->scales[iqs/2] & 0x0F;\n    const int ls1 = bq2->scales[iqs/2] >> 4;\n\n    int sumi0 = 0;\n    int sumi1 = 0;\n#pragma unroll\n    for (int l0 = 0; l0 < 8; l0 += 2) {\n        const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF];\n        const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9);\n\n        const int signs0 = __vcmpne4(signs & 0x08040201, 0);\n        const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);\n        const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);\n\n        const int signs1 = __vcmpne4(signs & 0x80402010, 0);\n        const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);\n        const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);\n\n        if (l0 < 4) {\n            sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0);\n            sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0);\n        } else {\n            sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1);\n            sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1);\n        }\n    }\n    const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4;\n    const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);\n    return d * sumi;\n}\n\n#define VDR_IQ2_S_Q8_1_MMVQ 2\n#define VDR_IQ2_S_Q8_1_MMQ  2\n\nstatic __device__ __forceinline__ float vec_dot_iq2_s_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_iq2_s * bq2 = (const block_iq2_s *) vbq + kbx;\n\n    const int       qs_packed = get_int_b2(bq2->qs, iqs/2);\n    const uint8_t * qs        = (const uint8_t *) &qs_packed;\n\n    const int qh = bq2->qh[iqs/2];\n\n    const int       signs_packed_32 = get_int_b2(bq2->qs, QK_K/32 + iqs/2);\n    const uint8_t * signs_packed_8  = (const uint8_t *) &signs_packed_32;\n\n    const int ls0 = bq2->scales[iqs/2] & 0x0F;\n    const int ls1 = bq2->scales[iqs/2] >> 4;\n\n    int sumi0 = 0;\n    int sumi1 = 0;\n#pragma unroll\n    for (int l0 = 0; l0 < 8; l0 += 2) {\n        const int * grid_pos = (const int *)(iq2s_grid + (qs[l0/2] | ((qh << (8-l0)) & 0x300)));\n\n        const int signs0 = __vcmpne4(((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), 0x00000000);\n        const int signs1 = __vcmpne4(((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), 0x00000000);\n\n        const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);\n        const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);\n\n        const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);\n        const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);\n\n        if (l0 < 4) {\n            sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0);\n            sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0);\n        } else {\n            sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1);\n            sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1);\n        }\n    }\n    const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4;\n\n    const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);\n    return d * sumi;\n}\n\n#define VDR_IQ3_XXS_Q8_1_MMVQ 2\n#define VDR_IQ3_XXS_Q8_1_MMQ  2\n\nstatic __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_iq3_xxs * bq3 = (const block_iq3_xxs *) vbq + kbx;\n\n    const int2 q3_packed = make_int2(get_int_b2(bq3->qs, iqs), get_int_b2(bq3->qs, iqs+1));\n    const uint8_t * q3 = (const uint8_t *) &q3_packed;\n    const uint32_t aux32 = get_int_b2(bq3->qs, QK_K/16 + iqs/2);\n\n    int sumi = 0;\n#pragma unroll\n    for (int l0 = 0; l0 < 8; l0 += 2) {\n        const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);\n        const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2));\n\n        const int signs0 = __vcmpne4(signs & 0x08040201, 0);\n        const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);\n\n        const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);\n\n        const int signs1 = __vcmpne4(signs & 0x80402010, 0);\n        const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);\n\n        const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);\n\n        sumi = ggml_cuda_dp4a(grid_l, u0, sumi);\n        sumi = ggml_cuda_dp4a(grid_h, u1, sumi);\n    }\n\n    const int ls = aux32 >> 28;\n    sumi = (ls*sumi + sumi/2)/2;\n    const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds);\n    return d * sumi;\n}\n\n#define VDR_IQ3_S_Q8_1_MMVQ 2\n#define VDR_IQ3_S_Q8_1_MMQ  2\n\n// TODO: don't use lookup table for signs\nstatic __device__ __forceinline__ float vec_dot_iq3_s_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_iq3_s * bq3 = (const block_iq3_s *) vbq + kbx;\n\n    const int2      qs_packed = make_int2(get_int_b2(bq3->qs, iqs + 0), get_int_b2(bq3->qs, iqs + 1));\n    const uint8_t * qs        = (const uint8_t *) &qs_packed;\n\n    const int qh = bq3->qh[iqs/2];\n\n    const int       signs_packed_32 = get_int_b2(bq3->signs, iqs/2);\n    const uint8_t * signs_packed_8  = (const uint8_t *) &signs_packed_32;\n\n    int sumi = 0;\n#pragma unroll\n    for (int l0 = 0; l0 < 8; l0 += 2) {\n        const int2 grid_pos = make_int2(\n            iq3s_grid[qs[l0 + 0] | ((qh << (8 - l0)) & 0x100)],\n            iq3s_grid[qs[l0 + 1] | ((qh << (7 - l0)) & 0x100)]);\n\n        const int signs0 = __vcmpne4(((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), 0x00000000);\n        const int signs1 = __vcmpne4(((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), 0x00000000);\n\n        const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);\n        const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);\n\n        const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);\n        const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);\n\n        sumi = ggml_cuda_dp4a(grid_l, u0, sumi);\n        sumi = ggml_cuda_dp4a(grid_h, u1, sumi);\n    }\n\n    sumi *= 1 + 2*((bq3->scales[iqs/4] >> ((iqs << 1) & 0x04)) & 0x0F);\n\n    const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds);\n    return d * sumi;\n}\n\n#define VDR_IQ1_S_Q8_1_MMVQ 1\n#define VDR_IQ1_S_Q8_1_MMQ  1\n\nstatic __device__ __forceinline__ float vec_dot_iq1_s_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n    const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;\n\n    const int       qs_packed = get_int_b2(bq1->qs, iqs);\n    const uint8_t * qs        = (const uint8_t *) &qs_packed;\n\n    const int qh = bq1->qh[iqs];\n\n    int sumi = 0;\n#pragma unroll\n    for (int l0 = 0; l0 < 8; l0 += 2) {\n        const int grid = iq1s_grid_gpu[qs[l0/2] | (((qh >> 3*(l0/2)) & 0x07) << 8)];\n\n        const int grid0 = (grid >> 0) & 0x0F0F0F0F;\n        const int grid1 = (grid >> 4) & 0x0F0F0F0F;\n\n        const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);\n        const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);\n\n        sumi = ggml_cuda_dp4a(grid0, u0, sumi);\n        sumi = ggml_cuda_dp4a(grid1, u1, sumi);\n    }\n\n    const float  d1q   = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1);\n    const float  delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);\n    const float2 ds    = __half22float2(bq8_1[iqs].ds);\n    return d1q * (ds.x*sumi + ds.y*delta);\n}\n\n#define VDR_IQ1_M_Q8_1_MMVQ 1\n#define VDR_IQ1_M_Q8_1_MMQ  1\n\nstatic __device__ __forceinline__ float vec_dot_iq1_m_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_iq1_m * bq1 = (const block_iq1_m *) vbq + kbx;\n\n    const int       qs_packed = get_int_b4(bq1->qs, iqs);\n    const uint8_t * qs        = (const uint8_t *) &qs_packed;\n\n    int   sumi[2] = {0};\n    float sumf[2] = {0.0f};\n#pragma unroll\n    for (int l0 = 0; l0 < 8; l0 += 2) {\n        const int qhl = bq1->qh[2*iqs + l0/4] >> (4 * ((l0/2) % 2));\n\n        const int grid = iq1s_grid_gpu[qs[l0/2] | ((qhl & 0x07) << 8)];\n\n        const int grid0 = (grid >> 0) & 0x0F0F0F0F;\n        const int grid1 = (grid >> 4) & 0x0F0F0F0F;\n\n        const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);\n        const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);\n\n        sumi[l0/4] = ggml_cuda_dp4a(grid0, u0, sumi[l0/4]);\n        sumi[l0/4] = ggml_cuda_dp4a(grid1, u1, sumi[l0/4]);\n\n        const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08);\n        int sumy = 0;\n        sumy = ggml_cuda_dp4a(u0, 0x01010101, sumy);\n        sumy = ggml_cuda_dp4a(u1, 0x01010101, sumy);\n        sumf[l0/4] += delta*sumy;\n    }\n\n    const uint16_t * sc = (const uint16_t *) bq1->scales;\n\n    iq1m_scale_t scale;\n    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000);\n    const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds);\n\n    const int tmp = sc[iqs/2] >> (6*(iqs%2));\n    const int sc0 = 2*((tmp >> 0) & 0x07) + 1;\n    const int sc1 = 2*((tmp >> 3) & 0x07) + 1;\n    return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);\n}\n\n#define VDR_IQ4_NL_Q8_1_MMVQ 2\n#define VDR_IQ4_NL_Q8_1_MMQ  4\n\nstatic __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_iq4_nl * bq4 = (const block_iq4_nl *) vbq + kbx;\n\n    const int * q8 = (const int *) bq8_1->qs + iqs;\n\n    int sumi = 0;\n#pragma unroll\n    for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {\n        const int aux_q4 = get_int_b2(bq4->qs, iqs + l);\n        const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);\n\n        sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);\n        sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);\n    }\n\n    const float d = __half2float(bq4->d) * __low2float(bq8_1->ds);\n    return d * sumi;\n}\n\n#define VDR_IQ4_XS_Q8_1_MMVQ 4\n#define VDR_IQ4_XS_Q8_1_MMQ  4\n\nstatic __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(\n    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {\n\n    const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq + kbx;\n\n    int sumi = 0;\n#pragma unroll\n    for (int j = 0; j < 4; ++j) {\n        const int aux_q4 = get_int_b4(bq4->qs, iqs + j);\n        const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);\n\n        const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);\n        const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);\n\n        sumi = ggml_cuda_dp4a(v.x, u0, sumi);\n        sumi = ggml_cuda_dp4a(v.y, u1, sumi);\n    }\n\n    const int ls = ((bq4->scales_l[iqs/8] >> (iqs & 0x04)) & 0x0F) | (((bq4->scales_h >> (iqs/2)) & 0x03) << 4);\n    sumi *= ls - 32;\n\n    const float d = __half2float(bq4->d) * __low2float(bq8_1[iqs/4].ds);\n    return d * sumi;\n}\n"
  },
  {
    "path": "src/ggml-cuda/vendors/cuda.h",
    "content": "#pragma once\n\n#include <cuda_runtime.h>\n#include <cuda.h>\n#include <cublas_v2.h>\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\n#if CUDART_VERSION >= 12050\n#include <cuda_fp8.h>\n#endif // CUDART_VERSION >= 12050\n\n#if CUDART_VERSION >= 12080\n#include <cuda_fp4.h>\n#endif // CUDART_VERSION >= 12080\n\n#if CUDART_VERSION < 11020\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED\n#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH\n#define CUBLAS_COMPUTE_16F CUDA_R_16F\n#define CUBLAS_COMPUTE_32F CUDA_R_32F\n#define cublasComputeType_t cudaDataType_t\n#endif // CUDART_VERSION < 11020\n"
  },
  {
    "path": "src/ggml-cuda/vendors/hip.h",
    "content": "#pragma once\n\n#define HIP_DISABLE_WARP_SYNC_BUILTINS 1\n#include <hip/hip_runtime.h>\n#include <hipblas/hipblas.h>\n#include <hip/hip_fp16.h>\n#include <hip/hip_bf16.h>\n\n#if defined(GGML_HIP_ROCWMMA_FATTN)\n#include <rocwmma/rocwmma-version.hpp>\n#endif // defined(GGML_HIP_ROCWMMA_FATTN)\n\n#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT\n#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT\n#define CUBLAS_OP_N HIPBLAS_OP_N\n#define CUBLAS_OP_T HIPBLAS_OP_T\n#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS\n#define CUBLAS_TF32_TENSOR_OP_MATH 0\n#define CUDA_R_16F  HIPBLAS_R_16F\n#define CUDA_R_16BF HIPBLAS_R_16B\n#define CUDA_R_32F  HIPBLAS_R_32F\n#define CUBLAS_SIDE_RIGHT HIPBLAS_SIDE_RIGHT\n#define CUBLAS_FILL_MODE_UPPER HIPBLAS_FILL_MODE_UPPER\n#define CUBLAS_DIAG_NON_UNIT HIPBLAS_DIAG_NON_UNIT\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported\n#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended\n#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned\n#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice\n#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite\n#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT(\"HipVMM Failure: %s\\n\", hipGetErrorString(err)); }}\n#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)\n#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)\n#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)\n#define __all_sync(mask, var) __all(var)\n#define __any_sync(mask, var) __any(var)\n#define cublasStrsmBatched hipblasStrsmBatched\n#define cublasCreate hipblasCreate\n#define cublasDestroy hipblasDestroy\n#define cublasGemmEx hipblasGemmEx\n#define cublasGemmBatchedEx hipblasGemmBatchedEx\n#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx\n#define cublasHandle_t hipblasHandle_t\n#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS\n#define cublasSetStream hipblasSetStream\n#define cublasSgemm hipblasSgemm\n#define cublasStatus_t hipblasStatus_t\n#define cublasOperation_t hipblasOperation_t\n#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch\n#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer\n#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess\n#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess\n#define cudaDeviceGetAttribute hipDeviceGetAttribute\n#define cudaDeviceProp hipDeviceProp_t\n#define cudaDeviceSynchronize hipDeviceSynchronize\n#define cudaError_t hipError_t\n#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled\n#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled\n#define cudaEventCreateWithFlags hipEventCreateWithFlags\n#define cudaEventDisableTiming hipEventDisableTiming\n#define cudaEventRecord hipEventRecord\n#define cudaEventSynchronize hipEventSynchronize\n#define cudaEvent_t hipEvent_t\n#define cudaEventDestroy hipEventDestroy\n#define cudaFree hipFree\n#define cudaFreeHost hipHostFree\n#define cudaGetDevice hipGetDevice\n#define cudaGetDeviceCount hipGetDeviceCount\n#define cudaGetDeviceProperties hipGetDeviceProperties\n#define cudaGetErrorString hipGetErrorString\n#define cudaGetLastError hipGetLastError\n#define cudaHostRegister hipHostRegister\n#define cudaHostRegisterPortable hipHostRegisterPortable\n#define cudaHostRegisterReadOnly hipHostRegisterReadOnly\n#define cudaHostUnregister hipHostUnregister\n#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel\n#define cudaLaunchHostFunc hipLaunchHostFunc\n#define cudaMalloc hipMalloc\n#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)\n#define cudaMallocManaged hipMallocManaged\n#define cudaMemAdvise hipMemAdvise\n#define cudaMemcpy hipMemcpy\n#define cudaMemcpyAsync hipMemcpyAsync\n#define cudaMemcpyPeerAsync hipMemcpyPeerAsync\n#define cudaMemcpy2DAsync hipMemcpy2DAsync\n#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice\n#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost\n#define cudaMemcpyHostToDevice hipMemcpyHostToDevice\n#define cudaMemcpyKind hipMemcpyKind\n#define cudaMemset hipMemset\n#define cudaMemsetAsync hipMemsetAsync\n#define cudaMemGetInfo hipMemGetInfo\n#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize\n#define cudaSetDevice hipSetDevice\n#define cuDeviceGet hipDeviceGet\n#define CUdevice hipDevice_t\n#define CUdeviceptr hipDeviceptr_t\n#define cuMemUnmap hipMemUnmap\n#define CUmemAccessDesc hipMemAccessDesc\n#define cuMemAddressFree hipMemAddressFree\n#define cuMemRelease hipMemRelease\n#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t\n#define cuMemCreate hipMemCreate\n#define cuMemAddressReserve hipMemAddressReserve\n#define cuMemMap hipMemMap\n#define cuMemSetAccess hipMemSetAccess\n#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity\n#define CUmemAllocationProp hipMemAllocationProp\n#define cuDeviceGetAttribute hipDeviceGetAttribute\n#define cudaStreamCreateWithFlags hipStreamCreateWithFlags\n#define cudaStreamDestroy hipStreamDestroy\n#define cudaStreamFireAndForget hipStreamFireAndForget\n#define cudaStreamNonBlocking hipStreamNonBlocking\n#define cudaStreamPerThread hipStreamPerThread\n#define cudaStreamSynchronize hipStreamSynchronize\n#define cudaStreamWaitEvent hipStreamWaitEvent\n#define cudaGraphExec_t hipGraphExec_t\n#define cudaGraphNode_t hipGraphNode_t\n#define cudaKernelNodeParams hipKernelNodeParams\n#define cudaKernelNodeParams hipKernelNodeParams\n#define cudaGraphExecDestroy hipGraphExecDestroy\n#define cudaGraphLaunch hipGraphLaunch\n#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure\n#define cudaGraphExecUpdateResult hipGraphExecUpdateResult\n#define cudaGraphNodeType hipGraphNodeType\n#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel\n#define cudaGraphInstantiate hipGraphInstantiate\n#define cudaStreamEndCapture hipStreamEndCapture\n#define cudaGraphDestroy hipGraphDestroy\n#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams\n#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction\n#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams\n#define cudaGraphNodeGetType hipGraphNodeGetType\n#define cudaGraphGetNodes hipGraphGetNodes\n#define cudaGraphExecUpdate hipGraphExecUpdate\n#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed\n#define cudaStreamBeginCapture hipStreamBeginCapture\n#define cudaGraph_t hipGraph_t\n#define cudaStream_t hipStream_t\n#define cudaSuccess hipSuccess\n#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor\n#define cudaFuncSetAttribute hipFuncSetAttribute\n#define cudaFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize\n#define __trap() do { abort(); __builtin_unreachable(); } while(0)\n#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS\n#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED\n#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED\n#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE\n#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH\n#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR\n#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED\n#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR\n#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED\n\n#if HIP_VERSION >= 60500000\n#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F\n#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F\n#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F\n#define cublasComputeType_t hipblasComputeType_t\n#define cudaDataType_t hipDataType\n#else\n#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F\n#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F\n#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F\n#define cublasComputeType_t hipblasDatatype_t\n#define cudaDataType_t hipblasDatatype_t\n#endif // HIP_VERSION >= 6050000\n\n#if !defined(__HIP_PLATFORM_AMD__)\n#error \"The HIP backend supports only AMD targets\"\n#endif // !defined(__HIP_PLATFORM_AMD__)\n\n#define __CUDA_ARCH__ 1300\n\n#if defined(__gfx900__) || defined(__gfx906__)\n#define GCN5\n#endif // defined(__gfx900__) || defined(__gfx906__)\n\n#if defined(__gfx803__)\n#define GCN4\n#endif // defined(__gfx803__)\n\n#if defined(GCN5) || defined(GCN4)\n#define GCN\n#endif // defined(GCN5) || defined(GCN4)\n\n#if defined(__gfx942__)\n#define CDNA3\n#endif // defined(__gfx942__)\n\n#if defined(__gfx90a__)\n#define CDNA2\n#endif // defined(__gfx90a__)\n\n#if defined(__gfx908__)\n#define CDNA1\n#endif // defined(__gfx908__)\n\n#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1)\n#define CDNA // For the entire family\n#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1)\n\n#if defined(__GFX12__)\n#define RDNA4\n#endif // defined(__GFX12__)\n\n#if defined(__GFX11__)\n#define RDNA3\n#endif // defined(__GFX11__)\n\n#if defined(__gfx1150__) || defined(__gfx1151__)\n#define RDNA3_5\n#endif // defined(__gfx1150__) || defined(__gfx1151__)\n\n#if defined(RDNA3) && !defined(RDNA3_5)\n#define RDNA3_0\n#endif // defined(RDNA3) && !defined(RDNA3_5)\n\n#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \\\n    defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)\n#define RDNA2\n#endif\n\n#if defined(__gfx1010__) || defined(__gfx1012__)\n#define RDNA1\n#endif // defined(__gfx1010__) || defined(__gfx1012__)\n\n#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)\n#define RDNA // For the entire family\n#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)\n\n#ifndef __has_builtin\n    #define __has_builtin(x) 0\n#endif\n\ntypedef __hip_bfloat16 nv_bfloat16;\ntypedef __hip_bfloat162 nv_bfloat162;\n\ntypedef int8_t int8x4_t __attribute__((ext_vector_type(4)));\ntypedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));\nstatic __device__ __forceinline__ int __vsubss4(const int a, const int b) {\n    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);\n    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);\n#if __has_builtin(__builtin_elementwise_sub_sat)\n    const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);\n    return reinterpret_cast<const int &>(c);\n#else\n    int8x4_t c;\n    int16_t tmp;\n#pragma unroll\n    for (int i = 0; i < 4; i++) {\n        tmp = va[i] - vb[i];\n        if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();\n        if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();\n        c[i] = tmp;\n    }\n    return reinterpret_cast<int &>(c);\n#endif // __has_builtin(__builtin_elementwise_sub_sat)\n}\n\nstatic __device__ __forceinline__ int __vsub4(const int a, const int b) {\n    return __vsubss4(a, b);\n}\n\nstatic __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {\n    const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);\n    const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);\n    unsigned int c;\n    uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);\n#pragma unroll\n    for (int i = 0; i < 4; ++i) {\n        vc[i] = va[i] == vb[i] ? 0xff : 0x00;\n    }\n    return c;\n}\n\nstatic __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {\n    const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);\n    const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);\n    unsigned int c;\n    uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);\n#pragma unroll\n    for (int i = 0; i < 4; ++i) {\n        vc[i] = va[i] == vb[i] ? 0x00 : 0xff;\n    }\n    return c;\n}\n"
  },
  {
    "path": "src/ggml-cuda/vendors/musa.h",
    "content": "#pragma once\n\n#include <musa_runtime.h>\n#include <musa.h>\n#include <mublas.h>\n#include <musa_bf16.h>\n#include <musa_fp16.h>\n#define CUBLAS_COMPUTE_16F CUDA_R_16F\n#define CUBLAS_COMPUTE_32F CUDA_R_32F\n#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F\n#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT\n#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT\n#define CUBLAS_OP_N MUBLAS_OP_N\n#define CUBLAS_OP_T MUBLAS_OP_T\n#define CUBLAS_DEFAULT_MATH MUBLAS_DEFAULT_MATH\n#define CUBLAS_SIDE_RIGHT MUBLAS_SIDE_RIGHT\n#define CUBLAS_FILL_MODE_UPPER MUBLAS_FILL_MODE_UPPER\n#define CUBLAS_DIAG_NON_UNIT MUBLAS_DIAG_NON_UNIT\n#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS\n#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH\n#define CUDA_R_16F  MUSA_R_16F\n#define CUDA_R_16BF MUSA_R_16BF\n#define CUDA_R_32F  MUSA_R_32F\n#define cublasStrsmBatched mublasStrsmBatched\n#define cublasComputeType_t cudaDataType_t\n#define cublasCreate mublasCreate\n#define cublasDestroy mublasDestroy\n#define cublasGemmEx mublasGemmEx\n#define cublasGemmBatchedEx mublasGemmBatchedEx\n#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx\n#define cublasHandle_t mublasHandle_t\n#define cublasSetMathMode mublasSetMathMode\n#define cublasSetStream mublasSetStream\n#define cublasSgemm mublasSgemm\n#define cublasStatus_t mublasStatus_t\n#define cublasOperation_t mublasOperation_t\n#define cublasGetStatusString mublasGetStatusString\n#define cudaDataType_t musaDataType_t\n#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer\n#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess\n#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess\n#define cudaDeviceProp musaDeviceProp\n#define cudaDeviceSynchronize musaDeviceSynchronize\n#define cudaError_t musaError_t\n#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled\n#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled\n#define cudaEventCreateWithFlags musaEventCreateWithFlags\n#define cudaEventDisableTiming musaEventDisableTiming\n#define cudaEventRecord musaEventRecord\n#define cudaEventSynchronize musaEventSynchronize\n#define cudaEvent_t musaEvent_t\n#define cudaEventDestroy musaEventDestroy\n#define cudaFree musaFree\n#define cudaFreeHost musaFreeHost\n#define cudaGetDevice musaGetDevice\n#define cudaGetDeviceCount musaGetDeviceCount\n#define cudaGetDeviceProperties musaGetDeviceProperties\n#define cudaGetErrorString musaGetErrorString\n#define cudaGetLastError musaGetLastError\n#define cudaHostRegister musaHostRegister\n#define cudaHostRegisterPortable musaHostRegisterPortable\n#define cudaHostRegisterReadOnly musaHostRegisterReadOnly\n#define cudaHostUnregister musaHostUnregister\n#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel\n#define cudaLaunchHostFunc musaLaunchHostFunc\n#define cudaMalloc musaMalloc\n#define cudaMallocHost musaMallocHost\n#define cudaMallocManaged musaMallocManaged\n#define cudaMemcpy musaMemcpy\n#define cudaMemcpyAsync musaMemcpyAsync\n#define cudaMemcpyPeerAsync musaMemcpyPeerAsync\n#define cudaMemcpy2DAsync musaMemcpy2DAsync\n#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice\n#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost\n#define cudaMemcpyHostToDevice musaMemcpyHostToDevice\n#define cudaMemcpyKind musaMemcpyKind\n#define cudaMemset musaMemset\n#define cudaMemsetAsync musaMemsetAsync\n#define cudaMemGetInfo musaMemGetInfo\n#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize\n#define cudaSetDevice musaSetDevice\n#define cudaStreamCreateWithFlags musaStreamCreateWithFlags\n#define cudaStreamDestroy musaStreamDestroy\n#define cudaStreamFireAndForget musaStreamFireAndForget\n#define cudaStreamNonBlocking musaStreamNonBlocking\n#define cudaStreamPerThread musaStreamPerThread\n#define cudaStreamSynchronize musaStreamSynchronize\n#define cudaStreamWaitEvent musaStreamWaitEvent\n#define cudaStream_t musaStream_t\n#define cudaSuccess musaSuccess\n\n// Additional mappings for MUSA virtual memory pool\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED\n#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE\n#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED\n#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED\n#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE\n#define CUdevice MUdevice\n#define CUdeviceptr MUdeviceptr\n#define CUmemAccessDesc MUmemAccessDesc\n#define CUmemAllocationProp MUmemAllocationProp\n#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle\n#define cuDeviceGet muDeviceGet\n#define cuDeviceGetAttribute muDeviceGetAttribute\n#define cuMemAddressFree muMemAddressFree\n#define cuMemAddressReserve muMemAddressReserve\n#define cuMemCreate muMemCreate\n#define cuMemGetAllocationGranularity muMemGetAllocationGranularity\n#define cuMemMap muMemMap\n#define cuMemRelease muMemRelease\n#define cuMemSetAccess muMemSetAccess\n#define cuMemUnmap muMemUnmap\n#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize\n#define cudaFuncSetAttribute musaFuncSetAttribute\n#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms\n#define make_cudaExtent make_musaExtent\n#define make_cudaPitchedPtr make_musaPitchedPtr\n\n// Additional mappings for MUSA graphs\n#define CUDA_SUCCESS MUSA_SUCCESS\n#define CUresult MUresult\n#define cuGetErrorString muGetErrorString\n#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure\n#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction\n#define cudaGraphDestroy musaGraphDestroy\n#define cudaGraphExecDestroy musaGraphExecDestroy\n#define cudaGraphExec_t musaGraphExec_t\n#define cudaGraphExecUpdate musaGraphExecUpdate\n#define cudaGraphExecUpdateResult musaGraphExecUpdateResult\n#define cudaGraphGetNodes musaGraphGetNodes\n#define cudaGraphInstantiate musaGraphInstantiate\n#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams\n#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams\n#define cudaGraphLaunch musaGraphLaunch\n#define cudaGraphNodeGetType musaGraphNodeGetType\n#define cudaGraphNode_t musaGraphNode_t\n#define cudaGraphNodeType musaGraphNodeType\n#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel\n#define cudaGraph_t musaGraph_t\n#define cudaKernelNodeParams musaKernelNodeParams\n#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed\n#define cudaStreamBeginCapture musaStreamBeginCapture\n#define cudaStreamEndCapture musaStreamEndCapture\n#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor\n\ntypedef __mt_bfloat16 nv_bfloat16;\ntypedef __mt_bfloat162 nv_bfloat162;\n"
  },
  {
    "path": "src/ggml-cuda/wkv.cu",
    "content": "#include \"common.cuh\"\n#include \"wkv.cuh\"\n\ntemplate <int block_size>\nstatic __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {\n    const int tid = threadIdx.x;\n    const int bid = blockIdx.x;\n\n    const int head_size = block_size;\n    const int batch_i = bid / H;\n    const int head_i = bid % H;\n    const int state_size = C * head_size;\n    const int n_seq_tokens = T / B;\n\n    float state[head_size];\n    __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];\n\n    #pragma unroll\n    for (int i = 0; i < head_size; i++) {\n        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];\n    }\n\n    __syncthreads();\n    _tf[tid] = tf[head_i * head_size + tid];\n    __syncthreads();\n\n    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {\n        __syncthreads();\n        _k[tid] = k[t];\n        _r[tid] = r[t];\n        _td[tid] = td[t];\n        __syncthreads();\n\n        const float _v = v[t];\n        float y = 0;\n        for (int j = 0; j < head_size; j += 4) {\n            const float4& k = (float4&)(_k[j]);\n            const float4& r = (float4&)(_r[j]);\n            const float4& tf = (float4&)(_tf[j]);\n            const float4& td = (float4&)(_td[j]);\n            float4& s = (float4&)(state[j]);\n            float4 kv;\n\n            kv.x = k.x * _v;\n            kv.y = k.y * _v;\n            kv.z = k.z * _v;\n            kv.w = k.w * _v;\n\n            y += r.x * (tf.x * kv.x + s.x);\n            y += r.y * (tf.y * kv.y + s.y);\n            y += r.z * (tf.z * kv.z + s.z);\n            y += r.w * (tf.w * kv.w + s.w);\n\n            s.x = s.x * td.x + kv.x;\n            s.y = s.y * td.y + kv.y;\n            s.z = s.z * td.z + kv.z;\n            s.w = s.w * td.w + kv.w;\n        }\n        dst[t] = y;\n    }\n\n    #pragma unroll\n    for (int i = 0; i < head_size; i++) {\n        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];\n    }\n}\n\ntemplate <int block_size>\nstatic __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {\n    const int tid = threadIdx.x;\n    const int bid = blockIdx.x;\n\n    const int head_size = block_size;\n    const int batch_i = bid / H;\n    const int head_i = bid % H;\n    const int state_size = C * head_size;\n    const int n_seq_tokens = T / B;\n\n    float state[head_size];\n    __shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];\n\n#ifndef GGML_USE_MUSA\n    #pragma unroll\n#endif\n    for (int i = 0; i < head_size; i++) {\n        state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];\n    }\n\n    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {\n        __syncthreads();\n        _r[tid] = r[t];\n        _w[tid] = w[t];\n        _k[tid] = k[t];\n        _a[tid] = a[t];\n        _b[tid] = b[t];\n        __syncthreads();\n\n        float sa = 0;\n        #pragma unroll\n        for (int j = 0; j < head_size; j += 4)\n        {\n            const float4& a = (float4&)(_a[j]);\n            const float4& s = (float4&)(state[j]);\n            sa += a.x * s.x;\n            sa += a.y * s.y;\n            sa += a.z * s.z;\n            sa += a.w * s.w;\n        }\n\n        const float _v = v[t];\n        float y = 0;\n        for (int j = 0; j < head_size; j += 4) {\n            const float4& r = (float4&)(_r[j]);\n            const float4& w = (float4&)(_w[j]);\n            const float4& k = (float4&)(_k[j]);\n            const float4& b = (float4&)(_b[j]);\n            float4& s = (float4&)(state[j]);\n            float4 kv;\n\n            kv.x = k.x * _v;\n            kv.y = k.y * _v;\n            kv.z = k.z * _v;\n            kv.w = k.w * _v;\n\n            s.x = s.x * w.x + kv.x + sa * b.x;\n            s.y = s.y * w.y + kv.y + sa * b.y;\n            s.z = s.z * w.z + kv.z + sa * b.z;\n            s.w = s.w * w.w + kv.w + sa * b.w;\n\n            y += s.x * r.x;\n            y += s.y * r.y;\n            y += s.z * r.z;\n            y += s.w * r.w;\n        }\n        dst[t] = y;\n    }\n\n    #pragma unroll\n    for (int i = 0; i < head_size; i++) {\n        dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];\n    }\n}\n\nvoid ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const float * k_d  = (const float *)dst->src[0]->data;\n    const float * v_d  = (const float *)dst->src[1]->data;\n    const float * r_d  = (const float *)dst->src[2]->data;\n    const float * tf_d = (const float *)dst->src[3]->data;\n    const float * td_d = (const float *)dst->src[4]->data;\n    const float * s_d  = (const float *)dst->src[5]->data;\n\n    const int64_t B = dst->src[5]->ne[1];\n    const int64_t T = dst->src[0]->ne[2];\n    const int64_t C = dst->ne[0];\n    const int64_t H = dst->src[0]->ne[1];\n\n    float * dst_d = (float *)dst->data;\n\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);\n    GGML_ASSERT(C % H == 0);\n    GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);\n\n    if (C / H == CUDA_WKV_BLOCK_SIZE) {\n        rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);\n    } else {\n        rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);\n    }\n}\n\nvoid ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {\n    const float * r_d = (const float *)dst->src[0]->data;\n    const float * w_d = (const float *)dst->src[1]->data;\n    const float * k_d = (const float *)dst->src[2]->data;\n    const float * v_d = (const float *)dst->src[3]->data;\n    const float * a_d = (const float *)dst->src[4]->data;\n    const float * b_d = (const float *)dst->src[5]->data;\n    const float * s_d = (const float *)dst->src[6]->data;\n\n    const int64_t B = dst->src[6]->ne[1];\n    const int64_t T = dst->src[0]->ne[2];\n    const int64_t C = dst->ne[0];\n    const int64_t H = dst->src[0]->ne[1];\n\n    float * dst_d = (float *)dst->data;\n\n    cudaStream_t stream = ctx.stream();\n\n    GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);\n    GGML_ASSERT(C % H == 0);\n    GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);\n\n    if (C / H == CUDA_WKV_BLOCK_SIZE) {\n        rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);\n    } else {\n        rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);\n    }\n}\n"
  },
  {
    "path": "src/ggml-cuda/wkv.cuh",
    "content": "#include \"common.cuh\"\n\n#define CUDA_WKV_BLOCK_SIZE 64\n\nvoid ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n\nvoid ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-hexagon/CMakeLists.txt",
    "content": "file(TO_CMAKE_PATH \"${HEXAGON_SDK_ROOT}\"   HEXAGON_SDK_ROOT)\nfile(TO_CMAKE_PATH \"${HEXAGON_TOOLS_ROOT}\" HEXAGON_TOOLS_ROOT)\n\nif (NOT IS_DIRECTORY \"${HEXAGON_SDK_ROOT}\")\n    message(FATAL_ERROR \"Make sure HEXAGON_SDK_ROOT point to the correct Hexagon SDK installation.\")\nendif()\n\nif (NOT IS_DIRECTORY \"${HEXAGON_TOOLS_ROOT}\")\n    message(\"Try to read HEXAGON_TOOLS_ROOT from hexagon_sdk.json\")\n    file(READ \"${HEXAGON_SDK_ROOT}/hexagon_sdk.json\" HEXAGON_SDK_CONFIG_PATH)\n    string(JSON HEXAGON_TOOLS_PATH GET ${HEXAGON_SDK_CONFIG_PATH} \"root\" \"tools\" \"info\" 0 \"path\")\n    message(\"Found HEXAGON_TOOLS_PATH: ${HEXAGON_TOOLS_PATH}\")\n    set(HEXAGON_TOOLS_ROOT \"${HEXAGON_SDK_ROOT}/${HEXAGON_TOOLS_PATH}\")\n    file(TO_CMAKE_PATH \"${HEXAGON_TOOLS_ROOT}\" HEXAGON_TOOLS_ROOT)\n    if (NOT IS_DIRECTORY \"${HEXAGON_TOOLS_ROOT}\")\n        message(FATAL_ERROR \"Make sure HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.\")\n    endif()\nendif()\n\nmessage(STATUS \"hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels\")\n\ninclude(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)\ninclude(ExternalProject)\n\noption(GGML_HEXAGON_HTP_DEBUG \"ggml-hexagon: enable HTP debug output\" OFF)\nset(GGML_HEXAGON_HTP_CERT  \"$ENV{HEXAGON_HTP_CERT}\" CACHE PATH \"ggml-hexagon: enable HTP library signing using certificate\")\nset(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING \"ggml-hexagon: quantize group size (32, 64, or 128)\")\n\nadd_library(htp_iface OBJECT\n    ${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)\n\nset_target_properties(htp_iface PROPERTIES POSITION_INDEPENDENT_CODE ON)\ntarget_include_directories(htp_iface PUBLIC\n    ${HEXAGON_SDK_ROOT}/incs\n    ${HEXAGON_SDK_ROOT}/incs/stddef\n    ${HEXAGON_SDK_ROOT}/utils/examples\n    ${CMAKE_CURRENT_SOURCE_DIR}/htp\n    ${CMAKE_CURRENT_BINARY_DIR})\n\nbuild_idl(htp/htp_iface.idl htp_iface)\n\nif (CMAKE_SYSTEM_NAME MATCHES Android)\n    target_link_options(htp_iface PUBLIC -llog -ldl)\nelseif (CMAKE_SYSTEM_NAME MATCHES Windows)\n    target_precompile_headers(htp_iface PUBLIC <sal.h>)\nelse()\n    target_link_options(htp_iface PUBLIC -ldl)\nendif()\n\nset(TARGET_NAME ggml-hexagon)\nggml_add_backend_library(${TARGET_NAME}\n    ggml-hexagon.cpp\n    htp-drv.cpp\n    htp-drv.h\n    libdl.h\n    ../../include/ggml-hexagon.h)\n\ntarget_link_libraries(${TARGET_NAME} PRIVATE htp_iface)\ntarget_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR})\n\n# Build HTP skels\nset(HTP_SKELS)\nfunction(build_htp_skel V)\n    ExternalProject_Add(htp-${V}\n        SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON\n        BUILD_BYPRODUCTS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so\n        CMAKE_ARGS\n            -DCMAKE_BUILD_TYPE=Release\n            -DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake\n            -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}\n            -DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT}\n            -DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT}\n            -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}\n            -DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}\n            -DDSP_VERSION=${V}\n            -DPREBUILT_LIB_DIR=\"toolv19_${V}\")\n    list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so)\n    set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE)\nendfunction()\n\nbuild_htp_skel(v68)\nbuild_htp_skel(v69)\nbuild_htp_skel(v73)\nbuild_htp_skel(v75)\nbuild_htp_skel(v79)\nbuild_htp_skel(v81)\n\n# Install Hexagon skels required at runtime\ninstall(FILES ${HTP_SKELS} TYPE LIB)\n\nif (CMAKE_SYSTEM_NAME MATCHES Windows AND GGML_HEXAGON_HTP_CERT)\n    file(TO_CMAKE_PATH \"$ENV{WINDOWS_SDK_BIN}/arm64\"      WINSDK_BIN0_ARM64)\n    file(TO_CMAKE_PATH \"$ENV{WINDOWS_SDK_BIN}/x86\"        WINSDK_BIN0_X86)\n    file(TO_CMAKE_PATH \"$ENV{WindowsSdkVerBinPath}/arm64\" WINSDK_BIN1_ARM64)\n    file(TO_CMAKE_PATH \"$ENV{WindowsSdkVerBinPath}/x86\"   WINSDK_BIN1_X86)\n\n    set(WINSDK_PATHS ${WINSDK_BIN0_ARM64} ${WINSDK_BIN0_X86} ${WINSDK_BIN1_ARM64} ${WINSDK_BIN1_X86})\n\n    find_program(INF2CAT  NAMES inf2cat.exe  PATHS ${WINSDK_PATHS} REQUIRED)\n    find_program(SIGNTOOL NAMES signtool.exe PATHS ${WINSDK_PATHS} REQUIRED)\n\n    message(STATUS \"hexagon: using ${GGML_HEXAGON_HTP_CERT} to sign libggml-htp skels\")\n\n    set(LIBGGML_HTP_CAT ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp.cat)\n    add_custom_target(libggml-htp-cat\n        BYPRODUCTS ${LIBGGML_HTP_CAT}\n        DEPENDS libggml-htp.inf ${HTP_SKELS}\n        COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/libggml-htp.inf ${CMAKE_CURRENT_BINARY_DIR}\n        COMMAND ${INF2CAT} /driver:${CMAKE_CURRENT_BINARY_DIR} /os:10_25H2_ARM64\n        COMMAND ${SIGNTOOL} sign /fd sha256 /f ${GGML_HEXAGON_HTP_CERT} ${LIBGGML_HTP_CAT}\n        COMMENT \"generating and signing libggml-htp.cat file\"\n        VERBATIM\n    )\n\n    add_dependencies(${TARGET_NAME} libggml-htp-cat)\n    install(FILES ${LIBGGML_HTP_CAT} TYPE LIB)\nendif()\n"
  },
  {
    "path": "src/ggml-hexagon/ggml-hexagon.cpp",
    "content": "#include <assert.h>\n#include <inttypes.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <string.h>\n#include <time.h>\n\n#include <atomic>\n#include <chrono>\n#include <cstddef>\n#include <mutex>\n#include <stdexcept>\n#include <string>\n\n#ifdef _WIN32\n#    include <sal.h>\n#else\n#    include <semaphore.h>\n#    include <unistd.h>\n#endif\n\n#pragma clang diagnostic ignored \"-Wnested-anon-types\"\n#pragma clang diagnostic ignored \"-Wgnu-anonymous-struct\"\n\n#include <AEEStdErr.h>\n#include <dspqueue.h>\n#include <rpcmem.h>\n\n#define GGML_COMMON_IMPL_CPP\n#include \"ggml-backend-impl.h\"\n#include \"ggml-common.h\"\n#include \"ggml-hexagon.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-quants.h\"\n#include \"op-desc.h\"\n#include \"htp-msg.h\"\n#include \"htp_iface.h\"\n#include \"htp-drv.h\"\n\nstatic size_t opt_ndev         = 1;\nstatic size_t opt_nhvx         = 0; // use all\nstatic int    opt_arch         = 0; // autodetect\nstatic int    opt_etm          = 0;\nstatic int    opt_verbose      = 0;\nstatic int    opt_profile      = 0;\nstatic int    opt_hostbuf      = 1; // hostbuf ON by default\nstatic int    opt_experimental = 0;\n\n// Enable all stages by default\nstatic int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE;\nstatic int opt_opsync = 0;  // synchronous ops\n\n#define HEX_VERBOSE(...) \\\n    if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__)\n\nstatic inline uint64_t hex_is_aligned(void * addr, uint32_t align) {\n    return ((size_t) addr & (align - 1)) == 0;\n}\n\nstatic inline size_t hex_round_up(size_t n, size_t m) {\n    return m * ((n + m - 1) / m);\n}\n\nstatic const char * status_to_str(uint32_t status) {\n    switch (status) {\n        case HTP_STATUS_OK:\n            return \"OK\";\n        case HTP_STATUS_NO_SUPPORT:\n            return \"NO-SUPPORT\";\n        case HTP_STATUS_INVAL_PARAMS:\n            return \"INVAL-PARAMS\";\n        case HTP_STATUS_VTCM_TOO_SMALL:\n            return \"VTCM-TOO-SMALL\";\n        case HTP_STATUS_INTERNAL_ERR:\n            return \"INTERNAL-ERROR\";\n        default:\n            return \"UNKNOWN\";\n    }\n}\n\n// ** debug helpers\n\nstatic void ggml_hexagon_dump_op_exec(const std::string &sess_name, const ggml_tensor * op, const uint32_t req_flags) {\n    if (!opt_verbose) return;\n\n    op_desc desc(op);\n    GGML_LOG_DEBUG(\"ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\\n\", sess_name.c_str(),\n                ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags);\n}\n\nstatic void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) {\n    if (!opt_verbose) return;\n\n    op_desc desc(op);\n    GGML_LOG_DEBUG(\"ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\\n\", sess_name.c_str(),\n                ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? \"yes\" : \"no\");\n}\n\nstatic void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op,\n                                      uint32_t op_usec, uint32_t op_cycles, uint32_t op_pkts, uint64_t call_usec) {\n    if (!opt_profile) return;\n\n    op_desc desc(op);\n    GGML_LOG_DEBUG(\"ggml-hex: %s profile-op %s: %s : %s : %s : %s : %s : op-usec %u op-cycles %u op-pkts %u (%f) call-usec %llu\\n\", sess_name.c_str(),\n                ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs,\n                op_usec, op_cycles, op_pkts, (float) op_cycles / op_pkts, (unsigned long long) call_usec);\n}\n\n// ** backend sessions\n\nstruct ggml_hexagon_session {\n    ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false);\n    ~ggml_hexagon_session() noexcept(true);\n\n    void allocate(int dev_id) noexcept(false);\n    void release() noexcept(true);\n\n    void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false);\n    void flush();\n\n    ggml_backend_buffer_type buffer_type        = {};\n    ggml_backend_buffer_type repack_buffer_type = {};\n\n    std::string      name;\n    remote_handle64  handle;\n    dspqueue_t       queue;\n    uint32_t         session_id;\n    uint32_t         domain_id;\n    uint64_t         queue_id;\n    int              dev_id;\n    bool             valid_session;\n    bool             valid_handle;\n    bool             valid_queue;\n    bool             valid_iface;\n    std::atomic<int> op_pending;\n    uint32_t         prof_usecs;\n    uint32_t         prof_cycles;\n    uint32_t         prof_pkts;\n};\n\nvoid ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) {\n    // Bump pending flag (cleared in the session::flush once we get the response)\n    this->op_pending++;  // atomic inc\n\n    int err = dspqueue_write(this->queue,\n                             0,                       // flags - the framework will autoset this\n                             n_bufs,                  // number of buffers\n                             bufs,                    // buffer references\n                             sizeof(req),             // Message length\n                             (const uint8_t *) &req,  // Message\n                             DSPQUEUE_TIMEOUT         // Timeout\n    );\n\n    if (err != 0) {\n        GGML_ABORT(\"ggml-hex: %s dspqueue_write failed: 0x%08x\\n\", this->name.c_str(), (unsigned) err);\n    }\n\n    if (sync) {\n        flush();\n    }\n}\n\n// Flush HTP response queue i.e wait for all outstanding requests to complete\nvoid ggml_hexagon_session::flush() {\n    dspqueue_t q = this->queue;\n\n    // Repeatedly read packets from the queue until it's empty. We don't\n    // necessarily get a separate callback for each packet, and new packets\n    // may arrive while we're processing the previous one.\n\n    while (this->op_pending) {\n        struct htp_general_rsp rsp;\n        uint32_t               rsp_size;\n        uint32_t               flags;\n\n        struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];\n        uint32_t               n_bufs;\n\n        // Read response packet from queue\n        int err = dspqueue_read(q, &flags,\n                                HTP_MAX_PACKET_BUFFERS,  // Maximum number of buffer references\n                                &n_bufs,                 // Number of buffer references\n                                bufs,                    // Buffer references\n                                sizeof(rsp),             // Max message length\n                                &rsp_size,               // Message length\n                                (uint8_t *) &rsp,        // Message\n                                DSPQUEUE_TIMEOUT);       // Timeout\n\n        if (err == AEE_EEXPIRED) {\n            // TODO: might need to bail out if the HTP is stuck on something\n            continue;\n        }\n\n        if (err != 0) {\n            GGML_ABORT(\"ggml-hex: dspqueue_read failed: 0x%08x\\n\", (unsigned) err);\n        }\n\n        // Basic sanity checks\n        if (rsp_size != sizeof(rsp)) {\n            GGML_ABORT(\"ggml-hex: dspcall : bad response (size)\\n\");\n        }\n\n        if (rsp.status != HTP_STATUS_OK) {\n            GGML_LOG_ERROR(\"ggml-hex: dspcall : dsp-rsp: %s\\n\", status_to_str(rsp.status));\n            // TODO: handle errors\n        }\n\n        // TODO: update profiling implementation, currently only works for opt_opsync mode\n        this->prof_usecs  = rsp.prof_usecs;\n        this->prof_cycles = rsp.prof_cycles;\n        this->prof_pkts   = rsp.prof_pkts;\n\n        this->op_pending--;  // atomic dec\n    }\n}\n\n// ** backend buffers\n\nstruct ggml_backend_hexagon_buffer_type_context {\n    ggml_backend_hexagon_buffer_type_context(const std::string & name, ggml_hexagon_session * sess) {\n        this->sess = sess;\n        this->name = name;\n    }\n\n    ggml_hexagon_session * sess;\n    std::string            name;\n};\n\nstruct ggml_backend_hexagon_buffer_context {\n    bool mmap_to(ggml_hexagon_session * s) {\n        HEX_VERBOSE(\"ggml-hex: %s mmaping buffer: base %p domain-id %d session-id %d size %zu fd %d repack %d\\n\",\n                    s->name.c_str(), (void *) this->base, s->domain_id, s->session_id, this->size, this->fd,\n                    (int) this->repack);\n\n        int err = fastrpc_mmap(s->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD);\n        if (err != 0) {\n            GGML_LOG_ERROR(\"ggml-hex: buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\\n\",\n                    s->domain_id, this->size, this->fd, (unsigned) err);\n            return false;\n        }\n\n        return true;\n    }\n\n    bool mmap() {\n        if (this->mapped) {\n            return true;\n        }\n        if (!mmap_to(this->sess)) {\n            return false;\n        }\n        this->mapped = true;\n        return true;\n    }\n\n    void munmap() {\n        if (!this->mapped) {\n            return;\n        }\n\n        fastrpc_munmap(this->sess->domain_id, this->fd, this->base, this->size);\n        this->mapped = false;\n    }\n\n    ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) {\n        size += 4 * 1024;  // extra page for padding\n\n        this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);\n        if (!this->base) {\n            GGML_LOG_ERROR(\"ggml-hex: %s failed to allocate buffer : size %zu\\n\", sess->name.c_str(), size);\n            throw std::runtime_error(\"ggml-hex: rpcmem_alloc failed (see log for details)\");\n        }\n\n        this->fd = rpcmem_to_fd(this->base);\n        if (this->fd < 0) {\n            GGML_LOG_ERROR(\"ggml-hex: %s failed to get FD for buffer %p\\n\", sess->name.c_str(), (void *) this->base);\n            rpcmem_free(this->base);\n            this->base = NULL;\n            throw std::runtime_error(\"ggml-hex: rpcmem_to_fd failed (see log for details)\");\n        }\n\n        HEX_VERBOSE(\"ggml-hex: %s allocated buffer: base %p size %zu fd %d repack %d\\n\", sess->name.c_str(),\n                    (void *) this->base, size, this->fd, (int) repack);\n\n        this->sess   = sess;\n        this->size   = size;\n        this->mapped = false;\n        this->repack = repack;\n    }\n\n    ~ggml_backend_hexagon_buffer_context() {\n        munmap();\n        if (this->base) {\n            rpcmem_free(this->base);\n            this->base = NULL;\n        }\n    }\n\n    ggml_hexagon_session * sess;  // primary session\n    uint8_t *              base;\n    size_t                 size;\n    int                    fd;\n    bool                   mapped;  // mmap is done\n    bool                   repack;  // repacked buffer\n};\n\nstatic ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_buffer_t buffer) {\n    return static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer->buft->context)->sess;\n}\n\nstatic void ggml_backend_hexagon_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);\n    delete ctx;\n}\n\nstatic void * ggml_backend_hexagon_buffer_get_base(ggml_backend_buffer_t buffer) {\n    auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);\n    return ctx->base;\n}\n\nstatic enum ggml_status ggml_backend_hexagon_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {\n    auto ctx  = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);\n    auto sess = ctx->sess;\n\n    HEX_VERBOSE(\"ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d repack %d\\n\", sess->name.c_str(),\n                tensor->name, (void *) ctx->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage,\n                (int) ctx->repack);\n\n    if (tensor->view_src != NULL && tensor->view_offs == 0) {\n        ; // nothing to do for the view\n    } else {\n        if (!ctx->mapped) {\n            ctx->mmap();\n        }\n    }\n    return GGML_STATUS_SUCCESS;\n}\n\n// ======== Q4x4x2 ====================\nstruct x2_q4 {\n    int v[2];\n};\n\nstatic x2_q4 unpack_q4(uint8_t v) {\n    x2_q4 x = { (int) (v & 0x0f) - 8, (int) (v >> 4) - 8 };\n    return x;\n}\n\nstatic void dump_block_q4_0(const block_q4_0 * b, int i) {\n    HEX_VERBOSE(\"ggml-hex: repack q4_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\\n\", i, unpack_q4(b->qs[0]).v[0],\n                unpack_q4(b->qs[1]).v[0], unpack_q4(b->qs[2]).v[0], unpack_q4(b->qs[3]).v[0], unpack_q4(b->qs[12]).v[1],\n                unpack_q4(b->qs[13]).v[1], unpack_q4(b->qs[14]).v[1], unpack_q4(b->qs[15]).v[1],\n                GGML_FP16_TO_FP32(b->d));\n}\n\nstatic void dump_packed_block_q4x4x2(const uint8_t * v, unsigned int i, size_t k) {\n    static const int qk        = QK_Q4_0x4x2;\n    const int        dblk_size = 8 * 2;   // 8x __fp16\n    const int        qblk_size = qk / 2;  // int4\n    const int        qrow_size = k / 2;   // int4 (not padded)\n\n    const uint8_t * v_q = v + 0;          // quants first\n    const uint8_t * v_d = v + qrow_size;  // then scales\n\n    const uint8_t *   q = v_q + i * qblk_size;\n    const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size);\n\n    HEX_VERBOSE(\"ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\\n\", i,\n                unpack_q4(q[0]).v[0], unpack_q4(q[1]).v[0], unpack_q4(q[2]).v[0], unpack_q4(q[3]).v[0],\n                unpack_q4(q[60]).v[0], unpack_q4(q[61]).v[0], unpack_q4(q[62]).v[0], unpack_q4(q[63]).v[0],\n                unpack_q4(q[124]).v[0], unpack_q4(q[125]).v[0], unpack_q4(q[126]).v[0], unpack_q4(q[127]).v[0],\n                GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3]));\n\n    HEX_VERBOSE(\"ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\\n\",\n                i + 1, unpack_q4(q[0]).v[1], unpack_q4(q[1]).v[1], unpack_q4(q[2]).v[1], unpack_q4(q[3]).v[1],\n                unpack_q4(q[60]).v[1], unpack_q4(q[61]).v[1], unpack_q4(q[62]).v[1], unpack_q4(q[63]).v[1],\n                unpack_q4(q[124]).v[1], unpack_q4(q[125]).v[1], unpack_q4(q[126]).v[1], unpack_q4(q[127]).v[1],\n                GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7]));\n}\n\nstatic void unpack_q4_0_quants(uint8_t * qs, const block_q4_0 * x, unsigned int bi) {\n    static const int qk = QK4_0;\n\n    for (unsigned int i = 0; i < qk / 2; ++i) {\n        const int x0             = (x->qs[i] & 0x0F);\n        const int x1             = (x->qs[i] >> 4);\n        qs[bi * qk + i + 0]      = x0;\n        qs[bi * qk + i + qk / 2] = x1;\n    }\n}\n\nstatic void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi) {\n    static const int qk = QK4_0;\n\n    for (unsigned int i = 0; i < qk / 2; ++i) {\n        const uint8_t x0 = qs[bi * qk + i + 0];\n        const uint8_t x1 = qs[bi * qk + i + qk / 2];\n        x->qs[i]         = x0 | (x1 << 4);\n    }\n}\n\nstatic void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {\n    static const int qk = QK_Q4_0x4x2;\n    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)\n    const int        nloe = k % qk;           // leftovers\n\n    const int dblk_size = 8 * 2;              // 8x __fp16\n    const int qblk_size = qk / 2;             // int4\n    const int qrow_size = k / 2;              // int4 (not padded to blocks)\n\n    uint8_t * y_q = y + 0;                    // quants first\n    uint8_t * y_d = y + qrow_size;            // then scales\n\n    if (opt_verbose > 2) {\n        for (int i = 0; i < nb; i++) {\n            dump_block_q4_0(&x[i * 8 + 0], 0);\n            dump_block_q4_0(&x[i * 8 + 1], 1);\n            dump_block_q4_0(&x[i * 8 + 2], 2);\n            dump_block_q4_0(&x[i * 8 + 3], 3);\n            dump_block_q4_0(&x[i * 8 + 4], 4);\n            dump_block_q4_0(&x[i * 8 + 5], 5);\n            dump_block_q4_0(&x[i * 8 + 6], 6);\n            dump_block_q4_0(&x[i * 8 + 7], 7);\n        }\n    }\n\n    // Repack the quants\n    for (int i = 0; i < nb; i++) {\n        uint8_t qs[QK_Q4_0x4x2];  // unpacked quants\n        unpack_q4_0_quants(qs, &x[i * 8 + 0], 0);\n        unpack_q4_0_quants(qs, &x[i * 8 + 1], 1);\n        unpack_q4_0_quants(qs, &x[i * 8 + 2], 2);\n        unpack_q4_0_quants(qs, &x[i * 8 + 3], 3);\n        unpack_q4_0_quants(qs, &x[i * 8 + 4], 4);\n        unpack_q4_0_quants(qs, &x[i * 8 + 5], 5);\n        unpack_q4_0_quants(qs, &x[i * 8 + 6], 6);\n        unpack_q4_0_quants(qs, &x[i * 8 + 7], 7);\n\n        bool partial = (nloe && i == nb-1);\n\n        uint8_t * q = y_q + (i * qblk_size);\n        for (int j = 0; j < qk / 2; j++) {\n            q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000];\n        }\n    }\n\n    // Repack the scales\n    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)\n    // the last block is truncated and overridden by the scales.\n    for (int i = 0; i < nb; i++) {\n        // Repack the scales\n        ggml_half * d = (ggml_half *) (y_d + i * dblk_size);\n        d[0]          = x[i * 8 + 0].d;\n        d[1]          = x[i * 8 + 1].d;\n        d[2]          = x[i * 8 + 2].d;\n        d[3]          = x[i * 8 + 3].d;\n        d[4]          = x[i * 8 + 4].d;\n        d[5]          = x[i * 8 + 5].d;\n        d[6]          = x[i * 8 + 6].d;\n        d[7]          = x[i * 8 + 7].d;\n    }\n\n    if (opt_verbose > 1) {\n        for (int i = 0; i < nb; i++) {\n            dump_packed_block_q4x4x2(y, i, k);\n        }\n    }\n}\n\nstatic void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) {\n    static const int qk = QK_Q4_0x4x2;\n    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)\n    const int        nloe = k % qk;           // leftovers\n\n    const int dblk_size = 8 * 2;              // 8x __fp16\n    const int qblk_size = qk / 2;             // int4\n    const int qrow_size = k / 2;              // int4 (not padded to blocks)\n\n    const uint8_t * y_q = y + 0;              // quants first\n    const uint8_t * y_d = y + qrow_size;      // then scales\n\n    if (opt_verbose > 1) {\n        for (int i = 0; i < nb; i++) {\n            dump_packed_block_q4x4x2(y, i, k);\n        }\n    }\n\n    // Unpack the quants\n    for (int i = 0; i < nb; i++) {\n        uint8_t qs[QK_Q4_0x4x2];  // unpacked quants\n\n        bool partial = (nloe && i == nb-1);\n\n        const uint8_t * q = y_q + (i * qblk_size);\n        for (int j = 0; j < qk / 2; j++) {\n            if (partial) {\n                qs[j*2+0] = q[j] & 0xf;\n                qs[j*2+1] = q[j] >> 4;\n            } else {\n                qs[j+000] = q[j] & 0xf;\n                qs[j+128] = q[j] >> 4;\n            }\n        }\n\n        pack_q4_0_quants(&x[i * 8 + 0], qs, 0);\n        pack_q4_0_quants(&x[i * 8 + 1], qs, 1);\n        pack_q4_0_quants(&x[i * 8 + 2], qs, 2);\n        pack_q4_0_quants(&x[i * 8 + 3], qs, 3);\n        pack_q4_0_quants(&x[i * 8 + 4], qs, 4);\n        pack_q4_0_quants(&x[i * 8 + 5], qs, 5);\n        pack_q4_0_quants(&x[i * 8 + 6], qs, 6);\n        pack_q4_0_quants(&x[i * 8 + 7], qs, 7);\n    }\n\n    // Repack the scales\n    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)\n    // the last block is truncated and overridden by the scales.\n    for (int i = 0; i < nb; i++) {\n        // Unpack the scales\n        const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size);\n        x[i * 8 + 0].d      = d[0];\n        x[i * 8 + 1].d      = d[1];\n        x[i * 8 + 2].d      = d[2];\n        x[i * 8 + 3].d      = d[3];\n        x[i * 8 + 4].d      = d[4];\n        x[i * 8 + 5].d      = d[5];\n        x[i * 8 + 6].d      = d[6];\n        x[i * 8 + 7].d      = d[7];\n    }\n\n    if (opt_verbose > 2) {\n        for (int i = 0; i < nb; i++) {\n            dump_block_q4_0(&x[i * 8 + 0], 0);\n            dump_block_q4_0(&x[i * 8 + 1], 1);\n            dump_block_q4_0(&x[i * 8 + 2], 2);\n            dump_block_q4_0(&x[i * 8 + 3], 3);\n            dump_block_q4_0(&x[i * 8 + 4], 4);\n            dump_block_q4_0(&x[i * 8 + 5], 5);\n            dump_block_q4_0(&x[i * 8 + 6], 6);\n            dump_block_q4_0(&x[i * 8 + 7], 7);\n        }\n    }\n}\n\nstatic void init_row_q4x4x2(block_q4_0 * x, int64_t k) {\n    static const int qk = QK_Q4_0x4x2;\n    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)\n\n    // Init the quants such that they unpack into zeros\n    uint8_t qs[QK_Q4_0x4x2];  // unpacked quants\n    memset(qs, 8, sizeof(qs));\n\n    for (int i = 0; i < nb; i++) {\n        pack_q4_0_quants(&x[i * 8 + 0], qs, 0);\n        pack_q4_0_quants(&x[i * 8 + 1], qs, 1);\n        pack_q4_0_quants(&x[i * 8 + 2], qs, 2);\n        pack_q4_0_quants(&x[i * 8 + 3], qs, 3);\n        pack_q4_0_quants(&x[i * 8 + 4], qs, 4);\n        pack_q4_0_quants(&x[i * 8 + 5], qs, 5);\n        pack_q4_0_quants(&x[i * 8 + 6], qs, 6);\n        pack_q4_0_quants(&x[i * 8 + 7], qs, 7);\n    }\n\n    // Init the scales\n    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)\n    // the last block is truncated and overridden by the scales.\n    for (int i = 0; i < nb; i++) {\n        // Unpack the scales\n        x[i * 8 + 0].d = 0;\n        x[i * 8 + 1].d = 0;\n        x[i * 8 + 2].d = 0;\n        x[i * 8 + 3].d = 0;\n        x[i * 8 + 4].d = 0;\n        x[i * 8 + 5].d = 0;\n        x[i * 8 + 6].d = 0;\n        x[i * 8 + 7].d = 0;\n    }\n}\n\n// repack q4_0 data into q4x4x2 tensor\nstatic void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) {\n    int64_t nrows = ggml_nrows(t);\n\n    size_t row_size    = ggml_row_size(t->type, t->ne[0]);\n    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2));  // extra elements for the pad\n    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)\n\n    // Ensure we don't try to read more data than is available in the source buffer 'data'\n    // or write more than the tensor can hold.\n    const size_t total_tensor_size = (size_t)nrows * row_size;\n    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;\n\n    // Calculate how many full rows and how many remaining bytes we need to process.\n    const int64_t n_full_rows = n_bytes_to_copy / row_size;\n    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;\n\n    void * buf_pd = ggml_aligned_malloc(row_size_pd);\n    GGML_ASSERT(buf_pd != NULL);\n\n    void * buf_rp = ggml_aligned_malloc(row_size_rp);\n    GGML_ASSERT(buf_rp != NULL);\n\n    HEX_VERBOSE(\"ggml-hex: repack-q4_0-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\\n\", t->name, data, size,\n                t->ne[0], nrows, row_size);\n\n    init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]);  // init padded buffer to make sure the tail is all zeros\n\n    // 1. Process all the full rows\n    for (int64_t i = 0; i < n_full_rows; i++) {\n        const uint8_t * src = (const uint8_t *) data + (i * row_size);\n        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);\n\n        memcpy(buf_pd, src, row_size);\n        repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]);\n        memcpy(dst, buf_rp, row_size);\n    }\n\n    // 2. Process the final, potentially partial, row\n    if (n_rem_bytes > 0) {\n        const int64_t i = n_full_rows;\n        const uint8_t * src = (const uint8_t *) data + (i * row_size);\n        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);\n\n        // re-init the row because we are potentially copying a partial row\n        init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]);\n\n        // Copy only the remaining bytes from the source.\n        memcpy(buf_pd, src, n_rem_bytes);\n\n        // Repack the entire buffer\n        repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]);\n\n        // Write only the corresponding remaining bytes to the destination tensor.\n        memcpy(dst, buf_rp, n_rem_bytes);\n    }\n\n    ggml_aligned_free(buf_pd, row_size_pd);\n    ggml_aligned_free(buf_rp, row_size_rp);\n}\n\n// repack q4x4x2 tensor into q4_0 data\nstatic void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) {\n    int64_t nrows = ggml_nrows(t);\n\n    size_t row_size    = ggml_row_size(t->type, t->ne[0]);\n    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2));  // extra elements for the pad\n    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)\n\n    // Ensure we don't try to copy more data than the tensor actually contains.\n    const size_t total_tensor_size = (size_t)nrows * row_size;\n    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;\n\n    // Calculate how many full rows and how many remaining bytes we need to process.\n    const int64_t n_full_rows = n_bytes_to_copy / row_size;\n    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;\n\n    void * buf_pd = ggml_aligned_malloc(row_size_pd);\n    GGML_ASSERT(buf_pd != NULL);\n\n    void * buf_rp = ggml_aligned_malloc(row_size_rp);\n    GGML_ASSERT(buf_rp != NULL);\n\n    HEX_VERBOSE(\"ggml-hex: repack-q4x4x2-q4_0 %s : data %p size %zu dims %ldx%ld row-size %zu\\n\", t->name, data, size,\n                t->ne[0], nrows, row_size);\n\n    memset(buf_pd, 0, row_size_pd);  // clear-out padded buffer to make sure the tail is all zeros\n\n    // 1. Process all the full rows\n    for (int64_t i = 0; i < n_full_rows; i++) {\n        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);\n        uint8_t *       dst = (uint8_t *) data + (i * row_size);\n\n        memcpy(buf_pd, src, row_size);\n        unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);\n        memcpy(dst, buf_rp, row_size);\n    }\n\n    // 2. Process the final, potentially partial, row\n    if (n_rem_bytes > 0) {\n        const int64_t i = n_full_rows;\n        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);\n        uint8_t *       dst = (uint8_t *) data + (i * row_size);\n\n        // We still need to read and unpack the entire source row because quantization is block-based.\n        memcpy(buf_pd, src, row_size);\n        unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);\n\n        // But we only copy the remaining number of bytes to the destination.\n        memcpy(dst, buf_rp, n_rem_bytes);\n    }\n\n    ggml_aligned_free(buf_pd, row_size_pd);\n    ggml_aligned_free(buf_rp, row_size_rp);\n}\n\n// ======== Q8x4x2 ====================\nstatic void dump_block_q8_0(const block_q8_0 * b, int i) {\n    HEX_VERBOSE(\"ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\\n\", i, b->qs[0], b->qs[1], b->qs[2],\n                b->qs[3], b->qs[28], b->qs[29], b->qs[30], b->qs[31], GGML_FP16_TO_FP32(b->d));\n}\n\nstatic void dump_packed_block_q8x4x2(const uint8_t * v, unsigned int i, size_t k) {\n    static const int qk        = QK_Q8_0x4x2;\n    const int        dblk_size = 8 * 2;   // 8x __fp16\n    const int        qblk_size = qk;      // int8\n    const int        qrow_size = k;       // int8 (not padded)\n\n    const uint8_t * v_q = v + 0;          // quants first\n    const uint8_t * v_d = v + qrow_size;  // then scales\n\n    const uint8_t *   q = v_q + i * qblk_size;\n    const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size);\n\n    HEX_VERBOSE(\"ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\\n\", i,\n                q[0], q[1], q[2], q[3], q[60], q[61], q[62], q[63], q[124], q[125], q[126], q[127],\n                GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3]));\n\n    HEX_VERBOSE(\"ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\\n\",\n                i + 1, q[128], q[129], q[130], q[131], q[192], q[193], q[194], q[195], q[252], q[253], q[254], q[255],\n                GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7]));\n}\n\nstatic void unpack_q8_0_quants(uint8_t * qs, const block_q8_0 * x, unsigned int bi) {\n    static const int qk = QK8_0;\n\n    for (unsigned int i = 0; i < qk; ++i) {\n        qs[bi * qk + i] = x->qs[i];\n    }\n}\n\nstatic void pack_q8_0_quants(block_q8_0 * x, const uint8_t * qs, unsigned int bi) {\n    static const int qk = QK8_0;\n\n    for (unsigned int i = 0; i < qk; ++i) {\n        x->qs[i] = qs[bi * qk + i];\n    }\n}\n\nstatic void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) {\n    static const int qk = QK_Q8_0x4x2;\n    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)\n\n    const int dblk_size = 8 * 2;              // 8x __fp16\n    const int qblk_size = qk;                 // int8\n    const int qrow_size = k;                  // int8 (not padded to blocks)\n\n    uint8_t * y_q = y + 0;                    // quants first\n    uint8_t * y_d = y + qrow_size;            // then scales\n\n    if (opt_verbose > 2) {\n        for (int i = 0; i < nb; i++) {\n            dump_block_q8_0(&x[i * 8 + 0], 0);\n            dump_block_q8_0(&x[i * 8 + 1], 1);\n            dump_block_q8_0(&x[i * 8 + 2], 2);\n            dump_block_q8_0(&x[i * 8 + 3], 3);\n            dump_block_q8_0(&x[i * 8 + 4], 4);\n            dump_block_q8_0(&x[i * 8 + 5], 5);\n            dump_block_q8_0(&x[i * 8 + 6], 6);\n            dump_block_q8_0(&x[i * 8 + 7], 7);\n        }\n    }\n\n    // Repack the quants\n    for (int i = 0; i < nb; i++) {\n        uint8_t qs[QK_Q8_0x4x2];  // unpacked quants\n\n        unpack_q8_0_quants(qs, &x[i * 8 + 0], 0);\n        unpack_q8_0_quants(qs, &x[i * 8 + 1], 1);\n        unpack_q8_0_quants(qs, &x[i * 8 + 2], 2);\n        unpack_q8_0_quants(qs, &x[i * 8 + 3], 3);\n        unpack_q8_0_quants(qs, &x[i * 8 + 4], 4);\n        unpack_q8_0_quants(qs, &x[i * 8 + 5], 5);\n        unpack_q8_0_quants(qs, &x[i * 8 + 6], 6);\n        unpack_q8_0_quants(qs, &x[i * 8 + 7], 7);\n\n        uint8_t * q = y_q + (i * qblk_size);\n        for (int j = 0; j < qk; j++) {\n            q[j] = qs[j];\n        }\n    }\n\n    // Repack the scales\n    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)\n    // the last block is truncated and overridden by the scales.\n    for (int i = 0; i < nb; i++) {\n        // Repack the scales\n        ggml_half * d = (ggml_half *) (y_d + i * dblk_size);\n        d[0]          = x[i * 8 + 0].d;\n        d[1]          = x[i * 8 + 1].d;\n        d[2]          = x[i * 8 + 2].d;\n        d[3]          = x[i * 8 + 3].d;\n        d[4]          = x[i * 8 + 4].d;\n        d[5]          = x[i * 8 + 5].d;\n        d[6]          = x[i * 8 + 6].d;\n        d[7]          = x[i * 8 + 7].d;\n    }\n\n    if (opt_verbose > 1) {\n        for (int i = 0; i < nb; i++) {\n            dump_packed_block_q8x4x2(y, i, k);\n        }\n    }\n}\n\nstatic void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) {\n    static const int qk = QK_Q8_0x4x2;\n    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)\n\n    const int dblk_size = 8 * 2;              // 8x __fp16\n    const int qblk_size = qk;                 // int8\n    const int qrow_size = k;                  // int8 (not padded to blocks)\n\n    const uint8_t * y_q = y + 0;              // quants first\n    const uint8_t * y_d = y + qrow_size;      // then scales\n\n    if (opt_verbose > 1) {\n        for (int i = 0; i < nb; i++) {\n            dump_packed_block_q8x4x2(y, i, k);\n        }\n    }\n\n    // Unpack the quants\n    for (int i = 0; i < nb; i++) {\n        uint8_t qs[QK_Q4_0x4x2];  // unpacked quants\n\n        const uint8_t * q = y_q + (i * qblk_size);\n        for (int j = 0; j < qk; j++) {\n            qs[j] = q[j];\n        }\n\n        pack_q8_0_quants(&x[i * 8 + 0], qs, 0);\n        pack_q8_0_quants(&x[i * 8 + 1], qs, 1);\n        pack_q8_0_quants(&x[i * 8 + 2], qs, 2);\n        pack_q8_0_quants(&x[i * 8 + 3], qs, 3);\n        pack_q8_0_quants(&x[i * 8 + 4], qs, 4);\n        pack_q8_0_quants(&x[i * 8 + 5], qs, 5);\n        pack_q8_0_quants(&x[i * 8 + 6], qs, 6);\n        pack_q8_0_quants(&x[i * 8 + 7], qs, 7);\n    }\n\n    // Repack the scales\n    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)\n    // the last block is truncated and overridden by the scales.\n    for (int i = 0; i < nb; i++) {\n        // Unpack the scales\n        const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size);\n        x[i * 8 + 0].d      = d[0];\n        x[i * 8 + 1].d      = d[1];\n        x[i * 8 + 2].d      = d[2];\n        x[i * 8 + 3].d      = d[3];\n        x[i * 8 + 4].d      = d[4];\n        x[i * 8 + 5].d      = d[5];\n        x[i * 8 + 6].d      = d[6];\n        x[i * 8 + 7].d      = d[7];\n    }\n\n    if (opt_verbose > 2) {\n        for (int i = 0; i < nb; i++) {\n            dump_block_q8_0(&x[i * 8 + 0], 0);\n            dump_block_q8_0(&x[i * 8 + 1], 1);\n            dump_block_q8_0(&x[i * 8 + 2], 2);\n            dump_block_q8_0(&x[i * 8 + 3], 3);\n            dump_block_q8_0(&x[i * 8 + 4], 4);\n            dump_block_q8_0(&x[i * 8 + 5], 5);\n            dump_block_q8_0(&x[i * 8 + 6], 6);\n            dump_block_q8_0(&x[i * 8 + 7], 7);\n        }\n    }\n}\n\nstatic void init_row_q8x4x2(block_q8_0 * x, int64_t k) {\n    static const int qk = QK_Q8_0x4x2;\n    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)\n\n    // Init the quants such that they unpack into zeros\n    uint8_t qs[QK_Q8_0x4x2];  // unpacked quants\n    memset(qs, 0, sizeof(qs));\n\n    for (int i = 0; i < nb; i++) {\n        pack_q8_0_quants(&x[i * 8 + 0], qs, 0);\n        pack_q8_0_quants(&x[i * 8 + 1], qs, 1);\n        pack_q8_0_quants(&x[i * 8 + 2], qs, 2);\n        pack_q8_0_quants(&x[i * 8 + 3], qs, 3);\n        pack_q8_0_quants(&x[i * 8 + 4], qs, 4);\n        pack_q8_0_quants(&x[i * 8 + 5], qs, 5);\n        pack_q8_0_quants(&x[i * 8 + 6], qs, 6);\n        pack_q8_0_quants(&x[i * 8 + 7], qs, 7);\n    }\n\n    // Init the scales\n    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q8_0x4x2)\n    // the last block is truncated and overridden by the scales.\n    for (int i = 0; i < nb; i++) {\n        // Unpack the scales\n        x[i * 8 + 0].d = 0;\n        x[i * 8 + 1].d = 0;\n        x[i * 8 + 2].d = 0;\n        x[i * 8 + 3].d = 0;\n        x[i * 8 + 4].d = 0;\n        x[i * 8 + 5].d = 0;\n        x[i * 8 + 6].d = 0;\n        x[i * 8 + 7].d = 0;\n    }\n}\n\n// repack q8_0 data into q8x4x2 tensor\nstatic void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) {\n    int64_t nrows = ggml_nrows(t);\n\n    size_t row_size    = ggml_row_size(t->type, t->ne[0]);\n    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2));  // extra elements for the pad\n    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)\n\n    // Ensure we don't try to read more data than is available in the source buffer 'data'\n    // or write more than the tensor can hold.\n    const size_t total_tensor_size = (size_t)nrows * row_size;\n    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;\n\n    // Calculate how many full rows and how many remaining bytes we need to process.\n    const int64_t n_full_rows = n_bytes_to_copy / row_size;\n    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;\n\n    void * buf_pd = ggml_aligned_malloc(row_size_pd);\n    GGML_ASSERT(buf_pd != NULL);\n\n    void * buf_rp = ggml_aligned_malloc(row_size_rp);\n    GGML_ASSERT(buf_rp != NULL);\n\n    HEX_VERBOSE(\"ggml-hex: repack-q8_0-q8x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\\n\", t->name, data, size,\n                t->ne[0], nrows, row_size);\n\n    init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]);  // init padded buffer to make sure the tail is all zeros\n\n    // 1. Process all the full rows\n    for (int64_t i = 0; i < n_full_rows; i++) {\n        const uint8_t * src = (const uint8_t *) data + (i * row_size);\n        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);\n\n        memcpy(buf_pd, src, row_size);\n        repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]);\n        memcpy(dst, buf_rp, row_size);\n    }\n\n    // 2. Process the final, potentially partial, row\n    if (n_rem_bytes > 0) {\n        const int64_t i = n_full_rows;\n        const uint8_t * src = (const uint8_t *) data + (i * row_size);\n        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);\n\n        // re-init the row because we are potentially copying a partial row\n        init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]);\n\n        // Copy only the remaining bytes from the source.\n        memcpy(buf_pd, src, n_rem_bytes);\n\n        // Repack the entire buffer\n        repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]);\n\n        // Write only the corresponding remaining bytes to the destination tensor.\n        memcpy(dst, buf_rp, n_rem_bytes);\n    }\n\n    ggml_aligned_free(buf_pd, row_size_pd);\n    ggml_aligned_free(buf_rp, row_size_rp);\n}\n\n// repack q8x4x2 tensor into q8_0 data\nstatic void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) {\n    int64_t nrows = ggml_nrows(t);\n\n    size_t row_size    = ggml_row_size(t->type, t->ne[0]);\n    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2));  // extra elements for the pad\n    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)\n\n    // Ensure we don't try to copy more data than the tensor actually contains.\n    const size_t total_tensor_size = (size_t)nrows * row_size;\n    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;\n\n    // Calculate how many full rows and how many remaining bytes we need to process.\n    const int64_t n_full_rows = n_bytes_to_copy / row_size;\n    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;\n\n    void * buf_pd = ggml_aligned_malloc(row_size_pd);\n    GGML_ASSERT(buf_pd != NULL);\n\n    void * buf_rp = ggml_aligned_malloc(row_size_rp);\n    GGML_ASSERT(buf_rp != NULL);\n\n    HEX_VERBOSE(\"ggml-hex: repack-q8x4x2-q8_0 %s : data %p size %zu dims %ldx%ld row-size %zu\\n\", t->name, data, size,\n                t->ne[0], nrows, row_size);\n\n    memset(buf_pd, 0, row_size_pd);  // clear-out padded buffer to make sure the tail is all zeros\n\n    // 1. Process all the full rows\n    for (int64_t i = 0; i < n_full_rows; i++) {\n        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);\n        uint8_t *       dst = (uint8_t *) data + (i * row_size);\n\n        memcpy(buf_pd, src, row_size);\n        unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);\n        memcpy(dst, buf_rp, row_size);\n    }\n\n    // 2. Process the final, potentially partial, row\n    if (n_rem_bytes > 0) {\n        const int64_t i = n_full_rows;\n        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);\n        uint8_t *       dst = (uint8_t *) data + (i * row_size);\n\n        // We still need to read and unpack the entire source row because quantization is block-based.\n        memcpy(buf_pd, src, row_size);\n        unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);\n\n        // But we only copy the remaining number of bytes to the destination.\n        memcpy(dst, buf_rp, n_rem_bytes);\n    }\n\n    ggml_aligned_free(buf_pd, row_size_pd);\n    ggml_aligned_free(buf_rp, row_size_rp);\n}\n\n// ======== MXFP4x4x2 ====================\nstruct x2_mxfp4 {\n    int v[2];\n};\n\nstatic x2_mxfp4 unpack_mxfp4(uint8_t v) {\n    x2_mxfp4 x;\n    x.v[0] = kvalues_mxfp4[(v & 0x0f)];\n    x.v[1] = kvalues_mxfp4[(v >> 4)];\n    return x;\n}\n\nstatic void dump_block_mxfp4(const block_mxfp4 * b, int i) {\n    HEX_VERBOSE(\"ggml-hex: repack mxfp4 %d: %d %d %d %d ... %d %d %d %d : %.6f\\n\", i, unpack_mxfp4(b->qs[0]).v[0],\n                unpack_mxfp4(b->qs[1]).v[0], unpack_mxfp4(b->qs[2]).v[0], unpack_mxfp4(b->qs[3]).v[0],\n                unpack_mxfp4(b->qs[12]).v[1], unpack_mxfp4(b->qs[13]).v[1], unpack_mxfp4(b->qs[14]).v[1],\n                unpack_mxfp4(b->qs[15]).v[1], GGML_E8M0_TO_FP32_HALF(b->e));\n}\n\nstatic void dump_packed_block_mxfp4x4x2(const uint8_t * v, unsigned int i, size_t k) {\n    static const int qk        = QK_MXFP4x4x2;\n    const int        eblk_size = 8 * 1;   // 8x E8M0\n    const int        qblk_size = qk / 2;  // int4\n    const int        qrow_size = k / 2;   // int4 (not padded)\n\n    const uint8_t * v_q = v + 0;          // quants first\n    const uint8_t * v_e = v + qrow_size;  // then scales\n\n    const uint8_t * q = v_q + i * qblk_size;\n    const uint8_t * e = (const uint8_t *) (v_e + i * eblk_size);\n\n    HEX_VERBOSE(\"ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\\n\", i,\n                unpack_mxfp4(q[0]).v[0], unpack_mxfp4(q[1]).v[0], unpack_mxfp4(q[2]).v[0], unpack_mxfp4(q[3]).v[0],\n                unpack_mxfp4(q[60]).v[0], unpack_mxfp4(q[61]).v[0], unpack_mxfp4(q[62]).v[0], unpack_mxfp4(q[63]).v[0],\n                unpack_mxfp4(q[124]).v[0], unpack_mxfp4(q[125]).v[0], unpack_mxfp4(q[126]).v[0],\n                unpack_mxfp4(q[127]).v[0], GGML_E8M0_TO_FP32_HALF(e[0]), GGML_E8M0_TO_FP32_HALF(e[1]),\n                GGML_E8M0_TO_FP32_HALF(e[2]), GGML_E8M0_TO_FP32_HALF(e[3]));\n\n    HEX_VERBOSE(\"ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\\n\",\n                i + 1, unpack_mxfp4(q[0]).v[1], unpack_mxfp4(q[1]).v[1], unpack_mxfp4(q[2]).v[1],\n                unpack_mxfp4(q[3]).v[1], unpack_mxfp4(q[60]).v[1], unpack_mxfp4(q[61]).v[1], unpack_mxfp4(q[62]).v[1],\n                unpack_mxfp4(q[63]).v[1], unpack_mxfp4(q[124]).v[1], unpack_mxfp4(q[125]).v[1],\n                unpack_mxfp4(q[126]).v[1], unpack_mxfp4(q[127]).v[1], GGML_E8M0_TO_FP32_HALF(e[4]),\n                GGML_E8M0_TO_FP32_HALF(e[5]), GGML_E8M0_TO_FP32_HALF(e[6]), GGML_E8M0_TO_FP32_HALF(e[7]));\n}\n\nstatic void unpack_mxfp4_quants(uint8_t * qs, const block_mxfp4 * x, unsigned int bi) {\n    static const int qk = QK_MXFP4;\n\n    for (unsigned int i = 0; i < qk / 2; ++i) {\n        const uint8_t x0         = (x->qs[i] & 0x0F);\n        const uint8_t x1         = (x->qs[i] >> 4);\n        qs[bi * qk + i + 0]      = x0;\n        qs[bi * qk + i + qk / 2] = x1;\n    }\n}\n\nstatic void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int bi) {\n    static const int qk = QK4_0;\n\n    for (unsigned int i = 0; i < qk / 2; ++i) {\n        const uint8_t x0 = qs[bi * qk + i + 0];\n        const uint8_t x1 = qs[bi * qk + i + qk / 2];\n        x->qs[i]         = x0 | (x1 << 4);\n    }\n}\n\nstatic void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) {\n    static const int qk = QK_MXFP4x4x2;\n    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)\n    const int        nloe = k % qk;           // leftovers\n\n    const int eblk_size = 8 * 1;              // 8x E8M0\n    const int qblk_size = qk / 2;             // int4\n    const int qrow_size = k / 2;              // int4 (not padded to blocks)\n\n    uint8_t * y_q = y + 0;                    // quants first\n    uint8_t * y_e = y + qrow_size;            // then scales\n\n    if (opt_verbose > 2) {\n        for (int i = 0; i < nb; i++) {\n            dump_block_mxfp4(&x[i * 8 + 0], 0);\n            dump_block_mxfp4(&x[i * 8 + 1], 1);\n            dump_block_mxfp4(&x[i * 8 + 2], 2);\n            dump_block_mxfp4(&x[i * 8 + 3], 3);\n            dump_block_mxfp4(&x[i * 8 + 4], 4);\n            dump_block_mxfp4(&x[i * 8 + 5], 5);\n            dump_block_mxfp4(&x[i * 8 + 6], 6);\n            dump_block_mxfp4(&x[i * 8 + 7], 7);\n        }\n    }\n\n    // Repack the quants\n    for (int i = 0; i < nb; i++) {\n        uint8_t qs[QK_MXFP4x4x2];  // unpacked quants\n\n        unpack_mxfp4_quants(qs, &x[i * 8 + 0], 0);\n        unpack_mxfp4_quants(qs, &x[i * 8 + 1], 1);\n        unpack_mxfp4_quants(qs, &x[i * 8 + 2], 2);\n        unpack_mxfp4_quants(qs, &x[i * 8 + 3], 3);\n        unpack_mxfp4_quants(qs, &x[i * 8 + 4], 4);\n        unpack_mxfp4_quants(qs, &x[i * 8 + 5], 5);\n        unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6);\n        unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7);\n\n        bool partial = (nloe && i == nb-1);\n\n        uint8_t * q = y_q + (i * qblk_size);\n        for (int j = 0; j < qk / 2; j++) {\n            q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000];\n        }\n    }\n\n    // Repack the scales\n    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2)\n    // the last block is truncated and overridden by the scales.\n    for (int i = 0; i < nb; i++) {\n        // Repack the scales\n        uint8_t * e = (uint8_t *) (y_e + i * eblk_size);\n        e[0]        = x[i * 8 + 0].e;\n        e[1]        = x[i * 8 + 1].e;\n        e[2]        = x[i * 8 + 2].e;\n        e[3]        = x[i * 8 + 3].e;\n        e[4]        = x[i * 8 + 4].e;\n        e[5]        = x[i * 8 + 5].e;\n        e[6]        = x[i * 8 + 6].e;\n        e[7]        = x[i * 8 + 7].e;\n    }\n\n    if (opt_verbose > 1) {\n        for (int i = 0; i < nb; i++) {\n            dump_packed_block_mxfp4x4x2(y, i, k);\n        }\n    }\n}\n\nstatic void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) {\n    static const int qk = QK_MXFP4x4x2;\n    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)\n    const int        nloe = k % qk;           // leftovers\n\n    const int eblk_size = 8 * 1;              // 8x E8M0\n    const int qblk_size = qk / 2;             // int4\n    const int qrow_size = k / 2;              // int4 (not padded to blocks)\n\n    const uint8_t * y_q = y + 0;              // quants first\n    const uint8_t * y_e = y + qrow_size;      // then scales\n\n    if (opt_verbose > 1) {\n        for (int i = 0; i < nb; i++) {\n            dump_packed_block_mxfp4x4x2(y, i, k);\n        }\n    }\n\n    // Unpack the quants\n    for (int i = 0; i < nb; i++) {\n        uint8_t qs[QK_MXFP4x4x2];  // unpacked quants\n\n        bool partial = (nloe && i == nb-1);\n\n        const uint8_t * q = y_q + (i * qblk_size);\n        for (int j = 0; j < qk / 2; j++) {\n            if (partial) {\n                qs[j*2+0] = q[j] & 0xf;\n                qs[j*2+1] = q[j] >> 4;\n            } else {\n                qs[j+000] = q[j] & 0xf;\n                qs[j+128] = q[j] >> 4;\n            }\n        }\n\n        pack_mxfp4_quants(&x[i * 8 + 0], qs, 0);\n        pack_mxfp4_quants(&x[i * 8 + 1], qs, 1);\n        pack_mxfp4_quants(&x[i * 8 + 2], qs, 2);\n        pack_mxfp4_quants(&x[i * 8 + 3], qs, 3);\n        pack_mxfp4_quants(&x[i * 8 + 4], qs, 4);\n        pack_mxfp4_quants(&x[i * 8 + 5], qs, 5);\n        pack_mxfp4_quants(&x[i * 8 + 6], qs, 6);\n        pack_mxfp4_quants(&x[i * 8 + 7], qs, 7);\n    }\n\n    // Repack the scales\n    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4_0x4x2)\n    // the last block is truncated and overridden by the scales.\n    for (int i = 0; i < nb; i++) {\n        // Unpack the scales\n        const uint8_t * e = (const uint8_t *) (y_e + i * eblk_size);\n        x[i * 8 + 0].e    = e[0];\n        x[i * 8 + 1].e    = e[1];\n        x[i * 8 + 2].e    = e[2];\n        x[i * 8 + 3].e    = e[3];\n        x[i * 8 + 4].e    = e[4];\n        x[i * 8 + 5].e    = e[5];\n        x[i * 8 + 6].e    = e[6];\n        x[i * 8 + 7].e    = e[7];\n    }\n\n    if (opt_verbose > 2) {\n        for (int i = 0; i < nb; i++) {\n            dump_block_mxfp4(&x[i * 8 + 0], 0);\n            dump_block_mxfp4(&x[i * 8 + 1], 1);\n            dump_block_mxfp4(&x[i * 8 + 2], 2);\n            dump_block_mxfp4(&x[i * 8 + 3], 3);\n            dump_block_mxfp4(&x[i * 8 + 4], 4);\n            dump_block_mxfp4(&x[i * 8 + 5], 5);\n            dump_block_mxfp4(&x[i * 8 + 6], 6);\n            dump_block_mxfp4(&x[i * 8 + 7], 7);\n        }\n    }\n}\n\nstatic void init_row_mxfp4x4x2(block_mxfp4 * x, int64_t k) {\n    static const int qk = QK_MXFP4x4x2;\n    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)\n\n    // Init the quants such that they unpack into zeros\n    uint8_t qs[QK_MXFP4x4x2];  // unpacked quants\n    memset(qs, 0, sizeof(qs));\n\n    for (int i = 0; i < nb; i++) {\n        pack_mxfp4_quants(&x[i * 8 + 0], qs, 0);\n        pack_mxfp4_quants(&x[i * 8 + 1], qs, 1);\n        pack_mxfp4_quants(&x[i * 8 + 2], qs, 2);\n        pack_mxfp4_quants(&x[i * 8 + 3], qs, 3);\n        pack_mxfp4_quants(&x[i * 8 + 4], qs, 4);\n        pack_mxfp4_quants(&x[i * 8 + 5], qs, 5);\n        pack_mxfp4_quants(&x[i * 8 + 6], qs, 6);\n        pack_mxfp4_quants(&x[i * 8 + 7], qs, 7);\n    }\n\n    // Init the scales\n    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2)\n    // the last block is truncated and overridden by the scales.\n    for (int i = 0; i < nb; i++) {\n        // Unpack the scales\n        x[i * 8 + 0].e = 0;\n        x[i * 8 + 1].e = 0;\n        x[i * 8 + 2].e = 0;\n        x[i * 8 + 3].e = 0;\n        x[i * 8 + 4].e = 0;\n        x[i * 8 + 5].e = 0;\n        x[i * 8 + 6].e = 0;\n        x[i * 8 + 7].e = 0;\n    }\n}\n\n// repack mxfp4 data into mxfp4x4x2 tensor\nstatic void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t size) {\n    int64_t nrows = ggml_nrows(t);\n\n    size_t row_size    = ggml_row_size(t->type, t->ne[0]);\n    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2));  // extra elements for the pad\n    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)\n\n    // Ensure we don't try to read more data than is available in the source buffer 'data'\n    // or write more than the tensor can hold.\n    const size_t total_tensor_size = (size_t)nrows * row_size;\n    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;\n\n    // Calculate how many full rows and how many remaining bytes we need to process.\n    const int64_t n_full_rows = n_bytes_to_copy / row_size;\n    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;\n\n    void * buf_pd = ggml_aligned_malloc(row_size_pd);\n    GGML_ASSERT(buf_pd != NULL);\n\n    void * buf_rp = ggml_aligned_malloc(row_size_rp);\n    GGML_ASSERT(buf_rp != NULL);\n\n    HEX_VERBOSE(\"ggml-hex: repack-mxfp4-mxfp4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\\n\", t->name, data,\n                size, t->ne[0], nrows, row_size);\n\n    init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]);  // init padded buffer to make sure the tail is all zeros\n\n    // 1. Process all the full rows\n    for (int64_t i = 0; i < n_full_rows; i++) {\n        const uint8_t * src = (const uint8_t *) data + (i * row_size);\n        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);\n\n        memcpy(buf_pd, src, row_size);\n        repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]);\n        memcpy(dst, buf_rp, row_size);\n    }\n\n    // 2. Process the final, potentially partial, row\n    if (n_rem_bytes > 0) {\n        const int64_t i = n_full_rows;\n        const uint8_t * src = (const uint8_t *) data + (i * row_size);\n        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);\n\n        // re-init the row because we are potentially copying a partial row\n        init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]);\n\n        // Copy only the remaining bytes from the source.\n        memcpy(buf_pd, src, n_rem_bytes);\n\n        // Repack the entire buffer (partial data + zero padding).\n        repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]);\n\n        // Write only the corresponding remaining bytes to the destination tensor.\n        memcpy(dst, buf_rp, n_rem_bytes);\n    }\n\n    ggml_aligned_free(buf_pd, row_size_pd);\n    ggml_aligned_free(buf_rp, row_size_rp);\n}\n\n// repack mxfp4x4x2 tensor into mxfp4 data\nstatic void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t size) {\n    int64_t nrows = ggml_nrows(t);\n\n    size_t row_size    = ggml_row_size(t->type, t->ne[0]);\n    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2));  // extra elements for the pad\n    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)\n\n    // Ensure we don't try to copy more data than the tensor actually contains.\n    const size_t total_tensor_size = (size_t)nrows * row_size;\n    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;\n\n    // Calculate how many full rows and how many remaining bytes we need to process.\n    const int64_t n_full_rows = n_bytes_to_copy / row_size;\n    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;\n\n    void * buf_pd = ggml_aligned_malloc(row_size_pd);\n    GGML_ASSERT(buf_pd != NULL);\n\n    void * buf_rp = ggml_aligned_malloc(row_size_rp);\n    GGML_ASSERT(buf_rp != NULL);\n\n    HEX_VERBOSE(\"ggml-hex: repack-mxfp4x4x2-mxfp4 %s : data %p size %zu dims %ldx%ld row-size %zu\\n\", t->name, data,\n                size, t->ne[0], nrows, row_size);\n\n    memset(buf_pd, 0, row_size_pd);  // clear-out padded buffer to make sure the tail is all zeros\n\n    // 1. Process all the full rows\n    for (int64_t i = 0; i < n_full_rows; i++) {\n        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);\n        uint8_t *       dst = (uint8_t *) data + (i * row_size);\n\n        memcpy(buf_pd, src, row_size);\n        unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);\n        memcpy(dst, buf_rp, row_size);\n    }\n\n    // 2. Process the final, potentially partial, row\n    if (n_rem_bytes > 0) {\n        const int64_t i = n_full_rows;\n        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);\n        uint8_t *       dst = (uint8_t *) data + (i * row_size);\n\n        // We still need to read and unpack the entire source row because the format is block-based.\n        memcpy(buf_pd, src, row_size);\n        unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);\n\n        // But we only copy the remaining number of bytes to the destination to respect the size limit.\n        memcpy(dst, buf_rp, n_rem_bytes);\n    }\n\n    ggml_aligned_free(buf_pd, row_size_pd);\n    ggml_aligned_free(buf_rp, row_size_rp);\n}\n\nstatic void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,\n                                                   ggml_tensor *         tensor,\n                                                   const void *          data,\n                                                   size_t                offset,\n                                                   size_t                size) {\n    auto ctx  = (ggml_backend_hexagon_buffer_context *) buffer->context;\n    auto sess = ctx->sess;\n\n    HEX_VERBOSE(\"ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\\n\", sess->name.c_str(), tensor->name, data,\n                offset, size);\n\n    switch (tensor->type) {\n        case GGML_TYPE_Q4_0:\n            GGML_ASSERT(offset == 0);\n            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));\n            repack_q4_0_q4x4x2(tensor, data, size);\n            break;\n\n        case GGML_TYPE_Q8_0:\n            GGML_ASSERT(offset == 0);\n            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));\n            repack_q8_0_q8x4x2(tensor, data, size);\n            break;\n\n        case GGML_TYPE_MXFP4:\n            GGML_ASSERT(offset == 0);\n            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));\n            repack_mxfp4_mxfp4x4x2(tensor, data, size);\n            break;\n\n        default:\n            memcpy((char *) tensor->data + offset, data, size);\n            break;\n    }\n}\n\nstatic void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,\n                                                   const ggml_tensor *   tensor,\n                                                   void *                data,\n                                                   size_t                offset,\n                                                   size_t                size) {\n    auto ctx  = (ggml_backend_hexagon_buffer_context *) buffer->context;\n    auto sess = ctx->sess;\n\n    HEX_VERBOSE(\"ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\\n\", sess->name.c_str(), tensor->name, data,\n                offset, size);\n\n    switch (tensor->type) {\n        case GGML_TYPE_Q4_0:\n            GGML_ASSERT(offset == 0);\n            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));\n            repack_q4x4x2_q4_0(data, tensor, size);\n            break;\n\n        case GGML_TYPE_Q8_0:\n            GGML_ASSERT(offset == 0);\n            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));\n            repack_q8x4x2_q8_0(data, tensor, size);\n            break;\n\n        case GGML_TYPE_MXFP4:\n            GGML_ASSERT(offset == 0);\n            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));\n            repack_mxfp4x4x2_mxfp4(data, tensor, size);\n            break;\n\n        default:\n            memcpy(data, (const char *) tensor->data + offset, size);\n            break;\n    }\n}\n\nstatic bool ggml_backend_hexagon_buffer_cpy_tensor(ggml_backend_buffer_t      buffer,\n                                                   const struct ggml_tensor * src,\n                                                   struct ggml_tensor *       dst) {\n    GGML_UNUSED(buffer);\n    GGML_UNUSED(src);\n    GGML_UNUSED(dst);\n    // we might optimize this later, for now take the slow path (ie get/set_tensor)\n    return false;\n}\n\nstatic void ggml_backend_hexagon_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    auto ctx  = (ggml_backend_hexagon_buffer_context *) buffer->context;\n    auto sess = ctx->sess;\n    HEX_VERBOSE(\"ggml-hex: %s clear-buff base %p size %zu\\n\", sess->name.c_str(), (void *) ctx->base, ctx->size);\n    memset(ctx->base, value, ctx->size);\n}\n\nstatic ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = {\n    /* .free_buffer     = */ ggml_backend_hexagon_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_hexagon_buffer_get_base,\n    /* .init_tensor     = */ ggml_backend_hexagon_buffer_init_tensor,\n    /* .memset_tensor   = */ NULL,\n    /* .set_tensor      = */ ggml_backend_hexagon_buffer_set_tensor,\n    /* .get_tensor      = */ ggml_backend_hexagon_buffer_get_tensor,\n    /* .cpy_tensor      = */ ggml_backend_hexagon_buffer_cpy_tensor,\n    /* .clear           = */ ggml_backend_hexagon_buffer_clear,\n    /* .reset           = */ NULL,\n};\n\n// ** backend buffer type\n\nstatic const char * ggml_backend_hexagon_buffer_type_name(ggml_backend_buffer_type_t buffer_type) {\n    return static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->name.c_str();\n}\n\nstatic ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer(\n            ggml_backend_buffer_type_t buffer_type, size_t size) {\n    auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;\n    try {\n        ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, false /*repack*/);\n        return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size);\n    } catch (const std::exception & exc) {\n        GGML_LOG_ERROR(\"ggml-hex: %s failed to allocate buffer context: %s\\n\", sess->name.c_str(), exc.what());\n        return nullptr;\n    }\n}\n\nstatic ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffer(\n            ggml_backend_buffer_type_t buffer_type, size_t size) {\n    auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;\n    try {\n        ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, true /*repack*/);\n        return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size);\n    } catch (const std::exception & exc) {\n        GGML_LOG_ERROR(\"ggml-hex: %s failed to allocate buffer context: %s\\n\", sess->name.c_str(), exc.what());\n        return nullptr;\n    }\n}\n\nstatic size_t ggml_backend_hexagon_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {\n    return 128;  // HVX alignment\n    GGML_UNUSED(buffer_type);\n}\n\nstatic size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * t) {\n    return ggml_nbytes(t);\n}\n\nstatic size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {\n    return 1 * 1024 * 1024 * 1024;  // 1GB per buffer\n    GGML_UNUSED(buffer_type);\n}\n\nstatic bool ggml_backend_hexagon_buffer_type_is_host(ggml_backend_buffer_type_t buft) {\n    return opt_hostbuf;\n    GGML_UNUSED(buft);\n}\n\nstatic bool ggml_backend_hexagon_repack_buffer_type_is_host(ggml_backend_buffer_type_t buft) {\n    return false;\n    GGML_UNUSED(buft);\n}\n\nstatic ggml_backend_buffer_type_i ggml_backend_hexagon_buffer_type_interface = {\n    /* .get_name         = */ ggml_backend_hexagon_buffer_type_name,\n    /* .alloc_buffer     = */ ggml_backend_hexagon_buffer_type_alloc_buffer,\n    /* .get_alignment    = */ ggml_backend_hexagon_buffer_type_get_alignment,\n    /* .get_max_size     = */ ggml_backend_hexagon_buffer_type_get_max_size,\n    /* .get_alloc_size   = */ ggml_backend_hexagon_buffer_type_get_alloc_size,\n    /* .is_host          = */ ggml_backend_hexagon_buffer_type_is_host,\n};\n\nstatic ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interface = {\n    /* .get_name         = */ ggml_backend_hexagon_buffer_type_name,\n    /* .alloc_buffer     = */ ggml_backend_hexagon_repack_buffer_type_alloc_buffer,\n    /* .get_alignment    = */ ggml_backend_hexagon_buffer_type_get_alignment,\n    /* .get_max_size     = */ ggml_backend_hexagon_buffer_type_get_max_size,\n    /* .get_alloc_size   = */ ggml_backend_hexagon_buffer_type_get_alloc_size,\n    /* .is_host          = */ ggml_backend_hexagon_repack_buffer_type_is_host,\n};\n\nvoid ggml_hexagon_session::allocate(int dev_id) noexcept(false) {\n    this->valid_session = false;\n    this->valid_handle  = false;\n    this->valid_queue   = false;\n    this->valid_iface   = false;\n\n    this->domain_id  = 3;  // Default for CDSP, updated after the session is created\n    this->session_id = 0;  // Default for CDSP, updated after the session is created\n    this->dev_id     = dev_id;\n    this->name       = std::string(\"HTP\") + std::to_string(dev_id);\n\n    this->op_pending  = 0;\n    this->prof_usecs  = 0;\n    this->prof_cycles = 0;\n    this->prof_pkts   = 0;\n\n    GGML_LOG_INFO(\"ggml-hex: allocating new session: %s\\n\", this->name.c_str());\n\n    domain * my_domain = get_domain(this->domain_id);\n    if (my_domain == NULL) {\n        GGML_LOG_ERROR(\"ggml-hex: unable to get domain struct for CDSP\\n\");\n        throw std::runtime_error(\"ggml-hex: failed to get CDSP domain (see log for details)\");\n    }\n\n    // Create new session\n    if (dev_id != 0) {\n        struct remote_rpc_reserve_new_session n;\n        n.domain_name_len  = strlen(CDSP_DOMAIN_NAME);\n        n.domain_name      = const_cast<char *>(CDSP_DOMAIN_NAME);\n        n.session_name     = const_cast<char *>(this->name.c_str());\n        n.session_name_len = this->name.size();\n\n        int err = remote_session_control(FASTRPC_RESERVE_NEW_SESSION, (void *) &n, sizeof(n));\n        if (err != AEE_SUCCESS) {\n            GGML_LOG_ERROR(\"ggml-hex: failed to reserve new session %d : error 0x%x\\n\", dev_id, err);\n            throw std::runtime_error(\"ggml-hex: remote_session_control(new-sess) failed (see log for details)\");\n        }\n\n        // Save the IDs\n        this->session_id    = n.session_id;\n        this->domain_id     = n.effective_domain_id;\n        this->valid_session = true;\n    }\n\n    // Get session URI\n\n    char session_uri[256];\n    {\n        char htp_uri[256];\n        snprintf(htp_uri, sizeof(htp_uri), \"file:///libggml-htp-v%u.so?htp_iface_skel_handle_invoke&_modver=1.0\", opt_arch);\n\n        struct remote_rpc_get_uri u = {};\n        u.session_id      = this->session_id;\n        u.domain_name     = const_cast<char *>(CDSP_DOMAIN_NAME);\n        u.domain_name_len = strlen(CDSP_DOMAIN_NAME);\n        u.module_uri      = const_cast<char *>(htp_uri);\n        u.module_uri_len  = strlen(htp_uri);\n        u.uri             = session_uri;\n        u.uri_len         = sizeof(session_uri);\n\n        int err = remote_session_control(FASTRPC_GET_URI, (void *) &u, sizeof(u));\n        if (err != AEE_SUCCESS) {\n            // fallback to single session uris\n            int htp_URI_domain_len = strlen(htp_uri) + MAX_DOMAIN_NAMELEN;\n\n            snprintf(session_uri, htp_URI_domain_len, \"%s%s\", htp_uri, my_domain->uri);\n\n            GGML_LOG_WARN(\"ggml-hex: failed to get URI for session %d : error 0x%x. Falling back to single session URI: %s\\n\", dev_id, err, session_uri);\n        }\n    }\n\n    // Enable Unsigned PD\n    {\n        struct remote_rpc_control_unsigned_module u;\n        u.domain = this->domain_id;\n        u.enable = 1;\n        int err  = remote_session_control(DSPRPC_CONTROL_UNSIGNED_MODULE, (void *) &u, sizeof(u));\n        if (err != AEE_SUCCESS) {\n            GGML_LOG_ERROR(\"ggml-hex: failed to enable unsigned PD for session %d : error 0x%x\\n\", dev_id, err);\n            throw std::runtime_error(\"ggml-hex: remote_session_control(unsign) failed (see log for details)\");\n        }\n    }\n\n    // Open session\n    int err = htp_iface_open(session_uri, &this->handle);\n    if (err != AEE_SUCCESS) {\n        GGML_LOG_ERROR(\"ggml-hex: failed to open session %d : error 0x%x\\n\", dev_id, err);\n        throw std::runtime_error(\"ggml-hex: failed to open session (see log for details)\");\n    }\n\n    this->valid_handle = true;\n\n    GGML_LOG_INFO(\"ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\\n\", this->name.c_str(),\n                  this->session_id, this->domain_id, session_uri, (unsigned long) this->handle);\n\n    // Enable FastRPC QoS mode\n    {\n        struct remote_rpc_control_latency l;\n        l.enable = 1;\n\n        int err = remote_handle64_control(this->handle, DSPRPC_CONTROL_LATENCY, (void *) &l, sizeof(l));\n        if (err != 0) {\n            GGML_LOG_WARN(\"ggml-hex: failed to enable fastrpc QOS mode: 0x%08x\\n\", (unsigned) err);\n        }\n    }\n\n    // Now let's setup the DSP queue\n    err = dspqueue_create(this->domain_id,\n                          0,              // Flags\n                          128 * 1024,     // Request  queue size (in bytes)\n                          64 * 1024,      // Response queue size (in bytes)\n                          nullptr,        // Read packet callback (we handle reads explicitly)\n                          nullptr,        // Error callback (we handle errors during reads)\n                          (void *) this,  // Callback context\n                          &queue);\n    if (err != 0) {\n        GGML_LOG_ERROR(\"ggml-hex: %s dspqueue_create failed: 0x%08x\\n\", this->name.c_str(), (unsigned) err);\n        throw std::runtime_error(\"ggml-hex: failed to create dspqueue (see log for details)\");\n    }\n\n    this->valid_queue = true;\n\n    // Export queue for use on the DSP\n    err = dspqueue_export(queue, &this->queue_id);\n    if (err != 0) {\n        GGML_LOG_ERROR(\"ggml-hex: dspqueue_export failed: 0x%08x\\n\", (unsigned) err);\n        throw std::runtime_error(\"ggml-hex: dspqueue export failed (see log for details)\");\n    }\n\n    if (opt_etm) {\n        err = htp_iface_enable_etm(this->handle);\n        if (err != 0) {\n            GGML_LOG_ERROR(\"ggml-hex: failed to enable ETM tracing: 0x%08x\\n\", (unsigned) err);\n        }\n    }\n\n    // Start the DSP-side service. We need to pass the queue ID to the\n    // DSP in a FastRPC call; the DSP side will import the queue and start\n    // listening for packets in a callback.\n    err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx);\n    if (err != 0) {\n        GGML_LOG_ERROR(\"ggml-hex: failed to start session: 0x%08x\\n\", (unsigned) err);\n        throw std::runtime_error(\"ggml-hex: iface start failed (see log for details)\");\n    }\n    this->valid_iface = true;\n}\n\nvoid ggml_hexagon_session::release() noexcept(true) {\n    GGML_LOG_INFO(\"ggml-hex: releasing session: %s\\n\", this->name.c_str());\n\n    int err;\n\n    // Stop the DSP-side service and close the queue\n    if (this->valid_iface) {\n        err = htp_iface_stop(this->handle);\n        if (err != 0) {\n            GGML_ABORT(\"ggml-hex: htp_iface_stop failed: 0x%08x\\n\", (unsigned) err);\n        }\n    }\n\n    if (opt_etm) {\n        err = htp_iface_disable_etm(this->handle);\n        if (err != 0) {\n            GGML_LOG_ERROR(\"ggml-hex: warn : failed to disable ETM tracing: 0x%08x\\n\", (unsigned) err);\n        }\n    }\n\n    if (this->valid_queue) {\n        err = dspqueue_close(queue);\n        if (err != 0) {\n            GGML_ABORT(\"ggml-hex: dspqueue_close failed: 0x%08x\\n\", (unsigned) err);\n        }\n    }\n\n    if (this->valid_handle) {\n        htp_iface_close(this->handle);\n    }\n}\n\nggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) {\n    buffer_type.device        = dev;\n    repack_buffer_type.device = dev;\n\n    try {\n        allocate(dev_id);\n\n        buffer_type.iface   = ggml_backend_hexagon_buffer_type_interface;\n        buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name, this);\n\n        repack_buffer_type.iface   = ggml_backend_hexagon_repack_buffer_type_interface;\n        repack_buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name + \"-REPACK\", this);\n    } catch (const std::exception & exc) {\n        release();\n        throw;\n    }\n}\n\nggml_hexagon_session::~ggml_hexagon_session() noexcept(true) {\n    release();\n\n    delete static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type.context);\n    delete static_cast<ggml_backend_hexagon_buffer_type_context *>(repack_buffer_type.context);\n}\n\n// ** backend interface\n\nstatic bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) {\n    return b->buft->iface.get_alignment == ggml_backend_hexagon_buffer_type_get_alignment;\n}\n\nstatic inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) {\n    if (!opt_hostbuf) {\n        return ggml_backend_buffer_is_hexagon(b);\n    }\n    return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;\n}\n\nstatic bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {\n    const struct ggml_tensor * src0 = op->src[0];\n    const struct ggml_tensor * src1 = op->src[1];\n    const struct ggml_tensor * src2 = op->src[2];\n    const struct ggml_tensor * src3 = op->src[3];\n    const struct ggml_tensor * src4 = op->src[4];\n    const struct ggml_tensor * dst  = op;\n\n    // Check for F16 support only as requested\n    if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) {\n        return false;\n    }\n\n    if (src3 && src3->type != GGML_TYPE_F16) {  // mask\n        return false;\n    }\n\n    if (src4 && src4->type != GGML_TYPE_F32) {  // sinks\n        return false;\n    }\n\n    // For now we support F32 or F16 output as htp backend often converts output on the fly if needed,\n    // but the op implementation writes to F16 or F32.\n    // Let's assume dst can be F32 or F16.\n    if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) {\n        return false;\n    }\n\n    return opt_experimental;\n}\n\n\nstatic bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {\n    const struct ggml_tensor * src0 = dst->src[0];\n    const struct ggml_tensor * src1 = dst->src[1];\n\n    if (dst->type != GGML_TYPE_F32) {\n        return false;\n    }\n\n    if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {\n        return false;\n    }\n\n    switch (src0->type) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_MXFP4:\n            if (src0->ne[0] % 32) {\n                return false;\n            }\n\n            if (ggml_nrows(src0) > 16 * 1024) {\n                return false;  // typically the lm-head which would be too large for VTCM\n            }\n\n            if (ggml_nrows(src1) > 1024 || src1->ne[2] != 1 || src1->ne[3] != 1) {\n                return false;  // no huge batches or broadcasting (for now)\n            }\n\n            // src0 (weights) must be repacked\n            if (src0->buffer && !ggml_backend_buffer_is_hexagon_repack(src0->buffer)) {\n                return false;\n            }\n            break;\n\n        case GGML_TYPE_F16:\n            if (src0->nb[1] < src0->nb[0]) {\n                GGML_LOG_DEBUG(\"ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\\n\");\n                return false;\n            }\n            if (ggml_nrows(src1) > 1024) {\n                return false;  // no huge batches (for now)\n            }\n            break;\n\n        default:\n            return false;\n    }\n\n    return true;\n}\n\nstatic bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {\n    const struct ggml_tensor * src0 = op->src[0];\n    const struct ggml_tensor * src1 = op->src[1];\n    const struct ggml_tensor * src2 = op->src[2];\n    const struct ggml_tensor * dst  = op;\n\n    if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32 || src2->type != GGML_TYPE_I32) {\n        return false;\n    }\n\n    switch (src0->type) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_MXFP4:\n            if ((src0->ne[0] % 32)) {\n                return false;\n            }\n\n            // src0 (weights) must be repacked\n            if (src0->buffer && !ggml_backend_buffer_is_hexagon_repack(src0->buffer)) {\n                return false;\n            }\n            break;\n\n        default:\n            return false;\n    }\n\n    return true;\n}\n\nstatic bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {\n    const struct ggml_tensor * src0 = op->src[0];\n    const struct ggml_tensor * src1 = op->src[1];\n    const struct ggml_tensor * dst  = op;\n\n    if (src0->type == GGML_TYPE_F32) {\n        if (src1->type != GGML_TYPE_F32) {\n            return false;\n        }\n        if (dst->type != GGML_TYPE_F32) {\n            return false;\n        }\n    }\n    else if (src0->type == GGML_TYPE_F16) {\n        if (src1->type != GGML_TYPE_F16) {\n            return false;\n        }\n        if (dst->type != GGML_TYPE_F16) {\n            return false;\n        }\n    }\n    else {\n        return false;\n    }\n\n    if (!ggml_are_same_shape(src0, dst)) {\n        return false;\n    }\n    if (!ggml_can_repeat(src1, src0) || ggml_is_permuted(src1)) {\n        return false;\n    }\n\n    return true;\n}\n\nstatic bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {\n    const struct ggml_tensor * src0 = op->src[0];\n    const struct ggml_tensor * src1 = op->src[1];\n    const struct ggml_tensor * dst  = op;\n\n    if (src0->type != GGML_TYPE_F32) {\n        return false;\n    }\n    if (src1->type != GGML_TYPE_F32) {\n        return false;\n    }\n    if (dst->type != GGML_TYPE_F32) {\n        return false;\n    }\n    if (!ggml_are_same_shape(src0, dst)) {\n        return false;\n    }\n\n    // REVISIT: add support for non-contigiuos tensors\n    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {\n        return false;\n    }\n\n    return true;\n}\n\nstatic bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {\n    const struct ggml_tensor * src0 = op->src[0];\n    const struct ggml_tensor * dst  = op;\n\n    if (src0->type != GGML_TYPE_F32) {\n        return false;\n    }\n    if (dst->type != GGML_TYPE_F32) {\n        return false;\n    }\n    if (!ggml_are_same_shape(src0, dst)) {\n        return false;\n    }\n\n    // TODO: add support for non-contigiuos tensors\n    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {\n        return false;\n    }\n\n    return true;\n}\n\nstatic bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {\n    const struct ggml_tensor * src0 = op->src[0];\n    const struct ggml_tensor * dst  = op;\n\n    if (src0->type != GGML_TYPE_F32) {\n        return false;\n    }\n    if (dst->type != GGML_TYPE_F32) {\n        return false;\n    }\n\n    // TODO: add support for non-contigiuos tensors\n    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {\n        return false;\n    }\n\n    return true;\n}\n\nstatic bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess,\n                                               const struct ggml_tensor *          op) {\n    const struct ggml_tensor * src0 = op->src[0];\n    const struct ggml_tensor * src1 = op->src[1];\n    const struct ggml_tensor * dst  = op;\n\n    if (src0->type != GGML_TYPE_F32) {\n        return false;\n    }\n    if (dst->type != GGML_TYPE_F32) {\n        return false;\n    }\n\n    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {\n        return false;\n    }\n\n    if (src1) {\n        if (src1->type != GGML_TYPE_F32) {\n            return false;\n        }\n        if (!ggml_are_same_shape(src0, src1)) {\n            return false;\n        }\n        if (!ggml_is_contiguous(src1)) {\n            return false;\n        }\n    }\n\n    return true;\n}\n\nstatic bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {\n    const struct ggml_tensor * src0 = op->src[0];\n    const struct ggml_tensor * src1 = op->src[1];\n    const struct ggml_tensor * src2 = op->src[2];\n    const struct ggml_tensor * dst  = op;\n\n    if (src2) {\n        return false;  // FIXME: add support for sinks\n    }\n\n    if (src0->type != GGML_TYPE_F32) {\n        return false;\n    }\n    if (dst->type != GGML_TYPE_F32) {\n        return false;\n    }\n\n    if (src1) {\n        if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {\n            return false;\n        }\n        if (src0->ne[0] != src1->ne[0]) {\n            return false;\n        }\n        if (src1->ne[1] < src0->ne[1]) {\n            return false;\n        }\n        if (src0->ne[2] % src1->ne[2] != 0) {\n            return false;\n        }\n        if (src0->ne[3] % src1->ne[3] != 0) {\n            return false;\n        }\n    }\n\n    if (src1) {\n        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {\n            return false;\n        }\n    } else {\n        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {\n            return false;\n        }\n    }\n\n    return true;\n}\n\nstatic bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {\n    const struct ggml_tensor * src0 = op->src[0]; // values\n    const struct ggml_tensor * src1 = op->src[1]; // indices\n    const struct ggml_tensor * dst  = op;\n\n    if (src0->type != GGML_TYPE_F32) {\n        return false;\n    }\n\n    if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {\n        return false;\n    }\n\n    if (dst->type != GGML_TYPE_F16) {\n        return false;\n    }\n\n    return true;\n}\n\nstatic bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {\n    const struct ggml_tensor * src0 = op->src[0]; // values\n    const struct ggml_tensor * src1 = op->src[1]; // indices\n    const struct ggml_tensor * dst  = op;\n\n    if (src0->type != GGML_TYPE_F32) {\n        return false;\n    }\n\n    if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {\n        return false;\n    }\n\n    if (dst->type != GGML_TYPE_F32) {\n        return false;\n    }\n\n    return true;\n}\n\nstatic bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {\n    const struct ggml_tensor * src0 = op->src[0]; // values\n    const struct ggml_tensor * dst  = op;         // indices\n\n    if (src0->type != GGML_TYPE_F32) {\n        return false;\n    }\n\n    if (dst->type != GGML_TYPE_I32) {\n        return false;\n    }\n\n    if (src0->ne[0] > (16*1024)) {\n        // reject tensors with huge rows for now\n        return false;\n    }\n\n    return true;\n}\n\nstatic bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {\n    const int32_t * op_params = &op->op_params[0];\n\n    int mode = op_params[2];\n\n    if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {\n        return false;\n    }\n    if (mode & 1) {\n        return false;\n    }\n\n    const struct ggml_tensor * src0 = op->src[0];\n    const struct ggml_tensor * src1 = op->src[1];\n    const struct ggml_tensor * src2 = op->src[2];\n    const struct ggml_tensor * dst  = op;\n\n    if (src0->type != GGML_TYPE_F32) {\n        return false;  // FIXME: add support for GGML_TYPE_F16 for src0\n    }\n    if (dst->type != GGML_TYPE_F32) {\n        return false;\n    }\n    if (src1->type != GGML_TYPE_I32) {\n        return false;\n    }\n    if (src2) {\n        if (src2->type != GGML_TYPE_F32) {\n            return false;\n        }\n        int n_dims = op_params[1];\n        if (src2->ne[0] < (n_dims / 2)) {\n            return false;\n        }\n    }\n\n    if (src2) {\n        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(src2) ||\n            !ggml_is_contiguous(dst)) {\n            return false;\n        }\n    } else {\n        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {\n            return false;\n        }\n    }\n\n    return true;\n}\n\nstatic bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {\n    const struct ggml_tensor * src0 = op->src[0];\n    const struct ggml_tensor * src1 = op->src[1];\n    const struct ggml_tensor * dst  = op;\n\n    // Only support FP32 for now\n    if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {\n        return false;\n    }\n\n    // Check IO tensor shapes and dims\n    if (src0->ne[3] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || dst->ne[3] != 1) {\n        return false; // src0 should be effectively 3D\n    }\n\n    const int d_conv = src1->ne[0];\n    const int d_inner = src0->ne[1];\n    const int n_t = dst->ne[1];\n    const int n_s = dst->ne[2];\n\n    if (src0->ne[0] != d_conv - 1 + n_t || src0->ne[1] != d_inner || src0->ne[2] != n_s) {\n        return false;\n    }\n    if (src1->ne[0] != d_conv || src1->ne[1] != d_inner) {\n        return false;\n    }\n    if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) {\n        return false;\n    }\n\n    // TODO: add support for non-contiguous tensors\n    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {\n        return false;\n    }\n\n    return true;\n}\n\nenum dspqbuf_type {\n    DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0,\n    DSPQBUF_TYPE_CPU_WRITE_DSP_READ,\n    DSPQBUF_TYPE_CONSTANT,\n};\n\nstatic void dspqbuf_dump(dspqueue_buffer * d, const struct ggml_tensor * t, dspqbuf_type type) {\n    if (opt_verbose < 2) return;\n\n    auto buf  = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context);\n    auto sess = buf->sess;\n\n    GGML_LOG_DEBUG(\"ggml-hex: %s dspqbuf : %s base-addr %p base-size %zu data %p offset %u size %u\\n\", sess->name.c_str(),\n                t->name, (void *) buf->base, buf->size, (void *) d->ptr, (unsigned int) d->offset,\n                (unsigned int) d->size);\n}\n\n// Init hexagon tensor from GGML tensor and Hexagon buffer\nstatic void htp_req_tensor_init(htp_tensor * h, const ggml_tensor * t) {\n    h->data  = 0;  // updated by the receiver\n    h->type  = t->type;\n    h->ne[0] = t->ne[0];\n    h->ne[1] = t->ne[1];\n    h->ne[2] = t->ne[2];\n    h->ne[3] = t->ne[3];\n    h->nb[0] = t->nb[0];\n    h->nb[1] = t->nb[1];\n    h->nb[2] = t->nb[2];\n    h->nb[3] = t->nb[3];\n}\n\nstatic size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_tensor * t, dspqbuf_type type) {\n    if (!t) {\n        return 0;\n    }\n\n    auto buf = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context);\n\n    memset(d, 0, sizeof(*d));\n    d->fd     = buf->fd;\n    d->ptr    = t->data;\n    d->offset = (uint8_t *) t->data - buf->base;\n    d->size   = ggml_nbytes(t);\n\n    if (!d->size) {\n        // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty\n        d->size = 64;\n    }\n\n    switch (type) {\n        case DSPQBUF_TYPE_DSP_WRITE_CPU_READ:\n            // Flush CPU\n            d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER;\n            break;\n        case DSPQBUF_TYPE_CPU_WRITE_DSP_READ:\n            // Flush CPU, Invalidate DSP\n            d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;\n            break;\n        default:\n            // Constant buffer, no cache maintenance\n            d->flags = 0;\n            break;\n    }\n\n    htp_req_tensor_init(h, t);\n\n    dspqbuf_dump(d, t, type);\n\n    return 1;\n}\n\ntypedef size_t (*htp_req_init_func_t)(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * op);\n\ntemplate <htp_req_init_func_t _init_req_func>\nstatic inline void ggml_hexagon_dispatch_op(ggml_hexagon_session *sess, const struct ggml_tensor * op, uint32_t flags) {\n    uint64_t t = ggml_time_us();\n\n    // Construct HTP request\n    htp_general_req req;\n    memset(&req, 0, sizeof(req));\n\n    req.flags = flags;\n    if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) {\n        req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE;\n    }\n    if (!(opt_opmask & HTP_OPMASK_COMPUTE)) {\n        req.flags |= HTP_OPFLAGS_SKIP_COMPUTE;\n    }\n\n    ggml_hexagon_dump_op_exec(sess->name, op, req.flags);\n\n    if ((opt_opmask & HTP_OPMASK_QUEUE)) {\n        dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];\n        size_t n_bufs = _init_req_func(&req, bufs, op);\n        sess->enqueue(req, bufs, n_bufs, opt_opsync);\n    }\n\n    t = ggml_time_us() - t;\n\n    ggml_hexagon_dump_op_prof(sess->name, op, sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, t);\n}\n\ntemplate <bool _is_src0_constant>\nstatic inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {\n    switch (t->op) {\n        case GGML_OP_MUL_MAT:\n            req->op = HTP_OP_MUL_MAT;\n            break;\n        case GGML_OP_MUL:\n            req->op = HTP_OP_MUL;\n            break;\n        case GGML_OP_ADD:\n            req->op = HTP_OP_ADD;\n            break;\n        case GGML_OP_SUB:\n            req->op = HTP_OP_SUB;\n            break;\n        case GGML_OP_DIV:\n            req->op = HTP_OP_DIV;\n            break;\n        default:\n            GGML_ABORT(\"ggml-hex: binary : unsupported op: %d\\n\", t->op);\n            break;\n    }\n\n    // src0: Weights (mulmat) or First Operand (binary op).\n    // If constant (e.g. weights), no cache management is needed.\n    // src1: Input Activations (mulmat) or Second Operand (binary op).\n\n    size_t n_bufs = 0;\n    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);\n\n    return n_bufs;\n}\n\nstatic inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {\n    req->op = HTP_OP_CPY;\n\n    size_t n_bufs = 0;\n    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);\n\n    return n_bufs;\n}\n\nstatic inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {\n    req->op = HTP_OP_GET_ROWS;\n\n    size_t n_bufs = 0;\n    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);\n\n    return n_bufs;\n}\n\nstatic inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {\n    req->op = HTP_OP_ARGSORT;\n    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));\n\n    size_t n_bufs = 0;\n    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);\n\n    return n_bufs;\n}\n\ntemplate <bool _is_src0_constant>\nstatic inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {\n    switch (t->op) {\n        case GGML_OP_MUL_MAT_ID:\n            req->op = HTP_OP_MUL_MAT_ID;\n            break;\n        case GGML_OP_ADD_ID:\n            req->op = HTP_OP_ADD_ID;\n            break;\n        default:\n            GGML_ABORT(\"ggml-hex: unsupported op: %d\\n\", t->op);\n    }\n\n    // src0: Weights (mulmat) or Input Activations (other op).\n    // If constant, no cache management is needed.\n    // src1: Input Activations (mulmat) or Second Operand (binary op).\n    // src2: Expert IDs (mulmat) or Activated Experts (other op).\n\n    size_t n_bufs = 0;\n    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);\n\n    return n_bufs;\n}\n\nstatic inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {\n    req->op = HTP_OP_SET_ROWS;\n\n    size_t n_bufs = 0;\n    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);\n\n    return n_bufs;\n}\n\nstatic inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {\n    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));\n\n    bool supported = false;\n\n    switch (t->op) {\n        case GGML_OP_RMS_NORM:\n            req->op   = HTP_OP_RMS_NORM;\n            supported = true;\n            break;\n\n        case GGML_OP_SCALE:\n            req->op   = HTP_OP_SCALE;\n            supported = true;\n            break;\n\n        case GGML_OP_SQR:\n            req->op   = HTP_OP_SQR;\n            supported = true;\n            break;\n\n        case GGML_OP_SQRT:\n            req->op   = HTP_OP_SQRT;\n            supported = true;\n            break;\n\n        case GGML_OP_UNARY:\n            if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {\n                req->op   = HTP_OP_UNARY_SILU;\n                supported = true;\n            } else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) {\n                req->op   = HTP_OP_UNARY_GELU;\n                supported = true;\n            }\n            break;\n\n        case GGML_OP_GLU:\n            if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU) {\n                req->op   = HTP_OP_GLU_SWIGLU;\n                supported = true;\n            } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {\n                req->op   = HTP_OP_GLU_SWIGLU_OAI;\n                supported = true;\n            } else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) {\n                req->op   = HTP_OP_GLU_GEGLU;\n                supported = true;\n            }\n            break;\n\n        case GGML_OP_SOFT_MAX:\n            req->op   = HTP_OP_SOFTMAX;\n            supported = true;\n            break;\n\n        default:\n            break;\n    }\n\n    if (!supported) {\n        GGML_ABORT(\"ggml-hex: unary : unsupported op: %d\\n\", t->op);\n    }\n\n    size_t n_bufs = 0;\n    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);\n\n    return n_bufs;\n}\n\nstatic inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {\n    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));\n    req->op = HTP_OP_SUM_ROWS;\n\n    size_t n_bufs = 0;\n    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);\n\n    return n_bufs;\n}\n\nstatic inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {\n    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));\n    req->op = HTP_OP_ROPE;\n\n    size_t n_bufs = 0;\n    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);\n\n    return n_bufs;\n}\n\nstatic inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {\n    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));\n    req->op = HTP_OP_FLASH_ATTN_EXT;\n\n    size_t n_bufs = 0;\n    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);\n\n    return n_bufs;\n}\n\nstatic inline size_t init_ssm_conv_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {\n    req->op = HTP_OP_SSM_CONV;\n\n    size_t n_bufs = 0;\n    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);\n    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CONSTANT);\n    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);\n\n    return n_bufs;\n}\n\nstatic const char * ggml_backend_hexagon_name(ggml_backend_t backend) {\n    auto sess = static_cast<ggml_hexagon_session *>(backend->context);\n    return sess->name.c_str();\n}\n\nstatic void ggml_backend_hexagon_free(ggml_backend_t backend) {\n    // we just need to delete the backend here\n    // the sessions are allocated & freed as part of the registry\n    delete backend;\n}\n\nstatic inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {\n    return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type));\n}\n\nstatic inline bool is_compute_op(ggml_tensor *node)\n{\n    return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);\n}\n\n// scan the graph and figure out last compute op index\nstatic inline int last_compute_op(ggml_cgraph * graph) {\n    int last = 0;\n    for (int i = 0; i < graph->n_nodes; ++i) {\n        if (is_compute_op(graph->nodes[i])) {\n            last = i;\n        }\n    }\n\n    return last;\n}\n\nstatic ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {\n    auto sess = static_cast<ggml_hexagon_session *>(backend->context);\n\n    HEX_VERBOSE(\"ggml-hex: %s graph-compute n_nodes %d\\n\", sess->name.c_str(), graph->n_nodes);\n\n    const int last = last_compute_op(graph);\n\n    const struct ggml_tensor * prev_op = nullptr;  // prev executed op\n\n    for (int i = 0; i < graph->n_nodes; ++i) {\n        ggml_tensor * node = graph->nodes[i];\n\n        if (!is_compute_op(node)) {\n            continue;\n        }\n\n        uint32_t flags = 0;\n\n        // skip quantizer if src1 is reused\n        if (op_reuse_src1(node, prev_op)) {\n            flags |= HTP_OPFLAGS_SKIP_QUANTIZE;\n        }\n\n        prev_op = node;\n\n        // ask for early notification for the last Op\n        if (i == last) {\n            flags |= HTP_OPFLAGS_EARLY_WAKEUP;\n        }\n\n        switch (node->op) {\n            case GGML_OP_MUL_MAT:\n                if (ggml_is_quantized(node->src[0]->type)) {\n                    ggml_hexagon_dispatch_op<init_binary_req<true>>(sess, node, flags);\n                } else {\n                    ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);\n                }\n                break;\n            case GGML_OP_MUL_MAT_ID:\n                if (ggml_is_quantized(node->src[0]->type)) {\n                    ggml_hexagon_dispatch_op<init_binary_id_req<true>>(sess, node, flags);\n                } else {\n                    ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);\n                }\n                break;\n            case GGML_OP_MUL:\n            case GGML_OP_ADD:\n            case GGML_OP_SUB:\n            case GGML_OP_DIV:\n                ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);\n                break;\n            case GGML_OP_ADD_ID:\n                ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);\n                break;\n            case GGML_OP_RMS_NORM:\n            case GGML_OP_SCALE:\n                ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);\n                break;\n            case GGML_OP_SQR:\n            case GGML_OP_SQRT:\n                ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);\n                break;\n            case GGML_OP_SUM_ROWS:\n                ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);\n                break;\n            case GGML_OP_UNARY:\n                if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||\n                        (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {\n                    ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);\n                }\n                break;\n            case GGML_OP_GLU:\n                if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||\n                        (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||\n                        (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {\n                    ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);\n                }\n                break;\n            case GGML_OP_SOFT_MAX:\n                ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);\n                break;\n\n            case GGML_OP_ROPE:\n                ggml_hexagon_dispatch_op<init_rope_req>(sess, node, flags);\n                break;\n\n            case GGML_OP_FLASH_ATTN_EXT:\n                ggml_hexagon_dispatch_op<init_flash_attn_ext_req>(sess, node, flags);\n                break;\n\n            case GGML_OP_SET_ROWS:\n                ggml_hexagon_dispatch_op<init_set_rows_req>(sess, node, flags);\n                break;\n\n            case GGML_OP_GET_ROWS:\n                ggml_hexagon_dispatch_op<init_get_rows_req>(sess, node, flags);\n                break;\n\n            case GGML_OP_CPY:\n                ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);\n                break;\n\n            case GGML_OP_ARGSORT:\n                ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);\n                break;\n\n            case GGML_OP_SSM_CONV:\n                ggml_hexagon_dispatch_op<init_ssm_conv_req>(sess, node, flags);\n                break;\n\n            default:\n                GGML_ABORT(\"\\nggml-hex: graph-compute %s is not supported\\n\", ggml_op_desc(node));\n        }\n    }\n\n    // Wait until all pending ops complete\n    sess->flush();\n\n    return GGML_STATUS_SUCCESS;\n}\n\nstatic void ggml_backend_hexagon_synchronize(ggml_backend_t backend) {\n    auto sess = static_cast<ggml_hexagon_session *>(backend->context);\n\n    HEX_VERBOSE(\"ggml-hex: %s synchronize\\n\", sess->name.c_str());\n\n    // Wait until all pending ops complete\n    sess->flush();\n}\n\nstruct node_info {\n    ggml_tensor * node;\n\n    std::vector<ggml_tensor *> fused;\n\n    ggml_op op() const {\n        return node->op;\n    }\n\n    const ggml_tensor * dst() const {\n        return fused.empty() ? node : fused.back();\n    }\n\n    const ggml_tensor * src0() const {\n        return node->src[0];\n    }\n\n    const ggml_tensor * src1() const {\n        return node->src[1];\n    }\n\n    bool is_empty() const {\n        return ggml_op_is_empty(node->op);\n    }\n\n    void add_fused(ggml_tensor * t) {\n        fused.push_back(t);\n    }\n\n    bool stackable() const {\n        switch (this->op()) {\n            case GGML_OP_MUL_MAT:\n            case GGML_OP_MUL_MAT_ID:\n                return ggml_is_quantized(this->src0()->type);\n            default:\n                return false;\n        }\n    }\n\n    bool same_input(const node_info& n) const {\n        return n.src1() == this->src1();\n    }\n};\n\nstatic std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<node_info> & nodes) {\n    const int n = nodes.size();\n\n    std::vector<int> res;\n    res.reserve(n);\n\n    std::vector<bool> used(n, false);\n\n    // The main goal here is to stack the MUL_MAT ops with the same src1 input.\n    // This allows use to reuse dynamically quantized src1 in VTCM.\n\n    // TODO: the current version might do incorrect reordering in cases where quantized src0\n    //       input is an output of another Op.\n\n    for (int i0 = 0; i0 < n; i0++) {\n        if (used[i0]) {\n            continue;\n        }\n\n        res.push_back(i0);\n\n        const auto & node0 = nodes[i0];\n\n        if (!node0.stackable()) {\n            continue;\n        }\n\n        // that many nodes forward to search for stackable nodes that can reuse VTCM\n        constexpr int N_FORWARD = 16;\n\n        for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {\n            if (used[i1]) {\n                continue;\n            }\n\n            const auto & node1 = nodes[i1];\n\n            if (node1.stackable() && node1.same_input(node0)) {\n                res.push_back(i1);\n                used[i1] = true;\n            }\n        }\n    }\n\n    return res;\n}\n\nstatic void ggml_backend_hexagon_graph_optimize(ggml_backend_t backend, ggml_cgraph * gf) {\n    const int n = gf->n_nodes;\n\n    constexpr int MAX_FUSE = 16;\n\n    enum ggml_op ops[MAX_FUSE];\n\n    std::vector<node_info> nodes;\n    nodes.reserve(gf->n_nodes);\n\n    // fuse nodes:\n    // we don't want to make reorders that break fusing, so we first pack all fusable tensors\n    //   and perform the reorder over the fused nodes. after the reorder is done, we unfuse\n    for (int i = 0; i < n; i++) {\n        node_info node = {\n            /*.node =*/gf->nodes[i],\n            /*.fused =*/{},\n        };\n\n        // fuse only ops that start with these operations\n        // can be expanded when needed\n        if (node.op() == GGML_OP_ADD ||\n            node.op() == GGML_OP_NORM ||\n            node.op() == GGML_OP_RMS_NORM) {\n            ops[0] = node.op();\n\n            int f = i + 1;\n            while (f < n && f < i + MAX_FUSE) {\n                // conservatively allow fusing only these ops\n                // can be expanded when needed\n                if (gf->nodes[f]->op != GGML_OP_ADD &&\n                    gf->nodes[f]->op != GGML_OP_MUL &&\n                    gf->nodes[f]->op != GGML_OP_NORM &&\n                    gf->nodes[f]->op != GGML_OP_RMS_NORM) {\n                    break;\n                }\n                ops[f - i] = gf->nodes[f]->op;\n                f++;\n            }\n\n            f -= i;\n            for (; f > 1; f--) {\n                if (ggml_can_fuse(gf, i, ops, f)) {\n                    break;\n                }\n            }\n\n            // add the fused tensors into the node info so we can unfuse them later\n            for (int k = 1; k < f; k++) {\n                ++i;\n\n                // the .dst() becomes the last fused tensor\n                node.add_fused(gf->nodes[i]);\n            }\n        }\n\n        nodes.push_back(std::move(node));\n    }\n\n    const auto order = ggml_hexagon_graph_optimize_reorder(nodes);\n\n    // unfuse\n    {\n        int j = 0;\n        for (const auto i : order) {\n            const auto & node = nodes[i];\n\n            gf->nodes[j++] = node.node;\n\n            for (auto * fused : node.fused) {\n                gf->nodes[j++] = fused;\n            }\n        }\n    }\n}\n\nstatic struct ggml_backend_i hexagon_backend_i = {\n    /* .get_name                = */ ggml_backend_hexagon_name,\n    /* .free                    = */ ggml_backend_hexagon_free,\n    /* .set_tensor_async        = */ NULL,\n    /* .get_tensor_async        = */ NULL,\n    /* .cpy_tensor_async        = */ NULL,\n    /* .synchronize             = */ ggml_backend_hexagon_synchronize,\n    /* .graph_plan_create       = */ NULL,\n    /* .graph_plan_free         = */ NULL,\n    /* .graph_plan_update       = */ NULL,\n    /* .graph_plan_compute      = */ NULL,\n    /* .graph_compute           = */ ggml_backend_hexagon_graph_compute,\n    /* .event_record            = */ NULL,\n    /* .event_wait              = */ NULL,\n    /* .graph_optimize          = */ ggml_backend_hexagon_graph_optimize,\n};\n\nstatic ggml_guid_t ggml_backend_hexagon_guid() {\n    static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49,\n                              0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11 };\n    return &guid;\n}\n\nbool ggml_backend_is_hexagon(ggml_backend_t backend) {\n    return backend && backend->iface.get_name == ggml_backend_hexagon_name;\n}\n\n// device interface\n\nstatic ggml_backend_t ggml_backend_hexagon_device_init(ggml_backend_dev_t dev, const char * params) {\n    auto sess = static_cast<ggml_hexagon_session *>(dev->context);\n\n    return new ggml_backend{\n        /* .guid      = */ ggml_backend_hexagon_guid(),\n        /* .interface = */ hexagon_backend_i,\n        /* .device    = */ dev,\n        /* .context   = */ sess,\n    };\n\n    GGML_UNUSED(params);\n}\n\nstatic const char * ggml_backend_hexagon_device_get_name(ggml_backend_dev_t dev) {\n    auto sess = static_cast<ggml_hexagon_session *>(dev->context);\n    return sess->name.c_str();\n\n    GGML_UNUSED(dev);\n}\n\nstatic const char * ggml_backend_hexagon_device_get_description(ggml_backend_dev_t dev) {\n    return \"Hexagon\";\n    GGML_UNUSED(dev);\n}\n\nstatic void ggml_backend_hexagon_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {\n    // ~2GB per session for now\n    *free  = 2ULL * 1024 * 1024 * 1024;\n    *total = *free;\n\n    GGML_UNUSED(dev);\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_hexagon_device_get_type(ggml_backend_dev_t dev) {\n    return GGML_BACKEND_DEVICE_TYPE_GPU;\n\n    GGML_UNUSED(dev);\n}\n\nstatic void ggml_backend_hexagon_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {\n    props->name        = ggml_backend_hexagon_device_get_name(dev);\n    props->description = ggml_backend_hexagon_device_get_description(dev);\n    props->type        = ggml_backend_hexagon_device_get_type(dev);\n    ggml_backend_hexagon_device_get_memory(dev, &props->memory_free, &props->memory_total);\n    props->caps = {\n        /* .async                 = */ true,\n        /* .host_buffer           = */ (bool) opt_hostbuf,\n        /* .buffer_from_host_ptr  = */ false,\n        /* .events                = */ false,\n    };\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_buffer_type(ggml_backend_dev_t dev) {\n    auto sess = static_cast<ggml_hexagon_session *>(dev->context);\n    return &sess->buffer_type;\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_repack_buffer_type(ggml_backend_dev_t dev) {\n    auto sess = static_cast<ggml_hexagon_session *>(dev->context);\n    return &sess->repack_buffer_type;\n}\n\nstatic bool ggml_hexagon_supported_buffer(ggml_hexagon_session *sess, const struct ggml_tensor * t) {\n    if (t && t->buffer) {\n        if (ggml_backend_buffer_is_hexagon(t->buffer)      == false) return false; // not our buffer\n        if (ggml_backend_hexagon_buffer_get_sess(t->buffer) != sess) return false; // wrong session\n    }\n    return true;\n}\n\nstatic bool ggml_hexagon_supported_buffers(ggml_hexagon_session *sess, const struct ggml_tensor * t) {\n    // all srcs & dsts must be mapped to the same session\n    if (!ggml_hexagon_supported_buffer(sess, t)) {\n        return false;\n    }\n\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        if (!ggml_hexagon_supported_buffer(sess, t->src[i])) {\n            return false;\n        }\n    }\n\n    return true;\n}\n\nstatic bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {\n    const struct ggml_tensor * src0 = op->src[0];\n    const struct ggml_tensor * dst  = op;\n\n    // for now we can do f32 -> f16 and f16 -> f32 (without reshaping)\n    if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;\n    if ( dst->type != GGML_TYPE_F32 &&  dst->type != GGML_TYPE_F16) return false;\n\n    const bool sametype   = (src0->type == dst->type);\n    const bool transposed = ggml_is_transposed(src0) || ggml_is_transposed(dst);\n    const bool sameshape  = !transposed && ggml_are_same_shape(src0, dst);\n\n    // can handle any shape and any same-type (pretty slow if reshaping is required)\n    if (sametype) return true;\n\n    // cannot handle re-shaping and type conversion at the same time\n    if (!sameshape) return false;\n\n    return true;\n}\n\nstatic bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {\n    auto sess = static_cast<ggml_hexagon_session *>(dev->context);\n\n    // all srcs & dsts must be mapped to the same session\n    if (!ggml_hexagon_supported_buffers(sess, op)) {\n        ggml_hexagon_dump_op_supp(sess->name, op, false);\n        return false;\n    }\n\n    bool supp = false;\n    switch (op->op) {\n        case GGML_OP_NONE:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n            supp = true;\n            break;\n\n        case GGML_OP_MUL_MAT:\n            supp = ggml_hexagon_supported_mul_mat(sess, op);\n            break;\n\n        case GGML_OP_MUL_MAT_ID:\n            supp = ggml_hexagon_supported_mul_mat_id(sess, op);\n            break;\n\n        case GGML_OP_MUL:\n        case GGML_OP_ADD:\n        case GGML_OP_SUB:\n        case GGML_OP_DIV:\n            supp = ggml_hexagon_supported_binary(sess, op);\n            break;\n\n        case GGML_OP_ADD_ID:\n            supp = ggml_hexagon_supported_add_id(sess, op);\n            break;\n\n        case GGML_OP_RMS_NORM:\n        case GGML_OP_SCALE:\n            supp = ggml_hexagon_supported_unary(sess, op);\n            break;\n\n        case GGML_OP_SQR:\n        case GGML_OP_SQRT:\n            supp = ggml_hexagon_supported_unary(sess, op);\n            break;\n\n        case GGML_OP_SUM_ROWS:\n            supp = ggml_hexagon_supported_sum_rows(sess, op);\n            break;\n\n        case GGML_OP_SOFT_MAX:\n            supp = ggml_hexagon_supported_softmax(sess, op);\n            break;\n\n        case GGML_OP_UNARY:\n            {\n                const auto unary_op = ggml_get_unary_op(op);\n                if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) {\n                    supp = ggml_hexagon_supported_activations(sess, op);\n                }\n                break;\n            }\n        case GGML_OP_GLU:\n            {\n                const auto glu_op = ggml_get_glu_op(op);\n                if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {\n                    supp = ggml_hexagon_supported_activations(sess, op);\n                }\n                break;\n            }\n        case GGML_OP_ROPE:\n            supp = ggml_hexagon_supported_rope(sess, op);\n            break;\n\n        case GGML_OP_FLASH_ATTN_EXT:\n            supp = ggml_hexagon_supported_flash_attn_ext(sess, op);\n            break;\n\n        case GGML_OP_SET_ROWS:\n            supp = ggml_hexagon_supported_set_rows(sess, op);\n            break;\n\n        case GGML_OP_GET_ROWS:\n            supp = ggml_hexagon_supported_get_rows(sess, op);\n            break;\n\n        case GGML_OP_CPY:\n            supp = ggml_hexagon_supported_cpy(sess, op);\n            break;\n\n        case GGML_OP_ARGSORT:\n            supp = ggml_hexagon_supported_argsort(sess, op);\n            break;\n\n        case GGML_OP_SSM_CONV:\n            supp = ggml_hexagon_supported_ssm_conv(sess, op);\n            break;\n\n        default:\n            break;\n    }\n\n    ggml_hexagon_dump_op_supp(sess->name, op, supp);\n    return supp;\n}\n\nstatic bool ggml_backend_hexagon_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    if (buft->iface.get_alignment != ggml_backend_hexagon_buffer_type_get_alignment) {\n        return false;\n    }\n\n    auto s0 = static_cast<ggml_hexagon_session *>(dev->context);\n    auto s1 = static_cast<ggml_backend_hexagon_buffer_type_context *>(buft->context)->sess;\n\n    // Need session/domain-id for buffers to be compatible\n    bool supp = (s0->session_id == s1->session_id);\n\n    HEX_VERBOSE(\"ggml-hex: %s device-supports-buft %s (%d)\\n\", s0->name.c_str(), s1->name.c_str(), (int) supp);\n\n    return supp;\n}\n\nstatic ggml_backend_buffer_type_t * ggml_backend_hexagon_device_get_extra_buffers_type(ggml_backend_dev_t dev) {\n    auto s0 = static_cast<ggml_hexagon_session *>(dev->context);\n    HEX_VERBOSE(\"ggml-hex: device-get-extra-buft : %s \\n\", s0->name.c_str());\n\n    static ggml_backend_buffer_type_t bufts[2];\n    bufts[0] = ggml_backend_hexagon_device_get_repack_buffer_type(dev);\n    bufts[1] = NULL;\n    return bufts;\n}\n\nstatic const struct ggml_backend_device_i ggml_backend_hexagon_device_i = {\n    /* .get_name             = */ ggml_backend_hexagon_device_get_name,\n    /* .get_description      = */ ggml_backend_hexagon_device_get_description,\n    /* .get_memory           = */ ggml_backend_hexagon_device_get_memory,\n    /* .get_type             = */ ggml_backend_hexagon_device_get_type,\n    /* .get_props            = */ ggml_backend_hexagon_device_get_props,\n    /* .init_backend         = */ ggml_backend_hexagon_device_init,\n    /* .get_buffer_type      = */ ggml_backend_hexagon_device_get_buffer_type,\n    /* .get_host_buffer_type = */ NULL,  // ggml_backend_hexagon_device_get_host_buffer_type,\n    /* .buffer_from_host_ptr = */ NULL,  // ggml_backend_hexagon_device_buffer_from_ptr,\n    /* .supports_op          = */ ggml_backend_hexagon_device_supports_op,\n    /* .supports_buft        = */ ggml_backend_hexagon_device_supports_buft,\n    /* .offload_op           = */ NULL,  // ggml_backend_hexagon_device_offload_op,\n    /* .event_new            = */ NULL,\n    /* .event_free           = */ NULL,\n    /* .event_synchronize    = */ NULL,\n};\n\n//** backend registry\n\n#define GGML_HEXAGON_MAX_SESSIONS 16\n\nstruct ggml_hexagon_registry {\n    ggml_hexagon_registry(ggml_backend_reg_t reg);\n    ~ggml_hexagon_registry();\n\n    ggml_backend_device devices[GGML_HEXAGON_MAX_SESSIONS];\n};\n\nggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {\n    GGML_LOG_INFO(\"ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev %zu\\n\", opt_ndev);\n\n    if (!opt_arch) {\n        int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch);\n        if (err != 0) {\n            GGML_LOG_ERROR(\"ggml-hex: failed to query HTP version (err %d) defaulting to v73\\n\", err);\n            opt_arch = 73;\n        }\n    }\n\n#if defined(__ANDROID__)\n    if (opt_arch < 75) {\n        opt_ndev = 1;\n        GGML_LOG_WARN(\"ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\\n\");\n    }\n#endif\n\n    GGML_LOG_INFO(\"ggml-hex: Hexagon Arch version v%d\\n\", opt_arch);\n\n    // Create devices / sessions\n    for (size_t i = 0; i < opt_ndev; i++) {\n        devices[i].iface = ggml_backend_hexagon_device_i;\n        devices[i].reg   = reg;\n        try {\n            devices[i].context = new ggml_hexagon_session(i, &devices[i]);\n        } catch (const std::exception & exc) {\n            GGML_LOG_ERROR(\"ggml-hex: failed to create device/session %zu\\n\", i);\n            devices[i].context = nullptr;\n        }\n    }\n}\n\nggml_hexagon_registry::~ggml_hexagon_registry() {\n    GGML_LOG_INFO(\"ggml-hex: releasing registry\\n\");\n\n    // Release devices / sessions\n    for (size_t i = 0; i < opt_ndev; i++) {\n        auto sess = static_cast<ggml_hexagon_session *>(devices[i].context);\n        delete sess;\n    }\n}\n\nstatic const char * ggml_backend_hexagon_reg_get_name(ggml_backend_reg_t reg) {\n    return \"HTP\";\n    GGML_UNUSED(reg);\n}\n\nstatic size_t ggml_backend_hexagon_reg_get_device_count(ggml_backend_reg_t reg) {\n    return opt_ndev;\n    GGML_UNUSED(reg);\n}\n\nstatic ggml_backend_dev_t ggml_backend_hexagon_reg_get_device(ggml_backend_reg_t reg, size_t index) {\n    auto hreg = static_cast<ggml_hexagon_registry *>(reg->context);\n\n    if (index >= opt_ndev || !hreg->devices[index].context) {\n        return nullptr;\n    }\n\n    return &hreg->devices[index];\n}\n\nstatic void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, const char * name) {\n    if (strcmp(name, \"ggml_backend_dev_get_extra_bufts\") == 0 && opt_hostbuf) {\n        ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_hexagon_device_get_extra_buffers_type;\n        return (void *) fct;\n    }\n\n    return NULL;\n}\n\nstatic void ggml_hexagon_init(ggml_backend_reg * reg) {\n    // Basic sanity checks to make sure definitions match\n    static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0,\n                  \"please update hexagon_type to match ggml_type\");\n    static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0,\n                  \"please update hexagon_type to match ggml_type\");\n    static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,\n                  \"please update hexagon_type to match ggml_type\");\n\n    const char * str_experimental = getenv(\"GGML_HEXAGON_EXPERIMENTAL\");\n    const char * str_verbose = getenv(\"GGML_HEXAGON_VERBOSE\");\n    const char * str_hostbuf = getenv(\"GGML_HEXAGON_HOSTBUF\");\n    const char * str_opmask  = getenv(\"GGML_HEXAGON_OPMASK\");\n    const char * str_opsync  = getenv(\"GGML_HEXAGON_OPSYNC\");\n    const char * str_profile = getenv(\"GGML_HEXAGON_PROFILE\");\n    const char * str_etm     = getenv(\"GGML_HEXAGON_ETM\");\n    const char * str_nhvx    = getenv(\"GGML_HEXAGON_NHVX\");\n    const char * str_ndev    = getenv(\"GGML_HEXAGON_NDEV\");\n    const char * str_arch    = getenv(\"GGML_HEXAGON_ARCH\");\n\n    opt_experimental = str_experimental ? atoi(str_experimental) : 0;\n    opt_verbose      = str_verbose ? atoi(str_verbose) : 0;\n    opt_hostbuf      = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;\n    opt_opmask       = str_opmask  ? strtoul(str_opmask, NULL, 0) : opt_opmask;\n    opt_opsync       = str_opsync  ? atoi(str_opsync)  : 0;\n    opt_profile      = str_profile ? atoi(str_profile) : 0;\n    opt_etm          = str_etm     ? atoi(str_etm) : 0;\n    opt_nhvx         = str_nhvx    ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;\n    opt_ndev         = str_ndev    ? strtoul(str_ndev, NULL, 0) : opt_ndev;\n\n    if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {\n        opt_ndev = GGML_HEXAGON_MAX_SESSIONS;\n    }\n\n    if (str_arch) {\n        if (str_arch[0] == 'v') {\n            str_arch++;\n        }\n        opt_arch = strtoul(str_arch, NULL, 0);\n    }\n\n    opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1;\n\n    reg->context = new ggml_hexagon_registry(reg);\n\n    HEX_VERBOSE(\"ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\\n\", sizeof(struct htp_general_req),\n                sizeof(struct htp_general_rsp));\n}\n\nstatic const struct ggml_backend_reg_i ggml_backend_hexagon_reg_i = {\n    /* .get_name         = */ ggml_backend_hexagon_reg_get_name,\n    /* .get_device_count = */ ggml_backend_hexagon_reg_get_device_count,\n    /* .get_device       = */ ggml_backend_hexagon_reg_get_device,\n    /* .get_proc_address = */ ggml_backend_hexagon_get_proc_address,\n};\n\nggml_backend_reg_t ggml_backend_hexagon_reg(void) {\n    static bool initialized = false;\n\n    static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION,\n                                    /* .iface       = */ ggml_backend_hexagon_reg_i,\n                                    /* .context     = */ NULL };\n\n    {\n        static std::mutex           mutex;\n        std::lock_guard<std::mutex> lock(mutex);\n        if (!initialized) {\n            auto nErr = htpdrv_init();\n            if (nErr != AEE_SUCCESS) {\n                return NULL;\n            }\n\n            ggml_hexagon_init(&reg);\n        }\n\n        initialized = true;\n    }\n\n    return &reg;\n}\n\nGGML_BACKEND_DL_IMPL(ggml_backend_hexagon_reg)\n"
  },
  {
    "path": "src/ggml-hexagon/htp/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.22.2)\nproject(ggml-htp C CXX ASM)\n\ninclude(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)\n\ninclude_directories(\n    ${HEXAGON_SDK_ROOT}/incs\n    ${HEXAGON_SDK_ROOT}/incs/stddef\n    ${CMAKE_CURRENT_SOURCE_DIR}/../../../include\n    ${CMAKE_CURRENT_SOURCE_DIR}/../..\n    ${CMAKE_CURRENT_SOURCE_DIR}/..\n    ${CMAKE_CURRENT_SOURCE_DIR}\n    ${CMAKE_CURRENT_BINARY_DIR})\n\nset(HTP_LIB ggml-htp-${DSP_VERSION})\n\nadd_library(${HTP_LIB} SHARED\n    main.c\n    htp_iface_skel.c\n    worker-pool.c\n    hex-dma.c\n    matmul-ops.c\n    binary-ops.c\n    unary-ops.c\n    sum-rows-ops.c\n    softmax-ops.c\n    act-ops.c\n    rope-ops.c\n    flash-attn-ops.c\n    set-rows-ops.c\n    get-rows-ops.c\n    cpy-ops.c\n    argsort-ops.c\n    ssm-conv.c\n)\n\ntarget_compile_definitions(${HTP_LIB} PRIVATE\n    $<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>\n    $<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>\n    FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})\n\nbuild_idl(htp_iface.idl ${HTP_LIB})\n\nset_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)\n\ninstall(TARGETS ${HTP_LIB})\n"
  },
  {
    "path": "src/ggml-hexagon/htp/act-ops.c",
    "content": "#pragma clang diagnostic ignored \"-Wunused-variable\"\n#pragma clang diagnostic ignored \"-Wunused-function\"\n#pragma clang diagnostic ignored \"-Wunused-but-set-variable\"\n\n#include <HAP_farf.h>\n#include <HAP_perf.h>\n\n#include <math.h>\n#include <string.h>\n\n#include \"hex-dma.h\"\n#include \"hvx-utils.h\"\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n#include \"htp-ctx.h\"\n#include \"htp-msg.h\"\n#include \"htp-ops.h\"\n\n#define htp_act_preamble3              \\\n    const uint32_t ne00 = src0->ne[0]; \\\n    const uint32_t ne01 = src0->ne[1]; \\\n    const uint32_t ne02 = src0->ne[2]; \\\n    const uint32_t ne03 = src0->ne[3]; \\\n                                       \\\n    const uint32_t ne10 = src1->ne[0]; \\\n    const uint32_t ne11 = src1->ne[1]; \\\n    const uint32_t ne12 = src1->ne[2]; \\\n    const uint32_t ne13 = src1->ne[3]; \\\n                                       \\\n    const uint32_t ne0 = dst->ne[0];   \\\n    const uint32_t ne1 = dst->ne[1];   \\\n    const uint32_t ne2 = dst->ne[2];   \\\n    const uint32_t ne3 = dst->ne[3];   \\\n                                       \\\n    const uint32_t nb00 = src0->nb[0]; \\\n    const uint32_t nb01 = src0->nb[1]; \\\n    const uint32_t nb02 = src0->nb[2]; \\\n    const uint32_t nb03 = src0->nb[3]; \\\n                                       \\\n    const uint32_t nb10 = src1->nb[0]; \\\n    const uint32_t nb11 = src1->nb[1]; \\\n    const uint32_t nb12 = src1->nb[2]; \\\n    const uint32_t nb13 = src1->nb[3]; \\\n                                       \\\n    const uint32_t nb0 = dst->nb[0];   \\\n    const uint32_t nb1 = dst->nb[1];   \\\n    const uint32_t nb2 = dst->nb[2];   \\\n    const uint32_t nb3 = dst->nb[3];\n\n#define htp_act_preamble2              \\\n    const uint32_t ne00 = src0->ne[0]; \\\n    const uint32_t ne01 = src0->ne[1]; \\\n    const uint32_t ne02 = src0->ne[2]; \\\n    const uint32_t ne03 = src0->ne[3]; \\\n                                       \\\n    const uint32_t ne0 = dst->ne[0];   \\\n    const uint32_t ne1 = dst->ne[1];   \\\n    const uint32_t ne2 = dst->ne[2];   \\\n    const uint32_t ne3 = dst->ne[3];   \\\n                                       \\\n    const uint32_t nb00 = src0->nb[0]; \\\n    const uint32_t nb01 = src0->nb[1]; \\\n    const uint32_t nb02 = src0->nb[2]; \\\n    const uint32_t nb03 = src0->nb[3]; \\\n                                       \\\n    const uint32_t nb0 = dst->nb[0];   \\\n    const uint32_t nb1 = dst->nb[1];   \\\n    const uint32_t nb2 = dst->nb[2];   \\\n    const uint32_t nb3 = dst->nb[3];\n\nstruct htp_act_context {\n    struct htp_ops_context *  octx;\n\n    // Precomputed values\n    const uint8_t *           data_src0;\n    const uint8_t *           data_src1;\n    uint8_t *                 data_dst;\n\n    size_t                    src0_row_size;\n    size_t                    src1_row_size;\n    size_t                    dst_row_size;\n\n    size_t                    src0_row_size_aligned;\n    size_t                    src1_row_size_aligned;\n    size_t                    dst_row_size_aligned;\n\n    size_t                    src0_spad_half_size;\n    size_t                    src1_spad_half_size;\n    size_t                    dst_spad_half_size;\n\n    uint32_t                  block;\n    uint32_t                  src0_nrows;\n    uint32_t                  src0_nrows_per_thread;\n    int                       nc;\n};\n\nstatic void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_act_context * actx = (struct htp_act_context *) data;\n    const struct htp_tensor * src0 = &actx->octx->src0;\n    const struct htp_tensor * src1 = &actx->octx->src1;\n    const struct htp_tensor * dst  = &actx->octx->dst;\n    htp_act_preamble3;\n\n    size_t src0_row_size = actx->src0_row_size;\n    size_t src1_row_size = actx->src1_row_size;\n    size_t dst_row_size  = actx->dst_row_size;\n\n    const uint32_t src0_nrows = actx->src0_nrows;\n    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;\n    const uint32_t src0_start_row = src0_nrows_per_thread * ith;\n    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);\n\n    // no work for this thread\n    if (src0_start_row >= src0_end_row) {\n        return;\n    }\n\n    uint64_t t1, t2;\n    t1 = HAP_perf_get_qtimer_count();\n\n    const uint8_t * restrict data_src0 = actx->data_src0;\n    const uint8_t * restrict data_src1 = actx->data_src1;\n    uint8_t * restrict data_dst        = actx->data_dst;\n\n    const int  nc = actx->nc;\n\n    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;\n    const size_t src1_row_size_aligned = actx->src1_row_size_aligned;\n    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;\n\n    uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);\n    uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);\n    uint8_t * restrict dst_spad_data  = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);\n\n    size_t src0_spad_half_size = actx->src0_spad_half_size;\n    size_t src1_spad_half_size = actx->src1_spad_half_size;\n    size_t dst_spad_half_size  = actx->dst_spad_half_size;\n\n    const int BLOCK = actx->block;\n    if (BLOCK == 0) {\n        FARF(ERROR,\n             \"swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\\n\",\n             actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);\n        return;\n    }\n\n    dma_queue * dma_queue = actx->octx->ctx->dma[ith];\n\n    // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379\n    for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {\n        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);\n\n        // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)\n        dma_queue_push_vtcm_to_ddr(dma_queue,\n            dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),\n            dst_row_size, dst_row_size_aligned, 0);\n\n        dma_queue_push_ddr_to_vtcm(dma_queue,\n            dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),\n            src0_row_size_aligned, src0_row_size, block_size);\n        dma_queue_push_ddr_to_vtcm(dma_queue,\n            dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),\n            src1_row_size_aligned, src1_row_size, block_size);\n    }\n\n    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {\n        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);\n\n        float * dst_spad  = (float *) dma_queue_pop(dma_queue).src;\n        float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;\n        float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;\n\n        for (uint32_t ib = 0; ib < block_size; ib++) {\n            const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));\n            const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float));\n            float *       dst_spad_ptr  = dst_spad + ib * (dst_row_size_aligned / sizeof(float));\n\n            //swiglu(x) = x1 * sigmoid(x0)\n            hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, nc);\n            hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,\n                                (const uint8_t *) src1_spad_ptr, nc);\n        }\n\n        dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,\n                                   dst_row_size_aligned, block_size);\n\n        // prefetch N+2 loop iteration if any\n        const uint32_t pref_block = (ir + BLOCK * 2);\n        if (pref_block < src0_end_row) {\n            const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);\n            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),\n                                       src0_row_size_aligned, src0_row_size, pref_block_size);\n            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),\n                                       src1_row_size_aligned, src1_row_size, pref_block_size);\n        }\n    }\n\n    dma_queue_flush(dma_queue);\n\n    t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\\n\", ith, nth,\n         ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,\n         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\nstatic void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_act_context * actx = (struct htp_act_context *) data;\n    const struct htp_tensor * src0 = &actx->octx->src0;\n    const struct htp_tensor * src1 = &actx->octx->src1;\n    const struct htp_tensor * dst  = &actx->octx->dst;\n    htp_act_preamble3;\n\n    uint64_t t1, t2;\n    t1 = HAP_perf_get_qtimer_count();\n\n    size_t src0_row_size = actx->src0_row_size;\n    size_t src1_row_size = actx->src1_row_size;\n    size_t dst_row_size  = actx->dst_row_size;\n\n    const uint32_t src0_nrows = actx->src0_nrows;\n    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;\n\n    const uint32_t src0_start_row = src0_nrows_per_thread * ith;\n    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);\n\n    // no work for this thread\n    if (src0_start_row >= src0_end_row) {\n        return;\n    }\n\n    const uint8_t * restrict data_src0 = actx->data_src0;\n    const uint8_t * restrict data_src1 = actx->data_src1;\n    uint8_t * restrict data_dst        = actx->data_dst;\n\n    const int nc = actx->nc;\n\n    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;\n    const size_t src1_row_size_aligned = actx->src1_row_size_aligned;\n    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;\n\n    uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);\n    uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);\n    uint8_t * restrict dst_spad_data  = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);\n\n    size_t src0_spad_half_size = actx->src0_spad_half_size;\n    size_t src1_spad_half_size = actx->src1_spad_half_size;\n    size_t dst_spad_half_size  = actx->dst_spad_half_size;\n\n    const int BLOCK = actx->block;\n    if (BLOCK == 0) {\n        FARF(ERROR,\n             \"swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least \"\n             \"%zu\\n\",\n             actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);\n        return;\n    }\n    const float alpha = ((const float *) (actx->octx->op_params))[2];\n    const float limit = ((const float *) (actx->octx->op_params))[3];\n\n    dma_queue * dma_queue = actx->octx->ctx->dma[ith];\n\n    // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379\n    for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {\n        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);\n\n        // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)\n        dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),\n                                   dst_row_size, dst_row_size_aligned, 0);\n\n        dma_queue_push_ddr_to_vtcm(\n            dma_queue,\n            dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),\n            src0_row_size_aligned, src0_row_size, block_size);\n        dma_queue_push_ddr_to_vtcm(\n            dma_queue,\n            dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),\n            src1_row_size_aligned, src1_row_size, block_size);\n    }\n\n    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {\n        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);\n\n        float * dst_spad  = (float *) dma_queue_pop(dma_queue).src;\n        float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;\n        float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;\n\n        for (uint32_t ib = 0; ib < block_size; ib++) {\n            const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));\n            const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float));\n            float *       dst_spad_ptr  = dst_spad + ib * (dst_row_size_aligned / sizeof(float));\n\n            // x (src0_spad_data) = std::min(src0_p[k], limit);\n            hvx_min_scalar_f32((uint8_t *) src0_spad_ptr, (const uint8_t *) src0_spad_ptr, limit, nc);\n            // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);\n            hvx_clamp_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, -limit, limit, nc);\n            // y (src1_spad_data)  = y1 + 1.f\n            hvx_add_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, 1.0, nc);\n            // x1 (dst_spad_data) = alpha * (x)\n            hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, alpha, nc);\n            // x2 (dst_spad_data) = sigmoid(x1) = 1/(1+exp(-x1))\n            hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc);\n            // out = x * sigmoid(alpha * x) * (y + 1.f)\n            hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,\n                                (const uint8_t *) src1_spad_ptr, nc);\n        }\n\n        dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,\n                                   dst_row_size_aligned, block_size);\n\n        // prefetch N+2 loop iteration if any\n        const uint32_t pref_block = (ir + BLOCK * 2);\n        if (pref_block < src0_end_row) {\n            const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);\n            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),\n                                       src0_row_size_aligned, src0_row_size, pref_block_size);\n            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),\n                                       src1_row_size_aligned, src1_row_size, pref_block_size);\n        }\n    }\n\n    dma_queue_flush(dma_queue);\n\n    t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"swiglu-oai-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\\n\", ith, nth, src0->ne[0],\n         src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2],\n         src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\n\nstatic void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_act_context * actx = (struct htp_act_context *) data;\n    const struct htp_tensor * src0 = &actx->octx->src0;\n    const struct htp_tensor * dst  = &actx->octx->dst;\n    htp_act_preamble2;\n\n    uint64_t t1, t2;\n    t1 = HAP_perf_get_qtimer_count();\n\n    const size_t src0_row_size = actx->src0_row_size;\n    const size_t dst_row_size  = actx->dst_row_size;\n    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;\n    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;\n\n    const uint32_t src0_nrows = actx->src0_nrows;\n    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;\n\n    const uint32_t src0_start_row = src0_nrows_per_thread * ith;\n    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);\n\n    // no work for this thread\n    if (src0_start_row >= src0_end_row) {\n        return;\n    }\n\n    const uint8_t * data_src0 = actx->data_src0;\n    uint8_t * data_dst        = actx->data_dst;\n\n    // nc/ne0 matches.\n    const int ne0_val = actx->nc; // == dst->ne[0]\n\n    uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);\n    uint8_t * dst_spad_data  = actx->octx->dst_spad.data  + (ith * actx->octx->dst_spad.size_per_thread);\n\n    size_t src0_spad_half_size = actx->src0_spad_half_size;\n    size_t dst_spad_half_size  = actx->dst_spad_half_size;\n\n    // In gelu = x*sigmoid(x*1.702)\n    const int BLOCK = actx->block;\n\n    if (BLOCK == 0) {\n        FARF(ERROR, \"gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\\n\",\n                actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);\n        return;\n    }\n\n    dma_queue * dma_queue = actx->octx->ctx->dma[ith];\n\n    // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379\n    for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {\n        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);\n\n        // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)\n        dma_queue_push_vtcm_to_ddr(dma_queue,\n            dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),\n            dst_row_size, dst_row_size_aligned, 0);\n\n        dma_queue_push_ddr_to_vtcm(dma_queue,\n            dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),\n            src0_row_size_aligned, src0_row_size, block_size);\n    }\n\n    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {\n        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);\n\n        float* dst_spad  = (float *) dma_queue_pop(dma_queue).src;\n        float* src0_spad = (float *) dma_queue_pop(dma_queue).dst;\n\n        for (uint32_t ib = 0; ib < block_size; ib++) {\n            const float* src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));\n            float* dst_spad_ptr        = dst_spad  + ib * (dst_row_size_aligned  / sizeof(float));\n\n            // gelu = x * sigmoid(1.702 * x) // current implementation\n            hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0_val);\n            hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);\n            hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);\n        }\n\n        dma_queue_push_vtcm_to_ddr(dma_queue,\n            dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),\n            dst_row_size, dst_row_size_aligned, block_size);\n\n        // prefetch N+2 loop iteration if any\n        const uint32_t pref_block = (ir + BLOCK * 2);\n        if (pref_block < src0_end_row) {\n            const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);\n            dma_queue_push_ddr_to_vtcm(dma_queue,\n                dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),\n                src0_row_size_aligned, src0_row_size, pref_block_size);\n        }\n    }\n\n    dma_queue_flush(dma_queue);\n\n    t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"gelu-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\\n\", ith, nth, ne00, ne01, ne02,\n         ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\n\nstatic void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_act_context * actx = (struct htp_act_context *) data;\n    const struct htp_tensor * src0 = &actx->octx->src0;\n    const struct htp_tensor * dst  = &actx->octx->dst;\n    htp_act_preamble2;\n\n    uint64_t t1, t2;\n    t1 = HAP_perf_get_qtimer_count();\n\n    const size_t src0_row_size = actx->src0_row_size;\n    const size_t dst_row_size  = actx->dst_row_size;\n    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;\n    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;\n\n    const uint32_t src0_nrows = actx->src0_nrows;\n    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;\n\n    const uint32_t src0_start_row = src0_nrows_per_thread * ith;\n    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);\n\n    // no work for this thread\n    if (src0_start_row >= src0_end_row) {\n        return;\n    }\n\n    const uint8_t * data_src0 = actx->data_src0;\n    uint8_t * data_dst        = actx->data_dst;\n\n    const int ne0_val = actx->nc; // == dst->ne[0]\n\n    uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);\n    uint8_t * dst_spad_data  = actx->octx->dst_spad.data  + (ith * actx->octx->dst_spad.size_per_thread);\n\n    size_t src0_spad_half_size = actx->src0_spad_half_size;\n    size_t dst_spad_half_size  = actx->dst_spad_half_size;\n\n    const int BLOCK = actx->block;\n\n    if (BLOCK == 0) {\n        FARF(ERROR, \"silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\\n\",\n                actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);\n        return;\n    }\n\n    dma_queue * dma_queue = actx->octx->ctx->dma[ith];\n\n    // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379\n    for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {\n        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);\n\n        // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)\n        dma_queue_push_vtcm_to_ddr(dma_queue,\n            dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),\n            dst_row_size, dst_row_size_aligned, 0);\n\n        dma_queue_push_ddr_to_vtcm(dma_queue,\n            dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),\n            src0_row_size_aligned, src0_row_size, block_size);\n    }\n\n    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {\n        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);\n\n        float* dst_spad  = (float *) dma_queue_pop(dma_queue).src;\n        float* src0_spad = (float *) dma_queue_pop(dma_queue).dst;\n\n        for (uint32_t ib = 0; ib < block_size; ib++) {\n            const float* src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));\n            float* dst_spad_ptr        = dst_spad  + ib * (dst_row_size_aligned  / sizeof(float));\n\n            // silu = x * sigmoid(x)\n            hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0_val);\n            hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);\n        }\n\n        dma_queue_push_vtcm_to_ddr(dma_queue,\n            dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),\n            dst_row_size, dst_row_size_aligned, block_size);\n\n        // prefetch N+2 loop iteration if any\n        const uint32_t pref_block = (ir + BLOCK * 2);\n        if (pref_block < src0_end_row) {\n            const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);\n            dma_queue_push_ddr_to_vtcm(dma_queue,\n                dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),\n                src0_row_size_aligned, src0_row_size, pref_block_size);\n        }\n    }\n\n    dma_queue_flush(dma_queue);\n\n    t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"silu-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\\n\", ith, nth, ne00, ne01, ne02,\n         ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\nstatic const float GELU_COEF_A     = 0.044715f;\nstatic const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;\n\nstatic void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_act_context * actx = (struct htp_act_context *) data;\n    const struct htp_tensor * src0 = &actx->octx->src0;\n    const struct htp_tensor * src1 = &actx->octx->src1;\n    const struct htp_tensor * dst  = &actx->octx->dst;\n    htp_act_preamble3;\n\n    size_t src0_row_size = actx->src0_row_size;\n    size_t src1_row_size = actx->src1_row_size;\n    size_t dst_row_size  = actx->dst_row_size;\n\n    uint64_t t1, t2;\n    t1 = HAP_perf_get_qtimer_count();\n\n    const uint32_t src0_nrows = actx->src0_nrows;\n    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;\n\n    const uint32_t src0_start_row = src0_nrows_per_thread * ith;\n    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);\n\n    // no work for this thread\n    if (src0_start_row >= src0_end_row) {\n        return;\n    }\n\n    const uint8_t * restrict data_src0 = actx->data_src0;\n    const uint8_t * restrict data_src1 = actx->data_src1;\n    uint8_t * restrict data_dst        = actx->data_dst;\n\n    const int nc = actx->nc;\n\n    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;\n    const size_t src1_row_size_aligned = actx->src1_row_size_aligned;\n    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;\n\n    uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);\n    uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);\n    uint8_t * restrict dst_spad_data  = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);\n\n    size_t src0_spad_half_size = actx->src0_spad_half_size;\n    size_t src1_spad_half_size = actx->src1_spad_half_size;\n    size_t dst_spad_half_size  = actx->dst_spad_half_size;\n\n    const int BLOCK = actx->block;\n    if (BLOCK == 0) {\n        FARF(ERROR,\n             \"geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\\n\",\n             actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);\n        return;\n    }\n\n    dma_queue * dma_queue = actx->octx->ctx->dma[ith];\n\n    // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379\n    for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {\n        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);\n\n        // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)\n        dma_queue_push_vtcm_to_ddr(dma_queue,\n            dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),\n            dst_row_size, dst_row_size_aligned, 0);\n\n        dma_queue_push_ddr_to_vtcm(dma_queue,\n            dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),\n            src0_row_size_aligned, src0_row_size, block_size);\n        dma_queue_push_ddr_to_vtcm(dma_queue,\n            dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),\n            src1_row_size_aligned, src1_row_size, block_size);\n    }\n\n    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {\n        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);\n\n        float * dst_spad  = (float *) dma_queue_pop(dma_queue).src;\n        float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;\n        float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;\n\n        for (uint32_t ib = 0; ib < block_size; ib++) {\n            const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float)));\n            const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float)));\n            uint8_t *       dst_spad_ptr  = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float)));\n\n            // geglu tanh implementation\n            // geglu(x, g) = gelu(x) * g\n            // gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)))\n            hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc);                       // res = x*x\n            hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc);   // res = res * GELU_COEF_A\n            hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc);          // res = res + 1.0f\n            hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc);       // res = res * x\n            hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI\n            hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc);         // res = tanh(res)\n            hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc);           // res = res + 1.0f\n            hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc);       // res = res * x\n            hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc);          // res = res + 0.5f\n            hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc);       // res = res * g\n        }\n\n        dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,\n                                   dst_row_size_aligned, block_size);\n\n        // prefetch N+2 loop iteration if any\n        const uint32_t pref_block = (ir + BLOCK * 2);\n        if (pref_block < src0_end_row) {\n            const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);\n            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),\n                                       src0_row_size_aligned, src0_row_size, pref_block_size);\n            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),\n                                       src1_row_size_aligned, src1_row_size, pref_block_size);\n        }\n    }\n\n    dma_queue_flush(dma_queue);\n\n    t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\\n\", ith, nth,\n         ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,\n         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\nstatic int execute_op_activations_f32(struct htp_ops_context * octx) {\n    const struct htp_tensor * src0 = &octx->src0;\n    const struct htp_tensor * src1 = &octx->src1;\n    struct htp_tensor *       dst  = &octx->dst;\n\n    if (((src0->ne[0] * SIZEOF_FP32) != src0->nb[1]) || ((dst->ne[0] * SIZEOF_FP32) != dst->nb[1])) {\n        FARF(ERROR, \"Non-contiguous tensors are not supported at this time \\n\");\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    worker_callback_t act_op_func;\n    const char *      op_type = NULL;\n\n    switch (octx->op) {\n        case HTP_OP_UNARY_SILU:\n            act_op_func = (worker_callback_t)unary_silu_f32_per_thread;\n            op_type     = \"silu-f32\";\n            break;\n\n        case HTP_OP_GLU_SWIGLU:\n            act_op_func = (worker_callback_t)glu_swiglu_f32_per_thread;\n            op_type     = \"swiglu-f32\";\n            break;\n\n        case HTP_OP_GLU_SWIGLU_OAI:\n            act_op_func = (worker_callback_t)glu_swiglu_oai_f32_per_thread;\n            op_type     = \"swiglu-oai-f32\";\n            break;\n        case HTP_OP_UNARY_GELU:\n            act_op_func = (worker_callback_t)unary_gelu_f32_per_thread;\n            op_type     = \"gelu-f32\";\n            break;\n\n        case HTP_OP_GLU_GEGLU:\n            act_op_func = (worker_callback_t)glu_geglu_f32_per_thread;\n            op_type     = \"geglu-f32\";\n            break;\n        default:\n            FARF(ERROR, \"Unsupported activations Op %u\\n\", octx->op);\n            return HTP_STATUS_NO_SUPPORT;\n    }\n\n    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];\n    const uint32_t n_threads  = MIN(octx->n_threads, src0_nrows);\n\n    size_t src0_row_size = src0->nb[1];\n    size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used\n    size_t dst_row_size  = dst->nb[1];\n\n    const bool src1_valid = src1->ne[0];\n    if (!src1_valid) {\n        src1_row_size = src0_row_size;\n    }\n\n    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);\n    const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);\n    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);\n    // VTCM scratchpads for all tensors\n    // N rows per thread, padded to HVX vector size\n\n    size_t spad_size_per_row   = (src0_row_size_aligned + src1_row_size_aligned) + dst_row_size_aligned;\n    size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads* spad_size_per_row);\n\n    // Make sure the reserved vtcm size is sufficient\n    if(vtcm_row_per_thread ==0){\n        FARF(ERROR, \"act-%s : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\\n\", op_type, octx->ctx->vtcm_size,\n             spad_size_per_row * n_threads);\n        return HTP_STATUS_VTCM_TOO_SMALL;\n    }\n\n    octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread;\n    octx->src1_spad.size_per_thread = src1_row_size_aligned * vtcm_row_per_thread;\n    octx->dst_spad.size_per_thread  = dst_row_size_aligned * vtcm_row_per_thread;\n\n    octx->dst_spad.size  = n_threads* octx->dst_spad.size_per_thread;\n    octx->src0_spad.size = n_threads* octx->src0_spad.size_per_thread;\n    octx->src1_spad.size = n_threads* octx->src1_spad.size_per_thread;\n\n    octx->src0_spad.data = octx->ctx->vtcm_base;\n    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;\n    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;\n\n    if (src1->ne[0]) {\n        FARF(HIGH, \"%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\\n\",\n             op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],\n             src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,\n             octx->dst_spad.size);\n    } else {\n        FARF(HIGH, \"%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\\n\", op_type,\n             src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],\n             octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);\n    }\n\n    if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {\n        return HTP_STATUS_OK;\n    }\n\n    // Prepare context\n    struct htp_act_context actx;\n    actx.octx = octx;\n\n    actx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;\n\n    actx.src0_row_size = src0_row_size;\n    actx.src1_row_size = src1_row_size;\n    actx.dst_row_size  = dst_row_size;\n\n    actx.src0_row_size_aligned = src0_row_size_aligned;\n    actx.src1_row_size_aligned = src1_row_size_aligned;\n    actx.dst_row_size_aligned  = dst_row_size_aligned;\n\n    actx.src0_spad_half_size = octx->src0_spad.size_per_thread / 2;\n    actx.src1_spad_half_size = octx->src1_spad.size_per_thread / 2;\n    actx.dst_spad_half_size  = octx->dst_spad.size_per_thread / 2;\n\n    actx.block = actx.src0_spad_half_size / actx.src0_row_size_aligned;\n    actx.src0_nrows = src0_nrows;\n\n    actx.nc = dst->ne[0];\n\n    // Pointers and GLU logic\n    const uint8_t * data_src0 = (const uint8_t *) src0->data;\n    const uint8_t * data_src1 = (const uint8_t *) src1->data;\n\n    if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) {\n         const int32_t swapped = octx->op_params[1];\n         data_src1 = data_src0;\n         actx.src1_row_size = actx.src0_row_size;\n\n         size_t nc_in_bytes = actx.nc * SIZEOF_FP32;\n         if (swapped) {\n             data_src0 += nc_in_bytes;\n         } else {\n             data_src1 += nc_in_bytes;\n         }\n    }\n\n    actx.data_src0 = data_src0;\n    actx.data_src1 = data_src1;\n    actx.data_dst  = (uint8_t *) dst->data;\n\n    worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_threads);\n    return HTP_STATUS_OK;\n}\n\nint op_activations(struct htp_ops_context * octx) {\n    int err = HTP_STATUS_OK;\n\n    switch (octx->src0.type) {\n        case HTP_TYPE_F32:\n            err = execute_op_activations_f32(octx);\n            break;\n\n        default:\n            err = HTP_STATUS_NO_SUPPORT;\n            break;\n    }\n\n    return err;\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/argsort-ops.c",
    "content": "#include <string.h>\n#include <stdlib.h>\n#include <math.h>\n#include <HAP_farf.h>\n#include <HAP_perf.h>\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n#include \"ggml.h\"\n\n#include \"hvx-utils.h\"\n#include \"hex-dma.h\"\n\n#include \"htp-ctx.h\"\n#include \"htp-msg.h\"\n#include \"htp-ops.h\"\n\n#ifndef MIN\n#define MIN(a, b) ((a) < (b) ? (a) : (b))\n#endif\n\nstruct htp_argsort_context {\n    struct htp_ops_context * octx;\n    uint32_t                 nrows_per_thread;\n};\n\nstatic inline bool all_greater_f32(HVX_Vector x, HVX_Vector y)\n{\n    const HVX_Vector one  = Q6_V_vsplat_R(1);\n    const HVX_Vector zero = Q6_V_vzero();\n\n    HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y);\n    HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero);\n    HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);\n    return hvx_vec_get_i32(sum) == 32;\n}\n\n// Sorts values and mirrors swaps to indices.\nstatic void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {\n    if (left >= right) return;\n\n    int pivot_idx = (left + right) / 2;\n    float pivot = values[pivot_idx];\n    int i = left;\n    int j = right;\n\n    HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);\n    while (i <= j) {\n        // Vectorized scan for i\n        while (i <= j) {\n            // Check if we have at least one full vector\n            if (i + 32 <= j) {\n                HVX_Vector vals_vec = *(HVX_UVector *)(values + i);\n                if (all_greater_f32(pivot_vec, vals_vec)) {\n                    // If all elements are < pivot, we can skip this whole block\n                    i += 32;\n                    continue;\n                }\n            }\n\n            // Scalar fallback / cleanup\n            if (values[i] < pivot) {\n                i++;\n            } else {\n                break;\n            }\n        }\n\n        // Vectorized scan for j\n        while (i <= j) {\n            if (j - 32 >= i) {\n                // Load 32 elements ending at j.\n                // Since we want `values[j] > pivot`, let's load from j-31 to j.\n                HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);\n                if (all_greater_f32(vals_vec, pivot_vec)) {\n                    j -= 32;\n                    continue;\n                }\n            }\n\n            if (values[j] > pivot) {\n                j--;\n            } else {\n                break;\n            }\n        }\n\n        if (i <= j) {\n            float tmp_val = values[i];\n            values[i] = values[j];\n            values[j] = tmp_val;\n\n            int32_t tmp_idx = indices[i];\n            indices[i] = indices[j];\n            indices[j] = tmp_idx;\n            i++;\n            j--;\n        }\n    }\n\n    if (left < j) quicksort_values_indices_asc(values, indices, left, j);\n    if (i < right) quicksort_values_indices_asc(values, indices, i, right);\n}\n\nstatic void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {\n    if (left >= right) return;\n\n    int pivot_idx = (left + right) / 2;\n    float pivot = values[pivot_idx];\n    int i = left;\n    int j = right;\n\n    HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);\n\n    while (i <= j) {\n        // Vectorized scan for i (values[i] > pivot)\n        while (i <= j) {\n            if (i + 32 <= j) {\n                HVX_Vector vals_vec = *(HVX_UVector *)(values + i);\n                if (all_greater_f32(vals_vec, pivot_vec)) {\n                    i += 32;\n                    continue;\n                }\n            }\n\n            if (values[i] > pivot) {\n                i++;\n            } else {\n                break;\n            }\n        }\n\n        // Vectorized scan for j (values[j] < pivot)\n        while (i <= j) {\n            if (j - 32 >= i) {\n                HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);\n                if (all_greater_f32(pivot_vec, vals_vec)) {\n                    j -= 32;\n                    continue;\n                }\n            }\n\n            if (values[j] < pivot) {\n                j--;\n            } else {\n                break;\n            }\n        }\n\n        if (i <= j) {\n            float tmp_val = values[i];\n            values[i] = values[j];\n            values[j] = tmp_val;\n\n            int32_t tmp_idx = indices[i];\n            indices[i] = indices[j];\n            indices[j] = tmp_idx;\n            i++;\n            j--;\n        }\n    }\n\n    if (left < j) quicksort_values_indices_desc(values, indices, left, j);\n    if (i < right) quicksort_values_indices_desc(values, indices, i, right);\n}\n\nstatic void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {\n    struct htp_argsort_context * actx = (struct htp_argsort_context *)data;\n    struct htp_ops_context * octx = actx->octx;\n\n    // Unpack context\n    const struct htp_tensor * src0 = &octx->src0;\n    const struct htp_tensor * dst = &octx->dst;\n\n    // Scratchpad memory\n    uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;\n\n    // Dimensions\n    uint32_t ne00 = src0->ne[0];\n    uint32_t ne01 = src0->ne[1];\n    uint32_t ne02 = src0->ne[2];\n    uint32_t ne03 = src0->ne[3];\n\n    uint32_t nb01 = src0->nb[1];\n    //uint32_t nb02 = src0->nb[2];\n    //uint32_t nb03 = src0->nb[3];\n\n    uint32_t nb1 = dst->nb[1];\n    //uint32_t nb2 = dst->nb[2];\n    //uint32_t nb3 = dst->nb[3];\n\n    // Sort order\n    enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];\n\n    // Rows to process\n    uint32_t total_rows = ne01 * ne02 * ne03;\n    uint32_t rows_per_thread = actx->nrows_per_thread;\n    uint32_t start_row = rows_per_thread * i;\n    uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);\n\n    // Scratchpad layout:\n    // We need space for one row of float data (values) and one row of int32 indices.\n    // values: ne00 * sizeof(float)\n    // indices: ne00 * sizeof(int32_t)\n    // Padded to 128 bytes.\n\n    size_t values_size = hex_round_up(ne00 * sizeof(float), 128);\n    float * values_buf = (float *) spad;\n    int32_t * indices_buf = (int32_t *) (spad + values_size);\n\n    for (uint32_t r = start_row; r < end_row; r++) {\n        uint32_t src_offset = r * nb01;\n        uint32_t dst_offset = r * nb1;\n\n        uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;\n        uint8_t * dst_ptr = (uint8_t *) dst->data  + dst_offset;\n\n        hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);\n        hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);\n\n        // Initialize indices\n        for (uint32_t j = 0; j < ne00; j++) {\n            indices_buf[j] = j;\n        }\n\n        // Sort values and mirror swaps to indices\n        if (order == GGML_SORT_ORDER_ASC) {\n            quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);\n        } else {\n            quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);\n        }\n\n        // Copy indices back to DDR\n        hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00);\n    }\n}\n\nint op_argsort(struct htp_ops_context * octx) {\n    // Check supported types\n    if (octx->src0.type != HTP_TYPE_F32) {\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    const uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];\n    const uint32_t n_threads = MIN(total_rows, octx->n_threads);\n\n    // Allocate scratchpad\n    // We need 1 row of float + 1 row of int32 per thread.\n    uint32_t ne00 = octx->src0.ne[0];\n    size_t values_size  = hex_round_up(ne00 * sizeof(float), 128);\n    size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);\n    size_t spad_per_thread = values_size + indices_size;\n\n    // Make sure we round up to 256 for alignment requirements\n    spad_per_thread = hex_round_up(spad_per_thread, 256);\n\n    size_t total_spad_size = spad_per_thread * n_threads;\n\n    if (octx->ctx->vtcm_size < total_spad_size) {\n        FARF(ERROR, \"argsort: VTCM size too small. Needed %zu, have %zu\", total_spad_size, octx->ctx->vtcm_size);\n        return HTP_STATUS_VTCM_TOO_SMALL;\n    }\n\n    octx->src0_spad.data = octx->ctx->vtcm_base;\n    octx->src0_spad.size = total_spad_size;\n    octx->src0_spad.size_per_thread = spad_per_thread;\n\n    FARF(HIGH, \"argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)\",\n         octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],\n         octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],\n         octx->src0.data, octx->dst.data);\n\n    struct htp_argsort_context actx;\n    actx.octx = octx;\n    actx.nrows_per_thread = (total_rows + n_threads - 1) / n_threads;\n\n    // Run jobs\n    worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_threads);\n\n    return HTP_STATUS_OK;\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/binary-ops.c",
    "content": "#pragma clang diagnostic ignored \"-Wunused-variable\"\n#pragma clang diagnostic ignored \"-Wunused-function\"\n#pragma clang diagnostic ignored \"-Wunused-but-set-variable\"\n\n#include <HAP_farf.h>\n#include <HAP_perf.h>\n\n#include <math.h>\n#include <string.h>\n\n#include \"hex-dma.h\"\n#include \"hvx-utils.h\"\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n#include \"htp-ctx.h\"\n#include \"htp-msg.h\"\n#include \"htp-ops.h\"\n\n#ifndef MIN\n#define MIN(a, b) ((a) < (b) ? (a) : (b))\n#endif\n\n// Context for binary operations\nstruct htp_binary_context {\n    struct htp_ops_context * octx;\n    struct fastdiv_values dim1_div;\n    struct fastdiv_values dim2_div;\n    struct fastdiv_values dim12_div;\n\n    struct fastdiv_values src1_dim1_div; // ne11\n    struct fastdiv_values src1_dim2_div; // ne12\n    struct fastdiv_values src1_dim3_div; // ne13\n\n    uint32_t nrows_per_thread;\n    bool split_at_ne01;\n    bool split_at_ne02;\n\n    // Precomputed values\n    uint32_t block_max;\n    size_t   src0_row_size_aligned;\n    size_t   src1_row_size_aligned;\n    size_t   dst_row_size_aligned;\n    uint32_t src1_fetch_rows; // 1 or block_max\n    uint32_t src1_dma_stride; // 0 or stride\n};\n\n#define htp_binary_preamble            \\\n    const struct htp_tensor * src0 = &octx->src0; \\\n    const struct htp_tensor * src1 = &octx->src1; \\\n    struct htp_tensor *       dst  = &octx->dst;  \\\n                                       \\\n    const uint32_t ne00 = src0->ne[0]; \\\n    const uint32_t ne01 = src0->ne[1]; \\\n    const uint32_t ne02 = src0->ne[2]; \\\n    const uint32_t ne03 = src0->ne[3]; \\\n                                       \\\n    const uint32_t ne10 = src1->ne[0]; \\\n    const uint32_t ne11 = src1->ne[1]; \\\n    const uint32_t ne12 = src1->ne[2]; \\\n    const uint32_t ne13 = src1->ne[3]; \\\n                                       \\\n    const uint32_t nb01 = src0->nb[1]; \\\n    const uint32_t nb02 = src0->nb[2]; \\\n    const uint32_t nb03 = src0->nb[3]; \\\n                                       \\\n    const uint32_t nb11 = src1->nb[1]; \\\n    const uint32_t nb12 = src1->nb[2]; \\\n    const uint32_t nb13 = src1->nb[3]; \\\n                                       \\\n    const uint32_t nb1 = dst->nb[1];   \\\n    const uint32_t nb2 = dst->nb[2];   \\\n    const uint32_t nb3 = dst->nb[3];\n\nstatic inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row,\n                                uint32_t ne01, uint32_t ne02) {\n    uint32_t i03, i02, i01, rem;\n    i03 = fastdiv(ir, &bctx->dim12_div);\n    rem = ir - i03 * (ne02 * ne01);\n    i02 = fastdiv(rem, &bctx->dim1_div);\n    i01 = rem - i02 * ne01;\n\n    uint32_t rows_left = end_row - ir;\n    uint32_t block_limit = rows_left;\n\n    if (bctx->split_at_ne01) {\n        block_limit = MIN(block_limit, ne01 - i01);\n    }\n    if (bctx->split_at_ne02) {\n         uint32_t rows_in_plane = (ne02 * ne01) - rem;\n         block_limit = MIN(block_limit, rows_in_plane);\n    }\n\n    return MIN(bctx->block_max, block_limit);\n}\n\n// Macro for scalar op switch\n#define COMPUTE_SCALAR_OP(DST, SRC, VAL, TYPE, N) \\\n    if(TYPE == HTP_TYPE_F32) { \\\n        switch (octx->op) { \\\n            case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \\\n            case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \\\n            case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \\\n            case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (*(float *)VAL), N); break; \\\n            default: break; \\\n        } \\\n    } \\\n    else { \\\n        switch (octx->op) { \\\n            case HTP_OP_ADD: hvx_add_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \\\n            case HTP_OP_SUB: hvx_sub_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \\\n            case HTP_OP_MUL: hvx_mul_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \\\n            case HTP_OP_DIV: hvx_div_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \\\n            default: break; \\\n        } \\\n    }\n\n// Macro for vector op switch (All Aligned)\n#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, TYPE, N) \\\n    if(TYPE == HTP_TYPE_F32) { \\\n        switch (octx->op) { \\\n            case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \\\n            default: break; \\\n        } \\\n    } \\\n    else { \\\n        switch (octx->op) { \\\n            case HTP_OP_ADD: hvx_add_f16_aaa(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_SUB: hvx_sub_f16_aaa(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_MUL: hvx_mul_f16_aaa(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_DIV: hvx_div_f16_aaa(DST, SRC0, SRC1, N); break; \\\n            default: break; \\\n        } \\\n    }\n\n// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned)\n#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, TYPE, N) \\\n    if(TYPE == HTP_TYPE_F32) { \\\n        switch (octx->op) { \\\n            case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \\\n            default: break; \\\n        } \\\n    } \\\n    else { \\\n        switch (octx->op) { \\\n            case HTP_OP_ADD: hvx_add_f16_aau(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_SUB: hvx_sub_f16_aau(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_MUL: hvx_mul_f16_aau(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_DIV: hvx_div_f16_aau(DST, SRC0, SRC1, N); break; \\\n            default: break; \\\n        } \\\n    }\n\n// Macro for vector op switch (All Unaligned - generic loop used in element repeat)\n#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, TYPE, N) \\\n    if(TYPE == HTP_TYPE_F32) { \\\n        switch (octx->op) { \\\n            case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \\\n            default: break; \\\n        } \\\n    } \\\n    else { \\\n        switch (octx->op) { \\\n            case HTP_OP_ADD: hvx_add_f16_uuu(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_SUB: hvx_sub_f16_uuu(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_MUL: hvx_mul_f16_uuu(DST, SRC0, SRC1, N); break; \\\n            case HTP_OP_DIV: hvx_div_f16_uuu(DST, SRC0, SRC1, N); break; \\\n            default: break; \\\n        } \\\n    }\n\n// 1. Scalar src1 (ne10 == 1)\nstatic void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_binary_context * bctx = (struct htp_binary_context *) data;\n    struct htp_ops_context * octx = bctx->octx;\n    htp_binary_preamble;\n\n    const uint32_t src0_type = octx->src0.type;\n    const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);\n    const uint32_t total_rows = ne01 * ne02 * ne03;\n    const uint32_t start_row = bctx->nrows_per_thread * ith;\n    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);\n    if (start_row >= end_row) return;\n\n    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);\n    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);\n    size_t src0_spad_half    = octx->src0_spad.size_per_thread / 2;\n    size_t dst_spad_half     = octx->dst_spad.size_per_thread  / 2;\n\n    dma_queue * q = octx->ctx->dma[ith];\n    uint32_t ir_prefetch = start_row;\n    int spad_idx = 0;\n\n    // Preamble\n    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {\n        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);\n        uint32_t i03, i02, i01, rem;\n        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);\n        rem = ir_prefetch - i03 * (ne02 * ne01);\n        i02 = fastdiv(rem, &bctx->dim1_div);\n        i01 = rem - i02 * ne01;\n\n        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;\n        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;\n\n        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;\n        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;\n\n        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);\n        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);\n        ir_prefetch += current_block_size;\n        spad_idx ^= 1;\n    }\n\n    // Main loop\n    for (uint32_t ir = start_row; ir < end_row; ) {\n        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);\n\n        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;\n        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;\n\n        uint32_t i03, i02, i01, rem;\n        i03 = fastdiv(ir, &bctx->dim12_div);\n        rem = ir - i03 * (ne02 * ne01);\n        i02 = fastdiv(rem, &bctx->dim1_div);\n        i01 = rem - i02 * ne01;\n\n        // src1 indices (broadcast/repeat)\n        uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);\n        uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);\n        uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_dim1_div);\n\n        uint8_t * src1_ptr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;\n        uint32_t s1_stride = (ne11 == 1) ? 0 : nb11;\n\n        for (uint32_t r = 0; r < current_block_size; r++) {\n            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;\n            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;\n            COMPUTE_SCALAR_OP(r_dst, r_src0, src1_ptr, src0_type, ne00);\n            src1_ptr += s1_stride;\n        }\n\n        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;\n        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);\n\n        if (ir_prefetch < end_row) {\n             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);\n             uint32_t p03, p02, p01, prem;\n             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);\n             prem = ir_prefetch - p03 * (ne02 * ne01);\n             p02 = fastdiv(prem, &bctx->dim1_div);\n             p01 = prem - p02 * ne01;\n             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;\n\n             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);\n             ir_prefetch += next_block_size;\n        }\n        ir += current_block_size;\n    }\n    dma_queue_flush(q);\n}\n\n// 2. Vector Same Shape (ne1x == ne0x) or Simple Broadcast\nstatic void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_binary_context * bctx = (struct htp_binary_context *) data;\n    struct htp_ops_context * octx = bctx->octx;\n    htp_binary_preamble;\n\n    const uint32_t src0_type = octx->src0.type;\n    const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);\n    const uint32_t total_rows = ne01 * ne02 * ne03;\n    const uint32_t start_row = bctx->nrows_per_thread * ith;\n    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);\n    if (start_row >= end_row) return;\n\n    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);\n    uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);\n    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);\n\n    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;\n    size_t src1_spad_half = octx->src1_spad.size_per_thread / 2;\n    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;\n\n    dma_queue * q = octx->ctx->dma[ith];\n    uint32_t ir_prefetch = start_row;\n    int spad_idx = 0;\n\n    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {\n        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);\n        uint32_t i03, i02, i01, rem;\n        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);\n        rem = ir_prefetch - i03 * (ne02 * ne01);\n        i02 = fastdiv(rem, &bctx->dim1_div);\n        i01 = rem - i02 * ne01;\n\n        uint32_t i13 = (ne13 == 1) ? 0 : i03;\n        uint32_t i12 = (ne12 == 1) ? 0 : i02;\n        uint32_t i11 = (ne11 == 1) ? 0 : i01;\n\n        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;\n        uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;\n        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;\n\n        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;\n        uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half;\n        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;\n\n        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);\n        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);\n        dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size);\n        ir_prefetch += current_block_size;\n        spad_idx ^= 1;\n    }\n\n    for (uint32_t ir = start_row; ir < end_row; ) {\n        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);\n        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;\n        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;\n        uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst;\n\n        for (uint32_t r = 0; r < current_block_size; r++) {\n            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;\n            uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned;\n            uint8_t * r_dst  = d_spad  + r * bctx->dst_row_size_aligned;\n            COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);\n        }\n\n        uint32_t i03, i02, i01, rem;\n        i03 = fastdiv(ir, &bctx->dim12_div);\n        rem = ir - i03 * (ne02 * ne01);\n        i02 = fastdiv(rem, &bctx->dim1_div);\n        i01 = rem - i02 * ne01;\n        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;\n        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);\n\n        if (ir_prefetch < end_row) {\n             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);\n             uint32_t p03, p02, p01, prem;\n             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);\n             prem = ir_prefetch - p03 * (ne02 * ne01);\n             p02 = fastdiv(prem, &bctx->dim1_div);\n             p01 = prem - p02 * ne01;\n\n             uint32_t p13 = (ne13 == 1) ? 0 : p03;\n             uint32_t p12 = (ne12 == 1) ? 0 : p02;\n             uint32_t p11 = (ne11 == 1) ? 0 : p01;\n\n             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;\n             uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;\n\n             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);\n             dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size);\n\n             ir_prefetch += next_block_size;\n        }\n        ir += current_block_size;\n    }\n    dma_queue_flush(q);\n}\n\n// 3. Row Broadcast (ne11 == 1, ne12 == 1, single row src1)\nstatic void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_binary_context * bctx = (struct htp_binary_context *) data;\n    struct htp_ops_context * octx = bctx->octx;\n    htp_binary_preamble;\n\n    const uint32_t src0_type = octx->src0.type;\n    const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);\n    const uint32_t total_rows = ne01 * ne02 * ne03;\n    const uint32_t start_row = bctx->nrows_per_thread * ith;\n    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);\n    if (start_row >= end_row) return;\n\n    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);\n    uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);\n    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);\n\n    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;\n    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;\n\n    dma_queue * q = octx->ctx->dma[ith];\n    uint32_t ir_prefetch = start_row;\n    int spad_idx = 0;\n\n    void * s1_ptr = (void *) src1_spad;\n\n    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {\n        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);\n        uint32_t i03, i02, i01, rem;\n        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);\n        rem = ir_prefetch - i03 * (ne02 * ne01);\n        i02 = fastdiv(rem, &bctx->dim1_div);\n        i01 = rem - i02 * ne01;\n\n        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;\n        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;\n\n        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;\n        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;\n\n        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);\n        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);\n        ir_prefetch += current_block_size;\n        spad_idx ^= 1;\n    }\n\n    for (uint32_t ir = start_row; ir < end_row; ) {\n        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);\n        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;\n        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;\n\n        for (uint32_t r = 0; r < current_block_size; r++) {\n            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;\n            uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant\n            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;\n            COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);\n        }\n\n        uint32_t i03, i02, i01, rem;\n        i03 = fastdiv(ir, &bctx->dim12_div);\n        rem = ir - i03 * (ne02 * ne01);\n        i02 = fastdiv(rem, &bctx->dim1_div);\n        i01 = rem - i02 * ne01;\n        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;\n        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);\n\n        if (ir_prefetch < end_row) {\n             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);\n             uint32_t p03, p02, p01, prem;\n             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);\n             prem = ir_prefetch - p03 * (ne02 * ne01);\n             p02 = fastdiv(prem, &bctx->dim1_div);\n             p01 = prem - p02 * ne01;\n             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;\n             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);\n             ir_prefetch += next_block_size;\n        }\n        ir += current_block_size;\n    }\n    dma_queue_flush(q);\n}\n\n// 4. Vector Complex (ne10 == ne00, complex broadcast)\nstatic void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_binary_context * bctx = (struct htp_binary_context *) data;\n    struct htp_ops_context * octx = bctx->octx;\n    htp_binary_preamble;\n\n    const uint32_t src0_type = octx->src0.type;\n    const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);\n    const uint32_t total_rows = ne01 * ne02 * ne03;\n    const uint32_t start_row = bctx->nrows_per_thread * ith;\n    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);\n    if (start_row >= end_row) return;\n\n    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);\n    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);\n    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;\n    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;\n\n    dma_queue * q = octx->ctx->dma[ith];\n    uint32_t ir_prefetch = start_row;\n    int spad_idx = 0;\n\n    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {\n        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);\n        uint32_t i03, i02, i01, rem;\n        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);\n        rem = ir_prefetch - i03 * (ne02 * ne01);\n        i02 = fastdiv(rem, &bctx->dim1_div);\n        i01 = rem - i02 * ne01;\n\n        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;\n        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;\n\n        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;\n        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;\n\n        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);\n        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);\n        ir_prefetch += current_block_size;\n        spad_idx ^= 1;\n    }\n\n    for (uint32_t ir = start_row; ir < end_row; ) {\n        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);\n        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;\n        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;\n\n        uint32_t i03, i02, i01, rem;\n        i03 = fastdiv(ir, &bctx->dim12_div);\n        rem = ir - i03 * (ne02 * ne01);\n        i02 = fastdiv(rem, &bctx->dim1_div);\n        i01 = rem - i02 * ne01;\n\n        for (uint32_t r = 0; r < current_block_size; r++) {\n            uint32_t r_i01 = i01 + r;\n            uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);\n            uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);\n            uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);\n\n            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;\n            uint8_t * r_src1 = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;\n            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;\n\n            // Read src1 from DDR (unaligned)\n            COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, src0_type, ne00);\n        }\n\n        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;\n        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);\n\n        if (ir_prefetch < end_row) {\n             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);\n             uint32_t p03, p02, p01, prem;\n             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);\n             prem = ir_prefetch - p03 * (ne02 * ne01);\n             p02 = fastdiv(prem, &bctx->dim1_div);\n             p01 = prem - p02 * ne01;\n             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;\n             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);\n             ir_prefetch += next_block_size;\n        }\n        ir += current_block_size;\n    }\n    dma_queue_flush(q);\n}\n\n// 5. Element Repeat (ne10 != ne00)\nstatic void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_binary_context * bctx = (struct htp_binary_context *) data;\n    struct htp_ops_context * octx = bctx->octx;\n    htp_binary_preamble;\n\n    const uint32_t src0_type = octx->src0.type;\n    const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);\n    const uint32_t row_size_bytes = ne00 * elem_size_bytes;;\n    const uint32_t total_rows = ne01 * ne02 * ne03;\n    const uint32_t start_row = bctx->nrows_per_thread * ith;\n    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);\n    if (start_row >= end_row) return;\n\n    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);\n    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);\n    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;\n    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;\n\n    dma_queue * q = octx->ctx->dma[ith];\n    uint32_t ir_prefetch = start_row;\n    int spad_idx = 0;\n\n    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {\n        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);\n        uint32_t i03, i02, i01, rem;\n        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);\n        rem = ir_prefetch - i03 * (ne02 * ne01);\n        i02 = fastdiv(rem, &bctx->dim1_div);\n        i01 = rem - i02 * ne01;\n\n        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;\n        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;\n\n        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;\n        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;\n\n        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);\n        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);\n        ir_prefetch += current_block_size;\n        spad_idx ^= 1;\n    }\n\n    for (uint32_t ir = start_row; ir < end_row; ) {\n        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);\n        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;\n        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;\n\n        uint32_t i03, i02, i01, rem;\n        i03 = fastdiv(ir, &bctx->dim12_div);\n        rem = ir - i03 * (ne02 * ne01);\n        i02 = fastdiv(rem, &bctx->dim1_div);\n        i01 = rem - i02 * ne01;\n\n        for (uint32_t r = 0; r < current_block_size; r++) {\n            uint32_t r_i01 = i01 + r;\n            uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);\n            uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);\n            uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);\n\n            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;\n            uint8_t * r_src1_row = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;\n            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;\n\n            // Repeat src1 row\n            for (uint32_t c = 0; c < ne00; c += ne10) {\n                uint32_t len = MIN(ne10, ne00 - c);\n                // Use UUU for speed and simplicity\n                COMPUTE_VECTOR_OP_UUU(r_dst + c * elem_size_bytes, r_src0 + c * elem_size_bytes, r_src1_row, src0_type, len);\n            }\n        }\n\n        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;\n        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);\n\n        if (ir_prefetch < end_row) {\n             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);\n             uint32_t p03, p02, p01, prem;\n             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);\n             prem = ir_prefetch - p03 * (ne02 * ne01);\n             p02 = fastdiv(prem, &bctx->dim1_div);\n             p01 = prem - p02 * ne01;\n             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;\n             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);\n             ir_prefetch += next_block_size;\n        }\n        ir += current_block_size;\n    }\n    dma_queue_flush(q);\n}\n\n// 6. ADD_ID (src1 gathered via src2 indices)\nstatic void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_binary_context * bctx = (struct htp_binary_context *) data;\n    struct htp_ops_context * octx = bctx->octx;\n\n    const struct htp_tensor * src0 = &octx->src0;\n    const struct htp_tensor * src1 = &octx->src1;\n    const struct htp_tensor * src2 = &octx->src2;\n    struct htp_tensor *       dst  = &octx->dst;\n\n    const uint32_t ne00 = src0->ne[0];\n    const uint32_t ne01 = src0->ne[1];\n    const uint32_t ne02 = src0->ne[2];\n    const uint32_t ne03 = src0->ne[3];\n    const uint32_t ne11 = src1->ne[1]; // for bounds check\n\n    const uint32_t nb01 = src0->nb[1];\n    const uint32_t nb02 = src0->nb[2];\n    const uint32_t nb03 = src0->nb[3];\n    const uint32_t nb11 = src1->nb[1]; // src1 row stride\n    const uint32_t nb1 = dst->nb[1];\n    const uint32_t nb2 = dst->nb[2];\n    const uint32_t nb3 = dst->nb[3];\n\n    const uint32_t total_rows = ne01 * ne02 * ne03;\n    const uint32_t start_row = bctx->nrows_per_thread * ith;\n    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);\n    if (start_row >= end_row) return;\n\n    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);\n    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);\n    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;\n    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;\n\n    dma_queue * q = octx->ctx->dma[ith];\n    uint32_t ir_prefetch = start_row;\n    int spad_idx = 0;\n\n    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {\n        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);\n        uint32_t i03, i02, i01, rem;\n        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);\n        rem = ir_prefetch - i03 * (ne02 * ne01);\n        i02 = fastdiv(rem, &bctx->dim1_div);\n        i01 = rem - i02 * ne01;\n\n        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;\n        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;\n\n        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;\n        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;\n\n        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);\n        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);\n        ir_prefetch += current_block_size;\n        spad_idx ^= 1;\n    }\n\n    for (uint32_t ir = start_row; ir < end_row; ) {\n        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);\n        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;\n        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;\n\n        uint32_t i03, i02, i01, rem;\n        i03 = fastdiv(ir, &bctx->dim12_div);\n        rem = ir - i03 * (ne02 * ne01);\n        i02 = fastdiv(rem, &bctx->dim1_div);\n        i01 = rem - i02 * ne01;\n\n        for (uint32_t r = 0; r < current_block_size; r++) {\n            uint32_t r_i01 = i01 + r; // linear within block since we split at ne01\n\n            const int32_t idx = *(int32_t *)((char *)src2->data + r_i01 * src2->nb[0] + i02 * src2->nb[1]);\n\n            uint8_t * r_src1 = (uint8_t *)src1->data + idx * nb11;\n            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;\n            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;\n\n            hvx_add_f32_aau(r_dst, r_src0, r_src1, ne00);\n        }\n\n        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;\n        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);\n\n        if (ir_prefetch < end_row) {\n             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);\n             uint32_t p03, p02, p01, prem;\n             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);\n             prem = ir_prefetch - p03 * (ne02 * ne01);\n             p02 = fastdiv(prem, &bctx->dim1_div);\n             p01 = prem - p02 * ne01;\n             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;\n             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);\n             ir_prefetch += next_block_size;\n        }\n        ir += current_block_size;\n    }\n    dma_queue_flush(q);\n}\n\nstatic int execute_op_binary(struct htp_ops_context * octx) {\n    const struct htp_tensor * src0 = &octx->src0;\n    const struct htp_tensor * src1 = &octx->src1;\n    struct htp_tensor *       dst  = &octx->dst;\n\n    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];\n    const uint32_t n_threads  = MIN(octx->n_threads, src0_nrows);\n\n    // Use packed row sizes for VTCM allocation\n    const uint32_t src0_type = octx->src0.type;\n    const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);\n    const size_t src0_row_size = src0->ne[0] * elem_size;\n    const size_t src1_row_size = src1->ne[0] * elem_size;\n    const size_t dst_row_size  = dst->ne[0] * elem_size;\n\n    // Align to VLEN\n    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);\n    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);\n    size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);\n\n    bool is_add_id = (octx->op == HTP_OP_ADD_ID);\n    bool is_scalar = !is_add_id && (src1->ne[0] == 1);\n\n    // Determine which kernel we will use to alloc memory and dispatch\n    bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) &&\n               (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&\n               (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&\n               (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);\n\n    bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);\n    bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]);\n    bool use_repeat  = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]);\n\n    size_t spad_row_total;\n    if (is_scalar) {\n        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);\n    } else if (is_row_bcast) {\n        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);\n    } else if (use_vector_same) {\n        spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned);\n    } else if (is_add_id) {\n        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly\n    } else {\n        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);\n    }\n\n    size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total);\n    // Adjust for static src1 in row_bcast case\n    if (is_row_bcast) {\n        size_t needed_static = src1_row_size_aligned;\n        if (octx->ctx->vtcm_size < needed_static) return HTP_STATUS_VTCM_TOO_SMALL;\n        size_t avail = octx->ctx->vtcm_size - needed_static;\n        rows_per_buffer = avail / (n_threads * spad_row_total);\n    }\n\n    if (rows_per_buffer < 1) {\n         FARF(ERROR, \"binary: VTCM too small\\n\");\n         return HTP_STATUS_VTCM_TOO_SMALL;\n    }\n\n    octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned;\n    octx->dst_spad.size_per_thread  = rows_per_buffer * 2 * dst_row_size_aligned;\n\n    if (is_scalar || use_complex || use_repeat || is_add_id) {\n        octx->src1_spad.size_per_thread = 0;\n    } else if (is_row_bcast) {\n        octx->src1_spad.size_per_thread = 0;\n    } else {\n        octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned;\n    }\n\n    octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;\n    if (is_row_bcast) {\n        octx->src1_spad.size = src1_row_size_aligned;\n    } else {\n        octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;\n    }\n    octx->dst_spad.size  = n_threads * octx->dst_spad.size_per_thread;\n\n    if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) {\n        return HTP_STATUS_VTCM_TOO_SMALL;\n    }\n\n    octx->src0_spad.data = octx->ctx->vtcm_base;\n    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;\n    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;\n\n    if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {\n        return HTP_STATUS_OK;\n    }\n\n    dma_queue * q = octx->ctx->dma[0];\n    if (is_row_bcast) {\n        dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * elem_size, 1);\n    }\n\n    struct htp_binary_context bctx;\n    bctx.octx = octx;\n    bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;\n    bctx.block_max = rows_per_buffer;\n    bctx.src0_row_size_aligned = src0_row_size_aligned;\n    bctx.src1_row_size_aligned = src1_row_size_aligned;\n    bctx.dst_row_size_aligned  = dst_row_size_aligned;\n\n    bctx.dim1_div = init_fastdiv_values(src0->ne[1]);\n    bctx.dim2_div = init_fastdiv_values(src0->ne[2]);\n    bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]);\n\n    bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]);\n    bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]);\n    bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]);\n\n    bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]);\n    bool dst_contig_dim1  = (dst->nb[2] == src0->ne[1] * dst->nb[1]);\n\n    bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]);\n    bool dst_contig_dim2  = (dst->nb[3] == src0->ne[2] * dst->nb[2]);\n\n    bctx.split_at_ne01 = (src0->ne[2] > 1) &&\n                         ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1);\n\n    bctx.split_at_ne02 = (src0->ne[3] > 1) &&\n                         ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);\n\n    // Precompute specific kernel parameters\n    if (use_vector_same) {\n        bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1];\n        bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer;\n    }\n\n    worker_callback_t worker_func;\n    if (is_add_id) worker_func = binary_job_add_id;\n    else if (is_scalar) worker_func = binary_job_scalar;\n    else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast;\n    else if (use_vector_same) worker_func = binary_job_vector_same_shape;\n    else if (use_complex) worker_func = binary_job_vector_complex;\n    else worker_func = binary_job_element_repeat;\n\n    if (is_row_bcast) {\n        dma_queue_pop(q);\n    }\n\n    worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_threads);\n\n    return HTP_STATUS_OK;\n}\n\nint op_binary(struct htp_ops_context * octx) {\n\n    // Does not support permutations of src1\n    const struct htp_tensor * src1 = &octx->src1;\n    if (src1->nb[1] < src1->nb[0]) {\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    const uint32_t src0_type = octx->src0.type;\n    if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) {\n        return execute_op_binary(octx);\n    }\n\n    return HTP_STATUS_NO_SUPPORT;\n}\n\n"
  },
  {
    "path": "src/ggml-hexagon/htp/cmake-toolchain.cmake",
    "content": "if (HEXAGON_TOOLCHAIN_INCLUDED)\n  return()\nendif()\nset(HEXAGON_TOOLCHAIN_INCLUDED true)\n\n#Cross Compiling for Hexagon\nset(HEXAGON TRUE)\nset(CMAKE_SYSTEM_NAME QURT)\nset(CMAKE_SYSTEM_PROCESSOR Hexagon)\nset(CMAKE_SYSTEM_VERSION \"1\") #${HEXAGON_PLATFORM_LEVEL})\nset(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)\nset(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)\nset(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)\nset(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)\nset(CUSTOM_RUNELF_PATH \"\")\n\n#To fix backward compatibility with EAI addon.\nif (NOT HEXAGON_SDK_ROOT)\n    set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})\nendif()\n\nif (NOT HEXAGON_TOOLS_ROOT)\n    if (DEFINED ENV{HEXAGON_TOOLS_ROOT})\n        set(HEXAGON_TOOLS_ROOT $ENV{HEXAGON_TOOLS_ROOT})\n    endif()\n    if(NOT HEXAGON_TOOLS_ROOT)\n        set(HEXAGON_TOOLS_ROOT $ENV{DEFAULT_HEXAGON_TOOLS_ROOT})\n    endif()\nendif()\n\nfile(TO_CMAKE_PATH \"${HEXAGON_TOOLS_ROOT}\" HEXAGON_TOOLS_ROOT)\nfile(TO_CMAKE_PATH \"${HEXAGON_SDK_ROOT}\"   HEXAGON_SDK_ROOT)\n\n#Get the Binary extension of the Hexagon Toolchain\nif(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows)\n    set(HEXAGON_TOOLCHAIN_SUFFIX .exe)\nendif()\nmessage(DEBUG \"CMAKE_HOST_SYSTEM_NAME:${CMAKE_HOST_SYSTEM_NAME}\")\n\ninclude(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_arch.cmake)\n\nset(HEXAGON_TOOLCHAIN ${HEXAGON_TOOLS_ROOT})\nset(HEXAGON_LIB_DIR \"${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib\")\nset(HEXAGON_ISS_DIR ${HEXAGON_TOOLCHAIN}/Tools/lib/iss)\n\nset(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES\n    HEXAGON_SDK_ROOT\n    HEXAGON_TOOLS_ROOT\n)\n\n#QURT Related includes and linker flags\nset(V_ARCH ${HEXAGON_ARCH})\nset(_QURT_INSTALL_DIR \"${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}\")\nset(_QURT_INSTALL_DIR \"${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}\")\n\nif( ${TREE} MATCHES PAKMAN )\n    set(_QURT_INSTALL_DIR \"${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}\")\nendif()\nmessage(DEBUG \"_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}\")\nset(RTOS_DIR ${_QURT_INSTALL_DIR})\nset(QCC_DIR \"${HEXAGON_QCC_DIR}/${V_ARCH}/G0\")\nset(TARGET_DIR \"${HEXAGON_LIB_DIR}/${V_ARCH}/G0\")\n\ninclude_directories(\n    ${_QURT_INSTALL_DIR}/include\n    ${_QURT_INSTALL_DIR}/include/qurt\n    ${_QURT_INSTALL_DIR}/include/posix\n    )\n\nset(QURT_START_LINK_LIBS)\nset(QURT_START_LINK_LIBS\n    \"${TARGET_DIR}/init.o\"\n    \"${RTOS_DIR}/lib/crt1.o\"\n    \"${RTOS_DIR}/lib/debugmon.o\"\n    \"${RTOS_DIR}/lib/libqurt.a\"\n    \"${TARGET_DIR}/libc.a\"\n    \"${TARGET_DIR}/libqcc.a\"\n    \"${TARGET_DIR}/libhexagon.a\"\n    \"${RTOS_DIR}/lib/libqurtcfs.a\"\n    \"${RTOS_DIR}/lib/libtimer_island.a\"\n    \"${RTOS_DIR}/lib/libtimer_main.a\"\n    \"${RTOS_DIR}/lib/libposix.a\"\n    )\nSTRING(REPLACE \";\" \" \" QURT_START_LINK_LIBS \"${QURT_START_LINK_LIBS}\")\n\nset(QURT_END_LINK_LIBS\n    ${TARGET_DIR}/fini.o\n    )\n\n#Non QURT related includes and linker flags\n\nset(TARGET_DIR_NOOS \"${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}\")\n\nif (NOT NO_WRAP_MEM_API)\n    set(WRAP_MALLOC   -Wl,--wrap=malloc)\n    set(WRAP_CALLOC   -Wl,--wrap=calloc)\n    set(WRAP_FREE     -Wl,--wrap=free)\n    set(WRAP_REALLOC  -Wl,--wrap=realloc)\n    set(WRAP_MEMALIGN -Wl,--wrap=memalign)\nendif()\n\nset(PIC_SHARED_LD_FLAGS\n    -mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH}\n    -G0\n    -fpic\n    -Wl,-Bsymbolic\n    -Wl,-L${TARGET_DIR_NOOS}/G0/pic\n    -Wl,-L${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/\n    -Wl,--no-threads ${WRAP_MALLOC} ${WRAP_CALLOC} ${WRAP_FREE} ${WRAP_REALLOC} ${WRAP_MEMALIGN}\n    -shared\n    \"-o <TARGET> <SONAME_FLAG><TARGET_SONAME>\"\n    \"<LINK_FLAGS>\"\n    -Wl,--start-group\n    \"<OBJECTS>\"\n    \"<LINK_LIBRARIES>\"\n    -Wl,--end-group\n    -lc\n    )\nSTRING(REPLACE \";\" \" \" PIC_SHARED_LD_FLAGS \"${PIC_SHARED_LD_FLAGS}\")\n\nset(HEXAGON_PIC_SHARED_LINK_OPTIONS \"${PIC_SHARED_LD_FLAGS}\")\n\n#System include paths\ninclude_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs)\ninclude_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef)\ninclude_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs)\n\n#LLVM toolchain setup\n#Compiler paths, options and architecture\nset(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX})\nset(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})\nset(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX})\nset(CMAKE_ASM_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})\nset(HEXAGON_LINKER ${CMAKE_C_COMPILER})\nset(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon)\n\nset(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG   \"-Wl,-soname,\")\nset(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG \"-Wl,-soname,\")\n\n#Compiler Options\nset(COMMON_FLAGS \"-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}\")\n\nset(CMAKE_CXX_FLAGS_DEBUG          \"${COMMON_FLAGS} -O0 -D_DEBUG -g\")\nset(CMAKE_CXX_FLAGS_RELWITHDEBINFO \"${COMMON_FLAGS} -O3 -g\")\nset(CMAKE_CXX_FLAGS_RELEASE        \"${COMMON_FLAGS} -O3\")\n\nset(CMAKE_C_FLAGS_DEBUG            \"${COMMON_FLAGS} -O0 -D_DEBUG -g\")\nset(CMAKE_C_FLAGS_RELWITHDEBINFO   \"${COMMON_FLAGS} -O3 -g\")\nset(CMAKE_C_FLAGS_RELEASE          \"${COMMON_FLAGS} -O3\")\n\nset(CMAKE_ASM_FLAGS_DEBUG          \"${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}\")\nset(CMAKE_ASM_FLAGS_RELEASE        \"${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}\")\nset(CMAKE_ASM_FLAGS_RELWITHDEBINFO \"${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}\" )\n\n#Linker Options\nset(CMAKE_C_CREATE_SHARED_LIBRARY   \"${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}\")\nset(CMAKE_CXX_CREATE_SHARED_LIBRARY \"${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}\")\n"
  },
  {
    "path": "src/ggml-hexagon/htp/cpy-ops.c",
    "content": "#pragma clang diagnostic ignored \"-Wunused-variable\"\n#pragma clang diagnostic ignored \"-Wunused-function\"\n#pragma clang diagnostic ignored \"-Wunused-but-set-variable\"\n\n#include <HAP_farf.h>\n#include <HAP_perf.h>\n\n#include <math.h>\n#include <string.h>\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n#include \"htp-ctx.h\"\n#include \"htp-msg.h\"\n#include \"htp-ops.h\"\n#include \"hvx-utils.h\"\n\nstruct htp_copy_context {\n    struct htp_ops_context * octx;\n\n    uint32_t          src0_type_size;\n    uint32_t          src0_block_size;\n\n    uint32_t          dst_type_size;\n    uint32_t          dst_block_size;\n\n    uint32_t          src0_blocks_per_row;\n    uint32_t          dst_blocks_per_row;\n\n    uint32_t          src0_nrows_per_thread;\n\n    void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith);\n};\n\n#define cpy_preamble                       \\\n    struct htp_tensor *src0 = &octx->src0; \\\n    struct htp_tensor *dst  = &octx->dst;  \\\n                                           \\\n    const uint32_t ne00 = src0->ne[0];     \\\n    const uint32_t ne01 = src0->ne[1];     \\\n    const uint32_t ne02 = src0->ne[2];     \\\n    const uint32_t ne03 = src0->ne[3];     \\\n                                           \\\n    const uint32_t nb00 = src0->nb[0];     \\\n    const uint32_t nb01 = src0->nb[1];     \\\n    const uint32_t nb02 = src0->nb[2];     \\\n    const uint32_t nb03 = src0->nb[3];     \\\n                                           \\\n    const uint32_t  ne0 = dst->ne[0];      \\\n    const uint32_t  ne1 = dst->ne[1];      \\\n    const uint32_t  ne2 = dst->ne[2];      \\\n    const uint32_t  ne3 = dst->ne[3];      \\\n                                           \\\n    const uint32_t  nb0 = dst->nb[0];      \\\n    const uint32_t  nb1 = dst->nb[1];      \\\n    const uint32_t  nb2 = dst->nb[2];      \\\n    const uint32_t  nb3 = dst->nb[3];      \\\n                                           \\\n    const uint32_t   nr = ne01;\n\nstatic void cpy_thread_sametype_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {\n    cpy_preamble;\n\n    // parallelize by src0 rows\n    const uint32_t dr  = ct->src0_nrows_per_thread;\n    const uint32_t ir0 = dr * ith;\n    const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;\n\n    // copy by rows\n    for (uint32_t i03 = 0; i03 < ne03; i03++) {\n        for (uint32_t i02 = 0; i02 < ne02; i02++) {\n            #pragma unroll(2)\n            for (uint32_t i01 = ir0; i01 < ir1; i01++) {\n                uint8_t* dst_ptr  = (uint8_t*) dst->data  + i01*nb1  + i02*nb2  + i03*nb3;\n                uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;\n                hex_l2fetch(src0_ptr, ne00 * ct->src0_type_size, nb01, 2);\n                hvx_copy_uu(dst_ptr, src0_ptr, ne00, ct->src0_type_size);\n            }\n        }\n    }\n}\n\nstatic void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith) {\n    cpy_preamble;\n\n    // parallelize by src0 rows\n    const uint32_t dr  = ct->src0_nrows_per_thread;\n    const uint32_t ir0 = dr * ith;\n    const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;\n\n    // dst counters\n    int64_t k10 = 0;\n    int64_t i11 = 0;\n    int64_t i12 = 0;\n    int64_t i13 = 0;\n\n    // number of blocks in a row\n    const int64_t nk00 = ct->src0_blocks_per_row;\n    const int64_t nk0  = ct->dst_blocks_per_row;\n\n    for (int64_t i03 = 0; i03 < ne03; i03++) {\n        for (int64_t i02 = 0; i02 < ne02; i02++) {\n            k10 += nk00 * ir0;\n            while (k10 >= nk0) {\n                k10 -= nk0;\n                if (++i11 == ne1) {\n                    i11 = 0;\n                    if (++i12 == ne2) {\n                        i12 = 0;\n                        if (++i13 == ne3) {\n                            i13 = 0;\n                        }\n                    }\n                }\n            }\n            for (int64_t i01 = ir0; i01 < ir1; i01++) {\n                for (int64_t k00 = 0; k00 < nk00; k00++) {\n                    const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);\n                          char * dst_ptr  = ((char *)  dst->data + k10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);\n                    memcpy(dst_ptr, src0_ptr, ct->dst_type_size);\n\n                    if (++k10 == nk0) {\n                        k10 = 0;\n                        if (++i11 == ne1) {\n                            i11 = 0;\n                            if (++i12 == ne2) {\n                                i12 = 0;\n                                if (++i13 == ne3) {\n                                    i13 = 0;\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n            k10 += nk00 * (ne01 - ir1);\n            while (k10 >= nk0) {\n                k10 -= nk0;\n                if (++i11 == ne1) {\n                    i11 = 0;\n                    if (++i12 == ne2) {\n                        i12 = 0;\n                        if (++i13 == ne3) {\n                            i13 = 0;\n                        }\n                    }\n                }\n            }\n        }\n    }\n}\n\nstatic void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {\n    cpy_preamble;\n\n    // parallelize by src0 rows\n    const uint32_t dr  = ct->src0_nrows_per_thread;\n    const uint32_t ir0 = dr * ith;\n    const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;\n\n    // copy by rows\n    for (uint32_t i03 = 0; i03 < ne03; i03++) {\n        for (uint32_t i02 = 0; i02 < ne02; i02++) {\n            #pragma unroll(2)\n            for (uint32_t i01 = ir0; i01 < ir1; i01++) {\n                uint8_t* dst_ptr  = (uint8_t*) dst->data  + i01*nb1  + i02*nb2  + i03*nb3;\n                uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;\n                hex_l2fetch(src0_ptr, ne00 * sizeof(float), nb01, 2);\n                hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);\n            }\n        }\n    }\n}\n\nstatic void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {\n    cpy_preamble;\n\n    // parallelize by src0 rows\n    const uint32_t dr  = ct->src0_nrows_per_thread;\n    const uint32_t ir0 = dr * ith;\n    const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;\n\n    // copy by rows\n    for (uint32_t i03 = 0; i03 < ne03; i03++) {\n        for (uint32_t i02 = 0; i02 < ne02; i02++) {\n            #pragma unroll(2)\n            for (uint32_t i01 = ir0; i01 < ir1; i01++) {\n                uint8_t* dst_ptr  = (uint8_t*) dst->data  + i01*nb1  + i02*nb2  + i03*nb3;\n                uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;\n                hex_l2fetch(src0_ptr, ne00 * sizeof(__fp16), nb01, 2);\n                hvx_copy_f32_f16_uu(dst_ptr, src0_ptr, ne00);\n            }\n        }\n    }\n}\n\nstatic void cpy_work_func(unsigned int n, unsigned int i, void *data) {\n    struct htp_copy_context *ct = (struct htp_copy_context *) data;\n    ct->copy(ct, ct->octx, n, i);\n}\n\nint op_cpy(struct htp_ops_context * octx) {\n    cpy_preamble;\n\n    const uint32_t n_threads = MIN(nr, octx->n_threads);\n\n    struct htp_copy_context ct;\n    ct.octx = octx;\n\n    switch (src0->type) {\n    case HTP_TYPE_F32: ct.src0_type_size = 4; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break;\n    case HTP_TYPE_F16: ct.src0_type_size = 2; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break;\n    default:\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    switch (dst->type) {\n    case HTP_TYPE_F32: ct.dst_type_size = 4; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break;\n    case HTP_TYPE_F16: ct.dst_type_size = 2; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break;\n    default:\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {\n        return HTP_STATUS_OK;\n    }\n\n    const bool sametype   = (src0->type == dst->type);\n    const bool transposed = (nb00 > nb01) || (nb0 > nb1);\n    const bool sameshape  = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3);\n\n    ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;\n\n    if (sametype && sameshape) {\n        ct.copy = cpy_thread_sametype_sameshape;\n    } else if (sameshape) {\n        /**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32)\n            ct.copy = cpy_thread_f16_f32_sameshape;\n        else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16)\n            ct.copy = cpy_thread_f32_f16_sameshape;\n        else\n            return HTP_STATUS_NO_SUPPORT;\n    } else if (sametype) {\n        ct.copy = cpy_thread_sametype_reshape;\n    } else {\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_threads);\n\n    return HTP_STATUS_OK;\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/flash-attn-ops.c",
    "content": "#pragma clang diagnostic ignored \"-Wunused-variable\"\n#pragma clang diagnostic ignored \"-Wunused-function\"\n#pragma clang diagnostic ignored \"-Wunused-but-set-variable\"\n\n#include <assert.h>\n#include <HAP_farf.h>\n#include <HAP_perf.h>\n#include <math.h>\n#include <string.h>\n\n#include \"hex-dma.h\"\n#include \"hvx-utils.h\"\n#include \"hvx-dump.h\"\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n#include \"htp-ctx.h\"\n#include \"htp-msg.h\"\n#include \"htp-ops.h\"\n\n// Must be multiple of 32\n#define FLASH_ATTN_BLOCK_SIZE (32 * 2)\n\n// This is a bit of a hack because the compiler is strugling to properly inline\n// the default hvx_vec_f32_to_f16 with output into the local array.\nstatic void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)\n{\n    *(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);\n}\n\n// Dot product of two F16 vectors, accumulating to float\nstatic inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {\n    const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16\n    const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16\n\n    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors\n    uint32_t nloe = n % VLEN_FP16; // leftover elements\n\n    HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));\n\n    uint32_t i = 0;\n\n    #pragma unroll(4)\n    for (i = 0; i < nvec; i++) {\n        rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]);\n    }\n\n    if (nloe) {\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);\n        HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);\n        HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);\n\n        rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);\n    }\n\n    HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));\n    rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)));\n    hvx_vec_store_u(r, 4, rsum);\n}\n\nstatic inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,\n                                                const uint8_t * restrict x,\n                                                const size_t stride_x,\n                                                const size_t nvec,\n                                                const size_t nloe) {\n    const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x;                   // fp16\n    const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x);      // fp16\n    const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2);  // fp16\n    const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3);  // fp16\n    const HVX_Vector * restrict vy  = (const HVX_Vector * restrict) y;                   // fp16\n\n    HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));\n    HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));\n    HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));\n    HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));\n\n    uint32_t i = 0;\n\n    for (i = 0; i < nvec; i++) {\n        HVX_Vector y_hf  = vy[i];\n        HVX_Vector x0_hf = vx0[i];\n        HVX_Vector x1_hf = vx1[i];\n        HVX_Vector x2_hf = vx2[i];\n        HVX_Vector x3_hf = vx3[i];\n\n        rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);\n        rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);\n        rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);\n        rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);\n    }\n\n    if (nloe) {\n        // Load x (fp16) and zero-out unused elements\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);\n        HVX_Vector     y_hf  = Q6_V_vand_QV(bmask, vy[i]);\n        HVX_Vector     x0_hf = Q6_V_vand_QV(bmask, vx0[i]);\n        HVX_Vector     x1_hf = Q6_V_vand_QV(bmask, vx1[i]);\n        HVX_Vector     x2_hf = Q6_V_vand_QV(bmask, vx2[i]);\n        HVX_Vector     x3_hf = Q6_V_vand_QV(bmask, vx3[i]);\n\n        rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);\n        rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);\n        rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);\n        rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);\n    }\n\n    HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));\n    HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));\n    HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p)));\n    HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p)));\n\n    HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };\n    return hvx_vec_reduce_sum_f32x4(rsum0123);\n}\n\nstatic inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,\n                                                 const uint8_t * restrict x,\n                                                 const size_t stride_x,\n                                                 const size_t n,\n                                                 float        s) {\n\n    const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors\n    const size_t nloe = n % VLEN_FP16; // leftover elements\n\n    HVX_Vector   sums;  // initialize at j = 0\n    const size_t stride_x_4 = stride_x * 4;\n    for (uint32_t j = 0; j < VLEN_FP32; j += 4) {\n        HVX_Vector     sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);\n        HVX_VectorPred pred    = Q6_Q_vsetq_R(j * SIZEOF_FP32);\n        sums                   = Q6_V_vmux_QVV(pred, sums, sums_x4);\n        x += stride_x_4;\n    }\n\n    sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums);\n    return Q6_Vsf_equals_Vqf32(sums);\n}\n\n// MAD: y (F32) += x (F16) * s (F16)\nstatic inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, int n) {\n    const HVX_Vector * restrict vx0 = (const HVX_Vector *) x;\n\n    HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;\n    HVX_Vector * restrict vy = (HVX_Vector *) y;\n\n    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors\n    uint32_t nloe = n % VLEN_FP16; // leftover elements\n\n    HVX_Vector S0 = hvx_vec_splat_f16(*s);\n\n    uint32_t i = 0;\n\n    #pragma unroll(2)\n    for (i = 0; i < nvec; ++i) {\n        vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);\n    }\n\n    if (nloe) {\n        HVX_VectorPair xy_p = vy_p[i];\n        xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);\n\n        HVX_Vector xy = Q6_V_lo_W(xy_p);\n        i = 2 * i;  // index for vy\n\n        if (nloe >= VLEN_FP32) {\n            vy[i] = xy;\n            nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);\n        }\n\n        if (nloe) {\n            hvx_vec_store_a(&vy[i], nloe * 4, xy);\n        }\n    }\n}\n\n// MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16)\nstatic inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1,\n                                          const __fp16 * restrict s0, const __fp16 * restrict s1, int n) {\n    const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0;\n    const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1;\n\n    HVX_VectorPair * restrict vy_p  = (HVX_VectorPair *) y;\n    HVX_Vector * restrict vy        = (HVX_Vector *) y;\n\n    uint32_t nvec = n / VLEN_FP16;  // num full fp16 hvx vectors\n    uint32_t nloe = n % VLEN_FP16;  // leftover elements\n\n    HVX_Vector S0 = hvx_vec_splat_f16(*s0);\n    HVX_Vector S1 = hvx_vec_splat_f16(*s1);\n\n    uint32_t i = 0;\n\n    #pragma unroll(2)\n    for (i = 0; i < nvec; ++i) {\n        vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);\n        vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1);\n    }\n\n    if (nloe) {\n        HVX_VectorPair xy_p = vy_p[i];\n        xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);\n        xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1);\n\n        HVX_Vector xy = Q6_V_lo_W(xy_p);\n        i = 2 * i;  // index for vy\n\n        if (nloe >= VLEN_FP32) {\n            vy[i] = xy;\n            nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);\n        }\n\n        if (nloe) {\n            hvx_vec_store_a(&vy[i], nloe * 4, xy);\n        }\n    }\n}\n\nstruct htp_fa_context {\n    const struct htp_ops_context * octx;\n\n    struct fastdiv_values src0_div21;\n    struct fastdiv_values src0_div1;\n\n    struct fastdiv_values broadcast_rk2;\n    struct fastdiv_values broadcast_rk3;\n    struct fastdiv_values broadcast_rv2;\n    struct fastdiv_values broadcast_rv3;\n\n    struct fastdiv_values src3_div2;\n    struct fastdiv_values src3_div3;\n\n    float scale;\n    float max_bias;\n    float logit_softcap;\n\n    uint32_t n_head_log2;\n    float m0;\n    float m1;\n\n    uint32_t n_blocks;\n\n    size_t size_q_row_padded;\n    size_t size_k_row_padded;\n    size_t size_v_row_padded;\n\n    size_t size_k_block;\n    size_t size_v_block;\n    size_t size_m_block;\n\n    uint32_t qrows;\n    uint32_t qrows_per_thread;\n\n    bool is_q_fp32;\n\n    uint64_t t_start;\n};\n\nstatic inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) {\n    assert((size_t) dst % 128 == 0);\n    assert((size_t) src % 128 == 0);\n\n    const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;\n    HVX_Vector * restrict vdst       = (HVX_Vector * restrict) dst;\n\n    const uint32_t nvec = n / VLEN_FP32;\n    const uint32_t nloe = n % VLEN_FP32;\n\n    uint32_t i = 0;\n    #pragma unroll(4)\n    for (; i < nvec; ++i) {\n        vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));\n    }\n    if (nloe) {\n        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);\n        hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));\n    }\n}\n\nstatic void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_fa_context * factx = (struct htp_fa_context *) data;\n    const struct htp_ops_context * octx = factx->octx;\n    const struct htp_tensor * q = &octx->src0;\n    const struct htp_tensor * k = &octx->src1;\n    const struct htp_tensor * v = &octx->src2;\n    const struct htp_tensor * mask  = (octx->src3.data) ? &octx->src3 : NULL;\n    const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;\n    const struct htp_tensor * dst = &octx->dst;\n\n    const uint32_t neq0 = q->ne[0];\n    const uint32_t neq1 = q->ne[1];\n    const uint32_t neq2 = q->ne[2];\n    const uint32_t neq3 = q->ne[3];\n\n    const uint32_t nek0 = k->ne[0];\n    const uint32_t nek1 = k->ne[1];\n    const uint32_t nek2 = k->ne[2];\n    const uint32_t nek3 = k->ne[3];\n\n    const uint32_t nev0 = v->ne[0];\n    const uint32_t nev1 = v->ne[1];\n    const uint32_t nev2 = v->ne[2];\n    const uint32_t nev3 = v->ne[3];\n\n    const uint32_t nbq1 = q->nb[1];\n    const uint32_t nbq2 = q->nb[2];\n    const uint32_t nbq3 = q->nb[3];\n\n    const uint32_t nbk1 = k->nb[1];\n    const uint32_t nbk2 = k->nb[2];\n    const uint32_t nbk3 = k->nb[3];\n\n    const uint32_t nbv1 = v->nb[1];\n    const uint32_t nbv2 = v->nb[2];\n    const uint32_t nbv3 = v->nb[3];\n\n    const uint32_t ne1 = dst->ne[1];\n    const uint32_t ne2 = dst->ne[2];\n    const uint32_t ne3 = dst->ne[3];\n\n    const uint32_t nb1 = dst->nb[1];\n    const uint32_t nb2 = dst->nb[2];\n    const uint32_t nb3 = dst->nb[3];\n\n    // total rows in q\n    const uint32_t nr = factx->qrows;\n    const uint32_t dr = factx->qrows_per_thread;\n    const uint32_t ir0 = dr * ith;\n    const uint32_t ir1 = MIN(ir0 + dr, nr);\n\n    if (ir0 >= ir1) return;\n\n    dma_queue * dma = octx->ctx->dma[ith];\n\n    const uint32_t DK = nek0;\n    const uint32_t DV = nev0;\n\n    const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);\n    const size_t size_k_row = DK * sizeof(__fp16);\n    const size_t size_v_row = DV * sizeof(__fp16);\n\n    // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator\n    uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;\n    uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith;\n    uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith;\n    uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;\n    uint8_t * spad_a = octx->dst_spad.data  + octx->dst_spad.size_per_thread  * ith;\n\n    const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);\n\n    for (uint32_t ir = ir0; ir < ir1; ++ir) {\n        const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);\n        const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);\n        const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);\n\n        const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3);\n        const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2);\n\n        const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3);\n        const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2);\n\n        // Fetch Q row\n        const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);\n        dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1);\n\n        // FARF(HIGH, \"fa %u: prefetch Q: ir %u iq1 %u iq2 %u iq3 %u q_row_ptr %p size %u : usec %u\", ith, ir, iq1, iq2, iq3, q_row_ptr, size_q_row,\n        //                 (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));\n\n        const __fp16 * mp_base = NULL;\n        if (mask) {\n            const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2);\n            const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3);\n            mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);\n        }\n\n        // Prefetch first two blocks\n        for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) {\n            const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;\n            const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);\n\n            // K\n            const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);\n            uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block;\n            dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size);\n\n            // V\n            const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);\n            uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block;\n            dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size);\n\n            // Mask\n            if (mask) {\n                const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);\n                uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;\n                // Mask is 1D contiguous for this row\n                dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);\n            }\n\n            // FARF(HIGH, \"fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u\",\n            //             ith, ir, ib, iq1, iq2, iq3,\n            //             size_k_row, size_v_row, current_block_size,\n            //             (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));\n        }\n\n        const uint32_t h = iq2; // head index\n        const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;\n\n        HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);\n        HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);\n\n        // Clear accumulator\n        hvx_splat_f32_a(spad_a, 0, DV);\n        float * VKQ32 = (float *) (spad_a + 0);\n\n        uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;\n        if (factx->is_q_fp32) {\n            hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK);  // inplace convert f32 to f16\n        }\n\n        const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);\n        for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) {\n            const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;\n            const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);\n\n            // Wait for DMA\n            uint8_t * k_base = dma_queue_pop(dma).dst; // K\n            uint8_t * v_base = dma_queue_pop(dma).dst; // V\n            __fp16  * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M\n\n            // FARF(HIGH, \"fa %u: process: ir %u ib %u : iq1 %u iq2 %u iq3 %u q_ptr_vtcm %p : usec %u\",\n            //              ith, ir, ib, iq1, iq2, iq3, q_ptr_vtcm,\n            //             (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));\n\n            // Inner loop processing the block from VTCM\n            uint32_t ic = 0;\n\n            // Process in sub-blocks of 32 (VLEN_FP32)\n            HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32];\n            HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);\n            for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {\n                // 1. Compute scores\n                HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale);\n\n                // 2. Softcap\n                if (factx->logit_softcap != 0.0f) {\n                    scores = hvx_vec_tanh_f32(scores);\n                    scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);\n                    scores = Q6_Vsf_equals_Vqf32(scores);\n                }\n\n                // 3. Mask\n                if (mask) {\n                    const __fp16 * mp = m_base + ic;\n                    HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;\n                    HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);\n                    HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);\n                    scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);\n                    scores = Q6_Vsf_equals_Vqf32(scores);\n                }\n\n                sb_scores[iv] = scores;\n                v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max\n            }\n\n            {\n                // 4. Online Softmax Update\n                HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);\n                HVX_Vector diff_vec  = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec));\n                HVX_Vector ms_vec    = hvx_vec_exp_f32(diff_vec);\n                M_vec = M_new_vec;\n\n                hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);\n\n                HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);\n                for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {\n                    HVX_Vector scores = sb_scores[iv];\n                    HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);\n                    HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));\n\n                    p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));\n\n                    // 5. Accumulate V\n                    __fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16];\n                    hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0));\n\n                    for (uint32_t j = 0; j < VLEN_FP32; j += 2) {\n                        const uint32_t  cur_ic = ic2 + j;\n                        const uint8_t * v_ptr  = v_base + cur_ic * factx->size_v_row_padded;\n                        hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV);\n                    }\n                }\n\n                p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);\n                S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));\n            }\n\n            if (ic < current_block_size) {\n                // Sync scalars for leftover/next block if needed\n                float M = hvx_vec_get_f32(M_vec);\n                float S = hvx_vec_get_f32(S_vec);\n\n                // Leftover\n                for (; ic < current_block_size; ++ic) {\n                    float s_val;\n                    const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;\n                    hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);\n                    if (factx->logit_softcap != 0.0f) {\n                        s_val = factx->logit_softcap * tanhf(s_val);\n                    }\n\n                    if (mask) {\n                        const float m_val = m_base[ic];\n                        s_val += slope * m_val;\n                    }\n\n                    const float Mold = M;\n                    __fp16 vs = 1.0f;\n\n                    if (s_val > M) {\n                        M = s_val;\n                        HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);\n                        HVX_Vector ms_vec   = hvx_vec_exp_f32(diff_vec);\n                        hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);\n\n                        float ms = hvx_vec_get_f32(ms_vec);\n                        S = S * ms + vs;\n                    } else {\n                        HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);\n                        vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));\n                        S += vs;\n                    }\n\n                    const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;\n\n                    hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV);\n                }\n\n                M_vec = hvx_vec_splat_f32(M);\n                S_vec = hvx_vec_splat_f32(S);\n            }\n\n            // Issue DMA for next+1 block (if exists)\n            if (ib + 2 < factx->n_blocks) {\n                const uint32_t next_ib = ib + 2;\n                const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;\n                const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);\n\n                // K\n                const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);\n                dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size);\n\n                // V\n                const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);\n                dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size);\n\n                // Mask\n                if (mask) {\n                    const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);\n                    dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);\n                }\n\n                // FARF(HIGH, \"fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u\",\n                //         ith, ir, next_ib, iq1, iq2, iq3,\n                //         size_k_row, size_v_row, next_block_size,\n                //         (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));\n            }\n        }\n\n        // sinks\n        float M = hvx_vec_get_f32(M_vec);\n        float S = hvx_vec_get_f32(S_vec);\n\n        if (sinks) {\n            const float s = ((float *)((char *) sinks->data))[h];\n\n            float vs = 1.0f;\n\n            if (s > M) {\n                HVX_Vector diff_vec = hvx_vec_splat_f32(M - s);\n                HVX_Vector ms_vec   = hvx_vec_exp_f32(diff_vec);\n                hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);\n\n                float ms = hvx_vec_get_f32(ms_vec);\n                S = S * ms + vs;\n            } else {\n                HVX_Vector diff_vec = hvx_vec_splat_f32(s - M);\n                vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));\n                S += vs;\n            }\n        }\n\n        const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;\n        hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv);\n\n        // Store result\n        // dst indices\n        const int i1 = iq1;\n        const int i2 = iq2;\n        const int i3 = iq3;\n\n        // dst is permuted\n        uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;\n\n        if (dst->type == HTP_TYPE_F32) {\n            hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);\n        } else if (dst->type == HTP_TYPE_F16) {\n            hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);\n        }\n    }\n}\n\nint op_flash_attn_ext(struct htp_ops_context * octx) {\n    const struct htp_tensor * q = &octx->src0;\n    const struct htp_tensor * k = &octx->src1;\n    const struct htp_tensor * v = &octx->src2;\n    const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;\n    const struct htp_tensor * dst = &octx->dst;\n\n    // Check support\n    if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    struct htp_fa_context factx;\n    factx.octx = octx;\n\n    factx.t_start = HAP_perf_get_qtimer_count();\n\n    factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);\n    factx.src0_div1  = init_fastdiv_values(q->ne[1]);\n\n    factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);\n    factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);\n    factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);\n    factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);\n\n    if (mask) {\n        factx.src3_div2 = init_fastdiv_values(mask->ne[2]);\n        factx.src3_div3 = init_fastdiv_values(mask->ne[3]);\n    }\n\n    factx.is_q_fp32 = (q->type == HTP_TYPE_F32);\n    factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128);\n    factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);\n    factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);\n\n    size_t size_q_block = factx.size_q_row_padded * 1; // single row for now\n    factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;\n    factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;\n    factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);\n\n    factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;\n\n    float scale         = 1.0f;\n    float max_bias      = 0.0f;\n    float logit_softcap = 0.0f;\n\n    memcpy(&scale,         (float *) octx->op_params + 0, sizeof(float));\n    memcpy(&max_bias,      (float *) octx->op_params + 1, sizeof(float));\n    memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));\n\n    if (logit_softcap != 0.0f) {\n        scale /= logit_softcap;\n    }\n\n    factx.scale = scale;\n    factx.max_bias = max_bias;\n    factx.logit_softcap = logit_softcap;\n\n    uint32_t n_head = q->ne[2];\n    factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head));\n    factx.m0 = powf(2.0f, -(max_bias       ) / factx.n_head_log2);\n    factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);\n\n    // total rows in q\n    const uint32_t neq0 = q->ne[0];\n    const uint32_t neq1 = q->ne[1];\n    const uint32_t neq2 = q->ne[2];\n    const uint32_t neq3 = q->ne[3];\n\n    factx.qrows = neq1*neq2*neq3;\n    factx.qrows_per_thread = (factx.qrows + octx->n_threads - 1) / octx->n_threads;\n\n    size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32\n\n    octx->src0_spad.size_per_thread = size_q_block * 1;\n    octx->src1_spad.size_per_thread = factx.size_k_block * 2;\n    octx->src2_spad.size_per_thread = factx.size_v_block * 2;\n    octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;\n    octx->dst_spad.size_per_thread  = size_vkq_acc;\n\n    octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;\n    octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;\n    octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads;\n    octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads;\n    octx->dst_spad.size  = octx->dst_spad.size_per_thread  * octx->n_threads;\n\n    size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size;\n\n    if (octx->ctx->vtcm_size < total_spad) {\n        return HTP_STATUS_VTCM_TOO_SMALL;\n    }\n\n    octx->src0_spad.data = octx->ctx->vtcm_base;\n    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;\n    octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;\n    octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;\n    octx->dst_spad.data  = octx->src3_spad.data + octx->src3_spad.size;\n\n    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {\n        worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);\n    }\n\n    return HTP_STATUS_OK;\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/get-rows-ops.c",
    "content": "#pragma clang diagnostic ignored \"-Wunused-variable\"\n#pragma clang diagnostic ignored \"-Wunused-function\"\n#pragma clang diagnostic ignored \"-Wunused-but-set-variable\"\n\n#include <HAP_farf.h>\n#include <HAP_perf.h>\n\n#include <math.h>\n#include <string.h>\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n#include \"htp-ctx.h\"\n#include \"htp-msg.h\"\n#include \"htp-ops.h\"\n#include \"hvx-utils.h\"\n\nstruct get_rows_context {\n    struct htp_ops_context * octx;\n    uint32_t src1_nrows_per_thread;\n    struct fastdiv_values get_rows_div_ne10;\n    struct fastdiv_values get_rows_div_ne10_ne11;\n};\n\n#define get_rows_preamble \\\n    const uint32_t ne00 = octx->src0.ne[0]; \\\n    const uint32_t ne01 = octx->src0.ne[1]; \\\n    const uint32_t ne02 = octx->src0.ne[2]; \\\n    const uint32_t ne03 = octx->src0.ne[3]; \\\n                                            \\\n    const uint32_t ne10 = octx->src1.ne[0]; \\\n    const uint32_t ne11 = octx->src1.ne[1]; \\\n    const uint32_t ne12 = octx->src1.ne[2]; \\\n                                            \\\n    const uint32_t nb01 = octx->src0.nb[1]; \\\n    const uint32_t nb02 = octx->src0.nb[2]; \\\n    const uint32_t nb03 = octx->src0.nb[3]; \\\n                                            \\\n    const uint32_t nb10 = octx->src1.nb[0]; \\\n    const uint32_t nb11 = octx->src1.nb[1]; \\\n    const uint32_t nb12 = octx->src1.nb[2]; \\\n                                            \\\n    const uint32_t nb1 = octx->dst.nb[1];   \\\n    const uint32_t nb2 = octx->dst.nb[2];   \\\n    const uint32_t nb3 = octx->dst.nb[3];   \\\n                                            \\\n    const uint32_t nr = ne10 * ne11 * ne12;\n\nstatic void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {\n    struct get_rows_context * grctx = (struct get_rows_context *)data;\n    struct htp_ops_context * octx = grctx->octx;\n    get_rows_preamble;\n\n    // parallelize by src1 elements (which correspond to dst rows)\n    const uint32_t dr  = grctx->src1_nrows_per_thread;\n    const uint32_t ir0 = dr * ith;\n    const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;\n\n    const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);\n\n    for (uint32_t i = ir0; i < ir1; ++i) {\n        const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11);\n        const uint32_t rem = i - i12 * ne11 * ne10;\n        const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10);\n        const uint32_t i10 = rem - i11 * ne10;\n\n        const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;\n\n        uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;\n\n        if (i01 >= ne01) {\n            // invalid index, skip for now to avoid crash\n            continue;\n        }\n\n        const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;\n        const uintptr_t dst_ptr  = octx->dst.data  + i10*nb1  + i11*nb2  + i12*nb3;\n        hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);\n    }\n}\n\nint op_get_rows(struct htp_ops_context * octx) {\n    get_rows_preamble;\n\n    const uint32_t n_threads = MIN(nr, octx->n_threads);\n\n    if (octx->src0.type != HTP_TYPE_F32) {\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    if (octx->dst.type != HTP_TYPE_F32) {\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {\n        return HTP_STATUS_OK;\n    }\n\n    struct get_rows_context grctx;\n    grctx.octx = octx;\n    grctx.get_rows_div_ne10      = init_fastdiv_values(octx->src1.ne[0]);\n    grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);\n\n    grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads;\n\n    worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_threads);\n    return HTP_STATUS_OK;\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hex-dma.c",
    "content": "#include \"hex-dma.h\"\n\n#include <stdbool.h>\n#include <stdlib.h>\n#include <string.h>\n\n#pragma clang diagnostic ignored \"-Wunused-function\"\n\nstatic inline uint32_t pow2_ceil(uint32_t x) {\n    if (x <= 1) {\n        return 1;\n    }\n    int p = 2;\n    x--;\n    while (x >>= 1) {\n        p <<= 1;\n    }\n    return p;\n}\n\ndma_queue * dma_queue_create(size_t capacity) {\n    dma_queue * q = (dma_queue *) memalign(32, sizeof(dma_queue));\n    if (q == NULL) {\n        FARF(ERROR, \"%s: failed to allocate DMA queue\\n\", __FUNCTION__);\n        return NULL;\n    }\n\n    capacity = pow2_ceil(capacity);\n\n    memset(q, 0, sizeof(dma_queue));\n    q->capacity = capacity;\n    q->idx_mask = capacity - 1;\n\n    q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t));\n    memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t));\n\n    q->dptr = (dma_ptr *) memalign(4, capacity * sizeof(dma_ptr));\n    memset(q->dptr, 0, capacity * sizeof(dma_ptr));\n\n    q->tail = &q->desc[capacity - 1];\n\n    if (!q->desc && !q->dptr) {\n        FARF(ERROR, \"%s: failed to allocate DMA queue items\\n\", __FUNCTION__);\n        return NULL;\n    }\n\n    FARF(HIGH, \"dma-queue: capacity %u\\n\", capacity);\n\n    return q;\n}\n\nvoid dma_queue_delete(dma_queue * q) {\n    if (!q) {\n        return;\n    }\n    free(q->desc);\n    free(q->dptr);\n    free(q);\n}\n\nvoid dma_queue_flush(dma_queue * q) {\n    while (dma_queue_pop(q).dst != NULL) ;\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hex-dma.h",
    "content": "#ifndef HTP_DMA_H\n#define HTP_DMA_H\n\n#include <HAP_farf.h>\n#include <hexagon_types.h>\n#include <stdbool.h>\n#include <stdint.h>\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\ntypedef struct {\n    void *dst;\n    const void *src;\n} dma_ptr;\n\ntypedef struct {\n    hexagon_udma_descriptor_type1_t * desc;  // descriptor pointers\n    hexagon_udma_descriptor_type1_t * tail;  // tail pointer\n    dma_ptr                         * dptr;  // dst/src pointers\n    uint32_t                          push_idx;\n    uint32_t                          pop_idx;\n    uint32_t                          capacity;\n    uint32_t                          idx_mask;\n} dma_queue;\n\ndma_queue * dma_queue_create(size_t capacity);\nvoid        dma_queue_delete(dma_queue * q);\nvoid        dma_queue_flush(dma_queue * q);\n\n// TODO: technically we don't need these and could use Q6_dmstart/wait/etc instead\n// but those do not seem to always compiler properly.\nstatic inline void dmstart(void * next) {\n    asm volatile(\" release(%0):at\" : : \"r\"(next));\n    asm volatile(\" dmstart(%0)\" : : \"r\"(next));\n}\n\nstatic inline void dmlink(void * cur, void * next) {\n    asm volatile(\" release(%0):at\" : : \"r\"(next));\n    asm volatile(\" dmlink(%0, %1)\" : : \"r\"(cur), \"r\"(next));\n}\n\nstatic inline unsigned int dmpoll(void) {\n    unsigned int ret = 0;\n    asm volatile(\" %0 = dmpoll\" : \"=r\"(ret) : : \"memory\");\n    return ret;\n}\n\nstatic inline unsigned int dmwait(void) {\n    unsigned int ret = 0;\n    asm volatile(\" %0 = dmwait\" : \"=r\"(ret) : : \"memory\");\n    return ret;\n}\n\nstatic inline dma_ptr dma_make_ptr(void *dst, const void *src)\n{\n    dma_ptr p = { dst, src };\n    return p;\n}\n\nstatic inline bool dma_queue_push(dma_queue * q,\n                                  dma_ptr     dptr,\n                                  size_t      dst_row_size,\n                                  size_t      src_row_size,\n                                  size_t      width, // width in bytes. number of bytes to transfer per row\n                                  size_t      nrows) {\n    if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {\n        FARF(ERROR, \"dma-push: queue full\\n\");\n        return false;\n    }\n\n    hexagon_udma_descriptor_type1_t * desc = &q->desc[q->push_idx];\n\n    desc->next           = NULL;\n    desc->length         = 0;\n    desc->desctype       = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1;\n    desc->dstbypass      = 1;\n    desc->srcbypass      = 1;\n#if __HVX_ARCH__ >= 73\n    desc->dstbypass      = 1;\n    desc->srcbypass      = 1;\n#else\n    desc->dstbypass      = 0;\n    desc->srcbypass      = 1;\n#endif\n    desc->order          = 0;\n    desc->dstate         = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE;\n    desc->src            = (void *) dptr.src;\n    desc->dst            = (void *) dptr.dst;\n    desc->allocation     = 0;\n    desc->padding        = 0;\n    desc->roiwidth       = width;\n    desc->roiheight      = nrows;\n    desc->srcstride      = src_row_size;\n    desc->dststride      = dst_row_size;\n    desc->srcwidthoffset = 0;\n    desc->dstwidthoffset = 0;\n\n    q->dptr[q->push_idx] = dptr;\n\n    dmlink(q->tail, desc);\n    q->tail = desc;\n\n    // FARF(ERROR, \"dma-push: i %u width %u nrows %d dst %p src %p\\n\", q->push_idx, width, nrows, dptr.dst, dptr.src);\n    q->push_idx = (q->push_idx + 1) & q->idx_mask;\n    return true;\n}\n\nstatic inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q,\n                                              dma_ptr     dptr,\n                                              size_t      dst_row_size,\n                                              size_t      src_row_size,\n                                              size_t      nrows) {\n    return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows);\n}\n\n\nstatic inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q,\n                                              dma_ptr     dptr,\n                                              size_t      dst_row_size,\n                                              size_t      src_row_size,\n                                              size_t      nrows) {\n    return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);\n}\n\nstatic inline dma_ptr dma_queue_pop(dma_queue * q) {\n    dma_ptr dptr  = { NULL };\n\n    if (q->push_idx == q->pop_idx) {\n        return dptr;\n    }\n\n    hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx];\n\n    // Wait for desc to complete\n    while (1) {\n        dmpoll();\n        if (desc->dstate == HEXAGON_UDMA_DESC_DSTATE_COMPLETE) {\n            break;\n        }\n        // FARF(ERROR, \"dma-pop: waiting for DMA : %u\\n\", q->pop_idx);\n    }\n\n    dptr = q->dptr[q->pop_idx];\n\n    // FARF(ERROR, \"dma-pop: i %u dst %p src %p\\n\", q->pop_idx, dptr.dst, dptr.src);\n    q->pop_idx = (q->pop_idx + 1) & q->idx_mask;\n    return dptr;\n}\n\nstatic inline dma_ptr dma_queue_pop_nowait(dma_queue * q) {\n    dma_ptr dptr  = { NULL };\n\n    if (q->push_idx == q->pop_idx) {\n        return dptr;\n    }\n\n    dptr = q->dptr[q->pop_idx];\n\n    // FARF(ERROR, \"dma-pop-nowait: i %u dst %p src %p\\n\", q->pop_idx, dptr.dst, dptr.src);\n    q->pop_idx = (q->pop_idx + 1) & q->idx_mask;\n    return dptr;\n}\n\nstatic inline bool dma_queue_empty(dma_queue * q) {\n    return q->push_idx == q->pop_idx;\n}\n\nstatic inline uint32_t dma_queue_depth(dma_queue * q) {\n    return (q->push_idx - q->pop_idx) & q->idx_mask;\n}\n\nstatic inline uint32_t dma_queue_capacity(dma_queue * q) {\n    return q->capacity;\n}\n\n#ifdef __cplusplus\n}  // extern \"C\"\n#endif\n\n#endif /* HTP_DMA_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hex-dump.h",
    "content": "#ifndef HEX_DUMP_H\n#define HEX_DUMP_H\n\n#include <HAP_farf.h>\n\nstatic inline void hex_dump_int8_line(char * pref, const int8_t * x, int n) {\n    char str[1024], *p = str, *p_end = str + sizeof(str);\n    p += snprintf(p, p_end - p, \"%s: \", pref);\n    for (int i = 0; i < n && p < p_end; i++) {\n        p += snprintf(p, p_end - p, \"%d, \", x[i]);\n    }\n    FARF(HIGH, \"%s\\n\", str);\n}\n\nstatic inline void hex_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) {\n    char str[1024], *p = str, *p_end = str + sizeof(str);\n    p += snprintf(p, p_end - p, \"%s: \", pref);\n    for (int i = 0; i < n && p < p_end; i++) {\n        p += snprintf(p, p_end - p, \"%d, \", x[i]);\n    }\n    FARF(HIGH, \"%s\\n\", str);\n}\n\nstatic inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) {\n    char str[1024], *p = str, *p_end = str + sizeof(str);\n    p += snprintf(p, p_end - p, \"%s: \", pref);\n    for (int i = 0; i < n; i++) {\n        p += snprintf(p, p_end - p, \"%d, \", (int) x[i]);\n    }\n    FARF(HIGH, \"%s\\n\", str);\n}\n\nstatic inline void hex_dump_f16_line(char * pref, const __fp16 * x, uint32_t n) {\n    char str[1024], *p = str, *p_end = str + sizeof(str);\n    p += snprintf(p, p_end - p, \"%s: \", pref);\n    for (int i = 0; i < n; i++) {\n        p += snprintf(p, p_end - p, \"%.6f, \", (float) x[i]);\n    }\n    FARF(HIGH, \"%s\\n\", str);\n}\n\nstatic inline void hex_dump_f32_line(char * pref, const float * x, uint32_t n) {\n    char str[1024], *p = str, *p_end = str + sizeof(str);\n    p += snprintf(p, p_end - p, \"%s: \", pref);\n    for (int i = 0; i < n; i++) {\n        p += snprintf(p, p_end - p, \"%.6f, \", x[i]);\n    }\n    FARF(HIGH, \"%s\\n\", str);\n}\n\nstatic inline void hex_dump_f32(char * pref, const float * x, uint32_t n) {\n    uint32_t n0 = n / 16;\n    uint32_t n1 = n % 16;\n\n    uint32_t i = 0;\n    for (; i < n0; i++) {\n        hex_dump_f32_line(pref, x + (16 * i), 16);\n    }\n    if (n1) {\n        hex_dump_f32_line(pref, x + (16 * i), n1);\n    }\n}\n\nstatic inline void hex_dump_f16(char * pref, const __fp16 * x, uint32_t n) {\n    uint32_t n0 = n / 16;\n    uint32_t n1 = n % 16;\n\n    uint32_t i = 0;\n    for (; i < n0; i++) {\n        hex_dump_f16_line(pref, x + (16 * i), 16);\n    }\n    if (n1) {\n        hex_dump_f16_line(pref, x + (16 * i), n1);\n    }\n}\n\n#endif /* HEX_DUMP_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hex-fastdiv.h",
    "content": "#ifndef HEX_FASTDIV_H\n#define HEX_FASTDIV_H\n\n// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.\n// Precompute mp (m' in the paper) and L such that division\n// can be computed using a multiply (high 32b of 64b result)\n// and a shift:\n//\n// n/d = (mulhi(n, mp) + n) >> L;\nstruct fastdiv_values {\n    uint32_t mp;\n    uint32_t l;\n};\n\nstatic inline struct fastdiv_values init_fastdiv_values(uint32_t d) {\n    struct fastdiv_values result = { 0, 0 };\n    // compute L = ceil(log2(d));\n    while (result.l < 32 && ((uint32_t) 1 << result.l) < d) {\n        ++(result.l);\n    }\n\n    result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1);\n    return result;\n}\n\nstatic inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) {\n    // Compute high 32 bits of n * mp\n    const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32);  // mulhi(n, mp)\n    // add n, apply bit shift\n    return (hi + n) >> vals->l;\n}\n\nstatic inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) {\n    return n - fastdiv(n, vals) * d;\n}\n\n#endif /* HEX_FASTDIV_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hex-utils.h",
    "content": "#ifndef HEX_UTILS_H\n#define HEX_UTILS_H\n\n#include <stdbool.h>\n#include <stdint.h>\n\n#include \"hexagon_types.h\"\n\n#include \"hex-fastdiv.h\"\n#include \"hex-dump.h\"\n\n#ifndef MAX\n#define MAX(a, b) ((a) > (b) ? (a) : (b))\n#endif\n\n#ifndef MIN\n#define MIN(a, b) ((a) < (b) ? (a) : (b))\n#endif\n\nstatic inline uint64_t hex_get_cycles() {\n    uint64_t cycles = 0;\n    asm volatile(\" %0 = c15:14\\n\" : \"=r\"(cycles));\n    return cycles;\n}\n\nstatic inline uint64_t hex_get_pktcnt() {\n    uint64_t pktcnt;\n    asm volatile(\" %0 = c19:18\\n\" : \"=r\"(pktcnt));\n    return pktcnt;\n}\n\nstatic inline int32_t hex_is_aligned(void * addr, uint32_t align) {\n    return ((size_t) addr & (align - 1)) == 0;\n}\n\nstatic inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {\n    uint32_t left_off  = (size_t) addr & (chunk_size - 1);\n    uint32_t right_off = left_off + n;\n    return right_off <= chunk_size;\n}\n\nstatic inline uint32_t hex_round_up(uint32_t n, uint32_t m) {\n    return m * ((n + m - 1) / m);\n}\n\nstatic inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {\n    const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));\n    Q6_l2fetch_AP((void *) p, control);\n}\n\n#endif /* HEX_UTILS_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/htp-ctx.h",
    "content": "#ifndef HTP_CTX_H\n#define HTP_CTX_H\n\n#include \"hex-dma.h\"\n#include \"worker-pool.h\"\n\n#include <assert.h>\n#include <dspqueue.h>\n#include <stdatomic.h>\n#include <stdint.h>\n\n#define HTP_MAX_NTHREADS 10\n\n// Main context for htp DSP backend\nstruct htp_context {\n    dspqueue_t            queue;\n    dma_queue *           dma[HTP_MAX_NTHREADS];\n    worker_pool_context_t worker_pool;\n    uint32_t              n_threads;\n\n    int thread_id;\n    int thread_prio;\n\n    uint8_t * vtcm_base;\n    size_t    vtcm_size;\n    uint32_t  vtcm_rctx;\n\n    atomic_bool vtcm_valid;\n    atomic_bool vtcm_inuse;\n    atomic_bool vtcm_needs_release;\n\n    uint32_t opmask;\n};\n\n#endif /* HTP_CTX_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/htp-msg.h",
    "content": "#ifndef HTP_MSG_H\n#define HTP_MSG_H\n\n#include <assert.h>\n\n// ggml-common.h must be included prio to this header\n\n// Mask to enable various stages of the Ops.\n// Used for debugging and profiling.\nenum {\n    HTP_OPMASK_QUEUE    = (1 << 0),  // Enable Queueing (ie calls into the DSP)\n    HTP_OPMASK_QUANTIZE = (1 << 1),  // Enable Quantize\n    HTP_OPMASK_COMPUTE  = (1 << 2),  // Enable Compute\n};\n\n// Op flags\nenum {\n    HTP_OPFLAGS_SKIP_QUANTIZE = (1 << 0),  // Skip dynamic quantization (reuse quantized tensors)\n    HTP_OPFLAGS_SKIP_COMPUTE  = (1 << 1),  // Skip actual computation (used for profiling)\n    HTP_OPFLAGS_EARLY_WAKEUP  = (1 << 2)   // Send early wakeup notification\n};\n\nenum htp_status {\n    HTP_STATUS_OK             = 1,\n    HTP_STATUS_INTERNAL_ERR   = 2,\n    HTP_STATUS_NO_SUPPORT     = 3,\n    HTP_STATUS_INVAL_PARAMS   = 4,\n    HTP_STATUS_VTCM_TOO_SMALL = 5,\n};\n\n// The values must match the ggml_type.\n// Duplicated here because we can't include full ggml.h in the htp build.\n// We have some static_asserts in the cpp code to ensure things are in sync.\nenum htp_data_type {\n    HTP_TYPE_F32   = 0,\n    HTP_TYPE_F16   = 1,\n    HTP_TYPE_Q4_0  = 2,\n    HTP_TYPE_Q8_0  = 8,\n    HTP_TYPE_I32   = 26,\n    HTP_TYPE_I64   = 27,\n    HTP_TYPE_MXFP4 = 39,\n    HTP_TYPE_COUNT\n};\n\n// Do not reorder first 4 (used as an index)\nenum htp_op {\n    HTP_OP_MUL = 0,\n    HTP_OP_ADD = 1,\n    HTP_OP_SUB = 2,\n    HTP_OP_DIV = 3,\n    HTP_OP_MUL_MAT,\n    HTP_OP_MUL_MAT_ID,\n    HTP_OP_RMS_NORM,\n    HTP_OP_UNARY_SILU,\n    HTP_OP_UNARY_GELU,\n    HTP_OP_GLU_SWIGLU,\n    HTP_OP_GLU_SWIGLU_OAI,\n    HTP_OP_GLU_GEGLU,\n    HTP_OP_SOFTMAX,\n    HTP_OP_ADD_ID,\n    HTP_OP_ROPE,\n    HTP_OP_FLASH_ATTN_EXT,\n    HTP_OP_SET_ROWS,\n    HTP_OP_GET_ROWS,\n    HTP_OP_SCALE,\n    HTP_OP_CPY,\n    HTP_OP_ARGSORT,\n    HTP_OP_SQR,\n    HTP_OP_SQRT,\n    HTP_OP_SUM_ROWS,\n    HTP_OP_SSM_CONV,\n    INVALID\n};\n\nstatic inline size_t htp_t_block_size(uint32_t t) {\n    switch (t) {\n        case HTP_TYPE_F32:\n            return 1;\n        case HTP_TYPE_F16:\n            return 1;\n        case HTP_TYPE_Q4_0:\n            return QK4_0;\n        case HTP_TYPE_Q8_0:\n            return QK8_0;\n        case HTP_TYPE_MXFP4:\n            return QK_MXFP4;\n        default:\n            assert(0 && \"unsupported HTP data type\");\n    }\n    return 0;\n}\n\nstatic inline size_t htp_type_nbytes(uint32_t t) {\n    switch (t) {\n        case HTP_TYPE_F32:\n            return 4;\n        case HTP_TYPE_F16:\n            return 2;\n        case HTP_TYPE_Q4_0:\n            return sizeof(block_q4_0);\n        case HTP_TYPE_Q8_0:\n            return sizeof(block_q8_0);\n        case HTP_TYPE_MXFP4:\n            return sizeof(block_mxfp4);\n        default:\n            assert(0 && \"unsupported HTP data type\");\n    }\n    return 0;\n}\n\n// Internal types\n#define QK_Q4_0x4x2  256  // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)\n#define QK_Q8_0x4x2  256  // 4x Q8_0 blocks concat with next 4x Q8_0 blocks\n#define QK_MXFP4x4x2 256  // 4x MXFP4 blocks concat with next 4x MXFP4 blocks\n\n#define HTP_MAX_DIMS 4\n\nstruct htp_tensor {\n    uint32_t data;                // Buffer offset in the messages, and data pointer on the NSP\n    uint32_t type;                // Data type\n    uint32_t ne[HTP_MAX_DIMS];    // Number of elements\n    uint32_t nb[HTP_MAX_DIMS];    // Stride in bytes (see ggml.h ggml_tensor)\n};\n\n#define HTP_MAX_OP_PARAMS 64\n\nstruct htp_general_req {\n    uint32_t op;  // GGML/HTP Op\n    int32_t  op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];\n    // Params for the op, e.g. epsilon of RMS norm\n    uint32_t flags;          // Request flags\n\n    struct htp_tensor src0;  // Input0 tensor\n    struct htp_tensor src1;  // Input1 tensor\n    struct htp_tensor src2;  // Input2 tensor\n    struct htp_tensor src3;  // Input3 tensor\n    struct htp_tensor src4;  // Input4 tensor\n    struct htp_tensor dst;   // Output tensor\n\n    // should be multiple of 64 bytes (cacheline)\n};\n\nstruct htp_general_rsp {\n    uint32_t op;           // GGML/HTP Op\n    uint32_t status;       // HTP_STATUS_...\n    uint32_t prof_usecs;   // Number of usec per request\n    uint32_t prof_cycles;  // Number of cycles per request\n    uint32_t prof_pkts;    // Number of instruction packets per request\n    uint8_t  unused[44];   // Pad to 64 bytes\n};\n\n#define HTP_MAX_MESSAGE_SIZE   sizeof(struct htp_general_req)\n#define HTP_MAX_PACKET_BUFFERS 8\n\n#endif /* HTP_MSG_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/htp-ops.h",
    "content": "#ifndef HTP_OPS_H\n#define HTP_OPS_H\n\n#include \"htp-ctx.h\"\n#include \"htp-msg.h\"\n#include \"worker-pool.h\"\n\n#include <assert.h>\n#include <stdint.h>\n\n#include <hex-fastdiv.h>\n\n// ggml-common.h must be included prior to this header\n\nstruct htp_spad {\n    uint8_t * data;\n    size_t    stride;\n    size_t    size;\n    size_t    size_per_thread;\n};\n\nstruct htp_ops_context {\n    struct htp_context * ctx;\n\n    enum htp_op op;\n    int32_t     op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];\n\n    struct htp_tensor src0;\n    struct htp_tensor src1;\n    struct htp_tensor src2;\n    struct htp_tensor src3;\n    struct htp_tensor src4;\n    struct htp_tensor dst;\n\n    struct htp_spad src0_spad;\n    struct htp_spad src1_spad;\n    struct htp_spad src2_spad;\n    struct htp_spad src3_spad;\n    struct htp_spad dst_spad;\n\n    worker_pool_context_t * wpool;      // worker pool\n    uint32_t                n_threads;  // num threads\n\n    uint32_t flags;\n};\n\nint op_matmul(struct htp_ops_context * octx);\nint op_matmul_id(struct htp_ops_context * octx);\nint op_binary(struct htp_ops_context * octx);\nint op_unary(struct htp_ops_context * octx);\nint op_sum_rows(struct htp_ops_context * octx);\nint op_activations(struct htp_ops_context * octx);\nint op_softmax(struct htp_ops_context * octx);\nint op_add_id(struct htp_ops_context * octx);\nint op_rope(struct htp_ops_context * octx);\nint op_flash_attn_ext(struct htp_ops_context * octx);\nint op_set_rows(struct htp_ops_context * octx);\nint op_get_rows(struct htp_ops_context * octx);\nint op_cpy(struct htp_ops_context * octx);\nint op_argsort(struct htp_ops_context * octx);\nint op_ssm_conv(struct htp_ops_context * octx);\n\n#endif /* HTP_OPS_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/htp_iface.idl",
    "content": "// FastRPC IDL interface for GGML HTP\n\n#ifndef HTP_IDL\n#define HTP_IDL\n\n#include \"AEEStdDef.idl\"\n#include \"remote.idl\"\n\ninterface htp_iface : remote_handle64 {\n    AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx);\n    AEEResult stop();\n    AEEResult enable_etm();\n    AEEResult disable_etm();\n};\n\n#endif /* HTP_IDL */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hvx-arith.h",
    "content": "#ifndef HVX_ARITH_H\n#define HVX_ARITH_H\n\n#include <assert.h>\n#include <stddef.h>\n#include <stdint.h>\n#include <math.h>\n\n#include \"hvx-base.h\"\n#include \"hex-utils.h\"\n\n//\n// Binary operations (add, mul, sub)\n//\n\n#define UNUSED(x) (void)(x)\n\n#define hvx_arith_loop_body(dst_type, src0_type, src1_type, elem_size, vec_store, vec_op) \\\n    do {                                                                       \\\n        dst_type * restrict vdst  = (dst_type *) dst;                          \\\n        src0_type * restrict vsrc0 = (src0_type *) src0;                       \\\n        src1_type * restrict vsrc1 = (src1_type *) src1;                       \\\n                                                                               \\\n        const uint32_t epv  = 128 / (elem_size);                               \\\n        const uint32_t nvec = n / epv;                                         \\\n        const uint32_t nloe = n % epv;                                         \\\n                                                                               \\\n        uint32_t i = 0;                                                        \\\n                                                                               \\\n        _Pragma(\"unroll(4)\")                                                   \\\n        for (; i < nvec; i++) {                                                \\\n            vdst[i] = vec_op(vsrc0[i], vsrc1[i]);                              \\\n        }                                                                      \\\n        if (nloe) {                                                            \\\n            HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]);                         \\\n            vec_store((void *) &vdst[i], nloe * (elem_size), v);               \\\n        }                                                                      \\\n    } while(0)\n\n#if __HVX_ARCH__ < 79\n\n#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))\n#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))\n#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))\n\n#else\n\n#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)\n#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)\n#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)\n\n#endif\n\n#define HVX_OP_ADD_F16(a, b) hvx_vec_add_f16_f16(a, b)\n#define HVX_OP_SUB_F16(a, b) hvx_vec_sub_f16_f16(a, b)\n#define HVX_OP_MUL_F16(a, b) hvx_vec_mul_f16_f16(a, b)\n\n// Generic macro to define alignment permutations for an op\n#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO, ELEM_TYPE) \\\nstatic inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    assert((uintptr_t) dst % 128 == 0); \\\n    assert((uintptr_t) src0 % 128 == 0); \\\n    assert((uintptr_t) src1 % 128 == 0); \\\n    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \\\n} \\\nstatic inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    assert((uintptr_t) dst % 128 == 0); \\\n    assert((uintptr_t) src0 % 128 == 0); \\\n    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \\\n} \\\nstatic inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    assert((uintptr_t) dst % 128 == 0); \\\n    assert((uintptr_t) src1 % 128 == 0); \\\n    hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \\\n} \\\nstatic inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    assert((uintptr_t) dst % 128 == 0); \\\n    hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \\\n} \\\nstatic inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    assert((uintptr_t) src0 % 128 == 0); \\\n    assert((uintptr_t) src1 % 128 == 0); \\\n    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \\\n} \\\nstatic inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    assert((uintptr_t) src0 % 128 == 0); \\\n    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \\\n} \\\nstatic inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    assert((uintptr_t) src1 % 128 == 0); \\\n    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \\\n} \\\nstatic inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \\\n} \\\n\nDEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD_F32, float)\nDEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB_F32, float)\nDEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL_F32, float)\n\nDEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f16, HVX_OP_ADD_F16, _Float16)\nDEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f16, HVX_OP_SUB_F16, _Float16)\nDEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f16, HVX_OP_MUL_F16, _Float16)\n\n// Dispatcher logic\n#define HVX_BINARY_DISPATCHER(OP_NAME) \\\nstatic inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \\\n    if (hex_is_aligned((void *) dst, 128)) { \\\n        if (hex_is_aligned((void *) src0, 128)) { \\\n            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \\\n            else                                    OP_NAME##_aau(dst, src0, src1, num_elems); \\\n        } else { \\\n            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \\\n            else                                    OP_NAME##_auu(dst, src0, src1, num_elems); \\\n        } \\\n    } else { \\\n        if (hex_is_aligned((void *) src0, 128)) { \\\n            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \\\n            else                                    OP_NAME##_uau(dst, src0, src1, num_elems); \\\n        } else { \\\n            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \\\n            else                                    OP_NAME##_uuu(dst, src0, src1, num_elems); \\\n        } \\\n    } \\\n}\n\nHVX_BINARY_DISPATCHER(hvx_add_f32)\nHVX_BINARY_DISPATCHER(hvx_sub_f32)\nHVX_BINARY_DISPATCHER(hvx_mul_f32)\n\nHVX_BINARY_DISPATCHER(hvx_add_f16)\nHVX_BINARY_DISPATCHER(hvx_sub_f16)\nHVX_BINARY_DISPATCHER(hvx_mul_f16)\n\n// Mul-Mul Optimized\nstatic inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {\n    assert((unsigned long) dst % 128 == 0);\n    assert((unsigned long) src0 % 128 == 0);\n    assert((unsigned long) src1 % 128 == 0);\n    assert((unsigned long) src2 % 128 == 0);\n\n    HVX_Vector * restrict vdst  = (HVX_Vector *) dst;\n    HVX_Vector * restrict vsrc0 = (HVX_Vector *) src0;\n    HVX_Vector * restrict vsrc1 = (HVX_Vector *) src1;\n    HVX_Vector * restrict vsrc2 = (HVX_Vector *) src2;\n\n    const uint32_t elem_size = sizeof(float);\n    const uint32_t epv  = 128 / elem_size;\n    const uint32_t nvec = num_elems / epv;\n    const uint32_t nloe = num_elems % epv;\n\n    uint32_t i = 0;\n\n    _Pragma(\"unroll(4)\")\n    for (; i < nvec; i++) {\n        HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]);\n        vdst[i] = HVX_OP_MUL(v1, vsrc2[i]);\n    }\n\n    if (nloe) {\n        HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]);\n        HVX_Vector v2 = HVX_OP_MUL_F32(v1, vsrc2[i]);\n        hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2);\n    }\n}\n\n// Scalar Operations\n\n#define hvx_scalar_loop_body(dst_type, src_type, elem_size, vec_store, scalar_op_macro)   \\\n    do {                                                                       \\\n        dst_type * restrict vdst = (dst_type *) dst;                           \\\n        src_type * restrict vsrc = (src_type *) src;                           \\\n                                                                               \\\n        const uint32_t epv  = 128 / (elem_size);                               \\\n        const uint32_t nvec = n / epv;                                         \\\n        const uint32_t nloe = n % epv;                                         \\\n                                                                               \\\n        uint32_t i = 0;                                                        \\\n                                                                               \\\n        _Pragma(\"unroll(4)\")                                                   \\\n        for (; i < nvec; i++) {                                                \\\n            HVX_Vector v = vsrc[i];                                            \\\n            vdst[i] = scalar_op_macro(v);                                      \\\n        }                                                                      \\\n        if (nloe) {                                                            \\\n            HVX_Vector v = vsrc[i];                                            \\\n            v = scalar_op_macro(v);                                            \\\n            vec_store((void *) &vdst[i], nloe * (elem_size), v);               \\\n        }                                                                      \\\n    } while(0)\n\n#define HVX_OP_ADD_SCALAR_F32(v) \\\n    ({ \\\n        const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \\\n        HVX_Vector out = HVX_OP_ADD_F32(v, val_vec); \\\n        Q6_V_vmux_QVV(pred_inf, inf, out); \\\n    })\n\n#define HVX_OP_MUL_SCALAR_F32(v) HVX_OP_MUL_F32(v, val_vec)\n#define HVX_OP_SUB_SCALAR_F32(v) HVX_OP_SUB_F32(v, val_vec)\n\n#define HVX_OP_ADD_SCALAR_F16(v) \\\n    ({ \\\n        const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VhVh(inf, v); \\\n        HVX_Vector out = HVX_OP_ADD_F16(v, val_vec); \\\n        Q6_V_vmux_QVV(pred_inf, inf, out); \\\n    })\n\n#define HVX_OP_MUL_SCALAR_F16(v) HVX_OP_MUL_F16(v, val_vec)\n#define HVX_OP_SUB_SCALAR_F16(v) HVX_OP_SUB_F16(v, val_vec)\n\n// Scalar Variants\n\n// Generic macro to define alignment permutations for an op\n#define DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(OP_NAME, OP_MACRO, SPLAT_MACRO, ELEM_TYPE) \\\nstatic inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \\\n    const HVX_Vector val_vec = SPLAT_MACRO(val); \\\n    const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \\\n    assert((uintptr_t) dst % 128 == 0); \\\n    assert((uintptr_t) src % 128 == 0); \\\n    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \\\n} \\\nstatic inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \\\n    const HVX_Vector val_vec = SPLAT_MACRO(val); \\\n    const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \\\n    assert((uintptr_t) dst % 128 == 0); \\\n    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \\\n} \\\nstatic inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \\\n    const HVX_Vector val_vec = SPLAT_MACRO(val); \\\n    const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \\\n    assert((uintptr_t) src % 128 == 0); \\\n    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \\\n} \\\nstatic inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \\\n    const HVX_Vector val_vec = SPLAT_MACRO(val); \\\n    const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \\\n    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \\\n} \\\n\nDEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f32, HVX_OP_ADD_SCALAR_F32, hvx_vec_splat_f32, float)\nDEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f32, HVX_OP_SUB_SCALAR_F32, hvx_vec_splat_f32, float)\nDEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f32, HVX_OP_MUL_SCALAR_F32, hvx_vec_splat_f32, float)\n\nDEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f16, HVX_OP_ADD_SCALAR_F16, hvx_vec_splat_f16, _Float16)\nDEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f16, HVX_OP_SUB_SCALAR_F16, hvx_vec_splat_f16, _Float16)\nDEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f16, HVX_OP_MUL_SCALAR_F16, hvx_vec_splat_f16, _Float16)\n\n// Dispatcher logic\n#define HVX_BINARY_SCALAR_DISPATCHER(OP_NAME, ELEM_TYPE) \\\nstatic inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, const uint32_t num_elems) { \\\n    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \\\n        OP_NAME##_aa(dst, src, val, num_elems); \\\n    } else if (hex_is_aligned((void *) dst, 128)) { \\\n        OP_NAME##_au(dst, src, val, num_elems); \\\n    } else if (hex_is_aligned((void *) src, 128)) { \\\n        OP_NAME##_ua(dst, src, val, num_elems); \\\n    } else { \\\n        OP_NAME##_uu(dst, src, val, num_elems); \\\n    } \\\n}\n\nHVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f32, float)\nHVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f32, float)\nHVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f32, float)\n\nHVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f16, _Float16)\nHVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f16, _Float16)\nHVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f16, _Float16)\n\n// MIN Scalar variants\n\n#define HVX_OP_MIN_SCALAR(v) Q6_Vsf_vmin_VsfVsf(val_vec, v)\n\nstatic inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {\n    const HVX_Vector val_vec = hvx_vec_splat_f32(val);\n    assert((unsigned long) dst % 128 == 0);\n    assert((unsigned long) src % 128 == 0);\n    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR);\n}\n\nstatic inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {\n    const HVX_Vector val_vec = hvx_vec_splat_f32(val);\n    assert((unsigned long) dst % 128 == 0);\n    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR);\n}\n\nstatic inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {\n    const HVX_Vector val_vec = hvx_vec_splat_f32(val);\n    assert((unsigned long) src % 128 == 0);\n    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR);\n}\n\nstatic inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {\n    const HVX_Vector val_vec = hvx_vec_splat_f32(val);\n    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR);\n}\n\nstatic inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {\n    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {\n        hvx_min_scalar_f32_aa(dst, src, val, num_elems);\n    } else if (hex_is_aligned((void *) dst, 128)) {\n        hvx_min_scalar_f32_au(dst, src, val, num_elems);\n    } else if (hex_is_aligned((void *) src, 128)) {\n        hvx_min_scalar_f32_ua(dst, src, val, num_elems);\n    } else {\n        hvx_min_scalar_f32_uu(dst, src, val, num_elems);\n    }\n}\n\n// CLAMP Scalar variants\n\n#define HVX_OP_CLAMP_SCALAR(v) \\\n    ({ \\\n        HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(v, max_vec); \\\n        HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(min_vec, v); \\\n        HVX_Vector tmp = Q6_V_vmux_QVV(pred_cap_right, max_vec, v); \\\n        Q6_V_vmux_QVV(pred_cap_left, min_vec, tmp); \\\n    })\n\nstatic inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {\n    const HVX_Vector min_vec = hvx_vec_splat_f32(min);\n    const HVX_Vector max_vec = hvx_vec_splat_f32(max);\n    assert((unsigned long) dst % 128 == 0);\n    assert((unsigned long) src % 128 == 0);\n    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);\n}\n\nstatic inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {\n    const HVX_Vector min_vec = hvx_vec_splat_f32(min);\n    const HVX_Vector max_vec = hvx_vec_splat_f32(max);\n    assert((unsigned long) dst % 128 == 0);\n    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);\n}\n\nstatic inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {\n    const HVX_Vector min_vec = hvx_vec_splat_f32(min);\n    const HVX_Vector max_vec = hvx_vec_splat_f32(max);\n    assert((unsigned long) src % 128 == 0);\n    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);\n}\n\nstatic inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {\n    const HVX_Vector min_vec = hvx_vec_splat_f32(min);\n    const HVX_Vector max_vec = hvx_vec_splat_f32(max);\n    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);\n}\n\nstatic inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) {\n    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {\n        hvx_clamp_scalar_f32_aa(dst, src, min, max, num_elems);\n    } else if (hex_is_aligned((void *) dst, 128)) {\n        hvx_clamp_scalar_f32_au(dst, src, min, max, num_elems);\n    } else if (hex_is_aligned((void *) src, 128)) {\n        hvx_clamp_scalar_f32_ua(dst, src, min, max, num_elems);\n    } else {\n        hvx_clamp_scalar_f32_uu(dst, src, min, max, num_elems);\n    }\n}\n\n//\n// Square\n//\n\n#define hvx_sqr_f32_loop_body(dst_type, src_type, vec_store)           \\\n    do {                                                                   \\\n        dst_type * restrict vdst  = (dst_type *) dst;                      \\\n        src_type * restrict vsrc = (src_type *) src;                       \\\n                                                                           \\\n        const uint32_t elem_size = sizeof(float);                          \\\n        const uint32_t epv  = 128 / elem_size;                             \\\n        const uint32_t nvec = n / epv;                                     \\\n        const uint32_t nloe = n % epv;                                     \\\n                                                                           \\\n        uint32_t i = 0;                                                    \\\n                                                                           \\\n        _Pragma(\"unroll(4)\")                                               \\\n        for (; i < nvec; i++) {                                            \\\n            vdst[i] = HVX_OP_MUL_F32(vsrc[i], vsrc[i]);                        \\\n        }                                                                  \\\n        if (nloe) {                                                        \\\n            HVX_Vector v = HVX_OP_MUL_F32(vsrc[i], vsrc[i]);                   \\\n            vec_store((void *) &vdst[i], nloe * elem_size, v);             \\\n        }                                                                  \\\n    } while(0)\n\nstatic inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) dst % 128 == 0);\n    assert((unsigned long) src % 128 == 0);\n    hvx_sqr_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);\n}\n\nstatic inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) dst % 128 == 0);\n    hvx_sqr_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);\n}\n\nstatic inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) src % 128 == 0);\n    hvx_sqr_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);\n}\n\nstatic inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    hvx_sqr_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);\n}\n\nstatic inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {\n    if (hex_is_aligned((void *) dst, 128)) {\n        if (hex_is_aligned((void *) src, 128)) {\n            hvx_sqr_f32_aa(dst, src, num_elems);\n        } else {\n            hvx_sqr_f32_au(dst, src, num_elems);\n        }\n    } else {\n        if (hex_is_aligned((void *) src, 128)) {\n            hvx_sqr_f32_ua(dst, src, num_elems);\n        } else {\n            hvx_sqr_f32_uu(dst, src, num_elems);\n        }\n    }\n}\n\n#undef HVX_OP_ADD_F32\n#undef HVX_OP_SUB_F32\n#undef HVX_OP_MUL_F32\n#undef HVX_OP_ADD_F16\n#undef HVX_OP_SUB_F16\n#undef HVX_OP_MUL_F16\n#undef hvx_arith_loop_body\n#undef HVX_OP_ADD_SCALAR_F32\n#undef HVX_OP_SUB_SCALAR_F32\n#undef HVX_OP_MUL_SCALAR_F32\n#undef HVX_OP_ADD_SCALAR_F16\n#undef HVX_OP_SUB_SCALAR_F16\n#undef HVX_OP_MUL_SCALAR_F16\n#undef hvx_scalar_loop_body\n#undef HVX_OP_MIN_SCALAR\n#undef HVX_OP_CLAMP_SCALAR\n#undef DEFINE_HVX_BINARY_OP_VARIANTS\n#undef HVX_BINARY_DISPATCHER\n#undef UNUSED\n\n#endif // HVX_ARITH_H\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hvx-base.h",
    "content": "#ifndef HVX_BASE_H\n#define HVX_BASE_H\n\n#include <stdbool.h>\n#include <stdint.h>\n\n#include \"hex-utils.h\"\n#include \"hvx-types.h\"\n\nstatic inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) {\n    // Rotate as needed.\n    v = Q6_V_vlalign_VVR(v, v, (size_t) dst);\n\n    uint32_t left_off  = (size_t) dst & 127;\n    uint32_t right_off = left_off + n;\n\n    HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) dst);\n    HVX_VectorPred qr     = Q6_Q_vsetq2_R(right_off);\n\n    if (right_off > 128) {\n        Q6_vmem_QRIV(qr, (HVX_Vector *) dst + 1, v);\n        // all 1's\n        qr = Q6_Q_vcmp_eq_VbVb(v, v);\n    }\n\n    ql_not = Q6_Q_or_QQn(ql_not, qr);\n    Q6_vmem_QnRIV(ql_not, (HVX_Vector *) dst, v);\n}\n\nstatic inline void hvx_vec_store_a(void * restrict dst, uint32_t n, HVX_Vector v) {\n    assert((unsigned long) dst % 128 == 0);\n    HVX_VectorPred m = Q6_Q_or_QQn(Q6_Q_vsetq_R((unsigned long) dst), Q6_Q_vsetq2_R(n));\n    Q6_vmem_QnRIV(m, (HVX_Vector *) dst, v);\n}\n\nstatic inline HVX_Vector hvx_vec_splat_f32(float v) {\n    union { float  f; uint32_t i; } u = { .f = v };\n    return Q6_V_vsplat_R(u.i);\n}\n\nstatic inline HVX_Vector hvx_vec_splat_f16(_Float16 v) {\n    union { __fp16 f; uint16_t i; } u = { .f = v };\n    return Q6_Vh_vsplat_R(u.i);\n}\n\nstatic inline HVX_Vector hvx_vec_repl4(HVX_Vector v) {\n    // vdelta control to replicate first 4 bytes across all elements\n    static const uint8_t __attribute__((aligned(128))) repl[128] = {\n        0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,\n        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,\n        0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,\n        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,\n        0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,\n        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,\n        0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,\n        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,\n    };\n\n    HVX_Vector ctrl = *(HVX_Vector *) repl;\n    return Q6_V_vdelta_VV(v, ctrl);\n}\n\nstatic inline float hvx_vec_get_f32(HVX_Vector v) {\n    float __attribute__((aligned(128))) x;\n    hvx_vec_store_a(&x, 4, v);\n    return x;\n}\n\nstatic inline int32_t hvx_vec_get_i32(HVX_Vector v) {\n    int32_t __attribute__((aligned(128))) x;\n    hvx_vec_store_a(&x, 4, v);\n    return x;\n}\n\nstatic inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {\n    // abs by clearing the fp16 sign bit\n    HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);\n    return Q6_V_vand_VV(v, mask);\n}\n\nstatic inline HVX_Vector hvx_vec_neg_f16(HVX_Vector v) {\n    // neg by setting the fp16 sign bit\n    HVX_Vector mask = Q6_Vh_vsplat_R(0x8000);\n    return Q6_V_vxor_VV(v, mask);\n}\n\nstatic inline HVX_Vector hvx_vec_abs_f32(HVX_Vector v) {\n    // abs by clearing the fp32 sign bit\n    HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff);\n    return Q6_V_vand_VV(v, mask);\n}\n\nstatic inline HVX_Vector hvx_vec_neg_f32(HVX_Vector v) {\n#if __HVX_ARCH__ > 75\n    return Q6_Vsf_vfneg_Vsf(v);\n#else\n    // neg by setting the fp32 sign bit\n    HVX_Vector mask = Q6_V_vsplat_R(0x80000000);\n    return Q6_V_vxor_VV(v, mask);\n#endif  // __HVX_ARCH__ > 75\n}\n\nstatic inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) {\n    const HVX_Vector vnan_exp  = Q6_Vh_vsplat_R(0x7C00);\n    const HVX_Vector vnan_frac = Q6_Vh_vsplat_R(0x7FFF);\n\n    // get pred of which are NaN, i.e., exponent bits all 1s and fraction bits non 0s\n    HVX_VectorPred p_exp  = Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_exp), vnan_exp);\n    HVX_VectorPred p_frac = Q6_Q_not_Q(Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_frac), vnan_exp));\n    return Q6_Q_and_QQ(p_exp, p_frac);\n}\n\nstatic inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {\n    const HVX_Vector zero = Q6_V_vsplat_R(0);\n    HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero);\n    HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero);\n    HVX_Vector  v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)));\n\n#if __HVX_ARCH__ < 79\n    // replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)\n    const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY);\n    HVX_VectorPred nan = hvx_vec_is_nan_f16(v);\n    v = Q6_V_vmux_QVV(nan, neg_inf, v);\n#endif\n\n    return v;\n}\n\n/* Q6_Vsf_equals_Vw is only available on v73+.*/\n#if __HVX_ARCH__ < 73\nstatic inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)\n{\n    HVX_Vector const vzero = Q6_V_vzero();\n    HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero);\n    HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in);\n    HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift);\n    HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift);\n    HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized);\n    HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp));\n    return ret;\n}\n\nstatic inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)\n{\n    return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in));\n}\n#endif\n\nstatic inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {\n    // This looks complicated.\n    // Ideally should just be Q6_Vh_equals_Vhf(vin)\n    // but that instruction does not do proper rounding.\n\n    // convert to qf32, multiplying by 1.0 in the process.\n    HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00));\n\n    // 'in-range' values are +/32752.\n    // add 192K to it, convert to sf\n    HVX_Vector v192K = Q6_V_vsplat_R(0x48400000);\n    HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K));\n    HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K));\n\n    // for in-range cases, result is {163858... 229360} so the exponent is always 144.\n    // if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer.\n    // Start by <<10 to get the final 'sign' bit in bit 15...\n    vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10);\n    vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10);\n\n    // now round down to 16\n    return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0);\n}\n\n#if __HVX_ARCH__ < 79\n\nstatic inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vector x, HVX_Vector y)\n{\n    HVX_VectorPair m = Q6_Wqf32_vmpy_VhfVhf(x, y);\n    HVX_Vector a0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(m), Q6_V_lo_W(acc)));\n    HVX_Vector a1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(m), Q6_V_hi_W(acc)));\n    return Q6_W_vcombine_VV(a1, a0);\n}\n\n#else\n\nstatic inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vector x, HVX_Vector y)\n{\n    return Q6_Wsf_vmpyacc_WsfVhfVhf(acc, x, y);\n}\n\n#endif\n\n#if __HVX_ARCH__ < 79\n\nstatic inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b)\n{\n    const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16\n    const HVX_Vector one    = Q6_Vh_vsplat_R(0x3C00); //  1.0 in IEEE FP16\n    HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one);\n    HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone);\n    HVX_Vector a0 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p));\n    HVX_Vector a1 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p));\n    return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0));\n}\n\nstatic inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b)\n{\n    const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16\n    const HVX_Vector one    = Q6_Vh_vsplat_R(0x3C00); //  1.0 in IEEE FP16\n    HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one);\n    HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone);\n    HVX_Vector a0 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p));\n    HVX_Vector a1 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p));\n    return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0));\n}\n\nstatic inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)\n{\n    return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b));\n}\n\n#else\n\nstatic inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b)\n{\n    return Q6_Vhf_vadd_VhfVhf(a, b);\n}\n\nstatic inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b)\n{\n    return Q6_Vhf_vsub_VhfVhf(a, b);\n}\n\nstatic inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)\n{\n    return Q6_Vhf_vmpy_VhfVhf(a, b);\n}\n\n#endif // __HVX_ARCH__ < 79\n\n#endif /* HVX_BASE_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hvx-copy.h",
    "content": "#ifndef HVX_COPY_H\n#define HVX_COPY_H\n\n#include <assert.h>\n#include <stddef.h>\n#include <stdint.h>\n\n#include \"hvx-base.h\"\n\n#define hvx_splat_loop_body(dst_type, vec_store)                 \\\n    do {                                                         \\\n        dst_type * restrict vdst = (dst_type *) dst;             \\\n                                                                 \\\n        uint32_t nvec = n / (128 / elem_size);                   \\\n        uint32_t nloe = n % (128 / elem_size);                   \\\n                                                                 \\\n        uint32_t i = 0;                                          \\\n                                                                 \\\n        _Pragma(\"unroll(4)\")                                     \\\n        for (; i < nvec; i++) {                                  \\\n            vdst[i] = src;                                       \\\n        }                                                        \\\n        if (nloe) {                                              \\\n            vec_store((void *) &vdst[i], nloe * elem_size, src); \\\n        }                                                        \\\n    } while(0)\n\nstatic inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {\n    assert((unsigned long) dst % 128 == 0);\n    hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a);\n}\n\nstatic inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {\n    hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u);\n}\n\nstatic inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) {\n    hvx_splat_a(dst,  hvx_vec_splat_f32(v), n, sizeof(float));\n}\n\nstatic inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) {\n    hvx_splat_u(dst,  hvx_vec_splat_f32(v), n, sizeof(float));\n}\n\nstatic inline void hvx_splat_f16_a(uint8_t * restrict dst, _Float16 v, uint32_t n) {\n    hvx_splat_u(dst,  hvx_vec_splat_f16(v), n, sizeof(__fp16));\n}\n\nstatic inline void hvx_splat_f16_u(uint8_t * restrict dst, _Float16 v, uint32_t n) {\n    hvx_splat_u(dst,  hvx_vec_splat_f16(v), n, sizeof(__fp16));\n}\n\n#define hvx_copy_loop_body(dst_type, src_type, vec_store)            \\\n    do {                                                             \\\n        dst_type * restrict vdst = (dst_type *) dst;                 \\\n        src_type * restrict vsrc = (src_type *) src;                 \\\n                                                                     \\\n        const uint32_t epv  = 128 / elem_size;                       \\\n        const uint32_t nvec = n / epv;                               \\\n        const uint32_t nloe = n % epv;                               \\\n                                                                     \\\n        uint32_t i = 0;                                              \\\n                                                                     \\\n        _Pragma(\"unroll(4)\")                                         \\\n        for (; i < nvec; i++) { vdst[i] = vsrc[i]; }                 \\\n        if (nloe) {                                                  \\\n            vec_store((void *) &vdst[i], nloe * elem_size, vsrc[i]); \\\n        }                                                            \\\n    } while(0)\n\n// Generic copy routines\nstatic inline void hvx_copy_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {\n    assert((unsigned long) dst % 128 == 0);\n    assert((unsigned long) src % 128 == 0);\n    hvx_copy_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);\n}\n\nstatic inline void hvx_copy_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {\n    assert((unsigned long) dst % 128 == 0);\n    hvx_copy_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);\n}\n\nstatic inline void hvx_copy_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {\n    assert((unsigned long) src % 128 == 0);\n    hvx_copy_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);\n}\n\nstatic inline void hvx_copy_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {\n    hvx_copy_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);\n}\n\n// copy n fp16 elements : source and destination are aligned to HVX Vector (128)\nstatic inline void hvx_copy_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    hvx_copy_aa(dst, src, n, sizeof(__fp16));\n}\n\n// copy n fp16 elements : source is aligned, destination is potentially unaligned\nstatic inline void hvx_copy_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    hvx_copy_au(dst, src, n, sizeof(__fp16));\n}\n\n// copy n fp16 elements : source is aligned, destination is potentially unaligned\nstatic inline void hvx_copy_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    hvx_copy_ua(dst, src, n, sizeof(__fp16));\n}\n\n// copy n fp16 elements : source is aligned, destination is potentially unaligned\nstatic inline void hvx_copy_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    hvx_copy_uu(dst, src, n, sizeof(__fp16));\n}\n\n// copy n fp32 elements : source and destination are aligned to HVX Vector (128)\nstatic inline void hvx_copy_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    hvx_copy_aa(dst, src, n, sizeof(float));\n}\n\n// copy n fp32 elements : source is aligned, destination is unaligned\nstatic inline void hvx_copy_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    hvx_copy_ua(dst, src, n, sizeof(float));\n}\n\n// copy n fp32 elements : source is unaligned, destination is aligned\nstatic inline void hvx_copy_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    hvx_copy_au(dst, src, n, sizeof(float));\n}\n\n// copy n fp32 elements : source is unaligned, destination unaligned\nstatic inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    hvx_copy_uu(dst, src, n, sizeof(float));\n}\n\n//// fp32 -> fp16\n\n#define hvx_copy_f16_f32_loop_body(dst_type, src_type, vec_store)                   \\\n    do {                                                                            \\\n        dst_type * restrict vdst = (dst_type *) dst;                                \\\n        src_type * restrict vsrc = (src_type *) src;                                \\\n                                                                                    \\\n        const uint32_t elem_size = sizeof(__fp16);                                  \\\n        const uint32_t epv  = 128 / elem_size;                                      \\\n        const uint32_t nvec = n / epv;                                              \\\n        const uint32_t nloe = n % epv;                                              \\\n                                                                                    \\\n        uint32_t i = 0;                                                             \\\n                                                                                    \\\n        _Pragma(\"unroll(4)\")                                                        \\\n        for (; i < nvec; i++) {                                                     \\\n            vdst[i] = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]);                 \\\n        }                                                                           \\\n        if (nloe) {                                                                 \\\n            HVX_Vector v = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]);            \\\n            vec_store((void *) &vdst[i], nloe * elem_size, v);                      \\\n        }                                                                           \\\n    } while(0)\n\n// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is aligned\nstatic inline void hvx_copy_f16_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) dst % 128 == 0);\n    assert((unsigned long) src % 128 == 0);\n    hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);\n}\n\n// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned\nstatic inline void hvx_copy_f16_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) dst % 128 == 0);\n    hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);\n}\n\n// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned\nstatic inline void hvx_copy_f16_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) src % 128 == 0);\n    hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);\n}\n\n// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned\nstatic inline void hvx_copy_f16_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);\n}\n\n//// fp16 -> fp32\n\n#define hvx_copy_f32_f16_loop_body(dst_type, src_type, vec_store)                   \\\n    do {                                                                            \\\n        dst_type * restrict vdst = (dst_type *) dst;                                \\\n        src_type * restrict vsrc = (src_type *) src;                                \\\n                                                                                    \\\n        const HVX_Vector one = hvx_vec_splat_f16(1.0);                              \\\n                                                                                    \\\n        const uint32_t elem_size = sizeof(__fp16);                                  \\\n        const uint32_t epv  = 128 / elem_size;                                      \\\n        const uint32_t nvec = n / epv;                                              \\\n              uint32_t nloe = n % epv;                                              \\\n                                                                                    \\\n        uint32_t i = 0;                                                             \\\n                                                                                    \\\n        _Pragma(\"unroll(4)\")                                                        \\\n        for (i = 0; i < nvec; ++i) {                                                \\\n            HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \\\n            vdst[i*2]   = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p));                        \\\n            vdst[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p));                        \\\n        }                                                                           \\\n                                                                                    \\\n        if (nloe) {                                                                 \\\n            HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \\\n                                                                                    \\\n            HVX_Vector vd = Q6_V_lo_W(p);                                           \\\n            i = 2 * i;                                                              \\\n                                                                                    \\\n            if (nloe >= 32) {                                                       \\\n                vdst[i] = Q6_Vsf_equals_Vqf32(vd);                                  \\\n                nloe -= 32; ++i; vd = Q6_V_hi_W(p);                                 \\\n            }                                                                       \\\n                                                                                    \\\n            if (nloe) {                                                             \\\n                vd = Q6_Vsf_equals_Vqf32(vd);                                       \\\n                hvx_vec_store_u(&vdst[i], nloe * sizeof(float), vd);                \\\n            }                                                                       \\\n        }                                                                           \\\n    } while(0)\n\n// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is aligned\nstatic inline void hvx_copy_f32_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) dst % 128 == 0);\n    assert((unsigned long) src % 128 == 0);\n    hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);\n}\n\n// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is aligned\nstatic inline void hvx_copy_f32_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) dst % 128 == 0);\n    hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);\n}\n\n// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is unaligned\nstatic inline void hvx_copy_f32_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) src % 128 == 0);\n    hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);\n}\n\n// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is unaligned\nstatic inline void hvx_copy_f32_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);\n}\n\n#endif // HVX_COPY_H\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hvx-div.h",
    "content": "#ifndef HVX_DIV_H\n#define HVX_DIV_H\n\n#include <HAP_farf.h>\n\n#include <math.h>\n#include <string.h>\n#include <assert.h>\n#include <stddef.h>\n#include <stdint.h>\n\n#include \"hvx-base.h\"\n#include \"hex-utils.h\"\n#include \"hvx-inverse.h\"\n#include \"hvx-arith.h\"\n\n#if __HVX_ARCH__ < 79\n#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))\n#else\n#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)\n#endif\n\n// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32.\nstatic inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX_Vector vec2_sf_const, HVX_Vector vec_hf_one_1_0) {\n#if __HVX_ARCH__ < 79\n    HVX_VectorPair src_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0);\n    HVX_Vector src_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(src_to_f32));\n    HVX_Vector src_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(src_to_f32));\n#else\n    HVX_VectorPair src_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0);\n    HVX_Vector src_to_f32_0 = Q6_V_lo_W(src_to_f32);\n    HVX_Vector src_to_f32_1 = Q6_V_hi_W(src_to_f32);\n#endif\n\n    HVX_Vector div_f32_0 = HVX_OP_MUL_F32(src_to_f32_0, vec2_sf_const);\n    HVX_Vector div_f32_1 = HVX_OP_MUL_F32(src_to_f32_1, vec2_sf_const);\n\n#if __HVX_ARCH__ < 79\n    HVX_Vector res = hvx_vec_f32_to_f16(div_f32_0, div_f32_1);\n#else\n    HVX_Vector res = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1);\n#endif\n    return res;\n}\n\n#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store)                     \\\n    do {                                                                                \\\n        dst_type * restrict vdst = (dst_type *) dst;                                    \\\n        src_type * restrict vsrc = (src_type *) src;                                    \\\n        HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00);                                     \\\n                                                                                        \\\n        const uint32_t nvec = n / VLEN_FP16;                                            \\\n        const uint32_t nloe = n % VLEN_FP16;                                            \\\n                                                                                        \\\n        uint32_t i = 0;                                                                 \\\n                                                                                        \\\n        _Pragma(\"unroll(4)\")                                                            \\\n        for (; i < nvec; i++) {                                                         \\\n            HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \\\n            vdst[i] = res;                                                              \\\n        }                                                                               \\\n        if (nloe) {                                                                     \\\n            HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \\\n            vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res);                      \\\n        }                                                                               \\\n    } while(0)\n\nstatic inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {\n    const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));\n    assert((uintptr_t) dst % 128 == 0);\n    assert((uintptr_t) src % 128 == 0);\n    hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);\n}\nstatic inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {\n    const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));\n    assert((uintptr_t) dst % 128 == 0);\n    hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);\n}\nstatic inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {\n    const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));\n    assert((uintptr_t) src % 128 == 0);\n    hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);\n}\nstatic inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {\n    const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));\n    hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);\n}\n\n// Compute div by using hvx_vec_inverse_f32_guard. Requires first by exapnding fp32 to fp16 and convert the result back to fp32.\nstatic inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector vec_hf_one_1_0) {\n#if __HVX_ARCH__ < 79\n    // Convert first input to fp32\n    HVX_VectorPair vec1_to_f32   = Q6_Wqf32_vmpy_VhfVhf(vec1, vec_hf_one_1_0);  // *1.0\n    HVX_Vector     vec1_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec1_to_f32));\n    HVX_Vector     vec1_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec1_to_f32));\n\n    // Convert second input to fp32\n    HVX_VectorPair vec2_to_f32   = Q6_Wqf32_vmpy_VhfVhf(vec2, vec_hf_one_1_0);  // *1.0\n    HVX_Vector     vec2_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec2_to_f32));\n    HVX_Vector     vec2_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec2_to_f32));\n#else\n    // Convert first input to fp32\n    HVX_VectorPair vec1_to_f32   = Q6_Wsf_vmpy_VhfVhf(vec1, vec_hf_one_1_0);  // *1.0\n    HVX_Vector     vec1_to_f32_0 = Q6_V_lo_W(vec1_to_f32);\n    HVX_Vector     vec1_to_f32_1 = Q6_V_hi_W(vec1_to_f32);\n\n    // Convert second input to fp32\n    HVX_VectorPair vec2_to_f32   = Q6_Wsf_vmpy_VhfVhf(vec2, vec_hf_one_1_0);  // *1.0\n    HVX_Vector     vec2_to_f32_0 = Q6_V_lo_W(vec2_to_f32);\n    HVX_Vector     vec2_to_f32_1 = Q6_V_hi_W(vec2_to_f32);\n#endif\n\n    // Inverse second input in fp32\n    HVX_Vector     vec2_inv_f32_0 = hvx_vec_inverse_f32_guard(vec2_to_f32_0, f32_nan_inf_mask);\n    HVX_Vector     vec2_inv_f32_1 = hvx_vec_inverse_f32_guard(vec2_to_f32_1, f32_nan_inf_mask);\n\n    // Multiply first input by inverse of second, in fp32\n    HVX_Vector     div_f32_0 = HVX_OP_MUL_F32(vec1_to_f32_0, vec2_inv_f32_0);\n    HVX_Vector     div_f32_1 = HVX_OP_MUL_F32(vec1_to_f32_1, vec2_inv_f32_1);\n\n    // Convert back to fp16\n#if __HVX_ARCH__ < 79\n    HVX_Vector     recip = hvx_vec_f32_to_f16(div_f32_0, div_f32_1);\n#else\n    HVX_Vector     recip = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1);\n#endif\n\n    return recip;\n}\n\n#define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store)                  \\\n    do {                                                                                  \\\n        dst_type * restrict vdst = (dst_type *) dst;                                      \\\n        src0_type * restrict vsrc0 = (src0_type *) src0;                                  \\\n        src1_type * restrict vsrc1 = (src1_type *) src1;                                  \\\n                                                                                          \\\n        const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000);                        \\\n        const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00);                                 \\\n                                                                                          \\\n        const uint32_t nvec = n / VLEN_FP16;                                              \\\n        const uint32_t nloe = n % VLEN_FP16;                                              \\\n                                                                                          \\\n        uint32_t i = 0;                                                                   \\\n                                                                                          \\\n        _Pragma(\"unroll(4)\")                                                              \\\n        for (; i < nvec; i++) {                                                           \\\n            HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \\\n            vdst[i] = res;                                                                \\\n        }                                                                                 \\\n        if (nloe) {                                                                       \\\n            HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \\\n            vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res);                        \\\n        }                                                                                 \\\n    } while(0)\n\n#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store)             \\\n    do {                                                                             \\\n        dst_type * restrict vdst = (dst_type *) dst;                                 \\\n        src0_type * restrict vsrc0 = (src0_type *) src0;                             \\\n        src1_type * restrict vsrc1 = (src1_type *) src1;                             \\\n                                                                                     \\\n        const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000);                   \\\n                                                                                     \\\n        const uint32_t nvec = n / VLEN_FP32;                                         \\\n        const uint32_t nloe = n % VLEN_FP32;                                         \\\n                                                                                     \\\n        uint32_t i = 0;                                                              \\\n                                                                                     \\\n        _Pragma(\"unroll(4)\")                                                         \\\n        for (; i < nvec; i++) {                                                      \\\n            HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \\\n            HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1);                     \\\n            vdst[i] = res;                                                           \\\n        }                                                                            \\\n        if (nloe) {                                                                  \\\n            HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \\\n            HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1);                     \\\n            vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res);                   \\\n        }                                                                            \\\n    } while(0)\n\n// Generic macro to define alignment permutations for an op\n#define DEFINE_HVX_DIV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \\\nstatic inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    assert((uintptr_t) dst % 128 == 0); \\\n    assert((uintptr_t) src0 % 128 == 0); \\\n    assert((uintptr_t) src1 % 128 == 0); \\\n    OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); \\\n} \\\nstatic inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    assert((uintptr_t) dst % 128 == 0); \\\n    assert((uintptr_t) src0 % 128 == 0); \\\n    OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); \\\n} \\\nstatic inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    assert((uintptr_t) dst % 128 == 0); \\\n    assert((uintptr_t) src1 % 128 == 0); \\\n    OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); \\\n} \\\nstatic inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    assert((uintptr_t) dst % 128 == 0); \\\n    OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); \\\n} \\\nstatic inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    assert((uintptr_t) src0 % 128 == 0); \\\n    assert((uintptr_t) src1 % 128 == 0); \\\n    OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); \\\n} \\\nstatic inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    assert((uintptr_t) src0 % 128 == 0); \\\n    OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); \\\n} \\\nstatic inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    assert((uintptr_t) src1 % 128 == 0); \\\n    OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); \\\n} \\\nstatic inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \\\n    OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); \\\n} \\\n\n// Dispatcher logic\n#define HVX_DIV_DISPATCHER(OP_NAME) \\\nstatic inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \\\n    if (hex_is_aligned((void *) dst, 128)) { \\\n        if (hex_is_aligned((void *) src0, 128)) { \\\n            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \\\n            else                                    OP_NAME##_aau(dst, src0, src1, num_elems); \\\n        } else { \\\n            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \\\n            else                                    OP_NAME##_auu(dst, src0, src1, num_elems); \\\n        } \\\n    } else { \\\n        if (hex_is_aligned((void *) src0, 128)) { \\\n            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \\\n            else                                    OP_NAME##_uau(dst, src0, src1, num_elems); \\\n        } else { \\\n            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \\\n            else                                    OP_NAME##_uuu(dst, src0, src1, num_elems); \\\n        } \\\n    } \\\n}\n\nDEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f32, hvx_div_f32_loop_body)\nDEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f16, hvx_div_f16_loop_body)\n\nHVX_DIV_DISPATCHER(hvx_div_f32)\nHVX_DIV_DISPATCHER(hvx_div_f16)\n\n#undef HVX_OP_MUL_F32\n\n#endif // HVX_DIV_H\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hvx-dump.h",
    "content": "#ifndef HVX_DUMP_H\n#define HVX_DUMP_H\n\n#include <HAP_farf.h>\n\n#include <stdbool.h>\n#include <stdint.h>\n\n#include \"hex-utils.h\"\n#include \"hvx-types.h\"\n\nstatic void hvx_vec_dump_f16_n(char * pref, HVX_Vector v, uint32_t n) {\n    HVX_VectorAlias u = { .v = v };\n\n    const uint32_t n0 = n / 16;\n    const uint32_t n1 = n % 16;\n    int            i  = 0;\n    for (; i < n0; i++) {\n        hex_dump_f16_line(pref, u.fp16 + (16 * i), 16);\n    }\n    if (n1) {\n        hex_dump_f16_line(pref, u.fp16 + (16 * i), n1);\n    }\n}\n\nstatic void hvx_vec_dump_f16(char * pref, HVX_Vector v) {\n    hvx_vec_dump_f16_n(pref, v, 64);\n}\n\nstatic void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) {\n    HVX_VectorAlias u = { .v = v };\n\n    const uint32_t n0 = n / 16;\n    const uint32_t n1 = n % 16;\n    int            i  = 0;\n    for (; i < n0; i++) {\n        hex_dump_f32_line(pref, u.fp32 + (16 * i), 16);\n    }\n    if (n1) {\n        hex_dump_f32_line(pref, u.fp32 + (16 * i), n1);\n    }\n}\n\nstatic void hvx_vec_dump_f32_hmt(char * pref, HVX_Vector v) {\n    union {\n        HVX_Vector v;\n        float      d[32];\n    } u = { .v = v };\n\n    FARF(HIGH, \"%s: %.6f %.6f %.6f %.6f ...  %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\\n\", pref, u.d[0], u.d[1],\n         u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);\n}\n\nstatic void hvx_vec_dump_f32(char * pref, HVX_Vector v) {\n    hvx_vec_dump_f32_n(pref, v, 32);\n}\n\nstatic void hvx_vec_dump_int32(char * pref, HVX_Vector v) {\n    union {\n        HVX_Vector v;\n        int32_t    d[32];\n    } u = { .v = v };\n\n    for (int i = 0; i < 32 / 16; i++) {\n        hex_dump_int32_line(pref, u.d + (16 * i), 16);\n    }\n}\n\nstatic void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) {\n    union {\n        HVX_Vector v;\n        int32_t    d[32];\n    } u = { .v = v };\n\n    FARF(HIGH, \"%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\\n\", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12],\n         u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);\n}\n\nstatic void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) {\n    union {\n        HVX_Vector v;\n        int8_t     d[128];\n    } u = { .v = v };\n\n    FARF(HIGH, \"%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\\n\", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60],\n         u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]);\n}\n\nstatic void hvx_vec_dump_int8(char * pref, HVX_Vector v) {\n    union {\n        HVX_Vector v;\n        int8_t     d[128];\n    } u = { .v = v };\n\n    for (int i = 0; i < 128 / 16; i++) {\n        hex_dump_int8_line(pref, u.d + (16 * i), 16);\n    }\n}\n\nstatic void hvx_vec_dump_uint8(char * pref, HVX_Vector v) {\n    union {\n        HVX_Vector v;\n        uint8_t    d[128];\n    } u = { .v = v };\n\n    for (int i = 0; i < 128 / 16; i++) {\n        hex_dump_uint8_line(pref, u.d + (16 * i), 16);\n    }\n}\n\nstatic bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) {\n    typedef union {\n        HVX_Vector v;\n        int8_t     d[128];\n    } U;\n\n    U u0 = { .v = v0 };\n    U u1 = { .v = v1 };\n\n    for (int i = 0; i < n; i++) {\n        if (u0.d[i] != u1.d[i]) {\n            return false;\n        }\n    }\n\n    return true;\n}\n\n#endif /* HVX_DUMP_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hvx-exp.h",
    "content": "#ifndef HVX_EXP_H\n#define HVX_EXP_H\n\n#include <stdbool.h>\n#include <stdint.h>\n\n#include \"hvx-base.h\"\n#include \"hvx-floor.h\"\n\n#define EXP_COEFF_5 (0x39506967)  // 0.000198757 = 1/(7!)\n#define EXP_COEFF_4 (0x3AB743CE)  // 0.0013982   = 1/(6!)\n#define EXP_COEFF_3 (0x3C088908)  // 0.00833345  = 1/(5!)\n#define EXP_COEFF_2 (0x3D2AA9C1)  // 0.416658    = 1/(4!)\n#define EXP_COEFF_1 (0x3E2AAAAA)  // 0.16666667  = 1/(3!)\n#define EXP_COEFF_0 (0x3F000000)  // 0.5         = 1/(2!)\n#define EXP_LOGN2   (0x3F317218)  // ln(2)   = 0.6931471805\n#define EXP_LOG2E   (0x3FB8AA3B)  // log2(e) = 1/ln(2) = 1.4426950408\n#define EXP_ONE     (0x3f800000)  // 1.0\n#define EXP_RANGE_R (0x41a00000)  // 20.0\n#define EXP_RANGE_L (0xc1a00000)  // -20.0\n\nstatic inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {\n    HVX_Vector z_qf32_v;\n    HVX_Vector x_v;\n    HVX_Vector x_qf32_v;\n    HVX_Vector y_v;\n    HVX_Vector k_v;\n    HVX_Vector f_v;\n    HVX_Vector epsilon_v;\n    HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E);\n    HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2);\n    HVX_Vector E_const;\n    HVX_Vector zero_v = Q6_V_vzero();\n\n    // exp(x) is approximated as follows:\n    //   f = floor(x/ln(2)) = floor(x*log2(e))\n    //   epsilon = x - f*ln(2)\n    //   exp(x) = exp(epsilon+f*ln(2))\n    //          = exp(epsilon)*exp(f*ln(2))\n    //          = exp(epsilon)*2^f\n    //\n    //   Since epsilon is close to zero, it can be approximated with its Taylor series:\n    //            exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+...\n    //   Preserving the first eight elements, we get:\n    //            exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7\n    //                   =  1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2\n\n    HVX_Vector temp_v = in_vec;\n\n    // Clamp inputs to (-20.0, 20.0)\n    HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));\n    HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);\n\n    in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);\n    in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v);\n\n    epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);\n    epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);\n\n    //    f_v is the floating point result and k_v is the integer result\n    f_v = hvx_vec_floor_f32(epsilon_v);\n    k_v = hvx_vec_truncate_f32(f_v);\n\n    x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v);\n\n    //  x = x - f_v * logn2;\n    epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2);\n    x_qf32_v  = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v);\n    // normalize before every QFloat's vmpy\n    x_qf32_v  = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);\n\n    // z = x * x;\n    z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);\n    z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);\n\n    x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);\n\n    // y = E4 + E5 * x;\n    E_const = Q6_V_vsplat_R(EXP_COEFF_5);\n    y_v     = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);\n    E_const = Q6_V_vsplat_R(EXP_COEFF_4);\n    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);\n    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);\n\n    // y = E3 + y * x;\n    E_const = Q6_V_vsplat_R(EXP_COEFF_3);\n    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);\n    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);\n    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);\n\n    // y = E2 + y * x;\n    E_const = Q6_V_vsplat_R(EXP_COEFF_2);\n    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);\n    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);\n    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);\n\n    // y = E1 + y * x;\n    E_const = Q6_V_vsplat_R(EXP_COEFF_1);\n    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);\n    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);\n    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);\n\n    // y = E0 + y * x;\n    E_const = Q6_V_vsplat_R(EXP_COEFF_0);\n    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);\n    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);\n    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);\n\n    // y = x + y * z;\n    y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v);\n    y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v);\n    y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);\n\n    // y = y + 1.0;\n    y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE));\n\n    // insert exponents\n    //        y = ldexpf(y, k);\n    //    y_v += k_v; // qf32\n    // modify exponent\n\n    y_v = Q6_Vsf_equals_Vqf32(y_v);\n\n    // add k_v to the exponent of y_v\n    HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1);\n\n    y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1);\n    y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent);\n\n    // exponent cannot be negative; if overflow is detected, result is set to zero\n    HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent);\n\n    y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN);\n\n    y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v);\n\n    return y_v;\n}\n\nstatic inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) {\n    const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);\n\n    HVX_Vector out = hvx_vec_exp_f32(in_vec);\n\n    return Q6_V_vmux_QVV(pred0, inf, out);\n}\n\nstatic inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {\n    int left_over       = num_elems & (VLEN_FP32 - 1);\n    int num_elems_whole = num_elems - left_over;\n\n    int unaligned_addr = 0;\n    int unaligned_loop = 0;\n    if ((0 == hex_is_aligned((void *) src, VLEN)) || (0 == hex_is_aligned((void *) dst, VLEN))) {\n        unaligned_addr = 1;\n    }\n    // assert((0 == unaligned_addr) || (0 == num_elems_whole));\n    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {\n        unaligned_loop = 1;\n    }\n\n    HVX_Vector vec_out = Q6_V_vzero();\n\n    static const float kInf    = INFINITY;\n    static const float kMaxExp = 88.02f;  // log(INF)\n\n    const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);\n    const HVX_Vector inf     = hvx_vec_splat_f32(kInf);\n\n    if (0 == unaligned_loop) {\n        HVX_Vector * p_vec_in1 = (HVX_Vector *) src;\n        HVX_Vector * p_vec_out = (HVX_Vector *) dst;\n\n        #pragma unroll(4)\n        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {\n            if (true == negate) {\n                HVX_Vector neg_vec_in = hvx_vec_neg_f32(*p_vec_in1++);\n                *p_vec_out++          = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);\n            } else {\n                *p_vec_out++ = hvx_vec_exp_f32_guard(*p_vec_in1++, max_exp, inf);\n            }\n        }\n    } else {\n        #pragma unroll(4)\n        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {\n            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);\n\n            if (true == negate) {\n                HVX_Vector neg_vec_in                    = hvx_vec_neg_f32(in);\n                *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);\n            } else {\n                *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(in, max_exp, inf);\n            }\n        }\n    }\n\n    if (left_over > 0) {\n        const float * srcf = (float *) src + num_elems_whole;\n        float *       dstf = (float *) dst + num_elems_whole;\n\n        HVX_Vector in = *(HVX_UVector *) srcf;\n\n        if (true == negate) {\n            HVX_Vector neg_vec_in = hvx_vec_neg_f32(in);\n\n            vec_out = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);\n        } else {\n            vec_out = hvx_vec_exp_f32_guard(in, max_exp, inf);\n        }\n\n        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);\n    }\n}\n\n#endif /* HVX_EXP_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hvx-floor.h",
    "content": "#ifndef HVX_FLOOR_H\n#define HVX_FLOOR_H\n\n#include <stdbool.h>\n#include <stdint.h>\n\n#include \"hvx-base.h\"\n\n#define IEEE_VSF_EXPLEN   (8)\n#define IEEE_VSF_EXPBIAS  (127)\n#define IEEE_VSF_EXPMASK  (0xFF)\n#define IEEE_VSF_MANTLEN  (23)\n#define IEEE_VSF_MANTMASK (0x7FFFFF)\n#define IEEE_VSF_MIMPMASK (0x800000)\n\nstatic inline HVX_Vector hvx_vec_truncate_f32(HVX_Vector in_vec) {\n    HVX_Vector mask_mant_v  = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);\n    HVX_Vector mask_impl_v  = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);\n    HVX_Vector const_zero_v = Q6_V_vzero();\n\n    HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);\n\n    HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;\n    expval_v &= IEEE_VSF_EXPMASK;\n    expval_v -= IEEE_VSF_EXPBIAS;\n\n    // negative exp == fractional value\n    HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);\n\n    HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v;         // fractional bits - exp shift\n\n    HVX_Vector mant_v = in_vec & mask_mant_v;                  // obtain mantissa\n    HVX_Vector vout   = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v);  // add implicit 1.0\n\n    vout = Q6_Vw_vasr_VwVw(vout, rshift_v);                    // shift to obtain truncated integer\n    vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout);        // expval<0 -> 0\n\n    HVX_Vector neg_vout = -vout;\n\n    vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout);  // handle negatives\n\n    return (vout);\n}\n\nstatic inline HVX_Vector hvx_vec_floor_f32(HVX_Vector in_vec) {\n    HVX_Vector mask_mant_v    = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);\n    HVX_Vector mask_impl_v    = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);\n    HVX_Vector const_mnlen_v  = Q6_V_vsplat_R(IEEE_VSF_MANTLEN);\n    HVX_Vector const_zero_v   = Q6_V_vzero();\n    HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000);  // -1 IEEE vsf\n\n    HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);\n\n    HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;\n    expval_v &= IEEE_VSF_EXPMASK;\n    expval_v -= IEEE_VSF_EXPBIAS;\n\n    HVX_VectorPred q_negexp     = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);\n    HVX_VectorPred q_expltmn    = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v);\n    HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v);\n    HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec);\n\n    // if expval < 0 (q_negexp)         // <0, floor is 0\n    //    if vin > 0\n    //       floor = 0\n    //    if vin < 0\n    //       floor = -1\n    // if expval < mant_len (q_expltmn) // >0, but fraction may exist\n    //    get sign (q_negative)\n    //    mask >> expval                // fraction bits to mask off\n    //    vout = ~(mask)                // apply mask to remove fraction\n    //    if (qneg)                     // negative floor is one less (more, sign bit for neg)\n    //      vout += ((impl_mask) >> expval)\n    //    if (mask && vin)\n    //      vout = vin\n    // else                             // already an integer\n    //    ;                             // no change\n\n    // compute floor\n    mask_mant_v >>= expval_v;\n    HVX_Vector neg_addin_v    = mask_impl_v >> expval_v;\n    HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v);\n    HVX_Vector vout           = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec);\n\n    HVX_Vector     mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v);  // chk if bits set\n    HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v);\n\n    HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v);        // frac bits to clear\n    HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v);  // clear frac bits\n\n    vout = in_vec;\n    vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout);         // expval<mant\n    vout = Q6_V_vmux_QVV(q_integral, in_vec, vout);            // integral values\n    vout = Q6_V_vmux_QVV(q_negexp_pos, const_zero_v, vout);    // expval<0 x>0 -> 0\n    vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout);  // expval<0 x<0 -> -1\n\n    return vout;\n}\n\n#endif /* HVX_FLOOR_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hvx-inverse.h",
    "content": "#ifndef HVX_INVERSE_H\n#define HVX_INVERSE_H\n\n#include <HAP_farf.h>\n\n#include <math.h>\n#include <string.h>\n#include <assert.h>\n#include <stddef.h>\n#include <stdint.h>\n\n#include \"hvx-base.h\"\n\n// ====================================================\n// FUNCTION: 1/(x+1)     y(0) = 1,  y(0.5) = 0.6667, y(1) = 0.5\n// Order:3; continuity: True; Ends forced: True\n// Mode: unsigned;   Result fractional bits: 14\n// Peak Error: 1.1295e-04  Rms Error: 2.8410e-05   Mean Error: 1.1370e-05\n//      32769  -32706   31252  -10589\n//      32590  -30635   22793   -4493\n//      32066  -27505   16481   -2348\n//      31205  -24054   11849   -1306\n\nstatic inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) {\n    // input is 0..0xffff representing 0.0  .. 1.0\n    HVX_Vector p;\n    p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull);\n    p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull);\n    p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull);\n    p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull);\n    return p;  // signed result, 14 fractional bits\n}\n\n// Find reciprocal of fp16.\n// (1) first, convert to fp32, multiplying by 1.0; this is done to\n//    handle denormals. Ignoring sign and zero, result should be at\n//    least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000)\n//    (exponent in range [103,143])\n// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly\n// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32\n// (4) convert that to fp16\n// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace\n//     the result with the max value.\nstatic inline HVX_Vector hvx_vec_inverse_f16(HVX_Vector vals) {\n    HVX_Vector     em_mask  = Q6_Vh_vsplat_R(0x7FFF);\n    HVX_Vector     avals    = Q6_V_vand_VV(vals, em_mask);\n    HVX_VectorPred is_neg   = Q6_Q_vcmp_gt_VhVh(avals, vals);\n    // is too small to 1/x ? for 'standard' fp16, this would be 0x101\n    HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals);\n\n    HVX_VectorPair to_qf32  = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00));  // *1.0\n    HVX_Vector     to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32));\n    HVX_Vector     to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32));\n\n    // bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector\n    HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9));\n    // likewise extract the upper 16 from each, containing the exponents in range 103..142\n    HVX_Vector exp_u16  = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0);\n    //Get exponent in IEEE 32-bit representation\n    exp_u16             = Q6_Vuh_vlsr_VuhR(exp_u16, 7);\n\n    // so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane\n    // We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0)\n    // Use poly to transform to 1/x, with 14 fractional bits\n    //\n    HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16);\n\n    HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm);  //count leading zeros\n\n    // Get mantissa for 16-bit representation\n    HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF));\n\n    //Compute Reciprocal Exponent\n    HVX_Vector exp_recip =\n        Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1)));\n    //Convert it for 16-bit representation\n    exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15));\n    exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10);\n\n    //Merge exponent and mantissa for reciprocal\n    HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip);\n    // map 'small' inputs to standard largest value 0x7bff\n    recip            = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip);\n    // add sign back\n    recip            = Q6_V_vandor_VQR(recip, is_neg, 0x80008000);\n    return recip;\n}\n\nstatic inline HVX_Vector hvx_vec_inverse_f32(HVX_Vector v_sf) {\n    HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3);\n    HVX_Vector two_sf       = hvx_vec_splat_f32(2.0);\n\n    // First approximation\n    HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf);\n\n    HVX_Vector r_qf;\n\n    // Refine\n    r_qf = Q6_Vqf32_vmpy_VsfVsf(\n        i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf)))));\n    r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(\n        r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));\n    r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(\n        r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));\n\n    return Q6_Vsf_equals_Vqf32(r_qf);\n}\n\nstatic inline HVX_Vector hvx_vec_inverse_f32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {\n    HVX_Vector out = hvx_vec_inverse_f32(v_sf);\n\n    HVX_Vector     masked_out = Q6_V_vand_VV(out, nan_inf_mask);\n    const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out);\n\n    return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);\n}\n\n#define hvx_inverse_f32_loop_body(dst_type, src_type, vec_store)             \\\n    do {                                                                     \\\n        dst_type * restrict vdst = (dst_type *) dst;                         \\\n        src_type * restrict vsrc = (src_type *) src;                         \\\n                                                                             \\\n        const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000);           \\\n                                                                             \\\n        const uint32_t nvec = n / VLEN_FP32;                                 \\\n        const uint32_t nloe = n % VLEN_FP32;                                 \\\n                                                                             \\\n        uint32_t i = 0;                                                      \\\n                                                                             \\\n        _Pragma(\"unroll(4)\")                                                 \\\n        for (; i < nvec; i++) {                                              \\\n             vdst[i] = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask);     \\\n        }                                                                    \\\n        if (nloe) {                                                          \\\n            HVX_Vector v = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask); \\\n            vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, v);             \\\n        }                                                                    \\\n    } while(0)\n\nstatic inline HVX_Vector hvx_vec_inverse_f16_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {\n    HVX_Vector out = hvx_vec_inverse_f16(v_sf);\n\n    HVX_Vector     masked_out = Q6_V_vand_VV(out, nan_inf_mask);\n    const HVX_VectorPred pred = Q6_Q_vcmp_eq_VhVh(nan_inf_mask, masked_out);\n\n    return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);\n}\n\n#define hvx_inverse_f16_loop_body(dst_type, src_type, vec_store)             \\\n    do {                                                                     \\\n        dst_type * restrict vdst = (dst_type *) dst;                         \\\n        src_type * restrict vsrc = (src_type *) src;                         \\\n                                                                             \\\n        const HVX_Vector nan_inf_mask = Q6_Vh_vsplat_R(0x7c00);              \\\n                                                                             \\\n        const uint32_t nvec = n / VLEN_FP16;                                 \\\n        const uint32_t nloe = n % VLEN_FP16;                                 \\\n                                                                             \\\n        uint32_t i = 0;                                                      \\\n                                                                             \\\n        _Pragma(\"unroll(4)\")                                                 \\\n        for (; i < nvec; i++) {                                              \\\n             vdst[i] = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask);     \\\n        }                                                                    \\\n        if (nloe) {                                                          \\\n            HVX_Vector v = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \\\n            vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, v);             \\\n        }                                                                    \\\n    } while(0)\n\n// Generic macro to define alignment permutations for an op\n#define DEFINE_HVX_INV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \\\nstatic inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \\\n    assert((uintptr_t) dst % 128 == 0); \\\n    assert((uintptr_t) src % 128 == 0); \\\n    OP_LOOP_BODY(HVX_Vector, HVX_Vector, hvx_vec_store_a); \\\n} \\\nstatic inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \\\n    assert((uintptr_t) dst % 128 == 0); \\\n    OP_LOOP_BODY(HVX_Vector, HVX_UVector, hvx_vec_store_a); \\\n} \\\nstatic inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \\\n    assert((uintptr_t) src % 128 == 0); \\\n    OP_LOOP_BODY(HVX_UVector, HVX_Vector, hvx_vec_store_u); \\\n} \\\nstatic inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \\\n    OP_LOOP_BODY(HVX_UVector, HVX_UVector, hvx_vec_store_u); \\\n} \\\n\n// Dispatcher logic\n#define HVX_INV_DISPATCHER(OP_NAME) \\\nstatic inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { \\\n    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \\\n        OP_NAME##_aa(dst, src, num_elems); \\\n    } else if (hex_is_aligned((void *) dst, 128)) { \\\n        OP_NAME##_au(dst, src, num_elems); \\\n    } else if (hex_is_aligned((void *) src, 128)) { \\\n        OP_NAME##_ua(dst, src, num_elems); \\\n    } else { \\\n        OP_NAME##_uu(dst, src, num_elems); \\\n    } \\\n}\n\nDEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f32, hvx_inverse_f32_loop_body)\nDEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f16, hvx_inverse_f16_loop_body)\n\nHVX_INV_DISPATCHER(hvx_inverse_f32)\nHVX_INV_DISPATCHER(hvx_inverse_f16)\n\n#endif // HVX_INVERSE_H\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hvx-reduce.h",
    "content": "#ifndef HVX_REDUCE_H\n#define HVX_REDUCE_H\n\n#include <math.h>\n#include <stdbool.h>\n#include <stdint.h>\n#include <assert.h>\n\n#include \"hex-utils.h\"\n#include \"hvx-base.h\"\n#include \"hvx-types.h\"\n\nstatic inline HVX_Vector hvx_vec_reduce_sum_n_i32(HVX_Vector in, unsigned int n) {\n    unsigned int total = n * 4;  // total vec nbytes\n    unsigned int width = 4;      // int32\n\n    HVX_Vector sum = in, sum_t;\n    while (width < total) {\n        sum_t = Q6_V_vror_VR(sum, width);     // rotate right\n        sum   = Q6_Vw_vadd_VwVw(sum_t, sum);  // elementwise sum\n        width = width << 1;\n    }\n    return sum;\n}\n\nstatic inline HVX_Vector hvx_vec_reduce_sum_i32(HVX_Vector in) {\n    return hvx_vec_reduce_sum_n_i32(in, 32);\n}\n\nstatic inline HVX_Vector hvx_vec_reduce_sum_n_qf32(HVX_Vector in, unsigned int n) {\n    unsigned int total = n * 4;  // total vec nbytes\n    unsigned int width = 4;      // fp32 nbytes\n\n    HVX_Vector sum = in, sum_t;\n    while (width < total) {\n        sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width);  // rotate right\n        sum   = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t);             // elementwise sum\n        width = width << 1;\n    }\n    return sum;\n}\n\nstatic inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) {\n    return hvx_vec_reduce_sum_n_qf32(in, 32);\n}\n\n#if __HVX_ARCH__ > 75\n\nstatic inline HVX_Vector hvx_vec_reduce_sum_f32x4(HVX_Vector_x4 in) {\n    HVX_VectorPair sum_p01 = Q6_W_vshuff_VVR(in.v[1], in.v[0], 4);\n    HVX_VectorPair sum_p23 = Q6_W_vshuff_VVR(in.v[3], in.v[2], 4);\n    HVX_Vector  sum_sf01  = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p01), Q6_V_hi_W(sum_p01));\n    HVX_Vector  sum_sf23  = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p23), Q6_V_hi_W(sum_p23));\n\n    HVX_VectorPair sum_p0123 = Q6_W_vshuff_VVR(sum_sf23, sum_sf01, 8);\n    HVX_Vector  sum_sf       = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p0123), Q6_V_hi_W(sum_p0123));\n\n    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2));\n    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4));\n    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8));\n    return sum_sf;\n}\n\nstatic inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {\n    HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);\n    HVX_Vector  sum_sf  = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));\n\n    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2));\n    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4));\n    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8));\n    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16));\n    return sum_sf;\n}\n\nstatic inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {\n    unsigned int total = n * 4;  // total vec nbytes\n    unsigned int width = 4;      // fp32 nbytes\n\n    HVX_Vector sum = in, sum_t;\n    while (width < total) {\n        sum_t = Q6_V_vror_VR(sum, width);       // rotate right\n        sum   = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum\n        width = width << 1;\n    }\n    return sum;\n}\n\n#else\n\nstatic inline HVX_Vector hvx_vec_reduce_sum_f32x4(HVX_Vector_x4 in) {\n    HVX_VectorPair sum_p01  = Q6_W_vshuff_VVR(in.v[1], in.v[0], 4);\n    HVX_VectorPair sum_p23  = Q6_W_vshuff_VVR(in.v[3], in.v[2], 4);\n    HVX_Vector     sum_qf01 = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p01), Q6_V_hi_W(sum_p01));\n    HVX_Vector     sum_qf23 = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p23), Q6_V_hi_W(sum_p23));\n\n    HVX_VectorPair sum_p0123 = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(sum_qf23), Q6_Vsf_equals_Vqf32(sum_qf01), 8);\n    HVX_Vector     sum_qf    = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p0123), Q6_V_hi_W(sum_p0123));\n\n    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2));\n    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4));\n    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8));\n    return Q6_Vsf_equals_Vqf32(sum_qf);\n}\n\nstatic inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {\n    HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);\n    HVX_Vector  sum_qf  = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));\n\n    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2));\n    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4));\n    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8));\n    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16));\n    return Q6_Vsf_equals_Vqf32(sum_qf);\n}\n\nstatic inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {\n    unsigned int total = n * 4;  // total vec nbytes\n    unsigned int width = 4;      // fp32 nbytes\n\n    HVX_Vector sum = in, sum_t;\n    while (width < total) {\n        sum_t = Q6_V_vror_VR(sum, width);                               // rotate right\n        sum   = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t));  // elementwise sum\n        width = width << 1;\n    }\n    return sum;\n}\n\n#endif\n\nstatic inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) {\n    return hvx_vec_reduce_sum_n_f32(in, 32);\n}\n\nstatic inline HVX_Vector hvx_vec_reduce_max_f16(HVX_Vector in) {\n    unsigned total = 128;  // total vec nbytes\n    unsigned width = 2;    // fp16 nbytes\n\n    HVX_Vector _max = in, _max_t;\n    while (width < total) {\n        _max_t = Q6_V_vror_VR(_max, width);         // rotate right\n        _max   = Q6_Vhf_vmax_VhfVhf(_max_t, _max);  // elementwise max\n        width  = width << 1;\n    }\n\n    return _max;\n}\n\nstatic inline HVX_Vector hvx_vec_reduce_max2_f16(HVX_Vector in, HVX_Vector _max) {\n    unsigned total = 128;  // total vec nbytes\n    unsigned width = 2;    // fp32 nbytes\n\n    HVX_Vector _max_t;\n\n    _max = Q6_Vhf_vmax_VhfVhf(in, _max);\n    while (width < total) {\n        _max_t = Q6_V_vror_VR(_max, width);         // rotate right\n        _max   = Q6_Vhf_vmax_VhfVhf(_max_t, _max);  // elementwise max\n        width  = width << 1;\n    }\n\n    return _max;\n}\n\nstatic inline HVX_Vector hvx_vec_reduce_max_f32(HVX_Vector in) {\n    unsigned total = 128;  // total vec nbytes\n    unsigned width = 4;    // fp32 nbytes\n\n    HVX_Vector _max = in, _max_t;\n    while (width < total) {\n        _max_t = Q6_V_vror_VR(_max, width);         // rotate right\n        _max   = Q6_Vsf_vmax_VsfVsf(_max_t, _max);  // elementwise max\n        width  = width << 1;\n    }\n\n    return _max;\n}\n\nstatic inline HVX_Vector hvx_vec_reduce_max2_f32(HVX_Vector in, HVX_Vector _max) {\n    unsigned total = 128;  // total vec nbytes\n    unsigned width = 4;    // fp32 nbytes\n\n    HVX_Vector _max_t;\n\n    _max = Q6_Vsf_vmax_VsfVsf(in, _max);\n    while (width < total) {\n        _max_t = Q6_V_vror_VR(_max, width);         // rotate right\n        _max   = Q6_Vsf_vmax_VsfVsf(_max_t, _max);  // elementwise max\n        width  = width << 1;\n    }\n\n    return _max;\n}\n\n#define hvx_reduce_loop_body(src_type, init_vec, pad_vec, vec_op, reduce_op, scalar_reduce) \\\n    do {                                                                                    \\\n        src_type * restrict vsrc = (src_type *) src;                                        \\\n        HVX_Vector acc = init_vec;                                                          \\\n                                                                                            \\\n        const uint32_t elem_size = sizeof(float);                                           \\\n        const uint32_t epv  = 128 / elem_size;                                              \\\n        const uint32_t nvec = num_elems / epv;                                              \\\n        const uint32_t nloe = num_elems % epv;                                              \\\n                                                                                            \\\n        uint32_t i = 0;                                                                     \\\n        _Pragma(\"unroll(4)\")                                                                \\\n        for (; i < nvec; i++) {                                                             \\\n            acc = vec_op(acc, vsrc[i]);                                                     \\\n        }                                                                                   \\\n        if (nloe) {                                                                         \\\n            const float * srcf = (const float *) src + i * epv;                             \\\n            HVX_Vector in = *(HVX_UVector *) srcf;                                          \\\n            HVX_Vector temp = Q6_V_valign_VVR(in, pad_vec, nloe * elem_size);               \\\n            acc = vec_op(acc, temp);                                                        \\\n        }                                                                                   \\\n        HVX_Vector v = reduce_op(acc);                                                      \\\n        return scalar_reduce(v);                                                            \\\n    } while(0)\n\n#define HVX_REDUCE_MAX_OP(acc, val) Q6_Vsf_vmax_VsfVsf(acc, val)\n#define HVX_REDUCE_SUM_OP(acc, val) Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(acc), val)\n#define HVX_SUM_SQ_OP(acc, val) Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(val, val))\n#define HVX_REDUCE_MAX_SCALAR(v) hvx_vec_get_f32(v)\n#define HVX_REDUCE_SUM_SCALAR(v) hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(v))\n\n// Max variants\n\nstatic inline float hvx_reduce_max_f32_a(const uint8_t * restrict src, const int num_elems) {\n    HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]);\n    assert((unsigned long) src % 128 == 0);\n    hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR);\n}\n\nstatic inline float hvx_reduce_max_f32_u(const uint8_t * restrict src, const int num_elems) {\n    HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]);\n    hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR);\n}\n\nstatic inline float hvx_reduce_max_f32(const uint8_t * restrict src, const int num_elems) {\n    if (hex_is_aligned((void *) src, 128)) {\n        return hvx_reduce_max_f32_a(src, num_elems);\n    } else {\n        return hvx_reduce_max_f32_u(src, num_elems);\n    }\n}\n\n// Sum variants\n\nstatic inline float hvx_reduce_sum_f32_a(const uint8_t * restrict src, const int num_elems) {\n    HVX_Vector init_vec = Q6_V_vsplat_R(0);\n    assert((unsigned long) src % 128 == 0);\n    hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);\n}\n\nstatic inline float hvx_reduce_sum_f32_u(const uint8_t * restrict src, const int num_elems) {\n    HVX_Vector init_vec = Q6_V_vsplat_R(0);\n    hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);\n}\n\nstatic inline float hvx_reduce_sum_f32(const uint8_t * restrict src, const int num_elems) {\n    if (hex_is_aligned((void *) src, 128)) {\n        return hvx_reduce_sum_f32_a(src, num_elems);\n    } else {\n        return hvx_reduce_sum_f32_u(src, num_elems);\n    }\n}\n\n// Sum of squares variants\n\nstatic inline float hvx_sum_of_squares_f32_a(const uint8_t * restrict src, const int num_elems) {\n    HVX_Vector init_vec = Q6_V_vsplat_R(0);\n    assert((uintptr_t) src % 128 == 0);\n    hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);\n}\n\nstatic inline float hvx_sum_of_squares_f32_u(const uint8_t * restrict src, const int num_elems) {\n    HVX_Vector init_vec = Q6_V_vsplat_R(0);\n    hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);\n}\n\nstatic inline float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) {\n    if (hex_is_aligned((void *) src, 128)) {\n        return hvx_sum_of_squares_f32_a(src, num_elems);\n    } else {\n        return hvx_sum_of_squares_f32_u(src, num_elems);\n    }\n}\n\n#undef hvx_reduce_loop_body\n#undef HVX_REDUCE_MAX_OP\n#undef HVX_REDUCE_SUM_OP\n#undef HVX_REDUCE_MAX_SCALAR\n#undef HVX_REDUCE_SUM_SCALAR\n#undef HVX_SUM_SQ_OP\n\n#endif /* HVX_REDUCE_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hvx-scale.h",
    "content": "#ifndef HVX_SCALE_H\n#define HVX_SCALE_H\n\n#include <assert.h>\n#include <stddef.h>\n#include <stdint.h>\n\n#include \"hvx-base.h\"\n\n#define hvx_scale_f32_loop_body(dst_type, src_type, vec_store)                       \\\n    do {                                                                             \\\n        dst_type * restrict vdst = (dst_type *) dst;                                 \\\n        src_type * restrict vsrc = (src_type *) src;                                 \\\n                                                                                     \\\n        HVX_Vector vs = hvx_vec_splat_f32(scale);                                    \\\n                                                                                     \\\n        const uint32_t elem_size = sizeof(float);                                    \\\n        const uint32_t epv = 128 / elem_size;                                        \\\n        const uint32_t nvec = n / epv;                                               \\\n        const uint32_t nloe = n % epv;                                               \\\n                                                                                     \\\n        uint32_t i = 0;                                                              \\\n                                                                                     \\\n        _Pragma(\"unroll(4)\")                                                         \\\n        for (; i < nvec; ++i) {                                                      \\\n            HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);                        \\\n            vdst[i]      = Q6_Vsf_equals_Vqf32(v);                                   \\\n        }                                                                            \\\n        if (nloe) {                                                                  \\\n            HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);                        \\\n            vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v));  \\\n        }                                                                            \\\n    } while(0)\n\nstatic inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {\n    assert((size_t) dst % 128 == 0);\n    assert((size_t) src % 128 == 0);\n    hvx_scale_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);\n}\n\nstatic inline void hvx_scale_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {\n    assert((size_t) dst % 128 == 0);\n    hvx_scale_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);\n}\n\nstatic inline void hvx_scale_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {\n    assert((size_t) src % 128 == 0);\n    hvx_scale_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);\n}\n\nstatic inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {\n    hvx_scale_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);\n}\n\nstatic inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {\n    if (((size_t) dst & 127) == 0) {\n        if (((size_t) src & 127) == 0) {\n            hvx_scale_f32_aa(dst, src, n, scale);\n        } else {\n            hvx_scale_f32_au(dst, src, n, scale);\n        }\n    } else {\n        if (((size_t) src & 127) == 0) {\n            hvx_scale_f32_ua(dst, src, n, scale);\n        } else {\n            hvx_scale_f32_uu(dst, src, n, scale);\n        }\n    }\n}\n\n#define hvx_scale_offset_f32_loop_body(dst_type, src_type, vec_store)                \\\n    do {                                                                             \\\n        dst_type * restrict vdst = (dst_type *) dst;                                 \\\n        src_type * restrict vsrc = (src_type *) src;                                 \\\n                                                                                     \\\n        HVX_Vector vs = hvx_vec_splat_f32(scale);                                    \\\n        HVX_Vector vo = hvx_vec_splat_f32(offset);                                   \\\n                                                                                     \\\n        const uint32_t elem_size = sizeof(float);                                    \\\n        const uint32_t epv = 128 / elem_size;                                        \\\n        const uint32_t nvec = n / epv;                                               \\\n        const uint32_t nloe = n % epv;                                               \\\n                                                                                     \\\n        uint32_t i = 0;                                                              \\\n                                                                                     \\\n        _Pragma(\"unroll(4)\")                                                         \\\n        for (; i < nvec; ++i) {                                                      \\\n            HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \\\n            vdst[i] = Q6_Vsf_equals_Vqf32(v);                                        \\\n        }                                                                            \\\n        if (nloe) {                                                                  \\\n            HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \\\n            vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v));  \\\n        }                                                                            \\\n    } while(0)\n\nstatic inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {\n    assert((size_t) dst % 128 == 0);\n    assert((size_t) src % 128 == 0);\n    hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);\n}\n\nstatic inline void hvx_scale_offset_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {\n    assert((size_t) dst % 128 == 0);\n    hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);\n}\n\nstatic inline void hvx_scale_offset_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {\n    assert((size_t) src % 128 == 0);\n    hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);\n}\n\nstatic inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {\n    hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);\n}\n\nstatic inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {\n    if (((size_t) dst & 127) == 0) {\n        if (((size_t) src & 127) == 0) {\n            hvx_scale_offset_f32_aa(dst, src, n, scale, offset);\n        } else {\n            hvx_scale_offset_f32_au(dst, src, n, scale, offset);\n        }\n    } else {\n        if (((size_t) src & 127) == 0) {\n            hvx_scale_offset_f32_ua(dst, src, n, scale, offset);\n        } else {\n            hvx_scale_offset_f32_uu(dst, src, n, scale, offset);\n        }\n    }\n}\n\n#endif // HVX_SCALE_H\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hvx-sigmoid.h",
    "content": "#ifndef HVX_SIGMOID_H\n#define HVX_SIGMOID_H\n\n#include \"hvx-base.h\"\n\n#define FAST_SIGMOID_LOG2F (0x3fb8aa3b)  // 1.442695022\n#define FAST_SIGMOID_C1    (0x3d009076)  // 0.03138777\n#define FAST_SIGMOID_C2    (0x3e8d74bd)  // 0.276281267\n#define FAST_SIGMOID_C3    (0x3f000000)  // 0.5\n\nstatic inline HVX_Vector hvx_vec_fast_sigmoid_f32(HVX_Vector v) {\n    v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));\n    v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3));\n\n    HVX_Vector in_int = hvx_vec_truncate_f32(Q6_Vsf_equals_Vqf32(v));\n    HVX_Vector x      = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int));\n    HVX_Vector xx     = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x);\n\n    HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2));\n    v1            = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));\n\n    HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1));\n    v2            = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx);\n    v2            = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x);\n\n    HVX_Vector v3          = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1));\n    HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1);\n    v3_exponent            = Q6_Vuw_vlsr_VuwR(v3_exponent, 24);\n    v3_exponent            = Q6_Vw_vadd_VwVw(in_int, v3_exponent);\n    v3                     = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24);\n\n    HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1));\n    HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4));\n\n    HVX_Vector res = hvx_vec_inverse_f32(v5);\n    res            = Q6_Vqf32_vmpy_VsfVsf(v3, res);\n\n    return Q6_Vsf_equals_Vqf32(res);\n}\n\nstatic inline HVX_Vector hvx_vec_fast_sigmoid_f32_guard(HVX_Vector v,\n                                                         HVX_Vector one,\n                                                         HVX_Vector max_exp,\n                                                         HVX_Vector min_exp) {\n    const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v);\n    const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp);\n\n    HVX_Vector out = hvx_vec_fast_sigmoid_f32(v);\n    out            = Q6_V_vmux_QVV(pred_max, out, one);\n    return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());\n}\n\nstatic inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {\n    // tanh(x) = 2 * sigmoid(2x) - 1\n    HVX_Vector two = hvx_vec_splat_f32(2.0f);\n    HVX_Vector one = hvx_vec_splat_f32(1.0f);\n    HVX_Vector x2  = Q6_Vqf32_vmpy_VsfVsf(x, two);\n\n    HVX_Vector max_exp = hvx_vec_splat_f32(87.f);\n    HVX_Vector min_exp = hvx_vec_splat_f32(-87.f);\n\n    HVX_Vector sig2x = hvx_vec_fast_sigmoid_f32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);\n\n    HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);\n    res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);\n    return Q6_Vsf_equals_Vqf32(res);\n}\n\n#define hvx_sigmoid_loop_body(dst_type, src_type, vec_store)    \\\n    do {                                                        \\\n        dst_type * restrict vdst = (dst_type *) dst;            \\\n        src_type * restrict vsrc = (src_type *) src;            \\\n                                                                \\\n        const HVX_Vector one     = hvx_vec_splat_f32(1.f);      \\\n        const HVX_Vector max_exp = hvx_vec_splat_f32(87.f);     \\\n        const HVX_Vector min_exp = hvx_vec_splat_f32(-87.f);    \\\n                                                                \\\n        const uint32_t epv  = 128 / sizeof(float);              \\\n        const uint32_t nvec = n / epv;                          \\\n        const uint32_t nloe = n % epv;                          \\\n                                                                \\\n        uint32_t i = 0;                                         \\\n                                                                \\\n        _Pragma(\"unroll(4)\")                                    \\\n        for (; i < nvec; i++) {                                 \\\n             vdst[i] = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \\\n        }                                                       \\\n        if (nloe) {                                             \\\n             HVX_Vector tmp = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \\\n             vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \\\n        }                                                       \\\n    } while(0)\n\n#define hvx_tanh_loop_body(dst_type, src_type, vec_store)       \\\n    do {                                                        \\\n        dst_type * restrict vdst = (dst_type *) dst;            \\\n        src_type * restrict vsrc = (src_type *) src;            \\\n                                                                \\\n        const uint32_t epv  = 128 / sizeof(float);              \\\n        const uint32_t nvec = n / epv;                          \\\n        const uint32_t nloe = n % epv;                          \\\n                                                                \\\n        uint32_t i = 0;                                         \\\n                                                                \\\n        _Pragma(\"unroll(4)\")                                    \\\n        for (; i < nvec; i++) {                                 \\\n             vdst[i] = hvx_vec_tanh_f32(vsrc[i]);               \\\n        }                                                       \\\n        if (nloe) {                                             \\\n             HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]);        \\\n             vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \\\n        }                                                       \\\n    } while(0)\n\nstatic inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) dst % 128 == 0);\n    assert((unsigned long) src % 128 == 0);\n    hvx_sigmoid_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);\n}\n\nstatic inline void hvx_sigmoid_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) dst % 128 == 0);\n    hvx_sigmoid_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);\n}\n\nstatic inline void hvx_sigmoid_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) src % 128 == 0);\n    hvx_sigmoid_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);\n}\n\nstatic inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);\n}\n\nstatic inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) dst % 128 == 0);\n    assert((unsigned long) src % 128 == 0);\n    hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);\n}\n\n#endif /* HVX_SIGMOID_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hvx-sqrt.h",
    "content": "#ifndef HVX_SQRT_H\n#define HVX_SQRT_H\n\n#include <stdbool.h>\n#include <stdint.h>\n\n#include \"hex-utils.h\"\n\n#include \"hvx-base.h\"\n\n#define RSQRT_CONST        0x5f3759df  // Constant for fast inverse square root calculation\n#define RSQRT_ONE_HALF     0x3f000000  // 0.5\n#define RSQRT_THREE_HALVES 0x3fc00000  // 1.5\n\n#if __HVX_ARCH__ < 79\n#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))\n#else\n#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)\n#endif\n\nstatic inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {\n    //Algorithm :\n    //  x2 = input*0.5\n    //  y  = * (long *) &input\n    //  y  = 0x5f3759df - (y>>1)\n    //  y  = y*(threehalfs - x2*y*y)\n\n    HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);\n    HVX_Vector onehalf    = Q6_V_vsplat_R(RSQRT_ONE_HALF);\n    HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES);\n\n    HVX_Vector x2, y, ypower2, temp;\n\n    x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf);\n    x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero());\n\n    y = Q6_Vw_vasr_VwR(in_vec, 1);\n    y = Q6_Vw_vsub_VwVw(rsqrtconst, y);\n\n    // 1st iteration\n    ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y);\n    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());\n    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);\n    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));\n    temp    = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp));\n\n    // 2nd iteration\n    y       = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());\n    ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);\n    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());\n    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);\n    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));\n    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);\n\n    // 3rd iteration\n    y       = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());\n    ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);\n    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());\n    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);\n    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));\n    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);\n\n    return Q6_Vsf_equals_Vqf32(temp);\n}\n\n// Compute sqrt(x) as x*inv_sqrt(x)\n#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store)                \\\n    do {                                                                     \\\n        dst_type * restrict vdst = (dst_type *) dst;                         \\\n        src_type * restrict vsrc = (src_type *) src;                         \\\n                                                                             \\\n        const uint32_t nvec = n / VLEN_FP32;                                 \\\n        const uint32_t nloe = n % VLEN_FP32;                                 \\\n                                                                             \\\n        uint32_t i = 0;                                                      \\\n                                                                             \\\n        _Pragma(\"unroll(4)\")                                                 \\\n        for (; i < nvec; i++) {                                              \\\n            HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]);                \\\n            HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]);             \\\n            vdst[i] = sqrt_res;                                              \\\n        }                                                                    \\\n        if (nloe) {                                                          \\\n            HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]);                \\\n            HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]);             \\\n            vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res);      \\\n        }                                                                    \\\n    } while(0)\n\nstatic inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) dst % 128 == 0);\n    assert((unsigned long) src % 128 == 0);\n    hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);\n}\n\nstatic inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) dst % 128 == 0);\n    hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);\n}\n\nstatic inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    assert((unsigned long) src % 128 == 0);\n    hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);\n}\n\nstatic inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {\n    hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);\n}\n\nstatic inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) {\n    if ((unsigned long) dst % 128 == 0) {\n        if ((unsigned long) src % 128 == 0) {\n            hvx_sqrt_f32_aa(dst, src, num_elems);\n        } else {\n            hvx_sqrt_f32_au(dst, src, num_elems);\n        }\n    } else {\n        if ((unsigned long) src % 128 == 0) {\n            hvx_sqrt_f32_ua(dst, src, num_elems);\n        } else {\n            hvx_sqrt_f32_uu(dst, src, num_elems);\n        }\n    }\n}\n\n#endif /* HVX_SQRT_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hvx-types.h",
    "content": "#ifndef HVX_TYPES_H\n#define HVX_TYPES_H\n\n#include <stdbool.h>\n#include <stdint.h>\n\n#include <hexagon_types.h>\n\n#define SIZEOF_FP32 (4)\n#define SIZEOF_FP16 (2)\n#define VLEN        (128)\n#define VLEN_FP32   (VLEN / SIZEOF_FP32)\n#define VLEN_FP16   (VLEN / SIZEOF_FP16)\n\ntypedef union {\n    HVX_Vector v;\n    uint8_t    b[VLEN];\n    uint16_t   h[VLEN_FP16];\n    uint32_t   w[VLEN_FP32];\n    __fp16     fp16[VLEN_FP16];\n    float      fp32[VLEN_FP32];\n} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias;\n\ntypedef struct {\n    HVX_Vector v[2];\n} HVX_Vector_x2;\n\ntypedef struct {\n    HVX_Vector v[4];\n} HVX_Vector_x4;\n\ntypedef struct {\n    HVX_Vector v[8];\n} HVX_Vector_x8;\n\n#endif /* HVX_TYPES_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/hvx-utils.h",
    "content": "#ifndef HVX_UTILS_H\n#define HVX_UTILS_H\n\n#include \"hex-utils.h\"\n\n#include \"hvx-types.h\"\n#include \"hvx-copy.h\"\n#include \"hvx-scale.h\"\n#include \"hvx-exp.h\"\n#include \"hvx-inverse.h\"\n#include \"hvx-reduce.h\"\n#include \"hvx-sigmoid.h\"\n#include \"hvx-sqrt.h\"\n#include \"hvx-arith.h\"\n#include \"hvx-div.h\"\n#include \"hvx-base.h\"\n\n#ifndef GATHER_TYPE\n#    if defined(__hexagon__)\n#        define GATHER_TYPE(_a) (intptr_t) _a\n#    else\n#        define GATHER_TYPE(_a) (HVX_Vector *) _a\n#    endif\n#endif\n\n#endif /* HVX_UTILS_H */\n"
  },
  {
    "path": "src/ggml-hexagon/htp/main.c",
    "content": "#pragma clang diagnostic ignored \"-Wgnu-zero-variadic-macro-arguments\"\n#pragma clang diagnostic ignored \"-Wunused-function\"\n\n#include <HAP_farf.h>\n#include <HAP_perf.h>\n#include <AEEStdErr.h>\n#include <dspqueue.h>\n#include <HAP_compute_res.h>\n#include <HAP_etm_config.h>\n#include <HAP_mem.h>\n#include <HAP_power.h>\n#include <HAP_ps.h>\n#include <qurt.h>\n#include <qurt_thread.h>\n#include <remote.h>\n#include <string.h>\n\n#include \"hex-dma.h\"\n#include \"hex-utils.h\"\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n#include \"htp-ctx.h\"\n#include \"htp-msg.h\"\n#include \"htp-ops.h\"\n#include \"worker-pool.h\"\n\nAEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {\n    struct htp_context * ctx;\n    int                  err = 0;\n\n    ctx = calloc(1, sizeof(*ctx));\n    if (ctx == NULL) {\n        return AEE_ENOMEMORY;\n    }\n\n    // Use the context structure as a handle\n    *handle = (remote_handle64) ctx;\n\n    // Enable FARF logs\n    HAP_setFARFRuntimeLoggingParams(0xffff, NULL, 0);\n\n    // Set client class\n    {\n        HAP_power_request_t request;\n        memset(&request, 0, sizeof(HAP_power_request_t));\n        request.type    = HAP_power_set_apptype;\n        request.apptype = HAP_POWER_COMPUTE_CLIENT_CLASS;\n\n        if ((err = HAP_power_set((void *) ctx, &request)) != 0) {\n            return err;\n        }\n    }\n\n    {\n        HAP_power_request_t request;\n        memset(&request, 0, sizeof(request));\n\n        request.type                              = HAP_power_set_DCVS_v3;\n        request.dcvs_v3.set_dcvs_enable           = TRUE;\n        request.dcvs_v3.dcvs_enable               = TRUE;\n        request.dcvs_v3.dcvs_option               = HAP_DCVS_V2_PERFORMANCE_MODE;\n        request.dcvs_v3.set_bus_params            = TRUE;\n        request.dcvs_v3.bus_params.min_corner     = HAP_DCVS_VCORNER_MAX;\n        request.dcvs_v3.bus_params.max_corner     = HAP_DCVS_VCORNER_MAX;\n        request.dcvs_v3.bus_params.target_corner  = HAP_DCVS_VCORNER_MAX;\n        request.dcvs_v3.set_core_params           = TRUE;\n        request.dcvs_v3.core_params.min_corner    = HAP_DCVS_VCORNER_MAX;\n        request.dcvs_v3.core_params.max_corner    = HAP_DCVS_VCORNER_MAX;\n        request.dcvs_v3.core_params.target_corner = HAP_DCVS_VCORNER_MAX;\n        request.dcvs_v3.set_sleep_disable         = TRUE;\n        request.dcvs_v3.sleep_disable             = TRUE;\n        if ((err = HAP_power_set((void *) ctx, &request)) != 0) {\n            return err;\n        }\n\n        memset(&request, 0, sizeof(request));\n        request.type         = HAP_power_set_HVX;\n        request.hvx.power_up = TRUE;\n        if ((err = HAP_power_set((void *) ctx, &request)) != 0) {\n            return err;\n        }\n    }\n\n    {\n        // Power on HMX\n        HAP_power_request_t request;\n        memset(&request, 0, sizeof(HAP_power_request_t));\n        request.type         = HAP_power_set_HMX;\n        request.hmx.power_up = TRUE;\n        FARF(ALWAYS, \"Powering HMX on\\n\");\n        err = HAP_power_set((void *) &ctx, &request);\n        if (err != AEE_SUCCESS) {\n            FARF(ERROR, \"Error powering on HMX.\");\n            return err;\n        }\n    }\n\n    return AEE_SUCCESS;\n}\n\nAEEResult htp_iface_close(remote_handle64 handle) {\n    struct htp_context * ctx = (struct htp_context *) handle;\n\n    if (!ctx) {\n        return AEE_EBADPARM;\n    }\n\n    if (ctx->queue) {\n        FARF(ERROR, \"Closing handle with queue still open\");\n        return AEE_EITEMBUSY;\n    }\n\n    free(ctx);\n    return AEE_SUCCESS;\n}\n\nAEEResult htp_iface_enable_etm(remote_handle64 handle) {\n    int err = HAP_user_etm_enable();\n    if (err) {\n        if (err == AEE_EVERSIONNOTSUPPORT) {\n            FARF(ERROR, \"API HAP_user_etm_enable is not supported\\n\");\n        } else {\n            FARF(ERROR, \"Error executing HAP_user_etm_enable with error code : 0x%x\\n\", err);\n        }\n    }\n    return err;\n}\n\nAEEResult htp_iface_disable_etm(remote_handle64 handle) {\n    int err = HAP_user_etm_disable();\n    if (err) {\n        if (err == AEE_EVERSIONNOTSUPPORT) {\n            FARF(ERROR, \"API HAP_user_etm_disable is not supported\\n\");\n        } else {\n            FARF(ERROR, \"Error executing HAP_user_etm_disable with error code : 0x%x\\n\", err);\n        }\n    }\n    return err;\n}\n\nstatic int vtcm_acquire(struct htp_context * ctx) {\n    int err;\n    if (!ctx->vtcm_valid) {\n        // Temporarily bump thread priority to make sure it's higher than other sessions.\n        // This way the resource manager will notify the other thread to release VTCM.\n        // Note that we need to reaquire VTCM at normal priority for this to work next time.\n        qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio - 10);\n        err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000);\n        if (err != 0) {\n            FARF(ERROR, \"Failed to acquire VTCM: 0x%08x\", (unsigned)err);\n            abort();\n        }\n        HAP_compute_res_release_cached(ctx->vtcm_rctx);\n        qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio);\n\n        err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000);\n        if (err != 0) {\n            FARF(ERROR, \"Failed to acquire VTCM: 0x%08x\", (unsigned)err);\n            abort();\n        }\n        ctx->vtcm_valid = true;\n    }\n\n    ctx->vtcm_inuse = true;\n    return 0;\n}\n\nstatic int vtcm_release(struct htp_context * ctx) {\n    ctx->vtcm_inuse = false;\n\n    if (ctx->vtcm_valid && ctx->vtcm_needs_release) {\n        ctx->vtcm_valid         = false;\n        ctx->vtcm_needs_release = false;\n        HAP_compute_res_release_cached(ctx->vtcm_rctx);\n    }\n\n    return 0;\n}\n\nstatic int vtcm_release_callback(unsigned int rctx, void * state) {\n    struct htp_context * ctx = (struct htp_context *) state;\n\n    if (!ctx || ctx->vtcm_rctx != rctx) {\n        return AEE_EBADPARM;\n    }\n\n    // If VTCM is not inuse (not processing Ops) release it right here\n    // otherwise we'll release it once we're done with the current Op.\n\n    if (ctx->vtcm_inuse) {\n        ctx->vtcm_needs_release = true;\n        return 0;\n    }\n\n    ctx->vtcm_valid = false;\n    HAP_compute_res_release_cached(ctx->vtcm_rctx);\n\n    return 0;\n}\n\nstatic int vtcm_alloc(struct htp_context * ctx) {\n    unsigned int vtcm_size = 8 * 1024 * 1024;  // 8MB default\n    HAP_compute_res_query_VTCM(0, &vtcm_size, NULL, NULL, NULL);\n\n    compute_res_attr_t attr;\n    HAP_compute_res_attr_init(&attr);\n    HAP_compute_res_attr_set_serialize(&attr, 0);\n    HAP_compute_res_attr_set_cache_mode(&attr, 1);\n    HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, 0, vtcm_size);\n    HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx);\n    HAP_compute_res_attr_set_hmx_param(&attr, 1);\n\n    // Allocate VTCM for scratch pads\n    uint32_t rctx = HAP_compute_res_acquire(&attr, 1000000 /* timeout */);\n    if (!rctx) {\n        FARF(ERROR, \"failed to allocate %zu bytes VTCM\\n\", ctx->vtcm_size);\n        return AEE_ENOMEMORY;\n    }\n\n    void * vtcm_ptr;\n    if (HAP_compute_res_attr_get_vtcm_ptr_v2(&attr, &vtcm_ptr, &vtcm_size) != 0) {\n        HAP_compute_res_release(rctx);\n        FARF(ERROR, \"failed to allocate %zu bytes VTCM (new)\\n\", ctx->vtcm_size);\n        return AEE_ENOMEMORY;\n    }\n\n    ctx->vtcm_base          = (uint8_t *) vtcm_ptr;\n    ctx->vtcm_size          = vtcm_size;\n    ctx->vtcm_rctx          = rctx;\n    ctx->vtcm_valid         = false;\n    ctx->vtcm_inuse         = false;\n    ctx->vtcm_needs_release = false;\n\n    return 0;\n}\n\nstatic void vtcm_free(struct htp_context * ctx) {\n    if (ctx->vtcm_rctx) {\n        HAP_compute_res_release(ctx->vtcm_rctx);\n        ctx->vtcm_base = 0;\n        ctx->vtcm_rctx = 0;\n    }\n}\n\nstatic void htp_packet_callback(dspqueue_t queue, int error, void * context);\nstatic void htp_error_callback(dspqueue_t queue, int error, void * context);\n\nAEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) {\n    struct htp_context * ctx = (struct htp_context *) handle;\n\n    if (!ctx) {\n        return AEE_EBADPARM;\n    }\n\n    if (ctx->queue) {\n        FARF(ERROR, \"Queue already open\");\n        return AEE_EITEMBUSY;\n    }\n\n    // Import queue created on the CPU\n    int err = dspqueue_import(dsp_queue_id,         // Queue ID from dspqueue_export\n                              htp_packet_callback,  // Packet callback\n                              htp_error_callback,   // Error callback; no errors expected on the DSP\n                              (void *) ctx,         // Callback context\n                              &ctx->queue);\n\n    if (err) {\n        FARF(ERROR, \"Queue import failed with 0x%08x\", (unsigned) err);\n        return err;\n    }\n\n    ctx->thread_id   = qurt_thread_get_id();\n    ctx->thread_prio = qurt_thread_get_priority(ctx->thread_id);\n\n    // allocate VTCM\n    err = vtcm_alloc(ctx);\n    if (err != AEE_SUCCESS) {\n        FARF(ERROR, \"Unable to allocate VTCM\");\n        return AEE_ENOMEMORY;\n    }\n\n    qurt_sysenv_max_hthreads_t hw_threads;\n    qurt_sysenv_get_max_hw_threads(&hw_threads);\n    uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;\n\n    if (n_hvx == 0) {\n        n_hvx = hw_nhvx;\n    }\n    if (n_hvx > hw_threads.max_hthreads) {\n        n_hvx = hw_threads.max_hthreads;\n    }\n    if (n_hvx > HTP_MAX_NTHREADS) {\n        n_hvx = HTP_MAX_NTHREADS;\n    }\n\n    ctx->n_threads = n_hvx;\n    for (int i = 0; i < ctx->n_threads; i++) {\n        // see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541\n        ctx->dma[i] = dma_queue_create(64);\n    }\n\n    // init worker pool\n    err = worker_pool_init(&ctx->worker_pool, n_hvx);\n    if (err != AEE_SUCCESS) {\n        FARF(ERROR, \"Unable to create worker pool\");\n        return err;\n    }\n\n    FARF(HIGH, \"session %u started: n-hvx %u vtcm-size %zu vtcm-rctx %u n-threads %u thread-id %d thread-prio %d \\n\",\n         sess_id, hw_nhvx, ctx->vtcm_size, ctx->vtcm_rctx, ctx->n_threads, ctx->thread_id, ctx->thread_prio);\n\n    return AEE_SUCCESS;\n}\n\nAEEResult htp_iface_stop(remote_handle64 handle) {\n    struct htp_context * ctx = (struct htp_context *) handle;\n    if (!ctx) {\n        return AEE_EBADPARM;\n    }\n\n    if (!ctx->queue) {\n        FARF(ERROR, \"Queue not open\");\n        return AEE_EBADSTATE;\n    }\n\n    // Close queue. dspqueue_close() will also wait for callbacks to finish.\n    int err    = dspqueue_close(ctx->queue);\n    ctx->queue = NULL;\n    if (err != 0) {\n        FARF(ERROR, \"Queue close failed with 0x%08x\", (unsigned) err);\n        return err;\n    }\n\n    if (ctx->worker_pool) {\n        // Release worker pool\n        worker_pool_release(&ctx->worker_pool);\n    }\n\n    for (int i = 0; i < ctx->n_threads; i++) {\n        dma_queue_delete(ctx->dma[i]);\n    }\n\n    vtcm_free(ctx);\n\n    return AEE_SUCCESS;\n}\n\nstatic void htp_error_callback(dspqueue_t queue, int error, void * context) {\n    // No errors expected on the DSP.\n    FARF(ERROR, \"Error callback: 0x%08x\", (unsigned) error);\n}\n\nstruct profile_data {\n    uint64_t usecs;\n    uint64_t cycles;\n    uint64_t pkts;\n};\n\nstatic inline void profile_start(struct profile_data * d) {\n    d->usecs  = HAP_perf_get_qtimer_count();\n    d->cycles = hex_get_cycles();\n    d->pkts   = hex_get_pktcnt();\n}\n\nstatic inline void profile_stop(struct profile_data * d) {\n    d->usecs  = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs);\n    d->cycles = hex_get_cycles() - d->cycles;\n    d->pkts   = hex_get_pktcnt() - d->pkts;\n}\n\nstatic int send_htp_rsp(struct htp_context *     c,\n                        uint32_t                 op,\n                        uint32_t                 status,\n                        struct dspqueue_buffer * bufs,\n                        size_t                   n_bufs,\n                        struct profile_data *    prof) {\n    // Prep response struct\n    struct htp_general_rsp rsp;\n    rsp.op          = op;\n    rsp.status      = status;\n    rsp.prof_usecs  = prof->usecs;\n    rsp.prof_cycles = prof->cycles;\n    rsp.prof_pkts   = prof->pkts;\n\n    int err = dspqueue_write(c->queue,\n                             0,                       // Flags\n                             n_bufs,\n                             bufs,                    // Buffer references\n                             sizeof(rsp),\n                             (const uint8_t *) &rsp,  // Message\n                             DSPQUEUE_TIMEOUT_NONE);\n\n    if (err != 0) {\n        FARF(ERROR, \"dspqueue_write failed: 0x%08x\", (unsigned) err);\n    }\n\n    return err;\n}\n\nstatic void proc_matmul_req(struct htp_context *     ctx,\n                            struct htp_general_req * req,\n                            struct dspqueue_buffer * bufs,\n                            size_t                   n_bufs) {\n    struct dspqueue_buffer rsp_bufs[1];\n\n    // We had written to the output buffer, we'd also need to flush it\n    rsp_bufs[0].fd     = bufs[2].fd;\n    rsp_bufs[0].ptr    = bufs[2].ptr;\n    rsp_bufs[0].size   = bufs[2].size;\n    rsp_bufs[0].offset = bufs[2].offset;\n    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP\n                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU\n\n    // Setup Op context\n    struct htp_ops_context octx = { 0 };\n    octx.ctx                    = ctx;\n    octx.src0                   = req->src0;\n    octx.src1                   = req->src1;\n    octx.dst                    = req->dst;\n    octx.flags                  = req->flags;\n    octx.op                     = req->op;\n\n    // Update data pointers\n    octx.src0.data = (uint32_t) bufs[0].ptr;\n    octx.src1.data = (uint32_t) bufs[1].ptr;\n    octx.dst.data  = (uint32_t) bufs[2].ptr;\n    octx.n_threads = ctx->n_threads;\n\n    struct profile_data prof;\n    profile_start(&prof);\n\n    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;\n    if (vtcm_acquire(ctx) == AEE_SUCCESS) {\n        rsp_status = op_matmul(&octx);\n        vtcm_release(ctx);\n    }\n\n    profile_stop(&prof);\n    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);\n}\n\nstatic void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {\n    struct dspqueue_buffer rsp_bufs[1];\n\n    // We had written to the output buffer, we'd also need to flush it\n    rsp_bufs[0].fd     = bufs[1].fd;\n    rsp_bufs[0].ptr    = bufs[1].ptr;\n    rsp_bufs[0].offset = bufs[1].offset;\n    rsp_bufs[0].size   = bufs[1].size;\n    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP\n                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU\n\n    // Setup Op context\n    struct htp_ops_context octx = { 0 };\n    octx.ctx                    = ctx;\n    octx.src0                   = req->src0;\n    octx.dst                    = req->dst;\n    octx.flags                  = req->flags;\n    octx.op                     = req->op;\n\n    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));\n\n    // Update data pointers\n    octx.src0.data = (uint32_t) bufs[0].ptr;\n    octx.dst.data  = (uint32_t) bufs[1].ptr;\n    octx.n_threads = ctx->n_threads;\n\n    struct profile_data prof;\n    profile_start(&prof);\n\n    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;\n    if (vtcm_acquire(ctx) == AEE_SUCCESS) {\n        rsp_status = op_argsort(&octx);\n        vtcm_release(ctx);\n    }\n\n    profile_stop(&prof);\n    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);\n}\n\nstatic void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {\n    struct dspqueue_buffer rsp_bufs[1];\n\n    // We had written to the output buffer, we'd also need to flush it\n    rsp_bufs[0].fd     = bufs[1].fd;\n    rsp_bufs[0].ptr    = bufs[1].ptr;\n    rsp_bufs[0].offset = bufs[1].offset;\n    rsp_bufs[0].size   = bufs[1].size;\n    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP\n                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU\n\n    // Setup Op context\n    struct htp_ops_context octx = { 0 };\n    octx.ctx                    = ctx;\n    octx.src0                   = req->src0;\n    octx.dst                    = req->dst;\n    octx.flags                  = req->flags;\n    octx.op                     = req->op;\n\n    // Update data pointers\n    octx.src0.data = (uint32_t) bufs[0].ptr;\n    octx.dst.data  = (uint32_t) bufs[1].ptr;\n    octx.n_threads = ctx->n_threads;\n\n    struct profile_data prof;\n    profile_start(&prof);\n\n    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;\n    if (vtcm_acquire(ctx) == AEE_SUCCESS) {\n        rsp_status = op_cpy(&octx);\n        vtcm_release(ctx);\n    }\n\n    profile_stop(&prof);\n    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);\n}\n\nstatic void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {\n    struct dspqueue_buffer rsp_bufs[1];\n\n    // We had written to the output buffer, we'd also need to flush it\n    rsp_bufs[0].fd     = bufs[2].fd;\n    rsp_bufs[0].ptr    = bufs[2].ptr;\n    rsp_bufs[0].offset = bufs[2].offset;\n    rsp_bufs[0].size   = bufs[2].size;\n    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP\n                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU\n\n    // Setup Op context\n    struct htp_ops_context octx = { 0 };\n    octx.ctx                    = ctx;\n    octx.src0                   = req->src0;\n    octx.src1                   = req->src1;\n    octx.dst                    = req->dst;\n    octx.flags                  = req->flags;\n    octx.op                     = req->op;\n\n    // Update data pointers\n    octx.src0.data = (uint32_t) bufs[0].ptr;\n    octx.src1.data = (uint32_t) bufs[1].ptr;\n    octx.dst.data  = (uint32_t) bufs[2].ptr;\n    octx.n_threads = ctx->n_threads;\n\n    struct profile_data prof;\n    profile_start(&prof);\n\n    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;\n    if (vtcm_acquire(ctx) == AEE_SUCCESS) {\n        rsp_status = op_get_rows(&octx);\n        vtcm_release(ctx);\n    }\n\n    profile_stop(&prof);\n    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);\n}\n\nstatic void proc_matmul_id_req(struct htp_context *     ctx,\n                               struct htp_general_req * req,\n                               struct dspqueue_buffer * bufs,\n                               size_t                   n_bufs) {\n    struct dspqueue_buffer rsp_bufs[1];\n\n    // We had written to the output buffer, we'd also need to flush it\n    rsp_bufs[0].fd     = bufs[3].fd;\n    rsp_bufs[0].ptr    = bufs[3].ptr;\n    rsp_bufs[0].size   = bufs[3].size;\n    rsp_bufs[0].offset = bufs[3].offset;\n    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP\n                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU\n\n    // Setup Op context\n    struct htp_ops_context octx = { 0 };\n    octx.ctx                    = ctx;\n    octx.src0                   = req->src0;\n    octx.src1                   = req->src1;\n    octx.src2                   = req->src2;\n    octx.dst                    = req->dst;\n    octx.flags                  = req->flags;\n    octx.op                     = req->op;\n\n    // Update data pointers\n    octx.src0.data = (uint32_t) bufs[0].ptr;\n    octx.src1.data = (uint32_t) bufs[1].ptr;\n    octx.src2.data = (uint32_t) bufs[2].ptr;\n    octx.dst.data  = (uint32_t) bufs[3].ptr;\n    octx.n_threads = ctx->n_threads;\n\n    struct profile_data prof;\n    profile_start(&prof);\n\n    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;\n    if (vtcm_acquire(ctx) == AEE_SUCCESS) {\n        rsp_status = op_matmul_id(&octx);\n        vtcm_release(ctx);\n    }\n\n    profile_stop(&prof);\n    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);\n}\n\nstatic void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {\n    struct dspqueue_buffer rsp_bufs[1];\n\n    // We had written to the output buffer, we'd also need to flush it\n    rsp_bufs[0].fd     = bufs[2].fd;\n    rsp_bufs[0].ptr    = bufs[2].ptr;\n    rsp_bufs[0].offset = bufs[2].offset;\n    rsp_bufs[0].size   = bufs[2].size;\n    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP\n                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU\n\n    // Setup Op context\n    struct htp_ops_context octx = { 0 };\n    octx.ctx                    = ctx;\n    octx.src0                   = req->src0;\n    octx.src1                   = req->src1;\n    octx.dst                    = req->dst;\n    octx.flags                  = req->flags;\n    octx.op                     = req->op;\n\n    // Update data pointers\n    octx.src0.data = (uint32_t) bufs[0].ptr;\n    octx.src1.data = (uint32_t) bufs[1].ptr;\n    octx.dst.data  = (uint32_t) bufs[2].ptr;\n    octx.n_threads = ctx->n_threads;\n\n    struct profile_data prof;\n    profile_start(&prof);\n\n    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;\n    if (vtcm_acquire(ctx) == AEE_SUCCESS) {\n        rsp_status = op_binary(&octx);\n        vtcm_release(ctx);\n    }\n\n    profile_stop(&prof);\n    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);\n}\n\nstatic void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {\n    struct dspqueue_buffer rsp_bufs[1];\n\n    // We had written to the output buffer, we'd also need to flush it\n    rsp_bufs[0].fd     = bufs[3].fd;\n    rsp_bufs[0].ptr    = bufs[3].ptr;\n    rsp_bufs[0].offset = bufs[3].offset;\n    rsp_bufs[0].size   = bufs[3].size;\n    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP\n                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU\n\n    // Setup Op context\n    struct htp_ops_context octx = { 0 };\n    octx.ctx                    = ctx;\n    octx.src0                   = req->src0;\n    octx.src1                   = req->src1;\n    octx.src2                   = req->src2;\n    octx.dst                    = req->dst;\n    octx.flags                  = req->flags;\n    octx.op                     = req->op;\n\n    // Update data pointers\n    octx.src0.data = (uint32_t) bufs[0].ptr;\n    octx.src1.data = (uint32_t) bufs[1].ptr;\n    octx.src2.data = (uint32_t) bufs[2].ptr;\n    octx.dst.data  = (uint32_t) bufs[3].ptr;\n    octx.n_threads = ctx->n_threads;\n\n    struct profile_data prof;\n    profile_start(&prof);\n\n    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;\n    if (vtcm_acquire(ctx) == AEE_SUCCESS) {\n        rsp_status = op_binary(&octx);\n        vtcm_release(ctx);\n    }\n\n    profile_stop(&prof);\n    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);\n}\n\nstatic void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {\n    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];\n\n    // We had written to the output buffer, we'd also need to flush it\n    rsp_bufs[0].fd     = bufs[1].fd;\n    rsp_bufs[0].ptr    = bufs[1].ptr;\n    rsp_bufs[0].offset = bufs[1].offset;\n    rsp_bufs[0].size   = bufs[1].size;\n    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP\n                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU\n\n    // Setup Op context\n    struct htp_ops_context octx = { 0 };\n    octx.ctx                    = ctx;\n    octx.src0                   = req->src0;\n    octx.dst                    = req->dst;\n    octx.flags                  = req->flags;\n    octx.op                     = req->op;\n\n    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));\n\n    // Update data pointers\n    octx.src0.data = (uint32_t) bufs[0].ptr;\n    octx.dst.data  = (uint32_t) bufs[1].ptr;\n    octx.n_threads = ctx->n_threads;\n\n    struct profile_data prof;\n    profile_start(&prof);\n\n    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;\n    if (vtcm_acquire(ctx) == AEE_SUCCESS) {\n        rsp_status = op_unary(&octx);\n        vtcm_release(ctx);\n    }\n\n    profile_stop(&prof);\n    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);\n}\n\nstatic void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {\n    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];\n\n    // We had written to the output buffer, we'd also need to flush it\n    rsp_bufs[0].fd     = bufs[1].fd;\n    rsp_bufs[0].ptr    = bufs[1].ptr;\n    rsp_bufs[0].offset = bufs[1].offset;\n    rsp_bufs[0].size   = bufs[1].size;\n    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP\n                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU\n\n    // Setup Op context\n    struct htp_ops_context octx = { 0 };\n    octx.ctx                    = ctx;\n    octx.src0                   = req->src0;\n    octx.dst                    = req->dst;\n    octx.flags                  = req->flags;\n    octx.op                     = req->op;\n\n    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));\n\n    // Update data pointers\n    octx.src0.data = (uint32_t) bufs[0].ptr;\n    octx.dst.data  = (uint32_t) bufs[1].ptr;\n    octx.n_threads = ctx->n_threads;\n\n    struct profile_data prof;\n    profile_start(&prof);\n\n    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;\n    if (vtcm_acquire(ctx) == AEE_SUCCESS) {\n        rsp_status = op_sum_rows(&octx);\n        vtcm_release(ctx);\n    }\n\n    profile_stop(&prof);\n    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);\n}\n\nstatic void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {\n    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];\n\n    // We've written to the output buffer, we'd also need to flush it\n    rsp_bufs[0].fd     = bufs[2].fd;\n    rsp_bufs[0].ptr    = bufs[2].ptr;\n    rsp_bufs[0].offset = bufs[2].offset;\n    rsp_bufs[0].size   = bufs[2].size;\n    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP\n                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU\n\n    // Setup OP context\n    struct htp_ops_context octx = { 0 };\n    octx.ctx                    = ctx;\n    octx.src0                   = req->src0;\n    octx.src1                   = req->src1;\n    octx.dst                    = req->dst;\n    octx.flags                  = req->flags;\n    octx.op                     = req->op;\n\n    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));\n\n    // Update data pointers\n    octx.src0.data = (uint32_t) bufs[0].ptr;\n    octx.src1.data = (uint32_t) bufs[1].ptr;\n    octx.dst.data  = (uint32_t) bufs[2].ptr;\n    octx.n_threads = ctx->n_threads;\n\n    struct profile_data prof;\n    profile_start(&prof);\n\n    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;\n    if (vtcm_acquire(ctx) == AEE_SUCCESS) {\n        rsp_status = op_ssm_conv(&octx);\n        vtcm_release(ctx);\n    }\n\n    profile_stop(&prof);\n    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);\n}\n\nstatic void proc_activations_req(struct htp_context *     ctx,\n                                 struct htp_general_req * req,\n                                 struct dspqueue_buffer * bufs,\n                                 uint32_t                 n_bufs) {\n    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];\n\n    int write_idx = (n_bufs == 3) ? 2 : 1;\n\n    // We had written to the output buffer, we'd also need to flush it\n    rsp_bufs[0].fd     = bufs[write_idx].fd;\n    rsp_bufs[0].ptr    = bufs[write_idx].ptr;\n    rsp_bufs[0].offset = bufs[write_idx].offset;\n    rsp_bufs[0].size   = bufs[write_idx].size;\n    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP\n                          DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU\n\n    // Setup Op context\n    struct htp_ops_context octx = { 0 };\n    octx.ctx                    = ctx;\n    octx.src0                   = req->src0;\n    if (3 == n_bufs) {\n        octx.src1 = req->src1;\n    }\n    octx.dst   = req->dst;\n    octx.flags = req->flags;\n    octx.op    = req->op;\n\n    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));\n\n    // Update data pointers\n    octx.src0.data = (uint32_t) bufs[0].ptr;\n    if (3 == n_bufs) {\n        octx.src1.data = (uint32_t) bufs[1].ptr;\n        octx.dst.data  = (uint32_t) bufs[2].ptr;\n    } else {\n        octx.dst.data = (uint32_t) bufs[1].ptr;\n    }\n    octx.n_threads = ctx->n_threads;\n\n    struct profile_data prof;\n    profile_start(&prof);\n\n    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;\n    if (vtcm_acquire(ctx) == AEE_SUCCESS) {\n        if (octx.op == HTP_OP_SOFTMAX) {\n            rsp_status = op_softmax(&octx);\n        } else {\n            rsp_status = op_activations(&octx);\n        }\n        vtcm_release(ctx);\n    }\n\n    profile_stop(&prof);\n    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);\n}\n\nstatic void proc_rope_req(struct htp_context *     ctx,\n                          struct htp_general_req * req,\n                          struct dspqueue_buffer * bufs,\n                          uint32_t                 n_bufs) {\n    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];\n\n    int write_idx = n_bufs - 1;\n\n    // We had written to the output buffer, we'd also need to flush it\n    rsp_bufs[0].fd     = bufs[write_idx].fd;\n    rsp_bufs[0].ptr    = bufs[write_idx].ptr;\n    rsp_bufs[0].offset = bufs[write_idx].offset;\n    rsp_bufs[0].size   = bufs[write_idx].size;\n    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP\n                          DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU\n\n    // Setup Op context\n    struct htp_ops_context octx = { 0 };\n    octx.ctx                    = ctx;\n    octx.src0                   = req->src0;\n    octx.src1                   = req->src1;\n    if (4 == n_bufs) {\n        octx.src2 = req->src2;\n    }\n    octx.dst   = req->dst;\n    octx.flags = req->flags;\n    octx.op    = req->op;\n\n    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));\n\n    // Update data pointers\n    octx.src0.data = (uint32_t) bufs[0].ptr;\n    octx.src1.data = (uint32_t) bufs[1].ptr;\n    if (4 == n_bufs) {\n        octx.src2.data = (uint32_t) bufs[2].ptr;\n        octx.dst.data  = (uint32_t) bufs[3].ptr;\n    } else {\n        octx.dst.data = (uint32_t) bufs[2].ptr;\n    }\n    octx.n_threads = ctx->n_threads;\n\n    struct profile_data prof;\n    profile_start(&prof);\n\n    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;\n    if (vtcm_acquire(ctx) == AEE_SUCCESS) {\n        rsp_status = op_rope(&octx);\n        vtcm_release(ctx);\n    }\n\n    profile_stop(&prof);\n    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);\n}\n\nstatic void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {\n    struct dspqueue_buffer rsp_bufs[1];\n\n    // We had written to the output buffer, we'd also need to flush it\n    rsp_bufs[0].fd     = bufs[2].fd;\n    rsp_bufs[0].ptr    = bufs[2].ptr;\n    rsp_bufs[0].offset = bufs[2].offset;\n    rsp_bufs[0].size   = bufs[2].size;\n    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP\n                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU\n\n    // Setup Op context\n    struct htp_ops_context octx = { 0 };\n    octx.ctx                    = ctx;\n    octx.src0                   = req->src0;\n    octx.src1                   = req->src1;\n    octx.dst                    = req->dst;\n    octx.flags                  = req->flags;\n    octx.op                     = req->op;\n\n    // Update data pointers\n    octx.src0.data = (uint32_t) bufs[0].ptr;\n    octx.src1.data = (uint32_t) bufs[1].ptr;\n    octx.dst.data  = (uint32_t) bufs[2].ptr;\n    octx.n_threads = ctx->n_threads;\n\n    struct profile_data prof;\n    profile_start(&prof);\n\n    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;\n    if (vtcm_acquire(ctx) == AEE_SUCCESS) {\n        rsp_status = op_set_rows(&octx);\n        vtcm_release(ctx);\n    }\n\n    profile_stop(&prof);\n    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);\n}\n\nstatic void proc_flash_attn_ext_req(struct htp_context *     ctx,\n                                    struct htp_general_req * req,\n                                    struct dspqueue_buffer * bufs,\n                                    uint32_t                 n_bufs) {\n    // Setup Op context\n    struct htp_ops_context octx;\n    memset(&octx, 0, sizeof(octx));\n\n    octx.ctx   = ctx;\n    octx.n_threads = ctx->n_threads;\n\n    octx.src0  = req->src0;\n    octx.src1  = req->src1;\n    octx.src2  = req->src2;\n    octx.src3  = req->src3;\n    octx.src4  = req->src4;\n    octx.dst   = req->dst;\n    octx.flags = req->flags;\n    octx.op    = req->op;\n\n    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));\n\n    // Update data pointers\n    octx.src0.data = (uint32_t) bufs[0].ptr;\n    octx.src1.data = (uint32_t) bufs[1].ptr;\n    octx.src2.data = (uint32_t) bufs[2].ptr;\n\n    int last_buf = 3;\n\n    if (octx.src3.ne[0]) {\n        octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid\n    }\n\n    if (octx.src4.ne[0]) {\n        octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid\n    }\n\n    octx.dst.data = (uint32_t) bufs[last_buf].ptr;\n\n    struct profile_data prof;\n    profile_start(&prof);\n\n    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;\n    if (vtcm_acquire(ctx) == AEE_SUCCESS) {\n        rsp_status = op_flash_attn_ext(&octx);\n        vtcm_release(ctx);\n    }\n\n    profile_stop(&prof);\n\n    struct dspqueue_buffer rsp_buf = bufs[last_buf];\n    rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP\n                     DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU\n\n    send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);\n}\n\nstatic void htp_packet_callback(dspqueue_t queue, int error, void * context) {\n    struct htp_context * ctx = (struct htp_context *) context;\n\n    // Repeatedly read packets from the queue until it's empty. We don't\n    // necessarily get a separate callback for each packet, and new packets\n    // may arrive while we're processing the previous one. This ensures we\n    // keep the DSP busy as much as possible and avoid waiting for the CPU.\n\n    while (1) {\n        struct htp_general_req req;\n        uint32_t               req_size;\n\n        struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];\n        uint32_t               n_bufs;\n        uint32_t               flags;\n\n        // Read packet from queue\n        int err = dspqueue_read_noblock(queue, &flags,\n                                        HTP_MAX_PACKET_BUFFERS,  // Maximum number of buffer references\n                                        &n_bufs,                 // Number of buffer references\n                                        bufs,                    // Buffer references\n                                        sizeof(req),             // Max message length\n                                        &req_size,               // Message length\n                                        (uint8_t *) &req);       // Message\n\n        if (err == AEE_EWOULDBLOCK) {\n            // Consumed all packets available for now\n            return;\n        }\n\n        if (err != 0) {\n            FARF(ERROR, \"dspqueue_read_noblock failed: 0x%08x\", (unsigned) err);\n            return;\n        }\n\n        if (req_size != sizeof(req)) {\n            FARF(ERROR, \"Invalid request size\");\n            continue;\n        }\n\n        if (req.flags & HTP_OPFLAGS_EARLY_WAKEUP) {\n            // Host wants early notification\n            dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0);\n        }\n\n        // Process packet based on its message type\n        switch (req.op) {\n            case HTP_OP_MUL_MAT:\n                if (n_bufs != 3) {\n                    FARF(ERROR, \"Bad matmul-req buffer list\");\n                    continue;\n                }\n                proc_matmul_req(ctx, &req, bufs, n_bufs);\n                break;\n\n            case HTP_OP_MUL_MAT_ID:\n                if (n_bufs != 4) {\n                    FARF(ERROR, \"Bad matmul-id-req buffer list\");\n                    continue;\n                }\n                proc_matmul_id_req(ctx, &req, bufs, n_bufs);\n                break;\n\n            case HTP_OP_MUL:\n            case HTP_OP_ADD:\n            case HTP_OP_SUB:\n            case HTP_OP_DIV:\n                if (n_bufs != 3) {\n                    FARF(ERROR, \"Bad binary-req buffer list\");\n                    continue;\n                }\n                proc_binary_req(ctx, &req, bufs);\n                break;\n\n            case HTP_OP_RMS_NORM:\n            case HTP_OP_SCALE:\n                if (n_bufs != 2) {\n                    FARF(ERROR, \"Bad unary-req buffer list\");\n                    continue;\n                }\n\n                proc_unary_req(ctx, &req, bufs);\n                break;\n\n            case HTP_OP_SQR:\n            case HTP_OP_SQRT:\n                if (n_bufs != 2) {\n                    FARF(ERROR, \"Bad unary-req buffer list\");\n                    continue;\n                }\n\n                proc_unary_req(ctx, &req, bufs);\n                break;\n\n            case HTP_OP_SUM_ROWS:\n                if (n_bufs != 2) {\n                    FARF(ERROR, \"Bad unary-req buffer list\");\n                    continue;\n                }\n\n                proc_sum_rows_req(ctx, &req, bufs);\n                break;\n\n            case HTP_OP_UNARY_SILU:\n            case HTP_OP_UNARY_GELU:\n                if (n_bufs != 2) {\n                    FARF(ERROR, \"Bad act-req buffer list\");\n                    continue;\n                }\n                proc_activations_req(ctx, &req, bufs, n_bufs);\n                break;\n\n            case HTP_OP_GLU_SWIGLU:\n            case HTP_OP_GLU_SWIGLU_OAI:\n            case HTP_OP_SOFTMAX:\n            case HTP_OP_GLU_GEGLU:\n                if ((n_bufs != 2) && (n_bufs != 3)) {\n                    FARF(ERROR, \"Bad act-req buffer list\");\n                    continue;\n                }\n                proc_activations_req(ctx, &req, bufs, n_bufs);\n                break;\n\n            case HTP_OP_ADD_ID:\n                if (n_bufs != 4) {\n                    FARF(ERROR, \"Bad add-id-req buffer list\");\n                    continue;\n                }\n                proc_add_id_req(ctx, &req, bufs);\n                break;\n\n            case HTP_OP_ROPE:\n                if ((n_bufs != 3) && (n_bufs != 4)) {\n                    FARF(ERROR, \"Bad rope-req buffer list\");\n                    continue;\n                }\n                proc_rope_req(ctx, &req, bufs, n_bufs);\n                break;\n\n            case HTP_OP_FLASH_ATTN_EXT:\n                if (!(n_bufs >= 4 && n_bufs <= 6)) {\n                    FARF(ERROR, \"Bad flash-attn-ext-req buffer list\");\n                    continue;\n                }\n                proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs);\n                break;\n\n            case HTP_OP_SET_ROWS:\n                if (n_bufs != 3) {\n                    FARF(ERROR, \"Bad set-rows-req buffer list\");\n                    continue;\n                }\n                proc_set_rows_req(ctx, &req, bufs);\n                break;\n\n            case HTP_OP_GET_ROWS:\n                if (n_bufs != 3) {\n                    FARF(ERROR, \"Bad get-rows-req buffer list\");\n                    continue;\n                }\n                proc_get_rows_req(ctx, &req, bufs);\n                break;\n\n            case HTP_OP_CPY:\n                if (n_bufs != 2) {\n                    FARF(ERROR, \"Bad cpy-req buffer list\");\n                    continue;\n                }\n                proc_cpy_req(ctx, &req, bufs);\n                break;\n\n            case HTP_OP_ARGSORT:\n                if (n_bufs != 2) {\n                    FARF(ERROR, \"Bad argsort-req buffer list\");\n                    continue;\n                }\n                proc_argsort_req(ctx, &req, bufs);\n                break;\n\n            case HTP_OP_SSM_CONV:\n                if (n_bufs != 3) {\n                    FARF(ERROR, \"Bad ssm-conv-req buffer list\");\n                    continue;\n                }\n                proc_ssm_conv_req(ctx, &req, bufs);\n                break;\n\n            default:\n                FARF(ERROR, \"Unknown Op %u\", req.op);\n                break;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/matmul-ops.c",
    "content": "#pragma clang diagnostic ignored \"-Wgnu-zero-variadic-macro-arguments\"\n#pragma clang diagnostic ignored \"-Wunused-function\"\n#pragma clang diagnostic ignored \"-Wunused-variable\"\n#pragma clang diagnostic ignored \"-Wunused-but-set-variable\"\n\n#include <HAP_farf.h>\n#include <HAP_perf.h>\n\n#include <math.h>\n#include <string.h>\n\n#include \"hex-dma.h\"\n#include \"hvx-utils.h\"\n#include \"hvx-dump.h\"\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n#include \"htp-ctx.h\"\n#include \"htp-msg.h\"\n#include \"htp-ops.h\"\n\n#define MM_SPAD_SRC0_NROWS 16\n#define MM_SPAD_SRC1_NROWS 16\n#define MM_SPAD_DST_NROWS  2\n\nstruct htp_matmul_context {\n    const char * type;\n    struct htp_ops_context * octx;\n\n    void (*vec_dot_1x1)(const int n, float * restrict s0,\n         const void * restrict vx0,\n         const void * restrict vy0);\n\n    void (*vec_dot_2x1)(const int n, float * restrict s0,\n         const void * restrict vx0, const void * restrict vx1,\n         const void * restrict vy0);\n\n    void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1,\n         const void * restrict vx0, const void * restrict vx1,\n         const void * restrict vy0, const void * restrict vy1);\n\n    // Precomputed values\n    uint32_t src0_nrows_per_thread;\n    uint32_t src1_nrows_per_thread;\n\n    struct fastdiv_values mm_div_ne12_ne1;\n    struct fastdiv_values mm_div_ne1;\n    struct fastdiv_values mm_div_r2;\n    struct fastdiv_values mm_div_r3;\n};\n\n// vdelta control to expand first 32 e8m0 values into 32 uint32 elements\nstatic const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {\n    0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,\n    0x00, 0x11, 0x10, 0x10, 0x10, 0x02, 0x00, 0x04, 0x00, 0x01, 0x02, 0x08, 0x08, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04,\n    0x00, 0x00, 0x22, 0x20, 0x20, 0x20, 0x21, 0x22, 0x20, 0x24, 0x04, 0x00, 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x02,\n    0x00, 0x04, 0x00, 0x11, 0x12, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08,\n    0x01, 0x02, 0x00, 0x04, 0x44, 0x40, 0x40, 0x40, 0x41, 0x40, 0x40, 0x40, 0x42, 0x40, 0x44, 0x40, 0x41, 0x42, 0x48,\n    0x48, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x12, 0x10, 0x10, 0x10, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00,\n    0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20,\n};\n\nstatic const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {\n    0,    0, 1,    0, 2,    0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0,\n    0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0,    0, 0,    0, 0,    0, 0,    0,\n    0,    0, 0,    0, 0,    0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0,    0, 0,    0, 0,    0, 0,    0,\n    0,    0, 0,    0, 0,    0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0,    0, 0,    0, 0,    0, 0,    0,\n    0,    0, 0,    0, 0,    0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0,    0, 0,    0, 0,    0,\n};\n\n// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales\n\nstatic inline size_t q8x4x2_row_size(uint32_t ne) {\n    // ensures perfect alignment of quants and full row\n    const uint32_t qk = QK_Q8_0x4x2;\n    const uint32_t nb = (ne + qk - 1) / qk;\n    return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128);\n}\n\nstatic inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) {\n    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;\n\n    HVX_Vector v0_1 = vptr[0];  // first 256 elements (128 bytes)\n    HVX_Vector v2_3 = vptr[1];  // ...\n    HVX_Vector v4_5 = vptr[2];  // ...\n    HVX_Vector v6_7 = vptr[3];  // ...\n\n    const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);\n    const HVX_Vector i8 = Q6_Vb_vsplat_R(8);\n\n    HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4);  // & 0x0F : first  128 elements\n    HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4);    // >> 4   : second 128 elements\n    HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4);  // & 0x0F ...\n    HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4);    // >> 4\n    HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4);  // & 0x0F\n    HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4);    // >> 4\n    HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4);  // & 0x0F\n    HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4);    // >> 4\n\n    // Convert uint4 to int4 (i.e. x - 8)\n    v0 = Q6_Vb_vsub_VbVb(v0, i8);\n    v1 = Q6_Vb_vsub_VbVb(v1, i8);\n    v2 = Q6_Vb_vsub_VbVb(v2, i8);\n    v3 = Q6_Vb_vsub_VbVb(v3, i8);\n    v4 = Q6_Vb_vsub_VbVb(v4, i8);\n    v5 = Q6_Vb_vsub_VbVb(v5, i8);\n    v6 = Q6_Vb_vsub_VbVb(v6, i8);\n    v7 = Q6_Vb_vsub_VbVb(v7, i8);\n\n    HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };\n    return r;\n}\n\nstatic HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) {\n    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;\n\n    const uint32_t qk   = QK_Q4_0x4x2; // 256\n    const uint32_t nb   = n / qk;\n    const uint32_t nloe = n % qk;\n\n    const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);\n    const HVX_Vector i8      = Q6_Vb_vsplat_R(8);\n\n    HVX_Vector_x8 r;\n    uint32_t i = 0;\n\n    #pragma unroll(2)\n    for (i=0; i < nb; i++) {\n        HVX_Vector v = vptr[i];                    // 256 elements (128 bytes)\n        HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4);  // & 0x0F : first  128 elements\n        HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4);    // >> 4   : second 128 elements\n        r.v[i*2+0] = Q6_Vb_vsub_VbVb(v0, i8);\n        r.v[i*2+1] = Q6_Vb_vsub_VbVb(v1, i8);\n    }\n\n    if (nloe) {\n        HVX_Vector v = vptr[i];                    // 256 elements (128 bytes)\n        HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4);  // & 0x0F : even 128 elements\n        HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4);    // >> 4   : odd  128 elements\n        HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...\n        r.v[i*2+0] = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v0_1_p), i8);\n        r.v[i*2+1] = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v0_1_p), i8);\n    }\n\n    return r;\n}\n\nstatic inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) {\n    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;\n\n    HVX_Vector v0_1 = vptr[0];  // first 256 elements (128 bytes)\n    HVX_Vector v2_3 = vptr[1];  // ...\n    HVX_Vector v4_5 = vptr[2];  // ...\n    HVX_Vector v6_7 = vptr[3];  // ...\n\n    const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);\n    const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;\n\n    HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4);  // & 0x0F\n    HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4);    // >> 4\n    HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4);  // & 0x0F\n    HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4);    // >> 4\n    HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4);  // & 0x0F\n    HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4);    // >> 4\n    HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4);  // & 0x0F\n    HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4);    // >> 4\n\n    v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);\n    v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);\n    v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);\n    v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);\n    v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);\n    v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);\n    v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);\n    v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);\n\n    HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };\n    return r;\n}\n\nstatic inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) {\n    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;\n\n    const uint32_t qk   = QK_Q4_0x4x2; // 256\n    const uint32_t nb   = n / qk;\n    const uint32_t nloe = n % qk;\n\n    const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);\n    const HVX_Vector lut     = *(const HVX_Vector *) kvalues_mxfp4_lut;\n\n    HVX_Vector_x8 r;\n    uint32_t i = 0;\n\n    #pragma unroll(2)\n    for (i=0; i < nb; i++) {\n        HVX_Vector v = vptr[i];                    // 256 elements (128 bytes)\n        HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4);  // & 0x0F : first  128 elements\n        HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4);    // >> 4   : second 128 elements\n        r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0);\n        r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0);\n    }\n\n    if (nloe) {\n        HVX_Vector v = vptr[i];                    // 256 elements (128 bytes)\n        HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4);  // & 0x0F : even 128 elements\n        HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4);    // >> 4   : odd  128 elements\n        HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...\n        r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0);\n        r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0);\n    }\n\n    return r;\n}\n\nstatic inline HVX_Vector_x8 hvx_vec_load_q8x4x8_full(const uint8_t * restrict ptr) {\n    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;\n\n    HVX_Vector v0 = vptr[0];  // first  128 vals\n    HVX_Vector v1 = vptr[1];  // ...\n    HVX_Vector v2 = vptr[2];  // ...\n    HVX_Vector v3 = vptr[3];  // ...\n    HVX_Vector v4 = vptr[4];  // ...\n    HVX_Vector v5 = vptr[5];  // ...\n    HVX_Vector v6 = vptr[6];  // ...\n    HVX_Vector v7 = vptr[7];  // ...\n\n    HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };\n    return r;\n}\n\nstatic inline HVX_Vector_x8 hvx_vec_load_q8x4x8_partial(const uint8_t * restrict ptr, uint32_t nloe) {\n    return hvx_vec_load_q8x4x8_full(ptr);\n}\n\n// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).\n// Accumulate each block into a single int32 value.\n// Return a single HVX vector with 32x int32 accumulators.\n// This version is parameterized to support less than 1024 elements.\n// if() checks are optimized out at compile time -- make sure to pass N as a constexpr.\n\nstatic inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {\n    HVX_Vector r0 = Q6_V_vzero();\n    HVX_Vector r1 = Q6_V_vzero();\n    HVX_Vector r2 = Q6_V_vzero();\n    HVX_Vector r3 = Q6_V_vzero();\n    HVX_Vector r4 = Q6_V_vzero();\n    HVX_Vector r5 = Q6_V_vzero();\n    HVX_Vector r6 = Q6_V_vzero();\n    HVX_Vector r7 = Q6_V_vzero();\n\n    HVX_VectorPair p3;\n    HVX_VectorPair p2;\n    HVX_VectorPair p1;\n    HVX_VectorPair p0;\n\n    if (n >=  128) { r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); }\n    if (n >=  256) { r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); }\n    if (n >=  384) { r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); }\n    if (n >=  512) { r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); }\n    if (n >=  640) { r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); }\n    if (n >=  768) { r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); }\n    if (n >=  896) { r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); }\n    if (n >= 1024) { r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); }\n\n    if (n >=  128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }\n    if (n >=  384) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }\n    if (n >=  640) { p2 = Q6_W_vdeal_VVR(r5, r4, -4); }\n    if (n >=  896) { p3 = Q6_W_vdeal_VVR(r7, r6, -4); }\n\n    if (n >=  128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }\n    if (n >=  384) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }\n    if (n >=  640) { r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); }\n    if (n >=  896) { r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); }\n\n    if (n >=  128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }\n    if (n >=  640) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }\n\n    if (n >=  128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }\n    if (n >=  640) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }\n\n    if (n >=  128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }\n    if (n >=  128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }\n\n    return r0;\n}\n\nstatic inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) {\n    HVX_Vector r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]);\n    HVX_Vector r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]);\n    HVX_Vector r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]);\n    HVX_Vector r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]);\n    HVX_Vector r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]);\n    HVX_Vector r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]);\n    HVX_Vector r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]);\n    HVX_Vector r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]);\n\n    HVX_VectorPair p0 = Q6_W_vdeal_VVR(r1, r0, -4);\n    HVX_VectorPair p1 = Q6_W_vdeal_VVR(r3, r2, -4);\n    HVX_VectorPair p2 = Q6_W_vdeal_VVR(r5, r4, -4);\n    HVX_VectorPair p3 = Q6_W_vdeal_VVR(r7, r6, -4);\n\n    r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));\n    r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1));\n    r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2));\n    r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3));\n\n    p0 = Q6_W_vdeal_VVR(r1, r0, -4);\n    p1 = Q6_W_vdeal_VVR(r3, r2, -4);\n\n    r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));\n    r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1));\n\n    p0 = Q6_W_vdeal_VVR(r1, r0, -4);\n    r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));\n\n    return r0;\n}\n\nstatic inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {\n    if (n >= 512)\n        return hvx_vec_rmpy_x8_full(x, y);\n\n    return hvx_vec_rmpy_x8_partial(x, y, 512);\n}\n\nstatic void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {\n    assert(n % 32 == 0);  // min sub-block size\n    assert((unsigned long) vx0 % 128 == 0);\n    assert((unsigned long) vy0 % 128 == 0);\n\n    const uint32_t qk = QK_Q4_0x4x2 * 4;\n\n    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16\n    const uint32_t x_qblk_size = qk / 2;                                      // int4\n    const uint32_t x_qrow_size = n / 2;                                       // int4 (not padded)\n\n    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16\n    const uint32_t y_qblk_size = qk;                                          // int8\n    const uint32_t y_qrow_size = n;                                           // int8 (not padded)\n\n    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0);            // quants first\n    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size);  // then scales\n\n    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);               // quants first\n    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);     // then scales\n\n    // Row sum (sf)\n    HVX_Vector r0_sum = Q6_V_vzero();\n\n    // Multiply and accumulate into int32.\n    // Compute combined scale (fp32).\n    // Apply scale to acc and accumulate into the row sum (qf32).\n\n    const uint32_t nb   = n / qk;  // num full blocks\n    const uint32_t nloe = n % qk;  // num leftover elemements\n\n    uint32_t i = 0;\n    for (; i < nb; i++) {\n        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q    + i * y_qblk_size);\n        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);\n\n        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));\n\n        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d    + i * y_dblk_size));\n        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));\n\n        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));\n\n        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);\n\n        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));\n    }\n\n    // Process leftovers\n    if (nloe) {\n        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q    + i * y_qblk_size, nloe);\n        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);\n\n        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));\n\n        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d    + i * y_dblk_size));\n        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));\n\n        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));\n\n        // Zero out unused elements\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);\n        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);\n        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);\n\n        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);\n\n        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));\n    }\n\n    r0_sum = hvx_vec_reduce_sum_f32(r0_sum);\n\n    hvx_vec_store_u(s0, 4, r0_sum);\n}\n\nstatic void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,\n                                      const void * restrict vx0, const void * restrict vx1,\n                                      const void * restrict vy0) {\n    assert(n % 32 == 0);  // min sub-block size\n    assert((unsigned long) vx0 % 128 == 0);\n    assert((unsigned long) vx1 % 128 == 0);\n    assert((unsigned long) vy0 % 128 == 0);\n\n    const uint32_t qk = QK_Q4_0x4x2 * 4;\n\n    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16\n    const uint32_t x_qblk_size = qk / 2;                                      // int4\n    const uint32_t x_qrow_size = n / 2;                                       // int4 (not padded)\n\n    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16\n    const uint32_t y_qblk_size = qk;                                          // int8\n    const uint32_t y_qrow_size = n;                                           // int8 (not padded)\n\n    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first\n    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales\n    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first\n    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales\n\n    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);               // quants first\n    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);     // then scales\n\n    // Row sum (sf)\n    HVX_Vector r0_sum = Q6_V_vzero();\n    HVX_Vector r1_sum = Q6_V_vzero();\n\n    // Multiply and accumulate into int32.\n    // Compute combined scale (fp32).\n    // Apply scale to acc and accumulate into the row sum (qf32).\n\n    const uint32_t nb   = n / qk;  // num full blocks\n    const uint32_t nloe = n % qk;  // num leftover elemements\n\n    uint32_t i = 0;\n    for (; i < nb; i++) {\n        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q    + i * y_qblk_size);\n        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);\n        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);\n\n        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));\n        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));\n\n        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d    + i * y_dblk_size));\n        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));\n        HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));\n\n        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));\n        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));\n\n        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);\n        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);\n\n        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));\n        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));\n    }\n\n    // Process leftovers\n    if (nloe) {\n        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q    + i * y_qblk_size, nloe);\n        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);\n        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);\n\n        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));\n        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));\n\n        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d    + i * y_dblk_size));\n        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));\n        HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));\n\n        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));\n        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));\n\n        // Zero out unused elements\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);\n        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);\n        r1_dd                = Q6_V_vand_QV(bmask, r1_dd);\n        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);\n        r1_ia                = Q6_V_vand_QV(bmask, r1_ia);\n\n        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);\n        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);\n\n        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));\n        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));\n    }\n\n    HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);\n    hvx_vec_store_u(s0, 8, rsum);\n}\n\nstatic void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,\n                                        const void * restrict vx0, const void * restrict vx1,\n                                        const void * restrict vy0, const void * restrict vy1) {\n    assert(n % 32 == 0);\n    assert((unsigned long) vx0 % 128 == 0);\n    assert((unsigned long) vx1 % 128 == 0);\n    assert((unsigned long) vy0 % 128 == 0);\n    assert((unsigned long) vy1 % 128 == 0);\n\n    const uint32_t qk = QK_Q4_0x4x2 * 4;\n\n    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16\n    const uint32_t x_qblk_size = qk / 2;                                      // int4\n    const uint32_t x_qrow_size = n / 2;                                       // int4 (not padded)\n\n    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16\n    const uint32_t y_qblk_size = qk;                                          // int8\n    const uint32_t y_qrow_size = n;                                           // int8 (not padded)\n\n    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first\n    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales\n    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first\n    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales\n\n    const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;              // quants first\n    const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;    // then scales\n    const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;              // quants first\n    const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;    // then scales\n\n    // Row sums (sf) - 4 accumulators for 2×2 tile\n    HVX_Vector r0_c0_sum = Q6_V_vzero();\n    HVX_Vector r0_c1_sum = Q6_V_vzero();\n    HVX_Vector r1_c0_sum = Q6_V_vzero();\n    HVX_Vector r1_c1_sum = Q6_V_vzero();\n\n    const uint32_t nb   = n / qk;  // num full blocks\n    const uint32_t nloe = n % qk;  // num leftover elements\n\n    uint32_t i = 0;\n    for (; i < nb; i++) {\n        // Load src1 columns (reused across both src0 rows)\n        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);\n        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);\n\n        // Load src0 rows (reused across both src1 columns)\n        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);\n        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);\n\n        // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1\n        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));\n        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));\n        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));\n        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));\n\n        // Load scales\n        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d   + i * y_dblk_size));\n        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d   + i * y_dblk_size));\n        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));\n        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));\n\n        // Compute combined scales\n        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));\n        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));\n        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));\n        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));\n\n        // Apply scales and accumulate\n        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);\n        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);\n        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);\n        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);\n\n        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));\n        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));\n        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));\n        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));\n    }\n\n    // Process leftovers\n    if (nloe) {\n        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q   + i * y_qblk_size, nloe);\n        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q   + i * y_qblk_size, nloe);\n        HVX_Vector_x8 r0_q  = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);\n        HVX_Vector_x8 r1_q  = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);\n\n        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));\n        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));\n        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));\n        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));\n\n        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d   + i * y_dblk_size));\n        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d   + i * y_dblk_size));\n        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));\n        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));\n\n        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));\n        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));\n        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));\n        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));\n\n        // Zero out unused scales\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);\n        r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);\n        r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);\n        r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);\n        r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);\n        r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);\n        r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);\n        r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);\n        r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);\n\n        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);\n        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);\n        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);\n        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);\n\n        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));\n        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));\n        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));\n        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));\n    }\n\n    // Reduce and store results\n    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);\n    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);\n\n    hvx_vec_store_u(s0, 8, r0_r1_c0_sum);  // row0,col0 row1,col0\n    hvx_vec_store_u(s1, 8, r0_r1_c1_sum);  // row0,col1 row1,col1\n}\n\nstatic void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {\n    assert(n % 32 == 0);  // min sub-block size\n    assert((unsigned long) vx0 % 128 == 0);\n    assert((unsigned long) vy0 % 128 == 0);\n\n    const uint32_t qk = QK_Q4_0x4x2 * 4;\n\n    const uint32_t x_dblk_size = 8 * 4 * 2;                                  // 32x __fp16\n    const uint32_t x_qblk_size = qk;                                         // int8\n    const uint32_t x_qrow_size = n;                                          // int8 (not padded)\n\n    const uint32_t y_dblk_size = 8 * 4 * 2;                                  // 32x __fp16\n    const uint32_t y_qblk_size = qk;                                         // int8\n    const uint32_t y_qrow_size = n;                                          // int8 (not padded)\n\n    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0);           // quants first\n    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales\n\n    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);              // quants first\n    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);    // then scales\n\n    // Row sum (sf)\n    HVX_Vector r0_sum = Q6_V_vzero();\n\n    // Multiply and accumulate into int32.\n    // Compute combined scale (fp32).\n    // Apply scale to acc and accumulate into the row sum (qf32).\n\n    const uint32_t nb   = n / qk;  // num full blocks\n    int32_t        nloe = n % qk;  // num leftover elemements (must be signed)\n\n    uint32_t i = 0;\n    for (; i < nb; i++) {\n        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q    + i * y_qblk_size);\n        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);\n\n        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));\n\n        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d    + i * y_dblk_size));\n        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));\n\n        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));\n\n        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);\n\n        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));\n    }\n\n    // Process leftovers\n    if (nloe) {\n        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q    + i * y_qblk_size, nloe);\n        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);\n\n        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));\n\n        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d    + i * y_dblk_size));\n        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));\n\n        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));\n\n        // Zero out unused elements\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);\n        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);\n        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);\n\n        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);\n\n        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));\n    }\n\n    r0_sum = hvx_vec_reduce_sum_f32(r0_sum);\n\n    hvx_vec_store_u(s0, 4, r0_sum);\n}\n\nstatic void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,\n                                      const void * restrict vx0, const void * restrict vx1,\n                                      const void * restrict vy0) {\n    assert(n % 32 == 0);  // min sub-block size\n    assert((unsigned long) vx0 % 128 == 0);\n    assert((unsigned long) vx1 % 128 == 0);\n    assert((unsigned long) vy0 % 128 == 0);\n\n    const uint32_t qk = QK_Q4_0x4x2 * 4;\n\n    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16\n    const uint32_t x_qblk_size = qk;                                          // int8\n    const uint32_t x_qrow_size = n;                                           // int8 (not padded)\n\n    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16\n    const uint32_t y_qblk_size = qk;                                          // int8\n    const uint32_t y_qrow_size = n;                                           // int8 (not padded)\n\n    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first\n    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales\n    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first\n    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales\n\n    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);               // quants first\n    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);     // then scales\n\n    // Row sum (qf32)\n    HVX_Vector r0_sum = Q6_V_vzero();\n    HVX_Vector r1_sum = Q6_V_vzero();\n\n    // Multiply and accumulate into int32.\n    // Compute combined scale (fp32).\n    // Apply scale to acc and accumulate into the row sum (qf32).\n\n    const uint32_t nb   = n / qk;  // num full blocks\n    int32_t        nloe = n % qk;  // num leftover elemements (must be signed)\n\n    uint32_t i = 0;\n    for (; i < nb; i++) {\n        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q    + i * y_qblk_size);\n        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);\n        HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);\n\n        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));\n        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));\n\n        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d    + i * y_dblk_size));\n        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));\n        HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));\n\n        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));\n        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));\n\n        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);\n        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);\n\n        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));\n        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));\n    }\n\n    // Process leftovers\n    if (nloe) {\n        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q    + i * y_qblk_size, nloe);\n        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);\n        HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);\n\n        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));\n        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));\n\n        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));\n        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));\n        HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));\n\n        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));\n        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));\n\n        // Zero out unused elements\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);\n        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);\n        r1_dd                = Q6_V_vand_QV(bmask, r1_dd);\n        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);\n        r1_ia                = Q6_V_vand_QV(bmask, r1_ia);\n\n        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);\n        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);\n\n        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));\n        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));\n    }\n\n    HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);\n    hvx_vec_store_u(s0, 8, rsum);\n}\n\nstatic void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,\n                                        const void * restrict vx0, const void * restrict vx1,\n                                        const void * restrict vy0, const void * restrict vy1) {\n    assert(n % 32 == 0);\n    assert((unsigned long) vx0 % 128 == 0);\n    assert((unsigned long) vx1 % 128 == 0);\n    assert((unsigned long) vy0 % 128 == 0);\n    assert((unsigned long) vy1 % 128 == 0);\n\n    const uint32_t qk = QK_Q8_0x4x2 * 4;\n\n    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16\n    const uint32_t x_qblk_size = qk;                                          // int8\n    const uint32_t x_qrow_size = n;                                           // int8 (not padded)\n\n    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16\n    const uint32_t y_qblk_size = qk;                                          // int8\n    const uint32_t y_qrow_size = n;                                           // int8 (not padded)\n\n    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first\n    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales\n    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first\n    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales\n\n    const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;              // quants first\n    const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;    // then scales\n    const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;              // quants first\n    const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;    // then scales\n\n    // Row sums (sf) - 4 accumulators for 2×2 tile\n    HVX_Vector r0_c0_sum = Q6_V_vzero();\n    HVX_Vector r0_c1_sum = Q6_V_vzero();\n    HVX_Vector r1_c0_sum = Q6_V_vzero();\n    HVX_Vector r1_c1_sum = Q6_V_vzero();\n\n    const uint32_t nb   = n / qk;  // num full blocks\n    const uint32_t nloe = n % qk;  // num leftover elements\n\n    uint32_t i = 0;\n    for (; i < nb; i++) {\n        // Load src1 columns (reused across both src0 rows)\n        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);\n        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);\n\n        // Load src0 rows (reused across both src1 columns)\n        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);\n        HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);\n\n        // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1\n        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));\n        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));\n        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));\n        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));\n\n        // Load scales\n        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d   + i * y_dblk_size));\n        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d   + i * y_dblk_size));\n        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));\n        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));\n\n        // Compute combined scales\n        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));\n        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));\n        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));\n        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));\n\n        // Apply scales and accumulate\n        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);\n        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);\n        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);\n        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);\n\n        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));\n        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));\n        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));\n        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));\n    }\n\n    // Process leftovers\n    if (nloe) {\n        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q   + i * y_qblk_size, nloe);\n        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q   + i * y_qblk_size, nloe);\n        HVX_Vector_x8 r0_q  = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);\n        HVX_Vector_x8 r1_q  = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);\n\n        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));\n        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));\n        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));\n        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));\n\n        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d   + i * y_dblk_size));\n        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d   + i * y_dblk_size));\n        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));\n        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));\n\n        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));\n        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));\n        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));\n        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));\n\n        // Zero out unused elements\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);\n        r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);\n        r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);\n        r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);\n        r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);\n        r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);\n        r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);\n        r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);\n        r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);\n\n        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);\n        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);\n        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);\n        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);\n\n        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));\n        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));\n        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));\n        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));\n    }\n\n    // Reduce and store results\n    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);\n    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);\n\n    hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);  // row0,col0 row1,col0\n    hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);  // row0,col1 row1,col1\n}\n\nstatic void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {\n    assert(n % 32 == 0);  // min sub-block size\n    assert((unsigned long) vx0 % 128 == 0);\n    assert((unsigned long) vy0 % 128 == 0);\n\n    const uint32_t qk = QK_MXFP4x4x2 * 4;\n\n    const uint32_t x_dblk_size = 8 * 4 * 1;                                  // 32x e8m0\n    const uint32_t x_qblk_size = qk / 2;                                     // fp4\n    const uint32_t x_qrow_size = n / 2;                                      // fp4 (not padded)\n\n    const uint32_t y_dblk_size = 8 * 4 * 2;                                  // 32x __fp16\n    const uint32_t y_qblk_size = qk;                                         // int8\n    const uint32_t y_qrow_size = n;                                          // int8 (not padded)\n\n    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0);           // quants first\n    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales\n\n    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);              // quants first\n    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);    // then scales\n\n    // Row sum (sf)\n    HVX_Vector r0_sum = Q6_V_vzero();\n\n    // Multiply and accumulate into int32.\n    // Compute combined scale (fp32).\n    // Apply scale to acc and accumulate into the row sum (qf32).\n\n    const uint32_t nb   = n / qk;  // num full blocks\n    int32_t        nloe = n % qk;  // num leftover elemements (must be signed)\n\n    uint32_t i = 0;\n    for (; i < nb; i++) {\n        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(   y_q    + i * y_qblk_size);\n        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);\n\n        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));\n\n        HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);\n        HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);\n\n        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving\n        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16\n        vy_d            = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));\n        vy_d            = Q6_Vsf_equals_Vqf32(vy_d);\n\n        // Convert rX_d scales from e8m0 to fp32\n        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...\n        // Left shift with zero fill to create FP32\n        // FIXME: might need to handle zero as a special case (see ggml-cpu code)\n        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;\n        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);\n        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);\n        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);\n        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);\n\n        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));\n\n        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);\n\n        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));\n    }\n\n    // Process leftovers\n    if (nloe) {\n        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(   y_q    + i * y_qblk_size, nloe);\n        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);\n\n        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));\n\n        HVX_Vector vy_d = *(const HVX_UVector *) (y_d    + i * y_dblk_size);\n        HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);\n\n        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving\n        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16\n        vy_d            = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));\n        vy_d            = Q6_Vsf_equals_Vqf32(vy_d);\n\n        // Convert rX_d scales from e8m0 to fp32\n        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...\n        // Left shift with zero fill to create FP32\n        // FIXME: might need to handle zero as a special case (see ggml-cpu code)\n        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;\n        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);\n        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);\n        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);\n        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);\n\n        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));\n\n        // Zero-out unused scales\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);\n        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);\n        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);\n\n        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);\n\n        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));\n    }\n\n    r0_sum = hvx_vec_reduce_sum_f32(r0_sum);\n\n    hvx_vec_store_u(s0, 4, r0_sum);\n}\n\nstatic void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,\n                                      const void * restrict vx0, const void * restrict vx1,\n                                      const void * restrict vy0) {\n    assert(n % 32 == 0);  // min sub-block size\n    assert((unsigned long) vx0 % 128 == 0);\n    assert((unsigned long) vx1 % 128 == 0);\n    assert((unsigned long) vy0 % 128 == 0);\n\n    const uint32_t qk = QK_MXFP4x4x2 * 4;\n\n    const uint32_t x_dblk_size = 8 * 4 * 1;                                   // 32x e8m0\n    const uint32_t x_qblk_size = qk / 2;                                      // fp4\n    const uint32_t x_qrow_size = n / 2;                                       // fp4 (not padded)\n\n    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16\n    const uint32_t y_qblk_size = qk;                                          // int8\n    const uint32_t y_qrow_size = n;                                           // int8 (not padded)\n\n    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first\n    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales\n    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first\n    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales\n\n    const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0;               // quants first\n    const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size;     // then scales\n\n    // Row sum (sf)\n    HVX_Vector r0_sum = Q6_V_vzero();\n    HVX_Vector r1_sum = Q6_V_vzero();\n\n    // Multiply and accumulate into int32.\n    // Compute combined scale (fp32).\n    // Apply scale to acc and accumulate into the row sum (f32).\n\n    const uint32_t nb   = n / qk;  // num full blocks\n    int32_t        nloe = n % qk;  // num leftover elemements (must be signed)\n\n    uint32_t i = 0;\n    for (; i < nb; i++) {\n        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(   y_q    + i * y_qblk_size);\n        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);\n        HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);\n\n        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));\n        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));\n\n        HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);\n        HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);\n        HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);\n\n        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving\n        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16\n        vy_d            = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));\n        vy_d            = Q6_Vsf_equals_Vqf32(vy_d);\n\n        // Convert rX_d scales from e8m0 to fp32\n        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...\n        // Left shift with zero fill to create FP32\n        // FIXME: might need to handle zero as a special case (see ggml-cpu code)\n        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;\n        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);\n        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);\n        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);\n        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);\n        r1_d                 = Q6_V_vdelta_VV(r1_d, expand);\n        r1_d                 = Q6_V_vand_VV(r1_d, e8m0_mask);\n        r1_d                 = Q6_Vw_vasl_VwR(r1_d, 23);\n\n        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));\n        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));\n\n        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);\n        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);\n\n        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));\n        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));\n    }\n\n    // Process leftovers\n    if (nloe) {\n        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(   y_q    + i * y_qblk_size, nloe);\n        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);\n        HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);\n\n        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));\n        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));\n\n        HVX_Vector vy_d = *(const HVX_UVector *) (y_d    + i * y_dblk_size);\n        HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);\n        HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);\n\n        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving\n        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16\n        vy_d            = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));\n        vy_d            = Q6_Vsf_equals_Vqf32(vy_d);\n\n        // Convert rX_d scales from e8m0 to fp32\n        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...\n        // Left shift with zero fill to create FP32\n        // FIXME: might need to handle zero as a special case (see ggml-cpu code)\n        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;\n        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);\n        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);\n        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);\n        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);\n        r1_d                 = Q6_V_vdelta_VV(r1_d, expand);\n        r1_d                 = Q6_V_vand_VV(r1_d, e8m0_mask);\n        r1_d                 = Q6_Vw_vasl_VwR(r1_d, 23);\n\n        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));\n        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));\n\n        // Zero-out unused values\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);\n        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);\n        r1_dd                = Q6_V_vand_QV(bmask, r1_dd);\n        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);\n        r1_ia                = Q6_V_vand_QV(bmask, r1_ia);\n\n        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);\n        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);\n\n        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));\n        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));\n    }\n\n    HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);\n    hvx_vec_store_u(s0, 8, rsum);\n}\n\nstatic void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,\n                                        const void * restrict vx0, const void * restrict vx1,\n                                        const void * restrict vy0, const void * restrict vy1) {\n    assert(n % 32 == 0);\n    assert((unsigned long) vx0 % 128 == 0);\n    assert((unsigned long) vx1 % 128 == 0);\n    assert((unsigned long) vy0 % 128 == 0);\n    assert((unsigned long) vy1 % 128 == 0);\n\n    const uint32_t qk = QK_MXFP4x4x2 * 4;\n\n    const uint32_t x_dblk_size = 8 * 4 * 1;                                   // 32x e8m0\n    const uint32_t x_qblk_size = qk / 2;                                      // fp4\n    const uint32_t x_qrow_size = n / 2;                                       // fp4 (not padded)\n\n    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16\n    const uint32_t y_qblk_size = qk;                                          // int8\n    const uint32_t y_qrow_size = n;                                           // int8 (not padded)\n\n    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first\n    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales\n    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first\n    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales\n\n    const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;              // quants first\n    const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;    // then scales\n    const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;              // quants first\n    const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;    // then scales\n\n    // Row sums (sf) - 4 accumulators for 2×2 tile\n    HVX_Vector r0_c0_sum = Q6_V_vzero();\n    HVX_Vector r0_c1_sum = Q6_V_vzero();\n    HVX_Vector r1_c0_sum = Q6_V_vzero();\n    HVX_Vector r1_c1_sum = Q6_V_vzero();\n\n    const uint32_t nb   = n / qk;  // num full blocks\n    const uint32_t nloe = n % qk;  // num leftover elements\n\n    uint32_t i = 0;\n    for (; i < nb; i++) {\n        // Load src1 columns (reused across both src0 rows)\n        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);\n        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);\n\n        // Load src0 rows (reused across both src1 columns)\n        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);\n        HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);\n\n        // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1\n        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));\n        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));\n        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));\n        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));\n\n        // Load scales\n        HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d   + i * y_dblk_size);\n        HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d   + i * y_dblk_size);\n        HVX_Vector r0_d  = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);\n        HVX_Vector r1_d  = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);\n\n        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving\n        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16\n        vy0_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));\n        vy0_d           = Q6_Vsf_equals_Vqf32(vy0_d);\n        vy1_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));\n        vy1_d           = Q6_Vsf_equals_Vqf32(vy1_d);\n\n        // Convert rX_d scales from e8m0 to fp32\n        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...\n        // Left shift with zero fill to create FP32\n        // FIXME: might need to handle zero as a special case (see ggml-cpu code)\n        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;\n        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);\n        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);\n        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);\n        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);\n        r1_d                 = Q6_V_vdelta_VV(r1_d, expand);\n        r1_d                 = Q6_V_vand_VV(r1_d, e8m0_mask);\n        r1_d                 = Q6_Vw_vasl_VwR(r1_d, 23);\n\n        // Compute combined scales\n        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));\n        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));\n        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));\n        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));\n\n        // Apply scales and accumulate\n        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);\n        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);\n        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);\n        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);\n\n        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));\n        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));\n        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));\n        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));\n    }\n\n    // Process leftovers\n    if (nloe) {\n        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(   y0_q   + i * y_qblk_size, nloe);\n        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(   y1_q   + i * y_qblk_size, nloe);\n        HVX_Vector_x8 r0_q  = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);\n        HVX_Vector_x8 r1_q  = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);\n\n        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));\n        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));\n        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));\n        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));\n\n        HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d   + i * y_dblk_size);\n        HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d   + i * y_dblk_size);\n        HVX_Vector r0_d  = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);\n        HVX_Vector r1_d  = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);\n\n        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving\n        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16\n        vy0_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));\n        vy0_d           = Q6_Vsf_equals_Vqf32(vy0_d);\n        vy1_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));\n        vy1_d           = Q6_Vsf_equals_Vqf32(vy1_d);\n\n        // Convert rX_d scales from e8m0 to fp32\n        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...\n        // Left shift with zero fill to create FP32\n        // FIXME: might need to handle zero as a special case (see ggml-cpu code)\n        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;\n        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);\n        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);\n        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);\n        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);\n        r1_d                 = Q6_V_vdelta_VV(r1_d, expand);\n        r1_d                 = Q6_V_vand_VV(r1_d, e8m0_mask);\n        r1_d                 = Q6_Vw_vasl_VwR(r1_d, 23);\n\n        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));\n        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));\n        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));\n        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));\n\n        // Zero out unused scales\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);\n        r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);\n        r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);\n        r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);\n        r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);\n        r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);\n        r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);\n        r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);\n        r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);\n\n        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);\n        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);\n        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);\n        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);\n\n        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));\n        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));\n        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));\n        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));\n    }\n\n    // Reduce and store results\n    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);\n    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);\n\n    hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);  // row0,col0 row1,col0\n    hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);  // row0,col1 row1,col1\n}\n\nstatic void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {\n    const HVX_Vector * restrict x = (const HVX_Vector *) vx;\n    const HVX_Vector * restrict y = (const HVX_Vector *) vy;\n\n    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors\n    uint32_t nloe = n % VLEN_FP16; // leftover elements\n\n    HVX_VectorPair rsum_p = Q6_W_vzero();\n\n    uint32_t i = 0;\n\n    #pragma unroll(4)\n    for (i = 0; i < nvec; i++) {\n        rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]);\n    }\n\n    if (nloe) {\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);\n        HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);\n        HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);\n        rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);\n    }\n\n    HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));\n    hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum));\n}\n\nstatic void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0,\n                                const void * restrict vx0, const void * restrict vx1,\n                                const void * restrict vy0) {\n    const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;\n    const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;\n    const HVX_Vector * restrict y  = (const HVX_Vector *) vy0;\n\n    uint32_t nvec = n / VLEN_FP16;\n    uint32_t nloe = n % VLEN_FP16;\n\n    HVX_VectorPair rsum0_p = Q6_W_vzero();\n    HVX_VectorPair rsum1_p = Q6_W_vzero();\n\n    uint32_t i = 0;\n\n    #pragma unroll(2)\n    for (i = 0; i < nvec; i++) {\n        HVX_Vector y_hf = y[i];\n        rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf);\n        rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf);\n    }\n\n    if (nloe) {\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);\n        HVX_Vector y_hf  = Q6_V_vand_QV(bmask, y[i]);\n        HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);\n        HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);\n        rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);\n        rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);\n    }\n\n    HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));\n    HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));\n    HVX_Vector rsum  = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);\n    hvx_vec_store_u(s0, 8, rsum);\n}\n\nstatic void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1,\n                                const void * restrict vx0, const void * restrict vx1,\n                                const void * restrict vy0, const void * restrict vy1) {\n    const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;\n    const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;\n    const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;\n    const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;\n\n    uint32_t nvec = n / VLEN_FP16;\n    uint32_t nloe = n % VLEN_FP16;\n\n    // Row sums (sf) - 4 accumulators for 2×2 tile\n    HVX_VectorPair r0_c0_sum_p = Q6_W_vzero();\n    HVX_VectorPair r0_c1_sum_p = Q6_W_vzero();\n    HVX_VectorPair r1_c0_sum_p = Q6_W_vzero();\n    HVX_VectorPair r1_c1_sum_p = Q6_W_vzero();\n\n    uint32_t i = 0;\n\n    #pragma unroll(2)\n    for (i = 0; i < nvec; i++) {\n        HVX_Vector r0_hf = x0[i];\n        HVX_Vector r1_hf = x1[i];\n        HVX_Vector c0_hf = y0[i];\n        HVX_Vector c1_hf = y1[i];\n\n        // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1\n        r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);\n        r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);\n        r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);\n        r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);\n    }\n\n    if (nloe) {\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);\n\n        HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);\n        HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);\n        HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);\n        HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);\n\n        r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);\n        r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);\n        r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);\n        r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);\n    }\n\n    HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p)));\n    HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p)));\n    HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p)));\n    HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p)));\n\n    // Reduce and store results\n    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);\n    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);\n\n    hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);  // row0,col0 row1,col0\n    hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);  // row0,col1 row1,col1\n}\n\nstatic void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {\n    const HVX_UVector * restrict x = (const HVX_UVector *) vx;\n    const HVX_UVector * restrict y = (const HVX_UVector *) vy;\n\n    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors\n    uint32_t nloe = n % VLEN_FP16; // leftover elements\n\n    HVX_Vector rsum = Q6_V_vzero();\n\n    uint32_t i = 0;\n\n    #pragma unroll(4)\n    for (i = 0; i < nvec; i++) {\n        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);\n        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));\n    }\n\n    if (nloe) {\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);\n        HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);\n        HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);\n\n        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);\n        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));\n    }\n\n    rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));\n    hvx_vec_store_u(&s[0], 4, rsum);\n}\n\nstatic void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {\n    const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;\n    const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;\n\n    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors\n    uint32_t nloe = n % VLEN_FP16; // leftover elements\n\n    const HVX_Vector zero = Q6_V_vzero();\n\n    HVX_Vector       rsum = Q6_V_vzero();\n\n    uint32_t i = 0;\n\n    #pragma unroll(2)\n    for (i = 0; i < nvec; i++) {\n        // Load y (fp32) and convert into fp16\n        HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero);  // 32 elements\n        HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero);  // 32 elements\n        HVX_Vector y_hf  = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));\n\n        // Load x (fp16)\n        HVX_Vector x_hf  = vx[i];\n\n        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);\n\n        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));\n    }\n\n    if (nloe) {\n        // Load y (fp32) and convert into fp16\n        HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero);  // 32 elements\n        HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero);  // 32 elements\n        HVX_Vector y_hf  = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));\n\n        // Load x (fp16)\n        HVX_Vector x_hf  = vx[i];\n\n        // Zero-out unused elements\n        // Note that we need to clear both x and y because they may contain NANs\n        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);\n        x_hf = Q6_V_vand_QV(bmask, x_hf);\n        y_hf = Q6_V_vand_QV(bmask, y_hf);\n\n        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);\n\n        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));\n    }\n\n    // Convert into fp32 and reduce\n    rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));\n    hvx_vec_store_u(&s[0], 4, rsum);\n}\n\n#define htp_matmul_tensors_preamble    \\\n    struct htp_tensor * restrict src0    = &octx->src0;      \\\n    struct htp_tensor * restrict src1    = &octx->src1;      \\\n    struct htp_tensor * restrict src2    = &octx->src2;      \\\n    struct htp_tensor * restrict dst     = &octx->dst;       \\\n    struct htp_spad * restrict src0_spad = &octx->src0_spad; \\\n    struct htp_spad * restrict src1_spad = &octx->src1_spad; \\\n    struct htp_spad * restrict dst_spad  = &octx->dst_spad;  \\\n                                                             \\\n    const uint32_t ne00 = src0->ne[0]; \\\n    const uint32_t ne01 = src0->ne[1]; \\\n    const uint32_t ne02 = src0->ne[2]; \\\n    const uint32_t ne03 = src0->ne[3]; \\\n                                       \\\n    const uint32_t ne10 = src1->ne[0]; \\\n    const uint32_t ne11 = src1->ne[1]; \\\n    const uint32_t ne12 = src1->ne[2]; \\\n    const uint32_t ne13 = src1->ne[3]; \\\n                                       \\\n    const uint32_t ne20 = src2->ne[0]; \\\n    const uint32_t ne21 = src2->ne[1]; \\\n    const uint32_t ne22 = src2->ne[2]; \\\n    const uint32_t ne23 = src2->ne[3]; \\\n                                       \\\n    const uint32_t ne0 = dst->ne[0];   \\\n    const uint32_t ne1 = dst->ne[1];   \\\n    const uint32_t ne2 = dst->ne[2];   \\\n    const uint32_t ne3 = dst->ne[3];   \\\n                                       \\\n    const uint32_t nb00 = src0->nb[0]; \\\n    const uint32_t nb01 = src0->nb[1]; \\\n    const uint32_t nb02 = src0->nb[2]; \\\n    const uint32_t nb03 = src0->nb[3]; \\\n                                       \\\n    const uint32_t nb10 = src1->nb[0]; \\\n    const uint32_t nb11 = src1->nb[1]; \\\n    const uint32_t nb12 = src1->nb[2]; \\\n    const uint32_t nb13 = src1->nb[3]; \\\n                                       \\\n    const uint32_t nb0 = dst->nb[0];   \\\n    const uint32_t nb1 = dst->nb[1];   \\\n    const uint32_t nb2 = dst->nb[2];   \\\n    const uint32_t nb3 = dst->nb[3];\n\n#define htp_matmul_preamble                                     \\\n    struct htp_matmul_context * mmctx = data;                   \\\n    struct htp_ops_context * octx  = mmctx->octx;               \\\n    htp_matmul_tensors_preamble;                                \\\n    dma_queue *dma_queue           = octx->ctx->dma[ith];       \\\n    uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread;\n\n// *** matmul with support for 4d tensors and full broadcasting\n\nstatic void matmul_4d(unsigned int nth, unsigned int ith, void * data) {\n    htp_matmul_preamble;\n\n    uint64_t t1, t2;\n    t1 = HAP_perf_get_qtimer_count();\n\n    assert(ne12 % ne02 == 0);\n    assert(ne13 % ne03 == 0);\n\n    // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)\n    const uint32_t nr0 = ne0;\n\n    // This is the size of the rest of the dimensions of the result\n    const uint32_t nr1 = ne1 * ne2 * ne3;\n\n    // distribute the thread work across the inner or outer loop based on which one is larger\n    uint32_t nchunk0 = nr0 > nr1 ? nth : 1;  // parallelize by src0 rows\n    uint32_t nchunk1 = nr0 > nr1 ? 1 : nth;  // parallelize by src1 rows\n\n    // The number of elements in each chunk\n    const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;\n    const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;\n\n    uint32_t current_chunk = ith;\n\n    const uint32_t ith0 = current_chunk % nchunk0;\n    const uint32_t ith1 = current_chunk / nchunk0;\n\n    const uint32_t ir0_start = dr0 * ith0;\n    const uint32_t ir0_end   = MIN(ir0_start + dr0, nr0);\n\n    const uint32_t ir1_start = dr1 * ith1;\n    const uint32_t ir1_end   = MIN(ir1_start + dr1, nr1);\n\n    // no work for this thread\n    if (ir0_start >= ir0_end || ir1_start >= ir1_end) {\n        return;\n    }\n\n    // block-tiling attempt\n    const uint32_t blck_0 = 64;\n    const uint32_t blck_1 = 64;\n\n    for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {\n        for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {\n            for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {\n                const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1);\n                const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1);\n                const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);\n\n                // broadcast src0 into src1\n                const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3);\n                const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2);\n\n                const uint32_t i1 = i11;\n                const uint32_t i2 = i12;\n                const uint32_t i3 = i13;\n\n                const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);\n                const uint8_t * restrict src1_col  = (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13);\n                float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));\n\n                const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);\n                for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {\n                    const uint8_t * restrict src0_row = src0_base + ir0 * nb01;\n                    mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col);\n                }\n            }\n        }\n    }\n\n    t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\\n\", ith, nth,\n         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0],\n         src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],\n         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\n// src1 tensor is already in VTCM spad\nstatic void matmul_2d(unsigned int nth, unsigned int ith, void * data) {\n    htp_matmul_preamble;\n\n    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows\n    const uint32_t src1_nrows = ne11 * ne12 * ne13;  // src1 rows\n\n    const uint32_t src0_start_row  = src0_nrows_per_thread * ith;\n    const uint32_t src0_end_row    = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);\n    const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);\n\n    // no work for this thread\n    if (src0_start_row >= src0_end_row) {\n        return;\n    }\n\n    const size_t dst_row_size  = nb1;\n    const size_t src0_row_size = nb01;\n    const size_t src1_row_size = nb11;\n\n    const size_t src0_stride = src0_spad->stride;\n    const size_t src1_stride = src1_spad->stride;\n\n    // Per-thread VTCM scratchpads for all tensors\n    // Note that the entire src1 tensor is already in VTCM\n    // For other tensors we allocate N rows per thread, padded to HVX vector size\n    uint8_t * restrict spad_dst  = dst_spad->data  + dst_spad->size_per_thread  * ith;\n    uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;\n    uint8_t * restrict src1_data = src1_spad->data;\n\n    volatile uint64_t t1, t2;\n    t1 = HAP_perf_get_qtimer_count();\n\n    const uint8_t * restrict src0_row = (const uint8_t *) src0->data;\n\n    // Prefill spad with src0 rows\n    #pragma unroll(4)\n    for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {\n        const int is0 = (ir0 - src0_start_row);\n        if (is0 >= MM_SPAD_SRC0_NROWS) {\n            break;\n        }\n        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),\n                       src0_stride, src0_row_size, 2);\n    }\n\n    // Process src0 rows\n    for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {\n        const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;\n\n        // Process src1 columns in pairs (2×2 tiling)\n        uint32_t ir1 = 0;\n        for (; ir1 + 1 < src1_nrows; ir1 += 2) {\n            const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride);\n            const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride);\n            float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size));\n            float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size));\n            mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1);\n        }\n\n        // Handle remaining src1 rows (fallback to 2×1)\n        for (; ir1 < src1_nrows; ++ir1) {\n            const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);\n            float * restrict dst_row          = (float *) (dst->data + (ir1 * dst_row_size));\n            mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col);\n        }\n\n        // Prefetch next (n + spad_nrows) row\n        const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);\n        const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;\n        if (pr0 < src0_end_row_x2) {\n            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),\n                           src0_stride, src0_row_size, 2);\n        }\n    }\n\n    // Process the last row (if any)\n    if (src0_end_row != src0_end_row_x2) {\n        uint32_t  ir0 = src0_end_row_x2;\n        const int is0 = (ir0 - src0_start_row);\n        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),\n                       src0_stride, src0_row_size, 1);\n        const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;\n\n        #pragma unroll(2)\n        for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {\n            const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);\n            float * restrict dst_row          = (float *) (dst->data + (ir1 * dst_row_size));\n            mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);\n        }\n    }\n\n    t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\\n\", mmctx->type, ith, nth,\n         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],\n         src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],\n         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\n// q8x4x2 src1 tensor is already in VTCM spad\nstatic void matvec_2d(unsigned int nth, unsigned int ith, void * data) {\n    htp_matmul_preamble;\n\n    const uint32_t src0_nrows = ne01;\n\n    const uint32_t src0_start_row  = src0_nrows_per_thread * ith;\n    const uint32_t src0_end_row    = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);\n    const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);\n\n    // no work for this thread\n    if (src0_start_row >= src0_end_row) {\n        return;\n    }\n\n    const size_t dst_row_size  = nb1;\n    const size_t src0_row_size = nb01;\n    const size_t src1_row_size = nb11;\n\n    const size_t src0_stride = src0_spad->stride;\n    const size_t src1_stride = src1_spad->stride;\n\n    // Per-thread VTCM scratchpads for all tensors\n    // Note that the entire src1 tensor is already in VTCM\n    // For other tensors we allocate N rows per thread, padded to HVX vector size\n    uint8_t * spad_dst  = dst_spad->data + dst_spad->size_per_thread * ith;\n    uint8_t * spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;\n    uint8_t * src1_data = src1_spad->data;\n\n    uint64_t t1, t2;\n    t1 = HAP_perf_get_qtimer_count();\n\n    float * tmp = (float *) spad_dst;\n\n    const uint8_t * restrict src0_row = (const uint8_t *) src0->data;\n    const uint8_t * restrict src1_col = (const uint8_t *) src1_data;\n    float * restrict dst_col          = (float *) dst->data;\n\n    // Prefill spad with 2x src0 rows\n    #pragma unroll(2)\n    for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {\n        const uint32_t is0 = (ir0 - src0_start_row);\n        if (is0 >= MM_SPAD_SRC0_NROWS) {\n            break;\n        }\n        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),\n                       src0_stride, src0_row_size, 2);\n    }\n\n    // Process src0 rows\n    for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {\n        const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;\n        mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);\n\n        // Prefetch next (n + spad_nrows) row\n        const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);\n        const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;\n        if (pr0 < src0_end_row_x2) {\n            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),\n                           src0_stride, src0_row_size, 2);\n        }\n    }\n\n    // Process the last row (if any)\n    if (src0_end_row != src0_end_row_x2) {\n        const uint32_t ir0 = src0_end_row_x2;\n        const uint32_t is0 = (ir0 - src0_start_row);\n        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),\n                       src0_stride, src0_row_size, 1);\n        const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;\n        mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);\n    }\n\n    hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);\n\n    t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\\n\", mmctx->type, ith, nth,\n         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],\n         src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],\n         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\n#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ids->ne[0] * ids->ne[1] + (i1)]\n\nstruct mmid_row_mapping {\n    uint32_t i1;\n    uint32_t i2;\n};\n\n// src1 tensor is already in VTCM spad\nstatic void matmul_id(unsigned int nth, unsigned int ith, void * data) {\n    htp_matmul_preamble;\n\n    struct htp_tensor * restrict     ids = &octx->src2;\n    struct htp_spad * restrict src2_spad = &octx->src2_spad;\n\n    uint64_t t1, t2;\n    t1 = HAP_perf_get_qtimer_count();\n\n    const uint32_t src0_nrows = ne01;  // src0 rows per expert\n    const uint32_t src1_nrows = ne11;\n\n    const uint32_t src0_start_row  = src0_nrows_per_thread * ith;\n    const uint32_t src0_end_row    = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);\n    const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);\n\n    // no work for this thread\n    if (src0_start_row >= src0_end_row) {\n        return;\n    }\n\n    const uint32_t n_ids = ids->ne[0];  // n_expert_used\n    const uint32_t n_as  = ne02;        // n_expert\n\n    const size_t matrix_row_counts_size = n_as * sizeof(uint32_t);\n    const size_t matrix_row_map_size    = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);\n\n    const uint32_t *                matrix_row_counts = (const uint32_t *) src2_spad->data + 0;\n    const struct mmid_row_mapping * matrix_rows       = (const void *) src2_spad->data + matrix_row_counts_size;\n\n    const size_t dst_row_size  = nb1;\n    const size_t src0_row_size = nb01;\n    const size_t src1_row_size = q8x4x2_row_size(ne10);\n\n    const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);\n\n    // Per-thread VTCM scratchpads for all tensors\n    // Note that the entire src1 tensor is already in VTCM\n    // For other tensors we allocate N rows per thread, padded to HVX vector size\n    uint8_t * restrict spad_dst  = dst_spad->data + dst_spad->size_per_thread * ith;\n    uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;\n    uint8_t * restrict src1_data = src1_spad->data;\n\n    for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) {\n        const int32_t cne1 = matrix_row_counts[cur_a];\n\n        if (cne1 == 0) {\n            continue;\n        }\n\n        const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0);\n\n        // Prefill spad with src0 rows\n        #pragma unroll(4)\n        for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {\n            const int is0 = (ir0 - src0_start_row);\n            if (is0 >= MM_SPAD_SRC0_NROWS) {\n                break;\n            }\n            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),\n                           src0_row_size_padded, src0_row_size, 2);\n        }\n\n        // Process src0 rows\n        for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {\n            const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;\n\n            for (uint32_t cid = 0; cid < cne1; ++cid) {\n                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);\n                const int               rm1         = row_mapping.i1;  // expert idx\n                const int               rm2         = row_mapping.i2;  // token idx\n\n                const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1;        // src1 row idx\n                const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);\n                float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));\n\n                mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);\n            }\n\n            // Prefetch next (n + spad_nrows) row\n            const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);\n            const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;\n            if (pr0 < src0_end_row_x2) {\n                dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),\n                               src0_row_size_padded, src0_row_size, 2);\n            }\n        }\n\n        // Process the last row (if any)\n        if (src0_end_row != src0_end_row_x2) {\n            uint32_t       ir0 = src0_end_row_x2;\n            const uint32_t is0 = (ir0 - src0_start_row);\n            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),\n                           src0_row_size_padded, src0_row_size, 1);\n            const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;\n\n            for (uint32_t cid = 0; cid < cne1; ++cid) {\n                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);\n                const int               rm1         = row_mapping.i1;  // expert idx\n                const int               rm2         = row_mapping.i2;  // token idx\n\n                const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1;        // src1 row idx\n                const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);\n                float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));\n\n                mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);\n            }\n        }\n    }\n\n    t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\\n\", mmctx->type,\n         ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],\n         src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1],\n         dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\n// src1 tensor is already in VTCM spad\nstatic void matvec_id(unsigned int nth, unsigned int ith, void * data) {\n    htp_matmul_preamble;\n\n    struct htp_tensor * restrict     ids = &octx->src2;\n    struct htp_spad * restrict src2_spad = &octx->src2_spad;\n\n    uint64_t t1, t2;\n    t1 = HAP_perf_get_qtimer_count();\n\n    const uint32_t src0_nrows = ne01;  // src0 rows per expert\n\n    const uint32_t src0_start_row  = src0_nrows_per_thread * ith;\n    const uint32_t src0_end_row    = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);\n    const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);\n\n    // no work for this thread\n    if (src0_start_row >= src0_end_row) {\n        return;\n    }\n\n    assert(ne13 % ne03 == 0);\n\n    const size_t dst_row_size  = nb1;\n    const size_t src0_row_size = nb01;\n    const size_t src1_row_size = q8x4x2_row_size(ne10);\n\n    const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);\n\n    const uint32_t n_aids = src2->ne[0];  // num activated experts\n    const uint32_t n_ids  = ne02;         // num experts\n\n    // Per-thread VTCM scratchpads for all tensors\n    // Note that the entire src1 tensor is already in VTCM\n    // For other tensors we allocate N rows per thread, padded to HVX vector size\n    uint8_t * restrict spad_dst  = dst_spad->data + dst_spad->size_per_thread * ith;\n    uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;\n    uint8_t * restrict src1_data = src1_spad->data;\n\n    for (uint32_t ie1 = 0; ie1 < n_aids; ++ie1) {  // for each expert\n        const uint32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]);\n        assert(eid < n_ids);\n\n        const uint8_t * restrict src0_row = (const uint8_t *) src0->data + eid * nb02;\n        const uint8_t * restrict src1_col = (const uint8_t *) src1_data;\n        float * restrict dst_row          = (float *) (dst->data + ie1 * nb1);\n\n        // Prefill spad with src0 rows\n        #pragma unroll(4)\n        for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {\n            const int is0 = (ir0 - src0_start_row);\n            if (is0 >= MM_SPAD_SRC0_NROWS) {\n                break;\n            }\n            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),\n                           src0_row_size_padded, src0_row_size, 2);\n        }\n\n        // Process src0 rows\n        for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {\n            const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;\n            mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);\n\n            // Prefetch next (n + spad_nrows) row\n            const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);\n            const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;\n            if (pr0 < src0_end_row_x2) {\n                dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),\n                               src0_row_size_padded, src0_row_size, 2);\n            }\n        }\n\n        // Process the last row (if any)\n        if (src0_end_row != src0_end_row_x2) {\n            uint32_t       ir0 = src0_end_row_x2;\n            const uint32_t is0 = (ir0 - src0_start_row);\n            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),\n                           src0_row_size_padded, src0_row_size, 1);\n            const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;\n            mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);\n        }\n    }\n\n    t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\\n\", mmctx->type,\n         ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],\n         src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0],\n         dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\n// *** dynamic quant\n\nstatic inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {\n    assert((unsigned long) x % 128 == 0);\n    assert((unsigned long) y_q % 128 == 0);\n\n    HVX_Vector * vx = (HVX_Vector *) x;\n    HVX_Vector zero   = Q6_V_vzero();\n\n    // Use reduce max fp32 to find max(abs(e)) first\n    HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));\n    HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1]));\n    HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2]));\n    HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3]));\n    // Load and convert into QF32\n    HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero);  // 32 elements\n    HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero);  // 32 elements\n    HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero);  // 32 elements\n    HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero);  // 32 elements\n\n    // Convert to QF32\n    HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes\n    HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes\n    HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes\n    HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes\n\n    // Combine and convert to fp16\n    HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));\n    HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf)));\n\n    // Convert into fp16\n    HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));\n    HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));\n\n    HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0\n    HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0\n    HVX_Vector vd01_hf   = Q6_Vhf_equals_Vqf16(vd01_qf16);\n    HVX_Vector vd23_hf   = Q6_Vhf_equals_Vqf16(vd23_qf16);\n\n    hvx_vec_store_u(y_d + 0, 2, vd01_hf);\n    HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64);\n    hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf);\n\n    hvx_vec_store_u(y_d + 4, 2, vd23_hf);\n    rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64);\n    hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf);\n\n    // Divide input by the scale\n    HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);\n    HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);\n    vx01_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));\n    vx23_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));\n\n    // Convert to int8\n    HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);\n    HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);\n    HVX_Vector vx_i8    = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);\n\n    *(HVX_Vector *) y_q = vx_i8;\n}\n\nstatic inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {\n    assert((unsigned long) x % 128 == 0);\n    assert((unsigned long) y_q % 128 == 0);\n\n    HVX_Vector * vx = (HVX_Vector *) x;\n\n    // Load and convert into QF32\n    HVX_Vector zero   = Q6_V_vzero();\n    HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero);  // 32 elements\n    HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero);  // 32 elements\n    HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero);  // 32 elements\n    HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero);  // 32 elements\n\n    // Convert into fp16\n    HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));\n    HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));\n\n    // Compute max and scale\n    HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes\n    HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes\n\n    HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0\n    HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0\n    HVX_Vector vd01_hf   = Q6_Vhf_equals_Vqf16(vd01_qf16);\n    HVX_Vector vd23_hf   = Q6_Vhf_equals_Vqf16(vd23_qf16);\n\n    hvx_vec_store_u(y_d + 0, 4, vd01_hf);\n    hvx_vec_store_u(y_d + 4, 4, vd23_hf);\n\n    // Divide input by the scale\n    HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);\n    HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);\n    vx01_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));\n    vx23_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));\n\n    // Convert to int8\n    HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);\n    HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);\n    HVX_Vector vx_i8    = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);\n\n    *(HVX_Vector *) y_q = vx_i8;\n}\n\nstatic inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {\n    assert((unsigned long) x % 128 == 0);\n    assert((unsigned long) y_q % 128 == 0);\n\n    HVX_Vector * vx = (HVX_Vector *) x;\n\n    // Load and convert into QF32\n    HVX_Vector zero   = Q6_V_vzero();\n    HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero);  // 32 elements\n    HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero);  // 32 elements\n    HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero);  // 32 elements\n    HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero);  // 32 elements\n\n    // Convert into fp16\n    HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));\n    HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));\n\n    // Compute max and scale\n    HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));\n    vmax_hf            = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes\n\n    HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0\n    HVX_Vector vd_hf   = Q6_Vhf_equals_Vqf16(vd_qf16);\n\n    *(HVX_UVector *) y_d = vd_hf;\n\n    // Divide input by the scale\n    HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf);\n    vx01_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf));\n    vx23_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf));\n\n    // Convert to int8\n    HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);\n    HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);\n    HVX_Vector vx_i8    = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);\n\n    *(HVX_Vector *) y_q = vx_i8;\n}\n\n// Overrides input x\nstatic void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {\n    assert(k % 32 == 0);\n    const uint32_t qk = QK_Q8_0x4x2;\n    const uint32_t nb = (k + qk - 1) / qk;\n\n    const uint32_t qrow_size = k;              // int8\n\n    const uint32_t dblk_size = 8 * 2;          // 8x __fp16\n    const uint32_t qblk_size = QK_Q8_0x4x2;    // int8\n\n    uint8_t * restrict y_q = (y + 0);          // quants first\n    uint8_t * restrict y_d = (y + qrow_size);  // then scales\n\n    // Temp scales override input since we're working off of the aligned temp buffer in VTCM\n    uint8_t * restrict t_d = (uint8_t *) x;\n\n    for (uint32_t i = 0; i < nb; i++) {\n#if FP32_QUANTIZE_GROUP_SIZE == 32\n        quantize_block_f32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);\n        quantize_block_f32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);\n#elif FP32_QUANTIZE_GROUP_SIZE == 64\n        quantize_block_f32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);\n        quantize_block_f32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);\n#elif FP32_QUANTIZE_GROUP_SIZE == 128\n        quantize_block_f32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);\n        quantize_block_f32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);\n#else\n#error \"FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128\"\n#endif\n    }\n\n    // now copy the scales into final location\n    hvx_copy_f16_ua(y_d, t_d, nb * 8);\n}\n\nstatic void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_matmul_context * mmctx = data;\n    struct htp_ops_context * octx = mmctx->octx;\n\n    const struct htp_tensor * src = &octx->src1;\n    uint8_t * restrict dst = octx->src1_spad.data;\n    struct htp_spad * spad = &octx->src0_spad;\n    uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;\n\n    uint64_t t1 = HAP_perf_get_qtimer_count();\n\n    const uint32_t ne0 = src->ne[0];\n    const uint32_t ne1 = src->ne[1];\n    const uint32_t ne2 = src->ne[2];\n    const uint32_t ne3 = src->ne[3];\n\n    const uint32_t nrows = ne1 * ne2 * ne3;                             // total n_rows\n\n    const uint32_t ir_first = nrows_per_thread * ith;                   // first row\n    const uint32_t ir_last  = MIN(ir_first + nrows_per_thread, nrows);  // last row\n\n    const size_t src_row_size = src->nb[1];\n    const size_t dst_row_size = q8x4x2_row_size(ne0);\n\n    uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first);\n    uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first);\n    uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith);\n\n    const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));\n    memset(tmp_data, 0, src_row_size_padded);  // zero-out temp row data for padding\n\n    for (uint32_t i = ir_first; i < ir_last; ++i) {\n        hex_l2fetch(src_data, src_row_size, src_row_size, 2);\n        hvx_copy_f32_aa(tmp_data, src_data, ne0);\n\n        // FARF(HIGH, \"quantize-q8x4-row: %u\\n\", i);\n        quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0);\n        dst_data += dst_row_size;\n        src_data += src_row_size;\n    }\n\n    uint64_t t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\\n\", ith, nth, nrows, ir_first,\n         ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\nstatic void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_matmul_context * mmctx = data;\n    struct htp_ops_context * octx = mmctx->octx;\n\n    const struct htp_tensor * src = &octx->src1;\n    uint8_t * restrict dst = octx->src1_spad.data;\n    uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;\n    uint32_t dst_stride = octx->src1_spad.stride;\n\n    uint64_t t1 = HAP_perf_get_qtimer_count();\n\n    const uint32_t ne0 = src->ne[0];\n    const uint32_t ne1 = src->ne[1];\n    const uint32_t ne2 = src->ne[2];\n    const uint32_t ne3 = src->ne[3];\n\n    const uint32_t nrows = ne1 * ne2 * ne3;                             // total n_rows\n\n    const uint32_t ir_first = nrows_per_thread * ith;                   // first row\n    const uint32_t ir_last  = MIN(ir_first + nrows_per_thread, nrows);  // last row\n\n    const size_t src_row_size = ne0 * sizeof(float);\n    const size_t src_stride   = src->nb[1];\n\n    uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);\n    uint8_t * restrict dst_data = (uint8_t *) dst       + (dst_stride * ir_first);\n\n    for (uint32_t i = ir_first; i < ir_last; ++i) {\n        hex_l2fetch(src_data, src_row_size, src_stride, 2);\n        hvx_copy_f16_f32_au(dst_data, src_data, ne0);\n\n        dst_data += dst_stride;\n        src_data += src_stride;\n    }\n\n    uint64_t t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\\n\", ith, nth, nrows, ir_first,\n        ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\n// TODO just a plain copy that should be done via the DMA during the Op setup\nstatic void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_matmul_context * mmctx = data;\n    struct htp_ops_context * octx = mmctx->octx;\n\n    const struct htp_tensor * src = &octx->src1;\n    uint8_t * restrict dst = octx->src1_spad.data;\n    uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;\n    uint32_t dst_stride = octx->src1_spad.stride;\n\n    uint64_t t1 = HAP_perf_get_qtimer_count();\n\n    const uint32_t ne0 = src->ne[0];\n    const uint32_t ne1 = src->ne[1];\n    const uint32_t ne2 = src->ne[2];\n    const uint32_t ne3 = src->ne[3];\n\n    const uint32_t nrows = ne1 * ne2 * ne3;                             // total n_rows\n\n    const uint32_t ir_first = nrows_per_thread * ith;                   // first row\n    const uint32_t ir_last  = MIN(ir_first + nrows_per_thread, nrows);  // last row\n\n    const size_t src_row_size = ne0 * sizeof(float);\n    const size_t src_stride   = src->nb[1];\n\n    uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);\n    uint8_t * restrict dst_data = (uint8_t *) dst       + (dst_stride * ir_first);\n\n    for (uint32_t i = ir_first; i < ir_last; ++i) {\n        hex_l2fetch(src_data, src_row_size, src_stride, 2);\n        hvx_copy_f16_au(dst_data, src_data, ne0);\n\n        dst_data += dst_stride;\n        src_data += src_stride;\n    }\n\n    uint64_t t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\\n\", ith, nth, nrows, ir_first,\n        ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\n\nstatic inline bool htp_is_permuted(const struct htp_tensor * t) {\n    return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];\n}\n\nstatic int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) {\n    switch (type) {\n        case HTP_TYPE_Q4_0:\n            mmctx->type        = \"q4x4x2-f32\";\n            mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1;\n            mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1;\n            mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2;\n            return 0;\n        case HTP_TYPE_Q8_0:\n            mmctx->type        = \"q8x4x2-f32\";\n            mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1;\n            mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;\n            mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;\n            return 0;\n        case HTP_TYPE_MXFP4:\n            mmctx->type        = \"mxfp4x4x2-f32\";\n            mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;\n            mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1;\n            mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2;\n            return 0;\n        default:\n            return -1;\n    }\n}\n\nstatic void htp_mminit_spad(struct htp_ops_context * octx,\n                                 size_t dst_row_size,\n                                 size_t src0_row_size_padded,\n                                 size_t src1_row_size,\n                                 uint32_t src1_nrows,\n                                 size_t src2_spad_size_per_thread) {\n    octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);\n    octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);\n    octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);\n\n    if (src2_spad_size_per_thread > 0) {\n        octx->src2_spad.size_per_thread = src2_spad_size_per_thread;\n        octx->src2_spad.size            = octx->src2_spad.size_per_thread;\n    }\n\n    // src0 spad is also used in dynamic quantizer to store padded src1 rows\n    size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));\n    if (octx->src0_spad.size_per_thread < src1_row_size_padded) {\n        octx->src0_spad.size_per_thread = src1_row_size_padded;\n    }\n\n    octx->src1_spad.size = octx->src1_spad.size_per_thread;\n    octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;\n    octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;\n}\n\nint op_matmul(struct htp_ops_context * octx) {\n    htp_matmul_tensors_preamble;\n\n    struct htp_matmul_context mmctx_struct = {0};\n    struct htp_matmul_context * mmctx = &mmctx_struct;\n    mmctx->octx = octx;\n\n    const uint32_t src0_nrows = ne01 * ne02 * ne03;\n    const uint32_t src1_nrows = ne11 * ne12 * ne13;\n\n    // Compute src0_nrows_per_thread\n    mmctx->src0_nrows_per_thread  = (src0_nrows + octx->n_threads - 1) / octx->n_threads;\n    mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even\n\n    const size_t src0_row_size = nb01;\n    const size_t dst_row_size  = nb1;\n    size_t       src1_row_size = nb11;\n\n    const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);\n    size_t       src1_row_size_padded;\n\n    worker_callback_t quant_job_func;\n    worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d;\n\n    bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);\n\n    if (src0->type == HTP_TYPE_F16) {\n        // Try optimized f16-f16 path first (src1 in VTCM)\n        const size_t f16_src1_row_size  = hex_round_up(ne10 * 2, 128);\n        const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256);\n        const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;\n        const size_t f16_dst_spad_size  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;\n\n        const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;\n\n        // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).\n        // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.\n        const bool is_batched  = (ne02 > 1) || (ne03 > 1);\n        const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);\n\n        if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {\n            // Optimized path\n            quant_job_func     = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16;\n            mmctx->type        = \"f16-f16\";\n            mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1;\n            mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1;\n            mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2;\n\n            src1_row_size = f16_src1_row_size;  // row size post quantization\n\n            octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);\n            octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);\n            octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);\n\n            octx->src1_spad.size = octx->src1_spad.size_per_thread;\n            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;\n            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;\n        } else {\n            // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required\n            quant_job_func = NULL;\n            if (src1->type == HTP_TYPE_F32) {\n                mmctx->type        = \"f16-f32\";\n                mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1;\n                matmul_job_func    = matmul_4d;\n            } else {\n                mmctx->type        = \"f16-f16\";\n                mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1;\n                matmul_job_func    = matmul_4d;\n            }\n\n            src1_row_size = nb11;  // original row size in DDR\n\n            octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);\n            octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);\n            octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);\n\n            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;\n            octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;\n            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;\n\n            // Init fastdiv for matmul_4d (supports broadcasting)\n            mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);\n            mmctx->mm_div_ne1      = init_fastdiv_values(dst->ne[1]);\n            mmctx->mm_div_r2       = init_fastdiv_values(src1->ne[2] / src0->ne[2]);\n            mmctx->mm_div_r3       = init_fastdiv_values(src1->ne[3] / src0->ne[3]);\n\n            need_quant = false;\n        }\n    } else {\n        if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {\n            return HTP_STATUS_NO_SUPPORT;\n        }\n\n        quant_job_func = quantize_f32_q8x4x2;\n        src1_row_size  = q8x4x2_row_size(ne10);\n        htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0);\n    }\n\n    // VTCM scratchpads for all tensors\n    size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;\n\n    FARF(HIGH, \"matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\\n\", mmctx->type,\n         octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size);\n\n    FARF(HIGH, \"matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\\n\", mmctx->type, src0->ne[0],\n         src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],\n         dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data);\n\n    // Make sure the reserved vtcm size is sufficient\n    if (octx->ctx->vtcm_size < spad_size) {\n        FARF(ERROR, \"matmul-%s : current VTCM reservation %zu is too small, needed %zu\\n\", mmctx->type,\n             octx->ctx->vtcm_size, spad_size);\n        return HTP_STATUS_VTCM_TOO_SMALL;\n    }\n\n    octx->src0_spad.data = octx->ctx->vtcm_base;\n    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;\n    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;\n\n    octx->src0_spad.stride = src0_row_size_padded;\n    octx->src1_spad.stride = src1_row_size;\n\n    if (need_quant) {\n        const uint32_t n_quant_jobs  = MIN(src1_nrows, octx->n_threads);\n        mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;\n        worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);\n    }\n\n    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {\n        const uint32_t n_matmul_jobs = octx->n_threads;\n        worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);\n    }\n\n    return HTP_STATUS_OK;\n}\n\nint op_matmul_id(struct htp_ops_context * octx) {\n    htp_matmul_tensors_preamble;\n\n    struct htp_matmul_context mmctx_struct = {0};\n    struct htp_matmul_context * mmctx = &mmctx_struct;\n    mmctx->octx = octx;\n\n    struct htp_tensor * restrict ids = &octx->src2;\n\n    const size_t src0_row_size = nb01;\n    const size_t dst_row_size  = nb1;\n\n    const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);\n\n    const uint32_t src0_nrows = ne01;  // per expert\n    const uint32_t src1_nrows = ne11 * ne12 * ne13;\n\n    worker_callback_t quant_job_func;\n    worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id;\n\n    // Compute src0_nrows_per_thread\n    mmctx->src0_nrows_per_thread  = (src0_nrows + octx->n_threads - 1) / octx->n_threads;\n    mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even\n\n    size_t src1_row_size;\n    size_t src1_row_size_padded;\n\n    // row groups\n    const int n_ids = ids->ne[0];  // n_expert_used\n    const int n_as  = ne02;        // n_expert\n\n    size_t matrix_row_counts_size = n_as * sizeof(uint32_t);\n    size_t matrix_row_map_size    = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);\n\n    if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    quant_job_func = quantize_f32_q8x4x2;\n    src1_row_size  = q8x4x2_row_size(ne10);\n\n    const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);\n    htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread);\n\n    size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;\n\n    FARF(HIGH, \"matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\\n\", mmctx->type,\n         octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size);\n\n    FARF(HIGH, \"matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\\n\", mmctx->type,\n         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],\n         ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data,\n         src1->data, dst->data);\n\n    // Make sure the reserved vtcm size is sufficient\n    if (octx->ctx->vtcm_size < spad_size) {\n        FARF(ERROR, \"matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\\n\", mmctx->type, octx->ctx->vtcm_size, spad_size);\n        return HTP_STATUS_VTCM_TOO_SMALL;\n    }\n\n    octx->src0_spad.data = octx->ctx->vtcm_base;\n    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;\n    octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;\n    octx->dst_spad.data  = octx->src2_spad.data + octx->src2_spad.size;\n\n    octx->src0_spad.stride = src0_row_size_padded;\n    octx->src1_spad.stride = src1_row_size;\n\n    if (src1_nrows > 1) {\n        // initialize matrix_row_counts and map\n        uint32_t *                matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0;\n        struct mmid_row_mapping * matrix_rows       = (void *) octx->src2_spad.data + matrix_row_counts_size;\n\n        memset(matrix_row_counts, 0, n_as * sizeof(uint32_t));\n\n        // group rows by src0 matrix\n        for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {  // token idx\n            for (uint32_t id = 0; id < n_ids; ++id) {         // expert idx\n                const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);\n\n                assert(i02 >= 0 && i02 < n_as);\n\n                MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 };\n                matrix_row_counts[i02] += 1;\n            }\n        }\n    }\n\n    // Setup worker pool callbacks\n    if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {\n        const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);\n        mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;\n        worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);\n    }\n\n    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {\n        const uint32_t n_matmul_jobs = octx->n_threads;\n        worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);\n    }\n\n    return HTP_STATUS_OK;\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/rope-ops.c",
    "content": "#pragma clang diagnostic ignored \"-Wunused-variable\"\n#pragma clang diagnostic ignored \"-Wunused-function\"\n#pragma clang diagnostic ignored \"-Wunused-but-set-variable\"\n\n#include <HAP_farf.h>\n#include <HAP_perf.h>\n\n#include <math.h>\n#include <string.h>\n\n#include \"hex-dma.h\"\n#include \"hvx-utils.h\"\n#include \"hex-fastdiv.h\"\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n#include \"htp-ctx.h\"\n#include \"htp-msg.h\"\n#include \"htp-ops.h\"\n\n// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h\n#define HTP_ROPE_TYPE_NORMAL 0\n#define HTP_ROPE_TYPE_NEOX   2\n\n#define HTP_ROPE_SPAD_NROWS  16\n#define HTP_ROPE_SPAD_BLOCK  (HTP_ROPE_SPAD_NROWS/2)\n\n#define htp_rope_preamble              \\\n    const uint32_t ne00 = src0->ne[0]; \\\n    const uint32_t ne01 = src0->ne[1]; \\\n    const uint32_t ne02 = src0->ne[2]; \\\n    const uint32_t ne03 = src0->ne[3]; \\\n                                       \\\n    const uint32_t ne0 = dst->ne[0];   \\\n    const uint32_t ne1 = dst->ne[1];   \\\n    const uint32_t ne2 = dst->ne[2];   \\\n    const uint32_t ne3 = dst->ne[3];   \\\n                                       \\\n    const uint32_t nb00 = src0->nb[0]; \\\n    const uint32_t nb01 = src0->nb[1]; \\\n    const uint32_t nb02 = src0->nb[2]; \\\n    const uint32_t nb03 = src0->nb[3]; \\\n                                       \\\n    const uint32_t nb0 = dst->nb[0];   \\\n    const uint32_t nb1 = dst->nb[1];   \\\n    const uint32_t nb2 = dst->nb[2];   \\\n    const uint32_t nb3 = dst->nb[3];\n\nstruct htp_rope_context {\n    int32_t n_dims;\n    int32_t mode;\n    int32_t n_ctx_orig;\n    int32_t sections[4];\n\n    float freq_base;\n    float freq_scale;\n    float ext_factor;\n    float attn_factor;\n    float beta_fast;\n    float beta_slow;\n    float theta_scale;\n    float corr_dims[2];\n\n    uint32_t src0_nrows_per_thread;\n    size_t spad_stride;\n\n    struct htp_ops_context * octx;\n\n    size_t src0_row_size;\n    size_t dst_row_size;\n    size_t src0_row_size_aligned;\n    size_t dst_row_size_aligned;\n    size_t theta_cache_offset;\n    uint32_t src0_nrows;\n\n    uint64_t t_start;\n};\n\nstatic float rope_yarn_ramp(const float low, const float high, const int i0) {\n    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);\n\n    return (1 - MIN(1, MAX(0, y)));\n}\n\nstatic void rope_cache_init(const float    theta_base,\n                            const float    freq_scale,\n                            const float *  freq_factors,\n                            float *        corr_dims,\n                            const uint32_t ne0,\n                            const float    ext_factor,\n                            const float    mscale,\n                            float *        cache,\n                            const float    theta_scale) {\n    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py\n    float theta = theta_base;\n\n    for (uint32_t i0 = 0; i0 < ne0; i0 += 2) {\n        const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;\n\n        float theta_extrap = theta / ff;\n\n        // Get n-d rotational scaling corrected for extrapolation\n        float theta_interp = freq_scale * theta_extrap;\n        float theta_final  = theta_interp;\n        float mscale_final = mscale;\n\n        if (ext_factor != 0.0f) {\n            float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;\n            theta_final    = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;\n\n            // Get n-d magnitude scaling corrected for interpolation\n            mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);\n        }\n\n        cache[i0 + 0] = cosf(theta_final) * mscale_final;\n        cache[i0 + 1] = sinf(theta_final) * mscale_final;\n\n        theta *= theta_scale;\n    }\n}\n\n#define M_PI 3.1415926535897932384626433\n\nstatic void rope_corr_dims(int     n_dims,\n                           int     n_ctx_orig,\n                           float   freq_base,\n                           float   beta_fast,\n                           float   beta_slow,\n                           float * dims) {\n    float start = floorf(n_dims * logf(n_ctx_orig / (beta_fast * 2 * (float) M_PI)) / (2 * logf(freq_base)));\n    float end   = ceilf(n_dims * logf(n_ctx_orig / (beta_slow * 2 * (float) M_PI)) / (2 * logf(freq_base)));\n    dims[0]     = MAX(0, start);\n    dims[1]     = MIN(n_dims - 1, end);\n}\n\nstatic inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {\n    const HVX_Vector * restrict vsrc   = (const HVX_Vector *) src0;\n    const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;\n    HVX_Vector       * restrict vdst   = (HVX_Vector *) dst;\n\n    uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2\n\n    uint32_t he = ne / 2;         // half_dims offset in elements\n    uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors\n\n    #pragma unroll(2)\n    for (uint32_t i = 0; i < nvec; i += 2) {\n        HVX_Vector v0 = vsrc[i/2+0];\n        HVX_Vector v1 = vsrc[i/2+hv];\n\n        HVX_Vector v2 = vtheta[i+0];\n        HVX_Vector v3 = vtheta[i+1];\n\n        HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4);  // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta\n\n        HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin));\n        HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin));\n        HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin));\n        HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin));\n\n        HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);\n        HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);\n\n        vdst[i/2+0]  = Q6_Vsf_equals_Vqf32(v4);\n        vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5);\n    }\n\n    for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {\n        const float cos_theta = theta_cache[i+0];\n        const float sin_theta = theta_cache[i+1];\n        float x0 = src0[i/2];\n        float x1 = src0[i/2 + he];\n        dst[i/2]      = x0 * cos_theta - x1 * sin_theta;\n        dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta;\n    }\n}\n\nstatic inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {\n    const HVX_Vector * restrict vsrc   = (const HVX_Vector *) src0;\n    const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;\n    HVX_Vector       * restrict vdst   = (HVX_Vector *) dst;\n\n    uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two\n\n    #pragma unroll(2)\n    for (uint32_t i = 0; i < nvec; i+=2) {\n        HVX_Vector v0 = vsrc[i+0];\n        HVX_Vector v1 = vsrc[i+1];\n\n        HVX_Vector v2 = vtheta[i+0];\n        HVX_Vector v3 = vtheta[i+1];\n\n        HVX_VectorPair vx0_x1   = Q6_W_vdeal_VVR(v1, v0, -4);  // vx0_x1[0] = x0, vx0_x1[1] = x1\n        HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4);  // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta\n\n        HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin));\n        HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin));\n        HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin));\n        HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin));\n\n        HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);\n        HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);\n\n        HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);\n\n        vdst[i+0] = Q6_V_lo_W(vstore);\n        vdst[i+1] = Q6_V_hi_W(vstore);\n    }\n\n    for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {\n        const float cos_theta = theta_cache[i+0];\n        const float sin_theta = theta_cache[i+1];\n        float x0 = src0[i+0];\n        float x1 = src0[i+1];\n        dst[i+0] = x0 * cos_theta - x1 * sin_theta;\n        dst[i+1] = x0 * sin_theta + x1 * cos_theta;\n    }\n}\n\nstatic void inline rope_basic_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,\n                   uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {\n    #pragma unroll(4)\n    for (uint32_t i = 0; i < nr; i++) {\n        float * d = (float *) (dst + i * rctx->dst_row_size_aligned);\n        float * s = (float *) (src + i * rctx->src0_row_size_aligned);\n\n        hvx_rope_f32_aa(d, s, rctx->n_dims, theta_cache);\n\n        // fill the remain channels with data from src tensor\n        if (rctx->n_dims < ne0) {\n            hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);\n        }\n    }\n}\n\nstatic void inline rope_neox_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,\n                   uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {\n    #pragma unroll(4)\n    for (uint32_t i = 0; i < nr; i++) {\n        float * d = (float *) (dst + i * rctx->dst_row_size_aligned);\n        float * s = (float *) (src + i * rctx->src0_row_size_aligned);\n\n        hvx_rope_neox_f32_aa(d, s, rctx->n_dims, theta_cache);\n\n        // fill the remain channels with data from src tensor\n        if (rctx->n_dims < ne0) {\n            hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);\n        }\n    }\n}\n\nstatic void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_rope_context * rctx = (struct htp_rope_context *) data;\n    struct htp_ops_context * octx = rctx->octx;\n\n    const struct htp_tensor * src0 = &octx->src0;\n    const struct htp_tensor * src1 = &octx->src1;\n    const struct htp_tensor * src2 = &octx->src2;\n    struct htp_tensor *       dst  = &octx->dst;\n\n    htp_rope_preamble;\n\n    const uint32_t src0_nrows = rctx->src0_nrows;\n    const uint32_t src0_nrows_per_thread = rctx->src0_nrows_per_thread;\n\n    const uint32_t src0_start_row = src0_nrows_per_thread * ith;\n    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);\n\n    // no work for this thread\n    if (src0_start_row >= src0_end_row) {\n        return;\n    }\n\n    uint64_t tt = HAP_perf_get_qtimer_count();\n\n    const int32_t mode    = rctx->mode;\n    const bool    is_neox = mode & HTP_ROPE_TYPE_NEOX;\n\n    // VTCM setup\n    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);\n    float *   theta_cache    = (float *) (src0_spad_base);\n              src0_spad_base = src0_spad_base + rctx->theta_cache_offset;\n    uint8_t * dst_spad_base  = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);\n\n    dma_queue * dma_queue = octx->ctx->dma[ith];\n    const int32_t * pos = (const int32_t *) src1->data;\n    const float * freq_factors = src2->data ? (const float *) src2->data : NULL;\n\n    uint32_t ir = 0;\n    uint32_t prev_i2 = (uint32_t) -1;\n\n    for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch\n        for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len\n            for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads\n                if (ir < src0_start_row) { ir++; i1++; continue; }\n                if (ir >= src0_end_row) goto done;\n\n                // Rows in this block\n                const uint32_t nrows = MIN(src0_end_row - ir, ne1 - i1);\n\n                // Depth before prefetch\n                uint32_t dma_depth = dma_queue_depth(dma_queue);\n\n                // FARF(HIGH, \"rope-block %u: ir %u n-rows %u dma-depth %u : usec %u\", ith, ir, nrows, dma_depth,\n                //             (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));\n\n                // Prefetch loop\n                for (uint32_t pnr = 0, pr = 0; pr < nrows && pr < HTP_ROPE_SPAD_NROWS; pr += pnr) {\n                    pnr = MIN(nrows - pr, HTP_ROPE_SPAD_BLOCK);\n\n                    uint32_t pi1 = i1 + pr;\n                    uint32_t pir = ir + pr;\n\n                    // Dummy DMA transaction for sequencing (interleaving dst,src,dst,...)\n                    dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr((void *) dst->data, dst_spad_base + pr * rctx->dst_row_size_aligned), 0, 0, 0);\n\n                    const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;\n                          uint8_t * src_spad = src0_spad_base + pr * rctx->src0_row_size_aligned;\n                    dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),\n                        rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);\n\n                    // FARF(HIGH, \"rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u\", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);\n                }\n\n                // Update theta cache\n                if (i2 != prev_i2) {\n                    prev_i2 = i2;\n\n                    const int32_t p = pos[i2];\n                    rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale);\n\n                    // FARF(HIGH, \"rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u\", ith, ir, i1, i2, i3, theta_cache,\n                    //         (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));\n                }\n\n                // Skip DMA transactions from prev block (if any)\n                // No need to wait for these since the DMA is setup for in-order processing\n                for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); }\n\n                // Compute loop\n                for (uint32_t cnr = 0, cr = 0; cr < nrows; cr += cnr, ir += cnr, i1 += cnr) {\n                    // Number of rows to compute\n                    cnr = MIN(nrows - cr, HTP_ROPE_SPAD_BLOCK);\n\n                    uint8_t * dst_spad = (uint8_t *) dma_queue_pop(dma_queue).src;\n                    uint8_t * src_spad = (uint8_t *) dma_queue_pop(dma_queue).dst;\n\n                    // FARF(HIGH, \"rope-compute %u: ir %u i1 %u i2 %u i3 %u src-spad %p cnr %u : usec %u\", ith, ir, i1, i2, i3, src_spad, cnr,\n                    //         (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));\n\n                    if (is_neox) {\n                        rope_neox_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);\n                    } else {\n                        rope_basic_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);\n                    }\n\n                    uint8_t * dst_addr = (uint8_t *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1;\n                    dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(dst_addr, dst_spad), rctx->dst_row_size, rctx->dst_row_size_aligned, cnr);\n\n                    // Prefetch more rows (if any)\n                    if ((cr + HTP_ROPE_SPAD_NROWS) < nrows) {\n                        uint32_t pnr = MIN(nrows - (cr + HTP_ROPE_SPAD_NROWS), HTP_ROPE_SPAD_BLOCK);\n                        uint32_t pi1 = i1 + HTP_ROPE_SPAD_NROWS;\n                        uint32_t pir = ir + HTP_ROPE_SPAD_NROWS;\n\n                        const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;\n                        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),\n                            rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);\n\n                        // FARF(HIGH, \"rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u\", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);\n                    }\n                }\n            }\n        }\n    }\n\ndone:\n    dma_queue_flush(dma_queue);\n    tt = HAP_perf_get_qtimer_count() - tt;\n\n    FARF(HIGH, \"rope-f32: %d/%d: (%u:%u) usec %u\\n\", ith, nth, src0_start_row, src0_end_row, (unsigned) HAP_perf_qtimer_count_to_us(tt));\n}\n\nstatic int execute_op_rope_f32(struct htp_ops_context * octx) {\n    int err = HTP_STATUS_OK;\n\n    const struct htp_tensor * src0 = &octx->src0;\n    const struct htp_tensor * src1 = &octx->src1;\n    const struct htp_tensor * src2 = &octx->src2;\n    struct htp_tensor *       dst  = &octx->dst;\n\n    const char * op_type = \"rope-f32\";\n\n    switch (octx->op) {\n        case HTP_OP_ROPE:\n            break;\n\n        default:\n            FARF(ERROR, \"Unsupported Op %u\\n\", octx->op);\n            return HTP_STATUS_NO_SUPPORT;\n    }\n\n    const uint32_t ne0 = dst->ne[0];\n    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];\n    const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);\n\n    const size_t src0_row_size = src0->nb[1];\n    const size_t dst_row_size  = dst->nb[1];\n\n    // Aligned row sizes for VTCM\n    const size_t src0_row_size_aligned    = hex_round_up(src0_row_size, VLEN);\n    const size_t dst_row_size_aligned     = hex_round_up(dst_row_size, VLEN);\n    const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128);\n\n    // Calculate spad sizes per thread\n    size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned;\n    size_t dst_spad_per_thread  = HTP_ROPE_SPAD_NROWS * dst_row_size_aligned;\n    size_t spad_per_thread = src0_spad_per_thread + dst_spad_per_thread;\n\n    // Check if we fit in VTCM\n    size_t total_vtcm_needed = spad_per_thread * n_threads;\n    if (octx->ctx->vtcm_size < total_vtcm_needed) {\n        FARF(ERROR, \"%s : current VTCM reservation %zu is too small, needed %zu\\n\", op_type, octx->ctx->vtcm_size, total_vtcm_needed);\n        return HTP_STATUS_VTCM_TOO_SMALL;\n    }\n\n    // Assign sizes\n    octx->src0_spad.size_per_thread = src0_spad_per_thread;\n    octx->dst_spad.size_per_thread  = dst_spad_per_thread;\n    octx->src0_spad.size = n_threads * src0_spad_per_thread;\n    octx->dst_spad.size  = n_threads * dst_spad_per_thread;\n    octx->src1_spad.size = 0;\n\n    // Assign pointers\n    octx->src0_spad.data = octx->ctx->vtcm_base;\n    octx->src1_spad.data = NULL;\n    octx->dst_spad.data  = octx->src0_spad.data + octx->src0_spad.size;\n\n    // Fill context\n    struct htp_rope_context rctx;\n    memset(&rctx, 0, sizeof(struct htp_rope_context));\n\n    rctx.t_start = HAP_perf_get_qtimer_count();\n\n    rctx.octx = octx;\n\n    const int32_t * op_params = &octx->op_params[0];\n    rctx.n_dims     = ((const int32_t *) op_params)[1];\n    rctx.mode       = ((const int32_t *) op_params)[2];\n    rctx.n_ctx_orig = ((const int32_t *) op_params)[4];\n\n    memcpy(&rctx.freq_base,   (int32_t *) op_params + 5,  sizeof(float));\n    memcpy(&rctx.freq_scale,  (int32_t *) op_params + 6,  sizeof(float));\n    memcpy(&rctx.ext_factor,  (int32_t *) op_params + 7,  sizeof(float));\n    memcpy(&rctx.attn_factor, (int32_t *) op_params + 8,  sizeof(float));\n    memcpy(&rctx.beta_fast,   (int32_t *) op_params + 9,  sizeof(float));\n    memcpy(&rctx.beta_slow,   (int32_t *) op_params + 10, sizeof(float));\n    memcpy(&rctx.sections,    (int32_t *) op_params + 11, sizeof(int) * 4);\n\n    rctx.theta_scale = powf(rctx.freq_base, -2.0f / rctx.n_dims);\n\n    rope_corr_dims(rctx.n_dims, rctx.n_ctx_orig, rctx.freq_base, rctx.beta_fast, rctx.beta_slow, rctx.corr_dims);\n\n    rctx.src0_row_size = src0_row_size;\n    rctx.dst_row_size  = dst_row_size;\n    rctx.src0_row_size_aligned = src0_row_size_aligned;\n    rctx.dst_row_size_aligned  = dst_row_size_aligned;\n    rctx.theta_cache_offset    = theta_cache_size_aligned;\n\n    rctx.src0_nrows = src0_nrows;\n    rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;\n\n    FARF(HIGH, \"rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\\n\", rctx.src0_nrows, rctx.n_dims, ne0,\n         rctx.ext_factor, rctx.theta_scale, rctx.attn_factor);\n\n    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {\n        worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_threads);\n    }\n\n    return err;\n}\n\nint op_rope(struct htp_ops_context * octx) {\n    int err = HTP_STATUS_OK;\n\n    switch (octx->src0.type) {\n        case HTP_TYPE_F32:\n            err = execute_op_rope_f32(octx);\n            break;\n\n        default:\n            err = HTP_STATUS_NO_SUPPORT;\n            break;\n    }\n\n    return err;\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/set-rows-ops.c",
    "content": "#pragma clang diagnostic ignored \"-Wunused-variable\"\n#pragma clang diagnostic ignored \"-Wunused-function\"\n#pragma clang diagnostic ignored \"-Wunused-but-set-variable\"\n\n#include <HAP_farf.h>\n#include <HAP_perf.h>\n\n#include <math.h>\n#include <string.h>\n\n#include \"hex-dma.h\"\n#include \"hvx-utils.h\"\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n#include \"htp-ctx.h\"\n#include \"htp-msg.h\"\n#include \"htp-ops.h\"\n\n#define set_rows_preamble \\\n    const uint32_t ne00 = octx->src0.ne[0]; \\\n    const uint32_t ne01 = octx->src0.ne[1]; \\\n    const uint32_t ne02 = octx->src0.ne[2]; \\\n    const uint32_t ne03 = octx->src0.ne[3]; \\\n                                            \\\n    const uint32_t ne10 = octx->src1.ne[0]; \\\n    const uint32_t ne11 = octx->src1.ne[1]; \\\n    const uint32_t ne12 = octx->src1.ne[2]; \\\n                                            \\\n    const uint32_t nb01 = octx->src0.nb[1]; \\\n    const uint32_t nb02 = octx->src0.nb[2]; \\\n    const uint32_t nb03 = octx->src0.nb[3]; \\\n                                            \\\n    const uint32_t nb10 = octx->src1.nb[0]; \\\n    const uint32_t nb11 = octx->src1.nb[1]; \\\n    const uint32_t nb12 = octx->src1.nb[2]; \\\n                                            \\\n    const uint32_t nb1 = octx->dst.nb[1];   \\\n    const uint32_t nb2 = octx->dst.nb[2];   \\\n    const uint32_t nb3 = octx->dst.nb[3];   \\\n                                            \\\n    const uint32_t ne1 = octx->dst.ne[1];   \\\n                                            \\\n    const uint32_t nr  = ne01;\n\nstruct htp_set_rows_context {\n    struct htp_ops_context * octx;\n    struct fastdiv_values div_ne12;\n    struct fastdiv_values div_ne11;\n    uint32_t src0_nrows_per_thread;\n};\n\nstatic void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {\n    struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;\n    struct htp_ops_context * octx = srctx->octx;\n\n    set_rows_preamble;\n\n    // parallelize by rows of src0\n    const uint32_t dr  = srctx->src0_nrows_per_thread;\n    const uint32_t ir0 = dr * ith;\n    const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;\n\n    const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);\n\n    for (uint32_t i03 = 0; i03 < ne03; ++i03) {\n        for (uint32_t i02 = 0; i02 < ne02; ++i02) {\n            for (uint32_t i = ir0; i < ir1; ++i) {\n                const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);\n                const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);\n                const uint32_t i10 = i;\n\n                const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;\n\n                uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;\n                if (i1 >= ne1) {\n                    // ignore invalid indices\n                    continue;\n                }\n\n                const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;\n                const uintptr_t dst_ptr  = octx->dst.data  + i1*nb1 + i02*nb2  + i03*nb3;\n\n                // copy row\n                hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);\n            }\n        }\n    }\n}\n\nstatic void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) {\n    struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;\n    struct htp_ops_context * octx = srctx->octx;\n\n    set_rows_preamble;\n\n    // parallelize by rows of src0\n    const uint32_t dr  = srctx->src0_nrows_per_thread;\n    const uint32_t ir0 = dr * ith;\n    const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;\n\n    const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);\n\n    for (uint32_t i03 = 0; i03 < ne03; ++i03) {\n        for (uint32_t i02 = 0; i02 < ne02; ++i02) {\n            for (uint32_t i = ir0; i < ir1; ++i) {\n                const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);\n                const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);\n                const uint32_t i10 = i;\n\n                const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;\n\n                uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;\n                if (i1 >= ne1) {\n                    // ignore invalid indices\n                    continue;\n                }\n\n                const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;\n                uint8_t*       dst_ptr  = (uint8_t *)       octx->dst.data  + i1*nb1 + i02*nb2  + i03*nb3;\n\n                hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);\n            }\n        }\n    }\n}\n\nint op_set_rows(struct htp_ops_context * octx) {\n    set_rows_preamble;\n\n    const uint32_t n_threads = MIN(nr, octx->n_threads);\n\n    if (octx->src0.type != HTP_TYPE_F32) {\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) {\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {\n        return HTP_STATUS_OK;\n    }\n\n    struct htp_set_rows_context srctx;\n    srctx.octx = octx;\n    srctx.div_ne12 = init_fastdiv_values(ne12);\n    srctx.div_ne11 = init_fastdiv_values(ne11);\n\n    srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;\n\n    switch(octx->dst.type) {\n    case HTP_TYPE_F32:\n        worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads);\n        break;\n    case HTP_TYPE_F16:\n        worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_threads);\n        break;\n    default:\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    return HTP_STATUS_OK;\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/softmax-ops.c",
    "content": "#pragma clang diagnostic ignored \"-Wunused-variable\"\n#pragma clang diagnostic ignored \"-Wunused-function\"\n#pragma clang diagnostic ignored \"-Wunused-but-set-variable\"\n\n#include <HAP_farf.h>\n#include <HAP_perf.h>\n\n#include <math.h>\n#include <string.h>\n\n#include \"hex-dma.h\"\n#include \"hvx-utils.h\"\n#include \"hex-fastdiv.h\"\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n#include \"htp-ctx.h\"\n#include \"htp-msg.h\"\n#include \"htp-ops.h\"\n\n#define htp_softmax_preamble3                              \\\n    const uint32_t ne00 = src0->ne[0];                     \\\n    const uint32_t ne01 = src0->ne[1];                     \\\n    const uint32_t ne02 = src0->ne[2];                     \\\n    const uint32_t ne03 = src0->ne[3];                     \\\n                                                           \\\n    const uint32_t nb00 = src0->nb[0];                     \\\n    const uint32_t nb01 = src0->nb[1];                     \\\n    const uint32_t nb02 = src0->nb[2];                     \\\n    const uint32_t nb03 = src0->nb[3];                     \\\n                                                           \\\n    const uint32_t ne10 = (src1->ne[0]) ? src1->ne[0] : 1; \\\n    const uint32_t ne11 = (src1->ne[0]) ? src1->ne[1] : 1; \\\n    const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; \\\n    const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; \\\n                                                           \\\n    const uint32_t nb10 = (src1->ne[0]) ? src1->nb[0] : 1; \\\n    const uint32_t nb11 = (src1->ne[0]) ? src1->nb[1] : 1; \\\n    const uint32_t nb12 = (src1->ne[0]) ? src1->nb[2] : 1; \\\n    const uint32_t nb13 = (src1->ne[0]) ? src1->nb[3] : 1; \\\n                                                           \\\n    const uint32_t ne0 = dst->ne[0];                       \\\n    const uint32_t ne1 = dst->ne[1];                       \\\n    const uint32_t ne2 = dst->ne[2];                       \\\n    const uint32_t ne3 = dst->ne[3];                       \\\n                                                           \\\n    const uint32_t nb0 = dst->nb[0];                       \\\n    const uint32_t nb1 = dst->nb[1];                       \\\n    const uint32_t nb2 = dst->nb[2];                       \\\n    const uint32_t nb3 = dst->nb[3];\n\nstruct htp_softmax_context {\n    bool     use_f16;\n    bool     use_src1;\n    uint32_t n_head;\n    uint32_t n_head_log2;\n\n    float scale;\n    float max_bias;\n    float m0;\n    float m1;\n\n    uint32_t src0_nrows_per_thread;\n    struct fastdiv_values fastdiv_ne01;\n    struct fastdiv_values fastdiv_ne02;\n    struct fastdiv_values fastdiv_ne12; // For mask broadcasting\n    struct fastdiv_values fastdiv_ne13; // For mask broadcasting\n    size_t spad_stride;\n\n    struct htp_ops_context * octx;\n};\n\nstatic void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) {\n    const struct htp_tensor * src0 = &octx->src0;\n    const struct htp_tensor * src1 = &octx->src1;\n\n    memset(smctx, 0, sizeof(struct htp_softmax_context));\n\n    memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float));\n    memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float));\n\n    smctx->n_head      = src0->ne[2];\n    smctx->n_head_log2 = 1u << (uint32_t) floor(log2(smctx->n_head));\n\n    smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2);\n    smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2);\n\n    smctx->use_src1 = (src1->ne[0] != 0);\n    smctx->use_f16  = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);\n\n    smctx->octx = octx;\n\n    // Initialize fastdiv values\n    const uint32_t ne01 = src0->ne[1];\n    const uint32_t ne02 = src0->ne[2];\n\n    if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01);\n    if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02);\n\n    const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1;\n    const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1;\n\n    if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12);\n    if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13);\n}\n\nstatic void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,\n                                      uint8_t * restrict dst,\n                                      const int num_elems,\n                                      float     scale,\n                                      const uint8_t * restrict mask,\n                                      float slope) {\n    const uint8_t * restrict src_curr  = src;\n    uint8_t * restrict dst_curr        = dst;\n    const uint8_t * restrict mask_curr = mask;\n\n    HVX_Vector scale_vec = hvx_vec_splat_f32(scale);\n    HVX_Vector slope_vec = hvx_vec_splat_f32(slope);\n\n    int step_of_1 = num_elems >> 5;\n\n    #pragma unroll(4)\n    for (int i = 0; i < step_of_1; i++) {\n        HVX_Vector v1 = *(HVX_Vector *) src_curr;\n\n        HVX_Vector v3 = *(HVX_Vector *) mask_curr;\n\n        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec);\n\n        HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v3, slope_vec);\n\n        HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, v4);\n\n        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v5);\n\n        src_curr += VLEN;\n        dst_curr += VLEN;\n        mask_curr += VLEN;\n    }\n}\n\nstatic void hvx_fast_softmax_f32(const uint8_t * restrict src,\n                                 uint8_t * restrict dst,\n                                 uint8_t * restrict pad,\n                                 const int num_elems) {\n    const HVX_Vector * restrict v_src = (HVX_Vector *) src;\n    HVX_Vector * restrict v_pad       = (HVX_Vector *) pad;\n    HVX_Vector * restrict v_dst       = (HVX_Vector *) dst;\n\n    HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000);\n    HVX_Vector max_vec = hvx_vec_splat_f32(((const float *) src)[0]);\n    HVX_Vector zero_v  = Q6_V_vzero();\n    HVX_Vector one_v   = hvx_vec_splat_f32(1.0);\n\n    int step_of_1 = num_elems >> 5;\n\n    #pragma unroll(4)\n    for (int i = 0; i < step_of_1; i++) {\n        HVX_Vector v1 = v_src[i];\n        max_vec       = Q6_Vsf_vmax_VsfVsf(max_vec, v1);\n    }\n\n    max_vec = hvx_vec_reduce_max_f32(max_vec); // replicated over all lanes\n\n    #pragma unroll(4)\n    for (int i = 0; i < step_of_1; i++) {\n        HVX_Vector v1 = v_src[i];\n        HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, max_vec);\n\n        HVX_Vector v3 = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(v2));\n\n        sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), v3);\n\n        v_pad[i] = v3;\n    }\n\n    sum_vec = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); // replicated over all lanes\n\n    HVX_VectorPred pos_sum   = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);\n    HVX_Vector     v4        = hvx_vec_inverse_f32(sum_vec);\n    HVX_Vector     scale_vec = Q6_V_vmux_QVV(pos_sum, v4, one_v);\n\n    #pragma unroll(4)\n    for (int i = 0; i < step_of_1; i++) {\n        HVX_Vector v1 = v_pad[i];\n        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec);\n        v_dst[i]      = Q6_Vsf_equals_Vqf32(v2);\n    }\n}\n\nstatic float hvx_softmax_f32(const uint8_t * restrict src,\n                             uint8_t * restrict dst,\n                             uint8_t * restrict spad,\n                             const int   num_elems,\n                             const float max) {\n    hvx_sub_scalar_f32(spad, src, max, num_elems);\n\n    hvx_exp_f32(spad, dst, num_elems, false);\n\n    float sum = hvx_reduce_sum_f32(dst, num_elems);\n\n    return sum;\n}\n\nstatic void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {\n    struct htp_softmax_context * smctx = (struct htp_softmax_context *) data;\n    struct htp_ops_context * octx = smctx->octx;\n\n    const struct htp_tensor * src0 = &octx->src0;\n    const struct htp_tensor * src1 = &octx->src1;\n    struct htp_tensor *       dst  = &octx->dst;\n\n    htp_softmax_preamble3;\n\n    const uint32_t src0_nrows            = ne01 * ne02 * ne03;  // src0 rows\n    const uint32_t src0_nrows_per_thread = smctx->src0_nrows_per_thread;\n\n    const uint32_t src0_start_row = src0_nrows_per_thread * ith;\n    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);\n\n    // no work for this thread\n    if (src0_start_row >= src0_end_row) {\n        return;\n    }\n\n    uint64_t t1, t2;\n    t1 = HAP_perf_get_qtimer_count();\n\n    int is_aligned = 1;\n    int opt_path   = 0;\n    if (!hex_is_aligned((void *) src0->data, VLEN) || !hex_is_aligned((void *) dst->data, VLEN)) {\n        is_aligned = 0;\n        FARF(HIGH, \"softmax-f32: unaligned addresses in elementwise op, possibly slower execution\\n\");\n    }\n    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {\n        opt_path = 1;\n    }\n\n    uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride);\n    uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride);\n    uint8_t * dst_spad_data  = octx->dst_spad.data + (ith * smctx->spad_stride);\n\n    float * wp0 = (float *) src0_spad_data;\n    float * wp1 = (float *) src1_spad_data;\n    float * wp2 = (float *) dst_spad_data;\n\n    uint32_t prev_i2 = (uint32_t)-1;\n    float slope = 1.0f;\n\n    for (uint32_t r = src0_start_row; r < src0_end_row; ++r) {\n        uint32_t i1 = fastmodulo(r, ne01, &smctx->fastdiv_ne01);\n        uint32_t r_div_ne01 = fastdiv(r, &smctx->fastdiv_ne01);\n        uint32_t i2 = fastmodulo(r_div_ne01, ne02, &smctx->fastdiv_ne02);\n        uint32_t i3 = fastdiv(r_div_ne01, &smctx->fastdiv_ne02);\n\n        // Map to original logic indices\n        // i01 = i1\n        // i02 = i2\n        // i03 = i3\n\n        const uint32_t i11 = i1;\n        // const uint32_t i12 = i2 % ne12;\n        // const uint32_t i13 = i3 % ne13;\n\n        uint32_t i12, i13;\n        if (ne12 == ne02) {\n             i12 = i2;\n        } else {\n             i12 = fastmodulo(i2, ne12, &smctx->fastdiv_ne12);\n        }\n\n        if (ne13 == ne03) {\n             i13 = i3;\n        } else {\n             i13 = fastmodulo(i3, ne13, &smctx->fastdiv_ne13);\n        }\n\n        // ALiBi\n        if (i2 != prev_i2) {\n            const uint32_t h = i2;  // head\n\n            slope = (smctx->max_bias > 0.0f) ?\n                        h < smctx->n_head_log2 ?\n                        powf(smctx->m0, h + 1) :\n                        powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) :\n                        1.0f;\n            prev_i2 = i2;\n        }\n\n        float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03);\n        float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3);\n\n        // broadcast the mask across rows\n        __fp16 * mp_f16 = (smctx->use_src1) ?\n                              (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :\n                              NULL;\n        float *  mp_f32 = (smctx->use_src1) ?\n                              (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :\n                              NULL;\n\n        if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) {\n            hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale,\n                                      (const uint8_t *) mp_f32, slope);\n        } else {\n            hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale);\n            if (mp_f32) {\n                if (smctx->use_f16) {\n                    for (int i = 0; i < ne00; ++i) {\n                        wp0[i] += slope * (float) mp_f16[i];\n                    }\n                } else {\n                    for (int i = 0; i < ne00; ++i) {\n                        wp0[i] += slope * mp_f32[i];\n                    }\n                }\n            }\n        }\n\n        if (1 == opt_path) {\n            hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);\n        } else {\n            float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);\n            float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);\n            sum       = sum > 0.0 ? (1.0 / sum) : 1;\n            hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);\n        }\n    }\n\n    t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\\n\", ith, nth,\n         smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,\n         ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\nstatic int execute_op_softmax_f32(struct htp_ops_context * octx) {\n    int err = HTP_STATUS_OK;\n\n    const struct htp_tensor * src0 = &octx->src0;\n    const struct htp_tensor * src1 = &octx->src1;\n    struct htp_tensor *       dst  = &octx->dst;\n\n    struct htp_softmax_context smctx;\n    const char * op_type = \"softmax-f32\";\n\n    switch (octx->op) {\n        case HTP_OP_SOFTMAX:\n            init_softmax_ctx(&smctx, octx);\n            break;\n\n        default:\n            FARF(ERROR, \"Unsupported Op %u\\n\", octx->op);\n            return HTP_STATUS_NO_SUPPORT;\n    }\n\n    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];\n    const uint32_t n_threads  = MIN(octx->n_threads, src0_nrows);\n\n    const size_t src0_row_size = src0->nb[1];\n    const size_t src1_row_size = src0_row_size;\n    const size_t dst_row_size  = dst->nb[1];\n\n    // VTCM scratchpads for all tensors\n    // N rows per thread, padded to HVX vector size\n    octx->dst_spad.size  = hex_round_up(dst_row_size, 128) * n_threads;\n    octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;\n    octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;\n\n    // Use stride for calculating offset\n    smctx.spad_stride = hex_round_up(src0_row_size, 128);\n\n    size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;\n\n    if (src1->ne[0]) {\n        FARF(HIGH,\n             \"%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\\n\",\n             op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],\n             src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,\n             octx->dst_spad.size);\n    } else {\n        FARF(HIGH, \"%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\\n\", op_type,\n             src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],\n             octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);\n    }\n\n    // Make sure the reserved vtcm size is sufficient\n    if (octx->ctx->vtcm_size < spad_size) {\n        FARF(ERROR, \"%s : current VTCM reservation %zu is too small, needed %zu\\n\", op_type, octx->ctx->vtcm_size,\n             spad_size);\n        return HTP_STATUS_VTCM_TOO_SMALL;\n    }\n\n    octx->src0_spad.data = octx->ctx->vtcm_base;\n    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;\n    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;\n\n    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {\n        smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;\n        worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads);\n    }\n\n    return err;\n}\n\nint op_softmax(struct htp_ops_context * octx) {\n    int err = HTP_STATUS_OK;\n\n    switch (octx->src0.type) {\n        case HTP_TYPE_F32:\n            err = execute_op_softmax_f32(octx);\n            break;\n\n        default:\n            err = HTP_STATUS_NO_SUPPORT;\n            break;\n    }\n\n    return err;\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/ssm-conv.c",
    "content": "#pragma clang diagnostic ignored \"-Wunused-variable\"\n#pragma clang diagnostic ignored \"-Wunused-function\"\n#pragma clang diagnostic ignored \"-Wunused-but-set-variable\"\n\n#include <HAP_farf.h>\n#include <HAP_mem.h>\n#include <HAP_perf.h>\n#include <HAP_ps.h>\n#include <hexagon_protos.h>\n#include <hexagon_types.h>\n#include <math.h>\n#include <qurt_thread.h>\n#include <string.h>\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n#include \"htp-ctx.h\"\n#include \"hex-dma.h\"\n#include \"htp-msg.h\"\n#include \"htp-ops.h\"\n#include \"hvx-utils.h\"\n\n#define htp_ssm_conv_tensors_preamble                        \\\n    struct htp_tensor * restrict src0    = &octx->src0;      \\\n    struct htp_tensor * restrict src1    = &octx->src1;      \\\n    struct htp_tensor * restrict dst     = &octx->dst;       \\\n    struct htp_spad * restrict src0_spad = &octx->src0_spad; \\\n    struct htp_spad * restrict src1_spad = &octx->src1_spad; \\\n    struct htp_spad * restrict dst_spad  = &octx->dst_spad;  \\\n                                                             \\\n    const uint32_t ne00 = src0->ne[0];                       \\\n    const uint32_t ne01 = src0->ne[1];                       \\\n    const uint32_t ne02 = src0->ne[2];                       \\\n    const uint32_t ne03 = src0->ne[3];                       \\\n                                                             \\\n    const uint32_t ne10 = src1->ne[0];                       \\\n    const uint32_t ne11 = src1->ne[1];                       \\\n    const uint32_t ne12 = src1->ne[2];                       \\\n    const uint32_t ne13 = src1->ne[3];                       \\\n                                                             \\\n    const uint32_t ne0 = dst->ne[0];                         \\\n    const uint32_t ne1 = dst->ne[1];                         \\\n    const uint32_t ne2 = dst->ne[2];                         \\\n    const uint32_t ne3 = dst->ne[3];                         \\\n                                                             \\\n    const uint32_t nb00 = src0->nb[0];                       \\\n    const uint32_t nb01 = src0->nb[1];                       \\\n    const uint32_t nb02 = src0->nb[2];                       \\\n    const uint32_t nb03 = src0->nb[3];                       \\\n                                                             \\\n    const uint32_t nb10 = src1->nb[0];                       \\\n    const uint32_t nb11 = src1->nb[1];                       \\\n    const uint32_t nb12 = src1->nb[2];                       \\\n    const uint32_t nb13 = src1->nb[3];                       \\\n                                                             \\\n    const uint32_t nb0 = dst->nb[0];                         \\\n    const uint32_t nb1 = dst->nb[1];                         \\\n    const uint32_t nb2 = dst->nb[2];                         \\\n    const uint32_t nb3 = dst->nb[3];\n\nstruct htp_ssm_conv_context {\n    struct htp_ops_context * octx;\n    uint32_t nrows_per_thread;\n    uint64_t t_start;\n};\n\n#define htp_ssm_conv_preamble                            \\\n    struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \\\n    struct htp_ops_context * octx = scctx->octx;         \\\n    htp_ssm_conv_tensors_preamble;                       \\\n    dma_queue * dma_queue         = octx->ctx->dma[ith];\n\n// Scalar FP32 SSM_CONV implementation\nstatic void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {\n    htp_ssm_conv_preamble;\n\n    uint64_t t1, t2;\n    t1 = HAP_perf_get_qtimer_count();\n\n    const uint32_t d_conv  = src1->ne[0];\n    const uint32_t d_inner = src0->ne[1];\n    const uint32_t n_t     = dst->ne[1];\n    const uint32_t n_s     = dst->ne[2];\n\n    const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); // stride for inner dimension\n    const uint32_t src0_stride_seq   = src0->nb[2] / sizeof(float); // stride for sequence dimension\n    const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); // stride for inner dimension\n    const uint32_t dst_stride_token  = dst->nb[1]  / sizeof(float); // stride for token dimension\n    const uint32_t dst_stride_seq    = dst->nb[2]  / sizeof(float); // stride for sequence dimension\n\n    const float * src0_data = (const float *) src0->data;\n    const float * src1_data = (const float *) src1->data;\n    float *       dst_data  = (float *) dst->data;\n\n    // Calculate row range for this thread\n    const uint32_t d_inner_per_thread = scctx->nrows_per_thread;\n    const uint32_t d_inner_start = d_inner_per_thread * ith;\n    const uint32_t d_inner_end   = MIN(d_inner_start + d_inner_per_thread, d_inner);\n\n    // No work for this thread\n    if (d_inner_start >= d_inner_end) {\n        return;\n    }\n\n    for (uint32_t i3 = 0; i3 < n_s; ++i3) {\n        for (uint32_t i2 = 0; i2 < n_t; ++i2) {\n            for (uint32_t i1 = d_inner_start; i1 < d_inner_end; ++i1) {\n                float sumf = 0.0f;\n\n                for (uint32_t i0 = 0; i0 < d_conv; ++i0) {\n                    const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq;\n                    const uint32_t src1_idx = i0 + i1 * src1_stride_inner;\n\n                    sumf += src0_data[src0_idx] * src1_data[src1_idx];\n                }\n\n                const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq;\n                dst_data[dst_idx] = sumf;\n            }\n        }\n    }\n\n    t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"ssm-conv-f32 %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\\n\",\n         ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], d_inner_start, d_inner_end,\n         src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1],\n         dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\n// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension\nstatic void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) {\n    htp_ssm_conv_preamble;\n\n    uint64_t t1, t2;\n    t1 = HAP_perf_get_qtimer_count();\n\n    const int nc  = src1->ne[0]; // d_conv\n    const int ncs = src0->ne[0]; // d_conv - 1 + n_t\n\n    const uint32_t d_conv  = src1->ne[0];\n    const uint32_t d_inner = src0->ne[1];\n    const uint32_t n_t     = dst->ne[1];\n    const uint32_t n_s     = dst->ne[2];\n\n    const float * src0_data = (const float *) src0->data;\n    const float * src1_data = (const float *) src1->data;\n    float *       dst_data  = (float *) dst->data;\n\n    // Calculate row range for this thread\n    const int dr = scctx->nrows_per_thread;\n    const uint32_t ir0 = dr * ith;\n    const uint32_t ir1 = MIN(ir0 + dr, d_inner);\n    const int      ir  = ir1 - ir0;\n\n    if (ir0 >= ir1) {\n        return;  // No work for this thread\n    }\n\n    // src0 and src1 gather offsets\n    uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 };\n    uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 };\n\n    for (uint32_t i = 0; i < VLEN_FP32; ++i) {\n        src0_offsets[i] = i * (ncs)    * sizeof(float);\n        src1_offsets[i] = i * (d_conv) * sizeof(float);\n    }\n\n    const uint32_t src0_gather_len = VLEN * ncs;\n    const uint32_t src1_gather_len = VLEN * d_conv;\n\n    // gather scratchpads\n    HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0);\n    HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN);\n\n    float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]);\n    float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]);\n\n    uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread;\n    uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread;\n\n    // copy src1 workload to VTCM\n    dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir);\n\n    // FARF(HIGH, \"ssm-conv-src1-fetch %d: ir0 %u size %u\\n\", ith, ir0, nb11 * ir);\n\n    for (uint32_t i3 = 0; i3 < n_s; ++i3) {\n        float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2]));\n\n        // copy src0 workload to VTCM\n        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir);\n\n        // FARF(HIGH, \"ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\\n\", ith, ir0, i3, nb01 * ir);\n\n        dma_queue_flush(dma_queue);\n\n        for (uint32_t i2 = 0; i2 < n_t; ++i2) {\n            float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2]));\n\n            const uint32_t nvec = ir / VLEN_FP32;\n            const uint32_t nloe = ir % VLEN_FP32;\n            uint32_t i1 = 0;\n\n            for (uint32_t vi1 = 0; vi1 < nvec; vi1++) {\n                HVX_Vector acc_vec = Q6_V_vsplat_R(0);\n\n                for (uint32_t i0 = 0; i0 < d_conv; ++i0) {\n                    Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])),\n                                     src0_gather_len, (*(const HVX_Vector *) src0_offsets));\n                    Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)),\n                                     src1_gather_len, (*(const HVX_Vector *) src1_offsets));\n\n                    HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);\n                    acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);\n                }\n\n                *(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec);\n                i1 += VLEN_FP32;\n            }\n\n            if (nloe) {\n                HVX_Vector acc_vec = Q6_V_vsplat_R(0);\n\n                for (uint32_t i0 = 0; i0 < d_conv; ++i0) {\n                    Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])),\n                                     src0_gather_len, (*(const HVX_Vector *) src0_offsets));\n                    Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)),\n                                     src1_gather_len, (*(const HVX_Vector *) src1_offsets));\n\n                    HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);\n                    acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);\n                }\n\n                hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec));\n            }\n        }\n    }\n\n    t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\\n\",\n         ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1,\n         src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1],\n         dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\nint op_ssm_conv_f32(struct htp_ops_context * octx) {\n    htp_ssm_conv_tensors_preamble;\n\n    if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) {\n        FARF(ERROR, \"ssm_conv: only (F32 x F32 -> F32) OPs supported\");\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    struct htp_ssm_conv_context scctx = { 0 };\n    scctx.octx = octx;\n\n    const uint32_t d_conv  = src1->ne[0];\n    const uint32_t d_inner = src0->ne[1];\n    const uint32_t n_t     = dst->ne[1];  // tokens per sequence\n    const uint32_t n_s     = dst->ne[2];  // number of sequences in the batch\n\n    const uint32_t n_threads = MIN(octx->n_threads, d_inner);\n\n    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {\n        uint32_t use_hvx = 0;\n        if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) {\n            int is_aligned = hex_is_aligned((void *) src0->data, VLEN) &&\n                             hex_is_aligned((void *) src1->data, VLEN) &&\n                             hex_is_aligned((void *) dst->data, VLEN);\n\n            if (is_aligned) {\n                use_hvx = 1;\n            }\n        }\n\n        if (use_hvx) {\n            scctx.nrows_per_thread  = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread\n            scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even\n\n            octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256);\n            octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256);\n            octx->dst_spad.size_per_thread  = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256);\n\n            octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads;\n            octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads;\n            octx->dst_spad.size  = octx->dst_spad.size_per_thread  * n_threads;\n\n            // Compute gather scratchpad size for src0 and src1\n            const size_t gather_spad_size = n_threads * VLEN * 2;\n\n            octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size;\n            octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;\n            octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;\n\n            FARF(HIGH, \"ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\\n\",\n                gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread,\n                octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size,\n                octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data);\n\n            const size_t total_spad_size =\n                gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;\n\n            if (total_spad_size > octx->ctx->vtcm_size) {\n                FARF(HIGH, \"ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu\", total_spad_size,\n                     octx->ctx->vtcm_size);\n                use_hvx = 0;\n            }\n        }\n\n        FARF(HIGH, \"ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d\\n\", src0->ne[0],\n             src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],\n             dst->ne[1], dst->ne[2], dst->ne[3], use_hvx);\n\n        if (use_hvx) {\n            worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32_hvx, &scctx, n_threads);\n        } else {\n            worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32, &scctx, n_threads);\n        }\n    }\n\n    return HTP_STATUS_OK;\n}\n\nint op_ssm_conv(struct htp_ops_context * octx) {\n    int                 err = HTP_STATUS_OK;\n    struct htp_tensor * dst = &octx->dst;\n\n    switch (dst->type) {\n        case HTP_TYPE_F32:\n            err = op_ssm_conv_f32(octx);\n            break;\n        default:\n            err = HTP_STATUS_NO_SUPPORT;\n            break;\n    }\n\n    return err;\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/sum-rows-ops.c",
    "content": "#pragma clang diagnostic ignored \"-Wunused-variable\"\n#pragma clang diagnostic ignored \"-Wunused-function\"\n#pragma clang diagnostic ignored \"-Wunused-but-set-variable\"\n\n#include <HAP_farf.h>\n#include <HAP_perf.h>\n\n#include <string.h>\n#include <math.h>\n\n#include \"hex-dma.h\"\n#include \"hvx-utils.h\"\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n#include \"htp-ctx.h\"\n#include \"htp-msg.h\"\n#include \"htp-ops.h\"\n\n#define sum_rows_preamble                       \\\n    struct htp_tensor *src0 =  &octx->src0;\\\n    struct htp_tensor *dst  = &octx->dst;  \\\n                                           \\\n    const uint32_t ne00 = src0->ne[0];     \\\n    const uint32_t ne01 = src0->ne[1];     \\\n    const uint32_t ne02 = src0->ne[2];     \\\n    const uint32_t ne03 = src0->ne[3];     \\\n                                           \\\n    const uint32_t nb00 = src0->nb[0];     \\\n    const uint32_t nb01 = src0->nb[1];     \\\n    const uint32_t nb02 = src0->nb[2];     \\\n    const uint32_t nb03 = src0->nb[3];     \\\n                                           \\\n    const uint32_t  ne0 = dst->ne[0];      \\\n    const uint32_t  ne1 = dst->ne[1];      \\\n    const uint32_t  ne2 = dst->ne[2];      \\\n    const uint32_t  ne3 = dst->ne[3];      \\\n                                           \\\n    const uint32_t  nb0 = dst->nb[0];      \\\n    const uint32_t  nb1 = dst->nb[1];      \\\n    const uint32_t  nb2 = dst->nb[2];      \\\n    const uint32_t  nb3 = dst->nb[3];      \\\n\nstruct sum_rows_context {\n    const uint8_t * src_data;\n    uint8_t       * dst_data;\n    uint32_t        ne00;\n    size_t          src_stride;\n    size_t          dst_stride;\n    uint32_t        rows_per_thread;\n    uint32_t        total_rows;\n    bool            opt_path;\n};\n\nstatic void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) {\n    const struct sum_rows_context * smctx = (const struct sum_rows_context *) data;\n\n    const uint32_t rows_per_thread = smctx->rows_per_thread;\n    const uint32_t total_rows      = smctx->total_rows;\n\n    const uint32_t start_row = rows_per_thread * ith;\n    const uint32_t end_row   = MIN(start_row + rows_per_thread, total_rows);\n\n    if (start_row >= end_row) {\n        return;\n    }\n\n    const size_t   src_stride = smctx->src_stride;\n    const size_t   dst_stride = smctx->dst_stride;\n    const uint32_t ne00       = smctx->ne00;\n    const bool     opt_path   = smctx->opt_path;\n\n    const float * restrict src_th = (const float *) (smctx->src_data + (start_row * src_stride));\n    float       * restrict dst_th = (float *)       (smctx->dst_data + (start_row * dst_stride));\n\n    // Calculate actual number of rows for this thread\n    const uint32_t n_rows = end_row - start_row;\n\n    for (uint32_t ir = 0; ir < n_rows; ir++) {\n        const float * restrict src_local = src_th + (ir * (src_stride / sizeof(float)));\n\n        if (ir + 1 < n_rows) {\n            hex_l2fetch(src_local + (src_stride / sizeof(float)), src_stride, src_stride, 1);\n        }\n\n        if (opt_path) {\n            dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);\n        } else {\n            dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);\n        }\n    }\n}\n\nint op_sum_rows(struct htp_ops_context * octx) {\n    sum_rows_preamble;\n\n    if (octx->src0.type != HTP_TYPE_F32) {\n        return HTP_STATUS_NO_SUPPORT;\n    }\n\n    if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {\n        return HTP_STATUS_OK;\n    }\n\n    const uint32_t src0_nrows = ne01 * ne02 * ne03;\n    const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);\n    const uint32_t rows_per_thread = (src0_nrows + n_threads - 1) / n_threads;\n\n    bool opt_path = false;\n    if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {\n        opt_path = true;\n    }\n\n    struct sum_rows_context smctx = {\n        .src_data        = (const uint8_t *) src0->data,\n        .dst_data        = (uint8_t *) dst->data,\n        .ne00            = ne00,\n        .src_stride      = nb01,\n        .dst_stride      = nb1,\n        .rows_per_thread = rows_per_thread,\n        .total_rows      = src0_nrows,\n        .opt_path        = opt_path,\n    };\n\n    worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_threads);\n\n    return HTP_STATUS_OK;\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/unary-ops.c",
    "content": "#pragma clang diagnostic ignored \"-Wunused-variable\"\n#pragma clang diagnostic ignored \"-Wunused-function\"\n#pragma clang diagnostic ignored \"-Wunused-but-set-variable\"\n\n#include <HAP_farf.h>\n#include <HAP_perf.h>\n\n#include <math.h>\n#include <string.h>\n\n#include \"hex-dma.h\"\n#include \"hvx-utils.h\"\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n#include \"htp-ctx.h\"\n#include \"htp-msg.h\"\n#include \"htp-ops.h\"\n\nstruct htp_unary_context {\n    struct htp_ops_context * octx;\n\n    // Precomputed values\n    const uint8_t *           data_src0;\n    uint8_t *                 data_dst;\n\n    size_t                    src0_row_size;\n    size_t                    dst_row_size;\n\n    size_t                    src0_row_size_aligned;\n    size_t                    dst_row_size_aligned;\n\n    size_t                    src0_spad_half_size;\n    size_t                    dst_spad_half_size;\n\n    uint32_t                  block;\n    uint32_t                  src0_nrows;\n    uint32_t                  src0_nrows_per_thread;\n    uint32_t                  nc;\n};\n\n#define htp_unary_preamble            \\\n    const uint32_t ne00 = src->ne[0]; \\\n    const uint32_t ne01 = src->ne[1]; \\\n    const uint32_t ne02 = src->ne[2]; \\\n    const uint32_t ne03 = src->ne[3]; \\\n                                      \\\n    const uint32_t ne0 = dst->ne[0];  \\\n    const uint32_t ne1 = dst->ne[1];  \\\n    const uint32_t ne2 = dst->ne[2];  \\\n    const uint32_t ne3 = dst->ne[3];  \\\n                                      \\\n    const uint32_t nb00 = src->nb[0]; \\\n    const uint32_t nb01 = src->nb[1]; \\\n    const uint32_t nb02 = src->nb[2]; \\\n    const uint32_t nb03 = src->nb[3]; \\\n                                      \\\n    const uint32_t nb0 = dst->nb[0];  \\\n    const uint32_t nb1 = dst->nb[1];  \\\n    const uint32_t nb2 = dst->nb[2];  \\\n    const uint32_t nb3 = dst->nb[3];\n\nstatic void hvx_fast_rms_norm_f32(const uint8_t * restrict src,\n                                  uint8_t * restrict dst,\n                                  uint8_t * restrict pad,\n                                  const int num_elems,\n                                  float     epsilon) {\n    const HVX_Vector * restrict v_src = (HVX_Vector *) src;\n    HVX_Vector * restrict v_dst       = (HVX_Vector *) dst;\n\n    HVX_Vector sum_v     = Q6_V_vsplat_R(0x00000000);\n    HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);\n\n    int step_of_1 = num_elems >> 5;\n    #pragma unroll(4)\n    for (int i = 0; i < step_of_1; i++) {\n        HVX_Vector v1 = v_src[i];\n        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);\n        sum_v         = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);\n    }\n\n    sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes\n\n    HVX_Vector t_v            = hvx_vec_splat_f32((float) num_elems);\n    HVX_Vector denom_v        = hvx_vec_inverse_f32(t_v);\n    HVX_Vector mean_v         = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);\n    HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);\n\n    HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));\n\n    #pragma unroll(4)\n    for (int i = 0; i < step_of_1; i++) {\n        HVX_Vector v1 = v_src[i];\n        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);\n        v_dst[i]      = Q6_Vsf_equals_Vqf32(v2);\n    }\n}\n\nstatic void scale_f32(const float * restrict src,\n                      float * restrict dst,\n                      uint8_t * restrict spad,\n                      const uint32_t num_rows,\n                      const uint32_t row_elems,\n                      const size_t   row_size,\n                      int32_t *      op_params) {\n    float scale = 0.f;\n    float bias  = 0.f;\n    memcpy(&scale, &op_params[0], sizeof(float));\n    memcpy(&bias,  &op_params[1], sizeof(float));\n\n    for (uint32_t ir = 0; ir < num_rows; ir++) {\n        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);\n        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);\n\n        hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);\n    }\n}\n\nstatic void rms_norm_f32(const float * restrict src,\n                         float * restrict dst,\n                         uint8_t * restrict spad,\n                         const uint32_t num_rows,\n                         const uint32_t row_elems,\n                         const size_t   row_size,\n                         int32_t *      op_params) {\n    float epsilon = 0.f;\n    memcpy(&epsilon, op_params, sizeof(float));\n\n    for (uint32_t ir = 0; ir < num_rows; ir++) {\n        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);\n        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);\n\n        hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);\n    }\n}\n\nstatic void sqr_f32(const float * restrict src,\n                    float * restrict dst,\n                    uint8_t * restrict spad,\n                    const uint32_t num_rows,\n                    const uint32_t row_elems,\n                    const size_t   row_size,\n                    int32_t *      op_params) {\n\n    for (uint32_t ir = 0; ir < num_rows; ir++) {\n        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);\n        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);\n\n        hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);\n    }\n}\n\nstatic void sqrt_f32(const float * restrict src,\n                     float * restrict dst,\n                     uint8_t * restrict spad,\n                     const uint32_t num_rows,\n                     const uint32_t row_elems,\n                     const size_t   row_size,\n                     int32_t *      op_params) {\n\n    for (uint32_t ir = 0; ir < num_rows; ir++) {\n        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);\n        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);\n\n        hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);\n    }\n}\n\nstatic void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {\n    const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;\n    struct htp_ops_context * octx = uctx->octx;\n    const struct htp_tensor * src = &octx->src0;\n    const struct htp_tensor * dst = &octx->dst;\n\n    htp_unary_preamble;\n\n    int                       htp_op = octx->op;\n    int32_t *                 op_params = octx->op_params;\n    uint32_t                  src0_nrows_per_thread = uctx->src0_nrows_per_thread;\n\n    const size_t src0_row_size = uctx->src0_row_size;\n    const size_t dst_row_size  = uctx->dst_row_size;\n\n    const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;\n    const size_t dst_row_size_aligned  = uctx->dst_row_size_aligned;\n\n    const uint32_t src0_nrows = uctx->src0_nrows;\n    const uint32_t src0_start_row = src0_nrows_per_thread * ith;\n    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);\n\n    // no work for this thread\n    if (src0_start_row >= src0_end_row) {\n        return;\n    }\n\n    uint64_t t1, t2;\n    t1 = HAP_perf_get_qtimer_count();\n\n    const uint8_t * restrict data_src = uctx->data_src0;\n    uint8_t * restrict       data_dst = uctx->data_dst;\n\n    uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);\n    uint8_t * dst_spad_data  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);\n\n    size_t src0_spad_half_size = uctx->src0_spad_half_size;\n    size_t dst_spad_half_size  = uctx->dst_spad_half_size;\n\n    const int BLOCK = uctx->block;\n    if (BLOCK == 0) {\n        FARF(ERROR, \"unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\\n\",\n             octx->src0_spad.size_per_thread, src0_row_size_aligned);\n        return;\n    }\n\n    dma_queue * dma_queue = octx->ctx->dma[ith];\n\n    for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {\n        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);\n\n        // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)\n        dma_queue_push_vtcm_to_ddr(dma_queue,\n            dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),\n            dst_row_size, dst_row_size_aligned, 0);\n\n        dma_queue_push_ddr_to_vtcm(dma_queue,\n            dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),\n            src0_row_size_aligned, src0_row_size, block_size);\n    }\n\n    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {\n        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);\n\n        float * dst_spad  = (float *) dma_queue_pop(dma_queue).src;\n        float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;\n\n        // Process block in VTCM\n        switch (htp_op) {\n            case HTP_OP_RMS_NORM:\n                rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);\n                break;\n            case HTP_OP_SCALE:\n                scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);\n                break;\n            case HTP_OP_SQR:\n                sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);\n                break;\n            case HTP_OP_SQRT:\n                sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);\n                break;\n            default:\n                break;\n        }\n\n        dma_queue_push_vtcm_to_ddr(dma_queue,\n            dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),\n            dst_row_size, dst_row_size_aligned, block_size);\n\n        // prefetch N+2 loop iteration if any\n        const uint32_t pref_block = (ir + BLOCK * 2);\n        if (pref_block < src0_end_row) {\n            const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);\n            dma_queue_push_ddr_to_vtcm(dma_queue,\n                dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),\n                src0_row_size_aligned, src0_row_size, pref_block_size);\n        }\n    }\n\n    dma_queue_flush(dma_queue);\n\n    t2 = HAP_perf_get_qtimer_count();\n\n    FARF(HIGH, \"unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\\n\", ith, nth, src->ne[0],\n         src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],\n         dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));\n}\n\nstatic int execute_op_unary_f32(struct htp_ops_context * octx) {\n    int err = HTP_STATUS_OK;\n\n    const struct htp_tensor * src0 = &octx->src0;\n    struct htp_tensor *       dst  = &octx->dst;\n\n    const char * op_type = NULL;\n\n    switch (octx->op) {\n        case HTP_OP_RMS_NORM:\n            op_type = \"rmsnorm-f32\";\n            break;\n        case HTP_OP_SCALE:\n            op_type = \"scale-f32\";\n            break;\n        case HTP_OP_SQR:\n            op_type = \"sqr-f32\";\n            break;\n        case HTP_OP_SQRT:\n            op_type = \"sqrt-f32\";\n            break;\n\n        default:\n            FARF(ERROR, \"Unsupported unary Op %u\\n\", octx->op);\n            return HTP_STATUS_NO_SUPPORT;\n    }\n\n    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];\n    const uint32_t n_threads  = MIN(octx->n_threads, src0_nrows);\n\n    const size_t src0_row_size = src0->nb[1];\n    const size_t dst_row_size  = dst->nb[1];\n\n    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);\n    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);\n\n    // VTCM scratchpads for all tensors\n    // N rows per thread, padded to HVX vector size\n    // Double buffering requires 2x size per buffer\n\n    size_t spad_size_per_row   = 2 * (src0_row_size_aligned + dst_row_size_aligned);\n    size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row);\n\n    // Make sure the reserved vtcm size is sufficient\n    if (vtcm_row_per_thread == 0) {\n        FARF(ERROR, \"unary-%s : current VTCM reservation %zu is too small, needed %zu\\n\", op_type, octx->ctx->vtcm_size,\n             spad_size_per_row * n_threads);\n        return HTP_STATUS_VTCM_TOO_SMALL;\n    }\n\n    octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2;\n    octx->dst_spad.size_per_thread  = dst_row_size_aligned * vtcm_row_per_thread * 2;\n\n    octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;\n    octx->dst_spad.size  = n_threads * octx->dst_spad.size_per_thread;\n\n    octx->src0_spad.data = octx->ctx->vtcm_base;\n    octx->dst_spad.data  = octx->src0_spad.data + octx->src0_spad.size;\n\n    FARF(HIGH, \"%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\\n\", op_type,\n         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],\n         octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);\n\n    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {\n        struct htp_unary_context uctx = {\n            .octx                  = octx,\n            .src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads,\n            .src0_nrows            = src0_nrows,\n\n            .data_src0             = (const uint8_t *)src0->data,\n            .data_dst              = (uint8_t *)dst->data,\n\n            .src0_row_size         = src0_row_size,\n            .dst_row_size          = dst_row_size,\n\n            .src0_row_size_aligned = src0_row_size_aligned,\n            .dst_row_size_aligned  = dst_row_size_aligned,\n\n            .src0_spad_half_size   = octx->src0_spad.size_per_thread / 2,\n            .dst_spad_half_size    = octx->dst_spad.size_per_thread / 2,\n\n            .block                 = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned,\n            .nc                    = src0->ne[0],\n        };\n\n        worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads);\n    }\n\n    return err;\n}\n\nint op_unary(struct htp_ops_context * octx) {\n    int err = HTP_STATUS_OK;\n\n    switch (octx->src0.type) {\n        case HTP_TYPE_F32:\n            err = execute_op_unary_f32(octx);\n            break;\n\n        default:\n            err = HTP_STATUS_NO_SUPPORT;\n            break;\n    }\n\n    return err;\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/worker-pool.c",
    "content": "#include \"worker-pool.h\"\n\n#include <qurt.h>\n#include <stdatomic.h>\n#include <stdint.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <string.h>\n\n#include \"HAP_farf.h\"\n\n#define WORKER_THREAD_STACK_SZ  (2 * 16384)\n#define LOWEST_USABLE_QURT_PRIO (254)\n\nstruct worker_pool_s;\n\n// internal structure kept in thread-local storage per instance of worker pool\ntypedef struct {\n    struct worker_pool_s * pool;\n    unsigned int           id;\n} worker_context_t;\n\n// internal structure kept in thread-local storage per instance of worker pool\ntypedef struct worker_pool_s {\n    worker_pool_job_t job[MAX_NUM_WORKERS];      // list of job descriptors\n    qurt_thread_t     thread[MAX_NUM_WORKERS];   // thread ID's of the workers\n    worker_context_t  context[MAX_NUM_WORKERS];  // worker contexts\n    void *            stack[MAX_NUM_WORKERS];    // thread stack pointers\n    unsigned int      n_threads;                 // number of workers in this pool\n\n    atomic_uint seqn;                            // seqno used to detect new jobs\n    atomic_uint next_job;                        // next job index\n    atomic_uint n_pending;                       // number of pending jobs\n    atomic_uint n_jobs;                          // number of current jobs\n    atomic_bool killed;                          // threads need to exit\n} worker_pool_t;\n\nstatic void worker_pool_main(void * context) {\n    worker_context_t * me   = (worker_context_t *) context;\n    worker_pool_t *    pool = me->pool;\n\n    FARF(HIGH, \"worker-pool: thread %u started\", me->id);\n\n    unsigned int prev_seqn = 0;\n    while (!atomic_load(&pool->killed)) {\n        unsigned int seqn = atomic_load(&pool->seqn);\n        if (seqn == prev_seqn) {\n            // Nothing to do\n            qurt_futex_wait(&pool->seqn, prev_seqn);\n            continue;\n        }\n\n        // New job\n        prev_seqn = seqn;\n\n        unsigned int n = atomic_load(&pool->n_jobs);\n        unsigned int i = atomic_fetch_add(&pool->next_job, 1);\n        if (i >= n) {\n            // Spurious wakeup\n            continue;\n        }\n\n        pool->job[i].func(n, i, pool->job[i].data);\n\n        atomic_fetch_sub(&pool->n_pending, 1);\n    }\n\n    FARF(HIGH, \"worker-pool: thread %u stopped\", me->id);\n}\n\nAEEResult worker_pool_init_with_stack_size(worker_pool_context_t * context, uint32_t n_threads, uint32_t stack_size) {\n    int err = 0;\n\n    if (NULL == context) {\n        FARF(ERROR, \"NULL context passed to worker_pool_init().\");\n        return AEE_EBADPARM;\n    }\n\n    // Allocations\n    int size = (stack_size * n_threads) + (sizeof(worker_pool_t));\n\n    unsigned char * mem_blob = (unsigned char *) malloc(size);\n    if (!mem_blob) {\n        FARF(ERROR, \"Could not allocate memory for worker pool!!\");\n        return AEE_ENOMEMORY;\n    }\n\n    worker_pool_t * me = (worker_pool_t *) (mem_blob + stack_size * n_threads);\n\n    // name for the first worker, useful in debugging threads\n    char name[19];\n    snprintf(name, 12, \"0x%8x:\", (int) me);\n    strcat(name, \"worker0\");\n    me->n_threads = n_threads;\n\n    // initializations\n    for (unsigned int i = 0; i < me->n_threads; i++) {\n        me->stack[i]  = NULL;\n        me->thread[i] = 0;\n\n        me->context[i].id   = i;\n        me->context[i].pool = me;\n    }\n\n    // initialize job queue\n    me->n_pending = 0;\n    me->n_jobs    = 0;\n    me->next_job  = 0;\n    me->seqn      = 0;\n    me->killed    = 0;\n\n    // launch the workers\n    qurt_thread_attr_t attr;\n    qurt_thread_attr_init(&attr);\n\n    for (unsigned int i = 0; i < me->n_threads; i++) {\n        // set up stack\n        me->stack[i] = mem_blob;\n        mem_blob += stack_size;\n        qurt_thread_attr_set_stack_addr(&attr, me->stack[i]);\n        qurt_thread_attr_set_stack_size(&attr, stack_size);\n\n        // set up name\n        qurt_thread_attr_set_name(&attr, name);\n        name[17] = (name[17] + 1);\n        // name threads context:worker0, context:worker1, .. (recycle at 9, but num threads should be less than that anyway)\n        if (name[17] > '9') {\n            name[17] = '0';\n        }\n\n        // set up priority - by default, match the creating thread's prio\n        int prio = qurt_thread_get_priority(qurt_thread_get_id());\n\n        if (prio < 1) {\n            prio = 1;\n        }\n        if (prio > LOWEST_USABLE_QURT_PRIO) {\n            prio = LOWEST_USABLE_QURT_PRIO;\n        }\n\n        qurt_thread_attr_set_priority(&attr, prio);\n\n        // launch\n        err = qurt_thread_create(&me->thread[i], &attr, worker_pool_main, (void *) &me->context[i]);\n        if (err) {\n            FARF(ERROR, \"Could not launch worker threads!\");\n            worker_pool_release((worker_pool_context_t *) &me);\n            return AEE_EQURTTHREADCREATE;\n        }\n    }\n    *context = (worker_pool_context_t *) me;\n    return AEE_SUCCESS;\n}\n\nAEEResult worker_pool_init(worker_pool_context_t * context, uint32_t n_threads) {\n    return worker_pool_init_with_stack_size(context, n_threads, WORKER_THREAD_STACK_SZ);\n}\n\n// clean up worker pool\nvoid worker_pool_release(worker_pool_context_t * context) {\n    worker_pool_t * me = (worker_pool_t *) *context;\n\n    // if no worker pool exists, return error.\n    if (NULL == me) {\n        return;\n    }\n\n    atomic_store(&me->killed, 1);\n    atomic_fetch_add(&me->seqn, 1);\n    qurt_futex_wake(&me->seqn, me->n_threads);\n\n    // de-initializations\n    for (unsigned int i = 0; i < me->n_threads; i++) {\n        if (me->thread[i]) {\n            int status;\n            (void) qurt_thread_join(me->thread[i], &status);\n        }\n    }\n\n    // free allocated memory (were allocated as a single buffer starting at stack[0])\n    if (me->stack[0]) {\n        free(me->stack[0]);\n    }\n\n    *context = NULL;\n}\n\n// run jobs\nAEEResult worker_pool_run_jobs(worker_pool_context_t context, worker_pool_job_t * job, unsigned int n) {\n    worker_pool_t * me = (worker_pool_t *) context;\n    if (NULL == me) {\n        FARF(ERROR, \"worker-pool: invalid context\");\n        return AEE_EBADPARM;\n    }\n\n    if (n > me->n_threads) {\n        FARF(ERROR, \"worker-pool: invalid number of jobs %u for n-threads %u\", n, me->n_threads);\n        return AEE_EBADPARM;\n    }\n\n    memcpy(me->job, job, sizeof(worker_pool_job_t) * n);\n\n    if (n > 1) {\n        atomic_store(&me->next_job, 1);\n        atomic_store(&me->n_jobs, n);\n        atomic_store(&me->n_pending, n - 1);\n\n        // wake up workers\n        atomic_fetch_add(&me->seqn, 1);\n        qurt_futex_wake(&me->seqn, n - 1);\n    }\n\n    // main thread runs job #0\n    me->job[0].func(n, 0, me->job[0].data);\n\n    if (n > 1) {\n        while (atomic_load(&me->n_pending))\n            ;\n    }\n\n    return 0;\n}\n\n// run func\nAEEResult worker_pool_run_func(worker_pool_context_t context, worker_callback_t func, void * data, unsigned int n) {\n    worker_pool_job_t job[n];\n\n    for (unsigned int i = 0; i < n; i++) {\n        job[i].func = func;\n        job[i].data = data;\n    }\n\n    return worker_pool_run_jobs(context, job, n);\n}\n\nAEEResult worker_pool_set_thread_priority(worker_pool_context_t context, unsigned int prio) {\n    worker_pool_t * me = (worker_pool_t *) context;\n\n    // if no worker pool exists, return error.\n    if (!me) {\n        return AEE_ENOMORE;\n    }\n\n    int result = AEE_SUCCESS;\n    if (prio < 1) {\n        prio = 1;\n    }\n    if (prio > LOWEST_USABLE_QURT_PRIO) {\n        prio = LOWEST_USABLE_QURT_PRIO;\n    }\n\n    for (unsigned int i = 0; i < me->n_threads; i++) {\n        int res = qurt_thread_set_priority(me->thread[i], (unsigned short) prio);\n        if (0 != res) {\n            result = AEE_EBADPARM;\n            FARF(ERROR, \"QURT failed to set priority of thread %d, ERROR = %d\", me->thread[i], res);\n        }\n    }\n\n    return result;\n}\n\nAEEResult worker_pool_retrieve_thread_id(worker_pool_context_t context, unsigned int * tids) {\n    worker_pool_t * me = (worker_pool_t *) context;\n    if (!me) {\n        FARF(ERROR, \"worker-pool: invalid context\");\n        return AEE_EBADPARM;\n        ;\n    }\n\n    for (int i = 0; i < me->n_threads; i++) {\n        tids[i] = me->thread[i];\n    }\n\n    return AEE_SUCCESS;\n}\n\nAEEResult worker_pool_get_thread_priority(worker_pool_context_t context, unsigned int * prio) {\n    worker_pool_t * me = (worker_pool_t *) context;\n    if (!me) {\n        FARF(ERROR, \"worker-pool: invalid context\");\n        return AEE_EBADPARM;\n    }\n\n    int priority = qurt_thread_get_priority(me->thread[0]);\n    if (priority > 0) {\n        *prio = priority;\n        return 0;\n    } else {\n        *prio = 0;\n        return AEE_EBADSTATE;\n    }\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp/worker-pool.h",
    "content": "#ifndef HTP_WORKER_POOL_H\n#define HTP_WORKER_POOL_H\n\n// MACRO enables function to be visible in shared-library case.\n#define WORKERPOOL_API __attribute__((visibility(\"default\")))\n\n#include <AEEStdDef.h>\n#include <AEEStdErr.h>\n#include <stdint.h>\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n/// signature of callbacks to be invoked by worker threads\ntypedef void (*worker_callback_t)(unsigned int n, unsigned int i, void *);\n\n/// Typedef of worker_pool context\ntypedef void * worker_pool_context_t;\n\n/// descriptor for requested callback\ntypedef struct {\n    worker_callback_t func;\n    void *            data;\n} worker_pool_job_t;\n\n/// Maximum supported number of worker threads.\n#define MAX_NUM_WORKERS 10\n\n// Initialize worker pool.\nWORKERPOOL_API AEEResult worker_pool_init(worker_pool_context_t * context, uint32_t n_threads);\n\n// Initialize worker pool with custom stack size\nWORKERPOOL_API AEEResult worker_pool_init_with_stack_size(worker_pool_context_t * context,\n                                                          uint32_t                n_threads,\n                                                          uint32_t                stack_size);\n\n// Kill worker threads and release worker pool resources\nWORKERPOOL_API void worker_pool_release(worker_pool_context_t * context);\n\n// Run jobs with the worker pool.\nWORKERPOOL_API AEEResult worker_pool_run_jobs(worker_pool_context_t context, worker_pool_job_t * job, unsigned int n);\n\nWORKERPOOL_API AEEResult worker_pool_run_func(worker_pool_context_t context,\n                                              worker_callback_t     func,\n                                              void *                data,\n                                              unsigned int          n);\n\nWORKERPOOL_API AEEResult worker_pool_set_thread_priority(worker_pool_context_t context, unsigned int prio);\nWORKERPOOL_API AEEResult worker_pool_get_thread_priority(worker_pool_context_t context, unsigned int * prio);\nWORKERPOOL_API AEEResult worker_pool_retrieve_thread_id(worker_pool_context_t context, unsigned int * tids);\n\n#ifdef __cplusplus\n}\n#endif\n\n#endif  // #ifndef HTP_WORKER_POOL_H\n"
  },
  {
    "path": "src/ggml-hexagon/htp-drv.cpp",
    "content": "// sample drv interface\n\n#pragma clang diagnostic ignored \"-Wgnu-anonymous-struct\"\n#pragma clang diagnostic ignored \"-Wmissing-prototypes\"\n#pragma clang diagnostic ignored \"-Wsign-compare\"\n\n#include <filesystem>\n#include <set>\n#include <sstream>\n#include <string>\n#ifdef _WIN32\n#   define WIN32_LEAN_AND_MEAN\n#   ifndef NOMINMAX\n#       define NOMINMAX\n#   endif\n#   include <windows.h>\n#   include <winevt.h>\n#else\n#    include <dlfcn.h>\n#    include <unistd.h>\n#endif\n#include \"ggml-impl.h\"\n#include \"htp-drv.h\"\n#include \"libdl.h\"\n\n#include <domain.h>\n\n//\n// Driver API types\n//\n\ntypedef void * (*rpcmem_alloc_pfn_t)(int heapid, uint32_t flags, int size);\ntypedef void * (*rpcmem_alloc2_pfn_t)(int heapid, uint32_t flags, size_t size);\ntypedef void   (*rpcmem_free_pfn_t)(void * po);\ntypedef int    (*rpcmem_to_fd_pfn_t)(void * po);\n\ntypedef AEEResult (*dspqueue_create_pfn_t)(int                 domain,\n                                           uint32_t            flags,\n                                           uint32_t            req_queue_size,\n                                           uint32_t            resp_queue_size,\n                                           dspqueue_callback_t packet_callback,\n                                           dspqueue_callback_t error_callback,\n                                           void *              callback_context,\n                                           dspqueue_t *        queue);\ntypedef AEEResult (*dspqueue_close_pfn_t)(dspqueue_t queue);\ntypedef AEEResult (*dspqueue_export_pfn_t)(dspqueue_t queue, uint64_t *queue_id);\ntypedef AEEResult (*dspqueue_write_pfn_t)(dspqueue_t queue, uint32_t flags,\n                                          uint32_t num_buffers,\n                                          struct dspqueue_buffer *buffers,\n                                          uint32_t message_length,\n                                          const uint8_t *message,\n                                          uint32_t timeout_us);\ntypedef AEEResult (*dspqueue_read_pfn_t)(dspqueue_t queue, uint32_t *flags,\n                                         uint32_t max_buffers, uint32_t *num_buffers,\n                                         struct dspqueue_buffer *buffers,\n                                         uint32_t max_message_length,\n                                         uint32_t *message_length, uint8_t *message,\n                                         uint32_t timeout_us);\n\ntypedef int (*fastrpc_mmap_pfn_t)(int domain, int fd, void *addr, int offset, size_t length, enum fastrpc_map_flags flags);\ntypedef int (*fastrpc_munmap_pfn_t)(int domain, int fd, void *addr, size_t length);\n\ntypedef int (*remote_handle64_open_pfn_t)(const char* name, remote_handle64 *ph);\ntypedef int (*remote_handle64_invoke_pfn_t)(remote_handle64 h, uint32_t dwScalars, remote_arg *pra);\ntypedef int (*remote_handle64_close_pfn_t)(remote_handle h);\ntypedef int (*remote_handle_control_pfn_t)(uint32_t req, void* data, uint32_t datalen);\ntypedef int (*remote_handle64_control_pfn_t)(remote_handle64 h, uint32_t req, void* data, uint32_t datalen);\ntypedef int (*remote_session_control_pfn_t)(uint32_t req, void *data, uint32_t datalen);\n\n//\n// Driver API pfns\n//\n\nrpcmem_alloc_pfn_t  rpcmem_alloc_pfn  = nullptr;\nrpcmem_alloc2_pfn_t rpcmem_alloc2_pfn = nullptr;\nrpcmem_free_pfn_t   rpcmem_free_pfn   = nullptr;\nrpcmem_to_fd_pfn_t  rpcmem_to_fd_pfn  = nullptr;\n\nfastrpc_mmap_pfn_t   fastrpc_mmap_pfn   = nullptr;\nfastrpc_munmap_pfn_t fastrpc_munmap_pfn = nullptr;\n\ndspqueue_create_pfn_t dspqueue_create_pfn = nullptr;\ndspqueue_close_pfn_t  dspqueue_close_pfn  = nullptr;\ndspqueue_export_pfn_t dspqueue_export_pfn = nullptr;\ndspqueue_write_pfn_t  dspqueue_write_pfn  = nullptr;\ndspqueue_read_pfn_t   dspqueue_read_pfn   = nullptr;\n\nremote_handle64_open_pfn_t    remote_handle64_open_pfn    = nullptr;\nremote_handle64_invoke_pfn_t  remote_handle64_invoke_pfn  = nullptr;\nremote_handle64_close_pfn_t   remote_handle64_close_pfn   = nullptr;\nremote_handle_control_pfn_t   remote_handle_control_pfn   = nullptr;\nremote_handle64_control_pfn_t remote_handle64_control_pfn = nullptr;\nremote_session_control_pfn_t  remote_session_control_pfn  = nullptr;\n\n//\n// Driver API\n//\n\nvoid * rpcmem_alloc(int heapid, uint32_t flags, int size) {\n    return rpcmem_alloc_pfn(heapid, flags, size);\n}\n\nvoid * rpcmem_alloc2(int heapid, uint32_t flags, size_t size) {\n    if (rpcmem_alloc2_pfn) {\n        return rpcmem_alloc2_pfn(heapid, flags, size);\n    } else {\n        GGML_LOG_INFO(\"ggml-hex: rpcmem_alloc2 not found, falling back to rpcmem_alloc\\n\");\n        return rpcmem_alloc_pfn(heapid, flags, size);\n    }\n}\n\nvoid rpcmem_free(void * po) {\n    return rpcmem_free_pfn(po);\n}\n\nint rpcmem_to_fd(void * po) {\n    return rpcmem_to_fd_pfn(po);\n}\n\nHTPDRV_API int fastrpc_mmap(int domain, int fd, void * addr, int offset, size_t length, enum fastrpc_map_flags flags) {\n    return fastrpc_mmap_pfn(domain, fd, addr, offset, length, flags);\n}\n\nHTPDRV_API int fastrpc_munmap(int domain, int fd, void * addr, size_t length) {\n    return fastrpc_munmap_pfn(domain, fd, addr, length);\n}\n\nAEEResult dspqueue_create(int                 domain,\n                          uint32_t            flags,\n                          uint32_t            req_queue_size,\n                          uint32_t            resp_queue_size,\n                          dspqueue_callback_t packet_callback,\n                          dspqueue_callback_t error_callback,\n                          void *              callback_context,\n                          dspqueue_t *        queue) {\n    return dspqueue_create_pfn(domain, flags, req_queue_size, resp_queue_size, packet_callback, error_callback,\n                               callback_context, queue);\n}\n\nAEEResult dspqueue_close(dspqueue_t queue) {\n    return dspqueue_close_pfn(queue);\n}\n\nAEEResult dspqueue_export(dspqueue_t queue, uint64_t * queue_id) {\n    return dspqueue_export_pfn(queue, queue_id);\n}\n\nAEEResult dspqueue_write(dspqueue_t               queue,\n                         uint32_t                 flags,\n                         uint32_t                 num_buffers,\n                         struct dspqueue_buffer * buffers,\n                         uint32_t                 message_length,\n                         const uint8_t *          message,\n                         uint32_t                 timeout_us) {\n    return dspqueue_write_pfn(queue, flags, num_buffers, buffers, message_length, message, timeout_us);\n}\n\nAEEResult dspqueue_read(dspqueue_t               queue,\n                        uint32_t *               flags,\n                        uint32_t                 max_buffers,\n                        uint32_t *               num_buffers,\n                        struct dspqueue_buffer * buffers,\n                        uint32_t                 max_message_length,\n                        uint32_t *               message_length,\n                        uint8_t *                message,\n                        uint32_t                 timeout_us) {\n    return dspqueue_read_pfn(queue, flags, max_buffers, num_buffers, buffers, max_message_length, message_length,\n                             message, timeout_us);\n}\n\nHTPDRV_API int remote_handle64_open(const char * name, remote_handle64 * ph) {\n    return remote_handle64_open_pfn(name, ph);\n}\n\nHTPDRV_API int remote_handle64_invoke(remote_handle64 h, uint32_t dwScalars, remote_arg * pra) {\n    return remote_handle64_invoke_pfn(h, dwScalars, pra);\n}\n\nHTPDRV_API int remote_handle64_close(remote_handle64 h) {\n    return remote_handle64_close_pfn(h);\n}\n\nHTPDRV_API int remote_handle_control(uint32_t req, void * data, uint32_t datalen) {\n    return remote_handle_control_pfn(req, data, datalen);\n}\n\nHTPDRV_API int remote_handle64_control(remote_handle64 h, uint32_t req, void * data, uint32_t datalen) {\n    return remote_handle64_control_pfn(h, req, data, datalen);\n}\n\nHTPDRV_API int remote_session_control(uint32_t req, void * data, uint32_t datalen) {\n    return remote_session_control_pfn(req, data, datalen);\n}\n\n#ifdef _WIN32\n\nstatic std::string wstr_to_str(std::wstring_view wstr) {\n    std::string result;\n    if (wstr.empty()) {\n        return result;\n    }\n    auto bytes_needed = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,\n                                            wstr.data(), (int) wstr.size(),\n                                            nullptr, 0, nullptr, nullptr);\n    if (bytes_needed == 0) {\n        GGML_LOG_ERROR(\"ggml-hex: WideCharToMultiByte failed. Error %lu\\n\", GetLastError());\n        throw std::runtime_error(\"Invalid wstring input\");\n    }\n\n    result.resize(bytes_needed, '\\0');\n    int bytes_written = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,\n                                            wstr.data(), (int) wstr.size(),\n                                            result.data(), bytes_needed,\n                                            nullptr, nullptr);\n    if (bytes_written == 0) {\n        GGML_LOG_ERROR(\"ggml-hex: WideCharToMultiByte failed. Error %lu\\n\", GetLastError());\n        throw std::runtime_error(\"Wstring conversion failed\");\n    }\n    return result;\n}\n\nstatic std::string get_driver_path() {\n    std::wstring serviceName = L\"qcnspmcdm\";\n    std::string result;\n\n    // Get a handle to the SCM database.\n    SC_HANDLE schSCManager = OpenSCManagerW(NULL, NULL, STANDARD_RIGHTS_READ);\n    if (nullptr == schSCManager) {\n        GGML_LOG_ERROR(\"ggml-hex: Failed to open SCManager. Error: %lu\\n\", GetLastError());\n        return result;\n    }\n\n    // Get a handle to the service.\n    SC_HANDLE schService = OpenServiceW(schSCManager,           // SCM database\n                                        serviceName.c_str(),    // name of service\n                                        SERVICE_QUERY_CONFIG);  // need query config access\n\n    if (nullptr == schService) {\n        GGML_LOG_ERROR(\"ggml-hex: Failed to open qcnspmcdm service. Error: %lu\\n\", GetLastError());\n        CloseServiceHandle(schSCManager);\n        return result;\n    }\n\n    // Store the size of buffer used as an output.\n    DWORD bufferSize;\n    if (!QueryServiceConfigW(schService, NULL, 0, &bufferSize) &&\n        (GetLastError() != ERROR_INSUFFICIENT_BUFFER)) {\n        GGML_LOG_ERROR(\"ggml-hex: Failed to query service config. Error: %lu\\n\", GetLastError());\n        CloseServiceHandle(schService);\n        CloseServiceHandle(schSCManager);\n        return result;\n    }\n    // Get the configuration of the service.\n    LPQUERY_SERVICE_CONFIGW serviceConfig =\n        static_cast<LPQUERY_SERVICE_CONFIGW>(LocalAlloc(LMEM_FIXED, bufferSize));\n    if (!QueryServiceConfigW(schService, serviceConfig, bufferSize, &bufferSize)) {\n        fprintf(stderr, \"ggml-hex: Failed to query service config. Error: %lu\\n\", GetLastError());\n        LocalFree(serviceConfig);\n        CloseServiceHandle(schService);\n        CloseServiceHandle(schSCManager);\n        return result;\n    }\n\n    // Read the driver file path get its parent directory\n    std::wstring driverPath = std::wstring(serviceConfig->lpBinaryPathName);\n    driverPath = driverPath.substr(0, driverPath.find_last_of(L\"\\\\\"));\n\n    // Clean up resources\n    LocalFree(serviceConfig);\n    CloseServiceHandle(schService);\n    CloseServiceHandle(schSCManager);\n\n    // Driver path would contain invalid path string, like:\n    // \\SystemRoot\\System32\\DriverStore\\FileRepository\\qcadsprpc8280.inf_arm64_c2b9460c9a072f37\n    // \"\\SystemRoot\" should be replace with a correct one (e.g. C:\\Windows)\n    const std::wstring systemRootPlaceholder = L\"\\\\SystemRoot\";\n    if (0 != driverPath.compare(0, systemRootPlaceholder.length(), systemRootPlaceholder)) {\n        GGML_LOG_ERROR(\"ggml-hex: String pattern not found in driver path.\\n\");\n        return result;\n    }\n\n    // Replace \\SystemRoot with an absolute path from system ENV windir\n    const std::wstring systemRootEnv = L\"windir\";\n\n    // Query the number of wide characters this variable requires\n    DWORD numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), NULL, 0);\n    if (numWords == 0) {\n        GGML_LOG_ERROR(\"ggml-hex: Failed get systemRoot environment variable\\n\");\n        return result;\n    }\n\n    // Query the actual system root name from environment variable\n    std::vector<wchar_t> systemRoot(numWords + 1);\n    numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), systemRoot.data(), numWords + 1);\n    if (numWords == 0) {\n        GGML_LOG_ERROR(\"ggml-hex: Failed to read windir environment variable\\n\");\n        return result;\n    }\n    driverPath.replace(0, systemRootPlaceholder.length(), std::wstring(systemRoot.data()));\n\n    return wstr_to_str(driverPath);\n}\n\n#endif\n\nusing dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;\n\nint htpdrv_init() {\n    static dl_handle_ptr lib_cdsp_rpc_handle = nullptr;\n    static bool initialized = false;\n#ifdef _WIN32\n    std::string drv_path = get_driver_path() + \"\\\\\" + \"libcdsprpc.dll\";\n#else\n    std::string drv_path = \"libcdsprpc.so\";\n#endif\n    if (initialized) {\n        GGML_LOG_INFO(\"ggml-hex: Driver already loaded\\n\");\n        return AEE_SUCCESS;\n    }\n    GGML_LOG_INFO(\"ggml-hex: Loading driver %s\\n\", drv_path.c_str());\n\n    fs::path path{ drv_path.c_str() };\n    dl_handle_ptr handle { dl_load_library(path) };\n    if (!handle) {\n        GGML_LOG_ERROR(\"ggml-hex: failed to load %s: %s\\n\", path.u8string().c_str(), dl_error());\n        return AEE_EUNABLETOLOAD;\n    }\n\n#define dlsym(drv, type, pfn, symbol, ignore)                               \\\n    do {                                                                    \\\n        pfn = (type) dl_get_sym(drv, #symbol);                              \\\n        if (!ignore && nullptr == pfn) {                                    \\\n            GGML_LOG_ERROR(\"ggml-hex: failed to dlsym %s\\n\", #symbol);      \\\n            return AEE_EUNABLETOLOAD;                                       \\\n        }                                                                   \\\n    } while (0)\n\n    dlsym(handle.get(), rpcmem_alloc_pfn_t, rpcmem_alloc_pfn, rpcmem_alloc, false);\n    dlsym(handle.get(), rpcmem_alloc2_pfn_t, rpcmem_alloc2_pfn, rpcmem_alloc2, true);\n    dlsym(handle.get(), rpcmem_free_pfn_t, rpcmem_free_pfn, rpcmem_free, false);\n    dlsym(handle.get(), rpcmem_to_fd_pfn_t, rpcmem_to_fd_pfn, rpcmem_to_fd, false);\n    dlsym(handle.get(), fastrpc_mmap_pfn_t, fastrpc_mmap_pfn, fastrpc_mmap, false);\n    dlsym(handle.get(), fastrpc_munmap_pfn_t, fastrpc_munmap_pfn, fastrpc_munmap, false);\n    dlsym(handle.get(), dspqueue_create_pfn_t, dspqueue_create_pfn, dspqueue_create, false);\n    dlsym(handle.get(), dspqueue_close_pfn_t, dspqueue_close_pfn, dspqueue_close, false);\n    dlsym(handle.get(), dspqueue_export_pfn_t, dspqueue_export_pfn, dspqueue_export, false);\n    dlsym(handle.get(), dspqueue_write_pfn_t, dspqueue_write_pfn, dspqueue_write, false);\n    dlsym(handle.get(), dspqueue_read_pfn_t, dspqueue_read_pfn, dspqueue_read, false);\n    dlsym(handle.get(), remote_handle64_open_pfn_t, remote_handle64_open_pfn, remote_handle64_open, false);\n    dlsym(handle.get(), remote_handle64_invoke_pfn_t, remote_handle64_invoke_pfn, remote_handle64_invoke, false);\n    dlsym(handle.get(), remote_handle_control_pfn_t, remote_handle_control_pfn, remote_handle_control, false);\n    dlsym(handle.get(), remote_handle64_control_pfn_t, remote_handle64_control_pfn, remote_handle64_control, false);\n    dlsym(handle.get(), remote_session_control_pfn_t, remote_session_control_pfn, remote_session_control, false);\n    dlsym(handle.get(), remote_handle64_close_pfn_t, remote_handle64_close_pfn, remote_handle64_close, false);\n\n    lib_cdsp_rpc_handle = std::move(handle);\n    initialized         = true;\n\n    return AEE_SUCCESS;\n}\n\ndomain * get_domain(int domain_id) {\n    int i    = 0;\n    int size = sizeof(supported_domains) / sizeof(domain);\n\n    for (i = 0; i < size; i++) {\n        if (supported_domains[i].id == domain_id) {\n            return &supported_domains[i];\n        }\n    }\n\n    return NULL;\n}\n\nint get_hex_arch_ver(int domain, int * arch) {\n    if (!remote_handle_control_pfn) {\n        GGML_LOG_ERROR(\"ggml-hex: remote_handle_control is not supported on this device\\n\");\n        return AEE_EUNSUPPORTEDAPI;\n    }\n\n    struct remote_dsp_capability arch_ver;\n    arch_ver.domain       = (uint32_t) domain;\n    arch_ver.attribute_ID = ARCH_VER;\n    arch_ver.capability   = (uint32_t) 0;\n\n    int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));\n    if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {\n        GGML_LOG_ERROR(\"ggml-hex: FastRPC capability API is not supported on this device\\n\");\n        return AEE_EUNSUPPORTEDAPI;\n    }\n\n    if (err != AEE_SUCCESS) {\n        GGML_LOG_ERROR(\"ggml-hex: FastRPC capability query failed (err %d)\\n\", err);\n        return err;\n    }\n\n    switch (arch_ver.capability & 0xff) {\n        case 0x68:\n            *arch = 68;\n            return 0;\n        case 0x69:\n            *arch = 69;\n            return 0;\n        case 0x73:\n            *arch = 73;\n            return 0;\n        case 0x75:\n            *arch = 75;\n            return 0;\n        case 0x79:\n            *arch = 79;\n            return 0;\n        case 0x81:\n            *arch = 81;\n            return 0;\n    }\n    return -1;\n}\n"
  },
  {
    "path": "src/ggml-hexagon/htp-drv.h",
    "content": "#pragma once\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n#ifdef _WIN32\n#    pragma clang diagnostic ignored \"-Wignored-attributes\"\n#endif\n\n#include <AEEStdErr.h>\n#include <rpcmem.h>\n#include <remote.h>\n#include <dspqueue.h>\n\n#if defined(_WIN32) && !defined(__MINGW32__)\n#    ifdef GGML_BACKEND_BUILD\n#        define HTPDRV_API __declspec(dllexport) extern\n#    else\n#        define HTPDRV_API __declspec(dllimport) extern\n#    endif\n#else\n#    define HTPDRV_API __attribute__ ((visibility (\"default\"))) extern\n#endif\n\n/* Offset to differentiate HLOS and Hexagon error codes.\n   Stores the value of AEE_EOFFSET for Hexagon. */\n#ifndef DSP_OFFSET\n#    define DSP_OFFSET 0x80000400\n#endif\n\n/* Errno for connection reset by peer. */\n#ifndef ECONNRESET\n#    ifdef __hexagon__\n#        define ECONNRESET 104\n#    endif\n#endif\n\n/* Abstraction of different OS specific sleep APIs.\n   SLEEP accepts input in seconds. */\n#ifndef SLEEP\n#    ifdef __hexagon__\n#        define SLEEP(x)                      \\\n            { /* Do nothing for simulator. */ \\\n            }\n#    else\n#        ifdef _WIN32\n#            define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */\n#        else\n#            define SLEEP(x) sleep(x)        /* sleep accepts input in seconds. */\n#        endif\n#    endif\n#endif\n\n/* Include windows specific header files. */\n#ifdef _WIN32\n#    include <windows.h>\n#    include <sysinfoapi.h>\n#    define _CRT_SECURE_NO_WARNINGS         1\n#    define _WINSOCK_DEPRECATED_NO_WARNINGS 1\n#endif\n\n/* Includes and defines for all HLOS except windows */\n#if !defined(__hexagon__) && !defined(_WIN32)\n#    include \"unistd.h\"\n\n#    include <sys/time.h>\n#endif\n\n/* Includes and defines for Hexagon and all HLOS except Windows. */\n#if !defined(_WIN32)\n/* Weak reference to remote symbol for compilation. */\n#    pragma weak remote_session_control\n#    pragma weak remote_handle_control\n#    pragma weak remote_handle64_control\n#    pragma weak fastrpc_mmap\n#    pragma weak fastrpc_munmap\n#    pragma weak rpcmem_alloc2\n#endif\n\n#if !defined(_WIN32)\n#    pragma weak remote_system_request\n#endif\n\n#ifdef _WIN32\n#     define DSPQUEUE_TIMEOUT DSPQUEUE_TIMEOUT_NONE\n#else\n#     define DSPQUEUE_TIMEOUT 1000000\n#endif\n\n/**\n * htpdrv_init API: driver interface entry point\n *\n * @return      Return AEE error codes as defined in Hexagon SDK.\n */\nHTPDRV_API int htpdrv_init(void);\n\n/**\n * get_domain API: get domain struct from domain value.\n *\n * @param[in]  domain value of a domain\n * @return     Returns domain struct of the domain if it is supported or else\n *             returns NULL.\n *\n */\nHTPDRV_API domain * get_domain(int domain_id);\n\n/**\n * get_hex_arch_ver API: query the Hexagon processor architecture version information\n *\n * @param[in]   domain_id value of a domain\n * @param[out]  Arch version (73, 75, ...)\n * @return      0 if query is successful.\n *              non-zero if error, return value points to the error.\n *\n */\nHTPDRV_API int get_hex_arch_ver(int domain, int * arch);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-hexagon/libdl.h",
    "content": "#pragma once\n\n#ifdef _WIN32\n#   define WIN32_LEAN_AND_MEAN\n#   ifndef NOMINMAX\n#       define NOMINMAX\n#   endif\n#   include <windows.h>\n#   include <winevt.h>\n#else\n#    include <dlfcn.h>\n#    include <unistd.h>\n#endif\n#include <filesystem>\n\nnamespace fs = std::filesystem;\n\n#ifdef _WIN32\n\nusing dl_handle = std::remove_pointer_t<HMODULE>;\n\nstruct dl_handle_deleter {\n    void operator()(HMODULE handle) {\n        FreeLibrary(handle);\n    }\n};\n\nstatic inline dl_handle * dl_load_library(const fs::path & path) {\n    // suppress error dialogs for missing DLLs\n    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);\n    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);\n\n    HMODULE handle = LoadLibraryW(path.wstring().c_str());\n\n    SetErrorMode(old_mode);\n\n    return handle;\n}\n\nstatic inline void * dl_get_sym(dl_handle * handle, const char * name) {\n    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);\n    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);\n\n    void * p = (void *) GetProcAddress(handle, name);\n\n    SetErrorMode(old_mode);\n\n    return p;\n}\n\nstatic inline const char * dl_error() {\n    return \"\";\n}\n\n#else\n\nusing dl_handle = void;\n\nstruct dl_handle_deleter {\n    void operator()(void * handle) {\n        dlclose(handle);\n    }\n};\n\nstatic inline dl_handle * dl_load_library(const fs::path & path) {\n    dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);\n    return handle;\n}\n\nstatic inline void * dl_get_sym(dl_handle * handle, const char * name) {\n    return dlsym(handle, name);\n}\n\nstatic inline const char * dl_error() {\n    const char *rslt = dlerror();\n    return rslt != nullptr ? rslt : \"\";\n}\n\n#endif\n"
  },
  {
    "path": "src/ggml-hexagon/libggml-htp.inf",
    "content": "[Version]\nSignature   = \"$WINDOWS NT$\"\nClass       = ComputeAccelerator\nClassGuid   = {F01A9D53-3FF6-48D2-9F97-C8A7004BE10C}\nProvider    = %GGML%\nDriverVer   = 01/01/2026,1.0.0.0\nCatalogFile = libggml-htp.cat\nPnpLockDown = 1\n\n[DestinationDirs]\nDrivers_Dir = 6\n\n[SourceDisksNames]\n1 = %DiskId%\n\n[SourceDisksFiles]\nlibggml-htp-v68.so = 1\nlibggml-htp-v69.so = 1\nlibggml-htp-v73.so = 1\nlibggml-htp-v75.so = 1\nlibggml-htp-v81.so = 1\n\n[ControlFlags]\nExcludeFromSelect = *\n\n[DefaultInstall.NTarm64]\nCopyFiles=Drivers_Dir\n\n[Drivers_Dir]\nlibggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE\nlibggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE\nlibggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE\nlibggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE\nlibggml-htp-v81.so,,,0x10 ;COPYFLG_NO_OVERWRITE\n\n[Strings]\nGGML   = 'GGML'\nDiskId = 'GGML HTP library'\n"
  },
  {
    "path": "src/ggml-hexagon/op-desc.h",
    "content": "#ifndef OP_DESC_H\n#define OP_DESC_H\n\n#define GGML_COMMON_IMPL_CPP\n#include \"ggml-backend-impl.h\"\n#include \"ggml-common.h\"\n\n#include <string>\n#include <stdio.h>\n\nstruct op_desc {\n    char strides[64 * GGML_MAX_SRC];\n    char dims[64 * GGML_MAX_SRC];\n    char types[16 * GGML_MAX_SRC];\n    char buffs[64 * GGML_MAX_SRC];\n    char names[64 * GGML_MAX_SRC];\n\n    int format_tensor_dims(char * str, const struct ggml_tensor * t) {\n        if (t->ne[2] == 1 && t->ne[3] == 1) {\n            return sprintf(str, \"%d:%d\", (int) t->ne[0], (int) t->ne[1]);\n        } else {\n            return sprintf(str, \"%d:%d:%d:%d\", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);\n        }\n    }\n\n    void format_op_dims(char * str, const struct ggml_tensor * t) {\n        char * p = str;\n\n        // append src0 and src1 (if any)\n        if (t->src[0]) {\n            p += format_tensor_dims(p, t->src[0]);\n\n            for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {\n                p += sprintf(p, \" x \");\n                p += format_tensor_dims(p, t->src[i]);\n            }\n\n            p += sprintf(p, \" -> \");\n        }\n\n        // format self dims separately for better visual alignment\n        char self[64];\n        format_tensor_dims(self, t);\n\n        p += sprintf(p, \"%s\", self);\n    }\n\n    int format_tensor_strides(char * str, const struct ggml_tensor * t) {\n        const char * c = ggml_is_contiguous(t) ? \"\" : \"!\";\n\n        if (t->ne[2] == 1 && t->ne[3] == 1) {\n            return sprintf(str, \"%zu:%zu%s\", (size_t) t->nb[0], (size_t) t->nb[1], c);\n        } else {\n            return sprintf(str, \"%zu:%zu:%zu:%zu%s\", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);\n        }\n    }\n\n    void format_op_strides(char * str, const struct ggml_tensor * t) {\n        char * p = str;\n\n        // append src0 and src1 (if any)\n        if (t->src[0]) {\n            p += format_tensor_strides(p, t->src[0]);\n\n            for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {\n                p += sprintf(p, \" x \");\n                p += format_tensor_strides(p, t->src[i]);\n            }\n\n            p += sprintf(p, \" -> \");\n        }\n\n        // format self dims separately for better visual alignment\n        char self[64];\n        format_tensor_strides(self, t);\n\n        p += sprintf(p, \"%s\", self);\n    }\n\n    void format_op_types(char * str, const struct ggml_tensor * t) {\n        char * p = str;\n\n        // append src0 and src1 (if any)\n        if (t->src[0]) {\n            p += sprintf(p, \"%s\", ggml_type_name(t->src[0]->type));\n\n            for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {\n                p += sprintf(p, \" x \");\n                p += sprintf(p, \"%s\", ggml_type_name(t->src[i]->type));\n            }\n\n            p += sprintf(p, \" -> \");\n        }\n\n        p += sprintf(p, \"%s\", ggml_type_name(t->type));\n    }\n\n    const char * tensor_buff_name(const struct ggml_tensor * t) {\n        if (t->buffer) {\n            return ggml_backend_buffer_name(t->buffer);\n        }\n        return \"NONE\";\n    }\n\n    void format_op_buffs(char * str, const struct ggml_tensor * t) {\n        char * p = str;\n\n        // append src0 and src1 (if any)\n        if (t->src[0]) {\n            p += sprintf(p, \"%s\", tensor_buff_name(t->src[0]));\n\n            for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {\n                p += sprintf(p, \" x \");\n                p += sprintf(p, \"%s\", tensor_buff_name(t->src[i]));\n            }\n\n            p += sprintf(p, \" -> \");\n        }\n\n        p += sprintf(p, \"%s\", tensor_buff_name(t));\n    }\n\n    void format_op_names(char * str, const struct ggml_tensor * t) {\n        char * p = str;\n\n        // append src0 and src1 (if any)\n        if (t->src[0]) {\n            p += sprintf(p, \"%s\", t->src[0]->name);\n\n            for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {\n                p += sprintf(p, \" x \");\n                p += sprintf(p, \"%s\", t->src[i]->name);\n            }\n\n            p += sprintf(p, \" -> \");\n        }\n\n        p += sprintf(p, \"%s\", t->name);\n    }\n\n    void format(const ggml_tensor * op) {\n        format_op_dims(dims, op);\n        format_op_strides(strides, op);\n        format_op_types(types, op);\n        format_op_buffs(buffs, op);\n        format_op_names(names, op);\n    }\n\n    op_desc() {}\n    op_desc(const ggml_tensor * op) { format(op); }\n};\n\n#endif // OP_DESC_H\n"
  },
  {
    "path": "src/ggml-hip/CMakeLists.txt",
    "content": "if (NOT EXISTS $ENV{ROCM_PATH})\n    if (NOT EXISTS /opt/rocm)\n        set(ROCM_PATH /usr)\n    else()\n        set(ROCM_PATH /opt/rocm)\n    endif()\nelse()\n    set(ROCM_PATH $ENV{ROCM_PATH})\nendif()\n\nlist(APPEND CMAKE_PREFIX_PATH  ${ROCM_PATH})\nlist(APPEND CMAKE_PREFIX_PATH \"${ROCM_PATH}/lib64/cmake\")\n\nif (NOT DEFINED CMAKE_HIP_FLAGS_DEBUG)\n    set(CMAKE_HIP_FLAGS_DEBUG \"-g -O2\")\nendif()\n\n# CMake on Windows doesn't support the HIP language yet\nif (WIN32)\n    set(CXX_IS_HIPCC TRUE)\nelse()\n    string(REGEX MATCH \"hipcc(\\.bat)?$\" CXX_IS_HIPCC \"${CMAKE_CXX_COMPILER}\")\nendif()\n\nif (CXX_IS_HIPCC)\n    if (LINUX)\n        if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES \"Clang\")\n            message(WARNING \"Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++\")\n        endif()\n\n        message(WARNING \"Setting hipcc as the C++ compiler is legacy behavior.\"\n                \" Prefer setting the HIP compiler directly. See README for details.\")\n    endif()\nelse()\n    # Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES.\n    if(AMDGPU_TARGETS AND NOT GPU_TARGETS)\n        set(GPU_TARGETS ${AMDGPU_TARGETS})\n    endif()\n    if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)\n        set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS})\n    endif()\n    cmake_minimum_required(VERSION 3.21)\n    enable_language(HIP)\nendif()\n\nfind_package(hip     REQUIRED)\nfind_package(hipblas REQUIRED)\nfind_package(rocblas REQUIRED)\n\nif (${hip_VERSION} VERSION_LESS 6.1)\n    message(FATAL_ERROR \"At least ROCM/HIP V6.1 is required\")\nendif()\n\nmessage(STATUS \"HIP and hipBLAS found\")\n\n# Workaround old compilers\nset(CMAKE_HIP_FLAGS \"${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024\")\n\nfile(GLOB   GGML_HEADERS_ROCM \"../ggml-cuda/*.cuh\")\nlist(APPEND GGML_HEADERS_ROCM \"../../include/ggml-cuda.h\")\n\nfile(GLOB   GGML_SOURCES_ROCM \"../ggml-cuda/*.cu\")\nfile(GLOB   SRCS \"../ggml-cuda/template-instances/fattn-tile*.cu\")\nlist(APPEND GGML_SOURCES_ROCM ${SRCS})\nfile(GLOB   SRCS \"../ggml-cuda/template-instances/fattn-mma*.cu\")\nlist(APPEND GGML_SOURCES_ROCM ${SRCS})\nfile(GLOB   SRCS \"../ggml-cuda/template-instances/mmq*.cu\")\nlist(APPEND GGML_SOURCES_ROCM ${SRCS})\nfile(GLOB   SRCS \"../ggml-cuda/template-instances/mmf*.cu\")\nlist(APPEND GGML_SOURCES_ROCM ${SRCS})\n\nif (GGML_CUDA_FA_ALL_QUANTS)\n    file(GLOB   SRCS \"../ggml-cuda/template-instances/fattn-vec*.cu\")\n    list(APPEND GGML_SOURCES_ROCM ${SRCS})\n    add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)\nelse()\n    file(GLOB   SRCS \"../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu\")\n    list(APPEND GGML_SOURCES_ROCM ${SRCS})\n    file(GLOB   SRCS \"../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu\")\n    list(APPEND GGML_SOURCES_ROCM ${SRCS})\n    file(GLOB   SRCS \"../ggml-cuda/template-instances/fattn-vec*f16-f16.cu\")\n    list(APPEND GGML_SOURCES_ROCM ${SRCS})\nendif()\n\nggml_add_backend_library(ggml-hip\n                         ${GGML_HEADERS_ROCM}\n                         ${GGML_SOURCES_ROCM}\n                        )\n\n# TODO: do not use CUDA definitions for HIP\nif (NOT GGML_BACKEND_DL)\n    target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)\nendif()\n\nadd_compile_definitions(GGML_USE_HIP)\n\nif (GGML_CUDA_FORCE_MMQ)\n    add_compile_definitions(GGML_CUDA_FORCE_MMQ)\nendif()\n\nif (GGML_CUDA_FORCE_CUBLAS)\n    add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)\nendif()\n\nif (GGML_CUDA_NO_PEER_COPY)\n    add_compile_definitions(GGML_CUDA_NO_PEER_COPY)\nendif()\n\nif (GGML_HIP_GRAPHS)\n    add_compile_definitions(GGML_HIP_GRAPHS)\nendif()\n\nif (GGML_HIP_NO_VMM)\n    add_compile_definitions(GGML_HIP_NO_VMM)\nendif()\n\nif (GGML_HIP_ROCWMMA_FATTN)\n    add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)\nendif()\n\nif (NOT GGML_HIP_MMQ_MFMA)\n    add_compile_definitions(GGML_HIP_NO_MMQ_MFMA)\nendif()\n\nif (GGML_HIP_EXPORT_METRICS)\n    set(CMAKE_HIP_FLAGS \"${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps\")\nendif()\n\nif (NOT GGML_CUDA_FA)\n    add_compile_definitions(GGML_CUDA_NO_FA)\nendif()\n\nif (CXX_IS_HIPCC)\n    set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)\n    target_link_libraries(ggml-hip PRIVATE hip::device)\nelse()\n    set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP)\nendif()\n\nif (GGML_STATIC)\n    message(FATAL_ERROR \"Static linking not supported for HIP/ROCm\")\nendif()\n\ntarget_link_libraries(ggml-hip PRIVATE ggml-base hip::host roc::rocblas roc::hipblas)\n"
  },
  {
    "path": "src/ggml-impl.h",
    "content": "#pragma once\n\n// GGML internal header\n\n#include \"ggml.h\"\n#include \"gguf.h\"\n\n#include <assert.h>\n#include <math.h>\n#include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/\n#include <stdbool.h>\n#include <stdint.h>\n#include <string.h>\n\n#ifdef __ARM_FEATURE_SVE\n#include <arm_sve.h>\n#endif // __ARM_FEATURE_SVE\n\n#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)\n// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:\n//\n//   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/\n//\n#include <arm_neon.h>\n#endif\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nvoid ggml_print_backtrace(void);\n\n#ifndef MIN\n#    define MIN(a, b) ((a) < (b) ? (a) : (b))\n#endif\n\n#ifndef MAX\n#    define MAX(a, b) ((a) > (b) ? (a) : (b))\n#endif\n\n// required for mmap as gguf only guarantees 32-byte alignment\n#define TENSOR_ALIGNMENT 32\n\n// static_assert should be a #define, but if it's not,\n// fall back to the _Static_assert C11 keyword.\n// if C99 - static_assert is noop\n// ref: https://stackoverflow.com/a/53923785/4039976\n#ifndef __cplusplus\n    #ifndef static_assert\n        #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)\n            #define static_assert(cond, msg) _Static_assert(cond, msg)\n        #else\n            #define static_assert(cond, msg) struct global_scope_noop_trick\n        #endif\n    #endif\n#endif\n\nstatic inline int ggml_up32(int n) {\n    return (n + 31) & ~31;\n}\n\n//static inline int ggml_up64(int n) {\n//    return (n + 63) & ~63;\n//}\n\nstatic inline int ggml_up(int n, int m) {\n    // assert m is a power of 2\n    GGML_ASSERT((m & (m - 1)) == 0);\n    return (n + m - 1) & ~(m - 1);\n}\n\n// TODO: move to ggml.h? (won't be able to inline)\nstatic bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {\n    if (a->type != b->type) {\n        return false;\n    }\n    for (int i = 0; i < GGML_MAX_DIMS; i++) {\n        if (a->ne[i] != b->ne[i]) {\n            return false;\n        }\n        if (a->nb[i] != b->nb[i]) {\n            return false;\n        }\n    }\n    return true;\n}\n\nstatic bool ggml_op_is_empty(enum ggml_op op) {\n    switch (op) {\n        case GGML_OP_NONE:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_TRANSPOSE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n            return true;\n        default:\n            return false;\n    }\n}\n\nstatic inline bool ggml_impl_is_view(const struct ggml_tensor * t) {\n    return t->view_src != NULL;\n}\n\nstatic inline float ggml_compute_softplus_f32(float input) {\n    return (input > 20.0f) ? input : logf(1 + expf(input));\n}\n//\n// logging\n//\n\nGGML_ATTRIBUTE_FORMAT(2, 3)\nGGML_API void ggml_log_internal        (enum ggml_log_level level, const char * format, ...);\nGGML_API void ggml_log_callback_default(enum ggml_log_level level, const char * text, void * user_data);\n\n#define GGML_LOG(...)       ggml_log_internal(GGML_LOG_LEVEL_NONE , __VA_ARGS__)\n#define GGML_LOG_INFO(...)  ggml_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)\n#define GGML_LOG_WARN(...)  ggml_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)\n#define GGML_LOG_ERROR(...) ggml_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)\n#define GGML_LOG_DEBUG(...) ggml_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)\n#define GGML_LOG_CONT(...)  ggml_log_internal(GGML_LOG_LEVEL_CONT , __VA_ARGS__)\n\n#define GGML_DEBUG 0\n\n#if (GGML_DEBUG >= 1)\n#define GGML_PRINT_DEBUG(...) GGML_LOG_DEBUG(__VA_ARGS__)\n#else\n#define GGML_PRINT_DEBUG(...)\n#endif\n\n#if (GGML_DEBUG >= 5)\n#define GGML_PRINT_DEBUG_5(...) GGML_LOG_DEBUG(__VA_ARGS__)\n#else\n#define GGML_PRINT_DEBUG_5(...)\n#endif\n\n#if (GGML_DEBUG >= 10)\n#define GGML_PRINT_DEBUG_10(...) GGML_LOG_DEBUG(__VA_ARGS__)\n#else\n#define GGML_PRINT_DEBUG_10(...)\n#endif\n\n// tensor params\n\nstatic void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) {\n    GGML_ASSERT(tensor != NULL); // silence -Warray-bounds warnings\n    assert(params_size <= GGML_MAX_OP_PARAMS);\n    memcpy(tensor->op_params, params, params_size);\n}\n\nstatic int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) {\n    assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));\n    return ((const int32_t *)(tensor->op_params))[i];\n}\n\nstatic float ggml_get_op_params_f32(const struct ggml_tensor * tensor, uint32_t i) {\n    assert(i < GGML_MAX_OP_PARAMS / sizeof(float));\n    return ((const float *)(tensor->op_params))[i];\n}\n\nstatic void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) {\n    assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));\n    ((int32_t *)(tensor->op_params))[i] = value;\n}\n\nstatic void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, float value) {\n    assert(i < GGML_MAX_OP_PARAMS / sizeof(float));\n    ((float *)(tensor->op_params))[i] = value;\n}\n\nstruct ggml_map_custom1_op_params {\n    ggml_custom1_op_t  fun;\n    int                n_tasks;\n    void             * userdata;\n};\n\nstruct ggml_map_custom2_op_params {\n    ggml_custom2_op_t   fun;\n    int                 n_tasks;\n    void              * userdata;\n};\n\nstruct ggml_map_custom3_op_params {\n    ggml_custom3_op_t fun;\n    int               n_tasks;\n    void            * userdata;\n};\n\nstruct ggml_custom_op_params {\n    ggml_custom_op_t fun;\n    int              n_tasks;\n    void           * userdata;\n};\n\n// bitset\n\ntypedef uint32_t ggml_bitset_t;\n\nstatic_assert(sizeof(ggml_bitset_t) == 4, \"bitset_t constants must be updated\");\n#define BITSET_SHR 5 // log2(sizeof(ggml_bitset_t)*8)\n#define BITSET_MASK (sizeof(ggml_bitset_t)*8 - 1)\n\nstatic size_t ggml_bitset_size(size_t n) {\n    return (n + BITSET_MASK) >> BITSET_SHR;\n}\n\nstatic inline bool ggml_bitset_get(const ggml_bitset_t * bitset, size_t i) {\n    return !!(bitset[i >> BITSET_SHR] & (1u << (i & BITSET_MASK)));\n}\n\nstatic inline void ggml_bitset_set(ggml_bitset_t * bitset, size_t i) {\n    bitset[i >> BITSET_SHR] |= (1u << (i & BITSET_MASK));\n}\n\nstatic inline void ggml_bitset_clear(ggml_bitset_t * bitset, size_t i) {\n    bitset[i >> BITSET_SHR] &= ~(1u << (i & BITSET_MASK));\n}\n\n// hash set\n\n#define GGML_HASHSET_FULL ((size_t)-1)\n#define GGML_HASHSET_ALREADY_EXISTS ((size_t)-2)\n\nstruct ggml_hash_set {\n    size_t size;\n    ggml_bitset_t * used;       // whether or not the keys are in use i.e. set\n    struct ggml_tensor ** keys; // actual tensors in the set, keys[i] is only defined if ggml_bitset_get(used, i)\n};\n\nstruct ggml_hash_set ggml_hash_set_new(size_t size);\nvoid                 ggml_hash_set_free(struct ggml_hash_set * hash_set);\n\n// returns the minimum size for a hash set that can hold min_sz elements\nsize_t ggml_hash_size(size_t min_sz);\n\n// remove all elements from the hash set\nvoid ggml_hash_set_reset(struct ggml_hash_set * hash_set);\n\n// returns true if key is in the hash set\nstatic bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key);\n\n// returns GGML_HASHSET_FULL if table is full, otherwise the current index of the key or where it should be inserted\nstatic size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key);\n\n// returns GGML_HASHSET_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full\nstatic size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key);\n\n// return index, asserts if table is full\nstatic size_t ggml_hash_find_or_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key);\n\n// hash function for ggml_tensor\nstatic inline size_t ggml_hash(const struct ggml_tensor * p) {\n    // the last 4 bits are always zero due to alignment\n    return (size_t)(uintptr_t)p >> 4;\n}\n\nstatic size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key) {\n    size_t h = ggml_hash(key) % hash_set->size;\n\n    // linear probing\n    size_t i = h;\n    while (ggml_bitset_get(hash_set->used, i) && hash_set->keys[i] != key) {\n        i = (i + 1) % hash_set->size;\n        if (i == h) {\n            // visited all hash table entries -> not found\n            return GGML_HASHSET_FULL;\n        }\n    }\n    return i;\n}\n\nstatic bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key) {\n    size_t i = ggml_hash_find(hash_set, key);\n    return i != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, i);\n}\n\nstatic size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key) {\n    size_t h = ggml_hash(key) % hash_set->size;\n\n    // linear probing\n    size_t i = h;\n    do {\n        if (!ggml_bitset_get(hash_set->used, i)) {\n            ggml_bitset_set(hash_set->used, i);\n            hash_set->keys[i] = key;\n            return i;\n        }\n        if (hash_set->keys[i] == key) {\n            return GGML_HASHSET_ALREADY_EXISTS;\n        }\n        i = (i + 1) % hash_set->size;\n    } while (i != h);\n\n    // visited all hash table entries -> not found\n    GGML_ABORT(\"fatal error\");\n}\n\nstatic size_t ggml_hash_find_or_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key) {\n    size_t h = ggml_hash(key) % hash_set->size;\n\n    // linear probing\n    size_t i = h;\n    do {\n        if (!ggml_bitset_get(hash_set->used, i)) {\n            ggml_bitset_set(hash_set->used, i);\n            hash_set->keys[i] = key;\n            return i;\n        }\n        if (hash_set->keys[i] == key) {\n            return i;\n        }\n        i = (i + 1) % hash_set->size;\n    } while (i != h);\n\n    // visited all hash table entries -> not found\n    GGML_ABORT(\"fatal error\");\n}\n\n// computation graph\n\nenum ggml_cgraph_eval_order {\n    GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,\n    GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,\n    GGML_CGRAPH_EVAL_ORDER_COUNT\n};\n\nstruct ggml_cgraph {\n    int size;    // maximum number of nodes/leafs/grads/grad_accs\n    int n_nodes; // number of nodes currently in use\n    int n_leafs; // number of leafs currently in use\n\n    struct ggml_tensor ** nodes;     // tensors with data that can change if the graph is evaluated\n    struct ggml_tensor ** grads;     // the outputs of these tensors are the gradients of the nodes\n    struct ggml_tensor ** grad_accs; // accumulators for node gradients\n    struct ggml_tensor ** leafs;     // tensors with constant data\n    int32_t             * use_counts;// number of uses of each tensor, indexed by hash table slot\n\n    struct ggml_hash_set visited_hash_set;\n\n    enum ggml_cgraph_eval_order order;\n};\n\n// returns a slice of cgraph with nodes [i0, i1)\n// the slice does not have leafs or gradients\n// if you need the gradients, get them from the original graph\nstruct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1);\n\n// ggml-alloc.c: true if the operation can reuse memory from its sources\nGGML_API bool ggml_op_can_inplace(enum ggml_op op);\n\n\n// Memory allocation\n\nGGML_API void * ggml_aligned_malloc(size_t size);\nGGML_API void ggml_aligned_free(void * ptr, size_t size);\n\n// FP16 <-> FP32\n// ref: https://github.com/Maratyszcza/FP16\n\nstatic inline float fp32_from_bits(uint32_t w) {\n    union {\n        uint32_t as_bits;\n        float as_value;\n    } fp32;\n    fp32.as_bits = w;\n    return fp32.as_value;\n}\n\nstatic inline uint32_t fp32_to_bits(float f) {\n    union {\n        float as_value;\n        uint32_t as_bits;\n    } fp32;\n    fp32.as_value = f;\n    return fp32.as_bits;\n}\n\nstatic inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {\n    const uint32_t w = (uint32_t) h << 16;\n    const uint32_t sign = w & UINT32_C(0x80000000);\n    const uint32_t two_w = w + w;\n\n    const uint32_t exp_offset = UINT32_C(0xE0) << 23;\n#if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)\n    const float exp_scale = 0x1.0p-112f;\n#else\n    const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));\n#endif\n    const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;\n\n    const uint32_t magic_mask = UINT32_C(126) << 23;\n    const float magic_bias = 0.5f;\n    const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;\n\n    const uint32_t denormalized_cutoff = UINT32_C(1) << 27;\n    const uint32_t result = sign |\n        (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));\n    return fp32_from_bits(result);\n}\n\nstatic inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {\n#if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)\n    const float scale_to_inf = 0x1.0p+112f;\n    const float scale_to_zero = 0x1.0p-110f;\n#else\n    const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));\n    const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));\n#endif\n    float base = (fabsf(f) * scale_to_inf) * scale_to_zero;\n\n    const uint32_t w = fp32_to_bits(f);\n    const uint32_t shl1_w = w + w;\n    const uint32_t sign = w & UINT32_C(0x80000000);\n    uint32_t bias = shl1_w & UINT32_C(0xFF000000);\n    if (bias < UINT32_C(0x71000000)) {\n        bias = UINT32_C(0x71000000);\n    }\n\n    base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;\n    const uint32_t bits = fp32_to_bits(base);\n    const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);\n    const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);\n    const uint32_t nonsign = exp_bits + mantissa_bits;\n    return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);\n}\n\n#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)\n#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)\n\n#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)\n#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)\n\nstatic inline float ggml_e8m0_to_fp32(uint8_t x) {\n    uint32_t bits;  // Stores the raw bit representation of the float\n\n    // Handle special case for minimum exponent (denormalized float)\n    if (x == 0) {\n        // Bit pattern for 2^(-127):\n        // - Sign bit: 0 (positive)\n        // - Exponent: 0 (denormalized number)\n        // - Mantissa: 0x400000 (0.5 in fractional form)\n        // Value = 0.5 * 2^(-126) = 2^(-127)\n        bits = 0x00400000;\n    }\n    // note: disabled as we don't need to handle NaNs\n    //// Handle special case for NaN (all bits set)\n    //else if (x == 0xFF) {\n    //    // Standard quiet NaN pattern:\n    //    // - Sign bit: 0\n    //    // - Exponent: all 1s (0xFF)\n    //    // - Mantissa: 0x400000 (quiet NaN flag)\n    //    bits = 0x7FC00000;\n    //}\n    // Normalized values (most common case)\n    else {\n        // Construct normalized float by shifting exponent into position:\n        // - Exponent field: 8 bits (positions 30-23)\n        // - Mantissa: 0 (implicit leading 1)\n        // Value = 2^(x - 127)\n        bits = (uint32_t) x << 23;\n    }\n\n    float result;  // Final float value\n                   // Safely reinterpret bit pattern as float without type-punning issues\n    memcpy(&result, &bits, sizeof(float));\n    return result;\n}\n\n// Equal to ggml_e8m0_to_fp32/2\n// Useful with MXFP4 quantization since the E0M2 values are doubled\nstatic inline float ggml_e8m0_to_fp32_half(uint8_t x) {\n    uint32_t bits;\n\n    // For x < 2: use precomputed denormal patterns\n    if (x < 2) {\n        // 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)\n        bits = 0x00200000 << x;\n    }\n    // For x >= 2: normalized exponent adjustment\n    else {\n        // 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)\n        bits = (uint32_t)(x - 1) << 23;\n    }\n    // Note: NaNs are not handled here\n\n    float result;\n    memcpy(&result, &bits, sizeof(float));\n    return result;\n}\n\n#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)\n#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)\n\n// UE4M3: unsigned, 4 exp bits (bias=7), 3 mantissa bits\n// Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float)\nstatic inline float ggml_ue4m3_to_fp32(uint8_t x) {\n    if (x == 0 || x == 0x7F) {\n        return 0.0f;\n    }\n    int   exp = (x >> 3) & 0xF;\n    int   man = x & 0x7;\n    float raw;\n    if (exp == 0) {\n        raw = ldexpf((float) man, -9);\n    } else {\n        raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7);\n    }\n    return raw * 0.5f;\n}\n\nstatic inline uint8_t ggml_fp32_to_ue4m3(float x) {\n    if (!(x > 0.0f)) {\n        return 0;\n    }\n    if (x > 448.0f) {\n        x = 448.0f;\n    }\n    uint32_t bits;\n    memcpy(&bits, &x, 4);\n    int fp32_exp  = ((bits >> 23) & 0xFF) - 127;\n    int fp32_man  = (bits >> 20) & 0x7;\n    int ue4m3_exp = fp32_exp + 7;\n    if (ue4m3_exp <= 0) {\n        // subnormal: value = man * 2^-9, man = round(x * 2^9)\n        int man = (int) (x * 512.0f + 0.5f);\n        if (man > 7) {\n            man = 7;\n        }\n        if (man < 1) {\n            return 0;\n        }\n        return (uint8_t) man;\n    }\n    if (ue4m3_exp >= 15) {\n        return 0x7E;\n    }\n    int round_bit = (bits >> 19) & 1;\n    int ue4m3_man = fp32_man + round_bit;\n    if (ue4m3_man > 7) {\n        ue4m3_man = 0;\n        ue4m3_exp++;\n        if (ue4m3_exp >= 15) {\n            return 0x7E;\n        }\n    }\n    return (uint8_t) ((ue4m3_exp << 3) | ue4m3_man);\n}\n\n/**\n * Converts brain16 to float32.\n *\n * The bfloat16 floating point format has the following structure:\n *\n *       ┌sign\n *       │\n *       │   ┌exponent\n *       │   │\n *       │   │      ┌mantissa\n *       │   │      │\n *       │┌──┴───┐┌─┴───┐\n *     0b0000000000000000 brain16\n *\n * Since bf16 has the same number of exponent bits as a 32bit float,\n * encoding and decoding numbers becomes relatively straightforward.\n *\n *       ┌sign\n *       │\n *       │   ┌exponent\n *       │   │\n *       │   │      ┌mantissa\n *       │   │      │\n *       │┌──┴───┐┌─┴───────────────────┐\n *     0b00000000000000000000000000000000 IEEE binary32\n *\n * For comparison, the standard fp16 format has fewer exponent bits.\n *\n *       ┌sign\n *       │\n *       │  ┌exponent\n *       │  │\n *       │  │    ┌mantissa\n *       │  │    │\n *       │┌─┴─┐┌─┴──────┐\n *     0b0000000000000000 IEEE binary16\n *\n * @see IEEE 754-2008\n */\nstatic inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {\n    union {\n        float f;\n        uint32_t i;\n    } u;\n    u.i = (uint32_t)h.bits << 16;\n    return u.f;\n}\n\n/**\n * Converts float32 to brain16.\n *\n * This is binary identical with Google Brain float conversion.\n * Floats shall round to nearest even, and NANs shall be quiet.\n * Subnormals aren't flushed to zero, except perhaps when used.\n * This code should vectorize nicely if using modern compilers.\n */\nstatic inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {\n    ggml_bf16_t h;\n    union {\n        float f;\n        uint32_t i;\n    } u;\n    u.f = s;\n    if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */\n        h.bits = (u.i >> 16) | 64; /* force to quiet */\n        return h;\n    }\n    h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;\n    return h;\n}\n\n#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)\n#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)\n\nstatic inline int32_t ggml_node_get_use_count(const struct ggml_cgraph * cgraph, int node_idx) {\n    const struct ggml_tensor * node = cgraph->nodes[node_idx];\n\n    size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);\n    if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) {\n        return 0;\n    }\n    return cgraph->use_counts[hash_pos];\n}\n\n// return true if the node's results are only used by N other nodes\n// and can be fused into their calculations.\nstatic inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {\n    const struct ggml_tensor * node = cgraph->nodes[node_idx];\n\n    // check the use count against how many we're replacing\n    if (ggml_node_get_use_count(cgraph, node_idx) != n_uses) {\n        return false;\n    }\n\n    // if node is a view, some other node might be using the intermediate result\n    // via the view source.\n    if (node->view_src) {\n        return false;\n    }\n\n    // If the user requested output for the node, can't fuse\n    if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {\n        return false;\n    }\n\n    return true;\n}\n\n// Returns true if nodes with indices { node_idxs } are the sequence of ggml_ops in ops[]\n// and are fusable. Nodes are considered fusable according to this function if:\n// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).\n// - all nodes except the last are a src of the following node.\n// - all nodes are the same shape.\n// TODO: Consider allowing GGML_OP_NONE nodes in between\nstatic inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const int * node_idxs, const enum ggml_op * ops, int num_ops) {\n    for (int i = 0; i < num_ops; ++i) {\n        if (node_idxs[i] >= cgraph->n_nodes) {\n            return false;\n        }\n\n        struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];\n        if (node->op != ops[i]) {\n            return false;\n        }\n        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {\n            return false;\n        }\n        if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {\n            return false;\n        }\n        if (i > 0) {\n            struct ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]];\n            if (node->src[0] != prev && node->src[1] != prev) {\n                return false;\n            }\n            if (!ggml_are_same_shape(node, prev)) {\n                return false;\n            }\n        }\n    }\n    return true;\n}\n\n// same as above, for sequential indices starting at node_idx\nstatic inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {\n    assert(num_ops < 32);\n\n    if (node_idx + num_ops > cgraph->n_nodes) {\n        return false;\n    }\n\n    int idxs[32];\n    for (int i = 0; i < num_ops; ++i) {\n        idxs[i] = node_idx + i;\n    }\n\n    return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);\n}\n\nGGML_API bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,\n                                         const int *                node_idxs,\n                                         int                        count,\n                                         const enum ggml_op *       ops,\n                                         const int *                outputs,\n                                         int                        num_outputs);\n\n// Returns true if the subgraph formed by {node_idxs} can be fused\n// checks whethers all nodes which are not part of outputs can be elided\n// by checking if their num_uses are confined to the subgraph\nstatic inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,\n                                          int                        node_idx,\n                                          int                        count,\n                                          const enum ggml_op *       ops,\n                                          const int *                outputs,\n                                          int                        num_outputs) {\n    GGML_ASSERT(count < 32);\n    if (node_idx + count > cgraph->n_nodes) {\n        return false;\n    }\n\n    int idxs[32];\n\n    for (int i = 0; i < count; ++i) {\n        idxs[i] = node_idx + i;\n    }\n\n    return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);\n}\n\n#ifdef __cplusplus\n}\n#endif\n\n#ifdef __cplusplus\n#include <array>\n#include <initializer_list>\n#include <vector>\n\n// nicer C++ syntax for ggml_can_fuse\ninline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {\n    return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());\n}\n\ninline bool ggml_can_fuse_subgraph(const struct ggml_cgraph *          cgraph,\n                                   int                                 start_idx,\n                                   std::initializer_list<enum ggml_op> ops,\n                                   std::initializer_list<int>          outputs = {}) {\n    return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());\n}\n\n// Return true if the edges in the graph match expectations.\ninline bool ggml_check_edges(const struct ggml_cgraph *                cgraph,\n                             int                                       start_idx,\n                             std::initializer_list<std::array<int, 3>> edges) {\n    for (const auto & edge : edges) {\n        int dst_node = edge[0];\n        int src_idx  = edge[1];\n        int src_node = edge[2];\n        if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {\n            return false;\n        }\n    }\n    return true;\n}\n\n// expose GGUF internals for test code\nGGML_API size_t gguf_type_size(enum gguf_type type);\nGGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);\nGGML_API void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & buf, bool only_meta);\n#endif // __cplusplus\n"
  },
  {
    "path": "src/ggml-metal/CMakeLists.txt",
    "content": "find_library(FOUNDATION_LIBRARY Foundation REQUIRED)\nfind_library(METAL_FRAMEWORK    Metal      REQUIRED)\nfind_library(METALKIT_FRAMEWORK MetalKit   REQUIRED)\n\nmessage(STATUS \"Metal framework found\")\n\nggml_add_backend_library(ggml-metal\n                         ggml-metal.cpp\n                         ggml-metal-device.m\n                         ggml-metal-device.cpp\n                         ggml-metal-common.cpp\n                         ggml-metal-context.m\n                         ggml-metal-ops.cpp\n                        )\n\ntarget_link_libraries(ggml-metal PRIVATE\n                      ${FOUNDATION_LIBRARY}\n                      ${METAL_FRAMEWORK}\n                      ${METALKIT_FRAMEWORK}\n                      )\n\nif (GGML_METAL_NDEBUG)\n    add_compile_definitions(GGML_METAL_NDEBUG)\nendif()\n\nset(METALLIB_COMMON \"${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h\")\nif (GGML_METAL_EMBED_LIBRARY)\n    enable_language(ASM)\n\n    add_compile_definitions(GGML_METAL_EMBED_LIBRARY)\n\n    set(METALLIB_SOURCE \"${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal\")\n    set(METALLIB_IMPL   \"${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h\")\n\n    file(MAKE_DIRECTORY \"${CMAKE_CURRENT_BINARY_DIR}/autogenerated\")\n\n    # merge ggml-common.h and ggml-metal.metal into a single file\n    set(METALLIB_EMBED_ASM        \"${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s\")\n    set(METALLIB_SOURCE_EMBED     \"${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal\")\n    set(METALLIB_SOURCE_EMBED_TMP \"${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp\")\n\n    add_custom_command(\n        OUTPUT \"${METALLIB_EMBED_ASM}\"\n        COMMAND echo \"Embedding Metal library\"\n        COMMAND sed -e \"/__embed_ggml-common.h__/r ${METALLIB_COMMON}\"       -e \"/__embed_ggml-common.h__/d\"         < \"${METALLIB_SOURCE}\"           > \"${METALLIB_SOURCE_EMBED_TMP}\"\n        COMMAND sed -e \"/\\#include \\\"ggml-metal-impl.h\\\"/r ${METALLIB_IMPL}\" -e \"/\\#include \\\"ggml-metal-impl.h\\\"/d\" < \"${METALLIB_SOURCE_EMBED_TMP}\" > \"${METALLIB_SOURCE_EMBED}\"\n        COMMAND echo \".section __DATA,__ggml_metallib\"          >  \"${METALLIB_EMBED_ASM}\"\n        COMMAND echo \".globl _ggml_metallib_start\"              >> \"${METALLIB_EMBED_ASM}\"\n        COMMAND echo \"_ggml_metallib_start:\"                    >> \"${METALLIB_EMBED_ASM}\"\n        COMMAND echo .incbin \"\\\"${METALLIB_SOURCE_EMBED}\\\"\"     >> \"${METALLIB_EMBED_ASM}\"\n        COMMAND echo \".globl _ggml_metallib_end\"                >> \"${METALLIB_EMBED_ASM}\"\n        COMMAND echo \"_ggml_metallib_end:\"                      >> \"${METALLIB_EMBED_ASM}\"\n        DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h\n        COMMENT \"Generate assembly for embedded Metal library\"\n        VERBATIM\n    )\n\n    target_sources(ggml-metal PRIVATE \"${METALLIB_EMBED_ASM}\")\nelse()\n    # copy metal files to bin directory\n    configure_file(../ggml-common.h  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h     COPYONLY)\n    configure_file(ggml-metal.metal  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal  COPYONLY)\n    configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)\n\n    if (GGML_METAL_SHADER_DEBUG)\n        # custom command to do the following:\n        #   xcrun -sdk macosx metal    -fno-fast-math -c ggml-metal.metal -o ggml-metal.air\n        #   xcrun -sdk macosx metallib                   ggml-metal.air   -o default.metallib\n        #\n        # note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works\n        #       disabling fast math is needed in order to pass tests/test-backend-ops\n        # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1\n        # note: unfortunately, we have to call it default.metallib instead of ggml.metallib\n        #       ref: https://github.com/ggml-org/whisper.cpp/issues/1720\n        # note: adding -g causes segmentation fault during compile\n        #set(XC_FLAGS -fno-fast-math -fno-inline -g)\n        set(XC_FLAGS -fno-fast-math -fno-inline)\n    else()\n        set(XC_FLAGS -O3)\n    endif()\n\n    # Append macOS metal versioning flags\n    if (GGML_METAL_MACOSX_VERSION_MIN)\n        message(STATUS \"Adding  -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation\")\n        list   (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN})\n    endif()\n\n    if (GGML_METAL_STD)\n        message(STATUS \"Adding  -std=${GGML_METAL_STD} flag to metal compilation\")\n        list   (APPEND XC_FLAGS -std=${GGML_METAL_STD})\n    endif()\n\n    add_custom_command(\n        OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib\n        COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - |\n                xcrun -sdk macosx metallib        - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib\n        COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h\n        COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal\n        DEPENDS ggml-metal.metal ${METALLIB_COMMON}\n        COMMENT \"Compiling Metal kernels\"\n        )\n\n    # FIXME: only add to the ggml-metal target?\n    add_custom_target(\n        ggml-metal-lib ALL\n        DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib\n        )\nendif() # GGML_METAL_EMBED_LIBRARY\n\nif (NOT GGML_METAL_EMBED_LIBRARY)\n    install(\n        FILES src/ggml-metal/ggml-metal.metal\n        PERMISSIONS\n            OWNER_READ\n            OWNER_WRITE\n            GROUP_READ\n            WORLD_READ\n        DESTINATION ${CMAKE_INSTALL_BINDIR})\n\n        install(\n            FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib\n            DESTINATION ${CMAKE_INSTALL_BINDIR}\n        )\nendif()\n"
  },
  {
    "path": "src/ggml-metal/ggml-metal-common.cpp",
    "content": "#include \"ggml-metal-common.h\"\n\n#include \"ggml-impl.h\"\n#include \"ggml-backend-impl.h\"\n\n#include <vector>\n\n// represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)\n// the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)\nstruct ggml_mem_range {\n    uint64_t pb; // buffer id\n\n    uint64_t p0; // begin\n    uint64_t p1; // end\n\n    ggml_mem_range_type pt;\n};\n\nstruct ggml_mem_ranges {\n    std::vector<ggml_mem_range> ranges;\n\n    int debug = 0;\n};\n\nggml_mem_ranges_t ggml_mem_ranges_init(int debug) {\n    auto * res = new ggml_mem_ranges;\n\n    res->ranges.reserve(256);\n    res->debug = debug;\n\n    return res;\n}\n\nvoid ggml_mem_ranges_free(ggml_mem_ranges_t mrs) {\n    delete mrs;\n}\n\nvoid ggml_mem_ranges_reset(ggml_mem_ranges_t mrs) {\n    mrs->ranges.clear();\n}\n\nstatic bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, ggml_mem_range mr) {\n    mrs->ranges.push_back(mr);\n\n    return true;\n}\n\nstatic ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggml_mem_range_type pt) {\n    // always use the base tensor\n    tensor = tensor->view_src ? tensor->view_src : tensor;\n\n    GGML_ASSERT(!tensor->view_src);\n\n    ggml_mem_range mr;\n\n    if (tensor->buffer) {\n        // when the tensor is allocated, use the actual memory address range in the buffer\n        //\n        // take the actual allocated size with ggml_backend_buft_get_alloc_size()\n        // this can be larger than the tensor size if the buffer type allocates extra memory\n        // ref: https://github.com/ggml-org/llama.cpp/pull/15966\n        mr = {\n            /*.pb =*/ (uint64_t) tensor->buffer,\n            /*.p0 =*/ (uint64_t) tensor->data,\n            /*.p1 =*/ (uint64_t) tensor->data + ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor),\n            /*.pt =*/ pt,\n        };\n    } else {\n        // otherwise, the pointer address is used as an unique id of the memory ranges\n        //   that the tensor will be using when it is allocated\n        mr = {\n            /*.pb =*/ (uint64_t) tensor,\n            /*.p0 =*/ 0,    //\n            /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used\n            /*.pt =*/ pt,\n        };\n    };\n\n    return mr;\n}\n\nstatic ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) {\n    return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC);\n}\n\nstatic ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) {\n    return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);\n}\n\nstatic bool ggml_mem_ranges_add_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {\n    GGML_ASSERT(tensor);\n\n    ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);\n\n    if (mrs->debug > 2) {\n        GGML_LOG_DEBUG(\"%s: add src range buf=%lld, [%lld, %lld)\\n\", __func__, mr.pb, mr.p0, mr.p1);\n    }\n\n    return ggml_mem_ranges_add(mrs, mr);\n}\n\nstatic bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {\n    GGML_ASSERT(tensor);\n\n    ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);\n\n    if (mrs->debug > 2) {\n        GGML_LOG_DEBUG(\"%s: add dst range buf=%lld, [%lld, %lld)\\n\", __func__, mr.pb, mr.p0, mr.p1);\n    }\n\n    return ggml_mem_ranges_add(mrs, mr);\n}\n\nbool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        if (tensor->src[i]) {\n            ggml_mem_ranges_add_src(mrs, tensor->src[i]);\n        }\n    }\n\n    return ggml_mem_ranges_add_dst(mrs, tensor);\n}\n\nstatic bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, ggml_mem_range mr) {\n    for (size_t i = 0; i < mrs->ranges.size(); i++) {\n        const auto & cmp = mrs->ranges[i];\n\n        // two memory ranges cannot intersect if they are in different buffers\n        if (mr.pb != cmp.pb) {\n            continue;\n        }\n\n        // intersecting source ranges are allowed\n        if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {\n            continue;\n        }\n\n        if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) {\n            if (mrs->debug > 2) {\n                GGML_LOG_DEBUG(\"%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\\n\",\n                        __func__,\n                        mr.pt == MEM_RANGE_TYPE_SRC ? \"src\" : \"dst\",\n                        mr.pb, mr.p0, mr.p1,\n                        cmp.pt == MEM_RANGE_TYPE_SRC ? \"src\" : \"dst\",\n                        cmp.pb, cmp.p0, cmp.p1);\n            }\n\n            return false;\n        }\n    }\n\n    return true;\n}\n\nstatic bool ggml_mem_ranges_check_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {\n    GGML_ASSERT(tensor);\n\n    ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);\n\n    const bool res = ggml_mem_ranges_check(mrs, mr);\n\n    return res;\n}\n\nstatic bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {\n    GGML_ASSERT(tensor);\n\n    ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);\n\n    const bool res = ggml_mem_ranges_check(mrs, mr);\n\n    return res;\n}\n\nbool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        if (tensor->src[i]) {\n            if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {\n                return false;\n            }\n        }\n    }\n\n    return ggml_mem_ranges_check_dst(mrs, tensor);\n}\n\nstruct node_info {\n    ggml_tensor * node;\n\n    std::vector<ggml_tensor *> fused;\n\n    ggml_op op() const {\n        return node->op;\n    }\n\n    const ggml_tensor * dst() const {\n        return fused.empty() ? node : fused.back();\n    }\n\n    bool is_empty() const {\n        return ggml_op_is_empty(node->op);\n    }\n\n    void add_fused(ggml_tensor * t) {\n        fused.push_back(t);\n    }\n};\n\nstatic std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {\n    // helper to add node src and dst ranges\n    const auto & h_add = [](ggml_mem_ranges_t mrs, const node_info & node) {\n        for (int i = 0; i < GGML_MAX_SRC; i++) {\n            if (node.node->src[i]) {\n                if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) {\n                    return false;\n                }\n            }\n        }\n\n        // keep track of the sources of the fused nodes as well\n        for (const auto * fused : node.fused) {\n            for (int i = 0; i < GGML_MAX_SRC; i++) {\n                if (fused->src[i]) {\n                    if (!ggml_mem_ranges_add_src(mrs, fused->src[i])) {\n                        return false;\n                    }\n                }\n            }\n        }\n\n        return ggml_mem_ranges_add_dst(mrs, node.dst());\n    };\n\n    // helper to check if a node can run concurrently with the existing set of nodes\n    const auto & h_check = [](ggml_mem_ranges_t mrs, const node_info & node) {\n        for (int i = 0; i < GGML_MAX_SRC; i++) {\n            if (node.node->src[i]) {\n                if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) {\n                    return false;\n                }\n            }\n        }\n\n        for (const auto * fused : node.fused) {\n            for (int i = 0; i < GGML_MAX_SRC; i++) {\n                if (fused->src[i]) {\n                    if (!ggml_mem_ranges_check_src(mrs, fused->src[i])) {\n                        return false;\n                    }\n                }\n            }\n        }\n\n        return ggml_mem_ranges_check_dst(mrs, node.dst());\n    };\n\n    // perform reorders only across these types of ops\n    // can be expanded when needed\n    const auto & h_safe = [](ggml_op op) {\n        switch (op) {\n            case GGML_OP_MUL_MAT:\n            case GGML_OP_MUL_MAT_ID:\n            case GGML_OP_ROPE:\n            case GGML_OP_NORM:\n            case GGML_OP_RMS_NORM:\n            case GGML_OP_GROUP_NORM:\n            case GGML_OP_L2_NORM:\n            case GGML_OP_SUM_ROWS:\n            case GGML_OP_SSM_CONV:\n            case GGML_OP_SSM_SCAN:\n            case GGML_OP_CLAMP:\n            case GGML_OP_TRI:\n            case GGML_OP_DIAG:\n            case GGML_OP_MUL:\n            case GGML_OP_ADD:\n            case GGML_OP_SUB:\n            case GGML_OP_DIV:\n            case GGML_OP_GLU:\n            case GGML_OP_SCALE:\n            case GGML_OP_UNARY:\n            case GGML_OP_GET_ROWS:\n            case GGML_OP_SET_ROWS:\n            case GGML_OP_SET:\n            case GGML_OP_CPY:\n            case GGML_OP_CONT:\n            case GGML_OP_REPEAT:\n                return true;\n            default:\n                return ggml_op_is_empty(op);\n        }\n    };\n\n    const int n = nodes.size();\n\n    std::vector<int> res;\n    res.reserve(n);\n\n    std::vector<bool> used(n, false);\n\n    // the memory ranges for the set of currently concurrent nodes\n    ggml_mem_ranges_t mrs0 = ggml_mem_ranges_init(0);\n\n    // the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder\n    ggml_mem_ranges_t mrs1 = ggml_mem_ranges_init(0);\n\n    for (int i0 = 0; i0 < n; i0++) {\n        if (used[i0]) {\n            continue;\n        }\n\n        const auto & node0 = nodes[i0];\n\n        // the node is not concurrent with the existing concurrent set, so we have to \"put a barrier\" (i.e reset mrs0)\n        // but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0\n        //\n        // note: we can always add empty nodes to the concurrent set as they don't read nor write anything\n        if (!node0.is_empty() && !h_check(mrs0, node0)) {\n            // this will hold the set of memory ranges from the nodes that haven't been processed yet\n            // if a node is not concurrent with this set, we cannot reorder it\n            ggml_mem_ranges_reset(mrs1);\n\n            // initialize it with the current node\n            h_add(mrs1, node0);\n\n            // that many nodes forward to search for a concurrent node\n            constexpr int N_FORWARD = 64;\n\n            for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {\n                if (used[i1]) {\n                    continue;\n                }\n\n                const auto & node1 = nodes[i1];\n\n                // disallow reordering of certain ops\n                if (!h_safe(node1.op())) {\n                    break;\n                }\n\n                const bool is_empty = node1.is_empty();\n\n                // to reorder a node and add it to the concurrent set, it has to be:\n                //   + empty or concurrent with all nodes in the existing concurrent set (mrs0)\n                //   + concurrent with all nodes prior to it that haven't been processed yet (mrs1)\n                if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {\n                    // add the node to the existing concurrent set (i.e. reorder it for early execution)\n                    h_add(mrs0, node1);\n                    res.push_back(i1);\n\n                    // mark as used, so we skip re-processing it later\n                    used[i1] = true;\n                } else {\n                    // expand the set of nodes that haven't been processed yet\n                    h_add(mrs1, node1);\n                }\n            }\n\n            // finalize the concurrent set and begin a new one\n            ggml_mem_ranges_reset(mrs0);\n        }\n\n        // expand the concurrent set with the current node\n        {\n            h_add(mrs0, node0);\n            res.push_back(i0);\n        }\n    }\n\n    ggml_mem_ranges_free(mrs0);\n    ggml_mem_ranges_free(mrs1);\n\n    return res;\n}\n\nvoid ggml_graph_optimize(ggml_cgraph * gf) {\n    constexpr int MAX_FUSE = 16;\n\n    const int n = gf->n_nodes;\n\n    enum ggml_op ops[MAX_FUSE];\n\n    std::vector<node_info> nodes;\n    nodes.reserve(gf->n_nodes);\n\n    // fuse nodes:\n    // we don't want to make reorders that break fusing, so we first pack all fusable tensors\n    //   and perform the reorder over the fused nodes. after the reorder is done, we unfuse\n    for (int i = 0; i < n; i++) {\n        node_info node = {\n            /*.node =*/ gf->nodes[i],\n            /*.fused =*/ {},\n        };\n\n        // fuse only ops that start with these operations\n        // can be expanded when needed\n        if (node.op() == GGML_OP_ADD ||\n            node.op() == GGML_OP_NORM ||\n            node.op() == GGML_OP_RMS_NORM) {\n            ops[0] = node.op();\n\n            int f = i + 1;\n            while (f < n && f < i + MAX_FUSE) {\n                // conservatively allow fusing only these ops\n                // can be expanded when needed\n                if (gf->nodes[f]->op != GGML_OP_ADD &&\n                    gf->nodes[f]->op != GGML_OP_MUL &&\n                    gf->nodes[f]->op != GGML_OP_NORM &&\n                    gf->nodes[f]->op != GGML_OP_RMS_NORM) {\n                    break;\n                }\n                ops[f - i] = gf->nodes[f]->op;\n                f++;\n            }\n\n            f -= i;\n            for (; f > 1; f--) {\n                if (ggml_can_fuse(gf, i, ops, f)) {\n                    break;\n                }\n            }\n\n            // add the fused tensors into the node info so we can unfuse them later\n            for (int k = 1; k < f; k++) {\n                ++i;\n\n                // the .dst() becomes the last fused tensor\n                node.add_fused(gf->nodes[i]);\n            }\n        }\n\n        nodes.push_back(std::move(node));\n    }\n\n#if 1\n    // reorder to improve concurrency\n    const auto order = ggml_metal_graph_optimize_reorder(nodes);\n#else\n    std::vector<int> order(nodes.size());\n    for (size_t i = 0; i < nodes.size(); i++) {\n        order[i] = i;\n    }\n#endif\n\n    // unfuse\n    {\n        int j = 0;\n        for (const auto i : order) {\n            const auto & node = nodes[i];\n\n            gf->nodes[j++] = node.node;\n\n            for (auto * fused : node.fused) {\n                gf->nodes[j++] = fused;\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-metal/ggml-metal-common.h",
    "content": "// helper functions for ggml-metal that are too difficult to implement in Objective-C\n\n#pragma once\n\n#include <stdbool.h>\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nstruct ggml_tensor;\nstruct ggml_cgraph;\n\nenum ggml_mem_range_type {\n    MEM_RANGE_TYPE_SRC = 0,\n    MEM_RANGE_TYPE_DST = 1,\n};\n\n// a helper object that can be used for reordering operations to improve concurrency\n//\n// the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they\n//   don't write to a memory that is being read by another task or written to by another task in the set\n//\n// with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task\n//   can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the\n//   tasks already in the set)\n//\ntypedef struct ggml_mem_ranges * ggml_mem_ranges_t;\n\nggml_mem_ranges_t ggml_mem_ranges_init(int debug);\nvoid ggml_mem_ranges_free(ggml_mem_ranges_t mrs);\n\n// remove all ranges from the set\nvoid ggml_mem_ranges_reset(ggml_mem_ranges_t mrs);\n\n// add src or dst ranges to track\nbool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor);\n\n// return false if:\n// - new src range overlaps with any existing dst range\n// - new dst range overlaps with any existing range (src or dst)\nbool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor);\n\n// reorder the nodes in the graph to improve concurrency, while respecting fusion\n//\n// note: this implementation is generic and not specific to metal\n//       if it proves to work well, we can start using it for other backends in the future\nvoid ggml_graph_optimize(struct ggml_cgraph * gf);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-metal/ggml-metal-context.h",
    "content": "#pragma once\n\n#include \"ggml-metal-device.h\"\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n//\n// backend context\n//\n\ntypedef struct ggml_metal * ggml_metal_t;\n\nggml_metal_t ggml_metal_init(ggml_metal_device_t dev);\nvoid ggml_metal_free(ggml_metal_t ctx);\n\nconst char * ggml_metal_get_name(ggml_metal_t ctx);\n\nvoid ggml_metal_synchronize(ggml_metal_t ctx);\n\nvoid ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);\nvoid ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);\nbool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);\n\nenum ggml_status ggml_metal_graph_compute (ggml_metal_t ctx, struct ggml_cgraph * gf);\nvoid             ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf);\n\nvoid ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev);\nvoid ggml_metal_event_wait  (ggml_metal_t ctx, ggml_metal_event_t ev);\n\nggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx);\n\nvoid ggml_metal_set_n_cb            (ggml_metal_t ctx, int n_cb);\nvoid ggml_metal_set_abort_callback  (ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data);\nbool ggml_metal_supports_family     (ggml_metal_t ctx, int family);\nvoid ggml_metal_capture_next_compute(ggml_metal_t ctx);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-metal/ggml-metal-context.m",
    "content": "#import \"ggml-metal-context.h\"\n\n#import \"ggml-impl.h\"\n#import \"ggml-backend-impl.h\"\n\n#import \"ggml-metal-impl.h\"\n#import \"ggml-metal-common.h\"\n#import \"ggml-metal-ops.h\"\n\n#import <Foundation/Foundation.h>\n\n#import <Metal/Metal.h>\n\n#undef MIN\n#undef MAX\n#define MIN(a, b) ((a) < (b) ? (a) : (b))\n#define MAX(a, b) ((a) > (b) ? (a) : (b))\n\n// max number of MTLCommandBuffer used to submit a graph for processing\n#define GGML_METAL_MAX_COMMAND_BUFFERS 8\n\nstruct ggml_metal_command_buffer {\n    id<MTLCommandBuffer> obj;\n};\n\nstruct ggml_metal {\n    char name[128];\n\n    ggml_metal_device_t  dev;\n    ggml_metal_library_t lib;\n\n    ggml_metal_event_t ev_cpy; // for async copies\n\n    dispatch_queue_t d_queue;\n\n    // additional, inference-time compiled pipelines\n    ggml_metal_pipelines_t pipelines_ext;\n\n    bool use_fusion;\n    bool use_concurrency;\n    bool use_graph_optimize;\n\n    int debug_graph;\n    int debug_fusion;\n\n    // how many times a given op was fused\n    uint64_t fuse_cnt[GGML_OP_COUNT];\n\n    // capture state\n    int capture_compute;\n    bool capture_started;\n\n    id<MTLCaptureScope> capture_scope;\n\n    // command buffer state\n    int n_cb;           // number of extra threads used to submit the command buffers\n    int n_nodes_0;      // number of nodes submitted by the main thread\n    int n_nodes_1;      // remaining number of nodes submitted by the n_cb threads\n    int n_nodes_per_cb;\n\n    struct ggml_cgraph * gf;\n\n    // the callback given to the thread pool\n    void (^encode_async)(size_t ith);\n\n    // n_cb command buffers + 1 used by the main thread\n    struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];\n\n    // extra command buffers for things like getting, setting and copying tensors\n    NSMutableArray * cmd_bufs_ext;\n\n    // the last command buffer queued into the Metal queue with operations relevant to the current Metal backend\n    id<MTLCommandBuffer> cmd_buf_last;\n\n    // abort ggml_metal_graph_compute if callback returns true\n    ggml_abort_callback abort_callback;\n    void *              abort_callback_data;\n\n    // error state - set when a command buffer fails during synchronize\n    // once set, graph_compute will return GGML_STATUS_FAILED until the backend is recreated\n    bool has_error;\n};\n\nggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {\n    GGML_LOG_INFO(\"%s: allocating\\n\", __func__);\n\n#if TARGET_OS_OSX && !GGML_METAL_NDEBUG\n    // Show all the Metal device instances in the system\n    NSArray * devices = MTLCopyAllDevices();\n    for (id<MTLDevice> device in devices) {\n        GGML_LOG_INFO(\"%s: found device: %s\\n\", __func__, [[device name] UTF8String]);\n    }\n    [devices release]; // since it was created by a *Copy* C method\n#endif\n\n    // init context\n    ggml_metal_t res = calloc(1, sizeof(struct ggml_metal));\n\n    id<MTLDevice> device = ggml_metal_device_get_obj(dev);\n\n    GGML_LOG_INFO(\"%s: picking default device: %s\\n\", __func__, [[device name] UTF8String]);\n\n    // TODO: would it be better to have one queue for the backend and one queue for the device?\n    //       the graph encoders and async ops would use the backend queue while the sync ops would use the device queue?\n    //res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND]\n    id<MTLCommandQueue> queue = ggml_metal_device_get_queue(dev);\n    if (queue == nil) {\n        GGML_LOG_ERROR(\"%s: error: failed to create command queue\\n\", __func__);\n        return NULL;\n    }\n\n    res->dev = dev;\n    res->lib = ggml_metal_device_get_library(dev);\n    if (res->lib == NULL) {\n        GGML_LOG_WARN(\"%s: the device does not have a precompiled Metal library - this is unexpected\\n\", __func__);\n        GGML_LOG_WARN(\"%s: will try to compile it on the fly\\n\", __func__);\n\n        res->lib = ggml_metal_library_init(dev);\n        if (res->lib == NULL) {\n            GGML_LOG_ERROR(\"%s: error: failed to initialize the Metal library\\n\", __func__);\n\n            free(res);\n\n            return NULL;\n        }\n    }\n\n    res->ev_cpy = ggml_metal_device_event_init(dev);\n\n    const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);\n\n    snprintf(res->name, sizeof(res->name), \"%s\", props_dev->name);\n\n    res->d_queue = dispatch_queue_create(\"ggml-metal\", DISPATCH_QUEUE_CONCURRENT);\n\n    res->use_fusion      = getenv(\"GGML_METAL_FUSION_DISABLE\") == nil;\n    res->use_concurrency = getenv(\"GGML_METAL_CONCURRENCY_DISABLE\") == nil;\n\n    {\n        const char * val = getenv(\"GGML_METAL_GRAPH_DEBUG\");\n        res->debug_graph = val ? atoi(val) : 0;\n    }\n\n    {\n        const char * val = getenv(\"GGML_METAL_FUSION_DEBUG\");\n        res->debug_fusion = val ? atoi(val) : 0;\n    }\n\n    res->use_graph_optimize = true;\n\n    if (getenv(\"GGML_METAL_GRAPH_OPTIMIZE_DISABLE\") != NULL) {\n        res->use_graph_optimize = false;\n    }\n\n    memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt));\n\n    GGML_LOG_INFO(\"%s: use fusion         = %s\\n\", __func__, res->use_fusion         ? \"true\" : \"false\");\n    GGML_LOG_INFO(\"%s: use concurrency    = %s\\n\", __func__, res->use_concurrency    ? \"true\" : \"false\");\n    GGML_LOG_INFO(\"%s: use graph optimize = %s\\n\", __func__, res->use_graph_optimize ? \"true\" : \"false\");\n\n    res->capture_compute = 0;\n    res->capture_started = false;\n    res->capture_scope = nil;\n\n    {\n        const char * val = getenv(\"GGML_METAL_CAPTURE_COMPUTE\");\n        if (val) {\n            res->capture_compute = atoi(val);\n        }\n    }\n\n    res->has_error = false;\n\n    res->gf = nil;\n    res->encode_async = nil;\n    for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {\n        res->cmd_bufs[i].obj = nil;\n    }\n\n    res->cmd_bufs_ext = [[NSMutableArray alloc] init];\n\n    res->cmd_buf_last = nil;\n\n    res->pipelines_ext = ggml_metal_pipelines_init();\n\n    return res;\n}\n\nvoid ggml_metal_free(ggml_metal_t ctx) {\n    GGML_LOG_INFO(\"%s: deallocating\\n\", __func__);\n\n    for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {\n        if (ctx->cmd_bufs[i].obj) {\n            [ctx->cmd_bufs[i].obj release];\n        }\n    }\n\n    for (int i = 0; i < (int) ctx->cmd_bufs_ext.count; ++i) {\n        if (ctx->cmd_bufs_ext[i]) {\n            [ctx->cmd_bufs_ext[i] release];\n        }\n    }\n\n    [ctx->cmd_bufs_ext removeAllObjects];\n    [ctx->cmd_bufs_ext release];\n\n    if (ctx->pipelines_ext) {\n        ggml_metal_pipelines_free(ctx->pipelines_ext);\n        ctx->pipelines_ext = nil;\n    }\n\n    if (ctx->debug_fusion > 0) {\n        GGML_LOG_DEBUG(\"%s: fusion stats:\\n\", __func__);\n        for (int i = 0; i < GGML_OP_COUNT; i++) {\n            if (ctx->fuse_cnt[i] == 0) {\n                continue;\n            }\n\n            // note: cannot use ggml_log here\n            GGML_LOG_DEBUG(\"%s: - %s: %\" PRIu64 \"\\n\", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);\n        }\n    }\n\n    Block_release(ctx->encode_async);\n\n    //[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND]\n\n    dispatch_release(ctx->d_queue);\n\n    ggml_metal_device_event_free(ctx->dev, ctx->ev_cpy);\n\n    free(ctx);\n}\n\nconst char * ggml_metal_get_name(ggml_metal_t ctx) {\n    return ctx->name;\n}\n\nvoid ggml_metal_synchronize(ggml_metal_t ctx) {\n    // wait for any backend operations to finish\n    if (ctx->cmd_buf_last) {\n        [ctx->cmd_buf_last waitUntilCompleted];\n        ctx->cmd_buf_last = nil;\n    }\n\n    // check status of all command buffers\n    {\n        const int n_cb = ctx->n_cb;\n\n        for (int cb_idx = 0; cb_idx <= n_cb; ++cb_idx) {\n            id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;\n            if (!cmd_buf) {\n                continue;\n            }\n\n            MTLCommandBufferStatus status = [cmd_buf status];\n            if (status != MTLCommandBufferStatusCompleted) {\n                GGML_LOG_ERROR(\"%s: error: command buffer %d failed with status %d\\n\", __func__, cb_idx, (int) status);\n                if (status == MTLCommandBufferStatusError) {\n                    GGML_LOG_ERROR(\"error: %s\\n\", [[cmd_buf error].localizedDescription UTF8String]);\n                }\n                ctx->has_error = true;\n                return;\n            }\n        }\n    }\n\n    // release any completed extra command buffers\n    if (ctx->cmd_bufs_ext.count > 0) {\n        for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) {\n            id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs_ext[i];\n\n            MTLCommandBufferStatus status = [cmd_buf status];\n            if (status != MTLCommandBufferStatusCompleted) {\n                GGML_LOG_ERROR(\"%s: error: command buffer %d failed with status %d\\n\", __func__, (int) i, (int) status);\n                if (status == MTLCommandBufferStatusError) {\n                    GGML_LOG_ERROR(\"error: %s\\n\", [[cmd_buf error].localizedDescription UTF8String]);\n                }\n\n                // release this and all remaining command buffers before returning\n                for (size_t j = i; j < ctx->cmd_bufs_ext.count; ++j) {\n                    [ctx->cmd_bufs_ext[j] release];\n                }\n                [ctx->cmd_bufs_ext removeAllObjects];\n\n                ctx->has_error = true;\n                return;\n            }\n\n            [cmd_buf release];\n        }\n\n        [ctx->cmd_bufs_ext removeAllObjects];\n    }\n}\n\nstatic struct ggml_metal_buffer_id ggml_metal_get_buffer_id(const struct ggml_tensor * t) {\n    if (!t) {\n        return (struct ggml_metal_buffer_id) { nil, 0 };\n    }\n\n    ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;\n\n    return ggml_metal_buffer_get_id(buffer->context, t);\n}\n\nvoid ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    @autoreleasepool {\n        // wrap the source data into a Metal buffer\n        id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);\n        id<MTLBuffer> buf_src = [device newBufferWithBytes:data\n                                                    length:size\n                                                   options:MTLResourceStorageModeShared];\n\n        GGML_ASSERT(buf_src);\n\n        struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(tensor);\n        if (bid_dst.metal == nil) {\n            GGML_ABORT(\"%s: failed to find buffer for tensor '%s'\\n\", __func__, tensor->name);\n        }\n\n        bid_dst.offs += offset;\n\n        // queue the copy operation into the queue of the Metal context\n        // this will be queued at the end, after any currently ongoing GPU operations\n        id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);\n        id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];\n        id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];\n\n        [encoder copyFromBuffer:buf_src\n                   sourceOffset:0\n                       toBuffer:bid_dst.metal\n              destinationOffset:bid_dst.offs\n                           size:size];\n\n        [encoder endEncoding];\n        [cmd_buf commit];\n        [buf_src release];\n\n        // do not wait here for completion\n        //[cmd_buf waitUntilCompleted];\n\n        // instead, remember a reference to the command buffer and wait for it later if needed\n        [ctx->cmd_bufs_ext addObject:cmd_buf];\n        ctx->cmd_buf_last = cmd_buf;\n\n        [cmd_buf retain];\n    }\n}\n\nvoid ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    @autoreleasepool {\n        id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);\n        id<MTLBuffer> buf_dst = [device newBufferWithBytesNoCopy:data\n                                                          length:size\n                                                         options:MTLResourceStorageModeShared\n                                                     deallocator:nil];\n\n        GGML_ASSERT(buf_dst);\n\n        struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(tensor);\n        if (bid_src.metal == nil) {\n            GGML_ABORT(\"%s: failed to find buffer for tensor '%s'\\n\", __func__, tensor->name);\n        }\n\n        bid_src.offs += offset;\n\n        // queue the copy operation into the queue of the Metal context\n        // this will be queued at the end, after any currently ongoing GPU operations\n        id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);\n        id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];\n        id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];\n\n        [encoder copyFromBuffer:bid_src.metal\n                   sourceOffset:bid_src.offs\n                       toBuffer:buf_dst\n              destinationOffset:0\n                           size:size];\n\n        [encoder endEncoding];\n        [cmd_buf commit];\n        [buf_dst release];\n\n        // do not wait here for completion\n        //[cmd_buf waitUntilCompleted];\n\n        // instead, remember a reference to the command buffer and wait for it later if needed\n        [ctx->cmd_bufs_ext addObject:cmd_buf];\n        ctx->cmd_buf_last = cmd_buf;\n\n        [cmd_buf retain];\n    }\n}\n\nbool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) {\n    @autoreleasepool {\n        struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(src);\n        struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(dst);\n\n        if (bid_src.metal == nil || bid_dst.metal == nil) {\n            return false;\n        }\n\n        // queue the copy operation into the Metal context\n        // this will be queued at the end, after any currently ongoing GPU operations\n        id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx_src->dev);\n        id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];\n        id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];\n\n        [encoder copyFromBuffer:bid_src.metal\n                   sourceOffset:bid_src.offs\n                       toBuffer:bid_dst.metal\n              destinationOffset:bid_dst.offs\n                           size:ggml_nbytes(src)];\n\n        [encoder endEncoding];\n\n        ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src);\n        ggml_metal_event_encode_signal(ev_cpy, cmd_buf);\n\n        [cmd_buf commit];\n\n        // do not wait here for completion\n        //[cmd_buf waitUntilCompleted];\n\n        // instead, remember a reference to the command buffer and wait for it later if needed\n        [ctx_src->cmd_bufs_ext addObject:cmd_buf];\n        ctx_src->cmd_buf_last = cmd_buf;\n\n        [cmd_buf retain];\n\n        ggml_metal_event_wait(ctx_dst, ev_cpy);\n\n        return true;\n    }\n}\n\nenum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {\n    if (ctx->has_error) {\n        GGML_LOG_ERROR(\"%s: backend is in error state from a previous command buffer failure - recreate the backend to recover\\n\", __func__);\n        return GGML_STATUS_FAILED;\n    }\n\n    // number of nodes encoded by the main thread (empirically determined)\n    const int n_main = MAX(64, 0.1*gf->n_nodes);\n\n    // number of threads in addition to the main thread\n    const int n_cb = ctx->n_cb;\n\n    // keep the memory wired\n    ggml_metal_device_rsets_keep_alive(ctx->dev);\n\n    // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them\n    // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread\n    // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes\n    // each thread creates it's own command buffer and enqueues the ops in parallel\n    //\n    // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2\n\n    @autoreleasepool {\n        ctx->gf = gf;\n\n        ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);\n        ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;\n\n        ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;\n\n        if (ctx->capture_compute >= 0) {\n            ctx->capture_compute--;\n        }\n\n        const bool use_capture = ctx->capture_compute == 0;\n        if (use_capture) {\n            ctx->capture_compute = -1;\n\n            // make sure all previous computations have finished before starting the capture\n            if (ctx->cmd_buf_last) {\n                [ctx->cmd_buf_last waitUntilCompleted];\n                ctx->cmd_buf_last = nil;\n            }\n\n            if (!ctx->capture_started) {\n                NSString * path = [NSString stringWithFormat:@\"/tmp/perf-metal-%d.gputrace\", getpid()];\n\n                GGML_LOG_WARN(\"%s: capturing graph in %s\\n\", __func__, [path UTF8String]);\n\n                // create capture scope\n                id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);\n                ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device];\n\n                MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];\n                descriptor.captureObject = ctx->capture_scope;\n                descriptor.destination = MTLCaptureDestinationGPUTraceDocument;\n                descriptor.outputURL = [NSURL fileURLWithPath:path];\n\n                NSError * error = nil;\n                if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {\n                    GGML_LOG_ERROR(\"%s: error: unable to start capture '%s'\\n\", __func__, [[error localizedDescription] UTF8String]);\n                } else {\n                    [ctx->capture_scope beginScope];\n                    ctx->capture_started = true;\n                }\n            }\n        }\n\n        // short-hand\n        id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);\n\n        // the main thread commits the first few commands immediately\n        // cmd_buf[n_cb]\n        {\n            id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];\n            [cmd_buf retain];\n\n            if (ctx->cmd_bufs[n_cb].obj) {\n                [ctx->cmd_bufs[n_cb].obj release];\n            }\n            ctx->cmd_bufs[n_cb].obj = cmd_buf;\n\n            [cmd_buf enqueue];\n\n            ctx->encode_async(n_cb);\n        }\n\n        // remember the command buffer for the next iteration\n        ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj;\n\n        // prepare the rest of the command buffers asynchronously (optional)\n        // cmd_buf[0.. n_cb)\n        for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {\n            id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];\n            [cmd_buf retain];\n\n            if (ctx->cmd_bufs[cb_idx].obj) {\n                [ctx->cmd_bufs[cb_idx].obj release];\n            }\n            ctx->cmd_bufs[cb_idx].obj = cmd_buf;\n\n            // always enqueue the first two command buffers\n            // enqueue all of the command buffers if we don't need to abort\n            if (cb_idx < 2 || ctx->abort_callback == NULL) {\n                [cmd_buf enqueue];\n\n                // update the pointer to the last queued command buffer\n                // this is needed to implement synchronize()\n                ctx->cmd_buf_last = cmd_buf;\n            }\n        }\n\n        dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);\n\n        // for debugging: block until graph is computed\n        //[ctx->cmd_buf_last waitUntilCompleted];\n\n        // enter here only when capturing in order to wait for all computation to finish\n        // otherwise, we leave the graph to compute asynchronously\n        if (use_capture && ctx->capture_started) {\n            // wait for completion and check status of each command buffer\n            // needed to detect if the device ran out-of-memory for example (#1881)\n            {\n                id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;\n                [cmd_buf waitUntilCompleted];\n\n                MTLCommandBufferStatus status = [cmd_buf status];\n                if (status != MTLCommandBufferStatusCompleted) {\n                    GGML_LOG_INFO(\"%s: command buffer %d failed with status %lu\\n\", __func__, n_cb, status);\n                    if (status == MTLCommandBufferStatusError) {\n                        GGML_LOG_INFO(\"error: %s\\n\", [[cmd_buf error].localizedDescription UTF8String]);\n                    }\n\n                    return GGML_STATUS_FAILED;\n                }\n            }\n\n            for (int i = 0; i < n_cb; ++i) {\n                id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;\n                [cmd_buf waitUntilCompleted];\n\n                MTLCommandBufferStatus status = [cmd_buf status];\n                if (status != MTLCommandBufferStatusCompleted) {\n                    GGML_LOG_INFO(\"%s: command buffer %d failed with status %lu\\n\", __func__, i, status);\n                    if (status == MTLCommandBufferStatusError) {\n                        GGML_LOG_INFO(\"error: %s\\n\", [[cmd_buf error].localizedDescription UTF8String]);\n                    }\n\n                    return GGML_STATUS_FAILED;\n                }\n\n                id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);\n                if (!next_buffer) {\n                    continue;\n                }\n\n                const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);\n                if (next_queued) {\n                    continue;\n                }\n\n                if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {\n                    GGML_LOG_INFO(\"%s: command buffer %d aborted\", __func__, i);\n                    return GGML_STATUS_ABORTED;\n                }\n\n                [next_buffer commit];\n            }\n\n            [ctx->capture_scope endScope];\n            [[MTLCaptureManager sharedCaptureManager] stopCapture];\n\n            ctx->capture_started = false;\n        }\n    }\n\n    return GGML_STATUS_SUCCESS;\n}\n\nvoid ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) {\n    //const int64_t t_start = ggml_time_us();\n\n    if (ctx->use_graph_optimize) {\n        ggml_graph_optimize(gf);\n    }\n\n    //printf(\"%s: graph optimize took %.3f ms\\n\", __func__, (ggml_time_us() - t_start) / 1000.0);\n}\n\nvoid ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev) {\n    @autoreleasepool {\n        id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);\n        id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];\n\n        ggml_metal_event_encode_signal(ev, cmd_buf);\n\n        [cmd_buf commit];\n\n        [ctx->cmd_bufs_ext addObject:cmd_buf];\n        ctx->cmd_buf_last = cmd_buf;\n\n        [cmd_buf retain];\n    }\n}\n\nvoid ggml_metal_event_wait(ggml_metal_t ctx, ggml_metal_event_t ev) {\n    @autoreleasepool {\n        id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);\n        id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];\n\n        ggml_metal_event_encode_wait(ev, cmd_buf);\n\n        [cmd_buf commit];\n\n        [ctx->cmd_bufs_ext addObject:cmd_buf];\n        ctx->cmd_buf_last = cmd_buf;\n\n        [cmd_buf retain];\n    }\n}\n\nggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx) {\n    return ctx->ev_cpy;\n}\n\nvoid ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {\n    if (ctx->n_cb != n_cb) {\n        ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);\n\n        if (ctx->n_cb > 2) {\n            GGML_LOG_WARN(\"%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\\n\", __func__, n_cb);\n        }\n    }\n\n    if (ctx->encode_async) {\n        Block_release(ctx->encode_async);\n    }\n\n    ctx->encode_async = Block_copy(^(size_t iter) {\n        const int cb_idx = iter;\n        const int n_cb_l = ctx->n_cb;\n\n        const int n_nodes_0 = ctx->n_nodes_0;\n        const int n_nodes_1 = ctx->n_nodes_1;\n\n        const int n_nodes_per_cb = ctx->n_nodes_per_cb;\n\n        int idx_start = 0;\n        int idx_end   = n_nodes_0;\n\n        if (cb_idx < n_cb_l) {\n            idx_start = n_nodes_0 + (                                         (cb_idx + 0) * n_nodes_per_cb);\n            idx_end   = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));\n        }\n\n        id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;\n\n        ggml_metal_op_t ctx_op = ggml_metal_op_init(\n            ctx->dev,\n            cmd_buf,\n            ctx->gf,\n            idx_start,\n            idx_end,\n            ctx->use_fusion,\n            ctx->use_concurrency,\n            ctx->capture_compute,\n            ctx->debug_graph,\n            ctx->debug_fusion);\n\n        for (int idx = 0; idx < ggml_metal_op_n_nodes(ctx_op); ++idx) {\n            const int res = ggml_metal_op_encode(ctx_op, idx);\n            if (res == 0) {\n                break;\n            }\n\n            idx += res - 1;\n        }\n\n        ggml_metal_op_free(ctx_op);\n\n        if (cb_idx < 2 || ctx->abort_callback == NULL) {\n            [cmd_buf commit];\n        }\n    });\n}\n\nvoid ggml_metal_set_abort_callback(ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data) {\n    ctx->abort_callback = abort_callback;\n    ctx->abort_callback_data = user_data;\n}\n\nbool ggml_metal_supports_family(ggml_metal_t ctx, int family) {\n    GGML_ASSERT(ctx->dev != nil);\n\n    id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);\n\n    return [device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];\n}\n\nvoid ggml_metal_capture_next_compute(ggml_metal_t ctx) {\n    ctx->capture_compute = 1;\n}\n"
  },
  {
    "path": "src/ggml-metal/ggml-metal-device.cpp",
    "content": "#include \"ggml-metal-device.h\"\n\n#include \"ggml-metal-impl.h\"\n\n#include \"ggml-impl.h\"\n\n#include <cassert>\n#include <memory>\n#include <string>\n#include <unordered_map>\n\nstruct ggml_metal_device_deleter {\n    void operator()(ggml_metal_device_t ctx) {\n        ggml_metal_device_free(ctx);\n    }\n};\n\ntypedef std::unique_ptr<ggml_metal_device, ggml_metal_device_deleter> ggml_metal_device_ptr;\n\nggml_metal_device_t ggml_metal_device_get(int device) {\n    static std::vector<ggml_metal_device_ptr> devs;\n\n    devs.emplace_back(ggml_metal_device_init(device));\n\n    return devs.back().get();\n}\n\nstruct ggml_metal_pipelines {\n    std::unordered_map<std::string, ggml_metal_pipeline_t> data;\n};\n\nggml_metal_pipelines_t ggml_metal_pipelines_init(void) {\n    ggml_metal_pipelines_t res = new ggml_metal_pipelines();\n\n    return res;\n}\n\nvoid ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {\n    if (!ppls) {\n        return;\n    }\n\n    for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {\n        ggml_metal_pipeline_free(it->second);\n    }\n\n    delete ppls;\n}\n\nvoid ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline) {\n    ppls->data[name] = pipeline;\n}\n\nggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {\n    if (ppls->data.find(name) == ppls->data.end()) {\n        return nullptr;\n    }\n\n    return ppls->data[name];\n}\n\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) {\n    char base[256];\n    char name[256];\n\n    const char * op_str = \"undefined\";\n    switch (op) {\n        case GGML_OP_ADD_ID: op_str = \"add_id\"; break;\n        case GGML_OP_CONCAT: op_str = \"concat\"; break;\n        default: GGML_ABORT(\"fatal error\");\n    };\n\n    snprintf(base, 256, \"kernel_%s\", op_str);\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) {\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_cpy_%s_%s\", ggml_type_name(tsrc), ggml_type_name(tdst));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {\n    GGML_ASSERT(ggml_is_contiguous(op->src[0]));\n    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);\n\n    const char * pool_str = \"undefined\";\n    switch (op_pool) {\n        case GGML_OP_POOL_AVG: pool_str = \"avg\"; break;\n        case GGML_OP_POOL_MAX: pool_str = \"max\"; break;\n        default: GGML_ASSERT(false && \"not implemented\");\n    };\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, sizeof(base), \"kernel_pool_1d_%s_%s\", pool_str, ggml_type_name(op->src[0]->type));\n    snprintf(name, sizeof(name), \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {\n    GGML_ASSERT(ggml_is_contiguous(op->src[0]));\n    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);\n\n    const char * pool_str = \"undefined\";\n    switch (op_pool) {\n        case GGML_OP_POOL_AVG: pool_str = \"avg\"; break;\n        case GGML_OP_POOL_MAX: pool_str = \"max\"; break;\n        default: GGML_ASSERT(false && \"not implemented\");\n    };\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_pool_2d_%s_%s\", pool_str, ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) {\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_get_rows_%s\", ggml_type_name(tsrc));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_set_rows_%s_%s\", ggml_type_name(tdst), ggml_type_name(tidx));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) {\n    char base[256];\n    char name[256];\n\n    const int n = op->src[0]->ne[0];\n\n    snprintf(base, 256, \"kernel_diag_%s\", ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s_n=%d\", base, n);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    res.nsg  = 1;\n    res.smem = 0;\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_repeat_%s\", ggml_type_name(tsrc));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {\n    char base[256];\n    char name[256];\n\n    int op_num = -1;\n\n    switch (op->op) {\n        case GGML_OP_SCALE:      op_num = OP_UNARY_NUM_SCALE;      break;\n        case GGML_OP_FILL:       op_num = OP_UNARY_NUM_FILL;       break;\n        case GGML_OP_CLAMP:      op_num = OP_UNARY_NUM_CLAMP;      break;\n        case GGML_OP_SQR:        op_num = OP_UNARY_NUM_SQR;        break;\n        case GGML_OP_SQRT:       op_num = OP_UNARY_NUM_SQRT;       break;\n        case GGML_OP_SIN:        op_num = OP_UNARY_NUM_SIN;        break;\n        case GGML_OP_COS:        op_num = OP_UNARY_NUM_COS;        break;\n        case GGML_OP_LOG:        op_num = OP_UNARY_NUM_LOG;        break;\n        case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;\n        case GGML_OP_UNARY:\n            switch (ggml_get_unary_op(op)) {\n                case GGML_UNARY_OP_TANH:        op_num = OP_UNARY_NUM_TANH;        break;\n                case GGML_UNARY_OP_RELU:        op_num = OP_UNARY_NUM_RELU;        break;\n                case GGML_UNARY_OP_SIGMOID:     op_num = OP_UNARY_NUM_SIGMOID;     break;\n                case GGML_UNARY_OP_GELU:        op_num = OP_UNARY_NUM_GELU;        break;\n                case GGML_UNARY_OP_GELU_ERF:    op_num = OP_UNARY_NUM_GELU_ERF;    break;\n                case GGML_UNARY_OP_GELU_QUICK:  op_num = OP_UNARY_NUM_GELU_QUICK;  break;\n                case GGML_UNARY_OP_SILU:        op_num = OP_UNARY_NUM_SILU;        break;\n                case GGML_UNARY_OP_ELU:         op_num = OP_UNARY_NUM_ELU;         break;\n                case GGML_UNARY_OP_NEG:         op_num = OP_UNARY_NUM_NEG;         break;\n                case GGML_UNARY_OP_ABS:         op_num = OP_UNARY_NUM_ABS;         break;\n                case GGML_UNARY_OP_SGN:         op_num = OP_UNARY_NUM_SGN;         break;\n                case GGML_UNARY_OP_STEP:        op_num = OP_UNARY_NUM_STEP;        break;\n                case GGML_UNARY_OP_HARDSWISH:   op_num = OP_UNARY_NUM_HARDSWISH;   break;\n                case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;\n                case GGML_UNARY_OP_EXP:         op_num = OP_UNARY_NUM_EXP;         break;\n                case GGML_UNARY_OP_SOFTPLUS:    op_num = OP_UNARY_NUM_SOFTPLUS;    break;\n                case GGML_UNARY_OP_EXPM1:       op_num = OP_UNARY_NUM_EXPM1;       break;\n                default: GGML_ABORT(\"fatal error\");\n            } break;\n        default: GGML_ABORT(\"fatal error\");\n    };\n\n    const char * t0_str = ggml_type_name(op->src[0]->type);\n    const char * t_str  = ggml_type_name(op->type);\n\n    const bool is_c4 = op->src[0]->ne[0] % 4 == 0;\n    const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;\n\n    snprintf(base, 256, \"kernel_unary_%s_%s%s\", t0_str, t_str, is_c4 ? \"_4\" : \"\");\n    snprintf(name, 256, \"%s_op=%d_cnt=%d\", base, op_num, is_cnt);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);\n        ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    res.c4  = is_c4;\n    res.cnt = is_cnt;\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) {\n    GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));\n\n    char base[256];\n    char name[256];\n\n    const char * op_str = \"undefined\";\n    switch (op->op) {\n        case GGML_OP_GLU:\n            switch (ggml_get_glu_op(op)) {\n                case GGML_GLU_OP_REGLU:        op_str = \"reglu\";        break;\n                case GGML_GLU_OP_GEGLU:        op_str = \"geglu\";        break;\n                case GGML_GLU_OP_SWIGLU:       op_str = \"swiglu\";       break;\n                case GGML_GLU_OP_SWIGLU_OAI:   op_str = \"swiglu_oai\";   break;\n                case GGML_GLU_OP_GEGLU_ERF:    op_str = \"geglu_erf\";    break;\n                case GGML_GLU_OP_GEGLU_QUICK:  op_str = \"geglu_quick\";  break;\n                default: GGML_ABORT(\"fatal error\");\n            } break;\n        default: GGML_ABORT(\"fatal error\");\n    };\n\n    snprintf(base, 256, \"kernel_%s_%s\", op_str, ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_SUM);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_op_sum_%s\", ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {\n    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));\n\n    char base[256];\n    char name[256];\n\n    int op_num = -1;\n\n    switch (op->op) {\n        case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;\n        case GGML_OP_MEAN:     op_num = OP_SUM_ROWS_NUM_MEAN;     break;\n        default: GGML_ABORT(\"fatal error\");\n    };\n\n    const char * t0_str = ggml_type_name(op->src[0]->type);\n    const char * t_str  = ggml_type_name(op->type);\n\n    const bool is_c4 = op->src[0]->ne[0] % 4 == 0;\n\n    snprintf(base, 256, \"kernel_sum_rows_%s_%s%s\", t0_str, t_str, is_c4 ? \"_4\" : \"\");\n    snprintf(name, 256, \"%s_op=%d\", base, op_num);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    res.smem = 32*sizeof(float);\n\n    if (is_c4) {\n        res.smem *= 4;\n    }\n\n    res.c4  = is_c4;\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {\n    GGML_ASSERT(op->op == GGML_OP_CUMSUM);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_cumsum_blk_%s\", ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {\n    GGML_ASSERT(op->op == GGML_OP_CUMSUM);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_cumsum_add_%s\", ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) {\n    GGML_ASSERT(op->op == GGML_OP_TRI);\n    GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));\n\n    char base[256];\n    char name[256];\n\n    const char * op_str = \"tri\";\n    const int ttype = op->op_params[0];\n\n    snprintf(base, 256, \"kernel_%s_%s_%d\", op_str, ggml_type_name(op->src[0]->type), ttype);\n\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {\n    GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);\n\n    char base[256];\n    char name[256];\n\n    const char * suffix = \"\";\n\n    if (op->src[0]->ne[0] % 4 == 0) {\n        suffix = \"_4\";\n    }\n\n    const ggml_type tsrc1 = op->src[1] ? op->src[1]->type : GGML_TYPE_F32;\n\n    snprintf(base, 256, \"kernel_soft_max_%s%s\", ggml_type_name(tsrc1), suffix);\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    res.smem = 32*sizeof(float);\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {\n    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);\n\n    GGML_ASSERT(ggml_is_contiguous(op->src[0]));\n    GGML_ASSERT(ggml_is_contiguous(op->src[1]));\n\n    char base[256];\n    char name[256];\n\n    const char * suffix = \"\";\n\n    if (op->src[1]->ne[0] % 4 == 0) {\n        suffix = \"_4\";\n    }\n\n    snprintf(base, 256, \"kernel_ssm_conv_%s_%s%s\", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) {\n    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);\n\n    GGML_ASSERT(ggml_is_contiguous(op->src[0]));\n    GGML_ASSERT(ggml_is_contiguous(op->src[1]));\n\n    char base[256];\n    char name[256];\n\n    const char * suffix = \"\";\n    if (op->src[1]->ne[0] % 4 == 0) {\n        suffix = \"_4\";\n    }\n\n    snprintf(base, 256, \"kernel_ssm_conv_%s_%s_batched%s\", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);\n    snprintf(name, 256, \"%s_ssm_conv_bs=%d\", base, ssm_conv_bs);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op)  {\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n\n    char base[256];\n    char name[256];\n\n    const int nsg = (ne00 + 31)/32;\n\n    snprintf(base, 256, \"kernel_ssm_scan_%s\", ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s_nsg=%d\", base, nsg);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    // Shared memory layout:\n    // - sgptg * NW floats for partial sums (nsg * 32)\n    // - sgptg floats for shared_x_dt (nsg)\n    // - sgptg floats for shared_dA (nsg)\n    // Total: nsg * (32 + 2) floats\n    res.smem = (32 + 2)*sizeof(float)*nsg;\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {\n    char base[256];\n    char name[256];\n\n    const int64_t C = op->ne[0];\n    const int64_t H = op->src[0]->ne[1];\n\n    switch (op->op) {\n        case GGML_OP_RWKV_WKV6:\n            {\n                GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);\n                GGML_ASSERT(C % H == 0);\n                GGML_ASSERT(C / H == 64);\n\n                snprintf(base, 256, \"kernel_rwkv_wkv6_%s\", ggml_type_name(op->src[0]->type));\n            } break;\n        case GGML_OP_RWKV_WKV7:\n            {\n                GGML_ASSERT(op->src[6]->type == GGML_TYPE_F32);\n                GGML_ASSERT(C % H == 0);\n                GGML_ASSERT(C / H == 64);\n\n                snprintf(base, 256, \"kernel_rwkv_wkv7_%s\", ggml_type_name(op->src[0]->type));\n            } break;\n        default:\n            GGML_ABORT(\"fatal error\");\n    }\n\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) {\n    char base[256];\n    char name[256];\n\n    // v is src[2], dimensions: S_v = ne[0], H = ne[1]\n    const int ne20 = op->src[2]->ne[0]; // S_v\n    const int ne21 = op->src[2]->ne[1]; // H\n    const int ne30 = op->src[3]->ne[0]; // G\n\n    const int nsg = op->src[2]->ne[0]/32;\n\n    GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->ne[0] == ne20 * ne21);\n    GGML_ASSERT(ne20 % 32 == 0);\n\n    snprintf(base, 256, \"kernel_gated_delta_net_%s_%d\", ggml_type_name(op->src[0]->type), nsg);\n    snprintf(name, 256, \"%s_ne20=%d_ne30=%d\", base, ne20, ne30);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);\n        ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    res.nsg = nsg;\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {\n    char base[256];\n    char name[256];\n\n    const int nsg = 8;\n    const int n   = op->src[1]->ne[1];\n    const int k   = op->src[1]->ne[0];\n\n    snprintf(base, 256, \"kernel_solve_tri_%s\", ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s_nsg=%d_n=%d_k=%d\", base, nsg, n, k);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0);\n        ggml_metal_cv_set_int16(cv, n,   FC_SOLVE_TRI + 1);\n        ggml_metal_cv_set_int16(cv, k,   FC_SOLVE_TRI + 2);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    res.nsg  = nsg;\n    res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16);\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_mul_mv_ext_%s_%s_r1_%d\", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);\n    snprintf(name, 256, \"%s_nsg=%d_nxpsg=%d\", base, nsg, nxpsg);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_int16(cv, nsg,   FC_MUL_MV + 0);\n        ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {\n    char base[256];\n    char name[256];\n\n    const ggml_type tsrc0 = op->src[0]->type;\n    const ggml_type tsrc1 = op->src[1]->type;\n\n    const bool bc_inp = op->src[0]->ne[0] % 32 != 0;\n    const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;\n\n    snprintf(base, 256, \"kernel_mul_mm_%s_%s\", ggml_type_name(tsrc0), ggml_type_name(tsrc1));\n    snprintf(name, 256, \"%s_bci=%d_bco=%d\", base, bc_inp, bc_out);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);\n        ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes\n    res.smem = bc_out ? 8192 : 4096 + 2048;\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) {\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n\n    char base[256];\n    char name[256];\n\n    int nsg = 0; // number of simdgroups\n    int nr0 = 0; // number of src0 rows per simdgroup\n    int nr1 = 1; // number of src1 rows per threadgroup\n\n    size_t smem = 0; // shared memory\n\n    const ggml_type tsrc0 = op->src[0]->type;\n    const ggml_type tsrc1 = op->src[1]->type;\n\n    const char * suffix = \"\";\n\n    // use custom matrix x vector kernel\n    switch (tsrc0) {\n        case GGML_TYPE_F32:\n        case GGML_TYPE_F16:\n        case GGML_TYPE_BF16:\n            {\n                if (ne00 < 32) {\n                    nsg = 1;\n                    nr0 = 32;\n                    nr1 = 1;\n                    suffix = \"_short\";\n                } else {\n                    nsg = std::min(4, (ne00 + 127) / 128);\n                    nr0 = 2;\n                    nr1 = 1;\n                    smem = 32*sizeof(float)*nr0;\n                    suffix = ne00 % 4 == 0 ? \"_4\" : \"\";\n                }\n            } break;\n        case GGML_TYPE_Q4_0:\n            {\n                nsg = N_SG_Q4_0;\n                nr0 = N_R0_Q4_0;\n            } break;\n        case GGML_TYPE_Q4_1:\n            {\n                nsg = N_SG_Q4_1;\n                nr0 = N_R0_Q4_1;\n            } break;\n        case GGML_TYPE_Q5_0:\n            {\n                nsg = N_SG_Q5_0;\n                nr0 = N_R0_Q5_0;\n            } break;\n        case GGML_TYPE_Q5_1:\n            {\n                nsg = N_SG_Q5_1;\n                nr0 = N_R0_Q5_1;\n            } break;\n        case GGML_TYPE_Q8_0:\n            {\n                nsg = N_SG_Q8_0;\n                nr0 = N_R0_Q8_0;\n                smem = 32*sizeof(float)*N_R0_Q8_0;\n            } break;\n        case GGML_TYPE_MXFP4:\n            {\n                nsg = N_SG_MXFP4;\n                nr0 = N_R0_MXFP4;\n                smem = 32*sizeof(float);\n            } break;\n        case GGML_TYPE_Q2_K:\n            {\n                nsg = N_SG_Q2_K;\n                nr0 = N_R0_Q2_K;\n            } break;\n        case GGML_TYPE_Q3_K:\n            {\n                nsg = N_SG_Q3_K;\n                nr0 = N_R0_Q3_K;\n            } break;\n        case GGML_TYPE_Q4_K:\n            {\n                nsg = N_SG_Q4_K;\n                nr0 = N_R0_Q4_K;\n            } break;\n        case GGML_TYPE_Q5_K:\n            {\n                nsg = N_SG_Q5_K;\n                nr0 = N_R0_Q5_K;\n            } break;\n        case GGML_TYPE_Q6_K:\n            {\n                nsg = N_SG_Q6_K;\n                nr0 = N_R0_Q6_K;\n            } break;\n        case GGML_TYPE_IQ2_XXS:\n            {\n                nsg = N_SG_IQ2_XXS;\n                nr0 = N_R0_IQ2_XXS;\n                smem = 256*8+128;\n            } break;\n        case GGML_TYPE_IQ2_XS:\n            {\n                nsg = N_SG_IQ2_XS;\n                nr0 = N_R0_IQ2_XS;\n                smem = 512*8+128;\n            } break;\n        case GGML_TYPE_IQ3_XXS:\n            {\n                nsg = N_SG_IQ3_XXS;\n                nr0 = N_R0_IQ3_XXS;\n                smem = 256*4+128;\n            } break;\n        case GGML_TYPE_IQ3_S:\n            {\n                nsg = N_SG_IQ3_S;\n                nr0 = N_R0_IQ3_S;\n                smem = 512*4;\n            } break;\n        case GGML_TYPE_IQ2_S:\n            {\n                nsg = N_SG_IQ2_S;\n                nr0 = N_R0_IQ2_S;\n            } break;\n        case GGML_TYPE_IQ1_S:\n            {\n                nsg = N_SG_IQ1_S;\n                nr0 = N_R0_IQ1_S;\n            } break;\n        case GGML_TYPE_IQ1_M:\n            {\n                nsg = N_SG_IQ1_M;\n                nr0 = N_R0_IQ1_M;\n            } break;\n        case GGML_TYPE_IQ4_NL:\n            {\n                nsg = N_SG_IQ4_NL;\n                nr0 = N_R0_IQ4_NL;\n                smem = 32*sizeof(float);\n            } break;\n        case GGML_TYPE_IQ4_XS:\n            {\n                nsg = N_SG_IQ4_XS;\n                nr0 = N_R0_IQ4_XS;\n                smem = 32*sizeof(float);\n            } break;\n        default:\n            {\n                GGML_LOG_ERROR(\"Asserting on type %d\\n\", (int) tsrc0);\n                GGML_ABORT(\"not implemented\");\n            }\n    };\n\n    snprintf(base, 256, \"kernel_mul_mv_%s_%s%s\", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);\n    snprintf(name, 256, \"%s_nsg=%d\", base, nsg);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    res.nr0  = nr0;\n    res.nr1  = nr1;\n    res.nsg  = nsg;\n    res.smem = smem;\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) {\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_mul_mm_id_map0_ne20_%d\", ne20);\n    snprintf(name, 256, \"%s_ne02=%d\", base, ne02);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    res.smem = (size_t) ne02*ne20*sizeof(uint16_t);\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {\n    char base[256];\n    char name[256];\n\n    const ggml_type tsrc0 = op->src[0]->type;\n    const ggml_type tsrc1 = op->src[1]->type;\n\n    const bool bc_inp = op->src[0]->ne[0] % 32 != 0;\n\n    snprintf(base, 256, \"kernel_mul_mm_id_%s_%s\", ggml_type_name(tsrc0), ggml_type_name(tsrc1));\n    snprintf(name, 256, \"%s_bci=%d\", base, bc_inp);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    res.smem = 8192;\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) {\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n\n    char base[256];\n    char name[256];\n\n    int nsg = 0; // number of simdgroups\n    int nr0 = 0; // number of src0 rows per simdgroup\n    int nr1 = 1; // number of src1 rows per threadgroup\n\n    size_t smem = 0; // shared memory\n\n    const ggml_type tsrc0 = op->src[0]->type;\n    const ggml_type tsrc1 = op->src[1]->type;\n\n    const char * suffix = \"\";\n\n        // use custom matrix x vector kernel\n    switch (tsrc0) {\n        case GGML_TYPE_F32:\n        case GGML_TYPE_F16:\n        case GGML_TYPE_BF16:\n            {\n                nsg = std::min(4, (ne00 + 127) / 128);\n                nr0 = 2;\n                nr1 = 1;\n                smem = 32*sizeof(float)*nr0;\n                suffix = ne00 % 4 == 0 ? \"_4\" : \"\";\n            } break;\n        case GGML_TYPE_Q4_0:\n            {\n                nsg = N_SG_Q4_0;\n                nr0 = N_R0_Q4_0;\n            } break;\n        case GGML_TYPE_Q4_1:\n            {\n                nsg = N_SG_Q4_1;\n                nr0 = N_R0_Q4_1;\n            } break;\n        case GGML_TYPE_Q5_0:\n            {\n                nsg = N_SG_Q5_0;\n                nr0 = N_R0_Q5_0;\n            } break;\n        case GGML_TYPE_Q5_1:\n            {\n                nsg = N_SG_Q5_1;\n                nr0 = N_R0_Q5_1;\n            } break;\n        case GGML_TYPE_Q8_0:\n            {\n                nsg = N_SG_Q8_0;\n                nr0 = N_R0_Q8_0;\n                smem = 32*sizeof(float)*N_R0_Q8_0;\n            } break;\n        case GGML_TYPE_MXFP4:\n            {\n                nsg = N_SG_MXFP4;\n                nr0 = N_R0_MXFP4;\n                smem = 32*sizeof(float);\n            } break;\n        case GGML_TYPE_Q2_K:\n            {\n                nsg = N_SG_Q2_K;\n                nr0 = N_R0_Q2_K;\n            } break;\n        case GGML_TYPE_Q3_K:\n            {\n                nsg = N_SG_Q3_K;\n                nr0 = N_R0_Q3_K;\n            } break;\n        case GGML_TYPE_Q4_K:\n            {\n                nsg = N_SG_Q4_K;\n                nr0 = N_R0_Q4_K;\n            } break;\n        case GGML_TYPE_Q5_K:\n            {\n                nsg = N_SG_Q5_K;\n                nr0 = N_R0_Q5_K;\n            } break;\n        case GGML_TYPE_Q6_K:\n            {\n                nsg = N_SG_Q6_K;\n                nr0 = N_R0_Q6_K;\n            } break;\n        case GGML_TYPE_IQ2_XXS:\n            {\n                nsg = N_SG_IQ2_XXS;\n                nr0 = N_R0_IQ2_XXS;\n                smem = 256*8+128;\n            } break;\n        case GGML_TYPE_IQ2_XS:\n            {\n                nsg = N_SG_IQ2_XS;\n                nr0 = N_R0_IQ2_XS;\n                smem = 512*8+128;\n            } break;\n        case GGML_TYPE_IQ3_XXS:\n            {\n                nsg = N_SG_IQ3_XXS;\n                nr0 = N_R0_IQ3_XXS;\n                smem = 256*4+128;\n            } break;\n        case GGML_TYPE_IQ3_S:\n            {\n                nsg = N_SG_IQ3_S;\n                nr0 = N_R0_IQ3_S;\n                smem = 512*4;\n            } break;\n        case GGML_TYPE_IQ2_S:\n            {\n                nsg = N_SG_IQ2_S;\n                nr0 = N_R0_IQ2_S;\n            } break;\n        case GGML_TYPE_IQ1_S:\n            {\n                nsg = N_SG_IQ1_S;\n                nr0 = N_R0_IQ1_S;\n            } break;\n        case GGML_TYPE_IQ1_M:\n            {\n                nsg = N_SG_IQ1_M;\n                nr0 = N_R0_IQ1_M;\n            } break;\n        case GGML_TYPE_IQ4_NL:\n            {\n                nsg = N_SG_IQ4_NL;\n                nr0 = N_R0_IQ4_NL;\n                smem = 32*sizeof(float);\n            } break;\n        case GGML_TYPE_IQ4_XS:\n            {\n                nsg = N_SG_IQ4_XS;\n                nr0 = N_R0_IQ4_XS;\n                smem = 32*sizeof(float);\n            } break;\n        default:\n            {\n                GGML_LOG_ERROR(\"Asserting on type %d\\n\", (int)op->src[2]->type);\n                GGML_ABORT(\"not implemented\");\n            }\n    };\n\n    snprintf(base, 256, \"kernel_mul_mv_id_%s_%s%s\", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);\n    snprintf(name, 256, \"%s_nsg=%d\", base, nsg);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    res.nr0  = nr0;\n    res.nr1  = nr1;\n    res.nsg  = nsg;\n    res.smem = smem;\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) {\n    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));\n    GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_argmax_%s\", ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    res.smem = 32*(sizeof(float) + sizeof(int32_t));\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_ARGSORT);\n\n    char base[256];\n    char name[256];\n\n    ggml_sort_order order = (ggml_sort_order) op->op_params[0];\n\n    const char * order_str = \"undefined\";\n    switch (order) {\n        case GGML_SORT_ORDER_ASC:  order_str = \"asc\";  break;\n        case GGML_SORT_ORDER_DESC: order_str = \"desc\"; break;\n        default: GGML_ABORT(\"fatal error\");\n    };\n\n    snprintf(base, 256, \"kernel_argsort_%s_%s_%s\", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_ARGSORT);\n\n    char base[256];\n    char name[256];\n\n    ggml_sort_order order = (ggml_sort_order) op->op_params[0];\n\n    const char * order_str = \"undefined\";\n    switch (order) {\n        case GGML_SORT_ORDER_ASC:  order_str = \"asc\";  break;\n        case GGML_SORT_ORDER_DESC: order_str = \"desc\"; break;\n        default: GGML_ABORT(\"fatal error\");\n    };\n\n    snprintf(base, 256, \"kernel_argsort_merge_%s_%s_%s\", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\n// note: reuse the argsort kernel for top_k\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_TOP_K);\n\n    char base[256];\n    char name[256];\n\n    // note: the top_k kernel is always descending order\n    ggml_sort_order order = GGML_SORT_ORDER_DESC;\n\n    const char * order_str = \"undefined\";\n    switch (order) {\n        case GGML_SORT_ORDER_ASC:  order_str = \"asc\";  break;\n        case GGML_SORT_ORDER_DESC: order_str = \"desc\"; break;\n        default: GGML_ABORT(\"fatal error\");\n    };\n\n    snprintf(base, 256, \"kernel_argsort_%s_%s_%s\", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_TOP_K);\n\n    char base[256];\n    char name[256];\n\n    ggml_sort_order order = GGML_SORT_ORDER_DESC;\n\n    const char * order_str = \"undefined\";\n    switch (order) {\n        case GGML_SORT_ORDER_ASC:  order_str = \"asc\";  break;\n        case GGML_SORT_ORDER_DESC: order_str = \"desc\"; break;\n        default: GGML_ABORT(\"fatal error\");\n    };\n\n    snprintf(base, 256, \"kernel_argsort_merge_%s_%s_%s\", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(\n        ggml_metal_library_t lib,\n        const struct ggml_tensor * op,\n        bool    has_mask,\n        int32_t ncpsg) {\n    assert(op->op == GGML_OP_FLASH_ATTN_EXT);\n    GGML_UNUSED(op);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_%s\",\n            \"flash_attn_ext_pad\");\n\n    snprintf(name, 256, \"%s_mask=%d_ncpsg=%d\",\n            base,\n            has_mask,\n            ncpsg);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT_PAD + 0);\n        //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);\n        //ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_PAD + 2);\n        //ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_PAD + 3);\n\n        //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);\n        //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);\n        //ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_PAD + 22);\n        //ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_PAD + 23);\n        //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);\n        ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(\n        ggml_metal_library_t lib,\n        const struct ggml_tensor * op,\n        int32_t nqptg,\n        int32_t ncpsg) {\n    assert(op->op == GGML_OP_FLASH_ATTN_EXT);\n    GGML_UNUSED(op);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_%s\",\n            \"flash_attn_ext_blk\");\n\n    snprintf(name, 256, \"%s_nqptg=%d_ncpsg=%d\",\n            base,\n            nqptg,\n            ncpsg);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        //ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT_BLK + 0);\n        //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);\n        //ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_BLK + 2);\n        //ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_BLK + 3);\n\n        //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);\n        //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);\n        //ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_BLK + 22);\n        //ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_BLK + 23);\n        ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);\n        ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(\n        ggml_metal_library_t lib,\n        const ggml_tensor * op,\n        bool    has_mask,\n        bool    has_sinks,\n        bool    has_bias,\n        bool    has_scap,\n        bool    has_kvpad,\n        int32_t nsg) {\n    assert(op->op == GGML_OP_FLASH_ATTN_EXT);\n\n    char base[256];\n    char name[256];\n\n    const int32_t dk = (int32_t) op->src[1]->ne[0];\n    const int32_t dv = (int32_t) op->src[2]->ne[0];\n\n    const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];\n    const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];\n\n    // do bounds checks for the mask?\n    const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);\n\n    snprintf(base, 256, \"kernel_%s_%s_dk%d_dv%d\",\n            \"flash_attn_ext\",\n            ggml_type_name(op->src[1]->type),\n            dk,\n            dv);\n\n    snprintf(name, 256, \"%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d\",\n            base,\n            has_mask,\n            has_sinks,\n            has_bias,\n            has_scap,\n            has_kvpad,\n            bc_mask,\n            ns10,\n            ns20,\n            nsg);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT + 0);\n        ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);\n        ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT + 2);\n        ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT + 3);\n        ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);\n\n        ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);\n\n        ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);\n        ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);\n        ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT + 22);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(\n        ggml_metal_library_t lib,\n        const ggml_tensor * op,\n        bool    has_mask,\n        bool    has_sinks,\n        bool    has_bias,\n        bool    has_scap,\n        bool    has_kvpad,\n        int32_t nsg,\n        int32_t nwg) {\n    assert(op->op == GGML_OP_FLASH_ATTN_EXT);\n\n    char base[256];\n    char name[256];\n\n    const int32_t dk = (int32_t) op->src[1]->ne[0];\n    const int32_t dv = (int32_t) op->src[2]->ne[0];\n\n    const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];\n    const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];\n\n    snprintf(base, 256, \"kernel_%s_%s_dk%d_dv%d\",\n            \"flash_attn_ext_vec\",\n            ggml_type_name(op->src[1]->type),\n            dk,\n            dv);\n\n    snprintf(name, 256, \"%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d\",\n            base,\n            has_mask,\n            has_sinks,\n            has_bias,\n            has_scap,\n            has_kvpad,\n            ns10,\n            ns20,\n            nsg, nwg);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT_VEC + 0);\n        ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);\n        ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_VEC + 2);\n        ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_VEC + 3);\n        ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);\n\n        ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);\n        ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);\n        ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_VEC + 22);\n        ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_VEC + 23);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(\n        ggml_metal_library_t lib,\n        const ggml_tensor * op,\n        int32_t dv,\n        int32_t nwg) {\n    assert(op->op == GGML_OP_FLASH_ATTN_EXT);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_flash_attn_ext_vec_reduce\");\n    snprintf(name, 256, \"%s_dv=%d_nwg=%d\", base, dv, nwg);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_int32(cv, dv,  FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);\n        ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    return res;\n\n    GGML_UNUSED(op);\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {\n    char base[256];\n    char name[256];\n\n    int op_num = -1;\n\n    switch (op->op) {\n        case GGML_OP_ADD: op_num = 0; break;\n        case GGML_OP_SUB: op_num = 1; break;\n        case GGML_OP_MUL: op_num = 2; break;\n        case GGML_OP_DIV: op_num = 3; break;\n        default: GGML_ABORT(\"fatal error\");\n    };\n\n    const char * t0_str = ggml_type_name(op->src[0]->type);\n    const char * t1_str = ggml_type_name(op->src[1]->type);\n    const char * t_str  = ggml_type_name(op->type);\n\n    const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);\n\n    const bool is_cb = op->src[0]->ne[0] != op->src[1]->ne[0];\n    const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;\n\n    snprintf(base, 256, \"kernel_bin_fuse_%s_%s_%s%s\", t0_str, t1_str, t_str, is_c4 ? \"_4\" : \"\");\n    snprintf(name, 256, \"%s_op=%d_nf=%d_rb=%d_cb=%d\", base, op_num, n_fuse, is_rb, is_cb);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);\n        ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);\n        ggml_metal_cv_set_bool (cv, is_rb,  FC_BIN + 2);\n        ggml_metal_cv_set_bool (cv, is_cb,  FC_BIN + 3);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    res.c4  = is_c4;\n    res.cnt = is_rb;\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {\n    char base[256];\n    char name[256];\n\n    int op_num = -1;\n\n    switch (op) {\n        case GGML_OP_ADD: op_num = 0; break;\n        case GGML_OP_SUB: op_num = 1; break;\n        case GGML_OP_MUL: op_num = 2; break;\n        case GGML_OP_DIV: op_num = 3; break;\n        default: GGML_ABORT(\"fatal error\");\n    };\n\n    snprintf(base, 256, \"kernel_bin_fuse_%s_%s_%s\", \"f32\", \"f32\", \"f32\");\n    snprintf(name, 256, \"%s_op=%d_nf=%d\", base, op_num, 1);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);\n        ggml_metal_cv_set_int16(cv, 1,      FC_BIN + 1);\n        ggml_metal_cv_set_bool (cv, false,  FC_BIN + 2);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_L2_NORM);\n\n    char base[256];\n    char name[256];\n\n    const bool is_c4 = op->src[0]->ne[0] % 4 == 0;\n\n    const char * t0_str = ggml_type_name(op->src[0]->type);\n    const char * t_str  = ggml_type_name(op->type);\n\n    snprintf(base, 256, \"kernel_l2_norm_%s_%s%s\", t0_str, t_str, is_c4 ? \"_4\" : \"\");\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    res.c4   = is_c4;\n    res.smem = 32*sizeof(float);\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_GROUP_NORM);\n\n    GGML_ASSERT(ggml_is_contiguous(op->src[0]));\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_group_norm_f32\");\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    res.smem = 32*sizeof(float);\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {\n    assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);\n\n    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));\n\n    char base[256];\n    char name[256];\n\n    const char * suffix = \"\";\n    if (op->ne[0] % 4 == 0) {\n        suffix = \"_4\";\n    }\n\n    switch (op->op) {\n        case GGML_OP_NORM:\n            switch (n_fuse) {\n                case 1: snprintf(base, 256, \"kernel_norm_f32%s\", suffix);         break;\n                case 2: snprintf(base, 256, \"kernel_norm_mul_f32%s\", suffix);     break;\n                case 3: snprintf(base, 256, \"kernel_norm_mul_add_f32%s\", suffix); break;\n                default: GGML_ABORT(\"fatal error\");\n            } break;\n        case GGML_OP_RMS_NORM:\n            switch (n_fuse) {\n                case 1: snprintf(base, 256, \"kernel_rms_norm_f32%s\", suffix);         break;\n                case 2: snprintf(base, 256, \"kernel_rms_norm_mul_f32%s\", suffix);     break;\n                case 3: snprintf(base, 256, \"kernel_rms_norm_mul_add_f32%s\", suffix); break;\n                default: GGML_ABORT(\"fatal error\");\n            } break;\n        default: GGML_ABORT(\"fatal error\");\n    }\n\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    res.smem = 32*sizeof(float);\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_ROPE);\n\n    char base[256];\n    char name[256];\n\n    const int mode = ((const int32_t *) op->op_params)[2];\n\n    const bool is_neox   = mode & GGML_ROPE_TYPE_NEOX;\n    const bool is_mrope  = mode & GGML_ROPE_TYPE_MROPE;\n    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;\n    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;\n\n    if (is_neox) {\n        snprintf(base, 256, \"kernel_rope_neox_%s\", ggml_type_name(op->src[0]->type));\n    } else if ((is_mrope || is_imrope) && !is_vision) {\n        GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token\n        snprintf(base, 256, \"kernel_rope_multi_%s\", ggml_type_name(op->src[0]->type));\n    } else if (is_vision) {\n        GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token\n        snprintf(base, 256, \"kernel_rope_vision_%s\", ggml_type_name(op->src[0]->type));\n    } else {\n        snprintf(base, 256, \"kernel_rope_norm_%s\", ggml_type_name(op->src[0]->type));\n    }\n\n    snprintf(name, 256, \"%s_imrope=%d\", base, is_imrope ? 1 : 0);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_IM2COL);\n\n    GGML_ASSERT(ggml_is_contiguous(op->src[1]));\n    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->type         == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_im2col_%s\", ggml_type_name(op->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_CONV_TRANSPOSE_1D);\n\n    GGML_ASSERT(ggml_is_contiguous(op->src[0]));\n    GGML_ASSERT(ggml_is_contiguous(op->src[1]));\n    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->type         == GGML_TYPE_F32);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_conv_transpose_1d_%s_%s\", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);\n\n    GGML_ASSERT(ggml_is_contiguous(op->src[0]));\n    GGML_ASSERT(ggml_is_contiguous(op->src[1]));\n    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->type         == GGML_TYPE_F32);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_conv_transpose_2d_%s_%s\", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_CONV_2D);\n\n    GGML_ASSERT(ggml_is_contiguous(op->src[0]));\n    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->type         == GGML_TYPE_F32);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_conv_2d_%s_%s\", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_UPSCALE);\n\n    char base[256];\n    char name[256];\n\n    const int32_t mode_flags = ggml_get_op_params_i32(op, 0);\n    const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);\n\n    const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);\n\n    if (mode == GGML_SCALE_MODE_BILINEAR) {\n        snprintf(base, 256, \"kernel_upscale_bilinear_%s\", ggml_type_name(op->src[0]->type));\n    } else if (mode == GGML_SCALE_MODE_BICUBIC) {\n        snprintf(base, 256, \"kernel_upscale_bicubic_%s\", ggml_type_name(op->src[0]->type));\n    } else {\n        snprintf(base, 256, \"kernel_upscale_nearest_%s\", ggml_type_name(op->src[0]->type));\n    }\n    snprintf(name, 256, \"%s_aa=%d\", base, antialias);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_bool(cv, antialias, FC_UPSCALE + 0);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_PAD);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_pad_%s\", ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (res.pipeline) {\n        return res;\n    }\n\n    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_PAD_REFLECT_1D);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_pad_reflect_1d_%s\", ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_ARANGE);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_arange_%s\", ggml_type_name(op->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_TIMESTEP_EMBEDDING);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_timestep_embedding_%s\", ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_OPT_STEP_ADAMW);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_opt_step_adamw_%s\", ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {\n    assert(op->op == GGML_OP_OPT_STEP_SGD);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_opt_step_sgd_%s\", ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor *  op) {\n    GGML_ASSERT(op->type == GGML_TYPE_I64);\n\n    char base[256];\n    char name[256];\n\n    snprintf(base, 256, \"kernel_memset_%s\", ggml_type_name(op->type));\n    snprintf(name, 256, \"%s\", base);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);\n    }\n\n    return res;\n}\n\nggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor *  op) {\n    assert(op->op == GGML_OP_COUNT_EQUAL);\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);\n\n    GGML_ASSERT(op->src[0]->type == op->src[1]->type);\n    GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);\n    GGML_ASSERT(op->type == GGML_TYPE_I64);\n\n    // note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int\n    GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));\n\n    char base[256];\n    char name[256];\n\n    int nsg = 1;\n    while (32*nsg < ne00 && nsg < 32) {\n        nsg *= 2;\n    }\n\n    snprintf(base, 256, \"kernel_count_equal_%s\", ggml_type_name(op->src[0]->type));\n    snprintf(name, 256, \"%s_nsg=%d\", base, nsg);\n\n    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);\n    if (!res.pipeline) {\n        ggml_metal_cv_t cv = ggml_metal_cv_init();\n\n        ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);\n\n        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);\n\n        ggml_metal_cv_free(cv);\n    }\n\n    res.smem = 32 * sizeof(int32_t);\n    res.nsg  = nsg;\n\n    return res;\n}\n"
  },
  {
    "path": "src/ggml-metal/ggml-metal-device.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nstruct ggml_metal_buffer_id {\n    void * metal; // id<MTLBuffer>\n    size_t offs;\n};\n\ntypedef struct ggml_metal_device * ggml_metal_device_t;\n\n//\n// MTLFunctionConstantValues wrapper\n//\n\ntypedef struct ggml_metal_cv * ggml_metal_cv_t;\n\nggml_metal_cv_t ggml_metal_cv_init(void);\nvoid ggml_metal_cv_free(ggml_metal_cv_t cv);\n\nvoid ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx);\nvoid ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx);\nvoid ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool    value, int32_t idx);\n\n//\n// MTLComputePipelineState wrapper\n//\n\ntypedef struct ggml_metal_pipeline * ggml_metal_pipeline_t;\n\nggml_metal_pipeline_t ggml_metal_pipeline_init(void);\nvoid ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline);\n\n// a collection of pipelines\ntypedef struct ggml_metal_pipelines * ggml_metal_pipelines_t;\n\nggml_metal_pipelines_t ggml_metal_pipelines_init(void);\nvoid ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls);\n\nvoid                  ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline);\nggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name);\n\nstruct ggml_metal_pipeline_with_params {\n    ggml_metal_pipeline_t pipeline;\n\n    int nsg;\n\n    int nr0;\n    int nr1;\n\n    size_t smem;\n\n    bool c4;\n    bool cnt;\n};\n\nint ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);\n\n//\n// MTLCommandBuffer wrapper\n//\n\ntypedef void * ggml_metal_cmd_buf_t;\n\n//\n// MTLComputeCommandEncoder wrapper\n//\n\ntypedef struct ggml_metal_encoder * ggml_metal_encoder_t;\n\nggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_cmd_buf_t cmd_buf_raw, bool concurrent);\nvoid ggml_metal_encoder_free(ggml_metal_encoder_t encoder);\n\nvoid ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name);\nvoid ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder);\n\nvoid ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline);\n\nvoid ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx);\nvoid ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx);\n\nvoid ggml_metal_encoder_set_threadgroup_memory_size(ggml_metal_encoder_t encoder, size_t size, int idx);\n\nvoid ggml_metal_encoder_dispatch_threadgroups(ggml_metal_encoder_t encoder, int tg0, int tg1, int tg2, int tptg0, int tptg1, int tptg2);\n\nvoid ggml_metal_encoder_memory_barrier(ggml_metal_encoder_t encoder);\n\nvoid ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder);\n\n//\n// MTLLibrary wrapper\n//\n\ntypedef struct ggml_metal_library * ggml_metal_library_t;\n\nggml_metal_library_t ggml_metal_library_init            (ggml_metal_device_t dev);\nggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose);\n\nvoid ggml_metal_library_free(ggml_metal_library_t lib);\n\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline    (ggml_metal_library_t lib, const char * name);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);\n\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base              (ggml_metal_library_t lib, enum ggml_op op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy               (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d           (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows          (ggml_metal_library_t lib, enum ggml_type tsrc);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows          (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag              (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat            (ggml_metal_library_t lib, enum ggml_type tsrc);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary             (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu               (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum               (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows          (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk        (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add        (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri               (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max          (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv          (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched  (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan          (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv              (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net   (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri         (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext        (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm            (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv            (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0    (ggml_metal_library_t lib, int ne02, int ne20);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id         (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id         (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax            (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort           (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge     (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k             (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge       (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse );\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one           (ggml_metal_library_t lib, enum ggml_op op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm           (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm        (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm              (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope              (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col            (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale           (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad               (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d    (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange            (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw    (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd      (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset            (ggml_metal_library_t lib, const struct ggml_tensor * op);\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal       (ggml_metal_library_t lib, const struct ggml_tensor * op);\n\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(\n        ggml_metal_library_t lib,\n        const struct ggml_tensor * op,\n        bool    has_mask,\n        int32_t ncpsg);\n\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(\n        ggml_metal_library_t lib,\n        const struct ggml_tensor * op,\n        int32_t nqptg,\n        int32_t ncpsg);\n\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(\n        ggml_metal_library_t lib,\n        const struct ggml_tensor * op,\n        bool    has_mask,\n        bool    has_sinks,\n        bool    has_bias,\n        bool    has_scap,\n        bool    has_kvpad,\n        int32_t nsg);\n\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(\n        ggml_metal_library_t lib,\n        const struct ggml_tensor * op,\n        bool    has_mask,\n        bool    has_sinks,\n        bool    has_bias,\n        bool    has_scap,\n        bool    has_kvpad,\n        int32_t nsg,\n        int32_t nwg);\n\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(\n        ggml_metal_library_t lib,\n        const struct ggml_tensor * op,\n        int32_t dv,\n        int32_t nwg);\n\n// MTLResidencySet wrapper\n\ntypedef void * ggml_metal_rset_t;\n\n// a collection of residency sets (non-owning)\ntypedef struct ggml_metal_rsets * ggml_metal_rsets_t;\n\nggml_metal_rsets_t ggml_metal_rsets_init(void);\nvoid ggml_metal_rsets_free(ggml_metal_rsets_t rsets);\n\n//\n// device\n//\n\nstruct ggml_metal_device_props {\n    int device;\n    char name[128];\n    char desc[128];\n\n    size_t max_buffer_size;\n    size_t max_working_set_size;\n    size_t max_theadgroup_memory_size;\n\n    bool has_simdgroup_reduction;\n    bool has_simdgroup_mm;\n    bool has_unified_memory;\n    bool has_bfloat;\n    bool has_tensor;\n    bool use_residency_sets;\n    bool use_shared_buffers;\n\n    bool supports_gpu_family_apple7;\n\n    int op_offload_min_batch_size;\n};\n\ntypedef struct ggml_metal_event * ggml_metal_event_t;\n\nvoid ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf);\nvoid ggml_metal_event_encode_wait  (ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf);\n\nggml_metal_device_t ggml_metal_device_init(int device);\nvoid ggml_metal_device_free(ggml_metal_device_t dev);\n\nggml_metal_device_t ggml_metal_device_get(int device);\n\nvoid * ggml_metal_device_get_obj  (ggml_metal_device_t dev); // id<MTLDevice>\nvoid * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id<MTLCommandQueue>\n\nggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev);\n\nvoid ggml_metal_device_rsets_add(ggml_metal_device_t dev, ggml_metal_rset_t rset);\nvoid ggml_metal_device_rsets_rm (ggml_metal_device_t dev, ggml_metal_rset_t rset);\n\nvoid ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev);\n\nggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev);\nvoid ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev);\nvoid ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev);\n\nvoid ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total);\nbool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op);\n\nconst struct ggml_metal_device_props * ggml_metal_device_get_props(ggml_metal_device_t dev);\n\n//\n// device buffers\n//\n\ntypedef struct ggml_metal_buffer * ggml_metal_buffer_t;\n\nggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared);\nggml_metal_buffer_t ggml_metal_buffer_map (ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size);\n\nvoid   ggml_metal_buffer_free     (ggml_metal_buffer_t buf);\nvoid * ggml_metal_buffer_get_base (ggml_metal_buffer_t buf);\nbool   ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf);\n\nvoid   ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);\nvoid   ggml_metal_buffer_set_tensor   (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);\nvoid   ggml_metal_buffer_get_tensor   (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);\nvoid   ggml_metal_buffer_clear        (ggml_metal_buffer_t buf, uint8_t value);\n\n// finds the Metal buffer that contains the tensor data on the GPU device\n// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the\n// Metal buffer based on the host memory pointer\n//\nstruct ggml_metal_buffer_id ggml_metal_buffer_get_id(ggml_metal_buffer_t buf, const struct ggml_tensor * t);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-metal/ggml-metal-device.m",
    "content": "#import \"ggml-metal-device.h\"\n\n#import \"ggml-impl.h\"\n\n#include <Foundation/Foundation.h>\n\n#include <Metal/Metal.h>\n\n#include <stdatomic.h>\n\n#ifndef TARGET_OS_VISION\n#define TARGET_OS_VISION 0\n#endif\n\n// create residency sets only on macOS >= 15.0\n#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \\\n    TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \\\n    TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \\\n    TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000\n#define GGML_METAL_HAS_RESIDENCY_SETS 1\n#endif\n\n// overload of MTLGPUFamilyMetalX (not available in some environments)\nstatic const NSInteger MTLGPUFamilyMetal3_GGML = 5001;\nstatic const NSInteger MTLGPUFamilyMetal4_GGML = 5002;\n\n#if !GGML_METAL_EMBED_LIBRARY\n// Here to assist with NSBundle Path Hack\n@interface GGMLMetalClass : NSObject\n@end\n@implementation GGMLMetalClass\n@end\n#endif\n\n//\n// MTLFunctionConstantValues wrapper\n//\n\nstruct ggml_metal_cv {\n    MTLFunctionConstantValues * obj;\n};\n\nggml_metal_cv_t ggml_metal_cv_init(void) {\n    ggml_metal_cv_t res = calloc(1, sizeof(struct ggml_metal_cv));\n\n    res->obj = [[MTLFunctionConstantValues alloc] init];\n\n    return res;\n}\n\nvoid ggml_metal_cv_free(ggml_metal_cv_t cv) {\n    [cv->obj release];\n    free(cv);\n}\n\nvoid ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx) {\n    [cv->obj setConstantValue:&value type:MTLDataTypeShort atIndex:idx];\n}\n\nvoid ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx) {\n    [cv->obj setConstantValue:&value type:MTLDataTypeInt atIndex:idx];\n}\n\nvoid ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) {\n    [cv->obj setConstantValue:&value type:MTLDataTypeBool atIndex:idx];\n}\n\n//\n// MTLComputePipelineState wrapper\n//\n\nstruct ggml_metal_pipeline {\n    id<MTLComputePipelineState> obj;\n};\n\nggml_metal_pipeline_t ggml_metal_pipeline_init(void) {\n    ggml_metal_pipeline_t res = calloc(1, sizeof(struct ggml_metal_pipeline));\n\n    *res = (struct ggml_metal_pipeline) {\n        /*.obj  =*/ nil,\n    };\n\n    return res;\n}\n\nvoid ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) {\n    [pipeline->obj release];\n\n    free(pipeline);\n}\n\nint ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline) {\n    return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;\n}\n\nstruct ggml_metal_library {\n    id<MTLLibrary> obj;\n    id<MTLDevice> device;\n\n    ggml_metal_pipelines_t pipelines; // cache of compiled pipelines\n\n    NSLock * lock;\n};\n\nggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {\n    id<MTLLibrary> library = nil;\n    id<MTLDevice> device = ggml_metal_device_get_obj(dev);\n\n    // load library\n    //\n    // - first check if the library is embedded\n    // - then check if the library is in the bundle\n    // - if not found, load the source and compile it\n    // - if that fails, return NULL\n    //\n    // TODO: move to a function\n    {\n        const int64_t t_start = ggml_time_us();\n\n        NSError * error = nil;\n        NSString * src = nil;\n\n#if GGML_METAL_EMBED_LIBRARY\n        GGML_LOG_INFO(\"%s: using embedded metal library\\n\", __func__);\n\n        extern const char ggml_metallib_start[];\n        extern const char ggml_metallib_end[];\n\n        src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];\n#else\n\n#ifdef SWIFT_PACKAGE\n        NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;\n#else\n        NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];\n#endif\n\n        NSString * path_lib = [bundle pathForResource:@\"default\" ofType:@\"metallib\"];\n        if (path_lib == nil) {\n            // Try to find the resource in the directory where the current binary located.\n            NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];\n            NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];\n\n            NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @\"default.metallib\"]];\n            if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {\n                GGML_LOG_INFO(\"%s: found '%s'\\n\", __func__, [path_lib_default UTF8String]);\n\n                NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];\n                if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {\n                    // Optionally, if this is a symlink, try to resolve it.\n                    path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];\n                    if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@\"/\"]) {\n                        // It is a relative path, adding the binary directory as directory prefix.\n                        path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];\n                    }\n                    if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {\n                        // Link to the resource could not be resolved.\n                        path_lib_default = nil;\n                    } else {\n                        GGML_LOG_INFO(\"%s: symlink resolved '%s'\\n\", __func__, [path_lib_default UTF8String]);\n                    }\n                }\n            } else {\n                // The resource couldn't be found in the binary's directory.\n                path_lib_default = nil;\n            }\n\n            path_lib = path_lib_default;\n        }\n\n        if (path_lib != nil) {\n            // pre-compiled library found\n            NSURL * libURL = [NSURL fileURLWithPath:path_lib];\n            GGML_LOG_INFO(\"%s: loading '%s'\\n\", __func__, [path_lib UTF8String]);\n\n            library = [device newLibraryWithURL:libURL error:&error];\n            if (error) {\n                GGML_LOG_ERROR(\"%s: error: %s\\n\", __func__, [[error description] UTF8String]);\n                return nil;\n            }\n        } else {\n            GGML_LOG_INFO(\"%s: default.metallib not found, loading from source\\n\", __func__);\n\n            NSString * path_source;\n            NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@\"GGML_METAL_PATH_RESOURCES\"];\n\n            GGML_LOG_INFO(\"%s: GGML_METAL_PATH_RESOURCES = %s\\n\", __func__, path_resource ? [path_resource UTF8String] : \"nil\");\n\n            if (path_resource) {\n                path_source = [path_resource stringByAppendingPathComponent:@\"ggml-metal.metal\"];\n            } else {\n                path_source = [bundle pathForResource:@\"ggml-metal\" ofType:@\"metal\"];\n            }\n\n            if (path_source == nil) {\n                GGML_LOG_WARN(\"%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\\n\", __func__);\n                path_source = @\"ggml-metal.metal\";\n            }\n\n            GGML_LOG_INFO(\"%s: loading '%s'\\n\", __func__, [path_source UTF8String]);\n\n            src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];\n            if (error) {\n                GGML_LOG_ERROR(\"%s: error: %s\\n\", __func__, [[error description] UTF8String]);\n                return nil;\n            }\n        }\n#endif\n\n        if (!library) {\n            @autoreleasepool {\n                // dictionary of preprocessor macros\n                NSMutableDictionary * prep = [NSMutableDictionary dictionary];\n\n                if (ggml_metal_device_get_props(dev)->has_bfloat) {\n                    [prep setObject:@\"1\" forKey:@\"GGML_METAL_HAS_BF16\"];\n                }\n\n                if (ggml_metal_device_get_props(dev)->has_tensor) {\n                    [prep setObject:@\"1\" forKey:@\"GGML_METAL_HAS_TENSOR\"];\n                }\n\n#if GGML_METAL_EMBED_LIBRARY\n                [prep setObject:@\"1\" forKey:@\"GGML_METAL_EMBED_LIBRARY\"];\n#endif\n\n                MTLCompileOptions * options = [MTLCompileOptions new];\n                options.preprocessorMacros = prep;\n\n                //[options setFastMathEnabled:false];\n\n                library = [device newLibraryWithSource:src options:options error:&error];\n                if (error) {\n                    GGML_LOG_ERROR(\"%s: error: %s\\n\", __func__, [[error description] UTF8String]);\n                    return nil;\n                }\n\n#if !__has_feature(objc_arc)\n                [options release];\n#endif\n            }\n        }\n\n#if GGML_METAL_EMBED_LIBRARY\n        [src release];\n#endif // GGML_METAL_EMBED_LIBRARY\n\n        GGML_LOG_INFO(\"%s: loaded in %.3f sec\\n\", __func__, (ggml_time_us() - t_start) / 1e6);\n    }\n\n    ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));\n\n    res->obj       = library;\n    res->device    = device;\n    res->pipelines = ggml_metal_pipelines_init();\n    res->lock      = [NSLock new];\n\n    return res;\n}\n\nggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) {\n    if (source == NULL) {\n        GGML_LOG_ERROR(\"%s: source is NULL\\n\", __func__);\n        return NULL;\n    }\n\n    id<MTLDevice> device = ggml_metal_device_get_obj(dev);\n    id<MTLLibrary> library = nil;\n    NSError * error = nil;\n\n    const int64_t t_start = ggml_time_us();\n\n    NSString * src = [[NSString alloc] initWithBytes:source\n                                              length:strlen(source)\n                                            encoding:NSUTF8StringEncoding];\n    if (!src) {\n        GGML_LOG_ERROR(\"%s: failed to create NSString from source\\n\", __func__);\n        return NULL;\n    }\n\n    @autoreleasepool {\n        NSMutableDictionary * prep = [NSMutableDictionary dictionary];\n\n        MTLCompileOptions * options = [MTLCompileOptions new];\n        options.preprocessorMacros = prep;\n\n        library = [device newLibraryWithSource:src options:options error:&error];\n        if (error) {\n            if (verbose) {\n                GGML_LOG_ERROR(\"%s: error compiling source: %s\\n\", __func__, [[error description] UTF8String]);\n            } else {\n                GGML_LOG_ERROR(\"%s: error compiling source\\n\", __func__);\n            }\n            library = nil;\n        }\n\n        [options release];\n    }\n\n    [src release];\n\n    if (!library) {\n        if (verbose) {\n            GGML_LOG_ERROR(\"%s: failed to create Metal library from source\\n\", __func__);\n        }\n\n        return NULL;\n    }\n\n    if (verbose) {\n        GGML_LOG_INFO(\"%s: compiled in %.3f sec\\n\", __func__, (ggml_time_us() - t_start) / 1e6);\n    }\n\n    ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));\n    if (!res) {\n        GGML_LOG_ERROR(\"%s: calloc failed\\n\", __func__);\n        return NULL;\n    }\n\n    res->obj       = library;\n    res->device    = device;\n    res->pipelines = ggml_metal_pipelines_init();\n    res->lock      = [NSLock new];\n\n    return res;\n}\n\nvoid ggml_metal_library_free(ggml_metal_library_t lib) {\n    if (!lib) {\n        return;\n    }\n\n    if (lib->obj) {\n        [lib->obj release];\n    }\n\n    ggml_metal_pipelines_free(lib->pipelines);\n\n    [lib->lock release];\n\n    free(lib);\n}\n\nstruct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {\n    [lib->lock lock];\n\n    struct ggml_metal_pipeline_with_params res = {\n        /*.pipeline =*/ nil,\n        /*.nsg      =*/ 0,\n        /*.nr0      =*/ 0,\n        /*.nr1      =*/ 0,\n        /*.smem     =*/ 0,\n        /*.c4       =*/ false,\n        /*.cnt      =*/ false,\n    };\n\n    res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);\n\n    [lib->lock unlock];\n\n    return res;\n}\n\nstruct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {\n    struct ggml_metal_pipeline_with_params res = {\n        /*.pipeline =*/ nil,\n        /*.nsg      =*/ 0,\n        /*.nr0      =*/ 0,\n        /*.nr1      =*/ 0,\n        /*.smem     =*/ 0,\n        /*.c4       =*/ false,\n        /*.cnt      =*/ false,\n    };\n\n    [lib->lock lock];\n\n    res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);\n    if (res.pipeline) {\n        [lib->lock unlock];\n\n        return res;\n    }\n\n    @autoreleasepool {\n        NSError * error = nil;\n\n        NSString * base_func = [NSString stringWithUTF8String:base];\n\n        GGML_LOG_DEBUG(\"%s: compiling pipeline: base = '%s', name = '%s'\\n\", __func__, base, name);\n\n        id<MTLFunction> mtl_function;\n        if (!cv) {\n            mtl_function = [lib->obj newFunctionWithName:base_func];\n        } else {\n            mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];\n        }\n        if (!mtl_function) {\n            [lib->lock unlock];\n\n            GGML_LOG_ERROR(\"%s: failed to compile pipeline: base = '%s', name = '%s'\\n\", __func__, base, name);\n            if (error) {\n                GGML_LOG_ERROR(\"%s: %s\\n\", __func__, [[error description] UTF8String]);\n            }\n\n            return res;\n        }\n\n        id<MTLComputePipelineState> obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];\n\n        [mtl_function release];\n\n        if (!obj) {\n            [lib->lock unlock];\n\n            GGML_LOG_ERROR(\"%s: failed to create pipeline state: base = '%s', name = '%s'\\n\", __func__, base, name);\n            if (error) {\n                GGML_LOG_ERROR(\"%s: %s\\n\", __func__, [[error description] UTF8String]);\n            }\n\n            return res;\n        }\n\n        GGML_LOG_DEBUG(\"%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\\n\", __func__, name,\n                (void *) obj,\n                (int)    obj.maxTotalThreadsPerThreadgroup,\n                (int)    obj.threadExecutionWidth);\n\n        if (obj.maxTotalThreadsPerThreadgroup == 0 || obj.threadExecutionWidth == 0) {\n            [obj release];\n\n            [lib->lock unlock];\n\n            GGML_LOG_ERROR(\"%s: incompatible pipeline %s\\n\", __func__, name);\n\n            return res;\n        }\n\n        res.pipeline = ggml_metal_pipeline_init();\n        res.pipeline->obj = obj;\n\n        ggml_metal_pipelines_add(lib->pipelines, name, res.pipeline);\n    }\n\n    [lib->lock unlock];\n\n    return res;\n}\n\n//\n// MTLComputeCommandEncoder wrapper\n//\n\nstruct ggml_metal_encoder {\n    id<MTLComputeCommandEncoder> obj;\n};\n\nggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_cmd_buf_t cmd_buf_raw, bool concurrent) {\n    ggml_metal_encoder_t res = calloc(1, sizeof(struct ggml_metal_encoder));\n\n    id<MTLCommandBuffer> cmd_buf = (id<MTLCommandBuffer>) cmd_buf_raw;\n\n    if (concurrent) {\n        res->obj = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];\n    } else {\n        res->obj = [cmd_buf computeCommandEncoder];\n    }\n\n    [res->obj retain];\n\n    return res;\n}\n\nvoid ggml_metal_encoder_free(ggml_metal_encoder_t encoder) {\n    [encoder->obj release];\n    free(encoder);\n}\n\nvoid ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name) {\n    [encoder->obj pushDebugGroup:[NSString stringWithCString:name encoding:NSUTF8StringEncoding]];\n}\n\nvoid ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) {\n    [encoder->obj popDebugGroup];\n}\n\nvoid ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline) {\n    [encoder->obj setComputePipelineState:pipeline.pipeline->obj];\n}\n\nvoid ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) {\n    [encoder->obj setBytes:data length:size atIndex:idx];\n}\n\nvoid ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx) {\n    [encoder->obj setBuffer:buffer.metal offset:buffer.offs atIndex:idx];\n}\n\nvoid ggml_metal_encoder_set_threadgroup_memory_size(ggml_metal_encoder_t encoder, size_t size, int idx) {\n    [encoder->obj setThreadgroupMemoryLength:size atIndex:idx];\n}\n\nvoid ggml_metal_encoder_dispatch_threadgroups(ggml_metal_encoder_t encoder, int tg0, int tg1, int tg2, int tptg0, int tptg1, int tptg2) {\n    [encoder->obj dispatchThreadgroups:MTLSizeMake(tg0, tg1, tg2) threadsPerThreadgroup:MTLSizeMake(tptg0, tptg1, tptg2)];\n}\n\nvoid ggml_metal_encoder_memory_barrier(ggml_metal_encoder_t encoder) {\n    [encoder->obj memoryBarrierWithScope:MTLBarrierScopeBuffers];\n}\n\nvoid ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder) {\n    [encoder->obj endEncoding];\n}\n\nstruct ggml_metal_device {\n    id<MTLDevice> mtl_device;\n\n    // a single global queue shared by all Metal backends\n    // technically not needed for devices with unified memory, but enables discrete GPUs support\n    // ref: https://github.com/ggml-org/llama.cpp/pull/15906\n    id<MTLCommandQueue> mtl_queue;\n\n    ggml_metal_rsets_t rsets;\n\n    ggml_metal_library_t library;\n\n    struct ggml_metal_device_props props;\n\n    // virtual address for GPU memory allocations\n    atomic_uintptr_t addr_virt;\n};\n\n//\n// MTLResidenceSet wrapper\n//\n\nstruct ggml_metal_rsets {\n    NSLock * lock;\n\n    NSMutableArray * data;\n\n    // number of seconds since the last graph computation\n    // keep the residency sets wired for that amount of time to avoid being collected by the OS\n    int keep_alive_s;\n\n    // background heartbeat thread to keep the residency sets alive\n    atomic_bool d_stop;\n    atomic_int  d_loop;\n\n    dispatch_group_t d_group;\n};\n\nggml_metal_rsets_t ggml_metal_rsets_init(void) {\n    ggml_metal_rsets_t res = calloc(1, sizeof(struct ggml_metal_rsets));\n\n    res->lock = [[NSLock alloc] init];\n    res->data = [[NSMutableArray alloc] init];\n\n    // by default keep the memory wired for 3 minutes\n    res->keep_alive_s = 3*60;\n\n    const char * GGML_METAL_RESIDENCY_KEEP_ALIVE_S = getenv(\"GGML_METAL_RESIDENCY_KEEP_ALIVE_S\");\n    if (GGML_METAL_RESIDENCY_KEEP_ALIVE_S) {\n        res->keep_alive_s = atoi(GGML_METAL_RESIDENCY_KEEP_ALIVE_S);\n    }\n\n    if (res->keep_alive_s <= 0) {\n        res->keep_alive_s = 3*60;\n    }\n\n    GGML_LOG_INFO(\"%s: creating a residency set collection (keep_alive = %d s)\\n\", __func__, res->keep_alive_s);\n\n    atomic_store_explicit(&res->d_stop, false, memory_order_relaxed);\n    atomic_store_explicit(&res->d_loop, 2*res->keep_alive_s, memory_order_relaxed);\n\n    res->d_group = dispatch_group_create();\n\n    // start a background thread that periodically requests residency for all the currently active sets in the collection\n    // the requests stop after a certain amount of time (keep_alive_s) of inactivity\n    dispatch_queue_t d_queue = dispatch_get_global_queue(QOS_CLASS_DEFAULT, 0);\n    dispatch_group_async(res->d_group, d_queue, ^{\n#if defined(GGML_METAL_HAS_RESIDENCY_SETS)\n        if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {\n              while (!atomic_load_explicit(&res->d_stop, memory_order_relaxed)) {\n                  if (atomic_load_explicit(&res->d_loop, memory_order_relaxed) > 0) {\n                      [res->lock lock];\n\n                      for (int i = 0; i < (int) res->data.count; ++i) {\n                          [res->data[i] requestResidency];\n                      }\n\n                      atomic_fetch_sub_explicit(&res->d_loop, 1, memory_order_relaxed);\n\n                      [res->lock unlock];\n                  }\n\n                  // half a second\n                  usleep(500 * 1000);\n              }\n        }\n#endif\n    });\n\n    return res;\n}\n\nvoid ggml_metal_rsets_free(ggml_metal_rsets_t rsets) {\n    if (rsets == NULL) {\n        return;\n    }\n\n    // note: if you hit this assert, most likely you haven't deallocated all Metal resources before exiting\n    GGML_ASSERT([rsets->data count] == 0);\n\n    atomic_store_explicit(&rsets->d_stop, true, memory_order_relaxed);\n\n    dispatch_group_wait(rsets->d_group, DISPATCH_TIME_FOREVER);\n    dispatch_release(rsets->d_group);\n\n    [rsets->data release];\n    [rsets->lock release];\n\n    free(rsets);\n}\n\nggml_metal_device_t ggml_metal_device_init(int device) {\n    ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device));\n\n    assert(dev != NULL);\n\n    if (dev->mtl_device == nil) {\n        dev->mtl_device = MTLCreateSystemDefaultDevice();\n\n        if (dev->mtl_device) {\n            dev->mtl_queue = [dev->mtl_device newCommandQueue];\n            if (dev->mtl_queue == nil) {\n                GGML_LOG_ERROR(\"%s: error: failed to create command queue\\n\", __func__);\n            }\n\n            dev->addr_virt = 0x000000400ULL;\n\n            dev->props.device = device;\n            dev->props.has_simdgroup_reduction  = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7];\n            dev->props.has_simdgroup_reduction |= [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];\n\n            dev->props.has_simdgroup_mm = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7];\n            dev->props.has_unified_memory = dev->mtl_device.hasUnifiedMemory;\n\n            dev->props.has_bfloat  = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];\n            dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6];\n            if (getenv(\"GGML_METAL_BF16_DISABLE\") != NULL) {\n                dev->props.has_bfloat = false;\n            }\n\n            dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML];\n            if (getenv(\"GGML_METAL_TENSOR_DISABLE\") != NULL) {\n                dev->props.has_tensor = false;\n            }\n\n            // note: disable the tensor API by default for old chips because with the current implementation it is not useful\n            // - M2 Ultra:   ~5% slower\n            // - M4, M4 Max: no significant difference\n            //\n            // TODO: try to update the tensor API kernels to at least match the simdgroup performance\n            if (getenv(\"GGML_METAL_TENSOR_ENABLE\") == NULL &&\n                ![[dev->mtl_device name] containsString:@\"M5\"] &&\n                ![[dev->mtl_device name] containsString:@\"M6\"] &&\n                ![[dev->mtl_device name] containsString:@\"A19\"] &&\n                ![[dev->mtl_device name] containsString:@\"A20\"]) {\n                GGML_LOG_WARN(\"%s: tensor API disabled for pre-M5 and pre-A19 devices\\n\", __func__);\n                dev->props.has_tensor = false;\n            }\n\n            // double-check that the tensor API compiles\n            if (dev->props.has_tensor) {\n                const char * src_tensor_f16 = \"\\n\"\n                    \"#include <metal_stdlib> \\n\"\n                    \"#include <metal_tensor> \\n\"\n                    \"#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \\n\"\n                    \" \\n\"\n                    \"using namespace metal; \\n\"\n                    \"using namespace mpp::tensor_ops; \\n\"\n                    \" \\n\"\n                    \"kernel void dummy_kernel( \\n\"\n                    \"    tensor<device  half, dextents<int32_t, 2>> A [[buffer(0)]], \\n\"\n                    \"    tensor<device  half, dextents<int32_t, 2>> B [[buffer(1)]], \\n\"\n                    \"    device float * C [[buffer(2)]], \\n\"\n                    \"    uint2 tgid [[threadgroup_position_in_grid]]) \\n\"\n                    \"{ \\n\"\n                    \"    auto tA = A.slice(0, (int)tgid.y); \\n\"\n                    \"    auto tB = B.slice((int)tgid.x, 0); \\n\"\n                    \" \\n\"\n                    \"    matmul2d< \\n\"\n                    \"        matmul2d_descriptor(8, 8, dynamic_extent), \\n\"\n                    \"        execution_simdgroups<4>> mm; \\n\"\n                    \" \\n\"\n                    \"    auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \\n\"\n                    \" \\n\"\n                    \"    auto sA = tA.slice(0, 0); \\n\"\n                    \"    auto sB = tB.slice(0, 0); \\n\"\n                    \"    mm.run(sB, sA, cT); \\n\"\n                    \" \\n\"\n                    \"    auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \\n\"\n                    \" \\n\"\n                    \"    cT.store(tC); \\n\"\n                    \"}\";\n\n                GGML_LOG_INFO(\"%s: testing tensor API for f16 support\\n\", __func__);\n                ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_f16, false);\n                if (lib == NULL) {\n                    GGML_LOG_WARN(\"%s: - the tensor API is not supported in this environment - disabling\\n\", __func__);\n                    dev->props.has_tensor = false;\n                } else {\n                    struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, \"dummy_kernel\", \"dummy_kernel\", nil);\n                    if (!ppl.pipeline) {\n                        GGML_LOG_WARN(\"%s: - the tensor API is not supported in this environment - disabling\\n\", __func__);\n                        dev->props.has_tensor = false;\n                    }\n\n                    ggml_metal_library_free(lib);\n                }\n            }\n\n            // try to compile a dummy kernel to determine if the tensor API is supported for bfloat\n            if (dev->props.has_tensor && dev->props.has_bfloat) {\n                const char * src_tensor_bf16 = \"\\n\"\n                    \"#include <metal_stdlib> \\n\"\n                    \"#include <metal_tensor> \\n\"\n                    \"#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \\n\"\n                    \" \\n\"\n                    \"using namespace metal; \\n\"\n                    \"using namespace mpp::tensor_ops; \\n\"\n                    \" \\n\"\n                    \"kernel void dummy_kernel( \\n\"\n                    \"    tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \\n\"\n                    \"    tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \\n\"\n                    \"    device float * C [[buffer(2)]], \\n\"\n                    \"    uint2 tgid [[threadgroup_position_in_grid]]) \\n\"\n                    \"{ \\n\"\n                    \"    auto tA = A.slice(0, (int)tgid.y); \\n\"\n                    \"    auto tB = B.slice((int)tgid.x, 0); \\n\"\n                    \" \\n\"\n                    \"    matmul2d< \\n\"\n                    \"        matmul2d_descriptor(8, 8, dynamic_extent), \\n\"\n                    \"        execution_simdgroups<4>> mm; \\n\"\n                    \" \\n\"\n                    \"    auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \\n\"\n                    \" \\n\"\n                    \"    auto sA = tA.slice(0, 0); \\n\"\n                    \"    auto sB = tB.slice(0, 0); \\n\"\n                    \"    mm.run(sB, sA, cT); \\n\"\n                    \" \\n\"\n                    \"    auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \\n\"\n                    \" \\n\"\n                    \"    cT.store(tC); \\n\"\n                    \"}\";\n\n                GGML_LOG_INFO(\"%s: testing tensor API for bfloat support\\n\", __func__);\n                ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_bf16, false);\n                if (lib == NULL) {\n                    GGML_LOG_WARN(\"%s: - the tensor API does not support bfloat - disabling bfloat support\\n\", __func__);\n                    dev->props.has_bfloat = false;\n                } else {\n                    struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, \"dummy_kernel\", \"dummy_kernel\", nil);\n                    if (!ppl.pipeline) {\n                        GGML_LOG_WARN(\"%s: - the tensor API does not support bfloat - disabling bfloat support\\n\", __func__);\n                        dev->props.has_bfloat = false;\n                    }\n\n                    ggml_metal_library_free(lib);\n                }\n            }\n\n            dev->props.use_residency_sets = true;\n#if defined(GGML_METAL_HAS_RESIDENCY_SETS)\n            dev->props.use_residency_sets = getenv(\"GGML_METAL_NO_RESIDENCY\") == nil;\n#endif\n\n            dev->props.use_shared_buffers = dev->props.has_unified_memory;\n#if TARGET_OS_OSX\n            // In case of eGPU, shared memory may be preferable.\n            dev->props.use_shared_buffers |= [dev->mtl_device location] == MTLDeviceLocationExternal;\n#endif\n            if (getenv(\"GGML_METAL_SHARED_BUFFERS_DISABLE\") != NULL) {\n                dev->props.use_shared_buffers = false;\n            }\n            if (getenv(\"GGML_METAL_SHARED_BUFFERS_ENABLE\") != NULL) {\n                dev->props.use_shared_buffers = true;\n            }\n\n            dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7];\n\n            dev->props.op_offload_min_batch_size  = getenv(\"GGML_OP_OFFLOAD_MIN_BATCH\") ? atoi(getenv(\"GGML_OP_OFFLOAD_MIN_BATCH\")) : 32;\n\n            dev->props.max_buffer_size            = dev->mtl_device.maxBufferLength;\n            dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength;\n            if (@available(macOS 10.12, iOS 16.0, *)) {\n                dev->props.max_working_set_size   = dev->mtl_device.recommendedMaxWorkingSetSize;\n            } else {\n                dev->props.max_working_set_size   = dev->mtl_device.maxBufferLength;\n            }\n\n            snprintf(dev->props.name, sizeof(dev->props.name), \"%s%d\", \"MTL\", device);\n            snprintf(dev->props.desc, sizeof(dev->props.desc), \"%s\", [[dev->mtl_device name] UTF8String]);\n\n            dev->library = ggml_metal_library_init(dev);\n            if (!dev->library) {\n                GGML_LOG_ERROR(\"%s: error: failed to create library\\n\", __func__);\n            }\n\n            if (dev->props.use_residency_sets) {\n                dev->rsets = ggml_metal_rsets_init();\n            } else {\n                dev->rsets = nil;\n            }\n\n            // print MTL GPU family:\n            GGML_LOG_INFO(\"%s: GPU name:   %s\\n\", __func__, dev->props.name);\n\n            // determine max supported GPU family\n            // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf\n            // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf\n            {\n                for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {\n                    if ([dev->mtl_device supportsFamily:i]) {\n                        GGML_LOG_INFO(\"%s: GPU family: MTLGPUFamilyApple%d  (%d)\\n\", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);\n                        break;\n                    }\n                }\n\n                for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {\n                    if ([dev->mtl_device supportsFamily:i]) {\n                        GGML_LOG_INFO(\"%s: GPU family: MTLGPUFamilyCommon%d (%d)\\n\", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);\n                        break;\n                    }\n                }\n\n                for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) {\n                    if ([dev->mtl_device supportsFamily:i]) {\n                        GGML_LOG_INFO(\"%s: GPU family: MTLGPUFamilyMetal%d  (%d)\\n\", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i);\n                        break;\n                    }\n                }\n            }\n\n            GGML_LOG_INFO(\"%s: simdgroup reduction   = %s\\n\", __func__, dev->props.has_simdgroup_reduction ? \"true\" : \"false\");\n            GGML_LOG_INFO(\"%s: simdgroup matrix mul. = %s\\n\", __func__, dev->props.has_simdgroup_mm        ? \"true\" : \"false\");\n            GGML_LOG_INFO(\"%s: has unified memory    = %s\\n\", __func__, dev->props.has_unified_memory      ? \"true\" : \"false\");\n            GGML_LOG_INFO(\"%s: has bfloat            = %s\\n\", __func__, dev->props.has_bfloat              ? \"true\" : \"false\");\n            GGML_LOG_INFO(\"%s: has tensor            = %s\\n\", __func__, dev->props.has_tensor              ? \"true\" : \"false\");\n            GGML_LOG_INFO(\"%s: use residency sets    = %s\\n\", __func__, dev->props.use_residency_sets      ? \"true\" : \"false\");\n            GGML_LOG_INFO(\"%s: use shared buffers    = %s\\n\", __func__, dev->props.use_shared_buffers      ? \"true\" : \"false\");\n\n#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)\n            if (@available(macOS 10.12, iOS 16.0, *)) {\n                GGML_LOG_INFO(\"%s: recommendedMaxWorkingSetSize  = %8.2f MB\\n\", __func__, dev->props.max_working_set_size / 1e6);\n            }\n#endif\n        }\n    }\n\n    return dev;\n}\n\nvoid ggml_metal_device_free(ggml_metal_device_t dev) {\n    assert(dev != NULL);\n\n    ggml_metal_rsets_free(dev->rsets);\n\n    ggml_metal_library_free(dev->library);\n    dev->library = NULL;\n\n    if (dev->mtl_queue) {\n        [dev->mtl_queue release];\n        dev->mtl_queue = nil;\n    }\n\n    if (dev->mtl_device) {\n        [dev->mtl_device release];\n        dev->mtl_device = nil;\n    }\n\n    free(dev);\n}\n\nvoid * ggml_metal_device_get_obj(ggml_metal_device_t dev) {\n    return dev->mtl_device;\n}\n\nvoid * ggml_metal_device_get_queue(ggml_metal_device_t dev) {\n    return dev->mtl_queue;\n}\n\nggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev) {\n    return dev->library;\n}\n\nvoid ggml_metal_device_rsets_add(ggml_metal_device_t dev, ggml_metal_rset_t rset) {\n    if (rset == nil) {\n        return;\n    }\n\n    GGML_ASSERT(dev->rsets);\n\n    [dev->rsets->lock lock];\n\n    [dev->rsets->data addObject:rset];\n\n    [dev->rsets->lock unlock];\n}\n\nvoid ggml_metal_device_rsets_rm(ggml_metal_device_t dev, ggml_metal_rset_t rset) {\n    if (rset == nil) {\n        return;\n    }\n\n    GGML_ASSERT(dev->rsets);\n\n    [dev->rsets->lock lock];\n\n    [dev->rsets->data removeObject:rset];\n\n    [dev->rsets->lock unlock];\n}\n\nvoid ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) {\n    if (dev->rsets == NULL) {\n        return;\n    }\n\n    atomic_store_explicit(&dev->rsets->d_loop, 2*dev->rsets->keep_alive_s, memory_order_relaxed);\n}\n\nstruct ggml_metal_event {\n    void * obj; // id<MTLEvent>\n\n    atomic_int value;\n};\n\nvoid ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) {\n    id<MTLEvent> event = (id<MTLEvent>)ev->obj;\n\n    id<MTLCommandBuffer> cmd_buf = (id<MTLCommandBuffer>) cmd_buf_raw;\n\n    [cmd_buf encodeSignalEvent:event value:atomic_fetch_add_explicit(&ev->value, 1, memory_order_relaxed) + 1];\n}\n\nvoid ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) {\n    id<MTLEvent> event = (id<MTLEvent>)ev->obj;\n\n    id<MTLCommandBuffer> cmd_buf = (id<MTLCommandBuffer>) cmd_buf_raw;\n\n    [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)];\n}\n\nggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) {\n    id<MTLEvent> event = [dev->mtl_device newEvent];\n\n    ggml_metal_event_t ev = calloc(1, sizeof(struct ggml_metal_event));\n\n    ev->obj = (__bridge void *)event;\n    ev->value = 0;\n\n    return ev;\n}\n\nvoid ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev) {\n    id<MTLEvent> event = ev->obj;\n    [event release];\n\n    free(ev);\n\n    GGML_UNUSED(dev);\n}\n\nvoid ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev) {\n    @autoreleasepool {\n        id<MTLEvent> event = ev->obj;\n\n        id<MTLCommandBuffer> cmd_buf = [dev->mtl_queue commandBuffer];\n        [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)];\n        [cmd_buf commit];\n        [cmd_buf waitUntilCompleted];\n    }\n}\n\nvoid ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) {\n    if (@available(macOS 10.12, iOS 16.0, *)) {\n        *total = dev->mtl_device.recommendedMaxWorkingSetSize;\n        *free  = *total - dev->mtl_device.currentAllocatedSize;\n    } else {\n        *free = 0;\n        *total = 0;\n    }\n}\n\nbool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op) {\n    const bool has_simdgroup_mm        = dev->props.has_simdgroup_mm;\n    const bool has_simdgroup_reduction = dev->props.has_simdgroup_reduction;\n    const bool has_bfloat              = dev->props.has_bfloat;\n\n    if (!has_bfloat) {\n        if (op->type == GGML_TYPE_BF16) {\n            return false;\n        }\n\n        for (size_t i = 0, n = 3; i < n; ++i) {\n            if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {\n                return false;\n            }\n        }\n    }\n\n    switch (op->op) {\n        case GGML_OP_SCALE:\n        case GGML_OP_FILL:\n        case GGML_OP_CLAMP:\n        case GGML_OP_SQR:\n        case GGML_OP_SQRT:\n        case GGML_OP_SIN:\n        case GGML_OP_COS:\n        case GGML_OP_LOG:\n            return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);\n        case GGML_OP_UNARY:\n            switch (ggml_get_unary_op(op)) {\n                case GGML_UNARY_OP_TANH:\n                case GGML_UNARY_OP_RELU:\n                case GGML_UNARY_OP_SIGMOID:\n                case GGML_UNARY_OP_GELU:\n                case GGML_UNARY_OP_GELU_ERF:\n                case GGML_UNARY_OP_GELU_QUICK:\n                case GGML_UNARY_OP_SILU:\n                case GGML_UNARY_OP_ELU:\n                case GGML_UNARY_OP_NEG:\n                case GGML_UNARY_OP_ABS:\n                case GGML_UNARY_OP_SGN:\n                case GGML_UNARY_OP_STEP:\n                case GGML_UNARY_OP_HARDSWISH:\n                case GGML_UNARY_OP_HARDSIGMOID:\n                case GGML_UNARY_OP_EXP:\n                case GGML_UNARY_OP_SOFTPLUS:\n                case GGML_UNARY_OP_EXPM1:\n                    return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);\n                default:\n                    return false;\n            }\n        case GGML_OP_GLU:\n            switch (ggml_get_glu_op(op)) {\n                case GGML_GLU_OP_REGLU:\n                case GGML_GLU_OP_GEGLU:\n                case GGML_GLU_OP_SWIGLU:\n                case GGML_GLU_OP_SWIGLU_OAI:\n                case GGML_GLU_OP_GEGLU_ERF:\n                case GGML_GLU_OP_GEGLU_QUICK:\n                    return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;\n               default:\n                    return false;\n            }\n        case GGML_OP_NONE:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_TRANSPOSE:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_CONCAT:\n            return true;\n        case GGML_OP_ADD:\n        case GGML_OP_SUB:\n        case GGML_OP_MUL:\n        case GGML_OP_DIV:\n        case GGML_OP_ADD_ID:\n        case GGML_OP_ACC:\n            return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_REPEAT:\n        case GGML_OP_CONV_TRANSPOSE_1D:\n            return true;\n        case GGML_OP_CONV_TRANSPOSE_2D:\n            return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) &&\n                (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&\n                op->src[1]->type == GGML_TYPE_F32 &&\n                op->type == GGML_TYPE_F32;\n        case GGML_OP_SUM:\n            return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);\n        case GGML_OP_TRI:\n            return ggml_is_contiguous_rows(op->src[0]);\n        case GGML_OP_SUM_ROWS:\n        case GGML_OP_CUMSUM:\n        case GGML_OP_MEAN:\n        case GGML_OP_SOFT_MAX:\n        case GGML_OP_GROUP_NORM:\n        case GGML_OP_L2_NORM:\n            return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);\n        case GGML_OP_COUNT_EQUAL:\n            return has_simdgroup_reduction &&\n                op->src[0]->type == GGML_TYPE_I32 &&\n                op->src[1]->type == GGML_TYPE_I32 &&\n                op->type == GGML_TYPE_I64;\n        case GGML_OP_ARGMAX:\n            return has_simdgroup_reduction;\n        case GGML_OP_NORM:\n        case GGML_OP_RMS_NORM:\n            return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));\n        case GGML_OP_ROPE:\n            return true;\n        case GGML_OP_IM2COL:\n            return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);\n        case GGML_OP_CONV_2D:\n            return ggml_is_contiguous(op->src[0]) &&\n                   op->src[1]->type == GGML_TYPE_F32 &&\n                   op->type == GGML_TYPE_F32 &&\n                   (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);\n        case GGML_OP_UPSCALE:\n            return op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_POOL_1D:\n            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_POOL_2D:\n            return op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_PAD:\n            // TODO: add circular padding support for metal, see https://github.com/ggml-org/llama.cpp/pull/16985\n            if (ggml_get_op_params_i32(op, 8) != 0) {\n                return false;\n            }\n\n            return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&\n                   (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);\n        case GGML_OP_PAD_REFLECT_1D:\n        case GGML_OP_TIMESTEP_EMBEDDING:\n        case GGML_OP_LEAKY_RELU:\n            return op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_ARGSORT:\n        case GGML_OP_TOP_K:\n        case GGML_OP_ARANGE:\n            return true;\n        case GGML_OP_FLASH_ATTN_EXT:\n            // for new head sizes, add checks here\n            if (op->src[0]->ne[0] != 32 &&\n                op->src[0]->ne[0] != 40 &&\n                op->src[0]->ne[0] != 48 &&\n                op->src[0]->ne[0] != 64 &&\n                op->src[0]->ne[0] != 72 &&\n                op->src[0]->ne[0] != 80 &&\n                op->src[0]->ne[0] != 96 &&\n                op->src[0]->ne[0] != 112 &&\n                op->src[0]->ne[0] != 128 &&\n                op->src[0]->ne[0] != 192 &&\n                op->src[0]->ne[0] != 256 &&\n                op->src[0]->ne[0] != 320 &&\n                op->src[0]->ne[0] != 576) {\n                return false;\n            }\n            if (op->src[1]->type != op->src[2]->type) {\n                return false;\n            }\n            return has_simdgroup_mm; // TODO: over-restricted for vec-kernels\n        case GGML_OP_SSM_CONV:\n        case GGML_OP_SSM_SCAN:\n            return has_simdgroup_reduction;\n        case GGML_OP_RWKV_WKV6:\n        case GGML_OP_RWKV_WKV7:\n            return true;\n        case GGML_OP_GATED_DELTA_NET:\n            return has_simdgroup_reduction && op->src[2]->ne[0] % 32 == 0;\n        case GGML_OP_SOLVE_TRI:\n        case GGML_OP_MUL_MAT:\n        case GGML_OP_MUL_MAT_ID:\n            return has_simdgroup_reduction && op->src[0]->type != GGML_TYPE_NVFP4;\n        case GGML_OP_SET:\n        case GGML_OP_CPY:\n        case GGML_OP_DUP:\n        case GGML_OP_CONT:\n            {\n                switch (op->src[0]->type) {\n                    case GGML_TYPE_F32:\n                        switch (op->type) {\n                           case GGML_TYPE_F32:\n                           case GGML_TYPE_F16:\n                           case GGML_TYPE_BF16:\n                           case GGML_TYPE_Q8_0:\n                           case GGML_TYPE_Q4_0:\n                           case GGML_TYPE_Q4_1:\n                           case GGML_TYPE_Q5_0:\n                           case GGML_TYPE_Q5_1:\n                           case GGML_TYPE_IQ4_NL:\n                           case GGML_TYPE_I32:\n                                return true;\n                           default:\n                                return false;\n                        }\n                    case GGML_TYPE_F16:\n                        switch (op->type) {\n                            case GGML_TYPE_F32:\n                            case GGML_TYPE_F16:\n                                return true;\n                            default:\n                                return false;\n                        }\n                    case GGML_TYPE_BF16:\n                        switch (op->type) {\n                            case GGML_TYPE_F32:\n                            case GGML_TYPE_BF16:\n                                return true;\n                            default:\n                                return false;\n                        }\n                    case GGML_TYPE_Q4_0:\n                    case GGML_TYPE_Q4_1:\n                    case GGML_TYPE_Q5_0:\n                    case GGML_TYPE_Q5_1:\n                    case GGML_TYPE_Q8_0:\n                        switch (op->type) {\n                            case GGML_TYPE_F32:\n                            case GGML_TYPE_F16:\n                                return true;\n                            default:\n                                return false;\n                        }\n                    case GGML_TYPE_I32:\n                        return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32;\n                    default:\n                        return false;\n                };\n            }\n        case GGML_OP_GET_ROWS:\n            return op->src[0]->type != GGML_TYPE_NVFP4;\n        case GGML_OP_SET_ROWS:\n            {\n                if (op->src[0]->type != GGML_TYPE_F32) {\n                    return false;\n                }\n\n                switch (op->type) {\n                    case GGML_TYPE_F32:\n                    case GGML_TYPE_F16:\n                    case GGML_TYPE_BF16:\n                    case GGML_TYPE_Q8_0:\n                    case GGML_TYPE_Q4_0:\n                    case GGML_TYPE_Q4_1:\n                    case GGML_TYPE_Q5_0:\n                    case GGML_TYPE_Q5_1:\n                    case GGML_TYPE_IQ4_NL:\n                        return true;\n                    default:\n                        return false;\n                };\n            }\n        case GGML_OP_DIAG:\n            return true;\n        case GGML_OP_OPT_STEP_ADAMW:\n        case GGML_OP_OPT_STEP_SGD:\n            return has_simdgroup_reduction;\n        default:\n            return false;\n    }\n}\n\nconst struct ggml_metal_device_props * ggml_metal_device_get_props(ggml_metal_device_t dev) {\n    return &dev->props;\n}\n\n//\n// device buffers\n//\n\n// max memory buffers that can be mapped to the device\n#define GGML_METAL_MAX_BUFFERS 64\n\nstruct ggml_metal_buffer_wrapper {\n    void   * data;\n    size_t   size;\n\n    id<MTLBuffer> metal;\n};\n\nstruct ggml_metal_buffer {\n    void * all_data;\n    size_t all_size;\n\n    // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host\n    bool is_shared;\n    bool owned;\n\n    // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap\n    int n_buffers;\n    struct ggml_metal_buffer_wrapper buffers[GGML_METAL_MAX_BUFFERS];\n\n    bool use_residency_sets;\n\n    // optional MTLResidencySet\n    // note: cannot use explicitly \"id<MTLResidencySet>\" here because it is not available on certain OSes\n    id rset;\n\n    // pointers to global device\n    ggml_metal_device_t dev;\n};\n\nstatic void ggml_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {\n#ifndef GGML_METAL_NDEBUG\n#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)\n    if (@available(macOS 10.12, iOS 16.0, *)) {\n        GGML_LOG_DEBUG(\"%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\\n\",\n                __func__,\n                size_aligned / 1024.0 / 1024.0,\n                device.currentAllocatedSize / 1024.0 / 1024.0,\n                device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);\n\n        if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {\n            GGML_LOG_WARN(\"%s: warning: current allocated size is greater than the recommended max working set size\\n\", __func__);\n        }\n    } else {\n        GGML_LOG_INFO(\"%s: allocated buffer, size = %8.2f MiB, (%8.2f)\\n\",\n                __func__,\n                size_aligned / 1024.0 / 1024.0,\n                device.currentAllocatedSize / 1024.0 / 1024.0);\n    }\n#endif\n#endif\n    GGML_UNUSED(device);\n    GGML_UNUSED(size_aligned);\n}\n\n// rset init\nstatic bool ggml_metal_buffer_rset_init(ggml_metal_buffer_t buf) {\n    buf->rset = nil;\n\n    if (!buf->use_residency_sets) {\n        return true;\n    }\n\n#if defined(GGML_METAL_HAS_RESIDENCY_SETS)\n    if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {\n        MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init];\n        desc.label = @\"ggml_metal\";\n        desc.initialCapacity = buf->n_buffers;\n\n        NSError * error;\n        buf->rset = [buf->dev->mtl_device newResidencySetWithDescriptor:desc error:&error];\n        if (error) {\n            GGML_LOG_ERROR(\"%s: error: %s\\n\", __func__, [[error description] UTF8String]);\n            [desc release];\n            return false;\n        }\n\n        [desc release];\n\n        for (int i = 0; i < buf->n_buffers; i++) {\n            [buf->rset addAllocation:buf->buffers[i].metal];\n        }\n\n        [buf->rset commit];\n        [buf->rset requestResidency];\n\n        return true;\n    }\n#endif\n\n    return true;\n}\n\n// rset free\nstatic void ggml_metal_buffer_rset_free(ggml_metal_buffer_t buf) {\n#if defined(GGML_METAL_HAS_RESIDENCY_SETS)\n    if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {\n        if (buf->rset) {\n            [buf->rset endResidency];\n            [buf->rset removeAllAllocations];\n            [buf->rset release];\n        }\n    }\n#else\n    GGML_UNUSED(buf);\n#endif\n}\n\nstatic void * ggml_metal_host_malloc(size_t n) {\n    void * data = NULL;\n\n#if TARGET_OS_OSX\n    kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);\n    if (err != KERN_SUCCESS) {\n        GGML_LOG_ERROR(\"%s: error: vm_allocate failed\\n\", __func__);\n        return NULL;\n    }\n#else\n    const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);\n    if (result != 0) {\n        GGML_LOG_ERROR(\"%s: error: posix_memalign failed\\n\", __func__);\n        return NULL;\n    }\n#endif\n\n    return data;\n}\n\nggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared) {\n    ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer));\n\n    res->dev = dev;\n\n    const size_t size_page = sysconf(_SC_PAGESIZE);\n\n    size_t size_aligned = size;\n    if ((size_aligned % size_page) != 0) {\n        size_aligned += (size_page - (size_aligned % size_page));\n    }\n\n    const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);\n\n    shared = shared && props_dev->use_shared_buffers;\n\n    // allocate shared buffer if the device supports it and it is required by the buffer type\n    if (shared) {\n        res->all_data = ggml_metal_host_malloc(size_aligned);\n        res->is_shared = true;\n    } else {\n        // use virtual address\n        res->all_data = (void *) atomic_fetch_add_explicit(&dev->addr_virt, size_aligned, memory_order_relaxed);\n        res->is_shared = false;\n    }\n    res->all_size = size_aligned;\n\n    res->owned = true;\n\n    res->n_buffers = 1;\n\n    if (res->all_data != NULL) {\n        res->buffers[0].size  = size;\n        res->buffers[0].metal = nil;\n\n        if (size_aligned > 0) {\n            if (props_dev->use_shared_buffers && shared) {\n                res->buffers[0].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:res->all_data\n                                                                  length:size_aligned\n                                                                 options:MTLResourceStorageModeShared\n                                                             deallocator:nil];\n            } else {\n                res->buffers[0].metal = [res->dev->mtl_device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];\n            }\n        }\n\n        res->buffers[0].data = res->all_data;\n    }\n\n    if (size_aligned > 0 && (res->all_data == NULL || res->buffers[0].metal == nil)) {\n        GGML_LOG_ERROR(\"%s: error: failed to allocate buffer, size = %8.2f MiB\\n\", __func__, size_aligned / 1024.0 / 1024.0);\n        free(res);\n        return NULL;\n    }\n\n    res->use_residency_sets = props_dev->use_residency_sets;\n\n    if (!ggml_metal_buffer_rset_init(res)) {\n        GGML_LOG_ERROR(\"%s: error: failed to initialize residency set\\n\", __func__);\n        free(res);\n        return NULL;\n    }\n\n    ggml_metal_device_rsets_add(dev, res->rset);\n\n    //ggml_metal_log_allocated_size(device, size_aligned);\n\n    return res;\n}\n\nggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size) {\n    ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer));\n\n    res->dev = dev;\n\n    res->all_data = ptr;\n    res->all_size = size;\n\n    res->is_shared = true;\n    res->owned = false;\n\n    res->n_buffers = 0;\n\n    const size_t size_page = sysconf(_SC_PAGESIZE);\n\n    // page-align the data ptr\n    {\n        const uintptr_t offs = (uintptr_t) ptr % size_page;\n        ptr  = (void *) ((char *) ptr - offs);\n        size += offs;\n    }\n\n    size_t size_aligned = size;\n    if ((size_aligned % size_page) != 0) {\n        size_aligned += (size_page - (size_aligned % size_page));\n    }\n\n    const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);\n\n    // the buffer fits into the max buffer size allowed by the device\n    if (size_aligned <= props_dev->max_buffer_size) {\n        res->buffers[res->n_buffers].data  = ptr;\n        res->buffers[res->n_buffers].size  = size;\n        res->buffers[res->n_buffers].metal = nil;\n\n        if (size_aligned > 0) {\n            res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];\n\n            if (res->buffers[res->n_buffers].metal == nil) {\n                GGML_LOG_ERROR(\"%s: error: failed to allocate buffer, size = %8.2f MiB\\n\", __func__, size_aligned / 1024.0 / 1024.0);\n                free(res);\n                return NULL;\n            }\n        }\n\n        ggml_metal_log_allocated_size(res->dev->mtl_device, size_aligned);\n\n        ++res->n_buffers;\n    } else {\n        // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into\n        // one of the views\n        const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case\n        const size_t size_step = props_dev->max_buffer_size - size_ovlp;\n        const size_t size_view = props_dev->max_buffer_size;\n\n        for (size_t i = 0; i < size; i += size_step) {\n            const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);\n\n            res->buffers[res->n_buffers].data  = (void *) ((uint8_t *) ptr + i);\n            res->buffers[res->n_buffers].size  = size_step_aligned;\n            res->buffers[res->n_buffers].metal = nil;\n\n            if (size_step_aligned > 0) {\n                res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];\n\n                if (res->buffers[res->n_buffers].metal == nil) {\n                    GGML_LOG_ERROR(\"%s: error: failed to allocate buffer, size = %8.2f MiB\\n\", __func__, size_step_aligned / 1024.0 / 1024.0);\n                    free(res);\n                    return NULL;\n                }\n            }\n\n            ggml_metal_log_allocated_size(res->dev->mtl_device, size_step_aligned);\n\n            if (i + size_step < size) {\n                GGML_LOG_INFO(\"\\n\");\n            }\n\n            ++res->n_buffers;\n        }\n    }\n\n    res->use_residency_sets = props_dev->use_residency_sets;\n\n    if (!ggml_metal_buffer_rset_init(res)) {\n        GGML_LOG_ERROR(\"%s: error: failed to initialize residency set\\n\", __func__);\n        free(res);\n        return NULL;\n    }\n\n    ggml_metal_device_rsets_add(dev, res->rset);\n\n    return res;\n}\n\nvoid ggml_metal_buffer_free(ggml_metal_buffer_t buf) {\n    ggml_metal_device_rsets_rm(buf->dev, buf->rset);\n\n    for (int i = 0; i < buf->n_buffers; i++) {\n        [buf->buffers[i].metal release];\n    }\n\n    ggml_metal_buffer_rset_free(buf);\n\n    if (buf->is_shared && buf->owned) {\n#if TARGET_OS_OSX\n        vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)buf->all_data, buf->all_size);\n#else\n        free(buf->all_data);\n#endif\n    }\n\n    free(buf);\n}\n\nvoid * ggml_metal_buffer_get_base(ggml_metal_buffer_t buf) {\n    return buf->all_data;\n}\n\nbool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) {\n    return buf->is_shared;\n}\n\nvoid ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {\n    if (buf->is_shared) {\n        memset((char *) tensor->data + offset, value, size);\n        return;\n    }\n\n    @autoreleasepool {\n        // dst\n        struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor);\n        bid_dst.offs += offset;\n\n        id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];\n\n        {\n            id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];\n\n            [encoder fillBuffer:bid_dst.metal\n                          range:NSMakeRange(bid_dst.offs, bid_dst.offs + size)\n                          value:value];\n\n            [encoder endEncoding];\n        }\n\n        [cmd_buf commit];\n        [cmd_buf waitUntilCompleted];\n    }\n}\n\nvoid ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    if (buf->is_shared) {\n        memcpy((char *) tensor->data + offset, data, size);\n        return;\n    }\n\n    @autoreleasepool {\n        // src\n        void * data_ptr = (void *)(uintptr_t) data; // \"const cast\" the src data\n        id<MTLBuffer> buf_src = [buf->dev->mtl_device newBufferWithBytesNoCopy:data_ptr\n                                                               length:size\n                                                              options:MTLResourceStorageModeShared\n                                                          deallocator:nil];\n\n        GGML_ASSERT(buf_src);\n\n        // dst\n        struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor);\n        bid_dst.offs += offset;\n\n        // note: for experimentation purposes, here we use a semaphore to wait for the copy to complete\n        //       this is alternative to waitUntilCompleted, which should be faster, but don't seem to make much difference\n        dispatch_semaphore_t completion_semaphore = dispatch_semaphore_create(0);\n\n        id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];\n\n        {\n            id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];\n\n            [encoder copyFromBuffer:buf_src\n                       sourceOffset:0\n                           toBuffer:bid_dst.metal\n                  destinationOffset:bid_dst.offs\n                               size:size];\n\n            [encoder endEncoding];\n        }\n\n        [cmd_buf addCompletedHandler:^(id<MTLCommandBuffer> cb) {\n                             // TODO: can check for errors here\n            GGML_UNUSED(cb);\n\n            dispatch_semaphore_signal(completion_semaphore);\n        }];\n\n        [cmd_buf commit];\n\n        dispatch_semaphore_wait(completion_semaphore, DISPATCH_TIME_FOREVER);\n        dispatch_release(completion_semaphore);\n\n        //[cmd_buf waitUntilCompleted];\n    }\n}\n\nvoid ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    if (buf->is_shared) {\n        memcpy(data, (const char *) tensor->data + offset, size);\n        return;\n    }\n\n    @autoreleasepool {\n        // src\n        struct ggml_metal_buffer_id bid_src = ggml_metal_buffer_get_id(buf, tensor);\n        bid_src.offs += offset;\n\n        // dst\n        id<MTLBuffer> buf_dst = [buf->dev->mtl_device newBufferWithBytesNoCopy:data\n                                                               length:size\n                                                              options:MTLResourceStorageModeShared\n                                                          deallocator:nil];\n\n        GGML_ASSERT(buf_dst);\n\n        id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];\n\n        {\n            id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];\n\n            [encoder copyFromBuffer:bid_src.metal\n                       sourceOffset:bid_src.offs\n                           toBuffer:buf_dst\n                  destinationOffset:0\n                               size:size];\n\n            [encoder endEncoding];\n        }\n\n        [cmd_buf commit];\n        [cmd_buf waitUntilCompleted];\n    }\n}\n\nvoid ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) {\n    if (buf->is_shared) {\n        memset(buf->all_data, value, buf->all_size);\n        return;\n    }\n\n    @autoreleasepool {\n        id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];\n\n        {\n            id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];\n\n            [encoder fillBuffer:buf->buffers[0].metal\n                          range:NSMakeRange(0, buf->buffers[0].size)\n                          value:value];\n\n            [encoder endEncoding];\n        }\n\n        [cmd_buf commit];\n        [cmd_buf waitUntilCompleted];\n    }\n}\n\nstruct ggml_metal_buffer_id ggml_metal_buffer_get_id(ggml_metal_buffer_t buf, const struct ggml_tensor * t) {\n    struct ggml_metal_buffer_id res = { nil, 0 };\n\n    const int64_t tsize = ggml_nbytes(t);\n\n    // find the view that contains the tensor fully\n    for (int i = 0; i < buf->n_buffers; ++i) {\n        const int64_t ioffs = (int64_t) t->data - (int64_t) buf->buffers[i].data;\n\n        //GGML_LOG_INFO(\"ioffs = %10ld, tsize = %10ld, sum = %10ld, buf->buffers[%d].size = %10ld\\n\", ioffs, tsize, ioffs + tsize, i, buf->buffers[i].size);\n        if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf->buffers[i].size) {\n            res.metal = buf->buffers[i].metal;\n            res.offs  = (size_t) ioffs;\n\n            //GGML_LOG_INFO(\"%s: tensor '%16s', offs = %8ld\\n\", __func__, t->name, *offs);\n\n            return res;\n        }\n    }\n\n    GGML_LOG_ERROR(\"%s: error: tensor '%s' buffer is nil\\n\", __func__, t->name);\n\n    return res;\n}\n"
  },
  {
    "path": "src/ggml-metal/ggml-metal-impl.h",
    "content": "#ifndef GGML_METAL_IMPL\n#define GGML_METAL_IMPL\n\n// kernel parameters for mat-vec threadgroups\n//\n// N_R0: number of src0 rows to process per simdgroup\n// N_SG: number of simdgroups per threadgroup\n//\n// TODO: for optimal performance, become function of the device and work size\n\n#define N_R0_Q4_0 4\n#define N_SG_Q4_0 2\n\n#define N_R0_Q4_1 4\n#define N_SG_Q4_1 2\n\n#define N_R0_Q5_0 4\n#define N_SG_Q5_0 2\n\n#define N_R0_Q5_1 4\n#define N_SG_Q5_1 2\n\n#define N_R0_Q8_0 2\n#define N_SG_Q8_0 4\n\n#define N_R0_MXFP4 2\n#define N_SG_MXFP4 2\n\n#define N_R0_Q2_K 4\n#define N_SG_Q2_K 2\n\n#define N_R0_Q3_K 2\n#define N_SG_Q3_K 2\n\n#define N_R0_Q4_K 2\n#define N_SG_Q4_K 2\n\n#define N_R0_Q5_K 1\n#define N_SG_Q5_K 2\n\n#define N_R0_Q6_K 2\n#define N_SG_Q6_K 2\n\n#define N_R0_IQ1_S 4\n#define N_SG_IQ1_S 2\n\n#define N_R0_IQ1_M 4\n#define N_SG_IQ1_M 2\n\n#define N_R0_IQ2_XXS 4\n#define N_SG_IQ2_XXS 2\n\n#define N_R0_IQ2_XS 4\n#define N_SG_IQ2_XS 2\n\n#define N_R0_IQ2_S 4\n#define N_SG_IQ2_S 2\n\n#define N_R0_IQ3_XXS 4\n#define N_SG_IQ3_XXS 2\n\n#define N_R0_IQ3_S 4\n#define N_SG_IQ3_S 2\n\n#define N_R0_IQ4_NL 2\n#define N_SG_IQ4_NL 2\n\n#define N_R0_IQ4_XS 2\n#define N_SG_IQ4_XS 2\n\n// function constants offsets\n#define FC_FLASH_ATTN_EXT_PAD          100\n#define FC_FLASH_ATTN_EXT_BLK          200\n#define FC_FLASH_ATTN_EXT              300\n#define FC_FLASH_ATTN_EXT_VEC          400\n#define FC_FLASH_ATTN_EXT_VEC_REDUCE   500\n#define FC_MUL_MV                      600\n#define FC_MUL_MM                      700\n#define FC_ROPE                        800\n#define FC_SSM_CONV                    900\n#define FC_SOLVE_TRI                   1000\n#define FC_COUNT_EQUAL                 1100\n#define FC_UNARY                       1200\n#define FC_BIN                         1300\n#define FC_SUM_ROWS                    1400\n#define FC_UPSCALE                     1500\n#define FC_GATED_DELTA_NET             1600\n\n// op-specific constants\n#define OP_FLASH_ATTN_EXT_NQPSG 8\n#define OP_FLASH_ATTN_EXT_NCPSG 64\n\n#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1\n#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32\n\n#define OP_UNARY_NUM_SCALE      10\n#define OP_UNARY_NUM_FILL       11\n#define OP_UNARY_NUM_CLAMP      12\n#define OP_UNARY_NUM_SQR        13\n#define OP_UNARY_NUM_SQRT       14\n#define OP_UNARY_NUM_SIN        15\n#define OP_UNARY_NUM_COS        16\n#define OP_UNARY_NUM_LOG        17\n#define OP_UNARY_NUM_LEAKY_RELU 18\n\n#define OP_UNARY_NUM_TANH        100\n#define OP_UNARY_NUM_RELU        101\n#define OP_UNARY_NUM_SIGMOID     102\n#define OP_UNARY_NUM_GELU        103\n#define OP_UNARY_NUM_GELU_ERF    104\n#define OP_UNARY_NUM_GELU_QUICK  105\n#define OP_UNARY_NUM_SILU        106\n#define OP_UNARY_NUM_ELU         107\n#define OP_UNARY_NUM_NEG         108\n#define OP_UNARY_NUM_ABS         109\n#define OP_UNARY_NUM_SGN         110\n#define OP_UNARY_NUM_STEP        111\n#define OP_UNARY_NUM_HARDSWISH   112\n#define OP_UNARY_NUM_HARDSIGMOID 113\n#define OP_UNARY_NUM_EXP         114\n#define OP_UNARY_NUM_SOFTPLUS    115\n#define OP_UNARY_NUM_EXPM1       116\n\n#define OP_SUM_ROWS_NUM_SUM_ROWS 10\n#define OP_SUM_ROWS_NUM_MEAN     11\n\n// kernel argument structs\n//\n// - element counters (e.g. ne00) typically use int32_t to reduce register usage\n//   however, be careful from int overflows when using those in the kernel implementation\n//\n// - strides (e.g. nb00) use uint64_t\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    int32_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne10;\n    int32_t  ne11;\n    int32_t  ne12;\n    int32_t  ne13;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n    int32_t  ne0;\n    int32_t  ne1;\n    int32_t  ne2;\n    int32_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n    int32_t  dim;\n} ggml_metal_kargs_concat;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    int32_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne0;\n    int32_t  ne1;\n    int32_t  ne2;\n    int32_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n    float    slope;\n    float    scale;\n    float    bias;\n    float    val;\n    float    min;\n    float    max;\n} ggml_metal_kargs_unary;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    int32_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne10;\n    int32_t  ne11;\n    int32_t  ne12;\n    int32_t  ne13;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n    int32_t  ne0;\n    int32_t  ne1;\n    int32_t  ne2;\n    int32_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n    uint64_t offs;\n    uint64_t o1[8];\n} ggml_metal_kargs_bin;\n\ntypedef struct {\n    int64_t ne0;\n    int64_t ne1;\n    size_t nb01;\n    size_t nb02;\n    size_t nb11;\n    size_t nb21;\n} ggml_metal_kargs_add_id;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    int32_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne0;\n    int32_t  ne1;\n    int32_t  ne2;\n    int32_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n} ggml_metal_kargs_repeat;\n\ntypedef struct {\n    int64_t  nk0;\n    int64_t  ne00;\n    int64_t  ne01;\n    int64_t  ne02;\n    int64_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int64_t  ne0;\n    int64_t  ne1;\n    int64_t  ne2;\n    int64_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n} ggml_metal_kargs_cpy;\n\ntypedef struct {\n    int64_t  ne10;\n    int64_t  ne11;\n    int64_t  ne12;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n    uint64_t offs;\n    bool     inplace;\n} ggml_metal_kargs_set;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    int32_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne0;\n    int32_t  ne1;\n    int32_t  ne2;\n    int32_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n    int32_t  n_past;\n    int32_t  n_dims;\n    int32_t  n_ctx_orig;\n    float    freq_base;\n    float    freq_scale;\n    float    ext_factor;\n    float    attn_factor;\n    float    beta_fast;\n    float    beta_slow;\n    int32_t  sect_0;\n    int32_t  sect_1;\n    int32_t  sect_2;\n    int32_t  sect_3;\n    bool     src2;\n} ggml_metal_kargs_rope;\n\ntypedef struct {\n    int32_t  ne11;\n    int32_t  ne_12_2; // assume K and V are same shape\n    int32_t  ne_12_3;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n    uint64_t nb21;\n    uint64_t nb22;\n    uint64_t nb23;\n    int32_t  ne31;\n    int32_t  ne32;\n    int32_t  ne33;\n    uint64_t nb31;\n    uint64_t nb32;\n    uint64_t nb33;\n} ggml_metal_kargs_flash_attn_ext_pad;\n\ntypedef struct {\n    int32_t  ne01;\n    int32_t  ne30;\n    int32_t  ne31;\n    int32_t  ne32;\n    int32_t  ne33;\n    uint64_t nb31;\n    uint64_t nb32;\n    uint64_t nb33;\n} ggml_metal_kargs_flash_attn_ext_blk;\n\ntypedef struct {\n    int32_t  ne01;\n    int32_t  ne02;\n    int32_t  ne03;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne11;\n    int32_t  ne_12_2; // assume K and V are same shape\n    int32_t  ne_12_3;\n    int32_t  ns10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n    int32_t  ns20;\n    uint64_t nb21;\n    uint64_t nb22;\n    uint64_t nb23;\n    int32_t  ne31;\n    int32_t  ne32;\n    int32_t  ne33;\n    uint64_t nb31;\n    uint64_t nb32;\n    uint64_t nb33;\n    int32_t  ne1;\n    int32_t  ne2;\n    int32_t  ne3;\n    float    scale;\n    float    max_bias;\n    float    m0;\n    float    m1;\n    int32_t  n_head_log2;\n    float    logit_softcap;\n} ggml_metal_kargs_flash_attn_ext;\n\ntypedef struct {\n    int32_t  ne01;\n    int32_t  ne02;\n    int32_t  ne03;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne11;\n    int32_t  ne_12_2; // assume K and V are same shape\n    int32_t  ne_12_3;\n    int32_t  ns10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n    int32_t  ns20;\n    uint64_t nb21;\n    uint64_t nb22;\n    uint64_t nb23;\n    int32_t  ne31;\n    int32_t  ne32;\n    int32_t  ne33;\n    uint64_t nb31;\n    uint64_t nb32;\n    uint64_t nb33;\n    int32_t  ne1;\n    int32_t  ne2;\n    int32_t  ne3;\n    float    scale;\n    float    max_bias;\n    float    m0;\n    float    m1;\n    int32_t  n_head_log2;\n    float    logit_softcap;\n} ggml_metal_kargs_flash_attn_ext_vec;\n\ntypedef struct {\n    int32_t  nrows;\n} ggml_metal_kargs_flash_attn_ext_vec_reduce;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne02;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne12;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n    int32_t  ne0;\n    int32_t  ne1;\n    int16_t  r2;\n    int16_t  r3;\n} ggml_metal_kargs_mul_mm;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne10;\n    int32_t  ne11;\n    int32_t  ne12;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n    int32_t  ne0;\n    int32_t  ne1;\n    int32_t  nr0;\n    int16_t  r2;\n    int16_t  r3;\n} ggml_metal_kargs_mul_mv;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne10;\n    int32_t  ne11;\n    int32_t  ne12;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n    int32_t  ne0;\n    int32_t  ne1;\n    int16_t  r2;\n    int16_t  r3;\n} ggml_metal_kargs_mul_mv_ext;\n\ntypedef struct {\n    int32_t  ne02;\n    int32_t  ne10;\n    int32_t  ne11;  // n_expert_used (bcast)\n    uint64_t nb11;\n    uint64_t nb12;\n    int32_t  ne21; // n_tokens\n    int32_t  ne20;  // n_expert_used\n    uint64_t nb21;\n} ggml_metal_kargs_mul_mm_id_map0;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne02;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne11;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n    int32_t  ne20;\n    int32_t  ne21;\n    int32_t  ne0;\n    int32_t  ne1;\n    int16_t  r2;\n    int16_t  r3;\n} ggml_metal_kargs_mul_mm_id;\n\ntypedef struct {\n    int32_t  nei0;\n    int32_t  nei1;\n    uint64_t nbi1;\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    int32_t  ne10;\n    int32_t  ne11;\n    int32_t  ne12;\n    int32_t  ne13;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    int32_t  ne0;\n    int32_t  ne1;\n    uint64_t nb1;\n    int32_t  nr0;\n} ggml_metal_kargs_mul_mv_id;\n\n// NORM\n// RMS_NORM\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne00_t;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n    float    eps;\n    int32_t  nef1[3];\n    int32_t  nef2[3];\n    int32_t  nef3[3];\n    uint64_t nbf1[3];\n    uint64_t nbf2[3];\n    uint64_t nbf3[3];\n} ggml_metal_kargs_norm;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    int32_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne0;\n    int32_t  ne1;\n    int32_t  ne2;\n    int32_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n    float    eps;\n} ggml_metal_kargs_l2_norm;\n\ntypedef struct {\n    int64_t  ne00;\n    int64_t  ne01;\n    int64_t  ne02;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    int32_t  ngrp;\n    float    eps;\n} ggml_metal_kargs_group_norm;\n\ntypedef struct {\n    int32_t  IC;\n    int32_t  IL;\n    int32_t  K;\n    int32_t  s0;\n    uint64_t nb0;\n    uint64_t nb1;\n} ggml_metal_kargs_conv_transpose_1d;\n\ntypedef struct {\n    int32_t  IC;\n    int32_t  IH;\n    int32_t  IW;\n    int32_t  KH;\n    int32_t  KW;\n    int32_t  OC;\n    int32_t  s0;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n} ggml_metal_kargs_conv_transpose_2d;\n\ntypedef struct {\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n    int32_t  IW;\n    int32_t  IH;\n    int32_t  KW;\n    int32_t  KH;\n    int32_t  IC;\n    int32_t  OC;\n    int32_t  OW;\n    int32_t  OH;\n    int32_t  N;\n    int32_t  s0;\n    int32_t  s1;\n    int32_t  p0;\n    int32_t  p1;\n    int32_t  d0;\n    int32_t  d1;\n} ggml_metal_kargs_conv_2d;\n\ntypedef struct {\n    uint64_t  ofs0;\n    uint64_t  ofs1;\n    int32_t  IW;\n    int32_t  IH;\n    int32_t  CHW;\n    int32_t  s0;\n    int32_t  s1;\n    int32_t  p0;\n    int32_t  p1;\n    int32_t  d0;\n    int32_t  d1;\n    int32_t  N;\n    int32_t  KH;\n    int32_t  KW;\n    int32_t  KHW; // KH * KW, pre-computed on CPU to save GPU resources\n} ggml_metal_kargs_im2col;\n\ntypedef struct{\n    int32_t  ne00;\n    uint64_t nb01;\n    int32_t  ne10;\n    uint64_t nb11;\n    int32_t  ne0;\n    uint64_t nb1;\n    int32_t  i00;\n    int32_t  i10;\n    float    alpha;\n    float    limit;\n} ggml_metal_kargs_glu;\n\ntypedef struct {\n    uint64_t np;\n} ggml_metal_kargs_sum;\n\ntypedef struct {\n    int64_t  ne00;\n    int64_t  ne01;\n    int64_t  ne02;\n    int64_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int64_t  ne0;\n    int64_t  ne1;\n    int64_t  ne2;\n    int64_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n} ggml_metal_kargs_sum_rows;\n\ntypedef struct {\n    int64_t  ne00;\n    int64_t  ne01;\n    int64_t  ne02;\n    int64_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int64_t  net0;\n    int64_t  net1;\n    int64_t  net2;\n    int64_t  net3;\n    uint64_t nbt0;\n    uint64_t nbt1;\n    uint64_t nbt2;\n    uint64_t nbt3;\n    bool     outb;\n} ggml_metal_kargs_cumsum_blk;\n\ntypedef struct {\n    int64_t  ne00;\n    int64_t  ne01;\n    int64_t  ne02;\n    int64_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int64_t  net0;\n    int64_t  net1;\n    int64_t  net2;\n    int64_t  net3;\n    uint64_t nbt0;\n    uint64_t nbt1;\n    uint64_t nbt2;\n    uint64_t nbt3;\n} ggml_metal_kargs_cumsum_add;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne11;\n    int32_t  ne12;\n    int32_t  ne13;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n    float    scale;\n    float    max_bias;\n    float    m0;\n    float    m1;\n    int32_t  n_head_log2;\n} ggml_metal_kargs_soft_max;\n\ntypedef struct {\n    int64_t  ne00;\n    int64_t  ne01;\n    int64_t  ne02;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    int64_t  ne10;\n    int64_t  ne11;\n    uint64_t nb10;\n    uint64_t nb11;\n    int64_t  ne0;\n    int64_t  ne1;\n    int64_t  ne2;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n} ggml_metal_kargs_ssm_conv;\n\ntypedef struct {\n    int64_t  d_state;\n    int64_t  d_inner;\n    int64_t  n_head;\n    int64_t  n_group;\n    int64_t  n_seq_tokens;\n    int64_t  n_seqs;\n    uint64_t s_off;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t ns12;\n    uint64_t nb13;\n    uint64_t nb20;\n    uint64_t nb21;\n    uint64_t ns21;\n    uint64_t nb22;\n    int64_t  ne30;\n    uint64_t nb31;\n    uint64_t nb41;\n    uint64_t nb42;\n    uint64_t ns42;\n    uint64_t nb43;\n    uint64_t nb51;\n    uint64_t nb52;\n    uint64_t ns52;\n    uint64_t nb53;\n    uint64_t nb0;\n} ggml_metal_kargs_ssm_scan;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    int32_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne10;\n    int32_t  ne11;\n    int32_t  ne12;\n    int32_t  ne13;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n    int32_t  ne20;\n    int32_t  ne21;\n    int32_t  ne22;\n    int32_t  ne23;\n    uint64_t nb20;\n    uint64_t nb21;\n    uint64_t nb22;\n    uint64_t nb23;\n    int32_t  ns02;\n    int32_t  ns12;\n    int32_t  ns22;\n    int32_t  ne0;\n    int32_t  ne1;\n    int32_t  ne2;\n    int32_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n} ggml_metal_kargs_gated_delta_net;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    int32_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne10;\n    int32_t  ne11;\n    int32_t  ne12;\n    int32_t  ne13;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n    int32_t  ne0;\n    int32_t  ne1;\n    int32_t  ne2;\n    int32_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n} ggml_metal_kargs_solve_tri;\n\ntypedef struct {\n    int32_t  ne00t;\n    int32_t  ne00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne10;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n} ggml_metal_kargs_get_rows;\n\ntypedef struct {\n    int32_t  nk0;\n    int32_t  ne01;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne11;\n    int32_t  ne12;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n} ggml_metal_kargs_set_rows;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    int32_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne0;\n    int32_t  ne1;\n    int32_t  ne2;\n    int32_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n} ggml_metal_kargs_diag;\n\ntypedef struct {\n    int64_t  ne00;\n    int64_t  ne01;\n    int64_t  ne02;\n    int64_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int64_t  ne0;\n    int64_t  ne1;\n    int64_t  ne2;\n    int64_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n    float    sf0;\n    float    sf1;\n    float    sf2;\n    float    sf3;\n    float    poffs;\n} ggml_metal_kargs_upscale;\n\ntypedef struct {\n    int64_t  ne00;\n    int64_t  ne01;\n    int64_t  ne02;\n    int64_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int64_t  ne0;\n    int64_t  ne1;\n    int64_t  ne2;\n    int64_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n} ggml_metal_kargs_pad;\n\ntypedef struct {\n    int64_t  ne00;\n    int64_t  ne01;\n    int64_t  ne02;\n    int64_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int64_t  ne0;\n    int64_t  ne1;\n    int64_t  ne2;\n    int64_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n    int32_t  p0;\n    int32_t  p1;\n} ggml_metal_kargs_pad_reflect_1d;\n\ntypedef struct {\n    uint64_t nb1;\n    int      dim;\n    int      max_period;\n} ggml_metal_kargs_timestep_embedding;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    int32_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne0;\n    int32_t  ne1;\n    int32_t  ne2;\n    int32_t  ne3;\n    uint64_t nb0;\n    uint64_t nb1;\n    uint64_t nb2;\n    uint64_t nb3;\n} ggml_metal_kargs_tri;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    int32_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne0;\n    int32_t  ne1;\n    int32_t  ne2;\n    int32_t  ne3;\n    int32_t  top_k;\n} ggml_metal_kargs_argsort;\n\ntypedef struct {\n    int64_t  ne00;\n    int64_t  ne01;\n    int64_t  ne02;\n    int64_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    int32_t  ne0;\n    int32_t  ne1;\n    int32_t  ne2;\n    int32_t  ne3;\n    int32_t  top_k;\n    int32_t  len;\n} ggml_metal_kargs_argsort_merge;\n\ntypedef struct {\n    int64_t  ne0;\n    float    start;\n    float    step;\n} ggml_metal_kargs_arange;\n\ntypedef struct {\n    int64_t val;\n} ggml_metal_kargs_memset;\n\ntypedef struct {\n    int32_t  ne00;\n    int32_t  ne01;\n    int32_t  ne02;\n    int32_t  ne03;\n    uint64_t nb00;\n    uint64_t nb01;\n    uint64_t nb02;\n    uint64_t nb03;\n    uint64_t nb10;\n    uint64_t nb11;\n    uint64_t nb12;\n    uint64_t nb13;\n} ggml_metal_kargs_count_equal;\n\ntypedef struct {\n    int32_t  k0;\n    int32_t  k1;\n    int32_t  s0;\n    int32_t  s1;\n    int32_t  p0;\n    int32_t  p1;\n    int64_t  IH;\n    int64_t  IW;\n    int64_t  OH;\n    int64_t  OW;\n    int64_t  np;\n} ggml_metal_kargs_pool_2d;\n\ntypedef struct {\n    int32_t  k0;\n    int32_t  s0;\n    int32_t  p0;\n    int64_t  IW;\n    int64_t  OW;\n    int64_t  np;\n} ggml_metal_kargs_pool_1d;\n\ntypedef struct {\n     int64_t ne00;\n    uint64_t nb01;\n} ggml_metal_kargs_argmax;\n\ntypedef struct {\n    int64_t  np;\n} ggml_metal_kargs_opt_step_adamw;\n\ntypedef struct {\n    int64_t  np;\n} ggml_metal_kargs_opt_step_sgd;\n\n#endif // GGML_METAL_IMPL\n"
  },
  {
    "path": "src/ggml-metal/ggml-metal-ops.cpp",
    "content": "#include \"ggml-metal-ops.h\"\n\n#include \"ggml.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-backend-impl.h\"\n\n#include \"ggml-metal-impl.h\"\n#include \"ggml-metal-common.h\"\n#include \"ggml-metal-device.h\"\n\n#include <cassert>\n#include <algorithm>\n#include <limits>\n#include <cmath>\n\nstatic ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {\n    if (!t) {\n        return { nullptr, 0 };\n    }\n\n    ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;\n\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t) buffer->context;\n\n    return ggml_metal_buffer_get_id(ctx, t);\n}\n\nstruct ggml_metal_op {\n    ggml_metal_op(\n        ggml_metal_device_t dev,\n        ggml_metal_cmd_buf_t cmd_buf,\n        ggml_cgraph * gf,\n        int  idx_start,\n        int  idx_end,\n        bool use_fusion,\n        bool use_concurrency,\n        bool use_capture,\n        int  debug_graph,\n        int  debug_fusion) {\n        this->dev             = dev;\n        this->lib             = ggml_metal_device_get_library(dev);\n        this->enc             = ggml_metal_encoder_init(cmd_buf, use_concurrency);\n        this->mem_ranges      = ggml_mem_ranges_init(debug_graph);\n        this->idx_start       = idx_start;\n        this->idx_end         = idx_end;\n        this->use_fusion      = use_fusion;\n        this->use_concurrency = use_concurrency;\n        this->use_capture     = use_capture;\n        this->debug_graph     = debug_graph;\n        this->debug_fusion    = debug_fusion;\n        this->gf              = gf;\n\n        idxs.reserve(gf->n_nodes);\n\n        // filter empty nodes\n        // TODO: this can be removed when the allocator starts filtering them earlier\n        //       https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830\n        for (int i = idx_start; i < idx_end; i++) {\n            if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) {\n                idxs.push_back(i);\n            }\n        }\n    }\n\n    ~ggml_metal_op() {\n        ggml_metal_encoder_end_encoding(this->enc);\n        ggml_metal_encoder_free(this->enc);\n        ggml_mem_ranges_free(this->mem_ranges);\n    }\n\n    int n_nodes() const {\n        return idxs.size();\n    }\n\n    ggml_tensor * node(int i) const {\n        assert(i >= 0 && i < (int) idxs.size());\n        return ggml_graph_node(gf, idxs[i]);\n    }\n\n    bool can_fuse(int i0, const ggml_op * ops, int n_ops) const {\n        assert(use_fusion);\n        assert(i0 >= 0 && i0 < n_nodes());\n\n        if (i0 + n_ops > n_nodes()) {\n            return false;\n        }\n\n        return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops);\n    }\n\n    ggml_metal_device_t  dev;\n    ggml_metal_library_t lib;\n    ggml_metal_encoder_t enc;\n    ggml_mem_ranges_t    mem_ranges;\n\n    bool use_fusion;\n    bool use_concurrency;\n    bool use_capture;\n\n    int debug_graph;\n    int debug_fusion;\n\nprivate:\n    ggml_cgraph * gf;\n\n    int idx_start;\n    int idx_end;\n\n    // non-empty node indices\n    std::vector<int> idxs;\n};\n\nggml_metal_op_t ggml_metal_op_init(\n        ggml_metal_device_t dev,\n        ggml_metal_cmd_buf_t cmd_buf,\n        ggml_cgraph * gf,\n        int idx_start,\n        int idx_end,\n        bool use_fusion,\n        bool use_concurrency,\n        bool use_capture,\n        int debug_graph,\n        int debug_fusion) {\n    ggml_metal_op_t res = new ggml_metal_op(\n        dev,\n        cmd_buf,\n        gf,\n        idx_start,\n        idx_end,\n        use_fusion,\n        use_concurrency,\n        use_capture,\n        debug_graph,\n        debug_fusion);\n\n    return res;\n}\n\nvoid ggml_metal_op_free(ggml_metal_op_t ctx) {\n    delete ctx;\n}\n\nint ggml_metal_op_n_nodes(ggml_metal_op_t ctx) {\n    return ctx->n_nodes();\n}\n\nstatic bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) {\n    if (!ctx->mem_ranges) {\n        return true;\n    }\n\n    ggml_metal_encoder_memory_barrier(ctx->enc);\n\n    ggml_mem_ranges_reset(ctx->mem_ranges);\n\n    return true;\n}\n\nstatic bool ggml_metal_op_concurrency_check(ggml_metal_op_t ctx, const ggml_tensor * node) {\n    if (!ctx->mem_ranges) {\n        return false;\n    }\n\n    return ggml_mem_ranges_check(ctx->mem_ranges, node);\n}\n\nstatic bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor * node) {\n    if (!ctx->mem_ranges) {\n        return true;\n    }\n\n    return ggml_mem_ranges_add(ctx->mem_ranges, node);\n}\n\nstatic int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {\n    struct ggml_tensor * node = ctx->node(idx);\n\n    //GGML_LOG_INFO(\"%s: encoding node %3d, op = %8s\\n\", __func__, idx, ggml_op_name(node->op));\n\n    if (ggml_is_empty(node)) {\n        return 1;\n    }\n\n    switch (node->op) {\n        case GGML_OP_NONE:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_TRANSPOSE:\n        case GGML_OP_PERMUTE:\n            {\n                // noop -> next node\n                if (ctx->debug_graph > 0) {\n                    GGML_LOG_DEBUG(\"%s: node[%5d] - %-12s %s\\n\", __func__, idx, ggml_op_name(node->op), \"(noop)\");\n                }\n            } return 1;\n        default:\n            {\n            } break;\n    }\n\n    if (!ggml_metal_device_supports_op(ctx->dev, node)) {\n        GGML_LOG_ERROR(\"%s: error: unsupported op '%s'\\n\", __func__, ggml_op_desc(node));\n        GGML_ABORT(\"unsupported op\");\n    }\n\n    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {\n        return 1;\n    }\n\n    int n_fuse = 1;\n\n    // check if the current node can run concurrently with other nodes before it\n    // the condition is that:\n    //  - the current node cannot write to any previous src or dst ranges\n    //  - the current node cannot read from any previous dst ranges\n    //\n    // if the condition is not satisfied, we put a memory barrier and clear all ranges\n    // otherwise, we add the new ranges to the encoding context and process the node concurrently\n    //\n    {\n        const bool is_concurrent = ggml_metal_op_concurrency_check(ctx, node);\n\n        if (!is_concurrent) {\n            ggml_metal_op_concurrency_reset(ctx);\n        }\n\n        if (ctx->debug_graph > 0) {\n            GGML_LOG_DEBUG(\"%s: node[%5d] - %-12s %-12s %s\\n\", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? \"(concurrent)\" : \"\");\n        }\n        if (ctx->debug_graph > 1) {\n            GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne);\n            GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);\n            GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);\n            GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);\n            GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);\n            GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);\n            GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);\n            GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);\n            GGML_TENSOR_LOCALS( int64_t, ne,  node,         ne);\n            GGML_TENSOR_LOCALS(uint64_t, nb,  node,         nb);\n\n            if (node->src[0]) {\n                GGML_LOG_DEBUG(\"%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\\n\", __func__, ggml_type_name(node->src[0]->type), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,\n                        ggml_is_contiguous(node->src[0]), node->src[0]->name);\n            }\n            if (node->src[1]) {\n                GGML_LOG_DEBUG(\"%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\\n\", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,\n                        ggml_is_contiguous(node->src[1]), node->src[1]->name);\n            }\n            if (node->src[2]) {\n                GGML_LOG_DEBUG(\"%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\\n\", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,\n                        ggml_is_contiguous(node->src[2]), node->src[2]->name);\n            }\n            if (node->src[3]) {\n                GGML_LOG_DEBUG(\"%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\\n\", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,\n                        ggml_is_contiguous(node->src[3]), node->src[3]->name);\n            }\n            if (node) {\n                GGML_LOG_DEBUG(\"%s: node  - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\\n\", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,\n                        node->name);\n            }\n        }\n    }\n\n    switch (node->op) {\n        case GGML_OP_CONCAT:\n            {\n                n_fuse = ggml_metal_op_concat(ctx, idx);\n            } break;\n        case GGML_OP_ADD:\n        case GGML_OP_SUB:\n        case GGML_OP_MUL:\n        case GGML_OP_DIV:\n            {\n                n_fuse = ggml_metal_op_bin(ctx, idx);\n            } break;\n        case GGML_OP_ADD_ID:\n            {\n                n_fuse = ggml_metal_op_add_id(ctx, idx);\n            } break;\n        case GGML_OP_REPEAT:\n            {\n                n_fuse = ggml_metal_op_repeat(ctx, idx);\n            } break;\n        case GGML_OP_ACC:\n            {\n                n_fuse = ggml_metal_op_acc(ctx, idx);\n            } break;\n        case GGML_OP_SCALE:\n        case GGML_OP_FILL:\n        case GGML_OP_CLAMP:\n        case GGML_OP_LEAKY_RELU:\n        case GGML_OP_SQR:\n        case GGML_OP_SQRT:\n        case GGML_OP_SIN:\n        case GGML_OP_COS:\n        case GGML_OP_LOG:\n        case GGML_OP_UNARY:\n            {\n                n_fuse = ggml_metal_op_unary(ctx, idx);\n            } break;\n        case GGML_OP_GLU:\n            {\n                n_fuse = ggml_metal_op_glu(ctx, idx);\n            } break;\n        case GGML_OP_SUM:\n            {\n                n_fuse = ggml_metal_op_sum(ctx, idx);\n            } break;\n        case GGML_OP_SUM_ROWS:\n        case GGML_OP_MEAN:\n            {\n                n_fuse = ggml_metal_op_sum_rows(ctx, idx);\n            } break;\n        case GGML_OP_CUMSUM:\n            {\n                n_fuse = ggml_metal_op_cumsum(ctx, idx);\n            } break;\n        case GGML_OP_SOFT_MAX:\n            {\n                n_fuse = ggml_metal_op_soft_max(ctx, idx);\n            } break;\n        case GGML_OP_SSM_CONV:\n            {\n                n_fuse = ggml_metal_op_ssm_conv(ctx, idx);\n            } break;\n        case GGML_OP_SSM_SCAN:\n            {\n                n_fuse = ggml_metal_op_ssm_scan(ctx, idx);\n            } break;\n        case GGML_OP_RWKV_WKV6:\n        case GGML_OP_RWKV_WKV7:\n            {\n                n_fuse = ggml_metal_op_rwkv(ctx, idx);\n            } break;\n        case GGML_OP_GATED_DELTA_NET:\n            {\n                n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);\n            } break;\n        case GGML_OP_SOLVE_TRI:\n            {\n                n_fuse = ggml_metal_op_solve_tri(ctx, idx);\n            } break;\n        case GGML_OP_MUL_MAT:\n            {\n                n_fuse = ggml_metal_op_mul_mat(ctx, idx);\n            } break;\n        case GGML_OP_MUL_MAT_ID:\n            {\n                n_fuse = ggml_metal_op_mul_mat_id(ctx, idx);\n            } break;\n        case GGML_OP_GET_ROWS:\n            {\n                n_fuse = ggml_metal_op_get_rows(ctx, idx);\n            } break;\n        case GGML_OP_SET_ROWS:\n            {\n                n_fuse = ggml_metal_op_set_rows(ctx, idx);\n            } break;\n        case GGML_OP_DIAG:\n            {\n                n_fuse = ggml_metal_op_diag(ctx, idx);\n            } break;\n        case GGML_OP_L2_NORM:\n            {\n                n_fuse = ggml_metal_op_l2_norm(ctx, idx);\n            } break;\n        case GGML_OP_GROUP_NORM:\n            {\n                n_fuse = ggml_metal_op_group_norm(ctx, idx);\n            } break;\n        case GGML_OP_NORM:\n        case GGML_OP_RMS_NORM:\n            {\n                n_fuse = ggml_metal_op_norm(ctx, idx);\n            } break;\n        case GGML_OP_ROPE:\n            {\n                n_fuse = ggml_metal_op_rope(ctx, idx);\n            } break;\n        case GGML_OP_IM2COL:\n            {\n                n_fuse = ggml_metal_op_im2col(ctx, idx);\n            } break;\n        case GGML_OP_CONV_2D:\n            {\n                n_fuse = ggml_metal_op_conv_2d(ctx, idx);\n            } break;\n        case GGML_OP_CONV_TRANSPOSE_1D:\n            {\n                n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);\n            } break;\n        case GGML_OP_CONV_TRANSPOSE_2D:\n            {\n                n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);\n            } break;\n        case GGML_OP_UPSCALE:\n            {\n                n_fuse = ggml_metal_op_upscale(ctx, idx);\n            } break;\n        case GGML_OP_PAD:\n            {\n                n_fuse = ggml_metal_op_pad(ctx, idx);\n            } break;\n        case GGML_OP_PAD_REFLECT_1D:\n            {\n                n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx);\n            } break;\n        case GGML_OP_ARANGE:\n            {\n                n_fuse = ggml_metal_op_arange(ctx, idx);\n            } break;\n        case GGML_OP_TIMESTEP_EMBEDDING:\n            {\n                n_fuse = ggml_metal_op_timestep_embedding(ctx, idx);\n            } break;\n        case GGML_OP_ARGSORT:\n            {\n                n_fuse = ggml_metal_op_argsort(ctx, idx);\n            } break;\n        case GGML_OP_TOP_K:\n            {\n                n_fuse = ggml_metal_op_top_k(ctx, idx);\n            } break;\n        case GGML_OP_TRI:\n            {\n                n_fuse = ggml_metal_op_tri(ctx, idx);\n            } break;\n        case GGML_OP_FLASH_ATTN_EXT:\n            {\n                n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);\n            } break;\n        case GGML_OP_SET:\n            {\n                n_fuse = ggml_metal_op_set(ctx, idx);\n            } break;\n        case GGML_OP_DUP:\n        case GGML_OP_CPY:\n        case GGML_OP_CONT:\n            {\n                n_fuse = ggml_metal_op_cpy(ctx, idx);\n            } break;\n        case GGML_OP_POOL_1D:\n            {\n                n_fuse = ggml_metal_op_pool_1d(ctx, idx);\n            } break;\n        case GGML_OP_POOL_2D:\n            {\n                n_fuse = ggml_metal_op_pool_2d(ctx, idx);\n            } break;\n        case GGML_OP_ARGMAX:\n            {\n                n_fuse = ggml_metal_op_argmax(ctx, idx);\n            } break;\n        case GGML_OP_OPT_STEP_ADAMW:\n            {\n                n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);\n            } break;\n        case GGML_OP_OPT_STEP_SGD:\n            {\n                n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);\n            } break;\n        case GGML_OP_COUNT_EQUAL:\n            {\n                n_fuse = ggml_metal_op_count_equal(ctx, idx);\n            } break;\n        default:\n            {\n                GGML_LOG_ERROR(\"%s: error: node %3d, op = %8s not implemented\\n\", __func__, idx, ggml_op_name(node->op));\n                GGML_ABORT(\"fatal error\");\n            }\n    }\n\n    if (ctx->debug_graph > 0) {\n        if (n_fuse > 1) {\n            GGML_LOG_DEBUG(\"%s:               fuse %d ops\\n\", __func__, n_fuse);\n        }\n    }\n\n    // update the mem ranges in the encoding context\n    for (int i = 0; i < n_fuse; ++i) {\n        if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) {\n            ggml_metal_op_concurrency_reset(ctx);\n        }\n    }\n\n    return n_fuse;\n}\n\nint ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) {\n    if (ctx->use_capture) {\n        ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx)));\n    }\n\n    int res = ggml_metal_op_encode_impl(ctx, idx);\n    if (idx + res > ctx->n_nodes()) {\n        GGML_ABORT(\"fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s\",\n                \"https://github.com/ggml-org/llama.cpp/pull/14849\");\n    }\n\n    if (ctx->use_capture) {\n        ggml_metal_encoder_debug_group_pop(ctx->enc);\n    }\n\n    return res;\n}\n\nint ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    const int32_t dim = ((const int32_t *) op->op_params)[0];\n\n    ggml_metal_kargs_concat args = {\n        /*.ne00 =*/ ne00,\n        /*.ne01 =*/ ne01,\n        /*.ne02 =*/ ne02,\n        /*.ne03 =*/ ne03,\n        /*.nb00 =*/ nb00,\n        /*.nb01 =*/ nb01,\n        /*.nb02 =*/ nb02,\n        /*.nb03 =*/ nb03,\n        /*.ne10 =*/ ne10,\n        /*.ne11 =*/ ne11,\n        /*.ne12 =*/ ne12,\n        /*.ne13 =*/ ne13,\n        /*.nb10 =*/ nb10,\n        /*.nb11 =*/ nb11,\n        /*.nb12 =*/ nb12,\n        /*.nb13 =*/ nb13,\n        /*.ne0  =*/ ne0,\n        /*.ne1  =*/ ne1,\n        /*.ne2  =*/ ne2,\n        /*.ne3  =*/ ne3,\n        /*.nb0  =*/ nb0,\n        /*.nb1  =*/ nb1,\n        /*.nb2  =*/ nb2,\n        /*.nb3  =*/ nb3,\n        /*.dim  =*/ dim,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);\n\n    const int nth = std::min(1024, ne0);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);\n\n    ggml_metal_kargs_repeat args = {\n        /*.ne00 =*/ ne00,\n        /*.ne01 =*/ ne01,\n        /*.ne02 =*/ ne02,\n        /*.ne03 =*/ ne03,\n        /*.nb00 =*/ nb00,\n        /*.nb01 =*/ nb01,\n        /*.nb02 =*/ nb02,\n        /*.nb03 =*/ nb03,\n        /*.ne0  =*/ ne0,\n        /*.ne1  =*/ ne1,\n        /*.ne2  =*/ ne2,\n        /*.ne3  =*/ ne3,\n        /*.nb0  =*/ nb0,\n        /*.nb1  =*/ nb1,\n        /*.nb2  =*/ nb2,\n        /*.nb3  =*/ nb3,\n    };\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);\n\n    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->type         == GGML_TYPE_F32);\n\n    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));\n    GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));\n\n    const size_t pnb1 = ((const int32_t *) op->op_params)[0];\n    const size_t pnb2 = ((const int32_t *) op->op_params)[1];\n    const size_t pnb3 = ((const int32_t *) op->op_params)[2];\n    const size_t offs = ((const int32_t *) op->op_params)[3];\n\n    const bool inplace = (bool) ((const int32_t *) op->op_params)[4];\n\n    if (!inplace) {\n        // run a separate kernel to cpy src->dst\n        // not sure how to avoid this\n        // TODO: make a simpler cpy_bytes kernel\n\n        //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;\n        auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);\n\n        ggml_metal_kargs_cpy args = {\n            /*.nk0  =*/ ne00,\n            /*.ne00 =*/ ne00,\n            /*.ne01 =*/ ne01,\n            /*.ne02 =*/ ne02,\n            /*.ne03 =*/ ne03,\n            /*.nb00 =*/ nb00,\n            /*.nb01 =*/ nb01,\n            /*.nb02 =*/ nb02,\n            /*.nb03 =*/ nb03,\n            /*.ne0  =*/ ne0,\n            /*.ne1  =*/ ne1,\n            /*.ne2  =*/ ne2,\n            /*.ne3  =*/ ne3,\n            /*.nb0  =*/ nb0,\n            /*.nb1  =*/ nb1,\n            /*.nb2  =*/ nb2,\n            /*.nb3  =*/ nb3,\n        };\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline);\n        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);\n\n        const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);\n\n        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);\n\n        ggml_metal_op_concurrency_reset(ctx);\n    }\n\n    ggml_metal_kargs_bin args = {\n        /*.ne00 =*/ ne10,\n        /*.ne01 =*/ ne11,\n        /*.ne02 =*/ ne12,\n        /*.ne03 =*/ ne13,\n        /*.nb00 =*/ nb00,\n        /*.nb01 =*/ pnb1,\n        /*.nb02 =*/ pnb2,\n        /*.nb03 =*/ pnb3,\n        /*.ne10 =*/ ne10,\n        /*.ne11 =*/ ne11,\n        /*.ne12 =*/ ne12,\n        /*.ne13 =*/ ne13,\n        /*.nb10 =*/ nb10,\n        /*.nb11 =*/ nb11,\n        /*.nb12 =*/ nb12,\n        /*.nb13 =*/ nb13,\n        /*.ne0  =*/ ne10,\n        /*.ne1  =*/ ne11,\n        /*.ne2  =*/ ne12,\n        /*.ne3  =*/ ne13,\n        /*.nb0  =*/ nb0,\n        /*.nb1  =*/ pnb1,\n        /*.nb2  =*/ pnb2,\n        /*.nb3  =*/ pnb3,\n        /*.offs =*/ offs,\n        /*.o1   =*/ { 0 },\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);\n\n    const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n\n    int nth = 1;\n\n    while (2*nth < args.ne0 && nth < nth_max) {\n        nth *= 2;\n    }\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));\n\n    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);\n    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);\n\n    ggml_metal_kargs_unary args = {\n        /*.ne00  =*/ ne00,\n        /*.ne01  =*/ ne01,\n        /*.ne02  =*/ ne02,\n        /*.ne03  =*/ ne03,\n        /*.nb00  =*/ nb00,\n        /*.nb01  =*/ nb01,\n        /*.nb02  =*/ nb02,\n        /*.nb03  =*/ nb03,\n        /*.ne0   =*/ ne0,\n        /*.ne1   =*/ ne1,\n        /*.ne2   =*/ ne2,\n        /*.ne3   =*/ ne3,\n        /*.nb0   =*/ nb0,\n        /*.nb1   =*/ nb1,\n        /*.nb2   =*/ nb2,\n        /*.nb3   =*/ nb3,\n        /*.slope =*/ 0.0,\n        /*.scale =*/ 0.0,\n        /*.bias  =*/ 0.0,\n        /*.val   =*/ 0.0,\n        /*.min   =*/ 0.0,\n        /*.max   =*/ 0.0,\n    };\n\n    if (op->op == GGML_OP_LEAKY_RELU) {\n        args.slope = ggml_get_op_params_f32(op, 0);\n    }\n\n    if (op->op == GGML_OP_SCALE) {\n        args.scale = ggml_get_op_params_f32(op, 0);\n        args.bias  = ggml_get_op_params_f32(op, 1);\n    }\n\n    if (op->op == GGML_OP_FILL) {\n        args.val = ggml_get_op_params_f32(op, 0);\n    }\n\n    if (op->op == GGML_OP_CLAMP) {\n        args.min = ggml_get_op_params_f32(op, 0);\n        args.max = ggml_get_op_params_f32(op, 1);\n    }\n\n    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);\n\n    if (pipeline.c4) {\n        args.ne00 = ne00/4;\n        args.ne0  = ne0/4;\n    }\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);\n    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);\n\n    if (pipeline.cnt) {\n        const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);\n\n        ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);\n    } else {\n        const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n\n        const int nth = MIN(args.ne00, nth_max);\n\n        const int nk0 = (args.ne00 + nth - 1)/nth;\n\n        ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);\n    }\n\n    return 1;\n}\n\nint ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    if (op->src[1]) {\n        GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));\n    }\n\n    auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);\n\n    const int32_t swp = ggml_get_op_params_i32(op, 1);\n    const float alpha = ggml_get_op_params_f32(op, 2);\n    const float limit = ggml_get_op_params_f32(op, 3);\n\n    const int32_t i00 = swp ? ne0 : 0;\n    const int32_t i10 = swp ? 0 : ne0;\n\n    ggml_metal_kargs_glu args = {\n        /*.ne00 =*/ ne00,\n        /*.nb01 =*/ nb01,\n        /*.ne10 =*/ op->src[1] ? ne10 : ne00,\n        /*.nb11 =*/ op->src[1] ? nb11 : nb01,\n        /*.ne0  =*/ ne0,\n        /*.nb1  =*/ nb1,\n        /*.i00  =*/ op->src[1] ? 0 : i00,\n        /*.i10  =*/ op->src[1] ? 0 : i10,\n        /*.alpha=*/ alpha,\n        /*.limit=*/ limit\n    };\n\n    const int64_t nrows = ggml_nrows(op->src[0]);\n\n    const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    if (op->src[1]) {\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n    } else {\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 2);\n    }\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op  = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);\n\n    ggml_metal_kargs_sum args = {\n        /*.np =*/ n,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);\n\n    int nth = 32; // SIMD width\n\n    while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {\n        nth *= 2;\n    }\n\n    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n    nth = std::min(nth, (int) n);\n\n    const int nsg = (nth + 31) / 32;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);\n\n    ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));\n\n    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);\n    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);\n\n    ggml_metal_kargs_sum_rows args = {\n        /*.ne00 =*/ ne00,\n        /*.ne01 =*/ ne01,\n        /*.ne02 =*/ ne02,\n        /*.ne03 =*/ ne03,\n        /*.nb00 =*/ nb00,\n        /*.nb01 =*/ nb01,\n        /*.nb02 =*/ nb02,\n        /*.nb03 =*/ nb03,\n        /*.ne0  =*/ ne0,\n        /*.ne1  =*/ ne1,\n        /*.ne2  =*/ ne2,\n        /*.ne3  =*/ ne3,\n        /*.nb0  =*/ nb0,\n        /*.nb1  =*/ nb1,\n        /*.nb2  =*/ nb2,\n        /*.nb3  =*/ nb3,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);\n\n    if (pipeline.c4) {\n        args.ne00 = ne00/4;\n        args.ne0  = ne0/4;\n    }\n\n    int nth = 32; // SIMD width\n\n    while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {\n        nth *= 2;\n    }\n\n    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n    nth = std::min(nth, (int) args.ne00);\n\n    const size_t smem = pipeline.smem;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);\n    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);\n\n    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);\n\n    int nth = 1;\n    while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {\n        nth *= 2;\n    }\n\n    GGML_ASSERT(ne00 <= nth*nth);\n\n    const int64_t net0 = (ne00 + nth - 1) / nth;\n    const int64_t net1 = ne01;\n    const int64_t net2 = ne02;\n    const int64_t net3 = ne03;\n\n    const uint64_t nbt0 = sizeof(float);\n    const uint64_t nbt1 = net0*nbt0;\n    const uint64_t nbt2 = net1*nbt1;\n    const uint64_t nbt3 = net2*nbt2;\n\n    const size_t smem = GGML_PAD(32*sizeof(float), 16);\n\n    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);\n    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);\n\n    ggml_metal_buffer_id bid_tmp = bid_dst;\n    bid_tmp.offs += ggml_nbytes(op);\n\n    {\n        ggml_metal_kargs_cumsum_blk args = {\n            /*.ne00 =*/ ne00,\n            /*.ne01 =*/ ne01,\n            /*.ne02 =*/ ne02,\n            /*.ne03 =*/ ne03,\n            /*.nb00 =*/ nb00,\n            /*.nb01 =*/ nb01,\n            /*.nb02 =*/ nb02,\n            /*.nb03 =*/ nb03,\n            /*.net0 =*/ net0,\n            /*.net1 =*/ net1,\n            /*.net2 =*/ net2,\n            /*.net3 =*/ net3,\n            /*.nbt0 =*/ nbt0,\n            /*.nbt1 =*/ nbt1,\n            /*.nbt2 =*/ nbt2,\n            /*.nbt3 =*/ nbt3,\n            /*.outb =*/ ne00 > nth,\n        };\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline_blk);\n        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);\n        ggml_metal_encoder_set_buffer  (enc, bid_tmp,  2);\n        ggml_metal_encoder_set_buffer  (enc, bid_dst,  3);\n\n        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n        ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);\n    }\n\n    if (ne00 > nth) {\n        ggml_metal_op_concurrency_reset(ctx);\n\n        {\n            ggml_metal_kargs_cumsum_blk args = {\n                /*.ne00 =*/ net0,\n                /*.ne01 =*/ net1,\n                /*.ne02 =*/ net2,\n                /*.ne03 =*/ net3,\n                /*.nb00 =*/ nbt0,\n                /*.nb01 =*/ nbt1,\n                /*.nb02 =*/ nbt2,\n                /*.nb03 =*/ nbt3,\n                /*.net0 =*/ net0,\n                /*.net1 =*/ net1,\n                /*.net2 =*/ net2,\n                /*.net3 =*/ net3,\n                /*.nbt0 =*/ nbt0,\n                /*.nbt1 =*/ nbt1,\n                /*.nbt2 =*/ nbt2,\n                /*.nbt3 =*/ nbt3,\n                /*.outb =*/ false,\n            };\n\n            ggml_metal_encoder_set_pipeline(enc, pipeline_blk);\n            ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n            ggml_metal_encoder_set_buffer  (enc, bid_tmp, 1);\n            ggml_metal_encoder_set_buffer  (enc, bid_tmp, 2);\n            ggml_metal_encoder_set_buffer  (enc, bid_tmp, 3);\n\n            ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n            ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);\n        }\n\n        ggml_metal_op_concurrency_reset(ctx);\n\n        {\n            auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);\n\n            ggml_metal_kargs_cumsum_add args = {\n                /*.ne00 =*/ ne00,\n                /*.ne01 =*/ ne01,\n                /*.ne02 =*/ ne02,\n                /*.ne03 =*/ ne03,\n                /*.nb00 =*/ nb00,\n                /*.nb01 =*/ nb01,\n                /*.nb02 =*/ nb02,\n                /*.nb03 =*/ nb03,\n                /*.net0 =*/ net0,\n                /*.net1 =*/ net1,\n                /*.net2 =*/ net2,\n                /*.net3 =*/ net3,\n                /*.nbt0 =*/ nbt0,\n                /*.nbt1 =*/ nbt1,\n                /*.nbt2 =*/ nbt2,\n                /*.nbt3 =*/ nbt3,\n            };\n\n            ggml_metal_encoder_set_pipeline(enc, pipeline_add);\n            ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n            ggml_metal_encoder_set_buffer  (enc, bid_tmp, 1);\n            ggml_metal_encoder_set_buffer  (enc, bid_dst, 2);\n\n            ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);\n        }\n    }\n\n    return 1;\n}\n\nint ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);\n\n    ggml_metal_kargs_get_rows args = {\n        /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,\n        /*.ne00  =*/ ne00,\n        /*.nb01  =*/ nb01,\n        /*.nb02  =*/ nb02,\n        /*.nb03  =*/ nb03,\n        /*.ne10  =*/ ne10,\n        /*.nb10  =*/ nb10,\n        /*.nb11  =*/ nb11,\n        /*.nb12  =*/ nb12,\n        /*.nb1   =*/ nb1,\n        /*.nb2   =*/ nb2,\n        /*.nb3   =*/ nb3,\n    };\n\n    const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n\n    const int nw0 = (args.ne00t + nth - 1)/nth;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);\n\n    const int32_t nk0 = ne0/ggml_blck_size(op->type);\n\n    int nth = 32; // SIMD width\n\n    while (nth < nk0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {\n        nth *= 2;\n    }\n\n    int nrptg = 1;\n    if (nth > nk0) {\n        nrptg = (nth + nk0 - 1)/nk0;\n        nth   = nk0;\n\n        if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {\n            nrptg--;\n        }\n    }\n\n    nth = std::min(nth, nk0);\n\n    ggml_metal_kargs_set_rows args = {\n        /*.nk0  =*/ nk0,\n        /*.ne01 =*/ ne01,\n        /*.nb01 =*/ nb01,\n        /*.nb02 =*/ nb02,\n        /*.nb03 =*/ nb03,\n        /*.ne11 =*/ ne11,\n        /*.ne12 =*/ ne12,\n        /*.nb10 =*/ nb10,\n        /*.nb11 =*/ nb11,\n        /*.nb12 =*/ nb12,\n        /*.nb1  =*/ nb1,\n        /*.nb2  =*/ nb2,\n        /*.nb3  =*/ nb3,\n    };\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS(int32_t,  ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS(int32_t,  ne, op, ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);\n\n    ggml_metal_kargs_diag args = {\n        /*.ne00 =*/ne00,\n        /*.ne01 =*/ne01,\n        /*.ne02 =*/ne02,\n        /*.ne03 =*/ne03,\n        /*.nb00 =*/nb00,\n        /*.nb01 =*/nb01,\n        /*.nb02 =*/nb02,\n        /*.nb03 =*/nb03,\n        /*.ne0  =*/ne0,\n        /*.ne1  =*/ne1,\n        /*.ne2  =*/ne2,\n        /*.ne3  =*/ne3,\n        /*.nb0  =*/nb0,\n        /*.nb1  =*/nb1,\n        /*.nb2  =*/nb2,\n        /*.nb3  =*/nb3,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op),         2);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    float scale;\n    float max_bias;\n\n    memcpy(&scale,    ((const int32_t *) op->op_params) + 0, sizeof(scale));\n    memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias));\n\n    const uint32_t n_head      = op->src[0]->ne[2];\n    const  int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));\n\n    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);\n    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n    // softmax\n\n    ggml_metal_kargs_soft_max args = {\n        /*.ne00        =*/ ne00,\n        /*.ne01        =*/ ne01,\n        /*.ne02        =*/ ne02,\n        /*.nb01        =*/ nb01,\n        /*.nb02        =*/ nb02,\n        /*.nb03        =*/ nb03,\n        /*.ne11        =*/ ne11,\n        /*.ne12        =*/ ne12,\n        /*.ne13        =*/ ne13,\n        /*.nb11        =*/ nb11,\n        /*.nb12        =*/ nb12,\n        /*.nb13        =*/ nb13,\n        /*.nb1         =*/ nb1,\n        /*.nb2         =*/ nb2,\n        /*.nb3         =*/ nb3,\n        /*.scale       =*/ scale,\n        /*.max_bias    =*/ max_bias,\n        /*.m0          =*/ m0,\n        /*.m1          =*/ m1,\n        /*.n_head_log2 =*/ n_head_log2,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);\n\n    int nth = 32; // SIMD width\n\n    if (ne00%4 == 0) {\n        while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {\n            nth *= 2;\n        }\n    } else {\n        while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {\n            nth *= 2;\n        }\n    }\n\n    const size_t smem = pipeline.smem;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    if (op->src[1]) {\n        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n    } else {\n        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2);\n    }\n    if (op->src[2]) {\n        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[2]), 3);\n    } else {\n        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);\n    }\n    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4);\n\n    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    ggml_metal_kargs_ssm_conv args = {\n        /*.ne00 =*/ ne00,\n        /*.ne01 =*/ ne01,\n        /*.ne02 =*/ ne02,\n        /*.nb00 =*/ nb00,\n        /*.nb01 =*/ nb01,\n        /*.nb02 =*/ nb02,\n        /*.ne10 =*/ ne10,\n        /*.ne11 =*/ ne11,\n        /*.nb10 =*/ nb10,\n        /*.nb11 =*/ nb11,\n        /*.ne0  =*/ ne0,\n        /*.ne1  =*/ ne1,\n        /*.ne2  =*/ ne2,\n        /*.nb0  =*/ nb0,\n        /*.nb1  =*/ nb1,\n        /*.nb2  =*/ nb2,\n    };\n\n    // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead\n    const bool use_batched = (ne1 > 1);\n\n    if (use_batched) {\n        // Determine the smallest power of 2 that's >= ne1, but <= 256\n        int BATCH_SIZE;\n        if      (ne1 > 128) BATCH_SIZE = 256;\n        else if (ne1 > 64 ) BATCH_SIZE = 128;\n        else if (ne1 > 32 ) BATCH_SIZE = 64;\n        else if (ne1 > 16 ) BATCH_SIZE = 32;\n        else if (ne1 > 8  ) BATCH_SIZE = 16;\n        else if (ne1 > 4  ) BATCH_SIZE = 8;\n        else                BATCH_SIZE = 2;\n\n        auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline);\n        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);\n        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op),         3);\n\n        // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences\n        // Each threadgroup has BATCH_SIZE threads, each handling one token\n        const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;\n        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);\n    } else {\n        auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline);\n        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);\n        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op),         3);\n\n        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);\n    }\n\n    return 1;\n}\n\nint ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne4, op->src[4], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb4, op->src[4], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne5, op->src[5], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb5, op->src[5], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    const ggml_tensor * src3 = op->src[3];\n    const ggml_tensor * src4 = op->src[4];\n    const ggml_tensor * src5 = op->src[5];\n    const ggml_tensor * src6 = op->src[6];\n\n    GGML_ASSERT(src3);\n    GGML_ASSERT(src4);\n    GGML_ASSERT(src5);\n    GGML_ASSERT(src6);\n\n    const int64_t d_state      = ne00;\n    const int64_t d_inner      = ne01;\n    const int64_t n_head       = ne02;\n    const int64_t n_group      = ne41;\n    const int64_t n_seq_tokens = ne12;\n    const int64_t n_seqs       = ne13;\n\n    ggml_metal_kargs_ssm_scan args = {\n        /*.d_state      =*/ d_state,\n        /*.d_inner      =*/ d_inner,\n        /*.n_head       =*/ n_head,\n        /*.n_group      =*/ n_group,\n        /*.n_seq_tokens =*/ n_seq_tokens,\n        /*.n_seqs       =*/ n_seqs,\n        /*.s_off        =*/ ggml_nelements(op->src[1]) * sizeof(float),\n        /*.nb00         =*/ nb00,\n        /*.nb01         =*/ nb01,\n        /*.nb02         =*/ nb02,\n        /*.nb03         =*/ nb03,\n        /*.nb10         =*/ nb10,\n        /*.nb11         =*/ nb11,\n        /*.nb12         =*/ nb12,\n        /*.ns12         =*/ nb12/nb10,\n        /*.nb13         =*/ nb13,\n        /*.nb20         =*/ nb20,\n        /*.nb21         =*/ nb21,\n        /*.ns21         =*/ nb21/nb20,\n        /*.nb22         =*/ nb22,\n        /*.ne30         =*/ ne30,\n        /*.nb31         =*/ nb31,\n        /*.nb41         =*/ nb41,\n        /*.nb42         =*/ nb42,\n        /*.ns42         =*/ nb42/nb40,\n        /*.nb43         =*/ nb43,\n        /*.nb51         =*/ nb51,\n        /*.nb52         =*/ nb52,\n        /*.ns52         =*/ nb52/nb50,\n        /*.nb53         =*/ nb53,\n        /*.nb0          =*/ nb0,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);\n\n    GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n\n    const size_t smem = pipeline.smem;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), 3);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[3]), 4);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[4]), 5);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[5]), 6);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[6]), 7);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         8);\n\n    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];\n    const int64_t T = op->src[0]->ne[2];\n    const int64_t C = op->ne[0];\n    const int64_t H = op->src[0]->ne[1];\n\n    auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);\n\n    int ida = 0;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);\n    if (op->op == GGML_OP_RWKV_WKV7) {\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);\n    }\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         ida++);\n    ggml_metal_encoder_set_bytes   (enc, (void *) &B, sizeof(B), ida++);\n    ggml_metal_encoder_set_bytes   (enc, (void *) &T, sizeof(T), ida++);\n    ggml_metal_encoder_set_bytes   (enc, (void *) &C, sizeof(C), ida++);\n    ggml_metal_encoder_set_bytes   (enc, (void *) &H, sizeof(H), ida++);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);\n\n    int ida = 0;\n\n    ggml_metal_kargs_gated_delta_net args = {\n        /*.ne00 =*/ ne00,\n        /*.ne01 =*/ ne01,\n        /*.ne02 =*/ ne02,\n        /*.ne03 =*/ ne03,\n        /*.nb00 =*/ nb00,\n        /*.nb01 =*/ nb01,\n        /*.nb02 =*/ nb02,\n        /*.nb03 =*/ nb03,\n        /*.ne10 =*/ ne10,\n        /*.ne11 =*/ ne11,\n        /*.ne12 =*/ ne12,\n        /*.ne13 =*/ ne13,\n        /*.nb10 =*/ nb10,\n        /*.nb11 =*/ nb11,\n        /*.nb12 =*/ nb12,\n        /*.nb13 =*/ nb13,\n        /*.ne20 =*/ ne20,\n        /*.ne21 =*/ ne21,\n        /*.ne22 =*/ ne22,\n        /*.ne23 =*/ ne23,\n        /*.nb20 =*/ nb20,\n        /*.nb21 =*/ nb21,\n        /*.nb22 =*/ nb22,\n        /*.nb23 =*/ nb23,\n        /*.ns02 =*/ (int32_t) (nb02/sizeof(float)),\n        /*.ns12 =*/ (int32_t) (nb12/sizeof(float)),\n        /*.ns22 =*/ (int32_t) (nb22/sizeof(float)),\n        /*.ne0  =*/ ne0,\n        /*.ne1  =*/ ne1,\n        /*.ne2  =*/ ne2,\n        /*.ne3  =*/ ne3,\n        /*.nb0  =*/ nb0,\n        /*.nb1  =*/ nb1,\n        /*.nb2  =*/ nb2,\n        /*.nb3  =*/ nb3,\n    };\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args),                  ida++);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         ida++); // dst\n\n    const int nsg = pipeline.nsg;\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    ggml_metal_kargs_solve_tri args = {\n        /*.ne00 =*/ ne00,\n        /*.ne01 =*/ ne01,\n        /*.ne02 =*/ ne02,\n        /*.ne03 =*/ ne03,\n        /*.nb00 =*/ nb00,\n        /*.nb01 =*/ nb01,\n        /*.nb02 =*/ nb02,\n        /*.nb03 =*/ nb03,\n        /*.ne10 =*/ ne10,\n        /*.ne11 =*/ ne11,\n        /*.ne12 =*/ ne12,\n        /*.ne13 =*/ ne13,\n        /*.nb10 =*/ nb10,\n        /*.nb11 =*/ nb11,\n        /*.nb12 =*/ nb12,\n        /*.nb13 =*/ nb13,\n        /*.ne0  =*/ ne0,\n        /*.ne1  =*/ ne1,\n        /*.ne2  =*/ ne2,\n        /*.ne3  =*/ ne3,\n        /*.nb0  =*/ nb0,\n        /*.nb1  =*/ nb1,\n        /*.nb2  =*/ nb2,\n        /*.nb3  =*/ nb3,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);\n\n    const int nsg = pipeline.nsg;\n\n    ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);\n    ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);\n    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);\n\n    const size_t pnb1 = ((const int32_t *) op->op_params)[0];\n    const size_t pnb2 = ((const int32_t *) op->op_params)[1];\n    const size_t pnb3 = ((const int32_t *) op->op_params)[2];\n    const size_t offs = ((const int32_t *) op->op_params)[3];\n\n    const bool inplace = (bool) ((const int32_t *) op->op_params)[4];\n\n    if (!inplace) {\n        // run a separate kernel to cpy src->dst\n        // not sure how to avoid this\n        // TODO: make a simpler cpy_bytes kernel\n\n        //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;\n        auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);\n\n        ggml_metal_kargs_cpy args = {\n            /*.nk0  =*/ ne00,\n            /*.ne00 =*/ ne00,\n            /*.ne01 =*/ ne01,\n            /*.ne02 =*/ ne02,\n            /*.ne03 =*/ ne03,\n            /*.nb00 =*/ nb00,\n            /*.nb01 =*/ nb01,\n            /*.nb02 =*/ nb02,\n            /*.nb03 =*/ nb03,\n            /*.ne0  =*/ ne0,\n            /*.ne1  =*/ ne1,\n            /*.ne2  =*/ ne2,\n            /*.ne3  =*/ ne3,\n            /*.nb0  =*/ nb0,\n            /*.nb1  =*/ nb1,\n            /*.nb2  =*/ nb2,\n            /*.nb3  =*/ nb3,\n        };\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline);\n        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);\n        ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);\n\n        const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);\n\n        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);\n\n        ggml_metal_op_concurrency_reset(ctx);\n    }\n\n    auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);\n\n    GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);\n\n    int64_t nk0 = ne10;\n    if (ggml_is_quantized(op->src[1]->type)) {\n        nk0 = ne10/16;\n    } else if (ggml_is_quantized(op->type)) {\n        nk0 = ne10/ggml_blck_size(op->type);\n    }\n\n    int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n\n    // when rows are small, we can batch them together in a single threadgroup\n    int nrptg = 1;\n\n    // TODO: relax this constraint in the future\n    if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {\n        if (nth > nk0) {\n            nrptg = (nth + nk0 - 1)/nk0;\n            nth   = nk0;\n\n            if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {\n                nrptg--;\n            }\n        }\n    }\n\n    nth = std::min<int>(nth, nk0);\n\n    ggml_metal_kargs_cpy args = {\n        /*.nk0  =*/ nk0,\n        /*.ne00 =*/ ne10,\n        /*.ne01 =*/ ne11,\n        /*.ne02 =*/ ne12,\n        /*.ne03 =*/ ne13,\n        /*.nb00 =*/ nb10,\n        /*.nb01 =*/ nb11,\n        /*.nb02 =*/ nb12,\n        /*.nb03 =*/ nb13,\n        /*.ne0  =*/ ne10,\n        /*.ne1  =*/ ne11,\n        /*.ne2  =*/ ne12,\n        /*.ne3  =*/ ne13,\n        /*.nb0  =*/ ggml_element_size(op),\n        /*.nb1  =*/ pnb1,\n        /*.nb2  =*/ pnb2,\n        /*.nb3  =*/ pnb3,\n    };\n\n    const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;\n\n    bid_dst.offs += offs;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, bid_src1, 1);\n    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);\n\n    GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);\n\n    int64_t nk0 = ne00;\n    if (ggml_is_quantized(op->src[0]->type)) {\n        nk0 = ne00/16;\n    } else if (ggml_is_quantized(op->type)) {\n        nk0 = ne00/ggml_blck_size(op->type);\n    }\n\n    int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n\n    // when rows are small, we can batch them together in a single threadgroup\n    int nrptg = 1;\n\n    // TODO: relax this constraint in the future\n    if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {\n        if (nth > nk0) {\n            nrptg = (nth + nk0 - 1)/nk0;\n            nth   = nk0;\n\n            if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {\n                nrptg--;\n            }\n        }\n    }\n\n    nth = std::min<int>(nth, nk0);\n\n    ggml_metal_kargs_cpy args = {\n        /*.nk0  =*/ nk0,\n        /*.ne00 =*/ ne00,\n        /*.ne01 =*/ ne01,\n        /*.ne02 =*/ ne02,\n        /*.ne03 =*/ ne03,\n        /*.nb00 =*/ nb00,\n        /*.nb01 =*/ nb01,\n        /*.nb02 =*/ nb02,\n        /*.nb03 =*/ nb03,\n        /*.ne0  =*/ ne0,\n        /*.ne1  =*/ ne1,\n        /*.ne2  =*/ ne2,\n        /*.ne3  =*/ ne3,\n        /*.nb0  =*/ nb0,\n        /*.nb1  =*/ nb1,\n        /*.nb2  =*/ nb2,\n        /*.nb3  =*/ nb3,\n    };\n\n    const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    const int32_t * opts = op->op_params;\n    ggml_op_pool op_pool = (ggml_op_pool) opts[0];\n\n    const int32_t k0 = opts[1];\n    const int32_t s0 = opts[2];\n    const int32_t p0 = opts[3];\n\n    const int64_t IW = op->src[0]->ne[0];\n    const int64_t OW = op->ne[0];\n\n    const int64_t np = ggml_nelements(op);\n\n    ggml_metal_kargs_pool_1d args_pool_1d = {\n        /* .k0 = */  k0,\n        /* .s0 = */  s0,\n        /* .p0 = */  p0,\n        /* .IW = */  IW,\n        /* .OW = */  OW,\n        /* .np = */  np\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);\n\n    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);\n    const int ntg = (np + nth - 1) / nth;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args_pool_1d, sizeof(args_pool_1d),  0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);\n\n    return 1;\n}\n\n\nint ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    const int32_t * opts = op->op_params;\n    ggml_op_pool op_pool = (ggml_op_pool) opts[0];\n\n    const int32_t k0 = opts[1];\n    const int32_t k1 = opts[2];\n    const int32_t s0 = opts[3];\n    const int32_t s1 = opts[4];\n    const int32_t p0 = opts[5];\n    const int32_t p1 = opts[6];\n\n    const int64_t IH = op->src[0]->ne[1];\n    const int64_t IW = op->src[0]->ne[0];\n\n    const int64_t N  = op->ne[3];\n    const int64_t OC = op->ne[2];\n    const int64_t OH = op->ne[1];\n    const int64_t OW = op->ne[0];\n\n    const int64_t np = N * OC * OH * OW;\n\n    ggml_metal_kargs_pool_2d args_pool_2d = {\n        /* .k0 = */ k0,\n        /* .k1 = */ k1,\n        /* .s0 = */ s0,\n        /* .s1 = */ s1,\n        /* .p0 = */ p0,\n        /* .p1 = */ p1,\n        /* .IH = */ IH,\n        /* .IW = */ IW,\n        /* .OH = */ OH,\n        /* .OW = */ OW,\n        /* .np = */ np\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);\n\n    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);\n    const int ntg = (np + nth - 1) / nth;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args_pool_2d, sizeof(args_pool_2d), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    GGML_ASSERT(ne00 == ne10);\n\n    GGML_ASSERT(ne12 % ne02 == 0);\n    GGML_ASSERT(ne13 % ne03 == 0);\n\n    const int16_t r2 = ne12/ne02;\n    const int16_t r3 = ne13/ne03;\n\n    // find the break-even point where the matrix-matrix kernel becomes more efficient compared\n    // to the matrix-vector kernel\n    const int ne11_mm_min = 8;\n\n    // first try to use small-batch mat-mv kernels\n    // these should be efficient for BS [2, ~8]\n    if (op->src[1]->type == GGML_TYPE_F32 && (ne00%128 == 0) &&\n        (\n         (\n          (\n           op->src[0]->type == GGML_TYPE_F32  || // TODO: helper function\n           op->src[0]->type == GGML_TYPE_F16  ||\n           op->src[0]->type == GGML_TYPE_BF16 ||\n           op->src[0]->type == GGML_TYPE_Q4_0 ||\n           op->src[0]->type == GGML_TYPE_Q4_1 ||\n           op->src[0]->type == GGML_TYPE_Q5_0 ||\n           op->src[0]->type == GGML_TYPE_Q5_1 ||\n           op->src[0]->type == GGML_TYPE_Q8_0 ||\n           op->src[0]->type == GGML_TYPE_MXFP4 ||\n           op->src[0]->type == GGML_TYPE_IQ4_NL ||\n           false) && (ne11 >= 2 && ne11 <= 8)\n         ) ||\n         (\n          (\n           op->src[0]->type == GGML_TYPE_Q4_K ||\n           op->src[0]->type == GGML_TYPE_Q5_K ||\n           op->src[0]->type == GGML_TYPE_Q6_K ||\n           op->src[0]->type == GGML_TYPE_Q2_K ||\n           op->src[0]->type == GGML_TYPE_Q3_K ||\n           false) && (ne11 >= 4 && ne11 <= 8)\n         )\n        )\n       ) {\n        // TODO: determine the optimal parameters based on grid utilization\n        //       I still don't know why we should not always use the maximum available threads:\n        //\n        //       nsg = pipeline.maxTotalThreadsPerThreadgroup / 32\n        //\n        //       my current hypothesis is that the work grid is not evenly divisible for different nsg\n        //       values and there can be some tail effects when nsg is high. need to confirm this\n        //\n        const int nsg    = 2;                 // num simdgroups per threadgroup\n\n        // num threads along row per simdgroup\n        int16_t nxpsg = 0;\n        if (ne00 % 256 == 0 && ne11 < 3) {\n            nxpsg = 16;\n        } else if (ne00 % 128 == 0) {\n            nxpsg = 8;\n        } else {\n            nxpsg = 4;\n        }\n\n        const int16_t nypsg  = 32/nxpsg;          // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)\n        const int16_t r0ptg  = nypsg*nsg;         // num src0 rows per threadgroup\n              int16_t r1ptg  = 4;                 // num src1 rows per threadgroup\n\n        // note: not sure how optimal are those across all different hardware. there might be something cleverer\n        switch (ne11) {\n            case 2:\n                r1ptg = 2; break;\n            case 3:\n            case 6:\n                r1ptg = 3; break;\n            case 4:\n            case 7:\n            case 8:\n                r1ptg = 4; break;\n            case 5:\n                r1ptg = 5; break;\n            default:\n                GGML_ABORT(\"unsupported ne11\");\n        };\n\n        auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);\n\n        ggml_metal_kargs_mul_mv_ext args = {\n            /*.ne00  =*/ ne00,\n            /*.ne01  =*/ ne01,\n            /*.ne02  =*/ ne02,\n            /*.nb00  =*/ nb00,\n            /*.nb01  =*/ nb01,\n            /*.nb02  =*/ nb02,\n            /*.nb03  =*/ nb03,\n            /*.ne10  =*/ ne10,\n            /*.ne11  =*/ ne11,\n            /*.ne12  =*/ ne12,\n            /*.nb10  =*/ nb10,\n            /*.nb11  =*/ nb11,\n            /*.nb12  =*/ nb12,\n            /*.nb13  =*/ nb13,\n            /*.ne0   =*/ ne0,\n            /*.ne1   =*/ ne1,\n            /*.r2    =*/ r2,\n            /*.r3    =*/ r3,\n        };\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline);\n        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);\n\n        ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + r0ptg - 1)/r0ptg), ((ne11 + r1ptg - 1)/r1ptg), ne12*ne13, 32, nsg, 1);\n    } else if (\n        !ggml_is_transposed(op->src[0]) &&\n        !ggml_is_transposed(op->src[1]) &&\n        // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs\n        // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel\n        props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {\n        //GGML_LOG_INFO(\"matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\\n\", ne00, ne01, ne02, ne11, ne12);\n\n        // some Metal matrix data types require aligned pointers\n        // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)\n        //switch (op->src[0]->type) {\n        //    case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break;\n        //    case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break;\n        //    case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8  == 0); break;\n        //    default: break;\n        //}\n\n        auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);\n\n        ggml_metal_kargs_mul_mm args = {\n            /*.ne00 =*/ ne00,\n            /*.ne02 =*/ ne02,\n            /*.nb01 =*/ nb01,\n            /*.nb02 =*/ nb02,\n            /*.nb03 =*/ nb03,\n            /*.ne12 =*/ ne12,\n            /*.nb10 =*/ nb10,\n            /*.nb11 =*/ nb11,\n            /*.nb12 =*/ nb12,\n            /*.nb13 =*/ nb13,\n            /*.ne0  =*/ ne0,\n            /*.ne1  =*/ ne1,\n            /*.r2   =*/ r2,\n            /*.r3   =*/ r3,\n        };\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline);\n        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);\n\n        const size_t smem = pipeline.smem;\n\n        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n        ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);\n    } else {\n        auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);\n\n        const int nr0 = pipeline.nr0;\n        const int nr1 = pipeline.nr1;\n        const int nsg = pipeline.nsg;\n\n        const size_t smem = pipeline.smem;\n\n        ggml_metal_kargs_mul_mv args = {\n            /*.ne00 =*/ ne00,\n            /*.ne01 =*/ ne01,\n            /*.ne02 =*/ ne02,\n            /*.nb00 =*/ nb00,\n            /*.nb01 =*/ nb01,\n            /*.nb02 =*/ nb02,\n            /*.nb03 =*/ nb03,\n            /*.ne10 =*/ ne10,\n            /*.ne11 =*/ ne11,\n            /*.ne12 =*/ ne12,\n            /*.nb10 =*/ nb10,\n            /*.nb11 =*/ nb11,\n            /*.nb12 =*/ nb12,\n            /*.nb13 =*/ nb13,\n            /*.ne0  =*/ ne0,\n            /*.ne1  =*/ ne1,\n            /*.nr0  =*/ nr0,\n            /*.r2   =*/ r2,\n            /*.r3   =*/ r3,\n        };\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline);\n        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);\n\n        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n        if (op->src[0]->type == GGML_TYPE_F32 ||\n            op->src[0]->type == GGML_TYPE_F16 ||\n            op->src[0]->type == GGML_TYPE_BF16 ||\n            op->src[0]->type == GGML_TYPE_Q8_0) {\n            ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);\n        } else {\n            ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);\n        }\n    }\n\n    return 1;\n}\n\nsize_t ggml_metal_op_mul_mat_id_extra_tpe(const ggml_tensor * op) {\n    assert(op->op == GGML_OP_MUL_MAT_ID);\n\n    const int64_t ne02 = op->src[0]->ne[2]; // n_expert\n\n    return ggml_type_size(GGML_TYPE_I32)*ne02;\n}\n\nsize_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) {\n    assert(op->op == GGML_OP_MUL_MAT_ID);\n\n    const int64_t ne02 = op->src[0]->ne[2]; // n_expert\n    const int64_t ne21 = op->src[2]->ne[1]; // n_token\n\n    return ggml_type_size(GGML_TYPE_I32)*ne02*ne21;\n}\n\nint ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    // src2 = ids\n    GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);\n\n    GGML_ASSERT(!ggml_is_transposed(op->src[0]));\n    GGML_ASSERT(!ggml_is_transposed(op->src[1]));\n\n    GGML_ASSERT(ne03 == 1);\n    GGML_ASSERT(ne13 == 1);\n\n    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);\n    ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);\n    ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);\n    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);\n\n    const uint32_t r2 = 1;\n    const uint32_t r3 = 1;\n\n    // find the break-even point where the matrix-matrix kernel becomes more efficient compared\n    // to the matrix-vector kernel\n    // ne20 = n_used_experts\n    // ne21 = n_rows (batch size)\n    const int ne21_mm_id_min = 32;\n\n    if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {\n        // some Metal matrix data types require aligned pointers\n        // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)\n        //switch (op->src[0]->type) {\n        //    case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break;\n        //    case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break;\n        //    case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8  == 0); break;\n        //    default: break;\n        //}\n\n        // extra buffers for intermediate id mapping\n        ggml_metal_buffer_id bid_tpe = bid_dst;\n        bid_tpe.offs += ggml_nbytes(op);\n\n        ggml_metal_buffer_id bid_ids = bid_tpe;\n        bid_ids.offs += ggml_metal_op_mul_mat_id_extra_tpe(op);\n\n        {\n            ggml_metal_kargs_mul_mm_id_map0 args = {\n                ne02,\n                ne10,\n                ne11, // n_expert_used (bcast)\n                nb11,\n                nb12,\n                ne21, // n_tokens\n                ne20, // n_expert_used\n                nb21,\n            };\n\n            auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);\n\n            const size_t smem = pipeline.smem;\n\n            GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n\n            GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);\n\n            ggml_metal_encoder_set_pipeline(enc, pipeline);\n            ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n            ggml_metal_encoder_set_buffer  (enc, bid_src2, 1);\n            ggml_metal_encoder_set_buffer  (enc, bid_tpe,  2);\n            ggml_metal_encoder_set_buffer  (enc, bid_ids,  3);\n\n            ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n            ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, ne02, 1, 1);\n        }\n\n        // this barrier is always needed because the next kernel has to wait for the id maps to be computed\n        ggml_metal_op_concurrency_reset(ctx);\n\n        {\n            auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);\n\n            ggml_metal_kargs_mul_mm_id args = {\n                /*.ne00  =*/ ne00,\n                /*.ne02  =*/ ne02,\n                /*.nb01  =*/ nb01,\n                /*.nb02  =*/ nb02,\n                /*.nb03  =*/ nb03,\n                /*.ne11  =*/ ne11, // n_expert_used (bcast)\n                /*.nb10  =*/ nb10,\n                /*.nb11  =*/ nb11,\n                /*.nb12  =*/ nb12,\n                /*.nb13  =*/ nb13,\n                /*.ne20  =*/ ne20, // n_expert_used\n                /*.ne21  =*/ ne21, // n_tokens\n                /*.ne0   =*/ ne0,\n                /*.ne1   =*/ ne1,\n                /*.r2    =*/ r2,\n                /*.r3    =*/ r3,\n            };\n\n            ggml_metal_encoder_set_pipeline(enc, pipeline);\n            ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n            ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);\n            ggml_metal_encoder_set_buffer  (enc, bid_src1, 2);\n            ggml_metal_encoder_set_buffer  (enc, bid_tpe,  3);\n            ggml_metal_encoder_set_buffer  (enc, bid_ids,  4);\n            ggml_metal_encoder_set_buffer  (enc, bid_dst,  5);\n\n            const size_t smem = pipeline.smem;\n\n            ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n            ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);\n        }\n    } else {\n        auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);\n\n        const int nr0 = pipeline.nr0;\n        const int nr1 = pipeline.nr1;\n        const int nsg = pipeline.nsg;\n\n        const size_t smem = pipeline.smem;\n\n        ggml_metal_kargs_mul_mv_id args = {\n            /*.nei0 =*/ ne20,\n            /*.nei1 =*/ ne21,\n            /*.nbi1 =*/ nb21,\n            /*.ne00 =*/ ne00,\n            /*.ne01 =*/ ne01,\n            /*.ne02 =*/ ne02,\n            /*.nb00 =*/ nb00,\n            /*.nb01 =*/ nb01,\n            /*.nb02 =*/ nb02,\n            /*.ne10 =*/ ne10,\n            /*.ne11 =*/ ne11,\n            /*.ne12 =*/ ne12,\n            /*.ne13 =*/ ne13,\n            /*.nb10 =*/ nb10,\n            /*.nb11 =*/ nb11,\n            /*.nb12 =*/ nb12,\n            /*.ne0  =*/ ne0,\n            /*.ne1  =*/ ne1,\n            /*.nb1  =*/ nb1,\n            /*.nr0  =*/ nr0,\n        };\n\n        if (ggml_is_quantized(op->src[0]->type)) {\n            GGML_ASSERT(ne00 >= nsg*nr0);\n        }\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline);\n        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);\n        ggml_metal_encoder_set_buffer(enc, bid_src0, 1);\n        ggml_metal_encoder_set_buffer(enc, bid_src1, 2);\n        ggml_metal_encoder_set_buffer(enc, bid_dst,  3);\n        ggml_metal_encoder_set_buffer(enc, bid_src2, 4);\n\n        const int64_t _ne1 = 1;\n        const int64_t ne123 = ne20*ne21;\n\n        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n        if (op->src[0]->type == GGML_TYPE_F32 ||\n            op->src[0]->type == GGML_TYPE_F16 ||\n            op->src[0]->type == GGML_TYPE_BF16 ||\n            op->src[0]->type == GGML_TYPE_Q8_0) {\n            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);\n        } else {\n            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);\n        }\n    }\n\n    return 1;\n}\n\nint ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n\n    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);\n    GGML_ASSERT(op->type         == GGML_TYPE_F32);\n\n    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));\n\n    ggml_metal_kargs_add_id args = {\n        /*.ne0  =*/ ne0,\n        /*.ne1  =*/ ne1,\n        /*.nb01 =*/ nb01,\n        /*.nb02 =*/ nb02,\n        /*.nb11 =*/ nb11,\n        /*.nb21 =*/ nb21,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), 3);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         4);\n\n    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, 1, nth, 1, 1);\n\n    return 1;\n}\n\nbool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {\n    assert(op->op == GGML_OP_FLASH_ATTN_EXT);\n\n    const int64_t ne00 = op->src[0]->ne[0]; // head size\n    const int64_t ne01 = op->src[0]->ne[1]; // batch size\n\n    // use vec kernel if the batch size is small and if the head size is supported\n    return (ne01 < 20) && (ne00 % 32 == 0);\n}\n\nsize_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {\n    assert(op->op == GGML_OP_FLASH_ATTN_EXT);\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);\n\n    size_t res = 0;\n\n    const bool has_mask = op->src[3] != nullptr;\n\n    // note: the non-vec kernel requires more extra memory, so always reserve for it\n    GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);\n\n    //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {\n    if (false) {\n        // note: always reserve the padding space to avoid graph reallocations\n        //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;\n        const bool has_kvpad = true;\n\n        if (has_kvpad) {\n            res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(\n                nb11*ne12*ne13 +\n                nb21*ne22*ne23 +\n                (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));\n        }\n    } else {\n        //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;\n        const bool has_kvpad = true;\n\n        if (has_kvpad) {\n            res += OP_FLASH_ATTN_EXT_NCPSG*(\n                nb11*ne12*ne13 +\n                nb21*ne22*ne23 +\n                (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));\n        }\n    }\n\n    return res;\n}\n\nsize_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {\n    assert(op->op == GGML_OP_FLASH_ATTN_EXT);\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n  //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n  //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n  //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n  //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);\n  //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);\n\n    size_t res = 0;\n\n    const bool has_mask = op->src[3] != nullptr;\n\n    if (!has_mask) {\n        return res;\n    }\n\n    const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);\n\n    // this optimization is not useful for the vector kernels\n    // note: always reserve the blk buffer to avoid graph reallocations\n    //if (is_vec) {\n    //    return res;\n    //}\n\n    const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;\n    const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;\n\n    const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;\n    const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;\n\n    res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);\n\n    return res;\n}\n\nsize_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {\n    assert(op->op == GGML_OP_FLASH_ATTN_EXT);\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n  //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n  //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);\n  //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);\n  //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);\n\n    size_t res = 0;\n\n    // note: always reserve the temp buffer to avoid graph reallocations\n    //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {\n    if (true) {\n        const int64_t nwg = 32;\n        const int64_t ne01_max = std::min(ne01, 32);\n\n        // temp buffer for writing the results from each workgroup\n        // - ne20: the size of the Value head\n        // -  + 2: the S and M values for each intermediate result\n        res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));\n    }\n\n    return res;\n}\n\nint ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS( int32_t, nb,  op,         nb);\n\n    GGML_ASSERT(ne00 % 4 == 0);\n\n    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->src[1]->type == op->src[2]->type);\n\n    //GGML_ASSERT(ggml_are_same_shape (src1, src2));\n    GGML_ASSERT(ne11 == ne21);\n    GGML_ASSERT(ne12 == ne22);\n\n    GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);\n    GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&\n            \"the Flash-Attention Metal kernel requires the mask to be at least n_queries big\");\n\n    float scale;\n    float max_bias;\n    float logit_softcap;\n\n    memcpy(&scale,         ((const int32_t *) op->op_params) + 0, sizeof(scale));\n    memcpy(&max_bias,      ((const int32_t *) op->op_params) + 1, sizeof(max_bias));\n    memcpy(&logit_softcap, ((const int32_t *) op->op_params) + 2, sizeof(logit_softcap));\n\n    if (logit_softcap != 0.0f) {\n        scale /= logit_softcap;\n    }\n\n    const bool has_mask  = op->src[3] != NULL;\n    const bool has_sinks = op->src[4] != NULL;\n    const bool has_bias  = max_bias != 0.0f;\n    const bool has_scap  = logit_softcap != 0.0f;\n\n    const uint32_t n_head      = op->src[0]->ne[2];\n    const  int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));\n\n    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);\n    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n    GGML_ASSERT(ne01 < 65536);\n\n    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);\n    ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);\n    ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);\n    ggml_metal_buffer_id bid_src3 = has_mask  ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;\n    ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;\n\n    ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);\n\n    ggml_metal_buffer_id bid_pad = bid_dst;\n    bid_pad.offs += ggml_nbytes(op);\n\n    ggml_metal_buffer_id bid_blk = bid_pad;\n    bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);\n\n    ggml_metal_buffer_id bid_tmp = bid_blk;\n    bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);\n\n    if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {\n        // half8x8 kernel\n        const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup\n        const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup\n\n        GGML_ASSERT(nqptg <= 32);\n        GGML_ASSERT(nqptg  % 8  == 0);\n        GGML_ASSERT(ncpsg  % 32 == 0);\n\n        bool need_sync = false;\n\n        const bool has_kvpad = ne11 % ncpsg != 0;\n\n        if (has_kvpad) {\n            assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);\n\n            ggml_metal_kargs_flash_attn_ext_pad args0 = {\n                /*.ne11    =*/ne11,\n                /*.ne_12_2 =*/ne12,\n                /*.ne_12_3 =*/ne13,\n                /*.nb11    =*/nb11,\n                /*.nb12    =*/nb12,\n                /*.nb13    =*/nb13,\n                /*.nb21    =*/nb21,\n                /*.nb22    =*/nb22,\n                /*.nb23    =*/nb23,\n                /*.ne31    =*/ne31,\n                /*.ne32    =*/ne32,\n                /*.ne33    =*/ne33,\n                /*.nb31    =*/nb31,\n                /*.nb32    =*/nb32,\n                /*.nb33    =*/nb33,\n            };\n\n            auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);\n\n            ggml_metal_encoder_set_pipeline(enc, pipeline0);\n            ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);\n            ggml_metal_encoder_set_buffer  (enc, bid_src1, 1);\n            ggml_metal_encoder_set_buffer  (enc, bid_src2, 2);\n            ggml_metal_encoder_set_buffer  (enc, bid_src3, 3);\n            ggml_metal_encoder_set_buffer  (enc, bid_pad,  4);\n\n            assert(ne12 == ne22);\n            assert(ne13 == ne23);\n\n            ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);\n\n            need_sync = true;\n        }\n\n        if (has_mask) {\n            assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);\n\n            ggml_metal_kargs_flash_attn_ext_blk args0 = {\n                /*.ne01 =*/ ne01,\n                /*.ne30 =*/ ne30,\n                /*.ne31 =*/ ne31,\n                /*.ne32 =*/ ne32,\n                /*.ne33 =*/ ne33,\n                /*.nb31 =*/ nb31,\n                /*.nb32 =*/ nb32,\n                /*.nb33 =*/ nb33,\n            };\n\n            auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);\n\n            ggml_metal_encoder_set_pipeline(enc, pipeline0);\n            ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);\n            ggml_metal_encoder_set_buffer  (enc, bid_src3, 1);\n            ggml_metal_encoder_set_buffer  (enc, bid_blk,  2);\n\n            const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);\n            const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);\n\n            ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);\n\n            need_sync = true;\n        }\n\n        if (need_sync) {\n            ggml_metal_op_concurrency_reset(ctx);\n        }\n\n        const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;\n\n        // 2*(2*ncpsg)\n        // ncpsg soft_max values + ncpsg mask values\n        //\n        // 16*32*(nsg)\n        // the shared memory needed for the simdgroups to load the KV cache\n        // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG\n        //\n#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))\n\n        //int64_t nsgmax = 4;\n        //\n        //if (is_q) {\n        //    nsgmax = 2;\n        //    while (true) {\n        //        const size_t smem = FATTN_SMEM(nsgmax);\n        //        if (smem > props_dev->max_theadgroup_memory_size) {\n        //            break;\n        //        }\n        //        nsgmax *= 2;\n        //    }\n        //    nsgmax /= 2;\n        //}\n\n        // simdgroups per threadgroup (a.k.a. warps)\n        //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;\n        int32_t nsg = ne00 >= 512 ? 8 : 4;\n\n        const size_t smem = FATTN_SMEM(nsg);\n\n        ggml_metal_kargs_flash_attn_ext args = {\n            /*.ne01          =*/ ne01,\n            /*.ne02          =*/ ne02,\n            /*.ne03          =*/ ne03,\n            /*.nb01          =*/ nb01,\n            /*.nb02          =*/ nb02,\n            /*.nb03          =*/ nb03,\n            /*.ne11          =*/ ne11,\n            /*.ne_12_2       =*/ ne12,\n            /*.ne_12_3       =*/ ne13,\n            /*.ns10          =*/ int32_t(nb11/nb10),\n            /*.nb11          =*/ nb11,\n            /*.nb12          =*/ nb12,\n            /*.nb13          =*/ nb13,\n            /*.ns20          =*/ int32_t(nb21/nb20),\n            /*.nb21          =*/ nb21,\n            /*.nb22          =*/ nb22,\n            /*.nb23          =*/ nb23,\n            /*.ne31          =*/ ne31,\n            /*.ne32          =*/ ne32,\n            /*.ne33          =*/ ne33,\n            /*.nb31          =*/ nb31,\n            /*.nb32          =*/ nb32,\n            /*.nb33          =*/ nb33,\n            /*.ne1           =*/ ne1,\n            /*.ne2           =*/ ne2,\n            /*.ne3           =*/ ne3,\n            /*.scale         =*/ scale,\n            /*.max_bias      =*/ max_bias,\n            /*.m0            =*/ m0,\n            /*.m1            =*/ m1,\n            /*.n_head_log2   =*/ n_head_log2,\n            /*.logit_softcap =*/ logit_softcap,\n        };\n\n        auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline);\n        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);\n        ggml_metal_encoder_set_buffer  (enc, bid_src1, 2);\n        ggml_metal_encoder_set_buffer  (enc, bid_src2, 3);\n        ggml_metal_encoder_set_buffer  (enc, bid_src3, 4);\n        ggml_metal_encoder_set_buffer  (enc, bid_src4, 5);\n        ggml_metal_encoder_set_buffer  (enc, bid_pad,  6);\n        ggml_metal_encoder_set_buffer  (enc, bid_blk,  7);\n        ggml_metal_encoder_set_buffer  (enc, bid_dst,  8);\n\n        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n        ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03, 32, nsg, 1);\n#undef FATTN_SMEM\n    } else {\n        // half4x4 kernel\n        const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup\n        const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!\n        const int nhptg = 1;                           // heads per threadgroup\n\n        GGML_ASSERT(nqptg <= 32);\n        GGML_ASSERT(nqptg  % 1  == 0);\n        GGML_ASSERT(ncpsg  % 32 == 0);\n\n        bool need_sync = false;\n\n        const bool has_kvpad = ne11 % ncpsg != 0;\n\n        if (has_kvpad) {\n            assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);\n\n            ggml_metal_kargs_flash_attn_ext_pad args0 = {\n                /*.ne11    =*/ne11,\n                /*.ne_12_2 =*/ne12,\n                /*.ne_12_3 =*/ne13,\n                /*.nb11    =*/nb11,\n                /*.nb12    =*/nb12,\n                /*.nb13    =*/nb13,\n                /*.nb21    =*/nb21,\n                /*.nb22    =*/nb22,\n                /*.nb23    =*/nb23,\n                /*.ne31    =*/ne31,\n                /*.ne32    =*/ne32,\n                /*.ne33    =*/ne33,\n                /*.nb31    =*/nb31,\n                /*.nb32    =*/nb32,\n                /*.nb33    =*/nb33,\n            };\n\n            auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);\n\n            ggml_metal_encoder_set_pipeline(enc, pipeline0);\n            ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);\n            ggml_metal_encoder_set_buffer  (enc, bid_src1, 1);\n            ggml_metal_encoder_set_buffer  (enc, bid_src2, 2);\n            ggml_metal_encoder_set_buffer  (enc, bid_src3, 3);\n            ggml_metal_encoder_set_buffer  (enc, bid_pad,  4);\n\n            assert(ne12 == ne22);\n            assert(ne13 == ne23);\n\n            ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);\n\n            need_sync = true;\n        }\n\n        if (need_sync) {\n            ggml_metal_op_concurrency_reset(ctx);\n        }\n\n        // note: for simplicity assume the K is larger or equal than V\n        GGML_ASSERT(ne10 >= ne20);\n\n        // ne00 + 2*ncpsg*(nsg)\n        // for each query, we load it as f16 in shared memory (ne00)\n        // and store the soft_max values and the mask\n        //\n        // ne20*(nsg)\n        // each simdgroup has a full f32 head vector in shared mem to accumulate results\n        //\n#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))\n\n        int64_t nsg = 1;\n\n        // workgroups\n        // each workgroup handles nsg*nkpsg cache values\n        int32_t nwg = 1;\n        if (false) {\n            // for small KV caches, we could launch a single workgroup and write the results directly to dst/\n            // however, this does not lead to significant improvement, so disabled\n            nwg = 1;\n            nsg = 4;\n        } else {\n            nwg = 32;\n            nsg = 1;\n            while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {\n                nsg *= 2;\n            }\n        }\n\n        ggml_metal_kargs_flash_attn_ext_vec args = {\n            /*.ne01          =*/ ne01,\n            /*.ne02          =*/ ne02,\n            /*.ne03          =*/ ne03,\n            /*.nb01          =*/ nb01,\n            /*.nb02          =*/ nb02,\n            /*.nb03          =*/ nb03,\n            /*.ne11          =*/ ne11,\n            /*.ne_12_2       =*/ ne12,\n            /*.ne_12_3       =*/ ne13,\n            /*.ns10          =*/ int32_t(nb11/nb10),\n            /*.nb11          =*/ nb11,\n            /*.nb12          =*/ nb12,\n            /*.nb13          =*/ nb13,\n            /*.ns20          =*/ int32_t(nb21/nb20),\n            /*.nb21          =*/ nb21,\n            /*.nb22          =*/ nb22,\n            /*.nb23          =*/ nb23,\n            /*.ne31          =*/ ne31,\n            /*.ne32          =*/ ne32,\n            /*.ne33          =*/ ne33,\n            /*.nb31          =*/ nb31,\n            /*.nb32          =*/ nb32,\n            /*.nb33          =*/ nb33,\n            /*.ne1           =*/ ne1,\n            /*.ne2           =*/ ne2,\n            /*.ne3           =*/ ne3,\n            /*.scale         =*/ scale,\n            /*.max_bias      =*/ max_bias,\n            /*.m0            =*/ m0,\n            /*.m1            =*/ m1,\n            /*.n_head_log2   =*/ n_head_log2,\n            /*.logit_softcap =*/ logit_softcap,\n        };\n\n        auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);\n\n        GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline);\n        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);\n        ggml_metal_encoder_set_buffer  (enc, bid_src1, 2);\n        ggml_metal_encoder_set_buffer  (enc, bid_src2, 3);\n        ggml_metal_encoder_set_buffer  (enc, bid_src3, 4);\n        ggml_metal_encoder_set_buffer  (enc, bid_src4, 5);\n\n        const size_t smem = FATTN_SMEM(nsg);\n\n        //printf(\"smem: %zu, max: %zu, nsg = %d, nsgmax = %d\\n\", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax);\n        GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);\n\n        if (nwg == 1) {\n            assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);\n\n            // using 1 workgroup -> write the result directly into dst\n            ggml_metal_encoder_set_buffer(enc, bid_pad, 6);\n            ggml_metal_encoder_set_buffer(enc, bid_dst, 7);\n\n            ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);\n        } else {\n            // sanity checks\n            assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);\n\n            GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);\n            GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));\n\n            // write the results from each workgroup into a temp buffer\n            ggml_metal_encoder_set_buffer(enc, bid_pad, 6);\n            ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);\n\n            ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);\n\n            // sync the 2 kernels\n            ggml_metal_op_concurrency_reset(ctx);\n\n            // reduce the results from the workgroups\n            {\n                const int32_t nrows = ne1*ne2*ne3;\n\n                ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {\n                    nrows,\n                };\n\n                auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);\n\n                ggml_metal_encoder_set_pipeline(enc, pipeline0);\n                ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);\n                ggml_metal_encoder_set_buffer  (enc, bid_tmp, 1);\n                ggml_metal_encoder_set_buffer  (enc, bid_dst, 2);\n\n                ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, 32*nwg, 1, 1);\n            }\n        }\n#undef FATTN_SMEM\n    }\n\n    return 1;\n}\n\nint ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    const bool use_fusion = ctx->use_fusion;\n\n    const int debug_fusion = ctx->debug_fusion;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);\n\n    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));\n    GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));\n\n    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);\n    ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);\n    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);\n\n    ggml_metal_kargs_bin args = {\n        /*.ne00 =*/ ne00,\n        /*.ne01 =*/ ne01,\n        /*.ne02 =*/ ne02,\n        /*.ne03 =*/ ne03,\n        /*.nb00 =*/ nb00,\n        /*.nb01 =*/ nb01,\n        /*.nb02 =*/ nb02,\n        /*.nb03 =*/ nb03,\n        /*.ne10 =*/ ne10,\n        /*.ne11 =*/ ne11,\n        /*.ne12 =*/ ne12,\n        /*.ne13 =*/ ne13,\n        /*.nb10 =*/ nb10,\n        /*.nb11 =*/ nb11,\n        /*.nb12 =*/ nb12,\n        /*.nb13 =*/ nb13,\n        /*.ne0  =*/ ne0,\n        /*.ne1  =*/ ne1,\n        /*.ne2  =*/ ne2,\n        /*.ne3  =*/ ne3,\n        /*.nb0  =*/ nb0,\n        /*.nb1  =*/ nb1,\n        /*.nb2  =*/ nb2,\n        /*.nb3  =*/ nb3,\n        /*.offs =*/ 0,\n        /*.o1   =*/ { bid_src1.offs },\n    };\n\n    ggml_op fops[8];\n\n    int n_fuse = 1;\n\n    // c[0] = add(a,    b[0])\n    // c[1] = add(c[0], b[1])\n    // c[2] = add(c[1], b[2])\n    // ...\n    if (use_fusion) {\n        fops[0] = GGML_OP_ADD;\n        fops[1] = GGML_OP_ADD;\n        fops[2] = GGML_OP_ADD;\n        fops[3] = GGML_OP_ADD;\n        fops[4] = GGML_OP_ADD;\n        fops[5] = GGML_OP_ADD;\n        fops[6] = GGML_OP_ADD;\n        fops[7] = GGML_OP_ADD;\n\n        // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops\n        //       across splits. idx_end indicates the last node in the current split\n        for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {\n            if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {\n                break;\n            }\n\n            ggml_tensor * f0 = ctx->node(idx + n_fuse);\n            ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);\n\n            if (f0 != f1->src[0]) {\n                break;\n            }\n\n            // b[0] === b[1] === ...\n            if (!ggml_are_same_layout(f0->src[1], f1->src[1])) {\n                break;\n            }\n\n            // only fuse ops if src1 is in the same Metal buffer\n            ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]);\n            if (bid_fuse.metal != bid_src1.metal) {\n                break;\n            }\n\n            //ctx->fuse_cnt[ops[n_fuse + 1]->op]++;\n\n            args.o1[n_fuse + 1] = bid_fuse.offs;\n        }\n\n        ++n_fuse;\n\n        if (debug_fusion > 1 && n_fuse > 1) {\n            GGML_LOG_DEBUG(\"%s: fuse: ADD x %d\\n\", __func__, n_fuse);\n        }\n    }\n\n    // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer\n    bid_src1.offs = 0;\n\n    struct ggml_metal_pipeline_with_params pipeline;\n\n    pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);\n\n    if (n_fuse > 1) {\n        bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));\n\n        for (int i = 1; i < n_fuse; ++i) {\n            if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {\n                ggml_metal_op_concurrency_reset(ctx);\n\n                break;\n            }\n        }\n    }\n\n    if (pipeline.c4) {\n        args.ne00 = ne00/4;\n        args.ne10 = ne10/4;\n        args.ne0  = ne0/4;\n    }\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);\n    ggml_metal_encoder_set_buffer  (enc, bid_src1, 2);\n    ggml_metal_encoder_set_buffer  (enc, bid_dst,  3);\n\n    if (pipeline.cnt) {\n        ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1);\n    } else {\n        const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n\n        int nth = 1;\n\n        while (2*nth < args.ne0 && nth < nth_max) {\n            nth *= 2;\n        }\n\n        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);\n    }\n\n    return n_fuse;\n}\n\nint ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));\n\n    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);\n    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);\n\n    float eps;\n    memcpy(&eps, op->op_params, sizeof(float));\n\n    ggml_metal_kargs_l2_norm args = {\n        /*.ne00  =*/ ne00,\n        /*.ne01  =*/ ne01,\n        /*.ne02  =*/ ne02,\n        /*.ne03  =*/ ne03,\n        /*.nb00  =*/ nb00,\n        /*.nb01  =*/ nb01,\n        /*.nb02  =*/ nb02,\n        /*.nb03  =*/ nb03,\n        /*.ne0   =*/ ne0,\n        /*.ne1   =*/ ne1,\n        /*.ne2   =*/ ne2,\n        /*.ne3   =*/ ne3,\n        /*.nb0   =*/ nb0,\n        /*.nb1   =*/ nb1,\n        /*.nb2   =*/ nb2,\n        /*.nb3   =*/ nb3,\n        /*.eps   =*/ eps,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);\n\n    if (pipeline.c4) {\n        args.ne00 = ne00/4;\n        args.ne0  = ne0/4;\n    }\n\n    int nth = 32; // SIMD width\n\n    while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {\n        nth *= 2;\n    }\n\n    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n\n    const size_t smem = pipeline.smem;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);\n    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);\n\n    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    const int32_t ngrp = ((const int32_t *) op->op_params)[0];\n\n    float eps;\n    memcpy(&eps, op->op_params + 1, sizeof(float));\n\n    ggml_metal_kargs_group_norm args = {\n        /*.ne00 =*/ ne00,\n        /*.ne01 =*/ ne01,\n        /*.ne02 =*/ ne02,\n        /*.nb00 =*/ nb00,\n        /*.nb01 =*/ nb01,\n        /*.nb02 =*/ nb02,\n        /*.ngrp =*/ ngrp,\n        /*.eps  =*/ eps,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);\n\n    int nth = 32; // SIMD width\n    //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {\n    //    nth *= 2;\n    //}\n\n    //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n    //nth = std::min(nth, ne00/4);\n\n    const size_t smem = pipeline.smem;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);\n\n    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ngrp, 1, 1, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    const bool use_fusion = ctx->use_fusion;\n\n    const int debug_fusion = ctx->debug_fusion;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    float eps;\n    memcpy(&eps, op->op_params, sizeof(float));\n\n    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);\n    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);\n\n    ggml_metal_kargs_norm args = {\n        /*.ne00   =*/ ne00,\n        /*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00,\n        /*.nb1    =*/ nb1,\n        /*.nb2    =*/ nb2,\n        /*.nb3    =*/ nb3,\n        /*.eps    =*/ eps,\n        /*.nef1   =*/ { ne01 },\n        /*.nef2   =*/ { ne02 },\n        /*.nef3   =*/ { ne03 },\n        /*.nbf1   =*/ { nb01 },\n        /*.nbf2   =*/ { nb02 },\n        /*.nbf3   =*/ { nb03 },\n    };\n\n    ggml_op fops[8];\n\n    int n_fuse = 1;\n\n    ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };\n\n    // d[0] = norm(a)\n    // d[1] = mul(d[0], b)\n    // d[2] = add(d[1], c)\n    if (use_fusion) {\n        fops[0] = op->op;\n        fops[1] = GGML_OP_MUL;\n        fops[2] = GGML_OP_ADD;\n\n        for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {\n            if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {\n                break;\n            }\n\n            ggml_tensor * f0 = ctx->node(idx + n_fuse);\n            ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);\n\n            if (f0 != f1->src[0]) {\n                break;\n            }\n\n            if (f1->src[1]->ne[0] != op->ne[0]) {\n                break;\n            }\n\n            if (!ggml_is_contiguous_rows(f1->src[1])) {\n                break;\n            }\n\n            if (f1->type != GGML_TYPE_F32) {\n                break;\n            }\n\n            //ctx->fuse_cnt[f1->op]++;\n\n            bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]);\n\n            args.nef1[n_fuse + 1] = f1->src[1]->ne[1];\n            args.nef2[n_fuse + 1] = f1->src[1]->ne[2];\n            args.nef3[n_fuse + 1] = f1->src[1]->ne[3];\n\n            args.nbf1[n_fuse + 1] = f1->src[1]->nb[1];\n            args.nbf2[n_fuse + 1] = f1->src[1]->nb[2];\n            args.nbf3[n_fuse + 1] = f1->src[1]->nb[3];\n        }\n\n        ++n_fuse;\n\n        if (debug_fusion > 1 && n_fuse > 1) {\n            if (n_fuse == 2) {\n                GGML_LOG_DEBUG(\"%s: fuse: %s + MUL\\n\", __func__, ggml_op_name(op->op));\n            }\n            if (n_fuse == 3) {\n                GGML_LOG_DEBUG(\"%s: fuse: %s + MUL + ADD\\n\", __func__, ggml_op_name(op->op));\n            }\n        }\n    }\n\n    if (n_fuse > 1) {\n        bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));\n\n        for (int i = 1; i < n_fuse; ++i) {\n            if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {\n                ggml_metal_op_concurrency_reset(ctx);\n\n                break;\n            }\n        }\n    }\n\n    auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);\n\n    int nth = 32; // SIMD width\n\n    while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {\n        nth *= 2;\n    }\n\n    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n    nth = std::min(nth, args.ne00_t);\n\n    const size_t smem = pipeline.smem;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, bid_src0,    1);\n    ggml_metal_encoder_set_buffer  (enc, bid_fuse[0], 2);\n    ggml_metal_encoder_set_buffer  (enc, bid_fuse[1], 3);\n    ggml_metal_encoder_set_buffer  (enc, bid_dst,     4);\n\n    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);\n\n    return n_fuse;\n}\n\nint ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    // make sure we have one or more position id(ne10) per token(ne02)\n    GGML_ASSERT(ne10 % ne02 == 0);\n    GGML_ASSERT(ne10 >= ne02);\n\n    const int nth = std::min(1024, ne00);\n\n    const int n_past     = ((const int32_t *) op->op_params)[0];\n    const int n_dims     = ((const int32_t *) op->op_params)[1];\n  //const int mode       = ((const int32_t *) op->op_params)[2];\n    // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal\n    const int n_ctx_orig = ((const int32_t *) op->op_params)[4];\n\n    float freq_base;\n    float freq_scale;\n    float ext_factor;\n    float attn_factor;\n    float beta_fast;\n    float beta_slow;\n\n    memcpy(&freq_base,   (const int32_t *) op->op_params +  5, sizeof(float));\n    memcpy(&freq_scale,  (const int32_t *) op->op_params +  6, sizeof(float));\n    memcpy(&ext_factor,  (const int32_t *) op->op_params +  7, sizeof(float));\n    memcpy(&attn_factor, (const int32_t *) op->op_params +  8, sizeof(float));\n    memcpy(&beta_fast,   (const int32_t *) op->op_params +  9, sizeof(float));\n    memcpy(&beta_slow,   (const int32_t *) op->op_params + 10, sizeof(float));\n\n    // mrope\n    const int sect_0 = ((const int32_t *) op->op_params)[11];\n    const int sect_1 = ((const int32_t *) op->op_params)[12];\n    const int sect_2 = ((const int32_t *) op->op_params)[13];\n    const int sect_3 = ((const int32_t *) op->op_params)[14];\n\n    ggml_metal_kargs_rope args = {\n        /*.ne00        =*/ ne00,\n        /*.ne01        =*/ ne01,\n        /*.ne02        =*/ ne02,\n        /*.ne03        =*/ ne03,\n        /*.nb00        =*/ nb00,\n        /*.nb01        =*/ nb01,\n        /*.nb02        =*/ nb02,\n        /*.nb03        =*/ nb03,\n        /*.ne0         =*/ ne0,\n        /*.ne1         =*/ ne1,\n        /*.ne2         =*/ ne2,\n        /*.ne3         =*/ ne3,\n        /*.nb0         =*/ nb0,\n        /*.nb1         =*/ nb1,\n        /*.nb2         =*/ nb2,\n        /*.nb3         =*/ nb3,\n        /*.n_past      =*/ n_past,\n        /*.n_dims      =*/ n_dims,\n        /*.n_ctx_orig  =*/ n_ctx_orig,\n        /*.freq_base   =*/ freq_base,\n        /*.freq_scale  =*/ freq_scale,\n        /*.ext_factor  =*/ ext_factor,\n        /*.attn_factor =*/ attn_factor,\n        /*.beta_fast   =*/ beta_fast,\n        /*.beta_slow   =*/ beta_slow,\n        /* sect_0      =*/ sect_0,\n        /* sect_1      =*/ sect_1,\n        /* sect_2      =*/ sect_2,\n        /* sect_3      =*/ sect_3,\n        /* src2        =*/ op->src[2] != nullptr,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n    if (op->src[2]) {\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), 3);\n    } else {\n        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 3);\n    }\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         4);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    const int32_t s0 = ((const int32_t *)(op->op_params))[0];\n    const int32_t s1 = ((const int32_t *)(op->op_params))[1];\n    const int32_t p0 = ((const int32_t *)(op->op_params))[2];\n    const int32_t p1 = ((const int32_t *)(op->op_params))[3];\n    const int32_t d0 = ((const int32_t *)(op->op_params))[4];\n    const int32_t d1 = ((const int32_t *)(op->op_params))[5];\n\n    const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1;\n\n    const int32_t N  = op->src[1]->ne[is_2D ? 3 : 2];\n    const int32_t IC = op->src[1]->ne[is_2D ? 2 : 1];\n    const int32_t IH = is_2D ? op->src[1]->ne[1] : 1;\n    const int32_t IW =         op->src[1]->ne[0];\n\n    const int32_t KH = is_2D ? op->src[0]->ne[1] : 1;\n    const int32_t KW =         op->src[0]->ne[0];\n\n    const int32_t OH = is_2D ? op->ne[2] : 1;\n    const int32_t OW =         op->ne[1];\n\n    const int32_t CHW = IC * KH * KW;\n\n    const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4;\n    const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4;\n\n    ggml_metal_kargs_im2col args = {\n        /*.ofs0 =*/ ofs0,\n        /*.ofs1 =*/ ofs1,\n        /*.IW   =*/ IW,\n        /*.IH   =*/ IH,\n        /*.CHW  =*/ CHW,\n        /*.s0   =*/ s0,\n        /*.s1   =*/ s1,\n        /*.p0   =*/ p0,\n        /*.p1   =*/ p1,\n        /*.d0   =*/ d0,\n        /*.d1   =*/ d1,\n        /*.N    =*/ N,\n        /*.KH   =*/ KH,\n        /*.KW   =*/ KW,\n        /*.KHW  =*/ KH * KW,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);\n\n    GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n\n    const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);\n\n    return 1;\n}\n\nint ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    GGML_ASSERT(ggml_is_contiguous(op->src[0]));\n    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->type == GGML_TYPE_F32);\n    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);\n\n    const int32_t s0 = ((const int32_t *) op->op_params)[0];\n    const int32_t s1 = ((const int32_t *) op->op_params)[1];\n    const int32_t p0 = ((const int32_t *) op->op_params)[2];\n    const int32_t p1 = ((const int32_t *) op->op_params)[3];\n    const int32_t d0 = ((const int32_t *) op->op_params)[4];\n    const int32_t d1 = ((const int32_t *) op->op_params)[5];\n\n    ggml_metal_kargs_conv_2d args = {\n        /*.nb00 =*/ nb00,\n        /*.nb01 =*/ nb01,\n        /*.nb02 =*/ nb02,\n        /*.nb03 =*/ nb03,\n        /*.nb10 =*/ nb10,\n        /*.nb11 =*/ nb11,\n        /*.nb12 =*/ nb12,\n        /*.nb13 =*/ nb13,\n        /*.nb0  =*/ nb0,\n        /*.nb1  =*/ nb1,\n        /*.nb2  =*/ nb2,\n        /*.nb3  =*/ nb3,\n        /*.IW   =*/ ne10,\n        /*.IH   =*/ ne11,\n        /*.KW   =*/ ne00,\n        /*.KH   =*/ ne01,\n        /*.IC   =*/ ne02,\n        /*.OC   =*/ ne03,\n        /*.OW   =*/ ne0,\n        /*.OH   =*/ ne1,\n        /*.N    =*/ ne3,\n        /*.s0   =*/ s0,\n        /*.s1   =*/ s1,\n        /*.p0   =*/ p0,\n        /*.p1   =*/ p1,\n        /*.d0   =*/ d0,\n        /*.d1   =*/ d1,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);\n\n    int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);\n    nth = std::min(nth, 256);\n    nth = std::max(nth, 1);\n\n    const uint64_t n_out = ggml_nelements(op);\n\n    uint64_t tg = (n_out + nth - 1)/nth;\n    tg = std::max<uint64_t>(tg, 1);\n    tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    const int32_t s0 = ((const int32_t *)(op->op_params))[0];\n\n    const int32_t IC = op->src[1]->ne[1];\n    const int32_t IL = op->src[1]->ne[0];\n\n    const int32_t K  = op->src[0]->ne[0];\n\n    const int32_t OL = op->ne[0];\n    const int32_t OC = op->ne[1];\n\n    ggml_metal_kargs_conv_transpose_1d args = {\n        /*.IC  =*/ IC,\n        /*.IL  =*/ IL,\n        /*.K   =*/ K,\n        /*.s0  =*/ s0,\n        /*.nb0 =*/ nb0,\n        /*.nb1 =*/ nb1,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, OL, OC, 1, 1, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    const int32_t s0 = ((const int32_t *)(op->op_params))[0];\n\n    const int32_t IC = op->src[1]->ne[2];\n    const int32_t IH = op->src[1]->ne[1];\n    const int32_t IW = op->src[1]->ne[0];\n\n    const int32_t KH = op->src[0]->ne[1];\n    const int32_t KW = op->src[0]->ne[0];\n\n    const int32_t OW = op->ne[0];\n    const int32_t OH = op->ne[1];\n    const int32_t OC = op->ne[2];\n\n    ggml_metal_kargs_conv_transpose_2d args = {\n        /*.IC  =*/ IC,\n        /*.IH  =*/ IH,\n        /*.IW  =*/ IW,\n        /*.KH  =*/ KH,\n        /*.KW  =*/ KW,\n        /*.OC  =*/ OC,\n        /*.s0  =*/ s0,\n        /*.nb0 =*/ nb0,\n        /*.nb1 =*/ nb1,\n        /*.nb2 =*/ nb2,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);\n\n    // Metal requires buffer size to be multiple of 16 bytes\n    const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);\n    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    float sf0 = (float)ne0/op->src[0]->ne[0];\n    float sf1 = (float)ne1/op->src[0]->ne[1];\n    float sf2 = (float)ne2/op->src[0]->ne[2];\n    float sf3 = (float)ne3/op->src[0]->ne[3];\n\n    const int32_t mode_flags = ggml_get_op_params_i32(op, 0);\n\n    float poffs = 0.5f;\n\n    if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {\n        poffs = 0.0f;\n        sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;\n        sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;\n    }\n\n    ggml_metal_kargs_upscale args = {\n        /*.ne00  =*/ ne00,\n        /*.ne01  =*/ ne01,\n        /*.ne02  =*/ ne02,\n        /*.ne03  =*/ ne03,\n        /*.nb00  =*/ nb00,\n        /*.nb01  =*/ nb01,\n        /*.nb02  =*/ nb02,\n        /*.nb03  =*/ nb03,\n        /*.ne0   =*/ ne0,\n        /*.ne1   =*/ ne1,\n        /*.ne2   =*/ ne2,\n        /*.ne3   =*/ ne3,\n        /*.nb0   =*/ nb0,\n        /*.nb1   =*/ nb1,\n        /*.nb2   =*/ nb2,\n        /*.nb3   =*/ nb3,\n        /*.sf0   =*/ sf0,\n        /*.sf1   =*/ sf1,\n        /*.sf2   =*/ sf2,\n        /*.sf3   =*/ sf3,\n        /*.poffs =*/ poffs,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);\n\n    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    ggml_metal_kargs_pad args = {\n        /*.ne00 =*/ ne00,\n        /*.ne01 =*/ ne01,\n        /*.ne02 =*/ ne02,\n        /*.ne03 =*/ ne03,\n        /*.nb00 =*/ nb00,\n        /*.nb01 =*/ nb01,\n        /*.nb02 =*/ nb02,\n        /*.nb03 =*/ nb03,\n        /*.ne0  =*/ ne0,\n        /*.ne1  =*/ ne1,\n        /*.ne2  =*/ ne2,\n        /*.ne3  =*/ ne3,\n        /*.nb0  =*/ nb0,\n        /*.nb1  =*/ nb1,\n        /*.nb2  =*/ nb2,\n        /*.nb3  =*/ nb3\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);\n\n    const int nth = std::min(1024, ne0);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    ggml_metal_kargs_pad_reflect_1d args = {\n        /*.ne00 =*/ ne00,\n        /*.ne01 =*/ ne01,\n        /*.ne02 =*/ ne02,\n        /*.ne03 =*/ ne03,\n        /*.nb00 =*/ nb00,\n        /*.nb01 =*/ nb01,\n        /*.nb02 =*/ nb02,\n        /*.nb03 =*/ nb03,\n        /*.ne0  =*/ ne0,\n        /*.ne1  =*/ ne1,\n        /*.ne2  =*/ ne2,\n        /*.ne3  =*/ ne3,\n        /*.nb0  =*/ nb0,\n        /*.nb1  =*/ nb1,\n        /*.nb2  =*/ nb2,\n        /*.nb3  =*/ nb3,\n        /*.p0 =*/ ((const int32_t *)(op->op_params))[0],\n        /*.p1 =*/ ((const int32_t *)(op->op_params))[1]\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);\n\n    const int nth = std::min(1024, ne0);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    float start;\n    float step;\n\n    memcpy(&start, ((const int32_t *) op->op_params) + 0, sizeof(float));\n    memcpy(&step,  ((const int32_t *) op->op_params) + 2, sizeof(float));\n\n    ggml_metal_kargs_arange args = {\n        /*.ne0   =*/ ne0,\n        /*.start =*/ start,\n        /*.step  =*/ step\n    };\n\n    const int nth = std::min(1024, ne0);\n\n    auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op), 1);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    const int dim        = op->op_params[0];\n    const int max_period = op->op_params[1];\n\n    ggml_metal_kargs_timestep_embedding args = {\n        /*.nb1 =*/ nb1,\n        /*.dim =*/ dim,\n        /*.max_period =*/ max_period,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);\n\n    const int nth = std::max(1, std::min(1024, dim/2));\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne00, 1, 1, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    ggml_metal_kargs_argmax args = {\n        /*.ne00 = */ ne00,\n        /*.nb01 = */ nb01,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);\n\n    const int64_t nrows = ggml_nrows(op->src[0]);\n\n    int nth = 32; // SIMD width\n    while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {\n        nth *= 2;\n    }\n\n    const size_t smem = pipeline.smem;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);\n\n    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);\n\n    // bitonic sort requires the number of elements to be power of 2\n    int nth = 1;\n    while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {\n        nth *= 2;\n    }\n\n    const int npr = (ne00 + nth - 1)/nth;\n\n    // Metal kernels require the buffer size to be multiple of 16 bytes\n    // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength\n    const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);\n\n    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);\n    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);\n\n    ggml_metal_buffer_id bid_tmp = bid_dst;\n    bid_tmp.offs += ggml_nbytes(op);\n\n    if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {\n        std::swap(bid_dst, bid_tmp);\n    }\n\n    ggml_metal_kargs_argsort args = {\n        /*.ne00  =*/ ne00,\n        /*.ne01  =*/ ne01,\n        /*.ne02  =*/ ne02,\n        /*.ne03  =*/ ne03,\n        /*.nb00  =*/ nb00,\n        /*.nb01  =*/ nb01,\n        /*.nb02  =*/ nb02,\n        /*.nb03  =*/ nb03,\n        /*.ne0   =*/ ne0,\n        /*.ne1   =*/ ne1,\n        /*.ne2   =*/ ne2,\n        /*.ne3   =*/ ne3,\n        /*.top_k =*/ nth,\n    };\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);\n    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);\n\n    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);\n\n    auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);\n\n    int len = nth;\n\n    while (len < ne00) {\n        ggml_metal_op_concurrency_reset(ctx);\n\n        ggml_metal_kargs_argsort_merge args_merge = {\n            /*.ne00  =*/ ne00,\n            /*.ne01  =*/ ne01,\n            /*.ne02  =*/ ne02,\n            /*.ne03  =*/ ne03,\n            /*.nb00  =*/ nb00,\n            /*.nb01  =*/ nb01,\n            /*.nb02  =*/ nb02,\n            /*.nb03  =*/ nb03,\n            /*.ne0   =*/ ne0,\n            /*.ne1   =*/ ne1,\n            /*.ne2   =*/ ne2,\n            /*.ne3   =*/ ne3,\n            /*.top_k =*/ ne00,\n            /*.len   =*/ len,\n        };\n\n        // merges per row\n        const int nm = (ne00 + 2*len - 1) / (2*len);\n\n        const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline_merge);\n        ggml_metal_encoder_set_bytes   (enc, &args_merge, sizeof(args_merge), 0);\n        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);\n        ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);\n        ggml_metal_encoder_set_buffer  (enc, bid_tmp,  3);\n\n        ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);\n\n        std::swap(bid_dst, bid_tmp);\n\n        len <<= 1;\n    }\n\n    return 1;\n}\n\nint ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);\n\n    // bitonic sort requires the number of elements to be power of 2\n    int nth = 1;\n    while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {\n        nth *= 2;\n    }\n\n    // blocks per row\n    const int npr = (ne00 + nth - 1)/nth;\n\n    const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);\n\n    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);\n    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);\n\n    ggml_metal_buffer_id bid_tmp = bid_dst;\n    bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);\n\n    if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {\n        std::swap(bid_dst, bid_tmp);\n    }\n\n    const int top_k = ne0;\n\n    ggml_metal_kargs_argsort args = {\n        /*.ne00  =*/ ne00,\n        /*.ne01  =*/ ne01,\n        /*.ne02  =*/ ne02,\n        /*.ne03  =*/ ne03,\n        /*.nb00  =*/ nb00,\n        /*.nb01  =*/ nb01,\n        /*.nb02  =*/ nb02,\n        /*.nb03  =*/ nb03,\n        /*.ne0   =*/ ne0,\n        /*.ne1   =*/ ne1,\n        /*.ne2   =*/ ne2,\n        /*.ne3   =*/ ne3,\n        /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices\n    };\n\n    if (npr > 1) {\n        args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);\n    }\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);\n    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);\n\n    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);\n\n    auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);\n\n    int len = args.top_k;\n\n    while (len < args.ne0) {\n        ggml_metal_op_concurrency_reset(ctx);\n\n        // merges per row\n        const int nm = (args.ne0 + 2*len - 1) / (2*len);\n\n        const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));\n\n        ggml_metal_kargs_argsort_merge args_merge = {\n            /*.ne00  =*/ ne00,\n            /*.ne01  =*/ ne01,\n            /*.ne02  =*/ ne02,\n            /*.ne03  =*/ ne03,\n            /*.nb00  =*/ nb00,\n            /*.nb01  =*/ nb01,\n            /*.nb02  =*/ nb02,\n            /*.nb03  =*/ nb03,\n            /*.ne0   =*/ args.ne0,\n            /*.ne1   =*/ ne1,\n            /*.ne2   =*/ ne2,\n            /*.ne3   =*/ ne3,\n            /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements\n            /*.len   =*/ len,\n        };\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline_merge);\n        ggml_metal_encoder_set_bytes   (enc, &args_merge, sizeof(args_merge), 0);\n        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);\n        ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);\n        ggml_metal_encoder_set_buffer  (enc, bid_tmp,  3);\n\n        ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);\n\n        std::swap(bid_dst, bid_tmp);\n\n        len <<= 1;\n    }\n\n    return 1;\n}\n\nint ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    ggml_metal_kargs_tri args = {\n        /*.ne00  =*/ ne00,\n        /*.ne01  =*/ ne01,\n        /*.ne02  =*/ ne02,\n        /*.ne03  =*/ ne03,\n        /*.nb00  =*/ nb00,\n        /*.nb01  =*/ nb01,\n        /*.nb02  =*/ nb02,\n        /*.nb03  =*/ nb03,\n        /*.ne0   =*/ ne0,\n        /*.ne1   =*/ ne1,\n        /*.ne2   =*/ ne2,\n        /*.ne3   =*/ ne3,\n        /*.nb0   =*/ nb0,\n        /*.nb1   =*/ nb1,\n        /*.nb2   =*/ nb2,\n        /*.nb3   =*/ nb3,\n    };\n\n    auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);\n\n    int nth = 32; // SIMD width\n\n    while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {\n        nth *= 2;\n    }\n\n    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n    nth = std::min(nth, ne00);\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);\n\n    const int64_t np = ggml_nelements(op->src[0]);\n    ggml_metal_kargs_opt_step_adamw args = {\n        /*.np =*/ np,\n    };\n\n    int ida = 0;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), ida++);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);\n\n    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);\n    const int64_t n = (np + nth - 1) / nth;\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);\n\n    auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);\n\n    const int64_t np = ggml_nelements(op->src[0]);\n    ggml_metal_kargs_opt_step_sgd args = {\n        /*.np =*/ np,\n    };\n\n    int ida = 0;\n\n    ggml_metal_encoder_set_pipeline(enc, pipeline);\n    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), ida++);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);\n    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);\n\n    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);\n    const int64_t n = (np + nth - 1) / nth;\n\n    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);\n\n    return 1;\n}\n\nint ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {\n    ggml_tensor * op = ctx->node(idx);\n\n    ggml_metal_library_t lib = ctx->lib;\n    ggml_metal_encoder_t enc = ctx->enc;\n\n    GGML_TENSOR_LOCALS(int32_t,  ne0, op->src[0], ne);\n    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);\n    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);\n\n    {\n        ggml_metal_kargs_memset args = { /*.val =*/ 0 };\n\n        auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline);\n        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);\n        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);\n\n        ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);\n    }\n\n    ggml_metal_op_concurrency_reset(ctx);\n\n    {\n        ggml_metal_kargs_count_equal args = {\n            /*.ne00 =*/ ne00,\n            /*.ne01 =*/ ne01,\n            /*.ne02 =*/ ne02,\n            /*.ne03 =*/ ne03,\n            /*.nb00 =*/ nb00,\n            /*.nb01 =*/ nb01,\n            /*.nb02 =*/ nb02,\n            /*.nb03 =*/ nb03,\n            /*.nb10 =*/ nb10,\n            /*.nb11 =*/ nb11,\n            /*.nb12 =*/ nb12,\n            /*.nb13 =*/ nb13,\n        };\n\n        auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);\n\n        const size_t smem = pipeline.smem;\n\n        const int nth = 32*pipeline.nsg;\n\n        GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));\n\n        ggml_metal_encoder_set_pipeline(enc, pipeline);\n        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);\n        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);\n        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);\n        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);\n\n        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);\n        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);\n    }\n\n    return 1;\n}\n"
  },
  {
    "path": "src/ggml-metal/ggml-metal-ops.h",
    "content": "#pragma once\n\n#include \"ggml-metal-device.h\"\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\ntypedef struct ggml_metal_op * ggml_metal_op_t;\n\nggml_metal_op_t ggml_metal_op_init(\n        ggml_metal_device_t dev,\n        ggml_metal_cmd_buf_t cmd_buf,\n        struct ggml_cgraph * gf,\n        int  idx_start,\n        int  idx_end,\n        bool use_fusion,\n        bool use_concurrency,\n        bool use_capture,\n        int  debug_graph,\n        int  debug_fusion);\n\nvoid ggml_metal_op_free(ggml_metal_op_t ctx);\n\nint ggml_metal_op_n_nodes(ggml_metal_op_t ctx);\n\nint ggml_metal_op_encode(ggml_metal_op_t ctx, int idx);\n\n//\n// available ops:\n//\n\n// tokens per expert\nsize_t ggml_metal_op_mul_mat_id_extra_tpe(const struct ggml_tensor * op);\n\n// id map [n_tokens, n_expert]\nsize_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);\n\n// return true if we should use the FA vector kernel for this op\nbool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);\n\nsize_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);\nsize_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op);\nsize_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);\n\nint ggml_metal_op_concat            (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_repeat            (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_acc               (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_unary             (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_glu               (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_sum               (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_sum_rows          (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_cumsum            (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_get_rows          (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_set_rows          (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_diag              (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_soft_max          (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_ssm_conv          (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_ssm_scan          (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_rwkv              (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_gated_delta_net   (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_solve_tri         (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_set               (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_cpy               (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_pool_1d           (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_pool_2d           (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_mul_mat           (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_mul_mat_id        (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_add_id            (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_flash_attn_ext    (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_bin               (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_l2_norm           (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_group_norm        (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_norm              (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_rope              (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_im2col            (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_conv_2d           (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_upscale           (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_pad               (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_pad_reflect_1d    (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_arange            (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_argmax            (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_argsort           (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_top_k             (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_tri               (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_opt_step_adamw    (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_opt_step_sgd      (ggml_metal_op_t ctx, int idx);\nint ggml_metal_op_count_equal       (ggml_metal_op_t ctx, int idx);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-metal/ggml-metal.cpp",
    "content": "#include \"ggml-metal.h\"\n\n#include \"ggml-impl.h\"\n#include \"ggml-backend-impl.h\"\n\n#include \"ggml-metal-device.h\"\n#include \"ggml-metal-context.h\"\n#include \"ggml-metal-ops.h\"\n\n#include <mutex>\n#include <string>\n\n#define GGML_METAL_NAME \"MTL\"\n#define GGML_METAL_MAX_DEVICES 16\n\n// number of Metal devices\n// note: can be overridden with GGML_METAL_DEVICES env to simulate virtual devices\nstatic int g_devices = 1;\n\n////////////////////////////////////////////////////////////////////////////////\n// backend interface\n////////////////////////////////////////////////////////////////////////////////\n\n// shared buffer\n\nstatic void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t buffer) {\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;\n\n    GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));\n\n    ggml_metal_buffer_free(ctx);\n}\n\nstatic void * ggml_backend_metal_buffer_shared_get_base(ggml_backend_buffer_t buffer) {\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;\n\n    GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));\n\n    return ggml_metal_buffer_get_base(ctx);\n}\n\nstatic void ggml_backend_metal_buffer_shared_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;\n\n    GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));\n\n    ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size);\n}\n\nstatic void ggml_backend_metal_buffer_shared_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;\n\n    GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));\n\n    ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size);\n}\n\nstatic void ggml_backend_metal_buffer_shared_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;\n\n    GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));\n\n    ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size);\n}\n\nstatic bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;\n\n    GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));\n\n    GGML_UNUSED(buffer);\n    GGML_UNUSED(src);\n    GGML_UNUSED(dst);\n\n    return false;\n}\n\nstatic void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;\n\n    GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));\n\n    ggml_metal_buffer_clear(ctx, value);\n}\n\nstatic ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = {\n    /* .free_buffer     = */ ggml_backend_metal_buffer_shared_free_buffer,\n    /* .get_base        = */ ggml_backend_metal_buffer_shared_get_base,\n    /* .init_tensor     = */ NULL,\n    /* .memset_tensor   = */ ggml_backend_metal_buffer_shared_memset_tensor,\n    /* .set_tensor      = */ ggml_backend_metal_buffer_shared_set_tensor,\n    /* .get_tensor      = */ ggml_backend_metal_buffer_shared_get_tensor,\n    /* .cpy_tensor      = */ ggml_backend_metal_buffer_shared_cpy_tensor,\n    /* .clear           = */ ggml_backend_metal_buffer_shared_clear,\n    /* .reset           = */ NULL,\n};\n\n// private buffer\n\nstatic void ggml_backend_metal_buffer_private_free_buffer(ggml_backend_buffer_t buffer) {\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;\n\n    GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));\n\n    ggml_metal_buffer_free(ctx);\n}\n\nstatic void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) {\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;\n\n    GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));\n\n    return ggml_metal_buffer_get_base(ctx);\n}\n\nstatic void ggml_backend_metal_buffer_private_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;\n\n    GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));\n\n    ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size);\n}\n\nstatic void ggml_backend_metal_buffer_private_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;\n\n    GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));\n\n    ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size);\n}\n\nstatic void ggml_backend_metal_buffer_private_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;\n\n    GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));\n\n    ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size);\n}\n\nstatic bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;\n\n    GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));\n\n    GGML_UNUSED(buffer);\n    GGML_UNUSED(src);\n    GGML_UNUSED(dst);\n\n    return false;\n}\n\nstatic void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;\n\n    GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));\n\n    ggml_metal_buffer_clear(ctx, value);\n}\n\nstatic ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = {\n    /* .free_buffer     = */ ggml_backend_metal_buffer_private_free_buffer,\n    /* .get_base        = */ ggml_backend_metal_buffer_private_get_base,\n    /* .init_tensor     = */ NULL,\n    /* .memset_tensor   = */ ggml_backend_metal_buffer_private_memset_tensor,\n    /* .set_tensor      = */ ggml_backend_metal_buffer_private_set_tensor,\n    /* .get_tensor      = */ ggml_backend_metal_buffer_private_get_tensor,\n    /* .cpy_tensor      = */ ggml_backend_metal_buffer_private_cpy_tensor,\n    /* .clear           = */ ggml_backend_metal_buffer_private_clear,\n    /* .reset           = */ NULL,\n};\n\nstatic bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) {\n    return buffer->iface.free_buffer == ggml_backend_metal_buffer_shared_free_buffer ||\n           buffer->iface.free_buffer == ggml_backend_metal_buffer_private_free_buffer;\n}\n\n//\n// buffer types\n//\n\nstruct ggml_backend_metal_buffer_type {\n    int device;\n    std::string name;\n};\n\nstruct ggml_backend_metal_buffer_type_deleter {\n    void operator()(ggml_backend_metal_buffer_type * ctx) const {\n        delete ctx;\n    }\n};\n\ntypedef std::unique_ptr<ggml_backend_metal_buffer_type, ggml_backend_metal_buffer_type_deleter> ggml_backend_metal_buffer_type_ptr;\n\n// common method for allocating shread or private Metal buffers\nstatic ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;\n    ggml_metal_buffer_t res = ggml_metal_buffer_init(ctx_dev, size, shared);\n\n    ggml_backend_buffer_i buf_i = ggml_metal_buffer_is_shared(res)\n        ? ggml_backend_metal_buffer_shared_i\n        : ggml_backend_metal_buffer_private_i;\n\n    return ggml_backend_buffer_init(buft, buf_i, res, size);\n}\n\nstatic size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {\n    size_t res = ggml_nbytes(tensor);\n\n    // some operations require additional memory for fleeting data:\n    switch (tensor->op) {\n        case GGML_OP_MUL_MAT_ID:\n            {\n                res += ggml_metal_op_mul_mat_id_extra_tpe(tensor);\n                res += ggml_metal_op_mul_mat_id_extra_ids(tensor);\n            } break;\n        case GGML_OP_FLASH_ATTN_EXT:\n            {\n                res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);\n                res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);\n                res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);\n            } break;\n        case GGML_OP_CUMSUM:\n        case GGML_OP_ARGSORT:\n            {\n                res *= 2;\n            } break;\n        case GGML_OP_TOP_K:\n            {\n                res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]);\n            } break;\n        default:\n            break;\n    }\n\n    return res;\n\n    GGML_UNUSED(buft);\n}\n\n// default (shared) buffer type\n\nstatic const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) {\n    ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;\n\n    return ctx->name.c_str();\n}\n\nstatic ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true);\n}\n\nstatic size_t ggml_backend_metal_buffer_type_shared_get_alignment(ggml_backend_buffer_type_t buft) {\n    return 32;\n\n    GGML_UNUSED(buft);\n}\n\nstatic size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_buffer_type_t buft) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;\n\n    return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;\n}\n\nstatic size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {\n    return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);\n}\n\nstatic bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) {\n    return false;\n\n    GGML_UNUSED(buft);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(int device) {\n    static std::mutex mutex;\n    std::lock_guard<std::mutex> lock(mutex);\n\n    static std::vector<ggml_backend_buffer_type> bufts;\n    static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs;\n\n    static bool initialized = false;\n    if (!initialized) {\n        bufts.reserve(g_devices);\n        ctxs.reserve(g_devices);\n\n        for (int i = 0; i < g_devices; ++i) {\n            ggml_backend_metal_buffer_type * raw_ctx =\n                new ggml_backend_metal_buffer_type {\n                    /* .device = */ i,\n                    /* .name   = */ GGML_METAL_NAME + std::to_string(i),\n                };\n            ctxs.emplace_back(raw_ctx);\n\n            ggml_backend_buffer_type buft = {\n                /* .iface = */ {\n                    /* .get_name         = */ ggml_backend_metal_buffer_type_shared_get_name,\n                    /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,\n                    /* .get_alignment    = */ ggml_backend_metal_buffer_type_shared_get_alignment,\n                    /* .get_max_size     = */ ggml_backend_metal_buffer_type_shared_get_max_size,\n                    /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,\n                    /* .is_host          = */ ggml_backend_metal_buffer_type_shared_is_host,\n                },\n                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),\n                /* .context = */ raw_ctx,\n            };\n\n            bufts.emplace_back(buft);\n        }\n\n        initialized = true;\n    }\n\n    return &bufts[device];\n}\n\n// default (private) buffer type\n\nstatic const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) {\n    ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;\n\n    return ctx->name.c_str();\n}\n\nstatic ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, false);\n}\n\nstatic size_t ggml_backend_metal_buffer_type_private_get_alignment(ggml_backend_buffer_type_t buft) {\n    return 32;\n\n    GGML_UNUSED(buft);\n}\n\nstatic size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_buffer_type_t buft) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;\n\n    return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;\n}\n\nstatic size_t ggml_backend_metal_buffer_type_private_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {\n    return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);\n}\n\nstatic bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) {\n    return false;\n\n    GGML_UNUSED(buft);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(int device) {\n    static std::mutex mutex;\n    std::lock_guard<std::mutex> lock(mutex);\n\n    static std::vector<ggml_backend_buffer_type> bufts;\n    static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs;\n\n    static bool initialized = false;\n    if (!initialized) {\n        bufts.reserve(g_devices);\n        ctxs.reserve(g_devices);\n\n        for (int i = 0; i < g_devices; ++i) {\n            ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{\n                /* .device = */ i,\n                /* .name   = */ GGML_METAL_NAME + std::to_string(i) + \"_Private\"\n            };\n            ctxs.emplace_back(raw_ctx);\n\n            ggml_backend_buffer_type buft = {\n                /* .iface = */ {\n                    /* .get_name         = */ ggml_backend_metal_buffer_type_private_get_name,\n                    /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_private_alloc_buffer,\n                    /* .get_alignment    = */ ggml_backend_metal_buffer_type_private_get_alignment,\n                    /* .get_max_size     = */ ggml_backend_metal_buffer_type_private_get_max_size,\n                    /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_private_get_alloc_size,\n                    /* .is_host          = */ ggml_backend_metal_buffer_type_private_is_host,\n                },\n                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),\n                /* .context = */ raw_ctx,\n            };\n\n            bufts.emplace_back(buft);\n        }\n\n        initialized = true;\n    }\n\n    return &bufts[device];\n}\n\n// mapped buffer type\n\nstatic const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) {\n    ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;\n\n    return ctx->name.c_str();\n}\n\nstatic ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    // for mapped buffers, prefer shared memory\n    return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true);\n}\n\nstatic size_t ggml_backend_metal_buffer_type_mapped_get_alignment(ggml_backend_buffer_type_t buft) {\n    return 32;\n\n    GGML_UNUSED(buft);\n}\n\nstatic size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_buffer_type_t buft) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;\n\n    return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;\n}\n\nstatic size_t ggml_backend_metal_buffer_type_mapped_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {\n    return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);\n}\n\nstatic bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) {\n    return false;\n\n    GGML_UNUSED(buft);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(int device) {\n    static std::mutex mutex;\n    std::lock_guard<std::mutex> lock(mutex);\n\n    static std::vector<ggml_backend_buffer_type> bufts;\n    static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs;\n\n    static bool initialized = false;\n    if (!initialized) {\n        bufts.reserve(g_devices);\n        ctxs.reserve(g_devices);\n\n        for (int i = 0; i < g_devices; ++i) {\n            ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{\n                /* .device = */ i,\n                /* .name   = */ GGML_METAL_NAME + std::to_string(i) + \"_Mapped\"\n            };\n            ctxs.emplace_back(raw_ctx);\n\n            // note: not obvious, but this buffer type still needs to implement .alloc_buffer:\n            //       https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099\n            ggml_backend_buffer_type buft = {\n                /* .iface = */ {\n                    /* .get_name         = */ ggml_backend_metal_buffer_type_mapped_get_name,\n                    /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,\n                    /* .get_alignment    = */ ggml_backend_metal_buffer_type_mapped_get_alignment,\n                    /* .get_max_size     = */ ggml_backend_metal_buffer_type_mapped_get_max_size,\n                    /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size,\n                    /* .is_host          = */ ggml_backend_metal_buffer_type_mapped_is_host,\n                },\n                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),\n                /* .context = */ raw_ctx,\n            };\n\n            bufts.emplace_back(buft);\n        }\n\n        initialized = true;\n    }\n\n    return &bufts[device];\n}\n\n// backend\n\nstatic const char * ggml_backend_metal_name(ggml_backend_t backend) {\n    ggml_metal_t ctx = (ggml_metal_t)backend->context;\n\n    return ggml_metal_get_name(ctx);\n}\n\nstatic void ggml_backend_metal_free(ggml_backend_t backend) {\n    ggml_metal_t ctx = (ggml_metal_t)backend->context;\n\n    // wait for any ongoing async operations to finish\n    ggml_metal_synchronize(ctx);\n\n    ggml_metal_free(ctx);\n\n    free(backend);\n}\n\nstatic void ggml_backend_metal_synchronize(ggml_backend_t backend) {\n    ggml_metal_t ctx = (ggml_metal_t)backend->context;\n\n    ggml_metal_synchronize(ctx);\n}\n\nstatic void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    ggml_metal_t ctx = (ggml_metal_t)backend->context;\n\n    ggml_metal_set_tensor_async(ctx, tensor, data, offset, size);\n}\n\nstatic void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    ggml_metal_t ctx = (ggml_metal_t)backend->context;\n\n    ggml_metal_get_tensor_async(ctx, tensor, data, offset, size);\n}\n\nstatic bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {\n    if (!ggml_backend_is_metal(backend_src) || !ggml_backend_is_metal(backend_dst)) {\n        return false;\n    }\n\n    if (!ggml_backend_buffer_is_metal(src->buffer) || !ggml_backend_buffer_is_metal(dst->buffer)) {\n        return false;\n    }\n\n    ggml_metal_t ctx_src = (ggml_metal_t)backend_src->context;\n    ggml_metal_t ctx_dst = (ggml_metal_t)backend_dst->context;\n\n    //ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;\n    //ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;\n\n    //ggml_metal_buffer_t buf_ctx_src = (ggml_metal_buffer_t)buf_src->context;\n    //ggml_metal_buffer_t buf_ctx_dst = (ggml_metal_buffer_t)buf_dst->context;\n\n    return ggml_metal_cpy_tensor_async(ctx_src, ctx_dst, src, dst);\n}\n\nstatic enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {\n    ggml_metal_t ctx = (ggml_metal_t)backend->context;\n\n    return ggml_metal_graph_compute(ctx, cgraph);\n}\n\nstatic void ggml_backend_metal_event_record(ggml_backend_t backend, ggml_backend_event_t event) {\n    ggml_metal_t ctx = (ggml_metal_t)backend->context;\n    ggml_metal_event_t ev = (ggml_metal_event_t)event->context;\n\n    ggml_metal_event_record(ctx, ev);\n}\n\nstatic void ggml_backend_metal_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {\n    ggml_metal_t ctx = (ggml_metal_t)backend->context;\n    ggml_metal_event_t ev = (ggml_metal_event_t)event->context;\n\n    ggml_metal_event_wait(ctx, ev);\n}\n\nstatic void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {\n    ggml_metal_t ctx = (ggml_metal_t)backend->context;\n\n    ggml_metal_graph_optimize(ctx, cgraph);\n}\n\nstatic void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {\n    GGML_ASSERT(ggml_backend_is_metal(backend));\n\n    ggml_metal_t ctx = (ggml_metal_t)backend->context;\n\n    ggml_metal_set_n_cb(ctx, n_cb);\n}\n\nstatic ggml_backend_i ggml_backend_metal_i = {\n    /* .get_name                = */ ggml_backend_metal_name,\n    /* .free                    = */ ggml_backend_metal_free,\n    /* .set_tensor_async        = */ ggml_backend_metal_set_tensor_async,\n    /* .get_tensor_async        = */ ggml_backend_metal_get_tensor_async,\n    /* .cpy_tensor_async        = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups\n    /* .synchronize             = */ ggml_backend_metal_synchronize,\n    /* .graph_plan_create       = */ NULL,\n    /* .graph_plan_free         = */ NULL,\n    /* .graph_plan_update       = */ NULL,\n    /* .graph_plan_compute      = */ NULL,\n    /* .graph_compute           = */ ggml_backend_metal_graph_compute,\n    /* .event_record            = */ ggml_backend_metal_event_record,\n    /* .event_wait              = */ ggml_backend_metal_event_wait,\n    /* .graph_optimize          = */ ggml_backend_metal_graph_optimize,\n};\n\nstatic ggml_guid_t ggml_backend_metal_guid(void) {\n    static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };\n    return &guid;\n}\n\nggml_backend_t ggml_backend_metal_init(void) {\n    ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;\n\n    ggml_metal_t ctx = ggml_metal_init(ctx_dev);\n    if (ctx == NULL) {\n        GGML_LOG_ERROR(\"%s: error: failed to allocate context\\n\", __func__);\n        return NULL;\n    }\n\n    ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend));\n\n    *backend = {\n        /* .guid      = */ ggml_backend_metal_guid(),\n        /* .interface = */ ggml_backend_metal_i,\n        /* .device    = */ dev,\n        /* .context   = */ ctx,\n    };\n\n    ggml_backend_metal_set_n_cb(backend, 1);\n\n    return backend;\n}\n\nbool ggml_backend_is_metal(ggml_backend_t backend) {\n    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());\n}\n\nvoid ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {\n    GGML_ASSERT(ggml_backend_is_metal(backend));\n\n    ggml_metal_t ctx = (ggml_metal_t)backend->context;\n\n    ggml_metal_set_abort_callback(ctx, abort_callback, user_data);\n}\n\nbool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {\n    GGML_ASSERT(ggml_backend_is_metal(backend));\n\n    ggml_metal_t ctx = (ggml_metal_t)backend->context;\n\n    return ggml_metal_supports_family(ctx, family);\n}\n\nvoid ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {\n    GGML_ASSERT(ggml_backend_is_metal(backend));\n\n    ggml_metal_t ctx = (ggml_metal_t)backend->context;\n\n    ggml_metal_capture_next_compute(ctx);\n}\n\n// backend device\n\nstatic const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;\n\n    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);\n\n    return props_dev->name;\n}\n\nstatic const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;\n\n    return ggml_metal_device_get_props(ctx_dev)->desc;\n}\n\nstatic void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;\n\n    ggml_metal_device_get_memory(ctx_dev, free, total);\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {\n    return GGML_BACKEND_DEVICE_TYPE_GPU;\n\n    GGML_UNUSED(dev);\n}\n\nstatic void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {\n    props->name        = ggml_backend_metal_device_get_name(dev);\n    props->description = ggml_backend_metal_device_get_description(dev);\n    props->type        = ggml_backend_metal_device_get_type(dev);\n\n    ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);\n\n    props->caps = {\n        /* .async                = */ true,\n        /* .host_buffer          = */ false,\n        /* .buffer_from_host_ptr = */ true,\n        /* .events               = */ true,\n    };\n}\n\nstatic ggml_backend_t ggml_backend_metal_device_init_backend(ggml_backend_dev_t dev, const char * params) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;\n\n    ggml_metal_t ctx = ggml_metal_init(ctx_dev);\n    if (ctx == NULL) {\n        GGML_LOG_ERROR(\"%s: error: failed to allocate context\\n\", __func__);\n        return NULL;\n    }\n\n    ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend));\n\n    *backend = {\n        /* .guid      = */ ggml_backend_metal_guid(),\n        /* .interface = */ ggml_backend_metal_i,\n        /* .device    = */ dev,\n        /* .context   = */ ctx,\n    };\n\n    ggml_backend_metal_set_n_cb(backend, 1);\n\n    return backend;\n\n    GGML_UNUSED(params);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;\n\n    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);\n\n    return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared(props_dev->device) : ggml_backend_metal_buffer_type_private(props_dev->device);\n}\n\nstatic ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;\n\n    ggml_metal_buffer_t res = ggml_metal_buffer_map(ctx_dev, ptr, size, max_tensor_size);\n\n    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);\n\n    return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(props_dev->device), ggml_backend_metal_buffer_shared_i, res, size);\n}\n\nstatic bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;\n\n    return ggml_metal_device_supports_op(ctx_dev, op);\n}\n\nstatic bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    return\n        buft->device == dev && (\n        buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name ||\n        buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name ||\n        buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name);\n\n    GGML_UNUSED(dev);\n}\n\nstatic int64_t get_op_batch_size(const ggml_tensor * op) {\n    switch (op->op) {\n        case GGML_OP_MUL_MAT:\n            return op->ne[1];\n        case GGML_OP_MUL_MAT_ID:\n            return op->ne[2];\n        default:\n            return ggml_nrows(op);\n    }\n}\n\nstatic bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;\n\n    return (op->op == GGML_OP_MUL_MAT ||\n            op->op == GGML_OP_MUL_MAT_ID) &&\n            get_op_batch_size(op) >= ggml_metal_device_get_props(ctx_dev)->op_offload_min_batch_size;\n}\n\nstatic ggml_backend_event_t ggml_backend_metal_device_event_new(ggml_backend_dev_t dev) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;\n\n    ggml_metal_event_t event = ggml_metal_device_event_init(ctx_dev);\n    GGML_ASSERT(event);\n\n    ggml_backend_event_t ev = new ggml_backend_event {\n        /* .device  = */ dev,\n        /* .context = */ event,\n    };\n\n    return ev;\n}\n\nstatic void ggml_backend_metal_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;\n\n    ggml_metal_event_t ev = (ggml_metal_event_t)event->context;\n\n    ggml_metal_device_event_free(ctx_dev, ev);\n\n    delete event;\n}\n\nstatic void ggml_backend_metal_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {\n    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;\n\n    ggml_metal_event_t evt = (ggml_metal_event_t)event->context;\n\n    ggml_metal_device_event_synchronize(ctx_dev, evt);\n}\n\nstatic ggml_backend_device_i ggml_backend_metal_device_i = {\n    /* .get_name             = */ ggml_backend_metal_device_get_name,\n    /* .get_description      = */ ggml_backend_metal_device_get_description,\n    /* .get_memory           = */ ggml_backend_metal_device_get_memory,\n    /* .get_type             = */ ggml_backend_metal_device_get_type,\n    /* .get_props            = */ ggml_backend_metal_device_get_props,\n    /* .init_backend         = */ ggml_backend_metal_device_init_backend,\n    /* .get_buffer_type      = */ ggml_backend_metal_device_get_buffer_type,\n    /* .get_host_buffer_type = */ NULL,\n    /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped,\n    /* .supports_op          = */ ggml_backend_metal_device_supports_op,\n    /* .supports_buft        = */ ggml_backend_metal_device_supports_buft,\n    /* .offload_op           = */ ggml_backend_metal_device_offload_op,\n    /* .event_new            = */ ggml_backend_metal_device_event_new,\n    /* .event_free           = */ ggml_backend_metal_device_event_free,\n    /* .event_synchronize    = */ ggml_backend_metal_device_event_synchronize,\n};\n\n// backend registry\n\nstruct ggml_backend_metal_reg {\n    std::vector<ggml_backend_dev_t> devices;\n};\n\ntypedef struct ggml_backend_metal_reg * ggml_backend_metal_reg_t;\n\nstatic ggml_backend_metal_reg_t ggml_backend_metal_reg_init(void) {\n    ggml_backend_metal_reg_t ctx = new struct ggml_backend_metal_reg;\n\n    return ctx;\n}\n\nstatic void ggml_backend_metal_reg_free(ggml_backend_metal_reg_t ctx) {\n    delete ctx;\n}\n\nstruct ggml_backend_metal_reg_deleter {\n    void operator()(ggml_backend_metal_reg_t ctx) {\n        ggml_backend_metal_reg_free(ctx);\n    }\n};\n\ntypedef std::unique_ptr<struct ggml_backend_metal_reg, ggml_backend_metal_reg_deleter> ggml_backend_metal_reg_ptr;\n\nstatic const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {\n    return GGML_METAL_NAME;\n\n    GGML_UNUSED(reg);\n}\n\nstatic size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {\n    ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context;\n    return ctx->devices.size();\n}\n\nstatic ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {\n    ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context;\n    GGML_ASSERT(index < ctx->devices.size());\n    return ctx->devices[index];\n}\n\nstatic ggml_backend_feature g_ggml_backend_metal_features[] = {\n#if defined(GGML_METAL_EMBED_LIBRARY)\n    { \"EMBED_LIBRARY\", \"1\" },\n#endif\n    { NULL, NULL },\n};\n\nstatic ggml_backend_feature * ggml_backend_metal_get_features(ggml_backend_reg_t reg) {\n    return g_ggml_backend_metal_features;\n\n    GGML_UNUSED(reg);\n}\n\nstatic void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) {\n    if (strcmp(name, \"ggml_backend_get_features\") == 0) {\n        return (void *)ggml_backend_metal_get_features;\n    }\n\n    return NULL;\n\n    GGML_UNUSED(reg);\n}\n\nstatic ggml_backend_reg_i ggml_backend_metal_reg_i = {\n    /* .get_name         = */ ggml_backend_metal_reg_get_name,\n    /* .get_device_count = */ ggml_backend_metal_reg_device_count,\n    /* .get_device       = */ ggml_backend_metal_reg_device_get,\n    /* .get_proc_address = */ ggml_backend_metal_get_proc_address,\n};\n\nstatic ggml_backend_dev_t ggml_backend_metal_device_init(ggml_backend_reg_t reg, int device) {\n    return new ggml_backend_device {\n        /* .iface   = */ ggml_backend_metal_device_i,\n        /* .reg     = */ reg,\n        /* .context = */ ggml_metal_device_get(device),\n    };\n}\n\nstatic void ggml_backend_metal_device_free(ggml_backend_dev_t dev) {\n    delete dev;\n}\n\nstruct ggml_backend_device_deleter {\n    void operator()(ggml_backend_dev_t ctx) {\n        ggml_backend_metal_device_free(ctx);\n    }\n};\n\ntypedef std::unique_ptr<ggml_backend_device, ggml_backend_device_deleter> ggml_backend_device_ptr;\n\nggml_backend_reg_t ggml_backend_metal_reg(void) {\n    static ggml_backend_reg reg;\n    static bool initialized = false;\n\n    {\n        static std::mutex mutex;\n        std::lock_guard<std::mutex> lock(mutex);\n\n        const char * env = getenv(\"GGML_METAL_DEVICES\");\n        if (env) {\n            g_devices = atoi(env);\n        }\n\n        static std::vector<ggml_backend_device_ptr> devs;\n\n        if (!initialized) {\n            static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init());\n\n            for (int i = 0; i < g_devices; ++i) {\n                auto * dev = ggml_backend_metal_device_init(&reg, i);\n                devs.emplace_back(dev);\n\n                reg_ctx->devices.push_back(dev);\n            }\n\n            reg = {\n                /* .api_version = */ GGML_BACKEND_API_VERSION,\n                /* .iface       = */ ggml_backend_metal_reg_i,\n                /* .context     = */ reg_ctx.get(),\n            };\n        }\n\n        initialized = true;\n    }\n\n    return &reg;\n}\n\nGGML_BACKEND_DL_IMPL(ggml_backend_metal_reg)\n"
  },
  {
    "path": "src/ggml-metal/ggml-metal.metal",
    "content": "#define GGML_COMMON_DECL_METAL\n#define GGML_COMMON_IMPL_METAL\n#if defined(GGML_METAL_EMBED_LIBRARY)\n__embed_ggml-common.h__\n#else\n#include \"ggml-common.h\"\n#endif\n#include \"ggml-metal-impl.h\"\n\n#include <metal_stdlib>\n\n#ifdef GGML_METAL_HAS_TENSOR\n#include <metal_tensor>\n\n#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>\n#endif\n\nusing namespace metal;\n\n#define MAX(x, y) ((x) > (y) ? (x) : (y))\n#define MIN(x, y) ((x) < (y) ? (x) : (y))\n#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }\n\n#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))\n\n#define FOR_UNROLL(x) _Pragma(\"clang loop unroll(full)\") for (x)\n\n#define N_SIMDWIDTH 32 // assuming SIMD group size is 32\n\n// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf\n//\n// cmd:\n//   .../usr/bin/metal -dM -E -c                             ggml/src/ggml-metal/ggml-metal.metal\n//   .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal\n//\n#if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16)\n#undef GGML_METAL_HAS_BF16\n#endif\n\n#if defined(GGML_METAL_HAS_BF16)\ntypedef matrix<bfloat, 4, 4> bfloat4x4;\ntypedef matrix<bfloat, 2, 4> bfloat2x4;\n#endif\n\nconstexpr constant static float kvalues_iq4nl_f[16] = {\n    -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f\n};\n\nconstexpr constant static float kvalues_mxfp4_f[16] = {\n    0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f\n};\n\nstatic inline int best_index_int8(int n, constant float * val, float x) {\n    if (x <= val[0]) return 0;\n    if (x >= val[n-1]) return n-1;\n    int ml = 0, mu = n-1;\n    while (mu-ml > 1) {\n        int mav = (ml+mu)/2;\n        if (x < val[mav]) mu = mav; else ml = mav;\n    }\n    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;\n}\n\nstatic inline float e8m0_to_fp32(uint8_t x) {\n    uint32_t bits;\n\n    if (x == 0) {\n        bits = 0x00400000;\n    } else {\n        bits = (uint32_t) x << 23;\n    }\n\n    return as_type<float>(bits);\n}\n\nstatic inline float dot(float x, float y) {\n    return x*y;\n}\n\nstatic inline float sum(float x) {\n    return x;\n}\n\nstatic inline float sum(float4 x) {\n    return x[0] + x[1] + x[2] + x[3];\n}\n\n// NOTE: this is not dequantizing - we are simply fitting the template\ntemplate <typename type4x4>\nvoid dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {\n    reg = (type4x4)(*src);\n}\n\ntemplate <typename type4>\nvoid dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {\n    reg = (type4)(*src);\n}\n\ntemplate <typename type4x4>\nvoid dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {\n    reg = (type4x4)(*src);\n}\n\ntemplate <typename type4>\nvoid dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {\n    reg = (type4)(*(src));\n}\n\n#if defined(GGML_METAL_HAS_BF16)\ntemplate <typename type4x4>\nvoid dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {\n    reg = (type4x4)(*src);\n}\n\ntemplate <typename type4>\nvoid dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {\n    reg = (type4)(*(src));\n}\n#endif\n\ntemplate <typename type4x4>\nvoid dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {\n    device const uint16_t * qs = ((device const uint16_t *)xb + 1);\n    const float d1 = il ? (xb->d / 16.h) : xb->d;\n    const float d2 = d1 / 256.f;\n    const float md = -8.h * xb->d;\n    const ushort mask0 = il ? 0x00F0 : 0x000F;\n    const ushort mask1 = mask0 << 8;\n\n    float4x4 reg_f;\n\n    for (int i = 0; i < 8; i++) {\n        reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;\n        reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;\n    }\n\n    reg = (type4x4) reg_f;\n}\n\ntemplate <typename type4>\nvoid dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {\n    device const uint16_t * qs = ((device const uint16_t *)xb + 1);\n    const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;\n    const float d2 = d1 / 256.f;\n    const float md = -8.h * xb->d;\n    const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;\n    const ushort mask1 = mask0 << 8;\n\n    for (int i = 0; i < 2; i++) {\n        reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;\n        reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;\n    }\n}\n\nvoid quantize_q4_0(device const float * src, device block_q4_0 & dst) {\n#pragma METAL fp math_mode(safe)\n    float amax = 0.0f; // absolute max\n    float max  = 0.0f;\n\n    for (int j = 0; j < QK4_0; j++) {\n        const float v = src[j];\n        if (amax < fabs(v)) {\n            amax = fabs(v);\n            max  = v;\n        }\n    }\n\n    const float d = max / -8;\n    const float id = d ? 1.0f/d : 0.0f;\n\n    dst.d = d;\n\n    for (int j = 0; j < QK4_0/2; ++j) {\n        const float x0 = src[0       + j]*id;\n        const float x1 = src[QK4_0/2 + j]*id;\n\n        const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));\n        const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));\n\n        dst.qs[j]  = xi0;\n        dst.qs[j] |= xi1 << 4;\n    }\n}\n\nvoid quantize_q4_1(device const float * src, device block_q4_1 & dst) {\n#pragma METAL fp math_mode(safe)\n    float min = FLT_MAX;\n    float max = -FLT_MAX;\n\n    for (int j = 0; j < QK4_1; j++) {\n        const float v = src[j];\n        if (min > v) min = v;\n        if (max < v) max = v;\n    }\n\n    const float d = (max - min) / ((1 << 4) - 1);\n    const float id = d ? 1.0f/d : 0.0f;\n\n    dst.d = d;\n    dst.m = min;\n\n    for (int j = 0; j < QK4_1/2; ++j) {\n        const float x0 = (src[0       + j] - min)*id;\n        const float x1 = (src[QK4_1/2 + j] - min)*id;\n\n        const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));\n        const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));\n\n        dst.qs[j]  = xi0;\n        dst.qs[j] |= xi1 << 4;\n    }\n}\n\nvoid quantize_q5_0(device const float * src, device block_q5_0 & dst) {\n#pragma METAL fp math_mode(safe)\n    float amax = 0.0f; // absolute max\n    float max  = 0.0f;\n\n    for (int j = 0; j < QK5_0; j++) {\n        const float v = src[j];\n        if (amax < fabs(v)) {\n            amax = fabs(v);\n            max  = v;\n        }\n    }\n\n    const float d = max / -16;\n    const float id = d ? 1.0f/d : 0.0f;\n\n    dst.d = d;\n\n    uint32_t qh = 0;\n    for (int j = 0; j < QK5_0/2; ++j) {\n        const float x0 = src[0       + j]*id;\n        const float x1 = src[QK5_0/2 + j]*id;\n\n        const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));\n        const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));\n\n        dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);\n        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);\n        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);\n    }\n\n    thread const uint8_t * qh8 = (thread const uint8_t *)&qh;\n\n    for (int j = 0; j < 4; ++j) {\n        dst.qh[j] = qh8[j];\n    }\n}\n\nvoid quantize_q5_1(device const float * src, device block_q5_1 & dst) {\n#pragma METAL fp math_mode(safe)\n    float max = src[0];\n    float min = src[0];\n\n    for (int j = 1; j < QK5_1; j++) {\n        const float v = src[j];\n        min = v < min ? v : min;\n        max = v > max ? v : max;\n    }\n\n    const float d = (max - min) / 31;\n    const float id = d ? 1.0f/d : 0.0f;\n\n    dst.d = d;\n    dst.m = min;\n\n    uint32_t qh = 0;\n    for (int j = 0; j < QK5_1/2; ++j) {\n        const float x0 = (src[0       + j] - min)*id;\n        const float x1 = (src[QK5_1/2 + j] - min)*id;\n\n        const uint8_t xi0 = (uint8_t)(x0 + 0.5f);\n        const uint8_t xi1 = (uint8_t)(x1 + 0.5f);\n\n        dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);\n        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);\n        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);\n    }\n\n    thread const uint8_t * qh8 = (thread const uint8_t *)&qh;\n\n    for (int j = 0; j < 4; ++j) {\n        dst.qh[j] = qh8[j];\n    }\n}\n\nvoid quantize_q8_0(device const float * src, device block_q8_0 & dst) {\n#pragma METAL fp math_mode(safe)\n    float amax = 0.0f; // absolute max\n\n    for (int j = 0; j < QK8_0; j++) {\n        const float v = src[j];\n        amax = MAX(amax, fabs(v));\n    }\n\n    const float d = amax / ((1 << 7) - 1);\n    const float id = d ? 1.0f/d : 0.0f;\n\n    dst.d = d;\n\n    for (int j = 0; j < QK8_0; ++j) {\n        const float x0 = src[j]*id;\n\n        dst.qs[j] = round(x0);\n    }\n}\n\nvoid quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {\n#pragma METAL fp math_mode(safe)\n    float amax = 0.0f; // absolute max\n    float max  = 0.0f;\n\n    for (int j = 0; j < QK4_NL; j++) {\n        const float v = src[j];\n        if (amax < fabs(v)) {\n            amax = fabs(v);\n            max  = v;\n        }\n    }\n\n    const float d = max / kvalues_iq4nl_f[0];\n    const float id = d ? 1.0f/d : 0.0f;\n\n    float sumqx = 0, sumq2 = 0;\n    for (int j = 0; j < QK4_NL/2; ++j) {\n        const float x0 = src[0        + j]*id;\n        const float x1 = src[QK4_NL/2 + j]*id;\n\n        const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);\n        const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);\n\n        dst.qs[j] = xi0 | (xi1 << 4);\n\n        const float v0 = kvalues_iq4nl_f[xi0];\n        const float v1 = kvalues_iq4nl_f[xi1];\n        const float w0 = src[0        + j]*src[0        + j];\n        const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];\n        sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];\n        sumq2 += w0*v0*v0 + w1*v1*v1;\n\n    }\n\n    dst.d = sumq2 > 0 ? sumqx/sumq2 : d;\n}\n\ntemplate <typename type4x4>\nvoid dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {\n    device const uint16_t * qs = ((device const uint16_t *)xb + 2);\n    const float d1 = il ? (xb->d / 16.h) : xb->d;\n    const float d2 = d1 / 256.f;\n    const float  m = xb->m;\n    const ushort mask0 = il ? 0x00F0 : 0x000F;\n    const ushort mask1 = mask0 << 8;\n\n    float4x4 reg_f;\n\n    for (int i = 0; i < 8; i++) {\n        reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;\n        reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;\n    }\n\n    reg = (type4x4) reg_f;\n}\n\ntemplate <typename type4>\nvoid dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {\n    device const uint16_t * qs = ((device const uint16_t *)xb + 2);\n    const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;\n    const float d2 = d1 / 256.f;\n    const float  m = xb->m;\n    const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;\n    const ushort mask1 = mask0 << 8;\n\n    for (int i = 0; i < 2; i++) {\n        reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;\n        reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {\n    device const uint16_t * qs = ((device const uint16_t *)xb + 3);\n    const float d = xb->d;\n    const float md = -16.h * xb->d;\n    const ushort mask = il ? 0x00F0 : 0x000F;\n\n    const uint32_t qh = *((device const uint32_t *)xb->qh);\n\n    const int x_mv = il ? 4 : 0;\n\n    const int gh_mv = il ? 12 : 0;\n    const int gh_bk = il ?  0 : 4;\n\n    float4x4 reg_f;\n\n    for (int i = 0; i < 8; i++) {\n        // extract the 5-th bits for x0 and x1\n        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;\n        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;\n\n        // combine the 4-bits from qs with the 5th bit\n        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);\n        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);\n\n        reg_f[i/2][2*(i%2) + 0] = d * x0 + md;\n        reg_f[i/2][2*(i%2) + 1] = d * x1 + md;\n    }\n\n    reg = (type4x4) reg_f;\n}\n\ntemplate <typename type4>\nvoid dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {\n    device const uint16_t * qs = ((device const uint16_t *)xb + 3);\n    const float d = xb->d;\n    const float md = -16.h * xb->d;\n    const ushort mask = (il/4) ? 0x00F0 : 0x000F;\n\n    const uint32_t qh = *((device const uint32_t *)xb->qh);\n\n    const int x_mv = (il/4) ? 4 : 0;\n\n    const int gh_mv = (il/4) ? 12 : 0;\n    const int gh_bk = (il/4) ?  0 : 4;\n\n    for (int ii = 0; ii < 2; ii++) {\n        int i = 2*(il%4) + ii;\n\n        // extract the 5-th bits for x0 and x1\n        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;\n        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;\n\n        // combine the 4-bits from qs with the 5th bit\n        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);\n        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);\n\n        reg[2*ii + 0] = d * x0 + md;\n        reg[2*ii + 1] = d * x1 + md;\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {\n    device const uint16_t * qs = ((device const uint16_t *)xb + 4);\n    const float d = xb->d;\n    const float m = xb->m;\n    const ushort mask = il ? 0x00F0 : 0x000F;\n\n    const uint32_t qh = *((device const uint32_t *)xb->qh);\n\n    const int x_mv = il ? 4 : 0;\n\n    const int gh_mv = il ? 12 : 0;\n    const int gh_bk = il ?  0 : 4;\n\n    float4x4 reg_f;\n\n    for (int i = 0; i < 8; i++) {\n        // extract the 5-th bits for x0 and x1\n        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;\n        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;\n\n        // combine the 4-bits from qs with the 5th bit\n        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);\n        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);\n\n        reg_f[i/2][2*(i%2) + 0] = d * x0 + m;\n        reg_f[i/2][2*(i%2) + 1] = d * x1 + m;\n    }\n\n    reg = (type4x4) reg_f;\n}\n\ntemplate <typename type4>\nvoid dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {\n    device const uint16_t * qs = ((device const uint16_t *)xb + 4);\n    const float d = xb->d;\n    const float m = xb->m;\n    const ushort mask = (il/4) ? 0x00F0 : 0x000F;\n\n    const uint32_t qh = *((device const uint32_t *)xb->qh);\n\n    const int x_mv = (il/4) ? 4 : 0;\n\n    const int gh_mv = (il/4) ? 12 : 0;\n    const int gh_bk = (il/4) ?  0 : 4;\n\n    for (int ii = 0; ii < 2; ii++) {\n        int i = 2*(il%4) + ii;\n\n        // extract the 5-th bits for x0 and x1\n        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;\n        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;\n\n        // combine the 4-bits from qs with the 5th bit\n        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);\n        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);\n\n        reg[2*ii + 0] = d * x0 + m;\n        reg[2*ii + 1] = d * x1 + m;\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {\n    device const int8_t * qs = ((device const int8_t *)xb->qs);\n    const float d = xb->d;\n\n    float4x4 reg_f;\n\n    for (int i = 0; i < 16; i++) {\n        reg_f[i/4][i%4] = (qs[i + 16*il] * d);\n    }\n\n    reg = (type4x4) reg_f;\n}\n\ntemplate <typename type4>\nvoid dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {\n    device const int8_t * qs = ((device const int8_t *)xb->qs);\n    const float d = xb->d;\n\n    for (int i = 0; i < 4; i++) {\n        reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {\n    device const uint8_t * q2 = (device const uint8_t *)xb->qs;\n\n    const float d = e8m0_to_fp32(xb->e);\n    const uint8_t shr = il >= 1 ? 4 : 0;\n\n    for (int i = 0; i < 4; ++i) {\n        reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];\n        reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];\n        reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];\n        reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];\n    }\n}\n\ntemplate <typename type4>\nvoid dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {\n    device const uint8_t * q2 = (device const uint8_t *)xb->qs;\n\n    const float d = e8m0_to_fp32(xb->e);\n    const short il4 = il%4;\n\n    const uint8_t shr = il >= 4 ? 4 : 0;\n\n    reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];\n    reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];\n    reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];\n    reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];\n}\n\ntemplate <typename type4x4>\nvoid dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {\n    const float d = xb->d;\n    const float min = xb->dmin;\n    device const uint8_t * q = (device const uint8_t *)xb->qs;\n    float dl, ml;\n    uint8_t sc = xb->scales[il];\n\n    q = q + 32*(il/8) + 16*(il&1);\n    il = (il/2)%4;\n\n    half  coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);\n    uchar mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);\n    dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);\n    for (int i = 0; i < 16; ++i) {\n        reg[i/4][i%4] = dl * (q[i] & mask) - ml;\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {\n    const half d_all = xb->d;\n    device const uint8_t * q = (device const uint8_t *)xb->qs;\n    device const uint8_t * h = (device const uint8_t *)xb->hmask;\n    device const int8_t * scales = (device const int8_t *)xb->scales;\n\n    q = q + 32 * (il/8) + 16 * (il&1);\n    h = h + 16 * (il&1);\n    uint8_t m = 1 << (il/2);\n    uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \\\n                                 ((il/4)>0 ? 12  : 3);\n    uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;\n    uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];\n    int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)\n                               : (scale_2&kmask2) | ((scale_1&kmask1) << 4);\n    float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);\n    const float ml = 4.f * dl;\n\n    il = (il/2) & 3;\n    const half    coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);\n    const uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);\n    dl *= coef;\n\n    for (int i = 0; i < 16; ++i) {\n        reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);\n    }\n}\n\nstatic inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {\n    return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}\n                 : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};\n}\n\ntemplate <typename type4x4>\nvoid dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {\n    device const uchar * q = xb->qs;\n\n    short is = (il/4) * 2;\n    q = q + (il/4) * 32 + 16 * (il&1);\n    il = il & 3;\n    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);\n    const float d   = il < 2 ? xb->d : xb->d / 16.h;\n    const float min = xb->dmin;\n    const float dl = d * sc[0];\n    const float ml = min * sc[1];\n\n    const ushort mask = il < 2 ? 0x0F : 0xF0;\n    for (int i = 0; i < 16; ++i) {\n        reg[i/4][i%4] = dl * (q[i] & mask) - ml;\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {\n    device const uint8_t * q  = xb->qs;\n    device const uint8_t * qh = xb->qh;\n\n    short is = (il/4) * 2;\n    q  = q + 32 * (il/4) + 16 * (il&1);\n    qh = qh + 16 * (il&1);\n    uint8_t ul = 1 << (il/2);\n    il = il & 3;\n    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);\n    const float d = il < 2 ? xb->d : xb->d / 16.f;\n    const float min = xb->dmin;\n    const float dl = d * sc[0];\n    const float ml = min * sc[1];\n\n    const ushort mask  = il<2 ? 0x0F : 0xF0;\n    const float qh_val = il<2 ? 16.f : 256.f;\n    for (int i = 0; i < 16; ++i) {\n        reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {\n    const half d_all = xb->d;\n    device const uint16_t * ql = (device const uint16_t *)xb->ql;\n    device const uint16_t * qh = (device const uint16_t *)xb->qh;\n    device const int8_t * scales = (device const int8_t *)xb->scales;\n\n    ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);\n    qh = qh + 16*(il/8) + 8*(il&1);\n    float sc = scales[(il%2) + 2 * ((il/2))];\n    il = (il/2) & 3;\n\n    const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);\n    const uint32_t kmask2 = il>1 ? 0xF0F0F0F0                       : 0x0F0F0F0F;\n    const float ml = d_all * sc * 32.f;\n    const float dl0 = d_all * sc;\n    const float dl1 = dl0 / 256.f;\n    const float dl2 = dl0 / (256.f * 256.f);\n    const float dl3 = dl0 / (256.f * 256.f * 256.f);\n    const uint8_t shr_h = il>2 ? 2 : 0;\n    const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);\n    const uint8_t shr_l = il>1 ? 4 : 0;\n    for (int i = 0; i < 4; ++i) {\n        const uint32_t  low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;\n        const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;\n        const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);\n        reg[i][0] = dl0 *  ((half)(q & 0xFF))       - ml;\n        reg[i][1] = dl1 * ((float)(q & 0xFF00))     - ml;\n        reg[i][2] = dl2 * ((float)(q & 0xFF0000))   - ml;\n        reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {\n    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2\n    const float d = xb->d;\n    const int ib32 = il/2;\n    il = il%2;\n    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16\n    // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.\n    device const uint16_t * q2 = xb->qs + 4*ib32;\n    const uint32_t aux32_g = q2[0] | (q2[1] << 16);\n    const uint32_t aux32_s = q2[2] | (q2[3] << 16);\n    thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;\n    const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;\n    constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);\n    uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];\n    for (int i = 0; i < 8; ++i) {\n        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);\n    }\n    grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);\n    signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];\n    for (int i = 0; i < 8; ++i) {\n        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {\n    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2\n    const float d = xb->d;\n    const int ib32 = il/2;\n    il = il%2;\n    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16\n    device const uint16_t * q2 = xb->qs + 4*ib32;\n    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;\n    constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));\n    uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];\n    for (int i = 0; i < 8; ++i) {\n        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);\n    }\n    grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));\n    signs = ksigns_iq2xs[q2[2*il+1] >> 9];\n    for (int i = 0; i < 8; ++i) {\n        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {\n    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2\n    const float d = xb->d;\n    const int ib32 = il/2;\n    il = il%2;\n    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16\n    device const uint8_t * q3 = xb->qs + 8*ib32;\n    device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;\n    const uint32_t aux32 = gas[0] | (gas[1] << 16);\n    const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;\n    constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);\n    constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);\n    uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];\n    for (int i = 0; i < 4; ++i) {\n        reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);\n        reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);\n    }\n    grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);\n    grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);\n    signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];\n    for (int i = 0; i < 4; ++i) {\n        reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);\n        reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {\n    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2\n    const float d = xb->d;\n    const int ib32 = il/2;\n    il = il%2;\n    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16\n    device const uint8_t * qs = xb->qs + 8*ib32;\n    device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;\n    const uint8_t qh = xb->qh[ib32] >> 4*il;\n    const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));\n    constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));\n    constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));\n    for (int i = 0; i < 4; ++i) {\n        reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);\n        reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);\n    }\n    grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));\n    grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));\n    for (int i = 0; i < 4; ++i) {\n        reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);\n        reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {\n    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2\n    const float d = xb->d;\n    const int ib32 = il/2;\n    il = il%2;\n    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16\n    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;\n    device const uint8_t * signs = qs + QK_K/8;\n    const uint8_t qh = xb->qh[ib32] >> 4*il;\n    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;\n    constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));\n    constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));\n    for (int i = 0; i < 8; ++i) {\n        reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);\n        reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {\n    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2\n    const int ib32 = il/2;\n    il = il%2;\n    const float d = xb->d;\n    device const uint8_t  * qs = xb->qs + 4*ib32 + 2*il;\n    device const uint16_t * qh = xb->qh;\n    const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);\n    const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);\n    const uint16_t h = qh[ib32] >> 6*il;\n    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));\n    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));\n    for (int i = 0; i < 4; ++i) {\n        reg[0][i] = dl * (grid1[i] & 0xf) + ml;\n        reg[1][i] = dl * (grid1[i] >>  4) + ml;\n        reg[2][i] = dl * (grid2[i] & 0xf) + ml;\n        reg[3][i] = dl * (grid2[i] >>  4) + ml;\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {\n    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2\n    const int ib32 = il/2;\n    il = il%2;\n    device const uint16_t * sc = (device const uint16_t *)xb->scales;\n\n    iq1m_scale_t scale;\n    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);\n    const float d = scale.f16;\n\n    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;\n    device const uint8_t * qh = xb->qh + 2*ib32 + il;\n\n    const float dl  = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);\n    const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);\n    const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);\n    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));\n    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));\n    for (int i = 0; i < 4; ++i) {\n        reg[0][i] = dl * (grid1[i] & 0xf) + ml1;\n        reg[1][i] = dl * (grid1[i] >>  4) + ml1;\n        reg[2][i] = dl * (grid2[i] & 0xf) + ml2;\n        reg[3][i] = dl * (grid2[i] >>  4) + ml2;\n    }\n}\n\ntemplate <typename type4x4>\nvoid dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {\n    device const uint16_t * q4 = (device const uint16_t *)xb->qs;\n    const float d = xb->d;\n    uint32_t aux32;\n    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;\n    for (int i = 0; i < 4; ++i) {\n        aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;\n        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];\n        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];\n        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];\n        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];\n    }\n}\n\ntemplate <typename type4>\nvoid dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {\n    device const uint16_t * q4 = (device const uint16_t *)xb->qs;\n    const float d = xb->d;\n    uint32_t aux32;\n    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;\n    aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;\n    reg[0] = d * kvalues_iq4nl_f[q8[0]];\n    reg[1] = d * kvalues_iq4nl_f[q8[1]];\n    reg[2] = d * kvalues_iq4nl_f[q8[2]];\n    reg[3] = d * kvalues_iq4nl_f[q8[3]];\n}\n\ntemplate <typename type4x4>\nvoid dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {\n    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2\n    const int ib32 = il/2;\n    il = il%2;\n    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16\n    device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;\n    const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);\n    const float d = (float)xb->d * (ls - 32);\n    uint32_t aux32;\n    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;\n    for (int i = 0; i < 4; ++i) {\n        aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;\n        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];\n        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];\n        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];\n        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];\n    }\n}\n\nenum ggml_sort_order {\n    GGML_SORT_ORDER_ASC,\n    GGML_SORT_ORDER_DESC,\n};\n\nconstant float GELU_COEF_A     = 0.044715f;\nconstant float GELU_QUICK_COEF = -1.702f;\nconstant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;\nconstant float SQRT_2_INV      = 0.70710678118654752440084436210484f;\n\n// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation\n// ref: https://www.johndcook.com/blog/python_erf/\nconstant float p_erf  = 0.3275911f;\nconstant float a1_erf = 0.254829592f;\nconstant float a2_erf = -0.284496736f;\nconstant float a3_erf = 1.421413741f;\nconstant float a4_erf = -1.453152027f;\nconstant float a5_erf = 1.061405429f;\n\ntemplate<typename T>\ninline T erf_approx(T x) {\n    T sign_x = sign(x);\n    x = fabs(x);\n    T t = 1.0f / (1.0f + p_erf * x);\n    T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);\n    return sign_x * y;\n}\n\ntemplate<typename T> T elu_approx(T x);\n\ntemplate<> inline float elu_approx<float>(float x) {\n    return (x > 0.f) ? x : (exp(x) - 1);\n}\n\ntemplate<> inline float4 elu_approx<float4>(float4 x) {\n    float4 res;\n\n    res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);\n    res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);\n    res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);\n    res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);\n\n    return res;\n}\n\nconstant short FC_unary_op [[function_constant(FC_UNARY + 0)]];\nconstant bool  FC_unary_cnt[[function_constant(FC_UNARY + 1)]];\n\ntemplate <typename T0, typename T, typename TC>\nkernel void kernel_unary_impl(\n        constant ggml_metal_kargs_unary & args,\n        device const char * src0,\n        device       char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n#define FC_OP  FC_unary_op\n#define FC_CNT FC_unary_cnt\n\n    device const T0 * src0_ptr;\n    device       T  * dst_ptr;\n\n    int i0;\n\n    if (FC_CNT) {\n        i0 = tgpig.x;\n\n        src0_ptr = (device const T0 *) (src0);\n        dst_ptr  = (device       T  *) (dst);\n    } else {\n        const int i03 = tgpig.z;\n        const int i02 = tgpig.y;\n        const int k0  = tgpig.x/args.ne01;\n        const int i01 = tgpig.x - k0*args.ne01;\n\n        i0 = k0*ntg.x + tpitg.x;\n\n        src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);\n        dst_ptr  = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1 );\n    }\n\n    {\n        //threadgroup_barrier(mem_flags::mem_none);\n\n        if (!FC_CNT) {\n            if (i0 >= args.ne0) {\n                return;\n            }\n        }\n\n        const TC x = (TC) src0_ptr[i0];\n\n        if (FC_OP == OP_UNARY_NUM_SCALE) {\n            dst_ptr[i0] = (T) (args.scale * x + args.bias);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_FILL) {\n            dst_ptr[i0] = (T) args.val;\n        }\n\n        if (FC_OP == OP_UNARY_NUM_CLAMP) {\n            dst_ptr[i0] = (T) clamp(x, args.min, args.max);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_SQR) {\n            dst_ptr[i0] = (T) (x * x);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_SQRT) {\n            dst_ptr[i0] = (T) sqrt(x);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_SIN) {\n            dst_ptr[i0] = (T) sin(x);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_COS) {\n            dst_ptr[i0] = (T) cos(x);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_LOG) {\n            dst_ptr[i0] = (T) log(x);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {\n            dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));\n        }\n\n        if (FC_OP == OP_UNARY_NUM_TANH) {\n            dst_ptr[i0] = (T) precise::tanh(x);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_RELU) {\n            dst_ptr[i0] = (T) fmax(0, x);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_SIGMOID) {\n            dst_ptr[i0] = (T) (1 / (1 + exp(-x)));\n        }\n\n        if (FC_OP == OP_UNARY_NUM_GELU) {\n            dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));\n        }\n\n        if (FC_OP == OP_UNARY_NUM_GELU_ERF) {\n            dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));\n        }\n\n        if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {\n            dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));\n        }\n\n        if (FC_OP == OP_UNARY_NUM_SILU) {\n            dst_ptr[i0] = (T) (x / (1 + exp(-x)));\n        }\n\n        if (FC_OP == OP_UNARY_NUM_ELU) {\n            dst_ptr[i0] = (T) elu_approx(x);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_NEG) {\n            dst_ptr[i0] = (T) -x;\n        }\n\n        if (FC_OP == OP_UNARY_NUM_ABS) {\n            dst_ptr[i0] = (T) fabs(x);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_SGN) {\n            dst_ptr[i0] = T(x > 0) - T(x < 0);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_STEP) {\n            dst_ptr[i0] = T(x > 0);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_HARDSWISH) {\n            dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));\n        }\n\n        if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {\n            dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));\n        }\n\n        if (FC_OP == OP_UNARY_NUM_EXP) {\n            dst_ptr[i0] = (T) exp(x);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {\n            dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);\n        }\n\n        if (FC_OP == OP_UNARY_NUM_EXPM1) {\n            // TODO: precise implementation\n            dst_ptr[i0] = (T) (exp(x) - 1);\n        }\n    }\n\n#undef FC_OP\n#undef FC_CNT\n}\n\ntypedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;\n\ntemplate [[host_name(\"kernel_unary_f32_f32\")]]   kernel kernel_unary_t kernel_unary_impl<float,  float,  float>;\ntemplate [[host_name(\"kernel_unary_f32_f32_4\")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;\ntemplate [[host_name(\"kernel_unary_f16_f16\")]]   kernel kernel_unary_t kernel_unary_impl<half,   half,   float>;\ntemplate [[host_name(\"kernel_unary_f16_f16_4\")]] kernel kernel_unary_t kernel_unary_impl<half4,  half4,  float4>;\n\n// OP: 0 - add, 1 - sub, 2 - mul, 3 - div\nconstant short FC_bin_op [[function_constant(FC_BIN + 0)]];\nconstant short FC_bin_f  [[function_constant(FC_BIN + 1)]];\nconstant bool  FC_bin_rb [[function_constant(FC_BIN + 2)]];\nconstant bool  FC_bin_cb [[function_constant(FC_BIN + 3)]];\n\ntemplate <typename T0, typename T1, typename T>\nkernel void kernel_bin_fuse_impl(\n        constant ggml_metal_kargs_bin & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n#define FC_OP FC_bin_op\n#define FC_F  FC_bin_f\n#define FC_RB FC_bin_rb\n#define FC_CB FC_bin_cb\n\n    if (FC_RB) {\n        // row broadcast\n        const uint i0 = tgpig.y*args.ne00 + tgpig.x;\n        const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;\n\n        device const T0 * src0_row = (device const T0 *) (src0);\n        device       T  * dst_row  = (device       T  *) (dst);\n\n        if (FC_F == 1) {\n            device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);\n\n            if (FC_OP == 0) {\n                dst_row[i0] = src0_row[i0] + src1_row[i1];\n            }\n\n            if (FC_OP == 1) {\n                dst_row[i0] = src0_row[i0] - src1_row[i1];\n            }\n\n            if (FC_OP == 2) {\n                dst_row[i0] = src0_row[i0] * src1_row[i1];\n            }\n\n            if (FC_OP == 3) {\n                dst_row[i0] = src0_row[i0] / src1_row[i1];\n            }\n        } else {\n            T0 res = src0_row[i0];\n\n            if (FC_OP == 0) {\n                FOR_UNROLL (short j = 0; j < FC_F; ++j) {\n                    res += ((device const T1 *) (src1 + args.o1[j]))[i1];\n                }\n            }\n\n            if (FC_OP == 1) {\n                FOR_UNROLL (short j = 0; j < FC_F; ++j) {\n                    res -= ((device const T1 *) (src1 + args.o1[j]))[i1];\n                }\n            }\n\n            if (FC_OP == 2) {\n                FOR_UNROLL (short j = 0; j < FC_F; ++j) {\n                    res *= ((device const T1 *) (src1 + args.o1[j]))[i1];\n                }\n            }\n\n            if (FC_OP == 3) {\n                FOR_UNROLL (short j = 0; j < FC_F; ++j) {\n                    res /= ((device const T1 *) (src1 + args.o1[j]))[i1];\n                }\n            }\n\n            dst_row[i0] = res;\n        }\n    } else {\n        const int i03 = tgpig.z;\n        const int i02 = tgpig.y;\n        const int i01 = tgpig.x;\n\n        if (i01 >= args.ne01) {\n            return;\n        }\n\n        const int i13 = i03%args.ne13;\n        const int i12 = i02%args.ne12;\n        const int i11 = i01%args.ne11;\n\n        device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);\n        device       T  * dst_ptr  = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs);\n\n        if (FC_F == 1) {\n            device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);\n\n            for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {\n                const int i10 = FC_CB ? i0%args.ne10 : i0;\n\n                if (FC_OP == 0) {\n                    dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];\n                }\n\n                if (FC_OP == 1) {\n                    dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];\n                }\n\n                if (FC_OP == 2) {\n                    dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];\n                }\n\n                if (FC_OP == 3) {\n                    dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];\n                }\n            }\n        } else {\n            device const T1 * src1_ptr[8];\n            FOR_UNROLL (short j = 0; j < FC_F; ++j) {\n                src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);\n            }\n\n            for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {\n                const int i10 = FC_CB ? i0%args.ne10 : i0;\n\n                T res = src0_ptr[i0];\n\n                if (FC_OP == 0) {\n                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {\n                        res += src1_ptr[j][i10];\n                    }\n                }\n\n                if (FC_OP == 1) {\n                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {\n                        res -= src1_ptr[j][i10];\n                    }\n                }\n\n                if (FC_OP == 2) {\n                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {\n                        res *= src1_ptr[j][i10];\n                    }\n                }\n\n                if (FC_OP == 3) {\n                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {\n                        res /= src1_ptr[j][i10];\n                    }\n                }\n\n                dst_ptr[i0] = res;\n            }\n        }\n    }\n\n#undef FC_OP\n#undef FC_F\n#undef FC_RB\n#undef FC_CB\n}\n\ntypedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;\n\ntemplate [[host_name(\"kernel_bin_fuse_f32_f32_f32\")]]   kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float,  float,  float>;\ntemplate [[host_name(\"kernel_bin_fuse_f32_f32_f32_4\")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;\n\nkernel void kernel_add_id(\n        constant ggml_metal_kargs_add_id & args,\n        device const char * src0,\n        device const char * src1,\n        device const char * src2,\n        device       char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    const int i1 = tgpig.x;\n    const int i2 = tgpig.y;\n\n    const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));\n\n    const size_t nb1 = args.ne0 * sizeof(float);\n    const size_t nb2 = args.ne1 * nb1;\n\n    device       float * dst_row  = (device       float *)((device char *)dst  +  i1*nb1       + i2*nb2);\n    device const float * src0_row = (device const float *)((device char *)src0 +  i1*args.nb01 + i2*args.nb02);\n    device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);\n\n    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {\n        dst_row[i0] = src0_row[i0] + src1_row[i0];\n    }\n}\n\ntemplate<typename T>\nkernel void kernel_repeat(\n        constant ggml_metal_kargs_repeat & args,\n        device const char * src0,\n        device       char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    const int i3 = tgpig.z;\n    const int i2 = tgpig.y;\n    const int i1 = tgpig.x;\n\n    const int i03 = i3%args.ne03;\n    const int i02 = i2%args.ne02;\n    const int i01 = i1%args.ne01;\n\n    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;\n    device       char * dst_ptr  = dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1;\n\n    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {\n        const int i00 = i0%args.ne00;\n        *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));\n    }\n}\n\ntypedef decltype(kernel_repeat<float>) kernel_repeat_t;\n\ntemplate [[host_name(\"kernel_repeat_f32\")]] kernel kernel_repeat_t kernel_repeat<float>;\ntemplate [[host_name(\"kernel_repeat_f16\")]] kernel kernel_repeat_t kernel_repeat<half>;\ntemplate [[host_name(\"kernel_repeat_i32\")]] kernel kernel_repeat_t kernel_repeat<int>;\ntemplate [[host_name(\"kernel_repeat_i16\")]] kernel kernel_repeat_t kernel_repeat<short>;\n\nkernel void kernel_reglu_f32(\n        constant ggml_metal_kargs_glu & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint tgpig[[threadgroup_position_in_grid]],\n        uint tpitg[[thread_position_in_threadgroup]],\n        uint   ntg[[threads_per_threadgroup]]) {\n    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;\n    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;\n    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);\n\n    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {\n        const float x0 = src0_row[i0];\n        const float x1 = src1_row[i0];\n\n        dst_row[i0] = x0*x1*(x0 > 0.0f);\n    }\n}\n\nkernel void kernel_geglu_f32(\n        constant ggml_metal_kargs_glu & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint tgpig[[threadgroup_position_in_grid]],\n        uint tpitg[[thread_position_in_threadgroup]],\n        uint   ntg[[threads_per_threadgroup]]) {\n    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;\n    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;\n    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);\n\n    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {\n        const float x0 = src0_row[i0];\n        const float x1 = src1_row[i0];\n\n        const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));\n\n        dst_row[i0] = gelu*x1;\n    }\n}\n\nkernel void kernel_swiglu_f32(\n        constant ggml_metal_kargs_glu & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint tgpig[[threadgroup_position_in_grid]],\n        uint tpitg[[thread_position_in_threadgroup]],\n        uint   ntg[[threads_per_threadgroup]]) {\n    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;\n    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;\n    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);\n\n    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {\n        const float x0 = src0_row[i0];\n        const float x1 = src1_row[i0];\n\n        const float silu = x0 / (1.0f + exp(-x0));\n\n        dst_row[i0] = silu*x1;\n    }\n}\n\nkernel void kernel_swiglu_oai_f32(\n        constant ggml_metal_kargs_glu & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint tgpig[[threadgroup_position_in_grid]],\n        uint tpitg[[thread_position_in_threadgroup]],\n        uint   ntg[[threads_per_threadgroup]]) {\n    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;\n    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;\n    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);\n\n    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {\n        float x0 = src0_row[i0];\n        float x1 = src1_row[i0];\n\n        x0 = min(x0, args.limit);\n        x1 = max(min(x1, args.limit), -args.limit);\n\n        float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));\n        out_glu = out_glu * (1.0f + x1);\n\n        dst_row[i0] = out_glu;\n    }\n}\n\nkernel void kernel_geglu_erf_f32(\n        constant ggml_metal_kargs_glu & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint tgpig[[threadgroup_position_in_grid]],\n        uint tpitg[[thread_position_in_threadgroup]],\n        uint   ntg[[threads_per_threadgroup]]) {\n    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;\n    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;\n    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);\n\n    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {\n        const float x0 = src0_row[i0];\n        const float x1 = src1_row[i0];\n\n        const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));\n\n        dst_row[i0] = gelu_erf*x1;\n    }\n}\n\nkernel void kernel_geglu_quick_f32(\n        constant ggml_metal_kargs_glu & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint tgpig[[threadgroup_position_in_grid]],\n        uint tpitg[[thread_position_in_threadgroup]],\n        uint   ntg[[threads_per_threadgroup]]) {\n    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;\n    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;\n    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);\n\n    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {\n        const float x0 = src0_row[i0];\n        const float x1 = src1_row[i0];\n\n        const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));\n\n        dst_row[i0] = gelu_quick*x1;\n    }\n}\n\nkernel void kernel_op_sum_f32(\n        constant ggml_metal_kargs_sum & args,\n        device const float * src0,\n        device       float * dst,\n        threadgroup  float * shmem_f32 [[threadgroup(0)]],\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n\n    if (args.np == 0) {\n        return;\n    }\n\n    // TODO: become function constant\n    const uint nsg = (ntg.x + 31) / 32;\n\n    float sumf = 0;\n\n    for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {\n        sumf += src0[i0];\n    }\n\n    sumf = simd_sum(sumf);\n\n    if (tiisg == 0) {\n        shmem_f32[sgitg] = sumf;\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    float total = 0;\n\n    if (sgitg == 0) {\n        float v = 0;\n\n        if (tpitg.x < nsg) {\n            v = shmem_f32[tpitg.x];\n        }\n\n        total = simd_sum(v);\n\n        if (tpitg.x == 0) {\n            dst[0] = total;\n        }\n    }\n}\n\nconstant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];\n\ntemplate <typename T0, typename T>\nkernel void kernel_sum_rows_impl(\n        constant ggml_metal_kargs_sum_rows & args,\n        device const char * src0,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n#define FC_OP  FC_sum_rows_op\n\n    const int i3 = tgpig.z;\n    const int i2 = tgpig.y;\n    const int i1 = tgpig.x;\n\n    threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;\n\n    if (sgitg == 0) {\n        shmem_t[tiisg] = 0.0f;\n    }\n\n    device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);\n    device       T  * dst_row = (device       T  *) (dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);\n\n    T0 sumf = T0(0.0f);\n\n    for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {\n        sumf += src_row[i0];\n    }\n\n    sumf = simd_sum(sumf);\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    if (tiisg == 0) {\n        shmem_t[sgitg] = sumf;\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    sumf = shmem_t[tiisg];\n    sumf = simd_sum(sumf);\n\n    if (tpitg.x == 0) {\n        if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {\n            if (is_same<float4, T0>::value) {\n                dst_row[0] = sum(sumf) / (4*args.ne00);\n            } else {\n                dst_row[0] = sum(sumf) / args.ne00;\n            }\n        } else {\n            dst_row[0] = sum(sumf);\n        }\n    }\n\n#undef FC_OP\n}\n\ntypedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;\n\ntemplate [[host_name(\"kernel_sum_rows_f32_f32\")]]   kernel kernel_sum_rows_t kernel_sum_rows_impl<float,  float>;\ntemplate [[host_name(\"kernel_sum_rows_f32_f32_4\")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;\n\ntemplate<typename T>\nkernel void kernel_cumsum_blk(\n        constant ggml_metal_kargs_cumsum_blk & args,\n        device const char * src0,\n        device       char * tmp,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    const int ib = tgpig[0]/args.ne01;\n\n    const int i00 = ib*ntg.x;\n    const int i01 = tgpig[0]%args.ne01;\n    const int i02 = tgpig[1];\n    const int i03 = tgpig[2];\n\n    device const float * src0_row = (device const float *) (src0 +\n            args.nb01*i01 +\n            args.nb02*i02 +\n            args.nb03*i03);\n\n    threadgroup float * shmem_f32 = (threadgroup float *) shmem;\n\n    float v = 0.0f;\n\n    if (i00 + tpitg.x < args.ne00) {\n        v = src0_row[i00 + tpitg.x];\n    }\n\n    float s = simd_prefix_inclusive_sum(v);\n\n    if (tiisg == N_SIMDWIDTH - 1) {\n        shmem_f32[sgitg] = s;\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    if (sgitg == 0) {\n        shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    s += shmem_f32[sgitg];\n\n    device float * dst_row = (device float *) dst +\n        args.ne00*i01 +\n        args.ne00*args.ne01*i02 +\n        args.ne00*args.ne01*args.ne02*i03;\n\n    if (i00 + tpitg.x < args.ne00) {\n        dst_row[i00 + tpitg.x] = s;\n    }\n\n    if (args.outb && tpitg.x == ntg.x - 1) {\n        device float * tmp_row = (device float *) tmp +\n            args.net0*i01 +\n            args.net0*args.net1*i02 +\n            args.net0*args.net1*args.net2*i03;\n\n        tmp_row[ib] = s;\n    }\n}\n\ntypedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;\n\ntemplate [[host_name(\"kernel_cumsum_blk_f32\")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;\n\ntemplate<typename T>\nkernel void kernel_cumsum_add(\n        constant ggml_metal_kargs_cumsum_add & args,\n        device const char * tmp,\n        device       char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    const int ib = tgpig[0]/args.ne01;\n\n    if (ib == 0) {\n        return;\n    }\n\n    const int i00 = ib*ntg.x;\n    const int i01 = tgpig[0]%args.ne01;\n    const int i02 = tgpig[1];\n    const int i03 = tgpig[2];\n\n    device const float * tmp_row = (device const float *) (tmp +\n            args.nbt1*i01 +\n            args.nbt2*i02 +\n            args.nbt3*i03);\n\n    device float * dst_row = (device float *) dst +\n        args.ne00*i01 +\n        args.ne00*args.ne01*i02 +\n        args.ne00*args.ne01*args.ne02*i03;\n\n    if (i00 + tpitg.x < args.ne00) {\n        dst_row[i00 + tpitg.x] += tmp_row[ib - 1];\n    }\n}\n\ntypedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;\n\ntemplate [[host_name(\"kernel_cumsum_add_f32\")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;\n\n\ntemplate<uint32_t ttype>\nbool _ggml_vec_tri_cmp(const int i, const int r);\n\ntemplate<>\nbool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {\n    return i < r;\n}\n\ntemplate<>\nbool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {\n    return i <= r;\n}\n\ntemplate<>\nbool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {\n    return i > r;\n}\n\ntemplate<>\nbool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {\n    return i >= r;\n}\n\ntemplate<typename T, int ttype>\nkernel void kernel_tri(\n        constant ggml_metal_kargs_tri & args,\n        device const char * src0,\n        device const char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    const int i3 = tgpig.z;\n    const int i2 = tgpig.y;\n    const int i1 = tgpig.x;\n\n    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {\n        return;\n    }\n\n    device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);\n    device       T * dst_row = (device       T *) ((device       char *) dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);\n\n    // Each thread is a single element of the row if ne00 < max threads per\n    // threadgroup, so this will loop once for each index that this thread is\n    // responsible for\n    for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {\n        // Use the comparison as a mask for branchless\n        dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];\n    }\n}\n\ntypedef decltype(kernel_tri<float, 0>) kernel_tri_t;\n\ntemplate [[host_name(\"kernel_tri_f32_0\")]] kernel kernel_tri_t kernel_tri<float, 0>;\ntemplate [[host_name(\"kernel_tri_f32_1\")]] kernel kernel_tri_t kernel_tri<float, 1>;\ntemplate [[host_name(\"kernel_tri_f32_2\")]] kernel kernel_tri_t kernel_tri<float, 2>;\ntemplate [[host_name(\"kernel_tri_f32_3\")]] kernel kernel_tri_t kernel_tri<float, 3>;\ntemplate [[host_name(\"kernel_tri_f16_0\")]] kernel kernel_tri_t kernel_tri<half, 0>;\ntemplate [[host_name(\"kernel_tri_f16_1\")]] kernel kernel_tri_t kernel_tri<half, 1>;\ntemplate [[host_name(\"kernel_tri_f16_2\")]] kernel kernel_tri_t kernel_tri<half, 2>;\ntemplate [[host_name(\"kernel_tri_f16_3\")]] kernel kernel_tri_t kernel_tri<half, 3>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_tri_bf16_0\")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;\ntemplate [[host_name(\"kernel_tri_bf16_1\")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;\ntemplate [[host_name(\"kernel_tri_bf16_2\")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;\ntemplate [[host_name(\"kernel_tri_bf16_3\")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;\n#endif\n\ntemplate<typename T>\nkernel void kernel_soft_max(\n        constant ggml_metal_kargs_soft_max & args,\n        device const  char * src0,\n        device const  char * src1,\n        device const  char * src2,\n        device        char * dst,\n        threadgroup  float * buf [[threadgroup(0)]],\n        uint3 tgpig[[threadgroup_position_in_grid]],\n        uint3 tpitg[[thread_position_in_threadgroup]],\n        uint  sgitg[[simdgroup_index_in_threadgroup]],\n        uint  tiisg[[thread_index_in_simdgroup]],\n        uint3  tptg[[threads_per_threadgroup]]) {\n    const int32_t i03 = tgpig.z;\n    const int32_t i02 = tgpig.y;\n    const int32_t i01 = tgpig.x;\n\n    const int32_t i13 = i03%args.ne13;\n    const int32_t i12 = i02%args.ne12;\n    const int32_t i11 = i01;\n\n    device const float * psrc0 =                (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);\n    device const     T * pmask = src1 != src0 ? (device const T *    ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;\n    device const float * psrc2 = src2 != src0 ? (device const float *) (src2)                                                 : nullptr;\n    device       float * pdst  =                (device       float *) (dst  + i01*args.nb1  + i02*args.nb2  + i03*args.nb3);\n\n    float slope = 1.0f;\n\n    // ALiBi\n    if (args.max_bias > 0.0f) {\n        const int32_t h = i02;\n\n        const float base = h < args.n_head_log2 ? args.m0 : args.m1;\n        const int   exp  = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;\n\n        slope = pow(base, exp);\n    }\n\n    // parallel max\n    float lmax = psrc2 ? psrc2[i02] : -INFINITY;\n\n    for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {\n        lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));\n    }\n\n    // find the max value in the block\n    float max_val = simd_max(lmax);\n    if (tptg.x > N_SIMDWIDTH) {\n        if (sgitg == 0) {\n            buf[tiisg] = -INFINITY;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        if (tiisg == 0) {\n            buf[sgitg] = max_val;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        max_val = buf[tiisg];\n        max_val = simd_max(max_val);\n    }\n\n    // parallel sum\n    float lsum = 0.0f;\n    for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {\n        const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);\n        lsum += exp_psrc0;\n        pdst[i00] = exp_psrc0;\n    }\n\n    // This barrier fixes a failing test\n    // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335\n    threadgroup_barrier(mem_flags::mem_none);\n\n    float sum = simd_sum(lsum);\n\n    if (tptg.x > N_SIMDWIDTH) {\n        if (sgitg == 0) {\n            buf[tiisg] = 0.0f;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        if (tiisg == 0) {\n            buf[sgitg] = sum;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        sum = buf[tiisg];\n        sum = simd_sum(sum);\n    }\n\n    if (psrc2) {\n        sum += exp(psrc2[i02] - max_val);\n    }\n\n    const float inv_sum = 1.0f/sum;\n\n    for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {\n        pdst[i00] *= inv_sum;\n    }\n}\n\ntemplate<typename T>\nkernel void kernel_soft_max_4(\n        constant ggml_metal_kargs_soft_max & args,\n        device const  char * src0,\n        device const  char * src1,\n        device const  char * src2,\n        device        char * dst,\n        threadgroup  float * buf [[threadgroup(0)]],\n        uint3 tgpig[[threadgroup_position_in_grid]],\n        uint3 tpitg[[thread_position_in_threadgroup]],\n        uint  sgitg[[simdgroup_index_in_threadgroup]],\n        uint  tiisg[[thread_index_in_simdgroup]],\n        uint3  tptg[[threads_per_threadgroup]]) {\n    const int32_t i03 = tgpig.z;\n    const int32_t i02 = tgpig.y;\n    const int32_t i01 = tgpig.x;\n\n    const int32_t i13 = i03%args.ne13;\n    const int32_t i12 = i02%args.ne12;\n    const int32_t i11 = i01;\n\n    device const float4 * psrc4 =                (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);\n    device const      T * pmask = src1 != src0 ? (device const T *     ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;\n    device const float *  psrc2 = src2 != src0 ? (device const float * ) (src2)                                                 : nullptr;\n    device       float4 * pdst4 =                (device       float4 *) (dst  + i01*args.nb1  + i02*args.nb2  + i03*args.nb3);\n\n    float slope = 1.0f;\n\n    if (args.max_bias > 0.0f) {\n        const int32_t h = i02;\n\n        const float base = h < args.n_head_log2 ? args.m0 : args.m1;\n        const int   exp  = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;\n\n        slope = pow(base, exp);\n    }\n\n    // parallel max\n    float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;\n\n    for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {\n        lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));\n    }\n\n    const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));\n\n    float max_val = simd_max(lmax);\n    if (tptg.x > N_SIMDWIDTH) {\n        if (sgitg == 0) {\n            buf[tiisg] = -INFINITY;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        if (tiisg == 0) {\n            buf[sgitg] = max_val;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        max_val = buf[tiisg];\n        max_val = simd_max(max_val);\n    }\n\n    // parallel sum\n    float4 lsum4 = 0.0f;\n    for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {\n        const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);\n        lsum4 += exp_psrc4;\n        pdst4[i00] = exp_psrc4;\n    }\n\n    const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];\n\n    // This barrier fixes a failing test\n    // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335\n    threadgroup_barrier(mem_flags::mem_none);\n\n    float sum = simd_sum(lsum);\n\n    if (tptg.x > N_SIMDWIDTH) {\n        if (sgitg == 0) {\n            buf[tiisg] = 0.0f;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        if (tiisg == 0) {\n            buf[sgitg] = sum;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        sum = buf[tiisg];\n        sum = simd_sum(sum);\n    }\n\n    if (psrc2) {\n        sum += exp(psrc2[i02] - max_val);\n    }\n\n    const float inv_sum = 1.0f/sum;\n\n    for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {\n        pdst4[i00] *= inv_sum;\n    }\n}\n\ntypedef decltype(kernel_soft_max<float>)    kernel_soft_max_t;\ntypedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;\n\ntemplate [[host_name(\"kernel_soft_max_f16\")]]   kernel kernel_soft_max_t   kernel_soft_max<half>;\ntemplate [[host_name(\"kernel_soft_max_f32\")]]   kernel kernel_soft_max_t   kernel_soft_max<float>;\ntemplate [[host_name(\"kernel_soft_max_f16_4\")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;\ntemplate [[host_name(\"kernel_soft_max_f32_4\")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;\n\n// ref: ggml.c:ggml_compute_forward_ssm_conv_f32\nkernel void kernel_ssm_conv_f32_f32(\n        constant ggml_metal_kargs_ssm_conv & args,\n        device const  void * src0,\n        device const  void * src1,\n        device       float * dst,\n        uint3 tgpig[[threadgroup_position_in_grid]],\n        uint3 tpitg[[thread_position_in_threadgroup]],\n        uint3   ntg[[threads_per_threadgroup]]) {\n    const int64_t ir = tgpig.x;\n    const int64_t i2 = tgpig.y;\n    const int64_t i3 = tgpig.z;\n\n    const int64_t nc  = args.ne10;\n  //const int64_t ncs = args.ne00;\n  //const int64_t nr  = args.ne01;\n  //const int64_t n_t = args.ne1;\n  //const int64_t n_s = args.ne2;\n\n    device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);\n    device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);\n    device       float * x = (device       float *) ((device       char *) dst  + ir*args.nb0  + i2*args.nb1  + i3*args.nb2);\n\n    float sumf = 0.0f;\n\n    for (int64_t i0 = 0; i0 < nc; ++i0) {\n        sumf += s[i0] * c[i0];\n    }\n\n    x[0] = sumf;\n}\n\nkernel void kernel_ssm_conv_f32_f32_4(\n        constant ggml_metal_kargs_ssm_conv & args,\n        device const  void * src0,\n        device const  void * src1,\n        device       float * dst,\n        uint3 tgpig[[threadgroup_position_in_grid]],\n        uint3 tpitg[[thread_position_in_threadgroup]],\n        uint3   ntg[[threads_per_threadgroup]]) {\n    const int64_t ir = tgpig.x;\n    const int64_t i2 = tgpig.y;\n    const int64_t i3 = tgpig.z;\n\n    const int64_t nc  = args.ne10;\n  //const int64_t ncs = args.ne00;\n  //const int64_t nr  = args.ne01;\n  //const int64_t n_t = args.ne1;\n  //const int64_t n_s = args.ne2;\n\n    device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);\n    device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);\n    device       float  * x = (device       float  *) ((device       char *) dst  + ir*args.nb0  + i2*args.nb1  + i3*args.nb2);\n\n    float sumf = 0.0f;\n\n    for (int64_t i0 = 0; i0 < nc/4; ++i0) {\n        sumf += dot(s[i0], c[i0]);\n    }\n\n    x[0] = sumf;\n}\n\nconstant short FC_ssm_conv_bs   [[function_constant(FC_SSM_CONV + 0)]];\n\n// Batched version: each threadgroup processes multiple tokens for better efficiency\n// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens\nkernel void kernel_ssm_conv_f32_f32_batched(\n        constant ggml_metal_kargs_ssm_conv & args,\n        device const  void * src0,\n        device const  void * src1,\n        device       float * dst,\n        uint3 tgpig[[threadgroup_position_in_grid]],\n        uint3 tpitg[[thread_position_in_threadgroup]],\n        uint3   ntg[[threads_per_threadgroup]]) {\n    // tgpig.x = row index (ir)\n    // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)\n    // tgpig.z = sequence index (i3)\n    // tpitg.x = thread within batch (0..BATCH_SIZE-1)\n    const short BATCH_SIZE = FC_ssm_conv_bs;\n\n    const int64_t ir      = tgpig.x;\n    const int64_t i2_base = tgpig.y * BATCH_SIZE;\n    const int64_t i3      = tgpig.z;\n    const int64_t i2_off  = tpitg.x;\n    const int64_t i2      = i2_base + i2_off;\n\n    const int64_t nc  = args.ne10;  // conv kernel size (typically 4)\n    const int64_t n_t = args.ne1;   // number of tokens\n\n    // Bounds check for partial batches at the end\n    if (i2 >= n_t) {\n        return;\n    }\n\n    // Load conv weights (shared across all tokens for this row)\n    device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);\n\n    // Load source for this specific token\n    device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);\n\n    // Output location for this token\n    device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);\n\n    float sumf = 0.0f;\n    for (int64_t i0 = 0; i0 < nc; ++i0) {\n        sumf += s[i0] * c[i0];\n    }\n\n    x[0] = sumf;\n}\n\nkernel void kernel_ssm_conv_f32_f32_batched_4(\n        constant ggml_metal_kargs_ssm_conv & args,\n        device const  void * src0,\n        device const  void * src1,\n        device       float * dst,\n        uint3 tgpig[[threadgroup_position_in_grid]],\n        uint3 tpitg[[thread_position_in_threadgroup]],\n        uint3   ntg[[threads_per_threadgroup]]) {\n    // tgpig.x = row index (ir)\n    // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)\n    // tgpig.z = sequence index (i3)\n    // tpitg.x = thread within batch (0..BATCH_SIZE-1)\n    const short BATCH_SIZE = FC_ssm_conv_bs;\n\n    const int64_t ir      = tgpig.x;\n    const int64_t i2_base = tgpig.y * BATCH_SIZE;\n    const int64_t i3      = tgpig.z;\n    const int64_t i2_off  = tpitg.x;\n    const int64_t i2      = i2_base + i2_off;\n\n    const int64_t nc  = args.ne10;  // conv kernel size (typically 4)\n    const int64_t n_t = args.ne1;   // number of tokens\n\n    // Bounds check for partial batches at the end\n    if (i2 >= n_t) {\n        return;\n    }\n\n    // Load conv weights (shared across all tokens for this row)\n    device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);\n\n    // Load source for this specific token\n    device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);\n\n    // Output location for this token\n    device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);\n\n    float sumf = 0.0f;\n    for (int64_t i0 = 0; i0 < nc/4; ++i0) {\n        sumf += dot(s[i0], c[i0]);\n    }\n\n    x[0] = sumf;\n}\n\n// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part\n// Optimized version: reduces redundant memory loads by having one thread load shared values\nkernel void kernel_ssm_scan_f32(\n        constant ggml_metal_kargs_ssm_scan & args,\n        device const void * src0,\n        device const void * src1,\n        device const void * src2,\n        device const void * src3,\n        device const void * src4,\n        device const void * src5,\n        device const void * src6,\n        device      float * dst,\n        threadgroup float * shared [[threadgroup(0)]],\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort  sgptg[[simdgroups_per_threadgroup]],\n        uint3    tgpg[[threadgroups_per_grid]]) {\n    constexpr short NW = N_SIMDWIDTH;\n\n    // Shared memory layout:\n    // [0..sgptg*NW-1]: partial sums for reduction (existing)\n    // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch\n    // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch\n    threadgroup float * shared_sums = shared;\n    threadgroup float * shared_x_dt = shared + sgptg * NW;\n    threadgroup float * shared_dA   = shared + sgptg * NW + sgptg;\n\n    shared_sums[tpitg.x] = 0.0f;\n\n    const int32_t i0 = tpitg.x;\n    const int32_t i1 = tgpig.x;\n    const int32_t ir = tgpig.y; // current head\n    const int32_t i3 = tgpig.z; // current seq\n\n    const int32_t nc  = args.d_state;\n    const int32_t nr  = args.d_inner;\n    const int32_t nh  = args.n_head;\n    const int32_t ng  = args.n_group;\n    const int32_t n_t = args.n_seq_tokens;\n\n    const int32_t s_off = args.s_off;\n\n    device const int32_t * ids = (device const int32_t *) src6;\n\n    device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);\n    device       float * s_buff  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);\n\n    const int32_t i = i0 + i1*nc;\n    const int32_t g = ir / (nh / ng); // repeat_interleave\n\n    float s0 = s0_buff[i];\n    float s  = 0.0f;\n\n    device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}\n\n    const float A0 = A[i0%args.ne30];\n\n    device const float * x  = (device const float *)((device const char *) src1 + i1*args.nb10  + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}\n    device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20  + i3*args.nb22);                // {nh, nt, ns}\n    device const float * B  = (device const float *)((device const char *) src4 +  g*args.nb41  + i3*args.nb43);                // {d_state, ng, nt, ns}\n    device const float * C  = (device const float *)((device const char *) src5 +  g*args.nb51  + i3*args.nb53);                // {d_state, ng, nt, ns}\n\n    device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}\n\n    for (int i2 = 0; i2 < n_t; i2 += sgptg) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Pre-compute x_dt and dA for this batch of tokens\n        // Only first sgptg threads do the loads and expensive math\n        if (i0 < sgptg && i2 + i0 < n_t) {\n            // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)\n            device const float * x_t  = x  + i0 * args.ns12;\n            device const float * dt_t = dt + i0 * args.ns21;\n\n            const float dt0  = dt_t[0];\n            const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;\n            shared_x_dt[i0] = x_t[0] * dtsp;\n            shared_dA[i0]   = dtsp;  // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        for (int t = 0; t < sgptg && i2 + t < n_t; t++) {\n            const float x_dt = shared_x_dt[t];\n            const float dA   = exp(shared_dA[t] * A0);\n\n            s = (s0 * dA) + (B[i0] * x_dt);\n\n            const float sumf = simd_sum(s * C[i0]);\n\n            if (tiisg == 0) {\n                shared_sums[t*NW + sgitg] = sumf;\n            }\n\n            // recurse\n            s0 = s;\n\n            B  += args.ns42;\n            C  += args.ns52;\n        }\n\n        // Advance pointers for next batch\n        x  += sgptg * args.ns12;\n        dt += sgptg * args.ns21;\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);\n\n        if (tiisg == 0 && i2 + sgitg < n_t) {\n            y[sgitg*nh*nr] = sumf;\n        }\n\n        y += sgptg*nh*nr;\n    }\n\n    s_buff[i] = s;\n}\n\nkernel void kernel_rwkv_wkv6_f32(\n    device const float * k,\n    device const float * v,\n    device const float * r,\n    device const float * tf,\n    device const float * td,\n    device const float * state_in,\n    device       float * dst,\n    constant    uint & B,\n    constant    uint & T,\n    constant    uint & C,\n    constant    uint & H,\n    uint3 tgpig[[threadgroup_position_in_grid]],\n    uint3 tpitg[[thread_position_in_threadgroup]],\n    uint3   ntg[[threads_per_threadgroup]])  {\n\n    const uint head_size = 64; // TODO: support head_size = 128\n    const uint batch_id = tgpig.x / H;\n    const uint head_id = tgpig.x % H;\n    const uint tid = tpitg.x;\n\n    if (batch_id >= B || head_id >= H) {\n        return;\n    }\n\n    const uint state_size = C * head_size;\n    const uint n_seq_tokens = T / B;\n\n    threadgroup float _k[head_size];\n    threadgroup float _r[head_size];\n    threadgroup float _tf[head_size];\n    threadgroup float _td[head_size];\n\n    float state[head_size];\n\n    for (uint i = 0; i < head_size; i++) {\n        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size\n                          + i * head_size + tid];\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    _tf[tid] = tf[head_id * head_size + tid];\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;\n    const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;\n\n    for (uint t = start_t; t < end_t; t += C) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        _k[tid] = k[t];\n        _r[tid] = r[t];\n        _td[tid] = td[t];\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        const float v_val = v[t];\n        float y = 0.0;\n\n        for (uint j = 0; j < head_size; j += 4) {\n            float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);\n            float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);\n            float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);\n            float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);\n            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);\n\n            float4 kv = k_vec * v_val;\n\n            float4 temp = tf_vec * kv + s_vec;\n            y += dot(r_vec, temp);\n\n            s_vec = s_vec * td_vec + kv;\n            state[j]   = s_vec[0];\n            state[j+1] = s_vec[1];\n            state[j+2] = s_vec[2];\n            state[j+3] = s_vec[3];\n        }\n\n        dst[t] = y;\n    }\n\n    for (uint i = 0; i < head_size; i++) {\n        dst[T * C + batch_id * state_size + head_id * head_size * head_size\n            + i * head_size + tid] = state[i];\n    }\n}\n\nkernel void kernel_rwkv_wkv7_f32(\n    device const float * r,\n    device const float * w,\n    device const float * k,\n    device const float * v,\n    device const float * a,\n    device const float * b,\n    device const float * state_in,\n    device       float * dst,\n    constant    uint & B,\n    constant    uint & T,\n    constant    uint & C,\n    constant    uint & H,\n    uint3 tgpig[[threadgroup_position_in_grid]],\n    uint3 tpitg[[thread_position_in_threadgroup]],\n    uint3   ntg[[threads_per_threadgroup]])  {\n\n    const uint head_size = 64; // TODO: support head_size = 128\n    const uint batch_id = tgpig.x / H;\n    const uint head_id = tgpig.x % H;\n    const uint tid = tpitg.x;\n\n    if (batch_id >= B || head_id >= H) {\n        return;\n    }\n\n    const uint state_size = C * head_size;\n    const uint n_seq_tokens = T / B;\n\n    threadgroup float _r[head_size];\n    threadgroup float _w[head_size];\n    threadgroup float _k[head_size];\n    threadgroup float _a[head_size];\n    threadgroup float _b[head_size];\n\n    float state[head_size];\n\n    for (uint i = 0; i < head_size; i++) {\n        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size\n                          + tid * head_size + i];\n    }\n\n    const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;\n    const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;\n\n    for (uint t = start_t; t < end_t; t += C) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        _r[tid] = r[t];\n        _w[tid] = w[t];\n        _k[tid] = k[t];\n        _a[tid] = a[t];\n        _b[tid] = b[t];\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        const float v_val = v[t];\n        float y = 0.0, sa = 0.0;\n\n        float4 sa_vec(0.0);\n\n        for (uint j = 0; j < head_size; j += 4) {\n            float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);\n            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);\n            sa_vec += a_vec * s_vec;\n        }\n        sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];\n\n        for (uint j = 0; j < head_size; j += 4) {\n            float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);\n            float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);\n            float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);\n            float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);\n            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);\n\n            float4 kv = k_vec * v_val;\n\n            s_vec = s_vec * w_vec + kv + sa * b_vec;\n            y += dot(s_vec, r_vec);\n\n            state[j]   = s_vec[0];\n            state[j+1] = s_vec[1];\n            state[j+2] = s_vec[2];\n            state[j+3] = s_vec[3];\n        }\n\n        dst[t] = y;\n    }\n\n    for (uint i = 0; i < head_size; i++) {\n        dst[T * C + batch_id * state_size + head_id * head_size * head_size\n            + tid * head_size + i] = state[i];\n    }\n}\n\nconstant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];\nconstant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];\n\n#if 1\ntemplate<short NSG>\nkernel void kernel_gated_delta_net_impl(\n        constant ggml_metal_kargs_gated_delta_net & args,\n        device const char * q,\n        device const char * k,\n        device const char * v,\n        device const char * g,\n        device const char * b,\n        device const char * s,\n        device       char * dst,\n        uint3 tgpig[[threadgroup_position_in_grid]],\n        uint3 tpitg[[thread_position_in_threadgroup]],\n        uint3   ntg[[threads_per_threadgroup]])  {\n#define S_v FC_gated_delta_net_ne20\n#define G   FC_gated_delta_net_ne30\n\n    const uint tx = tpitg.x;\n    const uint ty = tpitg.y;\n\n    const uint i23 = tgpig.z; // B\n    const uint i21 = tgpig.y; // H\n    const uint i20 = tgpig.x*NSG + ty;\n\n    const uint i01 = i21 % args.ne01;\n    const uint i11 = i21 % args.ne11;\n\n    const float scale = 1.0f / sqrt((float)S_v);\n\n    // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous\n    device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;\n\n    float ls[NSG];\n\n    FOR_UNROLL (short j = 0; j < NSG; j++) {\n        const short is = tx*NSG + j;\n        ls[j] = s_ptr[is];\n    }\n\n    device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;\n\n    device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);\n    device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);\n    device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);\n\n    device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);\n    device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;\n\n    for (short t = 0; t < args.ne22; t++) {\n        float s_k = 0.0f;\n\n        if (G == 1) {\n            const float g_exp = exp(g_ptr[0]);\n\n            FOR_UNROLL (short j = 0; j < NSG; j++) {\n                const short is = tx*NSG + j;\n                ls[j] *= g_exp;\n\n                s_k += ls[j]*k_ptr[is];\n            }\n        } else {\n            // KDA\n            FOR_UNROLL (short j = 0; j < NSG; j++) {\n                const short is = tx*NSG + j;\n                ls[j] *= exp(g_ptr[is]);\n\n                s_k += ls[j]*k_ptr[is];\n            }\n        }\n\n        s_k = simd_sum(s_k);\n\n        const float d = (v_ptr[i20] - s_k)*b_ptr[0];\n\n        float y = 0.0f;\n\n        FOR_UNROLL (short j = 0; j < NSG; j++) {\n            const short is = tx*NSG + j;\n            ls[j] += k_ptr[is]*d;\n\n            y += ls[j]*q_ptr[is];\n        }\n\n        y = simd_sum(y);\n\n        if (tx == 0) {\n            dst_attn[t*args.ne21*S_v] = y*scale;\n        }\n\n        q_ptr += args.ns02;\n        k_ptr += args.ns12;\n        v_ptr += args.ns22;\n\n        b_ptr += args.ne21;\n        g_ptr += args.ne21*G;\n    }\n\n    device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;\n\n    FOR_UNROLL (short j = 0; j < NSG; j++) {\n        const short is = tx*NSG + j;\n        dst_state[is] = ls[j];\n    }\n\n#undef S_v\n#undef G\n}\n\ntypedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;\n\ntemplate [[host_name(\"kernel_gated_delta_net_f32_1\")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;\ntemplate [[host_name(\"kernel_gated_delta_net_f32_2\")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;\ntemplate [[host_name(\"kernel_gated_delta_net_f32_4\")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;\n\n#else\n// a simplified version of the above\n// no performance improvement, so keep the above version for now\n\ntemplate<typename T, short NSG>\nkernel void kernel_gated_delta_net_impl(\n        constant ggml_metal_kargs_gated_delta_net & args,\n        device const char * q,\n        device const char * k,\n        device const char * v,\n        device const char * g,\n        device const char * b,\n        device const char * s,\n        device       char * dst,\n        uint3 tgpig[[threadgroup_position_in_grid]],\n        uint3 tpitg[[thread_position_in_threadgroup]],\n        uint3   ntg[[threads_per_threadgroup]])  {\n#define S_v FC_gated_delta_net_ne20\n#define G   FC_gated_delta_net_ne30\n\n    const uint tx = tpitg.x;\n    const uint ty = tpitg.y;\n\n    const uint i23 = tgpig.z; // B\n    const uint i21 = tgpig.y; // H\n    const uint i20 = tgpig.x*NSG + ty;\n\n    const uint i01 = i21 % args.ne01;\n    const uint i11 = i21 % args.ne11;\n\n    const float scale = 1.0f / sqrt((float)S_v);\n\n    device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;\n\n    float lsf[NSG];\n\n    FOR_UNROLL (short j = 0; j < NSG; j++) {\n        const short is = tx*NSG + j;\n        lsf[j] = s_ptr[is*S_v];\n    }\n\n    thread T * ls = (thread T *) (lsf);\n\n    device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;\n\n    device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);\n    device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);\n    device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);\n\n    device const float * b_ptr  = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);\n    device const float * g_ptr  = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;\n\n    for (short t = 0; t < args.ne22; t++) {\n        device const T * qt_ptr = (device const T *) (q_ptr);\n        device const T * kt_ptr = (device const T *) (k_ptr);\n        device const T * gt_ptr = (device const T *) (g_ptr);\n\n        if (G == 1) {\n            *ls *= exp(g_ptr[0]);\n        } else {\n            // KDA\n            *ls *= exp(gt_ptr[tx]);\n        }\n\n        const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));\n\n        const float d = (v_ptr[i20] - s_k)*b_ptr[0];\n\n        *ls += kt_ptr[tx]*d;\n\n        const float y = simd_sum(dot(*ls, qt_ptr[tx]));\n\n        if (tx == 0) {\n            *dst_attn = y*scale;\n        }\n\n        q_ptr += args.ns02;\n        k_ptr += args.ns12;\n        v_ptr += args.ns22;\n\n        b_ptr += args.ne21;\n        g_ptr += args.ne21*G;\n\n        dst_attn += args.ne21*S_v;\n    }\n\n    device float * dst_state  = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;\n    device T     * dstt_state = (device T     *) (dst_state);\n\n    FOR_UNROLL (short j = 0; j < NSG; j++) {\n        const short is = tx*NSG + j;\n        dst_state[is*S_v] = lsf[j];\n    }\n\n#undef S_v\n#undef G\n}\n\ntypedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;\n\ntemplate [[host_name(\"kernel_gated_delta_net_f32_1\")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float,  1>;\ntemplate [[host_name(\"kernel_gated_delta_net_f32_2\")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;\ntemplate [[host_name(\"kernel_gated_delta_net_f32_4\")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;\n#endif\n\nconstant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];\nconstant short FC_solve_tri_n   [[function_constant(FC_SOLVE_TRI + 1)]];\nconstant short FC_solve_tri_k   [[function_constant(FC_SOLVE_TRI + 2)]];\n\nkernel void kernel_solve_tri_f32(\n        constant ggml_metal_kargs_solve_tri & args,\n        device   const char * src0,\n        device   const char * src1,\n        device         char * dst,\n        threadgroup    char * shmem [[threadgroup(0)]],\n        ushort3 tgpig[[threadgroup_position_in_grid]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    constexpr short NW = N_SIMDWIDTH;\n\n    const short NSG = FC_solve_tri_nsg;\n    const short N   = FC_solve_tri_n;\n    const short K   = FC_solve_tri_k;\n    const short NP  = PAD2(N, NW);\n\n    const int32_t i03 = tgpig.z;\n    const int32_t i02 = tgpig.y;\n    const int32_t i01 = tgpig.x*NSG + sgitg;\n\n    threadgroup float * sh0 = (threadgroup float *) shmem;\n\n    device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;\n    device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;\n    device       float * dst_ptr  = (device       float *)(dst  + i02 * args.nb2  + i03 * args.nb3)  + i01;\n\n    for (short rr = 0; rr < N; rr += NSG) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        {\n            threadgroup float * sh0_cur = sh0 + sgitg*NP;\n\n            for (short t = 0; t*NW < N; ++t) {\n                const short idx = t*NW + tiisg;\n                sh0_cur[idx] = src0_ptr[idx];\n            }\n\n            src0_ptr += NSG*N;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        if (i01 >= args.ne10) {\n            continue;\n        }\n\n        for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {\n            const short r = rr + ir;\n\n            threadgroup float * sh0_cur = sh0 + ir*NP;\n\n            float sum = 0.0f;\n\n            for (short t = 0; t*NW < r; ++t) {\n                const short idx = t*NW + tiisg;\n                sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);\n            }\n\n            sum = simd_sum(sum);\n\n            if (tiisg == 0) {\n                const float diag = sh0_cur[r];\n\n                dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;\n            }\n        }\n    }\n}\n\nkernel void kernel_argmax_f32(\n        constant ggml_metal_kargs_argmax & args,\n        device   const char * src0,\n        device         char * dst,\n        threadgroup    char * shmem [[threadgroup(0)]],\n        uint  tgpig[[threadgroup_position_in_grid]],\n        uint  tpitg[[thread_position_in_threadgroup]],\n        uint  sgitg[[simdgroup_index_in_threadgroup]],\n        uint  tiisg[[thread_index_in_simdgroup]],\n        uint    ntg[[threads_per_threadgroup]]) {\n    device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01);\n\n    float   lmax = -INFINITY;\n    int32_t larg = -1;\n\n    for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {\n        if (x_row[i00] > lmax) {\n            lmax = x_row[i00];\n            larg = i00;\n        }\n    }\n\n    // find the argmax value in the block\n    float max_val = simd_max(lmax);\n    int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));\n\n    device int32_t * dst_i32 = (device int32_t *) dst;\n\n    threadgroup   float * shared_maxval = (threadgroup   float *) shmem;\n    threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH;\n\n    if (ntg > N_SIMDWIDTH) {\n        if (sgitg == 0) {\n            shared_maxval[tiisg] = -INFINITY;\n            shared_argmax[tiisg] = -1;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        if (tiisg == 0) {\n            shared_maxval[sgitg] = max_val;\n            shared_argmax[sgitg] = arg_val;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        max_val = shared_maxval[tiisg];\n        arg_val = shared_argmax[tiisg];\n\n        float max_val_reduced   = simd_max(max_val);\n        int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));\n\n        dst_i32[tgpig] = arg_val_reduced;\n\n        return;\n    }\n\n    dst_i32[tgpig] = arg_val;\n}\n\n// F == 1 : norm (no fuse)\n// F == 2 : norm + mul\n// F == 3 : norm + mul + add\ntemplate <typename T, short F>\nkernel void kernel_norm_fuse_impl(\n        constant ggml_metal_kargs_norm & args,\n        device const char * src0,\n        device const char * src1_0,\n        device const char * src1_1,\n        device       char * dst,\n        threadgroup float * shmem_f32 [[threadgroup(0)]],\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    if (sgitg == 0) {\n        shmem_f32[tiisg] = 0.0f;\n    }\n\n    const int i01 = tgpig.x;\n    const int i02 = tgpig.y;\n    const int i03 = tgpig.z;\n\n    device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);\n\n    device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);\n    device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);\n\n    T sumft(0.0f);\n\n    float sumf = 0.0f;\n\n    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {\n        sumft += x[i00];\n    }\n    sumf = dot(sumft, T(1.0f));\n    sumf = simd_sum(sumf);\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    if (tiisg == 0) {\n        shmem_f32[sgitg] = sumf;\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    sumf = shmem_f32[tiisg];\n    sumf = simd_sum(sumf);\n\n    const float mean = sumf/args.ne00;\n\n    device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);\n\n    sumf = 0.0f;\n    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {\n        y[i00] = x[i00] - mean;\n        sumf += dot(y[i00], y[i00]);\n    }\n    sumf = simd_sum(sumf);\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    if (tiisg == 0) {\n        shmem_f32[sgitg] = sumf;\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    sumf = shmem_f32[tiisg];\n    sumf = simd_sum(sumf);\n\n    const float variance = sumf/args.ne00;\n\n    const float scale = 1.0f/sqrt(variance + args.eps);\n    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {\n        if (F == 1) {\n            y[i00] = (y[i00]*scale);\n        }\n        if (F == 2) {\n            y[i00] = (y[i00]*scale)*f0[i00];\n        }\n        if (F == 3) {\n            y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];\n        }\n    }\n}\n\ntypedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;\n\ntemplate [[host_name(\"kernel_norm_f32\")]]         kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;\ntemplate [[host_name(\"kernel_norm_mul_f32\")]]     kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;\ntemplate [[host_name(\"kernel_norm_mul_add_f32\")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;\n\ntemplate [[host_name(\"kernel_norm_f32_4\")]]         kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;\ntemplate [[host_name(\"kernel_norm_mul_f32_4\")]]     kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;\ntemplate [[host_name(\"kernel_norm_mul_add_f32_4\")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;\n\n// F == 1 : rms_norm (no fuse)\n// F == 2 : rms_norm + mul\n// F == 3 : rms_norm + mul + add\ntemplate <typename T, short F>\nkernel void kernel_rms_norm_fuse_impl(\n        constant ggml_metal_kargs_norm & args,\n        device const char * src0,\n        device const char * src1_0,\n        device const char * src1_1,\n        device       char * dst,\n        threadgroup float * shmem_f32 [[threadgroup(0)]],\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    if (sgitg == 0) {\n        shmem_f32[tiisg] = 0.0f;\n    }\n\n    const int i01 = tgpig.x;\n    const int i02 = tgpig.y;\n    const int i03 = tgpig.z;\n\n    device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);\n\n    device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);\n    device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);\n\n    float sumf = 0.0f;\n\n    // parallel sum\n    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {\n        sumf += dot(x[i00], x[i00]);\n    }\n    sumf = simd_sum(sumf);\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    if (tiisg == 0) {\n        shmem_f32[sgitg] = sumf;\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    sumf = shmem_f32[tiisg];\n    sumf = simd_sum(sumf);\n\n    const float mean  = sumf/args.ne00;\n    const float scale = 1.0f/sqrt(mean + args.eps);\n\n    device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);\n    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {\n        if (F == 1) {\n            y[i00] = (x[i00]*scale);\n        }\n        if (F == 2) {\n            y[i00] = (x[i00]*scale)*f0[i00];\n        }\n        if (F == 3) {\n            y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];\n        }\n    }\n}\n\ntypedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;\n\ntemplate [[host_name(\"kernel_rms_norm_f32\")]]         kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;\ntemplate [[host_name(\"kernel_rms_norm_mul_f32\")]]     kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;\ntemplate [[host_name(\"kernel_rms_norm_mul_add_f32\")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;\n\ntemplate [[host_name(\"kernel_rms_norm_f32_4\")]]         kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;\ntemplate [[host_name(\"kernel_rms_norm_mul_f32_4\")]]     kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;\ntemplate [[host_name(\"kernel_rms_norm_mul_add_f32_4\")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;\n\ntemplate <typename T0, typename T>\nkernel void kernel_l2_norm_impl(\n        constant ggml_metal_kargs_l2_norm & args,\n        device const char * src0,\n        device       char * dst,\n        threadgroup float * shmem_f32 [[threadgroup(0)]],\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    const int i03 = tgpig.z;\n    const int i02 = tgpig.y;\n    const int i01 = tgpig.x;\n\n    if (sgitg == 0) {\n        shmem_f32[tiisg] = 0.0f;\n    }\n\n    device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);\n    device       T  * y = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1);\n\n    float sumf = 0.0f;\n\n    // parallel sum\n    for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {\n        sumf += dot(x[i00], x[i00]);\n    }\n    sumf = simd_sum(sumf);\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    if (tiisg == 0) {\n        shmem_f32[sgitg] = sumf;\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    sumf = shmem_f32[tiisg];\n    sumf = simd_sum(sumf);\n\n    const float scale = 1.0f/max(sqrt(sumf), args.eps);\n\n    for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {\n        y[i00] = x[i00] * scale;\n    }\n}\n\ntypedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;\n\ntemplate [[host_name(\"kernel_l2_norm_f32_f32\")]]   kernel kernel_l2_norm_t kernel_l2_norm_impl<float,  float>;\ntemplate [[host_name(\"kernel_l2_norm_f32_f32_4\")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;\n\nkernel void kernel_group_norm_f32(\n        constant ggml_metal_kargs_group_norm & args,\n        device const float * src0,\n        device       float * dst,\n        threadgroup float  * buf [[threadgroup(0)]],\n        uint tgpig[[threadgroup_position_in_grid]],\n        uint tpitg[[thread_position_in_threadgroup]],\n        uint sgitg[[simdgroup_index_in_threadgroup]],\n        uint tiisg[[thread_index_in_simdgroup]],\n        uint   ntg[[threads_per_threadgroup]]) {\n    const int64_t ne = args.ne00*args.ne01*args.ne02;\n    const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp);\n\n    int start = tgpig * gs;\n    int end   = start + gs;\n\n    start += tpitg;\n\n    if (end >= ne) {\n        end = ne;\n    }\n\n    float tmp = 0.0f; // partial sum for thread in warp\n\n    for (int j = start; j < end; j += ntg) {\n        tmp += src0[j];\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    tmp = simd_sum(tmp);\n    if (ntg > N_SIMDWIDTH) {\n        if (sgitg == 0) {\n            buf[tiisg] = 0.0f;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        if (tiisg == 0) {\n            buf[sgitg] = tmp;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        tmp = buf[tiisg];\n        tmp = simd_sum(tmp);\n    }\n\n    const float mean = tmp / gs;\n    tmp = 0.0f;\n\n    for (int j = start; j < end; j += ntg) {\n        float xi = src0[j] - mean;\n        dst[j] = xi;\n        tmp += xi * xi;\n    }\n\n    tmp = simd_sum(tmp);\n    if (ntg > N_SIMDWIDTH) {\n        if (sgitg == 0) {\n            buf[tiisg] = 0.0f;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        if (tiisg == 0) {\n            buf[sgitg] = tmp;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        tmp = buf[tiisg];\n        tmp = simd_sum(tmp);\n    }\n\n    const float variance = tmp / gs;\n    const float scale = 1.0f/sqrt(variance + args.eps);\n    for (int j = start; j < end; j += ntg) {\n        dst[j] *= scale;\n    }\n}\n\n// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])\n// il indicates where the q4 quants begin (0 or QK4_0/4)\n// we assume that the yl's have been multiplied with the appropriate scale factor\n// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)\ninline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {\n    float d = qb_curr->d;\n\n    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };\n\n    device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);\n\n    for (int i = 0; i < 8; i += 2) {\n        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);\n        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);\n        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);\n        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);\n    }\n\n    return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);\n}\n\n// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])\n// il indicates where the q4 quants begin (0 or QK4_0/4)\n// we assume that the yl's have been multiplied with the appropriate scale factor\n// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)\ninline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {\n    float d = qb_curr->d;\n    float m = qb_curr->m;\n\n    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };\n\n    device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);\n\n    for (int i = 0; i < 8; i+=2) {\n        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);\n        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);\n        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);\n        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);\n    }\n\n    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;\n}\n\n// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])\n// il indicates where the q5 quants begin (0 or QK5_0/4)\n// we assume that the yl's have been multiplied with the appropriate scale factor\n// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)\ninline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {\n    float d = qb_curr->d;\n\n    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };\n\n    device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 3 + il/2);\n           const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);\n\n    for (int i = 0; i < 8; i+=2) {\n        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010));\n        acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));\n        acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));\n        acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));\n    }\n\n    return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);\n}\n\n// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])\n// il indicates where the q5 quants begin (0 or QK5_1/4)\n// we assume that the yl's have been multiplied with the appropriate scale factor\n// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)\ninline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {\n    float d = qb_curr->d;\n    float m = qb_curr->m;\n\n    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };\n\n    device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 4 + il/2);\n           const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);\n\n    for (int i = 0; i < 8; i+=2) {\n        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010));\n        acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));\n        acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));\n        acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));\n    }\n\n    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;\n}\n\ntemplate<short NR0>\nstatic inline void helper_mv_reduce_and_write(\n        device float * dst_f32,\n        float sumf[NR0],\n        const int r0,\n        const int ne01,\n        ushort tiisg,\n        ushort sgitg,\n        threadgroup char * shmem) {\n    constexpr short NW = N_SIMDWIDTH;\n\n    threadgroup float * shmem_f32[NR0];\n\n    for (short row = 0; row < NR0; ++row) {\n        shmem_f32[row] = (threadgroup float *) shmem + NW*row;\n\n        if (sgitg == 0) {\n            shmem_f32[row][tiisg] = 0.0f;\n        }\n\n        sumf[row] = simd_sum(sumf[row]);\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    for (short row = 0; row < NR0; ++row) {\n        if (tiisg == 0) {\n            shmem_f32[row][sgitg] = sumf[row];\n        }\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    for (short row = 0; row < NR0 && r0 + row < ne01; ++row) {\n        float tot = simd_sum(shmem_f32[row][tiisg]);\n\n        if (tiisg == 0 && sgitg == 0) {\n            dst_f32[r0 + row] = tot;\n        }\n    }\n}\n\nconstant short FC_mul_mv_nsg   [[function_constant(FC_MUL_MV + 0)]];\nconstant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];\n\ntemplate<typename block_q_type, short NR0, typename args_t>\nvoid mul_vec_q_n_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    constexpr short NW = N_SIMDWIDTH;\n    constexpr short NQ = 16;\n\n    const int nb = args.ne00/QK4_0;\n\n    const int r0 = (tgpig.x*NSG + sgitg)*NR0;\n  //const int r0 =  tgpig.x*NR0;\n    const int r1 =  tgpig.y;\n    const int im =  tgpig.z;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n  //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n  //device const block_q_type * x = (device const block_q_type *) (src0 + offset0);\n    device const float        * y = (device const float        *) (src1 + offset1);\n\n    // pointers to src0 rows\n    device const block_q_type * ax[NR0];\n    FOR_UNROLL (int row = 0; row < NR0; ++row) {\n        const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n\n        ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);\n    }\n\n    float sumf[NR0] = {0.f};\n\n    const short ix = (tiisg/(NW/NQ));\n    const short il = (tiisg%(NW/NQ))*8;\n\n    //const int ib0 = sgitg*NQ + ix;\n    const int ib0 = ix;\n\n    float yl[16]; // src1 vector cache\n\n    //device const float * yb = y + ix*QK4_0 + il;\n    device const float * yb = y + ib0*QK4_0 + il;\n\n    // each thread in a SIMD group deals with half a block.\n    //for (int ib = ib0; ib < nb; ib += NSG*NQ) {\n    for (int ib = ib0; ib < nb; ib += NQ) {\n        float sumy[2] = { 0.f, 0.f };\n\n        FOR_UNROLL (short i = 0; i < 8; i += 2) {\n            sumy[0]  += yb[i +  0] + yb[i +  1];\n            yl[i + 0] = yb[i +  0];\n            yl[i + 1] = yb[i +  1]/256.f;\n\n            sumy[1]  += yb[i + 16] + yb[i + 17];\n            yl[i + 8] = yb[i + 16]/16.f;\n            yl[i + 9] = yb[i + 17]/4096.f;\n        }\n\n        FOR_UNROLL (short row = 0; row < NR0; row++) {\n            sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);\n        }\n\n        yb += QK4_0 * 16;\n        //yb += NSG*NQ*QK4_0;\n    }\n\n    device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;\n\n    //helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);\n\n    for (int row = 0; row < NR0; ++row) {\n        const float tot = simd_sum(sumf[row]);\n\n        if (tiisg == 0 && r0 + row < args.ne01) {\n            dst_f32[r0 + row] = tot;\n        }\n    }\n}\n\nkernel void kernel_mul_mv_q4_0_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n    mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\nkernel void kernel_mul_mv_q4_1_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n     mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\nkernel void kernel_mul_mv_q5_0_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n    mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\nkernel void kernel_mul_mv_q5_1_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n    mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\ntemplate<short NR0, typename args_t>\nvoid kernel_mul_mv_q8_0_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    constexpr short NW = N_SIMDWIDTH;\n    constexpr short NQ = 8;\n\n    const int nb = args.ne00/QK8_0;\n\n    const int r0 = tgpig.x*NR0;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n  //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n  //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);\n    device const float      * y = (device const float      *) (src1 + offset1);\n\n    // pointers to src0 rows\n    device const block_q8_0 * ax[NR0];\n    FOR_UNROLL (short row = 0; row < NR0; ++row) {\n        const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n\n        ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);\n    }\n\n    float sumf[NR0] = { 0.f };\n\n    const short ix = tiisg/(NW/NQ);\n    const short il = tiisg%(NW/NQ);\n\n    const int ib0 = sgitg*NQ + ix;\n\n    float yl[NQ];\n\n    device const float * yb = y + ib0*QK8_0 + il*NQ;\n\n    // each thread in a SIMD group deals with NQ quants at a time\n    for (int ib = ib0; ib < nb; ib += NSG*NQ) {\n        for (short i = 0; i < NQ; ++i) {\n            yl[i] = yb[i];\n        }\n\n        for (short row = 0; row < NR0; row++) {\n            device const int8_t * qs = ax[row][ib].qs + il*NQ;\n\n            float sumq = 0.f;\n            FOR_UNROLL (short i = 0; i < NQ; ++i) {\n                sumq += qs[i] * yl[i];\n            }\n\n            sumf[row] += sumq*ax[row][ib].d;\n        }\n\n        yb += NSG*NQ*QK8_0;\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);\n}\n\n[[host_name(\"kernel_mul_mv_q8_0_f32\")]]\nkernel void kernel_mul_mv_q8_0_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n    kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\n// mat-vec kernel processing in chunks of float4\n// chpb - chunks per quantization block\ntemplate<short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >\nvoid kernel_mul_mv_ext_q4_f32_impl(\n        constant ggml_metal_kargs_mul_mv_ext & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {\n    const short NSG   = FC_mul_mv_nsg;\n    const short nxpsg = FC_mul_mv_nxpsg;\n\n    const short chpt = 4; // chunks per thread\n\n  //const short nxpsg = (32);\n    const short nypsg = (32/nxpsg);\n\n    const short tx = tiisg%nxpsg;\n    const short ty = tiisg/nxpsg;\n\n    const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;\n    const int i11 = tgpig.y*r1ptg;\n    const int i1m = tgpig.z;\n\n    const int i12 = i1m%args.ne12;\n    const int i13 = i1m/args.ne12;\n\n    const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 = i11*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;\n\n    device const float4 * y4[r1ptg];\n\n    for (int ir1 = 0; ir1 < r1ptg; ++ir1) {\n        y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;\n    }\n\n    float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };\n\n    short cch = tx%chpb; // current chunk index\n\n    for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {\n        float4 lx[chpt];\n\n#pragma unroll(chpt)\n        for (short ch = 0; ch < chpt; ++ch) {\n            deq_t4(xq, cch, lx[ch]);\n\n            cch += nxpsg;\n            if (cch >= chpb) {\n                xq  += cch/chpb;\n                cch %= chpb;\n            }\n        }\n\n#pragma unroll(chpt)\n        for (short ch = 0; ch < chpt; ++ch) {\n#pragma unroll(r1ptg)\n            for (short ir1 = 0; ir1 < r1ptg; ++ir1) {\n                sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);\n            }\n        }\n\n#pragma unroll(r1ptg)\n        for (short ir1 = 0; ir1 < r1ptg; ++ir1) {\n            y4[ir1] += chpt*nxpsg;\n        }\n    }\n\n    // reduce only the threads in each row\n    for (short ir1 = 0; ir1 < r1ptg; ++ir1) {\n        if (nxpsg >= 32) {\n            sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);\n        }\n        if (nxpsg >= 16) {\n            sumf[ir1] += simd_shuffle_down(sumf[ir1],  8);\n        }\n        if (nxpsg >= 8) {\n            sumf[ir1] += simd_shuffle_down(sumf[ir1],  4);\n        }\n        if (nxpsg >= 4) {\n            sumf[ir1] += simd_shuffle_down(sumf[ir1],  2);\n        }\n        if (nxpsg >= 2) {\n            sumf[ir1] += simd_shuffle_down(sumf[ir1],  1);\n        }\n\n        //sumf[ir1] = simd_sum(sumf[ir1]);\n    }\n\n    if (tx == 0) {\n        for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {\n            device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;\n\n            if (i01 < args.ne01) {\n                dst_f32[i01] = sumf[ir1];\n            }\n        }\n    }\n}\n\n// mat-vec kernel processing in chunks of float4x4\ntemplate<short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >\nvoid kernel_mul_mv_ext_q4x4_f32_impl(\n        constant ggml_metal_kargs_mul_mv_ext & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {\n    const short NSG   = FC_mul_mv_nsg;\n    const short nxpsg = FC_mul_mv_nxpsg;\n\n    const short chpt = 1;\n\n  //const short nxpsg = (32);\n    const short nypsg = (32/nxpsg);\n\n    const short tx = tiisg%nxpsg;\n    const short ty = tiisg/nxpsg;\n\n    const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;\n    const int i11 = tgpig.y*r1ptg;\n    const int i1m = tgpig.z;\n\n    const int i12 = i1m%args.ne12;\n    const int i13 = i1m/args.ne12;\n\n    const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 = i11*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;\n\n    device const float4x4 * y4x4[r1ptg];\n\n    for (int ir1 = 0; ir1 < r1ptg; ++ir1) {\n        y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1;\n    }\n\n    float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };\n\n    short cch = tx%chpb;\n\n    for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) {\n        float4x4 lx[chpt];\n\n#pragma unroll(chpt)\n        for (short ch = 0; ch < chpt; ++ch) {\n            deq_t4x4(xq, cch, lx[ch]);\n\n            cch += nxpsg;\n            if (cch >= chpb) {\n                xq  += cch/chpb;\n                cch %= chpb;\n            }\n        }\n\n#pragma unroll(chpt)\n        for (short ch = 0; ch < chpt; ++ch) {\n#pragma unroll(r1ptg)\n            for (short ir1 = 0; ir1 < r1ptg; ++ir1) {\n                sumf[ir1] +=\n                    dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) +\n                    dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) +\n                    dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) +\n                    dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]);\n\n            }\n        }\n\n#pragma unroll(r1ptg)\n        for (short ir1 = 0; ir1 < r1ptg; ++ir1) {\n            y4x4[ir1] += chpt*nxpsg;\n        }\n    }\n\n    for (short ir1 = 0; ir1 < r1ptg; ++ir1) {\n        if (nxpsg >= 32) {\n            sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);\n        }\n        if (nxpsg >= 16) {\n            sumf[ir1] += simd_shuffle_down(sumf[ir1],  8);\n        }\n        if (nxpsg >= 8) {\n            sumf[ir1] += simd_shuffle_down(sumf[ir1],  4);\n        }\n        if (nxpsg >= 4) {\n            sumf[ir1] += simd_shuffle_down(sumf[ir1],  2);\n        }\n        if (nxpsg >= 2) {\n            sumf[ir1] += simd_shuffle_down(sumf[ir1],  1);\n        }\n\n        //sumf[ir1] = simd_sum(sumf[ir1]);\n    }\n\n    if (tx == 0) {\n        for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {\n            device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;\n\n            if (i01 < args.ne01) {\n                dst_f32[i01] = sumf[ir1];\n            }\n        }\n    }\n}\n\n// dispatchers needed for compile-time nxpsg\n// epb - elements per quantization block\ntemplate<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)>\nkernel void kernel_mul_mv_ext_q4_f32_disp(\n        constant ggml_metal_kargs_mul_mv_ext & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {\n    kernel_mul_mv_ext_q4_f32_impl<r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg);\n}\n\ntemplate<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>\nkernel void kernel_mul_mv_ext_q4x4_f32_disp(\n        constant ggml_metal_kargs_mul_mv_ext & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {\n    kernel_mul_mv_ext_q4x4_f32_impl<r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg);\n}\n\ntypedef decltype(kernel_mul_mv_ext_q4_f32_disp  <2, block_q8_0, 32,  dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;\ntypedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>)    mul_mv_ext_q4x4_f32_t;\n\ntemplate [[host_name(\"kernel_mul_mv_ext_f32_f32_r1_2\")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4,       4,  dequantize_f32_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_f32_f32_r1_3\")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4,       4,  dequantize_f32_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_f32_f32_r1_4\")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4,       4,  dequantize_f32_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_f32_f32_r1_5\")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4,       4,  dequantize_f32_t4>;\n\ntemplate [[host_name(\"kernel_mul_mv_ext_f16_f32_r1_2\")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4,        4,  dequantize_f16_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_f16_f32_r1_3\")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4,        4,  dequantize_f16_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_f16_f32_r1_4\")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4,        4,  dequantize_f16_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_f16_f32_r1_5\")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4,        4,  dequantize_f16_t4>;\n\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_mul_mv_ext_bf16_f32_r1_2\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, bfloat4,      4,  dequantize_bf16_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_bf16_f32_r1_3\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, bfloat4,      4,  dequantize_bf16_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_bf16_f32_r1_4\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, bfloat4,      4,  dequantize_bf16_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_bf16_f32_r1_5\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4,      4,  dequantize_bf16_t4>;\n#endif\n\ntemplate [[host_name(\"kernel_mul_mv_ext_q4_0_f32_r1_2\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0,   32, dequantize_q4_0_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q4_0_f32_r1_3\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0,   32, dequantize_q4_0_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q4_0_f32_r1_4\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0,   32, dequantize_q4_0_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q4_0_f32_r1_5\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0,   32, dequantize_q4_0_t4>;\n\ntemplate [[host_name(\"kernel_mul_mv_ext_q4_1_f32_r1_2\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1,   32, dequantize_q4_1_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q4_1_f32_r1_3\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1,   32, dequantize_q4_1_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q4_1_f32_r1_4\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1,   32, dequantize_q4_1_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q4_1_f32_r1_5\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1,   32, dequantize_q4_1_t4>;\n\ntemplate [[host_name(\"kernel_mul_mv_ext_q5_0_f32_r1_2\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0,   32, dequantize_q5_0_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q5_0_f32_r1_3\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0,   32, dequantize_q5_0_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q5_0_f32_r1_4\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0,   32, dequantize_q5_0_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q5_0_f32_r1_5\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0,   32, dequantize_q5_0_t4>;\n\ntemplate [[host_name(\"kernel_mul_mv_ext_q5_1_f32_r1_2\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1,   32, dequantize_q5_1_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q5_1_f32_r1_3\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1,   32, dequantize_q5_1_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q5_1_f32_r1_4\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1,   32, dequantize_q5_1_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q5_1_f32_r1_5\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1,   32, dequantize_q5_1_t4>;\n\ntemplate [[host_name(\"kernel_mul_mv_ext_q8_0_f32_r1_2\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0,   32, dequantize_q8_0_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q8_0_f32_r1_3\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0,   32, dequantize_q8_0_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q8_0_f32_r1_4\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0,   32, dequantize_q8_0_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q8_0_f32_r1_5\")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0,   32, dequantize_q8_0_t4>;\n\ntemplate [[host_name(\"kernel_mul_mv_ext_mxfp4_f32_r1_2\")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4,  32, dequantize_mxfp4_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_mxfp4_f32_r1_3\")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4,  32, dequantize_mxfp4_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_mxfp4_f32_r1_4\")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4,  32, dequantize_mxfp4_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_mxfp4_f32_r1_5\")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4,  32, dequantize_mxfp4_t4>;\n\ntemplate [[host_name(\"kernel_mul_mv_ext_iq4_nl_f32_r1_2\")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_iq4_nl_f32_r1_3\")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_iq4_nl_f32_r1_4\")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;\ntemplate [[host_name(\"kernel_mul_mv_ext_iq4_nl_f32_r1_5\")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>;\n\ntemplate [[host_name(\"kernel_mul_mv_ext_q4_K_f32_r1_2\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q4_K_f32_r1_3\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q4_K_f32_r1_4\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q4_K_f32_r1_5\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>;\n\ntemplate [[host_name(\"kernel_mul_mv_ext_q5_K_f32_r1_2\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q5_K_f32_r1_3\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q5_K_f32_r1_4\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q5_K_f32_r1_5\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>;\n\ntemplate [[host_name(\"kernel_mul_mv_ext_q6_K_f32_r1_2\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q6_K_f32_r1_3\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q6_K_f32_r1_4\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q6_K_f32_r1_5\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;\n\ntemplate [[host_name(\"kernel_mul_mv_ext_q2_K_f32_r1_2\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q2_K, 256, dequantize_q2_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q2_K_f32_r1_3\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q2_K, 256, dequantize_q2_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q2_K_f32_r1_4\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q2_K, 256, dequantize_q2_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q2_K_f32_r1_5\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q2_K, 256, dequantize_q2_K>;\n\ntemplate [[host_name(\"kernel_mul_mv_ext_q3_K_f32_r1_2\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q3_K, 256, dequantize_q3_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q3_K_f32_r1_3\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q3_K, 256, dequantize_q3_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q3_K_f32_r1_4\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q3_K, 256, dequantize_q3_K>;\ntemplate [[host_name(\"kernel_mul_mv_ext_q3_K_f32_r1_5\")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q3_K, 256, dequantize_q3_K>;\n\ntemplate<typename T0, typename T1, short NR0, typename args_t>\nvoid kernel_mul_mv_t_t_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    constexpr short NW = N_SIMDWIDTH;\n    constexpr short NB = 32;\n    constexpr short NF = 8;\n\n    const int nb = args.ne00/NB;\n\n    const int r0 = tgpig.x*NR0;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n  //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n  //device const T0 * x = (device const T0 *) (src0 + offset0);\n    device const T1 * y = (device const T1 *) (src1 + offset1);\n\n    // pointers to src0 rows\n    device const T0 * ax [NR0];\n    FOR_UNROLL (short row = 0; row < NR0; ++row) {\n        const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n\n        ax[row] = (device const T0 *) ((device char *) src0 + offset0);\n    }\n\n    float sumf[NR0] = { 0.f };\n\n    const short ix = tiisg/(NW/NF);\n    const short il = tiisg%(NW/NF);\n\n    const int ib0 = sgitg*NF + ix;\n\n    T1 yl[NF];\n\n    device const T1 * yb = y + (ib0*NB + il*NF);\n\n    for (int ib = ib0; ib < nb; ib += NSG*NF) {\n        for (short i = 0; i < NF; ++i) {\n            yl[i] = yb[i];\n        }\n\n        for (short row = 0; row < NR0; row++) {\n            device const T0 * xb = ax[row] + (ib*NB + il*NF);\n\n            float sumq = 0.f;\n            FOR_UNROLL (short i = 0; i < NF; ++i) {\n                sumq += xb[i] * yl[i];\n            }\n\n            sumf[row] += sumq;\n        }\n\n        yb += NSG*NF*NW;\n    }\n\n    for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {\n        for (short row = 0; row < NR0; row++) {\n            sumf[row] += ax[row][i] * y[i];\n        }\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);\n}\n\ntemplate<typename T0, typename T1, typename args_t>\nvoid kernel_mul_mv_t_t_disp(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    switch (args.nr0) {\n      //case 1: kernel_mul_mv_t_t_impl<T0, T1, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;\n        case 2: kernel_mul_mv_t_t_impl<T0, T1, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;\n      //case 3: kernel_mul_mv_t_t_impl<T0, T1, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;\n      //case 4: kernel_mul_mv_t_t_impl<T0, T1, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;\n    }\n}\n\ntemplate<typename T0, typename T1>\nkernel void kernel_mul_mv_t_t(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n    kernel_mul_mv_t_t_disp<T0, T1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\ntypedef decltype(kernel_mul_mv_t_t<half, half>) mul_mv_t_t;\n\ntemplate [[host_name(\"kernel_mul_mv_f32_f32\")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<float, float>;\ntemplate [[host_name(\"kernel_mul_mv_f16_f32\")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  float>;\ntemplate [[host_name(\"kernel_mul_mv_f16_f16\")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  half>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_mul_mv_bf16_f32\")]]  kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float>;\ntemplate [[host_name(\"kernel_mul_mv_bf16_bf16\")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat>;\n#endif\n\ntemplate<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>\nvoid kernel_mul_mv_t_t_4_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    constexpr short NW = N_SIMDWIDTH;\n    constexpr short NB  = 32;\n    constexpr short NF  = 16;\n    constexpr short NF4 = NF/4;\n\n    const int nb = args.ne00/NB;\n\n    const int r0 = tgpig.x*NR0;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n  //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const T1  * y  = (device const T1  *) (src1 + offset1);\n    device const T14 * y4 = (device const T14 *) (src1 + offset1);\n\n    // pointers to src0 rows\n    device const T0  * ax [NR0];\n    device const T04 * ax4[NR0];\n    FOR_UNROLL (short row = 0; row < NR0; ++row) {\n        const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n\n        ax [row] = (device const T0  *) ((device char *) src0 + offset0);\n        ax4[row] = (device const T04 *) ((device char *) src0 + offset0);\n    }\n\n    float sumf[NR0] = { 0.f };\n\n    const short ix = tiisg/(NW/NF);\n    const short il = tiisg%(NW/NF);\n\n    const int ib0 = sgitg*NF + ix;\n\n    T14 yl4[NF4];\n\n    device const T14 * yb4 = y4 + (ib0*NB + il*NF)/4;\n\n    for (int ib = ib0; ib < nb; ib += NSG*NF) {\n        for (short i = 0; i < NF4; ++i) {\n            yl4[i] = yb4[i];\n        }\n\n        for (short row = 0; row < NR0; row++) {\n            device const T04 * xb4 = ax4[row] + (ib*NB + il*NF)/4;\n\n            float sumq = 0.f;\n            FOR_UNROLL (short i = 0; i < NF4; ++i) {\n                sumq += dot(float4(xb4[i]), float4(yl4[i]));\n            }\n\n            sumf[row] += sumq;\n        }\n\n        yb4 += NSG*NF*NW/4;\n    }\n\n    for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {\n        for (short row = 0; row < NR0; row++) {\n            sumf[row] += ax[row][i] * y[i];\n        }\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);\n}\n\ntemplate<typename T0, typename T04, typename T1, typename T14, typename args_t>\nvoid kernel_mul_mv_t_t_4_disp(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    switch (args.nr0) {\n      //case 1: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;\n        case 2: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;\n      //case 3: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;\n      //case 4: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;\n    };\n}\n\ntemplate<typename T0, typename T04, typename T1, typename T14>\nkernel void kernel_mul_mv_t_t_4(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n    kernel_mul_mv_t_t_4_disp<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\ntypedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4>) mul_mv_t_t_4;\n\ntemplate [[host_name(\"kernel_mul_mv_f32_f32_4\")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4>;\ntemplate [[host_name(\"kernel_mul_mv_f16_f32_4\")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  float, float4>;\ntemplate [[host_name(\"kernel_mul_mv_f16_f16_4\")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  half,  half4>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_mul_mv_bf16_f32_4\")]]  kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float,  float4>;\ntemplate [[host_name(\"kernel_mul_mv_bf16_bf16_4\")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4>;\n#endif\n\ntemplate<typename T0, typename T1, typename args_t>\nvoid kernel_mul_mv_t_t_short_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3  tgpig,\n        ushort tiisg) {\n    const int r0 = tgpig.x*32 + tiisg;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    if (r0 >= args.ne01) {\n        return;\n    }\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n\n    device const T0 * x = (device const T0 *) (src0 + offset0);\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;\n\n    const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;\n\n    device const T1 * y = (device const T1 *) (src1 + offset1);\n\n    float res = 0.0f;\n\n    for (int i = 0; i < args.ne00; ++i) {\n        res += (float) x[i] * (float) y[i];\n    }\n\n    dst_f32[(uint64_t)r1*args.ne0 + r0] = res;\n}\n\ntemplate<typename T0, typename T1>\nkernel void kernel_mul_mv_t_t_short(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]]) {\n    kernel_mul_mv_t_t_short_impl<T0, T1, constant ggml_metal_kargs_mul_mv &>(\n        args,\n        src0,\n        src1,\n        dst,\n        tgpig,\n        tiisg);\n}\n\ntypedef decltype(kernel_mul_mv_t_t_short<half, half>) mul_mv_t_t_short_t;\n\ntemplate [[host_name(\"kernel_mul_mv_f32_f32_short\")]]  kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<float, float>;\ntemplate [[host_name(\"kernel_mul_mv_f16_f32_short\")]]  kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half,  float>;\ntemplate [[host_name(\"kernel_mul_mv_f16_f16_short\")]]  kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half,  half>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_mul_mv_bf16_f32_short\")]]  kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, float>;\ntemplate [[host_name(\"kernel_mul_mv_bf16_bf16_short\")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;\n#endif\n\nconstant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];\n\nstatic float rope_yarn_ramp(const float low, const float high, const int i0) {\n    const float y = (i0 / 2 - low) / max(0.001f, high - low);\n    return 1.0f - min(1.0f, max(0.0f, y));\n}\n\n// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn\n// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.\nstatic void rope_yarn(\n    float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,\n    thread float * cos_theta, thread float * sin_theta) {\n    // Get n-d rotational scaling corrected for extrapolation\n    float theta_interp = freq_scale * theta_extrap;\n    float theta = theta_interp;\n    if (ext_factor != 0.0f) {\n        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;\n        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;\n\n        // Get n-d magnitude scaling corrected for interpolation\n        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);\n    }\n    *cos_theta = cos(theta) * mscale;\n    *sin_theta = sin(theta) * mscale;\n}\n\n// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get\n// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`\nstatic float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {\n    return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));\n}\n\nstatic void rope_yarn_corr_dims(\n    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]\n) {\n    // start and end correction dims\n    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));\n    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));\n}\n\ntemplate<typename T>\nkernel void kernel_rope_norm(\n        constant ggml_metal_kargs_rope & args,\n        device const char * src0,\n        device const char * src1,\n        device const char * src2,\n        device       char * dst,\n        ushort  tiitg[[thread_index_in_threadgroup]],\n        ushort3 tptg [[threads_per_threadgroup]],\n        uint3   tgpig[[threadgroup_position_in_grid]]) {\n    const int i3 = tgpig[2];\n    const int i2 = tgpig[1];\n    const int i1 = tgpig[0];\n\n    float corr_dims[2];\n    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);\n\n    device const int32_t * pos = (device const int32_t *) src1;\n\n    const float theta_base = (float) pos[i2];\n    const float inv_ndims = -1.f/args.n_dims;\n\n    float cos_theta;\n    float sin_theta;\n\n    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {\n        if (i0 < args.n_dims) {\n            const int ic = i0/2;\n\n            const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);\n\n            const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;\n\n            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);\n\n            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);\n            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);\n\n            const float x0 = src[0];\n            const float x1 = src[1];\n\n            dst_data[0] = x0*cos_theta - x1*sin_theta;\n            dst_data[1] = x0*sin_theta + x1*cos_theta;\n        } else {\n            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);\n            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);\n\n            dst_data[0] = src[0];\n            dst_data[1] = src[1];\n        }\n    }\n}\n\ntemplate<typename T>\nkernel void kernel_rope_neox(\n        constant ggml_metal_kargs_rope & args,\n        device const char * src0,\n        device const char * src1,\n        device const char * src2,\n        device       char * dst,\n        ushort  tiitg[[thread_index_in_threadgroup]],\n        ushort3 tptg [[threads_per_threadgroup]],\n        uint3   tgpig[[threadgroup_position_in_grid]]) {\n    const int i3 = tgpig[2];\n    const int i2 = tgpig[1];\n    const int i1 = tgpig[0];\n\n    float corr_dims[2];\n    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);\n\n    device const int32_t * pos = (device const int32_t *) src1;\n\n    const float theta_base = (float) pos[i2];\n    const float inv_ndims = -1.f/args.n_dims;\n\n    float cos_theta;\n    float sin_theta;\n\n    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {\n        if (i0 < args.n_dims) {\n            const int ic = i0/2;\n\n            const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);\n\n            const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;\n\n            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);\n\n            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);\n            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);\n\n            const float x0 = src[0];\n            const float x1 = src[args.n_dims/2];\n\n            dst_data[0]             = x0*cos_theta - x1*sin_theta;\n            dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;\n        } else {\n            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);\n            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);\n\n            dst_data[0] = src[0];\n            dst_data[1] = src[1];\n        }\n    }\n}\n\ntemplate<typename T>\nkernel void kernel_rope_multi(\n        constant ggml_metal_kargs_rope & args,\n        device const char * src0,\n        device const char * src1,\n        device const char * src2,\n        device       char * dst,\n        ushort  tiitg[[thread_index_in_threadgroup]],\n        ushort3 tptg [[threads_per_threadgroup]],\n        uint3   tgpig[[threadgroup_position_in_grid]]) {\n    const int i3 = tgpig[2];\n    const int i2 = tgpig[1];\n    const int i1 = tgpig[0];\n\n    float corr_dims[2];\n    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);\n\n    device const int32_t * pos = (device const int32_t *) src1;\n\n    const float inv_ndims = -1.f/args.n_dims;\n\n    float cos_theta;\n    float sin_theta;\n\n    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {\n        if (i0 < args.n_dims) {\n            const int ic = i0/2;\n\n            // mrope theta calculations\n            // note: the rest is the same as kernel_rope_neox\n            const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;\n            const int sec_w01   = args.sect_0 + args.sect_1;               // end of section 1\n            const int sec_w012  = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2\n            const int sector    = ic % sect_dims;\n\n            float theta_base;\n            if (FC_rope_is_imrope) {\n                if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h\n                    theta_base = (float) pos[i2 + args.ne02 * 1];\n                } else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w\n                    theta_base = (float) pos[i2 + args.ne02 * 2];\n                } else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t\n                    theta_base = (float) pos[i2 + args.ne02 * 0];\n                } else { // e\n                    theta_base = (float) pos[i2 + args.ne02 * 3];\n                }\n            } else {\n                if (sector < args.sect_0) {\n                    theta_base = (float) pos[i2];\n                } else if (sector < sec_w01) {\n                    theta_base = (float) pos[i2 + args.ne02 * 1];\n                } else if (sector < sec_w012) {\n                    theta_base = (float) pos[i2 + args.ne02 * 2];\n                } else {\n                    theta_base = (float) pos[i2 + args.ne02 * 3];\n                }\n            }\n            // end of mrope\n\n            const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);\n\n            const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;\n\n            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);\n\n            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);\n            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);\n\n            const float x0 = src[0];\n            const float x1 = src[args.n_dims/2];\n\n            dst_data[0]             = x0*cos_theta - x1*sin_theta;\n            dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;\n        } else {\n            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);\n            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);\n\n            dst_data[0] = src[0];\n            dst_data[1] = src[1];\n        }\n    }\n}\n\ntemplate<typename T>\nkernel void kernel_rope_vision(\n        constant ggml_metal_kargs_rope & args,\n        device const char * src0,\n        device const char * src1,\n        device const char * src2,\n        device       char * dst,\n        ushort  tiitg[[thread_index_in_threadgroup]],\n        ushort3 tptg [[threads_per_threadgroup]],\n        uint3   tgpig[[threadgroup_position_in_grid]]) {\n    const int i3 = tgpig[2];\n    const int i2 = tgpig[1];\n    const int i1 = tgpig[0];\n\n    float corr_dims[2];\n    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);\n\n    device const int32_t * pos = (device const int32_t *) src1;\n\n    const float inv_ndims = -1.f/args.n_dims;\n\n    float cos_theta;\n    float sin_theta;\n\n    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {\n        if (i0 < 2*args.n_dims) { // different from kernel_rope_multi\n            const int ic = i0/2;\n\n            // mrope theta calculations (only support 2 dimensions)\n            const int sect_dims = args.sect_0 + args.sect_1;\n            const int sector    = ic % sect_dims;\n\n            float p;\n            float theta_base;\n            if (sector < args.sect_1) {\n                p = (float) sector;\n                theta_base = (float) pos[i2];\n            } else {\n                p = (float) sector - args.sect_0;\n                theta_base = (float) pos[i2 + args.ne02];\n            }\n\n            const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);\n            // end of mrope\n\n            const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;\n\n            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);\n\n            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);\n            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);\n\n            const float x0 = src[0];\n            const float x1 = src[args.n_dims]; // different from kernel_rope_multi\n\n            dst_data[0]           = x0*cos_theta - x1*sin_theta;\n            dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi\n        } else {\n            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);\n            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);\n\n            dst_data[0] = src[0];\n            dst_data[1] = src[1];\n        }\n    }\n}\n\ntypedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;\ntypedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;\ntypedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;\ntypedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;\n\ntemplate [[host_name(\"kernel_rope_norm_f32\")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;\ntemplate [[host_name(\"kernel_rope_norm_f16\")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;\n\ntemplate [[host_name(\"kernel_rope_neox_f32\")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;\ntemplate [[host_name(\"kernel_rope_neox_f16\")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;\n\ntemplate [[host_name(\"kernel_rope_multi_f32\")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;\ntemplate [[host_name(\"kernel_rope_multi_f16\")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;\n\ntemplate [[host_name(\"kernel_rope_vision_f32\")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;\ntemplate [[host_name(\"kernel_rope_vision_f16\")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;\n\ntypedef void (im2col_t)(\n        constant ggml_metal_kargs_im2col & args,\n        device const float * x,\n        device        char * dst,\n        uint3 tgpig[[threadgroup_position_in_grid]],\n        uint3  tgpg[[threadgroups_per_grid]],\n        uint3 tpitg[[thread_position_in_threadgroup]],\n        uint3   ntg[[threads_per_threadgroup]]);\n\ntemplate <typename T>\nkernel void kernel_im2col(\n        constant ggml_metal_kargs_im2col & args,\n        device const float * x,\n        device        char * dst,\n        uint3 tgpig[[threadgroup_position_in_grid]],\n        uint3  tgpg[[threadgroups_per_grid]],\n        uint3 tpitg[[thread_position_in_threadgroup]],\n        uint3   ntg[[threads_per_threadgroup]]) {\n//    const int64_t IC = tgpg[0];\n    const int64_t OH = tgpg[1];\n    const int64_t OW = tgpg[2];\n\n    const int64_t KH = ntg[1];\n    const int64_t KW = ntg[2];\n\n          int64_t in  = tpitg[0];\n    const int64_t ikh = tpitg[1];\n    const int64_t ikw = tpitg[2];\n\n    const int64_t iic = tgpig[0];\n    const int64_t ioh = tgpig[1];\n    const int64_t iow = tgpig[2];\n\n    const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;\n    const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;\n\n    int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);\n\n    device T * pdst = (device T *) (dst);\n\n    if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {\n        while (in < args.N) {\n            pdst[offset_dst] = 0.0f;\n            offset_dst += ntg[0]*args.CHW*OH*OW;\n\n            in += ntg[0];\n        }\n    } else {\n        int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;\n\n        while (in < args.N) {\n            pdst[offset_dst] = x[offset_src];\n\n            offset_dst += ntg[0]*args.CHW*OH*OW;\n            offset_src += ntg[0]*args.ofs0;\n\n            in += ntg[0];\n        }\n    }\n}\n\ntemplate [[host_name(\"kernel_im2col_f32\")]] kernel im2col_t kernel_im2col<float>;\ntemplate [[host_name(\"kernel_im2col_f16\")]] kernel im2col_t kernel_im2col<half>;\n\n// TODO: obsolete -- remove\n//typedef void (im2col_ext_t)(\n//        constant ggml_metal_kargs_im2col & args,\n//        device const float * x,\n//        device        char * dst,\n//        uint3 tgpig[[threadgroup_position_in_grid]],\n//        uint3  tgpg[[threadgroups_per_grid]],\n//        uint3 tpitg[[thread_position_in_threadgroup]],\n//        uint3   ntg[[threads_per_threadgroup]]);\n//\n//template <typename T>\n//kernel void kernel_im2col_ext(\n//        constant ggml_metal_kargs_im2col & args,\n//        device const float * x,\n//        device        char * dst,\n//        uint3 tgpig[[threadgroup_position_in_grid]],\n//        uint3  tgpg[[threadgroups_per_grid]],      // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW\n//        uint3 tpitg[[thread_position_in_threadgroup]],\n//        uint3   ntg[[threads_per_threadgroup]]) {  // [M, 1, 1]\n//    const int64_t KHW = (int64_t)args.KHW;\n//\n//    const int64_t d   = tgpig[0] / args.CHW;\n//    const int64_t chw = tgpig[0] % args.CHW;\n//    const int64_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)\n//    const int64_t HW = tgpig[0] % KHW;\n//\n//    const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];\n//    if (tpitg_0 >= args.N) {\n//        return;\n//    }\n//\n//    const int64_t tpitg_1 = HW / args.KW;\n//    const int64_t tpitg_2 = HW % args.KW;\n//\n//    const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;\n//    const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;\n//\n//    const int64_t offset_dst =\n//        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +\n//        (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);\n//\n//    device T * pdst = (device T *) (dst);\n//\n//    if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {\n//        pdst[offset_dst] = 0.0f;\n//    } else {\n//        const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;\n//        pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];\n//    }\n//}\n//\n//template [[host_name(\"kernel_im2col_ext_f32\")]] kernel im2col_ext_t kernel_im2col_ext<float>;\n//template [[host_name(\"kernel_im2col_ext_f16\")]] kernel im2col_ext_t kernel_im2col_ext<half>;\n\ntemplate <typename TK>\nkernel void kernel_conv_2d(\n        constant ggml_metal_kargs_conv_2d & args,\n        device const char * weights,\n        device const char * src,\n        device       char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        uint3    tgpg[[threadgroups_per_grid]],\n        uint3   tpitg[[thread_position_in_threadgroup]],\n        uint3     ntg[[threads_per_threadgroup]]) {\n\n    const uint threads_per_tg = ntg.x * ntg.y * ntg.z;\n    const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;\n    const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;\n    const uint thread_index = tg_index * threads_per_tg + local_thread;\n    const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;\n    const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;\n\n    for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {\n        uint64_t tmp = index;\n\n        const int32_t ow = tmp % args.OW; tmp /= args.OW;\n        const int32_t oh = tmp % args.OH; tmp /= args.OH;\n        const int32_t oc = tmp % args.OC; tmp /= args.OC;\n        const int32_t  n = tmp;\n\n        float acc = 0.0f;\n\n        const int32_t base_x = ow*args.s0 - args.p0;\n        const int32_t base_y = oh*args.s1 - args.p1;\n\n        int32_t ky_start = 0;\n        if (base_y < 0) {\n            ky_start = (-base_y + args.d1 - 1)/args.d1;\n        }\n        int32_t ky_end = args.KH;\n        const int32_t y_max = args.IH - 1 - base_y;\n        if (y_max < 0) {\n            ky_end = ky_start;\n        } else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {\n            ky_end = min(ky_end, y_max/args.d1 + 1);\n        }\n\n        int32_t kx_start = 0;\n        if (base_x < 0) {\n            kx_start = (-base_x + args.d0 - 1)/args.d0;\n        }\n        int32_t kx_end = args.KW;\n        const int32_t x_max = args.IW - 1 - base_x;\n        if (x_max < 0) {\n            kx_end = kx_start;\n        } else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {\n            kx_end = min(kx_end, x_max/args.d0 + 1);\n        }\n\n        if (ky_start < ky_end && kx_start < kx_end) {\n            const uint64_t src_base_n = (uint64_t) n  * args.nb13;\n            const uint64_t w_base_oc  = (uint64_t) oc * args.nb03;\n\n            for (int32_t ic = 0; ic < args.IC; ++ic) {\n                const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;\n                const uint64_t w_base_ocic = w_base_oc  + (uint64_t) ic * args.nb02;\n\n                for (int32_t ky = ky_start; ky < ky_end; ++ky) {\n                    const int32_t iy = base_y + ky*args.d1;\n                    const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;\n                    const uint64_t w_base_row   = w_base_ocic + (uint64_t) ky * args.nb01;\n\n                    for (int32_t kx = kx_start; kx < kx_end; ++kx) {\n                        const int32_t ix = base_x + kx*args.d0;\n                        const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;\n                        const uint64_t w_offs   = w_base_row   + (uint64_t) kx * args.nb00;\n\n                        const float x = *(device const float *)(src + src_offs);\n                        const float w = (float) (*(device const TK *)(weights + w_offs));\n\n                        acc += x * w;\n                    }\n                }\n            }\n        }\n\n        const uint64_t dst_offs =\n            (uint64_t) n  * args.nb3 +\n            (uint64_t) oc * args.nb2 +\n            (uint64_t) oh * args.nb1 +\n            (uint64_t) ow * args.nb0;\n\n        *(device float *)(dst + dst_offs) = acc;\n    }\n}\n\ntemplate [[host_name(\"kernel_conv_2d_f32_f32\")]]\nkernel void kernel_conv_2d<float>(\n        constant ggml_metal_kargs_conv_2d & args,\n        device const char * weights,\n        device const char * src,\n        device       char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        uint3    tgpg[[threadgroups_per_grid]],\n        uint3   tpitg[[thread_position_in_threadgroup]],\n        uint3     ntg[[threads_per_threadgroup]]);\n\ntemplate [[host_name(\"kernel_conv_2d_f16_f32\")]]\nkernel void kernel_conv_2d<half>(\n        constant ggml_metal_kargs_conv_2d & args,\n        device const char * weights,\n        device const char * src,\n        device       char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        uint3    tgpg[[threadgroups_per_grid]],\n        uint3   tpitg[[thread_position_in_threadgroup]],\n        uint3     ntg[[threads_per_threadgroup]]);\n\ntypedef void (conv_transpose_1d_t)(\n        constant ggml_metal_kargs_conv_transpose_1d & args,\n        device const float * src0,\n        device const float * src1,\n        device        char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        uint3    tgpg[[threadgroups_per_grid]]);\n\ntemplate <typename T>\nkernel void kernel_conv_transpose_1d(\n        constant ggml_metal_kargs_conv_transpose_1d & args,\n        device const     T * src0,\n        device const float * src1,\n        device        char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        uint3   tgpg[[threadgroups_per_grid]]) {\n\n    float v = 0.0f;\n\n    for (int64_t c = 0; c < args.IC; c++) {\n        const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];\n        const int32_t input_offset = c * args.IL;\n\n        for (int64_t i = 0; i < args.IL; i++) {\n            if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {\n                v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];\n            }\n        }\n    }\n\n    device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);\n\n    dst_ptr[0] = v;\n}\n\ntemplate [[host_name(\"kernel_conv_transpose_1d_f32_f32\")]]\nkernel void kernel_conv_transpose_1d<float>(\n    constant ggml_metal_kargs_conv_transpose_1d & args,\n    device const float * src0,\n    device const float * src1,\n    device        char * dst,\n    uint3   tgpig[[threadgroup_position_in_grid]],\n    uint3    tgpg[[threadgroups_per_grid]]);\n\ntemplate [[host_name(\"kernel_conv_transpose_1d_f16_f32\")]]\nkernel void kernel_conv_transpose_1d<half>(\n    constant ggml_metal_kargs_conv_transpose_1d & args,\n    device const half  * src0,\n    device const float * src1,\n    device        char * dst,\n    uint3   tgpig[[threadgroup_position_in_grid]],\n    uint3    tgpg[[threadgroups_per_grid]]);\n\n\ntypedef void (conv_transpose_2d_t)(\n        constant ggml_metal_kargs_conv_transpose_2d & args,\n        device const float * src0,\n        device const float * src1,\n        device        char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        uint3    tgpg[[threadgroups_per_grid]]);\n\ntemplate <typename T>\nkernel void kernel_conv_transpose_2d(\n        constant ggml_metal_kargs_conv_transpose_2d & args,\n        device const T * src0,\n        device const float * src1,\n        device        char * dst,\n        threadgroup float * shared_sum [[threadgroup(0)]],\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        uint3   tpitg[[thread_position_in_threadgroup]],\n        uint3     ntg[[threads_per_threadgroup]]) {\n\n    const int64_t out_x = tgpig[0];\n    const int64_t out_y = tgpig[1];\n    const int64_t out_c = tgpig[2];\n\n    const int64_t kw = tpitg[0];\n    const int64_t kh = tpitg[1];\n\n    float v = 0.0f;\n\n    for (int64_t in_c = 0; in_c < args.IC; in_c++) {\n        int64_t in_y = out_y - kh;\n\n        if (in_y < 0 || in_y % args.s0) continue;\n\n        in_y /= args.s0;\n\n        if (in_y >= args.IH) continue;\n\n        int64_t in_x = out_x - kw;\n\n        if (in_x < 0 || in_x % args.s0) continue;\n\n        in_x /= args.s0;\n\n        if (in_x >= args.IW) continue;\n\n        const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;\n        const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;\n\n        v += (float)src0[kernel_idx] * src1[input_idx];\n    }\n\n    const uint tid = tpitg.y * ntg.x + tpitg.x;\n    shared_sum[tid] = v;\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    if (tid == 0) {\n        float total = 0.0f;\n        const uint num_threads = ntg.x * ntg.y;\n        for (uint i = 0; i < num_threads; i++) {\n            total += shared_sum[i];\n        }\n\n        device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);\n        dst_ptr[0] = total;\n    }\n}\n\ntemplate [[host_name(\"kernel_conv_transpose_2d_f32_f32\")]]\nkernel void kernel_conv_transpose_2d<float>(\n    constant ggml_metal_kargs_conv_transpose_2d & args,\n    device const float * src0,\n    device const float * src1,\n    device        char * dst,\n    threadgroup float * shared_sum [[threadgroup(0)]],\n    uint3   tgpig[[threadgroup_position_in_grid]],\n    uint3   tpitg[[thread_position_in_threadgroup]],\n    uint3     ntg[[threads_per_threadgroup]]);\n\ntemplate [[host_name(\"kernel_conv_transpose_2d_f16_f32\")]]\nkernel void kernel_conv_transpose_2d<half>(\n    constant ggml_metal_kargs_conv_transpose_2d & args,\n    device const half  * src0,\n    device const float * src1,\n    device        char * dst,\n    threadgroup float * shared_sum [[threadgroup(0)]],\n    uint3   tgpig[[threadgroup_position_in_grid]],\n    uint3   tpitg[[thread_position_in_threadgroup]],\n    uint3     ntg[[threads_per_threadgroup]]);\n\nconstant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];\n\nkernel void kernel_upscale_nearest_f32(\n    constant ggml_metal_kargs_upscale & args,\n    device  const char * src0,\n    device        char * dst,\n    uint3 tgpig[[threadgroup_position_in_grid]],\n    uint3 tpitg[[thread_position_in_threadgroup]],\n    uint3   ntg[[threads_per_threadgroup]]) {\n\n    const int64_t i3 = tgpig.z;\n    const int64_t i2 = tgpig.y;\n    const int64_t i1 = tgpig.x;\n\n    const int64_t i03 = i3/args.sf3;\n    const int64_t i02 = i2/args.sf2;\n    const int64_t i01 = i1/args.sf1;\n\n    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {\n        const int64_t i00 = i0/args.sf0;\n\n        device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);\n        device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1  +  i0*args.nb0);\n\n        dst_ptr[0] = src0_ptr[0];\n    }\n}\n\nstatic inline float bilinear_tri(float x) {\n    return MAX(0.0f, 1.0f - fabs(x));\n}\n\nkernel void kernel_upscale_bilinear_f32(\n    constant ggml_metal_kargs_upscale & args,\n    device  const char * src0,\n    device        char * dst,\n    uint3 tgpig[[threadgroup_position_in_grid]],\n    uint3 tpitg[[thread_position_in_threadgroup]],\n    uint3   ntg[[threads_per_threadgroup]]) {\n\n    const int64_t i3 = tgpig.z;\n    const int64_t i2 = tgpig.y;\n    const int64_t i1 = tgpig.x;\n\n    const int64_t i03 = i3 / args.sf3;\n    const int64_t i02 = i2 / args.sf2;\n\n    const float   f01  = ((float)i1 + args.poffs) / args.sf1 - args.poffs;\n    const int64_t i01  = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));\n    const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));\n    const float   fd1  = MAX(0.0f, MIN(1.0f, f01 - (float)i01));\n\n    src0 += i03*args.nb03 + i02*args.nb02;\n\n    device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);\n\n    if (FC_upscale_aa) {\n        const float support0  = MAX(1.0f, 1.0f / args.sf0);\n        const float invscale0 = 1.0f / support0;\n        const float support1  = MAX(1.0f, 1.0f / args.sf1);\n        const float invscale1 = 1.0f / support1;\n\n        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {\n            const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;\n\n            int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));\n            int64_t x_max = MIN(args.ne00,  (int64_t)ceil (f00 + support0 + args.poffs));\n\n            int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));\n            int64_t y_max = MIN(args.ne01,  (int64_t)ceil (f01 + support1 + args.poffs));\n\n            float sum = 0.0f;\n            float wsum = 0.0f;\n\n            for (int64_t sy = y_min; sy < y_max; ++sy) {\n                const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);\n                for (int64_t sx = x_min; sx < x_max; ++sx) {\n                    const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);\n                    const float w  = wx * wy;\n                    const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);\n                    sum  += (*src_ptr) * w;\n                    wsum += w;\n                }\n            }\n\n            const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;\n            dst_ptr[i0] = v;\n        }\n    } else {\n        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {\n            const float   f00  = ((float)i0 + args.poffs) / args.sf0 - args.poffs;\n            const int64_t i00  = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));\n            const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));\n            const float   fd0  = MAX(0.0f, MIN(1.0f, f00 - (float)i00));\n\n            device const float * src00 = (device const float *)(src0 + i01*args.nb01  + i00*args.nb00);\n            device const float * src10 = (device const float *)(src0 + i01*args.nb01  + i00p*args.nb00);\n            device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);\n            device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);\n\n            const float v =\n                (*src00) * (1.0f - fd0) * (1.0f - fd1) +\n                (*src10) * fd0          * (1.0f - fd1) +\n                (*src01) * (1.0f - fd0) * fd1 +\n                (*src11) * fd0          * fd1;\n\n            dst_ptr[i0] = v;\n        }\n    }\n}\n\nstatic inline float bicubic_weight1(float x) {\n    const float a = -0.75f;\n    return ((a + 2) * x - (a + 3)) * x * x + 1;\n}\n\nstatic inline float bicubic_weight2(float x) {\n    const float a = -0.75f;\n    return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;\n}\n\nkernel void kernel_upscale_bicubic_f32(\n    constant ggml_metal_kargs_upscale & args,\n    device  const char * src0,\n    device        char * dst,\n    uint3 tgpig[[threadgroup_position_in_grid]],\n    uint3 tpitg[[thread_position_in_threadgroup]],\n    uint3   ntg[[threads_per_threadgroup]]) {\n\n    const int64_t i3 = tgpig.z;\n    const int64_t i2 = tgpig.y;\n    const int64_t i1 = tgpig.x;\n\n    const int64_t i03 = i3 / args.sf3;\n    const int64_t i02 = i2 / args.sf2;\n\n    const float   f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;\n    const int64_t i01 = (int64_t)floor(f01);\n    const float   fd1 = f01 - (float)i01;\n\n    const float w_y0 = bicubic_weight2(fd1 + 1.0f);\n    const float w_y1 = bicubic_weight1(fd1);\n    const float w_y2 = bicubic_weight1(1.0f - fd1);\n    const float w_y3 = bicubic_weight2(2.0f - fd1);\n\n    const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;\n\n    device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);\n\n    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {\n        const float   f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;\n        const int64_t i00 = (int64_t)floor(f00);\n        const float   fd0 = f00 - (float)i00;\n\n        const float w_x0 = bicubic_weight2(fd0 + 1.0f);\n        const float w_x1 = bicubic_weight1(fd0);\n        const float w_x2 = bicubic_weight1(1.0f - fd0);\n        const float w_x3 = bicubic_weight2(2.0f - fd0);\n\n        float sum = 0.0f;\n\n        for (int dy = -1; dy <= 2; ++dy) {\n            const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));\n            const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;\n\n            for (int dx = -1; dx <= 2; ++dx) {\n                const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));\n                const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;\n\n                const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);\n                sum += (*src_ptr) * wx * wy;\n            }\n        }\n\n        dst_ptr[i0] = sum;\n    }\n}\n\nkernel void kernel_pad_f32(\n    constant ggml_metal_kargs_pad & args,\n    device  const char * src0,\n    device        char * dst,\n    uint3 tgpig[[threadgroup_position_in_grid]],\n    uint3 tpitg[[thread_position_in_threadgroup]],\n    uint3   ntg[[threads_per_threadgroup]]) {\n\n    const int64_t i3 = tgpig.z;\n    const int64_t i2 = tgpig.y;\n    const int64_t i1 = tgpig.x;\n\n    const int64_t i03 = i3;\n    const int64_t i02 = i2;\n    const int64_t i01 = i1;\n\n    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);\n    device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1);\n\n    if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {\n        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {\n            if (i0 < args.ne00) {\n                dst_ptr[i0] = src0_ptr[i0];\n            } else {\n                dst_ptr[i0] = 0.0f;\n            }\n        }\n\n        return;\n    }\n\n    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {\n        dst_ptr[i0] = 0.0f;\n    }\n}\n\nkernel void kernel_pad_reflect_1d_f32(\n    constant   ggml_metal_kargs_pad_reflect_1d & args,\n    device  const char * src0,\n    device        char * dst,\n    uint3 tgpig[[threadgroup_position_in_grid]],\n    uint3  tgpg[[threadgroups_per_grid]],\n    uint3 tpitg[[thread_position_in_threadgroup]],\n    uint3   ntg[[threads_per_threadgroup]]) {\n\n    const int64_t i3 = tgpig.z;\n    const int64_t i2 = tgpig.y;\n    const int64_t i1 = tgpig.x;\n\n    const int64_t i03 = i3;\n    const int64_t i02 = i2;\n    const int64_t i01 = i1;\n\n    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);\n    device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1);\n\n    if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {\n        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {\n            if (i0 < args.p0) {\n                dst_ptr[i0] = src0_ptr[args.p0 - i0];\n            } else if (i0 < args.ne0 - args.p1) {\n                dst_ptr[i0] = src0_ptr[i0 - args.p0];\n            } else {\n                dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];\n            }\n        }\n    }\n}\n\nkernel void kernel_arange_f32(\n    constant   ggml_metal_kargs_arange & args,\n    device        char * dst,\n    uint3 tgpig[[threadgroup_position_in_grid]],\n    uint3 tpitg[[thread_position_in_threadgroup]],\n    uint3   ntg[[threads_per_threadgroup]]) {\n\n    device float * dst_ptr = (device float *) dst;\n\n    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {\n        dst_ptr[i0] = args.start + args.step * i0;\n    }\n}\n\nkernel void kernel_timestep_embedding_f32(\n    constant  ggml_metal_kargs_timestep_embedding & args,\n    device  const char * src0,\n    device        char * dst,\n    uint3 tgpig[[threadgroup_position_in_grid]],\n    uint3 tpitg[[thread_position_in_threadgroup]],\n    uint3   ntg[[threads_per_threadgroup]]) {\n\n    int i = tgpig.x;\n    device float * embed_data = (device float *)(dst + i*args.nb1);\n\n    int half_ = args.dim / 2;\n    for (int j = tpitg.x; j < half_; j += ntg.x) {\n        float timestep = ((device float *)src0)[i];\n        float freq = (float)exp(-log((float)args.max_period) * j / half_);\n        float arg = timestep * freq;\n        embed_data[j        ] = cos(arg);\n        embed_data[j + half_] = sin(arg);\n    }\n\n    if (args.dim % 2 != 0 && tpitg.x == 0) {\n        embed_data[2 * half_] = 0.f;\n    }\n}\n\n// bitonic sort implementation following the CUDA kernels as reference\ntypedef void (argsort_t)(\n        constant   ggml_metal_kargs_argsort & args,\n        device   const char * src0,\n        device      int32_t * dst,\n        threadgroup int32_t * shmem_i32 [[threadgroup(0)]],\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]);\n\ntemplate<ggml_sort_order order>\nkernel void kernel_argsort_f32_i32(\n        constant   ggml_metal_kargs_argsort & args,\n        device   const char * src0,\n        device      int32_t * dst,\n        threadgroup int32_t * shmem_i32 [[threadgroup(0)]],\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    // bitonic sort\n    const int col = tpitg[0];\n    const int ib  = tgpig[0] / args.ne01;\n\n    const int i00 = ib*ntg.x;\n    const int i01 = tgpig[0] % args.ne01;\n    const int i02 = tgpig[1];\n    const int i03 = tgpig[2];\n\n    device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);\n\n    // initialize indices\n    shmem_i32[col] = i00 + col;\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    for (int k = 2; k <= ntg.x; k *= 2) {\n        for (int j = k / 2; j > 0; j /= 2) {\n            int ixj = col ^ j;\n            if (ixj > col) {\n                if ((col & k) == 0) {\n                    if (shmem_i32[col] >= args.ne00 ||\n                       (shmem_i32[ixj] <  args.ne00 && (order == GGML_SORT_ORDER_ASC ?\n                            src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :\n                            src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))\n                    ) {\n                        SWAP(shmem_i32[col], shmem_i32[ixj]);\n                    }\n                } else {\n                    if (shmem_i32[ixj] >= args.ne00 ||\n                       (shmem_i32[col] <  args.ne00 && (order == GGML_SORT_ORDER_ASC ?\n                            src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :\n                            src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))\n                    ) {\n                        SWAP(shmem_i32[col], shmem_i32[ixj]);\n                    }\n                }\n            }\n\n            threadgroup_barrier(mem_flags::mem_threadgroup);\n        }\n    }\n\n    const int64_t i0 = ib*args.top_k;\n\n    // copy the result to dst without the padding\n    if (i0 + col < args.ne0 && col < args.top_k) {\n        dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;\n\n        dst[col] = shmem_i32[col];\n    }\n}\n\ntemplate [[host_name(\"kernel_argsort_f32_i32_asc\")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;\ntemplate [[host_name(\"kernel_argsort_f32_i32_desc\")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;\n\ntypedef void (argsort_merge_t)(\n        constant   ggml_metal_kargs_argsort_merge & args,\n        device const char    * src0,\n        device const int32_t * tmp,\n        device       int32_t * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]);\n\ntemplate<ggml_sort_order order>\nkernel void kernel_argsort_merge_f32_i32(\n        constant   ggml_metal_kargs_argsort_merge & args,\n        device const char    * src0,\n        device const int32_t * tmp,\n        device       int32_t * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n\n    const int im  = tgpig[0] / args.ne01;\n    const int i01 = tgpig[0] % args.ne01;\n    const int i02 = tgpig[1];\n    const int i03 = tgpig[2];\n\n    const int start = im * (2 * args.len);\n\n    const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));\n    const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));\n\n    const int total = len0 + len1;\n\n    device const int32_t * tmp0 = tmp + start\n        + i01*args.ne0\n        + i02*args.ne0*args.ne01\n        + i03*args.ne0*args.ne01*args.ne02;\n\n    device const int32_t * tmp1 = tmp0 + args.len;\n\n    dst += start\n        + i01*args.top_k\n        + i02*args.top_k*args.ne01\n        + i03*args.top_k*args.ne01*args.ne02;\n\n    device const float * src0_row = (device const float *)(src0\n        + args.nb01*i01\n        + args.nb02*i02\n        + args.nb03*i03);\n\n    if (total == 0) {\n        return;\n    }\n\n    const int chunk = (total + ntg.x - 1) / ntg.x;\n\n    const int k0 = tpitg.x * chunk;\n    const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);\n\n    if (k0 >= args.top_k) {\n        return;\n    }\n\n    if (k0 >= total) {\n        return;\n    }\n\n    int low  = k0 > len1 ? k0 - len1 : 0;\n    int high = MIN(k0, len0);\n\n    // binary-search partition (i, j) such that i + j = k\n    while (low < high) {\n        const int mid = (low + high) >> 1;\n\n        const int32_t idx0 = tmp0[mid];\n        const int32_t idx1 = tmp1[k0 - mid - 1];\n\n        const float val0 = src0_row[idx0];\n        const float val1 = src0_row[idx1];\n\n        bool take_left;\n        if (order == GGML_SORT_ORDER_ASC) {\n            take_left = (val0 <= val1);\n        } else {\n            take_left = (val0 >= val1);\n        }\n\n        if (take_left) {\n            low = mid + 1;\n        } else {\n            high = mid;\n        }\n    }\n\n    int i = low;\n    int j = k0 - i;\n\n    // keep the merge fronts into registers\n    int32_t idx0 = 0;\n    float   val0 = 0.0f;\n    if (i < len0) {\n        idx0 = tmp0[i];\n        val0 = src0_row[idx0];\n    }\n\n    int32_t idx1 = 0;\n    float   val1 = 0.0f;\n    if (j < len1) {\n        idx1 = tmp1[j];\n        val1 = src0_row[idx1];\n    }\n\n    for (int k = k0; k < k1; ++k) {\n        int32_t out_idx;\n\n        if (i >= len0) {\n            while (k < k1) {\n                dst[k++] = tmp1[j++];\n            }\n            break;\n        } else if (j >= len1) {\n            while (k < k1) {\n                dst[k++] = tmp0[i++];\n            }\n            break;\n        } else {\n            bool take_left;\n\n            if (order == GGML_SORT_ORDER_ASC) {\n                take_left = (val0 <= val1);\n            } else {\n                take_left = (val0 >= val1);\n            }\n\n            if (take_left) {\n                out_idx = idx0;\n                ++i;\n                if (i < len0) {\n                    idx0 = tmp0[i];\n                    val0 = src0_row[idx0];\n                }\n            } else {\n                out_idx = idx1;\n                ++j;\n                if (j < len1) {\n                    idx1 = tmp1[j];\n                    val1 = src0_row[idx1];\n                }\n            }\n        }\n\n        dst[k] = out_idx;\n    }\n}\n\ntemplate [[host_name(\"kernel_argsort_merge_f32_i32_asc\")]]  kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;\ntemplate [[host_name(\"kernel_argsort_merge_f32_i32_desc\")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;\n\nconstant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];\n\nconstant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];\n\n// pad the last chunk of C elements of k and v into a an extra pad buffer\nkernel void kernel_flash_attn_ext_pad(\n        constant ggml_metal_kargs_flash_attn_ext_pad & args,\n        device const char * k,\n        device const char * v,\n        device const char * mask,\n        device       char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort  tiitg[[thread_index_in_threadgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    const int32_t C = FC_flash_attn_ext_pad_ncpsg;\n\n    device char * k_pad    = dst;\n    device char * v_pad    = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3;\n    device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3;\n\n    const int32_t icp = args.ne11 % C;\n    const int32_t ic0 = args.ne11 - icp;\n\n    const int32_t i1 = tgpig[0];\n    const int32_t i2 = tgpig[1];\n    const int32_t i3 = tgpig[2];\n\n    if (i2 < args.ne_12_2 && i3 < args.ne_12_3) {\n        device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3;\n        device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3;\n\n        device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3;\n        device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;\n\n        if (i1 >= icp) {\n            // here it is not important the exact value that will be used as we rely on masking out the scores in the attention\n            for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {\n                k_dst[i] = 0;\n            }\n            for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {\n                v_dst[i] = 0;\n            }\n        } else {\n            for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {\n                k_dst[i] = k_src[i];\n            }\n            for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {\n                v_dst[i] = v_src[i];\n            }\n        }\n    }\n\n    if (FC_flash_attn_ext_pad_has_mask) {\n        if (i2 < args.ne32 && i3 < args.ne33) {\n            for (int ib = i1; ib < args.ne31; ib += C) {\n                device const half * mask_src = (device const half *)(mask      + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0;\n                device       half * mask_dst = (device       half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3;\n\n                for (int i = tiitg; i < C; i += ntg.x) {\n                    if (i >= icp) {\n                        mask_dst[i] = -MAXHALF;\n                    } else {\n                        mask_dst[i] = mask_src[i];\n                    }\n                }\n            }\n        }\n    }\n}\n\nconstant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]];\nconstant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]];\n\n// scan the blocks of the mask that are not masked\n// 0 -     masked (i.e. full of -INF, skip)\n// 1 - not masked (i.e. at least one element of the mask is not -INF)\n// 2 - all zero\nkernel void kernel_flash_attn_ext_blk(\n        constant ggml_metal_kargs_flash_attn_ext_blk & args,\n        device const char * mask,\n        device       char * dst,\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]]) {\n    // block size C x Q\n    const int32_t Q = FC_flash_attn_ext_blk_nqptg;\n    const int32_t C = FC_flash_attn_ext_blk_ncpsg;\n\n    constexpr short NW  = N_SIMDWIDTH;\n\n    const int32_t i3 = tgpig[2]/args.ne32;\n    const int32_t i2 = tgpig[2]%args.ne32;\n    const int32_t i1 = tgpig[1];\n    const int32_t i0 = tgpig[0];\n\n    char res = i0*C + C > args.ne30 ? 1 : 0;\n\n    device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;\n\n    // detailed check of the elements of the block\n    if ((C > NW || Q > 1) && res == 0) {\n        half mmin =  MAXHALF;\n        half mmax = -MAXHALF;\n\n        FOR_UNROLL (short j = 0; j < Q; ++j) {\n            FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {\n                mmin = min(mmin, mask_src[ii*NW]);\n                mmax = max(mmax, mask_src[ii*NW]);\n            }\n\n            mask_src += args.nb31/2;\n        }\n\n        mmin = simd_min(mmin);\n        mmax = simd_max(mmax);\n\n        if (mmax > -MAXHALF) {\n            if (mmin == 0.0 && mmax == 0.0) {\n                res = 2;\n            } else {\n                res = 1;\n            }\n        }\n    }\n\n    const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);\n    const int32_t nblk0 = ((args.ne30 + C - 1)/C);\n\n    if (tiisg == 0) {\n        dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res;\n    }\n}\n\nconstant bool FC_flash_attn_ext_has_mask  [[function_constant(FC_FLASH_ATTN_EXT + 0)]];\nconstant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];\nconstant bool FC_flash_attn_ext_has_bias  [[function_constant(FC_FLASH_ATTN_EXT + 2)]];\nconstant bool FC_flash_attn_ext_has_scap  [[function_constant(FC_FLASH_ATTN_EXT + 3)]];\nconstant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];\n\nconstant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]];\n\n//constant float FC_flash_attn_ext_scale         [[function_constant(FC_FLASH_ATTN_EXT + 10)]];\n//constant float FC_flash_attn_ext_max_bias      [[function_constant(FC_FLASH_ATTN_EXT + 11)]];\n//constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]];\n\nconstant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]];\nconstant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]];\nconstant int32_t FC_flash_attn_ext_nsg  [[function_constant(FC_FLASH_ATTN_EXT + 22)]];\n\n// ref: https://arxiv.org/pdf/2307.08691.pdf\ntemplate<\n    typename q_t,     // query types in shared memory\n    typename q4_t,\n    typename q8x8_t,\n    typename k_t,     // key types in shared memory\n    typename k4x4_t,\n    typename k8x8_t,\n    typename v_t,     // value types in shared memory\n    typename v4x4_t,\n    typename v8x8_t,\n    typename qk_t,    // Q*K types\n    typename qk8x8_t,\n    typename s_t,     // soft-max types\n    typename s2_t,\n    typename s8x8_t,\n    typename o_t,     // attention accumulation types\n    typename o4_t,\n    typename o8x8_t,\n    typename kd4x4_t, // key type in device memory\n    short nl_k,\n    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),\n    typename vd4x4_t, // value type in device memory\n    short nl_v,\n    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),\n    short DK,         // K head size\n    short DV,         // V head size\n    short Q,          // queries per threadgroup\n    short C,          // cache items per threadgroup\n    short NSG>        // number of simd groups\nvoid kernel_flash_attn_ext_impl(\n        constant ggml_metal_kargs_flash_attn_ext & args,\n        device const char * q,\n        device const char * k,\n        device const char * v,\n        device const char * mask,\n        device const char * sinks,\n        device const char * pad,\n        device const char * blk,\n        device       char * dst,\n        threadgroup  half * shmem_f16,\n        uint3   tgpig,\n        ushort  tiisg,\n        ushort  sgitg) {\n    const ushort iq3 = tgpig[2];\n    const ushort iq2 = tgpig[1];\n    const ushort iq1 = tgpig[0]*Q;\n\n#define NS10 (FC_flash_attn_ext_ns10)\n#define NS20 (FC_flash_attn_ext_ns20)\n\n    // note: I had some concerns that using this instead of the ugly macros above was affecting performance\n    //       need to re-check carefully and if no regressions are observerd - remove the macros\n    //       the concerns is that maybe using const variables requires extra registers? but not sure if the compiler\n    //         is clever enough to avoid this. unfortunately, using constexpr is not possible with FC\n    //const short NS10 = FC_flash_attn_ext_ns10;\n    //const short NS20 = FC_flash_attn_ext_ns20;\n\n    constexpr short KV   = 8;\n\n    constexpr short DK4  = DK/4;\n    constexpr short DK8  = DK/8;\n    constexpr short DK16 = DK/16;\n    constexpr short DV4  = DV/4;\n  //constexpr short DV8  = DV/8;\n    constexpr short DV16 = DV/16;\n\n    constexpr short PV   = PAD2(DV, 64);\n    constexpr short PV4  = PV/4;\n    constexpr short PV8  = PV/8;\n  //constexpr short PV16 = PV/16;\n\n    constexpr short NW  = N_SIMDWIDTH;\n    constexpr short NQ  = Q/NSG;\n    constexpr short SH  = 2*C; // shared memory per simdgroup (s_t == float)\n\n    constexpr short TS = 2*SH;\n    constexpr short T  = DK + 2*PV; // shared memory size per query in (half)\n\n    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 + 0*T); // holds the query data\n    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t\n    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper)\n    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK);\n    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix\n    threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t\n\n    threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory\n    threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t\n\n    threadgroup v_t    * sv    = (threadgroup v_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory\n    threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t\n\n    // mask storage in shared mem\n    threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);\n\n    // per-query mask pointers\n    device const half2 * pm2[NQ];\n\n    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {\n        const short j = jj*NSG + sgitg;\n\n        pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);\n    }\n\n    {\n        const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);\n        const int32_t nblk0 = ((args.ne11 + C - 1)/C);\n\n        blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0;\n    }\n\n    {\n        q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;\n\n        const short ikv2 = iq2/(args.ne02/args.ne_12_2);\n        const short ikv3 = iq3/(args.ne03/args.ne_12_3);\n\n        k += ikv2*args.nb12 + ikv3*args.nb13;\n        v += ikv2*args.nb22 + ikv3*args.nb23;\n    }\n\n    // load heads from Q to shared memory\n    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {\n        const short j = jj*NSG + sgitg;\n\n        device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01);\n\n        for (short i = tiisg; i < DK4; i += NW) {\n            if (iq1 + j < args.ne01) {\n                sq4[j*DK4 + i] = (q4_t) q4[i];\n            } else {\n                sq4[j*DK4 + i] = 0;\n            }\n        }\n    }\n\n    // zero out\n    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {\n        const short j = jj*NSG + sgitg;\n\n        for (short i = tiisg; i < DV4; i += NW) {\n            so4[j*PV4 + i] = 0;\n        }\n\n        for (short i = tiisg; i < SH; i += NW) {\n            ss[j*SH + i] = 0.0f;\n        }\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    float S[NQ] = { [0 ... NQ-1] = 0.0f };\n\n    {\n        float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 };\n\n        float slope = 1.0f;\n\n        // ALiBi\n        if (FC_flash_attn_ext_has_bias) {\n            const short h = iq2;\n\n            const float base = h < args.n_head_log2 ? args.m0 : args.m1;\n            const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;\n\n            slope = pow(base, exph);\n        }\n\n        // loop over the KV cache\n        // each simdgroup handles blocks of Q rows and C columns\n        for (int ic0 = 0; ; ++ic0) {\n            int ic = ic0*C;\n            if (ic >= args.ne11) {\n                break;\n            }\n\n            // the last partial chunk uses the pad buffer as source\n            if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) {\n                k    = pad;\n                v    = k + args.nb11*C*args.ne_12_2*args.ne_12_3;\n                mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;\n\n                const short ikv2 = iq2/(args.ne02/args.ne_12_2);\n                const short ikv3 = iq3/(args.ne03/args.ne_12_3);\n\n                k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;\n                v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;\n\n                if (!FC_flash_attn_ext_has_mask) {\n                    threadgroup half * sm = (threadgroup half *) (sm2);\n\n                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {\n                        const short j = jj*NSG + sgitg;\n\n                        for (short i = tiisg; i < C; i += NW) {\n                            if (ic + i >= args.ne11) {\n                                sm[2*j*SH + i] = -MAXHALF;\n                            }\n                        }\n                    }\n                } else {\n                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {\n                        const short j = jj*NSG + sgitg;\n\n                        pm2[jj] = (device const half2 *) ((device const half *) mask +\n                                (iq1 + j)*C +\n                                (iq2%args.ne32)*(C*args.ne31) +\n                                (iq3%args.ne33)*(C*args.ne31*args.ne32));\n                    }\n                }\n\n                ic = 0;\n            }\n\n            char blk_cur = 1;\n\n            // read the mask into shared mem\n            if (FC_flash_attn_ext_has_mask) {\n                blk_cur = blk[ic0];\n\n                if (blk_cur == 0) {\n                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {\n                        pm2[jj] += NW;\n                    }\n\n                    continue;\n                }\n\n                if (blk_cur == 1) {\n                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {\n                        const short j = jj*NSG + sgitg;\n\n                        if (FC_flash_attn_ext_bc_mask) {\n                            sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);\n                        } else {\n                            sm2[j*SH + tiisg] = pm2[jj][tiisg];\n                        }\n\n                        pm2[jj] += NW;\n                    }\n                } else if (blk_cur == 2) {\n                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {\n                        pm2[jj] += NW;\n                    }\n                }\n\n#if 0\n                // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks\n\n                threadgroup_barrier(mem_flags::mem_threadgroup);\n\n                // used to detect blocks full of -INF\n                // skip only when the entire threadgroup is masked\n                half2 smax2(-MAXHALF/2, -MAXHALF/2);\n\n                FOR_UNROLL (short j = 0; j < Q; ++j) {\n                    smax2 = max(smax2, sm2[j*SH + tiisg]);\n                }\n\n                smax2 = simd_max(smax2);\n\n                if (max(smax2[0], smax2[1]) <= -MAXHALF/2) {\n                    // this barrier is important\n                    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n                    continue;\n                }\n#endif\n            }\n\n            // Q*K^T\n            // this is compile-time check, so it does not have runtime overhead\n            if (is_same<kd4x4_t, k4x4_t>::value) {\n                // we can read directly from global memory\n                device      const k_t * pk = (device const k_t *) (k + ic*args.nb11);\n                threadgroup const q_t * pq = sq;\n                threadgroup       s_t * ps = ss;\n\n                pk += sgitg*(8*NS10);\n                ps += sgitg*(8*1);\n\n                static_assert((C/8) % NSG == 0, \"\");\n\n                constexpr short NC = (C/8)/NSG;\n\n                FOR_UNROLL (short cc = 0; cc < NC; ++cc) {\n                    qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);\n\n                    if (DK % 16 != 0) {\n                        k8x8_t mk;\n                        q8x8_t mq;\n\n                        FOR_UNROLL (short i = 0; i < DK8; ++i) {\n                            simdgroup_barrier(mem_flags::mem_none);\n\n                            simdgroup_load(mk, pk + 8*i, NS10, 0, true);\n                            simdgroup_load(mq, pq + 8*i, DK);\n\n                            simdgroup_barrier(mem_flags::mem_none);\n\n                            simdgroup_multiply_accumulate(mqk, mq, mk, mqk);\n                        }\n                    } else {\n                        k8x8_t mk[2];\n                        q8x8_t mq[2];\n\n                        // note: too much unroll can tank the performance for large heads\n                        #pragma unroll (MIN(DK8/2, 4*NSG))\n                        for (short i = 0; i < DK8/2; ++i) {\n                            simdgroup_barrier(mem_flags::mem_none);\n\n                            simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);\n                            simdgroup_load(mq[1], pq + 1*8 + 16*i, DK);\n\n                            simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true);\n                            simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true);\n\n                            simdgroup_barrier(mem_flags::mem_none);\n\n                            simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);\n                            simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);\n                        }\n                    }\n\n                    simdgroup_store(mqk, ps, SH, 0, false);\n\n                    pk += 8*(NSG*NS10);\n                    ps += 8*(NSG);\n                }\n            } else {\n                // TODO: this is the quantized K cache branch - not optimized yet\n                for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) {\n                    const short cc = ccc*NSG + sgitg;\n\n                    const short tx = tiisg%4;\n                    const short ty = tiisg/4;\n\n                    qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);\n\n                    for (short ii = 0; ii < DK16; ii += 4) {\n                        device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11));\n\n                        if (DK16%4 == 0) {\n                            // the head is evenly divisible by 4*16 = 64, so no need for bound checks\n                            {\n                                k4x4_t tmp;\n                                deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);\n                                sk4x4[4*ty + tx] = tmp;\n                            }\n\n                            simdgroup_barrier(mem_flags::mem_threadgroup);\n\n                            FOR_UNROLL (short k = 0; k < 4; ++k) {\n                                k8x8_t mk;\n                                q8x8_t mq;\n\n                                simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose\n                                simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);\n                                simdgroup_multiply_accumulate(mqk, mq, mk, mqk);\n\n                                simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose\n                                simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);\n                                simdgroup_multiply_accumulate(mqk, mq, mk, mqk);\n                            }\n                        } else {\n                            if (ii + tx < DK16) {\n                                k4x4_t tmp;\n                                deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);\n                                sk4x4[4*ty + tx] = tmp;\n                            }\n\n                            simdgroup_barrier(mem_flags::mem_threadgroup);\n\n                            for (short k = 0; k < 4 && ii + k < DK16; ++k) {\n                                k8x8_t mk;\n                                q8x8_t mq;\n\n                                simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose\n                                simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);\n                                simdgroup_multiply_accumulate(mqk, mq, mk, mqk);\n\n                                simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose\n                                simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);\n                                simdgroup_multiply_accumulate(mqk, mq, mk, mqk);\n                            }\n                        }\n                    }\n\n                    simdgroup_store(mqk, ss + 8*cc, SH, 0, false);\n                }\n            }\n\n            threadgroup_barrier(mem_flags::mem_threadgroup);\n\n            // online softmax\n            FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {\n                const short j = jj*NSG + sgitg;\n\n                const float m = M[jj];\n\n                // scale and apply the logitcap / mask\n                float2 s2 = ss2[j*SH/2 + tiisg]*args.scale;\n\n                if (FC_flash_attn_ext_has_scap) {\n                    s2 = args.logit_softcap*precise::tanh(s2);\n                }\n\n                // mqk = mqk + slope*mask\n                if (blk_cur != 2) {\n                    if (FC_flash_attn_ext_has_bias) {\n                        s2 += s2_t(sm2[j*SH + tiisg])*slope;\n                    } else {\n                        s2 += s2_t(sm2[j*SH + tiisg]);\n                    }\n                }\n\n                M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));\n\n                const float  ms  = exp(m  - M[jj]);\n                const float2 vs2 = exp(s2 - M[jj]);\n\n                S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]);\n\n                // the P matrix from the paper (Q rows, C columns)\n                ss2[j*SH/2 + tiisg] = vs2;\n\n                if (DV4 % NW == 0) {\n                    FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {\n                        const short i = ii*NW + tiisg;\n\n                        so4[j*PV4 + i] *= ms;\n                    }\n                } else {\n                    for (short i = tiisg; i < DV4; i += NW) {\n                        so4[j*PV4 + i] *= ms;\n                    }\n                }\n            }\n\n            threadgroup_barrier(mem_flags::mem_threadgroup);\n\n            // O = O + (Q*K^T)*V\n            {\n                // we can read directly from global memory\n                if (is_same<vd4x4_t, v4x4_t>::value) {\n                    static_assert(PV8 % NSG == 0, \"\");\n\n                    constexpr short NO = PV8/NSG;\n\n                    o8x8_t lo[NO];\n\n                    {\n                        auto sot = so + 8*sgitg;\n\n                        FOR_UNROLL (short ii = 0; ii < NO; ++ii) {\n                            simdgroup_load(lo[ii], sot, PV, 0, false);\n\n                            sot += 8*NSG;\n                        }\n                    }\n\n                    {\n                        device const v_t * pv = (device const v_t *) (v + ic*args.nb21);\n\n                        pv += 8*sgitg;\n\n                        if (DV <= 64) {\n                            FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {\n                                s8x8_t vs;\n                                simdgroup_load(vs, ss + 8*cc, SH, 0, false);\n\n                                FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {\n                                    v8x8_t mv[2];\n\n                                    simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false);\n                                    simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false);\n\n                                    simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]);\n                                    simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]);\n                                }\n\n                                pv  += 8*NS20;\n                            }\n                        } else {\n                            constexpr short NC = (C/8)/2;\n\n                            FOR_UNROLL (short cc = 0; cc < NC; ++cc) {\n                                s8x8_t vs[2];\n\n                                simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);\n                                simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false);\n\n                                FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {\n                                    v8x8_t mv[4];\n\n                                    simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);\n                                    simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);\n                                    simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);\n                                    simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);\n\n                                    simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]);\n                                    simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]);\n                                    simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]);\n                                    simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]);\n                                }\n\n                                pv  += 2*8*NS20;\n                            }\n                        }\n                    }\n\n                    {\n                        auto sot = so + 8*sgitg;\n\n                        FOR_UNROLL (short ii = 0; ii < NO; ++ii) {\n                            simdgroup_store(lo[ii], sot, PV, 0, false);\n\n                            sot += 8*NSG;\n                        }\n                    }\n                } else {\n                    // TODO: this is the quantized V cache branch - not optimized yet\n\n                    const short tx = tiisg%4;\n                    const short ty = tiisg/4;\n\n                    for (short cc = 0; cc < C/8; ++cc) {\n                        s8x8_t vs;\n                        simdgroup_load(vs, ss + 8*cc, SH, 0, false);\n\n                        for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {\n                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21));\n\n                            if (DV16%4 == 0) {\n                                // no need for bound checks\n                                {\n                                    v4x4_t tmp;\n                                    deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);\n                                    sv4x4[4*ty + tx] = tmp;\n                                }\n\n                                simdgroup_barrier(mem_flags::mem_threadgroup);\n\n                                FOR_UNROLL (short k = 0; k < 4; ++k) {\n                                    v8x8_t mv[2];\n                                    o8x8_t lo[2];\n\n                                    simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);\n                                    simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);\n                                    simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);\n                                    simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);\n\n                                    simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);\n                                    simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);\n\n                                    simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);\n                                    simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);\n                                }\n                            } else {\n                                if (ii + tx < DV16) {\n                                    v4x4_t tmp;\n                                    deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);\n                                    sv4x4[4*ty + tx] = tmp;\n                                }\n\n                                simdgroup_barrier(mem_flags::mem_threadgroup);\n\n                                for (short k = 0; k < 4 && ii + k < DV16; ++k) {\n                                    v8x8_t mv[2];\n                                    o8x8_t lo[2];\n\n                                    simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);\n                                    simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);\n                                    simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);\n                                    simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);\n\n                                    simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);\n                                    simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);\n\n                                    simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);\n                                    simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n\n            threadgroup_barrier(mem_flags::mem_threadgroup);\n        }\n\n        if (FC_flash_attn_ext_has_sinks) {\n            FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {\n                const short j = jj*NSG + sgitg;\n\n                const float m = M[jj];\n                const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;\n\n                M[jj] = simd_max(max(M[jj], s));\n\n                const float ms = exp(m - M[jj]);\n                const float vs = exp(s - M[jj]);\n\n                S[jj] = S[jj]*ms + simd_sum(vs);\n\n                for (short i = tiisg; i < DV4; i += NW) {\n                    so4[j*PV4 + i] *= ms;\n                }\n            }\n        }\n    }\n\n    // store to global memory\n    for (short jj = 0; jj < NQ; ++jj) {\n        const short j = jj*NSG + sgitg;\n        if (iq1 + j >= args.ne01) {\n            break;\n        }\n\n        device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;\n\n        const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];\n\n        if (DV4 % NW == 0) {\n            FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {\n                const short i = ii*NW + tiisg;\n\n                dst4[i] = (float4) so4[j*PV4 + i]*scale;\n            }\n        } else {\n            for (short i = tiisg; i < DV4; i += NW) {\n                dst4[i] = (float4) so4[j*PV4 + i]*scale;\n            }\n        }\n    }\n\n#undef NS10\n#undef NS20\n}\n\ntemplate<\n    typename q_t,     // query types in shared memory\n    typename q4_t,\n    typename q8x8_t,\n    typename k_t,     // key types in shared memory\n    typename k4x4_t,\n    typename k8x8_t,\n    typename v_t,     // value types in shared memory\n    typename v4x4_t,\n    typename v8x8_t,\n    typename qk_t,    // Q*K types\n    typename qk8x8_t,\n    typename s_t,     // soft-max types\n    typename s2_t,\n    typename s8x8_t,\n    typename o_t,     // attention accumulation types\n    typename o4_t,\n    typename o8x8_t,\n    typename kd4x4_t, // key type in device memory\n    short nl_k,\n    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),\n    typename vd4x4_t, // value type in device memory\n    short nl_v,\n    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),\n    short DK,         // K head size\n    short DV,         // V head size\n    short Q  = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup\n    short C  = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup\nkernel void kernel_flash_attn_ext(\n        constant ggml_metal_kargs_flash_attn_ext & args,\n        device const char * q,\n        device const char * k,\n        device const char * v,\n        device const char * mask,\n        device const char * sinks,\n        device const char * pad,\n        device const char * blk,\n        device       char * dst,\n        threadgroup  half * shmem_f16 [[threadgroup(0)]],\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {\n#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C\n#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg\n    switch (FC_flash_attn_ext_nsg) {\n      // note: disabled cases to reduce library load time\n      //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;\n      //case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;\n        case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;\n        case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;\n    }\n#undef FWD_TMPL\n#undef FWD_ARGS\n}\n\n// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as\n//       template to be able to explore different combinations\n//\n#define FA_TYPES \\\n    half,   half4,     simdgroup_half8x8,  \\\n    half,   half4x4,   simdgroup_half8x8,  \\\n    half,   half4x4,   simdgroup_half8x8,  \\\n    float,             simdgroup_float8x8, \\\n    float,  float2,    simdgroup_float8x8, \\\n    float,  float4,    simdgroup_float8x8\n    //half,   half4,     simdgroup_half8x8\n\n#define FA_TYPES_BF \\\n    bfloat, bfloat4,   simdgroup_bfloat8x8, \\\n    bfloat, bfloat4x4, simdgroup_bfloat8x8, \\\n    bfloat, bfloat4x4, simdgroup_bfloat8x8, \\\n    float,             simdgroup_float8x8,  \\\n    float,  float2,    simdgroup_float8x8,  \\\n    half,   half4,     simdgroup_half8x8\n    //float,  float4,    simdgroup_float8x8\n\n#define FA_TYPES_F32 \\\n    half,   half4,     simdgroup_half8x8,  \\\n    float,  float4x4,  simdgroup_float8x8, \\\n    float,  float4x4,  simdgroup_float8x8, \\\n    float,             simdgroup_float8x8, \\\n    float,  float2,    simdgroup_float8x8, \\\n    float,  float4,    simdgroup_float8x8\n    //half,   half4,     simdgroup_half8x8\n\ntypedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_f32_dk32_dv32\"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  32,  32>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f32_dk40_dv40\"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  40,  40>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f32_dk48_dv48\"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  48,  48>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f32_dk64_dv64\"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  64,  64>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f32_dk72_dv72\"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  72,  72>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f32_dk80_dv80\"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  80,  80>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f32_dk96_dv96\"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  96,  96>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f32_dk112_dv112\")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  112, 112>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f32_dk128_dv128\")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  128, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f32_dk192_dv192\")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  192, 192>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f32_dk192_dv128\")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  192, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f32_dk256_dv256\")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  256, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f32_dk320_dv256\")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  320, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f32_dk576_dv512\")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  576, 512>;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_f16_dk32_dv32\"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  32,  32>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f16_dk40_dv40\"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  40,  40>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f16_dk48_dv48\"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  48,  48>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f16_dk64_dv64\"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  64,  64>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f16_dk72_dv72\"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  72,  72>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f16_dk80_dv80\"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  80,  80>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f16_dk96_dv96\"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  96,  96>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f16_dk112_dv112\")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  112, 112>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f16_dk128_dv128\")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  128, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f16_dk192_dv192\")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  192, 192>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f16_dk192_dv128\")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  192, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f16_dk256_dv256\")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  256, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f16_dk320_dv256\")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  320, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_f16_dk576_dv512\")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  576, 512>;\n\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_flash_attn_ext_bf16_dk32_dv32\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 32,  32>;\ntemplate [[host_name(\"kernel_flash_attn_ext_bf16_dk40_dv40\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 40,  40>;\ntemplate [[host_name(\"kernel_flash_attn_ext_bf16_dk48_dv48\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 48,  48>;\ntemplate [[host_name(\"kernel_flash_attn_ext_bf16_dk64_dv64\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 64,  64>;\ntemplate [[host_name(\"kernel_flash_attn_ext_bf16_dk72_dv72\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 72,  72>;\ntemplate [[host_name(\"kernel_flash_attn_ext_bf16_dk80_dv80\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 80,  80>;\ntemplate [[host_name(\"kernel_flash_attn_ext_bf16_dk96_dv96\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 96,  96>;\ntemplate [[host_name(\"kernel_flash_attn_ext_bf16_dk112_dv112\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 112, 112>;\ntemplate [[host_name(\"kernel_flash_attn_ext_bf16_dk128_dv128\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 128, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_bf16_dk192_dv192\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 192, 192>;\ntemplate [[host_name(\"kernel_flash_attn_ext_bf16_dk192_dv128\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 192, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_bf16_dk256_dv256\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 256, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_bf16_dk320_dv256\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 320, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_bf16_dk576_dv512\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 576, 512>;\n#endif\n\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_0_dk32_dv32\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32,  32>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_0_dk40_dv40\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40,  40>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_0_dk48_dv48\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 48,  48>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_0_dk64_dv64\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64,  64>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_0_dk72_dv72\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72,  72>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_0_dk80_dv80\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80,  80>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_0_dk96_dv96\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96,  96>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_0_dk112_dv112\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_0_dk128_dv128\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_0_dk192_dv192\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_0_dk192_dv128\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_0_dk256_dv256\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_0_dk320_dv256\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 320, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_0_dk576_dv512\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_1_dk32_dv32\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32,  32>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_1_dk40_dv40\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40,  40>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_1_dk48_dv48\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 48,  48>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_1_dk64_dv64\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64,  64>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_1_dk72_dv72\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72,  72>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_1_dk80_dv80\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80,  80>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_1_dk96_dv96\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96,  96>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_1_dk112_dv112\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_1_dk128_dv128\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_1_dk192_dv192\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_1_dk192_dv128\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_1_dk256_dv256\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_1_dk320_dv256\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 320, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q4_1_dk576_dv512\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_0_dk32_dv32\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32,  32>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_0_dk40_dv40\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40,  40>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_0_dk48_dv48\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 48,  48>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_0_dk64_dv64\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64,  64>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_0_dk72_dv72\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72,  72>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_0_dk80_dv80\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80,  80>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_0_dk96_dv96\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96,  96>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_0_dk112_dv112\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_0_dk128_dv128\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_0_dk192_dv192\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_0_dk192_dv128\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_0_dk256_dv256\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_0_dk320_dv256\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 320, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_0_dk576_dv512\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_1_dk32_dv32\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32,  32>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_1_dk40_dv40\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40,  40>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_1_dk48_dv48\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 48,  48>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_1_dk64_dv64\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64,  64>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_1_dk72_dv72\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72,  72>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_1_dk80_dv80\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80,  80>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_1_dk96_dv96\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96,  96>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_1_dk112_dv112\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_1_dk128_dv128\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_1_dk192_dv192\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_1_dk192_dv128\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_1_dk256_dv256\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_1_dk320_dv256\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 320, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q5_1_dk576_dv512\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_q8_0_dk32_dv32\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32,  32>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q8_0_dk40_dv40\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40,  40>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q8_0_dk48_dv48\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 48,  48>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q8_0_dk64_dv64\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64,  64>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q8_0_dk72_dv72\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72,  72>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q8_0_dk80_dv80\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80,  80>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q8_0_dk96_dv96\"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96,  96>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q8_0_dk112_dv112\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q8_0_dk128_dv128\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q8_0_dk192_dv192\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q8_0_dk192_dv128\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q8_0_dk256_dv256\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q8_0_dk320_dv256\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 320, 256>;\ntemplate [[host_name(\"kernel_flash_attn_ext_q8_0_dk576_dv512\")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;\n\n#undef FA_TYPES\n#undef FA_TYPES_BF\n#undef FA_TYPES_F32\n\nconstant bool FC_flash_attn_ext_vec_has_mask  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]];\nconstant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];\nconstant bool FC_flash_attn_ext_vec_has_bias  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];\nconstant bool FC_flash_attn_ext_vec_has_scap  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];\nconstant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]];\n\n//constant float FC_flash_attn_ext_vec_scale         [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];\n//constant float FC_flash_attn_ext_vec_max_bias      [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];\n//constant float FC_flash_attn_ext_vec_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 12)]];\n\nconstant int32_t FC_flash_attn_ext_vec_ns10 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 20)]];\nconstant int32_t FC_flash_attn_ext_vec_ns20 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 21)]];\nconstant int32_t FC_flash_attn_ext_vec_nsg  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 22)]];\nconstant int32_t FC_flash_attn_ext_vec_nwg  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 23)]];\n\ntemplate<\n    typename q4_t,  // query types in shared memory\n    typename k4_t,  // key types in shared memory\n    typename v4_t,  // value types in shared memory\n    typename qk_t,  // Q*K types\n    typename s_t,   // soft-max types\n    typename s4_t,\n    typename o4_t,  // attention accumulation types\n    typename kd4_t, // key type in device memory\n    short nl_k,\n    void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),\n    typename vd4_t, // value type in device memory\n    short nl_v,\n    void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),\n    short DK,       // K head size\n    short DV,       // V head size\n    short NE = 4,   // head elements per thread\n    short Q  = OP_FLASH_ATTN_EXT_VEC_NQPSG,  // queries per threadgroup\n    short C  = OP_FLASH_ATTN_EXT_VEC_NCPSG>  // cache items per threadgroup\nkernel void kernel_flash_attn_ext_vec(\n        constant ggml_metal_kargs_flash_attn_ext_vec & args,\n        device const char * q,\n        device const char * k,\n        device const char * v,\n        device const char * mask,\n        device const char * sinks,\n        device const char * pad,\n        device       char * dst,\n        threadgroup  half * shmem_f16 [[threadgroup(0)]],\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {\n    static_assert(DK % 32 == 0, \"DK must be divisible by 32\");\n    static_assert(DV % 32 == 0, \"DV must be divisible by 32\");\n\n#define NWG  (FC_flash_attn_ext_vec_nwg)\n#define NSG  (FC_flash_attn_ext_vec_nsg)\n\n#define NS10 (FC_flash_attn_ext_vec_ns10)\n#define NS20 (FC_flash_attn_ext_vec_ns20)\n\n    const short iwg = tgpig[2]%NWG;\n\n    const ushort iq3 = tgpig[2]/NWG;\n    const ushort iq2 = tgpig[1];\n    const ushort iq1 = tgpig[0];\n\n    constexpr short DK4 = DK/4;\n    constexpr short DV4 = DV/4;\n\n    constexpr short PK  = PAD2(DK, 128);\n    constexpr short PK4 = PK/4;\n\n    constexpr short PV  = PAD2(DV, 128);\n    constexpr short PV4 = PV/4;\n\n    constexpr short NW  = N_SIMDWIDTH;\n    constexpr short NL  = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads\n    constexpr short SH  = 4*C;   // shared memory per simdgroup\n\n    static_assert(DK4 % NL == 0, \"DK4 must be divisible by NL\");\n    static_assert(DV4 % NL == 0, \"DV4 must be divisible by NL\");\n\n  //const short T = PK + NSG*SH; // shared memory size per query in (half)\n\n  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                      0*PK); // holds the query data\n    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                      0*PK); // same as above but in q4_t\n    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 +   sgitg*SH       + NSG*PK); // scratch buffer for attention\n    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 +   sgitg*SH       + NSG*PK); // same as above but in s4_t\n    threadgroup half  * sm  = (threadgroup half  *) (shmem_f16 +   sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask\n    threadgroup o4_t  * so4 = (threadgroup o4_t  *) (shmem_f16 + 2*sgitg*PV       + NSG*PK + NSG*SH); // scratch buffer for the results\n\n    // store the result for all queries in shared memory (the O matrix from the paper)\n    so4 += tiisg;\n\n    {\n        q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;\n\n        const short ikv2 = iq2/(args.ne02/args.ne_12_2);\n        const short ikv3 = iq3/(args.ne03/args.ne_12_3);\n\n        k += ikv2*args.nb12 + ikv3*args.nb13;\n        v += ikv2*args.nb22 + ikv3*args.nb23;\n    }\n\n    // load heads from Q to shared memory\n    device const float4 * q4 = (device const float4 *) ((device const char *) q);\n\n    if (iq1 < args.ne01) {\n        for (short i = tiisg; i < PK4; i += NW) {\n            if (i < DK4) {\n                sq4[i] = (q4_t) q4[i];\n            } else {\n                sq4[i] = (q4_t) 0.0f;\n            }\n        }\n    }\n\n    // zero out so\n    for (short i = 0; i < DV4/NL; ++i) {\n        so4[i*NL] = (o4_t) 0.0f;\n    }\n\n    // zero out shared memory SH\n    for (short i = tiisg; i < SH/4; i += NW) {\n        ss4[i] = (s4_t) 0.0f;\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    {\n        float S = 0.0f;\n        float M = -FLT_MAX/2;\n\n        // thread indices inside the simdgroup\n        const short tx = tiisg%NL;\n        const short ty = tiisg/NL;\n\n        // pointer to the mask\n        device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);\n\n        float slope = 1.0f;\n\n        // ALiBi\n        if (FC_flash_attn_ext_vec_has_bias) {\n            const short h = iq2;\n\n            const float base = h < args.n_head_log2 ? args.m0 : args.m1;\n            const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;\n\n            slope = pow(base, exph);\n        }\n\n        // loop over the KV cache\n        // each simdgroup handles blocks of Q rows and C columns\n        for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) {\n            int ic = ic0*C;\n            if (ic >= args.ne11) {\n                break;\n            }\n\n            // the last partial chunk uses the pad buffer as source\n            if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {\n                k    = pad;\n                v    = k + args.nb11*C*args.ne_12_2*args.ne_12_3;\n                mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;\n\n                const short ikv2 = iq2/(args.ne02/args.ne_12_2);\n                const short ikv3 = iq3/(args.ne03/args.ne_12_3);\n\n                k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;\n                v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;\n\n                if (!FC_flash_attn_ext_vec_has_mask) {\n                    if (ic + tiisg >= args.ne11) {\n                        sm[tiisg] = -MAXHALF;\n                    }\n                } else {\n                    pm = (device const half *) (mask) +\n                        iq1*C +\n                        (iq2%args.ne32)*(C*args.ne31) +\n                        (iq3%args.ne33)*(C*args.ne31*args.ne32);\n                }\n\n                ic = 0;\n            }\n\n            if (FC_flash_attn_ext_vec_has_mask) {\n                sm[tiisg] = pm[ic + tiisg];\n            }\n\n            // skip -INF blocks\n            if (simd_max(sm[tiisg]) <= -MAXHALF) {\n                continue;\n            }\n\n            // Q*K^T\n            {\n                device      const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11);\n                threadgroup const q4_t * pq4 = sq4;\n\n                pk4 += ty*NS10/4 + tx;\n                pq4 += tx;\n\n                qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f };\n\n                // each simdgroup processes 1 query and NE (NW/NL) cache elements\n                FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {\n                    if (is_same<kd4_t, k4_t>::value) {\n                        FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) {\n                            mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 +  ii*NL], (float4) pq4[ii*NL]);\n                        }\n                    } else {\n                        device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11));\n\n                        k4_t mk;\n\n                        FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) {\n                            const short i = ii*NL + tx;\n\n                            deq_k_t4(pk + i/nl_k, i%nl_k, mk);\n\n                            mqk[cc] += dot((float4) mk, (float4) sq4[i]);\n                        }\n                    }\n\n                    if (NE == 1) {\n                        mqk[cc] = simd_sum(mqk[cc]);\n                    } else {\n                        // simdgroup reduce (NE = 4)\n                        // [ 0 ..  7] -> [ 0]\n                        // [ 8 .. 15] -> [ 8]\n                        // [16 .. 23] -> [16]\n                        // [24 .. 31] -> [24]\n                        if (NE <= 1) {\n                            mqk[cc] += simd_shuffle_down(mqk[cc], 16);\n                        }\n                        if (NE <= 2) {\n                            mqk[cc] += simd_shuffle_down(mqk[cc],  8);\n                        }\n                        if (NE <= 4) {\n                            mqk[cc] += simd_shuffle_down(mqk[cc],  4);\n                        }\n                        if (NE <= 8) {\n                            mqk[cc] += simd_shuffle_down(mqk[cc],  2);\n                        }\n                        if (NE <= 16) {\n                            mqk[cc] += simd_shuffle_down(mqk[cc],  1);\n                        }\n\n                        // broadcast\n                        mqk[cc] = simd_shuffle(mqk[cc], NL*ty);\n                    }\n                }\n\n                if (FC_flash_attn_ext_vec_has_mask &&\n                   !FC_flash_attn_ext_vec_has_scap &&\n                   !FC_flash_attn_ext_vec_has_bias) {\n                    ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]);\n                } else {\n                    mqk[tx] *= args.scale;\n\n                    if (FC_flash_attn_ext_vec_has_scap) {\n                        mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]);\n                    }\n\n                    if (FC_flash_attn_ext_vec_has_bias) {\n                        mqk[tx] += (qk_t) sm[NE*tx + ty]*slope;\n                    } else {\n                        mqk[tx] += (qk_t) sm[NE*tx + ty];\n                    }\n\n                    ss[NE*tx + ty] = mqk[tx];\n                }\n            }\n\n            simdgroup_barrier(mem_flags::mem_threadgroup);\n\n            // online softmax\n            {\n                const float m = M;\n                const float s = ss[tiisg];\n\n                M = simd_max(max(M, s));\n\n                const float ms = exp(m - M);\n                const float vs = exp(s - M);\n\n                S = S*ms + simd_sum(vs);\n\n                // the P matrix from the paper (Q rows, C columns)\n                ss[tiisg] = vs;\n\n                // O = diag(ms)*O\n                if ((DV4/NL % NW == 0) || ty == 0) {\n                    FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {\n                        so4[ii*NL] *= ms;\n                    }\n                }\n            }\n\n            simdgroup_barrier(mem_flags::mem_threadgroup);\n\n            // O = O + (Q*K^T)*V\n            {\n                o4_t lo[DV4/NL];\n                FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {\n                    lo[ii] = 0.0f;\n                }\n\n                if (is_same<vd4_t, v4_t>::value) {\n                    device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21);\n\n                    pv4 += ty*NS20/4 + tx;\n\n                    const auto sst = ss + ty;\n\n                    FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {\n                        FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {\n                            lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE]));\n                        }\n                    }\n                } else {\n                    FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {\n                        device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21));\n\n                        FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {\n                            const short i = ii*NL + tx;\n\n                            v4_t mv;\n                            deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);\n\n                            lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty]));\n                        }\n                    }\n                }\n\n                FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {\n                    if (NE > 1) {\n                        lo[ii][0] += simd_shuffle_down(lo[ii][0], 16);\n                        lo[ii][1] += simd_shuffle_down(lo[ii][1], 16);\n                        lo[ii][2] += simd_shuffle_down(lo[ii][2], 16);\n                        lo[ii][3] += simd_shuffle_down(lo[ii][3], 16);\n                    }\n\n                    if (NE > 2) {\n                        lo[ii][0] += simd_shuffle_down(lo[ii][0],  8);\n                        lo[ii][1] += simd_shuffle_down(lo[ii][1],  8);\n                        lo[ii][2] += simd_shuffle_down(lo[ii][2],  8);\n                        lo[ii][3] += simd_shuffle_down(lo[ii][3],  8);\n                    }\n\n                    if (NE > 4) {\n                        lo[ii][0] += simd_shuffle_down(lo[ii][0],  4);\n                        lo[ii][1] += simd_shuffle_down(lo[ii][1],  4);\n                        lo[ii][2] += simd_shuffle_down(lo[ii][2],  4);\n                        lo[ii][3] += simd_shuffle_down(lo[ii][3],  4);\n                    }\n\n                    if (NE > 8) {\n                        lo[ii][0] += simd_shuffle_down(lo[ii][0],  2);\n                        lo[ii][1] += simd_shuffle_down(lo[ii][1],  2);\n                        lo[ii][2] += simd_shuffle_down(lo[ii][2],  2);\n                        lo[ii][3] += simd_shuffle_down(lo[ii][3],  2);\n                    }\n\n                    if (NE > 16) {\n                        lo[ii][0] += simd_shuffle_down(lo[ii][0],  1);\n                        lo[ii][1] += simd_shuffle_down(lo[ii][1],  1);\n                        lo[ii][2] += simd_shuffle_down(lo[ii][2],  1);\n                        lo[ii][3] += simd_shuffle_down(lo[ii][3],  1);\n                    }\n                }\n\n                if ((DV4/NL % NW == 0) || ty == 0) {\n                    FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {\n                        so4[ii*NL] += lo[ii];\n                    }\n                }\n            }\n        }\n\n        if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) {\n            const float m = M;\n            const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;\n\n            M = simd_max(max(M, s));\n\n            const float ms = exp(m - M);\n            const float vs = exp(s - M);\n\n            S = S*ms + simd_sum(vs);\n\n            if ((DV4/NL % NW == 0) || ty == 0) {\n                FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {\n                    so4[ii*NL] *= ms;\n                }\n            }\n        }\n\n        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)\n        if (tiisg == 0) {\n            ss[0] = (s_t) S;\n            ss[1] = (s_t) M;\n        }\n    }\n\n    so4 -= tiisg;\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // parallel reduce\n    for (short r = NSG/2; r > 0; r >>= 1) {\n        if (sgitg < r) {\n            const float S0 = ss[           0];\n            const float S1 = ss[r*(SH/2) + 0];\n\n            const float M0 = ss[           1];\n            const float M1 = ss[r*(SH/2) + 1];\n\n            const float M = max(M0, M1);\n\n            const float ms0 = exp(M0 - M);\n            const float ms1 = exp(M1 - M);\n\n            const float S = S0*ms0 + S1*ms1;\n\n            if (tiisg == 0) {\n                ss[0] = S;\n                ss[1] = M;\n            }\n\n            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1\n            for (short i = tiisg; i < DV4; i += NW) {\n                so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1;\n            }\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n    }\n\n    // final rescale with 1/S and store to global memory\n    if (sgitg == 0) {\n        const int64_t nrows = args.ne3*args.ne2*args.ne1;\n        const int64_t rid   = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;\n\n        device float4 * dst4 = (device float4 *) dst;\n        device float  * dst1 = (device float  *) dst + nrows*DV*NWG; // the S and M are stored after the results\n\n        const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f;\n\n        // interleave the workgroup data\n        for (short i = tiisg; i < DV4; i += NW) {\n            dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S;\n        }\n\n        // store S and M\n        if (NWG > 1) {\n            if (tiisg == 0) {\n                dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0];\n                dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1];\n            }\n        }\n    }\n\n#undef NWG\n#undef NSG\n#undef NS10\n#undef NS20\n}\n\n// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem\n//       in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max\n//\n#define FA_TYPES \\\n           half4,  \\\n           half4,  \\\n           half4,  \\\n    float,         \\\n    float, float4, \\\n           float4\n\n#define FA_TYPES_F32 \\\n           half4,  \\\n           float4, \\\n           float4, \\\n    float,         \\\n    float, float4, \\\n           float4\n\ntypedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f32_dk32_dv32\")]]    kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  32, 32, 4>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f16_dk32_dv32\")]]    kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  32, 32, 4>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_bf16_dk32_dv32\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 32, 32, 4>;\n#endif\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_0_dk32_dv32\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 32, 32, 4>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_1_dk32_dv32\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 32, 32, 4>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_0_dk32_dv32\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 32, 32, 4>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_1_dk32_dv32\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 32, 32, 4>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q8_0_dk32_dv32\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 32, 32, 4>;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f32_dk64_dv64\")]]    kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  64, 64, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f16_dk64_dv64\")]]    kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  64, 64, 2>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_bf16_dk64_dv64\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 64, 64, 2>;\n#endif\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_0_dk64_dv64\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 64, 64, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_1_dk64_dv64\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 64, 64, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_0_dk64_dv64\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 64, 64, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_1_dk64_dv64\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 64, 64, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q8_0_dk64_dv64\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 64, 64, 2>;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f32_dk96_dv96\")]]    kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  96, 96, 4>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f16_dk96_dv96\")]]    kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  96, 96, 4>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_bf16_dk96_dv96\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 96, 96, 4>;\n#endif\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_0_dk96_dv96\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 96, 96, 4>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_1_dk96_dv96\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 96, 96, 4>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_0_dk96_dv96\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 96, 96, 4>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_1_dk96_dv96\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 96, 96, 4>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q8_0_dk96_dv96\")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 96, 96, 4>;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f32_dk128_dv128\")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  128, 128, 1>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f16_dk128_dv128\")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  128, 128, 1>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_bf16_dk128_dv128\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 128, 128, 1>;\n#endif\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_0_dk128_dv128\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 128, 128, 1>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_1_dk128_dv128\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 128, 128, 1>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_0_dk128_dv128\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 128, 128, 1>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_1_dk128_dv128\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 128, 128, 1>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q8_0_dk128_dv128\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 128, 128, 1>;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f32_dk192_dv192\")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  192, 192, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f16_dk192_dv192\")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  192, 192, 2>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_bf16_dk192_dv192\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 192, 192, 2>;\n#endif\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_0_dk192_dv192\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 192, 192, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_1_dk192_dv192\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 192, 192, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_0_dk192_dv192\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 192, 192, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_1_dk192_dv192\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 192, 192, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q8_0_dk192_dv192\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 192, 192, 2>;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f32_dk192_dv128\")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  192, 128, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f16_dk192_dv128\")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  192, 128, 2>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_bf16_dk192_dv128\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 192, 128, 2>;\n#endif\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_0_dk192_dv128\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 192, 128, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_1_dk192_dv128\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 192, 128, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_0_dk192_dv128\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 192, 128, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_1_dk192_dv128\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 192, 128, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q8_0_dk192_dv128\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 192, 128, 2>;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f32_dk256_dv256\")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  256, 256, 1>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f16_dk256_dv256\")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  256, 256, 1>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_bf16_dk256_dv256\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 256, 256, 1>;\n#endif\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_0_dk256_dv256\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 256, 256, 1>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_1_dk256_dv256\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 256, 256, 1>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_0_dk256_dv256\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 256, 256, 1>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_1_dk256_dv256\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 256, 256, 1>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q8_0_dk256_dv256\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 256, 256, 1>;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f32_dk320_dv256\")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  320, 256, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f16_dk320_dv256\")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  320, 256, 2>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_bf16_dk320_dv256\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 320, 256, 2>;\n#endif\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_0_dk320_dv256\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 320, 256, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_1_dk320_dv256\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 320, 256, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_0_dk320_dv256\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 320, 256, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_1_dk320_dv256\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 320, 256, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q8_0_dk320_dv256\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 320, 256, 2>;\n\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f32_dk576_dv512\")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  576, 512, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_f16_dk576_dv512\")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  576, 512, 2>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_bf16_dk576_dv512\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 576, 512, 2>;\n#endif\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_0_dk576_dv512\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 576, 512, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q4_1_dk576_dv512\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 576, 512, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_0_dk576_dv512\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 576, 512, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q5_1_dk576_dv512\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 576, 512, 2>;\ntemplate [[host_name(\"kernel_flash_attn_ext_vec_q8_0_dk576_dv512\")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 576, 512, 2>;\n\n#undef FA_TYPES\n#undef FA_TYPES_F32\n\nconstant int32_t FC_flash_attn_ext_vec_reduce_DV  [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]];\nconstant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]];\n\nkernel void kernel_flash_attn_ext_vec_reduce(\n        constant ggml_metal_kargs_flash_attn_ext_vec_reduce & args,\n        device  const char * htmp,\n        device        char * dst,\n        uint   tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n#define NWG (FC_flash_attn_ext_vec_reduce_NWG)\n#define DV  (FC_flash_attn_ext_vec_reduce_DV)\n\n    const uint64_t rid = tgpig;\n\n    const short iwg = tiisg;\n\n    device const float  * ss    = (device const float  *) htmp + (uint64_t)args.nrows*DV*NWG;\n\n    float S = ss[rid*(2*NWG) + 2*iwg + 0];\n    float M = ss[rid*(2*NWG) + 2*iwg + 1];\n\n    const float m  = simd_max(M);\n    const float ms = exp(M - m);\n\n    S = simd_sum(S*ms);\n    S = S == 0.0f ? 0.0f : 1.0f/S;\n\n    const short DV4 = DV/4;\n\n    device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG;\n    device       float4 * dst4  = (device       float4 *) dst  + rid*DV4;\n\n    for (short i = sgitg; i < DV4; i += NWG) {\n        const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms);\n\n        if (iwg == 0) {\n            dst4[i] = v*S;\n        }\n    }\n\n#undef NWG\n#undef DV\n}\n\ntemplate<typename T0, typename T1>\nkernel void kernel_cpy_t_t(\n        constant ggml_metal_kargs_cpy & args,\n        device  const char * src0,\n        device        char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort  tiitg[[thread_index_in_threadgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    const int i03 = tgpig[2];\n    const int i02 = tgpig[1];\n    const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];\n    const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;\n\n    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;\n\n    const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);\n    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);\n    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;\n    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);\n\n    device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);\n\n    for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {\n        device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);\n        dst_data[i00] = (T1) src[0];\n        break;\n    }\n}\n\ntypedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;\n\ntemplate [[host_name(\"kernel_cpy_f32_f32\")]]   kernel kernel_cpy_t kernel_cpy_t_t<float,   float>;\ntemplate [[host_name(\"kernel_cpy_f32_f16\")]]   kernel kernel_cpy_t kernel_cpy_t_t<float,   half>;\ntemplate [[host_name(\"kernel_cpy_f32_i32\")]]   kernel kernel_cpy_t kernel_cpy_t_t<float,   int32_t>;\ntemplate [[host_name(\"kernel_cpy_i32_f32\")]]   kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;\ntemplate [[host_name(\"kernel_cpy_i32_i32\")]]   kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_cpy_f32_bf16\")]]  kernel kernel_cpy_t kernel_cpy_t_t<float,   bfloat>;\n#endif\ntemplate [[host_name(\"kernel_cpy_f16_f32\")]]   kernel kernel_cpy_t kernel_cpy_t_t<half,    float>;\ntemplate [[host_name(\"kernel_cpy_f16_f16\")]]   kernel kernel_cpy_t kernel_cpy_t_t<half,    half>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_cpy_bf16_f32\")]]  kernel kernel_cpy_t kernel_cpy_t_t<bfloat,  float>;\ntemplate [[host_name(\"kernel_cpy_bf16_bf16\")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat,  bfloat>;\n#endif\n\ntemplate<short QK,\n         typename block_q,\n         void (*quantize_func)(device const float *, device block_q &)>\nkernel void kernel_cpy_f32_q(\n        constant ggml_metal_kargs_cpy & args,\n        device const char * src0,\n        device char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort  tiitg[[thread_index_in_threadgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    const int i03 = tgpig[2];\n    const int i02 = tgpig[1];\n    const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];\n    const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;\n\n    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;\n\n    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);\n    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);\n    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;\n    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;\n\n    device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);\n\n    for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {\n        device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);\n\n        quantize_func(src, dst_data[i00]);\n\n        break;\n    }\n}\n\ntypedef decltype(kernel_cpy_f32_q<QK8_0,  block_q8_0,  quantize_q8_0>)  cpy_f_q_t;\n\ntemplate [[host_name(\"kernel_cpy_f32_q8_0\")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0,  block_q8_0,   quantize_q8_0>;\ntemplate [[host_name(\"kernel_cpy_f32_q4_0\")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0,  block_q4_0,   quantize_q4_0>;\ntemplate [[host_name(\"kernel_cpy_f32_q4_1\")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1,  block_q4_1,   quantize_q4_1>;\ntemplate [[host_name(\"kernel_cpy_f32_q5_0\")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0,  block_q5_0,   quantize_q5_0>;\ntemplate [[host_name(\"kernel_cpy_f32_q5_1\")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1,  block_q5_1,   quantize_q5_1>;\ntemplate [[host_name(\"kernel_cpy_f32_iq4_nl\")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;\n\ntemplate<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>\nkernel void kernel_cpy_q_f32(\n        constant ggml_metal_kargs_cpy & args,\n        device  const char * src0,\n        device        char * dst,\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort  tiitg[[thread_index_in_threadgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    const int i03 = tgpig[2];\n    const int i02 = tgpig[1];\n    const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];\n    const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;\n\n    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;\n\n    const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);\n    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);\n    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;\n    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);\n\n    device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);\n    device       T4x4    * dst_data = (device       T4x4    *)(dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1 + i0*args.nb0);\n\n    for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {\n        T4x4 temp;\n        dequantize_func(src_data + i00/nl, i00%nl, temp);\n        dst_data[i00] = temp;\n\n        break;\n    }\n}\n\ntypedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;\n\ntemplate [[host_name(\"kernel_cpy_q4_0_f32\")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;\ntemplate [[host_name(\"kernel_cpy_q4_1_f32\")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;\ntemplate [[host_name(\"kernel_cpy_q5_0_f32\")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;\ntemplate [[host_name(\"kernel_cpy_q5_1_f32\")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;\ntemplate [[host_name(\"kernel_cpy_q8_0_f32\")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;\n\ntemplate [[host_name(\"kernel_cpy_q4_0_f16\")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;\ntemplate [[host_name(\"kernel_cpy_q4_1_f16\")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;\ntemplate [[host_name(\"kernel_cpy_q5_0_f16\")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;\ntemplate [[host_name(\"kernel_cpy_q5_1_f16\")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;\ntemplate [[host_name(\"kernel_cpy_q8_0_f16\")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;\n\nkernel void kernel_concat(\n    constant ggml_metal_kargs_concat & args,\n    device  const char * src0,\n    device  const char * src1,\n    device        char * dst,\n    uint3   tgpig[[threadgroup_position_in_grid]],\n    ushort3 tpitg[[thread_position_in_threadgroup]],\n    ushort3   ntg[[threads_per_threadgroup]]) {\n\n    const int i3 = tgpig.z;\n    const int i2 = tgpig.y;\n    const int i1 = tgpig.x;\n\n    int o[4] = {0, 0, 0, 0};\n    o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));\n\n    device const float * x;\n\n    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {\n        if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {\n            x = (device const float *)(src0 + (i3       )*args.nb03 + (i2       )*args.nb02 + (i1       )*args.nb01 + (i0       )*args.nb00);\n        } else {\n            x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);\n        }\n\n        device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);\n\n        *y = *x;\n    }\n}\n\ntemplate<int nr0, typename args_t>\nvoid kernel_mul_mv_q2_K_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    const int nb = args.ne00/QK_K;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const int first_row = (r0 * NSG + sgitg) * nr0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);\n    device const float      * y = (device const float      *) (src1 + offset1);\n\n    float yl[32];\n    float sumf[nr0]={0.f};\n\n    const short ix = tiisg/8;  // 0...3\n    const short it = tiisg%8;  // 0...7\n    const short iq = it/4;     // 0 or 1\n    const short ir = it%4;     // 0...3\n    const short is = (8*ir)/16;// 0 or 1\n\n    device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;\n\n    for (int ib = ix; ib < nb; ib += 4) {\n        float4 sumy = {0.f, 0.f, 0.f, 0.f};\n        for (short i = 0; i < 8; ++i) {\n            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];\n            yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];\n            yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];\n            yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];\n        }\n\n        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales + 8*iq + is;\n        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;\n        device const half     * dh = &x[ib].d;\n\n        for (short row = 0; row < nr0; row++) {\n            float4 acc1 = {0.f, 0.f, 0.f, 0.f};\n            float4 acc2 = {0.f, 0.f, 0.f, 0.f};\n            for (int i = 0; i < 8; i += 2) {\n                acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);\n                acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);\n                acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);\n                acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);\n                acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);\n                acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);\n                acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);\n                acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);\n            }\n            float dall = dh[0];\n            float dmin = dh[1] * 1.f/16.f;\n            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +\n                                 (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +\n                                 (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +\n                                 (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -\n                         dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));\n\n            qs += args.nb01/2;\n            sc += args.nb01;\n            dh += args.nb01/2;\n        }\n\n        y4 += 4 * QK_K;\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {\n        float sum_all = simd_sum(sumf[row]);\n        if (tiisg == 0) {\n            dst_f32[first_row + row] = sum_all;\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_q2_K_f32\")]]\nkernel void kernel_mul_mv_q2_K_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);\n}\n\ntemplate<int nr0, typename args_t>\nvoid kernel_mul_mv_q3_K_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    const int nb = args.ne00/QK_K;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const int first_row = (r0 * NSG + sgitg) * nr0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);\n    device const float     * yy = (device const float      *) (src1 + offset1);\n\n    float yl[32];\n\n    //const uint16_t kmask1 = 0x3030;\n    //const uint16_t kmask2 = 0x0f0f;\n\n    const short tid = tiisg/4;\n    const short ix  = tiisg%4;\n    const short ip  = tid/4;          // 0 or 1\n    const short il  = 2*((tid%4)/2);  // 0 or 2\n    const short ir  = tid%2;\n    const short l0  = 8*ir;\n\n    // One would think that the Metal compiler would figure out that ip and il can only have\n    // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it\n    // with these two tales.\n    //\n    // Possible masks for the high bit\n    const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200},  // ip = 0, il = 0\n                           {0x0004, 0x0400, 0x0008, 0x0800},  // ip = 0, il = 2\n                           {0x0010, 0x1000, 0x0020, 0x2000},  // ip = 1, il = 0\n                           {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2\n\n    // Possible masks for the low 2 bits\n    const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};\n\n    const ushort4 hm = mm[2*ip + il/2];\n\n    const short shift = 2*il;\n\n    const float v1 = il == 0 ? 4.f : 64.f;\n    const float v2 = 4.f * v1;\n\n    const uint16_t s_shift1 = 4*ip;\n    const uint16_t s_shift2 = s_shift1 + il;\n\n    const short q_offset = 32*ip + l0;\n    const short y_offset = 128*ip + 32*il + l0;\n\n    device const float * y1 = yy + ix*QK_K + y_offset;\n\n    uint32_t scales32, aux32;\n    thread uint16_t * scales16 = (thread uint16_t *)&scales32;\n    thread const int8_t * scales = (thread const int8_t *)&scales32;\n\n    float sumf1[nr0] = {0.f};\n    float sumf2[nr0] = {0.f};\n\n    for (int i = ix; i < nb; i += 4) {\n        for (short l = 0; l < 8; ++l) {\n            yl[l+ 0] = y1[l+ 0];\n            yl[l+ 8] = y1[l+16];\n            yl[l+16] = y1[l+32];\n            yl[l+24] = y1[l+48];\n        }\n\n        device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);\n        device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);\n        device const uint16_t * a = (device const uint16_t *)(x[i].scales);\n        device const half * dh = &x[i].d;\n\n        for (short row = 0; row < nr0; ++row) {\n            const float d_all = (float)dh[0];\n\n            scales16[0] = a[4];\n            scales16[1] = a[5];\n            aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;\n            scales16[0] = a[il+0];\n            scales16[1] = a[il+1];\n            scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;\n\n            float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;\n            for (short l = 0; l < 8; l += 2) {\n                const int32_t qs = q[l/2];\n                s1 += yl[l+0] * (qs & qm[il/2][0]);\n                s2 += yl[l+1] * (qs & qm[il/2][1]);\n                s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);\n                s4 += yl[l+16] * (qs & qm[il/2][2]);\n                s5 += yl[l+17] * (qs & qm[il/2][3]);\n                s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);\n            }\n            float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);\n            float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);\n            sumf1[row] += d1 * (scales[0] - 32);\n            sumf2[row] += d2 * (scales[2] - 32);\n\n            s1 = s2 = s3 = s4 = s5 = s6 = 0;\n            for (short l = 0; l < 8; l += 2) {\n                const int32_t qs = q[l/2+8];\n                s1 += yl[l+8] * (qs & qm[il/2][0]);\n                s2 += yl[l+9] * (qs & qm[il/2][1]);\n                s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);\n                s4 += yl[l+24] * (qs & qm[il/2][2]);\n                s5 += yl[l+25] * (qs & qm[il/2][3]);\n                s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);\n            }\n            d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);\n            d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);\n            sumf1[row] += d1 * (scales[1] - 32);\n            sumf2[row] += d2 * (scales[3] - 32);\n\n            q  += args.nb01/2;\n            h  += args.nb01/2;\n            a  += args.nb01/2;\n            dh += args.nb01/2;\n        }\n\n        y1 += 4 * QK_K;\n    }\n\n    for (int row = 0; row < nr0; ++row) {\n        const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);\n        sumf1[row] = simd_sum(sumf);\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    if (tiisg == 0) {\n        for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {\n            dst_f32[first_row + row] = sumf1[row];\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_q3_K_f32\")]]\nkernel void kernel_mul_mv_q3_K_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);\n}\n\ntemplate<int nr0, typename args_t>\nvoid kernel_mul_mv_q4_K_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    constexpr uint16_t kmask1 = 0x3f3f;\n    constexpr uint16_t kmask2 = 0x0f0f;\n    constexpr uint16_t kmask3 = 0xc0c0;\n\n    const short ix = tiisg/8;  // 0...3\n    const short it = tiisg%8;  // 0...7\n    const short iq = it/4;     // 0 or 1\n    const short ir = it%4;     // 0...3\n\n    const int nb = args.ne00/QK_K;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const int first_row = (r0 * NSG + sgitg) * nr0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);\n    device const float      * y = (device const float      *) (src1 + offset1);\n\n    float yl[16];\n    float yh[16];\n\n    float sumf[nr0]={0.f};\n\n    device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;\n\n    uint16_t sc16[4];\n    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;\n\n    for (int ib = ix; ib < nb; ib += 4) {\n        float4 sumy = {0.f, 0.f, 0.f, 0.f};\n\n        for (short i = 0; i < 8; ++i) {\n            yl[i+0] = y4[i+  0]; sumy[0] += yl[i+0];\n            yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];\n            yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];\n            yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];\n        }\n\n        device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;\n        device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;\n        device const half     * dh = &x[ib].d;\n\n        for (short row = 0; row < nr0; row++) {\n            sc16[0] = sc[0] & kmask1;\n            sc16[1] = sc[2] & kmask1;\n            sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);\n            sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);\n\n            device const uint16_t * q2 = q1 + 32;\n\n            float4 acc1 = {0.f, 0.f, 0.f, 0.f};\n            float4 acc2 = {0.f, 0.f, 0.f, 0.f};\n\n            FOR_UNROLL (short i = 0; i < 4; ++i) {\n                acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);\n                acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);\n                acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);\n                acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000);\n                acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F);\n                acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00);\n                acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0);\n                acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);\n            }\n\n            sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +\n                                  (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +\n                                  (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +\n                                  (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -\n                         dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);\n\n            q1 += args.nb01/2;\n            sc += args.nb01/2;\n            dh += args.nb01/2;\n        }\n\n        y4 += 4 * QK_K;\n    }\n\n    device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;\n\n    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {\n        float sum_all = simd_sum(sumf[row]);\n        if (tiisg == 0) {\n            dst_f32[first_row + row] = sum_all;\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_q4_K_f32\")]]\nkernel void kernel_mul_mv_q4_K_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);\n}\n\ntemplate<int nr0, typename args_t>\nvoid kernel_mul_mv_q5_K_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    const int nb = args.ne00/QK_K;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const int first_row = (r0 * NSG + sgitg) * nr0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);\n    device const float     * yy = (device const float      *) (src1 + offset1);\n\n    float sumf[nr0]={0.f};\n\n    float yl[16], yh[16];\n\n    constexpr uint16_t kmask1 = 0x3f3f;\n    constexpr uint16_t kmask2 = 0x0f0f;\n    constexpr uint16_t kmask3 = 0xc0c0;\n\n    const short tid = tiisg/4;\n    const short ix  = tiisg%4;\n    const short iq  = tid/4;\n    const short ir  = tid%4;\n\n    const short l0 = 8*ir;\n    const short q_offset = 32*iq + l0;\n    const short y_offset = 64*iq + l0;\n\n    const uint8_t hm1 = 1u << (2*iq);\n    const uint8_t hm2 = hm1 << 1;\n    const uint8_t hm3 = hm1 << 4;\n    const uint8_t hm4 = hm2 << 4;\n\n    uint16_t sc16[4];\n    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;\n\n    device const float * y1 = yy + ix*QK_K + y_offset;\n\n    for (int i = ix; i < nb; i += 4) {\n        device const uint8_t * q1 = x[i].qs + q_offset;\n        device const uint8_t * qh = x[i].qh + l0;\n        device const half * dh = &x[i].d;\n        device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;\n\n        device const float * y2 = y1 + 128;\n        float4 sumy = {0.f, 0.f, 0.f, 0.f};\n        for (short l = 0; l < 8; ++l) {\n            yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];\n            yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];\n            yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];\n            yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];\n        }\n\n        for (short row = 0; row < nr0; ++row) {\n            device const uint8_t * q2 = q1 + 64;\n\n            sc16[0] = a[0] & kmask1;\n            sc16[1] = a[2] & kmask1;\n            sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);\n            sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);\n\n            float4 acc1 = {0.f};\n            float4 acc2 = {0.f};\n            FOR_UNROLL (short l = 0; l < 8; ++l) {\n                uint8_t h = qh[l];\n                acc1[0] += yl[l+0] * (q1[l] & 0x0F);\n                acc1[1] += yl[l+8] * (q1[l] & 0xF0);\n                acc1[2] += yh[l+0] * (q2[l] & 0x0F);\n                acc1[3] += yh[l+8] * (q2[l] & 0xF0);\n                acc2[0] += h & hm1 ? yl[l+0] : 0.f;\n                acc2[1] += h & hm2 ? yl[l+8] : 0.f;\n                acc2[2] += h & hm3 ? yh[l+0] : 0.f;\n                acc2[3] += h & hm4 ? yh[l+8] : 0.f;\n            }\n\n            sumf[row] += dh[0] * (sc8[0] * (acc1[0]      + 16.f*acc2[0]) +\n                                  sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +\n                                  sc8[4] * (acc1[2]      + 16.f*acc2[2]) +\n                                  sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -\n                         dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);\n\n            q1 += args.nb01;\n            qh += args.nb01;\n            dh += args.nb01/2;\n            a  += args.nb01/2;\n        }\n\n        y1 += 4 * QK_K;\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {\n        const float tot = simd_sum(sumf[row]);\n        if (tiisg == 0) {\n            dst_f32[first_row + row] = tot;\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_q5_K_f32\")]]\nkernel void kernel_mul_mv_q5_K_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);\n}\n\ntemplate<int nr0, typename args_t>\nvoid kernel_mul_mv_q6_K_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    constexpr uint8_t kmask1 = 0x03;\n    constexpr uint8_t kmask2 = 0x0C;\n    constexpr uint8_t kmask3 = 0x30;\n    constexpr uint8_t kmask4 = 0xC0;\n\n    const int nb = args.ne00/QK_K;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const int first_row = (r0 * NSG + sgitg) * nr0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);\n    device const float     * yy = (device const float      *) (src1 + offset1);\n\n    float sumf[nr0] = { 0.f };\n\n    float yl[16];\n\n    const short tid = tiisg/2;\n    const short ix  = tiisg%2;\n    const short ip  = tid/8;         // 0 or 1\n    const short il  = tid%8;\n    const short l0  = 4*il;\n    const short is  = 8*ip + l0/16;\n\n    const short y_offset   = 128*ip + l0;\n    const short q_offset_l =  64*ip + l0;\n    const short q_offset_h =  32*ip + l0;\n\n    for (int i = ix; i < nb; i += 2) {\n        device const uint8_t * q1 = x[i].ql + q_offset_l;\n        device const uint8_t * q2 = q1 + 32;\n        device const uint8_t * qh = x[i].qh + q_offset_h;\n        device const int8_t  * sc = x[i].scales + is;\n        device const half    * dh = &x[i].d;\n\n        device const float * y = yy + i * QK_K + y_offset;\n\n        for (short l = 0; l < 4; ++l) {\n            yl[4*l + 0] = y[l +  0];\n            yl[4*l + 1] = y[l + 32];\n            yl[4*l + 2] = y[l + 64];\n            yl[4*l + 3] = y[l + 96];\n        }\n\n        for (short row = 0; row < nr0; ++row) {\n            float4 sums = {0.f, 0.f, 0.f, 0.f};\n\n            FOR_UNROLL (short l = 0; l < 4; ++l) {\n                sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);\n                sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);\n                sums[2] += yl[4*l + 2] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);\n                sums[3] += yl[4*l + 3] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);\n            }\n\n            sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);\n\n            q1 += args.nb01;\n            q2 += args.nb01;\n            qh += args.nb01;\n            sc += args.nb01;\n            dh += args.nb01/2;\n        }\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {\n        float sum_all = simd_sum(sumf[row]);\n        if (tiisg == 0) {\n            dst_f32[first_row + row] = sum_all;\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_q6_K_f32\")]]\nkernel void kernel_mul_mv_q6_K_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);\n}\n\n// ======================= \"True\" 2-bit\n\ntemplate<int nr0, typename args_t>\nvoid kernel_mul_mv_iq2_xxs_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    const int nb = args.ne00/QK_K;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const int first_row = (r0 * NSG + sgitg) * nr0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);\n    device const float         * y = (device const float         *) (src1 + offset1);\n\n    float yl[32];\n    float sumf[nr0]={0.f};\n\n    const int nb32 = nb * (QK_K / 32);\n\n    threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);\n    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 256);\n    {\n        int nval = 4;\n        int pos  = (32*sgitg + tiisg)*nval;\n        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i];\n        nval = 2;\n        pos  = (32*sgitg + tiisg)*nval;\n        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n    }\n\n    const int ix = tiisg;\n\n    device const float * y4 = y + 32 * ix;\n\n    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {\n        for (short i = 0; i < 32; ++i) {\n            yl[i] = y4[i];\n        }\n\n        const int ibl = ib32 / (QK_K / 32);\n        const int ib  = ib32 % (QK_K / 32);\n\n        device const block_iq2_xxs * xr = x + ibl;\n        device const uint16_t * q2 = xr->qs + 4 * ib;\n        device const half * dh = &xr->d;\n\n        for (short row = 0; row < nr0; row++) {\n            const float db = dh[0];\n            device const uint8_t * aux8 = (device const uint8_t *)q2;\n            const uint32_t aux32 = q2[2] | (q2[3] << 16);\n            const float d = db * (0.5f + (aux32 >> 28));\n\n            float sum = 0;\n            for (short l = 0; l < 4; ++l) {\n                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);\n                const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];\n                for (short j = 0; j < 8; ++j) {\n                    sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);\n                }\n            }\n            sumf[row] += d * sum;\n\n            dh += args.nb01/2;\n            q2 += args.nb01/2;\n        }\n\n        y4 += 32 * 32;\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {\n        float sum_all = simd_sum(sumf[row]);\n        if (tiisg == 0) {\n            dst_f32[first_row + row] = sum_all * 0.25f;\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_iq2_xxs_f32\")]]\nkernel void kernel_mul_mv_iq2_xxs_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n    kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\ntemplate<int nr0, typename args_t>\nvoid kernel_mul_mv_iq2_xs_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    const int nb = args.ne00/QK_K;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const int first_row = (r0 * NSG + sgitg) * nr0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);\n    device const float        * y = (device const float        *) (src1 + offset1);\n\n    float yl[32];\n    float sumf[nr0]={0.f};\n\n    const int nb32 = nb * (QK_K / 32);\n\n    threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);\n    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 512);\n    {\n        int nval = 8;\n        int pos  = (32*sgitg + tiisg)*nval;\n        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i];\n        nval = 2;\n        pos  = (32*sgitg + tiisg)*nval;\n        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n    }\n\n    const int ix = tiisg;\n\n    device const float * y4 = y + 32 * ix;\n\n    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {\n        for (short i = 0; i < 32; ++i) {\n            yl[i] = y4[i];\n        }\n\n        const int ibl = ib32 / (QK_K / 32);\n        const int ib  = ib32 % (QK_K / 32);\n\n        device const block_iq2_xs * xr = x + ibl;\n        device const uint16_t * q2 = xr->qs + 4 * ib;\n        device const uint8_t  * sc = xr->scales + ib;\n        device const half * dh = &xr->d;\n\n        for (short row = 0; row < nr0; row++) {\n            const float db = dh[0];\n            const uint8_t ls1 = sc[0] & 0xf;\n            const uint8_t ls2 = sc[0] >>  4;\n            const float d1 = db * (0.5f + ls1);\n            const float d2 = db * (0.5f + ls2);\n\n            float sum1 = 0, sum2 = 0;\n            for (short l = 0; l < 2; ++l) {\n                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));\n                const uint8_t signs = ssigns[(q2[l] >> 9)];\n                for (short j = 0; j < 8; ++j) {\n                    sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);\n                }\n            }\n            for (short l = 2; l < 4; ++l) {\n                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));\n                const uint8_t signs = ssigns[(q2[l] >> 9)];\n                for (short j = 0; j < 8; ++j) {\n                    sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);\n                }\n            }\n            sumf[row] += d1 * sum1 + d2 * sum2;\n\n            dh += args.nb01/2;\n            q2 += args.nb01/2;\n            sc += args.nb01;\n        }\n\n        y4 += 32 * 32;\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {\n        float sum_all = simd_sum(sumf[row]);\n        if (tiisg == 0) {\n            dst_f32[first_row + row] = sum_all * 0.25f;\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_iq2_xs_f32\")]]\nkernel void kernel_mul_mv_iq2_xs_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\ntemplate<int nr0, typename args_t>\nvoid kernel_mul_mv_iq3_xxs_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    const int nb = args.ne00/QK_K;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const int first_row = (r0 * NSG + sgitg) * nr0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);\n    device const float         * y = (device const float         *) (src1 + offset1);\n\n    float yl[32];\n    float sumf[nr0]={0.f};\n\n    const int nb32 = nb * (QK_K / 32);\n\n    threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem);\n    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 256);\n    {\n        int nval = 4;\n        int pos  = (32*sgitg + tiisg)*nval;\n        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i];\n        nval = 2;\n        pos  = (32*sgitg + tiisg)*nval;\n        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n    }\n\n    const int ix = tiisg;\n\n    device const float * y4 = y + 32 * ix;\n\n    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {\n        for (short i = 0; i < 32; ++i) {\n            yl[i] = y4[i];\n        }\n\n        const int ibl = ib32 / (QK_K / 32);\n        const int ib  = ib32 % (QK_K / 32);\n\n        device const block_iq3_xxs * xr = x + ibl;\n        device const uint8_t  * q3 = xr->qs + 8 * ib;\n        device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;\n        device const half * dh = &xr->d;\n\n        for (short row = 0; row < nr0; row++) {\n            const float db = dh[0];\n            const uint32_t aux32 = gas[0] | (gas[1] << 16);\n            const float d = db * (0.5f + (aux32 >> 28));\n\n            float2 sum = {0};\n            for (short l = 0; l < 4; ++l) {\n                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);\n                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);\n                const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];\n                for (short j = 0; j < 4; ++j) {\n                    sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);\n                    sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);\n                }\n            }\n            sumf[row] += d * (sum[0] + sum[1]);\n\n            dh  += args.nb01/2;\n            q3  += args.nb01;\n            gas += args.nb01/2;\n        }\n\n        y4 += 32 * 32;\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {\n        float sum_all = simd_sum(sumf[row]);\n        if (tiisg == 0) {\n            dst_f32[first_row + row] = sum_all * 0.5f;\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_iq3_xxs_f32\")]]\nkernel void kernel_mul_mv_iq3_xxs_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\ntemplate<int nr0, typename args_t>\nvoid kernel_mul_mv_iq3_s_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    const int nb = args.ne00/QK_K;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const int first_row = (r0 * NSG + sgitg) * nr0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);\n    device const float       * y = (device const float       *) (src1 + offset1);\n\n    float yl[32];\n    float sumf[nr0]={0.f};\n\n    const int nb32 = nb * (QK_K / 32);\n\n    threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem;\n    {\n        int nval = 8;\n        int pos  = (32*sgitg + tiisg)*nval;\n        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i];\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n    }\n\n    const int ix = tiisg;\n\n    device const float * y4 = y + 32 * ix;\n\n    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {\n        for (short i = 0; i < 32; ++i) {\n            yl[i] = y4[i];\n        }\n\n        const int ibl = ib32 / (QK_K / 32);\n        const int ib  = ib32 % (QK_K / 32);\n\n        device const block_iq3_s * xr = x + ibl;\n        device const uint8_t * qs = xr->qs + 8 * ib;\n        device const uint8_t * qh = xr->qh + ib;\n        device const uint8_t * sc = xr->scales + (ib/2);\n        device const uint8_t * signs = xr->signs + 4 * ib;\n        device const half * dh = &xr->d;\n\n        for (short row = 0; row < nr0; row++) {\n            const float db = dh[0];\n            const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));\n\n            float2 sum = {0};\n            for (short l = 0; l < 4; ++l) {\n                const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;\n                const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;\n                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);\n                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);\n                for (short j = 0; j < 4; ++j) {\n                    sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);\n                    sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);\n                }\n            }\n            sumf[row] += d * (sum[0] + sum[1]);\n\n            dh    += args.nb01/2;\n            qs    += args.nb01;\n            qh    += args.nb01;\n            sc    += args.nb01;\n            signs += args.nb01;\n        }\n\n        y4 += 32 * 32;\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {\n        float sum_all = simd_sum(sumf[row]);\n        if (tiisg == 0) {\n            dst_f32[first_row + row] = sum_all;\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_iq3_s_f32\")]]\nkernel void kernel_mul_mv_iq3_s_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\ntemplate<int nr0, typename args_t>\nvoid kernel_mul_mv_iq2_s_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    const int nb = args.ne00/QK_K;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const int first_row = (r0 * NSG + sgitg) * nr0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);\n    device const float       * y = (device const float       *) (src1 + offset1);\n\n    float yl[32];\n    float sumf[nr0]={0.f};\n\n    const int nb32 = nb * (QK_K / 32);\n\n    //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem;\n    //{\n    //    int nval = 32;\n    //    int pos  = (32*sgitg + tiisg)*nval;\n    //    for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i];\n    //    threadgroup_barrier(mem_flags::mem_threadgroup);\n    //}\n\n    const short ix = tiisg;\n\n    device const float * y4 = y + 32 * ix;\n\n    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {\n        for (short i = 0; i < 32; ++i) {\n            yl[i] = y4[i];\n        }\n\n        const int ibl = ib32 / (QK_K / 32);\n        const int ib  = ib32 % (QK_K / 32);\n\n        device const block_iq2_s * xr = x + ibl;\n        device const uint8_t * qs = xr->qs + 4 * ib;\n        device const uint8_t * qh = xr->qh + ib;\n        device const uint8_t * sc = xr->scales + ib;\n        device const uint8_t * signs = qs + QK_K/8;\n        device const half * dh = &xr->d;\n\n        for (short row = 0; row < nr0; row++) {\n            const float db = dh[0];\n            const float d1 = db * (0.5f + (sc[0] & 0xf));\n            const float d2 = db * (0.5f + (sc[0] >>  4));\n\n            float2 sum = {0};\n            for (short l = 0; l < 2; ++l) {\n                //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));\n                //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));\n                constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));\n                constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));\n                for (short j = 0; j < 8; ++j) {\n                    sum[0] += yl[8*l + j +  0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);\n                    sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);\n                }\n            }\n            sumf[row] += d1 * sum[0] + d2 * sum[1];\n\n            dh    += args.nb01/2;\n            qs    += args.nb01;\n            qh    += args.nb01;\n            sc    += args.nb01;\n            signs += args.nb01;\n        }\n\n        y4 += 32 * 32;\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {\n        float sum_all = simd_sum(sumf[row]);\n        if (tiisg == 0) {\n            dst_f32[first_row + row] = sum_all * 0.25f;\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_iq2_s_f32\")]]\nkernel void kernel_mul_mv_iq2_s_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\ntemplate<int nr0, typename args_t>\nvoid kernel_mul_mv_iq1_s_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    const int nb = args.ne00/QK_K;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const int first_row = (r0 * NSG + sgitg) * nr0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);\n    device const float       * y = (device const float       *) (src1 + offset1);\n\n    float yl[32];\n    float sumf[nr0]={0.f};\n\n    const int nb32 = nb * (QK_K / 32);\n\n    const short ix = tiisg;\n\n    device const float * y4 = y + 32 * ix;\n\n    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {\n        float sumy = 0;\n        for (short i = 0; i < 32; ++i) {\n            yl[i] = y4[i];\n            sumy += yl[i];\n        }\n\n        const int ibl = ib32 / (QK_K / 32);\n        const int ib  = ib32 % (QK_K / 32);\n\n        device const block_iq1_s * xr = x + ibl;\n        device const uint8_t  * qs = xr->qs + 4 * ib;\n        device const uint16_t * qh = xr->qh + ib;\n        device const half     * dh = &xr->d;\n\n        for (short row = 0; row < nr0; row++) {\n            constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));\n            constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));\n            constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));\n            constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));\n\n            float sum = 0;\n            for (short j = 0; j < 4; ++j) {\n                sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)\n                     + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)\n                     + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)\n                     + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);\n            }\n            sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);\n\n            dh += args.nb01/2;\n            qs += args.nb01;\n            qh += args.nb01/2;\n        }\n\n        y4 += 32 * 32;\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {\n        float sum_all = simd_sum(sumf[row]);\n        if (tiisg == 0) {\n            dst_f32[first_row + row] = sum_all;\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_iq1_s_f32\")]]\nkernel void kernel_mul_mv_iq1_s_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);\n}\n\ntemplate<int nr0, typename args_t>\nvoid kernel_mul_mv_iq1_m_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    const int nb = args.ne00/QK_K;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const int first_row = (r0 * NSG + sgitg) * nr0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);\n    device const float       * y = (device const float       *) (src1 + offset1);\n\n    float yl[32];\n    float sumf[nr0]={0.f};\n\n    const int nb32 = nb * (QK_K / 32);\n\n    const short ix = tiisg;\n\n    device const float * y4 = y + 32 * ix;\n\n    iq1m_scale_t scale;\n\n    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {\n        float4 sumy = {0.f};\n        for (short i = 0; i < 8; ++i) {\n            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];\n            yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];\n            yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];\n            yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];\n        }\n\n        const int ibl = ib32 / (QK_K / 32);\n        const int ib  = ib32 % (QK_K / 32);\n\n        device const block_iq1_m * xr = x + ibl;\n        device const uint8_t  * qs = xr->qs + 4 * ib;\n        device const uint8_t  * qh = xr->qh + 2 * ib;\n        device const uint16_t * sc = (device const uint16_t *)xr->scales;\n\n        for (short row = 0; row < nr0; row++) {\n            scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);\n\n            constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));\n            constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));\n            constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));\n            constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));\n\n            float2 sum = {0.f};\n            for (short j = 0; j < 4; ++j) {\n                sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)\n                        + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);\n                sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)\n                        + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);\n            }\n            const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);\n            const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);\n\n            sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +\n                                             (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));\n\n            sc += args.nb01/2;\n            qs += args.nb01;\n            qh += args.nb01;\n        }\n\n        y4 += 32 * 32;\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {\n        float sum_all = simd_sum(sumf[row]);\n        if (tiisg == 0) {\n            dst_f32[first_row + row] = sum_all;\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_iq1_m_f32\")]]\nkernel void kernel_mul_mv_iq1_m_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);\n}\n\ntemplate<int NR0, typename args_t>\nvoid kernel_mul_mv_iq4_nl_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    threadgroup float * shmem_f32 = (threadgroup float *) shmem;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const int first_row = (r0 * NSG + sgitg) * NR0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);\n    device const float        * y = (device const float        *) (src1 + offset1);\n\n    const int nb   = args.ne00/QK4_NL;\n    const int ns01 = args.nb01/args.nb00;\n\n    const short ix = tiisg/2;  // 0...15\n    const short it = tiisg%2;  // 0 or 1\n\n    shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    float4 yl[4];\n    float sumf[NR0]={0.f};\n\n    device const float * yb = y + ix*QK4_NL + it*8;\n\n    uint32_t aux32[2];\n    thread const uint8_t * q8 = (thread const uint8_t *)aux32;\n\n    float4 qf1, qf2;\n\n    // [TAG_MUL_MV_WEIRD]\n    for (int ib = ix; ib < nb && ib < ns01; ib += 16) {\n        device const float4 * y4 = (device const float4 *)yb;\n        yl[0] = y4[0];\n        yl[1] = y4[4];\n        yl[2] = y4[1];\n        yl[3] = y4[5];\n\n        for (short row = 0; row < NR0; row++) {\n            device const block_iq4_nl & xb = x[row*ns01 + ib];\n            device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);\n\n            float4 acc1 = {0.f}, acc2 = {0.f};\n\n            aux32[0] = q4[0] | (q4[1] << 16);\n            aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;\n            aux32[0] &= 0x0f0f0f0f;\n            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};\n            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};\n            acc1 += yl[0] * qf1;\n            acc2 += yl[1] * qf2;\n\n            aux32[0] = q4[2] | (q4[3] << 16);\n            aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;\n            aux32[0] &= 0x0f0f0f0f;\n            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};\n            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};\n            acc1 += yl[2] * qf1;\n            acc2 += yl[3] * qf2;\n\n            acc1 += acc2;\n\n            sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);\n        }\n\n        yb += 16 * QK4_NL;\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {\n        float sum_all = simd_sum(sumf[row]);\n        if (tiisg == 0) {\n            dst_f32[first_row + row] = sum_all;\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_iq4_nl_f32\")]]\nkernel void kernel_mul_mv_iq4_nl_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\ntemplate<int NR0, typename args_t>\nvoid kernel_mul_mv_iq4_xs_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    threadgroup float * shmem_f32 = (threadgroup float *) shmem;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n    const int first_row = (r0 * NSG + sgitg) * NR0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);\n    device const float        * y = (device const float        *) (src1 + offset1);\n\n    const int nb   = args.ne00/QK_K;\n    const int ns01 = args.nb01/args.nb00;\n\n    const short ix = tiisg/16;  // 0 or 1\n    const short it = tiisg%16;  // 0...15\n    const short ib = it/2;\n    const short il = it%2;\n\n    shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    float4 yl[4];\n    float sumf[NR0]={0.f};\n\n    device const float * yb = y + ix * QK_K + ib * 32 + il * 8;\n\n    uint32_t aux32[2];\n    thread const uint8_t * q8 = (thread const uint8_t *)aux32;\n\n    float4 qf1, qf2;\n\n    // [TAG_MUL_MV_WEIRD]\n    for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {\n        device const float4 * y4 = (device const float4 *)yb;\n        yl[0] = y4[0];\n        yl[1] = y4[4];\n        yl[2] = y4[1];\n        yl[3] = y4[5];\n\n        for (short row = 0; row < NR0; ++row) {\n            device const block_iq4_xs & xb = x[row*ns01 + ibl];\n            device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);\n\n            float4 acc1 = {0.f}, acc2 = {0.f};\n\n            aux32[0] = (q4[0]     ) & 0x0f0f0f0f;\n            aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;\n            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};\n            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};\n            acc1 += yl[0] * qf1;\n            acc2 += yl[1] * qf2;\n\n            aux32[0] = (q4[1]     ) & 0x0f0f0f0f;\n            aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;\n            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};\n            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};\n            acc1 += yl[2] * qf1;\n            acc2 += yl[3] * qf2;\n\n            acc1 += acc2;\n\n            const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;\n            sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);\n        }\n\n        yb += 2 * QK_K;\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {\n        float sum_all = simd_sum(sumf[row]);\n        if (tiisg == 0) {\n            dst_f32[first_row + row] = sum_all;\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_iq4_xs_f32\")]]\nkernel void kernel_mul_mv_iq4_xs_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\ntemplate<int NR0, typename args_t>\nvoid kernel_mul_mv_mxfp4_f32_impl(\n        args_t args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg) {\n    const short NSG = FC_mul_mv_nsg;\n\n    threadgroup float * shmem_f32 = (threadgroup float *) shmem;\n\n    const int r0 = tgpig.x;\n    const int r1 = tgpig.y;\n    const int im = tgpig.z;\n\n    const int first_row = (r0 * NSG + sgitg) * NR0;\n\n    const uint i12 = im%args.ne12;\n    const uint i13 = im/args.ne12;\n\n    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;\n\n    device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);\n    device const float       * y = (device const float       *) (src1 + offset1);\n\n    const int nb   = args.ne00/QK_MXFP4;\n    const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors\n\n    const short ix = tiisg/2;  // 0...15\n    const short it = tiisg%2;  // 0 or 1\n\n    shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    float4 yl[4];\n    float sumf[NR0]={0.f};\n\n    device const float * yb = y + ix*QK_MXFP4 + it*8;\n\n    // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster\n    //       no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]\n    for (int ib = ix; ib < nb && ib < ns01; ib += 16) {\n        device const float4 * y4 = (device const float4 *) yb;\n\n        yl[0] = y4[0];\n        yl[1] = y4[4];\n        yl[2] = y4[1];\n        yl[3] = y4[5];\n\n        FOR_UNROLL (short row = 0; row < NR0; row++) {\n            device const block_mxfp4 & xb = x[row*ns01 + ib];\n            device const uint8_t     * q2 = (device const uint8_t *)(xb.qs + 8*it);\n\n            float4 acc1 = yl[0]*float4(shmem_f32[q2[0] &  0x0F], shmem_f32[q2[1] &  0x0F], shmem_f32[q2[2] &  0x0F], shmem_f32[q2[3] &  0x0F]);\n            float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4   ], shmem_f32[q2[1] >> 4   ], shmem_f32[q2[2] >> 4   ], shmem_f32[q2[3] >> 4   ]);\n            float4 acc3 = yl[2]*float4(shmem_f32[q2[4] &  0x0F], shmem_f32[q2[5] &  0x0F], shmem_f32[q2[6] &  0x0F], shmem_f32[q2[7] &  0x0F]);\n            float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4   ], shmem_f32[q2[5] >> 4   ], shmem_f32[q2[6] >> 4   ], shmem_f32[q2[7] >> 4   ]);\n\n            acc1 = (acc1 + acc3) + (acc2 + acc4);\n\n            sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3]));\n        }\n\n        yb += 16 * QK_MXFP4;\n    }\n\n    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;\n\n    for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {\n        float sum_all = simd_sum(sumf[row]);\n        if (tiisg == 0) {\n            dst_f32[first_row + row] = sum_all;\n        }\n    }\n}\n\n[[host_name(\"kernel_mul_mv_mxfp4_f32\")]]\nkernel void kernel_mul_mv_mxfp4_f32(\n        constant ggml_metal_kargs_mul_mv & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\ntemplate<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>\nkernel void kernel_get_rows_q(\n        constant ggml_metal_kargs_get_rows & args,\n        device const void * src0,\n        device const void * src1,\n        device       void * dst,\n        uint3               tgpig[[threadgroup_position_in_grid]],\n        ushort              tiitg[[thread_index_in_threadgroup]],\n        ushort3             ntg  [[threads_per_threadgroup]]) {\n    const int32_t iw0 = tgpig.x/args.ne10;\n    const int32_t i10 = tgpig.x%args.ne10;\n    const int32_t i11 = tgpig.y;\n    const int32_t i12 = tgpig.z;\n\n    const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];\n\n    const int32_t i02 = i11;\n    const int32_t i03 = i12;\n\n    auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 +   r*args.nb01);\n    auto pdst = (device      float4x4 *) ((      device char *) dst  + i12*args.nb3  + i11*args.nb2  + i10*args.nb1);\n\n    for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {\n        float4x4 temp;\n        dequantize_func(psrc + ind/nl, ind%nl, temp);\n        pdst[ind] = temp;\n\n        break;\n    }\n}\n\ntemplate<typename T0, typename T>\nkernel void kernel_get_rows_f(\n        constant ggml_metal_kargs_get_rows & args,\n        device const void * src0,\n        device const void * src1,\n        device       void * dst,\n        uint3               tgpig[[threadgroup_position_in_grid]],\n        ushort              tiitg[[thread_index_in_threadgroup]],\n        ushort3             ntg [[threads_per_threadgroup]]) {\n    const int32_t iw0 = tgpig.x/args.ne10;\n    const int32_t i10 = tgpig.x%args.ne10;\n    const int32_t i11 = tgpig.y;\n    const int32_t i12 = tgpig.z;\n\n    const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];\n\n    const int32_t i02 = i11;\n    const int32_t i03 = i12;\n\n    auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 +   r*args.nb01);\n    auto pdst = (      device T  *) ((      device char *)  dst + i12*args.nb3  + i11*args.nb2  + i10*args.nb1);\n\n    for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {\n        pdst[ind] = psrc[ind];\n\n        break;\n    }\n}\n\ntemplate<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>\nkernel void kernel_set_rows_q32(\n        constant ggml_metal_kargs_set_rows & args,\n        device const  void * src0,\n        device const  void * src1,\n        device       float * dst,\n        uint3                tgpig[[threadgroup_position_in_grid]],\n        uint                 tiitg[[thread_index_in_threadgroup]],\n        uint3                tptg [[threads_per_threadgroup]]) {\n    const int32_t i03 = tgpig.z;\n    const int32_t i02 = tgpig.y;\n\n    const int32_t i12 = i03%args.ne12;\n    const int32_t i11 = i02%args.ne11;\n\n    const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;\n    if (i01 >= args.ne01) {\n        return;\n    }\n\n    const int32_t i10 = i01;\n    const TI      i1  = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];\n\n          device block_q * dst_row = (      device block_q *) ((      device char *) dst  +  i1*args.nb1  + i02*args.nb2  + i03*args.nb3);\n    const device float   * src_row = (const device float   *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);\n\n    for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {\n        quantize_func(src_row + 32*ind, dst_row[ind]);\n    }\n}\n\ntemplate<typename T, typename TI>\nkernel void kernel_set_rows_f(\n        constant ggml_metal_kargs_set_rows & args,\n        device const  void * src0,\n        device const  void * src1,\n        device       float * dst,\n        uint3                tgpig[[threadgroup_position_in_grid]],\n        uint                 tiitg[[thread_index_in_threadgroup]],\n        uint3                tptg [[threads_per_threadgroup]]) {\n    const int32_t i03 = tgpig.z;\n    const int32_t i02 = tgpig.y;\n\n    const int32_t i12 = i03%args.ne12;\n    const int32_t i11 = i02%args.ne11;\n\n    const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;\n    if (i01 >= args.ne01) {\n        return;\n    }\n\n    const int32_t i10 = i01;\n    const TI      i1  = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];\n\n          device T     * dst_row = (      device T     *) ((      device char *) dst  +  i1*args.nb1  + i02*args.nb2  + i03*args.nb3);\n    const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);\n\n    for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {\n        dst_row[ind] = (T) src_row[ind];\n    }\n}\n\nkernel void kernel_diag_f32(\n        constant ggml_metal_kargs_diag & args,\n        device   const char * src0,\n        device         char * dst,\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiitg[[thread_index_in_threadgroup]]) {\n    constexpr short NW = N_SIMDWIDTH;\n\n    const int32_t i3 = tgpig.z;\n    const int32_t i2 = tgpig.y;\n    const int32_t i1 = tgpig.x;\n\n    device const float * src0_ptr = (device const float *)(src0 +                i2*args.nb02 + i3*args.nb03);\n    device       float * dst_ptr  = (device       float *)(dst  + i1*args.nb01 + i2*args.nb2  + i3*args.nb3);\n\n    for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {\n        dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;\n    }\n}\n\nconstant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];\nconstant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];\n\n// each block_q contains 16*nl weights\ntemplate<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>\nkernel void kernel_mul_mm(\n        constant ggml_metal_kargs_mul_mm & args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiitg[[thread_index_in_threadgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n\n    threadgroup S0 * sa = (threadgroup S0 *)(shmem);\n    threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);\n\n#ifdef GGML_METAL_HAS_TENSOR\n    threadgroup float * sc = (threadgroup float *)(shmem);\n#endif\n\n    constexpr int NR0 = 64;\n    constexpr int NR1 = 32;\n\n    constexpr int NK  = 32;\n    constexpr int NL0 = NK/16;\n    constexpr int NL1 = NK/8;\n\n    const int im = tgpig.z;\n    const int r0 = tgpig.y*NR0;\n    const int r1 = tgpig.x*NR1;\n\n    // if this block is of 64x32 shape or smaller\n    const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;\n    const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;\n\n    // a thread shouldn't load data outside of the matrix\n    const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63\n    const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31\n\n    const short il0 = (tiitg % NL0);\n\n    short il = il0;\n\n    const int i12 = im%args.ne12;\n    const int i13 = im/args.ne12;\n\n    const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;\n    const short    offset1 = il0/nl;\n\n    device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;\n\n    const short iy = 8*(tiitg % NL1);\n\n    device const T1 * y = (device const T1 *)(src1\n        + args.nb13*i13\n        + args.nb12*i12\n        + args.nb11*(r1 + lr1)\n        + args.nb10*iy);\n\n#ifndef GGML_METAL_HAS_TENSOR\n    S0_8x8 ma[4];\n    S1_8x8 mb[2];\n\n    simdgroup_float8x8 mc[8];\n\n    for (short i = 0; i < 8; i++){\n        mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);\n    }\n#else\n    auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK,  NR0));\n    auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));\n\n    mpp::tensor_ops::matmul2d<\n        mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),\n        execution_simdgroups<4>> mm;\n\n    auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();\n#endif\n\n    for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {\n#ifndef GGML_METAL_HAS_TENSOR\n        // load data and store to threadgroup memory\n        if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {\n            threadgroup_barrier(mem_flags::mem_threadgroup);\n\n            // no need for dequantization\n            for (short i = 0; i < 16; i++) {\n                const short sx = 2*il0 + i/8;\n                const short sy = (tiitg/NL0)/8;\n\n              //const short lx = i%8;\n              //const short ly = (tiitg/NL0)%8;\n                const short lx = (tiitg/NL0)%8;\n                const short ly = i%8;\n\n                const short ib = 8*sx + sy;\n\n                *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;\n            }\n        } else {\n            S0_4x4 temp_a;\n            dequantize_func(x, il, temp_a);\n\n            threadgroup_barrier(mem_flags::mem_threadgroup);\n\n            FOR_UNROLL (short i = 0; i < 16; i++) {\n                const short sx = 2*il0 + i/8;\n                const short sy = (tiitg/NL0)/8;\n\n              //const short lx = i%8;\n              //const short ly = (tiitg/NL0)%8;\n                const short lx = (tiitg/NL0)%8;\n                const short ly = i%8;\n\n                const short ib = 8*sx + sy;\n\n                // NOTE: this is massively slower.. WTF?\n                //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];\n\n                *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];\n            }\n        }\n\n        if (FC_mul_mm_bc_inp) {\n            for (short i = 0; i < 8; ++i) {\n                const short sx = (tiitg%NL1);\n                const short sy = (tiitg/NL1)/8;\n\n                const short lx = i;\n                const short ly = (tiitg/NL1)%8;\n              //const short lx = (tiitg/NL1)%8;\n              //const short ly = i;\n\n                const short ib = 4*sx + sy;\n\n                *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;\n            }\n        } else {\n            const short sx = (tiitg%NL1);\n            const short sy = (tiitg/NL1)/8;\n\n          //const short dx = sx;\n          //const short dy = sy;\n\n            const short ly = (tiitg/NL1)%8;\n\n            const short ib = 4*sx + sy;\n\n            *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));\n        }\n#else\n        // load data and store to threadgroup memory\n        if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {\n            threadgroup_barrier(mem_flags::mem_threadgroup);\n\n            // no need for dequantization\n            for (short i = 0; i < 16; i++) {\n                const short sx = 2*il0 + i/8;\n                const short sy = (tiitg/NL0)/8;\n\n                const short lx = i%8;\n                const short ly = (tiitg/NL0)%8;\n                //const short lx = (tiitg/NL0)%8;\n                //const short ly = i%8;\n\n                *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;\n            }\n        } else {\n            S0_4x4 temp_a;\n            dequantize_func(x, il, temp_a);\n\n            threadgroup_barrier(mem_flags::mem_threadgroup);\n\n            FOR_UNROLL (short i = 0; i < 16; i++) {\n                const short sx = 2*il0 + i/8;\n                const short sy = (tiitg/NL0)/8;\n\n                const short lx = i%8;\n                const short ly = (tiitg/NL0)%8;\n                //const short lx = (tiitg/NL0)%8;\n                //const short ly = i%8;\n\n                *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];\n            }\n        }\n\n        if (FC_mul_mm_bc_inp) {\n            for (short i = 0; i < 8; ++i) {\n                const short sx = (tiitg%NL1);\n                const short sy = (tiitg/NL1)/8;\n\n                const short lx = i;\n                const short ly = (tiitg/NL1)%8;\n                //const short lx = (tiitg/NL1)%8;\n                //const short ly = i;\n\n                *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;\n            }\n        } else {\n            const short sx = (tiitg%NL1);\n            const short sy = (tiitg/NL1)/8;\n\n            //const short lx = i;\n            const short ly = (tiitg/NL1)%8;\n            //const short lx = (tiitg/NL1)%8;\n            //const short ly = i;\n\n            *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));\n        }\n#endif\n\n        il = (il + 2 < nl) ? il + 2 : il % 2;\n        x  = (il < 2) ? x + (2 + nl - 1)/nl : x;\n\n        y += NK;\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n#ifndef GGML_METAL_HAS_TENSOR\n        // load matrices from threadgroup memory and conduct outer products\n        threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));\n        threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));\n\n        FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {\n            simdgroup_barrier(mem_flags::mem_none);\n\n            FOR_UNROLL (short i = 0; i < 4; i++) {\n                simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);\n            }\n\n            simdgroup_barrier(mem_flags::mem_none);\n\n            FOR_UNROLL (short i = 0; i < 2; i++) {\n                simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);\n            }\n\n            simdgroup_barrier(mem_flags::mem_none);\n\n            FOR_UNROLL (short i = 0; i < 8; i++){\n                simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);\n            }\n\n            lsma += 8*64;\n            lsmb += 4*64;\n        }\n#else\n        auto sA = tA.slice(0, 0);\n        auto sB = tB.slice(0, 0);\n\n        mm.run(sB, sA, cT);\n#endif\n    }\n\n    if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {\n        // if no bounds checks on the output are needed, we can directly write to device memory\n#ifdef GGML_METAL_HAS_TENSOR\n        device float * C = (device float *) dst +\n            r0 + \\\n            r1 * args.ne0 + im*args.ne1*args.ne0;\n\n        auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));\n        cT.store(tC);\n#else\n        device float * C = (device float *) dst +\n            (r0 + 32*(sgitg &  1)) + \\\n            (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;\n\n        for (short i = 0; i < 8; i++) {\n            simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);\n        }\n#endif\n    } else {\n        // block is smaller than 64x32, we should avoid writing data outside of the matrix\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;\n\n#ifdef GGML_METAL_HAS_TENSOR\n        auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));\n        cT.store(tC);\n#else\n        for (short i = 0; i < 8; i++) {\n            simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);\n        }\n#endif\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        if (sgitg == 0) {\n            for (int j = tiitg; j < nr1; j += NR1) {\n                device float  * D  = (device float  *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;\n                device float4 * D4 = (device float4 *) D;\n\n                threadgroup float  * C  = temp_str + (j*NR0);\n                threadgroup float4 * C4 = (threadgroup float4 *) C;\n\n                int i = 0;\n                for (; i < nr0/4; i++) {\n                    *(D4 + i) = *(C4 + i);\n                }\n\n                i *= 4;\n                for (; i < nr0; i++) {\n                    *(D + i) = *(C + i);\n                }\n            }\n        }\n    }\n}\n\ntemplate<short ne20> // n_expert_used\nkernel void kernel_mul_mm_id_map0(\n        constant ggml_metal_kargs_mul_mm_id_map0 & args,\n        device  const char * src2,\n        device        char * htpe,\n        device        char * hids,\n        threadgroup   char * shmem [[threadgroup(0)]],\n        ushort tpitg[[thread_position_in_threadgroup]],\n        ushort   ntg[[threads_per_threadgroup]]) {\n    const short ide = tpitg; // expert id\n\n    uint32_t n_all = 0;\n\n    device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;\n\n    for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens\n        if (i21 + tpitg < args.ne21) {\n            device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);\n\n            threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;\n\n            #pragma unroll(ne20)\n            for (short i20 = 0; i20 < ne20; i20++) {\n                sids[i20] = src2_i32[i20];\n            }\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        for (short t = 0; t < ntg; t++) {\n            if (i21 + t >= args.ne21) {\n                break;\n            }\n\n            threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;\n\n            short sel = 0;\n            #pragma unroll(ne20)\n            for (short i20 = 0; i20 < ne20; i20++) {\n                sel += (sids[i20] == ide)*(i20 + 1);\n            }\n\n            ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;\n\n            n_all += sel > 0;\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n    }\n\n    device uint32_t * tpe_u32 = (device uint32_t *) (htpe);\n    tpe_u32[ide] = n_all;\n}\n\ntypedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;\n\ntemplate [[host_name(\"kernel_mul_mm_id_map0_ne20_1\" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;\ntemplate [[host_name(\"kernel_mul_mm_id_map0_ne20_2\" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;\ntemplate [[host_name(\"kernel_mul_mm_id_map0_ne20_4\" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;\ntemplate [[host_name(\"kernel_mul_mm_id_map0_ne20_5\" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;\ntemplate [[host_name(\"kernel_mul_mm_id_map0_ne20_6\" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;\ntemplate [[host_name(\"kernel_mul_mm_id_map0_ne20_8\" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;\ntemplate [[host_name(\"kernel_mul_mm_id_map0_ne20_10\")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;\ntemplate [[host_name(\"kernel_mul_mm_id_map0_ne20_16\")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;\ntemplate [[host_name(\"kernel_mul_mm_id_map0_ne20_22\")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;\n\ntemplate<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>\nkernel void kernel_mul_mm_id(\n        constant ggml_metal_kargs_mul_mm_id & args,\n        device const char * src0,\n        device const char * src1,\n        device const char * htpe,\n        device const char * hids,\n        device       char * dst,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiitg[[thread_index_in_threadgroup]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n    threadgroup S0 * sa = (threadgroup S0 *)(shmem);\n    threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);\n\n#ifdef GGML_METAL_HAS_TENSOR\n    threadgroup float * sc = (threadgroup float *)(shmem);\n#endif\n\n    constexpr int NR0 = 64;\n    constexpr int NR1 = 32;\n\n    constexpr int NK  = 32;\n    constexpr int NL0 = NK/16;\n    constexpr int NL1 = NK/8;\n\n    const int im = tgpig.z; // expert\n    const int r0 = tgpig.y*NR0;\n    const int r1 = tgpig.x*NR1;\n\n    device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);\n    device const int32_t  * ids_i32 = (device const int32_t  *) (hids);\n\n    const int32_t neh1 = tpe_u32[im];\n\n    if (r1 >= neh1) {\n        return;\n    }\n\n    // if this block is of 64x32 shape or smaller\n    const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;\n    const short nr1 = (    neh1 - r1 < NR1) ? (    neh1 - r1) : NR1;\n\n    // a thread shouldn't load data outside of the matrix\n    const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63\n    const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31\n\n    const short il0 = (tiitg % NL0);\n\n    short il = il0;\n\n    const int id = ids_i32[im*args.ne21 + r1 + lr1];\n\n    const short i11 = (id % args.ne20) % args.ne11;\n    const short i12 = (id / args.ne20);\n    const short i13 = 0;\n\n    const uint64_t offset0 = im*args.nb02 + i13*args.nb03;\n    const short    offset1 = il0/nl;\n\n    device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;\n\n    const short iy = 8*(tiitg % NL1);\n\n    device const T1 * y = (device const T1 *)(src1\n        + args.nb13*i13\n        + args.nb12*i12\n        + args.nb11*i11\n        + args.nb10*iy);\n\n#ifndef GGML_METAL_HAS_TENSOR\n    S0_8x8 ma[4];\n    S1_8x8 mb[2];\n\n    simdgroup_float8x8 mc[8];\n\n    for (short i = 0; i < 8; i++){\n        mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);\n    }\n#else\n    auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK,  NR0));\n    auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));\n\n    mpp::tensor_ops::matmul2d<\n        mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),\n        execution_simdgroups<4>> mm;\n\n    auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();\n#endif\n\n    for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {\n#ifndef GGML_METAL_HAS_TENSOR\n        // load data and store to threadgroup memory\n        if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {\n            threadgroup_barrier(mem_flags::mem_threadgroup);\n\n            // no need for dequantization\n            for (short i = 0; i < 16; i++) {\n                const short sx = 2*il0 + i/8;\n                const short sy = (tiitg/NL0)/8;\n\n              //const short lx = i%8;\n              //const short ly = (tiitg/NL0)%8;\n                const short lx = (tiitg/NL0)%8;\n                const short ly = i%8;\n\n                const short ib = 8*sx + sy;\n\n                *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;\n            }\n        } else {\n            S0_4x4 temp_a;\n            dequantize_func(x, il, temp_a);\n\n            threadgroup_barrier(mem_flags::mem_threadgroup);\n\n            FOR_UNROLL (short i = 0; i < 16; i++) {\n                const short sx = 2*il0 + i/8;\n                const short sy = (tiitg/NL0)/8;\n\n              //const short lx = i%8;\n              //const short ly = (tiitg/NL0)%8;\n                const short lx = (tiitg/NL0)%8;\n                const short ly = i%8;\n\n                const short ib = 8*sx + sy;\n\n                // NOTE: this is massively slower.. WTF?\n                //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];\n\n                *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];\n            }\n        }\n\n        if (FC_mul_mm_bc_inp) {\n            for (short i = 0; i < 8; ++i) {\n                const short sx = (tiitg%NL1);\n                const short sy = (tiitg/NL1)/8;\n\n                const short lx = i;\n                const short ly = (tiitg/NL1)%8;\n              //const short lx = (tiitg/NL1)%8;\n              //const short ly = i;\n\n                const short ib = 4*sx + sy;\n\n                *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;\n            }\n        } else {\n            const short sx = (tiitg%NL1);\n            const short sy = (tiitg/NL1)/8;\n\n          //const short dx = sx;\n          //const short dy = sy;\n\n            const short ly = (tiitg/NL1)%8;\n\n            const short ib = 4*sx + sy;\n\n            *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));\n        }\n#else\n        // load data and store to threadgroup memory\n        if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {\n            threadgroup_barrier(mem_flags::mem_threadgroup);\n\n            // no need for dequantization\n            for (short i = 0; i < 16; i++) {\n                const short sx = 2*il0 + i/8;\n                const short sy = (tiitg/NL0)/8;\n\n                const short lx = i%8;\n                const short ly = (tiitg/NL0)%8;\n                //const short lx = (tiitg/NL0)%8;\n                //const short ly = i%8;\n\n                *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;\n            }\n        } else {\n            S0_4x4 temp_a;\n            dequantize_func(x, il, temp_a);\n\n            threadgroup_barrier(mem_flags::mem_threadgroup);\n\n            FOR_UNROLL (short i = 0; i < 16; i++) {\n                const short sx = 2*il0 + i/8;\n                const short sy = (tiitg/NL0)/8;\n\n                const short lx = i%8;\n                const short ly = (tiitg/NL0)%8;\n                //const short lx = (tiitg/NL0)%8;\n                //const short ly = i%8;\n\n                *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];\n            }\n        }\n\n        if (FC_mul_mm_bc_inp) {\n            for (short i = 0; i < 8; ++i) {\n                const short sx = (tiitg%NL1);\n                const short sy = (tiitg/NL1)/8;\n\n                const short lx = i;\n                const short ly = (tiitg/NL1)%8;\n                //const short lx = (tiitg/NL1)%8;\n                //const short ly = i;\n\n                *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;\n            }\n        } else {\n            const short sx = (tiitg%NL1);\n            const short sy = (tiitg/NL1)/8;\n\n            //const short lx = i;\n            const short ly = (tiitg/NL1)%8;\n            //const short lx = (tiitg/NL1)%8;\n            //const short ly = i;\n\n            *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));\n        }\n#endif\n\n        il = (il + 2 < nl) ? il + 2 : il % 2;\n        x  = (il < 2) ? x + (2 + nl - 1)/nl : x;\n\n        y += NK;\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n#ifndef GGML_METAL_HAS_TENSOR\n        // load matrices from threadgroup memory and conduct outer products\n        threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));\n        threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));\n\n        FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {\n            simdgroup_barrier(mem_flags::mem_none);\n\n            FOR_UNROLL (short i = 0; i < 4; i++) {\n                simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);\n            }\n\n            simdgroup_barrier(mem_flags::mem_none);\n\n            FOR_UNROLL (short i = 0; i < 2; i++) {\n                simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);\n            }\n\n            simdgroup_barrier(mem_flags::mem_none);\n\n            FOR_UNROLL (short i = 0; i < 8; i++){\n                simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);\n            }\n\n            lsma += 8*64;\n            lsmb += 4*64;\n        }\n#else\n        auto sA = tA.slice(0, 0);\n        auto sB = tB.slice(0, 0);\n\n        mm.run(sB, sA, cT);\n#endif\n    }\n\n    // block is smaller than 64x32, we should avoid writing data outside of the matrix\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n#ifdef GGML_METAL_HAS_TENSOR\n    auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));\n    cT.store(tC);\n#else\n    threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;\n\n    for (short i = 0; i < 8; i++) {\n        simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);\n    }\n#endif\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    for (short j = sgitg; j < nr1; j += 4) {\n        const int id = ids_i32[im*args.ne21 + r1 + j];\n\n        const short ide = id % args.ne20;\n        const short idt = id / args.ne20;\n\n        device float  * D  = (device float  *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;\n        device float4 * D4 = (device float4 *) D;\n\n        threadgroup float  * C  = (threadgroup float  *) shmem + j*NR0;\n        threadgroup float4 * C4 = (threadgroup float4 *) C;\n\n        int i = tiisg;\n        for (; i < nr0/4; i += 32) {\n            *(D4 + i) = *(C4 + i);\n        }\n\n        i = (4*(nr0/4)) + tiisg;\n        for (; i < nr0; i += 32) {\n            *(D + i) = *(C + i);\n        }\n    }\n}\n\n#define QK_NL 16\n\n//\n// get rows\n//\n\ntypedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;\n\ntemplate [[host_name(\"kernel_get_rows_f32\")]]  kernel get_rows_f_t kernel_get_rows_f<float, float>;\ntemplate [[host_name(\"kernel_get_rows_f16\")]]  kernel get_rows_f_t kernel_get_rows_f<half,  float>;\ntemplate [[host_name(\"kernel_get_rows_i32\")]]  kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_get_rows_bf16\")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;\n#endif\n\ntypedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;\n\ntemplate [[host_name(\"kernel_get_rows_q4_0\")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_0,    2, dequantize_q4_0>;\ntemplate [[host_name(\"kernel_get_rows_q4_1\")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_1,    2, dequantize_q4_1>;\ntemplate [[host_name(\"kernel_get_rows_q5_0\")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_0,    2, dequantize_q5_0>;\ntemplate [[host_name(\"kernel_get_rows_q5_1\")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_1,    2, dequantize_q5_1>;\ntemplate [[host_name(\"kernel_get_rows_q8_0\")]]    kernel get_rows_q_t kernel_get_rows_q<block_q8_0,    2, dequantize_q8_0>;\ntemplate [[host_name(\"kernel_get_rows_mxfp4\")]]   kernel get_rows_q_t kernel_get_rows_q<block_mxfp4,   2, dequantize_mxfp4>;\ntemplate [[host_name(\"kernel_get_rows_q2_K\")]]    kernel get_rows_q_t kernel_get_rows_q<block_q2_K,    QK_NL, dequantize_q2_K>;\ntemplate [[host_name(\"kernel_get_rows_q3_K\")]]    kernel get_rows_q_t kernel_get_rows_q<block_q3_K,    QK_NL, dequantize_q3_K>;\ntemplate [[host_name(\"kernel_get_rows_q4_K\")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_K,    QK_NL, dequantize_q4_K>;\ntemplate [[host_name(\"kernel_get_rows_q5_K\")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_K,    QK_NL, dequantize_q5_K>;\ntemplate [[host_name(\"kernel_get_rows_q6_K\")]]    kernel get_rows_q_t kernel_get_rows_q<block_q6_K,    QK_NL, dequantize_q6_K>;\ntemplate [[host_name(\"kernel_get_rows_iq2_xxs\")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;\ntemplate [[host_name(\"kernel_get_rows_iq2_xs\")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;\ntemplate [[host_name(\"kernel_get_rows_iq3_xxs\")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;\ntemplate [[host_name(\"kernel_get_rows_iq3_s\")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq3_s,   QK_NL, dequantize_iq3_s>;\ntemplate [[host_name(\"kernel_get_rows_iq2_s\")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq2_s,   QK_NL, dequantize_iq2_s>;\ntemplate [[host_name(\"kernel_get_rows_iq1_s\")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq1_s,   QK_NL, dequantize_iq1_s>;\ntemplate [[host_name(\"kernel_get_rows_iq1_m\")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq1_m,   QK_NL, dequantize_iq1_m>;\ntemplate [[host_name(\"kernel_get_rows_iq4_nl\")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl,  2,     dequantize_iq4_nl>;\ntemplate [[host_name(\"kernel_get_rows_iq4_xs\")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;\n\n//\n// set rows\n//\n\ntypedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;\n\ntemplate [[host_name(\"kernel_set_rows_f32_i64\")]]  kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;\ntemplate [[host_name(\"kernel_set_rows_f32_i32\")]]  kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;\ntemplate [[host_name(\"kernel_set_rows_f16_i64\")]]  kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;\ntemplate [[host_name(\"kernel_set_rows_f16_i32\")]]  kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_set_rows_bf16_i64\")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;\ntemplate [[host_name(\"kernel_set_rows_bf16_i32\")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;\n#endif\n\ntypedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;\n\ntemplate [[host_name(\"kernel_set_rows_q8_0_i64\")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0,   quantize_q8_0>;\ntemplate [[host_name(\"kernel_set_rows_q8_0_i32\")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0,   quantize_q8_0>;\ntemplate [[host_name(\"kernel_set_rows_q4_0_i64\")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0,   quantize_q4_0>;\ntemplate [[host_name(\"kernel_set_rows_q4_0_i32\")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0,   quantize_q4_0>;\ntemplate [[host_name(\"kernel_set_rows_q4_1_i64\")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1,   quantize_q4_1>;\ntemplate [[host_name(\"kernel_set_rows_q4_1_i32\")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1,   quantize_q4_1>;\ntemplate [[host_name(\"kernel_set_rows_q5_0_i64\")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0,   quantize_q5_0>;\ntemplate [[host_name(\"kernel_set_rows_q5_0_i32\")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0,   quantize_q5_0>;\ntemplate [[host_name(\"kernel_set_rows_q5_1_i64\")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1,   quantize_q5_1>;\ntemplate [[host_name(\"kernel_set_rows_q5_1_i32\")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1,   quantize_q5_1>;\ntemplate [[host_name(\"kernel_set_rows_iq4_nl_i64\")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;\ntemplate [[host_name(\"kernel_set_rows_iq4_nl_i32\")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;\n\n//\n// matrix-matrix multiplication\n//\n\ntypedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;\n\ntemplate [[host_name(\"kernel_mul_mm_f32_f32\")]]     kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32,     float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_f16_f32\")]]     kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16,     half,   half4x4,   float, float2x4>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_mul_mm_bf16_f32\")]]    kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4,     1,     dequantize_bf16,    bfloat, bfloat4x4, float, float2x4>;\n#endif\ntemplate [[host_name(\"kernel_mul_mm_q4_0_f32\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q4_1_f32\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q5_0_f32\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q5_1_f32\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q8_0_f32\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_mxfp4_f32\")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_mxfp4,   2,     dequantize_mxfp4,   float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q2_K_f32\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q3_K_f32\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q4_K_f32\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q5_K_f32\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_K,    QK_NL, dequantize_q5_K,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q6_K_f32\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q6_K,    QK_NL, dequantize_q6_K,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq2_xxs_f32\")]] kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq2_xs_f32\")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xs,  QK_NL, dequantize_iq2_xs,  float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq3_xxs_f32\")]] kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq3_s_f32\")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_s,   QK_NL, dequantize_iq3_s,   float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq2_s_f32\")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_s,   QK_NL, dequantize_iq2_s,   float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq1_s_f32\")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_s,   QK_NL, dequantize_iq1_s,   float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq1_m_f32\")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_m,   QK_NL, dequantize_iq1_m,   float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq4_nl_f32\")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_nl,  2,     dequantize_iq4_nl,  float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq4_xs_f32\")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_xs,  QK_NL, dequantize_iq4_xs,  float,  float4x4,  float, float2x4>;\n\ntemplate [[host_name(\"kernel_mul_mm_f32_f16\")]]     kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32,     float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_f16_f16\")]]     kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16,     half,   half4x4,   half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q4_0_f16\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q4_1_f16\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q5_0_f16\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q5_1_f16\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q8_0_f16\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_mxfp4_f16\")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_mxfp4,   2,     dequantize_mxfp4,   float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q2_K_f16\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q3_K_f16\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q4_K_f16\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q5_K_f16\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_K,    QK_NL, dequantize_q5_K,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_q6_K_f16\")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q6_K,    QK_NL, dequantize_q6_K,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq2_xxs_f16\")]] kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq2_xs_f16\")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xs,  QK_NL, dequantize_iq2_xs,  float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq3_xxs_f16\")]] kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq3_s_f16\")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_s,   QK_NL, dequantize_iq3_s,   float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq2_s_f16\")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_s,   QK_NL, dequantize_iq2_s,   float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq1_s_f16\")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_s,   QK_NL, dequantize_iq1_s,   float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq1_m_f16\")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_m,   QK_NL, dequantize_iq1_m,   float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq4_nl_f16\")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_nl,  2,     dequantize_iq4_nl,  float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_iq4_xs_f16\")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_xs,  QK_NL, dequantize_iq4_xs,  float,  float4x4,  half, half2x4>;\n\n//\n// indirect matrix-matrix multiplication\n//\n\ntypedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;\n\ntemplate [[host_name(\"kernel_mul_mm_id_f32_f32\")]]     kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32,     float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_f16_f32\")]]     kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16,     half,   half4x4,   float, float2x4>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_mul_mm_id_bf16_f32\")]]    kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4,     1,     dequantize_bf16,    bfloat, bfloat4x4, float, float2x4>;\n#endif\ntemplate [[host_name(\"kernel_mul_mm_id_q4_0_f32\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q4_1_f32\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q5_0_f32\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q5_1_f32\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q8_0_f32\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_mxfp4_f32\")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_mxfp4,   2,     dequantize_mxfp4,   float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q2_K_f32\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q3_K_f32\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q4_K_f32\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q5_K_f32\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_K,    QK_NL, dequantize_q5_K,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q6_K_f32\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q6_K,    QK_NL, dequantize_q6_K,    float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq2_xxs_f32\")]] kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq2_xs_f32\")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xs,  QK_NL, dequantize_iq2_xs,  float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq3_xxs_f32\")]] kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq3_s_f32\")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_s,   QK_NL, dequantize_iq3_s,   float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq2_s_f32\")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_s,   QK_NL, dequantize_iq2_s,   float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq1_s_f32\")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_s,   QK_NL, dequantize_iq1_s,   float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq1_m_f32\")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_m,   QK_NL, dequantize_iq1_m,   float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq4_nl_f32\")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_nl,  2,     dequantize_iq4_nl,  float,  float4x4,  float, float2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq4_xs_f32\")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_xs,  QK_NL, dequantize_iq4_xs,  float,  float4x4,  float, float2x4>;\n\ntemplate [[host_name(\"kernel_mul_mm_id_f32_f16\")]]     kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32,     float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_f16_f16\")]]     kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16,     half,   half4x4,   half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q4_0_f16\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q4_1_f16\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q5_0_f16\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q5_1_f16\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q8_0_f16\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_mxfp4_f16\")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_mxfp4,   2,     dequantize_mxfp4,   float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q2_K_f16\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q3_K_f16\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q4_K_f16\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q5_K_f16\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_K,    QK_NL, dequantize_q5_K,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_q6_K_f16\")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q6_K,    QK_NL, dequantize_q6_K,    float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq2_xxs_f16\")]] kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq2_xs_f16\")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xs,  QK_NL, dequantize_iq2_xs,  float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq3_xxs_f16\")]] kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq3_s_f16\")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_s,   QK_NL, dequantize_iq3_s,   float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq2_s_f16\")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_s,   QK_NL, dequantize_iq2_s,   float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq1_s_f16\")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_s,   QK_NL, dequantize_iq1_s,   float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq1_m_f16\")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_m,   QK_NL, dequantize_iq1_m,   float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq4_nl_f16\")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_nl,  2,     dequantize_iq4_nl,  float,  float4x4,  half, half2x4>;\ntemplate [[host_name(\"kernel_mul_mm_id_iq4_xs_f16\")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_xs,  QK_NL, dequantize_iq4_xs,  float,  float4x4,  half, half2x4>;\n\n//\n// matrix-vector multiplication\n//\n\ntypedef void (kernel_mul_mv_disp_t)(\n        ggml_metal_kargs_mul_mv args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        uint3  tgpig,\n        ushort tiisg);\n\ntypedef void (kernel_mul_mv2_disp_t)(\n        ggml_metal_kargs_mul_mv args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiisg,\n        ushort sgitg);\n\ntemplate<kernel_mul_mv_disp_t disp_fn>\nvoid mmv_fn(\n        ggml_metal_kargs_mul_mv args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiitg,\n        ushort tiisg,\n        ushort sgitg) {\n    disp_fn(args, src0, src1, dst, tgpig, tiisg);\n}\n\ntemplate<kernel_mul_mv2_disp_t disp_fn>\nvoid mmv_fn(\n        ggml_metal_kargs_mul_mv args,\n        device const char * src0,\n        device const char * src1,\n        device       char * dst,\n        threadgroup  char * shmem,\n        uint3  tgpig,\n        ushort tiitg,\n        ushort tiisg,\n        ushort sgitg) {\n    disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);\n}\n\ntypedef decltype(mmv_fn<kernel_mul_mv_t_t_disp<half, half, ggml_metal_kargs_mul_mv>>) mul_mv_disp_fn_t;\n\ntemplate<mul_mv_disp_fn_t disp_fn>\nkernel void kernel_mul_mv_id(\n        constant ggml_metal_kargs_mul_mv_id & args,\n        device const char * src0s,\n        device const char * src1,\n        device       char * dst,\n        device const char * ids,\n        threadgroup  char * shmem [[threadgroup(0)]],\n        uint3  tgpig[[threadgroup_position_in_grid]],\n        ushort tiitg[[thread_index_in_threadgroup]],\n        ushort tiisg[[thread_index_in_simdgroup]],\n        ushort sgitg[[simdgroup_index_in_threadgroup]]) {\n    const int iid1 = tgpig.z/args.nei0;\n    const int idx  = tgpig.z%args.nei0;\n\n    tgpig.z = 0;\n\n    const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx];\n\n    const int64_t i11 = idx % args.ne11;\n    const int64_t i12 = iid1;\n\n    const int64_t i1 = idx;\n    const int64_t i2 = i12;\n\n    device const char * src0_cur = src0s + i02*args.nb02;\n    device const char * src1_cur = src1  + i11*args.nb11 + i12*args.nb12;\n\n    device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float);\n\n    ggml_metal_kargs_mul_mv args0 = {\n        /*.ne00 =*/ args.ne00,\n        /*.ne01 =*/ args.ne01,\n        /*.ne02 =*/ 1, // args.ne02,\n        /*.nb00 =*/ args.nb00,\n        /*.nb01 =*/ args.nb01,\n        /*.nb02 =*/ args.nb02,\n        /*.nb03 =*/ args.nb02, // args.ne02 == 1\n        /*.ne10 =*/ args.ne10,\n        /*.ne11 =*/ 1, // args.ne11,\n        /*.ne12 =*/ 1, // args.ne12,\n        /*.nb10 =*/ args.nb10,\n        /*.nb11 =*/ args.nb11,\n        /*.nb12 =*/ args.nb12,\n        /*.nb13 =*/ args.nb12, // ne12 == 1\n        /*.ne0  =*/ args.ne0,\n        /*.ne1  =*/ 1, // args.ne1,\n        /*.nr0  =*/ args.nr0,\n        /*.r2   =*/ 1,\n        /*.r3   =*/ 1,\n    };\n\n    disp_fn(\n        args0,\n        /* src0 */ src0_cur,\n        /* src1 */ src1_cur,\n        /* dst  */ dst_cur,\n        shmem,\n        tgpig,\n        tiitg,\n        tiisg,\n        sgitg);\n}\n\ntypedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>) kernel_mul_mv_id_t;\n\ntypedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>) kernel_mul_mv_id_4_t;\n\ntemplate [[host_name(\"kernel_mul_mv_id_f32_f32\")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_f16_f32\")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<half,  float>>>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_mul_mv_id_bf16_f32\")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<bfloat, float>>>;\n#endif\ntemplate [[host_name(\"kernel_mul_mv_id_f32_f32_4\")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_f16_f32_4\")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<half,  half4,  float, float4>>>;\n#if defined(GGML_METAL_HAS_BF16)\ntemplate [[host_name(\"kernel_mul_mv_id_bf16_f32_4\")]]  kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<bfloat, bfloat4, float, float4>>>;\n#endif\n\ntemplate [[host_name(\"kernel_mul_mv_id_q8_0_f32\")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;\n\ntemplate [[host_name(\"kernel_mul_mv_id_q4_0_f32\")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_q4_1_f32\")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_q5_0_f32\")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_q5_1_f32\")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1>>>;\n\ntemplate [[host_name(\"kernel_mul_mv_id_mxfp4_f32\")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4>>>;\n\ntemplate [[host_name(\"kernel_mul_mv_id_q2_K_f32\")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl   <N_R0_Q2_K>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_q3_K_f32\")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl   <N_R0_Q3_K>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_q4_K_f32\")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl   <N_R0_Q4_K>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_q5_K_f32\")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl   <N_R0_Q5_K>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_q6_K_f32\")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl   <N_R0_Q6_K>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_iq1_s_f32\")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl  <N_R0_IQ1_S>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_iq1_m_f32\")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl  <N_R0_IQ1_M>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_iq2_xxs_f32\")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_iq2_xs_f32\")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_iq3_xxs_f32\")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_iq3_s_f32\")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl  <N_R0_IQ3_S>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_iq2_s_f32\")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl  <N_R0_IQ2_S>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_iq4_nl_f32\")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL>>>;\ntemplate [[host_name(\"kernel_mul_mv_id_iq4_xs_f32\")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS>>>;\n\nkernel void kernel_pool_2d_max_f32(\n        constant    ggml_metal_kargs_pool_2d & args,\n        device  const float * src0,\n        device        float * dst,\n        uint        gid[[thread_position_in_grid]]) {\n\n    if (gid >= args.np) {\n        return;\n    }\n\n    const int idx = gid;\n    const int I_HW = args.IH * args.IW;\n    const int O_HW = args.OH * args.OW;\n    const int nc = idx / O_HW;\n    const int cur_oh = idx % O_HW / args.OW;\n    const int cur_ow = idx % O_HW % args.OW;\n\n    device const float * i_ptr = src0 + nc * I_HW;\n    device       float * o_ptr = dst  + nc * O_HW;\n\n    const int start_h = cur_oh * args.s1 - args.p1;\n    const int bh = MAX(0,  start_h);\n    const int eh = MIN(args.IH, start_h + args.k1);\n    const int start_w = cur_ow * args.s0 - args.p0;\n    const int bw = MAX(0,  start_w);\n    const int ew = MIN(args.IW, start_w + args.k0);\n\n    float res = -INFINITY;\n\n    for (int i = bh; i < eh; i += 1) {\n        for (int j = bw; j < ew; j += 1) {\n            res = MAX(res, i_ptr[i * args.IW + j]);\n        }\n    }\n\n    o_ptr[cur_oh * args.OW + cur_ow] = res;\n}\n\nkernel void kernel_pool_2d_avg_f32(\n        constant    ggml_metal_kargs_pool_2d & args,\n        device  const float * src0,\n        device        float * dst,\n        uint        gid[[thread_position_in_grid]]) {\n\n    if (gid >= args.np) {\n        return;\n    }\n\n    const int idx = gid;\n    const int I_HW = args.IH * args.IW;\n    const int O_HW = args.OH * args.OW;\n    const int nc = idx / O_HW;\n    const int cur_oh = idx % O_HW / args.OW;\n    const int cur_ow = idx % O_HW % args.OW;\n\n    device const float * i_ptr = src0 + nc * I_HW;\n    device       float * o_ptr = dst  + nc * O_HW;\n\n    const int start_h = cur_oh * args.s1 - args.p1;\n    const int bh = MAX(0,  start_h);\n    const int eh = MIN(args.IH, start_h + args.k1);\n    const int start_w = cur_ow * args.s0 - args.p0;\n    const int bw = MAX(0,  start_w);\n    const int ew = MIN(args.IW, start_w + args.k0);\n    // const float scale = 1. / ((eh - bh) * (ew - bw));\n    const float scale = 1. / (args.k0 * args.k1);\n\n    float res = 0;\n\n    for (int i = bh; i < eh; i += 1) {\n        for (int j = bw; j < ew; j += 1) {\n            float cur = i_ptr[i * args.IW + j];\n            res += cur * scale;\n        }\n    }\n\n    o_ptr[cur_oh * args.OW + cur_ow] = res;\n}\n\n\nkernel void kernel_pool_1d_max_f32(\n        constant        ggml_metal_kargs_pool_1d & args,\n        device  const   float * src,\n        device          float * dst,\n        uint            gid [[thread_position_in_grid]]\n) {\n\n    if (gid >= args.np) {\n        return;\n    }\n\n    const int ow  = (int)gid % args.OW;\n    const int row = (int)gid / args.OW;\n\n    const int base = ow * args.s0 - args.p0;\n\n    float acc = -INFINITY;\n\n    const int src_off = row * args.IW;\n    const int dst_off = row * args.OW;\n\n    for (int ki = 0; ki < args.k0; ++ki) {\n        int j = base + ki;\n        if (j < 0 || j >= args.IW){\n            continue;\n        }\n        float v = src[src_off + j];\n        acc = max(acc, v);\n    }\n\n    dst[dst_off + ow] = acc;\n}\n\nkernel void kernel_pool_1d_avg_f32(\n        constant        ggml_metal_kargs_pool_1d & args,\n        device  const   float * src,\n        device          float * dst,\n        uint            gid [[thread_position_in_grid]]\n) {\n\n    if (gid >= args.np) {\n        return;\n    }\n\n    const int ow  = (int)gid % args.OW;\n    const int row = (int)gid / args.OW;\n\n    const int base = ow * args.s0 - args.p0;\n\n    float acc = 0.0f;\n    int   cnt = 0;\n\n    const int src_off = row * args.IW;\n    const int dst_off = row * args.OW;\n\n    for (int ki = 0; ki < args.k0; ++ki) {\n        const int j = base + ki;\n        if (j < 0 || j >= args.IW) {\n            continue;\n        }\n        acc += src[src_off + j];\n        cnt += 1;\n    }\n\n    dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;\n}\n\nkernel void kernel_opt_step_adamw_f32(\n        constant    ggml_metal_kargs_opt_step_adamw & args,\n        device       float * x,\n        device const float * g,\n        device       float * g_m,\n        device       float * g_v,\n        device const float * pars,\n        uint        gid[[thread_position_in_grid]]) {\n\n    if (gid >= args.np) {\n        return;\n    }\n\n    const float alpha  = pars[0];\n    const float beta1  = pars[1];\n    const float beta2  = pars[2];\n    const float eps    = pars[3];\n    const float wd     = pars[4];\n    const float beta1h = pars[5];\n    const float beta2h = pars[6];\n\n    const float gi = g[gid];\n    const float gmi = g_m[gid] * beta1 +      gi * (1.0f - beta1);\n    const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);\n\n    g_m[gid] = gmi;\n    g_v[gid] = gvi;\n\n    const float mh =      gmi * beta1h;\n    const float vh = sqrt(gvi * beta2h) + eps;\n\n    x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;\n}\n\nkernel void kernel_opt_step_sgd_f32(\n        constant    ggml_metal_kargs_opt_step_sgd & args,\n        device       float * x,\n        device const float * g,\n        device const float * pars,\n        uint        gid[[thread_position_in_grid]]) {\n\n    if (gid >= args.np) {\n        return;\n    }\n\n    x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];\n}\n\ntemplate<typename T>\nkernel void kernel_memset(\n        constant ggml_metal_kargs_memset & args,\n        device T * dst,\n        uint tpig[[thread_position_in_grid]]) {\n    dst[tpig] = args.val;\n}\n\ntypedef decltype(kernel_memset<int64_t>) kernel_memset_t;\n\ntemplate [[host_name(\"kernel_memset_i64\")]] kernel kernel_memset_t kernel_memset<int64_t>;\n\nconstant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];\n\ntemplate<typename T>\nkernel void kernel_count_equal(\n        constant ggml_metal_kargs_count_equal & args,\n        device   const char * src0,\n        device   const char * src1,\n        device   atomic_int * dst,\n        threadgroup int32_t * shmem_i32 [[threadgroup(0)]],\n        uint3   tgpig[[threadgroup_position_in_grid]],\n        ushort3 tpitg[[thread_position_in_threadgroup]],\n        ushort  sgitg[[simdgroup_index_in_threadgroup]],\n        ushort  tiisg[[thread_index_in_simdgroup]],\n        ushort3   ntg[[threads_per_threadgroup]]) {\n    const short NSG = FC_count_equal_nsg;\n\n    const int i3 = tgpig.z;\n    const int i2 = tgpig.y;\n    const int i1 = tgpig.x;\n\n    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {\n        return;\n    }\n\n    int sum = 0;\n\n    device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;\n    device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;\n\n    for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {\n        const T v0 = *(device const T *)(base0 + i0*args.nb00);\n        const T v1 = *(device const T *)(base1 + i0*args.nb10);\n        sum += (v0 == v1);\n    }\n\n    sum = simd_sum(sum);\n\n    if (tiisg == 0) {\n        shmem_i32[sgitg] = sum;\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    if (sgitg == 0) {\n        float v = 0.0f;\n        if (tpitg.x < NSG) {\n            v = shmem_i32[tpitg.x];\n        }\n\n        float total = simd_sum(v);\n        if (tpitg.x == 0) {\n            atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);\n        }\n    }\n}\n\ntypedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;\n\ntemplate [[host_name(\"kernel_count_equal_i32\")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;\n"
  },
  {
    "path": "src/ggml-musa/CMakeLists.txt",
    "content": "if (NOT EXISTS $ENV{MUSA_PATH})\n    if (NOT EXISTS /opt/musa)\n        set(MUSA_PATH /usr/local/musa)\n    else()\n        set(MUSA_PATH /opt/musa)\n    endif()\nelse()\n    set(MUSA_PATH $ENV{MUSA_PATH})\nendif()\n\nset(CMAKE_C_COMPILER \"${MUSA_PATH}/bin/clang\")\nset(CMAKE_C_EXTENSIONS OFF)\nset(CMAKE_CXX_COMPILER \"${MUSA_PATH}/bin/clang++\")\nset(CMAKE_CXX_EXTENSIONS OFF)\n\nlist(APPEND CMAKE_MODULE_PATH \"${MUSA_PATH}/cmake\")\n\nfind_package(MUSAToolkit)\n\nif (MUSAToolkit_FOUND)\n    message(STATUS \"MUSA Toolkit found\")\n\n    if (NOT DEFINED MUSA_ARCHITECTURES)\n        set(MUSA_ARCHITECTURES \"21;22;31\")\n    endif()\n    message(STATUS \"Using MUSA architectures: ${MUSA_ARCHITECTURES}\")\n\n    file(GLOB   GGML_HEADERS_MUSA \"../ggml-cuda/*.cuh\")\n    list(APPEND GGML_HEADERS_MUSA \"../../include/ggml-cuda.h\")\n    list(APPEND GGML_HEADERS_MUSA \"../ggml-musa/mudnn.cuh\")\n\n    file(GLOB   GGML_SOURCES_MUSA \"../ggml-cuda/*.cu\")\n    file(GLOB   SRCS \"../ggml-cuda/template-instances/fattn-tile*.cu\")\n    list(APPEND GGML_SOURCES_MUSA ${SRCS})\n    file(GLOB   SRCS \"../ggml-cuda/template-instances/fattn-mma*.cu\")\n    list(APPEND GGML_SOURCES_MUSA ${SRCS})\n    file(GLOB   SRCS \"../ggml-cuda/template-instances/mmq*.cu\")\n    list(APPEND GGML_SOURCES_MUSA ${SRCS})\n\n    if (GGML_MUSA_MUDNN_COPY)\n        file(GLOB   SRCS \"../ggml-musa/*.cu\")\n        list(APPEND GGML_SOURCES_MUSA ${SRCS})\n        add_compile_definitions(GGML_MUSA_MUDNN_COPY)\n    endif()\n\n    if (GGML_CUDA_FA_ALL_QUANTS)\n        file(GLOB   SRCS \"../ggml-cuda/template-instances/fattn-vec*.cu\")\n        list(APPEND GGML_SOURCES_MUSA ${SRCS})\n        add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)\n    else()\n        file(GLOB   SRCS \"../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu\")\n        list(APPEND GGML_SOURCES_MUSA ${SRCS})\n        file(GLOB   SRCS \"../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu\")\n        list(APPEND GGML_SOURCES_MUSA ${SRCS})\n        file(GLOB   SRCS \"../ggml-cuda/template-instances/fattn-vec*f16-f16.cu\")\n        list(APPEND GGML_SOURCES_MUSA ${SRCS})\n    endif()\n\n    set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)\n    foreach(SOURCE ${GGML_SOURCES_MUSA})\n        set(COMPILE_FLAGS \"-Od3 -fno-strict-aliasing -ffast-math -fsigned-char -x musa -mtgpu -fmusa-flush-denormals-to-zero\")\n        foreach(ARCH ${MUSA_ARCHITECTURES})\n            set(COMPILE_FLAGS \"${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}\")\n        endforeach()\n        set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS ${COMPILE_FLAGS})\n    endforeach()\n\n    ggml_add_backend_library(ggml-musa\n                             ${GGML_HEADERS_MUSA}\n                             ${GGML_SOURCES_MUSA}\n                            )\n\n    # TODO: do not use CUDA definitions for MUSA\n    if (NOT GGML_BACKEND_DL)\n        target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)\n    endif()\n\n    add_compile_definitions(GGML_USE_MUSA)\n    add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})\n\n    if (GGML_MUSA_GRAPHS)\n        add_compile_definitions(GGML_MUSA_GRAPHS)\n    endif()\n\n    if (GGML_CUDA_FORCE_MMQ)\n        add_compile_definitions(GGML_CUDA_FORCE_MMQ)\n    endif()\n\n    if (GGML_CUDA_FORCE_CUBLAS)\n        add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)\n    endif()\n\n    if (GGML_CUDA_NO_VMM)\n        add_compile_definitions(GGML_CUDA_NO_VMM)\n    endif()\n\n    if (NOT GGML_CUDA_FA)\n        add_compile_definitions(GGML_CUDA_NO_FA)\n    endif()\n\n    if (GGML_CUDA_NO_PEER_COPY)\n        add_compile_definitions(GGML_CUDA_NO_PEER_COPY)\n    endif()\n\n    if (GGML_STATIC)\n        target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)\n        # TODO: mudnn has not provided static libraries yet\n        # if (GGML_MUSA_MUDNN_COPY)\n        #     target_link_libraries(ggml-musa PRIVATE mudnn_static)\n        # endif()\n    else()\n        target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)\n        if (GGML_MUSA_MUDNN_COPY)\n            target_link_libraries(ggml-musa PRIVATE mudnn)\n        endif()\n    endif()\n\n    if (GGML_CUDA_NO_VMM)\n        # No VMM requested, no need to link directly with the musa driver lib (libmusa.so)\n    else()\n        target_link_libraries(ggml-musa PRIVATE MUSA::musa_driver)\n    endif()\nelse()\n    message(FATAL_ERROR \"MUSA Toolkit not found\")\nendif()\n"
  },
  {
    "path": "src/ggml-musa/mudnn.cu",
    "content": "#include <mutex>\n#include <mudnn.h>\n\n#include \"mudnn.cuh\"\n\nnamespace mudnn = musa::dnn;\n\n// Returns a human-readable error string for mudnn::Status\nconst char* mudnnGetErrorString(mudnn::Status err) {\n    switch (err) {\n        case mudnn::Status::SUCCESS:\n            return \"Success\";\n        case mudnn::Status::INVALID_PARAMETER:\n            return \"Invalid parameter\";\n        case mudnn::Status::NOT_INITIALIZED:\n            return \"Not initialized\";\n        case mudnn::Status::ALLOC_FAILED:\n            return \"Allocation failed\";\n        case mudnn::Status::NOT_SUPPORTED:\n            return \"Not supported\";\n        case mudnn::Status::INTERNAL_ERROR:\n            return \"Internal error\";\n        case mudnn::Status::ARCH_MISMATCH:\n            return \"Architecture mismatch\";\n        case mudnn::Status::EXECUTION_FAILED:\n            return \"Execution failed\";\n        default:\n            return \"Unknown mudnn status\";\n    }\n}\n\n// Error checking macro for MUDNN calls\n#define MUDNN_CHECK(err) CUDA_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString)\n\nnamespace {\n    // Thread-safe cache for mudnn::Handle objects per device\n    std::unordered_map<int, std::unique_ptr<mudnn::Handle>> handle_cache;\n    std::mutex handle_cache_mutex;\n\n    mudnn::Handle* get_cached_handle(int device_id) {\n        std::lock_guard<std::mutex> lock(handle_cache_mutex);\n        auto it = handle_cache.find(device_id);\n        if (it != handle_cache.end()) {\n            return it->second.get();\n        }\n        auto handle = std::make_unique<mudnn::Handle>(device_id);\n        mudnn::Handle* handle_ptr = handle.get();\n        handle_cache[device_id] = std::move(handle);\n        return handle_ptr;\n    }\n}\n\n// Extracts dimensions and strides from a ggml_tensor\nint get_ggml_dims_and_strides(const ggml_tensor* tensor,\n                              std::vector<int64_t>& dims,\n                              std::vector<int64_t>& strides) {\n    const int ndims = ggml_n_dims(tensor);\n    const size_t element_size = ggml_element_size(tensor);\n\n    dims.resize(ndims);\n    strides.resize(ndims);\n\n    for (int i = 0; i < ndims; ++i) {\n        dims[i] = tensor->ne[i];\n        strides[i] = tensor->nb[i] / static_cast<int64_t>(element_size);\n    }\n    return ndims;\n}\n\n// Converts ggml_type to mudnn::Tensor::Type\nmudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_F32:\n            return mudnn::Tensor::Type::FLOAT;\n        case GGML_TYPE_F16:\n            return mudnn::Tensor::Type::HALF;\n\n        // TODO: Add support for other types\n\n        default:\n            MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED);\n    }\n\n    return mudnn::Tensor::Type::FLOAT; // Default fallback\n}\n\n// Asynchronous memory copy using mudnn::Unary::IDENTITY\nmusaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) {\n    mudnn::Tensor tensor_dst, tensor_src;\n\n    MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type)));\n    MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type)));\n\n    std::vector<int64_t> dims, strides;\n    const int ndims = get_ggml_dims_and_strides(src, dims, strides);\n\n    MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data()));\n    MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data()));\n    MUDNN_CHECK(tensor_dst.SetAddr(dst->data));\n    MUDNN_CHECK(tensor_src.SetAddr(src->data));\n\n    mudnn::Unary op;\n    MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY));\n    MUDNN_CHECK(op.SetAlpha(0.0f));\n    MUDNN_CHECK(op.SetBeta(0.0f));\n\n    mudnn::Handle* handle = get_cached_handle(ctx.device);\n    MUDNN_CHECK(handle->SetStream(ctx.stream()));\n    MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src));\n\n    return musaSuccess;\n}\n"
  },
  {
    "path": "src/ggml-musa/mudnn.cuh",
    "content": "#pragma once\n\n#include \"ggml-cuda/common.cuh\"\n#include \"ggml.h\"\n\n// Asynchronously copies data from src tensor to dst tensor using the provided context.\n// Returns a musaError_t indicating success or failure.\nmusaError_t mudnnMemcpyAsync(\n    ggml_backend_cuda_context &ctx,\n    const ggml_tensor *dst,\n    const ggml_tensor *src\n);\n"
  },
  {
    "path": "src/ggml-opencl/CMakeLists.txt",
    "content": "find_package(OpenCL REQUIRED)\nfind_package(Python3 REQUIRED)\n\nset(TARGET_NAME ggml-opencl)\n\nggml_add_backend_library(${TARGET_NAME}\n                         ggml-opencl.cpp\n                         ../../include/ggml-opencl.h)\ntarget_link_libraries(${TARGET_NAME} PRIVATE ${OpenCL_LIBRARIES})\ntarget_include_directories(${TARGET_NAME} PRIVATE ${OpenCL_INCLUDE_DIRS})\n\nif (GGML_OPENCL_PROFILING)\n    message(STATUS \"OpenCL profiling enabled (increases CPU overhead)\")\n    add_compile_definitions(GGML_OPENCL_PROFILING)\nendif ()\n\nadd_compile_definitions(GGML_OPENCL_SOA_Q)\nadd_compile_definitions(GGML_OPENCL_TARGET_VERSION=${GGML_OPENCL_TARGET_VERSION})\n\nif (GGML_OPENCL_USE_ADRENO_KERNELS)\n    message(STATUS \"OpenCL will use matmul kernels optimized for Adreno\")\n    add_compile_definitions(GGML_OPENCL_USE_ADRENO_KERNELS)\nendif ()\n\nif (GGML_OPENCL_EMBED_KERNELS)\n    add_compile_definitions(GGML_OPENCL_EMBED_KERNELS)\n\n    set(EMBED_KERNEL_SCRIPT \"${CMAKE_CURRENT_SOURCE_DIR}/kernels/embed_kernel.py\")\n    file(MAKE_DIRECTORY     \"${CMAKE_CURRENT_BINARY_DIR}/autogenerated\")\n\n    target_include_directories(${TARGET_NAME} PRIVATE \"${CMAKE_CURRENT_BINARY_DIR}/autogenerated\")\nendif ()\n\nfunction(ggml_opencl_add_kernel KNAME)\n    set(KERN_HDR ${CMAKE_CURRENT_BINARY_DIR}/autogenerated/${KNAME}.cl.h)\n    set(KERN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/kernels/${KNAME}.cl)\n\n    if (GGML_OPENCL_EMBED_KERNELS)\n        message(STATUS \"opencl: embedding kernel ${KNAME}\")\n\n        # Python must be accessible from command line\n        add_custom_command(\n            OUTPUT ${KERN_HDR}\n            COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} ${KERN_SRC} ${KERN_HDR}\n            DEPENDS ${KERN_SRC} ${EMBED_KERNEL_SCRIPT}\n            COMMENT \"Generate ${KERN_HDR}\"\n        )\n\n        target_sources(${TARGET_NAME} PRIVATE ${KERN_HDR})\n    else ()\n        message(STATUS \"opencl: adding kernel ${KNAME}\")\n        configure_file(${KERN_SRC} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${KNAME}.cl COPYONLY)\n    endif ()\nendfunction()\n\nset(GGML_OPENCL_KERNELS\n    add\n    add_id\n    argsort\n    tri\n    fill\n    clamp\n    cpy\n    cvt\n    diag_mask_inf\n    diag\n    div\n    gelu\n    gemv_noshuffle_general\n    gemv_noshuffle\n    get_rows\n    glu\n    group_norm\n    solve_tri\n    im2col_f32\n    im2col_f16\n    mean\n    mul_mat_Ab_Bi_8x4\n    mul_mv_f16_f16\n    mul_mv_f16_f32_1row\n    mul_mv_f16_f32_l4\n    mul_mv_f16_f32\n    mul_mv_f32_f32\n    mul_mv_q4_0_f32\n    mul_mv_q4_0_f32_v\n    mul_mv_q4_0_f32_8x_flat\n    mul_mv_q4_0_f32_1d_8x_flat\n    mul_mv_q4_0_f32_1d_16x_flat\n    mul_mv_q4_1_f32\n    mul_mv_q4_1_f32_flat\n    mul_mv_q4_k_f32\n    mul_mv_q6_k_f32\n    mul_mv_q6_k_f32_flat\n    mul_mv_q8_0_f32\n    mul_mv_q8_0_f32_flat\n    mul_mv_mxfp4_f32\n    mul_mv_mxfp4_f32_flat\n    mul_mv_id_q4_0_f32_8x_flat\n    mul_mv_id_q8_0_f32\n    mul_mv_id_q8_0_f32_flat\n    mul_mv_id_mxfp4_f32\n    mul_mv_id_mxfp4_f32_flat\n    gemm_moe_mxfp4_f32\n    gemv_moe_mxfp4_f32\n    mul_mm_f32_f32_l4_lm\n    mul_mm_f16_f32_l4_lm\n    mul_mm_q4_0_f32_l4_lm\n    mul_mm_q4_1_f32_l4_lm\n    mul_mm_q8_0_f32_l4_lm\n    mul_mm_q6_k_f32_l4_lm\n    mul_mm_q8_0_f32_8x4\n    gemv_noshuffle_q4_1_f32\n    gemm_noshuffle_q4_1_f32\n    gemv_noshuffle_general_q8_0_f32\n    mul\n    neg\n    norm\n    relu\n    l2_norm\n    rms_norm\n    rope\n    scale\n    set_rows\n    sigmoid\n    silu\n    softmax_4_f32\n    softmax_4_f16\n    softmax_f32\n    softmax_f16\n    sqr\n    sqrt\n    ssm_conv\n    sub\n    sum_rows\n    cumsum\n    transpose\n    concat\n    tsembd\n    upscale\n    tanh\n    exp\n    expm1\n    softplus\n    pad\n    repeat\n    mul_mat_f16_f32\n    mul_mm_f16_f32_kq_kqv\n    conv2d\n    conv2d_f16_f32\n    flash_attn_f32_f16\n    flash_attn_f16\n    flash_attn_f32\n)\n\nforeach (K ${GGML_OPENCL_KERNELS})\n    ggml_opencl_add_kernel(${K})\nendforeach()\n"
  },
  {
    "path": "src/ggml-opencl/ggml-opencl.cpp",
    "content": "#define CL_TARGET_OPENCL_VERSION GGML_OPENCL_TARGET_VERSION\n#define CL_USE_DEPRECATED_OPENCL_1_2_APIS\n\n// suppress warnings in CL headers for GCC and Clang\n#pragma GCC diagnostic ignored \"-Woverlength-strings\"\n#ifdef __clang__\n#pragma GCC diagnostic ignored \"-Wgnu-anonymous-struct\"\n#endif\n\n#include \"ggml-opencl.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-backend-impl.h\"\n#include \"ggml.h\"\n\n#include <CL/cl.h>\n\n#include <inttypes.h>\n#include <string.h>\n\n#include <cstddef>\n#include <cstdint>\n#include <fstream>\n#include <vector>\n#include <string>\n#include <cmath>\n#include <map>\n#include <memory>\n#include <charconv>\n#include <mutex>\n\n#undef MIN\n#undef MAX\n#define MIN(a, b) ((a) < (b) ? (a) : (b))\n#define MAX(a, b) ((a) > (b) ? (a) : (b))\n#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))\n\n#define UNUSED(x) (void)(x)\n\n#define CL_CHECK(err)                                               \\\n    do {                                                            \\\n        cl_int err_ = (err);                                        \\\n        if (err_ != CL_SUCCESS) {                                   \\\n            GGML_LOG_ERROR(\"ggml_opencl: %s error %d at %s:%d\\n\",  \\\n                #err, err_, __FILE__, __LINE__);                    \\\n            GGML_ASSERT(0);                                         \\\n        }                                                           \\\n    } while (0)\n\n//------------------------------------------------------------------------------\n// OpenCL\n//------------------------------------------------------------------------------\n\nbool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor);\n\n// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.\n// Precompute mp (m' in the paper) and L such that division\n// can be computed using a multiply (high 32b of 64b result)\n// and a shift:\n//\n// n/d = (mulhi(n, mp) + n) >> L;\nstruct fastdiv_vals {\n    uint32_t mp;\n    uint32_t L;\n    uint32_t d;\n    uint32_t pad;\n};\nstatic_assert(sizeof(fastdiv_vals) == 16, \"fastdiv_vals size incorrect\");\n\nstatic fastdiv_vals init_fastdiv_values(uint64_t d_64) {\n    GGML_ASSERT(d_64 != 0);\n    GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());\n\n    uint32_t d = (uint32_t)d_64;\n\n    // compute L = ceil(log2(d));\n    uint32_t L = 0;\n    while (L < 32 && (uint32_t{ 1 } << L) < d) {\n        L++;\n    }\n\n    uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);\n    // pack divisor as well to reduce error surface\n    return { mp, L, d, 0 };\n}\n\nenum GPU_FAMILY {\n    ADRENO,\n    INTEL,\n    UNKNOWN,\n};\n\nenum ADRENO_GPU_GEN {\n    ADRENO_UNKNOWN,\n    A7X,\n    A8X,\n    X1E,\n};\n\nenum ADRENO_CL_COMPILER_TYPE {\n    E031,\n    DX,\n};\n\nstruct ggml_cl_version {\n    cl_uint major = 0;\n    cl_uint minor = 0;\n};\n\n\nstruct ggml_cl_compiler_version {\n    ADRENO_CL_COMPILER_TYPE type;\n    int major = -1;\n    int minor = -1;\n    int patch = -1;\n\n    bool same(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const {\n        return major == x && minor == y && patch == z && type == t;\n    }\n    bool newer_than(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const {\n        return major*10000 + minor*100 + patch > x*10000 + y*100 + z && type == t;\n    }\n    bool newer_than_or_same(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const {\n        return same(t, x, y, z) || newer_than(t, x, y, z);\n    }\n};\n\nstatic size_t align_to(size_t value, size_t to_alignment) {\n    GGML_ASSERT(to_alignment && \"Invalid alignment (must be non-zero)\");\n    GGML_ASSERT((to_alignment & (to_alignment - 1)) == 0 && \"to_alignment must be power-of-two\");\n\n    return ((value + to_alignment - 1) / to_alignment) * to_alignment;\n}\n\n\n// Parses a version string of form \"XX.YY \". On an error returns ggml_cl_version with all zeroes.\nstatic ggml_cl_version parse_cl_version(std::string_view str) {\n    size_t major_str_begin = 0;\n    size_t major_str_end   = str.find(\".\", major_str_begin);\n    if (major_str_end == std::string::npos) {\n        return {};\n    }\n\n    size_t minor_str_begin = major_str_end + 1;\n    size_t minor_str_end   = str.find(\" \", minor_str_begin);\n    if (minor_str_end == std::string::npos) {\n        return {};\n    }\n\n    cl_uint version_major;\n    if (std::from_chars(str.data() + major_str_begin, str.data() + major_str_end, version_major).ec != std::errc{}) {\n        return {};\n    }\n\n    cl_uint version_minor;\n    if (std::from_chars(str.data() + minor_str_begin, str.data() + minor_str_end, version_minor).ec != std::errc{}) {\n        return {};\n    }\n    return { version_major, version_minor };\n}\n\n// Returns OpenCL platform's version. On an error returns ggml_cl_version with all zeroes.\nstatic ggml_cl_version get_opencl_platform_version(cl_platform_id platform) {\n    size_t param_size;\n    CL_CHECK(clGetPlatformInfo(platform, CL_PLATFORM_VERSION, 0, nullptr, &param_size));\n    std::unique_ptr<char[]> param_storage(new char[param_size]);\n    CL_CHECK(clGetPlatformInfo(platform, CL_PLATFORM_VERSION, param_size, param_storage.get(), nullptr));\n\n    auto              param_value    = std::string_view(param_storage.get(), param_size);\n    const std::string version_prefix = \"OpenCL \";  // Suffix: \"XX.YY <platform-specific-info>\"\n    if (param_value.find(version_prefix) != 0) {\n        return {};\n    }\n    param_value.remove_prefix(version_prefix.length());\n    return parse_cl_version(param_value);\n}\n\n// Return a version to use in OpenCL C compilation. On an error returns ggml_cl_version with all zeroes.\nstatic ggml_cl_version get_opencl_c_version(ggml_cl_version platform_version, cl_device_id device) {\n    size_t param_size;\n\n#if CL_TARGET_OPENCL_VERSION >= 300\n    if (platform_version.major >= 3) {\n        CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_OPENCL_C_ALL_VERSIONS, 0, nullptr, &param_size));\n        if (!param_size) {\n            return {};\n        }\n\n        std::unique_ptr<cl_name_version[]> versions(new cl_name_version[param_size]);\n        CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_OPENCL_C_ALL_VERSIONS, param_size, versions.get(), nullptr));\n        unsigned versions_count = param_size / sizeof(cl_name_version);\n\n        cl_version version_max = 0;\n        for (unsigned i = 0; i < versions_count; i++) {\n            version_max = std::max<cl_version>(versions[i].version, version_max);\n        }\n\n        return { CL_VERSION_MAJOR(version_max), CL_VERSION_MINOR(version_max) };\n    }\n#else\n    GGML_UNUSED(platform_version);\n#endif  // CL_TARGET_OPENCL_VERSION >= 300\n\n    CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_OPENCL_C_VERSION, 0, nullptr, &param_size));\n    if (!param_size) {\n        return {};\n    }\n\n    std::unique_ptr<char[]> param_storage(new char[param_size]);\n    CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_OPENCL_C_VERSION, param_size, param_storage.get(), nullptr));\n    auto param_value = std::string_view(param_storage.get(), param_size);\n\n    const std::string version_prefix = \"OpenCL C \";  // Suffix: \"XX.YY <platform-specific-info>\"\n    if (param_value.find(version_prefix) != 0) {\n        return {};\n    }\n    param_value.remove_prefix(version_prefix.length());\n\n    return parse_cl_version(param_value);\n}\n\nstatic ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) {\n    if (strstr(device_name, \"730\") ||\n        strstr(device_name, \"740\") ||\n        strstr(device_name, \"750\")) {\n        return ADRENO_GPU_GEN::A7X;\n    }\n\n    if (strstr(device_name, \"830\") ||\n        strstr(device_name, \"840\")) {\n        return ADRENO_GPU_GEN::A8X;\n    }\n\n    if (strstr(device_name, \"X1\")) {\n        return ADRENO_GPU_GEN::X1E;\n    }\n\n    return ADRENO_GPU_GEN::ADRENO_UNKNOWN;\n}\n\nstatic ggml_cl_compiler_version get_adreno_cl_compiler_version(const char *driver_version) {\n    std::string driver_ver_str(driver_version);\n    ADRENO_CL_COMPILER_TYPE type = ADRENO_CL_COMPILER_TYPE::E031;\n    size_t compiler_ver_pos = driver_ver_str.find(\"E031\");\n    size_t compiler_ver_len = 13;\n    size_t compiler_major_offset = 5;\n    size_t compiler_minor_offset = 8;\n    size_t compiler_patch_offset = 11;\n\n    if (compiler_ver_pos == std::string::npos) {\n        compiler_ver_pos = driver_ver_str.find(\"DX\");\n        if (compiler_ver_pos == std::string::npos) {\n            return {};\n        }\n        type = ADRENO_CL_COMPILER_TYPE::DX;\n        compiler_ver_len = 11;\n        compiler_major_offset = 3;\n    }\n\n    std::string compiler_ver_str = driver_ver_str.substr(compiler_ver_pos, compiler_ver_len);\n    int major = std::atoi(compiler_ver_str.substr(compiler_major_offset, 2).c_str());\n    int minor = std::atoi(compiler_ver_str.substr(compiler_minor_offset, 2).c_str());\n    int patch = std::atoi(compiler_ver_str.substr(compiler_patch_offset, 2).c_str());\n    return { type, major, minor, patch };\n}\n\n// cl buffer wrapper\nstruct ggml_cl_buffer {\n    cl_mem buffer;\n    size_t size;\n\n    ggml_cl_buffer()\n        : buffer(nullptr), size(0) {}\n\n    ~ggml_cl_buffer() {\n        if (buffer) {\n            CL_CHECK(clReleaseMemObject(buffer));\n        }\n    }\n\n    void allocate(cl_context context, size_t new_size) {\n        if (new_size > size) {\n            size = new_size;\n            if (buffer) {\n                CL_CHECK(clReleaseMemObject(buffer));\n            }\n            cl_int err;\n            CL_CHECK((buffer = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err), err));\n        }\n    }\n};\n\n// Profiling\nstruct ProfilingInfo {\n    std::string op_name;\n    std::string kernel_name;\n\n    cl_kernel kernel;\n    cl_event evt;\n\n    cl_ulong cmd_queued;\n    cl_ulong cmd_submit;\n    cl_ulong cmd_start;\n    cl_ulong cmd_end;\n    cl_ulong overhead_start;\n    cl_ulong overhead_end;\n    // For the times below, see spec for clGetEventProfilingInfo\n    // The time kernel spent in cmd queue - SUBMIT - QUEUED\n    cl_ulong cmd_queued_duration_ns;\n    // The time kernel spent for submission - START - SUBMIT\n    cl_ulong cmd_submit_duration_ns;\n    // Kernel execution time in nanoseconds - END - START\n    cl_ulong cmd_duration_ns;\n    // The time for the kernel to complete - COMPLETE - END\n    cl_ulong cmd_complete_duration_ns;\n    // Total time to finish the kernel - COMPLETE - QUEUED\n    cl_ulong cmd_total_duration_ns;\n    // Global and local work sizes.\n    size_t global_size[3];\n    size_t local_size[3];\n    // Op output size.\n    size_t output_size[4];\n};\n\nstatic void populateProfilingInfo(\n        ProfilingInfo& info, cl_event evt, cl_kernel kernel, cl_uint work_dim,\n        size_t global_size[3], size_t local_size[3],\n        const ggml_tensor * tensor) {\n    info.op_name     = tensor->name;\n    info.kernel      = kernel;\n    info.evt         = evt;\n\n    // 0 means not specified, e.g., 2D workgroup, or NULL for driver to choose\n    info.local_size[0] = 0;\n    info.local_size[1] = 0;\n    info.local_size[2] = 0;\n\n    info.global_size[0] = 0;\n    info.global_size[1] = 0;\n    info.global_size[2] = 0;\n\n    if (local_size) {\n        for (cl_uint i = 0; i < work_dim; ++i) {\n            info.local_size[i] = local_size[i];\n        }\n    }\n\n    for (cl_uint i = 0; i < work_dim; ++i) {\n        info.global_size[i] = global_size[i];\n    }\n\n    info.output_size[0] = tensor->ne[0];\n    info.output_size[1] = tensor->ne[1];\n    info.output_size[2] = tensor->ne[2];\n    info.output_size[3] = tensor->ne[3];\n}\n\nstruct ggml_backend_opencl_context;\n\n// backend device context\nstruct ggml_backend_opencl_device_context {\n    cl_platform_id platform;\n    std::string platform_name;\n\n    cl_device_id   device;\n    std::string    device_name;\n    cl_device_type device_type;\n    std::string    device_version;\n\n    // Initialized by ggml_cl2_init().\n    ggml_backend_opencl_context * backend_ctx = nullptr;\n\n    // Initialized by ggml_backend_opencl_device_get_buffer_type()\n    ggml_backend_buffer_type buffer_type;\n\n    cl_context context = nullptr;\n};\n\n// backend context\nstruct ggml_backend_opencl_context {\n    int ref_count;\n\n    cl_device_id device;\n    std::string device_name;\n\n    std::string driver_version;\n\n    GPU_FAMILY gpu_family;\n    ADRENO_GPU_GEN adreno_gen;\n\n    cl_int alignment;\n    size_t max_alloc_size;\n    size_t max_workgroup_size;\n    bool fp16_support;\n    bool has_vector_subgroup_broadcast;\n    bool disable_fusion;\n    ggml_cl_compiler_version adreno_cl_compiler_version;\n\n    int adreno_wave_size;\n\n    cl_bool non_uniform_workgroups;\n    size_t  image_max_buffer_size;\n\n    cl_context context;\n    cl_command_queue queue;\n\n    // prealloc buffers for transposing weights and activations\n    ggml_cl_buffer prealloc_quant_trans;\n    ggml_cl_buffer prealloc_scales_trans;\n    ggml_cl_buffer prealloc_act_trans;\n\n    // prealloc buffers for src0 and src1\n    ggml_cl_buffer prealloc_src0;\n    ggml_cl_buffer prealloc_src1;\n\n    cl_program program_add;\n    cl_program program_add_id;\n    cl_program program_clamp;\n    cl_program program_cvt;\n    cl_program program_diag_mask_inf;\n    cl_program program_gelu;\n    cl_program program_gemv_noshuffle_general;\n    cl_program program_gemv_noshuffle;\n    cl_program program_get_rows;\n    cl_program program_set_rows;\n    cl_program program_glu;\n    cl_program program_im2col_f16;\n    cl_program program_im2col_f32;\n    cl_program program_mul_mat_Ab_Bi_8x4;\n    cl_program program_mul_mv_q4_0_f32;\n    cl_program program_mul_mv_q4_0_f32_v;\n    cl_program program_mul_mv_q4_0_f32_8x_flat;\n    cl_program program_mul_mv_q4_0_f32_1d_8x_flat;\n    cl_program program_mul_mv_q4_0_f32_1d_16x_flat;\n    cl_program program_mul_mv_q6_K;\n    cl_program program_mul_mv_q8_0_f32, program_mul_mv_q8_0_f32_flat;\n    cl_program program_mul_mv_mxfp4_f32;\n    cl_program program_mul_mv_mxfp4_f32_flat;\n    cl_program program_mul_mv_f16_f16;\n    cl_program program_mul_mv_f16_f32_1row;\n    cl_program program_mul_mv_f16_f32_l4;\n    cl_program program_mul_mv_f16_f32;\n    cl_program program_mul_mv_f32_f32;\n    cl_program program_mul;\n    cl_program program_mul_mat_f16_f32_tiled;\n    cl_program program_mul_mm_f16_f32_kqv;\n    cl_program program_mul_mm_f16_f32_kq;\n    cl_program program_div;\n    cl_program program_sub;\n    cl_program program_norm;\n    cl_program program_relu;\n    cl_program program_rms_norm;\n    cl_program program_group_norm;\n    cl_program program_rope;\n    cl_program program_silu;\n    cl_program program_sigmoid;\n    cl_program program_softmax_f32;\n    cl_program program_softmax_f16;\n    cl_program program_softmax_4_f32;\n    cl_program program_softmax_4_f16;\n    cl_program program_argsort_f32_i32;\n    cl_program program_sum_rows_f32;\n    cl_program program_pad;\n    cl_program program_upscale;\n    cl_program program_conv_2d_f16;\n    cl_program program_conv_2d_f32;\n    cl_program program_conv_2d_f16_f32;\n    cl_program program_tsembd;\n    cl_program program_gemv_moe_mxfp4_f32, program_gemm_moe_mxfp4_f32;\n    cl_program program_mul_mv_id_q4_0_f32_8x_flat;\n    cl_program program_mul_mv_id_q8_0_f32, program_mul_mv_id_q8_0_f32_flat;\n    cl_program program_mul_mv_id_mxfp4_f32;\n    cl_program program_mul_mv_id_mxfp4_f32_flat;\n    cl_program program_mul_mm_f32_f32_l4_lm;\n    cl_program program_mul_mm_f16_f32_l4_lm;\n    cl_program program_mul_mm_q8_0_f32_l4_lm;\n\n    cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;\n    cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;\n    cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16;\n    cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;\n    cl_kernel kernel_add_id;\n    cl_kernel kernel_scale_f32, kernel_scale_f32_4;\n    cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4;\n    cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4;\n    cl_kernel kernel_mean_f32, kernel_mean_f32_4;\n    cl_kernel kernel_silu, kernel_silu_4;\n    cl_kernel kernel_gelu, kernel_gelu_4;\n    cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;\n    cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;\n    cl_kernel kernel_relu;\n    cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;\n    cl_kernel kernel_tri;\n    cl_kernel kernel_fill;\n    cl_kernel kernel_clamp;\n    cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,\n              kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;\n    cl_kernel kernel_norm, kernel_norm_mul_add;\n    cl_kernel kernel_rms_norm, kernel_rms_norm_mul;\n    cl_kernel kernel_l2_norm_f32;\n    cl_kernel kernel_group_norm, kernel_group_norm_mul_add;\n    cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;\n    cl_kernel kernel_diag_f32;\n    cl_kernel kernel_soft_max, kernel_soft_max_4;\n    cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;\n    std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16;\n    std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16_q1;\n    std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32;\n    std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q1;\n    std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16;\n    std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1;\n    std::map<std::pair<int, int>, int>       kernels_flash_attn_bm;\n    std::map<std::pair<int, int>, int>       kernels_flash_attn_bn;\n    cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;\n    cl_kernel kernel_set_rows_f32_i64, kernel_set_rows_f32_i32, kernel_set_rows_f16_i64, kernel_set_rows_f16_i32;\n    cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;\n    cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16;\n    cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32, kernel_cpy_i32_i32;\n    cl_kernel kernel_mul_mat_f32_f32;\n    cl_kernel kernel_mul_mat_f16_f16;\n    cl_kernel kernel_mul_mat_f16_f32_1row;\n    cl_kernel kernel_mul_mat_f16_f32;\n    cl_kernel kernel_mul_mat_f16_f32_l4;\n    cl_kernel kernel_mul_mat_f16_f32_tiled;\n    cl_kernel kernel_mul_mm_f16_f32_kqv;\n    cl_kernel kernel_mul_mm_f16_f32_kq;\n    cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;\n    cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;\n    cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1;\n    cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;\n    cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans;\n    cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;\n    cl_kernel kernel_convert_block_q4_0_noshuffle;\n    cl_kernel kernel_restore_block_q4_0_noshuffle;\n    cl_kernel kernel_convert_block_q4_1_noshuffle;\n    cl_kernel kernel_restore_block_q4_1_noshuffle;\n    cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;\n    cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;\n    cl_kernel kernel_mul_mv_q4_1_f32;\n    cl_kernel kernel_mul_mv_q4_1_f32_flat;\n    cl_kernel kernel_mul_mv_q4_K_f32;\n    cl_kernel kernel_mul_mv_q6_K_f32;\n    cl_kernel kernel_mul_mv_q6_K_f32_flat;\n    cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;\n    cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat;\n    cl_kernel kernel_solve_tri_f32;\n    cl_kernel kernel_im2col_f32, kernel_im2col_f16;\n    cl_kernel kernel_argsort_f32_i32;\n    cl_kernel kernel_sum_rows_f32, kernel_sum_rows_f32_4;\n    cl_kernel kernel_cumsum_blk, kernel_cumsum_add;\n    cl_kernel kernel_repeat_f32;\n    cl_kernel kernel_pad;\n    cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc;\n    cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc;\n    cl_kernel kernel_neg_f32, kernel_neg_f32_4, kernel_neg_f32_nc;\n    cl_kernel kernel_neg_f16, kernel_neg_f16_4, kernel_neg_f16_nc;\n    cl_kernel kernel_exp_f32, kernel_exp_f32_4, kernel_exp_f32_nc;\n    cl_kernel kernel_exp_f16, kernel_exp_f16_4, kernel_exp_f16_nc;\n    cl_kernel kernel_expm1_f32, kernel_expm1_f32_4, kernel_expm1_f32_nc;\n    cl_kernel kernel_expm1_f16, kernel_expm1_f16_4, kernel_expm1_f16_nc;\n    cl_kernel kernel_softplus_f32, kernel_softplus_f32_4, kernel_softplus_f32_nc;\n    cl_kernel kernel_softplus_f16, kernel_softplus_f16_4, kernel_softplus_f16_nc;\n    cl_kernel kernel_upscale;\n    cl_kernel kernel_upscale_bilinear;\n    cl_kernel kernel_concat_f32;\n    cl_kernel kernel_conv_2d_f16;\n    cl_kernel kernel_conv_2d_f32;\n    cl_kernel kernel_conv_2d_f16_f32;\n    cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4;\n    cl_kernel kernel_timestep_embedding;\n    cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;\n    cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;\n    cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat;\n    cl_kernel kernel_mul_mv_id_mxfp4_f32;\n    cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;\n    cl_kernel kernel_mul_mm_f32_f32_l4_lm;\n    cl_kernel kernel_mul_mm_f16_f32_l4_lm;\n    cl_kernel kernel_mul_mm_q4_0_f32_l4_lm;\n    cl_kernel kernel_mul_mm_q4_1_f32_l4_lm;\n    cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;\n    cl_kernel kernel_mul_mm_q6_k_f32_l4_lm;\n\n    std::vector<ProfilingInfo> profiling_info;\n\n    void write_profiling_info() {\n        FILE * fperf = fopen(\"cl_profiling.csv\", \"w\");\n        if (!fperf) {\n            GGML_LOG_ERROR(\"Failed to open cl_profiling.csv\\n\");\n            return;\n        }\n\n        // Populate profiling info\n        for (ProfilingInfo & info : profiling_info) {\n            cl_ulong cmd_queued;\n            cl_ulong cmd_submit;\n            cl_ulong cmd_start;\n            cl_ulong cmd_end;\n            cl_ulong cmd_complete;\n\n            CL_CHECK(clWaitForEvents(1, &info.evt));\n            CL_CHECK(clGetEventProfilingInfo(\n                info.evt, CL_PROFILING_COMMAND_QUEUED, sizeof(cl_ulong), &cmd_queued, NULL));\n            CL_CHECK(clGetEventProfilingInfo(\n                info.evt, CL_PROFILING_COMMAND_SUBMIT, sizeof(cl_ulong), &cmd_submit, NULL));\n            CL_CHECK(clGetEventProfilingInfo(\n                info.evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &cmd_start, NULL));\n            CL_CHECK(clGetEventProfilingInfo(\n                info.evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &cmd_end, NULL));\n            CL_CHECK(clGetEventProfilingInfo(\n                info.evt, CL_PROFILING_COMMAND_COMPLETE, sizeof(cl_ulong), &cmd_complete, NULL));\n            CL_CHECK(clReleaseEvent(info.evt));\n\n            char kernel_name[512];\n            CL_CHECK(clGetKernelInfo(info.kernel, CL_KERNEL_FUNCTION_NAME,\n                sizeof(kernel_name), kernel_name, NULL));\n            info.kernel_name = kernel_name;\n\n            info.cmd_queued = cmd_queued;\n            info.cmd_submit = cmd_submit;\n            info.cmd_start  = cmd_start;\n            info.cmd_end    = cmd_end;\n\n            info.cmd_queued_duration_ns     = cmd_submit    - cmd_queued;\n            info.cmd_submit_duration_ns     = cmd_start     - cmd_submit;\n            info.cmd_duration_ns            = cmd_end       - cmd_start;\n            info.cmd_complete_duration_ns   = cmd_complete  - cmd_end;\n            info.cmd_total_duration_ns      = cmd_complete  - cmd_queued;\n        }\n\n        // Dump a csv\n        fprintf(fperf, \"op name, kernel name, exec duration (ms), global size, local size, output size\\n\");\n        for (const ProfilingInfo & info : profiling_info) {\n            fprintf(fperf, \"%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\\n\",\n                info.op_name.c_str(), info.kernel_name.c_str(),\n                info.cmd_duration_ns/1.e6f,\n                info.global_size[0], info.global_size[1], info.global_size[2],\n                info.local_size[0], info.local_size[1], info.local_size[2],\n                info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]);\n        }\n        fclose(fperf);\n\n        // Dump a simple chrome trace\n        FILE* ftrace = fopen(\"cl_trace.json\", \"w\");\n        if (!ftrace) {\n            GGML_LOG_ERROR(\"Failed to open cl_trace.json\\n\");\n            return;\n        }\n\n        fprintf(ftrace, \"[\\n\");\n        for (const ProfilingInfo & info : profiling_info) {\n            fprintf(ftrace, \"{\\\"name\\\": \\\"%s\\\", \\\"cat\\\": \\\"OpenCL\\\", \\\"ph\\\": \\\"B\\\", \\\"ts\\\": %\" PRIu64 \", \\\"pid\\\": \\\"\\\", \\\"tid\\\": \\\"Host\\\"},\\n\",\n                info.kernel_name.c_str(), info.cmd_queued/1000);\n            fprintf(ftrace, \"{\\\"name\\\": \\\"%s\\\", \\\"cat\\\": \\\"OpenCL\\\", \\\"ph\\\": \\\"E\\\", \\\"ts\\\": %\" PRIu64 \", \\\"pid\\\": \\\"\\\", \\\"tid\\\": \\\"Host\\\"},\\n\",\n                info.kernel_name.c_str(), info.cmd_submit/1000);\n\n            fprintf(ftrace, \"{\\\"name\\\": \\\"%s\\\", \\\"cat\\\": \\\"OpenCL\\\", \\\"ph\\\": \\\"B\\\", \\\"ts\\\": %\" PRIu64 \", \\\"pid\\\": \\\"\\\", \\\"tid\\\": \\\"Device\\\"},\\n\",\n                info.kernel_name.c_str(), info.cmd_start/1000);\n            fprintf(ftrace, \"{\\\"name\\\": \\\"%s\\\", \\\"cat\\\": \\\"OpenCL\\\", \\\"ph\\\": \\\"E\\\", \\\"ts\\\": %\" PRIu64 \", \\\"pid\\\": \\\"\\\", \\\"tid\\\": \\\"Device\\\"},\\n\",\n                info.kernel_name.c_str(), info.cmd_end/1000);\n        }\n        fclose(ftrace);\n    }\n\n    size_t get_kernel_workgroup_size(cl_kernel kernel) const {\n        size_t workgroup_size = 0;\n        size_t ret_size = 0;\n        CL_CHECK(\n            clGetKernelWorkGroupInfo(kernel, device, CL_KERNEL_WORK_GROUP_SIZE,\n                sizeof(size_t), &workgroup_size, &ret_size));\n        GGML_ASSERT(sizeof(size_t) == ret_size);\n        return workgroup_size;\n    }\n\n    void enqueue_ndrange_kernel(cl_kernel kernel, cl_uint work_dim, size_t *global_work_size, size_t *local_work_size, const ggml_tensor * tensor) {\n#ifdef GGML_OPENCL_PROFILING\n        cl_event evt;\n        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, work_dim, NULL, global_work_size, local_work_size, 0, NULL, &evt));\n\n        profiling_info.emplace_back();\n        populateProfilingInfo(profiling_info.back(), evt, kernel, work_dim, global_work_size, local_work_size, tensor);\n#else\n        GGML_UNUSED(tensor);\n        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, work_dim, NULL, global_work_size, local_work_size, 0, NULL, NULL));\n#endif\n    }\n\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n    // Transpose kernels\n    cl_program program_transpose;\n\n    cl_kernel kernel_transpose_32;\n    cl_kernel kernel_transpose_32_16;\n    cl_kernel kernel_transpose_16;\n    cl_kernel kernel_transpose_8_buf;\n    cl_kernel kernel_transpose_16_buf;\n    cl_kernel kernel_transpose_32_buf;\n    cl_kernel kernel_transpose_16_4x1;\n\n    // Gemm and Gemv related programs, kernels, etc\n    cl_program program_CL_gemm;\n    cl_program program_CL_gemv_general;\n    cl_program program_CL_gemv_4096_1_11008;\n    cl_program program_CL_gemv_4096_1_4096;\n    cl_program program_CL_gemv_11008_1_4096;\n    cl_program program_CL_gemv_32000_1_4096;\n    cl_kernel CL_mul_mat_Ab_Bi_8x4;\n    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general;\n    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008;\n    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096;\n    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096;\n    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096;\n    cl_kernel kernel_gemv_noshuffle_q4_1_f32;\n    cl_kernel kernel_gemm_noshuffle_q4_1_f32;\n    cl_kernel kernel_mul_mm_q8_0_f32_8x4;\n    cl_kernel CL_mul_mat_vec_q8_0_f32;\n#endif // GGML_OPENCL_USE_ADRENO_KERNELS\n\n    void free() {\n        ref_count--;\n        if (ref_count == 0) {\n#ifdef GGML_OPENCL_PROFILING\n            write_profiling_info();\n            profiling_info.clear();\n#endif\n        }\n    }\n};\n\n// All registered devices with a default device in the front.\nstatic std::vector<ggml_backend_device> g_ggml_backend_opencl_devices;\n\ninline std::string read_file(const std::string &path) {\n  std::ifstream ifs(path);\n  if (!ifs) {\n    return \"\";\n  }\n  std::string text;\n  ifs.seekg(0, std::ios::end);\n  text.resize(ifs.tellg());\n  ifs.seekg(0, std::ios::beg);\n  ifs.read(&text[0], text.size());\n  return text;\n}\n\nstatic cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts) {\n    cl_program p;\n    char *program_log;\n    size_t program_size;\n    size_t log_size;\n    int err;\n\n    program_size = strlen(program_buffer);\n\n    p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err);\n    if(err < 0) {\n        GGML_LOG_ERROR(\"OpenCL error creating program\");\n        exit(1);\n    }\n\n    err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);\n    if(err < 0) {\n        clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);\n        program_log = (char*) malloc(log_size + 1);\n        program_log[log_size] = '\\0';\n        clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL);\n        GGML_LOG_ERROR(\"ggml_opencl: kernel compile error:\\n\\n%s\\n\", program_log);\n        free(program_log);\n        exit(1);\n    }\n\n    return p;\n}\n\nstatic void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_version opencl_c_version) {\n    cl_int err;\n\n    // compiler options for general kernels\n    auto opencl_c_std =\n        std::string(\"CL\") + std::to_string(opencl_c_version.major) + \".\" + std::to_string(opencl_c_version.minor);\n    std::string compile_opts = std::string(\"-cl-std=\") + opencl_c_std +\n                               \" -cl-mad-enable -cl-unsafe-math-optimizations\"\n                               \" -cl-finite-math-only -cl-fast-relaxed-math\";\n\n    GGML_LOG_INFO(\"ggml_opencl: loading OpenCL kernels\");\n\n    // add\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"add.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"add.cl\");\n#endif\n        backend_ctx->program_add =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_add         = clCreateKernel(backend_ctx->program_add, \"kernel_add\", &err), err));\n        CL_CHECK((backend_ctx->kernel_add_row     = clCreateKernel(backend_ctx->program_add, \"kernel_add_row\", &err), err));\n        CL_CHECK((backend_ctx->kernel_add_f16     = clCreateKernel(backend_ctx->program_add, \"kernel_add_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_add_row_f16 = clCreateKernel(backend_ctx->program_add, \"kernel_add_row_f16\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // add_id\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"add_id.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"add_id.cl\");\n#endif\n        backend_ctx->program_add_id =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_add_id = clCreateKernel(backend_ctx->program_add_id, \"kernel_add_id\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // tri\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"tri.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"tri.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_tri = clCreateKernel(prog, \"kernel_tri_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n\n        CL_CHECK(clReleaseProgram(prog));\n    }\n\n    // fill\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"fill.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"fill.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_fill = clCreateKernel(prog, \"kernel_fill_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n\n        CL_CHECK(clReleaseProgram(prog));\n    }\n\n    // clamp\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"clamp.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"clamp.cl\");\n#endif\n        backend_ctx->program_clamp =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program_clamp, \"kernel_clamp\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // cpy\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"cpy.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"cpy.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(prog, \"kernel_cpy_f16_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(prog, \"kernel_cpy_f16_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(prog, \"kernel_cpy_f32_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(prog, \"kernel_cpy_f32_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_cpy_i32_i32 = clCreateKernel(prog, \"kernel_cpy_i32_i32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // cvt\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"cvt.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"cvt.cl\");\n#endif\n        backend_ctx->program_cvt =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, \"kernel_convert_block_q4_0_noshuffle\", &err), err));\n        CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, \"kernel_restore_block_q4_0_noshuffle\", &err), err));\n        CL_CHECK((backend_ctx->kernel_convert_block_q4_0  = clCreateKernel(backend_ctx->program_cvt, \"kernel_convert_block_q4_0\", &err), err));\n        CL_CHECK((backend_ctx->kernel_restore_block_q4_0  = clCreateKernel(backend_ctx->program_cvt, \"kernel_restore_block_q4_0\", &err), err));\n        CL_CHECK((backend_ctx->kernel_convert_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, \"kernel_convert_block_q4_1_noshuffle\", &err), err));\n        CL_CHECK((backend_ctx->kernel_restore_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, \"kernel_restore_block_q4_1_noshuffle\", &err), err));\n        CL_CHECK((backend_ctx->kernel_convert_block_q4_1  = clCreateKernel(backend_ctx->program_cvt, \"kernel_convert_block_q4_1\", &err), err));\n        CL_CHECK((backend_ctx->kernel_restore_block_q4_1  = clCreateKernel(backend_ctx->program_cvt, \"kernel_restore_block_q4_1\", &err), err));\n        CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, \"kernel_convert_block_mxfp4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, \"kernel_convert_block_mxfp4_trans\", &err), err));\n        CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, \"kernel_restore_block_mxfp4_trans\", &err), err));\n        CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, \"kernel_restore_block_mxfp4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_convert_block_q8_0  = clCreateKernel(backend_ctx->program_cvt, \"kernel_convert_block_q8_0\", &err), err));\n        CL_CHECK((backend_ctx->kernel_restore_block_q8_0  = clCreateKernel(backend_ctx->program_cvt, \"kernel_restore_block_q8_0\", &err), err));\n        CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans  = clCreateKernel(backend_ctx->program_cvt, \"kernel_restore_block_q8_0_trans\", &err), err));\n        CL_CHECK((backend_ctx->kernel_convert_block_q6_K  = clCreateKernel(backend_ctx->program_cvt, \"kernel_convert_block_q6_K\", &err), err));\n        CL_CHECK((backend_ctx->kernel_restore_block_q6_K  = clCreateKernel(backend_ctx->program_cvt, \"kernel_restore_block_q6_K\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // diag_mask_inf\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"diag_mask_inf.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"diag_mask_inf.cl\");\n#endif\n        backend_ctx->program_diag_mask_inf =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program_diag_mask_inf, \"kernel_diag_mask_inf_8\", &err), err));\n        CL_CHECK((backend_ctx->kernel_diag_mask_inf   = clCreateKernel(backend_ctx->program_diag_mask_inf, \"kernel_diag_mask_inf\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // diag\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"diag.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"diag.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_diag_f32 = clCreateKernel(prog, \"kernel_diag_f32\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // gelu\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"gelu.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"gelu.cl\");\n#endif\n        backend_ctx->program_gelu =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_gelu         = clCreateKernel(backend_ctx->program_gelu, \"kernel_gelu\", &err), err));\n        CL_CHECK((backend_ctx->kernel_gelu_4       = clCreateKernel(backend_ctx->program_gelu, \"kernel_gelu_4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_gelu_erf     = clCreateKernel(backend_ctx->program_gelu, \"kernel_gelu_erf\", &err), err));\n        CL_CHECK((backend_ctx->kernel_gelu_erf_4   = clCreateKernel(backend_ctx->program_gelu, \"kernel_gelu_erf_4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_gelu_quick   = clCreateKernel(backend_ctx->program_gelu, \"kernel_gelu_quick\", &err), err));\n        CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program_gelu, \"kernel_gelu_quick_4\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // glu\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"glu.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"glu.cl\");\n#endif\n        backend_ctx->program_glu =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_geglu           = clCreateKernel(backend_ctx->program_glu, \"kernel_geglu\", &err), err));\n        CL_CHECK((backend_ctx->kernel_reglu           = clCreateKernel(backend_ctx->program_glu, \"kernel_reglu\", &err), err));\n        CL_CHECK((backend_ctx->kernel_swiglu          = clCreateKernel(backend_ctx->program_glu, \"kernel_swiglu\", &err), err));\n        CL_CHECK((backend_ctx->kernel_swiglu_oai      = clCreateKernel(backend_ctx->program_glu, \"kernel_swiglu_oai\", &err), err));\n        CL_CHECK((backend_ctx->kernel_geglu_erf       = clCreateKernel(backend_ctx->program_glu, \"kernel_geglu_erf\", &err), err));\n        CL_CHECK((backend_ctx->kernel_geglu_quick     = clCreateKernel(backend_ctx->program_glu, \"kernel_geglu_quick\", &err), err));\n        CL_CHECK((backend_ctx->kernel_geglu_f16       = clCreateKernel(backend_ctx->program_glu, \"kernel_geglu_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_reglu_f16       = clCreateKernel(backend_ctx->program_glu, \"kernel_reglu_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_swiglu_f16      = clCreateKernel(backend_ctx->program_glu, \"kernel_swiglu_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_geglu_erf_f16   = clCreateKernel(backend_ctx->program_glu, \"kernel_geglu_erf_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_geglu_quick_f16 = clCreateKernel(backend_ctx->program_glu, \"kernel_geglu_quick_f16\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // get_rows\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"get_rows.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"get_rows.cl\");\n#endif\n        backend_ctx->program_get_rows =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_get_rows_f32  = clCreateKernel(backend_ctx->program_get_rows, \"kernel_get_rows_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_get_rows_f16  = clCreateKernel(backend_ctx->program_get_rows, \"kernel_get_rows_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_get_rows_q4_0 = clCreateKernel(backend_ctx->program_get_rows, \"kernel_get_rows_q4_0\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // solve_tri_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"solve_tri.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"solve_tri.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_solve_tri_f32 = clCreateKernel(prog, \"kernel_solve_tri_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n        CL_CHECK(clReleaseProgram(prog));\n    }\n\n    // im2col_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"im2col_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"im2col_f32.cl\");\n#endif\n        backend_ctx->program_im2col_f32 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col_f32, \"kernel_im2col_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // im2col_f16\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"im2col_f16.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"im2col_f16.cl\");\n#endif\n        backend_ctx->program_im2col_f16 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col_f16, \"kernel_im2col_f16\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_q4_0_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_q4_0_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_q4_0_f32.cl\");\n#endif\n        backend_ctx->program_mul_mv_q4_0_f32 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32 = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32, \"kernel_mul_mat_q4_0_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_q4_0_f32_v\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_q4_0_f32_v.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_q4_0_f32_v.cl\");\n#endif\n        backend_ctx->program_mul_mv_q4_0_f32_v =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_v = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_v, \"kernel_mul_mat_q4_0_f32_v\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_q4_0_f32_8x_flat\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_q4_0_f32_8x_flat.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_q4_0_f32_8x_flat.cl\");\n#endif\n        backend_ctx->program_mul_mv_q4_0_f32_8x_flat =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_8x_flat, \"kernel_mul_mat_q4_0_f32_8x_flat\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_q4_0_f32_1d_8x_flat\n    // This kernel does not compiler on Adreno cl compiler 38.01. Skip it for\n    // those compiler versions since it is anyway not used for Adreno.\n    if (backend_ctx->gpu_family != ADRENO ||\n        backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) ||\n        backend_ctx->adreno_cl_compiler_version.type == DX) {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_q4_0_f32_1d_8x_flat.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_q4_0_f32_1d_8x_flat.cl\");\n#endif\n        backend_ctx->program_mul_mv_q4_0_f32_1d_8x_flat =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_1d_8x_flat, \"kernel_mul_mat_q4_0_f32_1d_8x_flat\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_q4_0_f32_1d_16x_flat\n    // This kernel does not compiler on Adreno cl compiler 38.01. Skip it for\n    // those compiler versions since it is anyway not used for Adreno.\n    if (backend_ctx->gpu_family != ADRENO ||\n        backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) ||\n    backend_ctx->adreno_cl_compiler_version.type == DX) {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_q4_0_f32_1d_16x_flat.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_q4_0_f32_1d_16x_flat.cl\");\n#endif\n        backend_ctx->program_mul_mv_q4_0_f32_1d_16x_flat =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_1d_16x_flat, \"kernel_mul_mat_q4_0_f32_1d_16x_flat\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_q4_1_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_q4_1_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_q4_1_f32.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32 = clCreateKernel(prog, \"kernel_mul_mv_q4_1_f32\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_q4_1_f32_flat\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_q4_1_f32_flat.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_q4_1_f32_flat.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32_flat = clCreateKernel(prog, \"kernel_mul_mv_q4_1_f32_flat\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_q4_k_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_q4_k_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_q4_k_f32.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32 = clCreateKernel(prog, \"kernel_mul_mv_q4_K_f32\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_q6_k_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_q6_k_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_q6_k_f32.cl\");\n#endif\n        backend_ctx->program_mul_mv_q6_K =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32 = clCreateKernel(backend_ctx->program_mul_mv_q6_K, \"kernel_mul_mv_q6_K_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_q6_k_f32_flat\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_q6_k_f32_flat.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_q6_k_f32_flat.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32_flat = clCreateKernel(prog, \"kernel_mul_mv_q6_K_f32_flat\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_q8_0_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_q8_0_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_q8_0_f32.cl\");\n#endif\n        backend_ctx->program_mul_mv_q8_0_f32 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mv_q8_0_f32 = clCreateKernel(backend_ctx->program_mul_mv_q8_0_f32, \"kernel_mul_mv_q8_0_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_q8_0_f32_flat\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_q8_0_f32_flat.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_q8_0_f32_flat.cl\");\n#endif\n        backend_ctx->program_mul_mv_q8_0_f32_flat =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mv_q8_0_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_q8_0_f32_flat, \"kernel_mul_mv_q8_0_f32_flat\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_mxfp4_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_mxfp4_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_mxfp4_f32.cl\");\n#endif\n        backend_ctx->program_mul_mv_mxfp4_f32 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mv_mxfp4_f32 = clCreateKernel(backend_ctx->program_mul_mv_mxfp4_f32, \"kernel_mul_mv_mxfp4_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_mxfp4_f32_flat\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_mxfp4_f32_flat.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_mxfp4_f32_flat.cl\");\n#endif\n        backend_ctx->program_mul_mv_mxfp4_f32_flat =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mv_mxfp4_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_mxfp4_f32_flat, \"kernel_mul_mv_mxfp4_f32_flat\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_f16_f16\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_f16_f16.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_f16_f16.cl\");\n#endif\n        backend_ctx->program_mul_mv_f16_f16 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16 = clCreateKernel(backend_ctx->program_mul_mv_f16_f16, \"kernel_mul_mat_f16_f16\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_f16_f32_1row\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_f16_f32_1row.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_f16_f32_1row.cl\");\n#endif\n        backend_ctx->program_mul_mv_f16_f32_1row =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_1row, \"kernel_mul_mat_f16_f32_1row\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_f16_f32_l4\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_f16_f32_l4.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_f16_f32_l4.cl\");\n#endif\n        backend_ctx->program_mul_mv_f16_f32_l4 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4   = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_l4, \"kernel_mul_mat_f16_f32_l4\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_f16_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_f16_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_f16_f32.cl\");\n#endif\n        backend_ctx->program_mul_mv_f16_f32 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32, \"kernel_mul_mat_f16_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_f32_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_f32_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_f32_f32.cl\");\n#endif\n        backend_ctx->program_mul_mv_f32_f32 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32 = clCreateKernel(backend_ctx->program_mul_mv_f32_f32, \"kernel_mul_mat_f32_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mat_f16_f32_tiled\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mat_f16_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mat_f16_f32.cl\");\n#endif\n        backend_ctx->program_mul_mat_f16_f32_tiled =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_tiled = clCreateKernel(backend_ctx->program_mul_mat_f16_f32_tiled, \"mul_mat_f16_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mm_f32_f32_l4_lm\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mm_f32_f32_l4_lm.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mm_f32_f32_l4_lm.cl\");\n#endif\n        backend_ctx->program_mul_mm_f32_f32_l4_lm =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mm_f32_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f32_f32_l4_lm, \"kernel_mul_mm_f32_f32_l4_lm\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mm_f16_f32_l4_lm\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mm_f16_f32_l4_lm.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mm_f16_f32_l4_lm.cl\");\n#endif\n        backend_ctx->program_mul_mm_f16_f32_l4_lm =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_l4_lm, \"kernel_mul_mm_f16_f32_l4_lm\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mm_q4_0_f32_l4_lm\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mm_q4_0_f32_l4_lm.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mm_q4_0_f32_l4_lm.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm = clCreateKernel(prog, \"kernel_mul_mm_q4_0_f32_l4_lm\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mm_q4_1_f32_l4_lm\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mm_q4_1_f32_l4_lm.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mm_q4_1_f32_l4_lm.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm = clCreateKernel(prog, \"kernel_mul_mm_q4_1_f32_l4_lm\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mm_q8_0_f32_l4_lm\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mm_q8_0_f32_l4_lm.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mm_q8_0_f32_l4_lm.cl\");\n#endif\n        backend_ctx->program_mul_mm_q8_0_f32_l4_lm =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_q8_0_f32_l4_lm, \"kernel_mul_mm_q8_0_f32_l4_lm\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mm_q6_k_f32_l4_lm\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mm_q6_k_f32_l4_lm.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mm_q6_k_f32_l4_lm.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm = clCreateKernel(prog, \"kernel_mul_mm_q6_k_f32_l4_lm\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mm_f16_f32_kq_kqv\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mm_f16_f32_kq_kqv.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mm_f16_f32_kq_kqv.cl\");\n#endif\n        backend_ctx->program_mul_mm_f16_f32_kqv =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts+\" -DKQV \");\n        backend_ctx->program_mul_mm_f16_f32_kq =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kqv = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kqv, \"mul_mm_f16_f32_kqv\", &err), err));\n        CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kq = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kq, \"mul_mm_f16_f32_kq\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul.cl\");\n#endif\n        backend_ctx->program_mul =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul         = clCreateKernel(backend_ctx->program_mul, \"kernel_mul\", &err), err));\n        CL_CHECK((backend_ctx->kernel_mul_row     = clCreateKernel(backend_ctx->program_mul, \"kernel_mul_row\", &err), err));\n        CL_CHECK((backend_ctx->kernel_mul_f16     = clCreateKernel(backend_ctx->program_mul, \"kernel_mul_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_mul_row_f16 = clCreateKernel(backend_ctx->program_mul, \"kernel_mul_row_f16\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // norm\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"norm.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"norm.cl\");\n#endif\n        backend_ctx->program_norm =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_norm         = clCreateKernel(backend_ctx->program_norm, \"kernel_norm\", &err), err));\n        CL_CHECK((backend_ctx->kernel_norm_mul_add = clCreateKernel(backend_ctx->program_norm, \"kernel_norm_mul_add\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // relu\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"relu.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"relu.cl\");\n#endif\n        backend_ctx->program_relu =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program_relu, \"kernel_relu\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // rms_norm\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"rms_norm.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"rms_norm.cl\");\n#endif\n        backend_ctx->program_rms_norm =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_rms_norm     = clCreateKernel(backend_ctx->program_rms_norm, \"kernel_rms_norm\", &err), err));\n        CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, \"kernel_rms_norm_mul\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // l2_norm\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"l2_norm.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"l2_norm.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_l2_norm_f32     = clCreateKernel(prog, \"kernel_l2_norm_f32\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // rope\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"rope.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"rope.cl\");\n#endif\n        backend_ctx->program_rope =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_rope_norm_f32   = clCreateKernel(backend_ctx->program_rope, \"kernel_rope_norm_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_rope_norm_f16   = clCreateKernel(backend_ctx->program_rope, \"kernel_rope_norm_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_rope_neox_f32   = clCreateKernel(backend_ctx->program_rope, \"kernel_rope_neox_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_rope_neox_f16   = clCreateKernel(backend_ctx->program_rope, \"kernel_rope_neox_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_rope_multi_f32  = clCreateKernel(backend_ctx->program_rope, \"kernel_rope_multi_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_rope_multi_f16  = clCreateKernel(backend_ctx->program_rope, \"kernel_rope_multi_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program_rope, \"kernel_rope_vision_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program_rope, \"kernel_rope_vision_f16\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // scale\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"scale.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"scale.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_scale_f32   = clCreateKernel(prog, \"kernel_scale_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_scale_f32_4 = clCreateKernel(prog, \"kernel_scale_f32_4\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // silu\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"silu.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"silu.cl\");\n#endif\n        backend_ctx->program_silu =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_silu   = clCreateKernel(backend_ctx->program_silu, \"kernel_silu\", &err), err));\n        CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program_silu, \"kernel_silu_4\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // softmax_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"softmax_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"softmax_f32.cl\");\n#endif\n        backend_ctx->program_softmax_f32 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program_softmax_f32, \"kernel_soft_max\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // softmax_f16\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"softmax_f16.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"softmax_f16.cl\");\n#endif\n        backend_ctx->program_softmax_f16 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program_softmax_f16, \"kernel_soft_max_f16\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // softmax_4_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"softmax_4_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"softmax_4_f32.cl\");\n#endif\n        backend_ctx->program_softmax_4_f32 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program_softmax_4_f32, \"kernel_soft_max_4\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // softmax_4_f16\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"softmax_4_f16.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"softmax_4_f16.cl\");\n#endif\n        backend_ctx->program_softmax_4_f16 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program_softmax_4_f16, \"kernel_soft_max_4_f16\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // flash_attn\n    {\n        #ifdef GGML_OPENCL_EMBED_KERNELS\n                const std::string kernel_src_f16 {\n                    #include \"flash_attn_f16.cl.h\"\n                };\n                const std::string kernel_src_f32 {\n                    #include \"flash_attn_f32.cl.h\"\n                };\n                const std::string kernel_src_f32_f16 {\n                    #include \"flash_attn_f32_f16.cl.h\"\n                };\n        #else\n                const std::string kernel_src_f16 = read_file(\"flash_attn_f16.cl\");\n                const std::string kernel_src_f32 = read_file(\"flash_attn_f32.cl\");\n                const std::string kernel_src_f32_f16 = read_file(\"flash_attn_f32_f16.cl\");\n        #endif\n\n        if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) {\n            const struct { int dk; int dv; int bm; int bn; } fa_dims[] = {\n                { 40,  40, 32, 32}, { 64,  64, 64, 64}, { 80,  80, 64, 32}, { 96,  96, 64, 32},\n                {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16},\n                {192, 192, 16, 16}, {256, 256, 16, 16},\n            };\n\n            for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) {\n                const int dk = fa_dims[i].dk;\n                const int dv = fa_dims[i].dv;\n                const int bm = fa_dims[i].bm;\n                const int bn = fa_dims[i].bn;\n                std::string OPTS = compile_opts +\n                    \" -D DK=\" + std::to_string(dk) +\n                    \" -D DV=\" + std::to_string(dv) +\n                    \" -D BLOCK_M=\" + std::to_string(bm) +\n                    \" -D BLOCK_N=\" + std::to_string(bn);\n\n                cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS);\n                cl_kernel k_f16, k_f16_q1;\n                CL_CHECK((k_f16 = clCreateKernel(prog_f16, \"flash_attn_f16\", &err), err));\n                CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, \"flash_attn_f16_q1\", &err), err));\n                backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16;\n                backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1;\n                CL_CHECK(clReleaseProgram(prog_f16));\n\n                cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS);\n                cl_kernel k_f32, k_f32_q1;\n                CL_CHECK((k_f32 = clCreateKernel(prog_f32, \"flash_attn_f32\", &err), err));\n                CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, \"flash_attn_f32_q1\", &err), err));\n                backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32;\n                backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1;\n                CL_CHECK(clReleaseProgram(prog_f32));\n\n                cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS);\n                cl_kernel k_f32_f16, k_f32_f16_q1;\n                CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, \"flash_attn_f32_f16\", &err), err));\n                CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, \"flash_attn_f32_f16_q1\", &err), err));\n                backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16;\n                backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1;\n                CL_CHECK(clReleaseProgram(prog_f32_f16));\n\n                backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm;\n                backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn;\n            }\n            GGML_LOG_CONT(\".\");\n        }\n    }\n\n    // argsort\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"argsort.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"argsort.cl\");\n#endif\n        backend_ctx->program_argsort_f32_i32 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel(backend_ctx->program_argsort_f32_i32, \"kernel_argsort_f32_i32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // div\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"div.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"div.cl\");\n#endif\n        std::string compile_opts = std::string(\"-cl-std=\") + opencl_c_std +\n                               \" -cl-mad-enable -cl-finite-math-only \";\n\n        backend_ctx->program_div =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_div         = clCreateKernel(backend_ctx->program_div, \"kernel_div\", &err), err));\n        CL_CHECK((backend_ctx->kernel_div_row     = clCreateKernel(backend_ctx->program_div, \"kernel_div_row\", &err), err));\n        CL_CHECK((backend_ctx->kernel_div_f16     = clCreateKernel(backend_ctx->program_div, \"kernel_div_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_div_row_f16 = clCreateKernel(backend_ctx->program_div, \"kernel_div_row_f16\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // sqr\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"sqr.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"sqr.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_sqr_cont_f32     = clCreateKernel(prog, \"kernel_sqr_cont_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_sqr_cont_f32_4   = clCreateKernel(prog, \"kernel_sqr_cont_f32_4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_sqr_cont_f16     = clCreateKernel(prog, \"kernel_sqr_cont_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_sqr_cont_f16_4   = clCreateKernel(prog, \"kernel_sqr_cont_f16_4\", &err), err));\n\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // sqrt\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"sqrt.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"sqrt.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_sqrt_cont_f32     = clCreateKernel(prog, \"kernel_sqrt_cont_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_sqrt_cont_f32_4   = clCreateKernel(prog, \"kernel_sqrt_cont_f32_4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_sqrt_cont_f16     = clCreateKernel(prog, \"kernel_sqrt_cont_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_sqrt_cont_f16_4   = clCreateKernel(prog, \"kernel_sqrt_cont_f16_4\", &err), err));\n\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mean\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mean.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mean.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, \"kernel_mean_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_mean_f32_4 = clCreateKernel(prog, \"kernel_mean_f32_4\", &err), err));\n\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // sub\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"sub.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"sub.cl\");\n#endif\n        backend_ctx->program_sub =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_sub         = clCreateKernel(backend_ctx->program_sub, \"kernel_sub\", &err), err));\n        CL_CHECK((backend_ctx->kernel_sub_row     = clCreateKernel(backend_ctx->program_sub, \"kernel_sub_row\", &err), err));\n        CL_CHECK((backend_ctx->kernel_sub_f16     = clCreateKernel(backend_ctx->program_sub, \"kernel_sub_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_sub_row_f16 = clCreateKernel(backend_ctx->program_sub, \"kernel_sub_row_f16\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // sum_rows\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"sum_rows.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"sum_rows.cl\");\n#endif\n        backend_ctx->program_sum_rows_f32 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, \"kernel_sum_rows_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_sum_rows_f32_4 = clCreateKernel(backend_ctx->program_sum_rows_f32, \"kernel_sum_rows_f32_4\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // cumsum\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"cumsum.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"cumsum.cl\");\n#endif\n        cl_program prog;\n        prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_cumsum_blk = clCreateKernel(prog, \"kernel_cumsum_blk\", &err), err));\n        CL_CHECK((backend_ctx->kernel_cumsum_add = clCreateKernel(prog, \"kernel_cumsum_add\", &err), err));\n        GGML_LOG_CONT(\".\");\n        CL_CHECK(clReleaseProgram(prog));\n    }\n\n    // sigmoid\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"sigmoid.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"sigmoid.cl\");\n#endif\n        backend_ctx->program_sigmoid =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_sigmoid_f32 = clCreateKernel(backend_ctx->program_sigmoid, \"kernel_sigmoid_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_sigmoid_f16 = clCreateKernel(backend_ctx->program_sigmoid, \"kernel_sigmoid_f16\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // group_norm\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"group_norm.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"group_norm.cl\");\n#endif\n        backend_ctx->program_group_norm =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_group_norm         = clCreateKernel(backend_ctx->program_group_norm, \"kernel_group_norm\", &err), err));\n        CL_CHECK((backend_ctx->kernel_group_norm_mul_add = clCreateKernel(backend_ctx->program_group_norm, \"kernel_group_norm_mul_add\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // repeat\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"repeat.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"repeat.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n        CL_CHECK((backend_ctx->kernel_repeat_f32 = clCreateKernel(prog, \"kernel_repeat_f32\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // pad\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"pad.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"pad.cl\");\n#endif\n        if (!kernel_src.empty()) {\n            backend_ctx->program_pad =\n                build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n            CL_CHECK((backend_ctx->kernel_pad = clCreateKernel(backend_ctx->program_pad, \"kernel_pad\", &err), err));\n            GGML_LOG_CONT(\".\");\n        } else {\n            GGML_LOG_WARN(\"ggml_opencl: pad kernel source not found or empty. Pad operations will not be available.\\n\");\n            backend_ctx->program_pad = nullptr;\n            backend_ctx->kernel_pad = nullptr;\n        }\n    }\n\n    // tanh\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"tanh.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"tanh.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n        CL_CHECK((backend_ctx->kernel_tanh_f32    = clCreateKernel(prog, \"kernel_tanh_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_tanh_f32_4  = clCreateKernel(prog, \"kernel_tanh_f32_4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_tanh_f32_nc = clCreateKernel(prog, \"kernel_tanh_f32_nc\", &err), err));\n        CL_CHECK((backend_ctx->kernel_tanh_f16    = clCreateKernel(prog, \"kernel_tanh_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_tanh_f16_4  = clCreateKernel(prog, \"kernel_tanh_f16_4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_tanh_f16_nc = clCreateKernel(prog, \"kernel_tanh_f16_nc\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // neg\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"neg.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"neg.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n        CL_CHECK((backend_ctx->kernel_neg_f32    = clCreateKernel(prog, \"kernel_neg_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_neg_f32_4  = clCreateKernel(prog, \"kernel_neg_f32_4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_neg_f32_nc = clCreateKernel(prog, \"kernel_neg_f32_nc\", &err), err));\n        CL_CHECK((backend_ctx->kernel_neg_f16    = clCreateKernel(prog, \"kernel_neg_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_neg_f16_4  = clCreateKernel(prog, \"kernel_neg_f16_4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_neg_f16_nc = clCreateKernel(prog, \"kernel_neg_f16_nc\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // exp\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"exp.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"exp.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n        CL_CHECK((backend_ctx->kernel_exp_f32    = clCreateKernel(prog, \"kernel_exp_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_exp_f32_4  = clCreateKernel(prog, \"kernel_exp_f32_4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_exp_f32_nc = clCreateKernel(prog, \"kernel_exp_f32_nc\", &err), err));\n        CL_CHECK((backend_ctx->kernel_exp_f16    = clCreateKernel(prog, \"kernel_exp_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_exp_f16_4  = clCreateKernel(prog, \"kernel_exp_f16_4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_exp_f16_nc = clCreateKernel(prog, \"kernel_exp_f16_nc\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // expm1\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"expm1.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"expm1.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n        CL_CHECK((backend_ctx->kernel_expm1_f32    = clCreateKernel(prog, \"kernel_expm1_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_expm1_f32_4  = clCreateKernel(prog, \"kernel_expm1_f32_4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_expm1_f32_nc = clCreateKernel(prog, \"kernel_expm1_f32_nc\", &err), err));\n        CL_CHECK((backend_ctx->kernel_expm1_f16    = clCreateKernel(prog, \"kernel_expm1_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_expm1_f16_4  = clCreateKernel(prog, \"kernel_expm1_f16_4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_expm1_f16_nc = clCreateKernel(prog, \"kernel_expm1_f16_nc\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // softplus\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"softplus.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"softplus.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n        CL_CHECK((backend_ctx->kernel_softplus_f32    = clCreateKernel(prog, \"kernel_softplus_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_softplus_f32_4  = clCreateKernel(prog, \"kernel_softplus_f32_4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_softplus_f32_nc = clCreateKernel(prog, \"kernel_softplus_f32_nc\", &err), err));\n        CL_CHECK((backend_ctx->kernel_softplus_f16    = clCreateKernel(prog, \"kernel_softplus_f16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_softplus_f16_4  = clCreateKernel(prog, \"kernel_softplus_f16_4\", &err), err));\n        CL_CHECK((backend_ctx->kernel_softplus_f16_nc = clCreateKernel(prog, \"kernel_softplus_f16_nc\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // upscale\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"upscale.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"upscale.cl\");\n#endif\n        if (!kernel_src.empty()) {\n            backend_ctx->program_upscale =\n                build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n            CL_CHECK((backend_ctx->kernel_upscale = clCreateKernel(backend_ctx->program_upscale, \"kernel_upscale\", &err), err));\n            if (backend_ctx->program_upscale) {\n                 cl_int err_bilinear;\n                 backend_ctx->kernel_upscale_bilinear = clCreateKernel(backend_ctx->program_upscale, \"kernel_upscale_bilinear\", &err_bilinear);\n                 if (err_bilinear != CL_SUCCESS) {\n                    GGML_LOG_WARN(\"ggml_opencl: kernel_upscale_bilinear not found in upscale.cl. Bilinear upscale will not be available. Error: %d\\n\", err_bilinear);\n                    backend_ctx->kernel_upscale_bilinear = nullptr;\n                 }\n            } else {\n                backend_ctx->kernel_upscale_bilinear = nullptr;\n            }\n            GGML_LOG_CONT(\".\");\n        } else {\n            GGML_LOG_WARN(\"ggml_opencl: upscale kernel source not found or empty. Upscale operations will not be available.\\n\");\n            backend_ctx->program_upscale = nullptr;\n            backend_ctx->kernel_upscale = nullptr;\n            backend_ctx->kernel_upscale_bilinear = nullptr;\n        }\n    }\n\n    // concat\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"concat.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"concat.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n        CL_CHECK((backend_ctx->kernel_concat_f32 = clCreateKernel(prog, \"kernel_concat_f32\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // timestep_embedding\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"tsembd.cl.h\"\n        };\n#else\n\n        const std::string kernel_src = read_file(\"tsembd.cl\");\n#endif\n        if (!kernel_src.empty()) {\n            backend_ctx->program_tsembd =\n                build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n            CL_CHECK((backend_ctx->kernel_timestep_embedding = clCreateKernel(backend_ctx->program_tsembd, \"kernel_timestep_embedding\", &err), err));\n            GGML_LOG_CONT(\".\");\n        } else {\n            GGML_LOG_WARN(\"ggml_opencl: timestep_embedding kernel source not found or empty. This op will not be available.\\n\");\n            backend_ctx->program_tsembd = nullptr;\n            backend_ctx->kernel_timestep_embedding = nullptr;\n        }\n    }\n\n    // set_rows\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"set_rows.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"set_rows.cl\");\n#endif\n        backend_ctx->program_set_rows =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_set_rows_f32_i64 = clCreateKernel(backend_ctx->program_set_rows, \"kernel_set_rows_f32_i64\", &err), err));\n        CL_CHECK((backend_ctx->kernel_set_rows_f32_i32 = clCreateKernel(backend_ctx->program_set_rows, \"kernel_set_rows_f32_i32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_set_rows_f16_i64 = clCreateKernel(backend_ctx->program_set_rows, \"kernel_set_rows_f16_i64\", &err), err));\n        CL_CHECK((backend_ctx->kernel_set_rows_f16_i32 = clCreateKernel(backend_ctx->program_set_rows, \"kernel_set_rows_f16_i32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n     // conv2d\n     {\n        #ifdef GGML_OPENCL_EMBED_KERNELS\n                const std::string kernel_src {\n                    #include \"conv2d.cl.h\"\n                };\n                const std::string kernel_src_f16_f32 {\n                    #include \"conv2d_f16_f32.cl.h\"\n                };\n        #else\n                const std::string kernel_src = read_file(\"conv2d.cl\");\n                const std::string kernel_src_f16_f32 = read_file(\"conv2d_f16_f32.cl\");\n        #endif\n                if (!kernel_src.empty()) {\n                    backend_ctx->program_conv_2d_f16 =\n                        build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + \" -DUSE_FP16=1\").c_str());\n                    CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, \"kernel_conv_2d\", &err), err));\n                    GGML_LOG_CONT(\".\");\n                    backend_ctx->program_conv_2d_f32 =\n                        build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n                    CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, \"kernel_conv_2d\", &err), err));\n                    GGML_LOG_CONT(\".\");\n                } else {\n                    GGML_LOG_WARN(\"ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\\n\");\n                    backend_ctx->program_conv_2d_f16 = nullptr;\n                    backend_ctx->kernel_conv_2d_f16 = nullptr;\n                    backend_ctx->program_conv_2d_f32 = nullptr;\n                    backend_ctx->kernel_conv_2d_f32 = nullptr;\n                }\n                if (!kernel_src_f16_f32.empty()) {\n                    backend_ctx->program_conv_2d_f16_f32 =\n                        build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts);\n                    CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, \"kernel_conv_2d\", &err), err));\n                    GGML_LOG_CONT(\".\");\n                } else {\n                    GGML_LOG_WARN(\"ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\\n\");\n                    backend_ctx->program_conv_2d_f16_f32 = nullptr;\n                    backend_ctx->kernel_conv_2d_f16_f32 = nullptr;\n                }\n    }\n\n    // ssm_conv\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"ssm_conv.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"ssm_conv.cl\");\n#endif\n        cl_program prog =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32   = clCreateKernel(prog, \"kernel_ssm_conv_f32_f32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32_4 = clCreateKernel(prog, \"kernel_ssm_conv_f32_f32_4\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_id_q4_0_f32_8x_flat\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_id_q4_0_f32_8x_flat.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_id_q4_0_f32_8x_flat.cl\");\n#endif\n        backend_ctx->program_mul_mv_id_q4_0_f32_8x_flat =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_id_q4_0_f32_8x_flat, \"kernel_mul_mv_id_q4_0_f32_8x_flat\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_id_q8_0_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_id_q8_0_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_id_q8_0_f32.cl\");\n#endif\n        backend_ctx->program_mul_mv_id_q8_0_f32 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mv_id_q8_0_f32 = clCreateKernel(backend_ctx->program_mul_mv_id_q8_0_f32, \"kernel_mul_mv_id_q8_0_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_id_q8_0_f32_flat\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_id_q8_0_f32_flat.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_id_q8_0_f32_flat.cl\");\n#endif\n        backend_ctx->program_mul_mv_id_q8_0_f32_flat =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mv_id_q8_0_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_id_q8_0_f32_flat, \"kernel_mul_mv_id_q8_0_f32_flat\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_id_mxfp4_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_id_mxfp4_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_id_mxfp4_f32.cl\");\n#endif\n        backend_ctx->program_mul_mv_id_mxfp4_f32 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mv_id_mxfp4_f32 = clCreateKernel(backend_ctx->program_mul_mv_id_mxfp4_f32, \"kernel_mul_mv_id_mxfp4_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mv_id_mxfp4_f32_flat\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"mul_mv_id_mxfp4_f32_flat.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"mul_mv_id_mxfp4_f32_flat.cl\");\n#endif\n        backend_ctx->program_mul_mv_id_mxfp4_f32_flat =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_id_mxfp4_f32_flat, \"kernel_mul_mv_id_mxfp4_f32_flat\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // Adreno kernels\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n    // transpose\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"transpose.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"transpose.cl\");\n#endif\n        backend_ctx->program_transpose =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, \"kernel_transpose_32_16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_transpose_32    = clCreateKernel(backend_ctx->program_transpose, \"kernel_transpose_32\", &err), err));\n        CL_CHECK((backend_ctx->kernel_transpose_16    = clCreateKernel(backend_ctx->program_transpose, \"kernel_transpose_16\", &err), err));\n        CL_CHECK((backend_ctx->kernel_transpose_8_buf  = clCreateKernel(backend_ctx->program_transpose, \"kernel_transpose_8_buf\", &err), err));\n        CL_CHECK((backend_ctx->kernel_transpose_16_buf = clCreateKernel(backend_ctx->program_transpose, \"kernel_transpose_16_buf\", &err), err));\n        CL_CHECK((backend_ctx->kernel_transpose_32_buf = clCreateKernel(backend_ctx->program_transpose, \"kernel_transpose_32_buf\", &err), err));\n        CL_CHECK((backend_ctx->kernel_transpose_16_4x1 = clCreateKernel(backend_ctx->program_transpose, \"kernel_transpose_16_4x1\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // gemv_noshuffle_general\n    {\n        std::string CL_gemv_compile_opts = std::string(\"-cl-std=\") + opencl_c_std +\n                                       \" -cl-mad-enable \"\n                                       \" -DSIMDGROUP_WIDTH=\" +\n                                       std::to_string(backend_ctx->adreno_wave_size);\n        if (backend_ctx->has_vector_subgroup_broadcast) {\n            CL_gemv_compile_opts += \" -DVECTOR_SUB_GROUP_BROADCAT \";\n        }\n\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src_CL_gemv_general {\n            #include \"gemv_noshuffle_general.cl.h\"\n        };\n#else\n        const std::string kernel_src_CL_gemv_general = read_file(\"gemv_noshuffle_general.cl\");\n#endif\n\n        backend_ctx->program_CL_gemv_general = build_program_from_source(\n            backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts);\n\n        CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, \"kernel_gemv_noshuffle\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // gemv_noshuffle\n    {\n        // Gemv 2048, 16384\n        std::string CL_gemv_compile_opts = std::string(\"-cl-std=\") + opencl_c_std +\n            \" -cl-mad-enable \"\n            \" -DLINE_STRIDE_A=2048 \"\n            \" -DBLOCK_STRIDE_A=16384 \"\n            \" -DSIMDGROUP_WIDTH=\" +\n            std::to_string(backend_ctx->adreno_wave_size);\n        if (backend_ctx->has_vector_subgroup_broadcast) {\n            CL_gemv_compile_opts += \" -DVECTOR_SUB_GROUP_BROADCAT \";\n        }\n\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src_CL_gemv {\n            #include \"gemv_noshuffle.cl.h\"\n        };\n#else\n        const std::string kernel_src_CL_gemv = read_file(\"gemv_noshuffle.cl\");\n#endif\n\n        backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source(\n            backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts);\n        CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, \"kernel_gemv_noshuffle\", &err), err));\n        GGML_LOG_CONT(\".\");\n\n        // Gemv 2048, 16384\n        CL_gemv_compile_opts = std::string(\"-cl-std=\") + opencl_c_std +\n            \" -cl-mad-enable \"\n            \" -DLINE_STRIDE_A=2048 \"\n            \" -DBLOCK_STRIDE_A=16384 \"\n            \" -DSIMDGROUP_WIDTH=\" +\n            std::to_string(backend_ctx->adreno_wave_size);\n        if (backend_ctx->has_vector_subgroup_broadcast) {\n            CL_gemv_compile_opts += \" -DVECTOR_SUB_GROUP_BROADCAT \";\n        }\n\n        backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source(\n            backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts);\n        CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, \"kernel_gemv_noshuffle\", &err), err));\n        GGML_LOG_CONT(\".\");\n\n        // Gemv 5504, 44032\n        CL_gemv_compile_opts = std::string(\"-cl-std=\") + opencl_c_std +\n            \" -cl-mad-enable \"\n            \" -DLINE_STRIDE_A=5504 \"\n            \" -DBLOCK_STRIDE_A=44032 \"\n            \" -DSIMDGROUP_WIDTH=\" +\n            std::to_string(backend_ctx->adreno_wave_size);\n        if (backend_ctx->has_vector_subgroup_broadcast) {\n            CL_gemv_compile_opts += \" -DVECTOR_SUB_GROUP_BROADCAT \";\n        }\n\n        backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source(\n            backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts);\n        CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, \"kernel_gemv_noshuffle\", &err), err));\n        GGML_LOG_CONT(\".\");\n\n        // Gemv 16000, 128000\n        CL_gemv_compile_opts = std::string(\"-cl-std=\") + opencl_c_std +\n            \" -cl-mad-enable \"\n            \" -DLINE_STRIDE_A=16000 \"\n            \" -DBLOCK_STRIDE_A=128000 \"\n            \" -DSIMDGROUP_WIDTH=\" +\n            std::to_string(backend_ctx->adreno_wave_size);\n\n        if (backend_ctx->has_vector_subgroup_broadcast) {\n            CL_gemv_compile_opts += \" -DVECTOR_SUB_GROUP_BROADCAT \";\n        }\n\n        backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source(\n            backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts);\n        CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, \"kernel_gemv_noshuffle\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mat_Ab_Bi_8x4\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src_CL_gemm {\n            #include \"mul_mat_Ab_Bi_8x4.cl.h\"\n        };\n#else\n        const std::string kernel_src_CL_gemm = read_file(\"mul_mat_Ab_Bi_8x4.cl\");\n#endif\n        backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_CL_gemm.c_str(), compile_opts);\n        CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, \"kernel_mul_mat_Ab_Bi_8x4\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // gemm_noshuffle_q4_1_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"gemm_noshuffle_q4_1_f32.cl.h\"\n       };\n#else\n        const std::string kernel_src = read_file(\"gemm_noshuffle_q4_1_f32.cl\");\n#endif\n        cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);\n        CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_1_f32 = clCreateKernel(prog, \"kernel_gemm_noshuffle_q4_1_f32\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // gemv_noshuffle_q4_1_f32\n    {\n        std::string CL_gemv_compile_opts = std::string(\"-cl-std=\") + opencl_c_std +\n                                       \" -cl-mad-enable \";\n        if (backend_ctx->has_vector_subgroup_broadcast) {\n            CL_gemv_compile_opts += \" -DVECTOR_SUB_GROUP_BROADCAT \";\n        }\n\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"gemv_noshuffle_q4_1_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"gemv_noshuffle_q4_1_f32.cl\");\n#endif\n\n        cl_program prog = build_program_from_source(\n            backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_1_f32 = clCreateKernel(prog, \"kernel_gemv_noshuffle_q4_1_f32\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // mul_mm_q8_0_f32_8x4\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src_q8_8x4_gemm {\n            #include \"mul_mm_q8_0_f32_8x4.cl.h\"\n       };\n#else\n        const std::string kernel_src_q8_8x4_gemm = read_file(\"mul_mm_q8_0_f32_8x4.cl\");\n#endif\n        backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_q8_8x4_gemm.c_str(), compile_opts);\n        CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, \"kernel_mul_mm_q8_0_f32_8x4\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // gemv_noshuffle_general_q8_0_f32\n    {\n        std::string CL_gemv_compile_opts = std::string(\"-cl-std=\") + opencl_c_std +\n                                       \" -cl-mad-enable \"\n                                       \" -DSIMDGROUP_WIDTH=\" +\n                                       std::to_string(backend_ctx->adreno_wave_size);\n        if (backend_ctx->has_vector_subgroup_broadcast) {\n            CL_gemv_compile_opts += \" -DVECTOR_SUB_GROUP_BROADCAT \";\n        }\n\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src_CL_gemv_general {\n            #include \"gemv_noshuffle_general_q8_0_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src_CL_gemv_general = read_file(\"gemv_noshuffle_general_q8_0_f32.cl\");\n#endif\n\n        cl_program prog = build_program_from_source(\n            backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts);\n\n        CL_CHECK((backend_ctx->CL_mul_mat_vec_q8_0_f32 = clCreateKernel(prog, \"kernel_gemv_noshuffle_q8_0_f32\", &err), err));\n        CL_CHECK(clReleaseProgram(prog));\n        GGML_LOG_CONT(\".\");\n    }\n\n    std::string CL_moe_compile_opts = std::string(\"-cl-std=\") + opencl_c_std +\n            \" -cl-mad-enable \"\n            \" -cl-fast-relaxed-math\";\n\n    // gemv_moe_mxfp4_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"gemv_moe_mxfp4_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"gemv_moe_mxfp4_f32.cl\");\n#endif\n        backend_ctx->program_gemv_moe_mxfp4_f32 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemv_moe_mxfp4_f32, \"kernel_gemv_moe_mxfp4_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n\n    // gemm_moe_mxfp4_f32\n    {\n#ifdef GGML_OPENCL_EMBED_KERNELS\n        const std::string kernel_src {\n            #include \"gemm_moe_mxfp4_f32.cl.h\"\n        };\n#else\n        const std::string kernel_src = read_file(\"gemm_moe_mxfp4_f32.cl\");\n#endif\n        backend_ctx->program_gemm_moe_mxfp4_f32 =\n            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);\n\n        CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, \"kernel_gemm_moe_mxfp4_f32\", &err), err));\n        GGML_LOG_CONT(\".\");\n    }\n#endif // GGML_OPENCL_USE_ADRENO_KERNELS\n    GGML_LOG_CONT(\"\\n\");\n}\n\n// XXX static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {\n// XXX    static bool initialized = false;\n// XXX    static ggml_backend_opencl_context *backend_ctx = nullptr;\n\nstatic ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev);\n\nnamespace /* anonymous */ {\nextern struct ggml_backend_device_i ggml_backend_opencl_device_i;\n}\n\n// Look for available and suitable devices.\nstatic std::vector<ggml_backend_device> ggml_opencl_probe_devices(ggml_backend_reg * reg) {\n    std::vector<ggml_backend_device> found_devices;\n\n#ifdef GGML_OPENCL_PROFILING\n    GGML_LOG_INFO(\"ggml_opencl: OpenCL profiling enabled\\n\");\n#endif\n\n    struct cl_device;\n    struct cl_platform {\n        cl_platform_id id;\n        unsigned number;\n        char name[128];\n        char vendor[128];\n        struct cl_device * devices;\n        unsigned n_devices;\n        struct cl_device * default_device;\n    };\n\n    struct cl_device {\n        struct cl_platform * platform;\n        cl_device_id id;\n        unsigned number;\n        cl_device_type type;\n        char name[128];\n        char version[128];\n    };\n\n    enum { NPLAT = 16, NDEV = 16 };\n\n    struct cl_platform platforms[NPLAT];\n    unsigned n_platforms = 0;\n    struct cl_device devices[NDEV];\n    unsigned n_devices = 0;\n    struct cl_device * default_device = NULL;\n    unsigned           default_platform_number = 0;\n\n    cl_platform_id platform_ids[NPLAT];\n    if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) {\n        GGML_LOG_ERROR(\"ggml_opencl: platform IDs not available.\\n\");\n        return found_devices;\n    }\n\n    for (unsigned i = 0; i < n_platforms; i++) {\n        struct cl_platform * p = &platforms[i];\n        p->number = i;\n        p->id = platform_ids[i];\n        CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL));\n        CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL));\n\n        cl_device_id device_ids[NDEV];\n        cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices);\n        if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) {\n            p->n_devices = 0;\n        } else {\n            CL_CHECK(clGetDeviceIDsError);\n        }\n        p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL;\n        p->default_device = NULL;\n\n        for (unsigned j = 0; j < p->n_devices; j++) {\n            struct cl_device * d = &devices[n_devices];\n            d->number = n_devices++;\n            d->id = device_ids[j];\n            d->platform = p;\n            CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL));\n            CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL));\n            CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_VERSION, sizeof(d->version), &d->version, NULL));\n\n            if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) {\n                p->default_device = d;\n            }\n        }\n\n        if (default_device == NULL && p->default_device != NULL) {\n            default_device          = p->default_device;\n            default_platform_number = i;\n        }\n    }\n\n    if (n_devices == 0) {\n        GGML_LOG_ERROR(\"ggml_opencl: could find any OpenCL devices.\\n\");\n        return found_devices;\n    }\n\n    char *      user_platform_string = getenv(\"GGML_OPENCL_PLATFORM\");\n    char *      user_device_string   = getenv(\"GGML_OPENCL_DEVICE\");\n    int         user_platform_number = -1;\n    int         user_device_number   = -1;\n    cl_device * candidate_devices    = nullptr;\n    unsigned    n_candidate_devices  = 0;\n\n    unsigned n;\n    if (user_platform_string != NULL && sscanf(user_platform_string, \" %u\", &n) == 1 && n < n_platforms) {\n        user_platform_number = (int)n;\n    }\n    if (user_device_string != NULL && sscanf(user_device_string, \" %u\", &n) == 1 && n < n_devices) {\n        user_device_number = (int)n;\n    }\n    if (user_platform_number != -1 && user_device_number != -1) {\n        cl_platform* platform = &platforms[user_platform_number];\n        if ((unsigned)user_device_number >= platform->n_devices) {\n            GGML_LOG_ERROR(\"ggml_opencl: invalid device number %d\\n\", user_device_number);\n            exit(1);\n        }\n        default_device      = &platform->devices[user_device_number];\n        candidate_devices   = platform->devices;\n        n_candidate_devices = platform->n_devices;\n    } else {\n        // Choose a platform by matching a substring.\n        if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) {\n            for (unsigned i = 0; i < n_platforms; i++) {\n                struct cl_platform * p = &platforms[i];\n                if (strstr(p->name, user_platform_string) != NULL ||\n                    strstr(p->vendor, user_platform_string) != NULL) {\n                    user_platform_number = (int)i;\n                    break;\n                }\n            }\n            if (user_platform_number == -1) {\n                GGML_LOG_ERROR(\"ggml_opencl: no platform matching '%s' was found.\\n\", user_platform_string);\n                exit(1);\n            }\n        }\n\n        int                  platform_idx = user_platform_number != -1 ? user_platform_number : default_platform_number;\n        struct cl_platform * p            = &platforms[platform_idx];\n        candidate_devices                 = p->devices;\n        n_candidate_devices               = p->n_devices;\n        default_device                    = p->default_device;\n        if (n_candidate_devices == 0) {\n            GGML_LOG_ERROR(\"ggml_opencl: selected platform '%s' does not have any devices.\\n\", p->name);\n            exit(1);\n        }\n\n        if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) {\n            for (unsigned i = 0; i < n_candidate_devices; i++) {\n                struct cl_device * d = &candidate_devices[i];\n                if (strstr(d->name, user_device_string) != NULL) {\n                    user_device_number = d->number;\n                    break;\n                }\n            }\n            if (user_device_number == -1) {\n                GGML_LOG_ERROR(\"ggml_opencl: no device matching '%s' was found.\\n\", user_device_string);\n                exit(1);\n            }\n        }\n        if (user_device_number != -1) {\n            candidate_devices   = &devices[user_device_number];\n            n_candidate_devices = 1;\n            default_device      = &candidate_devices[0];\n        }\n\n        GGML_ASSERT(n_candidate_devices > 0);\n\n        if (default_device == NULL) {\n            default_device = &candidate_devices[0];\n        }\n    }\n\n    GGML_ASSERT(n_candidate_devices != 0 && candidate_devices);\n\n    // Put the default device in front.\n    for (unsigned i = 1; i < n_candidate_devices; i++) {\n        if (&candidate_devices[i] == default_device) {\n            std::swap(candidate_devices[0], candidate_devices[i]);\n            default_device = &candidate_devices[0];\n            break;\n        }\n    }\n\n    GGML_LOG_INFO(\"ggml_opencl: selected platform: '%s'\\n\", default_device->platform->name);\n\n    std::vector<cl_device_id> device_ids;\n    for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) {\n        device_ids.push_back(dev->id);\n    }\n\n    cl_int                err;\n    cl_context            shared_context;\n    cl_context_properties properties[] = { (intptr_t) CL_CONTEXT_PLATFORM, (intptr_t) default_device->platform->id, 0 };\n\n    CL_CHECK(\n        (shared_context = clCreateContext(properties, device_ids.size(), device_ids.data(), NULL, NULL, &err), err));\n\n    for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) {\n        GGML_LOG_INFO(\"\\nggml_opencl: device: '%s (%s)'\\n\", dev->name, dev->version);\n\n        auto dev_ctx = std::unique_ptr<ggml_backend_opencl_device_context>(new ggml_backend_opencl_device_context{\n            /*.platform         =*/dev->platform->id,\n            /*.platform_nane    =*/dev->platform->name,\n            /*.device           =*/dev->id,\n            /*.device_name      =*/dev->name,\n            /*.device_type      =*/dev->type,\n            /*.device_version   =*/dev->version,\n            /*.backend_ctx      =*/nullptr,\n            /*.buffer_type      =*/{},\n            /*.context          =*/shared_context,\n        });\n\n        found_devices.push_back(ggml_backend_device{\n            /* .iface   = */ ggml_backend_opencl_device_i,\n            /* .reg     = */ reg,\n            /* .context = */ dev_ctx.get(),\n        });\n\n        if (!ggml_cl2_init(&found_devices.back())) {\n            found_devices.pop_back();\n            GGML_LOG_INFO(\"ggml_opencl: drop unsupported device.\\n\");\n            continue;\n        }\n\n        dev_ctx.release();\n    }\n\n    if (found_devices.size()) {\n        auto * dev_ctx = static_cast<ggml_backend_opencl_device_context *>(found_devices.front().context);\n        GGML_LOG_INFO(\"ggml_opencl: default device: '%s (%s)'\\n\", dev_ctx->device_name.c_str(),\n                      dev_ctx->device_version.c_str());\n\n        if (dev_ctx->device_type != CL_DEVICE_TYPE_GPU) {\n            GGML_LOG_WARN(\"ggml_opencl: warning, the default device is not a GPU: '%s'.\\n\",\n                          dev_ctx->device_name.c_str());\n        }\n    }\n\n    return found_devices;\n}\n\n// Initialize device if it is supported (returns nullptr if it is not).\nstatic ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {\n    GGML_ASSERT(dev);\n    GGML_ASSERT(dev->context);\n\n    ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context;\n    GGML_ASSERT(dev_ctx->platform);\n    GGML_ASSERT(dev_ctx->device);\n\n    if (dev_ctx->backend_ctx) {\n        return dev_ctx->backend_ctx;\n    }\n\n    auto backend_ctx        = std::make_unique<ggml_backend_opencl_context>();\n    backend_ctx->device     = dev_ctx->device;\n    backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN;\n\n    // ref_count get increased in ggml_backend_opencl_device_init\n    // This function is also used to retrieve backend context, so we don't want\n    // to increase ref_count for each call. We only want to increase ref_count\n    // when the associated device is initialized\n    backend_ctx->ref_count  = 0;\n\n    if (strstr(dev_ctx->device_name.c_str(), \"Adreno\") ||\n        strstr(dev_ctx->device_name.c_str(), \"Qualcomm\") ||\n        strstr(dev_ctx->device_version.c_str(), \"Adreno\")) {\n        backend_ctx->gpu_family = GPU_FAMILY::ADRENO;\n        // Usually device version contains the detailed device name\n        backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str());\n        if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) {\n            backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str());\n        }\n\n        // Use wave size of 64 for all Adreno GPUs.\n        backend_ctx->adreno_wave_size = 64;\n    } else if (strstr(dev_ctx->device_name.c_str(), \"Intel\")) {\n        backend_ctx->gpu_family = GPU_FAMILY::INTEL;\n    } else {\n        GGML_LOG_ERROR(\"Unsupported GPU: %s\\n\", dev_ctx->device_name.c_str());\n        backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN;\n        return nullptr;\n    }\n\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n    if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) {\n        GGML_LOG_ERROR(\"ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; \"\n            \"run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\\n\");\n        return nullptr;\n    }\n#endif\n\n    // Populate backend device name\n    backend_ctx->device_name = dev_ctx->device_name;\n\n    // A local ref of cl_device_id for convenience\n    cl_device_id device = backend_ctx->device;\n\n    ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform);\n\n    // Check device OpenCL version, OpenCL 2.0 or above is required\n    ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, device);\n    if (opencl_c_version.major < 2) {\n        GGML_LOG_ERROR(\"ggml_opencl: OpenCL 2.0 or above is required\\n\");\n        return nullptr;\n    }\n\n    // Check driver version\n    size_t driver_version_str_size;\n    clGetDeviceInfo(device, CL_DRIVER_VERSION, 0, NULL, &driver_version_str_size);\n    char *driver_version = (char *)alloca(driver_version_str_size + 1);\n    clGetDeviceInfo(device, CL_DRIVER_VERSION, driver_version_str_size, driver_version, NULL);\n    driver_version[driver_version_str_size] = '\\0';\n    GGML_LOG_INFO(\"ggml_opencl: OpenCL driver: %s\\n\", driver_version);\n    backend_ctx->driver_version = driver_version;\n\n    backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version);\n    backend_ctx->has_vector_subgroup_broadcast =\n        (backend_ctx->adreno_cl_compiler_version.type == E031 && backend_ctx->adreno_cl_compiler_version.major >= 47) ||\n        (backend_ctx->adreno_cl_compiler_version.type == DX   && backend_ctx->adreno_cl_compiler_version.major >= 17);\n    GGML_LOG_INFO(\"ggml_opencl: vector subgroup broadcast support: %s\\n\",\n        backend_ctx->has_vector_subgroup_broadcast ? \"true\" : \"false\");\n\n    size_t ext_str_size;\n    clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size);\n    char *ext_buffer = (char *)alloca(ext_str_size + 1);\n    clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL);\n    ext_buffer[ext_str_size] = '\\0'; // ensure it is null terminated\n    // Check if ext_buffer contains cl_khr_fp16\n    backend_ctx->fp16_support = strstr(ext_buffer, \"cl_khr_fp16\") != NULL;\n    GGML_LOG_INFO(\"ggml_opencl: device FP16 support: %s\\n\", backend_ctx->fp16_support ? \"true\" : \"false\");\n\n    // fp16 is required\n    if (!backend_ctx->fp16_support) {\n        GGML_LOG_ERROR(\"ggml_opencl: device does not support FP16\\n\");\n        return nullptr;\n    }\n\n    // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes\n    // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x)\n    if (opencl_c_version.major == 3 && strstr(ext_buffer, \"cl_khr_subgroups\") == NULL &&\n        strstr(ext_buffer, \"cl_intel_subgroups\") == NULL) {\n        GGML_LOG_ERROR(\"ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) \"\n            \"(note that subgroups is an optional feature in OpenCL 3.0)\\n\");\n        return nullptr;\n    }\n\n    cl_uint base_align_in_bits;\n    CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &base_align_in_bits, NULL));\n    GGML_ASSERT(base_align_in_bits % 8u == 0);\n    backend_ctx->alignment = base_align_in_bits / 8u;\n    GGML_LOG_INFO(\"ggml_opencl: mem base addr align: %u\\n\", backend_ctx->alignment);\n\n    clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL);\n    GGML_LOG_INFO(\"ggml_opencl: max mem alloc size: %zu MB\\n\", backend_ctx->max_alloc_size/1024/1024);\n\n    clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL);\n    GGML_LOG_INFO(\"ggml_opencl: device max image buffer size (pixels): %lu\\n\", backend_ctx->image_max_buffer_size);\n\n    clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL);\n    GGML_LOG_INFO(\"ggml_opencl: device max workgroup size: %lu\\n\", backend_ctx->max_workgroup_size);\n\n    // Check SVM.\n    cl_device_svm_capabilities svm_caps;\n    CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0));\n    GGML_LOG_INFO(\"ggml_opencl: SVM coarse grain buffer support: %s\\n\",\n        svm_caps & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER ? \"true\" : \"false\");\n    GGML_LOG_INFO(\"ggml_opencl: SVM fine grain buffer support: %s\\n\",\n        svm_caps & CL_DEVICE_SVM_FINE_GRAIN_BUFFER ? \"true\" : \"false\");\n    GGML_LOG_INFO(\"ggml_opencl: SVM fine grain system support: %s\\n\",\n        svm_caps & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM ? \"true\" : \"false\");\n    GGML_LOG_INFO(\"ggml_opencl: SVM atomics support: %s\\n\",\n        svm_caps & CL_DEVICE_SVM_ATOMICS ? \"true\" : \"false\");\n\n    if (opencl_c_version.major >= 3) {\n        // Assume it is not available for 3.0, since it is optional in 3.0.\n        // If compiling against 3.0, then we can query.\n        backend_ctx->non_uniform_workgroups = false;\n#if CL_TARGET_OPENCL_VERSION >= 300\n        CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool),\n                                 &backend_ctx->non_uniform_workgroups, 0));\n#endif\n    } else {\n        GGML_ASSERT(opencl_c_version.major == 2);\n        // Non-uniform workgroup sizes is mandatory feature in v2.x.\n        backend_ctx->non_uniform_workgroups = true;\n    }\n\n    // Print out configurations\n#ifdef GGML_OPENCL_SOA_Q\n    GGML_LOG_INFO(\"ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\\n\");\n#endif // GGML_OPENCL_SOA_Q\n\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n    GGML_LOG_INFO(\"ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\\n\");\n#endif // GGML_OPENCL_USE_ADRENO_KERNELS\n\n    cl_int err;\n\n    // A local ref of cl_context for convenience\n    cl_context context = backend_ctx->context = dev_ctx->context;\n\n    //CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err),\n    //    (err != CL_INVALID_QUEUE_PROPERTIES && err != CL_INVALID_VALUE ? err :\n    //    (queue = clCreateCommandQueue(context, device, 0, &err), err)\n    //)));\n    cl_command_queue_properties command_queue_props = 0;\n#ifdef GGML_OPENCL_PROFILING\n    command_queue_props |= CL_QUEUE_PROFILING_ENABLE;\n#endif\n    CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err));\n\n    // Load kernels\n    load_cl_kernels(backend_ctx.get(), opencl_c_version);\n\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n    // Allocate intermediate buffers and images\n    size_t required_A_q_d_bytes = 311164928;\n    size_t required_A_s_d_bytes = 38895616;\n    size_t required_B_d_bytes = 45088768;\n\n    // Ensure buffer sizes do not exceed the maximum allocation size\n    size_t max_A_q_d_bytes = MIN(required_A_q_d_bytes, backend_ctx->max_alloc_size);\n    size_t max_A_s_d_bytes = MIN(required_A_s_d_bytes, backend_ctx->max_alloc_size);\n    size_t max_B_d_bytes   = MIN(required_B_d_bytes, backend_ctx->max_alloc_size);\n    if (required_A_q_d_bytes > backend_ctx->max_alloc_size) {\n        GGML_LOG_WARN(\"ggml_opencl: A_q_d buffer size reduced from %zu to %zu due to device limitations.\\n\",\n                      required_A_q_d_bytes, max_A_q_d_bytes);\n    }\n    if (required_A_s_d_bytes > backend_ctx->max_alloc_size) {\n        GGML_LOG_WARN(\"ggml_opencl: A_s_d buffer size reduced from %zu to %zu due to device limitations.\\n\",\n                      required_A_s_d_bytes, max_A_s_d_bytes);\n    }\n    if (required_B_d_bytes > backend_ctx->max_alloc_size) {\n        GGML_LOG_WARN(\"ggml_opencl: B_d buffer size reduced from %zu to %zu due to device limitations.\\n\",\n                      required_B_d_bytes, max_B_d_bytes);\n    }\n\n    backend_ctx->prealloc_quant_trans.allocate(context, max_A_q_d_bytes);\n    backend_ctx->prealloc_scales_trans.allocate(context, max_A_s_d_bytes);\n    backend_ctx->prealloc_act_trans.allocate(context, max_B_d_bytes);\n#endif // GGML_OPENCL_USE_ADRENO_KERNELS\n\n    backend_ctx->disable_fusion = getenv(\"GGML_OPENCL_DISABLE_FUSION\") != nullptr;\n\n    dev_ctx->backend_ctx = backend_ctx.release();\n    return dev_ctx->backend_ctx;\n}\n\nstatic void ggml_cl2_free(ggml_backend_t backend) {\n    ggml_backend_opencl_context * ctx = (ggml_backend_opencl_context *) backend->context;\n    ctx->free();\n\n    // The CL context is shared by all backends, release it if all backends have been released\n    bool should_release_opencl = true;\n    for (auto device : g_ggml_backend_opencl_devices) {\n        ggml_backend_opencl_device_context * ctx_dev = (ggml_backend_opencl_device_context *) device.context;\n        if (ctx_dev->backend_ctx->ref_count > 0) {\n            should_release_opencl = false;\n        }\n    }\n\n    if (should_release_opencl) {\n        CL_CHECK(clReleaseContext(ctx->context));\n    }\n}\n\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\nstatic void transpose_2d(\n    ggml_backend_opencl_context * backend_ctx,\n    cl_kernel kernel,\n    cl_mem src, cl_mem dst, size_t size,\n    cl_int stride, cl_int rows,\n    bool blocking = true\n) {\n    static ggml_cl_buffer buf;\n\n    cl_event evt;\n    cl_int err;\n\n    buf.allocate(backend_ctx->context, size);\n\n    cl_mem trans;\n    cl_buffer_region region;\n\n    region.origin = 0;\n    region.size = size;\n    CL_CHECK((trans = clCreateSubBuffer(\n        buf.buffer, CL_MEM_READ_WRITE,\n        CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &src));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &trans));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &stride));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &rows));\n\n    size_t local_size[3] = {64, 1, 1};\n    size_t global_size[3] = {(size_t)stride, (size_t)rows, 1};;\n    CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL,\n        global_size, local_size, 0, NULL, NULL));\n\n    if (blocking) {\n        CL_CHECK(clEnqueueCopyBuffer(backend_ctx->queue, trans, dst, 0, 0, size, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n        CL_CHECK(clReleaseEvent(evt));\n    } else {\n        CL_CHECK(clEnqueueCopyBuffer(backend_ctx->queue, trans, dst, 0, 0, size, 0, NULL, NULL));\n    }\n\n    CL_CHECK(clReleaseMemObject(trans));\n}\n\nstatic void transpose_2d_as_8b(\n    ggml_backend_opencl_context * backend_ctx,\n    cl_mem src, cl_mem dst, size_t size,\n    cl_int stride, cl_int rows,\n    bool blocking = true\n) {\n    transpose_2d(backend_ctx, backend_ctx->kernel_transpose_8_buf,\n        src, dst, size, stride, rows, blocking);\n}\n\nstatic void transpose_2d_as_16b(\n    ggml_backend_opencl_context * backend_ctx,\n    cl_mem src, cl_mem dst, size_t size,\n    cl_int stride, cl_int rows,\n    bool blocking = true\n) {\n    transpose_2d(backend_ctx, backend_ctx->kernel_transpose_16_buf,\n        src, dst, size, stride, rows, blocking);\n}\n\nstatic void transpose_2d_as_32b(\n    ggml_backend_opencl_context * backend_ctx,\n    cl_mem src, cl_mem dst, size_t size,\n    cl_int stride, cl_int rows,\n    bool blocking = true\n) {\n    transpose_2d(backend_ctx, backend_ctx->kernel_transpose_32_buf,\n        src, dst, size, stride, rows, blocking);\n}\n#endif // GGML_OPENCL_USE_ADRENO_KERNELS\n\n//------------------------------------------------------------------------------\n// Tensor extra management\n//------------------------------------------------------------------------------\nstruct ggml_tensor_extra_cl {\n    // The buffer object that holds the data.\n    cl_mem data_device;\n    // The offset into the buffer object. This is primarily for scratch buffer\n    // and view operation.\n    // NB: this offset no longer includes view offset (view_offs). Whenever this\n    // offset is used, view_offs should be considered.\n    cl_ulong offset;\n    // The actual size of the cl_mem object. This is needed when returning the\n    // block to the pool.\n    size_t actual_size;\n\n    void reset() {\n        data_device = nullptr;\n        offset = 0;\n        actual_size = 0;\n    }\n};\n\n// Additional tensor extra structs for quantized tensors.\n// These tensors are loaded from files and should not be allocated in scratch --\n// they should always be allocated from the pool. Hence, they do not have an\n// `offset`, which indicate their locations in the scratch buffer.\nstruct ggml_tensor_extra_cl_q4_0 {\n    // Quantized values.\n    cl_mem q = nullptr;\n    // Quantized values in image1d_buffer_t.\n    cl_mem q_img = nullptr;\n    // Scales.\n    cl_mem d = nullptr;\n    // Scales in image1d_buffer_t.\n    cl_mem d_img = nullptr;\n    // Size of quantized values.\n    size_t size_q = 0;\n    // Size of scales.\n    size_t size_d = 0;\n\n    ~ggml_tensor_extra_cl_q4_0() {\n        reset();\n    }\n\n    void reset() {\n        // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.\n        // They must be properly released so that the original buffer can be\n        // properly released to avoid memory leak.\n        if (q != nullptr) {\n            CL_CHECK(clReleaseMemObject(q));\n            q = nullptr;\n        }\n        if (d != nullptr) {\n            CL_CHECK(clReleaseMemObject(d));\n            d = nullptr;\n        }\n        // Currently, q_img and d_img are only initialized when SMALL_ALLOC is\n        // enabled. They point to the images in ggml_backend_opencl_buffer_context.\n        // So, there is no need to release them here.\n        // TODO: initialize them for non SMALL_PATH path, or remove them.\n        q_img = nullptr;\n        d_img = nullptr;\n        size_q = 0;\n        size_d = 0;\n    }\n};\n\nstruct ggml_tensor_extra_cl_q4_1 {\n    // Quantized values.\n    cl_mem q = nullptr;\n    // Quantized values in image1d_buffer_t.\n    cl_mem q_img = nullptr;\n    // Scales.\n    cl_mem d = nullptr;\n    // Scales in image1d_buffer_t.\n    cl_mem d_img = nullptr;\n    // Min\n    cl_mem m = nullptr;\n    // Min in image1d_buffer_t.\n    cl_mem m_img = nullptr;\n    // Size of quantized values.\n    size_t size_q = 0;\n    // Size of scales.\n    size_t size_d = 0;\n    // Size of min values.\n    size_t size_m = 0;\n\n    ~ggml_tensor_extra_cl_q4_1() {\n        reset();\n    }\n\n    void reset() {\n        // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.\n        // They must be properly released so that the original buffer can be\n        // properly released to avoid memory leak.\n        if (q != nullptr) {\n            CL_CHECK(clReleaseMemObject(q));\n            q = nullptr;\n        }\n        if (d != nullptr) {\n            CL_CHECK(clReleaseMemObject(d));\n            d = nullptr;\n        }\n        if (m != nullptr) {\n            CL_CHECK(clReleaseMemObject(m));\n            m = nullptr;\n        }\n        // Currently, q_img and d_img are only initialized when SMALL_ALLOC is\n        // enabled. They point to the images in ggml_backend_opencl_buffer_context.\n        // So, there is no need to release them here.\n        // TODO: initialize them for non SMALL_PATH path, or remove them.\n        q_img = nullptr;\n        d_img = nullptr;\n        m_img = nullptr;\n        size_q = 0;\n        size_d = 0;\n        size_m = 0;\n    }\n};\n\nstruct ggml_tensor_extra_cl_mxfp4 {\n    // Quantized values.\n    cl_mem q = nullptr;\n    // Quantized values in image1d_buffer_t.\n    cl_mem q_img = nullptr;\n    // Scales in E8M0.\n    cl_mem e = nullptr;\n    // Scales in image1d_buffer_t.\n    cl_mem e_img = nullptr;\n    // Size of quantized values.\n    size_t size_q = 0;\n    // Size of scales.\n    size_t size_e = 0;\n\n    ~ggml_tensor_extra_cl_mxfp4() {\n        reset();\n    }\n\n    void reset() {\n        // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.\n        // They must be properly released so that the original buffer can be\n        // properly released to avoid memory leak.\n        if (q != nullptr) {\n            CL_CHECK(clReleaseMemObject(q));\n            q = nullptr;\n        }\n        if (e != nullptr) {\n            CL_CHECK(clReleaseMemObject(e));\n            e = nullptr;\n        }\n        if (q != nullptr) {\n            CL_CHECK(clReleaseMemObject(q_img));\n            q = nullptr;\n        }\n        // Currently, q_img and d_img are not used. They can be image1d_buffer_t\n        // that wraps around q and d to utilize image access path.\n        q_img = nullptr;\n        e_img = nullptr;\n        size_q = 0;\n        size_e = 0;\n    }\n};\n\nstruct ggml_tensor_extra_cl_q8_0 {\n    cl_mem q = nullptr;\n    cl_mem q_img = nullptr;\n\n    cl_mem d = nullptr;\n    cl_mem d_img = nullptr;\n\n    size_t size_q = 0;\n    size_t size_d = 0;\n\n    ~ggml_tensor_extra_cl_q8_0() {\n        reset();\n    }\n\n    void reset() {\n        // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.\n        // They must be properly released so that the original buffer can be\n        // properly released to avoid memory leak.\n        if (q != nullptr) {\n            CL_CHECK(clReleaseMemObject(q));\n            q = nullptr;\n        }\n        if (d != nullptr) {\n            CL_CHECK(clReleaseMemObject(d));\n            d = nullptr;\n        }\n        // Currently, q_img and d_img are not used. They can be image1d_buffer_t\n        // that wraps around q and d to utilize image access path.\n        q_img = nullptr;\n        d_img = nullptr;\n        size_q = 0;\n        size_d = 0;\n    }\n};\n\nstruct ggml_tensor_extra_cl_q6_K {\n    // Lower 4 bits of quantized weights.\n    cl_mem ql = nullptr;\n    // Upper 2 bits of quantized weights.\n    cl_mem qh = nullptr;\n    // Scales for each block.\n    cl_mem s  = nullptr;\n    // Scales for each super block.\n    cl_mem d  = nullptr;\n\n    size_t size_ql = 0;\n    size_t size_qh = 0;\n    size_t size_s  = 0;\n    size_t size_d  = 0;\n\n    ~ggml_tensor_extra_cl_q6_K() {\n        reset();\n    }\n\n    void reset() {\n        if (ql != nullptr) {\n            CL_CHECK(clReleaseMemObject(ql));\n            ql = nullptr;\n        }\n        if (qh != nullptr) {\n            CL_CHECK(clReleaseMemObject(qh));\n            qh = nullptr;\n        }\n        if (s != nullptr) {\n            CL_CHECK(clReleaseMemObject(s));\n            s = nullptr;\n        }\n        if (d != nullptr) {\n            CL_CHECK(clReleaseMemObject(d));\n            d = nullptr;\n        }\n\n        size_ql = 0;\n        size_qh = 0;\n        size_s  = 0;\n        size_d  = 0;\n    }\n};\n\n//------------------------------------------------------------------------------\n// Backend API\n//------------------------------------------------------------------------------\n\n//\n// backend\n//\nstatic const char * ggml_backend_opencl_name(ggml_backend_t backend) {\n    return \"OpenCL\";\n\n    UNUSED(backend);\n}\n\nstatic void ggml_backend_opencl_free(ggml_backend_t backend) {\n    ggml_cl2_free(backend);\n}\n\nstatic void ggml_backend_opencl_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    GGML_UNUSED(backend);\n    GGML_UNUSED(tensor);\n    GGML_UNUSED(data);\n    GGML_UNUSED(offset);\n    GGML_UNUSED(size);\n}\n\nstatic void ggml_backend_opencl_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    GGML_UNUSED(backend);\n    GGML_UNUSED(tensor);\n    GGML_UNUSED(data);\n    GGML_UNUSED(offset);\n    GGML_UNUSED(size);\n}\n\nstatic bool ggml_backend_opencl_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {\n    GGML_UNUSED(backend);\n    GGML_UNUSED(src);\n    GGML_UNUSED(dst);\n    return false;\n}\n\nstatic void ggml_backend_opencl_synchronize(ggml_backend_t backend) {\n    auto * backend_ctx = static_cast<ggml_backend_opencl_context *>(backend->context);\n\n    cl_event evt;\n    CL_CHECK(clEnqueueBarrierWithWaitList(backend_ctx->queue, 0, nullptr, &evt));\n    CL_CHECK(clWaitForEvents(1, &evt));\n    CL_CHECK(clReleaseEvent(evt));\n}\n\n// Synchronizes the 'backend_ctx's device with others so that commands\n// enqueued to it won't start until commands in the other devices have\n// completed.\nstatic void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) {\n    if (g_ggml_backend_opencl_devices.size() < 2)\n      return; // No other devices to synchronize with.\n\n    std::vector<cl_event> events;\n    events.reserve(g_ggml_backend_opencl_devices.size());\n\n    for (ggml_backend_device & backend_dev : g_ggml_backend_opencl_devices) {\n        auto * other_backend_ctx = ggml_cl2_init(&backend_dev);\n        if (backend_ctx != other_backend_ctx) {\n            cl_event ev;\n            CL_CHECK(clEnqueueMarkerWithWaitList(other_backend_ctx->queue, 0, nullptr, &ev));\n            CL_CHECK(clFlush(other_backend_ctx->queue));\n            events.push_back(ev);\n        }\n    }\n\n    CL_CHECK(clEnqueueBarrierWithWaitList(backend_ctx->queue, events.size(), events.data(), nullptr));\n    for (auto ev : events) {\n        CL_CHECK(clReleaseEvent(ev));\n    }\n}\n\nstatic void sync_with_other_backends(ggml_backend_t backend) {\n    auto * backend_ctx = static_cast<ggml_backend_opencl_context *>(backend->context);\n    sync_with_other_backends(backend_ctx);\n}\n\nstatic bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {\n    if (!ggml_can_fuse(cgraph, node_idx, ops)) {\n        return false;\n    }\n\n    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {\n        const ggml_tensor *rms_norm = cgraph->nodes[node_idx];\n        const ggml_tensor *mul      = cgraph->nodes[node_idx+1];\n\n        GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);\n        GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);\n\n        // rms_norm only supports f32\n        if (mul->src[0]->type != GGML_TYPE_F32 ||\n            mul->src[1]->type != GGML_TYPE_F32 ||\n            mul->type != GGML_TYPE_F32) {\n            return false;\n        }\n\n        // if rms_norm is the B operand, then we don't handle broadcast\n        if (rms_norm == mul->src[1] &&\n            !ggml_are_same_shape(mul->src[0], rms_norm)) {\n            return false;\n        }\n\n        // rms_norm assumes contiguous rows\n        if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {\n            return false;\n        }\n    } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {\n        const ggml_tensor *norm = cgraph->nodes[node_idx];\n        const ggml_tensor *mul  = cgraph->nodes[node_idx+1];\n        const ggml_tensor *add  = cgraph->nodes[node_idx+2];\n        const ggml_tensor *w    = mul->src[0] == norm ? mul->src[1] : mul->src[0];\n        const ggml_tensor *b    = add->src[0] == mul  ? add->src[1] : add->src[0];\n\n        // norm fusion only supports F32\n        if (norm->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {\n            return false;\n        }\n\n        if (norm->src[0]->ne[0] % 4 != 0) {\n            return false;\n        }\n\n        if (!ggml_is_contiguous(norm->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {\n            return false;\n        }\n    } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_GROUP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {\n        const ggml_tensor *gn = cgraph->nodes[node_idx];\n        const ggml_tensor *mul = cgraph->nodes[node_idx+1];\n        const ggml_tensor *add = cgraph->nodes[node_idx+2];\n        const ggml_tensor *w   = mul->src[0] == gn ? mul->src[1] : mul->src[0];\n        const ggml_tensor *b   = add->src[0] == mul ? add->src[1] : add->src[0];\n\n        if (gn->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {\n            return false;\n        }\n\n        if (!ggml_is_contiguous(gn->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {\n            return false;\n        }\n    }\n\n    return true;\n}\n\nstatic void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);\nstatic void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);\nstatic void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);\n\nstatic ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        ggml_tensor * node = cgraph->nodes[i];\n\n        // NOTE: this may oversynchronize by synchronizing with\n        //       backends/devices which don't compute 'cgraph's\n        //       dependencies.\n        sync_with_other_backends(backend);\n\n        if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {\n            continue;\n        }\n\n        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {\n            continue;\n        }\n\n        if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {\n            ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);\n            i += 2;\n            continue;\n        }\n        if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_GROUP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {\n            ggml_opencl_op_group_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);\n            i += 2;\n            continue;\n        }\n        if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {\n            ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);\n            i++;\n            continue;\n        }\n\n        bool ok = ggml_cl_compute_forward(backend, node);\n        if (!ok) {\n            GGML_LOG_ERROR(\"%s: error: op not supported %s (%s)\\n\", __func__, node->name, ggml_op_name(node->op));\n        }\n        GGML_ASSERT(ok);\n    }\n\n    return GGML_STATUS_SUCCESS;\n}\n\nstatic bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {\n    ggml_backend_opencl_device_context * dev_ctx     = (ggml_backend_opencl_device_context *)dev->context;\n    ggml_backend_opencl_context *        backend_ctx = dev_ctx->backend_ctx;\n\n    switch (op->op) {\n        case GGML_OP_NONE:\n            return true;\n        case GGML_OP_GET_ROWS:\n            switch (op->src[0]->type) {\n                case GGML_TYPE_F32:\n                case GGML_TYPE_F16:\n                    return true;\n                case GGML_TYPE_Q4_0:\n#ifdef GGML_OPENCL_SOA_Q\n                    // We do not support flattened Q4_0 (and possibly other Q's)\n                    return false;\n#else // GGML_OPENCL_SOA_Q\n                    return true;\n#endif // GGML_OPENCL_SOA_Q\n                default:\n                    return false;\n            }\n        case GGML_OP_SET_ROWS:\n            {\n                // TODO: add support\n                // ref: https://github.com/ggml-org/llama.cpp/pull/14274\n#pragma message(\"TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)\")\n                if (op->src[0]->type != GGML_TYPE_F32) {\n                    return false;\n                }\n                switch (op->type) {\n                    case GGML_TYPE_F16:\n                    case GGML_TYPE_F32:\n                        return (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);\n                    default:\n                        return false;\n                }\n            }\n        case GGML_OP_CPY:\n        case GGML_OP_DUP:\n        case GGML_OP_CONT:\n            switch (op->src[0]->type) {\n                case GGML_TYPE_F32:\n                    switch (op->type) {\n                        case GGML_TYPE_F16:\n                        case GGML_TYPE_F32:\n                            return true;\n                        default:\n                            return false;\n                    }\n                case GGML_TYPE_F16:\n                    switch (op->type) {\n                        case GGML_TYPE_F16:\n                        case GGML_TYPE_F32:\n                            return true;\n                        default:\n                            return false;\n                    }\n                case GGML_TYPE_I32:\n                    switch (op->type) {\n                        case GGML_TYPE_I32:\n                            return true;\n                        default:\n                            return false;\n                    }\n                default:\n                    return false;\n            }\n        case GGML_OP_SET: {\n            return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32) &&\n                    op->type == op->src[0]->type &&\n                    op->type == op->src[1]->type;\n        }\n        case GGML_OP_SCALE:\n            return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);\n        case GGML_OP_ADD:\n            if (op->type == GGML_TYPE_F16) {\n                const bool src0_ok = op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32;\n                const bool src1_ok = op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32;\n                if (src0_ok && src1_ok) {\n                    return true;\n                }\n            }\n        case GGML_OP_MUL:\n        case GGML_OP_DIV:\n        case GGML_OP_SUB:\n            return (op->src[0]->type == op->src[1]->type) &&\n                   (op->src[0]->type == op->type) &&\n                   (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);\n        case GGML_OP_ADD_ID:\n            return op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_SQR:\n        case GGML_OP_SQRT:\n            return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&\n                    ggml_is_contiguous(op->src[0]);\n        case GGML_OP_UNARY:\n            switch (ggml_get_unary_op(op)) {\n                case GGML_UNARY_OP_GELU:\n                case GGML_UNARY_OP_SILU:\n                case GGML_UNARY_OP_RELU:\n                case GGML_UNARY_OP_GELU_ERF:\n                case GGML_UNARY_OP_GELU_QUICK:\n                   return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;\n                case GGML_UNARY_OP_SIGMOID:\n                    return ggml_is_contiguous(op->src[0]);\n                case GGML_UNARY_OP_TANH:\n                case GGML_UNARY_OP_NEG:\n                case GGML_UNARY_OP_EXP:\n                   return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;\n                case GGML_UNARY_OP_EXPM1:\n                   return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;\n                case GGML_UNARY_OP_SOFTPLUS:\n                   return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;\n                default:\n                    return false;\n            }\n        case GGML_OP_GLU:\n            switch (ggml_get_glu_op(op)) {\n                case GGML_GLU_OP_GEGLU:\n                case GGML_GLU_OP_REGLU:\n                case GGML_GLU_OP_SWIGLU:\n                case GGML_GLU_OP_SWIGLU_OAI:\n                case GGML_GLU_OP_GEGLU_ERF:\n                case GGML_GLU_OP_GEGLU_QUICK:\n                    return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);\n                default:\n                    return false;\n            }\n        case GGML_OP_TRI:\n            return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);\n        case GGML_OP_FILL:\n            return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);\n        case GGML_OP_CLAMP:\n            return op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_SOFT_MAX:\n        case GGML_OP_NORM:\n            return true;\n        case GGML_OP_RMS_NORM:\n            return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]);\n        case GGML_OP_L2_NORM:\n            return ggml_is_contiguous_rows(op->src[0]);\n        case GGML_OP_REPEAT:\n            return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded\n        case GGML_OP_PAD:\n            // TODO: add circular padding support for opencl, see https://github.com/ggml-org/llama.cpp/pull/16985\n            if (ggml_get_op_params_i32(op, 8) != 0) {\n                return false;\n            }\n            return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;\n        case GGML_OP_UPSCALE: {\n            ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);\n            const bool antialias = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & GGML_SCALE_FLAG_ANTIALIAS);\n            return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&\n                   (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR) && !antialias;\n        }\n        case GGML_OP_CONV_2D:\n            return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||\n                   (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||\n                   (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);\n        case GGML_OP_SSM_CONV:\n            return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);\n        case GGML_OP_CONCAT:\n            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;\n        case GGML_OP_TIMESTEP_EMBEDDING:\n            return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;\n        case GGML_OP_GROUP_NORM:\n            return ggml_is_contiguous(op->src[0]);\n        case GGML_OP_MUL_MAT:\n            if (op->src[0]->type == GGML_TYPE_F16) {\n                return true;\n            } else if (op->src[0]->type == GGML_TYPE_F32) {\n                return op->src[1]->type == GGML_TYPE_F32;\n            } else if (op->src[0]->type == GGML_TYPE_Q4_0  || op->src[0]->type == GGML_TYPE_Q4_1 ||\n                       op->src[0]->type == GGML_TYPE_MXFP4 ||\n                       op->src[0]->type == GGML_TYPE_Q4_K  ||\n                       op->src[0]->type == GGML_TYPE_Q6_K) {\n                return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);\n            } else if (op->src[0]->type == GGML_TYPE_Q8_0) {\n                return op->src[1]->type == GGML_TYPE_F32;\n            }\n            return false;\n        case GGML_OP_MUL_MAT_ID:\n            if (op->src[0]->type == GGML_TYPE_Q4_0 ||\n                op->src[0]->type == GGML_TYPE_Q8_0 ||\n                op->src[0]->type == GGML_TYPE_MXFP4) {\n                if (op->src[1]->type == GGML_TYPE_F32) {\n                    return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);\n                }\n            }\n            return false;\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n            return true;\n        case GGML_OP_DIAG:\n            return true;\n        case GGML_OP_DIAG_MASK_INF:\n            return op->ne[3] == 1;\n        case GGML_OP_ROPE: {\n            const int mode = ((const int32_t *) op->op_params)[2];\n            const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;\n            const bool is_vision = mode == GGML_ROPE_TYPE_VISION;\n            if (is_mrope && !is_vision) {\n                if (op->src[0]->type == GGML_TYPE_F32 ||\n                    op->src[0]->type == GGML_TYPE_F16) {\n                    return true;\n                }\n                return false;\n            }\n            if (is_vision) {\n                if (op->src[0]->type == GGML_TYPE_F32 ||\n                    op->src[0]->type == GGML_TYPE_F16) {\n                    return true;\n                }\n                return false;\n            }\n            return true;\n        }\n        case GGML_OP_SOLVE_TRI:\n            return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);\n        case GGML_OP_IM2COL:\n            return true;\n        case GGML_OP_ARGSORT: {\n            cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32;\n            int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);\n\n            int cols = 1;\n            while (cols < op->ne[0]) {\n                cols *= 2;\n            }\n\n            return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32;\n        }\n        case GGML_OP_SUM_ROWS:\n        case GGML_OP_CUMSUM:\n            return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);\n        case GGML_OP_MEAN:\n            return op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_FLASH_ATTN_EXT:\n            {\n                const ggml_tensor * q = op->src[0];\n                const ggml_tensor * k = op->src[1];\n                const ggml_tensor * v = op->src[2];\n\n                const int dk = q->ne[0];\n                const int dv = v->ne[0];\n\n                const struct { int dk; int dv; } supported_dims[] = {\n                    { 40,  40}, { 64,  64}, { 80,  80}, { 96,  96},\n                    {112, 112}, {128, 128}, {192, 128},\n                    {192, 192}, {256, 256},\n                };\n\n                bool dims_supported = false;\n                for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) {\n                    if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) {\n                        dims_supported = true;\n                        break;\n                    }\n                }\n                if (!dims_supported) {\n                    return false;\n                }\n\n                const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 &&\n                                        v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;\n                const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 &&\n                                        v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16;\n                const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 &&\n                                        v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32;\n\n                return is_f32_f32 || is_f16_f16 || is_f32_f16;\n            }\n        default:\n            return false;\n    }\n}\n\n// Forward declaration - implementation appears later in the file.\nstatic const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type);\n\nstatic ggml_guid_t ggml_backend_opencl_guid() {\n    static ggml_guid guid = { 0xde, 0xe0, 0x70, 0xa2, 0x73, 0x4e, 0x4d, 0xbc, 0xb0, 0xc7, 0x4f, 0xd4, 0x6d, 0x4e, 0x90, 0xfe };\n    return &guid;\n}\n\nstatic ggml_backend_i ggml_backend_opencl_i = {\n    /* .get_name                = */ ggml_backend_opencl_name,\n    /* .free                    = */ ggml_backend_opencl_free,\n    /* .set_tensor_async        = */ NULL,  /* ggml_backend_opencl_set_tensor_async */\n    /* .get_tensor_async        = */ NULL,  /* ggml_backend_opencl_get_tensor_async */\n    /* .cpy_tensor_async        = */ NULL,  /* ggml_backend_opencl_cpy_tensor_async */\n    /* .synchronize             = */ ggml_backend_opencl_synchronize,\n    /* .graph_plan_create       = */ NULL,\n    /* .graph_plan_free         = */ NULL,\n    /* .graph_plan_update       = */ NULL,\n    /* .graph_plan_compute      = */ NULL,\n    /* .graph_compute           = */ ggml_backend_opencl_graph_compute,\n    /* .event_record            = */ NULL,\n    /* .event_wait              = */ NULL,\n    /* .graph_optimize          = */ NULL,\n};\n\nggml_backend_t ggml_backend_opencl_init(void) {\n    ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_opencl_reg(), 0);\n    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev);\n\n    ggml_backend_t backend = new ggml_backend {\n        /* .guid    = */ ggml_backend_opencl_guid(),\n        /* .iface   = */ ggml_backend_opencl_i,\n        /* .device  = */ dev,\n        /* .context = */ backend_ctx\n    };\n\n    return backend;\n}\n\nbool ggml_backend_is_opencl(ggml_backend_t backend) {\n    return backend && backend->iface.get_name == ggml_backend_opencl_name;\n}\n\n//\n// buffer\n//\nstruct ggml_backend_opencl_buffer_context {\n    // A buffer context can hold multiple cl_mem objects. This is for flattening\n    // quantized weights and should be used with GGML_OPENCL_SMALL_ALLOC where\n    // each tensor is allocated a separate buffer. When flattening is enabled\n    // with small allocation, each tensor is backed by two cl_mem objects (for\n    // quants and scales) packed into a backend_opencl_buffer.\n    ggml_backend_opencl_buffer_context(cl_mem buf)\n        : name(\"OpenCL\") {\n        buffer.push_back(buf);\n    }\n\n    ~ggml_backend_opencl_buffer_context() {\n        for (cl_mem buf : buffer) {\n            CL_CHECK(clReleaseMemObject(buf));\n        }\n        for (cl_mem im : img) {\n            CL_CHECK(clReleaseMemObject(im));\n        }\n\n        // Delete all extras to trigger their destructors\n        for (ggml_tensor_extra_cl * e : temp_tensor_extras) {\n            delete e;\n        }\n        for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) {\n            delete e;\n        }\n        for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0) {\n            delete e;\n        }\n        for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) {\n            delete e;\n        }\n        for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4) {\n            delete e;\n        }\n        for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) {\n            delete e;\n        }\n        for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0) {\n            delete e;\n        }\n        for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) {\n            delete e;\n        }\n        for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K) {\n            delete e;\n        }\n        for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) {\n            delete e;\n        }\n    }\n\n    ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() {\n        ggml_tensor_extra_cl * extra;\n        if (temp_tensor_extras.empty()) {\n            extra = new ggml_tensor_extra_cl();\n        } else {\n            extra = temp_tensor_extras.back();\n            temp_tensor_extras.pop_back();\n        }\n\n        temp_tensor_extras_in_use.push_back(extra);\n\n        extra->reset();\n        return extra;\n    }\n\n    ggml_tensor_extra_cl_q4_0 * ggml_opencl_alloc_temp_tensor_extra_q4_0() {\n        ggml_tensor_extra_cl_q4_0 * extra;\n        if (temp_tensor_extras_q4_0.empty()) {\n            extra = new ggml_tensor_extra_cl_q4_0();\n        } else {\n            extra = temp_tensor_extras_q4_0.back();\n            temp_tensor_extras_q4_0.pop_back();\n        }\n\n        temp_tensor_extras_q4_0_in_use.push_back(extra);\n\n        extra->reset();\n        return extra;\n    }\n\n    ggml_tensor_extra_cl_q4_1 * ggml_opencl_alloc_temp_tensor_extra_q4_1() {\n        ggml_tensor_extra_cl_q4_1 * extra;\n        if (temp_tensor_extras_q4_1.empty()) {\n            extra = new ggml_tensor_extra_cl_q4_1();\n        } else {\n            extra = temp_tensor_extras_q4_1.back();\n            temp_tensor_extras_q4_1.pop_back();\n        }\n\n        temp_tensor_extras_q4_1_in_use.push_back(extra);\n\n        extra->reset();\n        return extra;\n    }\n\n    ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() {\n        ggml_tensor_extra_cl_mxfp4 * extra;\n        if (temp_tensor_extras_mxfp4.empty()) {\n            extra = new ggml_tensor_extra_cl_mxfp4();\n        } else {\n            extra = temp_tensor_extras_mxfp4.back();\n            temp_tensor_extras_mxfp4.pop_back();\n        }\n\n        temp_tensor_extras_mxfp4_in_use.push_back(extra);\n\n        extra->reset();\n        return extra;\n    }\n\n    ggml_tensor_extra_cl_q8_0 * ggml_opencl_alloc_temp_tensor_extra_q8_0() {\n        ggml_tensor_extra_cl_q8_0 * extra;\n        if (temp_tensor_extras_q8_0.empty()) {\n            extra = new ggml_tensor_extra_cl_q8_0();\n        } else {\n            extra = temp_tensor_extras_q8_0.back();\n            temp_tensor_extras_q8_0.pop_back();\n        }\n\n        temp_tensor_extras_q8_0_in_use.push_back(extra);\n\n        extra->reset();\n        return extra;\n    }\n\n    ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() {\n        ggml_tensor_extra_cl_q6_K * extra;\n        if (temp_tensor_extras_q6_K.empty()) {\n            extra = new ggml_tensor_extra_cl_q6_K();\n        } else {\n            extra = temp_tensor_extras_q6_K.back();\n            temp_tensor_extras_q6_K.pop_back();\n        }\n\n        temp_tensor_extras_q6_K_in_use.push_back(extra);\n\n        extra->reset();\n        return extra;\n    }\n\n    void reset() {\n        for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) {\n            temp_tensor_extras.push_back(e);\n        }\n        temp_tensor_extras_in_use.clear();\n\n        for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) {\n            temp_tensor_extras_q4_0.push_back(e);\n        }\n        temp_tensor_extras_q4_0_in_use.clear();\n\n        for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) {\n            temp_tensor_extras_q4_1.push_back(e);\n        }\n        temp_tensor_extras_q4_1_in_use.clear();\n\n        for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) {\n            temp_tensor_extras_mxfp4.push_back(e);\n        }\n        temp_tensor_extras_mxfp4_in_use.clear();\n\n        for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) {\n            temp_tensor_extras_q8_0.push_back(e);\n        }\n        temp_tensor_extras_q8_0_in_use.clear();\n\n        for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) {\n            temp_tensor_extras_q6_K.push_back(e);\n        }\n        temp_tensor_extras_q6_K_in_use.clear();\n    }\n\n    // Pools for extras. Available extras are in `temp_tensor_extras`. Extras\n    // being used are in `temp_tensor_extras_in_use`. At the first run, new\n    // extras get created and put in `in_use`. When the buffer is reset via\n    // the `reset` callback, all extras in `in_use` get moved to available extras\n    // for reuse.\n    std::vector<ggml_tensor_extra_cl *> temp_tensor_extras;\n    std::vector<ggml_tensor_extra_cl *> temp_tensor_extras_in_use;\n    std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0;\n    std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0_in_use;\n    std::vector<ggml_tensor_extra_cl_q4_1 *> temp_tensor_extras_q4_1;\n    std::vector<ggml_tensor_extra_cl_q4_1 *> temp_tensor_extras_q4_1_in_use;\n    std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4;\n    std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4_in_use;\n    std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0;\n    std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0_in_use;\n    std::vector<ggml_tensor_extra_cl_q6_K *> temp_tensor_extras_q6_K;\n    std::vector<ggml_tensor_extra_cl_q6_K *> temp_tensor_extras_q6_K_in_use;\n\n    // The buffer_context is initially created by ggml_backend_buft_alloc_buffer\n    // before any tensor is initialized (at the beginning of alloc_tensor_range).\n    // Hence, there is always a buffer object in this vector. When each tensor is\n    // being initialized, this original buffer object will be released if both\n    // flattening and small allocation are enabled, and additional buffer\n    // objects will be created in init_tensor to represent flattened quantized\n    // weights.\n    std::vector<cl_mem> buffer;\n    // These are image1d_buffer_t objects that wrap around the quants and scales.\n    // For Q4_0 quantization, there should be two of them - one for quants and\n    // one for scales. They should be populated only when flattening and small\n    // allocation are enabled.\n    std::vector<cl_mem> img;\n    std::string name;\n};\n\nstatic void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;\n    delete ctx;\n}\n\nstatic void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) {\n    ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer->buft->device);\n    return (void *) (uintptr_t) backend_ctx->alignment;\n}\n\nstatic enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {\n    ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;\n\n    ggml_cl2_init(buffer->buft->device);\n\n    if (tensor->view_src != nullptr) {\n        GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);\n\n        ggml_tensor_extra_cl * view_extra = (ggml_tensor_extra_cl *) tensor->view_src->extra;\n        GGML_ASSERT(view_extra && \"view_extra is nullptr?\");\n\n        // Reuse extra of the parent tensor. The offset of this view tensor\n        // becomes `extra->offset + view_offs` and needs to be calculated when\n        // it is used. This changes is needed because of the change to\n        // ggml_alloc.c in https://github.com/ggml-org/llama.cpp/pull/7640.\n        // `buffer` passed in here will always be `tensor->buffer`. It is OK\n        // to allocate extras from the same buffer context for ordinary\n        // intermediate tensors. But for views into kv cache tensors, doing so\n        // would mess up the extras used by kv cache.\n        // Before #7640, `buffer` is for intermediate tensors, which is always\n        // different from that of kv cache tensors.\n        //\n        // NB: now extra->offset no longer accounts for view_offs.\n        // NB: this should not apply to weight tensors (for end-to-end runs, but\n        //     may apply for test-backend-ops).\n        // FIXME: if any unexpected results are seen, double check the offset -\n        // there could be other places that need fix.\n        tensor->extra = view_extra;\n    } else {\n        {\n            size_t offset = (char *) tensor->data - (char *) ggml_backend_opencl_buffer_get_base(buffer);\n\n            ggml_tensor_extra_cl * extra = ctx->ggml_opencl_alloc_temp_tensor_extra();\n            extra->offset = offset;\n            extra->data_device = ctx->buffer[0];\n            extra->actual_size = ggml_nbytes(tensor);\n\n            tensor->extra = extra;\n        }\n    }\n    return GGML_STATUS_SUCCESS;\n}\n\n// The optimized gemm and gemv kernels are used for large matrices without batch.\n// tensor is the quantized weights matrix.\ninline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {\n    int64_t threshold_ne0 = 512;\n    int64_t threshold_ne1 = 512;\n    if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) &&\n         backend_ctx->adreno_cl_compiler_version.type != DX) {\n        threshold_ne0 = 128;\n        threshold_ne1 = 128;\n    }\n    return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 &&\n            tensor->ne[2] == 1 && tensor->ne[3] == 1;\n}\n\ninline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {\n    GGML_UNUSED(backend_ctx);\n    int ne01 = tensor->ne[1];\n    return ((strstr(tensor->name, \"ffn\") != NULL) || (strstr(tensor->name, \"as\") != NULL)) && (ne01 % 64 == 0);\n}\n\ninline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {\n\n    bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor);\n\n    size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3];\n\n    return ((elem_num < 128 * 1024 * 1024) && adreno_kernel);  // max element num: 2**27\n}\n\nstatic void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);\n\n    cl_context context = backend_ctx->context;\n    cl_command_queue queue = backend_ctx->queue;\n\n#ifdef GGML_OPENCL_SOA_Q\n    // We separate the quantized bits and scale from block_q4_0 by using an\n    // additional kernel, where each thread handles a block. We first read the\n    // original weights into a temporary buffer, then create two separate\n    // buffers for quantized bits and scales, which are then populated by the\n    // conversion kernel.\n    if (tensor->type == GGML_TYPE_Q4_0) {\n        // Tensors should have been preallocated, therefore they should\n        // already have ggml_tensor_extra_cl as extra.\n        ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;\n        GGML_ASSERT(extra_orig && \"Tesnors in OpenCL backend should have been allocated and initialized\");\n\n        // Allocate the new extra and create aliases from the original.\n        ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;\n        ggml_tensor_extra_cl_q4_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_0();\n\n        size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);\n        size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;\n        GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && \"Incorrect tensor size\");\n\n        cl_int err;\n        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,\n            ggml_nbytes(tensor), NULL, &err);\n        CL_CHECK(err);\n        CL_CHECK(clEnqueueWriteBuffer(\n            queue, data_device, CL_TRUE, 0,\n            ggml_nbytes(tensor), data, 0, NULL, NULL));\n\n        // We consider the specified offset arg as always, although For weights\n        // the offset arg should be 0 (we do not assert this).\n        //GGML_ASSERT(offset == 0);\n\n        // We create subbuffers from the original tensor buffer for scales and\n        // quants - i.e., scales and quants are aliases into the buffer object\n        // that backs the original tensor. This is a cleaner way to adapt to the\n        // new memory management.\n        // In the old code, we allocate new buffers for scales and quants\n        // respectively, which could still be done but would result in double\n        // allocation; properly deallocating the preallocated buffer that backs\n        // the tensors is tricky and would leak the backend specific information\n        // into the general backend code.\n        // Does this create misaligned subbuffers (alignment is 1024) in certain\n        // cases ?\n        cl_buffer_region region;\n\n        // The original tensor memory is divided into scales and quants, i.e.,\n        // we first store scales, then quants.\n        // Create subbuffer for scales.\n        region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);\n        region.size = size_d;\n        extra->d = clCreateSubBuffer(\n            extra_orig->data_device, CL_MEM_READ_WRITE,\n            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);\n        CL_CHECK(err);\n        auto previous_origin = region.origin;\n\n        // Create subbuffer for quants.\n        region.origin = align_to(previous_origin + size_d, backend_ctx->alignment);\n        region.size = size_q;\n        extra->q = clCreateSubBuffer(\n            extra_orig->data_device, CL_MEM_READ_WRITE,\n            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);\n        CL_CHECK(err);\n\n        //cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0;\n    #ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n        cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0;\n\n        // The optimized kernels need weights in natural order, so unshuffle.\n        if (use_adreno_kernels(backend_ctx, tensor)) {\n            kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle;\n        }\n    #else\n        cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0;\n    #endif // GGML_OPENCL_USE_ADRENO_KERNELS\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));\n\n        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};\n        size_t local_work_size[] = {64, 1, 1};\n\n        cl_event evt;\n        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n        CL_CHECK(clReleaseMemObject(data_device));\n\n        tensor->extra = extra;\n\n        // transpose the weights and scales\n    #ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n        // Only do transpose for large, non batched matrix\n        // TODO: use preallocated images instead of sub-buffer then image\n        if (use_adreno_kernels(backend_ctx, tensor)) {\n        // <----------------------------------------------------------------------------------> //\n        // start transpose\n        // <----------------------------------------------------------------------------------> //\n        int M = tensor->ne[1];   // ne01\n        int K = tensor->ne[0];   // ne00\n\n        //For matrix-vector multiplication kernel, we assume K is a multiple of 32\n        GGML_ASSERT(K % 32 == 0);\n        //For transpose kernels, we assume K is a multiple of 4 (satisfied by prior assert), and M is a multiple of 4\n        GGML_ASSERT(M % 4 == 0);\n\n        // transpose is out of place, so we need to allocate transposed buffers\n        // <----------------------------------------------------------------------------------> //\n        // use sub_buffer of max buffer size instead\n\n        size_t q_size_bytes = K * M / 8 * sizeof(float);\n        backend_ctx->prealloc_quant_trans.allocate(context, q_size_bytes);\n\n        cl_buffer_region region;\n        region.origin = 0;\n        region.size = q_size_bytes;\n        cl_mem qT_d = clCreateSubBuffer(\n            backend_ctx->prealloc_quant_trans.buffer,\n            0,\n            CL_BUFFER_CREATE_TYPE_REGION,\n            &region,\n            &err);\n        CL_CHECK(err);\n\n        bool K_tile_trans = true;\n        if ((K / 32) % 4 != 0){\n            K_tile_trans =false;\n        }\n\n        size_t d_size_bytes = M * (K / 32) * 2;\n        backend_ctx->prealloc_scales_trans.allocate(context, d_size_bytes);\n\n        region.origin = 0;\n        region.size = d_size_bytes;\n        cl_mem dT_d = clCreateSubBuffer(\n            backend_ctx->prealloc_scales_trans.buffer,\n            0,\n            CL_BUFFER_CREATE_TYPE_REGION,\n            &region,\n            &err);\n        CL_CHECK(err);\n\n        // <----------------------------------------------------------------------------------> //\n\n\n        // create images from the buffers\n        // <----------------------------------------------------------------------------------> //\n        cl_mem q_d_image1D;\n        cl_mem d_d_image1D;\n        cl_mem qT_d_image1D;\n        cl_mem dT_d_image1D;\n\n        cl_image_format img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT };\n        cl_image_desc img_desc_1d;\n\n        memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n        img_desc_1d.image_width = M * K / 4 / 4;\n        img_desc_1d.buffer = extra->q;\n        q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);\n        CL_CHECK(err);\n\n        img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT };\n        memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n        img_desc_1d.image_width = M * K / 4 / 4;\n        img_desc_1d.buffer = qT_d;\n        qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);\n        CL_CHECK(err);\n\n        memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n        if (K_tile_trans) {\n            img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT };\n            img_desc_1d.image_width = M * K / 32 / 4;\n        } else {\n            img_fmt_1d = { CL_R, CL_HALF_FLOAT };\n            img_desc_1d.image_width = M * K / 32;\n        }\n        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n        img_desc_1d.buffer = extra->d;\n        d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);\n        CL_CHECK(err);\n\n        img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT };\n        memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n        img_desc_1d.image_width = M * K / 32 / 4;\n        img_desc_1d.buffer = dT_d;\n        dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);\n        CL_CHECK(err);\n        // <----------------------------------------------------------------------------------> //\n\n        // set up and call the transpose kernels\n        // <----------------------------------------------------------------------------------> //\n        // weights\n        int height_q = M / 4;\n        int width_q = K / 4 / 4;\n        kernel = backend_ctx->kernel_transpose_16;\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int),    &height_q));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int),    &width_q));\n\n        size_t local_size_q[3] = {4, 16, 1};\n        size_t global_size_q[3] = {static_cast<size_t>(width_q), static_cast<size_t>(height_q), 1};\n        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n\n        // scales\n        int height_s = M / 4;\n        int width_s = K / 32 / 4;\n\n        kernel = backend_ctx->kernel_transpose_16;\n        if (!K_tile_trans) {\n            kernel = backend_ctx->kernel_transpose_16_4x1;\n            width_s = K / 32;\n        }\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s));\n\n        size_t local_size_s[3] = {4, 16, 1};\n        size_t global_size_s[3] = {static_cast<size_t>(width_s), static_cast<size_t>(height_s), 1};\n        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n        // <----------------------------------------------------------------------------------> //\n\n        // copy transposed buffer contents to original buffers\n        // <----------------------------------------------------------------------------------> //\n        // weights\n        CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n\n        // scales\n        CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n        // <----------------------------------------------------------------------------------> //\n\n        // deallocate transpose buffers\n        // <----------------------------------------------------------------------------------> //\n        CL_CHECK(clReleaseMemObject(qT_d));\n        CL_CHECK(clReleaseMemObject(dT_d));\n\n        // deallocate temporary images\n        CL_CHECK(clReleaseMemObject(q_d_image1D));\n        CL_CHECK(clReleaseMemObject(d_d_image1D));\n        CL_CHECK(clReleaseMemObject(qT_d_image1D));\n        CL_CHECK(clReleaseMemObject(dT_d_image1D));\n        // <----------------------------------------------------------------------------------> //\n        // end transpose\n        // <----------------------------------------------------------------------------------> //\n        }\n    #endif // GGML_OPENCL_USE_ADRENO_KERNELS\n\n        return;\n\n    }\n    if (tensor->type == GGML_TYPE_Q4_1) {\n        ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;\n        GGML_ASSERT(extra_orig && \"Tesnors in OpenCL backend should have been allocated and initialized\");\n\n        // Allocate the new extra and create aliases from the original.\n        ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;\n        ggml_tensor_extra_cl_q4_1 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_1();\n\n        size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);\n        size_t size_m = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);\n        size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;\n        GGML_ASSERT(size_d + size_m + size_q == ggml_nbytes(tensor) && \"Incorrect tensor size\");\n\n        cl_int err;\n        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,\n            ggml_nbytes(tensor), NULL, &err);\n        CL_CHECK(err);\n        CL_CHECK(clEnqueueWriteBuffer(\n            queue, data_device, CL_TRUE, 0,\n            ggml_nbytes(tensor), data, 0, NULL, NULL));\n\n        cl_buffer_region region;\n\n        // The original tensor memory is divided into scales and quants, i.e.,\n        // we first store scales, mins, then quants.\n        // Create subbuffer for scales.\n        region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);\n        region.size = size_d;\n        extra->d = clCreateSubBuffer(\n            extra_orig->data_device, CL_MEM_READ_WRITE,\n            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);\n        CL_CHECK(err);\n        auto previous_origin = region.origin;\n\n        // Create subbuffer for mins.\n        region.origin = align_to(previous_origin + size_d, backend_ctx->alignment);\n        region.size = size_m;\n        extra->m = clCreateSubBuffer(\n            extra_orig->data_device, CL_MEM_READ_WRITE,\n            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);\n        CL_CHECK(err);\n        previous_origin = region.origin;\n\n        // Create subbuffer for quants.\n        region.origin = align_to(previous_origin + size_m, backend_ctx->alignment);\n        region.size = size_q;\n        extra->q = clCreateSubBuffer(\n            extra_orig->data_device, CL_MEM_READ_WRITE,\n            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);\n        CL_CHECK(err);\n\n    #ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n        cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1;\n\n        if (use_adreno_kernels(backend_ctx, tensor)) {\n            kernel = backend_ctx->kernel_convert_block_q4_1_noshuffle;\n        }\n    #else\n        cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1;\n    #endif // GGML_OPENCL_USE_ADRENO_KERNELS\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m));\n\n        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};\n        size_t local_work_size[] = {64, 1, 1};\n\n        cl_event evt;\n        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n        CL_CHECK(clReleaseMemObject(data_device));\n\n        tensor->extra = extra;\n\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n        if (use_adreno_kernels(backend_ctx, tensor)) {\n\n            int M = tensor->ne[1];\n            int K = tensor->ne[0];\n\n            GGML_ASSERT(K % 32 == 0);\n\n            // Transpose q as ushort\n            transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M);\n            // Transpose d as ushort\n            transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M);\n            // Transpose m as ushort\n            transpose_2d_as_16b(backend_ctx, extra->m, extra->m, size_m, K/32, M);\n        }\n#endif // GGML_OPENCL_USE_ADRENO_KERNELS\n        return;\n    }\n    if (tensor->type == GGML_TYPE_MXFP4) {\n        ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;\n        GGML_ASSERT(extra_orig && \"Tesnors in OpenCL backend should have been allocated and initialized\");\n\n        // Allocate the new extra and create aliases from the original.\n        ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;\n        ggml_tensor_extra_cl_mxfp4 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_mxfp4();\n\n        size_t size_e = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(char);\n        size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;\n        GGML_ASSERT(size_e + size_q == ggml_nbytes(tensor) && \"Incorrect tensor size\");\n\n        cl_int err;\n        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,\n            ggml_nbytes(tensor), NULL, &err);\n        CL_CHECK(err);\n        CL_CHECK(clEnqueueWriteBuffer(\n            queue, data_device, CL_TRUE, 0,\n            ggml_nbytes(tensor), data, 0, NULL, NULL));\n\n        // The original tensor memory is divided into scales and quants, i.e.,\n        // we first store scales, then quants.\n        cl_buffer_region region;\n\n        // Create subbuffer for scales.\n        region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);\n        region.size = size_e;\n        extra->e = clCreateSubBuffer(\n            extra_orig->data_device, CL_MEM_READ_WRITE,\n            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);\n        CL_CHECK(err);\n        auto previous_origin = region.origin;\n\n        // Create subbuffer for quants.\n        region.origin = align_to(previous_origin + size_e, backend_ctx->alignment);\n        region.size = size_q;\n        extra->q = clCreateSubBuffer(\n            extra_orig->data_device, CL_MEM_READ_WRITE,\n            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);\n        CL_CHECK(err);\n\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n        if (use_adreno_moe_kernels(backend_ctx, tensor)) {\n            cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans;\n\n            int ne00 = tensor->ne[0];\n            int ne01 = tensor->ne[1];\n            int ne02 = tensor->ne[2];\n            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));\n            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));\n            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));\n            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00));\n            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01));\n\n            size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};\n            size_t local_work_size[3] = {64, 2, 1};\n\n            cl_event evt;\n            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));\n            CL_CHECK(clWaitForEvents(1, &evt));\n            CL_CHECK(clReleaseMemObject(data_device));\n            tensor->extra = extra;\n\n            return;\n        }\n#endif\n        cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4;\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));\n\n        size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};\n        size_t local_work_size[3] = {64, 1, 1};\n\n        cl_event evt;\n        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n        CL_CHECK(clReleaseMemObject(data_device));\n\n        // Create image for Q\n        cl_image_format img_format_q = {CL_RG, CL_UNSIGNED_INT32};\n        cl_image_desc img_desc_q = {\n            CL_MEM_OBJECT_IMAGE1D_BUFFER,\n            static_cast<size_t>(ggml_nelements(tensor)/32*2),\n            0, 0, 0, 0, 0, 0, 0,\n            { extra->q }\n        };\n        extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);\n        tensor->extra = extra;\n\n        return;\n    }\n    if (tensor->type == GGML_TYPE_Q8_0) {\n        ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;\n        GGML_ASSERT(extra_orig && \"Tesnors in OpenCL backend should have been allocated and initialized\");\n\n        // Allocate the new extra and create aliases from the original.\n        ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;\n        ggml_tensor_extra_cl_q8_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q8_0();\n\n        size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);\n        size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)*sizeof(char));\n        GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && \"Incorrect tensor size\");\n\n        cl_int err;\n        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,\n            ggml_nbytes(tensor), NULL, &err);\n        CL_CHECK(err);\n        CL_CHECK(clEnqueueWriteBuffer(\n            queue, data_device, CL_TRUE, 0,\n            ggml_nbytes(tensor), data, 0, NULL, NULL));\n\n        // The original tensor memory is divided into scales and quants, i.e.,\n        // we first store scales, then quants.\n        cl_buffer_region region;\n\n        // Create subbuffer for scales.\n        region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);\n        region.size = size_d;\n        extra->d = clCreateSubBuffer(\n            extra_orig->data_device, CL_MEM_READ_WRITE,\n            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);\n        CL_CHECK(err);\n        auto previous_origin = region.origin;\n\n        // Create subbuffer for quants.\n        region.origin = align_to(previous_origin + size_d, backend_ctx->alignment);\n        region.size = size_q;\n        extra->q = clCreateSubBuffer(\n            extra_orig->data_device, CL_MEM_READ_WRITE,\n            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);\n        CL_CHECK(err);\n\n        cl_kernel kernel = backend_ctx->kernel_convert_block_q8_0;\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));\n\n        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};\n        size_t local_work_size[] = {64, 1, 1};\n\n        cl_event evt;\n        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n        CL_CHECK(clReleaseMemObject(data_device));\n\n        tensor->extra = extra;\n\n        // Transpose the weights and scales\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n        if (enable_adreno_trans_weight(backend_ctx, tensor)) {\n\n            int M = tensor->ne[1];   // ne01\n            int K = tensor->ne[0];   // ne00\n\n            GGML_ASSERT(K % 32 == 0);\n            GGML_ASSERT(M % 4 == 0);\n            GGML_ASSERT(tensor->ne[2] == 1);\n            GGML_ASSERT(tensor->ne[3] == 1);\n\n            // Transpose weights\n            size_t q_size_bytes = K * M / 4 * sizeof(float);\n            cl_buffer_region region;\n            region.origin = 0;\n            region.size = q_size_bytes;\n            cl_mem qT_d = clCreateSubBuffer(\n                backend_ctx->prealloc_quant_trans.buffer,\n                0,\n                CL_BUFFER_CREATE_TYPE_REGION,\n                &region,\n                &err);\n            CL_CHECK(err);\n\n            cl_mem q_d_image1D;\n            cl_mem qT_d_image1D;\n\n            cl_image_format img_fmt_1d;\n            cl_image_desc img_desc_1d;\n\n            img_fmt_1d = { CL_RGBA, CL_FLOAT };\n            memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n            img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n            img_desc_1d.image_width = M * K / 4 / 4;\n            img_desc_1d.buffer = extra->q;\n            q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);\n            CL_CHECK(err);\n\n            img_fmt_1d = { CL_RGBA, CL_FLOAT };\n            memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n            img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n            img_desc_1d.image_width = M * K / 4 / 4;\n            img_desc_1d.buffer = qT_d;\n            qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);\n            CL_CHECK(err);\n\n            int height_q = M / 4;\n            int width_q = K / 4 / 4;\n            kernel = backend_ctx->kernel_transpose_32;\n\n            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D));\n            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D));\n            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int),    &height_q));\n            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int),    &width_q));\n\n            size_t local_size_q[3] = {4, 16, 1};\n            size_t global_size_q[3] = {static_cast<size_t>(width_q), static_cast<size_t>(height_q), 1};\n            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt));\n            CL_CHECK(clWaitForEvents(1, &evt));\n\n            // Transpose scales\n            size_t d_size_bytes = M * (K / 32) * 2;\n            region.origin = 0;\n            region.size = d_size_bytes;\n            cl_mem dT_d = clCreateSubBuffer(\n                backend_ctx->prealloc_scales_trans.buffer,\n                0,\n                CL_BUFFER_CREATE_TYPE_REGION,\n                &region,\n                &err);\n            CL_CHECK(err);\n\n            cl_mem d_d_image1D;\n            cl_mem dT_d_image1D;\n\n            memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n            img_fmt_1d = { CL_R, CL_HALF_FLOAT };\n            img_desc_1d.image_width = M * K / 32;\n            img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n            img_desc_1d.buffer = extra->d;\n            d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);\n            CL_CHECK(err);\n\n            img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT };\n            memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n            img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n            img_desc_1d.image_width = M * K / 32 / 4;\n            img_desc_1d.buffer = dT_d;\n            dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);\n            CL_CHECK(err);\n\n            int height_s = M / 4;\n            int width_s = K / 32;\n\n            kernel = backend_ctx->kernel_transpose_16_4x1;\n\n            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D));\n            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D));\n            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s));\n            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s));\n\n            size_t local_size_s[3] = {4, 16, 1};\n            size_t global_size_s[3] = {static_cast<size_t>(width_s), static_cast<size_t>(height_s), 1};\n            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt));\n            CL_CHECK(clWaitForEvents(1, &evt));\n\n            // copy transposed buffer contents to original buffers\n            CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt));\n            CL_CHECK(clWaitForEvents(1, &evt));\n\n            CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt));\n            CL_CHECK(clWaitForEvents(1, &evt));\n\n            CL_CHECK(clReleaseMemObject(qT_d));\n            CL_CHECK(clReleaseMemObject(dT_d));\n\n            CL_CHECK(clReleaseMemObject(q_d_image1D));\n            CL_CHECK(clReleaseMemObject(d_d_image1D));\n            CL_CHECK(clReleaseMemObject(qT_d_image1D));\n            CL_CHECK(clReleaseMemObject(dT_d_image1D));\n        } // end transpose\n#endif // GGML_OPENCL_USE_ADRENO_KERNELS\n\n        return;\n    }\n    if (tensor->type == GGML_TYPE_Q6_K) {\n        ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;\n        GGML_ASSERT(extra_orig && \"Tesnors in OpenCL backend should have been allocated and initialized\");\n\n        // Allocate the new extra and create aliases from the original.\n        ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;\n        ggml_tensor_extra_cl_q6_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q6_K();\n\n        size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;\n        size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4;\n        size_t size_s  = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16;\n        size_t size_d  = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);\n        GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) &&\n            \"Incorrect tensor size\");\n\n        cl_int err;\n        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,\n            ggml_nbytes(tensor), NULL, &err);\n        CL_CHECK(err);\n        CL_CHECK(clEnqueueWriteBuffer(\n            queue, data_device, CL_TRUE, 0,\n            ggml_nbytes(tensor), data, 0, NULL, NULL));\n\n        cl_buffer_region region;\n\n        // Subbuffer for ql\n        region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);\n        region.size = size_ql;\n        extra->ql = clCreateSubBuffer(\n            extra_orig->data_device, CL_MEM_READ_WRITE,\n            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);\n        CL_CHECK(err);\n        auto previous_origin = region.origin;\n\n        // Subbuffer for qh\n        region.origin = align_to(previous_origin + size_ql, backend_ctx->alignment);\n        region.size = size_qh;\n        extra->qh = clCreateSubBuffer(\n            extra_orig->data_device, CL_MEM_READ_WRITE,\n            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);\n        CL_CHECK(err);\n        previous_origin = region.origin;\n\n        // Subbuffer for scales\n        region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment);\n        region.size = size_s;\n        extra->s = clCreateSubBuffer(\n            extra_orig->data_device, CL_MEM_READ_WRITE,\n            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);\n        CL_CHECK(err);\n        previous_origin = region.origin;\n\n        // Create subbuffer for d.\n        region.origin = align_to(previous_origin + size_s, backend_ctx->alignment);\n        region.size = size_d;\n        extra->d = clCreateSubBuffer(\n            extra_orig->data_device, CL_MEM_READ_WRITE,\n            CL_BUFFER_CREATE_TYPE_REGION, &region, &err);\n        CL_CHECK(err);\n        previous_origin = region.origin;\n\n        // Flatten the weights\n        cl_kernel kernel = backend_ctx->kernel_convert_block_q6_K;\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s));\n        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d));\n\n        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};\n        size_t local_work_size[] = {64, 1, 1};\n\n        cl_event evt;\n        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n        CL_CHECK(clReleaseMemObject(data_device));\n\n        extra->size_ql = size_ql;\n        extra->size_qh = size_qh;\n        extra->size_s  = size_s;\n        extra->size_d  = size_d;\n\n        tensor->extra  = extra;\n        return;\n    }\n#endif // GGML_OPENCL_SOA_Q\n\n    ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;\n    GGML_ASSERT(extra);\n\n    CL_CHECK(clEnqueueWriteBuffer(\n        queue, extra->data_device, CL_TRUE, extra->offset + offset,\n        size, data, 0, NULL, NULL));\n\n    GGML_UNUSED(buffer);\n}\n\nstatic void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    GGML_ASSERT(tensor->extra);\n\n    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);\n\n    cl_context context = backend_ctx->context;\n    cl_command_queue queue = backend_ctx->queue;\n\n    // Make sure all previously submitted commands in other devices are finished.\n    sync_with_other_backends(backend_ctx);\n\n#ifdef GGML_OPENCL_SOA_Q\n    // In end-to-end runs, get_tensor is usually used to get back the logits,\n    // where we can simply do clEnqueueReadBuffer since they are f32.\n    // However, in test-backend-ops, the GPU graph is copied to the CPU backend,\n    // which requires reading back quantized weight tensors.\n    // To properly support this, we need to restore block_q4_0 struct arrays\n    // from the flattened buffers.\n    if (tensor->type == GGML_TYPE_Q4_0) {\n        ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra;\n\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n        if (use_adreno_kernels(backend_ctx, tensor)) {\n            cl_int err;\n            cl_kernel kernel;\n\n            cl_int M = tensor->ne[1];   // ne01\n            cl_int K = tensor->ne[0];   // ne00\n\n            GGML_ASSERT(K % 32 == 0);\n            GGML_ASSERT(M % 4 == 0);\n\n            size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2;\n            size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t);\n            GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && \"Incorrect tensor size\");\n\n            cl_mem buf_trans_q;\n            cl_mem buf_trans_d;\n\n            CL_CHECK((buf_trans_q = clCreateBuffer(context, CL_MEM_READ_WRITE,\n                size_q, NULL, &err), err));\n            CL_CHECK((buf_trans_d = clCreateBuffer(context, CL_MEM_READ_WRITE,\n                size_d, NULL, &err), err));\n\n            kernel = backend_ctx->kernel_transpose_16_buf;\n\n            // transpose q back\n            cl_int stride_k_q = K/4;\n            size_t local_size_q[3] = {64, 1, 1};\n            size_t global_size_q[3] = {(size_t)M, (size_t)stride_k_q, 1};\n\n            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));\n            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_q));\n            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &M));\n            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &stride_k_q));\n\n            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,\n                global_size_q, local_size_q, 0, NULL, NULL));\n\n            // transpose scales back\n            cl_int stride_k_d = K/32;\n            size_t local_size_d[3] = {64, 1, 1};\n            size_t global_size_d[3] = {(size_t)M, (size_t)stride_k_d, 1};\n\n            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->d));\n            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d));\n            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &M));\n            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &stride_k_d));\n\n            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,\n                global_size_d, local_size_d, 0, NULL, NULL));\n\n            // unpack\n            cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,\n                ggml_nbytes(tensor), NULL, &err);\n            CL_CHECK(err);\n\n            cl_uchar mask_0F = 0x0F;\n            cl_uchar mask_F0 = 0xF0;\n\n            size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};\n            size_t local_work_size[] = {1, 1, 1};\n\n            kernel = backend_ctx->kernel_restore_block_q4_0_noshuffle;\n            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &buf_trans_q));\n            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem),   &buf_trans_d));\n            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &data_device));\n            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F));\n            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0));\n\n            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,\n                global_work_size, local_work_size, 0, NULL, NULL));\n\n            // read back to host\n            CL_CHECK(clEnqueueReadBuffer(\n                queue, data_device, CL_TRUE, offset,\n                size, data, 0, NULL, NULL));\n\n            CL_CHECK(clReleaseMemObject(data_device));\n            CL_CHECK(clReleaseMemObject(buf_trans_q));\n            CL_CHECK(clReleaseMemObject(buf_trans_d));\n\n            return;\n        }\n#endif\n\n        cl_int err;\n        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,\n            ggml_nbytes(tensor), NULL, &err);\n        CL_CHECK(err);\n\n        cl_kernel kernel = backend_ctx->kernel_restore_block_q4_0;\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));\n\n        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};\n        size_t local_work_size[] = {1, 1, 1};\n\n        cl_event evt;\n        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,\n            global_work_size, local_work_size, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n        CL_CHECK(clEnqueueReadBuffer(\n            queue, data_device, CL_TRUE, offset,\n            size, data, 0, NULL, NULL));\n        CL_CHECK(clReleaseMemObject(data_device));\n        return;\n    }\n    if (tensor->type == GGML_TYPE_Q4_1) {\n        ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra;\n\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n        if (use_adreno_kernels(backend_ctx, tensor)) {\n            static ggml_cl_buffer buf_trans_q;\n            static ggml_cl_buffer buf_trans_m;\n            static ggml_cl_buffer buf_trans_d;\n            static ggml_cl_buffer buf_unpacked;\n\n            cl_int M = tensor->ne[1];\n            cl_int K = tensor->ne[0];\n\n            GGML_ASSERT(K % ggml_blck_size(tensor->type) == 0);\n\n            size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2;\n            size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t);\n            size_t size_m = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t);\n            GGML_ASSERT(size_d + size_q + size_m == ggml_nbytes(tensor) && \"Incorrect tensor size\");\n\n            buf_trans_q.allocate(backend_ctx->context, size_q);\n            buf_trans_m.allocate(backend_ctx->context, size_m);\n            buf_trans_d.allocate(backend_ctx->context, size_d);\n            buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor));\n\n            // transpose q, d, m back\n            transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4);\n            transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32);\n            transpose_2d_as_16b(backend_ctx, extra->m, buf_trans_m.buffer, size_m, M, K/32);\n\n            cl_uchar mask_0F = 0x0F;\n            cl_uchar mask_F0 = 0xF0;\n\n            size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};\n            size_t local_work_size[] = {1, 1, 1};\n\n            cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1_noshuffle;\n            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &buf_trans_q.buffer));\n            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem),   &buf_trans_d.buffer));\n            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &buf_trans_m.buffer));\n            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem),   &buf_unpacked.buffer));\n            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_0F));\n            CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_F0));\n\n            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));\n            CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL));\n            return;\n        }\n#endif\n\n        cl_int err;\n        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,\n            ggml_nbytes(tensor), NULL, &err);\n        CL_CHECK(err);\n\n        cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1;\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device));\n\n        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};\n        size_t local_work_size[] = {1, 1, 1};\n\n        cl_event evt;\n        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,\n            global_work_size, local_work_size, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n        CL_CHECK(clEnqueueReadBuffer(\n            queue, data_device, CL_TRUE, offset,\n            size, data, 0, NULL, NULL));\n        CL_CHECK(clReleaseMemObject(data_device));\n        return;\n    }\n    if (tensor->type == GGML_TYPE_MXFP4) {\n        ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra;\n\n        cl_int err;\n        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,\n            ggml_nbytes(tensor), NULL, &err);\n        CL_CHECK(err);\n\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n        if (use_adreno_moe_kernels(backend_ctx, tensor)) {\n            cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans;\n\n            int ne00 = tensor->ne[0];\n            int ne01 = tensor->ne[1];\n            int ne02 = tensor->ne[2];\n            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));\n            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));\n            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));\n            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00));\n            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01));\n\n            size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};\n            size_t local_work_size[3] = {64, 2, 1};\n\n            cl_event evt;\n            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,\n                global_work_size, local_work_size, 0, NULL, &evt));\n            CL_CHECK(clWaitForEvents(1, &evt));\n            CL_CHECK(clEnqueueReadBuffer(\n                queue, data_device, CL_TRUE, offset,\n                size, data, 0, NULL, NULL));\n            CL_CHECK(clReleaseMemObject(data_device));\n            return;\n        }\n#endif\n        cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));\n\n        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};\n        size_t local_work_size[] = {1, 1, 1};\n\n        cl_event evt;\n        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,\n            global_work_size, local_work_size, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n        CL_CHECK(clEnqueueReadBuffer(\n            queue, data_device, CL_TRUE, offset,\n            size, data, 0, NULL, NULL));\n        CL_CHECK(clReleaseMemObject(data_device));\n        return;\n    }\n    if (tensor->type == GGML_TYPE_Q8_0) {\n        ggml_tensor_extra_cl_q8_0 * extra = (ggml_tensor_extra_cl_q8_0 *)tensor->extra;\n\n        cl_int err;\n        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,\n            ggml_nbytes(tensor), NULL, &err);\n        CL_CHECK(err);\n\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n        if (enable_adreno_trans_weight(backend_ctx, tensor)) {\n            cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0_trans;\n\n            int ne00 = tensor->ne[0];\n            int ne01 = tensor->ne[1];\n            GGML_ASSERT(tensor->ne[2] == 1);\n            GGML_ASSERT(tensor->ne[3] == 1);\n\n            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));\n            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));\n            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));\n            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00));\n            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01));\n\n            size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), 1, 1};\n            size_t local_work_size[3] = {64, 1, 1};\n\n            cl_event evt;\n            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,\n                global_work_size, local_work_size, 0, NULL, &evt));\n            CL_CHECK(clWaitForEvents(1, &evt));\n\n            CL_CHECK(clEnqueueReadBuffer(\n                queue, data_device, CL_TRUE, offset,\n                size, data, 0, NULL, NULL));\n            CL_CHECK(clReleaseMemObject(data_device));\n            return;\n        }\n#endif\n        cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0;\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));\n\n        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};\n        size_t local_work_size[] = {1, 1, 1};\n\n        cl_event evt;\n        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,\n            global_work_size, local_work_size, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n        CL_CHECK(clEnqueueReadBuffer(\n            queue, data_device, CL_TRUE, offset,\n            size, data, 0, NULL, NULL));\n        CL_CHECK(clReleaseMemObject(data_device));\n        return;\n    }\n    if (tensor->type == GGML_TYPE_Q6_K) {\n        ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra;\n\n        cl_int err;\n        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,\n            ggml_nbytes(tensor), NULL, &err);\n        CL_CHECK(err);\n\n        cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K;\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d));\n        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device));\n\n        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};\n        size_t local_work_size[] = {1, 1, 1};\n\n        cl_event evt;\n        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,\n            global_work_size, local_work_size, 0, NULL, &evt));\n        CL_CHECK(clWaitForEvents(1, &evt));\n        CL_CHECK(clEnqueueReadBuffer(\n            queue, data_device, CL_TRUE, offset,\n            size, data, 0, NULL, NULL));\n        CL_CHECK(clReleaseMemObject(data_device));\n        return;\n    }\n#endif // GGML_OPENCL_SOA_Q\n\n    ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;\n\n    CL_CHECK(clEnqueueReadBuffer(\n        queue, extra->data_device, CL_TRUE, extra->offset + tensor->view_offs + offset,\n        size, data, 0, NULL, NULL));\n\n    GGML_UNUSED(buffer);\n}\n\nstatic void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    ggml_backend_dev_t dev = buffer->buft->device;\n    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev);\n    cl_command_queue queue = backend_ctx->queue;\n\n    ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;\n    for (cl_mem buf : ctx->buffer) {\n        CL_CHECK(clEnqueueFillBuffer(queue, buf, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL));\n    }\n    CL_CHECK(clFinish(queue));\n}\n\nstatic void ggml_backend_opencl_buffer_reset(ggml_backend_buffer_t buffer) {\n    ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;\n    ctx->reset();\n}\n\nstatic ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = {\n    /* .free_buffer     = */ ggml_backend_opencl_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_opencl_buffer_get_base,\n    /* .init_tensor     = */ ggml_backend_opencl_buffer_init_tensor,\n    /* .memset_tensor   = */ NULL,\n    /* .set_tensor      = */ ggml_backend_opencl_buffer_set_tensor,\n    /* .get_tensor      = */ ggml_backend_opencl_buffer_get_tensor,\n    /* .cpy_tensor      = */ NULL,\n    /* .clear           = */ ggml_backend_opencl_buffer_clear,\n    /* .reset           = */ ggml_backend_opencl_buffer_reset,\n};\n\n//\n// buffer type\n//\n\nstatic const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type) {\n    return \"OpenCL\";\n\n    GGML_UNUSED(buffer_type);\n}\n\nstatic ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) {\n    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer_type->device);\n\n    // clCreateBuffer returns -61 for size 0\n    size = std::max(size, (size_t)1);\n\n    cl_int err;\n    cl_mem mem = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, size, NULL, &err);\n    if (err != CL_SUCCESS) {\n        GGML_LOG_INFO(\"%s: failed to allocate %.2f MiB\\n\", __func__, size / 1024.0 / 1024.0);\n        return nullptr;\n    }\n\n    ggml_backend_opencl_buffer_context * ctx = new ggml_backend_opencl_buffer_context(mem);\n\n    return ggml_backend_buffer_init(buffer_type, ggml_backend_opencl_buffer_interface, ctx, size);\n}\n\nstatic size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {\n    ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device);\n    return backend_ctx->alignment;\n}\n\nstatic size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {\n    static size_t max_size = -1;\n    if (max_size == (size_t)-1) {\n        ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device);\n        max_size = backend_ctx->max_alloc_size;\n    }\n    return max_size;\n}\n\nstatic bool ggml_backend_opencl_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {\n    return ggml_backend_is_opencl(backend);\n\n    UNUSED(buft);\n}\n\nstatic ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = {\n    /* .get_name         = */ ggml_backend_opencl_buffer_type_get_name,\n    /* .alloc_buffer     = */ ggml_backend_opencl_buffer_type_alloc_buffer,\n    /* .get_alignment    = */ ggml_backend_opencl_buffer_type_get_alignment,\n    /* .get_max_size     = */ ggml_backend_opencl_buffer_type_get_max_size,\n    /* .get_alloc_size   = */ NULL,\n    /* .is_host          = */ NULL,\n};\n\n//\n// backend device\n//\n\nstatic const char * ggml_backend_opencl_device_get_name(ggml_backend_dev_t dev) {\n    return \"GPUOpenCL\";\n\n    GGML_UNUSED(dev);\n}\n\nstatic const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_t dev) {\n    ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *) dev->context;\n    return dev_ctx->device_name.c_str();\n}\n\nstatic void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {\n    // no memory to report\n    *free  = 0;\n    *total = 0;\n\n    GGML_UNUSED(dev);\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_opencl_device_get_type(ggml_backend_dev_t dev) {\n    return GGML_BACKEND_DEVICE_TYPE_GPU;\n\n    GGML_UNUSED(dev);\n}\n\nstatic void ggml_backend_opencl_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {\n    props->name        = ggml_backend_opencl_device_get_name(dev);\n    props->description = ggml_backend_opencl_device_get_description(dev);\n    props->type        = ggml_backend_opencl_device_get_type(dev);\n    ggml_backend_opencl_device_get_memory(dev, &props->memory_free, &props->memory_total);\n    props->caps = ggml_backend_dev_caps {\n        /* .async                 = */ false,\n        /* .host_buffer           = */ false,\n        /* .buffer_from_host_ptr  = */ false,\n        /* .events                = */ false,\n    };\n}\n\nstatic ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, const char * params) {\n    ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(dev);\n    // Getting a new reference to the backend, increase ref_count\n    backend_ctx->ref_count++;\n\n    ggml_backend_t backend = new ggml_backend {\n        /* .guid      = */ ggml_backend_opencl_guid(),\n        /* .interface = */ ggml_backend_opencl_i,\n        /* .device    = */ dev,\n        /* .context   = */ backend_ctx,\n    };\n\n    return backend;\n\n    GGML_UNUSED(params);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_opencl_device_get_buffer_type(ggml_backend_dev_t dev) {\n    auto * dev_ctx = static_cast<ggml_backend_opencl_device_context *>(dev->context);\n\n    dev_ctx->buffer_type = ggml_backend_buffer_type{\n        /* .iface   = */ ggml_backend_opencl_buffer_type_interface,\n        /* .device  = */ dev,\n        /* .context = */ nullptr,\n    };\n\n    return &dev_ctx->buffer_type;\n}\n\nstatic ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {\n    GGML_UNUSED(dev);\n    GGML_UNUSED(ptr);\n    GGML_UNUSED(size);\n    GGML_UNUSED(max_tensor_size);\n    return nullptr;\n}\n\nstatic bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {\n    return ggml_opencl_supports_op(dev, op);\n}\n\nstatic bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    // Check 'dev' and 'buffer_type' are not objects belonging to this backend.\n    if (dev->iface.get_name != ggml_backend_opencl_device_get_name ||\n        buft->iface.get_name != ggml_backend_opencl_buffer_type_get_name) {\n        return false;\n    }\n\n    // Check cl_context is the same. clEnqueue* commands may not use\n    // buffers from another cl_context.\n    ggml_backend_opencl_context * backend_ctx0 = ggml_cl2_init(dev);\n    ggml_backend_opencl_context * backend_ctx1 = ggml_cl2_init(buft->device);\n    return backend_ctx0->context == backend_ctx1->context;\n}\n\nnamespace /* anonymous */ {\nstruct ggml_backend_device_i ggml_backend_opencl_device_i = {\n    /* .get_name             = */ ggml_backend_opencl_device_get_name,\n    /* .get_description      = */ ggml_backend_opencl_device_get_description,\n    /* .get_memory           = */ ggml_backend_opencl_device_get_memory,\n    /* .get_type             = */ ggml_backend_opencl_device_get_type,\n    /* .get_props            = */ ggml_backend_opencl_device_get_props,\n    /* .init_backend         = */ ggml_backend_opencl_device_init,\n    /* .get_buffer_type      = */ ggml_backend_opencl_device_get_buffer_type,\n    /* .get_host_buffer_type = */ NULL,\n    /* .buffer_from_host_ptr = */ ggml_backend_opencl_device_buffer_from_ptr,\n    /* .supports_op          = */ ggml_backend_opencl_device_supports_op,\n    /* .supports_buft        = */ ggml_backend_opencl_device_supports_buft,\n    /* .offload_op           = */ NULL,\n    /* .event_new            = */ NULL,\n    /* .event_free           = */ NULL,\n    /* .event_synchronize    = */ NULL,\n};\n}\n\n// Backend registry\n\nstatic const char * ggml_backend_opencl_reg_get_name(ggml_backend_reg_t reg) {\n    return \"OpenCL\";\n\n    GGML_UNUSED(reg);\n}\n\nstatic size_t ggml_backend_opencl_reg_device_count(ggml_backend_reg_t reg) {\n    return g_ggml_backend_opencl_devices.size();\n\n    GGML_UNUSED(reg);\n}\n\nstatic ggml_backend_dev_t ggml_backend_opencl_reg_device_get(ggml_backend_reg_t reg, size_t index) {\n    GGML_ASSERT(index < ggml_backend_opencl_reg_device_count(reg));\n\n    return &g_ggml_backend_opencl_devices[index];\n\n    GGML_UNUSED(reg);\n    GGML_UNUSED(index);\n}\n\nstatic struct ggml_backend_reg_i ggml_backend_opencl_reg_i = {\n    /* .get_name         = */ ggml_backend_opencl_reg_get_name,\n    /* .device_count     = */ ggml_backend_opencl_reg_device_count,\n    /* .device_get       = */ ggml_backend_opencl_reg_device_get,\n    /* .get_proc_address = */ NULL,\n};\n\nggml_backend_reg_t ggml_backend_opencl_reg(void) {\n    static std::mutex mutex;\n    static ggml_backend_reg reg;\n    static bool initialized = false;\n    std::lock_guard<std::mutex> lock(mutex);\n\n    if (initialized) {\n        return &reg;\n    }\n    initialized = true;\n\n    g_ggml_backend_opencl_devices = ggml_opencl_probe_devices(&reg);\n\n    reg = ggml_backend_reg{\n        /* .api_version = */ GGML_BACKEND_API_VERSION,\n        /* .iface       = */ ggml_backend_opencl_reg_i,\n        /* .context     = */ NULL,\n    };\n\n    return &reg;\n}\n\nGGML_BACKEND_DL_IMPL(ggml_backend_opencl_reg)\n\n//------------------------------------------------------------------------------\n// Debugging utils\n//------------------------------------------------------------------------------\n#if 0\n#define QK4_0 32\ntypedef struct {\n    ggml_fp16_t d;          // delta\n    uint8_t qs[QK4_0 / 2];  // nibbles / quants\n} block_q4_0;\nstatic_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2,\n    \"wrong q4_0 block size/padding\");\n\n#include <math.h>\n#ifdef __cplusplus\n#include \"half.hpp\"\n#endif\n\nstatic void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tensor) {\n    void * buf = malloc(ggml_nbytes(tensor));\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n    cl_command_queue queue = backend_ctx->queue;\n#ifdef GGML_OPENCL_SOA_Q\n    void * buf_q;\n    void * buf_d;\n#endif\n\n    // Make sure everything is done.\n    CL_CHECK(clFinish(queue));\n\n#ifdef GGML_OPENCL_SOA_Q\n    if (tensor->type == GGML_TYPE_Q4_0) {\n        ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *) tensor->extra;\n        GGML_ASSERT(extra);\n\n        size_t size_q = ggml_nelements(tensor)/QK4_0 * QK4_0/2;\n        size_t size_d = ggml_nelements(tensor)/QK4_0 * sizeof(ggml_fp16_t);\n        GGML_ASSERT(size_q + size_d == ggml_nbytes(tensor));\n        buf_q = malloc(size_q);\n        buf_d = malloc(size_d);\n\n        CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL));\n        CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL));\n        CL_CHECK(clFinish(queue));\n    } else if (tensor->type == GGML_TYPE_MXFP4) {\n        ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *) tensor->extra;\n        GGML_ASSERT(extra);\n\n        size_t size_q = ggml_nelements(tensor)/QK_MXFP4 * QK_MXFP4/2;\n        size_t size_e = ggml_nelements(tensor)/QK_MXFP4 * sizeof(char);\n        GGML_ASSERT(size_q + size_e == ggml_nbytes(tensor));\n        buf_q = malloc(size_q);\n        buf_d = malloc(size_e);\n\n        CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL));\n        CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL));\n        CL_CHECK(clFinish(queue));\n    } else {\n        // Read out the tensor from GPU memory.\n        ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;\n        GGML_ASSERT(extra);\n\n        CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE,\n        extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL));\n        CL_CHECK(clFinish(queue));\n    }\n#else\n    // Read out the tensor from GPU memory.\n    ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;\n    GGML_ASSERT(extra);\n\n    CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE,\n        extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL));\n    CL_CHECK(clFinish(queue));\n#endif // GGML_OPENCL_SOA_Q\n\n    // Open file and dump.\n    char fname[512];\n    snprintf(fname, sizeof(fname), \"./tensor-dumps/%s.txt\", tensor->name);\n    FILE * f = fopen(fname, \"w\");\n    if (!f) {\n        printf(\"Failed to open %s\\n\", fname);\n        return;\n    }\n\n    if (tensor->type == GGML_TYPE_F32) {\n        float * data = (float *) buf;\n        for (int i = 0; i < ggml_nelements(tensor); ++i) {\n            if (isnan(data[i])) {\n                printf(\"NaN found: %s\\n\", tensor->name);\n                break;\n            }\n            fprintf(f, \"%f\\n\", data[i]);\n        }\n    } else if (tensor->type == GGML_TYPE_I32) {\n        int * data = (int *) buf;\n        for (int i = 0; i < ggml_nelements(tensor); ++i) {\n            if (isnan(data[i])) {\n                printf(\"NaN found: %s\\n\", tensor->name);\n                break;\n            }\n            fprintf(f, \"%d\\n\", data[i]);\n        }\n    } else if (tensor->type == GGML_TYPE_F16) {\n#ifdef __cplusplus\n        half_float::half * data = (half_float::half *) buf;\n        for (int i = 0; i < ggml_nelements(tensor); ++i) {\n            if (std::isnan(data[i])) {\n                printf(\"NaN found: %s\\n\", tensor->name);\n                break;\n            }\n            fprintf(f, \"%f\\n\", float(data[i]));\n        }\n#endif\n    } else if (tensor->type == GGML_TYPE_Q4_0) {\n#ifdef GGML_OPENCL_SOA_Q\n        ggml_fp16_t * data_d = (ggml_fp16_t *)buf_d;\n        unsigned char * data_q = (unsigned char *)buf_q;\n\n        for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) {\n            fprintf(f, \"%04x, \", data_d[i]);\n            for (int k = 0; k < QK4_0/2; ++k) {\n                fprintf(f, \"%02x, \", data_q[k]);\n            }\n            fprintf(f, \"\\n\");\n            data_q += QK4_0/2;\n        }\n        free(buf_d);\n        free(buf_q);\n#else\n        block_q4_0 * data = (block_q4_0 *) buf;\n        for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) {\n            fprintf(f, \"%04x, \", data[i].d);\n            for (int k = 0; k < QK4_0/2; ++k) {\n                fprintf(f, \"%02x, \", data[i].qs[k]);\n            }\n            fprintf(f, \"\\n\");\n        }\n#endif // GGML_OPENCL_SOA_Q\n    }\n    free(buf);\n    fflush(f);\n    fclose(f);\n}\n#else\n#define dump_tensor(tensor)\n#endif\n\n//------------------------------------------------------------------------------\n// Ops\n//------------------------------------------------------------------------------\n\nstatic bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {\n    const int64_t ne10 = src1->ne[0];\n\n    const int64_t ne0 = dst->ne[0];\n    const int64_t ne1 = dst->ne[1];\n\n    // TODO: find the optimal values for these\n    return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&\n            src1->type == GGML_TYPE_F32 &&\n             dst->type == GGML_TYPE_F32 &&\n            (ne0 >= 32 && ne1 >= 32 && ne10 >= 32);\n}\n\n// Copy a noncontiguous tensor to contiguous tensor. ne[] remains the same but\n// nb[] is recalculated such that tensor is contiguous.\nstatic void ggml_cl_copy_to_contiguous(ggml_backend_t backend, const ggml_tensor * src, cl_mem dst,\n                                       cl_ulong &nb0, cl_ulong &nb1, cl_ulong &nb2, cl_ulong &nb3) {\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    const int tensor_type_size = ggml_type_size(src->type);\n\n    const int ne00 = src->ne[0];\n    const int ne01 = src->ne[1];\n    const int ne02 = src->ne[2];\n    const int ne03 = src->ne[3];\n\n    const cl_ulong nb00 = src->nb[0];\n    const cl_ulong nb01 = src->nb[1];\n    const cl_ulong nb02 = src->nb[2];\n    const cl_ulong nb03 = src->nb[3];\n\n    const int ne0 = src->ne[0];\n    const int ne1 = src->ne[1];\n    const int ne2 = src->ne[2];\n    const int ne3 = src->ne[3];\n\n    nb0 = tensor_type_size;\n    nb1 = tensor_type_size*ne00;\n    nb2 = tensor_type_size*ne00*ne01;\n    nb3 = tensor_type_size*ne00*ne01*ne02;\n\n    ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra;\n\n    cl_ulong offset0 = extra->offset + src->view_offs;\n    cl_ulong offsetd = 0;\n\n    cl_kernel kernel;\n\n    switch (src->type) {\n        case GGML_TYPE_F32:\n            kernel = backend_ctx->kernel_cpy_f32_f32;\n            break;\n        case GGML_TYPE_F16:\n            kernel = backend_ctx->kernel_cpy_f16_f16;\n            break;\n        default:\n            GGML_ASSERT(false && \"not implemented\");\n    }\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &dst));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),      &ne01));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne02));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne03));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb00));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne0));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne1));\n    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne2));\n    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne3));\n    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb0));\n    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));\n    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));\n    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));\n\n    const int nth = MIN(64, ne00);\n\n    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n    size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src);\n}\n\nstatic void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    UNUSED(backend);\n    UNUSED(src0);\n    UNUSED(src1);\n    UNUSED(dst);\n}\n\nstatic void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);\n    GGML_TENSOR_LOCALS(int,      ne1, src1, ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb);\n    GGML_TENSOR_LOCALS(int,      ne,  dst,  ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb,  dst,  nb);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel;\n\n    switch (src0->type) {\n        case GGML_TYPE_F32:\n            kernel = backend_ctx->kernel_get_rows_f32;\n            break;\n        case GGML_TYPE_F16:\n            kernel = backend_ctx->kernel_get_rows_f16;\n            break;\n        case GGML_TYPE_Q4_0:\n            kernel = backend_ctx->kernel_get_rows_q4_0;\n            break;\n        default:\n            GGML_ASSERT(false && \"not implemented\");\n    }\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb03));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne10));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));\n    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1));\n    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2));\n    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3));\n\n    int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);\n    int nth = 1;\n    while (nth < ne00 && 2*nth <= max_workgroup_size) {\n        nth *= 2;\n    }\n\n    size_t global_work_size[] = {(size_t)ne10*nth, (size_t)ne11, (size_t)ne12};\n    size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n    GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);\n\n    // ne0 = ne00\n    // ne2 = ne02\n    // ne3 = ne03\n\n    const int      ne01 = src0->ne[1];\n    const int      ne02 = src0->ne[2];\n    const int      ne03 = src0->ne[3];\n\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const int      ne11 = src1->ne[1];\n    const int      ne12 = src1->ne[2];\n\n    const cl_ulong nb10 = src1->nb[0];\n    const cl_ulong nb11 = src1->nb[1];\n    const cl_ulong nb12 = src1->nb[2];\n\n    const int      ne0  = dst->ne[0];\n\n    const cl_ulong nb1  = dst->nb[1];\n    const cl_ulong nb2  = dst->nb[2];\n    const cl_ulong nb3  = dst->nb[3];\n\n    const int nblk0 = ne0/ggml_blck_size(dst->type);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel;\n\n    switch (dst->type) {\n        case GGML_TYPE_F32:\n            if (src1->type == GGML_TYPE_I64) {\n                kernel = backend_ctx->kernel_set_rows_f32_i64;\n            } else {\n                kernel = backend_ctx->kernel_set_rows_f32_i32;\n            }\n            break;\n        case GGML_TYPE_F16:\n            if (src1->type == GGML_TYPE_I64) {\n                kernel = backend_ctx->kernel_set_rows_f16_i64;\n            } else {\n                kernel = backend_ctx->kernel_set_rows_f16_i32;\n            }\n            break;\n        default:\n            GGML_ABORT(\"not implemented\");\n    }\n\n    fastdiv_vals ne11_ = init_fastdiv_values(ne11);\n    fastdiv_vals ne12_ = init_fastdiv_values(ne12);\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne01));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb03));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(fastdiv_vals), &ne11_));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(fastdiv_vals), &ne12_));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb10));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11));\n    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12));\n    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &nblk0));\n    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb1));\n    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb2));\n    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb3));\n\n    int nth0 = 64;\n    if (backend_ctx->gpu_family == INTEL) {\n        nth0 = 32;\n    } else if (backend_ctx->gpu_family == ADRENO) {\n        nth0 = 64;\n    }\n\n    int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);\n    while (nth0 < nblk0 && nth0 < max_workgroup_size) {\n        nth0 *= 2;\n    }\n\n    int rows_per_workgroup = 1;\n    if (nth0 > nblk0) {\n        rows_per_workgroup = nth0 / nblk0;\n        nth0 = nblk0;\n    }\n\n    size_t global_work_size[] = {\n        (size_t)(ne01 + rows_per_workgroup - 1)/rows_per_workgroup*nth0,\n        (size_t)ne02*rows_per_workgroup,\n        (size_t)ne03};\n    size_t local_work_size[] = {(size_t)nth0, (size_t)rows_per_workgroup, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const cl_ulong nb00 = src0->nb[0];\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const int ne10 = src1->ne[0];\n    const int ne11 = src1->ne[1];\n    const int ne12 = src1->ne[2];\n    const int ne13 = src1->ne[3];\n\n    const cl_ulong nb10 = src1->nb[0];\n    const cl_ulong nb11 = src1->nb[1];\n    const cl_ulong nb12 = src1->nb[2];\n    const cl_ulong nb13 = src1->nb[3];\n\n    const int ne0  = dst->ne[0];\n    const int ne1  = dst->ne[1];\n    const int ne2  = dst->ne[2];\n    const int ne3  = dst->ne[3];\n\n    const cl_ulong nb0  = dst->nb[0];\n    const cl_ulong nb1  = dst->nb[1];\n    const cl_ulong nb2  = dst->nb[2];\n    const cl_ulong nb3  = dst->nb[3];\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel;\n\n    const bool bcast_row = ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0;\n\n    if (bcast_row) {\n        GGML_ASSERT(ggml_is_contiguous(src0));\n        GGML_ASSERT(ne11 == 1);\n    }\n\n    if (dst->type == GGML_TYPE_F32) {\n        GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32);\n        if (bcast_row) {\n            kernel = backend_ctx->kernel_add_row;\n            const int ne = ne00 / 4;\n            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));\n        } else {\n            kernel = backend_ctx->kernel_add;\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne11));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne13));\n            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));\n            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));\n            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));\n            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));\n            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &ne2));\n            CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &ne3));\n            CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));\n            CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));\n            CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));\n            CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));\n        }\n    } else if (dst->type == GGML_TYPE_F16) {\n        GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);\n        GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);\n        const int type_src0 = (src0->type == GGML_TYPE_F32);\n        const int type_src1 = (src1->type == GGML_TYPE_F32);\n        if (bcast_row) {\n            kernel = backend_ctx->kernel_add_row_f16;\n            const int ne = ne00 / 4;\n            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));\n            CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),      &type_src0));\n            CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int),      &type_src1));\n        } else {\n            kernel = backend_ctx->kernel_add_f16;\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne11));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne13));\n            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));\n            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));\n            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));\n            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));\n            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &ne2));\n            CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &ne3));\n            CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));\n            CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));\n            CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));\n            CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));\n            CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int),      &type_src0));\n            CL_CHECK(clSetKernelArg(kernel, 31, sizeof(int),      &type_src1));\n        }\n    } else {\n        GGML_ASSERT(false && \"unsupported data types for add\");\n    }\n\n    if (bcast_row) {\n        int n = ggml_nelements(dst)/4;\n        size_t global_work_size[] = {(size_t)n, 1, 1};\n        size_t local_work_size[] = {64, 1, 1};\n\n        size_t * local_work_size_ptr = local_work_size;\n        if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {\n            local_work_size_ptr = nullptr;\n        }\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size_ptr, dst);\n    } else {\n        unsigned int nth = MIN(64, ne0);\n        size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n        size_t local_work_size[] = {nth, 1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    }\n}\n\nstatic void ggml_cl_add_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    const ggml_tensor * src2 = dst->src[2];\n    GGML_ASSERT(src2);\n    GGML_ASSERT(src2->extra);\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(src2->type == GGML_TYPE_I32);\n    GGML_ASSERT(dst->type  == GGML_TYPE_F32);\n\n    GGML_ASSERT(ggml_is_contiguous_rows(src0));\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n\n    const cl_ulong nb11 = src1->nb[1];\n\n    const cl_ulong nb21 = src2->nb[1];\n\n    const int ne0 = dst->ne[0];\n    const int ne1 = dst->ne[1];\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offset2 = extra2->offset + src2->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel = backend_ctx->kernel_add_id;\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extra2->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb21));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne0));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne1));\n\n    int nth = MIN(ne00, (int) backend_ctx->get_kernel_workgroup_size(kernel));\n    size_t global_work_size[] = { (size_t)ne01*nth, (size_t)ne02, 1 };\n    size_t local_work_size[] = { (size_t)nth, 1, 1 };\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    GGML_ASSERT(src0->type == src1->type);\n    GGML_ASSERT(src0->type == dst->type);\n    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const cl_ulong nb00 = src0->nb[0];\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const int ne10 = src1->ne[0];\n    const int ne11 = src1->ne[1];\n    const int ne12 = src1->ne[2];\n    const int ne13 = src1->ne[3]; UNUSED(ne13);\n\n    const cl_ulong nb10 = src1->nb[0];\n    const cl_ulong nb11 = src1->nb[1];\n    const cl_ulong nb12 = src1->nb[2];\n    const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);\n\n    const int ne0  = dst->ne[0];\n    const int ne1  = dst->ne[1];\n    const int ne2  = dst->ne[2];\n    const int ne3  = dst->ne[3];\n\n    const cl_ulong nb0  = dst->nb[0];\n    const cl_ulong nb1  = dst->nb[1];\n    const cl_ulong nb2  = dst->nb[2];\n    const cl_ulong nb3  = dst->nb[3];\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    bool bcast_row = false;\n    cl_kernel kernel;\n\n    if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {\n        GGML_ASSERT(ggml_is_contiguous(src0));\n\n        // src1 is a row\n        GGML_ASSERT(ne11 == 1);\n\n        bcast_row = true;\n        int ne = ne00 / 4;\n\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_mul_row;\n        } else {\n            kernel = backend_ctx->kernel_mul_row_f16;\n        }\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));\n        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));\n    } else {\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_mul;\n        } else {\n            kernel = backend_ctx->kernel_mul_f16;\n        }\n\n        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));\n        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));\n        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));\n        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));\n        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));\n        CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10));\n        CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne11));\n        CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne12));\n        CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne13));\n        CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));\n        CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));\n        CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));\n        CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));\n        CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &ne0));\n        CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &ne1));\n        CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &ne2));\n        CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &ne3));\n        CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));\n        CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));\n        CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));\n        CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));\n    }\n\n    if (bcast_row) {\n        int n = ggml_nelements(dst)/4;\n        size_t global_work_size[] = {(size_t)n, 1, 1};\n        size_t local_work_size[] = {64, 1, 1};\n\n        size_t * local_work_size_ptr = local_work_size;\n        if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {\n            local_work_size_ptr = nullptr;  // Let driver choose the work-group sizes.\n        }\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n    } else {\n        unsigned int nth = MIN(64, ne0);\n        size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};\n        size_t local_work_size[] = {nth, 1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    }\n}\n\nstatic void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    GGML_ASSERT(src0->type == src1->type);\n    GGML_ASSERT(src0->type == dst->type);\n    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const cl_ulong nb00 = src0->nb[0];\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const int ne10 = src1->ne[0];\n    const int ne11 = src1->ne[1];\n    const int ne12 = src1->ne[2];\n    const int ne13 = src1->ne[3];\n\n    const cl_ulong nb10 = src1->nb[0];\n    const cl_ulong nb11 = src1->nb[1];\n    const cl_ulong nb12 = src1->nb[2];\n    const cl_ulong nb13 = src1->nb[3];\n\n    const int ne0  = dst->ne[0];\n\n    const cl_ulong nb0  = dst->nb[0];\n    const cl_ulong nb1  = dst->nb[1];\n    const cl_ulong nb2  = dst->nb[2];\n    const cl_ulong nb3  = dst->nb[3];\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    bool bcast_row = false;\n    cl_kernel kernel;\n\n    if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {\n        GGML_ASSERT(ggml_is_contiguous(src0));\n\n        // src1 is a row\n        GGML_ASSERT(ne11 == 1);\n\n        bcast_row = true;\n        int ne = ne00 / 4;\n\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_div_row;\n        } else {\n            kernel = backend_ctx->kernel_div_row_f16;\n        }\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));\n        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));\n    } else {\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_div;\n        } else {\n            kernel = backend_ctx->kernel_div_f16;\n        }\n\n        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb00));\n        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));\n        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));\n        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb03));\n        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne10));\n        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne11));\n        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne12));\n        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne13));\n        CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10));\n        CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11));\n        CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12));\n        CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13));\n        CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &ne0));\n        CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0));\n        CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1));\n        CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));\n        CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));\n    }\n\n    if (bcast_row) {\n        int n = ggml_nelements(dst)/4;\n        size_t global_work_size[] = {(size_t)n, 1, 1};\n        size_t local_work_size[] = {64, 1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    } else {\n        unsigned int nth = MIN(64, ne0);\n        size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};\n        size_t local_work_size[] = {nth, 1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    }\n}\n\nstatic void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    GGML_ASSERT(src0->type == src1->type);\n    GGML_ASSERT(src0->type == dst->type);\n    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const cl_ulong nb00 = src0->nb[0];\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const int ne10 = src1->ne[0];\n    const int ne11 = src1->ne[1];\n    const int ne12 = src1->ne[2];\n    const int ne13 = src1->ne[3];\n\n    const cl_ulong nb10 = src1->nb[0];\n    const cl_ulong nb11 = src1->nb[1];\n    const cl_ulong nb12 = src1->nb[2];\n    const cl_ulong nb13 = src1->nb[3];\n\n    const int ne0  = dst->ne[0];\n\n    const cl_ulong nb0  = dst->nb[0];\n    const cl_ulong nb1  = dst->nb[1];\n    const cl_ulong nb2  = dst->nb[2];\n    const cl_ulong nb3  = dst->nb[3];\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    bool bcast_row = false;\n    cl_kernel kernel;\n\n    if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {\n        GGML_ASSERT(ggml_is_contiguous(src0));\n\n        // src1 is a row\n        GGML_ASSERT(ne11 == 1);\n\n        bcast_row = true;\n        int ne = ne00 / 4;\n\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_sub_row;\n        } else {\n            kernel = backend_ctx->kernel_sub_row_f16;\n        }\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));\n        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));\n    } else {\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_sub;\n        } else {\n            kernel = backend_ctx->kernel_sub_f16;\n        }\n\n        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb00));\n        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));\n        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));\n        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb03));\n        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne10));\n        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne11));\n        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne12));\n        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne13));\n        CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10));\n        CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11));\n        CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12));\n        CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13));\n        CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &ne0));\n        CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0));\n        CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1));\n        CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));\n        CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));\n    }\n\n    if (bcast_row) {\n        int n = ggml_nelements(dst)/4;\n        size_t global_work_size[] = {(size_t)n, 1, 1};\n        size_t local_work_size[] = {64, 1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    } else {\n        unsigned int nth = MIN(64, ne0);\n        size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};\n        size_t local_work_size[] = {nth, 1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    }\n}\n\nstatic void ggml_cl_sqr(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel;\n\n    // Currently assumes src0 is contiguous\n    int n = ggml_nelements(dst);\n    if (n % 4 == 0) {\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_sqr_cont_f32_4;\n        } else {\n            kernel = backend_ctx->kernel_sqr_cont_f16_4;\n        }\n        n /= 4;\n    } else {\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_sqr_cont_f32;\n        } else {\n            kernel = backend_ctx->kernel_sqr_cont_f16;\n        }\n    }\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n\n    size_t global_work_size[] = {(size_t)n, 1, 1};\n    size_t local_work_size[] = {64, 1, 1};\n\n    size_t * local_work_size_ptr = local_work_size;\n    if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {\n        local_work_size_ptr = nullptr;\n    }\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n}\n\nstatic void ggml_cl_sqrt(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel;\n\n    // Currently assumes src0 is contiguous\n    int n = ggml_nelements(dst);\n    if (n % 4 == 0) {\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_sqrt_cont_f32_4;\n        } else {\n            kernel = backend_ctx->kernel_sqrt_cont_f16_4;\n        }\n        n /= 4;\n    } else {\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_sqrt_cont_f32;\n        } else {\n            kernel = backend_ctx->kernel_sqrt_cont_f16;\n        }\n    }\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n\n    size_t global_work_size[] = {(size_t)n, 1, 1};\n    size_t local_work_size[] = {64, 1, 1};\n\n    size_t * local_work_size_ptr = local_work_size;\n    if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {\n        local_work_size_ptr = nullptr;\n    }\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n}\n\nstatic void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n    GGML_UNUSED(src1);\n\n    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const cl_ulong nb1  = dst->nb[1];\n    const cl_ulong nb2  = dst->nb[2];\n    const cl_ulong nb3  = dst->nb[3];\n\n    cl_kernel kernel;\n\n    const bool is_c4 = ne00 % 4 == 0;\n    if (is_c4) {\n        kernel = backend_ctx->kernel_mean_f32_4;\n    } else {\n        kernel = backend_ctx->kernel_mean_f32;\n    }\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),      &ne01));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne02));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne03));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));\n\n    size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03};\n    size_t local_work_size[] = {(size_t)64, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_ssm_conv(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    int ne01 = src0->ne[1];\n    cl_ulong nb00 = src0->nb[0];\n    cl_ulong nb01 = src0->nb[1];\n    cl_ulong nb02 = src0->nb[2];\n\n    int ne10 = src1->ne[0];\n    cl_ulong nb11 = src1->nb[1];\n\n    int ne1  = dst->ne[1];\n    int ne2  = dst->ne[2];\n    cl_ulong nb0 = dst->nb[0];\n    cl_ulong nb1 = dst->nb[1];\n    cl_ulong nb2 = dst->nb[2];\n\n    cl_kernel kernel = backend_ctx->kernel_ssm_conv_f32_f32;\n\n    if (ne10 % 4 == 0) {\n        kernel = backend_ctx->kernel_ssm_conv_f32_f32_4;\n    }\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb00));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb0));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2));\n\n    size_t global_work_size[] = {(size_t)ne01, (size_t)ne1, (size_t)ne2};\n    size_t local_work_size[]  = {64, 1, 1};\n\n    size_t * local_work_size_ptr = local_work_size;\n    if (ne01 % 64 != 0 && !backend_ctx->non_uniform_workgroups) {\n        local_work_size_ptr = nullptr;\n    }\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n}\n\nstatic void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel;\n\n    int n = ggml_nelements(dst);\n\n    if (n % 4 == 0) {\n        kernel = backend_ctx->kernel_gelu_4;\n        n /= 4;\n    } else {\n        kernel = backend_ctx->kernel_gelu;\n    }\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n\n    size_t global_work_size[] = {(size_t)n, 1, 1};\n    size_t local_work_size[] = {64, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_gelu_erf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel;\n\n    int n = ggml_nelements(dst);\n\n    if (n % 4 == 0) {\n        kernel = backend_ctx->kernel_gelu_erf_4;\n        n /= 4;\n    } else {\n        kernel = backend_ctx->kernel_gelu_erf;\n    }\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n\n    size_t global_work_size[] = {(size_t)n, 1, 1};\n    size_t local_work_size[] = {64, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel;\n\n    int n = ggml_nelements(dst);\n\n    if (n % 4 == 0) {\n        kernel = backend_ctx->kernel_gelu_quick_4;\n        n /= 4;\n    } else {\n        kernel = backend_ctx->kernel_gelu_quick;\n    }\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n\n    size_t global_work_size[] = {(size_t)n, 1, 1};\n    size_t local_work_size[] = {64, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel;\n\n    int n = ggml_nelements(dst);\n\n    if (n % 4 == 0) {\n        kernel = backend_ctx->kernel_silu_4;\n        n /= 4;\n    } else {\n        kernel = backend_ctx->kernel_silu;\n    }\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n\n    size_t global_work_size[] = {(size_t)n, 1, 1};\n    size_t local_work_size[] = {64, 1, 1};\n\n    size_t * local_work_size_ptr = local_work_size;\n    if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {\n        local_work_size_ptr = nullptr;  // Let driver choose the work-group sizes.\n    }\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n}\n\nstatic void ggml_cl_relu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel = backend_ctx->kernel_relu;\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n\n    const int64_t n = ggml_nelements(dst);\n\n    size_t global_work_size[] = {(size_t)n, 1, 1};\n    size_t local_work_size[] = {64, 1, 1};\n\n    size_t * local_work_size_ptr = local_work_size;\n    if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {\n        local_work_size_ptr = nullptr;  // Let driver choose the work-group sizes.\n    }\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n}\n\nstatic void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel;\n    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n        kernel = backend_ctx->kernel_sigmoid_f32;\n    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {\n        kernel = backend_ctx->kernel_sigmoid_f16;\n    } else {\n        GGML_ASSERT(false && \"Unsupported data types for sigmoid (input and output must be both f32 or f16)\");\n    }\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n\n    const int64_t n = ggml_nelements(dst);\n\n    size_t global_work_size[] = {(size_t)n, 1, 1};\n    size_t local_work_size[] = {64, 1, 1};\n\n    size_t * local_work_size_ptr = local_work_size;\n    if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {\n        local_work_size_ptr = nullptr;  // Let driver choose the work-group sizes.\n    }\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n}\n\nstatic void ggml_cl_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    const int tri_type = ggml_get_op_params_i32(dst, 0);\n    const int64_t n = ggml_nelements(dst);\n    const int     ne0  = dst->ne[0];\n    const int     ne1  = dst->ne[1];\n\n    cl_kernel kernel = backend_ctx->kernel_tri;\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),      &n));\n    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),      &ne0));\n    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne1));\n    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),      &tri_type));\n\n    size_t local_work_size[1] = { 256 };\n    size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] };\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src0);\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    float v = 0.0f;\n    memcpy(&v, ((int32_t *) dst->op_params), sizeof(float));\n\n    const int64_t n = ggml_nelements(dst);\n\n    cl_kernel kernel = backend_ctx->kernel_fill;\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(float),    &v));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(float),    &n));\n\n    size_t local_work_size[1] = { 256 };\n    size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] };\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    float min;\n    float max;\n    memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));\n    memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));\n\n    cl_kernel kernel = backend_ctx->kernel_clamp;\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float),    &min));\n    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float),    &max));\n\n    const int64_t n = ggml_nelements(dst);\n\n    size_t global_work_size[] = {(size_t)n, 1, 1};\n    size_t local_work_size[] = {64, 1, 1};\n\n    size_t * local_work_size_ptr = local_work_size;\n    if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {\n        local_work_size_ptr = nullptr;  // Let driver choose the work-group sizes.\n    }\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n}\n\nstatic void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n\n    const int ne00 = src0 ? src0->ne[0] : 0;\n    const int ne01 = src0 ? src0->ne[1] : 0;\n    const int ne02 = src0 ? src0->ne[2] : 0;\n    const int ne03 = src0 ? src0->ne[3] : 0;\n\n    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;\n    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;\n    const cl_ulong nb03 = src0 ? src0->nb[3] : 0;\n\n    const int nth = MIN(64, ne00);\n\n    cl_kernel kernel = backend_ctx->kernel_norm;\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),    &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong),  &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),    &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong),  &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),       &ne00));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),       &ne01));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),       &ne02));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),       &ne03));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong),  &nb01));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong),  &nb02));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),  &nb03));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float),     &eps));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL));\n\n    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n    size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    //ggml_backend_opencl_device_context * dev_ctx =\n    //    (ggml_backend_opencl_device_context *)backend->device->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n\n    const int ne00 = src0 ? src0->ne[0] : 0;\n    const int ne01 = src0 ? src0->ne[1] : 0;\n    const int ne02 = src0 ? src0->ne[2] : 0;\n    const int ne03 = src0 ? src0->ne[3] : 0;\n\n    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;\n    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;\n    const cl_ulong nb03 = src0 ? src0->nb[3] : 0;\n\n    GGML_ASSERT(ne00 % 4 == 0);\n\n    const int nth = MIN(64, ne00);\n\n    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n    size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n    cl_kernel kernel = backend_ctx->kernel_rms_norm;\n\n    // Note, this kernel declares local memory in kernel args and the size\n    // depends on subgroup size.\n    // Note, this requires OpenCL 2.1 and above\n    // For now we use fixed subgroup size to simplify support for OpenCL 2.0.\n    size_t sgs;\n    //CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device,\n    //    CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE,\n    //    sizeof(local_work_size), local_work_size,\n    //    sizeof(size_t), &sgs, NULL));\n    if (backend_ctx->gpu_family == ADRENO) {\n        sgs = 64;\n    } else if (backend_ctx->gpu_family == INTEL) {\n        sgs = 32;\n    } else {\n        GGML_ASSERT(false && \"Unsupported GPU\");\n    }\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),    &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong),  &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),    &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong),  &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),       &ne00));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),       &ne01));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),       &ne02));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),       &ne03));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong),  &nb01));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong),  &nb02));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),  &nb03));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float),     &eps));\n    // This is local memory - the size depends on subgroup size.\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs,  NULL));\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) {\n    GGML_ASSERT(mul_tensor);\n    GGML_ASSERT(rms_norm_tensor);\n\n    // src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm)\n    const ggml_tensor * src0 = rms_norm_tensor->src[0];\n    const ggml_tensor * src1;\n    if (mul_tensor->src[0] == rms_norm_tensor) {\n        src1 = mul_tensor->src[1];\n    } else if (mul_tensor->src[1] == rms_norm_tensor) {\n        src1 = mul_tensor->src[0];\n    } else {\n        GGML_ASSERT(false && \"Invalid args for rms_norm and mul\");\n    }\n    const ggml_tensor * dst = mul_tensor;\n\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    float eps;\n    memcpy(&eps, rms_norm_tensor->op_params, sizeof(float));\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const int ne10 = src1->ne[0];\n    const int ne11 = src1->ne[1];\n    const int ne12 = src1->ne[2];\n    const int ne13 = src1->ne[3];\n\n    const cl_ulong nb11 = src1->nb[1];\n    const cl_ulong nb12 = src1->nb[2];\n    const cl_ulong nb13 = src1->nb[3];\n\n    const cl_ulong nb1 = dst->nb[1];\n    const cl_ulong nb2 = dst->nb[2];\n    const cl_ulong nb3 = dst->nb[3];\n\n    GGML_ASSERT(ne00 % 4 == 0);\n\n    size_t sgs;\n    if (backend_ctx->gpu_family == ADRENO) {\n        sgs = 64;\n    } else if (backend_ctx->gpu_family == INTEL) {\n        sgs = 32;\n    } else {\n        GGML_ASSERT(false && \"Unsupported GPU\");\n    }\n\n    cl_kernel kernel = backend_ctx->kernel_rms_norm_mul;\n\n    int nth = sgs;\n    int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);\n    while (nth < ne00 && nth < max_workgroup_size) {\n        nth *= 2;\n    }\n    nth = MIN(nth, max_workgroup_size);\n    nth = MIN(nth, ne00);\n\n    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n    size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),        &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong),      &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),        &extra1->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong),      &offset1));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),        &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong),      &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),           &ne00));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),           &ne01));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),           &ne02));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),           &ne03));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),      &nb01));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),      &nb02));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong),      &nb03));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),           &ne10));\n    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),           &ne11));\n    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),           &ne12));\n    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),           &ne13));\n    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),      &nb11));\n    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),      &nb12));\n    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),      &nb13));\n    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong),      &nb1));\n    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong),      &nb2));\n    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong),      &nb3));\n    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float),         &eps));\n    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs,     NULL));\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {\n    GGML_ASSERT(norm_tensor && mul_tensor && add_tensor);\n\n    const ggml_tensor * src0 = norm_tensor->src[0];\n    const ggml_tensor * src1 = mul_tensor->src[0] == norm_tensor ? mul_tensor->src[1] : mul_tensor->src[0];\n    const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];\n    const ggml_tensor * dst = add_tensor;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offset2 = extra2->offset + src2->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    float eps;\n    memcpy(&eps, norm_tensor->op_params, sizeof(float));\n\n    const int ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3];\n    const cl_ulong nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3];\n    const int ne10 = src1->ne[0], ne11 = src1->ne[1], ne12 = src1->ne[2], ne13 = src1->ne[3];\n    const cl_ulong nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3];\n    const int ne20 = src2->ne[0], ne21 = src2->ne[1], ne22 = src2->ne[2], ne23 = src2->ne[3];\n    const cl_ulong nb21 = src2->nb[1], nb22 = src2->nb[2], nb23 = src2->nb[3];\n    const cl_ulong nbd1 = dst->nb[1], nbd2 = dst->nb[2], nbd3 = dst->nb[3];\n\n    size_t sgs;\n    if (backend_ctx->gpu_family == ADRENO) sgs = 64;\n    else if (backend_ctx->gpu_family == INTEL) sgs = 32;\n    else GGML_ASSERT(false && \"Unsupported GPU\");\n\n    cl_kernel kernel = backend_ctx->kernel_norm_mul_add;\n\n    int nth = sgs;\n    int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);\n    while (nth < ne00/4 && nth < max_workgroup_size) nth *= 2;\n    nth = MIN(nth, max_workgroup_size);\n    nth = MIN(nth, ne00/4);\n\n    size_t gws[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n    size_t lws[] = {(size_t)nth, 1, 1};\n    size_t num_subgroups = (nth + sgs - 1) / sgs;\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));\n    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));\n    CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb03));\n    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10));\n    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne11));\n    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne12));\n    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne13));\n    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));\n    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));\n    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));\n    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne20));\n    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne21));\n    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne22));\n    CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne23));\n    CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb21));\n    CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb22));\n    CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb23));\n    CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nbd1));\n    CL_CHECK(clSetKernelArg(kernel, 30, sizeof(cl_ulong), &nbd2));\n    CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_ulong), &nbd3));\n    CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &eps));\n    CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_float2) * num_subgroups, NULL));\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, gws, lws, dst);\n}\n\nstatic void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {\n    GGML_ASSERT(gn_tensor && mul_tensor && add_tensor);\n\n    const ggml_tensor * src0 = gn_tensor->src[0];\n    const ggml_tensor * src1 = mul_tensor->src[0] == gn_tensor ? mul_tensor->src[1] : mul_tensor->src[0];\n    const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];\n    const ggml_tensor * dst = add_tensor;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offset2 = extra2->offset + src2->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    int groups;\n    float eps;\n    memcpy(&groups, gn_tensor->op_params, sizeof(int));\n    memcpy(&eps, (char *)gn_tensor->op_params + sizeof(int), sizeof(float));\n\n    cl_kernel kernel = backend_ctx->kernel_group_norm_mul_add;\n    int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);\n    int ne = ggml_nelements(src0);\n    int group_size = ne / groups;\n\n    size_t lws[] = { (size_t)MIN(max_workgroup_size, group_size) };\n    size_t gws[] = { (size_t)groups * lws[0] };\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));\n    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne));\n    CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &group_size));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &eps));\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 1, gws, lws, dst);\n}\n\nstatic void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    int32_t n_groups   = ((const int32_t *) dst->op_params)[0];\n    int32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + n_groups - 1) / n_groups);\n    float   eps        = ((const float *) dst->op_params)[1];\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne = ne00*ne01*ne02;\n\n    cl_kernel kernel = backend_ctx->kernel_group_norm;\n\n    size_t sgs = 64;\n    if (backend_ctx->gpu_family == ADRENO) {\n        sgs = 64;\n    } else if (backend_ctx->gpu_family == INTEL) {\n        sgs = 32;\n    } else {\n        GGML_ASSERT(false && \"Unsupported GPU\");\n    }\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),      &ne));\n    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),      &group_size));\n    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float),    &eps));\n\n    size_t global_work_size[] = {(size_t)n_groups*sgs, 1, 1};\n    size_t local_work_size[] = {(size_t)sgs, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_l2_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n\n    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);\n\n    size_t sgs;\n    if (backend_ctx->gpu_family == ADRENO) {\n        sgs = 64;\n    } else if (backend_ctx->gpu_family == INTEL) {\n        sgs = 32;\n    } else {\n        GGML_ASSERT(false && \"Unsupported GPU\");\n    }\n\n    cl_kernel kernel = backend_ctx->kernel_l2_norm_f32;\n\n    int nth = sgs;\n    while (nth < ne00 && nth < (int)backend_ctx->get_kernel_workgroup_size(kernel)) {\n        nth *= 2;\n    }\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),    &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong),  &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),    &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong),  &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),       &ne00));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),       &ne01));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),       &ne02));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),       &ne03));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong),  &nb01));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong),  &nb02));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),  &nb03));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float),     &eps));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs,  NULL));\n\n    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n    size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const cl_ulong nb00 = src0->nb[0];\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const cl_ulong nb0  = dst->nb[0];\n    const cl_ulong nb1  = dst->nb[1];\n    const cl_ulong nb2  = dst->nb[2];\n    const cl_ulong nb3  = dst->nb[3];\n\n    cl_kernel kernel;\n\n    if (ggml_is_contiguous(src0)) {\n        // Handle contiguous input\n        int n = ggml_nelements(dst);\n        if (n % 4 == 0) {\n            if (src0->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_tanh_f32_4;\n            } else {\n                kernel = backend_ctx->kernel_tanh_f16_4;\n            }\n            n /= 4;\n        } else {\n            if (src0->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_tanh_f32;\n            } else {\n                kernel = backend_ctx->kernel_tanh_f16;\n            }\n        }\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n\n        size_t global_work_size[] = {(size_t)n, 1, 1};\n        size_t local_work_size[] = {64, 1, 1};\n\n        size_t * local_work_size_ptr = local_work_size;\n        if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {\n            local_work_size_ptr = nullptr;\n        }\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n    } else {\n        // Handle non-contiguous input\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_tanh_f32_nc;\n        } else {\n            kernel = backend_ctx->kernel_tanh_f16_nc;\n        }\n\n        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));\n        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &nb00));\n        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb01));\n        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb02));\n        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb03));\n        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb0));\n        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));\n        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));\n        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));\n\n        int nth = 64;\n\n        size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n        size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    }\n}\n\nstatic void ggml_cl_neg(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);\n    GGML_TENSOR_LOCALS(int,      ne,  dst,  ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb,  dst,  nb);\n\n    cl_kernel kernel;\n\n    if (ggml_is_contiguous(src0)) {\n        // Handle contiguous input\n        int n = ggml_nelements(dst);\n        if (n % 4 == 0) {\n            if (src0->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_neg_f32_4;\n            } else {\n                kernel = backend_ctx->kernel_neg_f16_4;\n            }\n            n /= 4;\n        } else {\n            if (src0->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_neg_f32;\n            } else {\n                kernel = backend_ctx->kernel_neg_f16;\n            }\n        }\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int),   &n));\n\n        size_t global_work_size[] = {(size_t)CEIL_DIV(n, 64)*64, 1, 1};\n        size_t local_work_size[] = {64, 1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    } else {\n        // Handle non-contiguous input\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_neg_f32_nc;\n        } else {\n            kernel = backend_ctx->kernel_neg_f16_nc;\n        }\n\n        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));\n        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &nb00));\n        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb01));\n        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb02));\n        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb03));\n        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb0));\n        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));\n        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));\n        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));\n\n        int nth = 64;\n\n        size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n        size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    }\n}\n\nstatic void ggml_cl_exp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);\n    GGML_TENSOR_LOCALS(int,      ne,  dst,  ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb,  dst,  nb);\n\n    cl_kernel kernel;\n\n    if (ggml_is_contiguous(src0)) {\n        // Handle contiguous input\n        int n = ggml_nelements(dst);\n        if (n % 4 == 0) {\n            if (src0->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_exp_f32_4;\n            } else {\n                kernel = backend_ctx->kernel_exp_f16_4;\n            }\n            n /= 4;\n        } else {\n            if (src0->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_exp_f32;\n            } else {\n                kernel = backend_ctx->kernel_exp_f16;\n            }\n        }\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int),   &n));\n\n        size_t global_work_size[] = {(size_t)CEIL_DIV(n, 64)*64, 1, 1};\n        size_t local_work_size[] = {64, 1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    } else {\n        // Handle non-contiguous input\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_exp_f32_nc;\n        } else {\n            kernel = backend_ctx->kernel_exp_f16_nc;\n        }\n\n        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));\n        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &nb00));\n        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb01));\n        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb02));\n        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb03));\n        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb0));\n        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));\n        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));\n        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));\n\n        int nth = 64;\n\n        size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n        size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    }\n}\n\nstatic void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const cl_ulong nb00 = src0->nb[0];\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const cl_ulong nb0 = dst->nb[0];\n    const cl_ulong nb1 = dst->nb[1];\n    const cl_ulong nb2 = dst->nb[2];\n    const cl_ulong nb3 = dst->nb[3];\n\n    cl_kernel kernel;\n\n    if (ggml_is_contiguous(src0)) {\n        // Handle contiguous input\n        int n = ggml_nelements(dst);\n        if (n % 4 == 0) {\n            if (src0->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_expm1_f32_4;\n            } else {\n                kernel = backend_ctx->kernel_expm1_f16_4;\n            }\n            n /= 4;\n        } else {\n            if (src0->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_expm1_f32;\n            } else {\n                kernel = backend_ctx->kernel_expm1_f16;\n            }\n        }\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n\n        size_t global_work_size[] = {(size_t)n, 1, 1};\n        size_t local_work_size[] = {64, 1, 1};\n\n        size_t * local_work_size_ptr = local_work_size;\n        if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {\n            local_work_size_ptr = nullptr;\n        }\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n    } else {\n        // Handle non-contiguous input\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_expm1_f32_nc;\n        } else {\n            kernel = backend_ctx->kernel_expm1_f16_nc;\n        }\n\n        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));\n        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &nb00));\n        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb01));\n        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb02));\n        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb03));\n        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb0));\n        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));\n        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));\n        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));\n\n        int nth = 64;\n\n        size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n        size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    }\n}\n\nstatic void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const cl_ulong nb00 = src0->nb[0];\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const cl_ulong nb0 = dst->nb[0];\n    const cl_ulong nb1 = dst->nb[1];\n    const cl_ulong nb2 = dst->nb[2];\n    const cl_ulong nb3 = dst->nb[3];\n\n    cl_kernel kernel;\n\n    if (ggml_is_contiguous(src0)) {\n        // Handle contiguous input\n        int n = ggml_nelements(dst);\n        if (n % 4 == 0) {\n            if (src0->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_softplus_f32_4;\n            } else {\n                kernel = backend_ctx->kernel_softplus_f16_4;\n            }\n            n /= 4;\n        } else {\n            if (src0->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_softplus_f32;\n            } else {\n                kernel = backend_ctx->kernel_softplus_f16;\n            }\n        }\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n\n        size_t global_work_size[] = {(size_t)n, 1, 1};\n        size_t local_work_size[] = {64, 1, 1};\n\n        size_t * local_work_size_ptr = local_work_size;\n        if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {\n            local_work_size_ptr = nullptr;\n        }\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n    } else {\n        // Handle non-contiguous input\n        if (src0->type == GGML_TYPE_F32) {\n            kernel = backend_ctx->kernel_softplus_f32_nc;\n        } else {\n            kernel = backend_ctx->kernel_softplus_f16_nc;\n        }\n\n        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));\n        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &nb00));\n        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb01));\n        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb02));\n        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb03));\n        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb0));\n        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));\n        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));\n        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));\n\n        int nth = 64;\n\n        size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n        size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    }\n}\n\nstatic void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n    GGML_ASSERT(dst->type == src0->type);\n\n    UNUSED(src1_shape_def);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad  = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd  = extrad->offset + dst->view_offs;\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const cl_ulong nb00 = src0->nb[0];\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const int ne0 = dst->ne[0];\n    const int ne1 = dst->ne[1];\n    const int ne2 = dst->ne[2];\n    const int ne3 = dst->ne[3];\n\n    const cl_ulong nb0 = dst->nb[0];\n    const cl_ulong nb1 = dst->nb[1];\n    const cl_ulong nb2 = dst->nb[2];\n    const cl_ulong nb3 = dst->nb[3];\n\n    cl_kernel kernel = backend_ctx->kernel_repeat_f32;\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),      &ne01));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne02));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne03));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb00));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne0));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb0));\n    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1));\n    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2));\n    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3));\n\n    int nth = 64;\n\n    size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3};\n    size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    if (backend_ctx->kernel_pad == nullptr) {\n        GGML_LOG_WARN(\"%s: pad kernel not available, skipping OpenCL execution.\\n\", __func__);\n        return;\n    }\n\n    ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra_dst  = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong off_src0 = extra_src0->offset + src0->view_offs;\n    cl_ulong off_dst  = extra_dst->offset  + dst->view_offs;\n\n    const int s_ne0 = src0->ne[0];\n    const int s_ne1 = src0->ne[1];\n    const int s_ne2 = src0->ne[2];\n    const int s_ne3 = src0->ne[3];\n\n    const int s_nb0 = src0->nb[0];\n    const int s_nb1 = src0->nb[1];\n    const int s_nb2 = src0->nb[2];\n    const int s_nb3 = src0->nb[3];\n\n    const int d_ne0 = dst->ne[0];\n    const int d_ne1 = dst->ne[1];\n    const int d_ne2 = dst->ne[2];\n    const int d_ne3 = dst->ne[3];\n\n    const int d_nb0 = dst->nb[0];\n    const int d_nb1 = dst->nb[1];\n    const int d_nb2 = dst->nb[2];\n    const int d_nb3 = dst->nb[3];\n\n    const int lp0 = ((const int*)(dst->op_params))[0];\n    const int rp0 = ((const int*)(dst->op_params))[1];\n    const int lp1 = ((const int*)(dst->op_params))[2];\n    const int rp1 = ((const int*)(dst->op_params))[3];\n    const int lp2 = ((const int*)(dst->op_params))[4];\n    const int rp2 = ((const int*)(dst->op_params))[5];\n    const int lp3 = ((const int*)(dst->op_params))[6];\n    const int rp3 = ((const int*)(dst->op_params))[7];\n\n    cl_kernel kernel = backend_ctx->kernel_pad;\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),    &extra_src0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong),  &off_src0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),    &extra_dst->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong),  &off_dst));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),       &s_ne0));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),       &s_ne1));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),       &s_ne2));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),       &s_ne3));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong),  &s_nb0));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong),  &s_nb1));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),  &s_nb2));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),  &s_nb3));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),       &d_ne0));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),       &d_ne1));\n    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),       &d_ne2));\n    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),       &d_ne3));\n    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),  &d_nb0));\n    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),  &d_nb1));\n    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),  &d_nb2));\n    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),  &d_nb3));\n    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),       &lp0));\n    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int),       &rp0));\n    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),       &lp1));\n    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),       &rp1));\n    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),       &lp2));\n    CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),       &rp2));\n    CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int),       &lp3));\n    CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int),       &rp3));\n\n    size_t lws0 = 64;\n    size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0;\n\n    size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 };\n    size_t local_work_size[]  = { lws0, 1, 1 };\n\n    size_t * local_work_size_ptr = local_work_size;\n     if (d_ne0 % lws0 != 0 && !backend_ctx->non_uniform_workgroups) {\n        local_work_size_ptr = nullptr;\n    }\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n}\n\nstatic void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    const int mode_flags        = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);\n    const ggml_scale_mode mode  = (ggml_scale_mode) (mode_flags & 0xFF);\n    cl_kernel kernel = nullptr;\n\n    if (mode == GGML_SCALE_MODE_NEAREST) {\n        kernel = backend_ctx->kernel_upscale;\n        if (kernel == nullptr) {\n            GGML_LOG_WARN(\"%s: nearest upscale kernel not available, skipping OpenCL execution.\\n\", __func__);\n            return;\n        }\n    } else if (mode == GGML_SCALE_MODE_BILINEAR) {\n        kernel = backend_ctx->kernel_upscale_bilinear;\n        if (kernel == nullptr) {\n            GGML_LOG_WARN(\"%s: bilinear upscale kernel not available, skipping OpenCL execution.\\n\", __func__);\n            return;\n        }\n    } else {\n        GGML_LOG_WARN(\"%s: unsupported upscale mode %d, skipping OpenCL execution.\\n\", __func__, mode);\n        return;\n    }\n\n    ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra_dst  = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong off_src0 = extra_src0->offset + src0->view_offs;\n    cl_ulong off_dst  = extra_dst->offset  + dst->view_offs;\n\n    const cl_ulong nb00 = src0->nb[0];\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const int ne0 = dst->ne[0];\n    const int ne1 = dst->ne[1];\n    const int ne2 = dst->ne[2];\n    const int ne3 = dst->ne[3];\n\n    float sf0 = (float)ne0 / ne00;\n    float sf1 = (float)ne1 / ne01;\n    float sf2 = (float)ne2 / ne02;\n    float sf3 = (float)ne3 / ne03;\n\n    float pixel_offset = 0.5f;\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),    &extra_src0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong),  &off_src0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),    &extra_dst->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong),  &off_dst));\n    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong),  &nb00));\n    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong),  &nb01));\n    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong),  &nb02));\n    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong),  &nb03));\n\n    if (mode == GGML_SCALE_MODE_NEAREST) {\n        CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int),       &ne0));\n        CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int),       &ne1));\n        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne2));\n        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne3));\n        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float),    &sf0));\n        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float),    &sf1));\n        CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float),    &sf2));\n        CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float),    &sf3));\n    } else if (mode == GGML_SCALE_MODE_BILINEAR) {\n        if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {\n            sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;\n            sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;\n            pixel_offset = 0.0f;\n        }\n\n        CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int),       &ne00));\n        CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int),       &ne01));\n        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne0));\n        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne1));\n        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne2));\n        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne3));\n        CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float),    &sf0));\n        CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float),    &sf1));\n        CL_CHECK(clSetKernelArg(kernel, 16, sizeof(float),    &sf2));\n        CL_CHECK(clSetKernelArg(kernel, 17, sizeof(float),    &sf3));\n        CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float),    &pixel_offset));\n    }\n\n\n    size_t dst_total_elements = (size_t)ne0 * ne1 * ne2 * ne3;\n    if (dst_total_elements == 0) {\n        return;\n    }\n    size_t global_work_size[] = { dst_total_elements, 1, 1 };\n    size_t local_work_size_pref = 256;\n    size_t local_work_size[] = { MIN(local_work_size_pref, dst_total_elements), 1, 1};\n\n    size_t * local_work_size_ptr = local_work_size;\n    if (dst_total_elements % local_work_size[0] != 0 && !backend_ctx->non_uniform_workgroups) {\n        local_work_size_ptr = nullptr;\n    }\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n}\n\nstatic void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd  = extrad->offset + dst->view_offs;\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const cl_ulong nb00 = src0->nb[0];\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const cl_ulong nb10 = src1->nb[0];\n    const cl_ulong nb11 = src1->nb[1];\n    const cl_ulong nb12 = src1->nb[2];\n    const cl_ulong nb13 = src1->nb[3];\n\n    const int ne0 = dst->ne[0];\n    const int ne1 = dst->ne[1];\n    const int ne2 = dst->ne[2];\n    const int ne3 = dst->ne[3];\n\n    const cl_ulong nb0 = dst->nb[0];\n    const cl_ulong nb1 = dst->nb[1];\n    const cl_ulong nb2 = dst->nb[2];\n    const cl_ulong nb3 = dst->nb[3];\n\n    const cl_int dim = ((const int32_t *) dst->op_params)[0];\n    GGML_ASSERT(dim >= 0 && dim <= 3);\n\n    int nth = MIN(64, ne0);\n\n    cl_kernel kernel = backend_ctx->kernel_concat_f32;\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));\n    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10));\n    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11));\n    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12));\n    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13));\n    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &ne0));\n    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0));\n    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1));\n    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));\n    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));\n    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_int),   &dim));\n\n    size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3};\n    size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    if (backend_ctx->kernel_timestep_embedding == nullptr) {\n        GGML_LOG_WARN(\"%s: timestep_embedding kernel not available, skipping OpenCL execution.\\n\", __func__);\n        return;\n    }\n\n    ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra_dst  = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong off_src0 = extra_src0->offset + src0->view_offs;\n    cl_ulong off_dst  = extra_dst->offset  + dst->view_offs;\n\n    const int logical_dim = dst->op_params[0];\n    const int max_period  = dst->op_params[1];\n    const int dst_nb1_bytes = dst->nb[1];\n\n    cl_kernel kernel = backend_ctx->kernel_timestep_embedding;\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),    &extra_src0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong),  &off_src0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),    &extra_dst->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong),  &off_dst));\n    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),       &dst_nb1_bytes));\n    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),       &logical_dim));\n    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),       &max_period));\n\n    size_t gws0 = (size_t)(((logical_dim + 1) / 2) + 1);\n\n    size_t gws1 = (size_t)src0->ne[0];\n\n    size_t global_work_size[] = {gws0, gws1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);\n}\n\nstatic void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {\n    const ggml_tensor * v = dst->src[2];\n    const ggml_tensor * mask = dst->src[3];\n    const ggml_tensor * sinks = dst->src[4];\n    GGML_ASSERT(q->extra);\n    GGML_ASSERT(k->extra);\n    GGML_ASSERT(v->extra);\n    GGML_ASSERT(dst->extra);\n    if (mask) {\n        GGML_ASSERT(mask->extra);\n    }\n    if (sinks) {\n        GGML_ASSERT(sinks->extra);\n    }\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    const int n_q = q->ne[1];\n    const int n_kv = k->ne[1];\n    const int d_head_q = q->ne[0];\n    const int d_head_v = v->ne[0];\n    const int n_head = q->ne[2];\n    const int n_head_kv = k->ne[2];\n    const int n_batch = q->ne[3];\n\n    cl_kernel kernel = NULL;\n\n    const bool is_f16 = q->type == GGML_TYPE_F16;\n    const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16;\n    const std::pair<int, int> dk_dv = {d_head_q, d_head_v};\n\n    if (n_q == 1) {\n        if (is_mixed) {\n            kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv);\n        } else if (is_f16) {\n            kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv);\n        } else {\n            kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv);\n        }\n    } else {\n        if (is_mixed) {\n            kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv);\n        } else if (is_f16) {\n            kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv);\n        } else {\n            kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv);\n        }\n    }\n    GGML_ASSERT(kernel != NULL);\n\n    ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra;\n    ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra;\n    ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;\n    ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;\n    ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;\n    ggml_tensor_extra_cl * extra_sinks = sinks ? (ggml_tensor_extra_cl *)sinks->extra : NULL;\n\n    cl_ulong offset_q = extra_q->offset + q->view_offs;\n    cl_ulong offset_k = extra_k->offset + k->view_offs;\n    cl_ulong offset_v = extra_v->offset + v->view_offs;\n    cl_ulong offset_o = extra_o->offset + dst->view_offs;\n    cl_mem   mask_buffer = extra_mask ? extra_mask->data_device : NULL;\n    cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;\n    cl_mem   sinks_buffer = extra_sinks ? extra_sinks->data_device : NULL;\n    cl_ulong offset_sinks = extra_sinks ? extra_sinks->offset + sinks->view_offs : 0;\n\n    const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];\n    const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];\n    const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3];\n    const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3];\n    const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0;\n    const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0;\n    const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0;\n    const int mask_ne2 = mask ? mask->ne[2] : 0;\n    const int mask_ne3 = mask ? mask->ne[3] : 0;\n\n    float scale, max_bias, logit_softcap;\n    const float * params = (const float *)dst->op_params;\n    scale         = params[0];\n    max_bias      = params[1];\n    logit_softcap = params[2];\n\n    const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv);\n\n    const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0;\n    const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f;\n    const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f);\n    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f);\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra_q->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra_k->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k));\n    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extra_v->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v));\n    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem),   &extra_o->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o));\n    CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float),    &scale));\n    CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int),      &n_q));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),     &n_kv));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),     &is_causal));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),     &n_head));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3));\n    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3));\n    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3));\n    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3));\n    CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float),    &max_bias));\n    CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float),    &m0));\n    CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float),    &m1));\n    CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int),      &n_head_log2_val));\n    CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float),    &logit_softcap));\n    CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int),      &n_head_kv));\n    CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem),   &mask_buffer));\n    CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask));\n    CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1));\n    CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2));\n    CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));\n    CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int),      &mask_ne2));\n    CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int),      &mask_ne3));\n    CL_CHECK(clSetKernelArg(kernel, 38, sizeof(cl_mem),   &sinks_buffer));\n    CL_CHECK(clSetKernelArg(kernel, 39, sizeof(cl_ulong), &offset_sinks));\n\n    if (n_q == 1) {\n        const size_t wg_size = 64;\n        size_t local_work_size[] = { wg_size, 1 };\n        size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) };\n        backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);\n    } else {\n        const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv);\n        const size_t wg_size = block_m;\n        size_t local_work_size[] = { wg_size, 1 };\n        size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) };\n        backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);\n    }\n}\n\nstatic void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    const int M = src0->ne[1];\n    const int N = src1->ne[1];\n    const int K = src0->ne[0];\n\n    cl_kernel kernel = backend_ctx->kernel_mul_mat_f16_f32_tiled;\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(int),      &M));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int),      &N));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int),      &K));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem),   &extra1->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd));\n\n    // Tiling parameters. These need to be tuned for optimal performance.\n    // They must match the #defines in the kernel mul_mat_f16_f32.cl.\n    //\n    // OPWM / OPWN: Output tile size per Work-Group. A work-group computes a tile of size OPWM x OPWN.\n    // TPWM / TPWN: Threads per Work-group. This is the work-group size.\n    // OPTM / OPTN: Output elements per Thread. Each thread computes OPTM x OPTN elements.\n    //\n    // The following relationships must hold:\n    //   OPWM = TPWM * OPTM\n    //   OPWN = TPWN * OPTN\n    //\n    const int OPWM = 64;\n    const int OPWN = 64;\n    const int TPWM = 16;\n    const int TPWN = 8;\n\n    size_t local_work_size[2] = { TPWM, TPWN };\n    size_t global_work_size[2] = {\n        (size_t) ((M + OPWM - 1) / OPWM) * TPWM,\n        (size_t) ((N + OPWN - 1) / OPWN) * TPWN,\n    };\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_TENSOR_BINARY_OP_LOCALS;\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13;\n    const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1;\n\n    const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1];\n    const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3];\n    const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5];\n\n    const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type);\n    const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type);\n    const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type);\n\n    const int64_t NPQ = (int64_t)N * OW * OH;\n\n    const uint32_t BS_K = 64;\n    const uint32_t BS_NPQ = 64;\n    const uint32_t BS_CRS = 16;\n    const uint32_t VEC_SIZE = 4;\n\n    const uint32_t TS_K = 4;\n    const uint32_t TS_NPQ = 8;\n\n    const uint32_t WG_K = BS_K / TS_K;\n    const uint32_t WG_NPQ = BS_NPQ / TS_NPQ;\n\n    auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; };\n    const uint32_t NB_K = splitWork(Cout, BS_K);\n    const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ);\n\n    cl_kernel kernel;\n    size_t shmem_size;\n\n    if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {\n        kernel = backend_ctx->kernel_conv_2d_f16;\n        shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4));\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {\n        kernel = backend_ctx->kernel_conv_2d_f32;\n        shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));\n    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {\n        kernel = backend_ctx->kernel_conv_2d_f16_f32;\n        shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));\n    } else {\n        GGML_ASSERT(false && \"Unsupported data type combination for conv2d\");\n    }\n\n    cl_uint idx = 0;\n    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL));\n    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N));\n    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H));\n    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH));\n    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1));\n    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1));\n    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03));\n    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13));\n    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3));\n\n    size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 };\n    size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 };\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    const int  ne00 = src0->ne[0];\n    const int  ne01 = src0->ne[1];\n    const int  ne02 = src0->ne[2];\n\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n\n    const int  ne10 = src1->ne[0];\n    const int  ne11 = src1->ne[1];\n    const int  ne12 = src1->ne[2];\n\n    const cl_ulong nb10 = src1->nb[0];\n\n    const int  ne0 = dst->ne[0];\n    const int  ne1 = dst->ne[1];\n\n    GGML_ASSERT(ne00 == ne10);\n\n    cl_kernel kernel;\n    cl_context context = backend_ctx->context;\n\n    cl_int              status;\n    cl_image_format     img_fmt_1d;\n    cl_image_desc       img_desc_1d;\n    cl_buffer_region    region;\n    cl_mem              A_image1d;\n    cl_mem              A_sub_buffer;\n    cl_mem              B_sub_buffer;\n    cl_mem              D_image1d;\n    cl_mem              D_sub_buffer;\n\n    int M = ne01;\n    int N = ne1;\n    int K = ne00;\n\n    if (nb01 > nb02) {\n        // KQ\n        kernel = backend_ctx->kernel_mul_mm_f16_f32_kq;\n    } else {\n        // KQV\n        kernel = backend_ctx->kernel_mul_mm_f16_f32_kqv;\n    }\n    // create sub-buffer for A\n    // <--------------------------------------------> //\n    extra0 = src0->view_src ? (ggml_tensor_extra_cl *)src0->view_src->extra : (ggml_tensor_extra_cl *)src0->extra;\n\n    region.origin = (extra0->offset);\n    if (nb01 > nb02) {\n        // KQ\n        region.size = nb01 * ne01;\n    } else {\n        // KQV\n        region.size = nb02 * ne02;\n    }\n\n    A_sub_buffer = clCreateSubBuffer((extra0->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);\n    CL_CHECK(status);\n\n    // <--------------------------------------------> //\n\n    // create sub-buffer for B\n    // <--------------------------------------------> //\n    region.origin = (extra1->offset);\n    region.size = nb10 * ne10 * ne11 * ne12;\n    B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);\n    CL_CHECK(status);\n    // <--------------------------------------------> //\n\n    img_fmt_1d = {CL_RGBA, CL_FLOAT};\n    memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n    img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n    if (nb01 > nb02) {\n        img_desc_1d.image_width = (nb01 * ne01 / 4)/4;\n    }\n    else {\n        img_desc_1d.image_width = (nb02 * ne02 / 4)/4;\n    }\n    img_desc_1d.buffer = A_sub_buffer;\n    A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);\n    CL_CHECK(status);\n\n    // create sub-buffer for output C\n    // <--------------------------------------------> //\n    region.origin = (extrad->offset);\n    region.size = ne0 * ne1 * dst->ne[2] * dst->nb[0]; // size of C in bytes\n    D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);\n    CL_CHECK(status);\n    // <--------------------------------------------> //\n\n    // create image for C output\n    // <--------------------------------------------> //\n    img_fmt_1d = {CL_R, CL_FLOAT};\n    memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n    img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n    img_desc_1d.image_width = ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4;\n    img_desc_1d.buffer = D_sub_buffer;\n    D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);\n    CL_CHECK(status);\n    // <--------------------------------------------> //\n\n    int offset_src0 = 0;\n    int offset_src1 = 0;\n\n    // set kernel args\n    // <--------------------------------------------> //\n    cl_uint k_arg = 0;\n    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem), &A_image1d));\n    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &offset_src0));\n    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem), &B_sub_buffer));\n    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &offset_src1));\n    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem), &D_image1d));\n    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &extrad->offset));\n    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &M));\n    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &K));\n    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &N));\n    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &ne02));\n    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &ne12));\n    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &nb01));\n\n    size_t global_work_size[3] = {64, static_cast<size_t>(((M+63)/64)), static_cast<size_t>(((N+31)/32)*ne12)};\n    size_t local_work_size[3] = {64, 1, 2};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n\n    // deallocate sub buffers and images\n    // <--------------------------------------------> //\n    CL_CHECK(clReleaseMemObject(A_image1d));\n    CL_CHECK(clReleaseMemObject(D_image1d));\n    CL_CHECK(clReleaseMemObject(A_sub_buffer));\n    CL_CHECK(clReleaseMemObject(B_sub_buffer));\n    CL_CHECK(clReleaseMemObject(D_sub_buffer));\n}\n\nstatic void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n    ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra;\n\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    const int  ne00 = src0->ne[0];\n    const int  ne01 = src0->ne[1];\n\n    const int  ne1 = dst->ne[1];\n\n    GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);\n\n    cl_context context = backend_ctx->context;\n    cl_kernel kernel;\n\n    cl_int              err;\n    cl_image_format     img_fmt;\n    cl_image_desc       img_desc;\n    cl_buffer_region    region;\n\n    int M = ne01;\n    int N = ne1;\n    int K = ne00;\n\n    if (ne1 == 1) {\n        cl_mem q_img = nullptr;\n        cl_mem b_sub_buf = nullptr;\n        cl_mem b_img = nullptr;\n\n        // image for q\n        img_fmt = { CL_R, CL_UNSIGNED_INT32};\n        memset(&img_desc, 0, sizeof(img_desc));\n        img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n        img_desc.image_width = M * K / 2 / 4;\n        img_desc.buffer = extra0_q4_1->q;\n        CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));\n\n        // subbuffer for activations\n        region.origin = offset1;\n        region.size = K * N * sizeof(float);\n        CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));\n\n        // image for activations\n        img_fmt = {CL_RGBA, CL_FLOAT};\n        memset(&img_desc, 0, sizeof(img_desc));\n        img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n        img_desc.image_width = K * N / 4;\n        img_desc.buffer = b_sub_buf;\n        CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));\n\n        kernel = backend_ctx->kernel_gemv_noshuffle_q4_1_f32;\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &q_img));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem),   &extra0_q4_1->d));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra0_q4_1->m));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem),   &b_img));\n        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int),   &ne00));\n        CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int),   &ne01));\n\n        size_t local_work_size[3] = {64, 4, 1};\n        size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n\n        CL_CHECK(clReleaseMemObject(q_img));\n        CL_CHECK(clReleaseMemObject(b_sub_buf));\n        CL_CHECK(clReleaseMemObject(b_img));\n    } else {\n        cl_mem b_sub_buf = nullptr;\n        cl_mem b_sub_buf_trans = nullptr;\n        cl_mem b_img = nullptr;\n        cl_mem b_img_trans = nullptr;\n\n        // subbuffer for activations\n        region.origin = offset1;\n        region.size = K * N * sizeof(float);\n        CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));\n\n        // image for activations\n        img_fmt = {CL_RGBA, CL_FLOAT};\n        memset(&img_desc, 0, sizeof(img_desc));\n        img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n        img_desc.image_width = K * N / 4;\n        img_desc.buffer = b_sub_buf;\n        CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));\n\n        // pad N to multiple of 8\n        int extra_elements = N % 8;\n        int padding = 0;\n        if (extra_elements > 0){\n            padding = 8 - extra_elements;\n        }\n\n        // subbuffer for transposed activations\n        region.origin = 0;\n        region.size = K * (N + padding) * sizeof(float)/2;\n        backend_ctx->prealloc_act_trans.allocate(context, region.size);\n        CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));\n\n        // image for transposed activations\n        img_fmt = {CL_RGBA, CL_HALF_FLOAT};\n        memset(&img_desc, 0, sizeof(img_desc));\n        img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n        img_desc.image_width = K * (N + padding) / 4;\n        img_desc.buffer = b_sub_buf_trans;\n        CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err));\n\n        // transpose activations\n        int height_B = N/4;\n        if (height_B == 0) {\n            height_B = 1;\n        }\n        int width_B = K/4;\n        int padded_height_B = (N + padding)/4;\n\n        kernel = backend_ctx->kernel_transpose_32_16;\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int),    &height_B));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int),    &width_B));\n        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),    &padded_height_B));\n\n        size_t local_work_size_t[2] = { 1, 16 };\n        size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B };\n        backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst);\n\n        // gemm\n        kernel = backend_ctx->kernel_gemm_noshuffle_q4_1_f32;\n        int padded_N = N + padding;\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0_q4_1->q));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem),   &extra0_q4_1->d));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra0_q4_1->m));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem),   &b_img_trans));\n        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int),   &ne01));\n        CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int),   &padded_N));\n        CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int),   &ne00));\n        CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int),   &ne1));\n\n        size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1};\n        size_t local_work_size[3] = {1, 128, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n\n        CL_CHECK(clReleaseMemObject(b_sub_buf));\n        CL_CHECK(clReleaseMemObject(b_sub_buf_trans));\n        CL_CHECK(clReleaseMemObject(b_img));\n        CL_CHECK(clReleaseMemObject(b_img_trans));\n    }\n#else\n    GGML_UNUSED(backend);\n    GGML_UNUSED(src0);\n    GGML_UNUSED(src1);\n    GGML_UNUSED(dst);\n#endif\n}\n\nstatic void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    const enum ggml_type src0t = src0->type;\n    const enum ggml_type src1t = src1->type;\n\n    GGML_ASSERT(src0t == GGML_TYPE_Q8_0);\n    GGML_ASSERT(src1t == GGML_TYPE_F32);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;\n\n    GGML_ASSERT(src1->view_offs == 0);\n    GGML_ASSERT(dst->view_offs == 0);\n\n    const int  ne00 = src0->ne[0];\n    const int  ne01 = src0->ne[1];\n    const int  ne02 = src0->ne[2];\n\n    const int  ne10 = src1->ne[0];\n    const int  ne12 = src1->ne[2];\n\n    const int  ne0 = dst->ne[0];\n    const int  ne1 = dst->ne[1];\n\n    GGML_ASSERT(ne00 == ne10);\n    GGML_ASSERT((ne00 % 32) == 0);\n    GGML_ASSERT(ne0 == ne01);\n\n    cl_context context = backend_ctx->context;\n    cl_kernel kernel;\n\n    // init CL objects\n    cl_int              status;\n    cl_image_format     img_fmt_1d;\n    cl_image_desc       img_desc_1d;\n    cl_buffer_region    region;\n    cl_mem              A_image1d;\n    cl_mem              B_image1d;\n    cl_mem              B_sub_buffer;\n    cl_mem              S_image1d;\n\n    cl_mem              D_image1d;\n    cl_mem              D_sub_buffer;\n\n    int M = ne01;\n    int N = ne1;\n    int K = ne00;\n\n    // create an image for A\n    img_fmt_1d = { CL_R, CL_FLOAT};\n    memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n    img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n    img_desc_1d.image_width = M * K / 4;    // Divide by 4 for char -> float\n    img_desc_1d.buffer = extra0_q8_0->q;\n    A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);\n    CL_CHECK(status);\n\n    // create an image for Scale\n    img_fmt_1d = { CL_R, CL_HALF_FLOAT};\n    memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n    img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n    img_desc_1d.image_width = M * K / 32;    // Block size is 32\n    img_desc_1d.buffer = extra0_q8_0->d;\n    S_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);\n    CL_CHECK(status);\n\n    // create a sub_buffer for B\n    region.origin = (extra1->offset); // + src1->view_offs);\n    region.size = K * N * sizeof(float);\n    B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);\n    CL_CHECK(status);\n\n    // create an image for B from sub_buffer: RGBA (OCL)\n    img_fmt_1d = {CL_RGBA, CL_FLOAT};\n    memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n    img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n    img_desc_1d.image_width = K * N / 4;\n    img_desc_1d.buffer = B_sub_buffer;\n    B_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);\n    CL_CHECK(status);\n\n    // Create subbuffer and image1d_buffer for dst\n    region.origin = (extrad->offset); // + dst->view_offs;\n    region.size = M * N * sizeof(float);\n    D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);\n    CL_CHECK(status);\n\n    img_fmt_1d = {CL_R, CL_FLOAT};\n    memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n    img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n    img_desc_1d.image_width = M * N;\n    img_desc_1d.buffer = D_sub_buffer;\n    D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);\n    CL_CHECK(status);\n\n    size_t local_work_size[3] = {1, 1, 1};\n    size_t global_work_size[3] = {1, 1, 1};\n\n    if (N == 1) {\n        kernel = backend_ctx->CL_mul_mat_vec_q8_0_f32;\n\n        int r2 = 1;\n        int r3 = 1;\n        cl_uint k_arg = 0;\n\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &A_image1d));\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &extra0_q8_0->d));\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &B_image1d));\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_ulong), &extra1->offset));\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_ulong), &extrad->offset));\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne00));\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne01));\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne02));\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne10));\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne12));\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne0));\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne1));\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &r2));\n        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &r3));\n\n        size_t wavesize = backend_ctx->adreno_wave_size;\n        local_work_size[0] = wavesize;\n        local_work_size[1] = 4; // reduce factor\n        local_work_size[2] = 1;\n\n        global_work_size[0] = ((M + wavesize - 1) / wavesize) * wavesize;\n        global_work_size[1] = 4; // reduce factor\n        global_work_size[2] = 1;\n    } else {\n        cl_ulong offsetd = extrad->offset + dst->view_offs;\n        cl_mem              B_image1d_trans = nullptr;\n        // for B transpose\n        cl_mem B_d = nullptr;\n        int padding;\n\n        //how many extra elements beyond multiple of 8\n        int extra_elements = N % 8;\n\n        //how much padding to add\n        padding = 0;\n        if (extra_elements > 0){\n            padding = 8 - extra_elements;\n        }\n\n        // Specify the starting offset (in bytes)\n        region.origin = 0;\n        // Specify the size of the sub-buffer (divide by 2 for FP16)\n        region.size = K * (N + padding) * sizeof(float)/2;\n        backend_ctx->prealloc_act_trans.allocate(context, region.size);\n        B_d = clCreateSubBuffer(\n            backend_ctx->prealloc_act_trans.buffer,\n            0,\n            CL_BUFFER_CREATE_TYPE_REGION,\n            &region,\n            &status);\n        CL_CHECK(status);\n\n        cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16)\n        cl_image_desc image_desc_B_d_output = {\n            CL_MEM_OBJECT_IMAGE1D_BUFFER,\n            static_cast<size_t>(K * (N + padding)/4),\n            0, 0, 0, 0, 0, 0, 0, { B_d }\n        };\n        B_image1d_trans = clCreateImage(\n            context,\n            0,\n            &image_format_B_d_output,\n            &image_desc_B_d_output,\n            NULL,\n            &status);\n        CL_CHECK(status);\n\n        int height_B = N/4;\n        if (height_B == 0) {\n            height_B = 1;\n        }\n        int width_B = K/4;\n        int padded_height_B = (N + padding)/4;\n\n        kernel = backend_ctx->kernel_transpose_32_16;\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_image1d));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d_trans));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int),    &height_B));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int),    &width_B));\n        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),    &padded_height_B));\n\n        size_t local_size_t[2] = { 1, 16 };\n        size_t global_size_t[2] = {\n            static_cast<size_t>(width_B),\n            static_cast<size_t>(padded_height_B)\n        };\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst);\n\n        kernel = backend_ctx->kernel_mul_mm_q8_0_f32_8x4;\n\n        int N_with_padding = N + padding;\n\n        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q8_0->q));\n        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q8_0->d));\n        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &B_image1d_trans));\n        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &K));\n        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),      &M));\n        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &N_with_padding));\n        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &N));\n        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &offsetd));\n\n        global_work_size[0] = (size_t)(N + 7) / 8;\n        global_work_size[1] = (size_t)(M + 3) / 4;\n        global_work_size[2] = 1;\n\n        local_work_size[0] = 2;\n        local_work_size[1] = 128;\n        local_work_size[2] = 1;\n    }\n\n    // enqueue kernel with profiling\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n\n    // deallocate sub buffers and images\n    CL_CHECK(clReleaseMemObject(A_image1d));\n    CL_CHECK(clReleaseMemObject(B_sub_buffer));\n    CL_CHECK(clReleaseMemObject(B_image1d));\n    CL_CHECK(clReleaseMemObject(S_image1d));\n    CL_CHECK(clReleaseMemObject(D_sub_buffer));\n    CL_CHECK(clReleaseMemObject(D_image1d));\n#else\n    GGML_UNUSED(backend);\n    GGML_UNUSED(src0);\n    GGML_UNUSED(src1);\n    GGML_UNUSED(dst);\n#endif\n}\n\nstatic void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;\n    const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n#ifdef GGML_OPENCL_SOA_Q\n    ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;\n    ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra;\n    ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;\n    ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;\n    ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra;\n#endif\n\n    const int  ne00 = src0 ? src0->ne[0] : 0;\n    const int  ne01 = src0 ? src0->ne[1] : 0;\n    const int  ne02 = src0 ? src0->ne[2] : 0;\n    const int  ne03 = src0 ? src0->ne[3] : 0;\n\n    const cl_ulong nb00 = src0 ? src0->nb[0] : 0;\n    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;\n    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;\n    const cl_ulong nb03 = src0 ? src0->nb[3] : 0;\n\n    const int  ne10 = src1 ? src1->ne[0] : 0;\n    const int  ne11 = src1 ? src1->ne[1] : 0;\n    const int  ne12 = src1 ? src1->ne[2] : 0;\n    const int  ne13 = src1 ? src1->ne[3] : 0;\n\n    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;\n    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;\n    const cl_ulong nb12 = src1 ? src1->nb[2] : 0;\n    const cl_ulong nb13 = src1 ? src1->nb[3] : 0;\n\n    const int  ne0 = dst ? dst->ne[0] : 0;\n    const int  ne1 = dst ? dst->ne[1] : 0;\n\n    int r2 = ne12/ne02;\n    int r3 = ne13/ne03;\n\n    GGML_ASSERT(ne00 == ne10);\n\n    int nth0 = 32;\n    int nth1 = 1;\n    int nrows = 1;\n    // The number of values produced by each subgroup\n    int ndst = 4;\n\n    cl_kernel kernel;\n\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n    cl_context context = backend_ctx->context;\n\n    if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){\n        if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0  &&\n            // dst is wrapped with image1d_buffer, the size limit applies, also src0\n            (ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4 <= backend_ctx->image_max_buffer_size)) {\n            // For KQ\n            if (ggml_is_permuted(src0) && ggml_is_permuted(src1) &&\n                ((nb01 * ne01 / 4)/4 <= backend_ctx->image_max_buffer_size) &&\n                nb00 <= nb02 &&\n                nb02 <= nb01 &&\n                nb01 <= nb03 &&\n                nb10 <= nb12 &&\n                nb12 <= nb11 &&\n                nb11 <= nb13) {\n                ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);\n                return;\n            }\n            // For KQV\n            if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&\n                ((nb02 * ne02 / 4)/4 <= backend_ctx->image_max_buffer_size)) {\n                ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);\n                return;\n            }\n        }\n    }\n\n    if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) {\n\n    // init CL objects\n    // <--------------------------------------------> //\n    cl_int              status;\n    cl_image_format     img_fmt_1d;\n    cl_image_desc       img_desc_1d;\n    cl_buffer_region    region;\n    cl_mem              A_image1d = nullptr;\n    cl_mem              B_image1d = nullptr;\n    cl_mem              B_sub_buffer = nullptr;\n    cl_mem              C_d = nullptr;\n    // for B transpose\n    cl_mem B_d = nullptr;\n    cl_mem B_d_input_image = nullptr;\n    // <--------------------------------------------> //\n\n    // define matrix dimensions\n    // <--------------------------------------------> //\n    int M = ne01;\n    int N = ne1;\n    int K = ne00;\n    int padding;\n    // <--------------------------------------------> //\n\n    // NOTE: Kernels using image1d_buffer_t (e.g., src0_q) would normally require\n    // a limit check, but q4_0 / q4_1 tensors are very unlikely to exceed that\n    // limit, so the check is omitted.\n\n    // q4_1 x fp32\n    if (src0t == GGML_TYPE_Q4_1 && src1t == GGML_TYPE_F32) {\n            ggml_cl_mul_mat_q4_1_f32_adreno(backend, src0, src1, dst);\n            return;\n    }\n\n    // q8_0 x fp32\n    if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 &&\n        enable_adreno_trans_weight(backend_ctx, src0)) {\n            ggml_cl_mul_mat_q8_0_f32_adreno(backend, src0, src1, dst);\n            return;\n    }\n\n    // q4_0 x fp32\n    if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) {\n        // TODO: remove duplicate definitions of image description + format -- move to top\n\n        // create an image for A\n        // <--------------------------------------------> //\n        if (N == 1) {\n            img_fmt_1d = { CL_R, CL_UNSIGNED_INT32};\n        } else {\n            img_fmt_1d = { CL_R, CL_FLOAT};\n        }\n        memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n        img_desc_1d.image_width = M * K / 2 / 4;    // Divide by 4 for char -> float\n        img_desc_1d.buffer = extra0_q4_0->q;\n        A_image1d = clCreateImage(\n            context,\n            CL_MEM_READ_ONLY,\n            &img_fmt_1d,\n            &img_desc_1d,\n            NULL,\n            &status);\n        CL_CHECK(status);\n        // <--------------------------------------------> //\n\n\n        // create a sub_buffer for B\n        // <--------------------------------------------> //\n        region.origin = (extra1->offset);\n        region.size = K * N * sizeof(float);\n        B_sub_buffer = clCreateSubBuffer(\n            extra1->data_device,\n            0,\n            CL_BUFFER_CREATE_TYPE_REGION,\n            &region,\n            &status);\n        CL_CHECK(status);\n        // <--------------------------------------------> //\n\n        // transpose activation for Skyler's gemm\n        if (N != 1) {\n            //how many extra elements beyond multiple of 8\n            int extra_elements = N % 8;\n\n            //how much padding to add\n            padding = 0;\n            if (extra_elements > 0){\n                padding = 8 - extra_elements;\n            }\n\n            // Specify the starting offset (in bytes)\n            region.origin = 0;\n            // Specify the size of the sub-buffer (divide by 2 for FP16)\n            region.size = K * (N + padding) * sizeof(float)/2;\n            backend_ctx->prealloc_act_trans.allocate(context, region.size);\n\n            B_d = clCreateSubBuffer(\n                backend_ctx->prealloc_act_trans.buffer,\n                0,\n                CL_BUFFER_CREATE_TYPE_REGION,\n                &region,\n                &status);\n            CL_CHECK(status);\n\n            cl_image_format image_format_B_d_input = { CL_RGBA, CL_FLOAT };\n            cl_image_desc image_desc_B_d_input = {\n                CL_MEM_OBJECT_IMAGE1D_BUFFER,\n                static_cast<size_t>(K * N / 4),\n                0, 0, 0, 0, 0, 0, 0, { B_sub_buffer }\n            };\n            B_d_input_image = clCreateImage(\n                context,\n                0,\n                &image_format_B_d_input,\n                &image_desc_B_d_input,\n                NULL,\n                &status);\n            CL_CHECK(status);\n\n            cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16)\n            cl_image_desc image_desc_B_d_output = {\n                CL_MEM_OBJECT_IMAGE1D_BUFFER,\n                static_cast<size_t>(K * (N + padding)/4),\n                0, 0, 0, 0, 0, 0, 0, { B_d }\n            };\n            B_image1d = clCreateImage(\n                context,\n                0,\n                &image_format_B_d_output,\n                &image_desc_B_d_output,\n                NULL,\n                &status);\n            CL_CHECK(status);\n\n            int height_B = N/4;\n            if (height_B == 0) {\n                height_B = 1;\n            }\n            int width_B = K/4;\n            int padded_height_B = (N + padding)/4;\n\n            kernel = backend_ctx->kernel_transpose_32_16;\n            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_d_input_image));\n            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d));\n            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int),    &height_B));\n            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int),    &width_B));\n            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),    &padded_height_B));\n\n            size_t local_size_t[2] = { 1, 16 };\n            //WGS tuning\n            if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) {\n                local_size_t[0]=4;\n                local_size_t[1]=8;\n            } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) {\n                local_size_t[0]=2;\n                local_size_t[1]=8;\n            } else if(ne0 == 4096 && ne1 == 128 && ne10 == 11008) {\n                local_size_t[0]=1;\n                local_size_t[1]=8;\n            } else if(ne0 == 32000 && ne1 == 128 && ne10 == 4096) {\n                local_size_t[0]=2;\n                local_size_t[1]=8;\n            }\n\n            size_t global_size_t[2] = {\n                static_cast<size_t>(width_B),\n                static_cast<size_t>(padded_height_B)\n            };\n\n            backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst);\n        } else {\n            // no need to transpose B in other cases\n            // create an image for B from sub_buffer\n            // <--------------------------------------------> //\n            img_fmt_1d = {CL_RGBA, CL_FLOAT};\n\n            memset(&img_desc_1d, 0, sizeof(img_desc_1d));\n            img_desc_1d.image_width = K * N / 4;\n            img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;\n            img_desc_1d.buffer = B_sub_buffer;\n            B_image1d = clCreateImage(\n                context,\n                CL_MEM_READ_ONLY,\n                &img_fmt_1d,\n                &img_desc_1d,\n                NULL,\n                &status);\n            CL_CHECK(status);\n            // <--------------------------------------------> //\n        }\n\n        // choose gemm or gemv kernel\n        // <--------------------------------------------> //\n        if (N == 1) {\n            kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general;\n            if (M == 4096 && K == 4096) {\n                kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096;\n            } else if (M == 4096 && K == 11008) {\n                kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008;\n            } else if (M == 11008 && K == 4096) {\n                kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096;\n            } else if (M == 32000 && K == 4096) {\n                kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096;\n            }\n        } else {\n            kernel = backend_ctx->CL_mul_mat_Ab_Bi_8x4;\n        }\n        // <--------------------------------------------> //\n\n        // set kernel args\n        // <--------------------------------------------> //\n        cl_uint k_arg = 0;\n\n        if (N == 1) {\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &A_image1d));\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &extra0_q4_0->d));\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &B_image1d));\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_ulong), &extra1->offset));\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_ulong), &extrad->offset));\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne02));\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne10));\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &r3));\n        } else {\n            region.origin = extrad->offset; // Specify the starting offset (in bytes)\n            region.size = M * N * sizeof(float); // Specify the size of the sub-buffer\n            C_d = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);\n            CL_CHECK(status);\n\n            int padded_N = ne1 + padding;\n\n            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); //A_q_dextra0_q4_0->q\n            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); //A_s_d\n            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d)); //B_d\n            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &C_d)); //C_d\n            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),    &ne01)); //M\n            CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),    &padded_N)); //N with padding\n            CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),    &ne00)); //K\n            CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),    &ne1)); //N without padding\n        }\n        // <--------------------------------------------> //\n\n        // choose workgroup size\n        // <--------------------------------------------> //\n        size_t global_work_size[3] = {\n            64, static_cast<size_t>((M+63)/64), static_cast<size_t>((N+31)/32)};\n        size_t local_work_size[3] = {64, 2, 4};\n\n        global_work_size[0] = (size_t)(ceil((float)ne1/8));\n        global_work_size[1] = (size_t)(ne01/4);\n        global_work_size[2] = (size_t)(1);\n\n        local_work_size[0]  = (size_t)(1); //4x32 for FP32\n        local_work_size[1]  = (size_t)(128);\n        local_work_size[2]  = (size_t)(1);\n\n        //WGS tuning\n        if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) {\n            local_work_size[0] = 1;\n            local_work_size[1] = 128;\n        } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) {\n            local_work_size[0] = 2;\n            local_work_size[1] = 64;\n        } else if (ne0 == 4096 && ne1 == 128 && ne10 == 11008) {\n            local_work_size[0] = 2;\n            local_work_size[1] = 64;\n        } else if (ne0 == 32000 && ne1 == 128 && ne10 == 4096) {\n            local_work_size[0] = 2;\n            local_work_size[1] = 64;\n        }\n\n        if (N == 1) {\n            size_t wavesize = backend_ctx->adreno_wave_size;\n            local_work_size[0] = wavesize; // localsize\n            local_work_size[1] = 4; // reduce factor\n            local_work_size[2] = 1;\n\n            global_work_size[0] = (((M / 2) + wavesize - 1) / wavesize) * wavesize;\n            global_work_size[1] = 4; // reduce factor\n            global_work_size[2] = 1;\n        }\n        // <--------------------------------------------> //\n\n        // enqueue kernel with profiling\n        // <--------------------------------------------> //\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n        // <--------------------------------------------> //\n\n        // deallocate sub buffers and images\n        // <--------------------------------------------> //\n        CL_CHECK(clReleaseMemObject(A_image1d));\n        CL_CHECK(clReleaseMemObject(B_sub_buffer));\n        CL_CHECK(clReleaseMemObject(B_image1d));\n\n        if (N != 1) {\n            CL_CHECK(clReleaseMemObject(B_d));\n            CL_CHECK(clReleaseMemObject(B_d_input_image));\n            CL_CHECK(clReleaseMemObject(C_d));\n        }\n        // <--------------------------------------------> //\n\n        return;\n    }\n    } // if (ne01 && ne1)\n#endif // GGML_OPENCL_USE_ADRENO_KERNELS\n\n    // GEMM using local memory\n    // Current BK = 16, so ne00 % 16 == 0\n    if (src1t == GGML_TYPE_F32 &&\n        ne00 % 16 == 0 &&\n        ne11 > 1) {\n        switch(src0t) {\n            case GGML_TYPE_F32: {\n                kernel = backend_ctx->kernel_mul_mm_f32_f32_l4_lm;\n                nth0 = 128; // calculated as (BM*BN)/(TM*TN)\n\n                int batch_stride_a = ne00*ne01;\n                int batch_stride_b = ne10*ne11;\n                int batch_stride_d = ne0*ne1;\n\n                cl_mem mem_src0 = extra0->data_device;\n                cl_mem mem_src1 = extra1->data_device;\n\n                cl_ulong nb00_cont = nb00;\n                cl_ulong nb01_cont = nb01;\n                cl_ulong nb02_cont = nb02;\n                cl_ulong nb03_cont = nb03;\n\n                cl_ulong nb10_cont = nb10;\n                cl_ulong nb11_cont = nb11;\n                cl_ulong nb12_cont = nb12;\n                cl_ulong nb13_cont = nb13;\n\n                cl_ulong offset0_cont = offset0;\n                cl_ulong offset1_cont = offset1;\n\n                if (!ggml_is_contiguous(src0)) {\n                    backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0));\n                    ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer,\n                        nb00_cont, nb01_cont, nb02_cont, nb03_cont);\n                    mem_src0 = backend_ctx->prealloc_src0.buffer;\n                    offset0_cont = 0;\n                }\n\n                if (!ggml_is_contiguous(src1)) {\n                    backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1));\n                    ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer,\n                        nb10_cont, nb11_cont, nb12_cont, nb13_cont);\n                    mem_src1 = backend_ctx->prealloc_src1.buffer;\n                    offset1_cont = 0;\n                }\n\n                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &mem_src0));\n                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0_cont));\n                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &mem_src1));\n                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1_cont));\n                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne11));\n                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));\n                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10)); // stride_a\n                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10)); // stride_b\n                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne01)); // stride_d\n                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &batch_stride_a));\n                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &batch_stride_b));\n                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_d));\n                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r2));\n                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r3));\n\n                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.\n                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};\n                size_t local_work_size[] = {(size_t)nth0, 1, 1};\n\n                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n                return;\n            }\n            case GGML_TYPE_F16: {\n                kernel = backend_ctx->kernel_mul_mm_f16_f32_l4_lm;\n                nth0 = 128; // calculated as (BM*BN)/(TM*TN)\n\n                int batch_stride_a = ne00*ne01;\n                int batch_stride_b = ne10*ne11;\n                int batch_stride_d = ne0*ne1;\n\n                cl_mem mem_src0 = extra0->data_device;\n                cl_mem mem_src1 = extra1->data_device;\n\n                cl_ulong nb00_cont = nb00;\n                cl_ulong nb01_cont = nb01;\n                cl_ulong nb02_cont = nb02;\n                cl_ulong nb03_cont = nb03;\n\n                cl_ulong nb10_cont = nb10;\n                cl_ulong nb11_cont = nb11;\n                cl_ulong nb12_cont = nb12;\n                cl_ulong nb13_cont = nb13;\n\n                cl_ulong offset0_cont = offset0;\n                cl_ulong offset1_cont = offset1;\n\n                if (!ggml_is_contiguous(src0)) {\n                    backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0));\n                    ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer,\n                        nb00_cont, nb01_cont, nb02_cont, nb03_cont);\n                    mem_src0 = backend_ctx->prealloc_src0.buffer;\n                    offset0_cont = 0;\n                }\n\n                if (!ggml_is_contiguous(src1)) {\n                    backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1));\n                    ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer,\n                            nb10_cont, nb11_cont, nb12_cont, nb13_cont);\n                    mem_src1 = backend_ctx->prealloc_src1.buffer;\n                    offset1_cont = 0;\n                }\n\n                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &mem_src0));\n                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0_cont));\n                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &mem_src1));\n                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1_cont));\n                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne11));\n                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));\n                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10)); // stride_a\n                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10)); // stride_b\n                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne01)); // stride_d\n                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &batch_stride_a));\n                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &batch_stride_b));\n                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_d));\n                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r2));\n                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r3));\n\n                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.\n                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};\n                size_t local_work_size[] = {(size_t)nth0, 1, 1};\n\n                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n                return;\n            }\n            case GGML_TYPE_Q4_0: {\n                if (ne11 < 32) {\n                    break;\n                }\n                if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {\n                    break;\n                }\n\n                kernel = backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm;\n                nth0 = 128; // calculated as (BM*BN)/(TM*TN)\n\n                int batch_stride_a = ne00*ne01;\n                int batch_stride_b = ne10*ne11;\n                int batch_stride_d = ne0*ne1;\n\n                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q4_0->q));\n                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q4_0->d));\n                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne11));\n                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));\n                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10)); // stride_a\n                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10)); // stride_b\n                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne01)); // stride_d\n                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &batch_stride_a));\n                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &batch_stride_b));\n                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_d));\n                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r2));\n                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r3));\n\n                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.\n                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};\n                size_t local_work_size[] = {(size_t)nth0, 1, 1};\n\n                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n                return;\n            }\n            case GGML_TYPE_Q4_1: {\n                if (ne11 < 32) {\n                    break;\n                }\n                if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {\n                    break;\n                }\n\n                kernel = backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm;\n                nth0 = 128; // calculated as (BM*BN)/(TM*TN)\n\n                int batch_stride_a = ne00*ne01;\n                int batch_stride_b = ne10*ne11;\n                int batch_stride_d = ne0*ne1;\n\n                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q4_1->q));\n                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q4_1->d));\n                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra0_q4_1->m));\n                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_mem),   &extra1->data_device));\n                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_ulong), &offset1));\n                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_mem),   &extrad->data_device));\n                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &offsetd));\n                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne00));\n                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne01));\n                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne02));\n                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne11));\n                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne12));\n                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10)); // stride_a\n                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne10)); // stride_b\n                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne01)); // stride_d\n                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &batch_stride_a));\n                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_b));\n                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &batch_stride_d));\n                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r2));\n                CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int),      &r3));\n\n                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.\n                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};\n                size_t local_work_size[] = {(size_t)nth0, 1, 1};\n\n                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n                return;\n            }\n            case GGML_TYPE_Q8_0: {\n                if (ne11 < 32) {\n                    break;\n                }\n                if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {\n                    break;\n                }\n\n                kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm;\n                nth0 = 128; // calculated as (BM*BN)/(TM*TN)\n\n                int batch_stride_a = ne00*ne01;\n                int batch_stride_b = ne10*ne11;\n                int batch_stride_d = ne0*ne1;\n\n                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q8_0->q));\n                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q8_0->d));\n                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne11));\n                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));\n                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10)); // stride_a\n                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10)); // stride_b\n                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne01)); // stride_d\n                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &batch_stride_a));\n                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &batch_stride_b));\n                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_d));\n                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r2));\n                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r3));\n\n                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.\n                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};\n                size_t local_work_size[] = {(size_t)nth0, 1, 1};\n\n                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n                return;\n            }\n            case GGML_TYPE_Q6_K: {\n                if (ne11 < 32) {\n                    break;\n                }\n                if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {\n                    break;\n                }\n\n                kernel = backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm;\n                nth0 = 128; // calculated as (BM*BN)/(TM*TN)\n\n                int batch_stride_a = ne00*ne01;\n                int batch_stride_b = ne10*ne11;\n                int batch_stride_d = ne0*ne1;\n\n                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q6_K->ql));\n                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q6_K->qh));\n                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra0_q6_K->s));\n                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_mem),   &extra0_q6_K->d));\n                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extra1->data_device));\n                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset1));\n                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));\n                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));\n                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));\n                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne01));\n                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne02));\n                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne11));\n                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne12));\n                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne10)); // stride_a\n                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10)); // stride_b\n                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne01)); // stride_d\n                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_a));\n                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &batch_stride_b));\n                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &batch_stride_d));\n                CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int),      &r2));\n                CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),      &r3));\n\n                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.\n                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};\n                size_t local_work_size[] = {(size_t)nth0, 1, 1};\n\n                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n                return;\n            }\n            default:\n                break;\n        }\n    }\n\n    if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&\n        src0->ne[1] > 32 &&   // M > 32\n        src1->ne[1] > 32 &&   // N > 32\n        src0->ne[0] > 32 &&   // K > 32\n        src0->ne[2] == 1 && src0->ne[3] == 1 &&\n        src1->ne[2] == 1 && src1->ne[3] == 1 &&\n        ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&\n        backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {\n        ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);\n        return;\n    }\n\n    if (!ggml_is_transposed(src0) &&\n        !ggml_is_transposed(src1) &&\n        src1t == GGML_TYPE_F32 &&\n        ne00%32 == 0 &&\n        ne11 > 2) {\n#ifdef GGML_OPENCL_SOA_Q\n        // Set up kernel.\n        switch(src0t) {\n            case GGML_TYPE_Q4_0:\n                // This should have been satisfied.\n                GGML_ASSERT(ne11 == ne1);\n                GGML_ASSERT(ne01 == ne0);\n\n                if (backend_ctx->gpu_family == INTEL) {\n                    nth0 = 16;\n                    nth1 = 1;\n\n                    kernel = backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat;\n                } else if (backend_ctx->gpu_family == ADRENO) {\n                    nth0 = 64;\n                    nth1 = 1;\n\n                    kernel = backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat;\n                } else {\n                    GGML_ASSERT(false && \"TODO: Unknown GPU\");\n                }\n\n                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q4_0->q));\n                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q4_0->d));\n                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));\n                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));\n                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne0));\n                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));\n                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));\n                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));\n                break;\n            default:\n                break;\n        }\n\n        // Launch kernel.\n        if (src0t == GGML_TYPE_Q4_0) {\n            size_t global_work_size[] = {(size_t)(ne01 + 7)/8*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};\n            size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};\n\n            if (backend_ctx->gpu_family == INTEL) {\n                // Set global size for Intel. It uses 16x output values.\n                global_work_size[0] = (size_t)(ne01 + 15)/16*nth0;\n                global_work_size[1] = (size_t)ne11*nth1;\n                global_work_size[2] = (size_t)ne12*ne13;\n            }\n\n            backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n            return;\n        }\n#else // GGML_OPENCL_SOA_Q\n        // TODO: add block_q4_0 variant.\n#endif // GGML_OPENCL_SOA_Q\n    }\n\n    // use custom matrix x vector kernel\n    switch (src0t) {\n        case GGML_TYPE_F32:\n            //GGML_ASSERT(ne02 == ne12);\n            GGML_ASSERT(src1t == GGML_TYPE_F32);\n            kernel = backend_ctx->kernel_mul_mat_f32_f32;\n            nrows = 4;\n\n            if (backend_ctx->gpu_family == INTEL) {\n                nth0 = 32;\n                nth1 = 1;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                nth0 = 64;\n                nth1 = 1;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb00));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne10));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne11));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10));\n            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));\n            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));\n            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));\n            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &r3));\n            break;\n        case GGML_TYPE_F16:\n            //GGML_ASSERT(ne02 == ne12);\n            if (backend_ctx->gpu_family == INTEL) {\n                nth0 = 32;\n                nth1 = 1;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                nth0 = 64;\n                nth1 = 1;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            if (src1t == GGML_TYPE_F32) {\n                if (ne11 * ne12 < 4) {\n                    kernel = backend_ctx->kernel_mul_mat_f16_f32_1row;\n                } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {\n                    kernel = backend_ctx->kernel_mul_mat_f16_f32_l4;\n                    nrows = ne11;\n                } else {\n                    kernel = backend_ctx->kernel_mul_mat_f16_f32;\n                    nrows = 4;\n                }\n            } else {\n                kernel = backend_ctx->kernel_mul_mat_f16_f16;\n                nrows = 4;\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb00));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne10));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne11));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10));\n            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));\n            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));\n            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));\n            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &r3));\n            break;\n        case GGML_TYPE_Q4_0:\n            // This should have been satisfied.\n            GGML_ASSERT(ne11 == ne1);\n            GGML_ASSERT(ne01 == ne0);\n\n#ifdef GGML_OPENCL_SOA_Q\n            if (backend_ctx->gpu_family == INTEL) {\n                nth0 = 16;\n                nth1 = 1;\n\n                kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat;\n                ndst = 8;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                nth0 = 64;\n                nth1 = 1;\n\n                kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat;\n                ndst =8;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q4_0->q));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q4_0->d));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));\n#else // GGML_OPENCL_SOA_Q\n            if (backend_ctx->gpu_family == INTEL) {\n                // Use 1D local size. Each workgroup is a SIMD group. Each SIMD\n                // group produces N_DST (4 for Q4_0 kernel) values in the result.\n                // The number of workgroups on dim 0 (the leading dimension) is\n                // the nearest multiple of 4 that covers ne0 (equals ne01).\n                nth0 = 16;\n                nth1 = 1;\n\n                kernel = backend_ctx->kernel_mul_mat_q4_0_f32;\n                ndst = 4;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                nth0 = 64;\n                nth1 = 1;\n\n                kernel = backend_ctx->kernel_mul_mat_q4_0_f32_v;\n                ndst = 4;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));\n#endif // GGML_OPENCL_SOA_Q\n            break;\n        case GGML_TYPE_Q4_1: {\n#ifdef GGML_OPENCL_SOA_Q\n            if (backend_ctx->gpu_family == INTEL) {\n                nth0 = 16;\n                nth1 = 1;\n                ndst = 4;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                nth0 = 64;\n                nth1 = 1;\n                ndst = 4;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            kernel = backend_ctx->kernel_mul_mv_q4_1_f32_flat;\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q4_1->q));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q4_1->d));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra0_q4_1->m));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne02));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne10));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &r3));\n#else\n            if (backend_ctx->gpu_family == INTEL) {\n                nth0 = 16;\n                nth1 = 1;\n                ndst = 4;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                nth0 = 64;\n                nth1 = 1;\n                ndst = 4;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            kernel = backend_ctx->kernel_mul_mv_q4_1_f32;\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));\n#endif // GGML_OPENCL_SOA_Q\n            break;\n        }\n        case GGML_TYPE_Q8_0: {\n#ifdef GGML_OPENCL_SOA_Q\n            kernel = backend_ctx->kernel_mul_mv_q8_0_f32_flat;\n\n            // nth0 - subgroup size\n            // nth1 - number of subgroups per workgroup\n            // ndst - number of output values per workgroup = output per subgroup * number of subgroups\n            if (backend_ctx->gpu_family == INTEL) {\n                nth0 = 16;\n                nth1 = 2;\n                ndst = nth1*4;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                nth0 = 64;\n                nth1 = 2;\n                ndst = nth1*4;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q8_0->q));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q8_0->d));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb01));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb02));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r3));\n#else\n            kernel = backend_ctx->kernel_mul_mv_q8_0_f32;\n\n            // nth0 - subgroup size\n            // nth1 - number of subgroups per workgroup\n            // ndst - number of output values per workgroup = output per subgroup * number of subgroups\n            if (backend_ctx->gpu_family == INTEL) {\n                nth0 = 16;\n                nth1 = 2;\n                ndst = nth1*4;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                nth0 = 64;\n                nth1 = 2;\n                ndst = nth1*4;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb01));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb02));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r3));\n#endif // GGML_OPENCL_SOA_Q\n            break;\n        }\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K: {\n            kernel = backend_ctx->kernel_mul_mv_q4_K_f32;\n\n            if (backend_ctx->gpu_family == INTEL) {\n                nth0 = 16;\n                nth1 = 1;\n                ndst = 4;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                nth0 = 64;\n                nth1 = 1;\n                ndst = 4;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),     &extra0->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(int),        &offset0));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),     &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(int),        &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),     &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),        &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),        &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),        &ne01));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong),   &nb01));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong),   &nb02));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),   &nb03));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),        &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong),   &nb11));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong),   &nb12));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong),   &nb13));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),        &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),        &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),        &r2));\n            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),        &r3));\n            break;\n        }\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n#ifdef GGML_OPENCL_SOA_Q\n            kernel = backend_ctx->kernel_mul_mv_q6_K_f32_flat;\n\n            if (backend_ctx->gpu_family == INTEL) {\n                nth0 = 16;\n                nth1 = 2;\n                ndst = 4;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                nth0 = 64;\n                nth1 = 2;\n                ndst = 4;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q6_K->ql));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q6_K->qh));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra0_q6_K->s));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_mem),   &extra0_q6_K->d));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne02));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &r3));\n#else\n            kernel = backend_ctx->kernel_mul_mv_q6_K_f32;\n\n            if (backend_ctx->gpu_family == INTEL) {\n                nth0 = 16;\n                nth1 = 2;\n                ndst = 1;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                nth0 = 64;\n                nth1 = 2;\n                ndst = 1;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));\n#endif // GGML_OPENCL_SOA_Q\n            break;\n        case GGML_TYPE_MXFP4: {\n#ifdef GGML_OPENCL_SOA_Q\n            kernel = backend_ctx->kernel_mul_mv_mxfp4_f32_flat;\n\n            cl_mem q;\n            if (backend_ctx->gpu_family == INTEL) {\n                nth0 = 16;\n                nth1 = 2;\n                ndst = nth1*2;\n\n                q = extra0_mxfp4->q;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                nth0 = 64;\n                nth1 = 2;\n                ndst = nth1*2;\n\n                q = extra0_mxfp4->q_img;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &q));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_mxfp4->e));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb03));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb12));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb13));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r3));\n#else\n            kernel = backend_ctx->kernel_mul_mv_mxfp4_f32;\n\n            if (backend_ctx->gpu_family == INTEL) {\n                nth0 = 16;\n                nth1 = 2;\n                ndst = nth1*2;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                nth0 = 64;\n                nth1 = 2;\n                ndst = nth1*2;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb03));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb12));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb13));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r3));\n            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float)*nth0,nullptr));\n#endif\n            break;\n        }\n        default:\n            GGML_ASSERT(false && \"not implemented\");\n    }\n\n    if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 ||\n        src0t == GGML_TYPE_Q4_1 ||\n        src0t == GGML_TYPE_Q8_0 ||\n        src0t == GGML_TYPE_Q2_K) {\n        // Each SIMD group produces N_DST values in the result. Assuming each\n        // workgroup has N_SIMDGROUP SIMD groups, then each workgroup will\n        // produce N_DST*N_SIMDGROUP values in the result. Hence, the grid size\n        // (number of workgroups) will be a nearest multiple of\n        // N_DST*N_SIMDGROUP to cover the size of the dimension. Below, 4 is\n        // N_DST*N_SIMDGROUP (see the kernel for Q4_0 matmul).\n        size_t global_work_size[] = {(size_t)(ne01 + ndst-1)/ndst*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};\n        size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    } else if (src0t == GGML_TYPE_Q4_K) {\n        size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};\n        size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    } else if (src0t == GGML_TYPE_Q3_K) {\n        GGML_ASSERT(false && \"not implemented\");\n    } else if (src0t == GGML_TYPE_Q5_K) {\n        GGML_ASSERT(false && \"not implemented\");\n    } else if (src0t == GGML_TYPE_Q6_K) {\n        size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};\n        size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    } else {\n        int64_t ny = (ne11 + nrows - 1)/nrows;\n\n        size_t global_work_size[] = {(size_t)ne01*nth0, (size_t)ny*nth1, (size_t)ne12*ne13};\n        size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    }\n}\n\nstatic void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    const ggml_tensor * src2 = dst->src[2];\n    GGML_ASSERT(src2);\n    GGML_ASSERT(src2->extra);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offset2 = extra2->offset + src2->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    GGML_UNUSED(offset0);\n\n#ifdef GGML_OPENCL_SOA_Q\n    ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;\n    ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;\n    ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;\n#endif\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const cl_ulong nb00 = src0->nb[0];\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const int ne10 = src1->ne[0];\n    const int ne11 = src1->ne[1];\n    const int ne12 = src1->ne[2];\n    const int ne13 = src1->ne[3];\n\n    const cl_ulong nb11 = src1->nb[1];\n    const cl_ulong nb12 = src1->nb[2];\n    const cl_ulong nb13 = src1->nb[3];\n\n    const int ne20 = src2->ne[0];\n    const int ne21 = src2->ne[1];\n\n    const cl_ulong nb21 = src2->nb[1];\n    const cl_ulong nb20 = src2->nb[0];\n\n    UNUSED(nb20);\n\n    const int ne0 = dst->ne[0];\n    const int ne1 = dst->ne[1];\n\n    const int r2 = ne12/ne02;\n    const int r3 = ne13/ne03;\n    const int dst_rows = ne20*ne21; // ne20 = n_used_experts, ne21 = n_rows\n\n    GGML_ASSERT(ne00 == ne10);\n\n    int sgs   = 32; // subgroup size\n    int nsg   = 1;  // number of subgroups\n    int nrows = 1;  // number of row in src1\n    int ndst  = 4;  // number of values produced by each subgroup\n\n    cl_kernel kernel;\n\n    // subgroup mat vec\n    switch (src0->type) {\n        case GGML_TYPE_Q4_0: {\n            kernel = backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat;\n\n            if (backend_ctx->gpu_family == INTEL) {\n                sgs  = 16;\n                nsg  = 1;\n                ndst = 8;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                sgs  = 64;\n                nsg  = 1;\n                ndst = 8;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q4_0->q));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q4_0->d));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extra2->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne02));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb00));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne10));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne11));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb11));\n            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb12));\n            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &ne20));\n            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int),      &ne21));\n            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb21));\n            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &r3));\n\n            break;\n        }\n        case GGML_TYPE_Q8_0: {\n#ifdef GGML_OPENCL_SOA_Q\n            kernel = backend_ctx->kernel_mul_mv_id_q8_0_f32_flat;\n\n            if (backend_ctx->gpu_family == INTEL) {\n                sgs  = 16;\n                nsg  = 2;\n                ndst = 4;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                sgs  = 64;\n                nsg  = 2;\n                ndst = 4;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q8_0->q));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q8_0->d));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extra2->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne11));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne20));\n            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne21));\n            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb21));\n            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),      &ne1));\n#else\n            kernel = backend_ctx->kernel_mul_mv_id_q8_0_f32;\n\n            if (backend_ctx->gpu_family == INTEL) {\n                sgs  = 16;\n                nsg  = 2;\n                ndst = 4;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                sgs  = 64;\n                nsg  = 2;\n                ndst = 4;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extra2->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne01));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne11));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne20));\n            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne21));\n            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb21));\n            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),      &ne1));\n#endif // GGML_OPENCL_SOA_Q\n            break;\n        }\n        case GGML_TYPE_MXFP4: {\n#ifdef GGML_OPENCL_USE_ADRENO_KERNELS\n            if (use_adreno_moe_kernels(backend_ctx, src0)) {\n                cl_int status;\n\n                size_t local_size[3] = {64, 2, 1};\n                size_t global_size[3] = {64, 2, 1};\n\n                cl_mem src1_sub_buffer, buf_src1_image, buf_src2;\n\n                int tile_size = 320;\n                if (ne12 == 1) { // for gemv\n                    kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32;\n\n                    // create a sub_buffer for src2\n                    cl_buffer_region region;\n                    region.origin = offset2;\n                    region.size = ne20 * ne21 * sizeof(int);\n                    buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);\n                    CL_CHECK(status);\n\n                    // set thread grid\n                    global_size[0] = static_cast<size_t>(ne01);\n                    global_size[1] = 4;\n                    global_size[2] = static_cast<size_t>(ne20);\n                    local_size[1] = 4;\n                } else { // for gemm\n                    kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32;\n\n                    // preprocess router table\n                    int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size;\n                    void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short));\n                    void * host_src2 = malloc(ne21 * nb21);\n                    CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL));\n                    int total_experts = nb21 / nb20;\n                    int out_idx = 0;\n                    for (int i_expert = 0; i_expert < ne02; i_expert++) {\n                        for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) {\n                            for (int j = 0; j < ne21; j++) {\n                                for (int i = 0; i < ne20; i++) {\n                                    int expert = ((int *)host_src2)[j * total_experts + i];\n                                    if (i_expert == expert) {\n                                        ((short *)host_src2_reorder)[out_idx] = static_cast<short>(expert);\n                                        ((short *)host_src2_reorder)[out_idx + 1] = static_cast<short>(j * ne11 + (i % ne11));\n                                        ((short *)host_src2_reorder)[out_idx + 2] = static_cast<short>(j * ne20 + i);\n                                        ((short *)host_src2_reorder)[out_idx + 3] = static_cast<short>(i_tile);\n                                        out_idx += 4;\n                                    }\n                                }\n                            }\n                        }\n                    }\n                    buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status);\n                    CL_CHECK(status);\n\n                    // set thread grid\n                    global_size[0] = static_cast<size_t>(tile_size);\n                    global_size[2] = static_cast<size_t>(ne20 * ne21 * num_tiles_per_expert);\n                }\n\n                // create a sub_buffer for src1\n                cl_buffer_region region;\n                region.origin = offset1;\n                region.size = ne10 * ne11 * ne12 * sizeof(float);\n                src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);\n                CL_CHECK(status);\n\n                // create image for src1\n                cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};\n                cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};\n                buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);\n                CL_CHECK(status);\n\n                // Set kernel args\n                int arg_idx = 0;\n                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem),    &extra0_mxfp4->q));\n                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem),    &extra0_mxfp4->e));\n                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem),    &buf_src1_image));\n                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem),    &buf_src2));\n                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem),    &extrad->data_device));\n                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong),  &offsetd));\n                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int),       &ne00));\n                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int),       &ne01));\n                if (ne12 == 1) {\n                    CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int),       &ne11));\n                } else {\n                    CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int),       &tile_size));\n                }\n\n                // launch kernel\n                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);\n\n                // deallocate sub buffers and images\n                CL_CHECK(clReleaseMemObject(src1_sub_buffer));\n                CL_CHECK(clReleaseMemObject(buf_src1_image));\n                CL_CHECK(clReleaseMemObject(buf_src2));\n                return;\n            } // else fallback to generic kernel\n#endif // GGML_OPENCL_USE_ADRENO_KERNELS\n\n#ifdef GGML_OPENCL_SOA_Q\n            kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat;\n\n            cl_mem q;\n            if (backend_ctx->gpu_family == INTEL) {\n                sgs  = 16;\n                nsg  = 2;\n                ndst = 2;\n\n                q = extra0_mxfp4->q;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                sgs  = 64;\n                nsg  = 1;\n                ndst = 4;\n\n                q = extra0_mxfp4->q_img;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &q));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_mxfp4->e));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extra2->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb01));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne11));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));\n            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne20));\n            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &ne21));\n            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb21));\n            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &r3));\n#else // GGML_OPENCL_SOA_Q\n            kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32;\n\n            if (backend_ctx->gpu_family == INTEL) {\n                sgs  = 16;\n                nsg  = 2;\n                ndst = 2;\n            } else if (backend_ctx->gpu_family == ADRENO) {\n                sgs  = 64;\n                nsg  = 2;\n                ndst = 2;\n            } else {\n                GGML_ASSERT(false && \"TODO: Unknown GPU\");\n            }\n\n            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extra2->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));\n            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));\n            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));\n            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));\n            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb01));\n            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));\n            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));\n            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne11));\n            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne12));\n            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));\n            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));\n            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));\n            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne20));\n            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &ne21));\n            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb21));\n            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),      &ne0));\n            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int),      &ne1));\n            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &r2));\n            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &r3));\n            CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs,nullptr));\n#endif // GGML_OPENCL_SOA_Q\n            break;\n        }\n        default:\n            GGML_ASSERT(false && \"not implemented\");;\n    }\n\n    int _ne1 = 1;\n    int ne123 = dst_rows;\n\n    size_t global_work_size[] = {(size_t)(ne01+ndst*nsg-1)/(ndst*nsg)*sgs, (size_t)(_ne1+nrows-1)/nrows*nsg, (size_t)ne123};\n    size_t local_work_size[] = {(size_t)sgs, (size_t)nsg, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n    GGML_UNUSED(src1);\n\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    float scale;\n    float bias;\n    memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(float));\n    memcpy(&bias,  ((int32_t *) dst->op_params) + 1, sizeof(float));\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel;\n\n    int n = ggml_nelements(dst);\n\n    if (n % 4 == 0) {\n        kernel = backend_ctx->kernel_scale_f32_4;\n        n /= 4;\n    } else {\n        kernel = backend_ctx->kernel_scale_f32;\n    }\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float),    &scale));\n    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float),    &bias));\n\n    size_t global_work_size[] = {(size_t)n, 1, 1};\n    size_t local_work_size[] = {64, 1, 1};\n\n    size_t * local_work_size_ptr = local_work_size;\n    if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {\n        local_work_size_ptr = nullptr;  // Let driver choose the work-group sizes.\n    }\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n}\n\nstatic void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n\n    // GGML_OP_CPY happens between src0 and src1.\n    // GGML_OP_DUP and GGML_OP_CONT happen between src0 and dst.\n    UNUSED(dst);\n\n    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);\n    GGML_TENSOR_LOCALS(int,      ne1, src1, ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb);\n\n    const enum ggml_type src0t = src0->type;\n    const enum ggml_type src1t = src1->type;\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n\n    cl_kernel kernel;\n\n    switch (src0t) {\n        case GGML_TYPE_F32:\n            switch (src1t) {\n                case GGML_TYPE_F16:\n                    kernel = backend_ctx->kernel_cpy_f32_f16;\n                    break;\n                case GGML_TYPE_F32:\n                    kernel = backend_ctx->kernel_cpy_f32_f32;\n                    break;\n                default:\n                    GGML_ASSERT(false && \"not implemented\");\n            }\n            break;\n        case GGML_TYPE_F16:\n            switch (src1t) {\n                case GGML_TYPE_F16:\n                    kernel = backend_ctx->kernel_cpy_f16_f16;\n                    break;\n                case GGML_TYPE_F32:\n                    kernel = backend_ctx->kernel_cpy_f16_f32;\n                    break;\n                default:\n                    GGML_ASSERT(false && \"not implemented\");\n            }\n            break;\n        case GGML_TYPE_I32:\n            switch (src1t) {\n                case GGML_TYPE_I32:\n                    kernel = backend_ctx->kernel_cpy_i32_i32;\n                    break;\n                default:\n                    GGML_ASSERT(false && \"not implemented\");\n            }\n            break;\n        default:\n            GGML_ASSERT(false && \"not implemented\");\n    }\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),      &ne01));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne02));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne03));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb00));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne11));\n    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne12));\n    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne13));\n    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10));\n    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));\n    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));\n    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));\n\n    const int nth = MIN(64, ne00);\n\n    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n    size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src1);\n}\n\nstatic void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    ggml_cl_cpy(backend, src0, dst, nullptr);\n    UNUSED(src1);\n}\n\nstatic void ggml_cl_set(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32) &&\n        src1->type == src0->type && dst->type == src0->type);\n\n    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);\n    GGML_TENSOR_LOCALS(int,      ne1, src1, ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb);\n    GGML_TENSOR_LOCALS(int,      ne,  dst,  ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb,  dst,  nb);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    const cl_ulong pnb1    = ((const int32_t *)dst->op_params)[0];\n    const cl_ulong pnb2    = ((const int32_t *)dst->op_params)[1];\n    const cl_ulong pnb3    = ((const int32_t *)dst->op_params)[2];\n    const cl_ulong offs    = ((const int32_t *)dst->op_params)[3];\n    const bool     inplace = (bool)((const int32_t *)dst->op_params)[4];\n\n    cl_kernel kernel = nullptr;\n\n    // for inplace case, dst is a view of src0 and is updated on top of it\n    // so for non-inplace case, copy src0 to dst first\n    if (!inplace) {\n        ggml_cl_cpy(backend, src0, dst, nullptr);\n    }\n\n    // then copy src1 to dst with specified offset\n    if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n        kernel = backend_ctx->kernel_cpy_f32_f32;\n    } else if (src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {\n        kernel = backend_ctx->kernel_cpy_i32_i32;\n    } else {\n        GGML_ASSERT(false && \"not implemented\");\n    }\n\n    offsetd += offs;\n    cl_ulong nb = ggml_element_size(dst);\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra1->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne10));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),      &ne11));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne12));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne13));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb10));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb11));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb12));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb13));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne11));\n    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne12));\n    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne13));\n    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb));\n    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &pnb1));\n    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &pnb2));\n    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &pnb3));\n\n    int max_local_size = backend_ctx->get_kernel_workgroup_size(kernel);\n\n    const int nth = MIN(max_local_size, ne00);\n\n    size_t global_work_size[] = {(size_t)ne11*nth, (size_t)ne12, (size_t)ne13};\n    size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    int n_past = ((int32_t *)(dst->op_params))[0];\n\n    const int  ne00 = src0 ? src0->ne[0] : 0;\n    const int  ne01 = src0 ? src0->ne[1] : 0;\n    const int  ne02 = src0 ? src0->ne[2] : 0;\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel;\n\n    if (ne00%8 == 0) {\n        kernel = backend_ctx->kernel_diag_mask_inf_8;\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),      &ne00));\n        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),      &ne01));\n        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &n_past));\n\n        size_t global_work_size[] = {(size_t)ne00*ne01*ne02/8, 1, 1};\n        size_t local_work_size[] = {64, 1, 1};\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n    } else {\n        kernel = backend_ctx->kernel_diag_mask_inf;\n\n        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),      &ne00));\n        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),      &ne01));\n        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &n_past));\n\n        size_t global_work_size[] = {(size_t)ne00, (size_t)ne01, (size_t)ne02};\n        size_t local_work_size[] = {64, 1, 1};\n\n        size_t * local_work_size_ptr = local_work_size;\n        if (ne00 % 64 != 0 && !backend_ctx->non_uniform_workgroups) {\n            local_work_size_ptr = nullptr;  // Let driver choose the work-group sizes.\n        }\n\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);\n    }\n}\n\nstatic void ggml_cl_diag(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    UNUSED(src1);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);\n    GGML_TENSOR_LOCALS(int,      ne,  dst,  ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb,  dst,  nb);\n\n    cl_kernel kernel = backend_ctx->kernel_diag_f32;\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb03));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_int),   &ne0));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb0));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb2));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb3));\n\n    int nth = 64;\n\n    size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3};\n    size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    // Softmax can now fuse KQ mask and KQ scale, which used to be two additional\n    // ops before softmax. It now also fuses alibi if `max_bias > 0`. For llama,\n    // alibi is not used; however, for some other models, it is used.\n    // KQ_mask\n    if (src1) {\n        GGML_ASSERT(src1);\n        GGML_ASSERT(src1->extra);\n    }\n\n    const ggml_tensor * src2 = dst->src[2];\n    if (src2) {\n        GGML_ASSERT(src2->extra);\n    }\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;\n    ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;\n    cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const cl_long nb01 = src0->nb[1];\n    const cl_long nb02 = src0->nb[2];\n    const cl_long nb03 = src0->nb[3];\n\n    const int ne12 = src1 ? src1->ne[2] : 0;\n    const int ne13 = src1 ? src1->ne[3] : 0;\n\n    const cl_long nb11 = src1 ? src1->nb[1] : 0;\n    const cl_long nb12 = src1 ? src1->nb[2] : 0;\n    const cl_long nb13 = src1 ? src1->nb[3] : 0;\n\n    const cl_long nb1 = dst->nb[1];\n    const cl_long nb2 = dst->nb[2];\n    const cl_long nb3 = dst->nb[3];\n\n    float scale, max_bias;\n    memcpy(&scale,    dst->op_params + 0, sizeof(float));\n    memcpy(&max_bias, dst->op_params + 1, sizeof(float));\n\n    const int n_head      = src0->ne[2];\n    const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));\n\n    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);\n    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);\n\n    // Local size must be wave size. Each workgroup is a wave, working on a row,\n    // where a row corresponds to leading dimension.\n    int nth = MIN(32, ne00);\n\n    if (backend_ctx->gpu_family == INTEL) {\n        // This is the same as the initial value.\n        nth = MIN(32, ne00);\n    }\n    else if (backend_ctx->gpu_family == ADRENO) {\n        nth = 64;\n    } else {\n        GGML_ASSERT(false && \"TODO: Unknown GPU\");\n    }\n\n    cl_kernel kernel;\n\n    if (ne00%4 == 0) {\n        if (use_f16) {\n            kernel = backend_ctx->kernel_soft_max_4_f16;\n        } else {\n            kernel = backend_ctx->kernel_soft_max_4;\n        }\n    } else {\n        if (use_f16) {\n            kernel = backend_ctx->kernel_soft_max_f16;\n        } else {\n            kernel = backend_ctx->kernel_soft_max;\n        }\n    }\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   extra1 ? &extra1->data_device : &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   extra2 ? &extra2->data_device : &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne12));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne13));\n    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));\n    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));\n    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));\n    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));\n    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));\n    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));\n    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float),    &scale));\n    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float),    &max_bias));\n    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float),    &m0));\n    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float),    &m1));\n    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &n_head_log2));\n\n    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n    size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    ggml_tensor * src2 = dst->src[2];\n    ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;\n\n    cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;\n\n    const int  ne00 = src0 ? src0->ne[0] : 0;\n    const int  ne01 = src0 ? src0->ne[1] : 0;\n    const int  ne02 = src0 ? src0->ne[2] : 0;\n    const int  ne03 = src0 ? src0->ne[3] : 0;\n\n    const cl_ulong  nb00 = src0 ? src0->nb[0] : 0;\n    const cl_ulong  nb01 = src0 ? src0->nb[1] : 0;\n    const cl_ulong  nb02 = src0 ? src0->nb[2] : 0;\n    const cl_ulong  nb03 = src0 ? src0->nb[3] : 0;\n\n    const int ne10 = src1 ? src1->ne[0] : 0;\n    const int ne11 = src1 ? src1->ne[1] : 0; UNUSED(ne11);\n    const int ne12 = src1 ? src1->ne[2] : 0; UNUSED(ne12);\n    const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);\n\n    const int  ne0 = dst ? dst->ne[0] : 0;\n    const int  ne1 = dst ? dst->ne[1] : 0;\n    const int  ne2 = dst ? dst->ne[2] : 0;\n    const int  ne3 = dst ? dst->ne[3] : 0;\n\n    const cl_ulong  nb0 = dst ? dst->nb[0] : 0;\n    const cl_ulong  nb1 = dst ? dst->nb[1] : 0;\n    const cl_ulong  nb2 = dst ? dst->nb[2] : 0;\n    const cl_ulong  nb3 = dst ? dst->nb[3] : 0;\n\n    GGML_ASSERT(ne10 % ne02 == 0);\n    GGML_ASSERT(ne10 >= ne02);\n\n    int nth = MIN(64, ne00);\n\n    const int n_past     = ((int *) dst->op_params)[0];\n    const int n_dims     = ((int *) dst->op_params)[1];\n    const int mode       = ((int *) dst->op_params)[2];\n    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];\n\n    float freq_base;\n    float freq_scale;\n    float ext_factor;\n    float attn_factor;\n    float beta_fast;\n    float beta_slow;\n    int32_t sections[4];\n\n    memcpy(&freq_base,   (int32_t *) dst->op_params + 5, sizeof(float));\n    memcpy(&freq_scale,  (int32_t *) dst->op_params + 6, sizeof(float));\n    memcpy(&ext_factor,  (int32_t *) dst->op_params + 7, sizeof(float));\n    memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));\n    memcpy(&beta_fast,   (int32_t *) dst->op_params + 9, sizeof(float));\n    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));\n    memcpy(&sections,    (int32_t *) dst->op_params + 11, sizeof(int32_t)*4);\n\n    const bool is_neox = mode & 2;\n    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;\n    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;\n    const int  is_imrope = mode == GGML_ROPE_TYPE_IMROPE;\n\n    if (is_mrope) {\n        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);\n    }\n\n    if (is_vision) {\n        GGML_ASSERT(n_dims == ne00/2);\n    }\n\n    cl_kernel kernel;\n\n    if (is_neox) {\n        switch (src0->type) {\n            case GGML_TYPE_F32:\n                kernel = backend_ctx->kernel_rope_neox_f32;\n                break;\n            case GGML_TYPE_F16:\n                kernel = backend_ctx->kernel_rope_neox_f16;\n                break;\n            default:\n                GGML_ASSERT(false);\n        };\n    } else if (is_mrope && !is_vision) {\n        switch (src0->type) {\n            case GGML_TYPE_F32:\n                kernel = backend_ctx->kernel_rope_multi_f32;\n                break;\n            case GGML_TYPE_F16:\n                kernel = backend_ctx->kernel_rope_multi_f16;\n                break;\n            default:\n                GGML_ASSERT(false);\n        };\n    } else if (is_vision) {\n        switch (src0->type) {\n            case GGML_TYPE_F32:\n                kernel = backend_ctx->kernel_rope_vision_f32;\n                break;\n            case GGML_TYPE_F16:\n                kernel = backend_ctx->kernel_rope_vision_f16;\n                break;\n            default:\n                GGML_ASSERT(false);\n        }\n    } else {\n        switch (src0->type) {\n            case GGML_TYPE_F32:\n                kernel = backend_ctx->kernel_rope_norm_f32;\n                break;\n            case GGML_TYPE_F16:\n                kernel = backend_ctx->kernel_rope_norm_f16;\n                break;\n            default:\n                GGML_ASSERT(false);\n        };\n    }\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   extra2 ? &extra2->data_device : &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne01));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne02));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne03));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb00));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb03));\n    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne0));\n    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne1));\n    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &ne2));\n    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int),      &ne3));\n    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb0));\n    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb1));\n    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb2));\n    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &nb3));\n    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &n_past));\n    CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &n_dims));\n    CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int),      &n_ctx_orig));\n    CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float),    &freq_base));\n    CL_CHECK(clSetKernelArg(kernel, 28, sizeof(float),    &freq_scale));\n    CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float),    &ext_factor));\n    CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float),    &attn_factor));\n    CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float),    &beta_fast));\n    CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float),    &beta_slow));\n    // both mrope and vision kernels have sections\n    if (is_mrope || is_vision) {\n        CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, &sections));\n    }\n    // only mrope has is_imrope\n    if (is_mrope && !is_vision) {\n        CL_CHECK(clSetKernelArg(kernel, 34, sizeof(int), &is_imrope));\n    }\n\n    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};\n    size_t local_work_size[] = {(size_t)nth, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_solve_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_kernel kernel = backend_ctx->kernel_solve_tri_f32;\n    GGML_ASSERT(kernel != nullptr);\n\n    const int n = src0->ne[0];\n    const int k = src1->ne[0];\n\n    const cl_ulong nb00 = src0->nb[0];\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const cl_ulong nb10 = src1->nb[0];\n    const cl_ulong nb11 = src1->nb[1];\n    const cl_ulong nb12 = src1->nb[2];\n    const cl_ulong nb13 = src1->nb[3];\n\n    const cl_ulong nb0 = dst->nb[0];\n    const cl_ulong nb1 = dst->nb[1];\n    const cl_ulong nb2 = dst->nb[2];\n    const cl_ulong nb3 = dst->nb[3];\n\n    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &n));\n    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),      &k));\n    CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));\n    CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));\n    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong),&nb10));\n    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong),&nb11));\n    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong),&nb12));\n    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong),&nb13));\n    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb0));\n    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb1));\n    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb2));\n    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb3));\n\n    size_t global_work_size[3]= { (size_t)k, (size_t)dst->ne[2], (size_t)dst->ne[3]};\n    size_t local_work_size[] = {16, 4, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src1);\n    GGML_ASSERT(src1->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    // src0 - filter, src1 - input\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset1 = extra1->offset + src1->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];\n    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];\n    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];\n    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];\n    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];\n    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];\n\n    const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;\n\n    const cl_long IC = src1->ne[is_2D ? 2 : 1];\n    const cl_long IH = is_2D ? src1->ne[1] : 1;\n    const cl_long IW =         src1->ne[0];\n\n    const cl_long KH = is_2D ? src0->ne[1] : 1;\n    const cl_long KW =         src0->ne[0];\n\n    const cl_long OH = is_2D ? dst->ne[2] : 1;\n    const cl_long OW =         dst->ne[1];\n\n    // nb is byte offset, src is type float32\n    const cl_ulong delta_offset = src1->nb[is_2D ? 2 : 1]/4;\n    const cl_long  batch        = src1->ne[is_2D ? 3 : 2];\n    const cl_ulong batch_offset = src1->nb[is_2D ? 3 : 2]/4;\n\n    const cl_long pelements = OW*KW*KH;\n    const cl_long CHW       = IC*KH*KW;\n\n    cl_kernel kernel;\n\n    if(dst->type == GGML_TYPE_F16) {\n        kernel = backend_ctx->kernel_im2col_f16;\n    } else {\n        kernel = backend_ctx->kernel_im2col_f32;\n    }\n\n    CL_CHECK(clSetKernelArg(kernel,   0, sizeof(cl_mem),   &extra1->data_device));\n    CL_CHECK(clSetKernelArg(kernel,   1, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel,   2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,   3, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,   4, sizeof(cl_ulong), &batch_offset));\n    CL_CHECK(clSetKernelArg(kernel,   5, sizeof(cl_ulong), &delta_offset));\n    CL_CHECK(clSetKernelArg(kernel,   6, sizeof(cl_long),  &IW));\n    CL_CHECK(clSetKernelArg(kernel,   7, sizeof(cl_long),  &IH));\n    CL_CHECK(clSetKernelArg(kernel,   8, sizeof(cl_long),  &IC));\n    CL_CHECK(clSetKernelArg(kernel,   9, sizeof(cl_long),  &OW));\n    CL_CHECK(clSetKernelArg(kernel,  10, sizeof(cl_long),  &OH));\n    CL_CHECK(clSetKernelArg(kernel,  11, sizeof(cl_long),  &KW));\n    CL_CHECK(clSetKernelArg(kernel,  12, sizeof(cl_long),  &KH));\n    CL_CHECK(clSetKernelArg(kernel,  13, sizeof(cl_long),  &pelements));\n    CL_CHECK(clSetKernelArg(kernel,  14, sizeof(cl_long),  &CHW));\n    CL_CHECK(clSetKernelArg(kernel,  15, sizeof(int),      &s0));\n    CL_CHECK(clSetKernelArg(kernel,  16, sizeof(int),      &s1));\n    CL_CHECK(clSetKernelArg(kernel,  17, sizeof(int),      &p0));\n    CL_CHECK(clSetKernelArg(kernel,  18, sizeof(int),      &p1));\n    CL_CHECK(clSetKernelArg(kernel,  19, sizeof(int),      &d0));\n    CL_CHECK(clSetKernelArg(kernel,  20, sizeof(int),      &d1));\n\n    const int num_blocks = (pelements + 256 - 1) / 256;\n    size_t global_work_size[] = {(size_t)num_blocks*256, (size_t)OH, (size_t)batch*IC};\n    size_t local_work_size[] = {256, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_argsort(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n    GGML_UNUSED(src1);\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_I32);\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    const int ne00  = src0->ne[0];\n    const int nrows = ggml_nrows(src0);\n\n    int ne00_padded = 1;\n    while (ne00_padded < ne00) {\n        ne00_padded *= 2;\n    }\n\n    int order = (enum ggml_sort_order) dst->op_params[0];\n\n    cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32;\n\n    CL_CHECK(clSetKernelArg(kernel,   0, sizeof(cl_mem),            &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,   1, sizeof(cl_ulong),          &offset0));\n    CL_CHECK(clSetKernelArg(kernel,   2, sizeof(cl_mem),            &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,   3, sizeof(cl_ulong),          &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,   4, sizeof(int),               &ne00));\n    CL_CHECK(clSetKernelArg(kernel,   5, sizeof(int),               &ne00_padded));\n    CL_CHECK(clSetKernelArg(kernel,   6, sizeof(int),               &order));\n    CL_CHECK(clSetKernelArg(kernel,   7, ne00_padded*sizeof(int),   NULL));\n\n    size_t global_work_size[] = {(size_t)ne00_padded, (size_t)nrows, (size_t)1};\n    size_t local_work_size[] = {(size_t)ne00_padded, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n    GGML_UNUSED(src1);\n\n    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    const int ne00 = src0->ne[0];\n    const int ne01 = src0->ne[1];\n    const int ne02 = src0->ne[2];\n    const int ne03 = src0->ne[3];\n\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb02 = src0->nb[2];\n    const cl_ulong nb03 = src0->nb[3];\n\n    const cl_ulong nb1  = dst->nb[1];\n    const cl_ulong nb2  = dst->nb[2];\n    const cl_ulong nb3  = dst->nb[3];\n\n    cl_kernel kernel;\n\n    const bool is_c4 = ne00 % 4 == 0;\n    if (is_c4) {\n        kernel = backend_ctx->kernel_sum_rows_f32_4;\n    } else {\n        kernel = backend_ctx->kernel_sum_rows_f32;\n    }\n\n    CL_CHECK(clSetKernelArg(kernel,   0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,   1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,   2, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,   3, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,   4, sizeof(int),      &ne00));\n    CL_CHECK(clSetKernelArg(kernel,   5, sizeof(int),      &ne01));\n    CL_CHECK(clSetKernelArg(kernel,   6, sizeof(int),      &ne02));\n    CL_CHECK(clSetKernelArg(kernel,   7, sizeof(int),      &ne03));\n    CL_CHECK(clSetKernelArg(kernel,   8, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel,   9, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel,  10, sizeof(cl_ulong), &nb03));\n    CL_CHECK(clSetKernelArg(kernel,  11, sizeof(cl_ulong), &nb1));\n    CL_CHECK(clSetKernelArg(kernel,  12, sizeof(cl_ulong), &nb2));\n    CL_CHECK(clSetKernelArg(kernel,  13, sizeof(cl_ulong), &nb3));\n\n    size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03};\n    size_t local_work_size[] = {(size_t)64, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\nstatic void ggml_cl_cumsum(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n    GGML_UNUSED(src1);\n\n    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);\n    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);\n\n    cl_kernel kernel = backend_ctx->kernel_cumsum_blk;\n\n    int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);\n    int nth = 1;\n    while (nth < ne00 && 2*nth <= max_workgroup_size) {\n        nth *= 2;\n    }\n\n    GGML_ASSERT(ne00 <= nth*nth);\n\n    const int net0 = CEIL_DIV(ne00, nth);\n    const int net1 = ne01;\n    const int net2 = ne02;\n    const int net3 = ne03;\n\n    const cl_ulong nbt0 = sizeof(float);\n    const cl_ulong nbt1 = net0*nbt0;\n    const cl_ulong nbt2 = net1*nbt1;\n    const cl_ulong nbt3 = net2*nbt2;\n\n    static ggml_cl_buffer tmp_buffer;\n    tmp_buffer.allocate(backend_ctx->context, net0*ne01*ne02*ne03*sizeof(float));\n\n    CL_CHECK(clSetKernelArg(kernel,   0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,   1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,   2, sizeof(cl_mem),   &tmp_buffer.buffer));\n    CL_CHECK(clSetKernelArg(kernel,   3, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,   4, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,   5, sizeof(int),      &ne00));\n    CL_CHECK(clSetKernelArg(kernel,   6, sizeof(int),      &ne01));\n    CL_CHECK(clSetKernelArg(kernel,   7, sizeof(int),      &ne02));\n    CL_CHECK(clSetKernelArg(kernel,   8, sizeof(int),      &ne03));\n    CL_CHECK(clSetKernelArg(kernel,   9, sizeof(cl_ulong), &nb00));\n    CL_CHECK(clSetKernelArg(kernel,  10, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel,  11, sizeof(cl_ulong), &nb02));\n    CL_CHECK(clSetKernelArg(kernel,  12, sizeof(cl_ulong), &nb03));\n    CL_CHECK(clSetKernelArg(kernel,  13, sizeof(int),      &net0));\n    CL_CHECK(clSetKernelArg(kernel,  14, sizeof(int),      &net1));\n    CL_CHECK(clSetKernelArg(kernel,  15, sizeof(int),      &net2));\n\n    size_t global_work_size[] = { (size_t)(nth*net0*ne01), (size_t)ne02, (size_t)ne03};\n    size_t local_work_size[] = { (size_t)nth, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n\n    if(ne00 > nth) {\n        // if a single workgroup cannot handle an entire row, each workgroup\n        // computes a partial sum and stores to dst, tmp_buffer contains the sum\n        // of the each workgroup; cumsum this buffer and add to the partial sums in dst\n        cl_ulong offsett = 0;\n        kernel = backend_ctx->kernel_cumsum_blk;\n        CL_CHECK(clSetKernelArg(kernel,   0, sizeof(cl_mem),   &tmp_buffer.buffer));\n        CL_CHECK(clSetKernelArg(kernel,   1, sizeof(cl_ulong), &offsett));\n        CL_CHECK(clSetKernelArg(kernel,   2, sizeof(cl_mem),   &tmp_buffer.buffer));\n        CL_CHECK(clSetKernelArg(kernel,   3, sizeof(cl_mem),   &tmp_buffer.buffer));\n        CL_CHECK(clSetKernelArg(kernel,   4, sizeof(cl_ulong), &offsett));\n        CL_CHECK(clSetKernelArg(kernel,   5, sizeof(int),      &net0));\n        CL_CHECK(clSetKernelArg(kernel,   6, sizeof(int),      &ne01));\n        CL_CHECK(clSetKernelArg(kernel,   7, sizeof(int),      &ne02));\n        CL_CHECK(clSetKernelArg(kernel,   8, sizeof(int),      &ne03));\n        CL_CHECK(clSetKernelArg(kernel,   9, sizeof(cl_ulong), &nbt0));\n        CL_CHECK(clSetKernelArg(kernel,  10, sizeof(cl_ulong), &nbt1));\n        CL_CHECK(clSetKernelArg(kernel,  11, sizeof(cl_ulong), &nbt2));\n        CL_CHECK(clSetKernelArg(kernel,  12, sizeof(cl_ulong), &nbt3));\n        CL_CHECK(clSetKernelArg(kernel,  13, sizeof(int),      &net0));\n        CL_CHECK(clSetKernelArg(kernel,  14, sizeof(int),      &net1));\n        CL_CHECK(clSetKernelArg(kernel,  15, sizeof(int),      &net2));\n\n        size_t global_work_size_1[] = { (size_t)net1*nth, (size_t)net2, (size_t)net3};\n        size_t local_work_size_1[] = { (size_t)nth, 1, 1};\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_1, local_work_size_1, dst);\n\n        kernel = backend_ctx->kernel_cumsum_add;\n        CL_CHECK(clSetKernelArg(kernel,   0, sizeof(cl_mem),   &tmp_buffer.buffer));\n        CL_CHECK(clSetKernelArg(kernel,   1, sizeof(cl_mem),   &extrad->data_device));\n        CL_CHECK(clSetKernelArg(kernel,   2, sizeof(cl_ulong), &offsetd));\n        CL_CHECK(clSetKernelArg(kernel,   3, sizeof(int),      &ne00));\n        CL_CHECK(clSetKernelArg(kernel,   4, sizeof(int),      &ne01));\n        CL_CHECK(clSetKernelArg(kernel,   5, sizeof(int),      &ne02));\n        CL_CHECK(clSetKernelArg(kernel,   6, sizeof(int),      &ne03));\n        CL_CHECK(clSetKernelArg(kernel,   7, sizeof(int),      &nbt0));\n        CL_CHECK(clSetKernelArg(kernel,   8, sizeof(int),      &nbt1));\n        CL_CHECK(clSetKernelArg(kernel,   9, sizeof(int),      &nbt2));\n        CL_CHECK(clSetKernelArg(kernel,  10, sizeof(int),      &nbt3));\n\n        size_t global_work_size_2[] = { (size_t)(nth*net0*ne01), (size_t)ne02, (size_t)ne03};\n        size_t local_work_size_2[] = { (size_t)nth, 1, 1};\n        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_2, local_work_size_2, dst);\n    }\n}\n\nstatic void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->extra);\n    GGML_ASSERT(dst);\n    GGML_ASSERT(dst->extra);\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n\n    if (src1) {\n        GGML_ASSERT(src1);\n        GGML_ASSERT(src1->extra);\n        GGML_ASSERT(ggml_are_same_shape(src0, src1));\n    }\n\n    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;\n\n    cl_kernel kernel;\n    switch (ggml_get_glu_op(dst)) {\n        case GGML_GLU_OP_GEGLU:\n            if (dst->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_geglu;\n            } else {\n                kernel = backend_ctx->kernel_geglu_f16;\n            }\n            break;\n        case GGML_GLU_OP_REGLU:\n            if (dst->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_reglu;\n            } else {\n                kernel = backend_ctx->kernel_reglu_f16;\n            }\n            break;\n        case GGML_GLU_OP_SWIGLU:\n            if (dst->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_swiglu;\n            } else {\n                kernel = backend_ctx->kernel_swiglu_f16;\n            }\n            break;\n        case GGML_GLU_OP_SWIGLU_OAI:\n            kernel = backend_ctx->kernel_swiglu_oai;\n            break;\n        case GGML_GLU_OP_GEGLU_ERF:\n            if (dst->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_geglu_erf;\n            } else {\n                kernel = backend_ctx->kernel_geglu_erf_f16;\n            }\n            break;\n        case GGML_GLU_OP_GEGLU_QUICK:\n            if (dst->type == GGML_TYPE_F32) {\n                kernel = backend_ctx->kernel_geglu_quick;\n            } else {\n                kernel = backend_ctx->kernel_geglu_quick_f16;\n            }\n            break;\n        default:\n            GGML_ABORT(\"Unsupported glu op\");\n    }\n\n    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;\n    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;\n\n    ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;\n\n    cl_ulong offset0 = extra0->offset + src0->view_offs;\n    cl_ulong offsetd = extrad->offset + dst->view_offs;\n\n    cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;\n\n    const int ne0       = dst->ne[0];\n\n    const cl_ulong nb01 = src0->nb[1];\n    const cl_ulong nb11 = src1 ? src1->nb[1] : nb01;\n\n    const cl_ulong nb1  = dst->nb[1];\n\n    const int   swp   = ggml_get_op_params_i32(dst, 1);\n    const float alpha = ggml_get_op_params_f32(dst, 2);\n    const float limit = ggml_get_op_params_f32(dst, 3);\n\n    const int ne00_off = src1 ? 0 : (swp ? ne0 : 0);\n    const int ne10_off = src1 ? 0 : (swp ? 0 : ne0);\n\n    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));\n    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   src1 ? &extra1->data_device : &extra0->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));\n    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));\n    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));\n    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb01));\n    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb11));\n    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne0));\n    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb1));\n    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne00_off));\n    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10_off));\n\n    if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU_OAI) {\n        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &limit));\n        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &alpha));\n    }\n\n    const size_t nrows = ggml_nrows(src0);\n    size_t nth = 512;\n    size_t global_work_size[] = {nrows*nth, 1, 1};\n    size_t local_work_size[] = {nth, 1, 1};\n\n    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);\n}\n\n//------------------------------------------------------------------------------\n// Op offloading\n//------------------------------------------------------------------------------\n\ntypedef void (*ggml_cl_func_t)(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);\n\nbool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor) {\n    ggml_cl_func_t func = nullptr;\n\n    ggml_tensor * src0 = tensor->src[0];\n    ggml_tensor * src1 = tensor->src[1];\n\n    const bool any_on_device = tensor->extra\n        || (src0 != nullptr && src0->extra)\n        || (src1 != nullptr && src1->extra);\n\n    switch (tensor->op) {\n        case GGML_OP_GET_ROWS:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_get_rows;\n            break;\n        case GGML_OP_SET_ROWS:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_set_rows;\n            break;\n        case GGML_OP_CPY:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_cpy;\n            break;\n        case GGML_OP_SET:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_set;\n            break;\n        case GGML_OP_DUP:\n        case GGML_OP_CONT:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_dup;\n            break;\n        case GGML_OP_ADD:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_add;\n            break;\n        case GGML_OP_ADD_ID:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_add_id;\n            break;\n        case GGML_OP_MUL:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_mul;\n            break;\n        case GGML_OP_DIV:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_div;\n            break;\n        case GGML_OP_SUB:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_sub;\n            break;\n        case GGML_OP_SQR:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_sqr;\n            break;\n        case GGML_OP_SQRT:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_sqrt;\n            break;\n        case GGML_OP_MEAN:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_mean;\n            break;\n        case GGML_OP_UNARY:\n            switch (ggml_get_unary_op(tensor)) {\n                case GGML_UNARY_OP_GELU:\n                    if (!any_on_device) {\n                        return false;\n                    }\n                    func = ggml_cl_gelu;\n                    break;\n                case GGML_UNARY_OP_GELU_ERF:\n                    if (!any_on_device) {\n                        return false;\n                    }\n                    func = ggml_cl_gelu_erf;\n                    break;\n                case GGML_UNARY_OP_GELU_QUICK:\n                    if (!any_on_device) {\n                        return false;\n                    }\n                    func = ggml_cl_gelu_quick;\n                    break;\n                case GGML_UNARY_OP_SILU:\n                    if (!any_on_device) {\n                        return false;\n                    }\n                    func = ggml_cl_silu;\n                    break;\n                case GGML_UNARY_OP_RELU:\n                    if (!any_on_device) {\n                        return false;\n                    }\n                    func = ggml_cl_relu;\n                    break;\n                case GGML_UNARY_OP_SIGMOID:\n                    if (!any_on_device) {\n                        return false;\n                    }\n                    func = ggml_cl_sigmoid;\n                    break;\n                case GGML_UNARY_OP_TANH:\n                    if (!any_on_device) {\n                        return false;\n                    }\n                    func = ggml_cl_tanh;\n                    break;\n                case GGML_UNARY_OP_NEG:\n                    if (!any_on_device) {\n                        return false;\n                    }\n                    func = ggml_cl_neg;\n                    break;\n                case GGML_UNARY_OP_EXP:\n                    if (!any_on_device) {\n                        return false;\n                    }\n                    func = ggml_cl_exp;\n                    break;\n                case GGML_UNARY_OP_EXPM1:\n                    if (!any_on_device) {\n                        return false;\n                    }\n                    func = ggml_cl_expm1;\n                    break;\n                case GGML_UNARY_OP_SOFTPLUS:\n                    if (!any_on_device) {\n                        return false;\n                    }\n                    func = ggml_cl_softplus;\n                    break;\n                default:\n                    return false;\n            } break;\n        case GGML_OP_GLU:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_glu;\n            break;\n        case GGML_OP_TRI:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_tri;\n            break;\n        case GGML_OP_FILL:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_fill;\n            break;\n        case GGML_OP_CLAMP:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_clamp;\n            break;\n        case GGML_OP_NORM:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_norm;\n            break;\n        case GGML_OP_RMS_NORM:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_rms_norm;\n            break;\n        case GGML_OP_L2_NORM:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_l2_norm;\n            break;\n        case GGML_OP_GROUP_NORM:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_group_norm;\n            break;\n                case GGML_OP_REPEAT:\n             if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_repeat;\n            break;\n        case GGML_OP_PAD:\n            if (!any_on_device) {\n                return false;\n            }\n            ggml_cl_pad(backend, tensor->src[0], tensor);\n            return true;\n        case GGML_OP_UPSCALE:\n            if (!any_on_device) {\n                return false;\n            }\n            ggml_cl_upscale(backend, tensor->src[0], tensor);\n            return true;\n        case GGML_OP_CONV_2D:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_conv_2d;\n            break;\n        case GGML_OP_SSM_CONV:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_ssm_conv;\n            break;\n        case GGML_OP_CONCAT:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_concat;\n            break;\n        case GGML_OP_TIMESTEP_EMBEDDING:\n            if (!any_on_device) {\n                return false;\n            }\n            ggml_cl_timestep_embedding(backend, tensor->src[0], tensor);\n            return true;\n        case GGML_OP_MUL_MAT:\n            if (!any_on_device && !ggml_cl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {\n                return false;\n            }\n            func = ggml_cl_mul_mat;\n            break;\n        case GGML_OP_MUL_MAT_ID:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_mul_mat_id;\n            break;\n        case GGML_OP_SCALE:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_scale;\n            break;\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_nop;\n            break;\n        case GGML_OP_DIAG:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_diag;\n            break;\n        case GGML_OP_DIAG_MASK_INF:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_diag_mask_inf;\n            break;\n        case GGML_OP_SOFT_MAX:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_soft_max;\n            break;\n        case GGML_OP_ROPE:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_rope;\n            break;\n        case GGML_OP_SOLVE_TRI:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_solve_tri;\n            break;\n        case GGML_OP_IM2COL:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_im2col;\n            break;\n        case GGML_OP_ARGSORT:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_argsort;\n            break;\n        case GGML_OP_SUM_ROWS:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_sum_rows;\n            break;\n        case GGML_OP_CUMSUM:\n            if (!any_on_device) {\n                return false;\n            }\n            func = ggml_cl_cumsum;\n            break;\n        case GGML_OP_FLASH_ATTN_EXT:\n            if (!any_on_device) {\n                return false;\n            }\n            ggml_cl_flash_attn(backend, tensor->src[0], tensor->src[1], tensor);\n            return true;\n        default:\n            return false;\n    }\n\n    func(backend, tensor->src[0], tensor->src[1], tensor);\n    return true;\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/add.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// add\n//------------------------------------------------------------------------------\n\n// general-purpose kernel for addition of two tensors\n// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3\n// cons: not very efficient\nkernel void kernel_add(\n        global char * src0,\n        ulong  offset0,\n        global char * src1,\n        ulong  offset1,\n        global char * dst,\n        ulong  offsetd,\n        int   ne00,\n        int   ne01,\n        int   ne02,\n        int   ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int   ne10,\n        int   ne11,\n        int   ne12,\n        int   ne13,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        int   ne0,\n        int   ne1,\n        int   ne2,\n        int   ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst = dst + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int i13 = i03 % ne13;\n    int i12 = i02 % ne12;\n    int i11 = i01 % ne11;\n\n    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;\n    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;\n    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const int i10 = i0 % ne10;\n        *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10));\n    }\n}\n\n// assumption: src1 is a row\n// broadcast src1 into src0\nkernel void kernel_add_row(\n        global float4 * src0,\n        ulong  offset0,\n        global float4 * src1,\n        ulong  offset1,\n        global float4 * dst,\n        ulong  offsetd,\n        int ne\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    src1 = (global float4*)((global char*)src1 + offset1);\n    dst = (global float4*)((global char*)dst + offsetd);\n\n    // This performs better than using %.\n    uint gid = get_global_id(0);\n    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne\n    dst[gid] = src0[gid] + src1[idx1];\n}\n\nkernel void kernel_add_f16(\n        global char * src0,\n        ulong  offset0,\n        global char * src1,\n        ulong  offset1,\n        global char * dst,\n        ulong  offsetd,\n        int   ne00,\n        int   ne01,\n        int   ne02,\n        int   ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int   ne10,\n        int   ne11,\n        int   ne12,\n        int   ne13,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        int   ne0,\n        int   ne1,\n        int   ne2,\n        int   ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3,\n        int type_src0,\n        int type_src1\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst = dst + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int i13 = i03 % ne13;\n    int i12 = i02 % ne12;\n    int i11 = i01 % ne11;\n\n    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;\n    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;\n    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const int i10 = i0 % ne10;\n\n        half v0, v1;\n        if (type_src0 == 1) {\n            v0 = convert_half(*((global float *)(src0_ptr + i0*nb00)));\n        } else {\n            v0 = *((global half *)(src0_ptr + i0*nb00));\n        }\n\n        if (type_src1 == 1) {\n            v1 = convert_half(*((global float *)(src1_ptr + i10*nb10)));\n        } else {\n            v1 = *((global half *)(src1_ptr + i10*nb10));\n        }\n\n        *((global half *)(dst_ptr + i0*nb0)) = v0 + v1;\n    }\n}\n\nkernel void kernel_add_row_f16(\n        global char * src0,\n        ulong  offset0,\n        global char * src1,\n        ulong  offset1,\n        global half4 * dst,\n        ulong  offsetd,\n        int ne,\n        int type_src0,\n        int type_src1\n) {\n    dst = (global half4*)((global char*)dst + offsetd);\n\n    // This performs better than using %.\n    uint gid = get_global_id(0);\n    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne\n\n    half4 v0, v1;\n    if (type_src0 == 1) {\n        global float4* src0_f32 = (global float4*)((global char*)src0 + offset0);\n        v0 = convert_half4(src0_f32[gid]);\n    } else {\n        global half4* src0_f16 = (global half4*)((global char*)src0 + offset0);\n        v0 = src0_f16[gid];\n    }\n\n    if (type_src1 == 1) {\n        global float4* src1_f32 = (global float4*)((global char*)src1 + offset1);\n        v1 = convert_half4(src1_f32[idx1]);\n    } else {\n        global half4* src1_f16 = (global half4*)((global char*)src1 + offset1);\n        v1 = src1_f16[idx1];\n    }\n\n    dst[gid] = v0 + v1;\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/add_id.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// add_id\n//------------------------------------------------------------------------------\nkernel void kernel_add_id(\n    global char * src0,\n    ulong         offset0,\n    global char * src1,\n    ulong         offset1,\n    global char * src2,\n    ulong         offset2,\n    global char * dst,\n    ulong         offsetd,\n    ulong         nb01,\n    ulong         nb02,\n    ulong         nb11,\n    ulong         nb21,\n    int           ne0,\n    int           ne1\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    src2 = (global char*)((global char*)src2 + offset2);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    int i1 = get_group_id(0);\n    int i2 = get_group_id(1);\n\n    const int i11 = *((global const int *) (src2 + i1*sizeof(int) + i2*nb21));\n\n    const size_t nb1 = ne0 * sizeof(float);\n    const size_t nb2 = ne1 * nb1;\n\n    global float * dst_row  = (global float *)((global char *)dst  + i1*nb1 + i2*nb2);\n    global float * src0_row = (global float *)((global char *)src0 + i1*nb01 + i2*nb02);\n    global float * src1_row = (global float *)((global char *)src1 + i11*nb11);\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        dst_row[i0] = src0_row[i0] + src1_row[i0];\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/argsort.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define SWAP(x, y, T) { T tmp = (x); (x) = (y); (y) = tmp; }\n\nenum ggml_sort_order {\n    GGML_SORT_ORDER_ASC,\n    GGML_SORT_ORDER_DESC,\n};\n\nkernel void kernel_argsort_f32_i32(\n    global float * src0,\n    ulong          offset0,\n    global int   * dst,\n    ulong          offsetd,\n    const int      ne00,\n    const int      ne00_pad,\n    const int      order,\n    local int    * dst_row\n) {\n    // bitonic sort\n    int col = get_local_id(0);\n    int row = get_group_id(1);\n\n    if (col >= ne00_pad) {\n        return;\n    }\n\n    src0 = (global char  *)((global char *)src0 + offset0);\n    dst  = (global float *)((global char *)dst  + offsetd);\n\n    global float * x_row = src0 + row * ne00;\n\n    // initialize indices\n    dst_row[col] = col;\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    for (int k = 2; k <= ne00_pad; k *= 2) {\n        for (int j = k / 2; j > 0; j /= 2) {\n            int ixj = col ^ j;\n            if (ixj > col) {\n                if ((col & k) == 0) {\n                    if (dst_row[col] >= ne00 ||\n                        (dst_row[ixj] < ne00 && (order == GGML_SORT_ORDER_ASC ?\n                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :\n                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))\n                    ) {\n                        SWAP(dst_row[col], dst_row[ixj], int);\n                    }\n                } else {\n                    if (dst_row[ixj] >= ne00 ||\n                        (dst_row[col] < ne00 && (order == GGML_SORT_ORDER_ASC ?\n                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :\n                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))\n                    ) {\n                        SWAP(dst_row[col], dst_row[ixj], int);\n                    }\n                }\n            }\n            barrier(CLK_LOCAL_MEM_FENCE);\n        }\n    }\n\n    // copy the result to dst without the padding\n    if (col < ne00) {\n        dst[row * ne00 + col] = dst_row[col];\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/clamp.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// clamp\n//------------------------------------------------------------------------------\nkernel void kernel_clamp(\n        global float * src0,\n        ulong offset0,\n        global float * dst,\n        ulong offsetd,\n        float min,\n        float max\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = src0[get_global_id(0)] < min ?\n        min :\n        (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]);\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/concat.cl",
    "content": "kernel void kernel_concat_f32(\n    global  const char * src0,\n    ulong                offset0,\n    global  const char * src1,\n    ulong                offset1,\n    global        char * dst,\n    ulong                offsetd,\n    int             ne00,\n    int             ne01,\n    int             ne02,\n    int             ne03,\n    ulong           nb00,\n    ulong           nb01,\n    ulong           nb02,\n    ulong           nb03,\n    ulong           nb10,\n    ulong           nb11,\n    ulong           nb12,\n    ulong           nb13,\n    int             ne0,\n    ulong           nb0,\n    ulong           nb1,\n    ulong           nb2,\n    ulong           nb3,\n    int             dim\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst  + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    int o[4] = {0, 0, 0, 0};\n    o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));\n\n    global const float * x;\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {\n            x = (global const float *)(src0 + (i3       )*nb03 + (i2       )*nb02 + (i1       )*nb01 + (i0       )*nb00);\n        } else {\n            x = (global const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);\n        }\n\n        global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);\n\n        *y = *x;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/conv2d.cl",
    "content": "#ifdef USE_FP16\n#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#define T_FLOAT half\n#define T_FLOAT4 half4\n#define VSTORE_T_FLOAT4(data, offset, p) vstore_half4_rte(data, offset, p)\n#else\n#define T_FLOAT float\n#define T_FLOAT4 float4\n#define VSTORE_T_FLOAT4(data, offset, p) vstore4(data, offset, p)\n#endif\n\n#if defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#else\n#define REQD_SUBGROUP_SIZE_128\n#endif\n\n#define T_ACCUM float4\n#define VEC_SIZE 4\n\n#define BS_K 64\n#define BS_NPQ 64\n#define BS_CRS 16\n\n#define TS_K 4\n#define TS_NPQ 8\n\n#define WG_K (BS_K / TS_K)\n#define WG_NPQ (BS_NPQ / TS_NPQ)\n\n#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)\n#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)\n\nstatic inline uint splitWork(uint work_size, uint block_size){\n    return (work_size + block_size - 1) / block_size;\n}\n\nREQD_SUBGROUP_SIZE_128\nkernel void kernel_conv_2d(\n    global void* p_knl,\n    ulong off_knl,\n    global void* p_src,\n    ulong off_src,\n    global void* p_dst,\n    ulong off_dst,\n    local void* shared,\n    uint Cout, uint Cin, uint N,\n    uint KW, uint KH, uint W, uint H, uint OW, uint OH,\n    uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,\n    uint nb01, uint nb02, uint nb03,\n    uint nb11, uint nb12, uint nb13,\n    uint nb1, uint nb2, uint nb3\n) {\n    global T_FLOAT* knl_data = (global T_FLOAT*) ((global char*)p_knl + off_knl);\n    global T_FLOAT* src_data = (global T_FLOAT*) ((global char*)p_src + off_src);\n    global T_FLOAT* dst_data = (global T_FLOAT*) ((global char*)p_dst + off_dst);\n\n    const uint K = Cout;\n    const uint CRS = Cin*KH*KW;\n    const uint NPQ = N*OH*OW;\n\n    const uint lid_k = get_local_id(0);\n    const uint lid_npq = get_local_id(1);\n    const uint tid = lid_npq * WG_K + lid_k;\n\n    const uint B_idx_K = get_group_id(0);\n    const uint B_idx_NPQ = get_group_id(1);\n\n    const uint offset_k = B_idx_K * BS_K;\n    const uint offset_npq = B_idx_NPQ * BS_NPQ;\n\n    local T_FLOAT* Ash = (local T_FLOAT*)shared;\n    local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * BS_CRS];\n\n    T_ACCUM regC[TS_K][TS_NPQ_VEC];\n    for (int i = 0; i < TS_K; ++i) {\n        for (int j = 0; j < TS_NPQ_VEC; ++j) {\n            regC[i][j] = (T_ACCUM)(0.0f);\n        }\n    }\n\n    const uint NB_CRS = splitWork(CRS, BS_CRS);\n\n    for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {\n        const uint offset_crs = B_idx_CRS * BS_CRS;\n\n        for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {\n            const uint k_l = i / BS_CRS;\n            const uint crs_l = i % BS_CRS;\n            const uint k_g = offset_k + k_l;\n            const uint crs_g = offset_crs + crs_l;\n\n            if (k_g < K && crs_g < CRS) {\n                const uint Cin_idx = crs_g / (KW*KH);\n                const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;\n                const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;\n                const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;\n                Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];\n            } else {\n                Ash[k_l * BS_CRS + crs_l] = (T_FLOAT)0.0f;\n            }\n        }\n\n        for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {\n            const uint crs_l = i / BS_NPQ_VEC;\n            const uint npq_l_vec = i % BS_NPQ_VEC;\n            const uint crs_g = offset_crs + crs_l;\n\n            T_FLOAT4 val = (T_FLOAT4)(0.0f);\n            if (crs_g < CRS) {\n                const uint Cin_idx = crs_g / (KW * KH);\n                const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;\n                const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;\n                for (int v = 0; v < VEC_SIZE; ++v) {\n                    const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;\n                    if (npq_g < NPQ) {\n                        const uint N_idx = npq_g / (OH * OW);\n                        const uint pq_idx = npq_g % (OH * OW);\n                        const uint OH_idx = pq_idx / OW;\n                        const uint OW_idx = pq_idx % OW;\n                        const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);\n                        const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);\n\n                        if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {\n                            const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;\n                            ((T_FLOAT*)&val)[v] = src_data[src_idx];\n                        }\n                    }\n                }\n            }\n            Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;\n        }\n\n        barrier(CLK_LOCAL_MEM_FENCE);\n\n        #pragma unroll\n        for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {\n            T_FLOAT regA[TS_K];\n            for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {\n                regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];\n            }\n\n            for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {\n                T_FLOAT4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];\n                for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {\n                    regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]);\n                }\n            }\n        }\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n\n    for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {\n        const uint k_g = offset_k + lid_k * TS_K + k_l_reg;\n        if (k_g >= K) continue;\n\n        for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {\n            const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;\n\n            const uint N_idx = npq_g_base / (OH * OW);\n            const uint pq_idx = npq_g_base % (OH * OW);\n            const uint OH_idx = pq_idx / OW;\n            const uint OW_idx = pq_idx % OW;\n\n            if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {\n                const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;\n                VSTORE_T_FLOAT4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);\n            } else {\n                T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];\n                for (int v = 0; v < VEC_SIZE; ++v) {\n                    const uint npq_g = npq_g_base + v;\n                    if (npq_g < NPQ) {\n                        const uint N_idx_s = npq_g / (OH*OW);\n                        const uint pq_idx_s = npq_g % (OH*OW);\n                        const uint OH_idx_s = pq_idx_s / OW;\n                        const uint OW_idx_s = pq_idx_s % OW;\n                        const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;\n                        dst_data[dst_idx_s] = (T_FLOAT)(((float*)&res)[v]);\n                    }\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/conv2d_f16_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#if defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#else\n#define REQD_SUBGROUP_SIZE_128\n#endif\n\n#define T_ACCUM float4\n#define VEC_SIZE 4\n\n#define BS_K 64\n#define BS_NPQ 64\n#define BS_CRS 16\n\n#define TS_K 4\n#define TS_NPQ 8\n\n#define WG_K (BS_K / TS_K)\n#define WG_NPQ (BS_NPQ / TS_NPQ)\n\n#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)\n#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)\n\nstatic inline uint splitWork(uint work_size, uint block_size){\n    return (work_size + block_size - 1) / block_size;\n}\n\nREQD_SUBGROUP_SIZE_128\nkernel void kernel_conv_2d(\n    global void* p_knl,\n    ulong off_knl,\n    global void* p_src,\n    ulong off_src,\n    global void* p_dst,\n    ulong off_dst,\n    local void* shared,\n    uint Cout, uint Cin, uint N,\n    uint KW, uint KH, uint W, uint H, uint OW, uint OH,\n    uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,\n    uint nb01, uint nb02, uint nb03,\n    uint nb11, uint nb12, uint nb13,\n    uint nb1, uint nb2, uint nb3\n) {\n    global half* knl_data = (global half*) ((global char*)p_knl + off_knl);\n    global float* src_data = (global float*) ((global char*)p_src + off_src);\n    global float* dst_data = (global float*) ((global char*)p_dst + off_dst);\n\n    const uint K = Cout;\n    const uint CRS = Cin*KH*KW;\n    const uint NPQ = N*OH*OW;\n\n    const uint lid_k = get_local_id(0);\n    const uint lid_npq = get_local_id(1);\n    const uint tid = lid_npq * WG_K + lid_k;\n\n    const uint B_idx_K = get_group_id(0);\n    const uint B_idx_NPQ = get_group_id(1);\n\n    const uint offset_k = B_idx_K * BS_K;\n    const uint offset_npq = B_idx_NPQ * BS_NPQ;\n\n    local half* Ash = (local half*)shared;\n    local float4* Bsh = (local float4*) &Ash[BS_K * BS_CRS];\n\n    T_ACCUM regC[TS_K][TS_NPQ_VEC];\n    for (int i = 0; i < TS_K; ++i) {\n        for (int j = 0; j < TS_NPQ_VEC; ++j) {\n            regC[i][j] = (T_ACCUM)(0.0f);\n        }\n    }\n\n    const uint NB_CRS = splitWork(CRS, BS_CRS);\n\n    for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {\n        const uint offset_crs = B_idx_CRS * BS_CRS;\n\n        for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {\n            const uint k_l = i / BS_CRS;\n            const uint crs_l = i % BS_CRS;\n            const uint k_g = offset_k + k_l;\n            const uint crs_g = offset_crs + crs_l;\n\n            if (k_g < K && crs_g < CRS) {\n                const uint Cin_idx = crs_g / (KW*KH);\n                const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;\n                const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;\n                const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;\n                Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];\n            } else {\n                Ash[k_l * BS_CRS + crs_l] = (half)0.0f;\n            }\n        }\n\n        for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {\n            const uint crs_l = i / BS_NPQ_VEC;\n            const uint npq_l_vec = i % BS_NPQ_VEC;\n            const uint crs_g = offset_crs + crs_l;\n\n            float4 val = (float4)(0.0f);\n            if (crs_g < CRS) {\n                const uint Cin_idx = crs_g / (KW * KH);\n                const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;\n                const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;\n                for (int v = 0; v < VEC_SIZE; ++v) {\n                    const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;\n                    if (npq_g < NPQ) {\n                        const uint N_idx = npq_g / (OH * OW);\n                        const uint pq_idx = npq_g % (OH * OW);\n                        const uint OH_idx = pq_idx / OW;\n                        const uint OW_idx = pq_idx % OW;\n                        const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);\n                        const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);\n\n                        if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {\n                            const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;\n                            ((float*)&val)[v] = src_data[src_idx];\n                        }\n                    }\n                }\n            }\n            Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;\n        }\n\n        barrier(CLK_LOCAL_MEM_FENCE);\n\n        #pragma unroll\n        for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {\n            half regA[TS_K];\n            for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {\n                regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];\n            }\n\n            for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {\n                float4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];\n                for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {\n                    regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), regB, regC[k_l_reg][npq_l_vec_reg]);\n                }\n            }\n        }\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n\n    for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {\n        const uint k_g = offset_k + lid_k * TS_K + k_l_reg;\n        if (k_g >= K) continue;\n\n        for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {\n            const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;\n\n            const uint N_idx = npq_g_base / (OH * OW);\n            const uint pq_idx = npq_g_base % (OH * OW);\n            const uint OH_idx = pq_idx / OW;\n            const uint OW_idx = pq_idx % OW;\n\n            if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {\n                const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;\n                vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);\n            } else {\n                T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];\n                for (int v = 0; v < VEC_SIZE; ++v) {\n                    const uint npq_g = npq_g_base + v;\n                    if (npq_g < NPQ) {\n                        const uint N_idx_s = npq_g / (OH*OW);\n                        const uint pq_idx_s = npq_g % (OH*OW);\n                        const uint OH_idx_s = pq_idx_s / OW;\n                        const uint OW_idx_s = pq_idx_s % OW;\n                        const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;\n                        dst_data[dst_idx_s] = ((float*)&res)[v];\n                    }\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/cpy.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// cpy\n//------------------------------------------------------------------------------\n\nkernel void kernel_cpy_f16_f16(\n        global half * src0,\n        ulong offset0,\n        global half * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = (global half*)((global char*)src0 + offset0);\n    dst = (global half*)((global char*)dst + offsetd);\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;\n\n    int i3 = n / (ne2*ne1*ne0);\n    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);\n    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;\n    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);\n\n    global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);\n\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);\n        dst_data[i00] = src[0];\n    }\n}\n\nkernel void kernel_cpy_f16_f32(\n        global half * src0,\n        ulong offset0,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n\n    src0 = (global half*)((global char*)src0 + offset0);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;\n\n    int i3 = n / (ne2*ne1*ne0);\n    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);\n    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;\n    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);\n\n    global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);\n\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);\n        dst_data[i00] = src[0];\n    }\n}\n\nkernel void kernel_cpy_f32_f16(\n        global float * src0,\n        ulong offset0,\n        global half * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst = (global half*)((global char*)dst + offsetd);\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;\n\n    int i3 = n / (ne2*ne1*ne0);\n    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);\n    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;\n    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);\n\n    global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);\n\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);\n\n        dst_data[i00] = src[0];\n    }\n}\n\nkernel void kernel_cpy_f32_f32(\n        global float * src0,\n        ulong offset0,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;\n\n    int i3 = n / (ne2*ne1*ne0);\n    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);\n    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;\n    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);\n\n    global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);\n\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);\n\n        dst_data[i00] = src[0];\n    }\n}\n\nkernel void kernel_cpy_i32_i32(\n        global int * src0,\n        ulong offset0,\n        global int * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = (global int*)((global char*)src0 + offset0);\n    dst = (global int*)((global char*)dst + offsetd);\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;\n\n    int i3 = n / (ne2*ne1*ne0);\n    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);\n    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;\n    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);\n\n    global int * dst_data = (global int *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);\n\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        global const int * src = (global int *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);\n\n        dst_data[i00] = src[0];\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/cumsum.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n// max workgroup size is usually 1024, this covers various subgroups sizes\n#define MAX_SUBGROUPS 128\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_32\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_cumsum_blk(\n        global char * src0,\n        ulong offset0,\n        global char * tmp,\n        global char * dst,\n        ulong offsetd,\n        int   ne00,\n        int   ne01,\n        int   ne02,\n        int   ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        uint net0,\n        uint net1,\n        uint net2\n) {\n    src0 = src0 + offset0;\n    dst  = dst + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    const int nth = get_local_size(0);\n    const int tid = get_local_id(0);\n\n    const uint sg_size = get_sub_group_size();\n    const uint sg_id = get_sub_group_id();\n    const uint sg_lid = get_sub_group_local_id();\n\n    const int ib = i1 / ne01;\n    const int i00 = ib * nth;\n    const int i01 = i1 % ne01;\n    const int i02 = i2;\n    const int i03 = i3;\n\n    global const float * src0_row = (global const float *)(src0 + i03*nb03 + i02*nb02 + i01*nb01);\n    global       float * tmp_row  = (global float *)tmp + net0 * i01 + net0 * net1 * i02 + net0 * net1 * net2 * i03;\n    global       float * dst_row  = (global float *)dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;\n\n    __local float partial[MAX_SUBGROUPS];\n\n    float v = 0.0f;\n    if (i00 + tid < ne00) {\n        v = src0_row[i00 + tid];\n    }\n\n    float s = sub_group_scan_inclusive_add(v);\n    if (sg_lid == sg_size - 1) {\n        partial[sg_id] = s;\n    }\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    // NB: subgroup size should be larger than number of subgroups\n    // assuming max workgroup size of 1024, subgroup size should be >= 32\n    if (sg_id == 0) {\n        float x = 0.0f;\n        if (sg_lid < get_num_sub_groups()) {\n            x = partial[sg_lid];\n        }\n        float ex = sub_group_scan_exclusive_add(x);\n        if (sg_lid < get_num_sub_groups()) {\n            partial[sg_lid] = ex;\n        }\n    }\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    s += partial[sg_id];\n\n    if (i00 + tid < ne00) {\n        dst_row[i00 + tid] = s;\n    }\n    if (ne00 > nth && tid == nth - 1) {\n        tmp_row[ib] = s;\n    }\n}\n\nkernel void kernel_cumsum_add(\n        global char * tmp,\n        global char * dst,\n        ulong offsetd,\n        int   ne00,\n        int   ne01,\n        int   ne02,\n        int   ne03,\n        uint nbt0,\n        uint nbt1,\n        uint nbt2,\n        uint nbt3\n) {\n    dst  = dst + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    const int nth = get_local_size(0);\n    const int tid = get_local_id(0);\n\n    const int ib = i1 / ne01;\n    if (ib == 0) {\n        return;\n    }\n    const int i00 = ib * nth;\n    const int i01 = i1 % ne01;\n    const int i02 = i2;\n    const int i03 = i3;\n\n    global float * tmp_row  = (global float *)(tmp + nbt1 * i01 + nbt2 * i02 + nbt3 * i03);\n    global float * dst_row  = (global float *)dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;\n\n    if (i00 + tid < ne00) {\n        dst_row[i00 + tid] += tmp_row[ib - 1];\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/cvt.cl",
    "content": "//------------------------------------------------------------------------------\n// This file is contains kernels for data conversion.\n// These kernels are used when loading the model, so its performance is less\n// important.\n//------------------------------------------------------------------------------\n#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK4_0                   32\n#define QR4_0                   2\n#define QK4_1                   32\n#define QR4_1                   2\n#define QK5_0                   32\n#define QR5_0                   2\n#define QK5_1                   32\n#define QR5_1                   2\n#define QK8_0                   32\n#define QR8_0                   1\n#define QK_K                    256\n#define K_QUANTS_PER_ITERATION  2\n\ntypedef char int8_t;\ntypedef uchar uint8_t;\ntypedef short int16_t;\ntypedef ushort uint16_t;\ntypedef int int32_t;\ntypedef uint uint32_t;\n\n//------------------------------------------------------------------------------\n// block_q4_0\n//------------------------------------------------------------------------------\nstruct block_q4_0\n{\n    half d;\n    uint8_t qs[QK4_0 / 2];\n};\n\n//------------------------------------------------------------------------------\n// block_q4_1\n//------------------------------------------------------------------------------\nstruct block_q4_1 {\n    half d; // delta\n    half m; // min\n    uchar qs[QK4_1 / 2]; // nibbles / quants\n};\n\n//------------------------------------------------------------------------------\n// block_q6_K\n//------------------------------------------------------------------------------\nstruct block_q6_K {\n    uint8_t ql[QK_K/2];      // quants, lower 4 bits\n    uint8_t qh[QK_K/4];      // quants, upper 2 bits\n    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits\n    half d;                  // super-block scale\n};\n\n//------------------------------------------------------------------------------\n// kernel_convert_block_q4_0\n// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).\n// This kernel does not deshuffle the bits.\n//------------------------------------------------------------------------------\nkernel void kernel_convert_block_q4_0(\n    global struct block_q4_0 * src0,\n    global uchar * dst_q,\n    global half  * dst_d\n) {\n    global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);\n    global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);\n    global half  * d = (global half *) dst_d + get_global_id(0);\n\n    *d = b->d;\n\n    for (int i = 0; i < QK4_0/2; ++i) {\n        q[i] = b->qs[i];\n    }\n}\n\nkernel void kernel_restore_block_q4_0(\n    global uchar * src_q,\n    global half  * src_d,\n    global struct block_q4_0 * dst\n) {\n    global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0);\n    global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0);\n    global half  * d = (global half *) src_d + get_global_id(0);\n\n    b->d = *d;\n    for (int i = 0; i < QK4_0/2; ++i) {\n        b->qs[i] = q[i];\n    }\n}\n\n//------------------------------------------------------------------------------\n// kernel_convert_block_q4_0_noshuffle\n// Flatten q4_0 weights and unshuffle the bits\n//------------------------------------------------------------------------------\n\nkernel void kernel_convert_block_q4_0_noshuffle(\n    global struct block_q4_0 * src0,\n    global uchar * dst_q,\n    global half  * dst_d\n) {\n    global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);\n    global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);\n    global half  * d = (global half *) dst_d + get_global_id(0);\n\n    *d = b->d;\n    for (int i = 0; i < QK4_0/4; ++i) {\n        uchar x0 = b->qs[2*i + 0];\n        uchar x1 = b->qs[2*i + 1];\n\n        q[i + 0      ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);\n        q[i + QK4_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);\n\n#ifdef ADRENO_GPU\n        // Workaround for adreno - must have the following printf statement for\n        // the kernel to work properly. Otherwise it produces incorrect result.\n        // convert_uchar above also seems necessary.\n        // Compare against a large number so that it does not print anything.\n        // get_sub_group_local_id() also works.\n        if (get_global_id(0) == 65536*4096) {\n            printf(\"%04x - %02x\\n\", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0));\n        }\n#endif\n    }\n}\n\nkernel void kernel_restore_block_q4_0_noshuffle(\n    global uchar * src_q,\n    global half  * src_d,\n    global struct block_q4_0 * dst,\n    uchar mask_0F,\n    uchar mask_F0\n) {\n    global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0);\n    global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0);\n    global half  * d = (global half *) src_d + get_global_id(0);\n\n    b->d = *d;\n    for (int i = 0; i < QK4_0/4; ++i) {\n        uchar x0 = q[i + 0      ] ;\n        uchar x1 = q[i + QK4_0/4];\n\n        b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4));\n        b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0));\n    }\n}\n\n//------------------------------------------------------------------------------\n// kernel_convert_block_q4_1\n// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA).\n// This kernel does not deshuffle the bits.\n//------------------------------------------------------------------------------\nkernel void kernel_convert_block_q4_1(\n    global struct block_q4_1 * src0,\n    global uchar * dst_q,\n    global half  * dst_d,\n    global half  * dst_m\n) {\n    global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0);\n    global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0);\n    global half  * d = (global half *) dst_d + get_global_id(0);\n    global half  * m = (global half *) dst_m + get_global_id(0);\n\n    *d = b->d;\n    *m = b->m;\n\n    for (int i = 0; i < QK4_1/2; ++i) {\n        q[i] = b->qs[i];\n    }\n}\n\nkernel void kernel_restore_block_q4_1(\n    global uchar * src_q,\n    global half  * src_d,\n    global half  * src_m,\n    global struct block_q4_1 * dst\n) {\n    global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0);\n    global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0);\n    global half  * d = (global half *) src_d + get_global_id(0);\n    global half  * m = (global half *) src_m + get_global_id(0);\n\n    b->d = *d;\n    b->m = *m;\n    for (int i = 0; i < QK4_1/2; ++i) {\n        b->qs[i] = q[i];\n    }\n}\n\nkernel void kernel_convert_block_q4_1_noshuffle(\n    global struct block_q4_1 * src0,\n    global uchar * dst_q,\n    global half  * dst_d,\n    global half  * dst_m\n) {\n    global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0);\n    global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0);\n    global half  * d = (global half *) dst_d + get_global_id(0);\n    global half  * m = (global half *) dst_m + get_global_id(0);\n\n    *d = b->d;\n    *m = b->m;\n    for (int i = 0; i < QK4_1/4; ++i) {\n        uchar x0 = b->qs[2*i + 0];\n        uchar x1 = b->qs[2*i + 1];\n\n        q[i + 0      ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);\n        q[i + QK4_1/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);\n\n#ifdef ADRENO_GPU\n        if (get_global_id(0) == 65536*4096) {\n            printf(\"%04x - %02x\\n\", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0));\n        }\n#endif\n    }\n}\n\nkernel void kernel_restore_block_q4_1_noshuffle(\n    global uchar * src_q,\n    global half  * src_d,\n    global half  * src_m,\n    global struct block_q4_1 * dst,\n    uchar mask_0F,\n    uchar mask_F0\n) {\n    global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0);\n    global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0);\n    global half  * d = (global half *) src_d + get_global_id(0);\n    global half  * m = (global half *) src_m + get_global_id(0);\n\n    b->d = *d;\n    b->m = *m;\n    for (int i = 0; i < QK4_1/4; ++i) {\n        uchar x0 = q[i + 0      ] ;\n        uchar x1 = q[i + QK4_1/4];\n\n        b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4));\n        b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0));\n    }\n}\n\n//------------------------------------------------------------------------------\n// block_mxfp4\n//------------------------------------------------------------------------------\n#define QK_MXFP4 32\nstruct block_mxfp4 {\n    uchar e; // E8M0\n    uchar qs[QK_MXFP4 / 2];\n};\n\n//------------------------------------------------------------------------------\n// kernel_convert_block_mxfp4\n// Convert the block_mxfp4 format to 2 separate arrays (AOS -> SOA).\n// This kernel does not deshuffle the bits.\n//------------------------------------------------------------------------------\nkernel void kernel_convert_block_mxfp4(\n    global struct block_mxfp4 * src0,\n    global uchar * dst_q,\n    global uchar * dst_e\n) {\n    global struct block_mxfp4 * b = (global struct block_mxfp4 *) src0 + get_global_id(0);\n    global uchar * q = (global uchar *) dst_q + QK_MXFP4 / 2 * get_global_id(0);\n    global uchar * e = (global uchar *) dst_e + get_global_id(0);\n\n    *e = b->e;\n\n    for (int i = 0; i < QK_MXFP4 / 2; ++i) {\n        q[i] = b->qs[i];\n    }\n}\n\nkernel void kernel_convert_block_mxfp4_trans(\n    global struct block_mxfp4 * src0,\n    __global uint4 * dst_q,\n    __global uchar * dst_e,\n    uint ne00,\n    uint ne01\n) {\n    int i00 = get_global_id(1);\n    uint i01 = get_global_id(0);\n    uint i02 = get_global_id(2);\n\n    uint ne00_blk = ne00 / QK_MXFP4;\n    uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;\n    uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;\n\n    global struct block_mxfp4 * b = src0 + src_blk_offset;\n\n    dst_q[dst_blk_offset] = ((global uint4 *)(&(b->qs[0])))[0];\n    dst_e[dst_blk_offset] = b->e;\n}\n\nkernel void kernel_restore_block_mxfp4(\n    global uchar * src_q,\n    global half  * src_e,\n    global struct block_mxfp4 * dst\n) {\n    global struct block_mxfp4 * b = (global struct block_mxfp4 *) dst + get_global_id(0);\n    global uchar * q = (global uchar *) src_q + QK_MXFP4 / 2 * get_global_id(0);\n    global uchar * e = (global uchar *) src_e + get_global_id(0);\n\n    b->e = *e;\n    for (int i = 0; i < QK_MXFP4 / 2; ++i) {\n        b->qs[i] = q[i];\n    }\n}\n\nkernel void kernel_restore_block_mxfp4_trans(\n    __global uint4 * src_q,\n    __global uchar * src_e,\n    global struct block_mxfp4 * dst,\n    uint ne00,\n    uint ne01\n) {\n    int i00 = get_global_id(1);\n    uint i01 = get_global_id(0);\n    uint i02 = get_global_id(2);\n\n    uint ne00_blk = ne00 / QK_MXFP4;\n    uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;\n    uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;\n\n    global struct block_mxfp4 * b = dst + dst_blk_offset;\n\n    ((global uint4 *)(&(b->qs[0])))[0] = src_q[src_blk_offset];\n    b->e = src_e[src_blk_offset];\n}\n\n//------------------------------------------------------------------------------\n// block_q8_0\n//------------------------------------------------------------------------------\ntypedef struct {\n    half d;       // delta\n    char qs[QK8_0]; // quants\n} block_q8_0;\n\nkernel void kernel_convert_block_q8_0(\n    global block_q8_0 * src0,\n    global uchar * dst_q,\n    global half  * dst_d\n) {\n    global block_q8_0 * b = (global block_q8_0 *) src0 + get_global_id(0);\n    global uchar      * q = (global uchar *) dst_q + QK8_0*get_global_id(0);\n    global half       * d = (global half *) dst_d + get_global_id(0);\n\n    *d = b->d;\n\n    for (int i = 0; i < QK8_0; ++i) {\n        q[i] = b->qs[i];\n    }\n}\n\nkernel void kernel_restore_block_q8_0(\n    global uchar * src_q,\n    global half  * src_d,\n    global block_q8_0 * dst\n) {\n    global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0);\n    global uchar      * q = (global uchar *) src_q + QK8_0*get_global_id(0);\n    global half       * d = (global half *) src_d + get_global_id(0);\n\n    b->d = *d;\n    for (int i = 0; i < QK8_0; ++i) {\n        b->qs[i] = q[i];\n    }\n}\n\nkernel void kernel_restore_block_q8_0_trans(\n    global uchar * src_q,\n    global half  * src_d,\n    global block_q8_0 * dst,\n    uint ne00,\n    uint ne01\n){\n    uint num_blk_per_row = ne00 / QK8_0;\n\n    global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0) * num_blk_per_row;\n    global uchar      * q = (global uchar *) src_q + get_global_id(0) * 4; // 4 8-bit packed\n    global half       * d = (global half *) src_d + get_global_id(0);\n\n    for (uint blk = 0; blk < num_blk_per_row; blk++) {\n        b->d = *d;\n\n        for (uint i = 0; i < QK8_0; i+=4) {\n            b->qs[i]   = q[0];\n            b->qs[i+1] = q[1];\n            b->qs[i+2] = q[2];\n            b->qs[i+3] = q[3];\n\n            q += 4 * ne01; // M stride\n        }\n\n        d += ne01;\n\n        b++;\n    }\n}\n\n//------------------------------------------------------------------------------\n// kernel_convert_block_q6_K\n// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA).\n// This kernel does not deshuffle the bits.\n// Each thread processes a super block.\n//------------------------------------------------------------------------------\nkernel void kernel_convert_block_q6_K(\n    global struct block_q6_K * src0,\n    global uchar * dst_ql,\n    global uchar * dst_qh,\n    global char  * dst_s,\n    global half  * dst_d\n) {\n    global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0);\n    global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);\n    global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);\n    global char  * s  = (global char  *) dst_s  + QK_K/16*get_global_id(0);\n    global half  * d  = (global half  *) dst_d  + get_global_id(0);\n\n    *d = b->d;\n\n    for (int i = 0; i < QK_K/2; ++i) {\n        ql[i] = b->ql[i];\n    }\n    for (int i = 0; i < QK_K/4; ++i) {\n        qh[i] = b->qh[i];\n    }\n    for (int i = 0; i < QK_K/16; ++i) {\n        s[i] = b->scales[i];\n    }\n}\n\n// Restore block_q6_K from flattened arrays.\n// Each thread processes a super block.\nkernel void kernel_restore_block_q6_K(\n    global uchar * dst_ql,\n    global uchar * dst_qh,\n    global char  * dst_s,\n    global half  * dst_d,\n    global struct block_q6_K * dst\n) {\n    global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0);\n    global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);\n    global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);\n    global char  * s  = (global char  *) dst_s  + QK_K/16*get_global_id(0);\n    global half  * d  = (global half  *) dst_d  + get_global_id(0);\n\n    b->d = *d;\n\n    for (int i = 0; i < QK_K/2; ++i) {\n        b->ql[i] = ql[i];\n    }\n    for (int i = 0; i < QK_K/4; ++i) {\n        b->qh[i] = qh[i];\n    }\n    for (int i = 0; i < QK_K/16; ++i) {\n        b->scales[i] = s[i];\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/diag.cl",
    "content": "kernel void kernel_diag_f32(\n    global const char * src0,\n    ulong               offset0,\n    global       char * dst,\n    ulong               offsetd,\n    ulong               nb01,\n    ulong               nb02,\n    ulong               nb03,\n    int                 ne0,\n    ulong               nb0,\n    ulong               nb2,\n    ulong               nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst + offsetd;\n\n    int i3 = get_group_id(2);\n    int i2 = get_group_id(1);\n    int i1 = get_group_id(0);\n\n    global const float * src0_ptr = (global const float *)(src0 +           i2*nb02 + i3*nb03);\n    global       float * dst_ptr  = (global       float *)(dst  + i1*nb01 + i2*nb2  + i3*nb3);\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/diag_mask_inf.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// diag_mask_inf kernels\n//------------------------------------------------------------------------------\nkernel void kernel_diag_mask_inf(\n        global float * src0,\n        ulong offset0,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int n_past\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i02 = get_global_id(2);\n    int i01 = get_global_id(1);\n    int i00 = get_global_id(0);\n\n    if (i00 > n_past + i01) {\n        dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;\n    } else {\n        dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];\n    }\n}\n\nkernel void kernel_diag_mask_inf_8(\n        global float4 * src0,\n        ulong offset0,\n        global float4 * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int n_past\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    dst = (global float4*)((global char*)dst + offsetd);\n\n    int i = 2*get_global_id(0);\n\n    dst[i+0] = src0[i+0];\n    dst[i+1] = src0[i+1];\n    int i4 = 4*i;\n    int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;\n    int i01 = i4/(ne00);      i4 -= i01*ne00;\n    int i00 = i4;\n    for (int k = 3; k >= 0; --k) {\n        if (i00 + 4 + k <= n_past + i01) {\n            break;\n        }\n        (&dst[i+1])[k] = -INFINITY;\n        if (i00 + k > n_past + i01) {\n            (&dst[i])[k] = -INFINITY;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/div.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// div\n//------------------------------------------------------------------------------\nkernel void kernel_div(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global char * dst,\n        ulong offsetd,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        int ne11,\n        int ne12,\n        int ne13,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        int ne0,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int i13 = i03 % ne13;\n    int i12 = i02 % ne12;\n    int i11 = i01 % ne11;\n\n    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;\n    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;\n    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const int i10 = i0 % ne10;\n        *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) / *((global float *)(src1_ptr + i10*nb10));\n    }\n}\n\n// assumption: src1 is a row\n// broadcast src1 into src0\nkernel void kernel_div_row(\n        global float4 * src0,\n        ulong offset0,\n        global float4 * src1,\n        ulong offset1,\n        global float4 * dst,\n        ulong offsetd,\n        int ne\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    src1 = (global float4*)((global char*)src1 + offset1);\n    dst = (global float4*)((global char*)dst + offsetd);\n\n    // This performs better than using %.\n    uint gid = get_global_id(0);\n    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne\n    dst[gid] = src0[gid] / src1[idx1];\n}\n\nkernel void kernel_div_f16(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global char * dst,\n        ulong offsetd,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        int ne11,\n        int ne12,\n        int ne13,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        int ne0,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int i13 = i03 % ne13;\n    int i12 = i02 % ne12;\n    int i11 = i01 % ne11;\n\n    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;\n    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;\n    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const int i10 = i0 % ne10;\n        *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) / *((global half *)(src1_ptr + i10*nb10));\n    }\n}\n\nkernel void kernel_div_row_f16(\n        global half4 * src0,\n        ulong offset0,\n        global half4 * src1,\n        ulong offset1,\n        global half4 * dst,\n        ulong offsetd,\n        int ne\n) {\n    src0 = (global half4*)((global char*)src0 + offset0);\n    src1 = (global half4*)((global char*)src1 + offset1);\n    dst = (global half4*)((global char*)dst + offsetd);\n\n    // This performs better than using %.\n    uint gid = get_global_id(0);\n    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne\n    dst[gid] = src0[gid] / src1[idx1];\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/embed_kernel.py",
    "content": "#\n\nimport sys\nimport logging\nlogger = logging.getLogger(\"opencl-embed-kernel\")\n\n\ndef main():\n    logging.basicConfig(level=logging.INFO)\n\n    if len(sys.argv) != 3:\n        logger.info(\"Usage: python embed_kernel.py <input_file> <output_file>\")\n        sys.exit(1)\n\n    ifile = open(sys.argv[1], \"r\")\n    ofile = open(sys.argv[2], \"w\")\n\n    for i in ifile:\n        ofile.write('R\"({})\"\\n'.format(i))\n\n    ifile.close()\n    ofile.close()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/ggml-opencl/kernels/exp.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\nkernel void kernel_exp_f32(\n        global const float * src0,\n        ulong                offset0,\n        global       float * dst,\n        ulong                offsetd,\n        int                  n\n) {\n    if (get_global_id(0) >= n) {\n        return;\n    }\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst  = (global float*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = exp(src0[get_global_id(0)]);\n}\n\nkernel void kernel_exp_f32_4(\n        global const float4 * src0,\n        ulong                 offset0,\n        global       float4 * dst,\n        ulong                 offsetd,\n        int                   n\n) {\n    if (get_global_id(0) >= n) {\n        return;\n    }\n    src0 = (global float4*)((global char*)src0 + offset0);\n    dst  = (global float4*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = exp(src0[get_global_id(0)]);\n}\n\nkernel void kernel_exp_f16(\n        global const half * src0,\n        ulong               offset0,\n        global       half * dst,\n        ulong               offsetd,\n        int                 n\n) {\n    if (get_global_id(0) >= n) {\n        return;\n    }\n    src0 = (global half*)((global char*)src0 + offset0);\n    dst  = (global half*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = exp(src0[get_global_id(0)]);\n}\n\nkernel void kernel_exp_f16_4(\n        global const half4 * src0,\n        ulong                offset0,\n        global       half4 * dst,\n        ulong                offsetd,\n        int                  n\n) {\n    if (get_global_id(0) >= n) {\n        return;\n    }\n    src0 = (global half4*)((global char*)src0 + offset0);\n    dst  = (global half4*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = exp(src0[get_global_id(0)]);\n}\n\nkernel void kernel_exp_f32_nc(\n        global const char * src0,\n        ulong               offset0,\n        global       char * dst,\n        ulong               offsetd,\n        int   ne00,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {\n        global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n        global       float * y = (global       float *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n        *y = exp(*x);\n    }\n}\n\nkernel void kernel_exp_f16_nc(\n        global const char * src0,\n        ulong               offset0,\n        global       char * dst,\n        ulong               offsetd,\n        int   ne00,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {\n        global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n        global       half * y = (global       half *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n        *y = exp(*x);\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/expm1.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// expm1\n//------------------------------------------------------------------------------\n\nkernel void kernel_expm1_f32(\n        global const float * src0,\n        ulong                offset0,\n        global       float * dst,\n        ulong                offsetd\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst  = (global float*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;\n}\n\nkernel void kernel_expm1_f32_4(\n        global const float4 * src0,\n        ulong                 offset0,\n        global       float4 * dst,\n        ulong                 offsetd\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    dst  = (global float4*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;\n}\n\nkernel void kernel_expm1_f16(\n        global const half * src0,\n        ulong               offset0,\n        global       half * dst,\n        ulong               offsetd\n) {\n    src0 = (global half*)((global char*)src0 + offset0);\n    dst  = (global half*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;\n}\n\nkernel void kernel_expm1_f16_4(\n        global const half4 * src0,\n        ulong                offset0,\n        global       half4 * dst,\n        ulong                offsetd\n) {\n    src0 = (global half4*)((global char*)src0 + offset0);\n    dst  = (global half4*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;\n}\n\nkernel void kernel_expm1_f32_nc(\n        global const char * src0,\n        ulong               offset0,\n        global       char * dst,\n        ulong               offsetd,\n        int   ne00,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {\n        global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n        global       float * y = (global       float *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n        *y = exp(*x) - 1.0f;\n    }\n}\n\nkernel void kernel_expm1_f16_nc(\n        global const char * src0,\n        ulong               offset0,\n        global       char * dst,\n        ulong               offsetd,\n        int   ne00,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {\n        global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n        global       half * y = (global       half *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n        *y = exp(*x) - 1.0f;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/fill.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// fill\n//------------------------------------------------------------------------------\n__kernel void kernel_fill_f32(\n        __global float *dst,\n        ulong offsetd,\n        float v,\n        int n\n\n) {\n    dst = (global float*)((global char*)dst + offsetd);\n    if(get_global_id(0) < n){\n        dst[get_global_id(0)] = v;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/flash_attn_f16.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#define ACC_TYPE float\n#define ACC_TYPE4 float4\n#define DATA_TYPE half\n#define DATA_TYPE4 half4\n#define CONVERT_ACC4(x) convert_float4(x)\n#define CONVERT_DATA4(x) convert_half4(x)\n\n#define DK_VEC (DK/4)\n#define DV_VEC (DV/4)\n#define WG_SIZE (BLOCK_M)\n#define Q1_WG_SIZE 64\n\ninline float get_alibi_slope(\n    const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1\n) {\n    if (max_bias <= 0.0f) {\n        return 1.0f;\n    }\n    const float base = h < n_head_log2 ? m0 : m1;\n    const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;\n\n    return pow(base, exph);\n}\n__kernel void flash_attn_f16(\n    const global void * q_void, ulong q_offset,\n    const global void * k_void, ulong k_offset,\n    const global void * v_void, ulong v_offset,\n    global void * o_void, ulong o_offset,\n    const float scale,\n    const int n_q,\n    const int n_kv,\n    const int is_causal,\n    const int n_head,\n    const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,\n    const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,\n    const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,\n    const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,\n    const float max_bias,\n    const float m0,\n    const float m1,\n    const int n_head_log2,\n    const float logit_softcap,\n    const int n_head_kv,\n    const global void* mask_void,\n    const ulong mask_offset,\n    const ulong mask_nb1,\n    const ulong mask_nb2,\n    const ulong mask_nb3,\n    const int mask_ne2,\n    const int mask_ne3,\n    const global void* sinks_void,\n    const ulong sinks_offset\n) {\n    const int tid = get_local_id(0);\n    const int block_q_idx = get_group_id(0);\n    const int head_batch_idx = get_global_id(1);\n\n    const int my_query_row = block_q_idx * BLOCK_M + tid;\n\n    const int batch_idx = head_batch_idx / n_head;\n    const int head_idx = head_batch_idx % n_head;\n\n    const int gqa_ratio = n_head / n_head_kv;\n    const int head_kv_idx = head_idx / gqa_ratio;\n\n    const global char* q_base = (const global char*)q_void + q_offset;\n    const global char* k_base = (const global char*)k_void + k_offset;\n    const global char* v_base = (const global char*)v_void + v_offset;\n    global char* o_base = (global char*)o_void + o_offset;\n\n    const global char* mask_base = NULL;\n    if (mask_void != NULL) {\n        const int mask_head_idx = head_idx % mask_ne2;\n        const int mask_batch_idx = batch_idx % mask_ne3;\n        mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;\n    }\n\n    ACC_TYPE4 q_priv[DK_VEC];\n    if (my_query_row < n_q) {\n        const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;\n        const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);\n        #pragma unroll\n        for (int i = 0; i < DK_VEC; ++i) {\n            q_priv[i] = CONVERT_ACC4(q_ptr[i]);\n        }\n    }\n\n    ACC_TYPE4 o_acc[DV_VEC];\n    #pragma unroll\n    for (int i = 0; i < DV_VEC; ++i) {\n        o_acc[i] = (ACC_TYPE4)(0.0f);\n    }\n    ACC_TYPE m_i = -INFINITY;\n    ACC_TYPE l_i = 0.0f;\n\n    float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);\n\n    __local DATA_TYPE4 l_k[BLOCK_N][DK_VEC];\n    __local DATA_TYPE4 l_v[BLOCK_N][DV_VEC];\n\n    for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {\n        for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {\n            const int row = i / DK_VEC;\n            const int col = i % DK_VEC;\n            const int k_row_idx = k_start + row;\n            if (k_row_idx < n_kv) {\n                const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;\n                l_k[row][col] = ((__global DATA_TYPE4*)(k_base + k_row_offset))[col];\n            }\n        }\n        for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {\n            const int row = i / DV_VEC;\n            const int col = i % DV_VEC;\n            const int v_row_idx = k_start + row;\n            if (v_row_idx < n_kv) {\n                const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;\n                l_v[row][col] = ((__global DATA_TYPE4*)(v_base + v_row_offset))[col];\n            }\n        }\n        barrier(CLK_LOCAL_MEM_FENCE);\n\n        if (my_query_row >= n_q) {\n            continue;\n        }\n\n        for (int j = 0; j < BLOCK_N; j += 2) {\n            const int k_row0 = k_start + j;\n            const int k_row1 = k_start + j + 1;\n\n            ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);\n            ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);\n            #pragma unroll\n            for (int k = 0; k < DK_VEC; k++) {\n                dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);\n                dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);\n            }\n            ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;\n            ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;\n\n            if (is_causal) {\n                if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;\n                if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;\n            }\n\n            if (k_row0 >= n_kv) score0 = -INFINITY;\n            if (k_row1 >= n_kv) score1 = -INFINITY;\n\n            if (mask_base != NULL) {\n                const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);\n                if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];\n                if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];\n            }\n\n            if (logit_softcap > 0.0f) {\n                score0 = logit_softcap * tanh(score0 / logit_softcap);\n                score1 = logit_softcap * tanh(score1 / logit_softcap);\n            }\n\n            const ACC_TYPE m_new = max(m_i, max(score0, score1));\n            const ACC_TYPE p0 = exp(score0 - m_new);\n            const ACC_TYPE p1 = exp(score1 - m_new);\n            const ACC_TYPE scale_prev = exp(m_i - m_new);\n\n            #pragma unroll\n            for (int i = 0; i < DV_VEC; ++i) {\n                o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);\n            }\n            l_i = l_i * scale_prev + p0 + p1;\n            m_i = m_new;\n        }\n    }\n\n    if (my_query_row < n_q) {\n        if (sinks_void != NULL) {\n            const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);\n            const ACC_TYPE m_sink = sinks_ptr[head_idx];\n            const ACC_TYPE m_final = max(m_i, m_sink);\n\n            const ACC_TYPE scale_o = exp(m_i - m_final);\n            #pragma unroll\n            for (int i = 0; i < DV_VEC; ++i) {\n                o_acc[i] *= scale_o;\n            }\n\n            l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);\n        }\n\n        const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;\n        global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);\n        if (l_i > 0.0f) {\n            const ACC_TYPE l_inv = 1.0f / l_i;\n            #pragma unroll\n            for (int i = 0; i < DV_VEC; ++i) {\n                o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);\n            }\n        } else {\n            #pragma unroll\n            for (int i = 0; i < DV_VEC; ++i) {\n                o_row[i] = (DATA_TYPE4)(0.0f);\n            }\n        }\n    }\n}\n\n__kernel void flash_attn_f16_q1(\n    const global void * q_void, ulong q_offset,\n    const global void * k_void, ulong k_offset,\n    const global void * v_void, ulong v_offset,\n    global void * o_void, ulong o_offset,\n    const float scale,\n    const int n_q,\n    const int n_kv,\n    const int is_causal,\n    const int n_head,\n    const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,\n    const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,\n    const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,\n    const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,\n    const float max_bias,\n    const float m0,\n    const float m1,\n    const int n_head_log2,\n    const float logit_softcap,\n    const int n_head_kv,\n    const global void* mask_void,\n    const ulong mask_offset,\n    const ulong mask_nb1,\n    const ulong mask_nb2,\n    const ulong mask_nb3,\n    const int mask_ne2,\n    const int mask_ne3,\n    const global void* sinks_void,\n    const ulong sinks_offset\n) {\n    const int tid = get_local_id(0);\n    const int head_batch_idx = get_global_id(1);\n\n    const int batch_idx = head_batch_idx / n_head;\n    const int head_idx = head_batch_idx % n_head;\n\n    const int gqa_ratio = n_head / n_head_kv;\n    const int head_kv_idx = head_idx / gqa_ratio;\n\n    const global char* q_base = (const global char*)q_void + q_offset;\n    const global char* k_base = (const global char*)k_void + k_offset;\n    const global char* v_base = (const global char*)v_void + v_offset;\n    global char* o_base = (global char*)o_void + o_offset;\n\n    const global char* mask_base = NULL;\n    if (mask_void != NULL) {\n        const int mask_head_idx = head_idx % mask_ne2;\n        const int mask_batch_idx = batch_idx % mask_ne3;\n        mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;\n    }\n\n    ACC_TYPE4 q_priv[DK_VEC];\n    const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;\n    const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);\n    #pragma unroll\n    for (int i = 0; i < DK_VEC; ++i) {\n        q_priv[i] = CONVERT_ACC4(q_ptr[i]);\n    }\n\n    float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);\n\n    const global ACC_TYPE* sinks_ptr = NULL;\n    if (sinks_void != NULL) {\n        sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);\n    }\n\n    ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;\n    for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {\n        const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;\n        const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);\n        ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);\n        #pragma unroll\n        for (int k = 0; k < DK_VEC; k++) {\n            dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);\n        }\n        ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;\n        if (mask_base != NULL) {\n            const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);\n            score += slope * (ACC_TYPE)mask_ptr[k_idx];\n        }\n        if (logit_softcap > 0.0f) {\n            score = logit_softcap * tanh(score / logit_softcap);\n        }\n        m_i = max(m_i, score);\n    }\n\n    __local ACC_TYPE local_m[Q1_WG_SIZE];\n    local_m[tid] = m_i;\n    barrier(CLK_LOCAL_MEM_FENCE);\n    #pragma unroll\n    for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n    const ACC_TYPE m_final = local_m[0];\n\n    ACC_TYPE4 o_acc[DV_VEC];\n    #pragma unroll\n    for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);\n    ACC_TYPE l_i = 0.0f;\n\n    for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {\n        const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;\n        const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;\n        const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);\n        const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);\n        ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);\n        #pragma unroll\n        for (int k = 0; k < DK_VEC; k++) {\n            dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);\n        }\n        ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;\n        if (mask_base != NULL) {\n            const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);\n            score += slope * (ACC_TYPE)mask_ptr[k_idx];\n        }\n        if (logit_softcap > 0.0f) {\n            score = logit_softcap * tanh(score / logit_softcap);\n        }\n        const ACC_TYPE p = exp(score - m_final);\n        l_i += p;\n        #pragma unroll\n        for (int i = 0; i < DV_VEC; i++) {\n            o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);\n        }\n    }\n\n    __local ACC_TYPE local_l[Q1_WG_SIZE];\n    __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];\n    local_l[tid] = l_i;\n    barrier(CLK_LOCAL_MEM_FENCE);\n    #pragma unroll\n    for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) local_l[tid] += local_l[tid + s];\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n\n    const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;\n    global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);\n    ACC_TYPE l_final = local_l[0];\n\n    if (sinks_ptr != NULL) {\n        l_final += exp(sinks_ptr[head_idx] - m_final);\n    }\n\n    if (l_final > 0.0f) {\n        const ACC_TYPE l_inv = 1.0f / l_final;\n        for (int i = 0; i < DV_VEC; i++) {\n            local_o_comp[tid] = o_acc[i];\n            barrier(CLK_LOCAL_MEM_FENCE);\n            #pragma unroll\n            for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {\n                if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];\n                barrier(CLK_LOCAL_MEM_FENCE);\n            }\n            if (tid == 0) {\n                o_row[i] = CONVERT_DATA4(local_o_comp[0] * l_inv);\n            }\n        }\n    } else if (tid == 0) {\n        #pragma unroll\n        for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/flash_attn_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#define ACC_TYPE float\n#define ACC_TYPE4 float4\n#define DATA_TYPE float\n#define DATA_TYPE4 float4\n#define MASK_DATA_TYPE half\n#define CONVERT_ACC4(x) (x)\n#define CONVERT_DATA4(x) (x)\n\n#define DK_VEC (DK/4)\n#define DV_VEC (DV/4)\n#define WG_SIZE (BLOCK_M)\n#define Q1_WG_SIZE 64\n\ninline float get_alibi_slope(\n    const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1\n) {\n    if (max_bias <= 0.0f) {\n        return 1.0f;\n    }\n    const float base = h < n_head_log2 ? m0 : m1;\n    const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;\n\n    return pow(base, exph);\n}\n__kernel void flash_attn_f32(\n    const global void * q_void, ulong q_offset,\n    const global void * k_void, ulong k_offset,\n    const global void * v_void, ulong v_offset,\n    global void * o_void, ulong o_offset,\n    const float scale,\n    const int n_q,\n    const int n_kv,\n    const int is_causal,\n    const int n_head,\n    const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,\n    const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,\n    const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,\n    const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,\n    const float max_bias,\n    const float m0,\n    const float m1,\n    const int n_head_log2,\n    const float logit_softcap,\n    const int n_head_kv,\n    const global void* mask_void,\n    const ulong mask_offset,\n    const ulong mask_nb1,\n    const ulong mask_nb2,\n    const ulong mask_nb3,\n    const int mask_ne2,\n    const int mask_ne3,\n    const global void* sinks_void,\n    const ulong sinks_offset\n) {\n    const int tid = get_local_id(0);\n    const int block_q_idx = get_group_id(0);\n    const int head_batch_idx = get_global_id(1);\n\n    const int my_query_row = block_q_idx * BLOCK_M + tid;\n\n    const int batch_idx = head_batch_idx / n_head;\n    const int head_idx = head_batch_idx % n_head;\n\n    const int gqa_ratio = n_head / n_head_kv;\n    const int head_kv_idx = head_idx / gqa_ratio;\n\n    const global char* q_base = (const global char*)q_void + q_offset;\n    const global char* k_base = (const global char*)k_void + k_offset;\n    const global char* v_base = (const global char*)v_void + v_offset;\n    global char* o_base = (global char*)o_void + o_offset;\n\n    const global char* mask_base = NULL;\n    if (mask_void != NULL) {\n        const int mask_head_idx = head_idx % mask_ne2;\n        const int mask_batch_idx = batch_idx % mask_ne3;\n        mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;\n    }\n\n    ACC_TYPE4 q_priv[DK_VEC];\n    if (my_query_row < n_q) {\n        const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;\n        const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);\n        #pragma unroll\n        for (int i = 0; i < DK_VEC; ++i) {\n            q_priv[i] = CONVERT_ACC4(q_ptr[i]);\n        }\n    }\n\n    ACC_TYPE4 o_acc[DV_VEC];\n    #pragma unroll\n    for (int i = 0; i < DV_VEC; ++i) {\n        o_acc[i] = (ACC_TYPE4)(0.0f);\n    }\n    ACC_TYPE m_i = -INFINITY;\n    ACC_TYPE l_i = 0.0f;\n\n    float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);\n\n    __local DATA_TYPE4 l_k[BLOCK_N][DK_VEC];\n    __local DATA_TYPE4 l_v[BLOCK_N][DV_VEC];\n\n    for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {\n        for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {\n            const int row = i / DK_VEC;\n            const int col = i % DK_VEC;\n            const int k_row_idx = k_start + row;\n            if (k_row_idx < n_kv) {\n                const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;\n                l_k[row][col] = ((__global DATA_TYPE4*)(k_base + k_row_offset))[col];\n            }\n        }\n        for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {\n            const int row = i / DV_VEC;\n            const int col = i % DV_VEC;\n            const int v_row_idx = k_start + row;\n            if (v_row_idx < n_kv) {\n                const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;\n                l_v[row][col] = ((__global DATA_TYPE4*)(v_base + v_row_offset))[col];\n            }\n        }\n        barrier(CLK_LOCAL_MEM_FENCE);\n\n        if (my_query_row >= n_q) {\n            continue;\n        }\n\n        for (int j = 0; j < BLOCK_N; j += 2) {\n            const int k_row0 = k_start + j;\n            const int k_row1 = k_start + j + 1;\n\n            ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);\n            ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);\n            #pragma unroll\n            for (int k = 0; k < DK_VEC; k++) {\n                dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);\n                dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);\n            }\n            ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;\n            ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;\n\n            if (is_causal) {\n                if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;\n                if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;\n            }\n\n            if (k_row0 >= n_kv) score0 = -INFINITY;\n            if (k_row1 >= n_kv) score1 = -INFINITY;\n\n            if (mask_base != NULL) {\n                const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);\n                if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];\n                if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];\n            }\n\n            if (logit_softcap > 0.0f) {\n                score0 = logit_softcap * tanh(score0 / logit_softcap);\n                score1 = logit_softcap * tanh(score1 / logit_softcap);\n            }\n\n            const ACC_TYPE m_new = max(m_i, max(score0, score1));\n            const ACC_TYPE p0 = exp(score0 - m_new);\n            const ACC_TYPE p1 = exp(score1 - m_new);\n            const ACC_TYPE scale_prev = exp(m_i - m_new);\n\n            #pragma unroll\n            for (int i = 0; i < DV_VEC; ++i) {\n                o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);\n            }\n            l_i = l_i * scale_prev + p0 + p1;\n            m_i = m_new;\n        }\n    }\n\n    if (my_query_row < n_q) {\n        if (sinks_void != NULL) {\n            const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);\n            const ACC_TYPE m_sink = sinks_ptr[head_idx];\n            const ACC_TYPE m_final = max(m_i, m_sink);\n\n            const ACC_TYPE scale_o = exp(m_i - m_final);\n            #pragma unroll\n            for (int i = 0; i < DV_VEC; ++i) {\n                o_acc[i] *= scale_o;\n            }\n\n            l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);\n        }\n\n        const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;\n        global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);\n        if (l_i > 0.0f) {\n            const ACC_TYPE l_inv = 1.0f / l_i;\n            #pragma unroll\n            for (int i = 0; i < DV_VEC; ++i) {\n                o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);\n            }\n        } else {\n            #pragma unroll\n            for (int i = 0; i < DV_VEC; ++i) {\n                o_row[i] = (DATA_TYPE4)(0.0f);\n            }\n        }\n    }\n}\n\n__kernel void flash_attn_f32_q1(\n    const global void * q_void, ulong q_offset,\n    const global void * k_void, ulong k_offset,\n    const global void * v_void, ulong v_offset,\n    global void * o_void, ulong o_offset,\n    const float scale,\n    const int n_q,\n    const int n_kv,\n    const int is_causal,\n    const int n_head,\n    const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,\n    const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,\n    const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,\n    const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,\n    const float max_bias,\n    const float m0,\n    const float m1,\n    const int n_head_log2,\n    const float logit_softcap,\n    const int n_head_kv,\n    const global void* mask_void,\n    const ulong mask_offset,\n    const ulong mask_nb1,\n    const ulong mask_nb2,\n    const ulong mask_nb3,\n    const int mask_ne2,\n    const int mask_ne3,\n    const global void* sinks_void,\n    const ulong sinks_offset\n) {\n    const int tid = get_local_id(0);\n    const int head_batch_idx = get_global_id(1);\n\n    const int batch_idx = head_batch_idx / n_head;\n    const int head_idx = head_batch_idx % n_head;\n\n    const int gqa_ratio = n_head / n_head_kv;\n    const int head_kv_idx = head_idx / gqa_ratio;\n\n    const global char* q_base = (const global char*)q_void + q_offset;\n    const global char* k_base = (const global char*)k_void + k_offset;\n    const global char* v_base = (const global char*)v_void + v_offset;\n    global char* o_base = (global char*)o_void + o_offset;\n\n    const global char* mask_base = NULL;\n    if (mask_void != NULL) {\n        const int mask_head_idx = head_idx % mask_ne2;\n        const int mask_batch_idx = batch_idx % mask_ne3;\n        mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;\n    }\n\n    ACC_TYPE4 q_priv[DK_VEC];\n    const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;\n    const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);\n    #pragma unroll\n    for (int i = 0; i < DK_VEC; ++i) {\n        q_priv[i] = CONVERT_ACC4(q_ptr[i]);\n    }\n\n    float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);\n\n    const global ACC_TYPE* sinks_ptr = NULL;\n    if (sinks_void != NULL) {\n        sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);\n    }\n\n    ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;\n    for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {\n        const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;\n        const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);\n        ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);\n        #pragma unroll\n        for (int k = 0; k < DK_VEC; k++) {\n            dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);\n        }\n        ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;\n        if (mask_base != NULL) {\n            const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);\n            score += slope * (ACC_TYPE)mask_ptr[k_idx];\n        }\n        if (logit_softcap > 0.0f) {\n            score = logit_softcap * tanh(score / logit_softcap);\n        }\n        m_i = max(m_i, score);\n    }\n\n    __local ACC_TYPE local_m[Q1_WG_SIZE];\n    local_m[tid] = m_i;\n    barrier(CLK_LOCAL_MEM_FENCE);\n    #pragma unroll\n    for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n    const ACC_TYPE m_final = local_m[0];\n\n    ACC_TYPE4 o_acc[DV_VEC];\n    #pragma unroll\n    for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);\n    ACC_TYPE l_i = 0.0f;\n\n    for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {\n        const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;\n        const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;\n        const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);\n        const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);\n        ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);\n        #pragma unroll\n        for (int k = 0; k < DK_VEC; k++) {\n            dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);\n        }\n        ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;\n        if (mask_base != NULL) {\n            const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);\n            score += slope * (ACC_TYPE)mask_ptr[k_idx];\n        }\n        if (logit_softcap > 0.0f) {\n            score = logit_softcap * tanh(score / logit_softcap);\n        }\n        const ACC_TYPE p = exp(score - m_final);\n        l_i += p;\n        #pragma unroll\n        for (int i = 0; i < DV_VEC; i++) {\n            o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);\n        }\n    }\n\n    __local ACC_TYPE local_l[Q1_WG_SIZE];\n    __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];\n    local_l[tid] = l_i;\n    barrier(CLK_LOCAL_MEM_FENCE);\n    #pragma unroll\n    for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) local_l[tid] += local_l[tid + s];\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n\n    const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;\n    global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);\n    ACC_TYPE l_final = local_l[0];\n\n    if (sinks_ptr != NULL) {\n        l_final += exp(sinks_ptr[head_idx] - m_final);\n    }\n\n    if (l_final > 0.0f) {\n        const ACC_TYPE l_inv = 1.0f / l_final;\n        for (int i = 0; i < DV_VEC; i++) {\n            local_o_comp[tid] = o_acc[i];\n            barrier(CLK_LOCAL_MEM_FENCE);\n            #pragma unroll\n            for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {\n                if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];\n                barrier(CLK_LOCAL_MEM_FENCE);\n            }\n            if (tid == 0) {\n                o_row[i] = CONVERT_DATA4(local_o_comp[0] * l_inv);\n            }\n        }\n    } else if (tid == 0) {\n        #pragma unroll\n        for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/flash_attn_f32_f16.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#define ACC_TYPE float\n#define ACC_TYPE4 float4\n#define Q_DATA_TYPE4 float4\n#define KV_DATA_TYPE4 half4\n#define O_DATA_TYPE4 float4\n#define MASK_DATA_TYPE half\n#define CONVERT_Q_ACC4(x) (x)\n#define CONVERT_KV_ACC4(x) convert_float4(x)\n#define CONVERT_O_DATA4(x) (x)\n\n#define DK_VEC (DK/4)\n#define DV_VEC (DV/4)\n#define WG_SIZE (BLOCK_M)\n#define Q1_WG_SIZE 64\n\ninline float get_alibi_slope(\n    const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1\n) {\n    if (max_bias <= 0.0f) {\n        return 1.0f;\n    }\n    const float base = h < n_head_log2 ? m0 : m1;\n    const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;\n\n    return pow(base, exph);\n}\n__kernel void flash_attn_f32_f16(\n    const global void * q_void, ulong q_offset,\n    const global void * k_void, ulong k_offset,\n    const global void * v_void, ulong v_offset,\n    global void * o_void, ulong o_offset,\n    const float scale,\n    const int n_q,\n    const int n_kv,\n    const int is_causal,\n    const int n_head,\n    const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,\n    const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,\n    const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,\n    const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,\n    const float max_bias,\n    const float m0,\n    const float m1,\n    const int n_head_log2,\n    const float logit_softcap,\n    const int n_head_kv,\n    const global void* mask_void,\n    const ulong mask_offset,\n    const ulong mask_nb1,\n    const ulong mask_nb2,\n    const ulong mask_nb3,\n    const int mask_ne2,\n    const int mask_ne3,\n    const global void* sinks_void,\n    const ulong sinks_offset\n) {\n    const int tid = get_local_id(0);\n    const int block_q_idx = get_group_id(0);\n    const int head_batch_idx = get_global_id(1);\n\n    const int my_query_row = block_q_idx * BLOCK_M + tid;\n\n    const int batch_idx = head_batch_idx / n_head;\n    const int head_idx = head_batch_idx % n_head;\n\n    const int gqa_ratio = n_head / n_head_kv;\n    const int head_kv_idx = head_idx / gqa_ratio;\n\n    const global char* q_base = (const global char*)q_void + q_offset;\n    const global char* k_base = (const global char*)k_void + k_offset;\n    const global char* v_base = (const global char*)v_void + v_offset;\n    global char* o_base = (global char*)o_void + o_offset;\n\n    const global char* mask_base = NULL;\n    if (mask_void != NULL) {\n        const int mask_head_idx = head_idx % mask_ne2;\n        const int mask_batch_idx = batch_idx % mask_ne3;\n        mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;\n    }\n\n    ACC_TYPE4 q_priv[DK_VEC];\n    if (my_query_row < n_q) {\n        const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;\n        const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);\n        #pragma unroll\n        for (int i = 0; i < DK_VEC; ++i) {\n            q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);\n        }\n    }\n\n    ACC_TYPE4 o_acc[DV_VEC];\n    #pragma unroll\n    for (int i = 0; i < DV_VEC; ++i) {\n        o_acc[i] = (ACC_TYPE4)(0.0f);\n    }\n    ACC_TYPE m_i = -INFINITY;\n    ACC_TYPE l_i = 0.0f;\n\n    float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);\n\n    __local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];\n    __local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];\n\n    for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {\n        for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {\n            const int row = i / DK_VEC;\n            const int col = i % DK_VEC;\n            const int k_row_idx = k_start + row;\n            if (k_row_idx < n_kv) {\n                const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;\n                l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col];\n            }\n        }\n        for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {\n            const int row = i / DV_VEC;\n            const int col = i % DV_VEC;\n            const int v_row_idx = k_start + row;\n            if (v_row_idx < n_kv) {\n                const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;\n                l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col];\n            }\n        }\n        barrier(CLK_LOCAL_MEM_FENCE);\n\n        if (my_query_row >= n_q) {\n            continue;\n        }\n\n        for (int j = 0; j < BLOCK_N; j += 2) {\n            const int k_row0 = k_start + j;\n            const int k_row1 = k_start + j + 1;\n\n            ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);\n            ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);\n            #pragma unroll\n            for (int k = 0; k < DK_VEC; k++) {\n                dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);\n                dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);\n            }\n            ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;\n            ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;\n\n            if (is_causal) {\n                if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;\n                if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;\n            }\n\n            if (k_row0 >= n_kv) score0 = -INFINITY;\n            if (k_row1 >= n_kv) score1 = -INFINITY;\n\n            if (mask_base != NULL) {\n                const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);\n                if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];\n                if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];\n            }\n\n            if (logit_softcap > 0.0f) {\n                score0 = logit_softcap * tanh(score0 / logit_softcap);\n                score1 = logit_softcap * tanh(score1 / logit_softcap);\n            }\n\n            const ACC_TYPE m_new = max(m_i, max(score0, score1));\n            const ACC_TYPE p0 = exp(score0 - m_new);\n            const ACC_TYPE p1 = exp(score1 - m_new);\n            const ACC_TYPE scale_prev = exp(m_i - m_new);\n\n            #pragma unroll\n            for (int i = 0; i < DV_VEC; ++i) {\n                o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]);\n            }\n            l_i = l_i * scale_prev + p0 + p1;\n            m_i = m_new;\n        }\n    }\n\n    if (my_query_row < n_q) {\n        if (sinks_void != NULL) {\n            const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);\n            const ACC_TYPE m_sink = sinks_ptr[head_idx];\n            const ACC_TYPE m_final = max(m_i, m_sink);\n\n            const ACC_TYPE scale_o = exp(m_i - m_final);\n            #pragma unroll\n            for (int i = 0; i < DV_VEC; ++i) {\n                o_acc[i] *= scale_o;\n            }\n\n            l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);\n        }\n\n        const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;\n        global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);\n        if (l_i > 0.0f) {\n            const ACC_TYPE l_inv = 1.0f / l_i;\n            #pragma unroll\n            for (int i = 0; i < DV_VEC; ++i) {\n                o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv);\n            }\n        } else {\n            #pragma unroll\n            for (int i = 0; i < DV_VEC; ++i) {\n                o_row[i] = (O_DATA_TYPE4)(0.0f);\n            }\n        }\n    }\n}\n\n__kernel void flash_attn_f32_f16_q1(\n    const global void * q_void, ulong q_offset,\n    const global void * k_void, ulong k_offset,\n    const global void * v_void, ulong v_offset,\n    global void * o_void, ulong o_offset,\n    const float scale,\n    const int n_q,\n    const int n_kv,\n    const int is_causal,\n    const int n_head,\n    const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,\n    const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,\n    const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,\n    const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,\n    const float max_bias,\n    const float m0,\n    const float m1,\n    const int n_head_log2,\n    const float logit_softcap,\n    const int n_head_kv,\n    const global void* mask_void,\n    const ulong mask_offset,\n    const ulong mask_nb1,\n    const ulong mask_nb2,\n    const ulong mask_nb3,\n    const int mask_ne2,\n    const int mask_ne3,\n    const global void* sinks_void,\n    const ulong sinks_offset\n) {\n    const int tid = get_local_id(0);\n    const int head_batch_idx = get_global_id(1);\n\n    const int batch_idx = head_batch_idx / n_head;\n    const int head_idx = head_batch_idx % n_head;\n\n    const int gqa_ratio = n_head / n_head_kv;\n    const int head_kv_idx = head_idx / gqa_ratio;\n\n    const global char* q_base = (const global char*)q_void + q_offset;\n    const global char* k_base = (const global char*)k_void + k_offset;\n    const global char* v_base = (const global char*)v_void + v_offset;\n    global char* o_base = (global char*)o_void + o_offset;\n\n    const global char* mask_base = NULL;\n    if (mask_void != NULL) {\n        const int mask_head_idx = head_idx % mask_ne2;\n        const int mask_batch_idx = batch_idx % mask_ne3;\n        mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;\n    }\n\n    ACC_TYPE4 q_priv[DK_VEC];\n    const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;\n    const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);\n    #pragma unroll\n    for (int i = 0; i < DK_VEC; ++i) {\n        q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);\n    }\n\n    float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);\n\n    const global ACC_TYPE* sinks_ptr = NULL;\n    if (sinks_void != NULL) {\n        sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);\n    }\n\n    ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;\n    for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {\n        const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;\n        const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);\n        ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);\n        #pragma unroll\n        for (int k = 0; k < DK_VEC; k++) {\n            dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);\n        }\n        ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;\n        if (mask_base != NULL) {\n            const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);\n            score += slope * (ACC_TYPE)mask_ptr[k_idx];\n        }\n        if (logit_softcap > 0.0f) {\n            score = logit_softcap * tanh(score / logit_softcap);\n        }\n        m_i = max(m_i, score);\n    }\n\n    __local ACC_TYPE local_m[Q1_WG_SIZE];\n    local_m[tid] = m_i;\n    barrier(CLK_LOCAL_MEM_FENCE);\n    #pragma unroll\n    for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n    const ACC_TYPE m_final = local_m[0];\n\n    ACC_TYPE4 o_acc[DV_VEC];\n    #pragma unroll\n    for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);\n    ACC_TYPE l_i = 0.0f;\n\n    for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {\n        const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;\n        const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;\n        const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);\n        const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);\n        ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);\n        #pragma unroll\n        for (int k = 0; k < DK_VEC; k++) {\n            dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);\n        }\n        ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;\n        if (mask_base != NULL) {\n            const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);\n            score += slope * (ACC_TYPE)mask_ptr[k_idx];\n        }\n        if (logit_softcap > 0.0f) {\n            score = logit_softcap * tanh(score / logit_softcap);\n        }\n        const ACC_TYPE p = exp(score - m_final);\n        l_i += p;\n        #pragma unroll\n        for (int i = 0; i < DV_VEC; i++) {\n            o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);\n        }\n    }\n\n    __local ACC_TYPE local_l[Q1_WG_SIZE];\n    __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];\n    local_l[tid] = l_i;\n    barrier(CLK_LOCAL_MEM_FENCE);\n    #pragma unroll\n    for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) local_l[tid] += local_l[tid + s];\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n\n    const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;\n    global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);\n    ACC_TYPE l_final = local_l[0];\n\n    if (sinks_ptr != NULL) {\n        l_final += exp(sinks_ptr[head_idx] - m_final);\n    }\n\n    if (l_final > 0.0f) {\n        const ACC_TYPE l_inv = 1.0f / l_final;\n        for (int i = 0; i < DV_VEC; i++) {\n            local_o_comp[tid] = o_acc[i];\n            barrier(CLK_LOCAL_MEM_FENCE);\n            #pragma unroll\n            for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {\n                if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];\n                barrier(CLK_LOCAL_MEM_FENCE);\n            }\n            if (tid == 0) {\n                o_row[i] = CONVERT_O_DATA4(local_o_comp[0] * l_inv);\n            }\n        }\n    } else if (tid == 0) {\n        #pragma unroll\n        for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f);\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/gelu.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// gelu\n//------------------------------------------------------------------------------\n#define GELU_COEF_A     0.044715f\n#define GELU_QUICK_COEF -1.702f\n#define SQRT_2_OVER_PI  0.79788456080286535587989211986876f\n#define SQRT_2_INV      0.70710678118654752440084436210484f\n\nkernel void kernel_gelu(\n    global float * src0,\n    ulong offset0,\n    global float * dst,\n    ulong offsetd\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    float x = src0[get_global_id(0)];\n\n    dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));\n}\n\nkernel void kernel_gelu_4(\n    global float4 * src0,\n    ulong offset0,\n    global float4 * dst,\n    ulong offsetd\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    dst = (global float4*)((global char*)dst + offsetd);\n\n    float4 x = src0[get_global_id(0)];\n\n    dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));\n}\n\nkernel void kernel_gelu_erf(\n    global float * src0,\n    ulong offset0,\n    global float * dst,\n    ulong offsetd\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    float x = src0[get_global_id(0)];\n    dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));\n}\n\nkernel void kernel_gelu_erf_4(\n    global float4 * src0,\n    ulong offset0,\n    global float4 * dst,\n    ulong offsetd\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    dst = (global float4*)((global char*)dst + offsetd);\n\n    float4 x = src0[get_global_id(0)];\n    dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));\n}\n\nkernel void kernel_gelu_quick(\n    global float * src0,\n    ulong offset0,\n    global float * dst,\n    ulong offsetd\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    float x = src0[get_global_id(0)];\n    dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));\n}\n\nkernel void kernel_gelu_quick_4(\n    global float4 * src0,\n    ulong offset0,\n    global float4 * dst,\n    ulong offsetd\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    dst = (global float4*)((global char*)dst + offsetd);\n\n    float4 x = src0[get_global_id(0)];\n    dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n\n#define QK_MXFP4 32\n#define N_SIMDGROUP 2\n#define SIMDGROUP_WIDTH 64\n\nstatic inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {\n    ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;\n    fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;\n    fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;\n    fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;\n    fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;\n\n    bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;\n    bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;\n    bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;\n    bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;\n\n    fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;\n    fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;\n    fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;\n    fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;\n\n    sign_a.lo = (fp4x8.s0 << 12) & 0x8000;\n    sign_a.hi = (fp4x8.s0 << 8) & 0x8000;\n    sign_b.lo = (fp4x8.s0 << 4) & 0x8000;\n    sign_b.hi = fp4x8.s0 & 0x8000;\n\n    fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;\n    fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;\n\n    ushort2 fp16_packed_a_1, fp16_packed_b_1;\n    fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;\n    fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;\n    fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;\n    fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;\n\n    bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;\n    bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;\n    bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;\n    bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;\n\n    fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;\n    fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;\n    fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;\n    fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;\n\n    sign_a.lo = (fp4x8.s1 << 12) & 0x8000;\n    sign_a.hi = (fp4x8.s1 << 8) & 0x8000;\n    sign_b.lo = (fp4x8.s1 << 4) & 0x8000;\n    sign_b.hi = fp4x8.s1 & 0x8000;\n\n    fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;\n    fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;\n\n    return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));\n}\n\nstatic inline float e8m0_to_fp32(uchar x) {\n    int bits;\n    bits = (x == 0) ? 0x00400000 : ((uint) x << 23);\n    return as_float(bits);\n}\n\n\n__attribute__((qcom_reqd_sub_group_size(\"half\")))\n__kernel void kernel_gemm_moe_mxfp4_f32(\n    __global uint4 * src0_q,\n    __global uchar * src0_e,\n    __read_only image1d_buffer_t src1,\n    __global ushort4 * src2,\n    __global float * dst,\n    ulong         offsetd,\n    int           ne00,\n    int           ne01,\n    int           tile_size\n) {\n    uint i01  = get_global_id(0);\n    uint i20  = get_global_id(2);\n    uint sgid = get_local_id(1);\n    uint slid = get_sub_group_local_id();\n\n    ushort4 router = src2[i20];\n    ushort expert_id = router.x;\n    ushort i11 = router.y;\n    ushort i1 = router.z;\n    ushort tile_id = router.w;\n\n    if (tile_id * tile_size + i01 >= ne01) { // handle edge case when ne01 is not multiple of tile_size\n        return;\n    }\n\n    uint expert_offset = expert_id * ne00 * ne01 / 32;\n    uint tile_offset = expert_offset + tile_id * tile_size + i01;\n\n    __private float sum = 0.0f; // each thread calculate partial sum of one output\n\n    // loop along ne00 in block granularity, skip 4 blocks every iter\n    for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {\n        // load one block of q\n        uint4 regQ = src0_q[tile_offset + ib00 * ne01];\n        // convert 8 fp4 to fp16\n        half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));\n\n        uint offset = i11 * ne00 / 4 + ib00 * 8;\n        float4 shared_y4;\n        shared_y4 = read_imagef(src1, (offset + 0));\n        float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);\n\n        shared_y4 = read_imagef(src1, (offset + 4));\n        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);\n\n\n        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));\n\n        shared_y4 = read_imagef(src1, (offset + 1));\n        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);\n\n        shared_y4 = read_imagef(src1, (offset + 5));\n        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);\n\n\n        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));\n\n        shared_y4 = read_imagef(src1, (offset + 2));\n        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);\n\n        shared_y4 = read_imagef(src1, (offset + 6));\n        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);\n\n\n        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));\n\n        shared_y4 = read_imagef(src1, (offset + 3));\n        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);\n\n        shared_y4 = read_imagef(src1, (offset + 7));\n        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);\n\n        uchar regE = src0_e[tile_offset + ib00 * ne01];\n        sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));\n    }\n\n    // reduction in local memory, assumes #subgroups=4\n    __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];\n    if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;\n    // if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;\n    // if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;\n    barrier(CLK_LOCAL_MEM_FENCE);\n    if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];\n    // if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];\n    // if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];\n\n    // 1 outputs per thread in subgroup 0\n    if (sgid == 0) {\n        dst = dst + (offsetd >> 2);\n        dst[i01 + tile_id * tile_size + i1 * ne01] = sum;\n    }\n\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n\n#ifdef cl_qcom_reqd_sub_group_size\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_128\n#endif\n\nkernel void kernel_gemm_noshuffle_q4_1_f32(\n    global const ushort * src0_q,\n    global const half  * src0_d,\n    global const half  * src0_m,\n    read_only image1d_buffer_t src1,\n    global float * dst,\n    ulong offsetd,\n    int m,\n    int n,\n    int k,\n    int n_no_padding\n) {\n    dst = (global float *)((global char *)dst + offsetd);\n\n    int m_4 = m >> 2;\n    int n_4 = n >> 2;\n\n    int gy = get_global_id(0);\n    int gx = get_global_id(1);\n    int gx_2 = gx << 2;\n\n    half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;\n    half8 B;\n    half4 dequantized_weights;\n\n    global const ushort* weight_ptr = src0_q + gx_2;\n    global const half*   scale_ptr  = src0_d + gx_2;\n    global const half*   min_ptr    = src0_m + gx_2;\n\n    for(int i = 0; i < k; i += 4) {\n        B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4));\n        B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1);\n\n        ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m));\n\n        half4 scale = vload4(0, scale_ptr + (i/32)*(m));\n        half4 minv  = vload4(0,   min_ptr + (i/32)*(m));\n\n        // j=0\n        dequantized_weights.s0 = (bits4.s0 & (0x000F)) * scale.s0 + minv.s0;\n        dequantized_weights.s1 = (bits4.s1 & (0x000F)) * scale.s1 + minv.s1;\n        dequantized_weights.s2 = (bits4.s2 & (0x000F)) * scale.s2 + minv.s2;\n        dequantized_weights.s3 = (bits4.s3 & (0x000F)) * scale.s3 + minv.s3;\n        c0 += B * dequantized_weights.s0;\n        c1 += B * dequantized_weights.s1;\n        c2 += B * dequantized_weights.s2;\n        c3 += B * dequantized_weights.s3;\n\n        // j=1\n        B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4));\n        B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1);\n        dequantized_weights.s0 = ((bits4.s0 & (0x00F0)) >> 4) * scale.s0 + minv.s0;\n        dequantized_weights.s1 = ((bits4.s1 & (0x00F0)) >> 4) * scale.s1 + minv.s1;\n        dequantized_weights.s2 = ((bits4.s2 & (0x00F0)) >> 4) * scale.s2 + minv.s2;\n        dequantized_weights.s3 = ((bits4.s3 & (0x00F0)) >> 4) * scale.s3 + minv.s3;\n        c0 += B * dequantized_weights.s0;\n        c1 += B * dequantized_weights.s1;\n        c2 += B * dequantized_weights.s2;\n        c3 += B * dequantized_weights.s3;\n\n        // j=2\n        B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4));\n        B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1);\n        dequantized_weights.s0 = ((bits4.s0 & (0x0F00)) >> 8) * scale.s0 + minv.s0;\n        dequantized_weights.s1 = ((bits4.s1 & (0x0F00)) >> 8) * scale.s1 + minv.s1;\n        dequantized_weights.s2 = ((bits4.s2 & (0x0F00)) >> 8) * scale.s2 + minv.s2;\n        dequantized_weights.s3 = ((bits4.s3 & (0x0F00)) >> 8) * scale.s3 + minv.s3;\n        c0 += B * dequantized_weights.s0;\n        c1 += B * dequantized_weights.s1;\n        c2 += B * dequantized_weights.s2;\n        c3 += B * dequantized_weights.s3;\n\n        // j=3\n        B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4));\n        B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1);\n        dequantized_weights.s0 = ((bits4.s0 & (0xF000)) >> 12) * scale.s0 + minv.s0;\n        dequantized_weights.s1 = ((bits4.s1 & (0xF000)) >> 12) * scale.s1 + minv.s1;\n        dequantized_weights.s2 = ((bits4.s2 & (0xF000)) >> 12) * scale.s2 + minv.s2;\n        dequantized_weights.s3 = ((bits4.s3 & (0xF000)) >> 12) * scale.s3 + minv.s3;\n        c0 += B * dequantized_weights.s0;\n        c1 += B * dequantized_weights.s1;\n        c2 += B * dequantized_weights.s2;\n        c3 += B * dequantized_weights.s3;\n    }\n\n    int idx = (gy<<3)*m + (gx<<2);\n\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n\n#define QK_MXFP4 32\n#define N_SIMDGROUP 4\n#define SIMDGROUP_WIDTH 64\n\nstatic inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {\n    ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;\n    fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;\n    fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;\n    fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;\n    fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;\n\n    bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;\n    bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;\n    bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;\n    bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;\n\n    fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;\n    fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;\n    fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;\n    fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;\n\n    sign_a.lo = (fp4x8.s0 << 12) & 0x8000;\n    sign_a.hi = (fp4x8.s0 << 8) & 0x8000;\n    sign_b.lo = (fp4x8.s0 << 4) & 0x8000;\n    sign_b.hi = fp4x8.s0 & 0x8000;\n\n    fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;\n    fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;\n\n    ushort2 fp16_packed_a_1, fp16_packed_b_1;\n    fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;\n    fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;\n    fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;\n    fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;\n\n    bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;\n    bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;\n    bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;\n    bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;\n\n    fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;\n    fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;\n    fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;\n    fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;\n\n    sign_a.lo = (fp4x8.s1 << 12) & 0x8000;\n    sign_a.hi = (fp4x8.s1 << 8) & 0x8000;\n    sign_b.lo = (fp4x8.s1 << 4) & 0x8000;\n    sign_b.hi = fp4x8.s1 & 0x8000;\n\n    fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;\n    fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;\n\n    return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));\n}\n\nstatic inline float e8m0_to_fp32(uchar x) {\n    int bits;\n    bits = (x == 0) ? 0x00400000 : ((uint) x << 23);\n    return as_float(bits);\n}\n\n\n__attribute__((qcom_reqd_sub_group_size(\"half\")))\n__kernel void kernel_gemv_moe_mxfp4_f32(\n    __global uint4 * src0_q,\n    __global uchar * src0_e,\n    __read_only image1d_buffer_t src1,\n    __global uint * src2,\n    __global float * dst,\n    ulong         offsetd,\n    int           ne00,\n    int           ne01,\n    int           ne11\n) {\n    uint i01  = get_global_id(0);\n    uint i20  = get_global_id(2);\n    uint sgid = get_local_id(1);\n    uint slid = get_sub_group_local_id();\n\n    uint i11 = i20 % ne11;\n\n    uint expert_id = src2[i20];\n    uint expert_offset = expert_id * ne00 * ne01 / 32;\n\n    __private float sum = 0.0f; // each thread calculate partial sum of one output\n\n    // loop along ne00 in block granularity, skip 4 blocks every iter\n    for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {\n\n        // load one block of q\n        uint4 regQ = src0_q[expert_offset + ib00 * ne01 + i01];\n\n        uint offset = i11 * ne00 / 4 + ib00 * 8;\n\n        half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));\n\n        float4 shared_y4;\n        shared_y4 = read_imagef(src1, (offset + 0));\n        float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);\n\n        shared_y4 = read_imagef(src1, (offset + 4));\n        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);\n\n\n        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));\n\n        shared_y4 = read_imagef(src1, (offset + 1));\n        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);\n\n        shared_y4 = read_imagef(src1, (offset + 5));\n        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);\n\n\n        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));\n\n        shared_y4 = read_imagef(src1, (offset + 2));\n        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);\n\n        shared_y4 = read_imagef(src1, (offset + 6));\n        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);\n\n\n        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));\n\n        shared_y4 = read_imagef(src1, (offset + 3));\n        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);\n\n        shared_y4 = read_imagef(src1, (offset + 7));\n        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);\n\n        uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset];\n        sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));\n    }\n\n    // reduction in local memory, assumes #subgroups=4\n    __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];\n    if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;\n    if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;\n    if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;\n    barrier(CLK_LOCAL_MEM_FENCE);\n    if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];\n    if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];\n    if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];\n\n    // 1 outputs per thread in subgroup 0\n    if (sgid == 0) {\n        dst = dst + (offsetd >> 2);\n        dst[i01 + i20 * ne01] = sum;\n    }\n\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/gemv_noshuffle.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n\n#ifdef cl_qcom_reqd_sub_group_size\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#endif\n\n// assume\n#define QK4_0 32\n#define N_SIMDGROUP 4\n\n#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \\\n    float shared_y; \\\n    shared_y = sub_group_broadcast(y.s0, 0); \\\n    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 0); \\\n    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 0); \\\n    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 0); \\\n    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s4, 0); \\\n    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 0); \\\n    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 0); \\\n    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 0); \\\n    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s0, 1); \\\n    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 1); \\\n    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 1); \\\n    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 1); \\\n    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s4, 1); \\\n    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 1); \\\n    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 1); \\\n    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 1); \\\n    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n\n\n#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \\\n    shared_y = sub_group_broadcast(y.s0, 2); \\\n    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 2); \\\n    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 2); \\\n    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 2); \\\n    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s4, 2); \\\n    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 2); \\\n    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 2); \\\n    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 2); \\\n    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s0, 3); \\\n    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 3); \\\n    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 3); \\\n    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 3); \\\n    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s4, 3); \\\n    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 3); \\\n    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 3); \\\n    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 3); \\\n    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n\n\n#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \\\n    float8 shared_y; \\\n    shared_y = sub_group_broadcast(y, 0); \\\n    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \\\n    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \\\n    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \\\n    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \\\n    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \\\n    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \\\n    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \\\n    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \\\n    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \\\n    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \\\n    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \\\n    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \\\n    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \\\n    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \\\n    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \\\n    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \\\n    shared_y = sub_group_broadcast(y, 1); \\\n    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \\\n    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \\\n    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \\\n    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \\\n    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \\\n    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \\\n    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \\\n    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \\\n    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \\\n    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \\\n    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \\\n    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \\\n    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \\\n    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \\\n    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \\\n    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \\\n\n\n#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \\\n    shared_y = sub_group_broadcast(y, 2); \\\n    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \\\n    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \\\n    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \\\n    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \\\n    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \\\n    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \\\n    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \\\n    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \\\n    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \\\n    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \\\n    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \\\n    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \\\n    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \\\n    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \\\n    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \\\n    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \\\n    shared_y = sub_group_broadcast(y, 3); \\\n    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \\\n    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \\\n    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \\\n    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \\\n    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \\\n    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \\\n    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \\\n    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \\\n    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \\\n    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \\\n    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \\\n    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \\\n    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \\\n    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \\\n    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \\\n    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \\\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_64\n#endif\n__kernel void kernel_gemv_noshuffle(\n        __read_only  image1d_buffer_t src0_q,  // quantized A\n        global half2  * src0_d,  // A scales\n        __read_only  image1d_buffer_t src1,    // B\n        ulong offset1,            // offset to B (0)\n        global float * dst,     // C\n        ulong offsetd,            // offset to C (0)\n        uint K,               // K\n        int ne01,               // M\n        int ne02,               // 1\n        int ne10,               // K\n        int ne12,               // 1\n        int ne0,                // M\n        int ne1,                // N\n        int r2,                 // 1\n        int r3)\n{\n    uint groupId = get_local_id(1);\n    uint gid     = get_global_id(0);\n    ushort slid    = get_sub_group_local_id();\n\n    __private uint4     regA;\n    __private half2     regS;\n    __private float8    regB;\n\n    __private float2 totalSum = (float2)(0.0f);\n\n    // loop along K in block granularity, skip 4 blocks every iter\n    for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) {\n        regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows\n        // first 4 fibers in each wave load 8 B values to its private scope\n        if (slid < 4) {\n            regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));\n            regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));\n        }\n\n        // load half weights for two blocks in consecutive rows\n        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;\n        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;\n        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;\n        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;\n#ifdef VECTOR_SUB_GROUP_BROADCAT\n        dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);\n#else\n        dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);\n#endif // VECTOR_SUB_GROUP_BROADCAT\n\n        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;\n        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;\n        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;\n        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;\n#ifdef VECTOR_SUB_GROUP_BROADCAT\n        dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);\n#else\n        dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);\n#endif // VECTOR_SUB_GROUP_BROADCAT\n    }\n\n    // reduction in local memory, assumes #wave=4\n    __local float2 reduceLM[SIMDGROUP_WIDTH * 3];\n    if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;\n    if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;\n    if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;\n    barrier(CLK_LOCAL_MEM_FENCE);\n    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];\n    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];\n    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];\n\n    // 2 outputs per fiber in wave 0\n    if (groupId == 0) {\n        dst = (global float*)((global char*)dst + offsetd);\n        vstore2(totalSum, 0, &(dst[gid * 2]));\n    }\n\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/gemv_noshuffle_general.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n\n#ifdef cl_qcom_reqd_sub_group_size\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#endif\n\n// assume\n#define QK4_0 32\n#define N_SIMDGROUP 4\n\n#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \\\n    float shared_y; \\\n    shared_y = sub_group_broadcast(y.s0, 0); \\\n    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 0); \\\n    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 0); \\\n    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 0); \\\n    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s4, 0); \\\n    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 0); \\\n    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 0); \\\n    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 0); \\\n    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s0, 1); \\\n    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 1); \\\n    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 1); \\\n    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 1); \\\n    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s4, 1); \\\n    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 1); \\\n    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 1); \\\n    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 1); \\\n    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n\n\n#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \\\n    shared_y = sub_group_broadcast(y.s0, 2); \\\n    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 2); \\\n    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 2); \\\n    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 2); \\\n    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s4, 2); \\\n    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 2); \\\n    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 2); \\\n    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 2); \\\n    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s0, 3); \\\n    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 3); \\\n    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 3); \\\n    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 3); \\\n    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s4, 3); \\\n    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 3); \\\n    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 3); \\\n    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 3); \\\n    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \\\n\n\n#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \\\n    float8 shared_y; \\\n    shared_y = sub_group_broadcast(y, 0); \\\n    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \\\n    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \\\n    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \\\n    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \\\n    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \\\n    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \\\n    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \\\n    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \\\n    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \\\n    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \\\n    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \\\n    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \\\n    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \\\n    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \\\n    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \\\n    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \\\n    shared_y = sub_group_broadcast(y, 1); \\\n    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \\\n    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \\\n    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \\\n    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \\\n    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \\\n    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \\\n    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \\\n    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \\\n    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \\\n    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \\\n    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \\\n    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \\\n    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \\\n    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \\\n    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \\\n    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \\\n\n\n#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \\\n    shared_y = sub_group_broadcast(y, 2); \\\n    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \\\n    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \\\n    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \\\n    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \\\n    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \\\n    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \\\n    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \\\n    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \\\n    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \\\n    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \\\n    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \\\n    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \\\n    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \\\n    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \\\n    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \\\n    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \\\n    shared_y = sub_group_broadcast(y, 3); \\\n    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \\\n    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \\\n    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \\\n    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \\\n    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \\\n    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \\\n    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \\\n    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \\\n    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \\\n    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \\\n    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \\\n    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \\\n    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \\\n    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \\\n    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \\\n    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \\\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_64\n#endif\n__kernel void kernel_gemv_noshuffle(\n        __read_only  image1d_buffer_t src0_q,  // quantized A\n        global half2  * src0_d,  // A scales\n        __read_only  image1d_buffer_t src1,    // B\n        ulong offset1,            // offset to B (0)\n        global float * dst,     // C\n        ulong offsetd,            // offset to C (0)\n        int ne00,               // K\n        int ne01,               // M\n        int ne02,               // 1\n        int ne10,               // K\n        int ne12,               // 1\n        int ne0,                // M\n        int ne1,                // N\n        int r2,                 // 1\n        int r3)\n{\n    uint groupId = get_local_id(1);\n    uint gid     = get_global_id(0);\n    ushort slid    = get_sub_group_local_id();\n\n    uint K = ne00;\n    uint M = ne01;\n\n    uint LINE_STRIDE_A = M / 2;\n    uint BLOCK_STRIDE_A = N_SIMDGROUP * M;\n\n    __private uint4     regA;\n    __private half2     regS;\n    __private float8    regB;\n\n    __private float2 totalSum = (float2)(0.0f);\n\n    // loop along K in block granularity, skip 4 blocks every iter\n    for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) {\n        regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows\n        // first 4 fibers in each wave load 8 B values to its private scope\n        if (slid < 4) {\n            regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));\n            regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));\n        }\n\n        // load half weights for two blocks in consecutive rows\n        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;\n        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;\n        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;\n        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;\n#ifdef VECTOR_SUB_GROUP_BROADCAT\n        dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);\n#else\n        dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);\n#endif // VECTOR_SUB_GROUP_BROADCAT\n\n        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;\n        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;\n        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;\n        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;\n#ifdef VECTOR_SUB_GROUP_BROADCAT\n        dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);\n#else\n        dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);\n#endif // VECTOR_SUB_GROUP_BROADCAT\n    }\n\n    // reduction in local memory, assumes #wave=4\n    __local float2 reduceLM[SIMDGROUP_WIDTH * 3];\n    if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;\n    if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;\n    if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;\n    barrier(CLK_LOCAL_MEM_FENCE);\n    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];\n    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];\n    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];\n\n    // 2 outputs per fiber in wave 0\n    if (groupId == 0) {\n        dst = (global float*)((global char*)dst + offsetd);\n        vstore2(totalSum, 0, &(dst[gid * 2]));\n    }\n\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n\n#ifdef cl_qcom_reqd_sub_group_size\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#endif\n\n#define QK8_0 32\n#define N_SIMDGROUP 4\n\n#define dequantizeBlockAccum_ns_sgbroadcast_1(total_sums, bits8, scale, y) \\\n    float shared_y; \\\n    char elem; \\\n                                             \\\n    shared_y = sub_group_broadcast(y.s0, 0); \\\n    elem = (char)(bits8.s0 & 0x000000FF); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 0); \\\n    elem = (char)((bits8.s0 & 0x0000FF00) >> 8); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 0); \\\n    elem = (char)((bits8.s0 & 0x00FF0000) >> 16); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 0); \\\n    elem = (char)((bits8.s0 & 0xFF000000) >> 24); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n                                             \\\n    shared_y = sub_group_broadcast(y.s4, 0); \\\n    elem = (char)(bits8.s1 & 0x000000FF); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 0); \\\n    elem = (char)((bits8.s1 & 0x0000FF00) >> 8); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 0); \\\n    elem = (char)((bits8.s1 & 0x00FF0000) >> 16); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 0); \\\n    elem = (char)((bits8.s1 & 0xFF000000) >> 24); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n                                             \\\n    shared_y = sub_group_broadcast(y.s0, 1); \\\n    elem = (char)(bits8.s2 & 0x000000FF); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 1); \\\n    elem = (char)((bits8.s2 & 0x0000FF00) >> 8); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 1); \\\n    elem = (char)((bits8.s2 & 0x00FF0000) >> 16); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 1); \\\n    elem = (char)((bits8.s2 & 0xFF000000) >> 24); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n                                             \\\n    shared_y = sub_group_broadcast(y.s4, 1); \\\n    elem = (char)(bits8.s3 & 0x000000FF); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 1); \\\n    elem = (char)((bits8.s3 & 0x0000FF00) >> 8); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 1); \\\n    elem = (char)((bits8.s3 & 0x00FF0000) >> 16); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 1); \\\n    elem = (char)((bits8.s3 & 0xFF000000) >> 24); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n                                             \\\n    shared_y = sub_group_broadcast(y.s0, 2); \\\n    elem = (char)(bits8.s4 & 0x000000FF); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 2); \\\n    elem = (char)((bits8.s4 & 0x0000FF00) >> 8); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 2); \\\n    elem = (char)((bits8.s4 & 0x00FF0000) >> 16); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 2); \\\n    elem = (char)((bits8.s4 & 0xFF000000) >> 24); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n                                             \\\n    shared_y = sub_group_broadcast(y.s4, 2); \\\n    elem = (char)(bits8.s5 & 0x000000FF); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 2); \\\n    elem = (char)((bits8.s5 & 0x0000FF00) >> 8); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 2); \\\n    elem = (char)((bits8.s5 & 0x00FF0000) >> 16); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 2); \\\n    elem = (char)((bits8.s5 & 0xFF000000) >> 24); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n                                             \\\n    shared_y = sub_group_broadcast(y.s0, 3); \\\n    elem = (char)(bits8.s6 & 0x000000FF); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 3); \\\n    elem = (char)((bits8.s6 & 0x0000FF00) >> 8); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 3); \\\n    elem = (char)((bits8.s6 & 0x00FF0000) >> 16); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 3); \\\n    elem = (char)((bits8.s6 & 0xFF000000) >> 24); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n                                             \\\n    shared_y = sub_group_broadcast(y.s4, 3); \\\n    elem = (char)(bits8.s7 & 0x000000FF); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 3); \\\n    elem = (char)((bits8.s7 & 0x0000FF00) >> 8); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 3); \\\n    elem = (char)((bits8.s7 & 0x00FF0000) >> 16); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 3); \\\n    elem = (char)((bits8.s7 & 0xFF000000) >> 24); \\\n    total_sums += convert_int(elem) * scale * shared_y; \\\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_64\n#endif\n__kernel void kernel_gemv_noshuffle_q8_0_f32(\n        __read_only  image1d_buffer_t src0_q,  // quantized A\n        global half  * src0_d,  // A scales\n        __read_only  image1d_buffer_t src1,    // B\n        ulong offset1,            // offset to B (0)\n        global float * dst,     // C\n        ulong offsetd,            // offset to C\n        int ne00,               // K\n        int ne01,               // M\n        int ne02,               // 1\n        int ne10,               // K\n        int ne12,               // 1\n        int ne0,                // M\n        int ne1,                // N\n        int r2,                 // 1\n        int r3)\n{\n    uint groupId = get_local_id(1);\n    uint gid     = get_global_id(0);\n    ushort slid    = get_sub_group_local_id();\n\n    uint K = ne00;\n    uint M = ne01;\n\n    uint LINE_STRIDE_A = M;\n    uint BLOCK_STRIDE_A = 8 * M;   // 32 / 4 = 8\n\n    __private uint8     regA;\n    __private half      regS;\n    __private float8    regB;\n\n    __private float totalSum = (float)(0.0f);\n\n    // loop along K in block granularity, skip 4 blocks every iter\n    #pragma unroll 1 /* tell compiler not to unroll */\n    for (uint k = groupId; k < (K / QK8_0); k += N_SIMDGROUP) {\n        regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of one rows\n        // first 4 fibers in each wave load 8 B values to its private scope\n        if (slid < 4) {\n            regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));\n            regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));\n        }\n\n        // load weights for one block in consecutive rows\n        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;\n        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;\n        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;\n        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;\n        regA.s4 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;\n        regA.s5 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;\n        regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;\n        regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;\n\n        dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);\n    }\n\n    // reduction in local memory, assumes #wave=4\n    __local float reduceLM[SIMDGROUP_WIDTH * 3];\n    if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;\n    if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;\n    if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;\n    barrier(CLK_LOCAL_MEM_FENCE);\n    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];\n    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];\n    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];\n\n    // 1 outputs per fiber in wave 0\n    if (groupId == 0) {\n        dst = (global float*)((global char*)dst + offsetd);\n        dst[gid] = totalSum;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n\n#ifdef cl_qcom_reqd_sub_group_size\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#endif\n\n#define QK4_0 32\n#define NSUBGROUPS 4\n#define SUBGROUP_SIZE 64\n\n#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, minv, y) \\\n    float shared_y; \\\n    shared_y = sub_group_broadcast(y.s0, 0); \\\n    total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 0); \\\n    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 0); \\\n    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 0); \\\n    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s4, 0); \\\n    total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 0); \\\n    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 0); \\\n    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 0); \\\n    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s0, 1); \\\n    total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 1); \\\n    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 1); \\\n    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 1); \\\n    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s4, 1); \\\n    total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 1); \\\n    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 1); \\\n    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 1); \\\n    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \\\n\n\n#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, minv, y) \\\n    shared_y = sub_group_broadcast(y.s0, 2); \\\n    total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 2); \\\n    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 2); \\\n    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 2); \\\n    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s4, 2); \\\n    total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 2); \\\n    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 2); \\\n    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 2); \\\n    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s0, 3); \\\n    total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s1, 3); \\\n    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s2, 3); \\\n    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s3, 3); \\\n    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s4, 3); \\\n    total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s5, 3); \\\n    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s6, 3); \\\n    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \\\n    shared_y = sub_group_broadcast(y.s7, 3); \\\n    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \\\n    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \\\n\n\n#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, minv, y) \\\n    float8 shared_y; \\\n    shared_y = sub_group_broadcast(y, 0); \\\n    total_sums.s0 += ((bits4.s0 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s0; \\\n    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s1; \\\n    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s2; \\\n    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \\\n    total_sums.s0 += ((bits4.s2 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s4; \\\n    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s5; \\\n    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s6; \\\n    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \\\n    total_sums.s1 += ((bits4.s1 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s0; \\\n    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s1; \\\n    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s2; \\\n    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \\\n    total_sums.s1 += ((bits4.s3 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s4; \\\n    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s5; \\\n    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s6; \\\n    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \\\n    shared_y = sub_group_broadcast(y, 1); \\\n    total_sums.s0 += ((bits4.s4 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s0; \\\n    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s1; \\\n    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s2; \\\n    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \\\n    total_sums.s0 += ((bits4.s6 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s4; \\\n    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s5; \\\n    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s6; \\\n    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \\\n    total_sums.s1 += ((bits4.s5 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s0; \\\n    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s1; \\\n    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s2; \\\n    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \\\n    total_sums.s1 += ((bits4.s7 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s4; \\\n    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s5; \\\n    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s6; \\\n    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \\\n\n\n#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, minv, y) \\\n    shared_y = sub_group_broadcast(y, 2); \\\n    total_sums.s0 += ((bits4.s0 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s0; \\\n    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s1; \\\n    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s2; \\\n    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \\\n    total_sums.s0 += ((bits4.s2 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s4; \\\n    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s5; \\\n    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s6; \\\n    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \\\n    total_sums.s1 += ((bits4.s1 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s0; \\\n    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s1; \\\n    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s2; \\\n    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \\\n    total_sums.s1 += ((bits4.s3 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s4; \\\n    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s5; \\\n    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s6; \\\n    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \\\n    shared_y = sub_group_broadcast(y, 3); \\\n    total_sums.s0 += ((bits4.s4 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s0; \\\n    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s1; \\\n    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s2; \\\n    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \\\n    total_sums.s0 += ((bits4.s6 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s4; \\\n    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s5; \\\n    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s6; \\\n    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \\\n    total_sums.s1 += ((bits4.s5 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s0; \\\n    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s1; \\\n    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s2; \\\n    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \\\n    total_sums.s1 += ((bits4.s7 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s4; \\\n    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s5; \\\n    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s6; \\\n    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \\\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_gemv_noshuffle_q4_1_f32(\n        read_only  image1d_buffer_t src0_q,\n        global half2  * src0_d,\n        global half2  * src0_m,\n        read_only  image1d_buffer_t src1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01)\n{\n    uint groupId = get_local_id(1);\n    uint gid     = get_global_id(0);\n    ushort slid    = get_sub_group_local_id();\n\n    uint K = ne00;\n    uint M = ne01;\n\n    uint LINE_STRIDE_A = M / 2;\n    uint BLOCK_STRIDE_A = NSUBGROUPS * M;\n\n    private uint4     regA;\n    private half2     regS;\n    private half2     regM;\n    private float8    regB;\n\n    private float2 totalSum = (float2)(0.0f);\n\n    // loop along K in block granularity, skip 4 blocks every iter\n    for (uint k = groupId; k < (K / QK4_0); k += NSUBGROUPS) {\n        regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows\n        regM = src0_m[gid + k * LINE_STRIDE_A]; // each fiber loads min of two rows\n        // first 4 fibers in each wave load 8 B values to its private scope\n        if (slid < 4) {\n            regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));\n            regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));\n        }\n\n        // load half weights for two blocks in consecutive rows\n        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;\n        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;\n        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;\n        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;\n#ifdef VECTOR_SUB_GROUP_BROADCAT\n        dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regM, regB);\n#else\n        dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regM, regB);\n#endif // VECTOR_SUB_GROUP_BROADCAT\n\n        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;\n        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;\n        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;\n        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;\n#ifdef VECTOR_SUB_GROUP_BROADCAT\n        dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regM, regB);\n#else\n        dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regM, regB);\n#endif // VECTOR_SUB_GROUP_BROADCAT\n    }\n\n    // reduction in local memory, assumes #wave=4\n    local float2 reduceLM[SUBGROUP_SIZE * 3];\n    if (groupId == 1) {\n        reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum;\n    }\n    if (groupId == 2) {\n        reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum;\n    }\n    if (groupId == 3) {\n        reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum;\n    }\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    if (groupId == 0) {\n        totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid];\n    }\n    if (groupId == 0) {\n        totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid];\n    }\n    if (groupId == 0) {\n        totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid];\n    }\n\n    // 2 outputs per fiber in wave 0\n    if (groupId == 0) {\n        dst = (global float*)((global char*)dst + offsetd);\n        vstore2(totalSum, 0, &(dst[gid * 2]));\n    }\n\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/get_rows.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\ntypedef char int8_t;\ntypedef uchar uint8_t;\ntypedef short int16_t;\ntypedef ushort uint16_t;\ntypedef int int32_t;\ntypedef uint uint32_t;\n\n#define QK4_0                   32\n\n//------------------------------------------------------------------------------\n// block_q4_0\n//------------------------------------------------------------------------------\nstruct block_q4_0\n{\n    half d;\n    uint8_t qs[QK4_0 / 2];\n};\n\n\n//------------------------------------------------------------------------------\n// dequantize_q4_0_f32, dequantize_q4_0_f16\n//------------------------------------------------------------------------------\nvoid dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) {\n    global ushort * qs = ((global ushort *)xb + 1);\n    float d1 = il ? (xb->d / 16.h) : xb->d;\n    float d2 = d1 / 256.f;\n    float md = -8.h * xb->d;\n    ushort mask0 = il ? 0x00F0 : 0x000F;\n    ushort mask1 = mask0 << 8;\n\n    reg->s0 = d1 * (qs[0] & mask0) + md;\n    reg->s1 = d2 * (qs[0] & mask1) + md;\n\n    reg->s2 = d1 * (qs[1] & mask0) + md;\n    reg->s3 = d2 * (qs[1] & mask1) + md;\n\n    reg->s4 = d1 * (qs[2] & mask0) + md;\n    reg->s5 = d2 * (qs[2] & mask1) + md;\n\n    reg->s6 = d1 * (qs[3] & mask0) + md;\n    reg->s7 = d2 * (qs[3] & mask1) + md;\n\n    reg->s8 = d1 * (qs[4] & mask0) + md;\n    reg->s9 = d2 * (qs[4] & mask1) + md;\n\n    reg->sa = d1 * (qs[5] & mask0) + md;\n    reg->sb = d2 * (qs[5] & mask1) + md;\n\n    reg->sc = d1 * (qs[6] & mask0) + md;\n    reg->sd = d2 * (qs[6] & mask1) + md;\n\n    reg->se = d1 * (qs[7] & mask0) + md;\n    reg->sf = d2 * (qs[7] & mask1) + md;\n}\n\n\n//------------------------------------------------------------------------------\n// get_rows\n//------------------------------------------------------------------------------\nkernel void kernel_get_rows_f32(\n        global void * src0,\n        ulong offset0,\n        global int * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global int*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i10 = get_group_id(0);\n    int i11 = get_group_id(1);\n    int i12 = get_group_id(2);\n\n    int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];\n\n    int i02 = i11;\n    int i03 = i12;\n\n    for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {\n        if (ind >= ne00) {\n            return;\n        }\n        ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =\n            ((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];\n    }\n}\n\nkernel void kernel_get_rows_f16(\n        global void * src0,\n        ulong offset0,\n        global int * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global int*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i10 = get_group_id(0);\n    int i11 = get_group_id(1);\n    int i12 = get_group_id(2);\n\n    int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];\n\n    int i02 = i11;\n    int i03 = i12;\n\n    for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {\n        if (ind >= ne00) {\n            return;\n        }\n        ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =\n            ((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];\n    }\n}\n\nkernel void kernel_get_rows_q4_0(\n        global void * src0,\n        ulong offset0,\n        global int * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global int*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    const int NL = 2;\n\n    int i10 = get_group_id(0);\n    int i11 = get_group_id(1);\n    int i12 = get_group_id(2);\n\n    int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];\n\n    int i02 = i11;\n    int i03 = i12;\n\n    for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {\n        float16 temp;\n        if (ind >= ne00) {\n            return;\n        }\n        dequantize_q4_0_f32(\n            ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp);\n        *(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/glu.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#define GELU_COEF_A     0.044715f\n#define GELU_QUICK_COEF -1.702f\n#define SQRT_2_OVER_PI  0.79788456080286535587989211986876f\n#define SQRT_2_INV      0.70710678118654752440084436210484f\n\n//------------------------------------------------------------------------------\n// geglu\n//------------------------------------------------------------------------------\nkernel void kernel_geglu(\n    global char * src0,\n    ulong  offset0,\n    global char * src1,\n    ulong  offset1,\n    global char * dst,\n    ulong  offsetd,\n    ulong nb01,\n    ulong nb11,\n    int ne0,\n    ulong nb1,\n    int ne00_off,\n    int ne10_off\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;\n    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;\n    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const float x0 = src0_row[i0];\n        const float x1 = src1_row[i0];\n\n        const float gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));\n\n        dst_row[i0] = gelu*x1;\n    }\n}\n\nkernel void kernel_geglu_f16(\n    global char * src0,\n    ulong  offset0,\n    global char * src1,\n    ulong  offset1,\n    global char * dst,\n    ulong  offsetd,\n    ulong nb01,\n    ulong nb11,\n    int ne0,\n    ulong nb1,\n    int ne00_off,\n    int ne10_off\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;\n    global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;\n    global half * dst_row  = (global half *) ((global char *) dst  + get_group_id(0)*nb1);\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const half x0 = src0_row[i0];\n        const half x1 = src1_row[i0];\n\n        const half gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));\n\n        dst_row[i0] = gelu*x1;\n    }\n}\n\n//------------------------------------------------------------------------------\n// reglu\n//------------------------------------------------------------------------------\nkernel void kernel_reglu(\n    global char * src0,\n    ulong  offset0,\n    global char * src1,\n    ulong  offset1,\n    global char * dst,\n    ulong  offsetd,\n    ulong nb01,\n    ulong nb11,\n    int ne0,\n    ulong nb1,\n    int ne00_off,\n    int ne10_off\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;\n    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;\n    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const float x0 = src0_row[i0];\n        const float x1 = src1_row[i0];\n\n        dst_row[i0] = x0*x1*(x0 > 0.0f);\n    }\n}\n\nkernel void kernel_reglu_f16(\n    global char * src0,\n    ulong  offset0,\n    global char * src1,\n    ulong  offset1,\n    global char * dst,\n    ulong  offsetd,\n    ulong nb01,\n    ulong nb11,\n    int ne0,\n    ulong nb1,\n    int ne00_off,\n    int ne10_off\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;\n    global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;\n    global half * dst_row  = (global half *) ((global char *) dst  + get_group_id(0)*nb1);\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const half x0 = src0_row[i0];\n        const half x1 = src1_row[i0];\n\n        dst_row[i0] = x0*x1*(x0 > 0.0f);\n    }\n}\n\n//------------------------------------------------------------------------------\n// swiglu\n//------------------------------------------------------------------------------\nkernel void kernel_swiglu(\n    global char * src0,\n    ulong  offset0,\n    global char * src1,\n    ulong  offset1,\n    global char * dst,\n    ulong  offsetd,\n    ulong nb01,\n    ulong nb11,\n    int ne0,\n    ulong nb1,\n    int ne00_off,\n    int ne10_off\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;\n    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;\n    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const float x0 = src0_row[i0];\n        const float x1 = src1_row[i0];\n\n        const float silu = x0 / (1.0f + exp(-x0));\n\n        dst_row[i0] = silu*x1;\n    }\n}\n\nkernel void kernel_swiglu_f16(\n    global char * src0,\n    ulong  offset0,\n    global char * src1,\n    ulong  offset1,\n    global char * dst,\n    ulong  offsetd,\n    ulong nb01,\n    ulong nb11,\n    int ne0,\n    ulong nb1,\n    int ne00_off,\n    int ne10_off\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;\n    global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;\n    global half * dst_row  = (global half *) ((global char *) dst  + get_group_id(0)*nb1);\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const half x0 = src0_row[i0];\n        const half x1 = src1_row[i0];\n\n        const half silu = x0 / (1.0f + exp(-x0));\n\n        dst_row[i0] = silu*x1;\n    }\n}\n\n//------------------------------------------------------------------------------\n// swiglu_oai\n//------------------------------------------------------------------------------\nkernel void kernel_swiglu_oai(\n    global char * src0,\n    ulong         offset0,\n    global char * src1,\n    ulong         offset1,\n    global char * dst,\n    ulong         offsetd,\n    ulong         nb01,\n    ulong         nb11,\n    int           ne0,\n    ulong         nb1,\n    int           ne00_off,\n    int           ne10_off,\n    float         limit,\n    float         alpha\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;\n    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;\n    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        float x0 = src0_row[i0];\n        float x1 = src1_row[i0];\n\n        x0 = min(x0, limit);\n        x1 = max(min(x1, limit), -limit);\n\n        float out_glu = x0 / (1.0f + exp(-x0 * alpha));\n        out_glu = out_glu * (1.0f + x1);\n\n        dst_row[i0] = out_glu;\n    }\n}\n\n//------------------------------------------------------------------------------\n// geglu_erf\n//------------------------------------------------------------------------------\nkernel void kernel_geglu_erf(\n    global char * src0,\n    ulong  offset0,\n    global char * src1,\n    ulong  offset1,\n    global char * dst,\n    ulong  offsetd,\n    ulong nb01,\n    ulong nb11,\n    int ne0,\n    ulong nb1,\n    int ne00_off,\n    int ne10_off\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;\n    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;\n    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const float x0 = src0_row[i0];\n        const float x1 = src1_row[i0];\n\n        const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));\n\n        dst_row[i0] = gelu_erf*x1;\n    }\n}\n\nkernel void kernel_geglu_erf_f16(\n    global char * src0,\n    ulong  offset0,\n    global char * src1,\n    ulong  offset1,\n    global char * dst,\n    ulong  offsetd,\n    ulong nb01,\n    ulong nb11,\n    int ne0,\n    ulong nb1,\n    int ne00_off,\n    int ne10_off\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;\n    global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;\n    global half * dst_row  = (global half *) ((global char *) dst  + get_group_id(0)*nb1);\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const half x0 = src0_row[i0];\n        const half x1 = src1_row[i0];\n\n        const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));\n\n        dst_row[i0] = gelu_erf*x1;\n    }\n}\n\n//------------------------------------------------------------------------------\n// geglu_quick\n//------------------------------------------------------------------------------\nkernel void kernel_geglu_quick(\n    global char * src0,\n    ulong  offset0,\n    global char * src1,\n    ulong  offset1,\n    global char * dst,\n    ulong  offsetd,\n    ulong nb01,\n    ulong nb11,\n    int ne0,\n    ulong nb1,\n    int ne00_off,\n    int ne10_off\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;\n    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;\n    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const float x0 = src0_row[i0];\n        const float x1 = src1_row[i0];\n\n        const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));\n\n        dst_row[i0] = gelu_quick*x1;\n    }\n}\n\nkernel void kernel_geglu_quick_f16(\n    global char * src0,\n    ulong  offset0,\n    global char * src1,\n    ulong  offset1,\n    global char * dst,\n    ulong  offsetd,\n    ulong nb01,\n    ulong nb11,\n    int ne0,\n    ulong nb1,\n    int ne00_off,\n    int ne10_off\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;\n    global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;\n    global half * dst_row  = (global half *) ((global char *) dst  + get_group_id(0)*nb1);\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const half x0 = src0_row[i0];\n        const half x1 = src1_row[i0];\n\n        const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));\n\n        dst_row[i0] = gelu_quick*x1;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/group_norm.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n// Workgroup must be a subgroup\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_32\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_group_norm(\n        global float * src0,\n        ulong offset0,\n        global float * dst,\n        ulong offsetd,\n        int ne,\n        int group_size,\n        float eps\n) {\n    src0 = (global float  *)((global char *)src0 + offset0);\n    dst  = (global float *)((global char *)dst  + offsetd);\n\n    int start = get_group_id(0) * group_size;\n    int end   = start + group_size;\n\n    start += get_local_id(0);\n\n    if (end >= ne) {\n        end = ne;\n    }\n\n    float tmp = 0.0f;\n\n    for (int j = start; j < end; j += get_local_size(0)) {\n        tmp += src0[j];\n    }\n\n    tmp = sub_group_reduce_add(tmp);\n\n    const float mean = tmp / group_size;\n    tmp = 0.0f;\n\n    for (int j = start; j < end; j += get_local_size(0)) {\n        float xi = src0[j] - mean;\n        dst[j] = xi;\n        tmp += xi * xi;\n    }\n\n    tmp = sub_group_reduce_add(tmp);\n\n    const float variance = tmp / group_size;\n    const float scale = 1.0f/sqrt(variance + eps);\n    for (int j = start; j < end; j += get_local_size(0)) {\n        dst[j] *= scale;\n    }\n}\n\n//------------------------------------------------------------------------------\n// group_norm_mul_add\n//------------------------------------------------------------------------------\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_32\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_group_norm_mul_add(\n        global float * src0, ulong offset0,\n        global float * src1, ulong offset1,\n        global float * src2, ulong offset2,\n        global float * dst, ulong offsetd,\n        int ne,\n        int group_size,\n        float eps\n) {\n    src0 = (global float *)((global char *)src0 + offset0);\n    src1 = (global float *)((global char *)src1 + offset1);\n    src2 = (global float *)((global char *)src2 + offset2);\n    dst  = (global float *)((global char *)dst  + offsetd);\n\n    int start = get_group_id(0) * group_size;\n    int end = start + group_size;\n    if (end > ne) {\n        end = ne;\n    }\n\n    float sum = 0.0f;\n    float sum_sq = 0.0f;\n\n    for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {\n        float val = src0[j];\n        sum += val;\n        sum_sq += val*val;\n    }\n\n    sum = sub_group_reduce_add(sum);\n    sum_sq = sub_group_reduce_add(sum_sq);\n\n    const float mean = sum / group_size;\n    const float var = sum_sq / group_size - mean * mean;\n    const float scale = rsqrt(var + eps);\n\n    for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {\n        dst[j] = ((src0[j] - mean) * scale) * src1[j] + src2[j];\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/im2col_f16.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\nkernel void kernel_im2col_f16(\n        global float * src1,\n        ulong offset1,\n        global half  * dst,\n        ulong offsetd,\n        ulong batch_offset,\n        ulong delta_offset,\n        long IW,\n        long IH,\n        long IC,\n        long OW,\n        long OH,\n        long KW,\n        long KH,\n        long pelements,\n        long CHW,\n        int  s0,\n        int  s1,\n        int  p0,\n        int  p1,\n        int  d0,\n        int  d1\n) {\n    long i = get_global_id(0);\n    if (i >= pelements) {\n        return;\n    }\n\n    src1 = (global float*)((global char*)src1 + offset1);\n    dst = (global half*)((global char*)dst + offsetd);\n\n    long  ksize = OW * KH;\n    long  kx = i / ksize;\n    long  kd = kx * ksize;\n    long  ky = (i - kd) / OW;\n    long  ix = i % OW;\n\n    long  oh = get_group_id(1);\n    long  batch = get_group_id(2) / IC;\n    long  ic = get_group_id(2) % IC;\n\n    long iiw = ix * s0 + kx * d0 - p0;\n    long iih = oh * s1 + ky * d1 - p1;\n\n    long offset_dst =\n        ((batch * OH + oh) * OW + ix) * CHW +\n        (ic * (KW * KH) + ky * KW + kx);\n\n    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {\n        dst[offset_dst] = 0.0f;\n    } else {\n        long offset_src = ic * delta_offset + batch * batch_offset;\n        dst[offset_dst] = src1[offset_src + iih * IW + iiw];\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/im2col_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\nkernel void kernel_im2col_f32(\n        global float * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        ulong batch_offset,\n        ulong delta_offset,\n        long IW,\n        long IH,\n        long IC,\n        long OW,\n        long OH,\n        long KW,\n        long KH,\n        long pelements,\n        long CHW,\n        int  s0,\n        int  s1,\n        int  p0,\n        int  p1,\n        int  d0,\n        int  d1\n) {\n    long i = get_global_id(0);\n    if (i >= pelements) {\n        return;\n    }\n\n    src1 = (global float*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    long  ksize = OW * KH;\n    long  kx = i / ksize;\n    long  kd = kx * ksize;\n    long  ky = (i - kd) / OW;\n    long  ix = i % OW;\n\n    long  oh = get_group_id(1);\n    long  batch = get_group_id(2) / IC;\n    long  ic = get_group_id(2) % IC;\n\n    long iiw = ix * s0 + kx * d0 - p0;\n    long iih = oh * s1 + ky * d1 - p1;\n\n    long offset_dst =\n        ((batch * OH + oh) * OW + ix) * CHW +\n        (ic * (KW * KH) + ky * KW + kx);\n\n    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {\n        dst[offset_dst] = 0.0f;\n    } else {\n        long offset_src = ic * delta_offset + batch * batch_offset;\n        dst[offset_dst] = src1[offset_src + iih * IW + iiw];\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/l2_norm.cl",
    "content": "#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_32\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_l2_norm_f32(\n        global void * src0,\n        ulong offset0,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        float eps,\n        local float * sum\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);\n    global float * y = (global float *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);\n\n    float sumf = 0;\n\n    // parallel sum\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        sumf += x[i00] * x[i00];\n    }\n    sumf = sub_group_reduce_add(sumf);\n\n    if (get_sub_group_local_id() == 0) {\n        sum[get_sub_group_id()] = sumf;\n    }\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    // broadcast\n    for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {\n       if (get_local_id(0) < i) {\n           sum[get_local_id(0)] += sum[get_local_id(0) + i];\n       }\n    }\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    const float scale = 1.0f/max(sqrt(sum[0]), eps);\n\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        y[i00] = x[i00] * scale;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mean.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n\n// Most devices have max workgroup size of 1024, so this is enough for subgroup\n// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes\n#define MAX_SUBGROUPS 64\nkernel void kernel_mean_f32(\n    global char *  src0,\n    ulong           offset0,\n    global char *  dst,\n    ulong           offsetd,\n    int             ne00,\n    int             ne01,\n    int             ne02,\n    int             ne03,\n    ulong           nb01,\n    ulong           nb02,\n    ulong           nb03,\n    ulong           nb1,\n    ulong           nb2,\n    ulong           nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst  + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    const int lid = get_local_id(0);\n    const int lsize = get_local_size(0);\n\n    const uint sg_size = get_sub_group_size();\n    const uint sg_id = get_sub_group_id();\n    const uint sg_lid = get_sub_group_local_id();\n\n    __local float lmem[MAX_SUBGROUPS];\n\n    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {\n        return;\n    }\n\n    if(sg_id == 0){\n        lmem[sg_lid] = 0.0f;\n    }\n\n    global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);\n    global float * dst_row = (global float *) (dst  + i1*nb1  + i2*nb2  + i3*nb3);\n\n    float sumf = 0.0f;\n\n    for (int i0 = lid; i0 < ne00; i0 += lsize) {\n        sumf += src_row[i0];\n    }\n\n    sumf = sub_group_reduce_add(sumf);\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    if(sg_lid == 0){\n        lmem[sg_id] = sumf;\n    }\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    sumf = lmem[sg_lid];\n    sumf = sub_group_reduce_add(sumf);\n\n    if (lid == 0) {\n        dst_row[0] = sumf / ne00;\n    }\n}\n\nkernel void kernel_mean_f32_4(\n    global char *  src0,\n    ulong           offset0,\n    global char *  dst,\n    ulong           offsetd,\n    int             ne00,\n    int             ne01,\n    int             ne02,\n    int             ne03,\n    ulong           nb01,\n    ulong           nb02,\n    ulong           nb03,\n    ulong           nb1,\n    ulong           nb2,\n    ulong           nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst  + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    const int lid = get_local_id(0);\n    const int lsize = get_local_size(0);\n\n    const uint sg_size = get_sub_group_size();\n    const uint sg_id = get_sub_group_id();\n    const uint sg_lid = get_sub_group_local_id();\n\n    __local float lmem[MAX_SUBGROUPS];\n\n    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {\n        return;\n    }\n\n    if(sg_id == 0){\n        lmem[sg_lid] = 0.0f;\n    }\n\n    global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);\n    global float  * dst_row = (global float  *) (dst  + i1*nb1  + i2*nb2  + i3*nb3);\n\n    float4 sum_vec = (float4)0.0f;\n\n    for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {\n        sum_vec += src_row[i0];\n    }\n\n    float sumf = dot(sum_vec, (float4)(1.0f));\n    sumf = sub_group_reduce_add(sumf);\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    if(sg_lid == 0){\n        lmem[sg_id] = sumf;\n    }\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    sumf = lmem[sg_lid];\n    sumf = sub_group_reduce_add(sumf);\n\n    if (lid == 0) {\n        dst_row[0] = sumf / ne00;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// mul\n//------------------------------------------------------------------------------\nkernel void kernel_mul(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global char * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        int ne11,\n        int ne12,\n        int ne13,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int i13 = i03 % ne13;\n    int i12 = i02 % ne12;\n    int i11 = i01 % ne11;\n\n    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;\n    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;\n    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const int i10 = i0 % ne10;\n        *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10));\n    }\n}\n\n// assumption: src1 is a row\n// broadcast src1 into src0\nkernel void kernel_mul_row(\n        global float4 * src0,\n        ulong offset0,\n        global float4 * src1,\n        ulong offset1,\n        global float4 * dst,\n        ulong offsetd,\n        int ne\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    src1 = (global float4*)((global char*)src1 + offset1);\n    dst = (global float4*)((global char*)dst + offsetd);\n\n    // This performs better than using %.\n    uint gid = get_global_id(0);\n    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne\n    dst[gid] = src0[gid] * src1[idx1];\n}\n\nkernel void kernel_mul_f16(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global char * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        int ne11,\n        int ne12,\n        int ne13,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int i13 = i03 % ne13;\n    int i12 = i02 % ne12;\n    int i11 = i01 % ne11;\n\n    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;\n    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;\n    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const int i10 = i0 % ne10;\n        *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) * *((global half *)(src1_ptr + i10*nb10));\n    }\n}\n\nkernel void kernel_mul_row_f16(\n        global half4 * src0,\n        ulong offset0,\n        global half4 * src1,\n        ulong offset1,\n        global half4 * dst,\n        ulong offsetd,\n        int ne\n) {\n    src0 = (global half4*)((global char*)src0 + offset0);\n    src1 = (global half4*)((global char*)src1 + offset1);\n    dst = (global half4*)((global char*)dst + offsetd);\n\n    // This performs better than using %.\n    uint gid = get_global_id(0);\n    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne\n    dst[gid] = src0[gid] * src1[idx1];\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl",
    "content": "// src0_q, src0_d, src1 are transposed as a preprocessing step\n// 4-bit weights are transposed in groups of 4 (unsigned short int)\n// consider weights originally \"next to each other\", now \"on top of each other\"\n// each fiber computes a 8x4 tile of output elements\n// using unshuffled weights\n\n#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n\n#ifdef cl_qcom_reqd_sub_group_size\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_128\n#endif\n\nkernel void kernel_mul_mat_Ab_Bi_8x4(\n        global const ushort * src0_q,       // quantized A\n        global const half  * src0_d,        // A scales\n        __read_only image1d_buffer_t src1,  // B (1d image)\n        global float * dst,                 // C\n        int m,                              // M\n        int n,                              // N with padding\n        int k,                              // K\n        int n_no_padding                    // N without padding\n) {\n\n    int m_4 = m >> 2;\n    int n_4 = n >> 2;\n\n    int gy = get_global_id(0);\n    int gx = get_global_id(1);\n    int gx_2 = gx << 2;\n\n    half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; // 8x4 output elements\n    half8 B; // registers for activations\n    half4 dequantized_weights; // registers for dequantized weights\n    __global const ushort* weight_ptr = src0_q + gx_2; // pointer for weights\n    __global const half* scale_ptr = src0_d + gx_2; // pointer for scales\n\n    for(int i=0; i<k; i+=4){ //loop through K dimension\n\n        B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4));\n        B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1);\n\n        // keep (i/4) and (i/32) in parenthesis, rounds down\n        // load 4 consecutive groups of 4 weights\n        ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m)); // (i/4) because weights grouped in 4s\n\n        // load 4 consecutive scales\n        half4 scale = vload4(0, scale_ptr + (i/32)*(m));// (i/32) because 1 scale per 32 elements\n\n        // j=0\n        dequantized_weights.s0 = ((bits4.s0 & (0x000F)) - 8) * scale.s0; // dequantize a row of the 16 weights\n        dequantized_weights.s1 = ((bits4.s1 & (0x000F)) - 8) * scale.s1;\n        dequantized_weights.s2 = ((bits4.s2 & (0x000F)) - 8) * scale.s2;\n        dequantized_weights.s3 = ((bits4.s3 & (0x000F)) - 8) * scale.s3;\n        c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate\n        c1 += B * dequantized_weights.s1;\n        c2 += B * dequantized_weights.s2;\n        c3 += B * dequantized_weights.s3;\n\n        // j=1\n        B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4));\n        B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1);\n        dequantized_weights.s0 = (((bits4.s0 & (0x00F0)) >> 4) - 8) * scale.s0; // dequantize a row of the 16 weights\n        dequantized_weights.s1 = (((bits4.s1 & (0x00F0)) >> 4) - 8) * scale.s1;\n        dequantized_weights.s2 = (((bits4.s2 & (0x00F0)) >> 4) - 8) * scale.s2;\n        dequantized_weights.s3 = (((bits4.s3 & (0x00F0)) >> 4) - 8) * scale.s3;\n        c0 += B * dequantized_weights.s0; //vector-scalar multiplication to accumulate\n        c1 += B * dequantized_weights.s1;\n        c2 += B * dequantized_weights.s2;\n        c3 += B * dequantized_weights.s3;\n\n        // j=2\n        B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4));\n        B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1);\n        dequantized_weights.s0 = (((bits4.s0 & (0x0F00)) >> 8) - 8) * scale.s0; // dequantize a row of the 16 weights\n        dequantized_weights.s1 = (((bits4.s1 & (0x0F00)) >> 8) - 8) * scale.s1;\n        dequantized_weights.s2 = (((bits4.s2 & (0x0F00)) >> 8) - 8) * scale.s2;\n        dequantized_weights.s3 = (((bits4.s3 & (0x0F00)) >> 8) - 8) * scale.s3;\n        c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate\n        c1 += B * dequantized_weights.s1;\n        c2 += B * dequantized_weights.s2;\n        c3 += B * dequantized_weights.s3;\n\n        // j=3\n        B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4));\n        B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1);\n        dequantized_weights.s0 = (((bits4.s0 & (0xF000)) >> 12) - 8) * scale.s0; // dequantize a row of the 16 weights\n        dequantized_weights.s1 = (((bits4.s1 & (0xF000)) >> 12) - 8) * scale.s1;\n        dequantized_weights.s2 = (((bits4.s2 & (0xF000)) >> 12) - 8) * scale.s2;\n        dequantized_weights.s3 = (((bits4.s3 & (0xF000)) >> 12) - 8) * scale.s3;\n        c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate\n        c1 += B * dequantized_weights.s1;\n        c2 += B * dequantized_weights.s2;\n        c3 += B * dequantized_weights.s3;\n    }\n\n    int idx = (gy<<3)*m + (gx<<2); // vectorized store 16 elements\n\n    // conditional check if store is to a valid location. Required when N is not a multiple of 8\n    // if statements allow registers to be reused for each store\n    // provides a performance boost due to reduced register footprint, which increases number of concurrent waves\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mat_f16_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#if defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#else\n#define REQD_SUBGROUP_SIZE_128\n#endif\n\n#define OPWM 64\n#define OPWN 64\n#define CPWK 8\n#define OPTM 4\n#define OPTN 8\n\n#define WG_M (OPWM / OPTM)\n#define WG_N (OPWN / OPTN)\n#define VEC_K (CPWK / 4)\n\nREQD_SUBGROUP_SIZE_128\n__kernel void mul_mat_f16_f32(\n    const int M, const int N, const int K,\n    __global const void* A_void, ulong A_offset,\n    __global const void* B_void, ulong B_offset,\n    __global       void* C_void, ulong C_offset) {\n\n    __global const half*  A = (__global const half* )((__global const char*)A_void + A_offset);\n    __global const float* B = (__global const float*)((__global const char*)B_void + B_offset);\n    __global       float* C = (__global       float*)((__global       char*)C_void + C_offset);\n\n    const int lidm = get_local_id(0);\n    const int lidn = get_local_id(1);\n    const int lid = lidn * WG_M + lidm;\n\n    const int offsetM = get_group_id(0) * OPWM;\n    const int offsetN = get_group_id(1) * OPWN;\n\n    __local half4  Alocal[OPWM][VEC_K];\n    __local float4 Blocal[OPWN][VEC_K];\n\n    float sum[OPTM][OPTN];\n\n    for (int wm = 0; wm < OPTM; wm++) {\n        for (int wn = 0; wn < OPTN; wn++) {\n            sum[wm][wn] = 0.0f;\n        }\n    }\n\n    const int numTiles = (K + CPWK - 1) / CPWK;\n\n    const int load_row_a = lid % OPWM;\n    const int load_vec_k_a = lid / OPWM;\n    const int global_row_a = offsetM + load_row_a;\n\n    const int load_row_b = lid % OPWN;\n    const int load_vec_k_b = lid / OPWN;\n    const int global_row_b = offsetN + load_row_b;\n\n    for (int t = 0; t < numTiles; t++) {\n        const int k_start = t * CPWK;\n        const int k_vec_start_a = k_start + load_vec_k_a * 4;\n        const int k_vec_start_b = k_start + load_vec_k_b * 4;\n\n        if (global_row_a < M && k_vec_start_a < K) {\n            if (k_vec_start_a + 3 < K) {\n                Alocal[load_row_a][load_vec_k_a] = vload4(0, A + global_row_a * K + k_vec_start_a);\n            } else {\n                half4 tempA = (half4)(0.0h);\n                if (k_vec_start_a < K) tempA.s0 = A[global_row_a * K + k_vec_start_a];\n                if (k_vec_start_a + 1 < K) tempA.s1 = A[global_row_a * K + k_vec_start_a + 1];\n                if (k_vec_start_a + 2 < K) tempA.s2 = A[global_row_a * K + k_vec_start_a + 2];\n                Alocal[load_row_a][load_vec_k_a] = tempA;\n            }\n        } else {\n            Alocal[load_row_a][load_vec_k_a] = (half4)(0.0h);\n        }\n\n        if (global_row_b < N && k_vec_start_b < K) {\n            if (k_vec_start_b + 3 < K) {\n                Blocal[load_row_b][load_vec_k_b] = vload4(0, B + global_row_b * K + k_vec_start_b);\n            } else {\n                float4 tempB = (float4)(0.0f);\n                if (k_vec_start_b < K) tempB.s0 = B[global_row_b * K + k_vec_start_b];\n                if (k_vec_start_b + 1 < K) tempB.s1 = B[global_row_b * K + k_vec_start_b + 1];\n                if (k_vec_start_b + 2 < K) tempB.s2 = B[global_row_b * K + k_vec_start_b + 2];\n                Blocal[load_row_b][load_vec_k_b] = tempB;\n            }\n        } else {\n            Blocal[load_row_b][load_vec_k_b] = (float4)(0.0f);\n        }\n\n        barrier(CLK_LOCAL_MEM_FENCE);\n\n        #pragma unroll\n        for (int k_vec = 0; k_vec < VEC_K; k_vec++) {\n            float4 a_fvecs[OPTM];\n            int current_row_a = lidm;\n            for (int wm = 0; wm < OPTM; wm++) {\n                a_fvecs[wm] = convert_float4(Alocal[current_row_a][k_vec]);\n                current_row_a += WG_M;\n            }\n\n            float4 b_fvecs[OPTN];\n            int current_row_b = lidn;\n            for (int wn = 0; wn < OPTN; wn++) {\n                b_fvecs[wn] = Blocal[current_row_b][k_vec];\n                current_row_b += WG_N;\n            }\n\n            for (int wm = 0; wm < OPTM; wm++) {\n                for (int wn = 0; wn < OPTN; wn++) {\n                    sum[wm][wn] += dot(a_fvecs[wm], b_fvecs[wn]);\n                }\n            }\n        }\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n\n    for (int wm = 0; wm < OPTM; wm++) {\n        int globalRow = offsetM + lidm + wm * WG_M;\n        if (globalRow < M) {\n            for (int wn = 0; wn < OPTN; wn++) {\n                int globalCol = offsetN + lidn + wn * WG_N;\n                if (globalCol < N) {\n                    C[globalCol * M + globalRow] = sum[wm][wn];\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n\n#define LM_FIRST_256B   0\n#define LM_SECOND_256B  64\n#define LM_THIRD_256B   128\n#define LM_FOURTH_256B  192\n\n\ninline float16 mm_load_a(\n    image1d_buffer_t matrix_A,\n    uint subMatrixAStartInElements,\n    int nb01,\n    int line_stride_matrix_A_in_bytes\n) {\n    __private float8 regA;\n    size_t sub_block_id_m = get_local_id(0);\n\n#ifdef KQV\n    uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4);\n#else // KQ\n    uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4);\n#endif\n\n    regA.s0123  = read_imagef(matrix_A, a_texCoord/4);\n    regA.s4567  = read_imagef(matrix_A, (a_texCoord+4)/4);\n\n    return convert_float16(as_half16(regA));\n}\n\ninline float4 alu_32(\n    float16 regA,\n    __local float4* matrix_B_vec\n) {\n\n    __private float4 rC = 0;\n    int i = get_sub_group_id() * 64;\n\n    rC += regA.s0  * matrix_B_vec[i];\n    rC += regA.s1  * matrix_B_vec[i + 16];\n    rC += regA.s4  * matrix_B_vec[i + 1];\n    rC += regA.s5  * matrix_B_vec[i + 17];\n    rC += regA.s8  * matrix_B_vec[i + 2];\n    rC += regA.s9  * matrix_B_vec[i + 18];\n    rC += regA.sc  * matrix_B_vec[i + 3];\n    rC += regA.sd  * matrix_B_vec[i + 19];\n\n    i += 32;\n\n    rC += regA.s2  * matrix_B_vec[i];\n     rC += regA.s3  * matrix_B_vec[i + 16];\n    rC += regA.s6  * matrix_B_vec[i + 1];\n    rC += regA.s7  * matrix_B_vec[i + 17];\n    rC += regA.sa  * matrix_B_vec[i + 2];\n    rC += regA.sb  * matrix_B_vec[i + 18];\n    rC += regA.se  * matrix_B_vec[i + 3];\n    rC += regA.sf  * matrix_B_vec[i + 19];\n\n    return rC;\n}\n\ninline float16 alu_16(\n    float16 regA,\n    __local float* matrix_B_local\n) {\n    float16 out;\n    __local float4* matrix_B_vec = (__local float4*)matrix_B_local;\n\n    out.s0123 = alu_32(regA, matrix_B_vec);\n    out.s4567 = alu_32(regA, matrix_B_vec + 4);\n    out.s89ab = alu_32(regA, matrix_B_vec + 8);\n    out.scdef = alu_32(regA, matrix_B_vec + 12);\n\n    return out;\n}\n\ninline void mm_mad(\n    __local float* matrix_B_local,\n    float16 regA,\n    float8 regB,\n    uint b_localOffsetInWords,\n    float16* regC0_ptr,\n    float16* regC1_ptr\n) {\n    int offset = b_localOffsetInWords + get_sub_group_id() * 256;\n\n    matrix_B_local[offset + LM_FIRST_256B] = regB.s0;\n    matrix_B_local[offset + LM_SECOND_256B] = regB.s1;\n    matrix_B_local[offset + LM_THIRD_256B] = regB.s2;\n    matrix_B_local[offset + LM_FOURTH_256B] = regB.s3;\n\n    float16 add0 = alu_16(regA, matrix_B_local);\n    *regC0_ptr += add0;\n\n    matrix_B_local[offset + LM_FIRST_256B] = regB.s4;\n    matrix_B_local[offset + LM_SECOND_256B] = regB.s5;\n    matrix_B_local[offset + LM_THIRD_256B] = regB.s6;\n    matrix_B_local[offset + LM_FOURTH_256B] = regB.s7;\n\n    float16 add1 = alu_16(regA, matrix_B_local);\n    *regC1_ptr += add1;\n}\n\ninline void mm_store_c_N(\n    __write_only image1d_buffer_t matrix_C,\n    float16 regC0,\n    float16 regC1,\n    uint subMatrixCStartInElements,\n    int line_stride_matrix_C_in_bytes,\n    int mask\n) {\n    size_t sub_block_id_m = get_local_id(0);\n\n    uint strideInWords     = line_stride_matrix_C_in_bytes/4;\n    uint c_coordInWords_0  = (subMatrixCStartInElements + sub_block_id_m);\n\n    uint c_coordInWords_1  = c_coordInWords_0 + 1  * strideInWords;\n    uint c_coordInWords_2  = c_coordInWords_0 + 2  * strideInWords;\n    uint c_coordInWords_3  = c_coordInWords_0 + 3  * strideInWords;\n    uint c_coordInWords_4  = c_coordInWords_0 + 4  * strideInWords;\n    uint c_coordInWords_5  = c_coordInWords_0 + 5  * strideInWords;\n    uint c_coordInWords_6  = c_coordInWords_0 + 6  * strideInWords;\n    uint c_coordInWords_7  = c_coordInWords_0 + 7  * strideInWords;\n    uint c_coordInWords_8  = c_coordInWords_0 + 8  * strideInWords;\n    uint c_coordInWords_9  = c_coordInWords_0 + 9  * strideInWords;\n    uint c_coordInWords_10 = c_coordInWords_0 + 10 * strideInWords;\n    uint c_coordInWords_11 = c_coordInWords_0 + 11 * strideInWords;\n    uint c_coordInWords_12 = c_coordInWords_0 + 12 * strideInWords;\n    uint c_coordInWords_13 = c_coordInWords_0 + 13 * strideInWords;\n    uint c_coordInWords_14 = c_coordInWords_0 + 14 * strideInWords;\n    uint c_coordInWords_15 = c_coordInWords_0 + 15 * strideInWords;\n    uint c_coordInWords_16 = c_coordInWords_0 + 16 * strideInWords;\n    uint c_coordInWords_17 = c_coordInWords_0 + 17 * strideInWords;\n    uint c_coordInWords_18 = c_coordInWords_0 + 18 * strideInWords;\n    uint c_coordInWords_19 = c_coordInWords_0 + 19 * strideInWords;\n    uint c_coordInWords_20 = c_coordInWords_0 + 20 * strideInWords;\n    uint c_coordInWords_21 = c_coordInWords_0 + 21 * strideInWords;\n    uint c_coordInWords_22 = c_coordInWords_0 + 22 * strideInWords;\n    uint c_coordInWords_23 = c_coordInWords_0 + 23 * strideInWords;\n    uint c_coordInWords_24 = c_coordInWords_0 + 24 * strideInWords;\n    uint c_coordInWords_25 = c_coordInWords_0 + 25 * strideInWords;\n    uint c_coordInWords_26 = c_coordInWords_0 + 26 * strideInWords;\n    uint c_coordInWords_27 = c_coordInWords_0 + 27 * strideInWords;\n    uint c_coordInWords_28 = c_coordInWords_0 + 28 * strideInWords;\n    uint c_coordInWords_29 = c_coordInWords_0 + 29 * strideInWords;\n    uint c_coordInWords_30 = c_coordInWords_0 + 30 * strideInWords;\n    uint c_coordInWords_31 = c_coordInWords_0 + 31 * strideInWords;\n\n    if (mask > 0)  { write_imagef(matrix_C, c_coordInWords_0, regC0.s0);  }\n    if (mask > 1)  { write_imagef(matrix_C, c_coordInWords_1, regC0.s1);  }\n    if (mask > 2)  { write_imagef(matrix_C, c_coordInWords_2, regC0.s2);  }\n    if (mask > 3)  { write_imagef(matrix_C, c_coordInWords_3, regC0.s3);  }\n    if (mask > 4)  { write_imagef(matrix_C, c_coordInWords_4, regC0.s4);  }\n    if (mask > 5)  { write_imagef(matrix_C, c_coordInWords_5, regC0.s5);  }\n    if (mask > 6)  { write_imagef(matrix_C, c_coordInWords_6, regC0.s6);  }\n    if (mask > 7)  { write_imagef(matrix_C, c_coordInWords_7, regC0.s7);  }\n    if (mask > 8)  { write_imagef(matrix_C, c_coordInWords_8, regC0.s8);  }\n    if (mask > 9)  { write_imagef(matrix_C, c_coordInWords_9, regC0.s9);  }\n    if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC0.sa); }\n    if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC0.sb); }\n    if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC0.sc); }\n    if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC0.sd); }\n    if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC0.se); }\n    if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC0.sf); }\n    if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC1.s0); }\n    if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC1.s1); }\n    if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC1.s2); }\n    if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC1.s3); }\n    if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC1.s4); }\n    if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC1.s5); }\n    if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC1.s6); }\n    if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC1.s7); }\n    if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC1.s8); }\n    if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC1.s9); }\n    if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC1.sa); }\n    if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC1.sb); }\n    if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC1.sc); }\n    if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC1.sd); }\n    if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC1.se); }\n    if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC1.sf); }\n}\n\n#define TILESIZE_K 16\n#define TILESIZE_M 64\n#define TILESIZE_N 32\n#ifdef KQV\n__kernel void mul_mm_f16_f32_kqv(\n#else\n__kernel void mul_mm_f16_f32_kq(\n#endif\n        __read_only  image1d_buffer_t matrix_A,\n        int offset0,\n        __global float* matrix_B,\n        int offset1,\n        __write_only image1d_buffer_t matrix_C,\n        int offsetd,\n        int M, int K, int N,\n        int D_A,\n        int D_B,\n        int nb01\n) {\n\n    uint block_id_m = get_global_id(1);\n    uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N);\n    uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N);\n\n    __private float16  regA;\n    __private float8   regB;\n    __private float16 regC0;\n    __private float16 regC1;\n\n    const uint col   = block_id_m * TILESIZE_M;\n    const uint row   = block_id_n * TILESIZE_N;\n    const uint depth_A = block_id_d / (D_B/D_A);\n    const uint depth_B = block_id_d;\n\n#ifdef KQV\n    int line_stride_matrix_A_in_bytes = nb01 * M;\n    int line_stride_matrix_B_in_bytes = K * N * 4;\n#else\n    int line_stride_matrix_A_in_bytes = K * D_A * 2;\n    int line_stride_matrix_B_in_bytes = K * D_B * 4;\n#endif\n\n    int line_stride_matrix_C_in_bytes = M * 4;\n\n    const uint strideAinElements = line_stride_matrix_A_in_bytes / 2;\n    const uint strideBinElements = line_stride_matrix_B_in_bytes / 4;\n\n    size_t sub_block_id_m = get_local_id(0);\n\n    uint b_localOffsetInWords = (sub_block_id_m/16)*16\n                           + ((((sub_block_id_m)>>0)&1)<<2)\n                           + ((((sub_block_id_m)>>1)&1)<<3)\n                           + ((((sub_block_id_m)>>2)&1)<<0)\n                           + ((((sub_block_id_m)>>3)&1)<<1);\n\n    uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)};\n    uint b_globalOffsetInWords00, b_globalOffsetInWords16;\n#ifdef KQV\n    b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K;\n    b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K);\n    uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2;\n    uint subMatrixBStartInElements = depth_B * strideBinElements + row * K;\n#else\n    b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4;\n    b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4);\n    uint subMatrixAStartInElements = col * strideAinElements + depth_A * K;\n    uint subMatrixBStartInElements = row * strideBinElements + depth_B * K;\n#endif\n\n    __local float matrix_B_local[1024];\n\n    for (uint step=0; step < K; step+=TILESIZE_K) {\n        size_t sub_block_id_m = get_local_id(0);\n        regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes);\n\n        uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00;\n        uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16;\n\n        regB.s0123 = vload4(b_coordInWords00/4, matrix_B);\n        regB.s4567 = vload4(b_coordInWords16/4, matrix_B);\n\n        mm_mad(matrix_B_local, regA, regB, b_localOffsetInWords, &regC0, &regC1);\n\n        subMatrixAStartInElements += TILESIZE_K;\n        subMatrixBStartInElements += TILESIZE_K;\n    }\n\n    uint subMatrixCStartInElements = depth_B * N * M + row * M + col;\n    mm_store_c_N(matrix_C, regC0, regC1, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32));\n}\n\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#define LOAD_VEC_A 4\n#define LOAD_VEC_B 4\n\n#define BM 64\n#define BN 64\n#define BK 16\n#define TM 4\n#define TN 8\n\nkernel void kernel_mul_mm_f16_f32_l4_lm(\n    global half4 * src0,\n    ulong offset0,\n    global float4 * src1,\n    ulong offset1,\n    global float * dst,\n    ulong offsetd,\n\n    int ne00,\n    int ne01,\n    int ne02,\n    int ne11,\n    int ne12,\n\n    int stride_a,\n    int stride_b,\n    int stride_d,\n\n    int batch_stride_a,\n    int batch_stride_b,\n    int batch_stride_d,\n\n    int r2,\n    int r3\n) {\n    src0 = (global half4*)((global char*)src0 + offset0);\n    src1 = (global float4*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    local half  buf_a[BM * BK];\n    local float buf_b[BN * BK];\n\n    const int batch_idx = get_global_id(2);\n\n    const int i13 = batch_idx / ne12;\n    const int i12 = batch_idx % ne12;\n\n    const int i03 = i13 / r3;\n    const int i02 = i12 / r2;\n\n    const int batch_idx_a = i03 * ne02 + i02;\n\n    const int ir = get_group_id(0);\n    const int ic = get_group_id(1);\n\n    const int tid = get_local_id(0);\n    const int th_r  = tid % (BM / TM);\n    const int th_c  = tid / (BM / TM);\n\n    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);\n    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);\n    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);\n    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);\n\n    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;\n    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;\n\n    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;\n    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;\n\n    float sums[TM * TN];\n    half  cache_a[TM];\n    float cache_b[TN];\n\n    for (int i = 0; i < TM * TN; i++) {\n        sums[i] = 0.0f;\n    }\n\n    for (int block = 0; block < ne00; block += BK) {\n        for (int l = 0; l < BM; l += loadstride_a) {\n            if (ir*BM + loadc_a + l < ne01) {\n                const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;\n                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;\n                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;\n                buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;\n                buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;\n            } else {\n                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0h;\n                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0h;\n                buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0h;\n                buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0h;\n            }\n        }\n\n        for (int l = 0; l < BN; l += loadstride_b) {\n            if (ic*BN + loadc_b + l < ne11) {\n                const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;\n                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;\n                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;\n                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;\n                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;\n            } else {\n                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0h;\n                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0h;\n                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0h;\n                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0h;\n            }\n        }\n\n        barrier(CLK_LOCAL_MEM_FENCE);\n\n        pos_a += BK / LOAD_VEC_A;\n        pos_b += BK / LOAD_VEC_B;\n\n        for (int i = 0; i < BK; i++) {\n            for (int j = 0; j < TM; j++) {\n                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];\n            }\n            for (int j = 0; j < TN; j++) {\n                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];\n            }\n\n            for (int cc = 0; cc < TN; cc++) {\n                for (int cr = 0; cr < TM; cr++) {\n                    const int sums_idx = cc*TM + cr;\n                    sums[sums_idx] = mad(convert_float(cache_a[cr]), cache_b[cc], sums[sums_idx]);\n                }\n            }\n        }\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n\n    const int dr = ir * BM + th_r * TM;\n    const int dc = ic * BN + th_c * TN;\n\n    const int offsets = batch_idx * batch_stride_d;\n\n    for (int cc = 0; cc < TN; cc++) {\n        for (int cr = 0; cr < TM; cr++) {\n            if (dr + cr < ne01 && dc + cc < ne11) {\n                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#define LOAD_VEC_A 4\n#define LOAD_VEC_B 4\n\n#define BM 64\n#define BN 64\n#define BK 16\n#define TM 4\n#define TN 8\n\nkernel void kernel_mul_mm_f32_f32_l4_lm(\n    global float4 * src0,\n    ulong offset0,\n    global float4 * src1,\n    ulong offset1,\n    global float * dst,\n    ulong offsetd,\n\n    int ne00,\n    int ne01,\n    int ne02,\n    int ne11,\n    int ne12,\n\n    int stride_a,\n    int stride_b,\n    int stride_d,\n\n    int batch_stride_a,\n    int batch_stride_b,\n    int batch_stride_d,\n\n    int r2,\n    int r3\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    src1 = (global float4*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    local float buf_a[BM * BK];\n    local float buf_b[BN * BK];\n\n    const int batch_idx = get_global_id(2);\n\n    const int i13 = batch_idx / ne12;\n    const int i12 = batch_idx % ne12;\n\n    const int i03 = i13 / r3;\n    const int i02 = i12 / r2;\n\n    const int batch_idx_a = i03 * ne02 + i02;\n\n    const int ir = get_group_id(0);\n    const int ic = get_group_id(1);\n\n    const int tid = get_local_id(0);\n    const int th_r  = tid % (BM / TM);\n    const int th_c  = tid / (BM / TM);\n\n    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);\n    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);\n    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);\n    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);\n\n    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;\n    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;\n\n    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;\n    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;\n\n    float sums[TM * TN];\n    float cache_a[TM];\n    float cache_b[TN];\n\n    for (int i = 0; i < TM * TN; i++) {\n        sums[i] = 0.0f;\n    }\n\n    for (int block = 0; block < ne00; block += BK) {\n        for (int l = 0; l < BM; l += loadstride_a) {\n            if (ir*BM + loadc_a + l < ne01) {\n                const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;\n                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;\n                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;\n                buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;\n                buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;\n            } else {\n                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;\n            }\n        }\n\n        for (int l = 0; l < BN; l += loadstride_b) {\n            if (ic*BN + loadc_b + l < ne11) {\n                const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;\n                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;\n                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;\n                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;\n                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;\n            } else {\n                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;\n            }\n        }\n\n        barrier(CLK_LOCAL_MEM_FENCE);\n\n        pos_a += BK / LOAD_VEC_A;\n        pos_b += BK / LOAD_VEC_B;\n\n        for (int i = 0; i < BK; i++) {\n            for (int j = 0; j < TM; j++) {\n                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];\n            }\n\n            for (int j = 0; j < TN; j++) {\n                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];\n            }\n\n            for (int cc = 0; cc < TN; cc++) {\n                for (int cr = 0; cr < TM; cr++) {\n                    const int sums_idx = cc*TM + cr;\n                    sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);\n                }\n            }\n        }\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n\n    const int dr = ir * BM + th_r * TM;\n    const int dc = ic * BN + th_c * TN;\n\n    const int offsets = batch_idx * batch_stride_d;\n\n    for (int cc = 0; cc < TN; cc++) {\n        for (int cr = 0; cr < TM; cr++) {\n            if (dr + cr < ne01 && dc + cc < ne11) {\n                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#define LOAD_VEC_A 8\n#define LOAD_VEC_B 4\n\n#define BM 64\n#define BN 64\n#define BK 32\n#define TM 4\n#define TN 8\n\nkernel void kernel_mul_mm_q4_0_f32_l4_lm(\n    global uchar4 * src0_q,\n    global half   * src0_d,\n    global float4 * src1,\n    ulong offset1,\n    global float  * dst,\n    ulong offsetd,\n\n    int ne00,\n    int ne01,\n    int ne02,\n    int ne11,\n    int ne12,\n\n    int stride_a,\n    int stride_b,\n    int stride_d,\n\n    int batch_stride_a,\n    int batch_stride_b,\n    int batch_stride_d,\n\n    int r2,\n    int r3\n) {\n    src1 = (global float4*)((global char*)src1 + offset1);\n    dst  = (global float *)((global char*)dst  + offsetd);\n\n    local float buf_a[BM * BK];\n    local float buf_b[BN * BK];\n\n    const int batch_idx = get_global_id(2);\n\n    const int i13 = batch_idx / ne12;\n    const int i12 = batch_idx % ne12;\n\n    const int i03 = i13 / r3;\n    const int i02 = i12 / r2;\n\n    const int batch_idx_a = i03 * ne02 + i02;\n\n    const int ir = get_group_id(0);\n    const int ic = get_group_id(1);\n\n    const int tid = get_local_id(0);\n    const int th_r  = tid % (BM / TM);\n    const int th_c  = tid / (BM / TM);\n\n    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);\n    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);\n    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);\n    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);\n\n    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;\n    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;\n\n    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;\n    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;\n\n    float sums[TM * TN];\n    float cache_a[TM];\n    float cache_b[TN];\n\n    for (int i = 0; i < TM * TN; i++) {\n        sums[i] = 0.0f;\n    }\n\n    for (int block = 0; block < ne00; block += BK) {\n        for (int l = 0; l < BM; l += loadstride_a) {\n            if (ir*BM + loadc_a + l < ne01) {\n                int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;\n                int ib  = idx / 4;\n                int iqs = idx % 4;\n\n                float d = (float)src0_d[ib];\n                global uchar4 * qs = src0_q + ib*4 + iqs;\n                uchar4 q = *qs;\n                float4 v1 = (convert_float4((uchar4)((q.s0   )&0x0F, (q.s1   )&0x0F, (q.s2   )&0x0F, (q.s3   )&0x0F)) - 8.0f)*d;\n                float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d;\n\n                buf_a[(loadr_a * 4 +  0) * BM + loadc_a + l] = v1.s0;\n                buf_a[(loadr_a * 4 +  1) * BM + loadc_a + l] = v1.s1;\n                buf_a[(loadr_a * 4 +  2) * BM + loadc_a + l] = v1.s2;\n                buf_a[(loadr_a * 4 +  3) * BM + loadc_a + l] = v1.s3;\n                buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;\n                buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;\n                buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;\n                buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;\n            } else {\n                buf_a[(loadr_a * 4 +  0) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * 4 +  1) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * 4 +  2) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * 4 +  3) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;\n            }\n        }\n\n        for (int l = 0; l < BN; l += loadstride_b) {\n            if (ic*BN + loadc_b + l < ne11) {\n                int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;\n                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;\n                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;\n                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;\n                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;\n            } else {\n                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;\n            }\n        }\n\n        barrier(CLK_LOCAL_MEM_FENCE);\n\n        pos_a += BK / LOAD_VEC_A;\n        pos_b += BK / LOAD_VEC_B;\n\n        for (int i = 0; i < BK; i++) {\n            for (int j = 0; j < TM; j++) {\n                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];\n            }\n\n            for (int j = 0; j < TN; j++) {\n                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];\n            }\n\n            for (int cc = 0; cc < TN; cc++) {\n                for (int cr = 0; cr < TM; cr++) {\n                    const int sums_idx = cc*TM + cr;\n                    sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);\n                }\n            }\n        }\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n\n    const int dr = ir * BM + th_r * TM;\n    const int dc = ic * BN + th_c * TN;\n\n    const int offsets = batch_idx * batch_stride_d;\n\n    for (int cc = 0; cc < TN; cc++) {\n        for (int cr = 0; cr < TM; cr++) {\n            if (dr + cr < ne01 && dc + cc < ne11) {\n                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#define LOAD_VEC_A 8\n#define LOAD_VEC_B 4\n\n#define BM 64\n#define BN 64\n#define BK 32\n#define TM 4\n#define TN 8\n\nkernel void kernel_mul_mm_q4_1_f32_l4_lm(\n    global uchar4 * src0_q,\n    global half   * src0_d,\n    global half   * src0_m,\n    global float4 * src1,\n    ulong offset1,\n    global float  * dst,\n    ulong offsetd,\n\n    int ne00,\n    int ne01,\n    int ne02,\n    int ne11,\n    int ne12,\n\n    int stride_a,\n    int stride_b,\n    int stride_d,\n\n    int batch_stride_a,\n    int batch_stride_b,\n    int batch_stride_d,\n\n    int r2,\n    int r3\n) {\n    src1 = (global float4*)((global char*)src1 + offset1);\n    dst  = (global float *)((global char*)dst  + offsetd);\n\n    local float buf_a[BM * BK];\n    local float buf_b[BN * BK];\n\n    const int batch_idx = get_global_id(2);\n\n    const int i13 = batch_idx / ne12;\n    const int i12 = batch_idx % ne12;\n\n    const int i03 = i13 / r3;\n    const int i02 = i12 / r2;\n\n    const int batch_idx_a = i03 * ne02 + i02;\n\n    const int ir = get_group_id(0);\n    const int ic = get_group_id(1);\n\n    const int tid = get_local_id(0);\n    const int th_r  = tid % (BM / TM);\n    const int th_c  = tid / (BM / TM);\n\n    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);\n    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);\n    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);\n    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);\n\n    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;\n    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;\n\n    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;\n    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;\n\n    float sums[TM * TN];\n    float cache_a[TM];\n    float cache_b[TN];\n\n    for (int i = 0; i < TM * TN; i++) {\n        sums[i] = 0.0f;\n    }\n\n    for (int block = 0; block < ne00; block += BK) {\n        for (int l = 0; l < BM; l += loadstride_a) {\n            if (ir*BM + loadc_a + l < ne01) {\n                int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;\n                int ib  = idx / 4;\n                int iqs = idx % 4;\n\n                float d = (float)src0_d[ib];\n                float m = (float)src0_m[ib];\n                global uchar4 * qs = src0_q + ib*4 + iqs;\n                uchar4 q = *qs;\n                float4 v1 = (convert_float4((uchar4)((q.s0   )&0x0F, (q.s1   )&0x0F, (q.s2   )&0x0F, (q.s3   )&0x0F)))*d + m;\n                float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m;\n\n                buf_a[(loadr_a * 4 +  0) * BM + loadc_a + l] = v1.s0;\n                buf_a[(loadr_a * 4 +  1) * BM + loadc_a + l] = v1.s1;\n                buf_a[(loadr_a * 4 +  2) * BM + loadc_a + l] = v1.s2;\n                buf_a[(loadr_a * 4 +  3) * BM + loadc_a + l] = v1.s3;\n                buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;\n                buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;\n                buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;\n                buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;\n            } else {\n                buf_a[(loadr_a * 4 +  0) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * 4 +  1) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * 4 +  2) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * 4 +  3) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;\n            }\n        }\n\n        for (int l = 0; l < BN; l += loadstride_b) {\n            if (ic*BN + loadc_b + l < ne11) {\n                int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;\n                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;\n                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;\n                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;\n                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;\n            } else {\n                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;\n            }\n        }\n\n        barrier(CLK_LOCAL_MEM_FENCE);\n\n        pos_a += BK / LOAD_VEC_A;\n        pos_b += BK / LOAD_VEC_B;\n\n        for (int i = 0; i < BK; i++) {\n            for (int j = 0; j < TM; j++) {\n                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];\n            }\n\n            for (int j = 0; j < TN; j++) {\n                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];\n            }\n\n            for (int cc = 0; cc < TN; cc++) {\n                for (int cr = 0; cr < TM; cr++) {\n                    const int sums_idx = cc*TM + cr;\n                    sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);\n                }\n            }\n        }\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n\n    const int dr = ir * BM + th_r * TM;\n    const int dc = ic * BN + th_c * TN;\n\n    const int offsets = batch_idx * batch_stride_d;\n\n    for (int cc = 0; cc < TN; cc++) {\n        for (int cr = 0; cr < TM; cr++) {\n            if (dr + cr < ne01 && dc + cc < ne11) {\n                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#define LOAD_VEC_A 2\n#define LOAD_VEC_B 4\n\n#define BM 64\n#define BN 64\n#define BK 32\n#define TM 4\n#define TN 8\n\nkernel void kernel_mul_mm_q6_k_f32_l4_lm(\n    global uchar * src0_ql,\n    global uchar * src0_qh,\n    global char  * src0_s,\n    global half  * src0_d,\n    global float4 * src1,\n    ulong offset1,\n    global float  * dst,\n    ulong offsetd,\n\n    int ne00,\n    int ne01,\n    int ne02,\n    int ne11,\n    int ne12,\n\n    int stride_a,\n    int stride_b,\n    int stride_d,\n\n    int batch_stride_a,\n    int batch_stride_b,\n    int batch_stride_d,\n\n    int r2,\n    int r3\n) {\n    src1 = (global float4*)((global char*)src1 + offset1);\n    dst  = (global float *)((global char*)dst  + offsetd);\n\n    local float buf_a[BM * BK];\n    local float buf_b[BN * BK];\n\n    const int batch_idx = get_global_id(2);\n\n    const int i13 = batch_idx / ne12;\n    const int i12 = batch_idx % ne12;\n\n    const int i03 = i13 / r3;\n    const int i02 = i12 / r2;\n\n    const int batch_idx_a = i03 * ne02 + i02;\n\n    const int ir = get_group_id(0);\n    const int ic = get_group_id(1);\n\n    const int tid = get_local_id(0);\n    const int th_r  = tid % (BM / TM);\n    const int th_c  = tid / (BM / TM);\n\n    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);\n    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);\n    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);\n    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);\n\n    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;\n    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;\n\n    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;\n    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;\n\n    float sums[TM * TN];\n    float cache_a[TM];\n    float cache_b[TN];\n\n    for (int i = 0; i < TM * TN; i++) {\n        sums[i] = 0.0f;\n    }\n\n    for (int block = 0; block < ne00; block += BK) {\n        for (int l = 0; l < BM; l += loadstride_a) {\n            if (ir*BM + loadc_a + l < ne01) {\n                int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;\n\n                int ib = idx / 128;                  // 2 values per idx\n                int iqs = idx % 128;                 // 0..127\n\n                int n = iqs / 64;                    // 0,1\n                int b = (iqs % 64) / 32;             // 0,1\n                int is_b = (iqs % 16) / 8;           // 0,1\n                int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6\n                int is = 8 * n + qhshift + is_b;     // 0..15\n                int qsi = n * 64 + (iqs % 32) * 2;   // 0,2,4..126\n                int qhi = n * 32 + (iqs % 16) * 2;   // 0,2,4..62\n\n                float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is];\n\n                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32);\n                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32);\n            } else {\n                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;\n            }\n        }\n\n        for (int l = 0; l < BN; l += loadstride_b) {\n            if (ic*BN + loadc_b + l < ne11) {\n                int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;\n                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;\n                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;\n                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;\n                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;\n            } else {\n                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;\n            }\n        }\n\n        barrier(CLK_LOCAL_MEM_FENCE);\n\n        pos_a += BK / LOAD_VEC_A;\n        pos_b += BK / LOAD_VEC_B;\n\n        for (int i = 0; i < BK; i++) {\n            for (int j = 0; j < TM; j++) {\n                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];\n            }\n\n            for (int j = 0; j < TN; j++) {\n                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];\n            }\n\n            for (int cc = 0; cc < TN; cc++) {\n                for (int cr = 0; cr < TM; cr++) {\n                    const int sums_idx = cc*TM + cr;\n                    sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);\n                }\n            }\n        }\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n\n    const int dr = ir * BM + th_r * TM;\n    const int dc = ic * BN + th_c * TN;\n\n    const int offsets = batch_idx * batch_stride_d;\n\n    for (int cc = 0; cc < TN; cc++) {\n        for (int cr = 0; cr < TM; cr++) {\n            if (dr + cr < ne01 && dc + cc < ne11) {\n                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n\n#ifdef cl_qcom_reqd_sub_group_size\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_128\n#endif\n\nkernel void kernel_mul_mm_q8_0_f32_8x4(\n        global const uint * src0_q,\n        global const half  * src0_d,\n        __read_only image1d_buffer_t src1,\n        global float * dst,\n        int k,\n        int m,\n        int n,\n        int n_no_padding,\n        ulong offsetd\n) {\n\n    int m_4 = m >> 2;\n    int n_4 = n >> 2;\n\n    int gy   = get_global_id(0);\n    int gx   = get_global_id(1);\n    int gx_2 = gx << 2;\n    dst  = (global float *)((global char*)dst  + offsetd);\n\n\n    half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;\n    half8 B;\n    half4 deq;\n\n    __global const uint* wptr = src0_q + gx_2;\n    __global const half* sptr = src0_d + gx_2;\n\n      for (int i = 0; i < k; i += 4) {\n        uint4 pack4 = vload4(0, wptr + (i / 4) * m);\n        half4 scale = vload4(0, sptr + (i / 32) * m);\n\n        char4 p0 = as_char4(pack4.s0);\n        char4 p1 = as_char4(pack4.s1);\n        char4 p2 = as_char4(pack4.s2);\n        char4 p3 = as_char4(pack4.s3);\n\n        // ------------------- j = 0 (k = i+0) -------------------\n        B.s0123 = read_imageh(src1, gy * 2 + (i + 0) * n_4);\n        B.s4567 = read_imageh(src1, gy * 2 + (i + 0) * n_4 + 1);\n\n        half4 wj0 = convert_half4((char4)(p0.s0, p1.s0, p2.s0, p3.s0)) * scale;\n\n        c0 += B * wj0.s0;\n        c1 += B * wj0.s1;\n        c2 += B * wj0.s2;\n        c3 += B * wj0.s3;\n\n        // ------------------- j = 1 (k = i+1) -------------------\n        B.s0123 = read_imageh(src1, gy * 2 + (i + 1) * n_4);\n        B.s4567 = read_imageh(src1, gy * 2 + (i + 1) * n_4 + 1);\n\n        half4 wj1 = convert_half4((char4)(p0.s1, p1.s1, p2.s1, p3.s1)) * scale;\n\n        c0 += B * wj1.s0;\n        c1 += B * wj1.s1;\n        c2 += B * wj1.s2;\n        c3 += B * wj1.s3;\n\n        // ------------------- j = 2 (k = i+2) -------------------\n        B.s0123 = read_imageh(src1, gy * 2 + (i + 2) * n_4);\n        B.s4567 = read_imageh(src1, gy * 2 + (i + 2) * n_4 + 1);\n\n        half4 wj2 = convert_half4((char4)(p0.s2, p1.s2, p2.s2, p3.s2)) * scale;\n\n        c0 += B * wj2.s0;\n        c1 += B * wj2.s1;\n        c2 += B * wj2.s2;\n        c3 += B * wj2.s3;\n\n        // ------------------- j = 3 (k = i+3) -------------------\n        B.s0123 = read_imageh(src1, gy * 2 + (i + 3) * n_4);\n        B.s4567 = read_imageh(src1, gy * 2 + (i + 3) * n_4 + 1);\n\n        half4 wj3 = convert_half4((char4)(p0.s3, p1.s3, p2.s3, p3.s3)) * scale;\n\n        c0 += B * wj3.s0;\n        c1 += B * wj3.s1;\n        c2 += B * wj3.s2;\n        c3 += B * wj3.s3;\n    }\n\n    int idx = (gy << 3) * m + (gx << 2);\n\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);\n        idx += m;\n    }\n    if(idx+3 < m*n_no_padding){\n        vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#define LOAD_VEC_A 4\n#define LOAD_VEC_B 4\n\n#define BM 64\n#define BN 64\n#define BK 32\n#define TM 4\n#define TN 8\n\nkernel void kernel_mul_mm_q8_0_f32_l4_lm(\n    global char4  * src0_q,\n    global half   * src0_d,\n    global float4 * src1,\n    ulong offset1,\n    global float  * dst,\n    ulong offsetd,\n\n    int ne00,\n    int ne01,\n    int ne02,\n    int ne11,\n    int ne12,\n\n    int stride_a,\n    int stride_b,\n    int stride_d,\n\n    int batch_stride_a,\n    int batch_stride_b,\n    int batch_stride_d,\n\n    int r2,\n    int r3\n) {\n    src1 = (global float4*)((global char*)src1 + offset1);\n    dst  = (global float *)((global char*)dst  + offsetd);\n\n    local float buf_a[BM * BK];\n    local float buf_b[BN * BK];\n\n    const int batch_idx = get_global_id(2);\n\n    const int i13 = batch_idx / ne12;\n    const int i12 = batch_idx % ne12;\n\n    const int i03 = i13 / r3;\n    const int i02 = i12 / r2;\n\n    const int batch_idx_a = i03 * ne02 + i02;\n\n    const int ir = get_group_id(0);\n    const int ic = get_group_id(1);\n\n    const int tid = get_local_id(0);\n    const int th_r  = tid % (BM / TM);\n    const int th_c  = tid / (BM / TM);\n\n    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);\n    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);\n    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);\n    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);\n\n    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;\n    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;\n\n    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;\n    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;\n\n    float sums[TM * TN];\n    float cache_a[TM];\n    float cache_b[TN];\n\n    for (int i = 0; i < TM * TN; i++) {\n        sums[i] = 0.0f;\n    }\n\n    for (int block = 0; block < ne00; block += BK) {\n        for (int l = 0; l < BM; l += loadstride_a) {\n            if (ir*BM + loadc_a + l < ne01) {\n                int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;\n                int ib  = idx / 8;\n                int iqs = idx % 8;\n\n                float d = (float)src0_d[ib];\n                global char4 * qs = src0_q + ib*8 + iqs;\n                char4 q = *qs;\n                float4 v = convert_float4(q)*d;\n\n                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v.s0;\n                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v.s1;\n                buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v.s2;\n                buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v.s3;\n            } else {\n                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;\n                buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;\n            }\n        }\n\n        for (int l = 0; l < BN; l += loadstride_b) {\n            if (ic*BN + loadc_b + l < ne11) {\n                int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;\n                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;\n                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;\n                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;\n                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;\n            } else {\n                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;\n                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;\n            }\n        }\n\n        barrier(CLK_LOCAL_MEM_FENCE);\n\n        pos_a += BK / LOAD_VEC_A;\n        pos_b += BK / LOAD_VEC_B;\n\n        for (int i = 0; i < BK; i++) {\n            for (int j = 0; j < TM; j++) {\n                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];\n            }\n\n            for (int j = 0; j < TN; j++) {\n                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];\n            }\n\n            for (int cc = 0; cc < TN; cc++) {\n                for (int cr = 0; cr < TM; cr++) {\n                    const int sums_idx = cc*TM + cr;\n                    sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);\n                }\n            }\n        }\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n\n    const int dr = ir * BM + th_r * TM;\n    const int dc = ic * BN + th_c * TN;\n\n    const int offsets = batch_idx * batch_stride_d;\n\n    for (int cc = 0; cc < TN; cc++) {\n        for (int cr = 0; cr < TM; cr++) {\n            if (dr + cr < ne01 && dc + cc < ne11) {\n                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_f16_f16.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define N_F16_F16 4\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mat_f16_f16(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        int ne11,\n        int ne12,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3)\n{\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int r0 = get_group_id(0);\n    int rb = get_group_id(1)*N_F16_F16;\n    int im = get_group_id(2);\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;\n\n    global half * x = (global half *) (src0 + offset_src0);\n\n    if (ne00 < 128) {\n        for (int row = 0; row < N_F16_F16; ++row) {\n            int r1 = rb + row;\n            if (r1 >= ne11) {\n                break;\n            }\n\n            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;\n\n            global half * y = (global half *) (src1 + offset_src1);\n\n            float sumf = 0;\n            for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {\n                sumf += (half) x[i] * (half) y[i];\n            }\n\n            float all_sum = sub_group_reduce_add(sumf);\n            if (get_sub_group_local_id() == 0) {\n                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;\n            }\n        }\n    } else {\n        global half4 * x4 = (global half4 *)x;\n        for (int row = 0; row < N_F16_F16; ++row) {\n            int r1 = rb + row;\n            if (r1 >= ne11) {\n                break;\n            }\n\n            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;\n\n            global half  * y  = (global half  *) (src1 + offset_src1);\n            global half4 * y4 = (global half4 *) y;\n\n            float sumf = 0;\n            for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {\n                sumf += (half) x4[i].s0 * y4[i].s0;\n                sumf += (half) x4[i].s1 * y4[i].s1;\n                sumf += (half) x4[i].s2 * y4[i].s2;\n                sumf += (half) x4[i].s3 * y4[i].s3;\n            }\n\n            float all_sum = sub_group_reduce_add(sumf);\n            if (get_sub_group_local_id() == 0) {\n                for (int i = 4*(ne00/4); i < ne00; ++i) {\n                    all_sum += (half) x[i] * y[i];\n                }\n                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_f16_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define N_F16_F32 4\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mat_f16_f32(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        int ne11,\n        int ne12,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int r0 = get_group_id(0);\n    int rb = get_group_id(1)*N_F16_F32;\n    int im = get_group_id(2);\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;\n\n    global half * x = (global half *) (src0 + offset_src0);\n\n    if (ne00 < 128) {\n        for (int row = 0; row < N_F16_F32; ++row) {\n            int r1 = rb + row;\n            if (r1 >= ne11) {\n                break;\n            }\n\n            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;\n\n            global float * y = (global float *) (src1 + offset_src1);\n\n            float sumf = 0;\n            for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {\n                sumf += convert_float(x[i]) * y[i];\n            }\n\n            float all_sum = sub_group_reduce_add(sumf);\n            if (get_sub_group_local_id() == 0) {\n                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;\n            }\n        }\n    } else {\n        global half4 * x4 = (global half4 *)x;\n        for (int row = 0; row < N_F16_F32; ++row) {\n            int r1 = rb + row;\n            if (r1 >= ne11) {\n                break;\n            }\n\n            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;\n\n            global float  * y  = (global float  *) (src1 + offset_src1);\n            global float4 * y4 = (global float4 *) y;\n\n            float sumf = 0;\n            for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {\n                sumf += convert_float(x4[i].s0) * y4[i].s0;\n                sumf += convert_float(x4[i].s1) * y4[i].s1;\n                sumf += convert_float(x4[i].s2) * y4[i].s2;\n                sumf += convert_float(x4[i].s3) * y4[i].s3;\n            }\n\n            float all_sum = sub_group_reduce_add(sumf);\n            if (get_sub_group_local_id() == 0) {\n                for (int i = 4*(ne00/4); i < ne00; ++i) {\n                    all_sum += (float) x[i] * y[i];\n                }\n                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mat_f16_f32_1row(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        int ne11,\n        int ne12,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;\n    ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;\n\n    global half  * x = (global half  *) (src0 + offset_src0);\n    global float * y = (global float *) (src1 + offset_src1);\n\n    float sumf = 0;\n    if (ne00 < 128) {\n        for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {\n            sumf += (float) x[i] * (float) y[i];\n        }\n        float all_sum = sub_group_reduce_add(sumf);\n        if (get_sub_group_local_id() == 0) {\n            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;\n        }\n    } else {\n        global half4  * x4 = (global half4  *) x;\n        global float4 * y4 = (global float4 *) y;\n        for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {\n            sumf += (float) x4[i].s0 * y4[i].s0;\n            sumf += (float) x4[i].s1 * y4[i].s1;\n            sumf += (float) x4[i].s2 * y4[i].s2;\n            sumf += (float) x4[i].s3 * y4[i].s3;\n        }\n        float all_sum = sub_group_reduce_add(sumf);\n        if (get_sub_group_local_id() == 0) {\n            for (int i = 4*(ne00/4); i < ne00; ++i) {\n                all_sum += (float) x[i] * y[i];\n            }\n            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;\n        }\n    }\n\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n// Assumes row size (ne00) is a multiple of 4\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mat_f16_f32_l4(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        int ne11,\n        int ne12,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int nrows = ne11;\n    int r0 = get_group_id(0);\n    int im = get_group_id(2);\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;\n\n    global half4 * x4 = (global half4 *) (src0 + offset_src0);\n\n    for (int r1 = 0; r1 < nrows; ++r1) {\n        ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;\n\n        global float4 * y4 = (global float4 *) (src1 + offset_src1);\n\n        float sumf = 0;\n        for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {\n            sumf += convert_float(x4[i].s0) * y4[i].s0;\n            sumf += convert_float(x4[i].s1) * y4[i].s1;\n            sumf += convert_float(x4[i].s2) * y4[i].s2;\n            sumf += convert_float(x4[i].s3) * y4[i].s3;\n        }\n\n        float all_sum = sub_group_reduce_add(sumf);\n        if (get_sub_group_local_id() == 0) {\n            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_f32_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define N_F32_F32 4\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mat_f32_f32(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        int ne11,\n        int ne12,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int r0 = get_group_id(0);\n    int rb = get_group_id(1)*N_F32_F32;\n    int im = get_group_id(2);\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;\n\n    global float * x = (global float *) (src0 + offset_src0);\n\n    if (ne00 < 128) {\n        for (int row = 0; row < N_F32_F32; ++row) {\n            int r1 = rb + row;\n            if (r1 >= ne11) {\n                break;\n            }\n\n            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;\n\n            global float * y = (global float *) (src1 + offset_src1);\n\n            float sumf = 0;\n            for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {\n                sumf += (float) x[i] * (float) y[i];\n            }\n\n            float all_sum = sub_group_reduce_add(sumf);\n            if (get_sub_group_local_id() == 0) {\n                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;\n            }\n        }\n    } else {\n        global float4 * x4 = (global float4 *)x;\n        for (int row = 0; row < N_F32_F32; ++row) {\n            int r1 = rb + row;\n            if (r1 >= ne11) {\n                break;\n            }\n\n            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;\n\n            global float  * y  = (global float  *) (src1 + offset_src1);\n            global float4 * y4 = (global float4 *) y;\n\n            float sumf = 0;\n            for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {\n                sumf += (float) x4[i].s0 * y4[i].s0;\n                sumf += (float) x4[i].s1 * y4[i].s1;\n                sumf += (float) x4[i].s2 * y4[i].s2;\n                sumf += (float) x4[i].s3 * y4[i].s3;\n            }\n\n            float all_sum = sub_group_reduce_add(sumf);\n            if (get_sub_group_local_id() == 0) {\n                for (int i = 4*(ne00/4); i < ne00; ++i) {\n                    all_sum += (float) x[i] * y[i];\n                }\n                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK_MXFP4 32\ntypedef struct {\n    uchar e; // E8M0\n    uchar qs[QK_MXFP4/2];\n} block_mxfp4;\n\nconstant static float kvalues_mxfp4_f[16] = {\n    0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f\n};\n\nstatic inline float e8m0_to_fp32(uchar x) {\n    int bits;\n\n    if (x == 0) {\n        bits = 0x00400000;\n    } else {\n        bits = (uint) x << 23;\n    }\n\n    return as_float(bits);\n}\n\n#ifdef INTEL_GPU\n#define N_R0_MXFP4 2 // number of rows each subgroup works on\n#define N_SG_MXFP4 2 // number of subgroups in a work group\n#define N_SIMDWIDTH 16 // subgroup size\n#elif defined (ADRENO_GPU)\n#define N_R0_MXFP4 2\n#define N_SG_MXFP4 2\n#define N_SIMDWIDTH 64\n#endif\n\ninline void mul_mv_mxfp4_f32(\n    global char * src0,\n    global char * src1,\n    global char * dst,\n    int ne00,\n    ulong nb01,\n    ulong nb02,\n    ulong nb03,\n    int ne12,\n    ulong nb11,\n    ulong nb12,\n    ulong nb13,\n    int ne0,\n    int ne1,\n    int r2,\n    int r3,\n    local  char * shmem\n) {\n    local float * shmem_f32 = (local float *) shmem;\n    int nb = ne00/QK_MXFP4;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = 0;\n\n    int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;\n\n    uint i12 = im%ne12;\n    uint i13 = im/ne12;\n\n    ulong offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;\n    ulong offset_src1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;\n\n    global block_mxfp4 * x = (global block_mxfp4 *) (src0 + offset_src0);\n    global float       * y = (global float       *) (src1 + offset_src1);\n\n    const short ix = get_sub_group_local_id()/2;  // 0...15\n    const short it = get_sub_group_local_id()%2;  // 0 or 1\n\n    shmem_f32[get_sub_group_local_id()] = kvalues_mxfp4_f[get_sub_group_local_id()%16];\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    float4 yl[4];\n    float sumf[N_R0_MXFP4] = {0.f};\n\n    global float * yb = y + ix * QK_MXFP4 + it * 8;\n\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {\n        global float4 * y4 = (global float4 *)yb;\n        yl[0] = y4[0];\n        yl[1] = y4[4];\n        yl[2] = y4[1];\n        yl[3] = y4[5];\n\n        for (short row = 0; row < N_R0_MXFP4; row++) {\n            global block_mxfp4 * xb = x + row*nb + ib;\n            global uchar       * q2 = (global uchar *)(xb->qs + 8*it);\n\n            float4 acc1 = yl[0]*(float4)(shmem_f32[q2[0] &  0x0F], shmem_f32[q2[1] &  0x0F], shmem_f32[q2[2] &  0x0F], shmem_f32[q2[3] &  0x0F]);\n            float4 acc2 = yl[1]*(float4)(shmem_f32[q2[0] >> 4   ], shmem_f32[q2[1] >> 4   ], shmem_f32[q2[2] >> 4   ], shmem_f32[q2[3] >> 4   ]);\n            float4 acc3 = yl[2]*(float4)(shmem_f32[q2[4] &  0x0F], shmem_f32[q2[5] &  0x0F], shmem_f32[q2[6] &  0x0F], shmem_f32[q2[7] &  0x0F]);\n            float4 acc4 = yl[3]*(float4)(shmem_f32[q2[4] >> 4   ], shmem_f32[q2[5] >> 4   ], shmem_f32[q2[6] >> 4   ], shmem_f32[q2[7] >> 4   ]);\n\n            acc1 = (acc1 + acc3) + (acc2 + acc4);\n\n            sumf[row] += e8m0_to_fp32(xb->e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));\n        }\n\n        yb += (N_SIMDWIDTH/2) * QK_MXFP4;\n    }\n\n    global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;\n\n    for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {\n        float sum_all = sub_group_reduce_add(sumf[row]);\n        if (get_sub_group_local_id() == 0) {\n            dst_f32[first_row + row] = sum_all;\n        }\n    }\n}\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mv_id_mxfp4_f32(\n    global char * src0,\n    ulong         offset0,\n    global char * src1,\n    ulong         offset1,\n    global char * src2,\n    ulong         offset2,\n    global char * dst,\n    ulong         offsetd,\n    int           ne00,\n    ulong         nb01,\n    ulong         nb02,\n    ulong         nb03,\n    int           ne11,\n    int           ne12,\n    ulong         nb11,\n    ulong         nb12,\n    ulong         nb13,\n    int           ne20,\n    int           ne21,\n    ulong         nb21,\n    int           ne0,\n    int           ne1,\n    int           r2,\n    int           r3,\n    local  char * shmem\n) {\n    src0 = (global char *)((global char *)src0 + offset0);\n    src1 = (global char *)((global char *)src1 + offset1);\n    src2 = (global char *)((global char *)src2 + offset2);\n    dst  = (global char *)((global char *)dst  + offsetd);\n\n    const int iid1 = get_group_id(2)/ne20;\n    const int idx  = get_group_id(2)%ne20;\n\n    int i02 = ((global int *) (src2 + iid1*nb21))[idx];\n\n    int i11 = idx % ne11;\n    int i12 = iid1;\n\n    int i1 = idx;\n    int i2 = i12;\n\n    global char * src0_cur = src0 + i02*nb02;\n    global char * src1_cur = src1 + i11*nb11 + i12*nb12;\n\n    global char * dst_cur = dst + (i1*ne0 + i2*ne1*ne0)*sizeof(float);\n\n    mul_mv_mxfp4_f32(src0_cur, src1_cur, dst_cur,\n        ne00, nb01, nb02, nb03, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shmem);\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK_MXFP4 32\n\nstatic inline half4 mxfp4_to_fp16_packed(ushort fp4x4) {\n    ushort2 fp16_packed_a, fp16_packed_b, bias_a, bias_b, sign_a, sign_b;\n    fp16_packed_a.lo = (fp4x4 << 9) & 0x0E00;\n    fp16_packed_a.hi = (fp4x4 << 5) & 0x0E00;\n    fp16_packed_b.lo = (fp4x4 << 1) & 0x0E00;\n    fp16_packed_b.hi = (fp4x4 >> 3) & 0x0E00;\n\n    bias_a.lo = (fp16_packed_a.lo == 0) ? 0x0 : 0x3800;\n    bias_a.hi = (fp16_packed_a.hi == 0) ? 0x0 : 0x3800;\n    bias_b.lo = (fp16_packed_b.lo == 0) ? 0x0 : 0x3800;\n    bias_b.hi = (fp16_packed_b.hi == 0) ? 0x0 : 0x3800;\n\n    fp16_packed_a.lo = (fp16_packed_a.lo == 0x0200) ? 0x0 : fp16_packed_a.lo;\n    fp16_packed_a.hi = (fp16_packed_a.hi == 0x0200) ? 0x0 : fp16_packed_a.hi;\n    fp16_packed_b.lo = (fp16_packed_b.lo == 0x0200) ? 0x0 : fp16_packed_b.lo;\n    fp16_packed_b.hi = (fp16_packed_b.hi == 0x0200) ? 0x0 : fp16_packed_b.hi;\n\n    sign_a.lo = (fp4x4 << 12) & 0x8000;\n    sign_a.hi = (fp4x4 << 8) & 0x8000;\n    sign_b.lo = (fp4x4 << 4) & 0x8000;\n    sign_b.hi = fp4x4 & 0x8000;\n\n    fp16_packed_a = sign_a + bias_a + fp16_packed_a;\n    fp16_packed_b = sign_b + bias_b + fp16_packed_b;\n\n    return as_half4((ushort4)(fp16_packed_a, fp16_packed_b));\n}\n\nstatic inline float e8m0_to_fp32(uchar x) {\n    int bits;\n    bits = (x == 0) ? 0x00400000 : ((uint) x << 23);\n    return as_float(bits);\n}\n\n#ifdef INTEL_GPU\n#define N_R0_MXFP4 2 // number of rows each subgroup works on\n#define N_SG_MXFP4 2 // number of subgroups in a work group\n#define N_SIMDWIDTH 16 // subgroup size\n#elif defined (ADRENO_GPU)\n#define N_R0_MXFP4 4\n#define N_SG_MXFP4 1\n#define N_SIMDWIDTH 64\n#define SRC0Q_IMG\n#endif\n\nkernel void kernel_mul_mv_id_mxfp4_f32_flat(\n#ifdef SRC0Q_IMG\n    __read_only image1d_buffer_t src0_q,\n#else\n    global uchar * src0_q,\n#endif\n    global uchar * src0_e,\n    global uchar * src1,\n    ulong         offset1,\n    global uchar * src2,\n    ulong         offset2,\n    global uchar * dst,\n    ulong         offsetd,\n    int           ne00,\n    ulong         nb01,\n    ulong         nb02,\n    ulong         nb03,\n    int           ne11,\n    int           ne12,\n    ulong         nb11,\n    ulong         nb12,\n    ulong         nb13,\n    int           ne20,\n    int           ne21,\n    ulong         nb21,\n    int           ne0,\n    int           ne1,\n    int           r2,\n    int           r3\n) {\n    dst  = dst  + offsetd;\n\n    const int iid1 = get_group_id(2) / ne20;\n    const int idx  = get_group_id(2) % ne20;\n\n    uint i02 = ((global uint *) (src2 + offset2 + iid1 * nb21))[idx];\n\n    int i11 = idx % ne11;\n\n    int nb = ne00 / QK_MXFP4;\n\n    uint src0_off = i02*nb02;\n    src0_off /= 17; // 17 = sizeof(block_mxfp4)\n\n    src0_e = src0_e + src0_off;\n\n    dst = dst + (idx * ne0 + iid1 * ne1 * ne0) * sizeof(float);\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n\n    int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;\n\n    uint offset_src0 = first_row*nb01;\n    offset_src0 /= 17; // 17 = sizeof(block_mxfp4)\n#ifdef SRC0Q_IMG\n    ulong offset_q = src0_off + offset_src0;\n#else\n    src0_q = src0_q + src0_off*16;\n    global uchar16 * x_q = (global uchar16 *)(src0_q) + offset_src0;\n#endif\n    global uchar * x_e = src0_e + offset_src0;\n\n    const short ix = get_sub_group_local_id() >> 1;\n    const short it = get_sub_group_local_id() & 1;\n\n    float sumf[N_R0_MXFP4] = {0.f};\n\n    src1 = src1 + offset1 + i11 * nb11 + iid1 * nb12;\n    global float * y   = (global float *) (src1 + r1 * nb11);\n    global float * yb = y + ix * QK_MXFP4 + it * 8;\n\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH / 2) {\n        global float4 * y4 = (global float4 *)yb;\n\n        #pragma unroll\n        for (short row = 0; row < N_R0_MXFP4; row++) {\n            uchar xb_e = x_e[row * nb + ib];\n#ifdef SRC0Q_IMG\n            ushort4 xb_q = as_ushort4(read_imageui(src0_q, (offset_q + row * nb + ib) * 2 + it).xy);\n#else\n            ushort4 xb_q = vload4(0, (global ushort *)((global uchar *)(x_q + row * nb + ib) + 8 * it));\n#endif\n\n            half4 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s0);\n            half4 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s1);\n            float4 acc1 = y4[0] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);\n            acc1 += y4[4] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);\n\n            fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s2);\n            fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s3);\n            acc1 += y4[1] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);\n            acc1 += y4[5] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);\n\n            sumf[row] += e8m0_to_fp32(xb_e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));\n        }\n\n        yb += (N_SIMDWIDTH / 2) * QK_MXFP4;\n    }\n\n    global float * dst_f32 = (global float *)dst + (ulong)r1 * ne0;\n\n    for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {\n        float sum_all = sub_group_reduce_add(sumf[row]);\n        if (get_sub_group_local_id() == 0) {\n            dst_f32[first_row + row] = sum_all;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK4_0                   32\n\ntypedef char int8_t;\ntypedef uchar uint8_t;\ntypedef short int16_t;\ntypedef ushort uint16_t;\ntypedef int int32_t;\ntypedef uint uint32_t;\n\n//------------------------------------------------------------------------------\n// block_q4_0\n//------------------------------------------------------------------------------\nstruct block_q4_0\n{\n    half d;\n    uint8_t qs[QK4_0 / 2];\n};\n\n// This function requires the original shuffled weights.\n// As a reminder, the original weights are shuffled so that (q[0], q[16]) are\n// packed together in a byte, so are (q[1], q[17]) and so on.\ninline float block_q_4_0_dot_y_flat(\n        global uchar * x,\n        global half  * dh,\n        float sumy,\n        float16 yl,\n        int il\n) {\n    float           d   = *dh;\n    global ushort * qs  = ((global ushort *)x + il/2);\n    float           acc = 0.f;\n\n    acc += yl.s0 * (qs[0] & 0x000F);\n    acc += yl.s1 * (qs[0] & 0x0F00);\n    acc += yl.s8 * (qs[0] & 0x00F0);\n    acc += yl.s9 * (qs[0] & 0xF000);\n\n    acc += yl.s2 * (qs[1] & 0x000F);\n    acc += yl.s3 * (qs[1] & 0x0F00);\n    acc += yl.sa * (qs[1] & 0x00F0);\n    acc += yl.sb * (qs[1] & 0xF000);\n\n    acc += yl.s4 * (qs[2] & 0x000F);\n    acc += yl.s5 * (qs[2] & 0x0F00);\n    acc += yl.sc * (qs[2] & 0x00F0);\n    acc += yl.sd * (qs[2] & 0xF000);\n\n    acc += yl.s6 * (qs[3] & 0x000F);\n    acc += yl.s7 * (qs[3] & 0x0F00);\n    acc += yl.se * (qs[3] & 0x00F0);\n    acc += yl.sf * (qs[3] & 0xF000);\n\n    return d * (sumy * -8.f + acc);\n}\n\n//\n// This variant outputs 8 values.\n//\n#undef N_DST\n#undef N_SIMDGROUP\n#undef N_SIMDWIDTH\n\n#ifdef INTEL_GPU\n#define N_DST 8 // each SIMD group works on 8 rows\n#define N_SIMDGROUP 1 // number of SIMD groups in a thread group\n#define N_SIMDWIDTH 16 // subgroup size\n#elif defined (ADRENO_GPU)\n#define N_DST 8\n#define N_SIMDGROUP 1\n#define N_SIMDWIDTH 64\n#endif\n\ninline void mul_vec_q_n_f32_8x_flat(\n        global char  * src0_q,\n        global half  * src0_d,\n        global float * src1,\n        global float * dst,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    const ulong nb = ne00/QK4_0;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = 0;\n\n    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    // The number of scales is the same as the number of blocks.\n    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);\n    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.\n    ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;\n\n    global uchar * x = (global uchar *) src0_q + offset0_q;\n    global half  * d = (global half  *) src0_d + offset0_d;\n    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;\n\n    float16 yl;\n    float8 sumf = 0.f;\n\n    int ix = get_sub_group_local_id()/2;\n    int il = 8*(get_sub_group_local_id()%2);\n\n    global float * yb = y + ix*QK4_0 + il;\n\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {\n        float sumy = 0.f;\n\n        sumy += yb[0];\n        sumy += yb[1];\n        sumy += yb[2];\n        sumy += yb[3];\n        sumy += yb[4];\n        sumy += yb[5];\n        sumy += yb[6];\n        sumy += yb[7];\n\n        sumy += yb[16];\n        sumy += yb[17];\n        sumy += yb[18];\n        sumy += yb[19];\n        sumy += yb[20];\n        sumy += yb[21];\n        sumy += yb[22];\n        sumy += yb[23];\n\n        yl.s0 = yb[0];\n        yl.s1 = yb[1]/256.f;\n\n        yl.s2 = yb[2];\n        yl.s3 = yb[3]/256.f;\n\n        yl.s4 = yb[4];\n        yl.s5 = yb[5]/256.f;\n\n        yl.s6 = yb[6];\n        yl.s7 = yb[7]/256.f;\n\n        yl.s8 = yb[16]/16.f;\n        yl.s9 = yb[17]/4096.f;\n\n        yl.sa = yb[18]/16.f;\n        yl.sb = yb[19]/4096.f;\n\n        yl.sc = yb[20]/16.f;\n        yl.sd = yb[21]/4096.f;\n\n        yl.se = yb[22]/16.f;\n        yl.sf = yb[23]/4096.f;\n\n        sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);\n        sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);\n        sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);\n        sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);\n\n        sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);\n        sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);\n        sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);\n        sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);\n\n        yb += QK4_0 * (N_SIMDWIDTH/2);\n    }\n\n    float8 tot = (float8)(\n        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),\n        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),\n        sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),\n        sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)\n    );\n\n    if (get_sub_group_local_id() == 0) {\n        if (first_row + 0 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;\n        }\n        if (first_row + 1 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;\n        }\n        if (first_row + 2 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;\n        }\n        if (first_row + 3 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;\n        }\n\n        if (first_row + 4 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;\n        }\n        if (first_row + 5 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;\n        }\n        if (first_row + 6 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;\n        }\n        if (first_row + 7 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;\n        }\n    }\n}\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mv_id_q4_0_f32_8x_flat(\n        global char  *  src0_q,\n        global half  *  src0_d,\n        global float *  src1,\n        ulong           offset1,\n        global char  *  src2,\n        ulong           offset2,\n        global float *  dst,\n        ulong           offsetd,\n        int             ne00,\n        int             ne01,\n        int             ne02,\n        ulong           nb00,\n        ulong           nb02,\n        int             ne10,\n        int             ne11,\n        int             ne12,\n        ulong           nb11,\n        ulong           nb12,\n        int             ne20,\n        int             ne21,\n        ulong           nb21,\n        int             ne0,\n        int             ne1,\n        int             r2,\n        int             r3\n) {\n    src1 = (global float *)((global char *)src1 + offset1);\n    src2 = (global char  *)((global char *)src2 + offset2);\n    dst  = (global float *)((global char *)dst  + offsetd);\n\n    const int iid1 = get_group_id(2)/ne20;\n    const int idx  = get_group_id(2)%ne20;\n\n    const int i02 = ((global int *)(src2 + iid1*nb21))[idx];\n\n    const int i11 = idx%ne11;\n    const int i12 = iid1;\n\n    const int i1 = idx;\n    const int i2 = i12;\n\n    global char  * src0_q_cur = src0_q + (i02*nb02/nb00)*(QK4_0/2);\n    global half  * src0_d_cur = src0_d + (i02*nb02/nb00);\n    global float * src1_cur   = (global float *)((global char *) src1  + i11*nb11 + i12*nb12);\n    global float * dst_cur    = dst + i1*ne0 + i2*ne1*ne0;\n\n    mul_vec_q_n_f32_8x_flat(src0_q_cur, src0_d_cur, src1_cur, dst_cur, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK8_0 32\ntypedef struct {\n    half d;       // delta\n    char qs[QK8_0]; // quants\n} block_q8_0;\n\n#define NB_Q8_0 8\n\n#ifdef INTEL_GPU\n#define N_R0_Q8_0 4 // number of rows each subgroup works on\n#define N_SG_Q8_0 2 // number of subgroups in a work group\n#define N_SIMDWIDTH 16 // subgroup size\n#elif defined (ADRENO_GPU)\n#define N_R0_Q8_0 4\n#define N_SG_Q8_0 2\n#define N_SIMDWIDTH 64\n#endif\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mv_id_q8_0_f32(\n    global char * src0,\n    ulong         offset0,\n    global char * src1,\n    ulong         offset1,\n    global char * src2,\n    ulong         offset2,\n    global char * dst,\n    ulong         offsetd,\n    int           ne00,\n    int           ne01,\n    ulong         nb01,\n    ulong         nb02,\n    int           ne11,\n    int           ne12,\n    ulong         nb11,\n    ulong         nb12,\n    int           ne20,\n    int           ne21,\n    ulong         nb21,\n    int           ne0,\n    int           ne1\n) {\n    src0 = (global char *)((global char *)src0 + offset0);\n    src1 = (global char *)((global char *)src1 + offset1);\n    src2 = (global char *)((global char *)src2 + offset2);\n    dst  = (global char *)((global char *)dst  + offsetd);\n\n    int iid1 = get_group_id(2)/ne20;\n    int idx  = get_group_id(2)%ne20;\n\n    int i02 = ((global int *) (src2 + iid1*nb21))[idx];\n\n    int i11_ = idx % ne11;\n    int i12_ = iid1;\n\n    int i1 = idx;\n    int i2 = i12_;\n\n    global char * src0_cur = src0 + i02*nb02;\n    global char * src1_cur = src1 + i11_*nb11 + i12_*nb12;\n\n    global char * dst_cur = dst + (i1*ne0 + i2*ne1*ne0)*sizeof(float);\n\n    int nb = ne00/QK8_0;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n\n    int first_row = (r0*N_SG_Q8_0 + get_sub_group_id()) * N_R0_Q8_0;\n\n    ulong offset_src1 = r1*nb11;\n    global float * y  = (global float *) (src1_cur + offset_src1);\n\n    // pointers to src0 rows\n    global block_q8_0 * ax[N_R0_Q8_0];\n    for (int row = 0; row < N_R0_Q8_0; ++row) {\n        ulong offset_src0 = (first_row + row)*nb01;\n        ax[row] = (global block_q8_0 *) ((global char *) src0_cur + offset_src0);\n    }\n\n    float yl[NB_Q8_0];\n    float sumf[N_R0_Q8_0] = { 0.f };\n\n    const short ix = get_sub_group_local_id()/4;\n    const short il = get_sub_group_local_id()%4;\n\n    global float * yb = y + ix*QK8_0 + il*NB_Q8_0;\n\n    // each thread handles NB_Q8_0 quants at a time\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/4) {\n        for (short i = 0; i < NB_Q8_0; ++i) {\n            yl[i] = yb[i];\n        }\n\n        for (short row = 0; row < N_R0_Q8_0; row++) {\n            global char * qs = ax[row][ib].qs + il*NB_Q8_0;\n            float sumq = 0.f;\n            for (short iq = 0; iq < NB_Q8_0; ++iq) {\n                sumq += qs[iq] * yl[iq];\n            }\n            sumf[row] += sumq*ax[row][ib].d;\n        }\n\n        yb += N_SIMDWIDTH*NB_Q8_0;\n    }\n\n    global float * dst_f32 = (global float *) dst_cur + (ulong)r1*ne0;\n\n    for (int row = 0; row < N_R0_Q8_0; ++row) {\n        float tot = sub_group_reduce_add(sumf[row]);\n\n        if (get_sub_group_local_id() == 0 && first_row + row < ne01) {\n            dst_f32[first_row + row] = tot;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK8_0 32\ntypedef struct {\n    half d;       // delta\n    char qs[QK8_0]; // quants\n} block_q8_0;\n\n#define NB_Q8_0 8\n\n#ifdef INTEL_GPU\n#define N_R0_Q8_0 4 // number of rows each subgroup works on\n#define N_SG_Q8_0 2 // number of subgroups in a work group\n#define N_SIMDWIDTH 16 // subgroup size\n#elif defined (ADRENO_GPU)\n#define N_R0_Q8_0 4\n#define N_SG_Q8_0 2\n#define N_SIMDWIDTH 64\n#endif\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mv_id_q8_0_f32_flat(\n    global char * src0_q,\n    global half * src0_d,\n    global char * src1,\n    ulong         offset1,\n    global char * src2,\n    ulong         offset2,\n    global char * dst,\n    ulong         offsetd,\n    int           ne00,\n    int           ne01,\n    ulong         nb01,\n    ulong         nb02,\n    int           ne11,\n    int           ne12,\n    ulong         nb11,\n    ulong         nb12,\n    int           ne20,\n    int           ne21,\n    ulong         nb21,\n    int           ne0,\n    int           ne1\n) {\n    src1 = (global char *)((global char *)src1 + offset1);\n    src2 = (global char *)((global char *)src2 + offset2);\n    dst  = (global char *)((global char *)dst  + offsetd);\n\n    int iid1 = (int)get_group_id(2)/ne20;\n    int idx  = (int)get_group_id(2)%ne20;\n\n    int i02 = ((global int *) (src2 + iid1*nb21))[idx];\n\n    int i11_ = idx % ne11;\n    int i12_ = iid1;\n\n    int i1 = idx;\n    int i2 = i12_;\n\n    // 34 == sizeof(block_q8_0)\n    uint src0_off = i02*nb02;\n    src0_off /= 34;\n\n    global char * src0_q_cur = src0_q + src0_off*sizeof(char)*QK8_0;\n    global half * src0_d_cur = src0_d + src0_off;\n    global char * src1_cur   = src1 + i11_*nb11 + i12_*nb12;\n\n    global char * dst_cur = dst + (i1*ne0 + i2*ne1*ne0)*sizeof(float);\n\n    int nb = ne00/QK8_0;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n\n    int first_row = (r0*N_SG_Q8_0 + get_sub_group_id()) * N_R0_Q8_0;\n\n    ulong offset_src1 = r1*nb11;\n    global float * y  = (global float *) (src1_cur + offset_src1);\n\n    // pointers to src0 rows\n    uint offset_src0_base = first_row*nb01;\n\n    global char * ax0, * ax1, * ax2, * ax3;\n    global half * ad0, * ad1, * ad2, * ad3;\n    uint offset_src0;\n\n    offset_src0 = offset_src0_base + 0*nb01;\n    offset_src0 = offset_src0/34;\n    ax0 = (global char *) ((global char *) src0_q_cur + offset_src0*sizeof(char)*QK8_0);\n    ad0 = (global half *) ((global char *) src0_d_cur + offset_src0*sizeof(half));\n\n    offset_src0 = offset_src0_base + 1*nb01;\n    offset_src0 = offset_src0/34;\n    ax1 = (global char *) ((global char *) src0_q_cur + offset_src0*sizeof(char)*QK8_0);\n    ad1 = (global half *) ((global char *) src0_d_cur + offset_src0*sizeof(half));\n\n    offset_src0 = offset_src0_base + 2*nb01;\n    offset_src0 = offset_src0/34;\n    ax2 = (global char *) ((global char *) src0_q_cur + offset_src0*sizeof(char)*QK8_0);\n    ad2 = (global half *) ((global char *) src0_d_cur + offset_src0*sizeof(half));\n\n    offset_src0 = offset_src0_base + 3*nb01;\n    offset_src0 = offset_src0/34;\n    ax3 = (global char *) ((global char *) src0_q_cur + offset_src0*sizeof(char)*QK8_0);\n    ad3 = (global half *) ((global char *) src0_d_cur + offset_src0*sizeof(half));\n\n    const short ix = get_sub_group_local_id()/4;\n    const short il = get_sub_group_local_id()%4;\n\n    global float * yb = y + ix*QK8_0 + il*NB_Q8_0;\n\n    float8 yl;\n    float8 qv;\n    float4 sumf = 0.f;\n    float  sumq = 0.f;\n    global char * qs;\n\n    // each thread handles NB_Q8_0 quants at a time\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/4) {\n        yl = vload8(0, yb);\n\n        qs = ax0 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;\n        qv = convert_float8(vload8(0, qs));\n        sumq = 0;\n        sumq += qv.s0*yl.s0;\n        sumq += qv.s1*yl.s1;\n        sumq += qv.s2*yl.s2;\n        sumq += qv.s3*yl.s3;\n        sumq += qv.s4*yl.s4;\n        sumq += qv.s5*yl.s5;\n        sumq += qv.s6*yl.s6;\n        sumq += qv.s7*yl.s7;\n        sumf.s0 += sumq*ad0[ib];\n\n        qs = ax1 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;\n        qv = convert_float8(vload8(0, qs));\n        sumq = 0;\n        sumq += qv.s0*yl.s0;\n        sumq += qv.s1*yl.s1;\n        sumq += qv.s2*yl.s2;\n        sumq += qv.s3*yl.s3;\n        sumq += qv.s4*yl.s4;\n        sumq += qv.s5*yl.s5;\n        sumq += qv.s6*yl.s6;\n        sumq += qv.s7*yl.s7;\n        sumf.s1 += sumq*ad1[ib];\n\n        qs = ax2 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;\n        qv = convert_float8(vload8(0, qs));\n        sumq = 0;\n        sumq += qv.s0*yl.s0;\n        sumq += qv.s1*yl.s1;\n        sumq += qv.s2*yl.s2;\n        sumq += qv.s3*yl.s3;\n        sumq += qv.s4*yl.s4;\n        sumq += qv.s5*yl.s5;\n        sumq += qv.s6*yl.s6;\n        sumq += qv.s7*yl.s7;\n        sumf.s2 += sumq*ad2[ib];\n\n        qs = ax3 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;\n        qv = convert_float8(vload8(0, qs));\n        sumq = 0;\n        sumq += qv.s0*yl.s0;\n        sumq += qv.s1*yl.s1;\n        sumq += qv.s2*yl.s2;\n        sumq += qv.s3*yl.s3;\n        sumq += qv.s4*yl.s4;\n        sumq += qv.s5*yl.s5;\n        sumq += qv.s6*yl.s6;\n        sumq += qv.s7*yl.s7;\n        sumf.s3 += sumq*ad3[ib];\n\n        yb += N_SIMDWIDTH*NB_Q8_0;\n    }\n\n    global float * dst_f32 = (global float *) dst_cur + (ulong)r1*ne0;\n\n    float4 tot = (float4)(\n        sub_group_reduce_add(sumf.s0),\n        sub_group_reduce_add(sumf.s1),\n        sub_group_reduce_add(sumf.s2),\n        sub_group_reduce_add(sumf.s3)\n    );\n\n    if (get_sub_group_local_id() == 0) {\n        if (first_row + 0 < ne01) {\n            dst_f32[first_row + 0] = tot.s0;\n        }\n        if (first_row + 1 < ne01) {\n            dst_f32[first_row + 1] = tot.s1;\n        }\n        if (first_row + 2 < ne01) {\n            dst_f32[first_row + 2] = tot.s2;\n        }\n        if (first_row + 3 < ne01) {\n            dst_f32[first_row + 3] = tot.s3;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK_MXFP4 32\ntypedef struct {\n    uchar e; // E8M0\n    uchar qs[QK_MXFP4/2];\n} block_mxfp4;\n\nconstant static float kvalues_mxfp4_f[16] = {\n    0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f\n};\n\nstatic inline float e8m0_to_fp32(uchar x) {\n    int bits;\n\n    if (x == 0) {\n        bits = 0x00400000;\n    } else {\n        bits = (uint) x << 23;\n    }\n\n    return as_float(bits);\n}\n\n#ifdef INTEL_GPU\n#define N_R0_MXFP4 2 // number of rows each subgroup works on\n#define N_SG_MXFP4 2 // number of subgroups in a work group\n#define N_SIMDWIDTH 16 // subgroup size\n#elif defined (ADRENO_GPU)\n#define N_R0_MXFP4 2\n#define N_SG_MXFP4 2\n#define N_SIMDWIDTH 64\n#endif\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mv_mxfp4_f32(\n    global char * src0,\n    ulong         offset0,\n    global char * src1,\n    ulong         offset1,\n    global char * dst,\n    ulong         offsetd,\n    int ne00,\n    ulong nb01,\n    ulong nb02,\n    ulong nb03,\n    int ne12,\n    ulong nb11,\n    ulong nb12,\n    ulong nb13,\n    int ne0,\n    int ne1,\n    int r2,\n    int r3,\n    local  char * shmem\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    local float * shmem_f32 = (local float *) shmem;\n    int nb = ne00/QK_MXFP4;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n\n    int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;\n\n    uint i12 = im%ne12;\n    uint i13 = im/ne12;\n\n    ulong offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;\n    ulong offset_src1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;\n\n    global block_mxfp4 * x = (global block_mxfp4 *) (src0 + offset_src0);\n    global float       * y = (global float       *) (src1 + offset_src1);\n\n    const short ix = get_sub_group_local_id()/2;  // 0...15\n    const short it = get_sub_group_local_id()%2;  // 0 or 1\n\n    shmem_f32[get_sub_group_local_id()] = kvalues_mxfp4_f[get_sub_group_local_id()%16];\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    float4 yl[4];\n    float sumf[N_R0_MXFP4] = {0.f};\n\n    global float * yb = y + ix * QK_MXFP4 + it * 8;\n\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {\n        global float4 * y4 = (global float4 *)yb;\n        yl[0] = y4[0];\n        yl[1] = y4[4];\n        yl[2] = y4[1];\n        yl[3] = y4[5];\n\n        for (short row = 0; row < N_R0_MXFP4; row++) {\n            global block_mxfp4 * xb = x + row*nb + ib;\n            global uchar       * q2 = (global uchar *)(xb->qs + 8*it);\n\n            float4 acc1 = yl[0]*(float4)(shmem_f32[q2[0] &  0x0F], shmem_f32[q2[1] &  0x0F], shmem_f32[q2[2] &  0x0F], shmem_f32[q2[3] &  0x0F]);\n            float4 acc2 = yl[1]*(float4)(shmem_f32[q2[0] >> 4   ], shmem_f32[q2[1] >> 4   ], shmem_f32[q2[2] >> 4   ], shmem_f32[q2[3] >> 4   ]);\n            float4 acc3 = yl[2]*(float4)(shmem_f32[q2[4] &  0x0F], shmem_f32[q2[5] &  0x0F], shmem_f32[q2[6] &  0x0F], shmem_f32[q2[7] &  0x0F]);\n            float4 acc4 = yl[3]*(float4)(shmem_f32[q2[4] >> 4   ], shmem_f32[q2[5] >> 4   ], shmem_f32[q2[6] >> 4   ], shmem_f32[q2[7] >> 4   ]);\n\n            acc1 = (acc1 + acc3) + (acc2 + acc4);\n\n            sumf[row] += e8m0_to_fp32(xb->e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));\n        }\n\n        yb += (N_SIMDWIDTH/2) * QK_MXFP4;\n    }\n\n    global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;\n\n    for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {\n        float sum_all = sub_group_reduce_add(sumf[row]);\n        if (get_sub_group_local_id() == 0) {\n            dst_f32[first_row + row] = sum_all;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK_MXFP4 32\n\nstatic inline half4 mxfp4_to_fp16_packed(ushort fp4x4) {\n    ushort2 fp16_packed_a, fp16_packed_b, bias_a, bias_b, sign_a, sign_b;\n    fp16_packed_a.lo = (fp4x4 << 9) & 0x0E00;\n    fp16_packed_a.hi = (fp4x4 << 5) & 0x0E00;\n    fp16_packed_b.lo = (fp4x4 << 1) & 0x0E00;\n    fp16_packed_b.hi = (fp4x4 >> 3) & 0x0E00;\n\n    bias_a.lo = (fp16_packed_a.lo == 0) ? 0x0 : 0x3800;\n    bias_a.hi = (fp16_packed_a.hi == 0) ? 0x0 : 0x3800;\n    bias_b.lo = (fp16_packed_b.lo == 0) ? 0x0 : 0x3800;\n    bias_b.hi = (fp16_packed_b.hi == 0) ? 0x0 : 0x3800;\n\n    fp16_packed_a.lo = (fp16_packed_a.lo == 0x0200) ? 0x0 : fp16_packed_a.lo;\n    fp16_packed_a.hi = (fp16_packed_a.hi == 0x0200) ? 0x0 : fp16_packed_a.hi;\n    fp16_packed_b.lo = (fp16_packed_b.lo == 0x0200) ? 0x0 : fp16_packed_b.lo;\n    fp16_packed_b.hi = (fp16_packed_b.hi == 0x0200) ? 0x0 : fp16_packed_b.hi;\n\n    sign_a.lo = (fp4x4 << 12) & 0x8000;\n    sign_a.hi = (fp4x4 << 8) & 0x8000;\n    sign_b.lo = (fp4x4 << 4) & 0x8000;\n    sign_b.hi = fp4x4 & 0x8000;\n\n    fp16_packed_a = sign_a + bias_a + fp16_packed_a;\n    fp16_packed_b = sign_b + bias_b + fp16_packed_b;\n\n    return as_half4((ushort4)(fp16_packed_a, fp16_packed_b));\n}\n\nstatic inline float e8m0_to_fp32(uchar x) {\n    int bits;\n    bits = (x == 0) ? 0x00400000 : ((uint) x << 23);\n    return as_float(bits);\n}\n\n#ifdef INTEL_GPU\n#define N_R0_MXFP4 2 // number of rows each subgroup works on\n#define N_SG_MXFP4 2 // number of subgroups in a work group\n#define N_SIMDWIDTH 16 // subgroup size\n#elif defined (ADRENO_GPU)\n#define N_R0_MXFP4 2\n#define N_SG_MXFP4 2\n#define N_SIMDWIDTH 64\n#define SRC0Q_IMG\n#endif\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mv_mxfp4_f32_flat(\n#ifdef SRC0Q_IMG\n    __read_only image1d_buffer_t src0_q,\n#else\n    global uchar * src0_q,\n#endif\n    global uchar * src0_e,\n    global uchar * src1,\n    ulong          offset1,\n    global uchar * dst,\n    ulong          offsetd,\n    int ne00,\n    ulong nb01,\n    ulong nb02,\n    ulong nb03,\n    int ne12,\n    ulong nb11,\n    ulong nb12,\n    ulong nb13,\n    int ne0,\n    int ne1,\n    int r2,\n    int r3\n) {\n    src1 = src1 + offset1;\n    dst = dst + offsetd;\n\n    int nb = ne00 / QK_MXFP4;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n\n    int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;\n\n    uint i12 = im % ne12;\n    uint i13 = im / ne12;\n\n    uint offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;\n    // 17 = sizeof(block_mxfp4)\n    offset_src0 /= 17;\n#ifdef SRC0Q_IMG\n    ulong offset_q = offset_src0;\n#else\n    global uchar16 * x_q = (global uchar16 *)(src0_q) + offset_src0;\n#endif\n    global uchar * x_e = src0_e + offset_src0;\n\n    ulong offset_src1 = r1 * nb11 + i12 * nb12 + i13 * nb13;\n    global float * y = (global float *)(src1 + offset_src1);\n\n    const short ix = get_sub_group_local_id() >> 1;  // 0...15\n    const short it = get_sub_group_local_id() & 1;  // 0 or 1\n\n    float sumf[N_R0_MXFP4] = {0.f};\n\n    global float * yb = y + ix * QK_MXFP4 + it * 8;\n\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {\n        global float4 * y4 = (global float4 *)yb;\n\n        #pragma unroll\n        for (short row = 0; row < N_R0_MXFP4; row++) {\n            uchar xb_e = x_e[row * nb + ib];\n#ifdef SRC0Q_IMG\n            ushort4 xb_q = as_ushort4(read_imageui(src0_q, (offset_q + row * nb + ib) * 2 + it).xy);\n#else\n            ushort4 xb_q = vload4(0, (global ushort *)((global uchar *)(x_q + row * nb + ib) + 8 * it));\n#endif\n\n            half4 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s0);\n            half4 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s1);\n            float4 acc1 = y4[0] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);\n            acc1 += y4[4] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);\n\n            fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s2);\n            fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s3);\n            acc1 += y4[1] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);\n            acc1 += y4[5] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);\n\n            sumf[row] += e8m0_to_fp32(xb_e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));\n        }\n\n        yb += (N_SIMDWIDTH/2) * QK_MXFP4;\n    }\n\n    global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;\n\n    for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {\n        float sum_all = sub_group_reduce_add(sumf[row]);\n        if (get_sub_group_local_id() == 0) {\n            dst_f32[first_row + row] = sum_all;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK4_0                   32\n#define QR4_0                   2\n#define QK4_1                   32\n#define QR4_1                   2\n#define QK5_0                   32\n#define QR5_0                   2\n#define QK5_1                   32\n#define QR5_1                   2\n#define QK8_0                   32\n#define QR8_0                   1\n#define QK_K                    256\n#define K_QUANTS_PER_ITERATION  2\n\ntypedef char int8_t;\ntypedef uchar uint8_t;\ntypedef short int16_t;\ntypedef ushort uint16_t;\ntypedef int int32_t;\ntypedef uint uint32_t;\n\n//------------------------------------------------------------------------------\n// block_q4_0\n//------------------------------------------------------------------------------\nstruct block_q4_0\n{\n    half d;\n    uint8_t qs[QK4_0 / 2];\n};\n\n//------------------------------------------------------------------------------\n// mul_vec_q_n_f32\n//------------------------------------------------------------------------------\n// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])\n// il indicates where the q4 quants begin (0 or QK4_0/4)\n// we assume that the yl's have been multiplied with the appropriate scale factor\n// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)\ninline float block_q_4_0_dot_y(\n        global struct block_q4_0 * qb_curr,\n        float sumy,\n        private float * yl,\n        int il\n) {\n    float d = qb_curr->d;\n    float2 acc = 0.f;\n    global ushort * qs = ((global ushort *)qb_curr + 1 + il/2);\n    for (int i = 0; i < 8; i+=2) {\n        acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F)\n                + yl[i + 1] * (qs[i / 2] & 0x0F00);\n        acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0)\n                + yl[i + 9] * (qs[i / 2] & 0xF000);\n    }\n    return d * (sumy * -8.f + acc.s0 + acc.s1);\n}\n\n#ifdef INTEL_GPU\n#define N_DST 4 // each SIMD group works on 4 rows\n#define N_SIMDGROUP 1 // number of SIMD groups in a thread group\n#define N_SIMDWIDTH 16 // assuming SIMD group size is 16\n#elif defined (ADRENO_GPU)\n#define N_DST 4\n#define N_SIMDGROUP 1\n#define N_SIMDWIDTH 64\n#endif\n\ninline void mul_vec_q_n_f32(\n        global void * src0,\n        global float * src1,\n        global float * dst,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n\n    const ulong nb = ne00/QK4_0;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n\n    // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global\n    // id of a SIMD group in the grid.\n    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);\n\n    global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0;\n    global float             * y = (global float             *) src1 + r1*ne10 + im*ne00*ne1;\n\n    float yl[16];       // src1 vector cache\n    float sumf[N_DST]={0.f};\n\n    int ix = get_sub_group_local_id()/2;\n    int il = 8*(get_sub_group_local_id()%2);\n\n    global float * yb = y + ix * QK4_0 + il;\n\n    // each thread in a SIMD group deals with half a block.\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {\n        float sumy = 0;\n        for (int i = 0; i < 8; i += 2) {\n            sumy += yb[i] + yb[i+1];\n            yl[i+0] = yb[i+ 0];\n            yl[i+1] = yb[i+ 1]/256.f;\n            sumy += yb[i+16] + yb[i+17];\n            yl[i+8] = yb[i+16]/16.f;\n            yl[i+9] = yb[i+17]/4096.f;\n        }\n\n        for (int row = 0; row < N_DST; row++) {\n            sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il);\n        }\n\n        // One thread in a SIMD group (i.e., subgroup) handles a half block,\n        // hence then entire SIMD group handles SIMDWIDTH/2 blocks.\n        // y points to the activation matrix (of type float). Therefore for\n        // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because\n        // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of\n        // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size.\n        yb += QK4_0 * (N_SIMDWIDTH/2);\n    }\n\n    // The above does not work for Adreno - it produces incorrect results for\n    // row = 1, 2, 3 and only row = 0 gives the correct result.\n    // If N_DST is changed, the below array must be initialized accordingly.\n    // This also seems to perform better on Intel.\n    float tot[N_DST] = {\n        sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]),\n        sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])};\n    for (int row = 0; row < N_DST; ++row) {\n        if (get_sub_group_local_id() == 0 && first_row + row < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row];\n        }\n    }\n}\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mat_q4_0_f32(\n        global void * src0,\n        ulong offset0,\n        global float * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global float*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK4_0                   32\n#define QR4_0                   2\n#define QK4_1                   32\n#define QR4_1                   2\n#define QK5_0                   32\n#define QR5_0                   2\n#define QK5_1                   32\n#define QR5_1                   2\n#define QK8_0                   32\n#define QR8_0                   1\n#define QK_K                    256\n#define K_QUANTS_PER_ITERATION  2\n\ntypedef char int8_t;\ntypedef uchar uint8_t;\ntypedef short int16_t;\ntypedef ushort uint16_t;\ntypedef int int32_t;\ntypedef uint uint32_t;\n\n//------------------------------------------------------------------------------\n// block_q4_0\n//------------------------------------------------------------------------------\nstruct block_q4_0\n{\n    half d;\n    uint8_t qs[QK4_0 / 2];\n};\n\ninline float mm_block_q_4_0_dot_y_flat(\n        global uchar * x,\n        global half  * dh,\n        float sumy,\n        float16 yl,\n        int il\n) {\n    float           d   = *dh;\n    global ushort * qs  = ((global ushort *)x + il/2);\n    float           acc = 0.f;\n\n    acc += yl.s0 * (qs[0] & 0x000F);\n    acc += yl.s1 * (qs[0] & 0x0F00);\n    acc += yl.s8 * (qs[0] & 0x00F0);\n    acc += yl.s9 * (qs[0] & 0xF000);\n\n    acc += yl.s2 * (qs[1] & 0x000F);\n    acc += yl.s3 * (qs[1] & 0x0F00);\n    acc += yl.sa * (qs[1] & 0x00F0);\n    acc += yl.sb * (qs[1] & 0xF000);\n\n    acc += yl.s4 * (qs[2] & 0x000F);\n    acc += yl.s5 * (qs[2] & 0x0F00);\n    acc += yl.sc * (qs[2] & 0x00F0);\n    acc += yl.sd * (qs[2] & 0xF000);\n\n    acc += yl.s6 * (qs[3] & 0x000F);\n    acc += yl.s7 * (qs[3] & 0x0F00);\n    acc += yl.se * (qs[3] & 0x00F0);\n    acc += yl.sf * (qs[3] & 0xF000);\n\n    return d * (sumy * -8.f + acc);\n}\n\n#ifdef INTEL_GPU\n#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix)\n#define N_SIMDGROUP 1 // number of SIMD groups in a thread group\n#define N_SIMDWIDTH 16 // assuming SIMD group size is 16\n#elif defined (ADRENO_GPU)\n#define N_DST 16\n#define N_SIMDGROUP 1\n#define N_SIMDWIDTH 64\n#endif\n//\n// This variant performs 1d blocking with 16x output.\n// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix).\n//\ninline void mul_mat_q_n_f32_1d_16x_flat(\n        global uchar * src0_q,\n        global half  * src0_d,\n        global float * src1,\n        global float * dst,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    const int nb = ne00/QK4_0;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n\n    // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of\n    // a SIMD group in the grid. Each SIMD group produces N_DST values in the\n    // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.\n    // Currently with llama2 7B, im is always 0.\n    // TODO: how to handle im/gqa*(nb*ne0)?\n    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    // The number of scales is the same as the number of blocks.\n    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);\n    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.\n    ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;\n\n    global uchar * x = (global uchar *) src0_q + offset0_q;\n    global half  * d = (global half  *) src0_d + offset0_d;\n    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;\n\n    float16 yl;\n    float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,\n                             0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f);\n\n    int ix = get_sub_group_local_id()/2;\n    int il = 8*(get_sub_group_local_id()%2);\n\n    global float * yb = y + ix*QK4_0 + il;\n\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {\n        float sumy = 0.f;\n\n        sumy += yb[0];\n        sumy += yb[1];\n        sumy += yb[2];\n        sumy += yb[3];\n        sumy += yb[4];\n        sumy += yb[5];\n        sumy += yb[6];\n        sumy += yb[7];\n\n        sumy += yb[16];\n        sumy += yb[17];\n        sumy += yb[18];\n        sumy += yb[19];\n        sumy += yb[20];\n        sumy += yb[21];\n        sumy += yb[22];\n        sumy += yb[23];\n\n        yl.s0 = yb[0];\n        yl.s1 = yb[1]/256.f;\n\n        yl.s2 = yb[2];\n        yl.s3 = yb[3]/256.f;\n\n        yl.s4 = yb[4];\n        yl.s5 = yb[5]/256.f;\n\n        yl.s6 = yb[6];\n        yl.s7 = yb[7]/256.f;\n\n        yl.s8 = yb[16]/16.f;\n        yl.s9 = yb[17]/4096.f;\n\n        yl.sa = yb[18]/16.f;\n        yl.sb = yb[19]/4096.f;\n\n        yl.sc = yb[20]/16.f;\n        yl.sd = yb[21]/4096.f;\n\n        yl.se = yb[22]/16.f;\n        yl.sf = yb[23]/4096.f;\n\n        sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  0*nb*QK4_0/2, d + ib +  0*nb, sumy, yl, il);\n        sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  1*nb*QK4_0/2, d + ib +  1*nb, sumy, yl, il);\n        sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  2*nb*QK4_0/2, d + ib +  2*nb, sumy, yl, il);\n        sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  3*nb*QK4_0/2, d + ib +  3*nb, sumy, yl, il);\n\n        sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  4*nb*QK4_0/2, d + ib +  4*nb, sumy, yl, il);\n        sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  5*nb*QK4_0/2, d + ib +  5*nb, sumy, yl, il);\n        sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  6*nb*QK4_0/2, d + ib +  6*nb, sumy, yl, il);\n        sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  7*nb*QK4_0/2, d + ib +  7*nb, sumy, yl, il);\n\n        sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  8*nb*QK4_0/2, d + ib +  8*nb, sumy, yl, il);\n        sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  9*nb*QK4_0/2, d + ib +  9*nb, sumy, yl, il);\n        sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il);\n        sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il);\n\n        sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il);\n        sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il);\n        sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il);\n        sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il);\n\n        yb += QK4_0 * (N_SIMDWIDTH/2);\n    }\n\n    float16 tot = (float16)(\n        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),\n        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),\n        sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),\n        sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7),\n\n        sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9),\n        sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb),\n        sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd),\n        sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf)\n    );\n\n    if (get_sub_group_local_id() == 0) {\n        if (first_row + 0 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;\n        }\n        if (first_row + 1 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;\n        }\n        if (first_row + 2 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;\n        }\n        if (first_row + 3 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;\n        }\n\n        if (first_row + 4 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;\n        }\n        if (first_row + 5 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;\n        }\n        if (first_row + 6 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;\n        }\n        if (first_row + 7 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;\n        }\n\n        if (first_row + 8 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8;\n        }\n        if (first_row + 9 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9;\n        }\n        if (first_row + 10 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa;\n        }\n        if (first_row + 11 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb;\n        }\n\n        if (first_row + 12 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc;\n        }\n        if (first_row + 13 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd;\n        }\n        if (first_row + 14 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se;\n        }\n        if (first_row + 15 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf;\n        }\n    }\n}\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mat_q4_0_f32_1d_16x_flat(\n        global uchar * src0_q,\n        global half  * src0_d,\n        global float * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    src1 = (global float*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK4_0                   32\n#define QR4_0                   2\n#define QK4_1                   32\n#define QR4_1                   2\n#define QK5_0                   32\n#define QR5_0                   2\n#define QK5_1                   32\n#define QR5_1                   2\n#define QK8_0                   32\n#define QR8_0                   1\n#define QK_K                    256\n#define K_QUANTS_PER_ITERATION  2\n\ntypedef char int8_t;\ntypedef uchar uint8_t;\ntypedef short int16_t;\ntypedef ushort uint16_t;\ntypedef int int32_t;\ntypedef uint uint32_t;\n\n//------------------------------------------------------------------------------\n// block_q4_0\n//------------------------------------------------------------------------------\nstruct block_q4_0\n{\n    half d;\n    uint8_t qs[QK4_0 / 2];\n};\n\ninline float mm_block_q_4_0_dot_y_flat(\n        global uchar * x,\n        global half  * dh,\n        float sumy,\n        float16 yl,\n        int il\n) {\n    float           d   = *dh;\n    global ushort * qs  = ((global ushort *)x + il/2);\n    float           acc = 0.f;\n\n    acc += yl.s0 * (qs[0] & 0x000F);\n    acc += yl.s1 * (qs[0] & 0x0F00);\n    acc += yl.s8 * (qs[0] & 0x00F0);\n    acc += yl.s9 * (qs[0] & 0xF000);\n\n    acc += yl.s2 * (qs[1] & 0x000F);\n    acc += yl.s3 * (qs[1] & 0x0F00);\n    acc += yl.sa * (qs[1] & 0x00F0);\n    acc += yl.sb * (qs[1] & 0xF000);\n\n    acc += yl.s4 * (qs[2] & 0x000F);\n    acc += yl.s5 * (qs[2] & 0x0F00);\n    acc += yl.sc * (qs[2] & 0x00F0);\n    acc += yl.sd * (qs[2] & 0xF000);\n\n    acc += yl.s6 * (qs[3] & 0x000F);\n    acc += yl.s7 * (qs[3] & 0x0F00);\n    acc += yl.se * (qs[3] & 0x00F0);\n    acc += yl.sf * (qs[3] & 0xF000);\n\n    return d * (sumy * -8.f + acc);\n}\n\n#ifdef INTEL_GPU\n#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix)\n#define N_SIMDGROUP 1 // number of SIMD groups in a thread group\n#define N_SIMDWIDTH 16 // assuming SIMD group size is 16\n#elif defined (ADRENO_GPU)\n#define N_DST 8\n#define N_SIMDGROUP 1\n#define N_SIMDWIDTH 64\n#endif\n//\n// This variant performs 1d blocking with 8x output.\n// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix).\n//\ninline void mul_mat_q_n_f32_1d_8x_flat(\n        global uchar * src0_q,\n        global half  * src0_d,\n        global float * src1,\n        global float * dst,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    const int nb = ne00/QK4_0;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n\n    // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of\n    // a SIMD group in the grid. Each SIMD group produces N_DST values in the\n    // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.\n    // Currently with llama2 7B, im is always 0.\n    // TODO: how to handle im/gqa*(nb*ne0)?\n    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    // The number of scales is the same as the number of blocks.\n    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);\n    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.\n    ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;\n\n    global uchar * x = (global uchar *) src0_q + offset0_q;\n    global half  * d = (global half  *) src0_d + offset0_d;\n    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;\n\n    float16 yl;\n    float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f);\n\n    int ix = get_sub_group_local_id()/2;\n    int il = 8*(get_sub_group_local_id()%2);\n\n    global float * yb = y + ix*QK4_0 + il;\n\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {\n        float sumy = 0.f;\n\n        sumy += yb[0];\n        sumy += yb[1];\n        sumy += yb[2];\n        sumy += yb[3];\n        sumy += yb[4];\n        sumy += yb[5];\n        sumy += yb[6];\n        sumy += yb[7];\n\n        sumy += yb[16];\n        sumy += yb[17];\n        sumy += yb[18];\n        sumy += yb[19];\n        sumy += yb[20];\n        sumy += yb[21];\n        sumy += yb[22];\n        sumy += yb[23];\n\n        yl.s0 = yb[0];\n        yl.s1 = yb[1]/256.f;\n\n        yl.s2 = yb[2];\n        yl.s3 = yb[3]/256.f;\n\n        yl.s4 = yb[4];\n        yl.s5 = yb[5]/256.f;\n\n        yl.s6 = yb[6];\n        yl.s7 = yb[7]/256.f;\n\n        yl.s8 = yb[16]/16.f;\n        yl.s9 = yb[17]/4096.f;\n\n        yl.sa = yb[18]/16.f;\n        yl.sb = yb[19]/4096.f;\n\n        yl.sc = yb[20]/16.f;\n        yl.sd = yb[21]/4096.f;\n\n        yl.se = yb[22]/16.f;\n        yl.sf = yb[23]/4096.f;\n\n        sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);\n        sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);\n        sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);\n        sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);\n\n        sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);\n        sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);\n        sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);\n        sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);\n\n        yb += QK4_0 * (N_SIMDWIDTH/2);\n    }\n\n    float8 tot = (float8)(\n        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),\n        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),\n        sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),\n        sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)\n    );\n\n    if (get_sub_group_local_id() == 0) {\n        if (first_row + 0 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;\n        }\n        if (first_row + 1 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;\n        }\n        if (first_row + 2 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;\n        }\n        if (first_row + 3 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;\n        }\n\n        if (first_row + 4 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;\n        }\n        if (first_row + 5 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;\n        }\n        if (first_row + 6 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;\n        }\n        if (first_row + 7 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;\n        }\n    }\n}\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mat_q4_0_f32_1d_8x_flat(\n        global uchar * src0_q,\n        global half  * src0_d,\n        global float * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    src1 = (global float*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK4_0                   32\n#define QR4_0                   2\n#define QK4_1                   32\n#define QR4_1                   2\n#define QK5_0                   32\n#define QR5_0                   2\n#define QK5_1                   32\n#define QR5_1                   2\n#define QK8_0                   32\n#define QR8_0                   1\n#define QK_K                    256\n#define K_QUANTS_PER_ITERATION  2\n\ntypedef char int8_t;\ntypedef uchar uint8_t;\ntypedef short int16_t;\ntypedef ushort uint16_t;\ntypedef int int32_t;\ntypedef uint uint32_t;\n\n//------------------------------------------------------------------------------\n// block_q4_0\n//------------------------------------------------------------------------------\nstruct block_q4_0\n{\n    half d;\n    uint8_t qs[QK4_0 / 2];\n};\n\n// This function requires the original shuffled weights.\n// As a reminder, the original weights are shuffled so that (q[0], q[16]) are\n// packed together in a byte, so are (q[1], q[17]) and so on.\ninline float block_q_4_0_dot_y_flat(\n        global uchar * x,\n        global half  * dh,\n        float sumy,\n        float16 yl,\n        int il\n) {\n    float           d   = *dh;\n    global ushort * qs  = ((global ushort *)x + il/2);\n    float           acc = 0.f;\n\n    acc += yl.s0 * (qs[0] & 0x000F);\n    acc += yl.s1 * (qs[0] & 0x0F00);\n    acc += yl.s8 * (qs[0] & 0x00F0);\n    acc += yl.s9 * (qs[0] & 0xF000);\n\n    acc += yl.s2 * (qs[1] & 0x000F);\n    acc += yl.s3 * (qs[1] & 0x0F00);\n    acc += yl.sa * (qs[1] & 0x00F0);\n    acc += yl.sb * (qs[1] & 0xF000);\n\n    acc += yl.s4 * (qs[2] & 0x000F);\n    acc += yl.s5 * (qs[2] & 0x0F00);\n    acc += yl.sc * (qs[2] & 0x00F0);\n    acc += yl.sd * (qs[2] & 0xF000);\n\n    acc += yl.s6 * (qs[3] & 0x000F);\n    acc += yl.s7 * (qs[3] & 0x0F00);\n    acc += yl.se * (qs[3] & 0x00F0);\n    acc += yl.sf * (qs[3] & 0xF000);\n\n    return d * (sumy * -8.f + acc);\n}\n\n//\n// This variant outputs 8 values.\n//\n#undef N_DST\n#undef N_SIMDGROUP\n#undef N_SIMDWIDTH\n\n#ifdef INTEL_GPU\n#define N_DST 8 // each SIMD group works on 8 rows\n#define N_SIMDGROUP 1 // number of SIMD groups in a thread group\n#define N_SIMDWIDTH 16 // assuming SIMD group size is 32\n#elif defined (ADRENO_GPU)\n#define N_DST 8\n#define N_SIMDGROUP 1\n#define N_SIMDWIDTH 64\n#endif\n\ninline void mul_vec_q_n_f32_8x_flat(\n        global uchar * src0_q,\n        global half  * src0_d,\n        global float * src1,\n        global float * dst,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    const ulong nb = ne00/QK4_0;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n\n    // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of\n    // a SIMD group in the grid. Each SIMD group produces N_DST values in the\n    // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.\n    // Currently with llama2 7B, im is always 0.\n    // TODO: how to handle im/gqa*(nb*ne0)?\n    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    // The number of scales is the same as the number of blocks.\n    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);\n    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.\n    ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;\n\n    global uchar * x = (global uchar *) src0_q + offset0_q;\n    global half  * d = (global half  *) src0_d + offset0_d;\n    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;\n\n    float16 yl;\n    float8 sumf = 0.f;\n\n    int ix = get_sub_group_local_id()/2;\n    int il = 8*(get_sub_group_local_id()%2);\n\n    global float * yb = y + ix*QK4_0 + il;\n\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {\n        float sumy = 0.f;\n\n        sumy += yb[0];\n        sumy += yb[1];\n        sumy += yb[2];\n        sumy += yb[3];\n        sumy += yb[4];\n        sumy += yb[5];\n        sumy += yb[6];\n        sumy += yb[7];\n\n        sumy += yb[16];\n        sumy += yb[17];\n        sumy += yb[18];\n        sumy += yb[19];\n        sumy += yb[20];\n        sumy += yb[21];\n        sumy += yb[22];\n        sumy += yb[23];\n\n        yl.s0 = yb[0];\n        yl.s1 = yb[1]/256.f;\n\n        yl.s2 = yb[2];\n        yl.s3 = yb[3]/256.f;\n\n        yl.s4 = yb[4];\n        yl.s5 = yb[5]/256.f;\n\n        yl.s6 = yb[6];\n        yl.s7 = yb[7]/256.f;\n\n        yl.s8 = yb[16]/16.f;\n        yl.s9 = yb[17]/4096.f;\n\n        yl.sa = yb[18]/16.f;\n        yl.sb = yb[19]/4096.f;\n\n        yl.sc = yb[20]/16.f;\n        yl.sd = yb[21]/4096.f;\n\n        yl.se = yb[22]/16.f;\n        yl.sf = yb[23]/4096.f;\n\n        sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);\n        sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);\n        sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);\n        sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);\n\n        sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);\n        sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);\n        sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);\n        sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);\n\n        yb += QK4_0 * (N_SIMDWIDTH/2);\n    }\n\n    float8 tot = (float8)(\n        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),\n        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),\n        sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),\n        sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)\n    );\n\n    if (get_sub_group_local_id() == 0) {\n        if (first_row + 0 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;\n        }\n        if (first_row + 1 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;\n        }\n        if (first_row + 2 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;\n        }\n        if (first_row + 3 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;\n        }\n\n        if (first_row + 4 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;\n        }\n        if (first_row + 5 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;\n        }\n        if (first_row + 6 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;\n        }\n        if (first_row + 7 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;\n        }\n    }\n}\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mat_q4_0_f32_8x_flat(\n        global uchar * src0_q,\n        global half  * src0_d,\n        global float * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    src1 = (global float*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK4_0                   32\n#define QR4_0                   2\n#define QK4_1                   32\n#define QR4_1                   2\n#define QK5_0                   32\n#define QR5_0                   2\n#define QK5_1                   32\n#define QR5_1                   2\n#define QK8_0                   32\n#define QR8_0                   1\n#define QK_K                    256\n#define K_QUANTS_PER_ITERATION  2\n\ntypedef char int8_t;\ntypedef uchar uint8_t;\ntypedef short int16_t;\ntypedef ushort uint16_t;\ntypedef int int32_t;\ntypedef uint uint32_t;\n\n//------------------------------------------------------------------------------\n// block_q4_0\n//------------------------------------------------------------------------------\nstruct block_q4_0\n{\n    half d;\n    uint8_t qs[QK4_0 / 2];\n};\n\n//\n// This variant unrolls the loops and uses vector types instead of pointers.\n// It improves performance on Adreno but not so much on Intel.\n//\ninline float block_q_4_0_dot_y_v(\n        global struct block_q4_0 * qb_curr,\n        float sumy,\n        float16 yl,\n        int il\n) {\n    float d = qb_curr->d;\n    float acc = 0.f;\n    global ushort * qs = ((global ushort *)qb_curr + 1 + il/2);\n\n    acc += yl.s0 * (qs[0] & 0x000F);\n    acc += yl.s1 * (qs[0] & 0x0F00);\n    acc += yl.s8 * (qs[0] & 0x00F0);\n    acc += yl.s9 * (qs[0] & 0xF000);\n\n    acc += yl.s2 * (qs[1] & 0x000F);\n    acc += yl.s3 * (qs[1] & 0x0F00);\n    acc += yl.sa * (qs[1] & 0x00F0);\n    acc += yl.sb * (qs[1] & 0xF000);\n\n    acc += yl.s4 * (qs[2] & 0x000F);\n    acc += yl.s5 * (qs[2] & 0x0F00);\n    acc += yl.sc * (qs[2] & 0x00F0);\n    acc += yl.sd * (qs[2] & 0xF000);\n\n    acc += yl.s6 * (qs[3] & 0x000F);\n    acc += yl.s7 * (qs[3] & 0x0F00);\n    acc += yl.se * (qs[3] & 0x00F0);\n    acc += yl.sf * (qs[3] & 0xF000);\n\n    return d * (sumy * -8.f + acc);\n}\n\n#undef N_DST\n#undef N_SIMDGROUP\n#undef N_SIMDWIDTH\n\n#ifdef INTEL_GPU\n#define N_DST 4 // each SIMD group works on 4 rows\n#define N_SIMDGROUP 1 // number of SIMD groups in a thread group\n#define N_SIMDWIDTH 16 // assuming SIMD group size is 16\n#elif defined (ADRENO_GPU)\n#define N_DST 4\n#define N_SIMDGROUP 1\n#define N_SIMDWIDTH 64\n#endif\n\ninline void mul_vec_q_n_f32_v(\n        global void * src0,\n        global float * src1,\n        global float * dst,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    const ulong nb = ne00/QK4_0;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n\n    // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global\n    // id of a SIMD group in the grid.\n    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);\n\n    global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0;\n    global float             * y = (global float             *) src1 + r1*ne10 + im*ne00*ne1;\n\n    float16 yl;       // src1 vector cache\n    float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);\n\n    int ix = get_sub_group_local_id()/2;\n    int il = 8*(get_sub_group_local_id()%2);\n\n    global float * yb = y + ix * QK4_0 + il;\n\n    // each thread in a SIMD group deals with half a block.\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {\n        float sumy = 0;\n\n        sumy += yb[0];\n        sumy += yb[1];\n        sumy += yb[2];\n        sumy += yb[3];\n        sumy += yb[4];\n        sumy += yb[5];\n        sumy += yb[6];\n        sumy += yb[7];\n\n        sumy += yb[16];\n        sumy += yb[17];\n        sumy += yb[18];\n        sumy += yb[19];\n        sumy += yb[20];\n        sumy += yb[21];\n        sumy += yb[22];\n        sumy += yb[23];\n\n\n        yl.s0 = yb[0];\n        yl.s1 = yb[1]/256.f;\n\n        yl.s2 = yb[2];\n        yl.s3 = yb[3]/256.f;\n\n        yl.s4 = yb[4];\n        yl.s5 = yb[5]/256.f;\n\n        yl.s6 = yb[6];\n        yl.s7 = yb[7]/256.f;\n\n        yl.s8 = yb[16]/16.f;\n        yl.s9 = yb[17]/4096.f;\n\n        yl.sa = yb[18]/16.f;\n        yl.sb = yb[19]/4096.f;\n\n        yl.sc = yb[20]/16.f;\n        yl.sd = yb[21]/4096.f;\n\n        yl.se = yb[22]/16.f;\n        yl.sf = yb[23]/4096.f;\n\n        sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il);\n        sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il);\n        sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il);\n        sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il);\n\n        // One thread in a SIMD group (i.e., subgroup) handles a half block,\n        // hence then entire SIMD group handles SIMDWIDTH/2 blocks.\n        // y points to the activation matrix (of type float). Therefore for\n        // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because\n        // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of\n        // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size.\n        yb += QK4_0 * (N_SIMDWIDTH/2);\n    }\n\n    // The above does not work for Adreno - it produces incorrect results for\n    // row = 1, 2, 3 and only row = 0 gives the correct result.\n    // If N_DST is changed, the below array must be initialized accordingly.\n    // This also seems to perform better on Intel.\n    float4 tot = (float4)(\n        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),\n        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)\n    );\n\n    if (get_sub_group_local_id() == 0) {\n        if (first_row + 0 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;\n        }\n        if (first_row + 1 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;\n        }\n        if (first_row + 2 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;\n        }\n        if (first_row + 3 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;\n        }\n    }\n}\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mat_q4_0_f32_v(\n        global void * src0,\n        ulong offset0,\n        global float * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global float*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK4_1                   32\n\nstruct block_q4_1 {\n    half d; // delta\n    half m; // min\n    uchar qs[QK4_1 / 2]; // nibbles / quants\n};\n\ninline float block_q4_1_dot_y(\n    global const struct block_q4_1 * qb_curr,\n    float sumy,\n    float16 yl,\n    int il\n) {\n    float d = qb_curr->d;\n    float m = qb_curr->m;\n\n    float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);\n\n    global const ushort * qs = ((global const ushort *) qb_curr + 2 + il/2);\n\n    acc.s0 += yl.s0 * (qs[0] & 0x000F);\n    acc.s0 += yl.s1 * (qs[0] & 0x0F00);\n    acc.s0 += yl.s8 * (qs[0] & 0x00F0);\n    acc.s3 += yl.s9 * (qs[0] & 0xF000);\n\n    acc.s0 += yl.s2 * (qs[1] & 0x000F);\n    acc.s1 += yl.s3 * (qs[1] & 0x0F00);\n    acc.s2 += yl.sa * (qs[1] & 0x00F0);\n    acc.s3 += yl.sb * (qs[1] & 0xF000);\n\n    acc.s0 += yl.s4 * (qs[2] & 0x000F);\n    acc.s1 += yl.s5 * (qs[2] & 0x0F00);\n    acc.s2 += yl.sc * (qs[2] & 0x00F0);\n    acc.s3 += yl.sd * (qs[2] & 0xF000);\n\n    acc.s0 += yl.s6 * (qs[3] & 0x000F);\n    acc.s1 += yl.s7 * (qs[3] & 0x0F00);\n    acc.s2 += yl.se * (qs[3] & 0x00F0);\n    acc.s3 += yl.sf * (qs[3] & 0xF000);\n\n    return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;\n}\n\n#undef N_DST\n#undef N_SIMDGROUP\n#undef N_SIMDWIDTH\n\n#ifdef INTEL_GPU\n#define N_DST 4 // each subgroup works on 4 rows\n#define N_SIMDGROUP 1 // number of subgroups in a thread group\n#define N_SIMDWIDTH 16 // assuming subgroup size is 16\n#elif defined (ADRENO_GPU)\n#define N_DST 4\n#define N_SIMDGROUP 1\n#define N_SIMDWIDTH 64\n#endif\n\ninline void mul_vec_q_n_f32(\n        global void * src0,\n        global float * src1,\n        global float * dst,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    const ulong nb = ne00/QK4_1;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n\n    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);\n\n    global struct block_q4_1 * x = (global struct block_q4_1 *) src0 + offset0;\n    global float             * y = (global float             *) src1 + r1*ne10 + im*ne00*ne1;\n\n    float16 yl;\n    float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);\n\n    int ix = get_sub_group_local_id()/2;\n    int il = 8*(get_sub_group_local_id()%2);\n\n    global float * yb = y + ix * QK4_1 + il;\n\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {\n        float sumy = 0;\n\n        sumy += yb[0];\n        sumy += yb[1];\n        sumy += yb[2];\n        sumy += yb[3];\n        sumy += yb[4];\n        sumy += yb[5];\n        sumy += yb[6];\n        sumy += yb[7];\n\n        sumy += yb[16];\n        sumy += yb[17];\n        sumy += yb[18];\n        sumy += yb[19];\n        sumy += yb[20];\n        sumy += yb[21];\n        sumy += yb[22];\n        sumy += yb[23];\n\n\n        yl.s0 = yb[0];\n        yl.s1 = yb[1]/256.f;\n\n        yl.s2 = yb[2];\n        yl.s3 = yb[3]/256.f;\n\n        yl.s4 = yb[4];\n        yl.s5 = yb[5]/256.f;\n\n        yl.s6 = yb[6];\n        yl.s7 = yb[7]/256.f;\n\n        yl.s8 = yb[16]/16.f;\n        yl.s9 = yb[17]/4096.f;\n\n        yl.sa = yb[18]/16.f;\n        yl.sb = yb[19]/4096.f;\n\n        yl.sc = yb[20]/16.f;\n        yl.sd = yb[21]/4096.f;\n\n        yl.se = yb[22]/16.f;\n        yl.sf = yb[23]/4096.f;\n\n        sumf.s0 += block_q4_1_dot_y(x+ib+0*nb, sumy, yl, il);\n        sumf.s1 += block_q4_1_dot_y(x+ib+1*nb, sumy, yl, il);\n        sumf.s2 += block_q4_1_dot_y(x+ib+2*nb, sumy, yl, il);\n        sumf.s3 += block_q4_1_dot_y(x+ib+3*nb, sumy, yl, il);\n\n        yb += QK4_1 * (N_SIMDWIDTH/2);\n    }\n\n    float4 tot = (float4)(\n        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),\n        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)\n    );\n\n    if (get_sub_group_local_id() == 0) {\n        if (first_row + 0 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;\n        }\n        if (first_row + 1 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;\n        }\n        if (first_row + 2 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;\n        }\n        if (first_row + 3 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;\n        }\n    }\n}\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mv_q4_1_f32(\n        global void * src0,\n        ulong offset0,\n        global float * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global float*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK4_1                   32\n\nstruct block_q4_1 {\n    half d; // delta\n    half m; // min\n    uchar qs[QK4_1 / 2]; // nibbles / quants\n};\n\ninline float block_q4_1_dot_y_flat(\n    global const uchar * x,\n    global const half  * dh,\n    global const half  * mh,\n    float sumy,\n    float16 yl,\n    int il\n) {\n    float                 d   = *dh;\n    float                 m   = *mh;\n    global const ushort * qs = ((global const ushort *) x + il/2);\n\n    float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);\n\n    acc.s0 += yl.s0 * (qs[0] & 0x000F);\n    acc.s0 += yl.s1 * (qs[0] & 0x0F00);\n    acc.s0 += yl.s8 * (qs[0] & 0x00F0);\n    acc.s3 += yl.s9 * (qs[0] & 0xF000);\n\n    acc.s0 += yl.s2 * (qs[1] & 0x000F);\n    acc.s1 += yl.s3 * (qs[1] & 0x0F00);\n    acc.s2 += yl.sa * (qs[1] & 0x00F0);\n    acc.s3 += yl.sb * (qs[1] & 0xF000);\n\n    acc.s0 += yl.s4 * (qs[2] & 0x000F);\n    acc.s1 += yl.s5 * (qs[2] & 0x0F00);\n    acc.s2 += yl.sc * (qs[2] & 0x00F0);\n    acc.s3 += yl.sd * (qs[2] & 0xF000);\n\n    acc.s0 += yl.s6 * (qs[3] & 0x000F);\n    acc.s1 += yl.s7 * (qs[3] & 0x0F00);\n    acc.s2 += yl.se * (qs[3] & 0x00F0);\n    acc.s3 += yl.sf * (qs[3] & 0xF000);\n\n    return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;\n}\n\n#undef N_DST\n#undef N_SIMDGROUP\n#undef N_SIMDWIDTH\n\n#ifdef INTEL_GPU\n#define N_DST 4 // each subgroup works on 4 rows\n#define N_SIMDGROUP 1 // number of subgroups in a thread group\n#define N_SIMDWIDTH 16 // assuming subgroup size is 16\n#elif defined (ADRENO_GPU)\n#define N_DST 4\n#define N_SIMDGROUP 1\n#define N_SIMDWIDTH 64\n#endif\n\ninline void mul_vec_q_n_f32_flat(\n        global void * src0_q,\n        global void * src0_d,\n        global void * src0_m,\n        global float * src1,\n        global float * dst,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    const ulong nb = ne00/QK4_1;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n\n    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);\n\n    // The number of scales/mins is the same as the number of blocks.\n    ulong offset0_dm = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02));\n    // Each block contains QK4_1/2 uchars, hence offset for qs is as follows.\n    ulong offset0_q  = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_1/2;\n\n    global uchar * x = (global uchar *) src0_q + offset0_q;\n    global half  * d = (global half  *) src0_d + offset0_dm;\n    global half  * m = (global half  *) src0_m + offset0_dm;\n    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;\n\n    float16 yl;\n    float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);\n\n    int ix = get_sub_group_local_id()/2;\n    int il = 8*(get_sub_group_local_id()%2);\n\n    global float * yb = y + ix * QK4_1 + il;\n\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {\n        float sumy = 0;\n\n        sumy += yb[0];\n        sumy += yb[1];\n        sumy += yb[2];\n        sumy += yb[3];\n        sumy += yb[4];\n        sumy += yb[5];\n        sumy += yb[6];\n        sumy += yb[7];\n\n        sumy += yb[16];\n        sumy += yb[17];\n        sumy += yb[18];\n        sumy += yb[19];\n        sumy += yb[20];\n        sumy += yb[21];\n        sumy += yb[22];\n        sumy += yb[23];\n\n\n        yl.s0 = yb[0];\n        yl.s1 = yb[1]/256.f;\n\n        yl.s2 = yb[2];\n        yl.s3 = yb[3]/256.f;\n\n        yl.s4 = yb[4];\n        yl.s5 = yb[5]/256.f;\n\n        yl.s6 = yb[6];\n        yl.s7 = yb[7]/256.f;\n\n        yl.s8 = yb[16]/16.f;\n        yl.s9 = yb[17]/4096.f;\n\n        yl.sa = yb[18]/16.f;\n        yl.sb = yb[19]/4096.f;\n\n        yl.sc = yb[20]/16.f;\n        yl.sd = yb[21]/4096.f;\n\n        yl.se = yb[22]/16.f;\n        yl.sf = yb[23]/4096.f;\n\n        sumf.s0 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 0*nb*QK4_1/2, d + ib + 0*nb, m + ib + 0*nb, sumy, yl, il);\n        sumf.s1 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 1*nb*QK4_1/2, d + ib + 1*nb, m + ib + 1*nb, sumy, yl, il);\n        sumf.s2 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 2*nb*QK4_1/2, d + ib + 2*nb, m + ib + 2*nb, sumy, yl, il);\n        sumf.s3 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 3*nb*QK4_1/2, d + ib + 3*nb, m + ib + 3*nb, sumy, yl, il);\n\n        yb += QK4_1 * (N_SIMDWIDTH/2);\n    }\n\n    float4 tot = (float4)(\n        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),\n        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)\n    );\n\n    if (get_sub_group_local_id() == 0) {\n        if (first_row + 0 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;\n        }\n        if (first_row + 1 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;\n        }\n        if (first_row + 2 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;\n        }\n        if (first_row + 3 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;\n        }\n    }\n}\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mv_q4_1_f32_flat(\n        global void * src0_q,\n        global void * src0_d,\n        global void * src0_m,\n        global float * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    src1 = (global float*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    mul_vec_q_n_f32_flat(src0_q, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl",
    "content": "#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n//------------------------------------------------------------------------------\n// block_q4_K\n//------------------------------------------------------------------------------\n#define QK_K            256\n#define K_SCALE_SIZE    12\n\n// 8 blocks of 32 elements each\n// weight is represented as x = a * q + b\ntypedef struct {\n    half d;    // super-block scale for quantized scales\n    half dmin; // super-block scale for quantized mins\n\n    uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits\n    uchar qs[QK_K/2];           // 4-bit quants\n} block_q4_K;\n\n#undef N_DST\n#undef N_SIMDGROUP\n#undef N_SIMDWIDTH\n\n#ifdef INTEL_GPU\n#define N_DST 4 // number of rows each SIMD group works on\n#define N_SIMDGROUP 1 // number of SIMD groups in a thread group\n#define N_SIMDWIDTH 16 // SIMD group size\n#elif defined (ADRENO_GPU)\n#define N_DST 4\n#define N_SIMDGROUP 1\n#define N_SIMDWIDTH 64\n#endif\n\n#undef  BLOCK_STRIDE\n// number of (super) blocks each subgroup processes\n// each thread in a subgroup processes a block (32 weights)\n#define BLOCK_STRIDE (N_SIMDWIDTH/8)\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mv_q4_K_f32(\n        global char * src0,\n        int offset0,\n        global char * src1,\n        int offset1,\n        global char * dst,\n        int offsetd,\n        int ne00,\n        int ne01,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne12,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst  + offsetd;\n\n    ushort kmask1 = 0x3f3f;\n    ushort kmask2 = 0x0f0f;\n    ushort kmask3 = 0xc0c0;\n\n    int ix = get_sub_group_local_id()/8;  // super block index\n    int it = get_sub_group_local_id()%8;  // block index (inside super block)\n    int iq = it/4;     // 0 or 1 - first or second half of the super block\n    int ir = it%4;     // 0...3 - block index in the half super block\n\n    int nb = ne00/QK_K;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;\n    int offset_src1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;\n\n    global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0);\n    global float      * y = (global float      *) (src1 + offset_src1);\n\n    float yl[16];\n    float yh[16];\n    float sumf[N_DST] = {0.f};\n    float all_sum;\n\n    global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;\n\n    ushort  sc16[4];\n    uchar * sc8 = (uchar *)sc16;\n\n    for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {\n        float4 sumy = {0.f, 0.f, 0.f, 0.f};\n        for (int i = 0; i < 8; ++i) {\n            yl[i+0] = y4[i+0];\n            sumy.s0 += yl[i+0];\n\n            yl[i+8] = y4[i+32];\n            sumy.s1 += yl[i+8];\n\n            yh[i+0] = y4[i+128];\n            sumy.s2 += yh[i+0];\n\n            yh[i+8] = y4[i+160];\n            sumy.s3 += yh[i+8];\n        }\n\n        global ushort * sc = (global ushort *)x[ib].scales + iq;\n        global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir;\n        global half     * dh = &x[ib].d;\n\n        for (int row = 0; row < N_DST; row++) {\n            sc16[0] = sc[0] & kmask1;\n            sc16[1] = sc[2] & kmask1;\n            sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);\n            sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);\n\n            global ushort * q2 = q1 + 32;\n\n            float4 acc1 = {0.f, 0.f, 0.f, 0.f};\n            float4 acc2 = {0.f, 0.f, 0.f, 0.f};\n            for (int i = 0; i < 8; i += 2) {\n                acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F);\n                acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00);\n                acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0);\n                acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000);\n                acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F);\n                acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00);\n                acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0);\n                acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000);\n            }\n\n            float dall = dh[0];\n            float dmin = dh[1];\n            sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +\n                                 (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +\n                                 (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +\n                                 (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -\n                         dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);\n\n            q1 += nb01/2;\n            sc += nb01/2;\n            dh += nb01/2;\n        }\n\n        y4 += BLOCK_STRIDE * QK_K;\n    }\n\n    global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;\n\n    for (int row = 0; row < N_DST; ++row) {\n        all_sum = sub_group_reduce_add(sumf[row]);\n        if (first_row + row < ne01) {\n            if (get_sub_group_local_id() == 0) {\n                dst_f32[first_row + row] = all_sum;\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK4_0                   32\n#define QR4_0                   2\n#define QK4_1                   32\n#define QR4_1                   2\n#define QK5_0                   32\n#define QR5_0                   2\n#define QK5_1                   32\n#define QR5_1                   2\n#define QK8_0                   32\n#define QR8_0                   1\n#define QK_K                    256\n#define K_QUANTS_PER_ITERATION  2\n\ntypedef char int8_t;\ntypedef uchar uint8_t;\ntypedef short int16_t;\ntypedef ushort uint16_t;\ntypedef int int32_t;\ntypedef uint uint32_t;\n\n//------------------------------------------------------------------------------\n// block_q6_K\n//------------------------------------------------------------------------------\n// 6-bit quantization\n// weight is represented as x = a * q\n// 16 blocks of 16 elements each\n// Effectively 6.5625 bits per weight\ntypedef struct {\n    uint8_t ql[QK_K/2];      // quants, lower 4 bits\n    uint8_t qh[QK_K/4];      // quants, upper 2 bits\n    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits\n    half d;             // super-block scale\n} block_q6_K;\n\n//------------------------------------------------------------------------------\n// kernel_mul_mv_q6_K_f32\n//------------------------------------------------------------------------------\n\n#undef N_DST\n#undef N_SIMDGROUP\n#undef N_SIMDWIDTH\n\n#ifdef INTEL_GPU\n#define N_DST 1 // number of rows each SIMD group works on\n#define N_SIMDGROUP 2 // number of SIMD groups in a thread group\n#define N_SIMDWIDTH 16 // SIMD group size\n#elif defined (ADRENO_GPU)\n#define N_DST 1\n#define N_SIMDGROUP 2\n#define N_SIMDWIDTH 64\n#endif\n\n#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mv_q6_K_f32(\n        global void * src0,\n        ulong offset0,\n        global float * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global float*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    uchar kmask1 = 0x03;\n    uchar kmask2 = 0x0C;\n    uchar kmask3 = 0x30;\n    uchar kmask4 = 0xC0;\n\n    int nb = ne00/QK_K;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n\n    int row = N_SIMDGROUP * r0 + get_sub_group_id();\n\n    if (row >= ne01) {\n        return;\n    }\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);\n\n    global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0;\n    global float      * yy = (global float     *) src1 + r1*ne10 + im*ne00*ne1;\n\n    float sumf = 0;\n\n    // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a\n    // block. Values in a subblock shares a scale that is quantized with 8 bits;\n    // the entire block shares a single floating point scale.\n    // For work distribution, each thread processes a subblock (16 weights), hence\n    // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16\n    // (super) blocks -- this is the block stride.\n    // The 16 threads that process a (super) block are split into 2 portions, each has\n    // 8 threads; each portion works on 8 subblocks.\n    // For subgroup of 16 threads, the entire subgroup works on a single (super) block\n    // before moving to the next (super) block. Thread0 - thread7 work on the\n    // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks.\n    // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on\n    // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but\n    // works on a total of 16 weight values.\n    int tid  = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0\n    int ix   = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1\n    int ip   = tid/8;   // first or second half of (super) block (0 or 1)\n    int il   = tid%8;   // each half has 8 parts, one per scale\n    int n    = 4;       // 4 scales at a time (and 4 sums)\n    int l0   = n*il;    // offset into half-block, 0..28\n    int is   = 8*ip + l0/16; // 0, 1, 8, 9\n\n    int y_offset = 128*ip + l0;\n    int q_offset_l = 64*ip + l0;\n    int q_offset_h = 32*ip + l0;\n\n    for (int i = ix; i < nb; i += BLOCK_STRIDE) {\n\n        global uint8_t * q1 = x[i].ql + q_offset_l;\n        global uint8_t * q2 = q1 + QK_K/8;\n        global uint8_t * qh = x[i].qh + q_offset_h;\n        global int8_t  * sc = x[i].scales + is;\n\n        global float * y = yy + i * QK_K + y_offset;\n\n        float dall = x[i].d;\n\n        float4 sums = {0.f, 0.f, 0.f, 0.f};\n\n        sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f);\n        sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f);\n        sums.s2 += y[0+64] * ((float)((q1[0]  >> 4) | ((qh[0] & kmask3) << 0)) - 32.f);\n        sums.s3 += y[0+96] * ((float)((q2[0]  >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f);\n\n        sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f);\n        sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f);\n        sums.s2 += y[1+64] * ((float)((q1[1]  >> 4) | ((qh[1] & kmask3) << 0)) - 32.f);\n        sums.s3 += y[1+96] * ((float)((q2[1]  >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f);\n\n        sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f);\n        sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f);\n        sums.s2 += y[2+64] * ((float)((q1[2]  >> 4) | ((qh[2] & kmask3) << 0)) - 32.f);\n        sums.s3 += y[2+96] * ((float)((q2[2]  >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f);\n\n        sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f);\n        sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f);\n        sums.s2 += y[3+64] * ((float)((q1[3]  >> 4) | ((qh[3] & kmask3) << 0)) - 32.f);\n        sums.s3 += y[3+96] * ((float)((q2[3]  >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f);\n\n        sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]);\n    }\n\n    float tot = sub_group_reduce_add(sumf);\n    if (get_sub_group_local_id() == 0) {\n        dst[r1*ne0 + im*ne0*ne1 + row] = tot;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n//------------------------------------------------------------------------------\n// kernel_mul_mv_q6_K_f32_flat\n//------------------------------------------------------------------------------\n#define Q6_K_MASK1 0x03\n#define Q6_K_MASK2 0x0C\n#define Q6_K_MASK3 0x30\n#define Q6_K_MASK4 0xC0\n\n#define QK_K       256\n\ninline float block_q_6_K_dot_y_flat(\n    global uchar * blk_ql,\n    global uchar * blk_qh,\n    global char  * blk_scales,\n    global half  * blk_d,\n    global float * yy,\n    int ib,\n    int ip,\n    int is,\n    int l0\n) {\n    int y_offset   = 128*ip + l0;\n    int q_offset_l =  64*ip + l0;\n    int q_offset_h =  32*ip + l0;\n\n    global uchar * q1 = blk_ql     + ib*128 + q_offset_l;\n    global uchar * q2 = q1         + QK_K/8;\n    global uchar * qh = blk_qh     + ib*64 + q_offset_h;\n    global char  * sc = blk_scales + ib*16 + is;\n\n    global float * y = yy + ib * QK_K + y_offset;\n\n    float dall = blk_d[ib];\n\n    float  sumf = 0;\n    float4 sums = {0.f, 0.f, 0.f, 0.f};\n\n    sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & Q6_K_MASK1) << 4)) - 32.f);\n    sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & Q6_K_MASK2) << 2)) - 32.f);\n    sums.s2 += y[0+64] * ((float)((q1[0]  >> 4) | ((qh[0] & Q6_K_MASK3) << 0)) - 32.f);\n    sums.s3 += y[0+96] * ((float)((q2[0]  >> 4) | ((qh[0] & Q6_K_MASK4) >> 2)) - 32.f);\n\n    sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & Q6_K_MASK1) << 4)) - 32.f);\n    sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & Q6_K_MASK2) << 2)) - 32.f);\n    sums.s2 += y[1+64] * ((float)((q1[1]  >> 4) | ((qh[1] & Q6_K_MASK3) << 0)) - 32.f);\n    sums.s3 += y[1+96] * ((float)((q2[1]  >> 4) | ((qh[1] & Q6_K_MASK4) >> 2)) - 32.f);\n\n    sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & Q6_K_MASK1) << 4)) - 32.f);\n    sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & Q6_K_MASK2) << 2)) - 32.f);\n    sums.s2 += y[2+64] * ((float)((q1[2]  >> 4) | ((qh[2] & Q6_K_MASK3) << 0)) - 32.f);\n    sums.s3 += y[2+96] * ((float)((q2[2]  >> 4) | ((qh[2] & Q6_K_MASK4) >> 2)) - 32.f);\n\n    sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & Q6_K_MASK1) << 4)) - 32.f);\n    sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & Q6_K_MASK2) << 2)) - 32.f);\n    sums.s2 += y[3+64] * ((float)((q1[3]  >> 4) | ((qh[3] & Q6_K_MASK3) << 0)) - 32.f);\n    sums.s3 += y[3+96] * ((float)((q2[3]  >> 4) | ((qh[3] & Q6_K_MASK4) >> 2)) - 32.f);\n\n    sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]);\n\n    return sumf;\n}\n\n#undef N_DST\n#undef N_SIMDGROUP\n#undef N_SIMDWIDTH\n\n#ifdef INTEL_GPU\n#define N_DST 4\n#define N_SIMDGROUP 2\n#define N_SIMDWIDTH 16\n#elif defined (ADRENO_GPU)\n#define N_DST 4\n#define N_SIMDGROUP 2\n#define N_SIMDWIDTH 64\n#endif\n\n#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mv_q6_K_f32_flat(\n        global uchar * src0_ql,\n        global uchar * src0_qh,\n        global char  * src0_s,\n        global half  * src0_d,\n        global float * src1,\n        ulong offset1,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne10,\n        int ne12,\n        int ne0,\n        int ne1,\n        int r2,\n        int r3\n) {\n    src1 = (global float*)((global char*)src1 + offset1);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int nb = ne00/QK_K;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n\n    int i12 = im%ne12;\n    int i13 = im/ne12;\n\n    int first_row = (N_SIMDGROUP * r0 + get_sub_group_id()) * N_DST;\n\n    ulong offset_src0    = first_row*nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);\n    ulong offset_src0_ql = offset_src0 * 128;\n    ulong offset_src0_qh = offset_src0 * 64;\n    ulong offset_src0_s  = offset_src0 * 16;\n    ulong offset_src0_d  = offset_src0;\n\n    global uchar * blk_ql     = (global uchar *) src0_ql + offset_src0_ql;\n    global uchar * blk_qh     = (global uchar *) src0_qh + offset_src0_qh;\n    global char  * blk_scales = (global char  *) src0_s  + offset_src0_s;\n    global half  * blk_d      = (global half  *) src0_d  + offset_src0_d;\n    global float * yy         = (global float *) src1    + r1*ne10 + im*ne00*ne1;\n\n    int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0\n    int ix  = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1\n    int ip  = tid/8;   // first or second half of (super) block (0 or 1)\n    int il  = tid%8;   // each half has 8 parts, one per scale\n    int n   = 4;       // 4 scales at a time (and 4 sums)\n    int l0  = n*il;    // offset into half-block, 0..28\n    int is  = 8*ip + l0/16; // 0, 1, 8, 9\n\n    float4 sumf = 0;\n\n    for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {\n        if (first_row + 0 < ne01) {\n            sumf.s0 += block_q_6_K_dot_y_flat(blk_ql + 0*nb*128, blk_qh + 0*nb*64, blk_scales + 0*nb*16, blk_d + 0*nb, yy, ib, ip, is, l0);\n        }\n        if (first_row + 1 < ne01) {\n            sumf.s1 += block_q_6_K_dot_y_flat(blk_ql + 1*nb*128, blk_qh + 1*nb*64, blk_scales + 1*nb*16, blk_d + 1*nb, yy, ib, ip, is, l0);\n        }\n        if (first_row + 2 < ne01) {\n            sumf.s2 += block_q_6_K_dot_y_flat(blk_ql + 2*nb*128, blk_qh + 2*nb*64, blk_scales + 2*nb*16, blk_d + 2*nb, yy, ib, ip, is, l0);\n        }\n        if (first_row + 3 < ne01) {\n            sumf.s3 += block_q_6_K_dot_y_flat(blk_ql + 3*nb*128, blk_qh + 3*nb*64, blk_scales + 3*nb*16, blk_d + 3*nb, yy, ib, ip, is, l0);\n        }\n    }\n\n    float4 tot = (float4)(\n        sub_group_reduce_add(sumf.s0),\n        sub_group_reduce_add(sumf.s1),\n        sub_group_reduce_add(sumf.s2),\n        sub_group_reduce_add(sumf.s3)\n    );\n    if (get_sub_group_local_id() == 0) {\n        if (first_row + 0 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;\n        }\n        if (first_row + 1 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;\n        }\n        if (first_row + 2 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;\n        }\n        if (first_row + 3 < ne01) {\n            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK8_0 32\ntypedef struct {\n    half d;       // delta\n    char qs[QK8_0]; // quants\n} block_q8_0;\n\n#define NB_Q8_0 8\n\n#ifdef INTEL_GPU\n#define N_R0_Q8_0 4 // number of rows each subgroup works on\n#define N_SG_Q8_0 2 // number of subgroups in a work group\n#define N_SIMDWIDTH 16 // subgroup size\n#elif defined (ADRENO_GPU)\n#define N_R0_Q8_0 4\n#define N_SG_Q8_0 2\n#define N_SIMDWIDTH 64\n#endif\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mv_q8_0_f32(\n    global char * src0,\n    ulong         offset0,\n    global char * src1,\n    ulong         offset1,\n    global char * dst,\n    ulong         offsetd,\n    int           ne00,\n    int           ne01,\n    ulong         nb01,\n    ulong         nb02,\n    ulong         nb03,\n    int           ne12,\n    ulong         nb11,\n    ulong         nb12,\n    ulong         nb13,\n    int           ne0,\n    int           ne1,\n    int           r2,\n    int           r3\n) {\n    src0 = (global char*)((global char*)src0 + offset0);\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    int nb = ne00/QK8_0;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n\n    int first_row = (r0*N_SG_Q8_0 + get_sub_group_id()) * N_R0_Q8_0;\n\n    uint i12 = im%ne12;\n    uint i13 = im/ne12;\n\n    ulong offset_src1 = r1*nb11 + i12*nb12 + i13*nb13;\n    global float * y  = (global float *) (src1 + offset_src1);\n\n    // pointers to src0 rows\n    global block_q8_0 * ax[N_R0_Q8_0];\n    for (int row = 0; row < N_R0_Q8_0; ++row) {\n        ulong offset_src0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;\n        ax[row] = (global block_q8_0 *) ((global char *) src0 + offset_src0);\n    }\n\n    float yl[NB_Q8_0];\n    float sumf[N_R0_Q8_0] = { 0.f };\n\n    const short ix = get_sub_group_local_id()/4;\n    const short il = get_sub_group_local_id()%4;\n\n    global float * yb = y + ix*QK8_0 + il*NB_Q8_0;\n\n    // each thread handles NB_Q8_0 quants at a time\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/4) {\n        for (short i = 0; i < NB_Q8_0; ++i) {\n            yl[i] = yb[i];\n        }\n\n        for (short row = 0; row < N_R0_Q8_0; row++) {\n            global char * qs = ax[row][ib].qs + il*NB_Q8_0;\n            float sumq = 0.f;\n            for (short iq = 0; iq < NB_Q8_0; ++iq) {\n                sumq += qs[iq] * yl[iq];\n            }\n            sumf[row] += sumq*ax[row][ib].d;\n        }\n\n        yb += N_SIMDWIDTH*NB_Q8_0;\n    }\n\n    global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;\n\n    for (int row = 0; row < N_R0_Q8_0; ++row) {\n        float tot = sub_group_reduce_add(sumf[row]);\n\n        if (get_sub_group_local_id() == 0 && first_row + row < ne01) {\n            dst_f32[first_row + row] = tot;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#define QK8_0 32\ntypedef struct {\n    half d;       // delta\n    char qs[QK8_0]; // quants\n} block_q8_0;\n\n#define NB_Q8_0 8\n\n#ifdef INTEL_GPU\n#define N_R0_Q8_0 4 // number of rows each subgroup works on\n#define N_SG_Q8_0 2 // number of subgroups in a work group\n#define N_SIMDWIDTH 16 // subgroup size\n#elif defined (ADRENO_GPU)\n#define N_R0_Q8_0 4\n#define N_SG_Q8_0 2\n#define N_SIMDWIDTH 64\n#endif\n\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_16\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_mul_mv_q8_0_f32_flat(\n    global char * src0_q,\n    global half * src0_d,\n    global char * src1,\n    ulong         offset1,\n    global char * dst,\n    ulong         offsetd,\n    int           ne00,\n    int           ne01,\n    ulong         nb01,\n    ulong         nb02,\n    ulong         nb03,\n    int           ne12,\n    ulong         nb11,\n    ulong         nb12,\n    ulong         nb13,\n    int           ne0,\n    int           ne1,\n    int           r2,\n    int           r3\n) {\n    src1 = (global char*)((global char*)src1 + offset1);\n    dst  = (global char*)((global char*)dst  + offsetd);\n\n    int nb = ne00/QK8_0;\n\n    int r0 = get_group_id(0);\n    int r1 = get_group_id(1);\n    int im = get_group_id(2);\n\n    int first_row = (r0*N_SG_Q8_0 + get_sub_group_id()) * N_R0_Q8_0;\n\n    uint i12 = im%ne12;\n    uint i13 = im/ne12;\n\n    ulong offset_src1 = r1*nb11 + i12*nb12 + i13*nb13;\n    global float * y  = (global float *) (src1 + offset_src1);\n\n    // pointers to src0 rows\n    uint offset_src0_base = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;\n\n    global char * ax0, * ax1, * ax2, * ax3;\n    global half * ad0, * ad1, * ad2, * ad3;\n    uint offset_src0;\n\n    offset_src0 = offset_src0_base + 0*nb01;\n    offset_src0 = offset_src0/34;\n    ax0 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);\n    ad0 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));\n\n    offset_src0 = offset_src0_base + 1*nb01;\n    offset_src0 = offset_src0/34;\n    ax1 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);\n    ad1 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));\n\n    offset_src0 = offset_src0_base + 2*nb01;\n    offset_src0 = offset_src0/34;\n    ax2 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);\n    ad2 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));\n\n    offset_src0 = offset_src0_base + 3*nb01;\n    offset_src0 = offset_src0/34;\n    ax3 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);\n    ad3 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));\n\n    const short ix = get_sub_group_local_id()/4;\n    const short il = get_sub_group_local_id()%4;\n\n    global float * yb = y + ix*QK8_0 + il*NB_Q8_0;\n\n    float8 yl;\n    float8 qv;\n    float4 sumf = 0.f;\n    float  sumq = 0.f;\n    global char * qs;\n\n    // each thread handles NB_Q8_0 quants at a time\n    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/4) {\n        yl = vload8(0, yb);\n\n        qs = ax0 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;\n        qv = convert_float8(vload8(0, qs));\n        sumq = 0;\n        sumq += qv.s0*yl.s0;\n        sumq += qv.s1*yl.s1;\n        sumq += qv.s2*yl.s2;\n        sumq += qv.s3*yl.s3;\n        sumq += qv.s4*yl.s4;\n        sumq += qv.s5*yl.s5;\n        sumq += qv.s6*yl.s6;\n        sumq += qv.s7*yl.s7;\n        sumf.s0 += sumq*ad0[ib];\n\n        qs = ax1 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;\n        qv = convert_float8(vload8(0, qs));\n        sumq = 0;\n        sumq += qv.s0*yl.s0;\n        sumq += qv.s1*yl.s1;\n        sumq += qv.s2*yl.s2;\n        sumq += qv.s3*yl.s3;\n        sumq += qv.s4*yl.s4;\n        sumq += qv.s5*yl.s5;\n        sumq += qv.s6*yl.s6;\n        sumq += qv.s7*yl.s7;\n        sumf.s1 += sumq*ad1[ib];\n\n        qs = ax2 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;\n        qv = convert_float8(vload8(0, qs));\n        sumq = 0;\n        sumq += qv.s0*yl.s0;\n        sumq += qv.s1*yl.s1;\n        sumq += qv.s2*yl.s2;\n        sumq += qv.s3*yl.s3;\n        sumq += qv.s4*yl.s4;\n        sumq += qv.s5*yl.s5;\n        sumq += qv.s6*yl.s6;\n        sumq += qv.s7*yl.s7;\n        sumf.s2 += sumq*ad2[ib];\n\n        qs = ax3 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;\n        qv = convert_float8(vload8(0, qs));\n        sumq = 0;\n        sumq += qv.s0*yl.s0;\n        sumq += qv.s1*yl.s1;\n        sumq += qv.s2*yl.s2;\n        sumq += qv.s3*yl.s3;\n        sumq += qv.s4*yl.s4;\n        sumq += qv.s5*yl.s5;\n        sumq += qv.s6*yl.s6;\n        sumq += qv.s7*yl.s7;\n        sumf.s3 += sumq*ad3[ib];\n\n        yb += N_SIMDWIDTH*NB_Q8_0;\n    }\n\n    global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;\n\n    float4 tot = (float4)(\n        sub_group_reduce_add(sumf.s0),\n        sub_group_reduce_add(sumf.s1),\n        sub_group_reduce_add(sumf.s2),\n        sub_group_reduce_add(sumf.s3)\n    );\n\n    if (get_sub_group_local_id() == 0) {\n        if (first_row + 0 < ne01) {\n            dst_f32[first_row + 0] = tot.s0;\n        }\n        if (first_row + 1 < ne01) {\n            dst_f32[first_row + 1] = tot.s1;\n        }\n        if (first_row + 2 < ne01) {\n            dst_f32[first_row + 2] = tot.s2;\n        }\n        if (first_row + 3 < ne01) {\n            dst_f32[first_row + 3] = tot.s3;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/neg.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\nkernel void kernel_neg_f32(\n        global const float * src0,\n        ulong                offset0,\n        global       float * dst,\n        ulong                offsetd,\n        int                  n\n) {\n    if (get_global_id(0) >= n) {\n        return;\n    }\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst  = (global float*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = -src0[get_global_id(0)];\n}\n\nkernel void kernel_neg_f32_4(\n        global const float4 * src0,\n        ulong                 offset0,\n        global       float4 * dst,\n        ulong                 offsetd,\n        int                   n\n) {\n    if (get_global_id(0) >= n) {\n        return;\n    }\n    src0 = (global float4*)((global char*)src0 + offset0);\n    dst  = (global float4*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = -src0[get_global_id(0)];\n}\n\nkernel void kernel_neg_f16(\n        global const half * src0,\n        ulong               offset0,\n        global       half * dst,\n        ulong               offsetd,\n        int                 n\n) {\n    if (get_global_id(0) >= n) {\n        return;\n    }\n    src0 = (global half*)((global char*)src0 + offset0);\n    dst  = (global half*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = -src0[get_global_id(0)];\n}\n\nkernel void kernel_neg_f16_4(\n        global const half4 * src0,\n        ulong                offset0,\n        global       half4 * dst,\n        ulong                offsetd,\n        int                  n\n) {\n    if (get_global_id(0) >= n) {\n        return;\n    }\n    src0 = (global half4*)((global char*)src0 + offset0);\n    dst  = (global half4*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = -src0[get_global_id(0)];\n}\n\nkernel void kernel_neg_f32_nc(\n        global const char * src0,\n        ulong               offset0,\n        global       char * dst,\n        ulong               offsetd,\n        int   ne00,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {\n        global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n        global       float * y = (global       float *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n        *y = -*x;\n    }\n}\n\nkernel void kernel_neg_f16_nc(\n        global const char * src0,\n        ulong               offset0,\n        global       char * dst,\n        ulong               offsetd,\n        int   ne00,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {\n        global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n        global       half * y = (global       half *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n        *y = -*x;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/norm.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n//------------------------------------------------------------------------------\n// norm\n//------------------------------------------------------------------------------\nkernel void kernel_norm(\n        global void * src0,\n        ulong offset0,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        float eps,\n        local float * sum\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    dst = (global void*)((global char*)dst + offsetd);\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);\n\n    // MEAN\n    // parallel sum\n    sum[get_local_id(0)] = 0.0f;\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        sum[get_local_id(0)] += x[i00];\n    }\n    // reduce\n    barrier(CLK_LOCAL_MEM_FENCE);\n    for (uint i = get_local_size(0)/2; i > 0; i /= 2) {\n        if (get_local_id(0) < i) {\n            sum[get_local_id(0)] += sum[get_local_id(0) + i];\n        }\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n    float mean  = sum[0] / ne00;\n\n    // recenter and VARIANCE\n    barrier(CLK_LOCAL_MEM_FENCE);\n    global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;\n    sum[get_local_id(0)] = 0.0f;\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        y[i00] = x[i00] - mean;\n        sum[get_local_id(0)] += y[i00] * y[i00];\n    }\n\n    // reduce\n    barrier(CLK_LOCAL_MEM_FENCE);\n    for (uint i = get_local_size(0)/2; i > 0; i /= 2) {\n        if (get_local_id(0) < i) {\n            sum[get_local_id(0)] += sum[get_local_id(0) + i];\n        }\n        barrier(CLK_LOCAL_MEM_FENCE);\n    }\n    float variance = sum[0] / ne00;\n\n    float scale = 1.0f/sqrt(variance + eps);\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        y[i00] = y[i00] * scale;\n    }\n}\n\n//------------------------------------------------------------------------------\n// norm_mul_add\n//------------------------------------------------------------------------------\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_32\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_norm_mul_add(\n        global char * src0_ptr, ulong src0_offset,\n        global char * src1_ptr, ulong src1_offset,\n        global char * src2_ptr, ulong src2_offset,\n        global char * dst_ptr,  ulong dst_offset,\n        int ne00, int ne01, int ne02, int ne03,\n        ulong nb01, ulong nb02, ulong nb03,\n        int ne10, int ne11, int ne12, int ne13,\n        ulong nb11, ulong nb12, ulong nb13,\n        int ne20, int ne21, int ne22, int ne23,\n        ulong nb21, ulong nb22, ulong nb23,\n        ulong nbd1, ulong nbd2, ulong nbd3,\n        float eps,\n        local float2 * sums\n) {\n    const int i03 = get_group_id(2);\n    const int i02 = get_group_id(1);\n    const int i01 = get_group_id(0);\n\n    global float4 * x = (global float4 *)(src0_ptr + src0_offset + i01*nb01 + i02*nb02 + i03*nb03);\n    global float4 * w = (global float4 *)(src1_ptr + src1_offset + (i01%ne11)*nb11 + (i02%ne12)*nb12 + (i03%ne13)*nb13);\n    global float4 * b = (global float4 *)(src2_ptr + src2_offset + (i01%ne21)*nb21 + (i02%ne22)*nb22 + (i03%ne23)*nb23);\n    global float4 * y = (global float4 *)(dst_ptr  + dst_offset  + i01*nbd1 + i02*nbd2 + i03*nbd3);\n\n    float p_sum = 0.0f;\n    float p_sum_sq = 0.0f;\n\n    const int n_chunks = ne00 / 4;\n    for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {\n        float4 val = x[i00];\n        p_sum += val.x + val.y + val.z + val.w;\n        p_sum_sq += dot(val, val);\n    }\n\n    p_sum = sub_group_reduce_add(p_sum);\n    p_sum_sq = sub_group_reduce_add(p_sum_sq);\n\n    if (get_sub_group_local_id() == 0) {\n        sums[get_sub_group_id()] = (float2)(p_sum, p_sum_sq);\n    }\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    if (get_local_id(0) == 0) {\n        float sum = 0.0f;\n        float sum_sq = 0.0f;\n        for (uint i = 0; i < get_num_sub_groups(); ++i) {\n            float2 s = sums[i];\n            sum += s.x;\n            sum_sq += s.y;\n        }\n\n        const float inv_ne00 = 1.0f / (float)ne00;\n        const float mean = sum * inv_ne00;\n        const float variance = mad(-mean, mean, sum_sq * inv_ne00);\n\n        sums[0] = (float2)(mean, rsqrt(variance + eps));\n    }\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    const float2 mean_scale = sums[0];\n    const float mean = mean_scale.x;\n    const float scale = mean_scale.y;\n    const float neg_mean_scale = -mean * scale;\n\n    for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {\n        const int w_idx = ne10 > 1 ? i00 : 0;\n        const int b_idx = ne20 > 1 ? i00 : 0;\n        const float4 norm_x = mad(x[i00], (float4)scale, (float4)neg_mean_scale);\n        y[i00] = mad(norm_x, w[w_idx], b[b_idx]);\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/pad.cl",
    "content": "kernel void kernel_pad(\n        global void * src0,\n        ulong offset0,\n        global void * dst,\n        ulong offsetd,\n        int ne00, int ne01, int ne02, int ne03,\n        ulong nb00, ulong nb01, ulong nb02, ulong nb03,\n        int ne0, int ne1, int ne2, int ne3,\n        ulong nb0, ulong nb1, ulong nb2, ulong nb3,\n        int lp0, int rp0,\n        int lp1, int rp1,\n        int lp2, int rp2,\n        int lp3, int rp3\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst  = (global float*)((global char*)dst  + offsetd);\n\n    int i0 = get_global_id(0);\n    int i1 = get_group_id(1);\n    int i2 = get_group_id(2) % ne2;\n    int i3 = get_group_id(2) / ne2;\n\n    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {\n        return;\n    }\n\n    uint src0_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;\n    uint dst_idx  =         i3*nb3  +         i2*nb2  +         i1*nb1  +         i0*nb0;\n\n    global float * src0_ptr = (global float *)((global char *)src0 + src0_idx);\n    global float * dst_ptr  = (global float *)((global char *)dst  + dst_idx);\n\n    bool in_src_bounds = (i0 >= lp0 && i0 < ne0 - rp0) &&\n                         (i1 >= lp1 && i1 < ne1 - rp1) &&\n                         (i2 >= lp2 && i2 < ne2 - rp2) &&\n                         (i3 >= lp3 && i3 < ne3 - rp3);\n\n    *dst_ptr = in_src_bounds ? *src0_ptr : 0.0f;\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/relu.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// relu\n//------------------------------------------------------------------------------\nkernel void kernel_relu(\n        global float * src0,\n        ulong offset0,\n        global float * dst,\n        ulong offsetd\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]);\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/repeat.cl",
    "content": "kernel void kernel_repeat_f32(\n        global const char * src0,\n        ulong               offset0,\n        global       char * dst,\n        ulong               offsetd,\n        int     ne00,\n        int     ne01,\n        int     ne02,\n        int     ne03,\n        ulong   nb00,\n        ulong   nb01,\n        ulong   nb02,\n        ulong   nb03,\n        int     ne0,\n        ulong   nb0,\n        ulong   nb1,\n        ulong   nb2,\n        ulong   nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst  + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    const int i03 = i3%ne03;\n    const int i02 = i2%ne02;\n    const int i01 = i1%ne01;\n\n    global const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;\n    global       char * dst_ptr  = dst  +  i3*nb3  +  i2*nb2  +  i1*nb1;\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const int i00 = i0%ne00;\n        *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i00*nb00));\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/rms_norm.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n//------------------------------------------------------------------------------\n// rms_norm\n//------------------------------------------------------------------------------\n// This kernel depends on subgroup size.\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_32\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_rms_norm(\n        global void * src0,\n        ulong offset0,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        float eps,\n        local float * sum // Note, the size depends on number of subgroups\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);\n    global float * x_scalar = (global float *) x;\n    float4 sumf = 0;\n    float all_sum = 0;\n\n    // parallel sum\n    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {\n        sumf += x[i00] * x[i00];\n    }\n    all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3;\n    all_sum = sub_group_reduce_add(all_sum);\n    if (get_sub_group_local_id() == 0) {\n        sum[get_sub_group_id()] = all_sum;\n    }\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n    // broadcast\n    for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {\n       if (get_local_id(0) < i) {\n           sum[get_local_id(0)] += sum[get_local_id(0) + i];\n       }\n    }\n    if (get_local_id(0) == 0) {\n        for (int i = 4 * (ne00 / 4); i < ne00; i++) {\n            sum[0] += x_scalar[i];\n        }\n        sum[0] /= ne00;\n    }\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    const float mean  = sum[0];\n    const float scale = 1.0f/sqrt(mean + eps);\n\n    global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);\n    global float * y_scalar = (global float *) y;\n    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {\n        y[i00] = x[i00] * scale;\n    }\n    if (get_local_id(0) == 0) {\n        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {\n            y_scalar[i00] = x_scalar[i00] * scale;\n        }\n    }\n}\n\n//------------------------------------------------------------------------------\n// rms_norm_mul\n//------------------------------------------------------------------------------\n#ifdef INTEL_GPU\nREQD_SUBGROUP_SIZE_32\n#elif defined (ADRENO_GPU)\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_rms_norm_mul(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global char * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        int ne11,\n        int ne12,\n        int ne13,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3,\n        float eps,\n        local float * sum\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst  + offsetd;\n\n    // The size of sum is sizeof(float)*subgroup_size.\n    // Each subgroup writes its partial sum to this array.\n    // So the number of subgroups per workgroup for this kernel cannot exceed the subgroup size.\n    // This is generally true -\n    // for subgroup size 64, workgroup size should be less than 4096 (the max is usually 1024).\n    if (get_sub_group_id() == 0) {\n        sum[get_sub_group_local_id()] = 0.0f;\n    }\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);\n    global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11);\n\n    float sumf = 0;\n\n    // parallel sum\n    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {\n        sumf += dot(x[i00], x[i00]);\n    }\n    sumf = sub_group_reduce_add(sumf);\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    if (get_sub_group_local_id() == 0) {\n        sum[get_sub_group_id()] = sumf;\n    }\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    //for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {\n    //   if (get_local_id(0) < i) {\n    //       sum[get_local_id(0)] += sum[get_local_id(0) + i];\n    //   }\n    //}\n    //if (get_local_id(0) == 0) {\n    //    sum[0] /= ne00;\n    //}\n\n    //barrier(CLK_LOCAL_MEM_FENCE);\n\n    sumf = sum[get_sub_group_local_id()];\n    sumf = sub_group_reduce_add(sumf);\n\n    float mean  = sumf / ne00;\n    float scale = 1.0f/sqrt(mean + eps);\n\n    global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);\n    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {\n        y[i00] = (x[i00] * scale) * f[i00%(ne10/4)];\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/rope.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// kernel_rope\n//------------------------------------------------------------------------------\nfloat rope_yarn_ramp(float low, float high, int i0) {\n    const float y = (i0 / 2 - low) / max(0.001f, high - low);\n    return 1.0f - min(1.0f, max(0.0f, y));\n}\n\n// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn\n// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.\nfloat2 rope_yarn(\n    float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale\n) {\n    // Get n-d rotational scaling corrected for extrapolation\n    float theta_interp = freq_scale * theta_extrap;\n    float theta = theta_interp;\n    if (ext_factor != 0.0f) {\n        float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor;\n        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;\n\n        // Get n-d magnitude scaling corrected for interpolation\n        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);\n    }\n    return (float2)(cos(theta) * mscale, sin(theta) * mscale);\n}\n\n// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get\n// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`\nfloat rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {\n    return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));\n}\n\nfloat2 rope_yarn_corr_dims(\n    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow\n) {\n    // start and end correction dims\n    return (float2)(\n        max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))),\n        min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)))\n    );\n}\n\nkernel void kernel_rope_norm_f32(\n        global void * src0,\n        ulong offset0,\n        global int * src1,\n        ulong offset1,\n        global float * src2,\n        ulong offset2,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3,\n        int n_past,\n        int n_dims,\n        int n_ctx_orig,\n        float freq_base,\n        float freq_scale,\n        float ext_factor,\n        float attn_factor,\n        float beta_fast,\n        float beta_slow\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global int*)((global char*)src1 + offset1);\n    src2 = (global float*)((global char*)src2 + offset2);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i3 = get_group_id(2);\n    int i2 = get_group_id(1);\n    int i1 = get_group_id(0);\n\n    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);\n\n    global int * pos = src1;\n\n    float theta_base = (float) pos[i2];\n    float inv_ndims = -1.f/n_dims;\n\n    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {\n        if (i0 < n_dims) {\n            int ic = i0/2;\n\n            float theta = theta_base * pow(freq_base, inv_ndims*i0);\n\n            float freq_factor = src2 != src0 ? src2[ic] : 1.0f;\n\n            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);\n\n            global float * src       = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n            global float * dst_data  = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n            float x0 = src[0];\n            float x1 = src[1];\n\n            dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;\n            dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;\n        } else {\n            global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n            global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n            dst_data[0] = src[0];\n            dst_data[1] = src[1];\n        }\n    }\n}\n\nkernel void kernel_rope_norm_f16(\n        global void * src0,\n        ulong offset0,\n        global int * src1,\n        ulong offset1,\n        global float * src2,\n        ulong offset2,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3,\n        int n_past,\n        int n_dims,\n        int n_ctx_orig,\n        float freq_base,\n        float freq_scale,\n        float ext_factor,\n        float attn_factor,\n        float beta_fast,\n        float beta_slow\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global int*)((global char*)src1 + offset1);\n    src2 = (global float*)((global char*)src2 + offset2);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i3 = get_group_id(2);\n    int i2 = get_group_id(1);\n    int i1 = get_group_id(0);\n\n    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);\n\n    global int * pos = src1;\n\n    float theta_base = (float) pos[i2];\n    float inv_ndims = -1.f/n_dims;\n\n    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {\n        if (i0 < n_dims) {\n            int ic = i0/2;\n\n            float theta = theta_base * pow(freq_base, inv_ndims*i0);\n\n            float freq_factor = src2 != src0 ? src2[ic] : 1.0f;\n\n            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);\n\n            global half * src       = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n            global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n            float x0 = src[0];\n            float x1 = src[1];\n\n            dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;\n            dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;\n        } else {\n            global half * src      = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n            global half * dst_data = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n            dst_data[0] = src[0];\n            dst_data[1] = src[1];\n        }\n    }\n}\n\nkernel void kernel_rope_neox_f32(\n        global void * src0,\n        ulong offset0,\n        global int * src1,\n        ulong offset1,\n        global float * src2,\n        ulong offset2,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3,\n        int n_past,\n        int n_dims,\n        int n_ctx_orig,\n        float freq_base,\n        float freq_scale,\n        float ext_factor,\n        float attn_factor,\n        float beta_fast,\n        float beta_slow\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global int*)((global char*)src1 + offset1);\n    src2 = (global float*)((global char*)src2 + offset2);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i3 = get_group_id(2);\n    int i2 = get_group_id(1);\n    int i1 = get_group_id(0);\n\n    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);\n\n    global int * pos = src1;\n\n    float theta_base = (float) pos[i2];\n    float inv_ndims = -1.f/n_dims;\n\n    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {\n        if (i0 < n_dims) {\n            int ic = i0/2;\n\n            const float theta = theta_base * pow(freq_base, inv_ndims*i0);\n\n            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;\n\n            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);\n\n            global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);\n            global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);\n\n            const float x0 = src[0];\n            const float x1 = src[n_dims/2];\n\n            dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;\n            dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;\n        } else {\n            global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n            global float * dst_data  = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n            dst_data[0] = src[0];\n            dst_data[1] = src[1];\n        }\n    }\n}\n\nkernel void kernel_rope_neox_f16(\n        global void * src0,\n        ulong offset0,\n        global int * src1,\n        ulong offset1,\n        global float * src2,\n        ulong offset2,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3,\n        int n_past,\n        int n_dims,\n        int n_ctx_orig,\n        float freq_base,\n        float freq_scale,\n        float ext_factor,\n        float attn_factor,\n        float beta_fast,\n        float beta_slow\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global int*)((global char*)src1 + offset1);\n    src2 = (global float*)((global char*)src2 + offset2);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i3 = get_group_id(2);\n    int i2 = get_group_id(1);\n    int i1 = get_group_id(0);\n\n    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);\n\n    global int * pos = src1;\n\n    float theta_base = (float) pos[i2];\n    float inv_ndims = -1.f/n_dims;\n\n    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {\n        if (i0 < n_dims) {\n            int ic = i0/2;\n\n            const float theta = theta_base * pow(freq_base, inv_ndims*i0);\n\n            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;\n\n            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);\n\n            global half * src       = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);\n            global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);\n\n            const float x0 = src[0];\n            const float x1 = src[n_dims/2];\n\n            dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;\n            dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;\n        } else {\n            global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n            global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n            dst_data[0] = src[0];\n            dst_data[1] = src[1];\n        }\n    }\n}\n\nkernel void kernel_rope_multi_f32(\n        global void * src0,\n        ulong offset0,\n        global int * src1,\n        ulong offset1,\n        global float * src2,\n        ulong offset2,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3,\n        int n_past,\n        int n_dims,\n        int n_ctx_orig,\n        float freq_base,\n        float freq_scale,\n        float ext_factor,\n        float attn_factor,\n        float beta_fast,\n        float beta_slow,\n        int4 sections,\n        int  is_imrope\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global int*)((global char*)src1 + offset1);\n    src2 = (global float*)((global char*)src2 + offset2);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i3 = get_group_id(2);\n    int i2 = get_group_id(1);\n    int i1 = get_group_id(0);\n\n    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);\n\n    global int * pos = src1;\n\n    const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;\n    const int sec_w = sections.s1 + sections.s0;\n\n    float inv_ndims = -1.f/n_dims;\n\n    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {\n        if (i0 < n_dims) {\n            int ic = i0/2;\n\n            const int sector = (i0 / 2) % sect_dims;\n            float theta_base = 0.0f;\n\n            if (is_imrope) {\n                if (sector % 3 == 1 && sector < 3 * sections.s1) { // h\n                    theta_base = (float) pos[i2 + ne02 * 1];\n                } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w\n                    theta_base = (float) pos[i2 + ne02 * 2];\n                } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t\n                    theta_base = (float) pos[i2 + ne02 * 0];\n                } else { // e\n                    theta_base = (float) pos[i2 + ne02 * 3];\n                }\n            } else {\n                if (sector < sections.s0) {\n                    theta_base = pos[i2];\n                }\n                else if (sector >= sections.s0 && sector < sec_w) {\n                    theta_base = pos[i2 + ne2 * 1];\n                }\n                else if (sector >= sec_w && sector < sec_w + sections.s2) {\n                    theta_base = pos[i2 + ne2 * 2];\n                }\n                else if (sector >= sec_w + sections.s2) {\n                    theta_base = pos[i2 + ne2 * 3];\n                }\n            }\n\n            const float theta = theta_base * pow(freq_base, inv_ndims*i0);\n\n            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;\n\n            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);\n\n            global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);\n            global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);\n\n            const float x0 = src[0];\n            const float x1 = src[n_dims/2];\n\n            dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;\n            dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;\n        } else {\n            global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n            global float * dst_data  = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n            dst_data[0] = src[0];\n            dst_data[1] = src[1];\n        }\n    }\n}\n\nkernel void kernel_rope_multi_f16(\n        global void * src0,\n        ulong offset0,\n        global int * src1,\n        ulong offset1,\n        global float * src2,\n        ulong offset2,\n        global half * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3,\n        int n_past,\n        int n_dims,\n        int n_ctx_orig,\n        float freq_base,\n        float freq_scale,\n        float ext_factor,\n        float attn_factor,\n        float beta_fast,\n        float beta_slow,\n        int4 sections,\n        int  is_imrope\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global int*)((global char*)src1 + offset1);\n    src2 = (global float*)((global char*)src2 + offset2);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i3 = get_group_id(2);\n    int i2 = get_group_id(1);\n    int i1 = get_group_id(0);\n\n    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);\n\n    global int * pos = src1;\n\n    const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;\n    const int sec_w = sections.s1 + sections.s0;\n\n    float inv_ndims = -1.f/n_dims;\n\n    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {\n        if (i0 < n_dims) {\n            int ic = i0/2;\n\n            const int sector = (i0 / 2) % sect_dims;\n            float theta_base = 0.0f;\n\n            if (is_imrope) {\n                if (sector % 3 == 1 && sector < 3 * sections.s1) { // h\n                    theta_base = (float) pos[i2 + ne02 * 1];\n                } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w\n                    theta_base = (float) pos[i2 + ne02 * 2];\n                } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t\n                    theta_base = (float) pos[i2 + ne02 * 0];\n                } else { // e\n                    theta_base = (float) pos[i2 + ne02 * 3];\n                }\n            } else {\n                if (sector < sections.s0) {\n                    theta_base = pos[i2];\n                }\n                else if (sector >= sections.s0 && sector < sec_w) {\n                    theta_base = pos[i2 + ne2 * 1];\n                }\n                else if (sector >= sec_w && sector < sec_w + sections.s2) {\n                    theta_base = pos[i2 + ne2 * 2];\n                }\n                else if (sector >= sec_w + sections.s2) {\n                    theta_base = pos[i2 + ne2 * 3];\n                }\n            }\n\n            const float theta = theta_base * pow(freq_base, inv_ndims*i0);\n\n            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;\n\n            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);\n\n            global half * src      = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);\n            global half * dst_data = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);\n\n            const float x0 = src[0];\n            const float x1 = src[n_dims/2];\n\n            dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;\n            dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;\n        } else {\n            global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n            global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n            dst_data[0] = src[0];\n            dst_data[1] = src[1];\n        }\n    }\n}\n\nkernel void kernel_rope_vision_f32(\n        global void * src0,\n        ulong offset0,\n        global int * src1,\n        ulong offset1,\n        global float * src2,\n        ulong offset2,\n        global float * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3,\n        int n_past,\n        int n_dims,\n        int n_ctx_orig,\n        float freq_base,\n        float freq_scale,\n        float ext_factor,\n        float attn_factor,\n        float beta_fast,\n        float beta_slow,\n        int4 sections\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global int*)((global char*)src1 + offset1);\n    src2 = (global float*)((global char*)src2 + offset2);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i3 = get_group_id(2);\n    int i2 = get_group_id(1);\n    int i1 = get_group_id(0);\n\n    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);\n\n    global int * pos = src1;\n\n    const int sect_dims = sections.s0 + sections.s1;\n    const int sec_w = sections.s1 + sections.s0;\n\n    float inv_ndims = -1.f/n_dims;\n\n    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {\n        int ic = i0/2;\n\n        const int sector = (i0/2) % sect_dims;\n        float theta_base = 0.0f;\n\n        if (sector < sections.s0) {\n            const int p = sector;\n            theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);\n        } else if (sector >= sections.s0 && sector < sec_w) {\n            const int p = sector - sections.s0;\n            theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);\n        }\n\n        const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;\n\n        float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);\n\n        global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);\n        global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);\n\n        const float x0 = src[0];\n        const float x1 = src[n_dims];\n\n        dst_data[0]      = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;\n        dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;\n    }\n}\n\nkernel void kernel_rope_vision_f16(\n        global void * src0,\n        ulong offset0,\n        global int * src1,\n        ulong offset1,\n        global float * src2,\n        ulong offset2,\n        global half * dst,\n        ulong offsetd,\n        int ne00,\n        int ne01,\n        int ne02,\n        int ne03,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne0,\n        int ne1,\n        int ne2,\n        int ne3,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3,\n        int n_past,\n        int n_dims,\n        int n_ctx_orig,\n        float freq_base,\n        float freq_scale,\n        float ext_factor,\n        float attn_factor,\n        float beta_fast,\n        float beta_slow,\n        int4 sections\n) {\n    src0 = (global void*)((global char*)src0 + offset0);\n    src1 = (global int*)((global char*)src1 + offset1);\n    src2 = (global float*)((global char*)src2 + offset2);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int i3 = get_group_id(2);\n    int i2 = get_group_id(1);\n    int i1 = get_group_id(0);\n\n    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);\n\n    global int * pos = src1;\n\n    const int sect_dims = sections.s0 + sections.s1;\n    const int sec_w = sections.s1 + sections.s0;\n\n    float inv_ndims = -1.f/n_dims;\n\n    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {\n        int ic = i0/2;\n\n        const int sector = (i0/2) % sect_dims;\n        float theta_base = 0.0f;\n\n        if (sector < sections.s0) {\n            const int p = sector;\n            theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);\n        } else if (sector >= sections.s0 && sector < sec_w) {\n            const int p = sector - sections.s0;\n            theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);\n        }\n\n        const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;\n\n        float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);\n\n        global half * src      = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);\n        global half * dst_data = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);\n\n        const float x0 = src[0];\n        const float x1 = src[n_dims];\n\n        dst_data[0]      = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;\n        dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/scale.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\nkernel void kernel_scale_f32(\n        global float * src0,\n        ulong offset0,\n        global float * dst,\n        ulong offsetd,\n        float scale,\n        float bias\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst = (global float*)((global char*)dst + offsetd);\n    dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias;\n}\n\nkernel void kernel_scale_f32_4(\n        global float4 * src0,\n        ulong offset0,\n        global float4 * dst,\n        ulong offsetd,\n        float scale,\n        float bias\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    dst = (global float4*)((global char*)dst + offsetd);\n    dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias;\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/set_rows.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n// v = { mp, L, d }\ninline uint fastdiv(uint n, uint4 v) {\n    uint msbs;\n    msbs = mul_hi(n, v.s0);\n    return (msbs + n) >> v.s1;\n}\ninline uint fastmod(uint n, uint4 v) {\n    uint q = fastdiv(n, v);\n    return n - q * v.s2;\n}\n\nkernel void kernel_set_rows_f32_i64(\n        global char * src0,\n        ulong         offset0,\n        global char * src1,\n        ulong         offset1,\n        global char * dst,\n        ulong         offsetd,\n        int           ne01,\n        ulong         nb01,\n        ulong         nb02,\n        ulong         nb03,\n        uint4         ne11,\n        uint4         ne12,\n        ulong         nb10,\n        ulong         nb11,\n        ulong         nb12,\n        int           nblk0,\n        ulong         nb1,\n        ulong         nb2,\n        ulong         nb3\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst  + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);\n\n    if (i01 >= ne01) {\n        return;\n    }\n\n    //int i12 = i03%ne12;\n    //int i11 = i02%ne11;\n    int i12 = fastmod(i03, ne12);\n    int i11 = fastmod(i02, ne11);\n\n    int i10 = i01;\n    long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];\n\n    global float * dst_row = (global float *) (dst  +  i1*nb1  + i02*nb2  + i03*nb3);\n    global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);\n\n    for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {\n        dst_row[ind] = (float)src_row[ind];\n    }\n}\n\nkernel void kernel_set_rows_f16_i64(\n        global char * src0,\n        ulong         offset0,\n        global char * src1,\n        ulong         offset1,\n        global char * dst,\n        ulong         offsetd,\n        int           ne01,\n        ulong         nb01,\n        ulong         nb02,\n        ulong         nb03,\n        uint4         ne11,\n        uint4         ne12,\n        ulong         nb10,\n        ulong         nb11,\n        ulong         nb12,\n        int           nblk0,\n        ulong         nb1,\n        ulong         nb2,\n        ulong         nb3\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst  + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);\n\n    if (i01 >= ne01) {\n        return;\n    }\n\n    //int i12 = i03%ne12;\n    //int i11 = i02%ne11;\n    int i12 = fastmod(i03, ne12);\n    int i11 = fastmod(i02, ne11);\n\n    int i10 = i01;\n    long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];\n\n    global half  * dst_row = (global half  *) (dst  +  i1*nb1  + i02*nb2  + i03*nb3);\n    global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);\n\n    for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {\n        dst_row[ind] = src_row[ind];\n    }\n}\n\nkernel void kernel_set_rows_f32_i32(\n        global char * src0,\n        ulong         offset0,\n        global char * src1,\n        ulong         offset1,\n        global char * dst,\n        ulong         offsetd,\n        int           ne01,\n        ulong         nb01,\n        ulong         nb02,\n        ulong         nb03,\n        uint4         ne11,\n        uint4         ne12,\n        ulong         nb10,\n        ulong         nb11,\n        ulong         nb12,\n        int           nblk0,\n        ulong         nb1,\n        ulong         nb2,\n        ulong         nb3\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst  + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);\n\n    if (i01 >= ne01) {\n        return;\n    }\n\n    //int i12 = i03%ne12;\n    //int i11 = i02%ne11;\n    int i12 = fastmod(i03, ne12);\n    int i11 = fastmod(i02, ne11);\n\n    int i10 = i01;\n    int i1  = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];\n\n    global float * dst_row = (global float *) (dst  +  i1*nb1  + i02*nb2  + i03*nb3);\n    global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);\n\n    for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {\n        dst_row[ind] = (float)src_row[ind];\n    }\n}\n\nkernel void kernel_set_rows_f16_i32(\n        global char * src0,\n        ulong         offset0,\n        global char * src1,\n        ulong         offset1,\n        global char * dst,\n        ulong         offsetd,\n        int           ne01,\n        ulong         nb01,\n        ulong         nb02,\n        ulong         nb03,\n        uint4         ne11,\n        uint4         ne12,\n        ulong         nb10,\n        ulong         nb11,\n        ulong         nb12,\n        int           nblk0,\n        ulong         nb1,\n        ulong         nb2,\n        ulong         nb3\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst  + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);\n\n    if (i01 >= ne01) {\n        return;\n    }\n\n    //int i12 = i03%ne12;\n    //int i11 = i02%ne11;\n    int i12 = fastmod(i03, ne12);\n    int i11 = fastmod(i02, ne11);\n\n    int i10 = i01;\n    int i1  = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];\n\n    global half  * dst_row = (global half  *) (dst  +  i1*nb1  + i02*nb2  + i03*nb3);\n    global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);\n\n    for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {\n        dst_row[ind] = src_row[ind];\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/sigmoid.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// sigmoid\n//------------------------------------------------------------------------------\n\nkernel void kernel_sigmoid_f32(\n        global float * src0,\n        ulong offset0,\n        global float * dst,\n        ulong offsetd\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = 1.0f / (1.0f + exp(-src0[get_global_id(0)]));\n}\n\nkernel void kernel_sigmoid_f16(\n        global half * src0,\n        ulong offset0,\n        global half * dst,\n        ulong offsetd\n) {\n    src0 = (global half*)((global char*)src0 + offset0);\n    dst = (global half*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = 1.0f / (1.0f + exp(-src0[get_global_id(0)]));\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/silu.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// silu\n//------------------------------------------------------------------------------\nkernel void kernel_silu(\n        global float * src0,\n        ulong offset0,\n        global float * dst,\n        ulong offsetd\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    float x = src0[get_global_id(0)];\n    dst[get_global_id(0)] = x / (1.0f + exp(-x));\n}\n\nkernel void kernel_silu_4(\n        global float4 * src0,\n        ulong offset0,\n        global float4 * dst,\n        ulong offsetd\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    dst = (global float4*)((global char*)dst + offsetd);\n\n    float4 x = src0[get_global_id(0)];\n    dst[get_global_id(0)] = x / (1.0f + exp(-x));\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/softmax_4_f16.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_soft_max_4_f16(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global char * src2,\n        ulong offset2,\n        global char * dst,\n        ulong offsetd,\n        int ne00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne12,\n        int ne13,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3,\n        float scale,\n        float max_bias,\n        float m0,\n        float m1,\n        int n_head_log2\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    src2 = src2 + offset2;\n    dst  = dst  + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int i13 = i03%ne13;\n    int i12 = i02%ne12;\n    int i11 = i01;\n\n    global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);\n    global half4  * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;\n    global float  * psrc2 = src2 != src0 ? (global float *)(src2) : 0;\n    global float4 * pdst4 = (global float4 *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);\n\n    float slope = 1.0f;\n\n    // ALiBi\n    if (max_bias > 0.0f) {\n        int h = i02;\n\n        float base = h < n_head_log2 ? m0 : m1;\n        int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;\n\n        slope = pow(base, exp);\n    }\n\n    // parallel max\n    float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;\n    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {\n        lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));\n    }\n    float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3));\n\n    const float max = sub_group_reduce_max(lmax);\n\n    // parallel sum\n    float4 lsum4 = 0.0f;\n    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {\n        const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max);\n        lsum4 += exp_psrc4;\n        pdst4[i00] = exp_psrc4;\n    }\n    float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;\n\n    float sum = sub_group_reduce_add(lsum);\n\n    if (psrc2) {\n        sum += exp(psrc2[i02] - max);\n    }\n\n    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {\n        pdst4[i00] /= sum;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/softmax_4_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_soft_max_4(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global char * src2,\n        ulong offset2,\n        global char * dst,\n        ulong offsetd,\n        int ne00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne12,\n        int ne13,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3,\n        float scale,\n        float max_bias,\n        float m0,\n        float m1,\n        int n_head_log2\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    src2 = src2 + offset2;\n    dst  = dst  + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int i13 = i03%ne13;\n    int i12 = i02%ne12;\n    int i11 = i01;\n\n    global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);\n    global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;\n    global float  * psrc2 = src2 != src0 ? (global float  *)(src2) : 0;\n    global float4 * pdst4 = (global float4 *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);\n\n    float slope = 1.0f;\n\n    // ALiBi\n    if (max_bias > 0.0f) {\n        int h = i02;\n\n        float base = h < n_head_log2 ? m0 : m1;\n        int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;\n\n        slope = pow(base, exp);\n    }\n\n    // parallel max\n    float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;\n    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {\n        lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));\n    }\n    float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3));\n\n    const float max = sub_group_reduce_max(lmax);\n\n    // parallel sum\n    float4 lsum4 = 0.0f;\n    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {\n        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);\n        lsum4 += exp_psrc4;\n        pdst4[i00] = exp_psrc4;\n    }\n    float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;\n\n    float sum = sub_group_reduce_add(lsum);\n\n    if (psrc2) {\n        sum += exp(psrc2[i02] - max);\n    }\n\n    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {\n        pdst4[i00] /= sum;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/softmax_f16.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_soft_max_f16(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global char * src2,\n        ulong offset2,\n        global char * dst,\n        ulong offsetd,\n        int ne00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne12,\n        int ne13,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3,\n        float scale,\n        float max_bias,\n        float m0,\n        float m1,\n        int n_head_log2\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    src2 = src2 + offset2;\n    dst  = dst  + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int i13 = i03%ne13;\n    int i12 = i02%ne12;\n    int i11 = i01;\n\n    global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);\n    global half  * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;\n    global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;\n    global float * pdst  = (global float *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);\n\n    float slope = 1.0f;\n\n    // ALiBi\n    if (max_bias > 0.0f) {\n        int h = i02;\n\n        float base = h < n_head_log2 ? m0 : m1;\n        int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;\n\n        slope = pow(base, exp);\n    }\n\n    // parallel max\n    float lmax = psrc2 ? psrc2[i02] : -INFINITY;\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));\n    }\n    float max = sub_group_reduce_max(lmax);\n\n    // parallel sum\n    float lsum = 0.0f;\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);\n        lsum += exp_psrc0;\n        // Remember the result of exp here. exp is expensive, so we really do not\n        // wish to compute it twice.\n        pdst[i00] = exp_psrc0;\n    }\n\n    float sum = sub_group_reduce_add(lsum);\n\n    if (psrc2) {\n        sum += exp(psrc2[i02] - max);\n    }\n\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        pdst[i00] /= sum;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/softmax_f32.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n#ifdef cl_intel_subgroups\n#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n#else\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n#endif\n\n#ifdef cl_intel_required_subgroup_size\n#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable\n#define INTEL_GPU 1\n#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))\n#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))\n#elif defined(cl_qcom_reqd_sub_group_size)\n#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable\n#define ADRENO_GPU 1\n#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size(\"half\")))\n#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size(\"full\")))\n#endif\n\n#ifdef ADRENO_GPU\nREQD_SUBGROUP_SIZE_64\n#endif\nkernel void kernel_soft_max(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global char * src2,\n        ulong offset2,\n        global char * dst,\n        ulong offsetd,\n        int ne00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne12,\n        int ne13,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3,\n        float scale,\n        float max_bias,\n        float m0,\n        float m1,\n        int n_head_log2\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    src2 = src2 + offset2;\n    dst  = dst  + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int i13 = i03%ne13;\n    int i12 = i02%ne12;\n    int i11 = i01;\n\n    global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);\n    global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;\n    global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;\n    global float * pdst  = (global float *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);\n\n    float slope = 1.0f;\n\n    // ALiBi\n    if (max_bias > 0.0f) {\n        int h = i02;\n\n        float base = h < n_head_log2 ? m0 : m1;\n        int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;\n\n        slope = pow(base, exp);\n    }\n\n    // parallel max\n    float lmax = psrc2 ? psrc2[i02] : -INFINITY;\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));\n    }\n    float max = sub_group_reduce_max(lmax);\n\n    // parallel sum\n    float lsum = 0.0f;\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);\n        lsum += exp_psrc0;\n        // Remember the result of exp here. exp is expensive, so we really do not\n        // wish to compute it twice.\n        pdst[i00] = exp_psrc0;\n    }\n\n    float sum = sub_group_reduce_add(lsum);\n\n    if (psrc2) {\n        sum += exp(psrc2[i02] - max);\n    }\n\n    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {\n        pdst[i00] /= sum;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/softplus.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// softplus\n//------------------------------------------------------------------------------\n\nkernel void kernel_softplus_f32(\n        global const float * src0,\n        ulong                offset0,\n        global       float * dst,\n        ulong                offsetd\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst  = (global float*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));\n}\n\nkernel void kernel_softplus_f32_4(\n        global const float4 * src0,\n        ulong                 offset0,\n        global       float4 * dst,\n        ulong                 offsetd\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    dst  = (global float4*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));\n}\n\nkernel void kernel_softplus_f16(\n        global const half * src0,\n        ulong               offset0,\n        global       half * dst,\n        ulong               offsetd\n) {\n    src0 = (global half*)((global char*)src0 + offset0);\n    dst  = (global half*)((global char*)dst + offsetd);\n\n    const float x = convert_float(src0[get_global_id(0)]);\n    dst[get_global_id(0)] = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));\n}\n\nkernel void kernel_softplus_f16_4(\n        global const half4 * src0,\n        ulong                offset0,\n        global       half4 * dst,\n        ulong                offsetd\n) {\n    src0 = (global half4*)((global char*)src0 + offset0);\n    dst  = (global half4*)((global char*)dst + offsetd);\n\n    const float4 x = convert_float4(src0[get_global_id(0)]);\n    dst[get_global_id(0)] = convert_half4_rte((x > 20.0f) ? x : log(1.0f + exp(x)));\n}\n\nkernel void kernel_softplus_f32_nc(\n        global const char * src0,\n        ulong               offset0,\n        global       char * dst,\n        ulong               offsetd,\n        int   ne00,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {\n        global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n        global       float * y = (global       float *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n        *y = (*x > 20.0f) ? *x : log(1.0f + exp(*x));\n    }\n}\n\nkernel void kernel_softplus_f16_nc(\n        global const char * src0,\n        ulong               offset0,\n        global       char * dst,\n        ulong               offsetd,\n        int   ne00,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {\n        global const half * hx = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n        global       half * hy = (global       half *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n        const float x = convert_float(*hx);\n        *hy = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/solve_tri.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// solve_tri\n//------------------------------------------------------------------------------\nkernel void kernel_solve_tri_f32(\n        global uchar * src0,\n        ulong offset0,\n        global uchar * src1,\n        ulong offset1,\n        global uchar * dst,\n        ulong offsetd,\n        int n,\n        int k,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    int col = get_global_id(0);\n    int i2 = get_global_id(1);\n    int i3 = get_global_id(2);\n\n    global const uchar * Lb = src0 + offset0 + i2 * nb02 + i3 * nb03;\n    global const uchar * Bb = src1 + offset1 + i2 * nb12 + i3 * nb13;\n    global       uchar * Xb = dst + offsetd + i2 * nb2 + i3 * nb3;\n\n    for(int row = 0; row < n; ++row){\n        global const float *pB = (global const float *)(Bb + row * nb11 + col * nb10);\n\n        float sum = 0.0f;\n        for(int j = 0; j < row; ++j){\n            global const float *pL = (global const float *)(Lb + row * nb01 + j * nb00);\n            global const float *pX = (global const float *)(Xb + j * nb1 + col * nb0);\n            sum += (*pL) * (*pX);\n        }\n\n        global const float * pDiag = (global const float *)(Lb + row * nb01 + row *nb00);\n        global float * pOut = (global float *)(Xb + row * nb1 + col *nb0);\n\n        *pOut = ((* pB) - sum) / (*pDiag);\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/sqr.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\nkernel void kernel_sqr_cont_f32(\n    global float * src0,\n    ulong          offset0,\n    global float * dst,\n    ulong          offsetd\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst  = (global float*)((global char*)dst + offsetd);\n\n    uint gid = get_global_id(0);\n    dst[gid] = src0[gid] * src0[gid];\n}\n\nkernel void kernel_sqr_cont_f32_4(\n    global float4 * src0,\n    ulong           offset0,\n    global float4 * dst,\n    ulong           offsetd\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    dst  = (global float4*)((global char*)dst + offsetd);\n\n    uint gid = get_global_id(0);\n    dst[gid] = src0[gid] * src0[gid];\n}\n\nkernel void kernel_sqr_cont_f16(\n    global half * src0,\n    ulong         offset0,\n    global half * dst,\n    ulong         offsetd\n) {\n    src0 = (global half*)((global char*)src0 + offset0);\n    dst  = (global half*)((global char*)dst + offsetd);\n\n    uint gid = get_global_id(0);\n    dst[gid] = src0[gid] * src0[gid];\n}\n\nkernel void kernel_sqr_cont_f16_4(\n    global half4 * src0,\n    ulong          offset0,\n    global half4 * dst,\n    ulong          offsetd\n) {\n    src0 = (global half4*)((global char*)src0 + offset0);\n    dst  = (global half4*)((global char*)dst + offsetd);\n\n    uint gid = get_global_id(0);\n    dst[gid] = src0[gid] * src0[gid];\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/sqrt.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\nkernel void kernel_sqrt_cont_f32(\n    global float * src0,\n    ulong          offset0,\n    global float * dst,\n    ulong          offsetd\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst  = (global float*)((global char*)dst + offsetd);\n\n    uint gid = get_global_id(0);\n    dst[gid] = sqrt(src0[gid]);\n}\n\nkernel void kernel_sqrt_cont_f32_4(\n    global float4 * src0,\n    ulong           offset0,\n    global float4 * dst,\n    ulong           offsetd\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    dst  = (global float4*)((global char*)dst + offsetd);\n\n    uint gid = get_global_id(0);\n    dst[gid] = sqrt(src0[gid]);\n}\n\nkernel void kernel_sqrt_cont_f16(\n    global half * src0,\n    ulong         offset0,\n    global half * dst,\n    ulong         offsetd\n) {\n    src0 = (global half*)((global char*)src0 + offset0);\n    dst  = (global half*)((global char*)dst + offsetd);\n\n    uint gid = get_global_id(0);\n    dst[gid] = convert_half(sqrt(convert_float(src0[gid])));\n}\n\nkernel void kernel_sqrt_cont_f16_4(\n    global half4 * src0,\n    ulong          offset0,\n    global half4 * dst,\n    ulong          offsetd\n) {\n    src0 = (global half4*)((global char*)src0 + offset0);\n    dst  = (global half4*)((global char*)dst + offsetd);\n\n    uint gid = get_global_id(0);\n    dst[gid] = convert_half4(sqrt(convert_float4(src0[gid])));\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/ssm_conv.cl",
    "content": "kernel void kernel_ssm_conv_f32_f32(\n    global char * src0,\n    ulong         offset0,\n    global char * src1,\n    ulong         offset1,\n    global char * dst,\n    ulong         offsetd,\n    ulong         nb00,\n    ulong         nb01,\n    ulong         nb02,\n    int           ne10,\n    ulong         nb11,\n    ulong         nb0,\n    ulong         nb1,\n    ulong         nb2\n){\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst  + offsetd;\n\n    int ir = get_global_id(0);\n    int i2 = get_global_id(1);\n    int i3 = get_global_id(2);\n\n    int nc  = ne10;\n\n    global float * s = (global float *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);\n    global float * c = (global float *) (src1 + ir*nb11);\n    global float * d = (global float *) (dst  + ir*nb0  + i2*nb1  + i3*nb2);\n\n    float sumf = 0.0f;\n\n    for (int i0 = 0; i0 < nc; ++i0) {\n        sumf += s[i0] * c[i0];\n    }\n\n    d[0] = sumf;\n}\n\nkernel void kernel_ssm_conv_f32_f32_4(\n    global char * src0,\n    ulong         offset0,\n    global char * src1,\n    ulong         offset1,\n    global char * dst,\n    ulong         offsetd,\n    ulong         nb00,\n    ulong         nb01,\n    ulong         nb02,\n    int           ne10,\n    ulong         nb11,\n    ulong         nb0,\n    ulong         nb1,\n    ulong         nb2\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst  + offsetd;\n\n    int ir = get_global_id(0);\n    int i2 = get_global_id(1);\n    int i3 = get_global_id(2);\n\n    int nc = ne10;\n\n    global float4 * s = (global float4 *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);\n    global float4 * c = (global float4 *) (src1 + ir*nb11);\n    global float  * d = (global float  *) (dst  + ir*nb0  + i2*nb1  + i3*nb2);\n\n    float sumf = 0.0f;\n\n    for (int i0 = 0; i0 < nc/4; ++i0) {\n        sumf += dot(s[i0], c[i0]);\n    }\n\n    d[0] = sumf;\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/sub.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// div\n//------------------------------------------------------------------------------\nkernel void kernel_sub(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global char * dst,\n        ulong offsetd,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        int ne11,\n        int ne12,\n        int ne13,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        int ne0,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int i13 = i03 % ne13;\n    int i12 = i02 % ne12;\n    int i11 = i01 % ne11;\n\n    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;\n    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;\n    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const int i10 = i0 % ne10;\n        *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) - *((global float *)(src1_ptr + i10*nb10));\n    }\n}\n\n// assumption: src1 is a row\n// broadcast src1 into src0\nkernel void kernel_sub_row(\n        global float4 * src0,\n        ulong offset0,\n        global float4 * src1,\n        ulong offset1,\n        global float4 * dst,\n        ulong offsetd,\n        int ne\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    src1 = (global float4*)((global char*)src1 + offset1);\n    dst = (global float4*)((global char*)dst + offsetd);\n\n    // This performs better than using %.\n    uint gid = get_global_id(0);\n    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne\n    dst[gid] = src0[gid] - src1[idx1];\n}\n\nkernel void kernel_sub_f16(\n        global char * src0,\n        ulong offset0,\n        global char * src1,\n        ulong offset1,\n        global char * dst,\n        ulong offsetd,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        int ne10,\n        int ne11,\n        int ne12,\n        int ne13,\n        ulong nb10,\n        ulong nb11,\n        ulong nb12,\n        ulong nb13,\n        int ne0,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    src1 = src1 + offset1;\n    dst  = dst + offsetd;\n\n    int i03 = get_group_id(2);\n    int i02 = get_group_id(1);\n    int i01 = get_group_id(0);\n\n    int i13 = i03 % ne13;\n    int i12 = i02 % ne12;\n    int i11 = i01 % ne11;\n\n    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;\n    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;\n    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;\n\n    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {\n        const int i10 = i0 % ne10;\n        *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) - *((global half *)(src1_ptr + i10*nb10));\n    }\n}\n\nkernel void kernel_sub_row_f16(\n        global half4 * src0,\n        ulong offset0,\n        global half4 * src1,\n        ulong offset1,\n        global half4 * dst,\n        ulong offsetd,\n        int ne\n) {\n    src0 = (global half4*)((global char*)src0 + offset0);\n    src1 = (global half4*)((global char*)src1 + offset1);\n    dst = (global half4*)((global char*)dst + offsetd);\n\n    // This performs better than using %.\n    uint gid = get_global_id(0);\n    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne\n    dst[gid] = src0[gid] - src1[idx1];\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/sum_rows.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n\n// Most devices have max workgroup size of 1024, so this is enough for subgroup\n// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes\n#define MAX_SUBGROUPS 64\nkernel void kernel_sum_rows_f32(\n    global char *  src0,\n    ulong           offset0,\n    global char *  dst,\n    ulong           offsetd,\n    int             ne00,\n    int             ne01,\n    int             ne02,\n    int             ne03,\n    ulong           nb01,\n    ulong           nb02,\n    ulong           nb03,\n    ulong           nb1,\n    ulong           nb2,\n    ulong           nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst  + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    const int lid = get_local_id(0);\n    const int lsize = get_local_size(0);\n\n    const uint sg_size = get_sub_group_size();\n    const uint sg_id = get_sub_group_id();\n    const uint sg_lid = get_sub_group_local_id();\n\n    __local float lmem[MAX_SUBGROUPS];\n\n    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {\n        return;\n    }\n\n    if(sg_id == 0){\n        lmem[sg_lid] = 0.0f;\n    }\n\n    global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);\n    global float * dst_row = (global float *) (dst  + i1*nb1  + i2*nb2  + i3*nb3);\n\n    float sumf = 0.0f;\n\n    for (int i0 = lid; i0 < ne00; i0 += lsize) {\n        sumf += src_row[i0];\n    }\n\n    sumf = sub_group_reduce_add(sumf);\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    if(sg_lid == 0){\n        lmem[sg_id] = sumf;\n    }\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    sumf = lmem[sg_lid];\n    sumf = sub_group_reduce_add(sumf);\n\n    if (lid == 0) {\n        dst_row[0] = sumf;\n    }\n}\n\nkernel void kernel_sum_rows_f32_4(\n    global char *  src0,\n    ulong           offset0,\n    global char *  dst,\n    ulong           offsetd,\n    int             ne00,\n    int             ne01,\n    int             ne02,\n    int             ne03,\n    ulong           nb01,\n    ulong           nb02,\n    ulong           nb03,\n    ulong           nb1,\n    ulong           nb2,\n    ulong           nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst  + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    const int lid = get_local_id(0);\n    const int lsize = get_local_size(0);\n\n    const uint sg_size = get_sub_group_size();\n    const uint sg_id = get_sub_group_id();\n    const uint sg_lid = get_sub_group_local_id();\n\n    __local float lmem[MAX_SUBGROUPS];\n\n    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {\n        return;\n    }\n\n    if(sg_id == 0){\n        lmem[sg_lid] = 0.0f;\n    }\n\n    global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);\n    global float  * dst_row = (global float  *) (dst  + i1*nb1  + i2*nb2  + i3*nb3);\n\n    float4 sum_vec = (float4)0.0f;\n\n    for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {\n        sum_vec += src_row[i0];\n    }\n\n    float sumf = dot(sum_vec, (float4)(1.0f));\n    sumf = sub_group_reduce_add(sumf);\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    if(sg_lid == 0){\n        lmem[sg_id] = sumf;\n    }\n\n    barrier(CLK_LOCAL_MEM_FENCE);\n\n    sumf = lmem[sg_lid];\n    sumf = sub_group_reduce_add(sumf);\n\n    if (lid == 0) {\n        dst_row[0] = sumf;\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/tanh.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\nkernel void kernel_tanh_f32(\n        global const float * src0,\n        ulong                offset0,\n        global       float * dst,\n        ulong                offsetd\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst  = (global float*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = tanh(src0[get_global_id(0)]);\n}\n\nkernel void kernel_tanh_f32_4(\n        global const float4 * src0,\n        ulong                 offset0,\n        global       float4 * dst,\n        ulong                 offsetd\n) {\n    src0 = (global float4*)((global char*)src0 + offset0);\n    dst  = (global float4*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = tanh(src0[get_global_id(0)]);\n}\n\nkernel void kernel_tanh_f16(\n        global const half * src0,\n        ulong               offset0,\n        global       half * dst,\n        ulong               offsetd\n) {\n    src0 = (global half*)((global char*)src0 + offset0);\n    dst  = (global half*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = tanh(src0[get_global_id(0)]);\n}\n\nkernel void kernel_tanh_f16_4(\n        global const half4 * src0,\n        ulong                offset0,\n        global       half4 * dst,\n        ulong                offsetd\n) {\n    src0 = (global half4*)((global char*)src0 + offset0);\n    dst  = (global half4*)((global char*)dst + offsetd);\n\n    dst[get_global_id(0)] = tanh(src0[get_global_id(0)]);\n}\n\nkernel void kernel_tanh_f32_nc(\n        global const char * src0,\n        ulong               offset0,\n        global       char * dst,\n        ulong               offsetd,\n        int   ne00,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {\n        global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n        global       float * y = (global       float *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n        *y = tanh(*x);\n    }\n}\n\nkernel void kernel_tanh_f16_nc(\n        global const char * src0,\n        ulong               offset0,\n        global       char * dst,\n        ulong               offsetd,\n        int   ne00,\n        ulong nb00,\n        ulong nb01,\n        ulong nb02,\n        ulong nb03,\n        ulong nb0,\n        ulong nb1,\n        ulong nb2,\n        ulong nb3\n) {\n    src0 = src0 + offset0;\n    dst  = dst + offsetd;\n\n    const int i3 = get_group_id(2);\n    const int i2 = get_group_id(1);\n    const int i1 = get_group_id(0);\n\n    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {\n        global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);\n        global       half * y = (global       half *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);\n\n        *y = tanh(*x);\n    }\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/transpose.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n// 16-bit transpose, loading/storing a 4x4 tile of elements\nkernel void kernel_transpose_16(\n    __read_only image1d_buffer_t input,\n    __write_only image1d_buffer_t output,\n    const uint rows,\n    const uint cols\n) {\n\n    const int i = get_global_id(0);\n    const int j = get_global_id(1);\n    const int i_2 = i<<2;\n    const int j_2 = j<<2;\n\n    half4 temp0 = read_imageh(input, (j_2+0)*cols+i);\n    half4 temp1 = read_imageh(input, (j_2+1)*cols+i);\n    half4 temp2 = read_imageh(input, (j_2+2)*cols+i);\n    half4 temp3 = read_imageh(input, (j_2+3)*cols+i);\n\n    write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));\n    write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));\n    write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));\n    write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));\n}\n\n// Padded kernel for irregular shape\nkernel void kernel_transpose_16_4x1(\n    __read_only image1d_buffer_t input,\n    __write_only image1d_buffer_t output,\n    const uint rows,\n    const uint cols\n) {\n\n    const int i = get_global_id(0);\n    const int j = get_global_id(1);\n    const int j_2 = j << 2;\n\n    half temp0 = read_imageh(input, (j_2 + 0) * cols + i).x;\n    half temp1 = read_imageh(input, (j_2 + 1) * cols + i).x;\n    half temp2 = read_imageh(input, (j_2 + 2) * cols + i).x;\n    half temp3 = read_imageh(input, (j_2 + 3) * cols + i).x;\n\n    write_imageh(output, i * rows + j, (half4)(temp0, temp1, temp2, temp3));\n}\n\n// Transpose treating each element as 8-bit using buffer\nkernel void kernel_transpose_8_buf(\n    global const uchar * input,\n    global uchar * output,\n    const int ldi,\n    const int ldo\n) {\n    const int x = get_global_id(0);\n    const int y = get_global_id(1);\n\n    output[x*ldo + y] = input[y*ldi + x];\n}\n\n// Transpose treating each element as 16-bit using buffer\nkernel void kernel_transpose_16_buf(\n    global const ushort * input,\n    global ushort * output,\n    const int ldi,\n    const int ldo\n) {\n    const int x = get_global_id(0);\n    const int y = get_global_id(1);\n\n    output[x*ldo + y] = input[y*ldi + x];\n}\n\n// Transpose treating each element as 32-bit using buffer\nkernel void kernel_transpose_32_buf(\n    global const uint * input,\n    global uint * output,\n    const int ldi,\n    const int ldo\n) {\n    const int x = get_global_id(0);\n    const int y = get_global_id(1);\n\n    output[x*ldo + y] = input[y*ldi + x];\n}\n\n// 32-bit transpose, loading/storing a 4x4 tile of elements\nkernel void kernel_transpose_32(\n    __read_only image1d_buffer_t input,\n    __write_only image1d_buffer_t output,\n    const uint rows,\n    const uint cols\n) {\n\n    const int i = get_global_id(0);\n    const int j = get_global_id(1);\n    const int i_2 = i<<2;\n    const int j_2 = j<<2;\n\n    float4 temp0 = read_imagef(input, (j_2+0)*cols+i);\n    float4 temp1 = read_imagef(input, (j_2+1)*cols+i);\n    float4 temp2 = read_imagef(input, (j_2+2)*cols+i);\n    float4 temp3 = read_imagef(input, (j_2+3)*cols+i);\n\n    write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));\n    write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));\n    write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));\n    write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));\n\n}\n\n// 32-bit transpose, loading/storing a 4x4 tile of elements\n// Only used for activations\n// converts to FP16\n// also adds zero padding for non multiple of 8 prompt lengths\nkernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) {\n\n    const int i = get_global_id(0);\n    const int j = get_global_id(1);\n    const int i_2 = i<<2;\n    const int j_2 = j<<2;\n    half4 temp0 = {0,0,0,0}; // initialize outputs to 0\n    half4 temp1 = {0,0,0,0};\n    half4 temp2 = {0,0,0,0};\n    half4 temp3 = {0,0,0,0};\n\n    if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0\n        temp0 = read_imageh(input, (j_2+0)*cols+i);\n    }\n    if((j_2+1)*cols+i*4+3 < rows*cols*16){\n        temp1 = read_imageh(input, (j_2+1)*cols+i);\n    }\n    if((j_2+2)*cols+i*4+3 < rows*cols*16){\n        temp2 = read_imageh(input, (j_2+2)*cols+i);\n    }\n    if((j_2+3)*cols+i*4+3 < rows*cols*16){\n        temp3 = read_imageh(input, (j_2+3)*cols+i);\n    }\n\n    write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding\n    write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));\n    write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));\n    write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/tri.cl",
    "content": "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n\n//------------------------------------------------------------------------------\n// tri\n//------------------------------------------------------------------------------\n__kernel void kernel_tri_f32(\n        global float * src0,\n        ulong offset0,\n        global float * dst,\n        ulong offsetd,\n        int n,\n        int ne0,\n        int ne1,\n        int tri_type\n) {\n    src0 = (global float*)((global char*)src0 + offset0);\n    dst = (global float*)((global char*)dst + offsetd);\n\n    int idx = get_global_id(0);\n    if (idx >= n) return;\n\n    int i0 = idx % ne0;\n    int i1 = (idx / ne0) % ne1;\n\n    int keep = 0;\n    if (tri_type == 0) keep = (i0 >= i1);\n    else if (tri_type == 1) keep = (i0 >  i1);\n    else if (tri_type == 2) keep = (i0 <= i1);\n    else                    keep = (i0 <  i1);\n\n    dst[idx] = keep ? src0[idx] : 0.0f;\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/tsembd.cl",
    "content": "kernel void kernel_timestep_embedding(\n    global const void * p_timesteps,\n    ulong off_timesteps,\n    global void * p_dst,\n    ulong off_dst,\n    int dst_nb1_bytes,\n    int logical_dim,\n    int max_period\n) {\n    int local_i;\n    int local_j;\n    int local_half_dim;\n    float local_timestep_val;\n    float local_freq;\n    float local_arg;\n    global float * local_embed_data_ptr;\n    global const float * local_timesteps_input_ptr;\n    global float * local_dst_output_base_ptr;\n\n    local_timesteps_input_ptr = (global const float *)((global char *)p_timesteps + off_timesteps);\n    local_dst_output_base_ptr = (global float *)((global char *)p_dst + off_dst);\n\n    local_i = get_global_id(1);\n    local_j = get_global_id(0);\n\n    local_half_dim = logical_dim / 2;\n    local_embed_data_ptr = (global float *)((global char *)local_dst_output_base_ptr + local_i * dst_nb1_bytes);\n\n    if (logical_dim % 2 != 0 && local_j == local_half_dim) {\n        local_embed_data_ptr[2 * local_half_dim] = 0.0f;\n    }\n\n    if (local_j >= local_half_dim) {\n        return;\n    }\n\n    local_timestep_val = local_timesteps_input_ptr[local_i];\n\n    if (local_half_dim == 0) {\n        local_freq = 1.0f;\n    } else {\n        local_freq = exp(-log((float)max_period) * (float)local_j / (float)local_half_dim);\n    }\n\n    local_arg = local_timestep_val * local_freq;\n    local_embed_data_ptr[local_j] = cos(local_arg);\n    local_embed_data_ptr[local_j + local_half_dim] = sin(local_arg);\n}\n"
  },
  {
    "path": "src/ggml-opencl/kernels/upscale.cl",
    "content": "kernel void kernel_upscale(\n    global const void * p_src0,\n    ulong off_src0,\n    global void * p_dst,\n    ulong off_dst,\n    ulong nb00,\n    ulong nb01,\n    ulong nb02,\n    ulong nb03,\n    int ne10,\n    int ne11,\n    int ne12,\n    int ne13,\n    float sf0,\n    float sf1,\n    float sf2,\n    float sf3\n) {\n    global const char * src_base = (global const char *)p_src0 + off_src0;\n    global float * dst_base = (global float *)((global char *)p_dst + off_dst);\n\n    int index = get_global_id(0);\n    int dst_total_elements = ne10 * ne11 * ne12 * ne13;\n\n    if (index >= dst_total_elements) {\n        return;\n    }\n\n    int i10 = index % ne10;\n    int i11 = (index / ne10) % ne11;\n    int i12 = (index / (ne10 * ne11)) % ne12;\n    int i13 = index / (ne10 * ne11 * ne12);\n\n    int i00 = (int)(i10 / sf0);\n    int i01 = (int)(i11 / sf1);\n    int i02 = (int)(i12 / sf2);\n    int i03 = (int)(i13 / sf3);\n\n    ulong offset_src_element = (ulong)i03 * nb03 + (ulong)i02 * nb02 + (ulong)i01 * nb01 + (ulong)i00 * nb00;\n    global const float * src_element_ptr = (global const float *)(src_base + offset_src_element);\n\n    dst_base[index] = *src_element_ptr;\n}\n\nkernel void kernel_upscale_bilinear(\n    global const void * p_src0,\n    ulong off_src0,\n    global void * p_dst,\n    ulong off_dst,\n    ulong nb00,\n    ulong nb01,\n    ulong nb02,\n    ulong nb03,\n    int ne00_src,\n    int ne01_src,\n    int ne10_dst,\n    int ne11_dst,\n    int ne12_dst,\n    int ne13_dst,\n    float sf0,\n    float sf1,\n    float sf2,\n    float sf3,\n    float pixel_offset\n) {\n    global const char * src_base = (global const char *)p_src0 + off_src0;\n    global float * dst_base = (global float *)((global char *)p_dst + off_dst);\n\n    int index = get_global_id(0);\n    int dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;\n\n    if (index >= dst_total_elements) {\n        return;\n    }\n\n    int i10_dst = index % ne10_dst;\n    int i11_dst = (index / ne10_dst) % ne11_dst;\n    int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;\n    int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);\n\n    int i02_src = (int)(i12_dst / sf2);\n    int i03_src = (int)(i13_dst / sf3);\n\n    float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;\n    long y0_src = (long)floor(y_src_f);\n    long y1_src = y0_src + 1;\n\n    y0_src = max(0L, min(y0_src, (long)ne01_src - 1));\n    y1_src = max(0L, min(y1_src, (long)ne01_src - 1));\n\n    float dy = y_src_f - (float)y0_src;\n    dy = max(0.0f, min(dy, 1.0f));\n\n    float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;\n    long x0_src = (long)floor(x_src_f);\n    long x1_src = x0_src + 1;\n\n    x0_src = max(0L, min(x0_src, (long)ne00_src - 1));\n    x1_src = max(0L, min(x1_src, (long)ne00_src - 1));\n\n    float dx = x_src_f - (float)x0_src;\n    dx = max(0.0f, min(dx, 1.0f));\n\n    global const float * p_a = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);\n    global const float * p_b = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);\n    global const float * p_c = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);\n    global const float * p_d = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);\n\n    const float val_a = *p_a;\n    const float val_b = *p_b;\n    const float val_c = *p_c;\n    const float val_d = *p_d;\n\n    float result = val_a * (1.0f - dx) * (1.0f - dy) +\n                   val_b * dx * (1.0f - dy) +\n                   val_c * (1.0f - dx) * dy +\n                   val_d * dx * dy;\n\n    dst_base[index] = result;\n}\n"
  },
  {
    "path": "src/ggml-openvino/.clang-format",
    "content": "---\n# Override root .clang-format\nAlignConsecutiveAssignments: false\nAlignConsecutiveDeclarations: false\nCpp11BracedListStyle: true\nSpacesInContainerLiterals: false\nBreakBeforeBraces: Attach\nAccessModifierOffset: -4\nIndentCaseBlocks: false\nIndentCaseLabels: false\n\nLanguage:        Cpp\nAlignAfterOpenBracket: Align\nAlignArrayOfStructures: Left\nAlignConsecutiveBitFields: AcrossComments\nAlignConsecutiveMacros: AcrossComments\n# AlignConsecutiveShortCaseStatements: AcrossComments\nAlignEscapedNewlines: Left # LeftWithLastLine\nAlignOperands:   Align\nAlignTrailingComments:\n  Kind: Always\n  OverEmptyLines: 1\nAllowAllArgumentsOnNextLine: true\nAllowAllParametersOfDeclarationOnNextLine: false\n# AllowBreakBeforeNoexceptSpecifier: OnlyWithParen\nAllowShortBlocksOnASingleLine: Never\nAllowShortCaseLabelsOnASingleLine: false\nAllowShortFunctionsOnASingleLine: Inline\nAllowShortIfStatementsOnASingleLine: Never\nAllowShortLambdasOnASingleLine: Inline\nAllowShortLoopsOnASingleLine: false\nAlwaysBreakBeforeMultilineStrings: true\n# Treat CUDA keywords/attributes as \"attribute macros\" and avoid breaking lines inside them\nAttributeMacros:\n  - __host__\n  - __device__\n  - __global__\n  - __forceinline__\n  - __launch_bounds__\nBinPackArguments: true\nBinPackParameters: false # OnePerLine\nBitFieldColonSpacing: Both\n# BreakAdjacentStringLiterals: true\nBreakAfterAttributes: Never\nBreakBeforeBinaryOperators: None\nBreakBeforeInlineASMColon: OnlyMultiline\nBreakBeforeTernaryOperators: false\n# BreakBinaryOperations: Never\nBreakConstructorInitializers: AfterColon\n# BreakFunctionDefinitionParameters: false\nBreakInheritanceList: AfterComma\nBreakStringLiterals: true\n# BreakTemplateDeclarations: Yes\nColumnLimit:     120\nCommentPragmas:  '^ IWYU pragma:'\nCompactNamespaces: false\nConstructorInitializerIndentWidth: 4\nContinuationIndentWidth: 4\nDerivePointerAlignment: false\nDisableFormat:   false\nEmptyLineBeforeAccessModifier: Leave\nEmptyLineAfterAccessModifier: Never\nExperimentalAutoDetectBinPacking: false\nFixNamespaceComments: true\nIncludeBlocks:   Regroup\nIncludeCategories:\n  - Regex:           '\".*\"'\n    Priority:        1\n    SortPriority:    0\n  - Regex:           '^<.*\\.h>'\n    Priority:        2\n    SortPriority:    0\n  - Regex:           '^<.*'\n    Priority:        3\n    SortPriority:    0\n  - Regex:           '.*'\n    Priority:        4\n    SortPriority:    0\nIncludeIsMainRegex: '([-_](test|unittest))?$'\nIncludeIsMainSourceRegex: ''\nIndentAccessModifiers: false\nIndentExternBlock: NoIndent\nIndentGotoLabels: false\nIndentPPDirectives: AfterHash\nIndentWidth:     4\nIndentWrappedFunctionNames: false\nInsertBraces:    true # NOTE: may lead to incorrect formatting\nInsertNewlineAtEOF: true\nJavaScriptQuotes: Leave\nJavaScriptWrapImports: true\nKeepEmptyLinesAtTheStartOfBlocks: false\nLambdaBodyIndentation: Signature\nLineEnding: LF\nMacroBlockBegin: ''\nMacroBlockEnd:   ''\nMaxEmptyLinesToKeep: 1\nNamespaceIndentation: None\nObjCBinPackProtocolList: Auto\nObjCBlockIndentWidth: 4\nObjCSpaceAfterProperty: true\nObjCSpaceBeforeProtocolList: true\nPPIndentWidth: -1\nPackConstructorInitializers: CurrentLine\nPenaltyBreakAssignment: 2\nPenaltyBreakBeforeFirstCallParameter: 1\nPenaltyBreakComment: 300\nPenaltyBreakFirstLessLess: 120\nPenaltyBreakString: 1000\nPenaltyBreakTemplateDeclaration: 10\nPenaltyExcessCharacter: 1000000\nPenaltyReturnTypeOnItsOwnLine: 200\nPointerAlignment: Middle\nQualifierAlignment: Left\n#QualifierOrder: ['static', 'inline', 'friend', 'constexpr', 'const', 'volatile', 'type', 'restrict']\nRawStringFormats:\n  - Language:        Cpp\n    Delimiters:\n      - cc\n      - CC\n      - cpp\n      - Cpp\n      - CPP\n      - 'c++'\n      - 'C++'\n    CanonicalDelimiter: ''\nReferenceAlignment: Middle\nReflowComments:  false # IndentOnly\nSeparateDefinitionBlocks: Always\nSortIncludes:    CaseInsensitive\nSortUsingDeclarations: LexicographicNumeric\nSpaceAfterCStyleCast: true\nSpaceAfterLogicalNot: false\nSpaceAfterTemplateKeyword: true\nSpaceBeforeAssignmentOperators: true\nSpaceBeforeCpp11BracedList: false\nSpaceBeforeCtorInitializerColon: true\nSpaceBeforeInheritanceColon: true\nSpaceBeforeParens: ControlStatements\nSpaceBeforeRangeBasedForLoopColon: true\nSpaceInEmptyBlock: false\nSpaceInEmptyParentheses: false\nSpacesBeforeTrailingComments: 2\nSpacesInAngles:  Never\nSpacesInLineCommentPrefix:\n  Minimum: 1\n  Maximum: -1\nSpacesInParentheses: false\nSpacesInSquareBrackets: false\nSpaceBeforeSquareBrackets: false\nStandard:        c++17\nTabWidth:        4\nUseTab:          Never\nWhitespaceSensitiveMacros: ['STRINGIZE']\n...\n"
  },
  {
    "path": "src/ggml-openvino/CMakeLists.txt",
    "content": "find_package(OpenVINO REQUIRED)\nfind_package(OpenCL REQUIRED)\n\ninclude(\"${OpenVINO_DIR}/../3rdparty/tbb/lib/cmake/TBB/TBBConfig.cmake\")\n\nfile(GLOB_RECURSE GGML_HEADERS_OPENVINO \"*.h\" \"*.hpp\")\nfile(GLOB_RECURSE GGML_SOURCES_OPENVINO \"*.cpp\")\n\nggml_add_backend_library(ggml-openvino\n    ${GGML_SOURCES_OPENVINO}\n    ${GGML_HEADERS_OPENVINO}\n)\n\ntarget_link_libraries(ggml-openvino PRIVATE openvino::runtime TBB::tbb OpenCL::OpenCL)\n\nif (GGML_OPENVINO)\n    if (CMAKE_SYSTEM_PROCESSOR STREQUAL \"aarch64\")\n    elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL \"x86_64\" OR CMAKE_SYSTEM_PROCESSOR STREQUAL \"amd64\" OR CMAKE_SYSTEM_PROCESSOR STREQUAL \"AMD64\")\n    else()\n        message(FATAL_ERROR \"OpenVINO: OpenVINO toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}\")\n    endif()\nendif()\n"
  },
  {
    "path": "src/ggml-openvino/ggml-decoder.cpp",
    "content": "#include \"ggml-decoder.h\"\n\n#include \"ggml-backend-impl.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-openvino-extra.h\"\n#include \"ggml-openvino.h\"\n#include \"ggml-quants.h\"\n\n#include <ggml-impl.h>\n#include <ggml.h>\n\n#include <algorithm>\n#include <cassert>\n#include <cstddef>\n#include <cstdint>\n#include <cstdlib>\n#include <execution>\n#include <fstream>\n#include <iomanip>\n#include <map>\n#include <memory>\n#include <mutex>\n#include <openvino/core/dimension.hpp>\n#include <openvino/core/except.hpp>\n#include <openvino/core/node.hpp>\n#include <openvino/core/partial_shape.hpp>\n#include <openvino/core/type/bfloat16.hpp>\n#include <openvino/core/type/element_type.hpp>\n#include <openvino/core/type/float16.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/convert.hpp>\n#include <openvino/op/parameter.hpp>\n#include <openvino/runtime/tensor.hpp>\n#include <optional>\n#include <ostream>\n#include <set>\n#include <stdexcept>\n#include <string>\n#include <unordered_map>\n#include <vector>\n\nGgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,\n                             ModelParams & model_params,\n                             ComputeParams & compute_params,\n                             std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,\n                             bool is_static,\n                             bool is_stateful,\n                             bool is_prefill,\n                             int prefill_chunk_size) :\n    m_is_static(is_static),\n    m_is_stateful(is_stateful),\n    m_is_prefill(is_prefill),\n    m_naive(false),\n    m_prefill_chunk_size(prefill_chunk_size),\n    m_cgraph(cgraph),\n    m_model_weights(model_weights),\n    m_model_params(model_params),\n    m_compute_params(compute_params) {\n    if (auto * env = getenv(\"GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS\"); env && std::string(env) != \"0\") {\n#ifdef _WIN32\n        _putenv_s(\"GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS\", \"\");\n#else\n        unsetenv(\"GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS\");\n#endif\n        print_tensor_address_map(cgraph);\n    }\n\n    validate_cgraph();\n\n    set_input_output();\n    compute_model_inputs();\n    compute_model_outputs();\n\n    for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {\n        m_node_info_list[node_n].node_op_case = compute_op_case(m_node_info_list[node_n].node);\n        m_node_info_list[node_n].node_op_type = compute_op_type(m_node_info_list[node_n].node);\n    }\n\n    add_extra_inputs();\n}\n\nvoid GgmlOvDecoder::update_io(ggml_cgraph * cgraph) {\n    m_cgraph = cgraph;\n    m_model_inputs.clear();\n    m_model_outputs.clear();\n    m_node_info_list.clear();\n    set_input_output();\n    compute_model_inputs();\n    compute_model_outputs();\n}\n\nGgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, std::map<std::string, std::shared_ptr<ov::Node>> & model_weights) {\n    m_cgraph = cgraph;\n    m_model_weights = model_weights;\n    m_naive = true;\n    set_input_output();\n    compute_model_inputs();\n    compute_model_outputs();\n    for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {\n        m_node_info_list[node_n].node_op_case = compute_op_case(m_node_info_list[node_n].node);\n        m_node_info_list[node_n].node_op_type = compute_op_type(m_node_info_list[node_n].node);\n    }\n}\n\nvoid GgmlOvDecoder::set_input_output() {\n    for (int node_n = 0; node_n < m_cgraph->n_nodes; node_n++) {\n        auto node = m_cgraph->nodes[node_n];\n\n        NodeInfo current_node_info;\n        auto node_name = std::string(node->name);\n        auto node_output_name = node_name;\n        auto * node_output = node;\n        if (node->op == GGML_OP_SET_ROWS) {\n            // SET_ROWS updates the tensor in place. For later ov op that uses the\n            // the view_src of SET_ROWS, we need to make sure they get the updated tensor\n            // by putting the view_src name in the tensor_map in\n            // <openvino>/src/frontends/ggml/src/translate_session.cpp\n            node_output_name = std::string(node->view_src->name);\n            node_output = node->view_src;\n        }\n\n        current_node_info.node = node;\n        current_node_info.node_name = node_name;\n        current_node_info.node_output = node_output;\n        current_node_info.node_output_name = node_output_name;\n        current_node_info.node_op_case = 0;\n        current_node_info.data_addr = node->data;\n\n        for (int i = 0; i < GGML_MAX_SRC; i++) {\n            auto * src = node->src[i];\n            if (src == nullptr) {\n                continue;\n            }\n            auto src_name = std::string(src->name);\n            if (src->flags & GGML_TENSOR_FLAG_INPUT) {\n                src_name = get_graph_input_ov_name(src, node);\n            }\n            current_node_info.node_inputs[src_name] = src;\n            current_node_info.node_inputs_names.push_back(src_name);\n        }\n\n        m_node_info_list.push_back(current_node_info);\n    }\n}\n\nint GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const {\n    int op_case = 0;\n    switch (node->op) {\n    case GGML_OP_RESHAPE: {\n        auto * src = node->src[0];\n        if (src->op == GGML_OP_RESHAPE && src->src[0]->ne[0] == node->ne[0] && src->src[0]->ne[1] == node->ne[1]) {\n            op_case = 4;\n        } else if (node->ne[0] * node->ne[1] == src->ne[0]) {\n            op_case = 1;\n        } else if (src->ne[0] * src->ne[1] == node->ne[0]) {\n            op_case = 2;\n            if (src->ne[2] * src->ne[3] == node->ne[1]) {\n                op_case = 5;\n            }\n        } else if (src->ne[0] * src->ne[1] == node->ne[1]) {\n            op_case = 3;\n        } else if (src->ne[1] * src->ne[2] == node->ne[1]) {\n            op_case = 6;\n        }\n        break;\n    }\n    case GGML_OP_CONT: {\n        if (node->src[0]->op == GGML_OP_PERMUTE) {\n            op_case = 1;\n        } else if (node->src[0]->op == GGML_OP_TRANSPOSE) {\n            op_case = 2;\n        } else if (node->src[0]->op == GGML_OP_VIEW) {\n            op_case = 3;\n        }\n        break;\n    }\n    case GGML_OP_PERMUTE: {\n        if (node->src[0]->op != GGML_OP_VIEW) {\n            op_case = 1;\n        } else if (node->src[0]->src[0]->op == GGML_OP_NONE) {\n            // kv cache tensor\n            std::string src_name(node->view_src->name);\n            int layer = extract_layer_from_name(src_name);\n            if (!is_swa_layer(layer)) {\n                op_case = 2;\n            } else {\n                op_case = 3;\n            }\n        } else {\n            // rope'ed query tensor\n            op_case = 4;\n        }\n        break;\n    }\n    case GGML_OP_MUL_MAT: {\n        if (node->src[0]->op == GGML_OP_CONT && node->src[0]->src[0]->op == GGML_OP_TRANSPOSE) {\n            op_case = 2;\n        } else if (node->src[0]->op == GGML_OP_VIEW && node->src[1]->op == GGML_OP_VIEW) {\n            op_case = 3;\n        }\n        break;\n    }\n    case GGML_OP_GET_ROWS: {\n        if (node->src[1]->op == GGML_OP_VIEW) {\n            op_case = 2;\n        }\n        break;\n    }\n    case GGML_OP_ROPE: {\n        if (node->src[0]->op == GGML_OP_VIEW) {\n            op_case = 2;\n        }\n        break;\n    }\n    case GGML_OP_VIEW: {\n        if (node->src[0]->op == GGML_OP_VIEW) {\n            auto * src = node->src[0];\n            if (ggml_nelements(node) != ggml_nelements(src)) {\n                throw std::runtime_error(\"Unsupported VIEW case\");\n            }\n            op_case = 2;\n        }\n        {\n            auto * src = node->src[0];\n            if ((ggml_nelements(node) != ggml_nelements(src)) && m_naive) {\n                // Compare each dimension of node and src, if only one dimension differs then op_case=3\n                int diff_count = 0;\n                for (int i = 0; i < GGML_MAX_DIMS; i++) {\n                    if (node->ne[i] != src->ne[i]) {\n                        diff_count++;\n                    }\n                }\n                if (diff_count == 1) {\n                    op_case = 3;\n                }\n            }\n        }\n        break;\n    }\n    default:\n        break;\n    }\n    return op_case;\n}\n\nint extract_layer_from_name(const std::string & name) {\n    size_t pos1 = name.find(\"_l\");\n    assert(pos1 != std::string::npos);\n    pos1 += 2;\n    size_t pos2 = name.find(' ', pos1);\n    if (pos2 == std::string::npos) {\n        pos2 = name.length();\n    }\n    std::string layer_str = name.substr(pos1, pos2 - pos1);\n    int layer = std::stoi(layer_str);\n    return layer;\n}\n\nstd::pair<ModelParams, ComputeParams> GgmlOvDecoder::compute_llm_params(ggml_cgraph * cgraph, bool is_static) {\n    ModelParams model_params;\n    ComputeParams compute_params;\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        auto * node = cgraph->nodes[i];\n        std::string name = std::string(node->name);\n        if (node->op == GGML_OP_FLASH_ATTN_EXT) {\n            model_params.n_heads = node->src[0]->ne[2];\n            model_params.n_heads_kv = node->src[1]->ne[2];\n            model_params.head_size = node->src[0]->ne[0];\n            compute_params.input_len = node->src[0]->ne[1];\n\n            auto * cache_k_perm = node->src[1];\n            if (cache_k_perm->op == GGML_OP_CPY) {\n                cache_k_perm = cache_k_perm->src[0];\n            }\n            assert(cache_k_perm->op == GGML_OP_PERMUTE);\n            auto * cache_k_view = cache_k_perm->src[0];\n            assert(cache_k_view->op == GGML_OP_VIEW);\n\n            auto * cache_k = cache_k_view->src[0];\n            int layer = extract_layer_from_name(cache_k->name);\n            auto * mask = node->src[3];\n            std::string mask_name(mask->name);\n\n            model_params.kv_buffer_ctx_id = ggml_backend_openvino_buffer_get_ctx_id(cache_k->buffer);\n            if (mask_name.find(\"swa\") != std::string::npos) {\n                model_params.swa_layers.push_back(layer);\n                model_params.ctx_per_seq_swa = cache_k->ne[1];\n            } else {\n                model_params.ctx_per_seq = cache_k->ne[1];\n                model_params.n_seq = cache_k->ne[2];\n            }\n\n            compute_params.n_seq_active = mask->ne[3];\n            auto seq_size = cache_k->ne[0] * cache_k->ne[1] * ggml_type_size(cache_k->type);\n            size_t offset;\n            memcpy(&offset, cache_k_view->op_params, sizeof(size_t));\n            compute_params.seq_active_start = offset / seq_size;\n            compute_params.token_len_per_seq = node->ne[2];\n\n            if (mask_name.find(\"swa\") != std::string::npos) {\n                compute_params.attention_size_swa = mask->ne[0];\n            } else {\n                compute_params.attention_size = mask->ne[0];\n            }\n            if (is_static) {\n                compute_params.attention_size = model_params.ctx_per_seq;\n                compute_params.attention_size_swa = model_params.ctx_per_seq_swa;\n                compute_params.token_len_per_seq = 1;\n            }\n            break;\n        }\n        if (node->op == GGML_OP_ROPE) {\n            memcpy(model_params.rope_params, node->op_params, sizeof(int32_t) * 15);\n        }\n    }\n    auto * output_tensor = cgraph->nodes[cgraph->n_nodes - 1];\n    compute_params.output_len = output_tensor->ne[1];\n    // for NPU, output_len is always 1 except for llama-perplexity\n    if (is_static && compute_params.output_len == 0) {\n        compute_params.output_len = 1;\n    }\n    model_params.ctx = model_params.ctx_per_seq * model_params.n_seq;\n    model_params.ctx_swa = model_params.ctx_per_seq_swa * model_params.n_seq;\n    return {model_params, compute_params};\n}\n\nvoid GgmlOvDecoder::validate_cgraph() const {\n    if (m_model_params.n_seq > 1 && m_is_static == true) {\n        throw std::runtime_error(\"n_seq > 1 is not supported on NPU. Try setting -np 1.\");\n    }\n}\n\nov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const {\n    if (m_naive) {\n        return input!= nullptr ? ov::PartialShape{get_shape(input)} : ov::PartialShape{get_shape(op)};\n    }\n    auto name = std::string(input->name);\n    ov::PartialShape input_shape;\n\n    if (is_inp_tok(input, op) || is_inp_pos(input, op)) {\n        // tokens or positions\n        int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1;\n        input_shape = ov::PartialShape{1, 1, 1, len};\n\n    } else if (is_output_idx(input, op)) {\n        // output index\n        input_shape = ov::PartialShape{1, 1, 1, m_is_static ? m_compute_params.output_len : -1};\n\n    } else if (is_inp_mask(input, op)) {\n        // mask\n        if (m_is_static) {\n            input_shape = ov::PartialShape{1, 1, m_is_prefill ? m_prefill_chunk_size : 1, m_model_params.ctx};\n        } else if (m_is_stateful) {\n            input_shape = ov::PartialShape{1, 1, -1, -1};\n        } else {\n            input_shape = ov::PartialShape{-1, 1, -1, -1};\n        }\n\n    } else if (is_kvcache(input, op)) {\n        // kvcache\n        input_shape = ov::PartialShape{get_shape(input)};\n        if (!m_is_static) {\n            // do not fix ctx size to make llama-bench work across test params\n            input_shape[2] = -1;\n        }\n        if (is_stateful()) {\n            // Convert stateless KV cache layout [1, 1, seq, n_heads_kv * head_size]\n            // to stateful layout [1, seq, n_heads_kv, head_size].\n            assert(input_shape.size() == 4 && input_shape[0] == 1 && input_shape[1] == 1 &&\n                   input_shape[2].is_dynamic() &&\n                   input_shape[3] == (m_model_params.n_heads_kv * m_model_params.head_size));\n            input_shape = {input_shape[0], ov::Dimension::dynamic(), m_model_params.n_heads_kv,\n                           m_model_params.head_size};\n        }\n\n    } else if (is_kv_idx(input, op)) {\n        // kv update index\n        int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1;\n        input_shape = ov::PartialShape{1, 1, 1, len};\n\n    } else {\n        input_shape = ov::PartialShape{get_shape(input)};\n    }\n    return input_shape;\n}\n\nvoid GgmlOvDecoder::add_extra_inputs() {\n    // Extra inputs:\n    // 1. `attention_size`, used in FLASH_ATTN where the shape of the matmul's are 256 aligned,\n    //     see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.\n    // 2. `n_seq_active` and `seq_active_start`, used in FLASH_ATTN_EXT to indicate the active sequences in the batch\n\n    auto create_1d_input = [this](const std::string & name, int64_t value) {\n        if (m_is_static) {\n            auto constant =\n                std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{value});\n            constant->set_friendly_name(name);\n            m_model_extra_inputs[name] = constant;\n        } else {\n            auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1});\n            param_node->set_friendly_name(name);\n            param_node->output(0).get_tensor().set_names({name});\n            m_model_extra_inputs[name] = param_node;\n\n            auto tensor = std::make_shared<ov::Tensor>(ov::element::i64, ov::Shape{1});\n            *tensor->data<int64_t>() = value;\n            m_model_extra_input_values[name] = tensor;\n        }\n    };\n\n    create_1d_input(\"attention_size\", m_compute_params.attention_size);\n    if (m_compute_params.attention_size_swa != -1) {\n        create_1d_input(\"attention_size_swa\", m_compute_params.attention_size_swa);\n    }\n    create_1d_input(\"n_seq_active\", m_compute_params.n_seq_active);\n    create_1d_input(\"seq_active_start\", m_compute_params.seq_active_start);\n    create_1d_input(\"seq_active_end\", m_compute_params.seq_active_start + m_compute_params.n_seq_active);\n    create_1d_input(\"token_len_per_seq\", m_compute_params.token_len_per_seq);\n    // create_1d_input(\"token_len\", m_token_len_per_seq * m_n_seq_active);\n}\n\nbool GgmlOvDecoder::node_is_used_as_src(const int node_idx) {\n    ggml_tensor * node = m_cgraph->nodes[node_idx];\n    for (int i = node_idx; i < m_cgraph->n_nodes; i++) {\n        ggml_tensor * other_node = m_cgraph->nodes[i];\n        for (int j = 0; j < GGML_MAX_SRC; j++) {\n            if (other_node->src[j] == node) {\n                return true;\n            }\n        }\n    }\n    return false;\n}\n\nvoid GgmlOvDecoder::compute_model_inputs() {\n    m_model_inputs.clear();\n    m_inputs.clear();\n    for (int i = 0; i < m_cgraph->n_nodes; i++) {\n        ggml_tensor * node = m_cgraph->nodes[i];\n        // the node op is NONE means this node maybe as input of later nodes, we should add it to model inputs for this node.\n        if (node->op == GGML_OP_NONE && node_is_used_as_src(i)) {\n            std::string node_name(node->name);\n            if (m_model_weights.find(node_name) == m_model_weights.end()) {\n                m_inputs[node_name] = node;\n                auto param_node =\n                    std::make_shared<ov::op::v0::Parameter>(get_ov_type(node), get_graph_input_shape(node, nullptr));\n                param_node->set_friendly_name(node_name);\n                param_node->output(0).get_tensor().set_names({node_name});\n                m_model_inputs[node_name] = param_node;\n            }\n            continue;\n        }\n        for (int i = 0; i < GGML_MAX_SRC; i++) {\n            auto * src = node->src[i];\n            if (src == nullptr) {\n                continue;\n            }\n            std::string src_name = std::string(src->name);\n            if (src->flags & GGML_TENSOR_FLAG_INPUT) {\n                src_name = get_graph_input_ov_name(src, node);\n            }\n            if (m_model_weights.find(src_name) != m_model_weights.end()) {\n                continue;\n            }\n\n            bool is_intermediate_node = false;\n            for (const auto & node_info : m_node_info_list) {\n                if (node_info.node == src) {\n                    is_intermediate_node = true;\n                    break;\n                }\n            }\n            if (is_intermediate_node) {\n                continue;\n            }\n            if (m_model_inputs.find(src_name) != m_model_inputs.end()) {\n                continue;\n            }\n\n            m_inputs[src_name] = src;\n\n            ggml_backend_buffer * buffer = src->buffer;\n            // GGML_BACKEND_BUFFER_USAGE_ANY are kv caches\n            if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY) {\n                if (auto it = std::find(m_model_params.kv_names.begin(), m_model_params.kv_names.end(), src_name);\n                    it == m_model_params.kv_names.end()) {\n                    m_model_params.kv_names.push_back(src_name);\n                }\n            }\n            ov::PartialShape param_shape = get_graph_input_shape(node, src);\n            auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), param_shape);\n            param_node->set_friendly_name(src_name);\n            param_node->output(0).get_tensor().set_names({src_name});\n            m_model_inputs[src_name] = param_node;\n        }\n    }\n}\n\nvoid GgmlOvDecoder::compute_model_outputs() {\n    m_model_outputs.clear();\n    m_model_output_names.clear();\n    for (int node_n = 0; node_n < m_cgraph->n_nodes; node_n++) {\n        auto * cur_node = m_cgraph->nodes[node_n];\n        // if the node op is NONE means this node is not used at all, we can skip it directly without adding to model outputs.\n        if (cur_node->op == GGML_OP_NONE) {\n            continue;\n        }\n        auto cur_node_use_count = m_cgraph->use_counts[ggml_hash_find(&m_cgraph->visited_hash_set, cur_node)];\n        if (cur_node_use_count == 0) {\n            // The output of SET_ROWS is the view_src tensor, which is updated in place. We should use the view_src name as the output name to make sure it can be correctly matched with the later ops that use the view_src.\n            if (cur_node != nullptr && cur_node->op == GGML_OP_SET_ROWS) {\n                cur_node = cur_node->view_src;\n            }\n        } else {\n            int input_use_count = 0;\n            for (int i = 0; i < m_cgraph->n_nodes; i++) {\n                ggml_tensor * node = m_cgraph->nodes[i];\n                for (int j = 0; j < GGML_MAX_SRC; j++) {\n                    if (node->src[j] != NULL && node->src[j] == cur_node) {\n                        input_use_count++;\n                    }\n                }\n            }\n            if (input_use_count == cur_node_use_count) {\n                cur_node = nullptr;\n            }\n        }\n        if (cur_node != nullptr) {\n            std::string node_output_name(cur_node->name);\n            m_model_outputs[node_output_name] = cur_node;\n            m_model_output_names.push_back(node_output_name);\n        }\n    }\n}\n\nconst ggml_tensor * GgmlOvDecoder::get_tensor_used_op(const ggml_tensor * tensor) const {\n    if (tensor == nullptr) {\n        return nullptr;\n    }\n    for (int i = 0; i < m_cgraph->n_nodes; i++) {\n        const auto * node = m_cgraph->nodes[i];\n        for (int j = 0; j < GGML_MAX_SRC; j++) {\n            if (node->src[j] == tensor) {\n                return node;\n            }\n        }\n    }\n    return nullptr;\n}\n\nconst ggml_tensor * GgmlOvDecoder::get_tensor_from_name(const std::string & name) const {\n    for (int i = 0; i < m_cgraph->n_nodes; i++) {\n        const auto * node = m_cgraph->nodes[i];\n        for (int j = 0; j < GGML_MAX_SRC; j++) {\n            const auto * src = node->src[j];\n            if (src == nullptr) {\n                break;\n            }\n            if (std::string(src->name) == name) {\n                return src;\n            }\n        }\n    }\n    return nullptr;\n}\n\nstd::map<std::string, std::string> GgmlOvDecoder::get_kv_param_res_names() const {\n    std::map<std::string, std::string> kv_param_res_names;\n    for (const auto & name : m_model_params.kv_names) {\n        kv_param_res_names[name] = name;\n    }\n    return kv_param_res_names;\n}\n\nstd::map<std::string, std::shared_ptr<ov::Node>> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph, bool naive) {\n    static std::mutex weights_mutex;\n    std::lock_guard<std::mutex> lock(weights_mutex);\n\n    std::map<std::string, std::shared_ptr<ov::Node>> model_weights;\n    auto * nodes = cgraph->nodes;\n    auto n_nodes = cgraph->n_nodes;\n    for (int node_i = 0; node_i < n_nodes; node_i++) {\n        auto * node = nodes[node_i];\n        for (int i = 0; i < GGML_MAX_SRC; i++) {\n            auto * src = node->src[i];\n            if (src == nullptr) {\n                continue;\n            }\n\n            std::string src_name(src->name);\n            if (is_rope_freqs_weight(src, node)) {\n                src_name = \"rope_freqs.weight\";\n            }\n            if (!src->view_src) {\n                ggml_backend_buffer * buffer = src->buffer;\n                if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS || ggml_is_quantized(src->type)) {\n                    if (model_weights.find(src_name) == model_weights.end()) {\n                        auto weight_node = create_weight_node(src, naive);\n                        weight_node->set_friendly_name(src_name);\n                        model_weights[src_name] = weight_node;\n                    }\n                }\n            }\n        }\n    }\n    return model_weights;\n}\n\nstd::shared_ptr<ov::Node> GgmlOvDecoder::create_weight_node(ggml_tensor * tensor, bool naive) {\n    const bool is_ov_buffer = ggml_backend_buffer_is_openvino(tensor->buffer);\n\n    // Check if we have a pre-built constant from the OpenVINO backend buffer\n    // This is set during ggml_backend_openvino_buffer_set_tensor\n    if (tensor->extra) {\n        OPENVINO_ASSERT(is_ov_buffer, \"Unsupported weight tensor: \" + std::string(tensor->name) +\n                                          \" Possibly this is a cpu backend repacked quantized weights\");\n        // Cast to our extra base type and check the type\n        auto * extra_base = static_cast<ggml_openvino_extra_base *>(tensor->extra);\n\n        if (extra_base->type == ggml_openvino_extra_base::Type::WEIGHT) {\n            // F16/F32/BF16 weight with shared-memory constant\n            auto * weight_extra = static_cast<ggml_openvino_weight_extra *>(tensor->extra);\n            if (weight_extra->weight_node) {\n                // GGML_LOG_DEBUG(\"%s: using pre-built weight node for %s\\n\", __func__, tensor->name);\n                return weight_extra->weight_node;\n            }\n        } else if (extra_base->type == ggml_openvino_extra_base::Type::QUANTIZED_WEIGHT) {\n            // Quantized weight with pre-extracted data\n            auto * quant_extra = static_cast<ggml_openvino_quantized_weight_extra *>(tensor->extra);\n            if (quant_extra->weight_node) {\n                // GGML_LOG_DEBUG(\"%s: using pre-extracted quantized weight node for %s\\n\", __func__, tensor->name);\n                return quant_extra->weight_node;\n            }\n        }\n    }\n\n    // There are three cases where we need to create a new weight node:\n    // 1. weights are in openvino_host_buffer. Weight loading to host buffer will not trigger backend_buffer_set_tensor\n    // 2. weights are in cpu/cpu_mapped buffer. On token_embd.weight goes to case 1 or 2, depending on whether mmap or direct_io is used\n    // 3. test-backend-ops. buffers in test-backend-ops does not set USAGE_WEIGHT so backend_buffer_set_tensor will not create weight node\n\n    // GGML_LOG_DEBUG(\"%s: creating new weight node for %s\\n\", __func__, tensor->name);\n    static const std::set<ggml_type> weight_types = {GGML_TYPE_F32,  GGML_TYPE_F16,  GGML_TYPE_BF16,\n                                                     GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,\n                                                     GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K};\n    if (weight_types.find(tensor->type) == weight_types.end()) {\n        throw std::runtime_error(\"Unexpected weight tensor type: \" + std::string(tensor->name) + \" with type \" +\n                                 ggml_type_name(tensor->type));\n    }\n\n    OvWeight ov_weight;\n    if (ggml_is_quantized(tensor->type)) {\n        auto use_bias = naive;\n        if (is_ov_buffer) {\n            // For quantized weights, copy raw data to a temp buffer first because\n            // process_weight_tensor reads from data and writes extracted results\n            // (weights/scales/zp) to output_base_ptr — they would overlap if both\n            // point to tensor->data.\n            size_t raw_size = ggml_nbytes(tensor);\n            std::vector<uint8_t> tmp(raw_size);\n            memcpy(tmp.data(), tensor->data, raw_size);\n            ov_weight = process_weight_tensor(tensor, tmp.data(), tensor->data, use_bias);\n        } else {\n            ov_weight = process_weight_tensor(tensor, tensor->data, nullptr, use_bias);\n        }\n    } else {\n        // For non-quantized weights (F16/F32/BF16), data is already in tensor->data.\n        // process_weight_tensor will create an ov::Tensor wrapping tensor->data directly.\n        ov_weight = process_weight_tensor(tensor, tensor->data, tensor->data);\n    }\n\n    ov_weight.weight_node->set_friendly_name(tensor->name);\n    if (!is_ov_buffer) {\n        return ov_weight.weight_node;\n    }\n\n    ggml_openvino_extra_base * extra;\n    if (ov_weight.is_quantized()) {\n        extra = new ggml_openvino_quantized_weight_extra(std::move(ov_weight.weights), std::move(ov_weight.scales),\n                                                         std::move(ov_weight.zp), ov_weight.weight_node);\n    } else {\n        extra = new ggml_openvino_weight_extra(std::move(ov_weight.weights), ov_weight.weight_node);\n    }\n    ggml_openvino_buffer_register_extra(tensor, extra);\n\n    return ov_weight.weight_node;\n}\n\nvoid GgmlOvDecoder::dump_cgraph(const ggml_cgraph * cgraph, std::string & filename) {\n    std::ofstream file(filename);\n    if (!file.is_open()) {\n        std::cerr << \"Failed to open file\" << std::endl;\n        return;\n    }\n\n    file << \"=== GRAPH ===\\n\";\n\n    // clang-format off\n    file << \"n_nodes = \" << cgraph->n_nodes << \"\\n\";\n    file << \" \" << std::setw(3) << \"nodes\"\n                <<  std::setw(15) << \"shape\"\n                << std::setw(20) << \"op\"\n                << std::setw(20) << \"name\"\n                << std::setw(3) << \"    \"\n                << std::setw(62) << \"stride\"\n                << std::setw(20) << \"buffer_type\"\n                << \"\\n\";\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        ggml_tensor * node = cgraph->nodes[i];\n\n        // Get buffer type name\n        const char * buf_name = \"none\";\n        ggml_backend_buffer_t buf = node->view_src ? node->view_src->buffer : node->buffer;\n        if (buf) {\n            buf_name = ggml_backend_buffer_name(buf);\n        }\n\n        file << \" - \" << std::setw(3) << i << \": [ \"\n             << std::setw(5) << node->ne[0] << \", \"\n             << std::setw(5) << node->ne[1] << \", \"\n             << std::setw(5) << node->ne[2] << \", \"\n             << std::setw(5) << node->ne[3] << \"] \"\n             << std::left << std::setw(20) << ggml_op_name(node->op) << std::right << \" \"\n             << std::left << std::setw(45) << node->name << std::right\n             << std::setw(2) << \"[ \"\n             << std::setw(0) << node->nb[0] << \", \"\n             << std::setw(5) << node->nb[1] << \", \"\n             << std::setw(5) << node->nb[2] << \", \"\n             << std::setw(5) << node->nb[3] << \"] \"\n             << std::right << std::setw(15) << buf_name << std::right\n             << \"\\n\";\n\n        for (int i = 0; i < GGML_MAX_SRC; i++) {\n            if (auto* src = node->src[i]) {\n                // Get buffer type name for source\n                const char * src_buf_name = \"none\";\n                ggml_backend_buffer_t src_buf = src->view_src ? src->view_src->buffer : src->buffer;\n                if (src_buf) {\n                    src_buf_name = ggml_backend_buffer_name(src_buf);\n                }\n\n                file << std::setw(10) << \" [ \"\n                << std::setw(5) << src->ne[0] << \", \"\n                << std::setw(5) << src->ne[1] << \", \"\n                << std::setw(5) << src->ne[2] << \", \"\n                << std::setw(5) << src->ne[3] << \"] \"\n                << std::setw(12)\n                << i << \": \" << std::left << std::setw(12) << ggml_op_name(src->op) << std::right;\n                file << std::left << std::setw(30) << src->name << std::right\n                << std::setw(16) << \"[ \"\n                << std::setw(0) << src->nb[0] << \", \"\n                << std::setw(5) << src->nb[1] << \", \"\n                << std::setw(5) << src->nb[2] << \", \"\n                << std::setw(5) << src->nb[3] << \"] \"\n                << std::right << std::setw(15) << src_buf_name << std::right\n                << \"\\n\";\n            }\n        }\n    }\n\n    file << \"n_leafs = \" << cgraph->n_leafs << \"\\n\";\n    for (int i = 0; i < cgraph->n_leafs; i++) {\n        ggml_tensor * node = cgraph->leafs[i];\n\n        // Get buffer type name for leaf\n        const char * leaf_buf_name = \"none\";\n        ggml_backend_buffer_t leaf_buf = node->view_src ? node->view_src->buffer : node->buffer;\n        if (leaf_buf) {\n            leaf_buf_name = ggml_backend_buffer_name(leaf_buf);\n        }\n\n        file << \" - \" << std::setw(3) << i << \": [ \"\n             << std::setw(5) << node->ne[0] << \", \"\n             << std::setw(5) << node->ne[1] << \"] \"\n             << std::setw(8) << ggml_op_name(node->op) << \" \"\n             << std::setw(16) << ggml_get_name(node)\n             << std::setw(20) << leaf_buf_name << \"\\n\";\n    }\n    // clang-format on\n    file << \"========================================\\n\";\n\n    file.close();\n}\n\nvoid print_tensor_address_map(const ggml_cgraph * cgraph) {\n    std::map<void *, std::vector<std::string>> address_map;\n    for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {\n        auto * node = cgraph->nodes[node_n];\n        if (node->data) {\n            auto it = address_map.find(node->data);\n            if (it == address_map.end()) {\n                address_map[node->data] = std::vector<std::string>();\n            }\n            address_map[node->data].push_back(node->name);\n        }\n    }\n    for (const auto & pair : address_map) {\n        std::cout << \"Address: \" << pair.first << std::endl;\n        for (const auto & name : pair.second) {\n            std::cout << name << \" ; \";\n        }\n        std::cout << std::endl << std::endl;\n    }\n}\n\nov::Shape GgmlOvDecoder::get_shape(const ggml_tensor * tensor) {\n    std::vector<size_t> shape;\n    for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) {\n        shape.push_back(static_cast<size_t>(tensor->ne[i]));\n    }\n    return shape;\n}\n\nstd::vector<size_t> GgmlOvDecoder::get_stride(const ggml_tensor * tensor) {\n    std::vector<size_t> stride;\n    for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) {\n        stride.push_back(static_cast<size_t>(tensor->nb[i]));\n    }\n    return stride;\n}\n\nov::element::Type GgmlOvDecoder::get_ov_type(const ggml_tensor * tensor) {\n    switch (tensor->type) {\n    case GGML_TYPE_F64:\n        return ov::element::f64;\n    case GGML_TYPE_F32:\n        return ov::element::f32;\n    case GGML_TYPE_F16:\n        return ov::element::f16;\n    case GGML_TYPE_BF16:\n        return ov::element::bf16;\n    case GGML_TYPE_I8:\n        return ov::element::i8;\n    case GGML_TYPE_I16:\n        return ov::element::i16;\n    case GGML_TYPE_I32:\n        return ov::element::i32;\n    case GGML_TYPE_I64:\n        return ov::element::i64;\n    default:\n        return ov::element::dynamic;\n    }\n}\n\nov::PartialShape GgmlOvDecoder::get_input_shape(int node_idx, const std::string & name) const {\n    return ov::PartialShape(get_shape(m_node_info_list[node_idx].node_inputs.at(name)));\n}\n\nstd::vector<size_t> GgmlOvDecoder::get_input_stride(int node_idx, const std::string & name) const {\n    return get_stride(m_node_info_list[node_idx].node_inputs.at(name));\n}\n\nov::element::Type GgmlOvDecoder::get_input_type(int node_idx, const std::string & name) const {\n    return get_ov_type(m_node_info_list[node_idx].node_inputs.at(name));\n}\n\nsize_t GgmlOvDecoder::get_input_size() const {\n    return m_model_inputs.size();\n}\n\nsize_t GgmlOvDecoder::get_input_size(int node_idx) const {\n    return m_node_info_list[node_idx].node_inputs_names.size();\n}\n\nstd::vector<std::string> GgmlOvDecoder::get_input_names(int node_idx) const {\n    return m_node_info_list[node_idx].node_inputs_names;\n}\n\nov::PartialShape GgmlOvDecoder::get_output_shape(int node_idx) const {\n    auto * ggml_tensor = m_node_info_list[node_idx].node_output;\n    return ov::PartialShape(get_shape(ggml_tensor));\n}\n\nov::element::Type GgmlOvDecoder::get_output_type(const int node_idx) const {\n    return get_ov_type(m_node_info_list[node_idx].node);\n}\n\nstd::vector<std::string> GgmlOvDecoder::get_output_names(int node_idx) const {\n    return {m_node_info_list[node_idx].node_output_name};\n}\n\nconst std::string & GgmlOvDecoder::get_op_name() const {\n    static const std::string unknown_name = \"UNKNOWN_OP_NAME\";\n    return unknown_name;\n}\n\nconst std::string & GgmlOvDecoder::get_op_name(int node_idx) const {\n    return m_node_info_list[node_idx].node_name;\n}\n\nint32_t * GgmlOvDecoder::get_input_op_params(int node_idx, const std::string & name) const {\n    return m_node_info_list[node_idx].node_inputs.at(name)->op_params;\n}\n\nint32_t * GgmlOvDecoder::get_output_op_params(int node_idx) const {\n    return m_node_info_list[node_idx].node->op_params;\n}\n\nvoid GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>, int node_idx)> node_visitor) const {\n    for (int node_idx = 0; node_idx < m_cgraph->n_nodes; node_idx++) {\n        if (m_cgraph->nodes[node_idx]->op == GGML_OP_NONE) {\n            continue;\n        }\n        node_visitor(std::make_shared<GgmlOvDecoder>(*this), node_idx);\n    }\n}\n\nstd::string GgmlOvDecoder::compute_op_type(const ggml_tensor * node) {\n    static const std::map<ggml_op, std::string> ops = {\n        {GGML_OP_NONE,           \"GGML_OP_NONE\"          },\n        {GGML_OP_ACC,            \"GGML_OP_ACC\"           },\n        {GGML_OP_ADD,            \"GGML_OP_ADD\"           },\n        {GGML_OP_ADD1,           \"GGML_OP_ADD1\"          },\n        {GGML_OP_CONT,           \"GGML_OP_CONT\"          },\n        {GGML_OP_DIV,            \"GGML_OP_DIV\"           },\n        {GGML_OP_DUP,            \"GGML_OP_DUP\"           },\n        {GGML_OP_GET_ROWS,       \"GGML_OP_GET_ROWS\"      },\n        {GGML_OP_MUL,            \"GGML_OP_MUL\"           },\n        {GGML_OP_MUL_MAT,        \"GGML_OP_MUL_MAT\"       },\n        {GGML_OP_PERMUTE,        \"GGML_OP_PERMUTE\"       },\n        {GGML_OP_RESHAPE,        \"GGML_OP_RESHAPE\"       },\n        {GGML_OP_RMS_NORM,       \"GGML_OP_RMS_NORM\"      },\n        {GGML_OP_ROPE,           \"GGML_OP_ROPE\"          },\n        {GGML_OP_SCALE,          \"GGML_OP_SCALE\"         },\n        {GGML_OP_SOFT_MAX,       \"GGML_OP_SOFT_MAX\"      },\n        {GGML_OP_SUB,            \"GGML_OP_SUB\"           },\n        {GGML_OP_TRANSPOSE,      \"GGML_OP_TRANSPOSE\"     },\n        {GGML_OP_VIEW,           \"GGML_OP_VIEW\"          },\n        {GGML_OP_SET_ROWS,       \"GGML_OP_SET_ROWS\"      },\n        {GGML_OP_CPY,            \"GGML_OP_CPY\"           },\n        {GGML_OP_FLASH_ATTN_EXT, \"GGML_OP_FLASH_ATTN_EXT\"},\n    };\n    static const std::map<ggml_unary_op, std::string> unary_ops = {\n        {GGML_UNARY_OP_ABS,         \"GGML_UNARY_OP_ABS\"        },\n        {GGML_UNARY_OP_SGN,         \"GGML_UNARY_OP_SGN\"        },\n        {GGML_UNARY_OP_NEG,         \"GGML_UNARY_OP_NEG\"        },\n        {GGML_UNARY_OP_STEP,        \"GGML_UNARY_OP_STEP\"       },\n        {GGML_UNARY_OP_TANH,        \"GGML_UNARY_OP_TANH\"       },\n        {GGML_UNARY_OP_ELU,         \"GGML_UNARY_OP_ELU\"        },\n        {GGML_UNARY_OP_RELU,        \"GGML_UNARY_OP_RELU\"       },\n        {GGML_UNARY_OP_SIGMOID,     \"GGML_UNARY_OP_SIGMOID\"    },\n        {GGML_UNARY_OP_GELU,        \"GGML_UNARY_OP_GELU\"       },\n        {GGML_UNARY_OP_GELU_QUICK,  \"GGML_UNARY_OP_GELU_QUICK\" },\n        {GGML_UNARY_OP_SILU,        \"GGML_UNARY_OP_SILU\"       },\n        {GGML_UNARY_OP_HARDSWISH,   \"GGML_UNARY_OP_HARDSWISH\"  },\n        {GGML_UNARY_OP_HARDSIGMOID, \"GGML_UNARY_OP_HARDSIGMOID\"},\n        {GGML_UNARY_OP_EXP,         \"GGML_UNARY_OP_EXP\"        },\n        {GGML_UNARY_OP_COUNT,       \"GGML_UNARY_OP_COUNT\"      }\n    };\n    static const std::map<ggml_glu_op, std::string> glu_ops = {\n        {GGML_GLU_OP_SWIGLU, \"GGML_GLU_OP_SWIGLU\"},\n        {GGML_GLU_OP_GEGLU,  \"GGML_GLU_OP_GEGLU\" },\n        {GGML_GLU_OP_REGLU,  \"GGML_GLU_OP_REGLU\" }\n    };\n\n    switch (node->op) {\n    case GGML_OP_UNARY:\n        return unary_ops.at(ggml_get_unary_op(node));\n    case GGML_OP_GLU:\n        return glu_ops.at(ggml_get_glu_op(node));\n    default:\n        return ops.at(node->op);\n    }\n    static const std::string unknown_op = \"UNKNOWN_GGML_OP\";\n    return unknown_op;\n}\n\nconst std::string & GgmlOvDecoder::get_op_type(int node_idx) const {\n    return m_node_info_list[node_idx].node_op_type;\n}\n\nconst std::string & GgmlOvDecoder::get_op_type() const {\n    static const std::string unknown_op = \"UNKNOWN_GGML_OP\";\n    return unknown_op;\n}\n"
  },
  {
    "path": "src/ggml-openvino/ggml-decoder.h",
    "content": "#pragma once\n\n#include \"ggml-quants.h\"\n#include \"ggml.h\"\n#include \"openvino/decoder.h\"\n\n#include <cstdint>\n#include <cstring>\n#include <map>\n#include <memory>\n#include <openvino/core/partial_shape.hpp>\n#include <optional>\n#include <vector>\n\nstruct ModelParams {\n    int ctx = -1;\n    int ctx_swa = -1;\n    int ctx_per_seq = -1;\n    int ctx_per_seq_swa = -1;\n    int n_seq = 1;\n    int n_heads = -1;\n    int n_heads_kv = -1;\n    int head_size = -1;\n    int32_t rope_params[15];\n    std::vector<int> swa_layers;\n\n    std::vector<std::string> kv_names;\n    size_t kv_buffer_ctx_id = 0;\n\n    bool same_rope_params(const ModelParams & other) const {\n        return memcmp(rope_params, other.rope_params, sizeof(int32_t) * 15) == 0;\n    }\n\n    bool can_reuse_dynamically(const ModelParams & other) const { return same_rope_params(other); }\n\n    bool can_reuse_statically(const ModelParams & other) const { return same_rope_params(other) && ctx == other.ctx; }\n\n    bool kv_buffer_changed(const ModelParams & other) const { return kv_buffer_ctx_id != other.kv_buffer_ctx_id; }\n};\n\nstruct ComputeParams {\n    int n_seq_active = 1;\n    int seq_active_start = 0;\n    int attention_size = -1;\n    int attention_size_swa = -1;\n    int input_len = -1;\n    int token_len_per_seq = -1;\n    int past_kv_len = -1;\n    int output_len = 1;\n};\n\nclass GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {\npublic:\n    struct NodeInfo {\n        ggml_tensor * node;\n        std::string node_name;\n        std::string node_op_type;\n        std::map<std::string, ggml_tensor *> node_inputs;\n        std::vector<std::string> node_inputs_names;\n        ggml_tensor * node_output;\n        std::string node_output_name;\n        int node_op_case = 0;\n        void * data_addr;\n    };\n    // Graph decoder\n    GgmlOvDecoder(ggml_cgraph * cgraph,\n                  ModelParams & model_params,\n                  ComputeParams & compute_params,\n                  std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,\n                  bool is_static,\n                  bool is_stateful = false,\n                  bool is_prefill = false,\n                  int prefill_chunk_size = 256);\n\n    // Naive graph decoder\n    GgmlOvDecoder(ggml_cgraph * cgraph, std::map<std::string, std::shared_ptr<ov::Node>> & model_weights);\n\n    virtual ov::Any get_attribute(const std::string & name) const override {\n        return nullptr;\n        GGML_UNUSED(name);\n    }\n\n    virtual ov::PartialShape get_input_shape(int node_idx, const std::string & name) const override;\n\n    virtual std::vector<size_t> get_input_stride(int node_idx, const std::string & name) const override;\n\n    virtual ov::element::Type get_input_type(int node_idx, const std::string & name) const override;\n\n    virtual size_t get_input_size() const override;\n\n    virtual size_t get_input_size(int node_idx) const override;\n\n    virtual void get_input_node(size_t input_port_idx,\n                                std::string & producer_name,\n                                std::string & producer_output_port_name,\n                                size_t & producer_output_port_index) const override {\n        GGML_UNUSED(input_port_idx);\n        GGML_UNUSED(producer_name);\n        GGML_UNUSED(producer_output_port_name);\n        GGML_UNUSED(producer_output_port_index);\n    }\n\n    virtual std::vector<std::string> get_input_names(int node_idx) const override;\n\n    virtual ov::PartialShape get_output_shape(int node_idx) const override;\n\n    virtual ov::element::Type get_output_type(int node_idx) const override;\n\n    virtual int32_t * get_input_op_params(int node_idx, const std::string & name) const override;\n\n    virtual int32_t * get_output_op_params(int node_idx) const override;\n\n    virtual std::vector<std::string> get_output_names(int node_idx) const override;\n\n    virtual const std::string & get_op_type() const override;\n\n    virtual const std::string & get_op_type(int node_idx) const override;\n\n    virtual const std::string & get_op_name() const override;\n\n    virtual const std::string & get_op_name(int node_idx) const override;\n\n    virtual void visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>, int node_idx)> node_visitor) const override;\n\n    ggml_tensor * get_input_ggml_tensor(const std::string & name) const { return m_inputs.at(name); }\n\n    virtual int get_op_case(int node_idx) const override { return m_node_info_list[node_idx].node_op_case; }\n\n    virtual const std::map<std::string, std::shared_ptr<ov::Node>> & get_model_inputs() const override {\n        return m_model_inputs;\n    }\n\n    virtual const std::map<std::string, std::shared_ptr<ov::Node>> & get_model_extra_inputs() const override {\n        return m_model_extra_inputs;\n    }\n\n    virtual const std::map<std::string, std::shared_ptr<ov::Tensor>> & get_model_extra_input_values() const {\n        return m_model_extra_input_values;\n    }\n\n    virtual const std::map<std::string, std::shared_ptr<ov::Node>> & get_model_weights() const override {\n        return m_model_weights;\n    }\n\n    virtual std::vector<std::string> get_model_output_names() const override {\n        return m_model_output_names;\n    }\n\n    const std::map<std::string, ggml_tensor *> & get_model_outputs() const { return m_model_outputs; }\n\n    virtual int get_ctx_size() const { return m_model_params.ctx; }\n\n    virtual int get_ctx_swa_size() const { return m_model_params.ctx_swa; }\n\n    virtual int get_ctx_per_seq() const { return m_model_params.ctx_per_seq; }\n\n    virtual int get_ctx_per_seq_swa() const { return m_model_params.ctx_per_seq_swa; }\n\n    virtual int get_n_seq() const { return m_model_params.n_seq; }\n\n    virtual int is_swa_layer(int layer) const override {\n        return std::find(m_model_params.swa_layers.begin(), m_model_params.swa_layers.end(), layer) !=\n               m_model_params.swa_layers.end();\n    }\n\n    int get_past_kv_len() const { return m_compute_params.past_kv_len; }\n\n    int get_input_len() const { return m_compute_params.input_len; }\n\n    virtual int32_t * get_rope_params() const override { return const_cast<int32_t *>(m_model_params.rope_params); }\n\n    virtual std::map<std::string, std::string> get_kv_param_res_names() const override;\n\n    virtual bool is_static() const override { return m_is_static; }\n\n    virtual bool is_stateful() const override { return m_is_stateful; }\n\n    ov::PartialShape get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const;\n\n    static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename);\n\n    static std::shared_ptr<ov::Node> create_weight_node(ggml_tensor * tensor, bool naive = false);\n\n    static std::map<std::string, std::shared_ptr<ov::Node>> create_weight_nodes(ggml_cgraph * cgraph,\n                                                                                bool naive = false);\n\n    const ggml_tensor * get_tensor_used_op(const ggml_tensor * tensor) const;\n\n    const ggml_tensor * get_tensor_from_name(const std::string & name) const;\n\n    void clear_model_weights() { m_model_weights.clear(); }\n\n    static std::pair<ModelParams, ComputeParams> compute_llm_params(ggml_cgraph * cgraph, bool is_static);\n\n    ModelParams get_model_params() const { return m_model_params; }\n\n    ComputeParams get_compute_params() const { return m_compute_params; }\n\n    void set_model_params(const ModelParams & model_params) { m_model_params = model_params; }\n\n    void set_compute_params(const ComputeParams & compute_params) { m_compute_params = compute_params; }\n\n    bool m_is_static = false;\n    bool m_is_stateful = false;\n    bool m_is_prefill = false;\n    bool m_naive = false;\n    int m_prefill_chunk_size = 0;\n\n    static ov::Shape get_shape(const ggml_tensor * tensor);\n    static std::vector<size_t> get_stride(const ggml_tensor * tensor);\n    static ov::element::Type get_ov_type(const ggml_tensor * tensor);\n    static std::string compute_op_type(const ggml_tensor * node);\n    void add_extra_inputs();\n\n    void update_io(ggml_cgraph * cgraph);\n\n    inline static bool is_inp_tok(const ggml_tensor * tensor, const ggml_tensor * op) {\n        return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op == GGML_OP_NONE;\n    }\n\n    inline static bool is_inp_pos(const ggml_tensor * tensor, const ggml_tensor * op) {\n        return op->op == GGML_OP_ROPE && tensor == op->src[1];\n    }\n\n    inline static bool is_inp_emb(const ggml_tensor * tensor, const ggml_tensor * op) {\n        return tensor->op == GGML_OP_GET_ROWS && op->op == GGML_OP_RMS_NORM;\n    }\n\n    inline static bool is_inp_mask(const ggml_tensor * tensor, const ggml_tensor * op) {\n        return op->op == GGML_OP_CPY || (op->op == GGML_OP_FLASH_ATTN_EXT && tensor == op->src[3]);\n    }\n\n    inline static bool is_rope_freqs_weight(const ggml_tensor * tensor, const ggml_tensor * op) {\n        return op->op == GGML_OP_ROPE && tensor == op->src[2];\n    }\n\n    inline static bool is_kvcache(const ggml_tensor * tensor, const ggml_tensor * op) {\n        return op->op == GGML_OP_SET_ROWS && op->src[2] == tensor;\n    }\n\n    inline static bool is_kv_idx(const ggml_tensor * tensor, const ggml_tensor * op) {\n        return op->op == GGML_OP_SET_ROWS && op->src[1] == tensor;\n    }\n\n    inline static bool is_output_idx(const ggml_tensor * tensor, const ggml_tensor * op) {\n        return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op != GGML_OP_NONE;\n    }\n\n    static std::string get_graph_input_ov_name(const ggml_tensor * tensor, const ggml_tensor * op) {\n        if (is_inp_tok(tensor, op)) {\n            return \"inp_tokens\";\n        }\n        if (is_inp_pos(tensor, op)) {\n            return \"inp_pos\";\n        }\n        if (is_inp_emb(tensor, op)) {\n            return \"embd\";\n        }\n        if (is_output_idx(tensor, op)) {\n            return \"inp_out_ids\";\n        }\n        if (is_inp_mask(tensor, op)) {\n            return std::string(tensor->name).find(\"swa\") == std::string::npos ? \"self_kq_mask\" : \"self_kq_mask_swa\";\n        }\n        return tensor->name;\n    }\n\nprivate:\n    void set_input_output();\n    int compute_op_case(const ggml_tensor * node) const;\n    bool node_is_used_as_src(const int node_idx);\n    void compute_model_inputs();\n    void compute_model_outputs();\n\n    void validate_cgraph() const;\n\n    ggml_cgraph * m_cgraph = nullptr;\n    std::map<std::string, ggml_tensor *> m_inputs;\n\n    std::map<std::string, std::shared_ptr<ov::Node>> m_model_inputs;\n    std::map<std::string, std::shared_ptr<ov::Node>> m_model_extra_inputs;\n    std::map<std::string, std::shared_ptr<ov::Tensor>> m_model_extra_input_values;\n    std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights;\n    std::map<std::string, ggml_tensor *> m_model_outputs;\n    std::vector<std::string> m_model_output_names;\n    std::vector<NodeInfo> m_node_info_list;\n\n    ModelParams m_model_params;\n    ComputeParams m_compute_params;\n};\n\nvoid print_tensor_address_map(const ggml_cgraph * cgraph);\n\nint extract_layer_from_name(const std::string & name);\n"
  },
  {
    "path": "src/ggml-openvino/ggml-openvino-extra.cpp",
    "content": "#include \"ggml-openvino-extra.h\"\n\n#include \"ggml-impl.h\"\n#include \"ggml.h\"\n\n#include <cstring>\n#include <openvino/runtime/intel_gpu/ocl/ocl.hpp>\n#include <openvino/runtime/intel_npu/level_zero/level_zero.hpp>\n#include <optional>\n\nov::Core & ov_singleton_core() {\n    static ov::Core core;\n    return core;\n}\n\n// =====================================================\n// Device Configuration Implementations\n// =====================================================\n\nvoid ggml_openvino_device_config::init() {\n    if (initialized) {\n        return;\n    }\n    device_name = getenv(\"GGML_OPENVINO_DEVICE\") ? getenv(\"GGML_OPENVINO_DEVICE\") : \"CPU\";\n    auto available_devices = ov_singleton_core().get_available_devices();\n    if (std::find(available_devices.begin(), available_devices.end(), device_name) == available_devices.end()) {\n        GGML_LOG_WARN(\"GGML OpenVINO Backend: device %s is not available, fallback to CPU\\n\", device_name.c_str());\n        device_name = \"CPU\";\n    }\n    is_npu = (device_name == \"NPU\");\n\n    auto * cache_dir = getenv(\"GGML_OPENVINO_CACHE_DIR\");\n    if (device_name == \"NPU\") {\n        compile_config = {\n            {\"NPU_COMPILER_DYNAMIC_QUANTIZATION\", \"YES\"   },\n            {\"NPU_USE_NPUW\",                      \"YES\"   },\n            {\"NPUW_DEVICES\",                      \"NPU\"   },\n            {\"NPUW_FOLD\",                         \"YES\"   },\n            {\"NPUW_WEIGHTS_BANK\",                 \"shared\"},\n            {\"NPUW_FUNCALL_FOR_ALL\",              \"YES\"   },\n            {\"NPUW_FUNCALL_ASYNC\",                \"YES\"   },\n            {\"NPUW_DQ\",                           \"YES\"   },\n            {\"NPUW_DQ_FULL\",                      \"NO\"    },\n        };\n        if (cache_dir) {\n            compile_config[\"NPUW_CACHE_DIR\"] = cache_dir;\n        }\n    } else if (cache_dir) {\n        ov_singleton_core().set_property(ov::cache_dir(cache_dir));\n    }\n\n    // Initialize remote context with queue sharing for GPU\n    if (device_name == \"GPU\") {\n        // Create OpenCL context and queue\n        cl_int err;\n        cl_platform_id platform;\n        err = clGetPlatformIDs(1, &platform, nullptr);\n        if (err != CL_SUCCESS) {\n            GGML_LOG_ERROR(\"Failed to get OpenCL platform: %d\\n\", err);\n            return;\n        }\n\n        cl_device_id cl_device;\n        err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 1, &cl_device, nullptr);\n        if (err != CL_SUCCESS) {\n            GGML_LOG_ERROR(\"Failed to get OpenCL device: %d\\n\", err);\n            return;\n        }\n\n        cl_context cl_ctx = clCreateContext(nullptr, 1, &cl_device, nullptr, nullptr, &err);\n        if (err != CL_SUCCESS) {\n            GGML_LOG_ERROR(\"Failed to create OpenCL context: %d\\n\", err);\n            return;\n        }\n\n        cl_queue = clCreateCommandQueueWithProperties(cl_ctx, cl_device, nullptr, &err);\n        if (err != CL_SUCCESS) {\n            GGML_LOG_ERROR(\"Failed to create OpenCL command queue: %d\\n\", err);\n            clReleaseContext(cl_ctx);\n            return;\n        }\n\n        // Create OpenVINO remote context with queue sharing\n        remote_context = ov::intel_gpu::ocl::ClContext(ov_singleton_core(), cl_queue);\n\n        // Release the context (queue keeps a reference)\n        clReleaseContext(cl_ctx);\n    } else if (device_name == \"NPU\") {\n        // remote tensor is not used for NPU yet\n        // remote_context = ov_singleton_core().get_default_context(device_name);\n    }\n\n    initialized = true;\n}\n\nggml_openvino_device_config::~ggml_openvino_device_config() {\n    if (cl_queue != nullptr) {\n        clReleaseCommandQueue(cl_queue);\n        cl_queue = nullptr;\n    }\n}\n\n// Get the global device config singleton\nggml_openvino_device_config & ggml_openvino_get_device_config() {\n    static ggml_openvino_device_config config;\n    return config;\n}\n\n// Initialize device config (call during backend init)\nvoid ggml_openvino_init_device_config() {\n    ggml_openvino_get_device_config().init();\n}\n\n// Get the device name\nconst std::string & ggml_openvino_get_device_name() {\n    return ggml_openvino_get_device_config().device_name;\n}\n\n// Check if running on NPU\nbool ggml_openvino_is_npu() {\n    return ggml_openvino_get_device_config().is_npu;\n}\n\n// Get the remote context for the current device (returns empty optional for CPU)\nstd::optional<ov::RemoteContext> ggml_openvino_get_remote_context() {\n    return ggml_openvino_get_device_config().remote_context;\n}\n\n// Get the compile config for the current device\nconst ov::AnyMap & ggml_openvino_get_compile_config() {\n    return ggml_openvino_get_device_config().compile_config;\n}\n\n// Get the OpenCL command queue for GPU operations\ncl_command_queue ggml_openvino_get_cl_queue() {\n    return ggml_openvino_get_device_config().cl_queue;\n}\n\n// Get the clEnqueueMemFillINTEL function pointer (lazy load)\nclEnqueueMemFillINTEL_fn ggml_openvino_get_clEnqueueMemFillINTEL() {\n    static clEnqueueMemFillINTEL_fn fn = nullptr;\n    static bool loaded = false;\n    if (!loaded) {\n        loaded = true;\n        cl_platform_id platform;\n        if (clGetPlatformIDs(1, &platform, nullptr) == CL_SUCCESS) {\n            fn = (clEnqueueMemFillINTEL_fn) clGetExtensionFunctionAddressForPlatform(platform, \"clEnqueueMemFillINTEL\");\n        }\n    }\n    return fn;\n}\n\n// Get the clEnqueueMemcpyINTEL function pointer (lazy load)\nclEnqueueMemcpyINTEL_fn ggml_openvino_get_clEnqueueMemcpyINTEL() {\n    static clEnqueueMemcpyINTEL_fn fn = nullptr;\n    static bool loaded = false;\n    if (!loaded) {\n        loaded = true;\n        cl_platform_id platform;\n        if (clGetPlatformIDs(1, &platform, nullptr) == CL_SUCCESS) {\n            fn = (clEnqueueMemcpyINTEL_fn) clGetExtensionFunctionAddressForPlatform(platform, \"clEnqueueMemcpyINTEL\");\n        }\n    }\n    return fn;\n}\n\n// Get requantization type for a tensor type (returns nullopt if no requant needed)\nstd::optional<ExtraQuantType> ggml_openvino_get_requant_type(const ggml_tensor * tensor, bool no_requant) {\n    if (no_requant) {\n        return std::nullopt;\n    }\n    if (strncmp(tensor->name, \"token_embd.weight\", 17) == 0) {\n        return ((ggml_openvino_is_npu() && tensor->type == GGML_TYPE_Q6_K) ? ExtraQuantType::F16 : ExtraQuantType::Q8_0_C);\n    }\n    if (strncmp(tensor->name, \"output.weight\", 13) == 0) {\n        return ExtraQuantType::Q8_0_C;\n    }\n    if (ggml_openvino_is_npu()) {\n        return ExtraQuantType::Q4_0_128;\n    }\n    switch (tensor->type) {\n    case GGML_TYPE_Q6_K:\n    case GGML_TYPE_Q5_K:\n        return ExtraQuantType::Q8_0_C;\n    default:\n        return std::nullopt;\n    }\n}\n\n// =====================================================\n// Extracted Layout Calculation\n// =====================================================\n\nggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor, bool use_bias) {\n    ggml_openvino_extracted_layout layout = {};\n    layout.is_symmetric = false;\n\n    if (!ggml_is_quantized(tensor->type)) {\n        return layout;\n    }\n\n    // Only handle 2D weight tensors\n    if (tensor->ne[2] != 1 || tensor->ne[3] != 1) {\n        return layout;\n    }\n\n    int64_t n_elements = ggml_nelements(tensor);\n    const size_t alignment = 64;  // Good for SIMD\n\n    // Check if requantization is needed (NPU-specific)\n    auto requant_type = ggml_openvino_get_requant_type(tensor, use_bias);\n    if (requant_type.has_value()) {\n        layout.is_requant = true;\n        layout.requant_type = requant_type;\n\n        // Special case: requant to F16 - just store F16 weights, no scales/zp\n        if (requant_type.value() == ExtraQuantType::F16) {\n            layout.weights_size = n_elements * sizeof(uint16_t);  // F16 = 2 bytes\n            layout.total_size = layout.weights_size;\n            layout.weights_offset = 0;\n            // No scales/zp for F16\n            return layout;\n        }\n\n        // Requant to different quantized format (e.g., Q4_0_128)\n        switch (requant_type.value()) {\n        case ExtraQuantType::Q4_0_128:\n            layout.is_u4 = true;\n            layout.weights_per_block = 128;\n            layout.is_symmetric = true;\n            break;\n        case ExtraQuantType::Q4_0_C:\n            layout.is_u4 = true;\n            layout.weights_per_block = tensor->ne[0];\n            layout.is_symmetric = true;\n            break;\n        case ExtraQuantType::Q8_0_32:\n            layout.is_u4 = false;\n            layout.weights_per_block = 32;\n            layout.is_symmetric = true;\n            break;\n        case ExtraQuantType::Q8_0_C:\n            layout.is_u4 = false;\n            layout.weights_per_block = tensor->ne[0];\n            layout.is_symmetric = true;\n            break;\n        case ExtraQuantType::Q8_1_C:\n            layout.is_u4 = false;\n            layout.weights_per_block = tensor->ne[0];\n            break;\n        default:\n            layout.weights_per_block = -1;\n            GGML_ABORT(\"Code of re-quantizing to channel-wise is not updated\");\n            break;\n        }\n\n        if (layout.is_requant) {\n            // Calculate sizes for requantized format\n            layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements;\n            int64_t n_blocks = n_elements / layout.weights_per_block;\n            layout.scales_size = n_blocks * sizeof(uint16_t);\n            // For symmetric quantization, we only need one zp value (not one per block)\n            // Zero points are stored in U4 or U8 format matching the weight type\n            size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks;\n            layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements;\n\n            layout.weights_offset = 0;\n            layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment;\n            layout.zp_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment;\n            layout.total_size = layout.zp_offset + layout.zp_size;\n            layout.total_size = std::max(layout.total_size, ggml_nbytes(tensor));\n            return layout;\n        }\n    }\n\n    // Normal extraction (no requant) - determine format based on tensor type\n    layout.is_u4 = false;\n    layout.weights_per_block = 32;\n    layout.is_symmetric = false;\n\n    switch (tensor->type) {\n    case GGML_TYPE_Q4_0:\n        layout.is_u4 = true;\n        layout.is_symmetric = true;\n        break;\n\n    case GGML_TYPE_Q4_1:\n    case GGML_TYPE_Q4_K:\n        layout.is_u4 = true;\n        break;\n\n    case GGML_TYPE_Q8_0:\n        layout.is_symmetric = true;\n        break;\n\n    case GGML_TYPE_Q6_K:\n        layout.weights_per_block = 16;\n        layout.is_symmetric = true;\n        break;\n\n    case GGML_TYPE_Q5_K:\n        break;\n\n    default:\n        // Unsupported quantization type\n        return layout;\n    }\n\n    // Calculate sizes\n    // Weights: U4 = n_elements/2 bytes, U8 = n_elements bytes\n    layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements;\n\n    // Scales: F16 per block\n    int64_t n_blocks = n_elements / layout.weights_per_block;\n    layout.scales_size = n_blocks * sizeof(uint16_t);  // F16 = 2 bytes\n    // Zero points: U4 or U8 matching weight type\n    // For symmetric quantization, we only need one zp value (not one per block)\n    size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks;\n    layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements;\n\n    // Layout in buffer: [weights | scales | zp] with alignment\n    layout.weights_offset = 0;\n    layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment;\n    layout.zp_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment;\n    layout.total_size = layout.zp_offset + layout.zp_size;\n    layout.total_size = std::max(layout.total_size, ggml_nbytes(tensor));\n\n    return layout;\n}\n\nggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor, bool is_remote) {\n    ov::Shape shape;\n    for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) {\n        shape.push_back(static_cast<size_t>(tensor->ne[i]));\n    }\n\n    ov::element::Type element_type;\n    switch (tensor->type) {\n    case GGML_TYPE_F32:\n        element_type = ov::element::f32;\n        break;\n    case GGML_TYPE_F16:\n        element_type = ov::element::f16;\n        break;\n    case GGML_TYPE_BF16:\n        element_type = ov::element::bf16;\n        break;\n    case GGML_TYPE_I32:\n        element_type = ov::element::i32;\n        break;\n    case GGML_TYPE_I64:\n        element_type = ov::element::i64;\n        break;\n    default:\n        // GGML_LOG_WARN(\"%s: unsupported tensor type for ov::Tensor: %s\\n\", __func__, ggml_type_name(tensor->type));\n        return nullptr;\n    }\n\n    const auto & device_name = ggml_openvino_get_device_name();\n    auto remote_context = ggml_openvino_get_remote_context();\n\n    std::shared_ptr<ov::Tensor> ov_tensor;\n    if (is_remote) {\n        GGML_ASSERT(device_name == \"GPU\");\n        auto gpu_context = remote_context->as<ov::intel_gpu::ocl::ClContext>();\n        auto usm_tensor = gpu_context.create_tensor(element_type, shape, tensor->data);\n        ov_tensor = std::make_shared<ov::intel_gpu::ocl::USMTensor>(std::move(usm_tensor));\n    } else {\n        ov_tensor = std::make_shared<ov::Tensor>(element_type, shape, tensor->data);\n    }\n\n    return new ggml_openvino_tensor_extra(ov_tensor);\n}\n"
  },
  {
    "path": "src/ggml-openvino/ggml-openvino-extra.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n#include \"openvino/runtime/core.hpp\"\n\n#define CL_TARGET_OPENCL_VERSION 300\n#include <CL/cl.h>\n\n#include <cstdlib>\n#include <memory>\n#include <openvino/core/node.hpp>\n#include <openvino/runtime/remote_context.hpp>\n#include <openvino/runtime/tensor.hpp>\n#include <optional>\n#include <string>\n\n// ExtraQuantType enum - defines requantization target formats\nenum class ExtraQuantType { F16, Q4_0_C, Q8_1_C, Q4_0_128, Q8_0_C, Q8_0_32 };\n\nov::Core & ov_singleton_core();\n\n// Get the remote context for the current device (returns empty optional for CPU)\nstd::optional<ov::RemoteContext> ggml_openvino_get_remote_context();\n\n// Get the compile config for the current device\nconst ov::AnyMap & ggml_openvino_get_compile_config();\n\n// Get the OpenCL command queue for GPU operations (returns nullptr for CPU/NPU)\ncl_command_queue ggml_openvino_get_cl_queue();\n\n// Intel USM extension function type\ntypedef cl_int(CL_API_CALL * clEnqueueMemFillINTEL_fn)(cl_command_queue queue,\n                                                       void * dst_ptr,\n                                                       const void * pattern,\n                                                       size_t pattern_size,\n                                                       size_t size,\n                                                       cl_uint num_events_in_wait_list,\n                                                       const cl_event * event_wait_list,\n                                                       cl_event * event);\n\ntypedef cl_int(CL_API_CALL * clEnqueueMemcpyINTEL_fn)(cl_command_queue queue,\n                                                      cl_bool blocking,\n                                                      void * dst_ptr,\n                                                      const void * src_ptr,\n                                                      size_t size,\n                                                      cl_uint num_events_in_wait_list,\n                                                      const cl_event * event_wait_list,\n                                                      cl_event * event);\n\n// Get the clEnqueueMemFillINTEL function pointer (returns nullptr if not available)\nclEnqueueMemFillINTEL_fn ggml_openvino_get_clEnqueueMemFillINTEL();\n\n// Get the clEnqueueMemcpyINTEL function pointer (returns nullptr if not available)\nclEnqueueMemcpyINTEL_fn ggml_openvino_get_clEnqueueMemcpyINTEL();\n\n// =====================================================\n// Global Device Configuration (singleton)\n// =====================================================\n// Initialized once during backend init from GGML_OPENVINO_DEVICE env var\n\nstruct ggml_openvino_device_config {\n    std::string device_name = \"CPU\";\n    bool is_npu = false;\n    bool initialized = false;\n    std::optional<ov::RemoteContext> remote_context;\n    ov::AnyMap compile_config;\n    cl_command_queue cl_queue = nullptr;\n\n    void init();\n    ~ggml_openvino_device_config();\n};\n\n// Get the global device config singleton\nggml_openvino_device_config & ggml_openvino_get_device_config();\n\n// Initialize device config (call during backend init)\nvoid ggml_openvino_init_device_config();\n\n// Get the device name\nconst std::string & ggml_openvino_get_device_name();\n\n// Check if running on NPU\nbool ggml_openvino_is_npu();\n\n// Get requantization type for a tensor type (returns nullopt if no requant needed)\nstd::optional<ExtraQuantType> ggml_openvino_get_requant_type(const ggml_tensor * tensor, bool no_requant = false);\n\n// =====================================================\n// OpenVINO Tensor Extra Types\n// =====================================================\n// These types are stored in tensor->extra by the OpenVINO backend buffer.\n// They allow:\n// 1. Pre-built ov::Constant nodes for weights (avoiding memcpy during graph construction)\n// 2. ov::Tensor wrappers for KV cache / compute tensors (for direct use with infer_request)\n\n// Base class for OpenVINO tensor extra data\nstruct ggml_openvino_extra_base {\n    enum class Type { WEIGHT, QUANTIZED_WEIGHT, TENSOR };\n    Type type;\n    virtual ~ggml_openvino_extra_base() = default;\nprotected:\n    explicit ggml_openvino_extra_base(Type t) : type(t) {}\n};\n\n// Extra data for F16/F32/BF16 weight tensors - stores the pre-built weight node\nstruct ggml_openvino_weight_extra : public ggml_openvino_extra_base {\n    ov::Tensor weights;                     // The underlying weight data tensor\n    std::shared_ptr<ov::Node> weight_node;  // Pre-built OpenVINO weight node\n\n    ggml_openvino_weight_extra(ov::Tensor w, std::shared_ptr<ov::Node> n) :\n        ggml_openvino_extra_base(Type::WEIGHT),\n        weights(std::move(w)),\n        weight_node(std::move(n)) {}\n};\n\n// Extra data for quantized weight tensors - stores extracted weights/scales/zp and weight node\nstruct ggml_openvino_quantized_weight_extra : public ggml_openvino_extra_base {\n    ov::Tensor weights;   // U4 or U8 extracted weights\n    ov::Tensor scales;    // F16 scales\n    ov::Tensor zp;        // U4 or U8 zero points (same type as weights)\n    std::shared_ptr<ov::Node> weight_node;  // Pre-built OpenVINO weight subgraph\n\n    ggml_openvino_quantized_weight_extra(ov::Tensor w, ov::Tensor s, ov::Tensor z, std::shared_ptr<ov::Node> n) :\n        ggml_openvino_extra_base(Type::QUANTIZED_WEIGHT),\n        weights(std::move(w)),\n        scales(std::move(s)),\n        zp(std::move(z)),\n        weight_node(std::move(n)) {}\n};\n\n// Extra data for KV cache / compute tensors - stores ov::Tensor for infer_request\nstruct ggml_openvino_tensor_extra : public ggml_openvino_extra_base {\n    std::shared_ptr<ov::Tensor> tensor;  // For direct use with infer_request\n\n    explicit ggml_openvino_tensor_extra(std::shared_ptr<ov::Tensor> t)\n        : ggml_openvino_extra_base(Type::TENSOR), tensor(std::move(t)) {}\n};\n\n// =====================================================\n// Extracted Size Calculation for Quantized Tensors\n// =====================================================\n// For quantized tensors, we need extra space to store extracted weights, scales, and zero points.\n// Returns the total size needed in the buffer for extracted data.\n\nstruct ggml_openvino_extracted_layout {\n    size_t total_size = 0;      // Total bytes needed\n    size_t weights_offset = 0;  // Offset to weights in buffer\n    size_t weights_size = 0;    // Size of weights in bytes\n    size_t scales_offset = 0;   // Offset to scales in buffer\n    size_t scales_size = 0;     // Size of scales in bytes\n    size_t zp_offset = 0;       // Offset to zero points in buffer\n    size_t zp_size = 0;         // Size of zero points in bytes (U4 or U8)\n    bool is_u4;                 // true for U4 weights, false for U8\n    int64_t weights_per_block;  // weights per scale/zp block\n    bool is_symmetric;        // true for symmetric quantization\n\n    // Requantization info\n    bool is_requant = false;                      // true if this tensor needs requantization\n    std::optional<ExtraQuantType> requant_type;   // target requant type if is_requant\n};\n\n// Calculate the buffer layout for extracted quantized data\nggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor, bool use_bias = false);\n\nggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor, bool is_remote);\n\n// Register an extra with the tensor's OpenVINO buffer context for proper lifetime management.\n// This sets tensor->extra and tracks the extra in the buffer context for cleanup.\nvoid ggml_openvino_buffer_register_extra(ggml_tensor * tensor, ggml_openvino_extra_base * extra);\n\n// =====================================================\n// OpenVINO Backend Context and Interface\n// =====================================================\nstruct ggml_backend_openvino_context {\n    int device = 0;\n    std::string name = \"OpenVINO\";\n    std::string description = \"OpenVINO Backend Context\";\n\n    std::shared_ptr<void> runtime_context = nullptr;\n\n    ggml_backend_openvino_context() = default;\n};\n"
  },
  {
    "path": "src/ggml-openvino/ggml-openvino.cpp",
    "content": "#include \"ggml-openvino.h\"\n\n#include \"ggml-backend-impl.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-openvino-extra.h\"\n#include \"ggml-openvino/utils.h\"\n#include \"ggml-quants.h\"\n#include \"ggml.h\"\n\n#include <atomic>\n#include <cstdlib>\n#include <cstdint>\n#include <cstring>\n#include <memory>\n#include <mutex>\n#include <openvino/core/type/element_type.hpp>\n#include <openvino/openvino.hpp>\n#include <openvino/runtime/allocator.hpp>\n#include <openvino/runtime/intel_gpu/ocl/ocl.hpp>\n#include <openvino/runtime/intel_npu/level_zero/level_zero.hpp>\n#include <openvino/runtime/tensor.hpp>\n#include <set>\n#include <string>\n#include <vector>\n\n#if defined(_WIN32)\n#    define WIN32_LEAN_AND_MEAN\n#    ifndef NOMINMAX\n#        define NOMINMAX\n#    endif\n#    include <windows.h>\n#else\n#    include <unistd.h>\n#endif\n\n// =====================================================\n// OpenVINO Buffer Implementation using ov::Tensor\n// =====================================================\n//\n// Design: This implementation uses a hybrid approach:\n// 1. For weight tensors: Store a pre-built ov::op::v0::Constant in tensor->extra\n//    - This avoids the memcpy during graph construction\n//    - For quantized weights, the constant is already converted to OpenVINO format\n// 2. For KV cache / compute tensors: Store an ov::Tensor in tensor->extra\n//    - This can be directly passed to infer_request\n//    - Future: can be changed to ov::RemoteTensor for GPU/NPU\n//\n// This design is similar to:\n// - CUDA split buffer: tensor->extra stores device pointers\n// - CPU repack buffer: tensor->extra stores tensor_traits with repacked data\n// =====================================================\n\n// Buffer context that manages per-tensor allocations (no contiguous buffer for weights)\nstruct ggml_backend_openvino_buffer_context {\n    int device;\n    std::string name;\n    size_t id;\n\n    // For non-weight buffers (KV cache, compute), we still use contiguous allocation\n    void * data;\n    size_t size;\n    bool is_remote;\n\n    // Wrapping of the buffer\n    std::shared_ptr<ov::Tensor> ov_buffer;\n\n    // Track all extras for cleanup\n    std::map<ggml_tensor *, ggml_openvino_extra_base *> tensor_extras;\n\n    // Used for re-allocation on device for kvcache\n    void * data_prev;\n\n    ggml_backend_openvino_buffer_context(int device, size_t size, bool is_remote = false) :\n        device(device),\n        name(std::string(GGML_OPENVINO_NAME) + std::to_string(device)),\n        id([]() {\n            static std::atomic<size_t> next_id{1};\n            return next_id.fetch_add(1);\n        }()),\n        data(nullptr),\n        size(size),\n        is_remote(is_remote) {\n        if (size == 0) {\n            return;\n        }\n\n        const auto & device_name = ggml_openvino_get_device_name();\n\n        if (is_remote) {\n            GGML_ASSERT(device_name == \"GPU\");\n            auto remote_context = ggml_openvino_get_remote_context();\n            auto gpu_context = remote_context->as<ov::intel_gpu::ocl::ClContext>();\n            ov::intel_gpu::ocl::USMTensor usm_tensor =\n                gpu_context.create_usm_device_tensor(ov::element::u8, ov::Shape{size});\n            data = usm_tensor.get();\n            ov_buffer = std::make_shared<ov::intel_gpu::ocl::USMTensor>(std::move(usm_tensor));\n        } else {\n            data = ggml_aligned_malloc(size);\n            ov_buffer = std::make_shared<ov::Tensor>(ov::element::u8, ov::Shape{size}, data);\n        }\n\n        if (data == nullptr) {\n            GGML_LOG_ERROR(\"%s: failed to allocate %zu bytes\\n\", __func__, size);\n            return;\n        }\n\n        if (reinterpret_cast<uintptr_t>(data) % TENSOR_ALIGNMENT != 0) {\n            GGML_LOG_ERROR(\"%s: %s buffer is not aligned to %d bytes\\n\", __func__, device_name.c_str(),\n                           TENSOR_ALIGNMENT);\n            GGML_ABORT(\"fatal error\");\n        }\n    }\n\n    ~ggml_backend_openvino_buffer_context() {\n        // Clean up all tensor extras\n        // GGML_LOG_DEBUG(\"Deleting OpenVINO buffer context #%zu for device %d, size %zu MB\\n\", id, device,\n        //                size / 1024 / 1024);\n        for (auto & pair : tensor_extras) {\n            delete pair.second;\n        }\n        tensor_extras.clear();\n        if (!is_remote && data != nullptr) {\n            ggml_aligned_free(data, size);\n        }\n    }\n};\n\n// Buffer type context (per-device)\nstruct ggml_backend_openvino_buffer_type_context {\n    int device;\n    std::string name;\n};\n\n// Buffer interface functions\nstatic void ggml_backend_openvino_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;\n    delete ctx;\n}\n\nstatic void * ggml_backend_openvino_buffer_get_base(ggml_backend_buffer_t buffer) {\n    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;\n    return ctx->data;\n}\n\nstatic enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {\n    // GGML_LOG_DEBUG(\"%s: buffer usage=%d, tensor name=%s\\n\", __func__, buffer->usage, tensor->name);\n    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;\n\n    // Put kvcache on device memory for GPU (NPU memory is too small even for kvcache)\n    if (strncmp(tensor->name, \"cache_\", 6) == 0 && !ctx->is_remote && ggml_openvino_get_device_name() == \"GPU\" &&\n        !getenv(\"GGML_OPENVINO_STATEFUL_EXECUTION\")) {\n        GGML_ASSERT(ctx->tensor_extras.empty());\n        auto device = ctx->device;\n        auto size = ctx->size;\n        auto * data_prev = ctx->data;\n        delete ctx;\n        ctx = new ggml_backend_openvino_buffer_context(device, size, true);\n        buffer->context = ctx;\n        tensor->data = (char *) ctx->data + ((char *) tensor->data - (char *) data_prev);\n    }\n\n    // Views share the extra from view_src\n    if (tensor->view_src != nullptr) {\n        GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);\n        if (tensor->view_src->extra != nullptr) {\n            tensor->extra = tensor->view_src->extra;\n        }\n        return GGML_STATUS_SUCCESS;\n    }\n\n    ctx = (ggml_backend_openvino_buffer_context *) buffer->context;\n\n    if (tensor->data != nullptr && !ggml_is_quantized(tensor->type)) {\n        ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor, ctx->is_remote);\n        if (extra != nullptr) {\n            auto it = ctx->tensor_extras.find(tensor);\n            if (it != ctx->tensor_extras.end()) {\n                delete it->second;\n            }\n            ctx->tensor_extras[tensor] = extra;\n            tensor->extra = extra;\n        }\n    }\n\n    return GGML_STATUS_SUCCESS;\n}\n\nstatic void ggml_backend_openvino_buffer_memset_tensor(ggml_backend_buffer_t buffer,\n                                                       ggml_tensor * tensor,\n                                                       uint8_t value,\n                                                       size_t offset,\n                                                       size_t size) {\n    // GGML_LOG_DEBUG(\"%s: buffer usage=%d, tensor name=%s\\n\", __func__, buffer->usage, tensor->name);\n    GGML_ASSERT(tensor != nullptr && tensor->data != nullptr);\n    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;\n\n    if (ctx->is_remote) {\n        // For remote (device) buffers, use OpenCL USM memfill\n        cl_command_queue queue = ggml_openvino_get_cl_queue();\n        auto mem_fill_fn = ggml_openvino_get_clEnqueueMemFillINTEL();\n        if (queue != nullptr && mem_fill_fn != nullptr) {\n            uint8_t pattern = value;\n            cl_int err = mem_fill_fn(queue, (char *) tensor->data + offset, &pattern, sizeof(pattern), size, 0, nullptr,\n                                     nullptr);\n            if (err != CL_SUCCESS) {\n                GGML_LOG_ERROR(\"%s: clEnqueueMemFillINTEL failed with error %d\\n\", __func__, err);\n            }\n            clFinish(queue);\n        } else {\n            GGML_LOG_ERROR(\"%s: no OpenCL queue or clEnqueueMemFillINTEL not available for GPU buffer\\n\", __func__);\n        }\n    } else {\n        memset((char *) tensor->data + offset, value, size);\n    }\n}\n\nstatic void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer,\n                                                    ggml_tensor * tensor,\n                                                    const void * data,\n                                                    size_t offset,\n                                                    size_t size) {\n    // GGML_LOG_DEBUG(\"%s: buffer usage=%d, tensor name=%s\\n\", __func__, buffer->usage, tensor->name);\n    GGML_ASSERT(tensor != nullptr && tensor->data != nullptr);\n    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;\n\n    // Check if this is a weight buffer (usage is set BEFORE set_tensor is called, except in test-backend-ops)\n    bool is_weight_buffer = (buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS);\n    // Full tensor set: offset=0, full size, not a view\n    bool is_full_tensor_set = (offset == 0 && size == ggml_nbytes(tensor) && tensor->view_src == nullptr);\n    // 2D tensor (typical weight shape)\n    bool is_2d = (tensor->ne[2] == 1 && tensor->ne[3] == 1);\n\n    if (is_weight_buffer && is_full_tensor_set && is_2d) {\n        try {\n            auto result = process_weight_tensor(tensor, data, tensor->data);\n            result.weight_node->set_friendly_name(tensor->name);\n\n            // const auto & layout = result.layout;\n            ggml_openvino_extra_base * extra;\n\n            // Quantized path with extracted weight/scale/zp tensors\n            if (result.is_quantized()) {\n                extra = new ggml_openvino_quantized_weight_extra(std::move(result.weights), std::move(result.scales),\n                                                                 std::move(result.zp), result.weight_node);\n\n                // if (layout.is_requant) {\n                //     GGML_LOG_DEBUG(\"%s: requantized %s to %s (u%d, block_size=%ld)\\n\", __func__, tensor->name,\n                //                    extra_quant_type_name(layout.requant_type.value()), layout.is_u4 ? 4 : 8,\n                //                    layout.weights_per_block);\n                // } else {\n                //     int64_t n_blocks = ggml_nelements(tensor) / layout.weights_per_block;\n                //     GGML_LOG_DEBUG(\"%s: extracted quantized weight node for %s (u%d, %zu weights, %ld blocks)\\n\",\n                //                    __func__, tensor->name, layout.is_u4 ? 4 : 8, layout.weights_size, n_blocks);\n                // }\n            } else {\n                // F16/F32/BF16 weight or F16-requant\n                extra = new ggml_openvino_weight_extra(std::move(result.weights), result.weight_node);\n\n                // if (layout.total_size > 0) {\n                //     GGML_LOG_DEBUG(\"%s: requantized %s to F16\\n\", __func__, tensor->name);\n                // } else {\n                //     GGML_LOG_DEBUG(\"%s: created shared-memory weight node for %s\\n\", __func__, tensor->name);\n                // }\n            }\n\n            ctx->tensor_extras[tensor] = extra;\n            tensor->extra = extra;\n\n        } catch (const std::exception & e) {\n            GGML_LOG_ERROR(\"%s: failed to process weight tensor for %s: %s\\n\", __func__, tensor->name, e.what());\n            memcpy((char *) tensor->data + offset, data, size);\n        }\n    } else {\n        // Non-weight tensor (KV cache, activations, etc.) - copy data. test-backend-ops also goes here\n        if (ctx->is_remote) {\n            cl_command_queue queue = ggml_openvino_get_cl_queue();\n            auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL();\n            if (queue != nullptr && mem_cpy_fn != nullptr) {\n                cl_int err =\n                    mem_cpy_fn(queue, CL_TRUE, (char *) tensor->data + offset, data, size, 0, nullptr, nullptr);\n                if (err != CL_SUCCESS) {\n                    GGML_LOG_ERROR(\"%s: clEnqueueMemcpyINTEL failed with error %d\\n\", __func__, err);\n                }\n            } else {\n                GGML_LOG_ERROR(\"%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\\n\", __func__);\n            }\n        } else {\n            memcpy((char *) tensor->data + offset, data, size);\n        }\n\n        ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor, ctx->is_remote);\n        if (extra == nullptr) {\n            // GGML_LOG_ERROR(\"%s: failed to create tensor extra for %s\\n\", __func__, tensor->name);\n            return;\n        }\n\n        auto it = ctx->tensor_extras.find(tensor);\n        if (it != ctx->tensor_extras.end()) {\n            delete it->second;\n        }\n        ctx->tensor_extras[tensor] = extra;\n        tensor->extra = extra;\n    }\n}\n\nstatic void ggml_backend_openvino_buffer_get_tensor(ggml_backend_buffer_t buffer,\n                                                    const ggml_tensor * tensor,\n                                                    void * data,\n                                                    size_t offset,\n                                                    size_t size) {\n    // GGML_LOG_DEBUG(\"%s: buffer usage=%d, tensor name=%s\\n\", __func__, buffer->usage, tensor->name);\n    GGML_ASSERT(tensor != nullptr && tensor->data != nullptr);\n    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;\n\n    if (ctx->is_remote) {\n        // For remote (device) buffers, use OpenCL USM memcpy (device-to-host)\n        cl_command_queue queue = ggml_openvino_get_cl_queue();\n        auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL();\n        if (queue != nullptr && mem_cpy_fn != nullptr) {\n            cl_int err =\n                mem_cpy_fn(queue, CL_TRUE, data, (const char *) tensor->data + offset, size, 0, nullptr, nullptr);\n            if (err != CL_SUCCESS) {\n                GGML_LOG_ERROR(\"%s: clEnqueueMemcpyINTEL failed with error %d\\n\", __func__, err);\n            }\n        } else {\n            GGML_LOG_ERROR(\"%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\\n\", __func__);\n        }\n    } else {\n        memcpy(data, (const char *) tensor->data + offset, size);\n    }\n}\n\nstatic bool ggml_backend_openvino_buffer_cpy_tensor(ggml_backend_buffer_t buffer,\n                                                    const ggml_tensor * src,\n                                                    ggml_tensor * dst) {\n    // GGML_LOG_DEBUG(\"%s: src tensor name=%s, dst tensor name=%s\\n\", __func__, src->name, dst->name);\n    GGML_ASSERT(src != nullptr && dst != nullptr);\n    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;\n\n    if (ctx->is_remote) {\n        // For remote (device) buffers, use OpenCL USM memcpy\n        cl_command_queue queue = ggml_openvino_get_cl_queue();\n        auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL();\n        if (queue == nullptr || mem_cpy_fn == nullptr) {\n            GGML_LOG_ERROR(\"%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\\n\", __func__);\n            return false;\n        }\n        // Can copy from host to device\n        if (ggml_backend_buffer_is_host(src->buffer)) {\n            cl_int err = mem_cpy_fn(queue, CL_TRUE, dst->data, src->data, ggml_nbytes(src), 0, nullptr, nullptr);\n            if (err != CL_SUCCESS) {\n                GGML_LOG_ERROR(\"%s: clEnqueueMemcpyINTEL (host-to-device) failed with error %d\\n\", __func__, err);\n                return false;\n            }\n            return true;\n        }\n        // Can also copy from device to device if both are OpenVINO remote buffers\n        if (ggml_backend_buffer_is_openvino(src->buffer)) {\n            ggml_backend_openvino_buffer_context * src_ctx =\n                (ggml_backend_openvino_buffer_context *) src->buffer->context;\n            if (src_ctx->is_remote) {\n                cl_int err =\n                    mem_cpy_fn(queue, CL_TRUE, dst->data, src->data, ggml_nbytes(src), 0, nullptr, nullptr);\n                if (err != CL_SUCCESS) {\n                    GGML_LOG_ERROR(\"%s: clEnqueueMemcpyINTEL (device-to-device) failed with error %d\\n\", __func__,\n                                   err);\n                    return false;\n                }\n                return true;\n            }\n        }\n        return false;\n    }\n\n    // Host buffer - can copy from any host buffer\n    if (ggml_backend_buffer_is_host(src->buffer)) {\n        memcpy(dst->data, src->data, ggml_nbytes(src));\n        return true;\n    }\n    return false;\n}\n\nstatic void ggml_backend_openvino_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;\n    GGML_ASSERT(ctx->data != nullptr);\n    if (ctx->is_remote) {\n        cl_command_queue queue = ggml_openvino_get_cl_queue();\n        auto mem_fill_fn = ggml_openvino_get_clEnqueueMemFillINTEL();\n        if (queue != nullptr && mem_fill_fn != nullptr) {\n            uint8_t pattern = value;\n            cl_int err = mem_fill_fn(queue, ctx->data, &pattern, sizeof(pattern), ctx->size, 0, nullptr, nullptr);\n            if (err != CL_SUCCESS) {\n                GGML_LOG_WARN(\"%s: clEnqueueMemFillINTEL failed with error %d\\n\", __func__, err);\n            }\n            clFinish(queue);\n        } else {\n            GGML_LOG_WARN(\"%s: no OpenCL queue or clEnqueueMemFillINTEL not available for GPU buffer clear\\n\",\n                          __func__);\n        }\n    } else {\n        memset(ctx->data, value, ctx->size);\n    }\n}\n\nstatic const ggml_backend_buffer_i ggml_backend_openvino_buffer_interface = {\n    /* .free_buffer     = */ ggml_backend_openvino_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_openvino_buffer_get_base,\n    /* .init_tensor     = */ ggml_backend_openvino_buffer_init_tensor,\n    /* .memset_tensor   = */ ggml_backend_openvino_buffer_memset_tensor,\n    /* .set_tensor      = */ ggml_backend_openvino_buffer_set_tensor,\n    /* .get_tensor      = */ ggml_backend_openvino_buffer_get_tensor,\n    /* .cpy_tensor      = */ ggml_backend_openvino_buffer_cpy_tensor,\n    /* .clear           = */ ggml_backend_openvino_buffer_clear,\n    /* .reset           = */ NULL,\n};\n\n// Buffer type interface functions\nstatic const char * ggml_backend_openvino_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    ggml_backend_openvino_buffer_type_context * ctx = (ggml_backend_openvino_buffer_type_context *) buft->context;\n    return ctx->name.c_str();\n}\n\nstatic ggml_backend_buffer_t ggml_backend_openvino_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,\n                                                                            size_t size) {\n    ggml_backend_openvino_buffer_type_context * buft_ctx = (ggml_backend_openvino_buffer_type_context *) buft->context;\n\n    // Create buffer context with contiguous memory allocation\n    ggml_backend_openvino_buffer_context * ctx = new ggml_backend_openvino_buffer_context(buft_ctx->device, size);\n\n    if (ctx->data == nullptr && size > 0) {\n        GGML_LOG_ERROR(\"%s: failed to allocate buffer of size %zu\\n\", __func__, size);\n        delete ctx;\n        return nullptr;\n    }\n\n    return ggml_backend_buffer_init(buft, ggml_backend_openvino_buffer_interface, ctx, size);\n}\n\nstatic size_t ggml_backend_openvino_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    GGML_UNUSED(buft);\n    return TENSOR_ALIGNMENT;\n}\n\nstatic size_t ggml_backend_openvino_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {\n    GGML_UNUSED(buft);\n    return SIZE_MAX;\n}\n\nstatic size_t ggml_backend_openvino_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,\n                                                               const ggml_tensor * tensor) {\n    GGML_UNUSED(buft);\n\n    // For quantized 2D tensors (weights), we need extra space for extracted data\n    if (ggml_is_quantized(tensor->type) && tensor->ne[2] == 1 && tensor->ne[3] == 1) {\n        ggml_openvino_extracted_layout layout = ggml_openvino_get_extracted_layout(tensor);\n        if (layout.total_size > 0) {\n            // GGML_LOG_DEBUG(\"%s: tensor %s needs %zu bytes (original %zu, extracted: weights=%zu scales=%zu zp=%zu)\\n\",\n            //                __func__, tensor->name, layout.total_size, ggml_nbytes(tensor), layout.weights_size,\n            //                layout.scales_size, layout.zp_size);\n            return layout.total_size;\n        }\n    }\n\n    return ggml_nbytes(tensor);\n}\n\nstatic const ggml_backend_buffer_type_i ggml_backend_openvino_buffer_type_interface = {\n    /* .get_name         = */ ggml_backend_openvino_buffer_type_get_name,\n    /* .alloc_buffer     = */ ggml_backend_openvino_buffer_type_alloc_buffer,\n    /* .get_alignment    = */ ggml_backend_openvino_buffer_type_get_alignment,\n    /* .get_max_size     = */ ggml_backend_openvino_buffer_type_get_max_size,\n    /* .get_alloc_size   = */ ggml_backend_openvino_buffer_type_get_alloc_size,\n    /* .is_host          = */ nullptr,\n};\n\n// Get buffer type for a specific device\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(int device) {\n    GGML_ASSERT(device >= 0 && device < ggml_backend_openvino_get_device_count());\n\n    static std::mutex mutex;\n    std::lock_guard<std::mutex> lock(mutex);\n\n    static std::vector<ggml_backend_buffer_type> buffer_types;\n    static std::vector<ggml_backend_openvino_buffer_type_context> buffer_type_contexts;\n\n    if (buffer_types.empty()) {\n        int device_count = ggml_backend_openvino_get_device_count();\n        buffer_types.resize(device_count);\n        buffer_type_contexts.resize(device_count);\n\n        for (int i = 0; i < device_count; i++) {\n            buffer_type_contexts[i].device = i;\n            buffer_type_contexts[i].name = std::string(GGML_OPENVINO_NAME) + std::to_string(i);\n\n            buffer_types[i] = ggml_backend_buffer_type{\n                /* .iface   = */ ggml_backend_openvino_buffer_type_interface,\n                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), i),\n                /* .context = */ &buffer_type_contexts[i],\n            };\n        }\n    }\n\n    return &buffer_types[device];\n}\n\n// =====================================================\n// OpenVINO Host Buffer Implementation\n// =====================================================\n\nstatic const char * ggml_backend_openvino_host_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    ggml_backend_openvino_buffer_type_context * ctx = (ggml_backend_openvino_buffer_type_context *) buft->context;\n    static std::string name;\n    name = ctx->name + \"_HOST\";\n    return name.c_str();\n}\n\nstatic bool ggml_backend_openvino_host_buffer_type_is_host(ggml_backend_buffer_type_t buft) {\n    GGML_UNUSED(buft);\n    return true;\n}\n\nstatic const ggml_backend_buffer_type_i ggml_backend_openvino_host_buffer_type_interface = {\n    /* .get_name         = */ ggml_backend_openvino_host_buffer_type_get_name,\n    /* .alloc_buffer     = */ ggml_backend_openvino_buffer_type_alloc_buffer,\n    /* .get_alignment    = */ ggml_backend_openvino_buffer_type_get_alignment,\n    /* .get_max_size     = */ ggml_backend_openvino_buffer_type_get_max_size,\n    /* .get_alloc_size   = */ ggml_backend_openvino_buffer_type_get_alloc_size,\n    /* .is_host          = */ ggml_backend_openvino_host_buffer_type_is_host,\n};\n\nGGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_host_buffer_type(int device) {\n    GGML_ASSERT(device >= 0 && device < ggml_backend_openvino_get_device_count());\n\n    static std::mutex mutex;\n    std::lock_guard<std::mutex> lock(mutex);\n\n    static std::vector<ggml_backend_buffer_type> buffer_types;\n    static std::vector<ggml_backend_openvino_buffer_type_context> buffer_type_contexts;\n\n    if (buffer_types.empty()) {\n        int device_count = ggml_backend_openvino_get_device_count();\n        buffer_types.resize(device_count);\n        buffer_type_contexts.resize(device_count);\n\n        for (int i = 0; i < device_count; i++) {\n            buffer_type_contexts[i].device = i;\n            buffer_type_contexts[i].name = std::string(GGML_OPENVINO_NAME) + std::to_string(i);\n\n            buffer_types[i] = ggml_backend_buffer_type{\n                /* .iface   = */ ggml_backend_openvino_host_buffer_type_interface,\n                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), i),\n                /* .context = */ &buffer_type_contexts[i],\n            };\n        }\n    }\n\n    return &buffer_types[device];\n}\n\nbool ggml_backend_buffer_is_openvino(ggml_backend_buffer_t buffer) {\n    return buffer->iface.free_buffer == ggml_backend_openvino_buffer_free_buffer;\n}\n\nsize_t ggml_backend_openvino_buffer_get_ctx_id(ggml_backend_buffer_t buffer) {\n    if (!ggml_backend_buffer_is_openvino(buffer)) {\n        return 0;\n    }\n    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;\n    return ctx->id;\n}\n\nvoid ggml_openvino_buffer_register_extra(ggml_tensor * tensor, ggml_openvino_extra_base * extra) {\n    GGML_ASSERT(tensor != nullptr);\n    GGML_ASSERT(tensor->buffer != nullptr);\n    GGML_ASSERT(ggml_backend_buffer_is_openvino(tensor->buffer));\n\n    auto * ctx = static_cast<ggml_backend_openvino_buffer_context *>(tensor->buffer->context);\n\n    auto it = ctx->tensor_extras.find(tensor);\n    if (it != ctx->tensor_extras.end()) {\n        delete it->second;\n    }\n\n    ctx->tensor_extras[tensor] = extra;\n    tensor->extra = extra;\n}\n\nbool ggml_backend_buft_is_openvino(ggml_backend_buffer_type_t buft) {\n    return buft->iface.get_name == ggml_backend_openvino_buffer_type_get_name;\n}\n\nbool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft) {\n    return buft->iface.get_name == ggml_backend_openvino_host_buffer_type_get_name;\n}\n\nstatic void ggml_backend_openvino_free(ggml_backend_t backend) {\n    ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *) backend->context;\n    delete ctx;\n    delete backend;\n}\n\nstatic const char * ggml_backend_openvino_get_name(ggml_backend_t backend) {\n    return GGML_OPENVINO_NAME;\n    GGML_UNUSED(backend);\n}\n\nstatic enum ggml_status ggml_backend_openvino_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {\n    return ov_graph_compute(cgraph, backend);\n    GGML_UNUSED(backend);\n}\n\nstatic const ggml_backend_i ggml_backend_openvino_interface = {\n    /* .get_name                = */ ggml_backend_openvino_get_name,\n    /* .free                    = */ ggml_backend_openvino_free,\n    /* .set_tensor_async        = */ NULL,\n    /* .get_tensor_async        = */ NULL,\n    /* .cpy_tensor_async        = */ NULL,\n    /* .synchronize             = */ NULL,\n    /* .graph_plan_create       = */ NULL,\n    /* .graph_plan_free         = */ NULL,\n    /* .graph_plan_update       = */ NULL,\n    /* .graph_plan_compute      = */ NULL,\n    /* .graph_compute           = */ ggml_backend_openvino_graph_compute,\n    /* .event_record            = */ NULL,\n    /* .event_wait              = */ NULL,\n    /* .graph_optimize          = */ NULL,\n};\n\nint ggml_backend_openvino_get_device_count() {\n    return 1;\n}\n\nstatic ggml_guid_t ggml_backend_openvino_guid(void) {\n    static ggml_guid guid = {0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97,\n                             0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d};\n    return &guid;\n}\n\nstatic std::shared_ptr<ov_runtime_context> get_ov_runtime_context_ptr() {\n    static std::shared_ptr<ov_runtime_context> r_ctx = std::make_shared<ov_runtime_context>();\n    return r_ctx;\n}\n\n// backend API\nGGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device) {\n    if (device < 0 || device >= ggml_backend_openvino_get_device_count()) {\n        GGML_LOG_ERROR(\"%s: invalid device %d\\n\", __func__, device);\n        return nullptr;\n    }\n\n    ggml_backend_openvino_context * ctx = new ggml_backend_openvino_context;\n    if (ctx == nullptr) {\n        GGML_LOG_ERROR(\"%s: failed to allocate context\\n\", __func__);\n        return nullptr;\n    }\n\n    ctx->runtime_context = get_ov_runtime_context_ptr();\n    if (ctx->runtime_context == nullptr) {\n        GGML_LOG_ERROR(\"%s: failed to allocate runtime context\\n\", __func__);\n        delete ctx;\n        return nullptr;\n    }\n\n    std::shared_ptr<ov_runtime_context> r_ctx = std::static_pointer_cast<ov_runtime_context>(ctx->runtime_context);\n    r_ctx->device = ggml_openvino_get_device_name();\n    r_ctx->stateful = getenv(\"GGML_OPENVINO_STATEFUL_EXECUTION\") && !ggml_openvino_is_npu();\n\n    ggml_backend_t openvino_backend = new ggml_backend{\n        /* .guid      = */ ggml_backend_openvino_guid(),\n        /* .interface = */ ggml_backend_openvino_interface,\n        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), device),\n        /* .context   = */ ctx,\n    };\n\n    return openvino_backend;\n}\n\nGGML_BACKEND_API bool ggml_backend_is_openvino(ggml_backend_t backend) {\n    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_openvino_guid());\n}\n\nstruct ggml_backend_openvino_device_context {\n    int device;\n    std::string name;\n    std::string description;\n};\n\nstatic const char * ggml_backend_openvino_device_get_name(ggml_backend_dev_t dev) {\n    ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context;\n    return ctx->name.c_str();\n}\n\nstatic const char * ggml_backend_openvino_device_get_description(ggml_backend_dev_t dev) {\n    ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context;\n    return ctx->description.c_str();\n}\n\nstatic void ggml_backend_openvino_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {\n#ifdef _WIN32\n    MEMORYSTATUSEX status;\n    status.dwLength = sizeof(status);\n    GlobalMemoryStatusEx(&status);\n    *total = status.ullTotalPhys;\n    *free = status.ullAvailPhys;\n#else\n    long pages = sysconf(_SC_PHYS_PAGES);\n    long page_size = sysconf(_SC_PAGE_SIZE);\n    *total = pages * page_size;\n\n    // \"free\" system memory is ill-defined, for practical purposes assume that all of it is free:\n    *free = *total;\n#endif  // _WIN32\n\n    GGML_UNUSED(dev);\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_openvino_device_get_type(ggml_backend_dev_t dev) {\n    GGML_UNUSED(dev);\n    return GGML_BACKEND_DEVICE_TYPE_GPU;\n}\n\nstatic void ggml_backend_openvino_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {\n    props->name = ggml_backend_openvino_device_get_name(dev);\n    props->description = ggml_backend_openvino_device_get_description(dev);\n    props->type = ggml_backend_openvino_device_get_type(dev);\n    ggml_backend_openvino_device_get_memory(dev, &props->memory_free, &props->memory_total);\n\n    props->caps = {\n        /* .async                 = */ false,\n        /* .host_buffer           = */ false,\n        /* .buffer_from_host_ptr  = */ false,\n        /* .events                = */ false,\n    };\n}\n\nstatic ggml_backend_t ggml_backend_openvino_device_init(ggml_backend_dev_t dev, const char * params) {\n    GGML_UNUSED(params);\n    ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context;\n    return ggml_backend_openvino_init(ctx->device);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_openvino_device_get_buffer_type(ggml_backend_dev_t dev) {\n    ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context;\n    return ggml_backend_openvino_buffer_type(ctx->device);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_openvino_device_get_host_buffer_type(ggml_backend_dev_t dev) {\n    ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context;\n    return ggml_backend_openvino_host_buffer_type(ctx->device);\n}\n\nstatic bool has_view_op_input(const ggml_tensor * op) {\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        if (op->src[i] == nullptr) {\n            break;\n        }\n        if (op->src[i]->op == GGML_OP_VIEW) {\n            return true;\n        }\n    }\n    return false;\n}\n\nstatic bool is_supported_flash_attn_pattern(const ggml_tensor * op) {\n    // pattern of q,k,v should be q->op==PERMUTE, q->src[0]->op==VIEW, q->src[0]->src[0]->view_src==nullptr\n    for (int i = 0; i < 3; i++) {\n        const ggml_tensor * src = op->src[i];\n        if (src->op != GGML_OP_PERMUTE || src->src[0] == nullptr || src->src[0]->op != GGML_OP_VIEW ||\n            src->src[0]->src[0] == nullptr || src->src[0]->src[0]->view_src != nullptr) {\n            return false;\n        }\n    }\n    return true;\n}\n\nstatic bool is_op_unsupported_case(const ggml_tensor * op) {\n    switch (op->op) {\n    case GGML_OP_GET_ROWS:\n    case GGML_OP_SET_ROWS: {\n        if (op->ne[3] != 1) {\n            return true;\n        }\n        break;\n    }\n    case GGML_OP_ADD:\n    case GGML_OP_MUL: {\n        if (op->src[1]->op == GGML_OP_PERMUTE) {\n            return true;\n        }\n        for (int i = 0; i < 4; i++) {\n            if (op->src[0]->ne[i] != op->src[1]->ne[i] && (op->src[0]->ne[i] != 1 && op->src[1]->ne[i] != 1)) {\n                return true;\n            }\n        }\n        break;\n    }\n    case GGML_OP_SOFT_MAX: {\n        if (op->src[2] != nullptr) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support SOFT_MAX with sinks\\n\");\n            return true;\n        }\n        float scale = 1.0f;\n        float max_bias = 0.0f;\n        const auto * op_params = op->op_params;\n        memcpy(&scale, (const float *) op_params + 0, sizeof(float));\n        memcpy(&max_bias, (const float *) op_params + 1, sizeof(float));\n        if (max_bias > 0) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support SOFT_MAX with max_bias > 0\\n\");\n            return true;\n        }\n        break;\n    }\n    case GGML_OP_FLASH_ATTN_EXT: {\n        if (op->src[4] != nullptr) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support FLASH_ATTN_EXT with sinks\\n\");\n            return true;\n        }\n        if (!is_supported_flash_attn_pattern(op)) {\n            return true;\n        }\n        float scale = 1.0f;\n        float max_bias = 0.0f;\n        float logit_softcap = 0.0f;\n        const auto * op_params = op->op_params;\n        memcpy(&scale, (const float *) op_params + 0, sizeof(float));\n        memcpy(&max_bias, (const float *) op_params + 1, sizeof(float));\n        memcpy(&logit_softcap, (const float *) op_params + 2, sizeof(float));\n        if (max_bias > 0) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support FLASH_ATTN_EXT with max_bias > 0\\n\");\n            return true;\n        }\n        if (logit_softcap != 0) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support FLASH_ATTN_EXT with logit_softcap != 0\\n\");\n            return true;\n        }\n        break;\n    }\n    case GGML_OP_PERMUTE: {\n        if (op->type == GGML_TYPE_BF16) {\n            // err msg: [GPU] Could not find a suitable kernel for transpose\n            // GGML_LOG_WARN(\"OpenVINO backend does not support PERMUTE with BF16 type\\n\");\n            return true;\n        }\n        break;\n    }\n    case GGML_OP_CPY: {\n        if (op->src[1] != op) {\n            // GGML_LOG_WARN(\"OpenVINO backend only supports CPY that is a cast\\n\");\n            return true;\n        }\n        break;\n    }\n    case GGML_OP_MUL_MAT: {\n        if (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16) {\n            // Has accuracy issue, try enabling this and see `test-backend-ops -o \"MUL_MAT\"`\n            // GGML_LOG_WARN(\"OpenVINO backend does not support MUL_MAT with two F16 tensors\\n\");\n            return true;\n        }\n        if (op->src[0]->ne[3] != op->src[1]->ne[3] && op->src[0]->ne[3] != 1 && op->src[1]->ne[3] != 1) {\n            return true;\n        }\n        if (op->src[0]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_PERMUTE) {\n            return true;\n        }\n        if (ggml_is_quantized(op->src[0]->type) && op->src[0]->ne[1] == 1) {\n            // MUL_MAT(type_a=q4_0,type_b=f32,m=1,n=2048,k=8192,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1)\n            // triggers a bug in ov matmul_shape_inference.hpp\n            return true;\n        }\n        if (op->src[0]->op == GGML_OP_VIEW && op->src[1]->op == GGML_OP_VIEW) {\n            return true;\n        }\n        break;\n    }\n    case GGML_OP_ROPE: {\n        const int32_t * op_params = op->op_params;\n        const int n_dims = op_params[1];\n        const int mode = op_params[2];\n        if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support ROPE with mode %d\\n\", mode);\n            return true;\n        }\n        if (n_dims != 0.0f && n_dims != op->src[0]->ne[0]) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support ROPE with n_dims %d != src[0]->ne[0] %ld\\n\", n_dims,\n            //               op->src[0]->ne[0]);\n            return true;\n        }\n        if (op->type != GGML_TYPE_F32) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support ROPE with type %s\\n\", ggml_type_name(op->type));\n            return true;\n        }\n        float freq_scale;\n        float ext_factor;\n        memcpy(&freq_scale, op_params + 6, sizeof(float));\n        memcpy(&ext_factor, op_params + 7, sizeof(float));\n        if (ext_factor != 0.0f) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support ROPE with ext_factor %f != 0.0f\\n\", ext_factor);\n            return true;\n        }\n        if (op->src[0]->op == GGML_OP_VIEW) {\n            if (op->src[0]->view_src->ne[1] != op->src[0]->ne[2]) {\n                // GGML_LOG_WARN(\n                //     \"OpenVINO backend does not support ROPE with src[0]->view_src->ne[1] %ld != src[0]->ne[2] \"\n                //     \"%ld\\n\",\n                //     op->src[0]->view_src->ne[1], op->src[0]->ne[2]);\n                return true;\n            }\n        }\n        break;\n    }\n    default:\n        break;\n    }\n    if (op->op == GGML_OP_GET_ROWS) {\n        if (op->ne[0] == 256 && (op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q5_K)) {\n            // ERR = 0.000000306 > 0.000000100   GET_ROWS(type=q4_K,n=256,m=5,r=4,be1=1,be2=1,v=0)\n            // ERR = 0.000000197 > 0.000000100   GET_ROWS(type=q5_K,n=256,m=5,r=4,be1=1,be2=1,v=0)\n            return true;\n        }\n    }\n    return false;\n}\n\nstatic bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n    GGML_ASSERT(dev->reg != nullptr);\n\n    static std::set<ggml_type> supported_types{GGML_TYPE_F32,  GGML_TYPE_F16,  GGML_TYPE_BF16, GGML_TYPE_I64,\n                                               GGML_TYPE_I32,  GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_K,\n                                               GGML_TYPE_Q5_K, GGML_TYPE_Q8_0, GGML_TYPE_Q6_K};\n\n    static const std::set<ggml_op> supported_ops{GGML_OP_NONE, GGML_OP_ADD, GGML_OP_MUL, GGML_OP_MUL_MAT, GGML_OP_VIEW,\n                                                 /*GGML_OP_CONT,*/ GGML_OP_RESHAPE, GGML_OP_PERMUTE, GGML_OP_TRANSPOSE,\n                                                 GGML_OP_GET_ROWS, GGML_OP_ROPE, GGML_OP_RMS_NORM, GGML_OP_SCALE,\n                                                 // softmax is not updated due to replaced by flash_attn_ext\n                                                 // GGML_OP_SOFT_MAX,\n                                                 GGML_OP_SET_ROWS, GGML_OP_FLASH_ATTN_EXT, GGML_OP_CPY};\n    static const std::set<ggml_unary_op> supported_unary_ops{\n        GGML_UNARY_OP_SILU,\n    };\n    static const std::set<ggml_glu_op> supported_glu_ops{\n        GGML_GLU_OP_SWIGLU,\n        GGML_GLU_OP_GEGLU,\n    };\n\n    switch (op->op) {\n    case GGML_OP_UNARY: {\n        auto supported = supported_unary_ops.find(ggml_get_unary_op(op)) != supported_unary_ops.end();\n        if (!supported) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support unary op %s\\n\", ggml_unary_op_name(ggml_get_unary_op(op)));\n            return false;\n        }\n        if (has_view_op_input(op)) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support unary op %s with view input\\n\",\n            //               ggml_unary_op_name(ggml_get_unary_op(op)));\n            return false;\n        }\n        break;\n    }\n    case GGML_OP_GLU: {\n        auto supported = supported_glu_ops.find(ggml_get_glu_op(op)) != supported_glu_ops.end();\n        if (!supported) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support GLU op %s\\n\", ggml_glu_op_name(ggml_get_glu_op(op)));\n            return false;\n        }\n        if (has_view_op_input(op)) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support unary op %s with view input\\n\",\n            //               ggml_glu_op_name(ggml_get_glu_op(op)));\n            return false;\n        }\n        if (op->src[1] == nullptr && op->src[0]->ne[0] % 2 != 0) {\n            // triggers bug in ov gpu\n            return false;\n        }\n        break;\n    }\n    default: {\n        auto supported = supported_ops.find(op->op) != supported_ops.end();\n        if (!supported) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support op %s\\n\", ggml_op_name(op->op));\n            return false;\n        }\n        static std::set<ggml_op> ops_not_support_view_input{\n            GGML_OP_GET_ROWS,\n            GGML_OP_RMS_NORM,\n        };\n        if (ops_not_support_view_input.find(op->op) != ops_not_support_view_input.end() && has_view_op_input(op)) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support op %s with view input\\n\", ggml_op_name(op->op));\n            return false;\n        }\n    }\n    }\n\n    if (supported_types.find(op->type) == supported_types.end()) {\n        // GGML_LOG_WARN(\"OpenVINO backend does not support tensor type %s\\n\", ggml_type_name(op->type));\n        return false;\n    }\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        auto * src = op->src[i];\n        if (src == nullptr) {\n            break;\n        }\n        if (supported_types.find(src->type) == supported_types.end()) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support tensor type %s\\n\", ggml_type_name(src->type));\n            return false;\n        }\n        if (ggml_is_quantized(src->type) && src->ne[2] != 1) {\n            // GGML_LOG_WARN(\"OpenVINO backend does not support 3D quantized tensors\\n\");\n            return false;\n        }\n    }\n\n    if (is_op_unsupported_case(op)) {\n        return false;\n    }\n    return true;\n}\n\nstatic bool ggml_backend_openvino_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    return ggml_backend_buft_is_openvino(buft) || ggml_backend_buft_is_host(buft);\n    GGML_UNUSED(dev);\n}\n\nstatic const struct ggml_backend_device_i ggml_backend_openvino_device_interface = {\n    /* .get_name             = */ ggml_backend_openvino_device_get_name,\n    /* .get_description      = */ ggml_backend_openvino_device_get_description,\n    /* .get_memory           = */ ggml_backend_openvino_device_get_memory,\n    /* .get_type             = */ ggml_backend_openvino_device_get_type,\n    /* .get_props            = */ ggml_backend_openvino_device_get_props,\n    /* .init_backend         = */ ggml_backend_openvino_device_init,\n    /* .get_buffer_type      = */ ggml_backend_openvino_device_get_buffer_type,\n    /* .get_host_buffer_type = */ ggml_backend_openvino_device_get_host_buffer_type,\n    /* .buffer_from_host_ptr = */ NULL,\n    /* .supports_op          = */ ggml_backend_openvino_device_supports_op,\n    /* .supports_buft        = */ ggml_backend_openvino_device_supports_buft,\n    /* .offload_op           = */ NULL,\n    /* .event_new            = */ NULL,\n    /* .event_free           = */ NULL,\n    /* .event_synchronize    = */ NULL,\n};\n\nstruct ggml_backend_openvino_reg_context {\n    std::vector<ggml_backend_dev_t> devices;\n};\n\nstatic const char * ggml_backend_openvino_reg_get_name(ggml_backend_reg_t reg) {\n    return GGML_OPENVINO_NAME;\n    GGML_UNUSED(reg);\n}\n\nstatic size_t ggml_backend_openvino_reg_get_device_count(ggml_backend_reg_t reg) {\n    GGML_UNUSED(reg);\n    return (size_t) ggml_backend_openvino_get_device_count();\n}\n\nstatic ggml_backend_dev_t ggml_backend_openvino_reg_get_device(ggml_backend_reg_t reg, size_t index) {\n    ggml_backend_openvino_reg_context * ctx = (ggml_backend_openvino_reg_context *) reg->context;\n    GGML_ASSERT(index < ctx->devices.size());\n    return ctx->devices[index];\n}\n\nstatic const struct ggml_backend_reg_i ggml_backend_openvino_reg_interface = {\n    /* .get_name         = */ ggml_backend_openvino_reg_get_name,\n    /* .get_device_count = */ ggml_backend_openvino_reg_get_device_count,\n    /* .get_device       = */ ggml_backend_openvino_reg_get_device,\n    /* .get_proc_address = */ NULL,\n};\n\nstatic void ggml_openvino_init() {\n    // Initialize device config singleton from env var\n    ggml_openvino_init_device_config();\n    GGML_LOG_INFO(\"OpenVINO: using device %s\\n\", ggml_openvino_get_device_name().c_str());\n}\n\nGGML_BACKEND_API ggml_backend_reg_t ggml_backend_openvino_reg(void) {\n    static ggml_backend_reg reg;\n\n    static bool initialized = false;\n    {\n        static std::mutex mutex;\n        std::lock_guard<std::mutex> lock(mutex);\n        if (!initialized) {\n            ggml_openvino_init();\n\n            ggml_backend_openvino_reg_context * ctx = new ggml_backend_openvino_reg_context;\n\n            for (int i = 0; i < ggml_backend_openvino_get_device_count(); i++) {\n                ggml_backend_openvino_device_context * dev_ctx = new ggml_backend_openvino_device_context;\n                dev_ctx->device = i;\n                dev_ctx->name = GGML_OPENVINO_NAME + std::to_string(i);\n\n                dev_ctx->description = ov::get_openvino_version().description;\n\n                ggml_backend_dev_t dev =\n                    new ggml_backend_device{/* .interface = */ ggml_backend_openvino_device_interface,\n                                            /* .reg       = */ &reg,\n                                            /* .context   = */ dev_ctx};\n                ctx->devices.push_back(dev);\n            }\n\n            reg = ggml_backend_reg{/* .api_version = */ GGML_BACKEND_API_VERSION,\n                                   /* .iface       = */ ggml_backend_openvino_reg_interface,\n                                   /* .context     = */ ctx};\n        }\n\n        initialized = true;\n    }\n\n    return &reg;\n}\n"
  },
  {
    "path": "src/ggml-openvino/ggml-quants.cpp",
    "content": "#include \"ggml-quants.h\"\n\n#include \"ggml-common.h\"\n#include \"ggml-impl.h\"\n#include \"ggml.h\"\n\n#include <algorithm>\n#include <cassert>\n#include <cmath>\n#include <cstddef>\n#include <cstdint>\n#include <limits>\n#include <memory>\n#include <openvino/core/except.hpp>\n#include <openvino/core/node.hpp>\n#include <openvino/core/node_output.hpp>\n#include <openvino/core/parallel.hpp>\n#include <openvino/core/shape.hpp>\n#include <openvino/core/type/element_type.hpp>\n#include <openvino/core/type/element_type_traits.hpp>\n#include <openvino/core/type/float16.hpp>\n#include <openvino/op/add.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/convert.hpp>\n#include <openvino/op/multiply.hpp>\n#include <openvino/op/reshape.hpp>\n#include <openvino/op/subtract.hpp>\n#include <openvino/op/util/attr_types.hpp>\n#include <openvino/runtime/tensor.hpp>\n#include <string>\n#include <vector>\n\nvoid unpack_32_4(const uint8_t * data, uint8_t * dst) {\n    std::fill_n(dst, 16, 0);\n    for (int j = 0; j < 16; ++j) {\n        uint8_t x = (data[j] & 0x0F);\n        uint8_t y = (data[j] >> 4);\n        if (j % 2 != 0) {\n            x <<= 4;\n            y <<= 4;\n        }\n        dst[j / 2] |= x;\n        dst[8 + j / 2] |= y;  // Last 16 weights are in the higher bits\n    }\n}\n\n// Extracts (weight, scales, zp) from Q4_0 tensors.\n// Data layout is: |16 bit scale|32 x 4bit weights|.\nvoid extract_q4_0_data(const ggml_tensor * tensor,\n                       ov::Tensor & weights_arr,\n                       ov::Tensor & scales_arr,\n                       ov::Tensor & zp_arr) {\n    const uint64_t bytes_per_block = 18;  // 2 bytes scale, 32x0.5 byte weights\n\n    auto * data = static_cast<uint8_t *>(tensor->data);\n    auto * weights = static_cast<uint8_t *>(weights_arr.data());\n    auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>();\n    auto * zp = static_cast<uint8_t *>(zp_arr.data());\n\n    bool is_scalar_zp = (zp_arr.get_size() == 1);  // Symmetric quantization\n\n    // For Q4_0, zero point is always 8\n    if (is_scalar_zp) {\n        zp[0] = 8 | (8 << 4);  // Pack two 4-bit values\n    }\n\n    ov::parallel_for(scales_arr.get_size(), [&](size_t i) {\n        scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block)));\n        // For asymmetric quantization, compute per-block zero points\n        if (!is_scalar_zp) {\n            // Pack two 4-bit zero points per byte\n            if (i % 2 == 0) {\n                zp[i / 2] = 8;          // Lower nibble\n            } else {\n                zp[i / 2] |= (8 << 4);  // Upper nibble\n            }\n        }\n        unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16);\n    });\n}\n\n// Extracts (weight, scales, zp) from Q4_1 tensors.\n// Data layout is: |16 bit scale|16 bit min|32 x 4bit weights|.\nvoid extract_q4_1_data(const ggml_tensor * tensor,\n                       ov::Tensor & weights_arr,\n                       ov::Tensor & scales_arr,\n                       ov::Tensor & zp_arr,\n                       bool use_bias) {\n    const uint64_t bytes_per_block = 20;  // 2 bytes scale, 2 bytes min, 32x0.5 byte weights\n\n    auto * data = static_cast<uint8_t *>(tensor->data);\n    auto * weights = static_cast<uint8_t *>(weights_arr.data());\n    auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>();\n\n    if (use_bias) {\n        // Store bias (min) directly as f16 instead of computing u4 zero points\n        auto * bias = zp_arr.data<ov::element_type_traits<ov::element::f16>::value_type>();\n        ov::parallel_for(scales_arr.get_size(), [&](size_t i) {\n            float scale = static_cast<float>(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))));\n            float min = static_cast<float>(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block + 2))));\n            scales[i] = ov::float16(scale);\n            bias[i] = ov::float16(min);  // bias = min, dequant: w*s + bias\n            unpack_32_4(data + i * bytes_per_block + 4, weights + i * 16);\n        });\n    } else {\n        auto * zp = static_cast<uint8_t *>(zp_arr.data());\n        ov::parallel_for(scales_arr.get_size(), [&](size_t i) {\n            float scale = static_cast<float>(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))));\n            float min = static_cast<float>(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block + 2))));\n            scales[i] = ov::float16(scale);\n            // zp = -min / scale (bias = min, so zp = -bias/scale)\n            uint8_t zp_val = (scale != 0.0f) ? (uint8_t) std::round(-min / scale) : 0;\n            // Pack two 4-bit zero points per byte\n            if (i % 2 == 0) {\n                zp[i / 2] = zp_val & 0x0F;   // Lower nibble\n            } else {\n                zp[i / 2] |= (zp_val << 4);  // Upper nibble\n            }\n            unpack_32_4(data + i * bytes_per_block + 4, weights + i * 16);\n        });\n    }\n}\n\n// Extracts (weight, scales, zp) from Q8_0 tensors.\n// Data layout is: |16 bit scale|32 x 8bit weights|.\nvoid extract_q8_0_data(const ggml_tensor * tensor,\n                       ov::Tensor & weights_arr,\n                       ov::Tensor & scales_arr,\n                       ov::Tensor & zp_arr) {\n    const uint64_t weights_per_block = 32;\n    const uint64_t bytes_per_block = 34;  // 2 bytes scale, 32x1 byte weights\n\n    auto * data = static_cast<uint8_t *>(tensor->data);\n    auto * weights = static_cast<uint8_t *>(weights_arr.data());\n    auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>();\n    auto * zp = static_cast<uint8_t *>(zp_arr.data());\n\n    bool is_scalar_zp = (zp_arr.get_size() == 1);  // Symmetric quantization\n\n    // For Q8_0, zero point is always 128\n    if (is_scalar_zp) {\n        zp[0] = 128;\n    }\n\n    ov::parallel_for(scales_arr.get_size(), [&](size_t i) {\n        uint8_t * block_data = data + i * bytes_per_block;\n        scales[i] = ov::float16::from_bits(*(uint16_t *) block_data);\n        // For asymmetric quantization, store per-block zero points\n        if (!is_scalar_zp) {\n            zp[i] = 128;\n        }\n        for (size_t j = 0; j < weights_per_block; ++j) {\n            uint8_t x = block_data[j + 2];  // j+2 to skip the scale bytes.\n            // Original data is in int8_t, so we add a bias of -128 and invert the first bit.\n            x ^= 1 << 7;\n            weights[i * weights_per_block + j] = x;\n        }\n    });\n}\n\nvoid unpack_256_4(const uint8_t * data, uint8_t * dst) {\n    // Initialize the output array with zeros\n    std::fill_n(dst, 128, 0);\n\n    for (size_t i = 0; i < 4; ++i) {\n        for (int j = 0; j < 32; ++j) {\n            uint8_t x = (data[i * 32 + j] & 0x0F);\n            uint8_t y = (data[i * 32 + j] >> 4);\n            if (j % 2 != 0) {\n                x <<= 4;\n                y <<= 4;\n            }\n            dst[i * 32 + j / 2] |= x;\n            dst[i * 32 + 16 + j / 2] |= y;  // Last 16 weights are in the higher bits\n        }\n    }\n}\n\nvoid extract_q4_k_data(const ggml_tensor * tensor,\n                       ov::Tensor & weights_arr,\n                       ov::Tensor & scales_arr,\n                       ov::Tensor & zp_arr,\n                       bool use_bias) {\n    const uint64_t bytes_per_block = 2 + 2 + 12 + 128;\n    const uint64_t n_super_block = tensor->nb[3] / bytes_per_block;\n\n    auto * data = static_cast<uint8_t *>(tensor->data);\n    auto * weights = static_cast<uint8_t *>(weights_arr.data());\n    auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>();\n\n    // For bias path, zp_arr holds f16 bias values; for zp path, it holds packed u4 zero points\n    auto * zp_u4 = use_bias ? nullptr : static_cast<uint8_t *>(zp_arr.data());\n    auto * bias_f16 = use_bias ? zp_arr.data<ov::element_type_traits<ov::element::f16>::value_type>() : nullptr;\n\n    ov::parallel_for(n_super_block, [&](size_t i) {\n        uint8_t * block_data = data + i * bytes_per_block;\n\n        // Extract scale factors and offsets\n        float scale_scales = static_cast<float>(ov::float16::from_bits(*((uint16_t *) block_data)));\n        float scale_mins = static_cast<float>(ov::float16::from_bits(*((uint16_t *) block_data + 1)));\n\n        // Extract qs1 and qs2\n        uint8_t * qs1 = block_data + 4;\n\n        // Calculate scales\n        float scale_vals[8];\n        scale_vals[0] = scale_scales * static_cast<float>((*(qs1) & 0b111111));\n        scale_vals[1] = scale_scales * static_cast<float>((*(qs1 + 1) & 0b111111));\n        scale_vals[2] = scale_scales * static_cast<float>((*(qs1 + 2) & 0b111111));\n        scale_vals[3] = scale_scales * static_cast<float>((*(qs1 + 3) & 0b111111));\n        scale_vals[4] = scale_scales * static_cast<float>((*(qs1 + 8) & 0b00001111) | ((*(qs1) >> 6) << 4));\n        scale_vals[5] = scale_scales * static_cast<float>((*(qs1 + 9) & 0b00001111) | ((*(qs1 + 1) >> 6) << 4));\n        scale_vals[6] = scale_scales * static_cast<float>((*(qs1 + 10) & 0b00001111) | ((*(qs1 + 2) >> 6) << 4));\n        scale_vals[7] = scale_scales * static_cast<float>((*(qs1 + 11) & 0b00001111) | ((*(qs1 + 3) >> 6) << 4));\n\n        // Calculate min values (bias = -min)\n        float min_vals[8];\n        min_vals[0] = scale_mins * static_cast<float>((*(qs1 + 4) & 0b111111));\n        min_vals[1] = scale_mins * static_cast<float>((*(qs1 + 5) & 0b111111));\n        min_vals[2] = scale_mins * static_cast<float>((*(qs1 + 6) & 0b111111));\n        min_vals[3] = scale_mins * static_cast<float>((*(qs1 + 7) & 0b111111));\n        min_vals[4] = scale_mins * static_cast<float>((*(qs1 + 8) >> 4) | ((*(qs1 + 4) >> 6) << 4));\n        min_vals[5] = scale_mins * static_cast<float>((*(qs1 + 9) >> 4) | ((*(qs1 + 5) >> 6) << 4));\n        min_vals[6] = scale_mins * static_cast<float>((*(qs1 + 10) >> 4) | ((*(qs1 + 6) >> 6) << 4));\n        min_vals[7] = scale_mins * static_cast<float>((*(qs1 + 11) >> 4) | ((*(qs1 + 7) >> 6) << 4));\n\n        // Store scales and compute zero points or bias\n        for (int j = 0; j < 8; j++) {\n            scales[i * 8 + j] = ov::float16(scale_vals[j]);\n            if (use_bias) {\n                // Store bias = -min directly as f16, dequant: w*s + bias\n                bias_f16[i * 8 + j] = ov::float16(-min_vals[j]);\n            } else {\n                // zp = min / scale (since bias = -min and zp = -bias/scale)\n                uint8_t zp_val = (scale_vals[j] != 0.0f) ? (uint8_t) std::round(min_vals[j] / scale_vals[j]) : 0;\n                // Pack two 4-bit zero points per byte\n                size_t idx = i * 8 + j;\n                if (idx % 2 == 0) {\n                    zp_u4[idx / 2] = zp_val & 0x0F;\n                } else {\n                    zp_u4[idx / 2] |= (zp_val << 4);\n                }\n            }\n        }\n        unpack_256_4(block_data + 16, weights + i * 128);\n    });\n}\n\nvoid extract_q6_k_data(const ggml_tensor * tensor,\n                       ov::Tensor & weights_arr,\n                       ov::Tensor & scales_arr,\n                       ov::Tensor & zp_arr) {\n    const uint64_t bytes_per_block = 128 + 64 + 16 + 2;\n    const uint64_t n_super_block = tensor->nb[3] / bytes_per_block;\n\n    auto * data = static_cast<uint8_t *>(tensor->data);\n    auto * weights = static_cast<uint8_t *>(weights_arr.data());\n    auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>();\n    auto * zp = static_cast<uint8_t *>(zp_arr.data());\n\n    bool is_scalar_zp = (zp_arr.get_size() == 1);  // Symmetric quantization\n\n    // For Q6_K, zero point is always 32\n    if (is_scalar_zp) {\n        zp[0] = 32;\n    }\n\n    ov::parallel_for(n_super_block, [&](size_t i) {\n        uint8_t * block_data = data + i * bytes_per_block;\n\n        float scale_factor =\n            static_cast<float>(ov::float16::from_bits(*((uint16_t *) block_data + 104)));  // (128+64+16)/2\n\n        for (size_t j = 0; j < 16; j++) {\n            scales[j + i * 16] =\n                ov::float16(scale_factor * static_cast<float>(*((int8_t *) (block_data + 128 + 64 + j))));\n            // For asymmetric quantization, store per-block zero points\n            if (!is_scalar_zp) {\n                zp[j + i * 16] = 32;\n            }\n        }\n\n        uint8_t * ql = block_data;\n        uint8_t * qh = block_data + 128;\n\n        for (int64_t j = 0; j < 32; ++j) {\n            weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4);\n            weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4);\n            weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4);\n            weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4);\n            weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4);\n            weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4);\n            weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4);\n            weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4);\n        }\n    });\n}\n\nstatic inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) {\n    if (j < 4) {\n        *d = q[j] & 63;\n        *m = q[j + 4] & 63;\n    } else {\n        *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);\n        *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);\n    }\n}\n\nvoid extract_q5_k_data(const ggml_tensor * tensor,\n                       ov::Tensor & weights_arr,\n                       ov::Tensor & scales_arr,\n                       ov::Tensor & zp_arr,\n                       bool use_bias) {\n    const uint64_t bytes_per_block = 4 + 12 + 32 + 128;\n    const uint64_t n_super_block = tensor->nb[3] / bytes_per_block;\n\n    auto * data = static_cast<uint8_t *>(tensor->data);\n    auto * weights = static_cast<uint8_t *>(weights_arr.data());\n    auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>();\n\n    // For bias path, zp_arr holds f16 bias values; for zp path, it holds u8 zero points\n    auto * zp_u8 = use_bias ? nullptr : static_cast<uint8_t *>(zp_arr.data());\n    auto * bias_f16 = use_bias ? zp_arr.data<ov::element_type_traits<ov::element::f16>::value_type>() : nullptr;\n\n    ov::parallel_for(n_super_block, [&](size_t i) {\n        uint8_t * block_data = data + i * bytes_per_block;\n\n        const float d = static_cast<float>(ov::float16::from_bits(*((uint16_t *) block_data)));\n        const float min_factor = static_cast<float>(ov::float16::from_bits(*((uint16_t *) block_data + 1)));\n\n        const uint8_t * scales_data = block_data + 4;   // 12 bytes of scales\n        const uint8_t * qh = block_data + 4 + 12;       // 32 bytes of high bits\n        const uint8_t * ql = block_data + 4 + 12 + 32;  // 128 bytes of low bits\n\n        int is = 0;\n        uint8_t u1 = 1;\n        uint8_t u2 = 2;\n\n        // Process 2 blocks in one iteration\n        for (int j = 0; j < 256; j += 64) {  // 256 = QK_K, so 4 iterations of 64\n            uint8_t sc;\n            uint8_t m;\n\n            // Get scale and min for first 32 elements\n            get_scale_min_k4(is + 0, scales_data, &sc, &m);\n            const float d1 = d * sc;\n            const float m1 = min_factor * m;\n\n            // Get scale and min for second 32 elements\n            get_scale_min_k4(is + 1, scales_data, &sc, &m);\n            const float d2 = d * sc;\n            const float m2 = min_factor * m;\n\n            scales[i * 8 + is] = ov::float16(d1);\n            scales[i * 8 + is + 1] = ov::float16(d2);\n            if (use_bias) {\n                // Store bias = -min directly as f16, dequant: w*s + bias\n                bias_f16[i * 8 + is] = ov::float16(-m1);\n                bias_f16[i * 8 + is + 1] = ov::float16(-m2);\n            } else {\n                // zp = min / scale (since bias = -min and zp = -bias/scale)\n                zp_u8[i * 8 + is] = (d1 != 0.0f) ? (uint8_t) std::round(m1 / d1) : 0;\n                zp_u8[i * 8 + is + 1] = (d2 != 0.0f) ? (uint8_t) std::round(m2 / d2) : 0;\n            }\n\n            // Extract weights for first 32 elements (matching deq formula exactly)\n            for (int l = 0; l < 32; ++l) {\n                weights[i * 256 + j + l] = (ql[l] & 0xF) + ((qh[l] & u1) ? 16 : 0);\n            }\n\n            // Extract weights for second 32 elements\n            for (int l = 0; l < 32; ++l) {\n                weights[i * 256 + j + l + 32] = (ql[l] >> 4) + ((qh[l] & u2) ? 16 : 0);\n            }\n\n            ql += 32;\n            is += 2;\n            u1 <<= 2;\n            u2 <<= 2;\n        }\n    });\n}\n\n// TODO Reorder for make_intX_weights\n\nov::Output<ov::Node> make_int8_weights(ov::Tensor & weight,\n                                       ov::Tensor & scales,\n                                       ov::Tensor & zp,\n                                       size_t group_size,\n                                       bool use_bias) {\n    ov::Shape orig_shape = weight.get_shape();\n\n    // Expand dimensions for scales and zp/bias\n    auto scale_shape = scales.get_shape();\n    auto zp_shape = zp.get_shape();\n    bool is_scalar_zp = zp_shape.empty();  // Symmetric quantization\n\n    ov::Shape packed_shape = {orig_shape[0], orig_shape[1] / group_size, group_size};\n\n    if (packed_shape[1] == 1) {\n        // Requantized channel-wise case\n        packed_shape.erase(packed_shape.begin() + 1);\n    } else {\n        scale_shape.push_back(1);\n        scales.set_shape(scale_shape);\n        // For symmetric quantization, zp remains scalar (don't resize)\n        if (!is_scalar_zp) {\n            zp_shape.push_back(1);\n            zp.set_shape(zp_shape);\n        }\n    }\n\n    // Create graph nodes\n    auto weights_node = std::make_shared<ov::op::v0::Constant>(ov::element::u8, packed_shape,\n                                                               static_cast<uint8_t *>(weight.data()), nullptr);\n    weights_node->get_rt_info()[\"__gguf_tensor_holder\"] = weight;\n    auto scales_f16 = std::make_shared<ov::op::v0::Constant>(scales);\n    auto weights_f16 = std::make_shared<ov::op::v0::Convert>(weights_node, ov::element::f16);\n\n    ov::Output<ov::Node> result;\n    if (use_bias && !is_scalar_zp) {\n        // Bias path: w * s + b (zp tensor holds f16 bias values)\n        auto bias_f16 = std::make_shared<ov::op::v0::Constant>(zp);\n        auto w_s = std::make_shared<ov::op::v1::Multiply>(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY);\n        result = std::make_shared<ov::op::v1::Add>(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY);\n    } else {\n        // Zero point path: (w - zp) * s\n        auto zero_point = std::make_shared<ov::op::v0::Constant>(zp);\n        float zp_value;\n        if (ov::op::util::get_single_value(zero_point, zp_value)) {\n            zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value});\n        }\n        auto zero_point_f16 = std::make_shared<ov::op::v0::Convert>(zero_point, ov::element::f16);\n        auto w_zp =\n            std::make_shared<ov::op::v1::Subtract>(weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY);\n        result = std::make_shared<ov::op::v1::Multiply>(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY);\n    }\n\n    if (packed_shape.size() != 2) {\n        // If not requantized channel-wise case, reshape back to original shape\n        auto final_shape =\n            std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{orig_shape.size()}, orig_shape);\n        result = std::make_shared<ov::op::v1::Reshape>(result, final_shape, false);\n    }\n\n    return std::make_shared<ov::op::v0::Convert>(result, ov::element::f32);\n}\n\nov::Output<ov::Node> make_int4_weights(ov::Tensor & weight,\n                                       ov::Tensor & scales,\n                                       ov::Tensor & zp,\n                                       size_t group_size,\n                                       bool use_bias) {\n    ov::Shape orig_weight_shape = weight.get_shape();\n\n    // Expand dimensions for scales and zp/bias\n    ov::Shape scale_shape = scales.get_shape();\n    auto zp_shape = zp.get_shape();\n    bool is_scalar_zp = zp_shape.empty();  // Symmetric quantization\n\n    // Create INT4 weight tensor\n    ov::Shape packed_shape = {orig_weight_shape[0], orig_weight_shape[1] / group_size, group_size};\n\n    if (packed_shape[1] == 1) {\n        // Requantized channel-wise case\n        packed_shape.erase(packed_shape.begin() + 1);\n    } else {\n        scale_shape.push_back(1);\n        scales.set_shape(scale_shape);\n        // For symmetric quantization, zp remains scalar (don't resize)\n        if (!is_scalar_zp) {\n            zp_shape.push_back(1);\n            zp.set_shape(zp_shape);\n        }\n    }\n\n    auto weights_node = std::make_shared<ov::op::v0::Constant>(ov::element::u4, packed_shape,\n                                                               static_cast<uint8_t *>(weight.data()), nullptr);\n    weights_node->get_rt_info()[\"__gguf_tensor_holder\"] = weight;\n    auto weights_f16 = std::make_shared<ov::op::v0::Convert>(weights_node, ov::element::f16);\n    auto scales_f16 = std::make_shared<ov::op::v0::Constant>(scales);\n\n    ov::Output<ov::Node> result;\n    if (use_bias && !is_scalar_zp) {\n        // Bias path: w * s + b (zp tensor holds f16 bias values)\n        auto bias_f16 = std::make_shared<ov::op::v0::Constant>(zp);\n        auto w_s = std::make_shared<ov::op::v1::Multiply>(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY);\n        result = std::make_shared<ov::op::v1::Add>(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY);\n    } else {\n        // Zero point path: (w - zp) * s\n        auto zero_points_node = std::make_shared<ov::op::v0::Constant>(zp);\n        float zp_value;\n        if (ov::op::util::get_single_value(zero_points_node, zp_value)) {\n            zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value});\n        }\n        auto zero_points_f16 = std::make_shared<ov::op::v0::Convert>(zero_points_node, ov::element::f16);\n        auto w_zp =\n            std::make_shared<ov::op::v1::Subtract>(weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY);\n        result = std::make_shared<ov::op::v1::Multiply>(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY);\n    }\n\n    if (packed_shape.size() != 2) {\n        // If not requantized channel-wise case, reshape back to original shape\n        auto final_shape = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{orig_weight_shape.size()},\n                                                                  orig_weight_shape);\n        result = std::make_shared<ov::op::v1::Reshape>(result, final_shape, false);\n    }\n\n    return std::make_shared<ov::op::v0::Convert>(result, ov::element::f32);\n}\n\n// Extract quantized weights from tensor and create weight subgraph\nstd::shared_ptr<ov::Node> extract_quantized_weights(const ggml_tensor * tensor,\n                                                    const void * data,\n                                                    ov::Tensor & weights,\n                                                    ov::Tensor & scales,\n                                                    ov::Tensor & zp,\n                                                    bool use_bias) {\n    // Create a temporary tensor for extraction functions that read from tensor->data\n    ggml_tensor temp_tensor = *tensor;\n    temp_tensor.data = const_cast<void *>(data);\n\n    // Determine block size based on tensor type\n    int64_t weights_per_block;\n    bool is_u4;\n    switch (tensor->type) {\n    case GGML_TYPE_Q4_0:\n    case GGML_TYPE_Q4_1:\n    case GGML_TYPE_Q4_K:\n        is_u4 = true;\n        weights_per_block = 32;\n        break;\n    case GGML_TYPE_Q8_0:\n    case GGML_TYPE_Q5_K:\n        is_u4 = false;\n        weights_per_block = 32;\n        break;\n    case GGML_TYPE_Q6_K:\n        is_u4 = false;\n        weights_per_block = 16;\n        break;\n    default:\n        throw std::runtime_error(\"Unsupported quantized type for extraction: \" +\n                                 std::string(ggml_type_name(tensor->type)));\n    }\n\n    // Extract quantized data\n    switch (tensor->type) {\n    case GGML_TYPE_Q4_0:\n        extract_q4_0_data(&temp_tensor, weights, scales, zp);\n        break;\n    case GGML_TYPE_Q4_1:\n        extract_q4_1_data(&temp_tensor, weights, scales, zp, use_bias);\n        break;\n    case GGML_TYPE_Q4_K:\n        extract_q4_k_data(&temp_tensor, weights, scales, zp, use_bias);\n        break;\n    case GGML_TYPE_Q8_0:\n        extract_q8_0_data(&temp_tensor, weights, scales, zp);\n        break;\n    case GGML_TYPE_Q6_K:\n        extract_q6_k_data(&temp_tensor, weights, scales, zp);\n        break;\n    case GGML_TYPE_Q5_K:\n        extract_q5_k_data(&temp_tensor, weights, scales, zp, use_bias);\n        break;\n    default:\n        throw std::runtime_error(\"Unsupported quantized type: \" + std::string(ggml_type_name(tensor->type)));\n    }\n\n    // Create the OpenVINO weight subgraph\n    ov::Output<ov::Node> weight_node;\n    if (is_u4) {\n        weight_node = make_int4_weights(weights, scales, zp, weights_per_block, use_bias);\n    } else {\n        weight_node = make_int8_weights(weights, scales, zp, weights_per_block, use_bias);\n    }\n\n    auto result = weight_node.get_node_shared_ptr();\n    result->set_friendly_name(tensor->name);\n    return result;\n}\n\n// Requantize weights to target format, writing to provided buffers\nstd::shared_ptr<ov::Node> requantize_to_buffers(const ggml_tensor * tensor,\n                                                const void * data,\n                                                ExtraQuantType requant_type,\n                                                int64_t block_size,\n                                                ov::Tensor & weights,\n                                                ov::Tensor & scales,\n                                                ov::Tensor & zp) {\n    int64_t n_elements = ggml_nelements(tensor);\n\n    // First dequantize to F32\n    std::vector<float> weights_f32(n_elements);\n    ggml_get_type_traits(tensor->type)->to_float(data, weights_f32.data(), n_elements);\n\n    // Handle F16 case - just convert and create constant\n    if (requant_type == ExtraQuantType::F16) {\n        ggml_get_type_traits(GGML_TYPE_F16)->from_float_ref(weights_f32.data(), weights.data(), n_elements);\n        auto result = std::make_shared<ov::op::v0::Constant>(weights);\n        result->set_friendly_name(tensor->name);\n        return result;\n    }\n\n    // Requantize to target quantized format\n    bool is_u4 = (requant_type == ExtraQuantType::Q4_0_C || requant_type == ExtraQuantType::Q4_0_128);\n\n    if (is_u4) {\n        quantize_q4_0(weights_f32.data(), weights, scales, zp, n_elements, block_size);\n    } else if (requant_type == ExtraQuantType::Q8_1_C) {\n        quantize_q8_1(weights_f32.data(), weights, scales, zp, n_elements, block_size);\n    } else {\n        quantize_q8_0(weights_f32.data(), weights, scales, zp, n_elements, block_size);\n    }\n\n    // Create the OpenVINO weight subgraph\n    ov::Output<ov::Node> weight_node;\n    if (is_u4) {\n        weight_node = make_int4_weights(weights, scales, zp, block_size);\n    } else {\n        weight_node = make_int8_weights(weights, scales, zp, block_size);\n    }\n\n    auto result = weight_node.get_node_shared_ptr();\n    result->set_friendly_name(tensor->name);\n    return result;\n}\n\nOvWeight process_weight_tensor(const ggml_tensor * tensor, const void * data, void * output_base_ptr, bool use_bias) {\n    GGML_ASSERT(tensor != nullptr);\n    GGML_ASSERT(data != nullptr);\n\n    OvWeight result;\n\n    // Get 2D shape for weights [rows, cols]\n    ov::Shape node_shape = {static_cast<size_t>(tensor->ne[1]), static_cast<size_t>(tensor->ne[0])};\n\n    // Handle F16/F32/BF16 weights\n    if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) {\n        ov::element::Type element_type;\n        switch (tensor->type) {\n        case GGML_TYPE_F32:\n            element_type = ov::element::f32;\n            break;\n        case GGML_TYPE_F16:\n            element_type = ov::element::f16;\n            break;\n        case GGML_TYPE_BF16:\n            element_type = ov::element::bf16;\n            break;\n        default:\n            OPENVINO_THROW(\"Unexpected tensor type in F16/F32/BF16 path\");\n        }\n\n        if (output_base_ptr && output_base_ptr != data) {\n            // Using external buffer - copy data and create shared-memory constant\n            size_t tensor_bytes = ggml_nbytes(tensor);\n            memcpy(output_base_ptr, data, tensor_bytes);\n            result.weights = ov::Tensor(element_type, node_shape, output_base_ptr);\n        } else {\n            result.weights = ov::Tensor(element_type, node_shape, data);\n        }\n        result.weight_node = std::make_shared<ov::op::v0::Constant>(result.weights);\n        return result;\n    }\n\n    // Handle quantized weights\n    if (!ggml_is_quantized(tensor->type)) {\n        OPENVINO_THROW(\"Unsupported weight tensor type: \", ggml_type_name(tensor->type));\n    }\n\n    result.layout = ggml_openvino_get_extracted_layout(tensor, use_bias);\n    const auto & layout = result.layout;\n    if (layout.total_size == 0) {\n        OPENVINO_THROW(\"Unsupported quantized type: \", ggml_type_name(tensor->type));\n    }\n\n    if (use_bias) {\n        OPENVINO_ASSERT(!layout.is_requant,\n                        \"use_bias is only used for test-backend-ops, which should not have requantization\");\n        // bias node will be created on the fly and not use backend buffer\n        output_base_ptr = nullptr;\n    }\n\n    // F16 requant path - no separate scales/zp needed in result\n    if (layout.is_requant && layout.requant_type.has_value() && layout.requant_type.value() == ExtraQuantType::F16) {\n        if (output_base_ptr) {\n            result.weights = ov::Tensor(ov::element::f16, node_shape,\n                                        static_cast<uint8_t *>(output_base_ptr) + layout.weights_offset);\n        } else {\n            result.weights = ov::Tensor(ov::element::f16, node_shape);\n        }\n        ov::Tensor dummy_scales, dummy_zp;  // Not used for F16\n        result.weight_node =\n            requantize_to_buffers(tensor, data, ExtraQuantType::F16, 0, result.weights, dummy_scales, dummy_zp);\n        return result;\n    }\n\n    // Quantized path (normal extraction or quantized requant)\n    // Create weight/scale/zp tensors - shared between both paths\n    ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8;\n    ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block};\n    ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape;\n\n    if (output_base_ptr) {\n        uint8_t * buf_base = static_cast<uint8_t *>(output_base_ptr);\n        result.weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset);\n        result.scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset);\n        result.zp = ov::Tensor(weight_type, zp_shape, buf_base + layout.zp_offset);\n    } else {\n        result.weights = ov::Tensor(weight_type, node_shape);\n        result.scales = ov::Tensor(ov::element::f16, scale_shape);\n        if (use_bias && !layout.is_symmetric) {\n            // bias only has effect for asymmetric quant\n            result.zp = ov::Tensor(ov::element::f16, zp_shape);\n        } else {\n            result.zp = ov::Tensor(weight_type, zp_shape);\n        }\n    }\n\n    if (layout.is_requant && layout.requant_type.has_value()) {\n        result.weight_node = requantize_to_buffers(tensor, data, layout.requant_type.value(), layout.weights_per_block,\n                                                   result.weights, result.scales, result.zp);\n    } else {\n        result.weight_node =\n            extract_quantized_weights(tensor, data, result.weights, result.scales, result.zp, use_bias);\n    }\n\n    return result;\n}\n\nvoid quantize_q4_0(const float * x,\n                   ov::Tensor & weights_arr,\n                   ov::Tensor & scales_arr,\n                   ov::Tensor & zp_arr,\n                   int64_t k,\n                   int64_t qk) {\n    assert(k % qk == 0);\n    const int nb = k / qk;\n\n    auto * weights = static_cast<uint8_t *>(weights_arr.data());\n    auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>();\n    auto * zp = static_cast<uint8_t *>(zp_arr.data());\n    bool is_scalar_zp = (zp_arr.get_size() == 1);  // Symmetric quantization\n\n    // For Q4_0, zero point is always 8\n    if (is_scalar_zp) {\n        zp[0] = 8 | (8 << 4);  // Pack two 4-bit values\n    }\n\n    for (int i = 0; i < nb; i++) {\n        float amax = 0.0f;  // absolute max\n        float max = 0.0f;\n\n        for (int j = 0; j < qk; j++) {\n            const float v = x[i * qk + j];\n            if (amax < fabsf(v)) {\n                amax = fabsf(v);\n                max = v;\n            }\n        }\n\n        const float d = max / -8;\n\n        if (d == 0) {\n            scales[i] = ov::float16(1.0f);\n            // zp is already set to 8 for symmetric, or set per-block for asymmetric\n            if (!is_scalar_zp) {\n                if (i % 2 == 0) {\n                    zp[i / 2] = 8;\n                } else {\n                    zp[i / 2] |= (8 << 4);\n                }\n            }\n            memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2);\n            continue;\n        }\n\n        const float id = 1.0f / d;\n        scales[i] = ov::float16(d);\n        // For asymmetric quantization, store per-block zero points\n        if (!is_scalar_zp) {\n            if (i % 2 == 0) {\n                zp[i / 2] = 8;\n            } else {\n                zp[i / 2] |= (8 << 4);\n            }\n        }\n\n        for (int j = 0; j < qk / 2; ++j) {\n            const float x0 = x[i * qk + 2 * j] * id;\n            const float x1 = x[i * qk + 2 * j + 1] * id;\n            const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f));\n            const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f));\n            weights[i * qk / 2 + j] = xi0 | (xi1 << 4);\n        }\n    }\n}\n\nvoid quantize_q8_0(const float * x,\n                   ov::Tensor & weights_arr,\n                   ov::Tensor & scales_arr,\n                   ov::Tensor & zp_arr,\n                   int64_t k,\n                   int64_t qk) {\n    assert(k % qk == 0);\n    const int nb = k / qk;\n\n    auto * weights = static_cast<uint8_t *>(weights_arr.data());\n    auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>();\n    auto * zp = static_cast<uint8_t *>(zp_arr.data());\n    bool is_scalar_zp = (zp_arr.get_size() == 1);  // Symmetric quantization\n\n    // For Q8_0, zero point is always 128\n    if (is_scalar_zp) {\n        zp[0] = 128;\n    }\n\n    for (int i = 0; i < nb; i++) {\n        float amax = 0.0f;  // absolute max\n\n        for (int j = 0; j < qk; j++) {\n            const float v = x[i * qk + j];\n            if (amax < fabsf(v)) {\n                amax = fabsf(v);\n            }\n        }\n\n        const float d = amax / 127.0f;\n        const float id = d ? 1.0f / d : 0.0f;\n        scales[i] = ov::float16(d);\n        // For asymmetric quantization, store per-block zero points\n        if (!is_scalar_zp) {\n            zp[i] = 128;\n        }\n\n        for (int j = 0; j < qk; ++j) {\n            const float x0 = x[i * qk + j] * id;\n            const int8_t xi0 = roundf(x0);\n            weights[i * qk + j] = (uint8_t) (xi0 + 128);\n        }\n    }\n}\n\nvoid quantize_q8_1(const float * x,\n                   ov::Tensor & weights_arr,\n                   ov::Tensor & scales_arr,\n                   ov::Tensor & zp_arr,\n                   int64_t k,\n                   int64_t qk) {\n    assert(k % qk == 0);\n    const int nb = k / qk;\n\n    auto * weights = static_cast<uint8_t *>(weights_arr.data());\n    auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>();\n    auto * zp = static_cast<uint8_t *>(zp_arr.data());\n    for (int i = 0; i < nb; i++) {\n        float min = std::numeric_limits<float>::max();\n        float max = std::numeric_limits<float>::lowest();\n\n        for (int j = 0; j < qk; j++) {\n            const float v = x[i * qk + j];\n            if (v < min) {\n                min = v;\n            }\n            if (v > max) {\n                max = v;\n            }\n        }\n\n        const float d = (max - min) / ((1 << 8) - 1);\n        const float id = d ? 1.0f / d : 0.0f;\n        scales[i] = ov::float16(d);\n        // zp = -min / scale (Q8_1 is asymmetric)\n        zp[i] = (d != 0.0f) ? (uint8_t) std::round(-min / d) : 0;\n\n        for (int j = 0; j < qk; ++j) {\n            const float x0 = (x[i * qk + j] - min) * id;\n            const uint8_t xi0 = roundf(x0);\n            weights[i * qk + j] = xi0;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-openvino/ggml-quants.h",
    "content": "#pragma once\n#include \"ggml-openvino-extra.h\"  // For ExtraQuantType\n#include \"ggml.h\"\n\n#include <cstdint>\n#include <openvino/op/constant.hpp>\n#include <openvino/runtime/tensor.hpp>\n\nvoid unpack_32_4(const uint8_t* data, uint8_t* dst);\n\nvoid extract_q4_0_data(const ggml_tensor * tensor,\n                       ov::Tensor & weights_arr,\n                       ov::Tensor & scales_arr,\n                       ov::Tensor & zp_arr);\n\nvoid extract_q4_1_data(const ggml_tensor * tensor,\n                       ov::Tensor & weights_arr,\n                       ov::Tensor & scales_arr,\n                       ov::Tensor & zp_arr,\n                       bool use_bias = false);\n\nvoid extract_q8_0_data(const ggml_tensor * tensor,\n                       ov::Tensor & weights_arr,\n                       ov::Tensor & scales_arr,\n                       ov::Tensor & zp_arr);\n\nvoid unpack_256_4(const uint8_t* data, uint8_t* dst);\n\nvoid extract_q4_k_data(const ggml_tensor * tensor,\n                       ov::Tensor & weights_arr,\n                       ov::Tensor & scales_arr,\n                       ov::Tensor & zp_arr,\n                       bool use_bias = false);\n\nvoid extract_q5_k_data(const ggml_tensor * tensor,\n                       ov::Tensor & weights_arr,\n                       ov::Tensor & scales_arr,\n                       ov::Tensor & zp_arr,\n                       bool use_bias = false);\n\nvoid extract_q6_k_data(const ggml_tensor * tensor,\n                       ov::Tensor & weights_arr,\n                       ov::Tensor & scales_arr,\n                       ov::Tensor & zp_arr);\n\nstatic constexpr size_t GGML_QUANTIZATION_GROUP_SIZE = 32;\n\nov::Output<ov::Node> make_int8_weights(ov::Tensor & weight,\n                                       ov::Tensor & scales,\n                                       ov::Tensor & zp,\n                                       size_t group_size = GGML_QUANTIZATION_GROUP_SIZE,\n                                       bool use_bias = false);\n\nov::Output<ov::Node> make_int4_weights(ov::Tensor & weight,\n                                       ov::Tensor & scales,\n                                       ov::Tensor & zp,\n                                       size_t group_size = GGML_QUANTIZATION_GROUP_SIZE,\n                                       bool use_bias = false);\n\n// Extract quantized weights from tensor and create weight subgraph\n// If weights/scales/zp are provided (non-empty), uses them as output buffers\n// Otherwise allocates new ov::Tensors internally\n// Returns the weight node (make_int4_weights or make_int8_weights result)\nstd::shared_ptr<ov::Node> extract_quantized_weights(\n    const ggml_tensor * tensor,\n    const void * data,  // Source data pointer (may differ from tensor->data)\n    ov::Tensor & weights,\n    ov::Tensor & scales,\n    ov::Tensor & zp,\n    bool use_bias = false);  // Use fp bias instead of quantized zero_point (for test-backend-ops)\n\n// Requantize weights from tensor to target format, writing to provided buffers\n// For F16 target, only weights buffer is used (scales/zp ignored)\n// Returns the weight node\nstd::shared_ptr<ov::Node> requantize_to_buffers(const ggml_tensor * tensor,\n                                                const void * data,  // Source data pointer\n                                                ExtraQuantType requant_type,\n                                                int64_t block_size,\n                                                ov::Tensor & weights,\n                                                ov::Tensor & scales,\n                                                ov::Tensor & zp);\n\ninline const char * extra_quant_type_name(ExtraQuantType t) {\n    switch (t) {\n    case ExtraQuantType::F16:\n        return \"F16\";\n    case ExtraQuantType::Q4_0_C:\n        return \"Q4_0_C\";\n    case ExtraQuantType::Q4_0_128:\n        return \"Q4_0_128\";\n    case ExtraQuantType::Q8_0_C:\n        return \"Q8_0_C\";\n    case ExtraQuantType::Q8_0_32:\n        return \"Q8_0_32\";\n    case ExtraQuantType::Q8_1_C:\n        return \"Q8_1_C\";\n    default:\n        return \"unknown\";\n    }\n}\n\n// Result from process_weight_tensor containing the weight node and tensors.\n// For quantized weights, also contains the extracted layout and scale/zp tensors.\nstruct OvWeight {\n    std::shared_ptr<ov::Node> weight_node;\n    ggml_openvino_extracted_layout layout;  // Only meaningful for quantized (layout.total_size > 0)\n    ov::Tensor weights;\n    ov::Tensor scales;\n    ov::Tensor zp;\n\n    bool is_quantized() const { return layout.scales_size > 0; }\n};\n\n// Process weight tensor and create an OpenVINO weight node\n// Handles F16/F32/BF16 and quantized weights, with optional requantization\n// If output_base_ptr is nullptr, allocates internal buffers (for decoder use)\n// If output_base_ptr is provided, uses pre-allocated buffers at specified offsets (for backend buffer use)\n// Returns OvWeight with the weight node and optional quantized tensors\nOvWeight process_weight_tensor(\n    const ggml_tensor * tensor,\n    const void * data,                 // Source data pointer (may differ from tensor->data)\n    void * output_base_ptr = nullptr,  // Base pointer for output buffers (or nullptr for internal allocation)\n    bool use_bias = false);            // Use fp bias instead of quantized zero_point, only used in test-backend-ops\n\nvoid quantize_q4_0(const float * x,\n                   ov::Tensor & weights_arr,\n                   ov::Tensor & scales_arr,\n                   ov::Tensor & zp_arr,\n                   int64_t k,\n                   int64_t qk);\nvoid quantize_q8_1(const float * x,\n                   ov::Tensor & weights_arr,\n                   ov::Tensor & scales_arr,\n                   ov::Tensor & zp_arr,\n                   int64_t k,\n                   int64_t qk);\nvoid quantize_q8_0(const float * x,\n                   ov::Tensor & weights_arr,\n                   ov::Tensor & scales_arr,\n                   ov::Tensor & zp_arr,\n                   int64_t k,\n                   int64_t qk);\n\nnamespace ov {\nnamespace op {\nnamespace util {\n// From <openvino>/src/common/transformations/include/transformations/utils/utils.hpp\nbool get_single_value(const std::shared_ptr<ov::op::v0::Constant>& const_node,\n                      float& value,\n                      bool check_value_range = true);\n}  // namespace util\n}  // namespace op\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/decoder.h",
    "content": "#pragma once\n\n#include <cstdint>\n#include <map>\n#include <openvino/core/node.hpp>\n#include <openvino/frontend/decoder.hpp>\n#include <string>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\n\nclass GgmlDecoder : public DecoderBase {\npublic:\n    virtual ov::Any get_attribute(const std::string& name) const = 0;\n\n    virtual PartialShape get_input_shape(int node_idx, const std::string& name) const = 0;\n\n    virtual std::vector<size_t> get_input_stride(int node_idx, const std::string& name) const = 0;\n\n    virtual element::Type get_input_type(int node_idx, const std::string& name) const = 0;\n\n    virtual size_t get_input_size() const = 0;\n\n    virtual size_t get_input_size(int node_idx) const = 0;\n\n    virtual void get_input_node(size_t input_port_idx,\n                                std::string& producer_name,\n                                std::string& producer_output_port_name,\n                                size_t& producer_output_port_index) const = 0;\n\n    virtual std::vector<std::string> get_input_names(int node_idx) const = 0;\n\n    virtual PartialShape get_output_shape(int node_idx) const = 0;\n\n    virtual element::Type get_output_type(const int node_idx) const = 0;\n\n    virtual int32_t* get_input_op_params(int node_idx, const std::string& name) const = 0;\n\n    virtual int32_t * get_output_op_params(int node_idx) const = 0;\n\n    virtual std::vector<std::string> get_output_names(int node_idx) const = 0;\n\n    virtual const std::string& get_op_type() const = 0;\n\n    virtual const std::string& get_op_type(int node_idx) const = 0;\n\n    virtual const std::string& get_op_name() const = 0;\n\n    virtual const std::string& get_op_name(int node_idx) const = 0;\n\n    virtual void visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>, int node_idx)> node_visitor) const = 0;\n\n    virtual int get_op_case(int node_idx) const = 0;\n\n    virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_inputs() const = 0;\n    virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_extra_inputs() const = 0;\n    virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_weights() const = 0;\n    virtual std::vector<std::string> get_model_output_names() const = 0;\n\n    virtual int32_t* get_rope_params() const = 0;\n\n    virtual std::map<std::string, std::string> get_kv_param_res_names() const = 0;\n\n    virtual bool is_static() const = 0;\n\n    virtual bool is_stateful() const = 0;\n\n    virtual int is_swa_layer(int layer) const = 0;\n};\n\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/frontend.cpp",
    "content": "#include \"frontend.h\"\n\n#include \"input_model.h\"\n#include \"op_table.h\"\n#include \"translate_session.h\"\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\n\nFrontEnd::FrontEnd() {}\n\nstd::shared_ptr<Model> FrontEnd::convert(const InputModel::Ptr & model, bool naive) {\n    auto ggml_model = std::dynamic_pointer_cast<ggml::InputModel>(model);\n    FRONT_END_GENERAL_CHECK(ggml_model, \"Invalid input model\");\n    std::shared_ptr<Model> converted_model;\n    const auto & supported_ops = get_supported_ops();\n    {\n        TranslateSession translate_session(model, supported_ops, naive);\n        converted_model = translate_session.get_converted_model();\n    }\n    return converted_model;\n}\n\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/frontend.h",
    "content": "// Copyright (C) 2018-2024 Intel Corporation\n// SPDX-License-Identifier: Apache-2.0\n//\n\n#pragma once\n\n#include <openvino/frontend/frontend.hpp>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\n\nclass FrontEnd {\npublic:\n    using Ptr = std::shared_ptr<FrontEnd>;\n    FrontEnd();\n\n    static std::shared_ptr<Model> convert(const InputModel::Ptr& model, bool naive = false);\n};\n\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/input_model.cpp",
    "content": "#include \"input_model.h\"\n\n#include \"decoder.h\"\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\n\nInputModel::InputModel(const std::shared_ptr<GgmlDecoder> & gdecoder) : m_decoder(gdecoder) {}\n\nconst std::shared_ptr<GgmlDecoder> & InputModel::get_model_decoder() const {\n    return m_decoder;\n}\n\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/input_model.h",
    "content": "#pragma once\n\n#include <openvino/frontend/input_model.hpp>\n\n#include \"decoder.h\"\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\n\nclass FrontEnd;\nclass GgmlDecoder;\nusing ov::frontend::ggml::GgmlDecoder;\n\nclass InputModel : public ov::frontend::InputModel {\n    friend class ::ov::frontend::ggml::FrontEnd;\n\npublic:\n    explicit InputModel(const std::shared_ptr<GgmlDecoder>& gdecoder);\n\n    const std::shared_ptr<GgmlDecoder>& get_model_decoder() const;\n\nprivate:\n    std::shared_ptr<GgmlDecoder> m_decoder;\n};\n\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/node_context.h",
    "content": "#pragma once\n\n#include <cstdint>\n#include <openvino/frontend/node_context.hpp>\n#include <string>\n\n#include \"decoder.h\"\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\n\nclass TranslateSession;\n\ntypedef std::map<std::string, Output<Node>> TensorMap;\n\nclass NodeContext : public frontend::NodeContext {\npublic:\n    NodeContext(const std::shared_ptr<GgmlDecoder>& decoder,\n                std::shared_ptr<TensorMap>& tensor_map,\n                int node_idx,\n                TranslateSession* translate_session = nullptr)\n        : ov::frontend::NodeContext(decoder->get_op_type(node_idx)),\n          m_decoder(decoder),\n          m_tensor_map(tensor_map),\n          m_node_idx(node_idx),\n          m_translate_session(translate_session) {\n        m_input_names = decoder->get_input_names(m_node_idx);\n        m_output_names = decoder->get_output_names(m_node_idx);\n    }\n\n    TranslateSession* get_translate_session() const {\n        return m_translate_session;\n    }\n\n    const std::vector<std::string>& get_input_names() const { return m_input_names; }\n\n    size_t get_input_size() const override {\n        return m_decoder->get_input_size(m_node_idx);\n    }\n\n    ov::element::Type get_input_type(size_t index) const {\n        return m_decoder->get_input_type(m_node_idx, m_input_names[index]);\n    }\n\n    PartialShape get_input_shape(size_t input_index) const {\n        return m_decoder->get_input_shape(m_node_idx, m_input_names[input_index]);\n    }\n\n    std::vector<size_t> get_input_stride(size_t index) const {\n        return m_decoder->get_input_stride(m_node_idx, m_input_names[index]);\n    }\n\n    std::string get_output_name() const { return m_output_names[0]; }\n\n    PartialShape get_output_shape() const { return m_decoder->get_output_shape(m_node_idx); }\n\n    int32_t* get_input_op_params(size_t index) const {\n        return m_decoder->get_input_op_params(m_node_idx, m_input_names[index]);\n    }\n\n    int32_t * get_output_op_params() const { return m_decoder->get_output_op_params(m_node_idx); }\n\n    ov::element::Type get_output_type() const {\n        return m_decoder->get_output_type(m_node_idx);\n    }\n\n    Output<Node> get_input(int idx) const override {\n        return m_tensor_map->at(m_input_names[idx]);\n    }\n\n    Output<Node> get_input(const std::string& name) const override {\n        if (m_tensor_map->find(name) == m_tensor_map->end()) {\n            throw std::runtime_error(\"'\" + name + \"' not found in tensor map.\");\n        }\n        return m_tensor_map->at(name);\n    }\n\n    bool has_input(const std::string& name) const {\n        return m_tensor_map->find(name) != m_tensor_map->end();\n    }\n\n    const std::string& get_name() const override {\n        return m_decoder->get_op_name(m_node_idx);\n    }\n\n    ov::Any get_attribute_as_any(const std::string& name) const override {\n        return m_decoder->get_attribute(name);\n    }\n\n    int get_op_case() const {\n        return m_decoder->get_op_case(m_node_idx);\n    }\n\n    bool is_static() const { return m_decoder->is_static(); }\n\n    bool is_stateful() const { return m_decoder->is_stateful(); }\n\nprivate:\n    std::shared_ptr<GgmlDecoder> m_decoder;\n    std::shared_ptr<TensorMap>& m_tensor_map;\n    int m_node_idx;\n    TranslateSession* m_translate_session;\n    std::vector<std::string> m_input_names;\n    std::vector<std::string> m_output_names;\n};\n\nusing CreatorFunction = std::function<ov::OutputVector(const ov::frontend::ggml::NodeContext&)>;\n\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/cont.cpp",
    "content": "\n#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <climits>\n#include <cstdint>\n#include <memory>\n#include <openvino/op/reshape.hpp>\n#include <openvino/op/slice.hpp>\n#include <vector>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_cont(const NodeContext & context) {\n    num_inputs_check(context, 1, 1);\n\n    int op_case = context.get_op_case();\n    FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, \"Unsupported CONT case\");\n\n    auto src_shape = context.get_input_shape(0).to_shape();\n    auto dst_shape = context.get_output_shape().to_shape();\n    ov::Output<Node> res;\n\n    if (op_case == 1) {\n        // The input comes from a PERMUTE\n        throw std::runtime_error(\"Code of this case might be outdated\");\n        dst_shape[1] = -1;\n        res = std::make_shared<ov::op::v1::Reshape>(\n            context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {dst_shape.size()}, dst_shape), false);\n    } else if (op_case == 2) {\n        // The input comes from a TRANSPOSE\n        return {context.get_input(0)};\n    } else {\n        // The input comes from a VIEW\n        res = process_view_input(context, 0);\n    }\n\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/cpy.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <memory>\n#include <openvino/op/convert.hpp>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_cpy(const NodeContext & context) {\n    auto res = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_output_type());\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/flash_attn_ext.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <cstdint>\n#include <memory>\n#include <openvino/op/broadcast.hpp>\n#include <openvino/op/concat.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/convert.hpp>\n#include <openvino/op/reshape.hpp>\n#include <openvino/op/scaled_dot_product_attention.hpp>\n#include <openvino/op/transpose.hpp>\n#include <openvino/op/unsqueeze.hpp>\n#include <string>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_flash_attn_ext(const NodeContext & context) {\n    num_inputs_check(context, 4, 4);\n    auto q_f32 = context.get_input(0);\n    auto k = context.get_input(1);\n    auto v = context.get_input(2);\n    auto mask = context.get_input(3);\n\n    float * params = reinterpret_cast<float *>(context.get_output_op_params());\n    float scale = params[0];\n    // float max_bias      = params[1];\n    // float logit_softcap = params[2];\n\n    auto q = std::make_shared<ov::op::v0::Convert>(q_f32, ov::element::f16);\n    auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale});\n\n    ov::Output<ov::Node> mask_sliced, res;\n    std::string mask_name = \"KQ_mask_sliced\";\n    if (context.get_input_names()[3].find(\"swa\") != std::string::npos) {\n        mask_name = \"KQ_mask_swa_sliced\";\n    }\n    if (context.has_input(mask_name)) {\n        mask_sliced = context.get_input(mask_name);\n    } else {\n        auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});\n        auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});\n        auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});\n        auto token_len = get_dimensions(q, {2});\n        mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, two);\n    }\n\n    if (mask_sliced.get_element_type() != ov::element::f16) {\n        mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);\n    }\n\n    auto tile_kv = [&](int64_t num_heads, int64_t num_heads_kv, int64_t head_size, ov::Output<Node> kv) {\n        int64_t factor = num_heads / num_heads_kv;\n        if (factor > 1 && num_heads_kv > 1) {\n            ov::Output<ov::Node> kv_broadcast_shape, kv_unsqueezed, new_kv_shape;\n            auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});\n            kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);\n\n            kv_broadcast_shape = ov::op::v0::Constant::create(\n                ov::element::i64, {5}, {(int64_t) 1, (int64_t) 1, factor, (int64_t) 1, (int64_t) 1});\n            new_kv_shape =\n                ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 0, num_heads, (int64_t) -1, head_size});\n\n            kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape,\n                                                         ov::op::BroadcastType::BIDIRECTIONAL);\n            kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, true);\n        }\n        return kv;\n    };\n\n    auto q_shape = context.get_input_shape(0).to_shape();\n    auto k_shape = context.get_input_shape(1).to_shape();\n    k = tile_kv(q_shape[1], k_shape[1], q_shape[3], k);\n    v = tile_kv(q_shape[1], k_shape[1], q_shape[3], v);\n\n    auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false);\n    res = std::make_shared<ov::op::v1::Transpose>(sdpa,\n                                                  ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));\n    res = std::make_shared<ov::op::v0::Convert>(res, ov::element::f32);\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/get_rows.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <openvino/core/node.hpp>\n#include <openvino/core/node_output.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/convert.hpp>\n#include <openvino/op/gather.hpp>\n#include <openvino/op/squeeze.hpp>\n#include <openvino/op/unsqueeze.hpp>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_get_rows(const NodeContext & context) {\n    num_inputs_check(context, 2, 2);\n\n    int op_case = context.get_op_case();\n\n    Output<Node> res;\n    auto data = context.get_input(0);\n    auto indices = context.get_input(1);\n\n    if (op_case == 2) {\n        // The input comes from a VIEW\n        indices = process_view_input(context, 1);\n    }\n\n    // data[1,b,x,y] ind[1,1,b,x'] test-backend-ops case\n    // data[x,y] ind[1,1,1,x'] normal case\n    indices =\n        std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));\n    if (data.get_partial_shape().rank() == 4) {\n        if (!(data.get_partial_shape()[1].is_dynamic()) && data.get_partial_shape()[1].get_length() == 1) {\n            // Work-around for a bug in ov cpu plugin for test-backend-ops\n            data = std::make_shared<ov::op::v0::Squeeze>(data,\n                                                         ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));\n            auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});\n            res = std::make_shared<ov::op::v8::Gather>(data, indices, axis);\n        } else {\n            auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});\n            data =\n                std::make_shared<ov::op::v0::Squeeze>(data, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));\n            res = std::make_shared<ov::op::v8::Gather>(data, indices, axis, 1);\n        }\n    } else if (context.is_stateful() && data.get_partial_shape().rank() == 3) {\n        auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});\n        res = std::make_shared<ov::op::v8::Gather>(data, indices, axis, 1);\n    } else {\n        auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});\n        res = std::make_shared<ov::op::v8::Gather>(data, indices, axis);\n    }\n\n    if (res.get_element_type() != context.get_output_type()) {\n        res = std::make_shared<ov::op::v0::Convert>(res, context.get_output_type());\n    }\n    if (!(context.is_stateful())) {\n        res = std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));\n    }\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/glu_geglu.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <memory>\n#include <openvino/core/node_output.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/gelu.hpp>\n#include <openvino/op/multiply.hpp>\n#include <openvino/op/sigmoid.hpp>\n#include <openvino/op/slice.hpp>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_glu_geglu(const NodeContext & context) {\n    num_inputs_check(context, 1, 2);\n\n    ov::Output<ov::Node> src0;\n    ov::Output<ov::Node> src1;\n    if (context.get_input_size() == 2) {\n        src0 = context.get_input(0);\n        src1 = context.get_input(1);\n    } else {\n        // GGML splits along ne[0] (OV last axis) using floor division: nc = ne[0] / 2.\n        // Both halves are nc elements; if the dimension is odd, the last element is dropped.\n        // Use Slice instead of Split to handle odd dimensions correctly.\n        auto combined = context.get_input(0);\n        auto combined_shape = combined.get_partial_shape();\n        int64_t last_dim_val = combined_shape[combined_shape.rank().get_length() - 1].get_length();\n        int64_t nc = last_dim_val / 2;\n\n        auto axis   = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});\n        auto step   = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});\n        auto start0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});\n        auto stop0  = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc});\n        auto start1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc});\n        auto stop1  = ov::op::v0::Constant::create(ov::element::i64, {1}, {2 * nc});\n\n        src0 = std::make_shared<ov::op::v8::Slice>(combined, start0, stop0, step, axis);\n        src1 = std::make_shared<ov::op::v8::Slice>(combined, start1, stop1, step, axis);\n    }\n\n    int32_t * params = context.get_output_op_params();\n    const int32_t swapped = params[1];\n    if (swapped) {\n        std::swap(src0, src1);\n    }\n\n    auto gelu = std::make_shared<ov::op::v7::Gelu>(src0);\n    auto res = std::make_shared<ov::op::v1::Multiply>(gelu, src1);\n\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/glu_swiglu.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <cstdint>\n#include <memory>\n#include <openvino/core/node_output.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/multiply.hpp>\n#include <openvino/op/sigmoid.hpp>\n#include <openvino/op/slice.hpp>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_glu_swiglu(const NodeContext & context) {\n    num_inputs_check(context, 1, 2);\n\n    ov::Output<ov::Node> src0;\n    ov::Output<ov::Node> src1;\n    if (context.get_input_size() == 2) {\n        src0 = context.get_input(0);\n        src1 = context.get_input(1);\n    } else {\n        // GGML splits along ne[0] (OV last axis) using floor division: nc = ne[0] / 2.\n        // Both halves are nc elements; if the dimension is odd, the last element is dropped.\n        // Use Slice instead of Split to handle odd dimensions correctly.\n        auto combined = context.get_input(0);\n        auto combined_shape = combined.get_partial_shape();\n        int64_t last_dim_val = combined_shape[combined_shape.rank().get_length() - 1].get_length();\n        int64_t nc = last_dim_val / 2;\n\n        auto axis   = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});\n        auto step   = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});\n        auto start0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});\n        auto stop0  = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc});\n        auto start1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc});\n        auto stop1  = ov::op::v0::Constant::create(ov::element::i64, {1}, {2 * nc});\n\n        src0 = std::make_shared<ov::op::v8::Slice>(combined, start0, stop0, step, axis);\n        src1 = std::make_shared<ov::op::v8::Slice>(combined, start1, stop1, step, axis);\n    }\n\n    int32_t * params = context.get_output_op_params();\n    const int32_t swapped = params[1];\n    if (swapped) {\n        std::swap(src0, src1);\n    }\n\n    auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(src0);\n    auto silu = std::make_shared<ov::op::v1::Multiply>(src0, sigmoid);\n    auto res = std::make_shared<ov::op::v1::Multiply>(silu, src1);\n\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/mulmat.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <climits>\n#include <cstdint>\n#include <memory>\n#include <openvino/core/node.hpp>\n#include <openvino/core/node_output.hpp>\n#include <openvino/op/broadcast.hpp>\n#include <openvino/op/concat.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/convert.hpp>\n#include <openvino/op/matmul.hpp>\n#include <openvino/op/reshape.hpp>\n#include <openvino/op/slice.hpp>\n#include <openvino/op/transpose.hpp>\n#include <openvino/op/unsqueeze.hpp>\n#include <openvino/op/util/op_types.hpp>\n#include <vector>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_mulmat(const NodeContext & context) {\n    num_inputs_check(context, 2, 2);\n\n    int op_case = context.get_op_case();\n\n    ov::Output<Node> res;\n    ov::Output<ov::Node> B = context.get_input(0);\n    ov::Output<ov::Node> A = context.get_input(1);\n\n    bool transpose_b = true;\n    if (op_case == 2) {\n        B = B.get_node_shared_ptr()->input_value(0);\n        transpose_b = false;\n    } else if (op_case == 3) {\n        B = process_view_input(context, 0);\n        A = process_view_input(context, 1);\n    }\n    if (A.get_element_type() != B.get_element_type()) {\n        B = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_input_type(1));\n    }\n\n    auto B_shape = context.get_input_shape(0).to_shape();\n    auto A_shape = context.get_input_shape(1).to_shape();\n    int64_t A_batch = A_shape[1];\n    int64_t B_batch = B_shape[1];\n\n    auto A_batch_larger = A_batch > B_batch;\n    auto batch_large = A_batch_larger ? A_batch : B_batch;\n    auto batch_small = A_batch_larger ? B_batch : A_batch;\n\n    Output<Node> Z = A_batch_larger ? B : A;\n    int64_t factor = batch_large / batch_small;\n    if (factor > 1 && batch_small > 1) {\n        auto batch_large_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{batch_large});\n        auto batch_small_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{batch_small});\n        auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});\n\n        auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});\n        auto Z_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(Z, unsqueeze_axes);\n\n        auto broadcast_shape = ov::op::v0::Constant::create(\n            ov::element::i64, {5}, {(int64_t) 1, (int64_t) 1, factor, (int64_t) 1, (int64_t) 1});\n        auto new_Z_shape = ov::op::v0::Constant::create(ov::element::i64, {4},\n                                                        {(int64_t) 0, batch_large, (int64_t) -1, (int64_t) A_shape[3]});\n\n        auto Z_broadcasted = std::make_shared<ov::op::v3::Broadcast>(Z_unsqueezed, broadcast_shape,\n                                                                     ov::op::BroadcastType::BIDIRECTIONAL);\n        Z = std::make_shared<ov::op::v1::Reshape>(Z_broadcasted, new_Z_shape, true);\n    }\n    if (A_batch_larger) {\n        B = Z;\n    } else {\n        A = Z;\n    }\n\n    res = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);\n\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/permute.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <climits>\n#include <cstdint>\n#include <memory>\n#include <openvino/core/node.hpp>\n#include <openvino/op/add.hpp>\n#include <openvino/op/concat.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/reshape.hpp>\n#include <openvino/op/slice.hpp>\n#include <openvino/op/transpose.hpp>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_permute(const NodeContext & context) {\n    num_inputs_check(context, 1, 1);\n\n    int op_case = context.get_op_case();\n    FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4,\n                                \"Unsupported PERMUTE case\");\n\n    ov::Output<Node> res;\n    auto src = context.get_input(0);\n    auto perm = ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3});\n\n    if (op_case == 1 || context.is_stateful()) {\n        res = std::make_shared<ov::op::v1::Transpose>(src, perm);\n    } else if (op_case == 4) {\n        auto output_shape = context.get_output_shape().to_shape();\n        auto n_heads = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[1]});\n        auto head_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});\n        auto n_seq_active = context.has_input(\"n_seq_active\") ?\n                                context.get_input(\"n_seq_active\") :\n                                ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[0]});\n        auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});\n\n        auto new_shape =\n            std::make_shared<ov::op::v0::Concat>(ov::OutputVector{n_seq_active, neg_one, n_heads, head_size}, 0);\n\n        // // Alternative\n        // auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});\n        // auto new_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{n_seq_active, neg_one, zero, zero}, 0);\n\n        auto reshaped = std::make_shared<ov::op::v1::Reshape>(src, new_shape, true);\n        res = std::make_shared<ov::op::v1::Transpose>(reshaped, perm);\n    } else {\n        auto cache_shape = src.get_partial_shape();\n        auto output_shape = context.get_output_shape().to_shape();\n        int64_t head_size = output_shape[3];\n        int64_t n_heads = output_shape[1];\n        int64_t ctx_per_seq = cache_shape[2].is_static() ? cache_shape[2].get_length() : -1;\n        int64_t n_seq = cache_shape[1].get_length();\n\n        Output<Node> attention_size;\n        if (!context.has_input(\"attention_size\")) {\n            attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[2]});\n        } else if (op_case == 2) {\n            attention_size = context.get_input(\"attention_size\");\n        } else {\n            attention_size = context.get_input(\"attention_size_swa\");\n        }\n\n        Output<Node> seq_active_start;\n        Output<Node> seq_active_end;\n        if (context.has_input(\"seq_active_start\")) {\n            seq_active_start = context.get_input(\"seq_active_start\");\n            seq_active_end = context.get_input(\"seq_active_end\");\n        } else {\n            int64_t n_seq_active = output_shape[0];\n            size_t offset = *((size_t *) context.get_input_op_params(0));\n            int64_t seq_active_start_val = offset / context.get_input_stride(0)[0];\n            int64_t seq_active_end_val = seq_active_start_val + n_seq_active;\n            seq_active_start = ov::op::v0::Constant::create(ov::element::i64, {1}, {seq_active_start_val});\n            seq_active_end = ov::op::v0::Constant::create(ov::element::i64, {1}, {seq_active_end_val});\n        }\n\n        // 1. reshape to [n_seq, ctx_per_seq, n_heads, head_size]\n        // 2. slice out the active sequences\n        // 3. slice out the attention part in each sequence\n        // 4. permute\n        auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});\n        auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});\n\n        auto src_reshaped = std::make_shared<ov::op::v1::Reshape>(\n            src, ov::op::v0::Constant::create(ov::element::i64, {4}, {n_seq, ctx_per_seq, n_heads, head_size}), false);\n        auto slice1 = std::make_shared<ov::op::v8::Slice>(src_reshaped, seq_active_start, seq_active_end, one, zero);\n        auto slice2 = std::make_shared<ov::op::v8::Slice>(slice1, zero, attention_size, one, one);\n        res = std::make_shared<ov::op::v1::Transpose>(slice2, perm);\n    }\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/reshape.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <cstdint>\n#include <memory>\n#include <openvino/core/node.hpp>\n#include <openvino/core/node_output.hpp>\n#include <openvino/frontend/exception.hpp>\n#include <openvino/op/concat.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/reshape.hpp>\n#include <stdexcept>\n#include <vector>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_reshape(const NodeContext & context) {\n    num_inputs_check(context, 1, 1);\n    if (context.get_input_shape(0) == context.get_output_shape()) {\n        return {context.get_input(0)};\n    }\n\n    int op_case = context.get_op_case();\n    FRONT_END_CHECK_IMPLEMENTED(\n        op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4 || op_case == 5 || op_case == 6,\n        \"Unsupported RESHAPE case\");\n\n    auto output_shape = context.get_output_shape().to_shape();\n    std::shared_ptr<ov::Node> new_shape_node;\n    if (op_case == 1) {\n        if (context.is_stateful()) {\n            new_shape_node = ov::op::v0::Constant::create(\n                ov::element::i64, {3},\n                std::vector<int64_t>{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]});\n        } else {\n            new_shape_node = ov::op::v0::Constant::create(\n                ov::element::i64, {4},\n                std::vector<int64_t>{(int64_t) output_shape[0], -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});\n        }\n    } else if (op_case == 2) {\n        new_shape_node = ov::op::v0::Constant::create(\n            ov::element::i64, {4},\n            std::vector<int64_t>{(int64_t) output_shape[0], (int64_t) output_shape[1], -1, (int64_t) output_shape[3]});\n\n    } else if (op_case == 3) {\n        throw std::runtime_error(\"might be outdated RESHAPE case\");\n        new_shape_node = ov::op::v0::Constant::create(\n            ov::element::i64, {4}, std::vector<int64_t>{(int64_t) output_shape[0], (int64_t) output_shape[1], -1, 1});\n\n    } else if (op_case == 4) {\n        return {context.get_input(0).get_node_shared_ptr()->input_value(0)};\n\n    } else if (op_case == 5) {\n        if (context.is_stateful()) {\n            std::vector<int64_t> shape_vec = {1, -1, (int64_t) context.get_output_shape().to_shape()[3]};\n            new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {3}, shape_vec);\n        } else {\n            std::vector<int64_t> shape_vec = {1, 1, -1, (int64_t) context.get_output_shape().to_shape()[3]};\n            new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, shape_vec);\n        }\n\n        // // Alternative\n        // auto token_len = context.get_input(\"token_len\");\n        // auto emb_size =\n        //     ov::op::v0::Constant::create(ov::element::i64, {1}, {(int64_t) context.get_output_shape().to_shape()[3]});\n        // auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});\n        // new_shape_node = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one, one, token_len, emb_size}, 0);\n\n    } else if (op_case == 6) {\n        new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, context.get_output_shape().to_shape());\n    }\n    auto res = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), new_shape_node, false);\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/rms_norm.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <memory>\n#include <openvino/op/add.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/divide.hpp>\n#include <openvino/op/multiply.hpp>\n#include <openvino/op/power.hpp>\n#include <openvino/op/reduce_mean.hpp>\n#include <openvino/op/sqrt.hpp>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_rms_norm(const NodeContext & context) {\n    num_inputs_check(context, 1, 1);\n\n    auto input_node = context.get_input(0);\n    auto square = std::make_shared<ov::op::v1::Power>(\n        input_node, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {2.0f}));\n\n    auto mean = std::make_shared<ov::op::v1::ReduceMean>(\n        square, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}), true);\n\n    float eps;\n    memcpy(&eps, context.get_output_op_params(), sizeof(float));\n\n    auto rms = std::make_shared<ov::op::v0::Sqrt>(\n        std::make_shared<ov::op::v1::Add>(mean, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {eps})));\n\n    auto reciprocal =\n        std::make_shared<ov::op::v1::Divide>(ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {1.0f}), rms);\n\n    auto res = std::make_shared<ov::op::v1::Multiply>(input_node, reciprocal);\n\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/rope.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <cstdint>\n#include <memory>\n#include <openvino/core/node.hpp>\n#include <openvino/core/node_output.hpp>\n#include <openvino/op/add.hpp>\n#include <openvino/op/concat.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/multiply.hpp>\n#include <openvino/op/reshape.hpp>\n#include <openvino/op/shape_of.hpp>\n#include <openvino/op/slice.hpp>\n#include <openvino/op/split.hpp>\n#include <openvino/op/subtract.hpp>\n#include <openvino/op/unsqueeze.hpp>\n#include <vector>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_rope(const NodeContext & context) {\n    num_inputs_check(context, 2, 3);\n\n    int op_case = context.get_op_case();\n\n    ov::Output<Node> res;\n\n    auto data_node = context.get_input(0).get_node_shared_ptr();\n    auto output_shape = context.get_output_shape().to_shape();\n    int32_t * op_params = context.get_output_op_params();\n\n    Output<Node> cos_theta_node;\n    Output<Node> sin_theta_node;\n    if (context.has_input(\"rope_cos\")) {\n        cos_theta_node = context.get_input(\"rope_cos\");\n        sin_theta_node = context.get_input(\"rope_sin\");\n    } else {\n        auto inp_pos = context.get_input(1).get_node_shared_ptr();\n        std::shared_ptr<ov::Node> rope_freqs_weight;\n        if (context.get_input_size() == 3) {\n            rope_freqs_weight = context.get_input(2).get_node_shared_ptr();\n        }\n        auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight);\n        sin_theta_node = sin_cos.first;\n        cos_theta_node = sin_cos.second;\n    }\n\n    if (op_case == 2) {\n        // The input comes from a VIEW\n        int slice_len = output_shape[2] * output_shape[3];\n        data_node = process_view_input(context, 0, slice_len).get_node_shared_ptr();\n        if (context.is_stateful()) {\n            auto data_shape = ov::op::v0::Constant::create(\n                ov::element::i64, {3}, std::vector<int64_t>{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]});\n            data_node = std::make_shared<ov::op::v1::Reshape>(data_node, data_shape, false);\n        } else {\n            auto data_shape = ov::op::v0::Constant::create(\n                ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});\n            data_node = std::make_shared<ov::op::v1::Reshape>(data_node, data_shape, false);\n        }\n    }\n\n    const int mode = op_params[2];\n    constexpr int ROPE_TYPE_NORMAL = 0;\n    constexpr int ROPE_TYPE_NEOX = 2;\n\n    if (mode == ROPE_TYPE_NORMAL) {\n        auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});\n        auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});\n        auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});\n        auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});\n        auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});\n        Output<Node> even_slice;\n        Output<Node> odd_slice;\n        int32_t unsqueeze_dim = context.is_stateful() ? 3 : 4;\n        even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, neg_one);\n        odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, neg_one);\n\n        Output<Node> first_half =\n            std::make_shared<ov::op::v1::Subtract>(std::make_shared<ov::op::v1::Multiply>(even_slice, cos_theta_node),\n                                                   std::make_shared<ov::op::v1::Multiply>(odd_slice, sin_theta_node));\n        Output<Node> second_half =\n            std::make_shared<ov::op::v1::Add>(std::make_shared<ov::op::v1::Multiply>(even_slice, sin_theta_node),\n                                              std::make_shared<ov::op::v1::Multiply>(odd_slice, cos_theta_node));\n\n        first_half = std::make_shared<ov::op::v0::Unsqueeze>(first_half,\n                                                             ov::op::v0::Constant::create(ov::element::i64, {1}, {unsqueeze_dim}));\n        second_half = std::make_shared<ov::op::v0::Unsqueeze>(second_half,\n                                                              ov::op::v0::Constant::create(ov::element::i64, {1}, {unsqueeze_dim}));\n        auto stack = std::make_shared<ov::op::v0::Concat>(OutputVector{first_half, second_half}, unsqueeze_dim);\n\n        auto data_shape = ov::op::v0::Constant::create(\n            ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});\n        res = std::make_shared<ov::op::v1::Reshape>(stack, data_shape, false);\n    } else if (mode == ROPE_TYPE_NEOX) {\n        auto data_split = std::make_shared<ov::op::v1::Split>(\n            data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2);\n        Output<Node> slice_data_node_0 = data_split->outputs()[0];\n        Output<Node> slice_data_node_1 = data_split->outputs()[1];\n\n        auto first_half_node = std::make_shared<ov::op::v1::Subtract>(\n            std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, cos_theta_node),\n            std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, sin_theta_node));\n\n        auto second_half_node = std::make_shared<ov::op::v1::Add>(\n            std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, sin_theta_node),\n            std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, cos_theta_node));\n\n        res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, -1);\n    }\n\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/scale.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <openvino/op/add.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/multiply.hpp>\n#include <vector>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_scale(const NodeContext & context) {\n    num_inputs_check(context, 1, 1);\n\n    float scale;\n    float bias;\n    memcpy(&scale, (float *) context.get_output_op_params() + 0, sizeof(float));\n    memcpy(&bias, (float *) context.get_output_op_params() + 1, sizeof(float));\n\n    auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{scale});\n    auto scaled = std::make_shared<ov::op::v1::Multiply>(context.get_input(0), scale_node);\n\n    std::shared_ptr<ov::Node> res;\n    if (bias != 0.0f) {\n        auto bias_node =\n            std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{bias});\n        res = std::make_shared<ov::op::v1::Add>(scaled, bias_node);\n    } else {\n        res = scaled;\n    }\n\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/set_rows.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <cassert>\n#include <cstdint>\n#include <memory>\n#include <openvino/core/node.hpp>\n#include <openvino/core/node_output.hpp>\n#include <openvino/frontend/exception.hpp>\n#include <openvino/op/concat.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/convert.hpp>\n#include <openvino/op/gather.hpp>\n#include <openvino/op/reshape.hpp>\n#include <openvino/op/scatter_update.hpp>\n#include <openvino/op/shape_of.hpp>\n#include <openvino/op/slice.hpp>\n#include <openvino/op/squeeze.hpp>\n#include <openvino/op/transpose.hpp>\n#include <vector>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_set_rows(const NodeContext & context) {\n    num_inputs_check(context, 3, 3);\n\n    auto data = context.get_input(0);\n    auto indices = context.get_input(1);\n    auto dst = context.get_input(2);\n\n    data = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type());\n\n    auto dst_shape = context.get_output_shape().to_shape();\n\n    auto ind_squeezed =\n        std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 1, 2}));\n    auto data_reshaped = std::make_shared<ov::op::v1::Reshape>(\n        data,\n        ov::op::v0::Constant::create(ov::element::i64, {4},\n                                     {(int64_t) 1, (int64_t) 1, (int64_t) -1, (int64_t) dst_shape[3]}),\n        false);\n    auto axes = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2});\n\n    Output<Node> res;\n    if (context.is_stateful()) {\n        int concat_axis = 1;\n        int64_t dim2 = dst.get_partial_shape()[2].get_length();\n        int64_t dim3 = dst.get_partial_shape()[3].get_length();\n        data = std::make_shared<ov::op::v1::Reshape>(\n            data, ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 1, (int64_t) -1, dim2, dim3}), false);\n        res = std::make_shared<ov::op::v0::Concat>(OutputVector{dst, data}, concat_axis);\n    } else {\n        res = std::make_shared<ov::op::v3::ScatterUpdate>(dst, ind_squeezed, data_reshaped, axes);\n    }\n\n    if (auto dst_reshape = std::dynamic_pointer_cast<ov::op::v1::Reshape>(dst.get_node_shared_ptr())) {\n        // Fix the case of multiple sequences, reshape back to original shape [1, n_seq, ctx_per_seq, emb]\n        // ctx_per_seq is not fixed due to llama-bench compatibility\n        auto dst_shape_partial = dst_reshape->get_input_partial_shape(0);\n        std::vector<int64_t> dst_shape = {dst_shape_partial[0].get_length(), dst_shape_partial[1].get_length(),\n                                          dst_shape_partial[2].is_static() ? dst_shape_partial[2].get_length() : -1,\n                                          dst_shape_partial[3].get_length()};\n        res = std::make_shared<ov::op::v1::Reshape>(res, ov::op::v0::Constant::create(ov::element::i64, {4}, dst_shape),\n                                                    false);\n    }\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/softmax.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <climits>\n#include <cstdint>\n#include <memory>\n#include <openvino/core/node.hpp>\n#include <openvino/core/node_output.hpp>\n#include <openvino/op/add.hpp>\n#include <openvino/op/concat.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/convert.hpp>\n#include <openvino/op/matmul.hpp>\n#include <openvino/op/multiply.hpp>\n#include <openvino/op/slice.hpp>\n#include <openvino/op/softmax.hpp>\n#include <vector>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_soft_max(const NodeContext & context) {\n    // TODO code is outdated\n    num_inputs_check(context, 1, 2);\n\n    auto input_node = context.get_input(0).get_node_shared_ptr();\n    ov::Output<Node> res;\n\n    float scale = 1.0f;\n    float max_bias = 0.0f;\n    auto * op_params = context.get_output_op_params();\n    memcpy(&scale, (float *) op_params + 0, sizeof(float));\n    memcpy(&max_bias, (float *) op_params + 1, sizeof(float));\n    auto src0_shape = context.get_input_shape(0).get_shape();\n    const uint32_t h = src0_shape[2];\n    const uint32_t n_head = src0_shape[0];\n    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));\n\n    const float m0 = powf(2.0f, -(max_bias) / n_head_log2);\n    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n    const float slope =\n        (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f;\n\n    auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{scale});\n    auto scaled_input = std::make_shared<ov::op::v1::Multiply>(input_node, scale_node);\n\n    if (context.get_input_size() < 2) {\n        res = std::make_shared<ov::op::v8::Softmax>(scaled_input, 2);\n        return rename_outputs_with_suffix({res}, context.get_name());\n    }\n\n    ov::Output<ov::Node> mask_node_sliced;\n    if (context.has_input(\"KQ_mask_sliced\")) {\n        mask_node_sliced = context.get_input(\"KQ_mask_sliced\");\n    } else {\n        auto token_len = get_dimensions(input_node, {1});\n        auto mask_node = context.get_input(1);\n        auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});\n        auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});\n        mask_node_sliced = std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);\n    }\n\n    if (mask_node_sliced.get_element_type() != context.get_output_type()) {\n        mask_node_sliced = std::make_shared<ov::op::v0::Convert>(mask_node_sliced, context.get_output_type());\n    }\n\n    Output<Node> slope_mask;\n    if (slope != 1.0f) {\n        auto slope_node =\n            std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{slope});\n        slope_mask = std::make_shared<ov::op::v1::Multiply>(mask_node_sliced, slope_node);\n        throw std::runtime_error(\"Slope != 1.0f in softmax has not been tested, verify it before use.\");\n    }\n    slope_mask = mask_node_sliced;\n\n    auto input_slope_mask_node = std::make_shared<ov::op::v1::Add>(scaled_input, slope_mask);\n\n    res = std::make_shared<ov::op::v8::Softmax>(input_slope_mask_node, 2);\n\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/transpose.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <openvino/op/transpose.hpp>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_transpose(const NodeContext & context) {\n    num_inputs_check(context, 1, 1);\n\n    auto res = std::make_shared<ov::op::v1::Transpose>(\n        context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 1, 3, 2}));\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/unary_silu.cpp",
    "content": "#include \"../node_context.h\"\n#include \"../op_table.h\"\n#include \"../utils.h\"\n\n#include <openvino/core/node_output.hpp>\n#include <openvino/op/multiply.hpp>\n#include <openvino/op/sigmoid.hpp>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_unary_silu(const NodeContext & context) {\n    num_inputs_check(context, 1, 1);\n\n    auto input = context.get_input(0);\n    auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(input);\n    auto res = std::make_shared<ov::op::v1::Multiply>(input, sigmoid);\n\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op/view.cpp",
    "content": "#include \"../op_table.h\"\n#include \"../utils.h\"\n#include <openvino/op/reshape.hpp>\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace op {\n\nOutputVector translate_view(const NodeContext & context) {\n    num_inputs_check(context, 1, 1);\n\n    if (context.get_op_case() == 2) {\n        auto dst_shape = context.get_output_shape().to_shape();\n        return rename_outputs_with_suffix({process_view_input(context, 0, dst_shape[2] * dst_shape[3])},\n                                          context.get_name());\n    }\n    // op_case 3\n    if (context.get_op_case() == 3) {\n        auto input = context.get_input(0);\n        auto input_ov_shape = input.get_partial_shape();\n\n        auto input_llama_shape = context.get_input_shape(0).to_shape();\n\n        // if the input ov shape size is different from the input llama shape size, it means the input is already reshaped and we need to reshape it back to the original shape before slicing\n        if (input_ov_shape.size() != input_llama_shape.size()) {\n            input = std::make_shared<ov::op::v1::Reshape>(input, ov::op::v0::Constant::create(ov::element::i64, {input_llama_shape.size()}, input_llama_shape), false);\n        }\n\n        auto dst_shape = context.get_output_shape().to_shape();\n\n        // find the index of dst_shape that is different from input shape, and use that index to slice the input\n        int slice_dim = -1;\n        for (size_t i = 0; i < dst_shape.size(); ++i) {\n            if (dst_shape[i] != input_llama_shape[i]) {\n                slice_dim = i;\n                break;\n            }\n        }\n\n        auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});\n        auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {dst_shape[slice_dim]});\n        auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});\n        auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_dim});\n        auto sliced = std::make_shared<ov::op::v8::Slice>(input, begin, end, stride, axes);\n        return {sliced};\n    }\n    return {context.get_input(0)};\n}\n\n}  // namespace op\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op_table.cpp",
    "content": "#include \"op_table.h\"\n\n#include \"utils.h\"\n\n#include <openvino/op/add.hpp>\n#include <openvino/op/divide.hpp>\n#include <openvino/op/gather.hpp>\n#include <openvino/op/matmul.hpp>\n#include <openvino/op/multiply.hpp>\n#include <openvino/op/subtract.hpp>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\n\nstd::unordered_map<std::string, CreatorFunction> get_supported_ops() {\n    using namespace ov::op;\n    return {\n        {\"GGML_OP_ADD\",            op::translate_1to1_match_2_inputs<v1::Add>     },\n        {\"GGML_OP_ADD1\",           op::translate_1to1_match_2_inputs<v1::Add>     },\n        {\"GGML_OP_CONT\",           op::translate_cont                             },\n        {\"GGML_OP_DIV\",            op::translate_1to1_match_2_inputs<v1::Divide>  },\n        {\"GGML_OP_GET_ROWS\",       op::translate_get_rows                         },\n        {\"GGML_OP_MUL\",            op::translate_1to1_match_2_inputs<v1::Multiply>},\n        {\"GGML_OP_MUL_MAT\",        op::translate_mulmat                           },\n        {\"GGML_OP_PERMUTE\",        op::translate_permute                          },\n        {\"GGML_OP_RESHAPE\",        op::translate_reshape                          },\n        {\"GGML_OP_RMS_NORM\",       op::translate_rms_norm                         },\n        {\"GGML_OP_ROPE\",           op::translate_rope                             },\n        {\"GGML_OP_SCALE\",          op::translate_scale                            },\n        {\"GGML_OP_SOFT_MAX\",       op::translate_soft_max                         },\n        {\"GGML_OP_SUB\",            op::translate_1to1_match_2_inputs<v1::Subtract>},\n        {\"GGML_OP_TRANSPOSE\",      op::translate_transpose                        },\n        {\"GGML_UNARY_OP_SILU\",     op::translate_unary_silu                       },\n        {\"GGML_OP_VIEW\",           op::translate_view                             },\n        {\"GGML_GLU_OP_SWIGLU\",     op::translate_glu_swiglu                       },\n        {\"GGML_GLU_OP_GEGLU\",      op::translate_glu_geglu                        },\n        {\"GGML_OP_SET_ROWS\",       op::translate_set_rows                         },\n        {\"GGML_OP_CPY\",            op::translate_cpy                              },\n        {\"GGML_OP_FLASH_ATTN_EXT\", op::translate_flash_attn_ext                   },\n    };\n}\n\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/op_table.h",
    "content": "#pragma once\n\n#include \"node_context.h\"\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\n\nnamespace op {\n\n#define GGML_OP_CONVERTER(op) OutputVector op(const NodeContext& context)\n\nGGML_OP_CONVERTER(translate_add);\nGGML_OP_CONVERTER(translate_cont);\nGGML_OP_CONVERTER(translate_get_rows);\nGGML_OP_CONVERTER(translate_mul);\nGGML_OP_CONVERTER(translate_mulmat);\nGGML_OP_CONVERTER(translate_permute);\nGGML_OP_CONVERTER(translate_reshape);\nGGML_OP_CONVERTER(translate_rms_norm);\nGGML_OP_CONVERTER(translate_rope);\nGGML_OP_CONVERTER(translate_scale);\nGGML_OP_CONVERTER(translate_unary_silu);\nGGML_OP_CONVERTER(translate_soft_max);\nGGML_OP_CONVERTER(translate_transpose);\nGGML_OP_CONVERTER(translate_view);\nGGML_OP_CONVERTER(translate_glu_swiglu);\nGGML_OP_CONVERTER(translate_glu_geglu);\nGGML_OP_CONVERTER(translate_set_rows);\nGGML_OP_CONVERTER(translate_cpy);\nGGML_OP_CONVERTER(translate_flash_attn_ext);\n\n} // namespace op\n\nstd::unordered_map<std::string, CreatorFunction> get_supported_ops();\n\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/pass/eliminate_zp.cpp",
    "content": "#include \"eliminate_zp.h\"\n\n#include <openvino/core/graph_util.hpp>\n#include <openvino/core/parallel.hpp>\n#include <openvino/core/rt_info.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/convert.hpp>\n#include <openvino/op/multiply.hpp>\n#include <openvino/op/subtract.hpp>\n#include <openvino/pass/pattern/op/label.hpp>\n#include <openvino/pass/pattern/op/pattern.hpp>\n#include <openvino/pass/pattern/op/wrap_type.hpp>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace pass {\n\nEliminateZeroPoints::EliminateZeroPoints() {\n    // Find pattern:\n    // (Multiply Any(scale)\n    //           (Subtract (Convert Constant(data)))\n    //                     (Convert Constant(zero_point)))\n    // where zero_point is a scalar\n    // If data is u4 and zp value is 8 (q4_0), Replace the Subtract with an i4 Constant whose value is data - zp_val\n    // If data is u8 and zp value is 128 (q8_0) or 32 (q6_k), Replace the Subtract with an i8 Constant\n\n    auto m_data_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();\n    auto m_data_convert = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_data_constant});\n\n    auto m_zp_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();\n    auto m_zp_convert = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_zp_constant});\n\n    auto m_subtract = ov::pass::pattern::wrap_type<ov::op::v1::Subtract>({m_data_convert, m_zp_convert});\n    auto m_scale = ov::pass::pattern::any_input();\n    auto m_multiply = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_scale, m_subtract});\n\n    const auto callback = [=](ov::pass::pattern::Matcher & m) {\n        const auto & pattern_map = m.get_pattern_value_map();\n\n        auto multiply_node =\n            std::dynamic_pointer_cast<ov::op::v1::Multiply>(pattern_map.at(m_multiply).get_node_shared_ptr());\n        auto subtract_node =\n            std::dynamic_pointer_cast<ov::op::v1::Subtract>(pattern_map.at(m_subtract).get_node_shared_ptr());\n        auto data_constant =\n            std::dynamic_pointer_cast<ov::op::v0::Constant>(pattern_map.at(m_data_constant).get_node_shared_ptr());\n        auto zp_constant =\n            std::dynamic_pointer_cast<ov::op::v0::Constant>(pattern_map.at(m_zp_constant).get_node_shared_ptr());\n\n        if (!multiply_node || !subtract_node || !data_constant || !zp_constant) {\n            return false;\n        }\n\n        if (ov::shape_size(zp_constant->get_shape()) != 1) {\n            return false;\n        }\n\n        auto data_type = data_constant->get_element_type();\n        auto zp_data = zp_constant->cast_vector<int>();\n\n        if (zp_data.empty()) {\n            return false;\n        }\n\n        int zp_value = zp_data[0];\n\n        bool should_eliminate = false;\n        ov::element::Type target_type;\n\n        if (data_type == ov::element::u4 && zp_value == 8) {\n            should_eliminate = true;\n            target_type = ov::element::i4;\n        } else if (data_type == ov::element::u8 && (zp_value == 128 || zp_value == 32)) {\n            should_eliminate = true;\n            target_type = ov::element::i8;\n        }\n\n        if (!should_eliminate) {\n            return false;\n        }\n\n        auto data_shape = data_constant->get_shape();\n        size_t total_elements = ov::shape_size(data_shape);\n\n        std::shared_ptr<ov::op::v0::Constant> new_constant;\n\n        // TODO improve performance\n        if (data_type == ov::element::u4) {\n            auto data_values = data_constant->cast_vector<uint8_t>();\n            std::vector<int8_t> adjusted_values(total_elements);\n\n            ov::parallel_for(total_elements, [&](size_t i) {\n                adjusted_values[i] = static_cast<int8_t>(static_cast<int>(data_values[i]) - 8);\n            });\n\n            new_constant = std::make_shared<ov::op::v0::Constant>(target_type, data_shape, adjusted_values);\n        } else if (data_type == ov::element::u8) {\n            auto data_values = data_constant->cast_vector<uint8_t>();\n            std::vector<int8_t> adjusted_values(total_elements);\n\n            ov::parallel_for(total_elements, [&, zp_value](size_t i) {\n                adjusted_values[i] = static_cast<int8_t>(static_cast<int>(data_values[i]) - zp_value);\n            });\n\n            new_constant = std::make_shared<ov::op::v0::Constant>(target_type, data_shape, adjusted_values);\n        }\n\n        auto new_convert =\n            std::make_shared<ov::op::v0::Convert>(new_constant, subtract_node->get_output_element_type(0));\n        ov::replace_node(subtract_node, new_convert);\n\n        return true;\n    };\n\n    register_matcher(\n        std::make_shared<ov::pass::pattern::Matcher>(m_multiply, \"ov::frontend::ggml::pass::EliminateZeroPoints\"),\n        callback);\n}\n\n}  // namespace pass\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/pass/eliminate_zp.h",
    "content": "#include \"openvino/pass/matcher_pass.hpp\"\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace pass {\n\nclass EliminateZeroPoints : public ov::pass::MatcherPass {\npublic:\n    OPENVINO_MATCHER_PASS_RTTI(\"ov::frontend::ggml::pass::EliminateZeroPoints\")\n    EliminateZeroPoints();\n};\n\n}  // namespace pass\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp",
    "content": "#include \"fuse_to_sdpa.h\"\n\n#include <openvino/core/graph_util.hpp>\n#include <openvino/core/rt_info.hpp>\n#include <openvino/op/add.hpp>\n#include <openvino/op/convert.hpp>\n#include <openvino/op/matmul.hpp>\n#include <openvino/op/multiply.hpp>\n#include <openvino/op/scaled_dot_product_attention.hpp>\n#include <openvino/op/softmax.hpp>\n#include <openvino/op/transpose.hpp>\n#include <openvino/pass/pattern/op/label.hpp>\n#include <openvino/pass/pattern/op/pattern.hpp>\n#include <openvino/pass/pattern/op/wrap_type.hpp>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace pass {\n\nFuseToSDPA::FuseToSDPA() {\n    // Not maintained since FLASH_ATTN_EXT has replaced this pattern\n    const auto m_k = ov::pass::pattern::any_input();\n    const auto m_q = ov::pass::pattern::any_input();\n    const auto m_qk = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_q, m_k});\n    const auto m_qk_f32 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_qk});\n    const auto m_scale = ov::pass::pattern::any_input();\n    const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk_f32, m_scale});\n    const auto m_mask = ov::pass::pattern::any_input();\n    const auto m_masked_qk = ov::pass::pattern::wrap_type<ov::op::v1::Add>({m_scaled_qk, m_mask});\n    const auto m_softmax_qk = ov::pass::pattern::wrap_type<ov::op::v8::Softmax>({m_masked_qk});\n    const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_softmax_qk});\n    const auto m_v = ov::pass::pattern::any_input();\n    const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk_f16, m_v});\n\n    const auto callback = [=](ov::pass::pattern::Matcher & m) {\n        auto & pattern_to_output = m.get_pattern_value_map();\n        auto k = pattern_to_output[m_k];\n        auto q = pattern_to_output[m_q];\n        auto v = pattern_to_output[m_v];\n        auto mask = pattern_to_output[m_mask];\n        auto scale = pattern_to_output[m_scale];\n\n        auto mask_f16 = register_new_node<ov::op::v0::Convert>(mask, ov::element::f16);\n        auto scale_f16 = register_new_node<ov::op::v0::Convert>(scale, ov::element::f16);\n        auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_f16, scale_f16, false);\n\n        ov::replace_node(m.get_match_root(), sdpa);\n        ov::copy_runtime_info(m.get_matched_nodes(), sdpa);\n\n        return true;\n    };\n    register_matcher(std::make_shared<ov::pass::pattern::Matcher>(m_qkv, \"ov::frontend::ggml::pass::FuseToSDPA\"),\n                     callback);\n}\n\n}  // namespace pass\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/pass/fuse_to_sdpa.h",
    "content": "#include \"openvino/pass/matcher_pass.hpp\"\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace pass {\n\nclass FuseToSDPA : public ov::pass::MatcherPass {\npublic:\n    OPENVINO_MATCHER_PASS_RTTI(\"ov::frontend::ggml::pass::FuseToSDPA\")\n    FuseToSDPA();\n};\n\n}  // namespace pass\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h",
    "content": "#pragma once\n\n#include \"mark_decompression_convert_constant_folding.h\"\n#include \"openvino/pass/matcher_pass.hpp\"\n#include \"openvino/core/visibility.hpp\"\n\n#ifdef OPENVINO_STATIC_LIBRARY\n#    define TRANSFORMATIONS_API\n#else\n#    ifdef IMPLEMENT_OPENVINO_API\n#        define TRANSFORMATIONS_API OPENVINO_CORE_EXPORTS\n#    else\n#        define TRANSFORMATIONS_API OPENVINO_CORE_IMPORTS\n#    endif  // IMPLEMENT_OPENVINO_API\n#endif      // OPENVINO_STATIC_LIBRARY\n\nnamespace ov {\nnamespace pass {\n\nclass TRANSFORMATIONS_API MarkCompressedFloatConstants;\n\n}  // namespace pass\n}  // namespace ov\n\nclass ov::pass::MarkCompressedFloatConstants : public MatcherPass {\npublic:\n    OPENVINO_MATCHER_PASS_RTTI(\"MarkCompressedFloatConstants\")\n    MarkCompressedFloatConstants();\n};\n"
  },
  {
    "path": "src/ggml-openvino/openvino/pass/squeeze_matmul.cpp",
    "content": "#include \"squeeze_matmul.h\"\n\n#include <openvino/core/graph_util.hpp>\n#include <openvino/core/rt_info.hpp>\n#include <openvino/op/constant.hpp>\n#include <openvino/op/matmul.hpp>\n#include <openvino/op/squeeze.hpp>\n#include <openvino/op/unsqueeze.hpp>\n#include <openvino/pass/pattern/op/label.hpp>\n#include <openvino/pass/pattern/op/pattern.hpp>\n#include <openvino/pass/pattern/op/wrap_type.hpp>\n\nnamespace opp = ov::pass::pattern;\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace pass {\n\n// For quantized models, NPUW expects the activation to be 3d in DQ(DynamicQuantization) opt, e.g. DQMatMulGQ2i\nSqueezeMatmul::SqueezeMatmul() {\n    auto m_act = opp::any_input();\n    auto m_wei = opp::any_input();\n    auto m_matmul = opp::wrap_type<ov::op::v0::MatMul>({m_act, m_wei});\n\n    const auto callback = [=](ov::pass::pattern::Matcher & m) {\n        const auto & pattern_map = m.get_pattern_value_map();\n        auto matmul_node =\n            std::dynamic_pointer_cast<ov::op::v0::MatMul>(pattern_map.at(m_matmul).get_node_shared_ptr());\n        auto act = pattern_map.at(m_act);\n        auto wei = pattern_map.at(m_wei);\n        auto act_shape = act.get_partial_shape();\n        auto wei_shape = wei.get_partial_shape();\n        if (act_shape.rank().is_dynamic() || wei_shape.rank().is_dynamic()) {\n            return false;\n        }\n        if (act_shape.rank().get_length() == 4 && wei_shape.rank().get_length() == 2) {\n            auto axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});\n            auto squeezed_act = std::make_shared<ov::op::v0::Squeeze>(act, axis);\n            auto new_matmul = std::make_shared<ov::op::v0::MatMul>(squeezed_act, wei, matmul_node->get_transpose_a(),\n                                                                   matmul_node->get_transpose_b());\n            auto unsqueezed_output = std::make_shared<ov::op::v0::Unsqueeze>(new_matmul, axis);\n            unsqueezed_output->set_friendly_name(matmul_node->get_friendly_name());\n            ov::copy_runtime_info(matmul_node, {squeezed_act, new_matmul, unsqueezed_output});\n            ov::replace_node(matmul_node, unsqueezed_output);\n            return true;\n        }\n        return false;\n    };\n\n    register_matcher(std::make_shared<ov::pass::pattern::Matcher>(m_matmul, \"ov::frontend::ggml::pass::SqueezeMatmul\"),\n                     callback);\n}\n\n}  // namespace pass\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/pass/squeeze_matmul.h",
    "content": "#include \"openvino/pass/matcher_pass.hpp\"\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\nnamespace pass {\n\nclass SqueezeMatmul : public ov::pass::MatcherPass {\npublic:\n    OPENVINO_MATCHER_PASS_RTTI(\"ov::frontend::ggml::pass::SqueezeMatmul\")\n    SqueezeMatmul();\n};\n\n}  // namespace pass\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/translate_session.cpp",
    "content": "#include \"translate_session.h\"\n\n#include \"ggml-openvino/openvino/node_context.h\"\n#include \"ggml-openvino/openvino/utils.h\"\n#include \"input_model.h\"\n#include \"pass/eliminate_zp.h\"\n#include \"pass/mark_decompression_convert_constant_folding.h\"\n#include \"pass/squeeze_matmul.h\"\n\n#include <cstdint>\n#include <cstdlib>\n#include <map>\n#include <memory>\n#include <openvino/core/node.hpp>\n#include <openvino/op/add.hpp>\n#include <openvino/op/broadcast.hpp>\n#include <openvino/op/concat.hpp>\n#include <openvino/op/convert.hpp>\n#include <openvino/op/convert_like.hpp>\n#include <openvino/op/cos.hpp>\n#include <openvino/op/divide.hpp>\n#include <openvino/op/gather.hpp>\n#include <openvino/op/multiply.hpp>\n#include <openvino/op/parameter.hpp>\n#include <openvino/op/range.hpp>\n#include <openvino/op/reshape.hpp>\n#include <openvino/op/result.hpp>\n#include <openvino/op/sin.hpp>\n#include <openvino/op/slice.hpp>\n#include <openvino/op/squeeze.hpp>\n#include <openvino/op/strided_slice.hpp>\n#include <openvino/op/transpose.hpp>\n#include <openvino/op/unsqueeze.hpp>\n#include <openvino/pass/constant_folding.hpp>\n#include <openvino/pass/make_stateful.hpp>\n#include <openvino/core/preprocess/pre_post_process.hpp>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\n\nusing namespace ov::op;\n\nnamespace {\n\nov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs(\n    const std::shared_ptr<ov::Model> & model,\n    const std::map<std::string, std::string> & kv_param_res_names) {\n    ov::pass::MakeStateful::ParamResPairs pairs;\n    const auto & params = model->get_parameters();\n    const auto & results = model->get_results();\n\n    for (const auto & param_res : kv_param_res_names) {\n        const auto & param_name = param_res.first;\n        const auto & res_name = param_res.second;\n\n        auto param_it = std::find_if(params.begin(), params.end(), [&](const std::shared_ptr<v0::Parameter> & node) {\n            return node->get_friendly_name() == param_name;\n        });\n\n        OPENVINO_ASSERT(param_it != params.end(), \"The tensor name \", param_name,\n                        \" is not associated with any of \"\n                        \"Parameters in the network.\");\n\n        auto res_it = std::find_if(results.begin(), results.end(), [&](const std::shared_ptr<v0::Result> & node) {\n            return node->get_friendly_name() == res_name;\n        });\n\n        OPENVINO_ASSERT(res_it != results.end(), \"The tensor name \", res_name,\n                        \" is not associated with any of \"\n                        \"Results in the network.\");\n\n        std::shared_ptr<ov::op::v0::Parameter> param = *param_it;\n        std::shared_ptr<ov::op::v0::Result> res = *res_it;\n        pairs.emplace_back(param, res);\n    }\n    return pairs;\n}\n\nvoid add_sliced_mask(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {\n\n    auto create_sliced_mask = [&](const std::string & mask_name, const std::string & sliced_name, bool is_static) {\n        if ((tensor_map.find(mask_name) != tensor_map.end()) &&\n            (tensor_map.find(\"token_len_per_seq\") != tensor_map.end())) {\n            auto token_len_per_seq = tensor_map.at(\"token_len_per_seq\").get_node_shared_ptr();\n            auto mask = tensor_map.at(mask_name).get_node_shared_ptr();\n            std::shared_ptr<ov::Node> mask_sliced;\n            if (is_static) {\n                mask_sliced = mask;\n            } else if (ggml_model_decoder.is_stateful()) {\n                auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0});\n                auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1});\n                auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});\n                auto three_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});\n                auto neg_one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});\n                auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {-2,-1});\n                auto inp_pos = tensor_map.at(\"inp_pos\").get_node_shared_ptr();\n                auto gather_inp_pos = std::make_shared<ov::op::v8::Gather>(inp_pos, neg_one_1d, three_1d);\n                auto reshaped_inp_pos = std::make_shared<ov::op::v1::Reshape>(gather_inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {1}), false);\n                auto inp_pos_incremented = std::make_shared<ov::op::v1::Add>(reshaped_inp_pos, ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {1}));\n                auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len_per_seq, std::make_shared<v1::ConvertLike>(inp_pos_incremented, token_len_per_seq)}, 0);\n                mask_sliced =\n                    std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);\n                mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);\n                mask_sliced->set_friendly_name(sliced_name);\n            } else {\n                auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});\n                auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});\n                auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});\n                mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len_per_seq, one, two);\n                mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);\n                mask_sliced->set_friendly_name(sliced_name);\n            }\n            tensor_map.insert({sliced_name, mask_sliced->output(0)});\n        }\n    };\n\n    create_sliced_mask(\"self_kq_mask\", \"KQ_mask_sliced\", ggml_model_decoder.is_static());\n    create_sliced_mask(\"self_kq_mask_swa\", \"KQ_mask_swa_sliced\", ggml_model_decoder.is_static());\n}\n\nvoid add_rope_sin_cos(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {\n    int32_t * rope_params = ggml_model_decoder.get_rope_params();\n    if (tensor_map.find(\"inp_pos\") == tensor_map.end() || rope_params == nullptr) {\n        return;\n    }\n    auto inp_pos = tensor_map.at(\"inp_pos\").get_node_shared_ptr();\n    std::shared_ptr<ov::Node> rope_freqs_weight;\n    if (tensor_map.find(\"rope_freqs.weight\") != tensor_map.end()) {\n        rope_freqs_weight = tensor_map.at(\"rope_freqs.weight\").get_node_shared_ptr();\n    }\n\n    auto sin_cos = make_sin_cos(rope_params, inp_pos, rope_freqs_weight);\n    auto sin_theta = sin_cos.first;\n    auto cos_theta = sin_cos.second;\n\n    cos_theta.get_node_shared_ptr()->set_friendly_name(\"rope_cos\");\n    sin_theta.get_node_shared_ptr()->set_friendly_name(\"rope_sin\");\n    tensor_map.insert({\"rope_cos\", cos_theta});\n    tensor_map.insert({\"rope_sin\", sin_theta});\n}\n\n// Create common patterns\nvoid preprocess(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {\n    add_sliced_mask(tensor_map, ggml_model_decoder);\n    add_rope_sin_cos(tensor_map, ggml_model_decoder);\n}\n\n}  // namespace\n\nTranslateSession::TranslateSession(const frontend::InputModel::Ptr & input_model,\n                                   const std::unordered_map<std::string, CreatorFunction> & translator_map,\n                                   bool naive) :\n    m_input_model(input_model),\n    m_translator_map(translator_map),\n    m_ov_model(nullptr),\n    m_naive(naive) {}\n\nstd::shared_ptr<Model> TranslateSession::get_converted_model() {\n    if (m_ov_model) {\n        return m_ov_model;\n    }\n    m_ov_model = translate_graph(m_input_model);\n    return m_ov_model;\n}\n\nstd::shared_ptr<Model> TranslateSession::translate_graph(const frontend::InputModel::Ptr & input_model) {\n    ov::ParameterVector params;\n    ov::ResultVector results;\n    auto tensor_map = std::make_shared<TensorMap>();\n    std::shared_ptr<Model> resulting_model;\n\n    const auto & ggml_model = std::dynamic_pointer_cast<InputModel>(input_model);\n    std::shared_ptr<GgmlDecoder> ggml_model_decoder = ggml_model->get_model_decoder();\n\n    for (const auto & it : ggml_model_decoder->get_model_inputs()) {\n        params.push_back(std::dynamic_pointer_cast<ov::op::v0::Parameter>(it.second));\n        (*tensor_map)[it.first] = it.second;\n    }\n\n    for (const auto & it : ggml_model_decoder->get_model_extra_inputs()) {\n        if (std::dynamic_pointer_cast<ov::op::v0::Parameter>(it.second)) {\n            params.push_back(std::dynamic_pointer_cast<ov::op::v0::Parameter>(it.second));\n        }\n        (*tensor_map)[it.first] = it.second;\n    }\n\n    for (const auto & it : ggml_model_decoder->get_model_weights()) {\n        (*tensor_map)[it.first] = it.second;\n    }\n\n    auto node_visitor = [&](std::shared_ptr<GgmlDecoder> decoder, int node_idx) {\n        auto operation_type = decoder->get_op_type(node_idx);\n        if (operation_type == \"GGML_OP_NONE\") {\n            return;\n        }\n\n        ov::OutputVector converted_outputs;\n        auto it = m_translator_map.find(operation_type);\n        FRONT_END_OP_CONVERSION_CHECK(it != m_translator_map.end(), \"Translation for operation type \", operation_type,\n                                      \" is not implemented.\");\n        NodeContext node_context(decoder, tensor_map, node_idx, this);\n        converted_outputs = it->second(node_context);\n\n        const auto & node_output_names = decoder->get_output_names(node_idx);\n        FRONT_END_OP_CONVERSION_CHECK(node_output_names.size() == converted_outputs.size(), \"Number of \",\n                                      operation_type, \" outputs greater than number of converted outputs, which are \",\n                                      node_output_names.size(), \" and \", converted_outputs.size(), \" respectively.\");\n\n        for (size_t i = 0; i < node_output_names.size(); ++i) {\n            auto output_name = node_output_names[i];\n            if (i < converted_outputs.size() && converted_outputs[i].get_node_shared_ptr() != nullptr) {\n                (*tensor_map)[output_name] = converted_outputs[i];\n            }\n        }\n    };\n\n    if (!m_naive) {\n        preprocess(*tensor_map, *ggml_model_decoder);\n    }\n    ggml_model_decoder->visit_subgraph(node_visitor);\n\n    for (const auto & name : ggml_model_decoder->get_model_output_names()) {\n        FRONT_END_GENERAL_CHECK(tensor_map->find(name) != tensor_map->end(),\n                                \"Output name not found in tensor map: \", name);\n        auto result = std::make_shared<v0::Result>(tensor_map->at(name));\n        result->set_friendly_name(name);\n        results.push_back(result);\n    }\n\n    ov::ParameterVector used_params;\n    for (const auto & param : params) {\n        if (!param->output(0).get_target_inputs().empty()) {\n            used_params.push_back(param);\n        }\n    }\n    // if (auto diff = params.size() - used_params.size()) {\n    //     GGML_LOG_INFO(\"%zu parameters are not used in the model.\", diff);\n    // }\n    resulting_model = std::make_shared<Model>(results, used_params);\n\n    apply_transformations(resulting_model);\n    return resulting_model;\n}\n\nstd::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<Model> model) {\n    auto ggml_model_decoder = std::dynamic_pointer_cast<InputModel>(m_input_model)->get_model_decoder();\n    {\n        ov::pass::Manager manager;\n        manager.set_per_pass_validation(true);\n        manager.register_pass<ov::pass::MarkCompressedFloatConstants>();\n\n        if (ggml_model_decoder->is_stateful()) {\n            const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();\n            const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);\n            manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);\n        }\n\n        if (ggml_model_decoder->is_static()) {\n            manager.register_pass<pass::EliminateZeroPoints>();\n            manager.register_pass<pass::SqueezeMatmul>();\n        }\n        manager.run_passes(model);\n        if (ggml_model_decoder->is_stateful()) {\n            auto output_names = ggml_model_decoder->get_model_output_names();\n            std::map<std::string, int> model_output_indexes;\n            for (size_t i=0; i<output_names.size(); i++) {\n                model_output_indexes.insert(std::make_pair(output_names[i], i));\n            }\n            ov::preprocess::PrePostProcessor ppp(model);\n            for (size_t i=0; i<model->get_output_size(); i++) {\n                auto output_friendly_name = model->output(i).get_node_shared_ptr()->get_friendly_name();\n                auto output_id = model_output_indexes[output_friendly_name];\n                auto model_output_shape = model->output(i).get_partial_shape();\n                auto decoder_output_shape = ggml_model_decoder->get_output_shape(output_id);\n                if (model_output_shape.rank().is_static() && decoder_output_shape.rank().is_static()\n                    && model_output_shape.rank().get_length() + 1 == decoder_output_shape.rank().get_length()\n                    && decoder_output_shape[0].is_static() && decoder_output_shape[0].get_length() == 1) {\n                    ppp.output(i).postprocess().custom([](const ov::Output<ov::Node>& node) {\n                        auto axes = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {0});\n                        return std::make_shared<ov::op::v0::Unsqueeze>(node, axes);\n                    });\n                }\n            }\n            model = ppp.build();\n        }\n    }\n    return model;\n}\n\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/translate_session.h",
    "content": "#pragma once\n\n#include \"input_model.h\"\n#include \"node_context.h\"\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\n\nclass TranslateSession {\npublic:\n    TranslateSession(const frontend::InputModel::Ptr& input_model,\n                     const std::unordered_map<std::string, CreatorFunction>& translator_map, bool naive = false);\n\n    std::shared_ptr<Model> get_converted_model();\n    std::shared_ptr<Model> translate_graph(const frontend::InputModel::Ptr& input_model);\n\nprivate:\n    std::shared_ptr<Model> apply_transformations(std::shared_ptr<Model> model);\n    const frontend::InputModel::Ptr m_input_model;\n    const std::unordered_map<std::string, CreatorFunction>& m_translator_map;\n    std::shared_ptr<Model> m_ov_model;\n    bool m_naive;\n};\n\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/utils.cpp",
    "content": "#include \"utils.h\"\n\n#include \"ggml-impl.h\"\n\n#include <cstddef>\n#include <ctime>\n#include <memory>\n#include <openvino/op/add.hpp>\n#include <openvino/op/clamp.hpp>\n#include <openvino/op/convert.hpp>\n#include <openvino/op/cos.hpp>\n#include <openvino/op/divide.hpp>\n#include <openvino/op/gather.hpp>\n#include <openvino/op/maximum.hpp>\n#include <openvino/op/multiply.hpp>\n#include <openvino/op/shape_of.hpp>\n#include <openvino/op/sin.hpp>\n#include <openvino/op/squeeze.hpp>\n#include <openvino/op/subtract.hpp>\n#include <openvino/op/transpose.hpp>\n#include <string>\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\n\nstd::string getCurrentTime() {\n    std::time_t now = std::time(nullptr);\n    char buf[100];\n    std::strftime(buf, sizeof(buf), \"%Y-%m-%d %H:%M:%S\", std::localtime(&now));\n    return buf;\n}\n\nvoid num_inputs_check(const NodeContext & context, size_t min_inputs, size_t max_inputs) {\n    auto input_size = context.get_input_size();\n    FRONT_END_OP_CONVERSION_CHECK(input_size >= min_inputs, \"Got less inputs than expected\");\n    FRONT_END_OP_CONVERSION_CHECK(input_size <= max_inputs, \"Got more inputs than expected\");\n}\n\nint non_cont_dim(std::vector<size_t> ne, std::vector<size_t> nb) {\n    int dim = nb.size() - 1;\n    size_t bytes = nb[dim];\n    for (int i = dim; i > 0; i--) {\n        bytes *= ne[i];\n        if (bytes != nb[i - 1]) {\n            return i;\n        }\n    }\n    return 0;\n}\n\nstd::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<ov::op::v3::ShapeOf> & shape,\n                                         const std::vector<int> & dims) {\n    using namespace ov::op;\n    const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0});\n    const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims);\n    return std::make_shared<v8::Gather>(shape, dims_const, zero);\n}\n\nstd::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<ov::Node> & node, const std::vector<int> & dims) {\n    return get_dimensions(std::make_shared<ov::op::v3::ShapeOf>(node), dims);\n}\n\nOutputVector rename_outputs_with_suffix(const OutputVector & outputs, const std::string & suffix) {\n    for (const auto & output : outputs) {\n        auto node = output.get_node_shared_ptr();\n        std::string name = node->get_friendly_name();\n        name += \"_\";\n        name += suffix;\n        node->set_friendly_name(name);\n        // std::cout << name << \"  \" << output.get_partial_shape() << std::endl;\n    }\n    return outputs;\n}\n\nnamespace {\nov::Output<ov::Node> rope_yarn_ramp_mix(int n_dims, const float corr_dims[2], float ext_factor) {\n    int half_n_dims = n_dims / 2;\n    std::vector<float> dim_ids_vec(half_n_dims);\n    std::iota(dim_ids_vec.begin(), dim_ids_vec.end(), 0);\n    auto dim_ids = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, (size_t) half_n_dims}, dim_ids_vec);\n    auto corr_low = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {corr_dims[0]});\n    auto corr_high = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {corr_dims[1]});\n    auto denom = std::make_shared<ov::op::v1::Maximum>(\n        std::make_shared<ov::op::v1::Subtract>(corr_high, corr_low),\n        ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {0.001f}));\n    auto ramp_y =\n        std::make_shared<ov::op::v1::Divide>(std::make_shared<ov::op::v1::Subtract>(dim_ids, corr_low), denom);\n    auto ramp_clamped = std::make_shared<ov::op::v0::Clamp>(ramp_y, 0.0f, 1.0f);\n    auto ext_factor_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {ext_factor});\n    auto ramp_mix = std::make_shared<ov::op::v1::Multiply>(ramp_clamped, ext_factor_node);\n    return ramp_mix;\n}\n\nfloat ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {\n#ifndef M_PI\n#    define M_PI 3.14159265358979323846\n#endif\n    return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float) M_PI)) / (2 * logf(base));\n}\n\nvoid ggml_rope_yarn_corr_dims(int n_dims,\n                              int n_ctx_orig,\n                              float freq_base,\n                              float beta_fast,\n                              float beta_slow,\n                              float dims[2]) {\n    float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));\n    float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));\n    dims[0] = std::max(0.0f, start);\n    dims[1] = std::min(static_cast<float>(n_dims - 1), end);\n}\n}  // namespace\n\nstd::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t * rope_params,\n                                                           std::shared_ptr<ov::Node> inp_pos,\n                                                           std::shared_ptr<ov::Node> rope_freqs_weight,\n                                                           bool stateful) {\n    if (stateful) {\n        inp_pos = std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));\n        inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32);\n        auto pos_perm =\n            std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{2, 1, 0});\n        inp_pos = std::make_shared<ov::op::v1::Transpose>(inp_pos, pos_perm);\n    } else {\n        inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32);\n        auto pos_perm =\n            std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 3, 1, 2});\n        inp_pos = std::make_shared<ov::op::v1::Transpose>(inp_pos, pos_perm);\n    }\n\n    float freq_base;\n    float freq_scale;\n    float ext_factor;\n    float attn_factor;\n    float beta_fast;\n    float beta_slow;\n    const int n_dims = rope_params[1];\n    const int n_ctx_orig = rope_params[4];\n    memcpy(&freq_base, rope_params + 5, sizeof(float));\n    memcpy(&freq_scale, rope_params + 6, sizeof(float));\n    memcpy(&ext_factor, rope_params + 7, sizeof(float));\n    memcpy(&attn_factor, rope_params + 8, sizeof(float));\n    memcpy(&beta_fast, rope_params + 9, sizeof(float));\n    memcpy(&beta_slow, rope_params + 10, sizeof(float));\n\n    const float theta_scale = powf(freq_base, -2.0f / n_dims);\n\n    float corr_dims[2];\n    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);\n\n    std::vector<float> factor(n_dims / 2);\n    factor[0] = 1.0f;\n    for (size_t i = 1; i < factor.size(); i++) {\n        factor[i] = theta_scale * factor[i - 1];\n    }\n\n    Output<Node> freq_factors;\n    if (stateful) {\n        freq_factors =\n            std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor);\n    } else {\n        freq_factors =\n            std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor);\n    }\n    if (rope_freqs_weight) {\n        freq_factors = std::make_shared<ov::op::v1::Divide>(freq_factors, rope_freqs_weight);\n    }\n\n    auto theta_extrap = std::make_shared<ov::op::v1::Multiply>(freq_factors, inp_pos);\n    auto theta_interp = std::make_shared<ov::op::v1::Multiply>(\n        theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale}));\n\n    Output<Node> theta;\n    float mscale = attn_factor;\n    if (ext_factor == 0.0f) {\n        theta = theta_interp;\n    } else {\n        auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor);\n        Output<Node> one;\n        if (stateful) {\n            one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f});\n        } else {\n            one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f});\n        }\n        auto one_minus_ramp = std::make_shared<ov::op::v1::Subtract>(one, ramp_mix);\n\n        theta = std::make_shared<ov::op::v1::Add>(std::make_shared<ov::op::v1::Multiply>(theta_interp, one_minus_ramp),\n                                                  std::make_shared<ov::op::v1::Multiply>(theta_extrap, ramp_mix));\n        mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale));\n    }\n\n    Output<Node> cos_theta = std::make_shared<ov::op::v0::Cos>(theta);\n    Output<Node> sin_theta = std::make_shared<ov::op::v0::Sin>(theta);\n\n    auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale});\n\n    cos_theta = std::make_shared<ov::op::v1::Multiply>(cos_theta, mscale_node);\n    sin_theta = std::make_shared<ov::op::v1::Multiply>(sin_theta, mscale_node);\n    return std::make_pair(sin_theta, cos_theta);\n}\n\nov::Output<ov::Node> process_view_input(const NodeContext & context, int input_index, int slice_len) {\n    // Only works for VIEW operations that slice at the lowest dimension\n    // If the VIEW also reshape the result, `slice_len` should be provided\n    auto input = context.get_input(input_index);\n    auto * op_params = (size_t *) context.get_input_op_params(input_index);\n    auto src1_stride = context.get_input_stride(input_index);\n\n    int64_t split_addr = op_params[0] / src1_stride[3];\n    if (slice_len == 0) {\n        slice_len = context.get_input_shape(input_index)[3].get_length();\n    }\n    int64_t slice_end = split_addr + slice_len;\n\n    auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {split_addr});\n    auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_end});\n    auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});\n    auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {context.is_stateful() ? 2 : 3});\n    auto sliced = std::make_shared<ov::op::v8::Slice>(input, begin, end, stride, axes);\n    return sliced;\n}\n\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/openvino/utils.h",
    "content": "#pragma once\n\n#include <memory>\n#include <openvino/core/node.hpp>\n#include <openvino/op/shape_of.hpp>\n#include <openvino/op/slice.hpp>\n#include <utility>\n\n#include \"node_context.h\"\n\nnamespace ov {\nnamespace frontend {\nnamespace ggml {\n\nstd::string getCurrentTime();\n\nvoid dump_ov_model(std::shared_ptr<ov::Model> model);\n\nvoid num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs);\n\nint non_cont_dim(std::vector<size_t> ne, std::vector<size_t> nb);\n\ntemplate <typename T>\nstd::vector<int> argsort_descend(const std::vector<T>& v) {\n    std::vector<int> idx(v.size());\n    std::iota(idx.begin(), idx.end(), 0);\n    std::sort(idx.begin(), idx.end(), [&v](int i1, int i2) {\n        return v[i1] > v[i2];\n    });\n    return idx;\n}\n\ntemplate <typename T>\nstd::vector<T> sorted_descend(std::vector<T> v) {\n    std::sort(v.begin(), v.end(), [](T a, T b) {\n        return a > b;\n    });\n    return v;\n}\n\ntemplate <typename T>\nbool is_permuted(const std::vector<T>& strides) {\n    for (size_t i = 0; i < strides.size() - 1; ++i) {\n        if (strides[i] < strides[i + 1]) {\n            return true;\n        }\n    }\n    return false;\n}\n\ntemplate <typename T>\nstd::vector<T> permute(const std::vector<T>& x, const std::vector<int>& perm) {\n    std::vector<T> result;\n    result.reserve(perm.size());\n    for (int i : perm) {\n        result.push_back(x[i]);\n    }\n    return result;\n}\n\nstd::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<ov::op::v3::ShapeOf>& shape,\n                                         const std::vector<int>& dims);\nstd::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<ov::Node>& node, const std::vector<int>& dims);\n\nOutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std::string& suffix);\n\nstd::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t* rope_params,\n                                                           std::shared_ptr<ov::Node> inp_pos,\n                                                           std::shared_ptr<ov::Node> rope_freqs_weight = nullptr,\n                                                           bool stateful = false);\n\nov::Output<ov::Node> process_view_input(const NodeContext& context, int input_index, int slice_len = 0);\n\nnamespace op {\ntemplate <typename T>\nOutputVector translate_1to1_match_2_inputs(const NodeContext& context) {\n    num_inputs_check(context, 2, 2);\n    auto res = std::make_shared<T>(context.get_input(0), context.get_input(1));\n    return rename_outputs_with_suffix({res}, context.get_name());\n}\n}  // namespace op\n\n}  // namespace ggml\n}  // namespace frontend\n}  // namespace ov\n"
  },
  {
    "path": "src/ggml-openvino/utils.cpp",
    "content": "#include \"utils.h\"\n\n#include \"ggml-impl.h\"\n#include \"ggml-openvino-extra.h\"\n#include \"ggml-openvino/ggml-decoder.h\"\n#include \"ggml.h\"\n#include \"openvino/frontend.h\"\n#include \"openvino/input_model.h\"\n\n#include <algorithm>\n#include <cassert>\n#include <cmath>\n#include <cstddef>\n#include <cstdint>\n#include <cstdlib>\n#include <cstring>\n#include <iomanip>\n#include <iostream>\n#include <memory>\n#include <openvino/core/any.hpp>\n#include <openvino/core/graph_util.hpp>\n#include <openvino/core/shape.hpp>\n#include <openvino/core/type/float16.hpp>\n#include <openvino/frontend/manager.hpp>\n#include <openvino/openvino.hpp>\n#include <openvino/runtime/compiled_model.hpp>\n#include <openvino/runtime/infer_request.hpp>\n#include <openvino/runtime/intel_npu/properties.hpp>\n#include <openvino/runtime/properties.hpp>\n#include <openvino/runtime/tensor.hpp>\n#include <string>\n#include <unordered_map>\n#include <vector>\n\n// Suppress  deprecation warning for ov::Tensor::data()\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wdeprecated-declarations\"\n\nenum ggml_status ov_graph_compute(ggml_cgraph * cgraph, ggml_backend_t backend) {\n    ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *) backend->context;\n    try {\n        if (getenv(\"GGML_OPENVINO_DUMP_CGRAPH\")) {\n            std::string filename = \"cgraph_ov.txt\";\n            GgmlOvDecoder::dump_cgraph(cgraph, filename);\n        }\n\n        const auto is_static = ggml_openvino_is_npu();\n\n        GGML_ASSERT(ctx->runtime_context != nullptr);\n        std::shared_ptr<ov_runtime_context> r_ctx = std::static_pointer_cast<ov_runtime_context>(ctx->runtime_context);\n\n        return is_static ? ov_graph_compute_static(cgraph, r_ctx) : ov_graph_compute_dynamic(cgraph, r_ctx);\n    } catch (const ov::Exception & e) {\n        GGML_LOG_ERROR(\"GGML OpenVINO backend ov::Exception: %s\\n\", e.what());\n        return GGML_STATUS_FAILED;\n    } catch (const std::exception & e) {\n        GGML_LOG_ERROR(\"GGML OpenVINO backend std::exception: %s\\n\", e.what());\n        return GGML_STATUS_FAILED;\n    } catch (...) {\n        GGML_LOG_ERROR(\"GGML OpenVINO backend unknown exception\\n\");\n        return GGML_STATUS_FAILED;\n    }\n}\n\nov::Tensor create_ov_output_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder,\n                                   std::shared_ptr<ov::InferRequest> infer_request,\n                                   int output_index,\n                                   const ggml_tensor * ggml_tensor) {\n    auto output_type = ggml_decoder->get_ov_type(ggml_tensor);\n    ov::Shape output_shape;\n    if (ggml_decoder->is_static()) {\n        output_shape = infer_request->get_output_tensor(output_index).get_shape();\n    } else {\n        output_shape = ggml_decoder->get_shape(ggml_tensor);\n    }\n\n    ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data);\n    return output_tensor;\n}\n\nenum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr<ov_runtime_context> r_ctx) {\n    auto & core = ov_singleton_core();\n    const auto & config = ggml_openvino_get_compile_config();\n    auto device = r_ctx->device;\n    bool stateful = r_ctx->stateful;\n    static auto is_static = false;\n\n    if (is_naive(cgraph)) {\n        return naive_compute(cgraph, core, device, config);\n    }\n\n    auto start_time = ggml_time_us();\n\n    std::shared_ptr<GgmlOvDecoder> ggml_decoder;\n    std::shared_ptr<ov::InferRequest> infer_request;\n    ModelParams m_params;\n    ComputeParams c_params;\n    std::tie(m_params, c_params) = GgmlOvDecoder::compute_llm_params(cgraph, is_static);\n\n    graph_key key(cgraph);\n    bool cache_hit;\n\n    int64_t decoder_end_time;\n    int64_t conversion_end_time;\n    int64_t compile_end_time;\n    int64_t infer_end_time;\n\n    {\n        std::lock_guard<std::mutex> lock(r_ctx->ov_compute_mutex);\n\n        auto it = r_ctx->decoder_cache.find(key);\n\n        cache_hit = it != r_ctx->decoder_cache.end();\n        ModelParams old_m_params;\n        if (cache_hit) {\n            ggml_decoder = it->second;\n            old_m_params = ggml_decoder->get_model_params();\n            cache_hit = old_m_params.can_reuse_dynamically(m_params);\n        }\n\n        if (cache_hit) {\n            std::map<std::string, std::shared_ptr<ov::Node>> model_weights;\n            ggml_decoder->set_compute_params(c_params);\n            ggml_decoder->set_model_params(m_params);\n            if (old_m_params.kv_buffer_changed(m_params)) {\n                ggml_decoder->update_io(cgraph);\n            }\n            ggml_decoder->add_extra_inputs();\n            infer_request = r_ctx->infer_request_cache.at(key);\n\n            if (stateful) {\n                const auto * inp_pos = get_inp_pos_tensor(cgraph);\n                int32_t * pos_data = (int32_t *) inp_pos->data;\n                auto pos_shape = ggml_decoder->get_shape(inp_pos);\n                if (pos_data[0] == 0) {\n                    infer_request->reset_state();\n                    r_ctx->stateful_kv_size = pos_shape[3];\n                } else if (r_ctx->stateful_kv_size == static_cast<size_t>(pos_data[0])) {\n                    r_ctx->stateful_kv_size += pos_shape[3];\n                } else {\n                    auto states = infer_request->query_state();\n                    for (auto state : states) {\n                        auto state_tensor = state.get_state();\n                        auto state_tensor_shape = state_tensor.get_shape();\n                        if (static_cast<uint32_t>(pos_data[0]) > r_ctx->stateful_kv_size) {\n                            std::string state_name;\n                            try {\n                                state_name = r_ctx->kv_state_input_name_map.at(state.get_name());\n                            } catch (...) {\n                                GGML_LOG_ERROR(\"GGML OpenVINO backend stateful inference failed: no input found for the state\\n\");\n                                return GGML_STATUS_FAILED;\n                            }\n                            auto kv_tensor = get_ov_input_tensor(ggml_decoder, state_name);\n                            kv_tensor.set_shape({state_tensor_shape[0], kv_tensor.get_shape()[2],\n                                                 state_tensor_shape[2], state_tensor_shape[3]});\n                           state_tensor = kv_tensor;\n                           state_tensor_shape = state_tensor.get_shape();\n                        }\n                        ov::Coordinate begin = {0, 0, 0, 0};\n                        ov::Coordinate end = {state_tensor_shape[0], static_cast<uint32_t>(pos_data[0]),\n                                              state_tensor_shape[2], state_tensor_shape[3]};\n                        ov::Tensor new_state_tensor(state_tensor, begin, end);\n                        state.set_state(new_state_tensor);\n                    }\n                    r_ctx->stateful_kv_size = pos_data[0] + 1;\n                }\n            }\n\n            decoder_end_time = ggml_time_us();\n            conversion_end_time = decoder_end_time;\n            compile_end_time = decoder_end_time;\n        } else {\n            r_ctx->infer_request_cache.erase(key);\n\n            std::shared_ptr<ov::Model> model;\n            auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph);\n\n            ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights, is_static, stateful);\n            decoder_end_time = ggml_time_us();\n\n            auto input_model = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder);\n            model = ov::frontend::ggml::FrontEnd::convert(input_model);\n            ggml_decoder->clear_model_weights();\n            conversion_end_time = ggml_time_us();\n\n            if (getenv(\"GGML_OPENVINO_DUMP_IR\")) {\n                char timestamped_filename[64];\n                auto timestamp = (long long) ggml_time_us();\n                snprintf(timestamped_filename, sizeof(timestamped_filename), \"model_%lld.xml\", timestamp);\n                ov::serialize(model, timestamped_filename);\n            }\n\n            ov::CompiledModel compiled_model;\n            auto remote_context = ggml_openvino_get_remote_context();\n            if (remote_context.has_value()) {\n                compiled_model = core.compile_model(model, remote_context.value(), config);\n            } else {\n                compiled_model = core.compile_model(model, device, config);\n            }\n            compile_end_time = ggml_time_us();\n            infer_request = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());\n            r_ctx->infer_request_cache[key] = infer_request;\n            r_ctx->decoder_cache[key] = ggml_decoder;\n\n            std::vector<std::string> ov_input_names;\n            std::vector<std::string> ov_output_names;\n            for (const auto & ov_param : model->get_parameters()) {\n                ov_input_names.push_back(ov_param->get_friendly_name());\n            }\n            for (const auto & ov_output : model->get_results()) {\n                ov_output_names.push_back(ov_output->get_friendly_name());\n            }\n            r_ctx->ov_input_names_cache[key] = std::move(ov_input_names);\n            r_ctx->ov_output_names_cache[key] = std::move(ov_output_names);\n\n            if (stateful) {\n                const auto * inp_pos = get_inp_pos_tensor(cgraph);\n                auto pos_shape = ggml_decoder->get_shape(inp_pos);\n                r_ctx->stateful_kv_size = pos_shape[3];\n                const auto kv_param_res_names = ggml_decoder->get_kv_param_res_names();\n                for (const auto& pair : kv_param_res_names) {\n                    r_ctx->kv_state_input_name_map[pair.first+pair.second] = pair.first;\n                }\n            }\n        }\n\n        auto ov_input_names = r_ctx->ov_input_names_cache[key];\n        auto ov_output_names = r_ctx->ov_output_names_cache[key];\n\n        for (size_t i = 0; i < ov_input_names.size(); i++) {\n            auto param_name = ov_input_names[i];\n            auto input_tensor = get_ov_input_tensor(ggml_decoder, param_name);\n            infer_request->set_input_tensor(i, input_tensor);\n\n            if (getenv(\"GGML_OPENVINO_DEBUG_INPUT\")) {\n                print_input_tensor_info(param_name, input_tensor);\n            }\n        }\n\n        for (size_t i = 0; i < ov_output_names.size(); i++) {\n            auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names[i]);\n            auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor);\n            infer_request->set_output_tensor(i, output_tensor);\n        }\n\n        infer_request->infer();\n        infer_end_time = ggml_time_us();\n\n        if (getenv(\"GGML_OPENVINO_DEBUG_OUTPUT\")) {\n            for (size_t i = 0; i < ov_output_names.size(); i++) {\n                const auto output_tensor = infer_request->get_output_tensor(i);\n                print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data());\n            }\n        }\n\n        if (getenv(\"GGML_OPENVINO_PROFILING\")) {\n            GGML_LOG_INFO(\"\\nGGML OpenVINO Backend: \\n\");\n            GGML_LOG_INFO(\"  - Graph decoder time: %ld ms \\n\", (decoder_end_time - start_time) / 1000);\n            if (!cache_hit) {\n                GGML_LOG_INFO(\"  - Graph conversion time: %ld ms \\n\", (conversion_end_time - decoder_end_time) / 1000);\n                GGML_LOG_INFO(\"  - Graph compile time: %ld ms \\n\", (compile_end_time - conversion_end_time) / 1000);\n            }\n            GGML_LOG_INFO(\"  - Graph inference time: %ld ms \\n\", (infer_end_time - compile_end_time) / 1000);\n        }\n    }\n\n    return GGML_STATUS_SUCCESS;\n}\n\nenum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptr<ov_runtime_context> r_ctx) {\n    auto & core = ov_singleton_core();\n\n    auto get_prefill_chunk_size = [] {\n        const char * chunk_size_str = getenv(\"GGML_OPENVINO_PREFILL_CHUNK_SIZE\");\n        if (chunk_size_str && atoi(chunk_size_str) > 0) {\n            return atoi(chunk_size_str);\n        }\n        return 256;\n    };\n\n    static std::string device = \"NPU\";\n    static auto is_static = true;\n    static auto stateful = false;\n    static auto prefill_chunk_size = get_prefill_chunk_size();\n    const auto & config = ggml_openvino_get_compile_config();\n\n    if (is_naive(cgraph)) {\n        return naive_compute(cgraph, core, device, config);\n    }\n\n    auto start_time = ggml_time_us();\n\n    std::shared_ptr<GgmlOvDecoder> ggml_decoder;\n    std::shared_ptr<ov::InferRequest> infer_request;\n    ModelParams m_params;\n    ComputeParams c_params;\n    std::tie(m_params, c_params) = GgmlOvDecoder::compute_llm_params(cgraph, is_static);\n\n    const auto * inp_pos = get_inp_pos_tensor(cgraph);\n    const auto is_prefill = get_is_prefill(inp_pos);\n    graph_key key(cgraph);\n    bool cache_hit;\n\n    int64_t decoder_end_time;\n    int64_t conversion_end_time;\n    int64_t compile_end_time;\n    int64_t infer_end_time;\n\n    auto it = r_ctx->decoder_cache.find(key);\n\n    cache_hit = it != r_ctx->decoder_cache.end();\n    ModelParams old_m_params;\n    if (cache_hit) {\n        ggml_decoder = it->second;\n        old_m_params = ggml_decoder->get_model_params();\n        cache_hit = old_m_params.can_reuse_statically(m_params);\n    }\n\n    if (cache_hit) {\n        std::map<std::string, std::shared_ptr<ov::Node>> model_weights;\n        ggml_decoder->m_is_prefill = is_prefill;\n        ggml_decoder->set_model_params(m_params);\n        ggml_decoder->set_compute_params(c_params);\n        if (old_m_params.kv_buffer_changed(m_params)) {\n            ggml_decoder->update_io(cgraph);\n        }\n        ggml_decoder->add_extra_inputs();\n        infer_request = is_prefill ? r_ctx->infer_request_cache_prefill.at(key) : r_ctx->infer_request_cache.at(key);\n\n        decoder_end_time = ggml_time_us();\n        conversion_end_time = decoder_end_time;\n        compile_end_time = decoder_end_time;\n    } else {\n        r_ctx->infer_request_cache.erase(key);\n        r_ctx->infer_request_cache_prefill.erase(key);\n\n        std::shared_ptr<ov::Model> model;\n        auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph);\n\n        auto ggml_decoder_prefill = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights,\n                                                                    is_static, stateful, true, prefill_chunk_size);\n        auto ggml_decoder_decode = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights, is_static,\n                                                                   stateful, false, prefill_chunk_size);\n        decoder_end_time = ggml_time_us();\n\n        auto input_model_prefill = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder_prefill);\n        auto input_model_decode = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder_decode);\n\n        auto model_prefill = ov::frontend::ggml::FrontEnd::convert(input_model_prefill);\n        ggml_decoder_prefill->clear_model_weights();\n        auto model_decode = ov::frontend::ggml::FrontEnd::convert(input_model_decode);\n        ggml_decoder_decode->clear_model_weights();\n        conversion_end_time = ggml_time_us();\n\n        if (getenv(\"GGML_OPENVINO_DUMP_IR\")) {\n            char timestamped_filename[64];\n            auto timestamp = (long long) ggml_time_us();\n            snprintf(timestamped_filename, sizeof(timestamped_filename), \"model_prefill_%lld.xml\", timestamp);\n            ov::serialize(model_prefill, timestamped_filename);\n            snprintf(timestamped_filename, sizeof(timestamped_filename), \"model_decode_%lld.xml\", timestamp);\n            ov::serialize(model_decode, timestamped_filename);\n        }\n\n        ov::CompiledModel compiled_model_prefill;\n        ov::CompiledModel compiled_model_decode;\n        auto remote_context = ggml_openvino_get_remote_context();\n        if (remote_context.has_value()) {\n            compiled_model_prefill = core.compile_model(model_prefill, remote_context.value(), config);\n            compiled_model_decode = core.compile_model(model_decode, remote_context.value(), config);\n        } else {\n            compiled_model_prefill = core.compile_model(model_prefill, device, config);\n            compiled_model_decode = core.compile_model(model_decode, device, config);\n        }\n\n        r_ctx->infer_request_cache_prefill[key] =\n            std::make_shared<ov::InferRequest>(compiled_model_prefill.create_infer_request());\n        r_ctx->infer_request_cache[key] =\n            std::make_shared<ov::InferRequest>(compiled_model_decode.create_infer_request());\n        compile_end_time = ggml_time_us();\n\n        model = is_prefill ? model_prefill : model_decode;\n        ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode;\n        infer_request = is_prefill ? r_ctx->infer_request_cache_prefill[key] : r_ctx->infer_request_cache[key];\n        r_ctx->decoder_cache[key] = ggml_decoder;\n\n        std::vector<std::string> ov_input_names;\n        std::vector<std::string> ov_output_names;\n        for (const auto & ov_param : model->get_parameters()) {\n            ov_input_names.push_back(ov_param->get_friendly_name());\n        }\n        for (const auto & ov_output : model->get_results()) {\n            ov_output_names.push_back(ov_output->get_friendly_name());\n        }\n        r_ctx->ov_input_names_cache[key] = std::move(ov_input_names);\n        r_ctx->ov_output_names_cache[key] = std::move(ov_output_names);\n    }\n\n    auto ov_input_names = r_ctx->ov_input_names_cache[key];\n    auto ov_output_names = r_ctx->ov_output_names_cache[key];\n\n    if (is_prefill) {\n        auto inp_len = inp_pos->ne[0];\n        for (int chunk_index = 0; chunk_index * prefill_chunk_size < inp_len; chunk_index++) {\n            for (size_t i = 0; i < ov_input_names.size(); i++) {\n                auto param_name = ov_input_names[i];\n                auto input_tensor = get_ov_input_tensor_static_prefill(ggml_decoder, param_name, chunk_index);\n                infer_request->set_input_tensor(i, input_tensor);\n\n                if (getenv(\"GGML_OPENVINO_DEBUG_INPUT\")) {\n                    const auto input_tensor = infer_request->get_input_tensor(i);\n                    print_input_tensor_info(param_name, input_tensor);\n                }\n            }\n\n            for (size_t i = 0; i < ov_output_names.size(); i++) {\n                auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names[i]);\n                auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor);\n                infer_request->set_output_tensor(i, output_tensor);\n            }\n\n            infer_request->infer();\n\n            if (getenv(\"GGML_OPENVINO_DEBUG_OUTPUT\")) {\n                for (size_t i = 0; i < ov_output_names.size(); i++) {\n                    const auto output_tensor = infer_request->get_output_tensor(i);\n                    print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data());\n                }\n            }\n        }\n        infer_end_time = ggml_time_us();\n    } else {\n        for (size_t i = 0; i < ov_input_names.size(); i++) {\n            auto param_name = ov_input_names[i];\n            auto input_tensor = get_ov_input_tensor_static_decode(ggml_decoder, param_name);\n            infer_request->set_input_tensor(i, input_tensor);\n\n            if (getenv(\"GGML_OPENVINO_DEBUG_INPUT\")) {\n                const auto input_tensor = infer_request->get_input_tensor(i);\n                print_input_tensor_info(param_name, input_tensor);\n            }\n        }\n\n        for (size_t i = 0; i < ov_output_names.size(); i++) {\n            auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names[i]);\n            auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor);\n            infer_request->set_output_tensor(i, output_tensor);\n        }\n\n        infer_request->infer();\n        infer_end_time = ggml_time_us();\n\n        if (getenv(\"GGML_OPENVINO_DEBUG_OUTPUT\")) {\n            for (size_t i = 0; i < ov_output_names.size(); i++) {\n                const auto output_tensor = infer_request->get_output_tensor(i);\n                print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data());\n            }\n        }\n    }\n\n    if (getenv(\"GGML_OPENVINO_PROFILING\")) {\n        GGML_LOG_INFO(\"\\nGGML OpenVINO Backend: \\n\");\n        GGML_LOG_INFO(\"  - Graph decoder time: %ld ms \\n\", (decoder_end_time - start_time) / 1000);\n        if (!cache_hit) {\n            GGML_LOG_INFO(\"  - Graph conversion time: %ld ms \\n\", (conversion_end_time - decoder_end_time) / 1000);\n            GGML_LOG_INFO(\"  - Graph compile time: %ld ms \\n\", (compile_end_time - conversion_end_time) / 1000);\n        }\n        GGML_LOG_INFO(\"  - Graph inference time: %ld ms \\n\", (infer_end_time - compile_end_time) / 1000);\n    }\n\n    return GGML_STATUS_SUCCESS;\n}\n\nbool is_naive(ggml_cgraph * cgraph) {\n    constexpr int naive_graph_size_threshold = 20;\n    int count = 0;\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        if (cgraph->nodes[i]->op != GGML_OP_NONE) {\n            count++;\n        }\n    }\n    return count < naive_graph_size_threshold;\n}\n\nenum ggml_status naive_compute(ggml_cgraph * cgraph,\n                               ov::Core & core,\n                               const std::string & device,\n                               const ov::AnyMap & config) {\n    if (cgraph->n_nodes == 1 && (cgraph->nodes[0]->op == GGML_OP_NONE || cgraph->nodes[0]->op == GGML_OP_VIEW)) {\n        return GGML_STATUS_SUCCESS;\n    }\n\n    bool naive = true;\n    auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph, naive);\n    auto decoder = std::make_shared<GgmlOvDecoder>(cgraph, model_weights);\n    auto input_model = std::make_shared<ov::frontend::ggml::InputModel>(decoder);\n    auto model = ov::frontend::ggml::FrontEnd::convert(input_model, naive);\n    if (getenv(\"GGML_OPENVINO_DUMP_IR\")) {\n        ov::serialize(model, \"IR_naive.xml\");\n    }\n\n    std::shared_ptr<ov::InferRequest> infer_request;\n    auto remote_context = ggml_openvino_get_remote_context();\n    if (cgraph->nodes[0]->op == GGML_OP_MUL_MAT) {\n        // TODO ACCURACY hint triggers a bug in GPU plugin/driver on Lunar Lake. Remove once CVS-182166 is resolved\n        core.set_property(device, ov::hint::execution_mode(ov::hint::ExecutionMode::PERFORMANCE));\n    } else {\n        core.set_property(device, ov::hint::execution_mode(ov::hint::ExecutionMode::ACCURACY));\n    }\n    if (remote_context.has_value()) {\n        infer_request = std::make_shared<ov::InferRequest>(\n            core.compile_model(model, remote_context.value(), config).create_infer_request());\n    } else {\n        infer_request =\n            std::make_shared<ov::InferRequest>(core.compile_model(model, device, config).create_infer_request());\n    }\n\n    auto ov_params = model->get_parameters();\n    for (size_t i = 0; i < ov_params.size(); i++) {\n        auto param_name = ov_params[i]->get_friendly_name();\n        auto input_tensor = get_ov_input_tensor(decoder, param_name);\n        infer_request->set_input_tensor(i, input_tensor);\n    }\n\n    auto ov_results = model->get_results();\n    for (size_t i = 0; i < ov_results.size(); i++) {\n        auto * ggml_tensor = decoder->get_model_outputs().at(ov_results[i]->get_friendly_name());\n        auto output_tensor = create_ov_output_tensor(decoder, infer_request, i, ggml_tensor);\n        infer_request->set_output_tensor(i, output_tensor);\n    }\n\n    infer_request->infer();\n    return GGML_STATUS_SUCCESS;\n}\n\nnamespace {\nov::Tensor convert_ggml_input_to_ov(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string & name) {\n    const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(name);\n\n    if (ggml_tensor->extra != nullptr) {\n        // GGML_LOG_DEBUG(\"Using ggml_tensor->extra as ov::Tensor for input: %s\\n\", name.c_str());\n        auto * extra_base = static_cast<ggml_openvino_extra_base *>(ggml_tensor->extra);\n        if (extra_base->type != ggml_openvino_extra_base::Type::TENSOR) {\n            throw std::runtime_error(\"ggml tensor extra is not of type TENSOR for input: \" + name);\n        }\n        auto * tensor_extra = static_cast<ggml_openvino_tensor_extra *>(extra_base);\n        return *tensor_extra->tensor;\n    }\n\n    // GGML_LOG_DEBUG(\"Converting ggml tensor to ov::Tensor for input: %s\\n\", name.c_str());\n    auto * input_data = ggml_tensor->data;\n    ov::Shape input_shape;\n    if (ggml_tensor->op == GGML_OP_VIEW) {\n        // This case is added to make test-backend-ops work\n        input_shape = ggml_decoder->get_shape(ggml_tensor->view_src);\n    } else {\n        input_shape = ggml_decoder->get_shape(ggml_tensor);\n    }\n    auto input_tensor = ov::Tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape, input_data);\n    return input_tensor;\n}\n}  // namespace\n\nov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string & param_name) {\n    ov::Tensor input_tensor;\n    if (ggml_decoder->get_model_extra_inputs().find(param_name) != ggml_decoder->get_model_extra_inputs().end()) {\n        input_tensor = *ggml_decoder->get_model_extra_input_values().at(param_name);\n    } else {\n        input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name);\n    }\n    return input_tensor;\n}\n\nov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml_decoder,\n                                             const std::string & param_name) {\n    // NPU decoding stage\n    const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);\n    const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);\n\n    if (GgmlOvDecoder::is_inp_tok(ggml_tensor, op) || GgmlOvDecoder::is_inp_pos(ggml_tensor, op) ||\n        GgmlOvDecoder::is_kv_idx(ggml_tensor, op)) {\n        assert(ggml_tensor->ne[0] == 1);\n        ov::Shape input_shape = {1, 1, 1, 1};\n        ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape);\n        if (ggml_tensor->type == GGML_TYPE_I32) {\n            *input_tensor.data<int32_t>() = *((int32_t *) ggml_tensor->data);\n        } else if (ggml_tensor->type == GGML_TYPE_I64) {\n            *input_tensor.data<int64_t>() = *((int64_t *) ggml_tensor->data);\n        } else {\n            throw std::runtime_error(\"Unexpected tensor type for \" + param_name);\n        }\n        return input_tensor;\n    }\n\n    if (GgmlOvDecoder::is_output_idx(ggml_tensor, op)) {\n        ov::Shape input_shape = {1, 1, 1, 1};\n        ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape);\n        int32_t inp_out_id = *((int32_t *) ggml_tensor->data);\n        assert(ggml_tensor->ne[0] == 1);\n        assert(inp_out_id == 0);\n        *input_tensor.data<int32_t>() = inp_out_id;\n        return input_tensor;\n    }\n\n    if (GgmlOvDecoder::is_inp_mask(ggml_tensor, op)) {\n        size_t context_size = ggml_decoder->get_ctx_size();\n        std::vector<float> padded_data = pad_input<float>(ggml_tensor, 1, context_size, -INFINITY);\n        ov::Tensor input_tensor(ov::element::f32, ov::Shape{1, 1, 1, context_size});\n        auto * data_ptr = input_tensor.data<float>();\n        std::copy(padded_data.begin(), padded_data.begin() + context_size, data_ptr);\n        return input_tensor;\n    }\n\n    return get_ov_input_tensor(ggml_decoder, param_name);\n}\n\nov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr<GgmlOvDecoder> ggml_decoder,\n                                              const std::string & param_name,\n                                              int chunk_index) {\n    // NPU prompt processing stage\n    const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);\n    const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);\n\n    const size_t input_len = ggml_decoder->get_input_len();\n    const size_t chunk_size = ggml_decoder->m_prefill_chunk_size;\n    const size_t chunk_valid_size = std::min(chunk_size, input_len - chunk_index * chunk_size);\n    const size_t chunk_pad_size = chunk_size - chunk_valid_size;\n\n    if (GgmlOvDecoder::is_inp_tok(ggml_tensor, op) || GgmlOvDecoder::is_inp_pos(ggml_tensor, op) ||\n        GgmlOvDecoder::is_kv_idx(ggml_tensor, op)) {\n        ov::Shape input_shape = {1, 1, 1, chunk_size};\n        ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape);\n        // copy the chunk_index-th chunk from ggml_tensor\n        size_t element_size = ggml_type_size(ggml_tensor->type);\n        void * input_data = (char *) ggml_tensor->data + chunk_index * chunk_size * element_size;\n        std::memcpy(input_tensor.data(), input_data, chunk_valid_size * element_size);\n        // pad the rest with last_value + 1, so that kv's of padded positions are inserted\n        // to the next row after the valids row in the kvcache\n        if (chunk_pad_size > 0) {\n            if (ggml_tensor->type == GGML_TYPE_I32) {\n                int32_t last_value =\n                    *((int32_t *) ggml_tensor->data + (chunk_index * chunk_size + chunk_valid_size - 1));\n                int32_t * output_data = input_tensor.data<int32_t>();\n                std::fill(output_data + chunk_valid_size, output_data + chunk_size, last_value + 1);\n            } else if (ggml_tensor->type == GGML_TYPE_I64) {\n                int64_t last_value =\n                    *((int64_t *) ggml_tensor->data + (chunk_index * chunk_size + chunk_valid_size - 1));\n                int64_t * output_data = input_tensor.data<int64_t>();\n                std::fill(output_data + chunk_valid_size, output_data + chunk_size, last_value + 1);\n            } else {\n                throw std::runtime_error(\"Unexpected tensor type for \" + param_name);\n            }\n        }\n        return input_tensor;\n    }\n\n    if (GgmlOvDecoder::is_output_idx(ggml_tensor, op)) {\n        size_t output_len = ggml_decoder->get_compute_params().output_len;\n        ov::Shape input_shape = {1, 1, 1, output_len};\n        ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape);\n        if (ggml_tensor->ne[0] == 0) {\n            *input_tensor.data<int32_t>() = 0;\n        } else {\n            auto * data_addr = input_tensor.data<int32_t>();\n            for (size_t i = 0; i < output_len; i++) {\n                data_addr[i] = ((int32_t *) ggml_tensor->data)[i] % chunk_size;\n            }\n        }\n        return input_tensor;\n    }\n\n    if (GgmlOvDecoder::is_inp_mask(ggml_tensor, op)) {\n        size_t cols = ggml_tensor->ne[0];\n        size_t rows = ggml_tensor->ne[1];\n        float * ggml_data = (float *) ggml_tensor->data + chunk_index * chunk_size * cols;\n        size_t chunk_valid_rows = std::min(chunk_size, rows - chunk_index * chunk_size);\n        size_t context_size = ggml_decoder->get_ctx_size();\n        std::vector<float> padded_data =\n            pad_input<float>(ggml_data, chunk_valid_rows, cols, chunk_size, context_size, -INFINITY);\n        set_zero_diagonal(padded_data, chunk_size, context_size);\n        ov::Tensor input_tensor(ov::element::f32, ov::Shape{1, 1, chunk_size, context_size});\n        auto * data_ptr = input_tensor.data<float>();\n        std::copy(padded_data.begin(), padded_data.begin() + chunk_size * context_size, data_ptr);\n        return input_tensor;\n    }\n\n    return get_ov_input_tensor(ggml_decoder, param_name);\n}\n\nsize_t checksum(const void * data, size_t size) {\n    const uint8_t * bytes = static_cast<const uint8_t *>(data);\n    size_t sum = 0;\n    for (size_t i = 0; i < size; ++i) {\n        sum += (uint8_t) i;\n        sum += bytes[i];\n    }\n    return sum;\n}\n\nvoid print_input_tensor_info(const std::string & name, const ov::Tensor & tensor) {\n    std::cout << \"Input name: \" << name << \", Input shape: \" << tensor.get_shape() << \", Address: \" << tensor.data()\n              << std::endl;\n    switch (tensor.get_element_type()) {\n    case ov::element::f32: {\n        if (name.find(\"self_kq_mask\") == std::string::npos) {\n            std::cout << *(tensor.data<float>()) << std::endl;\n        } else {\n            size_t rows = tensor.get_shape()[2];\n            size_t cols = tensor.get_shape()[3];\n            auto * data = tensor.data<float>();\n            for (size_t i = 0; i < rows; ++i) {\n                for (size_t j = 0; j < cols; ++j) {\n                    float val = data[i * cols + j];\n                    if (std::isinf(val) && val < 0) {\n                        std::cout << std::setw(5) << \"-inf\";\n                    } else {\n                        std::cout << std::setw(5) << val;\n                    }\n                }\n                std::cout << std::endl;\n            }\n        }\n\n        break;\n    }\n    case ov::element::f16:\n        std::cout << *(tensor.data<ov::float16>()) << std::endl;\n        break;\n    case ov::element::i32:\n        for (size_t i = 0; i < tensor.get_size(); ++i) {\n            std::cout << tensor.data<int32_t>()[i] << \" \";\n        }\n        std::cout << std::endl;\n        break;\n    case ov::element::i64:\n        for (size_t i = 0; i < tensor.get_size(); ++i) {\n            std::cout << tensor.data<int64_t>()[i] << \" \";\n        }\n        std::cout << std::endl;\n        break;\n    default:\n        break;\n    }\n}\n\nvoid print_output_tensor_info(const std::string & name, const ov::Tensor & tensor, const void * output_dst) {\n    std::cout << \"Output name: \" << name << \", Output shape: \" << tensor.get_shape() << \", Address: \" << output_dst\n              << std::endl;\n\n    auto print_float_stats = [](const std::string & type_name, size_t size, auto get_value) {\n        if (size == 0) {\n            return;\n        }\n\n        float first = get_value(0);\n        float min = first;\n        float max = first;\n        double sum = first;\n\n        for (size_t i = 1; i < size; ++i) {\n            float v = get_value(i);\n            if (v < min) {\n                min = v;\n            }\n            if (v > max) {\n                max = v;\n            }\n            sum += v;\n        }\n        double mean = sum / size;\n\n        std::cout << std::right << std::setw(6) << type_name << std::right << std::setw(12) << \"First\" << std::setw(12)\n                  << \"Min\" << std::setw(12) << \"Max\" << std::setw(12) << \"Mean\" << std::endl;\n        std::cout << std::right << std::setw(6) << \"\" << std::right << std::setw(12) << first << std::setw(12) << min\n                  << std::setw(12) << max << std::setw(12) << mean << std::endl;\n    };\n\n    switch (tensor.get_element_type()) {\n    case ov::element::f32: {\n        const float * data = tensor.data<float>();\n        size_t size = tensor.get_size();\n        print_float_stats(\"[f32]\", size, [data](size_t i) { return data[i]; });\n        break;\n    }\n    case ov::element::f16: {\n        const ov::float16 * data = tensor.data<ov::float16>();\n        size_t size = tensor.get_size();\n        print_float_stats(\"[f16]\", size, [data](size_t i) { return static_cast<float>(data[i]); });\n        break;\n    }\n    default:\n        break;\n    }\n}\n\nvoid set_zero_diagonal(std::vector<float> & matrix, size_t rows, size_t cols) {\n    for (size_t i = 0; i < rows; ++i) {\n        size_t diag_col = std::min(i, cols - 1);\n        matrix[i * cols + diag_col] = 0.0f;\n    }\n}\n\nconst ggml_tensor * get_inp_pos_tensor(ggml_cgraph * cgraph) {\n    for (int i = 0; i < cgraph->n_nodes; ++i) {\n        auto * op = cgraph->nodes[i];\n        for (int j = 0; j < GGML_MAX_SRC; ++j) {\n            auto * src = op->src[j];\n            if (src == nullptr) {\n                break;\n            }\n            if (GgmlOvDecoder::is_inp_pos(src, op)) {\n                return src;\n            }\n        }\n    }\n    GGML_LOG_ERROR(\"get_inp_pos_tensor: inp_pos not found in cgraph\");\n    throw std::runtime_error(\"get_inp_pos_tensor: inp_pos not found in cgraph\");\n}\n\nbool get_is_prefill(const ggml_tensor * inp_pos) {\n    return inp_pos->ne[0] > 1;\n}\n\n#pragma GCC diagnostic pop\n"
  },
  {
    "path": "src/ggml-openvino/utils.h",
    "content": "#include \"ggml-backend-impl.h\"\n#include \"ggml-decoder.h\"\n#include \"ggml-impl.h\"\n\n#include <algorithm>\n#include <cstddef>\n#include <memory>\n#include <openvino/runtime/core.hpp>\n#include <openvino/runtime/infer_request.hpp>\n#include <string>\n#include <unordered_map>\n#include <vector>\n\nstruct graph_key {\n    int n_nodes;\n    std::string first_node_name;\n    std::string last_node_name;\n\n    graph_key(const ggml_cgraph * cgraph) : n_nodes(cgraph->n_nodes) {\n        if (n_nodes > 0) {\n            first_node_name = cgraph->nodes[0]->name;\n            last_node_name = cgraph->nodes[n_nodes - 1]->name;\n        }\n    }\n\n    bool operator==(const graph_key & other) const {\n        return n_nodes == other.n_nodes && first_node_name == other.first_node_name &&\n               last_node_name == other.last_node_name;\n    }\n};\n\nstruct graph_key_hash {\n    size_t operator()(const graph_key & key) const {\n        size_t h = std::hash<int>{}(key.n_nodes);\n        if (key.n_nodes > 0) {\n            h ^= std::hash<std::string>{}(key.first_node_name) + 0x9e3779b9 + (h << 6) + (h >> 2);\n            h ^= std::hash<std::string>{}(key.last_node_name) + 0x9e3779b9 + (h << 6) + (h >> 2);\n        }\n        return h;\n    }\n};\n\nstruct ov_runtime_context {\n    std::mutex ov_compute_mutex;\n    std::string device;\n    bool stateful;\n    std::unordered_map<graph_key, std::shared_ptr<GgmlOvDecoder>, graph_key_hash> decoder_cache;\n    std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache;\n    std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache_prefill;\n    std::unordered_map<graph_key, std::vector<std::string>, graph_key_hash> ov_input_names_cache;\n    std::unordered_map<graph_key, std::vector<std::string>, graph_key_hash> ov_output_names_cache;\n    //TODO: Stateful is only supported for single request at a time.\n    //      Simultanous stateful inference request support to be added.\n    size_t stateful_kv_size;\n    std::map<std::string, std::string> kv_state_input_name_map;\n\n    ov_runtime_context() :\n        device(\"CPU\"),\n        stateful(false),\n        stateful_kv_size(0) {}\n};\n\nenum ggml_status ov_graph_compute(struct ggml_cgraph * cgraph, ggml_backend_t backend);\n\nenum ggml_status ov_graph_compute_dynamic(struct ggml_cgraph * cgraph, std::shared_ptr<ov_runtime_context> r_ctx);\nenum ggml_status ov_graph_compute_static(struct ggml_cgraph * cgraph, std::shared_ptr<ov_runtime_context> r_ctx);\n\nsize_t checksum(const void * data, size_t size);\n\nvoid print_input_tensor_info(const std::string & name, const ov::Tensor & tensor);\n\nvoid print_output_tensor_info(const std::string & name, const ov::Tensor & tensor, const void * output_dst);\n\ntemplate <typename T>\nstd::vector<T> pad_input(const T * data,\n                         size_t rows,\n                         size_t cols,\n                         size_t padded_rows,\n                         size_t padded_cols,\n                         T pad_value) {\n    std::vector<T> padded(padded_rows * padded_cols, pad_value);\n\n    for (size_t i = 0; i < std::min(rows, padded_rows); ++i) {\n        for (size_t j = 0; j < std::min(cols, padded_cols); ++j) {\n            padded[i * padded_cols + j] = data[i * cols + j];\n        }\n    }\n\n    return padded;\n}\n\ntemplate <typename T>\nstd::vector<T> pad_input(const ggml_tensor * tensor, size_t padded_rows, size_t padded_cols, T pad_value) {\n    return pad_input<T>(reinterpret_cast<const T *>(tensor->data),\n                        static_cast<size_t>(tensor->ne[1]),  // rows\n                        static_cast<size_t>(tensor->ne[0]),  // cols\n                        padded_rows, padded_cols, pad_value);\n}\n\nvoid set_zero_diagonal(std::vector<float> & matrix, size_t rows, size_t cols);\n\nconst ggml_tensor * get_inp_pos_tensor(struct ggml_cgraph * cgraph);\n\nbool get_is_prefill(const ggml_tensor * inp_pos);\n\nov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string & param_name);\nov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml_decoder,\n                                             const std::string & param_name);\nov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr<GgmlOvDecoder> ggml_decoder,\n                                              const std::string & param_name,\n                                              int chunk_index);\n\nov::Tensor create_ov_output_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder,\n                                   std::shared_ptr<ov::InferRequest> infer_request,\n                                   int output_index,\n                                   const ggml_tensor * ggml_tensor);\n\nbool is_naive(struct ggml_cgraph * cgraph);\n\nenum ggml_status naive_compute(struct ggml_cgraph * cgraph,\n                               ov::Core & core,\n                               const std::string & device,\n                               const ov::AnyMap & config);\n"
  },
  {
    "path": "src/ggml-opt.cpp",
    "content": "#include \"ggml-opt.h\"\n\n#include \"ggml.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-impl.h\"\n\n#include <algorithm>\n#include <cmath>\n#include <cstdint>\n#include <cinttypes>\n#include <map>\n#include <random>\n#include <vector>\n\nstruct ggml_opt_dataset {\n    struct ggml_context   * ctx    = nullptr;\n    ggml_backend_buffer_t   buf    = nullptr;\n    struct ggml_tensor    * data   = nullptr;\n    struct ggml_tensor    * labels = nullptr;\n\n    int64_t ndata       = -1;\n    int64_t ndata_shard = -1;\n    size_t  nbs_data    = -1;\n    size_t  nbs_labels  = -1;\n\n    std::vector<int64_t> permutation;\n};\n\nstruct ggml_opt_context {\n    ggml_backend_sched_t       backend_sched        = nullptr;\n    ggml_cgraph              * allocated_graph      = nullptr;\n    ggml_cgraph              * allocated_graph_copy = nullptr;\n    struct ggml_context      * ctx_static           = nullptr;\n    struct ggml_context      * ctx_cpu              = nullptr;\n    struct ggml_context      * ctx_compute          = nullptr;\n    struct ggml_context      * ctx_copy             = nullptr;\n    ggml_backend_buffer_t      buf_static           = nullptr;\n    ggml_backend_buffer_t      buf_cpu              = nullptr;\n    std::mt19937               rng;\n    enum ggml_opt_loss_type    loss_type;\n    enum ggml_opt_build_type   build_type;\n    enum ggml_opt_build_type   build_type_alloc;\n\n    struct ggml_tensor * inputs  = nullptr;\n    struct ggml_tensor * outputs = nullptr;\n    struct ggml_tensor * labels  = nullptr;\n\n    struct ggml_tensor * loss     = nullptr;\n    struct ggml_tensor * pred     = nullptr;\n    struct ggml_tensor * ncorrect = nullptr;\n\n    struct ggml_cgraph * gf      = nullptr;\n    struct ggml_cgraph * gb_grad = nullptr;\n    struct ggml_cgraph * gb_opt  = nullptr;\n    bool static_graphs           = false;\n    bool eval_ready              = false;\n    std::vector<struct ggml_tensor *> grad_accs;\n    std::vector<struct ggml_tensor *> grad_m;\n    std::vector<struct ggml_tensor *> grad_v;\n\n    int64_t iter               = 1;\n    int32_t opt_period         = 1;\n    int32_t opt_i              = 0;\n    bool    loss_per_datapoint = false;\n\n    ggml_opt_get_optimizer_params get_opt_pars    = nullptr;\n    void *                        get_opt_pars_ud = nullptr;\n    struct ggml_tensor *          opt_step_params = nullptr; // Stores output of get_opt_pars.\n\n    enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;\n};\n\nstruct ggml_opt_result {\n    int64_t              ndata    = 0;\n    std::vector<float>   loss;\n    std::vector<int32_t> pred;\n    int64_t              ncorrect = 0;\n\n    int64_t opt_period         = -1;\n    bool    loss_per_datapoint = false;\n};\n\n// ====== Dataset ======\n\nggml_opt_dataset_t ggml_opt_dataset_init(\n        enum ggml_type type_data,\n        enum ggml_type type_label,\n        int64_t        ne_datapoint,\n        int64_t        ne_label,\n        int64_t        ndata,\n        int64_t        ndata_shard) {\n    GGML_ASSERT(ne_datapoint >  0);\n    GGML_ASSERT(ne_label     >= 0);\n    GGML_ASSERT(ndata        >  0);\n    GGML_ASSERT(ndata_shard  >  0);\n\n    ggml_opt_dataset_t result = new ggml_opt_dataset;\n    result->ndata       = ndata;\n    result->ndata_shard = ndata_shard;\n\n    {\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ 2*ggml_tensor_overhead(),\n            /*.mem_buffer =*/ nullptr,\n            /*.no_alloc   =*/ true,\n        };\n        result->ctx = ggml_init(params);\n    }\n\n    result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata);\n    result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;\n\n    if (ne_label > 0) {\n        result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata);\n        result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;\n    } else {\n        result->labels = nullptr;\n        result->nbs_labels = 0;\n    }\n\n    result->buf = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, ggml_backend_cpu_buffer_type());\n\n    const int64_t nshards = ndata/ndata_shard;\n    result->permutation.resize(nshards);\n    for (int64_t i = 0; i < nshards; ++i) {\n        result->permutation[i] = i;\n    }\n    return result;\n}\n\nvoid ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {\n    ggml_backend_buffer_free(dataset->buf);\n    ggml_free(dataset->ctx);\n    delete dataset;\n}\n\nint64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) {\n    return dataset->ndata;\n}\n\nstruct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {\n    return dataset->data;\n}\n\nstruct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset) {\n    return dataset->labels;\n}\n\nvoid ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata) {\n    GGML_ASSERT(idata <= dataset->ndata);\n\n    if (idata < 0) {\n        std::shuffle(dataset->permutation.begin(), dataset->permutation.end(), opt_ctx->rng);\n        return;\n    }\n\n    GGML_ASSERT(idata % dataset->ndata_shard == 0);\n    const int64_t ishard_max = idata / dataset->ndata_shard;\n    std::shuffle(dataset->permutation.begin(), dataset->permutation.begin() + ishard_max, opt_ctx->rng);\n}\n\nvoid ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, int64_t ibatch) {\n    GGML_ASSERT(   data_batch && ggml_is_contiguous(data_batch));\n    GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));\n    GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));\n    GGML_ASSERT(                   data_batch->type == dataset->data->type);\n    GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type);\n\n    const size_t nb_data_batch = ggml_nbytes(data_batch);\n    GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);\n    const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;\n\n    if (labels_batch) {\n        const size_t nb_labels_batch = ggml_nbytes(labels_batch);\n        GGML_ASSERT(nb_labels_batch == shards_per_batch*dataset->nbs_labels);\n    }\n\n    GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));\n\n    for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {\n        const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];\n\n        const char * ptr_data = (const char *) dataset->data->data + ishard*dataset->nbs_data;\n        ggml_backend_tensor_set(data_batch, ptr_data, ishard_batch*dataset->nbs_data, dataset->nbs_data);\n\n        if (!labels_batch) {\n            continue;\n        }\n\n        const char * ptr_labels = (const char *) dataset->labels->data + ishard*dataset->nbs_labels;\n        ggml_backend_tensor_set(labels_batch, ptr_labels, ishard_batch*dataset->nbs_labels, dataset->nbs_labels);\n    }\n}\n\nvoid ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) {\n    GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));\n    GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);\n\n    const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;\n\n    GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));\n\n    for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {\n        const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];\n\n        const char * ptr_data       = (const char *) dataset->data->data + ishard      *dataset->nbs_data;\n        char       * ptr_data_batch = (char       *) data_batch          + ishard_batch*dataset->nbs_data;\n        memcpy(ptr_data_batch, ptr_data, dataset->nbs_data);\n\n        if (!labels_batch) {\n            continue;\n        }\n\n        const char * ptr_labels       = (const char *) dataset->labels->data + ishard      *dataset->nbs_labels;\n        char       * ptr_labels_batch = (char       *) labels_batch          + ishard_batch*dataset->nbs_labels;\n        memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels);\n    }\n}\n\n// ====== Model / Context ======\n\nstruct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {\n    GGML_UNUSED(userdata);\n\n    ggml_opt_optimizer_params result;\n\n    result.adamw.alpha = 0.001f;\n    result.adamw.beta1 = 0.9f;\n    result.adamw.beta2 = 0.999f;\n    result.adamw.eps   = 1e-8f;\n    result.adamw.wd    = 0.0f;\n\n    result.sgd.alpha   = 1e-3f;\n    result.sgd.wd      = 0.0f;\n\n    return result;\n}\n\n\nstruct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {\n    return *((struct ggml_opt_optimizer_params *) userdata);\n}\n\nstruct ggml_opt_params ggml_opt_default_params(\n        ggml_backend_sched_t      backend_sched,\n        enum ggml_opt_loss_type   loss_type) {\n    return {\n        /*backend_sched   =*/ backend_sched,\n        /*ctx_compute     =*/ nullptr,\n        /*inputs          =*/ nullptr,\n        /*logits          =*/ nullptr,\n        /*loss_type       =*/ loss_type,\n        /*build_type      =*/ GGML_OPT_BUILD_TYPE_OPT,\n        /*opt_period      =*/ 1,\n        /*get_opt_pars    =*/ ggml_opt_get_default_optimizer_params,\n        /*get_opt_pars_ud =*/ nullptr,\n        /*optimizer       =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,\n    };\n}\n\nstatic ggml_tensor * map_tensor(std::map<ggml_tensor *, ggml_tensor *> & tensor_map, ggml_context * ctx, ggml_tensor * tensor) {\n    if (!tensor) {\n        return nullptr;\n    }\n\n    if (tensor_map.find(tensor) != tensor_map.end()) {\n        return tensor_map[tensor];\n    }\n\n    ggml_tensor * new_tensor = ggml_dup_tensor(ctx, tensor);\n    tensor_map[tensor] = new_tensor;\n\n    new_tensor->op = tensor->op;\n    for (int i = 0; i < GGML_MAX_DIMS; i++) {\n        new_tensor->nb[i] = tensor->nb[i];\n    }\n    new_tensor->flags = tensor->flags;\n    memcpy(new_tensor->op_params, tensor->op_params, sizeof(tensor->op_params));\n    strcpy(new_tensor->name, tensor->name);\n    new_tensor->data = tensor->data;\n    new_tensor->buffer = tensor->buffer;\n    new_tensor->extra = tensor->extra;\n    new_tensor->view_offs = tensor->view_offs;\n    new_tensor->view_src = map_tensor(tensor_map, ctx, tensor->view_src);\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        new_tensor->src[i] = map_tensor(tensor_map, ctx, tensor->src[i]);\n    }\n\n    return new_tensor;\n}\n\nstatic ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {\n    std::map<ggml_tensor *, ggml_tensor *> tensor_map;\n\n    ggml_cgraph * dst = ggml_new_graph_custom(ctx, src->size, /*grads =*/ true);\n\n    for (int i = 0; i < src->n_leafs; i++) {\n        ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->leafs[i]));\n    }\n    GGML_ASSERT(dst->n_leafs == src->n_leafs);\n    for (int i = 0; i < src->n_nodes; i++) {\n        ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->nodes[i]));\n    }\n    GGML_ASSERT(dst->n_nodes == src->n_nodes);\n    for (int i = 0; i < src->n_nodes; ++i) {\n        const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);\n        const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);\n\n        GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);\n        GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));\n        GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);\n        GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));\n\n        dst->grads[igrad_dst]     = src->grads[igrad_src];\n        dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];\n    }\n\n    return dst;\n}\n\nstatic void ggml_opt_build(ggml_opt_context_t opt_ctx) {\n    GGML_ASSERT(opt_ctx->ctx_compute && \"no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc\");\n    GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && \"when using static graphs the inputs must be allocated statically\");\n\n    const enum ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;\n\n    const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&\n        !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);\n\n    const bool need_momenta = opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT &&\n        opt_ctx->optimizer == GGML_OPT_OPTIMIZER_TYPE_ADAMW;\n\n    ggml_set_input(opt_ctx->inputs);\n    ggml_set_output(opt_ctx->outputs);\n\n    int n_param = 0;\n    for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) {\n        const struct ggml_tensor * node = opt_ctx->gf->nodes[i];\n        if (node->flags & GGML_TENSOR_FLAG_PARAM) {\n            n_param++;\n        }\n        GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && \"support for extra loss terms not implemented\");\n    }\n\n    if (!opt_ctx->ctx_static) {\n        // The static context is used for:\n        //   - gradients (1 per loss, 1 tensor per param if using gradient accumulation)\n        //   - optimizer momenta (2 tensors per param)\n        //   - labels (if using static graphs)\n        //   - loss (if using static graphs, up to 5 tensors)\n        //   - pred (if using static graphs)\n        //   - ncorrect (if using static graphs, 2 tensors).\n        constexpr size_t n_loss = 1;\n        const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);\n        const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;\n        const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ size_meta,\n            /*.mem_buffer =*/ nullptr,\n            /*.no_alloc   =*/ true,\n        };\n        opt_ctx->ctx_static = ggml_init(params);\n    }\n    GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc);\n\n    {\n        // The cpu context is allocated statically if using static graphs, dynamically otherwise.\n        // It is used for:\n        //   - optimizer parameters (1 shared for all optimizer invocations)\n        const size_t size_meta = 1 * ggml_tensor_overhead();\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ size_meta,\n            /*.mem_buffer =*/ nullptr,\n            /*.no_alloc   =*/ true,\n        };\n        ggml_free(opt_ctx->ctx_cpu);\n        opt_ctx->ctx_cpu = ggml_init(params);\n\n        ggml_backend_buffer_free(opt_ctx->buf_cpu);\n        opt_ctx->buf_cpu = nullptr;\n    }\n\n    struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute;\n\n    switch (opt_ctx->loss_type) {\n        case GGML_OPT_LOSS_TYPE_MEAN: {\n            opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);\n            ggml_set_name(opt_ctx->loss, \"loss_sum\");\n            const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));\n            opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);\n            ggml_set_name(opt_ctx->loss, \"loss_mean\");\n            opt_ctx->loss_per_datapoint = true;\n            break;\n        }\n        case GGML_OPT_LOSS_TYPE_SUM: {\n            opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);\n            ggml_set_name(opt_ctx->loss, \"loss_sum\");\n            opt_ctx->loss_per_datapoint = false;\n            break;\n        }\n        case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {\n            opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);\n            ggml_set_input(opt_ctx->labels);\n            ggml_set_name(opt_ctx->labels, \"labels\");\n            opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels);\n            ggml_set_name(opt_ctx->loss, \"loss_cross_entropy\");\n            if (opt_ctx->opt_period > 1) {\n                opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period);\n                ggml_set_name(opt_ctx->loss, \"loss_cross_entropy_scaled\");\n            }\n            opt_ctx->loss_per_datapoint = true;\n            break;\n        }\n        case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {\n            opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);\n            ggml_set_input(opt_ctx->labels);\n            ggml_set_name(opt_ctx->labels, \"labels\");\n            opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels);\n            ggml_set_name(opt_ctx->loss, \"loss_error\");\n            opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss);\n            ggml_set_name(opt_ctx->loss, \"loss_squared_error\");\n            opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss);\n            ggml_set_name(opt_ctx->loss, \"loss_sum_squared_error\");\n            const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));\n            opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);\n            ggml_set_name(opt_ctx->loss, \"loss_mean_squared_error\");\n            opt_ctx->loss_per_datapoint = true;\n            break;\n        }\n    }\n    ggml_set_output(opt_ctx->loss);\n    ggml_set_loss(opt_ctx->loss);\n    ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss);\n\n    if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) {\n        opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs);\n        ggml_set_name(opt_ctx->pred, \"pred\");\n        ggml_set_output(opt_ctx->pred);\n        ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred);\n\n        opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels));\n        ggml_set_name(opt_ctx->ncorrect, \"ncorrect\");\n        ggml_set_output(opt_ctx->ncorrect);\n        ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect);\n    }\n\n    if (opt_ctx->buf_static) {\n        if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) {\n            return;\n        }\n    } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) {\n        opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(\n            opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));\n        return;\n    }\n\n    if (opt_ctx->grad_accs.empty()) {\n        GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD);\n\n        const int n_nodes = opt_ctx->gf->n_nodes;\n        opt_ctx->grad_accs.resize(n_nodes);\n        for (int i = 0; i < n_nodes; ++i) {\n            ggml_tensor * node = opt_ctx->gf->nodes[i];\n            if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {\n                opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);\n            } else {\n                opt_ctx->grad_accs[i] = nullptr;\n            }\n        }\n\n        if (need_momenta && opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {\n            opt_ctx->grad_m.resize(n_nodes);\n            opt_ctx->grad_v.resize(n_nodes);\n            for (int i = 0; i < n_nodes; ++i) {\n                ggml_tensor * node = opt_ctx->gf->nodes[i];\n                if (node->flags & GGML_TENSOR_FLAG_PARAM) {\n                    opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);\n                    opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);\n                } else {\n                    opt_ctx->grad_m[i] = nullptr;\n                    opt_ctx->grad_v[i] = nullptr;\n                }\n            }\n        }\n    }\n\n    // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.\n    opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);\n    ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());\n\n    if (opt_ctx->buf_static) {\n        if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) {\n            return;\n        }\n    } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) {\n        opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));\n        ggml_graph_reset(opt_ctx->gb_grad);\n    }\n\n    GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT);\n\n    // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.\n    opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);\n\n    opt_ctx->opt_step_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, need_momenta ? 7 : 2);\n    ggml_tensor * adamw_params = opt_ctx->opt_step_params;\n    ggml_set_input(adamw_params);\n    const char * optimizer_name = ggml_opt_optimizer_name(opt_ctx->optimizer);\n    ggml_format_name(adamw_params, \"%s_params\", optimizer_name);\n    for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {\n        struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];\n        struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);\n\n        if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {\n            struct ggml_tensor * m = nullptr;\n            struct ggml_tensor * v = nullptr;\n            if (need_momenta) {\n                m = opt_ctx->grad_m[i];\n                v = opt_ctx->grad_v[i];\n                ggml_format_name(m, \"AdamW m for %s\", node->name);\n                ggml_format_name(v, \"AdamW v for %s\", node->name);\n            }\n            struct ggml_tensor * opt_step;\n            switch (optimizer) {\n                case GGML_OPT_OPTIMIZER_TYPE_ADAMW:\n                    opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);\n                    break;\n                case GGML_OPT_OPTIMIZER_TYPE_SGD:\n                    opt_step = ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);\n                    break;\n                default:\n                    GGML_ABORT(\"fatal error\");\n            }\n            ggml_format_name(opt_step, \"%s step for %s\", optimizer_name, node->name);\n            ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);\n        }\n    }\n\n    if (!opt_ctx->buf_static) {\n        opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(\n            opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));\n        ggml_graph_reset(opt_ctx->gb_opt);\n    }\n\n    opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type());\n}\n\nggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {\n    ggml_opt_context_t result = new struct ggml_opt_context;\n    result->backend_sched    = params.backend_sched;\n    result->ctx_compute      = params.ctx_compute;\n    result->loss_type        = params.loss_type;\n    result->build_type       = params.build_type;\n    result->build_type_alloc = params.build_type;\n    result->inputs           = params.inputs;\n    result->outputs          = params.outputs;\n    result->opt_period       = params.opt_period;\n    result->get_opt_pars     = params.get_opt_pars;\n    result->get_opt_pars_ud  = params.get_opt_pars_ud;\n    result->optimizer        = params.optimizer;\n\n    GGML_ASSERT(result->opt_period >= 1);\n\n    result->static_graphs = result->ctx_compute;\n\n    if (!result->static_graphs) {\n        GGML_ASSERT(!result->inputs);\n        GGML_ASSERT(!result->outputs);\n        return result;\n    }\n\n    GGML_ASSERT(result->inputs);\n    GGML_ASSERT(result->outputs);\n\n    result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.\n    ggml_build_forward_expand(result->gf, result->outputs);\n\n    ggml_opt_build(result);\n\n    return result;\n}\n\nvoid ggml_opt_free(ggml_opt_context_t opt_ctx) {\n    if (opt_ctx == nullptr) {\n        return;\n    }\n    ggml_backend_buffer_free(opt_ctx->buf_static);\n    ggml_backend_buffer_free(opt_ctx->buf_cpu);\n    ggml_free(opt_ctx->ctx_static);\n    ggml_free(opt_ctx->ctx_cpu);\n    delete opt_ctx;\n}\n\nvoid ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {\n    if (optimizer) {\n        ggml_graph_reset(opt_ctx->gb_opt);\n        opt_ctx->iter = 1;\n    } else {\n        ggml_graph_reset(opt_ctx->gb_grad);\n    }\n}\n\nbool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) {\n    return opt_ctx->static_graphs;\n}\n\nstruct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {\n    return opt_ctx->inputs;\n}\n\nstruct ggml_tensor * ggml_opt_outputs(ggml_opt_context_t opt_ctx) {\n    return opt_ctx->outputs;\n}\n\nstruct ggml_tensor * ggml_opt_labels(ggml_opt_context_t opt_ctx) {\n    return opt_ctx->labels;\n}\n\nstruct ggml_tensor * ggml_opt_loss(ggml_opt_context_t opt_ctx) {\n    return opt_ctx->loss;\n}\n\nstruct ggml_tensor * ggml_opt_pred(ggml_opt_context_t opt_ctx) {\n    return opt_ctx->pred;\n}\n\nstruct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx) {\n    return opt_ctx->ncorrect;\n}\n\nstruct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node) {\n    return ggml_graph_get_grad_acc(opt_ctx->gb_opt, node);\n}\n\n// ====== Optimization Result ======\n\nggml_opt_result_t ggml_opt_result_init() {\n    return new ggml_opt_result;\n}\n\nvoid ggml_opt_result_free(ggml_opt_result_t result) {\n    delete result;\n}\n\nvoid ggml_opt_result_reset(ggml_opt_result_t result) {\n    result->ndata = 0;\n    result->loss.clear();\n    result->pred.clear();\n    result->ncorrect = 0;\n}\n\nvoid ggml_opt_result_ndata(ggml_opt_result_t result, int64_t * ndata) {\n    *ndata = result->ndata;\n}\n\nvoid ggml_opt_result_loss(ggml_opt_result_t result, double * loss, double * unc) {\n    const int64_t nbatches = result->loss.size(); // Number of physical batches.\n\n    if (nbatches == 0) {\n        *loss = 0.0;\n        *unc  = NAN;\n        return;\n    }\n\n    double sum         = 0.0;\n    double sum_squared = 0.0;\n\n    for (const float & loss : result->loss) {\n        // If the loss is per datapoint it was scaled by 1.0f/opt_period for each physical batch.\n        const float loss_scaled = result->loss_per_datapoint ? loss*result->opt_period : loss;\n        sum         += loss_scaled;\n        sum_squared += loss_scaled*loss_scaled;\n    }\n\n    const double mean = sum/nbatches;\n    *loss = result->loss_per_datapoint ? mean : sum;\n\n    if (!unc) {\n        return;\n    }\n\n    if (nbatches < 2) {\n        *unc = NAN;\n        return;\n    }\n\n    const double var_sum = sum_squared/nbatches - mean*mean; // variance without Bessel's correction, i.e. nbatches/(nbatches-1)\n    *unc = result->loss_per_datapoint ? sqrt(var_sum / (nbatches - 1)) : sqrt(var_sum * nbatches/(nbatches - 1));\n}\n\nvoid ggml_opt_result_pred(ggml_opt_result_t result, int32_t * pred) {\n    for (size_t i = 0; i < result->pred.size(); ++i) {\n        pred[i] = result->pred[i];\n    }\n}\n\nvoid ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc) {\n    *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN;\n\n    if (!unc) {\n        return;\n    }\n\n    *unc = result->ncorrect >= 0 && result->ndata >= 2 ?\n        sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN;\n}\n\n// ====== Computation ======\n\nvoid ggml_opt_prepare_alloc(\n        ggml_opt_context_t    opt_ctx,\n        struct ggml_context * ctx_compute,\n        struct ggml_cgraph  * gf,\n        struct ggml_tensor  * inputs,\n        struct ggml_tensor  * outputs) {\n    GGML_ASSERT(!opt_ctx->static_graphs);\n    opt_ctx->ctx_compute = ctx_compute;\n    opt_ctx->gf          = gf;\n    opt_ctx->inputs      = inputs;\n    opt_ctx->outputs     = outputs;\n}\n\nvoid ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {\n    GGML_ASSERT(!opt_ctx->eval_ready);\n    if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {\n        ggml_graph_reset(opt_ctx->gb_grad);\n    }\n    if (backward) {\n        const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;\n        opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD;\n    } else {\n        opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD;\n    }\n\n    if (!opt_ctx->static_graphs) {\n        ggml_opt_build(opt_ctx);\n    }\n\n    struct ggml_cgraph * graph = nullptr;\n    switch (opt_ctx->build_type) {\n        case GGML_OPT_BUILD_TYPE_FORWARD: {\n            graph = opt_ctx->gf;\n        } break;\n        case GGML_OPT_BUILD_TYPE_GRAD: {\n            graph = opt_ctx->gb_grad;\n        } break;\n        case GGML_OPT_BUILD_TYPE_OPT: {\n            graph = opt_ctx->gb_opt;\n        } break;\n    }\n    GGML_ASSERT(graph);\n\n    if (opt_ctx->allocated_graph == graph) {\n        opt_ctx->eval_ready = true;\n        return;\n    }\n\n    ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph\n\n    if (opt_ctx->static_graphs) {\n        ggml_init_params params = {\n            /*.mem_size   =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads),\n            /*.mem_buffer =*/ nullptr,\n            /*.no_alloc   =*/ true,\n        };\n        ggml_free(opt_ctx->ctx_copy);\n        opt_ctx->ctx_copy = ggml_init(params);\n\n        opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);\n    } else {\n        opt_ctx->allocated_graph_copy = graph;\n    }\n\n    ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);\n    opt_ctx->allocated_graph = graph;\n\n    opt_ctx->eval_ready = true;\n}\n\nvoid ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {\n    GGML_ASSERT(opt_ctx->eval_ready);\n    if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {\n        const ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);\n\n        switch (opt_ctx->optimizer) {\n            case GGML_OPT_OPTIMIZER_TYPE_ADAMW: {\n                GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);\n                GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);\n                GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);\n                GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);\n                GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);\n                GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);\n                GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);\n                GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);\n\n                // beta1, beta2 after applying warmup\n                const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));\n                const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));\n\n                float * adamw_par_data = ggml_get_data_f32(opt_ctx->opt_step_params);\n                adamw_par_data[0] = opt_pars.adamw.alpha;\n                adamw_par_data[1] = opt_pars.adamw.beta1;\n                adamw_par_data[2] = opt_pars.adamw.beta2;\n                adamw_par_data[3] = opt_pars.adamw.eps;\n                adamw_par_data[4] = opt_pars.adamw.wd;\n                adamw_par_data[5] = beta1h;\n                adamw_par_data[6] = beta2h;\n            } break;\n            case GGML_OPT_OPTIMIZER_TYPE_SGD: {\n                GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);\n                GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);\n                GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);\n                float * sgd = ggml_get_data_f32(opt_ctx->opt_step_params);\n                sgd[0] = opt_pars.sgd.alpha;\n                sgd[1] = opt_pars.sgd.wd;\n            } break;\n            default:\n                GGML_ABORT(\"fatal error\");\n        }\n    }\n\n    ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);\n    opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;\n    opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;\n\n    if (!opt_ctx->static_graphs) {\n        opt_ctx->gf                   = nullptr;\n        opt_ctx->gb_grad              = nullptr;\n        opt_ctx->gb_opt               = nullptr;\n        opt_ctx->allocated_graph      = nullptr;\n        opt_ctx->allocated_graph_copy = nullptr;\n    }\n\n    opt_ctx->eval_ready = false;\n\n    if (!result) {\n        return;\n    }\n\n    if (result->ndata == 0) {\n        result->loss_per_datapoint = opt_ctx->loss_per_datapoint;\n        result->opt_period         = opt_ctx->opt_period;\n    } else {\n        GGML_ASSERT(result->loss_per_datapoint == opt_ctx->loss_per_datapoint);\n        GGML_ASSERT(result->opt_period         == opt_ctx->opt_period);\n    }\n\n    const int64_t ndata = opt_ctx->outputs->ne[1];\n    GGML_ASSERT(result->ndata == ndata*int64_t(result->loss.size()) && \"varying batch size not supported\");\n    result->ndata += ndata;\n\n    GGML_ASSERT(ggml_is_scalar(opt_ctx->loss));\n    GGML_ASSERT(opt_ctx->loss->type == GGML_TYPE_F32);\n    float loss;\n    ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));\n    result->loss.push_back(loss);\n\n    if (opt_ctx->pred) {\n        GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);\n        std::vector<int32_t> pred(ndata);\n        ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));\n        result->pred.insert(result->pred.end(), pred.begin(), pred.end());\n    }\n\n    if (!opt_ctx->ncorrect || result->ncorrect < 0) {\n        result->ncorrect = -1;\n        return;\n    }\n\n    GGML_ASSERT(ggml_is_scalar(opt_ctx->ncorrect));\n    GGML_ASSERT(opt_ctx->ncorrect->type == GGML_TYPE_I64);\n    int64_t ncorrect;\n    ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, ggml_nbytes(opt_ctx->ncorrect));\n    result->ncorrect += ncorrect;\n}\n\n// ====== High-Level Functions ======\n\nvoid ggml_opt_epoch(\n        ggml_opt_context_t      opt_ctx,\n        ggml_opt_dataset_t      dataset,\n        ggml_opt_result_t       result_train,\n        ggml_opt_result_t       result_eval,\n        int64_t                 idata_split,\n        ggml_opt_epoch_callback callback_train,\n        ggml_opt_epoch_callback callback_eval) {\n    GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && \"ggml_opt_epoch requires static graphs\");\n    struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);\n    struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);\n    struct ggml_tensor * data   = ggml_opt_dataset_data(dataset);\n    GGML_ASSERT(data->ne[0] == inputs->ne[0]);\n\n    const int64_t ndata       =   data->ne[1];\n    const int64_t ndata_batch = inputs->ne[1];\n\n    GGML_ASSERT(data->ne[1] % inputs->ne[1] == 0);\n    const int64_t nbatches = ndata/ndata_batch;\n\n    idata_split = idata_split < 0 ? ndata : idata_split;\n    GGML_ASSERT(idata_split % ndata_batch == 0);\n    const int64_t ibatch_split = idata_split / ndata_batch;\n\n    int64_t ibatch = 0;\n    int64_t t_loop_start = ggml_time_us();\n    for (; ibatch < ibatch_split; ++ibatch) {\n        ggml_opt_alloc(opt_ctx, /*backward =*/ true);\n        ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);\n        ggml_opt_eval(opt_ctx, result_train);\n        if (callback_train) {\n            callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);\n        }\n    }\n    t_loop_start = ggml_time_us();\n    for (; ibatch < nbatches; ++ibatch) {\n        ggml_opt_alloc(opt_ctx, /*backward =*/ false);\n        ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);\n        ggml_opt_eval(opt_ctx, result_eval);\n        if (callback_eval) {\n            callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);\n        }\n    }\n}\n\nvoid ggml_opt_epoch_callback_progress_bar(\n        bool               train,\n        ggml_opt_context_t opt_ctx,\n        ggml_opt_dataset_t dataset,\n        ggml_opt_result_t  result,\n        int64_t            ibatch,\n        int64_t            ibatch_max,\n        int64_t            t_start_us) {\n    fprintf(stderr, \"%s[\", train ? \"train: \" : \"val:   \");\n\n    // The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.\n    constexpr int64_t bar_length = 8;\n    const int64_t ibatch8 = 8 * ibatch;\n    for (int64_t j = 0; j < bar_length; ++j) {\n        if        (ibatch_max * (8*j + 8) / bar_length < ibatch8) {\n            fprintf(stderr, \"\\u2588\"); // full block\n        } else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) {\n            fprintf(stderr, \"\\u2589\"); // 7/8 filled\n        } else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) {\n            fprintf(stderr, \"\\u258A\"); // 6/8 filled\n        } else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) {\n            fprintf(stderr, \"\\u258B\"); // 5/8 filled\n        } else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) {\n            fprintf(stderr, \"\\u258C\"); // 4/8 filled\n        } else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) {\n            fprintf(stderr, \"\\u258D\"); // 3/8 filled\n        } else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) {\n            fprintf(stderr, \"\\u258E\"); // 2/8 filled\n        } else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) {\n            fprintf(stderr, \"\\u258F\"); // 1/8 filled\n        } else {\n            fprintf(stderr, \" \");\n        }\n    }\n\n    const int64_t batch_size = ggml_opt_inputs(opt_ctx)->ne[1];\n    const int64_t idata      = ibatch*batch_size;\n    const int64_t idata_max  = ibatch_max*batch_size;\n\n    double loss;\n    double loss_unc;\n    ggml_opt_result_loss(result, &loss, &loss_unc);\n\n    double accuracy;\n    double accuracy_unc;\n    ggml_opt_result_accuracy(result, &accuracy, &accuracy_unc);\n\n    const int64_t t_ibatch_us = ggml_time_us() - t_start_us;\n    int64_t t_ibatch_s = t_ibatch_us / 1000000;\n    const int64_t t_ibatch_h = t_ibatch_s / 3600;\n    t_ibatch_s -= t_ibatch_h * 3600;\n    const int64_t t_ibatch_m = t_ibatch_s / 60;\n    t_ibatch_s -= t_ibatch_m * 60;\n\n    const int64_t t_eta_us = t_ibatch_us * (ibatch_max - ibatch)/ibatch;\n    int64_t t_eta_s = t_eta_us / 1000000;\n    const int64_t t_eta_h = t_eta_s / 3600;\n    t_eta_s -= t_eta_h * 3600;\n    const int64_t t_eta_m = t_eta_s / 60;\n    t_eta_s -= t_eta_m * 60;\n\n    fprintf(stderr, \"] data=%07\" PRId64 \"/%07\" PRId64 \" loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% \"\n            \"t=%02\" PRId64 \":%02\" PRId64 \":%02\" PRId64 \" ETA=%02\" PRId64 \":%02\" PRId64 \":%02\" PRId64 \" \\r\",\n            idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,\n            t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);\n    if (ibatch == ibatch_max) {\n        fprintf(stderr, \"\\n\");\n    }\n    fflush(stderr);\n\n    GGML_UNUSED(dataset);\n}\n\nvoid ggml_opt_fit(\n        ggml_backend_sched_t            backend_sched,\n        ggml_context                  * ctx_compute,\n        ggml_tensor                   * inputs,\n        ggml_tensor                   * outputs,\n        ggml_opt_dataset_t              dataset,\n        enum ggml_opt_loss_type         loss_type,\n        enum ggml_opt_optimizer_type    optimizer,\n        ggml_opt_get_optimizer_params   get_opt_pars,\n        int64_t                         nepoch,\n        int64_t                         nbatch_logical,\n        float                           val_split,\n        bool                            silent) {\n    ggml_time_init();\n    const int64_t t_start_us = ggml_time_us();\n\n    const int64_t ndata           = ggml_opt_dataset_data(dataset)->ne[1];\n    const int64_t nbatch_physical = inputs->ne[1];\n    GGML_ASSERT(ndata          % nbatch_logical  == 0);\n    GGML_ASSERT(nbatch_logical % nbatch_physical == 0);\n\n    const int64_t opt_period       = nbatch_logical / nbatch_physical;\n    const int64_t nbatches_logical = ndata / nbatch_logical;\n\n    GGML_ASSERT(val_split >= 0.0f);\n    GGML_ASSERT(val_split <  1.0f);\n    const int64_t ibatch_split = int64_t(((1.0f - val_split) * nbatches_logical)) * opt_period; // train <-> val split index (physical)\n    const int64_t idata_split  = ibatch_split * nbatch_physical;\n\n    int64_t epoch = 1;\n\n    ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type);\n    params.ctx_compute     = ctx_compute;\n    params.inputs          = inputs;\n    params.outputs         = outputs;\n    params.opt_period      = opt_period;\n    params.get_opt_pars    = get_opt_pars;\n    params.get_opt_pars_ud = &epoch;\n    params.optimizer       = optimizer;\n    ggml_opt_context_t opt_ctx = ggml_opt_init(params);\n\n    // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.\n    if (nbatch_logical < ndata) {\n        ggml_opt_dataset_shuffle(opt_ctx, dataset, -1); // Shuffle all data (train + validation).\n    }\n\n    ggml_opt_result_t result_train = ggml_opt_result_init();\n    ggml_opt_result_t result_val   = ggml_opt_result_init();\n\n    ggml_opt_epoch_callback epoch_callback = silent ? nullptr : ggml_opt_epoch_callback_progress_bar;\n\n    for (; epoch <= nepoch; ++epoch) {\n        if (nbatch_logical < idata_split) {\n            ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split);\n        }\n\n        ggml_opt_result_reset(result_train);\n        ggml_opt_result_reset(result_val);\n\n        if (!silent) {\n            fprintf(stderr, \"%s: epoch %04\" PRId64 \"/%04\" PRId64 \":\\n\", __func__, epoch, nepoch);\n        }\n        ggml_opt_epoch(opt_ctx, dataset, result_train, result_val, idata_split, epoch_callback, epoch_callback);\n        if (!silent) {\n            fprintf(stderr, \"\\n\");\n        }\n    }\n\n    if (!silent) {\n        int64_t t_total_s = (ggml_time_us() - t_start_us) / 1000000;\n        const int64_t t_total_h = t_total_s / 3600;\n        t_total_s -= t_total_h * 3600;\n        const int64_t t_total_m = t_total_s / 60;\n        t_total_s -= t_total_m * 60;\n        fprintf(stderr, \"%s: training took %02\" PRId64 \":%02\" PRId64 \":%02\" PRId64 \"\\n\", __func__, t_total_h, t_total_m, t_total_s);\n    }\n\n    ggml_opt_free(opt_ctx);\n    ggml_opt_result_free(result_train);\n    ggml_opt_result_free(result_val);\n}\n\nenum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t c) {\n    return c->optimizer;\n}\n\nGGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type o) {\n    switch (o) {\n        case GGML_OPT_OPTIMIZER_TYPE_ADAMW:\n            return \"adamw\";\n        case GGML_OPT_OPTIMIZER_TYPE_SGD:\n            return \"sgd\";\n        default:\n            return \"undefined\";\n    };\n}\n"
  },
  {
    "path": "src/ggml-quants.c",
    "content": "#define GGML_COMMON_IMPL_C\n#include \"ggml-common.h\"\n\n#include \"ggml-quants.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-cpu/ggml-cpu-impl.h\"\n#include \"ggml-cpu.h\"\n\n#include <math.h>\n#include <string.h>\n#include <assert.h>\n#include <float.h>\n#include <stdlib.h> // for qsort\n#include <stdio.h>  // for GGML_ASSERT\n\n#define GROUP_MAX_EPS 1e-15f\n#define GROUP_MAX_EPS_IQ3_XXS 1e-8f\n#define GROUP_MAX_EPS_IQ2_S 1e-8f\n#define GROUP_MAX_EPS_IQ1_M 1e-7f\n#define GROUP_MAX_EPS_IQ1_S 1e-12f\n\n#define UNUSED GGML_UNUSED\n\nstatic inline int best_index_int8(int n, const int8_t * val, float x) {\n    if (x <= val[0]) return 0;\n    if (x >= val[n-1]) return n-1;\n    int ml = 0, mu = n-1;\n    while (mu-ml > 1) {\n        int mav = (ml+mu)/2;\n        if (x < val[mav]) mu = mav; else ml = mav;\n    }\n    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;\n}\n\n// reference implementation for deterministic creation of model files\nvoid quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {\n    static const int qk = QK4_0;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n        float amax = 0.0f; // absolute max\n        float max  = 0.0f;\n\n        for (int j = 0; j < qk; j++) {\n            const float v = x[i*qk + j];\n            if (amax < fabsf(v)) {\n                amax = fabsf(v);\n                max  = v;\n            }\n        }\n\n        const float d  = max / -8;\n        const float id = d ? 1.0f/d : 0.0f;\n\n        y[i].d = GGML_FP32_TO_FP16(d);\n\n        for (int j = 0; j < qk/2; ++j) {\n            const float x0 = x[i*qk + 0    + j]*id;\n            const float x1 = x[i*qk + qk/2 + j]*id;\n\n            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));\n            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));\n\n            y[i].qs[j]  = xi0;\n            y[i].qs[j] |= xi1 << 4;\n        }\n    }\n}\n\nvoid quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k) {\n    const int qk = QK4_1;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n        float min = FLT_MAX;\n        float max = -FLT_MAX;\n\n        for (int j = 0; j < qk; j++) {\n            const float v = x[i*qk + j];\n\n            if (v < min) min = v;\n            if (v > max) max = v;\n        }\n\n        const float d  = (max - min) / ((1 << 4) - 1);\n        const float id = d ? 1.0f/d : 0.0f;\n\n        y[i].d = GGML_FP32_TO_FP16(d);\n        y[i].m = GGML_FP32_TO_FP16(min);\n\n        for (int j = 0; j < qk/2; ++j) {\n            const float x0 = (x[i*qk + 0    + j] - min)*id;\n            const float x1 = (x[i*qk + qk/2 + j] - min)*id;\n\n            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));\n            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));\n\n            y[i].qs[j]  = xi0;\n            y[i].qs[j] |= xi1 << 4;\n        }\n    }\n}\n\nvoid quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k) {\n    static const int qk = QK5_0;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n        float amax = 0.0f; // absolute max\n        float max  = 0.0f;\n\n        for (int j = 0; j < qk; j++) {\n            const float v = x[i*qk + j];\n            if (amax < fabsf(v)) {\n                amax = fabsf(v);\n                max  = v;\n            }\n        }\n\n        const float d  = max / -16;\n        const float id = d ? 1.0f/d : 0.0f;\n\n        y[i].d = GGML_FP32_TO_FP16(d);\n\n        uint32_t qh = 0;\n\n        for (int j = 0; j < qk/2; ++j) {\n            const float x0 = x[i*qk + 0    + j]*id;\n            const float x1 = x[i*qk + qk/2 + j]*id;\n\n            const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));\n            const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));\n\n            y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);\n\n            // get the 5-th bit and store it in qh at the right position\n            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);\n            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);\n        }\n\n        memcpy(&y[i].qh, &qh, sizeof(qh));\n    }\n}\n\nvoid quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k) {\n    const int qk = QK5_1;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n        float min = FLT_MAX;\n        float max = -FLT_MAX;\n\n        for (int j = 0; j < qk; j++) {\n            const float v = x[i*qk + j];\n\n            if (v < min) min = v;\n            if (v > max) max = v;\n        }\n\n        const float d  = (max - min) / ((1 << 5) - 1);\n        const float id = d ? 1.0f/d : 0.0f;\n\n        y[i].d = GGML_FP32_TO_FP16(d);\n        y[i].m = GGML_FP32_TO_FP16(min);\n\n        uint32_t qh = 0;\n\n        for (int j = 0; j < qk/2; ++j) {\n            const float x0 = (x[i*qk + 0    + j] - min)*id;\n            const float x1 = (x[i*qk + qk/2 + j] - min)*id;\n\n            const uint8_t xi0 = (uint8_t)(x0 + 0.5f);\n            const uint8_t xi1 = (uint8_t)(x1 + 0.5f);\n\n            y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);\n\n            // get the 5-th bit and store it in qh at the right position\n            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);\n            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);\n        }\n\n        memcpy(&y[i].qh, &qh, sizeof(y[i].qh));\n    }\n}\n\n// reference implementation for deterministic creation of model files\nvoid quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK8_0 == 0);\n    const int nb = k / QK8_0;\n\n    for (int i = 0; i < nb; i++) {\n        float amax = 0.0f; // absolute max\n\n        for (int j = 0; j < QK8_0; j++) {\n            const float v = x[i*QK8_0 + j];\n            amax = MAX(amax, fabsf(v));\n        }\n\n        const float d = amax / ((1 << 7) - 1);\n        const float id = d ? 1.0f/d : 0.0f;\n\n        y[i].d = GGML_FP32_TO_FP16(d);\n\n        for (int j = 0; j < QK8_0; ++j) {\n            const float x0 = x[i*QK8_0 + j]*id;\n\n            y[i].qs[j] = roundf(x0);\n        }\n    }\n}\n\n// reference implementation for deterministic creation of model files\nvoid quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k) {\n    assert(QK8_1 == 32);\n    assert(k % QK8_1 == 0);\n    const int nb = k / QK8_1;\n\n    for (int i = 0; i < nb; i++) {\n        float amax = 0.0f; // absolute max\n\n        for (int j = 0; j < QK8_1; j++) {\n            const float v = x[i*QK8_1 + j];\n            amax = MAX(amax, fabsf(v));\n        }\n\n        const float d = amax / ((1 << 7) - 1);\n        const float id = d ? 1.0f/d : 0.0f;\n\n        y[i].d = GGML_FP32_TO_FP16(d);\n\n        int sum = 0;\n\n        for (int j = 0; j < QK8_1/2; ++j) {\n            const float v0 = x[i*QK8_1           + j]*id;\n            const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;\n\n            y[i].qs[          j] = roundf(v0);\n            y[i].qs[QK8_1/2 + j] = roundf(v1);\n\n            sum += y[i].qs[          j];\n            sum += y[i].qs[QK8_1/2 + j];\n        }\n\n        y[i].s = GGML_FP32_TO_FP16(sum*d);\n    }\n}\n\nstatic inline int best_index_mxfp4(float x, float e) {\n    int best_index = 0;\n    float best_err = fabsf(kvalues_mxfp4[0]*e - x);\n    for (int i = 1; i < 16; i++) {\n        float err = fabsf(kvalues_mxfp4[i]*e - x);\n        if (err < best_err) {\n            best_index = i;\n            best_err = err;\n        }\n    }\n    return best_index;\n}\n\nvoid quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {\n    static const int qk = QK_MXFP4;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n        float amax = 0.0f; // absolute max\n\n        for (int j = 0; j < qk; j++) {\n            const float v = x[i*qk + j];\n\n            if (amax < fabsf(v)) {\n                amax = fabsf(v);\n            }\n        }\n\n        const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;\n\n        const float d = GGML_E8M0_TO_FP32_HALF(e);\n\n        y[i].e = e;\n\n        for (int j = 0; j < qk/2; ++j) {\n            const uint8_t x0 = best_index_mxfp4(x[i*qk + 0    + j], d);\n            const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);\n\n            y[i].qs[j]  = x0;\n            y[i].qs[j] |= x1 << 4;\n        }\n    }\n}\n\nvoid quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k) {\n    static const int qk = QK_NVFP4;\n    static const int qk_sub = QK_NVFP4_SUB;\n    static const int n_sub = QK_NVFP4 / QK_NVFP4_SUB;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n        for (int s = 0; s < n_sub; s++) {\n            const float * xb = x + i*qk + s*qk_sub;\n\n            float amax = 0.0f;\n            for (int j = 0; j < qk_sub; j++) {\n                if (amax < fabsf(xb[j])) {\n                    amax = fabsf(xb[j]);\n                }\n            }\n\n            // UE4M3 scale: amax / 6.0 maps the max E2M1 value (6.0) to amax\n            const uint8_t ue = ggml_fp32_to_ue4m3(amax / 6.0f);\n            y[i].d[s] = ue;\n            const float d = ggml_ue4m3_to_fp32(ue);\n\n            for (int j = 0; j < qk_sub/2; ++j) {\n                const uint8_t x0 = best_index_mxfp4(xb[0        + j], d);\n                const uint8_t x1 = best_index_mxfp4(xb[qk_sub/2 + j], d);\n\n                y[i].qs[s*(qk_sub/2) + j] = x0 | (x1 << 4);\n            }\n        }\n    }\n}\n\nvoid dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    static const int qk = QK4_0;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n\n        for (int j = 0; j < qk/2; ++j) {\n            const int x0 = (x[i].qs[j] & 0x0F) - 8;\n            const int x1 = (x[i].qs[j] >>   4) - 8;\n\n            y[i*qk + j + 0   ] = x0*d;\n            y[i*qk + j + qk/2] = x1*d;\n        }\n    }\n}\n\nvoid dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    static const int qk = QK4_1;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n        const float m = GGML_FP16_TO_FP32(x[i].m);\n\n        for (int j = 0; j < qk/2; ++j) {\n            const int x0 = (x[i].qs[j] & 0x0F);\n            const int x1 = (x[i].qs[j] >>   4);\n\n            y[i*qk + j + 0   ] = x0*d + m;\n            y[i*qk + j + qk/2] = x1*d + m;\n        }\n    }\n}\n\nvoid dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    static const int qk = QK5_0;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n\n        uint32_t qh;\n        memcpy(&qh, x[i].qh, sizeof(qh));\n\n        for (int j = 0; j < qk/2; ++j) {\n            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;\n            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;\n\n            const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;\n            const int32_t x1 = ((x[i].qs[j] >>   4) | xh_1) - 16;\n\n            y[i*qk + j + 0   ] = x0*d;\n            y[i*qk + j + qk/2] = x1*d;\n        }\n    }\n}\n\nvoid dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    static const int qk = QK5_1;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n        const float m = GGML_FP16_TO_FP32(x[i].m);\n\n        uint32_t qh;\n        memcpy(&qh, x[i].qh, sizeof(qh));\n\n        for (int j = 0; j < qk/2; ++j) {\n            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;\n            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;\n\n            const int x0 = (x[i].qs[j] & 0x0F) | xh_0;\n            const int x1 = (x[i].qs[j] >>   4) | xh_1;\n\n            y[i*qk + j + 0   ] = x0*d + m;\n            y[i*qk + j + qk/2] = x1*d + m;\n        }\n    }\n}\n\nvoid dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    static const int qk = QK8_0;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n\n        for (int j = 0; j < qk; ++j) {\n            y[i*qk + j] = x[i].qs[j]*d;\n        }\n    }\n}\n\nvoid dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    static const int qk = QK_MXFP4;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n        const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);\n\n        for (int j = 0; j < qk/2; ++j) {\n            const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];\n            const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >>   4];\n\n            y[i*qk + j + 0   ] = x0*d;\n            y[i*qk + j + qk/2] = x1*d;\n        }\n    }\n}\n\nvoid dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    static const int qk = QK_NVFP4;\n    static const int qk_sub = QK_NVFP4_SUB;\n    static const int n_sub = QK_NVFP4 / QK_NVFP4_SUB;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n        for (int s = 0; s < n_sub; s++) {\n            const float d = ggml_ue4m3_to_fp32(x[i].d[s]);\n            float * yb = y + i*qk + s*qk_sub;\n\n            for (int j = 0; j < qk_sub/2; ++j) {\n                const int8_t v0 = kvalues_mxfp4[x[i].qs[s*(qk_sub/2) + j] & 0x0F];\n                const int8_t v1 = kvalues_mxfp4[x[i].qs[s*(qk_sub/2) + j] >>   4];\n\n                yb[j + 0       ] = v0*d;\n                yb[j + qk_sub/2] = v1*d;\n            }\n        }\n    }\n}\n\n//\n// 2-6 bit quantization in super-blocks\n//\n\n//\n// ===================== Helper functions\n//\nstatic inline int nearest_int(float fval) {\n    assert(fabsf(fval) <= 4194303.f);\n    float val = fval + 12582912.f;\n    int i; memcpy(&i, &val, sizeof(int));\n    return (i & 0x007fffff) - 0x00400000;\n}\n\nstatic float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, int rmse_type,\n        const float * GGML_RESTRICT qw) {\n    float max = 0;\n    float amax = 0;\n    for (int i = 0; i < n; ++i) {\n        float ax = fabsf(x[i]);\n        if (ax > amax) { amax = ax; max = x[i]; }\n    }\n    if (amax < GROUP_MAX_EPS) { // all zero\n        for (int i = 0; i < n; ++i) {\n            L[i] = 0;\n        }\n        return 0.f;\n    }\n    float iscale = -nmax / max;\n    if (rmse_type == 0) {\n        for (int i = 0; i < n; ++i) {\n            int l = nearest_int(iscale * x[i]);\n            L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));\n        }\n        return 1/iscale;\n    }\n    bool return_early = false;\n    if (rmse_type < 0) {\n        rmse_type = -rmse_type;\n        return_early = true;\n    }\n    float sumlx = 0;\n    float suml2 = 0;\n#ifdef HAVE_BUGGY_APPLE_LINKER\n    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7\n    for (volatile int i = 0; i < n; ++i) {\n#else\n    for (int i = 0; i < n; ++i) {\n#endif\n        int l = nearest_int(iscale * x[i]);\n        l = MAX(-nmax, MIN(nmax-1, l));\n        L[i] = l + nmax;\n        float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));\n        sumlx += w*x[i]*l;\n        suml2 += w*l*l;\n    }\n    float scale = suml2 ? sumlx/suml2 : 0.0f;\n    if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;\n    float best = scale * sumlx;\n    for (int is = -9; is <= 9; ++is) {\n        if (is == 0) {\n            continue;\n        }\n        iscale = -(nmax + 0.1f*is) / max;\n        sumlx = suml2 = 0;\n        for (int i = 0; i < n; ++i) {\n            int l = nearest_int(iscale * x[i]);\n            l = MAX(-nmax, MIN(nmax-1, l));\n            float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));\n            sumlx += w*x[i]*l;\n            suml2 += w*l*l;\n        }\n        if (suml2 > 0 && sumlx*sumlx > best*suml2) {\n            for (int i = 0; i < n; ++i) {\n                int l = nearest_int(iscale * x[i]);\n                L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));\n            }\n            scale = sumlx/suml2; best = scale*sumlx;\n        }\n    }\n    return scale;\n}\n\nstatic float make_q3_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, bool do_rmse) {\n    float max = 0;\n    float amax = 0;\n    for (int i = 0; i < n; ++i) {\n        float ax = fabsf(x[i]);\n        if (ax > amax) { amax = ax; max = x[i]; }\n    }\n    if (amax < GROUP_MAX_EPS) { // all zero\n        for (int i = 0; i < n; ++i) { L[i] = 0; }\n        return 0.f;\n    }\n    float iscale = -nmax / max;\n    if (do_rmse) {\n        float sumlx = 0;\n        float suml2 = 0;\n        for (int i = 0; i < n; ++i) {\n            int l = nearest_int(iscale * x[i]);\n            l = MAX(-nmax, MIN(nmax-1, l));\n            L[i] = l;\n            float w = x[i]*x[i];\n            sumlx += w*x[i]*l;\n            suml2 += w*l*l;\n        }\n        for (int itry = 0; itry < 5; ++itry) {\n            int n_changed = 0;\n            for (int i = 0; i < n; ++i) {\n                float w = x[i]*x[i];\n                float slx = sumlx - w*x[i]*L[i];\n                if (slx > 0) {\n                    float sl2 = suml2 - w*L[i]*L[i];\n                    int new_l = nearest_int(x[i] * sl2 / slx);\n                    new_l = MAX(-nmax, MIN(nmax-1, new_l));\n                    if (new_l != L[i]) {\n                        slx += w*x[i]*new_l;\n                        sl2 += w*new_l*new_l;\n                        if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {\n                            L[i] = new_l; sumlx = slx; suml2 = sl2;\n                            ++n_changed;\n                        }\n                    }\n                }\n            }\n            if (!n_changed) {\n                break;\n            }\n        }\n        for (int i = 0; i < n; ++i) {\n            L[i] += nmax;\n        }\n        return suml2 > 0.0f ? sumlx / suml2 : 0.0f;\n    }\n    for (int i = 0; i < n; ++i) {\n        int l = nearest_int(iscale * x[i]);\n        l = MAX(-nmax, MIN(nmax-1, l));\n        L[i] = l + nmax;\n    }\n    return 1/iscale;\n}\n\nstatic float make_qkx1_quants(int n, int nmax, const float * GGML_RESTRICT x, uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min,\n        int ntry, float alpha) {\n    float min = x[0];\n    float max = x[0];\n    for (int i = 1; i < n; ++i) {\n        if (x[i] < min) min = x[i];\n        if (x[i] > max) max = x[i];\n    }\n    if (max == min) {\n        for (int i = 0; i < n; ++i) L[i] = 0;\n        *the_min = 0;\n        return 0.f;\n    }\n    if (min > 0) min = 0;\n    float iscale = nmax/(max - min);\n    float scale = 1/iscale;\n    for (int itry = 0; itry < ntry; ++itry) {\n        float sumlx = 0; int suml2 = 0;\n        bool did_change = false;\n        for (int i = 0; i < n; ++i) {\n            int l = nearest_int(iscale*(x[i] - min));\n            l = MAX(0, MIN(nmax, l));\n            if (l != L[i]) {\n                L[i] = l;\n                did_change = true;\n            }\n            sumlx += (x[i] - min)*l;\n            suml2 += l*l;\n        }\n        scale = sumlx/suml2;\n        float sum = 0;\n        for (int i = 0; i < n; ++i) {\n            sum += x[i] - scale*L[i];\n        }\n        min = alpha*min + (1 - alpha)*sum/n;\n        if (min > 0) min = 0;\n        iscale = 1/scale;\n        if (!did_change) break;\n    }\n    *the_min = -min;\n    return scale;\n}\n\nstatic float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, const float * GGML_RESTRICT weights,\n        uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min, uint8_t * GGML_RESTRICT Laux,\n        float rmin, float rdelta, int nstep, bool use_mad) {\n    float min = x[0];\n    float max = x[0];\n    float sum_w = weights[0];\n    float sum_x = sum_w * x[0];\n#ifdef HAVE_BUGGY_APPLE_LINKER\n    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7\n    for (volatile int i = 1; i < n; ++i) {\n#else\n    for (int i = 1; i < n; ++i) {\n#endif\n        if (x[i] < min) min = x[i];\n        if (x[i] > max) max = x[i];\n        float w = weights[i];\n        sum_w += w;\n        sum_x += w * x[i];\n    }\n    if (min > 0) min = 0;\n    if (max == min) {\n        for (int i = 0; i < n; ++i) L[i] = 0;\n        *the_min = -min;\n        return 0.f;\n    }\n    float iscale = nmax/(max - min);\n    float scale = 1/iscale;\n    float best_error = 0;\n    for (int i = 0; i < n; ++i) {\n        int l = nearest_int(iscale*(x[i] - min));\n        L[i] = MAX(0, MIN(nmax, l));\n        float diff = scale * L[i] + min - x[i];\n        diff = use_mad ? fabsf(diff) : diff * diff;\n        float w = weights[i];\n        best_error += w * diff;\n    }\n    if (nstep < 1) {\n        *the_min = -min;\n        return scale;\n    }\n    for (int is = 0; is <= nstep; ++is) {\n        iscale = (rmin + rdelta*is + nmax)/(max - min);\n        float sum_l = 0, sum_l2 = 0, sum_xl = 0;\n        for (int i = 0; i < n; ++i) {\n            int l = nearest_int(iscale*(x[i] - min));\n            l = MAX(0, MIN(nmax, l));\n            Laux[i] = l;\n            float w = weights[i];\n            sum_l += w*l;\n            sum_l2 += w*l*l;\n            sum_xl += w*l*x[i];\n        }\n        float D = sum_w * sum_l2 - sum_l * sum_l;\n        if (D > 0) {\n            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;\n            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;\n            if (this_min > 0) {\n                this_min = 0;\n                this_scale = sum_xl / sum_l2;\n            }\n            float cur_error = 0;\n            for (int i = 0; i < n; ++i) {\n                float diff = this_scale * Laux[i] + this_min - x[i];\n                diff = use_mad ? fabsf(diff) : diff * diff;\n                float w = weights[i];\n                cur_error += w * diff;\n            }\n            if (cur_error < best_error) {\n                for (int i = 0; i < n; ++i) {\n                    L[i] = Laux[i];\n                }\n                best_error = cur_error;\n                scale = this_scale;\n                min = this_min;\n            }\n        }\n    }\n    *the_min = -min;\n    return scale;\n}\n\nstatic inline void get_scale_min_k4(int j, const uint8_t * GGML_RESTRICT q, uint8_t * GGML_RESTRICT d, uint8_t * GGML_RESTRICT m) {\n    if (j < 4) {\n        *d = q[j] & 63; *m = q[j + 4] & 63;\n    } else {\n        *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);\n        *m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);\n    }\n}\n\n//========================- 2-bit (de)-quantization\n\nvoid quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int nb = k / QK_K;\n\n    uint8_t L[QK_K];\n    uint8_t Laux[16];\n    float   weights[16];\n    float mins[QK_K/16];\n    float scales[QK_K/16];\n\n    const float q4scale = 15.f;\n\n    for (int i = 0; i < nb; i++) {\n        float max_scale = 0; // as we are deducting the min, scales are always positive\n        float max_min = 0;\n        for (int j = 0; j < QK_K/16; ++j) {\n            for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);\n            scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);\n            float scale = scales[j];\n            if (scale > max_scale) {\n                max_scale = scale;\n            }\n            float min = mins[j];\n            if (min > max_min) {\n                max_min = min;\n            }\n        }\n\n        if (max_scale > 0) {\n            float iscale = q4scale/max_scale;\n            for (int j = 0; j < QK_K/16; ++j) {\n                int l = nearest_int(iscale*scales[j]);\n                y[i].scales[j] = l;\n            }\n            y[i].d = GGML_FP32_TO_FP16(max_scale/q4scale);\n        } else {\n            for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;\n            y[i].d = GGML_FP32_TO_FP16(0.f);\n        }\n        if (max_min > 0) {\n            float iscale = q4scale/max_min;\n            for (int j = 0; j < QK_K/16; ++j) {\n                int l = nearest_int(iscale*mins[j]);\n                y[i].scales[j] |= (l << 4);\n            }\n            y[i].dmin = GGML_FP32_TO_FP16(max_min/q4scale);\n        } else {\n            y[i].dmin = GGML_FP32_TO_FP16(0.f);\n        }\n        for (int j = 0; j < QK_K/16; ++j) {\n            const float d = GGML_FP16_TO_FP32(y[i].d) * (y[i].scales[j] & 0xF);\n            if (!d) continue;\n            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * (y[i].scales[j] >> 4);\n            for (int ii = 0; ii < 16; ++ii) {\n                int l = nearest_int((x[16*j + ii] + dm)/d);\n                l = MAX(0, MIN(3, l));\n                L[16*j + ii] = l;\n            }\n        }\n\n        for (int j = 0; j < QK_K; j += 128) {\n            for (int l = 0; l < 32; ++l) {\n                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);\n            }\n        }\n\n        x += QK_K;\n    }\n}\n\nvoid dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int nb = k / QK_K;\n\n    for (int i = 0; i < nb; i++) {\n\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n        const float min = GGML_FP16_TO_FP32(x[i].dmin);\n\n        const uint8_t * q = x[i].qs;\n\n        int is = 0;\n        float dl, ml;\n        for (int n = 0; n < QK_K; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n\n                uint8_t sc = x[i].scales[is++];\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;\n\n                sc = x[i].scales[is++];\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;\n\n                shift += 2;\n            }\n            q += 32;\n        }\n    }\n}\n\nstatic float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, const float * GGML_RESTRICT weights,\n        uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min, uint8_t * GGML_RESTRICT Laux,\n        float rmin, float rdelta, int nstep, bool use_mad) {\n    float min = x[0];\n    float max = x[0];\n    float sum_w = weights ? weights[0] : x[0]*x[0];\n    float sum_x = sum_w * x[0];\n#ifdef HAVE_BUGGY_APPLE_LINKER\n    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7\n    for (volatile int i = 1; i < n; ++i) {\n#else\n    for (int i = 1; i < n; ++i) {\n#endif\n        if (x[i] < min) min = x[i];\n        if (x[i] > max) max = x[i];\n        float w = weights ? weights[i] : x[i]*x[i];\n        sum_w += w;\n        sum_x += w * x[i];\n    }\n    if (min > 0) {\n        min = 0;\n    }\n    if (max <= min) {\n        memset(L, 0, n);\n        *the_min = -min;\n        return 0.f;\n    }\n    float iscale = nmax/(max - min);\n    float scale = 1/iscale;\n    float best_mad = 0;\n    for (int i = 0; i < n; ++i) {\n        int l = nearest_int(iscale*(x[i] - min));\n        L[i] = MAX(0, MIN(nmax, l));\n        float diff = scale * L[i] + min - x[i];\n        diff = use_mad ? fabsf(diff) : diff*diff;\n        float w = weights ? weights[i] : x[i]*x[i];\n        best_mad += w * diff;\n    }\n    if (nstep < 1) {\n        *the_min = -min;\n        return scale;\n    }\n    for (int is = 0; is <= nstep; ++is) {\n        iscale = (rmin + rdelta*is + nmax)/(max - min);\n        float sum_l = 0, sum_l2 = 0, sum_xl = 0;\n        for (int i = 0; i < n; ++i) {\n            int l = nearest_int(iscale*(x[i] - min));\n            l = MAX(0, MIN(nmax, l));\n            Laux[i] = l;\n            float w = weights ? weights[i] : x[i]*x[i];\n            sum_l  += w*l;\n            sum_l2 += w*l*l;\n            sum_xl += w*l*x[i];\n        }\n        float D = sum_w * sum_l2 - sum_l * sum_l;\n        if (D > 0) {\n            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;\n            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;\n            if (this_min > 0) {\n                this_min = 0;\n                this_scale = sum_xl / sum_l2;\n            }\n            float mad = 0;\n            for (int i = 0; i < n; ++i) {\n                float diff = this_scale * Laux[i] + this_min - x[i];\n                diff = use_mad ? fabsf(diff) : diff*diff;\n                float w = weights ? weights[i] : x[i]*x[i];\n                mad += w * diff;\n            }\n            if (mad < best_mad) {\n                for (int i = 0; i < n; ++i) {\n                    L[i] = Laux[i];\n                }\n                best_mad = mad;\n                scale = this_scale;\n                min = this_min;\n            }\n        }\n    }\n    *the_min = -min;\n    return scale;\n}\n\nstatic float make_qp_quants(int n, int nmax, const float * GGML_RESTRICT x, uint8_t * GGML_RESTRICT L, const float * quant_weights) {\n    float max = 0;\n    for (int i = 0; i < n; ++i) {\n        max = MAX(max, x[i]);\n    }\n    if (max < GROUP_MAX_EPS) { // all zero\n        for (int i = 0; i < n; ++i) { L[i] = 0; }\n        return 0.f;\n    }\n    float iscale = nmax / max;\n    for (int i = 0; i < n; ++i) {\n        L[i] = nearest_int(iscale * x[i]);\n    }\n    float scale = 1/iscale;\n    float best_mse = 0;\n    for (int i = 0; i < n; ++i) {\n        float diff = x[i] - scale*L[i];\n        float w = quant_weights[i];\n        best_mse += w*diff*diff;\n    }\n    for (int is = -4; is <= 4; ++is) {\n        if (is == 0) continue;\n        float iscale_is = (0.1f*is + nmax)/max;\n        float scale_is = 1/iscale_is;\n        float mse = 0;\n        for (int i = 0; i < n; ++i) {\n            int l = nearest_int(iscale_is*x[i]);\n            l = MIN(nmax, l);\n            float diff = x[i] - scale_is*l;\n            float w = quant_weights[i];\n            mse += w*diff*diff;\n        }\n        if (mse < best_mse) {\n            best_mse = mse;\n            iscale = iscale_is;\n        }\n    }\n    float sumlx = 0;\n    float suml2 = 0;\n    for (int i = 0; i < n; ++i) {\n        int l = nearest_int(iscale * x[i]);\n        l = MIN(nmax, l);\n        L[i] = l;\n        float w = quant_weights[i];\n        sumlx += w*x[i]*l;\n        suml2 += w*l*l;\n    }\n    for (int itry = 0; itry < 5; ++itry) {\n        int n_changed = 0;\n        for (int i = 0; i < n; ++i) {\n            float w = quant_weights[i];\n            float slx = sumlx - w*x[i]*L[i];\n            float sl2 = suml2 - w*L[i]*L[i];\n            if (slx > 0 && sl2 > 0) {\n                int new_l = nearest_int(x[i] * sl2 / slx);\n                new_l = MIN(nmax, new_l);\n                if (new_l != L[i]) {\n                    slx += w*x[i]*new_l;\n                    sl2 += w*new_l*new_l;\n                    if (slx*slx*suml2 > sumlx*sumlx*sl2) {\n                        L[i] = new_l; sumlx = slx; suml2 = sl2;\n                        ++n_changed;\n                    }\n                }\n            }\n        }\n        if (!n_changed) {\n            break;\n        }\n    }\n    return suml2 > 0.0f ? sumlx / suml2 : 0.0f;\n}\n\nstatic void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int k, const float * GGML_RESTRICT quant_weights) {\n    GGML_ASSERT(quant_weights);\n    assert(k % QK_K == 0);\n    const int nb = k / QK_K;\n    const bool requantize = true;\n\n    uint8_t L[QK_K];\n    uint8_t Laux[16];\n    float mins[QK_K/16];\n    float scales[QK_K/16];\n    float sw[QK_K/16];\n    float weight[16];\n    uint8_t Ls[QK_K/16], Lm[QK_K/16];\n\n    for (int i = 0; i < nb; i++) {\n        memset(sw, 0, QK_K/16*sizeof(float));\n        float sumx2 = 0;\n        for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];\n        float sigma2 = sumx2/QK_K;\n        for (int j = 0; j < QK_K/16; ++j) {\n            const float * GGML_RESTRICT qw = quant_weights + QK_K * i + 16*j;\n            for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);\n            for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];\n            scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);\n        }\n\n        float dm, mm;\n        dm  = make_qp_quants(QK_K/16, 15, scales, Ls, sw);\n        mm  = make_qp_quants(QK_K/16, 15, mins,   Lm, sw);\n\n        y[i].d    = GGML_FP32_TO_FP16(dm);\n        y[i].dmin = GGML_FP32_TO_FP16(mm);\n        dm        = GGML_FP16_TO_FP32(y[i].d);\n        mm        = GGML_FP16_TO_FP32(y[i].dmin);\n\n        for (int j = 0; j < QK_K/16; ++j) {\n            y[i].scales[j] = Ls[j] | (Lm[j] << 4);\n        }\n\n        if (requantize) {\n            for (int j = 0; j < QK_K/16; ++j) {\n                const float d = dm * (y[i].scales[j] & 0xF);\n                if (!d) continue;\n                const float m = mm * (y[i].scales[j] >> 4);\n                for (int ii = 0; ii < 16; ++ii) {\n                    int l = nearest_int((x[16*j + ii] + m)/d);\n                    l = MAX(0, MIN(3, l));\n                    L[16*j + ii] = l;\n                }\n            }\n        }\n\n        for (int j = 0; j < QK_K; j += 128) {\n            for (int l = 0; l < 32; ++l) {\n                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);\n            }\n        }\n\n        x += QK_K;\n    }\n}\n\nsize_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);\n    if (!quant_weights) {\n        quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row);\n    }\n    else {\n        char * qrow = (char *)dst;\n        for (int64_t row = 0; row < nrow; ++row) {\n            quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights);\n            src += n_per_row;\n            qrow += row_size;\n        }\n    }\n    return nrow * row_size;\n}\n\n//========================= 3-bit (de)-quantization\n\nvoid quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int nb = k / QK_K;\n\n    int8_t L[QK_K];\n    float scales[QK_K / 16];\n\n    for (int i = 0; i < nb; i++) {\n\n        float max_scale = 0;\n        float amax = 0;\n        for (int j = 0; j < QK_K/16; ++j) {\n            scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);\n            float scale = fabsf(scales[j]);\n            if (scale > amax) {\n                amax = scale; max_scale = scales[j];\n            }\n        }\n\n        memset(y[i].scales, 0, 12);\n        if (max_scale) {\n            float iscale = -32.f/max_scale;\n            for (int j = 0; j < QK_K/16; ++j) {\n                int8_t l = nearest_int(iscale*scales[j]);\n                l = MAX(-32, MIN(31, l)) + 32;\n                if (j < 8) {\n                    y[i].scales[j] = l & 0xF;\n                } else {\n                    y[i].scales[j-8] |= ((l & 0xF) << 4);\n                }\n                l >>= 4;\n                y[i].scales[j%4 + 8] |= (l << (2*(j/4)));\n            }\n            y[i].d = GGML_FP32_TO_FP16(1/iscale);\n        } else {\n            y[i].d = GGML_FP32_TO_FP16(0.f);\n        }\n\n        int8_t sc;\n        for (int j = 0; j < QK_K/16; ++j) {\n            sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;\n            sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;\n            float d = GGML_FP16_TO_FP32(y[i].d) * sc;\n            if (!d) {\n                continue;\n            }\n            for (int ii = 0; ii < 16; ++ii) {\n                int l = nearest_int(x[16*j + ii]/d);\n                l = MAX(-4, MIN(3, l));\n                L[16*j + ii] = l + 4;\n            }\n        }\n\n        memset(y[i].hmask, 0, QK_K/8);\n        // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.\n        int m = 0;\n        uint8_t hm = 1;\n        for (int j = 0; j < QK_K; ++j) {\n            if (L[j] > 3) {\n                y[i].hmask[m] |= hm;\n                L[j] -= 4;\n            }\n            if (++m == QK_K/8) {\n                m = 0; hm <<= 1;\n            }\n        }\n        for (int j = 0; j < QK_K; j += 128) {\n            for (int l = 0; l < 32; ++l) {\n                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);\n            }\n        }\n\n        x += QK_K;\n    }\n}\n\nvoid dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int nb = k / QK_K;\n\n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n\n    uint32_t aux[4];\n    const int8_t * scales = (const int8_t*)aux;\n\n    for (int i = 0; i < nb; i++) {\n\n        const float d_all = GGML_FP16_TO_FP32(x[i].d);\n\n        const uint8_t * GGML_RESTRICT q = x[i].qs;\n        const uint8_t * GGML_RESTRICT hm = x[i].hmask;\n        uint8_t m = 1;\n\n        memcpy(aux, x[i].scales, 12);\n        uint32_t tmp = aux[2];\n        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);\n        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);\n        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);\n        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);\n\n        int is = 0;\n        float dl;\n        for (int n = 0; n < QK_K; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));\n                }\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));\n                }\n\n                shift += 2;\n                m <<= 1;\n            }\n            q += 32;\n        }\n\n    }\n}\n\nstatic void quantize_row_q3_K_impl(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t n_per_row, const float * GGML_RESTRICT quant_weights) {\n    assert(n_per_row % QK_K == 0);\n    const int nb = n_per_row / QK_K;\n\n    int8_t L[QK_K];\n    float scales[QK_K / 16];\n    float weight[16];\n    float sw[QK_K / 16];\n    int8_t Ls[QK_K / 16];\n\n    for (int i = 0; i < nb; i++) {\n\n        float sumx2 = 0;\n        for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];\n        float sigma2 = 2*sumx2/QK_K;\n\n        for (int j = 0; j < QK_K/16; ++j) {\n            if (quant_weights) {\n                const float * qw = quant_weights + QK_K * i + 16*j;\n                for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j+l]*x[16*j+l]);\n            } else {\n                for (int l = 0; l < 16; ++l) weight[l] = x[16*j+l]*x[16*j+l];\n            }\n            float sumw = 0;\n            for (int l = 0; l < 16; ++l) sumw += weight[l];\n            sw[j] = sumw;\n\n            scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight);\n\n        }\n\n        memset(y[i].scales, 0, 12);\n\n        float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw);\n        for (int j = 0; j < QK_K/16; ++j) {\n            int l = Ls[j];\n            if (j < 8) {\n                y[i].scales[j] = l & 0xF;\n            } else {\n                y[i].scales[j-8] |= ((l & 0xF) << 4);\n            }\n            l >>= 4;\n            y[i].scales[j%4 + 8] |= (l << (2*(j/4)));\n        }\n        y[i].d = GGML_FP32_TO_FP16(d_block);\n\n        int8_t sc;\n        for (int j = 0; j < QK_K/16; ++j) {\n            sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;\n            sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;\n            float d = GGML_FP16_TO_FP32(y[i].d) * sc;\n            if (!d) {\n                continue;\n            }\n            for (int ii = 0; ii < 16; ++ii) {\n                int l = nearest_int(x[16*j + ii]/d);\n                l = MAX(-4, MIN(3, l));\n                L[16*j + ii] = l + 4;\n            }\n        }\n\n        memset(y[i].hmask, 0, QK_K/8);\n        // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.\n        int m = 0;\n        uint8_t hm = 1;\n        for (int j = 0; j < QK_K; ++j) {\n            if (L[j] > 3) {\n                y[i].hmask[m] |= hm;\n                L[j] -= 4;\n            }\n            if (++m == QK_K/8) {\n                m = 0; hm <<= 1;\n            }\n        }\n        for (int j = 0; j < QK_K; j += 128) {\n            for (int l = 0; l < 32; ++l) {\n                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);\n            }\n        }\n\n        x += QK_K;\n    }\n}\n\nsize_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);\n    if (!quant_weights) {\n        quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row);\n    }\n    else {\n        char * qrow = (char *)dst;\n        for (int64_t row = 0; row < nrow; ++row) {\n            quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights);\n            src += n_per_row;\n            qrow += row_size;\n        }\n    }\n    return nrow * row_size;\n}\n\n// ====================== 4-bit (de)-quantization\n\nvoid quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int nb = k / QK_K;\n\n    uint8_t L[QK_K];\n    uint8_t Laux[32];\n    float   weights[32];\n    float mins[QK_K/32];\n    float scales[QK_K/32];\n\n    for (int i = 0; i < nb; i++) {\n        float max_scale = 0; // as we are deducting the min, scales are always positive\n        float max_min = 0;\n        for (int j = 0; j < QK_K/32; ++j) {\n            //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);\n            float sum_x2 = 0;\n            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];\n            float av_x = sqrtf(sum_x2/32);\n            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);\n            scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);\n            float scale = scales[j];\n            if (scale > max_scale) {\n                max_scale = scale;\n            }\n            float min = mins[j];\n            if (min > max_min) {\n                max_min = min;\n            }\n        }\n\n        float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;\n        float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;\n        for (int j = 0; j < QK_K/32; ++j) {\n            uint8_t ls = nearest_int(inv_scale*scales[j]);\n            uint8_t lm = nearest_int(inv_min*mins[j]);\n            ls = MIN(63, ls);\n            lm = MIN(63, lm);\n            if (j < 4) {\n                y[i].scales[j] = ls;\n                y[i].scales[j+4] = lm;\n            } else {\n                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);\n                y[i].scales[j-4] |= ((ls >> 4) << 6);\n                y[i].scales[j-0] |= ((lm >> 4) << 6);\n            }\n        }\n        y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);\n        y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);\n\n        uint8_t sc, m;\n        for (int j = 0; j < QK_K/32; ++j) {\n            get_scale_min_k4(j, y[i].scales, &sc, &m);\n            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;\n            if (!d) continue;\n            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;\n            for (int ii = 0; ii < 32; ++ii) {\n                int l = nearest_int((x[32*j + ii] + dm)/d);\n                l = MAX(0, MIN(15, l));\n                L[32*j + ii] = l;\n            }\n        }\n\n        uint8_t * q = y[i].qs;\n        for (int j = 0; j < QK_K; j += 64) {\n            for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);\n            q += 32;\n        }\n\n        x += QK_K;\n    }\n}\n\nvoid dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int nb = k / QK_K;\n\n    for (int i = 0; i < nb; i++) {\n        const uint8_t * q = x[i].qs;\n\n        const float d   = GGML_FP16_TO_FP32(x[i].d);\n        const float min = GGML_FP16_TO_FP32(x[i].dmin);\n\n        int is = 0;\n        uint8_t sc, m;\n        for (int j = 0; j < QK_K; j += 64) {\n            get_scale_min_k4(is + 0, x[i].scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, x[i].scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;\n            for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l]  >> 4) - m2;\n            q += 32; is += 2;\n        }\n    }\n}\n\nstatic void quantize_row_q4_K_impl(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {\n    assert(n_per_row % QK_K == 0);\n    const int64_t nb = n_per_row / QK_K;\n\n    uint8_t L[QK_K];\n    uint8_t Laux[32];\n    uint8_t Ls[QK_K/32];\n    uint8_t Lm[QK_K/32];\n    float   weights[32];\n    float   sw[QK_K/32];\n    float   mins[QK_K/32];\n    float   scales[QK_K/32];\n\n    for (int i = 0; i < nb; i++) {\n\n        float sum_x2 = 0;\n        for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];\n        float sigma2 = 2*sum_x2/QK_K;\n        float av_x = sqrtf(sigma2);\n\n        for (int j = 0; j < QK_K/32; ++j) {\n            if (quant_weights) {\n                const float * qw = quant_weights + QK_K*i + 32*j;\n                for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);\n            } else {\n                for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);\n            }\n            float sumw = 0;\n            for (int l = 0; l < 32; ++l) sumw += weights[l];\n            sw[j] = sumw;\n            scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);\n        }\n\n        float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw);\n        float m_block = make_qp_quants(QK_K/32, 63, mins,   Lm, sw);\n        for (int j = 0; j < QK_K/32; ++j) {\n            uint8_t ls = Ls[j];\n            uint8_t lm = Lm[j];\n            if (j < 4) {\n                y[i].scales[j] = ls;\n                y[i].scales[j+4] = lm;\n            } else {\n                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);\n                y[i].scales[j-4] |= ((ls >> 4) << 6);\n                y[i].scales[j-0] |= ((lm >> 4) << 6);\n            }\n        }\n        y[i].d = GGML_FP32_TO_FP16(d_block);\n        y[i].dmin = GGML_FP32_TO_FP16(m_block);\n\n        uint8_t sc, m;\n        for (int j = 0; j < QK_K/32; ++j) {\n            get_scale_min_k4(j, y[i].scales, &sc, &m);\n            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;\n            if (!d) continue;\n            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;\n            for (int ii = 0; ii < 32; ++ii) {\n                int l = nearest_int((x[32*j + ii] + dm)/d);\n                l = MAX(0, MIN(15, l));\n                L[32*j + ii] = l;\n            }\n        }\n        uint8_t * q = y[i].qs;\n        for (int j = 0; j < QK_K; j += 64) {\n            for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);\n            q += 32;\n        }\n\n        x += QK_K;\n\n    }\n}\n\nsize_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);\n    if (!quant_weights) {\n        quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row);\n    }\n    else {\n        char * qrow = (char *)dst;\n        for (int64_t row = 0; row < nrow; ++row) {\n            quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights);\n            src += n_per_row;\n            qrow += row_size;\n        }\n    }\n    return nrow * row_size;\n}\n\n// ====================== 5-bit (de)-quantization\n\nvoid quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    uint8_t L[QK_K];\n    float mins[QK_K/32];\n    float scales[QK_K/32];\n    float weights[32];\n    uint8_t Laux[32];\n\n    for (int i = 0; i < nb; i++) {\n        float max_scale = 0; // as we are deducting the min, scales are always positive\n        float max_min = 0;\n        for (int j = 0; j < QK_K/32; ++j) {\n            //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);\n            float sum_x2 = 0;\n            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];\n            float av_x = sqrtf(sum_x2/32);\n            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);\n            scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);\n            float scale = scales[j];\n            if (scale > max_scale) {\n                max_scale = scale;\n            }\n            float min = mins[j];\n            if (min > max_min) {\n                max_min = min;\n            }\n        }\n\n        float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;\n        float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;\n        for (int j = 0; j < QK_K/32; ++j) {\n            uint8_t ls = nearest_int(inv_scale*scales[j]);\n            uint8_t lm = nearest_int(inv_min*mins[j]);\n            ls = MIN(63, ls);\n            lm = MIN(63, lm);\n            if (j < 4) {\n                y[i].scales[j] = ls;\n                y[i].scales[j+4] = lm;\n            } else {\n                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);\n                y[i].scales[j-4] |= ((ls >> 4) << 6);\n                y[i].scales[j-0] |= ((lm >> 4) << 6);\n            }\n        }\n        y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);\n        y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);\n\n        uint8_t sc, m;\n        for (int j = 0; j < QK_K/32; ++j) {\n            get_scale_min_k4(j, y[i].scales, &sc, &m);\n            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;\n            if (!d) continue;\n            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;\n            for (int ii = 0; ii < 32; ++ii) {\n                int l = nearest_int((x[32*j + ii] + dm)/d);\n                l = MAX(0, MIN(31, l));\n                L[32*j + ii] = l;\n            }\n        }\n\n        uint8_t * GGML_RESTRICT qh = y[i].qh;\n        uint8_t * GGML_RESTRICT ql = y[i].qs;\n        memset(qh, 0, QK_K/8);\n\n        uint8_t m1 = 1, m2 = 2;\n        for (int n = 0; n < QK_K; n += 64) {\n            for (int j = 0; j < 32; ++j) {\n                int l1 = L[n + j];\n                if (l1 > 15) {\n                    l1 -= 16; qh[j] |= m1;\n                }\n                int l2 = L[n + j + 32];\n                if (l2 > 15) {\n                    l2 -= 16; qh[j] |= m2;\n                }\n                ql[j] = l1 | (l2 << 4);\n            }\n            m1 <<= 2; m2 <<= 2;\n            ql += 32;\n        }\n\n        x += QK_K;\n    }\n}\n\nvoid dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    for (int i = 0; i < nb; i++) {\n        const uint8_t * ql = x[i].qs;\n        const uint8_t * qh = x[i].qh;\n\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n        const float min = GGML_FP16_TO_FP32(x[i].dmin);\n\n        int is = 0;\n        uint8_t sc, m;\n        uint8_t u1 = 1, u2 = 2;\n        for (int j = 0; j < QK_K; j += 64) {\n            get_scale_min_k4(is + 0, x[i].scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, x[i].scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;\n            for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;\n            ql += 32; is += 2;\n            u1 <<= 2; u2 <<= 2;\n        }\n    }\n}\n\nstatic void quantize_row_q5_K_impl(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {\n    assert(n_per_row % QK_K == 0);\n    const int64_t nb = n_per_row / QK_K;\n\n    uint8_t L[QK_K];\n    uint8_t Laux[32];\n    uint8_t Ls[QK_K/32];\n    uint8_t Lm[QK_K/32];\n    float   mins[QK_K/32];\n    float   scales[QK_K/32];\n    float   sw[QK_K/32];\n    float   weights[32];\n\n    for (int i = 0; i < nb; i++) {\n\n        float sum_x2 = 0;\n        for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];\n        float sigma2 = 2*sum_x2/QK_K;\n        float av_x = sqrtf(sigma2);\n\n        for (int j = 0; j < QK_K/32; ++j) {\n            if (quant_weights) {\n                const float * qw = quant_weights + QK_K*i + 32*j;\n                for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);\n            } else {\n                for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);\n            }\n            float sumw = 0;\n            for (int l = 0; l < 32; ++l) sumw += weights[l];\n            sw[j] = sumw;\n\n            scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);\n        }\n\n        float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw);\n        float m_block = make_qp_quants(QK_K/32, 63, mins,   Lm, sw);\n\n        for (int j = 0; j < QK_K/32; ++j) {\n            uint8_t ls = Ls[j];\n            uint8_t lm = Lm[j];\n            ls = MIN(63, ls);\n            lm = MIN(63, lm);\n            if (j < 4) {\n                y[i].scales[j] = ls;\n                y[i].scales[j+4] = lm;\n            } else {\n                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);\n                y[i].scales[j-4] |= ((ls >> 4) << 6);\n                y[i].scales[j-0] |= ((lm >> 4) << 6);\n            }\n        }\n        y[i].d = GGML_FP32_TO_FP16(d_block);\n        y[i].dmin = GGML_FP32_TO_FP16(m_block);\n\n        uint8_t sc, m;\n        for (int j = 0; j < QK_K/32; ++j) {\n            get_scale_min_k4(j, y[i].scales, &sc, &m);\n            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;\n            if (!d) continue;\n            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;\n            for (int ii = 0; ii < 32; ++ii) {\n                int l = nearest_int((x[32*j + ii] + dm)/d);\n                l = MAX(0, MIN(31, l));\n                L[32*j + ii] = l;\n            }\n        }\n\n        uint8_t * GGML_RESTRICT qh = y[i].qh;\n        uint8_t * GGML_RESTRICT ql = y[i].qs;\n        memset(qh, 0, QK_K/8);\n\n        uint8_t m1 = 1, m2 = 2;\n        for (int n = 0; n < QK_K; n += 64) {\n            for (int j = 0; j < 32; ++j) {\n                int l1 = L[n + j];\n                if (l1 > 15) {\n                    l1 -= 16; qh[j] |= m1;\n                }\n                int l2 = L[n + j + 32];\n                if (l2 > 15) {\n                    l2 -= 16; qh[j] |= m2;\n                }\n                ql[j] = l1 | (l2 << 4);\n            }\n            m1 <<= 2; m2 <<= 2;\n            ql += 32;\n        }\n\n        x += QK_K;\n\n    }\n}\n\nsize_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);\n    if (!quant_weights) {\n        quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row);\n    }\n    else {\n        char * qrow = (char *)dst;\n        for (int64_t row = 0; row < nrow; ++row) {\n            quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights);\n            src += n_per_row;\n            qrow += row_size;\n        }\n    }\n    return nrow * row_size;\n}\n\n// ====================== 6-bit (de)-quantization\n\nvoid quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    int8_t L[QK_K];\n    float   scales[QK_K/16];\n\n    for (int i = 0; i < nb; i++) {\n\n        float max_scale = 0;\n        float max_abs_scale = 0;\n\n        for (int ib = 0; ib < QK_K/16; ++ib) {\n\n            const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);\n            scales[ib] = scale;\n\n            const float abs_scale = fabsf(scale);\n            if (abs_scale > max_abs_scale) {\n                max_abs_scale = abs_scale;\n                max_scale = scale;\n            }\n\n        }\n\n        if (max_abs_scale < GROUP_MAX_EPS) {\n            memset(&y[i], 0, sizeof(block_q6_K));\n            y[i].d = GGML_FP32_TO_FP16(0.f);\n            x += QK_K;\n            continue;\n        }\n\n        float iscale = -128.f/max_scale;\n        y[i].d = GGML_FP32_TO_FP16(1/iscale);\n        for (int ib = 0; ib < QK_K/16; ++ib) {\n            y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));\n        }\n\n        for (int j = 0; j < QK_K/16; ++j) {\n            float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];\n            if (!d) {\n                continue;\n            }\n            for (int ii = 0; ii < 16; ++ii) {\n                int l = nearest_int(x[16*j + ii]/d);\n                l = MAX(-32, MIN(31, l));\n                L[16*j + ii] = l + 32;\n            }\n        }\n\n        uint8_t * GGML_RESTRICT ql = y[i].ql;\n        uint8_t * GGML_RESTRICT qh = y[i].qh;\n        for (int j = 0; j < QK_K; j += 128) {\n            for (int l = 0; l < 32; ++l) {\n                const uint8_t q1 = L[j + l +  0] & 0xF;\n                const uint8_t q2 = L[j + l + 32] & 0xF;\n                const uint8_t q3 = L[j + l + 64] & 0xF;\n                const uint8_t q4 = L[j + l + 96] & 0xF;\n                ql[l+ 0] = q1 | (q3 << 4);\n                ql[l+32] = q2 | (q4 << 4);\n                qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);\n            }\n            ql += 64;\n            qh += 32;\n        }\n\n        x += QK_K;\n    }\n}\n\nvoid dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    for (int i = 0; i < nb; i++) {\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n\n        const uint8_t * GGML_RESTRICT ql = x[i].ql;\n        const uint8_t * GGML_RESTRICT qh = x[i].qh;\n        const int8_t  * GGML_RESTRICT sc = x[i].scales;\n\n        for (int n = 0; n < QK_K; n += 128) {\n            for (int l = 0; l < 32; ++l) {\n                int is = l/16;\n                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;\n                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;\n                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;\n                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;\n                y[l +  0] = d * sc[is + 0] * q1;\n                y[l + 32] = d * sc[is + 2] * q2;\n                y[l + 64] = d * sc[is + 4] * q3;\n                y[l + 96] = d * sc[is + 6] * q4;\n            }\n            y  += 128;\n            ql += 64;\n            qh += 32;\n            sc += 8;\n        }\n    }\n}\n\nstatic void quantize_row_q6_K_impl(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {\n    assert(n_per_row % QK_K == 0);\n    const int64_t nb = n_per_row / QK_K;\n\n    int8_t L[QK_K];\n    float   scales[QK_K/16];\n    //float   weights[16];\n\n    for (int i = 0; i < nb; i++) {\n\n        //float sum_x2 = 0;\n        //for (int j = 0; j < QK_K; ++j) sum_x2 += x[j]*x[j];\n        //float sigma2 = sum_x2/QK_K;\n\n        float max_scale = 0;\n        float max_abs_scale = 0;\n\n        for (int ib = 0; ib < QK_K/16; ++ib) {\n\n            float scale;\n            if (quant_weights) {\n                const float * qw = quant_weights + QK_K*i + 16*ib;\n                //for (int j = 0; j < 16; ++j) weights[j] = qw[j] * sqrtf(sigma2 + x[16*ib + j]*x[16*ib + j]);\n                //scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, weights);\n                scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, qw);\n            } else {\n                scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);\n            }\n            scales[ib] = scale;\n\n            const float abs_scale = fabsf(scale);\n            if (abs_scale > max_abs_scale) {\n                max_abs_scale = abs_scale;\n                max_scale = scale;\n            }\n\n        }\n\n        if (max_abs_scale < GROUP_MAX_EPS) {\n            memset(&y[i], 0, sizeof(block_q6_K));\n            y[i].d = GGML_FP32_TO_FP16(0.f);\n            x += QK_K;\n            continue;\n        }\n\n        float iscale = -128.f/max_scale;\n        y[i].d = GGML_FP32_TO_FP16(1/iscale);\n        for (int ib = 0; ib < QK_K/16; ++ib) {\n            y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));\n        }\n\n        for (int j = 0; j < QK_K/16; ++j) {\n            float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];\n            if (!d) {\n                continue;\n            }\n            for (int ii = 0; ii < 16; ++ii) {\n                int l = nearest_int(x[16*j + ii]/d);\n                l = MAX(-32, MIN(31, l));\n                L[16*j + ii] = l + 32;\n            }\n        }\n\n        uint8_t * GGML_RESTRICT ql = y[i].ql;\n        uint8_t * GGML_RESTRICT qh = y[i].qh;\n        for (int j = 0; j < QK_K; j += 128) {\n            for (int l = 0; l < 32; ++l) {\n                const uint8_t q1 = L[j + l +  0] & 0xF;\n                const uint8_t q2 = L[j + l + 32] & 0xF;\n                const uint8_t q3 = L[j + l + 64] & 0xF;\n                const uint8_t q4 = L[j + l + 96] & 0xF;\n                ql[l+ 0] = q1 | (q3 << 4);\n                ql[l+32] = q2 | (q4 << 4);\n                qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);\n            }\n            ql += 64;\n            qh += 32;\n        }\n\n        x += QK_K;\n\n    }\n}\n\nsize_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);\n    if (!quant_weights) {\n        quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row);\n    }\n    else {\n        char * qrow = (char *)dst;\n        for (int64_t row = 0; row < nrow; ++row) {\n            quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights);\n            src += n_per_row;\n            qrow += row_size;\n        }\n    }\n    return nrow * row_size;\n}\n\nstatic void quantize_row_q4_0_impl(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {\n    static_assert(QK4_0 == 32, \"QK4_0 must be 32\");\n\n    if (!quant_weights) {\n        quantize_row_q4_0_ref(x, y, n_per_row);\n        return;\n    }\n\n    float weight[QK4_0];\n    int8_t L[QK4_0];\n\n    float sum_x2 = 0;\n    for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];\n    float sigma2 = sum_x2/n_per_row;\n\n    const int64_t nb = n_per_row/QK4_0;\n    for (int ib = 0; ib < nb; ++ib) {\n        const float * xb = x + QK4_0 * ib;\n        const float * qw = quant_weights + QK4_0 * ib;\n        for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);\n        float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight);\n        y[ib].d = GGML_FP32_TO_FP16(d);\n        for (int j = 0; j < 16; ++j) {\n            y[ib].qs[j] = L[j] | (L[j+16] << 4);\n        }\n    }\n}\n\nsize_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    if (!quant_weights) {\n        quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row);\n        return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);\n    }\n    size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);\n    char * qrow = (char *)dst;\n    for (int64_t row = 0; row < nrow; ++row) {\n        quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights);\n        src += n_per_row;\n        qrow += row_size;\n    }\n    return nrow * row_size;\n}\n\nstatic void quantize_row_q4_1_impl(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {\n    static_assert(QK4_1 == 32, \"QK4_1 must be 32\");\n\n    if (!quant_weights) {\n        quantize_row_q4_1_ref(x, y, n_per_row);\n        return;\n    }\n\n    float weight[QK4_1];\n    uint8_t L[QK4_1], Laux[QK4_1];\n\n    float sum_x2 = 0;\n    for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];\n    float sigma2 = sum_x2/n_per_row;\n\n    const int64_t nb = n_per_row/QK4_1;\n    for (int ib = 0; ib < nb; ++ib) {\n        const float * xb = x + QK4_1 * ib;\n        const float * qw = quant_weights + QK4_1 * ib;\n        for (int j = 0; j < QK4_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);\n        float min;\n        float d = make_qkx3_quants(QK4_1, 15, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false);\n        y[ib].d = GGML_FP32_TO_FP16(d);\n        y[ib].m = GGML_FP32_TO_FP16(-min);\n        for (int j = 0; j < 16; ++j) {\n            y[ib].qs[j] = L[j] | (L[j+16] << 4);\n        }\n    }\n}\n\nsize_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    if (!quant_weights) {\n        quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row);\n        return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row);\n    }\n    size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);\n    char * qrow = (char *)dst;\n    for (int64_t row = 0; row < nrow; ++row) {\n        quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights);\n        src += n_per_row;\n        qrow += row_size;\n    }\n    return nrow * row_size;\n}\n\nstatic void quantize_row_q5_0_impl(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {\n    static_assert(QK5_0 == 32, \"QK5_0 must be 32\");\n\n    if (!quant_weights) {\n        quantize_row_q5_0_ref(x, y, n_per_row);\n        return;\n    }\n\n    float weight[QK5_0];\n    int8_t L[QK5_0];\n\n    float sum_x2 = 0;\n    for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];\n    float sigma2 = sum_x2/n_per_row;\n\n    const int64_t nb = n_per_row/QK5_0;\n    for (int ib = 0; ib < nb; ++ib) {\n        const float * xb = x + QK5_0 * ib;\n        const float * qw = quant_weights + QK5_0 * ib;\n        for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);\n        float d = make_qx_quants(QK5_0, 16, xb, L, 1, weight);\n        y[ib].d = GGML_FP32_TO_FP16(d);\n\n        uint32_t qh = 0;\n\n        for (int j = 0; j < 16; ++j) {\n            const uint8_t xi0 = L[j];\n            const uint8_t xi1 = L[j+16];\n            y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);\n\n            // get the 5-th bit and store it in qh at the right position\n            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);\n            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);\n        }\n\n        memcpy(&y[ib].qh, &qh, sizeof(qh));\n    }\n}\n\nsize_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    if (!quant_weights) {\n        quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row);\n        return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row);\n    }\n    size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);\n    char * qrow = (char *)dst;\n    for (int64_t row = 0; row < nrow; ++row) {\n        quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights);\n        src += n_per_row;\n        qrow += row_size;\n    }\n    return nrow * row_size;\n}\n\nstatic void quantize_row_q5_1_impl(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {\n    static_assert(QK5_1 == 32, \"QK5_1 must be 32\");\n\n    if (!quant_weights) {\n        quantize_row_q5_1_ref(x, y, n_per_row);\n        return;\n    }\n\n    float weight[QK5_1];\n    uint8_t L[QK5_1], Laux[QK5_1];\n\n    float sum_x2 = 0;\n    for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];\n    float sigma2 = sum_x2/n_per_row;\n\n    const int64_t nb = n_per_row/QK5_1;\n    for (int ib = 0; ib < nb; ++ib) {\n        const float * xb = x + QK5_1 * ib;\n        const float * qw = quant_weights + QK5_1 * ib;\n        for (int j = 0; j < QK5_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);\n        float min;\n        float d = make_qkx3_quants(QK5_1, 31, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false);\n        y[ib].d = GGML_FP32_TO_FP16(d);\n        y[ib].m = GGML_FP32_TO_FP16(-min);\n\n        uint32_t qh = 0;\n        for (int j = 0; j < 16; ++j) {\n            const uint8_t xi0 = L[j];\n            const uint8_t xi1 = L[j+16];\n            y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);\n            // get the 5-th bit and store it in qh at the right position\n            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);\n            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);\n        }\n        memcpy(&y[ib].qh, &qh, sizeof(qh));\n    }\n}\n\nsize_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    if (!quant_weights) {\n        quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row);\n        return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row);\n    }\n    size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);\n    char * qrow = (char *)dst;\n    for (int64_t row = 0; row < nrow; ++row) {\n        quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights);\n        src += n_per_row;\n        qrow += row_size;\n    }\n    return nrow * row_size;\n}\n\nsize_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    (void)quant_weights; // not used\n    const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);\n    quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row);\n    return nrow * row_size;\n}\n\nsize_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    GGML_UNUSED(quant_weights);\n    quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);\n    return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);\n}\n\nsize_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    GGML_UNUSED(quant_weights);\n    quantize_row_nvfp4_ref(src, dst, (int64_t)nrow*n_per_row);\n    return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row);\n}\n\n// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)\n\nvoid quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    for (int64_t i = 0; i < nb; i++) {\n        float amax = 0.0f; // absolute max\n\n        for (int j = 0; j < QK_K; j++) {\n            const float v = x[j];\n            amax = MAX(amax, fabsf(v));\n        }\n\n        const float d = amax;\n        const float id = d ? 1.0f/d : 0.0f;\n\n        y[i].d = GGML_FP32_TO_FP16(d);\n\n        // 5 elements per byte, along 32 bytes\n        for (size_t j = 0; j < sizeof(y->qs) - sizeof(y->qs) % 32; j += 32) {\n            for (size_t m = 0; m < 32; ++m) {\n                uint8_t q = 0;\n                for (size_t n = 0; n < 5; ++n) {\n                    int xi = lroundf(x[m + n*32] * id) + 1; // -1, 0, 1 -> 0, 1, 2\n                    q *= 3;\n                    q += xi;\n                }\n                // ceiling division (243 == pow(3, 5))\n                q = ((uint16_t)q * 256 + (243 - 1)) / 243;\n                y[i].qs[j + m] = q;\n            }\n            x += 5*32;\n        }\n        // along 16 bytes\n        for (size_t j = sizeof(y->qs) - sizeof(y->qs) % 32; j < sizeof(y->qs); j += 16) {\n            for (size_t m = 0; m < 16; ++m) {\n                uint8_t q = 0;\n                for (size_t n = 0; n < 5; ++n) {\n                    int xi = lroundf(x[m + n*16] * id) + 1; // -1, 0, 1 -> 0, 1, 2\n                    q *= 3;\n                    q += xi;\n                }\n                // ceiling division (243 == pow(3, 5))\n                q = ((uint16_t)q * 256 + (243 - 1)) / 243;\n                y[i].qs[j + m] = q;\n            }\n            x += 5*16;\n        }\n        // 4 elements per byte\n        for (size_t j = 0; j < sizeof(y->qh); ++j) {\n            uint8_t q = 0;\n            for (size_t m = 0; m < 4; ++m) {\n                // -1, 0, 1 -> 0, 1, 2\n                int xi = lroundf(x[j + m*sizeof(y->qh)] * id) + 1;\n                q *= 3;\n                q += xi;\n            }\n            // shift the first value to the most significant trit\n            q *= 3;\n            // ceiling division (243 == pow(3, 5))\n            q = ((uint16_t)q * 256 + (243 - 1)) / 243;\n            y[i].qh[j] = q;\n        }\n        x += 4*sizeof(y->qh);\n    }\n}\n\nvoid quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    for (int64_t i = 0; i < nb; i++) {\n        float amax = 0.0f; // absolute max\n\n        for (int j = 0; j < QK_K; j++) {\n            const float v = x[j];\n            amax = MAX(amax, fabsf(v));\n        }\n\n        const float d = amax;\n        const float id = d ? 1.0f/d : 0.0f;\n\n        y[i].d = GGML_FP32_TO_FP16(d);\n\n        for (size_t j = 0; j < sizeof(y->qs); j += 32) {\n            for (size_t m = 0; m < 32; ++m) {\n                uint8_t q = 0;\n                for (size_t n = 0; n < 4; ++n) {\n                    // -1, 0, 1 -> 0, 1, 2\n                    int xi = lroundf(x[m + n*32] * id) + 1;\n                    q += (xi & 3) << (2*n);\n                }\n                y[i].qs[j + m] = q;\n            }\n            x += 4*32;\n        }\n    }\n}\n\nsize_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    (void)quant_weights; // not used\n    const size_t row_size = ggml_row_size(GGML_TYPE_TQ1_0, n_per_row);\n    quantize_row_tq1_0_ref(src, dst, (int64_t)nrow*n_per_row);\n    return nrow * row_size;\n}\n\nsize_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    (void)quant_weights; // not used\n    const size_t row_size = ggml_row_size(GGML_TYPE_TQ2_0, n_per_row);\n    quantize_row_tq2_0_ref(src, dst, (int64_t)nrow*n_per_row);\n    return nrow * row_size;\n}\n\nvoid dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};\n\n    for (int64_t i = 0; i < nb; ++i) {\n\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n\n        for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {\n            for (size_t n = 0; n < 5; ++n) {\n                for (size_t m = 0; m < 32; ++m) {\n                    uint8_t q = x[i].qs[j + m] * pow3[n];\n                    int16_t xi = ((uint16_t) q * 3) >> 8;\n                    *y++ = (float) (xi - 1) * d;\n                }\n            }\n        }\n        for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {\n            for (size_t n = 0; n < 5; ++n) {\n                for (size_t m = 0; m < 16; ++m) {\n                    uint8_t q = x[i].qs[j + m] * pow3[n];\n                    int16_t xi = ((uint16_t) q * 3) >> 8;\n                    *y++ = (float) (xi - 1) * d;\n                }\n            }\n        }\n\n        for (size_t n = 0; n < 4; ++n) {\n            for (size_t j = 0; j < sizeof(x->qh); ++j) {\n                uint8_t q = x[i].qh[j] * pow3[n];\n                int16_t xi = ((uint16_t) q * 3) >> 8;\n                *y++ = (float) (xi - 1) * d;\n            }\n        }\n    }\n}\n\nvoid dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    for (int64_t i = 0; i < nb; ++i) {\n\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n\n        for (size_t j = 0; j < sizeof(x->qs); j += 32) {\n            for (size_t l = 0; l < 4; ++l) {\n                for (size_t m = 0; m < 32; ++m) {\n                    int8_t q = (x[i].qs[j + m] >> (l*2)) & 3;\n                    *y++ = (float) (q - 1) * d;\n                }\n            }\n        }\n    }\n}\n\n// ====================== \"True\" 2-bit (de)-quantization\n\nvoid dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    uint32_t aux32[2];\n    const uint8_t * aux8 = (const uint8_t *)aux32;\n\n    for (int i = 0; i < nb; i++) {\n\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n\n        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {\n            memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t));\n            const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;\n            for (int l = 0; l < 4; ++l) {\n                const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);\n                const uint8_t  signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];\n                for (int j = 0; j < 8; ++j) {\n                    y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);\n                }\n                y += 8;\n            }\n        }\n    }\n}\n\n// ====================== 2.3125 bpw (de)-quantization\n\nvoid dequantize_row_iq2_xs(const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    float db[2];\n\n    for (int i = 0; i < nb; i++) {\n\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n\n        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {\n            db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;\n            db[1] = d * (0.5f + (x[i].scales[ib32] >>  4)) * 0.25f;\n            for (int l = 0; l < 4; ++l) {\n                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511));\n                const uint8_t  signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9];\n                for (int j = 0; j < 8; ++j) {\n                    y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);\n                }\n                y += 8;\n            }\n        }\n    }\n}\n\n// ====================== 2.5625 bpw (de)-quantization\n\nvoid dequantize_row_iq2_s(const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    float db[2];\n\n    for (int i = 0; i < nb; i++) {\n\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n        const uint8_t * qs = x[i].qs;\n        const uint8_t * qh = x[i].qh;\n        const uint8_t * signs = qs + QK_K/8;\n\n        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {\n            db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;\n            db[1] = d * (0.5f + (x[i].scales[ib32] >>  4)) * 0.25f;\n            for (int l = 0; l < 4; ++l) {\n                const float dl = db[l/2];\n                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));\n                for (int j = 0; j < 8; ++j) {\n                    y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);\n                }\n                y += 8;\n            }\n            qs += 4;\n            signs += 4;\n        }\n    }\n}\n\n// ====================== 3.0625 bpw (de)-quantization\n\nvoid dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    uint32_t aux32;\n\n    for (int i = 0; i < nb; i++) {\n\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n        const uint8_t * qs = x[i].qs;\n        const uint8_t * scales_and_signs = qs + QK_K/4;\n\n        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {\n            memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t));\n            const float db = d * (0.5f + (aux32 >> 28)) * 0.5f;\n            for (int l = 0; l < 4; ++l) {\n                const uint8_t  signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];\n                const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]);\n                const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]);\n                for (int j = 0; j < 4; ++j) {\n                    y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);\n                    y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);\n                }\n                y += 8;\n            }\n            qs += 8;\n        }\n    }\n}\n\n// ====================== 3.3125 bpw (de)-quantization\n\nvoid dequantize_row_iq3_s(const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    for (int i = 0; i < nb; i++) {\n\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n        const uint8_t * qs = x[i].qs;\n        const uint8_t * qh = x[i].qh;\n        const uint8_t * signs = x[i].signs;\n\n        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {\n            const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf));\n            const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >>  4));\n            for (int l = 0; l < 4; ++l) {\n                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));\n                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));\n                for (int j = 0; j < 4; ++j) {\n                    y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);\n                    y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);\n                }\n                y += 8;\n            }\n            qs += 8;\n            signs += 4;\n            for (int l = 0; l < 4; ++l) {\n                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)));\n                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)));\n                for (int j = 0; j < 4; ++j) {\n                    y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);\n                    y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);\n                }\n                y += 8;\n            }\n            qh += 2;\n            qs += 8;\n            signs += 4;\n        }\n    }\n}\n\n// ====================== 1.5625 bpw (de)-quantization\n\nvoid dequantize_row_iq1_s(const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    for (int i = 0; i < nb; i++) {\n\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n        const uint8_t  * qs = x[i].qs;\n        const uint16_t * qh = x[i].qh;\n\n        for (int ib = 0; ib < QK_K/32; ++ib) {\n            const float dl = d * (2*((qh[ib] >> 12) & 7) + 1);\n            const float delta = qh[ib] & 0x8000 ? -IQ1S_DELTA : IQ1S_DELTA;\n            for (int l = 0; l < 4; ++l) {\n                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));\n                for (int j = 0; j < 8; ++j) {\n                    y[j] = dl * (grid[j] + delta);\n                }\n                y += 8;\n            }\n            qs += 4;\n        }\n    }\n}\n\nvoid dequantize_row_iq1_m(const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    float delta[4];\n    uint16_t idx[4];\n\n    iq1m_scale_t scale;\n\n    for (int i = 0; i < nb; i++) {\n\n        const uint16_t * sc = (const uint16_t *)x[i].scales;\n        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);\n        const float d = GGML_FP16_TO_FP32(scale.f16);\n\n        const uint8_t * qs = x[i].qs;\n        const uint8_t * qh = x[i].qh;\n\n        for (int ib = 0; ib < QK_K/32; ++ib) {\n            const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1);\n            const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1);\n\n            idx[0] = qs[0] | ((qh[0] << 8) & 0x700);\n            idx[1] = qs[1] | ((qh[0] << 4) & 0x700);\n            idx[2] = qs[2] | ((qh[1] << 8) & 0x700);\n            idx[3] = qs[3] | ((qh[1] << 4) & 0x700);\n            delta[0] = qh[0] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA;\n            delta[1] = qh[0] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA;\n            delta[2] = qh[1] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA;\n            delta[3] = qh[1] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA;\n            for (int l = 0; l < 2; ++l) {\n                const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);\n                for (int j = 0; j < 8; ++j) {\n                    y[j] = dl1 * (grid[j] + delta[l]);\n                }\n                y += 8;\n            }\n            for (int l = 2; l < 4; ++l) {\n                const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);\n                for (int j = 0; j < 8; ++j) {\n                    y[j] = dl2 * (grid[j] + delta[l]);\n                }\n                y += 8;\n            }\n            qs += 4;\n            qh += 2;\n        }\n    }\n}\n\nvoid dequantize_row_iq4_nl(const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK4_NL == 0);\n    const int64_t nb = k / QK4_NL;\n\n    for (int i = 0; i < nb; i++) {\n\n        const uint8_t * qs = x[i].qs;\n\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n        for (int j = 0; j < QK4_NL/2; ++j) {\n            y[j+       0] = d * kvalues_iq4nl[qs[j] & 0xf];\n            y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >>  4];\n        }\n        y  += QK4_NL;\n        qs += QK4_NL/2;\n    }\n}\n\nvoid dequantize_row_iq4_xs(const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    for (int i = 0; i < nb; i++) {\n\n        const uint8_t * qs = x[i].qs;\n\n        const float d = GGML_FP16_TO_FP32(x[i].d);\n\n        for (int ib = 0; ib < QK_K/32; ++ib) {\n            const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4);\n            const float dl = d * (ls - 32);\n            for (int j = 0; j < 16; ++j) {\n                y[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf];\n                y[j+16] = dl * kvalues_iq4nl[qs[j] >>  4];\n            }\n            y  += 32;\n            qs += 16;\n        }\n    }\n}\n\n//===================================== Q8_K ==============================================\n\nvoid quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    for (int i = 0; i < nb; i++) {\n\n        float max = 0;\n        float amax = 0;\n        for (int j = 0; j < QK_K; ++j) {\n            float ax = fabsf(x[j]);\n            if (ax > amax) {\n                amax = ax; max = x[j];\n            }\n        }\n        if (!amax) {\n            y[i].d = 0;\n            memset(y[i].qs, 0, QK_K);\n            x += QK_K;\n            continue;\n        }\n        //const float iscale = -128.f/max;\n        // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward\n        const float iscale = -127.f/max;\n        for (int j = 0; j < QK_K; ++j) {\n            int v = nearest_int(iscale*x[j]);\n            y[i].qs[j] = MIN(127, v);\n        }\n        for (int j = 0; j < QK_K/16; ++j) {\n            int sum = 0;\n            for (int ii = 0; ii < 16; ++ii) {\n                sum += y[i].qs[j*16 + ii];\n            }\n            y[i].bsums[j] = sum;\n        }\n        y[i].d = 1/iscale;\n        x += QK_K;\n    }\n}\n\nvoid dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    const int64_t nb = k / QK_K;\n\n    for (int i = 0; i < nb; i++) {\n        for (int j = 0; j < QK_K; ++j) {\n            *y++ = x[i].d * x[i].qs[j];\n        }\n    }\n}\n\n// ================================ IQ2 quantization =============================================\n\ntypedef struct {\n    uint64_t * grid;\n    int      * map;\n    uint16_t * neighbours;\n} iq2_entry_t;\n\nstatic iq2_entry_t iq2_data[4] = {\n    {NULL, NULL, NULL},\n    {NULL, NULL, NULL},\n    {NULL, NULL, NULL},\n    {NULL, NULL, NULL},\n};\n\nstatic inline int iq2_data_index(enum ggml_type type) {\n    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);\n    return type == GGML_TYPE_IQ2_XXS ? 0 :\n           type == GGML_TYPE_IQ2_XS  ? 1 :\n           type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 2 : 3;\n}\n\nstatic inline int iq2_grid_size(enum ggml_type type) {\n    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);\n    return type == GGML_TYPE_IQ2_XXS ? 256 :\n           type == GGML_TYPE_IQ2_XS  ? 512 :\n           type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? NGRID_IQ1S : 1024;\n}\n\nstatic int iq2_compare_func(const void * left, const void * right) {\n    const int * l = (const int *)left;\n    const int * r = (const int *)right;\n    return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;\n}\n\nvoid iq2xs_init_impl(enum ggml_type type) {\n    const int gindex = iq2_data_index(type);\n    const int grid_size = iq2_grid_size(type);\n    if (iq2_data[gindex].grid) {\n        return;\n    }\n    static const uint16_t kgrid_2bit_256[256] = {\n            0,     2,     5,     8,    10,    17,    20,    32,    34,    40,    42,    65,    68,    80,    88,    97,\n          100,   128,   130,   138,   162,   257,   260,   272,   277,   320,   388,   408,   512,   514,   546,   642,\n         1025,  1028,  1040,  1057,  1060,  1088,  1090,  1096,  1120,  1153,  1156,  1168,  1188,  1280,  1282,  1288,\n         1312,  1350,  1385,  1408,  1425,  1545,  1552,  1600,  1668,  1700,  2048,  2053,  2056,  2068,  2088,  2113,\n         2116,  2128,  2130,  2184,  2308,  2368,  2562,  2580,  4097,  4100,  4112,  4129,  4160,  4192,  4228,  4240,\n         4245,  4352,  4360,  4384,  4432,  4442,  4480,  4644,  4677,  5120,  5128,  5152,  5157,  5193,  5248,  5400,\n         5474,  5632,  5654,  6145,  6148,  6160,  6208,  6273,  6400,  6405,  6560,  6737,  8192,  8194,  8202,  8260,\n         8289,  8320,  8322,  8489,  8520,  8704,  8706,  9217,  9220,  9232,  9280,  9302,  9472,  9537,  9572,  9872,\n        10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516,\n        16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561,\n        17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488,\n        20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545,\n        22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874,\n        25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856,\n        33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,\n        37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,\n    };\n    static const uint16_t kgrid_2bit_512[512] = {\n            0,     2,     5,     8,    10,    17,    20,    22,    25,    32,    34,    37,    40,    65,    68,    70,\n           73,    80,    82,    85,    88,    97,   100,   128,   130,   133,   136,   145,   148,   153,   160,   257,\n          260,   262,   265,   272,   274,   277,   280,   282,   289,   292,   320,   322,   325,   328,   337,   340,\n          352,   360,   385,   388,   400,   512,   514,   517,   520,   529,   532,   544,   577,   580,   592,   597,\n          640,   650,  1025,  1028,  1030,  1033,  1040,  1042,  1045,  1048,  1057,  1060,  1088,  1090,  1093,  1096,\n         1105,  1108,  1110,  1120,  1153,  1156,  1168,  1280,  1282,  1285,  1288,  1297,  1300,  1312,  1345,  1348,\n         1360,  1377,  1408,  1537,  1540,  1552,  1574,  1600,  1602,  1668,  2048,  2050,  2053,  2056,  2058,  2065,\n         2068,  2080,  2085,  2113,  2116,  2128,  2136,  2176,  2208,  2218,  2305,  2308,  2320,  2368,  2433,  2441,\n         2560,  2592,  2600,  2710,  2720,  4097,  4100,  4102,  4105,  4112,  4114,  4117,  4120,  4129,  4132,  4160,\n         4162,  4165,  4168,  4177,  4180,  4192,  4202,  4225,  4228,  4240,  4352,  4354,  4357,  4360,  4369,  4372,\n         4384,  4417,  4420,  4432,  4480,  4500,  4502,  4609,  4612,  4614,  4624,  4672,  4704,  5120,  5122,  5125,\n         5128,  5137,  5140,  5152,  5185,  5188,  5193,  5200,  5220,  5248,  5377,  5380,  5392,  5440,  5632,  5652,\n         5705,  6145,  6148,  6160,  6162,  6208,  6228,  6278,  6400,  6405,  6502,  6737,  6825,  8192,  8194,  8197,\n         8200,  8202,  8209,  8212,  8224,  8257,  8260,  8272,  8320,  8352,  8449,  8452,  8464,  8512,  8520,  8549,\n         8704,  8738,  8832,  8872,  9217,  9220,  9232,  9257,  9280,  9472,  9537,  9554,  9625,  9729,  9754,  9894,\n        10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388,\n        16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480,\n        16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773,\n        16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473,\n        17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436,\n        18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497,\n        20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162,\n        21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528,\n        22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745,\n        24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234,\n        32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025,\n        33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810,\n        33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984,\n        35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462,\n        37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960,\n        40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,\n        42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,\n    };\n    static const uint16_t kgrid_1bit_2048[NGRID_IQ1S] = {\n            0,     2,     5,     8,    10,    17,    21,    32,    34,    40,    42,    69,    81,    84,    86,   101,\n          128,   130,   136,   138,   149,   160,   162,   168,   170,   260,   261,   273,   276,   278,   281,   282,\n          293,   321,   326,   329,   338,   341,   346,   353,   356,   358,   360,   389,   401,   404,   406,   421,\n          512,   514,   520,   522,   533,   544,   546,   552,   554,   581,   593,   601,   612,   617,   640,   642,\n          648,   650,   657,   661,   665,   672,   674,   680,   682,  1041,  1044,  1046,  1061,  1089,  1097,  1109,\n         1114,  1124,  1125,  1169,  1177,  1189,  1281,  1284,  1285,  1286,  1301,  1304,  1306,  1321,  1344,  1349,\n         1354,  1360,  1361,  1364,  1365,  1366,  1369,  1376,  1378,  1381,  1384,  1386,  1409,  1425,  1429,  1432,\n         1434,  1441,  1444,  1445,  1446,  1449,  1556,  1561,  1601,  1604,  1616,  1618,  1621,  1624,  1632,  1633,\n         1638,  1641,  1669,  1681,  1684,  1689,  2048,  2050,  2056,  2058,  2069,  2080,  2082,  2088,  2090,  2117,\n         2129,  2134,  2149,  2176,  2178,  2184,  2186,  2197,  2208,  2210,  2216,  2218,  2309,  2321,  2324,  2329,\n         2340,  2341,  2369,  2384,  2385,  2389,  2401,  2404,  2409,  2449,  2452,  2454,  2457,  2469,  2560,  2562,\n         2568,  2570,  2581,  2592,  2594,  2600,  2602,  2629,  2641,  2649,  2657,  2661,  2688,  2690,  2693,  2696,\n         2698,  2709,  2720,  2722,  2728,  2730,  4112,  4113,  4116,  4121,  4132,  4133,  4161,  4164,  4176,  4181,\n         4184,  4193,  4196,  4197,  4201,  4241,  4244,  4246,  4257,  4261,  4353,  4356,  4358,  4361,  4368,  4370,\n         4373,  4376,  4385,  4388,  4393,  4421,  4426,  4432,  4433,  4434,  4436,  4437,  4438,  4441,  4448,  4453,\n         4484,  4498,  4501,  4513,  4516,  4625,  4628,  4630,  4645,  4672,  4678,  4681,  4690,  4693,  4696,  4698,\n         4708,  4710,  4741,  4753,  4756,  4758,  4773,  5121,  5126,  5129,  5140,  5141,  5144,  5145,  5153,  5158,\n         5185,  5189,  5190,  5192,  5194,  5201,  5204,  5205,  5206,  5209,  5218,  5221,  5224,  5252,  5257,  5264,\n         5268,  5269,  5272,  5273,  5274,  5281,  5284,  5285,  5289,  5378,  5381,  5386,  5393,  5396,  5397,  5398,\n         5401,  5408,  5410,  5413,  5416,  5418,  5441,  5444,  5445,  5446,  5457,  5458,  5460,  5461,  5462,  5465,\n         5466,  5473,  5476,  5477,  5478,  5481,  5504,  5506,  5508,  5509,  5512,  5514,  5520,  5521,  5524,  5525,\n         5526,  5529,  5530,  5536,  5538,  5541,  5633,  5636,  5637,  5638,  5653,  5654,  5656,  5658,  5665,  5670,\n         5696,  5698,  5700,  5701,  5704,  5706,  5713,  5717,  5718,  5720,  5721,  5729,  5732,  5733,  5736,  5737,\n         5738,  5766,  5770,  5778,  5781,  5796,  5801,  6161,  6166,  6181,  6209,  6212,  6214,  6217,  6224,  6229,\n         6232,  6234,  6240,  6241,  6244,  6246,  6249,  6277,  6289,  6292,  6309,  6416,  6418,  6421,  6426,  6433,\n         6437,  6466,  6468,  6469,  6472,  6481,  6484,  6485,  6486,  6489,  6490,  6496,  6501,  6506,  6537,  6545,\n         6546,  6549,  6552,  6561,  6566,  6569,  6665,  6678,  6692,  6694,  6724,  6726,  6729,  6736,  6738,  6741,\n         6744,  6753,  6758,  6761,  6789,  6801,  6806,  6810,  8192,  8194,  8200,  8202,  8213,  8224,  8226,  8229,\n         8232,  8234,  8261,  8273,  8281,  8289,  8293,  8320,  8322,  8328,  8330,  8341,  8352,  8354,  8357,  8360,\n         8362,  8453,  8465,  8468,  8473,  8485,  8514,  8516,  8521,  8533,  8536,  8538,  8545,  8548,  8549,  8550,\n         8581,  8592,  8598,  8601,  8613,  8705,  8712,  8714,  8721,  8725,  8736,  8738,  8744,  8746,  8773,  8785,\n         8790,  8793,  8805,  8833,  8840,  8842,  8849,  8853,  8864,  8866,  8872,  8874,  9221,  9236,  9238,  9241,\n         9253,  9284,  9285,  9286,  9289,  9298,  9301,  9304,  9306,  9318,  9349,  9361,  9364,  9369,  9377,  9381,\n         9481,  9493,  9505,  9513,  9536,  9541,  9544,  9553,  9556,  9557,  9561,  9570,  9573,  9576,  9609,  9616,\n         9620,  9621,  9624,  9626,  9633,  9636,  9638,  9641,  9733,  9744,  9746,  9753,  9765,  9793,  9801,  9813,\n         9824,  9825,  9833,  9860,  9862,  9872,  9882, 10240, 10242, 10248, 10250, 10261, 10272, 10274, 10280, 10282,\n        10309, 10321, 10324, 10341, 10368, 10370, 10376, 10378, 10400, 10402, 10408, 10410, 10505, 10513, 10516, 10521,\n        10533, 10566, 10569, 10578, 10581, 10593, 10596, 10598, 10601, 10629, 10640, 10646, 10649, 10660, 10661, 10752,\n        10754, 10760, 10762, 10784, 10786, 10792, 10794, 10821, 10833, 10838, 10841, 10853, 10880, 10882, 10888, 10890,\n        10901, 10912, 10914, 10920, 10922, 16389, 16401, 16406, 16421, 16457, 16466, 16469, 16472, 16474, 16481, 16484,\n        16486, 16532, 16537, 16545, 16550, 16640, 16641, 16644, 16646, 16649, 16658, 16661, 16662, 16664, 16666, 16673,\n        16678, 16681, 16709, 16712, 16714, 16721, 16724, 16725, 16726, 16729, 16730, 16741, 16744, 16746, 16769, 16772,\n        16774, 16784, 16786, 16789, 16800, 16801, 16802, 16901, 16913, 16916, 16918, 16933, 16961, 16978, 16981, 16986,\n        16996, 17001, 17033, 17044, 17061, 17409, 17429, 17433, 17449, 17477, 17480, 17482, 17489, 17492, 17493, 17494,\n        17505, 17506, 17509, 17512, 17514, 17537, 17542, 17545, 17552, 17554, 17557, 17568, 17569, 17577, 17665, 17666,\n        17669, 17674, 17681, 17684, 17685, 17686, 17689, 17696, 17701, 17706, 17729, 17732, 17733, 17734, 17737, 17744,\n        17745, 17748, 17749, 17750, 17752, 17753, 17761, 17764, 17765, 17766, 17769, 17794, 17796, 17797, 17800, 17809,\n        17812, 17813, 17814, 17817, 17818, 17829, 17832, 17834, 17921, 17925, 17929, 17940, 17941, 17944, 17946, 17953,\n        17956, 17961, 17984, 17986, 17989, 17992, 18000, 18001, 18002, 18005, 18006, 18009, 18018, 18021, 18024, 18049,\n        18053, 18058, 18068, 18069, 18081, 18084, 18086, 18437, 18449, 18453, 18458, 18469, 18498, 18505, 18512, 18517,\n        18520, 18529, 18532, 18534, 18537, 18565, 18577, 18580, 18582, 18585, 18597, 18689, 18693, 18694, 18698, 18704,\n        18708, 18709, 18712, 18721, 18724, 18726, 18752, 18757, 18762, 18769, 18770, 18772, 18773, 18774, 18777, 18784,\n        18786, 18789, 18790, 18794, 18822, 18825, 18834, 18837, 18838, 18840, 18849, 18852, 18854, 18857, 18966, 19012,\n        19014, 19017, 19029, 19032, 19034, 19044, 19049, 19092, 19109, 20481, 20484, 20485, 20486, 20489, 20498, 20501,\n        20506, 20513, 20516, 20521, 20544, 20549, 20552, 20561, 20564, 20565, 20566, 20569, 20581, 20584, 20614, 20617,\n        20629, 20632, 20640, 20641, 20646, 20649, 20741, 20744, 20745, 20746, 20753, 20756, 20757, 20758, 20760, 20761,\n        20768, 20773, 20774, 20776, 20778, 20801, 20804, 20805, 20806, 20809, 20816, 20817, 20818, 20820, 20821, 20822,\n        20824, 20825, 20826, 20833, 20836, 20837, 20838, 20841, 20866, 20869, 20881, 20884, 20885, 20886, 20889, 20896,\n        20901, 20906, 20993, 20998, 21010, 21013, 21018, 21025, 21028, 21058, 21061, 21066, 21073, 21076, 21077, 21078,\n        21081, 21090, 21093, 21125, 21136, 21138, 21141, 21145, 21146, 21156, 21508, 21509, 21521, 21524, 21525, 21526,\n        21528, 21529, 21537, 21541, 21544, 21546, 21569, 21572, 21573, 21574, 21577, 21578, 21584, 21585, 21588, 21589,\n        21590, 21592, 21593, 21594, 21601, 21602, 21604, 21605, 21606, 21609, 21632, 21640, 21642, 21649, 21652, 21653,\n        21654, 21657, 21665, 21668, 21669, 21674, 21761, 21762, 21764, 21765, 21766, 21769, 21776, 21777, 21778, 21780,\n        21781, 21782, 21785, 21786, 21793, 21796, 21797, 21798, 21801, 21824, 21825, 21826, 21828, 21829, 21830, 21832,\n        21833, 21840, 21841, 21842, 21844, 21845, 21846, 21848, 21849, 21850, 21856, 21857, 21860, 21861, 21862, 21864,\n        21865, 21866, 21889, 21892, 21893, 21897, 21898, 21904, 21905, 21908, 21909, 21910, 21912, 21913, 21921, 21924,\n        21925, 21926, 21929, 22016, 22017, 22018, 22020, 22022, 22024, 22025, 22033, 22036, 22037, 22040, 22041, 22048,\n        22049, 22050, 22052, 22053, 22054, 22056, 22057, 22081, 22085, 22086, 22088, 22089, 22090, 22096, 22097, 22098,\n        22100, 22101, 22102, 22104, 22105, 22106, 22113, 22116, 22117, 22121, 22146, 22149, 22150, 22152, 22153, 22154,\n        22161, 22165, 22170, 22178, 22181, 22182, 22184, 22185, 22532, 22533, 22534, 22537, 22544, 22549, 22552, 22561,\n        22570, 22597, 22600, 22602, 22609, 22612, 22613, 22614, 22616, 22617, 22624, 22626, 22628, 22629, 22658, 22665,\n        22672, 22674, 22677, 22680, 22689, 22697, 22785, 22786, 22789, 22794, 22801, 22804, 22805, 22806, 22809, 22821,\n        22849, 22852, 22853, 22854, 22857, 22864, 22865, 22866, 22868, 22869, 22870, 22872, 22873, 22874, 22881, 22884,\n        22885, 22886, 22889, 22913, 22917, 22921, 22929, 22932, 22933, 22934, 22936, 22937, 22949, 23044, 23048, 23061,\n        23066, 23072, 23077, 23078, 23081, 23109, 23112, 23113, 23121, 23125, 23126, 23128, 23129, 23138, 23141, 23144,\n        23146, 23169, 23178, 23186, 23189, 23190, 23192, 23194, 23201, 24581, 24596, 24598, 24601, 24613, 24644, 24656,\n        24661, 24662, 24664, 24666, 24673, 24676, 24678, 24681, 24705, 24726, 24741, 24833, 24836, 24838, 24841, 24850,\n        24853, 24865, 24866, 24870, 24873, 24901, 24905, 24913, 24917, 24918, 24921, 24933, 24934, 24938, 24964, 24970,\n        24978, 24981, 24993, 24998, 25001, 25105, 25110, 25113, 25152, 25153, 25158, 25173, 25174, 25176, 25184, 25221,\n        25233, 25238, 25253, 25617, 25618, 25621, 25622, 25626, 25633, 25638, 25641, 25664, 25666, 25669, 25672, 25674,\n        25681, 25684, 25685, 25686, 25689, 25690, 25696, 25698, 25701, 25732, 25733, 25737, 25744, 25746, 25748, 25749,\n        25750, 25752, 25754, 25761, 25764, 25769, 25861, 25864, 25866, 25873, 25877, 25878, 25881, 25924, 25925, 25926,\n        25929, 25936, 25937, 25940, 25941, 25942, 25945, 25953, 25956, 25957, 25958, 25961, 25990, 25993, 25994, 26001,\n        26005, 26006, 26009, 26010, 26018, 26021, 26022, 26024, 26114, 26121, 26133, 26144, 26150, 26152, 26153, 26176,\n        26181, 26184, 26186, 26193, 26196, 26197, 26198, 26200, 26202, 26208, 26213, 26216, 26240, 26242, 26245, 26250,\n        26260, 26262, 26264, 26265, 26272, 26276, 26278, 26282, 26646, 26649, 26661, 26689, 26706, 26709, 26714, 26721,\n        26729, 26757, 26769, 26776, 26790, 26881, 26884, 26896, 26901, 26913, 26916, 26918, 26921, 26944, 26945, 26949,\n        26950, 26952, 26961, 26964, 26965, 26966, 26969, 26976, 26981, 26986, 27010, 27012, 27018, 27029, 27041, 27044,\n        27045, 27049, 27153, 27158, 27160, 27201, 27204, 27209, 27216, 27221, 27224, 27226, 27236, 27237, 27241, 27270,\n        27284, 27288, 27290, 27302, 32768, 32770, 32776, 32778, 32800, 32802, 32808, 32810, 32837, 32848, 32849, 32852,\n        32854, 32857, 32869, 32896, 32898, 32904, 32906, 32917, 32928, 32930, 32936, 32938, 33029, 33041, 33044, 33046,\n        33049, 33061, 33089, 33092, 33097, 33104, 33106, 33109, 33110, 33112, 33113, 33124, 33126, 33129, 33157, 33161,\n        33172, 33174, 33177, 33189, 33280, 33282, 33288, 33290, 33301, 33312, 33314, 33320, 33322, 33361, 33364, 33369,\n        33381, 33408, 33410, 33416, 33418, 33429, 33440, 33442, 33448, 33450, 33812, 33817, 33857, 33860, 33873, 33877,\n        33882, 33889, 33892, 33897, 33940, 33945, 34049, 34057, 34066, 34069, 34074, 34086, 34089, 34112, 34113, 34117,\n        34120, 34129, 34132, 34133, 34134, 34137, 34138, 34149, 34150, 34152, 34154, 34177, 34180, 34182, 34185, 34192,\n        34194, 34197, 34200, 34214, 34321, 34326, 34329, 34341, 34369, 34372, 34377, 34378, 34384, 34389, 34393, 34394,\n        34401, 34406, 34410, 34437, 34449, 34458, 34468, 34816, 34818, 34824, 34826, 34837, 34848, 34850, 34856, 34858,\n        34881, 34885, 34897, 34900, 34905, 34917, 34921, 34944, 34946, 34952, 34954, 34965, 34976, 34978, 34984, 34986,\n        35077, 35078, 35089, 35092, 35094, 35109, 35137, 35140, 35142, 35145, 35152, 35154, 35157, 35162, 35169, 35172,\n        35205, 35222, 35225, 35237, 35328, 35330, 35336, 35338, 35349, 35360, 35362, 35368, 35370, 35397, 35409, 35412,\n        35414, 35456, 35458, 35464, 35466, 35477, 35488, 35490, 35496, 35498, 36869, 36881, 36886, 36888, 36889, 36901,\n        36929, 36934, 36937, 36949, 36952, 36954, 36969, 36970, 36997, 37009, 37012, 37014, 37017, 37029, 37121, 37124,\n        37126, 37129, 37136, 37141, 37144, 37146, 37153, 37156, 37158, 37161, 37184, 37189, 37200, 37201, 37204, 37205,\n        37206, 37209, 37218, 37221, 37252, 37254, 37266, 37269, 37272, 37281, 37284, 37286, 37289, 37381, 37393, 37396,\n        37401, 37413, 37444, 37446, 37449, 37456, 37458, 37461, 37464, 37478, 37481, 37509, 37524, 37526, 37545, 37889,\n        37892, 37894, 37904, 37909, 37912, 37926, 37952, 37962, 37969, 37972, 37973, 37974, 37976, 37977, 37984, 37985,\n        37986, 37989, 38020, 38022, 38034, 38036, 38037, 38040, 38049, 38057, 38144, 38149, 38152, 38154, 38160, 38161,\n        38164, 38165, 38166, 38169, 38177, 38181, 38185, 38186, 38209, 38212, 38213, 38214, 38217, 38224, 38225, 38226,\n        38228, 38229, 38230, 38232, 38233, 38234, 38241, 38244, 38245, 38246, 38249, 38273, 38277, 38280, 38289, 38290,\n        38292, 38293, 38294, 38297, 38298, 38304, 38306, 38309, 38312, 38314, 38401, 38404, 38416, 38421, 38425, 38432,\n        38438, 38441, 38469, 38472, 38473, 38481, 38482, 38485, 38486, 38489, 38501, 38504, 38530, 38532, 38537, 38538,\n        38546, 38548, 38549, 38564, 38566, 38569, 38917, 38934, 38937, 38949, 38977, 38982, 38992, 38994, 38997, 38998,\n        39002, 39012, 39013, 39045, 39057, 39062, 39065, 39077, 39172, 39174, 39177, 39184, 39186, 39189, 39192, 39194,\n        39200, 39201, 39204, 39206, 39232, 39234, 39237, 39240, 39242, 39249, 39252, 39253, 39254, 39257, 39266, 39269,\n        39270, 39274, 39297, 39300, 39312, 39314, 39317, 39322, 39329, 39334, 39429, 39445, 39461, 39492, 39494, 39497,\n        39504, 39509, 39512, 39521, 39557, 39569, 39572, 39573, 39574, 40960, 40962, 40968, 40970, 40981, 40992, 40994,\n        41000, 41002, 41029, 41041, 41044, 41046, 41049, 41088, 41090, 41096, 41098, 41109, 41120, 41122, 41128, 41130,\n        41221, 41225, 41233, 41236, 41238, 41241, 41242, 41286, 41289, 41297, 41301, 41304, 41306, 41313, 41316, 41349,\n        41360, 41362, 41366, 41369, 41474, 41480, 41482, 41488, 41497, 41506, 41512, 41514, 41541, 41553, 41558, 41561,\n        41573, 41600, 41602, 41608, 41610, 41621, 41632, 41634, 41640, 41642, 42009, 42021, 42049, 42052, 42064, 42068,\n        42069, 42072, 42074, 42081, 42085, 42086, 42088, 42089, 42117, 42246, 42249, 42256, 42258, 42261, 42264, 42278,\n        42281, 42306, 42309, 42321, 42324, 42325, 42326, 42329, 42341, 42346, 42369, 42372, 42373, 42374, 42377, 42386,\n        42389, 42392, 42501, 42513, 42518, 42522, 42529, 42533, 42564, 42566, 42570, 42578, 42581, 42582, 42584, 42592,\n        42594, 42630, 42640, 42645, 42646, 42649, 42657, 42660, 42662, 43008, 43010, 43016, 43018, 43040, 43042, 43048,\n        43050, 43089, 43092, 43094, 43097, 43136, 43138, 43144, 43146, 43157, 43168, 43170, 43176, 43178, 43269, 43284,\n        43289, 43297, 43301, 43329, 43344, 43349, 43354, 43361, 43366, 43369, 43408, 43414, 43520, 43522, 43528, 43530,\n        43552, 43554, 43560, 43562, 43601, 43604, 43606, 43648, 43650, 43656, 43658, 43669, 43680, 43682, 43688, 43690,\n    };\n    static const uint16_t kgrid_2bit_1024[1024] = {\n            0,     2,     5,     8,    10,    17,    20,    22,    25,    32,    34,    37,    40,    65,    68,    70,\n           73,    80,    82,    85,    88,    97,   100,   102,   105,   128,   130,   133,   136,   145,   148,   160,\n          165,   170,   257,   260,   262,   265,   272,   274,   277,   280,   289,   292,   320,   322,   325,   328,\n          337,   340,   342,   345,   352,   357,   360,   385,   388,   400,   402,   405,   417,   420,   512,   514,\n          517,   520,   529,   532,   544,   554,   577,   580,   582,   585,   592,   597,   640,   645,   650,   660,\n          674,  1025,  1028,  1030,  1033,  1040,  1042,  1045,  1048,  1057,  1060,  1062,  1065,  1088,  1090,  1093,\n         1096,  1098,  1105,  1108,  1110,  1113,  1120,  1122,  1125,  1153,  1156,  1158,  1161,  1168,  1173,  1176,\n         1185,  1188,  1280,  1282,  1285,  1288,  1290,  1297,  1300,  1302,  1305,  1312,  1317,  1320,  1345,  1348,\n         1350,  1353,  1360,  1362,  1365,  1368,  1377,  1380,  1408,  1410,  1413,  1416,  1425,  1428,  1440,  1537,\n         1540,  1542,  1545,  1552,  1557,  1600,  1605,  1608,  1617,  1620,  1632,  1665,  1668,  1680,  2048,  2050,\n         2053,  2056,  2065,  2068,  2070,  2073,  2080,  2085,  2090,  2113,  2116,  2118,  2121,  2128,  2130,  2133,\n         2136,  2145,  2148,  2176,  2181,  2196,  2218,  2305,  2308,  2320,  2322,  2325,  2328,  2337,  2368,  2373,\n         2376,  2385,  2388,  2400,  2433,  2448,  2560,  2577,  2580,  2594,  2600,  2602,  2640,  2713,  4097,  4100,\n         4102,  4105,  4112,  4114,  4117,  4120,  4129,  4132,  4134,  4160,  4162,  4165,  4168,  4177,  4180,  4182,\n         4185,  4192,  4194,  4197,  4200,  4225,  4228,  4230,  4240,  4245,  4248,  4257,  4260,  4352,  4354,  4357,\n         4360,  4362,  4369,  4372,  4374,  4377,  4384,  4386,  4389,  4392,  4417,  4420,  4422,  4425,  4432,  4434,\n         4437,  4440,  4449,  4452,  4480,  4482,  4485,  4488,  4497,  4500,  4609,  4612,  4617,  4624,  4629,  4641,\n         4644,  4672,  4677,  4689,  4692,  4737,  4740,  4752,  5120,  5122,  5125,  5128,  5137,  5140,  5142,  5145,\n         5152,  5157,  5160,  5185,  5188,  5190,  5193,  5200,  5202,  5205,  5208,  5217,  5220,  5248,  5250,  5253,\n         5256,  5265,  5268,  5280,  5377,  5380,  5382,  5385,  5392,  5394,  5397,  5400,  5409,  5412,  5440,  5442,\n         5445,  5448,  5457,  5460,  5472,  5505,  5508,  5520,  5632,  5637,  5640,  5649,  5652,  5664,  5697,  5700,\n         5712,  5760,  5802,  6145,  6148,  6150,  6153,  6160,  6165,  6168,  6177,  6208,  6210,  6213,  6216,  6225,\n         6228,  6240,  6273,  6276,  6400,  6402,  6405,  6408,  6417,  6420,  6432,  6465,  6468,  6480,  6505,  6562,\n         6660,  6672,  6720,  6742,  8192,  8194,  8197,  8200,  8209,  8212,  8214,  8217,  8224,  8229,  8234,  8257,\n         8260,  8272,  8274,  8277,  8292,  8320,  8330,  8340,  8362,  8449,  8452,  8464,  8466,  8469,  8481,  8512,\n         8514,  8517,  8529,  8532,  8544,  8577,  8580,  8592,  8704,  8714,  8738,  8744,  8746,  8772,  8784,  8840,\n         8842,  8872,  9217,  9220,  9222,  9225,  9232,  9237,  9240,  9249,  9252,  9280,  9282,  9285,  9288,  9297,\n         9300,  9312,  9345,  9348,  9360,  9472,  9477,  9480,  9489,  9492,  9504,  9537,  9540,  9552,  9574,  9600,\n         9729,  9732,  9744,  9792,  9817, 10240, 10245, 10257, 10260, 10305, 10308, 10320, 10378, 10410, 10497, 10500,\n        10512, 10645, 10762, 10786, 10852, 10888, 10890, 16385, 16388, 16390, 16393, 16400, 16402, 16405, 16408, 16410,\n        16417, 16420, 16422, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16470, 16473, 16480, 16482, 16485, 16513,\n        16516, 16528, 16533, 16536, 16545, 16548, 16640, 16642, 16645, 16648, 16657, 16660, 16662, 16665, 16672, 16674,\n        16677, 16705, 16708, 16710, 16713, 16720, 16722, 16725, 16728, 16737, 16740, 16768, 16770, 16773, 16776, 16785,\n        16788, 16800, 16897, 16900, 16912, 16914, 16917, 16920, 16932, 16960, 16965, 16968, 16977, 16980, 16992, 17025,\n        17028, 17408, 17410, 17413, 17416, 17418, 17425, 17428, 17430, 17433, 17440, 17442, 17445, 17448, 17473, 17476,\n        17478, 17481, 17488, 17490, 17493, 17496, 17505, 17508, 17536, 17538, 17541, 17544, 17553, 17556, 17568, 17665,\n        17668, 17670, 17673, 17680, 17682, 17685, 17688, 17697, 17700, 17728, 17730, 17733, 17736, 17745, 17748, 17760,\n        17770, 17793, 17796, 17808, 17920, 17922, 17925, 17928, 17937, 17940, 17952, 17985, 17988, 18000, 18048, 18085,\n        18433, 18436, 18441, 18448, 18450, 18453, 18456, 18465, 18468, 18496, 18498, 18501, 18504, 18513, 18516, 18528,\n        18564, 18576, 18688, 18690, 18693, 18696, 18705, 18708, 18720, 18753, 18756, 18768, 18816, 18838, 18945, 18948,\n        18960, 19008, 20480, 20482, 20485, 20488, 20497, 20500, 20502, 20505, 20512, 20514, 20517, 20520, 20545, 20548,\n        20550, 20553, 20560, 20562, 20565, 20568, 20577, 20580, 20608, 20610, 20613, 20616, 20625, 20628, 20737, 20740,\n        20742, 20745, 20752, 20754, 20757, 20760, 20769, 20772, 20800, 20802, 20805, 20808, 20817, 20820, 20832, 20865,\n        20868, 20880, 20992, 20997, 21000, 21009, 21012, 21024, 21057, 21060, 21072, 21097, 21120, 21505, 21508, 21510,\n        21513, 21520, 21522, 21525, 21528, 21537, 21540, 21568, 21570, 21573, 21576, 21585, 21588, 21600, 21633, 21636,\n        21648, 21760, 21762, 21765, 21768, 21777, 21780, 21792, 21825, 21828, 21840, 21888, 22017, 22020, 22032, 22054,\n        22080, 22528, 22530, 22533, 22536, 22545, 22548, 22560, 22593, 22596, 22608, 22618, 22656, 22785, 22788, 22800,\n        22848, 23040, 23065, 23173, 23208, 24577, 24580, 24582, 24592, 24594, 24597, 24600, 24609, 24612, 24640, 24645,\n        24648, 24657, 24660, 24672, 24708, 24720, 24832, 24834, 24837, 24840, 24849, 24852, 24864, 24897, 24900, 24912,\n        24960, 24985, 25092, 25104, 25152, 25174, 25249, 25600, 25605, 25608, 25617, 25620, 25632, 25665, 25668, 25680,\n        25728, 25857, 25860, 25872, 25920, 25930, 25960, 26002, 26112, 26260, 26625, 26628, 26640, 26725, 26776, 26880,\n        26922, 27202, 27297, 32768, 32770, 32773, 32776, 32785, 32788, 32793, 32800, 32805, 32833, 32836, 32848, 32850,\n        32853, 32856, 32865, 32896, 32901, 32913, 32916, 33025, 33028, 33033, 33040, 33042, 33045, 33048, 33057, 33060,\n        33088, 33090, 33093, 33096, 33105, 33108, 33153, 33156, 33168, 33193, 33280, 33285, 33290, 33297, 33300, 33345,\n        33348, 33360, 33793, 33796, 33798, 33801, 33808, 33810, 33813, 33816, 33825, 33856, 33858, 33861, 33864, 33873,\n        33876, 33888, 33921, 33924, 33936, 34048, 34050, 34053, 34056, 34065, 34068, 34080, 34113, 34116, 34128, 34176,\n        34186, 34305, 34308, 34320, 34345, 34368, 34816, 34821, 34833, 34836, 34881, 34884, 34896, 34978, 35073, 35076,\n        35136, 35173, 35362, 35416, 35418, 35458, 35490, 36865, 36868, 36873, 36880, 36882, 36885, 36888, 36900, 36928,\n        36930, 36933, 36936, 36945, 36948, 36960, 36993, 36996, 37008, 37120, 37125, 37137, 37140, 37185, 37188, 37200,\n        37210, 37377, 37380, 37392, 37440, 37542, 37888, 37890, 37893, 37896, 37905, 37908, 37920, 37953, 37956, 37968,\n        38016, 38038, 38145, 38148, 38160, 38208, 38296, 38305, 38400, 38470, 38500, 38913, 38916, 38928, 38950, 38976,\n        39081, 39168, 39241, 39250, 39568, 40960, 40965, 40970, 40980, 40994, 41002, 41025, 41028, 41040, 41122, 41130,\n        41280, 41317, 41474, 41482, 41506, 41512, 41514, 41602, 41608, 41610, 41640, 41985, 41988, 42000, 42048, 42121,\n        42148, 42240, 42265, 42577, 43018, 43048, 43170, 43348, 43398, 43528, 43530, 43552, 43554, 43560, 43656, 43690,\n    };\n\n    const int kmap_size = 43692;\n    //const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2;\n    const int nwant = type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2;\n    const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :\n                             type == GGML_TYPE_IQ2_XS  ? kgrid_2bit_512 :\n                             type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? kgrid_1bit_2048 : kgrid_2bit_1024;\n    uint64_t * kgrid_q2xs;\n    int      * kmap_q2xs;\n    uint16_t * kneighbors_q2xs;\n\n    //printf(\"================================================================= %s(grid_size = %d)\\n\", __func__, grid_size);\n    uint64_t * the_grid = (uint64_t *)malloc(grid_size*sizeof(uint64_t));\n    for (int k = 0; k < grid_size; ++k) {\n        int8_t * pos = (int8_t *)(the_grid + k);\n        for (int i = 0; i < 8; ++i) {\n            int l = (kgrid[k] >> 2*i) & 0x3;\n            pos[i] = 2*l + 1;\n        }\n    }\n    kgrid_q2xs = the_grid;\n    iq2_data[gindex].grid = the_grid;\n    kmap_q2xs = (int *)malloc(kmap_size*sizeof(int));\n    iq2_data[gindex].map = kmap_q2xs;\n    for (int i = 0; i < kmap_size; ++i) kmap_q2xs[i] = -1;\n    uint64_t aux64;\n    uint8_t * aux8 = (uint8_t *)&aux64;\n    for (int i = 0; i < grid_size; ++i) {\n        aux64 = kgrid_q2xs[i];\n        uint16_t index = 0;\n        for (int k=0; k<8; ++k) {\n            uint16_t q = (aux8[k] - 1)/2;\n            index |= (q << 2*k);\n        }\n        kmap_q2xs[index] = i;\n    }\n    int8_t pos[8];\n    int * dist2 = (int *)malloc(2*grid_size*sizeof(int));\n    int num_neighbors = 0, num_not_in_map = 0;\n    for (int i = 0; i < kmap_size; ++i) {\n        if (kmap_q2xs[i] >= 0) continue;\n        ++num_not_in_map;\n        for (int k = 0; k < 8; ++k) {\n            int l = (i >> 2*k) & 0x3;\n            pos[k] = 2*l + 1;\n        }\n        for (int j = 0; j < grid_size; ++j) {\n            const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);\n            int d2 = 0;\n            for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);\n            dist2[2*j+0] = d2;\n            dist2[2*j+1] = j;\n        }\n        qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);\n        int n = 0; int d2 = dist2[0];\n        int nhave = 1;\n        for (int j = 0; j < grid_size; ++j) {\n            if (dist2[2*j] > d2) {\n                if (nhave == nwant) break;\n                d2 = dist2[2*j];\n                ++nhave;\n            }\n            ++n;\n        }\n        num_neighbors += n;\n    }\n    //printf(\"%s: %d neighbours in total\\n\", __func__, num_neighbors);\n    kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));\n    iq2_data[gindex].neighbours = kneighbors_q2xs;\n    int counter = 0;\n    for (int i = 0; i < kmap_size; ++i) {\n        if (kmap_q2xs[i] >= 0) continue;\n        for (int k = 0; k < 8; ++k) {\n            int l = (i >> 2*k) & 0x3;\n            pos[k] = 2*l + 1;\n        }\n        for (int j = 0; j < grid_size; ++j) {\n            const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);\n            int d2 = 0;\n            for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);\n            dist2[2*j+0] = d2;\n            dist2[2*j+1] = j;\n        }\n        qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);\n        kmap_q2xs[i] = -(counter + 1);\n        int d2 = dist2[0];\n        uint16_t * start = &kneighbors_q2xs[counter++];\n        int n = 0, nhave = 1;\n        for (int j = 0; j < grid_size; ++j) {\n            if (dist2[2*j] > d2) {\n                if (nhave == nwant) break;\n                d2 = dist2[2*j];\n                ++nhave;\n            }\n            kneighbors_q2xs[counter++] = dist2[2*j+1];\n            ++n;\n        }\n        *start = n;\n    }\n    free(dist2);\n}\n\nvoid iq2xs_free_impl(enum ggml_type type) {\n    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);\n    const int gindex = iq2_data_index(type);\n    if (iq2_data[gindex].grid) {\n        free(iq2_data[gindex].grid);       iq2_data[gindex].grid = NULL;\n        free(iq2_data[gindex].map);        iq2_data[gindex].map  = NULL;\n        free(iq2_data[gindex].neighbours); iq2_data[gindex].neighbours = NULL;\n    }\n}\n\nstatic int iq2_find_best_neighbour(const uint16_t * GGML_RESTRICT neighbours, const uint64_t * GGML_RESTRICT grid,\n        const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float scale, int8_t * GGML_RESTRICT L) {\n    int num_neighbors = neighbours[0];\n    GGML_ASSERT(num_neighbors > 0);\n    float best_d2 = FLT_MAX;\n    int grid_index = -1;\n    for (int j = 1; j <= num_neighbors; ++j) {\n        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);\n        float d2 = 0;\n        for (int i = 0; i < 8; ++i) {\n            float q = pg[i];\n            float diff = scale*q - xval[i];\n            d2 += weight[i]*diff*diff;\n        }\n        if (d2 < best_d2) {\n            best_d2 = d2; grid_index = neighbours[j];\n        }\n    }\n    GGML_ASSERT(grid_index >= 0);\n    const int8_t * pg = (const int8_t *)(grid + grid_index);\n    for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;\n    return grid_index;\n}\n\nstatic void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights) {\n\n    const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS);\n\n    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;\n    const int      * kmap_q2xs       = iq2_data[gindex].map;\n    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;\n\n    GGML_ASSERT(quant_weights   && \"missing quantization weights\");\n    GGML_ASSERT(kgrid_q2xs      && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(kmap_q2xs       && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(kneighbors_q2xs && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(n%QK_K == 0);\n\n    const int kMaxQ = 3;\n\n    const int64_t nbl = n/QK_K;\n\n    block_iq2_xxs * y = vy;\n\n    float scales[QK_K/32];\n    float weight[32];\n    float xval[32];\n    int8_t L[32];\n    int8_t Laux[32];\n    float  waux[32];\n    uint8_t block_signs[4];\n    uint32_t q2[2*(QK_K/32)];\n\n    for (int ibl = 0; ibl < nbl; ++ibl) {\n\n        y[ibl].d = GGML_FP32_TO_FP16(0.f);\n        memset(q2, 0, QK_K/4);\n\n        float max_scale = 0;\n\n        const float * xbl = x + QK_K*ibl;\n        float sumx2 = 0;\n        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];\n        float sigma2 = sumx2/QK_K;\n\n        for (int ib = 0; ib < QK_K/32; ++ib) {\n            const float * xb = xbl + 32*ib;\n            const float * qw = quant_weights + QK_K*ibl + 32*ib;\n            for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);\n            for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);\n            for (int k = 0; k < 4; ++k) {\n                int nflip = 0;\n                uint8_t s = 0;\n                for (int i = 0; i < 8; ++i) {\n                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];\n                    else {\n                        xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);\n                    }\n                }\n                if (nflip%2) {\n                    int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];\n                    for (int i = 1; i < 8; ++i) {\n                        float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];\n                        if (ax < min) {\n                            min = ax; imin = i;\n                        }\n                    }\n                    xval[8*k+imin] = -xval[8*k+imin];\n                    s ^= (1 << imin);\n                }\n                block_signs[k] = s & 127;\n            }\n            float max = xval[0];\n            for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);\n            if (max < GROUP_MAX_EPS) {\n                scales[ib] = 0;\n                memset(L, 0, 32);\n                continue;\n            }\n            float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight);\n            float eff_max = scale*kMaxQ;\n            if (eff_max <= 0) {\n                scales[ib] = 0;\n                memset(L, 0, 32);\n                continue;\n            }\n            float best = 0;\n            for (int is = -6; is <= 6; ++is) {\n                float id = (2*kMaxQ-1+is*0.1f)/eff_max;\n                float this_scale = 1/id;\n                for (int k = 0; k < 4; ++k) {\n                    for (int i = 0; i < 8; ++i) {\n                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));\n                        Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));\n                    }\n                    uint16_t u = 0;\n                    for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);\n                    int grid_index = kmap_q2xs[u];\n                    if (grid_index < 0) {\n                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;\n                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);\n                    }\n                }\n                float sumqx = 0, sumq2 = 0;\n                for (int i = 0; i < 32; ++i) {\n                    float w = weight[i];\n                    float q = 2*Laux[i] + 1;\n                    sumqx += w*xval[i]*q;\n                    sumq2 += w*q*q;\n                }\n                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {\n                    scale = sumqx/sumq2; best = scale*sumqx;\n                    memcpy(L, Laux, 32);\n                }\n            }\n            if (scale > 0) {\n                float id = 1/scale;\n                for (int k = 0; k < 4; ++k) {\n                    uint16_t u = 0;\n                    for (int i = 0; i < 8; ++i) {\n                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));\n                        l = MAX(0, MIN(kMaxQ-1, l));\n                        u |= (l << 2*i);\n                    }\n                    int grid_index = kmap_q2xs[u];\n                    if (grid_index < 0) {\n                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;\n                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);\n                    }\n                    const int8_t * pg = (const int8_t *)(kgrid_q2xs + grid_index);\n                    for (int i = 0; i < 8; ++i) L[8*k+i] = (pg[i] - 1)/2;\n                }\n                float sumqx = 0, sumq2 = 0;\n                for (int i = 0; i < 32; ++i) {\n                    float w = weight[i];\n                    float q = 2*L[i] + 1;\n                    sumqx += w*xval[i]*q;\n                    sumq2 += w*q*q;\n                }\n                if (sumq2 > 0) scale = sumqx/sumq2;\n            }\n            if (scale < 0) {\n                // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)\n                // and correspondingly flip quant signs.\n                scale = -scale;\n                for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;\n            }\n            for (int k = 0; k < 4; ++k) {\n                uint16_t u = 0;\n                for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);\n                int grid_index = kmap_q2xs[u];\n                if (grid_index < 0) {\n                    printf(\"Oops: found point %u not on grid:\", u);\n                    for (int i = 0; i < 8; ++i) printf(\" %d\", L[8*k+i]);\n                    printf(\"\\n\");\n                    GGML_ABORT(\"fatal error\");\n                }\n                q2[2*ib+0] |= ((uint32_t) grid_index << 8*k);\n                q2[2*ib+1] |= (block_signs[k] << 7*k);\n            }\n            GGML_ASSERT(scale >= 0);\n            scales[ib] = scale;\n            max_scale = MAX(max_scale, scale);\n        }\n\n        if (!max_scale) {\n            memset(y[ibl].qs, 0, QK_K/4);\n            continue;\n        }\n\n        float d = max_scale/31;\n        y[ibl].d = GGML_FP32_TO_FP16(d);\n        float id = 1/d;\n        for (int ib = 0; ib < QK_K/32; ++ib) {\n            int l = nearest_int(0.5f*(id*scales[ib]-1));\n            l = MAX(0, MIN(15, l));\n            q2[2*ib+1] |= ((uint32_t)l << 28);\n        }\n        memcpy(y[ibl].qs, q2, QK_K/4);\n    }\n}\n\nstatic void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights) {\n\n    const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS);\n\n    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;\n    const int      * kmap_q2xs       = iq2_data[gindex].map;\n    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;\n\n    GGML_ASSERT(quant_weights   && \"missing quantization weights\");\n    GGML_ASSERT(kmap_q2xs       && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(kgrid_q2xs      && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(kneighbors_q2xs && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(n%QK_K == 0);\n\n    const int kMaxQ = 3;\n\n    const int64_t nbl = n/QK_K;\n\n    block_iq2_xs * y = vy;\n\n    float scales[QK_K/16];\n    float weight[16];\n    float xval[16];\n    int8_t L[16];\n    int8_t Laux[16];\n    float  waux[16];\n    bool   is_on_grid[2];\n    bool   is_on_grid_aux[2];\n    uint8_t block_signs[2];\n    uint16_t q2[2*(QK_K/16)];\n\n    for (int ibl = 0; ibl < nbl; ++ibl) {\n\n        y[ibl].d = GGML_FP32_TO_FP16(0.f);\n        memset(q2, 0, QK_K/4);\n        memset(y[ibl].scales, 0, QK_K/32);\n\n        float max_scale = 0;\n\n        const float * xbl = x + QK_K*ibl;\n        float sumx2 = 0;\n        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];\n        float sigma2 = sumx2/QK_K;\n\n        for (int ib = 0; ib < QK_K/16; ++ib) {\n            const float * xb = xbl + 16*ib;\n            const float * qw = quant_weights + QK_K*ibl + 16*ib;\n            for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);\n            for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]);\n            for (int k = 0; k < 2; ++k) {\n                int nflip = 0;\n                uint8_t s = 0;\n                for (int i = 0; i < 8; ++i) {\n                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];\n                    else {\n                        xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);\n                    }\n                }\n                if (nflip%2) {\n                    int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];\n                    for (int i = 1; i < 8; ++i) {\n                        float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];\n                        if (ax < min) {\n                            min = ax; imin = i;\n                        }\n                    }\n                    xval[8*k+imin] = -xval[8*k+imin];\n                    s ^= (1 << imin);\n                }\n                block_signs[k] = s & 127;\n            }\n            float max = xval[0];\n            for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);\n            memset(L, 0, 16);\n            if (max < GROUP_MAX_EPS) {\n                scales[ib] = 0;\n                continue;\n            }\n            float best = 0;\n            float scale = max/(2*kMaxQ-1);\n            is_on_grid[0] = is_on_grid[1] = true;\n            for (int is = -9; is <= 9; ++is) {\n                float id = (2*kMaxQ-1+is*0.1f)/max;\n                float this_scale = 1/id;\n                for (int k = 0; k < 2; ++k) {\n                    for (int i = 0; i < 8; ++i) {\n                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));\n                        Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));\n                    }\n                    uint16_t u = 0;\n                    for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);\n                    int grid_index = kmap_q2xs[u];\n                    is_on_grid_aux[k] = true;\n                    if (grid_index < 0) {\n                        is_on_grid_aux[k] = false;\n                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;\n                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);\n                    }\n                }\n                float sumqx = 0, sumq2 = 0;\n                for (int i = 0; i < 16; ++i) {\n                    float w = weight[i];\n                    float q = 2*Laux[i] + 1;\n                    sumqx += w*xval[i]*q;\n                    sumq2 += w*q*q;\n                }\n                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {\n                    scale = sumqx/sumq2; best = scale*sumqx;\n                    for (int i = 0; i < 16; ++i) L[i] = Laux[i];\n                    for (int k = 0; k <  2; ++k) is_on_grid[k] = is_on_grid_aux[k];\n                }\n            }\n            int n_not_ongrid = 0;\n            for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid;\n            if (n_not_ongrid > 0 && scale > 0) {\n                float id = 1/scale;\n                for (int k = 0; k < 2; ++k) {\n                    if (is_on_grid[k]) continue;\n                    uint16_t u = 0;\n                    for (int i = 0; i < 8; ++i) {\n                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));\n                        l = MAX(0, MIN(kMaxQ-1, l));\n                        u |= (l << 2*i);\n                        L[8*k + i] = l;\n                    }\n                    int grid_index = kmap_q2xs[u];\n                    if (grid_index < 0) {\n                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;\n                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);\n                    }\n                }\n                float sumqx = 0, sumq2 = 0;\n                for (int i = 0; i < 16; ++i) {\n                    float w = weight[i];\n                    float q = 2*L[i] + 1;\n                    sumqx += w*xval[i]*q;\n                    sumq2 += w*q*q;\n                }\n                if (sumq2 > 0) scale = sumqx/sumq2;\n            }\n            if (scale < 0) {\n                scale = -scale;\n                for (int k = 0; k < 2; ++k) block_signs[k] = (~block_signs[k]) & 127;\n            }\n            for (int k = 0; k < 2; ++k) {\n                uint16_t u = 0;\n                for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);\n                int grid_index = kmap_q2xs[u];\n                if (grid_index < 0) {\n                    printf(\"Oops: found point %u not on grid:\", u);\n                    for (int i = 0; i < 8; ++i) printf(\" %d\", L[8*k+i]);\n                    printf(\"\\n\");\n                    GGML_ABORT(\"fatal error\");\n                }\n                q2[2*ib+k] = grid_index | (block_signs[k] << 9);\n            }\n            GGML_ASSERT(scale >= 0);\n            scales[ib] = scale;\n            max_scale = MAX(max_scale, scale);\n        }\n\n        if (!max_scale) {\n            memset(y[ibl].qs, 0, QK_K/4);\n            continue;\n        }\n\n        float d = max_scale/31;\n        y[ibl].d = GGML_FP32_TO_FP16(d);\n        float id = 1/d;\n        for (int ib = 0; ib < QK_K/16; ++ib) {\n            int l = nearest_int(0.5f*(id*scales[ib]-1));\n            l = MAX(0, MIN(15, l));\n            if (ib%2 == 0) y[ibl].scales[ib/2] = l;\n            else y[ibl].scales[ib/2] |= (l << 4);\n        }\n        memcpy(y[ibl].qs, q2, QK_K/4);\n\n    }\n}\n\nsize_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    GGML_ASSERT(n_per_row%QK_K == 0);\n    int64_t nblock = n_per_row/QK_K;\n    char * qrow = (char *)dst;\n    for (int64_t row = 0; row < nrow; ++row) {\n        quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights);\n        src += n_per_row;\n        qrow += nblock*sizeof(block_iq2_xxs);\n    }\n    return nrow * nblock * sizeof(block_iq2_xxs);\n}\n\nsize_t quantize_iq2_xs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    GGML_ASSERT(n_per_row%QK_K == 0);\n    int64_t nblock = n_per_row/QK_K;\n    char * qrow = (char *)dst;\n    for (int64_t row = 0; row < nrow; ++row) {\n        quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights);\n        src += n_per_row;\n        qrow += nblock*sizeof(block_iq2_xs);\n    }\n    return nrow * nblock * sizeof(block_iq2_xs);\n}\n\n//\n// ============================================= 3-bit using D4 lattice\n//\n\ntypedef struct {\n    uint32_t * grid;\n    int      * map;\n    uint16_t * neighbours;\n} iq3_entry_t;\n\nstatic iq3_entry_t iq3_data[2] = {\n    {NULL, NULL, NULL},\n    {NULL, NULL, NULL},\n};\n\nstatic inline int iq3_data_index(int grid_size) {\n    (void)grid_size;\n    GGML_ASSERT(grid_size == 256 || grid_size == 512);\n    return grid_size == 256 ? 0 : 1;\n}\n\nstatic int iq3_compare_func(const void * left, const void * right) {\n    const int * l = (const int *)left;\n    const int * r = (const int *)right;\n    return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;\n}\n\nvoid iq3xs_init_impl(int grid_size) {\n    const int gindex = iq3_data_index(grid_size);\n    if (iq3_data[gindex].grid) {\n        return;\n    }\n    static const uint16_t kgrid_256[256] = {\n            0,     2,     4,     9,    11,    15,    16,    18,    25,    34,    59,    61,    65,    67,    72,    74,\n           81,    85,    88,    90,    97,   108,   120,   128,   130,   132,   137,   144,   146,   153,   155,   159,\n          169,   175,   189,   193,   199,   200,   202,   213,   248,   267,   287,   292,   303,   315,   317,   321,\n          327,   346,   362,   413,   436,   456,   460,   462,   483,   497,   513,   515,   520,   522,   529,   531,\n          536,   538,   540,   551,   552,   576,   578,   585,   592,   594,   641,   643,   648,   650,   657,   664,\n          698,   704,   706,   720,   729,   742,   758,   769,   773,   808,   848,   852,   870,   889,   901,   978,\n          992,  1024,  1026,  1033,  1035,  1040,  1042,  1046,  1049,  1058,  1089,  1091,  1093,  1096,  1098,  1105,\n         1112,  1139,  1143,  1144,  1152,  1154,  1161,  1167,  1168,  1170,  1183,  1184,  1197,  1217,  1224,  1228,\n         1272,  1276,  1309,  1323,  1347,  1367,  1377,  1404,  1473,  1475,  1486,  1509,  1537,  1544,  1546,  1553,\n         1555,  1576,  1589,  1594,  1600,  1602,  1616,  1625,  1636,  1638,  1665,  1667,  1672,  1685,  1706,  1722,\n         1737,  1755,  1816,  1831,  1850,  1856,  1862,  1874,  1901,  1932,  1950,  1971,  2011,  2032,  2052,  2063,\n         2077,  2079,  2091,  2095,  2172,  2192,  2207,  2208,  2224,  2230,  2247,  2277,  2308,  2345,  2356,  2389,\n         2403,  2424,  2501,  2504,  2506,  2520,  2570,  2593,  2616,  2624,  2630,  2646,  2669,  2700,  2714,  2746,\n         2754,  2795,  2824,  2835,  2839,  2874,  2882,  2905,  2984,  3028,  3042,  3092,  3108,  3110,  3124,  3153,\n         3185,  3215,  3252,  3288,  3294,  3364,  3397,  3434,  3483,  3523,  3537,  3587,  3589,  3591,  3592,  3610,\n         3626,  3670,  3680,  3722,  3749,  3754,  3776,  3789,  3803,  3824,  3857,  3873,  3904,  3906,  3924,  3992,\n    };\n    static const uint16_t kgrid_512[512] = {\n            0,     1,     2,     5,     7,     8,     9,    10,    12,    14,    16,    17,    21,    27,    32,    34,\n           37,    39,    41,    43,    48,    50,    57,    60,    63,    64,    65,    66,    68,    72,    73,    77,\n           80,    83,    87,    89,    93,   100,   113,   117,   122,   128,   129,   133,   135,   136,   139,   142,\n          145,   149,   152,   156,   162,   165,   167,   169,   171,   184,   187,   195,   201,   205,   208,   210,\n          217,   219,   222,   228,   232,   234,   247,   249,   253,   256,   267,   271,   273,   276,   282,   288,\n          291,   297,   312,   322,   324,   336,   338,   342,   347,   353,   357,   359,   374,   379,   390,   393,\n          395,   409,   426,   441,   448,   450,   452,   464,   466,   470,   475,   488,   492,   512,   513,   514,\n          516,   520,   521,   523,   525,   527,   528,   530,   537,   540,   542,   556,   558,   561,   570,   576,\n          577,   579,   582,   584,   588,   593,   600,   603,   609,   616,   618,   632,   638,   640,   650,   653,\n          655,   656,   660,   666,   672,   675,   685,   688,   698,   705,   708,   711,   712,   715,   721,   727,\n          728,   732,   737,   754,   760,   771,   773,   778,   780,   793,   795,   802,   806,   808,   812,   833,\n          840,   843,   849,   856,   858,   873,   912,   916,   919,   932,   934,   961,   963,   968,   970,   977,\n          989,   993,  1010,  1016,  1024,  1025,  1027,  1029,  1031,  1032,  1034,  1036,  1038,  1041,  1043,  1047,\n         1048,  1050,  1057,  1059,  1061,  1064,  1066,  1079,  1080,  1083,  1085,  1088,  1090,  1096,  1099,  1103,\n         1106,  1109,  1113,  1116,  1122,  1129,  1153,  1156,  1159,  1169,  1171,  1176,  1183,  1185,  1195,  1199,\n         1209,  1212,  1216,  1218,  1221,  1225,  1234,  1236,  1241,  1243,  1250,  1256,  1270,  1281,  1287,  1296,\n         1299,  1306,  1309,  1313,  1338,  1341,  1348,  1353,  1362,  1375,  1376,  1387,  1400,  1408,  1410,  1415,\n         1425,  1453,  1457,  1477,  1481,  1494,  1496,  1507,  1512,  1538,  1545,  1547,  1549,  1551,  1554,  1561,\n         1563,  1565,  1570,  1572,  1575,  1577,  1587,  1593,  1601,  1603,  1605,  1612,  1617,  1619,  1632,  1648,\n         1658,  1662,  1664,  1674,  1680,  1690,  1692,  1704,  1729,  1736,  1740,  1745,  1747,  1751,  1752,  1761,\n         1763,  1767,  1773,  1787,  1795,  1801,  1806,  1810,  1817,  1834,  1840,  1844,  1857,  1864,  1866,  1877,\n         1882,  1892,  1902,  1915,  1934,  1953,  1985,  1987,  2000,  2002,  2013,  2048,  2052,  2058,  2064,  2068,\n         2071,  2074,  2081,  2088,  2104,  2114,  2119,  2121,  2123,  2130,  2136,  2141,  2147,  2153,  2157,  2177,\n         2179,  2184,  2189,  2193,  2203,  2208,  2223,  2226,  2232,  2244,  2249,  2251,  2256,  2258,  2265,  2269,\n         2304,  2306,  2324,  2335,  2336,  2361,  2373,  2375,  2385,  2418,  2443,  2460,  2480,  2504,  2509,  2520,\n         2531,  2537,  2562,  2568,  2572,  2578,  2592,  2596,  2599,  2602,  2614,  2620,  2625,  2627,  2629,  2634,\n         2641,  2650,  2682,  2688,  2697,  2707,  2712,  2718,  2731,  2754,  2759,  2760,  2775,  2788,  2793,  2805,\n         2811,  2817,  2820,  2832,  2842,  2854,  2890,  2902,  2921,  2923,  2978,  3010,  3012,  3026,  3081,  3083,\n         3085,  3097,  3099,  3120,  3136,  3152,  3159,  3188,  3210,  3228,  3234,  3245,  3250,  3256,  3264,  3276,\n         3281,  3296,  3349,  3363,  3378,  3392,  3395,  3420,  3440,  3461,  3488,  3529,  3531,  3584,  3588,  3591,\n         3600,  3602,  3614,  3616,  3628,  3634,  3650,  3657,  3668,  3683,  3685,  3713,  3716,  3720,  3726,  3729,\n         3736,  3753,  3778,  3802,  3805,  3819,  3841,  3845,  3851,  3856,  3880,  3922,  3938,  3970,  3993,  4032,\n    };\n\n    const int kmap_size = 4096;\n    const int nwant = grid_size == 256 ? 2 : 3;\n    const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;\n    uint32_t * kgrid_q3xs;\n    int      * kmap_q3xs;\n    uint16_t * kneighbors_q3xs;\n\n    //printf(\"================================================================= %s(grid_size = %d)\\n\", __func__, grid_size);\n    uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t));\n    for (int k = 0; k < grid_size; ++k) {\n        int8_t * pos = (int8_t *)(the_grid + k);\n        for (int i = 0; i < 4; ++i) {\n            int l = (kgrid[k] >> 3*i) & 0x7;\n            pos[i] = 2*l + 1;\n        }\n    }\n    kgrid_q3xs = the_grid;\n    iq3_data[gindex].grid = the_grid;\n    kmap_q3xs = (int *)malloc(kmap_size*sizeof(int));\n    iq3_data[gindex].map = kmap_q3xs;\n    for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1;\n    uint32_t aux32;\n    uint8_t * aux8 = (uint8_t *)&aux32;\n    for (int i = 0; i < grid_size; ++i) {\n        aux32 = kgrid_q3xs[i];\n        uint16_t index = 0;\n        for (int k=0; k<4; ++k) {\n            uint16_t q = (aux8[k] - 1)/2;\n            index |= (q << 3*k);\n        }\n        kmap_q3xs[index] = i;\n    }\n    int8_t pos[4];\n    int * dist2 = (int *)malloc(2*grid_size*sizeof(int));\n    int num_neighbors = 0, num_not_in_map = 0;\n    for (int i = 0; i < kmap_size; ++i) {\n        if (kmap_q3xs[i] >= 0) continue;\n        ++num_not_in_map;\n        for (int k = 0; k < 4; ++k) {\n            int l = (i >> 3*k) & 0x7;\n            pos[k] = 2*l + 1;\n        }\n        for (int j = 0; j < grid_size; ++j) {\n            const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);\n            int d2 = 0;\n            for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);\n            dist2[2*j+0] = d2;\n            dist2[2*j+1] = j;\n        }\n        qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);\n        int n = 0; int d2 = dist2[0];\n        int nhave = 1;\n        for (int j = 0; j < grid_size; ++j) {\n            if (dist2[2*j] > d2) {\n                if (nhave == nwant) break;\n                d2 = dist2[2*j];\n                ++nhave;\n            }\n            ++n;\n        }\n        num_neighbors += n;\n    }\n    //printf(\"%s: %d neighbours in total\\n\", __func__, num_neighbors);\n    kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));\n    iq3_data[gindex].neighbours = kneighbors_q3xs;\n    int counter = 0;\n    for (int i = 0; i < kmap_size; ++i) {\n        if (kmap_q3xs[i] >= 0) continue;\n        for (int k = 0; k < 4; ++k) {\n            int l = (i >> 3*k) & 0x7;\n            pos[k] = 2*l + 1;\n        }\n        for (int j = 0; j < grid_size; ++j) {\n            const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);\n            int d2 = 0;\n            for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);\n            dist2[2*j+0] = d2;\n            dist2[2*j+1] = j;\n        }\n        qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);\n        kmap_q3xs[i] = -(counter + 1);\n        int d2 = dist2[0];\n        uint16_t * start = &kneighbors_q3xs[counter++];\n        int n = 0, nhave = 1;\n        for (int j = 0; j < grid_size; ++j) {\n            if (dist2[2*j] > d2) {\n                if (nhave == nwant) break;\n                d2 = dist2[2*j];\n                ++nhave;\n            }\n            kneighbors_q3xs[counter++] = dist2[2*j+1];\n            ++n;\n        }\n        *start = n;\n    }\n    free(dist2);\n}\n\nvoid iq3xs_free_impl(int grid_size) {\n    GGML_ASSERT(grid_size == 256 || grid_size == 512);\n    const int gindex = iq3_data_index(grid_size);\n    if (iq3_data[gindex].grid) {\n        free(iq3_data[gindex].grid);       iq3_data[gindex].grid = NULL;\n        free(iq3_data[gindex].map);        iq3_data[gindex].map  = NULL;\n        free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL;\n    }\n}\n\nstatic int iq3_find_best_neighbour(const uint16_t * GGML_RESTRICT neighbours, const uint32_t * GGML_RESTRICT grid,\n        const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float scale, int8_t * GGML_RESTRICT L) {\n    int num_neighbors = neighbours[0];\n    GGML_ASSERT(num_neighbors > 0);\n    float best_d2 = FLT_MAX;\n    int grid_index = -1;\n    for (int j = 1; j <= num_neighbors; ++j) {\n        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);\n        float d2 = 0;\n        for (int i = 0; i < 4; ++i) {\n            float q = pg[i];\n            float diff = scale*q - xval[i];\n            d2 += weight[i]*diff*diff;\n        }\n        if (d2 < best_d2) {\n            best_d2 = d2; grid_index = neighbours[j];\n        }\n    }\n    GGML_ASSERT(grid_index >= 0);\n    const int8_t * pg = (const int8_t *)(grid + grid_index);\n    for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2;\n    return grid_index;\n}\n\nstatic void quantize_row_iq3_xxs_impl(int grid_size, const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n,\n        const float * GGML_RESTRICT quant_weights) {\n\n    const int gindex = iq3_data_index(grid_size);\n\n    const uint32_t * kgrid_q3xs      = iq3_data[gindex].grid;\n    const int      * kmap_q3xs       = iq3_data[gindex].map;\n    const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;\n\n    //GGML_ASSERT(quant_weights   && \"missing quantization weights\");\n    GGML_ASSERT(kgrid_q3xs      && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(kmap_q3xs       && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(kneighbors_q3xs && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(n%QK_K == 0);\n\n    const int kMaxQ = 8;\n\n    const int64_t nbl = n/QK_K;\n\n    ggml_fp16_t * dh;\n    uint8_t * qs;\n    int block_size;\n    if (grid_size == 256) {\n        block_iq3_xxs * y = vy;\n        dh = &y->d;\n        qs = y->qs;\n        block_size = sizeof(block_iq3_xxs);\n    } else {\n        block_iq3_s * y = vy;\n        dh = &y->d;\n        qs = y->qs;\n        block_size = sizeof(block_iq3_s);\n    }\n    int quant_size = block_size - sizeof(ggml_fp16_t);\n\n    float scales[QK_K/32];\n    float weight[32];\n    float xval[32];\n    int8_t L[32];\n    int8_t Laux[32];\n    float  waux[32];\n    bool   is_on_grid[8];\n    bool   is_on_grid_aux[8];\n    uint8_t block_signs[8];\n    uint8_t q3[3*(QK_K/8)+QK_K/32];\n    uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);\n    uint8_t  * qh = q3 + 3*(QK_K/8);\n\n    for (int ibl = 0; ibl < nbl; ++ibl) {\n\n        dh[0] = GGML_FP32_TO_FP16(0.f);\n        memset(q3, 0, 3*QK_K/8+QK_K/32);\n\n        float max_scale = 0;\n\n        const float * xbl = x + QK_K*ibl;\n        float sumx2 = 0;\n        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];\n        float sigma2 = 2*sumx2/QK_K;\n\n        for (int ib = 0; ib < QK_K/32; ++ib) {\n            const float * xb = xbl + 32*ib;\n            if (quant_weights) {\n                const float * qw = quant_weights + QK_K*ibl + 32*ib;\n                for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);\n            } else {\n                for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];\n            }\n            for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);\n            for (int k = 0; k < 4; ++k) {\n                int nflip = 0;\n                uint8_t s = 0;\n                for (int i = 0; i < 8; ++i) {\n                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];\n                    else {\n                        xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);\n                    }\n                }\n                if (nflip%2) {\n                    int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];\n                    for (int i = 1; i < 8; ++i) {\n                        float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];\n                        if (ax < min) {\n                            min = ax; imin = i;\n                        }\n                    }\n                    xval[8*k+imin] = -xval[8*k+imin];\n                    s ^= (1 << imin);\n                }\n                block_signs[k] = s & 127;\n            }\n            float max = xval[0];\n            for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);\n            memset(L, 0, 32);\n            if (max < GROUP_MAX_EPS_IQ3_XXS) {\n                scales[ib] = 0;\n                continue;\n            }\n            float best = 0;\n            float scale = max/(2*kMaxQ-1);\n            for (int k = 0; k < 8; ++k) is_on_grid[k] = true;\n            for (int is = -15; is <= 15; ++is) {\n                float id = (2*kMaxQ-1+is*0.2f)/max;\n                float this_scale = 1/id;\n                for (int k = 0; k < 8; ++k) {\n                    for (int i = 0; i < 4; ++i) {\n                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));\n                        Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));\n                    }\n                    uint16_t u = 0;\n                    for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);\n                    int grid_index = kmap_q3xs[u];\n                    is_on_grid_aux[k] = true;\n                    if (grid_index < 0) {\n                        is_on_grid_aux[k] = false;\n                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;\n                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);\n                    }\n                }\n                float sumqx = 0, sumq2 = 0;\n                for (int i = 0; i < 32; ++i) {\n                    float w = weight[i];\n                    float q = 2*Laux[i] + 1;\n                    sumqx += w*xval[i]*q;\n                    sumq2 += w*q*q;\n                }\n                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {\n                    scale = sumqx/sumq2; best = scale*sumqx;\n                    for (int i = 0; i < 32; ++i) L[i] = Laux[i];\n                    for (int k = 0; k <  8; ++k) is_on_grid[k] = is_on_grid_aux[k];\n                }\n            }\n            int n_not_ongrid = 0;\n            for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid;\n            if (n_not_ongrid > 0 && scale > 0) {\n                float id = 1/scale;\n                for (int k = 0; k < 8; ++k) {\n                    if (is_on_grid[k]) continue;\n                    uint16_t u = 0;\n                    for (int i = 0; i < 4; ++i) {\n                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));\n                        l = MAX(0, MIN(kMaxQ-1, l));\n                        u |= (l << 3*i);\n                    }\n                    int grid_index = kmap_q3xs[u];\n                    if (grid_index < 0) {\n                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;\n                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);\n                    }\n                    const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);\n                    for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;\n                }\n                float sumqx = 0, sumq2 = 0;\n                for (int i = 0; i < 32; ++i) {\n                    float w = weight[i];\n                    float q = 2*L[i] + 1;\n                    sumqx += w*xval[i]*q;\n                    sumq2 += w*q*q;\n                }\n                if (sumq2 > 0) scale = sumqx/sumq2;\n            }\n            if (scale < 0) {\n                // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)\n                // and correspondingly flip quant signs.\n                scale = -scale;\n                for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;\n            }\n            for (int k = 0; k < 8; ++k) {\n                uint16_t u = 0;\n                for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);\n                int grid_index = kmap_q3xs[u];\n                if (grid_index < 0) {\n                    printf(\"Oops: found point %u not on grid:\", u);\n                    for (int i = 0; i < 4; ++i) printf(\" %d\", L[4*k+i]);\n                    printf(\"\\n\");\n                    GGML_ABORT(\"fatal error\");\n                }\n                if (grid_size == 256) {\n                    q3[8*ib+k] = grid_index;\n                } else {\n                    q3[8*ib+k] = grid_index & 255;\n                    qh[ib] |= ((grid_index >> 8) << k);\n                }\n\n            }\n            scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);\n            GGML_ASSERT(scale >= 0);\n            scales[ib] = scale;\n            max_scale = MAX(max_scale, scale);\n        }\n\n        if (!max_scale) {\n            memset(qs, 0, quant_size);\n            dh += block_size/sizeof(ggml_fp16_t);\n            qs += block_size;\n            continue;\n        }\n\n        float d = max_scale/31;\n        dh[0] = GGML_FP32_TO_FP16(d * 1.0125f);  // small improvement via this fudge factor\n        float id = 1/d;\n        for (int ib = 0; ib < QK_K/32; ++ib) {\n            int l = nearest_int(0.5f*(id*scales[ib]-1));\n            l = MAX(0, MIN(15, l));\n            scales_and_signs[ib] |= ((uint32_t)l << 28);\n        }\n        memcpy(qs, q3, quant_size);\n\n        dh += block_size/sizeof(ggml_fp16_t);\n        qs += block_size;\n\n    }\n}\n\nsize_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    GGML_ASSERT(n_per_row%QK_K == 0);\n    int64_t nblock = n_per_row/QK_K;\n    char * qrow = (char *)dst;\n    for (int64_t row = 0; row < nrow; ++row) {\n        quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights);\n        src += n_per_row;\n        qrow += nblock*sizeof(block_iq3_xxs);\n    }\n    return nrow * nblock * sizeof(block_iq3_xxs);\n}\n\nvoid quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    quantize_row_iq3_xxs_impl(256, x, y, k, NULL);\n}\n\nstatic void quantize_row_iq3_s_impl(int block_size, const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int n,\n        const float * GGML_RESTRICT quant_weights,\n        float   * scales,\n        float   * weight,\n        float   * xval,\n        int8_t  * L,\n        int8_t  * Laux,\n        float   * waux,\n        bool    * is_on_grid,\n        bool    * is_on_grid_aux,\n        uint8_t * block_signs) {\n\n    const int gindex = iq3_data_index(512);\n\n    const uint32_t * kgrid_q3xs      = iq3_data[gindex].grid;\n    const int      * kmap_q3xs       = iq3_data[gindex].map;\n    const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;\n\n    //GGML_ASSERT(quant_weights   && \"missing quantization weights\");\n    GGML_ASSERT(kgrid_q3xs      && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(kmap_q3xs       && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(kneighbors_q3xs && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(n%QK_K == 0);\n\n    const int kMaxQ = 8;\n\n    const int64_t nbl = n/QK_K;\n\n    block_iq3_s * y = vy;\n\n    const int bs4 = block_size/4;\n    const int bs8 = block_size/8;\n\n    for (int ibl = 0; ibl < nbl; ++ibl) {\n\n        memset(&y[ibl], 0, sizeof(block_iq3_s));\n        y[ibl].d = GGML_FP32_TO_FP16(0.f);\n\n        uint8_t * qs = y[ibl].qs;\n        uint8_t * qh = y[ibl].qh;\n        uint8_t * signs = y[ibl].signs;\n\n        float max_scale = 0;\n\n        const float * xbl = x + QK_K*ibl;\n        float sumx2 = 0;\n        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];\n        float sigma2 = 2*sumx2/QK_K;\n\n        for (int ib = 0; ib < QK_K/block_size; ++ib) {\n            const float * xb = xbl + block_size*ib;\n            if (quant_weights) {\n                const float * qw = quant_weights + QK_K*ibl + block_size*ib;\n                for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);\n            } else {\n                for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];\n            }\n            for (int i = 0; i < block_size; ++i) waux[i] = sqrtf(weight[i]);\n            for (int k = 0; k < bs8; ++k) {\n                uint8_t s = 0;\n                for (int i = 0; i < 8; ++i) {\n                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];\n                    else {\n                        xval[8*k + i] = -xb[8*k + i]; s |= (1 << i);\n                    }\n                }\n                block_signs[k] = s;\n            }\n            float max = xval[0];\n            for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]);\n            memset(L, 0, block_size);\n            if (!max) {\n                scales[ib] = 0;\n                continue;\n            }\n            float best = 0;\n            float scale = max/(2*kMaxQ-1);\n            for (int k = 0; k < bs4; ++k) is_on_grid[k] = false;\n            for (int is = -9; is <= 9; ++is) {\n                float id = (2*kMaxQ-1+is*0.2f)/max;\n                float this_scale = 1/id;\n                for (int k = 0; k < bs4; ++k) {\n                    for (int i = 0; i < 4; ++i) {\n                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));\n                        Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));\n                    }\n                    uint16_t u = 0;\n                    for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);\n                    int grid_index = kmap_q3xs[u];\n                    is_on_grid_aux[k] = true;\n                    if (grid_index < 0) {\n                        is_on_grid_aux[k] = false;\n                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;\n                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);\n                    }\n                }\n                float sumqx = 0, sumq2 = 0;\n                for (int i = 0; i < block_size; ++i) {\n                    float w = weight[i];\n                    float q = 2*Laux[i] + 1;\n                    sumqx += w*xval[i]*q;\n                    sumq2 += w*q*q;\n                }\n                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {\n                    scale = sumqx/sumq2; best = scale*sumqx;\n                    for (int i = 0; i < block_size; ++i) L[i] = Laux[i];\n                    for (int k = 0; k < bs4; ++k) is_on_grid[k] = is_on_grid_aux[k];\n                }\n            }\n            int n_not_ongrid = 0;\n            for (int k = 0; k < bs4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;\n            if (n_not_ongrid > 0 && scale > 0) {\n                float id = 1/scale;\n                for (int k = 0; k < bs4; ++k) {\n                    //if (is_on_grid[k]) continue;\n                    uint16_t u = 0;\n                    for (int i = 0; i < 4; ++i) {\n                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));\n                        l = MAX(0, MIN(kMaxQ-1, l));\n                        u |= (l << 3*i);\n                    }\n                    int grid_index = kmap_q3xs[u];\n                    if (grid_index < 0) {\n                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;\n                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);\n                    }\n                    const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);\n                    for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;\n                }\n                float sumqx = 0, sumq2 = 0;\n                for (int i = 0; i < block_size; ++i) {\n                    float w = weight[i];\n                    float q = 2*L[i] + 1;\n                    sumqx += w*xval[i]*q;\n                    sumq2 += w*q*q;\n                }\n                if (sumq2 > 0) scale = sumqx/sumq2;\n            }\n            if (scale < 0) {\n                // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)\n                // and correspondingly flip quant signs.\n                scale = -scale;\n                for (int k = 0; k < bs8; ++k) block_signs[k] = ~block_signs[k];\n            }\n            for (int k = 0; k < bs4; ++k) {\n                uint16_t u = 0;\n                for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);\n                int grid_index = kmap_q3xs[u];\n                if (grid_index < 0) {\n                    printf(\"Oops: found point %u not on grid:\", u);\n                    for (int i = 0; i < 4; ++i) printf(\" %d\", L[4*k+i]);\n                    printf(\"\\n\");\n                    GGML_ABORT(\"fatal error\");\n                }\n                qs[k] = grid_index & 255;\n                qh[(ib*bs4+k)/8] |= ((grid_index >> 8) << ((ib*bs4+k)%8));\n            }\n            qs += bs4;\n            for (int k = 0; k < bs8; ++k) signs[k] = block_signs[k];\n            signs += bs8;\n            GGML_ASSERT(scale >= 0);\n            scales[ib] = scale;\n            max_scale = MAX(max_scale, scale);\n        }\n\n        if (!max_scale) {\n            continue;\n        }\n\n        float d = max_scale/31;\n        y[ibl].d = GGML_FP32_TO_FP16(d * 1.033f);\n        float id = 1/d;\n        for (int ib = 0; ib < QK_K/block_size; ib += 2) {\n            int l1 = nearest_int(0.5f*(id*scales[ib+0]-1));\n            l1 = MAX(0, MIN(15, l1));\n            int l2 = nearest_int(0.5f*(id*scales[ib+1]-1));\n            l2 = MAX(0, MIN(15, l2));\n            y[ibl].scales[ib/2] = l1 | (l2 << 4);\n        }\n\n    }\n}\n\n#define IQ3S_BLOCK_SIZE 32\nsize_t quantize_iq3_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    GGML_ASSERT(n_per_row%QK_K == 0);\n    int64_t nblock = n_per_row/QK_K;\n    float scales[QK_K/IQ3S_BLOCK_SIZE];\n    float weight[IQ3S_BLOCK_SIZE];\n    float xval[IQ3S_BLOCK_SIZE];\n    int8_t L[IQ3S_BLOCK_SIZE];\n    int8_t Laux[IQ3S_BLOCK_SIZE];\n    float  waux[IQ3S_BLOCK_SIZE];\n    bool   is_on_grid[IQ3S_BLOCK_SIZE/4];\n    bool   is_on_grid_aux[IQ3S_BLOCK_SIZE/4];\n    uint8_t block_signs[IQ3S_BLOCK_SIZE/8];\n    char * qrow = (char *)dst;\n    for (int64_t row = 0; row < nrow; ++row) {\n        quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights,\n                scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs);\n        src += n_per_row;\n        qrow += nblock*sizeof(block_iq3_s);\n    }\n    return nrow * nblock * sizeof(block_iq3_s);\n}\n\nvoid quantize_row_iq3_s_ref(const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    quantize_iq3_s(x, y, 1, k, NULL);\n}\n\n\n// =================================== 1.5 bpw ===================================================\n\nstatic int iq1_find_best_neighbour(const uint16_t * GGML_RESTRICT neighbours, const uint64_t * GGML_RESTRICT grid,\n        const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float * scale, int8_t * GGML_RESTRICT L, int ngrid) {\n    int num_neighbors = neighbours[0];\n    GGML_ASSERT(num_neighbors > 0);\n    float best_score = -FLT_MAX;\n    int grid_index = -1;\n    for (int j = 1; j <= num_neighbors; ++j) {\n        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);\n        float sumqx = 0, sumq2 = 0;\n        for (int i = 0; i < 8; ++i) {\n            float q = (pg[i] - 3)/2;\n            float w = weight[i];\n            sumqx += w*q*xval[i];\n            sumq2 += w*q*q;\n        }\n        if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {\n            *scale = sumqx/sumq2; best_score = *scale * sumqx;\n            grid_index = neighbours[j];\n        }\n    }\n    if (grid_index < 0) {\n        for (int i = 0; i < ngrid; ++i) {\n            const int8_t * grid_i = (const int8_t *)(grid + i);\n            float sumqx = 0, sumq2 = 0;\n            for (int j = 0; j < 8; ++j) {\n                float w = weight[j];\n                float q = (grid_i[j] - 3)/2;\n                sumqx += w*q*xval[j];\n                sumq2 += w*q*q;\n            }\n            if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {\n                *scale = sumqx/sumq2; best_score = *scale*sumqx;\n                grid_index = i;\n            }\n        }\n    }\n    if (grid_index < 0) {\n        printf(\"Oops, did not find grid point\\n\");\n        printf(\"Have %d neighbours\\n\", num_neighbors);\n        for (int j = 1; j <= num_neighbors; ++j) {\n            const int8_t * pg = (const int8_t *)(grid + neighbours[j]);\n            float sumqx = 0, sumq2 = 0;\n            for (int i = 0; i < 8; ++i) {\n                float q = (pg[i] - 3)/2;\n                float w = weight[i];\n                sumqx += w*q*xval[i];\n                sumq2 += w*q*q;\n            }\n            printf(\"    neighbour %d: sumqx = %g sumq2 = %g\\n\", j, (double)sumqx, (double)sumq2);\n        }\n    }\n    GGML_ASSERT(grid_index >= 0);\n    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n    *scale *= 1.05f;  // This is a fudge factor. Don't ask me why it improves the result.\n    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n    const int8_t * pg = (const int8_t *)(grid + grid_index);\n    for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;\n    return grid_index;\n}\n\nstatic int iq1_find_best_neighbour2(const uint16_t * GGML_RESTRICT neighbours, const uint64_t * GGML_RESTRICT grid,\n        const float * GGML_RESTRICT xval, const float * GGML_RESTRICT weight, float scale, const float * GGML_RESTRICT xg, int8_t * GGML_RESTRICT L, int ngrid) {\n    int num_neighbors = neighbours[0];\n    GGML_ASSERT(num_neighbors > 0);\n    float best_score = FLT_MAX;\n    int grid_index = -1;\n    for (int j = 1; j <= num_neighbors; ++j) {\n        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);\n        float d2 = 0;\n        for (int i = 0; i < 8; ++i) {\n            float q = xg[(pg[i] - 1)/2];\n            float w = weight[i];\n            float diff = scale*q - xval[i];\n            d2 += w*diff*diff;\n        }\n        if (d2 < best_score) {\n            best_score = d2;\n            grid_index = neighbours[j];\n        }\n    }\n    if (grid_index < 0) {\n        for (int i = 0; i < ngrid; ++i) {\n            const int8_t * grid_i = (const int8_t *)(grid + i);\n            float d2 = 0;\n            for (int j = 0; j < 8; ++j) {\n                float w = weight[j];\n                float q = xg[(grid_i[j] - 1)/2];\n                float diff = scale*q - xval[i];\n                d2 += w*diff*diff;\n            }\n            if (d2 < best_score) {\n                best_score = d2;\n                grid_index = i;\n            }\n        }\n    }\n    if (grid_index < 0) {\n        printf(\"Oops, did not find grid point\\n\");\n        printf(\"Have %d neighbours\\n\", num_neighbors);\n        for (int j = 1; j <= num_neighbors; ++j) {\n            const int8_t * pg = (const int8_t *)(grid + neighbours[j]);\n            float sumqx = 0, sumq2 = 0;\n            for (int i = 0; i < 8; ++i) {\n                float q = xg[(pg[i] - 1)/2];\n                float w = weight[i];\n                sumqx += w*q*xval[i];\n                sumq2 += w*q*q;\n            }\n            printf(\"    neighbour %d: sumqx = %g sumq2 = %g\\n\", j, (double)sumqx, (double)sumq2);\n        }\n    }\n    GGML_ASSERT(grid_index >= 0);\n    const int8_t * pg = (const int8_t *)(grid + grid_index);\n    for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;\n    return grid_index;\n}\n\nstatic int iq1_sort_helper(const void * left, const void * right) {\n    const float * l = left;\n    const float * r = right;\n    return *l < *r ? -1 : *l > *r ? 1 : 0;\n}\n\n#define IQ1S_BLOCK_SIZE 32\n#define IQ1M_BLOCK_SIZE 16\nstatic void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights,\n        float    * scales,\n        float    * weight,\n        float    * sumx,\n        float    * sumw,\n        float    * pairs,\n        int8_t   * L,\n        uint16_t * index,\n        int8_t   * shifts) {\n\n    const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);\n\n    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;\n    const int      * kmap_q2xs       = iq2_data[gindex].map;\n    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;\n\n    GGML_ASSERT(quant_weights   && \"missing quantization weights\");\n    GGML_ASSERT(kgrid_q2xs      && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(kmap_q2xs       && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(kneighbors_q2xs && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(n%QK_K == 0);\n\n    block_iq1_s * y = vy;\n\n    const int64_t nbl = n/QK_K;\n\n    const int block_size = IQ1S_BLOCK_SIZE;\n\n    const float x_p[3] = {-1 + IQ1S_DELTA,  IQ1S_DELTA, 1 + IQ1S_DELTA};\n    const float x_m[3] = {-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA};\n\n\n    int * idx = (int *)(pairs + 1);\n\n    for (int ibl = 0; ibl < nbl; ++ibl) {\n\n        y[ibl].d = GGML_FP32_TO_FP16(0.f);\n        memset(y[ibl].qs, 0, QK_K/8);\n        memset(y[ibl].qh, 0, QK_K/16);\n\n        float max_scale = 0;\n\n        const float * xbl = x + QK_K*ibl;\n        float sumx2 = 0;\n        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];\n        float sigma2 = 2*sumx2/QK_K;\n\n        for (int ib = 0; ib < QK_K/block_size; ++ib) {\n            const float * xb = xbl + block_size*ib;\n            const float * qw = quant_weights + QK_K*ibl + block_size*ib;\n            for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);\n            float max = fabsf(xb[0]);\n            for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));\n            if (max < GROUP_MAX_EPS_IQ1_S) {\n                scales[ib] = 0;\n                shifts[ib] = 1;\n                memset(L, 1, block_size);\n                continue;\n            }\n            // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.\n            // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two\n            // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights\n            // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and\n            // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale\n            // for each possible and score for each split.\n            for (int j = 0; j < block_size; ++j) {\n                pairs[2*j] = xb[j];\n                idx[2*j] = j;\n            }\n            qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);\n            {\n                sumx[0] = sumw[0] = 0;\n                for (int j = 0; j < block_size; ++j) {\n                    int i = idx[2*j];\n                    sumx[j+1] = sumx[j] + weight[i]*xb[i];\n                    sumw[j+1] = sumw[j] + weight[i];\n                }\n            }\n            float best_score = -FLT_MAX, scale = max;\n            int besti1 = -1, besti2 = -1, best_shift = 0;\n            for (int i1 = 0; i1 <= block_size; ++i1) {\n                for (int i2 = i1; i2 <= block_size; ++i2) {\n                    float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[block_size] - sumx[i2])*x_p[2];\n                    float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[block_size] - sumw[i2])*x_p[2]*x_p[2];\n                    if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {\n                        scale = sumqx/sumq2; best_score = scale*sumqx;\n                        besti1 = i1; besti2 = i2; best_shift = 1;\n                    }\n                    sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[block_size] - sumx[i2])*x_m[2];\n                    sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[block_size] - sumw[i2])*x_m[2]*x_m[2];\n                    if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {\n                        scale = sumqx/sumq2; best_score = scale*sumqx;\n                        besti1 = i1; besti2 = i2; best_shift = -1;\n                    }\n                }\n            }\n            if (besti1 < 0 || besti2 < 0 || best_shift == 0) {\n                scales[ib] = 0;\n                shifts[ib] = 1;\n                memset(L, 1, block_size);\n                continue;\n            }\n            for (int j =      0; j < besti1; ++j) L[idx[2*j]] = 0;\n            for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;\n            for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;\n            if (scale < 0) {\n                for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j];\n                scale = -scale; best_shift = -best_shift;\n            }\n            bool all_on_grid = true;\n            const float * xx = best_shift == 1 ? x_p : x_m;\n            for (int k = 0; k < block_size/8; ++k) {\n                uint16_t u = 0;\n                for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);\n                int grid_index = kmap_q2xs[u];\n                if (grid_index < 0) {\n                    all_on_grid = false;\n                    const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;\n                    grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S);\n                    GGML_ASSERT(grid_index >= 0);\n                }\n                index[k] = grid_index;\n            }\n            if (!all_on_grid) {\n                float sumqx = 0, sumq2 = 0;\n                for (int k = 0; k < block_size/8; ++k) {\n                    const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);\n                    for (int j = 0; j < 8; ++j) {\n                        float w = weight[8*k + j];\n                        float q = xx[(pg[j] - 1)/2];\n                        sumqx += w*q*xb[8*k+j];\n                        sumq2 += w*q*q;\n                    }\n                }\n                if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2;\n            }\n            uint16_t h = 0;\n            for (int k = 0; k < block_size/8; ++k) {\n                y[ibl].qs[(block_size/8)*ib + k] = index[k] & 255;\n                h |= (index[k] >> 8) << 3*k;\n            }\n            y[ibl].qh[ib] = h;\n            GGML_ASSERT(scale >= 0);\n            scales[ib] = scale;\n            shifts[ib] = best_shift;\n            max_scale = MAX(max_scale, scale);\n        }\n\n        if (!max_scale) {\n            continue;\n        }\n\n        float d = max_scale/15;\n        y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.125f is another fudge factor. Don't ask me why it is needed.\n        float id = 1/d;\n        for (int ib = 0; ib < QK_K/block_size; ++ib) {\n            int l = nearest_int(0.5f*(id*scales[ib]-1));\n            l = MAX(0, MIN(7, l));\n            if (shifts[ib] == -1) l |= 8;\n            y[ibl].qh[ib] |= (l << 12);\n        }\n    }\n}\n\nsize_t quantize_iq1_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    GGML_ASSERT(n_per_row%QK_K == 0);\n    float  scales[QK_K/IQ1S_BLOCK_SIZE];\n    float  weight[IQ1S_BLOCK_SIZE];\n    int8_t L[IQ1S_BLOCK_SIZE];\n    float  sumx[IQ1S_BLOCK_SIZE+1];\n    float  sumw[IQ1S_BLOCK_SIZE+1];\n    float  pairs[2*IQ1S_BLOCK_SIZE];\n    uint16_t index[IQ1S_BLOCK_SIZE/8];\n    int8_t shifts[QK_K/IQ1S_BLOCK_SIZE];\n    int64_t nblock = n_per_row/QK_K;\n    char * qrow = (char *)dst;\n    for (int64_t row = 0; row < nrow; ++row) {\n        quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights, scales, weight, sumx, sumw, pairs, L, index, shifts);\n        src += n_per_row;\n        qrow += nblock*sizeof(block_iq1_s);\n    }\n    return nrow * nblock * sizeof(block_iq1_s);\n}\n\nstatic void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights,\n        float    * scales,\n        float    * weight,\n        float    * pairs,\n        int8_t   * L,\n        uint16_t * index,\n        int8_t   * shifts) {\n\n    const int gindex = iq2_data_index(GGML_TYPE_IQ1_M);\n\n    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;\n    const int      * kmap_q2xs       = iq2_data[gindex].map;\n    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;\n\n    //GGML_ASSERT(quant_weights   && \"missing quantization weights\");\n    GGML_ASSERT(kgrid_q2xs      && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(kmap_q2xs       && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(kneighbors_q2xs && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(n%QK_K == 0);\n\n    block_iq1_m * y = vy;\n\n    const int64_t nbl = n/QK_K;\n\n    const int block_size = IQ1M_BLOCK_SIZE;\n\n    const float x_p[3] = {-1 + IQ1M_DELTA,  IQ1M_DELTA, 1 + IQ1M_DELTA};\n    const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA};\n    const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88};\n\n    int * idx = (int *)(pairs + 1);\n\n    float sumqx[4], sumq2[4];\n\n    iq1m_scale_t s;\n    const float * xx;\n\n    for (int ibl = 0; ibl < nbl; ++ibl) {\n        memset(y[ibl].qs, 0, QK_K/8);\n        memset(y[ibl].qh, 0, QK_K/16);\n        memset(y[ibl].scales, 0, QK_K/32);\n\n        float max_scale = 0;\n\n        const float * xbl = x + QK_K*ibl;\n        float sumx2 = 0;\n        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];\n        float sigma2 = 2*sumx2/QK_K;\n\n        for (int ib = 0; ib < QK_K/block_size; ++ib) {\n            const float * xb = xbl + block_size*ib;\n            if (quant_weights) {\n                const float * qw = quant_weights + QK_K*ibl + block_size*ib;\n                for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);\n            } else {\n                for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];\n            }\n            float max = fabsf(xb[0]);\n            for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));\n            if (max < GROUP_MAX_EPS_IQ1_M) {\n                scales[ib] = 0;\n                shifts[ib] = 0;\n                memset(L, 1, block_size);\n                continue;\n            }\n            // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.\n            // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two\n            // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights\n            // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and\n            // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale\n            // for each possible and score for each split.\n            for (int j = 0; j < block_size; ++j) {\n                pairs[2*j] = xb[j];\n                idx[2*j] = j;\n            }\n            qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);\n            float best_score = -FLT_MAX, scale = max;\n            int besti1 = -1, besti2 = -1, best_k = -1;\n            // 0: +, +\n            // 1: +, -\n            // 2: -, +\n            // 3: -, -\n            for (int i1 = 0; i1 <= block_size; ++i1) {\n                for (int i2 = i1; i2 <= block_size; ++i2) {\n                    memset(sumqx, 0, 4*sizeof(float));\n                    memset(sumq2, 0, 4*sizeof(float));\n                    for (int j = 0; j < i1; ++j) {\n                        int i = idx[2*j];\n                        if (i < block_size/2) {\n                            sumqx[0] += weight[i]*x_p[0]*xb[i];\n                            sumqx[1] += weight[i]*x_p[0]*xb[i];\n                            sumqx[2] += weight[i]*x_m[0]*xb[i];\n                            sumqx[3] += weight[i]*x_m[0]*xb[i];\n                            sumq2[0] += weight[i]*x_p[0]*x_p[0];\n                            sumq2[1] += weight[i]*x_p[0]*x_p[0];\n                            sumq2[2] += weight[i]*x_m[0]*x_m[0];\n                            sumq2[3] += weight[i]*x_m[0]*x_m[0];\n                        } else {\n                            sumqx[0] += weight[i]*x_p[0]*xb[i];\n                            sumqx[2] += weight[i]*x_p[0]*xb[i];\n                            sumqx[1] += weight[i]*x_m[0]*xb[i];\n                            sumqx[3] += weight[i]*x_m[0]*xb[i];\n                            sumq2[0] += weight[i]*x_p[0]*x_p[0];\n                            sumq2[2] += weight[i]*x_p[0]*x_p[0];\n                            sumq2[1] += weight[i]*x_m[0]*x_m[0];\n                            sumq2[3] += weight[i]*x_m[0]*x_m[0];\n                        }\n                    }\n                    for (int j = i1; j < i2; ++j) {\n                        int i = idx[2*j];\n                        if (i < block_size/2) {\n                            sumqx[0] += weight[i]*x_p[1]*xb[i];\n                            sumqx[1] += weight[i]*x_p[1]*xb[i];\n                            sumqx[2] += weight[i]*x_m[1]*xb[i];\n                            sumqx[3] += weight[i]*x_m[1]*xb[i];\n                            sumq2[0] += weight[i]*x_p[1]*x_p[1];\n                            sumq2[1] += weight[i]*x_p[1]*x_p[1];\n                            sumq2[2] += weight[i]*x_m[1]*x_m[1];\n                            sumq2[3] += weight[i]*x_m[1]*x_m[1];\n                        } else {\n                            sumqx[0] += weight[i]*x_p[1]*xb[i];\n                            sumqx[2] += weight[i]*x_p[1]*xb[i];\n                            sumqx[1] += weight[i]*x_m[1]*xb[i];\n                            sumqx[3] += weight[i]*x_m[1]*xb[i];\n                            sumq2[0] += weight[i]*x_p[1]*x_p[1];\n                            sumq2[2] += weight[i]*x_p[1]*x_p[1];\n                            sumq2[1] += weight[i]*x_m[1]*x_m[1];\n                            sumq2[3] += weight[i]*x_m[1]*x_m[1];\n                        }\n                    }\n                    for (int j = i2; j < block_size; ++j) {\n                        int i = idx[2*j];\n                        if (i < block_size/2) {\n                            sumqx[0] += weight[i]*x_p[2]*xb[i];\n                            sumqx[1] += weight[i]*x_p[2]*xb[i];\n                            sumqx[2] += weight[i]*x_m[2]*xb[i];\n                            sumqx[3] += weight[i]*x_m[2]*xb[i];\n                            sumq2[0] += weight[i]*x_p[2]*x_p[2];\n                            sumq2[1] += weight[i]*x_p[2]*x_p[2];\n                            sumq2[2] += weight[i]*x_m[2]*x_m[2];\n                            sumq2[3] += weight[i]*x_m[2]*x_m[2];\n                        } else {\n                            sumqx[0] += weight[i]*x_p[2]*xb[i];\n                            sumqx[2] += weight[i]*x_p[2]*xb[i];\n                            sumqx[1] += weight[i]*x_m[2]*xb[i];\n                            sumqx[3] += weight[i]*x_m[2]*xb[i];\n                            sumq2[0] += weight[i]*x_p[2]*x_p[2];\n                            sumq2[2] += weight[i]*x_p[2]*x_p[2];\n                            sumq2[1] += weight[i]*x_m[2]*x_m[2];\n                            sumq2[3] += weight[i]*x_m[2]*x_m[2];\n                        }\n                    }\n                    for (int k = 0; k < 4; ++k) {\n                        if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) {\n                            scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k];\n                            besti1 = i1; besti2 = i2; best_k = k;\n                        }\n                    }\n                }\n            }\n            if (besti1 < 0 || besti2 < 0 || best_k < 0) {\n                scales[ib] = 0;\n                shifts[ib] = 0;\n                memset(L, 1, block_size);\n                continue;\n            }\n            for (int j =      0; j < besti1; ++j) L[idx[2*j]] = 0;\n            for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;\n            for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;\n            if (scale < 0) {\n                for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j];\n                scale = -scale;\n                best_k = best_k == 0 ? 3 : best_k == 1 ? 2 : best_k == 2 ? 1 : 0;\n            }\n            bool all_on_grid = true;\n            for (int k = 0; k < block_size/8; ++k) {\n                if (k == 0) xx = best_k < 2 ? x_p : x_m;\n                else xx = best_k%2 == 0 ? x_p : x_m;\n                uint16_t u = 0;\n                for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);\n                int grid_index = kmap_q2xs[u];\n                if (grid_index < 0) {\n                    all_on_grid = false;\n                    const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;\n                    grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S);\n                    GGML_ASSERT(grid_index >= 0);\n                }\n                index[k] = grid_index;\n            }\n            if (!all_on_grid) {\n                float sumqx_f = 0, sumq2_f = 0;\n                for (int k = 0; k < block_size/8; ++k) {\n                    if (k == 0) xx = best_k < 2 ? x_p : x_m;\n                    else xx = best_k%2 == 0 ? x_p : x_m;\n                    const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);\n                    for (int j = 0; j < 8; ++j) {\n                        float w = weight[8*k + j];\n                        float q = xx[(pg[j] - 1)/2];\n                        sumqx_f += w*q*xb[8*k+j];\n                        sumq2_f += w*q*q;\n                    }\n                }\n                if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f;\n            }\n            y[ibl].qs[2*ib + 0] = index[0] & 255;\n            y[ibl].qs[2*ib + 1] = index[1] & 255;\n            y[ibl].qh[ib] = (index[0] >> 8) | ((index[1] >> 8) << 4);\n            GGML_ASSERT(scale >= 0);\n            scales[ib] = scale;\n            shifts[ib] = best_k;\n            max_scale = MAX(max_scale, scale);\n        }\n\n        if (!max_scale) {\n            continue;\n        }\n\n        uint16_t * sc = (uint16_t *)y[ibl].scales;\n        float d = max_scale/15;\n        float id = 1/d;\n        float sumqx_f = 0, sumq2_f = 0;\n        for (int ib = 0; ib < QK_K/block_size; ++ib) {\n            int l = nearest_int(0.5f*(id*scales[ib+0]-1));\n            l = MAX(0, MIN(7, l));\n            sc[ib/4] |= (l << 3*(ib%4));\n            y[ibl].qh[ib] |= masks[shifts[ib]];\n            const float * xb = xbl + block_size*ib;\n            if (quant_weights) {\n                const float * qw = quant_weights + QK_K*ibl + block_size*ib;\n                for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);\n            } else {\n                for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];\n            }\n            for (int k = 0; k < block_size/8; ++k) {\n                if (k == 0) xx = shifts[ib] < 2 ? x_p : x_m;\n                else xx = shifts[ib]%2 == 0 ? x_p : x_m;\n                const int8_t * pg = (const int8_t *)(kgrid_q2xs + y[ibl].qs[2*ib+k] + ((y[ibl].qh[ib] << (8 - 4*k)) & 0x700));\n                for (int j = 0; j < 8; ++j) {\n                    float w = weight[8*k + j];\n                    float q = xx[(pg[j] - 1)/2]*(2*l+1);\n                    sumqx_f += w*q*xb[8*k+j];\n                    sumq2_f += w*q*q;\n                }\n            }\n        }\n        if (sumq2_f > 0) d = sumqx_f/sumq2_f;\n        s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.\n        sc[0] |= ((s.u16 & 0x000f) << 12);\n        sc[1] |= ((s.u16 & 0x00f0) <<  8);\n        sc[2] |= ((s.u16 & 0x0f00) <<  4);\n        sc[3] |= ((s.u16 & 0xf000) <<  0);\n    }\n}\n\nsize_t quantize_iq1_m(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    GGML_ASSERT(n_per_row%QK_K == 0);\n    float  scales[QK_K/IQ1M_BLOCK_SIZE];\n    float  weight[IQ1M_BLOCK_SIZE];\n    int8_t L[IQ1M_BLOCK_SIZE];\n    float  pairs[2*IQ1M_BLOCK_SIZE];\n    uint16_t index[IQ1M_BLOCK_SIZE/8];\n    int8_t shifts[QK_K/IQ1M_BLOCK_SIZE];\n    int64_t nblock = n_per_row/QK_K;\n    char * qrow = (char *)dst;\n    for (int64_t row = 0; row < nrow; ++row) {\n        quantize_row_iq1_m_impl(src, qrow, n_per_row, quant_weights, scales, weight, pairs, L, index, shifts);\n        src += n_per_row;\n        qrow += nblock*sizeof(block_iq1_m);\n    }\n    return nrow * nblock * sizeof(block_iq1_m);\n}\n\n// ============================ 4-bit non-linear quants\n\nstatic void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,\n        ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,\n        float * scales, float * weight, uint8_t * L,\n        const int8_t * values,\n        const float * quant_weights,\n        const int ntry) {\n\n    float sigma2 = 0;\n    for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];\n    sigma2 *= 2.f/super_block_size;\n\n    memset(q4, 0, super_block_size/2);\n    dh[0] = GGML_FP32_TO_FP16(0.f);\n\n    float max_scale = 0, amax_scale = 0;\n    for (int ib = 0; ib < super_block_size/block_size; ++ib) {\n        const float * xb = x + ib*block_size;\n        uint8_t * Lb = L + ib*block_size;\n        if (quant_weights) {\n            const float * qw = quant_weights + ib*block_size;\n            for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);\n        } else {\n            for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];\n        }\n        float amax = 0, max = 0;\n        for (int j = 0; j < block_size; ++j) {\n            float ax = fabsf(xb[j]);\n            if (ax > amax) {\n                amax = ax; max = xb[j];\n            }\n        }\n        if (amax < GROUP_MAX_EPS) {\n            scales[ib] = 0;\n            continue;\n        }\n        float d = ntry > 0 ? -max/values[0] : max/values[0];\n        float id = 1/d;\n        float sumqx = 0, sumq2 = 0;\n        for (int j = 0; j < block_size; ++j) {\n            float al = id*xb[j];\n            int l = best_index_int8(16, values, al);\n            Lb[j] = l;\n            float q = values[l];\n            float w = weight[j];\n            sumqx += w*q*xb[j];\n            sumq2 += w*q*q;\n        }\n        d = sumq2 > 0 ? sumqx/sumq2 : 0.f;\n        float best = d*sumqx;\n        for (int itry = -ntry; itry <= ntry; ++itry) {\n            id = (itry + values[0])/max;\n            sumqx = sumq2 = 0;\n            for (int j = 0; j < block_size; ++j) {\n                float al = id*xb[j];\n                int l = best_index_int8(16, values, al);\n                float q = values[l];\n                float w = weight[j];\n                sumqx += w*q*xb[j];\n                sumq2 += w*q*q;\n            }\n            if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {\n                d = sumqx/sumq2; best = d * sumqx;\n            }\n        }\n        scales[ib] = d;\n        float abs_d = fabsf(d);\n        if (abs_d > amax_scale) {\n            amax_scale = abs_d; max_scale = d;\n        }\n    }\n\n    if (super_block_size/block_size > 1) {\n        int nb = super_block_size/block_size;\n        memset(scales_h, 0, ((nb+7)/8)*sizeof(uint16_t));\n        float d = -max_scale/32;\n        dh[0] = GGML_FP32_TO_FP16(d);\n        float id = d ? 1/d : 0.f;\n        for (int ib = 0; ib < super_block_size/block_size; ++ib) {\n            int l = nearest_int(id*scales[ib]);\n            l = MAX(-32, MIN(31, l));\n            float dl = d * l;\n            float idl = dl ? 1/dl : 0.f;\n            uint8_t * Lb = L + ib*block_size;\n            const float * xb = x + ib*block_size;\n            for (int j = 0; j < block_size; ++j) {\n                Lb[j] = best_index_int8(16, values, idl*xb[j]);\n            }\n            l += 32;\n            uint8_t l_l = l & 0xf;\n            uint8_t l_h = l >>  4;\n            if (ib%2 == 0) scales_l[ib/2] = l_l;\n            else scales_l[ib/2] |= (l_l << 4);\n            scales_h[ib/8] |= (l_h << 2*(ib%8));\n        }\n    } else {\n        dh[0] = GGML_FP32_TO_FP16(scales[0]);\n        if (ntry > 0) {\n            float id = scales[0] ? 1/scales[0] : 0;\n            for (int j = 0; j < super_block_size; ++j) {\n                L[j] = best_index_int8(16, values, id*x[j]);\n            }\n        }\n    }\n\n    for (int i = 0; i < super_block_size/32; ++i) {\n        for (int j = 0; j < 16; ++j) {\n            q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);\n        }\n    }\n}\n\nsize_t quantize_iq4_nl(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    GGML_ASSERT(n_per_row%QK4_NL == 0);\n    int64_t nblock = n_per_row/QK4_NL;\n    char * qrow = (char *)dst;\n    uint8_t L[QK4_NL];\n    float weight[QK4_NL];\n    uint16_t unused_h;\n    uint8_t * unused_l = NULL;\n    float scale;\n    for (int64_t row = 0; row < nrow; ++row) {\n        block_iq4_nl * iq4 = (block_iq4_nl *)qrow;\n        for (int ibl = 0; ibl < nblock; ++ibl) {\n            const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;\n            quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,\n                    &scale, weight, L, kvalues_iq4nl, qw, 7);\n        }\n        src += n_per_row;\n        qrow += nblock*sizeof(block_iq4_nl);\n    }\n    return nrow * nblock * sizeof(block_iq4_nl);\n}\n\n//void quantize_row_iq4_nl_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {\nvoid quantize_row_iq4_nl_ref(const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k) {\n    GGML_ASSERT(k%QK4_NL == 0);\n    int64_t nblock = k/QK4_NL;\n    uint8_t L[QK4_NL];\n    float weight[QK4_NL];\n    uint16_t unused_h;\n    uint8_t * unused_l = NULL;\n    float scale;\n    block_iq4_nl * iq4 = y;\n    for (int ibl = 0; ibl < nblock; ++ibl) {\n        quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,\n                &scale, weight, L, kvalues_iq4nl, NULL, -1);\n    }\n}\n\nsize_t quantize_iq4_xs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    GGML_ASSERT(n_per_row%QK_K == 0);\n    int64_t nblock = n_per_row/QK_K;\n    char * qrow = (char *)dst;\n    uint8_t L[QK_K];\n    float weight[32];\n    float scales[QK_K/32];\n    for (int64_t row = 0; row < nrow; ++row) {\n        block_iq4_xs * iq4 = (block_iq4_xs *)qrow;\n        for (int ibl = 0; ibl < nblock; ++ibl) {\n            const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;\n            quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l,\n                    scales, weight, L, kvalues_iq4nl, qw, 7);\n        }\n        src += n_per_row;\n        qrow += nblock*sizeof(block_iq4_xs);\n    }\n    return nrow * nblock * sizeof(block_iq4_xs);\n}\n\nvoid quantize_row_iq4_xs_ref(const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    quantize_iq4_xs(x, y, 1, k, NULL);\n}\n\n// =============================== 2.5625 bpw\n\nstatic void quantize_row_iq2_s_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights) {\n\n    const int gindex = iq2_data_index(GGML_TYPE_IQ2_S);\n\n    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;\n    const int      * kmap_q2xs       = iq2_data[gindex].map;\n    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;\n\n    GGML_ASSERT(kmap_q2xs       && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(kgrid_q2xs      && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(kneighbors_q2xs && \"forgot to call ggml_quantize_init()?\");\n    GGML_ASSERT(n%QK_K == 0);\n\n    const int kMaxQ = 3;\n\n    const int64_t nbl = n/QK_K;\n\n    block_iq2_s * y = vy;\n\n    float scales[QK_K/16];\n    float weight[16];\n    float xval[16];\n    int8_t L[16];\n    int8_t Laux[16];\n    float  waux[16];\n    bool   is_on_grid[2];\n    bool   is_on_grid_aux[2];\n    uint8_t block_signs[2];\n\n    for (int ibl = 0; ibl < nbl; ++ibl) {\n\n        memset(&y[ibl], 0, sizeof(block_iq2_s));\n        y[ibl].d = GGML_FP32_TO_FP16(0.f);\n\n        float max_scale = 0;\n\n        const float * xbl = x + QK_K*ibl;\n        float sumx2 = 0;\n        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];\n        float sigma2 = 2*sumx2/QK_K;\n\n        for (int ib = 0; ib < QK_K/16; ++ib) {\n            const float * xb = xbl + 16*ib;\n            if (quant_weights) {\n                const float * qw = quant_weights + QK_K*ibl + 16*ib;\n                for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);\n            } else {\n                for (int i = 0; i < 16; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i];\n            }\n            for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]);\n            for (int k = 0; k < 2; ++k) {\n                uint8_t s = 0;\n                for (int i = 0; i < 8; ++i) {\n                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];\n                    else {\n                        xval[8*k + i] = -xb[8*k + i]; s |= (1 << i);\n                    }\n                }\n                block_signs[k] = s;\n            }\n            float max = xval[0];\n            for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);\n            memset(L, 0, 16);\n            if (max < GROUP_MAX_EPS_IQ2_S) {\n                scales[ib] = 0;\n                continue;\n            }\n            float best = 0;\n            float scale = max/(2*kMaxQ-1);\n            is_on_grid[0] = is_on_grid[1] = true;\n            for (int is = -9; is <= 9; ++is) {\n                float id = (2*kMaxQ-1+is*0.1f)/max;\n                float this_scale = 1/id;\n                for (int k = 0; k < 2; ++k) {\n                    for (int i = 0; i < 8; ++i) {\n                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));\n                        Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));\n                    }\n                    uint16_t u = 0;\n                    for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);\n                    int grid_index = kmap_q2xs[u];\n                    is_on_grid_aux[k] = true;\n                    if (grid_index < 0) {\n                        is_on_grid_aux[k] = false;\n                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;\n                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);\n                    }\n                }\n                float sumqx = 0, sumq2 = 0;\n                for (int i = 0; i < 16; ++i) {\n                    float w = weight[i];\n                    float q = 2*Laux[i] + 1;\n                    sumqx += w*xval[i]*q;\n                    sumq2 += w*q*q;\n                }\n                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {\n                    scale = sumqx/sumq2; best = scale*sumqx;\n                    for (int i = 0; i < 16; ++i) L[i] = Laux[i];\n                    for (int k = 0; k <  2; ++k) is_on_grid[k] = is_on_grid_aux[k];\n                }\n            }\n            int n_not_ongrid = 0;\n            for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid;\n            if (n_not_ongrid > 0 && scale > 0) {\n                float id = 1/scale;\n                for (int k = 0; k < 2; ++k) {\n                    if (is_on_grid[k]) continue;\n                    uint16_t u = 0;\n                    for (int i = 0; i < 8; ++i) {\n                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));\n                        l = MAX(0, MIN(kMaxQ-1, l));\n                        u |= (l << 2*i);\n                        L[8*k + i] = l;\n                    }\n                    int grid_index = kmap_q2xs[u];\n                    if (grid_index < 0) {\n                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;\n                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);\n                    }\n                }\n                float sumqx = 0, sumq2 = 0;\n                for (int i = 0; i < 16; ++i) {\n                    float w = weight[i];\n                    float q = 2*L[i] + 1;\n                    sumqx += w*xval[i]*q;\n                    sumq2 += w*q*q;\n                }\n                if (sumq2 > 0) scale = sumqx/sumq2;\n            }\n            if (scale < 0) {\n                scale = -scale;\n                for (int k = 0; k < 2; ++k) block_signs[k] = ~block_signs[k];\n            }\n            for (int k = 0; k < 2; ++k) {\n                uint16_t u = 0;\n                for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);\n                int grid_index = kmap_q2xs[u];\n                if (grid_index < 0) {\n                    printf(\"Oops: found point %u not on grid:\", u);\n                    for (int i = 0; i < 8; ++i) printf(\" %d\", L[8*k+i]);\n                    printf(\"\\n\");\n                    GGML_ABORT(\"fatal error\");\n                }\n                const int i8 = 2*ib + k;\n                y[ibl].qs[i8] = grid_index & 255;\n                y[ibl].qh[i8/4] |= ((grid_index >> 8) << 2*(i8%4));\n                y[ibl].qs[QK_K/8 + i8] = block_signs[k];\n            }\n            GGML_ASSERT(scale >= 0);\n            scales[ib] = scale;\n            max_scale = MAX(max_scale, scale);\n        }\n\n        if (!max_scale) {\n            continue;\n        }\n\n        float d = max_scale/31;\n        y[ibl].d = GGML_FP32_TO_FP16(d * 0.9875f);\n        float id = 1/d;\n        for (int ib = 0; ib < QK_K/16; ++ib) {\n            int l = nearest_int(0.5f*(id*scales[ib]-1));\n            l = MAX(0, MIN(15, l));\n            if (ib%2 == 0) y[ibl].scales[ib/2] = l;\n            else y[ibl].scales[ib/2] |= (l << 4);\n        }\n    }\n}\n\nsize_t quantize_iq2_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {\n    GGML_ASSERT(n_per_row%QK_K == 0);\n    int64_t nblock = n_per_row/QK_K;\n    char * qrow = (char *)dst;\n    for (int64_t row = 0; row < nrow; ++row) {\n        quantize_row_iq2_s_impl(src, qrow, n_per_row, quant_weights);\n        src += n_per_row;\n        qrow += nblock*sizeof(block_iq2_s);\n    }\n    return nrow * nblock * sizeof(block_iq2_s);\n}\n\nvoid quantize_row_iq2_s_ref(const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k) {\n    assert(k % QK_K == 0);\n    quantize_iq2_s(x, y, 1, k, NULL);\n}\n\n// =============================== data validation\n\nstatic bool validate_float(float f, size_t i) {\n    if (isinf(f)) {\n        fprintf(stderr, \"ggml_validate_row_data: found inf value at block %zu\\n\", i);\n        return false;\n    }\n\n    if (isnan(f)) {\n        fprintf(stderr, \"ggml_validate_row_data: found nan value at block %zu\\n\", i);\n        return false;\n    }\n\n    return true;\n}\n\nstatic bool isinf_fp16(ggml_fp16_t f) {\n    return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0;\n}\n\nstatic bool isnan_fp16(ggml_fp16_t f) {\n    return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0;\n}\n\nstatic bool validate_fp16(ggml_fp16_t f, size_t i) {\n    if (isinf_fp16(f)) {\n        fprintf(stderr, \"ggml_validate_row_data: found inf value at block %zu\\n\", i);\n        return false;\n    }\n\n    if (isnan_fp16(f)) {\n        fprintf(stderr, \"ggml_validate_row_data: found nan value at block %zu\\n\", i);\n        return false;\n    }\n\n    return true;\n}\n\nstatic bool validate_e_e8m0(uint8_t e, size_t i) {\n    if (e == 0xff) {\n        fprintf(stderr, \"ggml_validate_row_data: found invalid e value %d at block %zu\\n\", e, i);\n        return false;\n    }\n\n    return true;\n}\n\n#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \\\n    const type * q = (const type *) (data); \\\n    for (size_t i = 0; i < (nb); ++i) { \\\n        if (!validate_fp16(q[i].d, i)) { \\\n            return false; \\\n        } \\\n    }\n\n#define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \\\n    const type * q = (const type *) (data); \\\n    for (size_t i = 0; i < (nb); ++i) { \\\n        if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \\\n            return false; \\\n        } \\\n    }\n\n#define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \\\n    const type * q = (const type *) (data); \\\n    for (size_t i = 0; i < (nb); ++i) { \\\n        if (!validate_e_e8m0(q[i].e, i)) { \\\n            return false; \\\n        } \\\n    }\n\n#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \\\n    const type * q = (const type *) (data); \\\n    for (size_t i = 0; i < (nb); ++i) { \\\n        for (size_t j = 0; j < (nr); ++j) { \\\n            if (!validate_fp16(q[i].d[j], i)) { \\\n                return false; \\\n            } \\\n        } \\\n    }\n\nbool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {\n    if (type < 0 || type >= GGML_TYPE_COUNT) {\n        fprintf(stderr, \"%s: invalid type %d\\n\", __func__, type);\n        return false;\n    }\n\n    if (nbytes % ggml_type_size(type) != 0) {\n        fprintf(stderr, \"%s: invalid size %zu for type %s (type size = %zu)\\n\", __func__, nbytes, ggml_type_name(type), ggml_type_size(type));\n        return false;\n    }\n\n    const size_t nb = nbytes/ggml_type_size(type);\n\n    switch (type) {\n        case GGML_TYPE_BF16:\n            {\n                int nans = 0;\n                int infs = 0;\n                const unsigned short * f = (const unsigned short *) data;\n                for (size_t i = 0; i < nb; ++i) {\n                    nans += (f[i] & 0x7fff) > 0x7f80;\n                    infs += (f[i] & 0x7fff) == 0x7f80;\n                }\n                if (nans) {\n                    fprintf(stderr, \"%s: found %d NaNs in row of %zu BF16 values\\n\", __func__, nans, nb);\n                    return false;\n                }\n                if (infs) {\n                    fprintf(stderr, \"%s: found %d infinities in row of %zu BF16 values\\n\", __func__, infs, nb);\n                    return false;\n                }\n            } break;\n        case GGML_TYPE_F16:\n            {\n                const ggml_fp16_t * f = (const ggml_fp16_t *) data;\n                size_t i = 0;\n#if defined(__AVX2__)\n                for (; i + 15 < nb; i += 16) {\n                    __m256i v = _mm256_loadu_si256((const __m256i *)(f + i));\n                    __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00));\n                    __m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00));\n                    int mask = _mm256_movemask_epi8(cmp);\n                    if (mask) {\n                        for (size_t j = 0; j < 16; ++j) {\n                            if (!validate_fp16(f[i + j], i + j)) {\n                                return false;\n                            }\n                        }\n                        GGML_UNREACHABLE();\n                    }\n                }\n#elif defined(__ARM_NEON)\n                for (; i + 7 < nb; i += 8) {\n                    uint16x8_t v = vld1q_u16(f + i);\n                    uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00));\n                    uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00));\n                    uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0);\n                    if (mask) {\n                        for (size_t j = 0; j < 8; ++j) {\n                            if (!validate_fp16(f[i + j], i + j)) {\n                                return false;\n                            }\n                        }\n                        GGML_UNREACHABLE();\n                    }\n                }\n#endif\n                for (; i < nb; ++i) {\n                    if (!validate_fp16(f[i], i)) {\n                        return false;\n                    }\n                }\n            } break;\n        case GGML_TYPE_F32:\n            {\n                const float * f = (const float *) data;\n                size_t i = 0;\n#if defined(__AVX2__)\n                for (; i + 7 < nb; i += 8) {\n                    __m256i v = _mm256_loadu_si256((const __m256i *)(f + i));\n                    __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000));\n                    __m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000));\n                    int mask = _mm256_movemask_epi8(cmp);\n                    if (mask) {\n                        for (size_t j = 0; j < 8; ++j) {\n                            if (!validate_float(f[i + j], i + j)) {\n                                return false;\n                            }\n                        }\n                        GGML_UNREACHABLE();\n                    }\n                }\n#elif defined(__ARM_NEON)\n                for (; i + 3 < nb; i += 4) {\n                    uint32x4_t v = vld1q_u32((const uint32_t *)f + i);\n                    uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000));\n                    uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000));\n                    uint64_t mask = vget_lane_u64(vreinterpret_u64_u16(vshrn_n_u32(cmp, 8)), 0);\n                    if (mask) {\n                        for (size_t j = 0; j < 4; ++j) {\n                            if (!validate_float(f[i + j], i + j)) {\n                                return false;\n                            }\n                        }\n                        GGML_UNREACHABLE();\n                    }\n                }\n#endif\n                for (; i < nb; ++i) {\n                    if (!validate_float(f[i], i)) {\n                        return false;\n                    }\n                }\n            } break;\n        case GGML_TYPE_F64:\n            {\n                const double * f = (const double *) data;\n                for (size_t i = 0; i < nb; ++i) {\n                    if (!validate_float(f[i], i)) {\n                        return false;\n                    }\n                }\n            } break;\n        case GGML_TYPE_Q4_0:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb);\n            } break;\n        case GGML_TYPE_Q4_1:\n            {\n                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m);\n            } break;\n        case GGML_TYPE_Q5_0:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_0, data, nb);\n            } break;\n        case GGML_TYPE_Q5_1:\n            {\n                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_1, data, nb, d, m);\n            } break;\n        case GGML_TYPE_Q8_0:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);\n            } break;\n        case GGML_TYPE_MXFP4:\n            {\n                VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);\n            } break;\n        case GGML_TYPE_NVFP4:\n            {\n                // UE4M3 scales are uint8_t — all byte values are valid\n                GGML_UNUSED(data);\n                GGML_UNUSED(nb);\n            } break;\n        case GGML_TYPE_Q2_K:\n            {\n                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);\n            } break;\n        case GGML_TYPE_Q3_K:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_K, data, nb);\n            } break;\n        case GGML_TYPE_Q4_K:\n            {\n                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin);\n            } break;\n        case GGML_TYPE_Q5_K:\n            {\n                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin);\n            } break;\n        case GGML_TYPE_Q6_K:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_q6_K, data, nb);\n            } break;\n        case GGML_TYPE_Q8_K:\n            {\n                const block_q8_K * q = (const block_q8_K *) data;\n                for (size_t i = 0; i < nb; ++i) {\n                    if (!validate_float(q[i].d, i)) {\n                        return false;\n                    }\n                }\n            } break;\n        case GGML_TYPE_TQ1_0:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_tq1_0, data, nb);\n            } break;\n        case GGML_TYPE_TQ2_0:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb);\n            } break;\n        case GGML_TYPE_IQ1_S:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);\n            } break;\n        case GGML_TYPE_IQ1_M:\n            {\n                const block_iq1_m * q = (const block_iq1_m *) data;\n                for (size_t i = 0; i < nb; ++i) {\n                    iq1m_scale_t scale;\n                    const uint16_t * sc = (const uint16_t *)q[i].scales;\n                    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);\n                    if (!validate_fp16(scale.f16, i)) {\n                        return false;\n                    }\n                }\n            } break;\n        case GGML_TYPE_IQ2_XXS:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xxs, data, nb);\n            } break;\n        case GGML_TYPE_IQ2_XS:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xs, data, nb);\n            } break;\n        case GGML_TYPE_IQ2_S:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_s, data, nb);\n            } break;\n        case GGML_TYPE_IQ3_XXS:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_xxs, data, nb);\n            } break;\n\n        case GGML_TYPE_IQ3_S:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb);\n            } break;\n        case GGML_TYPE_IQ4_XS:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb);\n            } break;\n        case GGML_TYPE_IQ4_NL:\n            {\n                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);\n            } break;\n\n        case GGML_TYPE_I8:\n        case GGML_TYPE_I16:\n        case GGML_TYPE_I32:\n        case GGML_TYPE_I64:\n            // nothing to validate\n            break;\n        default:\n            {\n                fprintf(stderr, \"%s: invalid type %d\\n\", __func__, type);\n                return false;\n            }\n    }\n\n    return true;\n}\n"
  },
  {
    "path": "src/ggml-quants.h",
    "content": "#pragma once\n\n#define GGML_COMMON_DECL_C\n#include \"ggml-common.h\"\n\n#include \"ggml.h\"\n\n// GGML internal header\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n// NOTE: these functions are defined as GGML_API because they used by the CPU backend\n\n// Quantization\nGGML_API void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);\n\nGGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k);\n\nGGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);\n\nGGML_API void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k);\n\nGGML_API void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl  * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs  * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_iq3_s_ref  (const float * GGML_RESTRICT x, block_iq3_s   * GGML_RESTRICT y, int64_t k);\nGGML_API void quantize_row_iq2_s_ref  (const float * GGML_RESTRICT x, block_iq2_s   * GGML_RESTRICT y, int64_t k);\n\n// Dequantization\nGGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\n//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\n\nGGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\n\nGGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\n\nGGML_API void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\n\nGGML_API void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_iq2_xs (const block_iq2_xs  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_iq2_s  (const block_iq2_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_iq1_s  (const block_iq1_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_iq1_m  (const block_iq1_m   * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_iq4_nl (const block_iq4_nl  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_iq4_xs (const block_iq4_xs  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\nGGML_API void dequantize_row_iq3_s  (const block_iq3_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);\n\n// Quantization utilizing an importance matrix (a.k.a. \"Activation aWare Quantization\")\nGGML_API size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_iq2_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_iq1_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_iq1_m  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_iq3_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\n\nGGML_API size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\n\nGGML_API size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\n\nGGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\nGGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);\n\nGGML_API void iq2xs_init_impl(enum ggml_type type);\nGGML_API void iq2xs_free_impl(enum ggml_type type);\nGGML_API void iq3xs_init_impl(int grid_size);\nGGML_API void iq3xs_free_impl(int grid_size);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-rpc/CMakeLists.txt",
    "content": "message(STATUS \"Using RPC backend\")\n\nggml_add_backend_library(ggml-rpc\n                         ggml-rpc.cpp\n                        )\n\nif (WIN32)\n    target_link_libraries(ggml-rpc PRIVATE ws2_32)\nendif()\n"
  },
  {
    "path": "src/ggml-rpc/ggml-rpc.cpp",
    "content": "#include \"ggml-rpc.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-backend-impl.h\"\n#include \"ggml-cpp.h\"\n\n#include <cinttypes>\n#include <string>\n#include <vector>\n#include <memory>\n#include <mutex>\n#include <unordered_map>\n#include <unordered_set>\n#ifdef _WIN32\n#  define WIN32_LEAN_AND_MEAN\n#  ifndef NOMINMAX\n#     define NOMINMAX\n#  endif\n#  include <windows.h>\n#  include <winsock2.h>\n#else\n#  include <arpa/inet.h>\n#  include <sys/socket.h>\n#  include <sys/types.h>\n#  include <netinet/in.h>\n#  include <netinet/tcp.h>\n#  include <netdb.h>\n#  include <unistd.h>\n#endif\n#include <cstring>\n#include <fstream>\n#include <filesystem>\n#include <algorithm>\n\nstatic const char * RPC_DEBUG = std::getenv(\"GGML_RPC_DEBUG\");\n\n#define LOG_DBG(...) \\\n    do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0)\n\n\nnamespace fs = std::filesystem;\n\nstatic constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB\n\n#ifdef _WIN32\ntypedef SOCKET sockfd_t;\nusing ssize_t = __int64;\n#else\ntypedef int sockfd_t;\n#endif\n\n// cross-platform socket\nstruct socket_t {\n    sockfd_t fd;\n    socket_t(sockfd_t fd) : fd(fd) {}\n    ~socket_t() {\n        LOG_DBG(\"[%s] closing socket %d\\n\", __func__, this->fd);\n#ifdef _WIN32\n        closesocket(this->fd);\n#else\n        close(this->fd);\n#endif\n    }\n};\n\n// macro for nicer error messages on server crash\n#define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT(\"Remote RPC server crashed or returned malformed response\")\n\n// all RPC structures must be packed\n#pragma pack(push, 1)\n// ggml_tensor is serialized into rpc_tensor\nstruct rpc_tensor {\n    uint64_t id;\n    uint32_t type;\n    uint64_t buffer;\n    uint32_t ne[GGML_MAX_DIMS];\n    uint32_t nb[GGML_MAX_DIMS];\n    uint32_t op;\n    int32_t  op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];\n    int32_t  flags;\n    uint64_t src[GGML_MAX_SRC];\n    uint64_t view_src;\n    uint64_t view_offs;\n    uint64_t data;\n    char name[GGML_MAX_NAME];\n\n    char padding[4];\n};\n\nstatic_assert(sizeof(rpc_tensor) % 8 == 0, \"rpc_tensor size must be multiple of 8\");\n\n// RPC commands\nenum rpc_cmd {\n    RPC_CMD_ALLOC_BUFFER = 0,\n    RPC_CMD_GET_ALIGNMENT,\n    RPC_CMD_GET_MAX_SIZE,\n    RPC_CMD_BUFFER_GET_BASE,\n    RPC_CMD_FREE_BUFFER,\n    RPC_CMD_BUFFER_CLEAR,\n    RPC_CMD_SET_TENSOR,\n    RPC_CMD_SET_TENSOR_HASH,\n    RPC_CMD_GET_TENSOR,\n    RPC_CMD_COPY_TENSOR,\n    RPC_CMD_GRAPH_COMPUTE,\n    RPC_CMD_GET_DEVICE_MEMORY,\n    RPC_CMD_INIT_TENSOR,\n    RPC_CMD_GET_ALLOC_SIZE,\n    RPC_CMD_HELLO,\n    RPC_CMD_DEVICE_COUNT,\n    RPC_CMD_GRAPH_RECOMPUTE,\n    RPC_CMD_COUNT,\n};\n\nstatic_assert(RPC_CMD_HELLO == 14, \"RPC_CMD_HELLO must be always 14\");\n\n// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold\nconst size_t HASH_THRESHOLD = 10 * 1024 * 1024;\n\nstruct rpc_msg_hello_rsp {\n    uint8_t major;\n    uint8_t minor;\n    uint8_t patch;\n};\n\nstruct rpc_msg_device_count_rsp {\n    uint32_t device_count;\n};\n\nstruct rpc_msg_get_alloc_size_req {\n    uint32_t   device;\n    rpc_tensor tensor;\n    rpc_tensor srcs[GGML_MAX_SRC];\n};\n\nstruct rpc_msg_get_alloc_size_rsp {\n    uint64_t alloc_size;\n};\n\nstruct rpc_msg_init_tensor_req {\n    rpc_tensor tensor;\n};\n\nstruct rpc_msg_alloc_buffer_req {\n    uint32_t device;\n    uint64_t size;\n};\n\nstruct rpc_msg_alloc_buffer_rsp {\n    uint64_t remote_ptr;\n    uint64_t remote_size;\n};\n\nstruct rpc_msg_get_alignment_req {\n    uint32_t device;\n};\n\nstruct rpc_msg_get_alignment_rsp {\n    uint64_t alignment;\n};\n\nstruct rpc_msg_get_max_size_req {\n    uint32_t device;\n};\n\nstruct rpc_msg_get_max_size_rsp {\n    uint64_t max_size;\n};\n\nstruct rpc_msg_buffer_get_base_req {\n    uint64_t remote_ptr;\n};\n\nstruct rpc_msg_buffer_get_base_rsp {\n    uint64_t base_ptr;\n};\n\nstruct rpc_msg_free_buffer_req {\n    uint64_t remote_ptr;\n};\n\nstruct rpc_msg_buffer_clear_req {\n    uint64_t remote_ptr;\n    uint8_t value;\n};\n\nstruct rpc_msg_set_tensor_hash_req {\n    rpc_tensor tensor;\n    uint64_t offset;\n    uint64_t hash;\n};\n\nstruct rpc_msg_set_tensor_hash_rsp {\n    uint8_t result;\n};\n\nstruct rpc_msg_get_tensor_req {\n    rpc_tensor tensor;\n    uint64_t offset;\n    uint64_t size;\n};\n\nstruct rpc_msg_copy_tensor_req {\n    rpc_tensor src;\n    rpc_tensor dst;\n};\n\nstruct rpc_msg_copy_tensor_rsp {\n    uint8_t result;\n};\n\nstruct rpc_msg_get_device_memory_req {\n    uint32_t device;\n};\n\nstruct rpc_msg_get_device_memory_rsp {\n    uint64_t free_mem;\n    uint64_t total_mem;\n};\n\nstruct rpc_msg_graph_recompute_req {\n    uint32_t device;\n};\n\n#pragma pack(pop)\n\n// RPC data structures\n\nstatic ggml_guid_t ggml_backend_rpc_guid() {\n    static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};\n    return &guid;\n}\n\nstruct ggml_backend_rpc_buffer_type_context {\n    std::string endpoint;\n    uint32_t    device;\n    std::string name;\n    size_t      alignment;\n    size_t      max_size;\n};\n\nstruct graph_cache {\n\n    bool is_cached(const ggml_cgraph * cgraph) {\n        if ((int)last_graph.size() != cgraph->n_nodes) {\n            return false;\n        }\n        for (int i = 0; i < cgraph->n_nodes; i++) {\n            if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {\n                return false;\n            }\n        }\n        return true;\n    }\n\n    void add(const ggml_cgraph * cgraph) {\n        last_graph.resize(cgraph->n_nodes);\n        for (int i = 0; i < cgraph->n_nodes; i++) {\n            memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));\n        }\n    }\n\n    std::vector<ggml_tensor> last_graph;\n};\n\nstruct ggml_backend_rpc_context {\n    std::string endpoint;\n    uint32_t    device;\n    std::string name;\n    graph_cache gc;\n};\n\nstruct ggml_backend_rpc_buffer_context {\n    std::shared_ptr<socket_t> sock;\n    void * base_ptr;\n    uint64_t remote_ptr;\n};\n\n// RPC helper functions\n\n// Computes FNV-1a hash of the data\nstatic uint64_t fnv_hash(const uint8_t * data, size_t len) {\n    const uint64_t fnv_prime = 0x100000001b3ULL;\n    uint64_t hash = 0xcbf29ce484222325ULL;\n\n    for (size_t i = 0; i < len; ++i) {\n        hash ^= data[i];\n        hash *= fnv_prime;\n    }\n    return hash;\n}\n\nstatic std::shared_ptr<socket_t> make_socket(sockfd_t fd) {\n#ifdef _WIN32\n    if (fd == INVALID_SOCKET) {\n        return nullptr;\n    }\n#else\n    if (fd < 0) {\n        return nullptr;\n    }\n#endif\n    return std::make_shared<socket_t>(fd);\n}\n\nstatic bool set_no_delay(sockfd_t sockfd) {\n    int flag = 1;\n    // set TCP_NODELAY to disable Nagle's algorithm\n    int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));\n    return ret == 0;\n}\n\nstatic bool set_reuse_addr(sockfd_t sockfd) {\n    int flag = 1;\n    int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));\n    return ret == 0;\n}\n\nstatic std::shared_ptr<socket_t> socket_connect(const char * host, int port) {\n    struct sockaddr_in addr;\n    auto sockfd = socket(AF_INET, SOCK_STREAM, 0);\n    auto sock_ptr = make_socket(sockfd);\n    if (sock_ptr == nullptr) {\n        return nullptr;\n    }\n    if (!set_no_delay(sockfd)) {\n        GGML_LOG_ERROR(\"Failed to set TCP_NODELAY\\n\");\n        return nullptr;\n    }\n    addr.sin_family = AF_INET;\n    addr.sin_port = htons(port);\n    struct hostent * server = gethostbyname(host);\n    if (server == NULL) {\n        GGML_LOG_ERROR(\"Cannot resolve host '%s'\\n\", host);\n        return nullptr;\n    }\n    memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);\n    if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {\n        return nullptr;\n    }\n    return sock_ptr;\n}\n\nstatic std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {\n    auto client_socket_fd = accept(srv_sockfd, NULL, NULL);\n    auto client_socket = make_socket(client_socket_fd);\n    if (client_socket == nullptr) {\n        return nullptr;\n    }\n    if (!set_no_delay(client_socket_fd)) {\n        GGML_LOG_ERROR(\"Failed to set TCP_NODELAY\\n\");\n        return nullptr;\n    }\n    return client_socket;\n}\n\nstatic std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {\n    auto sockfd = socket(AF_INET, SOCK_STREAM, 0);\n    auto sock = make_socket(sockfd);\n    if (sock == nullptr) {\n        return nullptr;\n    }\n    if (!set_reuse_addr(sockfd)) {\n        GGML_LOG_ERROR(\"Failed to set SO_REUSEADDR\\n\");\n        return nullptr;\n    }\n    if (inet_addr(host) == INADDR_NONE) {\n        GGML_LOG_ERROR(\"Invalid host address: %s\\n\", host);\n        return nullptr;\n    }\n    struct sockaddr_in serv_addr;\n    serv_addr.sin_family = AF_INET;\n    serv_addr.sin_addr.s_addr = inet_addr(host);\n    serv_addr.sin_port = htons(port);\n\n    if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {\n        return nullptr;\n    }\n    if (listen(sockfd, 1) < 0) {\n        return nullptr;\n    }\n    return sock;\n}\n\nstatic bool send_data(sockfd_t sockfd, const void * data, size_t size) {\n    size_t bytes_sent = 0;\n    while (bytes_sent < size) {\n        size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE);\n        ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0);\n        if (n < 0) {\n            GGML_LOG_ERROR(\"send failed (bytes_sent=%zu, size_to_send=%zu)\\n\",\n                           bytes_sent, size_to_send);\n            return false;\n        }\n        bytes_sent += (size_t)n;\n    }\n    return true;\n}\n\nstatic bool recv_data(sockfd_t sockfd, void * data, size_t size) {\n    size_t bytes_recv = 0;\n    while (bytes_recv < size) {\n        size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE);\n        ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0);\n        if (n < 0) {\n            GGML_LOG_ERROR(\"recv failed (bytes_recv=%zu, size_to_recv=%zu)\\n\",\n                           bytes_recv, size_to_recv);\n            return false;\n        }\n        if (n == 0) {\n            LOG_DBG(\"recv returned 0 (peer closed?)\\n\");\n            return false;\n        }\n        bytes_recv += (size_t)n;\n    }\n    return true;\n}\n\nstatic bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {\n    if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {\n        return false;\n    }\n    return send_data(sockfd, msg, msg_size);\n}\n\nstatic bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {\n    uint64_t size;\n    if (!recv_data(sockfd, &size, sizeof(size))) {\n        return false;\n    }\n    if (size != msg_size) {\n        return false;\n    }\n    return recv_data(sockfd, msg, msg_size);\n}\n\nstatic bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {\n    uint64_t size;\n    if (!recv_data(sockfd, &size, sizeof(size))) {\n        return false;\n    }\n    try {\n        input.resize(size);\n    } catch (const std::bad_alloc & e) {\n        GGML_LOG_ERROR(\"Failed to allocate input buffer of size %\" PRIu64 \"\\n\", size);\n        return false;\n    }\n    return recv_data(sockfd, input.data(), size);\n}\n\nstatic bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {\n    size_t pos = endpoint.find(':');\n    if (pos == std::string::npos) {\n        return false;\n    }\n    host = endpoint.substr(0, pos);\n    port = std::stoi(endpoint.substr(pos + 1));\n    return true;\n}\n\n// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |\n// No response\nstatic bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {\n    uint8_t cmd_byte = cmd;\n    if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {\n        return false;\n    }\n    if (!send_data(sock->fd, &input_size, sizeof(input_size))) {\n        return false;\n    }\n    if (!send_data(sock->fd, input, input_size)) {\n        return false;\n    }\n    return true;\n}\n\n// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |\n// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |\nstatic bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {\n    if (!send_rpc_cmd(sock, cmd, input, input_size)) {\n        return false;\n    }\n    // TODO: currently the output_size is always known, do we need support for commands with variable output size?\n    // even if we do, we can skip sending output_size from the server for commands with known output size\n    uint64_t out_size;\n    if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {\n        return false;\n    }\n    if (out_size != output_size) {\n        return false;\n    }\n    if (!recv_data(sock->fd, output, output_size)) {\n        return false;\n    }\n    return true;\n}\n\n// RPC client-side implementation\n\nstatic bool check_server_version(const std::shared_ptr<socket_t> & sock) {\n    rpc_msg_hello_rsp response;\n    bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));\n    RPC_STATUS_ASSERT(status);\n    if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {\n        GGML_LOG_ERROR(\"RPC server version mismatch: %d.%d.%d\\n\", response.major, response.minor, response.patch);\n        return false;\n    }\n    if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {\n        GGML_LOG_INFO(\"WARNING: RPC server version mismatch: %d.%d.%d\\n\", response.major, response.minor, response.patch);\n    }\n    return true;\n}\n\nstatic std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {\n    static std::mutex mutex;\n    std::lock_guard<std::mutex> lock(mutex);\n    static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;\n    static bool initialized = false;\n\n    auto it = sockets.find(endpoint);\n    if (it != sockets.end()) {\n        if (auto sock = it->second.lock()) {\n            return sock;\n        }\n    }\n    std::string host;\n    int port;\n    if (!parse_endpoint(endpoint, host, port)) {\n        GGML_LOG_ERROR(\"Failed to parse endpoint: %s\\n\", endpoint.c_str());\n        return nullptr;\n    }\n#ifdef _WIN32\n    if (!initialized) {\n        WSADATA wsaData;\n        int res = WSAStartup(MAKEWORD(2, 2), &wsaData);\n        if (res != 0) {\n            return nullptr;\n        }\n        initialized = true;\n    }\n#else\n    GGML_UNUSED(initialized);\n#endif\n    auto sock = socket_connect(host.c_str(), port);\n    if (sock == nullptr) {\n        return nullptr;\n    }\n    if (!check_server_version(sock)) {\n        return nullptr;\n    }\n    LOG_DBG(\"[%s] connected to %s, sockfd=%d\\n\", __func__, endpoint.c_str(), sock->fd);\n    sockets[endpoint] = sock;\n    return sock;\n}\n\nstatic void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;\n    rpc_msg_free_buffer_req request = {ctx->remote_ptr};\n    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);\n    RPC_STATUS_ASSERT(status);\n    delete ctx;\n}\n\nstatic void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {\n    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;\n    if (ctx->base_ptr != nullptr) {\n        return ctx->base_ptr;\n    }\n    rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};\n    rpc_msg_buffer_get_base_rsp response;\n    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));\n    RPC_STATUS_ASSERT(status);\n    ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);\n    return ctx->base_ptr;\n}\n\nstatic bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {\n    return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;\n}\n\nstatic rpc_tensor serialize_tensor(const ggml_tensor * tensor) {\n    rpc_tensor result;\n    if (!tensor) {\n        memset(&result, 0, sizeof(result));\n        return result;\n    }\n\n    result.id = reinterpret_cast<uint64_t>(tensor);\n    result.type = tensor->type;\n    if (tensor->buffer && ggml_backend_buffer_is_rpc(tensor->buffer)) {\n        ggml_backend_buffer_t buffer = tensor->buffer;\n        ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;\n        result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;\n    } else {\n        result.buffer = 0;\n    }\n    for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {\n        result.ne[i] = tensor->ne[i];\n        result.nb[i] = tensor->nb[i];\n    }\n    result.op = tensor->op;\n    for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {\n        result.op_params[i] = tensor->op_params[i];\n    }\n    result.flags = tensor->flags;\n    for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {\n        result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);\n    }\n    result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);\n    result.view_offs = tensor->view_offs;\n    result.data = reinterpret_cast<uint64_t>(tensor->data);\n\n    // Avoid sending uninitialized data over the wire\n    memset(result.name, 0, sizeof(result.name));\n    memset(result.padding, 0, sizeof(result.padding));\n\n    snprintf(result.name, GGML_MAX_NAME, \"%s\", tensor->name);\n    return result;\n}\n\nstatic enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {\n    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;\n\n    // CUDA backend on the server pads everything to 512 due to CUDA limitations.\n    // Due to bandwidth constraints, we only call the server init tensor functions if necessary.\n    // In particular, only quantized tensors need padding\n    if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {\n        rpc_msg_init_tensor_req request;\n\n        request.tensor = serialize_tensor(tensor);\n\n        bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);\n        RPC_STATUS_ASSERT(status);\n    }\n    return GGML_STATUS_SUCCESS;\n}\n\nstatic void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;\n    rpc_tensor rpc_tensor = serialize_tensor(tensor);\n    if (size > HASH_THRESHOLD) {\n        rpc_msg_set_tensor_hash_req request;\n        request.tensor = rpc_tensor;\n        request.offset = offset;\n        request.hash = fnv_hash((const uint8_t*)data, size);\n        rpc_msg_set_tensor_hash_rsp response;\n        bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));\n        RPC_STATUS_ASSERT(status);\n        if (response.result) {\n            // the server has the same data, no need to send it\n            return;\n        }\n    }\n    // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)\n    size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;\n    std::vector<uint8_t> input(input_size, 0);\n    memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));\n    memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));\n    memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);\n    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());\n    RPC_STATUS_ASSERT(status);\n}\n\nstatic void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;\n    rpc_msg_get_tensor_req request;\n    request.tensor = serialize_tensor(tensor);\n    request.offset = offset;\n    request.size = size;\n    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);\n    RPC_STATUS_ASSERT(status);\n}\n\nstatic bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {\n    if (ggml_backend_buffer_is_rpc(src->buffer)) {\n        // check if src and dst are on the same server\n        ggml_backend_buffer_t src_buffer = src->buffer;\n        ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;\n        ggml_backend_buffer_t dst_buffer = dst->buffer;\n        ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;\n        if (src_ctx->sock != dst_ctx->sock) {\n            return false;\n        }\n        ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;\n        rpc_msg_copy_tensor_req request;\n        request.src = serialize_tensor(src);\n        request.dst = serialize_tensor(dst);\n        rpc_msg_copy_tensor_rsp response;\n        bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));\n        RPC_STATUS_ASSERT(status);\n        return response.result;\n    }\n    return false;\n}\n\nstatic void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;\n    rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};\n    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);\n    RPC_STATUS_ASSERT(status);\n}\n\nstatic ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {\n    /* .free_buffer     = */ ggml_backend_rpc_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_rpc_buffer_get_base,\n    /* .init_tensor     = */ ggml_backend_rpc_buffer_init_tensor,\n    /* .memset_tensor   = */ NULL,\n    /* .set_tensor      = */ ggml_backend_rpc_buffer_set_tensor,\n    /* .get_tensor      = */ ggml_backend_rpc_buffer_get_tensor,\n    /* .cpy_tensor      = */ ggml_backend_rpc_buffer_cpy_tensor,\n    /* .clear           = */ ggml_backend_rpc_buffer_clear,\n    /* .reset           = */ NULL,\n};\n\nstatic const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {\n    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;\n    return buft_ctx->name.c_str();\n}\n\nstatic ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;\n    rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};\n    rpc_msg_alloc_buffer_rsp response;\n    auto sock = get_socket(buft_ctx->endpoint);\n    bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));\n    RPC_STATUS_ASSERT(status);\n    if (response.remote_ptr != 0) {\n        ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,\n            ggml_backend_rpc_buffer_interface,\n            new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},\n            response.remote_size);\n        return buffer;\n    } else {\n        return nullptr;\n    }\n}\n\nstatic size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {\n    rpc_msg_get_alignment_req request = {device};\n    rpc_msg_get_alignment_rsp response;\n    bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));\n    RPC_STATUS_ASSERT(status);\n    return response.alignment;\n}\n\nstatic size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;\n    return buft_ctx->alignment;\n}\n\nstatic size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {\n    rpc_msg_get_max_size_req request = {device};\n    rpc_msg_get_max_size_rsp response;\n    bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));\n    RPC_STATUS_ASSERT(status);\n    return response.max_size;\n}\n\nstatic size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {\n    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;\n    return buft_ctx->max_size;\n}\n\nstatic size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {\n    // should we query the remote server for the actual size\n    bool rpc_get = false;\n\n    // See comments in init_tensor.\n    rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr);\n\n    // ops that require additional memory for fleeting data on certain backends\n    // ref: https://github.com/ggml-org/llama.cpp/pull/15966\n    rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT;\n    rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID;\n\n    if (rpc_get) {\n        ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;\n        auto sock = get_socket(buft_ctx->endpoint);\n\n        rpc_msg_get_alloc_size_req request = {\n            /*.device =*/ buft_ctx->device,\n            /*.tensor =*/ serialize_tensor(tensor),\n            /*.srcs   =*/ {},\n        };\n\n        // .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well\n        for (int i = 0; i < GGML_MAX_SRC; i++) {\n            request.srcs[i] = serialize_tensor(tensor->src[i]);\n        }\n\n        // TODO: cache the alloc responses to avoid extra RPC calls?\n        rpc_msg_get_alloc_size_rsp response;\n        bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));\n        RPC_STATUS_ASSERT(status);\n\n        return response.alloc_size;\n    }\n\n    return ggml_nbytes(tensor);\n}\n\nstatic ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {\n    /* .get_name         = */ ggml_backend_rpc_buffer_type_name,\n    /* .alloc_buffer     = */ ggml_backend_rpc_buffer_type_alloc_buffer,\n    /* .get_alignment    = */ ggml_backend_rpc_buffer_type_get_alignment,\n    /* .get_max_size     = */ ggml_backend_rpc_get_max_size,\n    /* .get_alloc_size   = */ ggml_backend_rpc_buffer_type_get_alloc_size,\n    /* .is_host          = */ NULL,\n};\n\nstatic const char * ggml_backend_rpc_name(ggml_backend_t backend) {\n    ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;\n\n    return rpc_ctx->name.c_str();\n}\n\nstatic void ggml_backend_rpc_free(ggml_backend_t backend) {\n    ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;\n    delete rpc_ctx;\n    delete backend;\n}\n\nstatic void ggml_backend_rpc_synchronize(ggml_backend_t backend) {\n    GGML_UNUSED(backend);\n    // this is no-op because we don't have any async operations\n}\n\nstatic void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {\n    if (tensor == nullptr) {\n        return;\n    }\n    if (visited.find(tensor) != visited.end()) {\n        return;\n    }\n    visited.insert(tensor);\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        add_tensor(tensor->src[i], tensors, visited);\n    }\n    add_tensor(tensor->view_src, tensors, visited);\n    tensors.push_back(serialize_tensor(tensor));\n}\n\nstatic void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {\n    uint32_t n_nodes = cgraph->n_nodes;\n    std::vector<rpc_tensor> tensors;\n    std::unordered_set<ggml_tensor*> visited;\n    for (uint32_t i = 0; i < n_nodes; i++) {\n        add_tensor(cgraph->nodes[i], tensors, visited);\n    }\n    // serialization format:\n    // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |\n    uint32_t n_tensors = tensors.size();\n    int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);\n    output.resize(output_size, 0);\n    uint8_t * dest = output.data();\n    memcpy(dest, &device, sizeof(device));\n    dest += sizeof(device);\n    memcpy(dest, &n_nodes, sizeof(n_nodes));\n    dest += sizeof(n_nodes);\n    for (uint32_t i = 0; i < n_nodes; i++) {\n        memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));\n    }\n    dest += n_nodes * sizeof(uint64_t);\n    memcpy(dest, &n_tensors, sizeof(n_tensors));\n    dest += sizeof(n_tensors);\n    rpc_tensor * out_tensors = (rpc_tensor *)dest;\n    memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));\n}\n\nstatic enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {\n    ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;\n\n    GGML_ASSERT(cgraph->n_nodes > 0);\n    bool reuse = rpc_ctx->gc.is_cached(cgraph);\n    if (reuse) {\n        rpc_msg_graph_recompute_req request;\n        request.device = rpc_ctx->device;\n        auto sock = get_socket(rpc_ctx->endpoint);\n        bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));\n        RPC_STATUS_ASSERT(status);\n    } else {\n        rpc_ctx->gc.add(cgraph);\n        std::vector<uint8_t> input;\n        serialize_graph(rpc_ctx->device, cgraph, input);\n        auto sock = get_socket(rpc_ctx->endpoint);\n        bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());\n        RPC_STATUS_ASSERT(status);\n    }\n    return GGML_STATUS_SUCCESS;\n}\n\nstatic ggml_backend_i ggml_backend_rpc_interface = {\n    /* .get_name                = */ ggml_backend_rpc_name,\n    /* .free                    = */ ggml_backend_rpc_free,\n    /* .set_tensor_async        = */ NULL,\n    /* .get_tensor_async        = */ NULL,\n    /* .cpy_tensor_async        = */ NULL,\n    /* .synchronize             = */ ggml_backend_rpc_synchronize,\n    /* .graph_plan_create       = */ NULL,\n    /* .graph_plan_free         = */ NULL,\n    /* .graph_plan_update       = */ NULL,\n    /* .graph_plan_compute      = */ NULL,\n    /* .graph_compute           = */ ggml_backend_rpc_graph_compute,\n    /* .event_record            = */ NULL,\n    /* .event_wait              = */ NULL,\n    /* .graph_optimize          = */ NULL,\n};\n\nggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {\n    static std::mutex mutex;\n    std::lock_guard<std::mutex> lock(mutex);\n    std::string buft_name = \"RPC\" + std::to_string(device) + \"[\" + std::string(endpoint) + \"]\";\n    // NOTE: buffer types are allocated and never freed; this is by design\n    static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;\n    auto it = buft_map.find(buft_name);\n    if (it != buft_map.end()) {\n        return it->second;\n    }\n    auto sock = get_socket(endpoint);\n    if (sock == nullptr) {\n        GGML_LOG_ERROR(\"Failed to connect to %s\\n\", endpoint);\n        return nullptr;\n    }\n    size_t alignment = get_alignment(sock, device);\n    size_t max_size = get_max_size(sock, device);\n    ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {\n        /* .endpoint  = */ endpoint,\n        /* .device    = */ device,\n        /* .name      = */ buft_name,\n        /* .alignment = */ alignment,\n        /* .max_size  = */ max_size\n    };\n    auto reg = ggml_backend_rpc_add_server(endpoint);\n    ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {\n        /* .iface   = */ ggml_backend_rpc_buffer_type_interface,\n        /* .device  = */ ggml_backend_reg_dev_get(reg, device),\n        /* .context = */ buft_ctx\n    };\n    buft_map[buft_name] = buft;\n    return buft;\n}\n\nggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {\n    std::string dev_name = \"RPC\" + std::to_string(device) + \"[\" + std::string(endpoint) + \"]\";\n    ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {\n        /* .endpoint = */ endpoint,\n        /* .device   = */ device,\n        /* .name     = */ dev_name,\n        /* .gc       = */ {},\n    };\n    auto reg = ggml_backend_rpc_add_server(endpoint);\n    ggml_backend_t backend = new ggml_backend {\n        /* .guid    = */ ggml_backend_rpc_guid(),\n        /* .iface   = */ ggml_backend_rpc_interface,\n        /* .device  = */ ggml_backend_reg_dev_get(reg, device),\n        /* .context = */ ctx\n    };\n    return backend;\n}\n\nbool ggml_backend_is_rpc(ggml_backend_t backend) {\n    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());\n}\n\nstatic void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {\n    rpc_msg_get_device_memory_req request;\n    request.device = device;\n    rpc_msg_get_device_memory_rsp response;\n    bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));\n    RPC_STATUS_ASSERT(status);\n    *free = response.free_mem;\n    *total = response.total_mem;\n}\n\nvoid ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {\n    auto sock = get_socket(endpoint);\n    if (sock == nullptr) {\n        *free = 0;\n        *total = 0;\n        return;\n    }\n    get_device_memory(sock, device, free, total);\n}\n\n// RPC server-side implementation\n\nclass rpc_server {\npublic:\n    rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)\n        : backends(std::move(all_backends)), cache_dir(cache_dir) {\n        stored_graphs.resize(backends.size());\n    }\n    ~rpc_server();\n\n    void hello(rpc_msg_hello_rsp & response);\n    bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);\n    bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);\n    bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);\n    bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);\n    bool free_buffer(const rpc_msg_free_buffer_req & request);\n    bool buffer_clear(const rpc_msg_buffer_clear_req & request);\n    bool set_tensor(const std::vector<uint8_t> & input);\n    bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);\n    bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);\n    bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);\n    bool graph_compute(const std::vector<uint8_t> & input);\n    bool graph_recompute(const rpc_msg_graph_recompute_req & request);\n    bool init_tensor(const rpc_msg_init_tensor_req & request);\n    bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);\n    bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);\n\n    struct stored_graph {\n        ggml_context_ptr ctx_ptr;\n        ggml_cgraph *    graph;\n    };\n\nprivate:\n    bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);\n    ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);\n    ggml_tensor * create_node(uint64_t id,\n                              struct ggml_context * ctx,\n                              const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,\n                              std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);\n\n\n    std::vector<ggml_backend_t> backends;\n    const char * cache_dir;\n    std::unordered_set<ggml_backend_buffer_t> buffers;\n    // store the last computed graph for each backend\n    std::vector<stored_graph> stored_graphs;\n};\n\nvoid rpc_server::hello(rpc_msg_hello_rsp & response) {\n    response.major = RPC_PROTO_MAJOR_VERSION;\n    response.minor = RPC_PROTO_MINOR_VERSION;\n    response.patch = RPC_PROTO_PATCH_VERSION;\n    LOG_DBG(\"[%s] version: %d.%d.%d\\n\", __func__, response.major, response.minor, response.patch);\n}\n\nbool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {\n    uint32_t dev_id = request.device;\n    if (dev_id >= backends.size()) {\n        return false;\n    }\n    ggml_backend_buffer_type_t buft;\n    struct ggml_init_params params {\n        /*.mem_size   =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC),\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ true,\n    };\n\n    ggml_context_ptr ctx_ptr { ggml_init(params) };\n    GGML_ASSERT(ctx_ptr != nullptr);\n    ggml_context * ctx = ctx_ptr.get();\n\n    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);\n    if (tensor == nullptr) {\n        GGML_LOG_ERROR(\"Null tensor pointer passed to server get_alloc_size function.\\n\");\n        return false;\n    }\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        if (request.srcs[i].id != 0) {\n            tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]);\n        }\n    }\n\n    LOG_DBG(\"[%s] device: %d, buffer: %p, data: %p\\n\", __func__, dev_id, (void*)tensor->buffer, tensor->data);\n    if (tensor->buffer == nullptr) {\n        //No buffer allocated.\n        buft = ggml_backend_get_default_buffer_type(backends[dev_id]);\n    } else {\n        buft = tensor->buffer->buft;\n    }\n\n    response.alloc_size = ggml_backend_buft_get_alloc_size(buft, tensor);\n\n    return true;\n}\n\nbool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {\n    uint32_t dev_id = request.device;\n    if (dev_id >= backends.size()) {\n        return false;\n    }\n    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);\n    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);\n    response.remote_ptr = 0;\n    response.remote_size = 0;\n    if (buffer != nullptr) {\n        response.remote_ptr = reinterpret_cast<uint64_t>(buffer);\n        response.remote_size = buffer->size;\n        LOG_DBG(\"[%s] device: %d, size: %\" PRIu64 \" -> remote_ptr: %\" PRIx64 \", remote_size: %\" PRIu64 \"\\n\",\n            __func__, dev_id, request.size, response.remote_ptr, response.remote_size);\n        buffers.insert(buffer);\n    } else {\n        LOG_DBG(\"[%s] device: %d, size: %\" PRIu64 \" -> failed\\n\", __func__, dev_id, request.size);\n    }\n    return true;\n}\n\nbool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {\n    uint32_t dev_id = request.device;\n    if (dev_id >= backends.size()) {\n        return false;\n    }\n    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);\n    size_t alignment = ggml_backend_buft_get_alignment(buft);\n    LOG_DBG(\"[%s] device: %d, alignment: %lu\\n\", __func__, dev_id, alignment);\n    response.alignment = alignment;\n    return true;\n}\n\nbool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {\n    uint32_t dev_id = request.device;\n    if (dev_id >= backends.size()) {\n        return false;\n    }\n    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);\n    size_t max_size = ggml_backend_buft_get_max_size(buft);\n    LOG_DBG(\"[%s] device: %d, max_size: %lu\\n\", __func__, dev_id, max_size);\n    response.max_size = max_size;\n    return true;\n}\n\nbool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {\n    LOG_DBG(\"[%s] remote_ptr: %\" PRIx64 \"\\n\", __func__, request.remote_ptr);\n    ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);\n    if (buffers.find(buffer) == buffers.end()) {\n        GGML_LOG_ERROR(\"[%s] buffer not found\\n\", __func__);\n        return false;\n    }\n    void * base = ggml_backend_buffer_get_base(buffer);\n    response.base_ptr = reinterpret_cast<uint64_t>(base);\n    return true;\n}\n\nbool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {\n    LOG_DBG(\"[%s] remote_ptr: %\" PRIx64 \"\\n\", __func__, request.remote_ptr);\n    ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);\n    if (buffers.find(buffer) == buffers.end()) {\n        GGML_LOG_ERROR(\"[%s] buffer not found\\n\", __func__);\n        return false;\n    }\n    ggml_backend_buffer_free(buffer);\n    buffers.erase(buffer);\n    return true;\n}\n\nbool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {\n    LOG_DBG(\"[%s] remote_ptr: %\" PRIx64 \", value: %u\\n\", __func__, request.remote_ptr, request.value);\n    ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);\n    if (buffers.find(buffer) == buffers.end()) {\n        GGML_LOG_ERROR(\"[%s] buffer not found\\n\", __func__);\n        return false;\n    }\n    ggml_backend_buffer_clear(buffer, request.value);\n    return true;\n}\n\nggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {\n    // Validate tensor type before using it\n    if (tensor->type >= GGML_TYPE_COUNT) {\n        GGML_LOG_ERROR(\"[%s] invalid tensor type received: %u\\n\", __func__, tensor->type);\n        return nullptr;\n    }\n\n    ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,\n        tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);\n\n    // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type\n    if (result == nullptr) {\n        GGML_LOG_ERROR(\"[%s] ggml_new_tensor_4d failed for type %u\\\\n\", __func__, tensor->type);\n        return nullptr;\n    }\n\n    for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {\n        result->nb[i] = tensor->nb[i];\n    }\n    result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);\n    if (result->buffer && buffers.find(result->buffer) == buffers.end()) {\n        result->buffer = nullptr;\n    }\n\n    if (result->buffer) {\n        // require that the tensor data does not go beyond the buffer end\n        uint64_t tensor_size = (uint64_t) ggml_nbytes(result);\n        uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);\n        uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);\n        GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow\n        GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);\n    }\n\n    result->op = (ggml_op) tensor->op;\n    for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {\n        result->op_params[i] = tensor->op_params[i];\n    }\n    result->flags = tensor->flags;\n    result->data = reinterpret_cast<void *>(tensor->data);\n    ggml_set_name(result, tensor->name);\n    return result;\n}\n\n\nbool rpc_server::set_tensor(const std::vector<uint8_t> & input) {\n    // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |\n    if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {\n        return false;\n    }\n    const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();\n    uint64_t offset;\n    memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));\n    const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);\n\n    struct ggml_init_params params {\n        /*.mem_size   =*/ ggml_tensor_overhead(),\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ true,\n    };\n    ggml_context_ptr ctx_ptr { ggml_init(params) };\n    GGML_ASSERT(ctx_ptr != nullptr);\n    ggml_context * ctx = ctx_ptr.get();\n    ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);\n    if (tensor == nullptr || tensor->buffer == nullptr) {\n        GGML_LOG_ERROR(\"[%s] error deserializing tensor\\n\", __func__);\n        return false;\n    }\n    LOG_DBG(\"[%s] buffer: %p, data: %p, offset: %\" PRIu64 \", size: %zu\\n\", __func__, (void*)tensor->buffer, tensor->data, offset, size);\n\n    // sanitize tensor->data\n    {\n        const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);\n        const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);\n\n        if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {\n            GGML_LOG_ERROR(\"[%s] tensor data region (data=0x%\" PRIx64 \", offset=%\" PRIu64 \", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\\n\",\n                           __func__, in_tensor->data, offset, size, p0, p1);\n            return false;\n        }\n    }\n\n    const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);\n    if (cache_dir && size > HASH_THRESHOLD) {\n        uint64_t hash = fnv_hash((const uint8_t*)data, size);\n        char hash_str[17];\n        snprintf(hash_str, sizeof(hash_str), \"%016\" PRIx64, hash);\n        // save to cache_dir/hash_str\n        fs::path cache_file = fs::path(cache_dir) / hash_str;\n        std::ofstream ofs(cache_file, std::ios::binary);\n        ofs.write((const char *)data, size);\n        GGML_LOG_INFO(\"[%s] saved to '%s'\\n\", __func__, cache_file.c_str());\n    }\n    ggml_backend_tensor_set(tensor, data, offset, size);\n    return true;\n}\n\nbool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {\n    if (!cache_dir) {\n        return false;\n    }\n    char hash_str[17];\n    snprintf(hash_str, sizeof(hash_str), \"%016\" PRIx64, hash);\n    fs::path cache_file = fs::path(cache_dir) / hash_str;\n    std::error_code ec;\n    if (!fs::exists(cache_file, ec)) {\n        return false;\n    }\n    std::ifstream ifs(cache_file, std::ios::binary);\n    ifs.seekg(0, std::ios::end);\n    size_t size = ifs.tellg();\n    ifs.seekg(0, std::ios::beg);\n    data.resize(size);\n    ifs.read((char *)data.data(), size);\n    return true;\n}\n\nbool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response)\n{\n    std::vector<uint8_t> cached_file;\n    if (!get_cached_file(request.hash, cached_file)) {\n        response.result = 0;\n        return true;\n    }\n    size_t size = cached_file.size();\n    struct ggml_init_params params {\n        /*.mem_size   =*/ ggml_tensor_overhead(),\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ true,\n    };\n    ggml_context_ptr ctx_ptr { ggml_init(params) };\n    GGML_ASSERT(ctx_ptr != nullptr);\n    ggml_context * ctx = ctx_ptr.get();\n    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);\n    if (tensor == nullptr || tensor->buffer == nullptr) {\n        GGML_LOG_ERROR(\"[%s] error deserializing tensor\\n\", __func__);\n        return false;\n    }\n    LOG_DBG(\"[%s] buffer: %p, data: %p, offset: %\" PRIu64 \", size: %zu, hash: %\" PRIx64 \"\\n\",\n            __func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);\n\n    // sanitize tensor->data\n    {\n        const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);\n        const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);\n\n        if (request.tensor.data + request.offset < p0\n         || request.tensor.data + request.offset >= p1\n         || size > (p1 - request.tensor.data - request.offset)) {\n            GGML_LOG_ERROR(\"[%s] tensor data region (data=0x%\" PRIx64 \", offset=%\" PRIu64 \", size=%zu, hash=0x%\" PRIx64 \") out of buffer bounds [0x%zx, 0x%zx)\\n\",\n                           __func__, request.tensor.data, request.offset, size, request.hash, p0, p1);\n            return false;\n        }\n    }\n    ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size);\n    response.result = 1;\n    return true;\n}\n\nbool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {\n    struct ggml_init_params params {\n        /*.mem_size   =*/ ggml_tensor_overhead(),\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ true,\n    };\n    ggml_context_ptr ctx_ptr { ggml_init(params) };\n    GGML_ASSERT(ctx_ptr != nullptr);\n    ggml_context * ctx = ctx_ptr.get();\n    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);\n    if (tensor == nullptr) {\n        GGML_LOG_ERROR(\"Null tensor pointer passed to server init_tensor function.\\n\");\n        return false;\n    }\n    LOG_DBG(\"[%s] buffer: %p, data: %p\\n\", __func__, (void*)tensor->buffer, tensor->data);\n    // Call the backend's buffer_init_tensor function\n    ggml_backend_buffer_t buffer = tensor->buffer;\n    if (buffer && buffer->iface.init_tensor) {\n        buffer->iface.init_tensor(buffer, tensor);\n    } else {\n        GGML_LOG_ERROR(\"Null buffer for tensor passed to init_tensor function\\n\");\n    }\n\n    if (tensor->extra != nullptr) {\n        // This pointer can either be passed around client/server, or probably better stored server-side and kept track of.\n        // Currently unimplemented.\n        GGML_LOG_ERROR(\"tensor->extra populated by the backend, this is currently unsupported.\\n\");\n        return false;\n    }\n\n    return true;\n}\n\nbool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {\n    struct ggml_init_params params {\n        /*.mem_size   =*/ ggml_tensor_overhead(),\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ true,\n    };\n    ggml_context_ptr ctx_ptr { ggml_init(params) };\n    GGML_ASSERT(ctx_ptr != nullptr);\n    ggml_context * ctx = ctx_ptr.get();\n    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);\n    if (tensor == nullptr || tensor->buffer == nullptr) {\n        GGML_LOG_ERROR(\"[%s] error deserializing tensor\\n\", __func__);\n        return false;\n    }\n    LOG_DBG(\"[%s] buffer: %p, data: %p, offset: %\" PRIu64 \", size: %\" PRIu64 \"\\n\", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);\n\n    // sanitize tensor->data\n    {\n        const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);\n        const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);\n\n        if (request.tensor.data + request.offset < p0 ||\n            request.tensor.data + request.offset >= p1 ||\n            request.size > (p1 - request.tensor.data - request.offset)) {\n                GGML_LOG_ERROR(\"[%s] requested tensor region (data=0x%\" PRIx64 \", offset=%\" PRIu64 \", size=%\" PRIu64 \") out of buffer bounds [0x%zx, 0x%zx)\\n\",\n                               __func__, request.tensor.data, request.offset, request.size, p0, p1);\n                return false;\n        }\n    }\n\n    response.resize(request.size, 0);\n    ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);\n    return true;\n}\n\nbool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {\n    struct ggml_init_params params {\n        /*.mem_size   =*/ 2*ggml_tensor_overhead(),\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ true,\n    };\n    ggml_context_ptr ctx_ptr { ggml_init(params) };\n    GGML_ASSERT(ctx_ptr != nullptr);\n    ggml_context * ctx = ctx_ptr.get();\n\n    ggml_tensor * src = deserialize_tensor(ctx, &request.src);\n    ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);\n    if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) {\n        GGML_LOG_ERROR(\"[%s] error deserializing tensors\\n\", __func__);\n        return false;\n    }\n\n    uint64_t src_size   = (uint64_t) ggml_nbytes(src);\n    uint64_t dst_data   = (uint64_t) dst->data;\n    uint64_t dst_base   = (uint64_t) ggml_backend_buffer_get_base(dst->buffer);\n    uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);\n\n    if (dst_data + src_size > dst_base + dst_buf_sz) {\n        GGML_LOG_ERROR(\"[%s] out-of-bounds write in rpc_server::copy_tensor:\\n\"\n                         \"    write range : [0x%\" PRIx64 \", 0x%\" PRIx64 \"]\\n\"\n                         \"    buffer base: [0x%\" PRIx64 \", 0x%\" PRIx64 \"]\\n\",\n                         __func__,\n                         dst_data,\n                         dst_data + src_size,\n                         dst_base,\n                         dst_base + dst_buf_sz);\n        return false;\n    }\n\n    LOG_DBG(\"[%s] src->buffer: %p, dst->buffer: %p\\n\",\n            __func__, (void*) src->buffer, (void*) dst->buffer);\n\n    response.result = ggml_backend_buffer_copy_tensor(src, dst);\n    return true;\n}\n\nggml_tensor * rpc_server::create_node(uint64_t id,\n                                      struct ggml_context * ctx,\n                                      const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,\n                                      std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {\n    if (tensor_map.find(id) != tensor_map.end()) {\n        return tensor_map[id];\n    }\n    // Safely find the tensor pointer\n    auto it_ptr = tensor_ptrs.find(id);\n    if (it_ptr == tensor_ptrs.end()) {\n        return nullptr;\n    }\n    const rpc_tensor * tensor = it_ptr->second;\n\n    struct ggml_tensor * result = deserialize_tensor(ctx, tensor);\n    if (result == nullptr) {\n        return nullptr;\n    }\n    tensor_map[id] = result;\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        // Check if the source ID is 0 before calling create_node recursively\n        if (tensor->src[i] == 0) {\n            result->src[i] = nullptr;\n        } else {\n            result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);\n            // If the recursive call failed for a non-zero ID, propagate the error\n            if (result->src[i] == nullptr) {\n                GGML_LOG_ERROR(\"[%s] failed to create source node %d (src_id=%\" PRIu64 \") for node id %\" PRIu64 \"\\n\",\n                               __func__, i, tensor->src[i], id);\n                // Must return nullptr to signal failure up the call stack\n                return nullptr;\n            }\n        }\n    }\n\n    // Handle view_src similarly\n    if (tensor->view_src == 0) {\n        result->view_src = nullptr;\n    } else {\n        result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);\n        // If the recursive call failed for a non-zero ID, propagate the error\n        if (result->view_src == nullptr) {\n            GGML_LOG_ERROR(\"[%s] failed to create view_src node (view_src_id=%\" PRIu64 \") for node id %\" PRIu64 \"\\n\",\n                           __func__, tensor->view_src, id);\n            // Must return nullptr to signal failure up the call stack\n            return nullptr;\n        }\n    }\n    result->view_offs = tensor->view_offs;\n    return result;\n}\n\nbool rpc_server::graph_compute(const std::vector<uint8_t> & input) {\n    // serialization format:\n    // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |\n    if (input.size() < 2*sizeof(uint32_t)) {\n        return false;\n    }\n    const uint8_t * src = input.data();\n    uint32_t device;\n    memcpy(&device, src, sizeof(device));\n    src += sizeof(device);\n    if (device >= backends.size()) {\n        return false;\n    }\n    uint32_t n_nodes;\n    memcpy(&n_nodes, src, sizeof(n_nodes));\n    src += sizeof(n_nodes);\n    if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {\n        return false;\n    }\n    const uint64_t * nodes = (const uint64_t *)src;\n    src += n_nodes*sizeof(uint64_t);\n    uint32_t n_tensors;\n    memcpy(&n_tensors, src, sizeof(n_tensors));\n    src += sizeof(n_tensors);\n    if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {\n        return false;\n    }\n    const rpc_tensor * tensors = (const rpc_tensor *)src;\n    LOG_DBG(\"[%s] device: %u, n_nodes: %u, n_tensors: %u\\n\", __func__, device, n_nodes, n_tensors);\n\n    size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);\n\n    struct ggml_init_params params = {\n        /*.mem_size   =*/ buf_size,\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ true,\n    };\n    ggml_context_ptr ctx_ptr { ggml_init(params) };\n    GGML_ASSERT(ctx_ptr != nullptr);\n    ggml_context * ctx = ctx_ptr.get();\n    struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);\n    graph->n_nodes = n_nodes;\n    std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;\n    tensor_ptrs.reserve(n_tensors);\n    for (uint32_t i = 0; i < n_tensors; i++) {\n        tensor_ptrs.emplace(tensors[i].id, &tensors[i]);\n    }\n    std::unordered_map<uint64_t, ggml_tensor*> tensor_map;\n    tensor_map.reserve(n_nodes);\n    for (uint32_t i = 0; i < n_nodes; i++) {\n        int64_t id;\n        memcpy(&id, &nodes[i], sizeof(id));\n        graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);\n\n        // Check if create_node failed for a *non-zero* ID.\n        // If id was 0, create_node returning nullptr is expected.\n        // If id was non-zero and create_node returned nullptr, it indicates a deserialization error.\n        if (graph->nodes[i] == nullptr && id != 0) {\n            GGML_LOG_ERROR(\"[%s] failed to create graph node %d (id=%\" PRId64 \")\\n\", __func__, i, id);\n            return false;\n        }\n    }\n    ggml_status status = ggml_backend_graph_compute(backends[device], graph);\n    GGML_ASSERT(status == GGML_STATUS_SUCCESS && \"Unsuccessful graph computations are not supported with RPC\");\n    stored_graphs[device].ctx_ptr.swap(ctx_ptr);\n    stored_graphs[device].graph = graph;\n    return true;\n}\n\nbool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {\n    uint32_t device = request.device;\n    if (device >= backends.size()) {\n        return false;\n    }\n    if (stored_graphs[device].graph == nullptr) {\n        return false;\n    }\n    ggml_cgraph * graph = stored_graphs[device].graph;\n    LOG_DBG(\"[%s] device: %u\\n\", __func__, device);\n    ggml_status status = ggml_backend_graph_compute(backends[device], graph);\n    GGML_ASSERT(status == GGML_STATUS_SUCCESS && \"Unsuccessful graph computations are not supported with RPC\");\n    return true;\n}\n\nbool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {\n    uint32_t dev_id = request.device;\n    if (dev_id >= backends.size()) {\n        return false;\n    }\n    size_t free, total;\n    ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);\n    ggml_backend_dev_memory(dev, &free, &total);\n    response.free_mem = free;\n    response.total_mem = total;\n    LOG_DBG(\"[%s] device: %u, free_mem: %\" PRIu64 \", total_mem: %\" PRIu64 \"\\n\", __func__, dev_id, response.free_mem, response.total_mem);\n    return true;\n}\n\nrpc_server::~rpc_server() {\n    for (auto buffer : buffers) {\n        ggml_backend_buffer_free(buffer);\n    }\n}\n\nstatic void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,\n                             sockfd_t sockfd) {\n    rpc_server server(backends, cache_dir);\n    uint8_t cmd;\n    if (!recv_data(sockfd, &cmd, 1)) {\n        return;\n    }\n    // the first command sent by the client must be HELLO\n    if (cmd != RPC_CMD_HELLO) {\n        GGML_LOG_ERROR(\"Expected HELLO command, update client\\n\");\n        return;\n    }\n    if (!recv_msg(sockfd, nullptr, 0)) {\n        return;\n    }\n    rpc_msg_hello_rsp response;\n    server.hello(response);\n    if (!send_msg(sockfd, &response, sizeof(response))) {\n        return;\n    }\n    while (true) {\n        if (!recv_data(sockfd, &cmd, 1)) {\n            break;\n        }\n        if (cmd >= RPC_CMD_COUNT) {\n            // fail fast if the command is invalid\n            GGML_LOG_ERROR(\"Unknown command: %d\\n\", cmd);\n            break;\n        }\n        switch (cmd) {\n            case RPC_CMD_HELLO: {\n                // HELLO command is handled above\n                return;\n            }\n            case RPC_CMD_DEVICE_COUNT: {\n                if (!recv_msg(sockfd, nullptr, 0)) {\n                    return;\n                }\n                rpc_msg_device_count_rsp response;\n                response.device_count = backends.size();\n                if (!send_msg(sockfd, &response, sizeof(response))) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_ALLOC_BUFFER: {\n                rpc_msg_alloc_buffer_req request;\n                if (!recv_msg(sockfd, &request, sizeof(request))) {\n                    return;\n                }\n                rpc_msg_alloc_buffer_rsp response;\n                if (!server.alloc_buffer(request, response)) {\n                    return;\n                }\n                if (!send_msg(sockfd, &response, sizeof(response))) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_GET_ALLOC_SIZE: {\n                rpc_msg_get_alloc_size_req request;\n                if (!recv_msg(sockfd, &request, sizeof(request))) {\n                    return;\n                }\n                rpc_msg_get_alloc_size_rsp response;\n                if (!server.get_alloc_size(request, response)) {\n                    return;\n                }\n                if (!send_msg(sockfd, &response, sizeof(response))) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_GET_ALIGNMENT: {\n                rpc_msg_get_alignment_req request;\n                if (!recv_msg(sockfd, &request, sizeof(request))) {\n                    return;\n                }\n                rpc_msg_get_alignment_rsp response;\n                if (!server.get_alignment(request, response)) {\n                    return;\n                }\n                if (!send_msg(sockfd, &response, sizeof(response))) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_GET_MAX_SIZE: {\n                rpc_msg_get_max_size_req request;\n                if (!recv_msg(sockfd, &request, sizeof(request))) {\n                    return;\n                }\n                rpc_msg_get_max_size_rsp response;\n                if (!server.get_max_size(request, response)) {\n                    return;\n                }\n                if (!send_msg(sockfd, &response, sizeof(response))) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_BUFFER_GET_BASE: {\n                rpc_msg_buffer_get_base_req request;\n                if (!recv_msg(sockfd, &request, sizeof(request))) {\n                    return;\n                }\n                rpc_msg_buffer_get_base_rsp response;\n                if (!server.buffer_get_base(request, response)) {\n                    return;\n                }\n                if (!send_msg(sockfd, &response, sizeof(response))) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_FREE_BUFFER: {\n                rpc_msg_free_buffer_req request;\n                if (!recv_msg(sockfd, &request, sizeof(request))) {\n                    return;\n                }\n                if (!server.free_buffer(request)) {\n                    return;\n                }\n                if (!send_msg(sockfd, nullptr, 0)) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_BUFFER_CLEAR: {\n                rpc_msg_buffer_clear_req request;\n                if (!recv_msg(sockfd, &request, sizeof(request))) {\n                    return;\n                }\n                if (!server.buffer_clear(request)) {\n                    return;\n                }\n                if (!send_msg(sockfd, nullptr, 0)) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_SET_TENSOR: {\n                std::vector<uint8_t> input;\n                if (!recv_msg(sockfd, input)) {\n                    return;\n                }\n                if (!server.set_tensor(input)) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_SET_TENSOR_HASH: {\n                rpc_msg_set_tensor_hash_req request;\n                if (!recv_msg(sockfd, &request, sizeof(request))) {\n                    return;\n                }\n                rpc_msg_set_tensor_hash_rsp response;\n                if (!server.set_tensor_hash(request, response)) {\n                    return;\n                }\n                if (!send_msg(sockfd, &response, sizeof(response))) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_INIT_TENSOR: {\n                rpc_msg_init_tensor_req request;\n                if (!recv_msg(sockfd, &request,sizeof(request))) {\n                    return;\n                }\n                if (!server.init_tensor(request)) {\n                    return;\n                }\n                if (!send_msg(sockfd, nullptr, 0)) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_GET_TENSOR: {\n                rpc_msg_get_tensor_req request;\n                if (!recv_msg(sockfd, &request, sizeof(request))) {\n                    return;\n                }\n                std::vector<uint8_t> response;\n                if (!server.get_tensor(request, response)) {\n                    return;\n                }\n                if (!send_msg(sockfd, response.data(), response.size())) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_COPY_TENSOR: {\n                rpc_msg_copy_tensor_req request;\n                if (!recv_msg(sockfd, &request, sizeof(request))) {\n                    return;\n                }\n                rpc_msg_copy_tensor_rsp response;\n                if (!server.copy_tensor(request, response)) {\n                    return;\n                }\n                if (!send_msg(sockfd, &response, sizeof(response))) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_GRAPH_COMPUTE: {\n                std::vector<uint8_t> input;\n                if (!recv_msg(sockfd, input)) {\n                    return;\n                }\n                if (!server.graph_compute(input)) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_GRAPH_RECOMPUTE: {\n                rpc_msg_graph_recompute_req request;\n                if (!recv_msg(sockfd, &request, sizeof(request))) {\n                    return;\n                }\n                if (!server.graph_recompute(request)) {\n                    return;\n                }\n                break;\n            }\n            case RPC_CMD_GET_DEVICE_MEMORY: {\n                rpc_msg_get_device_memory_req request;\n                if (!recv_msg(sockfd, &request, sizeof(request))) {\n                    return;\n                }\n                rpc_msg_get_device_memory_rsp response;\n                if (!server.get_device_memory(request, response)) {\n                    return;\n                }\n                if (!send_msg(sockfd, &response, sizeof(response))) {\n                    return;\n                }\n                break;\n            }\n            default: {\n                GGML_LOG_ERROR(\"Unknown command: %d\\n\", cmd);\n                return;\n            }\n        }\n    }\n}\n\nvoid ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,\n                                   size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {\n    if (n_devices == 0 || devices == nullptr) {\n        fprintf(stderr, \"Invalid arguments to ggml_backend_rpc_start_server\\n\");\n        return;\n    }\n    std::vector<ggml_backend_t> backends;\n    printf(\"Starting RPC server v%d.%d.%d\\n\",\n        RPC_PROTO_MAJOR_VERSION,\n        RPC_PROTO_MINOR_VERSION,\n        RPC_PROTO_PATCH_VERSION);\n    printf(\"  endpoint       : %s\\n\", endpoint);\n    printf(\"  local cache    : %s\\n\", cache_dir ? cache_dir : \"n/a\");\n    printf(\"Devices:\\n\");\n    for (size_t i = 0; i < n_devices; i++) {\n        auto dev = devices[i];\n        size_t free, total;\n        ggml_backend_dev_memory(dev, &free, &total);\n        printf(\"  %s: %s (%zu MiB, %zu MiB free)\\n\", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),\n               total / 1024 / 1024, free / 1024 / 1024);\n        auto backend = ggml_backend_dev_init(dev, nullptr);\n        if (!backend) {\n            fprintf(stderr, \"Failed to create backend for device %s\\n\", dev->iface.get_name(dev));\n            return;\n        }\n        backends.push_back(backend);\n        ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;\n        if (reg) {\n            auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, \"ggml_backend_set_n_threads\");\n            if (ggml_backend_set_n_threads_fn) {\n                ggml_backend_set_n_threads_fn(backend, n_threads);\n            }\n        }\n    }\n\n    std::string host;\n    int port;\n    if (!parse_endpoint(endpoint, host, port)) {\n        return;\n    }\n#ifdef _WIN32\n    {\n        WSADATA wsaData;\n        int res = WSAStartup(MAKEWORD(2, 2), &wsaData);\n        if (res != 0) {\n            fprintf(stderr, \"WSAStartup failed: %d\\n\", res);\n            return;\n        }\n    }\n#endif\n    auto server_socket = create_server_socket(host.c_str(), port);\n    if (server_socket == nullptr) {\n        fprintf(stderr, \"Failed to create server socket\\n\");\n        return;\n    }\n    while (true) {\n        auto client_socket = socket_accept(server_socket->fd);\n        if (client_socket == nullptr) {\n            fprintf(stderr, \"Failed to accept client connection\\n\");\n            return;\n        }\n        printf(\"Accepted client connection\\n\");\n        fflush(stdout);\n        rpc_serve_client(backends, cache_dir, client_socket->fd);\n        printf(\"Client connection closed\\n\");\n        fflush(stdout);\n    }\n#ifdef _WIN32\n    WSACleanup();\n#endif\n    for (auto backend : backends) {\n        ggml_backend_free(backend);\n    }\n}\n\n// device interface\n\nstruct ggml_backend_rpc_device_context {\n    std::string endpoint;\n    uint32_t    device;\n    std::string name;\n    std::string description;\n};\n\nstatic const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {\n    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;\n\n    return ctx->name.c_str();\n}\n\nstatic const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {\n    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;\n\n    return ctx->description.c_str();\n}\n\nstatic void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {\n    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;\n\n    ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {\n    // TODO: obtain value from the server\n    return GGML_BACKEND_DEVICE_TYPE_GPU;\n\n    GGML_UNUSED(dev);\n}\n\nstatic void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {\n    props->name        = ggml_backend_rpc_device_get_name(dev);\n    props->description = ggml_backend_rpc_device_get_description(dev);\n    props->type        = ggml_backend_rpc_device_get_type(dev);\n    ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);\n    props->caps = {\n        /* .async                 = */ false,\n        /* .host_buffer           = */ false,\n        /* .buffer_from_host_ptr  = */ false,\n        /* .events                = */ false,\n    };\n}\n\nstatic ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {\n    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;\n\n    return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);\n\n    GGML_UNUSED(params);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {\n    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;\n\n    return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);\n\n    GGML_UNUSED(dev);\n}\n\nstatic bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {\n    GGML_UNUSED(dev);\n    GGML_UNUSED(op);\n    //TODO: call the remote backend and cache the results\n    return true;\n}\n\nstatic bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {\n        return false;\n    }\n    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;\n    ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;\n    return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;\n}\n\nstatic const struct ggml_backend_device_i ggml_backend_rpc_device_i = {\n    /* .get_name             = */ ggml_backend_rpc_device_get_name,\n    /* .get_description      = */ ggml_backend_rpc_device_get_description,\n    /* .get_memory           = */ ggml_backend_rpc_device_get_memory,\n    /* .get_type             = */ ggml_backend_rpc_device_get_type,\n    /* .get_props            = */ ggml_backend_rpc_device_get_props,\n    /* .init_backend         = */ ggml_backend_rpc_device_init,\n    /* .get_buffer_type      = */ ggml_backend_rpc_device_get_buffer_type,\n    /* .get_host_buffer_type = */ NULL,\n    /* .buffer_from_host_ptr = */ NULL,\n    /* .supports_op          = */ ggml_backend_rpc_device_supports_op,\n    /* .supports_buft        = */ ggml_backend_rpc_device_supports_buft,\n    /* .offload_op           = */ NULL,\n    /* .event_new            = */ NULL,\n    /* .event_free           = */ NULL,\n    /* .event_synchronize    = */ NULL,\n};\n\n// backend reg interface\n\nstruct ggml_backend_rpc_reg_context {\n    std::string                     name;\n    std::vector<ggml_backend_dev_t> devices;\n};\n\nstatic const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {\n    ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;\n    return ctx ? ctx->name.c_str() : \"RPC\";\n}\n\nstatic size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {\n    ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;\n    return ctx ? ctx->devices.size() : 0;\n}\n\nstatic ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {\n    ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;\n    if (ctx == nullptr) {\n        GGML_ABORT(\"The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead\");\n    } else {\n        GGML_ASSERT(index < ctx->devices.size());\n        return ctx->devices[index];\n    }\n}\n\nstatic void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {\n    if (std::strcmp(name, \"ggml_backend_rpc_add_server\") == 0) {\n        return (void *)ggml_backend_rpc_add_server;\n    }\n    if (std::strcmp(name, \"ggml_backend_rpc_start_server\") == 0) {\n        return (void *)ggml_backend_rpc_start_server;\n    }\n    return NULL;\n\n    GGML_UNUSED(reg);\n}\n\nstatic const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {\n    /* .get_name         = */ ggml_backend_rpc_reg_get_name,\n    /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,\n    /* .get_device       = */ ggml_backend_rpc_reg_get_device,\n    /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,\n};\n\nggml_backend_reg_t ggml_backend_rpc_reg(void) {\n    static struct ggml_backend_reg ggml_backend_rpc_reg = {\n        /* .api_version = */ GGML_BACKEND_API_VERSION,\n        /* .iface       = */ ggml_backend_rpc_reg_i,\n        /* .context     = */ NULL,\n    };\n\n    return &ggml_backend_rpc_reg;\n}\n\nstatic uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {\n    auto sock = get_socket(endpoint);\n    if (sock == nullptr) {\n        GGML_LOG_ERROR(\"Failed to connect to %s\\n\", endpoint);\n        return 0;\n    }\n    rpc_msg_device_count_rsp response;\n    bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));\n    RPC_STATUS_ASSERT(status);\n    return response.device_count;\n}\n\nstatic const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {\n    /* .get_name          = */ ggml_backend_rpc_reg_get_name,\n    /* .get_device_count  = */ ggml_backend_rpc_reg_get_device_count,\n    /* .get_device        = */ ggml_backend_rpc_reg_get_device,\n    /* .get_proc_address  = */ ggml_backend_rpc_get_proc_address,\n};\n\nggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {\n    static std::unordered_map<std::string, ggml_backend_reg_t> reg_map;\n    static std::mutex mutex;\n    static uint32_t dev_id = 0;\n    std::lock_guard<std::mutex> lock(mutex);\n    if (reg_map.find(endpoint) != reg_map.end()) {\n        return reg_map[endpoint];\n    }\n    uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);\n    if (dev_count == 0) {\n        return nullptr;\n    }\n    ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;\n    ctx->name = \"RPC[\" + std::string(endpoint) + \"]\";\n    for (uint32_t ind = 0; ind < dev_count; ind++) {\n        std::string dev_name = \"RPC\" + std::to_string(dev_id);\n        std::string dev_desc = std::string(endpoint);\n        ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {\n            /* .endpoint    = */ endpoint,\n            /* .device      = */ ind,\n            /* .name        = */ dev_name,\n            /* .description = */ dev_desc\n        };\n\n        ggml_backend_dev_t dev = new ggml_backend_device {\n            /* .iface   = */ ggml_backend_rpc_device_i,\n            /* .reg     = */ ggml_backend_rpc_reg(),\n            /* .context = */ dev_ctx,\n        };\n        ctx->devices.push_back(dev);\n        dev_id++;\n    }\n    ggml_backend_reg_t reg = new ggml_backend_reg {\n        /* .api_version = */ GGML_BACKEND_API_VERSION,\n        /* .iface       = */ ggml_backend_rpc_reg_interface,\n        /* .context     = */ ctx\n    };\n    reg_map[endpoint] = reg;\n    return reg;\n}\n\n\nGGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)\n"
  },
  {
    "path": "src/ggml-sycl/CMakeLists.txt",
    "content": "message(STATUS  \"GGML_SYCL_TARGET=${GGML_SYCL_TARGET}\")\n\nif (NOT GGML_SYCL_TARGET MATCHES \"^(INTEL)$\")\n    message(FATAL_ERROR \"GGML_SYCL_TARGET: Invalid target, the supported options are [INTEL]\")\nendif()\n\ncheck_cxx_compiler_flag(\"-fsycl\" SUPPORTS_SYCL)\n\nif (DEFINED ENV{ONEAPI_ROOT})\n    message(STATUS \"Using oneAPI Release SYCL compiler (icpx).\")\nelseif(SUPPORTS_SYCL)\n    message(WARNING \"Using open-source SYCL compiler (clang++). Didn't detect ENV {ONEAPI_ROOT}.\n        If you expected the oneAPI Release compiler, please install oneAPI & source it, like:\n        source /opt/intel/oneapi/setvars.sh\")\nelse()\n    message(FATAL_ERROR \"C++ compiler lacks SYCL support.\")\nendif()\nmessage(STATUS \"SYCL found\")\n#todo: AOT\n\nggml_add_backend_library(ggml-sycl\n                         ggml-sycl.cpp\n                         ../../include/ggml-sycl.h\n                        )\n\nfile(GLOB   GGML_HEADERS_SYCL \"*.hpp\")\nfile(GLOB   GGML_SOURCES_SYCL \"*.cpp\")\nfile(GLOB   SRCS \"template-instances/fattn-tile*.cpp\")\nlist(APPEND GGML_SOURCES_SYCL ${SRCS})\nfile(GLOB   SRCS \"template-instances/fattn-vec*.cpp\")\nlist(APPEND GGML_SOURCES_SYCL ${SRCS})\n\ntarget_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})\n\nif (WIN32)\n    # To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory\n    if( ${CMAKE_GENERATOR} MATCHES \"Visual Studio\" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES \"Intel C\"))\n        set_target_properties(ggml-sycl PROPERTIES VS_PLATFORM_TOOLSET \"Intel C++ Compiler 2025\")\n        set(CMAKE_CXX_COMPILER \"icx\")\n        set(CMAKE_CXX_COMPILER_ID \"IntelLLVM\")\n    endif()\nendif()\n\nmacro(detect_and_find_package package_name)\n    set(test_source \"\n    cmake_minimum_required(VERSION ${CMAKE_VERSION})\n    project(check_package LANGUAGES CXX)\n    find_package(${package_name} QUIET)\n    \")\n\n    set(test_dir \"${CMAKE_CURRENT_BINARY_DIR}/check_package_${package_name}\")\n    file(WRITE \"${test_dir}/CMakeLists.txt\" \"${test_source}\")\n\n    set(cmake_args \"\")\n    if(CMAKE_GENERATOR)\n        list(APPEND cmake_args \"-G\" \"${CMAKE_GENERATOR}\")\n    endif()\n    if(CMAKE_GENERATOR_PLATFORM)\n        list(APPEND cmake_args \"-A\" \"${CMAKE_GENERATOR_PLATFORM}\")\n    endif()\n    if(CMAKE_GENERATOR_TOOLSET)\n        list(APPEND cmake_args \"-T\" \"${CMAKE_GENERATOR_TOOLSET}\")\n    endif()\n    if(CMAKE_CXX_COMPILER)\n        list(APPEND cmake_args \"-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}\")\n    endif()\n\n    execute_process(\n        COMMAND ${CMAKE_COMMAND} ${cmake_args} .\n        WORKING_DIRECTORY \"${test_dir}\"\n        RESULT_VARIABLE result\n        OUTPUT_QUIET\n        ERROR_QUIET\n    )\n\n    if(result EQUAL 0)\n        find_package(${package_name} ${ARGN})\n    else()\n        message(WARNING \"Detection of ${package_name} failed. The package might be broken or incompatible.\")\n        set(${package_name}_FOUND FALSE)\n    endif()\nendmacro()\n\ndetect_and_find_package(IntelSYCL)\nif (IntelSYCL_FOUND)\n    # Use oneAPI CMake when possible\n    target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX)\nelse()\n    # Fallback to the simplest way of enabling SYCL when using intel/llvm nightly for instance\n    target_compile_options(ggml-sycl PRIVATE \"-fsycl\")\n    target_link_options(ggml-sycl PRIVATE \"-fsycl\")\nendif()\n\ntarget_compile_options(ggml-sycl PRIVATE \"-Wno-narrowing\")\n\n# Link against oneDNN\nset(GGML_SYCL_DNNL 0)\nif(GGML_SYCL_DNN)\n    find_package(DNNL)\n    if(DNNL_FOUND)\n        if (NOT DEFINED DNNL_GPU_VENDOR)\n            # default to intel target\n            set(DNNL_GPU_VENDOR \"INTEL\")\n            if(NOT \"${GGML_SYCL_TARGET}\" STREQUAL \"INTEL\")\n                message(WARNING \"oneDNN builds bundled with oneapi release only support INTEL target\")\n            endif()\n        endif()\n\n        # Verify oneDNN was compiled for the same target as llama\n        if(\"${GGML_SYCL_TARGET}\" STREQUAL \"${DNNL_GPU_VENDOR}\")\n            target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)\n            set(GGML_SYCL_DNNL 1)\n            get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)\n            foreach(CONFIG ${CONFIGS})\n                get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})\n                message(STATUS \"Found oneDNN: ${DNNL_LIB}\")\n            endforeach()\n        else()\n            message(WARNING\n                \"oneDNN must be compiled for the same target as llama.cpp.\n                 llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.\n                 Disabling oneDNN support.\")\n        endif()\n    else()\n        message(STATUS \"oneDNN not found, disabling oneDNN support\")\n    endif()\nelse()\n    message(STATUS \"oneDNN support disabled by the user\")\nendif()\ntarget_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})\n\nif (GGML_SYCL_F16)\n    add_compile_definitions(GGML_SYCL_F16)\nendif()\n\nif (GGML_SYCL_TARGET STREQUAL \"INTEL\")\n    add_compile_definitions(GGML_SYCL_WARP_SIZE=16)\n    target_link_options(ggml-sycl PRIVATE  -Xs   -ze-intel-greater-than-4GB-buffer-required)\n\n    # Link against Intel oneMKL\n    if (CMAKE_CXX_COMPILER_ID STREQUAL \"Clang\")\n        set(SYCL_COMPILER ON)\n    endif()\n    find_package(MKL REQUIRED)\n    target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS)\nelse()\n    # default for other target\n    message(FATAL_ERROR \"GGML_SYCL_TARGET is not supported\")\n    add_compile_definitions(GGML_SYCL_WARP_SIZE=32)\nendif()\n\nif (GGML_SYCL_GRAPH)\n    message(STATUS \"find GGML_SYCL_GRAPH\")\n    target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH)\nendif()\n\nif (GGML_SYCL_DEVICE_ARCH)\n    target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH})\n    target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH})\nendif()\n\n"
  },
  {
    "path": "src/ggml-sycl/add-id.cpp",
    "content": "#include <sycl/sycl.hpp>\n#include \"common.hpp\"\n#include \"add-id.hpp\"\n\nstatic void add_id_kernel(\n    const float* src0,\n    const float* src1,\n    const int32_t* src2,\n    float* dst,\n    int64_t ne0,\n    int64_t ne1,\n    size_t nb01,\n    size_t nb02,\n    size_t nb11,\n    size_t nb21,\n    sycl::nd_item<3> item_ct1) {\n  const int64_t i1 = item_ct1.get_group(2);\n  const int64_t i2 = item_ct1.get_group(1);\n\n  const int i11 =\n      *(const int32_t*)((const char*)src2 + i1 * sizeof(int32_t) + i2 * nb21);\n\n  const size_t nb1 = ne0 * sizeof(float);\n  const size_t nb2 = ne1 * nb1;\n\n  float* dst_row = (float*)((char*)dst + i1 * nb1 + i2 * nb2);\n  const float* src0_row =\n      (const float*)((const char*)src0 + i1 * nb01 + i2 * nb02);\n  const float* src1_row = (const float*)((const char*)src1 + i11 * nb11);\n\n  for (int64_t i0 = item_ct1.get_local_id(2); i0 < ne0;\n       i0 += item_ct1.get_local_range(2)) {\n    dst_row[i0] = src0_row[i0] + src1_row[i0];\n  }\n}\n\nvoid ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {\n  const ggml_tensor* src0 = dst->src[0];\n  const ggml_tensor* src1 = dst->src[1];\n  const ggml_tensor* src2 = dst->src[2];\n\n  GGML_TENSOR_TERNARY_OP_LOCALS\n\n  GGML_ASSERT(dst->type == GGML_TYPE_F32);\n  GGML_ASSERT(src0->type == GGML_TYPE_F32);\n  GGML_ASSERT(src1->type == GGML_TYPE_F32);\n  GGML_ASSERT(src2->type == GGML_TYPE_I32);\n\n  GGML_ASSERT(nb00 == sizeof(float));\n  GGML_ASSERT(nb10 == sizeof(float));\n  GGML_ASSERT(nb20 == sizeof(int32_t));\n\n  const float* src0_d = (const float*)src0->data;\n  const float* src1_d = (const float*)src1->data;\n  const int32_t* src2_d = (const int32_t*)src2->data;\n  float* dst_d = (float*)dst->data;\n\n  const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];\n  assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);\n\n  int threads = std::min((unsigned int)ne00, max_work_group_size);  // cols\n\n  ctx.stream()->parallel_for(\n      sycl::nd_range<3>(\n          sycl::range<3>(1, ne02, ne01) * sycl::range<3>(1, 1, threads),\n          sycl::range<3>(1, 1, threads)),\n      [=](sycl::nd_item<3> item_ct1) {\n        add_id_kernel(\n            src0_d,\n            src1_d,\n            src2_d,\n            dst_d,\n            ne0,\n            ne1,\n            nb01,\n            nb02,\n            nb11,\n            nb21,\n            item_ct1);\n      });\n}\n"
  },
  {
    "path": "src/ggml-sycl/add-id.hpp",
    "content": "#ifndef GGML_SYCL_ADD_ID_HPP\n#define GGML_SYCL_ADD_ID_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_add_id(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\n#endif // GGML_SYCL_ADD_ID_HPP\n"
  },
  {
    "path": "src/ggml-sycl/backend.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_BACKEND_HPP\n#define GGML_SYCL_BACKEND_HPP\n\n#include \"binbcast.hpp\"\n#include \"common.hpp\"\n#include \"concat.hpp\"\n#include \"conv.hpp\"\n#include \"convert.hpp\"\n#include \"count-equal.hpp\"\n#include \"cpy.hpp\"\n#include \"dequantize.hpp\"\n#include \"dmmv.hpp\"\n#include \"element_wise.hpp\"\n#include \"fattn.hpp\"\n#include \"gla.hpp\"\n#include \"im2col.hpp\"\n#include \"mmq.hpp\"\n#include \"mmvq.hpp\"\n#include \"norm.hpp\"\n#include \"outprod.hpp\"\n#include \"pad.hpp\"\n#include \"quantize.hpp\"\n#include \"quants.hpp\"\n#include \"roll.hpp\"\n#include \"rope.hpp\"\n#include \"set_rows.hpp\"\n#include \"ssm_conv.hpp\"\n#include \"softmax.hpp\"\n#include \"tsembd.hpp\"\n#include \"wkv.hpp\"\n#include \"pad_reflect_1d.hpp\"\n\n\n#endif  // GGML_SYCL_BACKEND_HPP\n"
  },
  {
    "path": "src/ggml-sycl/binbcast.cpp",
    "content": "#include \"binbcast.hpp\"\n\n#include <cstddef>\n#include <cstdint>\n#include <sycl/sycl.hpp>\n\n#include \"ggml.h\"\n\ntemplate<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>\nstatic void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,\n        int ne0, int ne1, int ne2, int ne3,\n        int ne10, int ne11, int ne12, int ne13,\n        /*int s0, */ int s1,  int s2,  int s3,\n        int s00, int s01, int s02, int s03,\n        int s10, int s11, int s12, int s13,\n        const sycl::nd_item<3> &item_ct1) {\n    const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +\n                    item_ct1.get_local_id(2);\n    const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +\n                    item_ct1.get_local_id(1));\n    const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +\n                    item_ct1.get_local_id(0)) /\n                   ne3;\n    const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +\n                    item_ct1.get_local_id(0)) %\n                   ne3;\n\n    if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {\n        return;\n    }\n\n    const int i11 = i1 % ne11;\n    const int i12 = i2 % ne12;\n    const int i13 = i3 % ne13;\n\n    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;\n    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;\n    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;\n\n    const src0_t * src0_row = src0 + i_src0;\n    const src1_t * src1_row = src1 + i_src1;\n    dst_t * dst_row = dst + i_dst;\n\n    for (int i0 = i0s; i0 < ne0;\n         i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {\n        const int i10 = i0 % ne10;\n        dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]);\n    }\n}\n\ntemplate<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>\nstatic void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,\n        int ne0, int ne1, int ne2, int ne3,\n        int ne10, int ne11, int ne12, int ne13,\n        /*int s0, */ int s1,  int s2,  int s3,\n        int s00, int s01, int s02, int s03,\n        int s10, int s11, int s12, int s13,\n        const sycl::nd_item<3> &item_ct1) {\n\n    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +\n                  item_ct1.get_local_id(2);\n\n    const int i3 = i/(ne2*ne1*ne0);\n    const int i2 = (i/(ne1*ne0)) % ne2;\n    const int i1 = (i/ne0) % ne1;\n    const int i0 = i % ne0;\n\n    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {\n        return;\n    }\n\n    const int i11 = i1 % ne11;\n    const int i12 = i2 % ne12;\n    const int i13 = i3 % ne13;\n\n    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;\n    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;\n    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;\n\n    const src0_t * src0_row = src0 + i_src0;\n    const src1_t * src1_row = src1 + i_src1;\n    dst_t * dst_row = dst + i_dst;\n\n    const int i10 = i0 % ne10;\n    dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]);\n}\n\n\ntemplate<float (*bin_op)(const float, const float)>\nstruct bin_bcast_sycl {\n    template <typename src0_t, typename src1_t, typename dst_t>\n    void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,\n                    const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,\n                    const int64_t ne12, const int64_t ne13, const int64_t ne0, const int64_t ne1, const int64_t ne2,\n                    const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03,\n                    const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,\n                    const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,\n                    const bool src1_is_contiguous, const bool src0_is_permuted, const bool src1_is_permuted,\n                    queue_ptr stream) {\n        int nr0 = ne10 / ne0;\n        int nr1 = ne11/ne1;\n        int nr2 = ne12/ne2;\n        int nr3 = ne13/ne3;\n\n        int nr[4] = { nr0, nr1, nr2, nr3 };\n\n        // collapse dimensions until first broadcast dimension\n        int64_t cne[] = {ne0, ne1, ne2, ne3};\n        int64_t cne0[] = {ne00, ne01, ne02, ne03};\n        int64_t cne1[] = {ne10, ne11, ne12, ne13};\n        size_t cnb[] = {nb0, nb1, nb2, nb3};\n        size_t cnb0[] = {nb00, nb01, nb02, nb03};\n        size_t cnb1[] = {nb10, nb11, nb12, nb13};\n        auto collapse = [](int64_t cne[]) {\n            cne[0] *= cne[1];\n            cne[1] = cne[2];\n            cne[2] = cne[3];\n            cne[3] = 1;\n        };\n\n        auto collapse_nb = [](size_t cnb[], int64_t cne[]) {\n            cnb[1] *= cne[1];\n            cnb[2] *= cne[2];\n            cnb[3] *= cne[3];\n        };\n\n        if (src0_is_contiguous && src1_is_contiguous && !src0_is_permuted && !src1_is_permuted) {\n            for (int i = 0; i < 4; i++) {\n                if (nr[i] != 1) {\n                    break;\n                }\n                if (i > 0) {\n                    collapse_nb(cnb, cne);\n                    collapse_nb(cnb0, cne0);\n                    collapse_nb(cnb1, cne1);\n                    collapse(cne);\n                    collapse(cne0);\n                    collapse(cne1);\n                }\n            }\n        }\n        {\n            int64_t ne0 = cne[0];\n            int64_t ne1 = cne[1];\n            int64_t ne2 = cne[2];\n            int64_t ne3 = cne[3];\n\n            int64_t ne10 = cne1[0];\n            int64_t ne11 = cne1[1];\n            int64_t ne12 = cne1[2];\n            int64_t ne13 = cne1[3];\n\n            size_t nb0 = cnb[0];\n            size_t nb1 = cnb[1];\n            size_t nb2 = cnb[2];\n            size_t nb3 = cnb[3];\n\n            size_t nb00 = cnb0[0];\n            size_t nb01 = cnb0[1];\n            size_t nb02 = cnb0[2];\n            size_t nb03 = cnb0[3];\n\n            size_t nb10 = cnb1[0];\n            size_t nb11 = cnb1[1];\n            size_t nb12 = cnb1[2];\n            size_t nb13 = cnb1[3];\n\n            // size_t s0 = nb0 / sizeof(dst_t);\n            size_t s1 = nb1 / sizeof(dst_t);\n            size_t s2 = nb2 / sizeof(dst_t);\n            size_t s3 = nb3 / sizeof(dst_t);\n\n            size_t s10 = nb10 / sizeof(src1_t);\n            size_t s11 = nb11 / sizeof(src1_t);\n            size_t s12 = nb12 / sizeof(src1_t);\n            size_t s13 = nb13 / sizeof(src1_t);\n\n            size_t s00 = nb00 / sizeof(src0_t);\n            size_t s01 = nb01 / sizeof(src0_t);\n            size_t s02 = nb02 / sizeof(src0_t);\n            size_t s03 = nb03 / sizeof(src0_t);\n\n            GGML_UNUSED(s00);\n\n            GGML_ASSERT(nb0 % sizeof(dst_t) == 0);\n            GGML_ASSERT(nb1 % sizeof(dst_t) == 0);\n            GGML_ASSERT(nb2 % sizeof(dst_t) == 0);\n            GGML_ASSERT(nb3 % sizeof(dst_t) == 0);\n\n            GGML_ASSERT(nb00 % sizeof(src0_t) == 0);\n            GGML_ASSERT(nb01 % sizeof(src0_t) == 0);\n            GGML_ASSERT(nb02 % sizeof(src0_t) == 0);\n            GGML_ASSERT(nb03 % sizeof(src0_t) == 0);\n\n            GGML_ASSERT(nb10 % sizeof(src1_t) == 0);\n            GGML_ASSERT(nb11 % sizeof(src1_t) == 0);\n            GGML_ASSERT(nb12 % sizeof(src1_t) == 0);\n            GGML_ASSERT(nb13 % sizeof(src1_t) == 0);\n\n            const int block_size = 128;\n\n            int64_t hne0 = std::max(ne0/2LL, 1LL);\n\n            sycl::range<3> block_dims(1, 1, 1);\n            block_dims[2] = std::min<unsigned int>(hne0, block_size);\n            block_dims[1] = std::min<unsigned int>(\n                ne1, block_size / (unsigned int)block_dims[2]);\n            block_dims[0] = std::min(\n                std::min<unsigned int>(\n                    ne2 * ne3, block_size / (unsigned int)block_dims[2] /\n                                   (unsigned int)block_dims[1]),\n                64U);\n\n            sycl::range<3> block_nums(\n                (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],\n                (ne1 + block_dims[1] - 1) / block_dims[1],\n                (hne0 + block_dims[2] - 1) / block_dims[2]);\n\n            if (block_nums[0] > 65535) {\n                // this is the maximum number of blocks in z direction, fallback to 1D grid kernel\n                int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;\n                {\n                    dpct::has_capability_or_fail(stream->get_device(),\n                                                 {sycl::aspect::fp16});\n\n                    stream->parallel_for(\n                        sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *\n                                              sycl::range<3>(1, 1, block_size),\n                                          sycl::range<3>(1, 1, block_size)),\n                        [=](sycl::nd_item<3> item_ct1) {\n                            k_bin_bcast_unravel<bin_op>(\n                                src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,\n                                ne10, ne11, ne12, ne13, s1, s2, s3, s00, s01, s02,\n                                s03, s10, s11, s12, s13, item_ct1);\n                        });\n                }\n            } else {\n                /*\n                DPCT1049:16: The work-group size passed to the SYCL kernel may\n                exceed the limit. To get the device limit, query\n                info::device::max_work_group_size. Adjust the work-group size if\n                needed.\n                */\n                dpct::has_capability_or_fail(stream->get_device(),\n                                             {sycl::aspect::fp16});\n\n                stream->parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,\n                                            ne2, ne3, ne10, ne11, ne12, ne13,\n                                            s1, s2, s3, s00, s01, s02, s03, s10, s11, s12, s13,\n                                            item_ct1);\n                    });\n            }\n        }\n    }\n};\n\ntemplate <class op>\ninline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,\n                                   ggml_tensor * dst) {\n    dpct::queue_ptr main_stream = ctx.stream();\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n        op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10,\n             ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3,\n             ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);\n    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {\n        op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01,\n             ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13,\n             nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),\n             main_stream);\n    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {\n        op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02,\n             ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1,\n             nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),\n             main_stream);\n    } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {\n        op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03,\n             ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,\n             nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),\n             main_stream);\n    } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {\n        op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03,\n             ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,\n             nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),\n             main_stream);\n    } else {\n        fprintf(stderr, \"%s: unsupported types: dst: %s, src0: %s, src1: %s\\n\", __func__, ggml_type_name(dst->type),\n                ggml_type_name(src0->type), ggml_type_name(src1->type));\n        GGML_ABORT(\"fatal error\");\n    }\n}\n\ninline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {\n\n    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, dst->src[0], dst->src[1], dst);\n}\n\ninline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {\n\n    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);\n}\n\ninline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {\n\n    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);\n}\n\ninline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {\n\n    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, dst->src[0], dst->src[1], dst);\n}\n\ninline void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {\n    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, dst->src[0], dst);\n}\n\n\nvoid ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    ggml_sycl_op_add(ctx, dst);\n}\n\nvoid ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    ggml_sycl_op_sub(ctx, dst);\n}\n\nvoid ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    ggml_sycl_op_mul(ctx, dst);\n}\n\nvoid ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    ggml_sycl_op_div(ctx, dst);\n}\n\nvoid ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_repeat(ctx, dst);\n}\n\n"
  },
  {
    "path": "src/ggml-sycl/binbcast.hpp",
    "content": "#ifndef GGML_SYCL_BINBCAST_HPP\n#define GGML_SYCL_BINBCAST_HPP\n#include \"common.hpp\"\n\n\nstatic __dpct_inline__ float op_repeat(const float a, const float b) {\n    return b;\n    GGML_UNUSED(a);\n}\n\nstatic __dpct_inline__ float op_add(const float a, const float b) {\n    return a + b;\n}\n\nstatic __dpct_inline__ float op_sub(const float a, const float b) {\n    return a - b;\n}\n\nstatic __dpct_inline__ float op_mul(const float a, const float b) {\n    return a * b;\n}\n\nstatic __dpct_inline__ float op_div(const float a, const float b) {\n    return a / b;\n}\n\nvoid ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\n\n#endif //GGML_SYCL_BINBCAST_HPP\n\n"
  },
  {
    "path": "src/ggml-sycl/common.cpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#include \"common.hpp\"\n\n#include \"ggml-backend-impl.h\"\n#include \"ggml-impl.h\"\n\nint get_current_device_id() {\n  return dpct::dev_mgr::instance().current_device_id();\n}\n\nvoid* ggml_sycl_host_malloc(size_t size) try {\n  if (getenv(\"GGML_SYCL_NO_PINNED\") != nullptr) {\n    return nullptr;\n  }\n\n  void* ptr = nullptr;\n  // allow to use dpct::get_in_order_queue() for host malloc\n  dpct::err0 err = CHECK_TRY_ERROR(\n      ptr = (void*)sycl::malloc_host(size, dpct::get_in_order_queue()));\n\n  if (err != 0) {\n    // clear the error\n    GGML_LOG_ERROR(\"WARNING: failed to allocate %.2f MB of pinned memory: %s\\n\", size / 1024.0 / 1024.0,    \"syclGetErrorString is not supported\");\n    return nullptr;\n  }\n\n  return ptr;\n} catch (sycl::exception const& exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nvoid ggml_sycl_host_free(void* ptr) try {\n  // allow to use dpct::get_in_order_queue() for host malloc\n  SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue())));\n} catch (sycl::exception const& exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nbool gpu_has_xmx(sycl::device &dev) {\n    return dev.has(sycl::aspect::ext_intel_matrix);\n}\n\nint64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {\n  const int64_t max_range = std::numeric_limits<int>::max();\n  int64_t sycl_down_blk_size = block_size;\n  int64_t global_range = accumulate_block_num * sycl_down_blk_size;\n  while(global_range > max_range) {\n      sycl_down_blk_size /= 2;\n      global_range = accumulate_block_num * sycl_down_blk_size;\n  }\n  return sycl_down_blk_size;\n}\n\nvoid release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams) {\n    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {\n        for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {\n            if (extra->events[i][is] != nullptr) {\n                SYCL_CHECK(CHECK_TRY_ERROR(dpct::destroy_event(extra->events[i][is])));\n            }\n        }\n        if (extra->data_device[i] != nullptr && streams.size()>0) {\n            ggml_sycl_set_device(i);\n            SYCL_CHECK(\n                CHECK_TRY_ERROR(sycl::free(extra->data_device[i], *(streams[i]))));\n        }\n    }\n    delete extra;\n}\n"
  },
  {
    "path": "src/ggml-sycl/common.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_COMMON_HPP\n#define GGML_SYCL_COMMON_HPP\n\n#include <cstddef>\n#include <fstream>\n#include <iostream>\n#include <string>\n\n#include \"dpct/helper.hpp\"\n#include \"ggml.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-sycl.h\"\n#include \"presets.hpp\"\n#include \"sycl_hw.hpp\"\n\nnamespace syclexp = sycl::ext::oneapi::experimental;\n\n#if GGML_SYCL_DNNL\n#include \"dnnl.hpp\"\n#include \"dnnl_sycl.hpp\"\n#endif\n\n#define GGML_COMMON_DECL_SYCL\n#define GGML_COMMON_IMPL_SYCL\n#define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building.\n#define SYCL_FAST_FP16  //don't change. remove it will break fattn-tile.hpp building\n\n/* suppress warning spam */\n#pragma clang diagnostic push\n#pragma clang diagnostic ignored \"-Wnested-anon-types\"\n#include \"ggml-common.h\"\n#pragma clang diagnostic pop\n#include \"ggml-impl.h\"\n\nvoid* ggml_sycl_host_malloc(size_t size);\nvoid ggml_sycl_host_free(void* ptr);\n\n\nextern int g_ggml_sycl_debug;\nextern int g_ggml_sycl_disable_optimize;\nextern int g_ggml_sycl_prioritize_dmmv;\nextern int g_ggml_sycl_enable_flash_attention;\n\n\n#if defined(__clang__) && __has_builtin(__builtin_expect)\n// Hint the optimizer to pipeline the more likely following instruction in branches\n#    define LIKELY(expr)   __builtin_expect(expr, true)\n#    define UNLIKELY(expr) __builtin_expect(expr, false)\n#else\n#    define LIKELY(expr)   (expr)\n#    define UNLIKELY(expr) (expr)\n#endif\n\n#define GGML_SYCL_DEBUG(...)              \\\n    do {                                  \\\n        if (UNLIKELY(g_ggml_sycl_debug))  \\\n            fprintf(stderr, __VA_ARGS__); \\\n    } while (0)\n\n#define CHECK_TRY_ERROR(expr)                                            \\\n  [&]() {                                                                \\\n    try {                                                                \\\n      expr;                                                              \\\n      return dpct::success;                                              \\\n    } catch (std::exception const& e) {                                  \\\n      std::cerr << e.what() << \"\\nException caught at file:\" << __FILE__ \\\n                << \", line:\" << __LINE__ << \", func:\" << __func__        \\\n                << std::endl;                                            \\\n      return dpct::default_error;                                        \\\n    }                                                                    \\\n  }()\n\n\n#define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP\n#define VER_4VEC 610 // todo for hardware optimize.\n#define VER_GEN9 700 // todo for hardware optimize.\n#define VER_GEN12 1000000 // todo for hardware optimize.\n#define VER_GEN13 (VER_GEN12 + 1030) // todo for hardware optimize.\n\n#define GGML_SYCL_MAX_NODES 8192 // TODO: adapt to hardwares\n\n// define for XMX in Intel GPU\n// TODO: currently, it's not used for XMX really.\n#if !defined(GGML_SYCL_FORCE_MMQ)\n    #define SYCL_USE_XMX\n#endif\n\n// max batch size to use MMQ kernels when tensor cores are available\n#define MMQ_MAX_BATCH_SIZE 32\n\n// dmmv = dequantize_mul_mat_vec\n#ifndef GGML_SYCL_DMMV_X\n#define GGML_SYCL_DMMV_X 32\n#endif\n#ifndef GGML_SYCL_MMV_Y\n#define GGML_SYCL_MMV_Y 1\n#endif\n\ntypedef sycl::queue *queue_ptr;\n\nenum ggml_sycl_backend_gpu_mode {\n  SYCL_UNSET_GPU_MODE = -1,\n  SYCL_SINGLE_GPU_MODE = 0,\n  SYCL_MUL_GPU_MODE\n};\n\nstatic_assert(sizeof(sycl::half) == sizeof(ggml_fp16_t), \"wrong fp16 size\");\n\nstatic void crash() {\n  int* ptr = NULL;\n  *ptr = 0;\n}\n\n[[noreturn]] static void ggml_sycl_error(\n    const char* stmt,\n    const char* func,\n    const char* file,\n    const int line,\n    const char* msg) {\n  fprintf(stderr, \"SYCL error: %s: %s\\n\", stmt, msg);\n  fprintf(stderr, \"  in function %s at %s:%d\\n\", func, file, line);\n  GGML_ABORT(\"SYCL error\");\n}\n\n#define SYCL_CHECK(err)                                                                                    \\\n    do {                                                                                                   \\\n        auto err_ = (err);                                                                                 \\\n        if (err_ != 0)                                                                                     \\\n            ggml_sycl_error(#err, __func__, __FILE__, __LINE__, \"Exception caught in this line of code.\"); \\\n    } while (0)\n\n#if DPCT_COMPAT_RT_VERSION >= 11100\n#define GGML_SYCL_ASSUME(x) __builtin_assume(x)\n#else\n#define GGML_SYCL_ASSUME(x)\n#endif // DPCT_COMPAT_RT_VERSION >= 11100\n\n#ifdef GGML_SYCL_F16\ntypedef sycl::half dfloat; // dequantize float\ntypedef sycl::half2 dfloat2;\n#else\ntypedef float dfloat; // dequantize float\ntypedef sycl::float2 dfloat2;\n#endif // GGML_SYCL_F16\n\n#define MMVQ_MAX_BATCH_SIZE  8\n\nstatic int g_all_sycl_device_count = -1;\nstatic bool g_ggml_backend_sycl_buffer_type_initialized = false;\n\nstatic ggml_sycl_backend_gpu_mode g_ggml_sycl_backend_gpu_mode =\n    SYCL_UNSET_GPU_MODE;\n\nstatic void* g_scratch_buffer = nullptr;\nstatic size_t g_scratch_size = 0; // disabled by default\nstatic size_t g_scratch_offset = 0;\n\n[[noreturn]] static inline void bad_arch(const sycl::stream& stream_ct1) {\n  stream_ct1 << \"ERROR: ggml-sycl was compiled without support for the \"\n                \"current GPU architecture.\\n\";\n  // __trap();\n  std::exit(1);\n\n  (void)bad_arch; // suppress unused function warning\n}\n\nint get_current_device_id();\n\ninline int ggml_sycl_get_device() {\n    return get_current_device_id();\n}\n\ninline dpct::err0 ggml_sycl_set_device(const int device) try {\n  int current_device_id;\n  SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));\n\n  // GGML_SYCL_DEBUG(\"ggml_sycl_set_device device_id=%d,\n  // current_device_id=%d\\n\", device, current_device);\n  if (device == current_device_id) {\n    return 0;\n  }\n\n  return CHECK_TRY_ERROR(dpct::select_device(device));\n} catch (sycl::exception const& exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  crash();\n  std::exit(1);\n}\n\n//////////////////////\nstruct optimize_feature {\n    bool reorder=false;\n};\n\nstruct sycl_device_info {\n    int cc;  // compute capability\n    int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum\n             // number of compute units on a SYCL device.\n    // size_t  smpb;               // max. shared memory per block\n    size_t  smpbo;              // max. shared memory per block (with opt-in)\n    int warp_size;     // WARP_SIZE(16)|WARP_32_SIZE(32)|WARP_16_SIZE(16). For Intel GPU, 16 is better in most cases. Some OP support 32 only.\n    int max_wg_per_cu; // max work groups per compute unit - refer to\n                       // cudaOccupancyMaxActiveBlocksPerMultiprocessor\n    bool    vmm;                // virtual memory support\n    size_t  total_vram;\n    //sycl_hw_info hw_info;     \\\\ device id and aarch, currently not used\n    optimize_feature opt_feature;\n};\n\n\nstruct ggml_sycl_device_info {\n    int device_count;\n\n    sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {};\n\n    std::array<float, GGML_SYCL_MAX_DEVICES> default_tensor_split = {};\n\n    int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0};\n};\n\nconst ggml_sycl_device_info & ggml_sycl_info();\n\nstruct ggml_sycl_pool {\n    virtual ~ggml_sycl_pool() = default;\n\n    virtual void * alloc(size_t size, size_t * actual_size) = 0;\n    virtual void free(void * ptr, size_t size) = 0;\n};\n\ntemplate<typename T>\nstruct ggml_sycl_pool_alloc {\n    ggml_sycl_pool * pool = nullptr;\n    T * ptr = nullptr;\n    size_t actual_size = 0;\n\n    explicit ggml_sycl_pool_alloc(ggml_sycl_pool & pool) : pool(&pool) {\n    }\n\n    ggml_sycl_pool_alloc(ggml_sycl_pool & pool, size_t size) : pool(&pool) {\n        alloc(size);\n    }\n\n    ~ggml_sycl_pool_alloc() {\n        if (ptr != nullptr) {\n            pool->free(ptr, actual_size);\n        }\n    }\n\n    T * realloc(size_t size) {\n        GGML_ASSERT(pool != nullptr);\n        if (ptr)\n            pool->free(ptr, actual_size);\n        ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);\n        return ptr;\n    }\n\n    // size is in number of elements\n    T * alloc(size_t size) {\n        GGML_ASSERT(pool != nullptr);\n        GGML_ASSERT(ptr == nullptr);\n        ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);\n        return ptr;\n    }\n\n    T * alloc(ggml_sycl_pool & pool, size_t size) {\n        this->pool = &pool;\n        return alloc(size);\n    }\n\n    T * get() {\n        return ptr;\n    }\n\n    ggml_sycl_pool_alloc() = default;\n    ggml_sycl_pool_alloc(const ggml_sycl_pool_alloc &) = delete;\n    ggml_sycl_pool_alloc(ggml_sycl_pool_alloc &&) = delete;\n    ggml_sycl_pool_alloc& operator=(const ggml_sycl_pool_alloc &) = delete;\n    ggml_sycl_pool_alloc& operator=(ggml_sycl_pool_alloc &&) = delete;\n};\n\n// backend interface\n\nstruct ggml_tensor_extra_gpu {\n  void* data_device[GGML_SYCL_MAX_DEVICES]; // 1 pointer for each device for split\n                                       // tensors\n  dpct::event_ptr events[GGML_SYCL_MAX_DEVICES]\n                        [GGML_SYCL_MAX_STREAMS]; // events for synchronizing multiple GPUs\n  optimize_feature optimized_feature;\n};\n\nvoid release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams={});\n\nnamespace sycl_ex = sycl::ext::oneapi::experimental;\nstruct ggml_backend_sycl_context {\n    int device;\n    std::string name;\n    optimize_feature opt_feature;\n\n    queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } };\n\n    explicit ggml_backend_sycl_context(int device) :\n        device(device),\n        name(GGML_SYCL_NAME + std::to_string(device)) {\n        opt_feature = ggml_sycl_info().devices[device].opt_feature;\n    }\n\n    queue_ptr stream(int device, int stream) {\n        if (qptrs[device][stream] == nullptr) {\n            qptrs[device][stream] = &(dpct::get_device(device).default_queue());\n        }\n        return qptrs[device][stream];\n    }\n\n    queue_ptr stream() {\n        return stream(device, 0);\n    }\n\n#if GGML_SYCL_DNNL\n    dnnl::engine make_engine(sycl::queue* q) {\n        // Get the device associated with the queue\n        sycl::device dev = q->get_device();\n        // Get the context associated with the queue\n        sycl::context ctx = q->get_context();\n        const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);\n        return eng;\n    }\n\n    std::unordered_map<sycl::queue*, dnnl::stream> stream_map;\n    std::unordered_map<sycl::queue*, dnnl::engine> engine_map;\n    dnnl::stream stream_dnnl(int device, int _stream) {\n        auto q = stream(device, _stream);\n        return stream_dnnl(q);\n    }\n    dnnl::engine engine_dnnl(sycl::queue* qptr) {\n        auto it = engine_map.find(qptr);\n        if (it == engine_map.end()) {\n            auto eng = make_engine(qptr);\n            engine_map[qptr] = eng;\n            return eng;\n        }\n        else\n        {\n            return it->second;\n        }\n    }\n    dnnl::stream stream_dnnl(sycl::queue* qptr) {\n        auto it = stream_map.find(qptr);\n        if (it == stream_map.end()) {\n            auto eng = engine_dnnl(qptr);\n            auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);\n            stream_map[qptr] = stream;\n            return stream;\n        }\n        else\n        {\n            return it->second;\n        }\n    }\n    dnnl::stream stream_dnnl() {\n        return stream_dnnl(device, 0);\n    }\n    dnnl::memory get_scratchpad_mem(const dnnl::memory::desc & scratchpad_md,\n                                    const dnnl::engine & eng, const queue_ptr q) {\n        ggml_sycl_pool_alloc<uint8_t> * pool;\n        auto it = scratchpad_map.find(q);\n        if (it == scratchpad_map.end()) {\n            scratchpad_map[q] = std::make_unique<ggml_sycl_pool_alloc<uint8_t>>(this->pool());\n            pool = scratchpad_map[q].get();\n        } else {\n            pool = it->second.get();\n        }\n\n        size_t scratchpad_size = scratchpad_md.get_size();\n        if (scratchpad_size > pool->actual_size) {\n            pool->realloc(scratchpad_size);\n        }\n        void * mem_ptr = pool->get();\n        return dnnl::memory(scratchpad_md, eng, mem_ptr);\n    }\n#endif\n\n    // pool\n    std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];\n    std::unordered_map<sycl::queue *, std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>>> scratchpad_map;\n\n    std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];\n\n    static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);\n\n    static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);\n\n    ggml_sycl_pool & pool(int device) {\n        if (pools[device] == nullptr) {\n            pools[device] = new_pool_for_device(stream(device,0), device);\n        }\n        return *pools[device];\n    }\n\n    ggml_sycl_pool & pool() {\n        return pool(device);\n    }\n\n#ifdef GGML_SYCL_GRAPH\n    std::unique_ptr<sycl_ex::command_graph<sycl_ex::graph_state::executable>> exec_graph = nullptr;\n#endif\n\n    ggml_sycl_pool & host_pool(int device) {\n        if (host_pools[device] == nullptr) {\n            host_pools[device] = new_pool_for_host(stream(device, 0), device);\n        }\n        return *host_pools[device];\n    }\n\n    ggml_sycl_pool & host_pool() { return host_pool(device); }\n};\n\n// common device functions\n\nstatic __dpct_inline__ float warp_reduce_sum(float x,\n    const sycl::nd_item<3>& item_ct1) {\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask);\n    }\n    return x;\n}\n\nstatic __dpct_inline__ sycl::float2\nwarp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        a.x() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.x(),\n            mask);\n        a.y() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.y(),\n            mask);\n    }\n    return a;\n}\n\n/* use WARP_SIZE or WARP_32_SIZE*/\ntemplate <int width>\nstatic __dpct_inline__ int warp_reduce_sum(int x) {\n  return sycl::reduce_over_group(\n      sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>());\n}\n\n/* use WARP_SIZE or WARP_32_SIZE*/\ntemplate <int width>\nstatic __dpct_inline__ float warp_reduce_sum(float x) {\n#pragma unroll\n  for (int offset = width / 2; offset > 0; offset >>= 1) {\n    x += dpct::permute_sub_group_by_xor(\n        sycl::ext::oneapi::this_work_item::get_sub_group(), x, offset, width);\n  }\n  return x;\n}\n\n/* use WARP_SIZE or WARP_32_SIZE*/\ntemplate <int width>\nstatic __dpct_inline__ float warp_reduce_sum(float x, const sycl::nd_item<3>& item_ct1) {\n#pragma unroll\n  for (int offset = width / 2; offset > 0; offset >>= 1) {\n    x += dpct::permute_sub_group_by_xor(\n        item_ct1.get_sub_group(), x, offset);\n  }\n  return x;\n}\n\n/* use WARP_SIZE or WARP_32_SIZE*/\ntemplate <int width>\nstatic __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {\n#pragma unroll\n  for (int offset = width / 2; offset > 0; offset >>= 1) {\n    a.x() += dpct::permute_sub_group_by_xor(\n        sycl::ext::oneapi::this_work_item::get_sub_group(), a.x(), offset,\n        width);\n    a.y() += dpct::permute_sub_group_by_xor(\n        sycl::ext::oneapi::this_work_item::get_sub_group(), a.y(), offset,\n        width);\n  }\n  return a;\n}\n\n/* use WARP_SIZE or WARP_32_SIZE*/\ntemplate <int width>\nstatic __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) {\n#pragma unroll\n  for (int offset = width / 2; offset > 0; offset >>= 1) {\n    a = a + dpct::permute_sub_group_by_xor(\n                sycl::ext::oneapi::this_work_item::get_sub_group(), a, offset,\n                width);\n  }\n  return a;\n}\n\nstatic constexpr int ggml_sycl_get_physical_warp_size() {\n  // todo: for old iGPU + dGPU case, need to be changed.\n  return WARP_SIZE;\n}\n\n/* use WARP_SIZE or WARP_32_SIZE*/\ntemplate <int width>\nstatic __dpct_inline__ int warp_reduce_all(int x) {\n    if (width == ggml_sycl_get_physical_warp_size()) {\n        return sycl::all_of_group(\n            sycl::ext::oneapi::this_work_item::get_sub_group(),\n            (~0xffffffff &\n             (0x1 << sycl::ext::oneapi::this_work_item::get_sub_group()\n                         .get_local_linear_id())) ||\n                x);\n    } else {\n#pragma unroll\n        for (int offset = width / 2; offset > 0; offset >>= 1) {\n            x = dpct::permute_sub_group_by_xor(\n                    sycl::ext::oneapi::this_work_item::get_sub_group(), x,\n                    offset, width) &&\n                x;\n        }\n        return x;\n    }\n}\n\n/* use WARP_SIZE or WARP_32_SIZE*/\ntemplate <int width>\nstatic __dpct_inline__ int warp_reduce_any(int x) {\n    if (width == ggml_sycl_get_physical_warp_size()) {\n        return sycl::any_of_group(\n            sycl::ext::oneapi::this_work_item::get_sub_group(),\n            (0xffffffff &\n             (0x1 << sycl::ext::oneapi::this_work_item::get_sub_group()\n                         .get_local_linear_id())) &&\n                x);\n    } else {\n#pragma unroll\n        for (int offset = width / 2; offset > 0; offset >>= 1) {\n            x = dpct::permute_sub_group_by_xor(\n                    sycl::ext::oneapi::this_work_item::get_sub_group(), x,\n                    offset, width) ||\n                x;\n        }\n        return x;\n    }\n}\n\n/* use WARP_SIZE or WARP_32_SIZE*/\ntemplate <int width>\nstatic __dpct_inline__ float warp_reduce_max(float x) {\n#pragma unroll\n  for (int offset = width / 2; offset > 0; offset >>= 1) {\n    x = sycl::fmax(x, dpct::permute_sub_group_by_xor(\n                          sycl::ext::oneapi::this_work_item::get_sub_group(), x,\n                          offset, width));\n  }\n  return x;\n}\n\nstatic __dpct_inline__ float warp_reduce_max(float x,\n    const sycl::nd_item<3>& item_ct1) {\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        x = sycl::fmax(x, dpct::permute_sub_group_by_xor(\n            item_ct1.get_sub_group(), x, mask));\n    }\n    return x;\n}\n\n/* Helper for Computing the linear offset of a ggml_tensor given\nper-dimension sizes, strides, and indices */\ntemplate<int N>\n__dpct_inline__ size_t calculate_offset(const std::array<int, N> & strides, const std::array<int, N> & indices) {\n    size_t offset = 0;\n#pragma unroll\n    for (int i = 0; i < N; i++) {\n        auto index_i = indices[i];\n        offset += strides[i] * index_i;\n    }\n    return offset;\n}\n\n// Helper for vec loading aligned data\ntemplate <typename Tp, int n>\ninline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {\n    return *reinterpret_cast<const sycl::vec<Tp, n>*>(aligned_ptr);\n}\n\n// Helper for accessing pointers with no warnings\ntemplate <typename Tp, int dim>\nstatic __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {\n    return acc.template get_multi_ptr<sycl::access::decorated::no>().get();\n}\n\nint64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);\n\nconstexpr size_t ceil_div(const size_t m, const size_t n) {\n    return (m + n - 1) / n;\n}\n\nbool gpu_has_xmx(sycl::device &dev);\n\ntemplate <int N, class T> std::string debug_get_array_str(const std::string & prefix, const T array[N]) {\n    if (LIKELY(!g_ggml_sycl_debug)) {\n        return \"\";\n    }\n    std::stringstream ss;\n    ss << prefix << \"=[\";\n    for (std::size_t i = 0; i < N - 1; ++i) {\n        ss << array[i] << \", \";\n    }\n    if constexpr (N > 0) {\n        ss << array[N - 1];\n    }\n    ss << \"]\";\n    return ss.str();\n}\n\ninline std::string debug_get_tensor_str(const std::string &prefix,\n        const ggml_tensor *tensor, const std::string &suffix = \"\") {\n    std::stringstream ss;\n    if (LIKELY(!g_ggml_sycl_debug)) { return ss.str(); }\n    ss << prefix.c_str() << \"=\";\n    if (tensor) {\n        ss << \"'\" << tensor->name << \"':type=\" << ggml_type_name(tensor->type);\n        ss << debug_get_array_str<GGML_MAX_DIMS>(\";ne\", tensor->ne);\n        ss << debug_get_array_str<GGML_MAX_DIMS>(\";nb\", tensor->nb);\n\n        if (!ggml_is_contiguous(tensor)) { ss << \";strided\"; }\n        if (ggml_is_permuted(tensor)) { ss << \";permuted\"; }\n    } else {\n        ss << \"nullptr\";\n    }\n    ss << suffix;\n    return ss.str();\n}\n\n// Use scope_op_debug_print to log operations coming from running a model\nstruct scope_op_debug_print {\n    // Use string_views to avoid the cost of creating a string and concatenating them\n    // string_views must be alive for as long as the object is alive\n    // scope_op_debug_print are used with string literals in practice which are stored in constant space so always accessible\n    scope_op_debug_print(const std::string_view & func, const std::string_view & func_suffix, const ggml_tensor * dst,\n                         std::size_t num_src, const std::string_view & suffix = \"\") :\n        func(func),\n        func_suffix(func_suffix) {\n        if (LIKELY(!g_ggml_sycl_debug)) {\n            return;\n        }\n        GGML_SYCL_DEBUG(\"[SYCL][OP] call %s%s:\", func.data(), func_suffix.data());\n        GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\" dst\", dst).c_str());\n        if (dst) {\n            for (std::size_t i = 0; i < num_src; ++i) {\n                GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\"\\tsrc\" + std::to_string(i), dst->src[i]).c_str());\n            }\n        }\n        GGML_SYCL_DEBUG(\"%s\\n\", suffix.data());\n    }\n\n    scope_op_debug_print(const std::string_view & func, const ggml_tensor * dst, std::size_t num_src,\n                         const std::string_view & suffix = \"\") :\n        scope_op_debug_print(func, \"\", dst, num_src, suffix) {}\n\n    ~scope_op_debug_print() { GGML_SYCL_DEBUG(\"[SYCL][OP] call %s%s done\\n\", func.data(), func_suffix.data()); }\n\n  private:\n    std::string_view func;\n    std::string_view func_suffix;\n};\n\nstatic __dpct_inline__ float get_alibi_slope(const float    max_bias,\n                                             const uint32_t h,\n                                             const uint32_t n_head_log2,\n                                             const float    m0,\n                                             const float    m1) {\n    if (max_bias <= 0.0f) {\n        return 1.0f;\n    }\n    const float base = h < n_head_log2 ? m0 : m1;\n    const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;\n\n    return dpct::pow(base, exph);\n}\n\nstatic const sycl::uint3 init_fastdiv_values(uint32_t d) {\n    GGML_ASSERT(d != 0);\n\n    uint32_t L = 0;\n    while (L < 32 && (uint32_t{ 1 } << L) < d) {\n        L++;\n    }\n\n    uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);\n    return sycl::uint3(mp, L, d);\n}\n\n// Maximum number of bytes that can be copied in a single instruction.\n// Set by test result.\nstatic constexpr int ggml_sycl_get_max_cpy_bytes() {\n    return 16;\n}\n\n// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes.\ntemplate <int nbytes, int alignment = 0>\nstatic __dpct_inline__ void ggml_sycl_memcpy_1(void * dst, const void * src) {\n    if constexpr (alignment != 0) {\n        static_assert(nbytes % alignment == 0, \"bad alignment\");\n    }\n    constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;\n\n#pragma unroll\n    for (int i = 0; i < nbytes/nb_per_cpy; ++i) {\n        if constexpr (nb_per_cpy == 1) {\n            ((char *) dst)[i] = ((const char *) src)[i];\n        } else if constexpr (nb_per_cpy == 2) {\n            ((short *) dst)[i] = ((const short *) src)[i];\n        } else if constexpr (nb_per_cpy == 4) {\n            ((int *) dst)[i] = ((const int *) src)[i];\n        } else if constexpr (nb_per_cpy == 8) {\n            ((sycl::int2 *) dst)[i] = ((const sycl::int2 *) src)[i];\n        } else if constexpr (nb_per_cpy == 16) {\n            ((sycl::int4 *) dst)[i] = ((const sycl::int4 *) src)[i];\n        } else {\n            static_assert(nbytes == 0 && nbytes == -1, \"bad nbytes\");\n        }\n    }\n}\ntemplate <typename T>\nsycl::half2 __dpct_inline__ make_half2( T x, T y) {\n    sycl::half2 res(static_cast<sycl::half>(x),static_cast<sycl::half>(y));\n    return res;\n}\n\nstatic __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) {\n    const uint32_t hi = sycl::mul_hi<unsigned>(n, fastdiv_values.x());\n    return (hi + n) >> fastdiv_values.y();\n}\n\n\ntemplate <typename T>\nsycl::float2 __dpct_inline__ make_float2( T x, T y) {\n    sycl::float2 res(static_cast<float>(x),static_cast<float>(y));\n    return res;\n}\n\nsycl::float2 __dpct_inline__ __half22float2(sycl::half2 &H) {\n    sycl::float2 float2_value(static_cast<float>(H.x()), static_cast<float>(H.y()));\n    return float2_value;\n}\n\nstatic __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) {\n    const uint32_t div_val = fastdiv(n, fastdiv_values);\n    const uint32_t mod_val = n - div_val * fastdiv_values.z();\n    return sycl::uint2(div_val, mod_val);\n}\n\nstatic __dpct_inline__ int ggml_sycl_dp4a(const int a, const int b, int c) {\n    return dpct::dp4a(a, b, c);\n}\n\nstatic __dpct_inline__ float ggml_sycl_e8m0_to_fp32(uint8_t x) {\n    uint32_t bits;\n    if (x == 0) {\n        bits = 0x00400000;\n    } else {\n        bits = (uint32_t) x << 23;\n    }\n\n    float result;\n    memcpy(&result, &bits, sizeof(float));\n    return result;\n}\n\nsycl::float2 __dpct_inline__ __half22float2(const sycl::half2 &H) {\n    sycl::float2 float2_value(static_cast<float>(H.x()), static_cast<float>(H.y()));\n    return float2_value;\n}\n\nfloat __dpct_inline__ __half2float(sycl::half H) {\n    return static_cast<float>(H);\n}\n\nstatic __dpct_inline__ void ggml_sycl_mad(float & acc, const float v, const float u) {\n    acc += v*u;\n}\n\nstatic __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::float2 v, const sycl::float2 u) {\n    acc += v.x() * u.x();\n    acc += v.y() * u.y();\n}\n\nstatic __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::half2 v, const sycl::half2 u) {\n#ifdef GGML_SYCL_F16\n    const sycl::float2 tmp = (v * u).template convert<float, sycl::rounding_mode::automatic>();\n    acc += tmp.x() + tmp.y();\n#else\n    const sycl::float2 tmpv = __half22float2(v);\n    const sycl::float2 tmpu = __half22float2(u);\n    acc += tmpv.x() * tmpu.x();\n    acc += tmpv.y() * tmpu.y();\n#endif // GGML_SYCL_F16\n}\n\nstatic __dpct_inline__ void ggml_sycl_mad(sycl::half2 & acc, const sycl::half2 v, const sycl::half2 u) {\n#ifdef GGML_SYCL_F16\n    acc += v*u;\n#else\n    const sycl::float2 tmpv = __half22float2(v);\n    const sycl::float2 tmpu = __half22float2(u);\n    sycl::float2 tmpacc = __half22float2(acc);\n    // tmpacc.x += tmpv.x() * tmpu.x();\n    // tmpacc.y += tmpv.y() * tmpu.y();\n    sycl::float2 tmp1(tmpacc.x() + tmpv.x() * tmpu.x(), tmpacc.y() + tmpv.y() * tmpu.y());\n    acc = make_half2(tmp1.x(), tmp1.y());\n#endif // GGML_SYCL_F16\n}\n\ntemplate <int n>\nstruct ggml_sycl_unroll {\n    template <typename Func, typename... Args>\n    void operator()(const Func & f, Args... args) const {\n        f(n - 1, args...);\n        ggml_sycl_unroll<n - 1>{}(f, args...);\n    }\n};\n\ntemplate <>\nstruct ggml_sycl_unroll<1> {\n    template <typename Func, typename... Args>\n    void operator()(const Func & f, Args... args) const {\n        f(0, args...);\n    }\n};\n\nstatic __dpct_inline__ sycl::half2 ggml_sycl_hmax2(const sycl::half2 a, const sycl::half2 b) {\n    sycl::half2 ret;\n    reinterpret_cast<sycl::half &>(ret.x()) =\n        sycl::vec<float, 1>(sycl::fmax(a[0], b[0])).convert<sycl::half, sycl::rounding_mode::automatic>()[0];\n    reinterpret_cast<sycl::half &>(ret.y()) =\n        sycl::vec<float, 1>(sycl::fmax(a[1], b[1])).convert<sycl::half, sycl::rounding_mode::automatic>()[0];\n    return ret;\n}\n\nstatic __dpct_inline__ sycl::half ggml_sycl_hmax(const sycl::half a, const sycl::half b) {\n    return sycl::vec<float, 1>(\n               sycl::fmax(sycl::vec<sycl::half, 1>(a).convert<float, sycl::rounding_mode::automatic>()[0],\n                          sycl::vec<sycl::half, 1>(b).convert<float, sycl::rounding_mode::automatic>()[0]))\n        .convert<sycl::half, sycl::rounding_mode::automatic>()[0];\n}\n\nstatic __dpct_inline__ uint32_t __hgt2_mask(const sycl::half2 a, const sycl::half2 b) {\n    const uint32_t mask_low  = 0x0000FFFF * (float(a[0]) > float(b[0]));\n    const uint32_t mask_high = 0xFFFF0000 * (float(a[1]) > float(b[1]));\n    return mask_low | mask_high;\n}\n\nstatic __dpct_inline__ uint32_t fastmodulo(uint32_t n, const sycl::uint3 fastdiv_values) {\n    // expects  fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)\n    return n - fastdiv(n, fastdiv_values) * fastdiv_values.z();\n}\n\nstatic bool fast_fp16_available(const int cc) {\n    GGML_UNUSED(cc);\n    return true;   //Intel GPUs always support FP16.\n}\n\nenum class block_reduce_method {\n    MAX,\n    SUM,\n};\n\ntemplate<block_reduce_method method_t, typename T, int warp_size>\nstruct block_reduce_policy;\n\ntemplate <typename T, typename... Ts>\ninline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);\n\ntemplate<typename...>\ninline constexpr bool ggml_sycl_dependent_false_v = false;\n\n#define WARP_32_SIZE 32\n\ntemplate <typename T, int warp_size> struct block_reduce_policy<block_reduce_method::SUM, T, warp_size> {\n    static T reduce(T val) {\n        if constexpr (is_any<T, float, sycl::float2, sycl::half2, int>) {\n            return warp_reduce_sum<warp_size>(val);\n        } else {\n            static_assert(ggml_sycl_dependent_false_v<T>, \"Unsupported type for block reduce sum\");\n        }\n    }\n\n    static T sentinel() {\n        if constexpr (std::is_same_v<T, float>) {\n            return 0.0f;\n        } else if constexpr (std::is_same_v<T, sycl::float2>) {\n            return sycl::float2(0.0f, 0.0f);\n        } else if constexpr (std::is_same_v<T, sycl::half2>) {\n            return sycl::half2(0.0f, 0.0f);\n        } else if constexpr (std::is_same_v<T, int>) {\n            return 0;\n        } else {\n            static_assert(ggml_sycl_dependent_false_v<T>, \"Unsupported type for block reduce sum\");\n        }\n    }\n};\n\ntemplate <typename T, int warp_size> struct block_reduce_policy<block_reduce_method::MAX, T, warp_size> {\n    static T reduce(T val) {\n        if constexpr (is_any<T, float, sycl::half2>) {\n            return warp_reduce_max<warp_size>(val);\n        } else {\n            static_assert(ggml_sycl_dependent_false_v<T>, \"Unsupported type for block reduce max\");\n        }\n    }\n\n    static T sentinel() {\n        if constexpr (std::is_same_v<T, float>) {\n            return -INFINITY;\n        } else if constexpr (std::is_same_v<T, sycl::half2>) {\n            return sycl::half2(-INFINITY, -INFINITY);\n        } else {\n            static_assert(ggml_sycl_dependent_false_v<T>, \"Unsupported type for block reduce max\");\n        }\n    }\n};\n\n\ntemplate <block_reduce_method reduce_method_t, int warp_size, typename T>\nstatic T block_reduce(T val, T * shared_vals, int block_size_template) {\n    auto item_ct1                 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    val                           = block_reduce_policy<reduce_method_t, T,warp_size>::reduce(val);\n    const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;\n    const int nthreads = item_ct1.get_local_range(2);\n    const int nwarps = nthreads / WARP_SIZE;\n\n    if (block_size > warp_size) {\n        assert((block_size <= 1024) && (block_size % warp_size) == 0);\n        const int warp_id = item_ct1.get_local_id(2) / warp_size;\n        const int lane_id = item_ct1.get_local_id(2) % warp_size;\n        if (lane_id == 0) {\n            shared_vals[warp_id] = val;\n        }\n        item_ct1.barrier(sycl::access::fence_space::local_space);\n\n        size_t nreduce = nwarps / WARP_SIZE;\n        float tmp = 0.f;\n        if (lane_id < (static_cast<int>(block_size) / warp_size)) {\n            for (size_t i = 0; i < nreduce; i += 1)\n            {\n                tmp += shared_vals[lane_id + i * WARP_SIZE];\n            }\n        }\n        return block_reduce_policy<reduce_method_t, T, warp_size>::reduce(tmp);\n    }\n    return val;\n}\n\n#endif // GGML_SYCL_COMMON_HPP\n"
  },
  {
    "path": "src/ggml-sycl/concat.cpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#include \"concat.hpp\"\n\nstatic inline size_t elem_size(ggml_type t) {\n    return ggml_type_size(t) / ggml_blck_size(t);\n}\n\ntemplate <typename T>\nstatic void concat_T_dim0(const T *x, const T *y, T *dst,\n                            const int ne0, const int ne00,\n                            const sycl::nd_item<3> &item_ct1) {\n  int nidx = item_ct1.get_local_id(2) +\n             item_ct1.get_group(2) * item_ct1.get_local_range(2);\n  if (nidx >= ne0) {\n    return;\n  }\n  // operation\n  int offset_dst = nidx + item_ct1.get_group(1) * ne0 +\n                   item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);\n  if (nidx < ne00) { // src0\n    int offset_src = nidx + item_ct1.get_group(1) * ne00 +\n                     item_ct1.get_group(0) * ne00 * item_ct1.get_group_range(1);\n    dst[offset_dst] = x[offset_src];\n  } else {\n    int offset_src =\n        nidx - ne00 + item_ct1.get_group(1) * (ne0 - ne00) +\n        item_ct1.get_group(0) * (ne0 - ne00) * item_ct1.get_group_range(1);\n    dst[offset_dst] = y[offset_src];\n  }\n}\n\ntemplate <typename T>\nstatic void concat_T_dim1(const T *x, const T *y, T *dst,\n                            const int ne0, const int ne01,\n                            const sycl::nd_item<3> &item_ct1) {\n  int nidx = item_ct1.get_local_id(2) +\n             item_ct1.get_group(2) * item_ct1.get_local_range(2);\n  if (nidx >= ne0) {\n    return;\n  }\n  // operation\n  int offset_dst = nidx + item_ct1.get_group(1) * ne0 +\n                   item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);\n  if (item_ct1.get_group(1) < (size_t) ne01) { // src0\n    int offset_src =\n        nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * ne01;\n    dst[offset_dst] = x[offset_src];\n  } else {\n    int offset_src =\n        nidx + (item_ct1.get_group(1) - ne01) * ne0 +\n        item_ct1.get_group(0) * ne0 * (item_ct1.get_group_range(1) - ne01);\n    dst[offset_dst] = y[offset_src];\n  }\n}\n\ntemplate <typename T>\nstatic void concat_T_dim2(const T *x, const T *y, T *dst,\n                            const int ne0, const int ne02,\n                            const sycl::nd_item<3> &item_ct1) {\n  int nidx = item_ct1.get_local_id(2) +\n             item_ct1.get_group(2) * item_ct1.get_local_range(2);\n  if (nidx >= ne0) {\n    return;\n  }\n  // operation\n  int offset_dst = nidx + item_ct1.get_group(1) * ne0 +\n                   item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);\n  if (item_ct1.get_group(0) < (size_t) ne02) { // src0\n    int offset_src = nidx + item_ct1.get_group(1) * ne0 +\n                     item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);\n    dst[offset_dst] = x[offset_src];\n  } else {\n    int offset_src =\n        nidx + item_ct1.get_group(1) * ne0 +\n        (item_ct1.get_group(0) - ne02) * ne0 * item_ct1.get_group_range(1);\n    dst[offset_dst] = y[offset_src];\n  }\n}\n\ntemplate <typename T>\nstatic void concat_T_sycl(const T *x, const T *y, T *dst,\n                            int ne00, int ne01, int ne02, int ne0, int ne1,\n                            int ne2, int dim, queue_ptr stream) {\n  int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;\n  sycl::range<3> gridDim(ne2, ne1, num_blocks);\n  switch (dim) {\n  case 0:\n      stream->parallel_for(sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),\n                                          sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),\n                        [=](sycl::nd_item<3> item_ct1) { concat_T_dim0<T>(x, y, dst, ne0, ne00, item_ct1); });\n      break;\n  case 1:\n      stream->parallel_for(sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),\n                                          sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),\n                        [=](sycl::nd_item<3> item_ct1) { concat_T_dim1<T>(x, y, dst, ne0, ne01, item_ct1); });\n      break;\n  // dim >=2 will be dispatched to the default path\n  default:\n      stream->parallel_for(sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),\n                                          sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),\n                        [=](sycl::nd_item<3> item_ct1) { concat_T_dim2<T>(x, y, dst, ne0, ne02, item_ct1); });\n      break;\n  }\n}\n\n// non-contiguous kernel (slow)\ntemplate<typename T>\nstatic void concat_T_sycl_non_cont(\n    queue_ptr stream, const char *src0, const char *src1, char *dst,\n    int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, uint64_t nb00,\n    uint64_t nb01, uint64_t nb02, uint64_t nb03, int64_t /*ne10*/,\n    int64_t /*ne11*/, int64_t /*ne12*/, int64_t /*ne13*/, uint64_t nb10,\n    uint64_t nb11, uint64_t nb12, uint64_t nb13, int64_t ne0, int64_t ne1,\n    int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,\n    uint64_t nb3, int32_t dim) {\n  sycl::range<3> gridDim(ne3, ne2, ne1);\n  stream->parallel_for(sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {\n      int64_t i3 = item_ct1.get_group(0);\n      int64_t i2 = item_ct1.get_group(1);\n      int64_t i1 = item_ct1.get_group(2);\n\n      int64_t o[4] = { 0, 0, 0, 0 };\n      o[dim]       = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));\n\n      const T * x;\n\n      for (int i0 = item_ct1.get_local_id(2); i0 < ne0; i0 += item_ct1.get_local_range(2)) {\n          if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {\n              x = (const T *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);\n          } else {\n              x = (const T *) (src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + (i1 - o[1]) * nb11 +\n                                   (i0 - o[0]) * nb10);\n          }\n\n          T *y = (T *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);\n\n          *y = *x;\n      }\n  });\n}\n\ntemplate <typename T>\nvoid concat_impl_sycl(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    const ggml_tensor *  src0   = dst->src[0];\n    const ggml_tensor *  src1   = dst->src[1];\n    queue_ptr            stream = ctx.stream();\n\n    const int32_t dim = ((int32_t *) dst->op_params)[0];\n\n    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {\n        const T * src0_d = (const T *) src0->data;\n        const T * src1_d = (const T *) src1->data;\n        T * dst_d = (T *) dst->data;\n        size_t type_size = elem_size(dst->type);\n        if (dim != 3) {\n            for (int i3 = 0; i3 < dst->ne[3]; i3++) {\n                concat_T_sycl<T>(src0_d + i3 * (src0->nb[3] / type_size), src1_d + i3 * (src1->nb[3] / type_size),\n                                dst_d + i3 * (dst->nb[3] / type_size), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0],\n                                dst->ne[1], dst->ne[2], dim, stream);\n            }\n        } else {\n            const size_t size0 = ggml_nbytes(src0);\n            const size_t size1 = ggml_nbytes(src1);\n\n            SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait()));\n            SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d + size0 / type_size, src1_d, size1).wait()));\n        }\n    } else {\n        concat_T_sycl_non_cont<T>(stream, (const char *) src0->data, (const char *) src1->data, (char *) dst->data,\n                                 src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1],\n                                 src0->nb[2], src0->nb[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],\n                                 src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2],\n                                 dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);\n    }\n}\n\nvoid ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {\n\n    switch (dst->type) {\n    case GGML_TYPE_F32:\n        concat_impl_sycl<float>(ctx, dst);\n        break;\n    case GGML_TYPE_I32:\n        concat_impl_sycl<int32_t>(ctx, dst);\n        break;\n    default:\n    GGML_ASSERT(false && \"ggml_sycl_op_concat: unsupported type\");\n    break;\n    }\n}\n"
  },
  {
    "path": "src/ggml-sycl/concat.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_CONCAT_HPP\n#define GGML_SYCL_CONCAT_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst);\n\n#endif // GGML_SYCL_CONCAT_HPP\n"
  },
  {
    "path": "src/ggml-sycl/conv.cpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#include \"conv.hpp\"\n\nstatic  void conv_transpose_1d_kernel(\n        const int s0, const int output_size,\n        const int src0_ne0, const int src0_ne1, const int src0_ne2,\n        const int src1_ne0, const int dst_ne0,\n        const float * src0, const float * src1,  float * dst,\n        const sycl::nd_item<3> &item_ct1) {\n    int global_index = item_ct1.get_local_id(2) +\n                       item_ct1.get_group(2) * item_ct1.get_local_range(2);\n    if (global_index >= output_size) {\n        return;\n    }\n\n    int out_index = global_index / dst_ne0;\n\n    float accumulator = 0;\n\n    for (int c = 0; c < src0_ne2; c++) {\n        int idx = global_index % dst_ne0;\n\n        int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);\n        int input_offset = src1_ne0 * c;\n\n        for (int i = 0; i < src1_ne0; i++) {\n            if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {\n                continue;\n            }\n            int weight_idx = idx - i*s0;\n\n            float kernel_weight = src0[kernel_offset + weight_idx];\n            float input_value =  src1[input_offset+i];\n\n            accumulator += kernel_weight * input_value;\n        }\n    }\n    dst[global_index] = accumulator;\n}\n\nstatic void conv_transpose_1d_f32_f32_sycl(\n    const int s0, const int output_size,\n    const int src0_ne0, const int src0_ne1, const int src0_ne2,\n    const int src1_ne0, const int dst_ne0,\n    const float *src0, const float *src1, float *dst,\n    const queue_ptr& stream) {\n\n    const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;\n    const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);\n    const sycl::range<3> block_nums(1, 1, num_blocks);\n    stream->parallel_for(\n        sycl::nd_range<3>(\n            block_nums * block_dims, block_dims),\n        [=](sycl::nd_item<3> item_ct1) {\n            conv_transpose_1d_kernel(\n                s0, output_size,\n                src0_ne0, src0_ne1, src0_ne2,\n                src1_ne0, dst_ne0,\n                src0, src1, dst, item_ct1);\n        });\n}\n\nvoid ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    const ggml_tensor *src0 = dst->src[0];\n    const ggml_tensor *src1 = dst->src[1];\n    const float * src0_d = (const float *)src0->data;\n    const float * src1_d = (const float *)src1->data;\n\n    float * dst_d = (float *)dst->data;\n    dpct::queue_ptr stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_ASSERT(ggml_is_contiguous(src0));\n    GGML_ASSERT(ggml_is_contiguous(src1));\n\n    const int32_t * opts = (const int32_t *)dst->op_params;\n\n    const int s0 = opts[0];\n\n    const int64_t output_size = ggml_nelements(dst);\n\n    conv_transpose_1d_f32_f32_sycl(s0, output_size,\n        src0->ne[0], src0->ne[1], src0->ne[2],\n        src1->ne[0], dst->ne[0],\n        src0_d, src1_d, dst_d, stream);\n}\n\n"
  },
  {
    "path": "src/ggml-sycl/conv.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_CONV_HPP\n#define GGML_SYCL_CONV_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst);\n\n#endif // GGML_SYCL_CONV_HPP\n"
  },
  {
    "path": "src/ggml-sycl/convert.cpp",
    "content": "#include \"convert.hpp\"\n#include \"dequantize.hpp\"\n#include \"presets.hpp\"\n\n#if defined(__INTEL_LLVM_COMPILER)\n    #if __has_include(<sycl/ext/oneapi/bfloat16.hpp>)\n        #include <sycl/ext/oneapi/bfloat16.hpp>\n        #define GGML_SYCL_HAS_BF16\n    #endif\n#endif\n\ntemplate <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>\nstatic void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,\n                             const sycl::nd_item<3> &item_ct1) {\n    const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +\n                       item_ct1.get_local_id(2));\n\n    if (i >= k) {\n        return;\n    }\n\n    const int64_t ib = i/qk; // block index\n    const int64_t iqs = (i%qk)/qr; // quant index\n    const int64_t iybs = i - i%qk; // y block start index\n    const int64_t y_offset = qr == 1 ? 1 : qk/2;\n\n    // dequantize\n    dfloat2 v;\n    dequantize_kernel(vx, ib, iqs, v);\n\n    y[iybs + iqs + 0] = v.x();\n    y[iybs + iqs + y_offset] = v.y();\n}\n\ntemplate <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>\nstatic void dequantize_block_sycl(const void *__restrict__ vx,\n                                  dst_t *__restrict__ y, const int64_t k,\n                                  dpct::queue_ptr stream) {\n    const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n        stream->parallel_for(\n            sycl::nd_range<3>(\n                sycl::range<3>(1, 1, num_blocks) *\n                    sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),\n                sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),\n            [=](sycl::nd_item<3> item_ct1) {\n                dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);\n            });\n    }\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,\n                                     dpct::queue_ptr stream) {\n    const int64_t nb = k / QK_K;\n#if QK_K == 256\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 64),\n                                               sycl::range<3>(1, 1, 64)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_q2_K(vx, y, item_ct1);\n                             });\n    }\n#else\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 32),\n                                               sycl::range<3>(1, 1, 32)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_q2_K(vx, y, item_ct1);\n                             });\n    }\n\n#endif\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,\n                                     dpct::queue_ptr stream) {\n    const int64_t nb = k / QK_K;\n#if QK_K == 256\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 64),\n                                               sycl::range<3>(1, 1, 64)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_q3_K(vx, y, item_ct1);\n                             });\n    }\n#else\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 32),\n                                               sycl::range<3>(1, 1, 32)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_q3_K(vx, y, item_ct1);\n                             });\n    }\n#endif\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,\n                                     dpct::queue_ptr stream) {\n    const int64_t nb32 = k / 32;\n    const int64_t nb = (k + 255) / 256;\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 32),\n                                               sycl::range<3>(1, 1, 32)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_q4_0(vx, y, nb32, item_ct1);\n                             });\n    }\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k,\n                                     dpct::queue_ptr stream) {\n\n    dpct::has_capability_or_fail(stream->get_device(),\n                                    {sycl::aspect::fp16});\n\n    int constexpr WARP_K = WARP_SIZE * QK4_0;\n    const int n_warp = (k + WARP_K - 1) / WARP_K;\n    GGML_ASSERT(k % 2 == 0);\n    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *\n        sycl::range<3>(1, 1, WARP_SIZE),\n        sycl::range<3>(1, 1, WARP_SIZE)),\n        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{\n            dequantize_block_q4_0_reorder(vx, y, k, item_ct1);\n        });\n\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,\n                                     dpct::queue_ptr stream) {\n    const int64_t nb32 = k / 32;\n    const int64_t nb = (k + 255) / 256;\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 32),\n                                               sycl::range<3>(1, 1, 32)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_q4_1(vx, y, nb32, item_ct1);\n                             });\n    }\n}\n\n\ntemplate <typename dst_t>\nstatic void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,\n                                     dpct::queue_ptr stream) {\n    const int64_t nb = k / QK_K;\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->submit([&](sycl::handler &cgh) {\n            sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);\n            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 32),\n                                               sycl::range<3>(1, 1, 32)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);\n                             });\n        });\n    }\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {\n    const int64_t nb = k / QK_K;\n    const size_t  local_size  = 32;\n    const size_t  global_size = nb * local_size;\n\n    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });\n\n    stream->submit([&](sycl::handler & cgh) {\n        sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);\n\n        cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),\n                         [=](sycl::nd_item<1> item_ct1) {\n                             dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);\n                         });\n    });\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,\n                                     dpct::queue_ptr stream) {\n    const int64_t nb = k / QK_K;\n#if QK_K == 256\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 64),\n                                               sycl::range<3>(1, 1, 64)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_q5_K(vx, y, item_ct1);\n                             });\n    }\n#else\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 32),\n                                               sycl::range<3>(1, 1, 32)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_q5_K(vx, y, item_ct1);\n                             });\n    }\n\n#endif\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,\n                                     dpct::queue_ptr stream) {\n    const int64_t nb = k / QK_K;\n#if QK_K == 256\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 64),\n                                               sycl::range<3>(1, 1, 64)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_q6_K(vx, y, item_ct1);\n                             });\n    }\n#else\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 32),\n                                               sycl::range<3>(1, 1, 32)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_q6_K(vx, y, item_ct1);\n                             });\n    }\n\n#endif\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {\n    const int64_t nb = k / QK_K;\n\n    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });\n\n    stream->parallel_for(\n        sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),\n        [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,\n                                        dpct::queue_ptr stream) {\n    const int64_t nb = k / QK_K;\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 32),\n                                               sycl::range<3>(1, 1, 32)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_iq1_s(\n                                     vx, y, item_ct1, iq1s_grid_gpu\n                                     );\n                             });\n        });\n    }\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,\n                                        dpct::queue_ptr stream) {\n    const int64_t nb = k / QK_K;\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 32),\n                                               sycl::range<3>(1, 1, 32)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_iq1_m(\n                                     vx, y, item_ct1, iq1s_grid_gpu\n                                     );\n                             });\n        });\n    }\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k,\n                                        dpct::queue_ptr stream) {\n    const int64_t nb = k / QK_K;\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 32),\n                                               sycl::range<3>(1, 1, 32)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_iq2_xxs(\n                                     vx, y, item_ct1, iq2xxs_grid,\n                                     ksigns_iq2xs, kmask_iq2xs);\n                             });\n        });\n    }\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k,\n                                       dpct::queue_ptr stream) {\n    const int64_t nb = k / QK_K;\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 32),\n                                               sycl::range<3>(1, 1, 32)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_iq2_xs(\n                                     vx, y, item_ct1, iq2xs_grid,\n                                     ksigns_iq2xs, kmask_iq2xs);\n                             });\n        });\n    }\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,\n                                      dpct::queue_ptr stream) {\n    const int64_t nb = k / QK_K;\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 32),\n                                               sycl::range<3>(1, 1, 32)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_iq2_s(vx, y, item_ct1);\n                             });\n        });\n    }\n}\n\n\ntemplate <typename dst_t>\nstatic void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k,\n                                        dpct::queue_ptr stream) {\n    const int64_t nb = k / QK_K;\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 32),\n                                               sycl::range<3>(1, 1, 32)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_iq3_xxs(\n                                     vx, y, item_ct1, iq3xxs_grid,\n                                     ksigns_iq2xs, kmask_iq2xs);\n                             });\n        });\n    }\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,\n                                        dpct::queue_ptr stream) {\n    const int64_t nb = k / QK_K;\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                                   sycl::range<3>(1, 1, 32),\n                                               sycl::range<3>(1, 1, 32)),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 dequantize_block_iq3_s(\n                                     vx, y, item_ct1, kmask_iq2xs, iq3s_grid);\n                             });\n        });\n    }\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k,\n                                       dpct::queue_ptr stream) {\n    const int64_t nb = (k + QK_K - 1) / QK_K;\n#if QK_K == 64\n    dequantize_row_iq4_nl_sycl(vx, y, k, stream);\n#else\n      {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                  cgh.parallel_for(\n                      sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                            sycl::range<3>(1, 1, 32),\n                                        sycl::range<3>(1, 1, 32)),\n                      [=](sycl::nd_item<3> item_ct1) {\n                            dequantize_block_iq4_xs(vx, y, item_ct1);\n                      });\n            });\n      }\n#endif\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k,\n                                       dpct::queue_ptr stream) {\n    const int64_t nb = (k + QK_K - 1) / QK_K;\n      {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                  cgh.parallel_for(\n                      sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *\n                                            sycl::range<3>(1, 1, 32),\n                                        sycl::range<3>(1, 1, 32)),\n                      [=](sycl::nd_item<3> item_ct1) {\n                            dequantize_block_iq4_nl(vx, y, item_ct1);\n                      });\n            });\n      }\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {\n    const int nb = (k + QK_K - 1) / QK_K;\n    stream->parallel_for(\n        sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),\n        [=](sycl::nd_item<3> item_ct1) {\n            dequantize_block_mxfp4(vx, y, item_ct1);\n        });\n}\n\ntemplate <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>\nstatic void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y,\n        const int64_t ne00, const int64_t ne01, const int64_t ne02,\n        const int64_t s01, const int64_t s02, const int64_t s03) {\n    auto          item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const int64_t i00 = 2 * (int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2));\n\n    if (i00 >= ne00) {\n        return;\n    }\n\n    const int64_t i01 = item_ct1.get_group(1);\n    const int64_t i02 = item_ct1.get_group(0) % ne02;\n    const int64_t i03 = item_ct1.get_group(0) / ne02;\n\n    const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;\n\n    const int64_t ib = ibx0 + i00/qk; // block index\n    const int64_t iqs = (i00%qk)/qr; // quant index\n    const int64_t iybs = i00 - i00%qk; // y block start index\n    const int64_t y_offset = qr == 1 ? 1 : qk/2;\n\n    // dequantize\n    #ifdef GGML_SYCL_F16\n        sycl::half2 v;\n    #else\n        sycl::float2 v;\n    #endif\n\n    dequantize_kernel(vx, ib, iqs, v);\n\n    const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;\n    y[iy0 + 0]        = ggml_sycl_cast<dst_t>(v.x());\n    y[iy0 + y_offset] = ggml_sycl_cast<dst_t>(v.y());\n}\n\n\ntemplate <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>\nstatic void dequantize_block_nc_sycl(const void *    vx,\n                                  dst_t *         y,\n                                  const int64_t   ne00,\n                                  const int64_t   ne01,\n                                  const int64_t   ne02,\n                                  const int64_t   ne03,\n                                  const int64_t   s01,\n                                  const int64_t   s02,\n                                  const int64_t   s03,\n                                  dpct::queue_ptr stream) {\n    const dpct::dim3 num_blocks((ne00 + 2 * SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2 * SYCL_DEQUANTIZE_BLOCK_SIZE), ne01,\n                                ne02 * ne03);\n    stream->parallel_for(sycl::nd_range<3>(num_blocks * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),\n                                           sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),\n                         [=](sycl::nd_item<3> item_ct1) {\n                             GGML_UNUSED(item_ct1);\n                             dequantize_block_nc<qk, qr, dequantize_kernel>(vx, y, ne00, ne01, ne02, s01, s02, s03);\n                         });\n}\ntemplate <typename src_t, typename dst_t>\nstatic void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,\n                          const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,\n                          const sycl::nd_item<3> & item_ct1) {\n\n    const int64_t work_group_size = item_ct1.get_local_range(2);\n    const int64_t global_id       = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);\n\n    const int64_t i01 = item_ct1.get_group(1);\n    const int64_t i02 = item_ct1.get_group(0) % ne02;\n    const int64_t i03 = item_ct1.get_group(0) / ne02;\n\n    // make each work-item deal with more elements since sycl global range can not exceed max int\n    const src_t * x = static_cast<const src_t *>(vx);\n    const int64_t ix = i03 * s03 + i02 * s02 + i01 * s01;\n    const int64_t iy = ((i03 * ne02 + i02) * ne01 + i01) * ne00;\n\n#pragma unroll\n    for (int64_t i00 = global_id; i00 < ne00; i00 += work_group_size * item_ct1.get_group_range(2)) {\n        y[iy + i00] = static_cast<dst_t>(x[ix + i00]);\n    }\n}\n\ntemplate <typename src_t, typename dst_t>\nstatic void convert_unary_nc_sycl(const void * __restrict__ vx, dst_t * __restrict__ y,\n                                  const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,\n                                  const int64_t s01, const int64_t s02, const int64_t s03, dpct::queue_ptr queue) {\n    dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });\n\n    sycl::range<3> global_size(ne02 * ne03, ne01, ceil_div(ne00, SYCL_DEQUANTIZE_BLOCK_SIZE));\n\n    // decrease global range when it exceeds the max int\n    // TODO: Downsample logic is separated from the kernel, a rewrite is desirable\n    int64_t        downsized_workgroup = downsample_sycl_global_range(global_size[0], SYCL_DEQUANTIZE_BLOCK_SIZE);\n    sycl::range<3> workgroup_size(1, 1, downsized_workgroup);\n\n    queue->parallel_for(sycl::nd_range<3>(global_size * workgroup_size, workgroup_size), [=](sycl::nd_item<3> item_ct1) {\n        convert_unary_nc<src_t>(vx, y, ne00, ne01, ne02, s01, s02, s03, item_ct1);\n    });\n}\n\ntemplate <typename src_t, typename dst_t>\nstatic void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr queue) {\n    convert_unary_nc_sycl<src_t>(vx, y, k, 1, 1, 1, k, k, k, queue);\n}\n\n\nto_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {\n    switch (type) {\n        case GGML_TYPE_Q4_0:\n            if (dst->src[0]->extra &&\n                ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {\n                return dequantize_row_q4_0_sycl_reorder;\n            } else {\n                return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;\n            }\n        case GGML_TYPE_Q4_1:\n            return dequantize_block_sycl<QK4_1, QR4_1, dequantize_q4_1>;\n        case GGML_TYPE_Q5_0:\n            return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;\n        case GGML_TYPE_Q5_1:\n            return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;\n        case GGML_TYPE_Q8_0:\n            return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;\n        case GGML_TYPE_Q2_K:\n            return dequantize_row_q2_K_sycl;\n        case GGML_TYPE_Q3_K:\n            return dequantize_row_q3_K_sycl;\n        case GGML_TYPE_Q4_K:\n            if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {\n                return dequantize_row_q4_K_sycl_reorder;\n            } else {\n                return dequantize_row_q4_K_sycl;\n            }\n        case GGML_TYPE_Q5_K:\n            return dequantize_row_q5_K_sycl;\n        case GGML_TYPE_Q6_K:\n            if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {\n                return dequantize_row_q6_K_sycl_reorder;\n            } else {\n                return dequantize_row_q6_K_sycl;\n            }\n        case GGML_TYPE_IQ1_S:\n            return dequantize_row_iq1_s_sycl;\n        case GGML_TYPE_IQ1_M:\n            return dequantize_row_iq1_m_sycl;\n        case GGML_TYPE_IQ2_XXS:\n            return dequantize_row_iq2_xxs_sycl;\n        case GGML_TYPE_IQ2_XS:\n            return dequantize_row_iq2_xs_sycl;\n        case GGML_TYPE_IQ2_S:\n            return dequantize_row_iq2_s_sycl;\n        case GGML_TYPE_IQ3_XXS:\n            return dequantize_row_iq3_xxs_sycl;\n        case GGML_TYPE_IQ3_S:\n            return dequantize_row_iq3_s_sycl;\n        case GGML_TYPE_IQ4_XS:\n            return dequantize_row_iq4_xs_sycl;\n        case GGML_TYPE_IQ4_NL:\n            return dequantize_row_iq4_nl_sycl;\n        case GGML_TYPE_MXFP4:\n            return dequantize_row_mxfp4_sycl;\n        case GGML_TYPE_F32:\n            return convert_unary_sycl<float>;\n#ifdef GGML_SYCL_HAS_BF16\n        case GGML_TYPE_BF16:\n            return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;\n#endif\n        default:\n            return nullptr;\n    }\n}\n\nto_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {\n    switch (type) {\n        case GGML_TYPE_Q4_0:\n            if (dst->src[0]->extra &&\n                ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {\n                return dequantize_row_q4_0_sycl_reorder;\n            } else {\n                return dequantize_row_q4_0_sycl;\n            }\n        case GGML_TYPE_Q4_1:\n            return dequantize_row_q4_1_sycl;\n        case GGML_TYPE_Q5_0:\n            return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;\n        case GGML_TYPE_Q5_1:\n            return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;\n        case GGML_TYPE_Q8_0:\n            return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;\n        case GGML_TYPE_Q2_K:\n            return dequantize_row_q2_K_sycl;\n        case GGML_TYPE_Q3_K:\n            return dequantize_row_q3_K_sycl;\n        case GGML_TYPE_Q4_K:\n            if (dst->src[0]->extra &&\n                ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {\n                return dequantize_row_q4_K_sycl_reorder;\n            } else {\n                return dequantize_row_q4_K_sycl;\n            }\n        case GGML_TYPE_Q5_K:\n            return dequantize_row_q5_K_sycl;\n        case GGML_TYPE_Q6_K:\n            if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {\n                return dequantize_row_q6_K_sycl_reorder;\n            } else {\n                return dequantize_row_q6_K_sycl;\n            }\n        case GGML_TYPE_IQ1_S:\n            return dequantize_row_iq1_s_sycl;\n        case GGML_TYPE_IQ1_M:\n            return dequantize_row_iq1_m_sycl;\n        case GGML_TYPE_IQ2_XXS:\n            return dequantize_row_iq2_xxs_sycl;\n        case GGML_TYPE_IQ2_XS:\n            return dequantize_row_iq2_xs_sycl;\n        case GGML_TYPE_IQ2_S:\n            return dequantize_row_iq2_s_sycl;\n        case GGML_TYPE_IQ3_XXS:\n            return dequantize_row_iq3_xxs_sycl;\n        case GGML_TYPE_IQ3_S:\n            return dequantize_row_iq3_s_sycl;\n        case GGML_TYPE_IQ4_XS:\n            return dequantize_row_iq4_xs_sycl;\n        case GGML_TYPE_IQ4_NL:\n            return dequantize_row_iq4_nl_sycl;\n        case GGML_TYPE_MXFP4:\n            return dequantize_row_mxfp4_sycl;\n        case GGML_TYPE_F16:\n            return convert_unary_sycl<sycl::half>;\n#ifdef GGML_SYCL_HAS_BF16\n        case GGML_TYPE_BF16:\n            return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;\n#endif\n        default:\n            return nullptr;\n    }\n}\n\n\nto_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_F32:\n            return convert_unary_nc_sycl<float>;\n#ifdef GGML_SYCL_HAS_BF16\n        case GGML_TYPE_BF16:\n            return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>;\n#endif\n        case GGML_TYPE_Q4_0:\n            return dequantize_block_nc_sycl<QK4_0, QR4_0, dequantize_q4_0>;\n        case GGML_TYPE_Q4_1:\n            return dequantize_block_nc_sycl<QK4_1, QR4_1, dequantize_q4_1>;\n        case GGML_TYPE_Q5_0:\n            return dequantize_block_nc_sycl<QK5_0, QR5_0, dequantize_q5_0>;\n        case GGML_TYPE_Q5_1:\n            return dequantize_block_nc_sycl<QK5_1, QR5_1, dequantize_q5_1>;\n        case GGML_TYPE_Q8_0:\n            return dequantize_block_nc_sycl<QK8_0, QR8_0, dequantize_q8_0>;\n        default:\n            return nullptr;\n    }\n}\n"
  },
  {
    "path": "src/ggml-sycl/convert.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2025 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_CONVERT_HPP\n#define GGML_SYCL_CONVERT_HPP\n\n#include \"common.hpp\"\n\ntemplate <typename T>\nusing to_t_sycl_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, dpct::queue_ptr stream);\ntypedef to_t_sycl_t<float>      to_fp32_sycl_t;\ntypedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;\n\nto_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst);\nto_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst);\n\n// Nc = Non-contiguous\ntemplate <typename T>\nusing to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,\n                                   int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue);\n\ntypedef to_t_nc_sycl_t<sycl::half> to_fp16_nc_sycl_t;\nto_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type);\n\ntemplate<typename dst_t, typename src_t>\n inline dst_t ggml_sycl_cast(src_t x) {\n    if constexpr (std::is_same_v<dst_t, src_t>) {\n        return x;\n    } else if constexpr (std::is_same_v<dst_t, sycl::ext::oneapi::bfloat16>) {\n        return sycl::ext::oneapi::bfloat16(float(x));\n    } else if constexpr (std::is_same_v<src_t, sycl::ext::oneapi::bfloat16>) {\n        return static_cast<float>(x);\n    } else if constexpr (std::is_same_v<src_t, sycl::float2> && std::is_same_v<dst_t, sycl::half2>) {\n        return x.template convert<sycl::half, sycl::rounding_mode::rte>();\n    } else if constexpr (std::is_same_v<src_t, sycl::float2> &&\n                         std::is_same_v<dst_t, sycl::vec<sycl::ext::oneapi::bfloat16, 2>>) {\n        return {x.x, x.y};\n    } else if constexpr(std::is_same_v<dst_t, int32_t>) {\n        return int32_t(x);\n    } else {\n        return float(x);\n    }\n}\n\n\n#endif  // GGML_SYCL_CONVERT_HPP\n"
  },
  {
    "path": "src/ggml-sycl/count-equal.cpp",
    "content": "#include \"count-equal.hpp\"\n\n#include <cstdint>\n\ntemplate <typename T>\nstatic void count_equal(const T *__restrict__ x, const T *__restrict__ y,\n                        int64_t *__restrict__ dst, const int64_t dk,\n                        const int64_t k) {\n    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const int64_t i0 = (int64_t)item_ct1.get_group(2) * dk;\n    const int64_t i1 = sycl::min(i0 + dk, k);\n\n    int nequal = 0;\n\n    for (int64_t i = i0 + item_ct1.get_local_id(2); i < i1; i += WARP_SIZE) {\n        const T xi = x[i];\n        const T yi = y[i];\n        nequal += xi == yi;\n    }\n\n    nequal = warp_reduce_sum<WARP_SIZE>(nequal);\n\n    if (item_ct1.get_local_id(2) != 0) {\n        return;\n    }\n\n    dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(\n        (int *)dst, nequal);\n}\n\nvoid ggml_sycl_count_equal(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(src0->type == src1->type);\n    GGML_ASSERT( dst->type == GGML_TYPE_I64);\n\n    GGML_ASSERT(ggml_are_same_shape(src0, src1));\n    GGML_ASSERT(ggml_is_contiguous(src0));\n    GGML_ASSERT(ggml_is_contiguous(src1));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n\n    int64_t * dst_d  = (int64_t *) dst->data;\n\n    dpct::queue_ptr stream = ctx.stream();\n    const int id       = get_current_device_id();\n    const int nsm = ggml_sycl_info().devices[id].nsm;\n\n    const int64_t ne = ggml_nelements(src0);\n    GGML_ASSERT(ne < (1 << 30) && \"atomicAdd implementation only supports int\");\n    const int64_t dne =\n        GGML_PAD((ne + 4 * nsm - 1) / (4 * nsm), SYCL_COUNT_EQUAL_CHUNK_SIZE);\n\n    SYCL_CHECK(CHECK_TRY_ERROR(stream->memset(dst_d, 0, ggml_nbytes(dst))));\n\n    const dpct::dim3 block_dims(WARP_SIZE, 1, 1);\n    const dpct::dim3 block_nums(\n        std::min((int64_t)4 * nsm, (ne + SYCL_COUNT_EQUAL_CHUNK_SIZE - 1) /\n                                       SYCL_COUNT_EQUAL_CHUNK_SIZE),\n        1, 1);\n\n    switch (src0->type) {\n    case GGML_TYPE_I32: {\n        const int *src0_d = (const int *)src0->data;\n        const int *src1_d = (const int *)src1->data;\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) {\n                count_equal(src0_d, src1_d, dst_d, dne, ne);\n                GGML_UNUSED(item_ct1);\n            });\n\n    } break;\n    default:\n        GGML_ASSERT(false);\n        break;\n    }\n}\n"
  },
  {
    "path": "src/ggml-sycl/count-equal.hpp",
    "content": "#ifndef GGML_SYCL_COUNT_EQUAL_HPP\n#define GGML_SYCL_COUNT_EQUAL_HPP\n#include \"common.hpp\"\n\n#define SYCL_COUNT_EQUAL_CHUNK_SIZE 128\n\nvoid ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\n#endif //GGML_SYCL_COUNT_EQUAL_HPP\n"
  },
  {
    "path": "src/ggml-sycl/cpy.cpp",
    "content": "#include \"cpy.hpp\"\n\n#include <float.h>\n\n#include \"dequantize.hpp\"\n#include \"ggml-sycl/common.hpp\"\n#include \"ggml-sycl/presets.hpp\"\n#include \"ggml.h\"\n\n\nstatic void cpy_1_f32_f32(const char * cxi, char * cdsti) {\n    const float * xi   = (const float *) cxi;\n    float *       dsti = (float *) cdsti;\n\n    *dsti = *xi;\n}\n\nstatic void cpy_1_f32_f16(const char * cxi, char * cdsti) {\n    const float * xi   = (const float *) cxi;\n    sycl::half *  dsti = (sycl::half *) cdsti;\n\n    *dsti = sycl::vec<float, 1>(*xi).convert<sycl::half, sycl::rounding_mode::automatic>()[0];\n}\n\nstatic void cpy_1_f16_f16(const char * cxi, char * cdsti) {\n    const sycl::half * xi   = (const sycl::half *) cxi;\n    sycl::half *       dsti = (sycl::half *) cdsti;\n\n    *dsti = *xi;\n}\n\nstatic void cpy_1_f16_f32(const char * cxi, char * cdsti) {\n    const sycl::half * xi   = (const sycl::half *) cxi;\n    float *            dsti = (float *) cdsti;\n\n    *dsti = *xi;\n}\n\nstatic void cpy_1_i16_i16(const char * cxi, char * cdsti) {\n    const int16_t * xi   = (const int16_t *) cxi;\n    int16_t *       dsti = (int16_t *) cdsti;\n\n    *dsti = *xi;\n}\n\nstatic void cpy_1_i32_i32(const char * cxi, char * cdsti) {\n    const int32_t * xi   = (const int32_t *) cxi;\n    int32_t *       dsti = (int32_t *) cdsti;\n\n    *dsti = *xi;\n}\n\ntemplate <cpy_kernel_t cpy_1>\nstatic void cpy_f32_f16(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,\n                        const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,\n                        const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,\n                        const sycl::nd_item<3> & item_ct1) {\n    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);\n\n    if (i >= ne) {\n        return;\n    }\n\n    // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor\n    // then combine those indices with the corresponding byte offsets to get the total offsets\n    const int i03      = i / (ne00 * ne01 * ne02);\n    const int i02      = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);\n    const int i01      = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;\n    const int i00      = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;\n    const int x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;\n\n    const int i13        = i / (ne10 * ne11 * ne12);\n    const int i12        = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);\n    const int i11        = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;\n    const int i10        = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;\n    const int dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;\n\n    cpy_1(cx + x_offset, cdst + dst_offset);\n}\n\n\n/* quantized type same copy */\ntemplate<typename T>\nstatic void cpy_blck_q_q(const char * cxi, char * cdsti) {\n    const T * xi = (const T *) cxi;\n    T * dsti = (T *) cdsti;\n    *dsti = *xi;\n}\n\n\nstatic void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {\n    float * cdstf = (float *) (cdsti);\n\n    for (int j = 0; j < QK8_0; j += 2) {\n        dfloat2 dq;\n        dequantize_q8_0(cxi, 0, j, dq);\n        *(cdstf + j)     = dq.x();\n        *(cdstf + j + 1) = dq.y();\n    }\n}\n\n\n\ntemplate <dequantize_kernel_t dequant, int qk> static void cpy_blck_q_f32(const char * cxi, char * cdsti) {\n    float * cdstf = (float *) (cdsti);\n\n    for (int j = 0; j < qk / 2; j++) {\n        dfloat2 dq;\n        dequant(cxi, 0, j, dq);\n        *(cdstf + j)          = dq.x();\n        *(cdstf + j + qk / 2) = dq.y();\n    }\n}\n\n\ntemplate <typename T, int qk>\nstatic void cpy_q_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,\n                      const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,\n                      const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,\n                      const sycl::nd_item<3> & item_ct1) {\n    const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;\n\n    if (i >= ne) {\n        return;\n    }\n\n    const int i03      = i / (ne00 * ne01 * ne02);\n    const int i02      = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);\n    const int i01      = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;\n    const int i00      = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;\n    const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;\n\n\n    const int i13        = i / (ne10 * ne11 * ne12);\n    const int i12        = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);\n    const int i11        = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;\n    const int i10        = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;\n    const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;\n\n    cpy_blck_q_q<T>(cx + x_offset, cdst + dst_offset);\n}\n\ntemplate <cpy_kernel_t cpy_blck, int qk>\nstatic void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,\n                      const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,\n                      const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,\n                      const sycl::nd_item<3> & item_ct1) {\n    const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;\n\n    if (i >= ne) {\n        return;\n    }\n\n\n    const int i03      = i / (ne00 * ne01 * ne02);\n    const int i02      = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);\n    const int i01      = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;\n    const int i00      = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;\n    const int x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;\n\n    const int i13        = i / (ne10 * ne11 * ne12);\n    const int i12        = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);\n    const int i11        = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;\n    const int i10        = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;\n    const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;\n\n    cpy_blck(cx + x_offset, cdst + dst_offset);\n}\n\ntemplate <cpy_kernel_t cpy_blck, int qk>\nstatic void cpy_q_f32(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,\n                      const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,\n                      const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,\n                      const sycl::nd_item<3> & item_ct1) {\n    const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;\n\n    if (i >= ne) {\n        return;\n    }\n\n    const int i03      = i / (ne00 * ne01 * ne02);\n    const int i02      = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);\n    const int i01      = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;\n    const int i00      = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;\n    const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;\n\n    const int i13        = i / (ne10 * ne11 * ne12);\n    const int i12        = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);\n    const int i11        = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;\n    const int i10        = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;\n    const int dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;\n\n    cpy_blck(cx + x_offset, cdst + dst_offset);\n}\n\nstatic void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                  const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                  const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;\n    {\n        dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });\n\n        stream->parallel_for(\n            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),\n                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),\n            [=](sycl::nd_item<3> item_ct1) {\n                cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,\n                                           nb10, nb11, nb12, nb13, item_ct1);\n            });\n    }\n}\n\nstatic void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                  const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                  const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;\n    {\n        dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });\n\n        stream->parallel_for(\n            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),\n                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),\n            [=](sycl::nd_item<3> item_ct1) {\n                cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,\n                                           nb10, nb11, nb12, nb13, item_ct1);\n            });\n    }\n}\n\nstatic void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                  const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                  const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;\n    {\n        dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });\n\n        stream->parallel_for(\n            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),\n                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),\n            [=](sycl::nd_item<3> item_ct1) {\n                cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,\n                                           nb10, nb11, nb12, nb13, item_ct1);\n            });\n    }\n}\n\nstatic void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n    GGML_ASSERT(ne % QK8_0 == 0);\n    const int num_blocks = ne / QK8_0;\n    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),\n                         [=](sycl::nd_item<3> item_ct1) {\n                             cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,\n                                                                 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);\n                         });\n}\n\nstatic void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = ne;\n    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),\n                         [=](sycl::nd_item<3> item_ct1) {\n                             cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,\n                                                                 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);\n                         });\n}\n\nstatic void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n    GGML_ASSERT(ne % QK4_0 == 0);\n    const int num_blocks = ne / QK4_0;\n    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),\n                         [=](sycl::nd_item<3> item_ct1) {\n                             cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,\n                                                                 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);\n                         });\n}\n\nstatic void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = ne;\n    stream->parallel_for(\n        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {\n            cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,\n                                                                     nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,\n                                                                     item_ct1);\n        });\n}\n\nstatic void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n    GGML_ASSERT(ne % QK4_1 == 0);\n    const int num_blocks = ne / QK4_1;\n    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),\n                         [=](sycl::nd_item<3> item_ct1) {\n                             cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,\n                                                                 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);\n                         });\n}\n\nstatic void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = ne;\n    stream->parallel_for(\n        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {\n            cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,\n                                                                     nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,\n                                                                     item_ct1);\n        });\n}\n\nstatic void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n    GGML_ASSERT(ne % QK5_0 == 0);\n    const int num_blocks = ne / QK5_0;\n    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),\n                         [=](sycl::nd_item<3> item_ct1) {\n                             cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,\n                                                                 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);\n                         });\n}\n\nstatic void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = ne;\n    stream->parallel_for(\n        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {\n            cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,\n                                                                     nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,\n                                                                     item_ct1);\n        });\n}\n\nstatic void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n    GGML_ASSERT(ne % QK5_1 == 0);\n    const int num_blocks = ne / QK5_1;\n    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),\n                         [=](sycl::nd_item<3> item_ct1) {\n                             cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,\n                                                                 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);\n                         });\n}\n\nstatic void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = ne;\n    stream->parallel_for(\n        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {\n            cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,\n                                                                     nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,\n                                                                     item_ct1);\n        });\n}\n\nstatic void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                     const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                     const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                     const int nb12, const int nb13, queue_ptr stream) {\n    GGML_ASSERT(ne % QK4_NL == 0);\n    const int num_blocks = ne / QK4_NL;\n    stream->parallel_for(\n        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {\n            cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,\n                                                   ne12, nb10, nb11, nb12, nb13, item_ct1);\n        });\n}\n\nstatic void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                  const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                  const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;\n    {\n        dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });\n\n        stream->parallel_for(\n            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),\n                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),\n            [=](sycl::nd_item<3> item_ct1) {\n                cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,\n                                           nb10, nb11, nb12, nb13, item_ct1);\n            });\n    }\n}\n\nstatic void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                  const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                  const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;\n    {\n        // dpct::has_capability_or_fail(stream->get_device(),\n        //                              {sycl::aspect::fp16});\n\n        stream->parallel_for(\n            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),\n                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),\n            [=](sycl::nd_item<3> item_ct1) {\n                cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,\n                                           nb10, nb11, nb12, nb13, item_ct1);\n            });\n    }\n}\n\nstatic void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                  const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                  const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;\n    {\n        // dpct::has_capability_or_fail(stream->get_device(),\n        //                              {sycl::aspect::fp16});\n\n        stream->parallel_for(\n            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),\n                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),\n            [=](sycl::nd_item<3> item_ct1) {\n                cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,\n                                           nb10, nb11, nb12, nb13, item_ct1);\n            });\n    }\n}\n\nstatic void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);\n    stream->parallel_for(\n        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),\n                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {\n            cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);\n        });\n}\n\n\nstatic void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);\n    stream->parallel_for(\n        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),\n                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {\n            cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);\n        });\n}\n\n\nstatic void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);\n\n    stream->parallel_for(\n        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),\n                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {\n            cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);\n        });\n}\n\n\nstatic void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n    const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);\n    stream->parallel_for(\n        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {\n            cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);\n        });\n}\n\n\nstatic void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,\n                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,\n                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,\n                                   const int nb12, const int nb13, queue_ptr stream) {\n\n   const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);\n   stream->parallel_for(\n        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {\n            cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);\n        });\n}\n\nvoid ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try {\n    // Unlike other operators ggml_sycl_cpy takes 2 distinct tensors instead of a dst ggml_tensor and rely on its src field\n    scope_op_debug_print scope_dbg_print(__func__, src1, /*num_src=*/0, debug_get_tensor_str(\"\\tsrc0\", src0));\n    const int64_t ne = ggml_nelements(src0);\n    GGML_ASSERT(ne == ggml_nelements(src1));\n\n    GGML_TENSOR_BINARY_OP_LOCALS01;\n\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    queue_ptr main_stream = ctx.stream();\n\n    char * src0_ddc = (char *) src0->data;\n    char * src1_ddc = (char *) src1->data;\n    if ((src0->type == src1->type) && (ggml_is_contiguous(src0) && ggml_is_contiguous(src1))) {\n        GGML_SYCL_DEBUG(\"%s: memcpy path\\n\", __func__);\n        main_stream->memcpy(src1_ddc, src0_ddc, ggml_nbytes(src0));\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {\n        ggml_cpy_f32_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                              nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {\n        ggml_cpy_f32_f16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                              nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {\n        ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                               nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {\n        ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                               nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {\n        ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                               nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {\n        ggml_cpy_f16_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                              nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {\n        ggml_cpy_f16_f16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                              nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {\n        ggml_cpy_i16_i16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                              nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {\n        ggml_cpy_i32_i32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                              nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {\n        ggml_cpy_q4_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                               nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {\n        ggml_cpy_q4_1_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                               nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {\n        ggml_cpy_q8_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                               nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {\n        ggml_cpy_f32_q5_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                               nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {\n        ggml_cpy_q5_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                               nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {\n        ggml_cpy_f32_q5_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                               nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {\n        ggml_cpy_q5_1_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,\n                               nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {\n        ggml_cpy_f32_iq4_nl_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,\n                                 nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {\n        ggml_cpy_q8_0_q8_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_Q5_0) {\n        ggml_cpy_q5_0_q5_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_Q5_1) {\n        ggml_cpy_q5_1_q5_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_Q4_0) {\n        ggml_cpy_q4_0_q4_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_Q4_1) {\n        ggml_cpy_q4_1_q4_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);\n    } else {\n        GGML_LOG_ERROR(\"%s: unsupported type combination (%s to %s)\\n\", __func__, ggml_type_name(src0->type),\n                       ggml_type_name(src1->type));\n        GGML_ABORT(\"fatal error\");\n    }\n} catch (const sycl::exception & exc) {\n    std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__ << \", line:\" << __LINE__ << std::endl;\n    std::exit(1);\n}\n\nvoid ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_cpy(ctx, dst->src[0], dst);\n}\n"
  },
  {
    "path": "src/ggml-sycl/cpy.hpp",
    "content": "#ifndef GGML_SYCL_CPY_HPP\n#define GGML_SYCL_CPY_HPP\n\n#include \"common.hpp\"\n#include <float.h>\n\ntypedef void (*cpy_kernel_t)(const char * cx, char * cdst);\n\n__dpct_inline__ int best_index_int8(int n, const int8_t * val, float x) {\n    if (x <= val[0]) {\n        return 0;\n    }\n    if (x >= val[n - 1]) {\n        return n - 1;\n    }\n    int ml = 0, mu = n - 1;\n    while (mu - ml > 1) {\n        int mav = (ml + mu) / 2;\n        if (x < val[mav]) {\n            mu = mav;\n        } else {\n            ml = mav;\n        }\n    }\n    return x - val[mu - 1] < val[mu] - x ? mu - 1 : mu;\n}\n\ninline void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {\n    const float * xi   = (const float *) cxi;\n    block_q8_0 *  dsti = (block_q8_0 *) cdsti;\n\n    float amax = 0.0f;  // absolute max\n\n    for (int j = 0; j < QK8_0; j++) {\n        const float v = xi[j];\n        amax          = sycl::fmax(amax, sycl::fabs((float) v));\n    }\n\n    const float d  = amax / ((1 << 7) - 1);\n    const float id = d ? 1.0f / d : 0.0f;\n\n    dsti->d = d;\n\n    for (int j = 0; j < QK8_0; ++j) {\n        const float x0 = xi[j] * id;\n\n        dsti->qs[j] = sycl::round((float) x0);\n    }\n}\n\ninline void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {\n    const float * xi   = (const float *) cxi;\n    block_q4_0 *  dsti = (block_q4_0 *) cdsti;\n\n    float amax = 0.0f;\n    float vmax = 0.0f;\n\n    for (int j = 0; j < QK4_0; ++j) {\n        const float v = xi[j];\n        if (amax < sycl::fabs((float) v)) {\n            amax = sycl::fabs((float) v);\n            vmax = v;\n        }\n    }\n\n    const float d  = vmax / -8;\n    const float id = d ? 1.0f / d : 0.0f;\n\n    dsti->d = d;\n\n    for (int j = 0; j < QK4_0 / 2; ++j) {\n        const float x0 = xi[0 + j] * id;\n        const float x1 = xi[QK4_0 / 2 + j] * id;\n\n        const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 8.5f));\n        const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 8.5f));\n\n        dsti->qs[j] = xi0;\n        dsti->qs[j] |= xi1 << 4;\n    }\n}\n\ninline void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {\n    const float * xi   = (const float *) cxi;\n    block_q4_1 *  dsti = (block_q4_1 *) cdsti;\n\n    float vmin = FLT_MAX;\n    float vmax = -FLT_MAX;\n\n    for (int j = 0; j < QK4_1; ++j) {\n        const float v = xi[j];\n\n        vmin = sycl::min(v, vmin);\n        vmax = sycl::max(v, vmax);\n    }\n\n    const float d  = (vmax - vmin) / ((1 << 4) - 1);\n    const float id = d ? 1.0f / d : 0.0f;\n\n    dsti->dm.x() = d;\n    dsti->dm.y() = vmin;\n\n    for (int j = 0; j < QK4_1 / 2; ++j) {\n        const float x0 = (xi[0 + j] - vmin) * id;\n        const float x1 = (xi[QK4_1 / 2 + j] - vmin) * id;\n\n        const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 0.5f));\n        const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 0.5f));\n\n        dsti->qs[j] = xi0;\n        dsti->qs[j] |= xi1 << 4;\n    }\n}\n\ninline void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {\n    const float * xi   = (const float *) cxi;\n    block_q5_0 *  dsti = (block_q5_0 *) cdsti;\n\n    float amax = 0.0f;\n    float vmax = 0.0f;\n\n    for (int j = 0; j < QK5_0; ++j) {\n        const float v = xi[j];\n        if (amax < sycl::fabs((float) v)) {\n            amax = sycl::fabs((float) v);\n            vmax = v;\n        }\n    }\n\n    const float d  = vmax / -16;\n    const float id = d ? 1.0f / d : 0.0f;\n\n    dsti->d = d;\n\n    uint32_t qh = 0;\n    for (int j = 0; j < QK5_0 / 2; ++j) {\n        const float x0 = xi[0 + j] * id;\n        const float x1 = xi[QK5_0 / 2 + j] * id;\n\n        const uint8_t xi0 = dpct::min(31, (int8_t) (x0 + 16.5f));\n        const uint8_t xi1 = dpct::min(31, (int8_t) (x1 + 16.5f));\n\n        dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);\n        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);\n        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0 / 2);\n    }\n    memcpy(dsti->qh, &qh, sizeof(qh));\n}\n\ninline void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {\n    const float * xi   = (const float *) cxi;\n    block_q5_1 *  dsti = (block_q5_1 *) cdsti;\n\n    float min = xi[0];\n    float max = xi[0];\n\n    for (int j = 1; j < QK5_1; ++j) {\n        const float v = xi[j];\n        min           = v < min ? v : min;\n        max           = v > max ? v : max;\n    }\n\n    const float d  = (max - min) / 31;\n    const float id = d ? 1.0f / d : 0.0f;\n\n    dsti->dm.x() = d;\n    dsti->dm.y() = min;\n\n    uint32_t qh = 0;\n    for (int j = 0; j < QK5_1 / 2; ++j) {\n        const float x0 = (xi[0 + j] - min) * id;\n        const float x1 = (xi[QK5_1 / 2 + j] - min) * id;\n\n        const uint8_t xi0 = (uint8_t) (x0 + 0.5f);\n        const uint8_t xi1 = (uint8_t) (x1 + 0.5f);\n\n        dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);\n        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);\n        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1 / 2);\n    }\n    memcpy(dsti->qh, &qh, sizeof(qh));\n}\n\ninline void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {\n    const float *  xi   = (const float *) cxi;\n    block_iq4_nl * dsti = (block_iq4_nl *) cdsti;\n\n    float amax = 0.0f;\n    float vmax = 0.0f;\n\n    for (int j = 0; j < QK4_NL; ++j) {\n        const float v = xi[j];\n        if (amax < sycl::fabs((float) v)) {\n            amax = sycl::fabs((float) v);\n            vmax = v;\n        }\n    }\n\n    float       d  = vmax / kvalues_iq4nl[0];\n    const float id = d ? 1.0f / d : 0.0f;\n\n    float sumqx = 0, sumq2 = 0;\n    for (int j = 0; j < QK4_NL / 2; ++j) {\n        const float   x0  = xi[0 + j] * id;\n        const float   x1  = xi[QK4_NL / 2 + j] * id;\n        const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);\n        const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);\n        dsti->qs[j]       = xi0 | (xi1 << 4);\n        const float v0    = kvalues_iq4nl[xi0];\n        const float v1    = kvalues_iq4nl[xi1];\n        const float w0    = xi[0 + j] * xi[0 + j];\n        const float w1    = xi[QK4_NL / 2 + j] * xi[QK4_NL / 2 + j];\n        sumqx += w0 * v0 * xi[j] + w1 * v1 * xi[QK4_NL / 2 + j];\n        sumq2 += w0 * v0 * v0 + w1 * v1 * v1;\n    }\n\n    dsti->d = sumq2 > 0 ? sumqx / sumq2 : d;\n}\n\nvoid ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1);\nvoid ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\n#endif  // GGML_SYCL_CPY_HPP\n"
  },
  {
    "path": "src/ggml-sycl/dequantize.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_DEQUANTIZE_HPP\n#define GGML_SYCL_DEQUANTIZE_HPP\n\n#include \"common.hpp\"\n\ntypedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);\ntypedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs,\n                                            const int iqs, dfloat2 &v);\n\nstatic __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,\n                                            const int iqs, dfloat2 &v) {\n    const block_q4_0 * x = (const block_q4_0 *) vx;\n\n    const dfloat d = x[ib].d;\n\n    const int vui = x[ib].qs[iqs];\n\n    v.x() = vui & 0xF;\n    v.y() = vui >> 4;\n\n#ifdef GGML_SYCL_F16\n    // v = v - {8.0f, 8.0f};\n    // v = v * {d, d};\n    v.s0() = (v.s0() - 8.0f) * d;\n    v.s1() = (v.s1() - 8.0f) * d;\n\n#else\n    v.x() = (v.x() - 8.0f) * d;\n    v.y() = (v.y() - 8.0f) * d;\n#endif // GGML_SYCL_F16\n}\n\nstatic __dpct_inline__ void dequantize_q4_0_reorder(const void *d_ptr, const int64_t ib, const void *qs,\n                                            const int iqs, dfloat2 &v) {\n    // const block_q4_0 * x = (const block_q4_0 *) vx;\n\n    const dfloat d = (const dfloat)*((const sycl::half*)d_ptr+ib);\n\n    const int vui = *((const uint8_t *)qs+iqs);\n\n    v.x() = vui & 0xF;\n    v.y() = vui >> 4;\n\n#ifdef GGML_SYCL_F16\n    // v = v - {8.0f, 8.0f};\n    // v = v * {d, d};\n    v.s0() = (v.s0() - 8.0f) * d;\n    v.s1() = (v.s1() - 8.0f) * d;\n\n#else\n    v.x() = (v.x() - 8.0f) * d;\n    v.y() = (v.y() - 8.0f) * d;\n#endif // GGML_SYCL_F16\n}\n\nstatic __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,\n                                            const int iqs, dfloat2 &v) {\n    const block_q4_1 * x = (const block_q4_1 *) vx;\n\n    const dfloat d = x[ib].dm[0];\n    const dfloat m = x[ib].dm[1];\n\n    const int vui = x[ib].qs[iqs];\n\n    v.x() = vui & 0xF;\n    v.y() = vui >> 4;\n\n#ifdef GGML_SYCL_F16\n    // v = v * {d, d};\n    // v = v + {m, m};\n    v.s0() = sycl::fma(v.s0(), d, m);\n    v.s1() = sycl::fma(v.s1(), d, m);\n\n#else\n    v.x() = sycl::fma(v.x(), d, m);\n    v.y() = sycl::fma(v.y(), d, m);\n#endif // GGML_SYCL_F16\n}\n\nstatic __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib,\n                                            const int iqs, dfloat2 &v) {\n    const block_q5_0 * x = (const block_q5_0 *) vx;\n\n    const dfloat d = x[ib].d;\n\n    uint32_t qh;\n    memcpy(&qh, x[ib].qh, sizeof(qh));\n\n    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;\n    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;\n\n    v.x() = ((x[ib].qs[iqs] & 0xf) | xh_0);\n    v.y() = ((x[ib].qs[iqs] >> 4) | xh_1);\n\n#ifdef GGML_SYCL_F16\n    // v = v - {16.0f, 16.0f};\n    // v = v * {d, d};\n    v.s0() = (v.s0() - 16.0f) * d;\n    v.s1() = (v.s1() - 16.0f) * d;\n\n#else\n    v.x() = (v.x() - 16.0f) * d;\n    v.y() = (v.y() - 16.0f) * d;\n#endif // GGML_SYCL_F16\n}\n\nstatic __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib,\n                                            const int iqs, dfloat2 &v) {\n    const block_q5_1 * x = (const block_q5_1 *) vx;\n\n    const dfloat d = x[ib].dm[0];\n    const dfloat m = x[ib].dm[1];\n\n    uint32_t qh;\n    memcpy(&qh, x[ib].qh, sizeof(qh));\n\n    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;\n    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;\n\n    v.x() = ((x[ib].qs[iqs] & 0xf) | xh_0);\n    v.y() = ((x[ib].qs[iqs] >> 4) | xh_1);\n\n#ifdef GGML_SYCL_F16\n    // v = v * {d, d};\n    // v = v + {m, m};\n    v.s0() = sycl::fma(v.s0(), d, m);\n    v.s1() = sycl::fma(v.s1(), d, m);\n#else\n    v.x() = sycl::fma(v.x(), d, m);\n    v.y() = sycl::fma(v.y(), d, m);\n#endif // GGML_SYCL_F16\n}\n\nstatic __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib,\n                                            const int iqs, dfloat2 &v) {\n    const block_q8_0 * x = (const block_q8_0 *) vx;\n\n    const dfloat d = x[ib].d;\n\n    v.x() = x[ib].qs[iqs + 0];\n    v.y() = x[ib].qs[iqs + 1];\n\n#ifdef GGML_SYCL_F16\n    // v = v * {d, d};\n    v.s0() *= d;\n    v.s1() *= d;\n#else\n    v.x() *= d;\n    v.y() *= d;\n#endif // GGML_SYCL_F16\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,\n                                  const sycl::nd_item<3> &item_ct1) {\n\n    const int64_t i = item_ct1.get_group(2);\n\n    // assume 32 threads\n    const int64_t tid = item_ct1.get_local_id(2);\n    const int64_t il  = tid/8;\n    const int64_t ir  = tid%8;\n    const int64_t ib = 8*i + ir;\n    if (ib >= nb32) {\n        return;\n    }\n\n    dst_t * y = yy + 256*i + 32*ir + 4*il;\n\n    const block_q4_0 * x = (const block_q4_0 *)vx + ib;\n    const float d = sycl::vec<sycl::half, 1>(x->d)\n                        .convert<float, sycl::rounding_mode::automatic>()[0];\n    const float dm = -8*d;\n\n    const uint8_t * q = x->qs + 4*il;\n\n    for (int l = 0; l < 4; ++l) {\n        y[l+ 0] = d * (q[l] & 0xF) + dm;\n        y[l+16] = d * (q[l] >>  4) + dm;\n    }\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,\n                                  const sycl::nd_item<3> &item_ct1) {\n\n    const int64_t i = item_ct1.get_group(2);\n    auto k=nb32;\n    // assume 32 threads\n    const int64_t tid = item_ct1.get_local_id(2);\n    const int lane_ib = i * WARP_SIZE + tid;\n\n    if (lane_ib >= k / QK4_0) {\n        return;\n    }\n\n    dst_t * y_ptr = yy + lane_ib * QK4_0;\n\n    auto qs = (const uint8_t*)vx + lane_ib * QK4_0 / 2;\n    auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib;\n\n    const float d = float(*s_ptr);\n\n#pragma unroll\n    for (int l = 0; l < QK4_0 / 2; ++l) {\n        int vq = qs[l];\n        y_ptr[l + 0] = d * ((vq & 0xF) - 8);\n        y_ptr[l + 16] = d * ((vq >> 4) - 8);\n    }\n\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,\n                                  const sycl::nd_item<3> &item_ct1) {\n\n    const int64_t i = item_ct1.get_group(2);\n\n    // assume 32 threads\n    const int64_t tid = item_ct1.get_local_id(2);\n    const int64_t il  = tid/8;\n    const int64_t ir  = tid%8;\n    const int64_t ib = 8*i + ir;\n    if (ib >= nb32) {\n        return;\n    }\n\n    dst_t * y = yy + 256*i + 32*ir + 4*il;\n\n    const block_q4_1 * x = (const block_q4_1 *)vx + ib;\n    const sycl::float2 d =\n        x->dm.convert<float, sycl::rounding_mode::automatic>();\n\n    const uint8_t * q = x->qs + 4*il;\n\n    for (int l = 0; l < 4; ++l) {\n        y[l + 0] = d.x() * (q[l] & 0xF) + d.y();\n        y[l + 16] = d.x() * (q[l] >> 4) + d.y();\n    }\n}\n\n\n//================================== k-quants\n\ntemplate<typename dst_t>\nstatic void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy,\n                                  const sycl::nd_item<3> &item_ct1) {\n\n    const int64_t i = item_ct1.get_group(2);\n    const block_q2_K * x = (const block_q2_K *) vx;\n\n    const int64_t tid = item_ct1.get_local_id(2);\n#if QK_K == 256\n    const int64_t n   = tid/32;\n    const int64_t l   = tid - 32*n;\n    const int64_t is  = 8*n + l/16;\n\n    const uint8_t q = x[i].qs[32*n + l];\n    dst_t * y = yy + i*QK_K + 128*n;\n\n    float dall = x[i].dm[0];\n    float dmin = x[i].dm[1];\n    y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);\n    y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);\n    y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);\n    y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);\n#else\n    const int64_t is = tid/16;  // 0 or 1\n    const int64_t il = tid%16;  // 0...15\n    const uint8_t q = x[i].qs[il] >> (2*is);\n    dst_t * y = yy + i*QK_K + 16*is + il;\n\n    float dall = x[i].dm[0];\n    float dmin = x[i].dm[1];\n    y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);\n    y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);\n#endif\n\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy,\n                                  const sycl::nd_item<3> &item_ct1) {\n\n    const int64_t i = item_ct1.get_group(2);\n    const block_q3_K * x = (const block_q3_K *) vx;\n\n#if QK_K == 256\n    const int64_t r = item_ct1.get_local_id(2) / 4;\n    const int64_t tid = r/2;\n    const int64_t is0 = r%2;\n    const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);\n    const int64_t n = tid / 4;\n    const int64_t j = tid - 4*n;\n\n    uint8_t m = 1 << (4*n + j);\n    int64_t is = 8*n + 2*j + is0;\n    int shift = 2*j;\n\n    int8_t us = is <  4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :\n                is <  8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :\n                is < 12 ? (x[i].scales[is-8] >>  4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :\n                          (x[i].scales[is-8] >>  4) | (((x[i].scales[is-4] >> 6) & 3) << 4);\n    float d_all = x[i].d;\n    float dl = d_all * (us - 32);\n\n    dst_t * y = yy + i*QK_K + 128*n + 32*j;\n    const uint8_t * q = x[i].qs + 32*n;\n    const uint8_t * hm = x[i].hmask;\n\n    for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));\n#else\n    const int64_t tid = item_ct1.get_local_id(2);\n    const int64_t is  = tid/16;  // 0 or 1\n    const int64_t il  = tid%16;  // 0...15\n    const int64_t im  = il/8;    // 0...1\n    const int64_t in  = il%8;    // 0...7\n\n    dst_t * y = yy + i*QK_K + 16*is + il;\n\n    const uint8_t q = x[i].qs[il] >> (2*is);\n    const uint8_t h = x[i].hmask[in] >> (2*is + im);\n    const float   d = (float)x[i].d;\n\n    if (is == 0) {\n        y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));\n        y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));\n    } else {\n        y[ 0] = d * ((x[i].scales[0] >>  4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));\n        y[32] = d * ((x[i].scales[1] >>  4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));\n    }\n#endif\n\n}\n\n#if QK_K == 256\nstatic inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {\n    if (j < 4) {\n        d = q[j] & 63;\n        m = q[j + 4] & 63;\n    } else {\n        d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);\n        m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);\n    }\n}\n#endif\n\ntemplate <typename dst_t>\ninline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall,\n                                   const float dmin, uint8_t * __restrict__ scales_local, int il, int ir) {\n    const int is = 2 * il;\n    constexpr int n  = 4;\n\n    uint8_t sc, m;\n    get_scale_min_k4(is + 0, scales_local, sc, m);\n    const float d1 = dall * sc;\n    const float m1 = dmin * m;\n\n    get_scale_min_k4(is + 1, scales_local, sc, m);\n    const float d2 = dall * sc;\n    const float m2 = dmin * m;\n\n    sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(qs_ptr + 32 * il + n * ir);\n    for (int l = 0; l < n; ++l) {\n        y[l + 0]  = d1 * (q_vec[l] & 0xF) - m1;\n        y[l + 32] = d2 * (q_vec[l] >> 4) - m2;\n    }\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,\n                                  uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {\n    const block_q4_K * x = (const block_q4_K *) vx;\n\n    const int64_t i = item_ct1.get_group(2);\n\n#if QK_K == 256\n    const int64_t tid = item_ct1.get_local_id(2);\n    const int64_t il  = tid / 8;\n    const int64_t ir  = tid % 8;\n\n    dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;\n\n    const sycl::half2 dm = x[i].dm;\n    const float dall = dm[0];\n    const float dmin = dm[1];\n\n    if (tid < 12) {\n        scales_local[tid] = x[i].scales[tid];\n    }\n\n    item_ct1.barrier(sycl::access::fence_space::local_space);\n    dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, il, ir);\n#else\n    const int64_t tid = item_ct1.get_local_id(2);\n    const uint8_t * q = x[i].qs;\n    dst_t * y = yy + i*QK_K;\n    const float d = (float)x[i].dm[0];\n    const float m = (float)x[i].dm[1];\n    y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);\n    y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >>  4) - m * (x[i].scales[1] >> 4);\n#endif\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local,\n                                          const sycl::nd_item<1> & item_ct1, int64_t nb) {\n    const int64_t i   = item_ct1.get_group(0);     // block index\n    const int64_t tid = item_ct1.get_local_id(0);  // thread index within block\n    const int64_t il  = tid / 8;\n    const int64_t ir  = tid % 8;\n\n    dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;\n\n    const uint8_t * base          = static_cast<const uint8_t *>(vx);\n    const size_t    qs_offset     = i * (QK_K / 2);\n    const size_t    scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE;\n    const size_t    dm_offset     = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2);\n\n    const uint8_t *    qs_ptr     = base + qs_offset;\n    const uint8_t *    scales_ptr = base + scales_offset;\n    ggml_half2         dm_values  = *reinterpret_cast<const ggml_half2 *>(base + dm_offset);\n\n    const float dall = dm_values.x();\n    const float dmin = dm_values.y();\n\n    if (tid < 12) {\n        scales_local[tid] = scales_ptr[tid];\n    }\n\n    item_ct1.barrier(sycl::access::fence_space::local_space);\n    dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, il, ir);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,\n                                  const sycl::nd_item<3> &item_ct1) {\n    const block_q5_K * x = (const block_q5_K *) vx;\n\n    const int64_t i = item_ct1.get_group(2);\n\n#if QK_K == 256\n    // assume 64 threads - this is very slightly better than the one below\n    const int64_t tid = item_ct1.get_local_id(2);\n    const int64_t il  = tid/16;   // il is in 0...3\n    const int64_t ir  = tid%16;   // ir is in 0...15\n    const int64_t is  = 2*il;     // is is in 0...6\n\n    dst_t * y = yy + i*QK_K + 64*il + 2*ir;\n\n    const float dall = x[i].dm[0];\n    const float dmin = x[i].dm[1];\n\n    const uint8_t * ql = x[i].qs + 32*il + 2*ir;\n    const uint8_t * qh = x[i].qh + 2*ir;\n\n    uint8_t sc, m;\n    get_scale_min_k4(is + 0, x[i].scales, sc, m);\n    const float d1 = dall * sc; const float m1 = dmin * m;\n    get_scale_min_k4(is + 1, x[i].scales, sc, m);\n    const float d2 = dall * sc; const float m2 = dmin * m;\n\n    uint8_t   hm  = 1 << (2*il);\n    y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;\n    y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;\n    hm <<= 1;\n    y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;\n    y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;\n#else\n    const int64_t tid = item_ct1.get_local_id(2);\n    const uint8_t q = x[i].qs[tid];\n    const int64_t im = tid/8;  // 0...3\n    const int64_t in = tid%8;  // 0...7\n    const int64_t is = tid/16; // 0 or 1\n    const uint8_t h = x[i].qh[in] >> im;\n    const float d = x[i].d;\n    dst_t * y = yy + i*QK_K + tid;\n    y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));\n    y[32] = d * x[i].scales[is+2] * ((q >>  4) - ((h >> 4) & 1 ? 0 : 16));\n#endif\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy,\n                                  const sycl::nd_item<3> &item_ct1) {\n    const block_q6_K * x = (const block_q6_K *) vx;\n\n    const int64_t i = item_ct1.get_group(2);\n#if QK_K == 256\n\n    // assume 64 threads - this is very slightly better than the one below\n    const int64_t tid = item_ct1.get_local_id(2);\n    const int64_t ip  = tid/32;   // ip is 0 or 1\n    const int64_t il  = tid - 32*ip; // 0...32\n    const int64_t is  = 8*ip + il/16;\n\n    dst_t * y = yy + i*QK_K + 128*ip + il;\n\n    const float d = x[i].d;\n\n    const uint8_t * ql = x[i].ql + 64*ip + il;\n    const uint8_t   qh = x[i].qh[32*ip + il];\n    const int8_t  * sc = x[i].scales + is;\n\n    y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);\n    y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);\n    y[64] = d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh >> 4) & 3) << 4)) - 32);\n    y[96] = d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32);\n#else\n\n    // assume 32 threads\n    const int64_t tid = item_ct1.get_local_id(2);\n    const int64_t ip  = tid/16;         // 0 or 1\n    const int64_t il  = tid - 16*ip;    // 0...15\n\n    dst_t * y = yy + i*QK_K + 16*ip + il;\n\n    const float d = x[i].d;\n\n    const uint8_t   ql = x[i].ql[16*ip + il];\n    const uint8_t   qh = x[i].qh[il] >> (2*ip);\n    const int8_t  * sc = x[i].scales;\n\n    y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);\n    y[32] = d * sc[ip+2] * ((int8_t)((ql  >> 4) | (((qh >> 4) & 3) << 4)) - 32);\n#endif\n}\n\ntemplate <typename dst_t>\nstatic void dequantize_block_q6_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy,\n                                          const sycl::nd_item<3> & item_ct1, int64_t n_blocks) {\n    const int64_t ib = item_ct1.get_group(2);\n\n    const int64_t tid = item_ct1.get_local_id(2);\n    const int64_t ip  = tid / 32;       // ip is 0 or 1\n    const int64_t il  = tid - 32 * ip;  // 0...32\n    const int64_t is  = 8 * ip + il / 16;\n\n    const uint8_t *   base_ptr           = static_cast<const uint8_t *>(vx);\n    const auto        ql_offset          = ib * (QK_K / 2);\n    const auto        qh_offset          = (QK_K / 2) * n_blocks + (QK_K / 4) * ib;\n    const auto        base_scales_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * n_blocks + (QK_K / 16) * ib;\n    const auto        base_d_offset      = ((QK_K / 2) + (QK_K / 4) + (QK_K / 16)) * n_blocks;\n    const uint8_t *   ql_ptr             = base_ptr + ql_offset;\n    const uint8_t *   qh_ptr             = base_ptr + qh_offset;\n    const uint8_t *   scales_ptr         = base_ptr + base_scales_offset;\n    const ggml_half * d                  = (const ggml_half *) (base_ptr + base_d_offset) + ib;\n\n    dst_t * y = yy + ib * QK_K + 128 * ip + il;\n\n    const uint8_t * ql = ql_ptr + 64 * ip + il;\n    const uint8_t   qh = *(qh_ptr + 32 * ip + il);\n    const int8_t *  sc = reinterpret_cast<const int8_t *>(scales_ptr + is);\n\n    y[0]  = *d * sc[0] * ((int8_t) ((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);\n    y[32] = *d * sc[2] * ((int8_t) ((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);\n    y[64] = *d * sc[4] * ((int8_t) ((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);\n    y[96] = *d * sc[6] * ((int8_t) ((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,\n                                     const sycl::nd_item<3> &item_ct1,\n                                     const uint64_t *iq2xxs_grid_ptr,\n                                     const uint8_t *ksigns_iq2xs_ptr,\n                                     const uint8_t *kmask_iq2xs_ptr) {\n\n    const int64_t i = item_ct1.get_group(2);\n    const block_iq2_xxs * x = (const block_iq2_xxs  *) vx;\n\n    const int64_t tid = item_ct1.get_local_id(2);\n#if QK_K == 256\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 8*il;\n    const uint16_t * q2 = x[i].qs + 4*ib;\n    const uint8_t  * aux8 = (const uint8_t *)q2;\n    const uint8_t  * grid = (const uint8_t *)(iq2xxs_grid_ptr + aux8[il]);\n    const uint32_t aux32 = q2[2] | (q2[3] << 16);\n    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;\n    const uint8_t signs = ksigns_iq2xs_ptr[(aux32 >> 7*il) & 127];\n    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs_ptr[j] ? -1.f : 1.f);\n#else\n    assert(false);\n#endif\n\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy,\n                                    const sycl::nd_item<3> &item_ct1,\n                                    const uint64_t *iq2xs_grid,\n                                    const uint8_t *ksigns_iq2xs,\n                                    const uint8_t *kmask_iq2xs) {\n\n    const int64_t i = item_ct1.get_group(2);\n    const block_iq2_xs * x = (const block_iq2_xs *) vx;\n\n    const int64_t tid = item_ct1.get_local_id(2);\n#if QK_K == 256\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 8*il;\n    const uint16_t * q2 = x[i].qs + 4*ib;\n    const uint8_t  * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));\n    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;\n    const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];\n    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);\n#else\n    assert(false);\n#endif\n\n}\n\ntemplate <typename dst_t>\n__dpct_inline__ static void\ndequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,\n                       const sycl::nd_item<3> &item_ct1) {\n\n    const int64_t i = item_ct1.get_group(2);\n    const block_iq2_s * x = (const block_iq2_s *) vx;\n\n    const int64_t tid = item_ct1.get_local_id(2);\n#if QK_K == 256\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 8*il;\n    const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));\n    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;\n    const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];\n#pragma unroll\n    for (int j = 0; j < 8; ++j)\n        y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);\n#else\n    assert(false);\n\n#endif\n\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,\n                                     const sycl::nd_item<3> &item_ct1,\n                                     const uint32_t *iq3xxs_grid,\n                                     const uint8_t *ksigns_iq2xs,\n                                     const uint8_t *kmask_iq2xs) {\n\n    const int64_t i = item_ct1.get_group(2);\n    const block_iq3_xxs * x = (const block_iq3_xxs  *) vx;\n\n    const int64_t tid = item_ct1.get_local_id(2);\n#if QK_K == 256\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 8*il;\n    const uint8_t  * q3 = x[i].qs + 8*ib;\n    const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;\n    const uint8_t  * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);\n    const uint8_t  * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);\n    const uint32_t aux32 = gas[0] | (gas[1] << 16);\n    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;\n    const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];\n    for (int j = 0; j < 4; ++j) {\n        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);\n        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);\n    }\n#else\n    assert(false);\n#endif\n\n}\n\ntemplate <typename dst_t>\n__dpct_inline__ static void\ndequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,\n                       const sycl::nd_item<3> &item_ct1,\n                       const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {\n\n    const int64_t i = item_ct1.get_group(2);\n    const block_iq3_s * x = (const block_iq3_s *) vx;\n\n    const int64_t tid = item_ct1.get_local_id(2);\n#if QK_K == 256\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 8*il;\n    const uint8_t * qs = x[i].qs + 8*ib;\n    const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));\n    const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));\n    const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));\n    const uint8_t signs = x[i].signs[4*ib + il];\n#pragma unroll\n    for (int j = 0; j < 4; ++j) {\n        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);\n        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);\n    }\n#else\n    assert(false);\n#endif\n\n}\n\ntemplate <typename dst_t>\n__dpct_inline__ static void\ndequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,\n                       const sycl::nd_item<3> &item_ct1,\n                       const uint32_t *iq1s_grid_gpu) {\n\n    const int64_t i = item_ct1.get_group(2);\n    const block_iq1_s * x = (const block_iq1_s  *) vx;\n\n    const int64_t tid = item_ct1.get_local_id(2);\n#if QK_K == 256\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 8*il;\n    const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;\n    const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);\n    uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;\n    grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];\n    grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;\n    grid32[0] &= 0x0f0f0f0f;\n#pragma unroll\n    for (int j = 0; j < 8; ++j) {\n        y[j] = d * (q[j] + delta);\n    }\n#else\n    assert(false);\n#endif\n\n}\n\ntemplate <typename dst_t>\n__dpct_inline__ static void\ndequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,\n                       const sycl::nd_item<3> &item_ct1,\n                       const uint32_t *iq1s_grid_gpu) {\n\n    const int64_t i = item_ct1.get_group(2);\n    const block_iq1_m * x = (const block_iq1_m  *) vx;\n\n    const int64_t tid = item_ct1.get_local_id(2);\n#if QK_K == 256\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 8*il;\n    const uint16_t * sc = (const uint16_t *)x[i].scales;\n    iq1m_scale_t scale;\n    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);\n    const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);\n    const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);\n    const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;\n    uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;\n    grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];\n    grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;\n    grid32[0] &= 0x0f0f0f0f;\n#pragma unroll\n    for (int j = 0; j < 8; ++j) {\n        y[j] = d * (q[j] + delta);\n    }\n#else\n    assert(false);\n#endif\n\n}\n\ntemplate <typename dst_t>\n__dpct_inline__ static void\ndequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,\n                        const sycl::nd_item<3> &item_ct1) {\n\n    const int64_t i = item_ct1.get_group(2);\n    const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);\n\n    const int64_t tid = item_ct1.get_local_id(2);\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 4*il;\n    const uint8_t  * q4 = x[ib].qs + 4*il;\n    const float d = (float)x[ib].d;\n#pragma unroll\n    for (int j = 0; j < 4; ++j) {\n        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];\n        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];\n    }\n\n}\n\n\ntemplate <typename dst_t>\n__dpct_inline__ static void\ndequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,\n                        const sycl::nd_item<3> &item_ct1) {\n    const int64_t i = item_ct1.get_group(2);\n    const block_iq4_xs * x = (const block_iq4_xs *)vx;\n\n    const int64_t tid = item_ct1.get_local_id(2);\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 4*il;\n    const uint8_t  * q4 = x[i].qs + 16*ib + 4*il;\n    const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);\n#pragma unroll\n    for (int j = 0; j < 4; ++j) {\n        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];\n        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];\n    }\n}\n\ntemplate<typename dst_t>\nstatic void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy,\n                                   const sycl::nd_item<3> &item_ct1) {\n    // auto                item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const int64_t       i        = item_ct1.get_group(2);\n    const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);\n\n    const int64_t    tid = item_ct1.get_local_id(2);\n    const int64_t il = tid/8; // 0...3\n    const int64_t ib = tid%8; // 0...7\n    dst_t * y = yy + i*QK_K + 32*ib + 4*il;\n    const uint8_t  * q4 = x[ib].qs + 4*il;\n    const float d = ggml_sycl_e8m0_to_fp32(x[ib].e);\n    for (int j = 0; j < 4; ++j) {\n        y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;\n        y[j+16] = d * kvalues_mxfp4[q4[j] >>  4]*0.5f;\n    }\n}\n\n#endif // GGML_SYCL_DEQUANTIZE_HPP\n"
  },
  {
    "path": "src/ggml-sycl/dmmv.cpp",
    "content": "#include \"convert.hpp\"\n#include \"dmmv.hpp\"\n#include \"dequantize.hpp\"\n#include \"presets.hpp\"\n\nstatic void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){\n    const sycl::half *x = (const sycl::half *)vx;\n\n    // automatic half -> float type cast if dfloat == float\n    v.x() = x[ib + iqs + 0];\n    v.y() = x[ib + iqs + 1];\n}\n\nstatic void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){\n    const float * x = (const float *) vx;\n\n    // automatic half -> float type cast if dfloat == float\n    v.x() = x[ib + iqs + 0];\n    v.y() = x[ib + iqs + 1];\n}\n\ntemplate <int qk, int qr, dequantize_kernel_t dequantize_kernel>\nstatic void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,\n                                   const sycl::nd_item<3> &item_ct1) {\n    // qk = quantized weights per x block\n    // qr = number of quantized weights per data value in x block\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n\n    if (row >= nrows) {\n        return;\n    }\n\n    const int tid = item_ct1.get_local_id(2);\n\n    const int iter_stride = 2*GGML_SYCL_DMMV_X;\n    const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter\n    const int y_offset = qr == 1 ? 1 : qk/2;\n\n// partial sum for each thread\n#ifdef GGML_SYCL_F16\n    sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics\n#else\n    float tmp = 0.0f;\n#endif // GGML_SYCL_F16\n\n    for (int i = 0; i < ncols; i += iter_stride) {\n        const int col = i + vals_per_iter*tid;\n        const int ib = (row*ncols + col)/qk; // x block index\n        const int iqs = (col%qk)/qr; // x quant index\n        const int iybs = col - col%qk; // y block start index\n\n// processing >2 values per i iter is faster for fast GPUs\n#pragma unroll\n        for (int j = 0; j < vals_per_iter; j += 2) {\n            // process 2 vals per j iter\n\n            // dequantize\n            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val\n            dfloat2 v;\n            dequantize_kernel(vx, ib, iqs + j/qr, v);\n\n            // matrix multiplication\n            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2\n#ifdef GGML_SYCL_F16\n            dfloat2 t1{y[iybs + iqs + j / qr + 0],\n                        y[iybs + iqs + j / qr + y_offset]};\n\n            tmp += v * t1;\n#else\n            tmp += v.x() * y[iybs + iqs + j / qr + 0];\n            tmp += v.y() * y[iybs + iqs + j / qr + y_offset];\n#endif // GGML_SYCL_F16\n        }\n    }\n\n    // sum up partial sums and write back result\n    const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;\n    for (int mask = mask_start; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (tid == 0) {\n#ifdef GGML_SYCL_F16\n        dst[row] = tmp.x() + tmp.y();\n#else\n        dst[row] = tmp;\n#endif // GGML_SYCL_F16\n    }\n}\n\ntemplate <int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_reorder>\nstatic void dequantize_mul_mat_vec_reorder(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,\n                                   const sycl::nd_item<3> &item_ct1) {\n    // qk = quantized weights per x block\n    // qr = number of quantized weights per data value in x block\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n\n    if (row >= nrows) {\n        return;\n    }\n\n    const int tid = item_ct1.get_local_id(2);\n\n\n    const int ncols_left = ncols % (QK4_0*WARP_SIZE);\n    const int ncols_align = ncols - ncols_left;\n    const int iter_stride = 8*2*GGML_SYCL_DMMV_X;\n    const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter //64/16=4, 512/16/2= 16\n    const int y_offset = qr == 1 ? 1 : qk/2;\n\n// partial sum for each thread\n#ifdef GGML_SYCL_F16\n    sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics\n#else\n    float tmp = 0.0f;\n#endif // GGML_SYCL_F16\n    const char *d_ptr = (const char*)vx+ncols*nrows/2;\n    int i=0;\n    for (i = 0; i < ncols_align; i += iter_stride) {\n        const int col = i + vals_per_iter*tid;\n        const int ib = (row*ncols + col)/qk; // x block index\n        const int iqs = (col%qk)/qr; // x quant index\n        const int iybs = col - col%qk; // y block start index\n\n// processing >2 values per i iter is faster for fast GPUs\n#pragma unroll\n        for (int j = 0; j < vals_per_iter; j += 2) {\n            // process 2 vals per j iter\n\n            // dequantize\n            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val\n            dfloat2 v;\n            dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);\n\n            // matrix multiplication\n            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2\n#ifdef GGML_SYCL_F16\n            dfloat2 t1{y[iybs + iqs + j / qr + 0],\n                        y[iybs + iqs + j / qr + y_offset]};\n\n            tmp += v * t1;\n#else\n            tmp += v.x() * y[iybs + iqs + j / qr + 0];\n            tmp += v.y() * y[iybs + iqs + j / qr + y_offset];\n#endif // GGML_SYCL_F16\n        }\n    }\n\n    for (; i < ncols; i += iter_stride) {\n        if (tid>=ncols_left/QK4_0) continue;\n        const int col = i + vals_per_iter*tid;\n        const int ib = (row*ncols + col)/qk; // x block index\n        const int iqs = (col%qk)/qr; // x quant index\n        const int iybs = col - col%qk; // y block start index\n\n// processing >2 values per i iter is faster for fast GPUs\n#pragma unroll\n        for (int j = 0; j < vals_per_iter; j += 2) {\n            // process 2 vals per j iter\n\n            // dequantize\n            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val\n            dfloat2 v;\n            dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);\n\n            // matrix multiplication\n            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2\n#ifdef GGML_SYCL_F16\n            dfloat2 t1{y[iybs + iqs + j / qr + 0],\n                        y[iybs + iqs + j / qr + y_offset]};\n\n            tmp += v * t1;\n#else\n            tmp += v.x() * y[iybs + iqs + j / qr + 0];\n            tmp += v.y() * y[iybs + iqs + j / qr + y_offset];\n#endif // GGML_SYCL_F16\n        }\n    }\n\n    // sum up partial sums and write back result\n    const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;\n    for (int mask = mask_start; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (tid == 0) {\n#ifdef GGML_SYCL_F16\n        dst[row] = tmp.x() + tmp.y();\n#else\n        dst[row] = tmp;\n#endif // GGML_SYCL_F16\n    }\n}\n\nstatic void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,\n                                         float *dst, const int ncols,\n                                         const int nrows,\n                                         dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,\n                                                          nrows, item_ct1);\n            });\n    }\n}\n\n/*\nDPCT1110:4: The total declared local variable size in device function\ndequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register\npressure. Consult with your hardware vendor to find the total register size\navailable and adjust the code, or use smaller sub-group size to avoid high\nregister pressure.\n*/\nstatic void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,\n                                        const float *__restrict__ yy,\n                                        float *__restrict__ dst,\n                                        const int ncols, int nrows,\n                                        const sycl::nd_item<3> &item_ct1) {\n\n    static_assert(16%K_QUANTS_PER_ITERATION == 0, \"16 must be divisible by K_QUANTS_PER_ITERATION\");\n\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n    if (row > nrows) return;\n\n    const int num_blocks_per_row = ncols / QK_K;\n    const int ib0 = row*num_blocks_per_row;\n\n    const block_q2_K * x = (const block_q2_K *)vx + ib0;\n\n    float tmp = 0; // partial sum for thread in warp\n\n#if QK_K == 256\n    const int tid =\n        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...15\n    const int ix =\n        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1\n\n    const int step = 16/K_QUANTS_PER_ITERATION;\n\n    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...\n    const int in = tid - step*im;                        // 0...15 or 0...7\n\n    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15 or 0...14 in steps of 2\n    const int q_offset = 32*im + l0;\n    const int s_offset = 8*im;\n    const int y_offset = 128*im + l0;\n\n    uint32_t aux[4];\n    const uint8_t * d = (const uint8_t *)aux;\n    const uint8_t * m = (const uint8_t *)(aux + 2);\n\n    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {\n\n        const float   * y = yy + i * QK_K + y_offset;\n        const uint8_t * q = x[i].qs + q_offset;\n\n        const float dall = x[i].dm[0];\n        const float dmin = x[i].dm[1];\n\n        const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);\n        aux[0] = a[0] & 0x0f0f0f0f;\n        aux[1] = a[1] & 0x0f0f0f0f;\n        aux[2] = (a[0] >> 4) & 0x0f0f0f0f;\n        aux[3] = (a[1] >> 4) & 0x0f0f0f0f;\n\n        float sum1 = 0, sum2 = 0;\n        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {\n            sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)\n                  + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)\n                  + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)\n                  + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)\n                  + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)\n                  + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)\n                  + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)\n                  +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);\n            sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]\n                  + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];\n\n        }\n        tmp += dall * sum1 - dmin * sum2;\n\n    }\n#else\n    const int tid = item_ct1.get_local_id(2) /\n                    (2 * K_QUANTS_PER_ITERATION); // 0...15 or 0...7\n    const int ix = item_ct1.get_local_id(2) %\n                   (2 * K_QUANTS_PER_ITERATION); // 0....1 or 0...3\n    const int offset = tid * K_QUANTS_PER_ITERATION;\n\n    uint32_t uaux[2];\n    const uint8_t * d = (const uint8_t *)uaux;\n\n\n    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {\n\n        const float   * y = yy + i * QK_K + offset;\n        const uint8_t * q = x[i].qs + offset;\n        const uint32_t * s = (const uint32_t *)x[i].scales;\n\n        uaux[0] = s[0] & 0x0f0f0f0f;\n        uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;\n\n        const sycl::float2 dall =\n            x[i].dm.convert<float, sycl::rounding_mode::automatic>();\n\n        float sum1 = 0, sum2 = 0;\n        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {\n            const uint8_t ql = q[l];\n            sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)\n                  + y[l+16] * d[1] * ((ql >> 2) & 3)\n                  + y[l+32] * d[2] * ((ql >> 4) & 3)\n                  + y[l+48] * d[3] * ((ql >> 6) & 3);\n            sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];\n        }\n        tmp += dall.x() * sum1 - dall.y() * sum2;\n    }\n\n#endif\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[row] = tmp;\n    }\n}\n\n/*\nDPCT1110:5: The total declared local variable size in device function\ndequantize_mul_mat_vec_q3_k exceeds 128 bytes and may cause high register\npressure. Consult with your hardware vendor to find the total register size\navailable and adjust the code, or use smaller sub-group size to avoid high\nregister pressure.\n*/\nstatic void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,\n                                        const float *__restrict__ yy,\n                                        float *__restrict__ dst,\n                                        const int ncols, int nrows,\n                                        const sycl::nd_item<3> &item_ct1) {\n\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n    if (row > nrows) return;\n\n    const int num_blocks_per_row = ncols / QK_K;\n    const int ib0 = row*num_blocks_per_row;\n\n    const block_q3_K * x = (const block_q3_K *)vx + ib0;\n\n    float tmp = 0; // partial sum for thread in warp\n\n#if QK_K == 256\n\n    const uint16_t kmask1 = 0x0303;\n    const uint16_t kmask2 = 0x0f0f;\n\n    const int tid =\n        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16\n    const int ix =\n        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1\n\n    const int n  = K_QUANTS_PER_ITERATION;               // iterations in the inner loop\n    const int step = 16/K_QUANTS_PER_ITERATION;\n    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...\n    const int in = tid - step*im;                        // 0....15 or 0...7\n\n    const uint8_t m = 1 << (4*im);\n\n    const int l0 = n*in;                                 // 0...15 or 0...14 in steps of 2\n    const int q_offset =  32*im + l0;\n    const int y_offset = 128*im + l0;\n\n    uint16_t utmp[4];\n    const int8_t * s = (const int8_t *)utmp;\n\n    const uint16_t s_shift = 4*im;\n\n    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {\n\n        const float   * y  = yy + i * QK_K + y_offset;\n        const uint8_t * q = x[i].qs + q_offset;\n        const uint8_t * h = x[i].hmask + l0;\n\n        const uint16_t * a = (const uint16_t *)x[i].scales;\n        utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);\n        utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);\n        utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);\n        utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);\n\n        const float d = x[i].d;\n\n        float sum = 0;\n        for (int l = 0; l < n; ++l) {\n            sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))\n                 + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))\n                 + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))\n                 + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));\n            sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))\n                 + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))\n                 + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))\n                + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));\n        }\n        tmp += d * sum;\n\n    }\n#else\n\n    const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION);  // 0...15 or 0...7\n    const int ix  = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);  // 0....1 or 0...3\n    const int offset = tid * K_QUANTS_PER_ITERATION;         // 0...15 or 0...14\n    const int in = offset/8;                                 // 0 or 1\n    const int im = offset%8;                                 // 0...7\n\n    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {\n\n        const float   * y = yy + i * QK_K + offset;\n        const uint8_t * q = x[i].qs + offset;\n        const uint8_t * s = x[i].scales;\n\n        const float dall = (float)x[i].d;\n\n        float sum = 0;\n        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {\n            const uint8_t hl = x[i].hmask[im+l] >> in;\n            const uint8_t ql = q[l];\n            sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))\n                 + y[l+16] * dall * ((s[0] >>  4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))\n                 + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))\n                 + y[l+48] * dall * ((s[1] >>  4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));\n        }\n        tmp += sum;\n    }\n#endif\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[row] = tmp;\n    }\n}\n\n/*\nDPCT1110:6: The total declared local variable size in device function\ndequantize_mul_mat_vec_q4_k exceeds 128 bytes and may cause high register\npressure. Consult with your hardware vendor to find the total register size\navailable and adjust the code, or use smaller sub-group size to avoid high\nregister pressure.\n*/\nstatic void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,\n                                        const float *__restrict__ yy,\n                                        float *__restrict__ dst,\n                                        const int ncols, int nrows,\n                                        const sycl::nd_item<3> &item_ct1) {\n\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n    if (row > nrows) return;\n    const int num_blocks_per_row = ncols / QK_K;\n    const int ib0 = row*num_blocks_per_row;\n\n    const block_q4_K * x = (const block_q4_K *)vx + ib0;\n\n#if QK_K == 256\n    const uint16_t kmask1 = 0x3f3f;\n    const uint16_t kmask2 = 0x0f0f;\n    const uint16_t kmask3 = 0xc0c0;\n\n    const int tid =\n        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16\n    const int ix =\n        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1\n\n    const int step = 8/K_QUANTS_PER_ITERATION;           // 8 or 4\n\n    const int il  = tid/step;                            // 0...3\n    const int ir  = tid - step*il;                       // 0...7 or 0...3\n    const int n   = 2 * K_QUANTS_PER_ITERATION;          // 2 or 4\n\n    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224\n    const int in = il%2;\n\n    const int l0 = n*(2*ir + in);\n    const int q_offset = 32*im + l0;\n    const int y_offset = 64*im + l0;\n\n    uint16_t aux[4];\n    const uint8_t * sc = (const uint8_t *)aux;\n\n#if K_QUANTS_PER_ITERATION == 2\n    uint32_t q32[4];\n    const uint8_t * q4 = (const uint8_t *)q32;\n#else\n    uint16_t q16[4];\n    const uint8_t * q4 = (const uint8_t *)q16;\n#endif\n\n    float tmp = 0; // partial sum for thread in warp\n\n    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {\n\n        const float   * y1 = yy + i*QK_K + y_offset;\n        const float   * y2 = y1 + 128;\n\n        const float dall = x[i].dm[0];\n        const float dmin = x[i].dm[1];\n\n        const uint16_t * a = (const uint16_t *)x[i].scales;\n        aux[0] = a[im+0] & kmask1;\n        aux[1] = a[im+2] & kmask1;\n        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);\n        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);\n\n#if K_QUANTS_PER_ITERATION == 2\n        const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);\n        const uint32_t * q2 = q1 + 16;\n\n        q32[0] = q1[0] & 0x0f0f0f0f;\n        q32[1] = q1[0] & 0xf0f0f0f0;\n        q32[2] = q2[0] & 0x0f0f0f0f;\n        q32[3] = q2[0] & 0xf0f0f0f0;\n\n        sycl::float4 s = {0.f, 0.f, 0.f, 0.f};\n        float smin = 0;\n        for (int l = 0; l < 4; ++l) {\n            s.x() += y1[l] * q4[l + 0]; s.y() += y1[l + 32] * q4[l + 4];\n            s.z() += y2[l] * q4[l + 8]; s.w() += y2[l + 32] * q4[l + 12];\n            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];\n        }\n        tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f +\n                       s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) -\n               dmin * smin;\n#else\n        const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);\n        const uint16_t * q2 = q1 + 32;\n\n        q16[0] = q1[0] & 0x0f0f;\n        q16[1] = q1[0] & 0xf0f0;\n        q16[2] = q2[0] & 0x0f0f;\n        q16[3] = q2[0] & 0xf0f0;\n\n        float4 s = {0.f, 0.f, 0.f, 0.f};\n        float smin = 0;\n        for (int l = 0; l < 2; ++l) {\n            s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];\n            s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];\n            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];\n        }\n        tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;\n#endif\n\n    }\n#else\n    const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION);  // 0...15\n    const int ix  = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);\n\n    const int step = tid * K_QUANTS_PER_ITERATION;\n\n    uint16_t aux16[2];\n    const uint8_t * s = (const uint8_t *)aux16;\n\n    float tmp = 0;\n\n    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {\n        const uint8_t * q = x[i].qs + step;\n        const float   * y = yy + i*QK_K + step;\n        const uint16_t * a = (const uint16_t *)x[i].scales;\n        aux16[0] = a[0] & 0x0f0f;\n        aux16[1] = (a[0] >> 4) & 0x0f0f;\n        const float d = (float)x[i].dm[0];\n        const float m = (float)x[i].dm[1];\n        float sum = 0.f;\n        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {\n            sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])\n                 + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])\n                 + y[j+32] * (d * s[1] * (q[j+ 0] >>  4) - m * s[3])\n                 + y[j+48] * (d * s[1] * (q[j+16] >>  4) - m * s[3]);\n        }\n        tmp += sum;\n    }\n\n#endif\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (tid == 0) {\n        dst[row] = tmp;\n    }\n}\n\n/*\nDPCT1110:7: The total declared local variable size in device function\ndequantize_mul_mat_vec_q5_k exceeds 128 bytes and may cause high register\npressure. Consult with your hardware vendor to find the total register size\navailable and adjust the code, or use smaller sub-group size to avoid high\nregister pressure.\n*/\nstatic void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,\n                                        const float *__restrict__ yy,\n                                        float *__restrict__ dst,\n                                        const int ncols,\n                                        const sycl::nd_item<3> &item_ct1) {\n\n    const int row = item_ct1.get_group(2);\n    const int num_blocks_per_row = ncols / QK_K;\n    const int ib0 = row*num_blocks_per_row;\n\n    const block_q5_K * x = (const block_q5_K *)vx + ib0;\n\n    float tmp = 0; // partial sum for thread in warp\n\n#if QK_K == 256\n    const uint16_t kmask1 = 0x3f3f;\n    const uint16_t kmask2 = 0x0f0f;\n    const uint16_t kmask3 = 0xc0c0;\n\n    const int tid = item_ct1.get_local_id(2) / 2; // 0...15\n    const int ix = item_ct1.get_local_id(2) % 2;\n\n    const int il  = tid/4;     // 0...3\n    const int ir  = tid - 4*il;// 0...3\n    const int n   = 2;\n\n    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224\n    const int in = il%2;\n\n    const int l0 = n*(2*ir + in);\n    const int q_offset = 32*im + l0;\n    const int y_offset = 64*im + l0;\n\n    const uint8_t hm1  = 1 << (2*im);\n    const uint8_t hm2  = hm1 << 4;\n\n    uint16_t aux[4];\n    const uint8_t * sc = (const uint8_t *)aux;\n\n    uint16_t q16[8];\n    const uint8_t * q4 = (const uint8_t *)q16;\n\n    for (int i = ix; i < num_blocks_per_row; i += 2) {\n\n        const uint8_t * ql1 = x[i].qs + q_offset;\n        const uint8_t * qh  = x[i].qh + l0;\n        const float   * y1  = yy + i*QK_K + y_offset;\n        const float   * y2  = y1 + 128;\n\n        const float dall = x[i].dm[0];\n        const float dmin = x[i].dm[1];\n\n        const uint16_t * a = (const uint16_t *)x[i].scales;\n        aux[0] = a[im+0] & kmask1;\n        aux[1] = a[im+2] & kmask1;\n        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);\n        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);\n\n        sycl::float4 sum = {0.f, 0.f, 0.f, 0.f};\n        float smin = 0;\n        const uint16_t * q1 = (const uint16_t *)ql1;\n        const uint16_t * q2 = q1 + 32;\n        q16[0] = q1[0] & 0x0f0f;\n        q16[1] = q1[8] & 0x0f0f;\n        q16[2] = (q1[0] >> 4) & 0x0f0f;\n        q16[3] = (q1[8] >> 4) & 0x0f0f;\n        q16[4] = q2[0] & 0x0f0f;\n        q16[5] = q2[8] & 0x0f0f;\n        q16[6] = (q2[0] >> 4) & 0x0f0f;\n        q16[7] = (q2[8] >> 4) & 0x0f0f;\n        for (int l = 0; l < n; ++l) {\n            sum.x() +=\n                y1[l + 0] * (q4[l + 0] + (qh[l + 0] & (hm1 << 0) ? 16 : 0)) +\n                y1[l + 16] * (q4[l + 2] + (qh[l + 16] & (hm1 << 0) ? 16 : 0));\n            sum.y() +=\n                y1[l + 32] * (q4[l + 4] + (qh[l + 0] & (hm1 << 1) ? 16 : 0)) +\n                y1[l + 48] * (q4[l + 6] + (qh[l + 16] & (hm1 << 1) ? 16 : 0));\n            sum.z() +=\n                y2[l + 0] * (q4[l + 8] + (qh[l + 0] & (hm2 << 0) ? 16 : 0)) +\n                y2[l + 16] * (q4[l + 10] + (qh[l + 16] & (hm2 << 0) ? 16 : 0));\n            sum.w() +=\n                y2[l + 32] * (q4[l + 12] + (qh[l + 0] & (hm2 << 1) ? 16 : 0)) +\n                y2[l + 48] * (q4[l + 14] + (qh[l + 16] & (hm2 << 1) ? 16 : 0));\n            smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]\n                  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];\n        }\n        tmp += dall * (sum.x() * sc[0] + sum.y() * sc[1] + sum.z() * sc[4] +\n                       sum.w() * sc[5]) -\n               dmin * smin;\n    }\n\n#else\n    const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION);  // 0...15\n    const int ix  = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);\n    const int step = tid * K_QUANTS_PER_ITERATION;\n    const int im = step/8;\n    const int in = step%8;\n\n    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {\n        const uint8_t * q = x[i].qs + step;\n        const int8_t  * s = x[i].scales;\n        const float   * y = yy + i*QK_K + step;\n        const float     d = x[i].d;\n        float sum = 0.f;\n        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {\n            const uint8_t h = x[i].qh[in+j] >> im;\n            sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))\n                 + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))\n                 + y[j+32] * d * s[2] * ((q[j+ 0] >>  4) - ((h >> 4) & 1 ? 0 : 16))\n                 + y[j+48] * d * s[3] * ((q[j+16] >>  4) - ((h >> 6) & 1 ? 0 : 16));\n        }\n        tmp += sum;\n    }\n#endif\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[row] = tmp;\n    }\n}\n\nstatic void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows,\n                                        const sycl::nd_item<3> &item_ct1) {\n\n    static_assert(16%K_QUANTS_PER_ITERATION == 0, \"16 must be divisible by K_QUANTS_PER_ITERATION\");\n\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n    if (row > nrows) return;\n\n    const int num_blocks_per_row = ncols / QK_K;\n    const int ib0 = row*num_blocks_per_row;\n\n    const block_q6_K * x = (const block_q6_K *)vx + ib0;\n\n#if QK_K == 256\n\n    const int tid =\n        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16\n    const int ix =\n        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1\n\n    const int step = 16/K_QUANTS_PER_ITERATION;          // 16 or 8\n\n    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...\n    const int in = tid - step*im;                        // 0...15 or 0...7\n\n#if K_QUANTS_PER_ITERATION == 1\n    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15\n    const int is = 0;\n#else\n    const int l0 = 4 * in;                               // 0, 4, 8, ..., 28\n    const int is = in / 4;\n#endif\n    const int ql_offset = 64*im + l0;\n    const int qh_offset = 32*im + l0;\n    const int s_offset  =  8*im + is;\n    const int y_offset = 128*im + l0;\n\n    float tmp = 0; // partial sum for thread in warp\n\n    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {\n\n        const float   * y  = yy + i * QK_K + y_offset;\n        const uint8_t * ql = x[i].ql + ql_offset;\n        const uint8_t * qh = x[i].qh + qh_offset;\n        const int8_t  * s  = x[i].scales + s_offset;\n\n        const float d = x[i].d;\n\n#if K_QUANTS_PER_ITERATION == 1\n        float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)\n                  + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)\n                  + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)\n                  + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)\n                  + y[64] * s[4] * d * ((int8_t)((ql[ 0]  >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)\n                  + y[80] * s[5] * d * ((int8_t)((ql[16]  >> 4) | ((qh[16] & 0x30) >> 0)) - 32)\n                  + y[96] * s[6] * d * ((int8_t)((ql[32]  >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)\n                  +y[112] * s[7] * d * ((int8_t)((ql[48]  >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);\n        tmp += sum;\n#else\n        float sum = 0;\n        for (int l = 0; l < 4; ++l) {\n            sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)\n                 + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)\n                 + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)\n                 + y[l+96] * s[6] * d * ((int8_t)((ql[l+32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);\n        }\n        tmp += sum;\n#endif\n\n    }\n\n#else\n\n    const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION);  // 0...7\n    const int ix  = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);  // 0...3\n\n    const int step = tid * K_QUANTS_PER_ITERATION;\n\n    float tmp = 0; // partial sum for thread in warp\n\n    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {\n\n        const float   * y  = yy + i * QK_K + step;\n        const uint8_t * ql = x[i].ql + step;\n        const uint8_t * qh = x[i].qh + step;\n        const int8_t  * s  = x[i].scales;\n\n        const float d = x[i+0].d;\n\n        float sum = 0;\n        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {\n            sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)\n                 + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)\n                 + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >>  4) | ((qh[j] & 0x30) >> 0)) - 32)\n                 + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >>  4) | ((qh[j] & 0xc0) >> 2)) - 32);\n        }\n        tmp += sum;\n\n    }\n\n#endif\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (tid == 0) {\n        dst[row] = tmp;\n    }\n}\n\nstatic void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y,\n                                             float *dst, const int ncols,\n                                             const int nrows,\n                                             dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(\n                    vx, y, dst, ncols, nrows, item_ct1);\n            });\n    }\n}\n\n\nstatic void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,\n                                             float *dst, const int ncols,\n                                             const int nrows,\n                                             dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(\n                    vx, y, dst, ncols, nrows, item_ct1);\n            });\n    }\n}\n\nstatic void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,\n                                             float *dst, const int ncols,\n                                             const int nrows,\n                                             dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(\n                    vx, y, dst, ncols, nrows, item_ct1);\n            });\n    }\n}\n\nstatic void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,\n                                             float *dst, const int ncols,\n                                             const int nrows,\n                                             dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(\n                    vx, y, dst, ncols, nrows, item_ct1);\n            });\n    }\n}\n\nstatic void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,\n                                             float *dst, const int ncols,\n                                             const int nrows,\n                                             dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(\n                    vx, y, dst, ncols, nrows, item_ct1);\n            });\n    }\n}\n\nstatic void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,\n                                             float *dst, const int ncols,\n                                             const int nrows,\n                                             dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(\n                    vx, y, dst, ncols, nrows, item_ct1);\n            });\n    }\n}\n\nstatic void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,\n                                             float *dst, const int ncols,\n                                             const int nrows,\n                                             dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2\n    const int block_num_y = (nrows + ny - 1) / ny;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);\n    stream->parallel_for(\n        sycl::nd_range<3>(block_nums * block_dims, block_dims),\n        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {\n            dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);\n        });\n}\n\nstatic void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,\n                                             float *dst, const int ncols,\n                                             const int nrows,\n                                             dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int ny = 2 / K_QUANTS_PER_ITERATION;\n    const int block_num_y = (nrows + ny - 1) / ny;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);\n    stream->parallel_for(\n        sycl::nd_range<3>(block_nums * block_dims, block_dims),\n        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {\n            dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);\n        });\n}\n\nstatic void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,\n                                             float *dst, const int ncols,\n                                             const int nrows,\n                                             dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int ny = 2 / K_QUANTS_PER_ITERATION;\n    const int block_num_y = (nrows + ny - 1) / ny;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);\n    stream->parallel_for(\n        sycl::nd_range<3>(block_nums * block_dims, block_dims),\n        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {\n            dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);\n        });\n}\n\nstatic void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,\n                                             float *dst, const int ncols,\n                                             const int nrows,\n                                             dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);\n    stream->parallel_for(\n        sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),\n        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {\n            dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);\n        });\n}\n\nstatic void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,\n                                             float *dst, const int ncols,\n                                             const int nrows,\n                                             dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int ny = 2 / K_QUANTS_PER_ITERATION;\n    const int block_num_y = (nrows + ny - 1) / ny;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);\n    stream->parallel_for(\n        sycl::nd_range<3>(block_nums * block_dims, block_dims),\n        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {\n            dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);\n        });\n}\n\nvoid ggml_sycl_op_dequantize_mul_mat_vec(\n    ggml_backend_sycl_context & ctx,\n    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,\n    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,\n    float *dst_dd_i, const int64_t row_low, const int64_t row_high,\n    const int64_t src1_ncols, const int64_t src1_padded_row_size,\n    const dpct::queue_ptr &stream) {\n\n    const int64_t ne00 = src0->ne[0];\n    const int64_t row_diff = row_high - row_low;\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics\n#ifdef GGML_SYCL_F16\n    ggml_sycl_pool_alloc<sycl::half> src1_dfloat_a(ctx.pool());\n    sycl::half *src1_dfloat = nullptr; // dfloat == half\n\n    bool src1_convert_f16 =\n        src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||\n        src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||\n        src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;\n\n    if (src1_convert_f16) {\n        scope_op_debug_print scope_dbg_print(__func__, \"/to_fp16_sycl\", dst, /*num_src=*/2,\n                                             \" : converting src1 to fp16\");\n        src1_dfloat = src1_dfloat_a.alloc(ne00);\n        const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);\n        GGML_ASSERT(to_fp16_sycl != nullptr);\n        to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream);\n    }\n#else\n    const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion\n#endif // GGML_SYCL_F16\n\n    switch (src0->type) {\n        case GGML_TYPE_Q4_0:\n            if ((ggml_tensor_extra_gpu*)dst->src[0]->extra &&\n                ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {\n                dequantize_mul_mat_vec_q4_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);\n            } else {\n                dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);\n            }\n            break;\n        case GGML_TYPE_Q4_1:\n            dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);\n            break;\n        case GGML_TYPE_Q5_0:\n            dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);\n            break;\n        case GGML_TYPE_Q5_1:\n            dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);\n            break;\n        case GGML_TYPE_Q8_0:\n            dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);\n            break;\n        case GGML_TYPE_Q2_K:\n            dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);\n            break;\n        case GGML_TYPE_Q3_K:\n            dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);\n            break;\n        case GGML_TYPE_Q4_K:\n            if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&\n                ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {\n                // reorder is currently not supported for dmmv\n                GGML_ABORT(\"Unimplemented dequantize case case for q4_k reorder\");\n            } else {\n                dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);\n            }\n            break;\n        case GGML_TYPE_Q5_K:\n            dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);\n            break;\n        case GGML_TYPE_Q6_K:\n            dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);\n            break;\n        case GGML_TYPE_F16:\n            convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);\n            break;\n        default:\n            printf(\"ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\\n\", src0->type);\n            GGML_ABORT(\"fatal error\");\n    }\n\n    GGML_UNUSED(src1);\n    GGML_UNUSED(dst);\n    GGML_UNUSED(src1_ddq_i);\n    GGML_UNUSED(src1_ncols);\n    GGML_UNUSED(src1_padded_row_size);\n    GGML_UNUSED(ctx);\n}\n"
  },
  {
    "path": "src/ggml-sycl/dmmv.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_DMMV_HPP\n#define GGML_SYCL_DMMV_HPP\n\n#include \"common.hpp\"\n\n\nvoid ggml_sycl_op_dequantize_mul_mat_vec(\n    ggml_backend_sycl_context & ctx,\n    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,\n    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,\n    float *dst_dd_i, const int64_t row_low, const int64_t row_high,\n    const int64_t src1_ncols, const int64_t src1_padded_row_size,\n    const dpct::queue_ptr &stream);\n\n#endif // GGML_SYCL_DMMV_HPP\n"
  },
  {
    "path": "src/ggml-sycl/dpct/helper.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_DPCT_HELPER_HPP\n#define GGML_SYCL_DPCT_HELPER_HPP\n\n#include <sycl/sycl.hpp>\n#include <sycl/half_type.hpp>\n#include <oneapi/mkl.hpp>\n\n#include <map>\n\n#include \"ggml.h\"\n\n#if defined(__linux__)\n#include <sys/mman.h>\n#elif defined(_WIN64)\n#ifndef NOMINMAX\n#define NOMINMAX\n#endif\n#include <windows.h>\n#else\n#error \"Only support Windows and Linux.\"\n#endif\n\n#if defined(__linux__)\n#include <unistd.h>\n#include <sys/syscall.h>\n#endif\n#if defined(_WIN64)\n#ifndef NOMINMAX\n#define NOMINMAX\n#endif\n#include <windows.h>\n#endif\n\n#define DPCT_COMPATIBILITY_TEMP (900)\n\n#if defined(_MSC_VER)\n#define __dpct_align__(n) __declspec(align(n))\n#define __dpct_inline__ __forceinline\n#else\n#define __dpct_align__(n) __attribute__((aligned(n)))\n#define __dpct_inline__ __inline__ __attribute__((always_inline))\n#endif\n\n#if defined(_MSC_VER)\n#define __dpct_noinline__ __declspec(noinline)\n#else\n#define __dpct_noinline__ __attribute__((noinline))\n#endif\n\ninline std::string get_device_type_name(const sycl::device &Device) {\n    auto DeviceType = Device.get_info<sycl::info::device::device_type>();\n    switch (DeviceType) {\n    case sycl::info::device_type::cpu:\n        return \"cpu\";\n    case sycl::info::device_type::gpu:\n        return \"gpu\";\n    case sycl::info::device_type::host:\n        return \"host\";\n    case sycl::info::device_type::accelerator:\n        return \"acc\";\n    default:\n        return \"unknown\";\n    }\n}\n\ninline std::string get_device_backend_and_type(const sycl::device &device) {\n    std::stringstream device_type;\n    sycl::backend backend = device.get_backend();\n    device_type <<  backend << \":\" << get_device_type_name(device);\n    return device_type.str();\n}\n\ntemplate <typename Ts> struct matrix_info_t {\n    oneapi::mkl::transpose transpose_info[2];\n    Ts                     value_info[2];\n    std::int64_t           size_info[3];\n    std::int64_t           ld_info[3];\n    std::int64_t           groupsize_info;\n};\n\nnamespace dpct\n{\n    typedef sycl::queue *queue_ptr;\n    typedef sycl::event *event_ptr;\n    typedef char *device_ptr;\n    typedef uint8_t byte_t;\n    typedef sycl::buffer<byte_t> buffer_t;\n\n    /// SYCL default exception handler\n    inline auto exception_handler = [](sycl::exception_list exceptions)\n    {\n        for (std::exception_ptr const &e : exceptions)\n        {\n            try\n            {\n                std::rethrow_exception(e);\n            }\n            catch (sycl::exception const &e)\n            {\n                std::cerr << \"Caught asynchronous SYCL exception:\" << std::endl\n                          << e.what() << std::endl\n                          << \"Exception caught at file:\" << __FILE__\n                          << \", line:\" << __LINE__ << std::endl;\n            }\n        }\n    };\n\n    enum error_code\n    {\n        success = 0,\n        default_error = 999\n    };\n\n    enum memcpy_direction\n    {\n        host_to_host,\n        host_to_device,\n        device_to_host,\n        device_to_device,\n        automatic\n    };\n\n    enum memory_region\n    {\n        global = 0, // device global memory\n        constant,   // device constant memory\n        local,      // device local memory\n        shared,     // memory which can be accessed by host and device\n    };\n\n    enum class library_data_t : unsigned char\n    {\n        real_float = 0,\n        complex_float,\n        real_double,\n        complex_double,\n        real_half,\n        complex_half,\n        real_bfloat16,\n        complex_bfloat16,\n        real_int4,\n        complex_int4,\n        real_uint4,\n        complex_uint4,\n        real_int8,\n        complex_int8,\n        real_uint8,\n        complex_uint8,\n        real_int16,\n        complex_int16,\n        real_uint16,\n        complex_uint16,\n        real_int32,\n        complex_int32,\n        real_uint32,\n        complex_uint32,\n        real_int64,\n        complex_int64,\n        real_uint64,\n        complex_uint64,\n        real_int8_4,\n        real_int8_32,\n        real_uint8_4,\n        library_data_t_size\n    };\n\n    template <typename T>\n    struct DataType\n    {\n        using T2 = T;\n    };\n    template <typename T>\n    struct DataType<sycl::vec<T, 2>>\n    {\n        using T2 = std::complex<T>;\n    };\n\n    static void destroy_event(event_ptr event)\n    {\n        delete event;\n    }\n\n    static inline unsigned int get_tid()\n    {\n#if defined(__linux__)\n        return syscall(SYS_gettid);\n#elif defined(_WIN64)\n        return GetCurrentThreadId();\n#else\n#error \"Only support Windows and Linux.\"\n#endif\n    }\n\n    namespace detail\n    {\n        static void get_version(const sycl::device &dev, int &major, int &minor)\n        {\n            // Version string has the following format:\n            // a. OpenCL<space><major.minor><space><vendor-specific-information>\n            // b. <major.minor>\n            // c. <AmdGcnArchName> e.g gfx1030\n            std::string ver;\n            ver = dev.get_info<sycl::info::device::version>();\n            std::string::size_type i = 0;\n            while (i < ver.size()) {\n              if (isdigit(ver[i]))\n                break;\n              i++;\n            }\n            major = std::stoi(&(ver[i]));\n            while (i < ver.size()) {\n              if (ver[i] == '.')\n                break;\n              i++;\n            }\n            if (i < ver.size()) {\n              // a. and b.\n              i++;\n              minor = std::stoi(&(ver[i]));\n            } else {\n              // c.\n              minor = 0;\n            }\n        }\n\n        template <typename tag, typename T>\n        class generic_error_type\n        {\n        public:\n            generic_error_type() = default;\n            generic_error_type(T value) : value{value} {}\n            operator T() const { return value; }\n\n        private:\n            T value;\n        };\n\n    } // namespace detail\n\n    // COPY from DPCT head files\n    /// dim3 is used to store 3 component dimensions.\n    class dim3 {\n        public:\n        unsigned x, y, z;\n\n        constexpr dim3(unsigned x = 1, unsigned y = 1, unsigned z = 1)\n            : x(x), y(y), z(z) {}\n\n        dim3(const sycl::id<3> &r) : dim3(r[2], r[1], r[0]) {}\n\n        operator sycl::range<3>() const { return sycl::range<3>(z, y, x); }\n    }; // namespace dim3\n\n    inline dim3 operator*(const dim3 &a, const dim3 &b) {\n    return dim3{a.x * b.x, a.y * b.y, a.z * b.z};\n    }\n    // COPY from DPCT head files\n\n\n    /// Pitched 2D/3D memory data.\n    class pitched_data\n    {\n    public:\n        pitched_data() : pitched_data(nullptr, 0, 0, 0) {}\n        pitched_data(void *data, size_t pitch, size_t x, size_t y)\n            : _data(data), _pitch(pitch), _x(x), _y(y) {}\n\n        void *get_data_ptr() { return _data; }\n        void set_data_ptr(void *data) { _data = data; }\n\n        size_t get_pitch() { return _pitch; }\n        void set_pitch(size_t pitch) { _pitch = pitch; }\n\n        size_t get_x() { return _x; }\n        void set_x(size_t x) { _x = x; }\n\n        size_t get_y() { return _y; }\n        void set_y(size_t y) { _y = y; }\n\n    private:\n        void *_data;\n        size_t _pitch, _x, _y;\n    };\n\n    class device_info\n    {\n    public:\n        // get interface\n        const char *get_name() const { return _name; }\n        char *get_name() { return _name; }\n        template <typename WorkItemSizesTy = sycl::range<3>,\n                  std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||\n                                       std::is_same_v<WorkItemSizesTy, int *>,\n                                   int> = 0>\n        auto get_max_work_item_sizes() const\n        {\n            if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)\n                return sycl::range<3>(_max_work_item_sizes_i[0],\n                                      _max_work_item_sizes_i[1],\n                                      _max_work_item_sizes_i[2]);\n            else\n            {\n                return _max_work_item_sizes_i;\n            }\n        }\n        template <typename WorkItemSizesTy = sycl::range<3>,\n                  std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||\n                                       std::is_same_v<WorkItemSizesTy, int *>,\n                                   int> = 0>\n        auto get_max_work_item_sizes()\n        {\n            if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)\n                return sycl::range<3>(_max_work_item_sizes_i[0],\n                                      _max_work_item_sizes_i[1],\n                                      _max_work_item_sizes_i[2]);\n            else\n            {\n                return _max_work_item_sizes_i;\n            }\n        }\n        bool get_host_unified_memory() const { return _host_unified_memory; }\n        int get_major_version() const { return _major; }\n        int get_minor_version() const { return _minor; }\n        int get_integrated() const { return _integrated; }\n        int get_max_clock_frequency() const { return _frequency; }\n        int get_max_compute_units() const { return _max_compute_units; }\n        int get_max_work_group_size() const { return _max_work_group_size; }\n        int get_max_sub_group_size() const { return _max_sub_group_size; }\n        int get_max_work_items_per_compute_unit() const\n        {\n            return _max_work_items_per_compute_unit;\n        }\n        int get_max_register_size_per_work_group() const\n        {\n            return _max_register_size_per_work_group;\n        }\n        template <typename NDRangeSizeTy = size_t *,\n                  std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||\n                                       std::is_same_v<NDRangeSizeTy, int *>,\n                                   int> = 0>\n        auto get_max_nd_range_size() const\n        {\n            if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)\n                return _max_nd_range_size;\n            else\n                return _max_nd_range_size_i;\n        }\n        template <typename NDRangeSizeTy = size_t *,\n                  std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||\n                                       std::is_same_v<NDRangeSizeTy, int *>,\n                                   int> = 0>\n        auto get_max_nd_range_size()\n        {\n            if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)\n                return _max_nd_range_size;\n            else\n                return _max_nd_range_size_i;\n        }\n        size_t get_global_mem_size() const { return _global_mem_size; }\n        size_t get_local_mem_size() const { return _local_mem_size; }\n        size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; }\n        /// Returns the maximum clock rate of device's global memory in kHz. If\n        /// compiler does not support this API then returns default value 3200000 kHz.\n        unsigned int get_memory_clock_rate() const { return _memory_clock_rate; }\n        /// Returns the maximum bus width between device and memory in bits. If\n        /// compiler does not support this API then returns default value 64 bits.\n        unsigned int get_memory_bus_width() const { return _memory_bus_width; }\n        uint32_t get_device_id() const { return _device_id; }\n        std::array<unsigned char, 16> get_uuid() const { return _uuid; }\n        /// Returns global memory cache size in bytes.\n        unsigned int get_global_mem_cache_size() const\n        {\n            return _global_mem_cache_size;\n        }\n\n        // set interface\n        void set_name(const char *name)\n        {\n            size_t length = strlen(name);\n            if (length < 256)\n            {\n                std::memcpy(_name, name, length + 1);\n            }\n            else\n            {\n                std::memcpy(_name, name, 255);\n                _name[255] = '\\0';\n            }\n        }\n        void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes)\n        {\n            for (int i = 0; i < 3; ++i)\n                _max_work_item_sizes_i[i] = max_work_item_sizes[i];\n        }\n        [[deprecated]] void\n        set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes)\n        {\n            for (int i = 0; i < 3; ++i)\n            {\n                _max_work_item_sizes_i[i] = max_work_item_sizes[i];\n            }\n        }\n        void set_host_unified_memory(bool host_unified_memory)\n        {\n            _host_unified_memory = host_unified_memory;\n        }\n        void set_major_version(int major) { _major = major; }\n        void set_minor_version(int minor) { _minor = minor; }\n        void set_integrated(int integrated) { _integrated = integrated; }\n        void set_max_clock_frequency(int frequency) { _frequency = frequency; }\n        void set_max_compute_units(int max_compute_units)\n        {\n            _max_compute_units = max_compute_units;\n        }\n        void set_global_mem_size(size_t global_mem_size)\n        {\n            _global_mem_size = global_mem_size;\n        }\n        void set_local_mem_size(size_t local_mem_size)\n        {\n            _local_mem_size = local_mem_size;\n        }\n        void set_max_mem_alloc_size(size_t max_mem_alloc_size)\n        {\n            _max_mem_alloc_size = max_mem_alloc_size;\n        }\n        void set_max_work_group_size(int max_work_group_size)\n        {\n            _max_work_group_size = max_work_group_size;\n        }\n        void set_max_sub_group_size(int max_sub_group_size)\n        {\n            _max_sub_group_size = max_sub_group_size;\n        }\n        void\n        set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit)\n        {\n            _max_work_items_per_compute_unit = max_work_items_per_compute_unit;\n        }\n        void set_max_nd_range_size(int max_nd_range_size[])\n        {\n            for (int i = 0; i < 3; i++)\n            {\n                _max_nd_range_size[i] = max_nd_range_size[i];\n                _max_nd_range_size_i[i] = max_nd_range_size[i];\n            }\n        }\n        void set_memory_clock_rate(unsigned int memory_clock_rate)\n        {\n            _memory_clock_rate = memory_clock_rate;\n        }\n        void set_memory_bus_width(unsigned int memory_bus_width)\n        {\n            _memory_bus_width = memory_bus_width;\n        }\n        void\n        set_max_register_size_per_work_group(int max_register_size_per_work_group)\n        {\n            _max_register_size_per_work_group = max_register_size_per_work_group;\n        }\n        void set_device_id(uint32_t device_id)\n        {\n            _device_id = device_id;\n        }\n        void set_uuid(std::array<unsigned char, 16> uuid)\n        {\n            _uuid = std::move(uuid);\n        }\n        void set_global_mem_cache_size(unsigned int global_mem_cache_size)\n        {\n            _global_mem_cache_size = global_mem_cache_size;\n        }\n\n    private:\n        char _name[256];\n        int _max_work_item_sizes_i[3];\n        bool _host_unified_memory = false;\n        int _major;\n        int _minor;\n        int _integrated = 0;\n        int _frequency;\n        // Set estimated value 3200000 kHz as default value.\n        unsigned int _memory_clock_rate = 3200000;\n        // Set estimated value 64 bits as default value.\n        unsigned int _memory_bus_width = 64;\n        unsigned int _global_mem_cache_size;\n        int _max_compute_units;\n        int _max_work_group_size;\n        int _max_sub_group_size;\n        int _max_work_items_per_compute_unit;\n        int _max_register_size_per_work_group;\n        size_t _global_mem_size;\n        size_t _local_mem_size;\n        size_t _max_mem_alloc_size;\n        size_t _max_nd_range_size[3];\n        int _max_nd_range_size_i[3];\n        uint32_t _device_id;\n        std::array<unsigned char, 16> _uuid;\n    };\n\n    static int get_major_version(const sycl::device &dev)\n    {\n        int major, minor;\n        detail::get_version(dev, major, minor);\n        return major;\n    }\n\n    static int get_minor_version(const sycl::device &dev)\n    {\n        int major, minor;\n        detail::get_version(dev, major, minor);\n        return minor;\n    }\n\n    static void get_device_info(device_info &out, const sycl::device &dev)\n    {\n        device_info prop;\n        prop.set_name(dev.get_info<sycl::info::device::name>().c_str());\n\n        int major, minor;\n        detail::get_version(dev, major, minor);\n        prop.set_major_version(major);\n        prop.set_minor_version(minor);\n\n        prop.set_max_work_item_sizes(\n#if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902)\n            // oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes\n            // is an enum class element\n            dev.get_info<sycl::info::device::max_work_item_sizes>());\n#else\n            // SYCL 2020-conformant code, max_work_item_sizes is a struct templated by\n            // an int\n            dev.get_info<sycl::info::device::max_work_item_sizes<3>>());\n#endif\n        prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations));\n\n        prop.set_max_clock_frequency(\n            dev.get_info<sycl::info::device::max_clock_frequency>() * 1000);\n\n        prop.set_max_compute_units(\n            dev.get_info<sycl::info::device::max_compute_units>());\n        prop.set_max_work_group_size(\n            dev.get_info<sycl::info::device::max_work_group_size>());\n        prop.set_global_mem_size(dev.get_info<sycl::info::device::global_mem_size>());\n        prop.set_local_mem_size(dev.get_info<sycl::info::device::local_mem_size>());\n        prop.set_max_mem_alloc_size(dev.get_info<sycl::info::device::max_mem_alloc_size>());\n\n#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6)\n        if (dev.has(sycl::aspect::ext_intel_memory_clock_rate))\n        {\n            unsigned int tmp =\n                dev.get_info<sycl::ext::intel::info::device::memory_clock_rate>();\n            if (tmp != 0)\n                prop.set_memory_clock_rate(1000 * tmp);\n        }\n        if (dev.has(sycl::aspect::ext_intel_memory_bus_width))\n        {\n            prop.set_memory_bus_width(\n                dev.get_info<sycl::ext::intel::info::device::memory_bus_width>());\n        }\n        if (dev.has(sycl::aspect::ext_intel_device_id))\n        {\n            prop.set_device_id(\n                dev.get_info<sycl::ext::intel::info::device::device_id>());\n        }\n        if (dev.has(sycl::aspect::ext_intel_device_info_uuid))\n        {\n            prop.set_uuid(dev.get_info<sycl::ext::intel::info::device::uuid>());\n        }\n#elif defined(_MSC_VER) && !defined(__clang__)\n#pragma message(\"get_device_info: querying memory_clock_rate and \\\n        memory_bus_width are not supported by the compiler used. \\\n        Use 3200000 kHz as memory_clock_rate default value. \\\n        Use 64 bits as memory_bus_width default value.\")\n#else\n#warning \"get_device_info: querying memory_clock_rate and \\\n        memory_bus_width are not supported by the compiler used. \\\n        Use 3200000 kHz as memory_clock_rate default value. \\\n        Use 64 bits as memory_bus_width default value.\"\n#endif\n\n        size_t max_sub_group_size = 1;\n        std::vector<size_t> sub_group_sizes =\n            dev.get_info<sycl::info::device::sub_group_sizes>();\n\n        for (const auto &sub_group_size : sub_group_sizes)\n        {\n            if (max_sub_group_size < sub_group_size)\n                max_sub_group_size = sub_group_size;\n        }\n\n        prop.set_max_sub_group_size(max_sub_group_size);\n\n        prop.set_max_work_items_per_compute_unit(\n            dev.get_info<sycl::info::device::max_work_group_size>());\n        int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF};\n        prop.set_max_nd_range_size(max_nd_range_size);\n\n        // Estimates max register size per work group, feel free to update the value\n        // according to device properties.\n        prop.set_max_register_size_per_work_group(65536);\n\n        prop.set_global_mem_cache_size(\n            dev.get_info<sycl::info::device::global_mem_cache_size>());\n        out = prop;\n    }\n\n    /// dpct device extension\n    class device_ext : public sycl::device {\n      typedef std::mutex mutex_type;\n\n     public:\n      device_ext() : sycl::device() {}\n      ~device_ext() {\n        std::lock_guard<mutex_type> lock(m_mutex);\n        clear_queues();\n      }\n      device_ext(const sycl::device &base) : sycl::device(base) {\n        std::lock_guard<mutex_type> lock(m_mutex);\n        init_queues();\n      }\n\n      int is_native_atomic_supported() { return 0; }\n      int get_major_version() const { return dpct::get_major_version(*this); }\n\n      int get_minor_version() const { return dpct::get_minor_version(*this); }\n\n      int get_max_compute_units() const {\n        return get_device_info().get_max_compute_units();\n      }\n\n      /// Return the maximum clock frequency of this device in KHz.\n      int get_max_clock_frequency() const {\n        return get_device_info().get_max_clock_frequency();\n      }\n\n      int get_integrated() const { return get_device_info().get_integrated(); }\n\n      int get_max_sub_group_size() const {\n        return get_device_info().get_max_sub_group_size();\n      }\n\n      int get_max_register_size_per_work_group() const {\n        return get_device_info().get_max_register_size_per_work_group();\n      }\n\n      int get_max_work_group_size() const {\n        return get_device_info().get_max_work_group_size();\n      }\n\n      int get_mem_base_addr_align() const {\n        return get_info<sycl::info::device::mem_base_addr_align>();\n      }\n\n      size_t get_global_mem_size() const {\n        return get_device_info().get_global_mem_size();\n      }\n\n      size_t get_max_mem_alloc_size() const {\n        return get_device_info().get_max_mem_alloc_size();\n      }\n\n      /// Get the number of bytes of free and total memory on the SYCL device.\n      /// \\param [out] free_memory The number of bytes of free memory on the\n      /// SYCL device. \\param [out] total_memory The number of bytes of total\n      /// memory on the SYCL device.\n      void get_memory_info(size_t &free_memory, size_t &total_memory) {\n        total_memory = get_device_info().get_global_mem_size();\n        const char *warning_info =\n            \"get_memory_info: [warning] ext_intel_free_memory is not \"\n            \"supported (export/set ZES_ENABLE_SYSMAN=1 to support), \"\n            \"use total memory as free memory\";\n#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105)\n        if (!has(sycl::aspect::ext_intel_free_memory)) {\n          std::cerr << warning_info << std::endl;\n          free_memory = total_memory;\n        } else {\n          free_memory = get_info<sycl::ext::intel::info::device::free_memory>();\n        }\n#else\n        std::cerr << warning_info << std::endl;\n        free_memory = total_memory;\n#if defined(_MSC_VER) && !defined(__clang__)\n#pragma message(\"Querying the number of bytes of free memory is not supported\")\n#else\n#warning \"Querying the number of bytes of free memory is not supported\"\n#endif\n#endif\n      }\n\n      void get_device_info(device_info &out) const {\n        dpct::get_device_info(out, *this);\n      }\n\n      device_info get_device_info() const {\n        device_info prop;\n        dpct::get_device_info(prop, *this);\n        return prop;\n      }\n\n      void reset() {\n        std::lock_guard<mutex_type> lock(m_mutex);\n        clear_queues();\n        init_queues();\n      }\n\n      sycl::queue &in_order_queue() { return _q_in_order; }\n\n      sycl::queue &out_of_order_queue() { return _q_out_of_order; }\n\n      sycl::queue &default_queue() { return in_order_queue(); }\n\n      void queues_wait_and_throw() {\n        std::unique_lock<mutex_type> lock(m_mutex);\n        lock.unlock();\n        for (auto &q : _queues) {\n            q.wait_and_throw();\n        }\n        // Guard the destruct of current_queues to make sure the ref count is\n        // safe.\n        lock.lock();\n      }\n\n      sycl::queue create_queue(bool enable_exception_handler = false) {\n        return create_in_order_queue(enable_exception_handler);\n      }\n\n      sycl::queue create_queue(sycl::device device,\n                               bool enable_exception_handler = false) {\n        return create_in_order_queue(device, enable_exception_handler);\n      }\n\n      sycl::queue create_in_order_queue(bool enable_exception_handler = false) {\n        std::lock_guard<mutex_type> lock(m_mutex);\n        return create_queue_impl(enable_exception_handler,\n                                 sycl::property::queue::in_order());\n      }\n\n      sycl::queue create_in_order_queue(sycl::device device,\n                                        bool enable_exception_handler = false) {\n        std::lock_guard<mutex_type> lock(m_mutex);\n        return create_queue_impl(device, enable_exception_handler,\n                                 sycl::property::queue::in_order());\n      }\n\n      sycl::queue create_out_of_order_queue(\n          bool enable_exception_handler = false) {\n        std::lock_guard<mutex_type> lock(m_mutex);\n        return create_queue_impl(enable_exception_handler);\n      }\n\n      void destroy_queue(sycl::queue queue) {\n        std::lock_guard<mutex_type> lock(m_mutex);\n        _queues.erase(std::remove_if(_queues.begin(), _queues.end(),\n                                    [=](const sycl::queue &q) -> bool\n                                    {\n                                        return q == queue;\n                                    }),\n                    _queues.end());\n      }\n      void set_saved_queue(sycl::queue q) {\n        std::lock_guard<mutex_type> lock(m_mutex);\n        _saved_queue = q;\n      }\n      sycl::queue get_saved_queue() const {\n        std::lock_guard<mutex_type> lock(m_mutex);\n        return _saved_queue;\n      }\n\n     private:\n      void clear_queues() { _queues.clear(); }\n\n      void init_queues() {\n        _q_in_order =\n            create_queue_impl(true, sycl::property::queue::in_order());\n        _q_out_of_order = create_queue_impl(true);\n        _saved_queue = default_queue();\n      }\n\n      /// Caller should acquire resource \\p m_mutex before calling this\n      /// function.\n      template <class... Properties>\n      sycl::queue create_queue_impl(bool enable_exception_handler,\n                                    Properties... properties) {\n        sycl::async_handler eh = {};\n        if (enable_exception_handler) {\n          eh = exception_handler;\n        }\n        _queues.push_back(sycl::queue(\n            *this, eh,\n            sycl::property_list(\n#ifdef DPCT_PROFILING_ENABLED\n                sycl::property::queue::enable_profiling(),\n#endif\n                properties...)));\n\n        return _queues.back();\n      }\n\n      template <class... Properties>\n      sycl::queue create_queue_impl(sycl::device device,\n                                    bool enable_exception_handler,\n                                    Properties... properties) {\n        sycl::async_handler eh = {};\n        if (enable_exception_handler) {\n          eh = exception_handler;\n        }\n        _queues.push_back(sycl::queue(\n            device, eh,\n                        sycl::property_list(\n#ifdef DPCT_PROFILING_ENABLED\n                            sycl::property::queue::enable_profiling(),\n#endif\n                            properties...)));\n\n        return _queues.back();\n      }\n\n      void get_version(int &major, int &minor) const {\n        detail::get_version(*this, major, minor);\n      }\n      sycl::queue _q_in_order, _q_out_of_order;\n      sycl::queue _saved_queue;\n      std::vector<sycl::queue> _queues;\n      mutable mutex_type m_mutex;\n    };\n\n\n    /// device manager\n    class dev_mgr\n    {\n    public:\n        device_ext &current_device()\n        {\n            unsigned int dev_id = current_device_id();\n            check_id(dev_id);\n            return *_devs[dev_id];\n        }\n        device_ext &cpu_device() const\n        {\n            std::lock_guard<std::recursive_mutex> lock(m_mutex);\n            if (_cpu_device == -1)\n            {\n                throw std::runtime_error(\"no valid cpu device\");\n            }\n            else\n            {\n                return *_devs[_cpu_device];\n            }\n        }\n        device_ext &get_device(unsigned int id) const\n        {\n            std::lock_guard<std::recursive_mutex> lock(m_mutex);\n            check_id(id);\n            return *_devs[id];\n        }\n        unsigned int current_device_id() const\n        {\n            std::lock_guard<std::recursive_mutex> lock(m_mutex);\n            auto it = _thread2dev_map.find(get_tid());\n            if (it != _thread2dev_map.end())\n                return it->second;\n            return DEFAULT_DEVICE_ID;\n        }\n\n        /// Select device with a device ID.\n        /// \\param [in] id The id of the device which can\n        /// be obtained through get_device_id(const sycl::device).\n        void select_device(unsigned int id)\n        {\n            std::lock_guard<std::recursive_mutex> lock(m_mutex);\n            check_id(id);\n            _thread2dev_map[get_tid()] = id;\n        }\n        unsigned int device_count() { return _devs.size(); }\n\n        unsigned int get_device_id(const sycl::device &dev)\n        {\n            unsigned int id = 0;\n            for (auto &dev_item : _devs)\n            {\n                if (*dev_item == dev)\n                {\n                    return id;\n                }\n                id++;\n            }\n            return -1;\n        }\n\n        inline std::string get_preferred_gpu_platform_name() {\n            std::string result;\n\n            std::string filter = \"\";\n            char* env = getenv(\"ONEAPI_DEVICE_SELECTOR\");\n            if (env) {\n                if (std::strstr(env, \"level_zero\")) {\n                    filter = \"level-zero\";\n                }\n                else if (std::strstr(env, \"opencl\")) {\n                    filter = \"opencl\";\n                }\n                else if (std::strstr(env, \"cuda\")) {\n                    filter = \"cuda\";\n                }\n                else if (std::strstr(env, \"hip\")) {\n                    filter = \"hip\";\n                }\n                else {\n                    throw std::runtime_error(\"invalid device filter: \" + std::string(env));\n                }\n            } else {\n                auto default_device = sycl::device(sycl::default_selector_v);\n                auto default_platform_name = default_device.get_platform().get_info<sycl::info::platform::name>();\n\n                if (std::strstr(default_platform_name.c_str(), \"Level-Zero\") || default_device.is_cpu()) {\n                    filter = \"level-zero\";\n                }\n                else if (std::strstr(default_platform_name.c_str(), \"CUDA\")) {\n                    filter = \"cuda\";\n                }\n                else if (std::strstr(default_platform_name.c_str(), \"HIP\")) {\n                    filter = \"hip\";\n                }\n            }\n\n            auto platform_list = sycl::platform::get_platforms();\n\n            for (const auto& platform : platform_list) {\n                auto devices = platform.get_devices();\n                auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {\n                    return d.is_gpu();\n                });\n\n                if (gpu_dev == devices.end()) {\n                    // cout << \"platform [\" << platform_name\n                    //      << \"] does not contain GPU devices, skipping\\n\";\n                    continue;\n                }\n\n                auto platform_name = platform.get_info<sycl::info::platform::name>();\n                std::string platform_name_low_case;\n                platform_name_low_case.resize(platform_name.size());\n\n                std::transform(\n                    platform_name.begin(), platform_name.end(), platform_name_low_case.begin(), ::tolower);\n\n                if (platform_name_low_case.find(filter) == std::string::npos) {\n                    // cout << \"platform [\" << platform_name\n                    //      << \"] does not match with requested \"\n                    //      << filter << \", skipping\\n\";\n                    continue;\n                }\n\n                result = platform_name;\n            }\n\n            if (result.empty())\n                throw std::runtime_error(\"can not find preferred GPU platform\");\n\n            return result;\n        }\n\n        template <class DeviceSelector>\n        std::enable_if_t<\n            std::is_invocable_r_v<int, DeviceSelector, const sycl::device &>>\n        select_device(const DeviceSelector &selector = sycl::gpu_selector_v)\n        {\n            sycl::device selected_device = sycl::device(selector);\n            unsigned int selected_device_id = get_device_id(selected_device);\n            select_device(selected_device_id);\n        }\n\n        /// Returns the instance of device manager singleton.\n        static dev_mgr &instance()\n        {\n            static dev_mgr d_m;\n            return d_m;\n        }\n        dev_mgr(const dev_mgr &) = delete;\n        dev_mgr &operator=(const dev_mgr &) = delete;\n        dev_mgr(dev_mgr &&) = delete;\n        dev_mgr &operator=(dev_mgr &&) = delete;\n\n    private:\n        mutable std::recursive_mutex m_mutex;\n        static bool compare_dev(sycl::device &device1, sycl::device &device2)\n        {\n            sycl::backend backend1 = device1.get_backend();\n            sycl::backend backend2 = device2.get_backend();\n            // levelzero backends always come first\n            if(backend1 == sycl::backend::ext_oneapi_level_zero && backend2 != sycl::backend::ext_oneapi_level_zero) return true;\n            if(backend1 != sycl::backend::ext_oneapi_level_zero && backend2 == sycl::backend::ext_oneapi_level_zero) return false;\n            dpct::device_info prop1;\n            dpct::get_device_info(prop1, device1);\n            dpct::device_info prop2;\n            dpct::get_device_info(prop2, device2);\n            return prop1.get_max_compute_units() > prop2.get_max_compute_units();\n        }\n        static int convert_backend_index(std::string & backend) {\n            if (backend == \"ext_oneapi_level_zero:gpu\") return 0;\n            if (backend == \"opencl:gpu\") return 1;\n            if (backend == \"ext_oneapi_cuda:gpu\") return 2;\n            if (backend == \"ext_oneapi_hip:gpu\") return 3;\n            if (backend == \"opencl:cpu\") return 4;\n            if (backend == \"opencl:acc\") return 5;\n            printf(\"convert_backend_index: can't handle backend=%s\\n\", backend.c_str());\n            GGML_ABORT(\"fatal error\");\n        }\n        static bool compare_backend(std::string &backend1, std::string &backend2) {\n            return convert_backend_index(backend1) < convert_backend_index(backend2);\n        }\n        dev_mgr()\n        {\n            sycl::device default_device =\n                sycl::device(sycl::default_selector_v);\n            _devs.push_back(std::make_shared<device_ext>(default_device));\n\n            std::vector<sycl::device> sycl_all_devs;\n            // Collect other devices except for the default device.\n            if (default_device.is_cpu())\n                _cpu_device = 0;\n\n            auto Platforms = sycl::platform::get_platforms();\n            // Keep track of the number of devices per backend\n            std::map<sycl::backend, size_t> DeviceNums;\n            std::map<std::string, std::vector<sycl::device>> backend_devices;\n            auto preferred_platform_name = get_preferred_gpu_platform_name();\n\n            while (!Platforms.empty()) {\n                auto Platform = Platforms.back();\n                Platforms.pop_back();\n                auto platform_name = Platform.get_info<sycl::info::platform::name>();\n                if (platform_name.compare(preferred_platform_name) != 0) {\n                    continue;\n                }\n                auto devices = Platform.get_devices();\n                std::string backend_type = get_device_backend_and_type(devices[0]);\n                for (const auto &device : devices) {\n                    backend_devices[backend_type].push_back(device);\n                }\n            }\n\n            std::vector<std::string> keys;\n            for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) {\n                keys.push_back(it->first);\n            }\n            std::sort(keys.begin(), keys.end(), compare_backend);\n\n            for (auto &key : keys) {\n                std::vector<sycl::device> devs = backend_devices[key];\n                std::sort(devs.begin(), devs.end(), compare_dev);\n                for (const auto &dev : devs) {\n                    sycl_all_devs.push_back(dev);\n                }\n            }\n\n            for (auto &dev : sycl_all_devs)\n            {\n                if (dev == default_device)\n                {\n                    continue;\n                }\n                _devs.push_back(std::make_shared<device_ext>(dev));\n                if (_cpu_device == -1 && dev.is_cpu())\n                {\n                    _cpu_device = _devs.size() - 1;\n                }\n            }\n        }\n        void check_id(unsigned int id) const\n        {\n            if (id >= _devs.size())\n            {\n                throw std::runtime_error(\"invalid device id\");\n            }\n        }\n        std::vector<std::shared_ptr<device_ext>> _devs;\n        /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current\n        /// thread id in _thread2dev_map, which means default device should be used\n        /// for the current thread.\n        const unsigned int DEFAULT_DEVICE_ID = 0;\n        /// thread-id to device-id map.\n        std::map<unsigned int, unsigned int> _thread2dev_map;\n        int _cpu_device = -1;\n    };\n\n    static inline sycl::queue &get_default_queue()\n    {\n        return dev_mgr::instance().current_device().default_queue();\n    }\n\n    namespace detail\n    {\n        enum class pointer_access_attribute\n        {\n            host_only = 0,\n            device_only,\n            host_device,\n            end\n        };\n\n        static pointer_access_attribute get_pointer_attribute(sycl::queue &q,\n                                                              const void *ptr)\n        {\n            switch (sycl::get_pointer_type(ptr, q.get_context()))\n            {\n            case sycl::usm::alloc::unknown:\n                return pointer_access_attribute::host_only;\n            case sycl::usm::alloc::device:\n                return pointer_access_attribute::device_only;\n            case sycl::usm::alloc::shared:\n            case sycl::usm::alloc::host:\n                return pointer_access_attribute::host_device;\n            }\n        }\n\n        template <typename ArgT>\n        inline constexpr std::uint64_t get_type_combination_id(ArgT Val)\n        {\n            static_assert((unsigned char)library_data_t::library_data_t_size <=\n                              std::numeric_limits<unsigned char>::max() &&\n                          \"library_data_t size exceeds limit.\");\n            static_assert(std::is_same_v<ArgT, library_data_t>, \"Unsupported ArgT\");\n            return (std::uint64_t)Val;\n        }\n\n        template <typename FirstT, typename... RestT>\n        inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal,\n                                                               RestT... RestVal)\n        {\n            static_assert((std::uint8_t)library_data_t::library_data_t_size <=\n                              std::numeric_limits<unsigned char>::max() &&\n                          \"library_data_t size exceeds limit.\");\n            static_assert(sizeof...(RestT) <= 8 && \"Too many parameters\");\n            static_assert(std::is_same_v<FirstT, library_data_t>, \"Unsupported FirstT\");\n            return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal);\n        }\n\n        class mem_mgr\n        {\n            mem_mgr()\n            {\n                // Reserved address space, no real memory allocation happens here.\n#if defined(__linux__)\n                mapped_address_space =\n                    (byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE,\n                                   MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);\n#elif defined(_WIN64)\n                mapped_address_space = (byte_t *)VirtualAlloc(\n                    NULL,               // NULL specified as the base address parameter\n                    mapped_region_size, // Size of allocation\n                    MEM_RESERVE,        // Allocate reserved pages\n                    PAGE_NOACCESS);     // Protection = no access\n#else\n#error \"Only support Windows and Linux.\"\n#endif\n                next_free = mapped_address_space;\n            }\n\n        public:\n            using buffer_id_t = int;\n\n            struct allocation\n            {\n                buffer_t buffer;\n                byte_t *alloc_ptr;\n                size_t size;\n            };\n\n            ~mem_mgr()\n            {\n#if defined(__linux__)\n                munmap(mapped_address_space, mapped_region_size);\n#elif defined(_WIN64)\n                VirtualFree(mapped_address_space, 0, MEM_RELEASE);\n#else\n#error \"Only support Windows and Linux.\"\n#endif\n            }\n\n            mem_mgr(const mem_mgr &) = delete;\n            mem_mgr &operator=(const mem_mgr &) = delete;\n            mem_mgr(mem_mgr &&) = delete;\n            mem_mgr &operator=(mem_mgr &&) = delete;\n\n            /// Allocate\n            void *mem_alloc(size_t size)\n            {\n                if (!size)\n                    return nullptr;\n                std::lock_guard<std::mutex> lock(m_mutex);\n                if (next_free + size > mapped_address_space + mapped_region_size)\n                {\n                    throw std::runtime_error(\"dpct_malloc: out of memory for virtual memory pool\");\n                }\n                // Allocation\n                sycl::range<1> r(size);\n                buffer_t buf(r);\n                allocation A{buf, next_free, size};\n                // Map allocation to device pointer\n                void *result = next_free;\n                m_map.emplace(next_free + size, A);\n                // Update pointer to the next free space.\n                next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1);\n\n                return result;\n            }\n\n            /// Deallocate\n            void mem_free(const void *ptr)\n            {\n                if (!ptr)\n                    return;\n                std::lock_guard<std::mutex> lock(m_mutex);\n                auto it = get_map_iterator(ptr);\n                m_map.erase(it);\n            }\n\n            /// map: device pointer -> allocation(buffer, alloc_ptr, size)\n            allocation translate_ptr(const void *ptr)\n            {\n                std::lock_guard<std::mutex> lock(m_mutex);\n                auto it = get_map_iterator(ptr);\n                return it->second;\n            }\n\n            /// Check if the pointer represents device pointer or not.\n            bool is_device_ptr(const void *ptr) const\n            {\n                std::lock_guard<std::mutex> lock(m_mutex);\n                return (mapped_address_space <= ptr) &&\n                       (ptr < mapped_address_space + mapped_region_size);\n            }\n\n            /// Returns the instance of memory manager singleton.\n            static mem_mgr &instance()\n            {\n                static mem_mgr m;\n                return m;\n            }\n\n        private:\n            std::map<byte_t *, allocation> m_map;\n            mutable std::mutex m_mutex;\n            byte_t *mapped_address_space;\n            byte_t *next_free;\n            const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024;\n            const size_t alignment = 256;\n            /// This padding may be defined to some positive value to debug\n            /// out of bound accesses.\n            const size_t extra_padding = 0;\n\n            std::map<byte_t *, allocation>::iterator get_map_iterator(const void *ptr)\n            {\n                auto it = m_map.upper_bound(const_cast<byte_t *>(reinterpret_cast<const byte_t *>(ptr)));\n                if (it == m_map.end())\n                {\n                    // Not a virtual pointer.\n                    throw std::runtime_error(\"can not get buffer from non-virtual pointer\");\n                }\n                const allocation &alloc = it->second;\n                if (ptr < alloc.alloc_ptr)\n                {\n                    // Out of bound.\n                    // This may happen if there's a gap between allocations due to alignment\n                    // or extra padding and pointer points to this gap.\n                    throw std::runtime_error(\"invalid virtual pointer\");\n                }\n                return it;\n            }\n        };\n\n        template <class T, memory_region Memory, size_t Dimension>\n        class accessor;\n        template <memory_region Memory, class T = byte_t>\n        class memory_traits\n        {\n        public:\n            static constexpr sycl::access::target target =\n                sycl::access::target::device;\n            static constexpr sycl::access_mode mode =\n                (Memory == constant) ? sycl::access_mode::read\n                                     : sycl::access_mode::read_write;\n            static constexpr size_t type_size = sizeof(T);\n            using element_t =\n                typename std::conditional<Memory == constant, const T, T>::type;\n            using value_t = typename std::remove_cv<T>::type;\n            template <size_t Dimension = 1>\n            using accessor_t = typename std::conditional<\n                Memory == local, sycl::local_accessor<value_t, Dimension>,\n                sycl::accessor<T, Dimension, mode, target>>::type;\n            using pointer_t = T *;\n        };\n\n        static inline void *dpct_malloc(size_t size, sycl::queue &q)\n        {\n            return sycl::malloc_device(size, q.get_device(), q.get_context());\n        }\n\n#define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F))\n        static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, size_t z,\n                                        sycl::queue &q)\n        {\n            pitch = PITCH_DEFAULT_ALIGN(x);\n            return dpct_malloc(pitch * y * z, q);\n        }\n\n        /**\n         * @brief Sets \\p value to the first \\p size elements starting from \\p dev_ptr in \\p q.\n         * @tparam valueT The type of the element to be set.\n         * @param [in] q The queue in which the operation is done.\n         * @param [in] dev_ptr Pointer to the virtual device memory address.\n         * @param [in] value The value to be set.\n         * @param [in] size Number of elements to be set to the value.\n         * @return An event representing the memset operation.\n         */\n        template <typename valueT>\n        static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr,\n                                              valueT value, size_t size)\n        {\n            return q.fill(dev_ptr, value, size);\n        }\n\n        /**\n         * @brief Sets \\p value to the 3D memory region pointed by \\p data in \\p q.\n         * @tparam valueT The type of the element to be set.\n         * @param [in] q The queue in which the operation is done.\n         * @param [in] data Pointer to the pitched device memory region.\n         * @param [in] value The value to be set.\n         * @param [in] size 3D memory region by number of elements.\n         * @return An event list representing the memset operations.\n         */\n        template <typename valueT>\n        static inline std::vector<sycl::event>\n        dpct_memset(sycl::queue &q, pitched_data data, valueT value,\n                    sycl::range<3> size)\n        {\n            std::vector<sycl::event> event_list;\n            size_t slice = data.get_pitch() * data.get_y();\n            unsigned char *data_surface = (unsigned char *)data.get_data_ptr();\n            for (size_t z = 0; z < size.get(2); ++z)\n            {\n                unsigned char *data_ptr = data_surface;\n                for (size_t y = 0; y < size.get(1); ++y)\n                {\n                    event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0)));\n                    data_ptr += data.get_pitch();\n                }\n                data_surface += slice;\n            }\n            return event_list;\n        }\n\n        /**\n         * @brief Sets \\p val to the pitched 2D memory region pointed by \\p ptr in \\p q.\n         * @tparam valueT The type of the element to be set.\n         * @param [in] q The queue in which the operation is done.\n         * @param [in] ptr Pointer to the virtual device memory.\n         * @param [in] pitch The pitch size by number of elements, including padding.\n         * @param [in] val The value to be set.\n         * @param [in] x The width of memory region by number of elements.\n         * @param [in] y The height of memory region by number of elements.\n         * @return An event list representing the memset operations.\n         */\n        template <typename valueT>\n        static inline std::vector<sycl::event>\n        dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x,\n                    size_t y)\n        {\n            return dpct_memset(q, pitched_data(ptr, pitch, x, 1), val,\n                               sycl::range<3>(x, y, 1));\n        }\n\n        static memcpy_direction deduce_memcpy_direction(sycl::queue &q, void *to_ptr,\n                                                        const void *from_ptr,\n                                                        memcpy_direction dir)\n        {\n            switch (dir)\n            {\n            case memcpy_direction::host_to_host:\n            case memcpy_direction::host_to_device:\n            case memcpy_direction::device_to_host:\n            case memcpy_direction::device_to_device:\n                return dir;\n            case memcpy_direction::automatic:\n            {\n                // table[to_attribute][from_attribute]\n                static const memcpy_direction\n                    direction_table[static_cast<unsigned>(pointer_access_attribute::end)]\n                                   [static_cast<unsigned>(pointer_access_attribute::end)] =\n                                       {{memcpy_direction::host_to_host,\n                                         memcpy_direction::device_to_host,\n                                         memcpy_direction::host_to_host},\n                                        {memcpy_direction::host_to_device,\n                                         memcpy_direction::device_to_device,\n                                         memcpy_direction::device_to_device},\n                                        {memcpy_direction::host_to_host,\n                                         memcpy_direction::device_to_device,\n                                         memcpy_direction::device_to_device}};\n                return direction_table[static_cast<unsigned>(get_pointer_attribute(\n                    q, to_ptr))][static_cast<unsigned>(get_pointer_attribute(q, from_ptr))];\n            }\n            default:\n                throw std::runtime_error(\"dpct_memcpy: invalid direction value\");\n            }\n        }\n\n        static sycl::event\n        dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size,\n                    memcpy_direction direction,\n                    const std::vector<sycl::event> &dep_events = {})\n        {\n            if (!size)\n                return sycl::event{};\n            return q.memcpy(to_ptr, from_ptr, size, dep_events);\n            GGML_UNUSED(direction);\n        }\n\n        // Get actual copy range and make sure it will not exceed range.\n        static inline size_t get_copy_range(sycl::range<3> size, size_t slice,\n                                            size_t pitch)\n        {\n            return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0);\n        }\n\n        static inline size_t get_offset(sycl::id<3> id, size_t slice,\n                                        size_t pitch)\n        {\n            return slice * id.get(2) + pitch * id.get(1) + id.get(0);\n        }\n\n        /// copy 3D matrix specified by \\p size from 3D matrix specified by \\p from_ptr\n        /// and \\p from_range to another specified by \\p to_ptr and \\p to_range.\n        static inline std::vector<sycl::event>\n        dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,\n                    sycl::range<3> to_range, sycl::range<3> from_range,\n                    sycl::id<3> to_id, sycl::id<3> from_id,\n                    sycl::range<3> size, memcpy_direction direction,\n                    const std::vector<sycl::event> &dep_events = {})\n        {\n            // RAII for host pointer\n            class host_buffer\n            {\n                void *_buf;\n                size_t _size;\n                sycl::queue &_q;\n                const std::vector<sycl::event> &_deps; // free operation depends\n\n            public:\n                host_buffer(size_t size, sycl::queue &q,\n                            const std::vector<sycl::event> &deps)\n                    : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {}\n                void *get_ptr() const { return _buf; }\n                size_t get_size() const { return _size; }\n                ~host_buffer()\n                {\n                    if (_buf)\n                    {\n                        _q.submit([&](sycl::handler &cgh)\n                                  {\n        cgh.depends_on(_deps);\n        cgh.host_task([buf = _buf] { std::free(buf); }); });\n                    }\n                }\n            };\n            std::vector<sycl::event> event_list;\n\n            size_t to_slice = to_range.get(1) * to_range.get(0),\n                   from_slice = from_range.get(1) * from_range.get(0);\n            unsigned char *to_surface =\n                (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0));\n            const unsigned char *from_surface =\n                (const unsigned char *)from_ptr +\n                get_offset(from_id, from_slice, from_range.get(0));\n\n            if (to_slice == from_slice && to_slice == size.get(1) * size.get(0))\n            {\n                return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2),\n                                    direction, dep_events)};\n            }\n            direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction);\n            size_t size_slice = size.get(1) * size.get(0);\n            switch (direction)\n            {\n            case host_to_host:\n                for (size_t z = 0; z < size.get(2); ++z)\n                {\n                    unsigned char *to_ptr = to_surface;\n                    const unsigned char *from_ptr = from_surface;\n                    if (to_range.get(0) == from_range.get(0) &&\n                        to_range.get(0) == size.get(0))\n                    {\n                        event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice,\n                                                         direction, dep_events));\n                    }\n                    else\n                    {\n                        for (size_t y = 0; y < size.get(1); ++y)\n                        {\n                            event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0),\n                                                             direction, dep_events));\n                            to_ptr += to_range.get(0);\n                            from_ptr += from_range.get(0);\n                        }\n                    }\n                    to_surface += to_slice;\n                    from_surface += from_slice;\n                }\n                break;\n            case host_to_device:\n            {\n                host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q,\n                                event_list);\n                std::vector<sycl::event> host_events;\n                if (to_slice == size_slice)\n                {\n                    // Copy host data to a temp host buffer with the shape of target.\n                    host_events =\n                        dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range,\n                                    sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size,\n                                    host_to_host, dep_events);\n                }\n                else\n                {\n                    // Copy host data to a temp host buffer with the shape of target.\n                    host_events = dpct_memcpy(\n                        q, buf.get_ptr(), from_surface, to_range, from_range,\n                        sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host,\n                        // If has padding data, not sure whether it is useless. So fill temp\n                        // buffer with it.\n                        std::vector<sycl::event>{\n                            dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(),\n                                        device_to_host, dep_events)});\n                }\n                // Copy from temp host buffer to device with only one submit.\n                event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(),\n                                                 buf.get_size(), host_to_device,\n                                                 host_events));\n                break;\n            }\n            case device_to_host:\n            {\n                host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q,\n                                event_list);\n                // Copy from host temp buffer to host target with reshaping.\n                event_list = dpct_memcpy(\n                    q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0),\n                    sycl::id<3>(0, 0, 0), size, host_to_host,\n                    // Copy from device to temp host buffer with only one submit.\n                    std::vector<sycl::event>{dpct_memcpy(q, buf.get_ptr(), from_surface,\n                                                         buf.get_size(),\n                                                         device_to_host, dep_events)});\n                break;\n            }\n            case device_to_device:\n                event_list.push_back(q.submit([&](sycl::handler &cgh){\n                cgh.depends_on(dep_events);\n                cgh.parallel_for<class dpct_memcpy_3d_detail>(\n                    size,\n                    [=](sycl::id<3> id) {\n                        to_surface[get_offset(id, to_slice, to_range.get(0))] =\n                            from_surface[get_offset(id, from_slice, from_range.get(0))];\n                    }); }));\n                break;\n            default:\n                throw std::runtime_error(\"dpct_memcpy: invalid direction value\");\n            }\n            return event_list;\n        }\n\n        /// memcpy 2D/3D matrix specified by pitched_data.\n        static inline std::vector<sycl::event>\n        dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id,\n                    pitched_data from, sycl::id<3> from_id, sycl::range<3> size,\n                    memcpy_direction direction = automatic)\n        {\n            return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(),\n                               sycl::range<3>(to.get_pitch(), to.get_y(), 1),\n                               sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id,\n                               size, direction);\n        }\n\n        /// memcpy 2D matrix with pitch.\n        static inline std::vector<sycl::event>\n        dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,\n                    size_t to_pitch, size_t from_pitch, size_t x, size_t y,\n                    memcpy_direction direction = automatic)\n        {\n            return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1),\n                               sycl::range<3>(from_pitch, y, 1),\n                               sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0),\n                               sycl::range<3>(x, y, 1), direction);\n        }\n\n        namespace deprecated\n        {\n\n            template <typename T, sycl::usm::alloc AllocKind>\n            class usm_allocator\n            {\n            private:\n                using Alloc = sycl::usm_allocator<T, AllocKind>;\n                Alloc _impl;\n\n            public:\n                using value_type = typename std::allocator_traits<Alloc>::value_type;\n                using pointer = typename std::allocator_traits<Alloc>::pointer;\n                using const_pointer = typename std::allocator_traits<Alloc>::const_pointer;\n                using void_pointer = typename std::allocator_traits<Alloc>::void_pointer;\n                using const_void_pointer =\n                    typename std::allocator_traits<Alloc>::const_void_pointer;\n                using reference = typename std::allocator_traits<Alloc>::value_type &;\n                using const_reference =\n                    const typename std::allocator_traits<Alloc>::value_type &;\n                using difference_type =\n                    typename std::allocator_traits<Alloc>::difference_type;\n                using size_type = typename std::allocator_traits<Alloc>::size_type;\n                using propagate_on_container_copy_assignment = typename std::allocator_traits<\n                    Alloc>::propagate_on_container_copy_assignment;\n                using propagate_on_container_move_assignment = typename std::allocator_traits<\n                    Alloc>::propagate_on_container_move_assignment;\n                using propagate_on_container_swap =\n                    typename std::allocator_traits<Alloc>::propagate_on_container_swap;\n                using is_always_equal =\n                    typename std::allocator_traits<Alloc>::is_always_equal;\n\n                template <typename U>\n                struct rebind\n                {\n                    typedef usm_allocator<U, AllocKind> other;\n                };\n\n                usm_allocator() : _impl(dpct::get_default_queue()) {}\n                ~usm_allocator() {}\n                usm_allocator(const usm_allocator &other) : _impl(other._impl) {}\n                usm_allocator(usm_allocator &&other) : _impl(std::move(other._impl)) {}\n                pointer address(reference r) { return &r; }\n                const_pointer address(const_reference r) { return &r; }\n                pointer allocate(size_type cnt, const_void_pointer hint = nullptr)\n                {\n                    return std::allocator_traits<Alloc>::allocate(_impl, cnt, hint);\n                }\n                void deallocate(pointer p, size_type cnt)\n                {\n                    std::allocator_traits<Alloc>::deallocate(_impl, p, cnt);\n                }\n                size_type max_size() const\n                {\n                    return std::allocator_traits<Alloc>::max_size(_impl);\n                }\n                bool operator==(const usm_allocator &other) const { return _impl == other._impl; }\n                bool operator!=(const usm_allocator &other) const { return _impl != other._impl; }\n            };\n\n        } // namespace deprecated\n\n        inline void dpct_free(void *ptr,\n                              const sycl::queue &q)\n        {\n            if (ptr)\n            {\n                sycl::free(ptr, q.get_context());\n            }\n        }\n\n        template <typename T>\n        inline auto get_memory(const void *x)\n        {\n            T *new_x = reinterpret_cast<T *>(const_cast<void *>(x));\n            return new_x;\n        }\n\n        template <typename T>\n        inline typename DataType<T>::T2 get_value(const T *s, sycl::queue &q)\n        {\n            using Ty = typename DataType<T>::T2;\n            Ty s_h;\n            if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only)\n                detail::dpct_memcpy(q, (void *)&s_h, (const void *)s, sizeof(T), device_to_host)\n                    .wait();\n            else\n                s_h = *reinterpret_cast<const Ty *>(s);\n            return s_h;\n        }\n\n    } // namespace detail\n\n    template <typename T>\n    inline auto get_value(const T *s, sycl::queue &q)\n    {\n        return detail::get_value(s, q);\n    }\n\n    namespace detail\n    {\n    template <class Ta, class Tb, class Tc, class Ts>\n    inline void gemm_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,\n                          int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,\n                          const void * beta, void * c, int ldc) {\n        Ts   alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);\n        Ts   beta_value  = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);\n        auto data_a      = get_memory<const Ta>(a);\n        auto data_b      = get_memory<const Tb>(b);\n        auto data_c      = get_memory<Tc>(c);\n        oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a,\n                                               lda, data_b, ldb, beta_value, data_c, ldc);\n    }\n\n        template <typename VecT, class BinaryOperation, class = void>\n        class vectorized_binary\n        {\n        public:\n            inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)\n            {\n                VecT v4;\n                for (size_t i = 0; i < v4.size(); ++i)\n                {\n                    v4[i] = binary_op(a[i], b[i]);\n                }\n                return v4;\n            }\n        };\n\n        template <typename VecT, class BinaryOperation>\n        class vectorized_binary<\n            VecT, BinaryOperation,\n            std::void_t<std::invoke_result_t<BinaryOperation, VecT, VecT>>>\n        {\n        public:\n            inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)\n            {\n                return binary_op(a, b).template as<VecT>();\n            }\n        };\n\n        template <class Ta, class Tb, class Tc, class Ts>\n        inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,\n                                    int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,\n                                    int ldb, const void * beta, void ** c, int ldc, int batch_size,\n                                    matrix_info_t<float> * matrix_info) {\n            Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);\n            Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);\n\n            matrix_info->transpose_info[0] = a_trans;\n            matrix_info->transpose_info[1] = b_trans;\n            matrix_info->value_info[0] = alpha_value;\n            matrix_info->value_info[1] = beta_value;\n            matrix_info->size_info[0] = m;\n            matrix_info->size_info[1] = n;\n            matrix_info->size_info[2] = k;\n            matrix_info->ld_info[0] = lda;\n            matrix_info->ld_info[1] = ldb;\n            matrix_info->ld_info[2] = ldc;\n            matrix_info->groupsize_info = batch_size;\n\n            sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(\n                q, matrix_info->transpose_info, matrix_info->transpose_info + 1,\n                matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,\n                reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,\n                reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,\n                reinterpret_cast<Ts *>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),\n                matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));\n        }\n\n        template <class Ta, class Tb, class Tc, class Ts>\n        inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,\n                                    int m, int n, int k, const void * alpha, const void * a, int lda,\n                                    long long int stride_a, const void * b, int ldb, long long int stride_b,\n                                    const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {\n            Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);\n            Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);\n            auto data_a = get_memory<const Ta>(a);\n            auto data_b = get_memory<const Tb>(b);\n            auto data_c = get_memory<Tc>(c);\n            oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value,\n                                                         data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,\n                                                         data_c, ldc, stride_c, batch_size);\n        }\n\n    } // namespace detail\n\n    template <typename VecT, class BinaryOperation>\n    inline unsigned vectorized_binary(unsigned a, unsigned b,\n                                      const BinaryOperation binary_op)\n    {\n        sycl::vec<unsigned, 1> v0{a}, v1{b};\n        auto v2 = v0.as<VecT>();\n        auto v3 = v1.as<VecT>();\n        auto v4 =\n            detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);\n        v0 = v4.template as<sycl::vec<unsigned, 1>>();\n        return v0;\n    }\n\n    static void async_dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size,\n                                  memcpy_direction direction = automatic,\n                                  sycl::queue &q = dpct::get_default_queue())\n    {\n        detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction);\n    }\n\n    static inline unsigned int select_device(unsigned int id)\n    {\n        dev_mgr::instance().select_device(id);\n        return id;\n    }\n\n    template <typename T>\n    T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask,\n                               unsigned int logical_sub_group_size = 32)\n    {\n        unsigned int id = g.get_local_linear_id();\n        unsigned int start_index =\n            id / logical_sub_group_size * logical_sub_group_size;\n        unsigned int target_offset = (id % logical_sub_group_size) ^ mask;\n        return sycl::select_from_group(g, x,\n                                       target_offset < logical_sub_group_size\n                                           ? start_index + target_offset\n                                           : id);\n    }\n\n    template <typename T1, typename T2>\n    using dot_product_acc_t = std::conditional_t<\n        std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,\n        uint32_t,\n        int32_t>;\n\n    template <typename T>\n    sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val) {\n      return sycl::vec<T, 1>(val)\n          .template as<sycl::vec<\n              std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>,\n              4>>()\n          .template convert<T>();\n    }\n\n    template <typename T1, typename T2, typename T3>\n    inline auto dp4a(T1 a, T2 b, T3 c) {\n      dot_product_acc_t<T1, T2> res = c;\n      auto va = extract_and_sign_or_zero_extend4(a);\n      auto vb = extract_and_sign_or_zero_extend4(b);\n      res += va[0] * vb[0];\n      res += va[1] * vb[1];\n      res += va[2] * vb[2];\n      res += va[3] * vb[3];\n      return res;\n    }\n\n    struct sub_sat\n    {\n        template <typename T>\n        auto operator()(const T x, const T y) const\n        {\n            return sycl::sub_sat(x, y);\n        }\n    };\n\n    template <typename S, typename T>\n    inline T vectorized_min(T a, T b)\n    {\n        sycl::vec<T, 1> v0{a}, v1{b};\n        auto v2 = v0.template as<S>();\n        auto v3 = v1.template as<S>();\n        auto v4 = sycl::min(v2, v3);\n        v0 = v4.template as<sycl::vec<T, 1>>();\n        return v0;\n    }\n\n    inline float pow(const float a, const int b) { return sycl::pown(a, b); }\n    inline double pow(const double a, const int b) { return sycl::pown(a, b); }\n    inline float pow(const float a, const float b) { return sycl::pow(a, b); }\n    inline double pow(const double a, const double b) { return sycl::pow(a, b); }\n    template <typename T, typename U>\n    inline typename std::enable_if_t<std::is_floating_point_v<T>, T>\n    pow(const T a, const U b)\n    {\n        return sycl::pow(a, static_cast<T>(b));\n    }\n    template <typename T, typename U>\n    inline typename std::enable_if_t<!std::is_floating_point_v<T>, double>\n    pow(const T a, const U b)\n    {\n        return sycl::pow(static_cast<double>(a), static_cast<double>(b));\n    }\n\n    inline double min(const double a, const float b)\n    {\n        return sycl::fmin(a, static_cast<double>(b));\n    }\n    inline double min(const float a, const double b)\n    {\n        return sycl::fmin(static_cast<double>(a), b);\n    }\n    inline float min(const float a, const float b) { return sycl::fmin(a, b); }\n    inline double min(const double a, const double b) { return sycl::fmin(a, b); }\n    inline std::uint32_t min(const std::uint32_t a, const std::int32_t b)\n    {\n        return sycl::min(a, static_cast<std::uint32_t>(b));\n    }\n    inline std::uint32_t min(const std::int32_t a, const std::uint32_t b)\n    {\n        return sycl::min(static_cast<std::uint32_t>(a), b);\n    }\n    inline std::int32_t min(const std::int32_t a, const std::int32_t b)\n    {\n        return sycl::min(a, b);\n    }\n    inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b)\n    {\n        return sycl::min(a, b);\n    }\n    inline std::uint64_t min(const std::uint64_t a, const std::int64_t b)\n    {\n        return sycl::min(a, static_cast<std::uint64_t>(b));\n    }\n    inline std::uint64_t min(const std::int64_t a, const std::uint64_t b)\n    {\n        return sycl::min(static_cast<std::uint64_t>(a), b);\n    }\n    inline std::int64_t min(const std::int64_t a, const std::int64_t b)\n    {\n        return sycl::min(a, b);\n    }\n    inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b)\n    {\n        return sycl::min(a, b);\n    }\n    inline std::uint64_t min(const std::uint64_t a, const std::int32_t b)\n    {\n        return sycl::min(a, static_cast<std::uint64_t>(b));\n    }\n    inline std::uint64_t min(const std::int32_t a, const std::uint64_t b)\n    {\n        return sycl::min(static_cast<std::uint64_t>(a), b);\n    }\n    inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b)\n    {\n        return sycl::min(a, static_cast<std::uint64_t>(b));\n    }\n    inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b)\n    {\n        return sycl::min(static_cast<std::uint64_t>(a), b);\n    }\n    // max function overloads.\n    // For floating-point types, `float` or `double` arguments are acceptable.\n    // For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or\n    // `std::int64_t` type arguments are acceptable.\n    inline double max(const double a, const float b)\n    {\n        return sycl::fmax(a, static_cast<double>(b));\n    }\n    inline double max(const float a, const double b)\n    {\n        return sycl::fmax(static_cast<double>(a), b);\n    }\n    inline float max(const float a, const float b) { return sycl::fmax(a, b); }\n    inline double max(const double a, const double b) { return sycl::fmax(a, b); }\n    inline std::uint32_t max(const std::uint32_t a, const std::int32_t b)\n    {\n        return sycl::max(a, static_cast<std::uint32_t>(b));\n    }\n    inline std::uint32_t max(const std::int32_t a, const std::uint32_t b)\n    {\n        return sycl::max(static_cast<std::uint32_t>(a), b);\n    }\n    inline std::int32_t max(const std::int32_t a, const std::int32_t b)\n    {\n        return sycl::max(a, b);\n    }\n    inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b)\n    {\n        return sycl::max(a, b);\n    }\n    inline std::uint64_t max(const std::uint64_t a, const std::int64_t b)\n    {\n        return sycl::max(a, static_cast<std::uint64_t>(b));\n    }\n    inline std::uint64_t max(const std::int64_t a, const std::uint64_t b)\n    {\n        return sycl::max(static_cast<std::uint64_t>(a), b);\n    }\n    inline std::int64_t max(const std::int64_t a, const std::int64_t b)\n    {\n        return sycl::max(a, b);\n    }\n    inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b)\n    {\n        return sycl::max(a, b);\n    }\n    inline std::uint64_t max(const std::uint64_t a, const std::int32_t b)\n    {\n        return sycl::max(a, static_cast<std::uint64_t>(b));\n    }\n    inline std::uint64_t max(const std::int32_t a, const std::uint64_t b)\n    {\n        return sycl::max(static_cast<std::uint64_t>(a), b);\n    }\n    inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b)\n    {\n        return sycl::max(a, static_cast<std::uint64_t>(b));\n    }\n    inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b)\n    {\n        return sycl::max(static_cast<std::uint64_t>(a), b);\n    }\n\n    inline void\n    has_capability_or_fail(const sycl::device &dev,\n                           const std::initializer_list<sycl::aspect> &props)\n    {\n        for (const auto &it : props)\n        {\n            if (dev.has(it))\n                continue;\n            switch (it)\n            {\n            case sycl::aspect::fp64:\n                throw std::runtime_error(\"'double' is not supported in '\" +\n                                         dev.get_info<sycl::info::device::name>() +\n                                         \"' device\");\n                break;\n            case sycl::aspect::fp16:\n                throw std::runtime_error(\"'half' is not supported in '\" +\n                                         dev.get_info<sycl::info::device::name>() +\n                                         \"' device\");\n                break;\n            default:\n#define __SYCL_ASPECT(ASPECT, ID) \\\n    case sycl::aspect::ASPECT:    \\\n        return #ASPECT;\n#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID)\n#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE)\n                auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string\n                {\n                    switch (AspectNum)\n                    {\n#include <sycl/info/aspects.def>\n#include <sycl/info/aspects_deprecated.def>\n                    default:\n                        return \"unknown aspect\";\n                    }\n                };\n#undef __SYCL_ASPECT_DEPRECATED_ALIAS\n#undef __SYCL_ASPECT_DEPRECATED\n#undef __SYCL_ASPECT\n                throw std::runtime_error(\n                    \"'\" + getAspectNameStr(it) + \"' is not supported in '\" +\n                    dev.get_info<sycl::info::device::name>() + \"' device\");\n            }\n            break;\n        }\n    }\n\n    static inline unsigned int get_current_device_id()\n    {\n        return dev_mgr::instance().current_device_id();\n    }\n\n    static inline device_ext &get_current_device()\n    {\n        return dev_mgr::instance().current_device();\n    }\n\n    static inline device_ext &get_device(unsigned int id)\n    {\n        return dev_mgr::instance().get_device(id);\n    }\n\n    static inline sycl::queue &get_in_order_queue()\n    {\n        return dev_mgr::instance().current_device().in_order_queue();\n    }\n\n    static sycl::event\n    dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size,\n                memcpy_direction direction,\n                const std::vector<sycl::event> &dep_events = {})\n    {\n        if (!size)\n            return sycl::event{};\n        return q.memcpy(to_ptr, from_ptr, size, dep_events);\n        GGML_UNUSED(direction);\n    }\n\n    // Get actual copy range and make sure it will not exceed range.\n    static inline size_t get_copy_range(sycl::range<3> size, size_t slice,\n                                        size_t pitch)\n    {\n        return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0);\n    }\n\n    static inline size_t get_offset(sycl::id<3> id, size_t slice,\n                                    size_t pitch)\n    {\n        return slice * id.get(2) + pitch * id.get(1) + id.get(0);\n    }\n\n    /// copy 3D matrix specified by \\p size from 3D matrix specified by \\p from_ptr\n    /// and \\p from_range to another specified by \\p to_ptr and \\p to_range.\n    static inline std::vector<sycl::event>\n    dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,\n                sycl::range<3> to_range, sycl::range<3> from_range,\n                sycl::id<3> to_id, sycl::id<3> from_id,\n                sycl::range<3> size, memcpy_direction direction,\n                const std::vector<sycl::event> &dep_events = {})\n    {\n        // RAII for host pointer\n        class host_buffer\n        {\n            void *_buf;\n            size_t _size;\n            sycl::queue &_q;\n            const std::vector<sycl::event> &_deps; // free operation depends\n\n        public:\n            host_buffer(size_t size, sycl::queue &q,\n                        const std::vector<sycl::event> &deps)\n                : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {}\n            void *get_ptr() const { return _buf; }\n            size_t get_size() const { return _size; }\n            ~host_buffer()\n            {\n                if (_buf)\n                {\n                    _q.submit([&](sycl::handler &cgh)\n                              {\n            cgh.depends_on(_deps);\n            cgh.host_task([buf = _buf] { std::free(buf); }); });\n                }\n            }\n        };\n        std::vector<sycl::event> event_list;\n\n        size_t to_slice = to_range.get(1) * to_range.get(0),\n               from_slice = from_range.get(1) * from_range.get(0);\n        unsigned char *to_surface =\n            (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0));\n        const unsigned char *from_surface =\n            (const unsigned char *)from_ptr +\n            get_offset(from_id, from_slice, from_range.get(0));\n\n        if (to_slice == from_slice && to_slice == size.get(1) * size.get(0))\n        {\n            return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2),\n                                direction, dep_events)};\n        }\n        direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction);\n        size_t size_slice = size.get(1) * size.get(0);\n        switch (direction)\n        {\n        case host_to_host:\n            for (size_t z = 0; z < size.get(2); ++z)\n            {\n                unsigned char *to_ptr = to_surface;\n                const unsigned char *from_ptr = from_surface;\n                if (to_range.get(0) == from_range.get(0) &&\n                    to_range.get(0) == size.get(0))\n                {\n                    event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice,\n                                                     direction, dep_events));\n                }\n                else\n                {\n                    for (size_t y = 0; y < size.get(1); ++y)\n                    {\n                        event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0),\n                                                         direction, dep_events));\n                        to_ptr += to_range.get(0);\n                        from_ptr += from_range.get(0);\n                    }\n                }\n                to_surface += to_slice;\n                from_surface += from_slice;\n            }\n            break;\n        case host_to_device:\n        {\n            host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q,\n                            event_list);\n            std::vector<sycl::event> host_events;\n            if (to_slice == size_slice)\n            {\n                // Copy host data to a temp host buffer with the shape of target.\n                host_events =\n                    dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range,\n                                sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size,\n                                host_to_host, dep_events);\n            }\n            else\n            {\n                // Copy host data to a temp host buffer with the shape of target.\n                host_events = dpct_memcpy(\n                    q, buf.get_ptr(), from_surface, to_range, from_range,\n                    sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host,\n                    // If has padding data, not sure whether it is useless. So fill temp\n                    // buffer with it.\n                    std::vector<sycl::event>{\n                        dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(),\n                                    device_to_host, dep_events)});\n            }\n            // Copy from temp host buffer to device with only one submit.\n            event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(),\n                                             buf.get_size(), host_to_device,\n                                             host_events));\n            break;\n        }\n        case device_to_host:\n        {\n            host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q,\n                            event_list);\n            // Copy from host temp buffer to host target with reshaping.\n            event_list = dpct_memcpy(\n                q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0),\n                sycl::id<3>(0, 0, 0), size, host_to_host,\n                // Copy from device to temp host buffer with only one submit.\n                std::vector<sycl::event>{dpct_memcpy(q, buf.get_ptr(), from_surface,\n                                                     buf.get_size(),\n                                                     device_to_host, dep_events)});\n            break;\n        }\n        case device_to_device:\n            event_list.push_back(q.submit([&](sycl::handler &cgh)\n                                          {\n        cgh.depends_on(dep_events);\n        cgh.parallel_for<class dpct_memcpy_3d_detail>(\n            size,\n            [=](sycl::id<3> id) {\n                to_surface[get_offset(id, to_slice, to_range.get(0))] =\n                    from_surface[get_offset(id, from_slice, from_range.get(0))];\n            }); }));\n        break;\n        default:\n            throw std::runtime_error(\"dpct_memcpy: invalid direction value\");\n        }\n        return event_list;\n    }\n\n    /// memcpy 2D/3D matrix specified by pitched_data.\n    static inline std::vector<sycl::event>\n    dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id,\n                pitched_data from, sycl::id<3> from_id, sycl::range<3> size,\n                memcpy_direction direction = automatic)\n    {\n        return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(),\n                           sycl::range<3>(to.get_pitch(), to.get_y(), 1),\n                           sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id,\n                           size, direction);\n    }\n\n    /// memcpy 2D matrix with pitch.\n    static inline std::vector<sycl::event>\n    dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,\n                size_t to_pitch, size_t from_pitch, size_t x, size_t y,\n                memcpy_direction direction = automatic)\n    {\n        return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1),\n                           sycl::range<3>(from_pitch, y, 1),\n                           sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0),\n                           sycl::range<3>(x, y, 1), direction);\n    }\n\n    inline void gemm(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n,\n                     int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,\n                     library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,\n                     library_data_t scaling_type) {\n        if (scaling_type == library_data_t::real_float &&\n            c_type == library_data_t::complex_float)\n        {\n            scaling_type = library_data_t::complex_float;\n        }\n        else if (scaling_type == library_data_t::real_double &&\n                 c_type == library_data_t::complex_double)\n        {\n            scaling_type = library_data_t::complex_double;\n        }\n\n        std::uint64_t key =\n            detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);\n        switch (key)\n        {\n        case detail::get_type_combination_id(\n            library_data_t::real_float, library_data_t::real_float,\n            library_data_t::real_float, library_data_t::real_float):\n        {\n            detail::gemm_impl<float, float, float, float>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_double, library_data_t::real_double,\n            library_data_t::real_double, library_data_t::real_double):\n        {\n            detail::gemm_impl<double, double, double, double>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::complex_float, library_data_t::complex_float,\n            library_data_t::complex_float, library_data_t::complex_float):\n        {\n            detail::gemm_impl<std::complex<float>, std::complex<float>,\n                              std::complex<float>, std::complex<float>>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::complex_double, library_data_t::complex_double,\n            library_data_t::complex_double, library_data_t::complex_double):\n        {\n            detail::gemm_impl<std::complex<double>, std::complex<double>,\n                              std::complex<double>, std::complex<double>>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_half, library_data_t::real_half,\n            library_data_t::real_half, library_data_t::real_half):\n        {\n            detail::gemm_impl<sycl::half, sycl::half, sycl::half,\n                              sycl::half>(q, a_trans, b_trans, m, n, k, alpha, a,\n                                          lda, b, ldb, beta, c, ldc);\n            break;\n        }\n#ifdef __INTEL_MKL__\n        case detail::get_type_combination_id(\n            library_data_t::real_bfloat16, library_data_t::real_bfloat16,\n            library_data_t::real_float, library_data_t::real_float):\n        {\n            detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_half, library_data_t::real_half,\n            library_data_t::real_float, library_data_t::real_float):\n        {\n            detail::gemm_impl<sycl::half, sycl::half, float, float>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_half, library_data_t::real_half,\n            library_data_t::real_half, library_data_t::real_float):\n        {\n            float alpha_value =\n                dpct::get_value(reinterpret_cast<const float *>(alpha), q);\n            float beta_value =\n                dpct::get_value(reinterpret_cast<const float *>(beta), q);\n            sycl::half alpha_half(alpha_value);\n            sycl::half beta_half(beta_value);\n            detail::gemm_impl<sycl::half, sycl::half, sycl::half,\n                              sycl::half>(q, a_trans, b_trans, m, n, k, &alpha_half,\n                                          a, lda, b, ldb, &beta_half, c, ldc);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_int8, library_data_t::real_int8,\n            library_data_t::real_float, library_data_t::real_float):\n        {\n            detail::gemm_impl<std::int8_t, std::int8_t, float, float>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_bfloat16, library_data_t::real_bfloat16,\n            library_data_t::real_bfloat16, library_data_t::real_float):\n        {\n            detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_int8, library_data_t::real_int8,\n            library_data_t::real_int32, library_data_t::real_int32):\n        {\n            float alpha_float =\n                dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);\n            float beta_float =\n                dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);\n            detail::gemm_impl<std::int8_t, std::int8_t, std::int32_t, float>(\n                q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);\n            break;\n        }\n#endif // __INTEL_MKL__\n        default:\n            throw std::runtime_error(\"the combination of data type is unsupported\");\n        }\n    }  // gemm()\n\n    /// Computes a batch of matrix-matrix product with general matrices.\n    /// \\param [in] q The queue where the routine should be executed.\n    /// \\param [in] a_trans Specifies the operation applied to A.\n    /// \\param [in] b_trans Specifies the operation applied to B.\n    /// \\param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C.\n    /// \\param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C.\n    /// \\param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B).\n    /// \\param [in] alpha Scaling factor for the matrix-matrix product.\n    /// \\param [in] a Input matrix A.\n    /// \\param [in] a_type Data type of the matrix A.\n    /// \\param [in] lda Leading dimension of A.\n    /// \\param [in] b Input matrix B.\n    /// \\param [in] b_type Data type of the matrix B.\n    /// \\param [in] ldb Leading dimension of B.\n    /// \\param [in] beta Scaling factor for matrix C.\n    /// \\param [in, out] c Input/Output matrix C.\n    /// \\param [in] c_type Data type of the matrix C.\n    /// \\param [in] ldc Leading dimension of C.\n    /// \\param [in] batch_size Specifies the number of matrix multiply operations to perform.\n    /// \\param [in] scaling_type Data type of the scaling factors.\n    inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,\n                           int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,\n                           const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],\n                           library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,\n                           matrix_info_t<float> * matrix_info) {\n        std::uint64_t key =\n            detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);\n        switch (key)\n        {\n        case detail::get_type_combination_id(\n            library_data_t::real_float, library_data_t::real_float,\n            library_data_t::real_float, library_data_t::real_float):\n        {\n            detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,\n                                                                beta, c, ldc, batch_size, matrix_info);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_double, library_data_t::real_double,\n            library_data_t::real_double, library_data_t::real_double):\n        {\n            detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,\n                                                                    beta, c, ldc, batch_size, matrix_info);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_half, library_data_t::real_half,\n            library_data_t::real_half, library_data_t::real_half):\n        {\n            detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);\n            break;\n        }\n#ifdef __INTEL_MKL__\n        case detail::get_type_combination_id(\n            library_data_t::real_bfloat16, library_data_t::real_bfloat16,\n            library_data_t::real_bfloat16, library_data_t::real_float):\n        {\n            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_bfloat16, library_data_t::real_bfloat16,\n            library_data_t::real_float, library_data_t::real_float):\n        {\n            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);\n            break;\n        }\n#endif\n        case detail::get_type_combination_id(\n            library_data_t::real_int8, library_data_t::real_int8,\n            library_data_t::real_int32, library_data_t::real_int32):\n        {\n            float alpha_float =\n                dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);\n            float beta_float =\n                dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);\n            detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(\n                q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,\n                matrix_info);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_int8, library_data_t::real_int8,\n            library_data_t::real_float, library_data_t::real_float):\n        {\n            detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_half, library_data_t::real_half,\n            library_data_t::real_float, library_data_t::real_float):\n        {\n            detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_half, library_data_t::real_half,\n            library_data_t::real_half, library_data_t::real_float):\n        {\n            float alpha_value =\n                dpct::get_value(reinterpret_cast<const float *>(alpha), q);\n            float beta_value =\n                dpct::get_value(reinterpret_cast<const float *>(beta), q);\n            sycl::half alpha_half(alpha_value);\n            sycl::half beta_half(beta_value);\n            detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(\n                q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);\n            break;\n        }\n        default:\n            throw std::runtime_error(\"the combination of data type is unsupported\");\n        }\n    }\n\n    /// Computes a batch of matrix-matrix product with general matrices.\n    /// \\param [in] q The queue where the routine should be executed.\n    /// \\param [in] a_trans Specifies the operation applied to A.\n    /// \\param [in] b_trans Specifies the operation applied to B.\n    /// \\param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C.\n    /// \\param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C.\n    /// \\param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B).\n    /// \\param [in] alpha Scaling factor for the matrix-matrix product.\n    /// \\param [in] a Input matrix A.\n    /// \\param [in] a_type Data type of the matrix A.\n    /// \\param [in] lda Leading dimension of A.\n    /// \\param [in] stride_a Stride between the different A matrices.\n    /// \\param [in] b Input matrix B.\n    /// \\param [in] b_type Data type of the matrix B.\n    /// \\param [in] ldb Leading dimension of B.\n    /// \\param [in] stride_b Stride between the different B matrices.\n    /// \\param [in] beta Scaling factor for matrix C.\n    /// \\param [in, out] c Input/Output matrix C.\n    /// \\param [in] c_type Data type of the matrix C.\n    /// \\param [in] ldc Leading dimension of C.\n    /// \\param [in] stride_c Stride between the different C matrices.\n    /// \\param [in] batch_size Specifies the number of matrix multiply operations to perform.\n    /// \\param [in] scaling_type Data type of the scaling factors.\n    inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,\n                           int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,\n                           long long int stride_a, const void * b, library_data_t b_type, int ldb,\n                           long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,\n                           long long int stride_c, int batch_size, library_data_t scaling_type) {\n        if (scaling_type == library_data_t::real_float &&\n            c_type == library_data_t::complex_float)\n        {\n            scaling_type = library_data_t::complex_float;\n        }\n        else if (scaling_type == library_data_t::real_double &&\n                 c_type == library_data_t::complex_double)\n        {\n            scaling_type = library_data_t::complex_double;\n        }\n\n        std::uint64_t key =\n            detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);\n        switch (key)\n        {\n        case detail::get_type_combination_id(\n            library_data_t::real_float, library_data_t::real_float,\n            library_data_t::real_float, library_data_t::real_float):\n        {\n            detail::gemm_batch_impl<float, float, float, float>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,\n                beta, c, ldc, stride_c, batch_size);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_double, library_data_t::real_double,\n            library_data_t::real_double, library_data_t::real_double):\n        {\n            detail::gemm_batch_impl<double, double, double, double>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,\n                beta, c, ldc, stride_c, batch_size);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::complex_float, library_data_t::complex_float,\n            library_data_t::complex_float, library_data_t::complex_float):\n        {\n            detail::gemm_batch_impl<std::complex<float>, std::complex<float>,\n                                    std::complex<float>, std::complex<float>>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,\n                beta, c, ldc, stride_c, batch_size);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::complex_double, library_data_t::complex_double,\n            library_data_t::complex_double, library_data_t::complex_double):\n        {\n            detail::gemm_batch_impl<std::complex<double>, std::complex<double>,\n                                    std::complex<double>, std::complex<double>>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,\n                beta, c, ldc, stride_c, batch_size);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_half, library_data_t::real_half,\n            library_data_t::real_half, library_data_t::real_half):\n        {\n            detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,\n                                    sycl::half>(q, a_trans, b_trans, m, n, k, alpha,\n                                                a, lda, stride_a, b, ldb, stride_b,\n                                                beta, c, ldc, stride_c, batch_size);\n            break;\n        }\n#ifdef __INTEL_MKL__\n        case detail::get_type_combination_id(\n            library_data_t::real_bfloat16, library_data_t::real_bfloat16,\n            library_data_t::real_bfloat16, library_data_t::real_float):\n        {\n            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,\n                batch_size);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_bfloat16, library_data_t::real_bfloat16,\n            library_data_t::real_float, library_data_t::real_float):\n        {\n            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,\n                batch_size);\n            break;\n        }\n#endif\n        case detail::get_type_combination_id(\n            library_data_t::real_int8, library_data_t::real_int8,\n            library_data_t::real_int32, library_data_t::real_int32):\n        {\n            detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,\n                                    std::int32_t>(q, a_trans, b_trans, m, n, k, alpha,\n                                                  a, lda, stride_a, b, ldb, stride_b,\n                                                  beta, c, ldc, stride_c, batch_size);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_int8, library_data_t::real_int8,\n            library_data_t::real_float, library_data_t::real_float):\n        {\n            detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,\n                beta, c, ldc, stride_c, batch_size);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_half, library_data_t::real_half,\n            library_data_t::real_float, library_data_t::real_float):\n        {\n            detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(\n                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,\n                beta, c, ldc, stride_c, batch_size);\n            break;\n        }\n        case detail::get_type_combination_id(\n            library_data_t::real_half, library_data_t::real_half,\n            library_data_t::real_half, library_data_t::real_float):\n        {\n            float alpha_value =\n                dpct::get_value(reinterpret_cast<const float *>(alpha), q);\n            float beta_value =\n                dpct::get_value(reinterpret_cast<const float *>(beta), q);\n            sycl::half alpha_half(alpha_value);\n            sycl::half beta_half(beta_value);\n            detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(\n                q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b,\n                &beta_half, c, ldc, stride_c, batch_size);\n            break;\n        }\n        default:\n            throw std::runtime_error(\"the combination of data type is unsupported\");\n        }\n    }\n\n    static inline void\n    async_dpct_memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr,\n                      size_t from_pitch, size_t x, size_t y,\n                      memcpy_direction direction = automatic,\n                      sycl::queue &q = get_default_queue())\n    {\n        detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y,\n                            direction);\n    }\n\n    using err0 = detail::generic_error_type<struct err0_tag, int>;\n    using err1 = detail::generic_error_type<struct err1_tag, int>;\n\n    static inline void dpct_free(void *ptr, sycl::queue &q = get_default_queue()) {\n        detail::dpct_free(ptr, q);\n    }\n\n    /// dpct accessor used as device function parameter.\n    template <class T, memory_region Memory, size_t Dimension> class accessor;\n    template <class T, memory_region Memory> class accessor<T, Memory, 3> {\n    public:\n        using memory_t = detail::memory_traits<Memory, T>;\n        using element_t = typename memory_t::element_t;\n        using pointer_t = typename memory_t::pointer_t;\n        using accessor_t = typename memory_t::template accessor_t<3>;\n        accessor(pointer_t data, const sycl::range<3> &in_range)\n            : _data(data), _range(in_range) {}\n        template <memory_region M = Memory>\n        accessor(typename std::enable_if<M != local, const accessor_t>::type &acc)\n            : accessor(acc, acc.get_range()) {}\n        accessor(const accessor_t &acc, const sycl::range<3> &in_range)\n            : accessor(acc.get_pointer(), in_range) {}\n        accessor<T, Memory, 2> operator[](size_t index) const {\n            sycl::range<2> sub(_range.get(1), _range.get(2));\n            return accessor<T, Memory, 2>(_data + index * sub.size(), sub);\n        }\n\n        pointer_t get_ptr() const { return _data; }\n\n    private:\n        pointer_t _data;\n        sycl::range<3> _range;\n    };\n    template <class T, memory_region Memory> class accessor<T, Memory, 2> {\n    public:\n        using memory_t = detail::memory_traits<Memory, T>;\n        using element_t = typename memory_t::element_t;\n        using pointer_t = typename memory_t::pointer_t;\n        using accessor_t = typename memory_t::template accessor_t<2>;\n        accessor(pointer_t data, const sycl::range<2> &in_range)\n            : _data(data), _range(in_range) {}\n        template <memory_region M = Memory>\n        accessor(typename std::enable_if<M != local, const accessor_t>::type &acc)\n            : accessor(acc, acc.get_range()) {}\n        accessor(const accessor_t &acc, const sycl::range<2> &in_range)\n            : accessor(acc.get_pointer(), in_range) {}\n\n        pointer_t operator[](size_t index) const {\n            return _data + _range.get(1) * index;\n        }\n\n        pointer_t get_ptr() const { return _data; }\n\n    private:\n        pointer_t _data;\n        sycl::range<2> _range;\n    };\n\n    namespace detail {\n        /// Device variable with address space of shared, global or constant.\n        template <class T, memory_region Memory, size_t Dimension> class device_memory {\n        public:\n            using accessor_t =\n                typename detail::memory_traits<Memory,\n                                            T>::template accessor_t<Dimension>;\n            using value_t = typename detail::memory_traits<Memory, T>::value_t;\n            using dpct_accessor_t = dpct::accessor<T, Memory, Dimension>;\n\n            device_memory() : device_memory(sycl::range<Dimension>(1)) {}\n\n            /// Constructor of 1-D array with initializer list\n            device_memory(const sycl::range<Dimension> &in_range,\n                        std::initializer_list<value_t> &&init_list)\n                : device_memory(in_range) {\n                assert(init_list.size() <= in_range.size());\n                _host_ptr = (value_t *)std::malloc(_size);\n                std::memset(_host_ptr, 0, _size);\n                std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T));\n            }\n\n            /// Constructor of 2-D array with initializer list\n            template <size_t D = Dimension>\n            device_memory(\n                const typename std::enable_if<D == 2, sycl::range<2>>::type &in_range,\n                std::initializer_list<std::initializer_list<value_t>> &&init_list)\n                : device_memory(in_range) {\n                assert(init_list.size() <= in_range[0]);\n                _host_ptr = (value_t *)std::malloc(_size);\n                std::memset(_host_ptr, 0, _size);\n                auto tmp_data = _host_ptr;\n                for (auto sub_list : init_list) {\n                    assert(sub_list.size() <= in_range[1]);\n                    std::memcpy(tmp_data, sub_list.begin(),\n                                sub_list.size() * sizeof(T));\n                    tmp_data += in_range[1];\n                }\n            }\n\n            /// Constructor with range\n            device_memory(const sycl::range<Dimension> &range_in)\n                : _size(range_in.size() * sizeof(T)), _range(range_in),\n                _reference(false), _host_ptr(nullptr), _device_ptr(nullptr) {\n                static_assert(\n                    (Memory == global) || (Memory == constant) || (Memory == shared),\n                    \"device memory region should be global, constant or shared\");\n                // Make sure that singleton class mem_mgr and dev_mgr will destruct\n                // later than this.\n                detail::mem_mgr::instance();\n                dev_mgr::instance();\n            }\n\n            /// Constructor with range\n            template <class... Args>\n            device_memory(Args... Arguments)\n                : device_memory(sycl::range<Dimension>(Arguments...)) {}\n\n            ~device_memory() {\n                if (_device_ptr && !_reference)\n                    dpct::dpct_free(_device_ptr);\n                if (_host_ptr)\n                    std::free(_host_ptr);\n            }\n\n            /// Allocate memory with default queue, and init memory if has initial\n            /// value.\n            void init() { init(dpct::get_default_queue()); }\n            /// Allocate memory with specified queue, and init memory if has initial\n            /// value.\n            void init(sycl::queue &q) {\n                if (_device_ptr)\n                    return;\n                if (!_size)\n                    return;\n                allocate_device(q);\n                if (_host_ptr)\n                    detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size,\n                                        host_to_device);\n            }\n\n            /// The variable is assigned to a device pointer.\n            void assign(value_t *src, size_t size) {\n                this->~device_memory();\n                new (this) device_memory(src, size);\n            }\n\n            /// Get memory pointer of the memory object, which is virtual pointer when\n            /// usm is not used, and device pointer when usm is used.\n            value_t *get_ptr() { return get_ptr(get_default_queue()); }\n            /// Get memory pointer of the memory object, which is virtual pointer when\n            /// usm is not used, and device pointer when usm is used.\n            value_t *get_ptr(sycl::queue &q) {\n                init(q);\n                return _device_ptr;\n            }\n\n            /// Get the device memory object size in bytes.\n            size_t get_size() { return _size; }\n\n            template <size_t D = Dimension>\n            typename std::enable_if<D == 1, T>::type &operator[](size_t index) {\n                init();\n                return _device_ptr[index];\n            }\n\n            /// Get dpct::accessor with dimension info for the device memory object\n            /// when usm is used and dimension is greater than 1.\n            template <size_t D = Dimension>\n            typename std::enable_if<D != 1, dpct_accessor_t>::type\n            get_access([[maybe_unused]] sycl::handler &cgh) {\n                return dpct_accessor_t((T *)_device_ptr, _range);\n            }\n\n        private:\n            device_memory(value_t *memory_ptr, size_t size)\n                : _size(size), _range(size / sizeof(T)), _reference(true),\n                _device_ptr(memory_ptr) {}\n\n            void allocate_device(sycl::queue &q) {\n        #ifndef DPCT_USM_LEVEL_NONE\n                if (Memory == shared) {\n                    _device_ptr = (value_t *)sycl::malloc_shared(_size, q.get_device(),\n                                                                q.get_context());\n                    return;\n                }\n        #ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY\n                if (Memory == constant) {\n                    _device_ptr = (value_t *)sycl::malloc_device(\n                        _size, q.get_device(), q.get_context(),\n                        sycl::ext::oneapi::property::usm::device_read_only());\n                    return;\n                }\n        #endif\n        #endif\n                _device_ptr = (value_t *)detail::dpct_malloc(_size, q);\n            }\n\n            size_t _size;\n            sycl::range<Dimension> _range;\n            bool _reference;\n            value_t *_host_ptr;\n            value_t *_device_ptr;\n        };\n        template <class T, memory_region Memory>\n        class device_memory<T, Memory, 0> : public device_memory<T, Memory, 1> {\n        public:\n            using base = device_memory<T, Memory, 1>;\n            using value_t = typename base::value_t;\n            using accessor_t =\n                typename detail::memory_traits<Memory, T>::template accessor_t<0>;\n\n            /// Constructor with initial value.\n            device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {}\n\n            /// Default constructor\n            device_memory() : base(1) {}\n        };\n        } // namespace detail\n\n    template <class T, size_t Dimension>\n    using global_memory = detail::device_memory<T, global, Dimension>;\n    template <class T, size_t Dimension>\n    using constant_memory = detail::device_memory<T, constant, Dimension>;\n    template <class T, size_t Dimension>\n    using shared_memory = detail::device_memory<T, shared, Dimension>;\n\n\n    template <typename T,\n            sycl::access::address_space addressSpace =\n                sycl::access::address_space::global_space,\n            sycl::memory_order memoryOrder = sycl::memory_order::relaxed,\n            sycl::memory_scope memoryScope = sycl::memory_scope::device>\n    inline T atomic_fetch_add(T *addr, T operand) {\n    auto atm =\n        sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);\n    return atm.fetch_add(operand);\n    }\n\n    template <sycl::access::address_space addressSpace =\n                sycl::access::address_space::global_space,\n            sycl::memory_order memoryOrder = sycl::memory_order::relaxed,\n            sycl::memory_scope memoryScope = sycl::memory_scope::device,\n            typename T1, typename T2>\n    inline T1 atomic_fetch_add(T1 *addr, T2 operand) {\n    auto atm =\n        sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);\n    return atm.fetch_add(operand);\n    }\n\n    template <typename T, sycl::access::address_space addressSpace =\n                            sycl::access::address_space::global_space>\n    inline T atomic_fetch_add(T *addr, T operand,\n                            sycl::memory_order memoryOrder) {\n    switch (memoryOrder) {\n        case sycl::memory_order::relaxed:\n            return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,\n                                    sycl::memory_scope::device>(addr, operand);\n        case sycl::memory_order::acq_rel:\n            return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,\n                                    sycl::memory_scope::device>(addr, operand);\n        case sycl::memory_order::seq_cst:\n            return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,\n                                    sycl::memory_scope::device>(addr, operand);\n        default:\n            assert(false && \"Invalid memory_order for atomics. Valid memory_order for \"\n                            \"atomics are: sycl::memory_order::relaxed, \"\n                            \"sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!\");\n        }\n    }\n\n    template <sycl::access::address_space addressSpace =\n                sycl::access::address_space::global_space,\n            typename T1, typename T2>\n    inline T1 atomic_fetch_add(T1 *addr, T2 operand,\n                            sycl::memory_order memoryOrder) {\n    atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);\n    }\n\n    inline unsigned int byte_level_permute(\n        unsigned int a, unsigned int b, unsigned int s) {\n      unsigned int ret;\n      ret = ((((std::uint64_t)b << 32 | a) >> (s & 0x7) * 8) & 0xff) |\n            (((((std::uint64_t)b << 32 | a) >> ((s >> 4) & 0x7) * 8) & 0xff)\n             << 8) |\n            (((((std::uint64_t)b << 32 | a) >> ((s >> 8) & 0x7) * 8) & 0xff)\n             << 16) |\n            (((((std::uint64_t)b << 32 | a) >> ((s >> 12) & 0x7) * 8) & 0xff)\n             << 24);\n      return ret;\n    }\n\n    inline uint32_t byte_level_permute_custom(\n        uint32_t low32, uint32_t high32, uint32_t sel, int mode = 0) {\n      constexpr uint16_t lookup[6][4] = {\n          {0x3210, 0x4321, 0x5432, 0x6543},  // Forward 4-byte extract\n          {0x5670, 0x6701, 0x7012, 0x0123},  // Backward 4-byte extract\n          {0x0000, 0x1111, 0x2222, 0x3333},  // Replicate 8-bit values\n          {0x3210, 0x3211, 0x3222, 0x3333},  // Edge clamp left\n          {0x0000, 0x1110, 0x2210, 0x3210},  // Edge clamp right\n          {0x1010, 0x3232, 0x1010, 0x3232}   // Replicate 16-bit values\n      };\n\n      if (mode >= 1 && mode <= 6) {\n        return byte_level_permute(low32, high32, lookup[mode - 1][sel & 0x3]);\n      } else if (!mode) {\n        return byte_level_permute(low32, high32, sel);\n      }\n      return 0;\n    }\n\n    template <int n_nondefault_params, int n_default_params, typename T>\n    class args_selector;\n\n    /// args_selector is a helper class for extracting arguments from an\n    /// array of pointers to arguments or buffer of arguments to pass to a\n    /// kernel function.\n    ///\n    /// \\param R(Ts...) The type of the kernel\n    /// \\param n_nondefault_params The number of nondefault parameters of the\n    /// kernel (excluding parameters that like sycl::nd_item, etc.) \\param\n    /// n_default_params The number of default parameters of the kernel\n    ///\n    /// Example usage:\n    /// With the following kernel:\n    ///   void foo(sycl::float2 *x, int n, sycl::nd_item<3> item_ct1, float\n    ///   f=.1) {}\n    /// and with the declaration:\n    ///   args_selector<2, 1, decltype(foo)> selector(kernelParams, extra);\n    /// we have:\n    ///   selector.get<0>() returns a reference to sycl::float*,\n    ///   selector.get<1>() returns a reference to int,\n    ///   selector.get<2>() returns a reference to float\n    template <int n_nondefault_params, int n_default_params, typename R,\n              typename... Ts>\n    class args_selector<n_nondefault_params, n_default_params, R(Ts...)> {\n      private:\n        void **kernel_params;\n        char *args_buffer;\n\n        template <int i> static constexpr int account_for_default_params() {\n            constexpr int n_total_params = sizeof...(Ts);\n            if constexpr (i >= n_nondefault_params) {\n                return n_total_params - n_default_params +\n                       (i - n_nondefault_params);\n            } else {\n                return i;\n            }\n        }\n\n      public:\n        /// Get the type of the ith argument of R(Ts...)\n        /// \\param [in] i Index of parameter to get\n        /// \\returns Type of ith parameter\n        template <int i>\n        using arg_type = std::tuple_element_t<account_for_default_params<i>(),\n                                              std::tuple<Ts...>>;\n        static constexpr int params_num = sizeof...(Ts);\n\n      private:\n        template <int i> static constexpr int get_offset() {\n            if constexpr (i == 0) {\n                // we can assume args_buffer is properly aligned to the\n                // first argument\n                return 0;\n            } else {\n                constexpr int prev_off = get_offset<i - 1>();\n                constexpr int prev_past_end =\n                    prev_off + sizeof(arg_type<i - 1>);\n                using T = arg_type<i>;\n                // is the past-the-end of the i-1st element properly aligned\n                // with the ith element's alignment?\n                if constexpr (prev_past_end % alignof(T) == 0) {\n                    return prev_past_end;\n                }\n                // otherwise bump prev_past_end to match alignment\n                else {\n                    return prev_past_end +\n                           (alignof(T) - (prev_past_end % alignof(T)));\n                }\n            }\n        }\n\n        static char *get_args_buffer(void **extra) {\n            if (!extra)\n                return nullptr;\n            for (; (std::size_t)*extra != 0; ++extra) {\n                if ((std::size_t)*extra == 1) {\n                    return static_cast<char *>(*(extra + 1));\n                }\n            }\n            return nullptr;\n        }\n\n      public:\n        /// If kernel_params is nonnull, then args_selector will\n        /// extract arguments from kernel_params. Otherwise, it\n        /// will extract them from extra.\n        /// \\param [in] kernel_params Array of pointers to arguments\n        /// a or null pointer.\n        /// \\param [in] extra Array containing pointer to argument buffer.\n        args_selector(void **kernel_params, void **extra)\n            : kernel_params(kernel_params),\n              args_buffer(get_args_buffer(extra)) {}\n\n        /// Get a reference to the ith argument extracted from kernel_params\n        /// or extra.\n        /// \\param [in] i Index of argument to get\n        /// \\returns Reference to the ith argument\n        template <int i> arg_type<i> &get() {\n            if (kernel_params) {\n                return *static_cast<arg_type<i> *>(kernel_params[i]);\n            } else {\n                return *reinterpret_cast<arg_type<i> *>(args_buffer +\n                                                        get_offset<i>());\n            }\n        }\n    }; // COPY from DPCT head file\n       // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp\n\n    /// Utility class for launching SYCL kernels through kernel\n    /// function wrapper.\n    /// For example:\n    /// A SYCL kernel function:\n    ///   void kernel_func(int *ptr, sycl::nd_item<3> item);\n    /// Kernel function wrapper:\n    ///   void kernel_func_wrapper(int *ptr) {\n    ///     sycl::queue queue = *dpct::kernel_launcher::_que;\n    ///     unsigned int localMemSize = dpct::kernel_launcher::_local_mem_size;\n    ///     sycl::nd_range<3> nr = dpct::kernel_launcher::_nr;\n    ///     queue.parallel_for(\n    ///       nr,\n    ///       [=](sycl::nd_item<3> item_ct1) {\n    ///         kernel_func(ptr, item_ct1);\n    ///       });\n    ///   }\n    /// Then launch the kernel through wrapper like:\n    ///   typedef void(*fpt)(int *);\n    ///   fpt fp = kernel_func_wrapper;\n    ///   dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0,\n    ///   device_ptr);\n    /// If the origin function type is erased, then need to register it first:\n    ///   void *fp = (void *)wrapper_register(&kernel_func_wrapper).get();\n    ///   dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), args,\n    ///   0, 0);\n    class kernel_launcher {\n        template <typename FuncT, typename ArgSelector, std::size_t... Index>\n        static void launch_helper(FuncT &&func, ArgSelector &selector,\n                                  std::index_sequence<Index...>) {\n            func(selector.template get<Index>()...);\n        }\n        static void set_execution_config(dim3 group_range, dim3 local_range,\n                                         unsigned int local_mem_size,\n                                         queue_ptr que) {\n            if (que) {\n                _que = que;\n            } else {\n                _que = &get_default_queue();\n            }\n            _nr = sycl::nd_range<3>(\n                static_cast<sycl::range<3>>(group_range * local_range),\n                static_cast<sycl::range<3>>(local_range));\n            _local_mem_size = local_mem_size;\n\n\n        };\n        static inline std::mutex kernel_function_ptr_map_mutex;\n\n      public:\n        /// Variables for storing execution configuration.\n        static inline thread_local sycl::queue *_que = nullptr;\n        static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>();\n        static inline thread_local unsigned int _local_mem_size = 0;\n        /// Map for retrieving launchable functor from a raw pointer.\n        static inline std::map<\n            const void *,\n            std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>>\n            kernel_function_ptr_map = {};\n\n        /// Registers a kernel function pointer with a corresponding launchable\n        /// functor.\n        /// \\param [in] func Pointer to the kernel function.\n        /// \\param [in] launcher Functor to handle kernel invocation.\n        static void register_kernel_ptr(\n            const void *func,\n            std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>\n                launcher) {\n            std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);\n            kernel_function_ptr_map[func] = std::move(launcher);\n        }\n        /// Launches a kernel function with arguments provided directly through\n        /// kernel function wrapper.\n        /// \\tparam FuncT Type of the kernel function wrapper.\n        /// \\tparam ArgsT Types of kernel arguments.\n        /// \\param [in] func Pointer to the kernel function wrapper.\n        /// \\param [in] group_range SYCL group range.\n        /// \\param [in] local_range SYCL local range.\n        /// \\param [in] local_mem_size The size of local memory required by the\n        /// kernel function. \\param [in] que SYCL queue used to execute kernel.\n        /// \\param [in] args Kernel arguments.\n        template <typename FuncT, typename... ArgsT>\n        static std::enable_if_t<std::is_invocable_v<FuncT *, ArgsT...>, void>\n        launch(FuncT *func, dim3 group_range, dim3 local_range,\n               unsigned int local_mem_size, queue_ptr que, ArgsT... args) {\n            set_execution_config(group_range, local_range, local_mem_size, que);\n            func(args...);\n        }\n        /// Launches a kernel function through registered kernel function\n        /// wrapper. \\param [in] func Pointer to the registered kernel function\n        /// wrapper. \\param [in] group_range SYCL group range. \\param [in]\n        /// local_range SYCL local range. \\param [in] args Array of pointers to\n        /// kernel arguments. \\param [in] local_mem_size The size of local\n        /// memory required by the kernel function. \\param [in] que SYCL queue\n        /// used to execute kernel.\n        static void launch(const void *func, dim3 group_range, dim3 local_range,\n                           void **args, unsigned int local_mem_size,\n                           queue_ptr que) {\n            std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);\n            auto Iter = kernel_function_ptr_map.find(func);\n            if (Iter == kernel_function_ptr_map.end()) {\n                throw std::runtime_error(\"dpct::launch() : no registered \"\n                                         \"kernel function wrapper found.\");\n            }\n            (Iter->second)(group_range, local_range, args, local_mem_size, que);\n        }\n        /// Launches a kernel function with packed arguments through kernel\n        /// function wrapper.\n        /// \\tparam FuncT Type of the kernel function wrapper.\n        /// \\param [in] func Pointer to the kernel function wrapper.\n        /// \\param [in] group_range SYCL group range.\n        /// \\param [in] local_range SYCL local range.\n        /// \\param [in] args Array of pointers to kernel arguments.\n        /// \\param [in] local_mem_size The size of local memory required by the\n        /// kernel function. \\param [in] que SYCL queue used to execute kernel.\n        template <typename FuncT>\n        static std::enable_if_t<std::is_function_v<FuncT>, void>\n        launch(FuncT *func, dim3 group_range, dim3 local_range, void **args,\n               unsigned int local_mem_size, queue_ptr que) {\n            constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num;\n            set_execution_config(group_range, local_range, local_mem_size, que);\n            args_selector<p_num, p_num, FuncT> selector(args, nullptr);\n            launch_helper(func, selector, std::make_index_sequence<p_num>{});\n        }\n    }; // COPY from DPCT head file\n       // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/kernel.hpp\n\n    // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp\n    template <typename T>\n    T select_from_sub_group(\n        sycl::sub_group g,\n        T x,\n        int remote_local_id,\n        int logical_sub_group_size = 32) {\n      unsigned int start_index = g.get_local_linear_id() /\n                                 logical_sub_group_size *\n                                 logical_sub_group_size;\n      return sycl::select_from_group(\n          g, x, start_index + remote_local_id % logical_sub_group_size);\n    }\n\n    // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp\n    template <typename T>\n    void ldmatrix(uintptr_t addr, T* m, bool trans = false, unsigned mat = 0) {\n      auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();\n      int lane = sg.get_local_linear_id();\n\n      int lane_group8_row = lane / 8;\n      int lane_group8_col = lane % 8;\n\n      if (!trans) {\n        // calculate the source lane\n        int src_lane = 2 * lane_group8_row;\n        if (lane_group8_col >= 4)\n          src_lane += 1;\n\n        // Broadcast the address from the source lane\n        auto recv_addr_uintp =\n            dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);\n\n        // Cast the received address from uintptr_t to the type of 'm'\n        auto recv_addr = reinterpret_cast<T*>(recv_addr_uintp);\n\n        // Non-transposed load\n        *m = recv_addr[lane_group8_col % 4];\n      } else {\n        // calculate the source lane\n        int src_lane = (lane % 4) * 2;\n\n        // Broadcast the address from the source lane\n        auto recv_addr_uintp_1 =\n            dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);\n        auto recv_addr_uintp_2 =\n            dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);\n\n        // Cast the received address from uintptr_t to 'half *'\n        auto recv_addr_1 = reinterpret_cast<sycl::half*>(recv_addr_uintp_1);\n        auto recv_addr_2 = reinterpret_cast<sycl::half*>(recv_addr_uintp_2);\n\n        // Transposed load\n        int index = lane / 4;\n        sycl::half val0 = recv_addr_1[index];\n        sycl::half val1 = recv_addr_2[index];\n\n        // Combine the two 16-bits into one 32-bit value\n        sycl::half2 val = sycl::half2(val0, val1);\n        *m = *reinterpret_cast<T*>(&val);\n      }\n    }\n\n    template <typename T>\n    void ldmatrix(uintptr_t addr, T* m1, T* m2, bool trans = false) {\n      // Load 1st matrix\n      ldmatrix(addr, m1, trans, 0);\n      // Load 2nd matrix\n      ldmatrix(addr, m2, trans, 1);\n    }\n\n    template <typename T>\n    void ldmatrix(\n        uintptr_t addr, T* m1, T* m2, T* m3, T* m4, bool trans = false) {\n      // Load 1st matrix\n      ldmatrix(addr, m1, trans, 0);\n      // Load 2nd matrix\n      ldmatrix(addr, m2, trans, 1);\n      // Load 3rd matrix\n      ldmatrix(addr, m3, trans, 2);\n      // Load 4th matrix\n      ldmatrix(addr, m4, trans, 3);\n    }\n\n    // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp\n\n    /// A helper struct that defines the pack type for the input matrix\n    /// fragments\n    /// of mma() function based on the type of input matrix fragments.\n    /// The MMAType struct is specialized for different types of input matrices.\n    /// Currently, the specialization for f16, bf16 and s8 types is defined\n    /// below. \\tparam [in] T The type of the input matrix fragments\n    template <typename T>\n    struct MMAType {\n      using PackType = uint32_t;\n    };\n\n    /// Each work item of a sub-group (limited to size 32) calling this function\n    /// calculates a subset fragment for the output matrix D using MAD operation\n    /// on A, B & C matrix fragments (D = A * B + C). Current supported shapes &\n    /// types:\n    /// - m8n8k4 (f32.f16.f16.f32)\n    /// - m8n8k16 (s32.s8.s8.s32)\n    /// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)\n    /// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32)\n    /// - m16n8k32 (s32.s8.s8.s32)\n    /// Here, m, n & k define the shapes of A, B & C matrices respectively\n    /// (A = [m x k], B = [k x n], C = [m x n]).\n    /// \\tparam [in] M The rows of A, C & D matrices\n    /// \\tparam [in] N The columns of B, C, D matrices\n    /// \\tparam [in] K The columns & rows of A & B matrices respectively\n    /// \\tparam [in] ABType The type of the input matrix (A & B) fragment\n    /// \\tparam [in] CDType The type of the output matrix (C & D) fragment\n    /// \\param [out] d_mat_frag The fragment of the output matrix D to store the\n    /// result of A * B + C\n    /// \\param [in] a_mat_frag The fragment of the input matrix A to be\n    /// multiplied with B matrix fragment \\param [in] b_mat_frag The fragment of\n    /// the input matrix B to be multiplied with A matrix fragment \\param [in]\n    /// c_mat_frag The fragment of the input matrix C to be added with the\n    /// result of A * B fragments\n    template <int M, int N, int K, typename ABType, typename CDType>\n    void mma(\n        volatile void** d_mat_frag,\n        void* a_mat_frag,\n        void* b_mat_frag,\n        void* c_mat_frag) {\n      auto d = reinterpret_cast<volatile CDType**>(d_mat_frag);\n      auto a =\n          reinterpret_cast<typename MMAType<ABType>::PackType*>(a_mat_frag);\n      auto b =\n          reinterpret_cast<typename MMAType<ABType>::PackType*>(b_mat_frag);\n      auto c = reinterpret_cast<CDType*>(c_mat_frag);\n\n      auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();\n      int lane = sg.get_local_linear_id();\n\n      static_assert(\n          (M == 8 && N == 8 && K == 4) || (M == 8 && N == 8 && K == 16) ||\n              (M == 16 && N == 8 && K == 8) || (M == 16 && N == 8 && K == 16) ||\n              (M == 16 && N == 8 && K == 32),\n          \"Unsupported MMA shape!\");\n\n      short row_load_offset = 4 * (lane >> 2);\n      short col_load_offset = 8 * (lane % 4);\n\n      if constexpr (M == 8 && N == 8 && K == 4) {\n        if constexpr (std::is_floating_point_v<CDType>) {\n          col_load_offset = row_load_offset % 16;\n\n          // Init D matrix with fragments of C matrix\n          *d[0] = c[0];\n          *d[1] = c[1];\n          *d[2] = c[2];\n          *d[3] = c[3];\n          *d[4] = c[4];\n          *d[5] = c[5];\n          *d[6] = c[6];\n          *d[7] = c[7];\n\n          // Calculate the row and col offset indices to iterate through the row\n          // & col fragments of A & B matrices\n          int r_ind = (lane % 2) ? 1 : 0;\n          int c_ind = ((lane % 4) / 2) ? 2 : 0;\n\n          // Each sub-group is responsible for computing a fragment size of 8*8\n          // elements of matrix D for each of 4 MMA computations.\n          // Each work item computes 8 elements of matrix D by gathering\n          // their corresponding col & row matrix fragments of length k (4)\n          // from A & B matrices respectively using below mapping logic:\n          // row0 = (i % 4) if (lane < 16) else (i % 4) + 4\n          // col0 = (lane % 4)\n          // As each row & col fragment of A & B matrices is distributed across\n          // 4 work items, each iteration of below loop loads a partial fragment\n          // of matrix A (row) and matrix B (col) using the row & col offsets.\n          typename MMAType<ABType>::PackType recv_a[2], recv_b[2];\n\n          for (int i = 0; i < 4; i++) {\n            // Load partial fragment from col0 of matrix A ({a0, a1})\n            recv_a[0] =\n                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);\n            // Load partial fragment from col0 of matrix A ({a2, a3})\n            recv_a[1] =\n                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);\n\n            // Load partial fragment from row0 of matrix B ({b0, b1})\n            recv_b[0] =\n                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);\n            // Load partial fragment from row0 of matrix B ({b2, b3})\n            recv_b[1] =\n                dpct::select_from_sub_group(sg, b[1], col_load_offset + i);\n\n            auto ra = reinterpret_cast<ABType*>(recv_a);\n            auto rb = reinterpret_cast<ABType*>(recv_b);\n\n            // Each work item calculates a partial product of A & B matrix\n            // fragments and adds it to the corresponding D matrix fragment (for\n            // even work item indices) d0 += col0{ a0 } * row0{ b0 } d1 += col0{\n            // a0 } * row0{ b1 } d2 += col1{ a2 } * row0{ b0 } d3 += col1{ a2 }\n            // * row0{ b1 } (for odd work item indices) d0 += col0{ a1 } * row0{\n            // b2 } d1 += col0{ a1 } * row0{ b3 } d2 += col1{ a3 } * row0{ b2 }\n            // d3 += col1{ a3 } * row0{ b3 }\n            *d[0] +=\n                static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);\n            *d[1] += static_cast<float>(ra[r_ind]) *\n                     static_cast<float>(rb[c_ind + 1]);\n            *d[2] += static_cast<float>(ra[r_ind + 2]) *\n                     static_cast<float>(rb[c_ind]);\n            *d[3] += static_cast<float>(ra[r_ind + 2]) *\n                     static_cast<float>(rb[c_ind + 1]);\n\n            // Load partial fragment from row1 of matrix B ({b0, b1})\n            recv_b[0] =\n                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 16);\n            // Load partial fragment from row1 of matrix B ({b2, b3})\n            recv_b[1] =\n                dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 16);\n\n            // (for even work item indices)\n            // d0 += col0{ a0 } * row1{ b0 }\n            // d1 += col0{ a0 } * row1{ b1 }\n            // d2 += col1{ a2 } * row1{ b0 }\n            // d3 += col1{ a2 } * row1{ b1 }\n            // (for odd work item indices)\n            // d0 += col0{ a1 } * row1{ b2 }\n            // d1 += col0{ a1 } * row1{ b3 }\n            // d2 += col1{ a3 } * row1{ b2 }\n            // d3 += col1{ a3 } * row1{ b3 }\n            *d[4] +=\n                static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);\n            *d[5] += static_cast<float>(ra[r_ind]) *\n                     static_cast<float>(rb[c_ind + 1]);\n            *d[6] += static_cast<float>(ra[r_ind + 2]) *\n                     static_cast<float>(rb[c_ind]);\n            *d[7] += static_cast<float>(ra[r_ind + 2]) *\n                     static_cast<float>(rb[c_ind + 1]);\n          }\n        }\n      } else if constexpr (M == 8 && N == 8 && K == 16) {\n        if constexpr (std::is_integral_v<ABType>) {\n          // Init D matrix with fragments of C matrix\n          *d[0] = c[0];\n          *d[1] = c[1];\n\n          // Each sub-group is responsible for computing a fragment size of 16*8\n          // elements of matrix D.\n          // Each work item computes 2 elements of matrix D by gathering\n          // their corresponding row & col matrix fragments of length k (16)\n          // from A & B matrices respectively using below mapping logic:\n          // row0 = ((lane % 4) * 4) + i\n          // col0 = (lane >> 2)\n          // As each row & col fragment of A & B matrices is distributed across\n          // 4 work items, each iteration of below loop loads a partial fragment\n          // of matrix A (row) and matrix B (col) using the row & col offsets.\n          for (int i = 0; i < 4; i++) {\n            typename MMAType<ABType>::PackType recv_a, recv_b[2];\n\n            // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})\n            recv_a = dpct::select_from_sub_group(sg, a[0], row_load_offset + i);\n            // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})\n            recv_b[0] =\n                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);\n            // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})\n            recv_b[1] =\n                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);\n\n            auto a = reinterpret_cast<ABType*>(&recv_a);\n            auto b = reinterpret_cast<ABType*>(recv_b);\n\n            // Each work item calculates a partial product of A & B matrix\n            // fragments and adds it to the corresponding D matrix fragment d0\n            // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{\n            // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row0{ a0, a1, a2,\n            // a3 } * col0{ b0, b1, b2, b3 } d3 += row0{ a0, a1, a2, a3 } *\n            // col1{ b0, b1, b2, b3 }\n            for (int j = 0; j < 4; j++) {\n              *d[0] += a[j] * b[j];\n              *d[1] += a[j] * b[j + 4];\n            }\n          }\n        }\n      } else if constexpr (M == 16 && N == 8 && K == 8) {\n        if constexpr (std::is_floating_point_v<CDType>) {\n          // Init D matrix fragment with C matrix fragment\n          *d[0] = c[0];\n          *d[1] = c[1];\n          *d[2] = c[2];\n          *d[3] = c[3];\n\n          // Each sub-group is responsible for computing a fragment size of 16*8\n          // elements of matrix D.\n          // Each work item computes 4 elements of matrix D by gathering\n          // their corresponding row & col matrix fragments of length k (8)\n          // from A & B matrices respectively using below mapping logic:\n          // row0 = (lane >> 2) & row1 = (lane >> 2) + 8\n          // col0 = (lane % 4) * 2 + (i & 0x1)\n          // As each row & col fragment of A & B matrices is distributed across\n          // 4 work items, each iteration of below loop loads a partial fragment\n          // of matrix A (row) and matrix B (col) using the row & col offsets.\n          for (int i = 0; i < 4; i++) {\n            typename MMAType<ABType>::PackType recv_a[2], recv_b[2];\n\n            // Load partial fragment from row0 of matrix A ({a0, a1})\n            recv_a[0] =\n                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);\n            // Load partial fragment from row1 of matrix A ({a2, a3})\n            recv_a[1] =\n                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);\n            // Load partial fragment from col0 of matrix B ({b0, b1})\n            recv_b[0] =\n                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);\n            // Load partial fragment from col1 of matrix B ({b0, b1})\n            recv_b[1] =\n                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);\n\n            auto ra = reinterpret_cast<ABType*>(recv_a);\n            auto rb = reinterpret_cast<ABType*>(recv_b);\n\n            // Each work item calculates a partial product of A & B matrix\n            // fragments and adds it to the corresponding D matrix fragment d0\n            // += row0{ a0, a1 } * col0{ b0, b1 } d1 += row0{ a0, a1 } * col1{\n            // b0, b1 } d2 += row1{ a2, a3 } * col0{ b0, b1 } d3 += row1{ a2, a3\n            // } * col1{ b0, b1 }\n            for (int j = 0; j < 2; j++) {\n              *d[0] += static_cast<float>(ra[j]) * static_cast<float>(rb[j]);\n              *d[1] +=\n                  static_cast<float>(ra[j]) * static_cast<float>(rb[j + 2]);\n              *d[2] +=\n                  static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j]);\n              *d[3] +=\n                  static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j + 2]);\n            }\n          }\n        }\n      } else if constexpr (M == 16 && N == 8 && K == 16) {\n        if constexpr (std::is_floating_point_v<CDType>) {\n          // Init D matrix fragment with C matrix fragment\n          *d[0] = c[0];\n          *d[1] = c[1];\n          *d[2] = c[2];\n          *d[3] = c[3];\n\n          // Each sub-group is responsible for computing a fragment size of 16*8\n          // elements of matrix D.\n          // Each work item computes 4 elements of matrix D by gathering\n          // their corresponding row & col matrix fragments of length k (8)\n          // from A & B matrices respectively using below mapping logic:\n          // row0 = (lane >> 2)    & row1 = (lane >> 2) + 8\n          // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1\n          // As each row & col fragment of A & B matrices is distributed across\n          // 4 work items, each iteration of below loop loads a partial fragment\n          // of matrix A (row) and matrix B (col) using the row & col offsets.\n          for (int i = 0; i < 4; i++) {\n            typename MMAType<ABType>::PackType recv_a[4], recv_b[4];\n\n            // Load partial fragment from row0 of matrix A ({a0, a1})\n            recv_a[0] =\n                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);\n            // Load partial fragment from row0 of matrix A ({a2, a3})\n            recv_a[1] =\n                dpct::select_from_sub_group(sg, a[2], row_load_offset + i);\n            // Load partial fragment from row1 of matrix A ({a0, a1})\n            recv_a[2] =\n                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);\n            // Load partial fragment from row1 of matrix A ({a2, a3})\n            recv_a[3] =\n                dpct::select_from_sub_group(sg, a[3], row_load_offset + i);\n\n            // Load partial fragment from col0 of matrix B ({b0, b1})\n            recv_b[0] =\n                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);\n            // Load partial fragment from col0 of matrix B ({b2, b3})\n            recv_b[1] =\n                dpct::select_from_sub_group(sg, b[1], col_load_offset + i);\n            // Load partial fragment from col1 of matrix B ({b0, b1})\n            recv_b[2] =\n                dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i);\n            // Load partial fragment from col1 of matrix B ({b2, b3})\n            recv_b[3] =\n                dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i);\n\n            auto ra = reinterpret_cast<ABType*>(recv_a);\n            auto rb = reinterpret_cast<ABType*>(recv_b);\n\n            // Each work item calculates a partial product of A & B matrix\n            // fragments and adds it to the corresponding D matrix fragment d0\n            // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{\n            // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, a2,\n            // a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } *\n            // col1{ b0, b1, b2, b3 }\n            for (int j = 0; j < 4; j++) {\n              *d[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);\n              *d[1] +=\n                  static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j + 4]);\n              *d[2] +=\n                  static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j]);\n              *d[3] += static_cast<CDType>(ra[j + 4]) *\n                       static_cast<CDType>(rb[j + 4]);\n            }\n          }\n        } else if constexpr (std::is_integral_v<ABType>) {\n          // Init D matrix with fragments of C matrix\n          *d[0] = c[0];\n          *d[1] = c[1];\n          *d[2] = c[2];\n          *d[3] = c[3];\n\n          // Each sub-group is responsible for computing a fragment size of 16*8\n          // elements of matrix D.\n          // Each work item computes 4 elements of matrix D by gathering\n          // their corresponding row & col matrix fragments of length k (8)\n          // from A & B matrices respectively using below mapping logic:\n          // row0 = (lane >> 2)    & row1 = (lane >> 2) + 8\n          // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1\n          // As each row & col fragment of A & B matrices is distributed across\n          // 4 work items, each iteration of below loop loads a partial fragment\n          // of matrix A (row) and matrix B (col) using the row & col offsets.\n          for (int i = 0; i < 4; i++) {\n            typename MMAType<ABType>::PackType recv_a[2], recv_b[2];\n\n            // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})\n            recv_a[0] =\n                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);\n            // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})\n            recv_a[1] =\n                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);\n            // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})\n            recv_b[0] =\n                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);\n            // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})\n            recv_b[1] =\n                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);\n\n            auto ra = reinterpret_cast<ABType*>(recv_a);\n            auto rb = reinterpret_cast<ABType*>(recv_b);\n\n            // Each work item calculates a partial product of A & B matrix\n            // fragments and adds it to the corresponding D matrix fragment d0\n            // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{\n            // a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } d2 += row1{ a4, a5, a6,\n            // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *\n            // col1{ b4, b5, b6, b7 }\n            for (int i = 0; i < 4; i++) {\n              *d[0] += ra[i] * rb[i];\n              *d[1] += ra[i] * rb[i + 4];\n              *d[2] += ra[i + 4] * rb[i];\n              *d[3] += ra[i + 4] * rb[i + 4];\n            }\n          }\n        }\n      } else if constexpr (M == 16 && N == 8 && K == 32) {\n        if constexpr (std::is_integral_v<ABType>) {\n          // Init D matrix with fragments of C matrix\n          *d[0] = c[0];\n          *d[1] = c[1];\n          *d[2] = c[2];\n          *d[3] = c[3];\n\n          // Each sub-group is responsible for computing a fragment size of 16*8\n          // elements of matrix D.\n          // Each work item computes 4 elements of matrix D by gathering\n          // their corresponding row & col matrix fragments of length k (32)\n          // from A & B matrices respectively using below mapping logic:\n          // row0 = (lane >> 2)    & row1 = (lane >> 2) + 8\n          // col0 = ((lane % 4) * 4) + (i & 0x3) & col1 = ((lane % 4) * 4) + (i\n          // & 0x3) As each row & col fragment of A & B matrices is distributed\n          // across 4 work items, each iteration of below loop loads a partial\n          // fragment of matrix A (row) and matrix B (col) using the row & col\n          // offsets.\n          for (int i = 0; i < 4; i++) {\n            typename MMAType<ABType>::PackType recv_a[2], recv_b[2];\n\n            // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})\n            recv_a[0] =\n                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);\n            // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})\n            recv_a[1] =\n                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);\n            // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})\n            recv_b[0] =\n                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);\n            // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})\n            recv_b[1] =\n                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);\n\n            auto a = reinterpret_cast<ABType*>(recv_a);\n            auto b = reinterpret_cast<ABType*>(recv_b);\n\n            // Each work item calculates a partial product of A & B matrix\n            // fragments and adds it to the corresponding D matrix fragment d0\n            // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{\n            // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a4, a5, a6,\n            // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *\n            // col1{ b0, b1, b2, b3 }\n            for (int j = 0; j < 4; j++) {\n              *d[0] += a[j] * b[j];\n              *d[1] += a[j] * b[j + 4];\n              *d[2] += a[j + 4] * b[j];\n              *d[3] += a[j + 4] * b[j + 4];\n            }\n          }\n\n          for (int i = 0; i < 4; i++) {\n            typename MMAType<ABType>::PackType recv_a[2], recv_b[2];\n\n            // Load partial fragment from row0 of matrix A ({a8, a9, a10, a11})\n            recv_a[0] =\n                dpct::select_from_sub_group(sg, a[2], row_load_offset + i);\n            // Load partial fragment from row1 of matrix A ({a12, a13, a14,\n            // a15})\n            recv_a[1] =\n                dpct::select_from_sub_group(sg, a[3], row_load_offset + i);\n            // Load partial fragment from col0 of matrix B ({b4, b5, b6, b7})\n            recv_b[0] =\n                dpct::select_from_sub_group(sg, b[1], col_load_offset + i);\n            // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})\n            recv_b[1] =\n                dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 4);\n\n            auto a = reinterpret_cast<ABType*>(recv_a);\n            auto b = reinterpret_cast<ABType*>(recv_b);\n\n            // Each work item calculates a partial product of A & B matrix\n            // fragments and adds it to the corresponding D matrix fragment d0\n            // += row0{ a8, a9, a10, a11 } * col0{ b4, b5, b6, b7 } d1 += row0{\n            // a8, a9, a10, a11 } * col1{ b4, b5, b6, b7 } d2 += row1{ a12, a13,\n            // a14, a15 } * col0{ b4, b5, b6, b7 } d3 += row1{ a12, a13, a14,\n            // a15 } * col1{ b4, b5, b6, b7 }\n            for (int j = 0; j < 4; j++) {\n              *d[0] += a[j] * b[j];\n              *d[1] += a[j] * b[j + 4];\n              *d[2] += a[j + 4] * b[j];\n              *d[3] += a[j + 4] * b[j + 4];\n            }\n          }\n        }\n      }\n    }\n} // COPY from DPCT head files\n\n#endif // GGML_SYCL_DPCT_HELPER_HPP\n"
  },
  {
    "path": "src/ggml-sycl/element_wise.cpp",
    "content": "#include \"common.hpp\"\n#include \"ggml-sycl/presets.hpp\"\n#include \"ggml.h\"\n#include \"element_wise.hpp\"\n\n#define SYCL_GLOBAL_ID_LOOP(K, ITEM) \\\n    for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0))\n\n#define SYCL_LOCAL_ID_CALC(ITEM, IDX) \\\n    (ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX))\n\nstatic void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,\n        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,\n        const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {\n    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const int64_t i = SYCL_LOCAL_ID_CALC(item_ct1, 2);\n\n    if (i >= ne) {\n        return;\n    }\n\n    int64_t src1_idx = i - offset;\n\n    int64_t tmp = src1_idx;\n    const int64_t i13 = tmp / s13;\n    tmp -= i13 * s13;\n    const int64_t i12 = tmp / s12;\n    tmp -= i12 * s12;\n    const int64_t i11 = tmp / s11;\n    tmp -= i11 * s11;\n    const int64_t i10 = tmp;\n\n    float val = x[i];\n    if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {\n        val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];\n    }\n    dst[i] = val;\n}\n\n/* Unary OP funcs */\ntemplate<typename T>\nstatic __dpct_inline__ T op_sgn(T x) {\n    return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_abs(T x) {\n    return sycl::fabs(x);\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_elu(T x) {\n    return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_gelu(T x) {\n    const T GELU_COEF_A    = static_cast<T>(0.044715f);\n    const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);\n    return static_cast<T>(0.5f) * x *\n           (static_cast<T>(1.0f) +\n            sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_silu(T x) {\n    return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_gelu_quick(T x) {\n    const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);\n    return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_gelu_erf(T x) {\n    const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);\n    return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_tanh(T x) {\n    return sycl::tanh(x);\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_relu(T x) {\n    return sycl::fmax(x, static_cast<T>(0));\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_sigmoid(T x) {\n    return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_sqrt(T x) {\n    return sycl::sqrt(x);\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_sin(T x) {\n    return sycl::sin(x);\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_cos(T x) {\n    return sycl::cos(x);\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_hardsigmoid(T x) {\n    return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_hardswish(T x) {\n    return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_exp(T x) {\n    return sycl::exp(x);\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_log(T x) {\n    if (x <= static_cast<T>(0)) {\n        return neg_infinity<T>();\n    }\n    return sycl::log(x);\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_softplus(T x) {\n    const float xf = (float) x;\n    const float ax = sycl::fabs(xf);\n    const float m  = sycl::fmax(xf, 0.0f);\n    const float y  = m + sycl::log1p(sycl::exp(-ax));\n    return (T) y;\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_neg(T x) {\n    return -x;\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_step(T x) {\n    return (x > static_cast<T>(0.0f)) ? static_cast<T>(1.0f) : static_cast<T>(0.0f);\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {\n    T neg_slope_T = static_cast<T>(negative_slope);\n    return sycl::fmax(x, static_cast<T>(0)) +\n           sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_sqr(T x) {\n    return x * x;\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {\n    return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_floor(T x) {\n    return sycl::floor(x);\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_ceil(T x) {\n    return sycl::ceil(x);\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_round(T x) {\n    return sycl::round(x);\n}\n\ntemplate<typename T>\nstatic __dpct_inline__ T op_trunc(T x) {\n    return sycl::trunc(x);\n}\n\ntemplate<typename T, typename F>\nstatic void unary_op_generic_kernel(\n        const T * x,\n        T * dst,\n        const int k,\n        const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3,\n        const size_t nb0,  const size_t nb1,  const size_t nb2,  const size_t nb3,\n        const size_t nbd0, const size_t nbd1, const size_t nbd2, const size_t nbd3,\n        const sycl::nd_item<1> & item_ct1,\n        F func) {\n\n        (void) ne3;\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        const int64_t i0 =  i % ne0;\n        const int64_t i1 = (i / ne0)        % ne1;\n        const int64_t i2 = (i / (ne0*ne1))  % ne2;\n        const int64_t i3 =  i / (ne0*ne1*ne2);\n\n        const char * src_base = (const char *) x;\n        char       * dst_base = (char *) dst;\n\n        const T * srcp = (const T *)(src_base + i0*nb0  + i1*nb1  + i2*nb2  + i3*nb3 );\n        T *       dstp = (T *)(dst_base + i0*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3);\n\n        *dstp = func(*srcp);\n    }\n}\n\ntemplate<typename T>\nstatic void unary_op_sqrt_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        dst[i] = op_sqrt(x[i]);\n    }\n}\n\ntemplate<typename T>\nstatic void unary_op_sin_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        dst[i] = op_sin(x[i]);\n    }\n}\n\ntemplate<typename T>\nstatic void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        dst[i] = op_cos(x[i]);\n    }\n}\n\ntemplate<typename T>\nstatic void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        dst[i] = op_log(x[i]);\n    }\n}\n\n\ntemplate<typename T>\nstatic void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        dst[i] = op_leaky_relu(x[i], negative_slope);\n    }\n}\n\ntemplate<typename T>\nstatic void unary_op_sqr_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        dst[i] = op_sqr(x[i]);\n    }\n}\n\ntemplate<typename T>\nstatic void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1, float min_val, float max_val) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        dst[i] = op_clamp(x[i], min_val, max_val);\n    }\n}\n\ntemplate<typename T>\nstatic void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        dst[i] = op_floor(x[i]);\n    }\n}\n\ntemplate<typename T>\nstatic void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        dst[i] = op_ceil(x[i]);\n    }\n}\n\ntemplate<typename T>\nstatic void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        dst[i] = op_round(x[i]);\n    }\n}\n\ntemplate<typename T>\nstatic void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        dst[i] = op_trunc(x[i]);\n    }\n}\n\ntemplate<typename  T>\nstatic void upscale(const T  *x, T *dst, const int nb00, const int nb01,\n                        const int nb02, const int nb03, const int ne10, const int ne11,\n                        const int ne12, const int ne13, const float sf0, const float sf1,\n                        const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {\n    int index = item_ct1.get_local_id(0) +\n               item_ct1.get_group(0) * item_ct1.get_local_range(0);\n    if (index >= ne10 * ne11 * ne12 * ne13) {\n        return;\n    }\n    // operation\n    int i10 = index % ne10;\n    int i11 = (index / ne10) % ne11;\n    int i12 = (index / (ne10 * ne11)) % ne12;\n    int i13 = (index / (ne10 * ne11 * ne12)) % ne13;\n\n    int i00 = static_cast<int>(i10 / sf0);\n    int i01 = static_cast<int>(i11 / sf1);\n    int i02 = static_cast<int>(i12 / sf2);\n    int i03 = static_cast<int>(i13 / sf3);\n\n    dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);\n}\n\ntemplate<typename T>\nstatic void clamp(const T * x, T * dst, const float min, const float max, const int k,\n                      const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);\n    }\n}\n\ntemplate<typename T>\nstatic void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        const int64_t j0 = (i / n) * o0 + (i % n);\n        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);\n        dst[i] = op_gelu(x[j0]) * g[j1];\n    }\n}\n\ntemplate<typename T>\nstatic void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        const int64_t j0 = (i / n) * o0 + (i % n);\n        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);\n        dst[i] = op_relu(x[j0]) * g[j1];\n    }\n}\n\ntemplate<typename T>\nstatic void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1)  {\n        const int64_t j0 = (i / n) * o0 + (i % n);\n        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);\n        dst[i] = op_silu(x[j0]) * g[j1];\n    }\n}\n\ntemplate<typename T>\nstatic void gated_op_fused_geglu_erf(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        const int64_t j0 = (i / n) * o0 + (i % n);\n        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);\n        dst[i] = op_gelu_erf(x[j0]) * g[j1];\n    }\n}\n\ntemplate<typename T>\nstatic void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        const int64_t j0 = (i / n) * o0 + (i % n);\n        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);\n        dst[i] = op_gelu_quick(x[j0]) * g[j1];\n    }\n}\n\nnamespace ggml_sycl_detail {\nstatic void acc_f32_sycl(const float *x, const float *y, float *dst,\n                         const int64_t n_elements, const int64_t ne10, const int64_t ne11,\n                         const int64_t ne12, const int64_t ne13, const int64_t s1, const int64_t s2, const int64_t s3,\n                         const int64_t offset, queue_ptr stream) {\n    const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;\n    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),\n                                           sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),\n                         [=](sycl::nd_item<3> item_ct1) {\n                             acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);\n                         });\n}\n\ntemplate<typename T>\nstatic void arange_kernel(T * dst, const int k, T start, T step,\n                         const sycl::nd_item<1> &item_ct1) {\n    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {\n        dst[i] = start + static_cast<T>(i) * step;\n    }\n}\n\ntemplate<typename T>\nstatic void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,\n                             const int nb02, const int nb03, const int ne10, const int ne11,\n                             const int ne12, const int ne13, const float sf0, const float sf1,\n                             const float sf2, const float sf3, queue_ptr stream) {\n    int dst_size = ne10 * ne11 * ne12 * ne13;\n    int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE);\n    sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);\n    stream->parallel_for(\n        sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {\n            upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);\n        });\n}\n\ntemplate<typename KernelInvoker, typename... Args>\nstatic inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);\n    GGML_ASSERT(dst->src[0]->type == dst->type);\n\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    switch (dst->type) {\n        case GGML_TYPE_F16:\n            {\n                auto data_pts = cast_data<sycl::half>(dst);\n                kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);\n                break;\n            }\n        case GGML_TYPE_F32:\n            {\n                auto data_pts = cast_data<float>(dst);\n                kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);\n                break;\n            }\n        default:\n            GGML_ABORT(\"GGML tensor type not supported!\\n\");\n    }\n}\n\ntemplate<typename KernelInvoker, typename... Args>\nstatic inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);\n    GGML_ASSERT(dst->src[0]->type == dst->type);\n\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;;\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    const int32_t swapped = ((const int32_t *) dst->op_params)[1];\n    void * src0_d = src0->data;\n    void * src1_d = src1 ? src1->data : src0->data;\n    const int64_t src0_o = src0->nb[1];\n    const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n    void * dst_d = dst->data;\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));\n        GGML_ASSERT(src1->ne[0] == nc);\n        GGML_ASSERT(src0->type == src1->type);\n    }\n    switch (dst->type) {\n        case GGML_TYPE_F16:\n            {\n                sycl::half * src0_p = (sycl::half *) src0_d;\n                sycl::half * src1_p = (sycl::half *) src1_d;\n\n                    if (!src1) {\n                        src0_p += swapped ? nc : 0;\n                        src1_p += swapped ? 0 : nc;\n                    }\n                kernel_invoker(src0_p,\n                               src1_p,\n                               (sycl::half *) dst_d,\n                               ggml_nelements(dst),\n                               nc,\n                               src0_o / sizeof(sycl::half),\n                               src1_o / sizeof(sycl::half),\n                               main_stream,\n                               std::forward<Args>(args)...);\n                break;\n            }\n        case GGML_TYPE_F32:\n            {\n                float * src0_p = (float *) src0_d;\n                float * src1_p = (float *) src1_d;\n\n                    if (!src1) {\n                        src0_p += swapped ? nc : 0;\n                        src1_p += swapped ? 0 : nc;\n                    }\n\n                kernel_invoker(src0_p,\n                               src1_p,\n                               (float *) dst_d,\n                               ggml_nelements(dst),\n                               nc,\n                               src0_o / sizeof(float),\n                               src1_o / sizeof(float),\n                               main_stream,\n                               std::forward<Args>(args)...);\n                break;\n            }\n        default:\n            GGML_ABORT(\"GGML tensor type not supported!\\n\");\n    }\n}\n\ntemplate<typename KernelInvoker, typename... Args>\nstatic inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);\n\n    GGML_ASSERT(dst->src[0]->type == dst->type);\n\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n\n    const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0];\n    const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1];\n    const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];\n    const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];\n    switch (dst->type) {\n        case GGML_TYPE_F16:\n            {\n                auto data_pts = cast_data<sycl::half>(dst);\n                kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],\n                               (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,\n                               main_stream, std::forward<Args>(args)...);\n                break;\n            }\n        case GGML_TYPE_F32:\n            {\n                auto data_pts = cast_data<float>(dst);\n                kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],\n                               (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,\n                               main_stream, std::forward<Args>(args)...);\n                break;\n            }\n        default:\n            GGML_ABORT(\"GGML tensor type not supported!\\n\");\n    }\n}\n\ntemplate<typename F>\nstatic inline void ggml_sycl_op_unary(\n        ggml_backend_sycl_context & ctx, ggml_tensor * dst, F func) {\n\n    ggml_tensor * src0 = dst->src[0];\n\n    const int64_t ne0  = dst->ne[0];\n    const int64_t ne1  = dst->ne[1];\n    const int64_t ne2  = dst->ne[2];\n    const int64_t ne3  = dst->ne[3];\n\n    const size_t  nb0  = src0->nb[0];\n    const size_t  nb1  = src0->nb[1];\n    const size_t  nb2  = src0->nb[2];\n    const size_t  nb3  = src0->nb[3];\n\n    const size_t  nbd0 = dst->nb[0];\n    const size_t  nbd1 = dst->nb[1];\n    const size_t  nbd2 = dst->nb[2];\n    const size_t  nbd3 = dst->nb[3];\n\n    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,\n        [=](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {\n\n            const int num_blocks = ceil_div(k_elements, 256);\n\n            stream->parallel_for(\n                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),\n                                  sycl::range<1>(256)),\n                [=](sycl::nd_item<1> item_ct1) {\n                    unary_op_generic_kernel(\n                        src, dst_ptr, k_elements,\n                        ne0, ne1, ne2, ne3,\n                        nb0, nb1, nb2, nb3,\n                        nbd0, nbd1, nbd2, nbd3,\n                        item_ct1,\n                        func\n                    );\n                });\n        });\n}\n\n\nstatic inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n    float start, stop, step;\n    memcpy(&start, dst->op_params, sizeof(float));\n    memcpy(&stop, (float *) dst->op_params + 1, sizeof(float));\n    memcpy(&step, (float *) dst->op_params + 2, sizeof(float));\n    dpct::queue_ptr stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    float * dst_ptr = (float *)dst->data;\n    const int k = (int)ggml_nelements(dst);\n    const int num_blocks = ceil_div(k, SYCL_ARANGE_BLOCK_SIZE);\n    stream->parallel_for(\n        sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),\n                          sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),\n        [=](sycl::nd_item<1> item_ct1) {\n            arange_kernel(dst_ptr, k, start, step, item_ct1);\n        });\n}\n\n} // namespace ggml_sycl_detail\n\n\n\nstatic inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_sgn(x);\n    });\n}\n\n\nstatic inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_abs(x);\n    });\n}\n\nstatic inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_elu(x);\n    });\n}\nstatic inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_silu(x);\n    });\n}\n\nstatic inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_gelu(x);\n    });\n}\n\nstatic inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_gelu_quick(x);\n    });\n}\n\nstatic inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_gelu_erf(x);\n    });\n}\n\nstatic inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_tanh(x);\n    });\n}\n\nstatic inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_relu(x);\n    });\n}\n\nstatic inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_hardsigmoid(x);\n    });\n}\n\nstatic inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_hardswish(x);\n    });\n}\n\nstatic inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_exp(x);\n    });\n}\n\nstatic inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,\n        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {\n            const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size\n            stream->parallel_for(\n                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),\n                                  sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),\n                [=](sycl::nd_item<1> item_ct1) {\n                    unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);\n                });\n        });\n}\n\nstatic inline void ggml_sycl_op_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_softplus(x);\n    });\n}\n\nstatic inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_neg(x);\n    });\n}\n\n\nstatic inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_step(x);\n    });\n}\n\nstatic inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_sigmoid(x);\n    });\n}\n\nstatic inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,\n        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {\n            const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE);\n            stream->parallel_for(\n                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),\n                                  sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),\n                [=](sycl::nd_item<1> item_ct1) {\n                    unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);\n                });\n        });\n}\n\nstatic inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,\n        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {\n            const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE);\n            stream->parallel_for(\n                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),\n                                  sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),\n                [=](sycl::nd_item<1> item_ct1) {\n                    unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);\n                });\n        });\n}\n\nstatic inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,\n        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {\n            const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size\n            stream->parallel_for(\n                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),\n                                  sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),\n                [=](sycl::nd_item<1> item_ct1) {\n                    unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);\n                });\n        });\n}\n\nstatic inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    float negative_slope;\n    memcpy(&negative_slope, dst->op_params, sizeof(float));\n    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,\n        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) {\n            const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);\n            stream->parallel_for(\n                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),\n                                  sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),\n                [=](sycl::nd_item<1> item_ct1) {\n                    unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);\n                });\n        }, negative_slope);\n}\n\nstatic inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,\n        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {\n            const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE);\n            stream->parallel_for(\n                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),\n                                  sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),\n                [=](sycl::nd_item<1> item_ct1) {\n                    unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);\n                });\n        });\n}\n\nstatic inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst,\n        [](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03,\n           int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3,\n           queue_ptr stream) {\n            ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream);\n        });\n}\n\nstatic inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    float min_val;\n    float max_val;\n    memcpy(&min_val, dst->op_params, sizeof(float));\n    memcpy(&max_val, (float *) dst->op_params + 1, sizeof(float));\n    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,\n        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) {\n            const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE);\n            stream->parallel_for(\n                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),\n                                  sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),\n                [=](sycl::nd_item<1> item_ct1) {\n                    clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);\n                });\n        }, min_val, max_val);\n}\n\nstatic inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,\n        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {\n            const int num_blocks = ceil_div(k_elements, 256);\n            stream->parallel_for(\n                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),\n                                  sycl::range<1>(256)),\n                [=](sycl::nd_item<1> item_ct1) {\n                    unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1);\n                });\n        });\n}\n\nstatic inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {\n        return op_ceil(x);\n    });\n}\n\nstatic inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,\n        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {\n            const int num_blocks = ceil_div(k_elements, 256);\n            stream->parallel_for(\n                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),\n                                  sycl::range<1>(256)),\n                [=](sycl::nd_item<1> item_ct1) {\n                    unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1);\n                });\n        });\n}\n\nstatic inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,\n        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {\n            const int num_blocks = ceil_div(k_elements, 256);\n            stream->parallel_for(\n                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),\n                                  sycl::range<1>(256)),\n                [=](sycl::nd_item<1> item_ct1) {\n                    unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1);\n                });\n        });\n}\n\nstatic inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    const float * src0_d = (const float *) src0->data;\n    const float * src1_d = (const float *) src1->data;\n    float       * dst_d  = (float       *)  dst->data;\n\n    dpct::queue_ptr stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_ASSERT(ggml_is_contiguous(src1));\n    GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));\n    GGML_ASSERT(ggml_is_contiguously_allocated(dst));\n\n    const int64_t s1     = dst->op_params[0] / sizeof(float);\n    const int64_t s2     = dst->op_params[1] / sizeof(float);\n    const int64_t s3     = dst->op_params[2] / sizeof(float);\n    const int64_t offset = dst->op_params[3] / sizeof(float);\n\n    ggml_sycl_detail::acc_f32_sycl(src0_d, src1_d, dst_d, ggml_nelements(dst),\n        src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],\n        s1, s2, s3, offset, stream);\n}\n\nstatic inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,\n        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {\n            const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);\n            main_stream->parallel_for(\n                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {\n                gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);\n            });\n        });\n}\n\nstatic inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,\n        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {\n            const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu\n            main_stream->parallel_for(\n                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {\n                gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);\n            });\n        });\n}\n\nstatic inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,\n        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {\n            const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu\n            main_stream->parallel_for(\n                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {\n                gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);\n            });\n        });\n}\n\n__dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {\n    x = sycl::fmin(x, limit);\n    g = sycl::fmax(sycl::fmin(g, limit), -limit);\n\n    float out_glu = x / (1.0f + sycl::native::exp(-x * alpha));\n    out_glu = out_glu * (1.0f + g);\n    return out_glu;\n}\n\n\ntemplate <typename T>\nstatic void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,\n                              const int64_t n, const int64_t o0, const int64_t o1,\n                              float alpha, float limit, sycl::nd_item<3> item_ct1) {\n    const int64_t i = int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2);\n\n    if (i >= k) {\n        return;\n    }\n\n    const int64_t j0 = (i / n) * o0 + (i % n);\n    const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);\n\n    float xi = x[j0];\n    float gi = g[j1];\n\n    dst[i] = ggml_sycl_op_swiglu_oai_single(xi, gi, alpha, limit);\n}\n\ntemplate <typename T>\nstatic void swiglu_oai_sycl(const T *       x,\n                            const T *       g,\n                            T *             dst,\n                            const int64_t   k,\n                            const int64_t   n,\n                            const int64_t   o0,\n                            const int64_t   o1,\n                            const float     alpha,\n                            const float     limit,\n                            dpct::queue_ptr stream) {\n    const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;\n    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),\n                                           sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),\n                         [=](sycl::nd_item<3> item_ct1) {\n                             swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);\n                         });\n}\n\nvoid ggml_sycl_op_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    void * src0_d = src0->data;\n    void * src1_d = src1 ? src1->data : src0->data;\n    const int64_t src0_o = src0->nb[1];\n    const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];\n    void * dst_d = dst->data;\n    const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;\n    dpct::queue_ptr     stream = ctx.stream();\n\n    GGML_ASSERT(ggml_is_contiguous_1(src0));\n    GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n    GGML_ASSERT(src0->type == dst->type);\n    GGML_ASSERT(dst->ne[0] == nc);\n    GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));\n\n    if (src1) {\n        GGML_ASSERT(ggml_is_contiguous_1(src1));\n        GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));\n        GGML_ASSERT(src1->ne[0] == nc);\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    //const int32_t swapped = ((const int32_t *) dst->op_params)[1];\n    const int32_t swapped = ggml_get_op_params_i32(dst, 1);\n    const float alpha = ggml_get_op_params_f32(dst, 2);\n    const float limit = ggml_get_op_params_f32(dst, 3);\n\n    float * src0_p = (float *) src0_d;\n    float * src1_p = (float *) src1_d;\n\n    if (!src1) {\n        src0_p += swapped ? nc : 0;\n        src1_p += swapped ? 0 : nc;\n    }\n\n    swiglu_oai_sycl(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);\n}\n\nstatic inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,\n        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {\n            const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);\n            main_stream->parallel_for(\n                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {\n                gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);\n            });\n        });\n}\n\nstatic inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,\n        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {\n            const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);\n            main_stream->parallel_for(\n                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {\n                gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);\n            });\n        });\n}\n\n\nvoid ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_sqrt(ctx, dst);\n}\n\nvoid ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_sin(ctx, dst);\n}\n\nvoid ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_cos(ctx, dst);\n}\n\nvoid ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    ggml_sycl_op_acc(ctx, dst);\n}\n\nvoid ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_gelu(ctx, dst);\n}\n\nvoid ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_silu(ctx, dst);\n}\n\nvoid ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_gelu_quick(ctx, dst);\n}\n\nvoid ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_gelu_erf(ctx, dst);\n}\n\nvoid ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_tanh(ctx, dst);\n}\n\nvoid ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_relu(ctx, dst);\n}\n\nvoid ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_sigmoid(ctx, dst);\n}\n\nvoid ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_hardsigmoid(ctx, dst);\n}\n\nvoid ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_hardswish(ctx, dst);\n}\n\nvoid ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_exp(ctx, dst);\n}\n\nvoid ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_log(ctx, dst);\n}\n\nvoid ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_softplus(ctx, dst);\n}\n\nvoid ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_neg(ctx, dst);\n}\n\nvoid ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_step(ctx, dst);\n}\n\nvoid ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_leaky_relu(ctx, dst);\n}\n\nvoid ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_sqr(ctx, dst);\n}\n\nvoid ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_upscale(ctx, dst);\n}\n\n\nvoid ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_clamp(ctx, dst);\n}\n\nvoid ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_sgn(ctx, dst);\n}\n\nvoid ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_abs(ctx, dst);\n}\n\nvoid ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_elu(ctx, dst);\n}\n\nvoid ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_geglu(ctx, dst);\n}\n\nvoid ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_reglu(ctx, dst);\n}\n\nvoid ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_swiglu(ctx, dst);\n}\n\nvoid ggml_sycl_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_swiglu_oai(ctx, dst);\n}\n\nvoid ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_geglu_erf(ctx, dst);\n}\n\nvoid ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_geglu_quick(ctx, dst);\n}\n\nvoid ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);\n    ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst);\n}\n\nvoid ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_floor(ctx, dst);\n}\n\nvoid ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_ceil(ctx, dst);\n}\n\nvoid ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_round(ctx, dst);\n}\n\nvoid ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_trunc(ctx, dst);\n}\n"
  },
  {
    "path": "src/ggml-sycl/element_wise.hpp",
    "content": "#ifndef GGML_SYCL_ELEMENTWISE_HPP\n#define GGML_SYCL_ELEMENTWISE_HPP\n\n#include \"common.hpp\"\n#include \"ggml.h\"\n#include <limits> // For std::numeric_limits\n\n#define SYCL_GLU_BLOCK_SIZE 256\n\ntemplate <typename T>\nT neg_infinity() {\n    return -std::numeric_limits<T>::infinity();\n}\n\ntemplate<typename T_Dst, typename T_Src = T_Dst>\nstruct typed_data {\n    const T_Src * src;\n    T_Dst * dst;\n};\n\ntemplate<typename T_Dst, typename T_Src = T_Dst>\ntyped_data<T_Dst, T_Src> cast_data(ggml_tensor * dst) {\n    return {\n        /* .src = */ static_cast<const T_Src *>(dst->src[0]->data),\n        /* .dst = */ static_cast<T_Dst *>(dst->data)\n    };\n}\n\nconst float GELU_QUICK_COEF = -1.702f;\n\n\nvoid ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\nvoid ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\nvoid ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\nvoid ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\nvoid ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\nvoid ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\nvoid ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\nvoid ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\nvoid ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\n#endif // GGML_SYCL_ELEMENTWISE_HPP\n"
  },
  {
    "path": "src/ggml-sycl/fattn-common.hpp",
    "content": "#pragma once\n\n#include <sycl/sycl.hpp>\n#include \"dpct/helper.hpp\"\n#include \"common.hpp\"\n#include \"convert.hpp\"\n#include \"vecdotq.hpp\"\n\n#include \"ggml.h\"\n\n#include <cstdint>\n#include <cmath>\n#include <float.h>\n\n\n#define FATTN_KQ_STRIDE       256\n#define HALF_MAX_HALF         sycl::half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.\n#define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.\n#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)\n\ntypedef void (*fattn_kernel_t)(\n    const char* Q,\n    const char* K,\n    const char* V,\n    const char* mask,\n    const char* sinks,\n    const int* KV_max,\n    float* dst,\n    sycl::float2* dst_meta,\n    const float scale,\n    const float max_bias,\n    const float m0,\n    const float m1,\n    const uint32_t n_head_log2,\n    const float logit_softcap,\n    const int32_t ne00,\n    const sycl::uint3 ne01,\n    const int32_t ne02,\n    const int32_t ne03,\n    const int32_t nb01,\n    const int32_t nb02,\n    const int32_t nb03,\n    const int32_t ne10,\n    const int32_t ne11,\n    const int32_t ne12,\n    const int32_t ne13,\n    const int32_t nb11,\n    const int32_t nb12,\n    const int64_t nb13,\n    const int32_t nb21,\n    const int32_t nb22,\n    const int64_t nb23,\n    const int32_t ne31,\n    const int32_t ne32,\n    const int32_t ne33,\n    const int32_t nb31,\n    const int32_t nb32,\n    const int64_t nb33);\n\ntypedef float (*vec_dot_KQ_t)(\n    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);\n\ntemplate <int D, int nthreads>\nstatic __dpct_inline__ float vec_dot_fattn_vec_KQ_f16(const char * __restrict__ K_c,\n                                                      const void * __restrict__ Q_v,\n                                                      const int * __restrict__ Q_q8,\n                                                      const void * __restrict__ Q_ds_v) {\n    const sycl::half2 * K_h2 = (const sycl::half2 *) K_c;\n    GGML_UNUSED(Q_q8);\n    GGML_UNUSED(Q_ds_v);\n\n    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();\n    constexpr int cpy_ne = cpy_nb / 4;\n\n    float sum = 0.0f;\n\n#pragma unroll\n    for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {\n        sycl::half2 tmp[cpy_ne];\n        ggml_sycl_memcpy_1<sizeof(tmp)>(\n            tmp,\n            K_h2 + k_KQ_0 + (sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_id(2) % nthreads) * cpy_ne);\n#pragma unroll\n        for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {\n#ifdef GGML_SYCL_F16\n            ggml_sycl_mad(sum,                tmp[k_KQ_1] , ((const sycl::half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);\n#else\n            ggml_sycl_mad(sum, __half22float2(tmp[k_KQ_1]), ((const sycl::float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);\n#endif // GGML_SYCL_F16\n        }\n    }\n\n    return sum;\n}\n\ntemplate <int D, int nthreads, int warp_size>\nstatic __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_0(const char * __restrict__ K_c,\n                                                       const void * __restrict__ Q_v,\n                                                       const int * __restrict__ Q_q8,\n                                                       const void * __restrict__ Q_ds_v) {\n    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n\n    const block_q4_0 * K_q4_0   = (const block_q4_0 *) K_c;\n    GGML_UNUSED(Q_v);\n\n    float sum = 0.0f;\n\n#pragma unroll\n    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {\n        const int k_KQ =\n            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);\n\n        const int ib    = k_KQ /  QI8_1;\n        const int iqs4  = k_KQ %  QI4_0;\n        const int shift = k_KQ & (QI8_1/2);\n\n        int v;\n        ggml_sycl_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);\n        v = (v >> shift) & 0x0F0F0F0F;\n        const int u = Q_q8[k_KQ_0/nthreads];\n\n        const int sumi = ggml_sycl_dp4a(v, u, 0);\n\n        const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];\n        sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x() - (8/QI8_1)*Q_ds.y());\n    }\n\n    return sum;\n}\n\ntemplate <int D, int nthreads , int warp_size>\nstatic __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_1(const char * __restrict__ K_c,\n                                                       const void * __restrict__ Q_v,\n                                                       const int * __restrict__ Q_q8,\n                                                       const void * __restrict__ Q_ds_v) {\n    auto               item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const block_q4_1 * K_q4_1   = (const block_q4_1 *) K_c;\n    GGML_UNUSED(Q_v);\n\n    float sum = 0.0f;\n\n#pragma unroll\n    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {\n        const int k_KQ =\n            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);\n\n        const int ib    = k_KQ /  QI8_1;\n        const int iqs4  = k_KQ %  QI4_1;\n        const int shift = k_KQ & (QI8_1/2);\n\n        int v;\n        ggml_sycl_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);\n        v = (v >> shift) & 0x0F0F0F0F;\n        const int u = Q_q8[k_KQ_0/nthreads];\n\n        const int sumi = ggml_sycl_dp4a(v, u, 0);\n\n        const sycl::float2 K_dm = (K_q4_1[ib].dm).template convert<float, sycl::rounding_mode::automatic>();\n        const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];\n\n        sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1;\n    }\n\n    return sum;\n}\n\ntemplate <int D, int nthreads, int warp_size>\nstatic __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_0(const char * __restrict__ K_c,\n                                                       const void * __restrict__ Q_v,\n                                                       const int * __restrict__ Q_q8,\n                                                       const void * __restrict__ Q_ds_v) {\n    auto               item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const block_q5_0 * K_q5_0   = (const block_q5_0 *) K_c;\n    GGML_UNUSED(Q_v);\n\n    float sum = 0.0f;\n\n#pragma unroll\n    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {\n        const int k_KQ =\n            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);\n\n        const int ib    = k_KQ /  QI8_1;\n        const int iqs4  = k_KQ %  QI5_0;\n        const int iqs8  = k_KQ %  QI8_1;\n        const int shift = k_KQ & (QI8_1/2);\n\n        int v;\n        ggml_sycl_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);\n        v = (v >> shift) & 0x0F0F0F0F;\n\n        {\n            int vh;\n            ggml_sycl_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);\n            vh >>= iqs8 * QI5_0;\n\n            v |= (vh <<  4) & 0x00000010; // 0 ->  4\n            v |= (vh << 11) & 0x00001000; // 1 -> 12\n            v |= (vh << 18) & 0x00100000; // 2 -> 20\n            v |= (vh << 25) & 0x10000000; // 3 -> 28\n        }\n\n        const int u = Q_q8[k_KQ_0/nthreads];\n\n        const int sumi = ggml_sycl_dp4a(v, u, 0);\n\n        const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];\n\n        sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x() - (16/QI8_1)*Q_ds.y());\n    }\n\n    return sum;\n}\n\ntemplate <int D, int nthreads, int warp_size>\nstatic __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_1(const char * __restrict__ K_c,\n                                                       const void * __restrict__ Q_v,\n                                                       const int * __restrict__ Q_q8,\n                                                       const void * __restrict__ Q_ds_v) {\n    auto               item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const block_q5_1 * K_q5_1   = (const block_q5_1 *) K_c;\n    GGML_UNUSED(Q_v);\n\n    float sum = 0.0f;\n\n#pragma unroll\n    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {\n        const int k_KQ =\n            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);\n\n        const int ib    = k_KQ /  QI8_1;\n        const int iqs4  = k_KQ %  QI5_1;\n        const int iqs8  = k_KQ %  QI8_1;\n        const int shift = k_KQ & (QI8_1/2);\n\n        int v;\n        ggml_sycl_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);\n        v = (v >> shift) & 0x0F0F0F0F;\n\n        {\n            int vh;\n            ggml_sycl_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);\n            vh >>= iqs8 * QI5_0;\n\n            v |= (vh <<  4) & 0x00000010; // 0 ->  4\n            v |= (vh << 11) & 0x00001000; // 1 -> 12\n            v |= (vh << 18) & 0x00100000; // 2 -> 20\n            v |= (vh << 25) & 0x10000000; // 3 -> 28\n        }\n\n        const int u = Q_q8[k_KQ_0/nthreads];\n\n        const int sumi = ggml_sycl_dp4a(v, u, 0);\n\n        const sycl::float2 K_dm = (K_q5_1[ib].dm).template convert<float, sycl::rounding_mode::automatic>();\n        const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];\n\n        sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1;\n    }\n\n    return sum;\n}\n\ntemplate <int D, int nthreads, int warp_size>\nstatic __dpct_inline__ float vec_dot_fattn_vec_KQ_q8_0(const char * __restrict__ K_c,\n                                                       const void * __restrict__ Q_v,\n                                                       const int * __restrict__ Q_q8,\n                                                       const void * __restrict__ Q_ds_v) {\n    auto               item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const block_q8_0 * K_q8_0   = (const block_q8_0 *) K_c;\n    GGML_UNUSED(Q_v);\n\n    float sum = 0.0f;\n\n#pragma unroll\n    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {\n        const int k_KQ =\n            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);\n\n        const int ib  = k_KQ / QI8_0;\n        const int iqs = k_KQ % QI8_0;\n\n        int v;\n        ggml_sycl_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);\n\n        const sycl::float2 * Q_ds = (const sycl::float2 *) Q_ds_v;\n        const float          Q_d  = Q_ds[k_KQ_0 / nthreads].x();\n\n        sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);\n    }\n\n    return sum;\n}\n\ntemplate <typename Tds, int ni, int warp_size>\nstatic __dpct_inline__ void quantize_q8_1_to_shared(const float * __restrict__ x,\n                                                    const float scale,\n                                                    int * __restrict__ yq32,\n                                                    void * __restrict__ yds) {\n    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n\n    float vals[sizeof(int)] = { 0.0f };\n#pragma unroll\n    for (int l = 0; l < int(sizeof(int)); ++l) {\n        vals[l] =\n            (ni == warp_size || item_ct1.get_local_id(2) < ni) ? scale * x[4 * item_ct1.get_local_id(2) + l] : 0.0f;\n    }\n\n    float amax = sycl::fabs(vals[0]);\n    float sum  = vals[0];\n#pragma unroll\n    for (int l = 1; l < int(sizeof(int)); ++l) {\n        amax = sycl::fmax(amax, sycl::fabs(vals[l]));\n        sum += vals[l];\n    }\n#pragma unroll\n    for (int mask = QI8_1/2; mask > 0; mask >>= 1) {\n        amax = sycl::fmax(\n            amax, dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), amax, mask));\n        sum += dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), sum, mask);\n    }\n\n    const float d = amax / 127;\n    int q32 = 0;\n    int8_t * q8 = (int8_t *) &q32;\n\n    if (d != 0.0f) {\n#pragma unroll\n        for (int l = 0; l < int(sizeof(int)); ++l) {\n            q8[l] = sycl::round(vals[l] / d);\n        }\n    }\n\n    yq32[item_ct1.get_local_id(2)] = q32;\n    if (item_ct1.get_local_id(2) % QI8_1 == 0 && (ni == warp_size || item_ct1.get_local_id(2) < ni)) {\n        if (std::is_same<Tds, sycl::half2>::value) {\n            ((sycl::half2  *) yds)[item_ct1.get_local_id(2)/QI8_1] =  make_half2(d, sum);\n        } else {\n            ((sycl::float2 *) yds)[item_ct1.get_local_id(2)/QI8_1] = make_float2(d, sum);\n        }\n    }\n}\n\ntypedef void (*dequantize_V_t)(const void *, void *, const int64_t);\n\ntemplate <typename T, int ne>\nstatic __dpct_inline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {\n    if constexpr (std::is_same_v<T, sycl::half>) {\n        ggml_sycl_memcpy_1<ne * sizeof(sycl::half)>(dst, (const sycl::half *) vx + i0);\n    } else if constexpr (std::is_same_v<T, float>) {\n        static_assert(ne % 2 == 0, \"bad ne\");\n        sycl::half2 tmp[ne / 2];\n        ggml_sycl_memcpy_1<ne * sizeof(sycl::half)>(tmp, (const sycl::half *) vx + i0);\n        sycl::float2 * dst_f2 = (sycl::float2 *) dst;\n#pragma unroll\n        for (int l = 0; l < ne/2; ++l) {\n            dst_f2[l] = tmp[l].template convert<float, sycl::rounding_mode::automatic>();\n        }\n    } else {\n        static_assert(std::is_same_v<T, void>, \"unsupported type\");\n    }\n}\n\ntemplate <typename T, int ne>\nstatic __dpct_inline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {\n    const block_q4_0 * x = (const block_q4_0 *) vx;\n\n    const int64_t ib    =  i0          /  QK4_0;\n    const int     iqs   =  i0          % (QK4_0/2);\n    const int     shift = (i0 % QK4_0) / (QK4_0/2);\n\n    int q;\n    static_assert(ne == 2 || ne == 4, \"bad ne\");\n    ggml_sycl_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);\n    q >>= 4*shift;\n    q &= 0x0F0F0F0F;\n    q = dpct::vectorized_binary<sycl::char4>(q, 0x08080808, dpct::sub_sat());\n\n    const int8_t * q8 = (const int8_t *) &q;\n\n#ifdef GGML_SYCL_F16\n    if constexpr (std::is_same_v<T, sycl::half>) {\n        const sycl::half2 d = sycl::half2(x[ib].d);\n\n#pragma unroll\n        for (int l0 = 0; l0 < ne; l0 += 2) {\n            ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]);\n        }\n    } else\n#endif // GGML_SYCL_F16\n    if constexpr (std::is_same_v<T, float>) {\n        const float d = x[ib].d;\n\n#pragma unroll\n        for (int l = 0; l < ne; ++l) {\n            ((float *) dst)[l] = d * q8[l];\n        }\n    } else {\n        static_assert(std::is_same_v<T, void>, \"bad type\");\n    }\n}\n\ntemplate <typename T, int ne>\nstatic __dpct_inline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {\n    const block_q4_1 * x = (const block_q4_1 *) vx;\n\n    const int64_t ib    =  i0          /  QK4_1;\n    const int     iqs   =  i0          % (QK4_1/2);\n    const int     shift = (i0 % QK4_1) / (QK4_1/2);\n\n    int q;\n    static_assert(ne == 2 || ne == 4, \"bad ne\");\n    ggml_sycl_memcpy_1<ne>(&q, x[ib].qs + iqs);\n    q >>= 4*shift;\n    q &= 0x0F0F0F0F;\n\n    const int8_t * q8 = (const int8_t *) &q;\n\n#ifdef GGML_SYCL_F16\n    if constexpr (std::is_same_v<T, sycl::half>) {\n        const sycl::half2 dm = x[ib].dm;\n        const sycl::half2 d  = sycl::half2(dm[0]);\n        const sycl::half2 m  = sycl::half2(dm[1]);\n\n#pragma unroll\n        for (int l0 = 0; l0 < ne; l0 += 2) {\n            ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m;\n        }\n    } else\n#endif // GGML_SYCL_F16\n    if constexpr (std::is_same_v<T, float>) {\n        const sycl::float2 dm = (x[ib].dm).template convert<float, sycl::rounding_mode::automatic>();\n\n#pragma unroll\n        for (int l = 0; l < ne; ++l) {\n            ((float *) dst)[l] = dm.x() * q8[l] + dm.y();\n        }\n    } else {\n        static_assert(std::is_same_v<T, void>, \"bad type\");\n    }\n}\n\ntemplate <typename T, int ne>\nstatic __dpct_inline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {\n    const block_q5_0 * x = (const block_q5_0 *) vx;\n\n    const int64_t ib    =  i0          /  QK5_0;\n    const int     idq   =  i0          %  QK5_0;\n    const int     iqs   =  i0          % (QK5_0/2);\n    const int     shift = (i0 % QK5_0) / (QK5_0/2);\n\n    int q;\n    static_assert(ne == 2 || ne == 4, \"bad ne\");\n    ggml_sycl_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);\n    q >>= 4*shift;\n    q &= 0x0F0F0F0F;\n\n    {\n        int qh;\n        ggml_sycl_memcpy_1<ne, 2>(&qh, x[ib].qh);\n#pragma unroll\n        for (int l = 0; l < ne; ++l) {\n            q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);\n        }\n    }\n\n    q = dpct::vectorized_binary<sycl::char4>(q, 0x10101010, dpct::sub_sat());\n\n    const int8_t * q8 = (const int8_t *) &q;\n\n#ifdef GGML_SYCL_F16\n    if constexpr (std::is_same_v<T, sycl::half>) {\n        const sycl::half2 d = sycl::half2(x[ib].d);\n\n#pragma unroll\n        for (int l0 = 0; l0 < ne; l0 += 2) {\n            ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]);\n        }\n    } else\n#endif // GGML_SYCL_F16\n    if constexpr (std::is_same_v<T, float>) {\n        const float d = x[ib].d;\n\n#pragma unroll\n        for (int l = 0; l < ne; ++l) {\n            ((float *) dst)[l] = d * q8[l];\n        }\n    } else {\n        static_assert(std::is_same_v<T, void>, \"bad type\");\n    }\n}\n\ntemplate <typename T, int ne>\nstatic __dpct_inline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {\n    const block_q5_1 * x = (const block_q5_1 *) vx;\n\n    const int64_t ib    =  i0          /  QK5_1;\n    const int     idq   =  i0          %  QK5_1;\n    const int     iqs   =  i0          % (QK5_1/2);\n    const int     shift = (i0 % QK5_1) / (QK5_1/2);\n\n    int q;\n    static_assert(ne == 2 || ne == 4, \"bad ne\");\n    ggml_sycl_memcpy_1<ne>(&q, x[ib].qs + iqs);\n    q >>= 4*shift;\n    q &= 0x0F0F0F0F;\n\n    {\n        int qh;\n        ggml_sycl_memcpy_1<ne>(&qh, x[ib].qh);\n#pragma unroll\n        for (int l = 0; l < ne; ++l) {\n            q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);\n        }\n    }\n\n    const int8_t * q8 = (const int8_t *) &q;\n\n#ifdef GGML_SYCL_F16\n    if constexpr (std::is_same_v<T, sycl::half>) {\n        const sycl::half2 dm = x[ib].dm;\n        const sycl::half2 d  = sycl::half2(dm[0]);\n        const sycl::half2 m  = sycl::half2(dm[1]);\n\n#pragma unroll\n        for (int l0 = 0; l0 < ne; l0 += 2) {\n            ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m;\n        }\n    } else\n#endif // GGML_SYCL_F16\n    if constexpr (std::is_same_v<T, float>) {\n        const sycl::float2 dm = (x[ib].dm).template convert<float, sycl::rounding_mode::automatic>();\n\n#pragma unroll\n        for (int l = 0; l < ne; ++l) {\n            ((float *) dst)[l] = dm.x() * q8[l] + dm.y();\n        }\n    } else {\n        static_assert(std::is_same_v<T, void>, \"bad type\");\n    }\n}\n\ntemplate <typename T, int ne>\nstatic __dpct_inline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {\n    const block_q8_0 * x = (const block_q8_0 *) vx;\n\n    const int64_t ib  = i0 / QK8_0;\n    const int     iqs = i0 % QK8_0;\n\n    static_assert(ne % 2 == 0, \"bad ne\");\n    int8_t qs[ne];\n    ggml_sycl_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);\n\n#ifdef GGML_SYCL_F16\n    if constexpr (std::is_same<T, sycl::half>::value) {\n        const sycl::half2 d = sycl::half2(x[ib].d);\n\n#pragma unroll\n        for (int l0 = 0; l0 < ne; l0 += 2) {\n            ((sycl::half2 *) dst)[l0 / 2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);\n        }\n    } else\n#endif // GGML_SYCL_F16\n    if constexpr (std::is_same<T, float>::value) {\n        const float d = x[ib].d;\n\n#pragma unroll\n        for (int l = 0; l < ne; ++l) {\n            ((float *) dst)[l] = d * qs[l];\n        }\n    } else {\n        static_assert(std::is_same_v<T, void>, \"unsupported type\");\n    }\n}\n\ntemplate <int type_K, int D, int nthreads, int warp_size>\nconstexpr vec_dot_KQ_t get_vec_dot_KQ() {\n    if constexpr (type_K == GGML_TYPE_F16) {\n        return vec_dot_fattn_vec_KQ_f16<D, nthreads>;\n    } else if constexpr (type_K == GGML_TYPE_Q4_0) {\n        return vec_dot_fattn_vec_KQ_q4_0<D, nthreads, warp_size>;\n    } else if constexpr (type_K == GGML_TYPE_Q4_1) {\n        return vec_dot_fattn_vec_KQ_q4_1<D, nthreads, warp_size>;\n    } else if constexpr (type_K == GGML_TYPE_Q5_0) {\n        return vec_dot_fattn_vec_KQ_q5_0<D, nthreads, warp_size>;\n    } else if constexpr (type_K == GGML_TYPE_Q5_1) {\n        return vec_dot_fattn_vec_KQ_q5_1<D, nthreads, warp_size>;\n    } else if constexpr (type_K == GGML_TYPE_Q8_0) {\n        return vec_dot_fattn_vec_KQ_q8_0<D, nthreads, warp_size>;\n    } else {\n        static_assert(type_K == -1, \"bad type\");\n        return nullptr;\n    }\n}\n\ntemplate <int type_V, typename T, int ne>\nconstexpr dequantize_V_t get_dequantize_V() {\n    if constexpr (type_V == GGML_TYPE_F16) {\n        return dequantize_V_f16<T, ne>;\n    } else if constexpr (type_V == GGML_TYPE_Q4_0) {\n        return dequantize_V_q4_0<T, ne>;\n    } else if constexpr (type_V == GGML_TYPE_Q4_1) {\n        return dequantize_V_q4_1<T, ne>;\n    } else if constexpr (type_V == GGML_TYPE_Q5_0) {\n        return dequantize_V_q5_0<T, ne>;\n    } else if constexpr (type_V == GGML_TYPE_Q5_1) {\n        return dequantize_V_q5_1<T, ne>;\n    } else if constexpr (type_V == GGML_TYPE_Q8_0) {\n        return dequantize_V_q8_0<T, ne>;\n    } else {\n        static_assert(type_V == -1, \"bad type\");\n        return nullptr;\n    }\n}\n\ntemplate <int ncols1, int warp_size>\nstatic void flash_attn_mask_to_KV_max(const sycl::half2 * __restrict__ mask,\n                                      int * __restrict__ KV_max,\n                                      const int ne30,\n                                      const int s31,\n                                      const int s33,\n                                      int *     buf_iw) {\n    auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const int ne31     = item_ct1.get_group_range(2);\n    const int tid      = item_ct1.get_local_id(2);\n    const int sequence = item_ct1.get_group(1);\n    const int jt       = item_ct1.get_group(2);\n\n    mask += sequence*s33 + jt*ncols1*s31;\n\n    if (tid < warp_size) {\n        buf_iw[tid] = 1;\n    }\n    item_ct1.barrier(sycl::access::fence_space::local_space);\n\n    int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;\n    for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {\n        int all_inf = 1;\n\n#pragma unroll\n        for (int j = 0; j < ncols1; ++j) {\n            const sycl::float2 tmp =\n                mask[j * s31 + KV_max_sj / 2 + tid].template convert<float, sycl::rounding_mode::automatic>();\n            all_inf = all_inf && int(sycl::isinf((float) (tmp.x()))) && int(sycl::isinf((float) (tmp.y())));\n        }\n\n        all_inf = warp_reduce_all<warp_size>(all_inf);\n        if (tid % warp_size == 0) {\n            buf_iw[tid / warp_size] = all_inf;\n        }\n        item_ct1.barrier(sycl::access::fence_space::local_space);\n        all_inf = buf_iw[tid % warp_size];\n        item_ct1.barrier(sycl::access::fence_space::local_space);\n        all_inf = warp_reduce_all<warp_size>(all_inf);\n\n        if (!all_inf) {\n            break;\n        }\n    }\n\n    // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.\n    // If the break was triggered it's the lower edge of the tile with the first non-masked values.\n    // In either case, walk back the decrementation by FATTN_KQ_STRIDE.\n    KV_max_sj += FATTN_KQ_STRIDE;\n\n    if (item_ct1.get_local_id(2) != 0) {\n        return;\n    }\n\n    KV_max[sequence*ne31 + jt] = KV_max_sj;\n}\n\ntemplate <int D, int ncols1, int ncols2>  // D == head size\n\nstatic void flash_attn_stream_k_fixup(float * __restrict__ dst,\n                                      const sycl::float2 * __restrict__ dst_fixup,\n                                      const int ne01,\n                                      const int ne02,\n                                      const int ne03,\n                                      const int ne11,\n                                      const int ne12,\n                                      const int nbatch_fa) {\n    auto          item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    constexpr int ncols    = ncols1 * ncols2;\n\n    const int bidx0 = item_ct1.get_group(2);\n    const int j     = item_ct1.get_group(1);\n    const int c     = item_ct1.get_group(0);\n    const int jc    = j*ncols2 + c;\n    const int tid   = item_ct1.get_local_id(2);\n\n    const float * dst_fixup_data = ((const float *) dst_fixup) + item_ct1.get_group_range(2) * (2 * 2 * ncols);\n\n    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.\n\n    const int iter_k     = (ne11      + (nbatch_fa - 1)) / nbatch_fa;\n    const int iter_j     = (ne01      + (ncols1    - 1)) / ncols1;\n    const int iter_z_gqa = (gqa_ratio + (ncols2    - 1)) / ncols2;\n\n    const int kbc0 = int64_t(bidx0 + 0) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);\n    const int kbc0_stop =\n        int64_t(bidx0 + 1) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);\n\n    const bool did_not_have_any_data   = kbc0 == kbc0_stop;\n    const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;\n    const bool did_not_write_last      = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;\n    if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {\n        return;\n    }\n\n    // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index\n    const int sequence =  kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);\n    const int z_KV     = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);\n    const int zt_gqa   = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);\n    const int jt       = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;\n\n    const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.\n\n    if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {\n        return;\n    }\n\n    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;\n\n    // Load the partial result that needs a fixup:\n    float dst_val = 0.0f;\n    float max_val = 0.0f;\n    float rowsum  = 0.0f;\n    {\n        dst_val = *dst;\n\n        const sycl::float2 tmp = dst_fixup[bidx0 * ncols + jc];\n        max_val                = tmp.x();\n        rowsum                 = tmp.y();\n    }\n\n    // Iterate over previous blocks and compute the combined results.\n    // All SYCL blocks that get here must have a previous block that needs a fixup.\n    int bidx = bidx0 - 1;\n    int kbc_stop = kbc0;\n    while(true) {\n        const int kbc = int64_t(bidx) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);\n        if (kbc == kbc_stop) { // Did not have any data.\n            bidx--;\n            kbc_stop = kbc;\n            continue;\n        }\n\n        const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];\n\n        const sycl::float2 tmp = dst_fixup[(item_ct1.get_group_range(2) + bidx) * ncols + jc];\n\n        // Scale the current and new value accumulators depending on the max. values.\n        const float max_val_new = sycl::fmax(max_val, tmp.x());\n\n        const float diff_val = max_val - max_val_new;\n        const float diff_add = tmp.x() - max_val_new;\n\n        const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_val) : 0.0f;\n        const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_add) : 0.0f;\n\n        dst_val = scale_val*dst_val + scale_add*dst_add;\n        rowsum  = scale_val * rowsum + scale_add * tmp.y();\n\n        max_val = max_val_new;\n\n        // If this block started in a previous tile we are done and don't need to combine additional partial results.\n        if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {\n            break;\n        }\n        bidx--;\n        kbc_stop = kbc;\n    }\n\n    // Write back final result:\n    *dst = dst_val / rowsum;\n}\n\ntemplate <int D>  // D == head size\n\nstatic void flash_attn_combine_results(const float * __restrict__ VKQ_parts,\n                                       const sycl::float2 * __restrict__ VKQ_meta,\n                                       float * __restrict__ dst,\n                                       const int parallel_blocks,\n                                       uint8_t * dpct_local) {\n    // Dimension 0: threadIdx.x\n    // Dimension 1: blockIdx.x\n    // Dimension 2: blockIdx.y\n    // Dimension 3: blockIdx.z\n    // Memory layout is permuted with [0, 2, 1, 3]\n\n    auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const int ne01     = item_ct1.get_group_range(2);\n    const int ne02     = item_ct1.get_group_range(1);\n\n    const int col      = item_ct1.get_group(2);\n    const int head     = item_ct1.get_group(1);\n    const int sequence = item_ct1.get_group(0);\n\n    const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;\n\n    VKQ_parts += j_dst_unrolled * parallel_blocks*D;\n    VKQ_meta  += j_dst_unrolled * parallel_blocks;\n    dst       += j_dst_unrolled *                 D;\n\n    const int tid = item_ct1.get_local_id(2);\n    __builtin_assume(tid < D);\n\n    auto meta = (sycl::float2 *) dpct_local;\n    for (int i = tid; i < 2*parallel_blocks; i += D) {\n        ((float *) meta)[i] = ((const float *)VKQ_meta) [i];\n    }\n\n    item_ct1.barrier(sycl::access::fence_space::local_space);\n\n    float kqmax = meta[0].x();\n    for (int l = 1; l < parallel_blocks; ++l) {\n        kqmax = sycl::max(kqmax, meta[l].x());\n    }\n\n    float VKQ_numerator   = 0.0f;\n    float VKQ_denominator = 0.0f;\n    for (int l = 0; l < parallel_blocks; ++l) {\n        const float KQ_max_scale = sycl::native::exp(meta[l].x() - kqmax);\n\n        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*D + tid];\n        VKQ_denominator += KQ_max_scale * meta[l].y();\n    }\n\n    dst[tid] = VKQ_numerator / VKQ_denominator;\n}\n\ntemplate <fattn_kernel_t fattn_kernel, int warp_size>\nstatic void lauch_kernel(\n    dpct::dim3 group_range,\n    dpct::dim3 local_range,\n    queue_ptr q,\n    unsigned int local_mem_size,\n    const char* __restrict__ Q,\n    const char* __restrict__ K,\n    const char* __restrict__ V,\n    const char* __restrict__ mask,\n    const char* __restrict__ sinks,\n    const int* __restrict__ KV_max,\n    float* __restrict__ dst,\n    sycl::float2* __restrict__ dst_meta,\n    const float scale,\n    const float max_bias,\n    const float m0,\n    const float m1,\n    const uint32_t n_head_log2,\n    const float logit_softcap,\n    const int32_t ne00,\n    const sycl::uint3 ne01,\n    const int32_t ne02,\n    const int32_t ne03,\n    const int32_t nb01,\n    const int32_t nb02,\n    const int32_t nb03,\n    const int32_t ne10,\n    const int32_t ne11,\n    const int32_t ne12,\n    const int32_t ne13,\n    const int32_t nb11,\n    const int32_t nb12,\n    const int64_t nb13,\n    const int32_t nb21,\n    const int32_t nb22,\n    const int64_t nb23,\n    const int32_t ne31,\n    const int32_t ne32,\n    const int32_t ne33,\n    const int32_t nb31,\n    const int32_t nb32,\n    const int64_t nb33) {\n    GGML_UNUSED(local_mem_size);\n    q->submit([&](sycl::handler &cgh) {\n        cgh.parallel_for(\n            sycl::nd_range<3>(\n                static_cast<sycl::range<3>>(group_range * local_range),\n                static_cast<sycl::range<3>>(local_range)),\n            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {\n                GGML_UNUSED(item_ct1);\n                fattn_kernel(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,\n                             max_bias, m0, m1, n_head_log2, logit_softcap, ne00,\n                             ne01, ne02, ne03, nb01, nb02, nb03, ne10, ne11,\n                             ne12, ne13, nb11, nb12, nb13, nb21, nb22, nb23,\n                             ne31, ne32, ne33, nb31, nb32, nb33);\n            });\n    });\n}\n\ntemplate <int DV, int ncols1, int ncols2, fattn_kernel_t fattn_kernel, int warp_size>\nvoid launch_fattn(\n    ggml_backend_sycl_context & ctx, ggml_tensor * dst, const int nwarps, const size_t nbytes_shared,\n    const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k) {\n\n    constexpr int ncols = ncols1 * ncols2;\n\n    const ggml_tensor * Q = dst->src[0];\n    const ggml_tensor * K = dst->src[1];\n    const ggml_tensor * V = dst->src[2];\n\n    const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));\n\n    const ggml_tensor * mask  = dst->src[3];\n    const ggml_tensor * sinks = dst->src[4];\n\n    ggml_tensor * KQV = dst;\n\n    GGML_ASSERT(Q->type == GGML_TYPE_F32);\n    GGML_ASSERT(KQV->type == GGML_TYPE_F32);\n\n    GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));\n    GGML_ASSERT(K->nb[0] == ggml_element_size(K));\n    GGML_ASSERT(V->nb[0] == ggml_element_size(V));\n\n    GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);\n\n    ggml_sycl_pool & pool = ctx.pool();\n    dpct::queue_ptr  main_stream = ctx.stream();\n    const int id  = ggml_sycl_get_device();\n    const int nsm = ggml_sycl_info().devices[id].nsm;\n\n    ggml_sycl_pool_alloc<sycl::half>   K_f16(pool);\n    ggml_sycl_pool_alloc<sycl::half>   V_f16(pool);\n    ggml_sycl_pool_alloc<int>    KV_max(pool);\n    ggml_sycl_pool_alloc<float>  dst_tmp(pool);\n    ggml_sycl_pool_alloc<sycl::float2> dst_tmp_meta(pool);\n\n    const char * K_data = (const char *) K->data;\n    size_t nb11 = K->nb[1];\n    size_t nb12 = K->nb[2];\n    size_t nb13 = K->nb[3];\n\n    const char * V_data = (const char *) V->data;\n    size_t nb21 = V->nb[1];\n    size_t nb22 = V->nb[2];\n    size_t nb23 = V->nb[3];\n\n    if (need_f16_K && K->type != GGML_TYPE_F16) {\n        const size_t bs = ggml_blck_size(K->type);\n        const size_t ts = ggml_type_size(K->type);\n\n        K_f16.alloc(ggml_nelements(K));\n        if (ggml_is_contiguously_allocated(K)) {\n            to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(K->type, dst);\n            to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);\n\n            nb11 = nb11 * bs * sizeof(sycl::half) / ts;\n            nb12 = nb12 * bs * sizeof(sycl::half) / ts;\n            nb13 = nb13 * bs * sizeof(sycl::half) / ts;\n        } else {\n            GGML_ASSERT(K->nb[0] == ts);\n            to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(K->type);\n            const int64_t s01 = nb11 / ts;\n            const int64_t s02 = nb12 / ts;\n            const int64_t s03 = nb13 / ts;\n            to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);\n\n            nb11 = K->ne[0] * sizeof(sycl::half);\n            nb12 = K->ne[1] * nb11;\n            nb13 = K->ne[2] * nb12;\n        }\n        K_data = (char *) K_f16.ptr;\n    }\n\n    if (need_f16_V && V->type != GGML_TYPE_F16) {\n        if (V_is_K_view) {\n            V_data = K_data;\n            nb21   = nb11;\n            nb22   = nb12;\n            nb23   = nb13;\n        } else {\n            const size_t bs = ggml_blck_size(V->type);\n            const size_t ts = ggml_type_size(V->type);\n\n            V_f16.alloc(ggml_nelements(V));\n            if (ggml_is_contiguously_allocated(V)) {\n                to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(V->type, dst);\n                to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);\n                V_data = (char *) V_f16.ptr;\n\n                nb21 = nb21 * bs * sizeof(sycl::half) / ts;\n                nb22 = nb22 * bs * sizeof(sycl::half) / ts;\n                nb23 = nb23 * bs * sizeof(sycl::half) / ts;\n            } else {\n                GGML_ASSERT(V->nb[0] == ts);\n                to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(V->type);\n                const int64_t s01 = nb21 / ts;\n                const int64_t s02 = nb22 / ts;\n                const int64_t s03 = nb23 / ts;\n                to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);\n\n                nb21 = V->ne[0] * sizeof(sycl::half);\n                nb22 = V->ne[1] * nb21;\n                nb23 = V->ne[2] * nb22;\n            }\n            V_data = (char *) V_f16.ptr;\n        }\n    }\n\n    const int ntiles_x     = ((Q->ne[1] + ncols1 - 1) / ncols1);\n    const int gqa_ratio    = Q->ne[2] / K->ne[2];\n    const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);\n    const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];\n\n    // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.\n    // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or\n    //     multiple sequences of possibly different lengths.\n    if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {\n        const int s31 = mask->nb[1] / sizeof(sycl::half2);\n        const int s33 = mask->nb[3] / sizeof(sycl::half2);\n\n        const dpct::dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);\n        const dpct::dim3 block_dim_KV_max(FATTN_KQ_STRIDE / 2, 1, 1);\n\n        const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;\n        const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;\n\n        KV_max.alloc(ne_KV_max);\n        {\n            dpct::has_capability_or_fail(main_stream->get_device(), { sycl::aspect::fp16 });\n\n            main_stream->submit([&](sycl::handler & cgh) {\n                sycl::local_accessor<int, 1> buf_iw_acc_ct1(sycl::range<1>(warp_size), cgh);\n\n                auto mask_data_ct0  = (const sycl::half2 *) mask->data;\n                auto KV_max_ptr_ct1 = KV_max.ptr;\n\n                cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max),\n                                 [=](sycl::nd_item<3> item_ct1) {\n                                     GGML_UNUSED(item_ct1);\n                                     flash_attn_mask_to_KV_max<ncols1, warp_size>(\n                                         mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33,\n                                         buf_iw_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());\n                                 });\n            });\n        }\n        SYCL_CHECK(0);\n    }\n\n    const dpct::dim3 block_dim(warp_size, nwarps, 1);\n\n    // Max. number of active blocks limited by occupancy.\n    int max_blocks_per_sm = ggml_sycl_info().devices[id].max_wg_per_cu;\n    int parallel_blocks = max_blocks_per_sm;\n    dpct::dim3 blocks_num;\n    if (stream_k) {\n        // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.\n        const int max_blocks = max_blocks_per_sm*nsm;\n        const int nblocks_stream_k = max_blocks;\n        const bool use_stream_k = true;\n\n        blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;\n        blocks_num.y = 1;\n        blocks_num.z = 1;\n\n        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.\n            dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));\n        }\n    } else {\n        const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.\n\n        // parallel_blocks must not be larger than what the tensor size allows:\n        parallel_blocks = std::min(parallel_blocks, ntiles_KQ);\n        // todo fix the hard code change\n        // parallel_blocks = ntiles_KQ;\n\n        // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.\n        // Test whether parallel_blocks can be set to a higher value for better efficiency.\n        const int blocks_per_wave = nsm * max_blocks_per_sm;\n        int nwaves_best = 0;\n        int efficiency_percent_best = 0;\n        for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {\n            const int nblocks_total = ntiles_total * parallel_blocks_test;\n            const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;\n            const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);\n\n            // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.\n            if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {\n                break;\n            }\n\n            if (efficiency_percent > efficiency_percent_best) {\n                nwaves_best = nwaves;\n                efficiency_percent_best = efficiency_percent;\n                parallel_blocks = parallel_blocks_test;\n            }\n        }\n\n        blocks_num.x = ntiles_x;\n        blocks_num.y = parallel_blocks;\n        blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];\n\n        if (parallel_blocks > 1) {\n            dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));\n            dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));\n        }\n    }\n\n    float scale         = 1.0f;\n    float max_bias      = 0.0f;\n    float logit_softcap = 0.0f;\n\n    memcpy(&scale,         (const float *) KQV->op_params + 0, sizeof(float));\n    memcpy(&max_bias,      (const float *) KQV->op_params + 1, sizeof(float));\n    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));\n\n    if (logit_softcap != 0.0f) {\n        scale /= logit_softcap;\n    }\n\n    const uint32_t n_head      = Q->ne[2];\n    const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));\n\n    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);\n    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n    // TODO other tensor dimensions after removal of WMMA kernel:\n    const sycl::uint3 ne01 = init_fastdiv_values(Q->ne[1]);\n\n    GGML_ASSERT(block_dim.x % warp_size == 0);\n\n    lauch_kernel<fattn_kernel, warp_size>(\n        blocks_num, block_dim, main_stream, (unsigned int) nbytes_shared, (const char *) Q->data, K_data, V_data,\n        mask ? ((const char *) mask->data) : nullptr, sinks ? ((const char *) sinks->data) : nullptr, KV_max.ptr,\n        !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, (sycl::float2 *)dst_tmp_meta.ptr, scale, max_bias, m0, m1,\n        n_head_log2, logit_softcap, Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], K->ne[0],\n        K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0,\n        mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,\n        mask ? mask->nb[3] : 0);\n    SYCL_CHECK(0);\n\n    if (stream_k) {\n        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.\n            const dpct::dim3 block_dim_combine(DV, 1, 1);\n            const dpct::dim3 blocks_num_combine = { blocks_num.x, ncols1, ncols2 };\n\n            main_stream->submit([&](sycl::handler & cgh) {\n                auto KQV_data_ct0         = (float *) KQV->data;\n                auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr;\n                auto Q_ne_ct2             = Q->ne[1];\n                auto Q_ne_ct3             = Q->ne[2];\n                auto Q_ne_ct4             = Q->ne[3];\n                auto K_ne_ct5             = K->ne[1];\n                auto K_ne_ct6             = K->ne[2];\n\n                cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),\n                                 [=](sycl::nd_item<3> item_ct1) {\n                                     GGML_UNUSED(item_ct1);\n                                     flash_attn_stream_k_fixup<DV, ncols1, ncols2>(KQV_data_ct0, dst_tmp_meta_ptr_ct1,\n                                                                                   Q_ne_ct2, Q_ne_ct3, Q_ne_ct4,\n                                                                                   K_ne_ct5, K_ne_ct6, nbatch_fa);\n                                 });\n            });\n        }\n    } else if (parallel_blocks > 1) {\n        const dpct::dim3 block_dim_combine(DV, 1, 1);\n        const dpct::dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);\n        const size_t     nbytes_shared_combine = parallel_blocks * sizeof(sycl::float2);\n        main_stream->submit([&](sycl::handler & cgh) {\n            sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(sycl::range<1>(nbytes_shared_combine), cgh);\n\n            auto dst_tmp_ptr_ct0      = dst_tmp.ptr;\n            auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr;\n            auto KQV_data_ct2         = (float *) KQV->data;\n\n            cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),\n                             [=](sycl::nd_item<3> item_ct1) {\n                                 GGML_UNUSED(item_ct1);\n                                 flash_attn_combine_results<DV>(\n                                     dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks,\n                                     dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());\n                             });\n        });\n    }\n    SYCL_CHECK(0);\n}\n"
  },
  {
    "path": "src/ggml-sycl/fattn-tile.cpp",
    "content": "#include <sycl/sycl.hpp>\n#include <sycl/ext/oneapi/work_group_static.hpp>\n#include \"dpct/helper.hpp\"\n#include \"common.hpp\"\n#include \"fattn-common.hpp\"\n#include \"fattn-tile.hpp\"\n#include <cmath>\n#include <float.h>\nnamespace syclex = sycl::ext::oneapi::experimental;\n\nvoid ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * K = dst->src[1];\n    const ggml_tensor * V = dst->src[2];\n    switch (K->ne[0]) {\n        case  40: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_sycl_flash_attn_ext_tile_case< 40,  40>(ctx, dst);\n        } break;\n        case  64: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_sycl_flash_attn_ext_tile_case< 64,  64>(ctx, dst);\n        } break;\n        case  72: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_sycl_flash_attn_ext_tile_case< 72,  72>(ctx, dst);\n        } break;\n        case  80: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_sycl_flash_attn_ext_tile_case< 80,  80>(ctx, dst);\n        } break;\n        case  96: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_sycl_flash_attn_ext_tile_case< 96,  96>(ctx, dst);\n        } break;\n        case 112: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_sycl_flash_attn_ext_tile_case<112, 112>(ctx, dst);\n        } break;\n        case 128: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_sycl_flash_attn_ext_tile_case<128, 128>(ctx, dst);\n        } break;\n        case 256: {\n            GGML_ASSERT(V->ne[0] == K->ne[0]);\n            ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst);\n        } break;\n        case 576: {\n            GGML_ASSERT(V->ne[0] == 512);\n            ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst);\n        } break;\n        default: {\n            GGML_ABORT(\"Unsupported head size\");\n        } break;\n    }\n}\n"
  },
  {
    "path": "src/ggml-sycl/fattn-tile.hpp",
    "content": "#include <sycl/sycl.hpp>\n#include <sycl/ext/oneapi/work_group_static.hpp>\n#include \"dpct/helper.hpp\"\n#include \"common.hpp\"\n#include \"fattn-common.hpp\"\n\n#include <cmath>\n#include <float.h>\n\nnamespace syclex = sycl::ext::oneapi::experimental;\n\n#define GGML_SYCL_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \\\n    if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) {                                          \\\n        static_assert((nthreads)          <= 512, \"bad nthreads\");                                    \\\n        static_assert((occupancy)         <=   8, \"bad occupancy\");                                   \\\n        static_assert((nbatch_fa)         <= 256, \"bad nbatch_fa\");                                   \\\n        static_assert((nbatch_K)          <= 256, \"bad nbatch_K\");                                    \\\n        return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23);    \\\n    }                                                                                                 \\\n\nstatic constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, const int DV, const int ncols) {\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  64,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  64,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  64,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  64,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  64,  40)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 256, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 256, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  64,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  64,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  64,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  64,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  64,  72)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  64,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  64,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  64,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  64,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  64,  40)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  64,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  64,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  64,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  64,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  64,  48)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  64,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  64,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  64,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  64,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  64,  56)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  2,  64, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  8, 256, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  2,  64, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  64,  64)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)\n\n    return 0;\n}\n\nstatic constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, const int DV, const int ncols) {\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  2, 128, 3,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 3,  32,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 3,  32,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 128, 3,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  2, 128, 3,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 3,  32, 128)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  8, 128, 3,  64, 128)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3,  32, 128)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  2, 128, 3,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 3,  32,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  32, 256)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32,  64)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  32,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  32,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  32,  64)\n\n    return 0;\n}\n\nstatic constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 64, 256, 2,  32,  40)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 3,  32,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 3,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 2,  32,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 256, 2, 128,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 64, 256, 2,  64,  64)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 64, 256, 2,  32,  72)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 64, 256, 2,  32,  40)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 64, 256, 2,  32,  48)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2,  32,  56)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  2, 256, 2, 128,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 2,  64, 128)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  8, 256, 2,  64, 128)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2,  64, 128)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2,  64,  32)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  2, 256, 2, 128,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  4, 256, 2,  64, 128)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  64, 128)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32, 128)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128,  64)\n\n    return 0;\n}\n\nstatic constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 64, 256, 2,  32,  40)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 8,  32,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  4,  64, 8,  32,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 5, 128,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 128, 5, 128,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 128, 4,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 64, 128, 5,  64,  64)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 64, 256, 2,  32,  72)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 64, 256, 2,  32,  40)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 64, 256, 2,  32,  48)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2,  32,  56)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  2,  64, 8,  32,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 8,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  8, 128, 8,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3,  64,  64)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  2,  64, 8,  32,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 6,  32, 256)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  8, 128, 6,  32, 256)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5,  32, 256)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3,  64, 128)\n\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4,  64,  64)\n    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128,  64)\n\n    return 0;\n}\n\nstatic constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {\n    if(fast_fp16_available(cc))\n        return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols);\n    else\n        return ggml_sycl_fattn_tile_get_config_fp32(DKQ, DV, ncols);\n}\n\nstatic constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) {\n#ifdef SYCL_FAST_FP16\n    return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols);\n#else\n    return ggml_sycl_fattn_tile_get_config_fp32(DKQ, DV, ncols);\n#endif // SYCL_FAST_FP16\n}\n\nstatic int ggml_sycl_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {\n    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1);\n}\n\nstatic constexpr int ggml_sycl_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) {\n    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1);\n}\n\nstatic int ggml_sycl_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {\n    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1);\n}\n\nstatic constexpr int ggml_sycl_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) {\n    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1);\n}\n\nstatic int ggml_sycl_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {\n    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1);\n}\n\nstatic constexpr int ggml_sycl_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {\n    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1);\n}\n\nstatic int ggml_sycl_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) {\n    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1);\n}\n\nstatic constexpr int ggml_sycl_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) {\n    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1);\n}\n\ntemplate <int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>\nstatic __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const __restrict__ KV,\n                                                      sycl::half2 * const __restrict__ tile_KV,\n                                                      const int stride_KV,\n                                                      const int i_sup) {\n    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();\n    constexpr int cpy_ne = cpy_nb / 4;\n\n    auto load = [&] (const int n) {\n        auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n        const int stride_j = warp_size >> n;\n\n        if (stride_j == 0) {\n            return;\n        }\n\n        const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j);\n        const int j0_stop  =                             ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j);\n        const int stride_i = warp_size / stride_j;\n\n        if (j0_start == j0_stop) {\n            return;\n        }\n\n#pragma unroll\n        for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {\n            const int i = i0 + item_ct1.get_local_id(1) * stride_i +\n                          (stride_j == warp_size ? 0 : item_ct1.get_local_id(2) / stride_j);\n\n            if (i0 + nwarps*stride_i <= I || i < I) {\n#pragma unroll\n                for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {\n                    const int j = j0 * cpy_ne + (stride_j == warp_size ? item_ct1.get_local_id(2) :\n                                                                         item_ct1.get_local_id(2) % stride_j) *\n                                                    cpy_ne;\n\n                    const __dpct_align__(16) sycl::half2 zero[cpy_ne] = {\n                        { 0.0f, 0.0f }\n                    };\n                    ggml_sycl_memcpy_1<cpy_nb>(\n                        tile_KV + i*(J/2 + J_padding) + j,\n                        !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);\n                }\n            }\n        }\n    };\n    // 1: max 64*16=512 bytes, 512 half\n    // 2: max 32*16=512 bytes, 256 half\n    // 3: max 16*16=256 bytes, 128 half\n    // 4: max  8*16=128 bytes,  64 half\n    // 5: max  4*16= 64 bytes,  32 half\n    // 6: max  2*16= 32 bytes,  16 half\n    // 7: max  1*16= 16 bytes,   8 half\n    static_assert(J % 8 == 0, \"bad J\");\n    static_assert((J/2) % cpy_ne == 0, \"bad J\");\n    ggml_sycl_unroll<7>{}(load);\n}\n\ntemplate <int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>\nstatic __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const __restrict__ KV,\n                                                      float * const __restrict__ tile_KV,\n                                                      const int stride_KV,\n                                                      const int i_sup) {\n    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();\n    constexpr int cpy_ne = cpy_nb / 4;\n\n    auto load = [&] (const int n) {\n        auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n        const int stride_j = warp_size >> n;\n\n        if (stride_j == 0) {\n            return;\n        }\n\n        const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j);\n        const int j0_stop  =                             (J/cpy_ne) - (J/cpy_ne) % (1*stride_j);\n        const int stride_i = warp_size / stride_j;\n\n        if (j0_start == j0_stop) {\n            return;\n        }\n\n#pragma unroll\n        for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {\n            const int i = i0 + item_ct1.get_local_id(1) * stride_i +\n                          (stride_j == warp_size ? 0 : item_ct1.get_local_id(2) / stride_j);\n\n            if (i0 + nwarps*stride_i <= I || i < I) {\n#pragma unroll\n                for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {\n                    const int j = j0 * (cpy_ne / 2) + (stride_j == warp_size ? item_ct1.get_local_id(2) :\n                                                                               item_ct1.get_local_id(2) % stride_j) *\n                                                          (cpy_ne / 2);\n\n                    const sycl::half2 zero[cpy_ne / 2] = {\n                        { 0.0f, 0.0f }\n                    };\n                    __dpct_align__(16) sycl::half2 tmp_h2[cpy_ne / 2];\n                    ggml_sycl_memcpy_1<sizeof(tmp_h2)>(\n                        tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);\n\n                    __dpct_align__(16) sycl::float2 tmp_f2[cpy_ne / 2];\n#pragma unroll\n                    for (int l = 0; l < cpy_ne/2; ++l) {\n                        tmp_f2[l] = tmp_h2[l].template convert<float, sycl::rounding_mode::automatic>();\n                    }\n                    ggml_sycl_memcpy_1<sizeof(tmp_f2)>(tile_KV + i*(J + J_padding) + 2*j, tmp_f2);\n                }\n            }\n        }\n    };\n    // 1: max 32*16=512 bytes, 128 float\n    // 2: max 16*16=256 bytes,  64 float\n    // 3: max  8*16=128 bytes,  32 float\n    // 4: max  4*16= 64 bytes,  16 float\n    // 5: max  2*16= 32 bytes,   8 float\n    static_assert(J % 8 == 0, \"bad J\");\n    static_assert(J % cpy_ne == 0, \"bad J\");\n    ggml_sycl_unroll<5>{}(load);\n}\n\n// Function that performs a single iteration in for the KQ matrix multiplication:\ntemplate <int  warp_size,\n          int  nwarps,\n          int  ncols1,\n          int  ncols2,\n          int  DKQ,\n          int  nbatch_fa,\n          int  nbatch_K,\n          bool use_logit_softcap,\n          bool oob_check,\n          typename T_vec_dot>\nstatic __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,\n                                                    const sycl::half2 * const __restrict__ K_h2,\n                                                    T_vec_dot * const KV_tmp,\n                                                    const int         stride_K2,\n                                                    const int         k_VKQ_0,\n                                                    const int         k_VKQ_sup,\n                                                    const int         k_KQ_0,\n                                                    float *           KQ_acc) {\n    auto          item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    constexpr int cpy_nb   = ggml_sycl_get_max_cpy_bytes();\n    constexpr int cpy_ne = cpy_nb / 4;\n\n    constexpr int ncols = ncols1*ncols2;\n    constexpr int cpw   = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp\n    constexpr int np    = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column\n\n    flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>\n        (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);\n    item_ct1.barrier();\n\n#ifdef SYCL_FAST_FP16\n    static_assert((nbatch_K/2) % cpy_ne == 0, \"bad nbatch_K\");\n#pragma unroll\n    for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {\n        __dpct_align__(16) sycl::half2 K_k[nbatch_fa / (np * warp_size)][cpy_ne];\n        __dpct_align__(16) sycl::half2 Q_k[cpw][cpy_ne];\n#else\n    static_assert(nbatch_K % cpy_ne == 0, \"bad nbatch_K\");\n#pragma unroll\n    for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {\n        __dpct_align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne];\n        __dpct_align__(16) float Q_k[cpw][cpy_ne];\n#endif // SYCL_FAST_FP16\n\n#pragma unroll\n        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {\n            const int i_KQ = i_KQ_0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2);\n\n#ifdef SYCL_FAST_FP16\n            ggml_sycl_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]);\n#else\n            ggml_sycl_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K   + cpy_ne) + k_KQ_1]);\n#endif // SYCL_FAST_FP16\n        }\n#pragma unroll\n        for (int jc0 = 0; jc0 < cpw; ++jc0) {\n            const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw;\n\n#ifdef SYCL_FAST_FP16\n            ggml_sycl_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]);\n#else\n            ggml_sycl_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc* DKQ    + k_KQ_0   + k_KQ_1]);\n#endif // SYCL_FAST_FP16\n        }\n\n#pragma unroll\n        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {\n#pragma unroll\n            for (int jc0 = 0; jc0 < cpw; ++jc0) {\n#pragma unroll\n                for (int k = 0; k < cpy_ne; ++k) {\n                    ggml_sycl_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]);\n                }\n            }\n        }\n    }\n\n    if (k_KQ_0 + nbatch_K < DKQ) {\n        item_ct1.barrier();  // Sync not needed on last iteration.\n    }\n}\n\n// Function that performs a single iteration of the main loop over up to nbatch_fa tokens.\ntemplate <int  warp_size,\n          int  nwarps,\n          int  ncols1,\n          int  ncols2,\n          int  DKQ,\n          int  DV,\n          int  nbatch_fa,\n          int  nbatch_K,\n          bool use_logit_softcap,\n          bool oob_check,\n          typename T_vec_dot,\n          typename T_KQ,\n          typename T_acc>\n/*\nThe total declared local variable size in device function flash_attn_tile_iter exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure.\n*/\nstatic __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,\n                                                 const sycl::half2 * const __restrict__ K_h2,\n                                                 const sycl::half2 * const __restrict__ V_h2,\n                                                 const sycl::half * const __restrict__ mask,\n                                                 const sycl::uint3 ne01,\n                                                 const float       logit_softcap,\n                                                 const float       slope,\n                                                 T_KQ * const      KQ,\n                                                 T_vec_dot * const KV_tmp,\n                                                 const int         stride_K2,\n                                                 const int         stride_V2,\n                                                 const int         stride_mask,\n                                                 float * const     KQ_max,\n                                                 float * const     KQ_sum,\n                                                 T_acc * const     VKQ,\n                                                 const int         k_VKQ_0,\n                                                 const int         k_VKQ_max,\n                                                 const int         col_Q_0,\n                                                 float *           KQ_max_new_shared) {\n    auto          item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    constexpr int cpy_nb   = ggml_sycl_get_max_cpy_bytes();\n    constexpr int cpy_ne = cpy_nb / 4;\n\n    constexpr int ncols = ncols1*ncols2;\n    constexpr int cpw   = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp\n    constexpr int np    = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column\n\n    constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.\n\n#ifdef SYCL_FAST_FP16\n    constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;\n#else\n    constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;\n#endif // SYCL_FAST_FP16\n    static_assert(cpw % KQ_cs == 0, \"bad KQ_cs\");\n    const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data\n\n    float KQ_max_new[cpw];\n#pragma unroll\n    for (int jc0 = 0; jc0 < cpw; ++jc0) {\n        KQ_max_new[jc0] = KQ_max[jc0];\n    }\n\n    float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication.\n\n    // KQ = K @ Q matrix multiplication:\n    constexpr int nbatch_K_last = DKQ % nbatch_K;\n#pragma unroll\n    for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) {\n        flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>(\n            Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);\n    }\n    if (nbatch_K_last > 0) {\n        constexpr int k_KQ_0 = DKQ - nbatch_K_last;\n        flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K_last, use_logit_softcap, oob_check>(\n            Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);\n    }\n\n    // Apply logit softcap + mask, update KQ_max:\n#pragma unroll\n    for (int jc0 = 0; jc0 < cpw; ++jc0) {\n        const int j = fastmodulo(col_Q_0 + (jc0 + (item_ct1.get_local_id(1) / np) * cpw) / ncols2, ne01);\n\n#pragma unroll\n        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {\n            const int i_KQ = i_KQ_0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2);\n\n#if defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)\n            // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.\n            // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.\n            KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f;\n#endif // defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)\n\n            if (use_logit_softcap) {\n                KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] =\n                    logit_softcap * sycl::tanh((float) KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0]);\n            }\n\n            if (!oob_check || i_KQ < k_VKQ_sup) {\n                KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] +=\n                    (ncols2 > 1 || mask) ? slope * sycl::vec<sycl::half, 1>(mask[j * stride_mask + k_VKQ_0 + i_KQ])\n                                                       .convert<float, sycl::rounding_mode::automatic>()[0] :\n                                           0.0f;\n\n                KQ_max_new[jc0] =\n                    sycl::fmax((float) KQ_max_new[jc0],\n                               (float) (KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] + FATTN_KQ_MAX_OFFSET));\n            }\n        }\n\n        KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);\n    }\n\n    if constexpr (np == 1) {\n        item_ct1.barrier();\n    } else {\n        static_assert(cpw == 1, \"bad cpw\");\n\n        if (item_ct1.get_local_id(2) == 0) {\n            KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0];\n        }\n        item_ct1.barrier();\n        KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np];\n        KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);\n    }\n\n    // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:\n#pragma unroll\n    for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {\n#ifdef SYCL_FAST_FP16\n        __dpct_align__(16) sycl::half tmp[nbatch_fa / (np * warp_size)][KQ_cs];\n#else\n        __dpct_align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs];\n#endif // SYCL_FAST_FP16\n\n#pragma unroll\n        for (int jc1 = 0; jc1 < KQ_cs; ++jc1) {\n            const int jc = jc0 + jc1;\n\n            const float KQ_max_scale = sycl::native::exp((float) (KQ_max[jc] - KQ_max_new[jc]));\n            KQ_max[jc] = KQ_max_new[jc];\n\n            float KQ_sum_add = 0.0f;\n#pragma unroll\n            for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {\n                const float val =\n                    !oob_check || i0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2) <\n                                      static_cast<uint32_t>(k_VKQ_sup) ?\n                        sycl::native::exp((float) (KQ_acc[(i0 / (np * warp_size)) * cpw + jc] - KQ_max[jc])) :\n                        0.0f;\n                KQ_sum_add += val;\n                tmp[i0/(np*warp_size)][jc1] = val;\n            }\n            KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;\n\n#ifdef SYCL_FAST_FP16\n            const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {\n                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale_h2.x();\n                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale_h2.y();\n            }\n#else\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {\n                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale;\n                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale;\n            }\n#endif // SYCL_FAST_FP16\n        }\n\n#pragma unroll\n        for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {\n            const int i = i0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2);\n\n            ggml_sycl_memcpy_1<sizeof(tmp[0])>(\n                KQ + (jc0 / KQ_cs + (item_ct1.get_local_id(1) / np) * (cpw / KQ_cs)) * (nbatch_fa * KQ_cs) + i * KQ_cs,\n                tmp[i0 / (np * warp_size)]);\n        }\n    }\n\n    // VKQ = V @ KQ matrix multiplication:\n    static_assert(DV <= DKQ, \"bad DV\");\n    static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), \"bad nbatch_K\");\n    constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K.\n    static_assert(nbatch_fa % nbatch_V == 0, \"bad nbatch_V\");\n    static_assert(nbatch_V % np == 0, \"bad nbatch_V\");\n#pragma unroll\n    for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {\n        flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>\n            (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);\n        item_ct1.barrier();\n\n#ifdef SYCL_FAST_FP16\n#pragma unroll\n        for (int k1 = 0; k1 < nbatch_V; k1 += np) {\n            __dpct_align__(16) sycl::half2 V_k[(DVp / 2) / warp_size];\n            __dpct_align__(16) sycl::half2 KQ_k[cpw];\n\n            constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {\n                ggml_sycl_memcpy_1<cpy_ne_D * 4>(&V_k[i0 / warp_size],\n                                                 &KV_tmp[(k1 + item_ct1.get_local_id(1) % np) * (DV / 2) + i0 +\n                                                         item_ct1.get_local_id(2) * cpy_ne_D]);\n            }\n#pragma unroll\n            for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {\n                const int jc_KQ = jc_VKQ_0 / KQ_cs + (item_ct1.get_local_id(1) / np) * (cpw / KQ_cs);\n\n                __dpct_align__(16) sycl::half tmp[KQ_cs];\n                ggml_sycl_memcpy_1<KQ_cs * sizeof(sycl::half)>(\n                    &tmp, KQ + jc_KQ * (nbatch_fa * KQ_cs) + (k0 + k1 + item_ct1.get_local_id(1) % np) * KQ_cs);\n#pragma unroll\n                for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) {\n                    KQ_k[jc_VKQ_0 + jc_VKQ_1] = sycl::half2(tmp[jc_VKQ_1]);\n                }\n            }\n\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {\n#pragma unroll\n                for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {\n                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x() +=\n                        V_k[i0/warp_size].x()*KQ_k[jc_VKQ_0].x();\n                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y() +=\n                        V_k[i0/warp_size].y()*KQ_k[jc_VKQ_0].y();\n                }\n            }\n        }\n#else\n#pragma unroll\n        for (int k1 = 0; k1 < nbatch_V; k1 += np) {\n            __dpct_align__(16) sycl::float2 V_k[(DVp/2)/warp_size];\n            __dpct_align__(16) float  KQ_k[cpw];\n\n            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;\n#pragma unroll\n            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {\n                ggml_sycl_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + item_ct1.get_local_id(1) % np)*DV + i0 + item_ct1.get_local_id(2)*cpy_ne_D]);\n            }\n#pragma unroll\n            for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {\n                const int jc_KQ = jc_VKQ_0/KQ_cs + (item_ct1.get_local_id(1) / np)*(cpw/KQ_cs);\n\n                ggml_sycl_memcpy_1<KQ_cs*sizeof(float)>(\n                    &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + item_ct1.get_local_id(1) % np)*KQ_cs);\n            }\n\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {\n#pragma unroll\n                for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {\n                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x() += V_k[i0/warp_size].x()*KQ_k[jc_VKQ_0];\n                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y() += V_k[i0/warp_size].y()*KQ_k[jc_VKQ_0];\n                }\n            }\n        }\n#endif // SYCL_FAST_FP16\n        item_ct1.barrier();\n    }\n}\n\ntemplate <int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, int warp_size>  // D == head size\n/*\nThe total declared local variable size in device function flash_attn_tile exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure.\n*/\nstatic void flash_attn_tile(const char *  Q,\n                            const char *  K,\n                            const char *  V,\n                            const char *  mask,\n                            const char *  sinks,\n                            const int *  KV_max,\n                            float *  dst,\n                            sycl::float2 *  dst_meta,\n                            const float          scale,\n                            const float          max_bias,\n                            const float          m0,\n                            const float          m1,\n                            const uint32_t       n_head_log2,\n                            const float          logit_softcap,\n                            const int32_t        ne00,\n                            const sycl::uint3    ne01,\n                            const int32_t        ne02,\n                            const int32_t        ne03,\n                            const int32_t        nb01,\n                            const int32_t        nb02,\n                            const int32_t        nb03,\n                            const int32_t        ne10,\n                            const int32_t        ne11,\n                            const int32_t        ne12,\n                            const int32_t        ne13,\n                            const int32_t        nb11,\n                            const int32_t        nb12,\n                            const int64_t        nb13,\n                            const int32_t        nb21,\n                            const int32_t        nb22,\n                            const int64_t        nb23,\n                            const int32_t        ne31,\n                            const int32_t        ne32,\n                            const int32_t        ne33,\n                            const int32_t        nb31,\n                            const int32_t        nb32,\n                            const int64_t        nb33) {\n#ifdef SYCL_FLASH_ATTN\n    // Skip unused kernel variants for faster compilation:\n    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    if ((use_logit_softcap && !(DV == 128 || DV == 256))) {\n        GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,\n            max_bias, m0, m1, n_head_log2, logit_softcap,\n            ne00, ne01, ne02, ne03,\n                  nb01, nb02, nb03,\n            ne10, ne11, ne12, ne13,\n                  nb11, nb12, nb13,\n                  nb21, nb22, nb23,\n                  ne31, ne32, ne33,\n                  nb31, nb32, nb33);\n        return;\n    }\n\n    static_assert(ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, \"kernel config not defined\");\n\n    constexpr int ncols     = ncols1*ncols2;\n\n    constexpr int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size;\n    constexpr int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2);\n    constexpr int nbatch_K  = ggml_sycl_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2);\n\n    // In this kernel Q, K, V are matrices while i, j, k are matrix indices.\n\n    const int col_Q_0 = item_ct1.get_group(2) * ncols1;  // Index of the first Q column for this SYCL block to work on.\n\n    const int           sequence  = item_ct1.get_group(0) / (ne02 / ncols2);\n    const int           head0     = item_ct1.get_group(0) * ncols2 - sequence * ne02;  // == item_ct1.get_group(0) % (ne02/ncols2)\n    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.\n    const float * Q_f  = (const float *) (Q + nb03*sequence + nb02* head0);\n    const sycl::half2 * K_h2      = (const sycl::half2 *) (K + nb13 * sequence + nb12 * (head0 / gqa_ratio));\n    const sycl::half2 * V_h2 =\n        (const sycl::half2 *) (V + nb23 * sequence + nb22 * (head0 / gqa_ratio));  // K and V have same shape\n\n    const sycl::half * maskh = mask ? (const sycl::half *) (mask + nb33 * (sequence % ne33)) : nullptr;\n\n    const int stride_K2   = nb11 / sizeof(sycl::half2);\n    const int stride_V2   = nb21 / sizeof(sycl::half2);\n    const int stride_mask = nb31 / sizeof(sycl::half);\n\n    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;\n\n    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();\n    constexpr int cpy_ne = cpy_nb / 4;\n\n    constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp.\n    constexpr int np  = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column.\n\n    static_assert(cpw == 1 || np == 1, \"bad cpw / np\");\n    static_assert(nbatch_fa % (np*warp_size) == 0, \"nbatch_fa % (np*warp_size) != 0\");\n\n    constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size.\n    constexpr int DVp  = (DV  + 2*warp_size - 1) & ~(2*warp_size - 1); // DV  padded to multiple of 2*warp_size.\n\n    // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel.\n    // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11.\n    //     KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV).\n    // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications.\n    // VKQ == Accumulators in registers for the final VKQ result.\n\n\n#ifdef SYCL_FAST_FP16\n    constexpr size_t lsm_size1 = ncols * DKQ/2 ;\n    constexpr size_t lsm_size2 = nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV ;\n    constexpr size_t lsm_size3 = ncols * nbatch_fa;\n    constexpr size_t lsm_size4 = nwarps;\n\n    constexpr size_t local_share_mem_size = lsm_size1 * sizeof(sycl::half2) +\n                                            lsm_size2 * sizeof(sycl::half2) +\n                                            lsm_size3 * sizeof(sycl::half) +\n                                            lsm_size4 * sizeof(float);\n\n    syclex::work_group_static<char[local_share_mem_size]> lsm;\n\n    sycl::half2 *Q_tmp = (sycl::half2 *)&lsm;\n    sycl::half2 *KV_tmp = (sycl::half2*)(Q_tmp +lsm_size1);\n    sycl::half *KQ = (sycl::half *)(KV_tmp+lsm_size2);\n    float *KQ_max_new_shared = (float *)(KQ+lsm_size3);\n\n    __dpct_align__(16) sycl::half2 VKQ[cpw * ((DVp / 2) / warp_size)] = {\n        { 0.0f, 0.0f }\n    };\n#else\n    constexpr size_t lsm_size1 = ncols * DKQ ;\n    constexpr size_t lsm_size2 = nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV;\n    constexpr size_t lsm_size3 = ncols * nbatch_fa;\n    constexpr size_t lsm_size4 = nwarps;\n\n    constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2 +lsm_size3 + lsm_size4) * sizeof(float);\n\n    syclex::work_group_static<char[local_share_mem_size]> lsm;\n\n    float *Q_tmp = (float *)&lsm;\n    float *KV_tmp = Q_tmp +lsm_size1;\n    float *KQ = KV_tmp+lsm_size2;\n    float *KQ_max_new_shared = KQ+lsm_size3;\n\n    __dpct_align__(16) sycl::float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};\n\n\n#endif // SYCL_FAST_FP16\n\n    float KQ_max[cpw] = {};\n\n#pragma unroll\n    for (int j0 = 0; j0 < ncols; j0 += nwarps) {\n        KQ_max[j0/nwarps] = -FLT_MAX/2.0f;\n    }\n    float KQ_sum[cpw] = {0.0f};\n\n    // Load Q data, convert to FP16 if fast:\n#pragma unroll\n    for (int jc0 = 0; jc0 < cpw; ++jc0) {\n        const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw;\n\n        const int j = jc / ncols2;\n        const int c = jc % ncols2;\n\n        constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size;\n\n#pragma unroll\n        for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {\n            if (i0 + np * warp_size * cpy_ne_D <= DKQ ||\n                i0 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D) + item_ct1.get_local_id(2) * cpy_ne_D <\n                    DKQ) {\n                __dpct_align__(16) float tmp_f[cpy_ne_D] = { 0.0f };\n                ggml_sycl_memcpy_1<sizeof(tmp_f)>(\n                    tmp_f, &Q_f[c * (nb02 / sizeof(float)) + fastmodulo(col_Q_0 + j, ne01) * (nb01 / sizeof(float)) +\n                                i0 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D) +\n                                item_ct1.get_local_id(2) * cpy_ne_D]);\n\n#pragma unroll\n                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {\n                    tmp_f[i1] *= scale;\n                }\n\n#ifdef SYCL_FAST_FP16\n                __dpct_align__(16) sycl::half2 tmp_h2[cpy_ne_D / 2];\n#pragma unroll\n                for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {\n                    tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);\n#if defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)\n                    // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.\n                    // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.\n                    tmp_h2[i1 / 2] *= sycl::half2(0.25f, 0.25f);\n#endif // defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)\n                }\n                ggml_sycl_memcpy_1<sizeof(tmp_h2)>(\n                    &Q_tmp[jc * (DKQ / 2) + i0 / 2 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D / 2) +\n                           item_ct1.get_local_id(2) * (cpy_ne_D / 2)],\n                    tmp_h2);\n#else\n                ggml_sycl_memcpy_1<sizeof(tmp_f)>(\n                    &Q_tmp[jc* DKQ    + i0   + (item_ct1.get_local_id(1) % np)*(warp_size*cpy_ne_D)   + item_ct1.get_local_id(2)* cpy_ne_D],\n                    tmp_f);\n#endif // SYCL_FAST_FP16\n            }\n        }\n    }\n\n    item_ct1.barrier();\n\n    // Main loop over KV cache:\n    const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;\n    if (ncols2 == 1) {\n        // Branch with out-of-bounds checks.\n        int k_VKQ_0 = item_ct1.get_group(1) * nbatch_fa;\n        while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {\n            constexpr bool oob_check = false;\n            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap,\n                                 oob_check>(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2,\n                                            stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0,\n                                            KQ_max_new_shared);\n            k_VKQ_0 += item_ct1.get_group_range(1) * nbatch_fa;\n        }\n        if (k_VKQ_0 < k_VKQ_max) {\n            constexpr bool oob_check = true;\n            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap,\n                                 oob_check>(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2,\n                                            stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0,\n                                            KQ_max_new_shared);\n        }\n    } else {\n        // Branch without out-of-bounds checks.\n        for (int k_VKQ_0 = item_ct1.get_group(1) * nbatch_fa; k_VKQ_0 < k_VKQ_max;\n             k_VKQ_0 += item_ct1.get_group_range(1) * nbatch_fa) {\n\n            constexpr bool oob_check = false;\n            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap,\n                                 oob_check>(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2,\n                                            stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0,\n                                            KQ_max_new_shared);\n        }\n    }\n\n#pragma unroll\n    for (int jc0 = 0; jc0 < cpw; ++jc0) {\n        KQ_sum[jc0] = warp_reduce_sum<warp_size>(KQ_sum[jc0]);\n    }\n\n    if constexpr (np > 1) {\n        static_assert(cpw == 1, \"bad cpw\");\n        static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, \"KV_tmp too small\");\n\n#ifdef SYCL_FAST_FP16\n        sycl::half2 * VKQ_combine = (sycl::half2 *) KV_tmp;\n#else\n        float * VKQ_combine    = (float *) KV_tmp;\n#endif // SYCL_FAST_FP16\n\n        float * KQ_sum_combine = (float *) Q_tmp;\n\n        if (item_ct1.get_local_id(1) % np != 0) {\n\n#ifdef SYCL_FAST_FP16\n            constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {\n                ggml_sycl_memcpy_1<cpy_ne_D * 4>(\n                    &VKQ_combine[item_ct1.get_local_id(1) * (DVp / 2) + i0 + item_ct1.get_local_id(2) * cpy_ne_D],\n                    &VKQ[i0 / warp_size]);\n            }\n#else\n\n            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;\n\n#pragma unroll\n            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {\n                ggml_sycl_memcpy_1<cpy_ne_D*4>(\n                    &VKQ_combine[item_ct1.get_local_id(1)*DVp + i0 + item_ct1.get_local_id(2)*cpy_ne_D], ((const float *) VKQ) + i0/warp_size);\n            }\n#endif // SYCL_FAST_FP16\n\n            if (item_ct1.get_local_id(2) == 0) {\n                KQ_sum_combine[item_ct1.get_local_id(1)] = KQ_sum[0];\n            }\n            return;\n        }\n\n        item_ct1.barrier();\n\n#pragma unroll\n        for (int ip = 1; ip < np; ++ip) {\n#ifdef SYCL_FAST_FP16\n            constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {\n                __dpct_align__(16) sycl::half2 tmp[cpy_ne_D];\n                ggml_sycl_memcpy_1<cpy_ne_D * 4>(tmp, &VKQ_combine[(item_ct1.get_local_id(1) + ip) * (DVp / 2) + i0 +\n                                                                   item_ct1.get_local_id(2) * cpy_ne_D]);\n#pragma unroll\n                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {\n                    VKQ[i0/warp_size + i1] += tmp[i1];\n                }\n            }\n#else\n            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;\n#pragma unroll\n            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {\n                __dpct_align__(16) float tmp[cpy_ne_D];\n                ggml_sycl_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(item_ct1.get_local_id(1) + ip)*DVp + i0 + item_ct1.get_local_id(2)*cpy_ne_D]);\n#pragma unroll\n                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {\n                    ((float *)VKQ)[i0/warp_size + i1] += tmp[i1];\n                }\n            }\n#endif // SYCL_FAST_FP16\n\n            KQ_sum[0] += KQ_sum_combine[item_ct1.get_local_id(1) + ip];\n        }\n    }\n\n    // Attention sink: adjust KQ max and sum only for the first of all parallel blocks:\n    if (sinks && item_ct1.get_group(1) == 0) {\n#pragma unroll\n        for (int jc0 = 0; jc0 < cpw; ++jc0) {\n            const int   jc   = jc0 + (item_ct1.get_local_id(1) / np) * cpw;\n            const float sink = ((const float *) sinks)[head0 + jc % ncols2];\n\n            float       KQ_max_new_j = sycl::fmax((float) KQ_max[jc0], sink);\n            const float KQ_max_scale = sycl::native::exp((float) (KQ_max[jc0] - KQ_max_new_j));\n            KQ_max[jc0] = KQ_max_new_j;\n\n            const float val = sycl::native::exp((float) (sink - KQ_max[jc0]));\n            KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val;\n\n#ifdef SYCL_FAST_FP16\n            const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {\n                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;\n            }\n#else\n#pragma unroll\n            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {\n                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale;\n                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale;\n            }\n#endif // SYCL_FAST_FP16\n        }\n    }\n\n    // Write back results:\n#pragma unroll\n    for (int jc0 = 0; jc0 < cpw; ++jc0) {\n        const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw;\n\n        const int j = jc / ncols2;\n        const int c = jc % ncols2;\n\n        if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z())) {\n            return;\n        }\n\n        const float scale = item_ct1.get_group_range(1) == 1 ? 1.0f / KQ_sum[jc0] : 1.0f;\n\n        const int j_dst_unrolled =\n            ((sequence * int(ne01.z()) + col_Q_0 + j) * ne02 + head0 + c) * item_ct1.get_group_range(1) +\n            item_ct1.get_group(1);\n\n#ifdef SYCL_FAST_FP16\n        constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;\n#pragma unroll\n        for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {\n            __dpct_align__(16) sycl::float2 tmp[cpy_ne_D];\n#pragma unroll\n            for (int i1 = 0; i1 < cpy_ne_D; ++i1) {\n                tmp[i1] = VKQ[jc0 * ((DVp / 2) / warp_size) + i0 / warp_size + i1]\n                              .template convert<float, sycl::rounding_mode::automatic>();\n                tmp[i1].x() *= scale;\n                tmp[i1].y() *= scale;\n            }\n            if (i0 + warp_size * cpy_ne_D <= DV / 2 || i0 + item_ct1.get_local_id(2) * cpy_ne_D < DV / 2) {\n                ggml_sycl_memcpy_1<sizeof(tmp)>(\n                    &dst[j_dst_unrolled * DV + 2 * i0 + item_ct1.get_local_id(2) * (2 * cpy_ne_D)], tmp);\n            }\n        }\n#else\n        constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;\n#pragma unroll\n        for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {\n            if (i0 + warp_size*cpy_ne_D <= DV || i0 + item_ct1.get_local_id(2)*cpy_ne_D < DV) {\n#pragma unroll\n                for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {\n                    VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x() *= scale;\n                    VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y() *= scale;\n                }\n                ggml_sycl_memcpy_1<cpy_ne_D*4>(\n                    &dst[j_dst_unrolled*DV + i0 + item_ct1.get_local_id(2)*cpy_ne_D],\n                    &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);\n            }\n        }\n#endif // SYCL_FAST_FP16\n\n        if (item_ct1.get_group_range(1) != 1 && item_ct1.get_local_id(2) == 0) {\n            dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]);\n        }\n    }\n#else\n    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,\n        max_bias, m0, m1, n_head_log2, logit_softcap,\n        ne00, ne01, ne02, ne03,\n              nb01, nb02, nb03,\n        ne10, ne11, ne12, ne13,\n              nb11, nb12, nb13,\n              nb21, nb22, nb23,\n              ne31, ne32, ne33,\n              nb31, nb32, nb33);\n#endif // SYCL_FLASH_ATTN\n}\n\ntemplate <int DKQ, int DV, int ncols2, bool use_logit_softcap>\nstatic void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * Q = dst->src[0];\n\n    const int id        = ggml_sycl_get_device();\n    const int cc        = ggml_sycl_info().devices[id].cc;\n    const int warp_size = WARP_32_SIZE; //can't support WARP_16_SIZE\n\n    constexpr size_t nbytes_shared = 0;\n\n    if constexpr (DV <= 256) {\n        if (Q->ne[1] > 16/ncols2) {\n            constexpr int cols_per_block = 32;\n            const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;\n            const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);\n            launch_fattn<DV, cols_per_block/ncols2, ncols2,\n                flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>\n                (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);\n            return;\n        }\n    }\n\n    if (Q->ne[1] > 8/ncols2) {\n        constexpr int cols_per_block = 16;\n        const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;\n        const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);\n        launch_fattn<DV, cols_per_block/ncols2, ncols2,\n            flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>\n            (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);\n        return;\n    }\n\n    if constexpr (ncols2 <= 8) {\n        if (Q->ne[1] > 4/ncols2) {\n            constexpr int cols_per_block = 8;\n            const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;\n            const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);\n            launch_fattn<DV, cols_per_block/ncols2, ncols2,\n                flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>\n                (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);\n            return;\n        }\n    }\n\n    if constexpr (ncols2 <= 4) {\n        if (Q->ne[1] > 2/ncols2) {\n            constexpr int cols_per_block = 4;\n            const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;\n            const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);\n            launch_fattn<DV, cols_per_block/ncols2, ncols2,\n                flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>\n                (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);\n            return;\n        }\n    }\n\n    if constexpr (ncols2 <= 2) {\n        constexpr int cols_per_block = 2;\n        const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;\n        const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);\n        launch_fattn<DV, cols_per_block/ncols2, ncols2,\n            flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>\n            (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);\n        return;\n    }\n\n    GGML_ABORT(\"fatal error\");\n}\n\ntemplate <int DKQ, int DV, bool use_logit_softcap>\nstatic void launch_fattn_tile_switch_ncols2(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * KQV  = dst;\n    const ggml_tensor * Q    = dst->src[0];\n    const ggml_tensor * K    = dst->src[1];\n    const ggml_tensor * mask = dst->src[3];\n\n    float max_bias = 0.0f;\n    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));\n\n    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);\n    const int gqa_ratio = Q->ne[2] / K->ne[2];\n\n    // On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases.\n    // However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented.\n    //const bool nvidia = GGML_SYCL_CC_IS_NVIDIA(ggml_sycl_info().devices[ggml_sycl_get_device()].cc);\n    const int gqa_limit = gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;\n    const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;\n\n    if constexpr (DV == 512) {\n        if (use_gqa_opt && gqa_ratio % 16 == 0) {\n            launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);\n            return;\n        }\n        if (use_gqa_opt && gqa_ratio % 4 == 0) {\n            launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);\n            return;\n        }\n    }\n\n    if constexpr (DV <= 256) {\n        if (use_gqa_opt && gqa_ratio % 8 == 0) {\n            launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);\n            return;\n        }\n\n        if (use_gqa_opt && gqa_ratio % 4 == 0) {\n            launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);\n            return;\n        }\n\n        if (use_gqa_opt && gqa_ratio % 2 == 0) {\n            launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);\n            return;\n        }\n\n        launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);\n        return;\n    }\n    GGML_ABORT(\"fatal error\");\n}\n\ntemplate <int DKQ, int DV>\nvoid ggml_sycl_flash_attn_ext_tile_case(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * KQV = dst;\n\n    float logit_softcap;\n    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));\n\n    if (logit_softcap == 0.0f) {\n        constexpr bool use_logit_softcap = false;\n        launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);\n    } else {\n        constexpr bool use_logit_softcap = true;\n        launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);\n    }\n}\n\nvoid ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\n#define DECL_FATTN_TILE_CASE(DKQ, DV)                             \\\n    template void ggml_sycl_flash_attn_ext_tile_case              \\\n    <DKQ, DV>(ggml_backend_sycl_context & ctx, ggml_tensor * dst) \\\n\nextern DECL_FATTN_TILE_CASE( 40,  40);\nextern DECL_FATTN_TILE_CASE( 64,  64);\nextern DECL_FATTN_TILE_CASE( 72,  72);\nextern DECL_FATTN_TILE_CASE( 80,  80);\nextern DECL_FATTN_TILE_CASE( 96,  96);\nextern DECL_FATTN_TILE_CASE(112, 112);\nextern DECL_FATTN_TILE_CASE(128, 128);\nextern DECL_FATTN_TILE_CASE(256, 256);\nextern DECL_FATTN_TILE_CASE(576, 512);\n\n"
  },
  {
    "path": "src/ggml-sycl/fattn-vec.hpp",
    "content": "#ifndef GGML_SYCL_FATTN_VEC_HPP\n#define GGML_SYCL_FATTN_VEC_HPP\n\n#include <sycl/sycl.hpp>\n#include <sycl/ext/oneapi/work_group_static.hpp>\n#include <iostream>\n#include <iomanip>\n\n#include \"dpct/helper.hpp\"\n#include \"common.hpp\"\n#include \"ggml.h\"\n#include \"fattn-common.hpp\"\n#include <cmath>\n#include <float.h>\n\nnamespace syclex = sycl::ext::oneapi::experimental;\n\nstatic int ggml_sycl_fattn_vec_get_nthreads_host(const int cc) {\n    return 128;\n    GGML_UNUSED(cc);\n}\n\nstatic constexpr int ggml_sycl_fattn_vec_get_nthreads_device() {\n    return 128;\n}\n\n// Currenlty llvm with the amdgcn target dose not support unrolling loops\n// that contain a break that can not be resolved at compile time.\n#ifdef __clang__\n#pragma clang diagnostic push\n#pragma clang diagnostic ignored \"-Wpass-failed\"\n#endif // __clang__\n\ntemplate <int D,\n          int ncols,\n          int type_K,\n          int type_V,\n          bool use_logit_softcap,\n          int warp_size>  // D == head size\nstatic void flash_attn_ext_vec(const char* __restrict__ Q,\n                        const char* __restrict__ K,\n                        const char* __restrict__ V,\n                        const char* __restrict__ mask,\n                        const char* __restrict__ sinks,\n                        const int* __restrict__ KV_max,\n                        float* __restrict__ dst,\n                        sycl::float2* __restrict__ dst_meta,\n                        const float scale,\n                        const float max_bias,\n                        const float m0,\n                        const float m1,\n                        const uint32_t n_head_log2,\n                        const float logit_softcap,\n                        const int32_t ne00,\n                        const sycl::uint3 ne01,\n                        const int32_t ne02,\n                        const int32_t ne03,\n                        const int32_t nb01,\n                        const int32_t nb02,\n                        const int32_t nb03,\n                        const int32_t ne10,\n                        const int32_t ne11,\n                        const int32_t ne12,\n                        const int32_t ne13,\n                        const int32_t nb11,\n                        const int32_t nb12,\n                        const int64_t nb13,\n                        const int32_t nb21,\n                        const int32_t nb22,\n                        const int64_t nb23,\n                        const int32_t ne31,\n                        const int32_t ne32,\n                        const int32_t ne33,\n                        const int32_t nb31,\n                        const int32_t nb32,\n                        const int64_t nb33) {\n#ifdef SYCL_FLASH_ATTN\n    // Skip unused kernel variants for faster compilation:\n\n    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    if (use_logit_softcap && !(D == 128 || D == 256)) {\n        GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,\n            max_bias, m0, m1, n_head_log2, logit_softcap,\n            ne00, ne01, ne02, ne03,\n                  nb01, nb02, nb03,\n            ne10, ne11, ne12, ne13,\n                  nb11, nb12, nb13,\n                  nb21, nb22, nb23,\n                  ne31, ne32, ne33,\n                  nb31, nb32, nb33);\n        return;\n    }\n\n    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.\n\n    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();\n    constexpr int cpy_ne = cpy_nb / 4;\n\n    constexpr int nthreads_KQ_q = (D/4 < warp_size ? D/4 : warp_size);\n    constexpr int nthreads_V_q  = (D/4 < warp_size ? D/4 : warp_size);\n\n    constexpr int nthreads    = ggml_sycl_fattn_vec_get_nthreads_device();\n    constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;\n    constexpr int nthreads_V  = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;\n\n    static_assert(warp_size % nthreads_KQ == 0, \"bad nthreads_K\");\n    static_assert(warp_size % nthreads_V  == 0, \"bad nthreads_V\");\n\n    constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;\n    constexpr int V_cols_per_iter   = warp_size / nthreads_V;\n\n    constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ, warp_size>();\n    constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;\n#ifdef GGML_SYCL_F16\n    constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, sycl::half, V_rows_per_thread>();\n#else\n    constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();\n#endif // GGML_SYCL_F16\n\n    const int ic0 = item_ct1.get_group(2) * ncols;  // Index of the Q/QKV column to work on.\n\n    const int sequence  = item_ct1.get_group(0) / ne02;\n    const int head      = item_ct1.get_group(0) - sequence * ne02;\n    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.\n    Q += nb03*sequence + nb02* head              + nb01*ic0;\n    K += nb13*sequence + nb12*(head / gqa_ratio);\n    V += nb23*sequence + nb22*(head / gqa_ratio);\n\n    const sycl::half * maskh = (const sycl::half *) (mask + nb33 * (sequence % ne33) + nb31 * ic0);\n\n    const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);\n\n    static_assert(D % (2*warp_size) == 0, \"D not divisible by 2*warp_size == 64.\");\n    constexpr int nwarps = nthreads / warp_size;\n    const int     tid    = warp_size * item_ct1.get_local_id(1) + item_ct1.get_local_id(2);\n    __builtin_assume(tid < nthreads);\n\n    constexpr int ne_KQ      = ncols*D;\n    constexpr int ne_combine = nwarps*V_cols_per_iter*D;\n\n    constexpr size_t lsm_size1 = ncols * warp_size;\n    constexpr size_t lsm_size2 = ncols * warp_size;\n#ifdef GGML_SYCL_F16\n    sycl::half2 VKQ[ncols][(D / 2) / nthreads_V] = { { { 0.0f, 0.0f } } };\n    constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine);\n    constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2)*sizeof(float) + lsm_size3*sizeof(sycl::half);\n\n    syclex::work_group_static<char[local_share_mem_size]> lsm;\n\n    float *KQ_max_shared = (float *)&lsm;\n    float *KQ_sum_shared = KQ_max_shared+lsm_size1;\n    sycl::half* KQ = (sycl::half*)(KQ_sum_shared + lsm_size2);\n\n\n#else\n    sycl::float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};\n\n    constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine);\n    constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2 + lsm_size3)*sizeof(float);\n\n\n    syclex::work_group_static<char[local_share_mem_size]> lsm;\n    float *KQ_max_shared = (float *)&lsm;\n    float *KQ_sum_shared = KQ_max_shared+lsm_size1;\n    float* KQ = KQ_sum_shared + lsm_size2;\n\n#endif // GGML_SYCL_F16\n\n    float KQ_max[ncols];\n    float KQ_sum[ncols];\n#pragma unroll\n    for (int j = 0; j < ncols; ++j) {\n        KQ_max[j] = -FLT_MAX/2.0f;\n        KQ_sum[j] = 0.0f;\n    }\n\n    // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:\n#ifdef GGML_SYCL_F16\n    sycl::half2 Q_reg[ncols][(D / 2) / nthreads_KQ] = {{{0.0f, 0.0f}}};  // Will be initialized completely.\n#else\n    sycl::float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.\n#endif // GGML_SYCL_F16\n    int    Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];\n    sycl::float2 Q_ds[ncols][1 > D / (sizeof(int) * nthreads_KQ) ? 1 : D / (sizeof(int) * nthreads_KQ)];\n    if constexpr (Q_q8_1) {\n#pragma unroll\n        for (int j0 = 0; j0 < ncols; j0 += nwarps) {\n            const int j = j0 + item_ct1.get_local_id(1);\n\n            if (j0 + nwarps > ncols && j >= ncols) {\n                break;\n            }\n\n            // Reuse KQ as temporary storage for converting Q to q8_1:\n            int    * tmp_q_i32 = (int    *) &KQ[j*D];\n            sycl::float2 * tmp_q_ds  = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int));\n\n            // Set memory to zero if out of bounds:\n            if (ncols > 1 && ic0 + j >= int(ne01.z())) {\n#pragma unroll\n                for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += warp_size) {\n                    const int i = i0 + item_ct1.get_local_id(2);\n\n                    if (i0 + warp_size <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {\n                        tmp_q_i32[i] = 0;\n                    }\n                }\n                if (item_ct1.get_local_id(2) < D/QK8_1) {\n                    tmp_q_ds[item_ct1.get_local_id(2)] = sycl::float2(0.0f, 0.0f);\n                }\n            } else {\n                const float * Q_f = (const float *) (Q + j*nb01);\n                constexpr int nthreads_quantize = D/sizeof(int) < warp_size ? D/sizeof(int) : warp_size;\n#pragma unroll\n                for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) {\n                    quantize_q8_1_to_shared<sycl::float2, nthreads_quantize, warp_size>\n                        (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1);\n                }\n            }\n        }\n\n\n        item_ct1.barrier(sycl::access::fence_space::local_space);\n\n#pragma unroll\n        for (int j = 0; j < ncols; ++j) {\n            int    * tmp_q_i32 = (int    *) &KQ[j*D];\n            sycl::float2 * tmp_q_ds  = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int));\n\n#pragma unroll\n            for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) {\n                const int i =\n                    i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ);\n\n                Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i];\n                Q_ds[j][i0/nthreads_KQ]  = tmp_q_ds[i/QI8_1];\n            }\n        }\n\n        item_ct1.barrier(sycl::access::fence_space::local_space);\n\n    } else {\n#ifdef GGML_SYCL_F16\n        const sycl::half2 scale_h2 = sycl::half2(scale, scale);\n#pragma unroll\n        for (int j = 0; j < ncols; ++j) {\n            const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j * nb01);\n#pragma unroll\n            for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {\n                const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) :\n                                                               item_ct1.get_local_id(2) % nthreads_KQ) *\n                                       cpy_ne;\n\n                sycl::float2 tmp[cpy_ne] = {\n                    { 0.0f, 0.0f }\n                };\n                if (ncols == 1 || ic0 + j < int(ne01.z())) {\n                    ggml_sycl_memcpy_1<cpy_nb>(tmp,            &Q_j[i]);\n                    ggml_sycl_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);\n                }\n#pragma unroll\n                for (int i1 = 0; i1 < cpy_ne; ++i1) {\n                    Q_reg[j][i0 / nthreads_KQ + i1] = sycl::half2(tmp[i1].x(), tmp[i1].y());\n                }\n            }\n#pragma unroll\n            for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {\n                Q_reg[j][k] *= scale_h2;\n            }\n        }\n#else\n#pragma unroll\n        for (int j = 0; j < ncols; ++j) {\n            const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j*nb01);\n#pragma unroll\n            for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {\n                const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ)*cpy_ne;\n                if (ncols == 1 || ic0 + j < int(ne01.z())) {\n                    ggml_sycl_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ],            &Q_j[i]);\n                    ggml_sycl_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);\n                }\n            }\n#pragma unroll\n            for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {\n                Q_reg[j][k].x() *= scale;\n                Q_reg[j][k].y() *= scale;\n            }\n        }\n#endif // GGML_SYCL_F16\n    }\n\n    const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;\n    K += item_ct1.get_group(1) * nthreads * nb11;\n    V += item_ct1.get_group(1) * nthreads * nb21;\n    maskh += item_ct1.get_group(1) * nthreads;\n    for (int k_VKQ_0 = item_ct1.get_group(1) * nthreads; k_VKQ_0 < k_VKQ_max;\n         k_VKQ_0 += item_ct1.get_group_range(1) * nthreads,\n             // Increment pointers after each loop:\n         K += item_ct1.get_group_range(1) * nthreads * nb11, V += item_ct1.get_group_range(1) * nthreads * nb21,\n             maskh += item_ct1.get_group_range(1) * nthreads) {\n        // Calculate KQ tile and keep track of new maximum KQ values:\n        float KQ_reg[ncols]={}; // KQ in registers.\n        float KQ_max_new[ncols]={};\n\n\n#pragma unroll\n        for (int j = 0; j < ncols; ++j) {\n            KQ_max_new[j] = KQ_max[j];\n        }\n\n#pragma unroll\n        for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {\n            const int i_KQ = item_ct1.get_local_id(1) * warp_size +\n                             (nthreads_KQ == warp_size ? 0 : (item_ct1.get_local_id(2) & ~(nthreads_KQ - 1))) + i_KQ_0;\n\n#pragma unroll\n            for (int j = 0; j < ncols; ++j) {\n                float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);\n                sum = warp_reduce_sum<nthreads_KQ>(sum);\n\n                if (use_logit_softcap) {\n                    sum = logit_softcap * sycl::tanh(sum);\n                }\n                if (mask) {\n                    sum += slope * sycl::vec<sycl::half, 1>(maskh[j * ne11 + i_KQ])\n                                       .convert<float, sycl::rounding_mode::automatic>()[0];\n                }\n\n                KQ_max_new[j] = sycl::fmax((float) KQ_max_new[j], sum);\n\n                if (int(nthreads_KQ == warp_size ? item_ct1.get_local_id(2)\n                                                 : item_ct1.get_local_id(2) %\n                                                       nthreads_KQ) == i_KQ_0) {\n                  KQ_reg[j] = sum;\n                }\n            }\n        }\n\n#pragma unroll\n        for (int j = 0; j < ncols; ++j) {\n#pragma unroll\n            for (int offset = nthreads_KQ; offset < warp_size; offset <<= 1) {\n               KQ_max_new[j] = sycl::fmax(\n                  (float)KQ_max_new[j],\n                  (float)dpct::permute_sub_group_by_xor(\n                      sycl::ext::oneapi::this_work_item::get_sub_group(),\n                      KQ_max_new[j],\n                      offset,\n                      warp_size));\n            }\n            const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - KQ_max_new[j]));\n            KQ_max[j] = KQ_max_new[j];\n\n            KQ_reg[j]            = sycl::native::exp((float) (KQ_reg[j] - KQ_max[j]));\n            KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];\n            KQ[j*nthreads + tid] = KQ_reg[j];\n\n#ifdef GGML_SYCL_F16\n            const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);\n#pragma unroll\n            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {\n                VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;\n            }\n#else\n#pragma unroll\n            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {\n                VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale;\n                VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale;\n            }\n#endif // GGML_SYCL_F16\n        }\n\n        sycl::group_barrier(sycl::ext::oneapi::this_work_item::get_sub_group());\n\n#pragma unroll\n        for (int k0 = 0; k0 < warp_size; k0 += V_cols_per_iter) {\n            const int k = item_ct1.get_local_id(1) * warp_size + k0 +\n                          (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V);\n\n#ifdef GGML_SYCL_F16\n            sycl::half2 KQ_k[ncols];\n#pragma unroll\n            for (int j = 0; j < ncols; ++j) {\n                KQ_k[j] = sycl::half2(KQ[j * nthreads + k]);\n            }\n#pragma unroll\n            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {\n                sycl::half2 tmp[V_rows_per_thread / 2];\n                dequantize_V(V + k * nb21, tmp,\n                             2 * i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) :\n                                                                      item_ct1.get_local_id(2) % nthreads_V) *\n                                               V_rows_per_thread);\n#pragma unroll\n                for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {\n#pragma unroll\n                    for (int j = 0; j < ncols; ++j) {\n                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j];\n                    }\n                }\n            }\n#else\n            float KQ_k[ncols];\n#pragma unroll\n            for (int j = 0; j < ncols; ++j) {\n                KQ_k[j] = KQ[j*nthreads + k];\n            }\n#pragma unroll\n            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {\n                sycl::float2 tmp[V_rows_per_thread/2];\n                dequantize_V(V + k*nb21, tmp,\n                    2*i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*V_rows_per_thread);\n#pragma unroll\n                for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {\n#pragma unroll\n                    for (int j = 0; j < ncols; ++j) {\n                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x() += tmp[i_VKQ_1].x()*KQ_k[j];\n                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y() += tmp[i_VKQ_1].y()*KQ_k[j];\n                    }\n                }\n            }\n#endif // GGML_SYCL_F16\n        }\n    }\n\n    if (sinks && item_ct1.get_group(1) == 0) {\n        const float sink = ((const float *) sinks)[head];\n\n#pragma unroll\n        for (int j0 = 0; j0 < ncols; j0 += nwarps) {\n            const int j = j0 + item_ct1.get_local_id(1);\n\n            if (j0 + nwarps > ncols && j >= ncols) {\n                break;\n            }\n            const float kqmax_new_j  = sycl::fmax(sink, (float) KQ_max[j]);\n            const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - kqmax_new_j));\n            KQ_max[j] = kqmax_new_j;\n\n            KQ_sum[j] = KQ_sum[j] * KQ_max_scale +\n                        (item_ct1.get_local_id(2) == 0 ? sycl::native::exp((float) (sink - KQ_max[j])) : 0.0f);\n#ifdef GGML_SYCL_F16\n            const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);\n#pragma unroll\n            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {\n                VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;\n            }\n#else\n#pragma unroll\n            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {\n                VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale;\n                VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale;\n            }\n#endif // GGML_SYCL_F16\n        }\n    }\n\n#pragma unroll\n    for (int j = 0; j < ncols; ++j) {\n        if (item_ct1.get_local_id(1) == 0) {\n            KQ_max_shared[j*warp_size+item_ct1.get_local_id(2)] = -FLT_MAX / 2.0f;\n            KQ_sum_shared[j*warp_size+item_ct1.get_local_id(2)] = 0.0f;\n        }\n    }\n\n    item_ct1.barrier(sycl::access::fence_space::local_space);\n\n#pragma unroll\n    for (int j = 0; j < ncols; ++j) {\n        if (item_ct1.get_local_id(2) == 0) {\n            KQ_max_shared[j*warp_size+item_ct1.get_local_id(1)] = KQ_max[j];\n        }\n    }\n\n\n    item_ct1.barrier(sycl::access::fence_space::local_space);\n\n#pragma unroll\n    for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {\n        if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z())) {\n            break;\n        }\n\n        float kqmax_new         = KQ_max_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)];\n        kqmax_new = warp_reduce_max<warp_size>(kqmax_new);\n        const float kqmax_scale = sycl::native::exp((float) (KQ_max[j_VKQ] - kqmax_new));\n        KQ_max[j_VKQ] = kqmax_new;\n\n#ifdef GGML_SYCL_F16\n        sycl::half2 * VKQ_tmp = (sycl::half2 *) KQ + item_ct1.get_local_id(1) * (V_cols_per_iter * D / 2) +\n                                (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V) * (D / 2);\n\n        const sycl::half2 kqmax_scale_h2 = sycl::half2(kqmax_scale, kqmax_scale);\n#pragma unroll\n        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {\n            VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2;\n        }\n#pragma unroll\n        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {\n            const int i_VKQ =\n                i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V) *\n                              (V_rows_per_thread / 2);\n\n            ggml_sycl_memcpy_1<V_rows_per_thread * sizeof(sycl::half)>(VKQ_tmp + i_VKQ,\n                                                                       &VKQ[j_VKQ][i_VKQ_0 / nthreads_V]);\n        }\n#else\n        sycl::float2 * VKQ_tmp = (sycl::float2 *) KQ + item_ct1.get_local_id(1)*(V_cols_per_iter*D/2)\n            + (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V)*(D/2);\n#pragma unroll\n        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {\n            VKQ[j_VKQ][i_VKQ_0/nthreads_V].x() *= kqmax_scale;\n            VKQ[j_VKQ][i_VKQ_0/nthreads_V].y() *= kqmax_scale;\n        }\n#pragma unroll\n        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {\n            const int i_VKQ = i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*(V_rows_per_thread/2);\n\n            ggml_sycl_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ,                       &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);\n            ggml_sycl_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);\n        }\n#endif // GGML_SYCL_F16\n\n        KQ_sum[j_VKQ] *= kqmax_scale;\n        KQ_sum[j_VKQ] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ]);\n        if (item_ct1.get_local_id(2) == 0) {\n            KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(1)] = KQ_sum[j_VKQ];\n        }\n\n        item_ct1.barrier(sycl::access::fence_space::local_space);\n\n\n        if (nthreads <= D || tid < D) {\n            KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)];\n            KQ_sum[j_VKQ] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ]);\n\n#pragma unroll\n            for (int i0 = 0; i0 < D; i0 += nthreads) {\n                float dst_val = 0;\n#pragma unroll\n                for (int w = 0; w < nwarps; ++w) {\n#pragma unroll\n                    for (int v = 0; v < V_cols_per_iter; ++v) {\n                        dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);\n                    }\n                }\n                if (item_ct1.get_group_range(1) == 1) {\n                    dst_val /= KQ_sum[j_VKQ];\n                }\n                dst[(((sequence * int(ne01.z()) + ic0 + j_VKQ) * ne02 + head) * item_ct1.get_group_range(1) +\n                     item_ct1.get_group(1)) *\n                        D +\n                    i0 + tid] = dst_val;\n            }\n        }\n\n        if (j_VKQ < ncols-1) {\n            item_ct1.barrier(sycl::access::fence_space::local_space);\n        }\n\n    }\n\n    if (item_ct1.get_group_range(1) != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z()))) {\n        dst_meta[((sequence * int(ne01.z()) + ic0 + tid) * ne02 + head) * item_ct1.get_group_range(1) +\n                 item_ct1.get_group(1)] = make_float2(KQ_max[tid], KQ_sum[tid]);\n    }\n#else\n    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,\n        max_bias, m0, m1, n_head_log2, logit_softcap,\n        ne00, ne01, ne02, ne03,\n              nb01, nb02, nb03,\n        ne10, ne11, ne12, ne13,\n              nb11, nb12, nb13,\n              nb21, nb22, nb23,\n              ne31, ne32, ne33,\n              nb31, nb32, nb33);\n\n#endif // SYCL_FLASH_ATTN\n}\n#ifdef __clang__\n#pragma clang diagnostic pop\n#endif // __clang__\n\n\ntemplate <int D, int cols_per_block, int type_K, int type_V, bool use_logit_softcap>\nvoid ggml_sycl_flash_attn_ext_vec_case_impl(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n\n    const int warp_size = WARP_16_SIZE; //better performance than WARP_32_SIZE\n\n    const int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;\n\n    const int nthreads = ggml_sycl_fattn_vec_get_nthreads_host(cc);\n    const int nwarps   = nthreads / warp_size;\n\n    const bool need_f16_K = type_K == GGML_TYPE_F16;\n    const bool need_f16_V = type_V == GGML_TYPE_F16;\n    constexpr size_t nbytes_shared = 0;\n\n    launch_fattn<D, cols_per_block, 1,\n                 flash_attn_ext_vec<D, cols_per_block, type_K, type_V,\n                                    use_logit_softcap, warp_size>, warp_size>(\n        ctx, dst, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);\n}\n\ntemplate <int D, int type_K, int type_V>\nvoid ggml_sycl_flash_attn_ext_vec_case(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * KQV = dst;\n    const ggml_tensor * Q   = dst->src[0];\n\n    float logit_softcap;\n    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));\n\n    if (Q->ne[1] == 1) {\n        constexpr int cols_per_block = 1;\n        if (logit_softcap == 0.0f) {\n            constexpr bool use_logit_softcap = false;\n            ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);\n        } else {\n            constexpr bool use_logit_softcap = true;\n            ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);\n        }\n        return;\n    }\n\n    constexpr int cols_per_block = 2;\n    if (logit_softcap == 0.0f) {\n        constexpr bool use_logit_softcap = false;\n        ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);\n    } else {\n        constexpr bool use_logit_softcap = true;\n        ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);\n    }\n}\n\n#define DECL_FATTN_VEC_CASE(D, type_K, type_V)                              \\\n    template void ggml_sycl_flash_attn_ext_vec_case                         \\\n    <D, type_K, type_V>(ggml_backend_sycl_context & ctx, ggml_tensor * dst) \\\n\n#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K)             \\\n    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16);  \\\n    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \\\n    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \\\n    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \\\n    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \\\n    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \\\n\nEXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)\nEXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)\nEXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)\nEXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)\nEXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)\nEXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)\n\nEXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)\nEXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)\nEXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)\nEXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)\nEXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)\nEXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)\n\nEXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)\nEXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)\nEXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)\nEXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)\nEXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)\nEXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)\n\n#endif // GGML_SYCL_FATTN_VEC_HPP\n"
  },
  {
    "path": "src/ggml-sycl/fattn.cpp",
    "content": "//\n// MIT license\n// Copyright (C) 2025 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n\n#include <sycl/sycl.hpp>\n#include \"dpct/helper.hpp\"\n#include \"common.hpp\"\n#include \"fattn-common.hpp\"\n#include \"fattn-tile.hpp\"\n#include \"fattn-vec.hpp\"\n#include \"fattn.hpp\"\n\n\n#define FATTN_VEC_CASE(D, type_K, type_V)                                                                        \\\n    {                                                                                                            \\\n        const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \\\n        const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \\\n        if (Q->ne[0] == (D) && type_K_okay && type_V_okay) {                                                     \\\n            ggml_sycl_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst);                                      \\\n            return;                                                                                              \\\n        }                                                                                                        \\\n    }                                                                    \\\n\n#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \\\n    FATTN_VEC_CASE( 64, type_K, type_V)       \\\n    FATTN_VEC_CASE(128, type_K, type_V)       \\\n    FATTN_VEC_CASE(256, type_K, type_V)       \\\n\nstatic void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * Q = dst->src[0];\n    ggml_tensor * K = dst->src[1];\n    ggml_tensor * V = dst->src[2];\n\n#ifdef GGML_SYCL_FA_ALL_QUANTS\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_F16)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)\n\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q4_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)\n\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q4_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)\n\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q5_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)\n\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q5_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)\n\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q8_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)\n#else\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_F16)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)\n    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)\n#endif // GGML_SYCL_FA_ALL_QUANTS\n\n    GGML_ABORT(\"Not match KV type in vec\");\n}\n\n// Best FlashAttention kernel for a specific GPU:\nenum best_fattn_kernel {\n    BEST_FATTN_KERNEL_NONE     =   0,\n    BEST_FATTN_KERNEL_VEC      = 100,\n    BEST_FATTN_KERNEL_TILE     = 200,\n};\n\nstatic best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {\n    GGML_UNUSED(device);\n#ifndef SYCL_FLASH_ATTN\n    GGML_UNUSED(dst);\n    return BEST_FATTN_KERNEL_NONE;\n#endif// SYCL_FLASH_ATTN\n\n    if(!g_ggml_sycl_enable_flash_attention) return BEST_FATTN_KERNEL_NONE;\n\n    const ggml_tensor * KQV   = dst;\n    const ggml_tensor * Q     = dst->src[0];\n    const ggml_tensor * K     = dst->src[1];\n    const ggml_tensor * V     = dst->src[2];\n    const ggml_tensor * mask  = dst->src[3];\n\n    const int gqa_ratio = Q->ne[2] / K->ne[2];\n    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);\n\n    float max_bias = 0.0f;\n    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));\n\n    bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;\n    for (const ggml_tensor * t : {Q, K, V, mask}) {\n        if (t == nullptr || ggml_is_quantized(t->type)) {\n            continue;\n        }\n        for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {\n            if (t->nb[i] % 16 != 0) {\n                gqa_opt_applies = false;\n                break;\n            }\n        }\n    }\n\n    switch (K->ne[0]) {\n        case  40:\n        case  64:\n        case  72:\n        case  80:\n        case  96:\n        case 128:\n        case 112:\n        case 256:\n            if (V->ne[0] != K->ne[0]) {\n                return BEST_FATTN_KERNEL_NONE;\n            }\n            break;\n        case 576:\n            if (V->ne[0] != 512) {\n                return BEST_FATTN_KERNEL_NONE;\n            }\n            if (!gqa_opt_applies) {\n                return BEST_FATTN_KERNEL_NONE;\n            }\n            break;\n        default:\n            return BEST_FATTN_KERNEL_NONE;\n    }\n\n#ifndef GGML_SYCL_FA_ALL_QUANTS\n    if (K->type != V->type) {\n        return BEST_FATTN_KERNEL_NONE;\n    }\n#endif // GGML_SYCL_FA_ALL_QUANTS\n\n    switch (K->type) {\n        case GGML_TYPE_F32:\n        case GGML_TYPE_F16:\n            break;\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n#ifndef GGML_SYCL_FA_ALL_QUANTS\n            return BEST_FATTN_KERNEL_NONE;\n#endif // GGML_SYCL_FA_ALL_QUANTS\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q8_0:\n            break;\n        default:\n            return BEST_FATTN_KERNEL_NONE;\n    }\n\n    if (mask && mask->ne[2] != 1) {\n        return BEST_FATTN_KERNEL_NONE;\n    }\n\n    // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:\n    const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;\n\n    // Todo: Use the XMX kernel if possible:\n\n    // If there are no tensor cores available, use the generic tile kernel:\n    if (can_use_vector_kernel) {\n        if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {\n            if (Q->ne[1] == 1) {\n                if (!gqa_opt_applies) {\n                    return BEST_FATTN_KERNEL_VEC;\n                }\n            }\n        } else {\n            if (Q->ne[1] <= 2) {\n                return BEST_FATTN_KERNEL_VEC;\n            }\n        }\n    }\n    return BEST_FATTN_KERNEL_TILE;\n}\n\nvoid ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_sycl_set_device(ctx.device);\n    switch (ggml_sycl_get_best_fattn_kernel(ggml_sycl_get_device(), dst)) {\n        case BEST_FATTN_KERNEL_NONE:\n            GGML_ABORT(\"Not support Flash-Attention\");\n        case BEST_FATTN_KERNEL_TILE:\n            ggml_sycl_flash_attn_ext_tile(ctx, dst);\n            break;\n        case BEST_FATTN_KERNEL_VEC:\n            ggml_sycl_flash_attn_ext_vec(ctx, dst);\n            break;\n    }\n}\n\nbool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst) {\n    return ggml_sycl_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;\n}\n"
  },
  {
    "path": "src/ggml-sycl/fattn.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2025 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_FATTN_HPP\n#define GGML_SYCL_FATTN_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nbool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst);\n\n#endif // GGML_SYCL_FATTN_HPP\n"
  },
  {
    "path": "src/ggml-sycl/gated_delta_net.cpp",
    "content": "#include <sycl/sycl.hpp>\n#include \"dpct/helper.hpp\"\n#include \"common.hpp\"\n#include \"ggml.h\"\n#include \"gated_delta_net.hpp\"\n#include <cmath>\n\n\ntemplate <int S_v, bool KDA>\nvoid gated_delta_net_sycl(const float *     q,\n                          const float *     k,\n                          const float *     v,\n                          const float *     g,\n                          const float *     beta,\n                          const float *     curr_state,\n                          float *           dst,\n                          int64_t           H,\n                          int64_t           n_tokens,\n                          int64_t           n_seqs,\n                          int64_t           sq1,\n                          int64_t           sq2,\n                          int64_t           sq3,\n                          int64_t           sv1,\n                          int64_t           sv2,\n                          int64_t           sv3,\n                          int64_t           sb1,\n                          int64_t           sb2,\n                          int64_t           sb3,\n                          const sycl::uint3 neqk1_magic,\n                          const sycl::uint3 rq3_magic,\n                          float             scale) {\n    auto           item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const uint32_t h_idx    = item_ct1.get_group(2);\n    const uint32_t sequence = item_ct1.get_group(1);\n    // each warp owns one column, using warp-level primitives to reduce across rows\n    const int      lane     = item_ct1.get_local_id(2);\n    const int      col      = item_ct1.get_group(0) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);\n\n    const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);\n    const uint32_t iq3 = fastdiv(sequence, rq3_magic);\n\n    const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;\n    float *       attn_data        = dst;\n    float *       state            = dst + attn_score_elems;\n\n    const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;\n    state += state_offset;\n    curr_state += state_offset;\n    attn_data += (sequence * n_tokens * H + h_idx) * S_v;\n\n    constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v;\n    static_assert(S_v % warp_size == 0, \"S_v must be a multiple of warp_size\");\n    constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;\n    float         s_shard[rows_per_lane];\n#pragma unroll\n    for (int r = 0; r < rows_per_lane; r++) {\n        const int i = r * warp_size + lane;\n        s_shard[r]  = curr_state[i * S_v + col];\n    }\n\n    for (int t = 0; t < n_tokens; t++) {\n        const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;\n        const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;\n        const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1;\n\n        const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1;\n        const float * beta_t = beta + gb_offset;\n        const float * g_t    = g    + gb_offset * (KDA ? S_v : 1);\n\n        const float beta_val = *beta_t;\n\n        if constexpr (!KDA) {\n            const float g_val = sycl::native::exp(*g_t);\n\n            // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]\n            float kv_shard = 0.0f;\n#pragma unroll\n            for (int r = 0; r < rows_per_lane; r++) {\n                const int i = r * warp_size + lane;\n                kv_shard += s_shard[r] * k_t[i];\n            }\n            float kv_col = warp_reduce_sum<warp_size>(kv_shard);\n\n            // delta[col] = (v[col] - g * kv[col]) * beta\n            float delta_col = (v_t[col] - g_val * kv_col) * beta_val;\n\n            // fused: S[i][col] = g * S[i][col] + k[i] * delta[col]\n            // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]\n            float attn_partial = 0.0f;\n#pragma unroll\n            for (int r = 0; r < rows_per_lane; r++) {\n                const int i = r * warp_size + lane;\n                s_shard[r]  = g_val * s_shard[r] + k_t[i] * delta_col;\n                attn_partial += s_shard[r] * q_t[i];\n            }\n\n            float attn_col = warp_reduce_sum<warp_size>(attn_partial);\n\n            if (lane == 0) {\n                attn_data[col] = attn_col * scale;\n            }\n        } else {\n            // kv[col] = sum_i g[i] * S[i][col] * k[i]\n            float kv_shard = 0.0f;\n#pragma unroll\n            for (int r = 0; r < rows_per_lane; r++) {\n                const int i = r * warp_size + lane;\n                kv_shard += sycl::native::exp(g_t[i]) * s_shard[r] * k_t[i];\n            }\n\n            float kv_col = warp_reduce_sum<warp_size>(kv_shard);\n\n            // delta[col] = (v[col] - kv[col]) * beta\n            float delta_col = (v_t[col] - kv_col) * beta_val;\n\n            // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]\n            // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]\n            float attn_partial = 0.0f;\n#pragma unroll\n            for (int r = 0; r < rows_per_lane; r++) {\n                const int i = r * warp_size + lane;\n                s_shard[r]  = sycl::native::exp(g_t[i]) * s_shard[r] + k_t[i] * delta_col;\n                attn_partial += s_shard[r] * q_t[i];\n            }\n\n            float attn_col = warp_reduce_sum<warp_size>(attn_partial);\n\n            if (lane == 0) {\n                attn_data[col] = attn_col * scale;\n            }\n        }\n\n        attn_data += S_v * H;\n    }\n\n    // Write state back to global memory\n#pragma unroll\n    for (int r = 0; r < rows_per_lane; r++) {\n        const int i          = r * warp_size + lane;\n        state[i * S_v + col] = s_shard[r];\n    }\n}\n\ntemplate <bool KDA>\nstatic void launch_gated_delta_net(const float *   q_d,\n                                   const float *   k_d,\n                                   const float *   v_d,\n                                   const float *   g_d,\n                                   const float *   b_d,\n                                   const float *   s_d,\n                                   float *         dst_d,\n                                   int64_t         S_v,\n                                   int64_t         H,\n                                   int64_t         n_tokens,\n                                   int64_t         n_seqs,\n                                   int64_t         sq1,\n                                   int64_t         sq2,\n                                   int64_t         sq3,\n                                   int64_t         sv1,\n                                   int64_t         sv2,\n                                   int64_t         sv3,\n                                   int64_t         sb1,\n                                   int64_t         sb2,\n                                   int64_t         sb3,\n                                   int64_t         neqk1,\n                                   int64_t         rq3,\n                                   float           scale,\n                                   dpct::queue_ptr stream) {\n    //TODO: Add chunked kernel for even faster pre-fill\n    const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size;\n\n    const int num_warps = 4;\n    dpct::dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);\n    dpct::dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);\n\n    const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1);\n    const sycl::uint3 rq3_magic   = init_fastdiv_values(rq3);\n\n    int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;\n\n    switch (S_v) {\n        case 16:\n            {\n                constexpr int sv = 16;\n                stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),\n                                     [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                                         gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,\n                                                                       n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,\n                                                                       sb3, neqk1_magic, rq3_magic, scale);\n                                     });\n            }\n            break;\n        case 32:\n            {\n                constexpr int sv = 32;\n                stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),\n                                     [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                                         gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,\n                                                                       n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,\n                                                                       sb3, neqk1_magic, rq3_magic, scale);\n                                     });\n            }\n            break;\n        case 64: {\n            {\n                constexpr int sv = 64;\n                stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),\n                                        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                                            gated_delta_net_sycl<sv, KDA>(\n                                                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,\n                                                sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);\n                                        });\n            }\n            break;\n        }\n        case 128: {\n            {\n                constexpr int sv = 128;\n                stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),\n                                        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                                            gated_delta_net_sycl<sv, KDA>(\n                                                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,\n                                                sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);\n                                        });\n            }\n            break;\n        }\n        default:\n            GGML_ABORT(\"fatal error\");\n            break;\n    }\n}\n\nvoid ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src_q     = dst->src[0];\n    ggml_tensor * src_k     = dst->src[1];\n    ggml_tensor * src_v     = dst->src[2];\n    ggml_tensor * src_g     = dst->src[3];\n    ggml_tensor * src_beta  = dst->src[4];\n    ggml_tensor * src_state = dst->src[5];\n\n    GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);\n    GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);\n    GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);\n    GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);\n    GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);\n    GGML_TENSOR_LOCALS(size_t,  nbv, src_v, nb);\n    GGML_TENSOR_LOCALS(size_t,  nbb, src_beta, nb);\n\n    const int64_t S_v      = nev0;\n    const int64_t H        = nev1;\n    const int64_t n_tokens = nev2;\n    const int64_t n_seqs   = nev3;\n\n    const bool kda = (src_g->ne[0] == S_v);\n\n    GGML_ASSERT(neq1 == nek1);\n    const int64_t neqk1 = neq1;\n\n    const int64_t rq3 = nev3 / neq3;\n\n    const float * q_d = (const float *) src_q->data;\n    const float * k_d = (const float *) src_k->data;\n    const float * v_d = (const float *) src_v->data;\n    const float * g_d = (const float *) src_g->data;\n    const float * b_d = (const float *) src_beta->data;\n\n    const float * s_d   = (const float *) src_state->data;\n    float *       dst_d = (float *) dst->data;\n\n    GGML_ASSERT(ggml_is_contiguous_rows(src_q));\n    GGML_ASSERT(ggml_is_contiguous_rows(src_k));\n    GGML_ASSERT(ggml_is_contiguous_rows(src_v));\n    GGML_ASSERT(ggml_are_same_stride(src_q, src_k));\n    GGML_ASSERT(src_g->ne[0] == 1 || kda);\n    GGML_ASSERT(ggml_is_contiguous(src_g));\n    GGML_ASSERT(ggml_is_contiguous(src_beta));\n    GGML_ASSERT(ggml_is_contiguous(src_state));\n\n    // strides in floats (beta strides used for both g and beta offset computation)\n    const int64_t sq1 = nbq1 / sizeof(float);\n    const int64_t sq2 = nbq2 / sizeof(float);\n    const int64_t sq3 = nbq3 / sizeof(float);\n    const int64_t sv1 = nbv1 / sizeof(float);\n    const int64_t sv2 = nbv2 / sizeof(float);\n    const int64_t sv3 = nbv3 / sizeof(float);\n    const int64_t sb1 = nbb1 / sizeof(float);\n    const int64_t sb2 = nbb2 / sizeof(float);\n    const int64_t sb3 = nbb3 / sizeof(float);\n\n    const float scale = 1.0f / sqrtf((float) S_v);\n\n    dpct::queue_ptr stream = ctx.stream();\n\n    if (kda) {\n        launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,\n            S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,\n            sb1, sb2, sb3, neqk1, rq3, scale, stream);\n    } else {\n        launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,\n            S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,\n            sb1, sb2, sb3, neqk1, rq3, scale, stream);\n    }\n}\n\nvoid ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6);\n    ggml_sycl_op_gated_delta_net(ctx, dst);\n}\n"
  },
  {
    "path": "src/ggml-sycl/gated_delta_net.hpp",
    "content": "#pragma once\n\n#include <sycl/sycl.hpp>\n#include \"dpct/helper.hpp\"\n#include \"common.hpp\"\n#include \"ggml.h\"\n\nvoid ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-sycl/gemm.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_GEMM_HPP\n#define GGML_SYCL_GEMM_HPP\n\n#include \"ggml-sycl.h\"\n\n#if GGML_SYCL_DNNL\n\n#include \"dnnl.hpp\"\n#include \"dnnl_sycl.hpp\"\n\nclass DnnlGemmWrapper {\npublic:\n    using dt = dnnl::memory::data_type;\n    using tag = dnnl::memory::format_tag;\n\n    template<typename T>\n    static constexpr dt to_dt() {\n        if constexpr (std::is_same_v<T, float>) return dt::f32;\n        else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;\n        else static_assert(0);\n    }\n\n    static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,\n        const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2,\n        const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2,\n        void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {\n\n        auto stream = ctx.stream_dnnl(q);\n        auto eng = ctx.engine_dnnl(q);\n\n        dnnl::memory::dims a_dims = {batches_a, m, k };\n        dnnl::memory::dims a_strides = {stra2, stra1, stra0};\n        const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);\n\n        dnnl::memory::dims b_dims = {batches_b, k, n };\n        dnnl::memory::dims b_strides = {strb2, strb0, strb1};\n        const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);\n\n        dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n};\n        dnnl::memory::dims c_strides = {m*n, 1,  m };\n        const auto c_md    = dnnl::memory::desc(c_dims, ct, c_strides);\n        dnnl::primitive_attr primitive_attr;\n        primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);\n\n#ifdef GGML_SYCL_F16\n        primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);\n#endif\n\n        auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));\n        auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));\n        auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr);\n        auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);\n\n        auto scratchpad_md = matmul_pd.scratchpad_desc();\n        auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);\n\n        auto matmul_prim = dnnl::matmul(matmul_pd);\n\n        std::unordered_map<int, dnnl::memory> matmul_args;\n        matmul_args.insert({ DNNL_ARG_SRC, a_mem });\n        matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });\n\n        matmul_args.insert({ DNNL_ARG_DST, c_mem });\n        matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });\n\n        matmul_prim.execute(stream, matmul_args);\n    }\n\n    static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,\n        const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {\n\n        gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);\n    }\n};\n\n#endif\n\n#endif // GGML_SYCL_GEMM_HPP\n"
  },
  {
    "path": "src/ggml-sycl/getrows.cpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#include \"ggml-impl.h\"\n#include \"common.hpp\"\n#include \"dequantize.hpp\"\n#include \"getrows.hpp\"\n\n\ntemplate<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>\nstatic void k_get_rows(\n            const void * src0, const int32_t * src1, dst_t * dst,\n            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/\n            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/\n            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,\n            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,\n            size_t s10, size_t s11, size_t s12,\n            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {\n\n    const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +\n                     item_ct1.get_local_id(2)) *\n                    2;\n    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +\n                    item_ct1.get_local_id(1);\n    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +\n                     item_ct1.get_local_id(0)) /\n                    ne12;\n    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +\n                     item_ct1.get_local_id(0)) %\n                    ne12;\n\n    if (i00 >= ne00) {\n        return;\n    }\n\n    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];\n\n    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;\n    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;\n\n    const int ib = i00/qk; // block index\n    const int iqs = (i00%qk)/qr; // quant index\n    const int iybs = i00 - i00%qk; // dst block start index\n    const int y_offset = qr == 1 ? 1 : qk/2;\n\n    // dequantize\n    dfloat2 v;\n    dequantize_kernel(src0_row, ib, iqs, v);\n\n    dst_row[iybs + iqs + 0] = v.x();\n    dst_row[iybs + iqs + y_offset] = v.y();\n}\n\ntemplate<typename src0_t, typename dst_t>\nstatic void k_get_rows_float(\n            const src0_t * src0, const int32_t * src1, dst_t * dst,\n            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/\n            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/\n            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,\n            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,\n            size_t s10, size_t s11, size_t s12,\n            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {\n\n    const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +\n                    item_ct1.get_local_id(2);\n    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +\n                    item_ct1.get_local_id(1);\n    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +\n                     item_ct1.get_local_id(0)) /\n                    ne12;\n    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +\n                     item_ct1.get_local_id(0)) %\n                    ne12;\n\n    if (i00 >= ne00) {\n        return;\n    }\n\n    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];\n\n    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;\n    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);\n\n    dst_row[i00] = src0_row[i00];\n}\n\ntemplate <int qk, int qr, dequantize_kernel_t dq>\nstatic void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,\n                          ggml_tensor *dst, const void *src0_dd,\n                          const int32_t *src1_dd, float *dst_dd,\n                          queue_ptr stream) {\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);\n    const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);\n    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);\n\n    // strides in elements\n    //const size_t s0 = nb0 / ggml_element_size(dst);\n    const size_t s1 = nb1 / ggml_element_size(dst);\n    const size_t s2 = nb2 / ggml_element_size(dst);\n    const size_t s3 = nb3 / ggml_element_size(dst);\n\n    const size_t s10 = nb10 / ggml_element_size(src1);\n    const size_t s11 = nb11 / ggml_element_size(src1);\n    const size_t s12 = nb12 / ggml_element_size(src1);\n    //const size_t s13 = nb13 / ggml_element_size(src1);\n\n    GGML_ASSERT(ne00 % 2 == 0);\n\n    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                         [=](sycl::nd_item<3> item_ct1) {\n                             k_get_rows<qk, qr, dq>(\n                                 src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,\n                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);\n                         });\n\n    GGML_UNUSED(dst);\n    GGML_UNUSED(ctx);\n}\n\ntemplate <typename src0_t>\nstatic void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,\n                                const ggml_tensor *src1, ggml_tensor *dst,\n                                const src0_t *src0_dd, const int32_t *src1_dd,\n                                float *dst_dd, queue_ptr stream) {\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);\n    const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;\n    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);\n\n    // strides in elements\n    //const size_t s0 = nb0 / ggml_element_size(dst);\n    const size_t s1 = nb1 / ggml_element_size(dst);\n    const size_t s2 = nb2 / ggml_element_size(dst);\n    const size_t s3 = nb3 / ggml_element_size(dst);\n\n    const size_t s10 = nb10 / ggml_element_size(src1);\n    const size_t s11 = nb11 / ggml_element_size(src1);\n    const size_t s12 = nb12 / ggml_element_size(src1);\n    //const size_t s13 = nb13 / ggml_element_size(src1);\n\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) {\n                k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,\n                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);\n            });\n    }\n\n    GGML_UNUSED(dst);\n    GGML_UNUSED(ctx);\n}\n\nvoid ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    GGML_ASSERT(dst->src[0]->nb[0] == ggml_type_size(dst->src[0]->type));\n    GGML_ASSERT(dst->src[1]->nb[0] == ggml_type_size(dst->src[1]->type));\n    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));\n\n    const int32_t * src1_i32 = (const int32_t *) dst->src[1]->data;\n    /* TODO: Refactor and remove duplicates */\n    switch (dst->src[0]->type) {\n        case GGML_TYPE_F16:\n            get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::half *)dst->src[0]->data,\n                                src1_i32, (float *)dst->data, ctx.stream());\n            break;\n        case GGML_TYPE_F32:\n            get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,\n            src1_i32, (float *)dst->data, ctx.stream());\n            break;\n        case GGML_TYPE_Q4_0:\n            get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,\n            src1_i32, (float *)dst->data, ctx.stream());\n            break;\n        case GGML_TYPE_Q4_1:\n            get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,\n            src1_i32, (float *)dst->data, ctx.stream());\n            break;\n        case GGML_TYPE_Q5_0:\n            get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,\n            src1_i32, (float *)dst->data, ctx.stream());\n            break;\n        case GGML_TYPE_Q5_1:\n            get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,\n            src1_i32, (float *)dst->data, ctx.stream());\n            break;\n        case GGML_TYPE_Q8_0:\n            get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,\n            src1_i32, (float *)dst->data, ctx.stream());\n            break;\n        default:\n            // TODO: k-quants\n            GGML_LOG_ERROR(\"%s: unsupported type: %s\\n\", __func__, ggml_type_name(dst->src[0]->type));\n            GGML_ABORT(\"fatal error\");\n    }\n}\n"
  },
  {
    "path": "src/ggml-sycl/getrows.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_GETROWS_HPP\n#define GGML_SYCL_GETROWS_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst);\n\n#endif // GGML_SYCL_GETROWS_HPP\n"
  },
  {
    "path": "src/ggml-sycl/ggml-sycl.cpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#include <algorithm>\n#include <assert.h>\n#include <atomic>\n#include <cinttypes>\n#include <cstddef>\n#include <cstdint>\n#include <cstdlib>\n#include <float.h>\n#include <limits>\n#include <stdint.h>\n#include <stdio.h>\n#include <vector>\n#include <cmath>\n#include <iostream>\n#include <fstream>\n#include <stdio.h>\n#include <stdlib.h>\n#include <regex>\n\n#include <sycl/sycl.hpp>\n#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC\n#    include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>\n#endif\n#include <sycl/half_type.hpp>\n\n#include \"ggml.h\"\n#include \"ggml-sycl.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-backend-impl.h\"\n\n#include \"ggml-sycl/add-id.hpp\"\n#include \"ggml-sycl/backend.hpp\"\n#include \"ggml-sycl/common.hpp\"\n#include \"ggml-sycl/element_wise.hpp\"\n#include \"ggml-sycl/gated_delta_net.hpp\"\n#include \"ggml-sycl/gemm.hpp\"\n#include \"ggml-sycl/getrows.hpp\"\n#include \"ggml-sycl/norm.hpp\"\n#include \"ggml-sycl/presets.hpp\"\n#include \"ggml-sycl/quantize.hpp\"\n#include \"ggml-sycl/repeat_back.hpp\"\n#include \"ggml-sycl/set_rows.hpp\"\n#include \"ggml-sycl/set.hpp\"\n#include \"ggml-sycl/ssm_conv.hpp\"\n#include \"ggml-sycl/sycl_hw.hpp\"\n\n\nstatic bool g_sycl_loaded = false;\nint g_ggml_sycl_debug = 0;\nint g_ggml_sycl_disable_optimize = 0;\nint g_ggml_sycl_disable_graph = 0;\nint g_ggml_sycl_disable_dnn = 0;\nint g_ggml_sycl_prioritize_dmmv = 0;\nint g_ggml_sycl_use_async_mem_op = 0;\nint g_ggml_sycl_enable_flash_attention = 1;\n\n\nstatic ggml_sycl_device_info ggml_sycl_init() {\n    ggml_sycl_device_info info = {};\n\n    info.device_count = dpct::dev_mgr::instance().device_count();\n    if (info.device_count == 0) {\n        GGML_LOG_ERROR(\"%s: failed to initialize: %s\\n\", GGML_SYCL_NAME, __func__);\n        return info;\n    }\n\n    GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);\n\n    int64_t total_vram = 0;\n/* This is a bit misleading;  reserved for later */\n// #if defined(SYCL_USE_XMX)\n//     GGML_LOG_INFO(\"%s: SYCL_USE_XMX: yes\\n\", __func__);\n// #else\n//     GGML_LOG_INFO(\"%s: SYCL_USE_XMX: no\\n\", __func__);\n// #endif\n    for (int i = 0; i < info.device_count; ++i) {\n        info.devices[i].vmm = 0;\n        dpct::device_info prop;\n        sycl::device device = dpct::dev_mgr::instance().get_device(i);\n\n        SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(\n            prop, device)));\n\n        info.default_tensor_split[i] = total_vram;\n        total_vram += prop.get_global_mem_size();\n\n        info.devices[i].cc =\n            100 * prop.get_major_version() + 10 * prop.get_minor_version();\n        info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores\n        info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);\n        info.devices[i].smpbo = prop.get_local_mem_size();\n        info.devices[i].warp_size = WARP_SIZE;\n\n        info.max_work_group_sizes[i] = prop.get_max_work_group_size();\n        info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units();\n\n    }\n\n    for (int id = 0; id < info.device_count; ++id) {\n        info.default_tensor_split[id] /= total_vram;\n    }\n    return info;\n}\n\nconst ggml_sycl_device_info & ggml_sycl_info() {\n    static ggml_sycl_device_info info = ggml_sycl_init();\n    return info;\n}\n\nstatic void print_device_detail(int id, sycl::device &device, std::string device_type) {\n\n    dpct::device_info prop;\n    SYCL_CHECK(CHECK_TRY_ERROR(\n        dpct::get_device_info(prop, device)));\n\n    std::string version;\n    version += std::to_string(prop.get_major_version());\n    version += \".\";\n    version += std::to_string(prop.get_minor_version());\n\n    device_type = std::regex_replace(device_type, std::regex(\"ext_oneapi_\"), \"\");\n    std::string name = std::string(prop.get_name());\n    name = std::regex_replace(name, std::regex(\"\\\\(R\\\\)\"), \"\");\n    name = std::regex_replace(name, std::regex(\"\\\\(TM\\\\)\"), \"\");\n\n    auto global_mem_size = prop.get_global_mem_size()/1000000;\n    GGML_LOG_INFO(\"|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\\n\", id, device_type.c_str(),\n            name.c_str(), version.c_str(), prop.get_max_compute_units(),\n            prop.get_max_work_group_size(), prop.get_max_sub_group_size(),\n            global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());\n}\n\nstatic void print_device_opt_feature(int device_count) {\n    GGML_LOG_INFO(\"SYCL Optimization Feature:\\n\");\n    GGML_LOG_INFO(\n        \"|ID|        Device Type|Reorder|\\n\");\n    GGML_LOG_INFO(\n        \"|--|-------------------|-------|\\n\");\n    std::map<std::string, size_t> DeviceNums;\n    for (int id = 0; id < device_count; ++id) {\n      sycl::device device = dpct::dev_mgr::instance().get_device(id);\n      std::string backend_type = get_device_backend_and_type(device);\n      int type_id = DeviceNums[backend_type]++;\n      std::stringstream device_type;\n      device_type << \"[\" << backend_type << \":\" << std::to_string(type_id)\n                  << \"]\";\n      std::string device_type_s = device_type.str();\n      device_type_s = std::regex_replace(device_type_s, std::regex(\"ext_oneapi_\"), \"\");\n      GGML_LOG_INFO(\"|%2d|%19s|%7s|\\n\", id, device_type_s.c_str(),\n        ggml_sycl_info().devices[id].opt_feature.reorder ? \"Y\": \"N\");\n    }\n\n}\nvoid ggml_backend_sycl_print_sycl_devices() {\n    GGML_SYCL_DEBUG(\"[SYCL] call ggml_backend_sycl_print_sycl_devices\\n\");\n    int device_count = dpct::dev_mgr::instance().device_count();\n    std::map<std::string, size_t> DeviceNums;\n    GGML_LOG_INFO(\"Found %d SYCL devices:\\n\", device_count);\n\n    GGML_LOG_INFO(\n        \"|  |                   |                                       |      \"\n        \" |Max    |        |Max  |Global |                     |\\n\");\n    GGML_LOG_INFO(\n        \"|  |                   |                                       |      \"\n        \" |compute|Max work|sub  |mem    |                     |\\n\");\n    GGML_LOG_INFO(\n        \"|ID|        Device Type|                                   \"\n        \"Name|Version|units  |group   |group|size   |       Driver version|\\n\");\n    GGML_LOG_INFO(\n        \"|--|-------------------|---------------------------------------|------\"\n        \"-|-------|--------|-----|-------|---------------------|\\n\");\n\n    for (int id = 0; id < device_count; ++id) {\n      sycl::device device = dpct::dev_mgr::instance().get_device(id);\n      std::string backend_type = get_device_backend_and_type(device);\n      int type_id = DeviceNums[backend_type]++;\n      std::stringstream device_type;\n      device_type << \"[\" << backend_type << \":\" << std::to_string(type_id)\n                  << \"]\";\n      print_device_detail(id, device, device_type.str());\n    }\n\n    print_device_opt_feature(device_count);\n}\n\nstatic inline int get_sycl_env(const char *env_name, int default_val) {\n    char *user_device_string = getenv(env_name);\n    int user_number = default_val;\n\n    unsigned n;\n    if (user_device_string != NULL &&\n        sscanf(user_device_string, \" %u\", &n) == 1) {\n        user_number = (int)n;\n    } else {\n        user_number = default_val;\n    }\n    return user_number;\n}\n\nstatic void ggml_check_sycl() try {\n    static bool initialized = false;\n\n    if (!initialized) {\n        g_ggml_sycl_debug = get_sycl_env(\"GGML_SYCL_DEBUG\", 0);\n        g_ggml_sycl_disable_optimize = get_sycl_env(\"GGML_SYCL_DISABLE_OPT\", 0);\n        g_ggml_sycl_disable_graph = get_sycl_env(\"GGML_SYCL_DISABLE_GRAPH\", 1);\n        g_ggml_sycl_disable_dnn = get_sycl_env(\"GGML_SYCL_DISABLE_DNN\", 0);\n        g_ggml_sycl_prioritize_dmmv = get_sycl_env(\"GGML_SYCL_PRIORITIZE_DMMV\", 0);\n\n#ifdef SYCL_FLASH_ATTN\n        g_ggml_sycl_enable_flash_attention = get_sycl_env(\"GGML_SYCL_ENABLE_FLASH_ATTN\", 1);\n#else\n        g_ggml_sycl_enable_flash_attention = 0;\n#endif\n\n        GGML_SYCL_DEBUG(\"[SYCL] call ggml_check_sycl\\n\");\n\n        GGML_LOG_INFO(\"Build with Macros:\\n\");\n#if defined(GGML_SYCL_FORCE_MMQ)\n        GGML_LOG_INFO(\"  GGML_SYCL_FORCE_MMQ: yes\\n\");\n#else\n        GGML_LOG_INFO(\"  GGML_SYCL_FORCE_MMQ: no\\n\");\n#endif\n#if defined(GGML_SYCL_F16)\n        GGML_LOG_INFO(\"  GGML_SYCL_F16: yes\\n\");\n#else\n        GGML_LOG_INFO(\"  GGML_SYCL_F16: no\\n\");\n#endif\n#if defined(GGML_SYCL_GRAPH)\n        GGML_LOG_INFO(\"  GGML_SYCL_GRAPH: yes\\n\");\n#else\n        GGML_LOG_INFO(\"  GGML_SYCL_GRAPH: no\\n\");\n#endif\n#if defined(GGML_SYCL_DNNL)\n        GGML_LOG_INFO(\"  GGML_SYCL_DNNL: yes\\n\");\n#else\n        GGML_LOG_INFO(\"  GGML_SYCL_DNNL: no\\n\");\n#endif\n\n        GGML_LOG_INFO(\"Running with Environment Variables:\\n\");\n        GGML_LOG_INFO(\"  GGML_SYCL_DEBUG: %d\\n\", g_ggml_sycl_debug);\n        GGML_LOG_INFO(\"  GGML_SYCL_DISABLE_OPT: %d\\n\", g_ggml_sycl_disable_optimize);\n#ifdef GGML_SYCL_GRAPH\n        GGML_LOG_INFO(\"  GGML_SYCL_DISABLE_GRAPH: %d\\n\", g_ggml_sycl_disable_graph);\n#else\n        GGML_LOG_INFO(\"  GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\\n\");\n#endif\n#if GGML_SYCL_DNNL\n        GGML_LOG_INFO(\"  GGML_SYCL_DISABLE_DNN: %d\\n\", g_ggml_sycl_disable_dnn);\n#else\n        GGML_LOG_INFO(\"  GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\\n\");\n#endif\n        GGML_LOG_INFO(\"  GGML_SYCL_PRIORITIZE_DMMV: %d\\n\", g_ggml_sycl_prioritize_dmmv);\n\n#ifdef SYCL_FLASH_ATTN\n        GGML_LOG_INFO(\"  GGML_SYCL_ENABLE_FLASH_ATTN: %d\\n\", g_ggml_sycl_enable_flash_attention);\n#else\n        GGML_LOG_INFO(\"  GGML_SYCL_ENABLE_FLASH_ATTN: %d disabled by compile flag\\n\",\n            g_ggml_sycl_enable_flash_attention);\n#endif\n\n/* NOT REMOVE, keep it for next optimize for XMX.\n#if defined(SYCL_USE_XMX)\n        fprintf(stderr, \"%s: SYCL_USE_XMX: yes\\n\", __func__);\n#else\n        fprintf(stderr, \"%s: SYCL_USE_XMX: no\\n\", __func__);\n#endif\n*/\n        // Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be\n        // properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in\n        // other places.\n#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC\n        g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph;\n        if (g_ggml_sycl_use_async_mem_op) {\n            for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) {\n                if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) {\n                    g_ggml_sycl_use_async_mem_op = 0;\n                    break;\n                }\n            }\n        }\n#endif\n        if (CHECK_TRY_ERROR(g_all_sycl_device_count =\n                            dpct::dev_mgr::instance().device_count()) != 0) {\n            initialized = true;\n            g_sycl_loaded = false;\n            return;\n        }\n        GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES);\n\n        initialized = true;\n        g_sycl_loaded = true;\n        ggml_backend_sycl_print_sycl_devices();\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\n/*\ndevice_index: device index from 0 to n (continue numbers).\n    It is used for device select/set in SYCL backend internal data structure.\n*/\ninline void check_allow_gpu_index(const int device_index) {\n  if (device_index >= ggml_sycl_info().device_count) {\n    char error_buf[256];\n    snprintf(\n        error_buf,\n        sizeof(error_buf),\n        \"%s error: device_index:%d is out of range: [0-%d]\",\n        __func__,\n        device_index,\n        ggml_sycl_info().device_count - 1);\n    GGML_LOG_ERROR(\"%s\\n\", error_buf);\n    assert(false);\n  }\n}\n\nGGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len) try {\n    GGML_SYCL_DEBUG(\"[SYCL] call ggml_backend_sycl_get_gpu_list\\n\");\n    for(int i=0;i<max_len;i++) id_list[i] = -1;\n\n    for (int i=0;i< ggml_sycl_info().device_count;i++){\n        if (i>=max_len) break;\n        id_list[i] = i;\n    }\n    return;\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\n// sycl buffer\n\nstruct ggml_backend_sycl_buffer_context {\n    int device;\n    void * dev_ptr = nullptr;\n    queue_ptr stream;\n    std::string name;\n    optimize_feature opt_feature;\n    std::vector<ggml_tensor_extra_gpu *> tensor_extras;\n\n    ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :\n        device(device), dev_ptr(dev_ptr), stream(stream) {\n            check_allow_gpu_index(device);\n            name = (GGML_SYCL_NAME + std::to_string(device));\n            opt_feature = ggml_sycl_info().devices[device].opt_feature;\n        }\n\n    ~ggml_backend_sycl_buffer_context() {\n        if (dev_ptr != nullptr) {\n            ggml_sycl_set_device(device);\n            SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));\n        }\n\n        //release extra used by tensors\n        for (ggml_tensor_extra_gpu * extra : tensor_extras) {\n            release_extra_gpu(extra);\n        }\n\n    }\n};\n\nstatic const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft);\n\nstatic bool ggml_backend_buffer_is_sycl(ggml_backend_buffer_t buffer) {\n    return buffer->buft->iface.get_name == ggml_backend_sycl_buffer_type_get_name;\n}\n\nstatic void\nggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {\n    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;\n    ggml_sycl_set_device(ctx->device);\n\n    delete ctx;\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {\n    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;\n    return ctx->dev_ptr;\n}\n\nstatic enum ggml_status\nggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,\n                                     ggml_tensor *tensor) try {\n    GGML_SYCL_DEBUG(\"[SYCL] call %s\", __func__);\n    GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\": tensor\", tensor, \"\\n\").c_str());\n    ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;\n\n    if (tensor->view_src != NULL) {\n        assert(tensor->view_src->buffer->buft == buffer->buft);\n        return GGML_STATUS_SUCCESS;\n    }\n    if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&\n        !g_ggml_sycl_disable_optimize) {\n        ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};\n        tensor->extra                 = extra;\n        ctx->tensor_extras.push_back(extra);  //used to release it when destroy ctx.\n    }\n\n    if (ggml_is_quantized(tensor->type)) {\n        // initialize padding to 0 to avoid possible NaN values\n        size_t original_size = ggml_nbytes(tensor);\n        size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);\n\n        if (padded_size > original_size && tensor->view_src == nullptr) {\n            SYCL_CHECK(CHECK_TRY_ERROR(ctx->stream->memset(\n                (char *)tensor->data + original_size, 0,\n                padded_size - original_size).wait()));\n        }\n    }\n    return GGML_STATUS_SUCCESS;\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,\n                                                ggml_tensor *tensor,\n                                                const void *data, size_t offset,\n                                                size_t size) try {\n    GGML_SYCL_DEBUG(\"[SYCL] call %s\", __func__);\n    GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\": tensor\", tensor).c_str());\n    GGML_SYCL_DEBUG(\" size=%zu offset=%zu\\n\", size, offset);\n    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;\n    ggml_sycl_set_device(ctx->device);\n    auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());\n    SYCL_CHECK(CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));\n#ifndef _WIN32\n    // Note: Use host buffer to save the data from mmap(), then copy to device. It's workaround for mmap() issue on PVC GPU.\n    // This function will be called during load model from disk. Use memory buffer replace dynamic won't save more time and brings potential memory leak risk here.\n    char * host_buf = (char *) malloc(size);\n    memcpy(host_buf, data, size);\n    SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, host_buf, size).wait()));\n    free(host_buf);\n#else\n    SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, data, size).wait()));\n#endif\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,\n                                                const ggml_tensor *tensor,\n                                                void *data, size_t offset,\n                                                size_t size) try {\n    GGML_SYCL_DEBUG(\"[SYCL] call %s\", __func__);\n    GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\": tensor\", tensor).c_str());\n    GGML_SYCL_DEBUG(\" size=%zu offset=%zu\\n\", size, offset);\n    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;\n\n    ggml_sycl_set_device(ctx->device);\n    auto stream = dpct::dev_mgr::instance().get_device(ctx->device).default_queue();\n\n    SYCL_CHECK(CHECK_TRY_ERROR(\n        stream.memcpy(data, (const char *)tensor->data + offset, size)\n            .wait()));\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,\n                    const void *ptr_src, size_t size) {\n    char *host_buf = (char *)malloc(size);\n    q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();\n    q_dst.memcpy((char *)ptr_dst, host_buf, size).wait();\n    free(host_buf);\n}\n\nstatic bool\nggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,\n                                    const ggml_tensor *src,\n                                    ggml_tensor *dst) try {\n    bool is_cpy_supported = ggml_backend_buffer_is_sycl(src->buffer);\n    GGML_SYCL_DEBUG(\"[SYCL] call %s\", __func__);\n    GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\": dst\", dst).c_str());\n    GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\" src\", src).c_str());\n    GGML_SYCL_DEBUG(\" is_cpy_supported=%d\\n\", is_cpy_supported);\n    if (is_cpy_supported) {\n        ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context;\n        ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context;\n\n        ggml_sycl_set_device(src_ctx->device);\n        /*\n        DPCT1009:198: SYCL uses exceptions to report errors and does not use the\n        error codes. The original code was commented out and a warning string\n        was inserted. You need to rewrite this code.\n        */\n        SYCL_CHECK(CHECK_TRY_ERROR(\n            dpct::dev_mgr::instance().get_device(src_ctx->device).queues_wait_and_throw()));\n        ggml_sycl_set_device(dst_ctx->device);\n        /*\n        DPCT1009:199: SYCL uses exceptions to report errors and does not use the\n        error codes. The original code was commented out and a warning string\n        was inserted. You need to rewrite this code.\n        */\n        SYCL_CHECK(CHECK_TRY_ERROR(\n            dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));\n        /*\n        DPCT1009:200: SYCL uses exceptions to report errors and does not use the\n        error codes. The original code was commented out and a warning string\n        was inserted. You need to rewrite this code.\n        */\n\n        queue_ptr stream_dst = dst_ctx->stream;\n        queue_ptr stream_src = src_ctx->stream;\n        size_t size = ggml_nbytes(src);\n\n        //todo. it's dirty solutino to walkaroud known issue:device2device cross GPUs.\n        dev2dev_memcpy(*stream_dst, *stream_src, dst->data, src->data, size);\n\n//todo, it's known issue：error in device2device cross GPUs. reused when the issue is fixed. DON\"T remove\n#if 0\n        SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(\n            (char *)dst->data, (const char *)src->data, size).wait()));\n\n        /*\n        DPCT1009:201: SYCL uses exceptions to report errors and does not use the\n        error codes. The original code was commented out and a warning string\n        was inserted. You need to rewrite this code.\n        */\n        SYCL_CHECK(CHECK_TRY_ERROR(\n            dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));\n#endif\n        return true;\n    }\n    return false;\n    GGML_UNUSED(buffer);\n} catch (const sycl::exception & exc) {\n    std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__ << \", line:\" << __LINE__ << std::endl;\n    std::exit(1);\n}\n\nstatic void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,\n                                           uint8_t value) try {\n    GGML_SYCL_DEBUG(\"[SYCL] call %s: size=%zu\\n\", __func__, buffer->size);\n    ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;\n\n    ggml_sycl_set_device(ctx->device);\n    queue_ptr stream = ctx->stream;\n    SYCL_CHECK(\n        CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw()));\n\n    SYCL_CHECK(CHECK_TRY_ERROR((*stream)\n                                    .memset(ctx->dev_ptr, value, buffer->size)\n                                    .wait()));\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,\n                                                   size_t offset, size_t size) {\n    GGML_SYCL_DEBUG(\"[SYCL] call %s\", __func__);\n    GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\": tensor\", tensor).c_str());\n    GGML_SYCL_DEBUG(\" size=%zu offset=%zu value=%u\\n\", size, offset, value);\n    ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;\n    SYCL_CHECK(ggml_sycl_set_device(ctx->device));\n    auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());\n    if (size == 0) {\n        return;  // Nothing to do\n    }\n    if (tensor->data == nullptr) {\n        GGML_ABORT(\"Error: Tensor data pointer is null.\\n\");\n    }\n    void * target_ptr = static_cast<char *>(tensor->data) + offset;\n    SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size)));\n    SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait()));\n}\n\nstatic void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {\n    GGML_SYCL_DEBUG(\"[SYCL] call %s\\n\", __func__);\n    if (buffer == nullptr) {\n        return;\n    }\n\n    ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;\n\n    if (ctx != nullptr) {\n        for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) {\n            release_extra_gpu(extra);\n        }\n        ctx->tensor_extras.clear();  // reset the tensor_extras vector\n    }\n}\n\nstatic const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {\n    /* .free_buffer     = */ ggml_backend_sycl_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_sycl_buffer_get_base,\n    /* .init_tensor     = */ ggml_backend_sycl_buffer_init_tensor,\n    /* .memset_tensor   = */ ggml_backend_sycl_buffer_memset_tensor,\n    /* .set_tensor      = */ ggml_backend_sycl_buffer_set_tensor,\n    /* .get_tensor      = */ ggml_backend_sycl_buffer_get_tensor,\n    /* .cpy_tensor      = */ ggml_backend_sycl_buffer_cpy_tensor,\n    /* .clear           = */ ggml_backend_sycl_buffer_clear,\n    /* .reset           = */ ggml_backend_sycl_buffer_reset,\n};\n\n// sycl buffer type\nstruct ggml_backend_sycl_buffer_type_context {\n    int device;\n    std::string name;\n\n    // each buffer type has its own stream\n    queue_ptr stream = nullptr;\n};\n\nstatic const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    ggml_backend_sycl_buffer_type_context * ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;\n\n    return ctx->name.c_str();\n}\n\nstatic ggml_backend_buffer_t\nggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,\n                                           size_t size) try {\n    ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;\n    ggml_sycl_set_device(buft_ctx->device);\n    const queue_ptr stream = buft_ctx->stream;\n    size = std::max(size, (size_t)1); // syclMalloc returns null for size 0\n\n    void * dev_ptr;\n    SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device(\n                                    size, *stream)));\n    if (!dev_ptr) {\n      GGML_LOG_ERROR(\"%s: can't allocate %lu Bytes of memory on device\\n\", __func__, size);\n      return nullptr;\n    }\n    ggml_backend_sycl_buffer_context * ctx = new  ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr, buft_ctx->stream);\n    return ggml_backend_buffer_init(buft, ggml_backend_sycl_buffer_interface, ctx, size);\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    return 128;\n    GGML_UNUSED(buft);\n}\n\nstatic size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {\n    return dpct::get_current_device().get_max_mem_alloc_size();\n\n    GGML_UNUSED(buft);\n}\n\nstatic size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {\n    size_t size = ggml_nbytes(tensor);\n    int64_t ne0 = tensor->ne[0];\n\n    if (ggml_is_quantized(tensor->type)) {\n        if (ne0 % MATRIX_ROW_PADDING != 0) {\n            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);\n        }\n    }\n\n    return size;\n\n    GGML_UNUSED(buft);\n}\n\nstatic const ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {\n    /* .get_name         = */ ggml_backend_sycl_buffer_type_get_name,\n    /* .alloc_buffer     = */ ggml_backend_sycl_buffer_type_alloc_buffer,\n    /* .get_alignment    = */ ggml_backend_sycl_buffer_type_get_alignment,\n    /* .get_max_size     = */ ggml_backend_sycl_buffer_type_get_max_size,\n    /* .get_alloc_size   = */ ggml_backend_sycl_buffer_type_get_alloc_size,\n    /* .is_host          = */ NULL,\n};\n\nggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {\n    static std::mutex mutex;\n    std::lock_guard<std::mutex> lock(mutex);\n\n\n    auto dev_count = ggml_backend_sycl_get_device_count();\n\n    if (device>=dev_count or device<0) {\n        GGML_LOG_ERROR(\"ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\\n\",\n            device, dev_count-1);\n        GGML_ASSERT(device<dev_count);\n    }\n    static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];\n\n    static bool ggml_backend_sycl_buffer_type_initialized = false;\n\n    if (!ggml_backend_sycl_buffer_type_initialized) {\n        for (int i = 0; i < dev_count; i++) {\n            auto & device_i = dpct::dev_mgr::instance().get_device(i);\n            queue_ptr stream = &(device_i.default_queue());\n            ggml_backend_sycl_buffer_types[i] = {\n                /* .iface    = */ ggml_backend_sycl_buffer_type_interface,\n                /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), i),\n                /* .context  = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream},\n            };\n        }\n        ggml_backend_sycl_buffer_type_initialized = true;\n    }\n    return &ggml_backend_sycl_buffer_types[device];\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {\n    GGML_SYCL_DEBUG(\"[SYCL] call ggml_backend_sycl_buffer_type\\n\");\n\n    int device = ctx->device;\n    if (device>=ggml_sycl_info().device_count or device<0) {\n        GGML_LOG_ERROR(\"ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\\n\",\n            device, ggml_sycl_info().device_count-1);\n        GGML_ASSERT(device<ggml_sycl_info().device_count);\n    }\n    static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];\n\n    static bool ggml_backend_sycl_buffer_type_initialized = false;\n\n    if (!ggml_backend_sycl_buffer_type_initialized) {\n        for (int i = 0; i < ggml_sycl_info().device_count; i++) {\n            ggml_backend_sycl_buffer_types[i] = {\n                /* .iface    = */ ggml_backend_sycl_buffer_type_interface,\n                /* .device   = */ nullptr,\n                /* .context  = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), ctx->stream(i, 0)},\n            };\n        }\n        ggml_backend_sycl_buffer_type_initialized = true;\n    }\n    return &ggml_backend_sycl_buffer_types[device];\n}\n\n// sycl split buffer\n\nstatic int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split) {\n    int64_t min_compute_capability = INT_MAX;\n    int64_t max_compute_capability = INT_MIN;\n    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {\n        if (tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? tensor_split[i + 1] : 1.0f)) {\n            if (min_compute_capability > ggml_sycl_info().devices[i].cc) {\n                min_compute_capability = ggml_sycl_info().devices[i].cc;\n            }\n            if (max_compute_capability < ggml_sycl_info().devices[i].cc) {\n                max_compute_capability = ggml_sycl_info().devices[i].cc;\n            }\n        }\n    }\n\n    switch(type) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n            return max_compute_capability >= VER_GEN9 ? 128 : 64;\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n            return 64;\n        case GGML_TYPE_F16:\n        case GGML_TYPE_F32:\n            return 1;\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ2_S:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ4_NL:\n            return max_compute_capability >= VER_GEN9 ? 128 : 64;\n        case GGML_TYPE_IQ3_S:\n            return max_compute_capability >= VER_GEN9 ? 128 : 64;\n        case GGML_TYPE_Q6_K:\n            return 64;\n        default:\n            GGML_ABORT(\"fatal error\");\n    }\n}\n\nstatic void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split, int id) {\n    const int64_t nrows = ggml_nrows(tensor);\n    const int64_t rounding = get_row_rounding(tensor->type, tensor_split);\n\n    *row_low = id == 0 ? 0 : nrows*tensor_split[id];\n    *row_low -= *row_low % rounding;\n    if (id == ggml_sycl_info().device_count - 1) {\n        *row_high = nrows;\n    } else {\n        *row_high = nrows*tensor_split[id + 1];\n        *row_high -= *row_high % rounding;\n    }\n}\n\nstatic size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);\n}\n\nstruct ggml_backend_sycl_split_buffer_type_context {\n    std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split;\n};\n\nstruct ggml_backend_sycl_split_buffer_context {\n    ~ggml_backend_sycl_split_buffer_context() try {\n        for (ggml_tensor_extra_gpu * extra : tensor_extras) {\n            release_extra_gpu(extra, streams);\n        }\n    }\n    catch (sycl::exception const &exc) {\n      std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n                << \", line:\" << __LINE__ << std::endl;\n      std::exit(1);\n    }\n\n    std::vector<ggml_tensor_extra_gpu *> tensor_extras;\n    std::vector<queue_ptr> streams;\n};\n\nstatic void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;\n    delete ctx;\n}\n\nstatic void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {\n    // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced\n    return (void *)0x1000;\n\n    GGML_UNUSED(buffer);\n}\n\nstatic enum ggml_status\nggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,\n                                           ggml_tensor *tensor) try {\n    GGML_SYCL_DEBUG(\"[SYCL] call %s\", __func__);\n    GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\": tensor\", tensor, \"\\n\").c_str());\n    GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported\n\n    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;\n    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;\n\n    const int64_t ne0 = tensor->ne[0];\n\n    ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};\n\n    ctx->tensor_extras.push_back(extra);\n    ctx->streams.push_back(&(dpct::get_current_device().default_queue()));\n\n    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {\n        int64_t row_low, row_high;\n        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);\n\n        int64_t nrows_split = row_high - row_low;\n        if (nrows_split == 0) {\n            continue;\n        }\n\n        size_t size = ggml_nbytes_split(tensor, nrows_split);\n        const size_t original_size = size;\n\n        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses\n        if (ne0 % MATRIX_ROW_PADDING != 0) {\n            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);\n        }\n\n        // FIXME: do not crash if SYCL Buffer alloc fails\n        // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first\n        ggml_sycl_set_device(i);\n        const queue_ptr stream = ctx->streams[i];\n        char * buf;\n        /*\n        DPCT1009:208: SYCL uses exceptions to report errors and does not use the\n        error codes. The original code was commented out and a warning string\n        was inserted. You need to rewrite this code.\n        */\n        SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device(\n                                        size, *stream)));\n        if (!buf) {\n            char err_buf[1024];\n            snprintf(err_buf, 1023, \"%s: can't allocate %lu Bytes of memory on device\\n\", __func__, size);\n            throw std::runtime_error(err_buf);\n        }\n        // set padding to 0 to avoid possible NaN values\n        if (size > original_size) {\n            /*\n            DPCT1009:209: SYCL uses exceptions to report errors and does not use\n            the error codes. The original code was commented out and a warning\n            string was inserted. You need to rewrite this code.\n            */\n            SYCL_CHECK(CHECK_TRY_ERROR(\n                (*stream)\n                    .memset(buf + original_size, 0, size - original_size)\n                    .wait()));\n        }\n\n        extra->data_device[i] = buf;\n\n        for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {\n            /*\n            DPCT1009:210: SYCL uses exceptions to report errors and does not use\n            the error codes. The original code was commented out and a warning\n            string was inserted. You need to rewrite this code.\n            */\n            SYCL_CHECK(\n                CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event()));\n        }\n    }\n    tensor->extra = extra;\n    return GGML_STATUS_SUCCESS;\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void\nggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer,\n                                          ggml_tensor *tensor, const void *data,\n                                          size_t offset, size_t size) try {\n    GGML_SYCL_DEBUG(\"[SYCL] call %s\", __func__);\n    GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\": tensor\", tensor).c_str());\n    GGML_SYCL_DEBUG(\" size=%zu offset=%zu\\n\", size, offset);\n    // split tensors must always be set in their entirety at once\n    GGML_ASSERT(offset == 0);\n    GGML_ASSERT(size == ggml_nbytes(tensor));\n\n    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;\n    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;\n\n    const int64_t ne0 = tensor->ne[0];\n    const size_t nb1 = tensor->nb[1];\n    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;\n\n    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {\n        int64_t row_low, row_high;\n        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);\n\n        int64_t nrows_split = row_high - row_low;\n        if (nrows_split == 0) {\n            continue;\n        }\n\n        const size_t offset_split = row_low*nb1;\n        size_t size = ggml_nbytes_split(tensor, nrows_split);\n        const size_t original_size = size;\n\n        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses\n        if (ne0 % MATRIX_ROW_PADDING != 0) {\n            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);\n        }\n\n        const char * buf_host = (const char *)data + offset_split;\n        /*\n        DPCT1009:211: SYCL uses exceptions to report errors and does not use the\n        error codes. The original code was commented out and a warning string\n        was inserted. You need to rewrite this code.\n        */\n        ggml_sycl_set_device(i);\n        const queue_ptr stream = ctx->streams[i];\n        SYCL_CHECK(CHECK_TRY_ERROR(\n            (*stream)\n                .memcpy(extra->data_device[i], buf_host, original_size)\n                .wait()));\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void\nggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer,\n                                          const ggml_tensor *tensor, void *data,\n                                          size_t offset, size_t size) try {\n    GGML_SYCL_DEBUG(\"[SYCL] call %s\", __func__);\n    GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\": tensor\", tensor).c_str());\n    GGML_SYCL_DEBUG(\" size=%zu offset=%zu\\n\", size, offset);\n    // split tensors must always be set in their entirety at once\n    GGML_ASSERT(offset == 0);\n    GGML_ASSERT(size == ggml_nbytes(tensor));\n\n    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;\n    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;\n\n    const int64_t ne0 = tensor->ne[0];\n    const size_t nb1 = tensor->nb[1];\n    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;\n\n    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {\n        int64_t row_low, row_high;\n        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);\n\n        int64_t nrows_split = row_high - row_low;\n        if (nrows_split == 0) {\n            continue;\n        }\n\n        const size_t offset_split = row_low*nb1;\n        size_t size = ggml_nbytes_split(tensor, nrows_split);\n        const size_t original_size = size;\n\n        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses\n        if (ne0 % MATRIX_ROW_PADDING != 0) {\n            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);\n        }\n\n        char * buf_host = (char *)data + offset_split;\n        /*\n        DPCT1009:212: SYCL uses exceptions to report errors and does not use the\n        error codes. The original code was commented out and a warning string\n        was inserted. You need to rewrite this code.\n        */\n        ggml_sycl_set_device(i);\n        const queue_ptr stream = ctx->streams[i];\n        SYCL_CHECK(CHECK_TRY_ERROR(\n            (*stream)\n                .memcpy(buf_host, extra->data_device[i], original_size)\n                .wait()));\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_backend_sycl_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    GGML_UNUSED(buffer);\n    GGML_UNUSED(value);\n}\n\nstatic struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = {\n    /* .free_buffer     = */ ggml_backend_sycl_split_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_sycl_split_buffer_get_base,\n    /* .init_tensor     = */ ggml_backend_sycl_split_buffer_init_tensor,\n    /* .memset_tensor   = */ NULL,\n    /* .set_tensor      = */ ggml_backend_sycl_split_buffer_set_tensor,\n    /* .get_tensor      = */ ggml_backend_sycl_split_buffer_get_tensor,\n    /* .cpy_tensor      = */ NULL,\n    /* .clear           = */ ggml_backend_sycl_split_buffer_clear,\n    /* .reset           = */ NULL,\n};\n\n// sycl split buffer type\n\nstatic const char * ggml_backend_sycl_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    return GGML_SYCL_NAME \"_Split\";\n\n    GGML_UNUSED(buft);\n}\n\nstatic bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {\n   return buffer->buft->iface.get_name == ggml_backend_sycl_split_buffer_type_get_name;\n}\n\nstatic ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point\n    // instead, we allocate them for each tensor separately in init_tensor\n    // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,\n    // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.\n    ggml_backend_sycl_split_buffer_context * ctx = new ggml_backend_sycl_split_buffer_context();\n\n    return ggml_backend_buffer_init(buft, ggml_backend_sycl_split_buffer_interface, ctx, size);\n}\n\nstatic size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    return 128;\n    GGML_UNUSED(buft);\n}\n\nstatic size_t ggml_backend_sycl_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {\n    ggml_backend_sycl_split_buffer_type_context * ctx = (ggml_backend_sycl_split_buffer_type_context *)buft->context;\n\n    size_t total_size = 0;\n\n    const int64_t ne0 = tensor->ne[0];\n\n    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {\n        int64_t row_low, row_high;\n        get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, i);\n\n        int64_t nrows_split = row_high - row_low;\n        if (nrows_split == 0) {\n            continue;\n        }\n\n        total_size += ggml_nbytes_split(tensor, nrows_split);\n\n        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses\n        if (ne0 % MATRIX_ROW_PADDING != 0) {\n            total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);\n        }\n    }\n\n    return total_size;\n}\n\nstatic bool ggml_backend_sycl_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {\n    return false;\n\n    GGML_UNUSED(buft);\n}\n\nstatic ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface = {\n    /* .get_name         = */ ggml_backend_sycl_split_buffer_type_get_name,\n    /* .alloc_buffer     = */ ggml_backend_sycl_split_buffer_type_alloc_buffer,\n    /* .get_alignment    = */ ggml_backend_sycl_split_buffer_type_get_alignment,\n    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX\n    /* .get_alloc_size   = */ ggml_backend_sycl_split_buffer_type_get_alloc_size,\n    /* .is_host          = */ ggml_backend_sycl_split_buffer_type_is_host,\n};\n\nggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {\n    static std::mutex mutex;\n    std::lock_guard<std::mutex> lock(mutex);\n\n    GGML_SYCL_DEBUG(\"[SYCL] call ggml_backend_sycl_split_buffer_type\\n\");\n    ggml_check_sycl();\n    // FIXME: this is not thread safe\n    static std::map<std::array<float, GGML_SYCL_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;\n\n    std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split_arr = {};\n\n    bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_SYCL_MAX_DEVICES, [](float x) { return x == 0.0f; });\n    if (all_zero) {\n        tensor_split_arr = ggml_sycl_info().default_tensor_split;\n    } else {\n        float split_sum = 0.0f;\n        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {\n            tensor_split_arr[i] = split_sum;\n            split_sum += tensor_split[i];\n        }\n        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {\n            tensor_split_arr[i] /= split_sum;\n        }\n    }\n\n    auto it = buft_map.find(tensor_split_arr);\n    if (it != buft_map.end()) {\n        return &it->second;\n    }\n\n    struct ggml_backend_buffer_type buft {\n        /* .iface   = */ ggml_backend_sycl_split_buffer_type_interface,\n        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0),\n        /* .context = */ new ggml_backend_sycl_split_buffer_type_context{tensor_split_arr},\n    };\n\n    auto result = buft_map.emplace(tensor_split_arr, buft);\n    return &result.first->second;\n}\n\n// host buffer type\n\nstatic const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_type_t buft) {\n    return GGML_SYCL_NAME \"_Host\";\n\n    GGML_UNUSED(buft);\n}\n\ninline void * aligned_malloc_host(size_t alignment, size_t size) {\n#ifdef _WIN32\n    return _aligned_malloc(size, alignment);\n#else\n    return aligned_alloc(alignment, size);\n#endif\n}\n\ninline void free_aligned_mem_host(void * memblock) {\n#ifdef _WIN32\n    _aligned_free(memblock);\n#else\n    free(memblock);\n#endif\n}\n\nstatic void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    free_aligned_mem_host((void *)buffer->context);\n}\n\nstatic ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    void * ptr = aligned_malloc_host(TENSOR_ALIGNMENT, size);\n    if (ptr == nullptr) {\n        // fallback to cpu buffer\n        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);\n    }\n\n    // FIXME: this is a hack to avoid having to implement a new buffer type\n    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);\n    buffer->buft = buft;\n    buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer;\n\n    return buffer;\n}\n\nggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type() {\n    GGML_SYCL_DEBUG(\"[SYCL] call ggml_backend_sycl_host_buffer_type\\n\");\n    static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_type_host = {\n        /* .iface    = */ {\n            /* .get_name         = */ ggml_backend_sycl_host_buffer_type_name,\n            /* .alloc_buffer     = */ ggml_backend_sycl_host_buffer_type_alloc_buffer,\n            /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,\n            /* .get_max_size     = */ NULL, // TODO: return device.maxBufferLength\n            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,\n            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,\n        },\n        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0),\n        /* .context  = */ nullptr,\n    };\n\n    return &ggml_backend_sycl_buffer_type_host;\n}\n\n// buffer pool for sycl (legacy)\nstruct ggml_sycl_pool_leg : public ggml_sycl_pool {\n    static const int MAX_SYCL_BUFFERS = 256;\n\n    int device;\n    queue_ptr qptr;\n    struct ggml_sycl_buffer {\n        void * ptr = nullptr;\n        size_t size = 0;\n    };\n\n    ggml_sycl_buffer buffer_pool[MAX_SYCL_BUFFERS] = {};\n    size_t pool_size = 0;\n\n    explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : device(device_), qptr(qptr_) {}\n\n    ~ggml_sycl_pool_leg() {\n        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {\n            ggml_sycl_buffer & b = buffer_pool[i];\n            if (b.ptr != nullptr) {\n                SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));\n                pool_size -= b.size;\n            }\n        }\n        GGML_ASSERT(pool_size == 0);\n    }\n\n    void * alloc(size_t size, size_t * actual_size) override {\n#ifdef DEBUG_sycl_MALLOC\n        int nnz = 0;\n        size_t max_size = 0;\n#endif\n        size_t best_diff = 1ull << 36;\n        int ibest = -1;\n        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {\n            ggml_sycl_buffer& b = buffer_pool[i];\n            if (b.ptr != nullptr) {\n#ifdef DEBUG_sycl_MALLOC\n                ++nnz;\n                if (b.size > max_size) max_size = b.size;\n#endif\n                if (b.size >= size) {\n                    size_t diff = b.size - size;\n                    if (diff < best_diff) {\n                        best_diff = diff;\n                        ibest = i;\n                        if (!best_diff) {\n                            void * ptr = b.ptr;\n                            *actual_size = b.size;\n                            b.ptr = nullptr;\n                            b.size = 0;\n                            return ptr;\n                        }\n                    }\n                }\n            }\n        }\n        if (ibest >= 0) {\n            ggml_sycl_buffer& b = buffer_pool[ibest];\n            void * ptr = b.ptr;\n            *actual_size = b.size;\n            b.ptr = nullptr;\n            b.size = 0;\n            return ptr;\n        }\n        void * ptr;\n        size_t look_ahead_size = (size_t) (1.05 * size);\n\n        SYCL_CHECK(\n            CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device(\n                                look_ahead_size, *qptr)));\n        if (!ptr) {\n            GGML_LOG_ERROR(\"%s: can't allocate %lu Bytes of memory on device/GPU\\n\", __func__, look_ahead_size);\n            return nullptr;\n        }\n\n        *actual_size = look_ahead_size;\n        pool_size += look_ahead_size;\n\n#ifdef DEBUG_SYCL_MALLOC\n        GGML_LOG_DEBUG(\"%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\\n\", __func__, id, nnz,\n                (uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));\n#endif\n\n        // GGML_SYCL_DEBUG(\"ggml_sycl_pool_malloc_leg look_ahead_size=%lu, return %p\\n\", look_ahead_size, ptr);\n        return ptr;\n    }\n\n    void free(void * ptr, size_t size) override {\n        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {\n            ggml_sycl_buffer& b = buffer_pool[i];\n            if (b.ptr == nullptr) {\n                b.ptr = ptr;\n                b.size = size;\n                return;\n            }\n        }\n        GGML_LOG_WARN(\"WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\\n\");\n        SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr)));\n        pool_size -= size;\n    }\n};\n\nstruct ggml_sycl_pool_host : public ggml_sycl_pool {\n    queue_ptr qptr;\n    int       device;\n\n    inline static int counter{ 0 };\n\n    struct ggml_sycl_buffer {\n        void * ptr  = nullptr;\n        size_t size = 0;\n    };\n\n    // Set arbitrarly to 64\n    static constexpr int          MAX_POOL_SIZE{ 64 };\n    std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);\n    size_t                        pool_size   = 0;\n\n    explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}\n\n    ~ggml_sycl_pool_host() {\n        for (int i = 0; i < MAX_POOL_SIZE; ++i) {\n            ggml_sycl_buffer & b = buffer_pool[i];\n            if (b.ptr != nullptr) {\n                SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));\n                b.ptr = nullptr;\n                pool_size -= b.size;\n                b.size = 0;\n            }\n        }\n        counter = 0;\n    }\n\n    void * alloc(size_t size, size_t * actual_size) override {\n        if (counter == MAX_POOL_SIZE) {\n            ggml_sycl_buffer b               = buffer_pool[0];\n            void *           ptr             = b.ptr;\n            *actual_size                     = b.size;\n            counter                          = 1;\n            return ptr;\n        }\n        ggml_sycl_buffer & b = buffer_pool[counter];\n\n        if (b.ptr == nullptr) {\n            void * ptr;\n\n            SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));\n            if (!ptr) {\n                GGML_LOG_ERROR(\"%s: can't allocate %lu Bytes of memory on host\\n\", __func__, size);\n                return nullptr;\n            }\n            pool_size += size;\n            *actual_size = size;\n            counter      = counter + 1;\n            return ptr;\n        } else {\n            ++counter;\n            b.size = size;\n            return b.ptr;\n        }\n    }\n\n    void free(void * ptr, size_t size) override {\n        // if the pool is not completed add the pointer to it in place of the first nullptr found.\n        // Otherwise do nothing, pointers will be freed once the pool is deallocated.\n        for (int i = 0; i < MAX_POOL_SIZE; ++i) {\n            ggml_sycl_buffer & b = buffer_pool[i];\n            if (b.ptr == nullptr) {\n                b.ptr  = ptr;\n                b.size = size;\n                return;\n            }\n        }\n    }\n};\n\nstd::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {\n    // return pool for the host to speed up memory management\n    return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));\n}\n\nstd::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {\n    // TBD: NO VMM support\n    // if (ggml_sycl_info().devices[device].vmm) {\n    //     return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(device));\n    // }\n   return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device));\n}\n\n// TBD pool with virtual memory management\n// struct ggml_sycl_pool_vmm : public ggml_sycl_pool\n\n/// kernels\ntypedef void (*ggml_sycl_op_mul_mat_t)(\n    ggml_backend_sycl_context & ctx,\n    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,\n    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,\n    float *dst_dd_i, const int64_t row_low, const int64_t row_high,\n    const int64_t src1_ncols, const int64_t src1_padded_row_size,\n    const queue_ptr &stream);\n\n\n\nstatic void mul_mat_p021_f16_f32(\n    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,\n    const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,\n    const sycl::nd_item<3> &item_ct1) {\n\n    const sycl::half *x = (const sycl::half *)vx;\n\n    const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +\n                      item_ct1.get_local_id(1);\n    const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +\n                        item_ct1.get_local_id(0);\n    const int channel_x = channel / (nchannels_y / nchannels_x);\n\n    const int nrows_y = ncols_x;\n    const int nrows_dst = nrows_x;\n    const int row_dst = row_x;\n\n    float tmp = 0.0f;\n\n    for (int col_x0 = 0; col_x0 < ncols_x;\n         col_x0 += item_ct1.get_local_range(2)) {\n        const int col_x = col_x0 + item_ct1.get_local_id(2);\n\n        if (col_x >= ncols_x) {\n            break;\n        }\n\n        // x is transposed and permuted\n        const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;\n        const float xi =\n            sycl::vec<sycl::half, 1>(x[ix])\n                .convert<float, sycl::rounding_mode::automatic>()[0];\n\n        const int row_y = col_x;\n\n\n        // y is not transposed but permuted\n        const int iy = channel*nrows_y + row_y;\n\n        tmp += xi * y[iy];\n    }\n\n    // dst is not transposed and not permuted\n    const int idst = channel*nrows_dst + row_dst;\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[idst] = tmp;\n    }\n}\n\nstatic void mul_mat_vec_nc_f16_f32( // nc == non-contiguous\n    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,\n    const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,\n    const sycl::nd_item<3> &item_ct1) {\n\n    const sycl::half *x = (const sycl::half *)vx;\n\n    const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +\n                      item_ct1.get_local_id(1);\n    const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +\n                        item_ct1.get_local_id(0);\n    const int channel_x = channel / channel_x_divisor;\n\n    const int nrows_dst = nrows_x;\n    const int row_dst   = row_x;\n\n    const int idst = channel*nrows_dst + row_dst;\n\n    float tmp = 0.0f;\n\n    for (int col_x0 = 0; col_x0 < ncols_x;\n         col_x0 += item_ct1.get_local_range(2)) {\n        const int col_x = col_x0 + item_ct1.get_local_id(2);\n\n        if (col_x >= ncols_x) {\n            break;\n        }\n\n        const int row_y = col_x;\n\n        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;\n        const int iy = channel * channel_stride_y + row_y;\n\n        const float xi =\n            sycl::vec<sycl::half, 1>(x[ix])\n                .convert<float, sycl::rounding_mode::automatic>()[0];\n\n        tmp += xi * y[iy];\n    }\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[idst] = tmp;\n    }\n}\n\nstatic void k_sum_rows_f32(const float * x, float * dst, const int ncols,\n                           const sycl::nd_item<3> &item_ct1) {\n    const int row = item_ct1.get_group(1);\n    const int col = item_ct1.get_local_id(2);\n\n    float sum = 0.0f;\n    for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) {\n        sum += x[row * ncols + i];\n    }\n\n    sum = warp_reduce_sum(sum, item_ct1);\n\n    if (col == 0) {\n        dst[row] = sum;\n    }\n}\n\n\ntemplate<typename T>\nstatic inline void ggml_sycl_swap(T & a, T & b) {\n    T tmp = a;\n    a = b;\n    b = tmp;\n}\n\ntemplate <ggml_sort_order order>\n__dpct_inline__ static void\nk_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,\n                  const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,\n                  uint8_t *dpct_local) {\n    // bitonic sort\n    int col_index =  item_ct1.get_local_id(2);\n    int row = item_ct1.get_group(1);\n\n    for (int i = 0; i < tasks_per_thread; i++) {\n        int col = col_index * tasks_per_thread + i;\n        if (col >= ncols_pad) {\n            return;\n        }\n    }\n\n    const float * x_row = x + row * ncols;\n    auto dst_row = (int *)dpct_local;\n\n    // initialize indices\n    for (int i=0;i<tasks_per_thread;i++){\n        int col = col_index*tasks_per_thread+i;\n        dst_row[col] = col;\n    }\n\n    item_ct1.barrier(sycl::access::fence_space::local_space);\n\n    for (int k = 2; k <= ncols_pad; k *= 2) {\n        for (int j = k / 2; j > 0; j /= 2) {\n            for (int i = 0; i < tasks_per_thread; i++) {\n                int col = col_index * tasks_per_thread + i;\n                int ixj = col ^ j;\n                if (ixj > col) {\n                    if ((col & k) == 0) {\n                        if (dst_row[col] >= ncols ||\n                            (dst_row[ixj] < ncols &&\n                             (order == GGML_SORT_ORDER_ASC\n                                  ? x_row[dst_row[col]] > x_row[dst_row[ixj]]\n                                  : x_row[dst_row[col]] <\n                                        x_row[dst_row[ixj]]))) {\n                            ggml_sycl_swap(dst_row[col], dst_row[ixj]);\n                        }\n                    } else {\n                        if (dst_row[ixj] >= ncols ||\n                            (dst_row[col] < ncols &&\n                             (order == GGML_SORT_ORDER_ASC\n                                  ? x_row[dst_row[col]] < x_row[dst_row[ixj]]\n                                  : x_row[dst_row[col]] >\n                                        x_row[dst_row[ixj]]))) {\n                            ggml_sycl_swap(dst_row[col], dst_row[ixj]);\n                        }\n                    }\n                }\n                item_ct1.barrier(sycl::access::fence_space::local_space);\n            }\n        }\n    }\n\n    // copy the result to dst without the padding\n    for (int i = 0; i < tasks_per_thread; i++) {\n        int col = col_index * tasks_per_thread + i;\n        if (col < ncols) {\n            dst[row * ncols + col] = dst_row[col];\n        }\n    }\n}\n\nstatic void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,\n                              const sycl::nd_item<3> &item_ct1) {\n    const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +\n                    item_ct1.get_local_id(1);\n    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +\n                    item_ct1.get_local_id(2);\n\n    if (col >= ncols) {\n        return;\n    }\n\n    const int i = row*ncols + col;\n    //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];\n    //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU\n    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;\n}\n\nstatic void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,\n                      const sycl::nd_item<3> &item_ct1) {\n    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +\n                  item_ct1.get_local_id(2);\n\n    if (i >= k) {\n        return;\n    }\n\n    dst[i] = scale * x[i] + bias;\n}\n\n\ntemplate <typename Ti, typename To>\nstatic  void pool2d_nchw_kernel(\n        const int ih, const int iw, const int oh, const int ow,\n        const int kh, const int kw, const int sh, const int sw,\n        const int ph, const int pw, const int parallel_elements,\n        const Ti* src, To* dst, const enum ggml_op_pool op,\n        const sycl::nd_item<3> &item_ct1) {\n        int idx = item_ct1.get_local_id(2) +\n                  item_ct1.get_group(2) * item_ct1.get_local_range(2);\n        if (idx >= parallel_elements) {\n            return;\n        }\n\n        const int I_HW = ih * iw;\n        const int O_HW = oh * ow;\n        const int nc = idx / O_HW;\n        const int cur_oh = idx % O_HW / ow;\n        const int cur_ow = idx % O_HW % ow;\n        const Ti* i_ptr = src + nc * I_HW;\n        To* o_ptr = dst + nc * O_HW;\n        const int start_h = cur_oh * sh - ph;\n        const int bh = sycl::max(0, start_h);\n        const int eh = sycl::min(ih, start_h + kh);\n        const int start_w = cur_ow * sw - pw;\n        const int bw = sycl::max(0, start_w);\n        const int ew = sycl::min(iw, start_w + kw);\n\n        To res = 0;\n\n        switch (op) {\n            case GGML_OP_POOL_AVG: res = 0; break;\n            case GGML_OP_POOL_MAX: res = -FLT_MAX; break;\n            default:\n                res      = (To) sycl::nan(uint32_t(0));\n                break;\n        }\n\n        for (int i = bh; i < eh; i += 1) {\n            for (int j = bw; j < ew; j += 1) {\n#if DPCT_COMPATIBILITY_TEMP >= 350\n                /*\n                DPCT1098:106: The '*' expression is used instead of the __ldg\n                call. These two expressions do not provide the exact same\n                functionality. Check the generated code for potential precision\n                and/or performance issues.\n                */\n                Ti cur = *(i_ptr + i * iw + j);\n#else\n                Ti cur = i_ptr[i * iw + j];\n#endif\n                switch (op) {\n                    case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;\n                    case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;\n                    default:\n                        res = (To) sycl::nan(uint32_t(0));\n                        break;\n                }\n            }\n        }\n        o_ptr[cur_oh * ow + cur_ow] = res;\n}\n\n\nstatic void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,\n                                           float *dst, const int ncols_x,\n                                           const int nrows_x,\n                                           const int nchannels_x,\n                                           const int nchannels_y,\n                                           queue_ptr stream) {\n\n    const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);\n    const sycl::range<3> block_dims(1, 1, WARP_SIZE);\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,\n                                     nchannels_y, item_ct1);\n            });\n    }\n}\n\nstatic void ggml_mul_mat_vec_nc_f16_f32_sycl(\n    const void *vx, const float *y, float *dst, const int ncols_x,\n    const int nrows_x, const int row_stride_x, const int nchannels_x,\n    const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {\n\n    const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);\n    const sycl::range<3> block_dims(1, 1, WARP_SIZE);\n    {\n        dpct::has_capability_or_fail(stream->get_device(),\n                                     {sycl::aspect::fp16});\n\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,\n                                       row_stride_x, channel_stride_x, channel_stride_y,\n                                       nchannels_y / nchannels_x, item_ct1);\n            });\n    }\n}\n\n\n\nstatic void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,\n                           const int k, queue_ptr stream) {\n    const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;\n    stream->parallel_for(\n        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *\n                              sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),\n                          sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),\n        [=](sycl::nd_item<3> item_ct1) {\n            scale_f32(x, dst, scale, bias, k, item_ct1);\n        });\n}\n\n\nstatic void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,\n                              const int nrows, queue_ptr stream) {\n    const sycl::range<3> block_dims(1, 1, WARP_SIZE);\n    const sycl::range<3> block_nums(1, nrows, 1);\n    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                         [=](sycl::nd_item<3> item_ct1)\n                             [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                                 k_sum_rows_f32(x, dst, ncols, item_ct1);\n                             });\n}\n\nstatic int next_power_of_2(int x) {\n    int n = 1;\n    while (n < x) {\n        n *= 2;\n    }\n    return n;\n}\n\nstatic void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,\n                                 const int nrows, ggml_sort_order order,\n                                 queue_ptr stream, int device) {\n    // bitonic sort requires ncols to be power of 2\n    const int ncols_pad = next_power_of_2(ncols);\n\n    int nth = 1;\n    int max_block_size = ggml_sycl_info().max_work_group_sizes[device];\n    while (nth < ncols_pad && nth < max_block_size)\n        nth *= 2;\n    if (nth > max_block_size)\n        nth = max_block_size;\n\n    const int tasks_per_thread = ncols_pad / nth;\n\n    const sycl::range<3> block_dims(1, 1, nth);\n    const sycl::range<3> block_nums(1, nrows, 1);\n    const size_t shared_mem = ncols_pad * sizeof(int);\n    GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo);\n\n    if (order == GGML_SORT_ORDER_ASC) {\n        stream->submit([&](sycl::handler &cgh) {\n            sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(\n                sycl::range<1>(shared_mem), cgh);\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1) {\n                    k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(\n                        x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,\n                        dpct_local_acc_ct1\n                            .get_multi_ptr<sycl::access::decorated::no>()\n                            .get());\n                });\n        });\n    } else if (order == GGML_SORT_ORDER_DESC) {\n        stream->submit([&](sycl::handler &cgh) {\n            sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(\n                sycl::range<1>(shared_mem), cgh);\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1) {\n                    k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(\n                        x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,\n                        dpct_local_acc_ct1\n                            .get_multi_ptr<sycl::access::decorated::no>()\n                            .get());\n                });\n        });\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n}\n\nstatic void top_k_f32_sycl(\n    const float * src,\n    int32_t * dst_indices,\n    const int64_t ncols,\n    const int64_t nrows,\n    const int k,\n    dpct::queue_ptr main_stream\n) {\n    const int block_size = 128;\n\n    const sycl::range<1> block_dims(block_size);\n    const sycl::range<1> grid_dims(nrows);\n\n    main_stream->submit([&](sycl::handler &cgh) {\n        sycl::local_accessor<float, 1> shared_vals(sycl::range<1>(block_size * k), cgh);\n        sycl::local_accessor<int, 1> shared_idx(sycl::range<1>(block_size * k), cgh);\n\n        cgh.parallel_for(\n            sycl::nd_range<1>(grid_dims * block_dims, block_dims),\n            [=](sycl::nd_item<1> item_ct1) {\n                const int row = item_ct1.get_group(0);\n                const int tid = item_ct1.get_local_id(0);\n\n                if (row >= nrows) return;\n\n                const float * src_row = src + row * ncols;\n                int32_t * dst_idx_row = dst_indices + row * k;\n\n                float local_vals[32];\n                int local_idx[32];\n\n                for (int i = 0; i < k; i++) {\n                    local_vals[i] = -FLT_MAX;\n                    local_idx[i] = -1;\n                }\n\n                for (int col = tid; col < ncols; col += block_size) {\n                    float val = src_row[col];\n\n                    if (val > local_vals[k-1]) {\n                        int pos = k - 1;\n                        while (pos > 0 && val > local_vals[pos - 1]) {\n                            pos--;\n                        }\n\n                        for (int i = k - 1; i > pos; i--) {\n                            local_vals[i] = local_vals[i - 1];\n                            local_idx[i] = local_idx[i - 1];\n                        }\n                        local_vals[pos] = val;\n                        local_idx[pos] = col;\n                    }\n                }\n\n                for (int i = 0; i < k; i++) {\n                    shared_vals[tid * k + i] = local_vals[i];\n                    shared_idx[tid * k + i] = local_idx[i];\n                }\n                item_ct1.barrier(sycl::access::fence_space::local_space);\n\n                if (tid == 0) {\n                    float final_vals[32];\n                    int final_idx[32];\n\n                    for (int i = 0; i < k; i++) {\n                        final_vals[i] = -FLT_MAX;\n                        final_idx[i] = -1;\n                    }\n\n                    for (int t = 0; t < block_size; t++) {\n                        for (int i = 0; i < k; i++) {\n                            float val = shared_vals[t * k + i];\n                            int idx = shared_idx[t * k + i];\n\n                            if (val > final_vals[k-1]) {\n                                int pos = k - 1;\n                                while (pos > 0 && val > final_vals[pos - 1]) {\n                                    pos--;\n                                }\n\n                                for (int j = k - 1; j > pos; j--) {\n                                    final_vals[j] = final_vals[j - 1];\n                                    final_idx[j] = final_idx[j - 1];\n                                }\n                                final_vals[pos] = val;\n                                final_idx[pos] = idx;\n                            }\n                        }\n                    }\n\n                    for (int i = 0; i < k; i++) {\n                        dst_idx_row[i] = final_idx[i];\n                    }\n\n                    if (k > 1) {\n                        int32_t temp = dst_idx_row[0];\n                        dst_idx_row[0] = dst_idx_row[1];\n                        dst_idx_row[1] = temp;\n                    }\n                }\n            });\n    });\n}\n\nstatic void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,\n                               const int nrows, queue_ptr stream) {\n    const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);\n    const sycl::range<3> block_nums(1, nrows, 1);\n    const size_t shared_mem = 256 * sizeof(float);\n\n    stream->submit([&](sycl::handler &cgh) {\n        sycl::local_accessor<float, 1> shared_data(\n            sycl::range<1>(shared_mem/sizeof(float)), cgh);\n        sycl::local_accessor<int, 1> shared_indices(\n            sycl::range<1>(shared_mem/sizeof(float)), cgh);\n\n        cgh.parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) {\n                const int tid = item_ct1.get_local_id(2);\n                const int row = item_ct1.get_global_id(1);\n\n                float max_val = -INFINITY;\n                int max_idx = -1;\n\n                for (int col = tid; col < ncols; col += 256) {\n                    float val = x[row * ncols + col];\n                    if (val > max_val) {\n                        max_val = val;\n                        max_idx = col;\n                    }\n                }\n\n                shared_data[tid] = max_val;\n                shared_indices[tid] = max_idx;\n                item_ct1.barrier(sycl::access::fence_space::local_space);\n\n                for (int stride = 256/2; stride > 0; stride >>= 1) {\n                    if (tid < stride) {\n                        float val1 = shared_data[tid];\n                        float val2 = shared_data[tid + stride];\n                        if (val2 > val1) {\n                            shared_data[tid] = val2;\n                            shared_indices[tid] = shared_indices[tid + stride];\n                        }\n                    }\n                    item_ct1.barrier(sycl::access::fence_space::local_space);\n                }\n\n\n                if (tid == 0) {\n                    dst[row] = shared_indices[0];\n                }\n            });\n    });\n}\nstatic void diag_mask_inf_f32_sycl(const float *x, float *dst,\n                                   const int ncols_x, const int nrows_x,\n                                   const int rows_per_channel, const int n_past,\n                                   queue_ptr stream) {\n    const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1);\n    const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE;\n    const sycl::range<3> block_nums(1, block_num_x, nrows_x);\n    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                         [=](sycl::nd_item<3> item_ct1) {\n                             diag_mask_inf_f32(x, dst, ncols_x,\n                                               rows_per_channel, n_past,\n                                               item_ct1);\n                         });\n}\n\nstatic dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,\n                                          const struct ggml_tensor *src,\n                                          int64_t i3, int64_t i2,\n                                          int64_t i1_low, int64_t i1_high,\n                                          queue_ptr stream) try {\n\n    dpct::memcpy_direction kind;\n    char * src_ptr;\n    if (ggml_backend_buffer_is_host(src->buffer)) {\n        kind = dpct::host_to_device;\n        //GGML_SYCL_DEBUG(\"%s: Host buffer type src tensor\\n\", __func__);\n        src_ptr = (char *) src->data;\n        // GGML_SYCL_DEBUG(\"ggml_sycl_cpy_tensor_2d  GGML_BACKEND_TYPE_CPU src_ptr %p\\n\", src_ptr);\n    } else if (ggml_backend_buffer_is_sycl(src->buffer)) {\n        // If buffer is a SYCL buffer\n        //GGML_SYCL_DEBUG(\"%s: SYCL buffer type src tensor\\n\", __func__);\n        kind    = dpct::device_to_device;\n        src_ptr = (char *) src->data;\n    } else if (ggml_backend_buffer_is_sycl_split(src->buffer)) {\n        /*\n        If buffer is a SYCL split buffer\n        */\n        //GGML_SYCL_DEBUG(\"%s: Split buffer type src tensor\\n\", __func__);\n        GGML_ASSERT(i1_low == 0 && i1_high == src->ne[1]);\n        kind = dpct::device_to_device;\n        ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;\n        int id;\n        SYCL_CHECK(CHECK_TRY_ERROR(\n            id = get_current_device_id()));\n        // GGML_SYCL_DEBUG(\"current device index %d\\n\", id);\n        src_ptr = (char *) extra->data_device[id];\n    } else {\n        // GGML_SYCL_DEBUG(\"GGML_ABORT(\"fatal error\")\\n\");\n        GGML_ABORT(\"fatal error\");\n    }\n    char * dst_ptr = (char *) dst;\n\n    GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);\n    GGML_TENSOR_LOCALS(int64_t, nb, src, nb);\n    const enum ggml_type type = src->type;\n    const int64_t ts = ggml_type_size(type);\n    const int64_t bs = ggml_blck_size(type);\n    int64_t i1_diff = i1_high - i1_low;\n\n    const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;\n    if (nb0 == ts && nb1 == ts*ne0/bs) {\n        // GGML_SYCL_DEBUG(\"stream->memcpy: dst_ptr=%p, x=%p, size=%lu\\n\", dst_ptr, x, i1_diff * nb1);\n        // return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1));\n        return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,\n                                    kind, *stream));\n\n    } else if (nb0 == ts) {\n        return CHECK_TRY_ERROR(\n            dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,\n                                    ts * ne0 / bs, i1_diff, kind, *stream));\n    } else {\n        for (int64_t i1 = 0; i1 < i1_diff; i1++) {\n            const void * rx = (const void *) ((const char *) x + i1*nb1);\n            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);\n            // pretend the row is a matrix with cols=1\n            dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(\n                rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));\n            /*\n            DPCT1001:85: The statement could not be removed.\n            */\n            /*\n            DPCT1000:86: Error handling if-stmt was detected but could not be\n            rewritten.\n            */\n            if (r != 0) return r;\n        }\n        return 0;\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\ninline void ggml_sycl_op_mul_mat_sycl(\n    ggml_backend_sycl_context & ctx,\n    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,\n    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,\n    float *dst_dd_i, const int64_t row_low, const int64_t row_high,\n    const int64_t src1_ncols, const int64_t src1_padded_row_size,\n    const queue_ptr &stream) try {\n\n    GGML_ASSERT(src0_dd_i  != nullptr);\n    GGML_ASSERT(src1_ddf_i != nullptr);\n    GGML_ASSERT(dst_dd_i   != nullptr);\n\n    const int64_t ne00 = src0->ne[0];\n    const int64_t ne10 = src1->ne[0];\n    GGML_ASSERT(ne00 == ne10);\n\n    const int64_t row_diff = row_high - row_low;\n\n    int id;\n    SYCL_CHECK(\n        CHECK_TRY_ERROR(id = get_current_device_id()));\n\n    const int64_t ne0 = dst->ne[0]; // used by MKL only\n    // the main device has a larger memory buffer to hold the results from all GPUs\n    // ldc == nrows of the matrix that cuBLAS writes into\n    int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only\n\n#ifdef GGML_SYCL_F16\n    bool use_fp16 = true;  // TODO(Yu) SYCL capability check\n#else\n    bool use_fp16 = false;\n#endif\n    if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) &&\n        row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {\n        ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());\n        if (src0->type != GGML_TYPE_F16) {\n            scope_op_debug_print scope_dbg_print(__func__, \"/to_fp16_sycl\", dst, /*num_src=*/2,\n                                                 \" : converting src0 to fp16\");\n            const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst);\n            GGML_ASSERT(to_fp16_sycl != nullptr);\n            size_t ne = row_diff*ne00;\n            src0_as_f16.alloc(ne);\n            to_fp16_sycl(src0_dd_i, src0_as_f16.get(), ne, stream);\n        }\n        const sycl::half *src0_ptr = src0->type == GGML_TYPE_F16\n                                         ? (const sycl::half *)src0_dd_i\n                                         : src0_as_f16.get();\n\n        ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool());\n        if (src1->type != GGML_TYPE_F16) {\n            scope_op_debug_print scope_dbg_print(__func__, \"/to_fp16_sycl\", dst, /*num_src=*/2,\n                                                 \" : converting src1 to fp16\");\n            const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);\n            GGML_ASSERT(to_fp16_sycl != nullptr);\n            size_t ne = src1_ncols*ne10;\n            src1_as_f16.alloc(ne);\n            to_fp16_sycl(src1_ddf_i, src1_as_f16.get(), ne, stream);\n        }\n        const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16\n                ? (const sycl::half *)src1->data + src1_padded_row_size\n                                         : src1_as_f16.get();\n\n#if GGML_SYCL_DNNL\n        if (!g_ggml_sycl_disable_dnn) {\n                DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,\n                                     DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),\n                                      dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);\n        }\n        else\n#endif\n        {\n            ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);\n\n            const sycl::half alpha_f16 = 1.0f;\n            const sycl::half beta_f16  = 0.0f;\n            SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(\n                *stream, oneapi::mkl::transpose::trans,\n                oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,\n                &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,\n                src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,\n                dst_f16.get(), dpct::library_data_t::real_half, ldc,\n                dpct::library_data_t::real_half)));\n            scope_op_debug_print scope_dbg_print(__func__, \"/to_fp32_sycl\", dst, /*num_src=*/2,\n                                                 \" : converting dst to fp32\");\n            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);\n            to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);\n        }\n    } else {\n        ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool());\n        ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool());\n        if (src0->type != GGML_TYPE_F32) {\n            scope_op_debug_print scope_dbg_print(__func__, \"/to_fp32_sycl\", dst, /*num_src=*/2,\n                                                 \" : converting src0 to fp32\");\n            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst);\n            GGML_ASSERT(to_fp32_sycl != nullptr);\n            src0_ddq_as_f32.alloc(row_diff*ne00);\n            to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);\n        }\n        if (src1->type != GGML_TYPE_F32) {\n            scope_op_debug_print scope_dbg_print(__func__, \"/to_fp32_sycl\", dst, /*num_src=*/2,\n                                                 \" : converting src1 to fp32\");\n            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst);\n            GGML_ASSERT(to_fp32_sycl != nullptr);\n            src1_ddq_as_f32.alloc(src1_ncols*ne10);\n            to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);\n        }\n        const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();\n        const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();\n\n#if GGML_SYCL_DNNL\n        if (!g_ggml_sycl_disable_dnn) {\n            DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,\n                                      DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),\n                                      dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);\n        }\n        else\n#endif\n        {\n            const float alpha = 1.0f;\n            const float beta  = 0.0f;\n            SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(\n                *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff,\n                src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,\n                dpct::get_value(&beta, *stream), dst_dd_i, ldc)));\n        }\n    }\n    GGML_UNUSED(dst);\n    GGML_UNUSED(src1_ddq_i);\n    GGML_UNUSED(src1_padded_row_size);\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);\n    float *       dst_dd  = static_cast<float *>(dst->data);\n\n    const int32_t * opts = (const int32_t *)dst->op_params;\n    enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);\n    const int k0 = opts[1];\n    const int k1 = opts[2];\n    const int s0 = opts[3];\n    const int s1 = opts[4];\n    const int p0 = opts[5];\n    const int p1 = opts[6];\n\n    const int64_t IH = dst->src[0]->ne[1];\n    const int64_t IW = dst->src[0]->ne[0];\n\n    const int64_t N = dst->ne[3];\n    const int64_t OC = dst->ne[2];\n    const int64_t OH = dst->ne[1];\n    const int64_t OW = dst->ne[0];\n\n    const int parallel_elements = N * OC * OH * OW;\n    const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;\n    sycl::range<3> block_nums(1, 1, num_blocks);\n    main_stream->parallel_for(\n        sycl::nd_range<3>(block_nums *\n                              sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),\n                          sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),\n        [=](sycl::nd_item<3> item_ct1) {\n            pool2d_nchw_kernel(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0,\n                               parallel_elements, src0_dd, dst_dd, op,\n                               item_ct1);\n        });\n}\n\ninline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);\n    float *       dst_dd  = static_cast<float *>(dst->data);\n\n    const int64_t ne = ggml_nelements(dst->src[0]);\n\n    sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);\n}\n\ninline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);\n    float *       dst_dd  = static_cast<float *>(dst->data);\n\n    const int64_t ncols = dst->src[0]->ne[0];\n    const int64_t nrows = ggml_nrows(dst->src[0]);\n\n    sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);\n}\n\ninline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n\n    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);\n    float *       dst_dd  = static_cast<float *>(dst->data);\n\n    const int64_t ncols = dst->src[0]->ne[0];\n    const int64_t nrows = ggml_nrows(dst->src[0]);\n\n    sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);\n\n    main_stream->parallel_for(\n        sycl::range<1>(nrows),\n        [=](sycl::id<1> row) {\n            dst_dd[row] /= ncols;\n        }\n    );\n}\n\n\ninline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_I32);\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);\n    int32_t *       dst_dd  = static_cast<int32_t *>(dst->data);\n\n\n    const int64_t ncols = dst->src[0]->ne[0];\n    const int64_t nrows = ggml_nrows(dst->src[0]);\n\n    enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];\n\n    argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,\n                         main_stream, ctx.device);\n}\n\nstatic void ggml_sycl_op_top_k(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(src0);\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_I32);\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n\n    const float * src0_dd = static_cast<const float *>(src0->data);\n    int32_t * dst_dd = static_cast<int32_t *>(dst->data);\n\n    const int k = dst->ne[0];\n    const int64_t ncols = src0->ne[0];\n    const int64_t nrows = ggml_nrows(src0);\n\n    GGML_ASSERT(k > 0 && k <= 32);\n    GGML_ASSERT(k <= ncols);\n\n    top_k_f32_sycl(src0_dd, dst_dd, ncols, nrows, k, main_stream);\n}\n\ninline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_I32);\n\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);\n    int32_t *       dst_dd  = static_cast<int32_t *>(dst->data);\n\n    const int64_t ncols = dst->src[0]->ne[0];\n    const int64_t nrows = ggml_nrows(dst->src[0]);\n\n    argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);\n}\n\ninline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);\n    float *       dst_dd  = static_cast<float *>(dst->data);\n\n    const int64_t ne00 = dst->src[0]->ne[0];\n    const int64_t ne01 = dst->src[0]->ne[1];\n    const int nrows0 = ggml_nrows(dst->src[0]);\n\n    const int n_past = ((int32_t *) dst->op_params)[0];\n\n    diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);\n}\n\nstatic void tri_f32_sycl(\n    const float * src,\n    float * dst,\n    const int64_t ne0,\n    const int64_t ne1,\n    const int64_t ne2,\n    const int64_t ne3,\n    const ggml_tri_type ttype,\n    dpct::queue_ptr main_stream\n) {\n    const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3;\n\n    main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) {\n        const int64_t idx = (int64_t) tid[0];\n\n        const int64_t i0 = idx % ne0;\n        const int64_t t1 = idx / ne0;\n        const int64_t i1 = t1 % ne1;\n\n        bool keep = false;\n        switch (ttype) {\n            case GGML_TRI_TYPE_LOWER:      keep = (i0 <  i1); break;\n            case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break;\n            case GGML_TRI_TYPE_UPPER:      keep = (i0 >  i1); break;\n            case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break;\n            default: keep = false; break;\n        }\n\n        dst[idx] = keep ? src[idx] : 0.0f;\n    });\n}\n\nstatic void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    GGML_ASSERT(src0);\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type  == GGML_TYPE_F32);\n    GGML_ASSERT(ggml_is_contiguous(src0));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n\n    const float * src0_dd = static_cast<const float *>(src0->data);\n    float *       dst_dd  = static_cast<float *>(dst->data);\n\n    const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);\n\n    const int64_t ne0 = src0->ne[0];\n    const int64_t ne1 = src0->ne[1];\n    const int64_t ne2 = src0->ne[2];\n    const int64_t ne3 = src0->ne[3];\n\n    tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream);\n}\n\n\ninline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);\n    float *       dst_dd  = static_cast<float *>(dst->data);\n\n    float scale;\n    float bias;\n    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));\n    memcpy(&bias,  (float *) dst->op_params + 1, sizeof(float));\n\n    scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);\n    /*\n    DPCT1010:87: SYCL uses exceptions to report errors and does not use the\n    error codes. The call was replaced with 0. You need to rewrite this code.\n    */\n    SYCL_CHECK(0);\n}\n\nstatic void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {\n    static bool peer_access_enabled = false;\n\n    const bool enable_peer_access = n_tokens <= GGML_SYCL_PEER_MAX_BATCH_SIZE;\n\n    if (peer_access_enabled == enable_peer_access) {\n        return;\n    }\n\n#ifdef NDEBUG\n    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {\n        SYCL_CHECK(ggml_sycl_set_device(i));\n    }\n\n    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {\n        SYCL_CHECK(ggml_sycl_set_device(i));\n\n        for (int id_other = 0; id_other < ggml_sycl_info().device_count; ++id_other) {\n            if (i == id_other) {\n                continue;\n            }\n            if (i != main_device && id_other != main_device) {\n                continue;\n            }\n\n            // int can_access_peer;\n            // SYCL_CHECK(syclDeviceCanAccessPeer(&can_access_peer, id, id_other));\n            // if (can_access_peer) {\n            //     if (enable_peer_access) {\n            //         SYCL_CHECK(syclDeviceEnablePeerAccess(id_other, 0));\n            //     } else {\n            //         SYCL_CHECK(syclDeviceDisablePeerAccess(id_other));\n            //     }\n            // }\n        }\n    }\n#endif // NDEBUG\n\n    peer_access_enabled = enable_peer_access;\n}\n\ntemplate <template <int> typename quantize_f>\nstatic void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,\n                                 const ggml_tensor *src1, ggml_tensor *dst,\n                                 ggml_sycl_op_mul_mat_t op) try {\n\n    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);\n\n    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);\n    const int64_t nrows1 = ggml_nrows(src1);\n\n    GGML_ASSERT(ne03 == ne13);\n\n    const int64_t ne0 = dst->ne[0];\n    const int64_t ne1 = dst->ne[1];\n\n    const int nb2 = dst->nb[2];\n    const int nb3 = dst->nb[3];\n\n    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));\n    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src1->buffer));\n    GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));\n\n    GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);\n\n    const int64_t i02_divisor = ne12 / ne02;\n\n    const size_t src0_ts = ggml_type_size(src0->type);\n    const size_t src0_bs = ggml_blck_size(src0->type);\n    const size_t q8_1_ts = sizeof(block_q8_1);\n    const size_t q8_1_bs = QK8_1;\n\n    ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;\n    ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;\n\n    const bool src0_is_contiguous = ggml_is_contiguous(src0);\n    const bool src1_is_contiguous = ggml_is_contiguous(src1);\n\n    int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);\n\n    const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);\n    GGML_ASSERT(!(split && ne02 > 1));\n    GGML_ASSERT(!(split && ne03 > 1));\n    GGML_ASSERT(!(split && ne02 < ne12));\n\n    std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split;\n    if (split) {\n        // TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_TYPE_GPU_SPLIT check\n        // GGML_ASSERT(src0->buffer != nullptr && src0->buffer->buft == ...);\n        ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;\n        tensor_split = buft_ctx->tensor_split;\n    }\n\n    struct dev_data {\n        ggml_sycl_pool_alloc<char> src0_dd_alloc;\n        ggml_sycl_pool_alloc<float> src1_ddf_alloc;\n        ggml_sycl_pool_alloc<char> src1_ddq_alloc;\n        ggml_sycl_pool_alloc<float> dst_dd_alloc;\n\n        char *src0_dd = nullptr;\n        float *src1_ddf = nullptr; // float\n        char *src1_ddq = nullptr;  // q8_1\n        float *dst_dd = nullptr;\n\n        int64_t row_low;\n        int64_t row_high;\n    };\n\n    dev_data dev[GGML_SYCL_MAX_DEVICES];\n\n    int used_devices = 0;\n    queue_ptr main_stream = ctx.stream();\n\n    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {\n        // by default, use all rows\n        dev[i].row_low  = 0;\n        dev[i].row_high = ne01;\n\n        // for multi GPU, get the row boundaries from tensor split\n        // and round to mul_mat_q tile sizes\n        if (split) {\n            const int64_t rounding = get_row_rounding(src0->type, tensor_split);\n\n            if (i != 0) {\n                dev[i].row_low  = ne01*tensor_split[i];\n                if (dev[i].row_low < ne01) {\n                    dev[i].row_low -= dev[i].row_low % rounding;\n                }\n            }\n\n            if (i != ggml_sycl_info().device_count - 1) {\n                dev[i].row_high  = ne01*tensor_split[i + 1];\n                if (dev[i].row_high < ne01) {\n                    dev[i].row_high -= dev[i].row_high % rounding;\n                }\n            }\n        }\n    }\n\n    constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>,\n                                                      no_quantize_q8_1<QK8_1 / WARP_SIZE>>;\n    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {\n        if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {\n            continue;\n        }\n\n        used_devices++;\n\n        const bool src1_on_device = i == ctx.device;\n        const bool  dst_on_device = i == ctx.device;\n\n        ggml_sycl_set_device(i);\n        queue_ptr stream = ctx.stream(i, 0);\n\n        if (src0_is_contiguous) {\n            dev[i].src0_dd = (char *) src0->data;\n        } else {\n            dev[i].src0_dd = dev[i].src0_dd_alloc.alloc(ctx.pool(i), ggml_nbytes(src0));\n        }\n\n        if (src1_on_device && src1_is_contiguous) {\n            dev[i].src1_ddf = (float *) src1->data;\n        } else {\n            dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));\n        }\n\n        if constexpr(quantize_enabled) {\n            dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);\n\n            if (src1_on_device && src1_is_contiguous) {\n                scope_op_debug_print scope_dbg_print(__func__, \"/quantize_row_q8_1_sycl\", dst,\n                                                     /*num_src=*/2, \" : converting src1 to Q8_1\");\n                try {\n                    quantize_row_q8_1_sycl<quantize_f>(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);\n                } catch (sycl::exception const &exc) {\n                    std::cerr << \"Quantize_row_q8_1_sycl error\" << exc.what() << \"Exception caught at file:\" << __FILE__\n                              << \", line:\" << __LINE__ << std::endl;\n                    std::exit(1);\n                }\n            }\n        }\n\n        if (dst_on_device) {\n            dev[i].dst_dd = (float *) dst->data;\n        } else {\n            const size_t size_dst_ddf = split ? (dev[i].row_high - dev[i].row_low)*ne1 : ggml_nelements(dst);\n            dev[i].dst_dd = dev[i].dst_dd_alloc.alloc(ctx.pool(i), size_dst_ddf);\n        }\n    }\n\n    // if multiple devices are used they need to wait for the main device\n    // here an event is recorded that signals that the main device has finished calculating the input data\n    if (split && used_devices > 1) {\n        ggml_sycl_set_device(ctx.device);\n        SYCL_CHECK(CHECK_TRY_ERROR(\n            *src0_extra->events[ctx.device][0] =\n                ctx.stream()->ext_oneapi_submit_barrier()));\n    }\n\n    const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;\n    for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {\n        const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;\n        const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;\n        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {\n            if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {\n                continue;\n            }\n\n            const bool src1_on_device = i == ctx.device;\n            const bool  dst_on_device = i == ctx.device;\n            const int64_t row_diff = dev[i].row_high - dev[i].row_low;\n\n            ggml_sycl_set_device(i);\n            queue_ptr stream = ctx.stream(i, is);\n\n            // wait for main GPU data if necessary\n            if (split && (i != ctx.device || is != 0)) {\n                SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(\n                    {*src0_extra->events[ctx.device][0]})));\n            }\n\n            for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {\n                const int64_t i03 = i0 / ne12;\n                const int64_t i02 = i0 % ne12;\n\n                const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;\n\n                // for split tensors the data begins at i0 == i0_offset_low\n                char  *  src0_dd_i =  dev[i].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;\n                float * src1_ddf_i = dev[i].src1_ddf + (i0*ne11 + src1_col_0) * ne10;\n                char  * src1_ddq_i = dev[i].src1_ddq +  src1_ddq_i_offset;\n                float *   dst_dd_i =   dev[i].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);\n\n                // the main device memory buffer can be on VRAM scratch, with space for all partial results\n                // in that case an offset on dst_ddf_i is needed\n                if (i == ctx.device) {\n                    dst_dd_i += dev[i].row_low; // offset is 0 if no tensor split\n                }\n\n                // copy src0, src1 to device if necessary\n                if (src1_is_contiguous) {\n                    if (i != ctx.device) {\n                        if constexpr (quantize_enabled) {\n                            char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;\n                            SYCL_CHECK(\n                                CHECK_TRY_ERROR(stream\n                                                    ->memcpy(src1_ddq_i, src1_ddq_i_source,\n                                                             src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs)\n                                                    .wait()));\n                        } else {\n                            float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];\n                            src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10;\n\n                            SYCL_CHECK(\n                                CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source,\n                                                               src1_ncols * ne10 * sizeof(float))));\n                        }\n                    }\n                } else {\n                    if (src1_on_device) {\n                        SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0,\n                                                           src1_col_0 + src1_ncols, stream));\n                    } else {\n                        GGML_ABORT(\"src1 is non-contiguous and not on device\");\n                    }\n\n                    if constexpr (quantize_enabled) {\n                        scope_op_debug_print scope_dbg_print(__func__, \"/quantize_row_q8_1_sycl\", dst,\n                                                             /*num_src=*/2, \" : converting src1 to Q8_1\");\n                        try {\n                            quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols,\n                                                                  src1_padded_col_size, stream);\n                        } catch (const sycl::exception & exc) {\n                            std::cerr << \"Quantize_row_q8_1_sycl error\" << exc.what()\n                                      << \"Exception caught at file:\" << __FILE__ << \", line:\" << __LINE__ << std::endl;\n                            std::exit(1);\n                        }\n                    }\n                }\n\n                if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {\n                    SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[i].row_low, dev[i].row_high, stream));\n                }\n                if (src1->type == GGML_TYPE_F16) {\n                    src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;\n                }\n                // do the computation\n                SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,\n                    dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));\n\n                // copy dst to host or other device if necessary\n                if (!dst_on_device) {\n                    void * dst_off_device = dst->data;\n                    if (split) {\n                        // src0 = weight matrix is saved as a transposed matrix for better memory layout.\n                        // dst is NOT transposed.\n                        // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.\n                        // Instead they need to be copied to the correct slice in ne0 = dst row index.\n                        // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.\n                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);\n                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));\n                        dhf_dst_i += src1_col_0*ne0 + dev[i].row_low;\n\n                        SYCL_CHECK(CHECK_TRY_ERROR(dpct::async_dpct_memcpy(\n                            dhf_dst_i, ne0 * sizeof(float), dst_dd_i,\n                            row_diff * sizeof(float), row_diff * sizeof(float),\n                            src1_ncols, dpct::device_to_device, *stream)));\n                    } else {\n                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);\n                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));\n                        dhf_dst_i += src1_col_0*ne0;\n                        SYCL_CHECK(CHECK_TRY_ERROR(\n                            stream->memcpy(dhf_dst_i, dst_dd_i,\n                                           src1_ncols * ne0 * sizeof(float)).wait()));\n                    }\n                }\n\n                // add event for the main device to wait on until other device is done\n                if (split && (i != ctx.device || is != 0)) {\n                    SYCL_CHECK(CHECK_TRY_ERROR(\n                        *src0_extra->events[i][is] =\n                            stream->ext_oneapi_submit_barrier()));\n                }\n            }\n        }\n    }\n\n    // main device waits for all other devices to be finished\n    if (split && ggml_sycl_info().device_count > 1) {\n        int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;\n        is_max = is_max <= GGML_SYCL_MAX_STREAMS ? is_max : GGML_SYCL_MAX_STREAMS;\n\n        ggml_sycl_set_device(ctx.device);\n        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {\n            if (dev[i].row_low == dev[i].row_high) {\n                continue;\n            }\n            for (int64_t is = 0; is < is_max; ++is) {\n                SYCL_CHECK(CHECK_TRY_ERROR(\n                    ctx.stream()->ext_oneapi_submit_barrier(\n                        {*src0_extra->events[i][is]})));\n            }\n        }\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_repeat_back(ctx, dst);\n}\n\nstatic void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    ggml_sycl_op_get_rows(ctx, dst);\n}\n\nstatic void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_norm(ctx, dst);\n}\n\nstatic void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_rms_norm(ctx, dst);\n}\n\nstatic void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    ggml_sycl_op_rms_norm_back(ctx, dst);\n}\n\nstatic void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_l2_norm(ctx, dst);\n}\n\nstatic void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_group_norm(ctx, dst);\n}\n\nstatic void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,\n                                       const ggml_tensor *src1,\n                                       ggml_tensor *dst) try {\n    GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));\n    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));\n    GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation\n    GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation\n    GGML_ASSERT(src0->type == GGML_TYPE_F16);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n\n    const int64_t ne00 = src0->ne[0];\n    const int64_t ne01 = src0->ne[1];\n    const int64_t ne02 = src0->ne[2];\n\n    const int64_t ne12 = src1->ne[2];\n\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    queue_ptr main_stream = ctx.stream();\n\n    void  * src0_ddq = src0->data;\n    float * src1_ddf = (float *) src1->data;\n    float * dst_ddf  = (float *) dst->data;\n\n    ggml_mul_mat_p021_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,\n                                     const ggml_tensor *src1,\n                                     ggml_tensor *dst) try {\n    GGML_ASSERT(!ggml_is_transposed(src0));\n    GGML_ASSERT(!ggml_is_transposed(src1));\n    GGML_ASSERT(!ggml_is_permuted(src0));\n    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));\n    GGML_ASSERT(src0->type == GGML_TYPE_F16);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->ne[1] == 1);\n    GGML_ASSERT(src1->ne[3] == 1);\n\n    const int64_t ne00 = src0->ne[0];\n    const int64_t ne01 = src0->ne[1];\n    const int64_t ne02 = src0->ne[2];\n\n    const int64_t nb01 = src0->nb[1];\n    const int64_t nb02 = src0->nb[2];\n\n    const int64_t ne12 = src1->ne[2];\n    const int64_t nb11 = src1->nb[1];\n\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    queue_ptr main_stream = ctx.stream();\n\n    void  * src0_ddq = src0->data;\n    float * src1_ddf = (float *) src1->data;\n    float * dst_ddf  = (float *) dst->data;\n\n    const int64_t row_stride_x = nb01 / sizeof(sycl::half);\n    const int64_t channel_stride_x = nb02 / sizeof(sycl::half);\n    const int64_t channel_stride_y = nb11 / sizeof(float);\n\n    ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,\n                                   const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,\n                                   size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,\n                                   int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {\n    const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);\n    const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);\n\n    if (i13 >= ne13 || i12 >= ne12) {\n        return;\n    }\n\n    const int64_t i03 = i13 / r3;\n    const int64_t i02 = i12 / r2;\n\n    const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);\n    const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);\n    uint8_t *       dst_bytes  = static_cast<uint8_t *>(dst);\n\n    ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;\n    ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;\n    ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;\n}\n\nstatic void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,\n                                           const ggml_tensor * src1, ggml_tensor * dst) try {\n    GGML_ASSERT(!ggml_is_transposed(src0));\n    GGML_ASSERT(!ggml_is_transposed(src1));\n    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));\n    GGML_ASSERT(src0->type == GGML_TYPE_F16);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    // TODO: see https://github.com/ggml-org/llama.cpp/pull/13155\n    // Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst\n    GGML_ASSERT(ggml_is_contiguous(dst));\n\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    queue_ptr queue = ctx.stream();\n\n    dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });\n\n    const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);\n    float *            dst_ddf  = static_cast<float *>(dst->data);\n\n    const sycl::half * src1_f16       = static_cast<const sycl::half *>(src1->data);\n    const size_t       type_size_src0 = ggml_type_size(src0->type);\n    const size_t       type_size_src1 = ggml_type_size(src1->type);\n\n    bool is_src0_cont_2 = ggml_is_contiguous_2(src0);\n    bool is_src1_cont_2 = ggml_is_contiguous_2(src1);\n\n    // SRC1 strides\n    int64_t                          s11 = nb11 / type_size_src1;\n    int64_t                          s12 = nb12 / type_size_src1;\n    int64_t                          s13 = nb13 / type_size_src1;\n    ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());\n\n    // convert src1 to fp16\n    if (src1->type != GGML_TYPE_F16) {\n        scope_op_debug_print    scope_dbg_print(__func__, \"/to_fp16_nc_sycl\", dst, /*num_src=*/2,\n                                                \" : converting src1 to fp16\");\n\n        // iterate tensor dims and find the slowest moving dim and stride\n        int last_dim=0;\n        int last_str=0;\n        size_t largest_str=0;\n        for(int i = 0; i< 4; i++){\n            // last stride is always the largest\n            if(src1->nb[i] == largest_str){\n                if(src1->ne[last_dim] == 1){\n                    last_str = i;\n                    last_dim = i;\n                }\n            }\n            if(src1->nb[i] > largest_str){\n                largest_str = src1->nb[i];\n                last_str = i;\n                last_dim = i;\n            }\n\n        }\n#if GGML_SYCL_DNNL\n        // oneDNN handles strided data and does not need overhead of ggml_get_to_fp16_nc_sycl\n        const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;\n        src1_f16_alloc.alloc(ne_src1);\n        const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);\n        GGML_ASSERT(to_fp16_sycl != nullptr);\n        to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);\n# else\n        const int64_t ne_src1 = ggml_nelements(src1);\n        src1_f16_alloc.alloc(ne_src1);\n        const to_fp16_nc_sycl_t to_fp16_nc_sycl = ggml_get_to_fp16_nc_sycl(src1->type);\n        GGML_ASSERT(to_fp16_nc_sycl != nullptr);\n        to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);\n#endif\n\n        src1_f16 = src1_f16_alloc.get();\n        s11      = ne10;\n        s12      = ne11 * s11;\n        s13      = ne12 * s12;\n\n        is_src1_cont_2 = true;\n    }\n\n    ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());\n\n    dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;\n    dpct::library_data_t mkl_data_type    = dpct::library_data_t::real_float;\n\n    // dst strides\n    size_t nbd2 = dst->nb[2];\n    size_t nbd3 = dst->nb[3];\n\n    const float alpha_f32 = 1.0f;\n    const float beta_f32  = 0.0f;\n\n    const void * alpha = &alpha_f32;\n    const void * beta  = &beta_f32;\n\n    GGML_ASSERT(ne12 % ne02 == 0);\n    GGML_ASSERT(ne13 % ne03 == 0);\n    GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));\n    GGML_ASSERT(ne10 == ne00);\n\n    // broadcast factors\n    const int64_t r2 = ne12 / ne02;\n    const int64_t r3 = ne13 / ne03;\n\n#if GGML_SYCL_DNNL\n    if (!g_ggml_sycl_disable_dnn) {\n            int64_t str_a0 = nb00 / type_size_src0;\n            int64_t str_a1 = nb01 / type_size_src0;\n            int64_t str_a2 = nb02 / type_size_src0;\n\n            int64_t str_b0 = nb10 / type_size_src1;\n            int64_t str_b1 = nb11 / type_size_src1;\n            int64_t str_b2 = nb12 / type_size_src1;\n\n            auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,\n                                                const sycl::half *src1, float *dst,\n                                                int64_t a0, int64_t a1, int64_t batcha,\n                                                int64_t /*b0*/, int64_t b1, int64_t batchb,\n                                                int64_t sa0, int64_t sa1, int64_t sa2,\n                                                int64_t sb0, int64_t sb1, int64_t sb2,\n                                                int64_t sd2) {\n                bool supported_broadcast = batchb == batcha ? true\n                        : batchb == 1 || batcha == 1        ? true\n                                                            : false;\n                if (supported_broadcast) {\n                    DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0,\n                            DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,\n                            DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,\n                            DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb);\n                } else {\n                    // iterate over batches from smaller set of matrices (matrix 0)\n                    int64_t batches0 = batcha;\n                    int64_t batches1 = batchb;\n\n                    if (batches0 > batches1) {\n                        int64_t num_mul_mats = batches1;\n                        int64_t sub_batch = batches0 / num_mul_mats;\n                        // src0 is batched and bigger, shift and multiply with src1\n                        for (int64_t i0 = 0; i0 < num_mul_mats; i0++) {\n                            const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);\n                            const sycl::half *src1_shifted = src1 + (sb2 * i0);\n                            float *dst_shifted = dst + (sd2 * i0 * sub_batch);\n                            DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,\n                                    DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,\n                                    src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,\n                                    sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),\n                                    queue, sub_batch, 1);\n                        }\n                    } else {\n                        int64_t num_mul_mats = batches0;\n                        int64_t sub_batch = batches1 / num_mul_mats;\n                        // src1 is batched and bigger, shift and multiply with src0\n                        for (int64_t i1 = 0; i1 < num_mul_mats; i1++) {\n                            const sycl::half *src0_shifted = src0 + (sa2 * i1);\n                            const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);\n                            float *dst_shifted = dst + (sd2 * i1 * sub_batch);\n                            DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,\n                                    DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,\n                                    src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,\n                                    sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),\n                                    queue, 1, sub_batch);\n                        }\n                    }\n                }\n            };\n\n            const bool cont_batches_dim2_a = nb02 * ne02 == nb03;\n            const bool cont_batches_dim2_b = nb12 * ne12 == nb13;\n            const bool cont_batches_dim3_a = ne02 == 1 && nb02 * ne01 == nb03;\n            const bool cont_batches_dim3_b = ne12 == 1 && nb12 * ne11 == nb13;\n            if (cont_batches_dim2_a && cont_batches_dim2_b) {\n                // A batch is considered contiguous if the dimension 2 is not strided\n                int64_t batches0 = ne02 * ne03;\n                int64_t batches1 = ne12 * ne13;\n                launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,\n                        ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,\n                        str_b2, nb2 / sizeof(float));\n            } else if (cont_batches_dim3_a && cont_batches_dim3_b) {\n                // This case is similar to the one above with the difference that only the batch in dimension 3 is used and the dimension 2 is of size 1.\n                int64_t batches0 = ne02 * ne03;\n                int64_t batches1 = ne12 * ne13;\n                int64_t str_a3 = nb03 / type_size_src0;\n                int64_t str_b3 = nb13 / type_size_src1;\n                launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,\n                        ne10, ne11, batches1, str_a0, str_a1, str_a3, str_b0, str_b1,\n                        str_b3, nb2 / sizeof(float));\n            } else {\n                for (int64_t b_a = 0; b_a < ne03; b_a++) {\n                    const sycl::half *src0_f16_shifted\n                            = src0_f16 + (nb03 * b_a / type_size_src0);\n                    const sycl::half *src1_f16_shifted\n                            = src1_f16 + (nb13 * b_a / type_size_src1);\n                    float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float));\n                    int64_t batches0 = ne02;\n                    int64_t batches1 = ne12;\n                    launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted,\n                            ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,\n                            str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float));\n                }\n            }\n\n    }\n    else\n#endif\n    {\n        if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {\n            // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:\n            const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;\n            const int64_t smb = ne12 == 1 ? s13       : s12;\n\n            // there is no broadcast and src0, src1 are contiguous across dims 2, 3\n            SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::mkl::transpose::trans,\n                                                        oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,\n                                                        src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,\n                                                        src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,\n                                                        mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));\n        } else {\n            const int ne23 = ne12 * ne13;\n\n            ggml_sycl_pool_alloc<const void *>         ptrs_src(ctx.pool(), 2 * ne23);\n            ggml_sycl_pool_alloc<void *>               ptrs_dst(ctx.pool(), 1 * ne23);\n            ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);\n\n            sycl::range<3> block_dims(1, ne12, ne13);\n            queue->submit([&](sycl::handler & cgh) {\n                const void ** ptrs_src_get = ptrs_src.get();\n                void **       ptrs_dst_get = ptrs_dst.get();\n                size_t        nb12_scaled  = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);\n                size_t        nb13_scaled  = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);\n                cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {\n                    k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,\n                                           nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);\n                });\n            });\n\n            SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(\n                *queue, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,\n                (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,\n                (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,\n                (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));\n        }\n    }\n} catch (const sycl::exception & exc) {\n    std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__ << \", line:\" << __LINE__ << std::endl;\n    std::exit(1);\n}\n\nenum class mul_mat_algo {\n    DMMV         = 0,\n    MMVQ         = 1,\n    MUL_MAT_SYCL = 2,\n};\n\ninline bool ggml_sycl_supports_mmq(enum ggml_type type) {\n    // TODO: accuracy issues in MMQ\n    GGML_UNUSED(type);\n    return false;\n}\n\ninline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_Q4_0:\n            return true;\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q6_K:\n            return !g_ggml_sycl_prioritize_dmmv;\n        default:\n            return false;\n    }\n}\n\ninline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_Q4_0:\n            return true;\n        default:\n            return false;\n    }\n}\n\ninline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q6_K:\n            return true;\n        default:\n            return false;\n    }\n}\n\nstatic bool ggml_sycl_supports_dmmv(enum ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_F16:\n            return true;\n        default:\n            return false;\n    }\n}\n\n// Helper functions to unify device memory allocation for both async and sync paths\nstatic inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) {\n    bool use_async = g_ggml_sycl_use_async_mem_op;\n#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC\n    if (use_async) {\n        return syclex::async_malloc(*stream, sycl::usm::alloc::device, size);\n    }\n#else\n    // If async allocation extension is not available, use_async should always be false.\n    GGML_ASSERT(!use_async);\n#endif\n    return sycl::malloc(size, *stream, sycl::usm::alloc::device);\n}\n\nstatic inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) {\n    bool use_async = g_ggml_sycl_use_async_mem_op;\n#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC\n    if (use_async) {\n        syclex::async_free(*stream, ptr);\n        return;\n    }\n#else\n    // If async allocation extension is not available, use_async should always be false.\n    GGML_ASSERT(!use_async);\n#endif\n    sycl::free(ptr, *stream);\n}\n\nstatic void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,\n                            dpct::queue_ptr stream) {\n    uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));\n\n    sycl::event copy_event;\n    SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));\n    if (!g_ggml_sycl_use_async_mem_op) {\n        copy_event.wait();\n    }\n\n    GGML_ASSERT((size % sizeof(block_q4_0) == 0));\n    GGML_ASSERT((offset % sizeof(block_q4_0) == 0));\n    int offset_blks = offset / sizeof(block_q4_0);\n    auto qs_ptr      = data_device + offset_blks * QK4_0 / 2;\n    auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;\n\n    auto reorder_event = stream->parallel_for(\n        size / sizeof(block_q4_0),\n            [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n            const block_q4_0* x = (const block_q4_0*)tmp_buf;\n            const int ib = i;\n\n            for (int j = 0; j < QK4_0/2; j ++)\n            {\n                *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];\n            }\n            *(d_ptr + ib) = x[ib].d;\n        });\n    if (!g_ggml_sycl_use_async_mem_op) {\n        reorder_event.wait_and_throw();\n    }\n    sycl_ext_free(stream, tmp_buf);\n}\n\nstatic void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {\n    GGML_ASSERT(size % sizeof(block_q4_K) == 0);\n    GGML_ASSERT(offset % sizeof(block_q4_K) == 0);\n\n    const int nblocks = size / sizeof(block_q4_K);\n\n    uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));\n\n    sycl::event copy_event;\n    SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));\n    if (!g_ggml_sycl_use_async_mem_op) {\n        copy_event.wait();\n    }\n\n    auto * qs_ptr     = data_device;\n    auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;\n    auto * dm_ptr     = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);\n\n    auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {\n        const block_q4_K * x  = (const block_q4_K *) tmp_buf;\n        const int          ib = i;\n\n        for (int j = 0; j < QK_K / 2; ++j) {\n            qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];\n        }\n\n        for (int j = 0; j < K_SCALE_SIZE; ++j) {\n            scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];\n        }\n\n        dm_ptr[ib] = x[ib].dm;\n    });\n    if (!g_ggml_sycl_use_async_mem_op) {\n        reorder_event.wait_and_throw();\n    }\n    sycl_ext_free(stream, tmp_buf);\n}\n\nstatic void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {\n    GGML_ASSERT(size % sizeof(block_q6_K) == 0);\n    GGML_ASSERT(offset % sizeof(block_q6_K) == 0);\n\n    const int nblocks = size / sizeof(block_q6_K);\n\n    uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));\n\n    sycl::event copy_event;\n    SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));\n    if (!g_ggml_sycl_use_async_mem_op) {\n        copy_event.wait();\n    }\n\n    auto *       ql_ptr     = data_device;\n    auto *       qh_ptr     = ql_ptr + (QK_K / 2) * nblocks;\n    auto *       scales_ptr = qh_ptr + (QK_K / 4) * nblocks;\n    sycl::half * dm_ptr     = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);\n\n    auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {\n        const block_q6_K * x  = (const block_q6_K *) tmp_buf;\n        const int          ib = i;\n\n        const uint8_t * ql              = x[ib].ql;\n        const uint8_t * qh              = x[ib].qh;\n        uint8_t *       base_ql_ptr     = ql_ptr + (QK_K / 2) * ib;\n        uint8_t *       base_qh_ptr     = qh_ptr + (QK_K / 4) * ib;\n        uint8_t *       base_scales_ptr = scales_ptr + (QK_K / 16) * ib;\n\n        for (int j = 0; j < QK_K / 2; ++j) {\n            base_ql_ptr[j] = ql[j];\n        }\n        for (int j = 0; j < QK_K / 4; ++j) {\n            base_qh_ptr[j] = qh[j];\n        }\n\n        for (int j = 0; j < QK_K / 16; ++j) {\n            base_scales_ptr[j] = x[ib].scales[j];\n        }\n\n        dm_ptr[ib] = x[ib].d;\n    });\n    if (!g_ggml_sycl_use_async_mem_op) {\n        reorder_event.wait_and_throw();\n    }\n    sycl_ext_free(stream, tmp_buf);\n}\n\nstatic void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {\n    uint8_t * data_device = (uint8_t *) src0->data;\n    size_t ncols = src0->ne[0];\n    size_t nrows = src0->ne[1];\n    size_t size = ggml_nbytes(src0);\n\n    switch (src0->type) {\n        case GGML_TYPE_Q4_0:\n            reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);\n            break;\n        case GGML_TYPE_Q4_K:\n            reorder_qw_q4_k(data_device, size, 0, stream);\n            break;\n        case GGML_TYPE_Q6_K:\n            reorder_qw_q6_k(data_device, size, 0, stream);\n            break;\n        default:\n            GGML_ABORT(\"reorder_qw() called with unsupported type\");\n            break;\n    }\n}\n\nstatic bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {\n    return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT\n            ctx.opt_feature.reorder &&      //allow this device due to good perf, skip the devices with bad perf.\n            dst->op == GGML_OP_MUL_MAT &&   //limit to some supported cases of Q4_0, to do for more cases.\n            dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;\n}\n\nstatic void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,\n                            ggml_tensor * dst, mul_mat_algo mm_algorithm) {\n    if (!should_reorder_tensor(*ctx, dst)) {\n        return;\n    }\n\n    ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);\n    if (!extra || extra->optimized_feature.reorder) {\n        return;  // Skip permutations and already reordered tensors\n    }\n\n    switch (mm_algorithm) {\n        case mul_mat_algo::DMMV:\n            if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {\n                return;\n            }\n            break;\n        case mul_mat_algo::MMVQ:\n            if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {\n                return;\n            }\n            break;\n        case mul_mat_algo::MUL_MAT_SYCL:\n            if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {\n                return;\n            }\n            break;\n    }\n\n    reorder_qw(src0, ctx->stream());\n    extra->optimized_feature.reorder = true;  // Used to decode/dequan in next steps and avoid re-reordering\n}\n\n\nstatic bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&\n           src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;\n}\n\nstatic bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&\n           src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;\n}\n\nstatic void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);\n    int64_t min_compute_capability = INT_MAX;\n\n    if (split) {\n        ggml_backend_sycl_split_buffer_type_context * buft_ctx =\n            (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;\n        auto & tensor_split = buft_ctx->tensor_split;\n        for (int id = 0; id < ggml_sycl_info().device_count; ++id) {\n            // skip devices that are not going to do any work:\n            if (tensor_split[id] >= (id + 1 < ggml_sycl_info().device_count ? tensor_split[id + 1] : 1.0f)) {\n                continue;\n            }\n\n            if (min_compute_capability > ggml_sycl_info().devices[id].cc) {\n                min_compute_capability = ggml_sycl_info().devices[id].cc;\n            }\n        }\n    } else {\n        min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;\n    }\n\n    // check data types and tensor shapes for custom matrix multiplication kernels:\n    bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);\n\n    bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);\n\n    bool use_mul_mat_q =  ggml_sycl_supports_mmq(src0->type)\n        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;\n\n\n    // mmvq and mmq need the __dp4a instruction which is available for gen12+\n    // Workaround in https://github.com/ggml-org/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e\n    use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);\n#ifdef SYCL_USE_XMX\n    use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);\n#endif // SYCL_USE_XMX\n\n    // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization\n    // is enabled takes precedence over DMMV, the current if-else implementation\n    // requires disabling DMMV if both conditions are met\n    if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) &&\n                                          ggml_sycl_supports_reorder_mmvq(src0->type)))) {\n        use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;\n    }\n\n    if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {\n        // TODO: Refactor and cleanup of mul mat dispatching.\n        if (src0->ne[3] == 1 && src1->ne[3] == 1) {\n            // KQ single-batch\n            // mmv p021 was specific for these dimensions\n            ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);\n        } else {\n            // The kernel from the if path is faster for that specific case, but does not support all mul mats.\n            ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);\n        }\n    } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1 && src1->ne[3] == 1) {\n        // KQV single-batch\n        ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);\n    } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {\n        // KQ + KQV multi-batch\n        ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);\n    } else if (use_dequantize_mul_mat_vec) {\n        opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);\n        ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec);\n    } else if (use_mul_mat_vec_q) {\n        opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);\n        ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);\n        if (extra && extra->optimized_feature.reorder) {\n            ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);\n        } else {\n            ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);\n        }\n    } else if (use_mul_mat_q) {\n        ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q);\n    } else {\n        ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl);\n    }\n}\n\n\nstruct mmid_row_mapping {\n    int32_t i1;\n    int32_t i2;\n};\n\n__dpct_inline__ static void k_copy_src1_to_contiguous(\n    const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,\n    int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,\n    const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,\n    int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,\n    const sycl::nd_item<3> &item_ct1, int &src1_row) {\n    int32_t iid1 = item_ct1.get_group(2);\n    int32_t id = item_ct1.get_group(1);\n\n    const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);\n\n    if (row_id_i != i02) {\n        return;\n    }\n\n    const int64_t i11 = id % ne11;\n    const int64_t i12 = iid1;\n\n    if (item_ct1.get_local_id(2) == 0) {\n        src1_row =\n            dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(\n                cur_src1_row, 1);\n        row_mapping[src1_row] = {id, iid1};\n    }\n    /*\n    DPCT1065:194: Consider replacing sycl::nd_item::barrier() with\n    sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better\n    performance if there is no access to global memory.\n    */\n    item_ct1.barrier();\n\n    const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);\n    float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);\n\n#pragma unroll\n    for (int i = item_ct1.get_local_id(2); i < ne10;\n         i += item_ct1.get_local_range(2)) {\n        src1_row_contiguous[i] = src1_row_original[i];\n    }\n}\n\n__dpct_inline__ static void k_copy_dst_from_contiguous(\n    char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,\n    const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,\n    size_t nb2, const sycl::nd_item<3> &item_ct1) {\n    int32_t i = item_ct1.get_group(2);\n\n    const int32_t i1 = row_mapping[i].i1;\n    const int32_t i2 = row_mapping[i].i2;\n\n    const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);\n    float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);\n\n#pragma unroll\n    for (int j = item_ct1.get_local_id(2); j < ne0;\n         j += item_ct1.get_local_range(2)) {\n        dst_row_original[j] = dst_row_contiguous[j];\n    }\n}\n\nstatic void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,\n                                 ggml_tensor *dst) try {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);\n    const ggml_tensor *src0 = dst->src[0];\n    const ggml_tensor *src1 = dst->src[1];\n    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && \"mul_mat_id does not support split buffers\");\n\n    const ggml_tensor *ids = dst->src[2];\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const queue_ptr stream = ctx.stream();\n\n    const int64_t n_as = ne02;\n    const int64_t n_ids = ids->ne[0];\n\n    std::vector<char> ids_host(ggml_nbytes(ids));\n    const char * ids_dev = (const char *) ids->data;\n\n    SYCL_CHECK(CHECK_TRY_ERROR(\n        stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));\n    SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));\n\n    ggml_tensor src0_row = *src0;\n    ggml_tensor src1_row = *src1;\n    ggml_tensor dst_row = *dst;\n\n    char *src0_original = (char *)src0->data;\n    char *src1_original = (char *)src1->data;\n    char *dst_original = (char *)dst->data;\n\n    src0_row.ne[2] = 1;\n    src0_row.ne[3] = 1;\n    src0_row.nb[3] = nb02;\n\n    src1_row.ne[1] = 1;\n    src1_row.ne[2] = 1;\n    src1_row.ne[3] = 1;\n    src1_row.nb[2] = nb11;\n    src1_row.nb[3] = nb11;\n\n    dst_row.ne[1] = 1;\n    dst_row.ne[2] = 1;\n    dst_row.ne[3] = 1;\n    dst_row.nb[2] = nb1;\n    dst_row.nb[3] = nb1;\n    if (ne12 == 1) {\n        for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {\n            for (int64_t id = 0; id < n_ids; id++) {\n                const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);\n                GGML_ASSERT(i02 >= 0 && i02 < n_as);\n\n                const int64_t i11 = id % ne11;\n                const int64_t i12 = iid1;\n\n                const int64_t i1 = id;\n                const int64_t i2 = i12;\n\n            src0_row.data = src0_original + i02*nb02;\n            src1_row.data = src1_original + i11*nb11 + i12*nb12;\n            dst_row.data = dst_original + i1*nb1 + i2*nb2;\n\n            ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);\n            }\n        }\n    } else {\n        ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));\n        ggml_sycl_pool_alloc<char>  dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));\n\n        src1_row.data = src1_contiguous.get();\n        dst_row.data  =  dst_contiguous.get();\n\n        for (int64_t i02 = 0; i02 < n_as; i02++) {\n            int64_t num_src1_rows = 0;\n            for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {\n                for (int64_t id = 0; id < n_ids; id++) {\n                    const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);\n\n                    GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);\n\n                    if (row_id_i != i02) {\n                        continue;\n                    }\n\n                    num_src1_rows++;\n                }\n            }\n\n            if (num_src1_rows == 0) {\n                continue;\n            }\n\n\n            ggml_sycl_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);\n            ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);\n            SYCL_CHECK(CHECK_TRY_ERROR(\n                stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));\n\n            const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];\n            assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);\n\n            {\n                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));\n                sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);\n                stream->submit([&](sycl::handler &cgh) {\n                    sycl::local_accessor<int, 0> src1_row_acc(cgh);\n\n                    char *__restrict src1_contiguous_get =\n                        src1_contiguous.get();\n                    int *__restrict dev_cur_src1_row_get =\n                        dev_cur_src1_row.get();\n                    mmid_row_mapping *__restrict dev_row_mapping_get =\n                        dev_row_mapping.get();\n                    size_t ids_nb_ct6 = ids->nb[1];\n                    size_t ids_nb_ct7 = ids->nb[0];\n\n                    cgh.parallel_for(\n                        sycl::nd_range<3>(grid_dims * block_dims, block_dims),\n                        [=](sycl::nd_item<3> item_ct1) {\n                            k_copy_src1_to_contiguous(\n                                src1_original, src1_contiguous_get,\n                                dev_cur_src1_row_get,\n                                dev_row_mapping_get, ids_dev, i02,\n                                ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,\n                                item_ct1, src1_row_acc);\n                        });\n                });\n            }\n\n            src0_row.data = src0_original + i02*nb02;\n\n            GGML_ASSERT(nb11 == sizeof(float)*ne10);\n            GGML_ASSERT(nb1 == sizeof(float)*ne0);\n            src1_row.ne[1] = num_src1_rows;\n\n            src1_row.nb[1] = nb11;\n            src1_row.nb[2] = num_src1_rows*nb11;\n            src1_row.nb[3] = num_src1_rows*nb11;\n\n            dst_row.ne[1] = num_src1_rows;\n            dst_row.nb[1] = nb1;\n            dst_row.nb[2] = num_src1_rows*nb1;\n            dst_row.nb[3] = num_src1_rows*nb1;\n\n            ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);\n\n            {\n                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));\n                sycl::range<3> grid_dims(1, 1, num_src1_rows);\n                stream->submit([&](sycl::handler &cgh) {\n                    const char *__restrict dst_contiguous_get =\n                        dst_contiguous.get();\n                    const mmid_row_mapping *__restrict dev_row_mapping_get =\n                        dev_row_mapping.get();\n\n                    cgh.parallel_for(\n                        sycl::nd_range<3>(grid_dims * block_dims, block_dims),\n                        [=](sycl::nd_item<3> item_ct1) {\n                            k_copy_dst_from_contiguous(dst_original,\n                                                       dst_contiguous_get,\n                                                       dev_row_mapping_get,\n                                                       ne0, nb1, nb2, item_ct1);\n                        });\n                });\n            }\n        }\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_scale(ctx, dst);\n}\n\nstatic void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_diag_mask_inf(ctx, dst);\n}\n\nstatic void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_pool2d(ctx, dst);\n}\n\nstatic void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    ggml_sycl_op_im2col(ctx, dst);\n}\n\nstatic void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));\n    ggml_sycl_op_sum(ctx, dst);\n}\n\nstatic void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));\n    ggml_sycl_op_sum_rows(ctx, dst);\n}\n\nstatic void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));\n    ggml_sycl_op_mean(ctx, dst);\n}\n\nstatic void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));\n    ggml_sycl_op_argsort(ctx, dst);\n}\n\nstatic void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));\n    ggml_sycl_op_argmax(ctx, dst);\n}\n\n\nstatic void ggml_sycl_set_main_device(const int main_device) try {\n    if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {\n        return;\n    }\n    check_allow_gpu_index(main_device);\n    dpct::select_device(main_device);\n\n    if (g_ggml_sycl_debug) {\n        dpct::device_info prop;\n        SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(\n            prop, dpct::dev_mgr::instance().get_device(main_device))));\n        GGML_LOG_INFO(\"Using device %d (%s) as main device\\n\",\n                main_device, prop.get_name());\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try {\n    if (!g_sycl_loaded) return false;\n\n    if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {\n        ggml_sycl_set_peer_access(dst->src[1]->ne[1], ctx.device);\n    }\n\n    switch (dst->op) {\n        case GGML_OP_ARGMAX:\n            ggml_sycl_argmax(ctx, dst);\n            break;\n        case GGML_OP_CONV_TRANSPOSE_1D:\n            ggml_sycl_op_conv_transpose_1d(ctx, dst);\n            break;\n        case GGML_OP_REPEAT:\n            ggml_sycl_repeat(ctx, dst);\n            break;\n        case GGML_OP_REPEAT_BACK:\n            ggml_sycl_repeat_back(ctx, dst);\n            break;\n        case GGML_OP_GET_ROWS:\n            ggml_sycl_get_rows(ctx, dst);\n            break;\n        case GGML_OP_SET:\n            ggml_sycl_op_set(ctx, dst);\n            break;\n        case GGML_OP_SET_ROWS:\n            ggml_sycl_op_set_rows(ctx, dst);\n            break;\n        case GGML_OP_DUP:\n            ggml_sycl_dup(ctx, dst);\n            break;\n        case GGML_OP_ADD:\n        case GGML_OP_ADD1: // TODO: more efficient implementation\n            ggml_sycl_add(ctx, dst);\n            break;\n        case GGML_OP_ADD_ID:\n            ggml_sycl_add_id(ctx, dst);\n            break;\n        case GGML_OP_SUB:\n            ggml_sycl_sub(ctx, dst);\n            break;\n        case GGML_OP_COUNT_EQUAL:\n            ggml_sycl_count_equal(ctx, dst);\n            break;\n        case GGML_OP_ACC:\n            ggml_sycl_acc(ctx, dst);\n            break;\n        case GGML_OP_MUL:\n            ggml_sycl_mul(ctx, dst);\n            break;\n        case GGML_OP_LOG:\n            ggml_sycl_log(ctx, dst);\n            break;\n        case GGML_OP_DIV:\n            ggml_sycl_div(ctx, dst);\n            break;\n        case GGML_OP_UNARY:\n            switch (ggml_get_unary_op(dst)) {\n                case GGML_UNARY_OP_NEG:\n                    ggml_sycl_neg(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_STEP:\n                    ggml_sycl_step(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_GELU:\n                    ggml_sycl_gelu(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_SILU:\n                    ggml_sycl_silu(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_GELU_QUICK:\n                    ggml_sycl_gelu_quick(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_GELU_ERF:\n                    ggml_sycl_gelu_erf(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_TANH:\n                    ggml_sycl_tanh(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_RELU:\n                    ggml_sycl_relu(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_SIGMOID:\n                    ggml_sycl_sigmoid(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_HARDSIGMOID:\n                    ggml_sycl_hardsigmoid(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_HARDSWISH:\n                    ggml_sycl_hardswish(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_EXP:\n                    ggml_sycl_exp(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_SOFTPLUS:\n                    ggml_sycl_softplus(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_SGN:\n                    ggml_sycl_sgn(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_ABS:\n                    ggml_sycl_abs(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_ELU:\n                    ggml_sycl_elu(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_FLOOR:\n                    ggml_sycl_floor(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_CEIL:\n                    ggml_sycl_ceil(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_ROUND:\n                    ggml_sycl_round(ctx, dst);\n                    break;\n                case GGML_UNARY_OP_TRUNC:\n                    ggml_sycl_trunc(ctx, dst);\n                    break;\n                default:\n                    return false;\n            }\n            break;\n        case GGML_OP_GLU:\n            switch (ggml_get_glu_op(dst)) {\n                case GGML_GLU_OP_REGLU:\n                    ggml_sycl_reglu(ctx, dst);\n                    break;\n                case GGML_GLU_OP_GEGLU:\n                    ggml_sycl_geglu(ctx, dst);\n                    break;\n                case GGML_GLU_OP_SWIGLU:\n                    ggml_sycl_swiglu(ctx, dst);\n                    break;\n                case GGML_GLU_OP_SWIGLU_OAI:\n                    ggml_sycl_swiglu_oai(ctx, dst);\n                    break;\n                case GGML_GLU_OP_GEGLU_ERF:\n                    ggml_sycl_geglu_erf(ctx, dst);\n                    break;\n                case GGML_GLU_OP_GEGLU_QUICK:\n                    ggml_sycl_geglu_quick(ctx, dst);\n                    break;\n                default:\n                    return false;\n            }\n            break;\n        case GGML_OP_NORM:\n            ggml_sycl_norm(ctx, dst);\n            break;\n        case GGML_OP_GROUP_NORM:\n            ggml_sycl_group_norm(ctx, dst);\n            break;\n        case GGML_OP_CONCAT:\n            ggml_sycl_op_concat(ctx, dst);\n            break;\n        case GGML_OP_PAD_REFLECT_1D:\n            ggml_sycl_op_pad_reflect_1d(ctx,dst);\n            break;\n        case GGML_OP_UPSCALE:\n            ggml_sycl_upscale(ctx, dst);\n            break;\n        case GGML_OP_PAD:\n            ggml_sycl_pad(ctx, dst);\n            break;\n        case GGML_OP_LEAKY_RELU:\n            ggml_sycl_leaky_relu(ctx, dst);\n            break;\n        case GGML_OP_RMS_NORM_BACK:\n            ggml_sycl_rms_norm_back(ctx, dst);\n            break;\n        case GGML_OP_RMS_NORM:\n            ggml_sycl_rms_norm(ctx, dst);\n            break;\n        case GGML_OP_L2_NORM:\n            ggml_sycl_l2_norm(ctx, dst);\n            break;\n        case GGML_OP_MUL_MAT:\n            if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {\n                return false;\n            }\n            /* ggml_sycl_mul_mat_id is dependent on ggml_sycl_mul_mat */\n            ggml_sycl_mul_mat(ctx, dst->src[0], dst->src[1], dst);\n            break;\n        case GGML_OP_MUL_MAT_ID:\n            if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {\n                return false;\n            }\n            ggml_sycl_mul_mat_id(ctx, dst);\n            break;\n        case GGML_OP_OUT_PROD:\n            ggml_sycl_op_out_prod(ctx, dst);\n            break;\n        case GGML_OP_SCALE:\n            ggml_sycl_scale(ctx, dst);\n            break;\n        case GGML_OP_SQR:\n            ggml_sycl_sqr(ctx, dst);\n            break;\n        case GGML_OP_SQRT:\n            ggml_sycl_sqrt(ctx, dst);\n            break;\n        case GGML_OP_SIN:\n            ggml_sycl_sin(ctx, dst);\n            break;\n        case GGML_OP_COS:\n            ggml_sycl_cos(ctx, dst);\n            break;\n        case GGML_OP_CLAMP:\n            ggml_sycl_clamp(ctx, dst);\n            break;\n        case GGML_OP_CPY:\n            ggml_sycl_cpy(ctx, dst->src[0], dst->src[1]);\n            break;\n        case GGML_OP_CONT:\n            ggml_sycl_dup(ctx, dst);\n            break;\n        case GGML_OP_NONE:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n            GGML_SYCL_DEBUG(\"%s: Tensor NO-OP\\n\", __func__);\n            break;\n        case GGML_OP_TRI:\n            ggml_sycl_op_tri(ctx, dst);\n            break;\n        case GGML_OP_DIAG_MASK_INF:\n            ggml_sycl_diag_mask_inf(ctx, dst);\n            break;\n        case GGML_OP_SOFT_MAX:\n            ggml_sycl_op_soft_max(ctx, dst);\n            break;\n        case GGML_OP_SOFT_MAX_BACK:\n            ggml_sycl_op_soft_max_back(ctx, dst);\n            break;\n        case GGML_OP_ROPE:\n            ggml_sycl_rope(ctx, dst);\n            break;\n        case GGML_OP_ROPE_BACK:\n            ggml_sycl_rope_back(ctx, dst);\n            break;\n        case GGML_OP_IM2COL:\n            ggml_sycl_im2col(ctx, dst);\n            break;\n        case GGML_OP_POOL_2D:\n            ggml_sycl_pool2d(ctx, dst);\n            break;\n        case GGML_OP_SUM:\n            ggml_sycl_sum(ctx, dst);\n            break;\n        case GGML_OP_SUM_ROWS:\n            ggml_sycl_sum_rows(ctx, dst);\n            break;\n        case GGML_OP_MEAN:\n            ggml_sycl_mean(ctx, dst);\n            break;\n        case GGML_OP_ARGSORT:\n            ggml_sycl_argsort(ctx, dst);\n            break;\n        case GGML_OP_TOP_K:\n            ggml_sycl_op_top_k(ctx, dst);\n            break;\n        case GGML_OP_TIMESTEP_EMBEDDING:\n            ggml_sycl_op_timestep_embedding(ctx, dst);\n            break;\n        case GGML_OP_RWKV_WKV6:\n            ggml_sycl_op_rwkv_wkv6(ctx, dst);\n            break;\n        case GGML_OP_RWKV_WKV7:\n            ggml_sycl_op_rwkv_wkv7(ctx, dst);\n            break;\n        case GGML_OP_GATED_LINEAR_ATTN:\n            ggml_sycl_op_gated_linear_attn(ctx, dst);\n            break;\n        case GGML_OP_GATED_DELTA_NET:\n            ggml_sycl_gated_delta_net(ctx, dst);\n            break;\n        case GGML_OP_SSM_CONV:\n            ggml_sycl_ssm_conv(ctx, dst);\n            break;\n        case GGML_OP_ROLL:\n            ggml_sycl_roll(ctx, dst);\n            break;\n        case GGML_OP_ARANGE:\n            ggml_sycl_arange(ctx, dst);\n            break;\n        case GGML_OP_FLASH_ATTN_EXT:\n            ggml_sycl_flash_attn_ext(ctx, dst);\n            break;\n        default:\n            return false;\n    }\n\n    return true;\n} catch (sycl::exception & e) {\n    std::cerr << e.what() << \"Exception caught at file:\" << __FILE__ << \", line:\" << __LINE__ << std::endl;\n    std::cerr << \"Error OP \"<<ggml_op_name(dst->op)<< std::endl;\n    std::exit(1);\n}\n\nGGML_API void ggml_backend_sycl_get_device_description(int device, char *description,\n                                      size_t description_size) try {\n    GGML_SYCL_DEBUG(\"[SYCL] call ggml_backend_sycl_get_device_description\\n\");\n    dpct::device_info prop;\n    SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(\n        prop, dpct::dev_mgr::instance().get_device(device))));\n    snprintf(description, description_size, \"%s\", prop.get_name());\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nvoid ggml_backend_sycl_get_device_memory(int device, size_t *free,\n                                                   size_t *total) try {\n    GGML_SYCL_DEBUG(\"[SYCL] call ggml_backend_sycl_get_device_memory\\n\");\n    ggml_sycl_set_device(device);\n\n    SYCL_CHECK(CHECK_TRY_ERROR(\n        dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total)));\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\n////////////////////////////////////////////////////////////////////////////////\n\n// backend\n\nstatic const char * ggml_backend_sycl_get_name(ggml_backend_t backend) {\n\n    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;\n\n    return sycl_ctx->name.c_str();\n}\n\nstatic void ggml_backend_sycl_free(ggml_backend_t backend) {\n    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;\n\n    delete sycl_ctx;\n    delete backend;\n}\n\nstatic void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,\n                                               ggml_tensor *tensor,\n                                               const void *data, size_t offset,\n                                               size_t size) try {\n    GGML_SYCL_DEBUG(\"[SYCL] call %s\", __func__);\n    GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\": tensor\", tensor).c_str());\n    GGML_SYCL_DEBUG(\" size=%zu offset=%zu\\n\", size, offset);\n    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;\n    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;\n\n    GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && \"unsupported buffer type\");\n    const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);\n    SYCL_CHECK(CHECK_TRY_ERROR(\n        (stream)->memcpy((char *)tensor->data + offset, data, size)));\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,\n                                               const ggml_tensor *tensor,\n                                               void *data, size_t offset,\n                                               size_t size) try {\n    GGML_SYCL_DEBUG(\"[SYCL] call %s\", __func__);\n    GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\": tensor\", tensor).c_str());\n    GGML_SYCL_DEBUG(\" size=%zu offset=%zu\\n\", size, offset);\n    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;\n    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;\n\n    GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && \"unsupported buffer type\");\n    const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);\n    SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(\n        data, (const char *)tensor->data + offset, size)));\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,\n                                               const ggml_tensor *src,\n                                               ggml_tensor *dst) try {\n    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;\n    bool is_cpy_supported                = dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) &&\n                            ggml_backend_buffer_is_sycl(src->buffer);\n    GGML_SYCL_DEBUG(\"[SYCL] call %s\", __func__);\n    GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\": dst\", dst).c_str());\n    GGML_SYCL_DEBUG(\"%s\", debug_get_tensor_str(\" src\", src).c_str());\n    GGML_SYCL_DEBUG(\" is_cpy_supported=%d\\n\", is_cpy_supported);\n    if (is_cpy_supported) {\n        /*\n        DPCT1009:215: SYCL uses exceptions to report errors and does not use the\n        error codes. The original code was commented out and a warning string\n        was inserted. You need to rewrite this code.\n        */\n        const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);\n        SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(\n            dst->data, src->data, ggml_nbytes(dst))));\n        return true;\n    }\n\n    return false;\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_backend_sycl_synchronize(ggml_backend_t backend) try {\n    GGML_SYCL_DEBUG(\"[SYCL] call %s\\n\", __func__);\n    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;\n    const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);\n    SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait()));\n\n    GGML_UNUSED(backend);\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {\n    ggml_sycl_set_main_device(sycl_ctx->device);\n\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        ggml_tensor * node = cgraph->nodes[i];\n        if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {\n            continue;\n        }\n        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {\n            continue;\n        }\n#ifndef NDEBUG\n        assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));\n        for (int j = 0; j < GGML_MAX_SRC; j++) {\n            if (node->src[j] != nullptr) {\n                assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));\n            }\n        }\n#endif\n        bool ok = ggml_sycl_compute_forward(*sycl_ctx, node);\n        if (!ok) {\n            GGML_LOG_ERROR(\"%s: error: op not supported %s (%s)\\n\", __func__, node->name, ggml_op_name(node->op));\n        }\n        GGML_ASSERT(ok);\n    }\n}\n\n#ifdef GGML_SYCL_GRAPH\nstatic bool check_graph_compatibility(ggml_cgraph * cgraph) {\n    if (ggml_sycl_info().device_count > 1) {\n        // A sycl_ex::command_graph object can only be created for a single device\n        GGML_LOG_INFO(\"%s: disabling SYCL graphs due to multiple devices\\n\", __func__);\n        return false;\n    }\n\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        const ggml_op node_op = cgraph->nodes[i]->op;\n        switch (node_op) {\n            default:\n                break;\n            case GGML_OP_CONCAT:\n                // ggml_sycl_op_concat() does a blocking host wait after memcpy operations,\n                // but wait() can't be called on the events returned by a queue recording\n                // to a graph.\n                [[fallthrough]];\n            case GGML_OP_MUL_MAT_ID:\n                // ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after\n                // submitting a memcpy operation, but wait() can't be called on a queue that\n                // is recording to a graph.\n                GGML_LOG_INFO(\"%s: disabling SYCL graphs due to unsupported node type %s\\n\", __func__,\n                              ggml_op_name(node_op));\n                return false;\n            case GGML_OP_MUL_MAT:\n                // We cannot use graphs with ggml_sycl_mul_mat() when SYCL async memory allocation extensions are not available,\n                // as SYCL malloc / free and host wait calls are not supported when recording to a graph which are all present\n                // in reordering.\n                if (!g_ggml_sycl_use_async_mem_op) {\n                    GGML_LOG_INFO(\n                        \"%s: disabling SYCL graphs due to unsupported node type when using a compiler without the \"\n                        \"oneAPI async memory allocation extension \"\n                        \"%s\\n\",\n                        __func__, ggml_op_name(node_op));\n                    return false;\n                }\n        }\n    }\n    return true;\n}\n#endif\n\nstatic ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {\n    auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);\n\n#ifdef GGML_SYCL_GRAPH\n    bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph);\n    if (use_sycl_graph) {\n        const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph);\n        if (!graph_support) {\n            GGML_SYCL_DEBUG(\"[SYCL-GRAPH] can not use graphs on device:%d\\n\", sycl_ctx->device);\n            ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);\n            return GGML_STATUS_SUCCESS;\n        }\n\n        sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});\n\n        model_sycl_graph.begin_recording(*(sycl_ctx->stream()));\n        ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);\n        model_sycl_graph.end_recording();\n\n        const bool graph_update_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph);\n        if (!sycl_ctx->exec_graph || !graph_update_support) {\n            auto exec_graph = graph_update_support ? model_sycl_graph.finalize(sycl_ex::property::graph::updatable{}) :\n                                                     model_sycl_graph.finalize();\n            sycl_ctx->exec_graph = std::make_unique<\n                sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);\n        } else {\n            try {\n                sycl_ctx->exec_graph->update(model_sycl_graph);\n                GGML_SYCL_DEBUG(\"[SYCL-GRAPH] update success\\n\");\n            } catch (sycl::exception const & e) {\n                GGML_SYCL_DEBUG(\"[SYCL-GRAPH] Exception when updating graph, %s\\n\", e.what());\n                auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});\n                sycl_ctx->exec_graph = std::make_unique<\n                    sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);\n            }\n        }\n\n        sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph));\n    } else\n#endif\n    {\n        ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);\n    }\n    return GGML_STATUS_SUCCESS;\n}\n\nstatic void ggml_backend_sycl_event_record(ggml_backend_t backend, ggml_backend_event_t event)\ntry\n{\n    ggml_backend_sycl_context *sycl_ctx =\n        (ggml_backend_sycl_context *)backend->context;\n\n    sycl::event *sycl_event = static_cast<sycl::event *>(event->context);\n\n    const queue_ptr &stream = sycl_ctx->stream(sycl_ctx->device, 0);\n    // Record the current state of the queue\n    SYCL_CHECK(CHECK_TRY_ERROR(*sycl_event = stream->ext_oneapi_submit_barrier()));\n}\ncatch (sycl::exception const &exc)\n{\n    std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n              << \", line:\" << __LINE__ << std::endl;\n    std::exit(1);\n}\n\nstatic void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try {\n    GGML_SYCL_DEBUG(\"[SYCL] call %s\\n\", __func__);\n    sycl::event* sycl_event = static_cast<sycl::event*>(event->context);\n\n    if (ggml_backend_is_sycl(backend)) {\n        SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));\n    } else\n        GGML_ABORT(\"fatal error\");\n} catch (sycl::exception const& exc) {\n    std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n              << \", line:\" << __LINE__ << std::endl;\n    std::exit(1);\n}\n\nstatic ggml_backend_i ggml_backend_sycl_interface = {\n    /* .get_name                = */ ggml_backend_sycl_get_name,\n    /* .free                    = */ ggml_backend_sycl_free,\n    /* .set_tensor_async        = */ ggml_backend_sycl_set_tensor_async,\n    /* .get_tensor_async        = */ ggml_backend_sycl_get_tensor_async,\n    /* .cpy_tensor_async        = */ NULL, // ggml_backend_sycl_cpy_tensor_async,\n                                           // // TODO: update for the new\n                                           // interface\n    /* .synchronize             = */ ggml_backend_sycl_synchronize,\n    /* .graph_plan_create       = */ NULL,\n    /* .graph_plan_free         = */ NULL,\n    /* .graph_plan_update       = */ NULL,\n    /* .graph_plan_compute      = */ NULL,\n    /* .graph_compute           = */ ggml_backend_sycl_graph_compute,\n    /* .event_record            = */ ggml_backend_sycl_event_record,\n    /* .event_wait              = */ ggml_backend_sycl_event_wait,\n    /* .graph_optimize          = */ NULL,\n};\n\nstatic ggml_guid_t ggml_backend_sycl_guid() {\n    static ggml_guid guid = { 0x58, 0x05, 0x13, 0x8f, 0xcd, 0x3a, 0x61, 0x9d, 0xe7, 0xcd, 0x98, 0xa9, 0x03, 0xfd, 0x7c, 0x53 };\n    return &guid;\n}\n\nbool ggml_backend_is_sycl(ggml_backend_t backend) {\n    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_sycl_guid());\n}\n\nint ggml_backend_sycl_get_device_count() {\n    return ggml_sycl_info().device_count;\n}\n\n\n// backend device\n\nstruct ggml_backend_sycl_device_context {\n    int device;\n    std::string name;\n    std::string description;\n    int op_offload_min_batch_size;\n};\n\nstatic const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) {\n    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;\n    return ctx->name.c_str();\n}\n\nstatic const char * ggml_backend_sycl_device_get_description(ggml_backend_dev_t dev) {\n    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;\n    return ctx->description.c_str();\n}\n\nstatic void ggml_backend_sycl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {\n    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;\n    ggml_sycl_set_device(ctx->device);\n    SYCL_CHECK(CHECK_TRY_ERROR(\n    dpct::dev_mgr::instance().get_device(ctx->device).get_memory_info(*free, *total)));\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_sycl_device_get_type(ggml_backend_dev_t dev) {\n    GGML_UNUSED(dev);\n    return GGML_BACKEND_DEVICE_TYPE_GPU;\n}\n\nstatic void ggml_backend_sycl_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {\n    props->name        = ggml_backend_sycl_device_get_name(dev);\n    props->description = ggml_backend_sycl_device_get_description(dev);\n    props->type        = ggml_backend_sycl_device_get_type(dev);\n    ggml_backend_sycl_device_get_memory(dev, &props->memory_free, &props->memory_total);\n\n    bool host_buffer = getenv(\"GGML_SYCL_NO_PINNED\") == nullptr;\n#ifdef GGML_SYCL_NO_PEER_COPY\n    bool events = false;\n#else\n    bool events = true;\n#endif\n\n    props->caps = {\n        /* .async                 = */ true,\n        /* .host_buffer           = */ host_buffer,\n        /* .buffer_from_host_ptr  = */ false,\n        /* .events                = */ events,\n    };\n}\n\nstatic ggml_backend_t ggml_backend_sycl_device_init(ggml_backend_dev_t dev, const char * params) {\n    GGML_UNUSED(params);\n    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;\n    return ggml_backend_sycl_init(ctx->device);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_sycl_device_get_buffer_type(ggml_backend_dev_t dev) {\n    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;\n    return ggml_backend_sycl_buffer_type(ctx->device);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_sycl_device_get_host_buffer_type(ggml_backend_dev_t dev) {\n    GGML_UNUSED(dev);\n    return ggml_backend_sycl_host_buffer_type();\n}\n\nstatic ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {\n    GGML_UNUSED(dev);\n    GGML_UNUSED(ptr);\n    GGML_UNUSED(size);\n    GGML_UNUSED(max_tensor_size);\n    return nullptr;\n}\n\nstatic bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n    ggml_backend_sycl_device_context *sycl_ctx =\n        (ggml_backend_sycl_device_context *)dev->context;\n    int device = sycl_ctx->device;\n    switch (op->op) {\n        case GGML_OP_CONV_TRANSPOSE_1D:\n            {\n                ggml_type src0_type = op->src[0]->type;\n                ggml_type src1_type = op->src[1]->type;\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                return false;\n            }\n        case GGML_OP_UNARY:\n            switch (ggml_get_unary_op(op)) {\n                case GGML_UNARY_OP_SGN:\n                case GGML_UNARY_OP_ABS:\n                case GGML_UNARY_OP_NEG:\n                case GGML_UNARY_OP_STEP:\n                case GGML_UNARY_OP_RELU:\n                case GGML_UNARY_OP_HARDSIGMOID:\n                case GGML_UNARY_OP_TANH:\n                case GGML_UNARY_OP_GELU:\n                case GGML_UNARY_OP_SILU:\n                case GGML_UNARY_OP_SIGMOID:\n                case GGML_UNARY_OP_HARDSWISH:\n                case GGML_UNARY_OP_GELU_QUICK:\n                case GGML_UNARY_OP_GELU_ERF:\n                case GGML_UNARY_OP_EXP:\n                case GGML_UNARY_OP_SOFTPLUS:\n                case GGML_UNARY_OP_ELU:\n                case GGML_UNARY_OP_CEIL:\n                    return true;\n                case GGML_UNARY_OP_FLOOR:\n                case GGML_UNARY_OP_ROUND:\n                case GGML_UNARY_OP_TRUNC:\n#if defined (GGML_SYCL_F16)\n                    return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);\n#else\n                    return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);\n#endif\n                default:\n                    return false;\n            }\n        case GGML_OP_GLU:\n            switch (ggml_get_glu_op(op)) {\n                case GGML_GLU_OP_REGLU:\n                case GGML_GLU_OP_GEGLU:\n                case GGML_GLU_OP_SWIGLU:\n                case GGML_GLU_OP_SWIGLU_OAI:\n                case GGML_GLU_OP_GEGLU_ERF:\n                case GGML_GLU_OP_GEGLU_QUICK:\n                    return ggml_is_contiguous_1(op->src[0]);\n                default:\n                    return false;\n            }\n            break;\n        case GGML_OP_MUL_MAT:\n        case GGML_OP_MUL_MAT_ID:\n            {\n                struct ggml_tensor * a = op->src[0];\n                struct ggml_tensor * b = op->src[1];\n\n                if (a->ne[3] != b->ne[3]) {\n                    return false;\n                }\n                ggml_type a_type = a->type;\n                if (a_type == GGML_TYPE_IQ4_NL  || a_type == GGML_TYPE_IQ4_XS ||\n                    a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S  ||\n                    a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S ||\n                    a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M\n                    ) {\n                    if (b->ne[1] == 1 && ggml_nrows(b) > 1) {\n                        return false;\n                    }\n                }\n                ggml_type src0_type = op->src[0]->type;\n                if (src0_type == GGML_TYPE_BF16 ) {\n                    // TODO: support GGML_TYPE_BF16\n                    // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added\n                    return false;\n                }\n\n                // TODO: The configuration below needs more work to be supported with oneDNN\n                if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&\n                    a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {\n                  return false;\n                }\n\n                // TODO: This specific configuration can fail with oneDNN and needs more debugging\n                if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&\n                    a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {\n                    return false;\n                }\n                return true;\n            }\n        case GGML_OP_OUT_PROD:\n            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;\n        case GGML_OP_GET_ROWS:\n            {\n                switch (op->src[0]->type) {\n                    case GGML_TYPE_F16:\n                    case GGML_TYPE_F32:\n                    case GGML_TYPE_Q4_0:\n                    case GGML_TYPE_Q4_1:\n                    case GGML_TYPE_Q5_0:\n                    case GGML_TYPE_Q5_1:\n                    case GGML_TYPE_Q8_0:\n                        return true;\n                    default:\n                        return false;\n                }\n            }\n         case GGML_OP_SET:\n               return (op->type == GGML_TYPE_F32) &&\n                      (op->src[0] && op->src[1]) &&\n                      (op->src[0]->type == GGML_TYPE_F32) &&\n                      (op->src[1]->type == GGML_TYPE_F32);\n\n        case GGML_OP_SET_ROWS:\n            {\n                return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||\n                         op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q5_0 ||\n                         op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL) &&\n                        (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32));\n            }\n            break;\n        case GGML_OP_CPY:\n            {\n                ggml_type src0_type = op->src[0]->type;\n                ggml_type src1_type = op->src[1]->type;\n                if (src0_type == src1_type && (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) && src0_type != GGML_TYPE_BF16) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {\n                    return true;\n                }\n                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {\n                    return true;\n                }\n                if(src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_Q8_0) {\n                    return true;\n                }\n                if(src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_Q5_0) {\n                    return true;\n                }\n                if(src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_Q5_1) {\n                    return true;\n                }\n                if(src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_Q4_0) {\n                    return true;\n                }\n                if(src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_Q4_1) {\n                    return true;\n                }\n                return false;\n            }\n        case GGML_OP_REPEAT_BACK:\n            {\n                ggml_type src0_type = op->src[0]->type;\n                return src0_type == GGML_TYPE_F32;\n            }\n        case GGML_OP_CONCAT:\n        case GGML_OP_DUP:\n        case GGML_OP_ARGMAX:\n        case GGML_OP_NONE:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n        case GGML_OP_ADD:\n        case GGML_OP_ADD1:\n        case GGML_OP_ADD_ID:\n        case GGML_OP_SUB:\n        case GGML_OP_COUNT_EQUAL:\n        case GGML_OP_MUL:\n        case GGML_OP_DIV:\n        case GGML_OP_REPEAT:\n            return true;\n        case GGML_OP_PAD_REFLECT_1D:\n            return ggml_is_contiguous(op->src[0]) && op-> type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_SQR:\n        case GGML_OP_SQRT:\n        case GGML_OP_SIN:\n        case GGML_OP_COS:\n        case GGML_OP_CLAMP:\n        case GGML_OP_LOG:\n#if defined (GGML_SYCL_F16)\n            return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type));\n#else\n            return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);\n#endif\n        case GGML_OP_NORM:\n        case GGML_OP_L2_NORM:\n        case GGML_OP_GROUP_NORM:\n        case GGML_OP_RMS_NORM:\n            return true;\n        case GGML_OP_RMS_NORM_BACK:\n            return ggml_is_contiguous(op->src[0]);\n        case GGML_OP_SCALE:\n            return true;\n        case GGML_OP_CONT:\n            return op->src[0]->type != GGML_TYPE_BF16;\n        case GGML_OP_TRI:\n            {\n                const ggml_tensor * src0 = op->src[0];\n                return src0 &&\n                       op->type == GGML_TYPE_F32 &&\n                       ggml_is_contiguous(src0);\n            }\n        case GGML_OP_DIAG_MASK_INF:\n            return true;\n        case GGML_OP_SOFT_MAX:\n            return true;\n        case GGML_OP_SOFT_MAX_BACK: {\n            float max_bias = 0.0f;\n            memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));\n            return max_bias == 0.0f;\n        }\n        case GGML_OP_ROPE:\n        case GGML_OP_ROPE_BACK:\n        case GGML_OP_IM2COL:\n            return true;\n        case GGML_OP_UPSCALE:\n            return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);\n        case GGML_OP_SUM:\n        case GGML_OP_SUM_ROWS:\n        case GGML_OP_MEAN:\n            return ggml_is_contiguous(op->src[0]);\n        case GGML_OP_ARGSORT:\n            return op->src[0]->ne[0] * sizeof(int) <=\n                   ggml_sycl_info().devices[device].smpbo;\n        case GGML_OP_TOP_K: {\n            const ggml_tensor * src0 = op->src[0];\n            const int k = op->ne[0];\n            return src0 &&\n                op->type == GGML_TYPE_I32 &&\n                src0->type == GGML_TYPE_F32 &&\n                ggml_is_contiguous(src0) &&\n                k > 0 && k <= 32;\n        }\n        case GGML_OP_POOL_2D:\n            return true;\n        case GGML_OP_ACC:\n            return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);\n        case GGML_OP_PAD:\n            // TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985\n            if (ggml_get_op_params_i32(op, 8) != 0) {\n                return false;\n            }\n            return ggml_is_contiguous(op->src[0]);\n        case GGML_OP_LEAKY_RELU:\n        case GGML_OP_TIMESTEP_EMBEDDING:\n        case GGML_OP_RWKV_WKV6:\n        case GGML_OP_RWKV_WKV7:\n        case GGML_OP_GATED_LINEAR_ATTN:\n        case GGML_OP_GATED_DELTA_NET:\n            return true;\n        case GGML_OP_SSM_CONV:\n            return op->type == GGML_TYPE_F32 &&\n                   op->src[0]->type == GGML_TYPE_F32 &&\n                   op->src[1]->type == GGML_TYPE_F32;\n        case GGML_OP_ROLL:\n            return op->type == GGML_TYPE_F32;\n        case GGML_OP_ARANGE:\n            return op->type == GGML_TYPE_F32;\n        case GGML_OP_FLASH_ATTN_EXT:\n            return ggml_sycl_flash_attn_ext_supported(device, op);\n        default:\n            return false;\n    }\n\n    GGML_UNUSED(dev);\n}\n\nstatic bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    if (buft->iface.get_name != ggml_backend_sycl_buffer_type_get_name) {\n        return false;\n    }\n    ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;\n    ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;\n    return buft_ctx->device == sycl_ctx->device;\n}\n\nstatic int64_t get_op_batch_size(const ggml_tensor * op) {\n    switch (op->op) {\n        case GGML_OP_GET_ROWS:\n            return 0;\n        case GGML_OP_MUL_MAT:\n            return op->ne[1];\n        case GGML_OP_MUL_MAT_ID:\n        case GGML_OP_ROPE:\n            return op->ne[2];\n        default:\n            return ggml_nrows(op);\n    }\n}\n\nstatic bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n    ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;\n    return get_op_batch_size(op) >= sycl_ctx->op_offload_min_batch_size;\n}\n\nstatic ggml_backend_event_t\nggml_backend_sycl_device_event_new(ggml_backend_dev_t dev) {\n\n#ifdef GGML_SYCL_NO_PEER_COPY\n    return nullptr;\n#else\n  sycl::event *event_ptr = new sycl::event();\n\n  return new ggml_backend_event{\n      /* .device = */ dev,\n      /* .context = */ event_ptr,\n  };\n#endif\n}\n\nstatic void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) try {\n  GGML_UNUSED(dev);\n  if (event == nullptr) {\n    return;\n  }\n\n  if (event->context != nullptr) {\n    sycl::event *sycl_event = static_cast<sycl::event *>(event->context);\n    delete sycl_event;\n    event->context = nullptr;\n  }\n\n  delete event;\n} catch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\n\nstatic void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try {\n  GGML_UNUSED(dev);\n  GGML_SYCL_DEBUG(\"[SYCL] call %s\\n\", __func__);\n\n  sycl::event *sycl_event = static_cast<sycl::event *>(event->context);\n  SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));\n} catch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic const ggml_backend_device_i ggml_backend_sycl_device_interface = {\n    /* .get_name                = */ ggml_backend_sycl_device_get_name,\n    /* .get_description         = */ ggml_backend_sycl_device_get_description,\n    /* .get_memory              = */ ggml_backend_sycl_device_get_memory,\n    /* .get_type                = */ ggml_backend_sycl_device_get_type,\n    /* .get_props               = */ ggml_backend_sycl_device_get_props,\n    /* .init_backend            = */ ggml_backend_sycl_device_init,\n    /* .get_buffer_type         = */ ggml_backend_sycl_device_get_buffer_type,\n    /* .get_host_buffer_type    = */ ggml_backend_sycl_device_get_host_buffer_type,\n    /* .buffer_from_host_ptr    = */ ggml_backend_sycl_device_buffer_from_host_ptr,\n    /* .supports_op             = */ ggml_backend_sycl_device_supports_op,\n    /* .supports_buft           = */ ggml_backend_sycl_device_supports_buft,\n    /* .offload_op              = */ ggml_backend_sycl_device_offload_op,\n    /* .event_new               = */ ggml_backend_sycl_device_event_new,\n    /* .event_free              = */ ggml_backend_sycl_device_event_free,\n    /* .event_synchronize       = */ ggml_backend_sycl_device_event_synchronize,\n};\n\n// backend reg\n\nstruct ggml_backend_sycl_reg_context {\n    std::vector<ggml_backend_dev_t> devices;\n};\n\nstatic const char * ggml_backend_sycl_reg_get_name(ggml_backend_reg_t reg) {\n    GGML_UNUSED(reg);\n    return GGML_SYCL_NAME;\n}\n\nstatic size_t ggml_backend_sycl_reg_get_device_count(ggml_backend_reg_t reg) {\n    ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context;\n    return ctx->devices.size();\n}\n\nstatic ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t reg, size_t index) {\n    ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context;\n    GGML_ASSERT(index < ctx->devices.size());\n    return ctx->devices[index];\n}\n\nstatic void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) {\n    GGML_UNUSED(reg);\n\n    if (strcmp(name, \"ggml_backend_split_buffer_type\") == 0) {\n        return (void *)ggml_backend_sycl_split_buffer_type;\n    }\n\n    // SYCL doesn't support registering host memory, left here for reference\n    // \"ggml_backend_register_host_buffer\"\n    // \"ggml_backend_unregister_host_buffer\"\n    GGML_UNUSED(name);\n    return nullptr;\n}\n\nstatic const ggml_backend_reg_i ggml_backend_sycl_reg_interface = {\n    /* .get_name          = */ ggml_backend_sycl_reg_get_name,\n    /* .get_device_count  = */ ggml_backend_sycl_reg_get_device_count,\n    /* .get_device        = */ ggml_backend_sycl_reg_get_device,\n    /* .get_proc_address  = */ ggml_backend_sycl_reg_get_proc_address,\n};\n\n\n// backend registry\n\nggml_backend_reg_t ggml_backend_sycl_reg() {\n    static ggml_backend_reg reg;\n    static bool initialized = false;\n\n    {\n        static std::mutex mutex;\n        std::lock_guard<std::mutex> lock(mutex);\n        if (!initialized) {\n            ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context;\n            const int min_batch_size = getenv(\"GGML_OP_OFFLOAD_MIN_BATCH\") ? atoi(getenv(\"GGML_OP_OFFLOAD_MIN_BATCH\")) : 32;\n\n            for (int i = 0; i < ggml_sycl_info().device_count; i++) {\n                ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context;\n                dev_ctx->device = i;\n                dev_ctx->name = GGML_SYCL_NAME + std::to_string(i);\n\n                ggml_sycl_set_device(i);\n\n                dpct::device_info prop;\n                SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(\n                    prop, dpct::dev_mgr::instance().get_device(i))));\n\n                dev_ctx->description = prop.get_name();\n                dev_ctx->op_offload_min_batch_size = min_batch_size;\n\n                ggml_backend_dev_t dev = new ggml_backend_device {\n                    /* .iface       = */ ggml_backend_sycl_device_interface,\n                    /* .reg         = */ &reg,\n                    /* .context     = */ dev_ctx\n                };\n                ctx->devices.push_back(dev);\n            }\n\n            reg = ggml_backend_reg {\n                /* .api_version = */ GGML_BACKEND_API_VERSION,\n                /* .iface       = */ ggml_backend_sycl_reg_interface,\n                /* .context     = */ ctx\n            };\n        }\n\n        initialized = true;\n    }\n\n    return &reg;\n}\n\nggml_backend_t ggml_backend_sycl_init(int device) {\n    GGML_SYCL_DEBUG(\"[SYCL] call ggml_backend_sycl_init\\n\");\n    ggml_check_sycl();\n\n    check_allow_gpu_index(device);\n\n    ggml_backend_sycl_context * ctx = new ggml_backend_sycl_context(device);\n    if (ctx == nullptr) {\n        GGML_LOG_ERROR(\"%s: error: failed to allocate context\\n\", __func__);\n        return nullptr;\n    };\n\n    ggml_backend_t sycl_backend = new ggml_backend {\n        /* .guid    = */ ggml_backend_sycl_guid(),\n        /* .iface   = */ ggml_backend_sycl_interface,\n        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),\n        /* .context = */ ctx\n    };\n\n    return sycl_backend;\n}\n\nGGML_BACKEND_DL_IMPL(ggml_backend_sycl_reg)\n"
  },
  {
    "path": "src/ggml-sycl/gla.cpp",
    "content": "#include <sycl/sycl.hpp>\n\n#include \"common.hpp\"\n\ntemplate <u_int HEAD_SIZE>\nstatic void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B, u_int T, u_int C, u_int H, float scale,\n                                         const float * k, const float * v, const float * r, const float * td,\n                                         const float * s, float * dst) {\n    const u_int head_size    = HEAD_SIZE;\n    const u_int state_size   = C * head_size;\n    const u_int n_seq_tokens = T / B;\n    sycl::range<1> block_dims((C / H));\n    sycl::range<1> grid_dims((B * H));\n    stream->submit([&](sycl::handler & cgh) {\n        /* local memory accessors*/\n        auto _k  = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);\n        auto _r  = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);\n        auto _td = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);\n\n        cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {\n            u_int tid = item.get_local_id(0);\n            u_int bid = item.get_group(0);\n\n            u_int batch_i = bid / H;\n            u_int head_i  = bid % H;\n\n            float state[head_size];\n\n#pragma unroll\n            for (u_int i = 0; i < head_size; i++) {\n                state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];\n            }\n\n            for (u_int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;\n                 t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {\n\n                item.barrier(sycl::access::fence_space::local_space);  //sync threads\n                _k[tid]  = k[t];\n                _r[tid]  = r[t];\n                _td[tid] = td[t];\n                item.barrier(sycl::access::fence_space::local_space);  //sync threads\n\n                const float _v = v[t];\n                float       y  = 0;\n\n                for (u_int j = 0; j < head_size; j += 4) {\n                    const sycl::float4 & k  = (sycl::float4 &) (_k[j]);\n                    const sycl::float4 & r  = (sycl::float4 &) (_r[j]);\n                    const sycl::float4 & td = (sycl::float4 &) (_td[j]);\n                    sycl::float4 &       s  = (sycl::float4 &) (state[j]);\n                    sycl::float4         kv;\n\n                    kv.x() = k.x() * _v;\n                    kv.y() = k.y() * _v;\n                    kv.z() = k.z() * _v;\n                    kv.w() = k.w() * _v;\n\n                    s.x() = s.x() * td.x() + kv.x();\n                    s.y() = s.y() * td.y() + kv.y();\n                    s.z() = s.z() * td.z() + kv.z();\n                    s.w() = s.w() * td.w() + kv.w();\n\n                    y += r.x() * s.x();\n                    y += r.y() * s.y();\n                    y += r.z() * s.z();\n                    y += r.w() * s.w();\n                }\n                dst[t] = y * scale;\n            }\n#pragma unroll\n            for (u_int i = 0; i < head_size; i++) {\n                dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];\n            }\n        });\n    });\n}\n\nvoid ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/5);\n    const float * k_d  = static_cast<const float *>(dst->src[0]->data);\n    const float * v_d  = static_cast<const float *>(dst->src[1]->data);\n    const float * r_d  = static_cast<const float *>(dst->src[2]->data);\n    const float * td_d = static_cast<const float *>(dst->src[3]->data);\n    const float * s_d  = static_cast<const float *>(dst->src[4]->data);\n\n    const int64_t B = dst->src[4]->ne[1];\n    const int64_t T = dst->src[0]->ne[2];\n    const int64_t C = dst->ne[0];\n    const int64_t H = dst->src[0]->ne[1];\n\n    dpct::queue_ptr stream = ctx.stream();\n    GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);\n    GGML_ASSERT(C % H == 0);\n    GGML_ASSERT(C / H == 64 || C / H == 128);\n\n    float scale;\n    memcpy(&scale, dst->op_params, sizeof(float));\n\n    float * dst_d = (float *) dst->data;\n\n    if (C / H == 64) {\n        gated_linear_attn_f32_kernel<64>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);\n    } else {\n        gated_linear_attn_f32_kernel<128>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);\n    }\n}\n"
  },
  {
    "path": "src/ggml-sycl/gla.hpp",
    "content": "#ifndef GGML_SYCL_GLA_HPP\n#define GGML_SYCL_GLA_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\n#endif  // GGML_SYCL_GLA_HPP\n"
  },
  {
    "path": "src/ggml-sycl/im2col.cpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#include \"im2col.hpp\"\n\n#include <sycl/sycl.hpp>\n#include <type_traits>  // For std::is_same_v\n\n#include \"ggml.h\"\n\ntemplate <typename T>\nstatic void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_t offset_delta, int64_t IC, int64_t IW,\n                          int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,\n                          int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ct1) {\n    const int64_t work_group_size = item_ct1.get_local_range(2);\n    const int64_t global_id       = item_ct1.get_local_id(2) + (work_group_size * item_ct1.get_group(2));\n\n    // make each work-item deal with more elements since sycl global range can not exceed max int\n    for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) {\n        const int64_t ksize = OW * KH;\n        const int64_t kx    = i / ksize;\n        const int64_t kd    = kx * ksize;\n        const int64_t ky    = (i - kd) / OW;\n        const int64_t ix    = i % OW;\n\n        const int64_t oh    = item_ct1.get_group(1);\n        const int64_t batch = item_ct1.get_group(0) / IC;\n        const int64_t ic    = item_ct1.get_group(0) % IC;\n\n        const int64_t iiw = (ix * s0) + (kx * d0) - p0;\n        const int64_t iih = (oh * s1) + (ky * d1) - p1;\n\n        const int64_t offset_dst = (((batch * OH + oh) * OW + ix) * CHW) + (ic * (KW * KH) + ky * KW + kx);\n\n        const int64_t offset_src_base = (ic * offset_delta) + (batch * batch_offset);\n        const int64_t offset_src      = offset_src_base + (iih * IW) + iiw;\n\n        const bool  out_of_bounds = (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW);\n        const float src_val       = out_of_bounds ? 0.0f : x[offset_src];\n\n        if constexpr (std::is_same_v<T, sycl::half>) {\n            dst[offset_dst] = sycl::half(src_val);\n        } else if constexpr (std::is_same_v<T, float>) {\n            dst[offset_dst] = src_val;\n        }\n    }\n}\n\ntemplate <typename T>\nstatic void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,\n                                 int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta,\n                                 int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) {\n    const int64_t parallel_elements = OW * KW * KH;\n    const int64_t num_blocks        = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;\n\n    // decrease global range when it exceeds the max int\n    int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE);\n\n    sycl::range<3> block_nums(batch * IC, OH, num_blocks);\n    sycl::range<3> local_range(1, 1, local_size);\n\n    const int64_t CHW = IC * KH * KW;\n\n    stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {\n        im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1,\n                         p0, p1, d0, d1, item_ct1);\n    });\n}\n\nstatic void im2col_sycl_f16(const float * x, sycl::half * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH,\n                            int64_t KW, int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset,\n                            int64_t offset_delta, int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) {\n    if (!stream->get_device().has(sycl::aspect::fp16)) {\n        throw sycl::exception(sycl::make_error_code(sycl::errc::kernel_not_supported),\n                              \"Device does not support half precision (fp16) operations!\");\n    }\n    im2col_sycl_internal<sycl::half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0,\n                                     p1, d0, d1, stream);\n}\n\nstatic void im2col_sycl_f32(const float * x, float * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,\n                            int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, int s0,\n                            int s1, int p0, int p1, int d0, int d1, queue_ptr stream) {\n    im2col_sycl_internal<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1,\n                                d0, d1, stream);\n}\n\nvoid ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);\n\n    const int32_t s0 = ((const int32_t *) (dst->op_params))[0];\n    const int32_t s1 = ((const int32_t *) (dst->op_params))[1];\n    const int32_t p0 = ((const int32_t *) (dst->op_params))[2];\n    const int32_t p1 = ((const int32_t *) (dst->op_params))[3];\n    const int32_t d0 = ((const int32_t *) (dst->op_params))[4];\n    const int32_t d1 = ((const int32_t *) (dst->op_params))[5];\n\n    const bool is_2D = ((const int32_t *) (dst->op_params))[6] == 1;\n\n    const int64_t IC = src1->ne[is_2D ? 2 : 1];\n    const int64_t IH = is_2D ? src1->ne[1] : 1;\n    const int64_t IW = src1->ne[0];\n\n    const int64_t KH = is_2D ? src0->ne[1] : 1;\n    const int64_t KW = src0->ne[0];\n\n    const int64_t OH = is_2D ? dst->ne[2] : 1;\n    const int64_t OW = dst->ne[1];\n\n    const size_t  delta_offset = src1->nb[is_2D ? 2 : 1] / sizeof(float);\n    const int64_t batch        = src1->ne[is_2D ? 3 : 2];\n    const size_t  batch_offset = src1->nb[is_2D ? 3 : 2] / sizeof(float);\n\n    queue_ptr stream = ctx.stream();\n\n    if (dst->type == GGML_TYPE_F16) {\n        im2col_sycl_f16((const float *) src1->data, (sycl::half *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch,\n                        batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);\n    } else {\n        im2col_sycl_f32((const float *) src1->data, (float *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch,\n                        batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);\n    }\n}\n"
  },
  {
    "path": "src/ggml-sycl/im2col.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_IM2COL_HPP\n#define GGML_SYCL_IM2COL_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_op_im2col(\n        ggml_backend_sycl_context & ctx, ggml_tensor *dst);\n\n#endif // GGML_SYCL_IM2COL_HPP\n"
  },
  {
    "path": "src/ggml-sycl/mmq.cpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#include \"mmq.hpp\"\n#include \"vecdotq.hpp\"\n\ntypedef void (*allocate_tiles_sycl_t)(\n    int** x_ql,\n    sycl::half2** x_dm,\n    int** x_qh,\n    int** x_sc);\ntypedef void (*load_tiles_sycl_t)(\n    const void* __restrict__ vx,\n    int* __restrict__ x_ql,\n    sycl::half2* __restrict__ x_dm,\n    int* __restrict__ x_qh,\n    int* __restrict__ x_sc,\n    const int& i_offset,\n    const int& i_max,\n    const int& k,\n    const int& blocks_per_row);\ntypedef float (*vec_dot_q_mul_mat_sycl_t)(\n    const int* __restrict__ x_ql,\n    const sycl::half2* __restrict__ x_dm,\n    const int* __restrict__ x_qh,\n    const int* __restrict__ x_sc,\n    const int* __restrict__ y_qs,\n    const sycl::half2* __restrict__ y_ms,\n    const int& i,\n    const int& j,\n    const int& k);\n\n\ntemplate <int mmq_y>\nstatic __dpct_inline__ void\nallocate_tiles_q4_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,\n                    int *tile_x_qs_q4_0, float *tile_x_d_q4_0) {\n    (void)x_qh; (void)x_sc;\n\n    *x_ql = tile_x_qs_q4_0;\n    *x_dm = (sycl::half2 *)tile_x_d_q4_0;\n}\n\ntemplate <int mmq_y, int nwarps, bool need_check>\nstatic __dpct_inline__ void\nload_tiles_q4_0(const void *__restrict__ vx, int *__restrict__ x_ql,\n                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,\n                int *__restrict__ x_sc, const int &i_offset, const int &i_max,\n                const int &k, const int &blocks_per_row) {\n    (void)x_qh; (void)x_sc;\n    GGML_SYCL_ASSUME(i_offset >= 0);\n    GGML_SYCL_ASSUME(i_offset <  nwarps);\n    GGML_SYCL_ASSUME(k >= 0);\n    GGML_SYCL_ASSUME(k <  WARP_SIZE);\n\n    const int kbx  = k / QI4_0;\n    const int kqsx = k % QI4_0;\n\n    const block_q4_0 * bx0 = (const block_q4_0 *) vx;\n\n    float * x_dmf = (float *) x_dm;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {\n        int i = i0 + i_offset;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;\n\n        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);\n        // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;\n    }\n\n    const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;\n    const int kbxd = k % blocks_per_tile_x_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {\n        int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;\n\n        x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;\n    }\n}\n\nstatic __dpct_inline__ float vec_dot_q4_0_q8_1_mul_mat(\n    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,\n    const int *__restrict__ x_qh, const int *__restrict__ x_sc,\n    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,\n    const int &i, const int &j, const int &k) {\n    (void)x_qh; (void)x_sc;\n\n    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));\n    const float * x_dmf = (const float *) x_dm;\n\n    int u[2*VDR_Q4_0_Q8_1_MMQ];\n\n#pragma unroll\n    for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {\n        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];\n        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];\n    }\n\n    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>\n        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],\n         y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);\n}\n\ntemplate <int mmq_y>\nstatic __dpct_inline__ void\nallocate_tiles_q4_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,\n                    int *tile_x_qs_q4_1, sycl::half2 *tile_x_dm_q4_1) {\n    (void)x_qh; (void)x_sc;\n\n    *x_ql = tile_x_qs_q4_1;\n    *x_dm = tile_x_dm_q4_1;\n}\n\n\ntemplate <int mmq_y, int nwarps, bool need_check>\nstatic __dpct_inline__ void\nload_tiles_q4_1(const void *__restrict__ vx, int *__restrict__ x_ql,\n                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,\n                int *__restrict__ x_sc, const int &i_offset, const int &i_max,\n                const int &k, const int &blocks_per_row) {\n    (void)x_qh; (void)x_sc;\n\n    GGML_SYCL_ASSUME(i_offset >= 0);\n    GGML_SYCL_ASSUME(i_offset <  nwarps);\n    GGML_SYCL_ASSUME(k >= 0);\n    GGML_SYCL_ASSUME(k <  WARP_SIZE);\n\n    const int kbx  = k / QI4_1;\n    const int kqsx = k % QI4_1;\n\n    const block_q4_1 * bx0 = (const block_q4_1 *) vx;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {\n        int i = i0 + i_offset;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;\n\n        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);\n    }\n\n    const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;\n    const int kbxd = k % blocks_per_tile_x_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {\n        int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;\n\n        x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;\n    }\n}\n\nstatic __dpct_inline__ float vec_dot_q4_1_q8_1_mul_mat(\n    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,\n    const int *__restrict__ x_qh, const int *__restrict__ x_sc,\n    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,\n    const int &i, const int &j, const int &k) {\n    (void)x_qh; (void)x_sc;\n\n    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));\n\n    int u[2*VDR_Q4_1_Q8_1_MMQ];\n\n#pragma unroll\n    for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {\n        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];\n        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];\n    }\n\n    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>\n        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],\n         y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);\n}\n\ntemplate <int mmq_y>\nstatic __dpct_inline__ void\nallocate_tiles_q5_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,\n                    int *tile_x_ql_q5_0, float *tile_x_d_q5_0) {\n    (void)x_qh; (void)x_sc;\n\n    *x_ql = tile_x_ql_q5_0;\n    *x_dm = (sycl::half2 *)tile_x_d_q5_0;\n}\n\ntemplate <int mmq_y, int nwarps, bool need_check>\nstatic __dpct_inline__ void\nload_tiles_q5_0(const void *__restrict__ vx, int *__restrict__ x_ql,\n                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,\n                int *__restrict__ x_sc, const int &i_offset, const int &i_max,\n                const int &k, const int &blocks_per_row) {\n    (void)x_qh; (void)x_sc;\n\n    GGML_SYCL_ASSUME(i_offset >= 0);\n    GGML_SYCL_ASSUME(i_offset <  nwarps);\n    GGML_SYCL_ASSUME(k >= 0);\n    GGML_SYCL_ASSUME(k <  WARP_SIZE);\n\n    const int kbx  = k / QI5_0;\n    const int kqsx = k % QI5_0;\n\n    const block_q5_0 * bx0 = (const block_q5_0 *) vx;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {\n        int i = i0 + i_offset;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;\n\n        const int ql = get_int_from_uint8(bxi->qs, kqsx);\n        const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));\n\n        int qs0 = (ql >>  0)   & 0x0F0F0F0F;\n        qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4\n        qs0    |= (qh << 11)   & 0x00001000;  // 1 -> 12\n        qs0    |= (qh << 18)   & 0x00100000;  // 2 -> 20\n        qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28\n        qs0 = dpct::vectorized_binary<sycl::char4>(\n            qs0, 0x10101010, dpct::sub_sat()); // subtract 16\n\n        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;\n\n        int qs1 = (ql >>  4)   & 0x0F0F0F0F;\n        qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4\n        qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12\n        qs1    |= (qh <<  2)   & 0x00100000;  // 18 -> 20\n        qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28\n        qs1 = dpct::vectorized_binary<sycl::char4>(\n            qs1, 0x10101010, dpct::sub_sat()); // subtract 16\n\n        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;\n    }\n\n    const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;\n    const int kbxd = k % blocks_per_tile_x_row;\n    float * x_dmf = (float *) x_dm;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {\n        int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;\n\n        x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;\n    }\n}\n\nstatic __dpct_inline__ float vec_dot_q5_0_q8_1_mul_mat(\n    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,\n    const int *__restrict__ x_qh, const int *__restrict__ x_sc,\n    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,\n    const int &i, const int &j, const int &k) {\n    (void)x_qh; (void)x_sc;\n\n    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));\n    const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;\n    const float * x_dmf = (const float *) x_dm;\n    const float * y_df  = (const float *) y_ds;\n\n    int u[2*VDR_Q5_0_Q8_1_MMQ];\n\n#pragma unroll\n    for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {\n        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];\n        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];\n    }\n\n    return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>\n        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);\n}\n\ntemplate <int mmq_y>\nstatic __dpct_inline__ void\nallocate_tiles_q5_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,\n                    int *tile_x_ql_q5_1, sycl::half2 *tile_x_dm_q5_1) {\n    (void)x_qh; (void)x_sc;\n\n    *x_ql = tile_x_ql_q5_1;\n    *x_dm = tile_x_dm_q5_1;\n}\n\ntemplate <int mmq_y, int nwarps, bool need_check>\nstatic __dpct_inline__ void\nload_tiles_q5_1(const void *__restrict__ vx, int *__restrict__ x_ql,\n                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,\n                int *__restrict__ x_sc, const int &i_offset, const int &i_max,\n                const int &k, const int &blocks_per_row) {\n    (void)x_qh; (void)x_sc;\n\n    GGML_SYCL_ASSUME(i_offset >= 0);\n    GGML_SYCL_ASSUME(i_offset < nwarps);\n    GGML_SYCL_ASSUME(k >= 0);\n    GGML_SYCL_ASSUME(k <  WARP_SIZE);\n\n    const int kbx  = k / QI5_1;\n    const int kqsx = k % QI5_1;\n\n    const block_q5_1 * bx0 = (const block_q5_1 *) vx;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {\n        int i = i0 + i_offset;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;\n\n        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);\n        const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));\n\n        int qs0 = (ql >>  0) & 0x0F0F0F0F;\n        qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4\n        qs0    |= (qh << 11) & 0x00001000; // 1 -> 12\n        qs0    |= (qh << 18) & 0x00100000; // 2 -> 20\n        qs0    |= (qh << 25) & 0x10000000; // 3 -> 28\n\n        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;\n\n        int qs1 = (ql >>  4) & 0x0F0F0F0F;\n        qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4\n        qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12\n        qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20\n        qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28\n\n        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;\n    }\n\n    const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;\n    const int kbxd = k % blocks_per_tile_x_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {\n        int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;\n\n        x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;\n    }\n}\n\nstatic __dpct_inline__ float vec_dot_q5_1_q8_1_mul_mat(\n    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,\n    const int *__restrict__ x_qh, const int *__restrict__ x_sc,\n    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,\n    const int &i, const int &j, const int &k) {\n    (void)x_qh; (void)x_sc;\n\n    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));\n    const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;\n\n    int u[2*VDR_Q5_1_Q8_1_MMQ];\n\n#pragma unroll\n    for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {\n        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];\n        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];\n    }\n\n    return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>\n        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);\n}\n\ntemplate <int mmq_y>\nstatic __dpct_inline__ void\nallocate_tiles_q8_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,\n                    int *tile_x_qs_q8_0, float *tile_x_d_q8_0) {\n    (void)x_qh; (void)x_sc;\n\n    *x_ql = tile_x_qs_q8_0;\n    *x_dm = (sycl::half2 *)tile_x_d_q8_0;\n}\n\ntemplate <int mmq_y, int nwarps, bool need_check>\nstatic __dpct_inline__ void\nload_tiles_q8_0(const void *__restrict__ vx, int *__restrict__ x_ql,\n                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,\n                int *__restrict__ x_sc, const int &i_offset, const int &i_max,\n                const int &k, const int &blocks_per_row) {\n    (void)x_qh; (void)x_sc;\n\n    GGML_SYCL_ASSUME(i_offset >= 0);\n    GGML_SYCL_ASSUME(i_offset <  nwarps);\n    GGML_SYCL_ASSUME(k >= 0);\n    GGML_SYCL_ASSUME(k <  WARP_SIZE);\n\n    const int kbx  = k / QI8_0;\n    const int kqsx = k % QI8_0;\n    float * x_dmf = (float *) x_dm;\n\n    const block_q8_0 * bx0 = (const block_q8_0 *) vx;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {\n        int i = i0 + i_offset;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;\n\n        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);\n    }\n\n    const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;\n    const int kbxd = k % blocks_per_tile_x_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {\n        int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;\n\n        x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;\n    }\n}\n\nstatic __dpct_inline__ float vec_dot_q8_0_q8_1_mul_mat(\n    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,\n    const int *__restrict__ x_qh, const int *__restrict__ x_sc,\n    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,\n    const int &i, const int &j, const int &k) {\n    (void)x_qh; (void)x_sc;\n\n    const float * x_dmf = (const float *) x_dm;\n    const float * y_df  = (const float *) y_ds;\n\n    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>\n        (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],\n         y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);\n}\n\ntemplate <int mmq_y>\nstatic __dpct_inline__ void\nallocate_tiles_q2_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,\n                    int *tile_x_ql_q2_K, sycl::half2 *tile_x_dm_q2_K,\n                    int *tile_x_sc_q2_K) {\n    (void)x_qh;\n\n    *x_ql = tile_x_ql_q2_K;\n    *x_dm = tile_x_dm_q2_K;\n    *x_sc = tile_x_sc_q2_K;\n}\n\ntemplate <int mmq_y, int nwarps, bool need_check>\nstatic __dpct_inline__ void\nload_tiles_q2_K(const void *__restrict__ vx, int *__restrict__ x_ql,\n                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,\n                int *__restrict__ x_sc, const int &i_offset, const int &i_max,\n                const int &k, const int &blocks_per_row) {\n    (void)x_qh;\n\n    GGML_SYCL_ASSUME(i_offset >= 0);\n    GGML_SYCL_ASSUME(i_offset <  nwarps);\n    GGML_SYCL_ASSUME(k >= 0);\n    GGML_SYCL_ASSUME(k <  WARP_SIZE);\n\n    const int kbx  = k / QI2_K;\n    const int kqsx = k % QI2_K;\n\n    const block_q2_K * bx0 = (const block_q2_K *) vx;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {\n        int i = i0 + i_offset;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;\n\n        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);\n    }\n\n    const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;\n    const int kbxd = k % blocks_per_tile_x_row;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {\n        int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;\n\n        x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;\n    }\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {\n        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);\n\n        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));\n    }\n}\n\n#define VDR_Q2_K_Q8_1_MMQ  2\n// contiguous u/y values\nstatic __dpct_inline__ float\nvec_dot_q2_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,\n                           const uint8_t *__restrict__ scales,\n                           const sycl::half2 &dm2, const float &d8) {\n\n    int sumi_d = 0;\n    int sumi_m = 0;\n\n#pragma unroll\n    for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {\n        int sumi_d_sc = 0;\n\n        const int sc = scales[i0 / (QI8_1/2)];\n\n        // fill int with 4x m\n        int m = sc >> 4;\n        m |= m <<  8;\n        m |= m << 16;\n\n#pragma unroll\n        for (int i = i0; i < i0 + QI8_1/2; ++i) {\n            sumi_d_sc = dpct::dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product\n            sumi_m = dpct::dp4a(m, u[i],\n                                sumi_m); // multiply sum of q8_1 values with m\n        }\n\n        sumi_d += sumi_d_sc * (sc & 0xF);\n    }\n\n    const sycl::float2 dm2f =\n        dm2.convert<float, sycl::rounding_mode::automatic>();\n\n    return d8 * (dm2f.x() * sumi_d - dm2f.y() * sumi_m);\n}\n\nstatic __dpct_inline__ float vec_dot_q2_K_q8_1_mul_mat(\n    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,\n    const int *__restrict__ x_qh, const int *__restrict__ x_sc,\n    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,\n    const int &i, const int &j, const int &k) {\n    (void)x_qh;\n\n    const int kbx = k / QI2_K;\n    const int ky  = (k % QI2_K) * QR2_K;\n    const float * y_df = (const float *) y_ds;\n\n    int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];\n\n    const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);\n    const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));\n\n#pragma unroll\n    for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {\n        v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;\n    }\n\n    const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;\n\n    const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;\n    return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);\n}\n\ntemplate <int mmq_y>\nstatic __dpct_inline__ void\nallocate_tiles_q3_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,\n                    int *tile_x_ql_q3_K, sycl::half2 *tile_x_dm_q3_K,\n                    int *tile_x_qh_q3_K, int *tile_x_sc_q3_K) {\n\n    *x_ql = tile_x_ql_q3_K;\n    *x_dm = tile_x_dm_q3_K;\n    *x_qh = tile_x_qh_q3_K;\n    *x_sc = tile_x_sc_q3_K;\n}\n\ntemplate <int mmq_y, int nwarps, bool need_check>\nstatic __dpct_inline__ void\nload_tiles_q3_K(const void *__restrict__ vx, int *__restrict__ x_ql,\n                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,\n                int *__restrict__ x_sc, const int &i_offset, const int &i_max,\n                const int &k, const int &blocks_per_row) {\n\n    GGML_SYCL_ASSUME(i_offset >= 0);\n    GGML_SYCL_ASSUME(i_offset <  nwarps);\n    GGML_SYCL_ASSUME(k >= 0);\n    GGML_SYCL_ASSUME(k <  WARP_SIZE);\n\n    const int kbx  = k / QI3_K;\n    const int kqsx = k % QI3_K;\n\n    const block_q3_K * bx0 = (const block_q3_K *) vx;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {\n        int i = i0 + i_offset;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;\n\n        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);\n    }\n\n    const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;\n    const int kbxd = k % blocks_per_tile_x_row;\n    float * x_dmf = (float *) x_dm;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {\n        int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;\n\n        x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;\n    }\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {\n        int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);\n\n        // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted\n        x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));\n    }\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {\n        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);\n\n        const int ksc = k % (QI3_K/4);\n\n        const int ksc_low = ksc % (QI3_K/8);\n        const int shift_low = 4 * (ksc / (QI3_K/8));\n        const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;\n\n        const int ksc_high = QI3_K/8;\n        const int shift_high = 2 * ksc;\n        const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;\n\n        const int sc = dpct::vectorized_binary<sycl::char4>(\n            sc_low | sc_high, 0x20202020, dpct::sub_sat());\n\n        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;\n    }\n}\n\n#define VDR_Q3_K_Q8_1_MMQ  2\n// contiguous u/y values\nstatic __dpct_inline__ float\nvec_dot_q3_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,\n                           const int8_t *__restrict__ scales, const float &d3,\n                           const float &d8) {\n\n    int sumi = 0;\n\n#pragma unroll\n    for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {\n        int sumi_sc = 0;\n\n        for (int i = i0; i < i0 + QI8_1/2; ++i) {\n            sumi_sc = dpct::dp4a(v[i], u[i], sumi_sc); // SIMD dot product\n        }\n\n        sumi += sumi_sc * scales[i0 / (QI8_1/2)];\n    }\n\n    return d3*d8 * sumi;\n}\n\nstatic __dpct_inline__ float vec_dot_q3_K_q8_1_mul_mat(\n    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,\n    const int *__restrict__ x_qh, const int *__restrict__ x_sc,\n    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,\n    const int &i, const int &j, const int &k) {\n\n    const int kbx  = k / QI3_K;\n    const int ky  = (k % QI3_K) * QR3_K;\n    const float * x_dmf = (const float *) x_dm;\n    const float * y_df  = (const float *) y_ds;\n\n    const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;\n\n    int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];\n\n#pragma unroll\n    for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {\n        const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);\n        const int shift = 2 * ((ky % 32) / 8);\n        const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;\n\n        const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);\n        const int vlh = (vh << 2) & 0x04040404;\n\n        v[l] = dpct::vectorized_binary<sycl::char4>(vll, vlh, dpct::sub_sat());\n    }\n\n    const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;\n    return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);\n}\n\ntemplate <int mmq_y>\nstatic __dpct_inline__ void\nallocate_tiles_q4_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,\n                    int *tile_x_ql_q4_K, sycl::half2 *tile_x_dm_q4_K,\n                    int *tile_x_sc_q4_K) {\n    (void)x_qh;\n\n    *x_ql = tile_x_ql_q4_K;\n    *x_dm = tile_x_dm_q4_K;\n    *x_sc = tile_x_sc_q4_K;\n}\n\ntemplate <int mmq_y, int nwarps, bool need_check>\nstatic __dpct_inline__ void\nload_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql,\n                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,\n                int *__restrict__ x_sc, const int &i_offset, const int &i_max,\n                const int &k, const int &blocks_per_row) {\n    (void)x_qh;\n\n    GGML_SYCL_ASSUME(i_offset >= 0);\n    GGML_SYCL_ASSUME(i_offset <  nwarps);\n    GGML_SYCL_ASSUME(k >= 0);\n    GGML_SYCL_ASSUME(k <  WARP_SIZE);\n\n    const int kbx  = k / QI4_K; // == 0 if QK_K == 256\n    const int kqsx = k % QI4_K; // == k if QK_K == 256\n\n    const block_q4_K * bx0 = (const block_q4_K *) vx;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {\n        int i = i0 + i_offset;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;\n\n        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);\n    }\n\n    constexpr int blocks_per_tile_x_row = QI4_K > WARP_SIZE ? 1 : WARP_SIZE / QI4_K; // == 1 if QK_K == 256\n    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {\n        int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;\n\n#if QK_K == 256\n        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;\n#else\n        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};\n#endif\n    }\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {\n        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);\n\n        const int * scales = (const int *) bxi->scales;\n\n        const int ksc = k % (WARP_SIZE/8);\n\n        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8\n        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits\n        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits\n\n        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;\n    }\n}\n\n\n#define VDR_Q4_K_Q8_1_MMQ  8\n\n// contiguous u/y values\nstatic __dpct_inline__ float vec_dot_q4_K_q8_1_impl_mmq(\n    const int *__restrict__ v, const int *__restrict__ u,\n    const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,\n    const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {\n\n    float sumf_d = 0.0f;\n    float sumf_m = 0.0f;\n\n#pragma unroll\n    for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {\n        int sumi_d = 0;\n\n#pragma unroll\n        for (int j = 0; j < QI8_1; ++j) {\n            sumi_d = dpct::dp4a((v[j] >> (4 * i)) & 0x0F0F0F0F,\n                                u[i * QI8_1 + j], sumi_d); // SIMD dot product\n        }\n\n        const sycl::float2 ds8f =\n            ds8[i].convert<float, sycl::rounding_mode::automatic>();\n\n        sumf_d += ds8f.x() * (sc[i] * sumi_d);\n        sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val\n    }\n\n    const sycl::float2 dm4f =\n        dm4.convert<float, sycl::rounding_mode::automatic>();\n\n    return dm4f.x() * sumf_d - dm4f.y() * sumf_m;\n}\n\n\nstatic __dpct_inline__ float vec_dot_q4_K_q8_1_mul_mat(\n    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,\n    const int *__restrict__ x_qh, const int *__restrict__ x_sc,\n    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,\n    const int &i, const int &j, const int &k) {\n    (void)x_qh;\n\n    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);\n\n    const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;\n    return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,\n                                      x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);\n}\n\ntemplate <int mmq_y>\nstatic __dpct_inline__ void\nallocate_tiles_q5_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,\n                    int *tile_x_ql_q5_K, sycl::half2 *tile_x_dm_q5_K,\n                    int *tile_x_sc_q5_K) {\n    (void)x_qh;\n\n    *x_ql = tile_x_ql_q5_K;\n    *x_dm = tile_x_dm_q5_K;\n    *x_sc = tile_x_sc_q5_K;\n}\n\ntemplate <int mmq_y, int nwarps, bool need_check>\nstatic __dpct_inline__ void\nload_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql,\n                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,\n                int *__restrict__ x_sc, const int &i_offset, const int &i_max,\n                const int &k, const int &blocks_per_row) {\n    (void)x_qh;\n\n    GGML_SYCL_ASSUME(i_offset >= 0);\n    GGML_SYCL_ASSUME(i_offset <  nwarps);\n    GGML_SYCL_ASSUME(k >= 0);\n    GGML_SYCL_ASSUME(k <  WARP_SIZE);\n\n    const int kbx  = k / QI5_K; // == 0 if QK_K == 256\n    const int kqsx = k % QI5_K; // == k if QK_K == 256\n\n    const block_q5_K * bx0 = (const block_q5_K *) vx;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {\n        int i = i0 + i_offset;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;\n        const int ky = QR5_K*kqsx;\n\n        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);\n        const int ql0 = (ql >> 0) & 0x0F0F0F0F;\n        const int ql1 = (ql >> 4) & 0x0F0F0F0F;\n\n        const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));\n        const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;\n        const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;\n\n        const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;\n        const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);\n\n        x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;\n        x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;\n    }\n\n    constexpr int blocks_per_tile_x_row = QI5_K > WARP_SIZE ? 1 : WARP_SIZE / QI5_K; // == 1 if QK_K == 256\n    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {\n        int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;\n\n#if QK_K == 256\n        x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;\n#endif\n    }\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {\n        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);\n\n        const int * scales = (const int *) bxi->scales;\n\n        const int ksc = k % (WARP_SIZE/8);\n\n        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8\n        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits\n        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits\n\n        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;\n    }\n}\n\n#define VDR_Q5_K_Q8_1_MMQ  8\n\n// contiguous u/y values\nstatic __dpct_inline__ float vec_dot_q5_K_q8_1_impl_mmq(\n    const int *__restrict__ v, const int *__restrict__ u,\n    const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,\n    const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {\n\n    float sumf_d = 0.0f;\n    float sumf_m = 0.0f;\n\n#pragma unroll\n    for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {\n        int sumi_d = 0;\n\n#pragma unroll\n        for (int j = 0; j < QI8_1; ++j) {\n            sumi_d = dpct::dp4a(v[i * QI8_1 + j], u[i * QI8_1 + j],\n                                sumi_d); // SIMD dot product\n        }\n\n        const sycl::float2 ds8f =\n            ds8[i].convert<float, sycl::rounding_mode::automatic>();\n\n        sumf_d += ds8f.x() * (sc[i] * sumi_d);\n        sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val\n    }\n\n    const sycl::float2 dm4f =\n        dm4.convert<float, sycl::rounding_mode::automatic>();\n\n    return dm4f.x() * sumf_d - dm4f.y() * sumf_m;\n}\n\nstatic __dpct_inline__ float vec_dot_q5_K_q8_1_mul_mat(\n    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,\n    const int *__restrict__ x_qh, const int *__restrict__ x_sc,\n    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,\n    const int &i, const int &j, const int &k) {\n    (void)x_qh;\n\n    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);\n\n    const int index_x = i * (QR5_K*WARP_SIZE + 1) +  QR5_K*k;\n    const int index_y = j * WARP_SIZE             + (QR5_K*k) % WARP_SIZE;\n    return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,\n                                      x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);\n}\n\ntemplate <int mmq_y>\nstatic __dpct_inline__ void\nallocate_tiles_q6_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,\n                    int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_sc) {\n    (void)x_qh;\n\n    *x_ql = tile_x_ql;\n    *x_dm = tile_x_dm;\n    *x_sc = tile_x_sc;\n}\n\ntemplate <int mmq_y, int nwarps, bool need_check>\nstatic __dpct_inline__ void\nload_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql,\n                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,\n                int *__restrict__ x_sc, const int &i_offset, const int &i_max,\n                const int &k, const int &blocks_per_row) {\n    (void)x_qh;\n\n    GGML_SYCL_ASSUME(i_offset >= 0);\n    GGML_SYCL_ASSUME(i_offset <  nwarps);\n    GGML_SYCL_ASSUME(k >= 0);\n    GGML_SYCL_ASSUME(k <  WARP_SIZE);\n\n    const int kbx  = k / QI6_K; // == 0 if QK_K == 256\n    const int kqsx = k % QI6_K; // == k if QK_K == 256\n\n    const block_q6_K * bx0 = (const block_q6_K *) vx;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {\n        int i = i0 + i_offset;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;\n        const int ky = QR6_K*kqsx;\n\n        const int ql = get_int_from_uint8(bxi->ql, kqsx);\n        const int ql0 = (ql >> 0) & 0x0F0F0F0F;\n        const int ql1 = (ql >> 4) & 0x0F0F0F0F;\n\n        const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));\n        const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;\n        const int qh1 =  (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4))))       & 0x30303030;\n\n        const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;\n        const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);\n\n        x_ql[i * (2 * WARP_SIZE + 1) + kq0] =\n            dpct::vectorized_binary<sycl::char4>(ql0 | qh0, 0x20202020,\n                                                 dpct::sub_sat());\n        x_ql[i * (2 * WARP_SIZE + 1) + kq1] =\n            dpct::vectorized_binary<sycl::char4>(ql1 | qh1, 0x20202020,\n                                                 dpct::sub_sat());\n    }\n\n    constexpr int blocks_per_tile_x_row = QI6_K > WARP_SIZE ? 1 : WARP_SIZE / QI6_K; // == 1 if QK_K == 256\n    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256\n    float * x_dmf = (float *) x_dm;\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {\n        int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;\n\n        x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;\n    }\n\n#pragma unroll\n    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {\n        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;\n\n        if (need_check) {\n            i = sycl::min(i, i_max);\n        }\n\n        const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;\n\n        x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));\n    }\n}\n\n#define VDR_Q6_K_Q8_1_MMQ  8\n\n// contiguous u/y values\nstatic __dpct_inline__ float\nvec_dot_q6_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,\n                           const int8_t *__restrict__ sc, const float &d6,\n                           const float *__restrict__ d8) {\n\n    float sumf_d = 0.0f;\n\n#pragma unroll\n    for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {\n        sycl::int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale\n\n#pragma unroll\n        for (int i = i0; i < i0 + 2; ++i) {\n            sumi_d.x() = dpct::dp4a(v[2 * i + 0], u[2 * i + 0],\n                                    sumi_d.x()); // SIMD dot product\n            sumi_d.x() = dpct::dp4a(v[2 * i + 1], u[2 * i + 1],\n                                    sumi_d.x()); // SIMD dot product\n\n            sumi_d.y() = dpct::dp4a(v[2 * i + 4], u[2 * i + 4],\n                                    sumi_d.y()); // SIMD dot product\n            sumi_d.y() = dpct::dp4a(v[2 * i + 5], u[2 * i + 5],\n                                    sumi_d.y()); // SIMD dot product\n        }\n\n        sumf_d += d8[i0 / 4] *\n                  (sc[i0 / 2 + 0] * sumi_d.x() + sc[i0 / 2 + 1] * sumi_d.y());\n    }\n\n    return d6 * sumf_d;\n}\n\nstatic __dpct_inline__ float vec_dot_q6_K_q8_1_mul_mat(\n    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,\n    const int *__restrict__ x_qh, const int *__restrict__ x_sc,\n    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,\n    const int &i, const int &j, const int &k) {\n    (void)x_qh;\n\n    const float * x_dmf = (const float *) x_dm;\n    const float * y_df  = (const float *) y_ds;\n\n    const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);\n\n    const int index_x = i * (QR6_K*WARP_SIZE + 1) +  QR6_K*k;\n    const int index_y = j * WARP_SIZE             + (QR6_K*k) % WARP_SIZE;\n    return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);\n}\n\ntemplate <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,\n          int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,\n          vec_dot_q_mul_mat_sycl_t vec_dot>\n/*\nDPCT1110:8: The total declared local variable size in device function mul_mat_q\nexceeds 128 bytes and may cause high register pressure. Consult with your\nhardware vendor to find the total register size available and adjust the code,\nor use smaller sub-group size to avoid high register pressure.\n*/\nstatic __dpct_inline__ void\nmul_mat_q(const void *__restrict__ vx, const void *__restrict__ vy,\n          float *__restrict__ dst, const int ncols_x, const int nrows_x,\n          const int ncols_y, const int nrows_y, const int nrows_dst,\n          int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_qh,\n          int *tile_x_sc, const sycl::nd_item<3> &item_ct1, int *tile_y_qs,\n          sycl::half2 *tile_y_ds) {\n\n    const block_q_t  * x = (const block_q_t  *) vx;\n    const block_q8_1 * y = (const block_q8_1 *) vy;\n\n    const int blocks_per_row_x = ncols_x / qk;\n    const int blocks_per_col_y = nrows_y / QK8_1;\n    const int blocks_per_warp = WARP_SIZE / qi;\n\n    const int & ncols_dst = ncols_y;\n\n    const int row_dst_0 = item_ct1.get_group(2) * mmq_y;\n    const int & row_x_0 = row_dst_0;\n\n    const int col_dst_0 = item_ct1.get_group(1) * mmq_x;\n    const int & col_y_0 = col_dst_0;\n\n    float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};\n\n    for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {\n\n        load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm,\n                   tile_x_qh, tile_x_sc, item_ct1.get_local_id(1),\n                   nrows_x - row_x_0 - 1, item_ct1.get_local_id(2),\n                   blocks_per_row_x);\n\n#pragma unroll\n        for (int ir = 0; ir < qr; ++ir) {\n            const int kqs = ir * WARP_SIZE + item_ct1.get_local_id(2);\n            const int kbxd = kqs / QI8_1;\n\n#pragma unroll\n            for (int i = 0; i < mmq_x; i += nwarps) {\n                const int col_y_eff = dpct::min(\n                    (unsigned int)(col_y_0 + item_ct1.get_local_id(1) + i),\n                    ncols_y - 1); // to prevent out-of-bounds memory accesses\n\n                const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];\n\n                const int index_y = (item_ct1.get_local_id(1) + i) * WARP_SIZE +\n                                    kqs % WARP_SIZE;\n                tile_y_qs[index_y] = get_int_from_int8_aligned(\n                    by0->qs, item_ct1.get_local_id(2) % QI8_1);\n            }\n\n#pragma unroll\n            for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {\n                const int ids =\n                    (ids0 + item_ct1.get_local_id(1) * QI8_1 +\n                     item_ct1.get_local_id(2) / (WARP_SIZE / QI8_1)) %\n                    mmq_x;\n                const int kby = item_ct1.get_local_id(2) % (WARP_SIZE / QI8_1);\n                const int col_y_eff = sycl::min(col_y_0 + ids, ncols_y - 1);\n\n                // if the sum is not needed it's faster to transform the scale to f32 ahead of time\n                const sycl::half2 *dsi_src =\n                    &y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) +\n                       ir * (WARP_SIZE / QI8_1) + kby]\n                         .ds;\n                sycl::half2 *dsi_dst =\n                    &tile_y_ds[ids * (WARP_SIZE / QI8_1) + kby];\n                if (need_sum) {\n                    *dsi_dst = *dsi_src;\n                } else {\n                    float * dfi_dst = (float *) dsi_dst;\n                    *dfi_dst = (*dsi_src)[0];\n                }\n            }\n\n            /*\n            DPCT1118:9: SYCL group functions and algorithms must be encountered\n            in converged control flow. You may need to adjust the code.\n            */\n            /*\n            DPCT1065:56: Consider replacing sycl::nd_item::barrier() with\n            sycl::nd_item::barrier(sycl::access::fence_space::local_space) for\n            better performance if there is no access to global memory.\n            */\n            item_ct1.barrier();\n\n// #pragma unroll // unrolling this loop causes too much register pressure\n            for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {\n#pragma unroll\n                for (int j = 0; j < mmq_x; j += nwarps) {\n#pragma unroll\n                    for (int i = 0; i < mmq_y; i += WARP_SIZE) {\n                        sum[i / WARP_SIZE][j / nwarps] += vec_dot(\n                            tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,\n                            tile_y_qs, tile_y_ds, item_ct1.get_local_id(2) + i,\n                            item_ct1.get_local_id(1) + j, k);\n                    }\n                }\n            }\n\n            /*\n            DPCT1118:10: SYCL group functions and algorithms must be encountered\n            in converged control flow. You may need to adjust the code.\n            */\n            /*\n            DPCT1065:57: Consider replacing sycl::nd_item::barrier() with\n            sycl::nd_item::barrier(sycl::access::fence_space::local_space) for\n            better performance if there is no access to global memory.\n            */\n            item_ct1.barrier();\n        }\n    }\n\n#pragma unroll\n    for (int j = 0; j < mmq_x; j += nwarps) {\n        const int col_dst = col_dst_0 + j + item_ct1.get_local_id(1);\n\n        if (col_dst >= ncols_dst) {\n            return;\n        }\n\n#pragma unroll\n        for (int i = 0; i < mmq_y; i += WARP_SIZE) {\n            const int row_dst = row_dst_0 + item_ct1.get_local_id(2) + i;\n\n            if (row_dst >= nrows_dst) {\n                continue;\n            }\n\n            dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];\n        }\n    }\n}\n\n#define  MMQ_X_Q4_0_RDNA2  64\n#define  MMQ_Y_Q4_0_RDNA2  128\n#define NWARPS_Q4_0_RDNA2  8\n#define  MMQ_X_Q4_0_RDNA1  64\n#define  MMQ_Y_Q4_0_RDNA1  64\n#define NWARPS_Q4_0_RDNA1  8\n#if defined(SYCL_USE_XMX)\n#define  MMQ_X_Q4_0_AMPERE 4\n#define  MMQ_Y_Q4_0_AMPERE 32\n#define NWARPS_Q4_0_AMPERE 4\n#else\n#define  MMQ_X_Q4_0_AMPERE 64\n#define  MMQ_Y_Q4_0_AMPERE 128\n#define NWARPS_Q4_0_AMPERE 4\n#endif\n#define  MMQ_X_Q4_0_PASCAL 64\n#define  MMQ_Y_Q4_0_PASCAL 64\n#define NWARPS_Q4_0_PASCAL 8\n\ntemplate <bool need_check> static void\n    mul_mat_q4_0(\n    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,\n    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,\n    const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_0, float *tile_x_d_q4_0,\n    int *tile_y_qs, sycl::half2 *tile_y_ds) {\n    int   * tile_x_ql = nullptr;\n    sycl::half2 *tile_x_dm = nullptr;\n    int   * tile_x_qh = nullptr;\n    int   * tile_x_sc = nullptr;\n\n//sycl_todo: change according to hardware\n\n    const int mmq_x  =  MMQ_X_Q4_0_AMPERE;\n    const int mmq_y  =  MMQ_Y_Q4_0_AMPERE;\n    const int nwarps = NWARPS_Q4_0_AMPERE;\n    allocate_tiles_q4_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,\n                               tile_x_qs_q4_0, tile_x_d_q4_0);\n    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,\n              load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ,\n              vec_dot_q4_0_q8_1_mul_mat>(\n        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,\n        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);\n}\n\n#define  MMQ_X_Q4_1_RDNA2  64\n#define  MMQ_Y_Q4_1_RDNA2  128\n#define NWARPS_Q4_1_RDNA2  8\n#define  MMQ_X_Q4_1_RDNA1  64\n#define  MMQ_Y_Q4_1_RDNA1  64\n#define NWARPS_Q4_1_RDNA1  8\n#if defined(SYCL_USE_XMX)\n#define  MMQ_X_Q4_1_AMPERE 4\n#define  MMQ_Y_Q4_1_AMPERE 32\n#define NWARPS_Q4_1_AMPERE 4\n#else\n#define  MMQ_X_Q4_1_AMPERE 64\n#define  MMQ_Y_Q4_1_AMPERE 128\n#define NWARPS_Q4_1_AMPERE 4\n#endif\n#define  MMQ_X_Q4_1_PASCAL 64\n#define  MMQ_Y_Q4_1_PASCAL 64\n#define NWARPS_Q4_1_PASCAL 8\n\ntemplate <bool need_check> static void\n    mul_mat_q4_1(\n    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,\n    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,\n    const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_1,\n    sycl::half2 *tile_x_dm_q4_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {\n    int   * tile_x_ql = nullptr;\n    sycl::half2 *tile_x_dm = nullptr;\n    int   * tile_x_qh = nullptr;\n    int   * tile_x_sc = nullptr;\n\n//sycl_todo: change according to hardware\n    const int mmq_x  =  MMQ_X_Q4_1_AMPERE;\n    const int mmq_y  =  MMQ_Y_Q4_1_AMPERE;\n    const int nwarps = NWARPS_Q4_1_AMPERE;\n    allocate_tiles_q4_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,\n                               tile_x_qs_q4_1, tile_x_dm_q4_1);\n    mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,\n              load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ,\n              vec_dot_q4_1_q8_1_mul_mat>(\n        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,\n        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);\n}\n\n#define  MMQ_X_Q5_0_RDNA2  64\n#define  MMQ_Y_Q5_0_RDNA2  128\n#define NWARPS_Q5_0_RDNA2  8\n#define  MMQ_X_Q5_0_RDNA1  64\n#define  MMQ_Y_Q5_0_RDNA1  64\n#define NWARPS_Q5_0_RDNA1  8\n#if defined(SYCL_USE_XMX)\n#define  MMQ_X_Q5_0_AMPERE 4\n#define  MMQ_Y_Q5_0_AMPERE 32\n#define NWARPS_Q5_0_AMPERE 4\n#else\n#define  MMQ_X_Q5_0_AMPERE 128\n#define  MMQ_Y_Q5_0_AMPERE 64\n#define NWARPS_Q5_0_AMPERE 4\n#endif\n#define  MMQ_X_Q5_0_PASCAL 64\n#define  MMQ_Y_Q5_0_PASCAL 64\n#define NWARPS_Q5_0_PASCAL 8\n\ntemplate <bool need_check> static void\n    mul_mat_q5_0(\n    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,\n    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,\n    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_0, float *tile_x_d_q5_0,\n    int *tile_y_qs, sycl::half2 *tile_y_ds) {\n    int   * tile_x_ql = nullptr;\n    sycl::half2 *tile_x_dm = nullptr;\n    int   * tile_x_qh = nullptr;\n    int   * tile_x_sc = nullptr;\n\n//sycl_todo: change according to hardware\n    const int mmq_x  =  MMQ_X_Q5_0_AMPERE;\n    const int mmq_y  =  MMQ_Y_Q5_0_AMPERE;\n    const int nwarps = NWARPS_Q5_0_AMPERE;\n    allocate_tiles_q5_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,\n                               tile_x_ql_q5_0, tile_x_d_q5_0);\n    mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,\n              load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ,\n              vec_dot_q5_0_q8_1_mul_mat>(\n        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,\n        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);\n}\n\n#define  MMQ_X_Q5_1_RDNA2  64\n#define  MMQ_Y_Q5_1_RDNA2  128\n#define NWARPS_Q5_1_RDNA2  8\n#define  MMQ_X_Q5_1_RDNA1  64\n#define  MMQ_Y_Q5_1_RDNA1  64\n#define NWARPS_Q5_1_RDNA1  8\n#if defined(SYCL_USE_XMX)\n#define  MMQ_X_Q5_1_AMPERE 4\n#define  MMQ_Y_Q5_1_AMPERE 32\n#define NWARPS_Q5_1_AMPERE 4\n#else\n#define  MMQ_X_Q5_1_AMPERE 128\n#define  MMQ_Y_Q5_1_AMPERE 64\n#define NWARPS_Q5_1_AMPERE 4\n#endif\n#define  MMQ_X_Q5_1_PASCAL 64\n#define  MMQ_Y_Q5_1_PASCAL 64\n#define NWARPS_Q5_1_PASCAL 8\n\ntemplate <bool need_check> static void\nmul_mat_q5_1(\n    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,\n    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,\n    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_1,\n    sycl::half2 *tile_x_dm_q5_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {\n    int   * tile_x_ql = nullptr;\n    sycl::half2 *tile_x_dm = nullptr;\n    int   * tile_x_qh = nullptr;\n    int   * tile_x_sc = nullptr;\n\n//sycl_todo: change according to hardware\n    const int mmq_x  =  MMQ_X_Q5_1_AMPERE;\n    const int mmq_y  =  MMQ_Y_Q5_1_AMPERE;\n    const int nwarps = NWARPS_Q5_1_AMPERE;\n    allocate_tiles_q5_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,\n                               tile_x_ql_q5_1, tile_x_dm_q5_1);\n    mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,\n              load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ,\n              vec_dot_q5_1_q8_1_mul_mat>(\n        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,\n        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);\n}\n\n#define  MMQ_X_Q8_0_RDNA2  64\n#define  MMQ_Y_Q8_0_RDNA2  128\n#define NWARPS_Q8_0_RDNA2  8\n#define  MMQ_X_Q8_0_RDNA1  64\n#define  MMQ_Y_Q8_0_RDNA1  64\n#define NWARPS_Q8_0_RDNA1  8\n#if defined(SYCL_USE_XMX)\n#define  MMQ_X_Q8_0_AMPERE 4\n#define  MMQ_Y_Q8_0_AMPERE 32\n#define NWARPS_Q8_0_AMPERE 4\n#else\n#define  MMQ_X_Q8_0_AMPERE 128\n#define  MMQ_Y_Q8_0_AMPERE 64\n#define NWARPS_Q8_0_AMPERE 4\n#endif\n#define  MMQ_X_Q8_0_PASCAL 64\n#define  MMQ_Y_Q8_0_PASCAL 64\n#define NWARPS_Q8_0_PASCAL 8\n\ntemplate <bool need_check> static void\n    mul_mat_q8_0(\n    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,\n    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,\n    const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q8_0, float *tile_x_d_q8_0,\n    int *tile_y_qs, sycl::half2 *tile_y_ds) {\n    int   * tile_x_ql = nullptr;\n    sycl::half2 *tile_x_dm = nullptr;\n    int   * tile_x_qh = nullptr;\n    int   * tile_x_sc = nullptr;\n\n//sycl_todo: change according to hardware\n    const int mmq_x  =  MMQ_X_Q8_0_AMPERE;\n    const int mmq_y  =  MMQ_Y_Q8_0_AMPERE;\n    const int nwarps = NWARPS_Q8_0_AMPERE;\n    allocate_tiles_q8_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,\n                               tile_x_qs_q8_0, tile_x_d_q8_0);\n    mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,\n              load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ,\n              vec_dot_q8_0_q8_1_mul_mat>(\n        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,\n        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);\n}\n\n#define  MMQ_X_Q2_K_RDNA2  64\n#define  MMQ_Y_Q2_K_RDNA2  128\n#define NWARPS_Q2_K_RDNA2  8\n#define  MMQ_X_Q2_K_RDNA1  128\n#define  MMQ_Y_Q2_K_RDNA1  32\n#define NWARPS_Q2_K_RDNA1  8\n#if defined(SYCL_USE_XMX)\n#define  MMQ_X_Q2_K_AMPERE 4\n#define  MMQ_Y_Q2_K_AMPERE 32\n#define NWARPS_Q2_K_AMPERE 4\n#else\n#define  MMQ_X_Q2_K_AMPERE 64\n#define  MMQ_Y_Q2_K_AMPERE 128\n#define NWARPS_Q2_K_AMPERE 4\n#endif\n#define  MMQ_X_Q2_K_PASCAL 64\n#define  MMQ_Y_Q2_K_PASCAL 64\n#define NWARPS_Q2_K_PASCAL 8\n\ntemplate <bool need_check> static void\nmul_mat_q2_K(\n    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,\n    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,\n    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q2_K,\n    sycl::half2 *tile_x_dm_q2_K, int *tile_x_sc_q2_K, int *tile_y_qs,\n    sycl::half2 *tile_y_ds) {\n    int   * tile_x_ql = nullptr;\n    sycl::half2 *tile_x_dm = nullptr;\n    int   * tile_x_qh = nullptr;\n    int   * tile_x_sc = nullptr;\n\n//sycl_todo: change according to hardware\n    const int mmq_x  =  MMQ_X_Q2_K_AMPERE;\n    const int mmq_y  =  MMQ_Y_Q2_K_AMPERE;\n    const int nwarps = NWARPS_Q2_K_AMPERE;\n    allocate_tiles_q2_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,\n                               tile_x_ql_q2_K, tile_x_dm_q2_K, tile_x_sc_q2_K);\n    mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,\n              load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ,\n              vec_dot_q2_K_q8_1_mul_mat>(\n        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,\n        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);\n}\n\n#define  MMQ_X_Q3_K_RDNA2  128\n#define  MMQ_Y_Q3_K_RDNA2  64\n#define NWARPS_Q3_K_RDNA2  8\n#define  MMQ_X_Q3_K_RDNA1  32\n#define  MMQ_Y_Q3_K_RDNA1  128\n#define NWARPS_Q3_K_RDNA1  8\n#if defined(SYCL_USE_XMX)\n#define  MMQ_X_Q3_K_AMPERE 4\n#define  MMQ_Y_Q3_K_AMPERE 32\n#define NWARPS_Q3_K_AMPERE 4\n#else\n#define  MMQ_X_Q3_K_AMPERE 128\n#define  MMQ_Y_Q3_K_AMPERE 128\n#define NWARPS_Q3_K_AMPERE 4\n#endif\n#define  MMQ_X_Q3_K_PASCAL 64\n#define  MMQ_Y_Q3_K_PASCAL 64\n#define NWARPS_Q3_K_PASCAL 8\n\ntemplate <bool need_check> static void\nmul_mat_q3_K(\n    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,\n    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,\n    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q3_K,\n    sycl::half2 *tile_x_dm_q3_K, int *tile_x_qh_q3_K, int *tile_x_sc_q3_K,\n    int *tile_y_qs, sycl::half2 *tile_y_ds) {\n    int   * tile_x_ql = nullptr;\n    sycl::half2 *tile_x_dm = nullptr;\n    int   * tile_x_qh = nullptr;\n    int   * tile_x_sc = nullptr;\n\n//sycl_todo: change according to hardware\n    const int mmq_x  =  MMQ_X_Q3_K_AMPERE;\n    const int mmq_y  =  MMQ_Y_Q3_K_AMPERE;\n    const int nwarps = NWARPS_Q3_K_AMPERE;\n    allocate_tiles_q3_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,\n                               tile_x_ql_q3_K, tile_x_dm_q3_K, tile_x_qh_q3_K,\n                               tile_x_sc_q3_K);\n    mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,\n              load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ,\n              vec_dot_q3_K_q8_1_mul_mat>(\n        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,\n        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);\n}\n\n#define  MMQ_X_Q4_K_RDNA2  64\n#define  MMQ_Y_Q4_K_RDNA2  128\n#define NWARPS_Q4_K_RDNA2  8\n#define  MMQ_X_Q4_K_RDNA1  32\n#define  MMQ_Y_Q4_K_RDNA1  64\n#define NWARPS_Q4_K_RDNA1  8\n#if defined(SYCL_USE_XMX)\n#define  MMQ_X_Q4_K_AMPERE 4\n#define  MMQ_Y_Q4_K_AMPERE 32\n#define NWARPS_Q4_K_AMPERE 4\n#else\n#define  MMQ_X_Q4_K_AMPERE 64\n#define  MMQ_Y_Q4_K_AMPERE 128\n#define NWARPS_Q4_K_AMPERE 4\n#endif\n#define  MMQ_X_Q4_K_PASCAL 64\n#define  MMQ_Y_Q4_K_PASCAL 64\n#define NWARPS_Q4_K_PASCAL 8\n\ntemplate <bool need_check> static void\n    mul_mat_q4_K(\n    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,\n    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,\n    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q4_K,\n    sycl::half2 *tile_x_dm_q4_K, int *tile_x_sc_q4_K, int *tile_y_qs,\n    sycl::half2 *tile_y_ds) {\n    int   * tile_x_ql = nullptr;\n    sycl::half2 *tile_x_dm = nullptr;\n    int   * tile_x_qh = nullptr;\n    int   * tile_x_sc = nullptr;\n\n//sycl_todo: change according to hardware\n    const int mmq_x  =  MMQ_X_Q4_K_AMPERE;\n    const int mmq_y  =  MMQ_Y_Q4_K_AMPERE;\n    const int nwarps = NWARPS_Q4_K_AMPERE;\n    allocate_tiles_q4_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,\n                               tile_x_ql_q4_K, tile_x_dm_q4_K, tile_x_sc_q4_K);\n    mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,\n              load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ,\n              vec_dot_q4_K_q8_1_mul_mat>(\n        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,\n        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);\n}\n\n#define  MMQ_X_Q5_K_RDNA2  64\n#define  MMQ_Y_Q5_K_RDNA2  128\n#define NWARPS_Q5_K_RDNA2  8\n#define  MMQ_X_Q5_K_RDNA1  32\n#define  MMQ_Y_Q5_K_RDNA1  64\n#define NWARPS_Q5_K_RDNA1  8\n#if defined(SYCL_USE_XMX)\n#define  MMQ_X_Q5_K_AMPERE 4\n#define  MMQ_Y_Q5_K_AMPERE 32\n#define NWARPS_Q5_K_AMPERE 4\n#else\n#define  MMQ_X_Q5_K_AMPERE 64\n#define  MMQ_Y_Q5_K_AMPERE 128\n#define NWARPS_Q5_K_AMPERE 4\n#endif\n#define  MMQ_X_Q5_K_PASCAL 64\n#define  MMQ_Y_Q5_K_PASCAL 64\n#define NWARPS_Q5_K_PASCAL 8\n\ntemplate <bool need_check> static void\nmul_mat_q5_K(\n    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,\n    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,\n    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_K,\n    sycl::half2 *tile_x_dm_q5_K, int *tile_x_sc_q5_K, int *tile_y_qs,\n    sycl::half2 *tile_y_ds) {\n    int   * tile_x_ql = nullptr;\n    sycl::half2 *tile_x_dm = nullptr;\n    int   * tile_x_qh = nullptr;\n    int   * tile_x_sc = nullptr;\n\n//sycl_todo: change according to hardware\n    const int mmq_x  =  MMQ_X_Q5_K_AMPERE;\n    const int mmq_y  =  MMQ_Y_Q5_K_AMPERE;\n    const int nwarps = NWARPS_Q5_K_AMPERE;\n    allocate_tiles_q5_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,\n                               tile_x_ql_q5_K, tile_x_dm_q5_K, tile_x_sc_q5_K);\n    mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,\n              load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ,\n              vec_dot_q5_K_q8_1_mul_mat>(\n        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,\n        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);\n}\n\n#define  MMQ_X_Q6_K_RDNA2  64\n#define  MMQ_Y_Q6_K_RDNA2  128\n#define NWARPS_Q6_K_RDNA2  8\n#define  MMQ_X_Q6_K_RDNA1  32\n#define  MMQ_Y_Q6_K_RDNA1  64\n#define NWARPS_Q6_K_RDNA1  8\n#if defined(SYCL_USE_XMX)\n#define  MMQ_X_Q6_K_AMPERE 4\n#define  MMQ_Y_Q6_K_AMPERE 32\n#define NWARPS_Q6_K_AMPERE 4\n#else\n#define  MMQ_X_Q6_K_AMPERE 64\n#define  MMQ_Y_Q6_K_AMPERE 64\n#define NWARPS_Q6_K_AMPERE 4\n#endif\n#define  MMQ_X_Q6_K_PASCAL 64\n#define  MMQ_Y_Q6_K_PASCAL 64\n#define NWARPS_Q6_K_PASCAL 8\n\ntemplate <bool need_check> static void\n    mul_mat_q6_K(\n    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,\n    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,\n    const sycl::nd_item<3> &item_ct1, int *tile_x_ql, sycl::half2 *tile_x_dm,\n    int *tile_x_sc, int *tile_y_qs, sycl::half2 *tile_y_ds) {\n    // int   * tile_x_ql = nullptr;\n    // sycl::half2 *tile_x_dm = nullptr;\n    int   * tile_x_qh = nullptr;\n    // int   * tile_x_sc = nullptr;\n\n//sycl_todo: change according to hardware\n    const int mmq_x  =  MMQ_X_Q6_K_AMPERE;\n    const int mmq_y  =  MMQ_Y_Q6_K_AMPERE;\n    const int nwarps = NWARPS_Q6_K_AMPERE;\n    allocate_tiles_q6_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,\n                               tile_x_ql, tile_x_dm, tile_x_sc);\n    mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,\n              load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ,\n              vec_dot_q6_K_q8_1_mul_mat>(\n        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,\n        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);\n}\n\nstatic void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,\n                                        float *dst, const int ncols_x,\n                                        const int nrows_x, const int ncols_y,\n                                        const int nrows_y, const int nrows_dst,\n                                        dpct::queue_ptr stream) try {\n\n    int id;\n    SYCL_CHECK(\n        CHECK_TRY_ERROR(id = get_current_device_id()));\n    const int compute_capability = ggml_sycl_info().devices[id].cc;\n\n    int mmq_x, mmq_y, nwarps;\n    if (compute_capability >= VER_GEN13) {\n        mmq_x  =  MMQ_X_Q4_0_RDNA2;\n        mmq_y  =  MMQ_Y_Q4_0_RDNA2;\n        nwarps = NWARPS_Q4_0_RDNA2;\n    } else if (compute_capability >= VER_GEN12) {\n        mmq_x  =  MMQ_X_Q4_0_RDNA1;\n        mmq_y  =  MMQ_Y_Q4_0_RDNA1;\n        nwarps = NWARPS_Q4_0_RDNA1;\n    } else if (compute_capability >= VER_GEN9) {\n        mmq_x  =  MMQ_X_Q4_0_AMPERE;\n        mmq_y  =  MMQ_Y_Q4_0_AMPERE;\n        nwarps = NWARPS_Q4_0_AMPERE;\n    } else if (compute_capability >= VER_4VEC) {\n        mmq_x  =  MMQ_X_Q4_0_PASCAL;\n        mmq_y  =  MMQ_Y_Q4_0_PASCAL;\n        nwarps = NWARPS_Q4_0_PASCAL;\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n\n    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;\n    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;\n    const sycl::range<3> block_nums(1, block_num_y, block_num_x);\n    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);\n\n    if (nrows_x % mmq_y == 0) {\n        const bool need_check = false;\n        /*\n        DPCT1049:20: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q4_0<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_qs_q4_0_acc_ct1),\n                            get_pointer(tile_x_d_q4_0_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    } else {\n        const bool need_check = true;\n        /*\n        DPCT1049:21: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q4_0<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_qs_q4_0_acc_ct1),\n                            get_pointer(tile_x_d_q4_0_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,\n                                        float *dst, const int ncols_x,\n                                        const int nrows_x, const int ncols_y,\n                                        const int nrows_y, const int nrows_dst,\n                                        dpct::queue_ptr stream) try {\n\n    int id;\n    SYCL_CHECK(\n        CHECK_TRY_ERROR(id = get_current_device_id()));\n    const int compute_capability = ggml_sycl_info().devices[id].cc;\n\n    int mmq_x, mmq_y, nwarps;\n    if (compute_capability >= VER_GEN13) {\n        mmq_x  =  MMQ_X_Q4_1_RDNA2;\n        mmq_y  =  MMQ_Y_Q4_1_RDNA2;\n        nwarps = NWARPS_Q4_1_RDNA2;\n    } else if (compute_capability >= VER_GEN12) {\n        mmq_x  =  MMQ_X_Q4_1_RDNA1;\n        mmq_y  =  MMQ_Y_Q4_1_RDNA1;\n        nwarps = NWARPS_Q4_1_RDNA1;\n    } else if (compute_capability >= VER_GEN9) {\n        mmq_x  =  MMQ_X_Q4_1_AMPERE;\n        mmq_y  =  MMQ_Y_Q4_1_AMPERE;\n        nwarps = NWARPS_Q4_1_AMPERE;\n    } else if (compute_capability >= VER_4VEC) {\n        mmq_x  =  MMQ_X_Q4_1_PASCAL;\n        mmq_y  =  MMQ_Y_Q4_1_PASCAL;\n        nwarps = NWARPS_Q4_1_PASCAL;\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n\n    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;\n    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;\n    const sycl::range<3> block_nums(1, block_num_y, block_num_x);\n    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);\n\n    if (nrows_x % mmq_y == 0) {\n        const bool need_check = false;\n        /*\n        DPCT1049:22: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q4_1<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_qs_q4_1_acc_ct1),\n                            get_pointer(tile_x_dm_q4_1_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    } else {\n        const bool need_check = true;\n        /*\n        DPCT1049:23: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q4_1<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_qs_q4_1_acc_ct1),\n                            get_pointer(tile_x_dm_q4_1_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,\n                                        float *dst, const int ncols_x,\n                                        const int nrows_x, const int ncols_y,\n                                        const int nrows_y, const int nrows_dst,\n                                        dpct::queue_ptr stream) try {\n\n    int id;\n    SYCL_CHECK(\n        CHECK_TRY_ERROR(id = get_current_device_id()));\n    const int compute_capability = ggml_sycl_info().devices[id].cc;\n\n    int mmq_x, mmq_y, nwarps;\n    if (compute_capability >= VER_GEN13) {\n        mmq_x  =  MMQ_X_Q5_0_RDNA2;\n        mmq_y  =  MMQ_Y_Q5_0_RDNA2;\n        nwarps = NWARPS_Q5_0_RDNA2;\n    } else if (compute_capability >= VER_GEN12) {\n        mmq_x  =  MMQ_X_Q5_0_RDNA1;\n        mmq_y  =  MMQ_Y_Q5_0_RDNA1;\n        nwarps = NWARPS_Q5_0_RDNA1;\n    } else if (compute_capability >= VER_GEN9) {\n        mmq_x  =  MMQ_X_Q5_0_AMPERE;\n        mmq_y  =  MMQ_Y_Q5_0_AMPERE;\n        nwarps = NWARPS_Q5_0_AMPERE;\n    } else if (compute_capability >= VER_4VEC) {\n        mmq_x  =  MMQ_X_Q5_0_PASCAL;\n        mmq_y  =  MMQ_Y_Q5_0_PASCAL;\n        nwarps = NWARPS_Q5_0_PASCAL;\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n\n    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;\n    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;\n    const sycl::range<3> block_nums(1, block_num_y, block_num_x);\n    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);\n\n    if (nrows_x % mmq_y == 0) {\n        const bool need_check = false;\n        /*\n        DPCT1049:24: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(\n                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q5_0<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_ql_q5_0_acc_ct1),\n                            get_pointer(tile_x_d_q5_0_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    } else {\n        const bool need_check = true;\n        /*\n        DPCT1049:25: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(\n                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q5_0<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_ql_q5_0_acc_ct1),\n                            get_pointer(tile_x_d_q5_0_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,\n                                        float *dst, const int ncols_x,\n                                        const int nrows_x, const int ncols_y,\n                                        const int nrows_y, const int nrows_dst,\n                                        dpct::queue_ptr stream) try {\n\n    int id;\n    SYCL_CHECK(\n        CHECK_TRY_ERROR(id = get_current_device_id()));\n    const int compute_capability = ggml_sycl_info().devices[id].cc;\n\n    int mmq_x, mmq_y, nwarps;\n    if (compute_capability >= VER_GEN13) {\n        mmq_x  =  MMQ_X_Q5_1_RDNA2;\n        mmq_y  =  MMQ_Y_Q5_1_RDNA2;\n        nwarps = NWARPS_Q5_1_RDNA2;\n    } else if (compute_capability >= VER_GEN12) {\n        mmq_x  =  MMQ_X_Q5_1_RDNA1;\n        mmq_y  =  MMQ_Y_Q5_1_RDNA1;\n        nwarps = NWARPS_Q5_1_RDNA1;\n    } else if (compute_capability >= VER_GEN9) {\n        mmq_x  =  MMQ_X_Q5_1_AMPERE;\n        mmq_y  =  MMQ_Y_Q5_1_AMPERE;\n        nwarps = NWARPS_Q5_1_AMPERE;\n    } else if (compute_capability >= VER_4VEC) {\n        mmq_x  =  MMQ_X_Q5_1_PASCAL;\n        mmq_y  =  MMQ_Y_Q5_1_PASCAL;\n        nwarps = NWARPS_Q5_1_PASCAL;\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n\n    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;\n    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;\n    const sycl::range<3> block_nums(1, block_num_y, block_num_x);\n    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);\n\n    if (nrows_x % mmq_y == 0) {\n        const bool need_check = false;\n        /*\n        DPCT1049:26: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(\n                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q5_1<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_ql_q5_1_acc_ct1),\n                            get_pointer(tile_x_dm_q5_1_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    } else {\n        const bool need_check = true;\n        /*\n        DPCT1049:27: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(\n                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q5_1<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_ql_q5_1_acc_ct1),\n                            get_pointer(tile_x_dm_q5_1_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,\n                                        float *dst, const int ncols_x,\n                                        const int nrows_x, const int ncols_y,\n                                        const int nrows_y, const int nrows_dst,\n                                        dpct::queue_ptr stream) try {\n\n    int id;\n    SYCL_CHECK(\n        CHECK_TRY_ERROR(id = get_current_device_id()));\n    const int compute_capability = ggml_sycl_info().devices[id].cc;\n\n    int mmq_x, mmq_y, nwarps;\n    if (compute_capability >= VER_GEN13) {\n        mmq_x  =  MMQ_X_Q8_0_RDNA2;\n        mmq_y  =  MMQ_Y_Q8_0_RDNA2;\n        nwarps = NWARPS_Q8_0_RDNA2;\n    } else if (compute_capability >= VER_GEN12) {\n        mmq_x  =  MMQ_X_Q8_0_RDNA1;\n        mmq_y  =  MMQ_Y_Q8_0_RDNA1;\n        nwarps = NWARPS_Q8_0_RDNA1;\n    } else if (compute_capability >= VER_GEN9) {\n        mmq_x  =  MMQ_X_Q8_0_AMPERE;\n        mmq_y  =  MMQ_Y_Q8_0_AMPERE;\n        nwarps = NWARPS_Q8_0_AMPERE;\n    } else if (compute_capability >= VER_4VEC) {\n        mmq_x  =  MMQ_X_Q8_0_PASCAL;\n        mmq_y  =  MMQ_Y_Q8_0_PASCAL;\n        nwarps = NWARPS_Q8_0_PASCAL;\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n\n    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;\n    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;\n    const sycl::range<3> block_nums(1, block_num_y, block_num_x);\n    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);\n\n    if (nrows_x % mmq_y == 0) {\n        const bool need_check = false;\n        /*\n        DPCT1049:28: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q8_0<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_qs_q8_0_acc_ct1),\n                            get_pointer(tile_x_d_q8_0_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    } else {\n        const bool need_check = true;\n        /*\n        DPCT1049:29: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q8_0<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_qs_q8_0_acc_ct1),\n                            get_pointer(tile_x_d_q8_0_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,\n                                        float *dst, const int ncols_x,\n                                        const int nrows_x, const int ncols_y,\n                                        const int nrows_y, const int nrows_dst,\n                                        dpct::queue_ptr stream) try {\n\n    int id;\n    SYCL_CHECK(\n        CHECK_TRY_ERROR(id = get_current_device_id()));\n    const int compute_capability = ggml_sycl_info().devices[id].cc;\n\n    int mmq_x, mmq_y, nwarps;\n    if (compute_capability >= VER_GEN13) {\n        mmq_x  =  MMQ_X_Q2_K_RDNA2;\n        mmq_y  =  MMQ_Y_Q2_K_RDNA2;\n        nwarps = NWARPS_Q2_K_RDNA2;\n    } else if (compute_capability >= VER_GEN12) {\n        mmq_x  =  MMQ_X_Q2_K_RDNA1;\n        mmq_y  =  MMQ_Y_Q2_K_RDNA1;\n        nwarps = NWARPS_Q2_K_RDNA1;\n    } else if (compute_capability >= VER_GEN9) {\n        mmq_x  =  MMQ_X_Q2_K_AMPERE;\n        mmq_y  =  MMQ_Y_Q2_K_AMPERE;\n        nwarps = NWARPS_Q2_K_AMPERE;\n    } else if (compute_capability >= VER_4VEC) {\n        mmq_x  =  MMQ_X_Q2_K_PASCAL;\n        mmq_y  =  MMQ_Y_Q2_K_PASCAL;\n        nwarps = NWARPS_Q2_K_PASCAL;\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n\n    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;\n    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;\n    const sycl::range<3> block_nums(1, block_num_y, block_num_x);\n    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);\n\n    if (nrows_x % mmq_y == 0) {\n        const bool need_check = false;\n        /*\n        DPCT1049:30: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q2_K<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_ql_q2_K_acc_ct1),\n                            get_pointer(tile_x_dm_q2_K_acc_ct1),\n                            get_pointer(tile_x_sc_q2_K_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    } else {\n        const bool need_check = true;\n        /*\n        DPCT1049:31: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q2_K<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_ql_q2_K_acc_ct1),\n                            get_pointer(tile_x_dm_q2_K_acc_ct1),\n                            get_pointer(tile_x_sc_q2_K_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,\n                                        float *dst, const int ncols_x,\n                                        const int nrows_x, const int ncols_y,\n                                        const int nrows_y, const int nrows_dst,\n                                        dpct::queue_ptr stream) try {\n\n#if QK_K == 256\n\n    int id;\n    SYCL_CHECK(\n        CHECK_TRY_ERROR(id = get_current_device_id()));\n    const int compute_capability = ggml_sycl_info().devices[id].cc;\n\n    int mmq_x, mmq_y, nwarps;\n    if (compute_capability >= VER_GEN13) {\n        mmq_x  =  MMQ_X_Q3_K_RDNA2;\n        mmq_y  =  MMQ_Y_Q3_K_RDNA2;\n        nwarps = NWARPS_Q3_K_RDNA2;\n    } else if (compute_capability >= VER_GEN12) {\n        mmq_x  =  MMQ_X_Q3_K_RDNA1;\n        mmq_y  =  MMQ_Y_Q3_K_RDNA1;\n        nwarps = NWARPS_Q3_K_RDNA1;\n    } else if (compute_capability >= VER_GEN9) {\n        mmq_x  =  MMQ_X_Q3_K_AMPERE;\n        mmq_y  =  MMQ_Y_Q3_K_AMPERE;\n        nwarps = NWARPS_Q3_K_AMPERE;\n    } else if (compute_capability >= VER_4VEC) {\n        mmq_x  =  MMQ_X_Q3_K_PASCAL;\n        mmq_y  =  MMQ_Y_Q3_K_PASCAL;\n        nwarps = NWARPS_Q3_K_PASCAL;\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n\n    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;\n    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;\n    const sycl::range<3> block_nums(1, block_num_y, block_num_x);\n    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);\n\n    if (nrows_x % mmq_y == 0) {\n        const bool need_check = false;\n        /*\n        DPCT1049:32: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);\n                sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q3_K<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_ql_q3_K_acc_ct1),\n                            get_pointer(tile_x_dm_q3_K_acc_ct1),\n                            get_pointer(tile_x_qh_q3_K_acc_ct1),\n                            get_pointer(tile_x_sc_q3_K_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    } else {\n        const bool need_check = true;\n        /*\n        DPCT1049:33: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);\n                sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q3_K<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_ql_q3_K_acc_ct1),\n                            get_pointer(tile_x_dm_q3_K_acc_ct1),\n                            get_pointer(tile_x_qh_q3_K_acc_ct1),\n                            get_pointer(tile_x_sc_q3_K_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    }\n#endif\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,\n                                        float *dst, const int ncols_x,\n                                        const int nrows_x, const int ncols_y,\n                                        const int nrows_y, const int nrows_dst,\n                                        dpct::queue_ptr stream) try {\n\n    int id;\n    SYCL_CHECK(\n        CHECK_TRY_ERROR(id = get_current_device_id()));\n    const int compute_capability = ggml_sycl_info().devices[id].cc;\n\n    int mmq_x, mmq_y, nwarps;\n    if (compute_capability >= VER_GEN13) {\n        mmq_x  =  MMQ_X_Q4_K_RDNA2;\n        mmq_y  =  MMQ_Y_Q4_K_RDNA2;\n        nwarps = NWARPS_Q4_K_RDNA2;\n    } else if (compute_capability >= VER_GEN12) {\n        mmq_x  =  MMQ_X_Q4_K_RDNA1;\n        mmq_y  =  MMQ_Y_Q4_K_RDNA1;\n        nwarps = NWARPS_Q4_K_RDNA1;\n    } else if (compute_capability >= VER_GEN9) {\n        mmq_x  =  MMQ_X_Q4_K_AMPERE;\n        mmq_y  =  MMQ_Y_Q4_K_AMPERE;\n        nwarps = NWARPS_Q4_K_AMPERE;\n    } else if (compute_capability >= VER_4VEC) {\n        mmq_x  =  MMQ_X_Q4_K_PASCAL;\n        mmq_y  =  MMQ_Y_Q4_K_PASCAL;\n        nwarps = NWARPS_Q4_K_PASCAL;\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n\n    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;\n    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;\n    const sycl::range<3> block_nums(1, block_num_y, block_num_x);\n    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);\n\n    if (nrows_x % mmq_y == 0) {\n        const bool need_check = false;\n        /*\n        DPCT1049:34: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q4_K<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_ql_q4_K_acc_ct1),\n                            get_pointer(tile_x_dm_q4_K_acc_ct1),\n                            get_pointer(tile_x_sc_q4_K_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    } else {\n        const bool need_check = true;\n        /*\n        DPCT1049:35: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q4_K<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_ql_q4_K_acc_ct1),\n                            get_pointer(tile_x_dm_q4_K_acc_ct1),\n                            get_pointer(tile_x_sc_q4_K_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,\n                                        float *dst, const int ncols_x,\n                                        const int nrows_x, const int ncols_y,\n                                        const int nrows_y, const int nrows_dst,\n                                        dpct::queue_ptr stream) try {\n\n    int id;\n    SYCL_CHECK(\n        CHECK_TRY_ERROR(id = get_current_device_id()));\n    const int compute_capability = ggml_sycl_info().devices[id].cc;\n\n    int mmq_x, mmq_y, nwarps;\n    if (compute_capability >= VER_GEN13) {\n        mmq_x  =  MMQ_X_Q5_K_RDNA2;\n        mmq_y  =  MMQ_Y_Q5_K_RDNA2;\n        nwarps = NWARPS_Q5_K_RDNA2;\n    } else if (compute_capability >= VER_GEN12) {\n        mmq_x  =  MMQ_X_Q5_K_RDNA1;\n        mmq_y  =  MMQ_Y_Q5_K_RDNA1;\n        nwarps = NWARPS_Q5_K_RDNA1;\n    } else if (compute_capability >= VER_GEN9) {\n        mmq_x  =  MMQ_X_Q5_K_AMPERE;\n        mmq_y  =  MMQ_Y_Q5_K_AMPERE;\n        nwarps = NWARPS_Q5_K_AMPERE;\n    } else if (compute_capability >= VER_4VEC) {\n        mmq_x  =  MMQ_X_Q5_K_PASCAL;\n        mmq_y  =  MMQ_Y_Q5_K_PASCAL;\n        nwarps = NWARPS_Q5_K_PASCAL;\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n\n    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;\n    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;\n    const sycl::range<3> block_nums(1, block_num_y, block_num_x);\n    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);\n\n    if (nrows_x % mmq_y == 0) {\n        const bool need_check = false;\n        /*\n        DPCT1049:36: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q5_K<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_ql_q5_K_acc_ct1),\n                            get_pointer(tile_x_dm_q5_K_acc_ct1),\n                            get_pointer(tile_x_sc_q5_K_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    } else {\n        const bool need_check = true;\n        /*\n        DPCT1049:37: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q5_K<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_ql_q5_K_acc_ct1),\n                            get_pointer(tile_x_dm_q5_K_acc_ct1),\n                            get_pointer(tile_x_sc_q5_K_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nstatic void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,\n                                        float *dst, const int ncols_x,\n                                        const int nrows_x, const int ncols_y,\n                                        const int nrows_y, const int nrows_dst,\n                                        dpct::queue_ptr stream) try {\n\n    int id;\n    SYCL_CHECK(\n        CHECK_TRY_ERROR(id = get_current_device_id()));\n    const int compute_capability = ggml_sycl_info().devices[id].cc;\n\n    int mmq_x, mmq_y, nwarps;\n    if (compute_capability >= VER_GEN13) {\n        mmq_x  =  MMQ_X_Q6_K_RDNA2;\n        mmq_y  =  MMQ_Y_Q6_K_RDNA2;\n        nwarps = NWARPS_Q6_K_RDNA2;\n    } else if (compute_capability >= VER_GEN12) {\n        mmq_x  =  MMQ_X_Q6_K_RDNA1;\n        mmq_y  =  MMQ_Y_Q6_K_RDNA1;\n        nwarps = NWARPS_Q6_K_RDNA1;\n    } else if (compute_capability >= VER_GEN9) {\n        mmq_x  =  MMQ_X_Q6_K_AMPERE;\n        mmq_y  =  MMQ_Y_Q6_K_AMPERE;\n        nwarps = NWARPS_Q6_K_AMPERE;\n    } else if (compute_capability >= VER_4VEC) {\n        mmq_x  =  MMQ_X_Q6_K_PASCAL;\n        mmq_y  =  MMQ_Y_Q6_K_PASCAL;\n        nwarps = NWARPS_Q6_K_PASCAL;\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n\n    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;\n    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;\n    const sycl::range<3> block_nums(1, block_num_y, block_num_x);\n    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);\n\n    if (nrows_x % mmq_y == 0) {\n        const bool need_check = false;\n        /*\n        DPCT1049:38: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(\n                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q6_K<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_ql_acc_ct1),\n                            get_pointer(tile_x_dm_acc_ct1),\n                            get_pointer(tile_x_sc_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    } else {\n        const bool need_check = true;\n        /*\n        DPCT1049:39: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        {\n            dpct::has_capability_or_fail(stream->get_device(),\n                                         {sycl::aspect::fp16});\n\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(\n                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),\n                    cgh);\n                sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(\n                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);\n                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);\n                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(\n                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) {\n                        mul_mat_q6_K<need_check>(\n                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,\n                            nrows_dst, item_ct1,\n                            get_pointer(tile_x_ql_acc_ct1),\n                            get_pointer(tile_x_dm_acc_ct1),\n                            get_pointer(tile_x_sc_acc_ct1),\n                            get_pointer(tile_y_qs_acc_ct1),\n                            get_pointer(tile_y_ds_acc_ct1));\n                    });\n            });\n        }\n    }\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n\nvoid ggml_sycl_op_mul_mat_q(\n    ggml_backend_sycl_context & ctx,\n    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,\n    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,\n    float *dst_dd_i, const int64_t row_low, const int64_t row_high,\n    const int64_t src1_ncols, const int64_t src1_padded_row_size,\n    const dpct::queue_ptr &stream) try {\n\n    const int64_t ne00 = src0->ne[0];\n\n    const int64_t ne10 = src1->ne[0];\n    GGML_ASSERT(ne10 % QK8_1 == 0);\n\n    const int64_t ne0 = dst->ne[0];\n\n    const int64_t row_diff = row_high - row_low;\n\n    int device_id;\n    SYCL_CHECK(\n        CHECK_TRY_ERROR(device_id = get_current_device_id()));\n\n    // the main device has a larger memory buffer to hold the results from all GPUs\n    // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into\n    const int64_t nrows_dst = device_id == ctx.device ? ne0 : row_diff;\n\n    switch (src0->type) {\n        case GGML_TYPE_Q4_0:\n            ggml_mul_mat_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);\n            break;\n        case GGML_TYPE_Q4_1:\n            ggml_mul_mat_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);\n            break;\n        case GGML_TYPE_Q5_0:\n            ggml_mul_mat_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);\n            break;\n        case GGML_TYPE_Q5_1:\n            ggml_mul_mat_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);\n            break;\n        case GGML_TYPE_Q8_0:\n            ggml_mul_mat_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);\n            break;\n        case GGML_TYPE_Q2_K:\n            ggml_mul_mat_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);\n            break;\n        case GGML_TYPE_Q3_K:\n            ggml_mul_mat_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);\n            break;\n        case GGML_TYPE_Q4_K:\n            ggml_mul_mat_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);\n            break;\n        case GGML_TYPE_Q5_K:\n            ggml_mul_mat_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);\n            break;\n        case GGML_TYPE_Q6_K:\n            ggml_mul_mat_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);\n            break;\n        default:\n            GGML_ABORT(\"fatal error\");\n    }\n\n    GGML_UNUSED(src1);\n    GGML_UNUSED(dst);\n    GGML_UNUSED(src1_ddf_i);\n}\ncatch (sycl::exception const &exc) {\n  std::cerr << exc.what() << \"Exception caught at file:\" << __FILE__\n            << \", line:\" << __LINE__ << std::endl;\n  std::exit(1);\n}\n"
  },
  {
    "path": "src/ggml-sycl/mmq.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_MMQ_HPP\n#define GGML_SYCL_MMQ_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_op_mul_mat_q(\n    ggml_backend_sycl_context & ctx,\n    const ggml_tensor* src0,\n    const ggml_tensor* src1,\n    ggml_tensor* dst,\n    const char* src0_dd_i,\n    const float* src1_ddf_i,\n    const char* src1_ddq_i,\n    float* dst_dd_i,\n    const int64_t row_low,\n    const int64_t row_high,\n    const int64_t src1_ncols,\n    const int64_t src1_padded_row_size,\n    const dpct::queue_ptr& stream);\n\n#endif // GGML_SYCL_MMQ_HPP\n"
  },
  {
    "path": "src/ggml-sycl/mmvq.cpp",
    "content": "#include \"mmvq.hpp\"\n\n#include \"ggml.h\"\n#include \"common.hpp\"\n#include \"quants.hpp\"\n#include \"vecdotq.hpp\"\n\ntemplate <typename reorder_vec_dot_q_sycl>\nstatic void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,\n                                  const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) {\n    using block_type   = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>;\n    using block_traits = typename block_type::traits;\n\n    const auto sg           = nd_item.get_sub_group();\n    const int  sg_range     = sg.get_group_linear_range();\n    const int  workgroup_id = nd_item.get_group_linear_id();\n    const int  sg_id        = sg.get_group_linear_id();\n    const int  row          = workgroup_id * sg_range + sg_id;\n\n    if (row >= nrows) {\n        return;\n    }\n\n    const int     blocks_per_row              = ncols / block_traits::qk;\n    constexpr int blocks_per_subgroup         = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);\n    constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;\n    const int     nblocks                     = nrows * (ncols / block_traits::qk);\n\n    static_assert(blocks_per_subgroup > 0);\n    static_assert(block_elements_per_subgroup > 0);\n\n    float partial_sum = 0.0f;\n    for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {\n        const int ibx = row * blocks_per_row + i;  // x block index\n\n        const auto         bx_offset      = block_type::get_block_offset(ibx, nblocks);\n        const auto         d_offset       = block_type::get_d_offset(nrows, ncols, ibx);\n        // Y block index that aligns with ibx\n        const int iby = i * block_type::block_to_q8_1_ratio();\n        const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;\n        const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2));\n\n#pragma unroll\n        for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {\n            // x block quant index when casting the quants to int\n            const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);\n\n            partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);\n        }\n    }\n\n    auto sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum, std::plus<>());\n\n    if (sg.leader()) {\n        dst[row] = sum;\n    }\n}\n\ntemplate <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>\nstatic void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,\n                          const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) {\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);\n\n    if (row >= nrows) {\n        return;\n    }\n\n    const int     blocks_per_row  = ncols / qk;\n    constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi;  // Ensuring blocks_per_warp > 0\n\n    assert(blocks_per_warp > 0);\n\n    // partial sum for each thread\n    float tmp = 0.0f;\n\n    const block_q_t *  x = (const block_q_t *) vx;\n    const block_q8_1 * y = (const block_q8_1 *) vy;\n\n    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {\n        const int ibx = row * blocks_per_row + i;  // x block index\n\n        const int iby = i * (qk / QK8_1);          // y block index that aligns with ibx\n\n        for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) {\n            const int iqs = elem + vdr * (item_ct1.get_local_id(2) %\n                                          (qi / vdr));  // x block quant index when casting the quants to int\n\n            tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);\n        }\n    }\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[row] = tmp;\n    }\n}\n\ntemplate <int qk, int qi, typename block_q_t, int vdr>\nstatic void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,\n                                       const void *__restrict__ vy,\n                                       float *__restrict__ dst, const int ncols,\n                                       const int nrows,\n                                       const sycl::nd_item<3> &item_ct1) {\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n\n    if (row >= nrows) {\n        return;\n    }\n\n    const int blocks_per_row = ncols / qk;\n    const int blocks_per_warp = vdr * WARP_SIZE / qi;\n    assert(blocks_per_warp>0);\n\n// partial sum for each thread\n    float tmp = 0.0f;\n\n    const block_q_t  * x = (const block_q_t  *) vx;\n    const block_q8_1 * y = (const block_q8_1 *) vy;\n\n    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;\n         i += blocks_per_warp) {\n        const int ibx = row*blocks_per_row + i; // x block index\n\n        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx\n\n        const int iqs =\n            vdr *\n            (item_ct1.get_local_id(2) %\n             (qi / vdr)); // x block quant index when casting the quants to int\n\n        tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);\n    }\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[row] = tmp;\n    }\n}\n\ntemplate <int qk, int qi, typename block_q_t, int vdr>\nstatic void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,\n                                      const void *__restrict__ vy,\n                                      float *__restrict__ dst, const int ncols,\n                                      const int nrows,\n                                      const sycl::nd_item<3> &item_ct1) {\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n\n    if (row >= nrows) {\n        return;\n    }\n\n    const int blocks_per_row = ncols / qk;\n    const int blocks_per_warp = vdr * WARP_SIZE / qi;\n    assert(blocks_per_warp>0);\n// partial sum for each thread\n    float tmp = 0.0f;\n\n    const block_q_t  * x = (const block_q_t  *) vx;\n    const block_q8_1 * y = (const block_q8_1 *) vy;\n\n    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;\n         i += blocks_per_warp) {\n        const int ibx = row*blocks_per_row + i; // x block index\n\n        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx\n\n        const int iqs =\n            vdr *\n            (item_ct1.get_local_id(2) %\n             (qi / vdr)); // x block quant index when casting the quants to int\n\n        tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid, ksigns64);\n    }\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[row] = tmp;\n    }\n}\n\ntemplate <int qk, int qi, typename block_q_t, int vdr>\nstatic void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,\n                                     const void *__restrict__ vy,\n                                     float *__restrict__ dst, const int ncols,\n                                     const int nrows,\n                                     const sycl::nd_item<3> &item_ct1) {\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n\n    if (row >= nrows) {\n        return;\n    }\n\n    const int blocks_per_row = ncols / qk;\n    const int blocks_per_warp = vdr * WARP_SIZE / qi;\n    assert(blocks_per_warp>0);\n// partial sum for each thread\n    float tmp = 0.0f;\n\n    const block_q_t  * x = (const block_q_t  *) vx;\n    const block_q8_1 * y = (const block_q8_1 *) vy;\n\n    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;\n         i += blocks_per_warp) {\n        const int ibx = row*blocks_per_row + i; // x block index\n\n        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx\n\n        const int iqs =\n            vdr *\n            (item_ct1.get_local_id(2) %\n             (qi / vdr)); // x block quant index when casting the quants to int\n\n        tmp += vec_dot_iq2_s_q8_1(&x[ibx], &y[iby], iqs);\n    }\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[row] = tmp;\n    }\n}\n\ntemplate <int qk, int qi, typename block_q_t, int vdr>\nstatic void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,\n                                       const void *__restrict__ vy,\n                                       float *__restrict__ dst, const int ncols,\n                                       const int nrows,\n                                       const sycl::nd_item<3> &item_ct1) {\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n\n    if (row >= nrows) {\n        return;\n    }\n\n    const int blocks_per_row = ncols / qk;\n    const int blocks_per_warp = vdr * WARP_SIZE / qi;\n    assert(blocks_per_warp>0);\n// partial sum for each thread\n    float tmp = 0.0f;\n\n    const block_q_t  * x = (const block_q_t  *) vx;\n    const block_q8_1 * y = (const block_q8_1 *) vy;\n\n    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;\n         i += blocks_per_warp) {\n        const int ibx = row*blocks_per_row + i; // x block index\n\n        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx\n\n        const int iqs =\n            vdr *\n            (item_ct1.get_local_id(2) %\n             (qi / vdr)); // x block quant index when casting the quants to int\n\n        tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid, ksigns64);\n    }\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[row] = tmp;\n    }\n}\n\ntemplate <int qk, int qi, typename block_q_t, int vdr>\nstatic void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,\n                                     const void *__restrict__ vy,\n                                     float *__restrict__ dst, const int ncols,\n                                     const int nrows,\n                                     const sycl::nd_item<3> &item_ct1) {\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n\n    if (row >= nrows) {\n        return;\n    }\n\n    const int blocks_per_row = ncols / qk;\n    const int blocks_per_warp = vdr * WARP_SIZE / qi;\n    assert(blocks_per_warp>0);\n// partial sum for each thread\n    float tmp = 0.0f;\n\n    const block_q_t  * x = (const block_q_t  *) vx;\n    const block_q8_1 * y = (const block_q8_1 *) vy;\n\n    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;\n         i += blocks_per_warp) {\n        const int ibx = row*blocks_per_row + i; // x block index\n\n        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx\n\n        const int iqs =\n            vdr *\n            (item_ct1.get_local_id(2) %\n             (qi / vdr)); // x block quant index when casting the quants to int\n\n        tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid);\n    }\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[row] = tmp;\n    }\n}\n\ntemplate <int qk, int qi, typename block_q_t, int vdr>\nstatic void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,\n                                     const void *__restrict__ vy,\n                                     float *__restrict__ dst, const int ncols,\n                                     const int nrows,\n                                     const sycl::nd_item<3> &item_ct1) {\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n\n    if (row >= nrows) {\n        return;\n    }\n\n    const int blocks_per_row = ncols / qk;\n    const int blocks_per_warp = vdr * WARP_SIZE / qi;\n    assert(blocks_per_warp>0);\n// partial sum for each thread\n    float tmp = 0.0f;\n\n    const block_q_t  * x = (const block_q_t  *) vx;\n    const block_q8_1 * y = (const block_q8_1 *) vy;\n\n    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;\n         i += blocks_per_warp) {\n        const int ibx = row*blocks_per_row + i; // x block index\n\n        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx\n\n        const int iqs =\n            vdr *\n            (item_ct1.get_local_id(2) %\n             (qi / vdr)); // x block quant index when casting the quants to int\n\n        tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_gpu);\n    }\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[row] = tmp;\n    }\n}\n\ntemplate <int qk, int qi, typename block_q_t, int vdr>\nstatic void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,\n                                     const void *__restrict__ vy,\n                                     float *__restrict__ dst, const int ncols,\n                                     const int nrows,\n                                     const sycl::nd_item<3> &item_ct1) {\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n\n    if (row >= nrows) {\n        return;\n    }\n\n    const int blocks_per_row = ncols / qk;\n    const int blocks_per_warp = vdr * WARP_SIZE / qi;\n    assert(blocks_per_warp>0);\n// partial sum for each thread\n    float tmp = 0.0f;\n\n    const block_q_t  * x = (const block_q_t  *) vx;\n    const block_q8_1 * y = (const block_q8_1 *) vy;\n\n    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;\n         i += blocks_per_warp) {\n        const int ibx = row*blocks_per_row + i; // x block index\n\n        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx\n\n        const int iqs =\n            vdr *\n            (item_ct1.get_local_id(2) %\n             (qi / vdr)); // x block quant index when casting the quants to int\n\n        tmp += vec_dot_iq1_m_q8_1(&x[ibx], &y[iby], iqs);\n    }\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[row] = tmp;\n    }\n}\n\ntemplate <int qk, int qi, typename block_q_t, int vdr>\nstatic void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,\n                                      const void *__restrict__ vy,\n                                      float *__restrict__ dst, const int ncols,\n                                      const int nrows,\n                                      const sycl::nd_item<3> &item_ct1) {\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n\n    if (row >= nrows) {\n        return;\n    }\n\n    const int blocks_per_row = ncols / qk;\n    const int blocks_per_warp = vdr * WARP_SIZE / qi;\n    assert(blocks_per_warp>0);\n// partial sum for each thread\n    float tmp = 0.0f;\n\n    const block_q_t  * x = (const block_q_t  *) vx;\n    const block_q8_1 * y = (const block_q8_1 *) vy;\n\n    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;\n         i += blocks_per_warp) {\n        const int ibx = row*blocks_per_row + i; // x block index\n\n        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx\n\n        const int iqs =\n            vdr *\n            (item_ct1.get_local_id(2) %\n             (qi / vdr)); // x block quant index when casting the quants to int\n\n        tmp += vec_dot_iq4_nl_q8_1(&x[ibx], &y[iby], iqs);\n    }\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[row] = tmp;\n    }\n}\n\n\ntemplate <int qk, int qi, typename block_q_t, int vdr>\nstatic void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,\n                                      const void *__restrict__ vy,\n                                      float *__restrict__ dst, const int ncols,\n                                      const int nrows,\n                                      const sycl::nd_item<3> &item_ct1) {\n    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +\n                    item_ct1.get_local_id(1);\n\n    if (row >= nrows) {\n        return;\n    }\n\n    const int blocks_per_row = ncols / qk;\n    const int blocks_per_warp = vdr * WARP_SIZE / qi;\n    assert(blocks_per_warp>0);\n// partial sum for each thread\n    float tmp = 0.0f;\n\n    const block_q_t  * x = (const block_q_t  *) vx;\n    const block_q8_1 * y = (const block_q8_1 *) vy;\n\n    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;\n         i += blocks_per_warp) {\n        const int ibx = row*blocks_per_row + i; // x block index\n\n        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx\n\n        const int iqs =\n            vdr *\n            (item_ct1.get_local_id(2) %\n             (qi / vdr)); // x block quant index when casting the quants to int\n\n        tmp += vec_dot_iq4_xs_q8_1(&x[ibx], &y[iby], iqs);\n    }\n\n    // sum up partial sums and write back result\n#pragma unroll\n    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {\n        tmp +=\n            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);\n    }\n\n    if (item_ct1.get_local_id(2) == 0) {\n        dst[row] = tmp;\n    }\n}\n\nstatic void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,\n                                                    const int nrows, dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK4_0 == 0);\n    const int        block_num_y   = ceil_div(nrows, GGML_SYCL_MMV_Y);\n    constexpr size_t num_subgroups = 16;\n    GGML_ASSERT(block_num_y % num_subgroups == 0);\n\n    const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));\n    const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);\n\n    stream->submit([&](sycl::handler & cgh) {\n        cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),\n                         [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                             mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,\n                                                                                           nd_item);\n                         });\n    });\n}\n\nstatic void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,\n                                       dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK4_0 == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n\n    {\n        stream->submit([&](sycl::handler & cgh) {\n            cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                             [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                                 mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(\n                                     vx, vy, dst, ncols, nrows, item_ct1);\n                             });\n        });\n    }\n}\n\nstatic void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,\n                                       float *dst, const int ncols,\n                                       const int nrows,\n                                       dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK4_1 == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,\n                                      VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,\n                                        dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_MXFP4 == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n\n    {\n        stream->submit([&](sycl::handler & cgh) {\n            cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                             [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                                 mul_mat_vec_q<QK_MXFP4, QI_MXFP4, block_mxfp4, VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1>(\n                                     vx, vy, dst, ncols, nrows, item_ct1);\n                             });\n        });\n    }\n}\n\n\nstatic void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,\n                                       float *dst, const int ncols,\n                                       const int nrows,\n                                       dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK5_0 == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,\n                                      VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,\n                                       float *dst, const int ncols,\n                                       const int nrows,\n                                       dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK5_1 == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,\n                                      VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,\n                                       float *dst, const int ncols,\n                                       const int nrows,\n                                       dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK8_0 == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,\n                                      VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,\n                                       float *dst, const int ncols,\n                                       const int nrows,\n                                       dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q<QK_K, QI2_K, block_q2_K,\n                                      VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,\n                                       float *dst, const int ncols,\n                                       const int nrows,\n                                       dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q<QK_K, QI3_K, block_q3_K,\n                                      VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,\n                                       float *dst, const int ncols,\n                                       const int nrows,\n                                       dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q<QK_K, QI4_K, block_q4_K,\n                                      VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,\n    const int nrows, dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n\n    const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);\n    constexpr size_t num_subgroups = 16;\n    GGML_ASSERT(block_num_y % num_subgroups == 0);\n\n    const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);\n    const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);\n\n    stream->submit([&](sycl::handler & cgh) {\n        cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),\n                            [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                                mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,\n                                                                                            nrows, nd_item);\n                            });\n    });\n}\n\n\nstatic void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,\n                                       float *dst, const int ncols,\n                                       const int nrows,\n                                       dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q<QK_K, QI5_K, block_q5_K,\n                                      VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,\n                                               const int nrows, dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int        block_num_y   = ceil_div(nrows, GGML_SYCL_MMV_Y);\n    constexpr size_t num_subgroups = 16;\n    GGML_ASSERT(block_num_y % num_subgroups == 0);\n\n    const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);\n    const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);\n\n    stream->submit([&](sycl::handler & cgh) {\n        cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),\n                         [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                             mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,\n                                                                                           nd_item);\n                         });\n    });\n}\nstatic void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,\n                                       float *dst, const int ncols,\n                                       const int nrows,\n                                       dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q<QK_K, QI6_K, block_q6_K,\n                                      VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\n\nstatic void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,\n                                          float *dst, const int ncols,\n                                          const int nrows,\n                                          dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,\n                                         float *dst, const int ncols,\n                                         const int nrows,\n                                         dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n        stream->submit([&](sycl::handler & cgh) {\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,\n                                         float *dst, const int ncols,\n                                         const int nrows,\n                                         dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,\n                                          float *dst, const int ncols,\n                                          const int nrows,\n                                          dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,\n                                          float *dst, const int ncols,\n                                          const int nrows,\n                                          dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,\n                                          float *dst, const int ncols,\n                                          const int nrows,\n                                          dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,\n                                          float *dst, const int ncols,\n                                          const int nrows,\n                                          dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,\n                                          float *dst, const int ncols,\n                                          const int nrows,\n                                          dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK4_NL == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nstatic void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,\n                                          float *dst, const int ncols,\n                                          const int nrows,\n                                          dpct::queue_ptr stream) {\n    GGML_ASSERT(ncols % QK_K == 0);\n    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;\n    const sycl::range<3> block_nums(1, 1, block_num_y);\n    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);\n    {\n\n        stream->submit([&](sycl::handler &cgh) {\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                        mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(\n                            vx, vy, dst, ncols, nrows, item_ct1);\n                    });\n        });\n    }\n}\n\nvoid ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,\n                                ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,\n                                const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low,\n                                const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_col_size,\n                                const dpct::queue_ptr & stream) {\n    const int64_t ne10 = src1->ne[0];\n    GGML_ASSERT(ne10 % QK8_1 == 0);\n\n    const int64_t ne00     = src0->ne[0];\n    const int64_t row_diff = row_high - row_low;\n\n    int id;\n    SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id()));\n    const size_t q8_1_ts = sizeof(block_q8_1);\n    const size_t q8_1_bs = QK8_1;\n    // the main device has a larger memory buffer to hold the results from all GPUs\n    // nrows_dst == nrows of the matrix that the kernel writes into\n\n    for (int i = 0; i < src1_ncols; i++) {\n        const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;\n        const char * src1_ddq_i_bs     = src1_ddq_i + src1_ddq_i_offset;\n        float *      dst_dd_i_bs       = dst_dd_i + i * dst->ne[0];\n        switch (src0->type) {\n            case GGML_TYPE_Q4_0:\n                if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&\n                    ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {\n                    GGML_SYCL_DEBUG(\"Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\\n\");\n                    reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                } else {\n                    GGML_SYCL_DEBUG(\"Calling mul_mat_vec_q4_0_q8_1_sycl\\n\");\n                    mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                }\n                break;\n            case GGML_TYPE_Q4_1:\n                mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_Q5_0:\n                mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_Q5_1:\n                mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_Q8_0:\n                mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_Q2_K:\n                mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_Q3_K:\n                mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_Q4_K:\n                if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&\n                    ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {\n                    GGML_SYCL_DEBUG(\"Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\\n\");\n                    reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                } else {\n                    GGML_SYCL_DEBUG(\"Calling mul_mat_vec_q4_K_q8_1_sycl\\n\");\n                    mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                }\n                break;\n            case GGML_TYPE_Q5_K:\n                mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_Q6_K:\n                if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&\n                    ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {\n                    GGML_SYCL_DEBUG(\"Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\\n\");\n                    reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                } else {\n                    GGML_SYCL_DEBUG(\"Calling mul_mat_vec_q6_k_q8_1_sycl\\n\");\n                    mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                }\n                break;\n            case GGML_TYPE_IQ1_S:\n                mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_IQ1_M:\n                mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_IQ2_XXS:\n                mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_IQ2_XS:\n                mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_IQ2_S:\n                mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_IQ3_XXS:\n                mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_IQ3_S:\n                mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_IQ4_NL:\n                mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_IQ4_XS:\n                mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            case GGML_TYPE_MXFP4:\n                mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);\n                break;\n            default:\n                GGML_ABORT(\"fatal error\");\n        }\n    }\n    GGML_UNUSED(src1);\n    GGML_UNUSED(dst);\n    GGML_UNUSED(src1_ddf_i);\n    GGML_UNUSED(ctx);\n}\n"
  },
  {
    "path": "src/ggml-sycl/mmvq.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_MMVQ_HPP\n#define GGML_SYCL_MMVQ_HPP\n\n#include \"common.hpp\"\n\n\nvoid ggml_sycl_op_mul_mat_vec_q(\n    ggml_backend_sycl_context & ctx,\n    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,\n    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,\n    float *dst_dd_i, const int64_t row_low, const int64_t row_high,\n    const int64_t src1_ncols, const int64_t src1_padded_row_size,\n    const dpct::queue_ptr &stream);\n\n#endif // GGML_SYCL_MMVQ_HPP\n"
  },
  {
    "path": "src/ggml-sycl/norm.cpp",
    "content": "#include \"norm.hpp\"\n#include \"ggml-sycl/common.hpp\"\n#include \"ggml-sycl/presets.hpp\"\n\nstatic void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,\n        const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {\n\n    const int nrows = item_ct1.get_group_range(2);\n    const int nchannels = item_ct1.get_group_range(1);\n\n    const int nthreads = item_ct1.get_local_range(2);\n    const int sample  = item_ct1.get_group(0);\n    const int channel = item_ct1.get_group(1);\n    const int row     = item_ct1.get_group(2);\n\n    const int tid = item_ct1.get_local_id(2);\n    const int nwarps = nthreads / WARP_SIZE;\n\n    const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});\n    const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});\n\n    x += strided_offset;\n    dst += packed_offset;\n\n    sycl::float2 mean_var = sycl::float2(0.f, 0.f);\n\n    for (int col = tid; col < ncols; col += block_size) {\n        const float xi = x[col];\n        mean_var.x() += xi;\n        mean_var.y() += xi * xi;\n    }\n\n    // sum up partial sums\n    mean_var = warp_reduce_sum(mean_var, item_ct1);\n    if  (block_size > WARP_SIZE) {\n        const auto sub_group = item_ct1.get_sub_group();\n        const auto sg_id = sub_group.get_group_linear_id();\n        const auto wi_in_sg = sub_group.get_local_linear_id();\n        if (wi_in_sg == 0) {\n            s_sum[sg_id] = mean_var;\n        }\n        item_ct1.barrier(sycl::access::fence_space::local_space);\n        mean_var = 0.f;\n        const size_t nreduce = ceil_div(nwarps, WARP_SIZE);\n        for (size_t i = 0; i < nreduce; i += 1)\n        {\n            mean_var += s_sum[wi_in_sg + i * WARP_SIZE];\n        }\n        mean_var = warp_reduce_sum(mean_var, item_ct1);\n    }\n\n    const float mean = mean_var.x() / ncols;\n    const float var = mean_var.y() / ncols - mean * mean;\n    const float inv_std = sycl::rsqrt(var + eps);\n\n    for (int col = tid; col < ncols; col += block_size) {\n        dst[col] = (x[col] - mean) * inv_std;\n    }\n}\n\nstatic void group_norm_f32(const float* x, float* dst, const int group_size, const int ne_elements, const float eps,\n    const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {\n    int start = item_ct1.get_group(2) * group_size;\n    int end = start + group_size;\n    const int nthreads = item_ct1.get_local_range(2);\n    const int nwarps = nthreads / WARP_SIZE;\n    start += item_ct1.get_local_id(2);\n    size_t nreduce = nwarps / WARP_SIZE;\n\n    if (end >= ne_elements) {\n        end = ne_elements;\n    }\n\n    float tmp = 0.0f; // partial sum for thread in warp\n\n    for (int j = start; j < end; j += block_size) {\n        tmp += x[j];\n    }\n\n    tmp = warp_reduce_sum(tmp, item_ct1);\n    if (block_size > WARP_SIZE) {\n\n        int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;\n        int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;\n        if (lane_id == 0) {\n            s_sum[warp_id] = tmp;\n        }\n        /*\n        DPCT1118:1: SYCL group functions and algorithms must be encountered in\n        converged control flow. You may need to adjust the code.\n        */\n        /*\n        DPCT1065:54: Consider replacing sycl::nd_item::barrier() with\n        sycl::nd_item::barrier(sycl::access::fence_space::local_space) for\n        better performance if there is no access to global memory.\n        */\n        item_ct1.barrier();\n        tmp = 0.f;\n        for (size_t i = 0; i < nreduce; i += 1)\n        {\n            tmp += s_sum[lane_id + i * WARP_SIZE];\n        }\n        tmp = warp_reduce_sum(tmp, item_ct1);\n    }\n\n    float mean = tmp / group_size;\n    tmp = 0.0f;\n\n    for (int j = start; j < end; j += block_size) {\n        float xi = x[j] - mean;\n        dst[j] = xi;\n        tmp += xi * xi;\n    }\n\n    tmp = warp_reduce_sum(tmp, item_ct1);\n    if (block_size > WARP_SIZE) {\n\n        int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;\n        int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;\n        if (lane_id == 0) {\n            s_sum[warp_id] = tmp;\n        }\n        /*\n        DPCT1118:2: SYCL group functions and algorithms must be encountered in\n        converged control flow. You may need to adjust the code.\n        */\n        /*\n        DPCT1065:55: Consider replacing sycl::nd_item::barrier() with\n        sycl::nd_item::barrier(sycl::access::fence_space::local_space) for\n        better performance if there is no access to global memory.\n        */\n        item_ct1.barrier();\n        tmp = 0.f;\n        for (size_t i = 0; i < nreduce; i += 1)\n        {\n            tmp += s_sum[lane_id + i * WARP_SIZE];\n        }\n        tmp = warp_reduce_sum(tmp, item_ct1);\n    }\n\n    float variance = tmp / group_size;\n    float scale = sycl::rsqrt(variance + eps);\n    for (int j = start; j < end; j += block_size) {\n        dst[j] *= scale;\n    }\n}\n\nstatic void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,\n        const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {\n\n    const int nrows = item_ct1.get_group_range(2);\n    const int nchannels = item_ct1.get_group_range(1);\n\n    const int sample  = item_ct1.get_group(0);\n    const int channel = item_ct1.get_group(1);\n    const int row     = item_ct1.get_group(2);\n\n    const int nthreads = item_ct1.get_local_range(2);\n\n    const int tid = item_ct1.get_local_id(2);\n    const int nwarps = nthreads / WARP_SIZE;\n\n    const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});\n    const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});\n\n    x   += strided_offset;\n    dst += packed_offset;\n\n\n    float tmp = 0.0f; // partial sum for thread in warp\n\n    for (int col = tid; col < ncols; col += block_size) {\n        const float xi = x[col];\n        tmp += xi * xi;\n    }\n\n    // sum up partial sums\n    tmp = warp_reduce_sum(tmp, item_ct1);\n    if (block_size > WARP_SIZE) {\n        const auto sub_group = item_ct1.get_sub_group();\n        const auto sg_id = sub_group.get_group_linear_id();\n        const auto wi_in_sg = sub_group.get_local_linear_id();\n        if (wi_in_sg == 0) {\n            s_sum[sg_id] = tmp;\n        }\n\n        item_ct1.barrier(sycl::access::fence_space::local_space);\n        const size_t nreduce = ceil_div(nwarps, WARP_SIZE);\n        tmp = 0.f;\n        for (size_t i = 0; i < nreduce; i += 1)\n        {\n            tmp += s_sum[wi_in_sg + i * WARP_SIZE];\n        }\n        tmp = warp_reduce_sum(tmp, item_ct1);\n    }\n\n    const float mean = tmp / ncols;\n    const float scale = sycl::rsqrt(mean + eps);\n\n    for (int col = tid; col < ncols; col += block_size) {\n        dst[col] = scale * x[col];\n    }\n}\n\ntemplate<int warp_size>\nstatic void l2_norm_f32(const float * x, float * dst, const int ncols,\n    const int64_t stride_row, const int64_t stride_channel,\n    const int64_t stride_sample, const float eps,\n    const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) {\n    const int nrows     = item_ct1.get_group_range(2);\n    const int nchannels = item_ct1.get_group_range(1);\n\n    const int row     = item_ct1.get_group(2);\n    const int channel = item_ct1.get_group(1);\n    const int sample  = item_ct1.get_group(0);\n    const int tid     = item_ct1.get_local_id(2);\n\n    x   += sample*stride_sample + channel*stride_channel + row*stride_row;\n    dst += ((sample*nchannels + channel)*nrows + row)*ncols;\n\n    float tmp = 0.0f; // partial sum for thread in warp\n\n    for (int col = tid; col < ncols; col += block_size) {\n        const float xi = x[col];\n        tmp += xi * xi;\n    }\n\n    tmp = block_reduce<block_reduce_method::SUM, warp_size>(tmp, s_sum, block_size);\n    const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps));\n\n    for (int col = tid; col < ncols; col += block_size) {\n        dst[col] = scale * x[col];\n    }\n}\n\nstatic void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,\n        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,\n        const float eps, queue_ptr stream, int device) {\n\n    const sycl::range<3> global_dims(nsamples, nchannels, nrows);\n    if (ncols < 1024) {\n        const sycl::range<3> block_dims(1, 1, WARP_SIZE);\n        stream->submit([&](sycl::handler& cgh) {\n            cgh.parallel_for(\n                sycl::nd_range<3>(global_dims * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                    norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);\n                });\n            });\n    }\n    else {\n        const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];\n        assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);\n        const sycl::range<3> block_dims(1, 1, work_group_size);\n        /*\n        DPCT1049:17: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        stream->submit([&](sycl::handler& cgh) {\n            sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(\n                            sycl::range<1>(work_group_size / WARP_SIZE), cgh);\n            cgh.parallel_for(\n                sycl::nd_range<3>(global_dims * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                    norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);\n                });\n            });\n    }\n}\n\nstatic void group_norm_f32_sycl(const float* x, float* dst,\n    const int num_groups, const float eps, const int group_size,\n    const int ne_elements, queue_ptr stream, int device) {\n    if (group_size < 1024) {\n        const sycl::range<3> block_dims(1, 1, WARP_SIZE);\n        stream->submit([&](sycl::handler& cgh) {\n            const float eps_ct4 = eps;\n            cgh.parallel_for(\n                sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,\n                    block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                    group_norm_f32(\n                        x, dst, group_size, ne_elements, eps_ct4, item_ct1,\n                        nullptr, WARP_SIZE);\n                });\n            });\n    }\n    else {\n        const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];\n        assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);\n        const sycl::range<3> block_dims(1, 1, work_group_size);\n        /*\n        DPCT1049:18: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n\n        stream->submit([&](sycl::handler& cgh) {\n            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),\n                cgh);\n\n            const float eps_ct4 = eps;\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,\n                    block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                    group_norm_f32(x, dst, group_size, ne_elements,\n                        eps_ct4, item_ct1,\n                        get_pointer(s_sum_acc_ct1), work_group_size);\n                });\n            });\n    }\n}\n\nstatic void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,\n        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {\n    // printf(\"%s ncols=%d, nrows=%d, WARP_SIZE=%d\\n\", __func__, ncols, nrows, WARP_SIZE);\n\n    const sycl::range<3> global_dims(nsamples, nchannels, nrows);\n    if (ncols < 1024) {\n        const sycl::range<3> block_dims(1, 1, WARP_SIZE);\n        stream->submit([&](sycl::handler& cgh) {\n            cgh.parallel_for(\n                sycl::nd_range<3>(global_dims * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                    rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);\n                });\n            });\n    }\n    else {\n        const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];\n        assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);\n        const sycl::range<3> block_dims(1, 1, work_group_size);\n        /*\n        DPCT1049:19: The work-group size passed to the SYCL kernel may exceed\n        the limit. To get the device limit, query\n        info::device::max_work_group_size. Adjust the work-group size if needed.\n        */\n        stream->submit([&](sycl::handler& cgh) {\n            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),\n                cgh);\n            cgh.parallel_for(\n                sycl::nd_range<3>(global_dims * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                    rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);\n                });\n            });\n    }\n}\n\ntemplate<int warp_size>\nstatic void l2_norm_f32_sycl(const float *   x,\n                             float *         dst,\n                             const int       ncols,\n                             const int       nrows,\n                             const int       nchannels,\n                             const int       nsamples,\n                             const int64_t   stride_row,\n                             const int64_t   stride_channel,\n                             const int64_t   stride_sample,\n                             const float     eps,\n                             queue_ptr       stream,\n                             int             device) {\n    const dpct::dim3 blocks_num(nrows, nchannels, nsamples);\n\n    if (ncols < 1024) {\n        const dpct::dim3 block_dims(warp_size, 1, 1);\n        stream->submit([&](sycl::handler& cgh) {\n            cgh.parallel_for(\n                sycl::nd_range<3>(blocks_num * block_dims,\n                    block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                [[sycl::reqd_sub_group_size(warp_size)]] {\n                    l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,\n                        nullptr, warp_size);\n                });\n            });\n    }\n    else {\n        const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];\n        assert(work_group_size % (warp_size * warp_size) == 0);\n        const sycl::range<3> block_dims(1, 1, work_group_size);\n        int lsm_size =  block_dims[2] > warp_size ? work_group_size / warp_size * sizeof(float): 0;\n        stream->submit([&](sycl::handler& cgh) {\n            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(lsm_size),\n                cgh);\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(blocks_num * block_dims,\n                    block_dims),\n                [=](sycl::nd_item<3> item_ct1)\n                [[sycl::reqd_sub_group_size(warp_size)]] {\n                    l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample,\n                        eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);\n                });\n            });\n    }\n}\n\nvoid ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {\n    const ggml_tensor * src0 = dst->src[0];\n\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);\n    float *       dst_dd  = static_cast<float *>(dst->data);\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n    GGML_ASSERT(eps >= 0.0f);\n    const size_t ts0 = ggml_type_size(src0->type);\n    GGML_ASSERT(nb00 == ts0);\n    const int64_t s01 = nb01 / ts0;\n    const int64_t s02 = nb02 / ts0;\n    const int64_t s03 = nb03 / ts0;\n\n    norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);\n}\n\nvoid ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {\n\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    int num_groups = dst->op_params[0];\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n\n    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);\n    float *       dst_dd  = static_cast<float *>(dst->data);\n\n    float eps;\n    memcpy(&eps, dst->op_params + 1, sizeof(float));\n\n    int group_size = dst->src[0]->ne[0] * dst->src[0]->ne[1] * ((dst->src[0]->ne[2] + num_groups - 1) / num_groups);\n    group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device);\n}\n\nvoid ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    dpct::queue_ptr main_stream = ctx.stream();\n    SYCL_CHECK(ggml_sycl_set_device(ctx.device));\n\n    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);\n    float *       dst_dd  = static_cast<float *>(dst->data);\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n    const size_t ts0 = ggml_type_size(src0->type);\n    GGML_ASSERT(nb00 == ts0);\n    const int64_t s01 = nb01 / ts0;\n    const int64_t s02 = nb02 / ts0;\n    const int64_t s03 = nb03 / ts0;\n    rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);\n}\n\nvoid ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); // dz\n    GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); // x\n    GGML_ASSERT(dst->type         == GGML_TYPE_F32);\n\n    float eps = 1e-5f;\n    std::memcpy(&eps, dst->op_params, sizeof(float));\n    if (!(eps > 0.0f) || !std::isfinite(eps)) eps = 1e-5f;\n\n    const float * g_base  = static_cast<const float *>(dst->src[0]->data); // dz\n    const float * x_base  = static_cast<const float *>(dst->src[1]->data); // x\n          float * dx_base = static_cast<      float *>(dst->data);\n\n    const int64_t D  = dst->ne[0];\n    const int64_t n1 = dst->ne[1], n2 = dst->ne[2], n3 = dst->ne[3]; (void) n3;\n    const int64_t N  = ggml_nrows(dst);\n    if (D == 0 || N == 0) return;\n\n    const ggml_tensor *G = dst->src[0];\n    const ggml_tensor *X = dst->src[1];\n    const int ts = (int) ggml_type_size(X->type);\n    GGML_ASSERT((size_t) X->nb[0]   == (size_t) ts);\n    GGML_ASSERT((size_t) G->nb[0]   == (size_t) ts);\n    GGML_ASSERT((size_t) dst->nb[0] == (size_t) ts);\n\n    const int64_t xs1 = X->nb[1] / ts, xs2 = X->nb[2] / ts, xs3 = X->nb[3] / ts;\n    const int64_t gs1 = G->nb[1] / ts, gs2 = G->nb[2] / ts, gs3 = G->nb[3] / ts;\n    const int64_t ds1 = dst->nb[1] / ts, ds2 = dst->nb[2] / ts, ds3 = dst->nb[3] / ts;\n\n    dpct::queue_ptr q = ctx.stream();\n\n    // work-group size: multiple of WARP_SIZE, capped by device and 256, and not larger than D\n    const int device_max_wg = ggml_sycl_info().max_work_group_sizes[ctx.device];\n    auto roundup = [](int v, int m) { return ((v + m - 1) / m) * m; };\n    int wg_cap = 256;\n    if (device_max_wg > 0) wg_cap = std::min(wg_cap, device_max_wg);\n    int WG = std::max(WARP_SIZE, std::min(roundup((int)std::min<int64_t>(D, wg_cap), WARP_SIZE), wg_cap));\n\n    // FP32 path: per-thread compensated accumulation + hierarchical reduction\n    q->submit([&](sycl::handler &cgh) {\n        const int nwarps_loc = std::max(1, WG / WARP_SIZE);\n        // store one partial value per warp (xx and xg) for cross-warp reduction\n        auto l_xx   = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);\n        auto l_xg   = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);\n\n        cgh.parallel_for(\n            sycl::nd_range<3>(sycl::range<3>(1, 1, N) * sycl::range<3>(1, 1, WG),\n                              sycl::range<3>(1, 1, WG)),\n            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                const int row = item_ct1.get_group(2);\n                const int tid = item_ct1.get_local_id(2);\n\n                const int64_t i1 = row % n1;\n                const int64_t i2 = (row / n1) % n2;\n                const int64_t i3 = row / (n1 * n2);\n\n                const float *__restrict x_row = x_base + i3 * xs3 + i2 * xs2 + i1 * xs1;\n                const float *__restrict g_row = g_base + i3 * gs3 + i2 * gs2 + i1 * gs1;\n                float *__restrict d_row       = dx_base + i3 * ds3 + i2 * ds2 + i1 * ds1;\n\n                // per-thread accumulation (compensated by default)\n                float sum_xx = 0.f, sum_xg = 0.f;\n#ifndef GGML_SYCL_RMS_BACK_FAST\n                float c_xx = 0.f, c_xg = 0.f;\n#endif\n                for (int64_t col = tid; col < D; col += WG) {\n                    const float xv = x_row[col];\n                    const float gv = g_row[col];\n#ifdef GGML_SYCL_RMS_BACK_FAST\n                    sum_xx += xv * xv;\n                    sum_xg += xv * gv;\n#else\n                    float y1 = xv * xv - c_xx;\n                    float t1 = sum_xx + y1;\n                    c_xx = (t1 - sum_xx) - y1;\n                    sum_xx = t1;\n\n                    float y2 = xv * gv - c_xg;\n                    float t2 = sum_xg + y2;\n                    c_xg = (t2 - sum_xg) - y2;\n                    sum_xg = t2;\n#endif\n                }\n\n                // warp-level reduction\n                sycl::float2 xx = sycl::float2(sum_xx,\n#ifndef GGML_SYCL_RMS_BACK_FAST\n                    c_xx\n#else\n                    0.f\n#endif\n                );\n                sycl::float2 xg = sycl::float2(sum_xg,\n#ifndef GGML_SYCL_RMS_BACK_FAST\n                    c_xg\n#else\n                    0.f\n#endif\n                );\n                xx = warp_reduce_sum(xx, item_ct1);\n                xg = warp_reduce_sum(xg, item_ct1);\n\n                // cross-warp reduction using local memory (single barrier)\n                const auto sub_group = item_ct1.get_sub_group();\n                const auto sg_id     = sub_group.get_group_linear_id();\n                const auto wi_in_sg  = sub_group.get_local_linear_id();\n                const int nthreads   = item_ct1.get_local_range(2);\n                const int nwarps     = nthreads / WARP_SIZE;\n\n                sycl::float2 xx_total = xx;\n                sycl::float2 xg_total = xg;\n                if (nwarps > 1) {\n                    if (wi_in_sg == 0) {\n                        l_xx[sg_id] = xx;\n                        l_xg[sg_id] = xg;\n                    }\n                    item_ct1.barrier(sycl::access::fence_space::local_space);\n\n                    if (sg_id == 0) {\n                        const unsigned wi_u = wi_in_sg;\n                        sycl::float2 xx_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xx[wi_u] : sycl::float2(0.f, 0.f);\n                        sycl::float2 xg_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xg[wi_u] : sycl::float2(0.f, 0.f);\n                        xx_total = warp_reduce_sum(xx_first, item_ct1);\n                        xg_total = warp_reduce_sum(xg_first, item_ct1);\n                    } else {\n                        // other subgroups keep their local totals; they'll be ignored\n                        xx_total = xx;\n                        xg_total = xg;\n                    }\n                    // ensure all threads see the first-subgroup result via broadcast below\n                }\n\n                // compute inv_r and coeff once per row and broadcast to the whole work-group\n                float inv_r = 0.f;\n                float coeff = 0.f;\n                if (tid == 0) {\n                    const float sum_xx_f  = xx_total.x() + xx_total.y();\n                    const float sum_xdz_f = xg_total.x() + xg_total.y();\n                    const float mean_eps  = sum_xx_f / (float) D + eps;\n                    const float sum_eps   = sum_xx_f + eps * (float) D;\n                    inv_r = sycl::rsqrt(mean_eps);\n                    coeff = -sum_xdz_f / sum_eps;\n                }\n                inv_r = sycl::group_broadcast(item_ct1.get_group(), inv_r);\n                coeff = sycl::group_broadcast(item_ct1.get_group(), coeff);\n\n                for (int64_t col = tid; col < D; col += WG) {\n                    d_row[col] = (g_row[col] + coeff * x_row[col]) * inv_r;\n                }\n            });\n    });\n\n}\n\nvoid ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src0_d = (const float *) src0->data;\n    float * dst_d = (float *) dst->data;\n    dpct::queue_ptr     stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_UNARY_OP_LOCALS;\n\n    float eps;\n    memcpy(&eps, dst->op_params, sizeof(float));\n    GGML_ASSERT(eps >= 0.0f);\n\n    const size_t ts0 = ggml_type_size(src0->type);\n    GGML_ASSERT(nb00 == ts0);\n    const int64_t s01 = nb01 / ts0;\n    const int64_t s02 = nb02 / ts0;\n    const int64_t s03 = nb03 / ts0;\n\n    /*support both WARP_SIZE or WARP_32_SIZE in code\n      choose by hardware for better performance\n    */\n    l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device);\n}\n"
  },
  {
    "path": "src/ggml-sycl/norm.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_NORM_HPP\n#define GGML_SYCL_NORM_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);\n\nvoid ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);\n\nvoid ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context& ctx, ggml_tensor* dst);\n\nvoid ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);\n\nvoid ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);\n\n#endif // GGML_SYCL_NORM_HPP\n"
  },
  {
    "path": "src/ggml-sycl/outprod.cpp",
    "content": "#include \"outprod.hpp\"\n\nvoid ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    const ggml_tensor *src0 = dst->src[0];\n    const ggml_tensor *src1 = dst->src[1];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n    GGML_ASSERT(ggml_is_contiguous(src0));\n    GGML_ASSERT(ggml_is_contiguous(dst));\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    // Get SYCL queue\n    dpct::queue_ptr stream = ctx.stream();\n\n    // Dimension checks\n    GGML_ASSERT(ne01 == ne11);  // Inner dimensions must match\n    GGML_ASSERT(ne0 == ne00);   // Output rows match src0 rows\n    GGML_ASSERT(ne1 == ne10);   // Output cols match src1 cols\n\n    // Get data pointers\n    const float* src0_d = (const float*)src0->data;\n    const float* src1_d = (const float*)src1->data;\n    float* dst_d = (float*)dst->data;\n\n    // GEMM parameters\n    const float alpha = 1.0f;\n    const float beta = 0.0f;\n\n    // Handle transposition of src1\n    const bool src1_T = ggml_is_transposed(src1);\n    const oneapi::mkl::transpose src1_op = src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;\n    const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);\n\n    try {\n        // Perform matrix multiplication using oneMKL GEMM\n        oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op,\n                                               ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);\n    }\n    catch (sycl::exception const& exc) {\n        std::cerr << exc.what() << std::endl;\n        GGML_ASSERT(false);\n    }\n}\n"
  },
  {
    "path": "src/ggml-sycl/outprod.hpp",
    "content": "#ifndef GGML_SYCL_OUTPROD_HPP\n#define GGML_SYCL_OUTPROD_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst);\n\n\n#endif // GGML_SYCL_OUTPROD_HPP\n\n"
  },
  {
    "path": "src/ggml-sycl/pad.cpp",
    "content": "//\n// MIT license\n// Copyright (C) 2025 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n//#include \"common.hpp\"\n#include \"pad.hpp\"\n\nstatic void pad_f32(const float * src, float * dst,\n                    const int lp0, const int rp0, const int lp1, const int rp1,\n                    const int lp2, const int rp2, const int lp3, const int rp3,\n                    const int ne0, const int ne1, const int ne2, const int ne3,\n                    sycl::nd_item<3> item_ct1) {\n    int i0 = item_ct1.get_local_id(2) +\n             item_ct1.get_group(2) * item_ct1.get_local_range(2);\n    int i1 = item_ct1.get_group(1);\n    int i2 = item_ct1.get_group(0) % ne2;\n    int i3 = item_ct1.get_group(0) / ne2;\n    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {\n        return;\n    }\n\n    // operation\n    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;\n    if ((i0 >= lp0 && i0 < ne0 - rp0) &&\n        (i1 >= lp1 && i1 < ne1 - rp1) &&\n        (i2 >= lp2 && i2 < ne2 - rp2) &&\n        (i3 >= lp3 && i3 < ne3 - rp3)) {\n        const int64_t i00 = i0 - lp0;\n        const int64_t i01 = i1 - lp1;\n        const int64_t i02 = i2 - lp2;\n        const int64_t i03 = i3 - lp3;\n        const int64_t ne02 = ne2 - lp2 - rp2;\n        const int64_t ne01 = ne1 - lp1 - rp1;\n        const int64_t ne00 = ne0 - lp0 - rp0;\n\n        const int64_t src_idx = i03 * (ne00 * ne01 * ne02) +\n                                i02 * (ne00 * ne01) + i01 * ne00 + i00;\n\n        dst[dst_idx] = src[src_idx];\n    } else {\n        dst[dst_idx] = 0.0f;\n    }\n}\n\nstatic void pad_f32_sycl(const float *src, float *dst, const int lp0,\n                         const int rp0, const int lp1, const int rp1,\n                         const int lp2, const int rp2, const int lp3,\n                         const int rp3, const int ne0, const int ne1,\n                         const int ne2, const int ne3,\n                         dpct::queue_ptr stream) {\n    int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;\n    dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3);\n    stream->parallel_for(\n        sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),\n                          sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),\n        [=](sycl::nd_item<3> item_ct1) {\n            pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1,\n                    ne2, ne3, item_ct1);\n        });\n}\n\nvoid ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const float * src0_d = (const float *)src0->data;\n    float * dst_d = (float *)dst->data;\n    dpct::queue_ptr     stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];\n    const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];\n    const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];\n    const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];\n    const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];\n    const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];\n    const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];\n    const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];\n\n    pad_f32_sycl(src0_d, dst_d,\n                 lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,\n                 dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);\n}\n\nvoid ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    ggml_sycl_op_pad(ctx, dst);\n}\n"
  },
  {
    "path": "src/ggml-sycl/pad.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2025 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_PAD_HPP\n#define GGML_SYCL_PAD_HPP\n\n#include \"common.hpp\"\n\n#define SYCL_PAD_BLOCK_SIZE 256\n\nvoid ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\n#endif // GGML_SYCL_PAD_HPP\n"
  },
  {
    "path": "src/ggml-sycl/pad_reflect_1d.cpp",
    "content": "#include \"pad_reflect_1d.hpp\"\n\nstatic void pad_reflect_1d_kernel_f32(\n    const void *__restrict__ src0, void *__restrict__ dst, const int64_t ne0,\n    const int64_t ne00, const sycl::uint3 ne01, const int64_t ne02,\n    const int64_t ne03, const int64_t nb00, const int64_t nb01,\n    const int64_t nb02, const int64_t nb03, const int64_t nb0,\n    const int64_t nb1, const int64_t nb2, const int64_t nb3, const int p0,\n    const int p1, sycl::nd_item<3> item_ct1) {\n\n    const int64_t i3 = item_ct1.get_group(0);\n    const int64_t i2 = item_ct1.get_group(1);\n\n    const sycl::uint2 div_mod_packed =\n        fast_div_modulo(item_ct1.get_group(2), ne01);\n    const int64_t tile1 = div_mod_packed.y();\n    const int64_t tile0 = div_mod_packed.x();\n    const int64_t i1 = tile1;\n    const int64_t i0 =\n        item_ct1.get_local_id(2) + tile0 * item_ct1.get_local_range(2);\n\n    if (i0 >= ne0 || i1 >= ne01.z() || i2 >= ne02 || i3 >= ne03) {\n        return;\n    }\n\n    const char *src0_ptr =\n        (const char *)src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;\n    char *dst_ptr = (char *)dst + i3 * nb3 + i2 * nb2 + i1 * nb1;\n\n    const int64_t rel_i0 = i0 - p0; // relative i0 in src0\n    int64_t src_idx;\n\n    if (rel_i0 < 0) {\n        // Left padding - reflect\n        src_idx = -rel_i0;\n    } else if (rel_i0 < ne00) {\n        // Middle - copy\n        src_idx = rel_i0;\n    } else {\n        // Right padding - reflect\n        src_idx = 2 * ne00 - 2 - rel_i0;\n    }\n    const float value = *(const float *)(src0_ptr + src_idx * nb00);\n    *(float *)(dst_ptr + i0 * nb0) = value;\n\n    GGML_UNUSED(p1);\n}\n\nvoid ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context &ctx,\n                                 ggml_tensor *dst) {\n\n    const ggml_tensor *src0 = dst->src[0];\n    dpct::queue_ptr stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    const int32_t *opts = (const int32_t *)dst->op_params;\n    const int p0 = opts[0];\n    const int p1 = opts[1];\n\n    const int64_t ne00 = src0->ne[0];\n    const int64_t ne01 = src0->ne[1];\n    const sycl::uint3 ne01_packed = init_fastdiv_values(ne01);\n    const int64_t ne02 = src0->ne[2];\n    const int64_t ne03 = src0->ne[3];\n\n    const int64_t ne0 = dst->ne[0];\n\n    GGML_ASSERT(ne0 == ne00 + p0 + p1);\n\n    constexpr int64_t bx = SYCL_PAD_REFLECT_1D_BLOCK_SIZE;\n    const int64_t tiles0 = (ne0 + bx - 1) / bx;\n    const dpct::dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02,\n                               (unsigned)ne03);\n    const dpct::dim3 block_dims((unsigned)bx, 1, 1);\n\n    stream->submit([&](sycl::handler &cgh) {\n        auto src0_data_ct0 = src0->data;\n        auto dst_data_ct1 = dst->data;\n        auto src0_nb_ct7 = src0->nb[0];\n        auto src0_nb_ct8 = src0->nb[1];\n        auto src0_nb_ct9 = src0->nb[2];\n        auto src0_nb_ct10 = src0->nb[3];\n        auto dst_nb_ct11 = dst->nb[0];\n        auto dst_nb_ct12 = dst->nb[1];\n        auto dst_nb_ct13 = dst->nb[2];\n        auto dst_nb_ct14 = dst->nb[3];\n\n        cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),\n                         [=](sycl::nd_item<3> item_ct1) {\n                             pad_reflect_1d_kernel_f32(\n                                 src0_data_ct0, dst_data_ct1, ne0, ne00,\n                                 ne01_packed, ne02, ne03, src0_nb_ct7,\n                                 src0_nb_ct8, src0_nb_ct9, src0_nb_ct10,\n                                 dst_nb_ct11, dst_nb_ct12, dst_nb_ct13,\n                                 dst_nb_ct14, p0, p1, item_ct1);\n                         });\n    });\n}\n"
  },
  {
    "path": "src/ggml-sycl/pad_reflect_1d.hpp",
    "content": "#ifndef GGML_SYCL_PAD_REFLECT_1D_HPP\n#define GGML_SYCL_PAD_REFLECT_1D_HPP\n\n#include \"common.hpp\"\n\n#define SYCL_PAD_REFLECT_1D_BLOCK_SIZE 256\n\nvoid ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst);\n\n#endif // GGML_SYCL_PAD_REFLECT_1D_HPP\n"
  },
  {
    "path": "src/ggml-sycl/presets.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_PRESETS_HPP\n#define GGML_SYCL_PRESETS_HPP\n\n#define GGML_SYCL_MAX_STREAMS       8\n#define GGML_SYCL_MAX_BUFFERS       256\n\n#define WARP_SIZE GGML_SYCL_WARP_SIZE\n#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses\n\n#define SYCL_GELU_BLOCK_SIZE 256\n#define SYCL_SILU_BLOCK_SIZE 256\n#define SYCL_TANH_BLOCK_SIZE 256\n#define SYCL_RELU_BLOCK_SIZE 256\n#define SYCL_HARDSIGMOID_BLOCK_SIZE 256\n#define SYCL_HARDSWISH_BLOCK_SIZE 256\n#define SYCL_EXP_BLOCK_SIZE 256\n#define SYCL_NEG_BLOCK_SIZE 256\n#define SYCL_SIGMOID_BLOCK_SIZE 256\n#define SYCL_SQRT_BLOCK_SIZE 256\n#define SYCL_SIN_BLOCK_SIZE 256\n#define SYCL_SQR_BLOCK_SIZE 256\n#define SYCL_SET_BLOCK_SIZE 256\n#define SYCL_CPY_BLOCK_SIZE 32\n#define SYCL_SCALE_BLOCK_SIZE 256\n#define SYCL_CLAMP_BLOCK_SIZE 256\n#define SYCL_ROPE_BLOCK_SIZE 256\n#define SYCL_ALIBI_BLOCK_SIZE 32\n#define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32\n#define SYCL_QUANTIZE_BLOCK_SIZE 256\n#define SYCL_DEQUANTIZE_BLOCK_SIZE 256\n#define SYCL_GET_ROWS_BLOCK_SIZE 256\n#define SYCL_UPSCALE_BLOCK_SIZE 256\n#define SYCL_CONCAT_BLOCK_SIZE 256\n#define SYCL_PAD_BLOCK_SIZE 256\n#define SYCL_ACC_BLOCK_SIZE 256\n#define SYCL_IM2COL_BLOCK_SIZE 256\n#define SYCL_POOL2D_BLOCK_SIZE 256\n#define SYCL_ARGMAX_BLOCK_SIZE 256\n#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256\n#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256\n#define SYCL_ARANGE_BLOCK_SIZE 256\n\n// dmmv = dequantize_mul_mat_vec\n#ifndef GGML_SYCL_DMMV_X\n#define GGML_SYCL_DMMV_X 32\n#endif\n#ifndef GGML_SYCL_MMV_Y\n#define GGML_SYCL_MMV_Y 1\n#endif\n\n#ifndef K_QUANTS_PER_ITERATION\n#define K_QUANTS_PER_ITERATION 2\n#else\nstatic_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, \"K_QUANTS_PER_ITERATION must be 1 or 2\");\n#endif\n\n#ifndef GGML_SYCL_PEER_MAX_BATCH_SIZE\n#define GGML_SYCL_PEER_MAX_BATCH_SIZE 128\n#endif // GGML_SYCL_PEER_MAX_BATCH_SIZE\n\n#define MUL_MAT_SRC1_COL_STRIDE 128\n\n#define QK_WARP_SIZE 32\n#define WARP_32_SIZE 32\n#define WARP_16_SIZE 16\n\n#endif // GGML_SYCL_PRESETS_HPP\n"
  },
  {
    "path": "src/ggml-sycl/quantize.hpp",
    "content": "/***************************************************************************\n *\n *  Copyright (C) 2025 Codeplay Software Ltd.\n *  Copyright (C) 2025 Intel Corporation\n *\n *  MIT License\n *\n *  Unless required by applicable law or agreed to in writing, software\n *  distributed under the License is distributed on an \"AS IS\" BASIS,\n *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *  See the License for the specific language governing permissions and\n *  limitations under the License.\n *\n *  quantize.hpp\n *\n *  Description:\n *     Sycl backend specific quantization functions\n **************************************************************************/\n\n#pragma once\n\n#include <sycl/nd_item.hpp>\n\n#include \"ggml-sycl/dpct/helper.hpp\"\n\ntemplate <int ElementsPerWI>\n__dpct_inline__ static void quantize_q8_1_impl(const float * __restrict__ x,\n                                               sycl::vec<int8_t, ElementsPerWI> & quantized_values, float & d,\n                                               float & sum, const sycl::nd_item<1> & it) {\n    auto subgroup_id = it.get_group(0);\n    auto wi_id       = it.get_local_id(0);\n\n    sycl::vec<float, ElementsPerWI> wi_f32_vals;\n\n    auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;\n    wi_f32_vals           = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);\n\n    float amax = 0.0f;\n\n#pragma unroll(ElementsPerWI)\n    for (int i = 0; i < ElementsPerWI; i++) {\n        sum += wi_f32_vals[i];\n        amax                = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));\n        quantized_values[i] = 0;\n    }\n    sum  = sycl::reduce_over_group(it.get_sub_group(), sum, sycl::plus<float>());\n    amax = sycl::reduce_over_group(it.get_sub_group(), amax, sycl::maximum<float>());\n    d    = amax == 0 ? 1 : amax / 127;\n\n#pragma unroll(ElementsPerWI)\n    for (int i = 0; i < ElementsPerWI; i++) {\n        quantized_values[i] = sycl::round(wi_f32_vals[i] / d);\n    }\n\n    d = amax == 0 ? 0 : d;\n}\n\n// No op to control codepath in ggml_sycl_op_mul_mat\ntemplate <int ElementsPerWI> struct no_quantize_q8_1 {\n    void operator()(const float *, void *, int, int, const sycl::nd_item<1> &) const {}\n};\n\ntemplate <int ElementsPerWI> struct quantize_and_reorder_q8_1_soa {\n    __dpct_inline__ void operator()(const float * __restrict__ x, void * reordered_q8_tensor, const int kx,\n                                    const int kx_padded, const sycl::nd_item<1> & it) const {\n        /*\n        Quantizes and reorders the resultant q8 tensor in a per row fashion\n        Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values\n    */\n        auto subgroup_id = it.get_group(0);\n        auto wi_id       = it.get_local_id(0);\n\n        sycl::vec<int8_t, ElementsPerWI> quantized_values;\n        float                            d   = 0.0f;\n        float                            sum = 0.0f;\n        quantize_q8_1_impl<ElementsPerWI>(x, quantized_values, d, sum, it);\n\n        const int num_blocks_per_row = kx / QK8_1;\n        auto      row                = subgroup_id / num_blocks_per_row;\n        auto      col                = subgroup_id % num_blocks_per_row;\n        auto      row_offset         = row * (kx_padded / QK8_1) * sizeof(block_q8_1);\n        auto      col_offset         = QK8_1 * col + wi_id * ElementsPerWI;\n\n        auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);\n        *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;\n\n        auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));\n        if (wi_id == 0) {\n            *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));\n        }\n    }\n};\n\ntemplate <int ElementsPerWI> struct quantize_q8_1 {\n    __dpct_inline__ void operator()(const float * __restrict__ x, void * q8_tensor, const int kx, const int kx_padded,\n                                    const sycl::nd_item<1> & it) const {\n        auto subgroup_id = it.get_group(0);\n        auto wi_id       = it.get_local_id(0);\n\n        const int num_blocks_per_row = kx / QK8_1;\n        auto      row                = subgroup_id / num_blocks_per_row;\n        const int pitch              = kx_padded / QK8_1;\n\n        sycl::vec<int8_t, ElementsPerWI> quantized_values;\n        float                            d   = 0.0f;\n        float                            sum = 0.0f;\n        quantize_q8_1_impl<ElementsPerWI>(x, quantized_values, d, sum, it);\n\n        block_q8_1 * quant_ptr = (block_q8_1 *) q8_tensor;\n        auto         block_id  = subgroup_id % num_blocks_per_row + row * pitch;\n\n        int8_t * qs                                               = &(quant_ptr[block_id].qs[wi_id * ElementsPerWI]);\n        *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(qs) = quantized_values;\n        if (wi_id == 0) {\n            quant_ptr[block_id].ds = sycl::half2(sycl::half(d), sycl::half(sum));\n        }\n    }\n};\n\ntemplate <template <int> typename quantize_f>\nvoid quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,\n                            dpct::queue_ptr stream) {\n    static_assert(QK8_1 % WARP_SIZE == 0);\n    auto local_range      = std::size_t(WARP_SIZE);\n    auto num_quant_blocks = ky * (kx / QK8_1);\n    auto global_range     = num_quant_blocks * local_range;\n    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });\n\n    stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),\n                         [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                             quantize_f<QK8_1 / WARP_SIZE>()(x, vy, kx, kx_padded, it);\n                         });\n}\n"
  },
  {
    "path": "src/ggml-sycl/quants.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2025 Codeplay Software Ltd.\n// Copyright (C) 2025 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_QUANTS_HPP\n#define GGML_SYCL_QUANTS_HPP\n\n#include <utility>\n\n#include \"ggml-common.h\"\n#include \"ggml.h\"\n\nnamespace ggml_sycl_reordered {\n\n// The reordered block moves quants (qs) and  scales(d) to two\n// uniform regions of memory that is contiguous in the same tensor.\n// What this means is that instead of having:\n// [d0, qs0] [d1, qs1] [d2, qs2] ... [dN, qsN]\n// We have:\n// [qs0, qs1, qs2, ..., qsN]  [d0, d1, d2, ..., dN]\n//\n// Notes: out-of-bounds qs will run into d values\n// Alignment relies on the allocated size of qs\n\ntemplate <ggml_type type> struct block_q_t;\n\n// qk number of weights / quants in a block\n// qr number of weights in a byte (described as 'before dequantization')\n//    for quantization types that has low and high bits split, qr is calculated with\n//    using the lower bits, e.g for Q6 quants QR6 is 2\n// qi number of 32 bit integers needed to represent all the quants from a block (`qs` field)\n// See ggml-common.h to see how these are calculated\ntemplate <> struct block_q_t<GGML_TYPE_Q4_0> {\n    struct traits {\n        static constexpr uint32_t qk       = QK4_0;\n        static constexpr uint32_t qi       = QI4_0;\n        static constexpr uint32_t qr       = QR4_0;\n        static constexpr uint32_t vdr_mmvq = 2;\n    };\n\n    static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {\n        return { block_index * (QK4_0 / QR4_0), 0 };\n    }\n\n    static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {\n        return { (ncols / QR4_0 * nrows) + block_index * sizeof(ggml_half), 0 };\n    }\n\n    static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }\n};\n\ntemplate <> struct block_q_t<GGML_TYPE_Q4_K> {\n    struct traits {\n        static constexpr uint32_t qk       = QK_K;\n        static constexpr uint32_t qi       = QI4_K;\n        static constexpr uint32_t qr       = QR4_K;\n        static constexpr uint32_t vdr_mmvq = 2;\n    };\n\n    static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {\n        return { block_index * (traits::qk / traits::qr), 0 };\n    }\n\n    static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {\n        auto nblocks = (nrows * (ncols / QK_K));\n        return { nblocks * (QK_K / 2) + (block_index * K_SCALE_SIZE),\n                 (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) };\n    }\n\n    static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }\n};\n\ntemplate <> struct block_q_t<GGML_TYPE_Q6_K> {\n    struct traits {\n        static constexpr uint32_t qk       = QK_K;\n        static constexpr uint32_t qi       = QI6_K;\n        static constexpr uint32_t qr       = QR6_K;\n        static constexpr uint32_t vdr_mmvq = 1;\n    };\n\n    static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) {\n        auto low_bits_index  = block_index * (QK_K / QR6_K);\n        // the index of high bits it's after all low bits\n        auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4));\n        return { low_bits_index, high_bits_index };\n    }\n\n    static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {\n        auto nblocks        = (nrows * (ncols / QK_K));\n        auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4);\n        auto block_scales   = total_qs_bytes + block_index * (QK_K / 16);\n        auto sb_scale       = total_qs_bytes + nblocks * (QK_K / 16) + block_index * sizeof(ggml_half);\n        return { block_scales, sb_scale };\n    }\n\n    static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }\n};\n\n}  // namespace ggml_sycl_reordered\n\n#endif  // GGML_SYCL_QUANTS_HPP\n"
  },
  {
    "path": "src/ggml-sycl/repeat_back.cpp",
    "content": "#include \"repeat_back.hpp\"\n\n#include \"common.hpp\"\n\n#define GGML_ASSERT_TENSOR_FITS_INT(t) \\\n    GGML_ASSERT((t)->ne[0] < INT_MAX && (t)->ne[1] < INT_MAX && (t)->ne[2] < INT_MAX && (t)->ne[3] < INT_MAX)\n\nvoid ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    const float * src0_dd = (const float *) dst->src[0]->data;\n    float *       dst_dd  = (float *) dst->data;\n\n    GGML_ASSERT_TENSOR_FITS_INT(dst);\n    GGML_ASSERT_TENSOR_FITS_INT(dst->src[0]);\n\n    const int ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];\n    const int ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],\n              ne03 = dst->src[0]->ne[3];\n\n    const int nr0 = ne00 / ne0;\n    const int nr1 = ne01 / ne1;\n    const int nr2 = ne02 / ne2;\n    const int nr3 = ne03 / ne3;\n\n    const int nb0 = dst->src[0]->nb[0];\n    const int nb1 = dst->src[0]->nb[1];\n    const int nb2 = dst->src[0]->nb[2];\n    const int nb3 = dst->src[0]->nb[3];\n\n    const char * base = (const char *) src0_dd;\n\n    const size_t  total      = (size_t) ne0 * ne1 * ne2 * ne3;\n    constexpr int BLOCK_SIZE = 256;\n    const int     num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;\n\n    const float inv_ne0      = 1.0f / ne0;\n    const float inv_ne_01    = 1.0f / (ne0 * ne1);\n    const float inv_ne_012   = 1.0f / (ne0 * ne1 * ne2);\n    const int   repeat_count = nr0 * nr1 * nr2 * nr3;\n\n    queue_ptr stream = ctx.stream();\n\n    stream->parallel_for(\n        sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),\n        [=](sycl::nd_item<1> item_ct1) {\n            const size_t i = item_ct1.get_global_linear_id();\n            if (i >= total) {\n                return;\n            }\n\n            const int i3 = (int) (i * inv_ne_012);\n            const int i2 = (int) (i * inv_ne_01) - i3 * ne2;\n            const int i1 = (int) (i * inv_ne0) - (int) (i * inv_ne_01) * ne1;\n            const int i0 = i - (int) (i * inv_ne0) * ne0;\n\n            int   j0 = 0, j1 = 0, j2 = 0, j3 = 0;\n            float acc = 0.0f;\n\n            for (int j = 0; j < repeat_count; ++j) {\n                const float * ptr = (const float *) (base + (i0 + j0 * ne0) * nb0 + (i1 + j1 * ne1) * nb1 +\n                    (i2 + j2 * ne2) * nb2 + (i3 + j3 * ne3) * nb3);\n                acc += *ptr;\n\n                int carry = (++j0 >= nr0);\n                j0 -= carry * nr0;\n                carry = (carry && (++j1 >= nr1));\n                j1 -= carry * nr1;\n                carry = (carry && (++j2 >= nr2));\n                j2 -= carry * nr2;\n                j3 += carry;\n            }\n            dst_dd[i] = acc;\n        });\n}\n"
  },
  {
    "path": "src/ggml-sycl/repeat_back.hpp",
    "content": "#ifndef GGML_SYCL_REPEAT_BACK_HPP\n#define GGML_SYCL_REPEAT_BACK_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\n#endif  // GGML_SYCL_REPEAT_BACK_HPP\n"
  },
  {
    "path": "src/ggml-sycl/roll.cpp",
    "content": "#include \"roll.hpp\"\n#include \"common.hpp\"\n\nusing namespace sycl;\n\nstatic inline int wrap_add(int i, int shift, int n) {\n\n    int s = i + shift;\n    return (s >= n) ? (s - n) : s;\n}\n\nstatic void kernel_roll_fused_i0_i1(\n    queue &q,\n    const float *src_d,\n    float *dst_d,\n    int ne0, int ne1, int ne2, int ne3,\n    int sh0, int sh1, int sh2, int sh3)\n{\n    if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return;\n\n\n    const int stride1 = ne0;\n    const int stride2 = ne0 * ne1;\n    const int stride3 = ne0 * ne1 * ne2;\n\n\n    const int shNe0 = (ne0 - sh0) % ne0;\n    const int shNe1 = (ne1 - sh1) % ne1;\n    const int shNe2 = (ne2 - sh2) % ne2;\n    const int shNe3 = (ne3 - sh3) % ne3;\n\n\n    const size_t g0 = (size_t) ne3;\n    const size_t g1 = (size_t) ne2;\n    const size_t g2 = (size_t) (ne1 * ne0);\n\n    const range<3> global{ g0, g1, g2 };\n\n    q.submit([&](handler &h) {\n        h.parallel_for(global, [=](id<3> idx) {\n            const int i3 = (int) idx[0];\n            const int i2 = (int) idx[1];\n\n            const int fused = (int) idx[2];\n            const int i1 = fused / ne0;\n            const int i0 = fused - i1 * ne0;  // fused % ne0\n\n\n            const int idx_dst = i0\n                              + i1 * stride1\n                              + i2 * stride2\n                              + i3 * stride3;\n\n\n            const int s0 = wrap_add(i0, shNe0, ne0);\n            const int s1 = wrap_add(i1, shNe1, ne1);\n            const int s2 = wrap_add(i2, shNe2, ne2);\n            const int s3 = wrap_add(i3, shNe3, ne3);\n\n            const int idx_src = s0\n                              + s1 * stride1\n                              + s2 * stride2\n                              + s3 * stride3;\n\n            dst_d[idx_dst] = src_d[idx_src];\n        });\n    });\n}\n\nvoid ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    const ggml_tensor *src = dst->src[0];\n    GGML_ASSERT(src && src->type == GGML_TYPE_F32);\n\n    const int ne0 = (int) dst->ne[0];\n    const int ne1 = (int) dst->ne[1];\n    const int ne2 = (int) dst->ne[2];\n    const int ne3 = (int) dst->ne[3];\n\n    const int32_t *params = (const int32_t *) dst->op_params;\n    int shift0 = params[0];\n    int shift1 = params[1];\n    int shift2 = params[2];\n    int shift3 = params[3];\n\n\n    if ((shift0 | shift1 | shift2 | shift3) == 0) {\n        const size_t nb = ggml_nbytes(src);\n        queue *q = ctx.stream();\n        SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb)));\n        return;\n    }\n\n    auto norm = [](int sh, int n) -> int {\n        if (n <= 0) return 0;\n        sh %= n;\n        if (sh < 0) sh += n;\n        return sh;\n    };\n    shift0 = norm(shift0, ne0);\n    shift1 = norm(shift1, ne1);\n    shift2 = norm(shift2, ne2);\n    shift3 = norm(shift3, ne3);\n\n    try {\n        queue *q = ctx.stream();\n\n        const float *src_d = (const float *) src->data;\n        float *dst_d = (float *) dst->data;\n        GGML_ASSERT(src_d && dst_d);\n\n        kernel_roll_fused_i0_i1(\n            *q, src_d, dst_d,\n            ne0, ne1, ne2, ne3,\n            shift0, shift1, shift2, shift3\n        );\n    } catch (const std::exception &e) {\n        std::fprintf(stderr, \"[SYCL-ROLL] ERROR: %s\\n\", e.what());\n        throw;\n    }\n}\n"
  },
  {
    "path": "src/ggml-sycl/roll.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_ROLL_HPP\n#define GGML_SYCL_ROLL_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst);\n\n#endif // GGML_SYCL_ROLL_HPP\n"
  },
  {
    "path": "src/ggml-sycl/rope.cpp",
    "content": "#include \"rope.hpp\"\n#include \"convert.hpp\"\n#include \"ggml-sycl/common.hpp\"\n#include \"ggml.h\"\n\nstruct rope_corr_dims {\n    float v[2];\n};\n\nstruct mrope_sections {\n    int v[4];\n};\n\nstatic float rope_yarn_ramp(const float low, const float high, const int i0) {\n    const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);\n    return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));\n}\n\ntemplate <bool forward>\nstatic void rope_yarn(const float theta_extrap, const float freq_scale,\n                      const rope_corr_dims corr_dims, const int64_t i0,\n                      const float ext_factor, float mscale, float &cos_theta,\n                      float &sin_theta) {\n    float theta_interp = freq_scale * theta_extrap;\n    float theta = theta_interp;\n    if (ext_factor != 0.0f) {\n        float ramp_mix =\n            rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;\n        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;\n\n        mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);\n    }\n    cos_theta = sycl::cos(theta) * mscale;\n    sin_theta = sycl::sin(theta) * mscale;\n    if (!forward) {\n        sin_theta *= -1.0f;\n    }\n}\n\ntemplate <bool forward, bool has_ff, typename T, typename D>\nstatic void rope_norm(const T *x, D *dst, const int ne00, const int ne01,\n                      const int ne02, const int s01, const int s02,\n                      const int s03, const int s1, const int s2, const int s3,\n                      const int n_dims, const int32_t *pos,\n                      const float freq_scale, const float ext_factor,\n                      const float attn_factor, const rope_corr_dims corr_dims,\n                      const float theta_scale, const float *freq_factors,\n                      const int64_t *row_indices, const int set_rows_stride) {\n    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +\n                        item_ct1.get_local_id(1));\n\n    if (i0 >= ne00) {\n        return;\n    }\n\n    const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +\n                        item_ct1.get_local_id(2);\n\n    const uint32_t i3 = row_dst / (ne01 * ne02);\n    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;\n    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;\n\n    int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;\n    const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;\n\n    if (set_rows_stride != 0) {\n        idst = i1 * s1 + i0;\n        idst += row_indices[i2] * set_rows_stride;\n    }\n\n    const auto &store_coaelsced = [&](float x0, float x1) {\n        if constexpr (std::is_same_v<float, D>) {\n            sycl::float2 v = sycl::float2(x0, x1);\n            ggml_sycl_memcpy_1<8>(dst + idst, &v);\n        } else if constexpr (std::is_same_v<sycl::half, D>) {\n            sycl::half2 v = sycl::half2(x0, x1);\n            ggml_sycl_memcpy_1<4>(dst + idst, &v);\n        }\n    };\n    if (i0 >= n_dims) {\n        store_coaelsced(x[ix + 0], x[ix + 1]);\n        return;\n    }\n\n    const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);\n\n    const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;\n\n    float cos_theta;\n    float sin_theta;\n\n    rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,\n                       ext_factor, attn_factor, cos_theta, sin_theta);\n\n    const float x0 = x[ix + 0];\n    const float x1 = x[ix + 1];\n\n    store_coaelsced(x0 * cos_theta - x1 * sin_theta,\n                    x0 * sin_theta + x1 * cos_theta);\n}\n\ntemplate <bool forward, bool has_ff, typename T, typename D>\nstatic void rope_neox(const T *x, D *dst, const int ne00, const int ne01,\n                      const int ne02, const int s01, const int s02,\n                      const int s03, const int s1, const int s2, const int s3,\n                      const int n_dims, const int32_t *pos,\n                      const float freq_scale, const float ext_factor,\n                      const float attn_factor, const rope_corr_dims corr_dims,\n                      const float theta_scale, const float *freq_factors,\n                      const int64_t *row_indices, const int set_rows_stride) {\n    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +\n                        item_ct1.get_local_id(1));\n\n    if (i0 >= ne00) {\n        return;\n    }\n\n    const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +\n                        item_ct1.get_local_id(2);\n\n    const uint32_t i3 = row_dst / (ne01 * ne02);\n    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;\n    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;\n\n    int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;\n    const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;\n\n    if (set_rows_stride != 0) {\n        idst = i1 * s1 + i0 / 2;\n        idst += row_indices[i2] * set_rows_stride;\n    }\n\n    if (i0 >= n_dims) {\n        dst[idst + i0 / 2 + 0] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 0]);\n        dst[idst + i0 / 2 + 1] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 1]);\n\n        return;\n    }\n\n    const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);\n\n    const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;\n\n    float cos_theta;\n    float sin_theta;\n\n    rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,\n                       ext_factor, attn_factor, cos_theta, sin_theta);\n\n    const float x0 = x[ix + 0];\n    const float x1 = x[ix + n_dims / 2];\n\n    dst[idst + 0] = ggml_sycl_cast<D>(x0 * cos_theta - x1 * sin_theta);\n    dst[idst + n_dims / 2] = ggml_sycl_cast<D>(x0 * sin_theta + x1 * cos_theta);\n}\n\ntemplate <bool forward, bool has_ff, typename T>\nstatic void rope_multi(const T *x, T *dst, const int ne00, const int ne01,\n                       const int ne02, const int s01, const int s02,\n                       const int s03, const int s1, const int s2, const int s3,\n                       const int n_dims, const int32_t *pos,\n                       const float freq_scale, const float ext_factor,\n                       const float attn_factor, const rope_corr_dims corr_dims,\n                       const float theta_scale, const float *freq_factors,\n                       const mrope_sections sections, const bool is_imrope) {\n    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +\n                        item_ct1.get_local_id(1));\n\n    if (i0 >= ne00) {\n        return;\n    }\n\n    const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +\n                        item_ct1.get_local_id(2);\n\n    const uint32_t i3 = row_dst / (ne01 * ne02);\n    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;\n    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;\n\n    int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;\n    const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;\n\n    if (i0 >= n_dims) {\n        dst[idst + i0 / 2 + 0] = x[ix + i0 / 2 + 0];\n        dst[idst + i0 / 2 + 1] = x[ix + i0 / 2 + 1];\n\n        return;\n    }\n\n    const int sect_dims =\n        sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];\n    const int sec_w = sections.v[1] + sections.v[0];\n    const int sector = (i0 / 2) % sect_dims;\n\n    float theta_base = 0.0;\n    if (is_imrope) {\n        if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h\n            theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);\n        } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w\n            theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);\n        } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t\n            theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);\n        } else {\n            theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);\n        }\n    } else {\n        if (sector < sections.v[0]) {\n            theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);\n        } else if (sector >= sections.v[0] && sector < sec_w) {\n            theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);\n        } else if (sector >= sec_w && sector < sec_w + sections.v[2]) {\n            theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);\n        } else if (sector >= sec_w + sections.v[2]) {\n            theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);\n        }\n    }\n\n    const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;\n\n    float cos_theta;\n    float sin_theta;\n\n    rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,\n                       ext_factor, attn_factor, cos_theta, sin_theta);\n\n    const float x0 = x[ix + 0];\n    const float x1 = x[ix + n_dims / 2];\n\n    dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;\n    dst[idst + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;\n}\n\ntemplate <bool forward, bool has_ff, typename T>\nstatic void rope_vision(const T *x, T *dst, const int ne00, const int ne01,\n                        const int ne02, const int s01, const int s02,\n                        const int s03, const int s1, const int s2, const int s3,\n                        const int n_dims, const int32_t *pos,\n                        const float freq_scale, const float ext_factor,\n                        const float attn_factor, const rope_corr_dims corr_dims,\n                        const float theta_scale, const float *freq_factors,\n                        const mrope_sections sections) {\n    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +\n                        item_ct1.get_local_id(1));\n\n    if (i0 >= ne00) {\n        return;\n    }\n\n    const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +\n                        item_ct1.get_local_id(2);\n\n    const uint32_t i3 = row_dst / (ne01 * ne02);\n    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;\n    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;\n\n    int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;\n    const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;\n\n    const int sect_dims = sections.v[0] + sections.v[1];\n    const int sec_w = sections.v[1] + sections.v[0];\n    const int sector = (i0 / 2) % sect_dims;\n\n    float theta_base = 0.0;\n    if (sector < sections.v[0]) {\n        const int p = sector;\n        theta_base = pos[i2] * dpct::pow(theta_scale, p);\n    } else if (sector >= sections.v[0] && sector < sec_w) {\n        const int p = sector - sections.v[0];\n        theta_base = pos[i2 + ne02] * dpct::pow(theta_scale, p);\n    }\n\n    const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;\n\n    float cos_theta;\n    float sin_theta;\n\n    rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,\n                       ext_factor, attn_factor, cos_theta, sin_theta);\n\n    const float x0 = x[ix + 0];\n    const float x1 = x[ix + n_dims];\n\n    dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;\n    dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;\n}\n\ntemplate <bool forward, typename T, typename D>\nstatic void\nrope_norm_sycl(const T *x, D *dst, const int ne00, const int ne01,\n               const int ne02, const int s01, const int s02, const int s03,\n               const int s1, const int s2, const int s3, const int n_dims,\n               const int nr, const int32_t *pos, const float freq_scale,\n               const float freq_base, const float ext_factor,\n               const float attn_factor, const rope_corr_dims corr_dims,\n               const float *freq_factors, const int64_t *row_indices,\n               const int set_rows_stride, dpct::queue_ptr stream) {\n    GGML_ASSERT(ne00 % 2 == 0);\n    const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);\n    const int n_blocks_x =\n        (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);\n    const dpct::dim3 block_nums(nr, n_blocks_x, 1);\n\n    const float theta_scale = powf(freq_base, -2.0f / n_dims);\n\n    if (freq_factors == nullptr) {\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) {\n                GGML_UNUSED(item_ct1);\n                rope_norm<forward, false>(\n                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,\n                    pos, freq_scale, ext_factor, attn_factor, corr_dims,\n                    theta_scale, freq_factors, row_indices, set_rows_stride);\n            });\n    } else {\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) {\n                GGML_UNUSED(item_ct1);\n                rope_norm<forward, true>(\n                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,\n                    pos, freq_scale, ext_factor, attn_factor, corr_dims,\n                    theta_scale, freq_factors, row_indices, set_rows_stride);\n            });\n    }\n}\n\ntemplate <bool forward, typename T, typename D>\nstatic void\nrope_neox_sycl(const T *x, D *dst, const int ne00, const int ne01,\n               const int ne02, const int s01, const int s02, const int s03,\n               const int s1, const int s2, const int s3, const int n_dims,\n               const int nr, const int32_t *pos, const float freq_scale,\n               const float freq_base, const float ext_factor,\n               const float attn_factor, const rope_corr_dims corr_dims,\n               const float *freq_factors, const int64_t *row_indices,\n               const int set_rows_stride, dpct::queue_ptr stream) {\n    GGML_ASSERT(ne00 % 2 == 0);\n    const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);\n    const int n_blocks_x =\n        (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);\n    const dpct::dim3 block_nums(nr, n_blocks_x, 1);\n\n    const float theta_scale = powf(freq_base, -2.0f / n_dims);\n\n    if (freq_factors == nullptr) {\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) {\n                GGML_UNUSED(item_ct1);\n                rope_neox<forward, false>(\n                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,\n                    pos, freq_scale, ext_factor, attn_factor, corr_dims,\n                    theta_scale, freq_factors, row_indices, set_rows_stride);\n            });\n    } else {\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) {\n                GGML_UNUSED(item_ct1);\n                rope_neox<forward, true>(\n                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,\n                    pos, freq_scale, ext_factor, attn_factor, corr_dims,\n                    theta_scale, freq_factors, row_indices, set_rows_stride);\n            });\n    }\n}\n\ntemplate <bool forward, typename T>\nstatic void\nrope_multi_sycl(const T *x, T *dst, const int ne00, const int ne01,\n                const int ne02, const int s01, const int s02, const int s03,\n                const int s1, const int s2, const int s3, const int n_dims,\n                const int nr, const int32_t *pos, const float freq_scale,\n                const float freq_base, const float ext_factor,\n                const float attn_factor, const rope_corr_dims corr_dims,\n                const float *freq_factors, const mrope_sections sections,\n                const bool is_imrope, dpct::queue_ptr stream) {\n    GGML_ASSERT(ne00 % 2 == 0);\n    const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);\n    const int n_blocks_x =\n        (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);\n    const dpct::dim3 block_nums(nr, n_blocks_x, 1);\n\n    const float theta_scale = powf(freq_base, -2.0f / n_dims);\n\n    if (freq_factors == nullptr) {\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) {\n                GGML_UNUSED(item_ct1);\n                rope_multi<forward, false, T>(\n                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,\n                    pos, freq_scale, ext_factor, attn_factor, corr_dims,\n                    theta_scale, freq_factors, sections, is_imrope);\n            });\n    } else {\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) {\n                GGML_UNUSED(item_ct1);\n                rope_multi<forward, true, T>(\n                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,\n                    pos, freq_scale, ext_factor, attn_factor, corr_dims,\n                    theta_scale, freq_factors, sections, is_imrope);\n            });\n    }\n}\n\ntemplate <bool forward, typename T>\nstatic void\nrope_vision_sycl(const T *x, T *dst, const int ne00, const int ne01,\n                 const int ne02, const int s01, const int s02, const int s03,\n                 const int s1, const int s2, const int s3, const int n_dims,\n                 const int nr, const int32_t *pos, const float freq_scale,\n                 const float freq_base, const float ext_factor,\n                 const float attn_factor, const rope_corr_dims corr_dims,\n                 const float *freq_factors, const mrope_sections sections,\n                 dpct::queue_ptr stream) {\n    GGML_ASSERT(ne00 % 2 == 0);\n    const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);\n    const int n_blocks_x =\n        (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);\n    const dpct::dim3 block_nums(nr, n_blocks_x, 1);\n\n    const float theta_scale = powf(freq_base, -2.0f / n_dims);\n\n    if (freq_factors == nullptr) {\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) {\n                GGML_UNUSED(item_ct1);\n                rope_vision<forward, false, T>(\n                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,\n                    pos, freq_scale, ext_factor, attn_factor, corr_dims,\n                    theta_scale, freq_factors, sections);\n            });\n    } else {\n        stream->parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1) {\n                GGML_UNUSED(item_ct1);\n                rope_vision<forward, true, T>(\n                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,\n                    pos, freq_scale, ext_factor, attn_factor, corr_dims,\n                    theta_scale, freq_factors, sections);\n            });\n    }\n}\n\ntemplate <bool forward>\nvoid ggml_sycl_op_rope_impl(ggml_backend_sycl_context &ctx, ggml_tensor *dst,\n                            const ggml_tensor *set_rows = nullptr) {\n    const ggml_tensor *src0 = dst->src[0];\n    const ggml_tensor *src1 = dst->src[1];\n    const ggml_tensor *src2 = dst->src[2];\n\n    const float *src0_d = (const float *)src0->data;\n    const float *src1_d = (const float *)src1->data;\n\n    void *dst_d = dst->data;\n    const int64_t *row_indices = nullptr;\n    ggml_type dst_type = dst->type;\n    int set_rows_stride = 0;\n\n    if (set_rows != nullptr) {\n        GGML_ASSERT(forward);\n        dst_d = set_rows->data;\n        row_indices = (const int64_t *)set_rows->src[1]->data;\n        dst_type = set_rows->type;\n        set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);\n    }\n    dpct::queue_ptr stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);\n    GGML_ASSERT(src0->type == dst->type ||\n                (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));\n\n    const int64_t ne00 = src0->ne[0]; // head dims\n    const int64_t ne01 = src0->ne[1]; // num heads\n    const int64_t ne02 = src0->ne[2]; // num heads\n    const int64_t nr = ggml_nrows(src0);\n\n    const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);\n    const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);\n    const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);\n\n    const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);\n    const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);\n    const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);\n\n    const int n_dims = ((int32_t *)dst->op_params)[1];\n    const int mode = ((int32_t *)dst->op_params)[2];\n    const int n_ctx_orig = ((int32_t *)dst->op_params)[4];\n    mrope_sections sections;\n\n    float freq_base;\n    float freq_scale;\n    float ext_factor;\n    float attn_factor;\n    float beta_fast;\n    float beta_slow;\n\n    memcpy(&freq_base, (int32_t *)dst->op_params + 5, sizeof(float));\n    memcpy(&freq_scale, (int32_t *)dst->op_params + 6, sizeof(float));\n    memcpy(&ext_factor, (int32_t *)dst->op_params + 7, sizeof(float));\n    memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));\n    memcpy(&beta_fast, (int32_t *)dst->op_params + 9, sizeof(float));\n    memcpy(&beta_slow, (int32_t *)dst->op_params + 10, sizeof(float));\n    memcpy(&sections.v, (int32_t *)dst->op_params + 11, sizeof(int) * 4);\n\n    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;\n    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;\n    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;\n    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;\n\n    if (is_mrope) {\n        GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 ||\n                    sections.v[2] > 0);\n    }\n\n    if (is_vision) {\n        GGML_ASSERT(n_dims == ne00 / 2);\n    }\n\n    const int32_t *pos = (const int32_t *)src1_d;\n\n    const float *freq_factors = nullptr;\n    if (src2 != nullptr) {\n        freq_factors = (const float *)src2->data;\n    }\n\n    rope_corr_dims corr_dims;\n    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,\n                             beta_slow, corr_dims.v);\n\n    // compute\n    if (is_neox) {\n        GGML_SYCL_DEBUG(\"%s: neox path\\n\", __func__);\n        if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {\n            rope_neox_sycl<forward, float, float>(\n                (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,\n                s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,\n                ext_factor, attn_factor, corr_dims, freq_factors, row_indices,\n                set_rows_stride, stream);\n        } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {\n            rope_neox_sycl<forward, float, sycl::half>(\n                (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,\n                s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,\n                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,\n                row_indices, set_rows_stride, stream);\n        } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {\n            rope_neox_sycl<forward, sycl::half, sycl::half>(\n                (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,\n                ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,\n                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,\n                row_indices, set_rows_stride, stream);\n        } else {\n            GGML_ABORT(\"Fatal error: Tensor type unsupported!\");\n        }\n    } else if (is_mrope && !is_vision) {\n        GGML_SYCL_DEBUG(\"%s: mrope path\\n\", __func__);\n        if (src0->type == GGML_TYPE_F32) {\n            rope_multi_sycl<forward>((const float *)src0_d, (float *)dst_d,\n                                     ne00, ne01, ne02, s01, s02, s03, s1, s2,\n                                     s3, n_dims, nr, pos, freq_scale, freq_base,\n                                     ext_factor, attn_factor, corr_dims,\n                                     freq_factors, sections, is_imrope, stream);\n        } else if (src0->type == GGML_TYPE_F16) {\n            rope_multi_sycl<forward>(\n                (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,\n                ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,\n                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,\n                sections, is_imrope, stream);\n        } else {\n            GGML_ABORT(\"Fatal error: Tensor type unsupported!\");\n        }\n    } else if (is_vision) {\n        GGML_SYCL_DEBUG(\"%s: vision path\\n\", __func__);\n        if (src0->type == GGML_TYPE_F32) {\n            rope_vision_sycl<forward>(\n                (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,\n                s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,\n                ext_factor, attn_factor, corr_dims, freq_factors, sections,\n                stream);\n        } else if (src0->type == GGML_TYPE_F16) {\n            rope_vision_sycl<forward>(\n                (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,\n                ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,\n                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,\n                sections, stream);\n        } else {\n            GGML_ABORT(\"Fatal error: Tensor type unsupported!\");\n        }\n    } else {\n        GGML_SYCL_DEBUG(\"%s: norm path\\n\", __func__);\n        if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {\n            rope_norm_sycl<forward, float, float>(\n                (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,\n                s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,\n                ext_factor, attn_factor, corr_dims, freq_factors, row_indices,\n                set_rows_stride, stream);\n        } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {\n            rope_norm_sycl<forward, float, sycl::half>(\n                (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,\n                s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,\n                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,\n                row_indices, set_rows_stride, stream);\n        } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {\n            rope_norm_sycl<forward, sycl::half, sycl::half>(\n                (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,\n                ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,\n                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,\n                row_indices, set_rows_stride, stream);\n        } else {\n            GGML_ABORT(\"Fatal error: Tensor type unsupported!\");\n        }\n    }\n}\n\nvoid ggml_sycl_rope(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);\n\n    ggml_sycl_op_rope_impl<true>(ctx, dst);\n}\n\nvoid ggml_sycl_rope_back(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);\n    ggml_sycl_op_rope_impl<false>(ctx, dst);\n}\n\nvoid ggml_sycl_rope_fused(ggml_backend_sycl_context &ctx, ggml_tensor *rope,\n                          ggml_tensor *set_rows) {\n    scope_op_debug_print scope_dbg_print(__func__, rope, /*num_src=*/3);\n    ggml_sycl_op_rope_impl<true>(ctx, rope, set_rows);\n}\n"
  },
  {
    "path": "src/ggml-sycl/rope.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_ROPE_HPP\n#define GGML_SYCL_ROPE_HPP\n\n#include \"common.hpp\"\n\n#define SYCL_ROPE_BLOCK_SIZE 256\n\nvoid ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);\n\nvoid ggml_sycl_rope_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_rope_fused(ggml_backend_sycl_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);\n\n#endif // GGML_SYCL_ROPE_HPP\n"
  },
  {
    "path": "src/ggml-sycl/set.cpp",
    "content": "#include \"presets.hpp\"\n#include \"common.hpp\"\n#include \"ggml.h\"\n#include \"set.hpp\"\n#include <cstdint>\n#include <sycl/sycl.hpp>\nusing namespace sycl;\n\n// Internal function: perform element-wise set operation for each thread\ninline void set_f32(const float* src, float* dst,\n                    const int64_t ne0, const int64_t ne1,\n                    const int64_t ne2, const int64_t ne3,\n                    const int64_t nb[3], const int64_t src_nb[3],\n                    const int64_t offset_elem,\n                    const nd_item<1>& item)\n{\n    const size_t idx = item.get_global_id(0);\n    const size_t total = ne0 * ne1 * ne2 * ne3;\n    if (idx >= total) return;\n\n    // Convert linear index to 4D indices\n    const size_t i3 = idx / (ne2 * ne1 * ne0);\n    const size_t rem = idx % (ne2 * ne1 * ne0);\n    const size_t i2 = rem / (ne1 * ne0);\n    const size_t rem2 = rem % (ne1 * ne0);\n    const size_t i1 = rem2 / ne0;\n    const size_t i0 = rem2 % ne0;\n\n    // Compute source and destination indices and copy\n    dst[i0 + i1*nb[0] + i2*nb[1] + i3*nb[2] + offset_elem] =\n        src[i0 + i1*src_nb[0] + i2*src_nb[1] + i3*src_nb[2]];\n}\n\n// Main function: prepare GPU queue and launch parallel_for\nvoid ggml_sycl_op_set(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {\n    const ggml_tensor* src0 = dst->src[0];\n    const ggml_tensor* src1 = dst->src[1];\n\n    // Ensure shapes and types are compatible\n    GGML_ASSERT(ggml_are_same_shape(src0, dst));\n    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));\n    GGML_ASSERT(dst->type == src0->type && src0->type == src1->type && dst->type == GGML_TYPE_F32);\n\n    const int32_t* opts = (const int32_t*) dst->op_params;\n    const int64_t nb[3]     = {opts[0]/sizeof(float), opts[1]/sizeof(float), opts[2]/sizeof(float)};\n    const int64_t offset_elem = opts[3] / sizeof(float);\n    const bool inplace = opts[4];\n\n    float* dst_ptr = (float*) dst->data;\n    const float* src0_ptr = (const float*) src0->data;\n    const float* src1_ptr = (const float*) src1->data;\n\n    queue_ptr stream = ctx.stream();\n\n    // Copy src0 to dst if not inplace\n    if (!inplace)\n        stream->memcpy(dst_ptr, src0_ptr, ggml_nbytes(dst));\n\n    const int64_t ne[4] = {src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]};\n    const int64_t src_nb[3] = {src1->nb[1]/sizeof(float), src1->nb[2]/sizeof(float), src1->nb[3]/sizeof(float)};\n\n    const size_t total_threads = ne[0]*ne[1]*ne[2]*ne[3];\n    const size_t grid_size = ((total_threads + SYCL_SET_BLOCK_SIZE - 1) / SYCL_SET_BLOCK_SIZE) * SYCL_SET_BLOCK_SIZE;\n\n    // Copy src0 to dst if not inplace\n    stream->parallel_for(\n        nd_range<1>(range<1>(grid_size), range<1>(SYCL_SET_BLOCK_SIZE)),\n        [=](nd_item<1> item) {\n            set_f32(src1_ptr, dst_ptr,\n                ne[0], ne[1], ne[2], ne[3],\n                nb, src_nb, offset_elem, item); }\n    );\n}\n"
  },
  {
    "path": "src/ggml-sycl/set.hpp",
    "content": "#pragma once\n#include \"backend.hpp\"\n#include \"ggml.h\"\n\nvoid ggml_sycl_op_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-sycl/set_rows.cpp",
    "content": "#include \"set_rows.hpp\"\n#include \"cpy.hpp\"\n\nnamespace utils {\ntemplate<typename T>\nstatic constexpr bool is_arithmetic_v() {\n    return std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half> || std::is_same_v<T, sycl::ext::oneapi::bfloat16>;\n}\n}\n\ntemplate<typename TIn, typename TOut>\nstatic inline std::enable_if_t<utils::is_arithmetic_v<TIn>() && utils::is_arithmetic_v<TOut>(), void>\nconvert (const char* src, char* dst) {\n    auto src_val = *reinterpret_cast<const TIn*>(src);\n    auto dst_val = sycl::vec<TIn, 1>(src_val).template convert<TOut, sycl::rounding_mode::automatic>()[0];\n   *reinterpret_cast<TOut*>(dst) = dst_val;\n}\n\ntemplate <typename TIdx, typename blockType, int qk, cpy_kernel_t cpyblck>\nstatic void set_rows_sycl_q(const char * __restrict__ src0_d,\n                            const TIdx * __restrict__ src1_d,\n                            blockType * __restrict__ dst_d,\n                            // tensor dimensions src0 and src1\n                            const int64_t ne00,\n                            const int64_t ne01,\n                            const int64_t ne02,\n                            const int64_t ne03,\n                            const int64_t ne10,\n                            const int64_t ne11,\n                            const int64_t ne12,\n                            const int64_t ne13,\n                            // strides for src0\n                            const size_t  nb00,\n                            const size_t  nb01,\n                            const size_t  nb02,\n                            const size_t  nb03,\n                            // strides for src1\n                            const size_t  nb10,\n                            const size_t  nb11,\n                            const size_t  nb12,\n                            const size_t  nb13,\n                            // strides for dst\n                            const size_t  nb1,\n                            const size_t  nb2,\n                            const size_t  nb3,\n                            queue_ptr     stream) {\n    const int64_t total_blocks = (ne00 * ne01 * ne02 * ne03) / qk;\n    constexpr int block_size   = 256;\n    const int64_t grid_size    = ceil_div(total_blocks, block_size);\n\n    stream->parallel_for(sycl::nd_range<1>(grid_size * block_size, block_size), [=](sycl::nd_item<1> item_ct1) {\n        const int64_t i = item_ct1.get_global_linear_id();\n        if (i >= total_blocks) {\n            return;\n        }\n        const int64_t i_base      = i * qk;\n        const int64_t i03         = i_base / (ne00 * ne01 * ne02);\n        const int64_t rem1        = i_base - i03 * (ne00 * ne01 * ne02);\n        const int64_t i02         = rem1 / (ne00 * ne01);\n        const int64_t rem2        = rem1 - i02 * ne00 * ne01;\n        const int64_t i01         = rem2 / ne00;\n        const int64_t i00         = rem2 - i01 * ne00;\n        const int64_t i12         = i03 % ne12;\n        const int64_t i11         = i02 % ne11;\n        const int64_t i10         = i01;\n        const size_t  src_offset  = calculate_offset<3>({ nb01, nb02, nb03 }, { i01, i02, i03 });\n        const char *  src_block   = src0_d + src_offset + i00 * sizeof(float);\n        const size_t  src1_offset = calculate_offset<3>({ nb10, nb11, nb12 }, { i10, i11, i12 });\n        const int64_t dst_row     = src1_d[src1_offset / sizeof(TIdx)];\n        const size_t  dst_offset =\n            calculate_offset<3>({ nb1, nb2, nb3 }, { dst_row, i02, i03 }) + (i00 / qk) * sizeof(blockType);\n        char * dst_block = reinterpret_cast<char *>(reinterpret_cast<char *>(dst_d) + dst_offset);\n        cpyblck(src_block, dst_block);\n    });\n    GGML_UNUSED(ne10);\n    GGML_UNUSED(ne13);\n    GGML_UNUSED(nb00);\n    GGML_UNUSED(nb13);\n}\n\ntemplate<typename TIn, typename TIdx, typename TOut>\nstatic void k_set_rows(\n        const char * __restrict__ src0, const TIdx * __restrict__ src1, char * __restrict__ dst,\n        const int64_t ne00, const int64_t ne01, const int64_t ne02,\n        const int64_t ne11, const int64_t ne12,\n        const size_t nb01, const size_t nb02, const size_t nb03,\n        const size_t nb10, const size_t nb11, const size_t nb12,\n        const size_t nb1, const size_t nb2, const size_t nb3,\n        const size_t src_type_size, const size_t dst_type_size,\n        const int64_t total_elements,\n        const sycl::nd_item<1> & item_ct1) {\n\n    const int64_t i = item_ct1.get_global_linear_id();\n    if (i >= total_elements) {\n        return;\n    }\n\n    const int64_t i03 = i / (ne00 * ne01 * ne02);\n    const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);\n    const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;\n    const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;\n\n    const int64_t i12 = i03 % ne12;\n    const int64_t i11 = i02 % ne11;\n    const int64_t i10 = i01;\n\n    const int64_t dst_row = *(const TIdx *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12}));\n\n    const char * src0_row = src0 + calculate_offset<3>({nb01, nb02, nb03}, {i01, i02, i03});\n    const char * src_elem = src0_row + i00 * src_type_size;\n    char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;\n    char * dst_elem = dst_row_ptr + i00 * dst_type_size;\n\n    convert<TIn, TOut>(src_elem, dst_elem);\n}\n\ntemplate<typename TIn, typename TIdx, typename TOut>\nstatic void set_rows_sycl(\n        const char * src0_d, const TIdx * src1_d, char * dst_d,\n        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,\n        const int64_t ne11, const int64_t ne12, const size_t nb01, const size_t nb02, const size_t nb03,\n        const size_t nb10, const size_t nb11, const size_t nb12,\n        const size_t nb1, const size_t nb2, const size_t nb3,\n        const size_t src_type_size, const size_t dst_type_size,\n        queue_ptr stream) {\n\n    const int64_t total_elements = ne00 * ne01 * ne02 * ne03;\n\n    constexpr int block_size = 64;\n    const int64_t grid_size = ceil_div(total_elements, block_size);\n\n    stream->parallel_for(\n        sycl::nd_range<1>(grid_size * block_size, block_size),\n        [=](sycl::nd_item<1> item_ct1) {\n            k_set_rows<TIn, TIdx, TOut>(\n                src0_d, src1_d, dst_d,\n                ne00, ne01, ne02,\n                ne11, ne12,\n                nb01, nb02, nb03,\n                nb10, nb11, nb12,\n                nb1, nb2, nb3,\n                src_type_size, dst_type_size,\n                total_elements,\n                item_ct1\n            );\n        }\n    );\n}\n\ntemplate<typename TIn, typename TIdx>\nstatic void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    const char * src0_d = (const char *)src0->data;\n    const TIdx * src1_d = (const TIdx *)src1->data;\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    dpct::queue_ptr stream = ctx.stream();\n    switch (dst->type) {\n        case GGML_TYPE_F32:\n            set_rows_sycl<TIn, TIdx, float>(\n                src0_d, src1_d, (char *)dst->data,\n                ne00, ne01, ne02, ne03,\n                ne11, ne12,\n                nb01, nb02, nb03,\n                nb10, nb11, nb12,\n                nb1, nb2, nb3,\n                sizeof(TIn), sizeof(float),\n                stream\n            );\n            break;\n        case GGML_TYPE_F16:\n            dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });\n            set_rows_sycl<TIn, TIdx, sycl::half>(\n                src0_d, src1_d, (char *)dst->data,\n                ne00, ne01, ne02, ne03,\n                ne11, ne12,\n                nb01, nb02, nb03,\n                nb10, nb11, nb12,\n                nb1, nb2, nb3,\n                sizeof(TIn), sizeof(sycl::half),\n                stream\n            );\n            break;\n        case GGML_TYPE_BF16:\n            set_rows_sycl<TIn, TIdx, sycl::ext::oneapi::bfloat16>(\n                src0_d, src1_d, (char *)dst->data,\n                ne00, ne01, ne02, ne03,\n                ne11, ne12,\n                nb01, nb02, nb03,\n                nb10, nb11, nb12,\n                nb1, nb2, nb3,\n                sizeof(TIn), sizeof(sycl::ext::oneapi::bfloat16),\n                stream\n            );\n            break;\n        case GGML_TYPE_Q8_0:\n            set_rows_sycl_q<TIdx, block_q8_0, QK8_0, cpy_blck_f32_q8_0>(src0_d, src1_d, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_Q5_1:\n            set_rows_sycl_q<TIdx, block_q5_1, QK5_1, cpy_blck_f32_q5_1>(src0_d, src1_d, (block_q5_1 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_Q5_0:\n            set_rows_sycl_q<TIdx, block_q5_0, QK5_0, cpy_blck_f32_q5_0>(src0_d, src1_d, (block_q5_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_Q4_1:\n            set_rows_sycl_q<TIdx, block_q4_1, QK4_1, cpy_blck_f32_q4_1>(src0_d, src1_d, (block_q4_1 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_Q4_0:\n            set_rows_sycl_q<TIdx, block_q4_0, QK4_0, cpy_blck_f32_q4_0>(src0_d, src1_d, (block_q4_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);\n            break;\n        case GGML_TYPE_IQ4_NL:\n            set_rows_sycl_q<TIdx, block_iq4_nl, QK4_NL, cpy_blck_f32_iq4_nl>(src0_d, src1_d, (block_iq4_nl *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);\n            break;\n\n        default:\n            GGML_ABORT(\"Unsupported tensor type!\");\n            break;\n    }\n}\n\nvoid ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I64 || dst->src[1]->type == GGML_TYPE_I32);\n\n    if (src1->type == GGML_TYPE_I64) {\n        set_rows_sycl<float, int64_t>(ctx, src0, src1, dst);\n    } else {\n        set_rows_sycl<float, int32_t>(ctx, src0, src1, dst);\n    }\n}\n"
  },
  {
    "path": "src/ggml-sycl/set_rows.hpp",
    "content": "#ifndef GGML_SYCL_SET_ROWS_HPP\n#define GGML_SYCL_SET_ROWS_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\n#endif // GGML_SYCL_SET_ROWS_HPP\n"
  },
  {
    "path": "src/ggml-sycl/softmax.cpp",
    "content": "#include \"softmax.hpp\"\n#include <cstdint>\n#include <utility>\n#include <cmath>\n\n\ntemplate <typename T> static __dpct_inline__ float t2f32(T val) {\n    return (float) val;\n}\n\ntemplate <> float __dpct_inline__ t2f32<sycl::half>(sycl::half val) {\n  return sycl::vec<sycl::half, 1>(val)\n      .convert<float, sycl::rounding_mode::automatic>()[0];\n}\n\nstruct soft_max_params {\n\n    int64_t nheads;\n    uint32_t n_head_log2;\n    int64_t ncols;\n    int64_t nrows_x;\n    int64_t nrows_y;\n    int64_t ne00;\n    int64_t ne01;\n    int64_t ne02;\n    int64_t ne03;\n    int64_t nb11;\n    int64_t nb12;\n    int64_t nb13;\n\n    int64_t ne12;\n    int64_t ne13;\n    float scale;\n    float max_bias;\n    float m0;\n    float m1;\n};\n\n// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.\n// As we want to keep pragma unroll for all other cases we suppress the clang transformation warning here.\n#ifdef __clang__\n#pragma clang diagnostic push\n#pragma clang diagnostic ignored \"-Wpass-failed\"\n#endif // __clang__\ntemplate <bool use_shared, int ncols_template, int block_size_template, typename T>\nstatic void soft_max_f32(const float *         x,\n                         const T *             mask,\n                         const float *         sinks,\n                         float *               dst,\n                         const soft_max_params p,\n                         uint8_t *             dpct_local) {\n    auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const int ncols    = ncols_template == 0 ? p.ncols : ncols_template;\n    const int block_size = block_size_template == 0\n                               ? item_ct1.get_local_range(2)\n                               : block_size_template;\n    const int nthreads = block_size;\n    const int nwarps = nthreads / WARP_SIZE;\n    size_t nreduce = nwarps / WARP_SIZE;\n\n    const int tid = item_ct1.get_local_id(2);\n\n    const int64_t i03 = item_ct1.get_group(0);\n    const int64_t i02 = item_ct1.get_group(1);\n    const int64_t i01 = item_ct1.get_group(2);\n\n    //TODO: noncontigous inputs/outputs\n    const int rowx = item_ct1.get_group(2) +\n                     item_ct1.get_group(1) * item_ct1.get_group_range(2) +\n                     item_ct1.get_group(0) * item_ct1.get_group_range(2) *\n                         item_ct1.get_group_range(1);\n\n    const int64_t i11 = i01;\n    const int64_t i12 = i02 % p.ne12;\n    const int64_t i13 = i03 % p.ne13;\n\n    x    += int64_t(rowx)*ncols;\n    mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);\n    dst  += int64_t(rowx)*ncols;\n\n    const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;\n    const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;\n\n    const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);\n\n    float * buf_iw = (float *) dpct_local;\n\n    // shared memory buffer to cache values between iterations:\n    float *vals = use_shared ? buf_iw + sycl::max(nwarps, WARP_SIZE) : dst;\n    float max_val = sinks ? sinks[i02] : -INFINITY;\n#pragma unroll\n    for (int col0 = 0; col0 < ncols; col0 += block_size) {\n        const int col = col0 + tid;\n\n        if (ncols_template == 0 && col >= ncols) {\n            break;\n        }\n\n        const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);\n\n        vals[col] = val;\n        max_val   = sycl::max(max_val, val);\n    }\n    // find the max value in the block\n    max_val = warp_reduce_max<WARP_SIZE>(max_val);\n\n    if (block_size > WARP_SIZE) {\n        if (warp_id == 0) {\n            buf_iw[lane_id] = -INFINITY;\n        }\n        item_ct1.barrier();\n\n        if (lane_id == 0) {\n            buf_iw[warp_id] = max_val;\n        }\n        item_ct1.barrier();\n\n        max_val = buf_iw[lane_id];\n        max_val = warp_reduce_max<WARP_SIZE>(max_val);\n    }\n    float tmp = 0.0f; // partial sum\n\n#pragma unroll\n    for (int col0 = 0; col0 < ncols; col0 += block_size) {\n        const int col = col0 + tid;\n\n        if (ncols_template == 0 && col >= ncols) {\n            break;\n        }\n\n        const float val = sycl::native::exp(vals[col] - max_val);\n        tmp += val;\n        vals[col] = val;\n    }\n    // find the sum of exps in the block\n    tmp = warp_reduce_sum<WARP_SIZE>(tmp);\n    if (block_size > WARP_SIZE) {\n        item_ct1.barrier();\n        if (warp_id == 0) {\n            buf_iw[lane_id] = 0.0f;\n            for (size_t i = 1; i < nreduce; i += 1) {\n                buf_iw[lane_id + i * WARP_SIZE] = 0.f;\n            }\n        }\n        item_ct1.barrier();\n\n        if (lane_id == 0) {\n            buf_iw[warp_id] = tmp;\n        }\n        item_ct1.barrier();\n\n        tmp = buf_iw[lane_id];\n        for (size_t i = 1; i < nreduce; i += 1) {\n            tmp += buf_iw[lane_id + i * WARP_SIZE];\n        }\n        tmp = warp_reduce_sum<WARP_SIZE>(tmp);\n    }\n    if (sinks) {\n        tmp += sycl::native::exp(sinks[i02] - max_val);\n    }\n    const float inv_sum = 1.0f / tmp;\n\n#pragma unroll\n    for (int col0 = 0; col0 < ncols; col0 += block_size) {\n        const int col = col0 + tid;\n\n        if (ncols_template == 0 && col >= ncols) {\n            return;\n        }\n\n        dst[col] = vals[col] * inv_sum;\n    }\n}\n#ifdef __clang__\n#pragma clang diagnostic pop\n#endif // __clang__\n\nstatic void soft_max_back_f32(const float *grad, const float *dstf, float *dst,\n                              const int ncols, const float scale) {\n    auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();\n    const int tid      = item_ct1.get_local_id(2);\n    const int rowx     = item_ct1.get_group(2);\n\n    grad += int64_t(rowx)*ncols;\n    dstf += int64_t(rowx)*ncols;\n    dst  += int64_t(rowx)*ncols;\n\n    float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients\n\n    for (int col = tid; col < ncols; col += WARP_SIZE) {\n        dgf_dot += dstf[col]*grad[col];\n    }\n\n    dgf_dot = warp_reduce_sum<WARP_SIZE>(dgf_dot);\n\n    for (int col = tid; col < ncols; col += WARP_SIZE) {\n        dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];\n    }\n}\n\ntemplate <int... Ns, typename T>\nstatic void launch_soft_max_kernels(const float *           x,\n                                    const T *               mask,\n                                    const float *           sinks,\n                                    float *                 dst,\n                                    const soft_max_params & p,\n                                    dpct::queue_ptr         stream,\n                                    dpct::dim3              block_dims,\n                                    dpct::dim3              block_nums,\n                                    size_t                  nbytes_shared)\n{\n    auto launch_kernel = [=](auto I) -> bool {\n        constexpr int ncols = decltype(I)::value;\n        constexpr int block = (ncols > 1024 ? 1024 : ncols);\n        if (p.ncols == ncols) {\n            stream->submit([&](sycl::handler &cgh) {\n                sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(\n                    sycl::range<1>(nbytes_shared), cgh);\n\n                cgh.parallel_for(\n                    sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                    [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(\n                        WARP_SIZE)]] {\n                        soft_max_f32<true, ncols, block>(\n                            x, mask, sinks, dst, p,\n                            dpct_local_acc_ct1\n                                .get_multi_ptr<sycl::access::decorated::no>()\n                                .get());\n                        GGML_UNUSED(item_ct1);\n                    });\n            });\n            return true;\n        }\n        return false;\n    };\n\n    // unary fold over launch_kernel\n    if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {\n        return;\n    }\n\n    stream->submit([&](sycl::handler &cgh) {\n        sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(\n            sycl::range<1>(nbytes_shared), cgh);\n\n        cgh.parallel_for(\n            sycl::nd_range<3>(block_nums * block_dims, block_dims),\n            [=](sycl::nd_item<3> item_ct1)\n                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {\n                    soft_max_f32<true, 0, 0>(\n                        x, mask, sinks, dst, p,\n                        dpct_local_acc_ct1\n                            .get_multi_ptr<sycl::access::decorated::no>()\n                            .get());\n                    GGML_UNUSED(item_ct1);\n                });\n    });\n}\n\ntemplate <typename T>\nstatic void soft_max_f32_sycl(const float *x, const T *mask,\n                              const float *sinks, float *dst,\n                              const soft_max_params &params,\n                              dpct::queue_ptr stream, int device) {\n    int nth = WARP_SIZE;\n    int max_block_size = ggml_sycl_info().max_work_group_sizes[device];\n    const int64_t ncols_x = params.ncols;\n\n    while (nth < ncols_x && nth < max_block_size) nth *= 2;\n    if (nth>max_block_size) nth = max_block_size;\n\n    const dpct::dim3 block_dims(nth, 1, 1);\n    const dpct::dim3 block_nums(params.ne01, params.ne02, params.ne03);\n    const size_t nbytes_shared =\n        (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE) * sizeof(float);\n\n    const int id       = get_current_device_id();\n    const size_t smpbo = ggml_sycl_info().devices[id].smpbo;\n\n    if (nbytes_shared <= smpbo && ncols_x <= max_block_size) {\n        launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(\n            x, mask, sinks, dst, params, stream, block_dims, block_nums,\n            nbytes_shared);\n    } else {\n        const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);\n\n        stream->submit([&](sycl::handler &cgh) {\n            sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(\n                sycl::range<1>(nbytes_shared_low), cgh);\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1) {\n                    soft_max_f32<false, 0, 0>(\n                        x, mask, sinks, dst, params,\n                        dpct_local_acc_ct1\n                            .get_multi_ptr<sycl::access::decorated::no>()\n                            .get());\n                    GGML_UNUSED(item_ct1);\n                });\n        });\n    }\n}\n\nstatic void soft_max_back_f32_sycl(const float *   grad,\n                                   const float *   dstf,\n                                   float *         dst,\n                                   const int       ncols,\n                                   const int       nrows,\n                                   const float     scale,\n                                   dpct::queue_ptr stream) {\n    const dpct::dim3 block_dims(WARP_SIZE, 1, 1);\n    const dpct::dim3 block_nums(nrows, 1, 1);\n\n    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),\n                         [=](sycl::nd_item<3> item_ct1) {\n                             soft_max_back_f32(grad, dstf, dst, ncols, scale);\n                             GGML_UNUSED(item_ct1);\n                         });\n}\n\nvoid ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    const ggml_tensor * src2 = dst->src[2];\n\n    const float * src0_d = (const float *) src0->data;\n    const void  * src1_d = src1 ? (const void *) src1->data : nullptr;\n    const void  * src2_d = src2 ? (const void *) src2->data : nullptr;\n    float       *  dst_d = (float *) dst->data;\n\n    dpct::queue_ptr stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    // src1 contains mask and it is optional\n    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);\n\n    const int64_t nrows_x = ggml_nrows(src0);\n    const int64_t nrows_y = src0->ne[1];\n\n    const int64_t ne00 = src0->ne[0];\n\n    float scale    = 1.0f;\n    float max_bias = 0.0f;\n\n    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));\n    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));\n\n    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);\n\n    const int64_t nb11 = src1 ? src1->nb[1] : 1;\n    const int64_t nb12 = src1 ? src1->nb[2] : 1;\n    const int64_t nb13 = src1 ? src1->nb[3] : 1;\n\n    const int64_t ne12 = src1 ? src1->ne[2] : 1;\n    const int64_t ne13 = src1 ? src1->ne[3] : 1;\n\n    const uint32_t n_head      = src0->ne[2];\n    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));\n\n    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);\n    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n\n    soft_max_params params = {};\n    params.nheads = src0->ne[2];\n    params.n_head_log2 = n_head_log2;\n    params.ncols = ne00;\n    params.nrows_x = nrows_x;\n    params.nrows_y = nrows_y;\n    params.ne00 = src0->ne[0];\n    params.ne01 = src0->ne[1];\n    params.ne02 = src0->ne[2];\n    params.ne03 = src0->ne[3];\n    params.nb11 = nb11;\n    params.nb12 = nb12;\n    params.nb13 = nb13;\n    params.ne12 = ne12;\n    params.ne13 = ne13;\n    params.scale = scale;\n    params.max_bias = max_bias;\n    params.m0 = m0;\n    params.m1 = m1;\n\n    if (use_f16) {\n        soft_max_f32_sycl(src0_d, (const sycl::half *)src1_d,\n                          (const float *)src2_d, dst_d, params, stream,\n                          ctx.device);\n    } else {\n        soft_max_f32_sycl(src0_d, (const float *)src1_d, (const float *)src2_d,\n                          dst_d, params, stream, ctx.device);\n    }\n}\n\nvoid ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);\n    const ggml_tensor * src0 = dst->src[0]; // grad\n    const ggml_tensor * src1 = dst->src[1]; // forward pass output\n\n    const float * src0_d = (const float *) src0->data;\n    const float * src1_d = (const float *) src1->data;\n    float       * dst_d  = (float       *) dst->data;\n\n    dpct::queue_ptr stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    const int64_t ncols = src0->ne[0];\n    const int64_t nrows = ggml_nrows(src0);\n\n    float scale    = 1.0f;\n    float max_bias = 0.0f;\n\n    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));\n    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));\n\n    GGML_ASSERT(max_bias == 0.0f);\n\n    soft_max_back_f32_sycl(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);\n}\n"
  },
  {
    "path": "src/ggml-sycl/softmax.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_SOFTMAX_HPP\n#define GGML_SYCL_SOFTMAX_HPP\n\n#include \"common.hpp\"\n\n#define SYCL_SOFT_MAX_BLOCK_SIZE 1024\n\nvoid ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst);\n\nvoid ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\n#endif // GGML_SYCL_SOFTMAX_HPP\n"
  },
  {
    "path": "src/ggml-sycl/ssm_conv.cpp",
    "content": "#include \"ssm_conv.hpp\"\n#include \"common.hpp\"\n\n#include <cstdio>\n\nusing namespace sycl;\n\nstatic void kernel_ssm_conv(\n    queue &q,\n    const float *src_data,\n    const float *weights,\n    float *dst_data,\n    int d_conv,\n    int d_inner,\n    int n_t,\n    int n_s,\n    int ncs __attribute__((unused)),\n    int src_stride_inner,\n    int src_stride_seq,\n    int dst_stride_token,\n    int dst_stride_seq\n) {\n    const size_t total_work = static_cast<size_t>(d_inner) * static_cast<size_t>(n_t) * static_cast<size_t>(n_s);\n    const size_t work_group_size = 256;\n    const size_t num_work_groups = (total_work + work_group_size - 1) / work_group_size;\n\n    const range<1> global_range(num_work_groups * work_group_size);\n    const range<1> local_range(work_group_size);\n\n    q.submit([&](handler &h) {\n        h.parallel_for(\n            nd_range<1>(global_range, local_range),\n            [=](nd_item<1> item) {\n                const size_t idx = item.get_global_id(0);\n                if (idx >= total_work) {\n                    return;\n                }\n\n                const int channel = static_cast<int>(idx % d_inner);\n                const int token   = static_cast<int>((idx / d_inner) % n_t);\n                const int seq     = static_cast<int>(idx / (static_cast<size_t>(d_inner) * static_cast<size_t>(n_t)));\n\n                const float *s = src_data\n                    + static_cast<size_t>(seq) * static_cast<size_t>(src_stride_seq)\n                    + static_cast<size_t>(channel) * static_cast<size_t>(src_stride_inner)\n                    + static_cast<size_t>(token);\n\n                const float *c = weights + static_cast<size_t>(channel) * static_cast<size_t>(d_conv);\n\n                float sumf = 0.0f;\n                for (int i0 = 0; i0 < d_conv; ++i0) {\n                    sumf += s[i0] * c[i0];\n                }\n\n                const size_t dst_idx =\n                    static_cast<size_t>(seq) * static_cast<size_t>(dst_stride_seq) +\n                    static_cast<size_t>(token) * static_cast<size_t>(dst_stride_token) +\n                    static_cast<size_t>(channel);\n\n                dst_data[dst_idx] = sumf;\n            }\n        );\n    });\n}\n\nvoid ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    ggml_tensor * src0 = dst->src[0];\n    ggml_tensor * src1 = dst->src[1];\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type  == GGML_TYPE_F32);\n\n    const int d_conv   = src1->ne[0];\n    const int ncs      = src0->ne[0];\n    const int d_inner  = src0->ne[1];\n    const int n_t      = dst->ne[1];\n    const int n_s      = dst->ne[2];\n\n    GGML_ASSERT(src0->ne[0] == d_conv - 1 + n_t);\n    GGML_ASSERT(src0->ne[1] == d_inner);\n    GGML_ASSERT(src1->ne[1] == d_inner);\n\n    GGML_ASSERT(dst->ne[0] == d_inner);\n    GGML_ASSERT(dst->ne[1] == n_t);\n    GGML_ASSERT(dst->ne[2] == n_s);\n\n    GGML_ASSERT(src0->nb[0] == sizeof(float));\n    GGML_ASSERT(src1->nb[0] == sizeof(float));\n\n    GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));\n\n    const int src_stride_inner = ncs;\n    const int src_stride_seq   = ncs * d_inner;\n    const int dst_stride_token = d_inner;\n    const int dst_stride_seq   = d_inner * n_t;\n\n    try {\n        queue *q = ctx.stream();\n\n        const float *src_data = static_cast<const float *>(src0->data);\n        const float *weights  = static_cast<const float *>(src1->data);\n        float *dst_data       = static_cast<float *>(dst->data);\n\n        GGML_ASSERT(src_data && weights && dst_data);\n\n        kernel_ssm_conv(\n            *q,\n            src_data,\n            weights,\n            dst_data,\n            d_conv,\n            d_inner,\n            n_t,\n            n_s,\n            ncs,\n            src_stride_inner,\n            src_stride_seq,\n            dst_stride_token,\n            dst_stride_seq\n        );\n\n    } catch (const std::exception &e) {\n        std::fprintf(stderr, \"[SYCL-SSM_CONV] ERROR: %s\\n\", e.what());\n        throw;\n    }\n}\n"
  },
  {
    "path": "src/ggml-sycl/ssm_conv.hpp",
    "content": "#pragma once\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n"
  },
  {
    "path": "src/ggml-sycl/sycl_hw.cpp",
    "content": "#include \"sycl_hw.hpp\"\n\n// TODO: currently not used\n/*\nsycl_hw_info get_device_hw_info(sycl::device *device_ptr) {\n  sycl_hw_info res;\n  int32_t id = device_ptr->get_info<sycl::ext::intel::info::device::device_id>();\n  res.device_id = id;\n\n  syclex::architecture arch = device_ptr->get_info<syclex::info::device::architecture>();\n  res.arch = arch;\n\n  return res;\n}\n*/\n"
  },
  {
    "path": "src/ggml-sycl/sycl_hw.hpp",
    "content": "#ifndef SYCL_HW_HPP\n#define SYCL_HW_HPP\n\n#include <algorithm>\n#include <stdio.h>\n#include <vector>\n#include <map>\n\n#include <sycl/sycl.hpp>\n\nnamespace syclex = sycl::ext::oneapi::experimental;\n\n// TODO: currently not used\n/*\nstruct sycl_hw_info {\n  syclex::architecture arch;\n  int32_t device_id;\n};\n\nbool is_in_vector(std::vector<int> &vec, int item);\n\nsycl_hw_info get_device_hw_info(sycl::device *device_ptr);\n*/\n\n\n#endif // SYCL_HW_HPP\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.hpp\"\n\nDECL_FATTN_TILE_CASE(112, 112);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.hpp\"\n\nDECL_FATTN_TILE_CASE(128, 128);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.hpp\"\n\nDECL_FATTN_TILE_CASE(256, 256);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.hpp\"\n\nDECL_FATTN_TILE_CASE(40, 40);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.hpp\"\n\nDECL_FATTN_TILE_CASE(576, 512);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.hpp\"\n\nDECL_FATTN_TILE_CASE(64, 64);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.hpp\"\n\nDECL_FATTN_TILE_CASE(72, 72);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.hpp\"\n\nDECL_FATTN_TILE_CASE(80, 80);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-tile.hpp\"\n\nDECL_FATTN_TILE_CASE(96, 96);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);\n"
  },
  {
    "path": "src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp",
    "content": "// This file has been autogenerated by generate_cu_files.py, do not edit manually.\n\n#include \"../fattn-vec.hpp\"\n\nDECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);\nDECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);\n"
  },
  {
    "path": "src/ggml-sycl/tsembd.cpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#include \"tsembd.hpp\"\n\nstatic void timestep_embedding_f32(\n        const float * timesteps, float * dst, const int nb1,\n        const int dim, const int max_period, const sycl::nd_item<3> &item_ct1) {\n    // item_ct1.get_group(1)(blockIDx.y): idx of timesteps->ne[0]\n    // item_ct1.get_group(2) (blockIDx.x): idx of ((dim + 1) / 2) / BLOCK_SIZE\n    int i = item_ct1.get_group(1);\n    int j = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2);\n    float * embed_data = (float *)((char *)dst +  i*nb1);\n\n    int half = dim / 2;\n\n    if (dim % 2 != 0 && j == half) {\n        embed_data[2 * half] = 0.f;\n    }\n\n    if (j >= half) {\n        return;\n    }\n\n    float timestep = timesteps[i];\n    float freq = (float)sycl::native::exp(-(sycl::log((float)max_period)) * j / half);\n    float arg = timestep * freq;\n    embed_data[j] = sycl::cos(arg);\n    embed_data[j + half] = sycl::sin(arg);\n}\n\nstatic void timestep_embedding_f32_sycl(\n        const float * x, float * dst, const int ne00, const int nb1,\n        const int dim, const int max_period, const queue_ptr& stream) {\n    // As the kernel returns when thread.idx is larger than dim/2, the half_ceil does not need to pad\n    int half_ceil = dim / 2;\n    int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;\n    sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);\n    sycl::range<3> gridDim(1, ne00, num_blocks);\n    stream->parallel_for(\n        sycl::nd_range<3>(\n            gridDim * block_dims, block_dims),\n        [=](sycl::nd_item<3> item_ct1) {\n            timestep_embedding_f32(\n                x, dst, nb1, dim, max_period, item_ct1\n            );\n        });\n}\n\nvoid ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);\n    const ggml_tensor *  src0   = dst->src[0];\n    const float * src0_d = (const float *)src0->data;\n    float * dst_d = (float *)dst->data;\n    dpct::queue_ptr stream = ctx.stream();\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    const int dim = dst->op_params[0];\n    const int max_period = dst->op_params[1];\n\n    timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);\n}\n"
  },
  {
    "path": "src/ggml-sycl/tsembd.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2024 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_TSEMBD_HPP\n#define GGML_SYCL_TSEMBD_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\n#endif // GGML_SYCL_TSEMBD_HPP\n"
  },
  {
    "path": "src/ggml-sycl/vecdotq.hpp",
    "content": "//\n// MIT license\n// Copyright (C) 2025 Intel Corporation\n// SPDX-License-Identifier: MIT\n//\n\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n\n#ifndef GGML_SYCL_VECDOTQ_HPP\n#define GGML_SYCL_VECDOTQ_HPP\n\n#include \"dpct/helper.hpp\"\n#include \"ggml.h\"\n#include \"quants.hpp\"\n\ntypedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,\n                                  const int & iqs);\n\nstatic __dpct_inline__ int get_int_b1(const void * x, const int & i32) {\n    const uint8_t * x8 = (const uint8_t *) x;\n\n    int x32  = x8[4*i32 + 0] <<  0;\n    x32     |= x8[4*i32 + 1] <<  8;\n    x32     |= x8[4*i32 + 2] << 16;\n    x32     |= x8[4*i32 + 3] << 24;\n\n    return x32;\n}\n\n\nstatic __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {\n  const uint16_t* x16 =\n      (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte\n                                                 // alignment\n\n  int x32 = 0;\n  x32 |= x16[0] << 0;\n  x32 |= x16[1] << 16;\n\n  return x32;\n}\n\nstatic __dpct_inline__ int get_int_from_uint8(\n    const uint8_t* x8,\n    const int& i32) {\n  const uint16_t* x16 =\n      (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte\n                                                 // alignment\n\n  int x32 = 0;\n  x32 |= x16[0] << 0;\n  x32 |= x16[1] << 16;\n\n  return x32;\n}\n\nstatic __dpct_inline__ int get_int_from_int8_aligned(\n    const int8_t* x8,\n    const int& i32) {\n  return *(\n      (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment\n}\n\nstatic __dpct_inline__ int get_int_from_uint8_aligned(\n    const uint8_t* x8,\n    const int& i32) {\n  return *(\n      (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment\n}\n\nstatic __dpct_inline__ void get_int_from_table_16(const uint32_t &q4,\n                                                  const uint8_t *values,\n                                                  int &val1, int &val2) {\n\n    uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;\n    aux32 = q4 & 0x0f0f0f0f;\n    uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);\n    uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);\n    val1 = v1 | (v2 << 16);\n    aux32 = (q4 >> 4) & 0x0f0f0f0f;\n    v1 = values[q8[0]] | (values[q8[1]] << 8);\n    v2 = values[q8[2]] | (values[q8[3]] << 8);\n    val2 = v1 | (v2 << 16);\n}\n\nstatic __dpct_inline__ sycl::int2 get_int_from_table_16(\n    const int& q4, const int8_t* table) {\n  const uint32_t* table32 = (const uint32_t*)table;\n  uint32_t tmp[2];\n  const uint32_t low_high_selection_indices =\n      (0x32103210 | ((q4 & 0x88888888) >> 1));\n#pragma unroll\n  for (uint32_t i = 0; i < 2; ++i) {\n    const uint32_t shift = 16 * i;\n\n    const uint32_t low =\n        dpct::byte_level_permute(table32[0], table32[1], q4 >> shift);\n    const uint32_t high =\n        dpct::byte_level_permute(table32[2], table32[3], q4 >> shift);\n    tmp[i] = dpct::byte_level_permute(\n        low, high, low_high_selection_indices >> shift);\n  }\n  return sycl::int2(\n      dpct::byte_level_permute(tmp[0], tmp[1], 0x6420),\n      dpct::byte_level_permute(tmp[0], tmp[1], 0x7531));\n}\n\n#define VDR_Q2_K_Q8_1_MMVQ 1\n\n// contiguous v/x values\nstatic __dpct_inline__ float vec_dot_q2_K_q8_1_impl_mmvq(\n    const int &v, const int *__restrict__ u, const uint8_t *__restrict__ scales,\n    const sycl::half2 &dm2, const float *__restrict__ d8) {\n\n    float sumf_d = 0.0f;\n    float sumf_m = 0.0f;\n\n#pragma unroll\n    for (int i = 0; i < QR2_K; ++i) {\n        const int sc = scales[2*i];\n\n        const int vi = (v >> (2*i)) & 0x03030303;\n\n        sumf_d +=\n            d8[i] * (dpct::dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product\n\n        // fill int with 4x m\n        int m = sc >> 4;\n        m |= m <<  8;\n        m |= m << 16;\n        sumf_m += d8[i] *\n                  dpct::dp4a(\n                      m, u[i],\n                      0); // multiply constant q2_K part with sum of q8_1 values\n    }\n\n    const sycl::float2 dm2f =\n        dm2.convert<float, sycl::rounding_mode::automatic>();\n\n    return dm2f.x() * sumf_d - dm2f.y() * sumf_m;\n}\n\n\n#define VDR_Q3_K_Q8_1_MMVQ 1\n\n// contiguous v/x values\nstatic __dpct_inline__ float vec_dot_q3_K_q8_1_impl_mmvq(\n    const int &vl, const int &vh, const int *__restrict__ u,\n    const uint8_t *__restrict__ scales, const int &scale_offset,\n    const float &d3, const float *__restrict__ d8) {\n\n    float sumf = 0.0f;\n\n#pragma unroll\n    for (int i = 0; i < QR3_K; ++i) {\n        const int isc = scale_offset + 2*i;\n\n        const int isc_low = isc % (QK_K/32);\n        const int sc_shift_low = 4 * (isc / (QK_K/32));\n        const int sc_low  = (scales[isc_low] >> sc_shift_low) & 0xF;\n\n        const int isc_high = isc % (QK_K/64);\n        const int sc_shift_high = 2 * (isc / (QK_K/64));\n        const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;\n\n        const int sc = (sc_low | sc_high) - 32;\n\n        const int vil = (vl >> (2*i)) & 0x03030303;\n\n        const int vih = ((vh >> i) << 2) & 0x04040404;\n\n        const int vi =\n            dpct::vectorized_binary<sycl::char4>(vil, vih, dpct::sub_sat());\n\n        sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product\n    }\n\n    return d3 * sumf;\n}\n\n#define VDR_Q4_K_Q8_1_MMVQ 2\n\n// contiguous v/x values\nstatic __dpct_inline__ float vec_dot_q4_K_q8_1_impl_vmmq(\n    const int *__restrict__ v, const int *__restrict__ u,\n    const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,\n    const sycl::half2 &dm4, const float *__restrict__ d8) {\n\n    float sumf_d = 0.0f;\n    float sumf_m = 0.0f;\n\n#pragma unroll\n    for (int i = 0; i < QR4_K; ++i) {\n        const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;\n        const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;\n\n        const int dot1 =\n            dpct::dp4a(v1i, u[2 * i + 1],\n                       dpct::dp4a(v0i, u[2 * i + 0], 0)); // SIMD dot product\n        const int dot2 =\n            dpct::dp4a(0x01010101, u[2 * i + 1],\n                       dpct::dp4a(0x01010101, u[2 * i + 0], 0)); // sum of u\n\n        sumf_d += d8[i] * (dot1 * sc[i]);\n        sumf_m += d8[i] * (dot2 * m[i]);  // multiply constant part of q4_K with sum of q8_1 values\n    }\n\n    const sycl::float2 dm4f =\n        dm4.convert<float, sycl::rounding_mode::automatic>();\n\n    return dm4f.x() * sumf_d - dm4f.y() * sumf_m;\n}\n\n\n#define VDR_Q5_K_Q8_1_MMVQ 2\n\n// contiguous v/x values\nstatic __dpct_inline__ float vec_dot_q5_K_q8_1_impl_vmmq(\n    const int *__restrict__ vl, const int *__restrict__ vh,\n    const int *__restrict__ u, const uint8_t *__restrict__ sc,\n    const uint8_t *__restrict__ m, const sycl::half2 &dm5,\n    const float *__restrict__ d8) {\n\n    float sumf_d = 0.0f;\n    float sumf_m = 0.0f;\n\n#pragma unroll\n    for (int i = 0; i < QR5_K; ++i) {\n        const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;\n        const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;\n\n        const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;\n        const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;\n\n        const int v0i = vl0i | vh0i;\n        const int v1i = vl1i | vh1i;\n\n        const int dot1 =\n            dpct::dp4a(v0i, u[2 * i + 0],\n                       dpct::dp4a(v1i, u[2 * i + 1], 0)); // SIMD dot product\n        const int dot2 =\n            dpct::dp4a(0x01010101, u[2 * i + 0],\n                       dpct::dp4a(0x01010101, u[2 * i + 1], 0)); // sum of u\n\n        sumf_d += d8[i] * (dot1 * sc[i]);\n        sumf_m += d8[i] * (dot2 * m[i]);\n\n    }\n\n    const sycl::float2 dm5f =\n        dm5.convert<float, sycl::rounding_mode::automatic>();\n\n    return dm5f.x() * sumf_d - dm5f.y() * sumf_m;\n}\n\n\n#define VDR_Q6_K_Q8_1_MMVQ 1\n\n// contiguous v/x values\nstatic __dpct_inline__ float\nvec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh,\n                            const int *__restrict__ u,\n                            const int8_t *__restrict__ scales, const float &d,\n                            const float *__restrict__ d8) {\n\n    float sumf = 0.0f;\n\n#pragma unroll\n    for (int i = 0; i < QR6_K; ++i) {\n        const int sc = scales[4*i];\n\n        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;\n\n        const int vih = ((vh >> (4*i)) << 4) & 0x30303030;\n\n        const int vi = dpct::vectorized_binary<sycl::char4>(\n            (vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32\n\n        sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product\n    }\n\n    return d*sumf;\n}\n\n// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called\n// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q\n\ntemplate <ggml_type T> struct reorder_vec_dot_q_sycl {\n    static_assert(T != T, \"ggml_type for reorder vecdot not implemented\");\n};\n\ntemplate <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {\n    static constexpr ggml_type gtype = GGML_TYPE_Q4_0;\n\n    using q4_0_block  = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_0>;\n    using q4_0_traits = typename q4_0_block::traits;\n\n    __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, const sycl::half2 & ds8) {\n        int sumi = 0;\n\n#pragma unroll\n        for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {\n            const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;\n            const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;\n\n            // SIMD dot product of quantized values\n            sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);\n            sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);\n        }\n\n        const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();\n\n        // second part effectively subtracts 8 from each quant value\n        return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());\n    }\n\n    __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,\n                                     const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,\n                                     const sycl::half2 * q8_1_ds, const int & iqs) {\n        const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset.first;\n        const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset.first));\n        int             v[q4_0_traits::vdr_mmvq];\n        int             u[2 * q4_0_traits::vdr_mmvq];\n\n\n#pragma unroll\n        for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {\n            v[i]         = get_int_from_uint8(bq4_0, iqs + i);\n            u[2 * i + 0] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i);\n            u[2 * i + 1] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i + q4_0_traits::qi);\n        }\n\n        return vec_dot_q4_0_q8_1_impl(v, u, d, *q8_1_ds);\n    };\n};\n\nstatic inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales,\n                                             const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1,\n                                             const int &        iqs) {\n    int   v[2];\n    int   u[2 * QR4_K];\n    float d8[QR4_K];\n\n    v[0] = q4[0];\n    v[1] = q4[4];\n\n    uint16_t  aux[2];\n    const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2;\n    if (j < 2) {\n        aux[0] = scales[j + 0] & 0x3f3f;\n        aux[1] = scales[j + 2] & 0x3f3f;\n    } else {\n        aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);\n        aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);\n    }\n\n    const uint8_t * sc = (const uint8_t *) aux;\n    const uint8_t * m  = sc + 2;\n\n    const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));\n\n    for (int i = 0; i < QR4_K; ++i) {\n        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;\n        d8[i]                   = bq8i->ds[0];\n\n        const int * q8 = (const int *) bq8i->qs + ((iqs / 2) % 4);\n        u[2 * i + 0]   = q8[0];\n        u[2 * i + 1]   = q8[4];\n    }\n\n    return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, dm, d8);\n}\n\ntemplate <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {\n    static constexpr ggml_type gtype = GGML_TYPE_Q4_K;\n\n    using q4_k_block  = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;\n    using q4_k_traits = typename q4_k_block::traits;\n\n    __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,\n                                     const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,\n                                     const sycl::half2 * q8_1_ds, const int & iqs) {\n        const uint8_t *    base           = static_cast<const uint8_t *>(vbq);\n        const uint8_t *    qs             = base + ibx_offset.first;\n        const uint8_t *    scs            = base + d_offset.first;\n        const ggml_half2 * dms            = reinterpret_cast<const ggml_half2 *>(base + d_offset.second);\n\n        const int        bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));\n        const int *      q4         = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));\n        const uint16_t * scales     = (const uint16_t *) scs;\n\n        int   v[2];\n        int   u[2 * QR4_K];\n        float d8[QR4_K];\n\n        v[0] = q4[0];\n        v[1] = q4[4];\n\n        uint16_t  aux[2];\n        const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2;\n        if (j < 2) {\n            aux[0] = scales[j + 0] & 0x3f3f;\n            aux[1] = scales[j + 2] & 0x3f3f;\n        } else {\n            aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);\n            aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);\n        }\n\n        const uint8_t * sc = (const uint8_t *) aux;\n        const uint8_t * m  = sc + 2;\n\n        for (int i = 0; i < QR4_K; ++i) {\n            const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1;\n            sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i);\n\n            d8[i]                   = ds_values[0];\n\n            const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4);\n            u[2 * i + 0]   = q8[0];\n            u[2 * i + 1]   = q8[4];\n        }\n\n        return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, *dms, d8);\n    }\n};\n\ntemplate <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {\n    static constexpr ggml_type gtype = GGML_TYPE_Q6_K;\n\n    using q6_k_block  = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q6_K>;\n    using q6_k_traits = typename q6_k_block::traits;\n\n    __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u,\n                                                      const int8_t * __restrict__ scales, const float d,\n                                                      const float * __restrict__ d8) {\n        float sumf = 0.0f;\n\n#pragma unroll\n        for (int i = 0; i < QR6_K; ++i) {\n            const int sc = scales[4 * i];\n\n            const int vil = (vl >> (4 * i)) & 0x0F0F0F0F;\n\n            const int vih = ((vh >> (4 * i)) << 4) & 0x30303030;\n\n            const int vi = dpct::vectorized_binary<sycl::char4>((vil | vih), 0x20202020,\n                                                                dpct::sub_sat());  // vi = (vil | vih) - 32\n\n            sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc);                        // SIMD dot product\n        }\n\n        return d * sumf;\n    }\n\n    __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,\n                     const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds,\n                     const int iqs) {\n        const uint8_t *   base   = static_cast<const uint8_t *>(vbq);\n        const uint8_t *   ql     = base + ibx_offset.first;\n        const uint8_t *   qh     = base + ibx_offset.second;\n        const int8_t *    scales = reinterpret_cast<const int8_t *>(base + d_offset.first);\n        const ggml_half * d      = (const ggml_half *) (base + d_offset.second);\n\n        const int bq8_offset   = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4);\n        const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8);\n        const int vh_shift     = 2 * ((iqs % (QI6_K / 2)) / (QI6_K / 4));\n\n        const int vl = get_int_from_uint8(ql, iqs);\n        const int vh = get_int_from_uint8(qh, (QI6_K / 4) * (iqs / (QI6_K / 2)) + iqs % (QI6_K / 4)) >> vh_shift;\n\n        const int8_t * scs = scales + scale_offset;\n\n        int   u[QR6_K];\n        float d8[QR6_K];\n\n#pragma unroll\n        for (int i = 0; i < QR6_K; ++i) {\n            u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1);\n            const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i);\n            d8[i]                       = ds_values[0];\n        }\n        return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8);\n    }\n};\n#define VDR_Q4_0_Q8_1_MMVQ 2\n#define VDR_Q4_0_Q8_1_MMQ  4\n\ntemplate <int vdr>\nstatic __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4,\n                                                    const sycl::half2 & ds8) {\n    int sumi = 0;\n#pragma unroll\n    for (int i = 0; i < vdr; ++i) {\n        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;\n        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;\n\n        // SIMD dot product of quantized values\n        sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);\n        sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);\n    }\n\n    const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();\n\n    // second part effectively subtracts 8 from each quant value\n    return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y());\n}\n\n#define VDR_Q4_1_Q8_1_MMVQ 2\n#define VDR_Q4_1_Q8_1_MMQ  4\n\ntemplate <int vdr>\nstatic __dpct_inline__ float vec_dot_q4_1_q8_1_impl(const int *v, const int *u,\n                                                    const sycl::half2 &dm4,\n                                                    const sycl::half2 &ds8) {\n\n    int sumi = 0;\n\n#pragma unroll\n    for (int i = 0; i < vdr; ++i) {\n        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;\n        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;\n\n        // SIMD dot product of quantized values\n        sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);\n        sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);\n    }\n\n#ifdef GGML_SYCL_F16\n    const sycl::float2 tmp =\n        (dm4 * ds8).convert<float, sycl::rounding_mode::automatic>();\n    const float d4d8 = tmp.x();\n    const float m4s8 = tmp.y();\n#else\n    const sycl::float2 dm4f =\n        dm4.convert<float, sycl::rounding_mode::automatic>();\n    const sycl::float2 ds8f =\n        ds8.convert<float, sycl::rounding_mode::automatic>();\n    const float d4d8 = dm4f.x() * ds8f.x();\n    const float m4s8 = dm4f.y() * ds8f.y();\n#endif // GGML_SYCL_F16\n\n    // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it\n    return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));\n}\n\n#define VDR_Q5_0_Q8_1_MMVQ 2\n#define VDR_Q5_0_Q8_1_MMQ  4\n\ntemplate <int vdr>\nstatic __dpct_inline__ float\nvec_dot_q5_0_q8_1_impl(const int *vl, const int *vh, const int *u,\n                       const float &d5, const sycl::half2 &ds8) {\n    int sumi = 0;\n\n#pragma unroll\n    for (int i = 0; i < vdr; ++i) {\n        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits\n        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4\n        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12\n        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20\n        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28\n        sumi = dpct::dp4a(vi0, u[2 * i + 0],\n                          sumi); // SIMD dot product of quantized values\n\n        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits\n        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4\n        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12\n        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20\n        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28\n        sumi = dpct::dp4a(vi1, u[2 * i + 1],\n                          sumi); // SIMD dot product of quantized values\n    }\n\n    const sycl::float2 ds8f =\n        ds8.convert<float, sycl::rounding_mode::automatic>();\n\n    // second part effectively subtracts 16 from each quant value\n    return d5 * (sumi * ds8f.x() - (16 * vdr / QI5_0) * ds8f.y());\n}\n\n#define VDR_Q5_1_Q8_1_MMVQ 2\n#define VDR_Q5_1_Q8_1_MMQ  4\n\ntemplate <int vdr>\nstatic __dpct_inline__ float\nvec_dot_q5_1_q8_1_impl(const int *vl, const int *vh, const int *u,\n                       const sycl::half2 &dm5, const sycl::half2 &ds8) {\n\n    int sumi = 0;\n\n#pragma unroll\n    for (int i = 0; i < vdr; ++i) {\n        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits\n        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4\n        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12\n        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20\n        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28\n        sumi = dpct::dp4a(vi0, u[2 * i + 0],\n                          sumi); // SIMD dot product of quantized values\n\n        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits\n        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4\n        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12\n        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20\n        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28\n        sumi = dpct::dp4a(vi1, u[2 * i + 1],\n                          sumi); // SIMD dot product of quantized values\n    }\n\n#ifdef GGML_SYCL_F16\n     const sycl::float2 tmp =\n        (dm5 * ds8).convert<float, sycl::rounding_mode::automatic>();\n    const float d5d8 = tmp.x();\n    const float m5s8 = tmp.y();\n\n\n#else\n    const sycl::float2 dm5f =\n        dm5.convert<float, sycl::rounding_mode::automatic>();\n    const sycl::float2 ds8f =\n        ds8.convert<float, sycl::rounding_mode::automatic>();\n    const float d5d8 = dm5f.x() * ds8f.x();\n    const float m5s8 = dm5f.y() * ds8f.y();\n#endif // GGML_SYCL_F16\n\n    // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it\n    return sumi*d5d8 + m5s8 / (QI5_1 / vdr);\n}\n\n#define VDR_Q8_0_Q8_1_MMVQ 2\n#define VDR_Q8_0_Q8_1_MMQ 8\n\ntemplate <int vdr>\nstatic __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int *v, const int *u,\n                                                    const float &d8_0,\n                                                    const float &d8_1) {\n\n    int sumi = 0;\n\n#pragma unroll\n    for (int i = 0; i < vdr; ++i) {\n        // SIMD dot product of quantized values\n        sumi = dpct::dp4a(v[i], u[i], sumi);\n    }\n\n    return d8_0*d8_1 * sumi;\n}\n\ntemplate <typename T, int vdr>\nstatic __dpct_inline__ T vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const T & d8_0, const T & d8_1) {\n    int sumi = 0;\n\n#pragma unroll\n    for (int i = 0; i < vdr; ++i) {\n        // SIMD dot product of quantized values\n        sumi = ggml_sycl_dp4a(v[i], u[i], sumi);\n    }\n\n    return d8_0*d8_1 * ((T) sumi);\n}\n\ntemplate <int vdr>\nstatic __dpct_inline__ float vec_dot_q8_1_q8_1_impl(const int *v, const int *u,\n                                                    const sycl::half2 &dm8,\n                                                    const sycl::half2 &ds8) {\n\n    int sumi = 0;\n\n#pragma unroll\n    for (int i = 0; i < vdr; ++i) {\n        // SIMD dot product of quantized values\n        sumi = dpct::dp4a(v[i], u[i], sumi);\n    }\n\n#ifdef GGML_SYCL_F16\n    const sycl::float2 tmp =\n        (dm8 * ds8).convert<float, sycl::rounding_mode::automatic>();\n    const float d8d8 = tmp.x();\n    const float m8s8 = tmp.y();\n#else\n    const sycl::float2 dm8f =\n        dm8.convert<float, sycl::rounding_mode::automatic>();\n    const sycl::float2 ds8f =\n        ds8.convert<float, sycl::rounding_mode::automatic>();\n    const float d8d8 = dm8f.x() * ds8f.x();\n    const float m8s8 = dm8f.y() * ds8f.y();\n#endif // GGML_SYCL_F16\n\n    // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it\n    return sumi*d8d8 + m8s8 / (QI8_1 / vdr);\n}\n\nstatic __dpct_inline__ float\nvec_dot_q4_0_q8_1(const void *__restrict__ vbq,\n                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {\n\n    const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;\n\n    int v[VDR_Q4_0_Q8_1_MMVQ];\n    int u[2 * VDR_Q4_0_Q8_1_MMVQ];\n\n#pragma unroll\n    for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {\n        v[i]         = get_int_from_uint8(bq4_0->qs, iqs + i);\n        u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);\n        u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);\n    }\n\n    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);\n}\n\nstatic __dpct_inline__ float\nvec_dot_q4_1_q8_1(const void *__restrict__ vbq,\n                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {\n\n    const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;\n\n    int v[VDR_Q4_1_Q8_1_MMVQ];\n    int u[2*VDR_Q4_1_Q8_1_MMVQ];\n\n#pragma unroll\n    for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {\n        v[i]    = get_int_from_uint8_aligned(bq4_1->qs, iqs + i);\n        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);\n        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);\n    }\n\n    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);\n}\n\n#define VDR_MXFP4_Q8_1_MMVQ 2\n#define VDR_MXFP4_Q8_1_MMQ  4\n\nstatic __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq,\n                                                const block_q8_1 * __restrict__ bq8_1,\n                                                const int & iqs) {\n    const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq;\n\n    const int * q8 = (const int *) bq8_1->qs + iqs;\n\n    int sumi = 0;\n#pragma unroll\n    for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {\n        const int aux_q4 = get_int_b1(bq4->qs, iqs + l);\n        const sycl::int2 v      = get_int_from_table_16(aux_q4, kvalues_mxfp4);\n        sumi = ggml_sycl_dp4a(v.x(), q8[l + 0], sumi);\n        sumi = ggml_sycl_dp4a(v.y(), q8[l + 4], sumi);\n    }\n\n    const float d = ggml_sycl_e8m0_to_fp32(bq4->e) * 0.5f * (bq8_1->ds)[0];\n    return d * sumi;\n}\n\n\nstatic __dpct_inline__ float\nvec_dot_q5_0_q8_1(const void *__restrict__ vbq,\n                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {\n\n    const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;\n\n    int vl[VDR_Q5_0_Q8_1_MMVQ];\n    int vh[VDR_Q5_0_Q8_1_MMVQ];\n    int  u[2*VDR_Q5_0_Q8_1_MMVQ];\n\n#pragma unroll\n    for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {\n        vl[i]    = get_int_from_uint8(bq5_0->qs, iqs + i);\n        vh[i]    = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i));\n        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);\n        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0);\n    }\n\n    return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);\n}\n\nstatic __dpct_inline__ float\nvec_dot_q5_1_q8_1(const void *__restrict__ vbq,\n                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {\n\n    const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;\n\n    int vl[VDR_Q5_1_Q8_1_MMVQ];\n    int vh[VDR_Q5_1_Q8_1_MMVQ];\n    int  u[2*VDR_Q5_1_Q8_1_MMVQ];\n\n#pragma unroll\n    for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {\n        vl[i]   = get_int_from_uint8_aligned(bq5_1->qs, iqs + i);\n        vh[i]   = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i));\n        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);\n        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);\n    }\n\n    return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);\n}\n\nstatic __dpct_inline__ float\nvec_dot_q8_0_q8_1(const void *__restrict__ vbq,\n                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {\n\n    const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;\n\n    int v[VDR_Q8_0_Q8_1_MMVQ];\n    int u[VDR_Q8_0_Q8_1_MMVQ];\n\n#pragma unroll\n    for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {\n        v[i] = get_int_from_int8(bq8_0->qs, iqs + i);\n        u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);\n    }\n\n    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d,\n                                                      bq8_1->ds[0]);\n}\n\nstatic __dpct_inline__ float\nvec_dot_q2_K_q8_1(const void *__restrict__ vbq,\n                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {\n\n    const block_q2_K * bq2_K = (const block_q2_K *) vbq;\n\n    const int bq8_offset = QR2_K * (iqs / QI8_1);\n    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);\n\n    const uint8_t * scales = bq2_K->scales + scale_offset;\n\n    const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs);\n    int    u[QR2_K];\n    float d8[QR2_K];\n\n#pragma unroll\n    for (int i = 0; i < QR2_K; ++ i) {\n        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);\n        d8[i] = bq8_1[bq8_offset + i].ds[0];\n    }\n\n    return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);\n}\n\nstatic __dpct_inline__ float\nvec_dot_q3_K_q8_1(const void *__restrict__ vbq,\n                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {\n\n    const block_q3_K * bq3_K = (const block_q3_K *) vbq;\n\n    const int bq8_offset = QR3_K * (iqs / (QI3_K/2));\n    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);\n\n    const float d = bq3_K->d;\n\n    const int vl = get_int_from_uint8(bq3_K->qs, iqs);\n\n    // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted\n    const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;\n\n    int    u[QR3_K];\n    float d8[QR3_K];\n\n#pragma unroll\n    for (int i = 0; i < QR3_K; ++i) {\n        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);\n        d8[i] = bq8_1[bq8_offset + i].ds[0];\n    }\n\n    return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);\n}\n\nstatic __dpct_inline__ float vec_dot_q4_K_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,\n                                               const int & iqs) {\n#ifndef GGML_QKK_64\n\n    const block_q4_K * bq4_K = (const block_q4_K *) vbq;\n\n    const int        bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));\n    const int *      q4         = (const int *) (bq4_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));\n    const uint16_t * scales     = (const uint16_t *) bq4_K->scales;\n\n    return vec_dot_q4_K_q8_1_common(q4, scales, bq4_K->dm, bq8_1, iqs);\n\n#else\n\n#if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics\n    const block_q4_K * bq4_K = (const block_q4_K *) vbq;\n\n    float sumf_d = 0.0f;\n    float sumf_m = 0.0f;\n\n    uint16_t aux16[2];\n    const uint8_t * s = (const uint8_t *)aux16;\n\n    const uint16_t * a = (const uint16_t *)bq4_K->scales;\n    aux16[0] = a[0] & 0x0f0f;\n    aux16[1] = (a[0] >> 4) & 0x0f0f;\n\n    const float dall = bq4_K->dm[0];\n    const float dmin = bq4_K->dm[1];\n\n    const float d8_1 = bq8_1[0].ds[0];\n    const float d8_2 = bq8_1[1].ds[1];\n\n    const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));\n    const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);\n    const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));\n    const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);\n\n    const int * q4 = (const int *)bq4_K->qs + (iqs/2);\n    const int v1 = q4[0];\n    const int v2 = q4[4];\n\n    const int dot1 = dpct::dp4a(ui2, v2 & 0x0f0f0f0f, dpct::dp4a(ui1, v1 & 0x0f0f0f0f, 0));\n    const int dot2 = dpct::dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, dpct::dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));\n    const int dot3 = dpct::dp4a(0x01010101, ui2, dpct::dp4a(0x01010101, ui1, 0));\n    const int dot4 = dpct::dp4a(0x01010101, ui4, dpct::dp4a(0x01010101, ui3, 0));\n\n    sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);\n    sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);\n\n    return dall * sumf_d - dmin * sumf_m;\n\n#else\n    bad_arch();\n#endif // __SYCL_ARCH__ >= VER_4VEC\n\n#endif\n}\n\nstatic __dpct_inline__ float\nvec_dot_q5_K_q8_1(const void *__restrict__ vbq,\n                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {\n\n#ifndef GGML_QKK_64\n    const block_q5_K * bq5_K = (const block_q5_K *) vbq;\n\n    int   vl[2];\n    int   vh[2];\n    int    u[2*QR5_K];\n    float d8[QR5_K];\n\n    const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));\n    const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));\n    const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));\n\n    vl[0] = ql[0];\n    vl[1] = ql[4];\n\n    vh[0] = qh[0] >> bq8_offset;\n    vh[1] = qh[4] >> bq8_offset;\n\n    const uint16_t * scales = (const uint16_t *)bq5_K->scales;\n    uint16_t aux[2];\n    const int j = bq8_offset/2;\n    if (j < 2) {\n        aux[0] = scales[j+0] & 0x3f3f;\n        aux[1] = scales[j+2] & 0x3f3f;\n    } else {\n        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);\n        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);\n    }\n    const uint8_t * sc = (const uint8_t *)aux;\n    const uint8_t * m  = sc + 2;\n\n#pragma unroll\n    for (int i = 0; i < QR5_K; ++i) {\n        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;\n        d8[i] = bq8i->ds[0];\n\n        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);\n        u[2*i+0] = q8[0];\n        u[2*i+1] = q8[4];\n    }\n\n    return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);\n\n#else\n\n#if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics\n    const block_q5_K * bq5_K = (const block_q5_K *) vbq;\n\n    const int8_t * s = bq5_K->scales;\n\n    const float d = bq5_K->d;\n\n    const float d8_1 = bq8_1[0].ds[0];\n    const float d8_2 = bq8_1[1].ds[1];\n\n    const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));\n    const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);\n    const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));\n    const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);\n\n    const int * ql = (const int *)bq5_K->qs + (iqs/2);\n    const int vl1 = ql[0];\n    const int vl2 = ql[4];\n\n    const int step = 4 * (iqs/2); // 0, 4, 8, 12\n    const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6\n    const int in = step%8; // 0, 4, 0, 4\n    const int vh = (*((const int *)(bq5_K->qh + in))) >> im;\n\n    const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);\n    const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);\n    const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);\n    const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);\n\n    const float sumf_d = d8_1 * (dpct::dp4a(ui1, v1, 0) * s[0] + dpct::dp4a(ui2, v2, 0) * s[1])\n                       + d8_2 * (dpct::dp4a(ui3, v3, 0) * s[2] + dpct::dp4a(ui4, v4, 0) * s[3]);\n\n    return d * sumf_d;\n\n#else\n    bad_arch();\n#endif // __SYCL_ARCH__ >= VER_4VEC\n\n#endif\n}\n\nstatic __dpct_inline__ float\nvec_dot_q6_K_q8_1(const void *__restrict__ vbq,\n                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {\n\n    const block_q6_K * bq6_K = (const block_q6_K *) vbq;\n\n    const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);\n    const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);\n    const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));\n\n    const int vl = get_int_from_uint8(bq6_K->ql, iqs);\n    const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;\n\n    const int8_t * scales = bq6_K->scales + scale_offset;\n\n    int    u[QR6_K];\n    float d8[QR6_K];\n\n#pragma unroll\n    for (int i = 0; i < QR6_K; ++i) {\n        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);\n        d8[i] = bq8_1[bq8_offset + 2 * i].ds[0];\n    }\n\n    return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);\n}\n\n\nstatic __dpct_inline__ float\nvec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq,\n                     const block_q8_1 *__restrict__ bq8_1, const int &iqs,\n                     const uint64_t *iq2xxs_grid, const uint8_t *ksigns_iq2xs,\n                     const uint8_t *kmask_iq2xs) {\n#if QK_K == 256\n    const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;\n\n    const int ib32 = iqs;\n    const uint16_t * q2 = bq2->qs + 4*ib32;\n    const uint8_t  * aux8 = (const uint8_t *)q2;\n    const int8_t   * q8 = bq8_1[ib32].qs;\n    uint32_t aux32 = q2[2] | (q2[3] << 16);\n    int sumi = 0;\n    for (int l = 0; l < 4; ++l) {\n        const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);\n        const uint8_t  signs = ksigns_iq2xs[aux32 & 127];\n        for (int j = 0; j < 8; ++j) {\n            sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);\n        }\n        q8 += 8;\n        aux32 >>= 7;\n    }\n    const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.25f;\n    return d * sumi;\n#else\n    assert(false);\n    return 0.f;\n#endif\n}\n\nstatic __dpct_inline__ float\nvec_dot_iq2_xs_q8_1(const void *__restrict__ vbq,\n                    const block_q8_1 *__restrict__ bq8_1, const int &iqs,\n                    const uint64_t *iq2xs_grid, const uint64_t *ksigns64) {\n#if DPCT_COMPATIBILITY_TEMP >=                                                 \\\n    MIN_CC_DP4A // lowest compute capability for integer intrinsics\n#if QK_K == 256\n    const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;\n\n    const int ib32 = iqs;\n    const uint16_t * q2 = bq2->qs + 4*ib32;\n    const int8_t   * q8 = bq8_1[ib32].qs;\n    const uint8_t ls1 = bq2->scales[ib32] & 0xf;\n    const uint8_t ls2 = bq2->scales[ib32] >>  4;\n    int sumi1 = 0;\n    for (int l = 0; l < 2; ++l) {\n        const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));\n        const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));\n        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(\n            grid[0] ^ signs[0], signs[0], std::minus<>());\n        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(\n            grid[1] ^ signs[1], signs[1], std::minus<>());\n        sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);\n        sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);\n        q8 += 8;\n    }\n    int sumi2 = 0;\n    for (int l = 2; l < 4; ++l) {\n        const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));\n        const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));\n        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(\n            grid[0] ^ signs[0], signs[0], std::minus<>());\n        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(\n            grid[1] ^ signs[1], signs[1], std::minus<>());\n        sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);\n        sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);\n        q8 += 8;\n    }\n    const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;\n    return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);\n#else\n    assert(false);\n    return 0.f;\n#endif\n#else\n    assert(false);\n    return 0.f;\n#endif\n}\n\nstatic __dpct_inline__ float\nvec_dot_iq2_s_q8_1(const void *__restrict__ vbq,\n                   const block_q8_1 *__restrict__ bq8_1, const int &iqs) {\n#if QK_K == 256\n    const block_iq2_s * bq2 = (const block_iq2_s *) vbq;\n\n    const int ib32 = iqs;\n    const int8_t  * q8 = bq8_1[ib32].qs;\n    const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;\n    const uint8_t ls1 = bq2->scales[ib32] & 0xf;\n    const uint8_t ls2 = bq2->scales[ib32] >>  4;\n    int sumi1 = 0;\n    for (int l = 0; l < 2; ++l) {\n        const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));\n        const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(\n            ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,\n            std::equal_to<>());\n        const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(\n            ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,\n            std::equal_to<>());\n        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(\n            grid[0] ^ signs0, signs0, std::minus<>());\n        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(\n            grid[1] ^ signs1, signs1, std::minus<>());\n        sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);\n        sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);\n        q8 += 8;\n    }\n    int sumi2 = 0;\n    for (int l = 2; l < 4; ++l) {\n        const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));\n        const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(\n            ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,\n            std::equal_to<>());\n        const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(\n            ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,\n            std::equal_to<>());\n        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(\n            grid[0] ^ signs0, signs0, std::minus<>());\n        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(\n            grid[1] ^ signs1, signs1, std::minus<>());\n        sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);\n        sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);\n        q8 += 8;\n    }\n    const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;\n    return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);\n#else\n    assert(false);\n#endif\n}\n\nstatic __dpct_inline__ float\nvec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,\n                     const block_q8_1 *__restrict__ bq8_1, const int &iqs,\n                     const uint32_t *iq3xxs_grid, const uint64_t *ksigns64) {\n#if DPCT_COMPATIBILITY_TEMP >=                                                 \\\n    MIN_CC_DP4A // lowest compute capability for integer intrinsics\n#if QK_K == 256\n    const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;\n\n    const int ib32 = iqs;\n    const uint8_t  * q3 = bq2->qs + 8*ib32;\n    const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32;\n    const int8_t   * q8 = bq8_1[ib32].qs;\n    uint32_t aux32 = gas[0] | (gas[1] << 16);\n    int sumi = 0;\n    for (int l = 0; l < 4; ++l) {\n        const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0];\n        const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1];\n        const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127));\n        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(\n            grid1[0] ^ signs[0], signs[0], std::minus<>());\n        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(\n            grid2[0] ^ signs[1], signs[1], std::minus<>());\n        sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);\n        sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);\n        q8 += 8;\n        aux32 >>= 7;\n    }\n    const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.5f;\n    return d * sumi;\n#else\n    assert(false);\n    return 0.f;\n#endif\n#else\n    assert(false);\n    return 0.f;\n#endif\n}\n\nstatic __dpct_inline__ float\nvec_dot_iq3_s_q8_1(const void *__restrict__ vbq,\n                   const block_q8_1 *__restrict__ bq8_1, const int &iqs,\n                   const uint32_t *iq3s_grid) {\n#if QK_K == 256\n    const block_iq3_s * bq2 = (const block_iq3_s *) vbq;\n\n    const int ib32 = iqs;\n    const uint8_t  * qs = bq2->qs + 8*ib32;\n    const int8_t   * q8 = bq8_1[ib32].qs;\n    int sumi = 0;\n    for (int l = 0; l < 4; ++l) {\n        const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));\n        const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));\n        uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(\n            ((bq2->signs[4 * ib32 + l] & 0xf) * 0x01010101) & 0x08040201,\n            0x08040201, std::equal_to<>());\n        uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(\n            ((bq2->signs[4 * ib32 + l] >> 4) * 0x01010101) & 0x08040201,\n            0x08040201, std::equal_to<>());\n        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(\n            grid1[0] ^ signs0, signs0, std::minus<>());\n        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(\n            grid2[0] ^ signs1, signs1, std::minus<>());\n        sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);\n        sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);\n        q8 += 8;\n    }\n    const float d =\n        (float)bq2->d *\n        (1 + 2 * ((bq2->scales[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf)) *\n        bq8_1[ib32].ds[0];\n    return d * sumi;\n#else\n    assert(false);\n#endif\n}\n\nstatic __dpct_inline__ float\nvec_dot_iq1_s_q8_1(const void *__restrict__ vbq,\n                   const block_q8_1 *__restrict__ bq8_1, const int &iqs,\n                   const uint32_t *iq1s_grid_gpu) {\n#if QK_K == 256\n    const block_iq1_s * bq1 = (const block_iq1_s *) vbq;\n\n    const int ib32 = iqs;\n    int sumi = 0;\n    const int * q8 = (const int *)bq8_1[ib32].qs;\n    for (int l = 0; l < 4; ++l) {\n        const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));\n        int grid0 = grid[0] & 0x0f0f0f0f;\n        int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;\n        sumi = dpct::dp4a(q8[2 * l + 1], grid1,\n                          dpct::dp4a(q8[2 * l + 0], grid0, sumi));\n    }\n\n    const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;\n    const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);\n    const float d = d1q * bq8_1[ib32].ds[0];\n    const float m = d1q * bq8_1[ib32].ds[1];\n    return d * sumi + m * delta;\n#else\n    assert(false);\n#endif\n}\n\nstatic __dpct_inline__ float\nvec_dot_iq1_m_q8_1(const void *__restrict__ vbq,\n                   const block_q8_1 *__restrict__ bq8_1, const int &iqs) {\n#if QK_K == 256\n    const block_iq1_m * bq1 = (const block_iq1_m *) vbq;\n\n    const int ib32 = iqs;\n    int   sumi[2] = {0, 0};\n    float sumf[2] = {0.f, 0.f};\n\n    const int * q8 = (const int *)bq8_1[ib32].qs;\n    for (int l = 0; l < 4; ++l) {\n        const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 7) << 8)));\n        int grid0 = grid[0] & 0x0f0f0f0f;\n        int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;\n        sumi[l / 2] = dpct::dp4a(q8[2 * l + 1], grid1,\n                                 dpct::dp4a(q8[2 * l + 0], grid0, sumi[l / 2]));\n        const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;\n        const int sumy = dpct::dp4a(q8[2 * l + 1], 0x01010101,\n                                    dpct::dp4a(q8[2 * l + 0], 0x01010101, 0));\n        sumf[l/2] += delta*sumy;\n    }\n\n    iq1m_scale_t scale;\n    const uint16_t * sc = (const uint16_t *)bq1->scales;\n    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);\n    const float d = (float)scale.f16 * bq8_1[ib32].ds[0];\n    return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));\n#else\n    assert(false);\n#endif\n}\n\n\nstatic __dpct_inline__ float\nvec_dot_iq4_nl_q8_1(const void *__restrict__ vbq,\n                    const block_q8_1 *__restrict__ bq8_1, const int &iqs) {\n\n    const block_iq4_nl * bq = (const block_iq4_nl *) vbq;\n\n    const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;\n    const int32_t  * q8 = (const int32_t  *)bq8_1->qs + iqs;\n\n    const uint8_t * values = (const uint8_t *)kvalues_iq4nl;\n\n    int v1, v2;\n    int sumi1 = 0, sumi2 = 0;\n    for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {\n        const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);\n        get_int_from_table_16(aux, values, v1, v2);\n        sumi1 = dpct::dp4a(v1, q8[l + 0], sumi1);\n        sumi2 = dpct::dp4a(v2, q8[l + 4], sumi2);\n    }\n\n    const float d = (float)bq->d * bq8_1->ds[0];\n    return d * (sumi1 + sumi2);\n}\n\n\nstatic __dpct_inline__ float\nvec_dot_iq4_xs_q8_1(const void *__restrict__ vbq,\n                    const block_q8_1 *__restrict__ bq8_1, const int &iqs) {\n\n#if QK_K == 256\n    const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;\n    const uint8_t * values = (const uint8_t *)kvalues_iq4nl;\n\n    // iqs is 0...7\n    const int ib32 = iqs;\n    const int32_t  * q8 = (const int *)bq8_1[ib32].qs;\n    const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;\n    const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);\n    const float d = (float)bq4->d * (ls - 32) * bq8_1[ib32].ds[0];\n    int v1, v2;\n    int sumi1 = 0, sumi2 = 0;\n    for (int j = 0; j < 4; ++j) {\n        get_int_from_table_16(q4[j], values, v1, v2);\n        sumi1 = dpct::dp4a(v1, q8[j + 0], sumi1);\n        sumi2 = dpct::dp4a(v2, q8[j + 4], sumi2);\n    }\n    return d * (sumi1 + sumi2);\n#else\n    assert(false);\n#endif\n}\n\n#endif // GGML_SYCL_VECDOTQ_HPP\n"
  },
  {
    "path": "src/ggml-sycl/wkv.cpp",
    "content": "#include <sycl/sycl.hpp>\n#include \"wkv.hpp\"\n\nconstexpr int WKV_BLOCK_SIZE = 64;\n\n// Helper function for the main kernel\ntemplate <int block_size>\nstatic void rwkv_wkv6_f32_kernel(\n    const int B, const int T, const int C, const int H,\n    const float* k, const float* v, const float* r,\n    const float* tf, const float* td, const float* s,\n    float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {\n\n    const int tid = item_ct1.get_local_id(2);\n    const int bid = item_ct1.get_group(2);\n\n    const int head_size = block_size;\n    const int batch_i = bid / H;\n    const int head_i = bid % H;\n    const int state_size = C * head_size;\n    const int n_seq_tokens = T / B;\n\n    // Set up shared memory pointers\n    float* _k = shared_mem;\n    float* _r = _k + head_size;\n    float* _tf = _r + head_size;\n    float* _td = _tf + head_size;\n\n    // Local state array\n    float state[block_size];\n\n    // Load initial state\n    #pragma unroll\n    for (int i = 0; i < head_size; i++) {\n        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];\n    }\n\n    // Sync threads before shared memory operations\n    item_ct1.barrier(sycl::access::fence_space::local_space);\n\n    // Load time-mixing parameters\n    _tf[tid] = tf[head_i * head_size + tid];\n    item_ct1.barrier(sycl::access::fence_space::local_space);\n\n    // Main sequence processing loop\n    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;\n         t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;\n         t += C) {\n\n        item_ct1.barrier(sycl::access::fence_space::local_space);\n\n        // Load current timestep data to shared memory\n        _k[tid] = k[t];\n        _r[tid] = r[t];\n        _td[tid] = td[t];\n\n        item_ct1.barrier(sycl::access::fence_space::local_space);\n\n        const float _v = v[t];\n        float y = 0;\n\n        // Process in chunks of 4 for better vectorization\n        sycl::float4 k4, r4, tf4, td4, s4;\n        #pragma unroll\n        for (int j = 0; j < head_size; j += 4) {\n            // Load data in vec4 chunks\n            k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);\n            r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);\n            tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);\n            td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);\n            s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);\n\n            // Compute key-value product\n            sycl::float4 kv4 = k4 * _v;\n\n            // Accumulate weighted sum\n            y += sycl::dot(r4, tf4 * kv4 + s4);\n\n            // Update state\n            s4 = s4 * td4 + kv4;\n\n            // Store updated state\n            state[j] = s4.x();\n            state[j+1] = s4.y();\n            state[j+2] = s4.z();\n            state[j+3] = s4.w();\n        }\n\n        dst[t] = y;\n    }\n\n    // Save final state\n    #pragma unroll\n    for (int i = 0; i < head_size; i++) {\n        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];\n    }\n}\n\ntemplate <int block_size>\nstatic void rwkv_wkv7_f32_kernel(\n    const int B, const int T, const int C, const int H,\n    const float* r, const float* w, const float* k, const float* v,\n    const float* a, const float* b, const float* s,\n    float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {\n\n    const int tid = item_ct1.get_local_id(2);\n    const int bid = item_ct1.get_group(2);\n\n    const int head_size = block_size;\n    const int batch_i = bid / H;\n    const int head_i = bid % H;\n    const int state_size = C * head_size;\n    const int n_seq_tokens = T / B;\n\n    float* _r = shared_mem;\n    float* _w = _r + head_size;\n    float* _k = _w + head_size;\n    float* _a = _k + head_size;\n    float* _b = _a + head_size;\n\n    float state[block_size];\n\n    #pragma unroll\n    for (int i = 0; i < head_size; i++) {\n        state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];\n    }\n\n    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;\n         t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;\n         t += C) {\n\n        item_ct1.barrier(sycl::access::fence_space::local_space);\n\n        _r[tid] = r[t];\n        _w[tid] = w[t];\n        _k[tid] = k[t];\n        _a[tid] = a[t];\n        _b[tid] = b[t];\n\n        item_ct1.barrier(sycl::access::fence_space::local_space);\n\n        const float _v = v[t];\n        float y = 0, sa = 0;\n        sycl::float4 a4, s4;\n\n        #pragma unroll\n        for (int j = 0; j < head_size; j += 4) {\n            a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);\n            s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);\n            sa += sycl::dot(a4, s4);\n        }\n\n        sycl::float4 r4, w4, k4, b4;\n        #pragma unroll\n        for (int j = 0; j < head_size; j += 4) {\n            r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);\n            w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);\n            k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);\n            b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);\n            s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);\n\n            sycl::float4 kv4 = k4 * _v;\n\n            s4 = s4 * w4 + kv4 + sa * b4;\n            y += sycl::dot(r4, s4);\n\n            state[j] = s4.x();\n            state[j+1] = s4.y();\n            state[j+2] = s4.z();\n            state[j+3] = s4.w();\n        }\n\n        dst[t] = y;\n    }\n\n    #pragma unroll\n    for (int i = 0; i < head_size; i++) {\n        dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];\n    }\n}\n\nvoid ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6);\n    const float* k_d = (const float*)dst->src[0]->data;\n    const float* v_d = (const float*)dst->src[1]->data;\n    const float* r_d = (const float*)dst->src[2]->data;\n    const float* tf_d = (const float*)dst->src[3]->data;\n    const float* td_d = (const float*)dst->src[4]->data;\n    const float* s_d = (const float*)dst->src[5]->data;\n    float* dst_d = (float*)dst->data;\n\n    const int64_t B = dst->src[5]->ne[1];\n    const int64_t T = dst->src[0]->ne[2];\n    const int64_t C = dst->ne[0];\n    const int64_t H = dst->src[0]->ne[1];\n\n    GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);\n    GGML_ASSERT(C % H == 0);\n    GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64\n\n    dpct::queue_ptr stream = ctx.stream();\n\n    // Calculate execution configuration\n    const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td\n    sycl::range<3> block_dims(1, 1, C / H);\n    sycl::range<3> grid_dims(1, 1, B * H);\n\n    // Submit kernel\n    if (C / H == WKV_BLOCK_SIZE) {\n        stream->submit([&](sycl::handler& cgh) {\n            sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(grid_dims * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1) {\n                    rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(\n                        B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,\n                        item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()\n                    );\n                });\n        });\n    } else {\n        stream->submit([&](sycl::handler& cgh) {\n            sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(grid_dims * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1) {\n                    rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(\n                        B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,\n                        item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()\n                    );\n                });\n        });\n    }\n}\n\nvoid ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {\n    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/7);\n    const float* r_d = (const float*)dst->src[0]->data;\n    const float* w_d = (const float*)dst->src[1]->data;\n    const float* k_d = (const float*)dst->src[2]->data;\n    const float* v_d = (const float*)dst->src[3]->data;\n    const float* a_d = (const float*)dst->src[4]->data;\n    const float* b_d = (const float*)dst->src[5]->data;\n    const float* s_d = (const float*)dst->src[6]->data;\n    float* dst_d = (float*)dst->data;\n\n    const int64_t B = dst->src[6]->ne[1];\n    const int64_t T = dst->src[0]->ne[2];\n    const int64_t C = dst->ne[0];\n    const int64_t H = dst->src[0]->ne[1];\n\n    GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);\n    GGML_ASSERT(C % H == 0);\n    GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2);\n\n    dpct::queue_ptr stream = ctx.stream();\n\n    // Calculate execution configuration\n    const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b\n    sycl::range<3> block_dims(1, 1, C / H);\n    sycl::range<3> grid_dims(1, 1, B * H);\n\n    // Submit kernel\n    if (C / H == WKV_BLOCK_SIZE) {\n        stream->submit([&](sycl::handler& cgh) {\n            sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(grid_dims * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1) {\n                    rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(\n                        B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,\n                        item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()\n                    );\n                });\n        });\n    } else {\n        stream->submit([&](sycl::handler& cgh) {\n            sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);\n\n            cgh.parallel_for(\n                sycl::nd_range<3>(grid_dims * block_dims, block_dims),\n                [=](sycl::nd_item<3> item_ct1) {\n                    rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(\n                        B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,\n                        item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()\n                    );\n                });\n        });\n    }\n}\n"
  },
  {
    "path": "src/ggml-sycl/wkv.hpp",
    "content": "#ifndef GGML_SYCL_WKV_HPP\n#define GGML_SYCL_WKV_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\nvoid ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst);\n\n#endif // GGML_SYCL_WKV_HPP\n"
  },
  {
    "path": "src/ggml-threading.cpp",
    "content": "#include \"ggml-threading.h\"\n#include <mutex>\n\nstd::mutex ggml_critical_section_mutex;\n\nvoid ggml_critical_section_start() {\n    ggml_critical_section_mutex.lock();\n}\n\nvoid ggml_critical_section_end(void) {\n    ggml_critical_section_mutex.unlock();\n}\n"
  },
  {
    "path": "src/ggml-threading.h",
    "content": "#pragma once\n\n#include \"ggml.h\"\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nGGML_API void ggml_critical_section_start(void);\nGGML_API void ggml_critical_section_end(void);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "src/ggml-virtgpu/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.19)\ncmake_policy(SET CMP0114 NEW)\n\ninclude(ExternalProject)\n\nmessage(STATUS \"Including the VirtGPU/Virglrenderer API Remoting\")\n\n# Download venus_hw.h from virglrenderer repository\nExternalProject_Add(\n    venus_hw_header\n    URL https://gitlab.freedesktop.org/virgl/virglrenderer/-/raw/virglrenderer-1.2.0/src/venus_hw.h\n    DOWNLOAD_NO_EXTRACT YES\n    DOWNLOAD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include\n    DOWNLOAD_NAME venus_hw.h\n    CONFIGURE_COMMAND \"\"\n    BUILD_COMMAND \"\"\n    INSTALL_COMMAND \"\"\n    LOG_DOWNLOAD ON\n)\n\nif (NOT GGML_VIRTGPU_BACKEND STREQUAL \"ONLY\")\n    message(STATUS \"Enable the VirtGPU/Virglrenderer API Remoting frontend library\")\n\n    find_package(PkgConfig REQUIRED)\n    pkg_check_modules(DRM REQUIRED libdrm)\n    if (NOT GGML_BACKEND_DL)\n      # cannot simply use USE_VIRTGPU, as in the 'else()' case the\n      # frontend isn't compiled\n      target_compile_definitions(ggml PUBLIC \"GGML_USE_VIRTGPU_FRONTEND\")\n    endif()\n\n    ggml_add_backend_library(ggml-virtgpu\n                             ggml-backend-buffer.cpp\n                             ggml-backend.cpp\n                             ggml-backend-device.cpp\n                             ggml-backend-reg.cpp\n                             ggml-backend-buffer-type.cpp\n                             virtgpu-apir.h\n                             virtgpu-forward.gen.h\n                             virtgpu.cpp\n                             virtgpu-shm.cpp\n                             virtgpu-utils.cpp\n                             virtgpu-forward-device.cpp\n                             virtgpu-forward-buffer-type.cpp\n                             virtgpu-forward-buffer.cpp\n                             virtgpu-forward-backend.cpp\n                             virtgpu-forward-impl.h\n                             apir_cs_ggml-rpc-front.cpp\n                             ../../include/ggml-virtgpu.h)\n\n    target_include_directories(ggml-virtgpu PUBLIC /usr/include/libdrm/)\n\n    target_link_libraries(ggml-virtgpu PUBLIC ${DRM_LIBRARIES})\n    target_include_directories(ggml-virtgpu PUBLIC ${DRM_INCLUDE_DIRS})\n    target_compile_options(ggml-virtgpu PUBLIC ${DRM_CFLAGS_OTHER})\n\n    target_include_directories(ggml-virtgpu PUBLIC ./include)\n    target_include_directories(ggml-virtgpu PRIVATE ${CMAKE_CURRENT_BINARY_DIR})\n\n    # Ensure venus_hw.h is downloaded before building ggml-virtgpu\n    add_dependencies(ggml-virtgpu venus_hw_header)\n\n    target_compile_options(ggml-virtgpu PRIVATE -std=c++20)\nelse()\n    message(STATUS \"Not building the VirtGPU/Virglrenderer API Remoting frontend library\")\nendif()\n\nif (NOT GGML_VIRTGPU_BACKEND STREQUAL \"OFF\")\n    add_subdirectory(\"backend\")\nendif()\n"
  },
  {
    "path": "src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp",
    "content": "#include \"backend/shared/apir_cs_rpc.h\"\n#include \"ggml-backend-impl.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-remoting.h\"\n\n#include <cinttypes>\n#include <unordered_map>\n#include <unordered_set>\n#include <vector>\n\napir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor) {\n    apir_rpc_tensor result;\n    result.id   = reinterpret_cast<uint64_t>(tensor);\n    result.type = tensor->type;\n    if (tensor->buffer) {\n        ggml_backend_buffer_t buffer = tensor->buffer;\n\n        result.buffer = BUFFER_TO_HOST_HANDLE(buffer);\n    } else {\n        result.buffer = 0;\n    }\n    for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {\n        result.ne[i] = tensor->ne[i];\n        result.nb[i] = tensor->nb[i];\n    }\n    result.op = tensor->op;\n    for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {\n        result.op_params[i] = tensor->op_params[i];\n    }\n    result.flags = tensor->flags;\n    for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {\n        result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);\n    }\n    result.view_src  = reinterpret_cast<uint64_t>(tensor->view_src);\n    result.view_offs = tensor->view_offs;\n    result.data      = reinterpret_cast<uint64_t>(tensor->data);\n    if (tensor->data) {\n        if (!tensor->buffer) {\n            GGML_ABORT(\"%s: tensor has data but not buffer\", __func__);\n        }\n        // tensor->data is serialized as an offset to the buffer base address\n        result.data -= reinterpret_cast<uint64_t>(BUFFER_TO_GGML_CONTEXT(tensor->buffer)->base);\n    }\n    snprintf(result.name, GGML_MAX_NAME, \"%s\", tensor->name);\n    return result;\n}\n\nvoid apir_add_tensor(ggml_tensor *                       tensor,\n                     std::vector<apir_rpc_tensor> &      tensors,\n                     std::unordered_set<ggml_tensor *> & visited) {\n    if (tensor == nullptr) {\n        return;\n    }\n    if (visited.find(tensor) != visited.end()) {\n        return;\n    }\n    visited.insert(tensor);\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        apir_add_tensor(tensor->src[i], tensors, visited);\n    }\n    apir_add_tensor(tensor->view_src, tensors, visited);\n    tensors.push_back(apir_serialize_tensor(tensor));\n}\n\nvoid apir_serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {\n    uint32_t                          n_nodes = cgraph->n_nodes;\n    std::vector<apir_rpc_tensor>      tensors;\n    std::unordered_set<ggml_tensor *> visited;\n    for (uint32_t i = 0; i < n_nodes; i++) {\n        apir_add_tensor(cgraph->nodes[i], tensors, visited);\n    }\n    // serialization format:\n    // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(apir_rpc_tensor)) |\n    uint32_t n_tensors = tensors.size();\n    int      output_size =\n        sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(apir_rpc_tensor);\n    output.resize(output_size, 0);\n    memcpy(output.data(), &n_nodes, sizeof(n_nodes));\n    for (uint32_t i = 0; i < n_nodes; i++) {\n        memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));\n    }\n    uint32_t * out_ntensors = (uint32_t *) (output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));\n    *out_ntensors           = n_tensors;\n    apir_rpc_tensor * out_tensors =\n        (apir_rpc_tensor *) (output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));\n    memcpy(out_tensors, tensors.data(), n_tensors * sizeof(apir_rpc_tensor));\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.19)\ncmake_policy(SET CMP0114 NEW)\n\nmessage(STATUS \"Enable the VirtGPU/Virglrenderer backend library\")\n\nggml_add_backend_library(ggml-virtgpu-backend\n                         backend.cpp\n                         backend-dispatched.cpp\n                         backend-dispatched-backend.cpp\n                         backend-dispatched-device.cpp\n                         backend-dispatched-buffer.cpp\n                         backend-dispatched-buffer-type.cpp\n                         shared/api_remoting.h\n                         shared/apir_backend.h\n                         shared/apir_cs.h\n                         apir_cs_ggml-rpc-back.cpp)\n\ntarget_compile_options(ggml-virtgpu-backend PRIVATE -std=c++20)\n\n# Add include directory for ggml-backend-impl.h and other core headers\ntarget_include_directories(ggml-virtgpu-backend PRIVATE ../..)\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp",
    "content": "#include \"ggml-backend-impl.h\"\n#include \"ggml-impl.h\"\n#include \"shared/apir_cs_rpc.h\"\n\n#include <cinttypes>\n#include <unordered_map>\n#include <unordered_set>\n#include <vector>\n\nstd::unordered_set<ggml_backend_buffer_t> backend_buffers;\n\nvoid apir_track_backend_buffer(ggml_backend_buffer_t buffer) {\n    backend_buffers.insert(buffer);\n}\n\nbool apir_untrack_backend_buffer(ggml_backend_buffer_t buffer) {\n    auto it = backend_buffers.find(buffer);\n    if (it == backend_buffers.end()) {\n        return false;\n    }\n\n    backend_buffers.erase(it);\n    return true;\n}\n\nstd::unordered_set<ggml_backend_buffer_t> apir_get_track_backend_buffers() {\n    return backend_buffers;\n}\n\nggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor) {\n    ggml_tensor * result =\n        ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);\n    for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {\n        result->nb[i] = tensor->nb[i];\n    }\n    result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);\n    if (result->buffer && backend_buffers.find(result->buffer) == backend_buffers.end()) {\n        printf(\"WARNING: HOST BUFFER NOT FOUND | %p\\n\", (void *) result->buffer);\n        result->buffer = nullptr;\n    }\n\n    uint64_t tensor_data = tensor->data;\n    if (result->buffer) {\n        // require that the tensor data does not go beyond the buffer end\n        uint64_t tensor_size  = (uint64_t) ggml_nbytes(result);\n        uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);\n        uint64_t buffer_size  = (uint64_t) ggml_backend_buffer_get_size(result->buffer);\n\n        // tensor->data is serialized as an offset to the buffer base address\n        tensor_data += buffer_start;\n\n        GGML_ASSERT(tensor_data + tensor_size >= tensor_data);  // check for overflow\n        GGML_ASSERT(tensor_data >= buffer_start && tensor_data + tensor_size <= buffer_start + buffer_size);\n    }\n\n    result->op = (ggml_op) tensor->op;\n    for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {\n        result->op_params[i] = tensor->op_params[i];\n    }\n    result->flags = tensor->flags;\n    result->data  = reinterpret_cast<void *>(tensor_data);\n    ggml_set_name(result, tensor->name);\n    return result;\n}\n\nggml_tensor * apir_create_node(uint64_t                                                      id,\n                               ggml_context *                                                ctx,\n                               const std::unordered_map<uint64_t, const apir_rpc_tensor *> & tensor_ptrs,\n                               std::unordered_map<uint64_t, ggml_tensor *> &                 tensor_map) {\n    if (id == 0) {\n        return nullptr;\n    }\n    if (tensor_map.find(id) != tensor_map.end()) {\n        return tensor_map[id];\n    }\n    const apir_rpc_tensor * tensor = tensor_ptrs.at(id);\n    ggml_tensor *           result = apir_deserialize_tensor(ctx, tensor);\n    if (result == nullptr) {\n        return nullptr;\n    }\n    tensor_map[id] = result;\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        result->src[i] = apir_create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);\n    }\n    result->view_src  = apir_create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);\n    result->view_offs = tensor->view_offs;\n    return result;\n}\n\nggml_cgraph * apir_deserialize_graph(uint32_t                n_nodes,\n                                     uint32_t                n_tensors,\n                                     const apir_rpc_tensor * tensors,\n                                     const uint64_t *        nodes) {\n    size_t buf_size = ggml_tensor_overhead() * (n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);\n    ggml_init_params params = {\n        /*.mem_size   =*/buf_size,\n        /*.mem_buffer =*/NULL,\n        /*.no_alloc   =*/true,\n    };\n    ggml_context * ctx   = ggml_init(params);\n    ggml_cgraph *  graph = ggml_new_graph_custom(ctx, n_nodes, false);\n    graph->n_nodes       = n_nodes;\n    std::unordered_map<uint64_t, const apir_rpc_tensor *> tensor_ptrs;\n    for (uint32_t i = 0; i < n_tensors; i++) {\n        tensor_ptrs[tensors[i].id] = &tensors[i];\n    }\n    std::unordered_map<uint64_t, ggml_tensor *> tensor_map;\n    for (uint32_t i = 0; i < n_nodes; i++) {\n        int64_t id;\n        memcpy(&id, &nodes[i], sizeof(id));\n        graph->nodes[i] = apir_create_node(id, ctx, tensor_ptrs, tensor_map);\n    }\n\n    return graph;\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/backend-convert.h",
    "content": "#include \"shared/apir_backend.h\"\n\n#define BUFFER_TO_HOST_HANDLE(name) ggml_buffer_to_apir_handle(name)\n\nstatic inline apir_buffer_host_handle_t ggml_buffer_to_apir_handle(ggml_backend_buffer_t buffer) {\n    // in the backend, the buffer handle is the buffer pointer\n    return (apir_buffer_host_handle_t) buffer;\n}\n\nstatic inline apir_buffer_type_host_handle_t ggml_buffer_type_to_apir_handle(ggml_backend_buffer_type_t buft) {\n    // in the backend, the buffer handle is the buffer pointer\n    return (apir_buffer_type_host_handle_t) buft;\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/backend-dispatched-backend.cpp",
    "content": "#include \"backend-dispatched.h\"\n#include \"backend-virgl-apir.h\"\n#include \"ggml-backend-impl.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-impl.h\"\n#include \"shared/apir_backend.h\"\n\n#include <cstdint>\n\nstatic uint32_t validate_graph_operation(size_t cgraph_size, uint32_t shmem_res_id, const char * operation) {\n    if (cgraph_size == 0) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Zero-size computation graph\\n\", operation);\n        return 1;\n    }\n\n    // place-holder: validate that the size of shmem_res_id is <= cgraph_size\n    // need to add another method in the Virgl->APIR callback interface\n    GGML_UNUSED(shmem_res_id);\n\n    return 0;  // Valid\n}\n\nuint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n\n    static bool async_backend_initialized = false;\n    static bool async_backend;\n\n    if (!async_backend_initialized) {\n        ggml_backend_dev_props props;\n\n        dev->iface.get_props(dev, &props);\n        async_backend             = props.caps.async;\n        async_backend_initialized = true;\n    }\n\n    uint32_t shmem_res_id;\n    apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);\n\n    const void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);\n    if (!shmem_data) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Couldn't get the shmem addr from virgl\\n\", __func__);\n        apir_decoder_set_fatal(dec);\n        return 1;\n    }\n    size_t cgraph_size;\n    apir_decode_size_t(dec, &cgraph_size);\n\n    if (validate_graph_operation(cgraph_size, shmem_res_id, __func__) != 0) {\n        apir_decoder_set_fatal(dec);\n        return 1;\n    }\n\n    apir_decoder secondary_dec = apir_new_decoder((const char *) shmem_data, cgraph_size);\n\n    ggml_cgraph * cgraph = apir_decode_ggml_cgraph(&secondary_dec, cgraph_size);\n\n    if (!cgraph || apir_decoder_get_fatal(&secondary_dec)) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Failed to deserialize computation graph\\n\", __func__);\n        return 1;\n    }\n\n    if (cgraph->n_nodes < 0 || cgraph->n_leafs < 0) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Invalid negative node/leaf count: nodes=%d leafs=%d\\n\", __func__,\n                       cgraph->n_nodes, cgraph->n_leafs);\n        return 1;\n    }\n\n    ggml_status status;\n#if APIR_BACKEND_CHECK_SUPPORTS_OP == 1\n    for (int idx = 0; idx < cgraph->n_nodes; idx++) {\n        ggml_tensor * op = ggml_graph_node(cgraph, idx);\n        if (dev->iface.supports_op(dev, op)) {\n            continue;\n        }\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Graph node %d (%s) not supported by the backend\\n\", __func__, idx,\n                       ggml_op_desc(op));\n\n        status = GGML_STATUS_ABORTED;\n        apir_encode_ggml_status(enc, &status);\n\n        return 0;\n    }\n#endif\n\n    // Check if backend is properly initialized\n    if (!bck) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Backend not initialized (bck is null)\\n\", __func__);\n\n        return 1;\n    }\n\n    status = bck->iface.graph_compute(bck, cgraph);\n\n    if (async_backend && bck->iface.synchronize) {\n        bck->iface.synchronize(bck);\n    }\n\n    apir_encode_ggml_status(enc, &status);\n\n    return 0;\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp",
    "content": "#include \"backend-dispatched.h\"\n#include \"backend-virgl-apir.h\"\n#include \"ggml-backend-impl.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-impl.h\"\n\n#include <cstdint>\n\nuint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    ggml_backend_buffer_type_t buft;\n    buft = apir_decode_ggml_buffer_type(dec);\n\n    const char * string = buft->iface.get_name(buft);\n\n    const size_t string_size = strlen(string) + 1;\n    apir_encode_array_size(enc, string_size);\n    apir_encode_char_array(enc, string, string_size);\n\n    return 0;\n}\n\nuint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    ggml_backend_buffer_type_t buft;\n    buft = apir_decode_ggml_buffer_type(dec);\n\n    size_t value = buft->iface.get_alignment(buft);\n    apir_encode_size_t(enc, &value);\n\n    return 0;\n}\n\nuint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    ggml_backend_buffer_type_t buft;\n    buft = apir_decode_ggml_buffer_type(dec);\n\n    size_t value = SIZE_MAX;\n    if (buft->iface.get_max_size) {\n        value = buft->iface.get_max_size(buft);\n    }\n\n    apir_encode_size_t(enc, &value);\n\n    return 0;\n}\n\n/* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST is deprecated. Keeping the handler for backward compatibility. */\nuint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(dec);\n    const bool is_host = false;\n\n    apir_encode_bool_t(enc, &is_host);\n\n    return 0;\n}\n\nuint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    ggml_backend_buffer_type_t buft;\n    buft = apir_decode_ggml_buffer_type(dec);\n\n    size_t size;\n    apir_decode_size_t(dec, &size);\n\n    ggml_backend_buffer_t buffer;\n\n    buffer = buft->iface.alloc_buffer(buft, size);\n\n    apir_encode_ggml_buffer(enc, buffer);\n\n    if (buffer) {\n        apir_track_backend_buffer(buffer);\n    }\n\n    return 0;\n}\n\nuint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    ggml_backend_buffer_type_t buft;\n    buft = apir_decode_ggml_buffer_type(dec);\n\n    const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec);\n\n    // Check for decode error\n    if (op == nullptr) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Failed to decode tensor\\n\", __func__);\n        apir_decoder_set_fatal(dec);\n        return 1;\n    }\n\n    size_t value;\n    if (buft->iface.get_alloc_size) {\n        value = buft->iface.get_alloc_size(buft, op);\n    } else {\n        value = ggml_nbytes(op);  // Default fallback\n    }\n\n    apir_encode_size_t(enc, &value);\n\n    return 0;\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp",
    "content": "#include \"backend-dispatched.h\"\n#include \"backend-virgl-apir.h\"\n#include \"ggml-backend-impl.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-impl.h\"\n\n#include <cstdint>\n\nstatic uint32_t validate_buffer_operation(size_t offset, size_t size, const char * operation) {\n    // Only check for critical integer overflow - no arbitrary size limits\n    if (offset > SIZE_MAX - size) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Integer overflow in offset+size: %zu + %zu\\n\", operation, offset, size);\n        return 1;\n    }\n\n    return 0;  // Valid\n}\n\nuint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    ggml_backend_buffer_t buffer;\n    buffer = apir_decode_ggml_buffer(dec);\n\n    if (!buffer || apir_decoder_get_fatal(dec)) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Invalid buffer handle from guest\\n\", __func__);\n        return 1;\n    }\n\n    uintptr_t base = (uintptr_t) buffer->iface.get_base(buffer);\n    apir_encode_uintptr_t(enc, &base);\n\n    return 0;\n}\n\nuint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(enc);\n\n    ggml_backend_buffer_t buffer;\n    buffer = apir_decode_ggml_buffer(dec);\n\n    if (!buffer || apir_decoder_get_fatal(dec)) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Invalid buffer handle from guest\\n\", __func__);\n        return 1;\n    }\n\n    ggml_tensor * tensor;\n    // safe to remove the const qualifier here\n    tensor = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec);\n\n    uint32_t shmem_res_id;\n    apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);\n\n    size_t offset;\n    apir_decode_size_t(dec, &offset);\n\n    size_t size;\n    apir_decode_size_t(dec, &size);\n\n    if (validate_buffer_operation(offset, size, __func__) != 0) {\n        return 1;\n    }\n\n    void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);\n\n    if (!shmem_data) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Couldn't get the shmem addr from virgl\\n\", __func__);\n        return 1;\n    }\n\n    buffer->iface.set_tensor(buffer, tensor, shmem_data, offset, size);\n\n    return 0;\n}\n\nuint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(enc);\n\n    ggml_backend_buffer_t buffer;\n    buffer = apir_decode_ggml_buffer(dec);\n\n    if (!buffer || apir_decoder_get_fatal(dec)) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Invalid buffer handle from guest\\n\", __func__);\n        return 1;\n    }\n\n    const ggml_tensor * tensor;\n    // safe to remove the const qualifier here\n    tensor = apir_decode_ggml_tensor(dec);\n\n    uint32_t shmem_res_id;\n    apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);\n\n    size_t offset;\n    apir_decode_size_t(dec, &offset);\n\n    size_t size;\n    apir_decode_size_t(dec, &size);\n\n    if (validate_buffer_operation(offset, size, __func__) != 0) {\n        return 1;\n    }\n\n    void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);\n    if (!shmem_data) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Couldn't get the shmem addr from virgl\\n\", __func__);\n        return 1;\n    }\n\n    buffer->iface.get_tensor(buffer, tensor, shmem_data, offset, size);\n\n    return 0;\n}\n\nuint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n\n    ggml_backend_buffer_t buffer;\n    buffer = apir_decode_ggml_buffer(dec);\n\n    if (!buffer || apir_decoder_get_fatal(dec)) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Invalid buffer handle from guest\\n\", __func__);\n        return 1;\n    }\n\n    const ggml_tensor * src;\n    // safe to remove the const qualifier here\n    src               = apir_decode_ggml_tensor(dec);\n    ggml_tensor * dst = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec);\n\n    bool ret = buffer->iface.cpy_tensor(buffer, src, (ggml_tensor *) dst);\n\n    apir_encode_bool_t(enc, &ret);\n\n    return 0;\n}\n\nuint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(enc);\n\n    ggml_backend_buffer_t buffer;\n    buffer = apir_decode_ggml_buffer(dec);\n\n    if (!buffer || apir_decoder_get_fatal(dec)) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Invalid buffer handle from guest\\n\", __func__);\n        return 1;\n    }\n\n    uint8_t value;\n    apir_decode_uint8_t(dec, &value);\n\n    buffer->iface.clear(buffer, value);\n\n    return 0;\n}\n\nuint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(enc);\n\n    ggml_backend_buffer_t buffer;\n    buffer = apir_decode_ggml_buffer(dec);\n\n    if (!buffer || apir_decoder_get_fatal(dec)) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Invalid buffer handle from guest\\n\", __func__);\n        return 1;\n    }\n\n    if (!apir_untrack_backend_buffer(buffer)) {\n        GGML_LOG_WARN(GGML_VIRTGPU_BCK \"%s: unknown buffer %p\\n\", __func__, (void *) buffer);\n        return 1;\n    }\n\n    buffer->iface.free_buffer(buffer);\n\n    return 0;\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/backend-dispatched-device.cpp",
    "content": "#include \"backend-dispatched.h\"\n#include \"backend-virgl-apir.h\"\n#include \"ggml-backend-impl.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-impl.h\"\n\n#include <cstdint>\n\nuint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(dec);\n\n    int32_t dev_count = reg->iface.get_device_count(reg);\n    apir_encode_int32_t(enc, &dev_count);\n\n    return 0;\n}\n\nuint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(dec);\n\n    int32_t dev_count = reg->iface.get_device_count(reg);\n    apir_encode_int32_t(enc, &dev_count);\n\n    return 0;\n}\n\nuint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(dec);\n\n    const char * string = dev->iface.get_name(dev);\n\n    const size_t string_size = strlen(string) + 1;\n    apir_encode_array_size(enc, string_size);\n    apir_encode_char_array(enc, string, string_size);\n\n    return 0;\n}\n\nuint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(dec);\n\n    const char * string = dev->iface.get_description(dev);\n\n    const size_t string_size = strlen(string) + 1;\n    apir_encode_array_size(enc, string_size);\n    apir_encode_char_array(enc, string, string_size);\n\n    return 0;\n}\n\nuint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(dec);\n\n    uint32_t type = dev->iface.get_type(dev);\n    apir_encode_uint32_t(enc, &type);\n\n    return 0;\n}\n\nuint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(dec);\n\n    size_t free, total;\n    dev->iface.get_memory(dev, &free, &total);\n\n    apir_encode_size_t(enc, &free);\n    apir_encode_size_t(enc, &total);\n\n    return 0;\n}\n\nuint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n\n    const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec);\n\n    bool supports_op = dev->iface.supports_op(dev, op);\n\n    apir_encode_bool_t(enc, &supports_op);\n\n    return 0;\n}\n\nuint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(dec);\n\n    ggml_backend_buffer_type_t bufft = dev->iface.get_buffer_type(dev);\n\n    apir_encode_ggml_buffer_type(enc, bufft);\n\n    return 0;\n}\n\nuint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(dec);\n\n    ggml_backend_dev_props props;\n    dev->iface.get_props(dev, &props);\n\n    apir_encode_bool_t(enc, &props.caps.async);\n    apir_encode_bool_t(enc, &props.caps.host_buffer);\n    apir_encode_bool_t(enc, &props.caps.buffer_from_host_ptr);\n    apir_encode_bool_t(enc, &props.caps.events);\n\n    return 0;\n}\n\nuint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(dec);\n\n    uint32_t shmem_res_id;\n    apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);\n\n    void * shmem_ptr = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);\n    if (!shmem_ptr) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Couldn't get the shmem addr from virgl\\n\", __func__);\n        apir_decoder_set_fatal(dec);\n        return 1;\n    }\n\n    size_t size;\n    apir_decode_size_t(dec, &size);\n    size_t max_tensor_size;\n    apir_decode_size_t(dec, &max_tensor_size);\n\n    ggml_backend_buffer_t buffer;\n    buffer = dev->iface.buffer_from_host_ptr(dev, shmem_ptr, size, max_tensor_size);\n\n    apir_encode_ggml_buffer(enc, buffer);\n    apir_encode_ggml_buffer_type(enc, buffer->buft);\n\n    if (buffer) {\n        apir_track_backend_buffer(buffer);\n    }\n\n    return 0;\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/backend-dispatched.cpp",
    "content": "#include \"backend-dispatched.h\"\n\n#include \"backend-virgl-apir.h\"\n#include \"ggml-backend-impl.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-impl.h\"\n\n#include <cstdint>\n\nggml_backend_reg_t reg = NULL;\nggml_backend_dev_t dev = NULL;\nggml_backend_t     bck = NULL;\n\nuint64_t timer_start = 0;\nuint64_t timer_total = 0;\nuint64_t timer_count = 0;\n\nuint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p) {\n    if (reg != NULL) {\n        GGML_LOG_WARN(GGML_VIRTGPU_BCK \"%s: already initialized\\n\", __func__);\n        return APIR_BACKEND_INITIALIZE_ALREADY_INITED;\n    }\n    ggml_backend_reg_t (*ggml_backend_reg_fct)(void) = (ggml_backend_reg_t (*)()) ggml_backend_reg_fct_p;\n\n    reg = ggml_backend_reg_fct();\n    if (reg == NULL) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: backend registration failed\\n\", __func__);\n        return APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED;\n    }\n\n    size_t device_count = reg->iface.get_device_count(reg);\n    if (!device_count) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: no device found\\n\", __func__);\n        return APIR_BACKEND_INITIALIZE_NO_DEVICE;\n    }\n\n    dev = reg->iface.get_device(reg, 0);\n\n    if (!dev) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: failed to get device\\n\", __func__);\n        return APIR_BACKEND_INITIALIZE_NO_DEVICE;\n    }\n\n    bck = dev->iface.init_backend(dev, NULL);\n    if (!bck) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: backend initialization failed\\n\", __func__);\n        return APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED;\n    }\n\n    return APIR_BACKEND_INITIALIZE_SUCCESS;\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/backend-dispatched.gen.h",
    "content": "#pragma once\n\n/* device */\nuint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\n\n/* buffer-type */\nuint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\n/* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST is deprecated. Keeping the handler for backward compatibility. */\nuint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\n\n/* buffer */\nuint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\nuint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\n\n/* backend */\nuint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\n\nextern \"C\" {\nstatic const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {\n\n    /* device */\n\n    /* APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT  = */ backend_device_get_device_count,\n    /* APIR_COMMAND_TYPE_DEVICE_GET_COUNT  = */ backend_device_get_count,\n    /* APIR_COMMAND_TYPE_DEVICE_GET_NAME  = */ backend_device_get_name,\n    /* APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION  = */ backend_device_get_description,\n    /* APIR_COMMAND_TYPE_DEVICE_GET_TYPE  = */ backend_device_get_type,\n    /* APIR_COMMAND_TYPE_DEVICE_GET_MEMORY  = */ backend_device_get_memory,\n    /* APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP  = */ backend_device_supports_op,\n    /* APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE  = */ backend_device_get_buffer_type,\n    /* APIR_COMMAND_TYPE_DEVICE_GET_PROPS  = */ backend_device_get_props,\n    /* APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR  = */ backend_device_buffer_from_ptr,\n\n    /* buffer-type */\n\n    /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME  = */ backend_buffer_type_get_name,\n    /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT  = */ backend_buffer_type_get_alignment,\n    /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE  = */ backend_buffer_type_get_max_size,\n    /* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST  = */ backend_buffer_type_is_host /* DEPRECATED */,\n    /* APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER  = */ backend_buffer_type_alloc_buffer,\n    /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE  = */ backend_buffer_type_get_alloc_size,\n\n    /* buffer */\n\n    /* APIR_COMMAND_TYPE_BUFFER_GET_BASE  = */ backend_buffer_get_base,\n    /* APIR_COMMAND_TYPE_BUFFER_SET_TENSOR  = */ backend_buffer_set_tensor,\n    /* APIR_COMMAND_TYPE_BUFFER_GET_TENSOR  = */ backend_buffer_get_tensor,\n    /* APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR  = */ backend_buffer_cpy_tensor,\n    /* APIR_COMMAND_TYPE_BUFFER_CLEAR  = */ backend_buffer_clear,\n    /* APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER  = */ backend_buffer_free_buffer,\n\n    /* backend */\n\n    /* APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE  = */ backend_backend_graph_compute,\n};\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/backend-dispatched.h",
    "content": "#pragma once\n\n// clang-format off\n#include <cstdint>\n#include <cstddef>\n\n#include <ggml-backend.h>\n\n#include \"backend-convert.h\"\n#include \"backend-virgl-apir.h\"\n#include \"shared/apir_backend.h\"\n#include \"shared/apir_cs.h\"\n#include \"shared/apir_cs_ggml.h\"\n// clang-format on\n\n#define GGML_VIRTGPU_BCK \"ggml-virtgpu-backend: \"\n\nstruct virgl_apir_context {\n    uint32_t               ctx_id;\n    virgl_apir_callbacks * iface;\n};\n\ntypedef uint32_t (*backend_dispatch_t)(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);\n\n#include \"backend-dispatched.gen.h\"\n\nuint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p);\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/backend-virgl-apir.h",
    "content": "#pragma once\n\n#include \"ggml-backend-impl.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-impl.h\"\n#include \"shared/api_remoting.h\"\n\n#include <cstdarg>\n#include <cstdio>\n#include <cstdlib>\n\nextern ggml_backend_reg_t reg;\nextern ggml_backend_dev_t dev;\nextern ggml_backend_t     bck;\n\nstruct virgl_apir_callbacks {\n    const char * (*get_config)(uint32_t virgl_ctx_id, const char * key);\n    void * (*get_shmem_ptr)(uint32_t virgl_ctx_id, uint32_t res_id);\n};\n\nextern \"C\" {\nApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs);\nvoid                      apir_backend_deinit(uint32_t virgl_ctx_id);\nuint32_t                  apir_backend_dispatcher(uint32_t               virgl_ctx_id,\n                                                  virgl_apir_callbacks * virgl_cbs,\n                                                  uint32_t               cmd_type,\n                                                  char *                 dec_cur,\n                                                  const char *           dec_end,\n                                                  char *                 enc_cur,\n                                                  const char *           enc_end,\n                                                  char **                enc_cur_after);\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/backend.cpp",
    "content": "#include \"backend-dispatched.h\"\n#include \"backend-virgl-apir.h\"\n#include \"shared/api_remoting.h\"\n#include \"shared/apir_backend.h\"\n#include \"shared/apir_cs.h\"\n\n#include <dlfcn.h>\n#include <ggml-backend.h>\n\n#include <iostream>\n\n#define APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV \"APIR_LLAMA_CPP_GGML_LIBRARY_PATH\"\n#define APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV  \"APIR_LLAMA_CPP_GGML_LIBRARY_REG\"\n#define APIR_LLAMA_CPP_LOG_TO_FILE_ENV       \"APIR_LLAMA_CPP_LOG_TO_FILE\"\n\n#define GGML_DEFAULT_BACKEND_REG \"ggml_backend_init\"\n\nstatic void * backend_library_handle = NULL;\nstatic FILE * apir_logfile           = NULL;\n\nstatic void log_to_file_callback(enum ggml_log_level level, const char * text, void * user_data) {\n    FILE * logfile = (FILE *) user_data;\n    fprintf(logfile, \"[%d] %s\", level, text);\n    fflush(logfile);\n}\n\nextern \"C\" {\nvoid apir_backend_deinit(uint32_t virgl_ctx_id) {\n    GGML_UNUSED(virgl_ctx_id);\n\n    auto buffers = apir_get_track_backend_buffers();\n    for (const auto & buffer : buffers) {\n        apir_untrack_backend_buffer(buffer);\n        buffer->iface.free_buffer(buffer);\n    }\n\n    if (backend_library_handle) {\n        GGML_LOG_INFO(GGML_VIRTGPU_BCK \"The GGML backend library was loaded. Unloading it.\\n\");\n        dlclose(backend_library_handle);\n        backend_library_handle = NULL;\n    }\n\n    if (apir_logfile) {\n        fclose(apir_logfile);\n        apir_logfile = NULL;\n    }\n}\n\n#define APIR_GGML_LIBRARY_PATH_KEY \"ggml.library.path\"\n#define APIR_GGML_LIBRARY_REG_KEY  \"ggml.library.reg\"\n\nApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs) {\n    const char * dlsym_error;\n\n    const char * apir_log_to_file = getenv(APIR_LLAMA_CPP_LOG_TO_FILE_ENV);\n    if (apir_log_to_file) {\n        apir_logfile = fopen(apir_log_to_file, \"w\");\n        if (apir_logfile) {\n            ggml_log_set(log_to_file_callback, apir_logfile);\n        } else {\n            GGML_LOG_INFO(GGML_VIRTGPU_BCK \"Could not open the log file at '%s'\\n\", apir_log_to_file);\n        }\n    }\n\n    const char * library_name      = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY);\n    const char * virgl_library_reg = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_REG_KEY);\n    const char * library_reg       = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG;\n\n    if (!library_name) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: cannot open the GGML library: env var '%s' not defined\\n\", __func__,\n                       APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV);\n\n        return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;\n    }\n\n    backend_library_handle = dlopen(library_name, RTLD_LAZY);\n\n    if (!backend_library_handle) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: cannot open the GGML library: %s\\n\", __func__, dlerror());\n\n        return APIR_LOAD_LIBRARY_CANNOT_OPEN;\n    }\n\n    if (!library_reg) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: cannot register the GGML library: env var '%s' not defined\\n\", __func__,\n                       APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV);\n\n        return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;\n    }\n\n    void * ggml_backend_reg_fct = dlsym(backend_library_handle, library_reg);\n    dlsym_error                 = dlerror();\n    if (dlsym_error) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\\n\",\n                       __func__, library_reg, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error);\n\n        return APIR_LOAD_LIBRARY_SYMBOL_MISSING;\n    }\n\n    uint32_t ret = backend_dispatch_initialize(ggml_backend_reg_fct);\n\n    return (ApirLoadLibraryReturnCode) (APIR_LOAD_LIBRARY_INIT_BASE_INDEX + ret);\n}\n\nuint32_t apir_backend_dispatcher(uint32_t               virgl_ctx_id,\n                                 virgl_apir_callbacks * virgl_cbs,\n                                 uint32_t               cmd_type,\n                                 char *                 dec_cur,\n                                 const char *           dec_end,\n                                 char *                 enc_cur,\n                                 const char *           enc_end,\n                                 char **                enc_cur_after) {\n    apir_encoder enc = {\n        .cur   = enc_cur,\n        .start = enc_cur,\n        .end   = enc_end,\n        .fatal = false,\n    };\n\n    apir_decoder dec = {\n        .cur   = dec_cur,\n        .end   = dec_end,\n        .fatal = false,\n    };\n\n    virgl_apir_context ctx = {\n        .ctx_id = virgl_ctx_id,\n        .iface  = virgl_cbs,\n    };\n\n    if (cmd_type >= APIR_BACKEND_DISPATCH_TABLE_COUNT) {\n        GGML_LOG_ERROR(GGML_VIRTGPU_BCK \"%s: Received an invalid dispatch index (%d >= %d)\\n\", __func__, cmd_type,\n                       APIR_BACKEND_DISPATCH_TABLE_COUNT);\n        return APIR_BACKEND_FORWARD_INDEX_INVALID;\n    }\n\n    backend_dispatch_t forward_fct = apir_backend_dispatch_table[cmd_type];\n    uint32_t           ret         = forward_fct(&enc, &dec, &ctx);\n\n    *enc_cur_after = enc.cur;\n\n    return ret;\n}\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/shared/api_remoting.h",
    "content": "#pragma once\n\n/* the rest of this file must match virglrenderer/src/apir-protocol.h */\n\n#include <unistd.h>\n\n#include <cstdint>\n\n#define APIR_PROTOCOL_MAJOR 0\n#define APIR_PROTOCOL_MINOR 1\n\n#define APIR_HANDSHAKE_MAGIC 0xab1e\n\nenum ApirCommandType {\n    APIR_COMMAND_TYPE_HANDSHAKE   = 0,\n    APIR_COMMAND_TYPE_LOADLIBRARY = 1,\n    APIR_COMMAND_TYPE_FORWARD     = 2,\n\n    APIR_COMMAND_TYPE_LENGTH = 3,\n};\n\ntypedef uint64_t ApirCommandFlags;\n\nenum ApirLoadLibraryReturnCode {\n    APIR_LOAD_LIBRARY_SUCCESS                        = 0,\n    // these error codes are returned by the Virglrenderer APIR component\n    APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR = 1,\n    APIR_LOAD_LIBRARY_ALREADY_LOADED                 = 2,\n    APIR_LOAD_LIBRARY_ENV_VAR_MISSING                = 3,\n    APIR_LOAD_LIBRARY_CANNOT_OPEN                    = 4,\n    APIR_LOAD_LIBRARY_SYMBOL_MISSING                 = 5,\n    // any value greater than this is an APIR *backend library* initialization return code\n    APIR_LOAD_LIBRARY_INIT_BASE_INDEX                = 6,\n};\n\nenum ApirForwardReturnCode {\n    APIR_FORWARD_SUCCESS                = 0,\n    // these error codes are returned by the Virglrenderer APIR component\n    APIR_FORWARD_NO_DISPATCH_FCT        = 1,\n    APIR_FORWARD_TIMEOUT                = 2,\n    APIR_FORWARD_FAILED_TO_SYNC_STREAMS = 3,\n    // any value greater than this index an APIR *backend library* forward return code\n    APIR_FORWARD_BASE_INDEX             = 4,\n};\n\n__attribute__((unused)) static inline const char * apir_command_name(ApirCommandType type) {\n    switch (type) {\n        case APIR_COMMAND_TYPE_HANDSHAKE:\n            return \"HandShake\";\n        case APIR_COMMAND_TYPE_LOADLIBRARY:\n            return \"LoadLibrary\";\n        case APIR_COMMAND_TYPE_FORWARD:\n            return \"Forward\";\n        default:\n            return \"unknown\";\n    }\n}\n\n__attribute__((unused)) static const char * apir_load_library_error(ApirLoadLibraryReturnCode code) {\n#define APIR_LOAD_LIBRARY_ERROR(code_name) \\\n    do {                                   \\\n        if (code == code_name)             \\\n            return #code_name;             \\\n    } while (0)\n\n    APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_SUCCESS);\n    APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR);\n    APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_ALREADY_LOADED);\n    APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_ENV_VAR_MISSING);\n    APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_CANNOT_OPEN);\n    APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_SYMBOL_MISSING);\n    APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_INIT_BASE_INDEX);\n\n    return \"Unknown APIR_COMMAND_TYPE_LoadLibrary error\";\n\n#undef APIR_LOAD_LIBRARY_ERROR\n}\n\n__attribute__((unused)) static const char * apir_forward_error(ApirForwardReturnCode code) {\n#define APIR_FORWARD_ERROR(code_name) \\\n    do {                              \\\n        if (code == code_name)        \\\n            return #code_name;        \\\n    } while (0)\n\n    APIR_FORWARD_ERROR(APIR_FORWARD_SUCCESS);\n    APIR_FORWARD_ERROR(APIR_FORWARD_NO_DISPATCH_FCT);\n    APIR_FORWARD_ERROR(APIR_FORWARD_TIMEOUT);\n    APIR_FORWARD_ERROR(APIR_FORWARD_FAILED_TO_SYNC_STREAMS);\n    APIR_FORWARD_ERROR(APIR_FORWARD_BASE_INDEX);\n\n    return \"Unknown APIR_COMMAND_TYPE_FORWARD error\";\n\n#undef APIR_FORWARD_ERROR\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/shared/apir_backend.gen.h",
    "content": "typedef enum ApirBackendCommandType {\n\n    /* device */\n    APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT = 0,\n    APIR_COMMAND_TYPE_DEVICE_GET_COUNT        = 1,\n    APIR_COMMAND_TYPE_DEVICE_GET_NAME         = 2,\n    APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION  = 3,\n    APIR_COMMAND_TYPE_DEVICE_GET_TYPE         = 4,\n    APIR_COMMAND_TYPE_DEVICE_GET_MEMORY       = 5,\n    APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP      = 6,\n    APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE  = 7,\n    APIR_COMMAND_TYPE_DEVICE_GET_PROPS        = 8,\n    APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR  = 9,\n\n    /* buffer-type */\n    APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME       = 10,\n    APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT  = 11,\n    APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE   = 12,\n    APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST        = 13,\n    APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER   = 14,\n    APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE = 15,\n\n    /* buffer */\n    APIR_COMMAND_TYPE_BUFFER_GET_BASE    = 16,\n    APIR_COMMAND_TYPE_BUFFER_SET_TENSOR  = 17,\n    APIR_COMMAND_TYPE_BUFFER_GET_TENSOR  = 18,\n    APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR  = 19,\n    APIR_COMMAND_TYPE_BUFFER_CLEAR       = 20,\n    APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER = 21,\n\n    /* backend */\n    APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE = 22,\n\n    // last command_type index + 1\n    APIR_BACKEND_DISPATCH_TABLE_COUNT = 23,\n} ApirBackendCommandType;\n\nstatic inline const char * apir_dispatch_command_name(ApirBackendCommandType type) {\n    switch (type) {\n        /* device */\n        case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT:\n            return \"device_get_device_count\";\n        case APIR_COMMAND_TYPE_DEVICE_GET_COUNT:\n            return \"device_get_count\";\n        case APIR_COMMAND_TYPE_DEVICE_GET_NAME:\n            return \"device_get_name\";\n        case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION:\n            return \"device_get_description\";\n        case APIR_COMMAND_TYPE_DEVICE_GET_TYPE:\n            return \"device_get_type\";\n        case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY:\n            return \"device_get_memory\";\n        case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP:\n            return \"device_supports_op\";\n        case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE:\n            return \"device_get_buffer_type\";\n        case APIR_COMMAND_TYPE_DEVICE_GET_PROPS:\n            return \"device_get_props\";\n        case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR:\n            return \"device_buffer_from_ptr\";\n        /* buffer-type */\n        case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME:\n            return \"buffer_type_get_name\";\n        case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT:\n            return \"buffer_type_get_alignment\";\n        case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE:\n            return \"buffer_type_get_max_size\";\n        case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST:\n            return \"buffer_type_is_host\";\n        case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER:\n            return \"buffer_type_alloc_buffer\";\n        case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE:\n            return \"buffer_type_get_alloc_size\";\n        /* buffer */\n        case APIR_COMMAND_TYPE_BUFFER_GET_BASE:\n            return \"buffer_get_base\";\n        case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR:\n            return \"buffer_set_tensor\";\n        case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR:\n            return \"buffer_get_tensor\";\n        case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR:\n            return \"buffer_cpy_tensor\";\n        case APIR_COMMAND_TYPE_BUFFER_CLEAR:\n            return \"buffer_clear\";\n        case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER:\n            return \"buffer_free_buffer\";\n        /* backend */\n        case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE:\n            return \"backend_graph_compute\";\n\n        default:\n            return \"unknown\";\n    }\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/shared/apir_backend.h",
    "content": "#pragma once\n\n#include \"apir_backend.gen.h\"\n\n#include <stdint.h>  // for uintptr_t\n#include <time.h>    // for timespec, clock_gettime\n\n#define APIR_BACKEND_INITIALIZE_SUCCESS                     0\n#define APIR_BACKEND_INITIALIZE_CANNOT_OPEN_BACKEND_LIBRARY 1\n#define APIR_BACKEND_INITIALIZE_CANNOT_OPEN_GGML_LIBRARY    2\n#define APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS     3\n#define APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS        4\n#define APIR_BACKEND_INITIALIZE_BACKEND_FAILED              5\n#define APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED          6\n#define APIR_BACKEND_INITIALIZE_ALREADY_INITED              7\n#define APIR_BACKEND_INITIALIZE_NO_DEVICE                   8\n#define APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED         9\n\n// new entries here need to be added to the apir_backend_initialize_error function below\n\n#define APIR_BACKEND_FORWARD_INDEX_INVALID 6\n\n// 0 is fast, 1 avoids the backend to crash if an unsupported tensor is received\n#define APIR_BACKEND_CHECK_SUPPORTS_OP 0\n\ntypedef uintptr_t apir_buffer_type_host_handle_t;\ntypedef uintptr_t apir_buffer_host_handle_t;\n\nstatic const char * apir_backend_initialize_error(int code) {\n#define APIR_BACKEND_INITIALIZE_ERROR(code_name) \\\n    do {                                         \\\n        if (code == code_name)                   \\\n            return #code_name;                   \\\n    } while (0)\n\n    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_SUCCESS);\n    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_CANNOT_OPEN_BACKEND_LIBRARY);\n    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_CANNOT_OPEN_GGML_LIBRARY);\n    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS);\n    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS);\n    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_FAILED);\n    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED);\n    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_ALREADY_INITED);\n    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_NO_DEVICE);\n    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED);\n\n    return \"Unknown APIR_BACKEND_INITIALIZE error:/\";\n\n#undef APIR_BACKEND_INITIALIZE_ERROR\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/shared/apir_cs.h",
    "content": "#pragma once\n\n#include \"ggml-impl.h\"\n\n#include <cassert>\n#include <cstring>\n\n#define likely(x)   __builtin_expect(!!(x), 1)\n#define unlikely(x) __builtin_expect(!!(x), 0)\n\nstruct apir_encoder {\n    char *       cur;\n    const char * start;\n    const char * end;\n    bool         fatal;\n};\n\nstruct apir_decoder {\n    const char * cur;\n    const char * end;\n    bool         fatal;\n};\n\n/*\n * new encoder and decoder\n */\n\nstatic apir_decoder apir_new_decoder(const char * ptr, size_t size) {\n    apir_decoder dec = {\n        .cur   = ptr,\n        .end   = ptr + size,\n        .fatal = false,\n    };\n\n    return dec;\n}\n\nstatic apir_encoder apir_new_encoder(char * ptr, size_t size) {\n    apir_encoder enc = {\n        .cur   = ptr,\n        .start = ptr,\n        .end   = ptr + size,\n        .fatal = false,\n    };\n\n    return enc;\n}\n\n/*\n * fatal flag handling\n */\n\nstatic inline void apir_encoder_reset_fatal(apir_encoder * enc) {\n    enc->fatal = false;\n}\n\nstatic inline void apir_encoder_set_fatal(apir_encoder * enc) {\n    enc->fatal = true;\n}\n\nstatic inline bool apir_encoder_get_fatal(const apir_encoder * enc) {\n    return enc->fatal;\n}\n\nstatic inline void apir_decoder_reset_fatal(apir_decoder * dec) {\n    dec->fatal = false;\n}\n\nstatic inline void apir_decoder_set_fatal(apir_decoder * dec) {\n    dec->fatal = true;\n}\n\nstatic inline bool apir_decoder_get_fatal(const apir_decoder * dec) {\n    return dec->fatal;\n}\n\n/*\n * encode peek\n */\n\nstatic inline bool apir_decoder_peek_internal(apir_decoder * dec, size_t size, void * val, size_t val_size) {\n    assert(val_size <= size);\n\n    if (unlikely(size > (size_t) (dec->end - dec->cur))) {\n        GGML_LOG_ERROR(\"%s: reading too much from the decoder ...\\n\", __func__);\n        apir_decoder_set_fatal(dec);\n        memset(val, 0, val_size);\n        return false;\n    }\n\n    /* we should not rely on the compiler to optimize away memcpy... */\n    memcpy(val, dec->cur, val_size);\n    return true;\n}\n\nstatic inline void apir_decoder_peek(apir_decoder * dec, size_t size, void * val, size_t val_size) {\n    apir_decoder_peek_internal(dec, size, val, val_size);\n}\n\nstatic inline const void * apir_decoder_use_inplace(apir_decoder * dec, size_t size) {\n    if (unlikely(size > (size_t) (dec->end - dec->cur))) {\n        GGML_LOG_ERROR(\"%s: reading too much from the decoder ...\\n\", __func__);\n        apir_decoder_set_fatal(dec);\n        return NULL;\n    }\n    const void * addr = dec->cur;\n    dec->cur += size;\n\n    return addr;\n}\n\n/*\n * read/write\n */\n\nstatic inline void apir_decoder_read(apir_decoder * dec, size_t size, void * val, size_t val_size) {\n    if (apir_decoder_peek_internal(dec, size, val, val_size)) {\n        dec->cur += size;\n    }\n}\n\nstatic inline char * apir_encoder_write(apir_encoder * enc, size_t size, const void * val, size_t val_size) {\n    assert(val_size <= size);\n    assert(size <= ((size_t) (enc->end - enc->cur)));\n\n    char * write_addr = enc->cur;\n    /* we should not rely on the compiler to optimize away memcpy... */\n    memcpy(write_addr, val, val_size);\n    enc->cur += size;\n\n    return write_addr;\n}\n\n/*\n * encode/decode\n */\n\nstatic inline void apir_decode(apir_decoder * dec, size_t size, void * data, size_t data_size) {\n    assert(size % 4 == 0);\n    apir_decoder_read(dec, size, data, data_size);\n}\n\nstatic inline void apir_encode(apir_encoder * enc, size_t size, const void * data, size_t data_size) {\n    assert(size % 4 == 0);\n    apir_encoder_write(enc, size, data, data_size);\n}\n\n/*\n * typed encode/decode\n */\n\n/* uint8_t */\n\nstatic inline void apir_encode_uint8_t(apir_encoder * enc, const uint8_t * val) {\n    apir_encode(enc, sizeof(int), val, sizeof(*val));\n}\n\nstatic inline void apir_decode_uint8_t(apir_decoder * dec, uint8_t * val) {\n    apir_decode(dec, sizeof(int), val, sizeof(*val));\n}\n\n/* uint64_t */\n\nstatic inline void apir_encode_uint64_t(apir_encoder * enc, const uint64_t * val) {\n    apir_encode(enc, 8, val, sizeof(*val));\n}\n\nstatic inline void apir_decode_uint64_t(apir_decoder * dec, uint64_t * val) {\n    apir_decode(dec, 8, val, sizeof(*val));\n}\n\nstatic inline void apir_encode_uint64_t_array(apir_encoder * enc, const uint64_t * val, uint32_t count) {\n    const size_t size = sizeof(*val) * count;\n    assert(size >= count);\n    apir_encode(enc, size, val, size);\n}\n\nstatic inline void apir_decode_uint64_t_array(apir_decoder * dec, uint64_t * val, uint32_t count) {\n    const size_t size = sizeof(*val) * count;\n    assert(size >= count);\n    apir_decode(dec, size, val, size);\n}\n\nstatic inline const uint64_t * apir_decode_uint64_t_array_inplace(apir_decoder * dec, uint32_t count) {\n    return (uint64_t *) (uintptr_t) apir_decoder_use_inplace(dec, count * sizeof(uint64_t));\n}\n\n/* int32_t */\n\nstatic inline void apir_encode_int32_t(apir_encoder * enc, const int32_t * val) {\n    apir_encode(enc, 4, val, sizeof(*val));\n}\n\nstatic inline void apir_decode_int32_t(apir_decoder * dec, int32_t * val) {\n    apir_decode(dec, 4, val, sizeof(*val));\n}\n\nstatic inline void apir_encode_int32_t_array(apir_encoder * enc, const int32_t * val, uint32_t count) {\n    const size_t size = sizeof(*val) * count;\n    assert(size >= count);\n    apir_encode(enc, size, val, size);\n}\n\nstatic inline void apir_decode_int32_t_array(apir_decoder * dec, int32_t * val, uint32_t count) {\n    const size_t size = sizeof(*val) * count;\n    assert(size >= count);\n    apir_decode(dec, size, val, size);\n}\n\n/* array size (uint64_t) */\n\nstatic inline void apir_encode_array_size(apir_encoder * enc, uint64_t size) {\n    apir_encode_uint64_t(enc, &size);\n}\n\nstatic inline uint64_t apir_decode_array_size(apir_decoder * dec, uint64_t expected_size) {\n    uint64_t size;\n    apir_decode_uint64_t(dec, &size);\n    if (size != expected_size) {\n        GGML_LOG_ERROR(\"%s: Couldn't decode array from the decoder\\n\", __func__);\n        apir_decoder_set_fatal(dec);\n        size = 0;\n    }\n    return size;\n}\n\nstatic inline uint64_t apir_decode_array_size_unchecked(apir_decoder * dec) {\n    uint64_t size;\n    apir_decode_uint64_t(dec, &size);\n    return size;\n}\n\n/* non-array pointer */\n\nstatic inline bool apir_encode_simple_pointer(apir_encoder * enc, const void * val) {\n    apir_encode_array_size(enc, val ? 1 : 0);\n    return val;\n}\n\nstatic inline bool apir_decode_simple_pointer(apir_decoder * dec) {\n    return apir_decode_array_size_unchecked(dec);\n}\n\n/* uint32_t */\n\nstatic inline void apir_encode_uint32_t(apir_encoder * enc, const uint32_t * val) {\n    apir_encode(enc, 4, val, sizeof(*val));\n}\n\nstatic inline void apir_decode_uint32_t(apir_decoder * dec, uint32_t * val) {\n    apir_decode(dec, 4, val, sizeof(*val));\n}\n\nstatic inline void apir_encode_uint32_t_array(apir_encoder * enc, const uint32_t * val, uint32_t count) {\n    const size_t size = sizeof(*val) * count;\n    assert(size >= count);\n    apir_encode(enc, size, val, size);\n}\n\nstatic inline void apir_decode_uint32_t_array(apir_decoder * dec, uint32_t * val, uint32_t count) {\n    const size_t size = sizeof(*val) * count;\n    assert(size >= count);\n    apir_decode(dec, size, val, size);\n}\n\n/* size_t */\n\nstatic inline void apir_encode_size_t(apir_encoder * enc, const size_t * val) {\n    const uint64_t tmp = *val;\n    apir_encode_uint64_t(enc, &tmp);\n}\n\nstatic inline void apir_decode_size_t(apir_decoder * dec, size_t * val) {\n    uint64_t tmp;\n    apir_decode_uint64_t(dec, &tmp);\n    *val = tmp;\n}\n\nstatic inline void apir_encode_size_t_array(apir_encoder * enc, const size_t * val, uint32_t count) {\n    if (sizeof(size_t) == sizeof(uint64_t)) {\n        apir_encode_uint64_t_array(enc, (const uint64_t *) val, count);\n    } else {\n        for (uint32_t i = 0; i < count; i++) {\n            apir_encode_size_t(enc, &val[i]);\n        }\n    }\n}\n\nstatic inline void apir_decode_size_t_array(apir_decoder * dec, size_t * val, uint32_t count) {\n    if (sizeof(size_t) == sizeof(uint64_t)) {\n        apir_decode_uint64_t_array(dec, (uint64_t *) val, count);\n    } else {\n        for (uint32_t i = 0; i < count; i++) {\n            apir_decode_size_t(dec, &val[i]);\n        }\n    }\n}\n\n/* opaque blob */\n\nstatic inline void apir_encode_blob_array(apir_encoder * enc, const void * val, size_t size) {\n    apir_encode(enc, (size + 3) & ~3, val, size);\n}\n\nstatic inline void apir_decode_blob_array(apir_decoder * dec, void * val, size_t size) {\n    apir_decode(dec, (size + 3) & ~3, val, size);\n}\n\n/* string */\n\nstatic inline void apir_encode_char_array(apir_encoder * enc, const char * val, size_t size) {\n    assert(size && strlen(val) < size);\n    apir_encode_blob_array(enc, val, size);\n}\n\nstatic inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t size) {\n    apir_decode_blob_array(dec, val, size);\n    if (size) {\n        val[size - 1] = '\\0';\n    } else {\n        GGML_LOG_ERROR(\"%s: Couldn't decode the blog array\\n\", __func__);\n        apir_decoder_set_fatal(dec);\n    }\n}\n\n/* (temp) buffer allocation */\n\nstatic inline void * apir_decoder_alloc_array(size_t size, size_t count) {\n    size_t alloc_size;\n    if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) {\n        GGML_LOG_ERROR(\"%s: overflow in array allocation of %zu * %zu bytes\\n\", __func__, size, count);\n        return NULL;\n    }\n\n    return malloc(alloc_size);\n}\n\n/* bool */\n\nstatic inline void apir_encode_bool_t(apir_encoder * enc, const bool * val) {\n    apir_encode(enc, sizeof(int), val, sizeof(bool));\n}\n\nstatic inline void apir_decode_bool_t(apir_decoder * dec, bool * val) {\n    apir_decode(dec, sizeof(int), val, sizeof(bool));\n}\n\n/* apir_buffer_type_host_handle_t */\n\nstatic inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder *                         enc,\n                                                              const apir_buffer_type_host_handle_t * val) {\n    apir_encode(enc, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));\n}\n\nstatic inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder *                   dec,\n                                                              apir_buffer_type_host_handle_t * val) {\n    apir_decode(dec, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));\n}\n\n/* apir_buffer_host_handle_t */\n\nstatic inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc, const apir_buffer_host_handle_t * val) {\n    apir_encode(enc, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));\n}\n\nstatic inline void apir_decode_apir_buffer_host_handle_t(apir_decoder * dec, apir_buffer_host_handle_t * val) {\n    apir_decode(dec, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));\n}\n\n/* uintptr_t */\n\nstatic inline void apir_encode_uintptr_t(apir_encoder * enc, const uintptr_t * val) {\n    apir_encode(enc, sizeof(*val), val, sizeof(*val));\n}\n\nstatic inline void apir_decode_uintptr_t(apir_decoder * dec, uintptr_t * val) {\n    apir_decode(dec, sizeof(*val), val, sizeof(*val));\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/shared/apir_cs_ggml.h",
    "content": "#include \"apir_cs.h\"\n#include \"apir_cs_rpc.h\"\n#include \"ggml-impl.h\"\n\n// ggml_buffer_to_apir_host_handle(ggml_backend_buffer_t buffer);\n\nstatic inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle);\n\nstatic inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec);\n\n/* apir_rpc_tensor */\n\nstatic inline void apir_encode_rcp_tensor(apir_encoder * enc, const apir_rpc_tensor * apir_rpc_tensor) {\n    size_t apir_rpc_tensor_size = sizeof(*apir_rpc_tensor);\n    apir_encode(enc, apir_rpc_tensor_size, apir_rpc_tensor, apir_rpc_tensor_size);\n}\n\nstatic inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_inplace(apir_decoder * dec) {\n    size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor);\n\n    return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size);\n}\n\nstatic inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec, uint32_t n_tensors) {\n    size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor) * n_tensors;\n\n    return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size);\n}\n\n/* ggml_tensor */\n\nstatic inline void apir_encode_ggml_tensor(apir_encoder * enc, const ggml_tensor * tensor) {\n    apir_rpc_tensor serialized = apir_serialize_tensor(tensor);\n\n    apir_encode_rcp_tensor(enc, &serialized);\n}\n\nstatic inline const ggml_tensor * apir_decode_ggml_tensor(apir_decoder * dec) {\n    const apir_rpc_tensor * apir_rpc_tensor = apir_decode_apir_rpc_tensor_inplace(dec);\n\n    if (!apir_rpc_tensor) {\n        return NULL;\n    }\n\n    ggml_init_params params{\n        /*.mem_size   =*/ggml_tensor_overhead(),\n        /*.mem_buffer =*/NULL,\n        /*.no_alloc   =*/true,\n    };\n\n    ggml_context * ctx = ggml_init(params);\n\n    const ggml_tensor * tensor = apir_deserialize_tensor(ctx, apir_rpc_tensor);\n\n    return tensor;\n}\n\n/* *** ggml_backend_buffer_type_t *** */\n\n// ggml_backend_buffer_type_t is a POINTER (to a struct).\n// Only the host pointer is shared between the host and guest.\n// The guest stores it in `buft->context`.\n// The host simply writes the pointer address in the buffer variable.\n\nstatic inline void apir_encode_ggml_buffer_type(apir_encoder * enc, ggml_backend_buffer_type_t buft) {\n    apir_buffer_type_host_handle_t handle = ggml_buffer_type_to_apir_handle(buft);\n    apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle));\n}\n\nstatic inline ggml_backend_buffer_type_t apir_decode_ggml_buffer_type(apir_decoder * dec) {\n    apir_buffer_type_host_handle_t handle;\n\n    apir_decoder_read(dec, sizeof(handle), &handle, sizeof(handle));\n\n    return (ggml_backend_buffer_type_t) handle;\n}\n\nstatic inline void apir_encode_apir_buffer_type_host_handle(apir_encoder * enc, apir_buffer_type_host_handle_t handle) {\n    apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle));\n}\n\nstatic inline apir_buffer_type_host_handle_t apir_decode_apir_buffer_type_host_handle(apir_decoder * dec) {\n    apir_buffer_type_host_handle_t handle;\n\n    apir_decoder_read(dec, sizeof(handle), &handle, sizeof(handle));\n\n    return handle;\n}\n\n/* *** ggml_backend_type_t *** */\n\n// ggml_backend_buffer_t is a POINTER.\n// same logic as for ggml_backend_buffer_type_t\n\nstatic inline void apir_encode_ggml_buffer(apir_encoder * enc, const ggml_backend_buffer_t buffer) {\n    apir_buffer_host_handle_t handle = BUFFER_TO_HOST_HANDLE(buffer);\n    apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle));\n}\n\nstatic inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec) {\n    ggml_backend_buffer_t buffer;\n    size_t                buffer_ptr_size = sizeof(buffer);\n\n    apir_decoder_read(dec, buffer_ptr_size, &buffer, buffer_ptr_size);\n\n    // SECURITY: Validate buffer handle against tracked buffers to prevent\n    // guest VM from providing arbitrary host memory addresses\n    if (buffer) {\n        extern std::unordered_set<ggml_backend_buffer_t> backend_buffers;\n        if (backend_buffers.find(buffer) == backend_buffers.end()) {\n            GGML_LOG_WARN(\"ggml-virtgpu-backend: %s: Invalid buffer handle from guest: %p\\n\", __func__,\n                          (void *) buffer);\n            // Set fatal flag to prevent further processing with invalid handle\n            apir_decoder_set_fatal(dec);\n            return NULL;\n        }\n    }\n\n    return buffer;\n}\n\n/* enum ggml_status */\n\nstatic inline void apir_encode_ggml_status(apir_encoder * enc, const ggml_status * status) {\n    apir_encoder_write(enc, sizeof(*status), status, sizeof(*status));\n}\n\nstatic inline void apir_decode_ggml_status(apir_decoder * dec, ggml_status * status) {\n    apir_decoder_read(dec, sizeof(*status), status, sizeof(*status));\n}\n\n/* virtgpu_shmem */\n\nstatic inline void apir_encode_virtgpu_shmem_res_id(apir_encoder * enc, uint32_t shmem_res_id) {\n    apir_encode_uint32_t(enc, &shmem_res_id);\n}\n\nstatic inline void apir_decode_virtgpu_shmem_res_id(apir_decoder * dec, uint32_t * shmem_res_id) {\n    apir_decode_uint32_t(dec, shmem_res_id);\n}\n\n/* ggml_cgraph */\n\nstatic inline size_t apir_serialize_ggml_cgraph(ggml_cgraph * cgraph, std::vector<uint8_t> & cgraph_data) {\n    apir_serialize_graph(cgraph, cgraph_data);\n\n    return cgraph_data.size();\n}\n\nstatic inline void apir_encode_cgraph_data(apir_encoder * enc, std::vector<uint8_t> & cgraph_data) {\n    size_t cgraph_size = cgraph_data.size();\n\n    apir_encode(enc, cgraph_size, cgraph_data.data(), cgraph_size);\n}\n\nstatic inline ggml_cgraph * apir_decode_ggml_cgraph(apir_decoder * dec, size_t cgraph_size) {\n    GGML_UNUSED(cgraph_size);\n\n    uint32_t n_nodes;\n    apir_decode_uint32_t(dec, &n_nodes);\n    const uint64_t * nodes = apir_decode_uint64_t_array_inplace(dec, n_nodes);\n\n    uint32_t n_tensors;\n    apir_decode_uint32_t(dec, &n_tensors);\n    const apir_rpc_tensor * tensors = apir_decode_apir_rpc_tensor_array_inplace(dec, n_tensors);\n\n    return apir_deserialize_graph(n_nodes, n_tensors, tensors, nodes);\n}\n\nstatic inline void apir_encode_ggml_buffer_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle) {\n    apir_encoder_write(enc, sizeof(*handle), &handle, sizeof(*handle));\n}\n\nstatic inline void apir_encode_ggml_tensor_inline(apir_encoder * enc, const ggml_tensor * tensor) {\n    size_t tensor_size = sizeof(*tensor);\n\n    if (tensor->extra) {\n        GGML_ABORT(\"%s: Cannot pass tensors with extra\", __func__);\n    }\n\n    if (tensor->src[0] && tensor->buffer) {\n        static int first = 1;\n        if (first) {\n            GGML_LOG_WARN(\"%s: Cannot pass tensors with src and buffer\\n\", __func__);\n            first = 0;\n        }\n    }\n\n    apir_encoder_write(enc, tensor_size, tensor, tensor_size);\n\n    // tensor->data is a pointer inside the device buffer. No need to touch it\n    // tensor->buffer is a pointer to a buffer. Encoding the buffer handle in sequence.\n    // (could also make a copy of the tensor, and update locally.)\n\n    if (tensor->buffer) {\n        apir_buffer_host_handle_t buffer_handle = ggml_buffer_to_apir_handle(tensor->buffer);\n        apir_encode_ggml_buffer_handle(enc, &buffer_handle);\n    }\n\n    if (tensor->view_src) {\n        apir_encoder_write(enc, tensor_size, tensor->view_src, tensor_size);\n    }\n\n    for (int i = 0; tensor->src[i]; i++) {\n        const ggml_tensor * tensor_src = tensor->src[i];\n        apir_encoder_write(enc, tensor_size, tensor_src, tensor_size);\n    }\n}\n\nstatic inline const ggml_tensor * apir_decode_ggml_tensor_inplace(apir_decoder * dec) {\n    // it safe to remove the `const` qualifier here, we *do* want to\n    // modify the shared memory data to fix the `src` pointers.\n    ggml_tensor * tensor = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor));\n\n    // tensor->data is a pointer inside the device buffer. No need to touch it\n    // tensor->buffer is a pointer to a buffer. Decode the buffer handle encoded in sequence.\n    if (tensor->buffer) {\n        tensor->buffer = apir_decode_ggml_buffer(dec);\n    }\n\n    if (tensor->view_src) {\n        ggml_tensor * tensor_view_src = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor));\n        tensor->view_src              = tensor_view_src;\n    }\n\n    for (int i = 0; tensor->src[i]; i++) {\n        ggml_tensor * tensor_src = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor));\n        tensor->src[i] = tensor_src;  // overwrite op->src[i] pointer with the actual location of the src tensor\n    }\n\n    return tensor;\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/backend/shared/apir_cs_rpc.h",
    "content": "#pragma once\n\n// clang-format off\n#include \"ggml.h\"\n#include \"ggml-backend-impl.h\"\n\n#include <unordered_map>\n#include <unordered_set>\n#include <vector>\n#include <cstdint>\n// clang-format on\n\n// ggml_tensor is serialized into apir_rpc_tensor\nstruct apir_rpc_tensor {\n    uint64_t id;\n    uint32_t type;\n    uint64_t buffer;\n    uint32_t ne[GGML_MAX_DIMS];\n    uint32_t nb[GGML_MAX_DIMS];\n    uint32_t op;\n    int32_t  op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];\n    int32_t  flags;\n    uint64_t src[GGML_MAX_SRC];\n    uint64_t view_src;\n    uint64_t view_offs;\n    uint64_t data;\n    char     name[GGML_MAX_NAME];\n\n    char padding[4];\n};\n\n/* frontend */\n\napir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor);\n\nvoid apir_serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output);\n\n/* backend */\n\nvoid                                      apir_track_backend_buffer(ggml_backend_buffer_t buffer);\nbool                                      apir_untrack_backend_buffer(ggml_backend_buffer_t buffer);\nstd::unordered_set<ggml_backend_buffer_t> apir_get_track_backend_buffers();\n\nvoid apir_add_tensor(ggml_tensor *                       tensor,\n                     std::vector<apir_rpc_tensor> &      tensors,\n                     std::unordered_set<ggml_tensor *> & visited);\n\nggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor);\n\nggml_tensor * apir_create_node(uint64_t                                                      id,\n                               ggml_context *                                                ctx,\n                               const std::unordered_map<uint64_t, const apir_rpc_tensor *> & tensor_ptrs,\n                               std::unordered_map<uint64_t, ggml_tensor *> &                 tensor_map);\n\nggml_cgraph * apir_deserialize_graph(uint32_t                n_nodes,\n                                     uint32_t                n_tensors,\n                                     const apir_rpc_tensor * tensors,\n                                     const uint64_t *        nodes);\n"
  },
  {
    "path": "src/ggml-virtgpu/ggml-backend-buffer-type.cpp",
    "content": "#include \"ggml-remoting.h\"\n\nstatic ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,\n                                                                            size_t                     size) {\n    virtgpu * gpu = BUFT_TO_GPU(buft);\n\n    ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context));\n    if (!context) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: Couldn't allocate the buffer context ...\", __func__);\n    }\n\n    context->gpu = gpu;\n\n    bool async__unused, host_buffer__unused, events__unused;\n    bool buffer_from_host_ptr;\n    apir_device_get_props(gpu, &async__unused, &host_buffer__unused, &buffer_from_host_ptr, &events__unused);\n\n    if (buffer_from_host_ptr) {\n        context->apir_context = apir_device_buffer_from_ptr(gpu, size, size);\n        context->base         = context->apir_context.shmem.mmap_ptr;\n        context->is_from_ptr  = true;\n    } else {\n        context->apir_context = apir_buffer_type_alloc_buffer(gpu, gpu->cached_buffer_type.host_handle, size);\n        context->is_from_ptr  = false;\n        context->base         = NULL;\n    }\n\n    ggml_backend_buffer_t buffer =\n        ggml_backend_buffer_init(buft, ggml_backend_remoting_buffer_interface, (void *) context, size);\n\n    return buffer;\n}\n\nstatic const char * ggml_backend_remoting_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    virtgpu * gpu = BUFT_TO_GPU(buft);\n\n    // Return the prefixed name that was built once during initialization\n    return gpu->cached_buffer_type.name;\n}\n\nstatic size_t ggml_backend_remoting_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    virtgpu * gpu = BUFT_TO_GPU(buft);\n\n    return gpu->cached_buffer_type.alignment;\n}\n\nstatic size_t ggml_backend_remoting_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {\n    virtgpu * gpu = BUFT_TO_GPU(buft);\n\n    return gpu->cached_buffer_type.max_size;\n}\n\nstatic size_t ggml_backend_remoting_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,\n                                                               const ggml_tensor *        tensor) {\n    virtgpu * gpu = BUFT_TO_GPU(buft);\n\n    if (tensor->buffer == NULL || !tensor->buffer->context ||\n        !buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) {\n        return ggml_nbytes(tensor);\n    }\n\n    return apir_buffer_type_get_alloc_size(gpu, gpu->cached_buffer_type.host_handle, tensor);\n}\n\nconst ggml_backend_buffer_type_i ggml_backend_remoting_buffer_type_interface = {\n    /* .get_name         = */ ggml_backend_remoting_buffer_type_get_name,\n    /* .alloc_buffer     = */ ggml_backend_remoting_buffer_type_alloc_buffer,\n    /* .get_alignment    = */ ggml_backend_remoting_buffer_type_get_alignment,\n    /* .get_max_size     = */ ggml_backend_remoting_buffer_type_get_max_size,\n    /* .get_alloc_size   = */ ggml_backend_remoting_buffer_type_get_alloc_size,\n    /* .is_host          = */ NULL,\n};\n\nconst ggml_backend_buffer_type_i ggml_backend_remoting_buffer_from_ptr_type_interface = {\n    /* .get_name         = */ ggml_backend_remoting_buffer_type_get_name,\n    /* .alloc_buffer     = */ NULL,\n    /* .get_alignment    = */ ggml_backend_remoting_buffer_type_get_alignment,\n    /* .get_max_size     = */ ggml_backend_remoting_buffer_type_get_max_size,\n    /* .get_alloc_size   = */ ggml_backend_remoting_buffer_type_get_alloc_size,\n    /* .is_host          = */ NULL,\n};\n"
  },
  {
    "path": "src/ggml-virtgpu/ggml-backend-buffer.cpp",
    "content": "#include \"ggml-remoting.h\"\n\n#define BUFFER_TO_GPU(name) ((ggml_backend_remoting_buffer_context *) (name)->context)->gpu\n\nstatic void * ggml_backend_remoting_buffer_get_base(ggml_backend_buffer_t buffer) {\n    ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) buffer->context;\n    if (context->base) {\n        return context->base;\n    }\n\n    context->base = apir_buffer_get_base(BUFFER_TO_GPU(buffer), BUFFER_TO_APIR_CONTEXT(buffer));\n\n    return context->base;\n}\n\nstatic void ggml_backend_remoting_buffer_set_tensor(ggml_backend_buffer_t buffer,\n                                                    ggml_tensor *         tensor,\n                                                    const void *          data,\n                                                    size_t                offset,\n                                                    size_t                size) {\n    virtgpu * gpu = BUFFER_TO_GPU(buffer);\n\n    ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer);\n    if (context->is_from_ptr) {\n        memcpy((char *) tensor->data + offset, data, size);\n    } else {\n        apir_buffer_set_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), tensor, data, offset, size);\n    }\n\n    return;\n}\n\nstatic void ggml_backend_remoting_buffer_get_tensor(ggml_backend_buffer_t buffer,\n                                                    const ggml_tensor *   tensor,\n                                                    void *                data,\n                                                    size_t                offset,\n                                                    size_t                size) {\n    virtgpu *                              gpu     = BUFFER_TO_GPU(buffer);\n    ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer);\n    if (context->is_from_ptr) {\n        memcpy(data, (const char *) tensor->data + offset, size);\n    } else {\n        apir_buffer_get_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), tensor, data, offset, size);\n    }\n}\n\nstatic void ggml_backend_remoting_buffer_set_tensor_from_ptr(ggml_backend_buffer_t buffer,\n                                                             ggml_tensor *         tensor,\n                                                             const void *          data,\n                                                             size_t                offset,\n                                                             size_t                size) {\n    UNUSED(buffer);\n\n    memcpy((char *) tensor->data + offset, data, size);\n\n    return;\n}\n\nstatic void ggml_backend_remoting_buffer_get_tensor_from_ptr(ggml_backend_buffer_t buffer,\n                                                             const ggml_tensor *   tensor,\n                                                             void *                data,\n                                                             size_t                offset,\n                                                             size_t                size) {\n    UNUSED(buffer);\n\n    memcpy(data, (const char *) tensor->data + offset, size);\n}\n\nstatic bool ggml_backend_remoting_buffer_cpy_tensor(ggml_backend_buffer_t buffer,\n                                                    const ggml_tensor *   src,\n                                                    ggml_tensor *         dst) {\n    virtgpu * gpu = BUFFER_TO_GPU(buffer);\n\n    bool ret = apir_buffer_cpy_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), src, dst);\n\n    return ret;\n}\n\nstatic void ggml_backend_remoting_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    virtgpu * gpu = BUFFER_TO_GPU(buffer);\n\n    apir_buffer_clear(gpu, BUFFER_TO_APIR_CONTEXT(buffer), value);\n\n    return;\n}\n\nstatic void ggml_backend_remoting_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    virtgpu * gpu = BUFFER_TO_GPU(buffer);\n\n    apir_buffer_free_buffer(gpu, BUFFER_TO_APIR_CONTEXT(buffer));\n\n    ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer);\n    free(context);\n    buffer->context = NULL;\n}\n\nconst ggml_backend_buffer_i ggml_backend_remoting_buffer_interface = {\n    /* .free_buffer     = */ ggml_backend_remoting_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_remoting_buffer_get_base,\n    /* .init_tensor     = */ NULL,\n    /* .memset_tensor   = */ NULL,\n    /* .set_tensor      = */ ggml_backend_remoting_buffer_set_tensor,\n    /* .get_tensor      = */ ggml_backend_remoting_buffer_get_tensor,\n    /* .cpy_tensor      = */ ggml_backend_remoting_buffer_cpy_tensor,\n    /* .clear           = */ ggml_backend_remoting_buffer_clear,\n    /* .reset           = */ NULL,\n};\n\nconst ggml_backend_buffer_i ggml_backend_remoting_buffer_from_ptr_interface = {\n    /* .free_buffer     = */ ggml_backend_remoting_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_remoting_buffer_get_base,\n    /* .init_tensor     = */ NULL,\n    /* .memset_tensor   = */ NULL,\n    /* .set_tensor      = */ ggml_backend_remoting_buffer_set_tensor_from_ptr,\n    /* .get_tensor      = */ ggml_backend_remoting_buffer_get_tensor_from_ptr,\n    /* .cpy_tensor      = */ ggml_backend_remoting_buffer_cpy_tensor,\n    /* .clear           = */ ggml_backend_remoting_buffer_clear,\n    /* .reset           = */ NULL,\n};\n"
  },
  {
    "path": "src/ggml-virtgpu/ggml-backend-device.cpp",
    "content": "#include \"ggml-remoting.h\"\n\nstatic const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) {\n    virtgpu * gpu = DEV_TO_GPU(dev);\n\n    // Return the prefixed name that was built once during initialization\n    return gpu->cached_device_info.name;\n}\n\nstatic const char * ggml_backend_remoting_device_get_description(ggml_backend_dev_t dev) {\n    virtgpu * gpu = DEV_TO_GPU(dev);\n\n    // Return the pre-cached description from the virtgpu structure\n    return gpu->cached_device_info.description;\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_remoting_device_get_type(ggml_backend_dev_t dev) {\n    virtgpu * gpu = DEV_TO_GPU(dev);\n\n    return (enum ggml_backend_dev_type) gpu->cached_device_info.type;\n}\n\nstatic void ggml_backend_remoting_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {\n    virtgpu * gpu = DEV_TO_GPU(dev);\n\n    *free  = gpu->cached_device_info.memory_free;\n    *total = gpu->cached_device_info.memory_total;\n}\n\nstatic bool ggml_backend_remoting_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n#if USE_ALWAYS_TRUE_SUPPORTS_OP == 1\n    /* ggml-rpc cheats it like this */\n    /* with the current implementation of serialize_tensor, the src/view aren't properly passed */\n    UNUSED(dev);\n    UNUSED(op);\n\n    return true;\n#else\n    virtgpu * gpu = DEV_TO_GPU(dev);\n\n    return apir_device_supports_op(gpu, op);\n#endif\n}\n\nstatic bool ggml_backend_remoting_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    bool supported = buft->device == dev;\n\n    return supported;\n}\n\nstatic bool ggml_backend_remoting_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n    UNUSED(dev);\n    UNUSED(op);\n\n    return false;\n}\n\nstatic void ggml_backend_remoting_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {\n    props->name        = ggml_backend_remoting_device_get_name(dev);\n    props->description = ggml_backend_remoting_device_get_description(dev);\n    props->type        = ggml_backend_remoting_device_get_type(dev);\n    ggml_backend_remoting_device_get_memory(dev, &props->memory_free, &props->memory_total);\n\n    virtgpu * gpu = DEV_TO_GPU(dev);\n    apir_device_get_props(gpu, &props->caps.async, &props->caps.host_buffer, &props->caps.buffer_from_host_ptr,\n                          &props->caps.events);\n\n    props->caps.buffer_from_host_ptr = false;\n    props->caps.async                = false;\n    props->caps.events               = false;\n}\n\nggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev) {\n    virtgpu * gpu = DEV_TO_GPU(dev);\n\n    static std::atomic<bool>        initialized = false;\n    static ggml_backend_buffer_type buft;\n\n    if (!initialized) {\n        static std::mutex           mutex;\n        std::lock_guard<std::mutex> lock(mutex);\n\n        if (!initialized) {\n            buft = {\n                /* .iface    = */ ggml_backend_remoting_buffer_type_interface,\n                /* .device   = */ dev,\n                /* .context  = */ (void *) gpu->cached_buffer_type.host_handle,\n            };\n            initialized = true;\n        }\n    }\n\n    return &buft;\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_from_ptr_type(ggml_backend_dev_t dev) {\n    virtgpu * gpu = DEV_TO_GPU(dev);\n\n    static std::atomic<bool>        initialized = false;\n    static ggml_backend_buffer_type buft;\n\n    if (!initialized) {\n        static std::mutex           mutex;\n        std::lock_guard<std::mutex> lock(mutex);\n\n        if (!initialized) {\n            buft = {\n                /* .iface    = */ ggml_backend_remoting_buffer_from_ptr_type_interface,\n                /* .device   = */ dev,\n                /* .context  = */ (void *) gpu->cached_buffer_type.host_handle,\n            };\n            initialized = true;\n        }\n    }\n\n    return &buft;\n}\n\nstatic ggml_backend_buffer_t ggml_backend_remoting_device_buffer_from_ptr(ggml_backend_dev_t dev,\n                                                                          void *             ptr,\n                                                                          size_t             size,\n                                                                          size_t             max_tensor_size) {\n    virtgpu * gpu = DEV_TO_GPU(dev);\n\n    ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context));\n    if (!context) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: Couldn't allocate the buffer context ...\", __func__);\n    }\n\n    context->gpu          = gpu;\n    context->apir_context = apir_device_buffer_from_ptr(gpu, size, max_tensor_size);\n    context->base         = ptr;\n    context->is_from_ptr  = true;\n\n    ggml_backend_buffer_t buffer =\n        ggml_backend_buffer_init(ggml_backend_remoting_device_get_buffer_from_ptr_type(dev),\n                                 ggml_backend_remoting_buffer_from_ptr_interface, (void *) context, size);\n\n    return buffer;\n}\n\nconst ggml_backend_device_i ggml_backend_remoting_device_interface = {\n    /* .get_name             = */ ggml_backend_remoting_device_get_name,\n    /* .get_description      = */ ggml_backend_remoting_device_get_description,\n    /* .get_memory           = */ ggml_backend_remoting_device_get_memory,\n    /* .get_type             = */ ggml_backend_remoting_device_get_type,\n    /* .get_props            = */ ggml_backend_remoting_device_get_props,\n    /* .init_backend         = */ ggml_backend_remoting_device_init,\n    /* .get_buffer_type      = */ ggml_backend_remoting_device_get_buffer_type,\n    /* .get_host_buffer_type = */ NULL,\n    /* .buffer_from_host_ptr = */ ggml_backend_remoting_device_buffer_from_ptr,\n    /* .supports_op          = */ ggml_backend_remoting_device_supports_op,\n    /* .supports_buft        = */ ggml_backend_remoting_device_supports_buft,\n    /* .offload_op           = */ ggml_backend_remoting_device_offload_op,\n    /* .event_new            = */ NULL,\n    /* .event_free           = */ NULL,\n    /* .event_synchronize    = */ NULL,\n};\n"
  },
  {
    "path": "src/ggml-virtgpu/ggml-backend-reg.cpp",
    "content": "#include \"ggml-remoting.h\"\n#include \"ggml-virtgpu.h\"\n\n#include <iostream>\n#include <mutex>\n\nvoid ggml_virtgpu_cleanup(virtgpu * gpu);\n\nstatic virtgpu * apir_initialize() {\n    static virtgpu *         gpu         = NULL;\n    static std::atomic<bool> initialized = false;\n\n    if (initialized) {\n        // fast track\n        return gpu;\n    }\n\n    {\n        static std::mutex           mutex;\n        std::lock_guard<std::mutex> lock(mutex);\n\n        if (initialized) {\n            // thread safe\n            return gpu;\n        }\n\n        gpu = create_virtgpu();\n        if (!gpu) {\n            initialized = true;\n            return NULL;\n        }\n\n        // Pre-fetch and cache all device information, it will not change\n        gpu->cached_device_info.description = apir_device_get_description(gpu);\n        if (!gpu->cached_device_info.description) {\n            GGML_ABORT(GGML_VIRTGPU \"%s: failed to initialize the virtgpu device description\", __func__);\n        }\n        gpu->cached_device_info.device_count = apir_device_get_count(gpu);\n        gpu->cached_device_info.type         = apir_device_get_type(gpu);\n\n        {\n            // Get the remote name and create prefixed version\n            char * rmt_device_name = apir_device_get_name(gpu);\n            if (!rmt_device_name) {\n                GGML_ABORT(GGML_VIRTGPU \"%s: failed to get the virtgpu device name\", __func__);\n            }\n\n            size_t device_name_len       = strlen(rmt_device_name) + 11;  // \"[virtgpu] \" + null terminator\n            gpu->cached_device_info.name = (char *) malloc(device_name_len);\n            if (!gpu->cached_device_info.name) {\n                free(rmt_device_name);\n                GGML_ABORT(GGML_VIRTGPU \"%s: failed to allocate memory for prefixed device name\", __func__);\n            }\n            snprintf(gpu->cached_device_info.name, device_name_len, \"[virtgpu] %s\", rmt_device_name);\n            free(rmt_device_name);\n        }\n\n        apir_device_get_memory(gpu, &gpu->cached_device_info.memory_free, &gpu->cached_device_info.memory_total);\n\n        apir_buffer_type_host_handle_t buft_host_handle = apir_device_get_buffer_type(gpu);\n        gpu->cached_buffer_type.host_handle             = buft_host_handle;\n        {\n            // Get the remote name and create prefixed version\n            char * rmt_name = apir_buffer_type_get_name(gpu, buft_host_handle);\n            if (!rmt_name) {\n                GGML_ABORT(GGML_VIRTGPU \"%s: failed to get the virtgpu buffer type name\", __func__);\n            }\n\n            size_t prefixed_len          = strlen(rmt_name) + 11;  // \"[virtgpu] \" + null terminator\n            gpu->cached_buffer_type.name = (char *) malloc(prefixed_len);\n            if (!gpu->cached_buffer_type.name) {\n                free(rmt_name);\n                GGML_ABORT(GGML_VIRTGPU \"%s: failed to allocate memory for prefixed buffer type name\", __func__);\n            }\n            snprintf(gpu->cached_buffer_type.name, prefixed_len, \"[virtgpu] %s\", rmt_name);\n            free(rmt_name);\n        }\n\n        gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle);\n        gpu->cached_buffer_type.max_size  = apir_buffer_type_get_max_size(gpu, buft_host_handle);\n\n        initialized = true;\n    }\n\n    return gpu;\n}\n\nstatic int ggml_backend_remoting_get_device_count() {\n    virtgpu * gpu = apir_initialize();\n    if (!gpu) {\n        return 0;\n    }\n\n    return gpu->cached_device_info.device_count;\n}\n\nstatic size_t ggml_backend_remoting_reg_get_device_count(ggml_backend_reg_t reg) {\n    UNUSED(reg);\n\n    return ggml_backend_remoting_get_device_count();\n}\n\nstatic std::vector<ggml_backend_dev_t> devices;\n\nggml_backend_dev_t ggml_backend_remoting_get_device(size_t device) {\n    GGML_ASSERT(device < devices.size());\n    return devices[device];\n}\n\nstatic void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) {\n    if (devices.size() > 0) {\n        GGML_LOG_INFO(GGML_VIRTGPU \"%s: already initialized\\n\", __func__);\n        return;\n    }\n\n    virtgpu * gpu = apir_initialize();\n    if (!gpu) {\n        GGML_LOG_ERROR(GGML_VIRTGPU \"%s: apir_initialize failed\\n\", __func__);\n        return;\n    }\n\n    static std::atomic<bool> initialized = false;\n\n    if (initialized) {\n        return;  // fast track\n    }\n\n    {\n        static std::mutex           mutex;\n        std::lock_guard<std::mutex> lock(mutex);\n        if (!initialized) {\n            for (int i = 0; i < ggml_backend_remoting_get_device_count(); i++) {\n                ggml_backend_remoting_device_context * ctx       = new ggml_backend_remoting_device_context;\n                char                                   desc[256] = \"ggml-virtgpu API Remoting device\";\n\n                ctx->device      = i;\n                ctx->name        = GGML_VIRTGPU_NAME + std::to_string(i);\n                ctx->description = desc;\n                ctx->gpu         = gpu;\n\n                ggml_backend_dev_t dev = new ggml_backend_device{\n                    /* .iface   = */ ggml_backend_remoting_device_interface,\n                    /* .reg     = */ reg,\n                    /* .context = */ ctx,\n                };\n                devices.push_back(dev);\n            }\n            initialized = true;\n        }\n    }\n}\n\nstatic ggml_backend_dev_t ggml_backend_remoting_reg_get_device(ggml_backend_reg_t reg, size_t device) {\n    UNUSED(reg);\n\n    return ggml_backend_remoting_get_device(device);\n}\n\nstatic const char * ggml_backend_remoting_reg_get_name(ggml_backend_reg_t reg) {\n    UNUSED(reg);\n\n    return GGML_VIRTGPU_NAME;\n}\n\nstatic const ggml_backend_reg_i ggml_backend_remoting_reg_i = {\n    /* .get_name         = */ ggml_backend_remoting_reg_get_name,\n    /* .get_device_count = */ ggml_backend_remoting_reg_get_device_count,\n    /* .get_device       = */ ggml_backend_remoting_reg_get_device,\n    /* .get_proc_address = */ NULL,\n};\n\nggml_backend_reg_t ggml_backend_virtgpu_reg() {\n    virtgpu * gpu = apir_initialize();\n    if (!gpu) {\n        GGML_LOG_ERROR(GGML_VIRTGPU \"%s: virtgpu_apir_initialize failed\\n\", __func__);\n    }\n\n    static ggml_backend_reg reg = {\n        /* .api_version = */ GGML_BACKEND_API_VERSION,\n        /* .iface       = */ ggml_backend_remoting_reg_i,\n        /* .context     = */ gpu,\n    };\n\n    static bool initialized = false;\n    if (initialized) {\n        return &reg;\n    }\n    initialized = true;\n\n    ggml_backend_remoting_reg_init_devices(&reg);\n\n    return &reg;\n}\n\n// public function, not exposed in the GGML interface at the moment\nvoid ggml_virtgpu_cleanup(virtgpu * gpu) {\n    if (gpu->cached_device_info.name) {\n        free(gpu->cached_device_info.name);\n        gpu->cached_device_info.name = NULL;\n    }\n    if (gpu->cached_device_info.description) {\n        free(gpu->cached_device_info.description);\n        gpu->cached_device_info.description = NULL;\n    }\n    if (gpu->cached_buffer_type.name) {\n        free(gpu->cached_buffer_type.name);\n        gpu->cached_buffer_type.name = NULL;\n    }\n\n    mtx_destroy(&gpu->data_shmem_mutex);\n}\n\nGGML_BACKEND_DL_IMPL(ggml_backend_virtgpu_reg)\n"
  },
  {
    "path": "src/ggml-virtgpu/ggml-backend.cpp",
    "content": "#include \"../../include/ggml-virtgpu.h\"\n#include \"ggml-remoting.h\"\n\nstatic const char * ggml_backend_remoting_get_name(ggml_backend_t backend) {\n    UNUSED(backend);\n\n    return \"API Remoting backend\";\n}\n\nstatic void ggml_backend_remoting_free(ggml_backend_t backend) {\n    delete backend;\n}\n\nstatic ggml_status ggml_backend_remoting_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {\n    virtgpu * gpu = DEV_TO_GPU(backend->device);\n\n    return apir_backend_graph_compute(gpu, cgraph);\n}\n\nstatic void ggml_backend_remoting_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {\n    virtgpu * gpu = DEV_TO_GPU(backend->device);\n#if true\n    UNUSED(gpu);\n    UNUSED(cgraph);\n#else\n    // not working yet\n\n    apir_backend_graph_optimize(gpu, cgraph);\n#endif\n}\n\nstatic ggml_backend_i ggml_backend_remoting_interface = {\n    /* .get_name                = */ ggml_backend_remoting_get_name,\n    /* .free                    = */ ggml_backend_remoting_free,\n    /* .set_tensor_async        = */ NULL,  // ggml_backend_remoting_set_tensor_async,\n    /* .get_tensor_async        = */ NULL,  // ggml_backend_remoting_get_tensor_async,\n    /* .cpy_tensor_async        = */ NULL,  // ggml_backend_remoting_cpy_tensor_async,\n    /* .synchronize             = */ NULL,  // ggml_backend_remoting_synchronize,\n    /* .graph_plan_create       = */ NULL,\n    /* .graph_plan_free         = */ NULL,\n    /* .graph_plan_update       = */ NULL,\n    /* .graph_plan_compute      = */ NULL,\n    /* .graph_compute           = */ ggml_backend_remoting_graph_compute,\n    /* .event_record            = */ NULL,\n    /* .event_wait              = */ NULL,\n    /* .graph_optimize          = */ ggml_backend_remoting_graph_optimize,\n};\n\nstatic ggml_guid_t ggml_backend_remoting_guid() {\n    static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x14, 0x03, 0x86, 0x02,\n                              0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };\n\n    return &guid;\n}\n\nggml_backend_t ggml_backend_remoting_device_init(ggml_backend_dev_t dev, const char * params) {\n    UNUSED(params);\n\n    ggml_backend_remoting_device_context * ctx = (ggml_backend_remoting_device_context *) dev->context;\n\n    ggml_backend_t remoting_backend = new ggml_backend{\n        /* .guid      = */ ggml_backend_remoting_guid(),\n        /* .interface = */ ggml_backend_remoting_interface,\n        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_virtgpu_reg(), ctx->device),\n        /* .context   = */ ctx,\n    };\n\n    return remoting_backend;\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/ggml-remoting.h",
    "content": "#pragma once\n\n#include \"ggml-backend-impl.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-impl.h\"\n#include \"virtgpu.h\"\n\n#include <memory>\n#include <string>\n\n#define GGML_VIRTGPU_NAME \"ggml-virtgpu\"\n#define GGML_VIRTGPU      \"ggml-virtgpu: \"\n\n// USE_ALWAYS_TRUE_SUPPORTS_OP: 1 is fast, 0 avoid micro-benchmark crashes\n\n#define USE_ALWAYS_TRUE_SUPPORTS_OP 1\n#define USE_METAL_GUEST_SUPPORTS_OP 0\n\n#define DEV_TO_GPU(name) ((ggml_backend_remoting_device_context *) (name)->context)->gpu\n\n#define BUFFER_TO_GGML_CONTEXT(name) ((ggml_backend_remoting_buffer_context *) (name)->context)\n\n#define BUFFER_TO_APIR_CONTEXT(name) &((ggml_backend_remoting_buffer_context *) (name)->context)->apir_context\n\n#define BUFFER_TO_HOST_HANDLE(name) ((ggml_backend_remoting_buffer_context *) (name)->context)->apir_context.host_handle\n\n#define GET_DEVICE_CONTEXT() (ggml_backend_remoting_device_context *) ggml_backend_remoting_get_device(0)->context\n\n#define BUFT_TO_GPU(name) ((ggml_backend_remoting_device_context *) (name)->device->context)->gpu\n\nstruct ggml_backend_remoting_device_context {\n    size_t      device;\n    std::string name;\n    std::string description;\n\n    std::vector<std::tuple<void *, size_t, virtgpu_shmem *>> shared_memory;\n\n    virtgpu * gpu;\n};\n\nstruct ggml_backend_remoting_buffer_context {\n    apir_buffer_context_t apir_context;\n\n    virtgpu * gpu;\n\n    void * base;\n\n    bool is_from_ptr;\n};\n\nextern const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_type_interface;\nextern const ggml_backend_device_i      ggml_backend_remoting_device_interface;\nextern const ggml_backend_buffer_i      ggml_backend_remoting_buffer_interface;\nextern const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_from_ptr_type_interface;\nextern const ggml_backend_buffer_i      ggml_backend_remoting_buffer_from_ptr_interface;\n\nggml_backend_dev_t         ggml_backend_remoting_get_device(size_t device);\nggml_backend_t             ggml_backend_remoting_device_init(ggml_backend_dev_t dev, const char * params);\nggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev);\n\nstatic inline apir_buffer_type_host_handle_t ggml_buffer_type_to_apir_handle(ggml_backend_buffer_type_t buft) {\n    // in the backend, the buffer handle is the buffer pointer\n    return (apir_buffer_type_host_handle_t) buft->context;\n}\n\nstatic inline apir_buffer_host_handle_t ggml_buffer_to_apir_handle(ggml_backend_buffer_t buffer) {\n    if (!buffer->context) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: no context available :/\", __func__);\n    }\n    return BUFFER_TO_HOST_HANDLE(buffer);\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/ggmlremoting_functions.yaml",
    "content": "# YAML schema for GGML remoting API functions\n# This defines the structure for generating the remoting layer code\n\n# Configuration for the generated files\nconfig:\n  # Base path for the generated files\n  base_path: \"ggml/src\"\n\n  # Header files to update\n  files:\n    apir_backend_header: \"ggml-virtgpu-apir/backend/shared/apir_backend.gen.h\"\n    backend_dispatched_header: \"ggml-virtgpu-apir/backend/backend-dispatched.gen.h\"\n    virtgpu_forward_header: \"ggml-virtgpu-apir/virtgpu-forward.gen.h\"\n\n# Simplified function definitions with grouping and metadata combined\nfunctions:\n  device:\n    group_description: \"device\"\n    functions:\n      get_device_count:\n        # No specific metadata - uses default void return and base params\n\n      get_count:\n        frontend_return: \"int\"\n\n      get_name:\n        frontend_return: \"char *\"\n\n      get_description:\n        frontend_return: \"char *\"\n\n      get_type:\n        frontend_return: \"uint32_t\"\n\n      get_memory:\n        frontend_return: \"void\"\n        frontend_extra_params:\n        - \"size_t *free\"\n        - \"size_t *total\"\n\n      supports_op:\n        frontend_return: \"bool\"\n        frontend_extra_params:\n        - \"const ggml_tensor *op\"\n\n      get_buffer_type:\n        frontend_return: \"apir_buffer_type_host_handle_t\"\n\n      get_props:\n        frontend_return: \"void\"\n        frontend_extra_params:\n        - \"bool *async\"\n        - \"bool *host_buffer\"\n        - \"bool *buffer_from_host_ptr\"\n        - \"bool *events\"\n\n      buffer_from_ptr:\n        frontend_return: \"apir_buffer_context_t\"\n        frontend_extra_params:\n        - \"size_t size\"\n        - \"size_t max_tensor_size\"\n\n  buffer_type:\n    group_description: \"buffer-type\"\n    functions:\n      get_name:\n        frontend_return: \"char *\"\n        frontend_extra_params:\n        - \"apir_buffer_type_host_handle_t host_handle\"\n\n      get_alignment:\n        frontend_return: \"size_t\"\n        frontend_extra_params:\n        - \"apir_buffer_type_host_handle_t host_handle\"\n\n      get_max_size:\n        frontend_return: \"size_t\"\n        frontend_extra_params:\n        - \"apir_buffer_type_host_handle_t host_handle\"\n\n      is_host:\n        deprecated: true\n\n      alloc_buffer:\n        frontend_return: \"apir_buffer_context_t\"\n        frontend_extra_params:\n        - \"apir_buffer_type_host_handle_t host_handle\"\n        - \"size_t size\"\n\n      get_alloc_size:\n        frontend_return: \"size_t\"\n        frontend_extra_params:\n        - \"apir_buffer_type_host_handle_t host_handle\"\n        - \"const ggml_tensor *op\"\n\n  buffer:\n    group_description: \"buffer\"\n    functions:\n      get_base:\n        frontend_return: \"void *\"\n        frontend_extra_params:\n        - \"apir_buffer_context_t *buffer_context\"\n\n      set_tensor:\n        frontend_return: \"void\"\n        frontend_extra_params:\n        - \"apir_buffer_context_t *buffer_context\"\n        - \"ggml_tensor *tensor\"\n        - \"const void *data\"\n        - \"size_t offset\"\n        - \"size_t size\"\n\n      get_tensor:\n        frontend_return: \"void\"\n        frontend_extra_params:\n        - \"apir_buffer_context_t *buffer_context\"\n        - \"const ggml_tensor *tensor\"\n        - \"void *data\"\n        - \"size_t offset\"\n        - \"size_t size\"\n\n      cpy_tensor:\n        frontend_return: \"bool\"\n        frontend_extra_params:\n        - \"apir_buffer_context_t *buffer_context\"\n        - \"const ggml_tensor *src\"\n        - \"const ggml_tensor *dst\"\n\n      clear:\n        frontend_return: \"void\"\n        frontend_extra_params:\n        - \"apir_buffer_context_t *buffer_context\"\n        - \"uint8_t value\"\n\n      free_buffer:\n        frontend_return: \"void\"\n        frontend_extra_params:\n        - \"apir_buffer_context_t *buffer_context\"\n\n  backend:\n    group_description: \"backend\"\n    functions:\n      graph_compute:\n        frontend_return: \"ggml_status\"\n        frontend_extra_params:\n        - \"ggml_cgraph *cgraph\"\n\n      graph_optimize:\n        frontend_return: \"ggml_cgraph *\"\n        frontend_extra_params:\n        - \"ggml_cgraph *cgraph\"\n        enabled: false\n\n# Naming patterns used for code generation\nnaming_patterns:\n  # How to generate enum names\n  enum_prefix: \"APIR_COMMAND_TYPE_\"\n\n  # How to generate backend function names\n  backend_function_prefix: \"backend_\"\n\n  # How to generate frontend function names\n  frontend_function_prefix: \"apir_\"\n\n  # Standard frontend first parameter\n  frontend_base_param: \"struct virtgpu *gpu\"\n"
  },
  {
    "path": "src/ggml-virtgpu/include/apir_hw.h",
    "content": "#pragma once\n\n#include <stdint.h>\n\nstruct virgl_renderer_capset_apir {\n    uint32_t apir_version;\n    uint32_t supports_blob_resources;\n    uint32_t reserved[4];  // For future expansion\n};\n"
  },
  {
    "path": "src/ggml-virtgpu/regenerate_remoting.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\n# Generated by Claude AI\n\nScript to completely regenerate the GGML remoting codebase from YAML configuration.\n\nThis script reads api_functions.yaml and regenerates all the header files and\nimplementation templates for the GGML remoting layer.\n\nUsage:\n  python regenerate_remoting.py\n\nThe script will:\n1. Read ggmlremoting_functions.yaml configuration\n2. Generate updated header files\n3. Generate implementation templates in dedicated files\n4. Show a summary of what was generated\n\"\"\"\n\nimport yaml\nfrom typing import Dict, List, Any\nfrom pathlib import Path\nimport os\nimport subprocess\nimport shutil\nimport logging\n\nNL = '\\n' # can't have f\"{'\\n'}\" in f-strings\n\n\nclass RemotingCodebaseGenerator:\n    def __init__(self, yaml_path: str = \"ggmlremoting_functions.yaml\"):\n        \"\"\"Initialize the generator with the YAML configuration.\"\"\"\n        self.yaml_path = yaml_path\n\n        if not Path(yaml_path).exists():\n            raise FileNotFoundError(f\"Configuration file {yaml_path} not found\")\n\n        with open(yaml_path, 'r') as f:\n            self.config = yaml.safe_load(f)\n\n        self.functions = self.config['functions']\n        self.naming_patterns = self.config['naming_patterns']\n        self.config_data = self.config['config']\n\n        # Check if clang-format is available\n        self.clang_format_available = self._check_clang_format_available()\n\n    def _check_clang_format_available(self) -> bool:\n        \"\"\"Check if clang-format is available in the system PATH.\"\"\"\n        return shutil.which(\"clang-format\") is not None\n\n    def _format_file_with_clang_format(self, file_path: Path) -> bool:\n        \"\"\"Format a file with clang-format -i. Returns True if successful, False otherwise.\"\"\"\n        if not self.clang_format_available:\n            return False\n\n        try:\n            subprocess.run(\n                [\"clang-format\", \"-i\", str(file_path)],\n                check=True,\n                capture_output=True,\n                text=True\n            )\n            return True\n        except subprocess.CalledProcessError:\n            logging.exception(f\"   ⚠️  clang-format failed for {file_path}\")\n            return False\n        except Exception as e:\n            logging.exception(f\"   ⚠️  Unexpected error formatting {file_path}: {e}\")\n            return False\n\n    def generate_enum_name(self, group_name: str, function_name: str) -> str:\n        \"\"\"Generate the APIR_COMMAND_TYPE enum name for a function.\"\"\"\n        prefix = self.naming_patterns['enum_prefix']\n        return f\"{prefix}{group_name.upper()}_{function_name.upper()}\"\n\n    def generate_backend_function_name(self, group_name: str, function_name: str) -> str:\n        \"\"\"Generate the backend function name.\"\"\"\n        function_key = f\"{group_name}_{function_name}\"\n        overrides = self.naming_patterns.get('backend_function_overrides', {})\n\n        if function_key in overrides:\n            return overrides[function_key]\n\n        prefix = self.naming_patterns['backend_function_prefix']\n        return f\"{prefix}{group_name}_{function_name}\"\n\n    def generate_frontend_function_name(self, group_name: str, function_name: str) -> str:\n        \"\"\"Generate the frontend function name.\"\"\"\n        prefix = self.naming_patterns['frontend_function_prefix']\n        return f\"{prefix}{group_name}_{function_name}\"\n\n    def get_enabled_functions(self) -> List[Dict[str, Any]]:\n        \"\"\"Get all enabled functions with their metadata.\"\"\"\n        functions = []\n        enum_value = 0\n\n        for group_name, group_data in self.functions.items():\n            group_description = group_data['group_description']\n\n            for function_name, func_metadata in group_data['functions'].items():\n                # Handle case where func_metadata is None or empty (functions with only comments)\n                if func_metadata is None:\n                    func_metadata = {}\n\n                # Functions are enabled by default unless explicitly disabled\n                if func_metadata.get('enabled', True):\n                    functions.append({\n                        'group_name': group_name,\n                        'function_name': function_name,\n                        'enum_name': self.generate_enum_name(group_name, function_name),\n                        'enum_value': enum_value,\n                        'backend_function': self.generate_backend_function_name(group_name, function_name),\n                        'frontend_function': self.generate_frontend_function_name(group_name, function_name),\n                        'frontend_return': func_metadata.get('frontend_return', 'void'),\n                        'frontend_extra_params': func_metadata.get('frontend_extra_params', []),\n                        'group_description': group_description,\n                        'deprecated': func_metadata.get('deprecated', False),\n                    })\n                    enum_value += 1\n\n        return functions\n\n    def generate_apir_backend_header(self) -> str:\n        \"\"\"Generate the complete apir_backend.h file.\"\"\"\n        functions = self.get_enabled_functions()\n\n        # Generate the enum section\n        enum_lines = [\"typedef enum ApirBackendCommandType {\"]\n        current_group = None\n\n        for func in functions:\n            # Add comment for new group\n            if func['group_name'] != current_group:\n                enum_lines.append(\"\")\n                enum_lines.append(f\"  /* {func['group_description']} */\")\n                current_group = func['group_name']\n\n            enum_lines.append(f\"  {func['enum_name']} = {func['enum_value']},\")\n\n        # Add the count\n        total_count = len(functions)\n        enum_lines.append(\"\\n  // last command_type index + 1\")\n        enum_lines.append(f\"  APIR_BACKEND_DISPATCH_TABLE_COUNT = {total_count},\")\n        enum_lines.append(\"} ApirBackendCommandType;\")\n\n        # Generate function name mapping\n        func_lines = []\n        func_lines.append(\"static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) {\")\n        func_lines.append(\"    switch (type) {\")\n\n        current_group = None\n        for func in functions:\n            # Add comment for new group\n            if func['group_name'] != current_group:\n                func_lines.append(f\"        /* {func['group_description']} */\")\n                current_group = func['group_name']\n\n            # Generate clean function name without backend_ prefix\n            clean_name = f\"{func['group_name']}_{func['function_name']}\"\n            func_lines.append(f\"        case {func['enum_name']}:\")\n            func_lines.append(f\"            return \\\"{clean_name}\\\";\")\n\n        func_lines.append(\"\")\n        func_lines.append(\"        default:\")\n        func_lines.append(\"            return \\\"unknown\\\";\")\n        func_lines.append(\"    }\")\n        func_lines.append(\"}\")\n\n        # Full header template\n        header_content = NL.join(enum_lines) + \"\\n\\n\" + NL.join(func_lines) + \"\\n\"\n\n        return header_content\n\n    def generate_backend_dispatched_header(self) -> str:\n        \"\"\"Generate the complete backend-dispatched.h file.\"\"\"\n        functions = self.get_enabled_functions()\n\n        # Function declarations\n        decl_lines = []\n        current_group = None\n\n        for func in functions:\n            if func['group_name'] != current_group:\n                decl_lines.append(f\"\\n/* {func['group_description']} */\")\n                current_group = func['group_name']\n\n            signature = \"uint32_t\"\n            params = \"apir_encoder *enc, apir_decoder *dec, virgl_apir_context *ctx\"\n            if func['deprecated']:\n                decl_lines.append(f\"/* {func['enum_name']} is deprecated. Keeping the handler for backward compatibility. */\")\n\n            decl_lines.append(f\"{signature} {func['backend_function']}({params});\")\n\n        # Dispatch table\n        table_lines = []\n        current_group = None\n\n        for func in functions:\n            if func['group_name'] != current_group:\n                table_lines.append(f\"\\n  /* {func['group_description']} */\")\n                table_lines.append(\"\")\n                current_group = func['group_name']\n\n            deprecated = \" /* DEPRECATED */\" if func['deprecated'] else \"\"\n            table_lines.append(f\"  /* {func['enum_name']}  = */ {func['backend_function']}{deprecated},\")\n\n        header_content = f'''\\\n#pragma once\n\n{NL.join(decl_lines)}\n\nextern \"C\" {{\nstatic const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {{\n  {NL.join(table_lines)}\n}};\n}}\n'''\n        return header_content\n\n    def generate_virtgpu_forward_header(self) -> str:\n        \"\"\"Generate the complete virtgpu-forward.gen.h file.\"\"\"\n        functions = self.get_enabled_functions()\n\n        decl_lines = []\n        current_group = None\n\n        for func in functions:\n            if func['group_name'] != current_group:\n                decl_lines.append(\"\")\n                decl_lines.append(f\"/* {func['group_description']} */\")\n                current_group = func['group_name']\n\n            if func['deprecated']:\n                decl_lines.append(f\"/* {func['frontend_function']} is deprecated. */\")\n                continue\n\n            # Build parameter list\n            params = [self.naming_patterns['frontend_base_param']]\n            params.extend(func['frontend_extra_params'])\n            param_str = ', '.join(params)\n\n            decl_lines.append(f\"{func['frontend_return']} {func['frontend_function']}({param_str});\")\n\n        header_content = f'''\\\n#pragma once\n{NL.join(decl_lines)}\n'''\n        return header_content\n\n    def regenerate_codebase(self) -> None:\n        \"\"\"Regenerate the entire remoting codebase.\"\"\"\n        logging.info(\"🔄 Regenerating GGML Remoting Codebase...\")\n        logging.info(\"=\" * 50)\n\n        # Detect if we're running from frontend directory\n        current_dir = os.getcwd()\n        is_frontend_dir = current_dir.endswith('ggml-virtgpu')\n\n        if is_frontend_dir:\n            # Running from ggml/src/ggml-virtgpu-apir\n            logging.info(\"📍 Detected frontend directory execution\")\n            frontend_base = Path(\".\")\n        else:\n            # Running from project root (fallback to original behavior)\n            logging.info(\"📍 Detected project root execution\")\n            base_path = self.config_data.get('base_path', 'ggml/src')\n            frontend_base = Path(base_path) / \"ggml-virtgpu\"\n\n        # Compute final file paths\n        backend_base = frontend_base / \"backend\"\n        apir_backend_path = backend_base / \"shared\" / \"apir_backend.gen.h\"\n        backend_dispatched_path = backend_base / \"backend-dispatched.gen.h\"\n        virtgpu_forward_path = frontend_base / \"virtgpu-forward.gen.h\"\n\n        # Create output directories for each file\n        apir_backend_path.parent.mkdir(parents=True, exist_ok=True)\n        backend_dispatched_path.parent.mkdir(parents=True, exist_ok=True)\n        virtgpu_forward_path.parent.mkdir(parents=True, exist_ok=True)\n\n        # Generate header files\n        logging.info(\"📁 Generating header files...\")\n\n        apir_backend_content = self.generate_apir_backend_header()\n        apir_backend_path.write_text(apir_backend_content)\n        logging.info(f\"   ✅ {apir_backend_path.resolve()}\")\n\n        backend_dispatched_content = self.generate_backend_dispatched_header()\n        backend_dispatched_path.write_text(backend_dispatched_content)\n        logging.info(f\"   ✅ {backend_dispatched_path.resolve()}\")\n\n        virtgpu_forward_content = self.generate_virtgpu_forward_header()\n        virtgpu_forward_path.write_text(virtgpu_forward_content)\n        logging.info(f\"   ✅ {virtgpu_forward_path.resolve()}\")\n\n        # Format generated files with clang-format\n        generated_files = [apir_backend_path, backend_dispatched_path, virtgpu_forward_path]\n\n        if not self.clang_format_available:\n            logging.warning(\"\\n⚠️clang-format not found in PATH. Generated files will not be formatted.\\n\"\n                            \"   Install clang-format to enable automatic code formatting.\")\n        else:\n            logging.info(\"\\n🎨 Formatting files with clang-format...\")\n            for file_path in generated_files:\n                if self._format_file_with_clang_format(file_path):\n                    logging.info(f\"   ✅ Formatted {file_path.name}\")\n                else:\n                    logging.warning(f\"   ❌ Failed to format {file_path.name}\")\n\n        # Generate summary\n        functions = self.get_enabled_functions()\n        total_functions = len(functions)\n\n        logging.info(\"\\n📊 Generation Summary:\")\n        logging.info(\"=\" * 50)\n        logging.info(f\"   Total functions: {total_functions}\")\n        logging.info(f\"   Function groups: {len(self.functions)}\")\n        logging.info(\"   Header files: 3\")\n        logging.info(f\"   Working directory: {current_dir}\")\n\n\ndef main():\n    try:\n        generator = RemotingCodebaseGenerator()\n        generator.regenerate_codebase()\n    except Exception as e:\n        logging.exception(f\"❌ Error: {e}\")\n        exit(1)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/ggml-virtgpu/virtgpu-apir.h",
    "content": "#include \"backend/shared/apir_backend.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-impl.h\"\n#include \"ggml.h\"\n#include \"virtgpu-shm.h\"\n#include \"virtgpu-utils.h\"\n\nstruct apir_buffer_context_t {\n    apir_buffer_host_handle_t host_handle;\n\n    struct virtgpu_shmem           shmem;\n    apir_buffer_type_host_handle_t buft_host_handle;\n};\n\n#include \"virtgpu-forward.gen.h\"\n"
  },
  {
    "path": "src/ggml-virtgpu/virtgpu-forward-backend.cpp",
    "content": "#include \"virtgpu-forward-impl.h\"\n\nstatic long long current_time_ms() {\n    timespec ts;\n    clock_gettime(CLOCK_REALTIME, &ts);  // Use CLOCK_MONOTONIC for elapsed time\n    return (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec;\n}\n\nggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE);\n\n    std::vector<uint8_t> cgraph_data;\n    size_t               cgraph_size = apir_serialize_ggml_cgraph(cgraph, cgraph_data);\n\n    virtgpu_shmem   temp_shmem;  // Local storage for large buffers\n    virtgpu_shmem * shmem              = &temp_shmem;\n    bool            using_shared_shmem = false;\n\n    if (cgraph_size <= gpu->data_shmem.mmap_size) {\n        // Lock mutex before using shared data_shmem buffer\n        if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) {\n            GGML_ABORT(GGML_VIRTGPU \"%s: Failed to lock data_shmem mutex\", __func__);\n        }\n        using_shared_shmem = true;\n        shmem              = &gpu->data_shmem;\n    } else if (virtgpu_shmem_create(gpu, cgraph_size, shmem)) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: Couldn't allocate the guest-host shared buffer\", __func__);\n    }\n\n    apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id);\n\n    apir_encode_size_t(encoder, &cgraph_size);\n\n    char *       shmem_data    = (char *) shmem->mmap_ptr;\n    apir_encoder secondary_enc = apir_new_encoder(shmem_data, cgraph_size);\n\n    apir_encode_cgraph_data(&secondary_enc, cgraph_data);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    ggml_status status = GGML_STATUS_ABORTED;\n    apir_decode_ggml_status(decoder, &status);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    // Unlock mutex before cleanup\n    if (using_shared_shmem) {\n        mtx_unlock(&gpu->data_shmem_mutex);\n    } else {\n        virtgpu_shmem_destroy(gpu, shmem);\n    }\n\n    return status;\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp",
    "content": "#include \"virtgpu-forward-impl.h\"\n\nchar * apir_buffer_type_get_name(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME);\n\n    apir_encode_apir_buffer_type_host_handle(encoder, host_handle);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    const size_t string_size = apir_decode_array_size_unchecked(decoder);\n    char *       string      = (char *) apir_decoder_alloc_array(sizeof(char), string_size);\n    if (!string) {\n        GGML_LOG_ERROR(GGML_VIRTGPU \"%s: Could not allocate the device name buffer\\n\", __func__);\n        apir_decoder_set_fatal(decoder);\n    }\n    apir_decode_char_array(decoder, string, string_size);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return string;\n}\n\nsize_t apir_buffer_type_get_alignment(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT);\n\n    apir_encode_apir_buffer_type_host_handle(encoder, host_handle);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    size_t alignment;\n    apir_decode_size_t(decoder, &alignment);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return alignment;\n}\n\nsize_t apir_buffer_type_get_max_size(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE);\n\n    apir_encode_apir_buffer_type_host_handle(encoder, host_handle);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    size_t max_size;\n    apir_decode_size_t(decoder, &max_size);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return max_size;\n}\n\napir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu *                      gpu,\n                                                    apir_buffer_type_host_handle_t host_handle,\n                                                    size_t                         size) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    apir_buffer_context_t buffer_context;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER);\n\n    apir_encode_apir_buffer_type_host_handle(encoder, host_handle);\n\n    apir_encode_size_t(encoder, &size);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    apir_decode_apir_buffer_host_handle_t(decoder, &buffer_context.host_handle);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return buffer_context;\n}\n\nsize_t apir_buffer_type_get_alloc_size(virtgpu *                      gpu,\n                                       apir_buffer_type_host_handle_t host_handle,\n                                       const ggml_tensor *            op) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE);\n\n    apir_encode_apir_buffer_type_host_handle(encoder, host_handle);\n\n    apir_encode_ggml_tensor_inline(encoder, op);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    size_t alloc_size;\n    apir_decode_size_t(decoder, &alloc_size);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return alloc_size;\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/virtgpu-forward-buffer.cpp",
    "content": "#include \"virtgpu-forward-impl.h\"\n\nvoid * apir_buffer_get_base(virtgpu * gpu, apir_buffer_context_t * buffer_context) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_GET_BASE);\n\n    apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    uintptr_t base;\n    apir_decode_uintptr_t(decoder, &base);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return (void *) base;\n}\n\nvoid apir_buffer_set_tensor(virtgpu *               gpu,\n                            apir_buffer_context_t * buffer_context,\n                            ggml_tensor *           tensor,\n                            const void *            data,\n                            size_t                  offset,\n                            size_t                  size) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_SET_TENSOR);\n\n    apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle);\n    apir_encode_ggml_tensor(encoder, tensor);\n\n    virtgpu_shmem   temp_shmem;  // Local storage for large buffers\n    virtgpu_shmem * shmem              = &temp_shmem;\n    bool            using_shared_shmem = false;\n\n    if (size <= gpu->data_shmem.mmap_size) {\n        // Lock mutex before using shared data_shmem buffer\n        if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) {\n            GGML_ABORT(GGML_VIRTGPU \"%s: Failed to lock data_shmem mutex\", __func__);\n        }\n        using_shared_shmem = true;\n        shmem              = &gpu->data_shmem;\n\n    } else if (virtgpu_shmem_create(gpu, size, shmem)) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: Couldn't allocate the guest-host shared buffer\", __func__);\n    }\n\n    memcpy(shmem->mmap_ptr, data, size);\n    apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id);\n\n    apir_encode_size_t(encoder, &offset);\n    apir_encode_size_t(encoder, &size);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    // Unlock mutex before cleanup\n    if (using_shared_shmem) {\n        mtx_unlock(&gpu->data_shmem_mutex);\n    } else {\n        virtgpu_shmem_destroy(gpu, shmem);\n    }\n\n    return;\n}\n\nvoid apir_buffer_get_tensor(virtgpu *               gpu,\n                            apir_buffer_context_t * buffer_context,\n                            const ggml_tensor *     tensor,\n                            void *                  data,\n                            size_t                  offset,\n                            size_t                  size) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_GET_TENSOR);\n\n    apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle);\n    apir_encode_ggml_tensor(encoder, tensor);\n\n    virtgpu_shmem   temp_shmem;  // Local storage for large buffers\n    virtgpu_shmem * shmem              = &temp_shmem;\n    bool            using_shared_shmem = false;\n\n    if (size <= gpu->data_shmem.mmap_size) {\n        // Lock mutex before using shared data_shmem buffer\n        if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) {\n            GGML_ABORT(GGML_VIRTGPU \"%s: Failed to lock data_shmem mutex\", __func__);\n        }\n        using_shared_shmem = true;\n        shmem              = &gpu->data_shmem;\n\n    } else if (virtgpu_shmem_create(gpu, size, shmem)) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: Couldn't allocate the guest-host shared buffer\", __func__);\n    }\n\n    apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id);\n    apir_encode_size_t(encoder, &offset);\n    apir_encode_size_t(encoder, &size);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    memcpy(data, shmem->mmap_ptr, size);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    // Unlock mutex before cleanup\n    if (using_shared_shmem) {\n        mtx_unlock(&gpu->data_shmem_mutex);\n    } else {\n        virtgpu_shmem_destroy(gpu, shmem);\n    }\n}\n\nbool apir_buffer_cpy_tensor(virtgpu *               gpu,\n                            apir_buffer_context_t * buffer_context,\n                            const ggml_tensor *     src,\n                            const ggml_tensor *     dst) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR);\n\n    apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle);\n    apir_encode_ggml_tensor(encoder, src);\n    apir_encode_ggml_tensor(encoder, dst);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    bool ret_val;\n    apir_decode_bool_t(decoder, &ret_val);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return ret_val;\n}\n\nvoid apir_buffer_clear(virtgpu * gpu, apir_buffer_context_t * buffer_context, uint8_t value) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_CLEAR);\n\n    apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle);\n    apir_encode_uint8_t(encoder, &value);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    remote_call_finish(gpu, encoder, decoder);\n}\n\nvoid apir_buffer_free_buffer(virtgpu * gpu, apir_buffer_context_t * buffer_context) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER);\n\n    apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    remote_call_finish(gpu, encoder, decoder);\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/virtgpu-forward-device.cpp",
    "content": "#include \"virtgpu-forward-impl.h\"\n#include \"virtgpu-shm.h\"\n\nint apir_device_get_count(virtgpu * gpu) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_COUNT);\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    int32_t dev_count = -1;\n    apir_decode_int32_t(decoder, &dev_count);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return dev_count;\n}\n\nchar * apir_device_get_name(virtgpu * gpu) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_NAME);\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    const size_t string_size = apir_decode_array_size_unchecked(decoder);\n    char *       string      = (char *) apir_decoder_alloc_array(sizeof(char), string_size);\n    if (!string) {\n        GGML_LOG_ERROR(GGML_VIRTGPU \"%s: Could not allocate the device name buffer\\n\", __func__);\n        return NULL;\n    }\n    apir_decode_char_array(decoder, string, string_size);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return string;\n}\n\nchar * apir_device_get_description(virtgpu * gpu) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    const size_t string_size = apir_decode_array_size_unchecked(decoder);\n    char *       string      = (char *) apir_decoder_alloc_array(sizeof(char), string_size);\n    if (!string) {\n        GGML_LOG_ERROR(GGML_VIRTGPU \"%s: Could not allocate the device description buffer\\n\", __func__);\n\n        return NULL;\n    }\n    apir_decode_char_array(decoder, string, string_size);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return string;\n}\n\nuint32_t apir_device_get_type(virtgpu * gpu) {\n    static uint32_t dev_type = 255;\n    if (dev_type != 255) {\n        return dev_type;\n    }\n\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_TYPE);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    apir_decode_uint32_t(decoder, &dev_type);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return dev_type;\n}\n\nvoid apir_device_get_memory(virtgpu * gpu, size_t * free, size_t * total) {\n    static size_t         dev_free  = 0;\n    static size_t         dev_total = 0;\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_MEMORY);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    apir_decode_size_t(decoder, &dev_free);\n    apir_decode_size_t(decoder, &dev_total);\n\n    *free  = dev_free;\n    *total = dev_total;\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return;\n}\n\nbool apir_device_supports_op(virtgpu * gpu, const ggml_tensor * op) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP);\n\n    apir_encode_ggml_tensor_inline(encoder, op);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    bool supports_op;\n    apir_decode_bool_t(decoder, &supports_op);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return supports_op;\n}\n\napir_buffer_type_host_handle_t apir_device_get_buffer_type(virtgpu * gpu) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    apir_buffer_type_host_handle_t buft_handle;\n    apir_decode_apir_buffer_type_host_handle_t(decoder, &buft_handle);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return buft_handle;\n}\n\nvoid apir_device_get_props(virtgpu * gpu,\n                           bool *    async,\n                           bool *    host_buffer,\n                           bool *    buffer_from_host_ptr,\n                           bool *    events) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_PROPS);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    apir_decode_bool_t(decoder, async);\n    apir_decode_bool_t(decoder, host_buffer);\n    apir_decode_bool_t(decoder, buffer_from_host_ptr);\n    apir_decode_bool_t(decoder, events);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return;\n}\n\napir_buffer_context_t apir_device_buffer_from_ptr(virtgpu * gpu, size_t size, size_t max_tensor_size) {\n    apir_encoder *        encoder;\n    apir_decoder *        decoder;\n    ApirForwardReturnCode ret;\n\n    apir_buffer_context_t buffer_context;\n\n    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR);\n\n    if (virtgpu_shmem_create(gpu, size, &buffer_context.shmem)) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: Couldn't allocate %ldb of guest-host shared buffer\", __func__, size);\n    }\n\n    apir_encode_virtgpu_shmem_res_id(encoder, buffer_context.shmem.res_id);\n\n    apir_encode_size_t(encoder, &size);\n    apir_encode_size_t(encoder, &max_tensor_size);\n\n    REMOTE_CALL(gpu, encoder, decoder, ret);\n\n    apir_decode_apir_buffer_host_handle_t(decoder, &buffer_context.host_handle);\n    buffer_context.buft_host_handle = apir_decode_apir_buffer_type_host_handle(decoder);\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    return buffer_context;\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/virtgpu-forward-impl.h",
    "content": "#pragma once\n\n// clang-format off\n#include \"virtgpu.h\"\n#include \"ggml-remoting.h\"\n#include \"backend/shared/apir_backend.h\"\n#include \"backend/shared/apir_cs_ggml.h\"\n#include \"ggml-backend-impl.h\"\n// clang-format on\n\n#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__)                                           \\\n    int32_t      REMOTE_CALL_PREPARE_forward_flag = (int32_t) apir_command_type__;                                     \\\n    const char * REMOTE_CALL_PREPARE_command_name = apir_dispatch_command_name(apir_command_type__);                   \\\n    do {                                                                                                               \\\n        encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, REMOTE_CALL_PREPARE_forward_flag); \\\n        if (!encoder_name) {                                                                                           \\\n            GGML_ABORT(GGML_VIRTGPU \"%s: failed to prepare the remote call encoder\", __func__);                        \\\n        }                                                                                                              \\\n    } while (0)\n\n#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name)                                     \\\n    do {                                                                                                    \\\n        ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \\\n        if (!decoder_name) {                                                                                \\\n            GGML_ABORT(GGML_VIRTGPU \"%s: failed to kick the remote call\", __func__);                        \\\n        }                                                                                                   \\\n        if (ret_name < APIR_FORWARD_BASE_INDEX) {                                                           \\\n            GGML_ABORT(GGML_VIRTGPU \"%s: failed to forward the API call: %s: code %d\", __func__,            \\\n                       apir_forward_error(ret_name), ret_name);                                             \\\n        }                                                                                                   \\\n        ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX);                            \\\n        if (ret_name != 0) {                                                                                \\\n            GGML_ABORT(GGML_VIRTGPU \"backend function '%s' failed (return code: %d)\",                       \\\n                       REMOTE_CALL_PREPARE_command_name, ret_name);                                         \\\n        }                                                                                                   \\\n    } while (0)\n"
  },
  {
    "path": "src/ggml-virtgpu/virtgpu-forward.gen.h",
    "content": "#pragma once\n\n/* device */\nvoid                           apir_device_get_device_count(struct virtgpu * gpu);\nint                            apir_device_get_count(struct virtgpu * gpu);\nchar *                         apir_device_get_name(struct virtgpu * gpu);\nchar *                         apir_device_get_description(struct virtgpu * gpu);\nuint32_t                       apir_device_get_type(struct virtgpu * gpu);\nvoid                           apir_device_get_memory(struct virtgpu * gpu, size_t * free, size_t * total);\nbool                           apir_device_supports_op(struct virtgpu * gpu, const ggml_tensor * op);\napir_buffer_type_host_handle_t apir_device_get_buffer_type(struct virtgpu * gpu);\nvoid                           apir_device_get_props(struct virtgpu * gpu,\n                                                     bool *           async,\n                                                     bool *           host_buffer,\n                                                     bool *           buffer_from_host_ptr,\n                                                     bool *           events);\napir_buffer_context_t          apir_device_buffer_from_ptr(struct virtgpu * gpu, size_t size, size_t max_tensor_size);\n\n/* buffer-type */\nchar *                apir_buffer_type_get_name(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);\nsize_t                apir_buffer_type_get_alignment(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);\nsize_t                apir_buffer_type_get_max_size(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);\n/* apir_buffer_type_is_host is deprecated. */\napir_buffer_context_t apir_buffer_type_alloc_buffer(struct virtgpu *               gpu,\n                                                    apir_buffer_type_host_handle_t host_handle,\n                                                    size_t                         size);\nsize_t                apir_buffer_type_get_alloc_size(struct virtgpu *               gpu,\n                                                      apir_buffer_type_host_handle_t host_handle,\n                                                      const ggml_tensor *            op);\n\n/* buffer */\nvoid * apir_buffer_get_base(struct virtgpu * gpu, apir_buffer_context_t * buffer_context);\nvoid   apir_buffer_set_tensor(struct virtgpu *        gpu,\n                              apir_buffer_context_t * buffer_context,\n                              ggml_tensor *           tensor,\n                              const void *            data,\n                              size_t                  offset,\n                              size_t                  size);\nvoid   apir_buffer_get_tensor(struct virtgpu *        gpu,\n                              apir_buffer_context_t * buffer_context,\n                              const ggml_tensor *     tensor,\n                              void *                  data,\n                              size_t                  offset,\n                              size_t                  size);\nbool   apir_buffer_cpy_tensor(struct virtgpu *        gpu,\n                              apir_buffer_context_t * buffer_context,\n                              const ggml_tensor *     src,\n                              const ggml_tensor *     dst);\nvoid   apir_buffer_clear(struct virtgpu * gpu, apir_buffer_context_t * buffer_context, uint8_t value);\nvoid   apir_buffer_free_buffer(struct virtgpu * gpu, apir_buffer_context_t * buffer_context);\n\n/* backend */\nggml_status apir_backend_graph_compute(struct virtgpu * gpu, ggml_cgraph * cgraph);\n"
  },
  {
    "path": "src/ggml-virtgpu/virtgpu-shm.cpp",
    "content": "#include \"virtgpu-shm.h\"\n\n#include \"virtgpu.h\"\n\n#include <assert.h>\n\nstatic uint32_t virtgpu_ioctl_resource_create_blob(virtgpu *  gpu,\n                                                   uint32_t   blob_mem,\n                                                   uint32_t   blob_flags,\n                                                   size_t     blob_size,\n                                                   uint64_t   blob_id,\n                                                   uint32_t * res_id) {\n#ifdef SIMULATE_BO_SIZE_FIX\n    blob_size = align64(blob_size, 4096);\n#endif\n\n    drm_virtgpu_resource_create_blob args = {\n        .blob_mem   = blob_mem,\n        .blob_flags = blob_flags,\n        .bo_handle  = 0,\n        .res_handle = 0,\n        .size       = blob_size,\n        .pad        = 0,\n        .cmd_size   = 0,\n        .cmd        = 0,\n        .blob_id    = blob_id,\n    };\n\n    if (virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_RESOURCE_CREATE_BLOB, &args)) {\n        return 0;\n    }\n\n    *res_id = args.res_handle;\n    return args.bo_handle;\n}\n\nstatic void virtgpu_ioctl_gem_close(virtgpu * gpu, uint32_t gem_handle) {\n    drm_gem_close args = {\n        .handle = gem_handle,\n        .pad    = 0,\n    };\n\n    const int ret = virtgpu_ioctl(gpu, DRM_IOCTL_GEM_CLOSE, &args);\n    assert(!ret);\n#ifdef NDEBUG\n    UNUSED(ret);\n#endif\n}\n\nstatic void * virtgpu_ioctl_map(virtgpu * gpu, uint32_t gem_handle, size_t size) {\n    drm_virtgpu_map args = {\n        .offset = 0,\n        .handle = gem_handle,\n        .pad    = 0,\n    };\n\n    if (virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_MAP, &args)) {\n        return NULL;\n    }\n\n    void * ptr = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, gpu->fd, args.offset);\n    if (ptr == MAP_FAILED) {\n        return NULL;\n    }\n\n    return ptr;\n}\n\nvoid virtgpu_shmem_destroy(virtgpu * gpu, virtgpu_shmem * shmem) {\n    munmap(shmem->mmap_ptr, shmem->mmap_size);\n    virtgpu_ioctl_gem_close(gpu, shmem->gem_handle);\n}\n\nint virtgpu_shmem_create(virtgpu * gpu, size_t size, virtgpu_shmem * shmem) {\n    size = align64(size, 16384);\n\n    uint32_t res_id;\n    uint32_t gem_handle = virtgpu_ioctl_resource_create_blob(gpu, VIRTGPU_BLOB_MEM_HOST3D,\n                                                             VIRTGPU_BLOB_FLAG_USE_MAPPABLE, size, 0, &res_id);\n\n    if (!gem_handle) {\n        return 1;\n    }\n\n    void * ptr = virtgpu_ioctl_map(gpu, gem_handle, size);\n    if (!ptr) {\n        virtgpu_ioctl_gem_close(gpu, gem_handle);\n        GGML_LOG_ERROR(GGML_VIRTGPU \"%s: virtgpu_ioctl_map failed\\n\", __func__);\n        return 1;\n    }\n\n    shmem->res_id     = res_id;\n    shmem->mmap_size  = size;\n    shmem->mmap_ptr   = ptr;\n    shmem->gem_handle = gem_handle;\n\n    return 0;\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/virtgpu-shm.h",
    "content": "#pragma once\n\n#include \"virtgpu-utils.h\"\n\n#include <sys/mman.h>\n\n#include <atomic>\n#include <cassert>\n#include <cstddef>\n#include <cstdint>\n\nstruct virtgpu;\n\nstruct virtgpu_shmem {\n    uint32_t res_id;\n    size_t   mmap_size;\n    void *   mmap_ptr;\n\n    uint32_t gem_handle;\n};\n\nint  virtgpu_shmem_create(virtgpu * gpu, size_t size, virtgpu_shmem * shmem);\nvoid virtgpu_shmem_destroy(virtgpu * gpu, virtgpu_shmem * shmem);\n"
  },
  {
    "path": "src/ggml-virtgpu/virtgpu-utils.cpp",
    "content": "#include \"virtgpu-utils.h\"\n\n#include <malloc.h>\n#include <stdlib.h>\n\n#include <cstring>\n\n#define NODE_ALLOC_ALIGN 64\n#define NODE_PTR_MASK    (~((uintptr_t) NODE_ALLOC_ALIGN - 1))\n#define NODE_LEVEL_MASK  ((uintptr_t) NODE_ALLOC_ALIGN - 1)\n#define NULL_NODE        0\n\n#define os_malloc_aligned(_size, _align) _aligned_malloc(_size, _align)\n#define os_free_aligned(_ptr)            free(_ptr)\n#define p_atomic_cmpxchg(v, old, _new)   __sync_val_compare_and_swap((v), (old), (_new))\n\nstatic inline uint64_t util_logbase2_64(uint64_t n) {\n#if defined(HAVE___BUILTIN_CLZLL)\n    return ((sizeof(uint64_t) * 8 - 1) - __builtin_clzll(n | 1));\n#else\n    uint64_t pos = 0ull;\n    if (n >= 1ull << 32) {\n        n >>= 32;\n        pos += 32;\n    }\n    if (n >= 1ull << 16) {\n        n >>= 16;\n        pos += 16;\n    }\n    if (n >= 1ull << 8) {\n        n >>= 8;\n        pos += 8;\n    }\n    if (n >= 1ull << 4) {\n        n >>= 4;\n        pos += 4;\n    }\n    if (n >= 1ull << 2) {\n        n >>= 2;\n        pos += 2;\n    }\n    if (n >= 1ull << 1) {\n        pos += 1;\n    }\n    return pos;\n#endif\n}\n\nvoid util_sparse_array_init(util_sparse_array * arr, size_t elem_size, size_t node_size) {\n    memset(arr, 0, sizeof(*arr));\n    arr->elem_size      = elem_size;\n    arr->node_size_log2 = util_logbase2_64(node_size);\n    assert(node_size >= 2 && node_size == (1ull << arr->node_size_log2));\n}\n\nstatic inline void * os_malloc_aligned(size_t size, size_t alignment) {\n    void * ptr;\n    alignment = (alignment + sizeof(void *) - 1) & ~(sizeof(void *) - 1);\n    if (posix_memalign(&ptr, alignment, size) != 0) {\n        return NULL;\n    }\n    return ptr;\n}\n\nstatic inline void * _util_sparse_array_node_data(uintptr_t handle) {\n    return (void *) (handle & NODE_PTR_MASK);\n}\n\nstatic inline unsigned _util_sparse_array_node_level(uintptr_t handle) {\n    return handle & NODE_LEVEL_MASK;\n}\n\nstatic inline void _util_sparse_array_node_finish(util_sparse_array * arr, uintptr_t node) {\n    if (_util_sparse_array_node_level(node) > 0) {\n        uintptr_t * children  = (uintptr_t *) _util_sparse_array_node_data(node);\n        size_t      node_size = 1ull << arr->node_size_log2;\n        for (size_t i = 0; i < node_size; i++) {\n            if (children[i]) {\n                _util_sparse_array_node_finish(arr, children[i]);\n            }\n        }\n    }\n\n    os_free_aligned(_util_sparse_array_node_data(node));\n}\n\nstatic inline uintptr_t _util_sparse_array_node(void * data, unsigned level) {\n    assert(data != NULL);\n    assert(((uintptr_t) data & NODE_LEVEL_MASK) == 0);\n    assert((level & NODE_PTR_MASK) == 0);\n    return (uintptr_t) data | level;\n}\n\ninline uintptr_t _util_sparse_array_node_alloc(util_sparse_array * arr, unsigned level) {\n    size_t size;\n    if (level == 0) {\n        size = arr->elem_size << arr->node_size_log2;\n    } else {\n        size = sizeof(uintptr_t) << arr->node_size_log2;\n    }\n\n    void * data = os_malloc_aligned(size, NODE_ALLOC_ALIGN);\n    memset(data, 0, size);\n\n    return _util_sparse_array_node(data, level);\n}\n\nstatic inline uintptr_t _util_sparse_array_set_or_free_node(uintptr_t * node_ptr, uintptr_t cmp_node, uintptr_t node) {\n    uintptr_t prev_node = p_atomic_cmpxchg(node_ptr, cmp_node, node);\n\n    if (prev_node != cmp_node) {\n        /* We lost the race.  Free this one and return the one that was already\n       * allocated.\n       */\n        os_free_aligned(_util_sparse_array_node_data(node));\n        return prev_node;\n    } else {\n        return node;\n    }\n}\n\nvoid * util_sparse_array_get(util_sparse_array * arr, uint64_t idx) {\n    const unsigned node_size_log2 = arr->node_size_log2;\n    uintptr_t      root           = p_atomic_read(&arr->root);\n    if (unlikely(!root)) {\n        unsigned root_level = 0;\n        uint64_t idx_iter   = idx >> node_size_log2;\n        while (idx_iter) {\n            idx_iter >>= node_size_log2;\n            root_level++;\n        }\n        uintptr_t new_root = _util_sparse_array_node_alloc(arr, root_level);\n        root               = _util_sparse_array_set_or_free_node(&arr->root, NULL_NODE, new_root);\n    }\n\n    while (1) {\n        unsigned root_level = _util_sparse_array_node_level(root);\n        uint64_t root_idx   = idx >> (root_level * node_size_log2);\n        if (likely(root_idx < (1ull << node_size_log2))) {\n            break;\n        }\n\n        /* In this case, we have a root but its level is low enough that the\n       * requested index is out-of-bounds.\n       */\n        uintptr_t new_root = _util_sparse_array_node_alloc(arr, root_level + 1);\n\n        uintptr_t * new_root_children = (uintptr_t *) _util_sparse_array_node_data(new_root);\n        new_root_children[0]          = root;\n\n        /* We only add one at a time instead of the whole tree because it's\n       * easier to ensure correctness of both the tree building and the\n       * clean-up path.  Because we're only adding one node we never have to\n       * worry about trying to free multiple things without freeing the old\n       * things.\n       */\n        root = _util_sparse_array_set_or_free_node(&arr->root, root, new_root);\n    }\n\n    void *   node_data  = _util_sparse_array_node_data(root);\n    unsigned node_level = _util_sparse_array_node_level(root);\n    while (node_level > 0) {\n        uint64_t child_idx = (idx >> (node_level * node_size_log2)) & ((1ull << node_size_log2) - 1);\n\n        uintptr_t * children = (uintptr_t *) node_data;\n        uintptr_t   child    = p_atomic_read(&children[child_idx]);\n\n        if (unlikely(!child)) {\n            child = _util_sparse_array_node_alloc(arr, node_level - 1);\n            child = _util_sparse_array_set_or_free_node(&children[child_idx], NULL_NODE, child);\n        }\n\n        node_data  = _util_sparse_array_node_data(child);\n        node_level = _util_sparse_array_node_level(child);\n    }\n\n    uint64_t elem_idx = idx & ((1ull << node_size_log2) - 1);\n    return (void *) ((char *) node_data + (elem_idx * arr->elem_size));\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/virtgpu-utils.h",
    "content": "#pragma once\n\n#include <atomic>\n#include <cassert>\n#include <cerrno>\n#include <cstdarg>\n#include <cstddef>\n#include <cstdint>\n#include <cstdio>\n#include <cstdlib>\n#include <ctime>\n\n#define unlikely(x) __builtin_expect(!!(x), 0)\n#define likely(x)   __builtin_expect(!!(x), 1)\n\n#ifndef UNUSED\n#    define UNUSED(x) (void) (x)\n#endif\n\n/** Checks is a value is a power of two. Does not handle zero. */\n#define IS_POT(v) (((v) & ((v) - 1)) == 0)\n\n/** Checks is a value is a power of two. Zero handled. */\n#define IS_POT_NONZERO(v) ((v) != 0 && IS_POT(v))\n\n/** Align a value to a power of two */\n#define ALIGN_POT(x, pot_align) (((x) + (pot_align) - 1) & ~((pot_align) - 1))\n\n#define p_atomic_read(_v) __atomic_load_n((_v), __ATOMIC_ACQUIRE)\n\nstatic inline bool util_is_power_of_two_nonzero64(uint64_t v) {\n    return IS_POT_NONZERO(v);\n}\n\nstatic inline uint64_t align64(uint64_t value, uint64_t alignment) {\n    assert(util_is_power_of_two_nonzero64(alignment));\n    return ALIGN_POT(value, alignment);\n}\n\nstruct list_head {\n    list_head * prev;\n    list_head * next;\n};\n\nstruct util_sparse_array {\n    size_t   elem_size;\n    unsigned node_size_log2;\n\n    uintptr_t root;\n};\n\nvoid * util_sparse_array_get(util_sparse_array * arr, uint64_t idx);\nvoid   util_sparse_array_init(util_sparse_array * arr, size_t elem_size, size_t node_size);\n\ninline void os_time_sleep(int64_t usecs) {\n    timespec time;\n    time.tv_sec  = usecs / 1000000;\n    time.tv_nsec = (usecs % 1000000) * 1000;\n    while (clock_nanosleep(CLOCK_MONOTONIC, 0, &time, &time) == EINTR)\n        ;\n}\n\nstruct timer_data {\n    long long start;\n    long long total;\n    long long count;\n};\n\nstatic inline void start_timer(timer_data * timer) {\n    timespec ts;\n    clock_gettime(CLOCK_MONOTONIC, &ts);\n    timer->start = (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec;\n}\n\n// returns the duration in ns\nstatic inline long long stop_timer(timer_data * timer) {\n    timespec ts;\n    clock_gettime(CLOCK_MONOTONIC, &ts);\n    long long timer_end = (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec;\n\n    long long duration = (timer_end - timer->start);\n    timer->total += duration;\n    timer->count += 1;\n\n    return duration;\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/virtgpu.cpp",
    "content": "#include \"virtgpu.h\"\n\n#include <stdio.h>\n#include <unistd.h>\n\n#include <cassert>\n#include <cerrno>\n#include <cstdlib>\n\nstatic virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr dev);\nstatic virt_gpu_result_t virtgpu_open(virtgpu * gpu);\n\nstatic virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu);\nstatic virt_gpu_result_t virtgpu_init_context(virtgpu * gpu);\n\nstatic int      virtgpu_ioctl_context_init(virtgpu * gpu, virgl_renderer_capset capset_id);\nstatic int      virtgpu_ioctl_get_caps(virtgpu *             gpu,\n                                       virgl_renderer_capset id,\n                                       uint32_t              version,\n                                       void *                capset,\n                                       size_t                capset_size);\nstatic uint64_t virtgpu_ioctl_getparam(virtgpu * gpu, uint64_t param);\nstatic void     virtgpu_init_renderer_info(virtgpu * gpu);\n\nstatic void log_call_duration(long long call_duration_ns, const char * name);\n\nconst uint64_t APIR_HANDSHAKE_MAX_WAIT_MS   = 2 * 1000;   // 2s\nconst uint64_t APIR_LOADLIBRARY_MAX_WAIT_MS = 60 * 1000;  // 60s\n\nstatic int virtgpu_handshake(virtgpu * gpu) {\n    apir_encoder * encoder;\n    apir_decoder * decoder;\n\n    encoder = remote_call_prepare(gpu, APIR_COMMAND_TYPE_HANDSHAKE, 0);\n    if (!encoder) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: failed to prepare the remote call encoder\", __func__);\n        return 1;\n    }\n\n    /* write handshake props */\n\n    uint32_t guest_major = APIR_PROTOCOL_MAJOR;\n    uint32_t guest_minor = APIR_PROTOCOL_MINOR;\n    apir_encode_uint32_t(encoder, &guest_major);\n    apir_encode_uint32_t(encoder, &guest_minor);\n\n    /* *** */\n\n    uint32_t  ret_magic;\n    long long call_duration_ns;\n    ret_magic = remote_call(gpu, encoder, &decoder, APIR_HANDSHAKE_MAX_WAIT_MS, &call_duration_ns);\n    log_call_duration(call_duration_ns, \"API Remoting handshake\");\n\n    if (!decoder) {\n        GGML_ABORT(GGML_VIRTGPU\n                   \"%s: failed to initiate the communication with the virglrenderer library. \"\n                   \"Most likely, the wrong virglrenderer library was loaded in the hypervisor.\",\n                   __func__);\n        return 1;\n    }\n\n    /* read handshake return values */\n\n    uint32_t host_major;\n    uint32_t host_minor;\n\n    if (ret_magic != APIR_HANDSHAKE_MAGIC) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: handshake with the virglrenderer failed (code=%d | %s)\", __func__, ret_magic,\n                   apir_backend_initialize_error(ret_magic));\n    } else {\n        apir_decode_uint32_t(decoder, &host_major);\n        apir_decode_uint32_t(decoder, &host_minor);\n    }\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    if (ret_magic != APIR_HANDSHAKE_MAGIC) {\n        return 1;\n    }\n\n    GGML_LOG_INFO(GGML_VIRTGPU \"%s: Guest is running with %u.%u\\n\", __func__, guest_major, guest_minor);\n    GGML_LOG_INFO(GGML_VIRTGPU \"%s: Host is running with %u.%u\\n\", __func__, host_major, host_minor);\n\n    if (guest_major != host_major) {\n        GGML_LOG_ERROR(GGML_VIRTGPU \"Host major (%d) and guest major (%d) version differ\\n\", host_major, guest_major);\n    } else if (guest_minor != host_minor) {\n        GGML_LOG_WARN(GGML_VIRTGPU \"Host minor (%d) and guest minor (%d) version differ\\n\", host_minor, guest_minor);\n    }\n\n    return 0;\n}\n\nstatic ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) {\n    apir_encoder *            encoder;\n    apir_decoder *            decoder;\n    ApirLoadLibraryReturnCode ret;\n\n    encoder = remote_call_prepare(gpu, APIR_COMMAND_TYPE_LOADLIBRARY, 0);\n    if (!encoder) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: hypercall error: failed to prepare the API Remoting command encoder\", __func__);\n        return APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR;\n    }\n\n    long long call_duration_ns;\n\n    ret = (ApirLoadLibraryReturnCode) remote_call(gpu, encoder, &decoder, APIR_LOADLIBRARY_MAX_WAIT_MS,\n                                                  &call_duration_ns);\n    log_call_duration(call_duration_ns, \"API Remoting LoadLibrary\");\n\n    if (!decoder) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: hypercall error: failed to trigger the API Remoting hypercall.\\n\", __func__);\n        return APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR;\n    }\n\n    remote_call_finish(gpu, encoder, decoder);\n\n    if (ret == APIR_LOAD_LIBRARY_SUCCESS) {\n        GGML_LOG_INFO(GGML_VIRTGPU \"The API Remoting backend was successfully loaded and initialized\\n\");\n\n        return ret;\n    }\n\n    // something wrong happened, find out what.\n    if (ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) {\n        if (ret == APIR_LOAD_LIBRARY_ENV_VAR_MISSING) {\n            GGML_ABORT(GGML_VIRTGPU\n                       \"%s: virglrenderer could not open the API Remoting backend library, \"\n                       \"some environment variables are missing. \"\n                       \"Make sure virglrenderer is correctly configured by the hypervisor. (%s)\",\n                       __func__, apir_load_library_error(ret));\n        } else if (ret == APIR_LOAD_LIBRARY_CANNOT_OPEN) {\n            GGML_ABORT(GGML_VIRTGPU\n                       \"%s: virglrenderer could not open the API Remoting backend library. \"\n                       \"Make sure virglrenderer is correctly configured by the hypervisor. (%s)\",\n                       __func__, apir_load_library_error(ret));\n        } else if (ret == APIR_LOAD_LIBRARY_ENV_VAR_MISSING) {\n            GGML_ABORT(GGML_VIRTGPU\n                       \"%s: could not load the backend library, some symbols are missing. \"\n                       \"Make sure virglrenderer is correctly configured by the hypervisor. (%s) \",\n                       __func__, apir_load_library_error(ret));\n        } else {\n            GGML_ABORT(GGML_VIRTGPU \"%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)\",\n                       __func__, apir_load_library_error(ret), ret);\n        }\n        return ret;\n    }\n\n    GGML_LOG_INFO(GGML_VIRTGPU \"%s: virglrenderer successfully loaded the API Remoting backend library.\\n\", __func__);\n\n    ApirLoadLibraryReturnCode apir_ret = (ApirLoadLibraryReturnCode) (ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX);\n\n    if (apir_ret == APIR_LOAD_LIBRARY_CANNOT_OPEN) {\n        GGML_ABORT(GGML_VIRTGPU\n                   \"%s: the API Remoting backend library couldn't load the GGML backend library. \"\n                   \"Make sure virglrenderer is correctly configured by the hypervisor. (%s)\",\n                   __func__, apir_load_library_error(apir_ret));\n    } else if (apir_ret == APIR_LOAD_LIBRARY_SYMBOL_MISSING) {\n        GGML_ABORT(\n            GGML_VIRTGPU\n            \"%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. \"\n            \"Make sure virglrenderer is correctly configured by the hypervisor. (%s)\",\n            __func__, apir_load_library_error(apir_ret));\n    } else if (apir_ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) {\n        GGML_ABORT(GGML_VIRTGPU\n                   \"%s: the API Remoting backend library couldn't load the GGML backend library: apir code=%d | %s)\",\n                   __func__, apir_ret, apir_load_library_error(apir_ret));\n    } else {\n        uint32_t lib_ret = apir_ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX;\n        GGML_ABORT(GGML_VIRTGPU\n                   \"%s: the API Remoting backend library failed to initialize its backend library: apir code=%d)\",\n                   __func__, lib_ret);\n    }\n    return ret;\n}\n\nvirtgpu * create_virtgpu() {\n    virtgpu * gpu = new virtgpu();\n\n    gpu->use_apir_capset = getenv(\"GGML_REMOTING_USE_APIR_CAPSET\") != nullptr;\n    util_sparse_array_init(&gpu->shmem_array, sizeof(virtgpu_shmem), 1024);\n\n    // Initialize mutex to protect shared data_shmem buffer\n    if (mtx_init(&gpu->data_shmem_mutex, mtx_plain) != thrd_success) {\n        delete gpu;\n        GGML_ABORT(GGML_VIRTGPU \"%s: failed to initialize data_shmem mutex\", __func__);\n        return NULL;\n    }\n\n    if (virtgpu_open(gpu) != APIR_SUCCESS) {\n        GGML_LOG_ERROR(GGML_VIRTGPU \"%s: failed to open the virtgpu device\\n\", __func__);\n        return NULL;\n    }\n\n    if (virtgpu_init_capset(gpu) != APIR_SUCCESS) {\n        if (gpu->use_apir_capset) {\n            GGML_ABORT(GGML_VIRTGPU\n                       \"%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library \"\n                       \"supports it.\",\n                       __func__);\n        } else {\n            GGML_ABORT(GGML_VIRTGPU \"%s: failed to initialize the virtgpu Venus capset\", __func__);\n        }\n        return NULL;\n    }\n\n    if (virtgpu_init_context(gpu) != APIR_SUCCESS) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: failed to initialize the GPU context\", __func__);\n        return NULL;\n    }\n\n    if (virtgpu_shmem_create(gpu, SHMEM_REPLY_SIZE, &gpu->reply_shmem)) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: failed to create the shared reply memory pages\", __func__);\n        return NULL;\n    }\n\n    if (virtgpu_shmem_create(gpu, SHMEM_DATA_SIZE, &gpu->data_shmem)) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: failed to create the shared data memory pages\", __func__);\n        return NULL;\n    }\n\n    if (virtgpu_handshake(gpu)) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: failed to handshake with the virglrenderer library\", __func__);\n        return NULL;\n    }\n\n    if (virtgpu_load_library(gpu) != APIR_LOAD_LIBRARY_SUCCESS) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: failed to load the backend library\", __func__);\n        return NULL;\n    }\n\n    return gpu;\n}\n\nstatic virt_gpu_result_t virtgpu_open(virtgpu * gpu) {\n    drmDevicePtr devs[8];\n    int          count = drmGetDevices2(0, devs, ARRAY_SIZE(devs));\n    if (count < 0) {\n        GGML_LOG_ERROR(GGML_VIRTGPU \"%s: failed to enumerate DRM devices\\n\", __func__);\n        return APIR_ERROR_INITIALIZATION_FAILED;\n    }\n\n    virt_gpu_result_t result = APIR_ERROR_INITIALIZATION_FAILED;\n    for (int i = 0; i < count; i++) {\n        result = virtgpu_open_device(gpu, devs[i]);\n        if (result == APIR_SUCCESS) {\n            break;\n        }\n    }\n\n    drmFreeDevices(devs, count);\n\n    return result;\n}\n\nstatic virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr dev) {\n    const char * node_path = dev->nodes[DRM_NODE_RENDER];\n\n    int fd = open(node_path, O_RDWR | O_CLOEXEC);\n    if (fd < 0) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: failed to open %s\", __func__, node_path);\n        return APIR_ERROR_INITIALIZATION_FAILED;\n    }\n\n    drmVersionPtr version = drmGetVersion(fd);\n    if (!version || strcmp(version->name, \"virtio_gpu\") || version->version_major != 0) {\n        if (version) {\n            GGML_LOG_ERROR(GGML_VIRTGPU \"%s: unknown DRM driver %s version %d\\n\", __func__, version->name,\n                           version->version_major);\n        } else {\n            GGML_LOG_ERROR(GGML_VIRTGPU \"%s: failed to get DRM driver version\\n\", __func__);\n        }\n\n        if (version) {\n            drmFreeVersion(version);\n        }\n        close(fd);\n        return APIR_ERROR_INITIALIZATION_FAILED;\n    }\n\n    gpu->fd = fd;\n\n    drmFreeVersion(version);\n\n    GGML_LOG_INFO(GGML_VIRTGPU \"using DRM device %s\\n\", node_path);\n\n    return APIR_SUCCESS;\n}\n\nstatic virt_gpu_result_t virtgpu_init_context(virtgpu * gpu) {\n    assert(!gpu->capset.version);\n    const int ret = virtgpu_ioctl_context_init(gpu, gpu->capset.id);\n    if (ret) {\n        GGML_LOG_ERROR(GGML_VIRTGPU \"%s: failed to initialize context: %s\\n\", __func__, strerror(errno));\n        return APIR_ERROR_INITIALIZATION_FAILED;\n    }\n\n    return APIR_SUCCESS;\n}\n\nstatic virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu) {\n    if (gpu->use_apir_capset) {\n        GGML_LOG_INFO(GGML_VIRTGPU \"Using the APIR capset\\n\");\n        gpu->capset.id = VIRTGPU_DRM_CAPSET_APIR;\n    } else {\n        GGML_LOG_INFO(GGML_VIRTGPU \"Using the Venus capset\\n\");\n        gpu->capset.id = VIRTGPU_DRM_CAPSET_VENUS;\n    }\n    gpu->capset.version = 0;\n\n    int ret =\n        virtgpu_ioctl_get_caps(gpu, gpu->capset.id, gpu->capset.version, &gpu->capset.data, sizeof(gpu->capset.data));\n\n    if (ret) {\n        GGML_LOG_ERROR(GGML_VIRTGPU \"%s: failed to get APIR v%d capset: %s\\n\", __func__, gpu->capset.version,\n                       strerror(errno));\n        return APIR_ERROR_INITIALIZATION_FAILED;\n    }\n\n    assert(gpu->capset.data.supports_blob_resources);\n\n    return APIR_SUCCESS;\n}\n\nstatic int virtgpu_ioctl_context_init(virtgpu * gpu, virgl_renderer_capset capset_id) {\n    drm_virtgpu_context_set_param ctx_set_params[3] = {\n        {\n         .param = VIRTGPU_CONTEXT_PARAM_CAPSET_ID,\n         .value = capset_id,\n         },\n        {\n         .param = VIRTGPU_CONTEXT_PARAM_NUM_RINGS,\n         .value = 1,\n         },\n        {\n         .param = VIRTGPU_CONTEXT_PARAM_POLL_RINGS_MASK,\n         .value = 0, /* don't generate drm_events on fence signaling */\n        },\n    };\n\n    drm_virtgpu_context_init args = {\n        .num_params     = ARRAY_SIZE(ctx_set_params),\n        .pad            = 0,\n        .ctx_set_params = (uintptr_t) &ctx_set_params,\n    };\n\n    return virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_CONTEXT_INIT, &args);\n}\n\nstatic int virtgpu_ioctl_get_caps(virtgpu *             gpu,\n                                  virgl_renderer_capset id,\n                                  uint32_t              version,\n                                  void *                capset,\n                                  size_t                capset_size) {\n    drm_virtgpu_get_caps args = {\n        .cap_set_id  = id,\n        .cap_set_ver = version,\n        .addr        = (uintptr_t) capset,\n        .size        = (__u32) capset_size,\n        .pad         = 0,\n    };\n\n    return virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_GET_CAPS, &args);\n}\n\nstatic uint64_t virtgpu_ioctl_getparam(virtgpu * gpu, uint64_t param) {\n    /* val must be zeroed because kernel only writes the lower 32 bits */\n    uint64_t             val  = 0;\n    drm_virtgpu_getparam args = {\n        .param = param,\n        .value = (uintptr_t) &val,\n    };\n\n    const int ret = virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_GETPARAM, &args);\n    return ret ? 0 : val;\n}\n\napir_encoder * remote_call_prepare(virtgpu * gpu, ApirCommandType apir_cmd_type, int32_t cmd_flags) {\n    /*\n     * Prepare the command encoder and its buffer\n     */\n\n    thread_local char encoder_buffer[4096];\n\n    thread_local apir_encoder enc;\n    enc = {\n        .cur   = encoder_buffer,\n        .start = encoder_buffer,\n        .end   = encoder_buffer + sizeof(encoder_buffer),\n        .fatal = false,\n    };\n\n    /*\n     * Fill the command encoder with the common args:\n     * - cmd_type (int32_t)\n     * - cmd_flags (int32_t)\n     * - reply res id (uint32_t)\n   */\n\n    int32_t cmd_type = apir_cmd_type;\n\n    // for testing during the hypervisor transition\n    if (!gpu->use_apir_capset) {\n        cmd_type += VENUS_COMMAND_TYPE_LENGTH;\n    }\n    apir_encode_int32_t(&enc, &cmd_type);\n    apir_encode_int32_t(&enc, &cmd_flags);\n\n    uint32_t reply_res_id = gpu->reply_shmem.res_id;\n    apir_encode_uint32_t(&enc, &reply_res_id);\n\n    return &enc;\n}\n\nvoid remote_call_finish(virtgpu * gpu, apir_encoder * enc, apir_decoder * dec) {\n    UNUSED(gpu);\n\n    if (!enc) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: Invalid (null) encoder\", __func__);\n    }\n\n    if (!dec) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: Invalid (null) decoder\", __func__);\n    }\n\n    if (apir_encoder_get_fatal(enc)) {\n        GGML_LOG_ERROR(GGML_VIRTGPU \"%s: Failed to encode the output parameters.\", __func__);\n    }\n\n    if (apir_decoder_get_fatal(dec)) {\n        GGML_LOG_ERROR(GGML_VIRTGPU \"%s: Failed to decode the input parameters.\", __func__);\n    }\n}\n\nuint32_t remote_call(virtgpu *       gpu,\n                     apir_encoder *  encoder,\n                     apir_decoder ** decoder,\n                     float           max_wait_ms,\n                     long long *     call_duration_ns) {\n    /*\n     * Prepare the reply notification pointer\n     */\n\n    volatile std::atomic_uint * atomic_reply_notif = (volatile std::atomic_uint *) gpu->reply_shmem.mmap_ptr;\n    *atomic_reply_notif                            = 0;\n\n    /*\n     * Trigger the execbuf ioctl\n     */\n\n    drm_virtgpu_execbuffer args = {\n        .flags   = VIRTGPU_EXECBUF_RING_IDX,\n        .size    = (uint32_t) (encoder->cur - encoder->start),\n        .command = (uintptr_t) encoder->start,\n\n        .bo_handles     = 0,\n        .num_bo_handles = 0,\n\n        .fence_fd         = 0,\n        .ring_idx         = 0,\n        .syncobj_stride   = 0,\n        .num_in_syncobjs  = 0,\n        .num_out_syncobjs = 0,\n        .in_syncobjs      = 0,\n        .out_syncobjs     = 0,\n    };\n\n    *decoder = NULL;\n\n    int ret = drmIoctl(gpu->fd, DRM_IOCTL_VIRTGPU_EXECBUFFER, &args);\n\n    if (ret != 0) {\n        GGML_ABORT(GGML_VIRTGPU \"%s: the virtgpu EXECBUFFER ioctl failed (%d)\", __func__, ret);\n    }\n\n    /*\n     * Wait for the response notification\n     */\n    timer_data wait_host_reply_timer = { 0, 0, 0 };\n\n    start_timer(&wait_host_reply_timer);\n\n    timespec ts_start, ts_end;\n    clock_gettime(CLOCK_MONOTONIC, &ts_start);\n    long long start_time = (long long) ts_start.tv_sec * 1000000000LL + ts_start.tv_nsec;\n\n    bool     timedout    = false;\n    uint32_t notif_value = 0;\n    while (true) {\n        notif_value = std::atomic_load_explicit(atomic_reply_notif, std::memory_order_acquire);\n\n        if (notif_value != 0) {\n            break;\n        }\n\n        int64_t base_sleep_us = 15;\n\n        os_time_sleep(base_sleep_us);\n\n        if (max_wait_ms) {\n            clock_gettime(CLOCK_MONOTONIC, &ts_end);\n            long long end_time    = (long long) ts_end.tv_sec * 1000000000LL + ts_end.tv_nsec;\n            float     duration_ms = (end_time - start_time) / 1000000;\n\n            if (duration_ms > max_wait_ms) {\n                timedout = true;\n                break;\n            }\n        }\n    }\n\n    if (call_duration_ns) {\n        *call_duration_ns = stop_timer(&wait_host_reply_timer);\n    }\n\n    if (max_wait_ms && timedout) {\n        GGML_LOG_ERROR(GGML_VIRTGPU \"%s: timed out waiting for the host answer...\\n\", __func__);\n        return APIR_FORWARD_TIMEOUT;\n    }\n\n    /*\n     * Prepare the decoder\n     */\n    static apir_decoder response_dec;\n    response_dec.cur = (char *) gpu->reply_shmem.mmap_ptr + sizeof(*atomic_reply_notif);\n    response_dec.end = (char *) gpu->reply_shmem.mmap_ptr + gpu->reply_shmem.mmap_size;\n    *decoder         = &response_dec;\n\n    // extract the actual return value from the notif flag\n    uint32_t returned_value = notif_value - 1;\n    return returned_value;\n}\n\nstatic void log_call_duration(long long call_duration_ns, const char * name) {\n    double call_duration_ms = (double) call_duration_ns / 1e6;  // 1 millisecond = 1e6 nanoseconds\n    double call_duration_s  = (double) call_duration_ns / 1e9;  // 1 second = 1e9 nanoseconds\n\n    if (call_duration_s > 1) {\n        GGML_LOG_INFO(GGML_VIRTGPU \"waited %.2fs for the %s host reply...\\n\", call_duration_s, name);\n    } else if (call_duration_ms > 1) {\n        GGML_LOG_INFO(GGML_VIRTGPU \"waited %.2fms for the %s host reply...\\n\", call_duration_ms, name);\n    } else {\n        GGML_LOG_INFO(GGML_VIRTGPU \"waited %lldns for the %s host reply...\\n\", call_duration_ns, name);\n    }\n}\n"
  },
  {
    "path": "src/ggml-virtgpu/virtgpu.h",
    "content": "#pragma once\n\n// clang-format off\n#include \"virtgpu-utils.h\"\n#include \"virtgpu-shm.h\"\n#include \"virtgpu-apir.h\"\n\n#include \"backend/shared/api_remoting.h\"\n#include \"backend/shared/apir_cs.h\"\n\n#include <fcntl.h>\n#include <stdbool.h>\n#include <stdio.h>\n#include <sys/stat.h>\n#include <sys/sysmacros.h>\n#include <threads.h>\n#include <xf86drm.h>\n\n#include <cstring>\n\n#include \"ggml-remoting.h\"\n\n#define VIRGL_RENDERER_UNSTABLE_APIS 1\n#include \"apir_hw.h\"\n#include <drm/virtgpu_drm.h>\n#include \"venus_hw.h\"\n// clang-format on\n\n#ifndef VIRTGPU_DRM_CAPSET_APIR\n// Will be defined include/drm/virtgpu_drm.h when\n// https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590/diffs\n// is merged\n#    define VIRTGPU_DRM_CAPSET_APIR 10\n#endif\n\n// Mesa/Virlgrenderer Venus internal. Only necessary during the\n// Venus->APIR transition in Virglrenderer\n#define VENUS_COMMAND_TYPE_LENGTH 331\n\n#ifndef VIRTGPU_DRM_CAPSET_VENUS  // only available with Linux >= v6.16\n#    define VIRTGPU_DRM_CAPSET_VENUS 4\n#endif\n\ntypedef uint32_t virgl_renderer_capset;\n\n/* from src/virtio/vulkan/vn_renderer_virtgpu.c */\n#define VIRTGPU_PCI_VENDOR_ID       0x1af4\n#define VIRTGPU_PCI_DEVICE_ID       0x1050\n#define VIRTGPU_BLOB_MEM_GUEST_VRAM 0x0004\n#define VIRTGPU_PARAM_GUEST_VRAM    9\n\n#define SHMEM_DATA_SIZE  0x1830000  // 24MiB\n#define SHMEM_REPLY_SIZE 0x4000\n\n#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0]))\n\nenum virt_gpu_result_t {\n    APIR_SUCCESS                     = 0,\n    APIR_ERROR_INITIALIZATION_FAILED = -1,\n};\n\n#define PRINTFLIKE(f, a) __attribute__((format(__printf__, f, a)))\n\nstruct virtgpu {\n    bool use_apir_capset;\n\n    int fd;\n\n    struct {\n        virgl_renderer_capset      id;\n        uint32_t                   version;\n        virgl_renderer_capset_apir data;\n    } capset;\n\n    util_sparse_array shmem_array;\n\n    /* APIR communication pages */\n    virtgpu_shmem reply_shmem;\n    virtgpu_shmem data_shmem;\n\n    /* Mutex to protect shared data_shmem buffer from concurrent access */\n    mtx_t data_shmem_mutex;\n\n    /* Cached device information to prevent memory leaks and race conditions */\n    struct {\n        char *   description;\n        char *   name;\n        int32_t  device_count;\n        uint32_t type;\n        size_t   memory_free;\n        size_t   memory_total;\n    } cached_device_info;\n\n    /* Cached buffer type information to prevent memory leaks and race conditions */\n    struct {\n        apir_buffer_type_host_handle_t host_handle;\n        char *                         name;\n        size_t                         alignment;\n        size_t                         max_size;\n    } cached_buffer_type;\n};\n\nstatic inline int virtgpu_ioctl(virtgpu * gpu, unsigned long request, void * args) {\n    return drmIoctl(gpu->fd, request, args);\n}\n\nvirtgpu * create_virtgpu();\n\napir_encoder * remote_call_prepare(virtgpu * gpu, ApirCommandType apir_cmd_type, int32_t cmd_flags);\n\nuint32_t remote_call(virtgpu *       gpu,\n                     apir_encoder *  enc,\n                     apir_decoder ** dec,\n                     float           max_wait_ms,\n                     long long *     call_duration_ns);\n\nvoid remote_call_finish(virtgpu * gpu, apir_encoder * enc, apir_decoder * dec);\n"
  },
  {
    "path": "src/ggml-vulkan/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.19)\ncmake_policy(SET CMP0114 NEW)\ncmake_policy(SET CMP0116 NEW)\nif (POLICY CMP0147)\n    # Parallel build custom build steps\n    cmake_policy(SET CMP0147 NEW)\nendif()\n\nfind_package(Vulkan COMPONENTS glslc REQUIRED)\n\nif (CMAKE_CXX_COMPILER_ID STREQUAL \"MSVC\")\n    # Parallel build object files\n    add_definitions(/MP)\nendif()\n\nfunction(detect_host_compiler)\n    if (CMAKE_HOST_SYSTEM_NAME STREQUAL \"Windows\")\n        find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)\n        find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH)\n    else()\n        find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH)\n        find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH)\n    endif()\n    set(HOST_C_COMPILER \"${HOST_C_COMPILER}\" PARENT_SCOPE)\n    set(HOST_CXX_COMPILER \"${HOST_CXX_COMPILER}\" PARENT_SCOPE)\nendfunction()\n\n# Function to test shader extension support\n# Parameters:\n#  EXTENSION_NAME - Name of the extension to test (e.g., \"GL_EXT_integer_dot_product\")\n#  TEST_SHADER_FILE - Path to the test shader file\n#  RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result\nfunction(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE)\n    execute_process(\n        COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 \"${TEST_SHADER_FILE}\"\n        OUTPUT_VARIABLE glslc_output\n        ERROR_VARIABLE glslc_error\n    )\n\n    if (${glslc_error} MATCHES \".*extension not supported: ${EXTENSION_NAME}.*\")\n        message(STATUS \"${EXTENSION_NAME} not supported by glslc\")\n        set(${RESULT_VARIABLE} OFF PARENT_SCOPE)\n    else()\n        message(STATUS \"${EXTENSION_NAME} supported by glslc\")\n        set(${RESULT_VARIABLE} ON PARENT_SCOPE)\n        add_compile_definitions(${RESULT_VARIABLE})\n\n        # Ensure the extension support is forwarded to vulkan-shaders-gen\n        list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON)\n        set(VULKAN_SHADER_GEN_CMAKE_ARGS \"${VULKAN_SHADER_GEN_CMAKE_ARGS}\" PARENT_SCOPE)\n    endif()\nendfunction()\n\nif (Vulkan_FOUND)\n    message(STATUS \"Vulkan found\")\n\n    ggml_add_backend_library(ggml-vulkan\n                             ggml-vulkan.cpp\n                             ../../include/ggml-vulkan.h\n                            )\n\n    set(VULKAN_SHADER_GEN_CMAKE_ARGS \"\")\n\n    # Test all shader extensions\n    test_shader_extension_support(\n        \"GL_KHR_cooperative_matrix\"\n        \"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat.comp\"\n        \"GGML_VULKAN_COOPMAT_GLSLC_SUPPORT\"\n    )\n\n    test_shader_extension_support(\n        \"GL_NV_cooperative_matrix2\"\n        \"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2.comp\"\n        \"GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT\"\n    )\n\n    test_shader_extension_support(\n        \"GL_EXT_integer_dot_product\"\n        \"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp\"\n        \"GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT\"\n    )\n\n    test_shader_extension_support(\n        \"GL_EXT_bfloat16\"\n        \"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/bfloat16.comp\"\n        \"GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT\"\n    )\n\n    target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)\n    target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})\n\n    # Workaround to the \"can't dereference invalidated vector iterator\" bug in clang-cl debug build\n    # Possibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector\n    if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL \"Clang\")\n        add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0)\n    endif()\n\n    if (GGML_VULKAN_CHECK_RESULTS)\n        add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)\n    endif()\n\n    if (GGML_VULKAN_DEBUG)\n        add_compile_definitions(GGML_VULKAN_DEBUG)\n    endif()\n\n    if (GGML_VULKAN_MEMORY_DEBUG)\n        add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)\n    endif()\n\n    if (GGML_VULKAN_SHADER_DEBUG_INFO)\n        add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)\n        list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DGGML_VULKAN_SHADER_DEBUG_INFO=ON)\n    endif()\n\n    if (GGML_VULKAN_VALIDATE)\n        add_compile_definitions(GGML_VULKAN_VALIDATE)\n    endif()\n\n    if (GGML_VULKAN_RUN_TESTS)\n        add_compile_definitions(GGML_VULKAN_RUN_TESTS)\n    endif()\n\n    # Set up toolchain for host compilation whether cross-compiling or not\n    if (CMAKE_CROSSCOMPILING)\n        if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN)\n            set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN})\n        else()\n            detect_host_compiler()\n            if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER)\n                message(FATAL_ERROR \"Host compiler not found\")\n            else()\n                message(STATUS \"Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}\")\n            endif()\n            configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY)\n            set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake)\n        endif()\n    else()\n        # For non-cross-compiling, use empty toolchain (use host compiler)\n        set(HOST_CMAKE_TOOLCHAIN_FILE \"\")\n    endif()\n\n    include(ExternalProject)\n\n    if (CMAKE_CROSSCOMPILING)\n        list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE})\n        message(STATUS \"vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}\")\n    endif()\n\n    ExternalProject_Add(\n        vulkan-shaders-gen\n        SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders\n        CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/$<CONFIG>\n                   -DCMAKE_INSTALL_BINDIR=.\n                   -DCMAKE_BUILD_TYPE=$<CONFIG>\n                   ${VULKAN_SHADER_GEN_CMAKE_ARGS}\n\n        BUILD_COMMAND ${CMAKE_COMMAND} --build . --config $<CONFIG>\n        BUILD_ALWAYS  TRUE\n\n        # NOTE: When DESTDIR is set using Makefile generators and\n        # \"make install\" triggers the build step, vulkan-shaders-gen\n        # would be installed into the DESTDIR prefix, so it is unset\n        # to ensure that does not happen.\n\n        INSTALL_COMMAND ${CMAKE_COMMAND} -E env --unset=DESTDIR\n                        ${CMAKE_COMMAND} --install . --config $<CONFIG>\n    )\n\n    set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)\n    set (_ggml_vk_genshaders_dir \"${CMAKE_BINARY_DIR}/$<CONFIG>\")\n    set (_ggml_vk_genshaders_cmd \"${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}\")\n    set (_ggml_vk_header     \"${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp\")\n    set (_ggml_vk_input_dir  \"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders\")\n    set (_ggml_vk_output_dir \"${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv\")\n\n    file(GLOB _ggml_vk_shader_files CONFIGURE_DEPENDS \"${_ggml_vk_input_dir}/*.comp\")\n\n    # Because external projects do not provide source-level tracking,\n    # the vulkan-shaders-gen sources need to be explicitly added to\n    # ensure that changes will cascade into shader re-generation.\n\n    file(GLOB _ggml_vk_shaders_gen_sources\n              CONFIGURE_DEPENDS \"${_ggml_vk_input_dir}/*.cpp\"\n                                \"${_ggml_vk_input_dir}/*.h\")\n\n    add_custom_command(\n        OUTPUT ${_ggml_vk_header}\n        COMMAND ${_ggml_vk_genshaders_cmd}\n            --output-dir ${_ggml_vk_output_dir}\n            --target-hpp ${_ggml_vk_header}\n        DEPENDS ${_ggml_vk_shaders_gen_sources}\n                vulkan-shaders-gen\n        COMMENT \"Generate vulkan shaders header\"\n    )\n    target_sources(ggml-vulkan PRIVATE ${_ggml_vk_header})\n\n    foreach (file_full ${_ggml_vk_shader_files})\n        get_filename_component(file ${file_full} NAME)\n        set (_ggml_vk_target_cpp \"${CMAKE_CURRENT_BINARY_DIR}/${file}.cpp\")\n\n        add_custom_command(\n            OUTPUT  ${_ggml_vk_target_cpp}\n            DEPFILE ${_ggml_vk_target_cpp}.d\n            COMMAND ${_ggml_vk_genshaders_cmd}\n                --glslc      ${Vulkan_GLSLC_EXECUTABLE}\n                --source     ${file_full}\n                --output-dir ${_ggml_vk_output_dir}\n                --target-hpp ${_ggml_vk_header}\n                --target-cpp ${_ggml_vk_target_cpp}\n            DEPENDS ${file_full}\n                    ${_ggml_vk_shaders_gen_sources}\n                    vulkan-shaders-gen\n            COMMENT \"Generate vulkan shaders for ${file}\"\n        )\n        target_sources(ggml-vulkan PRIVATE ${_ggml_vk_target_cpp})\n    endforeach()\n\nelse()\n    message(WARNING \"Vulkan not found\")\nendif()\n"
  },
  {
    "path": "src/ggml-vulkan/cmake/host-toolchain.cmake.in",
    "content": "set(CMAKE_BUILD_TYPE Release)\nset(CMAKE_C_FLAGS -O2)\nset(CMAKE_CXX_FLAGS -O2)\nset(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)\nset(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER)\nset(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER)\nset(CMAKE_C_COMPILER \"@HOST_C_COMPILER@\")\nset(CMAKE_CXX_COMPILER \"@HOST_CXX_COMPILER@\")\nset(CMAKE_RUNTIME_OUTPUT_DIRECTORY @CMAKE_RUNTIME_OUTPUT_DIRECTORY@)\n\nif(\"@CMAKE_C_COMPILER_ID@\" STREQUAL \"MSVC\")\n    foreach(CONFIG IN ITEMS DEBUG RELEASE MINSIZEREL RELWITHDEBINFO)\n        set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})\n    endforeach()\nendif()\n"
  },
  {
    "path": "src/ggml-vulkan/ggml-vulkan.cpp",
    "content": "#include \"ggml-vulkan.h\"\n#include <vulkan/vulkan_core.h>\n#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS)\n#include <chrono>\n#include \"ggml-cpu.h\"\n#endif\n\n// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-\n#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1\n// We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE\n// to avoid conflicts with applications or other libraries who might use it.\n#if VK_HEADER_VERSION >= 301\nnamespace vk::detail { class DispatchLoaderDynamic; }\nusing vk::detail::DispatchLoaderDynamic;\n#else\nnamespace vk { class DispatchLoaderDynamic; }\nusing vk::DispatchLoaderDynamic;\n#endif\nDispatchLoaderDynamic & ggml_vk_default_dispatcher();\n#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()\n\n#include <vulkan/vulkan.hpp>\n\n#include <algorithm>\n#include <cmath>\n#include <iomanip>\n#include <iostream>\n#include <tuple>\n#include <vector>\n#include <deque>\n#include <sstream>\n#include <utility>\n#include <memory>\n#include <limits>\n#include <map>\n#include <set>\n#include <unordered_map>\n#include <memory>\n#include <mutex>\n#include <future>\n#include <thread>\n\n#if defined(_MSC_VER)\n# define NOMINMAX 1\n# include <windows.h>\n# define YIELD() YieldProcessor()\n#elif defined(__clang__) || defined(__GNUC__)\n# if defined(__x86_64__) ||defined(__i386__)\n#  include <immintrin.h>\n#  define YIELD() _mm_pause()\n# elif defined(__arm__) || defined(__aarch64__)\n#  if defined(__clang__)\n#   include <arm_acle.h>\n#   define YIELD() __yield()\n#  else\n#   define YIELD() asm volatile(\"yield\")\n#  endif\n# endif\n#endif\n\n#if !defined(YIELD)\n#define YIELD()\n#endif\n\n#include \"ggml-impl.h\"\n#include \"ggml-backend-impl.h\"\n\n#include \"ggml-vulkan-shaders.hpp\"\n\n// remove this once it's more widely available in the SDK\n#if !defined(VK_KHR_shader_bfloat16)\n\n#define VK_KHR_shader_bfloat16 1\n#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION                          1\n#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME                        \"VK_KHR_shader_bfloat16\"\n#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)\n#define VK_COMPONENT_TYPE_BFLOAT16_KHR                               ((VkComponentTypeKHR)1000141000)\n\ntypedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {\n    VkStructureType                       sType;\n    void*                                 pNext;\n    VkBool32                              shaderBFloat16Type;\n    VkBool32                              shaderBFloat16DotProduct;\n    VkBool32                              shaderBFloat16CooperativeMatrix;\n} VkPhysicalDeviceShaderBfloat16FeaturesKHR;\n#endif\n\n#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))\n#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))\nstatic bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }\n\n#define VK_VENDOR_ID_AMD 0x1002\n#define VK_VENDOR_ID_APPLE 0x106b\n#define VK_VENDOR_ID_INTEL 0x8086\n#define VK_VENDOR_ID_NVIDIA 0x10de\n#define VK_VENDOR_ID_QUALCOMM 0x5143\n\n#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256\n\n#define GGML_VK_MAX_NODES 8192\n\n#define VK_CHECK(err, msg)                                          \\\n    do {                                                            \\\n        vk::Result err_ = (err);                                    \\\n        if (err_ != vk::Result::eSuccess) {                         \\\n            fprintf(stderr, \"ggml_vulkan: %s error %s at %s:%d\\n\",  \\\n                #err, to_string(err_).c_str(), __FILE__, __LINE__); \\\n            exit(1);                                                \\\n        }                                                           \\\n    } while (0)\n\n#ifdef GGML_VULKAN_DEBUG\n#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl\n#else\n#define VK_LOG_DEBUG(msg) ((void) 0)\n#endif // GGML_VULKAN_DEBUG\n\nstruct ggml_backend_vk_context;\n\n#define MAX_PARAMETER_COUNT 12\n// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.\n#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3)\n\ntypedef std::shared_ptr<struct vk_pipeline_struct> vk_pipeline;\n\nstruct vk_pipeline_struct {\n    std::string name;\n    vk::ShaderModule shader_module;\n    vk::PipelineLayout layout;\n    vk::Pipeline pipeline;\n    uint32_t push_constant_size;\n    uint32_t parameter_count;\n    std::array<uint32_t, 3> wg_denoms;\n    uint32_t align;\n    // true if fields have been set by ggml_vk_create_pipeline\n    bool initialized {};\n    // set to true to request the pipeline is compiled\n    std::atomic<bool> needed {};\n    // set to true when the shader has been compiled\n    std::atomic<bool> compiled {};\n    // number of registers used, extracted from pipeline executable properties\n    uint32_t register_count {};\n\n#if defined(VK_EXT_shader_64bit_indexing)\n    bool is_64b_indexing {};\n#endif\n    // linked list of pipelines for multiple compilation variants.\n    // currently only used to compile a 64-bit indexing variant.\n    vk_pipeline next;\n};\n\ntypedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;\n\nstatic void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);\n\nstruct vk_matmul_pipeline_struct {\n    vk_pipeline l, m, s;\n    vk_pipeline a_l, a_m, a_s;\n    // Returns true when all unaligned pipelines are null.\n    // We only check for unaligned variants since one of the unaligned pipelines must exist\n    // while aligned pipelines are optional\n    bool is_empty() const {\n        return l == nullptr && m == nullptr && s == nullptr;\n    }\n};\ntypedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;\n\nstruct vk_matmul_pipeline2 {\n    vk_matmul_pipeline2() {\n        f16acc = std::make_shared<vk_matmul_pipeline_struct>();\n        f32acc = std::make_shared<vk_matmul_pipeline_struct>();\n    }\n    vk_matmul_pipeline f32acc;\n    vk_matmul_pipeline f16acc;\n};\n\nstruct vk_device_struct;\ntypedef std::shared_ptr<vk_device_struct> vk_device;\ntypedef std::weak_ptr<vk_device_struct> vk_device_ref;\n\nstruct vk_buffer_struct;\ntypedef std::shared_ptr<vk_buffer_struct> vk_buffer;\ntypedef std::weak_ptr<vk_buffer_struct> vk_buffer_ref;\n\nstruct ggml_backend_vk_buffer_type_context {\n    std::string name;\n    vk_device device;\n};\n\nstruct vk_queue;\n\nstruct vk_command_buffer {\n    vk::CommandBuffer buf;\n    bool in_use = false;\n};\n\n// Stores command pool/buffers. There's an instance of this\n// for each (context,queue) pair and for each (device,queue) pair.\nstruct vk_command_pool {\n    void init(vk_device& device, vk_queue *q_);\n    void destroy(vk::Device& device);\n\n    vk::CommandPool pool;\n    // Using deque so the pointers to command buffers\n    // remain valid even if we add more\n    std::deque<vk_command_buffer> cmd_buffers;\n\n    vk_queue *q;\n\n    size_t buffers_in_use() const {\n        return std::count_if(cmd_buffers.begin(), cmd_buffers.end(),\n            [](const auto& cb) { return cb.in_use; });\n    }\n};\n\n// Prevent simultaneous submissions to the same queue.\n// This could be per vk_queue if we stopped having two vk_queue structures\n// sharing the same vk::Queue.\nstatic std::mutex queue_mutex;\n\nstruct vk_queue {\n    uint32_t queue_family_index;\n    vk::Queue queue;\n\n    vk_command_pool cmd_pool;\n\n    vk::PipelineStageFlags stage_flags;\n\n    bool transfer_only;\n\n    // copy everything except the cmd_pool\n    void copyFrom(vk_queue &other) {\n        queue_family_index = other.queue_family_index;\n        queue = other.queue;\n        stage_flags = other.stage_flags;\n        transfer_only = other.transfer_only;\n    }\n};\n\nstatic const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft);\nstatic ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);\nstatic size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft);\nstatic size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft);\nstatic size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor);\nstatic ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {\n    /* .get_name         = */ ggml_backend_vk_buffer_type_name,\n    /* .alloc_buffer     = */ ggml_backend_vk_buffer_type_alloc_buffer,\n    /* .get_alignment    = */ ggml_backend_vk_buffer_type_get_alignment,\n    /* .get_max_size     = */ ggml_backend_vk_buffer_type_get_max_size,\n    /* .get_alloc_size   = */ ggml_backend_vk_buffer_type_get_alloc_size,\n    /* .is_host          = */ NULL,\n};\n\nclass vk_memory_logger;\nclass vk_perf_logger;\nstatic void ggml_vk_destroy_buffer(vk_buffer& buf);\nstatic void ggml_vk_synchronize(ggml_backend_vk_context * ctx);\n\nstatic constexpr uint32_t mul_mat_vec_max_cols = 8;\nstatic constexpr uint32_t p021_max_gqa_ratio = 8;\n\nenum vk_device_architecture {\n    OTHER,\n    AMD_GCN,\n    AMD_RDNA1,\n    AMD_RDNA2,\n    AMD_RDNA3,\n    INTEL_XE2,\n    NVIDIA_PRE_TURING,\n    NVIDIA_TURING,\n};\n\nstatic vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {\n    vk::PhysicalDeviceProperties props = device.getProperties();\n\n    if (props.vendorID == VK_VENDOR_ID_AMD) {\n        const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();\n\n        bool amd_shader_core_properties = false;\n        bool integer_dot_product = false;\n        bool subgroup_size_control = false;\n\n        for (const auto& properties : ext_props) {\n            if (strcmp(\"VK_AMD_shader_core_properties\", properties.extensionName) == 0) {\n                amd_shader_core_properties = true;\n            } else if (strcmp(\"VK_KHR_shader_integer_dot_product\", properties.extensionName) == 0) {\n                integer_dot_product = true;\n            } else if (strcmp(\"VK_EXT_subgroup_size_control\", properties.extensionName) == 0) {\n                subgroup_size_control = true;\n            }\n        }\n\n        if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {\n            return vk_device_architecture::OTHER;\n        }\n\n        vk::PhysicalDeviceProperties2 props2;\n        vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;\n        vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;\n        vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;\n\n        props2.pNext = &shader_core_props_amd;\n        shader_core_props_amd.pNext = &integer_dot_props;\n        integer_dot_props.pNext = &subgroup_size_control_props;\n\n        device.getProperties2(&props2);\n\n        if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) {\n            return vk_device_architecture::AMD_GCN;\n        }\n        if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) {\n            // RDNA\n            if (shader_core_props_amd.wavefrontsPerSimd == 20) {\n                return vk_device_architecture::AMD_RDNA1;\n            }\n            if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) {\n                return vk_device_architecture::AMD_RDNA3;\n            }\n            return vk_device_architecture::AMD_RDNA2;\n        }\n    } else if (props.vendorID == VK_VENDOR_ID_INTEL) {\n        const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();\n\n        bool subgroup_size_control = false;\n\n        for (const auto& properties : ext_props) {\n            if (strcmp(\"VK_EXT_subgroup_size_control\", properties.extensionName) == 0) {\n                subgroup_size_control = true;\n            }\n        }\n\n        if (!subgroup_size_control) {\n            return vk_device_architecture::OTHER;\n        }\n\n        vk::PhysicalDeviceProperties2 props2;\n        vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;\n\n        props2.pNext = &subgroup_size_control_props;\n        device.getProperties2(&props2);\n\n        if (subgroup_size_control_props.minSubgroupSize == 16) {\n            // Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8.\n            // Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value.\n            // https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html\n            // https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html\n            return vk_device_architecture::INTEL_XE2;\n        }\n    } else if (props.vendorID == VK_VENDOR_ID_NVIDIA) {\n        const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();\n\n        bool cooperative_matrix = false;\n        bool sm_builtins = false;\n\n        // Detect \"pre-turing\" based on lack of coopmat support.\n        for (const auto& properties : ext_props) {\n            if (strcmp(\"VK_KHR_cooperative_matrix\", properties.extensionName) == 0) {\n                cooperative_matrix = true;\n            } else if (strcmp(\"VK_NV_shader_sm_builtins\", properties.extensionName) == 0) {\n                sm_builtins = true;\n            }\n        }\n\n        if (!cooperative_matrix) {\n            return vk_device_architecture::NVIDIA_PRE_TURING;\n        }\n\n        if (sm_builtins) {\n            vk::PhysicalDeviceProperties2 props2;\n            vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;\n\n            props2.pNext = &sm_props;\n\n            device.getProperties2(&props2);\n\n            // Turing has 32, following architectures have 48\n            if (sm_props.shaderWarpsPerSM == 32) {\n                return vk_device_architecture::NVIDIA_TURING;\n            }\n        }\n    }\n    return vk_device_architecture::OTHER;\n}\n\nenum vk_conv_shapes {\n    CONV_SHAPE_128x128,\n    CONV_SHAPE_64x32,\n    CONV_SHAPE_32x256,\n    CONV_SHAPE_COUNT,\n};\n\nstruct vk_conv_block_size {\n    uint32_t K;\n    uint32_t NPQ;\n    uint32_t CRS;\n};\n\nvk_conv_block_size vk_conv_block_sizes[CONV_SHAPE_COUNT] = {\n    // K   NPQ  CRS\n    { 128, 128, 16 }, // CONV_SHAPE_128x128\n    {  64,  32, 32 }, // CONV_SHAPE_64x32\n    {  32, 256, 16 }, // CONV_SHAPE_32x256\n};\n\nenum dmmv_wg_sizes {\n    DMMV_WG_SIZE_SUBGROUP,\n    DMMV_WG_SIZE_LARGE,\n    DMMV_WG_SIZE_COUNT,\n};\n\nenum FaCodePath {\n    FA_SCALAR,\n    FA_COOPMAT1,\n    FA_COOPMAT2,\n};\n\nstruct vk_fa_pipeline_state {\n    uint32_t HSK, HSV;\n    uint32_t Br, Bc;\n    uint32_t D_split, row_split;\n    bool shmem_staging;\n    FaCodePath path;\n    uint32_t workgroup_size, subgroup_size;\n    bool aligned;\n    bool f32acc;\n    uint32_t flags;\n    uint32_t limit_occupancy_shmem;\n\n    bool operator<(const vk_fa_pipeline_state &b) const {\n        return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <\n               std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);\n    }\n};\n\nstruct vk_conv2d_pipeline_state {\n    vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH)\n        : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {}\n\n    uint32_t s0, s1, p0, p1, d0, d1, KW, KH;\n\n    bool operator<(const vk_conv2d_pipeline_state &b) const {\n        return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) <\n               std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH);\n    }\n};\n\nstruct vk_solve_tri_pipeline_state {\n    vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)\n        : N(N), K(K) {}\n\n    uint32_t N, K;\n\n    bool operator<(const vk_solve_tri_pipeline_state &b) const {\n        return std::tie(N, K) <\n               std::tie(b.N, b.K);\n    }\n};\n\nenum shader_reduction_mode {\n    SHADER_REDUCTION_MODE_SHMEM,\n    SHADER_REDUCTION_MODE_HYBRID,\n    SHADER_REDUCTION_MODE_SUBGROUP,\n    SHADER_REDUCTION_MODE_COUNT,\n};\n\n// argsort pipelines for up to 1<<10 invocations per workgroup\nstatic constexpr uint32_t num_argsort_pipelines = 11;\nstatic constexpr uint32_t num_topk_moe_pipelines = 10;\nstatic constexpr uint32_t num_topk_pipelines = 11;\n\nstatic constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,\n                                                                             GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,\n                                                                             GGML_OP_SUM_ROWS, GGML_OP_CLAMP,    GGML_OP_DIV,\n                                                                             GGML_OP_RESHAPE };\n\nstatic constexpr std::initializer_list<ggml_op> topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY,    GGML_OP_RESHAPE,  GGML_OP_ADD,\n                                                                            GGML_OP_ARGSORT,  GGML_OP_VIEW,     GGML_OP_GET_ROWS,\n                                                                            GGML_OP_RESHAPE,  GGML_OP_SUM_ROWS, GGML_OP_CLAMP,\n                                                                            GGML_OP_DIV,      GGML_OP_RESHAPE };\n\nstatic constexpr std::initializer_list<ggml_op> topk_moe_early_softmax     { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,\n                                                                             GGML_OP_VIEW,     GGML_OP_GET_ROWS };\n\nstatic constexpr std::initializer_list<ggml_op> topk_moe_late_softmax      { GGML_OP_ARGSORT,  GGML_OP_VIEW,\n                                                                             GGML_OP_GET_ROWS, GGML_OP_RESHAPE,\n                                                                             GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };\n\n//node #978 (  SOFT_MAX):     ffn_moe_probs-15 (   0K) [Vulka         ] use=2:    ffn_moe_logits-15 (   0K) [Vulka         ]\n//node #979 (   RESHAPE): ffn_moe_probs-15 (re (   0K) [Vulka         ] use=1:     ffn_moe_probs-15 (   0K) [Vulka         ]\n//node #980 (   ARGSORT):   ffn_moe_argsort-15 (   0K) [Vulka         ] use=1:     ffn_moe_probs-15 (   0K) [Vulka         ]\n//node #981 (      VIEW):      ffn_moe_topk-15 (   0K) [Vulka         ] use=4:   ffn_moe_argsort-15 (   0K) [Vulka         ]\n//node #982 (  GET_ROWS):   ffn_moe_weights-15 (   0K) [Vulka         ] use=1: ffn_moe_probs-15 (re (   0K) [Vulka         ]      ffn_moe_topk-15 (   0K) [Vulka         ]\n//node #983 (   RESHAPE): ffn_moe_weights-15 ( (   0K) [Vulka         ] use=2:   ffn_moe_weights-15 (   0K) [Vulka         ]\n//node #984 (  SUM_ROWS): ffn_moe_weights_sum- (   0K) [Vulka         ] use=1: ffn_moe_weights-15 ( (   0K) [Vulka         ]\n//node #985 (     CLAMP): ffn_moe_weights_sum_ (   0K) [Vulka         ] use=1: ffn_moe_weights_sum- (   0K) [Vulka         ]\n//node #986 (       DIV): ffn_moe_weights_norm (   0K) [Vulka         ] use=1: ffn_moe_weights-15 ( (   0K) [Vulka         ] ffn_moe_weights_sum_ (   0K) [Vulka         ]\n//node #987 (   RESHAPE): ffn_moe_weights_norm (   0K) [Vulka         ] use=1: ffn_moe_weights_norm (   0K) [Vulka         ]\nstatic constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {\n    { 1, 0, 0 }, // reshape->src[0]  == softmax\n    { 2, 0, 0 }, // argsort->src[0]  == softmax\n    { 3, 0, 2 }, // view->src[0]     == argsort\n    { 4, 0, 1 }, // get_rows->src[0] == reshape\n    { 4, 1, 3 }, // get_rows->src[1] == view\n    { 5, 0, 4 }, // reshape->src[0]  == get_rows\n    { 6, 0, 5 }, // sum_rows->src[0] == reshape\n    { 7, 0, 6 }, // clamp->src[0]    == sum_rows\n    { 8, 0, 5 }, // div->src[0]      == reshape\n    { 8, 1, 7 }, // div->src[1]      == clamp\n    { 9, 0, 8 }, // reshape->src[0]  == div\n};\n\n//node #436 (     UNARY):     ffn_moe_probs-10 ( 256K) [Vulka         ] use=2:    ffn_moe_logits-10 ( 256K) [Vulka         ]\n//node #437 (   RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka         ] use=1:     ffn_moe_probs-10 ( 256K) [Vulka         ]\n//node #438 (       ADD): ffn_moe_probs_biased ( 256K) [Vulka         ] use=1:     ffn_moe_probs-10 ( 256K) [Vulka         ] blk.10.exp_probs_b.b (   0K) [Vulka         ]\n//node #439 (   ARGSORT):   ffn_moe_argsort-10 ( 256K) [Vulka         ] use=1: ffn_moe_probs_biased ( 256K) [Vulka         ]\n//node #440 (      VIEW):      ffn_moe_topk-10 ( 255K) [Vulka         ] use=3:   ffn_moe_argsort-10 ( 256K) [Vulka         ]\n//node #441 (  GET_ROWS):   ffn_moe_weights-10 (  12K) [Vulka         ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka         ]      ffn_moe_topk-10 ( 255K) [Vulka         ]\n//node #442 (   RESHAPE): ffn_moe_weights-10 ( (  12K) [Vulka         ] use=2:   ffn_moe_weights-10 (  12K) [Vulka         ]\n//node #443 (  SUM_ROWS): ffn_moe_weights_sum- (   2K) [Vulka         ] use=1: ffn_moe_weights-10 ( (  12K) [Vulka         ]\n//node #444 (     CLAMP): ffn_moe_weights_sum_ (   2K) [Vulka         ] use=1: ffn_moe_weights_sum- (   2K) [Vulka         ]\n//node #445 (       DIV): ffn_moe_weights_norm (  12K) [Vulka         ] use=1: ffn_moe_weights-10 ( (  12K) [Vulka         ] ffn_moe_weights_sum_ (   2K) [Vulka         ]\n//node #446 (   RESHAPE): ffn_moe_weights_norm (  12K) [Vulka         ] use=1: ffn_moe_weights_norm (  12K) [Vulka         ]\nstatic constexpr std::initializer_list<std::array<int, 3>> topk_moe_sigmoid_norm_bias_edges {\n    { 1, 0, 0 }, // reshape->src[0]  == sigmoid\n    { 2, 0, 0 }, // add->src[0]      == sigmoid\n    { 3, 0, 2 }, // argsort->src[0]  == add\n    { 4, 0, 3 }, // view->src[0]     == argsort\n    { 5, 0, 1 }, // get_rows->src[0] == reshape\n    { 5, 1, 4 }, // get_rows->src[1] == view\n    { 6, 0, 5 }, // reshape->src[0]  == get_rows\n    { 7, 0, 6 }, // sum_rows->src[0] == reshape\n    { 8, 0, 7 }, // clamp->src[0]    == sum_rows\n    { 9, 0, 6 }, // div->src[0]      == reshape\n    { 9, 1, 8 }, // div->src[1]      == clamp\n    {10, 0, 9 }, // reshape->src[0]  == div\n};\n\n// same as early_softmax_norm but ending after the get_rows\nstatic constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {\n    { 1, 0, 0 }, // reshape->src[0]  == softmax\n    { 2, 0, 0 }, // argsort->src[0]  == softmax\n    { 3, 0, 2 }, // view->src[0]     == argsort\n    { 4, 0, 1 }, // get_rows->src[0] == reshape\n    { 4, 1, 3 }, // get_rows->src[1] == view\n};\n\n//node #652 (   ARGSORT):   ffn_moe_argsort-11 (   0K) [Vulka         ] use=1:     ffn_moe_probs-11 (   0K) [Vulka         ]\n//node #653 (      VIEW):      ffn_moe_topk-11 (   0K) [Vulka         ] use=7:   ffn_moe_argsort-11 (   0K) [Vulka         ]\n//node #654 (  GET_ROWS):   ffn_moe_weights-11 (   0K) [Vulka         ] use=1: ffn_moe_probs-11 (re (   0K) [Vulka         ]      ffn_moe_topk-11 (   0K) [Vulka         ]\n//node #655 (   RESHAPE): ffn_moe_weights-11 ( (   0K) [Vulka         ] use=1:   ffn_moe_weights-11 (   0K) [Vulka         ]\n//node #656 (  SOFT_MAX):             node_656 (   0K) [Vulka         ] use=1: ffn_moe_weights-11 ( (   0K) [Vulka         ]\n//node #657 (   RESHAPE): ffn_moe_weights_soft (   0K) [Vulka         ] use=1:             node_656 (   0K) [Vulka         ]\nstatic constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {\n    { 1, 0, 0 }, // view->src[0]     == argsort\n    { 2, 1, 1 }, // get_rows->src[1] == view\n    { 3, 0, 2 }, // reshape->src[0]  == get_rows\n    { 4, 0, 3 }, // soft_max->src[0] == reshape\n    { 5, 0, 4 }, // reshape->src[0]  == soft_max\n};\n\nenum topk_moe_mode {\n    TOPK_MOE_EARLY_SOFTMAX,\n    TOPK_MOE_EARLY_SOFTMAX_NORM,\n    TOPK_MOE_LATE_SOFTMAX,\n    TOPK_MOE_SIGMOID_NORM_BIAS,\n    TOPK_MOE_COUNT,\n};\n\nstatic constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {\n    { 1, 0, 0 }, // view->src[0]     == rope\n    { 2, 0, 1 }, // set_rows->src[0] == view\n};\n\nstatic constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_view_set_rows_edges {\n    { 1, 0, 0 }, // mul->src[0]      == rms\n    { 2, 0, 1 }, // rope->src[0]     == mul\n    { 3, 0, 2 }, // view->src[0]     == rope\n    { 4, 0, 3 }, // set_rows->src[0] == view\n};\n\n\nstruct vk_device_struct {\n    std::recursive_mutex mutex;\n\n    vk::PhysicalDevice physical_device;\n    vk::PhysicalDeviceProperties properties;\n    std::string name;\n    uint64_t max_memory_allocation_size;\n    uint64_t max_buffer_size;\n    uint64_t suballocation_block_size;\n    uint64_t min_imported_host_pointer_alignment;\n    bool external_memory_host {};\n    bool fp16;\n    bool bf16;\n    bool pipeline_robustness;\n    bool memory_priority;\n    vk::Device device;\n    uint32_t vendor_id;\n    vk::DriverId driver_id;\n    vk_device_architecture architecture;\n    vk_queue compute_queue;\n    vk_queue transfer_queue;\n    bool single_queue;\n    bool support_async;\n    bool async_use_transfer_queue;\n    uint32_t subgroup_size;\n    uint32_t subgroup_size_log2;\n    uint32_t shader_core_count;\n    bool uma;\n    bool prefer_host_memory;\n    bool float_controls_rte_fp16;\n    bool subgroup_basic;\n    bool subgroup_arithmetic;\n    bool subgroup_shuffle;\n    bool subgroup_ballot;\n    bool subgroup_clustered;\n    bool subgroup_vote;\n    bool multi_add;\n    bool shader_int64;\n    bool buffer_device_address;\n    bool vulkan_memory_model;\n\n    bool add_rms_fusion;\n    uint32_t partials_binding_alignment;\n\n    bool shader_64b_indexing;\n\n    bool integer_dot_product;\n    // 0: default, 1: force mmvq, -1: disable mmvq\n    int32_t mmvq_mode;\n\n    bool subgroup_size_control;\n    uint32_t subgroup_min_size;\n    uint32_t subgroup_max_size;\n    bool subgroup_require_full_support;\n\n    // floor(log2(maxComputeWorkGroupInvocations))\n    uint32_t max_workgroup_size_log2 {};\n\n    bool coopmat_support;\n    bool coopmat_acc_f32_support {};\n    bool coopmat_acc_f16_support {};\n    bool coopmat_bf16_support {};\n    bool coopmat_support_16x16x16_f16acc {};\n    bool coopmat_support_16x16x16_f32acc {};\n    bool coopmat1_fa_support {};\n    uint32_t coopmat_m;\n    uint32_t coopmat_n;\n    uint32_t coopmat_k;\n\n    bool coopmat_int_support;\n    uint32_t coopmat_int_m;\n    uint32_t coopmat_int_n;\n    uint32_t coopmat_int_k;\n\n    bool coopmat2;\n\n    bool pipeline_executable_properties_support {};\n\n    size_t idx;\n\n    bool mul_mat_l[GGML_TYPE_COUNT];\n    bool mul_mat_m[GGML_TYPE_COUNT];\n    bool mul_mat_s[GGML_TYPE_COUNT];\n    bool mul_mat_id_l[GGML_TYPE_COUNT];\n    bool mul_mat_id_m[GGML_TYPE_COUNT];\n    bool mul_mat_id_s[GGML_TYPE_COUNT];\n\n    vk::DescriptorSetLayout dsl;\n\n    vk_matmul_pipeline pipeline_matmul_f32 {};\n    vk_matmul_pipeline pipeline_matmul_f32_f16 {};\n    vk_matmul_pipeline pipeline_matmul_bf16 {};\n    vk_matmul_pipeline2 pipeline_matmul_f16;\n    vk_matmul_pipeline2 pipeline_matmul_f16_f32;\n\n    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];\n    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];\n    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];\n\n    vk_matmul_pipeline pipeline_matmul_id_f32 {};\n    vk_matmul_pipeline pipeline_matmul_id_bf16 {};\n    vk_matmul_pipeline2 pipeline_matmul_id_f16;\n    vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;\n\n    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];\n    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT];\n\n    vk_pipeline pipeline_matmul_split_k_reduce;\n    vk_pipeline pipeline_quantize_q8_1_x4;\n\n    vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];\n    vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];\n    vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];\n    vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];\n\n    vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];\n    vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];\n\n    vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];\n    vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;\n    vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];\n    vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];\n    vk_pipeline pipeline_acc_f32;\n    vk_pipeline pipeline_set_f32;\n\n    // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]\n    vk_pipeline pipeline_add[2][2][2];\n    vk_pipeline pipeline_add_norepeat[2][2][2];\n    vk_pipeline pipeline_sub[2][2][2];\n    vk_pipeline pipeline_sub_norepeat[2][2][2];\n    vk_pipeline pipeline_mul[2][2][2];\n    vk_pipeline pipeline_mul_norepeat[2][2][2];\n    vk_pipeline pipeline_div[2][2][2];\n    vk_pipeline pipeline_div_norepeat[2][2][2];\n    vk_pipeline pipeline_add_rms[2][2][2];\n    vk_pipeline pipeline_add_rms_norepeat[2][2][2];\n\n    // indexed by num_additional_fused_ops == num_adds - 1\n    vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];\n    vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS];\n\n    vk_pipeline pipeline_add_id_f32;\n\n    vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;\n    vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;\n    vk_pipeline pipeline_scale_f32;\n    vk_pipeline pipeline_sqr_f32;\n    vk_pipeline pipeline_sqrt_f32;\n    vk_pipeline pipeline_sin_f32;\n    vk_pipeline pipeline_cos_f32;\n    vk_pipeline pipeline_log[2];\n    vk_pipeline pipeline_tri[2];\n    vk_pipeline pipeline_diag[2];\n    vk_pipeline pipeline_clamp_f32;\n    vk_pipeline pipeline_pad_f32;\n    vk_pipeline pipeline_roll_f32;\n    vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;\n    vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;\n    vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;\n    vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];\n    vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];\n    vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32;\n    vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT];\n    vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT];\n    vk_pipeline pipeline_norm_f32;\n    vk_pipeline pipeline_group_norm_f32;\n    vk_pipeline pipeline_rms_norm_f32;\n    vk_pipeline pipeline_rms_norm_mul_f32;\n    vk_pipeline pipeline_rms_norm_partials_f32;\n    vk_pipeline pipeline_rms_norm_mul_partials_f32;\n    vk_pipeline pipeline_rms_norm_mul_rope_f32_f32;\n    vk_pipeline pipeline_rms_norm_mul_rope_f32_f16;\n    vk_pipeline pipeline_rms_norm_back_f32;\n    vk_pipeline pipeline_l2_norm_f32;\n\n    // [src/dst 0=fp32,1=fp16]\n    vk_pipeline pipeline_exp[2];\n    vk_pipeline pipeline_elu[2];\n    vk_pipeline pipeline_gelu[2];\n    vk_pipeline pipeline_gelu_erf[2];\n    vk_pipeline pipeline_gelu_quick[2];\n    vk_pipeline pipeline_silu[2];\n    vk_pipeline pipeline_relu[2];\n    vk_pipeline pipeline_xielu[2];\n    vk_pipeline pipeline_neg[2];\n    vk_pipeline pipeline_tanh[2];\n    vk_pipeline pipeline_sigmoid[2];\n    vk_pipeline pipeline_hardsigmoid[2];\n    vk_pipeline pipeline_hardswish[2];\n    vk_pipeline pipeline_abs[2];\n    vk_pipeline pipeline_softplus[2];\n    vk_pipeline pipeline_step[2];\n    vk_pipeline pipeline_round[2];\n    vk_pipeline pipeline_ceil[2];\n    vk_pipeline pipeline_floor[2];\n    vk_pipeline pipeline_trunc[2];\n    vk_pipeline pipeline_sgn[2];\n\n    vk_pipeline pipeline_add1_f16_f16;\n    vk_pipeline pipeline_add1_f16_f32;\n    vk_pipeline pipeline_add1_f32_f32;\n\n    vk_pipeline pipeline_arange_f32;\n\n    vk_pipeline pipeline_fill_f32;\n\n    vk_pipeline pipeline_geglu[2];\n    vk_pipeline pipeline_reglu[2];\n    vk_pipeline pipeline_swiglu[2];\n    vk_pipeline pipeline_swiglu_oai[2];\n    vk_pipeline pipeline_geglu_erf[2];\n    vk_pipeline pipeline_geglu_quick[2];\n\n    vk_pipeline pipeline_leaky_relu_f32;\n    vk_pipeline pipeline_silu_back_f32;\n    vk_pipeline pipeline_diag_mask_inf_f32;\n    vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;\n    vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;\n    vk_pipeline pipeline_soft_max_back_f32;\n\n    vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16;\n    vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16;\n    vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16;\n\n    vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;\n    vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;\n    vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16, pipeline_rope_multi_f32_f16;\n    vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;\n    vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];\n    vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];\n    vk_pipeline pipeline_topk_f32[num_topk_pipelines];\n    vk_pipeline pipeline_sum_rows_f32;\n    vk_pipeline pipeline_cumsum_f32;\n    vk_pipeline pipeline_cumsum_small_f32;\n    vk_pipeline pipeline_cumsum_multipass1_f32;\n    vk_pipeline pipeline_cumsum_multipass2_f32;\n    vk_pipeline pipeline_argmax_f32;\n    vk_pipeline pipeline_count_equal_i32;\n    std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;\n    vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;\n    vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;\n    vk_pipeline pipeline_timestep_embedding_f32;\n    vk_pipeline pipeline_conv_transpose_1d_f32;\n    vk_pipeline pipeline_pool2d_f32;\n    vk_pipeline pipeline_rwkv_wkv6_f32;\n    vk_pipeline pipeline_rwkv_wkv7_f32;\n    // [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128\n    vk_pipeline pipeline_gated_delta_net[3][2];\n    vk_pipeline pipeline_ssm_scan_f32_d128;\n    vk_pipeline pipeline_ssm_scan_f32_d256;\n    vk_pipeline pipeline_ssm_conv_f32;\n    vk_pipeline pipeline_opt_step_adamw_f32;\n    vk_pipeline pipeline_opt_step_sgd_f32;\n    std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT];\n    std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];\n    std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];\n    std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];\n    vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;\n    vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;\n\n    std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];\n\n    std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt;\n\n    vk_pipeline pipeline_flash_attn_split_k_reduce;\n    vk_pipeline pipeline_count_experts;\n\n    // [2] is for whether to take n_experts from spec constant (0) or push constant (1)\n    vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];\n\n    std::vector<vk_pipeline_ref> all_pipelines;\n\n    std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;\n\n    vk::Fence fence;\n    vk_buffer sync_staging;\n\n    ggml_backend_buffer_type buffer_type;\n\n    bool disable_fusion;\n    bool disable_host_visible_vidmem;\n    bool allow_sysmem_fallback;\n    bool disable_graph_optimize;\n\n    std::unique_ptr<vk_memory_logger> memory_logger;\n\n    ~vk_device_struct() {\n        VK_LOG_DEBUG(\"destroy device \" << name);\n\n        device.destroyFence(fence);\n\n        ggml_vk_destroy_buffer(sync_staging);\n\n        compute_queue.cmd_pool.destroy(device);\n        transfer_queue.cmd_pool.destroy(device);\n\n        for (auto& pipeline : all_pipelines) {\n            if (pipeline.expired()) {\n                continue;\n            }\n\n            vk_pipeline pl = pipeline.lock();\n            ggml_vk_destroy_pipeline(device, pl);\n        }\n        all_pipelines.clear();\n\n        device.destroyDescriptorSetLayout(dsl);\n\n        device.destroy();\n    }\n};\n\nvoid vk_command_pool::init(vk_device& device, vk_queue *q_) {\n    cmd_buffers.clear();\n    q = q_;\n\n    vk::CommandPoolCreateInfo command_pool_create_info(\n        vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT | VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT),\n        q->queue_family_index);\n    pool = device->device.createCommandPool(command_pool_create_info);\n}\n\nvoid vk_command_pool::destroy(vk::Device& device) {\n    device.destroyCommandPool(pool);\n    pool = nullptr;\n    cmd_buffers.clear();\n}\n\nstruct vk_buffer_struct {\n    vk::Buffer buffer = VK_NULL_HANDLE;\n    vk::DeviceMemory device_memory = VK_NULL_HANDLE;\n    vk::MemoryPropertyFlags memory_property_flags;\n    void * ptr;\n    size_t size = 0;\n    vk::DeviceAddress bda_addr {};\n\n    vk_device device;\n\n    ~vk_buffer_struct() {\n        if (size == 0) {\n            return;\n        }\n        VK_LOG_DEBUG(\"~vk_buffer_struct(\" << buffer << \", \" << size << \")\");\n\n        device->device.freeMemory(device_memory);\n        device->device.destroyBuffer(buffer);\n    }\n};\n\nstruct vk_subbuffer {\n    vk_buffer buffer;\n    uint64_t offset;\n    uint64_t size;\n\n    operator vk::DescriptorBufferInfo() const {\n        return { buffer->buffer, offset, size };\n    }\n};\n\n// vk_event is used for the event-related backend interfaces. It uses 'event' for\n// event_wait and 'fence' for event_synchronize. Polling on an event for\n// event_synchronize wouldn't be sufficient to wait for command buffers to complete,\n// and would lead to validation errors.\nstruct vk_event {\n    vk::Event event;\n    vk::Fence fence;\n    vk_command_buffer* cmd_buffer = nullptr;\n};\n\nstruct vk_semaphore {\n    vk::Semaphore s;\n    uint64_t value;\n};\n\nstruct vk_submission {\n    vk_command_buffer* buffer = nullptr;\n    std::vector<vk_semaphore> wait_semaphores;\n    std::vector<vk_semaphore> signal_semaphores;\n};\n\ntypedef std::vector<vk_submission> vk_sequence;\n\nstruct vk_mat_mat_push_constants {\n    uint32_t M; uint32_t N; uint32_t K;\n    uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;\n    uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;\n    uint32_t base_work_group_z; uint32_t num_batches;\n    uint32_t k_split;\n    uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;\n    uint32_t padded_N;\n};\n\n#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1\n#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2\n#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4\n#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8\n\nstruct vk_mat_vec_push_constants {\n    uint32_t ncols;\n    uint32_t stride_a;\n    uint32_t stride_b;\n    uint32_t stride_d;\n    uint32_t batch_stride_a;\n    uint32_t batch_stride_b;\n    uint32_t batch_stride_d;\n    uint32_t fusion_flags;\n    uint32_t base_work_group_y;\n    uint32_t ne02;\n    uint32_t ne12;\n    uint32_t broadcast2;\n    uint32_t broadcast3;\n};\n\nstruct vk_mat_vec_p021_push_constants {\n    uint32_t ncols_x;\n    uint32_t nrows_x;\n    uint32_t nchannels_x;\n    uint32_t nchannels_y;\n    uint32_t b_offset;\n    uint32_t d_offset;\n    uint32_t fusion_flags;\n};\n\nstruct vk_mat_vec_nc_push_constants {\n    uint32_t ncols_x;\n    uint32_t nrows_x;\n    uint32_t row_stride_x;\n    uint32_t channel_stride_x;\n    uint32_t channel_stride_y;\n    uint32_t channel_x_divisor;\n    uint32_t ne12;\n    uint32_t b_offset;\n    uint32_t d_offset;\n    uint32_t nb03;\n    uint32_t nb13;\n    uint32_t nb23;\n    uint32_t fusion_flags;\n};\n\nstruct vk_mat_mat_id_push_constants {\n    uint32_t M; uint32_t N; uint32_t K;\n    uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;\n    uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;\n    uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;\n    uint32_t padded_N;\n};\nstruct vk_mat_vec_id_push_constants {\n    uint32_t ncols;\n    uint32_t stride_a;\n    uint32_t stride_b;\n    uint32_t stride_d;\n    uint32_t batch_stride_a;\n    uint32_t batch_stride_b;\n    uint32_t batch_stride_d;\n    uint32_t fusion_flags;\n    uint32_t nei0;\n    uint32_t ne11;\n    uint32_t expert_i1;\n    uint32_t nbi1;\n};\n\nstruct vk_flash_attn_push_constants {\n    uint32_t N;\n    uint32_t KV;\n\n    uint32_t ne1;\n    uint32_t ne2;\n    uint32_t ne3;\n\n    uint32_t neq2;\n    uint32_t neq3;\n    uint32_t nek2;\n    uint32_t nek3;\n    uint32_t nev2;\n    uint32_t nev3;\n    uint32_t nem1;\n    uint32_t nem2;\n    uint32_t nem3;\n\n    uint32_t nb01;\n    uint32_t nb02;\n    uint32_t nb03;\n    uint32_t nb11;\n    uint32_t nb12;\n    uint32_t nb13;\n    uint32_t nb21;\n    uint32_t nb22;\n    uint32_t nb23;\n\n    float scale;\n    float max_bias;\n    float logit_softcap;\n\n    uint32_t mask_n_head_log2;\n    float m0;\n    float m1;\n\n    uint32_t gqa_ratio;\n    uint32_t split_kv;\n    uint32_t k_num;\n};\nstatic_assert(sizeof(vk_flash_attn_push_constants) <= 128, \"sizeof(vk_flash_attn_push_constants) must be <= 128\");\n\nstruct vk_op_push_constants {\n    uint32_t KX;\n    uint32_t KY;\n    float param1;\n    float param2;\n    float param3;\n    float param4;\n};\n\nstruct vk_op_count_experts_push_constants {\n    uint32_t ne00;\n    uint32_t ne01;\n    uint32_t nb00;\n    uint32_t nb01;\n    uint32_t a_offset;\n};\n\nstruct vk_op_glu_push_constants {\n    uint32_t N;\n    uint32_t ne00;\n    uint32_t ne20;\n    uint32_t mode;  // 0: default, 1: swapped, 2: split\n    float alpha; // for swiglu_oai\n    float limit;\n};\n\nstruct vk_op_unary_push_constants {\n    uint32_t ne;\n    uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;\n    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;\n    uint32_t misalign_offsets;\n    float param1; float param2;\n    uint32_t ne0_012mp; uint32_t ne0_012L;\n    uint32_t ne0_01mp;  uint32_t ne0_01L;\n    uint32_t ne0_0mp;   uint32_t ne0_0L;\n    uint32_t ne1_012mp; uint32_t ne1_012L;\n    uint32_t ne1_01mp;  uint32_t ne1_01L;\n    uint32_t ne1_0mp;   uint32_t ne1_0L;\n};\nstatic_assert(sizeof(vk_op_unary_push_constants) <= 128, \"sizeof(vk_op_unary_push_constants) must be <= 128\");\n\nstatic vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {\n    GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));\n    ne = ne != 0 ? ne : ggml_nelements(dst);\n    GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());\n\n    vk_op_unary_push_constants p{};\n    p.ne = (uint32_t)ne;\n\n    size_t src0_tsize = ggml_type_size(src0->type);\n    p.ne00 = (uint32_t)src0->ne[0];\n    p.ne01 = (uint32_t)src0->ne[1];\n    p.ne02 = (uint32_t)src0->ne[2];\n    p.ne03 = (uint32_t)src0->ne[3];\n    p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);\n    p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);\n    p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);\n    p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);\n\n    size_t dst_tsize = ggml_type_size(dst->type);\n    p.ne10 = (uint32_t)dst->ne[0];\n    p.ne11 = (uint32_t)dst->ne[1];\n    p.ne12 = (uint32_t)dst->ne[2];\n    p.ne13 = (uint32_t)dst->ne[3];\n    p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);\n    p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);\n    p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);\n    p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);\n\n    return p; // offsets are initialized later in ggml_vk_op\n}\n\nstruct vk_op_pad_push_constants {\n    uint32_t ne;\n    uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;\n    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;\n    uint32_t misalign_offsets;\n    uint32_t circular;\n\n    uint32_t lp0; uint32_t rp0;\n    uint32_t lp1; uint32_t rp1;\n    uint32_t lp2; uint32_t rp2;\n    uint32_t lp3; uint32_t rp3;\n};\n\nstatic vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) {\n    int64_t ne = ggml_nelements(dst);\n    GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());\n\n    vk_op_pad_push_constants p{};\n    p.ne = (uint32_t)ne;\n\n    size_t src0_tsize = ggml_type_size(src0->type);\n    p.ne00 = (uint32_t)src0->ne[0];\n    p.ne01 = (uint32_t)src0->ne[1];\n    p.ne02 = (uint32_t)src0->ne[2];\n    p.ne03 = (uint32_t)src0->ne[3];\n    p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);\n    p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);\n    p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);\n    p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);\n\n    size_t dst_tsize = ggml_type_size(dst->type);\n    p.ne10 = (uint32_t)dst->ne[0];\n    p.ne11 = (uint32_t)dst->ne[1];\n    p.ne12 = (uint32_t)dst->ne[2];\n    p.ne13 = (uint32_t)dst->ne[3];\n    p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);\n    p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);\n    p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);\n    p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);\n\n    p.lp0 = dst->op_params[0];\n    p.rp0 = dst->op_params[1];\n    p.lp1 = dst->op_params[2];\n    p.rp1 = dst->op_params[3];\n    p.lp2 = dst->op_params[4];\n    p.rp2 = dst->op_params[5];\n    p.lp3 = dst->op_params[6];\n    p.rp3 = dst->op_params[7];\n    p.circular = dst->op_params[8];\n\n    return p; // fastdiv values and offsets are initialized later in ggml_vk_op\n}\n\n// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.\n// Precompute mp (m' in the paper) and L such that division\n// can be computed using a multiply (high 32b of 64b result)\n// and a shift:\n//\n// n/d = (mulhi(n, mp) + n) >> L;\nstatic void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L)\n{\n    // compute L = ceil(log2(d));\n    L = 0;\n    while (L < 32 && (uint32_t{1} << L) < d) {\n        L++;\n    }\n\n    mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1);\n}\n\ntemplate <typename T> void init_pushconst_fastdiv(T &p) {\n    GGML_UNUSED(p);\n    static_assert(!std::is_const<T>::value, \"unexpected type\");\n}\n\ntemplate <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) {\n    // Compute magic values to divide by these six numbers.\n    init_fastdiv_values(p.ne02*p.ne01*p.ne00,  p.ne0_012mp,    p.ne0_012L);\n    init_fastdiv_values(p.ne01*p.ne00,         p.ne0_01mp,     p.ne0_01L);\n    init_fastdiv_values(p.ne00,                p.ne0_0mp,      p.ne0_0L);\n    init_fastdiv_values(p.ne12*p.ne11*p.ne10,  p.ne1_012mp,    p.ne1_012L);\n    init_fastdiv_values(p.ne11*p.ne10,         p.ne1_01mp,     p.ne1_01L);\n    init_fastdiv_values(p.ne10,                p.ne1_0mp,      p.ne1_0L);\n}\n\nstruct vk_op_binary_push_constants {\n    uint32_t ne;\n    uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;\n    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;\n    uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;\n    uint32_t misalign_offsets;\n    float param1; float param2; int32_t param3;\n};\n\nstruct vk_op_multi_add_push_constants {\n    // shape for dst\n    uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;\n\n    // strides for srcs+dst\n    uint32_t nb[MAX_PARAMETER_COUNT][4];\n\n    uint32_t rms_partials;\n};\n// update multi_add.comp if this changes\nstatic_assert(MAX_PARAMETER_COUNT == 12);\nstatic_assert(sizeof(vk_op_multi_add_push_constants) <= 256);\n\nstruct vk_op_topk_moe_push_constants {\n    uint32_t n_rows;\n    uint32_t n_experts_push;\n    uint32_t n_expert_used;\n    float clamp_min;\n    float clamp_max;\n    uint32_t gating_func;\n    uint32_t has_bias;\n    uint32_t with_norm;\n    float output_scale;\n    float output_bias;\n};\n\nstruct vk_op_add_id_push_constants {\n    uint32_t ne0;\n    uint32_t ne1;\n    uint32_t s01;\n    uint32_t s02;\n    uint32_t s11;\n    uint32_t s21;\n};\n\nstruct vk_op_diag_mask_push_constants {\n    uint32_t ncols;\n    uint32_t rows_per_channel;\n    int32_t n_past;\n};\n\nstruct vk_op_rope_push_constants {\n    uint32_t rope_mode;\n    uint32_t nrows;\n    uint32_t n_dims;\n    float freq_scale;\n    float freq_base;\n    float ext_factor;\n    float attn_factor;\n    float corr_dims[2];\n    float theta_scale;\n    uint32_t has_ff;\n    int32_t sections[4];\n    uint32_t is_imrope;\n    uint32_t is_back;\n    uint32_t set_rows_stride;\n    uint32_t ne00;\n    uint32_t ne01;\n    uint32_t ne02;\n    uint32_t nb01;\n    uint32_t nb02;\n    uint32_t nb03;\n    uint32_t nb11;\n    uint32_t nb12;\n    uint32_t nb13;\n};\nstatic_assert(sizeof(vk_op_rope_push_constants) <= 128, \"sizeof(vk_op_rope_push_constants) must be <= 128\");\n\n// For fused rms_norm+mul+rope(+view+set_rows)\nstruct vk_op_rms_norm_mul_rope_push_constants {\n    vk_op_binary_push_constants bin;\n    vk_op_rope_push_constants rope;\n};\n\nstruct vk_op_soft_max_push_constants {\n    uint32_t KX;\n    uint32_t KY;\n    uint32_t ne00;\n    uint32_t ne01;\n    uint32_t ne02;\n    uint32_t ne12;\n    uint32_t ne13;\n    uint32_t nb11;\n    uint32_t nb12;\n    uint32_t nb13;\n    float scale;\n    float max_bias;\n    float m0;\n    float m1;\n    uint32_t n_head_log2;\n    uint32_t nrows_x;\n    uint32_t has_sinks;\n};\n\nstruct vk_op_argsort_push_constants {\n    uint32_t ncols;\n    uint32_t ncols_padded;\n    uint32_t ncols_padded_log2;\n    uint32_t nrows;\n    uint32_t order;\n    uint32_t outer_start;\n    uint32_t outer_end;\n    uint32_t inner_start;\n    uint32_t inner_end;\n};\n\nstruct vk_op_topk_push_constants {\n    uint32_t orig_ncols;\n    uint32_t ncols_input;\n    uint32_t ncols_output;\n    uint32_t k;\n    uint32_t nrows;\n    uint32_t first_pass;\n    uint32_t last_pass;\n};\n\nstruct vk_op_im2col_push_constants {\n    uint64_t dst_addr;\n    uint32_t batch_offset; uint32_t offset_delta;\n    uint32_t IC;\n    uint32_t IW; uint32_t IH;\n    uint32_t OW; uint32_t OH;\n    uint32_t KW; uint32_t KH;\n    uint32_t pelements;\n    uint32_t CHW;\n    int32_t s0; int32_t s1;\n    int32_t p0; int32_t p1;\n    int32_t d0; int32_t d1;\n    uint32_t batch_IC;\n};\n\nstruct vk_op_im2col_3d_push_constants {\n    uint64_t dst_addr;\n    uint32_t nb10;\n    uint32_t nb11;\n    uint32_t nb12;\n    uint32_t nb13;\n    uint32_t s0;\n    uint32_t s1;\n    uint32_t s2;\n    uint32_t p0;\n    uint32_t p1;\n    uint32_t p2;\n    uint32_t d0;\n    uint32_t d1;\n    uint32_t d2;\n    uint32_t IW;\n    uint32_t IH;\n    uint32_t ID;\n    uint32_t IC;\n    uint32_t KW;\n    uint32_t OH;\n    uint32_t KD_KH_KW;\n    uint32_t KH_KW;\n    uint32_t IC_KD_KH_KW;\n    uint32_t N_OD_OH;\n    uint32_t OD_OH;\n    uint32_t OD_OH_OW_IC_KD_KH_KW;\n    uint32_t OH_OW_IC_KD_KH_KW;\n    uint32_t OW_IC_KD_KH_KW;\n    uint32_t misalign_offsets;\n};\n\nstruct vk_op_timestep_embedding_push_constants {\n    uint32_t nb1;\n    uint32_t dim;\n    uint32_t max_period;\n};\n\nstruct vk_op_conv_transpose_1d_push_constants {\n    uint32_t Cout;\n    uint32_t Cin;\n    uint32_t K;\n    uint32_t L;\n    uint32_t KL;\n\n    uint32_t nb01;\n    uint32_t nb02;\n    uint32_t nb11;\n    uint32_t nb1;\n\n    int32_t s0;\n};\n\nstruct vk_op_pool2d_push_constants {\n    uint32_t IW; uint32_t IH;\n    uint32_t OW; uint32_t OH;\n    uint32_t OC;\n    uint32_t pelements;\n    uint32_t op;\n    int32_t k0; int32_t k1;\n    int32_t s0; int32_t s1;\n    int32_t p0; int32_t p1;\n};\n\nstruct vk_op_rwkv_wkv6_push_constants {\n    uint32_t B;\n    uint32_t T;\n    uint32_t C;\n    uint32_t H;\n};\n\nstruct vk_op_rwkv_wkv7_push_constants {\n    uint32_t B;\n    uint32_t T;\n    uint32_t C;\n    uint32_t H;\n};\nstruct vk_op_gated_delta_net_push_constants {\n    uint32_t H;\n    uint32_t n_tokens;\n    uint32_t n_seqs;\n    uint32_t s_off;\n    uint32_t sq1, sq2, sq3;\n    uint32_t sv1, sv2, sv3;\n    uint32_t sb1, sb2, sb3;\n    uint32_t neq1, rq3;\n    float scale;\n};\n\nstruct vk_op_ssm_scan_push_constants {\n    uint32_t nb02, nb03, nb12, nb13;\n    uint32_t nb21, nb22, nb31;\n    uint32_t nb42, nb43, nb52, nb53;\n    uint32_t s_off;\n    uint32_t n_head, d_head, n_group, n_tok;\n};\nstruct vk_op_ssm_conv_push_constants {\n    uint32_t nb01, nb02;\n    uint32_t nb11;\n    uint32_t dst_nb0, dst_nb1, dst_nb2;\n    uint32_t nc, ncs, nr, n_t, n_s;\n};\n\nstruct vk_op_conv2d_push_constants {\n    uint32_t Cout;\n    uint32_t Cin;\n    uint32_t N;\n\n    uint32_t W;\n    uint32_t H;\n    uint32_t OW;\n    uint32_t OH;\n\n    uint32_t nb01;\n    uint32_t nb02;\n    uint32_t nb03;\n\n    uint32_t nb11;\n    uint32_t nb12;\n    uint32_t nb13;\n\n    uint32_t nb1;\n    uint32_t nb2;\n    uint32_t nb3;\n\n    // init_fastdiv_values constants for dividing by OW, OW*OH\n    uint32_t OWmp;   uint32_t OWL;\n    uint32_t OWOHmp; uint32_t OWOHL;\n};\n\ntemplate <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {\n    // Compute magic values to divide by OW, OW*OH\n    init_fastdiv_values(p.OW,       p.OWmp,    p.OWL);\n    init_fastdiv_values(p.OW*p.OH,  p.OWOHmp,  p.OWOHL);\n}\n\nstruct vk_op_conv2d_dw_push_constants {\n    uint32_t ne;\n    uint32_t batches;\n    uint32_t channels;\n    uint32_t dst_w;\n    uint32_t dst_h;\n    uint32_t src_w;\n    uint32_t src_h;\n    uint32_t knl_w;\n    uint32_t knl_h;\n    int32_t stride_x;\n    int32_t stride_y;\n    int32_t pad_x;\n    int32_t pad_y;\n    int32_t dilation_x;\n    int32_t dilation_y;\n};\n\nstruct vk_op_upscale_push_constants {\n    uint32_t ne; uint32_t a_offset; uint32_t d_offset;\n    uint32_t ne00; uint32_t ne01;\n    uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;\n    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;\n    float sf0; float sf1; float sf2; float sf3;\n    float pixel_offset;\n};\n\nstruct vk_op_sum_rows_push_constants\n{\n    uint32_t n_cols;\n    uint32_t ne01, ne02;\n    uint32_t nb01, nb02, nb03;\n    uint32_t nb11, nb12, nb13;\n    float weight;\n    uint32_t misalign_offsets;\n    uint32_t ne0_12mp, ne0_12L;\n    uint32_t ne0_1mp, ne0_1L;\n};\n\nstatic vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {\n    uint32_t type_size = (uint32_t)ggml_type_size(src->type);\n    vk_op_sum_rows_push_constants p = {};\n    p.n_cols = (uint32_t)n_cols;\n    p.ne01 = (uint32_t)src->ne[1];\n    p.ne02 = (uint32_t)src->ne[2];\n    p.nb01 = (uint32_t)src->nb[1] / type_size;\n    p.nb02 = (uint32_t)src->nb[2] / type_size;\n    p.nb03 = (uint32_t)src->nb[3] / type_size;\n    p.nb11 = (uint32_t)dst->nb[1] / type_size;\n    p.nb12 = (uint32_t)dst->nb[2] / type_size;\n    p.nb13 = (uint32_t)dst->nb[3] / type_size;\n    p.weight = 1.0f;\n    return p;\n}\n\ntemplate <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) {\n    init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L);\n    init_fastdiv_values(p.ne01,        p.ne0_1mp,  p.ne0_1L);\n}\n\nstruct vk_quantize_q8_1_push_constants {\n    uint32_t ne;\n    uint32_t num_blocks;\n};\n\nstruct vk_op_flash_attn_split_k_reduce_push_constants {\n    uint32_t D;\n    uint32_t ne1;\n    uint32_t ne2;\n    uint32_t ne3;\n    uint32_t k_num;\n    uint32_t sinks;\n};\n\nstruct vk_op_flash_attn_mask_opt_push_constants {\n    uint32_t nem0;\n    uint32_t nem1;\n    uint32_t nem2;\n    uint32_t nbm1;\n    uint32_t nbm2;\n    uint32_t nbm3;\n    uint32_t nbd1;\n    uint32_t nbd2;\n    uint32_t nbd3;\n};\n\n// Allow pre-recording command buffers\nstruct vk_staging_memcpy {\n    vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}\n\n    void * dst;\n    const void * src;\n    size_t n;\n};\n\nstruct vk_staging_memset {\n    vk_staging_memset(void * _dst, uint32_t _val, size_t _n) : dst(_dst), val(_val), n(_n) {}\n\n    void * dst;\n    uint32_t val;\n    size_t n;\n};\n\nstruct vk_context_struct {\n    vk_submission * s;\n    std::vector<vk_sequence> seqs;\n\n    int exit_tensor_idx;\n\n    std::vector<vk_staging_memcpy> in_memcpys;\n    std::vector<vk_staging_memcpy> out_memcpys;\n    std::vector<vk_staging_memset> memsets;\n\n    vk_command_pool * p {};\n};\ntypedef std::shared_ptr<vk_context_struct> vk_context;\ntypedef std::weak_ptr<vk_context_struct> vk_context_ref;\n\nstruct ggml_vk_garbage_collector {\n    std::vector<vk_semaphore> tl_semaphores;\n    std::vector<vk_semaphore> semaphores;\n    std::vector<vk::Event> events;\n    std::vector<vk_context> contexts;\n};\n\nstatic void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx);\nstatic void ggml_vk_load_shaders(vk_device& device);\nstatic void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx);\n\nstatic bool vk_memory_logger_enabled = false;\n\n#define VK_LOG_MEMORY(msg) if (vk_memory_logger_enabled) { std::cerr << \"ggml_vulkan memory: \" << msg << std::endl; }\n\nstatic std::string format_size(size_t size) {\n    const size_t kib = 1024;\n    const size_t mib = kib * 1024;\n    const size_t gib = mib * 1024;\n\n    std::ostringstream oss;\n    oss << std::fixed << std::setprecision(2);\n\n    if (size >= gib) {\n        oss << static_cast<double>(size) / gib << \" GiB\";\n    } else if (size >= mib) {\n        oss << static_cast<double>(size) / mib << \" MiB\";\n    } else if (size >= kib) {\n        oss << static_cast<double>(size) / kib << \" KiB\";\n    } else {\n        oss << size << \" B\";\n    }\n\n    return oss.str();\n}\n\nclass vk_memory_logger {\npublic:\n    vk_memory_logger(): total_device(0), total_host(0) {}\n    void log_allocation(vk_buffer_ref buf_ref, size_t size);\n    void log_deallocation(vk_buffer_ref buf_ref);\n\nprivate:\n    std::map<vk::Buffer, size_t> allocations; // Track allocations\n    size_t total_device;\n    size_t total_host;\n    static std::mutex log_mutex;\n};\n\nstd::mutex vk_memory_logger::log_mutex;\n\nstatic bool vk_perf_logger_enabled = false;\nstatic bool vk_perf_logger_concurrent = false;\nstatic bool vk_enable_sync_logger = false;\n// number of calls between perf logger prints\nstatic uint32_t vk_perf_logger_frequency = 1;\nstatic std::string vk_pipeline_stats_filter;\n\nclass vk_perf_logger {\n  public:\n    void print_timings(bool force = false) {\n        if (timings.empty()) {\n            return;\n        }\n        print_count++;\n        if ((print_count % vk_perf_logger_frequency) != 0 && !force) {\n            return;\n        }\n        print_count = 0;\n        uint64_t total_all_op_times = 0;\n        std::cerr << \"----------------\\nVulkan Timings:\" << std::endl;\n        for (const auto & t : timings) {\n            uint64_t total_op_times = 0;\n            for (const auto & time : t.second) {\n                total_op_times += time;\n            }\n            std::cerr << t.first << \": \" << t.second.size() << \" x \" << (total_op_times / t.second.size() / 1000.0)\n                      << \" us = \" << (total_op_times / 1000.0) << \" us\";\n\n            // If we have as many flops entries as timing entries for the op, then compute and log the flops/S.\n            auto it = flops.find(t.first);\n            if (it != flops.end() && (it->second).size() == t.second.size()) {\n                uint64_t total_op_flops = 0;\n                for (const auto & elem : it->second) {\n                    total_op_flops += elem;\n                }\n                std::cerr << \" (\"\n                          << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) /\n                                 (double(total_op_times) / (1000.0 * 1000.0 * 1000.0))\n                          << \" GFLOPS/s)\";\n            }\n\n            total_all_op_times += total_op_times;\n\n            std::cerr << std::endl;\n        }\n\n        if (timings.size() > 0) {\n            std::cerr << \"Total time: \" << total_all_op_times / 1000.0 << \" us.\" << std::endl;\n        }\n\n        timings.clear();\n        flops.clear();\n    }\n\n    std::string get_node_fusion_name(const ggml_tensor * node, const char *fusion_name, uint64_t *n_flops) {\n        *n_flops = 0;\n        std::string fusion_str;\n        if (fusion_name) {\n            fusion_str = fusion_name + std::string(\" \");\n        }\n        if (node->op == GGML_OP_UNARY) {\n            return fusion_str + ggml_unary_op_name(ggml_get_unary_op(node));\n        }\n        if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {\n            const uint64_t m     = node->ne[0];\n            const uint64_t n     = node->ne[1];\n            const uint64_t k     = node->src[1]->ne[0];\n            const uint64_t batch = node->ne[2] * node->ne[3];\n            std::string    name  = ggml_op_name(node->op);\n            if ((node->op == GGML_OP_MUL_MAT && n <= mul_mat_vec_max_cols) ||\n                (node->op == GGML_OP_MUL_MAT_ID && node->src[2]->ne[1] == 1)) {\n                name += \"_VEC\";\n            }\n            name += \" \";\n            name += ggml_type_name(node->src[0]->type);\n            name += \" m=\" + std::to_string(m) + \" n=\" + std::to_string(n) + \" k=\" + std::to_string(k);\n            if (node->op == GGML_OP_MUL_MAT_ID) {\n                name += \" n_expert=\" + std::to_string(node->src[0]->ne[2]);\n            }\n            if (batch > 1) {\n                name += \" batch=\" + std::to_string(batch);\n            }\n            name = fusion_str + name;\n            *n_flops = m * n * (k + (k - 1)) * batch;\n            return name;\n        }\n        if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {\n            std::string   name    = ggml_op_name(node->op);\n            ggml_tensor * knl     = node->src[0];\n            uint64_t      OW      = node->ne[0];\n            uint64_t      OH      = node->ne[1];\n            uint64_t      N       = node->ne[3];\n            uint64_t      Cout    = node->ne[2];\n            uint64_t      KW      = knl->ne[0];\n            uint64_t      KH      = knl->ne[1];\n            uint64_t      Cin     = node->src[1]->ne[2];\n            // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ\n            uint64_t      size_M  = Cout;\n            uint64_t      size_K  = Cin * KW * KH;\n            uint64_t      size_N  = N * OW * OH;\n            *n_flops = size_M * size_N * (size_K + (size_K - 1));\n            name += \" M=Cout=\" + std::to_string(size_M) + \", K=Cin*KW*KH=\" + std::to_string(size_K) +\n                    \", N=N*OW*OH=\" + std::to_string(size_N);\n            name = fusion_str + name;\n            return name;\n        }\n        if (node->op == GGML_OP_RMS_NORM) {\n            std::string   name    = ggml_op_name(node->op);\n            name += \"(\" + std::to_string(node->ne[0]) + \",\" + std::to_string(node->ne[1]) + \",\" + std::to_string(node->ne[2]) + \",\" + std::to_string(node->ne[3]) + \")\";\n            name = fusion_str + name;\n            return name;\n        }\n        if (node->op == GGML_OP_FLASH_ATTN_EXT) {\n            const ggml_tensor * dst = node;\n            const ggml_tensor * q = node->src[0];\n            const ggml_tensor * k = node->src[1];\n            const ggml_tensor * v = node->src[2];\n            const ggml_tensor * m = node->src[3];\n            std::stringstream name;\n            name << fusion_str;\n            name << ggml_op_name(node->op) <<\n                \" dst(\" << dst->ne[0] << \",\" << dst->ne[1] << \",\" << dst->ne[2] << \",\" << dst->ne[3] << \"), \" <<\n                \" q(\" << q->ne[0] << \",\" << q->ne[1] << \",\" << q->ne[2] << \",\" << q->ne[3] << \"), \" <<\n                \" k(\" << k->ne[0] << \",\" << k->ne[1] << \",\" << k->ne[2] << \",\" << k->ne[3] << \"), \" <<\n                \" v(\" << v->ne[0] << \",\" << v->ne[1] << \",\" << v->ne[2] << \",\" << v->ne[3] << \"), \" <<\n                \" m(\" << (m?m->ne[0]:0) << \",\" << (m?m->ne[1]:0) << \",\" << (m?m->ne[2]:0) << \",\" << (m?m->ne[3]:0) << \")\";\n            *n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];\n            return name.str();\n        }\n        if (node->op == GGML_OP_TOP_K) {\n            std::stringstream name;\n            name << fusion_str;\n            name << ggml_op_name(node->op) <<\n                \" K=\" << node->ne[0] <<\n                \" (\" << node->src[0]->ne[0] << \",\" << node->src[0]->ne[1] << \",\" << node->src[0]->ne[2] << \",\" << node->src[0]->ne[3] << \")\";\n            return name.str();\n        }\n        return fusion_str + ggml_op_name(node->op);\n    }\n\n    void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) {\n        uint64_t n_flops;\n        std::string name = get_node_fusion_name(node, fusion_name, &n_flops);\n        if (n_flops) {\n            flops[name].push_back(n_flops);\n        }\n        timings[name].push_back(time);\n    }\n\n    void log_timing(const std::vector<ggml_tensor *> &nodes, const std::vector<const char *> &names, uint64_t time) {\n        uint64_t total_flops = 0;\n        std::string name;\n        for (size_t n = 0; n < nodes.size(); ++n) {\n            uint64_t n_flops = 0;\n            name += get_node_fusion_name(nodes[n], names[n], &n_flops);\n            total_flops += n_flops;\n\n            if (n != nodes.size() - 1) {\n                name += \", \";\n            }\n        }\n        if (total_flops) {\n            flops[name].push_back(total_flops);\n        }\n        timings[name].push_back(time);\n    }\n\n  private:\n    std::map<std::string, std::vector<uint64_t>> timings;\n    std::map<std::string, std::vector<uint64_t>> flops;\n    uint32_t print_count {};\n};\n\nstruct ggml_backend_vk_context {\n    std::string name;\n\n    vk_device device;\n\n    size_t semaphore_idx, event_idx;\n    ggml_vk_garbage_collector gc;\n    size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset;\n    vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials, sync_staging;\n    vk::Fence fence, almost_ready_fence;\n    bool submit_pending {};\n    bool almost_ready_fence_pending {};\n    // Set before op_add and unset after op_rms_norm to indicate that the add should\n    // write partial sums to accumulate the square of the vector components\n    bool do_add_rms_partials_offset_calculation;\n    bool do_add_rms_partials;\n\n    uint64_t last_total_mul_mat_bytes {};\n\n    // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.\n    vk_pipeline_struct * prealloc_y_last_pipeline_used {};\n    const ggml_tensor * prealloc_y_last_tensor_used {};\n\n    // Track which nodes have been used since the last sync, and whether they were written to\n    std::vector<const ggml_tensor *> unsynced_nodes_written;\n    std::vector<const ggml_tensor *> unsynced_nodes_read;\n    // Track which prealloc buffers have pending reads that need to be synchronized.\n    // These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set),\n    // and set to true after the buffer contents are consumed.\n    bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;\n\n    vk_context_ref compute_ctx;\n\n    vk_context_ref transfer_ctx;\n    vk_semaphore transfer_semaphore;\n    uint64_t transfer_semaphore_last_submitted {};\n\n    std::vector<vk_context_ref> tensor_ctxs;\n\n    std::vector<vk::DescriptorPool> descriptor_pools;\n    std::vector<vk::DescriptorSet> descriptor_sets;\n    uint32_t descriptor_set_idx {};\n    uint32_t pipeline_descriptor_set_requirements {};\n\n    vk_command_pool compute_cmd_pool;\n    vk_command_pool transfer_cmd_pool;\n\n    // number of additional consecutive nodes that are being fused with the\n    // node currently being processed\n    int num_additional_fused_ops {};\n    // Bitmask of which fused ops need to write an intermediate value to memory.\n    // Bit 'i' means nodes[start_of_fusion + i] writes to memory.\n    // If there's no fusion, bit 0 is still set.\n    int fused_ops_write_mask {};\n    topk_moe_mode fused_topk_moe_mode {};\n    bool fused_topk_moe_scale {};\n\n    // for GGML_VK_PERF_LOGGER\n    std::unique_ptr<vk_perf_logger> perf_logger;\n    vk::QueryPool query_pool;\n    std::vector<const char *> query_fusion_names;\n    std::vector<int> query_fusion_node_count;\n    std::vector<ggml_tensor *> query_nodes;\n    std::vector<int> query_node_idx;\n    int32_t num_queries {};\n    int32_t query_idx {};\n};\n\nstatic void * const vk_ptr_base = (void *)(uintptr_t) 0x1000;  // NOLINT\n\nstatic uint64_t vk_tensor_offset(const ggml_tensor * tensor) {\n    if (tensor->view_src) {\n        return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base;\n    }\n    return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;\n}\n\nstatic uint32_t get_misalign_bytes(const ggml_backend_vk_context * ctx, const ggml_tensor * t)\n{\n    return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;\n}\n\ntemplate <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {\n    GGML_UNUSED(p);\n    GGML_UNUSED(src0);\n    GGML_UNUSED(src1);\n    GGML_UNUSED(src2);\n    GGML_UNUSED(src3);\n    GGML_UNUSED(dst);\n    static_assert(!std::is_const<T>::value, \"unexpected type\");\n    GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0);\n    GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0);\n    GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0);\n    GGML_ASSERT(!src3 || get_misalign_bytes(ctx, src3) == 0);\n    GGML_ASSERT(!dst  || get_misalign_bytes(ctx, dst) == 0);\n}\n\ntemplate <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_mat_vec_p021_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {\n    const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);\n    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);\n\n    p.b_offset = b_offset;\n    p.d_offset = d_offset;\n\n    GGML_UNUSED(src0);\n    GGML_UNUSED(src2);\n    GGML_UNUSED(src3);\n}\n\ntemplate <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_mat_vec_nc_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {\n    const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);\n    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);\n\n    p.b_offset = b_offset;\n    p.d_offset = d_offset;\n\n    GGML_UNUSED(src0);\n    GGML_UNUSED(src2);\n    GGML_UNUSED(src3);\n}\n\nstruct ggml_backend_vk_buffer_context {\n    vk_device_ref device;\n    vk_buffer dev_buffer;\n    std::string name;\n\n    ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) :\n        device(device),\n        dev_buffer(dev_buffer),\n        name(name) {\n    }\n\n    ~ggml_backend_vk_buffer_context() {\n        ggml_vk_destroy_buffer(dev_buffer);\n    }\n};\n\nvoid vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {\n    if (!vk_memory_logger_enabled) {\n        return;\n    }\n    std::lock_guard<std::mutex> guard(log_mutex);\n    vk_buffer buf = buf_ref.lock();\n    const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);\n    const std::string type = device ? \"device\" : \"host\";\n    allocations[buf->buffer] = size;\n    total_device += device ? size : 0;\n    total_host += device ? 0 : size;\n    VK_LOG_MEMORY(buf->device->name << \": +\" << format_size(size) << \" \" << type << \" at \" << buf->buffer << \". Total device: \" << format_size(total_device) << \", total host: \" << format_size(total_host));\n}\n\nvoid vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {\n    if (buf_ref.expired() || buf_ref.lock()->size == 0 || !vk_memory_logger_enabled) {\n        return;\n    }\n\n    std::lock_guard<std::mutex> guard(log_mutex);\n    vk_buffer buf = buf_ref.lock();\n    const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);\n    std::string type = device ? \"device\" : \"host\";\n    auto it = allocations.find(buf->buffer);\n    total_device -= device ? it->second : 0;\n    total_host -= device ? 0 : it->second;\n    if (it != allocations.end()) {\n        VK_LOG_MEMORY(buf->device->name << \": -\" << format_size(it->second) << \" \" << type << \" at \" << buf->buffer << \". Total device: \" << format_size(total_device) << \", total host: \" << format_size(total_host));\n        allocations.erase(it);\n    } else {\n        VK_LOG_MEMORY(\"ERROR \" << buf->device->name << \": Attempted to deallocate unknown \" << type << \" memory at \" << buf->buffer);\n    }\n}\n\nstruct vk_instance_t {\n    vk::Instance instance;\n\n    bool debug_utils_support = false;  // VK_EXT_debug_utils enabled\n    PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {};\n    PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {};\n    PFN_vkQueueEndDebugUtilsLabelEXT   pfn_vkQueueEndDebugUtilsLabelEXT   = {};\n    PFN_vkCmdBeginDebugUtilsLabelEXT   pfn_vkCmdBeginDebugUtilsLabelEXT   = {};\n    PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {};\n    PFN_vkCmdInsertDebugUtilsLabelEXT  pfn_vkCmdInsertDebugUtilsLabelEXT  = {};\n\n    std::vector<size_t> device_indices;\n    std::vector<bool>   device_supports_membudget;\n    vk_device devices[GGML_VK_MAX_DEVICES];\n};\n\nstatic bool vk_instance_initialized = false;\nstatic vk_instance_t vk_instance;\n\n#ifdef GGML_VULKAN_CHECK_RESULTS\nstatic size_t vk_skip_checks;\nstatic size_t vk_output_tensor;\n\nstatic void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);\nstatic void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);\nstatic void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);\n#endif\n\ntypedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);\n\nstatic void ggml_backend_vk_free(ggml_backend_t backend);\n\nstatic VkDeviceSize ggml_vk_get_max_buffer_range(const ggml_backend_vk_context * ctx, const vk_buffer &buf, const VkDeviceSize offset) {\n    const VkDeviceSize range = std::min(VkDeviceSize{buf->size - offset},\n                                        VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});\n    return range;\n}\n\n// Wait for ctx->fence to be signaled.\nstatic void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {\n    // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep\n    // during this wait.\n    if (ctx->almost_ready_fence_pending) {\n        VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), \"almost_ready_fence\");\n        ctx->device->device.resetFences({ ctx->almost_ready_fence });\n        ctx->almost_ready_fence_pending = false;\n    }\n\n    // Spin (w/pause) waiting for the graph to finish executing.\n    vk::Result result;\n    while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) {\n        if (result != vk::Result::eNotReady) {\n            fprintf(stderr, \"ggml_vulkan: error %s at %s:%d\\n\", to_string(result).c_str(), __FILE__, __LINE__);\n            exit(1);\n        }\n        for (uint32_t i = 0; i < 100; ++i) {\n            YIELD();\n            YIELD();\n            YIELD();\n            YIELD();\n            YIELD();\n            YIELD();\n            YIELD();\n            YIELD();\n            YIELD();\n            YIELD();\n        }\n    }\n    ctx->device->device.resetFences({ ctx->fence });\n}\n\n// variables to track number of compiles in progress\nstatic uint32_t compile_count = 0;\nstatic std::mutex compile_count_mutex;\nstatic std::condition_variable compile_count_cond;\n\nstatic void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,\n                                         uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,\n                                         bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {\n    VK_LOG_DEBUG(\"ggml_vk_create_pipeline(\" << device->name << \", \" << pipeline->name << \", \" << entrypoint << \", \" << parameter_count <<\n                 \", (\" << wg_denoms[0] << \",\" << wg_denoms[1] << \",\" << wg_denoms[2] << \"), specialization_constants, \" <<\n                 disable_robustness << \", \" << require_full_subgroups << \", \" << required_subgroup_size << \")\");\n    GGML_ASSERT(parameter_count > 0);\n    GGML_ASSERT(parameter_count <= MAX_PARAMETER_COUNT);\n    GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT\n\n    vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));\n    pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);\n\n    vk::PushConstantRange pcr(\n        vk::ShaderStageFlagBits::eCompute,\n        0,\n        pipeline->push_constant_size\n    );\n\n    vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), device->dsl, pcr);\n    pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info);\n\n    std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size());\n\n    for (size_t i = 0; i < specialization_constants.size(); i++) {\n        specialization_entries[i].constantID = i;\n        specialization_entries[i].offset = i * sizeof(uint32_t);\n        specialization_entries[i].size = sizeof(uint32_t);\n    }\n\n    vk::SpecializationInfo specialization_info(\n        specialization_entries.size(),\n        specialization_entries.data(),\n        specialization_constants.size() * sizeof(uint32_t),\n        specialization_constants.data()\n    );\n\n    vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{};\n\n    if (device->subgroup_require_full_support && require_full_subgroups) {\n        pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT;\n    }\n\n    vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(\n            pipeline_shader_stage_create_flags,\n            vk::ShaderStageFlagBits::eCompute,\n            pipeline->shader_module,\n            entrypoint.c_str(),\n            &specialization_info);\n\n    vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info;\n    pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size;\n    if (device->subgroup_size_control && required_subgroup_size > 0) {\n        GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size);\n        pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info);\n    }\n\n    vk::ComputePipelineCreateInfo compute_pipeline_create_info(\n        device->pipeline_executable_properties_support ?\n            vk::PipelineCreateFlagBits::eCaptureStatisticsKHR :\n            vk::PipelineCreateFlags{},\n        pipeline_shader_create_info,\n        pipeline->layout);\n\n    vk::PipelineRobustnessCreateInfoEXT rci;\n\n    if (device->pipeline_robustness && disable_robustness) {\n        rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;\n        rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;\n        compute_pipeline_create_info.setPNext(&rci);\n    }\n\n#if defined(VK_EXT_shader_64bit_indexing)\n    vk::PipelineCreateFlags2CreateInfo pipelineFlags2CreateInfo;\n    if (pipeline->is_64b_indexing)\n    {\n        pipelineFlags2CreateInfo.flags = vk::PipelineCreateFlagBits2::e64BitIndexingEXT;\n        if (device->pipeline_executable_properties_support) {\n            pipelineFlags2CreateInfo.flags |= vk::PipelineCreateFlagBits2::eCaptureStatisticsKHR;\n        }\n        pipelineFlags2CreateInfo.setPNext(compute_pipeline_create_info.pNext);\n        compute_pipeline_create_info.setPNext(&pipelineFlags2CreateInfo);\n    }\n#endif\n\n    try {\n        pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;\n    } catch (const vk::SystemError& e) {\n        std::cerr << \"ggml_vulkan: Compute pipeline creation failed for \" << pipeline->name << std::endl;\n        std::cerr << \"ggml_vulkan: \" << e.what() << std::endl;\n        throw e;\n    }\n    pipeline->compiled = true;\n\n    if (vk_instance.debug_utils_support) {\n        vk::DebugUtilsObjectNameInfoEXT duoni;\n        duoni.objectType = vk::ObjectType::ePipeline;\n        duoni.pObjectName = pipeline->name.c_str();\n        duoni.objectHandle = /*reinterpret_cast*/(uint64_t)(static_cast<VkPipeline>(pipeline->pipeline));\n        vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));\n    }\n\n    if (device->pipeline_executable_properties_support) {\n        vk::PipelineExecutableInfoKHR executableInfo;\n        executableInfo.pipeline = pipeline->pipeline;\n\n        auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo);\n\n        bool print_stats = !vk_pipeline_stats_filter.empty() &&\n                           pipeline->name.find(vk_pipeline_stats_filter) != std::string::npos;\n        if (print_stats) {\n            std::cerr << \"ggml_vulkan: pipeline stats for \" << pipeline->name << \":\" << std::endl;\n        }\n\n        for (auto & s : statistics) {\n            if (print_stats) {\n                std::cerr << \"ggml_vulkan:   \" << s.name.data() << \": \";\n                switch (s.format) {\n                    case vk::PipelineExecutableStatisticFormatKHR::eBool32:\n                        std::cerr << (s.value.b32 ? \"true\" : \"false\");\n                        break;\n                    case vk::PipelineExecutableStatisticFormatKHR::eInt64:\n                        std::cerr << s.value.i64;\n                        break;\n                    case vk::PipelineExecutableStatisticFormatKHR::eUint64:\n                        std::cerr << s.value.u64;\n                        break;\n                    case vk::PipelineExecutableStatisticFormatKHR::eFloat64:\n                        std::cerr << s.value.f64;\n                        break;\n                }\n                std::cerr << std::endl;\n            }\n            // \"Register Count\" is reported by NVIDIA drivers.\n            if (strcmp(s.name, \"Register Count\") == 0) {\n                VK_LOG_DEBUG(pipeline->name << \" \" << s.name << \": \" << s.value.u64 << \" registers\");\n                pipeline->register_count = (uint32_t)s.value.u64;\n            }\n        }\n    }\n\n    device->all_pipelines.push_back(pipeline);\n\n    {\n        std::lock_guard<std::mutex> guard(compile_count_mutex);\n        assert(compile_count > 0);\n        compile_count--;\n    }\n    compile_count_cond.notify_all();\n}\n\nstatic void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {\n    VK_LOG_DEBUG(\"ggml_pipeline_destroy_pipeline(\" << pipeline->name << \")\");\n    device.destroyPipelineLayout(pipeline->layout);\n\n    device.destroyShaderModule(pipeline->shader_module);\n\n    device.destroyPipeline(pipeline->pipeline);\n}\n\nstatic void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) {\n    VK_LOG_DEBUG(\"ggml_pipeline_request_descriptor_sets(\" << pipeline->name << \", \" << n << \")\");\n    ctx->pipeline_descriptor_set_requirements += n;\n    if (!pipeline->compiled) {\n        pipeline->needed = true;\n        ggml_vk_load_shaders(ctx->device);\n    }\n    ggml_pipeline_allocate_descriptor_sets(ctx);\n}\n\nstatic void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) {\n\n    if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) {\n        // Enough descriptors are available\n        return;\n    }\n\n    vk_device& device = ctx->device;\n\n    // Grow by 50% to avoid frequent allocations\n    uint32_t needed = std::max(3 * ctx->descriptor_sets.size() / 2, size_t{ctx->pipeline_descriptor_set_requirements});\n    uint32_t to_alloc = needed - ctx->descriptor_sets.size();\n    uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE;\n    uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE;\n\n    while (to_alloc > 0) {\n        const uint32_t alloc_count = std::min(pool_remaining, to_alloc);\n        to_alloc -= alloc_count;\n        pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE;\n\n        if (pool_idx >= ctx->descriptor_pools.size()) {\n            vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE);\n            vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);\n            ctx->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));\n        }\n\n        std::vector<vk::DescriptorSetLayout> layouts(alloc_count);\n        for (uint32_t i = 0; i < alloc_count; i++) {\n            layouts[i] = device->dsl;\n        }\n        vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data());\n        std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);\n        ctx->descriptor_sets.insert(ctx->descriptor_sets.end(), sets.begin(), sets.end());\n\n        pool_idx++;\n    }\n}\n\nstatic vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) {\n    VK_LOG_DEBUG(\"ggml_vk_create_cmd_buffer()\");\n    vk::CommandBufferAllocateInfo command_buffer_alloc_info(\n        p.pool,\n        vk::CommandBufferLevel::ePrimary,\n        1);\n    const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);\n    p.cmd_buffers.push_back({ cmd_buffers.front(), true });\n    return &p.cmd_buffers[p.cmd_buffers.size()-1];\n}\n\nstatic void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {\n    if (ctx->seqs.empty()) {\n        if (fence) {\n            std::lock_guard<std::mutex> guard(queue_mutex);\n            ctx->p->q->queue.submit({}, fence);\n        }\n        return;\n    }\n    VK_LOG_DEBUG(\"ggml_vk_submit(\" << ctx << \", \" << fence << \")\");\n\n    std::vector<std::vector<uint64_t>> tl_wait_vals;\n    std::vector<std::vector<uint64_t>> tl_signal_vals;\n    std::vector<std::vector<vk::Semaphore>> tl_wait_semaphores;\n    std::vector<std::vector<vk::Semaphore>> tl_signal_semaphores;\n    std::vector<vk::TimelineSemaphoreSubmitInfo> tl_submit_infos;\n    std::vector<vk::SubmitInfo> submit_infos;\n    int idx = -1;\n    std::vector<std::vector<vk::PipelineStageFlags>> stage_flags;\n\n    size_t reserve = 0;\n\n    for (const auto& sequence : ctx->seqs) {\n        reserve += sequence.size();\n    }\n\n    // Pre-reserve vectors to prevent reallocation, which invalidates pointers\n    tl_wait_semaphores.reserve(reserve);\n    tl_wait_vals.reserve(reserve);\n    tl_signal_semaphores.reserve(reserve);\n    tl_signal_vals.reserve(reserve);\n    tl_submit_infos.reserve(reserve);\n    submit_infos.reserve(reserve);\n    stage_flags.reserve(reserve);\n\n    for (const auto& sequence : ctx->seqs) {\n        for (const auto& submission : sequence) {\n            stage_flags.push_back({});\n            idx++;\n            tl_wait_vals.push_back({});\n            tl_wait_semaphores.push_back({});\n            tl_signal_vals.push_back({});\n            tl_signal_semaphores.push_back({});\n            for (size_t i = 0; i < submission.wait_semaphores.size(); i++) {\n                stage_flags[idx].push_back(ctx->p->q->stage_flags);\n                tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value);\n                tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s);\n            }\n            for (size_t i = 0; i < submission.signal_semaphores.size(); i++) {\n                tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value);\n                tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s);\n            }\n            tl_submit_infos.push_back({\n                (uint32_t) submission.wait_semaphores.size(),\n                tl_wait_vals[idx].data(),\n                (uint32_t) submission.signal_semaphores.size(),\n                tl_signal_vals[idx].data(),\n            });\n            tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo;\n            tl_submit_infos[idx].pNext = nullptr;\n            vk::SubmitInfo si{\n                (uint32_t) submission.wait_semaphores.size(),\n                tl_wait_semaphores[idx].data(),\n                stage_flags[idx].data(),\n                1,\n                &submission.buffer->buf,\n                (uint32_t) submission.signal_semaphores.size(),\n                tl_signal_semaphores[idx].data(),\n            };\n            si.setPNext(&tl_submit_infos[idx]);\n            submit_infos.push_back(si);\n        }\n    }\n\n    std::lock_guard<std::mutex> guard(queue_mutex);\n    ctx->p->q->queue.submit(submit_infos, fence);\n\n    ctx->seqs.clear();\n}\n\nstatic uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyProperties>& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) {\n    VK_LOG_DEBUG(\"ggml_vk_find_queue_family_index()\");\n    const uint32_t qfsize = queue_family_props.size();\n\n    // Try with avoid preferences first\n    for (uint32_t i = 0; i < qfsize; i++) {\n        if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) {\n            return i;\n        }\n    }\n\n    // Fall back to only required\n    for (size_t i = 0; i < qfsize; i++) {\n        if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) {\n            return i;\n        }\n    }\n\n    // Fall back to reusing compute queue\n    for (size_t i = 0; i < qfsize; i++) {\n        if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) {\n            return i;\n        }\n    }\n\n    // Fall back to ignoring min_num_queries\n    for (size_t i = 0; i < qfsize; i++) {\n        if (queue_family_props[i].queueFlags & required) {\n            return i;\n        }\n    }\n\n    // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations.\n    // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional.\n    if (compute_index >= 0) {\n        return compute_index;\n    }\n\n    std::cerr << \"ggml_vulkan: No suitable queue family index found.\" << std::endl;\n\n    for(auto &q_family : queue_family_props) {\n        std::cerr << \"Queue number: \"  + std::to_string(q_family.queueCount) << \" flags: \" + to_string(q_family.queueFlags) << std::endl;\n    }\n    abort();\n}\n\nstatic void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {\n    VK_LOG_DEBUG(\"ggml_vk_create_queue()\");\n    std::lock_guard<std::recursive_mutex> guard(device->mutex);\n\n    q.queue_family_index = queue_family_index;\n    q.transfer_only = transfer_only;\n\n    q.cmd_pool.init(device, &q);\n\n    q.queue = device->device.getQueue(queue_family_index, queue_index);\n\n    q.stage_flags = stage_flags;\n}\n\nstatic vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p) {\n    vk_context result = std::make_shared<vk_context_struct>();\n    VK_LOG_DEBUG(\"ggml_vk_create_context(\" << result << \")\");\n    ctx->gc.contexts.emplace_back(result);\n    result->p = &p;\n    return result;\n}\n\nstatic vk_context ggml_vk_create_temporary_context(vk_command_pool& p) {\n    vk_context result = std::make_shared<vk_context_struct>();\n    VK_LOG_DEBUG(\"ggml_vk_create_temporary_context(\" << result << \")\");\n    result->p = &p;\n    return result;\n}\n\nstatic vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) {\n    VK_LOG_DEBUG(\"ggml_vk_create_timeline_semaphore()\");\n    vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 };\n    vk::SemaphoreCreateInfo ci{};\n    ci.setPNext(&tci);\n    vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);\n    ctx->gc.semaphores.push_back({ semaphore, 0 });\n    return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1];\n}\n\nstatic vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) {\n    VK_LOG_DEBUG(\"ggml_vk_create_timeline_semaphore()\");\n    if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) {\n        vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };\n        vk::SemaphoreCreateInfo ci{};\n        ci.setPNext(&tci);\n        vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);\n        ctx->gc.tl_semaphores.push_back({ semaphore, 0 });\n    }\n    return &ctx->gc.tl_semaphores[ctx->semaphore_idx++];\n}\n\nstatic vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {\n    if (ctx->event_idx >= ctx->gc.events.size()) {\n        ctx->gc.events.push_back(ctx->device->device.createEvent({}));\n    }\n    return ctx->gc.events[ctx->event_idx++];\n}\n\nstatic void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) {\n    VK_LOG_DEBUG(\"ggml_vk_command_pool_cleanup()\");\n\n    // Requires command buffers to be done\n    device->device.resetCommandPool(p.pool);\n    // Don't clear the command buffers and mark them as not in use.\n    // This allows us to reuse them\n    for (auto& cmd_buffer : p.cmd_buffers) {\n        cmd_buffer.in_use = false;\n    }\n}\n\nstatic void ggml_vk_queue_command_pools_cleanup(vk_device& device) {\n    VK_LOG_DEBUG(\"ggml_vk_queue_command_pools_cleanup()\");\n\n    // Arbitrary frequency to cleanup/reuse command buffers\n    static constexpr uint32_t cleanup_frequency = 10;\n\n    if (device->compute_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) {\n        ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool);\n    }\n    if (device->transfer_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) {\n        ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool);\n    }\n}\n\nstatic std::vector<uint32_t> ggml_vk_find_memory_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {\n    std::vector<uint32_t> indices;\n\n    for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {\n        vk::MemoryType memory_type = mem_props->memoryTypes[i];\n        if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) &&\n            (flags & memory_type.propertyFlags) == flags &&\n            mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) {\n            indices.push_back(i);\n        }\n    }\n    return indices;\n}\n\nstatic vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list,\n                                       void *import_ptr = nullptr) {\n    VK_LOG_DEBUG(\"ggml_vk_create_buffer(\" << device->name << \", \" << size << \", \" << to_string(req_flags_list.begin()[0]) << \", \" << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << \")\");\n    if (size > device->max_buffer_size) {\n        throw vk::OutOfDeviceMemoryError(\"Requested buffer size exceeds device buffer size limit\");\n    }\n\n    vk_buffer buf = std::make_shared<vk_buffer_struct>();\n\n    if (size == 0) {\n        buf->size = 0;\n        return buf;\n    }\n\n    vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;\n    vk::MemoryAllocateFlags mem_flags {};\n    if (device->buffer_device_address) {\n        usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;\n        mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;\n    }\n\n    vk::BufferCreateInfo buffer_create_info{\n        vk::BufferCreateFlags(),\n        size,\n        usage_flags,\n        vk::SharingMode::eExclusive,\n        0,\n        nullptr,\n    };\n\n    vk::ExternalMemoryBufferCreateInfo external_memory_bci;\n    if (import_ptr) {\n        external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;\n        buffer_create_info.setPNext(&external_memory_bci);\n    }\n\n    buf->buffer = device->device.createBuffer(buffer_create_info);\n\n    vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);\n\n    vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();\n\n    const vk::MemoryPriorityAllocateInfoEXT mem_priority_info { 1.0f };\n\n    vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };\n\n    if (device->memory_priority) {\n        mem_flags_info.setPNext(&mem_priority_info);\n    }\n\n    if (import_ptr) {\n        vk::MemoryHostPointerPropertiesEXT host_pointer_props;\n        try {\n            host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr);\n        } catch (vk::SystemError& e) {\n            GGML_LOG_WARN(\"ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\\n\", e.what());\n            device->device.destroyBuffer(buf->buffer);\n            return {};\n        }\n        vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();\n\n        uint32_t memory_type_idx;\n        vk::MemoryPropertyFlags property_flags = *req_flags_list.begin();\n        for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) {\n            if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) {\n                continue;\n            }\n            if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) {\n                continue;\n            }\n\n            vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx];\n            // check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed\n            if ((memory_type.propertyFlags & property_flags) == property_flags) {\n                property_flags = memory_type.propertyFlags;\n                break;\n            }\n        }\n        if (memory_type_idx == 32) {\n            GGML_LOG_WARN(\"ggml_vulkan: Memory type for host allocation not found\\n\");\n            device->device.destroyBuffer(buf->buffer);\n            return {};\n        }\n\n        buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags;\n        try {\n            vk::ImportMemoryHostPointerInfoEXT import_info;\n            import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;\n            import_info.pHostPointer = import_ptr;\n            import_info.setPNext(&mem_flags_info);\n            buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info });\n        } catch (const vk::SystemError& e) {\n        }\n    } else {\n        for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {\n            const auto & req_flags = *it;\n\n            const std::vector<uint32_t> memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);\n\n            if (memory_type_indices.empty()) {\n                continue;\n            }\n            buf->memory_property_flags = req_flags;\n\n            bool done = false;\n\n            for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {\n                try {\n                    buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });\n                    done = true;\n                    break;\n                } catch (const vk::SystemError& e) {\n                    // loop and retry\n                    // during last attempt throw the exception\n                    if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {\n                        device->device.destroyBuffer(buf->buffer);\n                        throw e;\n                    }\n                }\n            }\n\n            if (done) {\n                break;\n            }\n        }\n    }\n\n    if (!buf->device_memory) {\n        device->device.destroyBuffer(buf->buffer);\n        throw vk::OutOfDeviceMemoryError(\"No suitable memory type found\");\n    }\n\n    buf->ptr = nullptr;\n\n    if (import_ptr) {\n        buf->ptr = import_ptr;\n    } else {\n        if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {\n            buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);\n        }\n    }\n\n    device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);\n\n    buf->device = device;\n    buf->size = size;\n\n    if (device->buffer_device_address) {\n        const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);\n        buf->bda_addr = device->device.getBufferAddress(addressInfo);\n    }\n\n    device->memory_logger->log_allocation(buf, size);\n\n    return buf;\n}\n\nstatic vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {\n    try {\n        return ggml_vk_create_buffer(device, size, {req_flags, fallback_flags});\n    } catch (const vk::SystemError& e) {\n        std::cerr << \"ggml_vulkan: Memory allocation of size \" << size << \" failed.\" << std::endl;\n        std::cerr << \"ggml_vulkan: \" << e.what() << std::endl;\n        throw e;\n    }\n}\n\nstatic vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {\n    vk_buffer buf;\n    try {\n        if (device->prefer_host_memory) {\n            buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,\n                                                       vk::MemoryPropertyFlagBits::eDeviceLocal});\n        } else if (device->uma) {\n            // Fall back to host memory type\n            buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,\n                                                       vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});\n        } else if (device->disable_host_visible_vidmem) {\n            if (device->allow_sysmem_fallback) {\n                buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,\n                                                           vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});\n            } else {\n                buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal});\n            }\n        } else {\n            // use rebar if available, otherwise fallback to device only visible memory\n            if (device->allow_sysmem_fallback) {\n                buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,\n                                                           vk::MemoryPropertyFlagBits::eDeviceLocal,\n                                                           vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});\n            } else {\n                buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,\n                                                           vk::MemoryPropertyFlagBits::eDeviceLocal});\n            }\n        }\n    } catch (const vk::SystemError& e) {\n        std::cerr << \"ggml_vulkan: Device memory allocation of size \" << size << \" failed.\" << std::endl;\n        std::cerr << \"ggml_vulkan: \" << e.what() << std::endl;\n        throw e;\n    }\n\n    return buf;\n}\n\nstatic void ggml_vk_destroy_buffer(vk_buffer& buf) {\n    if (buf == nullptr) {\n        return;\n    }\n\n    if (buf->device != nullptr) {\n        buf->device->memory_logger->log_deallocation(buf);\n    }\n\n    buf.reset();\n}\n\nstatic vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset = 0) {\n    return { buf, offset, ggml_vk_get_max_buffer_range(ctx, buf, offset) };\n}\n\nstatic void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) {\n    VK_LOG_DEBUG(\"ggml_vk_sync_buffers()\");\n\n    const bool transfer_queue = subctx->p->q->transfer_only;\n\n    if (ctx) {\n        ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;\n    }\n\n    subctx->s->buffer->buf.pipelineBarrier(\n        subctx->p->q->stage_flags,\n        subctx->p->q->stage_flags,\n        {},\n        { {\n          { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) },\n          { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }\n        } },\n        {},\n        {}\n    );\n}\n\nstatic void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {\n    VK_LOG_DEBUG(\"ggml_vk_set_event()\");\n\n    ctx->s->buffer->buf.setEvent(\n        event,\n        ctx->p->q->stage_flags\n    );\n}\n\nstatic void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events) {\n    VK_LOG_DEBUG(\"ggml_vk_wait_events()\");\n    if (events.empty()) {\n        return;\n    }\n\n    ctx->s->buffer->buf.waitEvents(\n        events,\n        ctx->p->q->stage_flags,\n        ctx->p->q->stage_flags,\n        {},\n        {},\n        {}\n    );\n}\n\nstruct vk_fa_tuning_params {\n    FaCodePath path;\n    uint32_t workgroup_size;\n    uint32_t subgroup_size;\n    uint32_t block_rows;\n    uint32_t block_cols;\n    uint32_t d_split;\n    uint32_t row_split;\n    bool shmem_staging;\n    bool disable_subgroups;\n    uint32_t limit_occupancy_shmem;\n\n    void print() const {\n        std::cerr << \"path=\" << path << \" workgroup_size=\" << workgroup_size << \" subgroup_size=\" << subgroup_size <<\n                     \" block_rows=\" << block_rows << \" block_cols=\" << block_cols << \" d_split=\" << d_split <<\n                     \" row_split=\" << row_split << \" shmem_staging=\" << shmem_staging << \" disable_subgroups=\" << disable_subgroups <<\n                     \" limit_occupancy_shmem=\" << limit_occupancy_shmem << std::endl;\n    }\n};\n\nstatic bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);\nstatic bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);\n\nstatic vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {\n    GGML_UNUSED(kv_type);\n\n    vk_fa_tuning_params result{};\n    result.path = FA_SCALAR;\n\n    if (device->vendor_id == VK_VENDOR_ID_INTEL) {\n        // Disable subgroup use due to performance issues when enforcing subgroup sizes\n        result.subgroup_size = 32;\n        result.disable_subgroups = true;\n    } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) {\n        result.subgroup_size = n_rows < 4 ? 32 : device->subgroup_size;\n    } else {\n        result.subgroup_size = device->subgroup_size;\n    }\n\n    // Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers\n    uint32_t row_split_max_hsk = 64;\n    if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !device->uma) {\n        row_split_max_hsk = n_rows <= 8 ? 64 : 128;\n    }\n    result.row_split = (n_rows < 4 || hsk <= row_split_max_hsk) ? 1 : 4;\n\n    if (result.subgroup_size > 32 && (n_rows < 4 || hsk < (result.row_split == 1 ? 128 : 64))) {\n        result.workgroup_size = result.subgroup_size * 2;\n    } else {\n        result.workgroup_size = result.subgroup_size * 4;\n    }\n\n    const uint32_t D = hsk | hsv;\n\n    const bool reduce_block_rows = D & 8 || n_kv < 1024 || device->vendor_id == VK_VENDOR_ID_INTEL;\n\n    if (n_rows == 1) {\n        result.block_rows = 1;\n        result.block_cols = 64;\n    } else {\n        // row_split 1 means higher register use per row, so block size has to be adjusted\n        if (result.row_split == 1) {\n            result.block_rows = n_rows == 2 ? 2 : ((n_rows <= 4 || reduce_block_rows) ? 4 : 8);\n        } else {\n            result.block_rows = n_rows <= 4 ? 4 : ((n_rows <= 8 || reduce_block_rows) ? 8 : 16);\n        }\n\n        result.block_cols = (D & 8) ? 64 : 32;\n    }\n\n    const uint32_t D_lsb = D ^ (D & (D-1));  // extract lowest set bit\n\n    result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);\n\n    result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;\n\n    if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {\n        result.block_rows /= 2;\n    }\n\n    // On AMD RDNA, for small head sizes and big batch size the shader uses few registers, so too many subgroups get scheduled\n    // at once and end up thrashing the cache. Fix this by setting a large (unused) shmem buffer that reduces occupancy.\n    // This targets an occupancy of 4 subgroups per SIMD.\n    if (device->vendor_id == VK_VENDOR_ID_AMD && device->properties.limits.maxComputeSharedMemorySize == 65536) {\n        if (device->architecture != AMD_GCN && n_rows >= 64 && hsk <= 128) {\n            // 30kb target for hsk > 64, 26kb for <= 64 due to smaller workgroup size\n            // Values are guessed, tested on RDNA2\n            result.limit_occupancy_shmem = (hsk <= 64 ? 26 : 30) * 1024 / 4 / 4;\n        } else if (device->architecture == AMD_GCN && n_rows <= 8 && hsk >= 256) {\n            // Same thing for GCN, with an occupancy target of 2 subgroups per SIMD.\n            // Here low-batch FA with large head size is affected.\n            // n_rows < 4 switch because workgroup size switches from 128 to 256 there.\n            result.limit_occupancy_shmem = (n_rows < 4 ? 14 : 26) * 1024 / 4 / 4;\n        }\n    }\n\n    return result;\n}\n\nstatic vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {\n    GGML_UNUSED(n_rows);\n    GGML_UNUSED(n_kv);\n    GGML_UNUSED(kv_type);\n    GGML_UNUSED(f32acc);\n\n    vk_fa_tuning_params result{};\n    result.path = FA_COOPMAT1;\n\n    const uint32_t D = hsk | hsv;\n\n    const uint32_t coopmat_block_rows = 16;\n    const uint32_t coopmat_block_cols = 16;\n\n    const uint32_t num_subgroups = 4;\n\n    result.block_rows = coopmat_block_rows;\n    result.block_cols = coopmat_block_cols * num_subgroups;\n    result.row_split = num_subgroups;\n    result.subgroup_size = device->subgroup_size;\n    result.workgroup_size = num_subgroups * result.subgroup_size;\n\n    const uint32_t D_lsb = D ^ (D & (D-1));  // extract lowest set bit\n    result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);\n\n    result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;\n\n    return result;\n}\n\nstatic vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {\n    GGML_UNUSED(n_kv);\n    GGML_UNUSED(f32acc);\n\n    vk_fa_tuning_params result{};\n    result.path = FA_COOPMAT2;\n\n    const uint32_t D = hsk | hsv;\n\n    const bool small_rows = n_rows < 32;\n\n    if (small_rows) {\n        result.block_rows = 32;\n        result.block_cols = 32;\n    } else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) {\n        result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;\n        result.block_cols = 32;\n    } else {\n        result.block_rows = 64;\n        result.block_cols = 64;\n    }\n\n    result.subgroup_size = device->subgroup_size;\n    result.workgroup_size = (small_rows && (D % 32) == 0) ? 256 : 128;\n\n    return result;\n}\n\nstatic vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {\n    FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :\n                      device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;\n\n    if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) {\n        // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090\n        path = FA_SCALAR;\n    }\n\n    if (path == FA_COOPMAT1) {\n        bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||\n                        (!f32acc && device->coopmat_support_16x16x16_f16acc);\n        const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);\n        bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);\n\n        if (!shape_ok || !shmem_ok) {\n            path = FA_SCALAR;\n        }\n    }\n\n    // scalar is faster than coopmat when N==1\n    if (n_rows == 1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) {\n        path = FA_SCALAR;\n    }\n\n    switch (path) {\n    case FA_SCALAR:\n        return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);\n    case FA_COOPMAT1:\n        return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);\n    case FA_COOPMAT2:\n        return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);\n    default:\n        throw std::runtime_error(\"unsupported FaCodePath\");\n    }\n}\n\nstatic vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,\n                                                  bool use_mask, bool use_mask_opt, bool use_logit_softcap) {\n    const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary &&\n                                 (device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2);\n\n    uint32_t flags = (use_mask_opt      ? 1 : 0) |\n                     (use_mask          ? 2 : 0) |\n                     (use_logit_softcap ? 4 : 0) |\n                     (old_amd_windows   ? 8 : 0);\n\n    const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;\n\n    return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};\n}\n\nstatic std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) {\n    return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split,\n            state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem};\n}\n\nstatic bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {\n\n    uint32_t lut_size = 0;\n    switch (src0_type) {\n    case GGML_TYPE_IQ1_S:\n    case GGML_TYPE_IQ1_M:\n        lut_size = 2*2048 + 4*2048;\n        break;\n    case GGML_TYPE_IQ2_XXS:\n        lut_size = 8*256;\n        break;\n    case GGML_TYPE_IQ2_XS:\n        lut_size = 8*512;\n        break;\n    case GGML_TYPE_IQ2_S:\n        lut_size = 8*1024;\n        break;\n    case GGML_TYPE_IQ3_XXS:\n        lut_size = 4*256;\n        break;\n    case GGML_TYPE_IQ3_S:\n        lut_size = 4*512;\n        break;\n    case GGML_TYPE_IQ4_NL:\n    case GGML_TYPE_IQ4_XS:\n    case GGML_TYPE_MXFP4:\n        lut_size = 4*16;\n        break;\n    default:\n        break;\n    }\n\n    // Needs to be kept up to date on shader changes\n    const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;\n    const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);\n    const uint32_t warps = warptile[0] / warptile[10];\n\n    const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;\n    const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;\n    const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;\n    const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;\n\n    const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;\n    const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;\n\n    VK_LOG_DEBUG(\"ggml_vk_matmul_shmem_support(warptile=(\" << warptile[0] << \",\" << warptile[1] << \",\" << warptile[2] << \"), \"\n                 \"mul_mat_id=\" << mul_mat_id << \", src0_type=\" << ggml_type_name(src0_type) << \", supported=\" << supported);\n\n    return supported;\n}\n\nstruct GpuPipelineConfig {\n    // GPU architecture identifier.\n    // Example: vk_device_architecture::AMD_GCN\n    vk_device_architecture arch;\n\n    // Mapping of pipeline names to their specific subgroup sizes.\n    // Example: {\"soft_max_f32\", 64}\n    std::unordered_map<std::string, uint32_t> pipelines;\n\n    // Default subgroup size for this GPU.\n    // Defaults to 0 if not explicitly provided.\n    uint32_t default_subgroup_size = 0;\n};\n\n// Pipeline configuration for RDNA1 GPUs.\nstatic const std::unordered_map<std::string, uint32_t> rdna1_pipelines = {\n    {\"soft_max\", 64}, {\"im2col\", 64},\n    {\"argmax\", 64}, {\"mul_mat_vec\", 64},\n    {\"mul_mat_vec_f16\", 32}, {\"mul_mat_vec_f32_f16\", 32}\n};\n\n// Pipeline configuration for RDNA2 GPUs.\nstatic const std::unordered_map<std::string, uint32_t> rdna2_pipelines = {\n    {\"soft_max\", 64}, {\"im2col\", 64},\n};\n\nstatic constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32;\n\n// Define configurations for different GPUs.\nstatic std::vector<GpuPipelineConfig> gpu_pipeline_configs = {\n    {\n        vk_device_architecture::AMD_RDNA1,\n        {\n            rdna1_pipelines,\n        },\n        RDNA_DEFAULT_SUBGROUP_SIZE\n    },\n    {\n        vk_device_architecture::AMD_RDNA2,\n        {\n            rdna2_pipelines,\n        },\n        RDNA_DEFAULT_SUBGROUP_SIZE\n    },\n};\n\nstatic uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) {\n    for (const auto &config : gpu_pipeline_configs) {\n        if (config.arch == arch) {\n            auto pipIt = config.pipelines.find(pipeline_name);\n            if (pipIt != config.pipelines.end()) {\n                return pipIt->second;\n            }\n            std::vector<std::pair<std::string, uint32_t>> sorted_pipelines(config.pipelines.begin(), config.pipelines.end());\n            std::sort(sorted_pipelines.begin(), sorted_pipelines.end(),\n                      [](const auto &a, const auto &b) { return a.first.size() > b.first.size(); });\n            for (const auto &entry : sorted_pipelines) {\n                if (pipeline_name.find(entry.first) != std::string::npos) {\n                    return entry.second;\n                }\n            }\n            return config.default_subgroup_size;\n        }\n    }\n    return 0; // If no matching configuration is found\n}\n\nstatic void ggml_vk_load_shaders(vk_device& device) {\n    VK_LOG_DEBUG(\"ggml_vk_load_shaders(\" << device->name << \")\");\n\n    std::lock_guard<std::recursive_mutex> guard(device->mutex);\n    // some shaders have a minimum subgroup size\n    const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);\n    const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);\n    const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);\n\n    const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;\n    const uint32_t mul_mat_subgroup_size_8 = std::max(mul_mat_subgroup_size, 8u);\n    const uint32_t mul_mat_subgroup_size_16 = std::max(mul_mat_subgroup_size, 16u);\n    const uint32_t mul_mat_subgroup_size_32 = std::max(mul_mat_subgroup_size, 32u);\n\n    const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) ||\n                                      (device->subgroup_size_control && device->subgroup_max_size >= 16);\n\n    // mulmat\n    std::vector<uint32_t> l_warptile, m_warptile, s_warptile,\n                          l_warptile_id, m_warptile_id, s_warptile_id,\n                          l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,\n                          l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,\n                          l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k,\n                          l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,\n                          l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid,\n                          l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int,\n                          l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k;\n    std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,\n                            l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,\n                            l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,\n                            l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;\n\n    uint32_t l_align, m_align, s_align;\n    if (device->coopmat2) {\n        // spec constants and tile sizes for non-quant matmul/matmul_id\n        l_warptile = { 256, 128, 256, 64, 1 };\n        m_warptile = { 256, 128, 128, 64, 0 };\n        s_warptile = { 128,  64,  64, 64, 0 };\n        l_wg_denoms = {128, 256, 1 };\n        m_wg_denoms = {128, 128, 1 };\n        s_wg_denoms = { 64,  64, 1 };\n\n        // spec constants and tile sizes for quant matmul (non-Qi_K)\n        l_warptile_mmq = { 256, 128, 256, 64, 1 };\n        m_warptile_mmq = { 256, 128, 128, 64, 1 };\n        s_warptile_mmq = { 256, 32,  64, 128, 0 };\n        l_mmq_wg_denoms = { 128, 256, 1 };\n        m_mmq_wg_denoms = { 128, 128, 1 };\n        s_mmq_wg_denoms = { 32,  64,  1 };\n\n        // spec constants and tile sizes for quant matmul (Qi_K)\n        l_warptile_mmq_k = { 256, 128, 256, 64, 1 };\n        m_warptile_mmq_k = { 256, 128, 128, 64, 1 };\n        s_warptile_mmq_k = { 256, 32,  64, 128, 0 };\n        l_mmq_wg_denoms_k = { 128, 256, 1 };\n        m_mmq_wg_denoms_k = { 128, 128, 1 };\n        s_mmq_wg_denoms_k = { 32,  64,  1 };\n\n        // spec constants and tile sizes for quant matmul_id\n        l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size };\n        m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };\n        s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };\n        l_mmqid_wg_denoms = { 128, 128, 1 };\n        m_mmqid_wg_denoms = { 128, 64, 1 };\n        s_mmqid_wg_denoms = { 128, 64, 1 };\n\n        l_align = 128;\n        m_align =  64;\n        s_align =  32;\n    } else {\n        // Matrix cores require different warp group sizes\n        const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4;\n        const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4;\n        const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2;\n        const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4;\n        const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2;\n        const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2;\n        const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1;\n        const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;\n        const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;\n\n        const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;\n\n        l_warptile = { 128,             128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };\n        m_warptile = { 128,              64,  64, 16, subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };\n        s_warptile = { subgroup_size_32, 32,  32, 16, s_warptile_wm,       32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };\n\n        l_warptile_mmq = { 128,             128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };\n        m_warptile_mmq = { 128,              64,  64, 32, subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };\n        s_warptile_mmq = { subgroup_size_32, 32,  32, 32, s_warptile_wm,       32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };\n\n        // Integer MMQ has a smaller shared memory profile, but heavier register use\n        l_warptile_mmq_int = { 128,             128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };\n        m_warptile_mmq_int = { 128,              64,  64, 32, subgroup_size_8,     32, 2, 2, 2, 1, subgroup_size_8 };\n        s_warptile_mmq_int = { subgroup_size_32, 32,  32, 32, s_warptile_wm,       32, 2, 2, 1, 1, subgroup_size_8 };\n\n        // K-quants use even more registers, mitigate by setting WMITER to 1\n        l_warptile_mmq_int_k = { 128,               128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };\n        m_warptile_mmq_int_k = { 128,                64,  64, 32, subgroup_size_8,     32, 1, 2, 2, 1, subgroup_size_8 };\n        s_warptile_mmq_int_k = { subgroup_size_32,   32,  32, 32, s_warptile_wm,       32, 1, 2, 1, 1, subgroup_size_8 };\n\n        l_warptile_id = { 128,                      128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };\n        m_warptile_id = { 128,                       64,  64, 16, mul_mat_subgroup_size_16,     32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };\n        s_warptile_id = { mul_mat_subgroup_size_16,  32,  32, 16, s_warptile_wm,                32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };\n\n        l_warptile_mmqid = { 128,                       128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };\n        m_warptile_mmqid = { 128,                        64,  64, 32, mul_mat_subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };\n        s_warptile_mmqid = { mul_mat_subgroup_size_32,   32,  32, 32, s_warptile_wm,               32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };\n\n        l_warptile_mmqid_int = { 128,                       128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };\n        m_warptile_mmqid_int = { 128,                        64,  64, 32, mul_mat_subgroup_size_8,     32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };\n        s_warptile_mmqid_int = { mul_mat_subgroup_size_32,   32,  32, 32, s_warptile_wm,               32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };\n\n        l_warptile_mmqid_int_k = { 128,                     128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };\n        m_warptile_mmqid_int_k = { 128,                      64,  64, 32, mul_mat_subgroup_size_16,     32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };\n        s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32,  32, 32, s_warptile_wm,                32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };\n\n        // chip specific tuning\n        if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {\n            m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };\n            m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };\n        } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->coopmat_support && device->driver_id != vk::DriverId::eAmdProprietary) {\n            // This is intentionally using tx_m values, slight performance increase\n            l_warptile = { 256, 128, 128, 16, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };\n            l_warptile_mmq = l_warptile_mmq_int = { 256, 128, 128, 32, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };\n            l_warptile_mmq_int_k = { 256, 128, 128, 32, subgroup_size_16, 64, 1, 4, 2, 1, subgroup_size_16 };\n        } else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) {\n            // Xe2/Xe3 with coopmat enabled - warptile performance tuning\n            l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };\n            l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };\n        }\n\n        l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };\n        m_mmq_wg_denoms = m_wg_denoms = { 64,  64, 1 };\n        s_mmq_wg_denoms = s_wg_denoms = { 32,  32, 1 };\n        l_align = 128;\n        m_align =  64;\n        s_align =  32;\n\n        for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {\n            ggml_type t = (ggml_type)i;\n            // Disable medium and large matrix multiplication if not enough shared memory is available\n            // Check mmq warptiles as the largest configuration\n            // Throw an error if not enough for any matrix multiplication is available\n            if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) {\n                std::cerr << \"ggml_vulkan: Error: Shared memory size too small for matrix multiplication.\" << std::endl;\n                throw std::runtime_error(\"Shared memory size too small for matrix multiplication.\");\n            } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) {\n                device->mul_mat_m[i] = false;\n                device->mul_mat_l[i] = false;\n            } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) {\n                device->mul_mat_l[i] = false;\n            }\n\n            // Disable mul_mat_id if not enough shared memory is available\n            if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) {\n                device->mul_mat_id_s[i] = false;\n                device->mul_mat_id_m[i] = false;\n                device->mul_mat_id_l[i] = false;\n            } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) {\n                device->mul_mat_id_m[i] = false;\n                device->mul_mat_id_l[i] = false;\n            } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {\n                device->mul_mat_id_l[i] = false;\n            }\n        }\n    }\n\n    if (!device->pipeline_matmul_f32) {\n        device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();\n    }\n    if (!device->pipeline_matmul_f32_f16) {\n        device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();\n    }\n    if (!device->pipeline_matmul_id_f32) {\n        device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();\n    }\n    if (!device->pipeline_matmul_bf16) {\n        device->pipeline_matmul_bf16 = std::make_shared<vk_matmul_pipeline_struct>();\n    }\n    if (!device->pipeline_matmul_id_bf16) {\n        device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();\n    }\n\n    std::vector<std::future<void>> compiles;\n    auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,\n                                              uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,\n                                              uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {\n\n        if (!require_full_subgroups && required_subgroup_size == 0) {\n            required_subgroup_size = get_subgroup_size(name, device->architecture);\n        }\n\n        vk_pipeline *ptr = &base_pipeline;\n\n        int num_pipelines = 1;\n#if defined(VK_EXT_shader_64bit_indexing)\n        if (device->shader_64b_indexing) {\n            num_pipelines = 2;\n        }\n#endif\n        for (int i = 0; i < num_pipelines; ++i, ptr = &(*ptr)->next) {\n            vk_pipeline &pipeline = *ptr;\n            if (!pipeline) {\n                pipeline = std::make_shared<vk_pipeline_struct>();\n            }\n            if (!pipeline->initialized) {\n                pipeline->name = name;\n                pipeline->parameter_count = parameter_count;\n                pipeline->push_constant_size = push_constant_size;\n                pipeline->wg_denoms = wg_denoms;\n                pipeline->align = align;\n                pipeline->initialized = true;\n#if defined(VK_EXT_shader_64bit_indexing)\n                pipeline->is_64b_indexing = (i == 1);\n#endif\n            }\n\n            if (!pipeline->needed || pipeline->compiled) {\n                continue;\n            }\n            // TODO: We're no longer benefitting from the async compiles (shaders are\n            // compiled individually, as needed) and this complexity can be removed.\n            {\n                // wait until fewer than N compiles are in progress\n                uint32_t N = std::max(1u, std::thread::hardware_concurrency());\n                std::unique_lock<std::mutex> guard(compile_count_mutex);\n                while (compile_count >= N) {\n                    compile_count_cond.wait(guard);\n                }\n                compile_count++;\n            }\n\n            compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,\n                                          parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));\n        }\n    };\n\n    auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint,\n                                              uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,\n                                              uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {\n        return ggml_vk_create_pipeline(device, pipeline, name.c_str(), spv_size, spv_data, entrypoint,\n                                       parameter_count, push_constant_size, wg_denoms, specialization_constants,\n                                       align, disable_robustness, require_full_subgroups, required_subgroup_size);\n    };\n\n#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \\\n        for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \\\n            FaCodePath path = fa.first.path; \\\n            uint32_t Br = fa.first.Br; \\\n            uint32_t Bc = fa.first.Bc; \\\n            bool aligned = fa.first.aligned; \\\n            bool f32acc = fa.first.f32acc; \\\n            uint32_t fa_sgs = fa.first.subgroup_size; \\\n            bool fa_ds = fa.first.subgroup_size == 0; \\\n            if (path == FAPATH) { \\\n                if (aligned) { \\\n                    if (f32acc) { \\\n                        ggml_vk_create_pipeline(device, fa.second, \"flash_attn_f32_f16_aligned_f32acc\" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  \"main\", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \\\n                    } else { \\\n                        ggml_vk_create_pipeline(device, fa.second, \"flash_attn_f32_f16_aligned_f16acc\" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  \"main\", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \\\n                    } \\\n                } else { \\\n                    if (f32acc) { \\\n                        ggml_vk_create_pipeline(device, fa.second, \"flash_attn_f32_f16_f32acc\"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  \"main\", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1,  true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \\\n                    } else { \\\n                        ggml_vk_create_pipeline(device, fa.second, \"flash_attn_f32_f16_f16acc\"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  \"main\", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1,  true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \\\n                    } \\\n                } \\\n            } \\\n        }\n\n    if (device->fp16) {\n        CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )\n        CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )\n        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )\n        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )\n    } else {\n        CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)\n        CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)\n        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)\n        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)\n    }\n#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)\n    if (device->coopmat1_fa_support) {\n        CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)\n        CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)\n        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)\n        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)\n    }\n#endif\n#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)\n    if (device->coopmat2) {\n        CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)\n        CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)\n        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)\n        CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)\n        CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)\n        CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)\n        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)\n        CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)\n    }\n#endif\n#undef CREATE_FA\n\n    const int mul_mat_id_param_count = 5;\n\n#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)\n    if (device->coopmat2) {\n\n        // Create 6 variants, {s,m,l}x{unaligned,aligned}\n#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \\\n        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC \"_l\", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true);   \\\n        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC \"_m\", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true);   \\\n        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC \"_s\", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true);   \\\n        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC \"_aligned_l\", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true);   \\\n        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC \"_aligned_m\", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true);   \\\n        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC \"_aligned_s\", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true);   \\\n\n        // Create 2 variants, {f16,f32} accumulator\n#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \\\n        CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \\\n        CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \\\n\n        CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)\n#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)\n        if (device->coopmat_bf16_support) {\n            CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)\n        }\n#endif\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K], matmul_q5_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K], matmul_q6_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S],   matmul_iq1_s_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M],   matmul_iq1_m_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS],  matmul_iq2_xs_f16,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S],   matmul_iq2_s_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S],   matmul_iq3_s_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS],  matmul_iq4_xs_f16,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL],  matmul_iq4_nl_f16,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4],   matmul_mxfp4_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)\n\n        GGML_ASSERT(device->subgroup_ballot);\n\n        CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)\n#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)\n        if (device->coopmat_bf16_support) {\n            CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)\n        }\n#endif\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)\n#undef CREATE_MM\n#undef CREATE_MM2\n    } else\n#endif  // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)\n#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)\n    if (device->coopmat_support) {\n        // Create 6 variants, {s,m,l}x{unaligned,aligned}\n#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \\\n        if (device->mul_mat ## ID ## _l[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC \"_l\", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true);   \\\n        if (device->mul_mat ## ID ## _m[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC \"_m\", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true);   \\\n        if (device->mul_mat ## ID ## _s[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC \"_s\", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true);   \\\n        if (device->mul_mat ## ID ## _l[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC \"_aligned_l\", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true);   \\\n        if (device->mul_mat ## ID ## _m[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC \"_aligned_m\", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true);   \\\n        if (device->mul_mat ## ID ## _s[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC \"_aligned_s\", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true);   \\\n\n        // Create 2 variants, {f16,f32} accumulator\n#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \\\n        if (device->coopmat_acc_f16_support) { \\\n            CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \\\n        } \\\n        if (device->coopmat_acc_f32_support) { \\\n            CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \\\n        } \\\n\n        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );\n        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );\n        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );\n        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );\n#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)\n        if (device->coopmat_bf16_support) {\n            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )\n        }\n#endif\n\n        if (device->coopmat_acc_f16_support) {\n            CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n\n            CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S],   matmul_iq1_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M],   matmul_iq1_m_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS],  matmul_iq2_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S],   matmul_iq2_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S],   matmul_iq3_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS],  matmul_iq4_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL],  matmul_iq4_nl_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4],   matmul_mxfp4_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n        } else {\n            CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n\n            CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc,   matmul_iq1_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc,   matmul_iq1_m_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc,  matmul_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc,   matmul_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc,   matmul_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc,   matmul_mxfp4_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );\n        }\n\n        GGML_ASSERT(device->subgroup_ballot);\n\n        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)\n        if (device->coopmat_bf16_support) {\n            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        }\n#endif\n\n        CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n        CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);\n#undef CREATE_MM2\n#undef CREATE_MM\n    } else\n#endif  // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)\n    if (device->fp16) {\n        // Create 6 variants, {s,m,l}x{unaligned,aligned}\n#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \\\n        if (device->mul_mat ## ID ## _l[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC \"_l\", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n        if (device->mul_mat ## ID ## _m[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC \"_m\", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n        if (device->mul_mat ## ID ## _s[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC \"_s\", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n        if (device->mul_mat ## ID ## _l[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC \"_aligned_l\", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n        if (device->mul_mat ## ID ## _m[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC \"_aligned_m\", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n        if (device->mul_mat ## ID ## _s[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC \"_aligned_s\", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n\n#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \\\n        if (device->mul_mat ## ID ## _l[TYPE]) { \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC        \"_l\", NAMELC ## _len,        NAMELC ##  _data,        \"main\", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n        } \\\n        if (device->mul_mat ## ID ## _m[TYPE]) { \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC        \"_m\", NAMELC ## _len,        NAMELC ##  _data,        \"main\", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n        } \\\n        if (device->mul_mat ## ID ## _s[TYPE]) { \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC        \"_s\", NAMELC ## _len,        NAMELC ##  _data,        \"main\", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n        } \\\n\n        // Create 2 variants, {f16,f32} accumulator\n#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \\\n        CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \\\n        CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \\\n\n        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);\n\n        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);\n\n        CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n\n        CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S],   matmul_iq1_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M],   matmul_iq1_m_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS],  matmul_iq2_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S],   matmul_iq2_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S],   matmul_iq3_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS],  matmul_iq4_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL],  matmul_iq4_nl_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4],   matmul_mxfp4_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n\n#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)\n        if (device->integer_dot_product) {\n            CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);\n            CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);\n            CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);\n            CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);\n            CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);\n\n            CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);\n\n            CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);\n            CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);\n            CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);\n            CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);\n            CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);\n        }\n#endif\n\n        if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {\n            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);\n            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);\n            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);\n            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);\n\n            CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n\n#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)\n            if (device->integer_dot_product) {\n                CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n                CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n                CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n                CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n                CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n\n                CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n\n                CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);\n                CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);\n                CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);\n                CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);\n                CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);\n            }\n#endif\n        } else {\n            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n\n            CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_iq1_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_iq1_m_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_iq2_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_iq2_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_iq3_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_iq4_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_iq4_nl_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_mxfp4_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n\n#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)\n            if (device->integer_dot_product) {\n                CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n                CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n                CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n                CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n                CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n\n                CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n\n                CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n                CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n                CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n                CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n                CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            }\n#endif\n        }\n#undef CREATE_MM2\n#undef CREATE_MMQ\n#undef CREATE_MM\n    } else {\n        // Create 6 variants, {s,m,l}x{unaligned,aligned}\n#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \\\n        if (device->mul_mat ## ID ## _l[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC \"_l\", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n        if (device->mul_mat ## ID ## _m[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC \"_m\", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n        if (device->mul_mat ## ID ## _s[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC \"_s\", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n        if (device->mul_mat ## ID ## _l[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC \"_aligned_l\", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n        if (device->mul_mat ## ID ## _m[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC \"_aligned_m\", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n        if (device->mul_mat ## ID ## _s[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC \"_aligned_s\", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \\\n\n#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \\\n        if (device->mul_mat ## ID ## _l[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC \"_l\", NAMELC ## _fp32_len, NAMELC ## _fp32_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \\\n        if (device->mul_mat ## ID ## _m[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC \"_m\", NAMELC ## _fp32_len, NAMELC ## _fp32_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \\\n        if (device->mul_mat ## ID ## _s[TYPE]) \\\n            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC \"_s\", NAMELC ## _fp32_len, NAMELC ## _fp32_data, \"main\", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \\\n\n        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);\n\n        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);\n\n        CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n\n        CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc,   matmul_iq1_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc,   matmul_iq1_m_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc,  matmul_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc,   matmul_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc,   matmul_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc,   matmul_mxfp4_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);\n\n#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)\n        if (device->integer_dot_product) {\n            CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );\n            CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );\n            CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );\n            CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );\n            CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );\n\n            CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );\n            CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );\n            CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );\n            CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );\n            CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );\n        }\n#endif\n\n        if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {\n            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);\n            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);\n            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);\n            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);\n\n            CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc,   matmul_id_subgroup_iq1_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc,   matmul_id_subgroup_iq1_m_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc,  matmul_id_subgroup_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc,   matmul_id_subgroup_iq2_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_subgroup_iq3_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_subgroup_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_subgroup_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc,   matmul_id_subgroup_mxfp4_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);\n        } else {\n            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n\n            CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc,   matmul_id_iq1_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc,   matmul_id_iq1_m_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc,  matmul_id_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc,   matmul_id_iq2_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_iq3_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc,   matmul_id_mxfp4_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n        }\n    }\n    // reusing CREATE_MM from the fp32 path\n    if ((device->coopmat2 || device->coopmat_support)\n#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)\n        && !device->coopmat_bf16_support\n#endif\n        ) {\n        const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;\n\n        // use scalar tile sizes\n        l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };\n        m_warptile = { 128,  64,  64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };\n        s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, 2, 2, 1, subgroup_size_8 };\n\n        l_wg_denoms = {128, 128, 1 };\n        m_wg_denoms = { 64,  64, 1 };\n        s_wg_denoms = { 32,  32, 1 };\n\n        if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) {\n            // Xe2/Xe3 - bf16 warptile performance tuning\n            l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 };\n        }\n\n        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);\n        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);\n    }\n#undef CREATE_MM\n\n    // mul mat vec\n\n    // the number of rows computed per shader depends on GPU model and quant\n    uint32_t rm_stdq = 1;\n    uint32_t rm_kq = 2;\n    uint32_t rm_stdq_int = 1;\n    uint32_t rm_kq_int = 1;\n    auto const &rm_iq_int = [](uint32_t i) { return i == 0 ? 8u : 4u; };\n    if (device->vendor_id == VK_VENDOR_ID_AMD) {\n        if (device->architecture == AMD_GCN) {\n            rm_stdq = 2;\n            rm_kq = 4;\n            rm_stdq_int = 4;\n        }\n    } else if (device->vendor_id == VK_VENDOR_ID_INTEL) {\n        rm_stdq = 2;\n        rm_stdq_int = 2;\n    }\n    uint32_t rm_iq = 2 * rm_kq;\n\n    const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;\n    // Ensure a subgroup size >= 16 is available\n    const bool use_subgroups16 = use_subgroups && subgroup_min_size_16;\n\n    const uint32_t subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16) ? 16 : device->subgroup_size;\n    const uint32_t subgroup_size16 = std::max(subgroup_size, 16u);\n\n    const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0;\n    const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0;\n    static constexpr uint32_t mul_mat_vec_num_bindings = 5;\n    static constexpr uint32_t mul_mat_vec_id_num_bindings = 6;\n\n    for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) {\n        const uint32_t wg_size_subgroup   = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4);\n        const uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size16 : (subgroup_size16 * 4);\n\n        const shader_reduction_mode reduc = (use_subgroups && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP :\n                                            (use_subgroups && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID :\n                                            SHADER_REDUCTION_MODE_SHMEM;\n\n        const shader_reduction_mode reduc16 = (use_subgroups16 && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP :\n                                              (use_subgroups16 && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID :\n                                              SHADER_REDUCTION_MODE_SHMEM;\n\n        for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], \"mul_mat_vec_f32_f32_f32\",  arr_dmmv_f32_f32_f32_len[reduc],  arr_dmmv_f32_f32_f32_data[reduc],  \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], \"mul_mat_vec_f16_f32_f32\",  arr_dmmv_f16_f32_f32_len[reduc],  arr_dmmv_f16_f32_f32_data[reduc],  \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], \"mul_mat_vec_bf16_f32_f32\", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], \"mul_mat_vec_q4_0_f32_f32\", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], \"mul_mat_vec_q4_1_f32_f32\", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], \"mul_mat_vec_q5_0_f32_f32\", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], \"mul_mat_vec_q5_1_f32_f32\", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], \"mul_mat_vec_q8_0_f32_f32\", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], \"mul_mat_vec_q2_k_f32_f32\", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], \"mul_mat_vec_q3_k_f32_f32\", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], \"mul_mat_vec_q4_k_f32_f32\", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], \"mul_mat_vec_q5_k_f32_f32\", arr_dmmv_q5_k_f32_f32_len[reduc16], arr_dmmv_q5_k_f32_f32_data[reduc16], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], \"mul_mat_vec_q6_k_f32_f32\", arr_dmmv_q6_k_f32_f32_len[reduc16], arr_dmmv_q6_k_f32_f32_data[reduc16], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i],   \"mul_mat_vec_iq1_s_f32_f32\",   arr_dmmv_iq1_s_f32_f32_len[reduc16],   arr_dmmv_iq1_s_f32_f32_data[reduc16],   \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i],   \"mul_mat_vec_iq1_m_f32_f32\",   arr_dmmv_iq1_m_f32_f32_len[reduc16],   arr_dmmv_iq1_m_f32_f32_data[reduc16],   \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], \"mul_mat_vec_iq2_xxs_f32_f32\", arr_dmmv_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_iq2_xxs_f32_f32_data[reduc16], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i],  \"mul_mat_vec_iq2_xs_f32_f32\",  arr_dmmv_iq2_xs_f32_f32_len[reduc16],  arr_dmmv_iq2_xs_f32_f32_data[reduc16],  \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i],   \"mul_mat_vec_iq2_s_f32_f32\",   arr_dmmv_iq2_s_f32_f32_len[reduc16],   arr_dmmv_iq2_s_f32_f32_data[reduc16],   \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], \"mul_mat_vec_iq3_xxs_f32_f32\", arr_dmmv_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_iq3_xxs_f32_f32_data[reduc16], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i],   \"mul_mat_vec_iq3_s_f32_f32\",   arr_dmmv_iq3_s_f32_f32_len[reduc16],   arr_dmmv_iq3_s_f32_f32_data[reduc16],   \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i],  \"mul_mat_vec_iq4_xs_f32_f32\",  arr_dmmv_iq4_xs_f32_f32_len[reduc16],  arr_dmmv_iq4_xs_f32_f32_data[reduc16],  \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i],  \"mul_mat_vec_iq4_nl_f32_f32\",  arr_dmmv_iq4_nl_f32_f32_len[reduc16],  arr_dmmv_iq4_nl_f32_f32_data[reduc16],  \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i],   \"mul_mat_vec_mxfp4_f32_f32\",   arr_dmmv_mxfp4_f32_f32_len[reduc16],   arr_dmmv_mxfp4_f32_f32_data[reduc16],   \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], \"mul_mat_vec_f32_f16_f32\",  arr_dmmv_f32_f16_f32_len[reduc],  arr_dmmv_f32_f16_f32_data[reduc],  \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], \"mul_mat_vec_f16_f16_f32\",  arr_dmmv_f16_f16_f32_len[reduc],  arr_dmmv_f16_f16_f32_data[reduc],  \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], \"mul_mat_vec_bf16_f16_f32\", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], \"mul_mat_vec_q4_0_f16_f32\", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], \"mul_mat_vec_q4_1_f16_f32\", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], \"mul_mat_vec_q5_0_f16_f32\", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], \"mul_mat_vec_q5_1_f16_f32\", arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], \"mul_mat_vec_q8_0_f16_f32\", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], \"mul_mat_vec_q2_k_f16_f32\", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], \"mul_mat_vec_q3_k_f16_f32\", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], \"mul_mat_vec_q4_k_f16_f32\", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], \"mul_mat_vec_q5_k_f16_f32\", arr_dmmv_q5_k_f16_f32_len[reduc16], arr_dmmv_q5_k_f16_f32_data[reduc16], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], \"mul_mat_vec_q6_k_f16_f32\", arr_dmmv_q6_k_f16_f32_len[reduc16], arr_dmmv_q6_k_f16_f32_data[reduc16], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i],   \"mul_mat_vec_iq1_s_f16_f32\",   arr_dmmv_iq1_s_f16_f32_len[reduc16],   arr_dmmv_iq1_s_f16_f32_data[reduc16],   \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i],   \"mul_mat_vec_iq1_m_f16_f32\",   arr_dmmv_iq1_m_f16_f32_len[reduc16],   arr_dmmv_iq1_m_f16_f32_data[reduc16],   \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], \"mul_mat_vec_iq2_xxs_f16_f32\", arr_dmmv_iq2_xxs_f16_f32_len[reduc16], arr_dmmv_iq2_xxs_f16_f32_data[reduc16], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i],  \"mul_mat_vec_iq2_xs_f16_f32\",  arr_dmmv_iq2_xs_f16_f32_len[reduc16],  arr_dmmv_iq2_xs_f16_f32_data[reduc16],  \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i],   \"mul_mat_vec_iq2_s_f16_f32\",   arr_dmmv_iq2_s_f16_f32_len[reduc16],   arr_dmmv_iq2_s_f16_f32_data[reduc16],   \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], \"mul_mat_vec_iq3_xxs_f16_f32\", arr_dmmv_iq3_xxs_f16_f32_len[reduc16], arr_dmmv_iq3_xxs_f16_f32_data[reduc16], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i],   \"mul_mat_vec_iq3_s_f16_f32\",   arr_dmmv_iq3_s_f16_f32_len[reduc16],   arr_dmmv_iq3_s_f16_f32_data[reduc16],   \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i],  \"mul_mat_vec_iq4_xs_f16_f32\",  arr_dmmv_iq4_xs_f16_f32_len[reduc16],  arr_dmmv_iq4_xs_f16_f32_data[reduc16],  \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i],  \"mul_mat_vec_iq4_nl_f16_f32\",  arr_dmmv_iq4_nl_f16_f32_len[reduc16],  arr_dmmv_iq4_nl_f16_f32_data[reduc16],  \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i],   \"mul_mat_vec_mxfp4_f16_f32\",   arr_dmmv_mxfp4_f16_f32_len[reduc16],   arr_dmmv_mxfp4_f16_f32_data[reduc16],   \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);\n\n#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)\n            if (device->integer_dot_product) {\n                const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;\n                const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);\n\n                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], \"mul_mat_vec_q4_0_q8_1_f32\", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);\n                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], \"mul_mat_vec_q4_1_q8_1_f32\", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);\n                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], \"mul_mat_vec_q5_0_q8_1_f32\", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);\n                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], \"mul_mat_vec_q5_1_q8_1_f32\", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);\n                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], \"mul_mat_vec_q8_0_q8_1_f32\", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);\n\n                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], \"mul_mat_vec_mxfp4_q8_1_f32\", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);\n\n                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], \"mul_mat_vec_q2_k_q8_1_f32\", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);\n                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], \"mul_mat_vec_q3_k_q8_1_f32\", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);\n                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], \"mul_mat_vec_q4_k_q8_1_f32\", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);\n                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], \"mul_mat_vec_q5_k_q8_1_f32\", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);\n                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], \"mul_mat_vec_q6_k_q8_1_f32\", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);\n\n                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_S][i], \"mul_mat_vec_iq1_s_q8_1_f32\", arr_dmmv_iq1_s_q8_1_f32_len[reduc], arr_dmmv_iq1_s_q8_1_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);\n                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_M][i], \"mul_mat_vec_iq1_m_q8_1_f32\", arr_dmmv_iq1_m_q8_1_f32_len[reduc], arr_dmmv_iq1_m_q8_1_f32_data[reduc], \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);\n\n            }\n#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT\n        }\n\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], \"mul_mat_vec_id_f32_f32\",        arr_dmmv_id_f32_f32_f32_len[reduc],     arr_dmmv_id_f32_f32_f32_data[reduc],     \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], \"mul_mat_vec_id_f16_f32\",        arr_dmmv_id_f16_f32_f32_len[reduc],     arr_dmmv_id_f16_f32_f32_data[reduc],     \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], \"mul_mat_vec_id_bf16_f32\",       arr_dmmv_id_bf16_f32_f32_len[reduc],    arr_dmmv_id_bf16_f32_f32_data[reduc],    \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], \"mul_mat_vec_id_q4_0_f32\",       arr_dmmv_id_q4_0_f32_f32_len[reduc],    arr_dmmv_id_q4_0_f32_f32_data[reduc],    \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], \"mul_mat_vec_id_q4_1_f32\",       arr_dmmv_id_q4_1_f32_f32_len[reduc],    arr_dmmv_id_q4_1_f32_f32_data[reduc],    \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], \"mul_mat_vec_id_q5_0_f32\",       arr_dmmv_id_q5_0_f32_f32_len[reduc],    arr_dmmv_id_q5_0_f32_f32_data[reduc],    \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_1], \"mul_mat_vec_id_q5_1_f32\",       arr_dmmv_id_q5_1_f32_f32_len[reduc],    arr_dmmv_id_q5_1_f32_f32_data[reduc],    \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q8_0], \"mul_mat_vec_id_q8_0_f32\",       arr_dmmv_id_q8_0_f32_f32_len[reduc],    arr_dmmv_id_q8_0_f32_f32_data[reduc],    \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q2_K], \"mul_mat_vec_id_q2_k_f32\",       arr_dmmv_id_q2_k_f32_f32_len[reduc16],    arr_dmmv_id_q2_k_f32_f32_data[reduc16],    \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q3_K], \"mul_mat_vec_id_q3_k_f32\",       arr_dmmv_id_q3_k_f32_f32_len[reduc16],    arr_dmmv_id_q3_k_f32_f32_data[reduc16],    \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_K], \"mul_mat_vec_id_q4_k_f32\",       arr_dmmv_id_q4_k_f32_f32_len[reduc16],    arr_dmmv_id_q4_k_f32_f32_data[reduc16],    \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_K], \"mul_mat_vec_id_q5_k_f32\",       arr_dmmv_id_q5_k_f32_f32_len[reduc16],    arr_dmmv_id_q5_k_f32_f32_data[reduc16],    \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q6_K], \"mul_mat_vec_id_q6_k_f32\",       arr_dmmv_id_q6_k_f32_f32_len[reduc16],    arr_dmmv_id_q6_k_f32_f32_data[reduc16],    \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_S],   \"mul_mat_vec_id_iq1_s_f32\",   arr_dmmv_id_iq1_s_f32_f32_len[reduc16],   arr_dmmv_id_iq1_s_f32_f32_data[reduc16],   \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_M],   \"mul_mat_vec_id_iq1_m_f32\",   arr_dmmv_id_iq1_m_f32_f32_len[reduc16],   arr_dmmv_id_iq1_m_f32_f32_data[reduc16],   \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XXS], \"mul_mat_vec_id_iq2_xxs_f32\", arr_dmmv_id_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xxs_f32_f32_data[reduc16], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XS],  \"mul_mat_vec_id_iq2_xs_f32\",  arr_dmmv_id_iq2_xs_f32_f32_len[reduc16],  arr_dmmv_id_iq2_xs_f32_f32_data[reduc16],  \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_S],   \"mul_mat_vec_id_iq2_s_f32\",   arr_dmmv_id_iq2_s_f32_f32_len[reduc16],   arr_dmmv_id_iq2_s_f32_f32_data[reduc16],   \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_XXS], \"mul_mat_vec_id_iq3_xxs_f32\", arr_dmmv_id_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq3_xxs_f32_f32_data[reduc16], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_S],   \"mul_mat_vec_id_iq3_s_f32\",   arr_dmmv_id_iq3_s_f32_f32_len[reduc16],   arr_dmmv_id_iq3_s_f32_f32_data[reduc16],   \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS],  \"mul_mat_vec_id_iq4_xs_f32\",  arr_dmmv_id_iq4_xs_f32_f32_len[reduc16],  arr_dmmv_id_iq4_xs_f32_f32_data[reduc16],  \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL],  \"mul_mat_vec_id_iq4_nl_f32\",  arr_dmmv_id_iq4_nl_f32_f32_len[reduc16],  arr_dmmv_id_iq4_nl_f32_f32_data[reduc16],  \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);\n        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4],   \"mul_mat_vec_id_mxfp4_f32\",   arr_dmmv_id_mxfp4_f32_f32_len[reduc16],   arr_dmmv_id_mxfp4_f32_f32_data[reduc16],   \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);\n\n#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)\n        if (device->integer_dot_product) {\n            const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;\n            const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);\n\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], \"mul_mat_vec_id_q4_0_q8_1_f32\", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], \"mul_mat_vec_id_q4_1_q8_1_f32\", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], \"mul_mat_vec_id_q5_0_q8_1_f32\", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], \"mul_mat_vec_id_q5_1_q8_1_f32\", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], \"mul_mat_vec_id_q8_0_q8_1_f32\", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);\n\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], \"mul_mat_vec_id_mxfp4_q8_1_f32\", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);\n\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], \"mul_mat_vec_id_q2_k_q8_1_f32\", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], \"mul_mat_vec_id_q3_k_q8_1_f32\", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], \"mul_mat_vec_id_q4_k_q8_1_f32\", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], \"mul_mat_vec_id_q5_k_q8_1_f32\", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], \"mul_mat_vec_id_q6_k_q8_1_f32\", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);\n\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], \"mul_mat_vec_id_iq1_s_q8_1_f32\", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);\n            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], \"mul_mat_vec_id_iq1_m_q8_1_f32\", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], \"main\", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);\n        }\n#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT\n    }\n\n#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)\n    GGML_UNUSED(rm_stdq_int);\n    GGML_UNUSED(rm_kq_int);\n    GGML_UNUSED(rm_iq_int);\n#endif\n\n    // dequant shaders\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], \"f32_to_f16\",   dequant_f32_len,  dequant_f32_data,  \"main\", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], \"dequant_q4_0\", dequant_q4_0_len, dequant_q4_0_data, \"main\", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], \"dequant_q4_1\", dequant_q4_1_len, dequant_q4_1_data, \"main\", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], \"dequant_q5_0\", dequant_q5_0_len, dequant_q5_0_data, \"main\", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], \"dequant_q5_1\", dequant_q5_1_len, dequant_q5_1_data, \"main\", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], \"dequant_q8_0\", dequant_q8_0_len, dequant_q8_0_data, \"main\", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], \"dequant_q2_k\", dequant_q2_k_len, dequant_q2_k_data, \"main\", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], \"dequant_q3_k\", dequant_q3_k_len, dequant_q3_k_data, \"main\", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], \"dequant_q4_k\", dequant_q4_k_len, dequant_q4_k_data, \"main\", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], \"dequant_q5_k\", dequant_q5_k_len, dequant_q5_k_data, \"main\", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], \"dequant_q6_k\", dequant_q6_k_len, dequant_q6_k_data, \"main\", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S],   \"dequant_iq1_s\",   dequant_iq1_s_len,   dequant_iq1_s_data,   \"main\", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M],   \"dequant_iq1_m\",   dequant_iq1_m_len,   dequant_iq1_m_data,   \"main\", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], \"dequant_iq2_xxs\", dequant_iq2_xxs_len, dequant_iq2_xxs_data, \"main\", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS],  \"dequant_iq2_xs\",  dequant_iq2_xs_len,  dequant_iq2_xs_data,  \"main\", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S],   \"dequant_iq2_s\",   dequant_iq2_s_len,   dequant_iq2_s_data,   \"main\", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], \"dequant_iq3_xxs\", dequant_iq3_xxs_len, dequant_iq3_xxs_data, \"main\", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S],   \"dequant_iq3_s\",   dequant_iq3_s_len,   dequant_iq3_s_data,   \"main\", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS],  \"dequant_iq4_xs\",  dequant_iq4_xs_len,  dequant_iq4_xs_data,  \"main\", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL],  \"dequant_iq4_nl\",  dequant_iq4_nl_len,  dequant_iq4_nl_data,  \"main\", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4],   \"dequant_mxfp4\",   dequant_mxfp4_len,   dequant_mxfp4_data,   \"main\", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);\n\n    // get_rows\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], \"get_rows_f32\",  get_rows_f32_len,  get_rows_f32_data,  \"main\", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], \"get_rows_f16\",  get_rows_f16_len,  get_rows_f16_data,  \"main\", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], \"get_rows_bf16\", get_rows_bf16_len, get_rows_bf16_data, \"main\", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], \"get_rows_q4_0\", get_rows_q4_0_len, get_rows_q4_0_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], \"get_rows_q4_1\", get_rows_q4_1_len, get_rows_q4_1_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], \"get_rows_q5_0\", get_rows_q5_0_len, get_rows_q5_0_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], \"get_rows_q5_1\", get_rows_q5_1_len, get_rows_q5_1_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], \"get_rows_q8_0\", get_rows_q8_0_len, get_rows_q8_0_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q2_K], \"get_rows_q2_k\", get_rows_q2_k_len, get_rows_q2_k_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q3_K], \"get_rows_q3_k\", get_rows_q3_k_len, get_rows_q3_k_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_K], \"get_rows_q4_k\", get_rows_q4_k_len, get_rows_q4_k_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_K], \"get_rows_q5_k\", get_rows_q5_k_len, get_rows_q5_k_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q6_K], \"get_rows_q6_k\", get_rows_q6_k_len, get_rows_q6_k_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S],   \"get_rows_iq1_s\",   get_rows_iq1_s_len,   get_rows_iq1_s_data,   \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M],   \"get_rows_iq1_m\",   get_rows_iq1_m_len,   get_rows_iq1_m_data,   \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], \"get_rows_iq2_xxs\", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS],  \"get_rows_iq2_xs\",  get_rows_iq2_xs_len,  get_rows_iq2_xs_data,  \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S],   \"get_rows_iq2_s\",   get_rows_iq2_s_len,   get_rows_iq2_s_data,   \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], \"get_rows_iq3_xxs\", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S],   \"get_rows_iq3_s\",   get_rows_iq3_s_len,   get_rows_iq3_s_data,   \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS],  \"get_rows_iq4_xs\",  get_rows_iq4_xs_len,  get_rows_iq4_xs_data,  \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL],  \"get_rows_iq4_nl\",  get_rows_iq4_nl_len,  get_rows_iq4_nl_data,  \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4],   \"get_rows_mxfp4\",   get_rows_mxfp4_len,   get_rows_mxfp4_data,   \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32],     \"get_rows_i32\",     get_rows_i32_len,     get_rows_i32_data,     \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], \"get_rows_f32_f32\",  get_rows_f32_f32_len,  get_rows_f32_f32_data,  \"main\", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], \"get_rows_f16_f32\",  get_rows_f16_f32_len,  get_rows_f16_f32_data,  \"main\", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], \"get_rows_bf16_f32\", get_rows_bf16_f32_len, get_rows_bf16_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], \"get_rows_q4_0_f32\", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], \"get_rows_q4_1_f32\", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], \"get_rows_q5_0_f32\", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], \"get_rows_q5_1_f32\", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], \"get_rows_q8_0_f32\", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q2_K], \"get_rows_q2_k_f32\", get_rows_q2_k_f32_len, get_rows_q2_k_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q3_K], \"get_rows_q3_k_f32\", get_rows_q3_k_f32_len, get_rows_q3_k_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_K], \"get_rows_q4_k_f32\", get_rows_q4_k_f32_len, get_rows_q4_k_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_K], \"get_rows_q5_k_f32\", get_rows_q5_k_f32_len, get_rows_q5_k_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q6_K], \"get_rows_q6_k_f32\", get_rows_q6_k_f32_len, get_rows_q6_k_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S],   \"get_rows_iq1_s_f32\",   get_rows_iq1_s_f32_len,   get_rows_iq1_s_f32_data,   \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M],   \"get_rows_iq1_m_f32\",   get_rows_iq1_m_f32_len,   get_rows_iq1_m_f32_data,   \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], \"get_rows_iq2_xxs_f32\", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS],  \"get_rows_iq2_xs_f32\",  get_rows_iq2_xs_f32_len,  get_rows_iq2_xs_f32_data,  \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S],   \"get_rows_iq2_s_f32\",   get_rows_iq2_s_f32_len,   get_rows_iq2_s_f32_data,   \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], \"get_rows_iq3_xxs_f32\", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S],   \"get_rows_iq3_s_f32\",   get_rows_iq3_s_f32_len,   get_rows_iq3_s_f32_data,   \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS],  \"get_rows_iq4_xs_f32\",  get_rows_iq4_xs_f32_len,  get_rows_iq4_xs_f32_data,  \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL],  \"get_rows_iq4_nl_f32\",  get_rows_iq4_nl_f32_len,  get_rows_iq4_nl_f32_data,  \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4],   \"get_rows_mxfp4_f32\",   get_rows_mxfp4_f32_len,   get_rows_mxfp4_f32_data,   \"main\", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, \"split_k_reduce\", split_k_reduce_len, split_k_reduce_data, \"main\", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, \"fa_split_k_reduce\", fa_split_k_reduce_len, fa_split_k_reduce_data, \"main\", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);\n\n    for (auto &it : device->pipeline_fa_mask_opt) {\n        auto BrBc = it.first;\n        ggml_vk_create_pipeline(device, it.second, \"fa_mask_opt\", fa_mask_opt_len, fa_mask_opt_data, \"main\", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size);\n    }\n\n    if (device->subgroup_clustered && device->subgroup_require_full_support) {\n        ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, \"quantize_q8_1_x4\", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, \"main\", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);\n    } else {\n        ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, \"quantize_q8_1_x4\", quantize_q8_1_x4_len, quantize_q8_1_x4_data, \"main\", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);\n    }\n\n    for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {\n        if (device->subgroup_arithmetic && device->subgroup_require_full_support) {\n            ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], \"mul_mat_vec_p021_f16_f32\"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_p021_push_constants), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);\n        } else {\n            ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], \"mul_mat_vec_p021_f16_f32\"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len,              mul_mat_vec_p021_f16_f32_data,              \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_p021_push_constants), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);\n        }\n    }\n    ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, \"mul_mat_vec_nc_f16_f32\", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, \"main\", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_norm_f32, \"norm_f32\", norm_f32_len, norm_f32_data, \"main\", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, \"group_norm_f32\", group_norm_f32_len, group_norm_f32_data, \"main\", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, \"rms_norm_f32\", rms_norm_f32_len, rms_norm_f32_data, \"main\", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);\n    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, \"rms_norm_mul_f32\", rms_norm_f32_len, rms_norm_f32_data, \"main\", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);\n    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, \"rms_norm_partials_f32\", rms_norm_partials_f32_len, rms_norm_partials_f32_data, \"main\", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);\n    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, \"rms_norm_mul_partials_f32\", rms_norm_partials_f32_len, rms_norm_partials_f32_data, \"main\", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);\n\n    if (device->float_controls_rte_fp16 &&\n        sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {\n        ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, \"rms_norm_mul_rope_f32_f32\", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, \"main\", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);\n        ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, \"rms_norm_mul_rope_f32_f16\", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, \"main\", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);\n    }\n\n    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, \"rms_norm_back_f32\", rms_norm_back_f32_len, rms_norm_back_f32_data, \"main\", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, \"l2_norm_f32\", l2_norm_f32_len, l2_norm_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, \"cpy_f32_f32\", cpy_f32_f32_len, cpy_f32_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, \"cpy_f32_f16\", cpy_f32_f16_len, cpy_f32_f16_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, \"cpy_f16_f16\", cpy_f16_f16_len, cpy_f16_f16_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, \"cpy_f16_f32\", cpy_f16_f32_len, cpy_f16_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,\"cpy_f32_bf16\",cpy_f32_bf16_len,cpy_f32_bf16_data,\"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, \"cpy_i32_f32\", cpy_i32_f32_len, cpy_i32_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, \"cpy_f32_i32\", cpy_f32_i32_len, cpy_f32_i32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, \"contig_cpy_f32_f32\", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, \"contig_cpy_f32_f16\", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, \"contig_cpy_f16_f16\", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, \"contig_cpy_f16_f32\", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,\"contig_cpy_f32_bf16\",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,\"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, \"contig_cpy_i32_f32\", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, \"contig_cpy_f32_i32\", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, \"cpy_transpose_32\", cpy_transpose_32_len, cpy_transpose_32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, \"cpy_transpose_16\", cpy_transpose_16_len, cpy_transpose_16_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);\n\n    if (device->float_controls_rte_fp16) {\n        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], \"cpy_f32_q4_0\", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], \"cpy_f32_q4_1\", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], \"cpy_f32_q5_0\", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], \"cpy_f32_q5_1\", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], \"cpy_f32_q8_0\", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], \"cpy_f32_iq4_nl\", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);\n    } else {\n        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], \"cpy_f32_q4_0\", cpy_f32_q4_0_len, cpy_f32_q4_0_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], \"cpy_f32_q4_1\", cpy_f32_q4_1_len, cpy_f32_q4_1_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], \"cpy_f32_q5_0\", cpy_f32_q5_0_len, cpy_f32_q5_0_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], \"cpy_f32_q5_1\", cpy_f32_q5_1_len, cpy_f32_q5_1_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], \"cpy_f32_q8_0\", cpy_f32_q8_0_len, cpy_f32_q8_0_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], \"cpy_f32_iq4_nl\", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);\n    }\n\n#define SET_ROWS(itype, rte) \\\n        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32],  \"set_rows_f32\" #itype,  set_rows_f32 ## itype ## rte ## _len,  set_rows_f32 ## itype ## rte ## _data,  \"main\", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \\\n        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16],  \"set_rows_f16\" #itype,  set_rows_f16 ## itype ## rte ## _len,  set_rows_f16 ## itype ## rte ## _data,  \"main\", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \\\n        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], \"set_rows_bf16\" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \\\n        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], \"set_rows_q4_0\" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \\\n        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], \"set_rows_q4_1\" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \\\n        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], \"set_rows_q5_0\" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \\\n        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], \"set_rows_q5_1\" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \\\n        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], \"set_rows_q8_0\" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \\\n        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], \"set_rows_iq4_nl\" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, \"main\", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);\n\n    if (device->float_controls_rte_fp16) {\n        SET_ROWS(_i32, _rte)\n        SET_ROWS(_i64, _rte)\n    } else {\n        SET_ROWS(_i32, )\n        SET_ROWS(_i64, )\n    }\n#undef SET_ROWS\n\n\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], \"cpy_q4_0_f32\", cpy_q4_0_f32_len, cpy_q4_0_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], \"cpy_q4_1_f32\", cpy_q4_1_f32_len, cpy_q4_1_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], \"cpy_q5_0_f32\", cpy_q5_0_f32_len, cpy_q5_0_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], \"cpy_q5_1_f32\", cpy_q5_1_f32_len, cpy_q5_1_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], \"cpy_q8_0_f32\", cpy_q8_0_f32_len, cpy_q8_0_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], \"cpy_iq4_nl_f32\", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);\n\n    auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {\n        std::string s;\n        s += std::string(src0_f16 ? \"_f16\" : \"_f32\");\n        s += std::string(src1_f16 ? \"_f16\" : \"_f32\");\n        s += std::string(dst_f16 ? \"_f16\" : \"_f32\");\n        return s;\n    };\n\n    bool rte = device->float_controls_rte_fp16;\n#define CREATE_BINARY(name, namemod, spec, bindings) \\\n    for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \\\n        ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \\\n                                #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \\\n                                \"main\", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);\n\n    CREATE_BINARY(add, , {0}, 4)\n    CREATE_BINARY(add, _norepeat, {1}, 4)\n    CREATE_BINARY(sub, , {0}, 3)\n    CREATE_BINARY(sub, _norepeat, {1}, 3)\n    CREATE_BINARY(mul, , {0}, 3)\n    CREATE_BINARY(mul, _norepeat, {1}, 3)\n    CREATE_BINARY(div, , {0}, 3)\n    CREATE_BINARY(div, _norepeat, {1}, 3)\n    CREATE_BINARY(add_rms, , {0}, 4)\n    CREATE_BINARY(add_rms, _norepeat, {1}, 4)\n#undef CREATE_BINARY\n\n    if (device->multi_add) {\n        for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {\n            ggml_vk_create_pipeline2(device, device->pipeline_multi_add[i],     \"multi_add_f32_\"     + std::to_string(i+1), multi_add_f32_len,     multi_add_f32_data,     \"main\", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);\n            ggml_vk_create_pipeline2(device, device->pipeline_multi_add_rms[i], \"multi_add_rms_f32_\" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, \"main\", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);\n        }\n    }\n\n    ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, \"add_id_f32\", add_id_f32_len, add_id_f32_data, \"main\", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_acc_f32, \"acc_f32\", acc_f32_len, acc_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_set_f32, \"set_f32\", acc_f32_len, acc_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_concat_f32, \"concat_f32\", concat_f32_len, concat_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_concat_f16, \"concat_f16\", concat_f16_len, concat_f16_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_concat_i32, \"concat_i32\", concat_i32_len, concat_i32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, \"upscale_f32\", upscale_f32_len, upscale_f32_data, \"main\", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, \"upscale_f32\", upscale_f32_len, upscale_f32_data, \"main\", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_upscale_bicubic_f32, \"upscale_f32\", upscale_f32_len, upscale_f32_data, \"main\", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BICUBIC}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_antialias_f32, \"upscale_f32\", upscale_f32_len, upscale_f32_data, \"main\", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_scale_f32, \"scale_f32\", scale_f32_len, scale_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, \"sqr_f32\", sqr_f32_len, sqr_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, \"sqrt_f32\", sqrt_f32_len, sqrt_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_sin_f32, \"sin_f32\", sin_f32_len, sin_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_cos_f32, \"cos_f32\", cos_f32_len, cos_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n\n    if (device->float_controls_rte_fp16) {\n        ggml_vk_create_pipeline(device, device->pipeline_log[0], \"log_f32_rte\", log_f32_rte_len, log_f32_rte_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_log[1], \"log_f16_rte\", log_f16_rte_len, log_f16_rte_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    } else {\n        ggml_vk_create_pipeline(device, device->pipeline_log[0], \"log_f32\", log_f32_len, log_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_log[1], \"log_f16\", log_f16_len, log_f16_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    }\n\n    ggml_vk_create_pipeline(device, device->pipeline_tri[0], \"tri_f32\", tri_f32_len, tri_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_tri[1], \"tri_f16\", tri_f16_len, tri_f16_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_diag[0], \"diag_f32\", diag_f32_len, diag_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_diag[1], \"diag_f16\", diag_f16_len, diag_f16_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, \"clamp_f32\", clamp_f32_len, clamp_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_pad_f32, \"pad_f32\", pad_f32_len, pad_f32_data, \"main\", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_roll_f32, \"roll_f32\", roll_f32_len, roll_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, \"repeat_f32\", repeat_f32_len, repeat_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, \"repeat_back_f32\", repeat_back_f32_len, repeat_back_f32_data, \"main\", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);\n\n#define CREATE_UNARY(name)  \\\n    ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name \"_f32\", name ## _f32_len, name ## _f32_data, \"main\", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);  \\\n    ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name \"_f16\", name ## _f16_len, name ## _f16_data, \"main\", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);\n\n    CREATE_UNARY(elu)\n    CREATE_UNARY(gelu)\n    CREATE_UNARY(gelu_erf)\n    CREATE_UNARY(gelu_quick)\n    CREATE_UNARY(silu)\n    CREATE_UNARY(relu)\n    CREATE_UNARY(xielu)\n    CREATE_UNARY(neg)\n    CREATE_UNARY(tanh)\n    CREATE_UNARY(sigmoid)\n    CREATE_UNARY(hardsigmoid)\n    CREATE_UNARY(hardswish)\n    CREATE_UNARY(abs)\n    CREATE_UNARY(softplus)\n    CREATE_UNARY(step)\n    CREATE_UNARY(round)\n    CREATE_UNARY(ceil)\n    CREATE_UNARY(floor)\n    CREATE_UNARY(trunc)\n    CREATE_UNARY(sgn)\n#undef CREATE_UNARY\n\n#define CREATE_UNARY_RTE(name)  \\\n    if (device->float_controls_rte_fp16) {  \\\n        ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name \"_f32_rte\", name ## _f32_rte_len, name ## _f32_rte_data, \"main\", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);   \\\n        ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name \"_f16_rte\", name ## _f16_rte_len, name ## _f16_rte_data, \"main\", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);   \\\n    } else {    \\\n        ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name \"_f32\", name ## _f32_len, name ## _f32_data, \"main\", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);   \\\n        ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name \"_f16\", name ## _f16_len, name ## _f16_data, \"main\", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);   \\\n    }\n    CREATE_UNARY_RTE(exp)\n#undef CREATE_UNARY_RTE\n\n    ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, \"add1_f16_f16\", add1_f16_f16_len, add1_f16_f16_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, \"add1_f16_f32\", add1_f16_f32_len, add1_f16_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, \"add1_f32_f32\", add1_f32_f32_len, add1_f32_f32_data, \"main\", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_arange_f32, \"arange_f32\", arange_f32_len, arange_f32_data, \"main\", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_fill_f32, \"fill_f32\", fill_f32_len, fill_f32_data, \"main\", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);\n\n#define CREATE_GLU(name)  \\\n    if (device->float_controls_rte_fp16) {  \\\n        ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name \"_f32_rte\", name ## _f32_rte_len, name ## _f32_rte_data, \"main\", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \\\n        ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name \"_f16_rte\", name ## _f16_rte_len, name ## _f16_rte_data, \"main\", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \\\n    } else {    \\\n        ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name \"_f32\", name ## _f32_len, name ## _f32_data, \"main\", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \\\n        ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name \"_f16\", name ## _f16_len, name ## _f16_data, \"main\", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \\\n    }\n\n    CREATE_GLU(geglu)\n    CREATE_GLU(reglu)\n    CREATE_GLU(swiglu)\n    CREATE_GLU(swiglu_oai)\n    CREATE_GLU(geglu_erf)\n    CREATE_GLU(geglu_quick)\n#undef CREATE_GLU\n\n    ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, \"leaky_relu_f32\", leaky_relu_f32_len, leaky_relu_f32_data, \"main\", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, \"silu_back_f32\", silu_back_f32_len, silu_back_f32_data, \"main\", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, \"diag_mask_inf_f32\", diag_mask_inf_f32_len, diag_mask_inf_f32_data, \"main\", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);\n\n    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, \"soft_max_f32\", soft_max_f32_len, soft_max_f32_data, \"main\", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, \"soft_max_f32_wg512\", soft_max_f32_len, soft_max_f32_data, \"main\", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, \"soft_max_f32_f16\", soft_max_f32_f16_len, soft_max_f32_f16_data, \"main\", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, \"soft_max_f32_f16_wg512\", soft_max_f32_f16_len, soft_max_f32_f16_data, \"main\", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, \"soft_max_back_f32\", soft_max_back_f32_len, soft_max_back_f32_data, \"main\", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);\n\n    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32,     \"soft_max_large1_f32\",     soft_max_large1_f32_len,     soft_max_large1_f32_data,     \"main\", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);\n    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32,     \"soft_max_large2_f32\",     soft_max_large2_f32_len,     soft_max_large2_f32_data,     \"main\", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);\n    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32,     \"soft_max_large3_f32\",     soft_max_large3_f32_len,     soft_max_large3_f32_data,     \"main\", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);\n    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, \"soft_max_large1_f32_f16\", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, \"main\", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);\n    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, \"soft_max_large2_f32_f16\", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, \"main\", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);\n    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, \"soft_max_large3_f32_f16\", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, \"main\", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);\n\n    ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, \"rope_norm_f32\", rope_norm_f32_len, rope_norm_f32_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, \"rope_neox_f32\", rope_neox_f32_len, rope_neox_f32_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, \"rope_multi_f32\", rope_multi_f32_len, rope_multi_f32_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, \"rope_vision_f32\", rope_vision_f32_len, rope_vision_f32_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n\n    if (device->float_controls_rte_fp16) {\n        ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, \"rope_norm_f16\", rope_norm_f16_rte_len, rope_norm_f16_rte_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, \"rope_neox_f16\", rope_neox_f16_rte_len, rope_neox_f16_rte_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, \"rope_multi_f16\", rope_multi_f16_rte_len, rope_multi_f16_rte_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, \"rope_vision_f16\", rope_vision_f16_rte_len, rope_vision_f16_rte_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n\n        ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, \"rope_norm_f32_f16\", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, \"rope_neox_f32_f16\", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, \"rope_multi_f32_f16\", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n    } else {\n        ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, \"rope_norm_f16\", rope_norm_f16_len, rope_norm_f16_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, \"rope_neox_f16\", rope_neox_f16_len, rope_neox_f16_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, \"rope_multi_f16\", rope_multi_f16_len, rope_multi_f16_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, \"rope_vision_f16\", rope_vision_f16_len, rope_vision_f16_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n\n        ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, \"rope_norm_f32_f16\", rope_norm_f32_f16_len, rope_norm_f32_f16_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, \"rope_neox_f32_f16\", rope_neox_f32_f16_len, rope_neox_f32_f16_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, \"rope_multi_f32_f16\", rope_multi_f32_f16_len, rope_multi_f32_f16_data, \"main\", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);\n    }\n\n    for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {\n        uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);\n        if (i <= device->max_workgroup_size_log2 &&\n            2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {\n            const uint32_t NCOLS_PADDED_LOG2 = i;\n            ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], \"argsort_f32_\"+std::to_string(i), argsort_f32_len, argsort_f32_data, \"main\", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);\n        }\n        const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1;\n        BLOCK_SIZE /= WG_UNROLL_FACTOR;\n        ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], \"argsort_large_f32_\"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, \"main\", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);\n    }\n\n    for (uint32_t i = 0; i < num_topk_pipelines; ++i) {\n        const uint32_t BLOCK_SIZE = 1u << i;\n        const uint32_t NCOLS_PADDED_LOG2 = i;\n        if (i <= device->max_workgroup_size_log2) {\n            uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +\n                                  sizeof(int) * device->subgroup_size +\n                                  2 * sizeof(int) +\n                                  2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int);\n            if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&\n                nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {\n                ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], \"topk_f32_\"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, \"main\", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);\n            } else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {\n                ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], \"topk_f32_\"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, \"main\", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);\n            }\n        }\n    }\n\n    ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, \"argmax_f32\", argmax_f32_len, argmax_f32_data, \"main\", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, \"sum_rows_f32\", sum_rows_f32_len, sum_rows_f32_data, \"main\", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);\n\n    const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;\n    ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32,       \"cumsum_f32\", cumsum_f32_len, cumsum_f32_data, \"main\", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);\n    ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, \"cumsum_f32\", cumsum_f32_len, cumsum_f32_data, \"main\", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size);\n    ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, \"cumsum_multipass1_f32\", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, \"main\", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);\n    ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, \"cumsum_multipass2_f32\", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, \"main\", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);\n\n    ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, \"count_equal_i32\", count_equal_i32_len, count_equal_i32_data, \"main\", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_count_experts, \"count_experts\", count_experts_len, count_experts_data, \"main\", 2, sizeof(vk_op_count_experts_push_constants), {1, 1, 1}, {}, 1, true);\n\n    for (auto &s : device->pipeline_solve_tri_f32) {\n        const vk_solve_tri_pipeline_state &state = s.first;\n\n        // Max number of rows to load at a time, limited by shared memory\n        const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((state.N + state.K) * sizeof(float));\n        // Need at least K invocations, and prefer a minimum of 128 to spread out loading shared memory\n        const uint32_t block_size = std::max(128u, 1u << (uint32_t)ceilf(log2f(float(state.K))));\n\n        ggml_vk_create_pipeline(\n            device, s.second, \"solve_tri_f32\",\n            solve_tri_f32_len, solve_tri_f32_data, \"main\", 3,\n            sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K, batch_N, block_size }, 1, true);\n    }\n\n#define IM2COL(bda) \\\n    ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, \"im2col_f32\", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, \"main\", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);   \\\n    ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, \"im2col_3d_f32\", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, \"main\", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);      \\\n    if (device->float_controls_rte_fp16) {  \\\n        ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, \"im2col_f32_f16\", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, \"main\", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);   \\\n        ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, \"im2col_3d_f32_f16\", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, \"main\", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);      \\\n    } else {    \\\n        ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, \"im2col_f32_f16\", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, \"main\", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);   \\\n        ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, \"im2col_3d_f32_f16\", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, \"main\", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);      \\\n    }\n    if (device->shader_int64 && device->buffer_device_address) {\n        IM2COL(_bda)\n    } else {\n        IM2COL()\n    }\n\n    ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, \"timestep_embedding_f32\", timestep_embedding_f32_len, timestep_embedding_f32_data, \"main\", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, \"conv_transpose_1d_f32\", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, \"main\", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, \"pool2d_f32\", pool2d_f32_len, pool2d_f32_data, \"main\", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, \"rwkv_wkv6_f32\", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, \"main\", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, \"rwkv_wkv7_f32\", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, \"main\", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);\n\n    {\n        const uint32_t gdn_sizes[] = {32, 64, 128};\n        const char * gdn_names[][2] = {\n            {\"gated_delta_net_f32_d32\",     \"gated_delta_net_f32_d32_kda\"},\n            {\"gated_delta_net_f32_d64\",     \"gated_delta_net_f32_d64_kda\"},\n            {\"gated_delta_net_f32_d128\",    \"gated_delta_net_f32_d128_kda\"},\n        };\n        for (uint32_t si = 0; si < 3; si++) {\n            for (uint32_t kda = 0; kda < 2; kda++) {\n                ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],\n                    gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,\n                    \"main\", 7, sizeof(vk_op_gated_delta_net_push_constants),\n                    {1, 1, 1}, {gdn_sizes[si], kda}, 1);\n            }\n        }\n    }\n\n    if (device->subgroup_arithmetic && device->subgroup_require_full_support) {\n        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, \"ssm_scan_128_f32\", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, \"main\", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);\n        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, \"ssm_scan_256_f32\", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, \"main\", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);\n    } else {\n        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, \"ssm_scan_128_f32\", ssm_scan_f32_len, ssm_scan_f32_data, \"main\", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);\n        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, \"ssm_scan_256_f32\", ssm_scan_f32_len, ssm_scan_f32_data, \"main\", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);\n    }\n\n    ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, \"ssm_conv_f32\", ssm_conv_f32_len, ssm_conv_f32_data, \"main\", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, \"opt_step_adamw_f32\", opt_step_adamw_f32_len, opt_step_adamw_f32_data, \"main\", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);\n\n    ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, \"opt_step_sgd_f32\", opt_step_sgd_f32_len, opt_step_sgd_f32_data, \"main\", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);\n\n    // conv2d, conv_transpose_2d\n    for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {\n        uint32_t conv2d_WG_SIZE  = 256;\n        uint32_t use_collectives = 0;  // Enables subgroup ops for preventing the re-calculation of indices.\n        uint32_t conv2d_TS_K     = (s == CONV_SHAPE_64x32) ? 4 : 8;\n        uint32_t conv2d_SHMEM_PAD = 4;\n        vk_conv_block_size conv2d_BS = vk_conv_block_sizes[s];\n        bool conv2d_UNROLL = true;\n\n#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)\n        if (device->coopmat2) {\n            conv2d_SHMEM_PAD = 8; // 8 float16_t\n        }\n#endif\n\n        if (device->vendor_id == VK_VENDOR_ID_INTEL) {\n            conv2d_SHMEM_PAD = 0;\n            conv2d_UNROLL = false;\n        } else if (device->vendor_id == VK_VENDOR_ID_AMD) {\n            conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4;\n            if (s == CONV_SHAPE_128x128 && device->architecture != vk_device_architecture::AMD_GCN) {\n                conv2d_UNROLL = false;\n            }\n        }\n\n        // Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math.\n        bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA ||\n                                    device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;\n        bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD ||\n                                     device->architecture == vk_device_architecture::AMD_GCN;\n\n        if (device->subgroup_shuffle &&\n            device->vendor_id != VK_VENDOR_ID_INTEL &&   // Do not enable collectives on Intel, see PR 14316.\n            allow_collectives_nv &&\n            allow_collectives_amd) {\n            use_collectives = 1;\n            conv2d_BS.CRS   = std::min(\n                device->subgroup_size,\n                conv2d_BS.CRS);  // CRS block size should be capped at subgroup size for correctness when shuffle is used.\n        }\n\n        uint32_t conv2d_shmem_req =\n            (conv2d_BS.K * (conv2d_BS.CRS + conv2d_SHMEM_PAD) + conv2d_BS.CRS * (conv2d_BS.NPQ + conv2d_SHMEM_PAD)) * sizeof(float);\n        if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {\n            conv2d_BS.CRS = 8;\n            if (use_collectives) {\n                conv2d_BS.CRS = std::min(device->subgroup_size, conv2d_BS.CRS);\n            }\n        }\n\n        std::array<uint32_t, 3> wg_denoms = { conv2d_BS.K, 1, 1 };\n        std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };\n\n#define CREATE_CONV(name, type_suffix, spv_suffix) \\\n        for (auto &c : device->pipeline_##name##type_suffix[s]) { \\\n            const vk_conv2d_pipeline_state &state = c.first;  \\\n            std::vector<uint32_t> spec_constants_cpy = spec_constants; \\\n            spec_constants_cpy.push_back(state.s0); \\\n            spec_constants_cpy.push_back(state.s1); \\\n            spec_constants_cpy.push_back(state.p0); \\\n            spec_constants_cpy.push_back(state.p1); \\\n            spec_constants_cpy.push_back(state.d0); \\\n            spec_constants_cpy.push_back(state.d1); \\\n            spec_constants_cpy.push_back(state.KW); \\\n            spec_constants_cpy.push_back(state.KH); \\\n            ggml_vk_create_pipeline( \\\n                device, c.second, #name #type_suffix, \\\n                name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, \"main\", 3, \\\n                sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives);    \\\n        }\n#define CREATE_CONVS(spv_suffix) \\\n        CREATE_CONV(conv2d, _f32, spv_suffix) \\\n        CREATE_CONV(conv2d, _f16_f32, spv_suffix) \\\n        CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \\\n        CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix)\n#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)\n        if (device->coopmat2) {\n            CREATE_CONVS(_cm2)\n        } else\n#endif\n        if (conv2d_UNROLL) {\n            CREATE_CONVS(_unroll)\n        } else {\n            CREATE_CONVS( )\n        }\n#undef CREATE_CONV\n#undef CREATE_CONVS\n    }\n\n    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, \"conv2d_dw_whcn_f32\", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, \"main\", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, \"conv2d_dw_cwhn_f32\", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, \"main\", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, \"conv2d_dw_whcn_f16_f32\", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, \"main\", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);\n    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, \"conv2d_dw_cwhn_f16_f32\", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, \"main\", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);\n\n    for (uint32_t use_push = 0; use_push < 2; ++use_push) {\n        for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {\n            ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], \"topk_moe_f32_\"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, \"main\", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, use_push}, 1, true, true, device->subgroup_size);\n        }\n    }\n\n    for (auto &c : compiles) {\n        c.wait();\n    }\n}\n\nstatic bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);\nstatic uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev);\n\nstatic vk_device ggml_vk_get_device(size_t idx) {\n    VK_LOG_DEBUG(\"ggml_vk_get_device(\" << idx << \")\");\n\n    if (vk_instance.devices[idx] == nullptr) {\n        VK_LOG_DEBUG(\"Initializing new vk_device\");\n        vk_device device = std::make_shared<vk_device_struct>();\n        vk_instance.devices[idx] = device;\n\n        device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger());\n\n        size_t dev_num = vk_instance.device_indices[idx];\n\n        std::vector<vk::PhysicalDevice> physical_devices = vk_instance.instance.enumeratePhysicalDevices();\n\n        if (dev_num >= physical_devices.size()) {\n            std::cerr << \"ggml_vulkan: Device with index \" << dev_num << \" does not exist.\" << std::endl;\n            throw std::runtime_error(\"Device not found\");\n        }\n\n        device->physical_device = physical_devices[dev_num];\n        const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();\n\n        device->architecture = get_device_architecture(device->physical_device);\n\n        const char* GGML_VK_PREFER_HOST_MEMORY = getenv(\"GGML_VK_PREFER_HOST_MEMORY\");\n        device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;\n\n        const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv(\"GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM\");\n        device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr;\n\n        const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv(\"GGML_VK_ALLOW_SYSMEM_FALLBACK\");\n        device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr;\n\n        const char* GGML_VK_DISABLE_GRAPH_OPTIMIZE = getenv(\"GGML_VK_DISABLE_GRAPH_OPTIMIZE\");\n        device->disable_graph_optimize = GGML_VK_DISABLE_GRAPH_OPTIMIZE != nullptr;\n\n        bool fp16_storage = false;\n        bool fp16_compute = false;\n        bool maintenance4_support = false;\n        bool sm_builtins = false;\n        bool amd_shader_core_properties2 = false;\n        bool pipeline_robustness = false;\n        bool coopmat2_support = false;\n        bool pipeline_executable_properties_support = false;\n        device->coopmat_support = false;\n        device->integer_dot_product = false;\n        device->shader_64b_indexing = false;\n        bool bfloat16_support = false;\n\n        for (const auto& properties : ext_props) {\n            if (strcmp(\"VK_KHR_maintenance4\", properties.extensionName) == 0) {\n                maintenance4_support = true;\n            } else if (strcmp(\"VK_KHR_16bit_storage\", properties.extensionName) == 0) {\n                fp16_storage = true;\n            } else if (strcmp(\"VK_KHR_shader_float16_int8\", properties.extensionName) == 0) {\n                fp16_compute = true;\n            } else if (strcmp(\"VK_NV_shader_sm_builtins\", properties.extensionName) == 0) {\n                sm_builtins = true;\n            } else if (strcmp(\"VK_AMD_shader_core_properties2\", properties.extensionName) == 0) {\n                amd_shader_core_properties2 = true;\n            } else if (strcmp(\"VK_EXT_pipeline_robustness\", properties.extensionName) == 0) {\n                pipeline_robustness = true;\n            } else if (strcmp(\"VK_EXT_subgroup_size_control\", properties.extensionName) == 0) {\n                device->subgroup_size_control = true;\n#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)\n            } else if (strcmp(\"VK_KHR_cooperative_matrix\", properties.extensionName) == 0 &&\n                       !getenv(\"GGML_VK_DISABLE_COOPMAT\")) {\n                device->coopmat_support = true;\n                device->coopmat_m = 0;\n                device->coopmat_n = 0;\n                device->coopmat_k = 0;\n#endif\n#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)\n            } else if (strcmp(\"VK_NV_cooperative_matrix2\", properties.extensionName) == 0 &&\n                       !getenv(\"GGML_VK_DISABLE_COOPMAT2\")) {\n                coopmat2_support = true;\n#endif\n#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)\n            } else if (strcmp(\"VK_KHR_shader_integer_dot_product\", properties.extensionName) == 0 &&\n                       !getenv(\"GGML_VK_DISABLE_INTEGER_DOT_PRODUCT\")) {\n                device->integer_dot_product = true;\n#endif\n#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)\n            } else if (strcmp(\"VK_KHR_shader_bfloat16\", properties.extensionName) == 0 &&\n                       !getenv(\"GGML_VK_DISABLE_BFLOAT16\")) {\n                bfloat16_support = true;\n#endif\n            } else if (strcmp(\"VK_KHR_pipeline_executable_properties\", properties.extensionName) == 0) {\n                pipeline_executable_properties_support = true;\n            } else if (strcmp(\"VK_EXT_memory_priority\", properties.extensionName) == 0 &&\n                       getenv(\"GGML_VK_ENABLE_MEMORY_PRIORITY\")) {\n                device->memory_priority = true;\n            } else if (strcmp(\"VK_EXT_external_memory_host\", properties.extensionName) == 0) {\n                device->external_memory_host = true;\n#if defined(VK_EXT_shader_64bit_indexing)\n            } else if (strcmp(\"VK_EXT_shader_64bit_indexing\", properties.extensionName) == 0) {\n                device->shader_64b_indexing = true;\n#endif\n            }\n        }\n\n        vk::PhysicalDeviceProperties2 props2;\n        vk::PhysicalDeviceMaintenance3Properties props3;\n        vk::PhysicalDeviceMaintenance4Properties props4;\n        vk::PhysicalDeviceSubgroupProperties subgroup_props;\n        vk::PhysicalDeviceDriverProperties driver_props;\n        vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;\n        vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;\n        vk::PhysicalDeviceVulkan11Properties vk11_props;\n        vk::PhysicalDeviceVulkan12Properties vk12_props;\n        vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;\n        vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;\n        vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props;\n\n        props2.pNext = &props3;\n        props3.pNext = &subgroup_props;\n        subgroup_props.pNext = &driver_props;\n        driver_props.pNext = &vk11_props;\n        vk11_props.pNext = &vk12_props;\n\n        VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;\n\n        if (maintenance4_support) {\n            last_struct->pNext = (VkBaseOutStructure *)&props4;\n            last_struct = (VkBaseOutStructure *)&props4;\n        }\n        if (sm_builtins) {\n            last_struct->pNext = (VkBaseOutStructure *)&sm_props;\n            last_struct = (VkBaseOutStructure *)&sm_props;\n        }\n        if (amd_shader_core_properties2) {\n            last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;\n            last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;\n        }\n        if (device->subgroup_size_control) {\n            last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props;\n            last_struct = (VkBaseOutStructure *)&subgroup_size_control_props;\n        }\n\n#if defined(VK_NV_cooperative_matrix2)\n        vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;\n        if (coopmat2_support) {\n            last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props;\n            last_struct = (VkBaseOutStructure *)&coopmat2_props;\n        }\n#endif\n\n        if (device->integer_dot_product) {\n            last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;\n            last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;\n        }\n\n        if (device->external_memory_host) {\n            last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props;\n            last_struct = (VkBaseOutStructure *)&external_memory_host_props;\n        }\n\n        device->physical_device.getProperties2(&props2);\n        device->properties = props2.properties;\n        device->vendor_id = device->properties.vendorID;\n        device->driver_id = driver_props.driverID;\n\n        if (device->driver_id == vk::DriverId::eMoltenvk) {\n            // Disable external_memory_host until https://github.com/KhronosGroup/MoltenVK/pull/2622\n            // is available in the Vulkan SDK.\n            device->external_memory_host = false;\n        }\n\n        // Implementing the async backend interfaces seems broken on older Intel HW,\n        // see https://github.com/ggml-org/llama.cpp/issues/17302.\n        device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL ||\n                                 std::string(device->properties.deviceName.data()).find(\"(DG1)\") == std::string::npos) &&\n                                getenv(\"GGML_VK_DISABLE_ASYNC\") == nullptr;\n\n        if (!device->support_async) {\n            GGML_LOG_DEBUG(\"ggml_vulkan: WARNING: Async execution disabled on certain Intel devices.\\n\");\n        }\n\n        const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv(\"GGML_VK_FORCE_MAX_ALLOCATION_SIZE\");\n\n        if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {\n            device->max_memory_allocation_size = std::stoull(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);\n        } else if (maintenance4_support) {\n            device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);\n        } else {\n            device->max_memory_allocation_size = props3.maxMemoryAllocationSize;\n        }\n\n        const char* GGML_VK_FORCE_MAX_BUFFER_SIZE = getenv(\"GGML_VK_FORCE_MAX_BUFFER_SIZE\");\n\n        if (GGML_VK_FORCE_MAX_BUFFER_SIZE != nullptr) {\n            device->max_buffer_size = std::stoull(GGML_VK_FORCE_MAX_BUFFER_SIZE);\n        } else if (maintenance4_support) {\n            device->max_buffer_size = props4.maxBufferSize;\n        } else {\n            device->max_buffer_size = device->max_memory_allocation_size;\n        }\n\n        const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv(\"GGML_VK_SUBALLOCATION_BLOCK_SIZE\");\n\n        if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {\n            device->suballocation_block_size = std::stoull(GGML_VK_SUBALLOCATION_BLOCK_SIZE);\n        } else {\n            // Limit batching of allocations to 1GB by default to avoid fragmentation issues\n            device->suballocation_block_size = 1024*1024*1024;\n        }\n        device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);\n\n        device->subgroup_size = subgroup_props.subgroupSize;\n        device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));\n        device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;\n        if (sm_builtins) {\n            device->shader_core_count = sm_props.shaderSMCount;\n        } else if (amd_shader_core_properties2) {\n            device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;\n        } else if (device->vendor_id == VK_VENDOR_ID_INTEL) {\n            device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device);\n        } else {\n            device->shader_core_count = 0;\n        }\n        device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;\n\n        device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&\n                                 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic);\n        device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&\n                                      (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);\n#ifdef __APPLE__\n        // Workaround for subgroup arithmetic failing on MoltenVK with AMD GPUs (issue 15846)\n        if (device->vendor_id == VK_VENDOR_ID_AMD) {\n            device->subgroup_arithmetic = false;\n        }\n#endif\n        device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&\n                                   (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);\n        device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&\n                                     (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered);\n\n        device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&\n                                  (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);\n\n        device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&\n                                (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote);\n\n        const bool force_disable_f16 = getenv(\"GGML_VK_DISABLE_F16\") != nullptr;\n\n        device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;\n\n        if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) {\n            device->coopmat_support = false;\n        }\n\n        device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;\n\n        device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment;\n\n        device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));\n\n        std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();\n\n        // Try to find a non-graphics compute queue and transfer-focused queues\n        // On AMD, the graphics queue seems to be faster, so don't avoid it\n        const vk::QueueFlagBits graphics_flag = device->vendor_id == VK_VENDOR_ID_AMD ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics;\n        const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, graphics_flag, -1, 1);\n        const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | graphics_flag, compute_queue_family_index, 1);\n\n        const float priorities[] = { 1.0f, 1.0f };\n        device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;\n\n        std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;\n        if (compute_queue_family_index != transfer_queue_family_index) {\n            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});\n            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});\n        } else if(!device->single_queue) {\n            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});\n        } else {\n            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});\n        }\n        vk::DeviceCreateInfo device_create_info;\n        std::vector<const char *> device_extensions;\n        vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();\n\n        VkPhysicalDeviceFeatures2 device_features2;\n        device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;\n        device_features2.pNext = nullptr;\n        device_features2.features = (VkPhysicalDeviceFeatures)device_features;\n\n        VkPhysicalDeviceVulkan11Features vk11_features;\n        vk11_features.pNext = nullptr;\n        vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;\n        device_features2.pNext = &vk11_features;\n\n        VkPhysicalDeviceVulkan12Features vk12_features;\n        vk12_features.pNext = nullptr;\n        vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;\n        vk11_features.pNext = &vk12_features;\n\n        last_struct = (VkBaseOutStructure *)&vk12_features;\n\n        VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;\n        pl_robustness_features.pNext = nullptr;\n        pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;\n        pl_robustness_features.pipelineRobustness = VK_FALSE;\n\n        if (pipeline_robustness) {\n            last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features;\n            last_struct = (VkBaseOutStructure *)&pl_robustness_features;\n            device_extensions.push_back(\"VK_EXT_pipeline_robustness\");\n        }\n\n        VkPhysicalDeviceMemoryPriorityFeaturesEXT memory_priority_features;\n        memory_priority_features.pNext = nullptr;\n        memory_priority_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PRIORITY_FEATURES_EXT;\n        memory_priority_features.memoryPriority = VK_FALSE;\n        if (device->memory_priority) {\n            last_struct->pNext = (VkBaseOutStructure *)&memory_priority_features;\n            last_struct = (VkBaseOutStructure *)&memory_priority_features;\n            device_extensions.push_back(\"VK_EXT_memory_priority\");\n        }\n\n        VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;\n        subgroup_size_control_features.pNext = nullptr;\n        subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;\n        subgroup_size_control_features.computeFullSubgroups = false;\n        subgroup_size_control_features.subgroupSizeControl = false;\n\n        if (device->subgroup_size_control) {\n            last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features;\n            last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;\n        }\n\n#if defined(VK_KHR_cooperative_matrix)\n        VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;\n        coopmat_features.pNext = nullptr;\n        coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;\n        coopmat_features.cooperativeMatrix = VK_FALSE;\n\n        if (device->coopmat_support) {\n            last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;\n            last_struct = (VkBaseOutStructure *)&coopmat_features;\n        }\n#endif\n\n#if defined(VK_NV_cooperative_matrix2)\n        VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};\n        coopmat2_features.pNext = nullptr;\n        coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;\n        if (coopmat2_support) {\n            last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features;\n            last_struct = (VkBaseOutStructure *)&coopmat2_features;\n            device_extensions.push_back(\"VK_NV_cooperative_matrix2\");\n        }\n#endif\n\n#if defined(VK_KHR_shader_bfloat16)\n        VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};\n        bfloat16_features.pNext = nullptr;\n        bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;\n        if (bfloat16_support) {\n            last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;\n            last_struct = (VkBaseOutStructure *)&bfloat16_features;\n            device_extensions.push_back(\"VK_KHR_shader_bfloat16\");\n        }\n#endif\n\n        VkPhysicalDeviceMaintenance4Features maint4_features {};\n        maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;\n        if (maintenance4_support) {\n            last_struct->pNext = (VkBaseOutStructure *)&maint4_features;\n            last_struct = (VkBaseOutStructure *)&maint4_features;\n            device_extensions.push_back(\"VK_KHR_maintenance4\");\n        }\n\n        VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};\n        shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;\n        if (device->integer_dot_product) {\n            last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;\n            last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;\n            device_extensions.push_back(\"VK_KHR_shader_integer_dot_product\");\n        }\n\n        VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {};\n        pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR;\n        if (pipeline_executable_properties_support) {\n            last_struct->pNext = (VkBaseOutStructure *)&pep_features;\n            last_struct = (VkBaseOutStructure *)&pep_features;\n            device_extensions.push_back(\"VK_KHR_pipeline_executable_properties\");\n        }\n\n        if (device->external_memory_host) {\n            device_extensions.push_back(\"VK_EXT_external_memory_host\");\n        }\n\n#if defined(VK_EXT_shader_64bit_indexing)\n        VkPhysicalDeviceShader64BitIndexingFeaturesEXT shader_64bit_indexing_features {};\n        shader_64bit_indexing_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_64_BIT_INDEXING_FEATURES_EXT;\n        if (device->shader_64b_indexing) {\n            last_struct->pNext = (VkBaseOutStructure *)&shader_64bit_indexing_features;\n            last_struct = (VkBaseOutStructure *)&shader_64bit_indexing_features;\n            device_extensions.push_back(\"VK_EXT_shader_64bit_indexing\");\n        }\n#endif\n\n        vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);\n\n        device->pipeline_executable_properties_support = pipeline_executable_properties_support;\n\n        device->fp16 = device->fp16 && vk12_features.shaderFloat16;\n\n#if defined(VK_KHR_shader_bfloat16)\n        device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;\n#else\n        device->bf16 = false;\n#endif\n\n        device->pipeline_robustness = pl_robustness_features.pipelineRobustness;\n\n        device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&\n                            device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&\n                            getenv(\"GGML_VK_DISABLE_MULTI_ADD\") == nullptr;\n\n        device->shader_int64 = device_features2.features.shaderInt64;\n        device->buffer_device_address = vk12_features.bufferDeviceAddress;\n        device->vulkan_memory_model = vk12_features.vulkanMemoryModel;\n\n        if (device->subgroup_size_control) {\n            device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;\n            device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;\n            device_extensions.push_back(\"VK_EXT_subgroup_size_control\");\n        }\n\n        device->subgroup_size_control = device->subgroup_size_control &&\n                (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&\n                subgroup_size_control_features.subgroupSizeControl;\n\n        device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;\n\n#if defined(VK_KHR_cooperative_matrix)\n        device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;\n        device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support;\n#endif\n\n        if (coopmat2_support) {\n#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)\n            if (coopmat2_features.cooperativeMatrixWorkgroupScope &&\n                coopmat2_features.cooperativeMatrixFlexibleDimensions &&\n                coopmat2_features.cooperativeMatrixReductions &&\n                coopmat2_features.cooperativeMatrixConversions &&\n                coopmat2_features.cooperativeMatrixPerElementOperations &&\n                coopmat2_features.cooperativeMatrixTensorAddressing &&\n                coopmat2_features.cooperativeMatrixBlockLoads &&\n                vk12_features.bufferDeviceAddress) {\n\n                std::vector<VkCooperativeMatrixFlexibleDimensionsPropertiesNV> flexible_dimensions;\n                uint32_t count = 0;\n\n                PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV\n                    _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV =\n                        (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV)\n                        vk_instance.instance.getProcAddr(\"vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV\");\n\n                _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr);\n\n                VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {};\n                empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV;\n                flexible_dimensions.resize(count, empty_prop);\n\n                _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data());\n\n                bool found_fp16_128 = false,\n                     found_fp16_256 = false,\n                     found_fp32_128 = false,\n                     found_fp32_256 = false;\n                // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128\n                // with 32x16x16 and 256 with 32x32x16.\n                for (auto &prop : flexible_dimensions) {\n                    if (prop.saturatingAccumulation == VK_FALSE &&\n                        prop.scope == VK_SCOPE_WORKGROUP_KHR &&\n                        prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&\n                        prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {\n\n                        if (prop.workgroupInvocations == 128 &&\n                            prop.MGranularity <= 32 &&\n                            prop.NGranularity <= 16 &&\n                            prop.KGranularity <= 16) {\n                            if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&\n                                prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {\n                                found_fp16_128 = true;\n                            }\n                            if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&\n                                prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {\n                                found_fp32_128 = true;\n                            }\n                        }\n                        if (prop.workgroupInvocations == 256 &&\n                            prop.MGranularity <= 32 &&\n                            prop.NGranularity <= 32 &&\n                            prop.KGranularity <= 16) {\n                            if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&\n                                prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {\n                                found_fp16_256 = true;\n                            }\n                            if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&\n                                prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {\n                                found_fp32_256 = true;\n                            }\n                        }\n                    }\n                }\n                if (found_fp16_128 && found_fp16_256 &&\n                    found_fp32_128 && found_fp32_256 &&\n                    coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {\n                    device->coopmat2 = true;\n                }\n            }\n#endif\n        }\n\n        if (!vk11_features.storageBuffer16BitAccess) {\n            std::cerr << \"ggml_vulkan: device \" << GGML_VK_NAME << idx << \" does not support 16-bit storage.\" << std::endl;\n            throw std::runtime_error(\"Unsupported device\");\n        }\n\n        device_extensions.push_back(\"VK_KHR_16bit_storage\");\n\n#ifdef GGML_VULKAN_VALIDATE\n        device_extensions.push_back(\"VK_KHR_shader_non_semantic_info\");\n#endif\n\n        if (device->fp16) {\n            device_extensions.push_back(\"VK_KHR_shader_float16_int8\");\n        }\n\n#if defined(VK_KHR_cooperative_matrix)\n        if (device->coopmat_support) {\n            // Query supported shapes\n            std::vector<VkCooperativeMatrixPropertiesKHR> cm_props;\n\n            PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR =\n                (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, \"vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR\");\n\n            uint32_t cm_props_num;\n\n            pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr);\n\n            cm_props.resize(cm_props_num);\n\n            for (auto& prop : cm_props) {\n                prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;\n            }\n\n            pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data());\n\n            VK_LOG_DEBUG(\"ggml_vulkan: Cooperative Matrix Shapes: \" << cm_props.size());\n\n            for (auto& prop : cm_props) {\n                VK_LOG_DEBUG(\"ggml_vulkan: M: \" << prop.MSize << \" N: \" << prop.NSize << \" K: \" << prop.KSize << \" A: \" << vk::to_string((vk::ComponentTypeKHR)prop.AType) << \" B: \" << vk::to_string((vk::ComponentTypeKHR)prop.BType) << \" C: \" << vk::to_string((vk::ComponentTypeKHR)prop.CType) << \" Result: \" << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << \" saturatingAccumulation: \" << prop.saturatingAccumulation << \" scope: \" << vk::to_string((vk::ScopeKHR)prop.scope));\n\n                if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&\n                    (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&\n                    (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup\n                ) {\n                    if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&\n                        (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) {\n                        // coopmat sizes not set yet\n                        if (device->coopmat_m == 0) {\n                            device->coopmat_acc_f32_support = true;\n                            device->coopmat_m = prop.MSize;\n                            device->coopmat_n = prop.NSize;\n                            device->coopmat_k = prop.KSize;\n                        } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {\n                            // Only enable if shape is identical\n                            device->coopmat_acc_f32_support = true;\n                        }\n                        if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {\n                            device->coopmat_support_16x16x16_f32acc = true;\n                        }\n                    } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&\n                               (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {\n                        // coopmat sizes not set yet\n                        if (device->coopmat_m == 0) {\n                            device->coopmat_acc_f16_support = true;\n                            device->coopmat_m = prop.MSize;\n                            device->coopmat_n = prop.NSize;\n                            device->coopmat_k = prop.KSize;\n                        } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {\n                            // Only enable if shape is identical\n                            device->coopmat_acc_f16_support = true;\n                        }\n                        if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {\n                            device->coopmat_support_16x16x16_f16acc = true;\n                        }\n                    }\n                } else if ((vk::ComponentTypeKHR)prop.AType      == vk::ComponentTypeKHR::eSint8 &&\n                           (vk::ComponentTypeKHR)prop.BType      == vk::ComponentTypeKHR::eSint8 &&\n                           (vk::ComponentTypeKHR)prop.CType      == vk::ComponentTypeKHR::eSint32 &&\n                           (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&\n                           (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&\n                           device->coopmat_int_m == 0\n                ) {\n                    device->coopmat_int_support = true;\n                    device->coopmat_int_m = prop.MSize;\n                    device->coopmat_int_n = prop.NSize;\n                    device->coopmat_int_k = prop.KSize;\n                }\n#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)\n                if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&\n                    prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&\n                    prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&\n                    prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&\n                    (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup\n                ) {\n                    // coopmat sizes not set yet\n                    if (device->coopmat_m == 0) {\n                        device->coopmat_bf16_support = true;\n                        device->coopmat_m = prop.MSize;\n                        device->coopmat_n = prop.NSize;\n                        device->coopmat_k = prop.KSize;\n                    } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {\n                        // Only enable if shape is identical\n                        device->coopmat_bf16_support = true;\n                    }\n                }\n#endif\n            }\n\n            if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {\n                // No suitable matmul mode found\n                GGML_LOG_DEBUG(\"ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\\n\");\n                device->coopmat_support = false;\n            }\n            if (getenv(\"GGML_VK_DISABLE_BFLOAT16\")) {\n                device->coopmat_bf16_support = false;\n            }\n        }\n\n        if (device->coopmat_support) {\n            device_extensions.push_back(\"VK_KHR_cooperative_matrix\");\n        }\n#if defined(VK_KHR_shader_bfloat16)\n        if (device->coopmat_bf16_support) {\n            device_extensions.push_back(\"VK_KHR_shader_bfloat16\");\n        }\n#endif\n#endif\n        device->name = GGML_VK_NAME + std::to_string(idx);\n\n        device_create_info = {\n            vk::DeviceCreateFlags(),\n            device_queue_create_infos,\n            {},\n            device_extensions\n        };\n        device_create_info.setPNext(&device_features2);\n        device->device = device->physical_device.createDevice(device_create_info);\n\n        // Queues\n        ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);\n\n        // Shaders\n        // Disable matmul tile sizes early if performance low or not supported\n        for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {\n            switch (device->vendor_id) {\n#ifndef GGML_VULKAN_RUN_TESTS\n            case VK_VENDOR_ID_AMD:\n                device->mul_mat_l[i]    = device->coopmat_support && device->driver_id != vk::DriverId::eAmdProprietary;\n                device->mul_mat_m[i]    = true;\n                device->mul_mat_s[i]    = true;\n                device->mul_mat_id_l[i] = false;\n                device->mul_mat_id_m[i] = true;\n                device->mul_mat_id_s[i] = true;\n                break;\n            case VK_VENDOR_ID_INTEL:\n                if (!device->coopmat_support || device->architecture != INTEL_XE2) {\n                    device->mul_mat_l[i] = false;\n                    device->mul_mat_id_l[i] = false;\n                } else {\n                    device->mul_mat_l[i] = true;  // if coopmat & XE2+, allow large matmul warptile config for Intel\n                    device->mul_mat_id_l[i] = true;\n                }\n                device->mul_mat_m[i] = true;\n                device->mul_mat_s[i] = true;\n                device->mul_mat_id_m[i] = true;\n                device->mul_mat_id_s[i] = true;\n                break;\n            case VK_VENDOR_ID_APPLE:\n                device->mul_mat_l[i] = false;\n                device->mul_mat_m[i] = true;\n                device->mul_mat_s[i] = false;\n                device->mul_mat_id_l[i] = false;\n                device->mul_mat_id_m[i] = true;\n                device->mul_mat_id_s[i] = false;\n                break;\n#endif\n            default:\n                device->mul_mat_l[i] = true;\n                device->mul_mat_m[i] = true;\n                device->mul_mat_s[i] = true;\n                device->mul_mat_id_l[i] = true;\n                device->mul_mat_id_m[i] = true;\n                device->mul_mat_id_s[i] = true;\n                break;\n            }\n        }\n\n\n        std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;\n        std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;\n        for (uint32_t i = 0; i < MAX_PARAMETER_COUNT; i++) {\n            dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute});\n            dsl_binding_flags.push_back({});\n        }\n\n        vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags };\n\n        vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(\n            {},\n            dsl_binding);\n        descriptor_set_layout_create_info.setPNext(&dslbfci);\n        device->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);\n\n        ggml_vk_load_shaders(device);\n\n        if (!device->single_queue) {\n            const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;\n            ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);\n\n            device->async_use_transfer_queue = (getenv(\"GGML_VK_ASYNC_USE_TRANSFER_QUEUE\") != nullptr);\n        } else {\n            // TODO: Use pointer or reference to avoid copy\n            device->transfer_queue.copyFrom(device->compute_queue);\n            device->transfer_queue.cmd_pool.init(device, &device->transfer_queue);\n\n            device->async_use_transfer_queue = false;\n        }\n\n        device->buffer_type = {\n            /* .iface    = */ ggml_backend_vk_buffer_type_interface,\n            /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx),\n            /* .context  = */ new ggml_backend_vk_buffer_type_context{ device->name, device },\n        };\n\n        device->fence = device->device.createFence({});\n\n        device->idx = idx;\n\n        device->disable_fusion = getenv(\"GGML_VK_DISABLE_FUSION\") != nullptr;\n\n        device->add_rms_fusion = !device->disable_fusion &&\n                                 device->subgroup_arithmetic &&\n                                 device->vendor_id != VK_VENDOR_ID_INTEL;\n        device->partials_binding_alignment =\n            std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);\n\n        device->mmvq_mode = 0;\n        if (getenv(\"GGML_VK_DISABLE_MMVQ\")) {\n            device->mmvq_mode = -1;\n        } else if (getenv(\"GGML_VK_FORCE_MMVQ\")) {\n            device->mmvq_mode = 1;\n        }\n\n        return device;\n    }\n\n    return vk_instance.devices[idx];\n}\n\nstatic void ggml_vk_print_gpu_info(size_t idx) {\n    GGML_ASSERT(idx < vk_instance.device_indices.size());\n    size_t dev_num = vk_instance.device_indices[idx];\n    VK_LOG_DEBUG(\"ggml_vk_print_gpu_info(\" << dev_num << \")\");\n    GGML_ASSERT(vk_instance_initialized);\n\n    std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();\n\n    if (dev_num >= devices.size()) {\n        std::cerr << \"ggml_vulkan: Device with index \" << dev_num << \" does not exist.\" << std::endl;\n        throw std::runtime_error(\"Device not found\");\n    }\n\n    vk::PhysicalDevice physical_device = devices[dev_num];\n    std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();\n\n    bool fp16_storage = false;\n    bool fp16_compute = false;\n    bool coopmat_support = false;\n    bool coopmat2_support = false;\n    bool integer_dot_product = false;\n    bool bfloat16_support = false;\n\n    for (auto properties : ext_props) {\n        if (strcmp(\"VK_KHR_16bit_storage\", properties.extensionName) == 0) {\n            fp16_storage = true;\n        } else if (strcmp(\"VK_KHR_shader_float16_int8\", properties.extensionName) == 0) {\n            fp16_compute = true;\n#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)\n       } else if (strcmp(\"VK_KHR_cooperative_matrix\", properties.extensionName) == 0 &&\n                   !getenv(\"GGML_VK_DISABLE_COOPMAT\")) {\n            coopmat_support = true;\n#endif\n#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)\n        } else if (strcmp(\"VK_NV_cooperative_matrix2\", properties.extensionName) == 0 &&\n                   !getenv(\"GGML_VK_DISABLE_COOPMAT2\")) {\n            coopmat2_support = true;\n#endif\n#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)\n        } else if (strcmp(\"VK_KHR_shader_integer_dot_product\", properties.extensionName) == 0 &&\n                    !getenv(\"GGML_VK_DISABLE_INTEGER_DOT_PRODUCT\")) {\n            integer_dot_product = true;\n#endif\n#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)\n        } else if (strcmp(\"VK_KHR_shader_bfloat16\", properties.extensionName) == 0 &&\n                    !getenv(\"GGML_VK_DISABLE_BFLOAT16\")) {\n            bfloat16_support = true;\n#endif\n        }\n    }\n\n    const vk_device_architecture device_architecture = get_device_architecture(physical_device);\n\n    const char* GGML_VK_DISABLE_F16 = getenv(\"GGML_VK_DISABLE_F16\");\n    bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;\n\n    bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;\n\n    vk::PhysicalDeviceProperties2 props2;\n    vk::PhysicalDeviceMaintenance3Properties props3;\n    vk::PhysicalDeviceSubgroupProperties subgroup_props;\n    vk::PhysicalDeviceDriverProperties driver_props;\n    vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;\n    props2.pNext = &props3;\n    props3.pNext = &subgroup_props;\n    subgroup_props.pNext = &driver_props;\n\n    // Pointer to the last chain element\n    VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;\n\n    if (integer_dot_product) {\n        last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;\n        last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;\n    }\n\n    physical_device.getProperties2(&props2);\n\n    VkPhysicalDeviceFeatures2 device_features2;\n    device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;\n    device_features2.pNext = nullptr;\n\n    VkPhysicalDeviceVulkan11Features vk11_features;\n    vk11_features.pNext = nullptr;\n    vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;\n    device_features2.pNext = &vk11_features;\n\n    VkPhysicalDeviceVulkan12Features vk12_features;\n    vk12_features.pNext = nullptr;\n    vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;\n    vk11_features.pNext = &vk12_features;\n\n    // Pointer to the last chain element\n    last_struct = (VkBaseOutStructure *)&vk12_features;\n\n#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)\n    VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;\n    coopmat_features.pNext = nullptr;\n    coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;\n    coopmat_features.cooperativeMatrix = VK_FALSE;\n\n    if (coopmat_support) {\n        last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;\n        last_struct = (VkBaseOutStructure *)&coopmat_features;\n    }\n#endif\n\n    VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};\n    shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;\n    if (integer_dot_product) {\n        last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;\n        last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;\n    }\n\n#if defined(VK_KHR_shader_bfloat16)\n    VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};\n    bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;\n    if (bfloat16_support) {\n        last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;\n        last_struct = (VkBaseOutStructure *)&bfloat16_features;\n    }\n#endif\n\n    vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);\n\n    fp16 = fp16 && vk12_features.shaderFloat16;\n\n#if defined(VK_KHR_shader_bfloat16)\n    bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;\n#else\n    bool bf16 = false;\n#endif\n\n    uint32_t default_subgroup_size = get_subgroup_size(\"\", device_architecture);\n    const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;\n    const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;\n\n    integer_dot_product = integer_dot_product\n                       && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated\n                       && shader_integer_dot_product_features.shaderIntegerDotProduct;\n\n    coopmat_support = coopmat_support\n#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)\n                   && coopmat_features.cooperativeMatrix\n#endif\n                   && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);\n\n    std::string matrix_cores = coopmat2_support ? \"NV_coopmat2\" : coopmat_support ? \"KHR_coopmat\" : \"none\";\n\n    std::string device_name = props2.properties.deviceName.data();\n    GGML_LOG_DEBUG(\"ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\\n\",\n              idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size,\n              props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());\n\n    if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {\n        GGML_LOG_DEBUG(\"ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\\n\");\n    }\n}\n\nstatic bool ggml_vk_instance_layer_settings_available();\nstatic bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);\nstatic bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);\nstatic bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);\n\nstatic DispatchLoaderDynamic ggml_vk_default_dispatcher_instance;\nDispatchLoaderDynamic & ggml_vk_default_dispatcher() {\n    return ggml_vk_default_dispatcher_instance;\n}\n\nstatic void ggml_vk_instance_init() {\n    if (vk_instance_initialized) {\n        return;\n    }\n    VK_LOG_DEBUG(\"ggml_vk_instance_init()\");\n\n    // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-\n    ggml_vk_default_dispatcher_instance.init(vkGetInstanceProcAddr);\n\n    uint32_t api_version = vk::enumerateInstanceVersion();\n\n    if (api_version < VK_API_VERSION_1_2) {\n        std::cerr << \"ggml_vulkan: Error: Vulkan 1.2 required.\" << std::endl;\n        throw vk::SystemError(vk::Result::eErrorFeatureNotPresent, \"Vulkan 1.2 required\");\n    }\n\n    vk::ApplicationInfo app_info{ \"ggml-vulkan\", 1, nullptr, 0, api_version };\n\n    const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();\n    const bool layer_settings = ggml_vk_instance_layer_settings_available();\n#ifdef __APPLE__\n    const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);\n#endif\n    const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv(\"GGML_VK_DEBUG_MARKERS\") != nullptr;\n    std::vector<const char*> layers;\n\n    if (layer_settings) {\n        layers.push_back(\"VK_LAYER_KHRONOS_validation\");\n    }\n    std::vector<const char*> extensions;\n    if (layer_settings) {\n        extensions.push_back(\"VK_EXT_layer_settings\");\n    }\n#ifdef __APPLE__\n    if (portability_enumeration_ext) {\n        extensions.push_back(\"VK_KHR_portability_enumeration\");\n    }\n#endif\n    if (debug_utils_ext) {\n        extensions.push_back(\"VK_EXT_debug_utils\");\n    }\n    VkBool32 enable_best_practice = layer_settings;\n    std::vector<vk::LayerSettingEXT> settings = {\n        {\n            \"VK_LAYER_KHRONOS_validation\",\n            \"validate_best_practices\",\n            vk::LayerSettingTypeEXT::eBool32,\n            1,\n            &enable_best_practice\n        },\n    };\n    vk::LayerSettingsCreateInfoEXT layer_setting_info(settings);\n    vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions, &layer_setting_info);\n#ifdef __APPLE__\n    if (portability_enumeration_ext) {\n        instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;\n    }\n#endif\n\n    vk_instance.instance = vk::createInstance(instance_create_info);\n    vk_instance_initialized = true;\n\n    if (debug_utils_ext) {\n        vk_instance.debug_utils_support              = true;\n        vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, \"vkSetDebugUtilsObjectNameEXT\");\n        vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, \"vkQueueBeginDebugUtilsLabelEXT\");\n        vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, \"vkQueueEndDebugUtilsLabelEXT\");\n        vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, \"vkCmdBeginDebugUtilsLabelEXT\");\n        vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT =   (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, \"vkCmdEndDebugUtilsLabelEXT\");\n        vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, \"vkCmdInsertDebugUtilsLabelEXT\");\n    }\n\n    vk_perf_logger_enabled = getenv(\"GGML_VK_PERF_LOGGER\") != nullptr;\n    vk_perf_logger_concurrent = getenv(\"GGML_VK_PERF_LOGGER_CONCURRENT\") != nullptr;\n    vk_enable_sync_logger = getenv(\"GGML_VK_SYNC_LOGGER\") != nullptr;\n    vk_memory_logger_enabled = getenv(\"GGML_VK_MEMORY_LOGGER\") != nullptr;\n    const char* GGML_VK_PIPELINE_STATS = getenv(\"GGML_VK_PIPELINE_STATS\");\n    if (GGML_VK_PIPELINE_STATS != nullptr) {\n        vk_pipeline_stats_filter = GGML_VK_PIPELINE_STATS;\n    }\n    const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv(\"GGML_VK_PERF_LOGGER_FREQUENCY\");\n\n    if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) {\n        vk_perf_logger_frequency = std::stoul(GGML_VK_PERF_LOGGER_FREQUENCY);\n    }\n\n    // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-\n    VULKAN_HPP_DEFAULT_DISPATCHER.init(vk_instance.instance);\n\n    std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();\n\n    // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan\n    char * devices_env = getenv(\"GGML_VK_VISIBLE_DEVICES\");\n    if (devices_env != nullptr) {\n        size_t num_available_devices = devices.size();\n\n        std::string devices(devices_env);\n        std::replace(devices.begin(), devices.end(), ',', ' ');\n\n        std::stringstream ss(devices);\n        size_t tmp;\n        while (ss >> tmp) {\n            if(tmp >= num_available_devices) {\n                std::cerr << \"ggml_vulkan: Invalid device index \" << tmp << \" in GGML_VK_VISIBLE_DEVICES.\" << std::endl;\n                throw std::runtime_error(\"Invalid Vulkan device index\");\n            }\n            vk_instance.device_indices.push_back(tmp);\n        }\n    } else {\n        // If no vulkan devices are found, return early\n        if (devices.empty()) {\n            GGML_LOG_INFO(\"ggml_vulkan: No devices found.\\n\");\n            return;\n        }\n\n        // Default to using all dedicated GPUs\n        for (size_t i = 0; i < devices.size(); i++) {\n            vk::PhysicalDeviceProperties2 new_props;\n            vk::PhysicalDeviceDriverProperties new_driver;\n            vk::PhysicalDeviceIDProperties new_id;\n            new_props.pNext = &new_driver;\n            new_driver.pNext = &new_id;\n            devices[i].getProperties2(&new_props);\n\n            if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) {\n                // Check if there are two physical devices corresponding to the same GPU\n                // This handles the case where the same GPU appears with different drivers (e.g., RADV + AMDVLK on Linux),\n                // see https://github.com/ggml-org/llama.cpp/pull/7582 for original deduplication.\n                // MoltenVK on macOS may report the same UUID for distinct GPUs on multi-GPU cards,\n                // see https://github.com/KhronosGroup/MoltenVK/issues/2683. Skip when both old/new\n                // driver is MoltenVK\n                auto old_device = std::find_if(\n                    vk_instance.device_indices.begin(),\n                    vk_instance.device_indices.end(),\n                    [&devices, &new_id, &new_driver](const size_t k){\n                        vk::PhysicalDeviceProperties2 old_props;\n                        vk::PhysicalDeviceDriverProperties old_driver;\n                        vk::PhysicalDeviceIDProperties old_id;\n                        old_props.pNext = &old_driver;\n                        old_driver.pNext = &old_id;\n                        devices[k].getProperties2(&old_props);\n\n                        bool same_uuid = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));\n                        same_uuid = same_uuid || (\n                            old_id.deviceLUIDValid && new_id.deviceLUIDValid &&\n                            std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))\n                        );\n                        bool both_molten_vk = (new_driver.driverID == vk::DriverId::eMoltenvk && old_driver.driverID == vk::DriverId::eMoltenvk);\n\n                        return same_uuid && !both_molten_vk;\n                    }\n                );\n                if (old_device == vk_instance.device_indices.end()) {\n                    vk_instance.device_indices.push_back(i);\n                } else {\n                    // There can be two physical devices corresponding to the same GPU if there are 2 different drivers\n                    // This can cause error when splitting layers aross the devices, need to keep only 1\n                    VK_LOG_DEBUG(\"Device \" << i << \" and device \" << *old_device << \" have the same deviceUUID\");\n\n                    vk::PhysicalDeviceProperties2 old_props;\n                    vk::PhysicalDeviceDriverProperties old_driver;\n                    old_props.pNext = &old_driver;\n                    devices[*old_device].getProperties2(&old_props);\n\n                    std::map<vk::DriverId, int> driver_priorities {};\n                    int old_priority = std::numeric_limits<int>::max();\n                    int new_priority = std::numeric_limits<int>::max();\n\n                    // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id\n                    // Smaller number -> higher priority\n                    switch (old_props.properties.vendorID) {\n                        case VK_VENDOR_ID_AMD:\n                            driver_priorities[vk::DriverId::eMesaRadv] = 1;\n                            driver_priorities[vk::DriverId::eAmdOpenSource] = 2;\n                            driver_priorities[vk::DriverId::eAmdProprietary] = 3;\n                            break;\n                        case VK_VENDOR_ID_INTEL:\n                            driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1;\n                            driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2;\n                            break;\n                        case VK_VENDOR_ID_NVIDIA:\n                            driver_priorities[vk::DriverId::eNvidiaProprietary] = 1;\n#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235\n                            driver_priorities[vk::DriverId::eMesaNvk] = 2;\n#endif\n                            break;\n                        case VK_VENDOR_ID_QUALCOMM:\n                            driver_priorities[vk::DriverId::eQualcommProprietary] = 1;\n                            driver_priorities[vk::DriverId::eMesaTurnip] = 2;\n                            break;\n                    }\n                    driver_priorities[vk::DriverId::eMesaDozen] = 100;\n\n                    if (driver_priorities.count(old_driver.driverID)) {\n                        old_priority = driver_priorities[old_driver.driverID];\n                    }\n                    if (driver_priorities.count(new_driver.driverID)) {\n                        new_priority = driver_priorities[new_driver.driverID];\n                    }\n\n                    if (new_priority < old_priority) {\n                        auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device);\n                        vk_instance.device_indices.erase(r, vk_instance.device_indices.end());\n                        vk_instance.device_indices.push_back(i);\n\n                        VK_LOG_DEBUG(\"Prioritize device \" << i << \" driver \" << new_driver.driverName << \" over device \" << *old_device << \" driver \" << old_driver.driverName);\n                    }\n                    else {\n                        VK_LOG_DEBUG(\"Prioritize device \" << *old_device << \" driver \" << old_driver.driverName << \" over device \" << i << \" driver \" << new_driver.driverName << std::endl);\n                    }\n                }\n            }\n        }\n\n        // If no GPUs found, fall back to the first non-CPU device.\n        // If only CPU devices are available, return without devices.\n        if (vk_instance.device_indices.empty()) {\n            for (size_t i = 0; i < devices.size(); i++) {\n                if (devices[i].getProperties().deviceType != vk::PhysicalDeviceType::eCpu) {\n                    vk_instance.device_indices.push_back(i);\n                    break;\n                }\n            }\n        }\n\n        if (vk_instance.device_indices.empty()) {\n            GGML_LOG_INFO(\"ggml_vulkan: No devices found.\\n\");\n            return;\n        }\n    }\n    GGML_LOG_DEBUG(\"ggml_vulkan: Found %zu Vulkan devices:\\n\", vk_instance.device_indices.size());\n\n    for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {\n        vk::PhysicalDevice vkdev = devices[vk_instance.device_indices[i]];\n        std::vector<vk::ExtensionProperties> extensionprops = vkdev.enumerateDeviceExtensionProperties();\n\n        bool membudget_supported = false;\n        for (const auto & ext : extensionprops) {\n            if (strcmp(VK_EXT_MEMORY_BUDGET_EXTENSION_NAME, ext.extensionName) == 0) {\n                membudget_supported = true;\n                break;\n            }\n        }\n\n        vk_instance.device_supports_membudget.push_back(membudget_supported);\n\n        ggml_vk_print_gpu_info(i);\n    }\n}\n\nstatic void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {\n    VK_LOG_DEBUG(\"ggml_vk_init(\" << ctx->name << \", \" << idx << \")\");\n    ggml_vk_instance_init();\n    GGML_ASSERT(idx < vk_instance.device_indices.size());\n\n    ctx->name = GGML_VK_NAME + std::to_string(idx);\n\n    ctx->device = ggml_vk_get_device(idx);\n\n    ctx->semaphore_idx = 0;\n    ctx->event_idx = 0;\n\n    ctx->prealloc_size_x = 0;\n    ctx->prealloc_size_y = 0;\n    ctx->prealloc_size_split_k = 0;\n    // Fixed size of 1KB, for deterministic behavior\n    ctx->prealloc_size_add_rms_partials = 1024;\n\n    ctx->fence = ctx->device->device.createFence({});\n    ctx->almost_ready_fence = ctx->device->device.createFence({});\n\n    ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue);\n    if (ctx->device->async_use_transfer_queue) {\n        vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };\n        vk::SemaphoreCreateInfo ci{};\n        ci.setPNext(&tci);\n        ctx->transfer_semaphore.s = ctx->device->device.createSemaphore(ci);\n        ctx->transfer_semaphore.value = 0;\n\n        ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue);\n    }\n\n    if (vk_perf_logger_enabled) {\n        ctx->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());\n    }\n\n#ifdef GGML_VULKAN_CHECK_RESULTS\n    const char* skip_checks = getenv(\"GGML_VULKAN_SKIP_CHECKS\");\n    vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));\n    const char* output_tensor = getenv(\"GGML_VULKAN_OUTPUT_TENSOR\");\n    vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor));\n#endif\n}\n\nstatic vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {\n    VK_LOG_DEBUG(\"ggml_vk_get_to_fp16()\");\n    switch (type) {\n        case GGML_TYPE_F32:\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ2_S:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ3_S:\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ4_NL:\n        case GGML_TYPE_MXFP4:\n            break;\n        default:\n            return nullptr;\n    }\n\n    return ctx->device->pipeline_dequant[type];\n}\n\nstatic vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {\n    VK_LOG_DEBUG(\"ggml_vk_get_mul_mat_mat_pipeline(\" << ggml_type_name(src0_type) << \", \" << ggml_type_name(src1_type) << \", \" << prec << \")\");\n    if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {\n        return ctx->device->pipeline_matmul_f32;\n    }\n    if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {\n        return ctx->device->pipeline_matmul_f32_f16;\n    }\n    if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {\n        return ctx->device->pipeline_matmul_bf16;\n    }\n    if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {\n        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_matmul_f16_f32.f16acc;\n        }\n        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {\n            return ctx->device->pipeline_matmul_f16.f16acc;\n        }\n    } else {\n        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_matmul_f16_f32.f32acc;\n        }\n        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {\n            return ctx->device->pipeline_matmul_f16.f32acc;\n        }\n    }\n\n    // MMQ\n    if (src1_type == GGML_TYPE_Q8_1) {\n        vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;\n\n        if (pipelines->is_empty()) {\n            return nullptr;\n        }\n\n        return pipelines;\n    }\n\n    if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {\n        return nullptr;\n    }\n\n    switch (src0_type) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ2_S:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ3_S:\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ4_NL:\n        case GGML_TYPE_MXFP4:\n            break;\n        default:\n            return nullptr;\n    }\n\n    if (ctx->device->coopmat2) {\n        assert(src1_type == GGML_TYPE_F16);\n        return prec == GGML_PREC_DEFAULT ? ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f32acc;\n    }\n    if (ctx->device->coopmat_support) {\n        return (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;\n    }\n    return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;\n}\n\nstatic vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols, uint32_t m, uint32_t k) {\n    VK_LOG_DEBUG(\"ggml_vk_get_dequantize_mul_mat_vec()\");\n    GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16 || b_type == GGML_TYPE_Q8_1);\n    GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);\n\n    if (b_type == GGML_TYPE_Q8_1) {\n        switch (a_type) {\n            case GGML_TYPE_Q4_0:\n            case GGML_TYPE_Q4_1:\n            case GGML_TYPE_Q5_0:\n            case GGML_TYPE_Q5_1:\n            case GGML_TYPE_Q8_0:\n            case GGML_TYPE_MXFP4:\n            case GGML_TYPE_Q2_K:\n            case GGML_TYPE_Q3_K:\n            case GGML_TYPE_Q4_K:\n            case GGML_TYPE_Q5_K:\n            case GGML_TYPE_Q6_K:\n            case GGML_TYPE_IQ1_S:\n            case GGML_TYPE_IQ1_M:\n                break;\n            default:\n                return nullptr;\n        }\n    }\n\n    switch (a_type) {\n        case GGML_TYPE_F32:\n        case GGML_TYPE_F16:\n        case GGML_TYPE_BF16:\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ2_S:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ3_S:\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ4_NL:\n        case GGML_TYPE_MXFP4:\n            break;\n        default:\n            return nullptr;\n    }\n\n    // heuristic to choose workgroup size\n    uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;\n    if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {\n        // Prefer larger workgroups when M is small, to spread the work out more\n        // and keep more SMs busy.\n        // q6_k seems to prefer small workgroup size even for \"medium\" values of M.\n        if (a_type == GGML_TYPE_Q6_K) {\n            if (m < 4096 && k >= 1024) {\n                dmmv_wg = DMMV_WG_SIZE_LARGE;\n            }\n        } else {\n            if (m <= 8192 && k >= 1024) {\n                dmmv_wg = DMMV_WG_SIZE_LARGE;\n            }\n        }\n    }\n\n    if (b_type == GGML_TYPE_Q8_1) {\n        if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {\n            dmmv_wg = DMMV_WG_SIZE_SUBGROUP;\n        }\n        return ctx->device->pipeline_dequant_mul_mat_vec_q8_1_f32[dmmv_wg][a_type][num_cols-1];\n    }\n\n    return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[dmmv_wg][a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[dmmv_wg][a_type][num_cols-1];\n}\n\nstatic vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {\n    VK_LOG_DEBUG(\"ggml_vk_get_mul_mat_mat_id_pipeline()\");\n    if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {\n        return ctx->device->pipeline_matmul_id_f32;\n    }\n    if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {\n        return ctx->device->pipeline_matmul_id_bf16;\n    }\n    if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {\n        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_matmul_id_f16_f32.f16acc;\n        }\n        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {\n            return ctx->device->pipeline_matmul_id_f16.f16acc;\n        }\n    } else {\n        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_matmul_id_f16_f32.f32acc;\n        }\n        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {\n            return ctx->device->pipeline_matmul_id_f16.f32acc;\n        }\n    }\n\n    // MMQ\n    if (src1_type == GGML_TYPE_Q8_1) {\n        vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;\n\n        if (pipelines->is_empty()) {\n            return nullptr;\n        }\n\n        return pipelines;\n    }\n\n    GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));\n\n    switch (src0_type) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ2_S:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ3_S:\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ4_NL:\n        case GGML_TYPE_MXFP4:\n            break;\n        default:\n            return nullptr;\n    }\n\n    vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];\n    // XXX TODO 'prec' is not actually allowed in mul_mat_id.\n    bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;\n    bool support_fp16acc = !mmp.f16acc->is_empty();\n    bool support_fp32acc = !mmp.f32acc->is_empty();\n\n    if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {\n        return mmp.f16acc;\n    } else {\n        GGML_ASSERT(support_fp32acc);\n        return mmp.f32acc;\n    }\n}\n\nstatic vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t m, uint32_t k) {\n    VK_LOG_DEBUG(\"ggml_vk_get_dequantize_mul_mat_vec_id()\");\n    GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_Q8_1);\n\n    if (b_type == GGML_TYPE_Q8_1) {\n        switch (a_type) {\n            case GGML_TYPE_Q4_0:\n            case GGML_TYPE_Q4_1:\n            case GGML_TYPE_Q5_0:\n            case GGML_TYPE_Q5_1:\n            case GGML_TYPE_Q8_0:\n            case GGML_TYPE_MXFP4:\n            case GGML_TYPE_Q2_K:\n            case GGML_TYPE_Q3_K:\n            case GGML_TYPE_Q4_K:\n            case GGML_TYPE_Q5_K:\n            case GGML_TYPE_Q6_K:\n            case GGML_TYPE_IQ1_S:\n            case GGML_TYPE_IQ1_M:\n                break;\n            default:\n                return nullptr;\n        }\n    }\n\n    switch (a_type) {\n        case GGML_TYPE_F32:\n        case GGML_TYPE_F16:\n        case GGML_TYPE_BF16:\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ2_S:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ3_S:\n        case GGML_TYPE_IQ4_XS:\n        case GGML_TYPE_IQ4_NL:\n        case GGML_TYPE_MXFP4:\n            break;\n        default:\n            return nullptr;\n    }\n\n    // heuristic to choose workgroup size\n    uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;\n    if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {\n        // Prefer larger workgroups when M is small, to spread the work out more\n        // and keep more SMs busy.\n        // q6_k seems to prefer small workgroup size even for \"medium\" values of M.\n        if (a_type == GGML_TYPE_Q6_K) {\n            if (m < 4096 && k >= 1024) {\n                dmmv_wg = DMMV_WG_SIZE_LARGE;\n            }\n        } else {\n            if (m <= 8192 && k >= 1024) {\n                dmmv_wg = DMMV_WG_SIZE_LARGE;\n            }\n        }\n    }\n\n    if (b_type == GGML_TYPE_Q8_1) {\n        if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {\n            dmmv_wg = DMMV_WG_SIZE_SUBGROUP;\n        }\n        return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[dmmv_wg][a_type];\n    }\n\n    return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[dmmv_wg][a_type];\n}\n\nstatic void * ggml_vk_host_malloc(vk_device& device, size_t size) {\n    VK_LOG_MEMORY(\"ggml_vk_host_malloc(\" << size << \")\");\n    vk_buffer buf = ggml_vk_create_buffer(device, size,\n        {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,\n         vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});\n\n    if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {\n        fprintf(stderr, \"WARNING: failed to allocate %.2f MB of pinned memory\\n\",\n            size/1024.0/1024.0);\n        device->device.freeMemory(buf->device_memory);\n        device->device.destroyBuffer(buf->buffer);\n        return nullptr;\n    }\n\n    std::lock_guard<std::recursive_mutex> guard(device->mutex);\n    device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));\n\n    return buf->ptr;\n}\n\nstatic void ggml_vk_host_free(vk_device& device, void* ptr) {\n    if (ptr == nullptr) {\n        return;\n    }\n    VK_LOG_MEMORY(\"ggml_vk_host_free(\" << ptr << \")\");\n    std::lock_guard<std::recursive_mutex> guard(device->mutex);\n\n    vk_buffer buf;\n    size_t index;\n    for (size_t i = 0; i < device->pinned_memory.size(); i++) {\n        const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);\n        const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);\n        if (ptr >= addr && ptr < endr) {\n            buf = std::get<2>(device->pinned_memory[i]);\n            index = i;\n            break;\n        }\n    }\n    if (buf == nullptr) {\n        fprintf(stderr, \"WARNING: failed to free pinned memory: memory not in map\\n\");\n        return;\n    }\n\n    ggml_vk_destroy_buffer(buf);\n\n    device->pinned_memory.erase(device->pinned_memory.begin() + index);\n}\n\nstatic void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {\n    std::lock_guard<std::recursive_mutex> guard(device->mutex);\n    buf = nullptr;\n    buf_offset = 0;\n    for (size_t i = 0; i < device->pinned_memory.size(); i++) {\n        const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);\n        const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);\n        if (ptr >= addr && ptr < endr) {\n            buf = std::get<2>(device->pinned_memory[i]);\n            buf_offset = ((const uint8_t *)ptr) - addr;\n            break;\n        }\n    }\n}\n\nstatic vk_subbuffer ggml_vk_tensor_subbuffer(\n    const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) {\n\n    vk_buffer buffer = nullptr;\n    size_t offset = 0;\n    if (ctx->device->uma) {\n        ggml_vk_host_get(ctx->device, tensor->data, buffer, offset);\n    }\n    if (!buffer) {\n        auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;\n        buffer = buf_ctx->dev_buffer;\n        offset = vk_tensor_offset(tensor) + tensor->view_offs;\n    }\n    GGML_ASSERT(buffer != nullptr);\n\n    size_t size = ggml_nbytes(tensor);\n\n    size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);\n    // The shader must support misaligned offsets when indexing into the buffer\n    GGML_ASSERT(allow_misalign || misalign_bytes == 0);\n    offset &= ~misalign_bytes;\n    size += misalign_bytes;\n\n    return vk_subbuffer{buffer, offset, size};\n}\n\n// Get a command buffer from pool. Create a new one if no reusable buffer is available\nstatic vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) {\n    for (auto& cmd_buffer : pool.cmd_buffers) {\n        if (!cmd_buffer.in_use) {\n            cmd_buffer.in_use = true;\n            return &cmd_buffer;\n        }\n    }\n    return ggml_vk_create_cmd_buffer(device, pool);\n}\n\nstatic vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {\n    vk_submission s;\n    s.buffer = ggml_vk_get_or_create_cmd_buffer(device, p);\n    if (one_time) {\n        s.buffer->buf.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });\n    } else {\n        s.buffer->buf.begin({ vk::CommandBufferUsageFlags{} });\n    }\n\n    return s;\n}\n\ntemplate <typename T> size_t push_constant_size(const T &t) {\n    static_assert(std::is_class<T>::value, \"T must be a struct/class\");\n    GGML_UNUSED(t);\n    return sizeof(T);\n}\ntemplate <typename T> size_t push_constant_size(const std::vector<T> &t) {\n    GGML_UNUSED(t);\n    return sizeof(T) * t.size();\n}\ntemplate <typename T, uint32_t N> size_t push_constant_size(const std::array<T, N> &t) {\n    GGML_UNUSED(t);\n    return sizeof(T) * N;\n}\n\ntemplate <typename T> const T *push_constant_data(const T &t) {\n    static_assert(std::is_class<T>::value, \"T must be a struct/class\");\n    return &t;\n}\ntemplate <typename T> const T *push_constant_data(const std::vector<T> &t) {\n    return t.data();\n}\ntemplate <typename T, uint32_t N> const T *push_constant_data(const std::array<T, N> &t) {\n    return t.data();\n}\n\ntemplate <typename T>\nstatic void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, const T &push_constants, std::array<uint32_t, 3> elements) {\n    const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);\n    const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);\n    const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);\n    VK_LOG_DEBUG(\"ggml_vk_dispatch_pipeline(\" << pipeline->name << \", {\";\n    for (auto& buffer : descriptor_buffer_infos) {\n        std::cerr << \"(\" << buffer.buffer << \", \" << buffer.offset << \", \" << buffer.range << \"), \";\n    }\n    std::cerr << \"}, (\" << wg0 << \",\" << wg1 << \",\" << wg2 << \"))\");\n    GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&\n                wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&\n                wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);\n    GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());\n    GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);\n    GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());\n    GGML_ASSERT(pipeline->push_constant_size == push_constant_size(push_constants));\n\n    vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];\n    vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };\n    ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});\n\n    subctx->s->buffer->buf.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));\n    subctx->s->buffer->buf.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);\n    subctx->s->buffer->buf.bindDescriptorSets(vk::PipelineBindPoint::eCompute,\n                                pipeline->layout,\n                                0,\n                                { descriptor_set },\n                                {});\n    subctx->s->buffer->buf.dispatch(wg0, wg1, wg2);\n}\n\nstatic void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {\n    s.buffer->buf.end();\n\n    s.wait_semaphores = std::move(wait_semaphores);\n    s.signal_semaphores = std::move(signal_semaphores);\n}\n\nstatic void ggml_vk_ctx_end(vk_context& ctx) {\n    VK_LOG_DEBUG(\"ggml_vk_ctx_end(\" << ctx << \", \" << ctx->seqs.size() << \")\");\n    if (ctx->s == nullptr) {\n        return;\n    }\n\n    ctx->s->buffer->buf.end();\n    ctx->s = nullptr;\n}\n\nstatic void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {\n    VK_LOG_DEBUG(\"ggml_vk_ctx_begin(\" << device->name << \")\");\n    if (subctx->s != nullptr) {\n        ggml_vk_ctx_end(subctx);\n    }\n\n    subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->p) });\n    subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();\n}\n\nstatic vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) {\n    if (!ctx->compute_ctx.expired()) {\n        return ctx->compute_ctx.lock();\n    }\n\n    vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);\n\n    ctx->compute_ctx = result;\n    ggml_vk_ctx_begin(ctx->device, result);\n\n    if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {\n        result->s->wait_semaphores.push_back(ctx->transfer_semaphore);\n        ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value;\n    }\n\n    return result;\n}\n\n// Submit any pending transfer queue work and signal the transfer semaphore.\n// The next compute context created via ggml_vk_get_compute_ctx will wait on this semaphore.\n// Returns true if work was submitted.\nstatic bool ggml_vk_submit_transfer_ctx(ggml_backend_vk_context * ctx) {\n    if (!ctx->device->async_use_transfer_queue || ctx->transfer_ctx.expired()) {\n        return false;\n    }\n\n    vk_context cpy_ctx = ctx->transfer_ctx.lock();\n    ggml_vk_ctx_end(cpy_ctx);\n\n    for (auto& cpy : cpy_ctx->in_memcpys) {\n        memcpy(cpy.dst, cpy.src, cpy.n);\n    }\n\n    ctx->transfer_semaphore.value++;\n    cpy_ctx->seqs.back().back().signal_semaphores.push_back(ctx->transfer_semaphore);\n\n    ggml_vk_submit(cpy_ctx, {});\n    ctx->transfer_ctx.reset();\n    return true;\n}\n\nstatic size_t ggml_vk_align_size(size_t width, size_t align) {\n    VK_LOG_DEBUG(\"ggml_vk_align_size(\" << width << \", \" << align << \")\");\n    return CEIL_DIV(width, align) * align;\n}\n\nstatic void deferred_memcpy(void * dst, const void * src, size_t size, std::vector<vk_staging_memcpy>* memcpys = nullptr) {\n    if (memcpys == nullptr) {\n        memcpy(dst, src, size);\n    } else {\n        memcpys->emplace_back(dst, src, size);\n    }\n}\n\nstatic void deferred_memset(void * dst, uint32_t val, size_t size, std::vector<vk_staging_memset>* memsets = nullptr) {\n    if (memsets == nullptr) {\n        memset(dst, val, size);\n    } else {\n        memsets->emplace_back(dst, val, size);\n    }\n}\n\nstatic void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) {\n    if (device->sync_staging == nullptr || device->sync_staging->size < size) {\n        VK_LOG_MEMORY(\"ggml_vk_ensure_sync_staging_buffer(\" << size << \")\");\n        ggml_vk_destroy_buffer(device->sync_staging);\n        device->sync_staging = ggml_vk_create_buffer_check(device, size,\n            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,\n            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);\n    }\n}\n\nstatic void ggml_vk_ensure_sync_staging_buffer(ggml_backend_vk_context * ctx, size_t size) {\n    if (ctx->sync_staging == nullptr || ctx->sync_staging->size < size) {\n        VK_LOG_MEMORY(\"ggml_vk_ensure_sync_staging_buffer(\" << size << \")\");\n        ggml_vk_destroy_buffer(ctx->sync_staging);\n        ctx->sync_staging = ggml_vk_create_buffer_check(ctx->device, size,\n            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,\n            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);\n    }\n}\n\nstatic void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) {\n    VK_LOG_DEBUG(\"ggml_vk_buffer_write_nc_async(\" << tensor << \")\");\n    GGML_ASSERT(!ggml_is_contiguous(tensor));\n    // Buffer is already mapped\n    if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {\n        std::cerr << \"ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write.\" << std::endl;\n        GGML_ABORT(\"fatal error\");\n    }\n    // Check if src is pinned memory\n    vk_buffer buf = nullptr;\n    size_t buf_offset = 0;\n    ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset);\n\n    const uint64_t ne0 = tensor->ne[0];\n    const uint64_t ne1 = tensor->ne[1];\n    const uint64_t ne2 = tensor->ne[2];\n    const uint64_t ne3 = tensor->ne[3];\n    const uint64_t nb0 = tensor->nb[0];\n    const uint64_t nb1 = tensor->nb[1];\n    const uint64_t nb2 = tensor->nb[2];\n    const uint64_t nb3 = tensor->nb[3];\n    const ggml_type type = tensor->type;\n    const uint64_t ts = ggml_type_size(type);\n    const uint64_t bs = ggml_blck_size(type);\n\n    const uint64_t dstnb0 = ts;\n    const uint64_t dstnb1 = dstnb0*(ne0/bs);\n    const uint64_t dstnb2 = dstnb1*ne1;\n    const uint64_t dstnb3 = dstnb2*ne2;\n\n    const uint64_t ne = ggml_nelements(tensor);\n\n    if (buf != nullptr) {\n        // Memory is pinned, use as staging buffer\n        std::vector<vk::BufferCopy> slices;\n\n        for (uint64_t i3 = 0; i3 < ne3; i3++) {\n            for (uint64_t i2 = 0; i2 < ne2; i2++) {\n                // Find longest contiguous slice\n                if (ne1*nb1 == dstnb2) {\n                    slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 });\n                } else {\n                    for (uint64_t i1 = 0; i1 < ne1; i1++) {\n                        if (ne0*nb0/bs == dstnb1) {\n                            slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 });\n                        } else {\n                            const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;\n                            const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1;\n                            for (uint64_t i0 = 0; i0 < ne0; i0++) {\n                                slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 });\n                            }\n                        }\n                    }\n                }\n            }\n        }\n\n        ggml_vk_sync_buffers(ctx, subctx);\n        subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices);\n        return;\n    }\n\n    if (!sync_staging) {\n        GGML_ABORT(\"Asynchronous write to non-pinned memory not supported\");\n    }\n\n    // Staging buffer required\n    vk_buffer& staging = ctx->device->sync_staging;\n    const uint64_t copy_size = ts*ne/bs;\n    ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);\n    VkBufferCopy buf_copy{ 0, offset, copy_size };\n\n    ggml_vk_sync_buffers(ctx, subctx);\n    vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);\n\n    for (uint64_t i3 = 0; i3 < ne3; i3++) {\n        for (uint64_t i2 = 0; i2 < ne2; i2++) {\n            // Find longest contiguous slice\n            if (ne1*nb1 == dstnb2) {\n                deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys);\n            } else {\n                for (uint64_t i1 = 0; i1 < ne1; i1++) {\n                    if (ne0*nb0/bs == dstnb1) {\n                        deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys);\n                    } else {\n                        const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;\n                        const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1;\n                        for (uint64_t i0 = 0; i0 < ne0; i0++) {\n                            deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys);\n                        }\n                    }\n                }\n            }\n        }\n    }\n}\n\nstatic bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {\n    VK_LOG_DEBUG(\"ggml_vk_buffer_write_2d_async(\" << width << \", \" << height << \")\");\n    // Check if src is pinned memory\n    vk_buffer buf = nullptr;\n    size_t buf_offset = 0;\n    ggml_vk_host_get(dst->device, src, buf, buf_offset);\n\n    if (buf != nullptr) {\n        // Memory is pinned, use as staging buffer\n        std::vector<vk::BufferCopy> slices(1);\n        if (width == spitch) {\n            // Only do single write if stride is equal\n            slices[0].srcOffset = buf_offset;\n            slices[0].dstOffset = offset;\n            slices[0].size = width * height;\n        } else {\n            slices.resize(height);\n            for (size_t i = 0; i < height; i++) {\n                slices[i].srcOffset = buf_offset + i * spitch;\n                slices[i].dstOffset = offset + i * width;\n                slices[i].size = width;\n            }\n        }\n\n        ggml_vk_sync_buffers(nullptr, subctx);\n        subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices);\n        return true;\n    }\n    VK_LOG_DEBUG(\"STAGING\");\n\n    if (!sync_staging) {\n        // copy was not handled caller needs to fall back\n        return false;\n    }\n\n    // Staging buffer required\n    const size_t copy_size = width*height;\n    ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size);\n\n    vk_buffer& staging_buffer = dst->device->sync_staging;\n\n    VkBufferCopy buf_copy = {\n        0,\n        offset,\n        copy_size};\n\n    ggml_vk_sync_buffers(nullptr, subctx);\n    vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);\n\n    if (width == spitch) {\n        deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys);\n    } else {\n        for (size_t i = 0; i < height; i++) {\n            deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);\n        }\n    }\n    return true;\n}\n\nstatic bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {\n    VK_LOG_DEBUG(\"ggml_vk_buffer_write_async(\" << size << \")\");\n    return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);\n}\n\nstatic void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) {\n    VK_LOG_DEBUG(\"ggml_vk_buffer_write_2d(\" << width << \", \" << height << \")\");\n    // Buffer is already mapped\n    if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {\n        GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);\n\n        for (size_t i = 0; i < height; i++) {\n            memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);\n        }\n    } else {\n        std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);\n\n        vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);\n        ggml_vk_ctx_begin(dst->device, subctx);\n        bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);\n        GGML_ASSERT(ret);\n        ggml_vk_ctx_end(subctx);\n\n        for (auto& cpy : subctx->in_memcpys) {\n            memcpy(cpy.dst, cpy.src, cpy.n);\n        }\n\n        for (auto& mset : subctx->memsets) {\n            memset(mset.dst, mset.val, mset.n);\n        }\n\n        ggml_vk_submit(subctx, dst->device->fence);\n        VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), \"vk_buffer_write_2d waitForFences\");\n        dst->device->device.resetFences({ dst->device->fence });\n        ggml_vk_queue_command_pools_cleanup(dst->device);\n    }\n}\n\nstatic void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) {\n    VK_LOG_DEBUG(\"ggml_vk_buffer_write(\" << size << \")\");\n    ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1);\n}\n\nstatic bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {\n    VK_LOG_DEBUG(\"ggml_vk_buffer_read_2d_async(offset=\" << offset << \", width=\" << width << \", height=\" << height << \")\");\n    GGML_ASSERT(width > 0);\n    GGML_ASSERT(height > 0);\n    GGML_ASSERT(src != nullptr);\n\n    // TODO: staging_offset is not used\n\n    // Check if dst is pinned memory\n    vk_buffer buf = nullptr;\n    size_t buf_offset = 0;\n    ggml_vk_host_get(src->device, dst, buf, buf_offset);\n\n    std::vector<vk::BufferCopy> slices(1);\n    if (width == spitch && width == dpitch) {\n        // Only do single write if stride is equal\n        slices[0].srcOffset = offset;\n        slices[0].dstOffset = buf_offset;\n        slices[0].size = width * height;\n    } else {\n        slices.resize(height);\n        for (size_t i = 0; i < height; i++) {\n            slices[i].srcOffset = offset + i * spitch;\n            slices[i].dstOffset = buf_offset + i * dpitch;\n            slices[i].size = width;\n        }\n    }\n\n    if (buf != nullptr) {\n        // Memory is pinned, use as staging buffer\n        ggml_vk_sync_buffers(nullptr, subctx);\n        subctx->s->buffer->buf.copyBuffer(src->buffer, buf->buffer, slices);\n\n        return true;\n    }\n    VK_LOG_DEBUG(\"STAGING\");\n\n    if (!sync_staging) {\n        // copy was not handled caller needs to fall back\n        return false;\n    }\n\n    // Fall back to staging buffer\n    const size_t copy_size = dpitch * height;\n    ggml_vk_ensure_sync_staging_buffer(src->device, copy_size);\n\n    vk_buffer& staging_buffer = src->device->sync_staging;\n\n    ggml_vk_sync_buffers(nullptr, subctx);\n    subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, slices);\n\n    deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);\n    return true;\n}\n\nstatic bool ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) {\n    return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging);\n}\n\nstatic void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {\n    VK_LOG_DEBUG(\"ggml_vk_buffer_read(\" << src->buffer << \", \" << offset << \", \" << size << \")\");\n\n    // If the device is not an UMA device the memory is host-accessible through rebar. While writing\n    // through PCIe is sufficient fast reading back data from PCIe is slower than going through\n    // the HW device to host copy path.\n    if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) {\n        GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);\n\n        memcpy(dst, (uint8_t *) src->ptr + offset, size);\n    } else {\n        std::lock_guard<std::recursive_mutex> guard(src->device->mutex);\n\n        vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);\n        ggml_vk_ctx_begin(src->device, subctx);\n        bool ret = ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true);\n        GGML_ASSERT(ret);\n        ggml_vk_ctx_end(subctx);\n\n        ggml_vk_submit(subctx, src->device->fence);\n        VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), \"vk_buffer_read waitForFences\");\n        src->device->device.resetFences({ src->device->fence });\n        ggml_vk_queue_command_pools_cleanup(src->device);\n\n        for (auto& cpy : subctx->out_memcpys) {\n            memcpy(cpy.dst, cpy.src, cpy.n);\n        }\n    }\n}\n\nstatic void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {\n    VK_LOG_DEBUG(\"ggml_vk_buffer_copy_async(\" << size << \")\");\n    // Make sure both buffers are on same device\n    GGML_ASSERT(src->device == dst->device);\n\n    VkBufferCopy bc{ src_offset, dst_offset, size };\n\n    vkCmdCopyBuffer(ctx->s->buffer->buf, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc);\n}\n\nstatic void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {\n    if (src->device == dst->device) {\n        std::lock_guard<std::recursive_mutex> guard(src->device->mutex);\n        VK_LOG_DEBUG(\"ggml_vk_buffer_copy(SINGLE_DEVICE, \" << size << \")\");\n        // Copy within the device\n        vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);\n        ggml_vk_ctx_begin(src->device, subctx);\n        ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);\n        ggml_vk_ctx_end(subctx);\n        ggml_vk_submit(subctx, src->device->fence);\n        VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), \"vk_buffer_copy waitForFences\");\n        src->device->device.resetFences({ src->device->fence });\n        ggml_vk_queue_command_pools_cleanup(src->device);\n    } else {\n        VK_LOG_DEBUG(\"ggml_vk_buffer_copy(MULTI_DEVICE, \" << size << \")\");\n        // Copy device to device\n        ggml_vk_ensure_sync_staging_buffer(src->device, size);\n\n        // Copy to src staging buffer\n        ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);\n        // Copy to dst buffer\n        ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1);\n    }\n}\n\nstatic void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {\n    VK_LOG_DEBUG(\"ggml_vk_buffer_memset_async(\" << offset << \", \" << c << \", \" << size << \")\");\n\n    if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&\n        dst->device->uma) {\n        deferred_memset((uint8_t*)dst->ptr + offset, c, size, &ctx->memsets);\n        return;\n    }\n\n    // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers\n    ctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c);\n}\n\nstatic void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {\n    VK_LOG_DEBUG(\"ggml_vk_buffer_memset(\" << offset << \", \" << c << \", \" << size << \")\");\n\n    if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&\n        dst->device->uma) {\n        memset((uint8_t*)dst->ptr + offset, c, size);\n        return;\n    }\n\n    std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);\n    vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);\n    ggml_vk_ctx_begin(dst->device, subctx);\n    subctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c);\n    ggml_vk_ctx_end(subctx);\n\n    ggml_vk_submit(subctx, dst->device->fence);\n    VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), \"vk_memset waitForFences\");\n    dst->device->device.resetFences({ dst->device->fence });\n    ggml_vk_queue_command_pools_cleanup(dst->device);\n}\n\nstatic uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) {\n    VK_LOG_DEBUG(\"ggml_vk_guess_split_k(\" << m << \", \" << n << \", \" << k << \", \" << disable_split_k << \")\");\n\n    if (disable_split_k) {\n        return 1;\n    }\n\n    uint32_t split_k = 1;\n    if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {\n        // If k is 'large' and the SMs will fill less than halfway, use split_k.\n        uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);\n        uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);\n\n        if (k >= 2048) {\n            if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) {\n                split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);\n            } else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) {\n                split_k = 3;\n            }\n            // Cap the split at 8x. Unless k is huge this is a lot of overhead.\n            split_k = std::min(split_k, 8u);\n\n            // ggml_vk_matmul will align the splits to be a multiple of 256.\n            // If this rounded up size would cause the last split to be empty,\n            // then reduce the split count.\n            while (true) {\n                if (split_k == 1) {\n                    break;\n                }\n                uint32_t k_split = CEIL_DIV(k, split_k);\n                k_split = ROUNDUP_POW2(k_split, 256);\n                if (k_split * (split_k - 1) < k) {\n                    break;\n                }\n                split_k--;\n            }\n        }\n    }\n\n    return split_k;\n}\n\nstatic vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {\n    VK_LOG_DEBUG(\"ggml_vk_guess_matmul_pipeline(\" << m << \", \" << n << \", \" << aligned << \", \" << ggml_type_name(src0_type) << \", \" << ggml_type_name(src1_type) << \")\");\n\n    if (ctx->device->coopmat2) {\n        const uint32_t shader_core_count = ctx->device->shader_core_count;\n        const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);\n        const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]);\n\n        // Use large shader when the N dimension is greater than the medium shader's tile size\n        uint32_t crossover_large = mmp->m->wg_denoms[1];\n\n        // Prefer large over medium if either:\n        // - medium or large tiles would overfill the GPU\n        // - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not\n        //   (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead)\n        bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count ||\n                            // split_k==3 with large tiles likely better than medium tiles with no split_k.\n                            (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);\n\n        if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {\n            return aligned ? mmp->a_l : mmp->l;\n        }\n        // Use medium shader when the N dimension is greater than the small shader's tile size\n        uint32_t crossover_medium = mmp->s->wg_denoms[1];\n        if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {\n            return aligned ? mmp->a_m : mmp->m;\n        }\n        return aligned ? mmp->a_s : mmp->s;\n    }\n\n    if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {\n        return aligned ? mmp->a_s : mmp->s;\n    }\n    if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {\n        return aligned ? mmp->a_m : mmp->m;\n    }\n    return aligned ? mmp->a_l : mmp->l;\n\n    GGML_UNUSED(src1_type);\n}\n\nstatic uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {\n    VK_LOG_DEBUG(\"ggml_vk_guess_matmul_pipeline_align(\" << m << \", \" << n << \", \" << ggml_type_name(src0_type) << \", \" << ggml_type_name(src1_type) << \")\");\n    return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;\n}\n\nstatic void ggml_vk_matmul(\n        ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,\n        vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,\n        uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,\n        uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,\n        uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,\n        uint32_t padded_n) {\n        VK_LOG_DEBUG(\"ggml_vk_matmul(a: (\" << a.buffer->buffer << \", \" << a.offset << \", \" << a.size << \"), b: (\" << b.buffer->buffer << \", \" << b.offset << \", \" << b.size << \"), d: (\" << d.buffer->buffer << \", \" << d.offset << \", \" << d.size << \"), split_k: (\" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << \", \" << split_k_buffer.offset << \", \" << split_k_buffer.size << \"), m: \" << m << \", n: \" << n << \", k: \" << k << \", stride_a: \" << stride_a << \", stride_b: \" << stride_b << \", stride_d: \" << stride_d << \", batch_stride_a: \" << batch_stride_a << \", batch_stride_b: \" << batch_stride_b << \", batch_stride_d: \" << batch_stride_d << \", split_k: \" << split_k << \", batch: \" << batch << \", ne02: \" << ne02 << \", ne12: \" << ne12 << \", broadcast2: \" << broadcast2 << \", broadcast3: \" << broadcast3 << \", padded_n: \" << padded_n << \")\");\n    if (split_k == 1) {\n        ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));\n\n        uint32_t base_work_group_z = 0;\n        while (base_work_group_z < batch) {\n            uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);\n\n            const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };\n            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });\n            base_work_group_z += groups_z;\n        }\n        return;\n    }\n\n    if (ctx->prealloc_split_k_need_sync) {\n        ggml_vk_sync_buffers(ctx, subctx);\n    }\n\n    GGML_ASSERT(batch_stride_d == m * n);\n\n    // Round the split size up to a multiple of 256 (k-quant alignment)\n    uint32_t k_split = CEIL_DIV(k, split_k);\n    k_split = ROUNDUP_POW2(k_split, 256);\n\n    ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));\n\n    uint32_t base_work_group_z = 0;\n    while (base_work_group_z < batch) {\n        uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);\n\n        const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };\n        // Make sure enough workgroups get assigned for split k to work\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });\n        base_work_group_z += groups_z;\n    }\n    ggml_vk_sync_buffers(ctx, subctx);\n    const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };\n    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });\n    ctx->prealloc_split_k_need_sync = true;\n}\n\nstatic vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {\n    VK_LOG_DEBUG(\"ggml_vk_guess_matmul_id_pipeline(\" << m << \", \" << n << \", \" << aligned << \", \" << ggml_type_name(src0_type) << \")\");\n\n    if (ctx->device->coopmat2) {\n        // Use large shader when the N dimension is greater than the medium shader's tile size\n        uint32_t crossover_large = mmp->m->wg_denoms[1];\n        if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {\n            return aligned ? mmp->a_l : mmp->l;\n        }\n        // Use medium shader when the N dimension is greater than the small shader's tile size\n        uint32_t crossover_medium = mmp->s->wg_denoms[1];\n        if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {\n            return aligned ? mmp->a_m : mmp->m;\n        }\n        return aligned ? mmp->a_s : mmp->s;\n    }\n\n    if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {\n        return aligned ? mmp->a_s : mmp->s;\n    }\n    if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {\n        return aligned ? mmp->a_m : mmp->m;\n    }\n    return aligned ? mmp->a_l : mmp->l;\n}\n\nstatic uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {\n    VK_LOG_DEBUG(\"ggml_vk_guess_matmul_pipeline_align(\" << m << \", \" << n << \", \" << ggml_type_name(src0_type) << \")\");\n    return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;\n}\n\nstatic void ggml_vk_matmul_id(\n        ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,\n        vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, const vk_subbuffer & expert_count_buf,\n        uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,\n        uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,\n        uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,\n        uint32_t padded_n) {\n    VK_LOG_DEBUG(\"ggml_vk_matmul_id(a: (\" << a.buffer->buffer << \", \" << a.offset << \", \" << a.size << \"), b: (\" << b.buffer->buffer << \", \" << b.offset << \", \" << b.size << \"), d: (\" << d.buffer->buffer << \", \" << d.offset << \", \" << d.size << \"), ids: (\" << ids.buffer->buffer << \", \" << ids.offset << \", \" << ids.size << \"), expert_count: (\" << expert_count_buf.buffer->buffer << \", \" << expert_count_buf.offset << \", \" << expert_count_buf.size << \"), \" <<\n        \"m: \" << m << \", n: \" << n << \", k: \" << k << \", stride_a: \" << stride_a << \", stride_b: \" << stride_b << \", stride_d: \" << stride_d << \", \" <<\n        \"batch_stride_a: \" << batch_stride_a << \", batch_stride_b: \" << batch_stride_b << \", batch_stride_d: \" << batch_stride_d << \", \" <<\n        \"n_as: \" << n_as << \", nei0: \" << nei0 << \", nei1: \" << nei1 << \", nbi1: \" << nbi1 << \", ne11: \" << ne11 << \")\");\n    const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,\n                                              nei0, nei1, nbi1, ne11, padded_n };\n    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids, expert_count_buf }, pc, { m, nei1, n_as });\n}\n\nstatic bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {\n    return\n        tensor->nb[0] == ggml_type_size(tensor->type) &&\n        tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&\n        (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]);\n}\n\nstatic vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {\n\n    // Choose \"contiguous copy\" shader if src/dst are contiguous\n    bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst));\n\n    // Use optimized \"transpose\" shader if src dim1 is the innermost dimension.\n    bool transpose = dst && src->nb[1] == ggml_type_size(to) && ggml_are_same_shape(dst, src);\n\n    if (transpose && src->type == to) {\n        if (ggml_type_size(to) == 4) {\n            return ctx->device->pipeline_cpy_transpose_32;\n        } else if (ggml_type_size(to) == 2) {\n            return ctx->device->pipeline_cpy_transpose_16;\n        }\n    }\n\n    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) {\n        if (contig) {\n            return ctx->device->pipeline_contig_cpy_f32_f32;\n        } else {\n            return ctx->device->pipeline_cpy_f32_f32;\n        }\n    }\n    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) {\n        if (contig) {\n            return ctx->device->pipeline_contig_cpy_f32_f16;\n        } else {\n            return ctx->device->pipeline_cpy_f32_f16;\n        }\n    }\n    if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) {\n        if (contig) {\n            return ctx->device->pipeline_contig_cpy_f16_f16;\n        } else {\n            return ctx->device->pipeline_cpy_f16_f16;\n        }\n    }\n    if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) {\n        if (contig) {\n            return ctx->device->pipeline_contig_cpy_f16_f32;\n        } else {\n            return ctx->device->pipeline_cpy_f16_f32;\n        }\n    }\n    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {\n        if (contig) {\n            return ctx->device->pipeline_contig_cpy_f32_bf16;\n        } else {\n            return ctx->device->pipeline_cpy_f32_bf16;\n        }\n    }\n    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) {\n        if (contig) {\n            return ctx->device->pipeline_contig_cpy_f32_i32;\n        } else {\n            return ctx->device->pipeline_cpy_f32_i32;\n        }\n    }\n    if (src->type == GGML_TYPE_I32 && to == GGML_TYPE_F32) {\n        if (contig) {\n            return ctx->device->pipeline_contig_cpy_i32_f32;\n        } else {\n            return ctx->device->pipeline_cpy_i32_f32;\n        }\n    }\n    if (src->type == GGML_TYPE_F32) {\n        switch (to) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_IQ4_NL:\n            return ctx->device->pipeline_cpy_f32_quant[to];\n        default:\n            break;\n        }\n    }\n\n    if (to == GGML_TYPE_F32) {\n        switch (src->type) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_IQ4_NL:\n            return ctx->device->pipeline_cpy_quant_f32[src->type];\n        default:\n            break;\n        }\n    }\n\n    if (src->type == to) {\n        // Copy two or four bytes at a time, depending on block size.\n        // For quantized types, we scale by block size/type size. But\n        // this path is also used for bf16->bf16 for example, where the\n        // type size must be exactly 2 or 4.\n        GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);\n        if ((ggml_type_size(src->type) % 4) == 0) {\n            if (contig) {\n                return ctx->device->pipeline_contig_cpy_f32_f32;\n            } else {\n                return ctx->device->pipeline_cpy_f32_f32;\n            }\n        } else {\n            if (contig) {\n                return ctx->device->pipeline_contig_cpy_f16_f16;\n            } else {\n                return ctx->device->pipeline_cpy_f16_f16;\n            }\n        }\n    }\n\n    std::cerr << \"Missing CPY op for types: \" << ggml_type_name(src->type) << \" \" << ggml_type_name(to) << std::endl;\n    GGML_ABORT(\"fatal error\");\n}\n\nstatic void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, const vk_subbuffer & in, const vk_subbuffer & out) {\n    VK_LOG_DEBUG(\"ggml_vk_cpy_to_contiguous((\" << tensor << \", type=\" << tensor->type << \", ne0=\" << tensor->ne[0] << \", ne1=\" << tensor->ne[1] << \", ne2=\" << tensor->ne[2] << \", ne3=\" << tensor->ne[3] << \", nb0=\" << tensor->nb[0] << \", nb1=\" << tensor->nb[1] << \", nb2=\" << tensor->nb[2] << \", nb3=\" << tensor->nb[3] << \"), \";\n    std::cerr << \"buffer in size=\" << in.buffer->size << \", buffer out size=\" << out.buffer->size << \")\");\n    const int tensor_type_size = ggml_type_size(tensor->type);\n\n    const uint32_t ne = ggml_nelements(tensor);\n    std::array<uint32_t, 3> elements;\n\n    if (ne > 262144) {\n        elements = { 512, 512, CEIL_DIV(ne, 262144) };\n    } else if (ne > 512) {\n        elements = { 512, CEIL_DIV(ne, 512), 1 };\n    } else {\n        elements = { ne, 1, 1 };\n    }\n\n    vk_op_unary_push_constants pc = {\n        (uint32_t)ne,\n        (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,\n        (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3],                       1                   , (uint32_t)tensor->ne[0]                   , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),\n        0,\n        0.0f, 0.0f,\n        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n    };\n    init_pushconst_fastdiv(pc);\n    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);\n    ggml_vk_sync_buffers(ctx, subctx);\n}\n\nstatic vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {\n    switch(type) {\n        case GGML_TYPE_Q8_1:\n            return ctx->device->pipeline_quantize_q8_1_x4;\n        default:\n            std::cerr << \"Missing quantize pipeline for type: \" << ggml_type_name(type) << std::endl;\n            GGML_ABORT(\"fatal error\");\n    }\n}\n\nstatic void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, const vk_subbuffer & in, const vk_subbuffer & out, uint32_t ne) {\n    VK_LOG_DEBUG(\"ggml_vk_quantize_q8_1(\" << \"buffer in size=\" << in.buffer->size << \", buffer out size=\" << out.buffer->size << \", \" << ne << \")\");\n\n    vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);\n\n    const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]);\n    // clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks.\n    const uint64_t max_elements = std::min<uint64_t>(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits<uint32_t>::max());\n    const uint32_t elements = std::min(ne, static_cast<uint32_t>(max_elements));\n\n    const vk_quantize_q8_1_push_constants pc = {\n        ne,\n        num_blocks,\n    };\n\n    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, { elements, 1, 1 });\n    ggml_vk_sync_buffers(ctx, subctx);\n}\n\nstatic vk_pipeline ggml_vk_get_64b_indexing_pipeline(ggml_backend_vk_context * ctx, vk_pipeline &pipeline) {\n    GGML_UNUSED(ctx);\n#if defined(VK_EXT_shader_64bit_indexing)\n    vk_pipeline *ptr = &pipeline;\n    while (*ptr) {\n        if ((*ptr)->is_64b_indexing) {\n            return *ptr;\n        }\n        ptr = &(*ptr)->next;\n    }\n#endif\n    return pipeline;\n}\n\nstatic void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k) {\n    VK_LOG_DEBUG(\"ggml_vk_mul_mat_q_f16((\" << src0 << \", name=\" << src0->name << \", type=\" << ggml_type_name(src0->type) << \", ne0=\" << src0->ne[0] << \", ne1=\" << src0->ne[1] << \", ne2=\" << src0->ne[2] << \", ne3=\" << src0->ne[3] << \", nb0=\" << src0->nb[0] << \", nb1=\" << src0->nb[1] << \", nb2=\" << src0->nb[2] << \", nb3=\" << src0->nb[3];\n    std::cerr << \"), (\" << src1 << \", name=\" << src1->name << \", type=\" << ggml_type_name(src1->type) << \", ne0=\" << src1->ne[0] << \", ne1=\" << src1->ne[1] << \", ne2=\" << src1->ne[2] << \", ne3=\" << src1->ne[3] << \", nb0=\" << src1->nb[0] << \", nb1=\" << src1->nb[1] << \", nb2=\" << src1->nb[2] << \", nb3=\" << src1->nb[3];\n    std::cerr << \"), (\" << dst << \", name=\" << dst->name << \", type=\" << ggml_type_name(dst->type) << \", ne0=\" << dst->ne[0] << \", ne1=\" << dst->ne[1] << \", ne2=\" << dst->ne[2] << \", ne3=\" << dst->ne[3] << \", nb0=\" << dst->nb[0] << \", nb1=\" << dst->nb[1] << \", nb2=\" << dst->nb[2] << \", nb3=\" << dst->nb[3];\n    std::cerr << \"))\");\n    GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16);  // NOLINT\n    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT\n\n    const uint64_t ne00 = src0->ne[0];\n    const uint64_t ne01 = src0->ne[1];\n    const uint64_t ne02 = src0->ne[2];\n    const uint64_t ne03 = src0->ne[3];\n\n    const uint64_t ne10 = src1->ne[0];\n    const uint64_t ne11 = src1->ne[1];\n    const uint64_t ne12 = src1->ne[2];\n    const uint64_t ne13 = src1->ne[3];\n\n    const uint64_t ne21 = dst->ne[1];\n    const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type);\n    const uint32_t stride_batch_d = stride_d*ne21;\n\n    const uint64_t r2 = ne12 / ne02;\n    const uint64_t r3 = ne13 / ne03;\n\n    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;\n    ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;\n    ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;\n\n    vk_buffer d_Qx = nullptr;\n    size_t qx_buf_offset = 0;\n    vk_buffer d_Qy = nullptr;\n    size_t qy_buf_offset = 0;\n\n    bool src0_uma = false;\n    bool src1_uma = false;\n\n    if (ctx->device->uma) {\n        ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);\n        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);\n        src0_uma = d_Qx != nullptr;\n        src1_uma = d_Qy != nullptr;\n    }\n\n    // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf\n    const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||\n                              !ggml_vk_dim01_contiguous(src0);\n    const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||\n                              (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||\n                              !ggml_vk_dim01_contiguous(src1);\n\n    // If src0 is BF16, try to use a BF16 x BF16 multiply\n    ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;\n\n    const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;\n\n    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0;\n\n    // Check for mmq first\n    vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;\n\n    if (mmp == nullptr) {\n        // Fall back to f16 dequant mul mat\n        mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);\n        quantize_y = false;\n    }\n\n    const bool qx_needs_dequant = mmp == nullptr || x_non_contig;\n    const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);\n\n    if (qx_needs_dequant) {\n        // Fall back to dequant + f16 mulmat\n        mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);\n    }\n\n    // Not implemented\n    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT\n\n    const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));\n    const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;\n\n    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));\n\n    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {\n        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);\n    }\n\n    // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking\n    uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;\n    const uint64_t x_ne = ggml_nelements(src0);\n    // 128 elements per Q8_1 x4 block\n    const uint64_t y_ne = padded_n * ne10 * ne12 * ne13;\n    const uint64_t d_ne = ggml_nelements(dst);\n\n    const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline);\n\n    const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);\n    const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);\n    const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;\n    const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);\n    const uint64_t d_sz = sizeof(float) * d_ne;\n\n    vk_pipeline to_fp16_vk_0 = nullptr;\n    vk_pipeline to_fp16_vk_1 = nullptr;\n    vk_pipeline to_q8_1 = nullptr;\n\n    if (x_non_contig) {\n        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);\n    } else {\n        to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);\n    }\n    if (y_non_contig) {\n        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);\n    } else {\n        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);\n    }\n    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT\n    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT\n\n    if (quantize_y) {\n        to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);\n    }\n\n    {\n        const uint64_t split_k_size = split_k > 1 ? d_sz * split_k : 0;\n        if (\n                (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||\n                (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange) ||\n                (split_k > 1 && split_k_size > ctx->device->properties.limits.maxStorageBufferRange)) {\n            GGML_ABORT(\"Requested preallocation size is too large\");\n        }\n        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {\n            ctx->prealloc_size_x = x_sz;\n            ggml_vk_preallocate_buffers(ctx, subctx);\n        }\n        if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {\n            ctx->prealloc_size_y = y_sz;\n            ggml_vk_preallocate_buffers(ctx, subctx);\n        }\n        if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {\n            ctx->prealloc_size_split_k = split_k_size;\n            ggml_vk_preallocate_buffers(ctx, subctx);\n        }\n\n        // Request descriptor sets\n        if (qx_needs_dequant) {\n            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);\n        }\n        if (qy_needs_dequant) {\n            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);\n        }\n        if (quantize_y) {\n            ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);\n        }\n        if (split_k > 1) {\n            ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1);\n        }\n    }\n\n    vk_buffer d_D = dst_buf_ctx->dev_buffer;\n    const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;\n    GGML_ASSERT(d_D != nullptr);\n    GGML_ASSERT(d_D->size >= d_buf_offset + d_sz);\n    vk_buffer d_X;\n    uint64_t x_buf_offset = 0;\n    vk_buffer d_Y;\n    uint64_t y_buf_offset = 0;\n    if (!src0_uma) {\n        d_Qx = src0_buf_ctx->dev_buffer;\n        qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;\n        GGML_ASSERT(d_Qx != nullptr);\n    }\n    if (!src1_uma) {\n        d_Qy = src1_buf_ctx->dev_buffer;\n        qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;\n        GGML_ASSERT(d_Qy != nullptr);\n    }\n    if (qx_needs_dequant) {\n        d_X = ctx->prealloc_x;\n        GGML_ASSERT(d_X->size >= x_sz);\n    } else {\n        d_X = d_Qx;\n        x_buf_offset = qx_buf_offset;\n        GGML_ASSERT(qx_sz == x_sz);\n    }\n    if (qy_needs_dequant) {\n        d_Y = ctx->prealloc_y;\n        GGML_ASSERT(d_Y->size >= y_sz);\n    } else if (quantize_y) {\n        d_Y = ctx->prealloc_y;\n        GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz, 144) * 144);\n    } else {\n        d_Y = d_Qy;\n        y_buf_offset = qy_buf_offset;\n        GGML_ASSERT(qy_sz == y_sz);\n    }\n\n    if (x_non_contig || qx_needs_dequant) {\n        if (ctx->prealloc_x_need_sync) {\n            ggml_vk_sync_buffers(ctx, subctx);\n        }\n    }\n\n    if (x_non_contig) {\n        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));\n    } else if (qx_needs_dequant) {\n        const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };\n        ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_X, 0, x_sz } }, pc, { (uint32_t)(x_ne), 1, 1});\n        ggml_vk_sync_buffers(ctx, subctx);\n    }\n    if (y_non_contig) {\n        if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||\n            ctx->prealloc_y_last_tensor_used != src1) {\n            if (ctx->prealloc_y_need_sync) {\n                ggml_vk_sync_buffers(ctx, subctx);\n            }\n            ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));\n            ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();\n            ctx->prealloc_y_last_tensor_used = src1;\n        }\n    }\n    if (quantize_y) {\n        if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||\n            ctx->prealloc_y_last_tensor_used != src1) {\n            if (ctx->prealloc_y_need_sync) {\n                ggml_vk_sync_buffers(ctx, subctx);\n            }\n            ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);\n            ctx->prealloc_y_last_pipeline_used = to_q8_1.get();\n            ctx->prealloc_y_last_tensor_used = src1;\n        }\n    }\n\n    uint32_t stride_batch_x = ne00*ne01;\n    uint32_t stride_batch_y = ne10*ne11;\n\n    if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {\n        stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);\n    }\n\n    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {\n        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);\n    }\n\n    // compute\n    ggml_vk_matmul(\n        ctx, subctx, pipeline,\n        { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },\n        ggml_vk_subbuffer(ctx, d_D, d_buf_offset), { ctx->prealloc_split_k, 0, d_sz * split_k },\n        ne01, ne11, ne10,\n        ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d,\n        split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n\n    );  // NOLINT\n\n    if (x_non_contig || qx_needs_dequant) {\n        ctx->prealloc_x_need_sync = true;\n    }\n    if (y_non_contig || quantize_y) {\n        ctx->prealloc_y_need_sync = true;\n    }\n}\n\n// Device tuning\nstatic bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_t n, uint32_t k, ggml_type src0_type) {\n    if (device->mmvq_mode == 1) {\n        return true;\n    } else if (device->mmvq_mode == -1) {\n        return false;\n    }\n\n    // General performance issue with q3_k and q6_k due to 2-byte alignment\n    if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {\n        return false;\n    }\n\n    // MMVQ is generally good for batches\n    if (n > 1) {\n        return true;\n    }\n\n    // Quantization overhead is not worth it for small k\n    switch (device->vendor_id) {\n    case VK_VENDOR_ID_NVIDIA:\n        if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {\n            return true;\n        }\n\n        if (k <= 4096) {\n            return false;\n        }\n\n        switch (src0_type) {\n        case GGML_TYPE_MXFP4:\n        case GGML_TYPE_Q8_0:\n            return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;\n        default:\n            return true;\n        }\n    case VK_VENDOR_ID_AMD:\n        if (k < 2048) {\n            return false;\n        }\n\n        switch (src0_type) {\n        case GGML_TYPE_Q8_0:\n            return device->architecture == vk_device_architecture::AMD_GCN;\n        default:\n            return true;\n        }\n    case VK_VENDOR_ID_INTEL:\n        if (k < 2048) {\n            return false;\n        }\n\n        if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) {\n            // Intel Windows proprietary driver tuning\n            switch (src0_type) {\n            case GGML_TYPE_MXFP4:\n            case GGML_TYPE_Q4_K:\n            case GGML_TYPE_Q5_K:\n                return false;\n            default:\n                return true;\n            }\n        }\n\n        switch (src0_type) {\n        // From tests on A770 Linux, may need more tuning\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q5_1:\n            return false;\n        default:\n            return true;\n        }\n    default:\n        return true;\n    }\n\n    GGML_UNUSED(m);\n}\n\nstatic void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {\n    ggml_tensor * dst = cgraph->nodes[node_idx];\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    VK_LOG_DEBUG(\"ggml_vk_mul_mat_vec_q_f16((\" << src0 << \", name=\" << src0->name << \", type=\" << src0->type << \", ne0=\" << src0->ne[0] << \", ne1=\" << src0->ne[1] << \", ne2=\" << src0->ne[2] << \", ne3=\" << src0->ne[3] << \", nb0=\" << src0->nb[0] << \", nb1=\" << src0->nb[1] << \", nb2=\" << src0->nb[2] << \", nb3=\" << src0->nb[3];\n    std::cerr << \"), (\" << src1 << \", name=\" << src1->name << \", type=\" << src1->type << \", ne0=\" << src1->ne[0] << \", ne1=\" << src1->ne[1] << \", ne2=\" << src1->ne[2] << \", ne3=\" << src1->ne[3] << \", nb0=\" << src1->nb[0] << \", nb1=\" << src1->nb[1] << \", nb2=\" << src1->nb[2] << \", nb3=\" << src1->nb[3];\n    std::cerr << \"), (\" << dst << \", name=\" << dst->name << \", type=\" << dst->type << \", ne0=\" << dst->ne[0] << \", ne1=\" << dst->ne[1] << \", ne2=\" << dst->ne[2] << \", ne3=\" << dst->ne[3] << \", nb0=\" << dst->nb[0] << \", nb1=\" << dst->nb[1] << \", nb2=\" << dst->nb[2] << \", nb3=\" << dst->nb[3];\n    std::cerr << \")),)\");\n    GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16);  // NOLINT\n    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT\n\n    const uint64_t ne00 = src0->ne[0];\n    const uint64_t ne01 = src0->ne[1];\n    const uint64_t ne02 = src0->ne[2];\n    const uint64_t ne03 = src0->ne[3];\n\n    const uint64_t ne10 = src1->ne[0];\n    const uint64_t ne11 = src1->ne[1];\n    const uint64_t ne12 = src1->ne[2];\n    const uint64_t ne13 = src1->ne[3];\n\n    const uint64_t ne20 = dst->ne[0];\n    const uint64_t ne21 = dst->ne[1];\n    // const uint64_t ne22 = dst->ne[2];\n    // const uint64_t ne23 = dst->ne[3];\n\n    const uint64_t r2 = ne12 / ne02;\n    const uint64_t r3 = ne13 / ne03;\n\n    // batch_n indicates that we need to compute a few vector results, and this assumes\n    // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides.\n    GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1);\n    bool batch_n = ne11 > 1;\n\n    const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);\n    const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);\n\n    const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;\n    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type);\n\n    vk_pipeline to_fp16_vk_0 = nullptr;\n    vk_pipeline to_fp16_vk_1 = nullptr;\n    if (x_non_contig) {\n        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);\n    }\n    if (y_non_contig) {\n        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);\n    } else {\n        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);\n    }\n\n    // Check for mmq first\n    vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11, ne20, ne00) : nullptr;\n    vk_pipeline to_q8_1 = nullptr;\n\n    if (dmmv == nullptr) {\n        // Fall back to f16 dequant mul mat\n        dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11, ne20, ne00);\n        quantize_y = false;\n    }\n\n    if (quantize_y) {\n        to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);\n    }\n\n    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {\n        dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);\n    }\n\n    const bool qx_needs_dequant = x_non_contig;\n    const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);\n\n    // Not implemented\n    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT\n\n    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT\n    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT\n    GGML_ASSERT(dmmv != nullptr);\n\n    const uint64_t x_ne = ggml_nelements(src0);\n    const uint64_t y_ne = ggml_nelements(src1);\n\n    const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);\n    const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;\n    const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :\n                         (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);\n\n    {\n        if (\n                (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||\n                (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange)) {\n            GGML_ABORT(\"Requested preallocation size is too large\");\n        }\n        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {\n            ctx->prealloc_size_x = x_sz;\n            ggml_vk_preallocate_buffers(ctx, subctx);\n        }\n        if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {\n            ctx->prealloc_size_y = y_sz;\n            ggml_vk_preallocate_buffers(ctx, subctx);\n        }\n\n        // Request descriptor sets\n        if (qx_needs_dequant) {\n            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);\n        }\n        if (qy_needs_dequant) {\n            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);\n        }\n        if (quantize_y) {\n            ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);\n        }\n    }\n\n    vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);\n    vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);\n    vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1);\n    vk_subbuffer d_X, d_Y;\n\n    if (qx_needs_dequant) {\n        d_X = { ctx->prealloc_x, 0, ctx->prealloc_x->size };\n    } else {\n        d_X = d_Qx;\n        GGML_ASSERT(qx_sz == x_sz);\n    }\n    if (qy_needs_dequant || quantize_y) {\n        d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };\n    } else {\n        d_Y = d_Qy;\n    }\n\n    if (x_non_contig) {\n        if (ctx->prealloc_x_need_sync) {\n            ggml_vk_sync_buffers(ctx, subctx);\n        }\n\n        GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));\n        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, d_Qx, d_X);\n    }\n    if (y_non_contig) {\n        GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);\n        if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||\n            ctx->prealloc_y_last_tensor_used != src1) {\n            if (ctx->prealloc_y_need_sync) {\n                ggml_vk_sync_buffers(ctx, subctx);\n            }\n            ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);\n            ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();\n            ctx->prealloc_y_last_tensor_used = src1;\n        }\n    }\n    if (quantize_y) {\n        if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||\n            ctx->prealloc_y_last_tensor_used != src1) {\n            if (ctx->prealloc_y_need_sync) {\n                ggml_vk_sync_buffers(ctx, subctx);\n            }\n            ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);\n            ctx->prealloc_y_last_pipeline_used = to_q8_1.get();\n            ctx->prealloc_y_last_tensor_used = src1;\n        }\n    }\n\n    // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride\n    uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01;\n    uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11);\n    uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21);\n\n    if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {\n        stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);\n    }\n\n    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {\n        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);\n    }\n\n    const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];\n\n    uint32_t groups_x = ne01;\n    uint32_t groups_z = 1;\n\n    if (ne01 > max_groups_x) {\n        groups_z = 64;\n        groups_x = CEIL_DIV(groups_x, groups_z);\n    }\n\n    uint32_t fusion_flags = 0;\n\n    vk_subbuffer d_F0 = d_D;\n    if (ctx->num_additional_fused_ops > 0) {\n        const ggml_tensor * add = cgraph->nodes[node_idx + 1];\n        const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];\n\n        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);\n        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;\n    }\n\n    vk_subbuffer d_F1 = d_D;\n    if (ctx->num_additional_fused_ops == 2) {\n        const ggml_tensor * add = cgraph->nodes[node_idx + 2];\n        const ggml_tensor * bias = add->src[0] == cgraph->nodes[node_idx + 1] ? add->src[1] : add->src[0];\n\n        d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);\n        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;\n    }\n\n    ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));\n\n    uint32_t base_work_group_y = 0;\n    while (base_work_group_y < ne12 * ne13) {\n\n        uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);\n        const vk_mat_vec_push_constants pc = {\n            (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,\n            stride_batch_x, stride_batch_y, stride_batch_d,\n            fusion_flags, base_work_group_y,\n            (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,\n        };\n        ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,\n                                  {\n                                    d_X,\n                                    d_Y,\n                                    d_D,\n                                    d_F0,\n                                    d_F1,\n                                  },\n                                  pc, { groups_x, groups_y, groups_z });\n        base_work_group_y += groups_y;\n    }\n\n    if (x_non_contig) {\n        ctx->prealloc_x_need_sync = true;\n    }\n    if (y_non_contig || quantize_y) {\n        ctx->prealloc_y_need_sync = true;\n    }\n}\n\nstatic void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {\n    ggml_tensor * dst = cgraph->nodes[node_idx];\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    VK_LOG_DEBUG(\"ggml_vk_mul_mat_p021_f16_f32(\" << src0 << \", name=\" << src0->name << \", type=\" << src0->type << \", ne0=\" << src0->ne[0] << \", ne1=\" << src0->ne[1] << \", ne2=\" << src0->ne[2] << \", ne3=\" << src0->ne[3] << \", nb0=\" << src0->nb[0] << \", nb1=\" << src0->nb[1] << \", nb2=\" << src0->nb[2] << \", nb3=\" << src0->nb[3];\n    std::cerr << \"), (\" << src1 << \", name=\" << src1->name << \", type=\" << src1->type << \", ne0=\" << src1->ne[0] << \", ne1=\" << src1->ne[1] << \", ne2=\" << src1->ne[2] << \", ne3=\" << src1->ne[3] << \", nb0=\" << src1->nb[0] << \", nb1=\" << src1->nb[1] << \", nb2=\" << src1->nb[2] << \", nb3=\" << src1->nb[3];\n    std::cerr << \"), (\" << dst << \", name=\" << dst->name << \", type=\" << dst->type << \", ne0=\" << dst->ne[0] << \", ne1=\" << dst->ne[1] << \", ne2=\" << dst->ne[2] << \", ne3=\" << dst->ne[3] << \", nb0=\" << dst->nb[0] << \", nb1=\" << dst->nb[1] << \", nb2=\" << dst->nb[2] << \", nb3=\" << dst->nb[3];\n    std::cerr << \"))\");\n    GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));\n    GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]);  // NOLINT\n    GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]);  // NOLINT\n    GGML_ASSERT(src0->type == GGML_TYPE_F16);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n\n    const uint64_t ne00 = src0->ne[0];\n    const uint64_t ne01 = src0->ne[1];\n    const uint64_t ne02 = src0->ne[2];\n    // const uint64_t ne03 = src0->ne[3];\n\n    //const uint64_t ne10 = src1->ne[0];\n    const uint64_t ne11 = src1->ne[1];\n    const uint64_t ne12 = src1->ne[2];\n    // const uint64_t ne13 = src1->ne[3];\n\n    GGML_ASSERT(ne11 == 1);\n\n    // With grouped query attention there are > 1 Q matrices per K, V matrix.\n    uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02;\n    if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {\n        gqa_ratio = 1;\n    }\n\n    vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1];\n\n    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {\n        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);\n    }\n\n    {\n        // Request descriptor sets\n        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n    }\n\n    vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);\n    vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);\n    vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true);\n\n    vk_subbuffer d_F0 = d_D;\n\n    uint32_t fusion_flags = 0;\n\n    if (ctx->num_additional_fused_ops > 0) {\n        const ggml_tensor * add = cgraph->nodes[node_idx + 1];\n        const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];\n\n        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);\n        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;\n    }\n\n    vk_subbuffer d_F1 = d_D;\n    if (ctx->num_additional_fused_ops > 1) {\n        const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1];\n\n        d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);\n        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;\n    }\n\n    // compute\n\n    vk_mat_vec_p021_push_constants pc = {\n        (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12,\n        0, 0, fusion_flags\n    };\n\n    init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);\n\n    uint32_t workgroups_z = (uint32_t)ne12;\n    // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups\n    if (gqa_ratio > 1) {\n        workgroups_z /= gqa_ratio;\n    }\n\n    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,\n        {\n            d_Qx,\n            d_Qy,\n            d_D,\n            d_F0,\n            d_F1,\n        }, pc, { 1, (uint32_t)ne01, workgroups_z });\n}\n\nstatic void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {\n    ggml_tensor * dst = cgraph->nodes[node_idx];\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    VK_LOG_DEBUG(\"ggml_vk_mul_mat_nc_f16_f32((\" << src0 << \", name=\" << src0->name << \", type=\" << src0->type << \", ne0=\" << src0->ne[0] << \", ne1=\" << src0->ne[1] << \", ne2=\" << src0->ne[2] << \", ne3=\" << src0->ne[3] << \", nb0=\" << src0->nb[0] << \", nb1=\" << src0->nb[1] << \", nb2=\" << src0->nb[2] << \", nb3=\" << src0->nb[3];\n    std::cerr << \"), (\" << src1 << \", name=\" << src1->name << \", type=\" << src1->type << \", ne0=\" << src1->ne[0] << \", ne1=\" << src1->ne[1] << \", ne2=\" << src1->ne[2] << \", ne3=\" << src1->ne[3] << \", nb0=\" << src1->nb[0] << \", nb1=\" << src1->nb[1] << \", nb2=\" << src1->nb[2] << \", nb3=\" << src1->nb[3];\n    std::cerr << \"), (\" << dst << \", name=\" << dst->name << \", type=\" << dst->type << \", ne0=\" << dst->ne[0] << \", ne1=\" << dst->ne[1] << \", ne2=\" << dst->ne[2] << \", ne3=\" << dst->ne[3] << \", nb0=\" << dst->nb[0] << \", nb1=\" << dst->nb[1] << \", nb2=\" << dst->nb[2] << \", nb3=\" << dst->nb[3];\n    std::cerr << \"))\");\n    GGML_ASSERT(!ggml_is_transposed(src0));\n    GGML_ASSERT(!ggml_is_transposed(src1));\n    GGML_ASSERT(!ggml_is_permuted(src0));\n    GGML_ASSERT(src0->type == GGML_TYPE_F16);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n\n    const uint64_t ne00 = src0->ne[0];\n    const uint64_t ne01 = src0->ne[1];\n    const uint64_t ne02 = src0->ne[2];\n    const uint64_t ne03 = src0->ne[3];\n\n    const uint64_t nb01 = src0->nb[1];\n    const uint64_t nb02 = src0->nb[2];\n\n    const uint64_t nb12 = src1->nb[2];\n\n    // const uint64_t ne10 = src1->ne[0];\n    const uint64_t ne11 = src1->ne[1];\n    const uint64_t ne12 = src1->ne[2];\n    // const uint64_t ne13 = src1->ne[3];\n\n    const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t));\n    const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float));\n    const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float));\n\n    GGML_ASSERT(ne11 == 1);\n    GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op\n\n    const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);\n    const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);\n    const uint32_t channel_stride_y = nb12 / sizeof(float);\n\n    vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_nc_f16_f32;\n    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {\n        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);\n    }\n\n    {\n        // Request descriptor sets\n        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n    }\n\n    vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);\n    vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);\n    vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true);\n    vk_subbuffer d_F0 = d_D;\n\n    uint32_t fusion_flags = 0;\n\n    if (ctx->num_additional_fused_ops > 0) {\n        const ggml_tensor * add = cgraph->nodes[node_idx + 1];\n        const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];\n\n        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);\n        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;\n    }\n\n    vk_subbuffer d_F1 = d_D;\n    if (ctx->num_additional_fused_ops > 1) {\n        const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1];\n\n        d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);\n        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;\n    }\n\n    // compute\n    vk_mat_vec_nc_push_constants pc = {\n        (uint32_t)ne00, (uint32_t)ne01,\n        row_stride_x, channel_stride_x, channel_stride_y,\n        (uint32_t)(ne12 / ne02), (uint32_t)ne12,\n        0, 0,\n        nb03, nb13, nb23, fusion_flags\n    };\n\n    init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);\n\n    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,\n        {\n            d_Qx,\n            d_Qy,\n            d_D,\n            d_F0,\n            d_F1,\n        }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });\n}\n\nstatic void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {\n    ggml_tensor * dst = cgraph->nodes[node_idx];\n    ggml_tensor * src0 = dst->src[0];\n    ggml_tensor * src1 = dst->src[1];\n    VK_LOG_DEBUG(\"ggml_vk_mul_mat(\" << src0 << \", \" << src1 << \", \" << dst << \")\");\n\n    // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases\n    // where the M dimension is very large.\n    // Split_k doesn't work with M splitting.\n    // This only supports batchsize == 1.\n    const size_t nbytes = ggml_nbytes(src0);\n    const bool needs_split = dst->ne[2] == 1 && dst->ne[3] == 1 && nbytes > ctx->device->properties.limits.maxStorageBufferRange;\n    if (needs_split) {\n        // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets)\n        const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]);\n        uint32_t m_offset = 0;\n        while (m_offset < dst->ne[0]) {\n            const uint32_t cur_M_size = std::min(M_split, (uint32_t)(dst->ne[0] - m_offset));\n            ggml_tensor dst2 = *dst;\n            ggml_tensor src02 = *src0;\n\n            dst2.view_src = dst->view_src ? dst->view_src : dst;\n            src02.view_src = src0->view_src ? src0->view_src : src0;\n\n            dst2.view_offs += m_offset * dst->nb[0];\n            src02.view_offs += m_offset * src0->nb[1];\n            dst2.ne[0] = cur_M_size;\n            src02.ne[1] = cur_M_size;\n\n            ggml_vk_mul_mat_q_f16(ctx, subctx, &src02, src1, &dst2, true);\n\n            m_offset += cur_M_size;\n        }\n    } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&\n        // detect 0213 permutation, and batch size of 1\n        src0->nb[0] <= src0->nb[2] &&\n        src0->nb[2] <= src0->nb[1] &&\n        src0->nb[1] <= src0->nb[3] &&\n        src1->nb[0] <= src1->nb[2] &&\n        src1->nb[2] <= src1->nb[1] &&\n        src1->nb[1] <= src1->nb[3] &&\n        src0->ne[3] == 1 &&\n        src1->ne[3] == 1 &&\n        src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&\n        src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {\n        ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);\n    } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&\n               !ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&\n               src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&\n               src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&\n               src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {\n        ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);\n    // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)\n    // when ne12 and ne13 are one.\n    } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&\n               (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {\n        ggml_vk_mul_mat_vec_q_f16(ctx, subctx, cgraph, node_idx);\n    } else {\n        ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, false);\n    }\n}\n\nstatic void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {\n    VK_LOG_DEBUG(\"ggml_vk_mul_mat_id_q_f16((\" << src0 << \", name=\" << src0->name << \", type=\" << src0->type << \", ne0=\" << src0->ne[0] << \", ne1=\" << src0->ne[1] << \", ne2=\" << src0->ne[2] << \", ne3=\" << src0->ne[3] << \", nb0=\" << src0->nb[0] << \", nb1=\" << src0->nb[1] << \", nb2=\" << src0->nb[2] << \", nb3=\" << src0->nb[3];\n    std::cerr << \"), (\" << src1 << \", name=\" << src1->name << \", type=\" << src1->type << \", ne0=\" << src1->ne[0] << \", ne1=\" << src1->ne[1] << \", ne2=\" << src1->ne[2] << \", ne3=\" << src1->ne[3] << \", nb0=\" << src1->nb[0] << \", nb1=\" << src1->nb[1] << \", nb2=\" << src1->nb[2] << \", nb3=\" << src1->nb[3];\n    std::cerr << \"), (\" << ids << \", name=\" << ids->name << \", type=\" << ids->type << \", ne0=\" << ids->ne[0] << \", ne1=\" << ids->ne[1] << \", ne2=\" << ids->ne[2] << \", ne3=\" << ids->ne[3] << \", nb0=\" << ids->nb[0] << \", nb1=\" << ids->nb[1] << \", nb2=\" << ids->nb[2] << \", nb3=\" << ids->nb[3];\n    std::cerr << \"), (\" << dst << \", name=\" << dst->name << \", type=\" << dst->type << \", ne0=\" << dst->ne[0] << \", ne1=\" << dst->ne[1] << \", ne2=\" << dst->ne[2] << \", ne3=\" << dst->ne[3] << \", nb0=\" << dst->nb[0] << \", nb1=\" << dst->nb[1] << \", nb2=\" << dst->nb[2] << \", nb3=\" << dst->nb[3] << \"),)\");\n    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT\n    GGML_ASSERT(ids->type == GGML_TYPE_I32);\n\n    const uint64_t ne00 = src0->ne[0];\n    const uint64_t ne01 = src0->ne[1];\n    const uint64_t ne02 = src0->ne[2];\n    // const uint64_t ne03 = src0->ne[3];\n\n    const uint64_t ne10 = src1->ne[0];\n    const uint64_t ne11 = src1->ne[1];\n    const uint64_t ne12 = src1->ne[2];\n    const uint64_t ne13 = src1->ne[3];\n\n    const uint64_t nei0 = ids->ne[0];\n    const uint64_t nei1 = ids->ne[1];\n\n    const uint32_t nbi0 = ids->nb[0];\n    const uint32_t nbi1 = ids->nb[1];\n    const uint32_t nbi2 = ids->nb[2];\n\n    const uint64_t ne20 = dst->ne[0];\n    const uint64_t ne21 = dst->ne[1];\n    // const uint64_t ne22 = dst->ne[2];\n    // const uint64_t ne23 = dst->ne[3];\n\n    const uint64_t n_as = ne02;\n\n    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;\n    ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;\n    ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;\n    ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;\n\n    vk_buffer d_Qx = nullptr;\n    size_t qx_buf_offset = 0;\n    vk_buffer d_Qy = nullptr;\n    size_t qy_buf_offset = 0;\n    vk_buffer d_ids = nullptr;\n    size_t ids_buf_offset = 0;\n\n    bool src0_uma = false;\n    bool src1_uma = false;\n    bool ids_uma = false;\n\n    if (ctx->device->uma) {\n        ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);\n        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);\n        ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);\n        src0_uma = d_Qx != nullptr;\n        src1_uma = d_Qy != nullptr;\n        ids_uma = d_ids != nullptr;\n    }\n\n    // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf\n    const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||\n                              !ggml_vk_dim01_contiguous(src0);\n    const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||\n                              (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||\n                              !ggml_vk_dim01_contiguous(src1);\n\n    // If src0 is BF16, try to use a BF16 x BF16 multiply\n    ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;\n\n    const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;\n\n    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0;\n\n    // Check for mmq first\n    vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;\n\n    if (mmp == nullptr) {\n        // Fall back to f16 dequant mul mat\n        mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);\n        quantize_y = false;\n    }\n\n    const bool qx_needs_dequant = mmp == nullptr || x_non_contig;\n    const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);\n\n    if (qx_needs_dequant) {\n        // Fall back to dequant + f16 mulmat\n        mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);\n    }\n\n    // Not implemented\n    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT\n\n    const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));\n    const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;\n\n    vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);\n\n    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {\n        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);\n    }\n    // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking\n    uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;\n    const uint64_t x_ne = ggml_nelements(src0);\n    const uint64_t y_ne = padded_n * ne10 * ne12 * ne13;\n    const uint64_t d_ne = ggml_nelements(dst);\n\n    const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);\n    const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);\n    const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;\n    const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);\n    const uint64_t ids_sz = nbi2;\n    const uint64_t d_sz = sizeof(float) * d_ne;\n\n    vk_pipeline to_fp16_vk_0 = nullptr;\n    vk_pipeline to_fp16_vk_1 = nullptr;\n    vk_pipeline to_q8_1 = nullptr;\n\n    if (x_non_contig) {\n        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);\n    } else {\n        to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);\n    }\n    if (y_non_contig) {\n        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);\n    } else {\n        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);\n    }\n    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT\n    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT\n\n    if (quantize_y) {\n        to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);\n    }\n    vk_pipeline count_experts = ctx->device->pipeline_count_experts;\n\n    uint32_t expert_count_size = sizeof(uint32_t) * n_as;\n\n    {\n        if (\n                (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||\n                (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange)) {\n            GGML_ABORT(\"Requested preallocation size is too large\");\n        }\n        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {\n            ctx->prealloc_size_x = x_sz;\n            ggml_vk_preallocate_buffers(ctx, subctx);\n        }\n        if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {\n            ctx->prealloc_size_y = y_sz;\n            ggml_vk_preallocate_buffers(ctx, subctx);\n        }\n        if (ctx->prealloc_size_split_k < expert_count_size) {\n            ctx->prealloc_size_split_k = expert_count_size;\n            ggml_vk_preallocate_buffers(ctx, subctx);\n        }\n\n        // Request descriptor sets\n        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n        if (qx_needs_dequant) {\n            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);\n        }\n        if (qy_needs_dequant) {\n            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);\n        }\n        if (quantize_y) {\n            ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);\n        }\n        ggml_pipeline_request_descriptor_sets(ctx, count_experts, 1);\n    }\n\n    vk_buffer d_D = dst_buf_ctx->dev_buffer;\n    const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;\n    GGML_ASSERT(d_D != nullptr);\n    vk_buffer d_X;\n    uint64_t x_buf_offset = 0;\n    vk_buffer d_Y;\n    uint64_t y_buf_offset = 0;\n    if (!src0_uma) {\n        d_Qx = src0_buf_ctx->dev_buffer;\n        qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;\n        GGML_ASSERT(d_Qx != nullptr);\n    }\n    if (!src1_uma) {\n        d_Qy = src1_buf_ctx->dev_buffer;\n        qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;\n        GGML_ASSERT(d_Qy != nullptr);\n    }\n    if (!ids_uma) {\n        d_ids = ids_buf_ctx->dev_buffer;\n        ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;\n        GGML_ASSERT(d_ids != nullptr);\n    }\n    if (qx_needs_dequant) {\n        d_X = ctx->prealloc_x;\n        GGML_ASSERT(d_X->size >= x_sz);\n    } else {\n        d_X = d_Qx;\n        x_buf_offset = qx_buf_offset;\n        GGML_ASSERT(qx_sz == x_sz);\n    }\n    if (qy_needs_dequant) {\n        d_Y = ctx->prealloc_y;\n        GGML_ASSERT(d_Y->size >= y_sz);\n    } else if (quantize_y) {\n        d_Y = ctx->prealloc_y;\n        GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz, 144) * 144);\n    } else {\n        d_Y = d_Qy;\n        y_buf_offset = qy_buf_offset;\n        GGML_ASSERT(qy_sz == y_sz);\n    }\n\n    if (x_non_contig || qx_needs_dequant) {\n        if (ctx->prealloc_x_need_sync) {\n            ggml_vk_sync_buffers(ctx, subctx);\n        }\n    }\n    // Count how many times each expert is used\n    vk_subbuffer expert_count_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);\n    if (ctx->prealloc_split_k_need_sync) {\n        ggml_vk_sync_buffers(ctx, subctx);\n    }\n    {\n        const std::vector<uint32_t> pc = { (uint32_t)nei0,\n                                           (uint32_t)nei1,\n                                           (uint32_t)(nbi0 / ggml_type_size(ids->type)),\n                                           (uint32_t)(nbi1 / ggml_type_size(ids->type)),\n                                           (uint32_t)(get_misalign_bytes(ctx, ids) / ggml_type_size(ids->type)) };\n        ggml_vk_dispatch_pipeline(ctx, subctx, count_experts,\n            { vk_subbuffer{ d_ids, ids_buf_offset, ids_sz }, expert_count_buf }, pc, { (uint32_t)n_as, 1, 1});\n    }\n\n    if (x_non_contig) {\n        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));\n    } else if (qx_needs_dequant) {\n        const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };\n        ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,\n            { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_X, 0, x_sz } }, pc, { (uint32_t)x_ne, 1, 1});\n    }\n    if (y_non_contig) {\n        if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||\n            ctx->prealloc_y_last_tensor_used != src1) {\n            if (ctx->prealloc_y_need_sync) {\n                ggml_vk_sync_buffers(ctx, subctx);\n            }\n            ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));\n            ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();\n            ctx->prealloc_y_last_tensor_used = src1;\n        }\n    }\n    if (quantize_y) {\n        if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||\n            ctx->prealloc_y_last_tensor_used != src1) {\n            if (ctx->prealloc_y_need_sync) {\n                ggml_vk_sync_buffers(ctx, subctx);\n            }\n            ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);\n            ctx->prealloc_y_last_pipeline_used = to_q8_1.get();\n            ctx->prealloc_y_last_tensor_used = src1;\n        }\n    }\n    ggml_vk_sync_buffers(ctx, subctx);\n\n    uint32_t stride_batch_x = ne00*ne01;\n    uint32_t stride_batch_y = ne10*ne11;\n\n    if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {\n        stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);\n    }\n\n    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {\n        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);\n    }\n\n    // compute\n    ggml_vk_matmul_id(\n        ctx, subctx, pipeline,\n        { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },\n        { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf,\n        ne01, ne21, ne10, ne10, ne10, ne01,\n        stride_batch_x, stride_batch_y, ne20*ne21,\n        n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n\n    );  // NOLINT\n\n    if (x_non_contig || qx_needs_dequant) {\n        ctx->prealloc_x_need_sync = true;\n    }\n    if (y_non_contig || quantize_y) {\n        ctx->prealloc_y_need_sync = true;\n    }\n    ctx->prealloc_split_k_need_sync = true;\n}\n\nstatic void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {\n    ggml_tensor * dst = cgraph->nodes[node_idx];\n    ggml_tensor * src0 = dst->src[0];\n    ggml_tensor * src1 = dst->src[1];\n    ggml_tensor * ids = dst->src[2];\n    VK_LOG_DEBUG(\"ggml_vk_mul_mat_vec_id_q_f16((\" << src0 << \", name=\" << src0->name << \", type=\" << src0->type << \", ne0=\" << src0->ne[0] << \", ne1=\" << src0->ne[1] << \", ne2=\" << src0->ne[2] << \", ne3=\" << src0->ne[3] << \", nb0=\" << src0->nb[0] << \", nb1=\" << src0->nb[1] << \", nb2=\" << src0->nb[2] << \", nb3=\" << src0->nb[3];\n    std::cerr << \"), (\" << src1 << \", name=\" << src1->name << \", type=\" << src1->type << \", ne0=\" << src1->ne[0] << \", ne1=\" << src1->ne[1] << \", ne2=\" << src1->ne[2] << \", ne3=\" << src1->ne[3] << \", nb0=\" << src1->nb[0] << \", nb1=\" << src1->nb[1] << \", nb2=\" << src1->nb[2] << \", nb3=\" << src1->nb[3];\n    std::cerr << \"), (\" << ids << \", name=\" << ids->name << \", type=\" << ids->type << \", ne0=\" << ids->ne[0] << \", ne1=\" << ids->ne[1] << \", ne2=\" << ids->ne[2] << \", ne3=\" << ids->ne[3] << \", nb0=\" << ids->nb[0] << \", nb1=\" << ids->nb[1] << \", nb2=\" << ids->nb[2] << \", nb3=\" << ids->nb[3];\n    std::cerr << \"), (\" << dst << \", name=\" << dst->name << \", type=\" << dst->type << \", ne0=\" << dst->ne[0] << \", ne1=\" << dst->ne[1] << \", ne2=\" << dst->ne[2] << \", ne3=\" << dst->ne[3] << \", nb0=\" << dst->nb[0] << \", nb1=\" << dst->nb[1] << \", nb2=\" << dst->nb[2] << \", nb3=\" << dst->nb[3];\n    std::cerr << \"))\");\n    GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16);  // NOLINT\n    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT\n    GGML_ASSERT(ids->type == GGML_TYPE_I32);\n\n    const uint64_t ne00 = src0->ne[0];\n    const uint64_t ne01 = src0->ne[1];\n    // const uint64_t ne02 = src0->ne[2];\n    // const uint64_t ne03 = src0->ne[3];\n\n    const uint64_t ne10 = src1->ne[0];\n    const uint64_t ne11 = src1->ne[1];\n    const uint64_t ne12 = src1->ne[2];\n    // const uint64_t ne13 = src1->ne[3];\n\n    const uint64_t nei0 = ids->ne[0];\n    const uint64_t nei1 = ids->ne[1];\n    const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));\n\n    const uint64_t ne20 = dst->ne[0];\n    const uint64_t ne21 = dst->ne[1];\n    // const uint64_t ne22 = dst->ne[2];\n    // const uint64_t ne23 = dst->ne[3];\n\n    const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);\n    const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);\n\n    const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;\n    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12, ne10, src0->type);\n\n    vk_pipeline to_fp16_vk_0 = nullptr;\n    vk_pipeline to_fp16_vk_1 = nullptr;\n    if (x_non_contig) {\n        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);\n    }\n    if (y_non_contig) {\n        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);\n    } else {\n        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);\n    }\n\n    // Check for mmq first\n    vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1, ne20, ne00) : nullptr;\n    vk_pipeline to_q8_1 = nullptr;\n\n    if (dmmv == nullptr) {\n        // Fall back to f16 dequant mul mat\n        dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type, ne20, ne00);\n        quantize_y = false;\n    }\n\n    if (quantize_y) {\n        to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);\n    }\n\n    const bool qx_needs_dequant = x_non_contig;\n    const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);\n\n    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {\n        dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);\n    }\n\n    // Not implemented\n    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT\n    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT\n    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT\n    GGML_ASSERT(dmmv != nullptr);\n\n    const uint64_t x_ne = ggml_nelements(src0);\n    const uint64_t y_ne = ggml_nelements(src1);\n\n    const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);\n    const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;\n    const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :\n                                       (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);\n\n    {\n        if (\n                (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||\n                (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange)) {\n            GGML_ABORT(\"Requested preallocation size is too large\");\n        }\n        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {\n            ctx->prealloc_size_x = x_sz;\n            ggml_vk_preallocate_buffers(ctx, subctx);\n        }\n        if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {\n            ctx->prealloc_size_y = y_sz;\n            ggml_vk_preallocate_buffers(ctx, subctx);\n        }\n\n        // Request descriptor sets\n        if (qx_needs_dequant) {\n            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);\n        }\n        if (qy_needs_dequant) {\n            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);\n        }\n        if (quantize_y) {\n            ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);\n        }\n        ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);\n    }\n\n    vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);\n    vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);\n    vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1);\n    vk_subbuffer d_ids = ggml_vk_tensor_subbuffer(ctx, ids);\n    vk_subbuffer d_F0 = d_D;\n    vk_subbuffer d_X, d_Y;\n\n    if (qx_needs_dequant) {\n        d_X = { ctx->prealloc_x, 0, ctx->prealloc_x->size };\n    } else {\n        d_X = d_Qx;\n    }\n    if (qy_needs_dequant || quantize_y) {\n        d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };\n    } else {\n        d_Y = d_Qy;\n    }\n\n    if (x_non_contig) {\n        if (ctx->prealloc_x_need_sync) {\n            ggml_vk_sync_buffers(ctx, subctx);\n        }\n    }\n\n    if (x_non_contig) {\n        GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));\n        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, d_Qx, d_X);\n    }\n    if (y_non_contig) {\n        GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);\n        if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||\n            ctx->prealloc_y_last_tensor_used != src1) {\n            if (ctx->prealloc_y_need_sync) {\n                ggml_vk_sync_buffers(ctx, subctx);\n            }\n            ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);\n            ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();\n            ctx->prealloc_y_last_tensor_used = src1;\n        }\n    }\n    if (quantize_y) {\n        if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||\n            ctx->prealloc_y_last_tensor_used != src1) {\n            if (ctx->prealloc_y_need_sync) {\n                ggml_vk_sync_buffers(ctx, subctx);\n            }\n            ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);\n            ctx->prealloc_y_last_pipeline_used = to_q8_1.get();\n            ctx->prealloc_y_last_tensor_used = src1;\n        }\n    }\n\n    uint32_t stride_batch_y = ne10*ne11;\n\n    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {\n        stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);\n    }\n\n    const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];\n\n    uint32_t groups_x = ne01;\n    uint32_t groups_z = 1;\n\n    if (ne01 > max_groups_x) {\n        groups_z = 64;\n        groups_x = CEIL_DIV(groups_x, groups_z);\n    }\n\n    uint32_t fusion_flags = 0;\n\n    if (ctx->num_additional_fused_ops > 0) {\n        const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];\n\n        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);\n\n        if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) {\n            fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE0;\n        } else {\n            GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID);\n            fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;\n        }\n    }\n\n    vk_subbuffer d_F1 = d_D;\n    if (ctx->num_additional_fused_ops > 1) {\n        const ggml_tensor * scale = cgraph->nodes[node_idx + 2]->src[1];\n\n        d_F1 = ggml_vk_tensor_subbuffer(ctx, scale);\n        fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;\n    }\n\n    // Loop over the batch dimension\n    for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {\n        const vk_mat_vec_id_push_constants pc = {\n            (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,\n            (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),\n            fusion_flags,\n            (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1\n        };\n        ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,\n            {\n                d_X,\n                d_Y,\n                d_D,\n                d_F0,\n                d_F1,\n                d_ids,\n            },\n            pc, { groups_x, (uint32_t)nei0, groups_z });\n    }\n\n    if (x_non_contig) {\n        ctx->prealloc_x_need_sync = true;\n    }\n    if (y_non_contig || quantize_y) {\n        ctx->prealloc_y_need_sync = true;\n    }\n}\n\nstatic bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int node_idx) {\n    ggml_tensor * dst = cgraph->nodes[node_idx];\n    ggml_tensor * src0 = dst->src[0];\n    ggml_tensor * src2 = dst->src[2];\n    return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));\n}\n\nstatic void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {\n    ggml_tensor * dst = cgraph->nodes[node_idx];\n    ggml_tensor * src0 = dst->src[0];\n    ggml_tensor * src1 = dst->src[1];\n    ggml_tensor * src2 = dst->src[2];\n    VK_LOG_DEBUG(\"ggml_vk_mul_mat_id(\" << src0 << \", \" << src1 << \", \" << src2 << \", \" << dst << \")\");\n    if (ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {\n        ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, cgraph, node_idx);\n    } else {\n        ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst);\n    }\n}\n\nstatic bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {\n    GGML_UNUSED(f32acc);\n    // Needs to be kept up to date on shader changes\n    const uint32_t wg_size = params.workgroup_size;\n    const uint32_t Br = params.block_rows;\n    const uint32_t Bc = params.block_cols;\n\n    const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);\n\n    // tmpsh is overestimated slightly\n    const uint32_t tmpsh = wg_size * sizeof(float);\n    const uint32_t tmpshv4 = wg_size * 4 * float_type_size;\n\n    const uint32_t masksh = Bc * (Br + 1) * float_type_size;\n\n    const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;\n\n    const uint32_t D = std::max(hsk, hsv);\n    const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;\n\n    const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;\n    const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;\n\n    VK_LOG_DEBUG(\"ggml_vk_flash_attn_scalar_shmem_support(HSK=\" << hsk << \", HSV=\" << hsv << \", total_size=\" << total_size << \", supported=\" << supported);\n\n    return supported;\n}\n\nstatic bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {\n    // Needs to be kept up to date on shader changes\n    const uint32_t Br = params.block_rows;\n    const uint32_t Bc = params.block_cols;\n\n    const uint32_t MatBr = 16, MatBc = 16;\n\n    const uint32_t row_split = Bc / MatBc;\n\n    const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);\n    const uint32_t hsv_pad = ROUNDUP_POW2(hsv, 16);\n\n    const uint32_t acctype = f32acc ? 4 : 2;\n    const uint32_t f16vec4 = 8;\n\n    const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);\n\n    const uint32_t qstride = hsk_pad / 4 + 2;\n    const uint32_t Qf = Br * qstride * f16vec4;\n\n    const uint32_t psh_stride = Br / 4 + 2;\n    const uint32_t Psh = Bc * psh_stride * f16vec4;\n\n    const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;\n    const uint32_t sfsh = Bc * sfshstride * acctype;\n\n    const uint32_t kvshstride = (params.shmem_staging ? std::max(hsk_pad, hsv_pad) : MatBr) / 4 + 2;\n    const uint32_t vsh_stride = MatBc / 4 * row_split;\n    const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4;\n\n    const uint32_t osh_stride = params.row_split * MatBr / 4;\n    const uint32_t pvsh = MatBc * osh_stride * f16vec4;\n\n    const uint32_t slope = Br * acctype;\n\n    const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + pvsh + slope;\n    const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;\n\n    VK_LOG_DEBUG(\"ggml_vk_flash_attn_coopmat_shmem_support(HSK=\" << hsk << \", HSV=\" << hsv << \", f32acc=\" << f32acc << \", total_size=\" << total_size << \", supported=\" << supported);\n\n    return supported;\n}\n\nstatic void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst) {\n    VK_LOG_DEBUG(\"ggml_vk_flash_attn((\" << q << \", name=\" << q->name << \", type=\" << q->type << \", ne0=\" << q->ne[0] << \", ne1=\" << q->ne[1] << \", ne2=\" << q->ne[2] << \", ne3=\" << q->ne[3] << \", nb0=\" << q->nb[0] << \", nb1=\" << q->nb[1] << \", nb2=\" << q->nb[2] << \", nb3=\" << q->nb[3];\n    std::cerr << \"), (\" << k << \", name=\" << k->name << \", type=\" << k->type << \", ne0=\" << k->ne[0] << \", ne1=\" << k->ne[1] << \", ne2=\" << k->ne[2] << \", ne3=\" << k->ne[3] << \", nb0=\" << k->nb[0] << \", nb1=\" << k->nb[1] << \", nb2=\" << k->nb[2] << \", nb3=\" << k->nb[3];\n    std::cerr << \"), (\" << v << \", name=\" << v->name << \", type=\" << v->type << \", ne0=\" << v->ne[0] << \", ne1=\" << v->ne[1] << \", ne2=\" << v->ne[2] << \", ne3=\" << v->ne[3] << \", nb0=\" << v->nb[0] << \", nb1=\" << v->nb[1] << \", nb2=\" << v->nb[2] << \", nb3=\" << v->nb[3];\n    std::cerr << \"), (\" << dst << \", name=\" << dst->name << \", type=\" << dst->type << \", ne0=\" << dst->ne[0] << \", ne1=\" << dst->ne[1] << \", ne2=\" << dst->ne[2] << \", ne3=\" << dst->ne[3] << \", nb0=\" << dst->nb[0] << \", nb1=\" << dst->nb[1] << \", nb2=\" << dst->nb[2] << \", nb3=\" << dst->nb[3];\n    if (sinks) {\n        std::cerr << \"), (\" << sinks << \", name=\" << sinks->name << \", type=\" << sinks->type << \", ne0=\" << sinks->ne[0] << \", ne1=\" << sinks->ne[1] << \", ne2=\" << sinks->ne[2] << \", ne3=\" << sinks->ne[3] << \", nb0=\" << sinks->nb[0] << \", nb1=\" << sinks->nb[1] << \", nb2=\" << sinks->nb[2] << \", nb3=\" << sinks->nb[3];\n    }\n    std::cerr << \"))\");\n\n    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)\n    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)\n    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)\n    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)\n    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)\n    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)\n\n    const uint32_t nem0 = mask ? mask->ne[0] : 0;\n    const uint32_t nem1 = mask ? mask->ne[1] : 0;\n    const uint32_t nem2 = mask ? mask->ne[2] : 0;\n    const uint32_t nem3 = mask ? mask->ne[3] : 0;\n\n    const uint32_t HSK = nek0;\n    const uint32_t HSV = nev0;\n    uint32_t N = neq1;\n    const uint32_t KV = nek1;\n\n    GGML_ASSERT(ne0 == HSV);\n    GGML_ASSERT(ne2 == N);\n\n    // input tensor rows must be contiguous\n    GGML_ASSERT(nbq0 == ggml_type_size(q->type));\n    GGML_ASSERT(nbk0 == ggml_type_size(k->type));\n    GGML_ASSERT(nbv0 == ggml_type_size(v->type));\n\n    GGML_ASSERT(neq0 == HSK);\n\n    GGML_ASSERT(neq1 == N);\n\n    GGML_ASSERT(nev1 == nek1);\n\n    // dst cannot be transposed or permuted\n    GGML_ASSERT(nb0 == sizeof(float));\n    GGML_ASSERT(nb0 <= nb1);\n    GGML_ASSERT(nb1 <= nb2);\n    GGML_ASSERT(nb2 <= nb3);\n\n    assert(dst->type == GGML_TYPE_F32);\n    assert(q->type == GGML_TYPE_F32);\n    assert(k->type == v->type);\n\n    uint32_t gqa_ratio = 1;\n    uint32_t qk_ratio = neq2 / nek2;\n    uint32_t workgroups_x = (uint32_t)neq1;\n    uint32_t workgroups_y = (uint32_t)neq2;\n    uint32_t workgroups_z = (uint32_t)neq3;\n\n    const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32;\n\n    // For scalar/coopmat1 FA, we can use the \"large\" size to accommodate qga.\n    // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).\n    vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);\n    const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);\n\n    if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&\n        qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {\n        // grouped query attention - make the N dimension equal to gqa_ratio, reduce\n        // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1\n        // and change addressing calculations to index Q's dimension 2.\n        gqa_ratio = qk_ratio;\n        N = gqa_ratio;\n        workgroups_y /= gqa_ratio;\n    }\n\n    tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);\n\n    const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));\n    uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));\n    uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));\n\n    // For F32, the shader treats it as a block of size 4 (for vec4 loads)\n    if (k->type == GGML_TYPE_F32) {\n        k_stride /= 4;\n    }\n    if (v->type == GGML_TYPE_F32) {\n        v_stride /= 4;\n    }\n\n    const uint32_t alignment = tuning_params.block_cols;\n    bool aligned = (KV % alignment) == 0 &&\n                   // the \"aligned\" shader variant will forcibly align strides, for performance\n                   (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;\n\n    // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.\n    if (((HSK | HSV) % 16) != 0 && tuning_params.path == FA_COOPMAT2) {\n        aligned = false;\n    }\n\n    float scale         = 1.0f;\n    float max_bias      = 0.0f;\n    float logit_softcap = 0.0f;\n\n    memcpy(&scale,         (const float *) dst->op_params + 0, sizeof(float));\n    memcpy(&max_bias,      (const float *) dst->op_params + 1, sizeof(float));\n    memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));\n\n    if (logit_softcap != 0) {\n        scale /= logit_softcap;\n    }\n\n    // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.\n    bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;\n    vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,\n                                                                   mask != nullptr, use_mask_opt, logit_softcap != 0);\n\n    vk_pipeline pipeline = nullptr;\n\n    {\n        std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);\n        auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];\n        auto it = pipelines.find(fa_pipeline_state);\n        if (it != pipelines.end()) {\n            pipeline = it->second;\n        } else {\n            pipelines[fa_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();\n        }\n    }\n\n    assert(pipeline);\n    // Compile early to initialize wg_denoms.\n    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n\n    uint32_t split_kv = KV;\n    uint32_t split_k = 1;\n\n    // Intel Alchemist prefers more workgroups\n    const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1;\n\n    // Use a placeholder core count if one isn't available. split_k is a big help for perf.\n    const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16;\n\n    const uint32_t Br = fa_pipeline_state.Br;\n    const uint32_t Bc = fa_pipeline_state.Bc;\n\n    GGML_ASSERT(Br == pipeline->wg_denoms[0]);\n    const uint32_t Tr = CEIL_DIV(N, Br);\n\n    // Try to use split_k when KV is large enough to be worth the overhead.\n    if (gqa_ratio > 1 && workgroups_x <= Br) {\n        split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);\n    } else if (gqa_ratio <= 1) {\n        uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z;\n        if (total_wgs_no_split < shader_core_count * 2) {\n            split_k = shader_core_count * 2 / total_wgs_no_split;\n        }\n    }\n\n    if (split_k > 1) {\n        // Try to evenly split KV into split_k chunks, but it needs to be a multiple\n        // of \"align\", so recompute split_k based on that.\n        split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);\n        split_k = CEIL_DIV(KV, split_kv);\n    }\n\n    // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)\n    // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.\n    // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].\n    // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3].\n    const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0;\n    if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {\n        GGML_ABORT(\"Requested preallocation size is too large\");\n    }\n    if (ctx->prealloc_size_split_k < split_k_size) {\n        ctx->prealloc_size_split_k = split_k_size;\n        ggml_vk_preallocate_buffers(ctx, subctx);\n    }\n\n    const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);\n    const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;\n\n    vk_pipeline pipeline_fa_mask_opt = nullptr;\n    if (use_mask_opt) {\n        std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);\n        auto &pipelines = ctx->device->pipeline_fa_mask_opt;\n        auto it = pipelines.find({Br, Bc});\n        if (it != pipelines.end()) {\n            pipeline_fa_mask_opt = it->second;\n        } else {\n            pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();\n        }\n        assert(pipeline_fa_mask_opt);\n        ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);\n\n        if (ctx->prealloc_size_y < mask_opt_size) {\n            ctx->prealloc_size_y = mask_opt_size;\n            ggml_vk_preallocate_buffers(ctx, subctx);\n        }\n        if (ctx->prealloc_y_need_sync) {\n            ggml_vk_sync_buffers(ctx, subctx);\n        }\n    }\n\n    const uint32_t n_head_kv   = neq2;\n    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));\n    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);\n    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n    vk_subbuffer q_buf = ggml_vk_tensor_subbuffer(ctx, q);\n    vk_subbuffer k_buf = ggml_vk_tensor_subbuffer(ctx, k);\n    vk_subbuffer v_buf = ggml_vk_tensor_subbuffer(ctx, v);\n    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);\n    vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;\n    vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;\n    vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;\n\n    uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;\n\n    if (use_mask_opt)\n    {\n        const vk_op_flash_attn_mask_opt_push_constants opt_pc = {\n            nem0,\n            nem1,\n            nem2,\n            (uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)),\n            (uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)),\n            (uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)),\n            mask_opt_num_dwords,\n            mask_opt_num_dwords * CEIL_DIV(nem1, Br),\n            mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2,\n        };\n\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt,\n                                  { mask_buf, mask_opt_buf }, opt_pc,\n                                  { mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 });\n        ggml_vk_sync_buffers(ctx, subctx);\n    }\n\n    const vk_flash_attn_push_constants pc = { N, KV,\n                                              (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,\n                                              (uint32_t)neq2, (uint32_t)neq3,\n                                              (uint32_t)nek2, (uint32_t)nek3,\n                                              (uint32_t)nev2, (uint32_t)nev3,\n                                              nem1, nem2, nem3,\n                                              q_stride, (uint32_t)nbq2, (uint32_t)nbq3,\n                                              k_stride, (uint32_t)nbk2, (uint32_t)nbk3,\n                                              v_stride, (uint32_t)nbv2, (uint32_t)nbv3,\n                                              scale, max_bias, logit_softcap,\n                                              mask_n_head_log2, m0, m1,\n                                              gqa_ratio, split_kv, split_k };\n\n    if (split_k > 1) {\n        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);\n\n        if (ctx->prealloc_split_k_need_sync) {\n            ggml_vk_sync_buffers(ctx, subctx);\n        }\n\n        // We reuse workgroups_x to mean the number of splits, so we need to\n        // cancel out the divide by wg_denoms[0].\n        uint32_t dispatch_x;\n        if (gqa_ratio > 1) {\n            workgroups_x *= pipeline->wg_denoms[0];\n            dispatch_x = split_k * workgroups_x;\n        } else {\n            dispatch_x = Tr * split_k * pipeline->wg_denoms[0];\n        }\n\n        vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,\n                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},\n                                    pc, { dispatch_x, workgroups_y, workgroups_z });\n\n        ggml_vk_sync_buffers(ctx, subctx);\n        const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };\n        ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,\n                                    {split_k_buf, sinks_buf, dst_buf},\n                                    pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) });\n        ctx->prealloc_split_k_need_sync = true;\n    } else {\n        if (gqa_ratio > 1) {\n            // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms\n            workgroups_x *= pipeline->wg_denoms[0];\n        }\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,\n                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf},\n                                    pc, { workgroups_x, workgroups_y, workgroups_z });\n    }\n}\n\nstatic vk_conv_shapes ggml_vk_conv_select_shape(ggml_backend_vk_context * ctx, uint32_t K, uint32_t NPQ) {\n    auto n_tiles = [&](vk_conv_shapes s) {\n        return CEIL_DIV(K, vk_conv_block_sizes[s].K)\n            * CEIL_DIV(NPQ, vk_conv_block_sizes[s].NPQ);\n    };\n\n    // We can't query number of shader cores on Intel, use 32 as a placeholder\n    // so small convolutions will still choose a smaller tile.\n    const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;\n\n    if (K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) {\n        return CONV_SHAPE_128x128;\n    } else if (K <= 32 && n_tiles(CONV_SHAPE_32x256) >= shader_core_count * 2) {\n        return CONV_SHAPE_32x256;\n    } else {\n        return CONV_SHAPE_64x32;\n    }\n}\n\nstatic vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {\n    switch (op) {\n    case GGML_OP_GET_ROWS:\n        GGML_ASSERT(src1->type == GGML_TYPE_I32);\n        if (src0->type == GGML_TYPE_I32) {\n            // i32 src only supports i32 result\n            GGML_ASSERT(dst->type == GGML_TYPE_I32);\n            return ctx->device->pipeline_get_rows[src0->type];\n        }\n        if (dst->type == GGML_TYPE_F16) {\n            return ctx->device->pipeline_get_rows[src0->type];\n        }\n        if (dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_get_rows_f32[src0->type];\n        }\n        return nullptr;\n    case GGML_OP_ACC:\n        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_acc_f32;\n        }\n        return nullptr;\n    case GGML_OP_SET:\n        if (src0->type == src1->type && src0->type == dst->type &&\n            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {\n            return ctx->device->pipeline_set_f32;\n        }\n        return nullptr;\n    case GGML_OP_ADD:\n    case GGML_OP_SUB:\n    case GGML_OP_MUL:\n    case GGML_OP_DIV:\n        if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||\n            (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) ||\n            (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) {\n            return nullptr;\n        }\n        switch (op) {\n        case GGML_OP_ADD:\n        {\n            if (ctx->num_additional_fused_ops > 0) {\n                if (ctx->do_add_rms_partials) {\n                    return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops];\n                } else {\n                    return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];\n                }\n            }\n            if (ctx->do_add_rms_partials) {\n                auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;\n                return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];\n            } else {\n                auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;\n                return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];\n            }\n        }\n        case GGML_OP_SUB:\n        {\n            auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub;\n            return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];\n        }\n        case GGML_OP_MUL:\n        {\n            auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul;\n            return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];\n        }\n        case GGML_OP_DIV:\n        {\n            auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div;\n            return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];\n        }\n        default:\n            break;\n        }\n        return nullptr;\n    case GGML_OP_ADD_ID:\n        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_add_id_f32;\n        }\n        return nullptr;\n    case GGML_OP_CONCAT:\n        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_concat_f32;\n        }\n        if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {\n            return ctx->device->pipeline_concat_f16;\n        }\n        if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {\n            return ctx->device->pipeline_concat_i32;\n        }\n        return nullptr;\n    case GGML_OP_UPSCALE:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            uint32_t mode = (ggml_get_op_params_i32(dst, 0) & (0xFF | GGML_SCALE_FLAG_ANTIALIAS));\n            switch (mode) {\n                case GGML_SCALE_MODE_NEAREST:\n                    return ctx->device->pipeline_upscale_nearest_f32;\n                case GGML_SCALE_MODE_BILINEAR:\n                    return ctx->device->pipeline_upscale_bilinear_f32;\n                case GGML_SCALE_MODE_BICUBIC:\n                    return ctx->device->pipeline_upscale_bicubic_f32;\n                case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS:\n                    return ctx->device->pipeline_upscale_bilinear_antialias_f32;\n                default:\n                    return nullptr;\n            }\n        }\n        return nullptr;\n    case GGML_OP_SCALE:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_scale_f32;\n        }\n        return nullptr;\n    case GGML_OP_SQR:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_sqr_f32;\n        }\n        return nullptr;\n    case GGML_OP_SQRT:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_sqrt_f32;\n        }\n        return nullptr;\n    case GGML_OP_SIN:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_sin_f32;\n        }\n        return nullptr;\n    case GGML_OP_COS:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_cos_f32;\n        }\n        return nullptr;\n    case GGML_OP_LOG:\n        if (src0->type == dst->type &&\n            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {\n            return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];\n        }\n        return nullptr;\n    case GGML_OP_TRI:\n        if (src0->type == dst->type &&\n            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {\n            return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];\n        }\n        return nullptr;\n    case GGML_OP_DIAG:\n        if (src0->type == dst->type &&\n            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {\n            return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16];\n        }\n        return nullptr;\n    case GGML_OP_CLAMP:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_clamp_f32;\n        }\n        return nullptr;\n    case GGML_OP_PAD:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_pad_f32;\n        }\n        return nullptr;\n    case GGML_OP_ROLL:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_roll_f32;\n        }\n        return nullptr;\n    case GGML_OP_REPEAT:\n        if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {\n            return ctx->device->pipeline_repeat_f32;\n        }\n        return nullptr;\n    case GGML_OP_REPEAT_BACK:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_repeat_back_f32;\n        }\n        return nullptr;\n    case GGML_OP_CPY:\n    case GGML_OP_CONT:\n    case GGML_OP_DUP:\n        return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);\n    case GGML_OP_SET_ROWS:\n        if (src1->type == GGML_TYPE_I64) {\n            return ctx->device->pipeline_set_rows_i64[dst->type];\n        } else {\n            return ctx->device->pipeline_set_rows_i32[dst->type];\n        }\n    case GGML_OP_SILU_BACK:\n        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_silu_back_f32;\n        }\n        return nullptr;\n    case GGML_OP_NORM:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_norm_f32;\n        }\n        return nullptr;\n    case GGML_OP_GROUP_NORM:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_group_norm_f32;\n        }\n        return nullptr;\n    case GGML_OP_RMS_NORM:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            if (ctx->do_add_rms_partials) {\n                return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;\n            } else {\n                return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;\n            }\n        }\n        return nullptr;\n    case GGML_OP_RMS_NORM_BACK:\n        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_rms_norm_back_f32;\n        }\n        return nullptr;\n    case GGML_OP_L2_NORM:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_l2_norm_f32;\n        }\n        return nullptr;\n    case GGML_OP_UNARY:\n        if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||\n            (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||\n            (src0->type != dst->type)) {\n            return nullptr;\n        }\n\n        switch (ggml_get_unary_op(dst)) {\n            case GGML_UNARY_OP_EXP:\n                return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_ELU:\n                return ctx->device->pipeline_elu[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_SILU:\n                return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_GELU:\n                return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_GELU_ERF:\n                return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_GELU_QUICK:\n                return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_RELU:\n                return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_XIELU:\n                return ctx->device->pipeline_xielu[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_NEG:\n                return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_TANH:\n                return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_SIGMOID:\n                return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_HARDSIGMOID:\n                return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_HARDSWISH:\n                return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_ABS:\n                return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_SOFTPLUS:\n                return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_STEP:\n                return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_ROUND:\n                return ctx->device->pipeline_round[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_CEIL:\n                return ctx->device->pipeline_ceil[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_FLOOR:\n                return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_TRUNC:\n                return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];\n            case GGML_UNARY_OP_SGN:\n                return ctx->device->pipeline_sgn[dst->type == GGML_TYPE_F16];\n            default:\n                break;\n        }\n        return nullptr;\n    case GGML_OP_GLU:\n        if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||\n            (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||\n            (src0->type != dst->type)) {\n            return nullptr;\n        }\n\n        switch (ggml_get_glu_op(dst)) {\n            case GGML_GLU_OP_GEGLU:\n                return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];\n            case GGML_GLU_OP_REGLU:\n                return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];\n            case GGML_GLU_OP_SWIGLU:\n                return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];\n            case GGML_GLU_OP_SWIGLU_OAI:\n                return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];\n            case GGML_GLU_OP_GEGLU_ERF:\n                return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];\n            case GGML_GLU_OP_GEGLU_QUICK:\n                return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];\n            default:\n                break;\n        }\n        return nullptr;\n    case GGML_OP_DIAG_MASK_INF:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_diag_mask_inf_f32;\n        }\n        return nullptr;\n    case GGML_OP_SOFT_MAX:\n        GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);\n        GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);\n\n        if (ctx->num_additional_fused_ops) {\n            uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));\n            GGML_ASSERT(idx < num_topk_moe_pipelines);\n            // use n_experts from push constant if it's not equal to the power of two spec constant\n            bool use_push = dst->ne[0] != (1u << idx);\n            return ctx->device->pipeline_topk_moe[idx][use_push];\n        }\n\n        if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {\n            return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;\n        }\n        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {\n            return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;\n        }\n        return nullptr;\n    case GGML_OP_SOFT_MAX_BACK:\n        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_soft_max_back_f32;\n        }\n        return nullptr;\n    case GGML_OP_ROPE:\n    case GGML_OP_ROPE_BACK:\n        {\n            const ggml_tensor *rope = ctx->num_additional_fused_ops == 2 ? dst->src[0]->src[0] : dst;\n            const int mode = ((const int32_t *) rope->op_params)[2];\n            const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;\n            const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;\n            const bool is_vision = mode == GGML_ROPE_TYPE_VISION;\n\n            if (is_neox) {\n                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n                    return ctx->device->pipeline_rope_neox_f32;\n                }\n                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {\n                    return ctx->device->pipeline_rope_neox_f32_f16;\n                }\n                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {\n                    return ctx->device->pipeline_rope_neox_f16;\n                }\n            } else if (is_mrope && !is_vision) {\n                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n                    return ctx->device->pipeline_rope_multi_f32;\n                }\n                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {\n                    return ctx->device->pipeline_rope_multi_f32_f16;\n                }\n                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {\n                    return ctx->device->pipeline_rope_multi_f16;\n                }\n            } else if (is_vision) {\n                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n                    return ctx->device->pipeline_rope_vision_f32;\n                }\n                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {\n                    return ctx->device->pipeline_rope_vision_f16;\n                }\n            } else {\n                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n                    return ctx->device->pipeline_rope_norm_f32;\n                }\n                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {\n                    return ctx->device->pipeline_rope_norm_f32_f16;\n                }\n                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {\n                    return ctx->device->pipeline_rope_norm_f16;\n                }\n            }\n            return nullptr;\n        }\n    case GGML_OP_SUM:\n    case GGML_OP_SUM_ROWS:\n    case GGML_OP_MEAN:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_sum_rows_f32;\n        }\n        return nullptr;\n    case GGML_OP_CUMSUM:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            if (src0->ne[0] <= 512) {\n                return ctx->device->pipeline_cumsum_small_f32;\n            } else {\n                return ctx->device->pipeline_cumsum_f32;\n            }\n        }\n        return nullptr;\n    case GGML_OP_SOLVE_TRI:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n\n            vk_solve_tri_pipeline_state solve_tri_pipeline_state(src0->ne[0], src1->ne[0]);\n\n            vk_pipeline pipeline = nullptr;\n\n            {\n                std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);\n                auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);\n                if (it != ctx->device->pipeline_solve_tri_f32.end()) {\n                    pipeline = it->second;\n                } else {\n                    ctx->device->pipeline_solve_tri_f32[solve_tri_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();\n                }\n            }\n\n            return pipeline;\n        }\n        return nullptr;\n    case GGML_OP_ARGMAX:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {\n            return ctx->device->pipeline_argmax_f32;\n        }\n        return nullptr;\n    case GGML_OP_COUNT_EQUAL:\n        if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {\n            return ctx->device->pipeline_count_equal_i32;\n        }\n        return nullptr;\n    case GGML_OP_IM2COL:\n        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_im2col_f32;\n        }\n        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {\n            return ctx->device->pipeline_im2col_f32_f16;\n        }\n        return nullptr;\n    case GGML_OP_IM2COL_3D:\n        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_im2col_3d_f32;\n        }\n        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {\n            return ctx->device->pipeline_im2col_3d_f32_f16;\n        }\n        return nullptr;\n    case GGML_OP_TIMESTEP_EMBEDDING:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_timestep_embedding_f32;\n        }\n        return nullptr;\n    case GGML_OP_CONV_TRANSPOSE_1D:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_conv_transpose_1d_f32;\n        }\n        return nullptr;\n    case GGML_OP_POOL_2D:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_pool2d_f32;\n        }\n        return nullptr;\n    case GGML_OP_RWKV_WKV6:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_rwkv_wkv6_f32;\n        }\n        return nullptr;\n    case GGML_OP_RWKV_WKV7:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_rwkv_wkv7_f32;\n        }\n        return nullptr;\n    case GGML_OP_GATED_DELTA_NET:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            const uint32_t S_v = dst->src[2]->ne[0];\n            const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0;\n            uint32_t si;\n            switch (S_v) {\n                case 32:  si = 0; break;\n                case 64:  si = 1; break;\n                case 128: si = 2; break;\n                default: return nullptr;\n            }\n            return ctx->device->pipeline_gated_delta_net[si][kda];\n        }\n        return nullptr;\n    case GGML_OP_SSM_SCAN:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            const uint32_t d_state = src0->ne[0];\n            if (d_state == 128) {\n                return ctx->device->pipeline_ssm_scan_f32_d128;\n            } else if (d_state == 256) {\n                return ctx->device->pipeline_ssm_scan_f32_d256;\n            }\n        }\n        return nullptr;\n    case GGML_OP_SSM_CONV:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_ssm_conv_f32;\n        }\n        return nullptr;\n    case GGML_OP_OPT_STEP_ADAMW:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_opt_step_adamw_f32;\n        }\n        return nullptr;\n    case GGML_OP_OPT_STEP_SGD:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_opt_step_sgd_f32;\n        }\n        return nullptr;\n    case GGML_OP_LEAKY_RELU:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_leaky_relu_f32;\n        }\n        return nullptr;\n    case GGML_OP_CONV_2D:\n    case GGML_OP_CONV_TRANSPOSE_2D:\n        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            uint32_t K = dst->ne[2]; // Cout\n            uint32_t NPQ = dst->ne[3] * dst->ne[1] * dst->ne[0]; // N * OH * OW\n            vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, K, NPQ);\n\n            bool transpose = dst->op == GGML_OP_CONV_TRANSPOSE_2D;\n            uint32_t KW = (uint32_t)src0->ne[0];\n            uint32_t KH = (uint32_t)src0->ne[1];\n            uint32_t s0 = (uint32_t)(ggml_get_op_params_i32(dst, 0));\n            uint32_t s1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 1) : s0;\n            uint32_t p0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 2) : 0;\n            uint32_t p1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 3) : 0;\n            uint32_t d0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 4) : 1;\n            uint32_t d1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 5) : 1;\n            vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH);\n\n            std::map<vk_conv2d_pipeline_state, vk_pipeline> *pipelines = nullptr;\n            if (op == GGML_OP_CONV_2D) {\n                if (src0->type == GGML_TYPE_F32) {\n                    pipelines = &ctx->device->pipeline_conv2d_f32[shape];\n                } else if (src0->type == GGML_TYPE_F16) {\n                    pipelines = &ctx->device->pipeline_conv2d_f16_f32[shape];\n                }\n            } else if (op == GGML_OP_CONV_TRANSPOSE_2D) {\n                if (src0->type == GGML_TYPE_F32) {\n                    pipelines = &ctx->device->pipeline_conv_transpose_2d_f32[shape];\n                } else if (src0->type == GGML_TYPE_F16) {\n                    pipelines = &ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];\n                }\n            }\n\n            vk_pipeline pipeline = nullptr;\n\n            {\n                std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);\n                auto it = pipelines->find(conv2d_pipeline_state);\n                if (it != pipelines->end()) {\n                    pipeline = it->second;\n                } else {\n                    (*pipelines)[conv2d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();\n                }\n            }\n\n            return pipeline;\n        }\n        return nullptr;\n    case GGML_OP_CONV_2D_DW:\n        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            if (ggml_is_contiguous(src1)) {\n                return ctx->device->pipeline_conv2d_dw_whcn_f32;\n            } else if (ggml_is_contiguous_channels(src1)) {\n                return ctx->device->pipeline_conv2d_dw_cwhn_f32;\n            }\n        } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {\n            if (ggml_is_contiguous(src1)) {\n                return ctx->device->pipeline_conv2d_dw_whcn_f16_f32;\n            } else if (ggml_is_contiguous_channels(src1)) {\n                return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32;\n            }\n        }\n        return nullptr;\n    case GGML_OP_ADD1:\n        if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {\n            return ctx->device->pipeline_add1_f16_f16;\n        }\n        if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {\n            return ctx->device->pipeline_add1_f16_f32;\n        }\n        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_add1_f32_f32;\n        }\n        return nullptr;\n    case GGML_OP_ARANGE:\n        if (dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_arange_f32;\n        }\n        return nullptr;\n    case GGML_OP_FILL:\n        if (dst->type == GGML_TYPE_F32) {\n            return ctx->device->pipeline_fill_f32;\n        }\n        return nullptr;\n    default:\n        return nullptr;\n    }\n\n    GGML_UNUSED(src2);\n}\n\ntemplate <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {\n    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);\n    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);\n\n    p.misalign_offsets = (a_offset << 16) | d_offset;\n\n    GGML_UNUSED(src1);\n    GGML_UNUSED(src2);\n    GGML_UNUSED(src3);\n}\n\ntemplate <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {\n    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);\n    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);\n\n    p.misalign_offsets = (a_offset << 16) | d_offset;\n\n    GGML_UNUSED(src1);\n    GGML_UNUSED(src2);\n    GGML_UNUSED(src3);\n}\n\ntemplate <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {\n    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);\n    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);\n\n    p.misalign_offsets = (a_offset << 16) | d_offset;\n\n    GGML_UNUSED(src1);\n    GGML_UNUSED(src2);\n    GGML_UNUSED(src3);\n}\n\ntemplate <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {\n    const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);\n    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);\n\n    p.misalign_offsets = (a_offset << 16) | d_offset;\n\n    GGML_UNUSED(src0);\n    GGML_UNUSED(src2);\n    GGML_UNUSED(src3);\n}\n\ntemplate <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {\n    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);\n    const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);\n    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);\n\n    GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0));\n\n    p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;\n\n    GGML_UNUSED(src2);\n    GGML_UNUSED(src3);\n}\n\ntemplate <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {\n    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);\n    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);\n\n    p.a_offset = a_offset;\n    p.d_offset = d_offset;\n\n    GGML_UNUSED(src1);\n    GGML_UNUSED(src2);\n    GGML_UNUSED(src3);\n}\n\ntemplate<typename PC>\nstatic void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) {\n    VK_LOG_DEBUG(\"ggml_vk_op_f32((\" << src0 << \", name=\" << src0->name << \", type=\" << src0->type << \", ne0=\" << src0->ne[0] << \", ne1=\" << src0->ne[1] << \", ne2=\" << src0->ne[2] << \", ne3=\" << src0->ne[3] << \", nb0=\" << src0->nb[0] << \", nb1=\" << src0->nb[1] << \", nb2=\" << src0->nb[2] << \", nb3=\" << src0->nb[3];\n    if (src1 != nullptr) {\n        std::cerr << \"), (\" << src1 << \", name=\" << src1->name << \", type=\" << src1->type << \", ne0=\" << src1->ne[0] << \", ne1=\" << src1->ne[1] << \", ne2=\" << src1->ne[2] << \", ne3=\" << src1->ne[3] << \", nb0=\" << src1->nb[0] << \", nb1=\" << src1->nb[1] << \", nb2=\" << src1->nb[2] << \", nb3=\" << src1->nb[3];\n    }\n    if (src2 != nullptr) {\n        std::cerr << \"), (\" << src2 << \", name=\" << src2->name << \", type=\" << src2->type << \", ne0=\" << src2->ne[0] << \", ne1=\" << src2->ne[1] << \", ne2=\" << src2->ne[2] << \", ne3=\" << src2->ne[3] << \", nb0=\" << src2->nb[0] << \", nb1=\" << src2->nb[1] << \", nb2=\" << src2->nb[2] << \", nb3=\" << src2->nb[3];\n    }\n    if (src3 != nullptr) {\n        std::cerr << \"), (\" << src3 << \", name=\" << src3->name << \", type=\" << src3->type << \", ne0=\" << src3->ne[0] << \", ne1=\" << src3->ne[1] << \", ne2=\" << src3->ne[2] << \", ne3=\" << src3->ne[3] << \", nb0=\" << src3->nb[0] << \", nb1=\" << src3->nb[1] << \", nb2=\" << src3->nb[2] << \", nb3=\" << src3->nb[3];\n    }\n    std::cerr << \"), (\" << dst << \", name=\" << dst->name << \", type=\" << dst->type << \", ne0=\" << dst->ne[0] << \", ne1=\" << dst->ne[1] << \", ne2=\" << dst->ne[2] << \", ne3=\" << dst->ne[3] << \", nb0=\" << dst->nb[0] << \", nb1=\" << dst->nb[1] << \", nb2=\" << dst->nb[2] << \", nb3=\" << dst->nb[3];\n    std::cerr << \"), \" << ggml_op_name(op) << \")\");\n    GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type))));  // NOLINT\n    GGML_ASSERT(dst->buffer != nullptr);\n    const uint64_t ne00 = src0->ne[0];\n    const uint64_t ne01 = src0->ne[1];\n    const uint64_t ne02 = src0->ne[2];\n    const uint64_t ne03 = src0->ne[3];\n\n    const bool use_src1 = src1 != nullptr;\n    const uint64_t ne10 = use_src1 ? src1->ne[0] : 0;\n    const uint64_t ne11 = use_src1 ? src1->ne[1] : 0;\n    const uint64_t ne12 = use_src1 ? src1->ne[2] : 0;\n    const uint64_t ne13 = use_src1 ? src1->ne[3] : 0;\n\n    const bool use_src2 = src2 != nullptr;\n    const bool use_src3 = src3 != nullptr;\n\n    init_pushconst_fastdiv(pc);\n\n    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);\n\n    if (pipeline == nullptr) {\n        std::cerr << \"ggml_vulkan: Error: Missing op: \" << ggml_op_name(op) << \" for \" << ggml_type_name(src0->type);\n        if (src1 != nullptr) {\n            std::cerr << \" and \" << ggml_type_name(src1->type);\n        }\n        std::cerr << \" to \" << ggml_type_name(dst->type) << std::endl;\n        GGML_ABORT(\"fatal error\");\n    }\n\n    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n\n    vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, true);\n    vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, true) : vk_subbuffer{};\n    vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, true) : vk_subbuffer{};\n    vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, true) : vk_subbuffer{};\n    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);\n\n    // Compute misalignment offset for descriptors and store it in in push constants.\n    init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);\n\n    std::array<uint32_t, 3> elements;\n\n    switch (op) {\n    case GGML_OP_NORM:\n    case GGML_OP_RMS_NORM_BACK:\n    case GGML_OP_L2_NORM:\n    case GGML_OP_SOFT_MAX:\n    case GGML_OP_SOFT_MAX_BACK:\n    case GGML_OP_SUM_ROWS:\n    case GGML_OP_CUMSUM:\n    case GGML_OP_MEAN:\n    case GGML_OP_ARGMAX:\n        {\n            const uint32_t nr = ggml_nrows(src0);\n            if (nr > 262144) {\n                elements = { 512, 512, CEIL_DIV(nr, 262144) };\n            } else if (nr > 512) {\n                elements = { 512, CEIL_DIV(nr, 512), 1 };\n            } else {\n                elements = { nr, 1, 1 };\n            }\n        } break;\n    case GGML_OP_SOLVE_TRI:\n        {\n            uint32_t nr = (uint32_t)(ne02 * ne03);\n            if (nr > 262144) {\n                elements = { 512, 512, CEIL_DIV(nr, 262144) };\n            } else if (nr > 512) {\n                elements = { 512, CEIL_DIV(nr, 512), 1 };\n            } else {\n                elements = { nr, 1, 1 };\n            }\n        }\n        break;\n    case GGML_OP_RMS_NORM:\n        if (ctx->do_add_rms_partials) {\n            // Run one element per thread, 128 threads per workgroup\n            elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };\n        } else {\n            elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };\n        }\n        break;\n\n    case GGML_OP_SUM:\n        // We use GGML_OP_SUM_ROWS with 1 row.\n        elements = { 1, 1, 1 };\n        break;\n    case GGML_OP_GROUP_NORM:\n        {\n            const uint32_t num_groups = dst->op_params[0];\n            elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };\n        } break;\n    case GGML_OP_DIAG_MASK_INF:\n        elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };\n        break;\n    case GGML_OP_ROPE:\n    case GGML_OP_ROPE_BACK:\n        {\n            uint32_t nrows = (uint32_t)ggml_nrows(src0);\n            uint32_t z = 1;\n            if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {\n                z = CEIL_DIV(nrows, 32768);\n                nrows = 32768;\n            }\n            elements = { nrows, (uint32_t)ne00, z };\n\n        } break;\n    case GGML_OP_GET_ROWS:\n        elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };\n        elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);\n        elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);\n        break;\n    case GGML_OP_ARGSORT:\n        GGML_ASSERT(0);\n        break;\n    case GGML_OP_IM2COL:\n        {\n            const bool is_2D = dst->op_params[6] == 1;\n\n            const uint32_t IC = src1->ne[is_2D ? 2 : 1];\n\n            const uint32_t KH = is_2D ? src0->ne[1] : 1;\n            const uint32_t KW =         src0->ne[0];\n\n            const uint32_t OH = is_2D ? dst->ne[2] : 1;\n            const uint32_t OW =         dst->ne[1];\n\n            const uint32_t batch = src1->ne[is_2D ? 3 : 2];\n\n            elements = { OW * KW * KH, OH, batch * IC };\n            elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);\n            elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);\n        } break;\n    case GGML_OP_IM2COL_3D:\n        {\n            const uint32_t IC = ((const uint32_t *)(dst->op_params))[9];\n\n            const uint32_t N  = ne13 / IC;\n\n            const uint32_t KD = ne02;\n            const uint32_t KH = ne01;\n            const uint32_t KW = ne00;\n\n            const uint32_t OD = dst->ne[3] / N;\n            const uint32_t OH = dst->ne[2];\n            const uint32_t OW = dst->ne[1];\n\n            const uint32_t IC_KD_KH_KW = IC*KD*KH*KW;\n            const uint32_t N_OD_OH = N*OD*OH;\n\n            elements = { IC_KD_KH_KW, OW, N_OD_OH };\n            elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);\n        } break;\n    case GGML_OP_TIMESTEP_EMBEDDING:\n        {\n            const uint32_t dim = dst->op_params[0];\n            uint32_t half_ceil = (dim + 1) / 2;\n            elements = { half_ceil, (uint32_t)src0->ne[0], 1 };\n        } break;\n    case GGML_OP_CONV_TRANSPOSE_1D:\n        {\n            elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}\n        } break;\n    case GGML_OP_POOL_2D:\n        {\n            const uint32_t N = dst->ne[3];\n            const uint32_t OC = dst->ne[2];\n            const uint32_t OH = dst->ne[1];\n            const uint32_t OW = dst->ne[0];\n            elements = { N * OC * OH * OW, 1, 1};\n        } break;\n    case GGML_OP_CONV_2D:\n    case GGML_OP_CONV_TRANSPOSE_2D:\n        if constexpr (std::is_same_v<PC, vk_op_conv2d_push_constants>) {\n            const uint32_t NPQ = pc.N * pc.OH * pc.OW;\n            const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.Cout, NPQ);\n            const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);\n\n            elements = { pc.Cout, NPQ_blocks, 1 };\n            if (elements[1] > 512) {\n                elements[2] = CEIL_DIV(elements[1], 512);\n                elements[1] = 512;\n            }\n        } else {\n            GGML_ABORT(\"invalid push constant type for CONV_2D\");\n        }\n        break;\n    case GGML_OP_ADD:\n    case GGML_OP_SUB:\n    case GGML_OP_DIV:\n    case GGML_OP_MUL:\n    case GGML_OP_ADD1:\n    case GGML_OP_ARANGE:\n    case GGML_OP_FILL:\n    case GGML_OP_SCALE:\n    case GGML_OP_SQR:\n    case GGML_OP_SQRT:\n    case GGML_OP_SIN:\n    case GGML_OP_COS:\n    case GGML_OP_LOG:\n    case GGML_OP_TRI:\n    case GGML_OP_DIAG:\n    case GGML_OP_CLAMP:\n    case GGML_OP_PAD:\n    case GGML_OP_ROLL:\n    case GGML_OP_REPEAT:\n    case GGML_OP_REPEAT_BACK:\n    case GGML_OP_CPY:\n    case GGML_OP_CONCAT:\n    case GGML_OP_UPSCALE:\n    case GGML_OP_UNARY:\n    case GGML_OP_GLU:\n    case GGML_OP_CONV_2D_DW:\n        {\n            uint32_t ne = ggml_nelements(dst);\n            if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {\n                // Convert from number of logical elements to 2- or 4-byte units.\n                ne /= ggml_blck_size(src0->type);\n                if ((ggml_type_size(src0->type) % 4) == 0) {\n                    ne *= ggml_type_size(src0->type) / 4;\n                } else {\n                    ne *= ggml_type_size(src0->type) / 2;\n                }\n            }\n            // copy_to_quant has block size of 32, and each thread does QUANT_K elements.\n            // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.\n            // So divide by block size here before splitting into 512x512 groups.\n            if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {\n                ne = CEIL_DIV(ne, ggml_blck_size(dst->type));\n            }\n            if (ne > 262144) {\n                elements = { 512, 512, CEIL_DIV(ne, 262144) };\n            } else if (ne > 512) {\n                elements = { 512, CEIL_DIV(ne, 512), 1 };\n            } else {\n                elements = { ne, 1, 1 };\n            }\n\n            if (pipeline == ctx->device->pipeline_cpy_transpose_32 ||\n                pipeline == ctx->device->pipeline_cpy_transpose_16) {\n                // 32x32 tiles\n                elements[0] = (uint32_t)CEIL_DIV(dst->ne[0], 32);\n                elements[1] = (uint32_t)CEIL_DIV(dst->ne[1], 32);\n                elements[2] = (uint32_t)(dst->ne[2]*dst->ne[3]);\n                elements[0] = std::min(elements[0], ctx->device->properties.limits.maxComputeWorkGroupCount[0]);\n                elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);\n                elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);\n            }\n        } break;\n    case GGML_OP_ADD_ID:\n        {\n            elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };\n        } break;\n    case GGML_OP_SET_ROWS:\n        {\n            uint32_t ne = ggml_nelements(src0);\n            if (ggml_is_quantized(dst->type)) {\n                // quants run 32 threads each doing QUANT_K elements\n                ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));\n            } else {\n                // scalar types do one element per thread, running 512 threads\n                ne = CEIL_DIV(ne, 512);\n            }\n            if (ne > 262144) {\n                elements = { 512, 512, CEIL_DIV(ne, 262144) };\n            } else if (ne > 512) {\n                elements = { 512, CEIL_DIV(ne, 512), 1 };\n            } else {\n                elements = { ne, 1, 1 };\n            }\n        }\n        break;\n    case GGML_OP_SSM_CONV:\n        {\n            const uint32_t nr  = src0->ne[1];\n            const uint32_t n_t = dst->ne[1];\n            const uint32_t n_s = dst->ne[2];\n            elements = { nr, n_t, n_s };\n        }\n        break;\n    default:\n        elements = { (uint32_t)ggml_nelements(src0), 1, 1 };\n        break;\n    }\n\n    if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {\n        vk_subbuffer a_buf = src0_buf;\n        if (ctx->do_add_rms_partials) {\n            a_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_add_rms_partials, ctx->prealloc_size_add_rms_partials_offset);\n        }\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,\n            { src0_buf, src1_buf, dst_buf, a_buf }, pc, elements);\n    } else if (op == GGML_OP_GLU) {\n        // Empty src1 is possible in glu, but the shader needs a buffer\n        vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc, elements);\n    } else if (op == GGML_OP_SOFT_MAX) {\n        // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer\n        vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;\n        vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, subbuf2, dst_buf }, pc, elements);\n    } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {\n        // Empty src2 and src3 is possible in rope, but the shader needs a buffer\n        vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;\n        vk_subbuffer subbuf3 = use_src3 ? src3_buf : src0_buf;\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, subbuf2, dst_buf, subbuf3 }, pc, elements);\n    } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {\n        if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {\n            // buffer device address path doesn't use dst buffer\n            dst_buf.size = 1;\n        }\n        // im2col uses only src1 and dst buffers\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src1_buf, dst_buf }, pc, elements);\n    } else if (op == GGML_OP_COUNT_EQUAL) {\n        // count_equal assumes that destination buffer is initialized with zeroes\n        ggml_vk_buffer_memset_async(subctx, dst_buf.buffer, dst_buf.offset, 0, dst_buf.size);\n        ggml_vk_sync_buffers(ctx, subctx);\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);\n    } else if (op == GGML_OP_OPT_STEP_SGD) {\n        // OPT_STEP_SGD works on src0, it does not need dst\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf }, pc, elements);\n    } else if (use_src3) {\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, src3_buf, dst_buf }, pc, elements);\n    } else if (use_src2) {\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, dst_buf }, pc, elements);\n    } else if (use_src1) {\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);\n    } else {\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, dst_buf }, pc, elements);\n    }\n}\n\nstatic void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    const uint32_t src0_type_size = ggml_type_size(src0->type);\n    const uint32_t src1_type_size = ggml_type_size(src1->type);\n    const uint32_t dst_type_size = ggml_type_size(dst->type);\n\n    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS, {\n        (uint32_t)ggml_nelements(src0),\n        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,\n        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,\n        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,\n        0,\n        0.0f, 0.0f, 0,\n    });\n}\n\nstatic void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    const uint32_t src0_type_size = ggml_type_size(src0->type);\n    const uint32_t src1_type_size = ggml_type_size(src1->type);\n    const uint32_t dst_type_size = ggml_type_size(dst->type);\n\n    int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32\n    int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32\n    int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32\n    int offset = dst->op_params[3] / src0_type_size; // offset in bytes\n\n    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {\n        (uint32_t)ggml_nelements(src0),\n        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,\n        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,\n        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,\n        0,\n        0.0f, 0.0f, offset,\n    });\n}\n\nstatic void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {\n    const ggml_tensor *first_node = cgraph->nodes[node_idx];\n    const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];\n\n    // Make a list of all the tensors used by the op.\n    // Last element of the list is the dest tensor.\n    const ggml_tensor *tensors[MAX_PARAMETER_COUNT];\n    uint32_t num_srcs = ctx->num_additional_fused_ops + 2;\n    uint32_t num_tensors = num_srcs + 1;\n    GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT);\n\n    tensors[0] = first_node->src[0];\n    tensors[1] = first_node->src[1];\n    for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) {\n        // check whether the previous result is src[0] or src[1]\n        if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) {\n            tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1];\n        } else {\n            tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0];\n        }\n    }\n    tensors[num_srcs] = dst;\n\n    vk_op_multi_add_push_constants pc;\n    pc.ne20 = (uint32_t)dst->ne[0];\n    pc.ne21 = (uint32_t)dst->ne[1];\n    pc.ne22 = (uint32_t)dst->ne[2];\n    pc.ne23 = (uint32_t)dst->ne[3];\n\n    for (uint32_t i = 0; i < num_tensors; ++i) {\n        const ggml_tensor *t = tensors[i];\n        pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float);\n        pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float);\n        pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);\n        pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);\n    }\n    pc.rms_partials = ctx->do_add_rms_partials;\n\n    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, tensors[0], tensors[1], nullptr, dst, dst->op);\n\n    if (pipeline == nullptr) {\n        std::cerr << \"ggml_vulkan: Error: Missing multi_add\";\n        GGML_ABORT(\"fatal error\");\n    }\n\n    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n\n    ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT];\n    vk_buffer buf[MAX_PARAMETER_COUNT];\n    size_t offset[MAX_PARAMETER_COUNT];\n    bool uma[MAX_PARAMETER_COUNT];\n\n    for (uint32_t i = 0; i < num_tensors; ++i) {\n        buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;\n        buf[i] = nullptr;\n        offset[i] = 0;\n        uma[i] = false;\n\n        if (ctx->device->uma) {\n            ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);\n            uma[i] = buf[i] != nullptr;\n        }\n        if (!uma[i]) {\n            buf[i] = buf_ctx[i]->dev_buffer;\n            offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;\n        }\n        GGML_ASSERT(buf[i] != nullptr);\n    }\n    // If any remaining descriptors are unused, just point them at src[0]\n    for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) {\n        buf[i] = buf[0];\n        offset[i] = 0;\n    }\n    if (ctx->do_add_rms_partials) {\n        buf[num_tensors] = ctx->prealloc_add_rms_partials;\n        offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset;\n    }\n\n    std::array<uint32_t, 3> elements;\n\n    uint32_t ne = ggml_nelements(dst);\n    if (ne > 262144) {\n        elements = { 512, 512, CEIL_DIV(ne, 262144) };\n    } else if (ne > 512) {\n        elements = { 512, CEIL_DIV(ne, 512), 1 };\n    } else {\n        elements = { ne, 1, 1 };\n    }\n\n    static_assert(MAX_PARAMETER_COUNT == 12);\n    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,\n        {\n            ggml_vk_subbuffer(ctx, buf[0], offset[0]),\n            ggml_vk_subbuffer(ctx, buf[1], offset[1]),\n            ggml_vk_subbuffer(ctx, buf[2], offset[2]),\n            ggml_vk_subbuffer(ctx, buf[3], offset[3]),\n            ggml_vk_subbuffer(ctx, buf[4], offset[4]),\n            ggml_vk_subbuffer(ctx, buf[5], offset[5]),\n            ggml_vk_subbuffer(ctx, buf[6], offset[6]),\n            ggml_vk_subbuffer(ctx, buf[7], offset[7]),\n            ggml_vk_subbuffer(ctx, buf[8], offset[8]),\n            ggml_vk_subbuffer(ctx, buf[9], offset[9]),\n            ggml_vk_subbuffer(ctx, buf[10], offset[10]),\n            ggml_vk_subbuffer(ctx, buf[11], offset[11]),\n        }, pc, elements);\n}\n\nstatic void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    const uint32_t src0_type_size = ggml_type_size(src0->type);\n    const uint32_t src1_type_size = ggml_type_size(src1->type);\n    const uint32_t dst_type_size = ggml_type_size(dst->type);\n\n    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD, {\n        (uint32_t)ggml_nelements(src0),\n        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,\n        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,\n        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,\n        0,\n        0.0f, 0.0f, ctx->do_add_rms_partials,\n    });\n}\n\nstatic void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    const uint32_t src0_type_size = ggml_type_size(src0->type);\n    const uint32_t src1_type_size = ggml_type_size(src1->type);\n    const uint32_t dst_type_size = ggml_type_size(dst->type);\n\n    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SUB, {\n        (uint32_t)ggml_nelements(src0),\n        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,\n        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,\n        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,\n        0,\n        0.0f, 0.0f, 0,\n    });\n}\n\nstatic void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    const uint32_t src0_type_size = ggml_type_size(src0->type);\n    const uint32_t src1_type_size = ggml_type_size(src1->type);\n    const uint32_t dst_type_size = ggml_type_size(dst->type);\n\n    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_MUL, {\n        (uint32_t)ggml_nelements(src0),\n        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,\n        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,\n        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,\n        0,\n        0.0f, 0.0f, 0,\n    });\n}\n\nstatic void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    const uint32_t src0_type_size = ggml_type_size(src0->type);\n    const uint32_t src1_type_size = ggml_type_size(src1->type);\n    const uint32_t dst_type_size = ggml_type_size(dst->type);\n\n    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_DIV, {\n        (uint32_t)ggml_nelements(src0),\n        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,\n        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,\n        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,\n        0,\n        0.0f, 0.0f, 0,\n    });\n}\n\nstatic void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {\n    const uint32_t src0_type_size = ggml_type_size(src0->type);\n    const uint32_t src1_type_size = ggml_type_size(src1->type);\n    const uint32_t src2_type_size = ggml_type_size(src2->type);\n\n    ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_ADD_ID, {\n        (uint32_t)dst->ne[0],\n        (uint32_t)dst->ne[1],\n        (uint32_t)src0->nb[1] / src0_type_size,\n        (uint32_t)src0->nb[2] / src0_type_size,\n        (uint32_t)src1->nb[1] / src1_type_size,\n        (uint32_t)src2->nb[1] / src2_type_size,\n    });\n}\n\nstatic void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version) {\n    GGML_ASSERT(version == 6 || version == 7);\n    int num_srcs = version == 6 ? 6 : 7;\n\n    for (int i = 0; i < num_srcs; i++) {\n        GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));\n    }\n\n    GGML_ASSERT(dst->buffer != nullptr);\n\n    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);\n    GGML_ASSERT(pipeline != nullptr);\n\n    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n\n    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);\n    vk_subbuffer src_buf[7] = {};\n    for (int i = 0; i < num_srcs; i++) {\n        src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);\n    }\n\n    std::array<uint32_t, 3> elements = {\n        (uint32_t)(pc.B * pc.H),\n        1,\n        1\n    };\n\n    if (version == 6) {\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,\n            {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},\n            pc, elements);\n    } else if (version == 7) {\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,\n            {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},\n            pc, elements);\n    } else {\n        // shouldn't happen\n        GGML_ASSERT(false);\n    }\n}\n\nstatic void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {\n    const size_t seq_length = dst->src[0]->ne[2];\n    const size_t n_embed = dst->ne[0];\n    const size_t n_heads = dst->src[0]->ne[1];\n    const size_t n_seqs = dst->src[5]->ne[1];\n\n    ggml_vk_op_f32_wkv(\n        ctx, subctx, dst,\n        {\n            (uint32_t)n_seqs,\n            (uint32_t)seq_length,\n            (uint32_t)n_embed,\n            (uint32_t)n_heads,\n        },\n        6\n    );\n}\n\nstatic void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {\n    const size_t seq_length = dst->src[0]->ne[2];\n    const size_t n_embed = dst->ne[0];\n    const size_t n_heads = dst->src[0]->ne[1];\n    const size_t n_seqs = dst->src[6]->ne[1];\n\n    ggml_vk_op_f32_wkv(\n        ctx, subctx, dst,\n        {\n            (uint32_t)n_seqs,\n            (uint32_t)seq_length,\n            (uint32_t)n_embed,\n            (uint32_t)n_heads,\n        },\n        7\n    );\n}\n\nstatic void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {\n    const ggml_tensor * src_q     = dst->src[0];\n    const ggml_tensor * src_v     = dst->src[2];\n    const ggml_tensor * src_beta  = dst->src[4];\n\n    GGML_ASSERT(dst->buffer != nullptr);\n\n    const uint32_t S_v      = (uint32_t)src_v->ne[0];\n    const uint32_t H        = (uint32_t)src_v->ne[1];\n    const uint32_t n_tokens = (uint32_t)src_v->ne[2];\n    const uint32_t n_seqs   = (uint32_t)src_v->ne[3];\n\n    const uint32_t s_off = S_v * H * n_tokens * n_seqs;\n\n    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);\n    GGML_ASSERT(pipeline != nullptr);\n\n    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n\n    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);\n    vk_subbuffer src_buf[6] = {};\n    for (int i = 0; i < 6; i++) {\n        src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);\n    }\n\n    const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float));\n    const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float));\n    const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float));\n    const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float));\n    const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float));\n    const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float));\n    const uint32_t sb1 = (uint32_t)(src_beta->nb[1] / sizeof(float));\n    const uint32_t sb2 = (uint32_t)(src_beta->nb[2] / sizeof(float));\n    const uint32_t sb3 = (uint32_t)(src_beta->nb[3] / sizeof(float));\n\n    const uint32_t neq1 = (uint32_t)src_q->ne[1];\n    const uint32_t rq3  = (uint32_t)(src_v->ne[3] / src_q->ne[3]);\n\n    const float scale = 1.0f / sqrtf((float)S_v);\n    const vk_op_gated_delta_net_push_constants pc = {\n        H, n_tokens, n_seqs, s_off,\n        sq1, sq2, sq3,\n        sv1, sv2, sv3,\n        sb1, sb2, sb3,\n        neq1, rq3,\n        scale\n    };\n\n    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,\n        {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},\n        pc, { H, n_seqs, 1u });\n}\n\nstatic void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    const ggml_tensor * src2 = dst->src[2];\n    const ggml_tensor * src3 = dst->src[3];\n    const ggml_tensor * src4 = dst->src[4];\n    const ggml_tensor * src5 = dst->src[5];\n\n    GGML_ASSERT(dst->buffer != nullptr);\n\n    const uint32_t head_dim = src0->ne[1];\n    const uint32_t n_head = src1->ne[1];\n    const uint32_t n_group = src4->ne[1];\n    const uint32_t n_tok = src1->ne[2];\n    const uint32_t n_seq = src1->ne[3];\n\n    bool is_mamba2 = (src3->nb[1] == sizeof(float));\n    GGML_ASSERT(is_mamba2);\n\n    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op);\n    GGML_ASSERT(pipeline != nullptr);\n\n    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n\n    const int64_t s_off = ggml_nelements(src1) * sizeof(float);\n\n    const vk_op_ssm_scan_push_constants pc = {\n        (uint32_t)src0->nb[2], (uint32_t)src0->nb[3],\n        (uint32_t)src1->nb[2], (uint32_t)src1->nb[3],\n        (uint32_t)src2->nb[1], (uint32_t)src2->nb[2],\n        (uint32_t)src3->nb[1],\n        (uint32_t)src4->nb[2], (uint32_t)src4->nb[3],\n        (uint32_t)src5->nb[2], (uint32_t)src5->nb[3],\n        (uint32_t)s_off,\n        n_head, head_dim, n_group, n_tok\n    };\n\n    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);\n    vk_subbuffer src_buf[7] = {};\n    for (int i = 0; i < 7 && dst->src[i] != nullptr; i++) {\n        src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);\n    }\n\n    std::array<uint32_t, 3> elements;\n\n    const uint32_t d_state = src0->ne[0];\n    uint32_t num_subgroups = d_state / ctx->device->subgroup_size;\n    const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups);\n    const uint32_t num_workgroups_y = n_seq;\n    elements = { num_workgroups_x, num_workgroups_y, 1 };\n\n    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,\n        {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},\n        pc, elements);\n}\n\nstatic void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n\n    ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, {\n        (uint32_t)src0->nb[1], (uint32_t)src0->nb[2],\n        (uint32_t)src1->nb[1],\n        (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],\n        (uint32_t)src1->ne[0],\n        (uint32_t)src0->ne[0],\n        (uint32_t)src0->ne[1],\n        (uint32_t)dst->ne[1],\n        (uint32_t)dst->ne[2],\n    });\n}\n\nstatic void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc) {\n    const ggml_tensor * x = dst->src[0];\n    const ggml_tensor * g = dst->src[1];\n    const ggml_tensor * gm = dst->src[2];\n    const ggml_tensor * gv = dst->src[3];\n    const ggml_tensor * p = dst->src[4];\n\n    GGML_ASSERT(x->type == GGML_TYPE_F32);\n    GGML_ASSERT(g->type == GGML_TYPE_F32);\n    GGML_ASSERT(gm->type == GGML_TYPE_F32);\n    GGML_ASSERT(gv->type == GGML_TYPE_F32);\n    GGML_ASSERT(p->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->buffer != nullptr);\n    GGML_ASSERT(ggml_is_contiguous(x));\n    GGML_ASSERT(ggml_is_contiguous(g));\n    GGML_ASSERT(ggml_is_contiguous(gm));\n    GGML_ASSERT(ggml_is_contiguous(gv));\n    GGML_ASSERT(ggml_is_contiguous(p));\n    GGML_ASSERT(ggml_are_same_shape(x, g));\n    GGML_ASSERT(ggml_are_same_shape(x, gm));\n    GGML_ASSERT(ggml_are_same_shape(x, gv));\n    GGML_ASSERT(ggml_nelements(p) == 7);\n\n    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);\n    GGML_ASSERT(pipeline != nullptr);\n\n    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n\n    vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x);\n    vk_subbuffer g_buf = ggml_vk_tensor_subbuffer(ctx, g);\n    vk_subbuffer gm_buf = ggml_vk_tensor_subbuffer(ctx, gm);\n    vk_subbuffer gv_buf = ggml_vk_tensor_subbuffer(ctx, gv);\n    vk_subbuffer p_buf = ggml_vk_tensor_subbuffer(ctx, p);\n\n    std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };\n\n    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,\n        {x_buf, g_buf, gm_buf, gv_buf, p_buf},\n        pc, elements);\n}\n\nstatic void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {\n    const size_t n = ggml_nelements(dst->src[0]);\n\n    ggml_vk_op_f32_opt_step_adamw(\n        ctx, subctx, dst,\n        { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f }\n    );\n}\n\nstatic void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {\n    const size_t n = ggml_nelements(dst->src[0]);\n\n    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f });\n}\n\nstatic void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    int * op_params = (int *)dst->op_params;\n\n    const uint32_t src0_type_size = ggml_type_size(src0->type);\n    const uint32_t src1_type_size = ggml_type_size(src1->type);\n    const uint32_t dst_type_size = ggml_type_size(dst->type);\n\n    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONCAT, {\n        (uint32_t)ggml_nelements(dst),\n        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,\n        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,\n        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,\n        0,\n        0.0f, 0.0f, op_params[0],\n    });\n}\n\nstatic void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    const uint32_t src0_type_size = ggml_type_size(src0->type);\n    const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);\n\n    GGML_TENSOR_UNARY_OP_LOCALS\n\n    float sf0 = (float)ne0 / ne00;\n    float sf1 = (float)ne1 / ne01;\n    float sf2 = (float)ne2 / ne02;\n    float sf3 = (float)ne3 / ne03;\n    float pixel_offset = 0.5f;\n\n    if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {\n        sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;\n        sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;\n        pixel_offset = 0.0f;\n    }\n\n    ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UPSCALE, {\n        (uint32_t)ggml_nelements(dst), 0, 0,\n        (uint32_t)ne00, (uint32_t)ne01,\n        (uint32_t)nb00 / src0_type_size, (uint32_t)nb01 / src0_type_size, (uint32_t)nb02 / src0_type_size, (uint32_t)nb03 / src0_type_size,\n        (uint32_t)ne0, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,\n        sf0, sf1, sf2, sf3, pixel_offset\n    });\n}\n\nstatic void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);\n    p.param1 = ggml_get_op_params_f32(dst, 0);\n    p.param2 = ggml_get_op_params_f32(dst, 1);\n\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p));\n}\n\nstatic void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst));\n}\n\nstatic void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst));\n}\n\nstatic void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    const uint32_t src0_type_size = ggml_type_size(src0->type);\n    const uint32_t src1_type_size = ggml_type_size(src1->type);\n    const uint32_t dst_type_size = ggml_type_size(dst->type);\n\n    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD1, {\n        (uint32_t)ggml_nelements(src0),\n        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,\n        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,\n        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,\n        0,\n        0.0f, 0.0f, 0,\n    });\n}\n\nstatic void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {\n    VK_LOG_DEBUG(\"ggml_vk_arange(dst=\" << dst << \", ne=\" << ggml_nelements(dst) << \")\");\n\n    vk_op_push_constants pc = {\n        (uint32_t)ggml_nelements(dst),\n        1,\n        ggml_get_op_params_f32(dst, 0),\n        ggml_get_op_params_f32(dst, 2),\n        0.0f, 0.0f,\n    };\n\n    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);\n    GGML_ASSERT(pipeline != nullptr);\n\n    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);\n\n    std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };\n\n    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);\n}\n\nstatic void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {\n    VK_LOG_DEBUG(\"ggml_vk_fill(dst=\" << dst << \", ne=\" << ggml_nelements(dst) << \")\");\n\n    vk_op_push_constants pc = {\n        (uint32_t)ggml_nelements(dst),\n        1,\n        ggml_get_op_params_f32(dst, 0),\n        0.0f,\n        0.0f, 0.0f,\n    };\n\n    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);\n    GGML_ASSERT(pipeline != nullptr);\n\n    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);\n\n    std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };\n\n    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);\n}\n\nstatic void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst));\n}\n\nstatic void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst));\n}\n\nstatic void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));\n}\n\nstatic void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);\n    p.param1 = ggml_get_op_params_f32(dst, 0);\n\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));\n}\n\nstatic void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));\n\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p));\n}\n\nstatic void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);\n    p.param1 = ggml_get_op_params_f32(dst, 0);\n    p.param2 = ggml_get_op_params_f32(dst, 1);\n\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p));\n}\n\nstatic void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst);\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p));\n}\n\nstatic void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    const int32_t s0 = ggml_get_op_params_i32(dst, 0);\n    const int32_t s1 = ggml_get_op_params_i32(dst, 1);\n    const int32_t s2 = ggml_get_op_params_i32(dst, 2);\n    const int32_t s3 = ggml_get_op_params_i32(dst, 3);\n    const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);\n    const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);\n\n    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);\n    memcpy(&p.param1, &s01_packed, sizeof(float));\n    memcpy(&p.param2, &s23_packed, sizeof(float));\n\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p));\n}\n\nstatic void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p));\n}\n\nstatic void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p));\n}\n\nstatic void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    uint32_t ne = (uint32_t)ggml_nelements(src0);\n    if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {\n        // Convert from number of logical elements to 2- or 4-byte units.\n        ne /= ggml_blck_size(src0->type);\n        if ((ggml_type_size(src0->type) % 4) == 0) {\n            ne *= ggml_type_size(src0->type) / 4;\n        } else {\n            ne *= ggml_type_size(src0->type) / 2;\n        }\n    }\n\n    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p));\n}\n\nstatic void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    const uint32_t src0_type_size = ggml_type_size(src0->type);\n    const uint32_t src1_type_size = ggml_type_size(src1->type);\n    const uint32_t dst_type_size = ggml_type_size(dst->type);\n\n    // Skip empty skip_rows operations. For most ops the empty check at the start\n    // of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst\n    // with empty srcs.\n    if (ggml_is_empty(src0) || ggml_is_empty(src1)) {\n        return;\n    }\n\n    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SET_ROWS, {\n        (uint32_t)ggml_nelements(src0),\n        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,\n        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,\n        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,\n        0,\n        0.0f, 0.0f, 0,\n    });\n}\n\nstatic void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });\n}\n\nstatic void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    float * op_params = (float *)dst->op_params;\n\n    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });\n}\n\nstatic void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    const int * int_op_params = (const int *)dst->op_params;\n    const float * float_op_params = (const float *)dst->op_params;\n\n    const uint32_t num_groups = int_op_params[0];\n    const float eps = float_op_params[1];\n    const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);\n\n    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f, 0.0f, 0.0f });\n}\n\nstatic uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {\n    const uint32_t ne = (uint32_t)node->ne[0];\n    const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];\n    const uint32_t num_partials = CEIL_DIV(ne, denom);\n    return num_partials;\n}\n\nstatic uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {\n    const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);\n    const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);\n    return num_bytes;\n}\n\nstatic vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *dst, const ggml_tensor *src0, const bool has_ff, bool backprop, const uint32_t set_rows_stride) {\n    const int n_dims        = ((const int32_t *) dst->op_params)[1];\n    const int mode          = ((const int32_t *) dst->op_params)[2];\n    // const int n_ctx         = ((const int32_t *) dst->op_params)[3];\n    const int n_ctx_orig    = ((const int32_t *) dst->op_params)[4];\n    const float freq_base   = ((const float *)   dst->op_params)[5];\n    const float freq_scale  = ((const float *)   dst->op_params)[6];\n    const float ext_factor  = ((const float *)   dst->op_params)[7];\n    const float attn_factor = ((const float *)   dst->op_params)[8];\n    const float beta_fast   = ((const float *)   dst->op_params)[9];\n    const float beta_slow   = ((const float *)   dst->op_params)[10];\n    int sections[4] {};\n    if (mode & GGML_ROPE_TYPE_MROPE) {\n        memcpy(sections, (const int32_t *) dst->op_params + 11, sizeof(int)*4);\n    }\n\n    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;\n\n    float corr_dims[2];\n    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);\n\n    const float theta_scale = powf(freq_base, -2.0f/n_dims);\n\n    uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);\n    uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);\n    uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type);\n\n    uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type);\n    uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type);\n    uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type);\n\n    vk_op_rope_push_constants rope {\n        (uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale,\n        freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff,\n        { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,\n\n        (uint32_t)src0->ne[0],\n        (uint32_t)src0->ne[1],\n        (uint32_t)src0->ne[2],\n        nb01, nb02, nb03,\n        nb11, nb12, nb13,\n    };\n\n    return rope;\n}\n\nstatic void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx, float * op_params) {\n    ggml_tensor * dst;\n    const ggml_tensor * src0;\n    const ggml_tensor * src1;\n\n    if (ctx->num_additional_fused_ops > 0) {\n        // fused rms_norm + mul\n        ggml_tensor *mul = cgraph->nodes[node_idx + 1];\n        ggml_tensor *other_src = mul->src[0] == cgraph->nodes[node_idx + 0] ? mul->src[1] : mul->src[0];\n        dst = mul;\n        src0 = cgraph->nodes[node_idx]->src[0];\n        src1 = other_src;\n    } else {\n        dst = cgraph->nodes[node_idx];\n        src0 = src1 = dst->src[0];\n    }\n\n    const uint32_t src0_type_size = ggml_type_size(src0->type);\n    const uint32_t src1_type_size = ggml_type_size(src1->type);\n    const uint32_t dst_type_size = ggml_type_size(dst->type);\n\n    uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;\n\n    vk_op_binary_push_constants bin {\n        (uint32_t)ggml_nelements(src0),\n        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,\n        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,\n        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,\n        0,\n        op_params[0], 0.0f, (int32_t)param3,\n    };\n\n    // more than one fused op means rms_norm+mul+rope\n    if (ctx->num_additional_fused_ops > 1) {\n        static constexpr uint32_t max_tensors = 7;\n        const ggml_tensor *tensors[max_tensors] {};\n\n        ggml_tensor *rms = cgraph->nodes[node_idx + 0];\n        ggml_tensor *mul = cgraph->nodes[node_idx + 1];\n        ggml_tensor *rope = cgraph->nodes[node_idx + 2];\n\n        ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];\n\n        bool do_set_rows = ctx->num_additional_fused_ops == 4;\n\n        tensors[0] = rms->src[0];\n        tensors[1] = other_src;\n        tensors[2] = mul;\n        tensors[3] = rope->src[1]; // pos\n        tensors[4] = rope->src[2]; // ff\n        tensors[5] = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; // dst\n        tensors[6] = do_set_rows ? tensors[5]->src[1] : nullptr;\n        const uint32_t set_rows_stride = do_set_rows ? tensors[5]->nb[1] / ggml_type_size(tensors[5]->type) : 0;\n\n        vk_op_rms_norm_mul_rope_push_constants pc;\n        pc.bin = bin;\n        pc.rope = ggml_vk_make_rope_constants(rope, rope->src[0], tensors[4] != nullptr, false, set_rows_stride);\n\n        vk_pipeline pipeline = tensors[5]->type == GGML_TYPE_F16 ? ctx->device->pipeline_rms_norm_mul_rope_f32_f16 : ctx->device->pipeline_rms_norm_mul_rope_f32_f32;\n\n        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n\n        ggml_backend_vk_buffer_context * buf_ctx[max_tensors];\n        vk_buffer buf[max_tensors];\n        size_t offset[max_tensors];\n        bool uma[max_tensors];\n\n        for (uint32_t i = 0; i < max_tensors; ++i) {\n            if (!tensors[i]) {\n                // If any remaining descriptors are unused, just point them at src[0]\n                buf[i] = buf[0];\n                offset[i] = 0;\n                continue;\n            }\n            buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;\n            buf[i] = nullptr;\n            offset[i] = 0;\n            uma[i] = false;\n\n            if (ctx->device->uma) {\n                ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);\n                uma[i] = buf[i] != nullptr;\n            }\n            if (!uma[i]) {\n                buf[i] = buf_ctx[i]->dev_buffer;\n                offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;\n            }\n            GGML_ASSERT(buf[i] != nullptr);\n        }\n\n        std::array<uint32_t, 3> elements;\n        elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] };\n\n        static_assert(max_tensors == 7);\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,\n            {\n                ggml_vk_subbuffer(ctx, buf[0], offset[0]),\n                ggml_vk_subbuffer(ctx, buf[1], offset[1]),\n                ggml_vk_subbuffer(ctx, buf[2], offset[2]),\n                ggml_vk_subbuffer(ctx, buf[3], offset[3]),\n                ggml_vk_subbuffer(ctx, buf[4], offset[4]),\n                ggml_vk_subbuffer(ctx, buf[5], offset[5]),\n                ggml_vk_subbuffer(ctx, buf[6], offset[6]),\n            }, pc, elements);\n    } else {\n        ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, std::move(bin));\n    }\n\n    if (ctx->do_add_rms_partials_offset_calculation) {\n        ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);\n        ctx->do_add_rms_partials = false;\n        ctx->do_add_rms_partials_offset_calculation = false;\n    }\n}\n\nstatic void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    float * op_params = (float *)dst->op_params;\n    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });\n}\n\nstatic void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    const float * op_params = (const float *)dst->op_params;\n    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);\n    p.param1 = op_params[0];\n    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p));\n}\n\nstatic void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });\n}\n\nstatic void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    float * op_params = (float *)dst->op_params;\n    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY,\n        {\n            (uint32_t)ggml_nelements(src0), 0,\n            op_params[1], op_params[2], op_params[3], op_params[4]\n        }\n    );\n}\n\nstatic void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    const float * op_params_f = (const float *)dst->op_params;\n\n    const bool swapped = (bool)dst->op_params[1];\n    const bool split = src1 != nullptr;\n    const float alpha = op_params_f[2];\n    const float limit = op_params_f[3];\n\n    GGML_ASSERT(ggml_is_contiguous(src0));\n\n    if (!split) {\n        GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);\n    } else {\n        GGML_ASSERT(src0->ne[0] == src1->ne[0]);\n        GGML_ASSERT(src0->ne[0] == dst->ne[0]);\n        GGML_ASSERT(src0->type == src1->type);\n    }\n\n    const uint32_t mode = split ? 2 : (swapped ? 1 : 0);\n\n    ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GLU,\n        {\n            (uint32_t)ggml_nelements(dst),\n            (uint32_t)src0->ne[0],\n            (uint32_t)dst->ne[0],\n            mode,\n            alpha,\n            limit\n        });\n}\n\nstatic void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    int32_t * op_params = (int32_t *)dst->op_params;\n    ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });\n}\n\nstatic void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {\n    float * op_params = (float *)dst->op_params;\n\n    float scale = op_params[0];\n    float max_bias = op_params[1];\n\n    const uint32_t ncols =   (uint32_t)src0->ne[0];\n    const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);\n    const uint32_t nrows_y = (uint32_t)src0->ne[1];\n\n    const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;\n    const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;\n    const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;\n    const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;\n    const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;\n\n    const uint32_t n_head_kv   = src0->ne[2];\n    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));\n\n    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);\n    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n    vk_op_soft_max_push_constants pc {\n        ncols,\n        src1 != nullptr ? nrows_y : (uint32_t)0,\n        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],\n        ne12, ne13,\n        nb11, nb12, nb13,\n        scale, max_bias,\n        m0, m1,\n        n_head_log2,\n        nrows_x,\n        src2 != nullptr\n    };\n\n    if (ncols <= 16384) {\n        ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc));\n    } else {\n\n        vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0);\n        vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a;\n        vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a;\n        vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst);\n\n        uint32_t elems_per_wg = 128 * 4;\n        uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg);\n        size_t tmp_size = num_wgs * nrows_x * sizeof(float);\n\n        if (ctx->prealloc_size_x < tmp_size) {\n            ctx->prealloc_size_x = tmp_size;\n            ggml_vk_preallocate_buffers(ctx, subctx);\n        }\n        if (ctx->prealloc_size_y < tmp_size) {\n            ctx->prealloc_size_y = tmp_size;\n            ggml_vk_preallocate_buffers(ctx, subctx);\n        }\n        if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) {\n            ggml_vk_sync_buffers(ctx, subctx);\n        }\n\n        vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size };\n        vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size };\n\n        std::array<uint32_t, 3> elements = { num_wgs, nrows_x, 1 };\n\n        vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32;\n        vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32;\n        vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32;\n\n        ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);\n        ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);\n        ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1);\n\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);\n        ggml_vk_sync_buffers(ctx, subctx);\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);\n        ggml_vk_sync_buffers(ctx, subctx);\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);\n\n        ctx->prealloc_x_need_sync = true;\n        ctx->prealloc_y_need_sync = true;\n    }\n}\n\nstatic void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    float * op_params = (float *)dst->op_params;\n    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1], 0.0f, 0.0f });\n}\n\nstatic void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {\n    topk_moe_mode mode = ctx->fused_topk_moe_mode;\n    ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];\n    ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;\n    ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];\n    ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :\n                        (mode == TOPK_MOE_LATE_SOFTMAX) ?      cgraph->nodes[node_idx + 1] :\n                                                               cgraph->nodes[node_idx + 3];\n\n    GGML_ASSERT(logits->type == GGML_TYPE_F32);\n    GGML_ASSERT(bias->type == GGML_TYPE_F32);\n    GGML_ASSERT(weights->type == GGML_TYPE_F32);\n    GGML_ASSERT(ids->type == GGML_TYPE_I32);\n\n    const int n_experts = logits->ne[0];\n    const int n_rows    = logits->ne[1];\n    const int n_expert_used = weights->ne[1];\n\n    GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);\n\n    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX);\n\n    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n\n    vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);\n    vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias);\n    vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);\n    vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);\n\n    vk_op_topk_moe_push_constants pc {};\n    pc.n_rows = n_rows;\n    pc.n_experts_push = n_experts;\n    pc.n_expert_used = n_expert_used;\n    pc.clamp_min = -std::numeric_limits<float>::infinity();\n    pc.clamp_max = std::numeric_limits<float>::infinity();\n    if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {\n        ggml_tensor * clamp = cgraph->nodes[node_idx + 7];\n        GGML_ASSERT(clamp->op == GGML_OP_CLAMP);\n        pc.clamp_min = ggml_get_op_params_f32(clamp, 0);\n        pc.clamp_max = ggml_get_op_params_f32(clamp, 1);\n    }\n    if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {\n        ggml_tensor * clamp = cgraph->nodes[node_idx + 8];\n        GGML_ASSERT(clamp->op == GGML_OP_CLAMP);\n        pc.clamp_min = ggml_get_op_params_f32(clamp, 0);\n        pc.clamp_max = ggml_get_op_params_f32(clamp, 1);\n    }\n\n#define GATING_FUNC_SOFTMAX 0\n#define GATING_FUNC_SIGMOID 1\n#define GATING_FUNC_SOFTMAX_WEIGHT 2\n\n    pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :\n                     mode == TOPK_MOE_LATE_SOFTMAX ?      GATING_FUNC_SOFTMAX_WEIGHT :\n                                                          GATING_FUNC_SOFTMAX;\n    pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;\n    pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;\n    if (ctx->fused_topk_moe_scale) {\n        GGML_ASSERT(weights->op == GGML_OP_SCALE);\n        pc.output_scale = ggml_get_op_params_f32(weights, 0);\n        pc.output_bias = ggml_get_op_params_f32(weights, 1);\n    } else {\n        pc.output_scale = 1.0f;\n        pc.output_bias = 0.0f;\n    }\n\n    GGML_ASSERT(n_expert_used <= n_experts);\n\n    const uint32_t rows_per_block = 4;\n    std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };\n\n    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements);\n}\n\nstatic void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {\n    ggml_tensor * dst = cgraph->nodes[node_idx];\n    const ggml_tensor * src0 = dst->src[0];\n    const ggml_tensor * src1 = dst->src[1];\n    const ggml_tensor * src2 = dst->src[2];\n    const ggml_tensor * src3 = nullptr;\n    const int n_dims        = ((int32_t *) dst->op_params)[1];\n    const int mode          = ((int32_t *) dst->op_params)[2];\n    // const int n_ctx         = ((int32_t *) dst->op_params)[3];\n    const int n_ctx_orig    = ((int32_t *) dst->op_params)[4];\n    const float freq_base   = ((float *)   dst->op_params)[5];\n    const float beta_fast   = ((float *)   dst->op_params)[9];\n    const float beta_slow   = ((float *)   dst->op_params)[10];\n    int sections[4] {};\n    if (mode & GGML_ROPE_TYPE_MROPE) {\n        memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);\n    }\n\n    float corr_dims[2];\n    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);\n\n    uint32_t set_rows_stride = 0;\n    // Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride\n    // and overrides the dst and sets src3=row_indices\n    if (ctx->num_additional_fused_ops > 0) {\n        set_rows_stride = cgraph->nodes[node_idx + 2]->nb[1] / ggml_type_size(cgraph->nodes[node_idx + 2]->type);\n        src3 = cgraph->nodes[node_idx + 2]->src[1];\n        dst = cgraph->nodes[node_idx + 2];\n    }\n\n    ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE,\n        ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride));\n}\n\nstatic void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    const uint32_t * op_params = (const uint32_t *)dst->op_params;\n\n    uint32_t ncols = src0->ne[0];\n    uint32_t nrows = ggml_nrows(src0);\n\n    uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols)));\n    uint32_t ncolsp2 = 1 << ncols_pad_log2;\n\n    vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, };\n\n    // Pick the largest workgroup size <= ncolsp2\n    uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1);\n\n    // Use the \"small\" argsort shader if the whole sort can be done by a single workgroup.\n    bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 &&\n                     ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr;\n\n    vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx]\n                                     : ctx->device->pipeline_argsort_large_f32[pipeline_idx];\n\n    vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0);\n    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);\n    vk_subbuffer subbuf1 = dst_buf;\n\n    // Reserve space for ivec2 per element, with rows padded to a power of two\n    if (!use_small) {\n        const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int);\n\n        if (ctx->prealloc_size_x < x_sz) {\n            ctx->prealloc_size_x = x_sz;\n            ggml_vk_preallocate_buffers(ctx, subctx);\n        }\n        if (ctx->prealloc_x_need_sync) {\n            ggml_vk_sync_buffers(ctx, subctx);\n        }\n        subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size };\n    }\n\n    std::array<uint32_t, 3> elements;\n\n    elements[0] = ncolsp2;\n    elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]);\n    elements[2] = 1;\n\n    // First dispatch initializes tmp_idx and does the first N passes where\n    // there is only communication between threads in the same workgroup.\n    {\n        vk_op_argsort_push_constants pc2 = pc;\n        pc2.outer_start = 0;\n        pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2);\n        pc2.inner_start = 0;\n        pc2.inner_end = 100;\n        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);\n    }\n    if (!use_small) {\n        ggml_vk_sync_buffers(ctx, subctx);\n        // Loop over outer/inner passes, synchronizing between each pass.\n        for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) {\n            for (uint32_t inner = 0; inner < outer + 1; ++inner) {\n                vk_op_argsort_push_constants pc2 = pc;\n                pc2.outer_start = outer;\n                pc2.outer_end = outer + 1;\n                pc2.inner_start = inner;\n                pc2.inner_end = inner + 1;\n                // When the inner idx is large enough, there's only communication\n                // within a workgroup. So the remaining inner iterations can all\n                // run in the same dispatch.\n                if (outer - inner < pipeline_idx) {\n                    pc2.inner_end = 100;\n                    inner = outer;\n                    pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx];\n                } else {\n                    // Smaller workgroup empirically seems to perform better\n                    pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2];\n                }\n                ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n                ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);\n                ggml_vk_sync_buffers(ctx, subctx);\n            }\n        }\n        ctx->prealloc_x_need_sync = true;\n    }\n}\n\nstatic void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    uint32_t ncols = src0->ne[0];\n    uint32_t nrows = ggml_nrows(src0);\n    uint32_t k = dst->ne[0];\n\n    vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 };\n\n    if (ctx->prealloc_x_need_sync) {\n        ggml_vk_sync_buffers(ctx, subctx);\n    }\n\n    std::array<uint32_t, 3> elements;\n    elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);\n    elements[2] = 1;\n\n    uint32_t num_elements = ncols;\n\n    // Each iteration reduces a workgroup's worth of elements down to the K\n    // largest elements. Repeat until we have the top K elements.\n    // Need to do at least one iteration to write out the results.\n    bool done_one_iter = false;\n    uint32_t dbl_buf_index = 0;\n    size_t dbl_buf_size;\n    while (num_elements > k || !done_one_iter) {\n\n        // Prefer going as small as num_topk_pipelines - 3 for perf reasons.\n        // But if K is larger, then we need a larger workgroup\n        uint32_t max_pipeline = num_topk_pipelines - 1;\n        uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2);\n        max_pipeline = std::min(preferred_pipeline, max_pipeline);\n        uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;\n        // require full subgroup\n        min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);\n\n        uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));\n        pipeline_idx = std::min(pipeline_idx, max_pipeline);\n        pipeline_idx = std::max(pipeline_idx, min_pipeline);\n\n        if (num_elements > (1u << pipeline_idx)) {\n            // If we could finish on this loop iteration (i.e. a single workgroup)\n            // then do so. It's better than the overhead of another pass.\n            for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) {\n                if (num_elements <= (1u << i)) {\n                    pipeline_idx = i;\n                    break;\n                }\n            }\n        }\n\n        vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];\n        // If the device doesn't support a pipeline this large, use smaller\n        while (!pipeline) {\n            pipeline_idx--;\n            GGML_ASSERT(pipeline_idx >= min_pipeline);\n            pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];\n        }\n\n        vk_op_topk_push_constants pc2 = pc;\n        pc2.ncols_input = num_elements;\n\n        // Number of elements remaining after this pass\n        uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);\n\n        pc2.ncols_output = num_dst_elements;\n\n        if (!done_one_iter) {\n            // Reserve space for ivec2 per element, double buffered\n            // K per workgroup per row\n            dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int);\n            dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment);\n            const size_t x_sz = dbl_buf_size * 2;\n\n            if (ctx->prealloc_size_x < x_sz) {\n                ctx->prealloc_size_x = x_sz;\n                ggml_vk_preallocate_buffers(ctx, subctx);\n            }\n        }\n\n        vk_subbuffer src_buf;\n        vk_subbuffer dst_buf;\n\n        if (num_elements == ncols) {\n            pc2.first_pass = 1;\n            src_buf = ggml_vk_tensor_subbuffer(ctx, src0);\n        } else {\n            src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size };\n        }\n        if (num_dst_elements == k) {\n            pc2.last_pass = 1;\n            dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);\n        } else {\n            dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size };\n        }\n\n        elements[0] = num_elements;\n\n        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);\n        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements);\n        num_elements = num_dst_elements;\n        dbl_buf_index ^= 1;\n        if (num_elements > k) {\n            ggml_vk_sync_buffers(ctx, subctx);\n        }\n        done_one_iter = true;\n    }\n    ctx->prealloc_x_need_sync = true;\n}\n\nstatic void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);\n}\n\nstatic void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p);\n}\n\nstatic void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);\n    p.weight = 1.0f / (float)src0->ne[0];\n    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);\n}\n\nstatic void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);\n    // Use the single pass shader when the rows are small or there are enough rows to fill the GPU.\n    // For fewer, larger rows, use the multipass shader to spread each row across SMs.\n    if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) {\n        ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc);\n        return;\n    }\n\n    // First pass computes partial sums within a block, and stores the last partial\n    // to the temp buffer. Second pass sums the block partials from the temp buffer\n    // and adds that to the result of the first pass.\n    vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32;\n    vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32;\n    GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr);\n\n    ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);\n    ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);\n\n    std::array<uint32_t, 3> elements;\n\n    elements[0] = dst->ne[0];\n    elements[1] = (uint32_t)ggml_nrows(dst);\n    elements[2] = 1;\n\n    size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst);\n\n    if (ctx->prealloc_size_split_k < temp_size) {\n        ctx->prealloc_size_split_k = temp_size;\n        ggml_vk_preallocate_buffers(ctx, subctx);\n    }\n\n    vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0);\n    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);\n    vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);\n\n    if (ctx->prealloc_split_k_need_sync) {\n        ggml_vk_sync_buffers(ctx, subctx);\n    }\n\n    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements);\n    ggml_vk_sync_buffers(ctx, subctx);\n    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements);\n\n    ctx->prealloc_split_k_need_sync = true;\n}\n\nstatic void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f, 0.0f, 0.0f });\n}\n\nstatic void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });\n}\n\nstatic void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    const uint32_t src0_type_size = ggml_type_size(src0->type);\n    const uint32_t src1_type_size = ggml_type_size(src1->type);\n    const uint32_t dst_type_size = ggml_type_size(dst->type);\n\n    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOLVE_TRI, {\n        (uint32_t)ggml_nelements(src0),\n        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,\n        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,\n        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,\n        0,\n        0.0f, 0.0f, 0,\n    });\n}\n\nstatic void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    const int32_t s0 = dst->op_params[0];\n    const int32_t s1 = dst->op_params[1];\n    const int32_t p0 = dst->op_params[2];\n    const int32_t p1 = dst->op_params[3];\n    const int32_t d0 = dst->op_params[4];\n    const int32_t d1 = dst->op_params[5];\n\n    const bool is_2D = dst->op_params[6] == 1;\n\n    const uint32_t IC = src1->ne[is_2D ? 2 : 1];\n    const uint32_t IH = is_2D ? src1->ne[1] : 1;\n    const uint32_t IW =         src1->ne[0];\n\n    const uint32_t KH = is_2D ? src0->ne[1] : 1;\n    const uint32_t KW =         src0->ne[0];\n\n    const uint32_t OH = is_2D ? dst->ne[2] : 1;\n    const uint32_t OW =         dst->ne[1];\n\n    const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32\n    const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32\n\n    const uint32_t pelements = OW * KW * KH;\n    const uint32_t batch = src1->ne[is_2D ? 3 : 2];\n\n    const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;\n    const vk_buffer d_buf = d_buf_ctx->dev_buffer;\n\n    const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;\n\n    ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL, {\n        dst_addr,\n        batch_offset, offset_delta,\n        IC, IW, IH, OW, OH, KW, KH,\n        pelements,\n        IC * KH * KW,\n        s0, s1, p0, p1, d0, d1, batch * IC\n    });\n}\n\nstatic void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];\n    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];\n    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];\n    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];\n    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];\n    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];\n    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];\n    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];\n    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];\n    const int32_t IC = ((const int32_t *)(dst->op_params))[9];\n\n    const int64_t N  = ne13 / IC;\n    const int64_t ID = ne12;\n    const int64_t IH = ne11;\n    const int64_t IW = ne10;\n\n    const int64_t KD = ne02;\n    const int64_t KH = ne01;\n    const int64_t KW = ne00;\n\n    const int64_t OD = ne3 / N;\n    const int64_t OH = ne2;\n    const int64_t OW = ne1;\n\n    const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;\n    const vk_buffer d_buf = d_buf_ctx->dev_buffer;\n\n    const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;\n\n    vk_op_im2col_3d_push_constants pc {};\n\n    pc.dst_addr = dst_addr;\n    pc.nb10 = nb10 / ggml_type_size(src1->type);\n    pc.nb11 = nb11 / ggml_type_size(src1->type);\n    pc.nb12 = nb12 / ggml_type_size(src1->type);\n    pc.nb13 = nb13 / ggml_type_size(src1->type);\n    pc.s0 = s0;\n    pc.s1 = s1;\n    pc.s2 = s2;\n    pc.p0 = p0;\n    pc.p1 = p1;\n    pc.p2 = p2;\n    pc.d0 = d0;\n    pc.d1 = d1;\n    pc.d2 = d2;\n    pc.IW = IW;\n    pc.IH = IH;\n    pc.ID = ID;\n    pc.IC = IC;\n    pc.KW = KW;\n    pc.OH = OH;\n    pc.KD_KH_KW = KD*KH*KW;\n    pc.KH_KW = KH*KW;\n    pc.IC_KD_KH_KW = IC*KD*KH*KW;\n    pc.N_OD_OH = N*OD*OH;\n    pc.OD_OH = OD*OH;\n    pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;\n    pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;\n    pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;\n\n    ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc));\n}\n\nstatic void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    const uint32_t dim = dst->op_params[0];\n    const uint32_t max_period = dst->op_params[1];\n    const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type);\n\n    ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, {\n        nb1, dim, max_period,\n    });\n}\n\nstatic void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    // src0: (K, Cout, Cin, 1) -- kernel\n    // src1: (L, Cin, 1, 1) -- input\n    // dst: (*, Cout, 1, 1)\n\n    GGML_ASSERT(src0->type == GGML_TYPE_F32);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT( dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    GGML_ASSERT(nb00 == sizeof(float));\n    GGML_ASSERT(nb10 == sizeof(float));\n\n    const int32_t s0 = dst->op_params[0];\n\n    vk_op_conv_transpose_1d_push_constants p{};\n    p.Cout = static_cast<uint32_t>(ne01);\n    p.Cin = static_cast<uint32_t>(ne02);\n    p.K = static_cast<uint32_t>(ne00);\n    p.L = static_cast<uint32_t>(ne10);\n    p.KL = static_cast<uint32_t>(ne0);\n    p.nb01 = static_cast<uint32_t>(nb01 / nb00);\n    p.nb02 = static_cast<uint32_t>(nb02 / nb00);\n    p.nb11 = static_cast<uint32_t>(nb11 / nb10);\n    p.nb1 = static_cast<uint32_t>(nb1 / nb0);\n    p.s0 = static_cast<uint32_t>(s0);\n\n    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p));\n}\n\nstatic void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    uint32_t op = static_cast<uint32_t>(dst->op_params[0]);\n    const int32_t k1 = dst->op_params[1];\n    const int32_t k0 = dst->op_params[2];\n    const int32_t s1 = dst->op_params[3];\n    const int32_t s0 = dst->op_params[4];\n    const int32_t p1 = dst->op_params[5];\n    const int32_t p0 = dst->op_params[6];\n\n    const uint32_t IH = src0->ne[1];\n    const uint32_t IW = src0->ne[0];\n\n    const uint32_t N = dst->ne[3];\n\n    const uint32_t OC = dst->ne[2];\n    const uint32_t OH = dst->ne[1];\n    const uint32_t OW = dst->ne[0];\n\n    const uint32_t parallel_elements = N * OC * OH * OW;\n\n    ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_POOL_2D, {\n        IW, IH, OW, OH, OC,\n        parallel_elements,\n        op,\n        k0, k1, s0, s1, p0, p1,\n    });\n}\n\nstatic void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,\n                            const ggml_tensor * src1, ggml_tensor * dst) {\n    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);\n    GGML_ASSERT(src1->type == GGML_TYPE_F32);\n    GGML_ASSERT(dst->type == GGML_TYPE_F32);\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n    GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));\n    GGML_ASSERT(nb10 == sizeof(float));\n    GGML_ASSERT(nb0 == sizeof(float));\n\n    bool transpose = dst->op == GGML_OP_CONV_TRANSPOSE_2D;\n\n    vk_op_conv2d_push_constants p{};\n    p.Cout = static_cast<uint32_t>(!transpose ? ne03 : ne02);\n    p.Cin  = static_cast<uint32_t>(!transpose ? ne02 : ne03);\n    p.N    = static_cast<uint32_t>(ne13);\n    GGML_ASSERT(p.Cout == ne2);\n    GGML_ASSERT(p.Cin == ne12);\n\n    p.W  = static_cast<uint32_t>(ne10);\n    p.H  = static_cast<uint32_t>(ne11);\n    p.OW = static_cast<uint32_t>(ne0);\n    p.OH = static_cast<uint32_t>(ne1);\n\n    p.nb01 = static_cast<uint32_t>(nb01 / nb00);\n    p.nb02 = static_cast<uint32_t>(nb02 / nb00);\n    p.nb03 = static_cast<uint32_t>(nb03 / nb00);\n\n    p.nb11 = static_cast<uint32_t>(nb11 / nb10);\n    p.nb12 = static_cast<uint32_t>(nb12 / nb10);\n    p.nb13 = static_cast<uint32_t>(nb13 / nb10);\n\n    p.nb1 = static_cast<uint32_t>(nb1 / nb0);\n    p.nb2 = static_cast<uint32_t>(nb2 / nb0);\n    p.nb3 = static_cast<uint32_t>(nb3 / nb0);\n\n    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));\n}\n\nstatic void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {\n    vk_op_conv2d_dw_push_constants p{};\n    p.ne = ggml_nelements(dst);\n    p.channels = dst->ne[2];\n    p.batches = dst->ne[3];\n    p.dst_w = dst->ne[0];\n    p.dst_h = dst->ne[1];\n    p.src_w = src1->ne[0];\n    p.src_h = src1->ne[1];\n    p.knl_w = src0->ne[0];\n    p.knl_h = src0->ne[1];\n    p.stride_x = dst->op_params[0];\n    p.stride_y = dst->op_params[1];\n    p.pad_x = dst->op_params[2];\n    p.pad_y = dst->op_params[3];\n    p.dilation_x = dst->op_params[4];\n    p.dilation_y = dst->op_params[5];\n\n    GGML_ASSERT(src0->ne[3] == p.channels);\n    GGML_ASSERT(src1->ne[3] == p.batches);\n\n    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p));\n}\n\nstatic void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {\n    const float * op_params = (const float *)dst->op_params;\n    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });\n}\n\n#ifdef GGML_VULKAN_RUN_TESTS\nstatic void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) {\n    if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) {\n        return;\n    }\n    i0 = std::max(i0, 5);\n    i1 = std::max(i1, 5);\n    i2 = std::max(i2, 0);\n    fprintf(stderr, \"         \");\n    for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {\n        fprintf(stderr, \"%7d \", idx1);\n    }\n    fprintf(stderr, \"\\n\");\n    for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {\n        fprintf(stderr, \"%7d: \", idx0);\n        for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {\n            if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) {\n                float val;\n                if (type == GGML_TYPE_F32) {\n                    val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0);\n                } else if (type == GGML_TYPE_F16) {\n                    val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0));\n                } else {\n                    GGML_ABORT(\"fatal error\");\n                }\n                fprintf(stderr, \"% 7.2f \", val);\n            } else {\n                fprintf(stderr, \"        \");\n            }\n        }\n        fprintf(stderr, \"\\n\");\n    }\n}\n\ntemplate <typename X_TYPE, typename Y_TYPE>\nstatic void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) {\n    VK_LOG_DEBUG(\"ggml_vk_test_matmul(\" << m << \", \" << n << \", \" << k << \", \" << batch << \", \" << num_it << \", \" << split_k << \", \" << shader_size << \")\");\n    const size_t x_ne = m * k * batch;\n    const size_t y_ne = k * n * batch;\n    const size_t d_ne = m * n * batch;\n\n    vk_pipeline p;\n    std::string shname;\n    if (shader_size == 0) {\n        if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {\n            p = ctx->device->pipeline_matmul_f32->a_s;\n            shname = \"F32_ALIGNED_S\";\n        } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {\n            p = ctx->device->pipeline_matmul_f32_f16->a_s;\n            shname = \"F32_F16_ALIGNED_S\";\n        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {\n            p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s;\n            shname = \"F16_F32_ALIGNED_S\";\n        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {\n            p = ctx->device->pipeline_matmul_f16.f32acc->a_s;\n            shname = \"F16_ALIGNED_S\";\n        } else {\n            GGML_ABORT(\"fatal error\");\n        }\n    } else if (shader_size == 1) {\n        if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {\n            p = ctx->device->pipeline_matmul_f32->a_m;\n            shname = \"F32_ALIGNED_M\";\n        } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {\n            p = ctx->device->pipeline_matmul_f32_f16->a_m;\n            shname = \"F32_F16_ALIGNED_M\";\n        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {\n            p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m;\n            shname = \"F16_F32_ALIGNED_M\";\n        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {\n            p = ctx->device->pipeline_matmul_f16.f32acc->a_m;\n            shname = \"F16_ALIGNED_M\";\n        } else {\n            GGML_ABORT(\"fatal error\");\n        }\n    } else if (shader_size == 2) {\n        if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {\n            p = ctx->device->pipeline_matmul_f32->a_l;\n            shname = \"F32_ALIGNED_L\";\n        } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {\n            p = ctx->device->pipeline_matmul_f32_f16->a_l;\n            shname = \"F32_F16_ALIGNED_L\";\n        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {\n            p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l;\n            shname = \"F16_F32_ALIGNED_L\";\n        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {\n            p = ctx->device->pipeline_matmul_f16.f32acc->a_l;\n            shname = \"F16_ALIGNED_L\";\n        } else {\n            GGML_ABORT(\"fatal error\");\n        }\n    } else {\n        GGML_ASSERT(0);\n    }\n\n    const size_t kpad = ggml_vk_align_size(k, p->align);\n\n    if (k != kpad) {\n        if (shader_size == 0) {\n            if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {\n                p = ctx->device->pipeline_matmul_f32->s;\n                shname = \"F32_S\";\n            } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {\n                p = ctx->device->pipeline_matmul_f32_f16->s;\n                shname = \"F32_F16_S\";\n            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {\n                p = ctx->device->pipeline_matmul_f16_f32.f32acc->s;\n                shname = \"F16_F32_S\";\n            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {\n                p = ctx->device->pipeline_matmul_f16.f32acc->s;\n                shname = \"F16_S\";\n            }\n        } else if (shader_size == 1) {\n            if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {\n                p = ctx->device->pipeline_matmul_f32->m;\n                shname = \"F32_M\";\n            } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {\n                p = ctx->device->pipeline_matmul_f32_f16->m;\n                shname = \"F32_F16_M\";\n            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {\n                p = ctx->device->pipeline_matmul_f16_f32.f32acc->m;\n                shname = \"F16_F32_M\";\n            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {\n                p = ctx->device->pipeline_matmul_f16.f32acc->m;\n                shname = \"F16_M\";\n            }\n        } else if (shader_size == 2) {\n            if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {\n                p = ctx->device->pipeline_matmul_f32->l;\n                shname = \"F32_L\";\n            } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {\n                p = ctx->device->pipeline_matmul_f32_f16->l;\n                shname = \"F32_F16_L\";\n            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {\n                p = ctx->device->pipeline_matmul_f16_f32.f32acc->l;\n                shname = \"F16_F32_L\";\n            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {\n                p = ctx->device->pipeline_matmul_f16.f32acc->l;\n                shname = \"F16_L\";\n            }\n        }\n    }\n\n    if (split_k > 1) {\n        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);\n\n        if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {\n            // Resize buffer\n            if (ctx->prealloc_split_k != nullptr) {\n                ggml_vk_destroy_buffer(ctx->prealloc_split_k);\n            }\n            ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal});\n        }\n    }\n\n    ggml_pipeline_allocate_descriptor_sets(ctx);\n\n    vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});\n    vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});\n    vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});\n\n    X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne);\n    Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne);\n    float* d = (float *) malloc(sizeof(float) * d_ne);\n\n    for (size_t i = 0; i < x_ne; i++) {\n        if (std::is_same<float, X_TYPE>()) {\n            x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;\n            // x[i] = 1.0f;\n            // x[i] = i + 1;\n            // x[i] = (i % k == i / k) ? 1.0f : 0.0f;\n        } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {\n            x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);\n            // x[i] = ggml_fp32_to_fp16(1.0f);\n            // x[i] = ggml_fp32_to_fp16(i + 1);\n            // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);\n        } else {\n            GGML_ABORT(\"fatal error\");\n        }\n    }\n    for (size_t i = 0; i < y_ne; i++) {\n        if (std::is_same<float, Y_TYPE>()) {\n            y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;\n            // y[i] = (i % k == i / k) ? 1.0f : 0.0f;\n            // y[i] = i + 1;\n        } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {\n            y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);\n            // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);\n            // y[i] = ggml_fp32_to_fp16(i + 1);\n        } else {\n            GGML_ABORT(\"fatal error\");\n        }\n    }\n\n    ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch);\n    ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);\n\n    vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);\n    ggml_vk_ctx_begin(ctx->device, subctx);\n    for (size_t i = 0; i < num_it; i++) {\n        ggml_vk_matmul(\n            ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k),\n            m, n, k,\n            k, k, m, k*m, k*n, m*n,\n            split_k, batch, batch, batch, 1, 1, n\n        );\n    }\n    ggml_vk_ctx_end(subctx);\n\n    auto begin = std::chrono::high_resolution_clock::now();\n    ggml_vk_submit(subctx, ctx->fence);\n    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), \"ggml_vk_test_matmul waitForFences\");\n    ctx->device->device.resetFences({ ctx->fence });\n    ggml_vk_queue_command_pools_cleanup(ctx->device);\n\n    auto end = std::chrono::high_resolution_clock::now();\n    double time = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;\n\n    // copy dst to host\n    ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne);\n\n    float * d_chk = (float *) malloc(sizeof(float) * d_ne);\n\n    ggml_init_params iparams = {\n        /*.mem_size   =*/ 1024*1024*1024,\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ true,\n    };\n\n    ggml_context * ggml_ctx = ggml_init(iparams);\n\n    ggml_type src0_type;\n    ggml_type src1_type;\n\n    if (std::is_same<float, X_TYPE>()) {\n        src0_type = GGML_TYPE_F32;\n    } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {\n        src0_type = GGML_TYPE_F16;\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n    if (std::is_same<float, Y_TYPE>()) {\n        src1_type = GGML_TYPE_F32;\n    } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {\n        src1_type = GGML_TYPE_F16;\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n\n    ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch);\n    ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch);\n    ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);\n\n    src0_ggml->data = x;\n    src1_ggml->data = y;\n    tensor_ggml->data = d_chk;\n\n    ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);\n    ggml_build_forward_expand(cgraph, tensor_ggml);\n\n    ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);\n\n    ggml_free(ggml_ctx);\n\n    double avg_err = 0.0;\n    int first_err_n = -1;\n    int first_err_m = -1;\n    int first_err_b = -1;\n\n    for (size_t i = 0; i < m*n*batch; i++) {\n        double err = std::fabs(d[i] - d_chk[i]);\n        avg_err += err;\n\n        if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {\n            first_err_b = i / (m * n);\n            first_err_n = (i % (m * n)) / m;\n            first_err_m = (i % (m * n)) % m;\n        }\n    }\n\n    avg_err /= m * n;\n\n    double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0);\n\n    std::cerr << \"TEST \" << shname << \" m=\" << m << \" n=\" << n << \" k=\" << k << \" batch=\" << batch << \" split_k=\" << split_k << \" matmul \" << time / num_it << \"ms \" << tflops << \" TFLOPS avg_err=\" << avg_err << std::endl;\n\n    if (avg_err > 0.1 || std::isnan(avg_err)) {\n        std::cerr << \"m = \" << first_err_m << \" n = \" << first_err_n << \" b = \" << first_err_b << std::endl;\n        std::cerr << \"Actual result: \" << std::endl << std::endl;\n        ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);\n        std::cerr << \"Expected result: \" << std::endl << std::endl;\n        ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);\n\n        if (split_k > 1) {\n            float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);\n            ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);\n\n            std::cerr << \"d_buf0: \" << std::endl << std::endl;\n            ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);\n\n            std::cerr << \"d_buf1: \" << std::endl << std::endl;\n            ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);\n\n            std::cerr << \"d_buf2: \" << std::endl << std::endl;\n            ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);\n\n            std::cerr << \"d_buf3: \" << std::endl << std::endl;\n            ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);\n\n            free(split_k_buf);\n        }\n    }\n\n    free(d_chk);\n\n    ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);\n\n    ggml_vk_destroy_buffer(d_X);\n    ggml_vk_destroy_buffer(d_Y);\n    ggml_vk_destroy_buffer(d_D);\n\n    free(x);\n    free(y);\n    free(d);\n}\n\nstatic void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) {\n    if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) {\n        return;\n    }\n    i0 = std::max(i0, 5);\n    i1 = std::max(i1, 5);\n    i2 = std::max(i2, 0);\n    i3 = std::max(i3, 0);\n    fprintf(stderr, \"         \");\n    for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {\n        fprintf(stderr, \"%7d \", idx1);\n    }\n    fprintf(stderr, \"\\n\");\n    for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {\n        fprintf(stderr, \"%7d: \", idx0);\n        for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {\n            if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {\n                float val;\n                if (tensor->type == GGML_TYPE_F32) {\n                    val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);\n                } else if (tensor->type == GGML_TYPE_F16) {\n                    val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));\n                } else {\n                    GGML_ABORT(\"fatal error\");\n                }\n                fprintf(stderr, \"% 7.2f \", val);\n            } else {\n                fprintf(stderr, \"        \");\n            }\n        }\n        fprintf(stderr, \"\\n\");\n    }\n}\n\nstatic void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) {\n    ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr);\n}\n\nstatic void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) {\n    if (quant == GGML_TYPE_F32) {\n        memcpy(to, from, sizeof(float) * ne);\n        return;\n    }\n\n    const auto * tt = ggml_get_type_traits(quant);\n\n    ggml_to_float_t dequant_fn = tt->to_float;\n\n    dequant_fn(from, to, ne);\n}\n\nstatic void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {\n    VK_LOG_DEBUG(\"ggml_vk_test_dequant(\" << ne << \")\");\n    const size_t x_sz = sizeof(float) * ne;\n    const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne;\n    const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);\n    float * x = (float *) malloc(x_sz);\n    void * qx = malloc(qx_sz);\n    vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});\n    vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, {vk::MemoryPropertyFlagBits::eDeviceLocal});\n    float * x_ref = (float *) malloc(x_sz);\n    ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);\n\n    for (size_t i = 0; i < ne; i++) {\n        x[i] = rand() / (float)RAND_MAX;\n    }\n\n    vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant);\n\n    ggml_vk_quantize_data(x, qx, ne, quant);\n    ggml_vk_dequantize_data(qx, x_ref, ne, quant);\n\n    ggml_pipeline_request_descriptor_sets(ctx, p, 1);\n\n    ggml_pipeline_allocate_descriptor_sets(ctx);\n\n    ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);\n\n    vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);\n    ggml_vk_ctx_begin(ctx->device, subctx);\n    const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };\n    ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1});\n    ggml_vk_ctx_end(subctx);\n\n    auto begin = std::chrono::high_resolution_clock::now();\n\n    ggml_vk_submit(subctx, ctx->fence);\n    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), \"ggml_vk_test_dequant waitForFences\");\n    ctx->device->device.resetFences({ ctx->fence });\n    ggml_vk_queue_command_pools_cleanup(ctx->device);\n\n    auto end = std::chrono::high_resolution_clock::now();\n\n    double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;\n    ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16);\n\n    int first_err = -1;\n\n    double avg_err = 0.0;\n    for (size_t i = 0; i < ne; i++) {\n        double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i]));\n        avg_err += error;\n\n        if (first_err < 0 && error > 0.05) {\n            first_err = i;\n        }\n    }\n\n    avg_err /= ne;\n\n    std::cerr << \"TEST DEQUANT \" << ggml_type_name(quant) << \" time=\" << ms_dequant << \"ms avg_err=\" << avg_err << std::endl;\n\n    if (avg_err > 0.1) {\n        std::cerr << \"first_error = \" << first_err << std::endl;\n        std::cerr << \"Actual result: \" << std::endl << std::endl;\n        for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {\n            std::cerr << ggml_fp16_to_fp32(x_chk[i]) << \", \";\n        }\n        std::cerr << std::endl << \"Expected result: \" << std::endl << std::endl;\n        for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {\n            std::cerr << x_ref[i] << \", \";\n        }\n        std::cerr << std::endl;\n    }\n\n    ggml_vk_destroy_buffer(x_buf);\n    ggml_vk_destroy_buffer(qx_buf);\n\n    free(x);\n    free(qx);\n    free(x_ref);\n    free(x_chk);\n}\n\n// This does not work without ggml q8_1 quantization support\n//\n// typedef uint16_t ggml_half;\n// typedef uint32_t ggml_half2;\n//\n// #define QK8_1 32\n// typedef struct {\n//     union {\n//         struct {\n//             ggml_half d; // delta\n//             ggml_half s; // d * sum(qs[i])\n//         } GGML_COMMON_AGGR_S;\n//         ggml_half2 ds;\n//     } GGML_COMMON_AGGR_U;\n//     int8_t qs[QK8_1]; // quants\n// } block_q8_1;\n//\n// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {\n//     VK_LOG_DEBUG(\"ggml_vk_test_quantize(\" << ne << \")\");\n//     GGML_ASSERT(quant == GGML_TYPE_Q8_1);\n//\n//     const size_t x_sz = sizeof(float) * ne;\n//     const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);\n//     float * x = (float *) malloc(x_sz);\n//     block_q8_1 * qx     = (block_q8_1 *)malloc(qx_sz);\n//     block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);\n//     vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});\n//     vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});\n//\n//     for (size_t i = 0; i < ne; i++) {\n//         x[i] = rand() / (float)RAND_MAX;\n//     }\n//\n//     vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);\n//\n//     ggml_pipeline_request_descriptor_sets(ctx, p, 1);\n//\n//     ggml_pipeline_allocate_descriptor_sets(ctx);\n//\n//     ggml_vk_buffer_write(x_buf, 0, x, x_sz);\n//\n//     vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);\n//     ggml_vk_ctx_begin(ctx->device, subctx);\n//     ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, x_buf), ggml_vk_subbuffer(ctx, qx_buf), ne);\n//     ggml_vk_ctx_end(subctx);\n//\n//     auto begin = std::chrono::high_resolution_clock::now();\n//\n//     ggml_vk_submit(subctx, ctx->fence);\n//     VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), \"ggml_vk_test_quantize waitForFences\");\n//     ctx->device->device.resetFences({ ctx->fence });\n//     ggml_vk_queue_command_pools_cleanup(ctx->device);\n//\n//     auto end = std::chrono::high_resolution_clock::now();\n//\n//     double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;\n//     ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);\n//\n//     ggml_vk_quantize_data(x, qx_res, ne, quant);\n//\n//     int first_err = -1;\n//\n//     for (size_t i = 0; i < ne / 32; i++) {\n//         double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));\n//\n//         if (first_err < 0 && error > 0.1) {\n//             first_err = i;\n//         }\n//\n//         error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));\n//\n//         if (first_err < 0 && error > 0.1) {\n//             first_err = i;\n//         }\n//\n//         for (size_t j = 0; j < 32; j++) {\n//             uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);\n//\n//             if (first_err < 0 && error > 1) {\n//                 first_err = i;\n//             }\n//         }\n//     }\n//\n//     std::cerr << \"TEST QUANTIZE \" << ggml_type_name(quant) << \" time=\" << ms_quant << \"ms \" << (first_err == -1 ? \"CORRECT\" : \"INCORRECT\") << std::endl;\n//\n//     if (first_err != -1) {\n//         std::cerr << \"first_error = \" << first_err << std::endl;\n//         std::cerr << \"Actual result: \" << std::endl << std::endl;\n//         std::cout << \"d=\" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << \" s=\" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << \" \";\n//         for (size_t j = 0; j < 32; j++) {\n//             std::cout << \" qs\" << j << \"=\" << (uint32_t)qx[first_err].qs[j] << \" \";\n//         }\n//         std::cerr << std::endl << std::endl << \"Expected result: \" << std::endl << std::endl;\n//         std::cout << \"d=\" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << \" s=\" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << \" \";\n//         for (size_t j = 0; j < 32; j++) {\n//             std::cout << \" qs\" << j << \"=\" << (uint32_t)qx_res[first_err].qs[j] << \" \";\n//         }\n//         std::cerr << std::endl;\n//     }\n//\n//     ggml_vk_destroy_buffer(x_buf);\n//     ggml_vk_destroy_buffer(qx_buf);\n//\n//     free(x);\n//     free(qx);\n//     free(qx_res);\n// }\n\nstatic void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {\n    VK_LOG_DEBUG(\"ggml_vk_test_dequant_matmul(\" << m << \", \" << n << \", \" << k << \", \" << batch << \", \" << num_it << \", \" << split_k << \", \" << ggml_type_name(quant) << \")\");\n    const size_t x_ne = m * k * batch;\n    const size_t y_ne = k * n * batch;\n    const size_t d_ne = m * n * batch;\n\n    vk_matmul_pipeline2 * pipelines;\n\n    if (mmq) {\n        pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;\n    } else {\n        pipelines = ctx->device->pipeline_dequant_mul_mat_mat;\n    }\n\n    const bool fp16acc = ctx->device->fp16;\n\n    vk_pipeline p;\n    std::string shname;\n    if (shader_size == 0) {\n        p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;\n        shname = std::string(ggml_type_name(quant)) + \"_ALIGNED_S\";\n    } else if (shader_size == 1) {\n        p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;\n        shname = std::string(ggml_type_name(quant)) + \"_ALIGNED_M\";\n    } else if (shader_size == 2) {\n        p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;\n        shname = std::string(ggml_type_name(quant)) + \"_ALIGNED_L\";\n    } else {\n        GGML_ASSERT(0);\n    }\n\n    const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);\n\n    if (mmq || k != kpad) {\n        if (shader_size == 0) {\n            p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;\n            shname = std::string(ggml_type_name(quant)) + \"_S\";\n        } else if (shader_size == 1) {\n            p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;\n            shname = std::string(ggml_type_name(quant)) + \"_M\";\n        } else if (shader_size == 2) {\n            p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;\n            shname = std::string(ggml_type_name(quant)) + \"_L\";\n        } else {\n            GGML_ASSERT(0);\n        }\n    }\n\n    if (p == nullptr) {\n        std::cerr << \"error: no pipeline for ggml_vk_test_dequant_matmul \" << ggml_type_name(quant) << std::endl;\n        return;\n    }\n\n    const size_t x_sz = sizeof(float) * x_ne;\n    const size_t y_sz = sizeof(float) * y_ne;\n    const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);\n    const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;\n    const size_t d_sz = sizeof(float) * d_ne;\n    float * x = (float *) malloc(x_sz);\n    float * y = (float *) malloc(y_sz);\n    void * qx = malloc(qx_sz);\n    vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});\n    vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});\n    vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});\n    vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});\n    float * d = (float *) malloc(d_sz);\n    float * d_chk = (float *) malloc(d_sz);\n\n    for (size_t i = 0; i < x_ne; i++) {\n        x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;\n        // x[i] = (i % k == i / k) ? 1.0f : 0.0f;\n        // x[i] = i % k;\n    }\n\n    ggml_vk_quantize_data(x, qx, x_ne, quant);\n\n    for (size_t i = 0; i < y_ne; i++) {\n        y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;\n        // y[i] = (i % k == i / k) ? 1.0f : 0.0f;\n        // y[i] = i % k;\n    }\n\n    if (split_k > 1) {\n        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);\n\n        if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {\n            // Resize buffer\n            if (ctx->prealloc_split_k != nullptr) {\n                ggml_vk_destroy_buffer(ctx->prealloc_split_k);\n            }\n            ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal});\n        }\n    }\n    if (mmq) {\n        vk_pipeline pipeline_quantize_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);\n        ggml_pipeline_request_descriptor_sets(ctx, pipeline_quantize_q8_1, num_it);\n    }\n\n    ggml_pipeline_allocate_descriptor_sets(ctx);\n\n    ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);\n    ggml_vk_buffer_write(y_buf, 0, y, y_sz);\n\n    vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);\n    ggml_vk_ctx_begin(ctx->device, subctx);\n    if (mmq) {\n        for (size_t i = 0; i < num_it; i++) {\n            ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);\n            ggml_vk_matmul(\n                ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },\n                m, n, k,\n                k, k, m, k*m, k*n, m*n,\n                split_k, batch, batch, batch, 1, 1, n\n            );\n        }\n    } else {\n        for (size_t i = 0; i < num_it; i++) {\n            ggml_vk_matmul(\n                ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },\n                m, n, k,\n                k, k, m, k*m, k*n, m*n,\n                split_k, batch, batch, batch, 1, 1, n\n            );\n        }\n    }\n    ggml_vk_ctx_end(subctx);\n\n    auto begin = std::chrono::high_resolution_clock::now();\n\n    ggml_vk_submit(subctx, ctx->fence);\n    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), \"ggml_vk_test_dequant waitForFences\");\n    ctx->device->device.resetFences({ ctx->fence });\n    ggml_vk_queue_command_pools_cleanup(ctx->device);\n\n    auto end = std::chrono::high_resolution_clock::now();\n\n    double time_ms = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;\n    ggml_vk_buffer_read(d_buf, 0, d, d_sz);\n\n    ggml_init_params iparams = {\n        /*.mem_size   =*/ 1024*1024*1024,\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ true,\n    };\n\n    ggml_context * ggml_ctx = ggml_init(iparams);\n\n    ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch);\n    ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch);\n    ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);\n\n    src0_ggml->data = qx;\n    src1_ggml->data = y;\n    tensor_ggml->data = d_chk;\n\n    ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);\n    ggml_build_forward_expand(cgraph, tensor_ggml);\n\n    ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);\n\n    ggml_free(ggml_ctx);\n\n    double avg_err = 0.0;\n    int first_err_n = -1;\n    int first_err_m = -1;\n    int first_err_b = -1;\n\n    for (size_t i = 0; i < m*n*batch; i++) {\n        double err = std::fabs(d[i] - d_chk[i]);\n        avg_err += err;\n\n        if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {\n            first_err_b = i / (m * n);\n            first_err_n = (i % (m * n)) / m;\n            first_err_m = (i % (m * n)) % m;\n        }\n    }\n\n    avg_err /= m * n;\n\n    double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);\n\n    std::cerr << \"TEST dequant matmul \" << shname;\n    if (mmq) {\n        std::cerr << \" mmq\";\n    }\n    std::cerr << \" m=\" << m << \" n=\" << n << \" k=\" << k << \" batch=\" << batch << \" split_k=\" << split_k << \" matmul \" << time_ms / num_it << \"ms \" << tflops << \" TFLOPS avg_err=\" << avg_err << std::endl;\n\n    if (avg_err > 0.01 || std::isnan(avg_err)) {\n        std::cerr << \"m = \" << first_err_m << \" n = \" << first_err_n << \" b = \" << first_err_b << std::endl;\n        std::cerr << \"Actual result: \" << std::endl << std::endl;\n        ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);\n        std::cerr << std::endl;\n        std::cerr << \"Expected result: \" << std::endl << std::endl;\n        ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);\n\n        std::cerr << \"src0: \" << std::endl << std::endl;\n        ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);\n        std::cerr << std::endl;\n        std::cerr << \"src1: \" << std::endl << std::endl;\n        ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);\n\n        if (split_k > 1) {\n            float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);\n            ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);\n\n            std::cerr << \"d_buf0: \" << std::endl << std::endl;\n            ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);\n\n            std::cerr << \"d_buf1: \" << std::endl << std::endl;\n            ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);\n\n            std::cerr << \"d_buf2: \" << std::endl << std::endl;\n            ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);\n\n            std::cerr << \"d_buf3: \" << std::endl << std::endl;\n            ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);\n\n            free(split_k_buf);\n        }\n    }\n\n    ggml_vk_destroy_buffer(qx_buf);\n    ggml_vk_destroy_buffer(y_buf);\n    ggml_vk_destroy_buffer(qy_buf);\n    ggml_vk_destroy_buffer(d_buf);\n\n    free(x);\n    free(qx);\n    free(y);\n    free(d);\n    free(d_chk);\n}\n#endif\n\nstatic void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx) {\n#if defined(GGML_VULKAN_RUN_TESTS)\n    const std::vector<size_t> vals {\n        512, 512, 128,\n        128, 512, 512,\n        4096, 512, 4096,\n        11008, 512, 4096,\n        4096, 512, 11008,\n        32000, 512, 4096,\n        8, 8, 8,\n        100, 46, 576,\n        623, 111, 128,\n        100, 46, 558,\n        512, 1, 256,\n        128, 110, 622,\n        511, 511, 127,\n        511, 511, 7,\n        511, 511, 17,\n        49, 49, 128,\n        128, 49, 49,\n        4096, 49, 4096,\n    };\n    const size_t num_it = 100;\n\n    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);\n    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);\n    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);\n\n    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);\n    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);\n    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);\n\n    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);\n    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);\n    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);\n\n    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);\n    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);\n    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);\n\n    abort();\n\n    for (size_t i = 0; i < vals.size(); i += 3) {\n        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);\n        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);\n        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);\n        std::cerr << '\\n';\n        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0);\n        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1);\n        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2);\n        std::cerr << '\\n';\n        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);\n        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);\n        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);\n        std::cerr << '\\n' << std::endl;\n\n        if (vals[i + 2] % 32 == 0) {\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0);\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0);\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0);\n            std::cerr << '\\n';\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0);\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0);\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0);\n            std::cerr << '\\n';\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0);\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0);\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0);\n            std::cerr << '\\n' << std::endl;\n        }\n\n        if (vals[i + 2] % 256 == 0) {\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K);\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K);\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K);\n            std::cerr << '\\n';\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K);\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K);\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K);\n            std::cerr << '\\n';\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K);\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K);\n            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K);\n            std::cerr << '\\n' << std::endl;\n        }\n    }\n\n    GGML_ABORT(\"fatal error\");\n#endif\n\n    if (subctx) {\n        // Submit and wait for any pending work before reallocating the buffers\n        ggml_vk_ctx_end(subctx);\n        ggml_vk_submit(subctx, {});\n        ctx->submit_pending = true;\n        ggml_vk_synchronize(ctx);\n        GGML_ASSERT(ctx->compute_ctx.expired());\n        ggml_vk_ctx_begin(ctx->device, subctx);\n        ctx->compute_ctx = subctx;\n    }\n\n    if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) {\n        VK_LOG_MEMORY(\"ggml_vk_preallocate_buffers(x_size: \" << ctx->prealloc_size_x << \")\");\n        // Resize buffer\n        if (ctx->prealloc_x != nullptr) {\n            ggml_vk_destroy_buffer(ctx->prealloc_x);\n        }\n        ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x);\n    }\n    if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) {\n        VK_LOG_MEMORY(\"ggml_vk_preallocate_buffers(y_size: \" << ctx->prealloc_size_y << \")\");\n        // Resize buffer\n        if (ctx->prealloc_y != nullptr) {\n            ggml_vk_destroy_buffer(ctx->prealloc_y);\n        }\n        ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);\n        ctx->prealloc_y_last_tensor_used = nullptr;\n    }\n    if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {\n        VK_LOG_MEMORY(\"ggml_vk_preallocate_buffers(split_k_size: \" << ctx->prealloc_size_split_k << \")\");\n        // Resize buffer\n        if (ctx->prealloc_split_k != nullptr) {\n            ggml_vk_destroy_buffer(ctx->prealloc_split_k);\n        }\n        ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);\n    }\n    if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) {\n        VK_LOG_MEMORY(\"ggml_vk_preallocate_buffers(add_partials_size: \" << ctx->prealloc_add_rms_partials << \")\");\n        // Resize buffer\n        if (ctx->prealloc_add_rms_partials != nullptr) {\n            ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);\n        }\n        ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials);\n    }\n}\n\nstatic void ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);\n\n// Returns true if node has enqueued work into the queue, false otherwise\n// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.\nstatic bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){\n    ggml_tensor * node = cgraph->nodes[node_idx];\n    if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {\n        return false;\n    }\n    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {\n        return false;\n    }\n\n    VK_LOG_DEBUG(\"ggml_vk_build_graph(\" << node << \", \" << ggml_op_name(node->op) << \")\");\n    ctx->semaphore_idx = 0;\n\n    ggml_tensor * src0 = node->src[0];\n    ggml_tensor * src1 = node->src[1];\n    ggml_tensor * src2 = node->src[2];\n    ggml_tensor * src3 = node->src[3];\n\n    if (node->op == GGML_OP_ADD) {\n        int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;\n        if (next_node_idx < cgraph->n_nodes &&\n            cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&\n            cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&\n            ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&\n            ctx->device->add_rms_fusion) {\n            uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);\n            ctx->do_add_rms_partials_offset_calculation = true;\n            if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {\n                ctx->do_add_rms_partials = true;\n            }\n        }\n    }\n\n    vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);\n\n    {\n        // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers\n        // to synchronize them. This handles most \"normal\" synchronization when computing the graph, and when\n        // there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers\n        // outside of this logic. When a node uses one of the prealloc buffers for something like\n        // dequantization or split_k, additional synchronization is needed between those passes.\n        bool need_sync = false;\n\n        // Check whether \"node\" requires synchronization. The node requires synchronization if it\n        // overlaps in memory with another unsynchronized node and at least one of them is a write.\n        // Destination nodes are checked against both the written/read lists. Source nodes are only\n        // checked against the written list. Two nodes overlap in memory if they come from the same\n        // buffer and the tensor or view ranges overlap.\n        auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector<const ggml_tensor *> &unsynced_nodes) -> bool {\n            if (unsynced_nodes.size() == 0) {\n                return false;\n            }\n            auto n_base = vk_tensor_offset(node) + node->view_offs;\n            auto n_size = ggml_nbytes(node);\n            ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context;\n            vk_buffer a_buf = a_buf_ctx->dev_buffer;\n            for (auto &other : unsynced_nodes) {\n                ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context;\n                vk_buffer o_buf = o_buf_ctx->dev_buffer;\n                if (a_buf == o_buf) {\n                    auto o_base = vk_tensor_offset(other) + other->view_offs;\n                    auto o_size = ggml_nbytes(other);\n\n                    if ((o_base <= n_base && n_base < o_base + o_size) ||\n                        (n_base <= o_base && o_base < n_base + n_size)) {\n                        return true;\n                    }\n                }\n            }\n            return false;\n        };\n\n        // For all fused ops, check if the destination node or any of the source\n        // nodes require synchronization.\n        for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) {\n            const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];\n            // If the node actually writes to memory, then check if it needs to sync\n            if (ctx->fused_ops_write_mask & (1 << i)) {\n                if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {\n                    need_sync = true;\n                    break;\n                }\n            }\n            for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {\n                if (!cur_node->src[j]) {\n                    continue;\n                }\n                if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) {\n                    need_sync = true;\n                    break;\n                }\n            }\n        }\n\n        if (need_sync) {\n            if (vk_enable_sync_logger) {\n                std::cerr <<  \"sync\" << std::endl;\n            }\n            ctx->unsynced_nodes_written.clear();\n            ctx->unsynced_nodes_read.clear();\n            ggml_vk_sync_buffers(ctx, compute_ctx);\n\n            if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {\n                ctx->query_node_idx[ctx->query_idx] = node_idx;\n                compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);\n            }\n        }\n        // Add all fused nodes to the unsynchronized lists.\n        for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {\n            const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];\n            // Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.\n            if (ctx->fused_ops_write_mask & (1 << i)) {\n                ctx->unsynced_nodes_written.push_back(cur_node);\n            }\n            for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {\n                if (!cur_node->src[j]) {\n                    continue;\n                }\n                ctx->unsynced_nodes_read.push_back(cur_node->src[j]);\n            }\n        }\n    }\n    if (vk_enable_sync_logger) {\n        for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {\n            auto *n = cgraph->nodes[node_idx + i];\n            std::cerr << node_idx + i << \" \" << ggml_op_name(n->op) << \" \" <<  n->name;\n            if (n->op == GGML_OP_GLU) {\n                std::cerr << \" \" << ggml_glu_op_name(ggml_get_glu_op(n)) << \" \" << (n->src[1] ? \"split\" : \"single\") << \" \";\n            }\n            if (n->op == GGML_OP_ROPE) {\n                const int mode = ((const int32_t *) n->op_params)[2];\n                std::cerr << \" rope mode: \" << mode;\n            }\n            std::cerr << std::endl;\n        }\n    }\n\n    switch (node->op) {\n    case GGML_OP_REPEAT:\n        ggml_vk_repeat(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_REPEAT_BACK:\n        ggml_vk_repeat_back(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_ACC:\n    case GGML_OP_SET:\n        ggml_vk_acc(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_GET_ROWS:\n        ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_ADD:\n        if (ctx->num_additional_fused_ops) {\n            ggml_vk_multi_add(ctx, compute_ctx, cgraph, node_idx);\n        } else {\n            ggml_vk_add(ctx, compute_ctx, src0, src1, node);\n        }\n        break;\n    case GGML_OP_SUB:\n        ggml_vk_sub(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_MUL:\n        ggml_vk_mul(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_DIV:\n        ggml_vk_div(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_ADD_ID:\n        ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node);\n\n        break;\n    case GGML_OP_CONCAT:\n        ggml_vk_concat(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_UPSCALE:\n        ggml_vk_upscale(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_ADD1:\n        ggml_vk_add1(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_ARANGE:\n        ggml_vk_arange(ctx, compute_ctx, node);\n\n        break;\n    case GGML_OP_FILL:\n        ggml_vk_fill(ctx, compute_ctx, node);\n\n        break;\n    case GGML_OP_SCALE:\n        ggml_vk_scale(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_SQR:\n        ggml_vk_sqr(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_SQRT:\n        ggml_vk_sqrt(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_SIN:\n        ggml_vk_sin(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_COS:\n        ggml_vk_cos(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_LOG:\n        ggml_vk_log(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_TRI:\n        ggml_vk_tri(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_DIAG:\n        ggml_vk_diag(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_CLAMP:\n        ggml_vk_clamp(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_PAD:\n        ggml_vk_pad(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_ROLL:\n        ggml_vk_roll(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_CPY:\n    case GGML_OP_CONT:\n    case GGML_OP_DUP:\n        ggml_vk_cpy(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_SET_ROWS:\n        ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_SILU_BACK:\n        ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_NORM:\n        ggml_vk_norm(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_GROUP_NORM:\n        ggml_vk_group_norm(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_RMS_NORM:\n        ggml_vk_rms_norm(ctx, compute_ctx, cgraph, node_idx, (float *)node->op_params);\n        break;\n    case GGML_OP_RMS_NORM_BACK:\n        ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_L2_NORM:\n        ggml_vk_l2_norm(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_UNARY:\n        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {\n            ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);\n            break;\n        }\n\n        switch (ggml_get_unary_op(node)) {\n        case GGML_UNARY_OP_ELU:\n        case GGML_UNARY_OP_EXP:\n        case GGML_UNARY_OP_SILU:\n        case GGML_UNARY_OP_GELU:\n        case GGML_UNARY_OP_GELU_ERF:\n        case GGML_UNARY_OP_GELU_QUICK:\n        case GGML_UNARY_OP_RELU:\n        case GGML_UNARY_OP_NEG:\n        case GGML_UNARY_OP_TANH:\n        case GGML_UNARY_OP_SIGMOID:\n        case GGML_UNARY_OP_HARDSIGMOID:\n        case GGML_UNARY_OP_HARDSWISH:\n        case GGML_UNARY_OP_ABS:\n        case GGML_UNARY_OP_SOFTPLUS:\n        case GGML_UNARY_OP_STEP:\n        case GGML_UNARY_OP_ROUND:\n        case GGML_UNARY_OP_CEIL:\n        case GGML_UNARY_OP_FLOOR:\n        case GGML_UNARY_OP_TRUNC:\n        case GGML_UNARY_OP_SGN:\n            ggml_vk_unary(ctx, compute_ctx, src0, node);\n            break;\n        case GGML_UNARY_OP_XIELU:\n            ggml_vk_xielu(ctx, compute_ctx, src0, node);\n            break;\n        default:\n            return false;\n        }\n        break;\n    case GGML_OP_GLU:\n        switch (ggml_get_glu_op(node)) {\n        case GGML_GLU_OP_GEGLU:\n        case GGML_GLU_OP_REGLU:\n        case GGML_GLU_OP_SWIGLU:\n        case GGML_GLU_OP_SWIGLU_OAI:\n        case GGML_GLU_OP_GEGLU_ERF:\n        case GGML_GLU_OP_GEGLU_QUICK:\n            ggml_vk_glu(ctx, compute_ctx, src0, src1, node);\n            break;\n        default:\n            return false;\n        }\n        break;\n    case GGML_OP_DIAG_MASK_INF:\n        ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_SOFT_MAX:\n        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {\n            ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);\n        } else {\n            ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node);\n        }\n\n        break;\n    case GGML_OP_SOFT_MAX_BACK:\n        ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_ROPE:\n        ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, false);\n\n        break;\n    case GGML_OP_ROPE_BACK:\n        ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true);\n\n        break;\n    case GGML_OP_ARGSORT:\n        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {\n            ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);\n        } else {\n            ggml_vk_argsort(ctx, compute_ctx, src0, node);\n        }\n\n        break;\n    case GGML_OP_TOP_K:\n        ggml_vk_topk(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_SUM:\n        ggml_vk_sum(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_SUM_ROWS:\n        ggml_vk_sum_rows(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_CUMSUM:\n        ggml_vk_cumsum(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_MEAN:\n        ggml_vk_mean(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_ARGMAX:\n        ggml_vk_argmax(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_COUNT_EQUAL:\n        ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_SOLVE_TRI:\n        ggml_vk_solve_tri(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_IM2COL:\n        ggml_vk_im2col(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_IM2COL_3D:\n        ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_TIMESTEP_EMBEDDING:\n        ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_CONV_TRANSPOSE_1D:\n        ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_POOL_2D:\n        ggml_vk_pool_2d(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_CONV_2D:\n    case GGML_OP_CONV_TRANSPOSE_2D:\n        ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_CONV_2D_DW:\n        ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);\n\n        break;\n    case GGML_OP_LEAKY_RELU:\n        ggml_vk_leaky_relu(ctx, compute_ctx, src0, node);\n\n        break;\n    case GGML_OP_MUL_MAT:\n        ggml_vk_mul_mat(ctx, compute_ctx, cgraph, node_idx);\n\n        break;\n    case GGML_OP_MUL_MAT_ID:\n        ggml_vk_mul_mat_id(ctx, compute_ctx, cgraph, node_idx);\n\n        break;\n\n    case GGML_OP_FLASH_ATTN_EXT:\n        ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node);\n\n        break;\n\n    case GGML_OP_RWKV_WKV6:\n        ggml_vk_rwkv_wkv6(ctx, compute_ctx, node);\n\n        break;\n\n    case GGML_OP_RWKV_WKV7:\n        ggml_vk_rwkv_wkv7(ctx, compute_ctx, node);\n\n        break;\n\n    case GGML_OP_GATED_DELTA_NET:\n        ggml_vk_gated_delta_net(ctx, compute_ctx, node);\n\n        break;\n\n    case GGML_OP_SSM_SCAN:\n        ggml_vk_ssm_scan(ctx, compute_ctx, node);\n\n        break;\n\n    case GGML_OP_SSM_CONV:\n        ggml_vk_ssm_conv(ctx, compute_ctx, node);\n\n        break;\n\n    case GGML_OP_OPT_STEP_ADAMW:\n        ggml_vk_opt_step_adamw(ctx, compute_ctx, node);\n\n        break;\n\n    case GGML_OP_OPT_STEP_SGD:\n        ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node);\n\n        break;\n    default:\n        return false;\n    }\n\n    ctx->tensor_ctxs[node_idx] = compute_ctx;\n\n#if defined(GGML_VULKAN_CHECK_RESULTS)\n    // Force context reset on each node so that each tensor ends up in its own context\n    // and can be run and compared to its CPU equivalent separately\n    last_node = true;\n#endif\n\n    if (submit || last_node) {\n        ggml_vk_ctx_end(compute_ctx);\n\n        // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward\n        if (last_node) {\n            compute_ctx->exit_tensor_idx = node_idx_begin;\n        }\n        else {\n            compute_ctx->exit_tensor_idx = -1;\n        }\n\n        ctx->compute_ctx.reset();\n\n        ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);\n    }\n    return true;\n}\n\nstatic void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {\n    GGML_UNUSED(cgraph);\n    GGML_UNUSED(tensor);\n\n    VK_LOG_DEBUG(\"ggml_vk_compute_forward(\" << tensor << \", name=\" << tensor->name << \", op=\" << ggml_op_name(tensor->op) << \", type=\" << tensor->type << \", ne0=\" << tensor->ne[0] << \", ne1=\" << tensor->ne[1] << \", ne2=\" << tensor->ne[2] << \", ne3=\" << tensor->ne[3] << \", nb0=\" << tensor->nb[0] << \", nb1=\" << tensor->nb[1] << \", nb2=\" << tensor->nb[2] << \", nb3=\" << tensor->nb[3] << \", view_src=\" << tensor->view_src << \", view_offs=\" << tensor->view_offs << \")\");\n\n    vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock();\n\n    // Only run if ctx hasn't been submitted yet\n    if (!subctx->seqs.empty()) {\n#ifdef GGML_VULKAN_CHECK_RESULTS\n        ggml_vk_check_results_0(ctx, cgraph, tensor_idx);\n#endif\n\n        // Do staging buffer copies\n        for (auto& cpy : subctx->in_memcpys) {\n            memcpy(cpy.dst, cpy.src, cpy.n);\n        }\n\n        for (auto& mset : subctx->memsets) {\n            memset(mset.dst, mset.val, mset.n);\n        }\n\n        if (almost_ready && !ctx->almost_ready_fence_pending) {\n            ggml_vk_submit(subctx, ctx->almost_ready_fence);\n            ctx->almost_ready_fence_pending = true;\n        } else {\n            ggml_vk_submit(subctx, {});\n        }\n        ctx->submit_pending = true;\n\n#ifdef GGML_VULKAN_CHECK_RESULTS\n        ggml_vk_synchronize(ctx);\n        ggml_vk_check_results_1(ctx, cgraph, tensor_idx);\n#endif\n    }\n\n    if (tensor_idx == subctx->exit_tensor_idx) {\n        // Do staging buffer copies\n        for (auto& cpy : subctx->out_memcpys) {\n            memcpy(cpy.dst, cpy.src, cpy.n);\n        }\n        subctx->in_memcpys.clear();\n        subctx->out_memcpys.clear();\n        subctx->memsets.clear();\n    }\n}\n\n// Clean up after graph processing is done\nstatic void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {\n    VK_LOG_DEBUG(\"ggml_vk_graph_cleanup()\");\n    ctx->prealloc_y_last_pipeline_used = {};\n\n    ctx->unsynced_nodes_written.clear();\n    ctx->unsynced_nodes_read.clear();\n    ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;\n\n    ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);\n    if (ctx->device->async_use_transfer_queue) {\n        ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);\n    }\n\n    for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {\n        ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });\n    }\n    ctx->gc.semaphores.clear();\n\n    for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) {\n        ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });\n    }\n    ctx->gc.tl_semaphores.clear();\n    ctx->semaphore_idx = 0;\n\n    ctx->event_idx = 0;\n\n    for (auto& event : ctx->gc.events) {\n        ctx->device->device.resetEvent(event);\n    }\n\n    ctx->tensor_ctxs.clear();\n    ctx->gc.contexts.clear();\n    ctx->pipeline_descriptor_set_requirements = 0;\n    ctx->descriptor_set_idx = 0;\n}\n\n// Clean up on backend free\nstatic void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {\n    VK_LOG_DEBUG(\"ggml_vk_cleanup(\" << ctx->name << \")\");\n    // discard any unsubmitted command buffers\n    ctx->compute_ctx.reset();\n    // wait for any pending command buffers to finish\n    ggml_vk_synchronize(ctx);\n\n    ggml_vk_graph_cleanup(ctx);\n\n    ggml_vk_destroy_buffer(ctx->prealloc_x);\n    ggml_vk_destroy_buffer(ctx->prealloc_y);\n    ggml_vk_destroy_buffer(ctx->prealloc_split_k);\n    ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);\n    ggml_vk_destroy_buffer(ctx->sync_staging);\n\n    ctx->prealloc_y_last_pipeline_used = nullptr;\n\n    ctx->prealloc_size_x = 0;\n    ctx->prealloc_size_y = 0;\n    ctx->prealloc_size_split_k = 0;\n\n    for (auto& event : ctx->gc.events) {\n        ctx->device->device.destroyEvent(event);\n    }\n    ctx->gc.events.clear();\n\n    ctx->device->device.destroyFence(ctx->fence);\n    ctx->device->device.destroyFence(ctx->almost_ready_fence);\n\n    for (auto& pool : ctx->descriptor_pools) {\n        ctx->device->device.destroyDescriptorPool(pool);\n    }\n    ctx->descriptor_pools.clear();\n    ctx->descriptor_sets.clear();\n\n    ctx->compute_cmd_pool.destroy(ctx->device->device);\n    if (ctx->device->async_use_transfer_queue) {\n        ctx->device->device.destroySemaphore(ctx->transfer_semaphore.s);\n\n        ctx->transfer_cmd_pool.destroy(ctx->device->device);\n    }\n    if (vk_perf_logger_enabled) {\n        ctx->perf_logger->print_timings(true);\n    }\n}\n\nstatic int ggml_vk_get_device_count() {\n    ggml_vk_instance_init();\n\n    return vk_instance.device_indices.size();\n}\n\nstatic void ggml_vk_get_device_description(int device, char * description, size_t description_size) {\n    ggml_vk_instance_init();\n\n    std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();\n\n    vk::PhysicalDeviceProperties props;\n    devices[device].getProperties(&props);\n\n    snprintf(description, description_size, \"%s\", props.deviceName.data());\n}\n\n// backend interface\n\n#define UNUSED GGML_UNUSED\n\n// device backend\n\nstatic bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) {\n    return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name;\n}\n\nstatic void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    VK_LOG_MEMORY(\"ggml_backend_vk_buffer_free_buffer()\");\n    ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;\n    ggml_vk_destroy_buffer(ctx->dev_buffer);\n    delete ctx;\n}\n\nstatic void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {\n    return vk_ptr_base;\n\n    UNUSED(buffer);\n}\n\nstatic enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {\n    VK_LOG_DEBUG(\"ggml_backend_vk_buffer_init_tensor(\" << buffer << \" (\" << buffer->context << \"), \" << tensor << \")\");\n    if (tensor->view_src != nullptr) {\n        GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);\n    }\n    return GGML_STATUS_SUCCESS;\n}\n\nstatic void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {\n    VK_LOG_DEBUG(\"ggml_backend_vk_buffer_memset_tensor(\" << buffer << \", \" << tensor << \", \" << value << \", \" << offset << \", \" << size << \")\");\n    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;\n    vk_buffer buf = buf_ctx->dev_buffer;\n\n    if (size == 0) {\n        return;\n    }\n\n    uint32_t val32 = (uint32_t)value * 0x01010101;\n    ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);\n}\n\nstatic void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    VK_LOG_DEBUG(\"ggml_backend_vk_buffer_set_tensor(\" << buffer << \", \" << tensor << \", \" << data << \", \" << offset << \", \" << size << \")\");\n    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;\n    vk_buffer buf = buf_ctx->dev_buffer;\n\n    if (size == 0) {\n        return;\n    }\n\n    ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);\n}\n\nstatic void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    VK_LOG_DEBUG(\"ggml_backend_vk_buffer_get_tensor(\" << buffer << \", \" << tensor << \", \" << data << \", \" << offset << \", \" << size << \")\");\n    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;\n\n    if (size == 0) {\n        return;\n    }\n\n    vk_buffer buf = buf_ctx->dev_buffer;\n\n    ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);\n}\n\nstatic bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {\n    if (ggml_nbytes(src) == 0) {\n        return true;\n    }\n\n    if (ggml_backend_buffer_is_vk(src->buffer)) {\n        ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;\n        ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;\n\n        vk_buffer src_buf = src_buf_ctx->dev_buffer;\n        vk_buffer dst_buf = dst_buf_ctx->dev_buffer;\n\n        ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));\n\n        return true;\n    }\n    return false;\n\n    UNUSED(buffer);\n}\n\nstatic void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;\n\n    ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size);\n}\n\nstatic ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {\n    /* .free_buffer     = */ ggml_backend_vk_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_vk_buffer_get_base,\n    /* .init_tensor     = */ ggml_backend_vk_buffer_init_tensor,\n    /* .memset_tensor   = */ ggml_backend_vk_buffer_memset_tensor,\n    /* .set_tensor      = */ ggml_backend_vk_buffer_set_tensor,\n    /* .get_tensor      = */ ggml_backend_vk_buffer_get_tensor,\n    /* .cpy_tensor      = */ ggml_backend_vk_buffer_cpy_tensor,\n    /* .clear           = */ ggml_backend_vk_buffer_clear,\n    /* .reset           = */ NULL,\n};\n\n// vk buffer type\nstatic const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) {\n    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context;\n\n    return ctx->name.c_str();\n}\n\nstatic ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    VK_LOG_MEMORY(\"ggml_backend_vk_buffer_type_alloc_buffer(\" << size << \")\");\n    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;\n\n    vk_buffer dev_buffer = nullptr;\n    try {\n        dev_buffer = ggml_vk_create_buffer_device(ctx->device, size);\n    } catch (const vk::SystemError& e) {\n        return nullptr;\n    }\n\n    ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name);\n\n    return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size);\n}\n\nstatic size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;\n    return ctx->device->properties.limits.minStorageBufferOffsetAlignment;\n}\n\nstatic size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {\n    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;\n    return ctx->device->suballocation_block_size;\n}\n\nstatic size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {\n    return ggml_nbytes(tensor);\n\n    UNUSED(buft);\n}\n\nggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) {\n    ggml_vk_instance_init();\n\n    VK_LOG_DEBUG(\"ggml_backend_vk_buffer_type(\" << dev_num << \")\");\n\n    vk_device dev = ggml_vk_get_device(dev_num);\n\n    return &dev->buffer_type;\n}\n\n// host buffer type\n\nstatic const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) {\n    return GGML_VK_NAME \"_Host\";\n\n    UNUSED(buft);\n}\n\nstatic const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) {\n    return GGML_VK_NAME \"_Host\";\n\n    UNUSED(buffer);\n}\n\nstatic void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    VK_LOG_MEMORY(\"ggml_backend_vk_host_buffer_free_buffer()\");\n    ggml_vk_host_free(vk_instance.devices[0], buffer->context);\n}\n\nstatic ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    VK_LOG_MEMORY(\"ggml_backend_vk_host_buffer_type_alloc_buffer(\" << size << \")\");\n\n    size += 32;  // Behave like the CPU buffer type\n    void * ptr = nullptr;\n    try {\n        ptr = ggml_vk_host_malloc(vk_instance.devices[0], size);\n    } catch (vk::SystemError& e) {\n        GGML_LOG_WARN(\"ggml_vulkan: Failed to allocate pinned memory (%s)\\n\", e.what());\n        // fallback to cpu buffer\n        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);\n    }\n\n    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);\n    buffer->buft = buft;\n    buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer;\n\n    return buffer;\n\n    UNUSED(buft);\n}\n\nstatic size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment;\n\n    UNUSED(buft);\n}\n\nstatic size_t ggml_backend_vk_host_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {\n    return vk_instance.devices[0]->suballocation_block_size;\n\n    UNUSED(buft);\n}\n\n// Should be changed to return device-specific host buffer type\n// but that probably requires changes in llama.cpp\nggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {\n    static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = {\n        /* .iface    = */ {\n            /* .get_name         = */ ggml_backend_vk_host_buffer_type_name,\n            /* .alloc_buffer     = */ ggml_backend_vk_host_buffer_type_alloc_buffer,\n            /* .get_alignment    = */ ggml_backend_vk_host_buffer_type_get_alignment,\n            /* .get_max_size     = */ ggml_backend_vk_host_buffer_type_get_max_size,\n            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,\n            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,\n        },\n        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0),\n        /* .context  = */ nullptr,\n    };\n\n    // Make sure device 0 is initialized\n    ggml_vk_instance_init();\n    ggml_vk_get_device(0);\n\n    return &ggml_backend_vk_buffer_type_host;\n}\n\n\n// backend\n\nstatic const char * ggml_backend_vk_name(ggml_backend_t backend) {\n    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;\n\n    return ctx->name.c_str();\n}\n\nstatic void ggml_backend_vk_free(ggml_backend_t backend) {\n    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;\n    VK_LOG_DEBUG(\"ggml_backend_vk_free(\" << ctx->name << \")\");\n\n    ggml_vk_cleanup(ctx);\n\n    delete ctx;\n    delete backend;\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) {\n    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;\n\n    return &ctx->device->buffer_type;\n}\n\nstatic void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    VK_LOG_DEBUG(\"ggml_backend_vk_set_tensor_async(\" << size << \")\");\n    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;\n    GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && \"unsupported buffer type\");\n\n    if (size == 0) {\n        return;\n    }\n\n    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;\n\n    vk_context cpy_ctx;\n\n    if (ctx->device->async_use_transfer_queue) {\n        if (ctx->transfer_ctx.expired()) {\n            // Initialize new transfer context\n            cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);\n            ctx->transfer_ctx = cpy_ctx;\n            ggml_vk_ctx_begin(ctx->device, cpy_ctx);\n        } else {\n            cpy_ctx = ctx->transfer_ctx.lock();\n        }\n    } else {\n        cpy_ctx = ggml_vk_get_compute_ctx(ctx);\n    }\n\n    vk_buffer buf = buf_ctx->dev_buffer;\n\n    auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;\n\n    bool ret = ggml_vk_buffer_write_async(cpy_ctx, buf, dst_offset, data, size);\n\n    if (!ret) {\n        ggml_vk_ensure_sync_staging_buffer(ctx, size);\n        ggml_vk_sync_buffers(nullptr, cpy_ctx);\n\n        vk::BufferCopy buffer_cpy;\n        buffer_cpy.srcOffset = 0;\n        buffer_cpy.dstOffset = dst_offset;\n        buffer_cpy.size = size;\n\n        cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });\n        deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys);\n        ggml_vk_synchronize(ctx);\n    }\n}\n\nstatic void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    VK_LOG_DEBUG(\"ggml_backend_vk_get_tensor_async(\" << size << \")\");\n    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;\n    GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && \"unsupported buffer type\");\n\n    if (size == 0) {\n        return;\n    }\n\n    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;\n\n    vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);\n\n    vk_buffer buf = buf_ctx->dev_buffer;\n\n    auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;\n    bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size);\n\n    // If that failed, copy synchronously through a staging buffer\n    if (!ret) {\n        ggml_vk_ensure_sync_staging_buffer(ctx, size);\n        ggml_vk_sync_buffers(nullptr, compute_ctx);\n\n        vk::BufferCopy buffer_cpy;\n        buffer_cpy.srcOffset = src_offset;\n        buffer_cpy.dstOffset = 0;\n        buffer_cpy.size = size;\n\n        compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });\n        deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys);\n        ggml_vk_synchronize(ctx);\n    }\n}\n\nstatic bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {\n    VK_LOG_DEBUG(\"ggml_backend_vk_cpy_tensor_async(\" << src << \" -> \" << dst << \", size=\" << ggml_nbytes(src) << \")\");\n    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context;\n\n    // Skip zero-size tensors\n    if (ggml_nbytes(src) == 0) {\n        return true;\n    }\n\n    if (dst->buffer->buft != ggml_backend_vk_get_default_buffer_type(backend_dst)) {\n        return false;\n    }\n\n    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;\n    vk_buffer dst_buf = dst_buf_ctx->dev_buffer;\n\n    if (ggml_backend_buffer_is_vk(src->buffer)) {\n        ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;\n\n        // Async copy only works within the same device\n        if (src_buf_ctx->dev_buffer->device != dst_buf->device) {\n            return false;\n        }\n\n        vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);\n\n        ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs,\n                                   src_buf_ctx->dev_buffer, vk_tensor_offset(src) + src->view_offs,\n                                   ggml_nbytes(src));\n        return true;\n    }\n\n    if (ggml_backend_buffer_is_host(src->buffer)) {\n        vk_buffer pinned_buf = nullptr;\n        size_t pinned_offset = 0;\n        ggml_vk_host_get(ctx->device, src->data, pinned_buf, pinned_offset);\n        if (pinned_buf == nullptr) {\n            return false;\n        }\n\n        vk_context cpy_ctx;\n        if (ctx->device->async_use_transfer_queue) {\n            if (ctx->transfer_ctx.expired()) {\n                cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);\n                ctx->transfer_ctx = cpy_ctx;\n                ggml_vk_ctx_begin(ctx->device, cpy_ctx);\n            } else {\n                cpy_ctx = ctx->transfer_ctx.lock();\n            }\n        } else {\n            cpy_ctx = ggml_vk_get_compute_ctx(ctx);\n        }\n\n        return ggml_vk_buffer_write_async(cpy_ctx, dst_buf,\n                                          vk_tensor_offset(dst) + dst->view_offs,\n                                          src->data, ggml_nbytes(src));\n    }\n\n    GGML_UNUSED(backend_src);\n    return false;\n}\n\nstatic void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {\n    VK_LOG_DEBUG(\"ggml_vk_synchronize()\");\n\n    bool do_transfer = !ctx->compute_ctx.expired();\n\n    if (ggml_vk_submit_transfer_ctx(ctx)) {\n        ctx->submit_pending = true;\n    }\n\n    vk_context compute_ctx;\n    vk_command_buffer* cmd_buf = nullptr;\n    if (do_transfer) {\n        compute_ctx = ctx->compute_ctx.lock();\n        if (compute_ctx->s) {\n            cmd_buf = compute_ctx->s->buffer;\n        }\n\n        ggml_vk_ctx_end(compute_ctx);\n\n        for (auto& cpy : compute_ctx->in_memcpys) {\n            memcpy(cpy.dst, cpy.src, cpy.n);\n        }\n\n        ggml_vk_submit(compute_ctx, {});\n        ctx->submit_pending = true;\n    }\n\n    if (ctx->submit_pending) {\n        if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {\n            vk::TimelineSemaphoreSubmitInfo tl_info{\n                1, &ctx->transfer_semaphore.value,\n                0, nullptr,\n            };\n            vk::PipelineStageFlags stage = ctx->device->transfer_queue.stage_flags;\n            vk::SubmitInfo si{\n                1, &ctx->transfer_semaphore.s, &stage,\n                0, nullptr,\n                0, nullptr,\n            };\n            si.setPNext(&tl_info);\n            std::lock_guard<std::mutex> guard(queue_mutex);\n            ctx->device->compute_queue.queue.submit({ si }, ctx->fence);\n            ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value;\n        } else {\n            std::lock_guard<std::mutex> guard(queue_mutex);\n            ctx->device->compute_queue.queue.submit({}, ctx->fence);\n        }\n        ggml_vk_wait_for_fence(ctx);\n        ctx->submit_pending = false;\n        if (cmd_buf) {\n            cmd_buf->in_use = false;\n        }\n    }\n\n    if (do_transfer) {\n        for (auto& cpy : compute_ctx->out_memcpys) {\n            memcpy(cpy.dst, cpy.src, cpy.n);\n        }\n        ctx->compute_ctx.reset();\n    }\n}\n\nstatic void ggml_backend_vk_synchronize(ggml_backend_t backend) {\n    VK_LOG_DEBUG(\"ggml_backend_vk_synchronize()\");\n    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;\n\n    ggml_vk_synchronize(ctx);\n\n    ggml_vk_graph_cleanup(ctx);\n}\n\nstatic bool ggml_vk_is_empty(ggml_tensor * node) {\n    return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;\n}\n\nstatic bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {\n    if (!ggml_can_fuse(cgraph, node_idx, ops)) {\n        return false;\n    }\n\n    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {\n        // additional constraints specific to this fusion\n        const ggml_tensor *rms_norm = cgraph->nodes[node_idx];\n        const ggml_tensor *mul = cgraph->nodes[node_idx + 1];\n\n        GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);\n        GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);\n        // rms_norm only supports f32\n        if (mul->src[0]->type != GGML_TYPE_F32 ||\n            mul->src[1]->type != GGML_TYPE_F32 ||\n            mul->type != GGML_TYPE_F32) {\n            return false;\n        }\n        // if rms_norm is the B operand, then we don't handle broadcast\n        if (rms_norm == mul->src[1] &&\n            !ggml_are_same_shape(mul->src[0], rms_norm)) {\n            return false;\n        }\n        // rms_norm shader assumes contiguous rows\n        if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {\n            return false;\n        }\n    }\n    auto const &mm_add_ok = [&](const ggml_tensor *mul, const ggml_tensor *add) {\n        const ggml_tensor *bias = add->src[0] == mul ? add->src[1] : add->src[0];\n\n        // mat-vec only\n        if (ggml_nrows(mul) != 1) {\n            return false;\n        }\n        // shaders assume the types match\n        if (mul->type != bias->type) {\n            return false;\n        }\n        // shaders reuse the D shape for bias\n        if (!ggml_are_same_shape(mul, bias) ||\n            !ggml_are_same_stride(mul, bias)) {\n            return false;\n        }\n        // unaligned bias isn't handled\n        if (get_misalign_bytes(ctx, bias) != 0) {\n            return false;\n        }\n        return true;\n    };\n\n    if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) {\n        // additional constraints specific to this fusion\n        const ggml_tensor *mul = cgraph->nodes[node_idx];\n        const ggml_tensor *add = cgraph->nodes[node_idx + 1];\n\n        if (!mm_add_ok(mul, add)) {\n            return false;\n        }\n        if (ops.size() == 3) {\n            if (ops.begin()[2] != GGML_OP_ADD) {\n                return false;\n            }\n            if (!mm_add_ok(add, cgraph->nodes[node_idx + 2])) {\n                return false;\n            }\n        }\n    }\n\n    auto const &mmid_mul_ok = [&](const ggml_tensor *mmid, const ggml_tensor *mul) {\n        const ggml_tensor *scale = mul->src[1];\n\n        if (mmid != mul->src[0]) {\n            return false;\n        }\n        // mat-vec only\n        if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {\n            return false;\n        }\n        // shaders assume the types match\n        if (mmid->type != scale->type) {\n            return false;\n        }\n        // shaders assume the bias is contiguous\n        if (!ggml_is_contiguous(scale)) {\n            return false;\n        }\n        // unaligned bias isn't handled\n        if (get_misalign_bytes(ctx, scale) != 0) {\n            return false;\n        }\n        // shader only indexes by expert index\n        if (scale->ne[0] != 1 ||\n            scale->ne[1] != mul->ne[1] ||\n            scale->ne[2] != 1 ||\n            scale->ne[3] != 1) {\n            return false;\n        }\n        return true;\n    };\n\n    if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_ADD_ID) {\n        // additional constraints specific to this fusion\n        const ggml_tensor *mul = cgraph->nodes[node_idx];\n        const ggml_tensor *add = cgraph->nodes[node_idx + 1];\n        const ggml_tensor *bias = add->src[1];\n\n        if (mul != add->src[0]) {\n            return false;\n        }\n        // mat-vec only\n        if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {\n            return false;\n        }\n        // shaders assume the types match\n        if (mul->type != bias->type) {\n            return false;\n        }\n        // shaders assume the bias is contiguous\n        if (!ggml_is_contiguous(bias)) {\n            return false;\n        }\n        // the ID tensor must be the same for mul_mat_id and add_id\n        if (mul->src[2] != add->src[2]) {\n            return false;\n        }\n        // unaligned bias isn't handled\n        if (get_misalign_bytes(ctx, bias) != 0) {\n            return false;\n        }\n\n        if (ops.size() == 3) {\n            if (ops.begin()[2] != GGML_OP_MUL) {\n                return false;\n            }\n            const ggml_tensor *mul = cgraph->nodes[node_idx + 2];\n            return mmid_mul_ok(add, mul);\n        }\n    }\n\n    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {\n        // additional constraints specific to this fusion\n        const ggml_tensor *mmid = cgraph->nodes[node_idx];\n        const ggml_tensor *mul = cgraph->nodes[node_idx + 1];\n\n        if (!mmid_mul_ok(mmid, mul)) {\n            return false;\n        }\n    }\n\n    return true;\n}\n\nstatic bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,\n                                      int node_idx, topk_moe_mode mode) {\n\n    const ggml_tensor * softmax;\n    const ggml_tensor * weights;\n    const ggml_tensor * get_rows;\n    const ggml_tensor * argsort;\n\n    switch (mode) {\n    case TOPK_MOE_EARLY_SOFTMAX_NORM:\n        softmax = cgraph->nodes[node_idx + 0];\n        weights = cgraph->nodes[node_idx + 9];\n        get_rows = cgraph->nodes[node_idx + 4];\n        argsort = cgraph->nodes[node_idx + 2];\n        break;\n    case TOPK_MOE_SIGMOID_NORM_BIAS:\n        softmax = cgraph->nodes[node_idx + 0]; // really sigmoid\n        weights = cgraph->nodes[node_idx + 10];\n        get_rows = cgraph->nodes[node_idx + 5];\n        argsort = cgraph->nodes[node_idx + 3];\n        if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) {\n            return false;\n        }\n        // bias is expected to be 1D\n        if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||\n            !ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {\n            return false;\n        }\n        // sigmoid fusion seems to generate infinities on moltenvk\n        if (ctx->device->driver_id == vk::DriverId::eMoltenvk) {\n            return false;\n        }\n        break;\n    case TOPK_MOE_EARLY_SOFTMAX:\n        softmax = cgraph->nodes[node_idx + 0];\n        weights = cgraph->nodes[node_idx + 4];\n        get_rows = cgraph->nodes[node_idx + 4];\n        argsort = cgraph->nodes[node_idx + 2];\n        break;\n    case TOPK_MOE_LATE_SOFTMAX:\n        softmax = cgraph->nodes[node_idx + 4];\n        weights = cgraph->nodes[node_idx + 5];\n        get_rows = cgraph->nodes[node_idx + 2];\n        argsort = cgraph->nodes[node_idx + 0];\n        break;\n    default:\n        return false;\n    }\n\n    ggml_tensor * probs = get_rows->src[0];\n    if (probs->op != GGML_OP_RESHAPE) {\n        return false;\n    }\n    probs = probs->src[0];\n    ggml_tensor * selection_probs = argsort->src[0];\n\n    if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) {\n        return false;\n    }\n\n    if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {\n        return false;\n    }\n\n    if (softmax->op == GGML_OP_SOFT_MAX) {\n        const float * op_params = (const float *)softmax->op_params;\n\n        float scale = op_params[0];\n        float max_bias = op_params[1];\n\n        if (scale != 1.0f || max_bias != 0.0f) {\n            return false;\n        }\n\n        // don't fuse when masks or sinks are present\n        if (softmax->src[1] || softmax->src[2]) {\n            return false;\n        }\n    }\n\n    const int n_expert = softmax->ne[0];\n    if (n_expert > (1 << (num_topk_moe_pipelines-1))) {\n        return false;\n    }\n\n    if (!ctx->device->subgroup_arithmetic ||\n        !ctx->device->subgroup_shuffle ||\n        !ctx->device->subgroup_require_full_support ||\n        ctx->device->disable_fusion) {\n        return false;\n    }\n\n    return true;\n}\n\nstatic bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,\n                                           int node_idx) {\n    GGML_UNUSED(ctx);\n    const ggml_tensor *rope = cgraph->nodes[node_idx + 0];\n    const ggml_tensor *view = cgraph->nodes[node_idx + 1];\n    const ggml_tensor *set_rows = cgraph->nodes[node_idx + 2];\n\n    // ne3 not tested\n    if (rope->src[0]->ne[3] != 1) {\n        return false;\n    }\n\n    if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {\n        return false;\n    }\n\n    if (set_rows->src[1]->type != GGML_TYPE_I64) {\n        return false;\n    }\n\n    // The view should flatten two dims of rope into one dim\n    if (!ggml_is_contiguous(view) ||\n        view->ne[0] != rope->ne[0] * rope->ne[1]) {\n        return false;\n    }\n\n    // Only norm/neox/mrope shaders have the fusion code\n    const int mode = ((const int32_t *) rope->op_params)[2];\n    if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_MROPE) {\n        return false;\n    }\n\n    return true;\n}\n\n// Check whether the tensors overlap in memory.\n// Fusions can potentially overwrite src tensors in ways that are not prevented\n// by ggml-alloc. If the fusion src is being applied in a way that's elementwise\n// with the destination, then it's OK for them to overlap if they are exactly equal.\nstatic bool ggml_vk_tensors_overlap(const ggml_tensor * a, const ggml_tensor * b, bool elementwise) {\n    ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context;\n    vk_buffer a_buf = a_buf_ctx->dev_buffer;\n    ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context;\n    vk_buffer b_buf = b_buf_ctx->dev_buffer;\n    if (a_buf == b_buf) {\n        auto a_base = vk_tensor_offset(a) + a->view_offs;\n        auto a_size = ggml_nbytes(a);\n        auto b_base = vk_tensor_offset(b) + b->view_offs;\n        auto b_size = ggml_nbytes(b);\n\n        if (elementwise && a_base == b_base && a_size == b_size) {\n            return false;\n        }\n\n        if ((b_base <= a_base && a_base < b_base + b_size) ||\n            (a_base <= b_base && b_base < a_base + a_size)) {\n            return true;\n        }\n    }\n    return false;\n}\n\nstatic bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,\n                                               int node_idx) {\n    GGML_UNUSED(ctx);\n    const ggml_tensor *rms = cgraph->nodes[node_idx + 0];\n    const ggml_tensor *mul = cgraph->nodes[node_idx + 1];\n    const ggml_tensor *rope = cgraph->nodes[node_idx + 2];\n\n    const int mode = ((const int32_t *) rope->op_params)[2];\n\n    // noncontig tensors aren't tested, and don't seem common in practice\n    if (!ggml_is_contiguous(rms) ||\n        !ggml_is_contiguous(mul) ||\n        !ggml_is_contiguous(rope)) {\n        return false;\n    }\n\n    // only norm/neox are handled in the shader\n    if (mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_NORMAL) {\n        return false;\n    }\n\n    // shared memory size for passing data from mul->rope\n    if (mul->ne[0] > 1024) {\n        return false;\n    }\n\n    // conditions for pipeline creation\n    if (!(ctx->device->float_controls_rte_fp16 &&\n        sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {\n        return false;\n    }\n\n    return true;\n}\n\nstatic uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {\n\n    const ggml_tensor *first_node = cgraph->nodes[node_idx];\n    if (first_node->op != GGML_OP_ADD) {\n        return 0;\n    }\n\n    if (!ctx->device->multi_add) {\n        return 0;\n    }\n\n    int32_t num_adds = 1;\n    while (node_idx + num_adds < cgraph->n_nodes &&\n           cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD &&\n           num_adds < MAX_FUSED_ADDS) {\n        num_adds++;\n    }\n\n    // The shader currently requires same shapes (but different strides are allowed),\n    // everything f32, and no misalignment\n    for (int32_t i = 0; i < num_adds; ++i) {\n        const ggml_tensor *next_node = cgraph->nodes[node_idx + i];\n        if (!ggml_are_same_shape(first_node, next_node->src[0]) ||\n            !ggml_are_same_shape(first_node, next_node->src[1]) ||\n            next_node->type != GGML_TYPE_F32 ||\n            next_node->src[0]->type != GGML_TYPE_F32 ||\n            next_node->src[1]->type != GGML_TYPE_F32 ||\n            get_misalign_bytes(ctx, next_node) ||\n            get_misalign_bytes(ctx, next_node->src[0]) ||\n            get_misalign_bytes(ctx, next_node->src[1])) {\n            num_adds = i;\n        }\n    }\n\n    // Verify we can fuse these\n    ggml_op adds[MAX_FUSED_ADDS];\n    for (int32_t i = 0; i < num_adds; ++i) {\n        adds[i] = GGML_OP_ADD;\n    }\n\n    // decrease num_adds if they can't all be fused\n    while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, adds, num_adds)) {\n        num_adds--;\n    }\n\n    // a single add is not \"fused\", so just return zero\n    if (num_adds == 1) {\n        return 0;\n    }\n    return num_adds;\n}\n\nstatic int32_t find_first_set(uint32_t x) {\n    int32_t ret = 0;\n    if (!x) {\n        return -1;\n    }\n    while (!(x & 1)) {\n        x >>= 1;\n        ret++;\n    }\n    return ret;\n}\n\nstatic ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {\n    VK_LOG_DEBUG(\"ggml_backend_vk_graph_compute(\" << cgraph->n_nodes << \" nodes)\");\n    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;\n\n    if (vk_instance.debug_utils_support) {\n        vk::DebugUtilsLabelEXT dul = {};\n        dul.pLabelName = \"ggml_backend_vk_graph_compute\";\n        dul.color = std::array<float,4>{1.0f, 1.0f, 1.0f, 1.0f};\n        vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));\n    }\n\n    ctx->prealloc_size_add_rms_partials_offset = 0;\n    ctx->do_add_rms_partials = false;\n    ctx->do_add_rms_partials_offset_calculation = false;\n\n    int last_node = cgraph->n_nodes - 1;\n\n    // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly\n    while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) {\n        last_node -= 1;\n    }\n\n    // Reserve tensor context space for all nodes\n    ctx->tensor_ctxs.resize(cgraph->n_nodes);\n\n    bool first_node_in_batch = true; // true if next node will be first node in a batch\n    int submit_node_idx = 0; // index to first node in a batch\n\n    ggml_vk_submit_transfer_ctx(ctx);\n\n    vk_context compute_ctx;\n    if (vk_perf_logger_enabled) {\n        // allocate/resize the query pool\n        if (ctx->num_queries < cgraph->n_nodes + 1) {\n            if (ctx->query_pool) {\n                ctx->device->device.destroyQueryPool(ctx->query_pool);\n            }\n            vk::QueryPoolCreateInfo query_create_info;\n            query_create_info.queryType = vk::QueryType::eTimestamp;\n            query_create_info.queryCount = cgraph->n_nodes + 100;\n            ctx->query_pool = ctx->device->device.createQueryPool(query_create_info);\n            ctx->num_queries = query_create_info.queryCount;\n            ctx->query_fusion_names.resize(ctx->num_queries);\n            ctx->query_fusion_node_count.resize(ctx->num_queries);\n            ctx->query_nodes.resize(ctx->num_queries);\n            ctx->query_node_idx.resize(ctx->num_queries);\n        }\n\n        ctx->device->device.resetQueryPool(ctx->query_pool, 0, cgraph->n_nodes+1);\n        std::fill(ctx->query_fusion_names.begin(), ctx->query_fusion_names.end(), nullptr);\n        std::fill(ctx->query_fusion_node_count.begin(), ctx->query_fusion_node_count.end(), 0);\n        std::fill(ctx->query_nodes.begin(), ctx->query_nodes.end(), nullptr);\n        std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0);\n\n        GGML_ASSERT(ctx->compute_ctx.expired());\n        compute_ctx = ggml_vk_get_compute_ctx(ctx);\n        ctx->query_idx = 0;\n        compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);\n    }\n\n    ctx->prealloc_y_last_pipeline_used = nullptr;\n    ctx->prealloc_y_last_tensor_used = nullptr;\n\n    if (ctx->prealloc_size_add_rms_partials) {\n        ggml_vk_preallocate_buffers(ctx, nullptr);\n        compute_ctx = ggml_vk_get_compute_ctx(ctx);\n        // initialize partial sums to zero.\n        ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);\n        ggml_vk_sync_buffers(ctx, compute_ctx);\n    }\n\n    // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.\n    // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB\n    // (and scaled down based on model size, so smaller models submit earlier).\n    // Also submit at least every 100 nodes, in case there are workloads without as much matmul.\n    int nodes_per_submit = 100;\n    int submitted_nodes = 0;\n    int submit_count = 0;\n    uint64_t mul_mat_bytes = 0;\n    uint64_t total_mul_mat_bytes = 0;\n    uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), ctx->last_total_mul_mat_bytes / 40u);\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        if (first_node_in_batch) {\n            submit_node_idx = i;\n        }\n\n        if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {\n            auto bytes = ggml_nbytes(cgraph->nodes[i]->src[0]);\n            mul_mat_bytes += bytes;\n            total_mul_mat_bytes += bytes;\n        }\n\n        // op_srcs_fused_elementwise indicates whether an op's srcs all contribute to\n        // the fused result in an elementwise-way. This affects whether the memory for\n        // the src is allowed to overlap the memory for the destination.\n        // The array is sized to handle the largest fusion (asserted later).\n        bool op_srcs_fused_elementwise[12];\n\n        ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;\n        ctx->fused_topk_moe_scale = false;\n        const char *fusion_string {};\n        if (!ctx->device->disable_fusion) {\n            uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);\n            if (num_adds) {\n                ctx->num_additional_fused_ops = num_adds - 1;\n                fusion_string = \"MULTI_ADD\";\n                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, true);\n            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {\n                ctx->num_additional_fused_ops = 2;\n                fusion_string = \"MUL_MAT_ADD_ADD\";\n                op_srcs_fused_elementwise[0] = false;\n                op_srcs_fused_elementwise[1] = true;\n                op_srcs_fused_elementwise[2] = true;\n            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {\n                ctx->num_additional_fused_ops = 1;\n                fusion_string = \"MUL_MAT_ADD\";\n                op_srcs_fused_elementwise[0] = false;\n                op_srcs_fused_elementwise[1] = true;\n            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {\n                ctx->num_additional_fused_ops = 2;\n                fusion_string = \"MUL_MAT_ID_ADD_ID_MUL\";\n                op_srcs_fused_elementwise[0] = false;\n                op_srcs_fused_elementwise[1] = true;\n                op_srcs_fused_elementwise[2] = true;\n            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {\n                ctx->num_additional_fused_ops = 1;\n                fusion_string = \"MUL_MAT_ID_ADD_ID\";\n                op_srcs_fused_elementwise[0] = false;\n                op_srcs_fused_elementwise[1] = true;\n            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {\n                ctx->num_additional_fused_ops = 1;\n                fusion_string = \"MUL_MAT_ID_MUL\";\n                op_srcs_fused_elementwise[0] = false;\n                op_srcs_fused_elementwise[1] = true;\n            } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&\n                       ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&\n                       ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&\n                       ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {\n                ctx->num_additional_fused_ops = 4;\n                fusion_string = \"RMS_NORM_MUL_ROPE_VIEW_SET_ROWS\";\n                op_srcs_fused_elementwise[0] = false;\n                op_srcs_fused_elementwise[1] = false;\n                op_srcs_fused_elementwise[2] = false;\n                op_srcs_fused_elementwise[3] = false;\n                op_srcs_fused_elementwise[4] = false;\n            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&\n                       ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {\n                ctx->num_additional_fused_ops = 2;\n                fusion_string = \"RMS_NORM_MUL_ROPE\";\n                // rope is approximately elementwise - whole rows are done by a single workgroup and it's row-wise\n                op_srcs_fused_elementwise[0] = false;\n                op_srcs_fused_elementwise[1] = true;\n                op_srcs_fused_elementwise[2] = true;\n            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {\n                ctx->num_additional_fused_ops = 1;\n                fusion_string = \"RMS_NORM_MUL\";\n                // rms_norm is not elementwise, but whole rows must be consumed and the scale factor computed before\n                // they are overwritten, and one workgroup per row. So close enough.\n                op_srcs_fused_elementwise[0] = true;\n                op_srcs_fused_elementwise[1] = true;\n            } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&\n                       ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&\n                       ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {\n                ctx->num_additional_fused_ops = 2;\n                fusion_string = \"ROPE_VIEW_SET_ROWS\";\n                op_srcs_fused_elementwise[0] = false;\n                op_srcs_fused_elementwise[1] = false;\n                op_srcs_fused_elementwise[2] = false;\n            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&\n                       ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&\n                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {\n                ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;\n                // view of argsort writes to memory\n                ctx->fused_ops_write_mask |= 1 << 3;\n                ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;\n                fusion_string = \"TOPK_MOE_EARLY_SOFTMAX_NORM\";\n                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);\n            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&\n                       ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&\n                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {\n                ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1;\n                // view of argsort writes to memory\n                ctx->fused_ops_write_mask |= 1 << 4;\n                ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;\n                fusion_string = \"TOPK_MOE_SIGMOID_NORM_BIAS\";\n                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);\n            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&\n                       ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&\n                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {\n                ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;\n                // view of argsort writes to memory\n                ctx->fused_ops_write_mask |= 1 << 3;\n                ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;\n                fusion_string = \"TOPK_MOE_EARLY_SOFTMAX\";\n                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);\n            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&\n                       ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&\n                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {\n                ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;\n                // view of argsort writes to memory\n                ctx->fused_ops_write_mask |= 1 << 1;\n                ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;\n                fusion_string = \"TOPK_MOE_LATE_SOFTMAX\";\n                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);\n            }\n            if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {\n                // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.\n                if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||\n                    ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {\n                    ctx->fused_topk_moe_scale = true;\n                    ctx->num_additional_fused_ops++;\n                    op_srcs_fused_elementwise[ctx->num_additional_fused_ops] = false;\n                }\n            }\n        }\n        GGML_ASSERT(ctx->num_additional_fused_ops < (int)(sizeof(op_srcs_fused_elementwise) / sizeof(op_srcs_fused_elementwise[0])));\n        ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;\n\n        // Check whether fusion would overwrite src operands while they're still in use.\n        // If so, disable fusion.\n        if (ctx->num_additional_fused_ops) {\n            // There are up to two output nodes - topk_moe has two.\n            uint32_t bits = ctx->fused_ops_write_mask & ~(1 << ctx->num_additional_fused_ops);\n            ggml_tensor *output_nodes[2] {};\n            output_nodes[0] = cgraph->nodes[i + ctx->num_additional_fused_ops];\n            if (bits) {\n                int output_idx = find_first_set(bits);\n                GGML_ASSERT(bits == (1u << output_idx));\n                output_nodes[1] = cgraph->nodes[i + output_idx];\n            }\n\n            bool need_disable = false;\n\n            // topk_moe often overwrites the source, but for a given row all the src values are\n            // loaded before anything is stored. If there's only one row, this is safe, so treat\n            // this as a special case.\n            bool is_topk_moe_single_row = ctx->fused_topk_moe_mode != TOPK_MOE_COUNT &&\n                                          ggml_nrows(cgraph->nodes[i]->src[0]) == 1;\n\n            if (!is_topk_moe_single_row) {\n                for (int j = 0; j < 2; ++j) {\n                    ggml_tensor *dst = output_nodes[j];\n                    if (!dst) {\n                        continue;\n                    }\n                    // Loop over all srcs of all nodes in the fusion. If the src overlaps\n                    // the destination and the src is not an intermediate node that's being\n                    // elided, then disable fusion.\n                    for (int k = 0; k <= ctx->num_additional_fused_ops; ++k) {\n                        for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {\n                            ggml_tensor *src = cgraph->nodes[i + k]->src[s];\n                            if (!src || src->op == GGML_OP_NONE) {\n                                continue;\n                            }\n                            if (ggml_vk_tensors_overlap(src, dst, op_srcs_fused_elementwise[k])) {\n                                bool found = false;\n                                for (int n = 0; n < k; ++n) {\n                                    if (cgraph->nodes[i + n] == src) {\n                                        found = true;\n                                        break;\n                                    }\n                                }\n                                if (!found) {\n                                    need_disable = true;\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n            if (need_disable) {\n                ctx->num_additional_fused_ops = 0;\n                ctx->fused_ops_write_mask = 1;\n                ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;\n                ctx->fused_topk_moe_scale = false;\n            }\n        }\n\n        // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)\n        bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;\n        bool submit = (submitted_nodes >= nodes_per_submit) ||\n                      (mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) ||\n                      (i + ctx->num_additional_fused_ops >= last_node) ||\n                      (almost_ready && !ctx->almost_ready_fence_pending);\n\n        bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);\n\n        if (vk_perf_logger_enabled && enqueued) {\n            compute_ctx = ggml_vk_get_compute_ctx(ctx);\n            if (!vk_perf_logger_concurrent) {\n                // track a single node/fusion for the current query\n                ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];\n                ctx->query_fusion_names[ctx->query_idx] = fusion_string;\n                compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);\n            } else {\n                // track a fusion string and number of fused ops for the current node_idx\n                ctx->query_fusion_names[i] = fusion_string;\n                ctx->query_fusion_node_count[i] = ctx->num_additional_fused_ops;\n            }\n        }\n\n        if (enqueued) {\n            ++submitted_nodes;\n\n#ifndef GGML_VULKAN_CHECK_RESULTS\n            if (first_node_in_batch) {\n                first_node_in_batch = false;\n            }\n#endif\n        }\n\n        if (submit && enqueued) {\n            first_node_in_batch = true;\n            submitted_nodes = 0;\n            mul_mat_bytes = 0;\n            if (submit_count < 3) {\n                mul_mat_bytes_per_submit *= 2;\n            }\n            submit_count++;\n        }\n        i += ctx->num_additional_fused_ops;\n        ctx->num_additional_fused_ops = 0;\n        ctx->fused_ops_write_mask = 0;\n    }\n\n    ctx->last_total_mul_mat_bytes = total_mul_mat_bytes;\n\n    if (vk_perf_logger_enabled) {\n        // End the command buffer and submit/wait\n        GGML_ASSERT(!ctx->compute_ctx.expired());\n        compute_ctx = ctx->compute_ctx.lock();\n        ggml_vk_ctx_end(compute_ctx);\n\n        ggml_vk_submit(compute_ctx, ctx->device->fence);\n        VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), \"GGML_VULKAN_PERF waitForFences\");\n        ctx->device->device.resetFences({ ctx->device->fence });\n        ctx->compute_ctx.reset();\n\n        // Get the results and pass them to the logger\n        std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);\n        VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->query_pool, 0, ctx->query_idx, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), \"get timestamp results\");\n        if (!vk_perf_logger_concurrent) {\n            // Log each op separately\n            for (int i = 1; i < ctx->query_idx; i++) {\n                auto node = ctx->query_nodes[i];\n                auto name = ctx->query_fusion_names[i];\n                ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));\n            }\n        } else {\n            // Log each group of nodes\n            int prev_node_idx = 0;\n            for (int i = 1; i < ctx->query_idx; i++) {\n                auto cur_node_idx = ctx->query_node_idx[i];\n                std::vector<ggml_tensor *> nodes;\n                std::vector<const char *> names;\n                for (int node_idx = prev_node_idx; node_idx < cur_node_idx; ++node_idx) {\n                    if (ggml_op_is_empty(cgraph->nodes[node_idx]->op)) {\n                        continue;\n                    }\n                    nodes.push_back(cgraph->nodes[node_idx]);\n                    names.push_back(ctx->query_fusion_names[node_idx]);\n                    node_idx += ctx->query_fusion_node_count[node_idx];\n                }\n                prev_node_idx = cur_node_idx;\n                ctx->perf_logger->log_timing(nodes, names, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));\n            }\n        }\n        ctx->perf_logger->print_timings();\n    }\n\n    if (!ctx->device->support_async) {\n        ggml_vk_synchronize(ctx);\n    }\n\n    return GGML_STATUS_SUCCESS;\n\n    UNUSED(backend);\n}\n\n// Sort the graph for improved parallelism.\nstatic void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph)\n{\n    VK_LOG_DEBUG(\"ggml_vk_graph_optimize(\" << graph->n_nodes << \" nodes)\");\n    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;\n\n    if (ctx->device->disable_graph_optimize) {\n        return;\n    }\n\n    auto const &is_empty = [](ggml_tensor * node) -> bool {\n        return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;\n    };\n\n    auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool {\n        for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {\n            if (dst->src[s] == src) {\n                return true;\n            }\n        }\n        // implicit dependency if they view the same tensor\n        const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst;\n        const ggml_tensor *src2 = src->view_src ? src->view_src : src;\n        if (dst2 == src2) {\n            return true;\n        }\n        return false;\n    };\n\n    std::vector<ggml_tensor *> new_order;\n    std::vector<bool> used(graph->n_nodes, false);\n    std::set<ggml_tensor *> used_node_set;\n\n    int first_unused = 0;\n    while (first_unused < graph->n_nodes) {\n        std::vector<int> current_set;\n\n        // Check for fusion patterns and avoid reordering them\n        auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {\n            if (start + (int)pattern.size() <= graph->n_nodes) {\n                bool is_pattern = true;\n                for (size_t j = 0; j < pattern.size(); ++j) {\n                    if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {\n                        is_pattern = false;\n                    }\n                }\n                return is_pattern;\n            }\n            return false;\n        };\n\n        auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {\n            if (match_pattern(pattern, first_unused)) {\n                for (size_t j = 0; j < pattern.size(); ++j) {\n                    new_order.push_back(graph->nodes[first_unused + j]);\n                    used_node_set.insert(graph->nodes[first_unused + j]);\n                    used[first_unused + j] = true;\n                }\n                while (first_unused < graph->n_nodes && used[first_unused]) {\n                    first_unused++;\n                }\n                return true;\n            }\n            return false;\n        };\n\n        if (keep_pattern(topk_moe_early_softmax_norm)) {\n            continue;\n        }\n        if (keep_pattern(topk_moe_sigmoid_norm_bias)) {\n            continue;\n        }\n        if (keep_pattern(topk_moe_early_softmax)) {\n            continue;\n        }\n        if (keep_pattern(topk_moe_late_softmax)) {\n            continue;\n        }\n\n        // First, grab the next unused node.\n        current_set.push_back(first_unused);\n\n        // Loop through the next N nodes. Grab any that don't depend on other nodes that\n        // haven't already been run. Nodes that have already been run have used[i] set\n        // to true. Allow nodes that depend on the previous node if it's a fusion pattern\n        // that we support (e.g. RMS_NORM + MUL).\n        // This first pass only grabs \"real\" (non-view nodes). Second pass grabs view nodes.\n        // The goal is to not interleave real and view nodes in a way that breaks fusion.\n        const int NUM_TO_CHECK = 20;\n        for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {\n            if (used[j]) {\n                continue;\n            }\n            if (is_empty(graph->nodes[j])) {\n                continue;\n            }\n            // Don't pull forward nodes from fusion patterns\n            if (match_pattern(topk_moe_early_softmax_norm, j) ||\n                match_pattern(topk_moe_sigmoid_norm_bias, j) ||\n                match_pattern(topk_moe_early_softmax, j) ||\n                match_pattern(topk_moe_late_softmax, j)) {\n                continue;\n            }\n            bool ok = true;\n            for (int c = first_unused; c < j; ++c) {\n                if (!used[c] &&\n                    is_src_of(graph->nodes[j], graph->nodes[c]) &&\n                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&\n                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&\n                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&\n                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&\n                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {\n                    ok = false;\n                    break;\n                }\n            }\n            if (ok) {\n                current_set.push_back(j);\n\n                int rope_idx = j;\n\n                // When we've found RMS_NORM + MUL, try to find a ROPE that uses it\n                if (j > 0 &&\n                    graph->nodes[j]->op == GGML_OP_MUL &&\n                    graph->nodes[j-1]->op == GGML_OP_RMS_NORM) {\n                    for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {\n                        if (graph->nodes[k]->op == GGML_OP_ROPE &&\n                            graph->nodes[k]->src[0] == graph->nodes[j] &&\n                            // Check that other srcs are already valid\n                            graph->nodes[k]->src[1]->op == GGML_OP_NONE &&\n                            (graph->nodes[k]->src[2] == nullptr || graph->nodes[k]->src[2]->op == GGML_OP_NONE)) {\n                            rope_idx = k;\n                            current_set.push_back(rope_idx);\n                            used[rope_idx] = true;\n                            break;\n                        }\n                    }\n                }\n                // Look for ROPE + VIEW + SET_ROWS and make them consecutive\n                if (graph->nodes[rope_idx]->op == GGML_OP_ROPE) {\n                    int view_idx = -1;\n                    int set_rows_idx = -1;\n                    for (int k = rope_idx+1; k < std::min(rope_idx + 10, graph->n_nodes); ++k) {\n                        if (view_idx == -1 &&\n                            graph->nodes[k]->op == GGML_OP_VIEW &&\n                            graph->nodes[k]->src[0] == graph->nodes[rope_idx]) {\n                            view_idx = k;\n                            continue;\n                        }\n                        if (view_idx != -1 &&\n                            set_rows_idx == -1 &&\n                            graph->nodes[k]->op == GGML_OP_SET_ROWS &&\n                            graph->nodes[k]->src[0] == graph->nodes[view_idx]) {\n                            set_rows_idx = k;\n                            break;\n                        }\n                    }\n                    if (set_rows_idx != -1) {\n                        current_set.push_back(view_idx);\n                        current_set.push_back(set_rows_idx);\n                        used[view_idx] = true;\n                        used[set_rows_idx] = true;\n                    }\n                }\n                // Look for MUL_MAT_ID + ADD_ID + MUL\n                if (j > 0 &&\n                    graph->nodes[j]->op == GGML_OP_ADD_ID &&\n                    graph->nodes[j-1]->op == GGML_OP_MUL_MAT_ID) {\n                    for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {\n                        if (graph->nodes[k]->op == GGML_OP_MUL &&\n                            graph->nodes[k]->src[0] == graph->nodes[j] &&\n                            // src1 must either be weights or already processed\n                            (graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) {\n                            current_set.push_back(k);\n                            used[k] = true;\n                            break;\n                        }\n                    }\n                }\n                // Look for MUL_MAT + ADD + ADD\n                if (j > 0 &&\n                    graph->nodes[j]->op == GGML_OP_ADD &&\n                    graph->nodes[j-1]->op == GGML_OP_MUL_MAT) {\n                    for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {\n                        if (graph->nodes[k]->op == GGML_OP_ADD &&\n                            graph->nodes[k]->src[0] == graph->nodes[j] &&\n                            // src1 must either be weights or already processed\n                            (graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) {\n                            current_set.push_back(k);\n                            used[k] = true;\n                            break;\n                        }\n                    }\n                }\n            }\n        }\n        // Second pass grabs view nodes.\n        // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add).\n        if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) {\n            for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {\n                if (used[j]) {\n                    continue;\n                }\n                if (!is_empty(graph->nodes[j])) {\n                    continue;\n                }\n                bool ok = true;\n                for (int c = first_unused; c < j; ++c) {\n                    bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end();\n                    // skip views whose srcs haven't been processed.\n                    if (!used[c] &&\n                        is_src_of(graph->nodes[j], graph->nodes[c]) &&\n                        !c_in_current_set) {\n                        ok = false;\n                        break;\n                    }\n                }\n                if (ok) {\n                    current_set.push_back(j);\n                }\n            }\n        }\n\n        // Push the current set into new_order\n        for (auto c : current_set) {\n            new_order.push_back(graph->nodes[c]);\n            used_node_set.insert(graph->nodes[c]);\n            used[c] = true;\n        }\n        while (first_unused < graph->n_nodes && used[first_unused]) {\n            first_unused++;\n        }\n    }\n    // Replace the graph with the new order.\n    for (int i = 0; i < graph->n_nodes; ++i) {\n        graph->nodes[i] = new_order[i];\n    }\n}\n\nstatic void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_event_t event) {\n    VK_LOG_DEBUG(\"ggml_backend_vk_event_record(backend=\" << backend << \", event=\" << event << \")\");\n    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;\n    vk_event *vkev = (vk_event *)event->context;\n\n    ggml_vk_submit_transfer_ctx(ctx);\n\n    vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);\n    auto* cmd_buf = compute_ctx->s->buffer; // retrieve pointer before it gets reset\n\n    // the backend interface doesn't have an explicit reset, so reset it here\n    // before we record the command to set it\n    ctx->device->device.resetEvent(vkev->event);\n    ctx->device->device.resetFences({ vkev->fence });\n\n    ggml_vk_set_event(compute_ctx, vkev->event);\n\n    ggml_vk_ctx_end(compute_ctx);\n\n    ggml_vk_submit(compute_ctx, {vkev->fence});\n    ctx->submit_pending = true;\n    vkev->cmd_buffer = cmd_buf;\n    ctx->compute_ctx.reset();\n}\n\nstatic void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {\n    VK_LOG_DEBUG(\"ggml_backend_vk_event_wait(backend=\" << backend << \", event=\" << event << \")\");\n    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;\n    vk_event *vkev = (vk_event *)event->context;\n\n    vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);\n\n    ggml_vk_wait_events(compute_ctx, {vkev->event});\n    ggml_vk_ctx_end(compute_ctx);\n    ctx->compute_ctx.reset();\n}\n\n// TODO: enable async and synchronize\nstatic ggml_backend_i ggml_backend_vk_interface = {\n    /* .get_name                = */ ggml_backend_vk_name,\n    /* .free                    = */ ggml_backend_vk_free,\n    /* .set_tensor_async        = */ ggml_backend_vk_set_tensor_async,\n    /* .get_tensor_async        = */ ggml_backend_vk_get_tensor_async,\n    /* .cpy_tensor_async        = */ ggml_backend_vk_cpy_tensor_async,\n    /* .synchronize             = */ ggml_backend_vk_synchronize,\n    /* .graph_plan_create       = */ NULL,\n    /* .graph_plan_free         = */ NULL,\n    /* .graph_plan_update       = */ NULL,\n    /* .graph_plan_compute      = */ NULL,\n    /* .graph_compute           = */ ggml_backend_vk_graph_compute,\n    /* .event_record            = */ ggml_backend_vk_event_record,\n    /* .event_wait              = */ ggml_backend_vk_event_wait,\n    /* .graph_optimize          = */ ggml_vk_graph_optimize,\n};\n\nstatic ggml_guid_t ggml_backend_vk_guid() {\n    static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };\n    return &guid;\n}\n\nggml_backend_t ggml_backend_vk_init(size_t dev_num) {\n    VK_LOG_DEBUG(\"ggml_backend_vk_init(\" << dev_num << \")\");\n\n    ggml_backend_vk_context * ctx = new ggml_backend_vk_context;\n    ggml_vk_init(ctx, dev_num);\n\n    ggml_backend_t vk_backend = new ggml_backend {\n        /* .guid    = */ ggml_backend_vk_guid(),\n        /* .iface   = */ ggml_backend_vk_interface,\n        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),\n        /* .context = */ ctx,\n    };\n\n    if (!ctx->device->support_async) {\n        vk_backend->iface.get_tensor_async = nullptr;\n    }\n\n    return vk_backend;\n}\n\nbool ggml_backend_is_vk(ggml_backend_t backend) {\n    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());\n}\n\nint ggml_backend_vk_get_device_count() {\n    return ggml_vk_get_device_count();\n}\n\nvoid ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {\n    GGML_ASSERT(device < (int) vk_instance.device_indices.size());\n    int dev_idx = vk_instance.device_indices[device];\n    ggml_vk_get_device_description(dev_idx, description, description_size);\n}\n\nvoid ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {\n    GGML_ASSERT(device < (int) vk_instance.device_indices.size());\n    GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());\n\n    vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];\n    vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;\n    vk::PhysicalDeviceMemoryProperties2 memprops = {};\n    const bool membudget_supported = vk_instance.device_supports_membudget[device];\n    const bool is_integrated_gpu = vkdev.getProperties().deviceType == vk::PhysicalDeviceType::eIntegratedGpu;\n\n    if (membudget_supported) {\n        memprops.pNext = &budgetprops;\n    }\n    vkdev.getMemoryProperties2(&memprops);\n\n    *total = 0;\n    *free = 0;\n\n    for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) {\n        const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i];\n\n        if (is_integrated_gpu || (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal)) {\n            *total += heap.size;\n\n            if (membudget_supported && i < budgetprops.heapUsage.size()) {\n                *free += budgetprops.heapBudget[i] - budgetprops.heapUsage[i];\n            } else {\n                *free += heap.size;\n            }\n        }\n    }\n}\n\nstatic vk::PhysicalDeviceType ggml_backend_vk_get_device_type(int device_idx) {\n    GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size());\n\n    vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]];\n\n    vk::PhysicalDeviceProperties2 props = {};\n    device.getProperties2(&props);\n\n    return props.properties.deviceType;\n}\n\nstatic std::string ggml_backend_vk_get_device_pci_id(int device_idx) {\n    GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size());\n\n    vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]];\n\n    const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();\n\n    bool ext_support = false;\n\n    for (const auto& properties : ext_props) {\n        if (strcmp(\"VK_EXT_pci_bus_info\", properties.extensionName) == 0) {\n            ext_support = true;\n            break;\n        }\n    }\n\n    if (!ext_support) {\n        return \"\";\n    }\n\n    vk::PhysicalDeviceProperties2 props = {};\n    vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_info = {};\n\n    props.pNext = &pci_bus_info;\n\n    device.getProperties2(&props);\n\n    const uint32_t pci_domain = pci_bus_info.pciDomain;\n    const uint32_t pci_bus = pci_bus_info.pciBus;\n    const uint32_t pci_device = pci_bus_info.pciDevice;\n    const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning\n\n    char pci_bus_id[16] = {};\n    snprintf(pci_bus_id, sizeof(pci_bus_id), \"%04x:%02x:%02x.%x\", pci_domain, pci_bus, pci_device, pci_function);\n\n    return std::string(pci_bus_id);\n}\n\n//////////////////////////\n\nstruct ggml_backend_vk_device_context {\n    size_t device;\n    std::string name;\n    std::string description;\n    bool is_integrated_gpu;\n    std::string pci_bus_id;\n    int op_offload_min_batch_size;\n};\n\nstatic const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {\n    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;\n    return ctx->name.c_str();\n}\n\nstatic const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) {\n    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;\n    return ctx->description.c_str();\n}\n\nstatic void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {\n    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;\n    ggml_backend_vk_get_device_memory(ctx->device, free, total);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {\n    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;\n    return ggml_backend_vk_buffer_type(ctx->device);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) {\n    UNUSED(dev);\n    return ggml_backend_vk_host_buffer_type();\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {\n    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;\n\n    return ctx->is_integrated_gpu ? GGML_BACKEND_DEVICE_TYPE_IGPU : GGML_BACKEND_DEVICE_TYPE_GPU;\n}\n\nstatic void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {\n    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;\n\n    props->name        = ggml_backend_vk_device_get_name(dev);\n    props->description = ggml_backend_vk_device_get_description(dev);\n    props->type        = ggml_backend_vk_device_get_type(dev);\n    props->device_id   = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();\n    ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);\n    props->caps = {\n        /* .async                 = */ true,\n        /* .host_buffer           = */ true,\n        /* .buffer_from_host_ptr  = */ false,\n        /* .events                = */ true,\n    };\n}\n\nstatic ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {\n    UNUSED(params);\n    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;\n    return ggml_backend_vk_init(ctx->device);\n}\n\nstatic bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;\n    const vk_device& device = ggml_vk_get_device(ctx->device);\n\n    const bool uses_bda = (op->op == GGML_OP_IM2COL || op->op == GGML_OP_IM2COL_3D) &&\n                          device->shader_int64 && device->buffer_device_address;\n\n    auto const & tensor_size_supported = [&](size_t tensor_size) {\n        if (tensor_size > device->max_buffer_size) {\n            return false;\n        }\n        // For im2col shaders using BDA, maxStorageBufferRange limit doesn't apply.\n        // If shader64BitIndexing is enabled, maxStorageBufferRange limit doesn't apply.\n        if (!uses_bda && !device->shader_64b_indexing) {\n            if (tensor_size > device->properties.limits.maxStorageBufferRange) {\n                return false;\n            }\n        }\n        return true;\n    };\n    // reject any tensors larger than the max buffer size\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        if (op->src[i] && !tensor_size_supported(ggml_nbytes(op->src[i]))) {\n            return false;\n        }\n    }\n    if (!tensor_size_supported(ggml_nbytes(op))) {\n        return false;\n    }\n\n    switch (op->op) {\n        case GGML_OP_UNARY:\n            switch (ggml_get_unary_op(op)) {\n                case GGML_UNARY_OP_EXP:\n                case GGML_UNARY_OP_ELU:\n                case GGML_UNARY_OP_GELU:\n                case GGML_UNARY_OP_GELU_ERF:\n                case GGML_UNARY_OP_GELU_QUICK:\n                case GGML_UNARY_OP_SILU:\n                case GGML_UNARY_OP_RELU:\n                case GGML_UNARY_OP_XIELU:\n                case GGML_UNARY_OP_NEG:\n                case GGML_UNARY_OP_TANH:\n                case GGML_UNARY_OP_SIGMOID:\n                case GGML_UNARY_OP_HARDSIGMOID:\n                case GGML_UNARY_OP_HARDSWISH:\n                case GGML_UNARY_OP_ABS:\n                case GGML_UNARY_OP_SOFTPLUS:\n                case GGML_UNARY_OP_STEP:\n                case GGML_UNARY_OP_ROUND:\n                case GGML_UNARY_OP_CEIL:\n                case GGML_UNARY_OP_FLOOR:\n                case GGML_UNARY_OP_TRUNC:\n                case GGML_UNARY_OP_SGN:\n                    return ggml_is_contiguous(op->src[0]) &&\n                           (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&\n                           (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&\n                           (op->src[0]->type == op->type);\n                default:\n                    return false;\n            }\n        case GGML_OP_GLU:\n            switch (ggml_get_glu_op(op)) {\n                case GGML_GLU_OP_GEGLU:\n                case GGML_GLU_OP_REGLU:\n                case GGML_GLU_OP_SWIGLU:\n                case GGML_GLU_OP_SWIGLU_OAI:\n                case GGML_GLU_OP_GEGLU_ERF:\n                case GGML_GLU_OP_GEGLU_QUICK:\n                    return ggml_is_contiguous(op->src[0]) &&\n                           (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&\n                           (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&\n                           (op->src[0]->type == op->type);\n                default:\n                    return false;\n            }\n        case GGML_OP_MUL_MAT:\n        case GGML_OP_MUL_MAT_ID:\n            {\n                ggml_type src0_type = op->src[0]->type;\n                if (op->op == GGML_OP_MUL_MAT_ID) {\n                    if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {\n                        // If there's not enough shared memory for row_ids and the result tile, fallback to CPU\n                        return false;\n                    }\n                }\n                switch (src0_type) {\n                    case GGML_TYPE_F32:\n                    case GGML_TYPE_F16:\n                    case GGML_TYPE_BF16:\n                    case GGML_TYPE_Q4_0:\n                    case GGML_TYPE_Q4_1:\n                    case GGML_TYPE_Q5_0:\n                    case GGML_TYPE_Q5_1:\n                    case GGML_TYPE_Q8_0:\n                    case GGML_TYPE_Q2_K:\n                    case GGML_TYPE_Q3_K:\n                    case GGML_TYPE_Q4_K:\n                    case GGML_TYPE_Q5_K:\n                    case GGML_TYPE_Q6_K:\n                    case GGML_TYPE_IQ1_S:\n                    case GGML_TYPE_IQ1_M:\n                    case GGML_TYPE_IQ2_XXS:\n                    case GGML_TYPE_IQ2_XS:\n                    case GGML_TYPE_IQ2_S:\n                    case GGML_TYPE_IQ3_XXS:\n                    case GGML_TYPE_IQ3_S:\n                    case GGML_TYPE_IQ4_XS:\n                    case GGML_TYPE_IQ4_NL:\n                    case GGML_TYPE_MXFP4:\n                        break;\n                    default:\n                        return false;\n                }\n                struct ggml_tensor * a;\n                struct ggml_tensor * b;\n                if (op->op == GGML_OP_MUL_MAT) {\n                    a = op->src[0];\n                    b = op->src[1];\n                } else {\n                    a = op->src[2];\n                    b = op->src[1];\n                }\n                if (a->ne[3] != b->ne[3]) {\n                    return false;\n                }\n                if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) ||\n                    !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {\n                    return false;\n                }\n                if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) {\n                    // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader.\n                    // So don't support this combination for now.\n                    return false;\n                }\n\n                return true;\n            }\n        case GGML_OP_FLASH_ATTN_EXT:\n            {\n                bool coopmat2 = device->coopmat2;\n                uint32_t HSK = op->src[1]->ne[0];\n                uint32_t HSV = op->src[2]->ne[0];\n                if ((HSK % 8) != 0 || (HSV % 8) != 0) {\n                    return false;\n                }\n                if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) {\n                    return false;\n                }\n                if (op->src[0]->type != GGML_TYPE_F32) {\n                    return false;\n                }\n                if (op->type != GGML_TYPE_F32) {\n                    return false;\n                }\n                if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {\n                    return false;\n                }\n                // It's straightforward to support different K/V dequant, but would\n                // significantly increase the number of pipelines\n                if (op->src[1]->type != op->src[2]->type) {\n                    return false;\n                }\n                switch (op->src[1]->type) {\n                case GGML_TYPE_F16:\n                case GGML_TYPE_F32:\n                case GGML_TYPE_Q4_0:\n                case GGML_TYPE_Q8_0:\n                    // supported in scalar and coopmat2 paths\n                    break;\n                case GGML_TYPE_Q4_1:\n                case GGML_TYPE_Q5_0:\n                case GGML_TYPE_Q5_1:\n                // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently\n                //case GGML_TYPE_Q2_K:\n                //case GGML_TYPE_Q3_K:\n                //case GGML_TYPE_Q4_K:\n                //case GGML_TYPE_Q5_K:\n                //case GGML_TYPE_Q6_K:\n                //case GGML_TYPE_IQ1_S:\n                //case GGML_TYPE_IQ1_M:\n                //case GGML_TYPE_IQ2_XXS:\n                //case GGML_TYPE_IQ2_XS:\n                //case GGML_TYPE_IQ2_S:\n                //case GGML_TYPE_IQ3_XXS:\n                //case GGML_TYPE_IQ3_S:\n                //case GGML_TYPE_IQ4_XS:\n                case GGML_TYPE_IQ4_NL:\n                    // currently supported only in coopmat2 path\n                    if (!coopmat2) {\n                        return false;\n                    }\n                    break;\n                default:\n                    return false;\n                }\n                if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {\n                    // scalar/coopmat1 FA uses subgroupShuffle/subgroupAll\n                    return false;\n                }\n                return true;\n            }\n        case GGML_OP_GET_ROWS:\n            {\n                switch (op->src[0]->type) {\n                    case GGML_TYPE_F32:\n                    case GGML_TYPE_F16:\n                    case GGML_TYPE_BF16:\n                    case GGML_TYPE_Q4_0:\n                    case GGML_TYPE_Q4_1:\n                    case GGML_TYPE_Q5_0:\n                    case GGML_TYPE_Q5_1:\n                    case GGML_TYPE_Q8_0:\n                    case GGML_TYPE_Q2_K:\n                    case GGML_TYPE_Q3_K:\n                    case GGML_TYPE_Q4_K:\n                    case GGML_TYPE_Q5_K:\n                    case GGML_TYPE_Q6_K:\n                    case GGML_TYPE_IQ1_S:\n                    case GGML_TYPE_IQ1_M:\n                    case GGML_TYPE_IQ2_XXS:\n                    case GGML_TYPE_IQ2_XS:\n                    case GGML_TYPE_IQ2_S:\n                    case GGML_TYPE_IQ3_XXS:\n                    case GGML_TYPE_IQ3_S:\n                    case GGML_TYPE_IQ4_XS:\n                    case GGML_TYPE_IQ4_NL:\n                    case GGML_TYPE_MXFP4:\n                    case GGML_TYPE_I32:\n                        return true;\n                    default:\n                        return false;\n                }\n            }\n        case GGML_OP_SET_ROWS:\n            {\n                switch (op->type) {\n                    case GGML_TYPE_F32:\n                    case GGML_TYPE_F16:\n                    case GGML_TYPE_BF16:\n                    case GGML_TYPE_Q4_0:\n                    case GGML_TYPE_Q4_1:\n                    case GGML_TYPE_Q5_0:\n                    case GGML_TYPE_Q5_1:\n                    case GGML_TYPE_Q8_0:\n                    case GGML_TYPE_IQ4_NL:\n                        return true;\n                    default:\n                        return false;\n                }\n            }\n        case GGML_OP_CONT:\n        case GGML_OP_CPY:\n        case GGML_OP_DUP:\n            {\n                ggml_type src0_type = op->src[0]->type;\n                ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type;\n\n                if (src0_type == GGML_TYPE_F32) {\n                    switch (src1_type) {\n                    case GGML_TYPE_F32:\n                    case GGML_TYPE_F16:\n                    case GGML_TYPE_BF16:\n                    case GGML_TYPE_Q4_0:\n                    case GGML_TYPE_Q4_1:\n                    case GGML_TYPE_Q5_0:\n                    case GGML_TYPE_Q5_1:\n                    case GGML_TYPE_Q8_0:\n                    case GGML_TYPE_IQ4_NL:\n                        return true;\n                    default:\n                        break;\n                    }\n                }\n                if (src1_type == GGML_TYPE_F32) {\n                    switch (src0_type) {\n                    case GGML_TYPE_F16:\n                    case GGML_TYPE_Q4_0:\n                    case GGML_TYPE_Q4_1:\n                    case GGML_TYPE_Q5_0:\n                    case GGML_TYPE_Q5_1:\n                    case GGML_TYPE_Q8_0:\n                    case GGML_TYPE_IQ4_NL:\n                        return true;\n                    default:\n                        break;\n                    }\n                }\n\n                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {\n                    return true;\n                }\n\n                if (\n                    (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) ||\n                    (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32)\n                ) {\n                    return true;\n                }\n\n                // We can handle copying from a type to the same type if it's\n                // either not quantized or is quantized and contiguous.\n                // We use f16 or f32 shaders to do the copy,\n                // so the type/block size must be a multiple of 4.\n                if (src0_type == src1_type &&\n                    (!ggml_is_quantized(src0_type) || (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op))) &&\n                    (ggml_type_size(src0_type) % 2) == 0) {\n                    return true;\n                }\n                return false;\n            }\n        case GGML_OP_REPEAT:\n            return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);\n        case GGML_OP_REPEAT_BACK:\n            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_ROPE:\n            return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]);\n        case GGML_OP_ROPE_BACK:\n        case GGML_OP_NONE:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n        case GGML_OP_RMS_NORM:\n            return true;\n        case GGML_OP_NORM:\n        case GGML_OP_GROUP_NORM:\n            return ggml_is_contiguous(op->src[0]);\n        case GGML_OP_L2_NORM:\n            return ggml_is_contiguous_rows(op->src[0]) &&\n                   op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;\n        case GGML_OP_ADD:\n        case GGML_OP_SUB:\n        case GGML_OP_MUL:\n        case GGML_OP_DIV:\n            return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&\n                   (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&\n                   (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);\n        case GGML_OP_ADD_ID:\n            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&\n                   op->type == GGML_TYPE_F32;\n        case GGML_OP_SILU_BACK:\n        case GGML_OP_RMS_NORM_BACK:\n            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_SQR:\n        case GGML_OP_SQRT:\n        case GGML_OP_SIN:\n        case GGML_OP_COS:\n        case GGML_OP_CLAMP:\n            return op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_LEAKY_RELU:\n        case GGML_OP_OPT_STEP_ADAMW:\n        case GGML_OP_OPT_STEP_SGD:\n            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_LOG:\n        case GGML_OP_TRI:\n        case GGML_OP_DIAG:\n            return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&\n                   op->type == op->src[0]->type;\n        case GGML_OP_ARGSORT:\n            {\n                if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {\n                    return false;\n                }\n                // pipeline_argsort_large_f32 requires vulkan memory model.\n                if (device->vulkan_memory_model) {\n                    return true;\n                } else {\n                    return op->ne[0] <= (1 << device->max_workgroup_size_log2);\n                }\n            }\n        case GGML_OP_TOP_K:\n            {\n                if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {\n                    return false;\n                }\n                // We could potentially support larger, using argsort to sort the\n                // whole thing. Not clear if this is needed.\n                uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;\n                if (min_pipeline >= num_topk_pipelines ||\n                    !device->pipeline_topk_f32[min_pipeline]) {\n                    return false;\n                }\n            }\n            return true;\n        case GGML_OP_UPSCALE:\n            if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {\n                if ((op->op_params[0] & 0xFF) != GGML_SCALE_MODE_BILINEAR) {\n                    return false;\n                }\n            }\n            return op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_ACC:\n            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;\n        case GGML_OP_SET:\n            return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&\n                   (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);\n        case GGML_OP_CONCAT:\n            return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);\n        case GGML_OP_ADD1:\n            return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32)\n                || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32)\n                || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16);\n        case GGML_OP_ARANGE:\n        case GGML_OP_FILL:\n            return op->type == GGML_TYPE_F32;\n        case GGML_OP_SCALE:\n            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_PAD:\n        case GGML_OP_ROLL:\n            return op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_DIAG_MASK_INF:\n            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_SOFT_MAX:\n            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32\n                && (!op->src[1] || (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16));\n        case GGML_OP_SOFT_MAX_BACK:\n            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32\n                && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32;\n        case GGML_OP_SUM:\n        case GGML_OP_SUM_ROWS:\n        case GGML_OP_MEAN:\n            return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);\n        case GGML_OP_CUMSUM:\n            {\n                if (device->subgroup_arithmetic && device->subgroup_require_full_support) {\n                    return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);\n                }\n                return false;\n            }\n        case GGML_OP_SOLVE_TRI:\n            {\n                if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {\n                    return false;\n                }\n                const uint32_t N = op->src[0]->ne[0];\n                const uint32_t K = op->src[1]->ne[0];\n                // K dimension limited to workgroup size\n                if (K > 1u << device->max_workgroup_size_log2) {\n                    return false;\n                }\n                const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((N + K) * sizeof(float));\n\n                if (batch_N == 0) {\n                    return false;\n                }\n                return true;\n            }\n        case GGML_OP_ARGMAX:\n            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_COUNT_EQUAL:\n            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_I32\n                && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_I32;\n        case GGML_OP_IM2COL:\n            return ggml_is_contiguous(op->src[1])\n                && op->src[1]->type == GGML_TYPE_F32\n                && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);\n        case GGML_OP_IM2COL_3D:\n            return op->src[1]->type == GGML_TYPE_F32\n                && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);\n        case GGML_OP_TIMESTEP_EMBEDDING:\n            return op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_CONV_2D_DW:\n            return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16)\n                && op->src[1]->type == GGML_TYPE_F32;\n        case GGML_OP_POOL_2D:\n            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_RWKV_WKV6:\n        case GGML_OP_RWKV_WKV7:\n            return true; // all inputs are contiguous, see ggml.c\n        case GGML_OP_GATED_DELTA_NET:\n            {\n                const uint32_t S_v = op->src[2]->ne[0];\n                if (S_v != 32 && S_v != 64 && S_v != 128) {\n                    return false;\n                }\n                for (int i = 0; i < 6; i++) {\n                    if (op->src[i] == nullptr || op->src[i]->type != GGML_TYPE_F32) {\n                        return false;\n                    }\n                }\n                return op->type == GGML_TYPE_F32;\n            }\n        case GGML_OP_SSM_SCAN:\n            {\n                for (int i = 0; i < 6; i++) {\n                    if (op->src[i] && ggml_is_quantized(op->src[i]->type)) {\n                        return false;\n                    }\n                }\n                if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) {\n                    return false;\n                }\n                if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) {\n                    return false;\n                }\n\n                const uint32_t d_state = op->src[0]->ne[0];\n                const uint32_t head_dim = op->src[0]->ne[1];\n\n                bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float));\n                if (!is_mamba2) {\n                    return false;\n                }\n\n                if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) {\n                    return false;\n                }\n\n                size_t shmem_size = d_state * sizeof(float);\n\n                if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) {\n                    return false;\n                }\n\n                if (!device->subgroup_basic) {\n                    return false;\n                }\n\n                return true;\n            }\n        case GGML_OP_SSM_CONV:\n            return op->src[0]->type == GGML_TYPE_F32;\n        case GGML_OP_CONV_TRANSPOSE_1D:\n            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;\n        case GGML_OP_CONV_2D:\n        case GGML_OP_CONV_TRANSPOSE_2D:\n            {\n                // Channel-contiguous format is not supported yet.\n                return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&\n                    op->src[1]->type == GGML_TYPE_F32 &&\n                    op->type == GGML_TYPE_F32 &&\n                    ggml_is_contiguous(op->src[0]) &&\n                    ggml_is_contiguous(op->src[1]) &&\n                    ggml_is_contiguous(op));\n            }\n        default:\n            return false;\n    }\n\n    UNUSED(dev);\n}\n\nstatic bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) {\n        return false;\n    }\n\n    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;\n    ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;\n\n    return buft_ctx->device->idx == ctx->device;\n}\n\nstatic int64_t ggml_vk_get_op_batch_size(const ggml_tensor * op) {\n    switch (op->op) {\n        case GGML_OP_GET_ROWS:\n            return 0;\n        case GGML_OP_MUL_MAT:\n            return op->ne[1];\n        case GGML_OP_MUL_MAT_ID:\n        case GGML_OP_ROPE:\n        case GGML_OP_ROPE_BACK:\n            return op->ne[2];\n        default:\n            return ggml_nrows(op);\n    }\n}\n\nstatic bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n    ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context;\n\n    return ggml_vk_get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;\n}\n\nstatic ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) {\n    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;\n    auto device = ggml_vk_get_device(ctx->device);\n\n    vk_event *vkev = new vk_event;\n    if (!vkev) {\n        return nullptr;\n    }\n\n    // The event/fence is expected to initially be in the signaled state.\n    vkev->event = device->device.createEvent({});\n    vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled});\n    device->device.setEvent(vkev->event);\n\n    return new ggml_backend_event {\n        /* .device  = */ dev,\n        /* .context = */ vkev,\n    };\n}\n\nstatic void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {\n    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;\n    auto device = ggml_vk_get_device(ctx->device);\n\n    vk_event *vkev = (vk_event *)event->context;\n\n    device->device.destroyFence(vkev->fence);\n    device->device.destroyEvent(vkev->event);\n    delete vkev;\n    delete event;\n}\n\nstatic void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {\n    VK_LOG_DEBUG(\"ggml_backend_vk_device_event_synchronize(backend=\" << dev << \", event=\" << event << \")\");\n    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;\n    auto device = ggml_vk_get_device(ctx->device);\n    vk_event *vkev = (vk_event *)event->context;\n\n    VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), \"event_synchronize\");\n    // Finished using current command buffer so we flag for reuse\n    if (vkev->cmd_buffer) {\n        vkev->cmd_buffer->in_use = false;\n    }\n}\n\nstatic vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) {\n    if (!device->external_memory_host) {\n        return {};\n    }\n\n    uintptr_t uptr = reinterpret_cast<uintptr_t>(ptr);\n    if (uptr & (device->min_imported_host_pointer_alignment - 1)) {\n        return {};\n    }\n    if (size & (device->min_imported_host_pointer_alignment - 1)) {\n        return {};\n    }\n\n    const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached;\n\n    vk_buffer buf {};\n    try {\n        buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr);\n    } catch (vk::SystemError& e) {\n        GGML_LOG_WARN(\"ggml_vulkan: Failed ggml_vk_create_buffer (%s)\\n\", e.what());\n    }\n\n    return buf;\n}\n\nstatic ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {\n    VK_LOG_DEBUG(\"ggml_backend_vk_device_buffer_from_host_ptr(backend=\" << dev << \", ptr=\" << ptr << \", size=\" << size << \")\");\n    GGML_UNUSED(max_tensor_size);\n\n    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;\n    auto device = ggml_vk_get_device(ctx->device);\n\n    vk_buffer buf = ggml_vk_buffer_from_host_ptr(device, ptr, size);\n\n    if (!buf) {\n        return {};\n    }\n\n    ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name);\n\n    ggml_backend_buffer_t ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size);\n\n    return ret;\n}\n\nstatic const struct ggml_backend_device_i ggml_backend_vk_device_i = {\n    /* .get_name             = */ ggml_backend_vk_device_get_name,\n    /* .get_description      = */ ggml_backend_vk_device_get_description,\n    /* .get_memory           = */ ggml_backend_vk_device_get_memory,\n    /* .get_type             = */ ggml_backend_vk_device_get_type,\n    /* .get_props            = */ ggml_backend_vk_device_get_props,\n    /* .init_backend         = */ ggml_backend_vk_device_init,\n    /* .get_buffer_type      = */ ggml_backend_vk_device_get_buffer_type,\n    /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,\n    /* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr,\n    /* .supports_op          = */ ggml_backend_vk_device_supports_op,\n    /* .supports_buft        = */ ggml_backend_vk_device_supports_buft,\n    /* .offload_op           = */ ggml_backend_vk_device_offload_op,\n    /* .event_new            = */ ggml_backend_vk_device_event_new,\n    /* .event_free           = */ ggml_backend_vk_device_event_free,\n    /* .event_synchronize    = */ ggml_backend_vk_device_event_synchronize,\n};\n\nstatic const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {\n    UNUSED(reg);\n    return GGML_VK_NAME;\n}\n\nstatic size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) {\n    UNUSED(reg);\n    return ggml_backend_vk_get_device_count();\n}\n\nstatic ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) {\n    static std::vector<ggml_backend_dev_t> devices;\n\n    static bool initialized = false;\n\n    {\n        static std::mutex mutex;\n        std::lock_guard<std::mutex> lock(mutex);\n        if (!initialized) {\n            const int min_batch_size = getenv(\"GGML_OP_OFFLOAD_MIN_BATCH\") ? atoi(getenv(\"GGML_OP_OFFLOAD_MIN_BATCH\")) : 32;\n            for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {\n                ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;\n                char desc[256];\n                ggml_backend_vk_get_device_description(i, desc, sizeof(desc));\n                ctx->device = i;\n                ctx->name = GGML_VK_NAME + std::to_string(i);\n                ctx->description = desc;\n                ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;\n                ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i);\n                ctx->op_offload_min_batch_size = min_batch_size;\n                devices.push_back(new ggml_backend_device {\n                    /* .iface   = */ ggml_backend_vk_device_i,\n                    /* .reg     = */ reg,\n                    /* .context = */ ctx,\n                });\n            }\n            initialized = true;\n        }\n    }\n\n    GGML_ASSERT(device < devices.size());\n    return devices[device];\n}\n\nstatic const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {\n    /* .get_name         = */ ggml_backend_vk_reg_get_name,\n    /* .get_device_count = */ ggml_backend_vk_reg_get_device_count,\n    /* .get_device       = */ ggml_backend_vk_reg_get_device,\n    /* .get_proc_address = */ NULL,\n};\n\nggml_backend_reg_t ggml_backend_vk_reg() {\n    static ggml_backend_reg reg = {\n        /* .api_version = */ GGML_BACKEND_API_VERSION,\n        /* .iface       = */ ggml_backend_vk_reg_i,\n        /* .context     = */ nullptr,\n    };\n    try {\n        ggml_vk_instance_init();\n        return &reg;\n    } catch (const vk::SystemError& e) {\n        VK_LOG_DEBUG(\"ggml_backend_vk_reg() -> Error: System error: \" << e.what());\n        return nullptr;\n    } catch (const std::exception &e) {\n        VK_LOG_DEBUG(\"ggml_backend_vk_reg() -> Error: \" << e.what());\n        return nullptr;\n    } catch (...) {\n        VK_LOG_DEBUG(\"ggml_backend_vk_reg() -> Error: unknown exception during Vulkan init\");\n        return nullptr;\n    }\n}\n\n// Extension availability\nstatic bool ggml_vk_instance_layer_settings_available() {\n#ifdef GGML_VULKAN_VALIDATE\n    // Check if validation layer provides the extension\n    const std::string layer_name = \"VK_LAYER_KHRONOS_validation\";\n    for (const auto& layer : vk::enumerateInstanceLayerProperties()) {\n        if (layer_name == layer.layerName.data()) {\n            for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) {\n                if (strcmp(\"VK_EXT_layer_settings\", ext.extensionName.data()) == 0) {\n                    return true;\n                }\n            }\n        }\n    }\n\n    std::cerr << \"ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_layer_settings not found.\" << std::endl;\n#endif\n    return false;\n}\nstatic bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {\n#ifdef __APPLE__\n    // Check for portability enumeration extension for MoltenVK support\n    for (const auto& properties : instance_extensions) {\n        if (strcmp(\"VK_KHR_portability_enumeration\", properties.extensionName) == 0) {\n            return true;\n        }\n    }\n    std::cerr << \"ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found.\" << std::endl;\n#endif\n    return false;\n\n    UNUSED(instance_extensions);\n}\n\n// Extension availability\nstatic bool ggml_vk_instance_debug_utils_ext_available(\n    const std::vector<vk::ExtensionProperties> & instance_extensions) {\n    // Check for portability enumeration extension for MoltenVK support\n    for (const auto & properties : instance_extensions) {\n        if (strcmp(\"VK_EXT_debug_utils\", properties.extensionName) == 0) {\n            return true;\n        }\n    }\n\n    std::cerr << \"ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found.\" << std::endl;\n    return false;\n\n    UNUSED(instance_extensions);\n}\n\nstatic bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) {\n    VkPhysicalDeviceFeatures2 device_features2;\n    device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;\n\n    VkPhysicalDeviceVulkan11Features vk11_features;\n    vk11_features.pNext = nullptr;\n    vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;\n    device_features2.pNext = &vk11_features;\n\n    vkGetPhysicalDeviceFeatures2(vkdev, &device_features2);\n\n    return vk11_features.storageBuffer16BitAccess;\n}\n\nstatic bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {\n    switch (props.vendorID) {\n    case VK_VENDOR_ID_INTEL:\n        // Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost,\n        // while some older hardware (ex. Arc A770) has performance regressions\n        return arch == vk_device_architecture::INTEL_XE2;\n    case VK_VENDOR_ID_AMD:\n        if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {\n            // Workaround for AMD proprietary driver reporting support on all GPUs\n            return arch == vk_device_architecture::AMD_RDNA3;\n        }\n        return true;\n    default:\n        return true;\n    }\n}\n\nstatic uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) {\n    VkPhysicalDeviceProperties2 props = vkdev.getProperties2();\n\n    if (props.properties.vendorID != VK_VENDOR_ID_INTEL) {\n        return 0;\n    }\n\n    const uint32_t device_id = props.properties.deviceID;\n\n    switch (device_id) {\n    case 0x56A6:  // A310\n        return 6;\n    case 0x5693:  // A370M\n    case 0x56A5:  // A380\n    case 0x56B1:  // Pro A40/A50\n        return 8;\n    case 0x5697:  // A530M\n        return 12;\n    case 0x5692:  // A550M\n    case 0x56B3:  // Pro A60\n        return 16;\n    case 0x56A2:  // A580\n        return 24;\n    case 0x5691:  // A730M\n    case 0x56A1:  // A750\n        return 28;\n    case 0x56A0:  // A770\n    case 0x5690:  // A770M\n        return 32;\n    case 0xE212:  // Pro B50\n        return 16;\n    case 0xE20C:  // B570\n        return 18;\n    case 0xE20B:  // B580\n        return 20;\n    default:\n        return 0;\n    }\n}\n\n// checks\n\n#ifdef GGML_VULKAN_CHECK_RESULTS\nstatic void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector<const ggml_tensor *>& done, int level = 0) {\n    if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) {\n        return;\n    }\n    for (int j = 0; j < level; j++) {\n        std::cerr << \" \";\n    }\n    std::cerr << ggml_op_name(tensor->op) << \" gpu=\" << (tensor->extra != nullptr) << std::endl;\n\n    done.push_back(tensor);\n\n    for (int i = 0; i < GGML_MAX_SRC; i++) {\n        if (tensor->src[i] != nullptr) {\n            ggml_vk_print_graph_origin(tensor->src[i], done, level + 1);\n        }\n    }\n}\n\nstatic void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) {\n    if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) {\n        return;\n    }\n    i0 = std::max(i0, 5);\n    i1 = std::max(i1, 5);\n    i2 = std::max(i2, 0);\n    i3 = std::max(i3, 0);\n    fprintf(stderr, \"         \");\n    for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {\n        fprintf(stderr, \"%7d \", idx1);\n    }\n    fprintf(stderr, \"\\n\");\n    for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {\n        fprintf(stderr, \"%7d: \", idx0);\n        for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {\n            if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {\n                float val;\n                if (tensor->type == GGML_TYPE_F32) {\n                    val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);\n                } else if (tensor->type == GGML_TYPE_F16) {\n                    val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));\n                } else if (tensor->type == GGML_TYPE_I32) {\n                    val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);\n                } else {\n                    GGML_ABORT(\"fatal error\");\n                }\n                fprintf(stderr, \"% 7.2f \", val);\n            } else {\n                fprintf(stderr, \"        \");\n            }\n        }\n        fprintf(stderr, \"\\n\");\n    }\n}\n\nstatic void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) {\n    void * tensor_data = tensor->data;\n\n    const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer);\n\n    if (is_gpu) {\n        const size_t tensor_size = ggml_nbytes(tensor);\n        tensor_data = malloc(tensor_size);\n\n        ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;\n\n        vk_buffer buffer_gpu = buf_ctx->dev_buffer;\n        ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size);\n    }\n\n    std::cerr << \"TENSOR CHECK \" << name << \" (\" << tensor->name << \"): \" << ggml_op_name(tensor->op) << std::endl;\n    std::cerr << \"tensor=\" << tensor << \" tensor->type: \" << ggml_type_name(tensor->type) << \" ne0=\" << tensor->ne[0] << \" nb0=\" << tensor->nb[0] << \" ne1=\" << tensor->ne[1] << \" nb1=\" << tensor->nb[1] << \" ne2=\" << tensor->ne[2] << \" nb2=\" << tensor->nb[2] << \" ne3=\" << tensor->ne[3] << \" nb3=\" << tensor->nb[3] << std::endl;\n    if (tensor->src[0] != nullptr) {\n        std::cerr << \"tensor->src[0]=\" << tensor->src[0] << \" name=\" << tensor->src[0]->name << \" op=\" << ggml_op_name(tensor->src[0]->op) << \" type=\" << ggml_type_name(tensor->src[0]->type) << \" ne0=\" << tensor->src[0]->ne[0] << \" nb0=\" << tensor->src[0]->nb[0] << \" ne1=\" << tensor->src[0]->ne[1] << \" nb1=\" << tensor->src[0]->nb[1] << \" ne2=\" << tensor->src[0]->ne[2] << \" nb2=\" << tensor->src[0]->nb[2] << \" ne3=\" << tensor->src[0]->ne[3] << \" nb3=\" << tensor->src[0]->nb[3] << std::endl;\n    }\n    if (tensor->src[1] != nullptr) {\n        std::cerr << \"tensor->src[1]=\" << tensor->src[1] << \" name=\" << tensor->src[1]->name << \" op=\" << ggml_op_name(tensor->src[1]->op) << \" type=\" << ggml_type_name(tensor->src[1]->type) << \" ne0=\" << tensor->src[1]->ne[0] << \" nb0=\" << tensor->src[1]->nb[0] << \" ne1=\" << tensor->src[1]->ne[1] << \" nb1=\" << tensor->src[1]->nb[1] << \" ne2=\" << tensor->src[1]->ne[2] << \" nb2=\" << tensor->src[1]->nb[2] << \" ne3=\" << tensor->src[1]->ne[3] << \" nb3=\" << tensor->src[1]->nb[3] << std::endl;\n    }\n    std::cerr << std::endl << \"Result:\" << std::endl;\n    ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);\n    std::cerr << std::endl;\n    std::vector<const ggml_tensor *> done;\n    ggml_vk_print_graph_origin(tensor, done);\n\n    if (is_gpu) {\n        free(tensor_data);\n    }\n}\n\nvoid * comp_result;\nsize_t comp_size;\nsize_t comp_nb[GGML_MAX_DIMS];\nsize_t check_counter = 0;\nstatic void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {\n    ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops];\n    if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {\n        return;\n    }\n\n    check_counter++;\n    if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {\n        return;\n    }\n\n    VK_LOG_DEBUG(\"ggml_vk_check_results_0(\" << tensor->name << \")\");\n\n    struct ggml_init_params iparams = {\n        /*.mem_size   =*/ 2ul*1024ul*1024ul*1024ul,\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ false,\n    };\n\n    struct ggml_context * ggml_ctx = ggml_init(iparams);\n\n    std::array<struct ggml_tensor *, GGML_MAX_SRC> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};\n    const char * srci_name[GGML_MAX_SRC] = {\"src0\", \"src1\", \"src2\", \"src3\", \"src4\", \"src5\", \"src6\", \"src7\", \"src8\", \"src9\"};\n\n    std::map<ggml_tensor *, ggml_tensor *> cloned_tensors;\n    std::vector<void *> cloned_mallocs;\n\n    struct ggml_tensor * tensor_clone = nullptr;\n\n    for (int f = 0; f < ctx->num_additional_fused_ops + 1; ++f) {\n        tensor = cgraph->nodes[tensor_idx + f];\n        for (int i = 0; i < GGML_MAX_SRC; i++) {\n            ggml_tensor * srci = tensor->src[i];\n            if (srci == nullptr) {\n                continue;\n            }\n            // If a src tensor has been cloned, use that one\n            auto it = cloned_tensors.find(srci);\n            if (it != cloned_tensors.end()) {\n                src_clone[i] = it->second;\n                continue;\n            }\n            ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci);\n            size_t srci_size = ggml_nbytes(srci);\n\n            src_clone[i] = srci_clone;\n            void *src_buffer = malloc(srci_size);\n            cloned_mallocs.push_back(src_buffer);\n\n            srci_clone->data = src_buffer;\n            if (ggml_backend_buffer_is_host(srci->buffer)) {\n                memcpy(srci_clone->data, srci->data, srci_size);\n                memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);\n            } else if (ggml_backend_buffer_is_vk(srci->buffer)) {\n                ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context;\n                vk_buffer& buffer_gpu = buf_ctx->dev_buffer;\n                uint64_t offset = vk_tensor_offset(srci) + srci->view_offs;\n                if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) {\n                    for (int i3 = 0; i3 < srci->ne[3]; i3++) {\n                        for (int i2 = 0; i2 < srci->ne[2]; i2++) {\n                            const int idx = i3*srci->ne[2] + i2;\n                            ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]);\n                        }\n                    }\n\n                    srci_clone->nb[0] = srci->nb[0];\n                    srci_clone->nb[1] = srci->nb[1];\n                    for (int i = 2; i < GGML_MAX_DIMS; i++) {\n                        srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1];\n                    }\n                } else {\n                    if (offset + srci_size >= buffer_gpu->size) {\n                        srci_size = buffer_gpu->size - offset;\n                    }\n                    ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size);\n                    memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);\n                }\n            } else {\n                GGML_ABORT(\"fatal error\");\n            }\n\n            if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {\n                ggml_vk_print_tensor(srci, srci_name[i]);\n            }\n        }\n\n        if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {\n            const float * params = (const float *)tensor->op_params;\n            tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);\n            if (src_clone[4]) {\n                ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]);\n            }\n        } else if (tensor->op == GGML_OP_MUL_MAT) {\n            tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);\n        } else if (tensor->op == GGML_OP_MUL_MAT_ID) {\n            tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);\n        } else if (tensor->op == GGML_OP_SUB) {\n            tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);\n        } else if (tensor->op == GGML_OP_MUL) {\n            tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);\n        } else if (tensor->op == GGML_OP_DIV) {\n            tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);\n        } else if (tensor->op == GGML_OP_CONCAT) {\n            tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);\n        } else if (tensor->op == GGML_OP_UPSCALE) {\n            tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);\n        } else if (tensor->op == GGML_OP_SCALE) {\n            const float * params = (const float *)tensor->op_params;\n            tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);\n        } else if (tensor->op == GGML_OP_ADD1) {\n            tensor_clone = ggml_add1(ggml_ctx, src_clone[0], src_clone[1]);\n        } else if (tensor->op == GGML_OP_ARANGE) {\n            const float start = ggml_get_op_params_f32(tensor, 0);\n            const float stop = ggml_get_op_params_f32(tensor, 1);\n            const float step = ggml_get_op_params_f32(tensor, 2);\n            tensor_clone = ggml_arange(ggml_ctx, start, stop, step);\n        } else if (tensor->op == GGML_OP_FILL) {\n            const float value = ggml_get_op_params_f32(tensor, 0);\n            tensor_clone = ggml_fill(ggml_ctx, src_clone[0], value);\n        } else if (tensor->op == GGML_OP_SQR) {\n            tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);\n        } else if (tensor->op == GGML_OP_SQRT) {\n            tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]);\n        } else if (tensor->op == GGML_OP_SIN) {\n            tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);\n        } else if (tensor->op == GGML_OP_COS) {\n            tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);\n        } else if (tensor->op == GGML_OP_LOG) {\n            tensor_clone = ggml_log(ggml_ctx, src_clone[0]);\n        } else if (tensor->op == GGML_OP_TRI) {\n            tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type)ggml_get_op_params_i32(tensor, 0));\n        } else if (tensor->op == GGML_OP_DIAG) {\n            tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);\n        } else if (tensor->op == GGML_OP_CLAMP) {\n            const float * params = (const float *)tensor->op_params;\n            tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);\n        } else if (tensor->op == GGML_OP_PAD) {\n            tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3],\n                                                                tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]);\n        } else if (tensor->op == GGML_OP_REPEAT) {\n            tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);\n        } else if (tensor->op == GGML_OP_REPEAT_BACK) {\n            tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor);\n        } else if (tensor->op == GGML_OP_ADD) {\n            tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);\n        } else if (tensor->op == GGML_OP_ACC) {\n            tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);\n        } else if (tensor->op == GGML_OP_SET) {\n            tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);\n        } else if (tensor->op == GGML_OP_NORM) {\n            tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);\n        } else if (tensor->op == GGML_OP_GROUP_NORM) {\n            const float * float_params = (const float *)tensor->op_params;\n            tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);\n        } else if (tensor->op == GGML_OP_RMS_NORM) {\n            tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);\n        } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {\n            const float eps = ((float *) tensor->op_params)[0];\n            tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);\n        } else if (tensor->op == GGML_OP_SILU_BACK) {\n            tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);\n        } else if (tensor->op == GGML_OP_L2_NORM) {\n            const float eps = ((float *) tensor->op_params)[0];\n            tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);\n        } else if (tensor->op == GGML_OP_SOFT_MAX) {\n            if (tensor->src[1] != nullptr) {\n                const float * params = (const float *)tensor->op_params;\n                tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);\n            } else {\n                tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);\n            }\n        } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {\n            tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);\n        } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {\n            tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);\n        } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {\n            const int n_dims      = ((int32_t *) tensor->op_params)[1];\n            const int mode        = ((int32_t *) tensor->op_params)[2];\n            //const int n_ctx_ggml       = ((int32_t *) tensor->op_params)[3];\n            const int n_ctx_orig_ggml  = ((int32_t *) tensor->op_params)[4];\n            const float freq_base       = ((float *) tensor->op_params)[5];\n            const float freq_scale      = ((float *) tensor->op_params)[6];\n            const float ext_factor      = ((float *) tensor->op_params)[7];\n            const float attn_factor     = ((float *) tensor->op_params)[8];\n            const float beta_fast       = ((float *) tensor->op_params)[9];\n            const float beta_slow       = ((float *) tensor->op_params)[10];\n            if (mode & GGML_ROPE_TYPE_MROPE) {\n                int32_t *sections = ((int32_t *) tensor->op_params) + 11;\n                if (tensor->op == GGML_OP_ROPE) {\n                    tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);\n                } else {\n                    tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);\n                }\n            } else {\n                if (tensor->op == GGML_OP_ROPE) {\n                    tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);\n                } else {\n                    tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);\n                }\n            }\n        } else if (tensor->op == GGML_OP_UNARY) {\n            switch (ggml_get_unary_op(tensor)) {\n            case GGML_UNARY_OP_EXP:\n                tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_ELU:\n                tensor_clone = ggml_elu(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_SILU:\n                tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_GELU:\n                tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_GELU_ERF:\n                tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_GELU_QUICK:\n                tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_RELU:\n                tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_XIELU:\n                tensor_clone = ggml_xielu(ggml_ctx, src_clone[0], 0, 0, 0, 0);\n                ggml_set_op_params_f32(tensor_clone, 1, ggml_get_op_params_f32(tensor, 1));\n                ggml_set_op_params_f32(tensor_clone, 2, ggml_get_op_params_f32(tensor, 2));\n                ggml_set_op_params_f32(tensor_clone, 3, ggml_get_op_params_f32(tensor, 3));\n                ggml_set_op_params_f32(tensor_clone, 4, ggml_get_op_params_f32(tensor, 4));\n                break;\n            case GGML_UNARY_OP_NEG:\n                tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_TANH:\n                tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_SIGMOID:\n                tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_HARDSIGMOID:\n                tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_HARDSWISH:\n                tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_ABS:\n                tensor_clone = ggml_abs(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_SOFTPLUS:\n                tensor_clone = ggml_softplus(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_STEP:\n                tensor_clone = ggml_step(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_ROUND:\n                tensor_clone = ggml_round(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_CEIL:\n                tensor_clone = ggml_ceil(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_FLOOR:\n                tensor_clone = ggml_floor(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_TRUNC:\n                tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);\n                break;\n            case GGML_UNARY_OP_SGN:\n                tensor_clone = ggml_sgn(ggml_ctx, src_clone[0]);\n                break;\n            default:\n                std::cerr << \"Missing vk_check_results OP: \" << ggml_op_name(tensor->op) << std::endl;\n                GGML_ABORT(\"fatal error\");\n            }\n        } else if (tensor->op == GGML_OP_GLU) {\n            if (src_clone[1] == nullptr) {\n                tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);\n            } else {\n                tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);\n            }\n            ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2));\n            ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3));\n        } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {\n            if (tensor->src[1] == nullptr) {\n                tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);\n                tensor_clone->type = tensor->type;\n            } else {\n                tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);\n            }\n        } else if (tensor->op == GGML_OP_CONT) {\n            tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);\n        } else if (tensor->op == GGML_OP_RESHAPE) {\n            tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);\n        } else if (tensor->op == GGML_OP_VIEW) {\n            tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);\n        } else if (tensor->op == GGML_OP_PERMUTE) {\n            int32_t * params = (int32_t *)tensor->op_params;\n            tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]);\n        } else if (tensor->op == GGML_OP_TRANSPOSE) {\n            tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]);\n        } else if (tensor->op == GGML_OP_GET_ROWS) {\n            tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);\n        } else if (tensor->op == GGML_OP_ARGSORT) {\n            tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);\n        } else if (tensor->op == GGML_OP_TOP_K) {\n            tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]);\n        } else if (tensor->op == GGML_OP_SUM) {\n            tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);\n        } else if (tensor->op == GGML_OP_SUM_ROWS) {\n            tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);\n        } else if (tensor->op == GGML_OP_CUMSUM) {\n            tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);\n        } else if (tensor->op == GGML_OP_MEAN) {\n            tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);\n        } else if (tensor->op == GGML_OP_ARGMAX) {\n            tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);\n        } else if (tensor->op == GGML_OP_COUNT_EQUAL) {\n            tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);\n        } else if (tensor->op == GGML_OP_SOLVE_TRI) {\n            tensor_clone = ggml_solve_tri(ggml_ctx, src_clone[0], src_clone[1], true, true, false);\n        } else if (tensor->op == GGML_OP_IM2COL) {\n            const int32_t s0 = tensor->op_params[0];\n            const int32_t s1 = tensor->op_params[1];\n            const int32_t p0 = tensor->op_params[2];\n            const int32_t p1 = tensor->op_params[3];\n            const int32_t d0 = tensor->op_params[4];\n            const int32_t d1 = tensor->op_params[5];\n\n            const bool is_2D = tensor->op_params[6] == 1;\n            tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);\n        } else if (tensor->op == GGML_OP_IM2COL_3D) {\n            const int32_t s0 = tensor->op_params[0];\n            const int32_t s1 = tensor->op_params[1];\n            const int32_t s2 = tensor->op_params[2];\n            const int32_t p0 = tensor->op_params[3];\n            const int32_t p1 = tensor->op_params[4];\n            const int32_t p2 = tensor->op_params[5];\n            const int32_t d0 = tensor->op_params[6];\n            const int32_t d1 = tensor->op_params[7];\n            const int32_t d2 = tensor->op_params[8];\n            const int32_t IC = tensor->op_params[9];\n\n            tensor_clone = ggml_im2col_3d(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type);\n        } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {\n            const int32_t dim = tensor->op_params[0];\n            const int32_t max_period = tensor->op_params[1];\n            tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);\n        } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){\n            const int32_t s0 = tensor->op_params[0];\n            const int32_t p0 = tensor->op_params[1];\n            const int32_t d0 = tensor->op_params[2];\n            tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);\n        } else if (tensor->op == GGML_OP_POOL_2D) {\n            enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);\n            const int32_t k0 = tensor->op_params[1];\n            const int32_t k1 = tensor->op_params[2];\n            const int32_t s0 = tensor->op_params[3];\n            const int32_t s1 = tensor->op_params[4];\n            const int32_t p0 = tensor->op_params[5];\n            const int32_t p1 = tensor->op_params[6];\n\n            tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);\n        } else if (tensor->op == GGML_OP_CONV_2D) {\n            const int32_t s0 = tensor->op_params[0];\n            const int32_t s1 = tensor->op_params[1];\n            const int32_t p0 = tensor->op_params[2];\n            const int32_t p1 = tensor->op_params[3];\n            const int32_t d0 = tensor->op_params[4];\n            const int32_t d1 = tensor->op_params[5];\n            tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);\n        } else if (tensor->op == GGML_OP_CONV_2D_DW) {\n            const int32_t s0 = tensor->op_params[0];\n            const int32_t s1 = tensor->op_params[1];\n            const int32_t p0 = tensor->op_params[2];\n            const int32_t p1 = tensor->op_params[3];\n            const int32_t d0 = tensor->op_params[4];\n            const int32_t d1 = tensor->op_params[5];\n            tensor_clone = ggml_conv_2d_dw_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);\n        } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) {\n            const int32_t s = tensor->op_params[0];\n            tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s);\n        } else if (tensor->op == GGML_OP_LEAKY_RELU) {\n            const float * op_params = (const float *)tensor->op_params;\n            tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);\n        } else if (tensor->op == GGML_OP_RWKV_WKV6) {\n            tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],\n            src_clone[2], src_clone[3], src_clone[4], src_clone[5]);\n        } else if (tensor->op == GGML_OP_RWKV_WKV7) {\n            tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],\n            src_clone[4], src_clone[5], src_clone[6]);\n        } else if (tensor->op == GGML_OP_GATED_DELTA_NET) {\n            tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],\n            src_clone[2], src_clone[3], src_clone[4], src_clone[5]);\n        } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {\n            src_clone[0]->flags = tensor->src[0]->flags;\n            tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],\n            src_clone[2], src_clone[3], src_clone[4]);\n        } else if (tensor->op == GGML_OP_OPT_STEP_SGD) {\n            src_clone[0]->flags = tensor->src[0]->flags;\n            tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],\n            src_clone[2]);\n        } else if (tensor->op == GGML_OP_ADD_ID) {\n            tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);\n        } else if (tensor->op == GGML_OP_SSM_SCAN) {\n            tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],\n                                         src_clone[3], src_clone[4], src_clone[5], src_clone[6]);\n        } else if (tensor->op == GGML_OP_SSM_CONV) {\n            tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);\n        } else if (tensor->op == GGML_OP_ROLL) {\n            const int32_t s0 = tensor->op_params[0];\n            const int32_t s1 = tensor->op_params[1];\n            const int32_t s2 = tensor->op_params[2];\n            const int32_t s3 = tensor->op_params[3];\n            tensor_clone = ggml_roll(ggml_ctx, src_clone[0], s0, s1, s2, s3);\n        }\n        else {\n            std::cerr << \"Missing vk_check_results OP: \" << ggml_op_name(tensor->op) << std::endl;\n            GGML_ABORT(\"fatal error\");\n        }\n        cloned_tensors[tensor] = tensor_clone;\n    }\n\n    ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);\n    ggml_build_forward_expand(cgraph_cpu, tensor_clone);\n\n    ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);\n\n    if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {\n        ggml_vk_print_tensor(tensor_clone, \"tensor_clone\");\n    }\n\n    comp_size = ggml_nbytes(tensor_clone);\n\n    comp_result = malloc(comp_size);\n    memcpy(comp_result, tensor_clone->data, comp_size);\n    memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);\n\n    for (auto m : cloned_mallocs) {\n        free(m);\n    }\n\n    ggml_free(ggml_ctx);\n\n    VK_LOG_DEBUG(\"END ggml_vk_check_results_0(\" << tensor->name << \")\");\n}\n\nstatic void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {\n    ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops];\n    if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {\n        return;\n    }\n\n    if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {\n        return;\n    }\n\n    VK_LOG_DEBUG(\"ggml_vk_check_results_1(\" << tensor->name << \")\");\n\n    ggml_tensor * src0 = tensor->src[0];\n    ggml_tensor * src1 = tensor->src[1];\n    ggml_tensor * src2 = tensor->src[2];\n    ggml_tensor * src3 = tensor->src[3];\n\n    void * tensor_data = tensor->data;\n\n    if (ggml_backend_buffer_is_vk(tensor->buffer)) {\n        size_t tensor_size = ggml_nbytes(tensor);\n        tensor_data = malloc(tensor_size);\n\n        ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;\n\n        vk_buffer& buffer_gpu = buf_ctx->dev_buffer;\n        uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs;\n        if (offset + tensor_size >= buffer_gpu->size) {\n            tensor_size = buffer_gpu->size - offset;\n        }\n\n        ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size);\n    }\n\n    float first_error_result = -1.0f;\n    float first_error_correct = -1.0f;\n    std::array<int, 4> first_error = { -1, -1, -1, -1 };\n    double avg_err = 0.0;\n    size_t counter = 0;\n\n    for (int i3 = 0; i3 < tensor->ne[3]; i3++) {\n        for (int i2 = 0; i2 < tensor->ne[2]; i2++) {\n            for (int i1 = 0; i1 < tensor->ne[1]; i1++) {\n                for (int i0 = 0; i0 < tensor->ne[0]; i0++) {\n                    const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size;\n                    float correct = 0.0f;\n                    float result = 0.0f;\n\n                    if (buffer_size_fit) {\n                        if (tensor->type == GGML_TYPE_F32) {\n                            correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);\n                            result  = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);\n                        } else if (tensor->type == GGML_TYPE_F16) {\n                            correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));\n                            result  = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));\n                        } else if (tensor->type == GGML_TYPE_BF16) {\n                            correct = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));\n                            result  = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));\n                        } else if (tensor->type == GGML_TYPE_I32) {\n                            correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);\n                            result  = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);\n                        } else if (tensor->type == GGML_TYPE_I64) {\n                            correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);\n                            result  = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);\n                        } else {\n                            std::cerr << \"Results check not implemented for type \" << ggml_type_name(tensor->type) << std::endl;\n                        }\n                    } else {\n                        std::cerr << \"Missing debug code for type \" << ggml_type_name(tensor->type) << std::endl;\n                        GGML_ABORT(\"fatal error\");\n                    }\n\n                    if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) {\n                        std::cerr << \"ERROR: Invalid value in \" << ggml_op_name(tensor->op) << \" i3=\" << i3 << \" i2=\" << i2 << \" i1=\" << i1 << \" i0=\" << i0 << \" result=\" << result << \" correct=\" << correct << \" avg_err=\" << (avg_err / counter) << std::endl;\n                        std::cerr << \"tensor=\" << tensor << \" tensor->name=\" << tensor->name << \" tensor->type: \" << ggml_type_name(tensor->type) << \" ne0=\" << tensor->ne[0] << \" nb0=\" << tensor->nb[0] << \" ne1=\" << tensor->ne[1] << \" nb1=\" << tensor->nb[1] << \" ne2=\" << tensor->ne[2] << \" nb2=\" << tensor->nb[2] << \" ne3=\" << tensor->ne[3] << \" nb3=\" << tensor->nb[3] << \" offset=\" << tensor->view_offs << std::endl;\n                        if (src0 != nullptr) {\n                            std::cerr << \"src0=\" << src0 << \" src0->name=\" << src0->name << \" op=\" << ggml_op_name(src0->op) << \" type=\" << ggml_type_name(src0->type) << \" ne0=\" << src0->ne[0] << \" nb0=\" << src0->nb[0] << \" ne1=\" << src0->ne[1] << \" nb1=\" << src0->nb[1] << \" ne2=\" << src0->ne[2] << \" nb2=\" << src0->nb[2] << \" ne3=\" << src0->ne[3] << \" nb3=\" << src0->nb[3] << \" offset=\" << src0->view_offs << std::endl;\n                        }\n                        if (src1 != nullptr) {\n                            std::cerr << \"src1=\" << src1 << \" src1->name=\" << src1->name << \" op=\" << ggml_op_name(src1->op) << \" type=\" << ggml_type_name(src1->type) << \" ne0=\" << src1->ne[0] << \" nb0=\" << src1->nb[0] << \" ne1=\" << src1->ne[1] << \" nb1=\" << src1->nb[1] << \" ne2=\" << src1->ne[2] << \" nb2=\" << src1->nb[2] << \" ne3=\" << src1->ne[3] << \" nb3=\" << src1->nb[3] << \" offset=\" << src1->view_offs << std::endl;\n                        }\n                        if (src2 != nullptr) {\n                            std::cerr << \"src2=\" << src2 << \" src2->name=\" << src2->name << \" op=\" << ggml_op_name(src2->op) << \" type=\" << ggml_type_name(src2->type) << \" ne0=\" << src2->ne[0] << \" nb0=\" << src2->nb[0] << \" ne1=\" << src2->ne[1] << \" nb1=\" << src2->nb[1] << \" ne2=\" << src2->ne[2] << \" nb2=\" << src2->nb[2] << \" ne3=\" << src2->ne[3] << \" nb3=\" << src2->nb[3] << \" offset=\" << src2->view_offs << std::endl;\n                        }\n                        if (src3 != nullptr) {\n                            std::cerr << \"src3=\" << src3 << \" src3->name=\" << src3->name << \" op=\" << ggml_op_name(src3->op) << \" type=\" << ggml_type_name(src3->type) << \" ne0=\" << src3->ne[0] << \" nb0=\" << src3->nb[0] << \" ne1=\" << src3->ne[1] << \" nb1=\" << src3->nb[1] << \" ne2=\" << src3->ne[2] << \" nb2=\" << src3->nb[2] << \" ne3=\" << src3->ne[3] << \" nb3=\" << src3->nb[3] << \" offset=\" << src3->view_offs << std::endl;\n                        }\n                        std::cerr << \"First error: result=\" << first_error_result << \" correct=\" << first_error_correct  << \" i3=\" << first_error[3] << \" i2=\" << first_error[2] << \" i1=\" << first_error[1] << \" i0=\" << first_error[0] << std::endl;\n                        std::cerr << std::endl << \"Result:\" << std::endl;\n                        ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);\n                        std::cerr << std::endl << \"Correct:\" << std::endl;\n                        ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3);\n                        std::cerr << std::endl;\n                        std::vector<const ggml_tensor *> done;\n                        ggml_vk_print_graph_origin(tensor, done);\n                        GGML_ABORT(\"fatal error\");\n                    }\n                    const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f;\n                    if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) {\n                        first_error[0] = i0;\n                        first_error[1] = i1;\n                        first_error[2] = i2;\n                        first_error[3] = i3;\n                        first_error_result = result;\n                        first_error_correct = correct;\n                    }\n\n                    // Special case, value is infinite, avoid NaN result in avg_err\n                    // NaN also appears in results, if both are nan error is 0\n                    if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) {\n                        avg_err += std::fabs(correct - result) / denom;\n                    }\n                    counter++;\n                }\n            }\n        }\n    }\n\n    avg_err /= counter;\n\n    if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {\n        std::cerr << \"TENSOR CHECK: avg_err=\" << avg_err << \" in \" << ggml_op_name(tensor->op) << \" (check \" << check_counter << \")\" << std::endl;\n        std::cerr << \"tensor=\" << tensor << \" tensor->name=\" << tensor->name << \" tensor->type: \" << ggml_type_name(tensor->type) << \" ne0=\" << tensor->ne[0] << \" nb0=\" << tensor->nb[0] << \" ne1=\" << tensor->ne[1] << \" nb1=\" << tensor->nb[1] << \" ne2=\" << tensor->ne[2] << \" nb2=\" << tensor->nb[2] << \" ne3=\" << tensor->ne[3] << \" nb3=\" << tensor->nb[3] << \" offset=\" << tensor->view_offs << std::endl;\n        if (src0 != nullptr) {\n            std::cerr << \"src0=\" << src0 << \" op=\" << ggml_op_name(src0->op) << \" type=\" << ggml_type_name(src0->type) << \" ne0=\" << src0->ne[0] << \" nb0=\" << src0->nb[0] << \" ne1=\" << src0->ne[1] << \" nb1=\" << src0->nb[1] << \" ne2=\" << src0->ne[2] << \" nb2=\" << src0->nb[2] << \" ne3=\" << src0->ne[3] << \" nb3=\" << src0->nb[3] << \" offset=\" << src0->view_offs << std::endl;\n        }\n        if (src1 != nullptr) {\n            std::cerr << \"src1=\" << src1 << \" op=\" << ggml_op_name(src1->op) << \" type=\" << ggml_type_name(src1->type) << \" ne0=\" << src1->ne[0] << \" nb0=\" << src1->nb[0] << \" ne1=\" << src1->ne[1] << \" nb1=\" << src1->nb[1] << \" ne2=\" << src1->ne[2] << \" nb2=\" << src1->nb[2] << \" ne3=\" << src1->ne[3] << \" nb3=\" << src1->nb[3] << \" offset=\" << src1->view_offs << std::endl;\n        }\n        if (src2 != nullptr) {\n            std::cerr << \"src2=\" << src2 << \" op=\" << ggml_op_name(src2->op) << \" type=\" << ggml_type_name(src2->type) << \" ne0=\" << src2->ne[0] << \" nb0=\" << src2->nb[0] << \" ne1=\" << src2->ne[1] << \" nb1=\" << src2->nb[1] << \" ne2=\" << src2->ne[2] << \" nb2=\" << src2->nb[2] << \" ne3=\" << src2->ne[3] << \" nb3=\" << src2->nb[3] << \" offset=\" << src2->view_offs << std::endl;\n        }\n        if (src3 != nullptr) {\n            std::cerr << \"src3=\" << src3 << \" op=\" << ggml_op_name(src3->op) << \" type=\" << ggml_type_name(src3->type) << \" ne0=\" << src3->ne[0] << \" nb0=\" << src3->nb[0] << \" ne1=\" << src3->ne[1] << \" nb1=\" << src3->nb[1] << \" ne2=\" << src3->ne[2] << \" nb2=\" << src3->nb[2] << \" ne3=\" << src3->ne[3] << \" nb3=\" << src3->nb[3] << \" offset=\" << src3->view_offs << std::endl;\n        }\n        std::cerr << \"First error: result=\" << first_error_result << \" correct=\" << first_error_correct  << \" i3=\" << first_error[3] << \" i2=\" << first_error[2] << \" i1=\" << first_error[1] << \" i0=\" << first_error[0] << std::endl;\n        std::cerr << std::endl << \"Result:\" << std::endl;\n        ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);\n        std::cerr << std::endl << \"Correct:\" << std::endl;\n        ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0);\n        std::cerr << std::endl;\n        std::vector<const ggml_tensor *> done;\n        ggml_vk_print_graph_origin(tensor, done);\n    }\n\n    if (avg_err > 0.01 || std::isnan(avg_err)) {\n        std::cerr << \"ERROR: avg_err=\" << avg_err << \" in \" << ggml_op_name(tensor->op) << \" (check \" << check_counter << \")\" << std::endl;\n        std::cerr << \"tensor=\" << tensor << \" tensor->name=\" << tensor->name << \" tensor->type: \" << ggml_type_name(tensor->type) << \" ne0=\" << tensor->ne[0] << \" nb0=\" << tensor->nb[0] << \" ne1=\" << tensor->ne[1] << \" nb1=\" << tensor->nb[1] << \" ne2=\" << tensor->ne[2] << \" nb2=\" << tensor->nb[2] << \" ne3=\" << tensor->ne[3] << \" nb3=\" << tensor->nb[3] << \" offset=\" << tensor->view_offs << std::endl;\n        if (src0 != nullptr) {\n            std::cerr << \"src0=\" << src0 << \" op=\" << ggml_op_name(src0->op) << \" type=\" << ggml_type_name(src0->type) << \" ne0=\" << src0->ne[0] << \" nb0=\" << src0->nb[0] << \" ne1=\" << src0->ne[1] << \" nb1=\" << src0->nb[1] << \" ne2=\" << src0->ne[2] << \" nb2=\" << src0->nb[2] << \" ne3=\" << src0->ne[3] << \" nb3=\" << src0->nb[3] << \" offset=\" << src0->view_offs << std::endl;\n        }\n        if (src1 != nullptr) {\n            std::cerr << \"src1=\" << src1 << \" op=\" << ggml_op_name(src1->op) << \" type=\" << ggml_type_name(src1->type) << \" ne0=\" << src1->ne[0] << \" nb0=\" << src1->nb[0] << \" ne1=\" << src1->ne[1] << \" nb1=\" << src1->nb[1] << \" ne2=\" << src1->ne[2] << \" nb2=\" << src1->nb[2] << \" ne3=\" << src1->ne[3] << \" nb3=\" << src1->nb[3] << \" offset=\" << src1->view_offs << std::endl;\n        }\n        if (src2 != nullptr) {\n            std::cerr << \"src2=\" << src2 << \" op=\" << ggml_op_name(src2->op) << \" type=\" << ggml_type_name(src2->type) << \" ne0=\" << src2->ne[0] << \" nb0=\" << src2->nb[0] << \" ne1=\" << src2->ne[1] << \" nb1=\" << src2->nb[1] << \" ne2=\" << src2->ne[2] << \" nb2=\" << src2->nb[2] << \" ne3=\" << src2->ne[3] << \" nb3=\" << src2->nb[3] << \" offset=\" << src2->view_offs << std::endl;\n        }\n        if (src3 != nullptr) {\n            std::cerr << \"src3=\" << src3 << \" op=\" << ggml_op_name(src3->op) << \" type=\" << ggml_type_name(src3->type) << \" ne0=\" << src3->ne[0] << \" nb0=\" << src3->nb[0] << \" ne1=\" << src3->ne[1] << \" nb1=\" << src3->nb[1] << \" ne2=\" << src3->ne[2] << \" nb2=\" << src3->nb[2] << \" ne3=\" << src3->ne[3] << \" nb3=\" << src3->nb[3] << \" offset=\" << src3->view_offs << std::endl;\n        }\n        std::cerr << \"First error: result=\" << first_error_result << \" correct=\" << first_error_correct  << \" i3=\" << first_error[3] << \" i2=\" << first_error[2] << \" i1=\" << first_error[1] << \" i0=\" << first_error[0] << std::endl;\n        std::cerr << std::endl << \"Result:\" << std::endl;\n        ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);\n        std::cerr << std::endl << \"Correct:\" << std::endl;\n        ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]);\n        std::cerr << std::endl;\n        std::vector<const ggml_tensor *> done;\n        ggml_vk_print_graph_origin(tensor, done);\n        GGML_ABORT(\"fatal error\");\n    } else {\n        std::cerr << check_counter << \" \" << tensor->name << \" op=\" << ggml_op_name(tensor->op) << \" avg_err=\" << avg_err << std::endl;\n    }\n\n    free(comp_result);\n    comp_result = nullptr;\n    comp_size = 0;\n\n    if (ggml_backend_buffer_is_vk(tensor->buffer)) {\n        free(tensor_data);\n    }\n\n    VK_LOG_DEBUG(\"END ggml_vk_check_results_1(\" << tensor->name << \")\");\n}\n#endif\n\nGGML_BACKEND_DL_IMPL(ggml_backend_vk_reg)\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.19)\nproject(\"vulkan-shaders-gen\" C CXX)\n\nfind_package (Threads REQUIRED)\n\nif (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)\n    add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)\n    message(STATUS \"Enabling coopmat glslc support\")\nendif()\nif (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)\n    add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)\n    message(STATUS \"Enabling coopmat2 glslc support\")\nendif()\nif (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)\n    add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)\n    message(STATUS \"Enabling dot glslc support\")\nendif()\nif (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)\n    add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)\n    message(STATUS \"Enabling bfloat16 glslc support\")\nendif()\nif (GGML_VULKAN_SHADER_DEBUG_INFO)\n    add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)\n    message(STATUS \"Enabling shader debug info\")\nendif()\n\nset(TARGET vulkan-shaders-gen)\nadd_executable(${TARGET} vulkan-shaders-gen.cpp)\ninstall(TARGETS ${TARGET} RUNTIME)\ntarget_compile_features(${TARGET} PRIVATE cxx_std_17)\ntarget_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/abs.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    data_d[i] = D_TYPE(abs(float(data_a[i])));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/acc.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_binary_head.glsl\"\n\n// false for SET, true for ACC\nlayout(constant_id = 1) const bool ACC = true;\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint idx = gl_GlobalInvocationID.x;\n    if (idx >= p.ne) {\n        return;\n    }\n\n    const uint offset = p.param3;\n    const uint src1_i = idx - offset;\n    const uint i3 = src1_i / p.nb03;\n    const uint rem2 = src1_i - i3 * p.nb03;\n    const uint i2 = rem2 / p.nb02;\n    const uint rem1 = rem2 - i2 * p.nb02;\n    const uint i1 = rem1 / p.nb01;\n    const uint i0 = rem1 % p.nb01;\n\n    uint i00, i01, i02, i03;\n\n    if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {\n        if (ACC) {\n            data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));\n        } else {\n            data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));\n        }\n    } else {\n        data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/add.comp",
    "content": "#version 450\n\n#extension GL_EXT_shader_16bit_storage : require\n#if ADD_RMS\n#extension GL_KHR_shader_subgroup_arithmetic : enable\n#extension GL_KHR_shader_subgroup_basic : enable\n#endif\n\n#include \"types.glsl\"\n#include \"generic_binary_head.glsl\"\n\nconst uint num_threads = 256;\n\nlayout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};\n\nlayout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;\n\n#if ADD_RMS\n// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant\nshared FLOAT_TYPE sumsh[num_threads];\n#endif\n\nvoid main() {\n    uint idx = get_idx();\n    uint orig_idx = idx;\n\n    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation\n    const uint num_iter = 2;\n\n    FLOAT_TYPE sum_sq = 0;\n\n    [[unroll]] for (uint i = 0; i < num_iter; ++i) {\n        if (idx >= p.ne) {\n            continue;\n        }\n        uint i00, i01, i02, i03;\n        get_indices(idx, i00, i01, i02, i03);\n\n        FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);\n        sum_sq += sum*sum;\n\n        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);\n\n        idx += num_threads;\n    }\n\n#if ADD_RMS\n    if (p.param3 != 0) {\n        // reduce the sum within each subgroup, then across subgroups\n        const uint NumSubgroups = num_threads / gl_SubgroupSize;\n        sum_sq = subgroupAdd(sum_sq);\n        if (gl_SubgroupInvocationID == 0) {\n            sumsh[gl_SubgroupID] = sum_sq;\n        }\n        barrier();\n        [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {\n            if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {\n                sum_sq += sumsh[gl_SubgroupID + s];\n                sumsh[gl_SubgroupID] = sum_sq;\n            }\n            barrier();\n        }\n\n        if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {\n            partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;\n        }\n    }\n#endif\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/add1.comp",
    "content": "#version 450\n\n#extension GL_EXT_shader_16bit_storage : require\n\n#include \"types.glsl\"\n#include \"generic_binary_head.glsl\"\n\nconst uint num_threads = 256;\n\nlayout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    uint idx = get_idx();\n\n    const uint num_iter = 2;\n\n    [[unroll]] for (uint i = 0; i < num_iter; ++i) {\n        if (idx >= p.ne) {\n            continue;\n        }\n        uint i00, i01, i02, i03;\n        get_indices(idx, i00, i01, i02, i03);\n\n        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset()]));\n\n        idx += num_threads;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/add_id.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : require\n\n#include \"types.glsl\"\n\nlayout (push_constant) uniform parameter\n{\n    uint ne0;\n    uint ne1;\n    uint s01;\n    uint s02;\n    uint s11;\n    uint s21;\n} p;\n\n#define BLOCK_SIZE 512\n\nlayout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) readonly buffer Y {B_TYPE data_b[];};\nlayout (binding = 2) readonly buffer Z {int32_t data_c[];};\nlayout (binding = 3) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i1 = gl_WorkGroupID.x;\n    const uint i2 = gl_WorkGroupID.y;\n\n    const uint i11 = data_c[i1 + i2 * p.s21];\n\n    const uint s1 = p.ne0;\n    const uint s2 = p.ne0 * p.ne1;\n\n    const uint d0 = i1 * s1 + i2 * s2;\n    const uint a0 = i1 * p.s01 + i2 * p.s02;\n    const uint b0 = i11 * p.s11;\n\n    for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) {\n        data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0];\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/arange.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    // p.param1 = start, p.param2 = step\n    float value = p.param1 + p.param2 * float(i);\n    data_d[i] = D_TYPE(value);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/argmax.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\n#define FLT_MAX 3.402823466e+38F\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nlayout (constant_id = 0) const uint BLOCK_SIZE = 32;\n\nshared FLOAT_TYPE tmpmax[BLOCK_SIZE];\nshared uint tmp[BLOCK_SIZE];\n\nvoid main() {\n    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;\n    const uint col = gl_LocalInvocationID.x;\n\n    if (row >= p.KY) {\n        return;\n    }\n\n    A_TYPE amax = -FLT_MAX;\n    uint acol = col;\n\n    if (col < p.KX) {\n        amax = data_a[row*p.KX + col];\n    }\n\n    for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) {\n        A_TYPE val = data_a[row*p.KX + i];\n        if (val > amax) {\n            amax = val;\n            acol = i;\n        }\n    }\n\n    tmp[col] = acol;\n    tmpmax[col] = amax;\n\n    barrier();\n    [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {\n        if (col < s && col + s < p.KX) {\n            if (tmpmax[col] < tmpmax[col + s]) {\n                tmpmax[col] = tmpmax[col + s];\n                tmp[col] = tmp[col + s];\n            }\n        }\n        barrier();\n    }\n\n    if (col == 0) {\n        data_d[row] = D_TYPE(tmp[0]);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/argsort.comp",
    "content": "#version 450\n#extension GL_EXT_control_flow_attributes : enable\n\n#include \"types.glsl\"\n\nlayout(constant_id = 0) const int BLOCK_SIZE = 1024;\nlayout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;\n#define ASC 0\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 2) writeonly buffer D {int data_d[];};\n\nlayout (push_constant) uniform parameter {\n    uint ncols;\n    uint ncols_padded;\n    uint ncols_padded_log2;\n    uint nrows;\n    uint order;\n    uint outer_start;\n    uint outer_end;\n    uint inner_start;\n    uint inner_end;\n} p;\n\nshared ivec2 dst_row[BLOCK_SIZE];\n\nvoid argsort(bool needs_bounds_check, const uint row) {\n    // bitonic sort\n    const int col = int(gl_LocalInvocationID.x);\n\n    const uint row_offset = row * p.ncols;\n\n    // initialize indices\n    dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col]));\n    barrier();\n\n    uint num_outer_loop_iters = NCOLS_PADDED_LOG2;\n    [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {\n        uint num_inner_loop_iters = outer_idx + 1;\n        [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {\n            const int ixj = int(col ^ j);\n\n            int idx_0 = (col & k) == 0 ? col : ixj;\n            int idx_1 = (col & k) == 0 ? ixj : col;\n\n            ivec2 sh_idx_0 = dst_row[idx_0];\n            ivec2 sh_idx_1 = dst_row[idx_1];\n            bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;\n            bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;\n\n            if ((idx_0_oob ||\n                (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {\n                dst_row[idx_0] = sh_idx_1;\n                dst_row[idx_1] = sh_idx_0;\n            }\n\n            barrier();\n        }\n    }\n\n    if (col < p.ncols) {\n        if (p.order == ASC) {\n            data_d[row_offset + col] = dst_row[col].x;\n        } else {\n            data_d[row_offset + p.ncols - col - 1] = dst_row[col].x;\n        }\n    }\n}\n\nvoid main() {\n    if (p.ncols == BLOCK_SIZE) {\n        uint row = gl_WorkGroupID.y;\n        while (row < p.nrows) {\n            argsort(false, row);\n            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;\n        }\n    } else {\n        uint row = gl_WorkGroupID.y;\n        while (row < p.nrows) {\n            argsort(true, row);\n            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/argsort_large.comp",
    "content": "#version 450\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_KHR_memory_scope_semantics : enable\n#pragma use_vulkan_memory_model\n\n#include \"types.glsl\"\n\nlayout(constant_id = 0) const int BLOCK_SIZE = 1024;\nlayout(constant_id = 1) const int WG_UNROLL_FACTOR = 2;\n#define ASC 0\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) workgroupcoherent buffer B {ivec2 tmp_idx[];};\nlayout (binding = 2) workgroupcoherent buffer D {int data_d[];};\n\nlayout (push_constant) uniform parameter {\n    uint ncols;\n    uint ncols_padded;\n    uint ncols_padded_log2;\n    uint nrows;\n    uint order;\n    uint outer_start;\n    uint outer_end;\n    uint inner_start;\n    uint inner_end;\n} p;\n\nvoid argsort(bool needs_bounds_check, const uint row) {\n    // bitonic sort\n    int col = int(gl_GlobalInvocationID.x);\n    col = (col % BLOCK_SIZE) + (col / BLOCK_SIZE) * BLOCK_SIZE * WG_UNROLL_FACTOR;\n\n    const uint row_offset = row * p.ncols;\n    uint idx_offset = row * p.ncols_padded;\n\n    bool need_barrier = false;\n\n    // initialize indices\n    if (p.outer_start == 0 && p.inner_start == 0) {\n        [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {\n            uint c = u*BLOCK_SIZE + col;\n            if (c < p.ncols_padded) {\n                ivec2 v = ivec2(c, floatBitsToInt(data_a[row_offset + c]));\n                tmp_idx[idx_offset + c] = v;\n            }\n        }\n        need_barrier = true;\n    }\n\n    [[unroll]] for (uint outer_idx = p.outer_start, k = (2 << outer_idx); outer_idx < p.outer_end; k *= 2, outer_idx++) {\n        uint inner_end = min(p.inner_end, outer_idx + 1);\n        for (uint j = k >> (p.inner_start + 1), inner_idx = p.inner_start; inner_idx < inner_end; j /= 2, inner_idx++) {\n            if (need_barrier) {\n                controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);\n            }\n            need_barrier = true;\n            [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {\n                int c = u*BLOCK_SIZE + col;\n                const int ixj = int(c ^ j);\n\n                if (ixj < c) {\n                    continue;\n                }\n\n                int idx_0 = (c & k) == 0 ? c : ixj;\n                int idx_1 = (c & k) == 0 ? ixj : c;\n\n                ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0];\n                ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1];\n                bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;\n                bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;\n\n                if ((idx_0_oob ||\n                    (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) {\n                    tmp_idx[idx_offset + idx_0] = sh_idx_1;\n                    tmp_idx[idx_offset + idx_1] = sh_idx_0;\n                }\n            }\n        }\n    }\n\n    if (p.outer_end == p.ncols_padded_log2 &&\n        p.inner_end >= p.ncols_padded_log2 + 1) {\n        controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);\n        [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {\n            uint c = u*BLOCK_SIZE + col;\n            if (c < p.ncols) {\n                if (p.order == ASC) {\n                    data_d[row_offset + c] = tmp_idx[idx_offset + c].x;\n                } else {\n                    data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x;\n                }\n            }\n        }\n    }\n}\n\nvoid main() {\n    if (p.ncols == p.ncols_padded) {\n        uint row = gl_WorkGroupID.y;\n        while (row < p.nrows) {\n            argsort(false, row);\n            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;\n        }\n    } else {\n        uint row = gl_WorkGroupID.y;\n        while (row < p.nrows) {\n            argsort(true, row);\n            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/ceil.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float x = float(data_a[i]);\n    data_d[i] = D_TYPE(ceil(x));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/clamp.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint idx = get_idx();\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);\n    data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/concat.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_binary_head.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n    const int dim = p.param3;\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    const uint i3 = idx / (p.ne22*p.ne21*p.ne20);\n    const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20;\n    const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20);\n    const uint i2_offset = i2*p.ne21*p.ne20;\n    const uint i1 = (idx - i3_offset - i2_offset) / p.ne20;\n    const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20;\n\n    uint o[4] = {0, 0, 0, 0};\n    o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03));\n\n    const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;\n    const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10;\n    const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20;\n\n    const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;\n\n#ifndef OPTIMIZATION_ERROR_WORKAROUND\n    data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]);\n#else\n    if (is_src0) {\n        data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx];\n    } else {\n        data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx];\n    }\n#endif\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/contig_copy.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\n#extension GL_EXT_control_flow_attributes : require\n\nconst uint num_threads = 128;\n\nlayout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    uint idx = get_idx();\n\n    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation\n    const uint num_iter = 4;\n\n    // fast path for when all four iterations are in-bounds\n    if (idx + (num_iter-1)*num_threads < p.ne) {\n        [[unroll]] for (uint i = 0; i < num_iter; ++i) {\n\n#if defined(DATA_D_BF16)\n            float f = float(data_a[get_aoffset() + idx]);\n            data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));\n#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)\n            data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);\n#else\n            data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];\n#endif\n            idx += num_threads;\n        }\n    } else {\n        [[unroll]] for (uint i = 0; i < num_iter; ++i) {\n            if (idx >= p.ne) {\n                continue;\n            }\n\n#if defined(DATA_D_BF16)\n            float f = float(data_a[get_aoffset() + idx]);\n            data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));\n#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)\n            data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);\n#else\n            data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];\n#endif\n            idx += num_threads;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n\nlayout (push_constant) uniform parameter\n{\n    uint ne;\n    uint batches;\n    uint channels;\n    uint dst_w;\n    uint dst_h;\n    uint src_w;\n    uint src_h;\n    uint knl_w;\n    uint knl_h;\n    int stride_x;\n    int stride_y;\n    int pad_x;\n    int pad_y;\n    int dilation_x;\n    int dilation_y;\n} p;\n\nlayout (binding = 0) readonly buffer A {A_TYPE knl_data[];};\nlayout (binding = 1) readonly buffer B {B_TYPE src_data[];};\nlayout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nFLOAT_TYPE conv_2d_dw_whcn(uint idx) {\n    uint i0 = idx / p.dst_w;\n    uint dst_x = idx - i0 * p.dst_w;\n    uint i1 = i0 / p.dst_h;\n    uint dst_y = i0 - i1 * p.dst_h;\n    uint n = i1 / p.channels;\n    uint c = i1 - n * p.channels;\n\n    uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;\n    uint knl_i = c * p.knl_h * p.knl_w;\n\n    FLOAT_TYPE sum = 0.0;\n    for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {\n        uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;\n        if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int\n            continue;\n        }\n        for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {\n            uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;\n            if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int\n                continue;\n            }\n            FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);\n            FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);\n            sum = fma(v, k, sum);\n        }\n    }\n    return sum;\n}\n\nFLOAT_TYPE conv_2d_dw_cwhn(uint idx) {\n    uint i0 = idx / p.channels;\n    uint c = idx - i0 * p.channels;\n    uint i1 = i0 / p.dst_w;\n    uint dst_x = i0 - i1 * p.dst_w;\n    uint n = i1 / p.dst_h;\n    uint dst_y = i1 - n * p.dst_h;\n\n    uint src_i = n * p.channels * p.src_h * p.src_w;\n    uint src_row = p.src_w * p.channels;\n    uint knl_row = p.knl_w * p.channels;\n\n    FLOAT_TYPE sum = 0.0;\n    for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {\n        uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;\n        if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int\n            continue;\n        }\n        for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {\n            uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;\n            if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int\n                continue;\n            }\n            FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);\n            FLOAT_TYPE k = FLOAT_TYPE(knl_data[        knl_y * knl_row + knl_x * p.channels + c]);\n            sum = fma(v, k, sum);\n        }\n    }\n    return sum;\n}\n\nvoid main() {\n    uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n    if (idx >= p.ne) {\n        return;\n    }\n\n    FLOAT_TYPE result =\n#ifdef WHCN\n        conv_2d_dw_whcn(idx);\n#else\n        conv_2d_dw_cwhn(idx);\n#endif\n    dst_data[idx] = D_TYPE(result);\n}\n\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n#ifdef COOPMAT2\n#extension GL_NV_cooperative_matrix2 : enable\n#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require\n#extension GL_KHR_memory_scope_semantics : enable\n#endif\n\n#ifdef USE_COLLECTIVES\n#    extension GL_KHR_shader_subgroup_shuffle : enable\n#endif\n\n#include \"types.glsl\"\n\n// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j\nlayout(binding = 0) readonly buffer A {\n    A_TYPE knl_data[];\n};  // src0 - kernel:   [KW, KH, Cin, Cout] for conv_2d, [KW, KH, Cout, Cin] for conv_transposed_2d\n\nlayout(binding = 1) readonly buffer B {\n    B_TYPE src_data[];\n};  // src1 - input:    [W, H, Cin, N] -- channel_first format\n\nlayout(binding = 2) writeonly buffer D {\n    D_TYPE dst_data[];\n};  // dst - result:    [OW, OH, Cout, N]\n\nlayout(push_constant) uniform parameter {\n    // I/O channels, batch size\n    uint32_t Cout;\n    uint32_t Cin;\n    uint32_t N;\n\n    // Tensor spatial sizes: input, output\n    uint32_t W;\n    uint32_t H;\n    uint32_t OW;\n    uint32_t OH;\n\n    // Strides in elements\n    uint32_t nb01;\n    uint32_t nb02;\n    uint32_t nb03;\n\n    uint32_t nb11;\n    uint32_t nb12;\n    uint32_t nb13;\n\n    uint32_t nb1;\n    uint32_t nb2;\n    uint32_t nb3;\n\n    // fastdiv helper values\n    uint32_t OWmp;   uint32_t OWL;\n    uint32_t OWOHmp; uint32_t OWOHL;\n}\n\np;\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n// Blocktile sizes\nlayout(constant_id = 1) const uint BS_K            = 128;\nlayout(constant_id = 2) const uint BS_CRS          = 16;\nlayout(constant_id = 3) const uint BS_NPQ          = 128;\n// Thread-tile sizes\nlayout(constant_id = 4) const uint TS_K            = 8;\nlayout(constant_id = 5) const uint use_collectives = 1;\nlayout(constant_id = 6) const uint SHMEM_PAD       = 4;\n// Stride, padding, dilation\nlayout(constant_id = 7)  const uint s0             = 1;\nlayout(constant_id = 8)  const uint s1             = 1;\nlayout(constant_id = 9)  const uint p0             = 0;\nlayout(constant_id = 10) const uint p1             = 0;\nlayout(constant_id = 11) const uint d0             = 1;\nlayout(constant_id = 12) const uint d1             = 1;\n// Kernel spatial sizes\nlayout(constant_id = 13) const uint KW             = 1;\nlayout(constant_id = 14) const uint KH             = 1;\n\nuint32_t       tid     = gl_LocalInvocationID.x;\nconst uint32_t WG_SIZE = gl_WorkGroupSize.x;\n\nuint splitWork(uint work_size, uint block_size) {\n    return (block_size + work_size - 1) / block_size;\n}\n\nuint32_t K   = p.Cout;\nuint32_t CRS = p.Cin * KH * KW;\nuint32_t NPQ = p.N * p.OH * p.OW;\n\nuint32_t n_elems_out = K * NPQ;\n\n// Number of blocktiles per input\nuint32_t NB_CRS = splitWork(CRS, BS_CRS);\n\n#ifdef COOPMAT2\n#define SHMEM_TYPE float16_t\n#else\n#define SHMEM_TYPE float\n#endif\n\nconst uint32_t Ash_stride = BS_CRS + SHMEM_PAD;\nconst uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;\n\nconst uint32_t Ash_numel = BS_K * BS_CRS;\nconst uint32_t Bsh_numel = BS_CRS * BS_NPQ;\n\nconst uint32_t Ash_len = BS_K * Ash_stride;\nconst uint32_t Bsh_len = BS_CRS * Bsh_stride;\n\nshared SHMEM_TYPE Ash[Ash_len];  // K x CRS\nshared SHMEM_TYPE Bsh[Bsh_len];  // CRS x NPQ\n\n// Threadtile sizes\nconst uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;\n\n// Number of threadtiles per blocktile\nconst uint32_t NT_K   = BS_K / TS_K;\nconst uint32_t NT_NPQ = BS_NPQ / TS_NPQ;\n\n/*\nCompute\nKxCRS @ CRSxNPQ = K x NPQ\nK=Cout\nC=Cin\nR,S=KH,KW\nP,Q=OH,OW\n*/\n\nuint32_t B_idx_K   = gl_WorkGroupID.x;\nuint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512;\n\nuint32_t T_y = tid / NT_NPQ;\nuint32_t T_x = tid % NT_NPQ;\n\nuint32_t       Ar    = tid / BS_CRS;\nuint32_t       Ac    = tid % BS_CRS;\nconst uint32_t ArpWg = WG_SIZE / BS_CRS;\n\nuint32_t       Br    = tid / BS_NPQ;\nuint32_t       Bc    = tid % BS_NPQ;\nconst uint32_t BrpWg = WG_SIZE / BS_NPQ;\n\n// see init_fastdiv_values in ggml-vulkan.cpp\nuint fastdiv(uint n, uint mp, uint L) {\n    uint msbs, lsbs;\n    // msbs = mulhi(n, mp)\n    umulExtended(n, mp, msbs, lsbs);\n    return (msbs + n) >> L;\n}\n\n#ifdef COOPMAT2\n#define ACC_TYPE float16_t\n\nACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)\n{\n    uint32_t K_idx   = B_idx_K * BS_K + r;\n    uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;\n    uint32_t N_idx   = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;\n    uint32_t OH_idx  = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;\n    uint32_t OW_idx  = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;\n    uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;\n    if (K_idx < K && NPQ_idx < NPQ) {\n        dst_data[dst_idx] = D_TYPE(elem);\n    }\n    return elem;\n}\n#endif\n\nvoid main() {\n    if (B_idx_NPQ * BS_NPQ >= NPQ) {\n        return;\n    }\n\n#ifdef COOPMAT2\n    coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;\n    matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);\n#else\n    float regC[TS_K][TS_NPQ];\n    for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {\n        for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {\n            regC[T_ly][T_lx] = 0.0;\n        }\n    }\n#endif\n    /* Advance block in CRS dim */\n    [[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {\n        uint32_t CRS_idx_a;\n        uint32_t Cin_idx_a;\n        uint32_t KH_idx_a;\n        uint32_t KW_idx_a;\n\n#ifdef USE_COLLECTIVES\n        uint32_t cached_CRS_idx;\n        uint32_t cached_Cin_idx;\n        uint32_t cached_KH_idx;\n        uint32_t cached_KW_idx;\n        if (use_collectives == 1) {\n            cached_CRS_idx                = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;\n            cached_Cin_idx                = cached_CRS_idx / (KW * KH);\n            uint32_t cached_CRS_remainder = cached_CRS_idx % (KW * KH);\n            cached_KH_idx                 = cached_CRS_remainder / KW;\n            cached_KW_idx                 = cached_CRS_remainder % KW;\n\n            CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);\n            Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);\n            KH_idx_a  = subgroupShuffle(cached_KH_idx, Ac);\n            KW_idx_a  = subgroupShuffle(cached_KW_idx, Ac);\n        } else {\n            CRS_idx_a              = B_idx_CRS * BS_CRS + Ac;  // Global CRS_idx_a (column index of A)\n            Cin_idx_a              = CRS_idx_a / (KW * KH);\n            uint32_t CRS_remainder = CRS_idx_a % (KW * KH);\n            KH_idx_a               = CRS_remainder / KW;\n            KW_idx_a               = CRS_remainder % KW;\n        }\n#else\n        CRS_idx_a     = B_idx_CRS * BS_CRS + Ac;  // Global CRS_idx_a (column index of A)\n        Cin_idx_a     = CRS_idx_a / (KW * KH);\n        CRS_remainder = CRS_idx_a % (KW * KH);\n        KH_idx_a      = CRS_remainder / KW;\n        KW_idx_a      = CRS_remainder % KW;\n#endif\n\n        /* Load kernel to A_block: (BS_K x BS_CRS)*/\n        UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {\n            uint32_t B_ly    = r_offset + Ar;\n            uint32_t B_lx    = Ac;\n            uint32_t K_idx   = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/\n#ifdef TRANSPOSE\n            uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1);\n#else\n            uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);\n#endif\n            float    val     = knl_data[knl_idx];\n            if (K_idx >= K || CRS_idx_a >= CRS) {\n                val = 0.0;\n            }\n            Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);\n        }\n        /* Load input to B_block: (BS_CRS x BS_NPQ) */\n        UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {\n            uint32_t B_ly          = r_offset + Br;             /* Row index of B block */\n            uint32_t B_lx          = Bc;\n            uint32_t NPQ_idx       = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */\n            uint32_t N_idx         = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;\n            uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;\n            uint32_t OH_idx        = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW;\n            uint32_t OW_idx        = NPQ_remainder - OH_idx * p.OW;\n\n            uint32_t CRS_idx_b;\n            uint32_t Cin_idx_b;\n            uint32_t KH_idx_b;\n            uint32_t KW_idx_b;\n#ifdef USE_COLLECTIVES\n            if (use_collectives == 1) {\n                CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);\n                Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);\n                KH_idx_b  = subgroupShuffle(cached_KH_idx, r_offset + Br);\n                KW_idx_b  = subgroupShuffle(cached_KW_idx, r_offset + Br);\n            } else {\n                CRS_idx_b              = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */\n                Cin_idx_b              = CRS_idx_b / (KW * KH);\n                uint32_t CRS_remainder = CRS_idx_b % (KW * KH);\n                KH_idx_b               = CRS_remainder / KW;\n                KW_idx_b               = CRS_remainder % KW;\n            }\n#else\n            CRS_idx_b              = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */\n            Cin_idx_b              = CRS_idx_b / (KW * KH);\n            uint32_t CRS_remainder = CRS_idx_b % (KW * KH);\n            KH_idx_b               = CRS_remainder / KW;\n            KW_idx_b               = CRS_remainder % KW;\n#endif\n\n#ifdef TRANSPOSE\n            uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * d1 + p1;\n            uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * d0 + p0;\n            uint32_t H_idx = H_idx_x_s1 / s1;\n            uint32_t W_idx = W_idx_x_s0 / s0;\n#else\n            uint32_t H_idx = OH_idx * s1 + KH_idx_b * d1 - p1;\n            uint32_t W_idx = OW_idx * s0 + KW_idx_b * d0 - p0;\n#endif\n            uint32_t src_idx =\n                min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);\n            float val = src_data[src_idx];\n            if (CRS_idx_b >= CRS || NPQ_idx >= NPQ\n                || H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case)\n#ifdef TRANSPOSE\n                || (H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0)\n#endif\n                ) {\n                val = 0.0;\n            }\n            Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);\n        }\n        barrier();\n#ifdef COOPMAT2\n        coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;\n        coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;\n\n        coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);\n        coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);\n        matC = coopMatMulAdd(matA, matB, matC);\n#else\n        if (T_y * TS_K < K) {\n            UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {\n                float regA[TS_K];\n                float regB[TS_NPQ];\n                for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {\n                    regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];\n                }\n                for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {\n                    regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];\n                }\n                for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {\n                    for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {\n                        regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);\n                    }\n                }\n            }\n        }\n#endif\n        barrier();\n    }\n    /* Save C* */\n#ifdef COOPMAT2\n    coopMatPerElementNV(matC, matC, perElemOpStore);\n#else\n    if (T_y * TS_K < K) {\n        for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {\n            for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {\n                uint32_t K_idx   = B_idx_K * BS_K + T_y * TS_K + T_ly;\n                uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;\n                uint32_t N_idx   = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;\n                uint32_t OH_idx  = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;\n                uint32_t OW_idx  = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;\n                uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;\n                if (K_idx < K && NPQ_idx < NPQ) {\n                    dst_data[dst_idx] = regC[T_ly][T_lx];\n                }\n            }\n        }\n    }\n#endif\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};   // src0 - kernel:    [K, Cout, Cin]\nlayout (binding = 1) readonly buffer B {B_TYPE data_b[];};   // src1 - input:     [L, Cin]\nlayout (binding = 2) writeonly buffer D {D_TYPE data_d[];};     // dst - result      [KL, Cout]\n\nlayout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;\n\nlayout (push_constant) uniform parameter {\n    uint32_t Cout;\n    uint32_t Cin;\n    uint32_t K;\n    uint32_t L;\n    uint32_t KL;\n\n    uint32_t nb01;\n    uint32_t nb02;\n    uint32_t nb11;\n    uint32_t nb1;\n\n    int32_t s0;\n} p;\n\n\nuint32_t Cout_idx = gl_WorkGroupID.x;\nconst uint32_t bs = gl_WorkGroupSize.x;\nuint32_t tid = gl_LocalInvocationID.x;\n// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.\nuint32_t tmp_len = bs*p.s0+p.K;\nshared D_TYPE tmp[4096];\n\nuint splitWork(uint workSize){\n    return (bs + workSize -1) / bs;\n}\n\nvoid main(){\n    for(uint32_t i = 0; i < splitWork(tmp_len); i++){\n        uint32_t idx = i*bs+tid;\n        if(idx < tmp_len){\n            tmp[idx] = 0.0;\n        }\n    }\n\n    uint32_t L_blocks = splitWork(p.L);\n    for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){\n        if(L_block_id > 0){\n            barrier();\n            // Shift values in tmp to the current processing window\n            for(int i = 0; i < splitWork(tmp_len); i++){\n                uint32_t idx = i*bs+tid;\n                if(idx >= bs*p.s0 && idx < tmp_len){\n                    tmp[idx-bs*p.s0] = tmp[idx];\n                    tmp[idx] = 0.0;\n                }else if(idx >= p.K && idx < bs*p.s0){\n                    tmp[idx] = 0.0;\n                }\n            }\n        }\n        barrier();\n\n        // Save contributions of the block to tmp\n        uint32_t L_idx = L_block_id*bs + tid;\n        for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){\n            D_TYPE dp = 0.0;\n            for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){\n                A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];\n                if(L_idx < p.L){\n                    B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];\n                    dp = fma(elemKrn, elemInp, dp);\n                }\n            }\n            tmp[tid*p.s0 + K_idx] += dp;\n            barrier();\n        }\n\n        // Save the computed values except the last block that can have different size\n        uint32_t KLb_idx = L_block_id*bs*p.s0;\n        if(L_block_id < L_blocks-1){\n            for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){\n                uint32_t sh_idx = p.s0*tid+s0_idx;\n                uint32_t KL_idx = KLb_idx+sh_idx;\n                if(KL_idx < p.KL){\n                    data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];\n                }\n            }\n        }\n    }\n\n    for(uint32_t i = 0; i < splitWork(tmp_len); i++){\n        uint32_t idx = i*bs+tid;\n        uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;\n        if(KL_idx < p.KL){\n            data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/copy.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint idx = get_idx();\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n#if defined(DATA_D_BF16)\n    float f = float(data_a[get_aoffset() + src0_idx(idx)]);\n    data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));\n#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)\n    data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);\n#else\n    data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];\n#endif\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n#include \"dequant_funcs.glsl\"\n\n#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)\n// 16 invocations needed for init_iq_shmem\nlayout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;\n#else\nlayout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;\n#endif\n\nvoid main() {\n#ifdef NEEDS_INIT_IQ_SHMEM\n    init_iq_shmem(gl_WorkGroupSize);\n    if (gl_LocalInvocationIndex.x != 0) {\n        return;\n    }\n#endif\n\n    const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    uint dst_idx = get_doffset() + dst_idx(idx);\n    uint src_idx = src0_idx_quant(idx, QUANT_K);\n\n    const uint a_offset = 0;\n    const uint ib = src_idx;\n    const vec2 dm = get_dm(ib, a_offset);\n\n    [[unroll]] for (int j = 0; j < QUANT_K; j += 4) {\n        vec4 v = dequantize4(ib, j / QUANT_R, a_offset);\n        v = v * dm.x + vec4(dm.y);\n\n#if QUANT_R == 2\n        data_d[dst_idx + j/2 +             0] = v[0];\n        data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1];\n        data_d[dst_idx + j/2 +             1] = v[2];\n        data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3];\n#else\n        data_d[dst_idx + j + 0] = v[0];\n        data_d[dst_idx + j + 1] = v[1];\n        data_d[dst_idx + j + 2] = v[2];\n        data_d[dst_idx + j + 3] = v[3];\n#endif\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp",
    "content": "#version 450\n\n#include \"rte.glsl\"\n#include \"types.glsl\"\n\n#if defined(SET_ROWS) && QUANT_K == 1\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\nconst uint BLOCK_SIZE = 512;\n#else\nlayout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;\nconst uint BLOCK_SIZE = 32;\n#endif\n\nlayout (binding = 0) readonly buffer S {float data_s[];};\n\n#if defined(SET_ROWS)\n#include \"generic_binary_head.glsl\"\nlayout (binding = 1) readonly buffer C {B_TYPE data_i[];};\nlayout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};\n\n#if B_SIZE == 64\n#define DATA_I_SWIZZLE .x\n#else\n#define DATA_I_SWIZZLE\n#endif\n\n#else\n#include \"generic_unary_head.glsl\"\nlayout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};\n#endif\n\n#if defined(DATA_A_Q4_0)\nvoid quantize(uint dst_idx, uint src_idx)\n{\n    float amax = 0.0;\n    float vmax = 0.0;\n\n    [[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) {\n        const float v = data_s[src_idx + j];\n        if (amax < abs(v)) {\n            amax = abs(v);\n            vmax = v;\n        }\n    }\n\n    const float d  = vmax / -8;\n    const float id = (d != 0.0) ? 1.0/d : 0.0;\n\n    data_q[dst_idx].d = float16_t(d);\n\n    [[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) {\n        const float x0 = data_s[src_idx + 0              + j]*id;\n        const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id;\n\n        const uint xi0 = min(15, int(x0 + 8.5));\n        const uint xi1 = min(15, int(x1 + 8.5));\n\n        data_q[dst_idx].qs[j]  = uint8_t(xi0 | (xi1 << 4));\n    }\n}\n#endif\n\n#if defined(DATA_A_Q4_1)\nvoid quantize(uint dst_idx, uint src_idx)\n{\n    float vmin = 1.0/0.0;\n    float vmax = -vmin;\n\n    [[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) {\n        const float v = data_s[src_idx + j];\n\n        if (v < vmin) vmin = v;\n        if (v > vmax) vmax = v;\n    }\n\n    const float d  = (vmax - vmin) / ((1 << 4) - 1);\n    const float id = (d != 0.0) ? 1.0/d : 0.0;\n\n    data_q[dst_idx].d = float16_t(d);\n    data_q[dst_idx].m = float16_t(vmin);\n\n    [[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) {\n        const float x0 = (data_s[src_idx + 0              + j] - vmin)*id;\n        const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id;\n\n        const uint xi0 = min(15, int(x0 + 0.5));\n        const uint xi1 = min(15, int(x1 + 0.5));\n\n        data_q[dst_idx].qs[j]  = uint8_t(xi0 | (xi1 << 4));\n    }\n}\n#endif\n\n#if defined(DATA_A_Q5_0)\nvoid quantize(uint dst_idx, uint src_idx)\n{\n    float amax = 0.0;\n    float vmax = 0.0;\n\n    [[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) {\n        const float v = data_s[src_idx + j];\n        if (amax < abs(v)) {\n            amax = abs(v);\n            vmax = v;\n        }\n    }\n\n    const float d  = vmax / -16;\n    const float id = (d != 0.0) ? 1.0/d : 0.0;\n\n    data_q[dst_idx].d = float16_t(d);\n\n    uint32_t qh = 0;\n    [[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) {\n        const float x0 = data_s[src_idx + 0              + j]*id;\n        const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id;\n\n        const uint xi0 = min(31, int(x0 + 16.5));\n        const uint xi1 = min(31, int(x1 + 16.5));\n\n        data_q[dst_idx].qs[j]  = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));\n        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);\n        qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2);\n    }\n    data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF);\n    data_q[dst_idx].qh[1] = uint16_t(qh >> 16);\n}\n#endif\n\n#if defined(DATA_A_Q5_1)\nvoid quantize(uint dst_idx, uint src_idx)\n{\n    float min = data_s[src_idx + 0];\n    float max = min;\n\n    [[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) {\n        const float v = data_s[src_idx + j];\n        min = v < min ? v : min;\n        max = v > max ? v : max;\n    }\n\n    const float d  = (max - min) / 31;\n    const float id = (d != 0) ? 1.0/d : 0.0;\n\n    data_q[dst_idx].d = float16_t(d);\n    data_q[dst_idx].m = float16_t(min);\n\n    uint32_t qh = 0;\n    [[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) {\n        const float x0 = (data_s[src_idx + 0              + j] - min)*id;\n        const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id;\n\n        const uint xi0 = uint(x0 + 0.5);\n        const uint xi1 = uint(x1 + 0.5);\n\n        data_q[dst_idx].qs[j]  = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));\n        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);\n        qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2);\n    }\n    data_q[dst_idx].qh = qh;\n}\n#endif\n\n#if defined(DATA_A_Q8_0)\nvoid quantize(uint dst_idx, uint src_idx)\n{\n    float amax = 0.0; // absolute max\n\n    [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) {\n        const float v = data_s[src_idx + j];\n        amax = max(amax, abs(v));\n    }\n\n    const float d = amax / ((1 << 7) - 1);\n    const float id = (d != 0.0) ? 1.0/d : 0.0;\n\n    data_q[dst_idx].d = float16_t(d);\n\n    [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) {\n        const float x0 = data_s[src_idx + j]*id;\n\n        data_q[dst_idx].qs[j] = int8_t(round(x0));\n    }\n}\n#endif\n\n#if defined(DATA_A_IQ4_NL)\nuint best_index(float x) {\n    if (x <= kvalues_iq4nl[0]) return 0;\n    if (x >= kvalues_iq4nl[15]) return 15;\n    int ml = 0, mu = 15;\n    while (mu-ml > 1) {\n        int mav = (ml+mu)/2;\n        if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav;\n    }\n    return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu;\n}\n\nvoid quantize(uint dst_idx, uint src_idx)\n{\n    float amax = 0.0;\n    float vmax = 0.0;\n\n    [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) {\n        const float v = data_s[src_idx + j];\n        if (amax < abs(v)) {\n            amax = abs(v);\n            vmax = v;\n        }\n    }\n\n    float d = vmax / kvalues_iq4nl[0];\n    const float id = (d != 0.0) ? 1.0/d : 0.0;\n\n    float sumqx = 0, sumq2 = 0;\n    [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) {\n        const float x0 = data_s[src_idx + 0                + j]*id;\n        const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id;\n        const uint xi0 = best_index(x0);\n        const uint xi1 = best_index(x1);\n        data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));\n        const float v0 = kvalues_iq4nl[xi0];\n        const float v1 = kvalues_iq4nl[xi1];\n        const float w0 = data_s[src_idx + 0                + j]*data_s[src_idx + 0                + j];\n        const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];\n        sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];\n        sumq2 += w0*v0*v0 + w1*v1*v1;\n    }\n\n    data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d);\n\n}\n#endif\n\n#if defined(DATA_A_F32) || defined(DATA_A_F16)\nvoid quantize(uint dst_idx, uint src_idx)\n{\n    data_q[dst_idx] = A_TYPE(data_s[src_idx]);\n}\n#endif\n\n#if defined(DATA_A_BF16)\nvoid quantize(uint dst_idx, uint src_idx)\n{\n    data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));\n}\n#endif\n\n#if defined(SET_ROWS)\n\nvoid main() {\n#ifdef NEEDS_INIT_IQ_SHMEM\n    init_iq_shmem(gl_WorkGroupSize);\n#endif\n\n    const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    uint i00, i01, i02, i03;\n    get_indices(idx, i00, i01, i02, i03);\n\n    uint i12 = fastmod(i03, p.ne12);\n    uint i11 = fastmod(i02, p.ne11);\n    uint i10 = i01;\n\n    uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()] DATA_I_SWIZZLE;\n\n    uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();\n    uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();\n\n    quantize(dst_idx, src0_idx);\n}\n\n#else\n\nvoid main() {\n#ifdef NEEDS_INIT_IQ_SHMEM\n    init_iq_shmem(gl_WorkGroupSize);\n#endif\n\n    const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    uint dst_idx = dst_idx_quant(idx, QUANT_K);\n    uint src_idx = get_aoffset() + src0_idx(idx);\n\n    quantize(dst_idx, src_idx);\n}\n\n#endif\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/copy_transpose.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\n// workgroup does 32x32 tile, but uses 32x8 threads\n#define TILE_DIM 32\nlayout(local_size_x = 32, local_size_y = 8, local_size_z = 1) in;\n\nshared uint sh[TILE_DIM][TILE_DIM + 1];\n\nvoid iter(uvec3 wg_id) {\n    const uint tile_col = wg_id.x;\n    const uint tile_row = wg_id.y;\n\n    const uint tid_col = gl_LocalInvocationID.x;\n    const uint tid_row = gl_LocalInvocationID.y;\n\n    const uint i2 = wg_id.z % p.ne12;\n    const uint i3 = wg_id.z / p.ne12;\n    const uint i02 = i2;\n    const uint i03 = i3;\n\n    // The workgroup does TILE_DIM x TILE_DIM, but swaps the LSBs of the\n    // src coords to make memory accesses contiguous, dst has tid.x in i0,\n    // src has tid.x in i01\n\n    [[unroll]] for (uint y = 0; y < 4; ++y) {\n        const uint i00 = tile_col * TILE_DIM + tid_row + 8 * y;\n        const uint i01 = tile_row * TILE_DIM + tid_col;\n        if (i00 < p.ne00 && i01 < p.ne01 && i02 < p.ne02 && i03 < p.ne03) {\n            const uint src_idx = i00 * p.nb00 + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;\n            sh[tid_row + 8 * y][tid_col] = uint(data_a[get_aoffset() + src_idx]);\n        }\n    }\n\n    barrier();\n\n    [[unroll]] for (uint y = 0; y < 4; ++y) {\n        const uint i0 = tile_col * TILE_DIM + tid_col;\n        const uint i1 = tile_row * TILE_DIM + tid_row + 8 * y;\n        if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {\n            const uint dst_idx = i0 * p.nb10 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;\n            // load transposed\n            data_d[get_doffset() + dst_idx] = D_TYPE(sh[tid_col][tid_row + 8 * y]);\n        }\n    }\n}\n\n#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))\n\nvoid main() {\n    uint z = gl_WorkGroupID.z;\n    uint y = gl_WorkGroupID.y;\n    bool need_barrier = false;\n    for (uint z = gl_WorkGroupID.z; z < p.ne12 * p.ne13; z += gl_NumWorkGroups.z) {\n        for (uint y = gl_WorkGroupID.y; y < CEIL_DIV(p.ne11, TILE_DIM); y += gl_NumWorkGroups.y) {\n            for (uint x = gl_WorkGroupID.x; x < CEIL_DIV(p.ne10, TILE_DIM); x += gl_NumWorkGroups.x) {\n                if (need_barrier) {\n                    barrier();\n                }\n                need_barrier = true;\n                iter(uvec3(x, y, z));\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/cos.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint idx = get_idx();\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);\n    data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/count_equal.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n\n#include \"types.glsl\"\n#include \"generic_head.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) readonly buffer Y {B_TYPE data_b[];};\nlayout (binding = 2) buffer D {D_TYPE data_d[];};\n\nconst uint CHUNK_SIZE = 512;\n\nvoid main() {\n    const uint base = gl_WorkGroupID.x * CHUNK_SIZE;\n    const uint col = gl_LocalInvocationID.x;\n\n    uint count = 0;\n    [[unroll]]\n    for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {\n        const uint idx = base + i + col;\n        if (idx >= p.KX) {\n            break;\n        }\n        count += uint(data_a[idx] == data_b[idx]);\n    }\n\n    atomicAdd(data_d[0], D_TYPE(count));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/count_experts.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n\n#include \"types.glsl\"\n\nlayout (push_constant) uniform parameter\n{\n    uint32_t ne00;\n    uint32_t ne01;\n    uint32_t nb00;\n    uint32_t nb01;\n    uint32_t a_offset;\n} p;\n\n#define BLOCK_SIZE 256\n\nlayout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {uint data_a[];};\nlayout (binding = 1) writeonly buffer D {uint data_d[];};\n\nshared uint vals[BLOCK_SIZE];\n\nvoid main() {\n    const uint expert_id = gl_WorkGroupID.x;\n    const uint num_elements = p.ne00 * p.ne01;\n    const uint tid = gl_LocalInvocationID.x;\n\n    uint count = 0;\n    for (uint idx = tid; idx < num_elements; idx += BLOCK_SIZE) {\n        const uint i01 = idx / p.ne00;\n        const uint i00 = idx % p.ne00;\n        const uint a = data_a[p.a_offset + i01 * p.nb01 + i00 * p.nb00];\n\n        count += uint(a == expert_id);\n    }\n\n    vals[tid] = count;\n    barrier();\n    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            vals[tid] += vals[tid + s];\n        }\n        barrier();\n    }\n\n    if (tid == 0) {\n        data_d[expert_id] = vals[0];\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/cumsum.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"sum_rows.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_KHR_shader_subgroup_arithmetic : enable\n#extension GL_KHR_shader_subgroup_basic : enable\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nlayout (constant_id = 0) const uint BLOCK_SIZE = 128;\nlayout (constant_id = 1) const uint SUBGROUP_SIZE = 32;\nlayout (constant_id = 2) const uint ELEM_PER_THREAD = 4;\n\n#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))\n\nshared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];\nshared FLOAT_TYPE last_sum;\n\nvoid main() {\n    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;\n    const uint tid = gl_LocalInvocationID.x;\n\n    const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);\n    const uint i03_offset = i03 * p.ne01*p.ne02;\n    const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);\n    const uint i01 = row - i03_offset - i02*p.ne01;\n\n    const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;\n    const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;\n\n    uint subgroup_id = tid / SUBGROUP_SIZE;\n\n    if (tid == 0) {\n        last_sum = 0;\n    }\n\n    uint col = tid * ELEM_PER_THREAD;\n    uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD);\n    for (int i = 0; i < num_iter; ++i) {\n        FLOAT_TYPE v[ELEM_PER_THREAD];\n        FLOAT_TYPE thread_sum = 0;\n        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {\n            if (col + j < p.n_cols) {\n                thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]);\n            }\n            v[j] = thread_sum;\n        }\n\n        thread_sum = subgroupExclusiveAdd(thread_sum);\n        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {\n            v[j] += thread_sum;\n        }\n        // Store the largest partial sum for each subgroup, then add the partials for all\n        // lower subgroups and the final partial sum from the previous iteration.\n        if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {\n            partial[subgroup_id] = v[ELEM_PER_THREAD - 1];\n        }\n        barrier();\n        for (int s = 0; s < subgroup_id; ++s) {\n            [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {\n                v[j] += partial[s];\n            }\n        }\n        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {\n            v[j] += last_sum;\n        }\n        barrier();\n        if (tid == BLOCK_SIZE - 1) {\n            last_sum = v[ELEM_PER_THREAD - 1];\n        }\n        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {\n            if (col + j < p.n_cols) {\n                data_d[dst_idx + col + j] = D_TYPE(v[j]);\n            }\n        }\n        col += BLOCK_SIZE * ELEM_PER_THREAD;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"sum_rows.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_KHR_shader_subgroup_arithmetic : enable\n#extension GL_KHR_shader_subgroup_basic : enable\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\nlayout (binding = 2) writeonly buffer T {D_TYPE data_t[];};\n\nlayout (constant_id = 0) const uint BLOCK_SIZE = 128;\nlayout (constant_id = 1) const uint SUBGROUP_SIZE = 32;\n\n#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))\n\nshared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];\n\nvoid main() {\n    const uint row = gl_WorkGroupID.y;\n    const uint tid = gl_LocalInvocationID.x;\n    const uint col = gl_GlobalInvocationID.x;\n\n    const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);\n    const uint i03_offset = i03 * p.ne01*p.ne02;\n    const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);\n    const uint i01 = row - i03_offset - i02*p.ne01;\n\n    const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;\n    const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;\n\n    uint subgroup_id = tid / SUBGROUP_SIZE;\n\n    FLOAT_TYPE v = 0;\n    if (col < p.n_cols) {\n        v = FLOAT_TYPE(data_a[src_idx + col]);\n    }\n    v = subgroupInclusiveAdd(v);\n\n    // Store the largest partial sum for each subgroup, then add the partials for all\n    // lower subgroups and the final partial sum from the previous iteration.\n    if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {\n        partial[subgroup_id] = v;\n    }\n    barrier();\n    for (int j = 0; j < subgroup_id; ++j) {\n        v += partial[j];\n    }\n    barrier();\n    if (tid == BLOCK_SIZE - 1) {\n        data_t[gl_WorkGroupID.x + gl_NumWorkGroups.x * row] = v;\n    }\n    if (col < p.n_cols) {\n        data_d[dst_idx + col] = D_TYPE(v);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"sum_rows.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_KHR_shader_subgroup_arithmetic : enable\n#extension GL_KHR_shader_subgroup_basic : enable\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) buffer D {D_TYPE data_d[];};\nlayout (binding = 2) readonly buffer T {D_TYPE data_t[];};\n\nlayout (constant_id = 0) const uint BLOCK_SIZE = 128;\nlayout (constant_id = 1) const uint SUBGROUP_SIZE = 32;\n\n#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))\n\nshared FLOAT_TYPE temp[BLOCK_SIZE / SUBGROUP_SIZE];\n\nvoid main() {\n    const uint row = gl_WorkGroupID.y;\n    const uint tid = gl_LocalInvocationID.x;\n\n    const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);\n    const uint i03_offset = i03 * p.ne01*p.ne02;\n    const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);\n    const uint i01 = row - i03_offset - i02*p.ne01;\n\n    const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;\n    const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;\n\n    const uint col = gl_GlobalInvocationID.x;\n\n    float v = 0;\n    // prefetch value we're adding to\n    if (col < p.n_cols) {\n        v = data_d[dst_idx + col];\n    }\n\n    // compute the sum of all previous blocks\n    uint c = tid;\n    float sum = 0;\n    while (c < gl_WorkGroupID.x) {\n        sum += data_t[c + gl_NumWorkGroups.x * row];\n        c += BLOCK_SIZE;\n    }\n\n    sum = subgroupAdd(sum);\n    if (gl_SubgroupInvocationID == 0) {\n        temp[gl_SubgroupID] = sum;\n    }\n    barrier();\n    sum = 0;\n    [[unroll]] for (uint s = 0; s < BLOCK_SIZE / SUBGROUP_SIZE; ++s) {\n        sum += temp[s];\n    }\n\n    // Add the sum to what the first pass computed\n    if (col < p.n_cols) {\n        data_d[dst_idx + col] = v + sum;\n    }\n}\n\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_f32.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {float data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.x * 16;\n\n    if (i >= p.nel) {\n        return;\n    }\n\n    [[unroll]] for (uint l = 0; l < 16; l++) {\n        data_b[i + l] = D_TYPE(data_a[i + l]);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl",
    "content": "#if !defined(DATA_A_F32) && !defined(DATA_A_F16)\n#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require\n#endif\n\n#include \"types.glsl\"\n\n#if defined(DATA_A_F32)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);\n}\n#endif\n\n#if defined(DATA_A_F16)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);\n}\n#endif\n\n#if defined(DATA_A_BF16)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));\n}\n#endif\n\n#if defined(DATA_A_Q4_0)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);\n    return (vec2(vui & 0xF, vui >> 4) - 8.0f);\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);\n    return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f);\n}\n#endif\n\n#if defined(DATA_A_Q4_1)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);\n    return vec2(vui & 0xF, vui >> 4);\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);\n    return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12);\n}\n#endif\n\n#if defined(DATA_A_Q5_0)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0];\n    const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);\n    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);\n    return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f);\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0];\n    const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);\n    const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);\n    const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);\n    return (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f);\n}\n#endif\n\n#if defined(DATA_A_Q5_1)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    const uint uint_qh = data_a[a_offset + ib].qh;\n    const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);\n    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);\n    return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y);\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    const uint uint_qh = data_a_packed16[a_offset + ib].qh;\n    const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);\n    const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);\n    const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);\n    return vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y);\n}\n#endif\n\n#if defined(DATA_A_Q8_0)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147\n    const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy;\n    return vec4(v0.x, v0.y, v1.x, v1.y);\n}\n#endif\n\n#if defined(DATA_A_IQ1_S)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    const uint ib32 = iqs / 32;\n    const uint ib8 = iqs / 8;\n    const int i8 = int(iqs % 8);\n    const uint qh = data_a[a_offset + ib].qh[ib32];\n    const uint qs = data_a[a_offset + ib].qs[ib8];\n    const float dl = float(2 * bitfieldExtract(qh, 12, 3) + 1);\n    const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;\n    const uint idxhi = bitfieldExtract(qh, 3 * int(ib8 & 3), 3);\n    const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]);\n    // Signed bitfield extract.\n    const ivec2 gvec = ivec2(\n      bitfieldExtract(grid, 2 * (i8), 2),\n      bitfieldExtract(grid, 2 * (i8 + 1), 2)\n    );\n    return dl * (vec2(gvec) + delta);\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    const uint ib32 = iqs / 32;\n    const uint ib8 = iqs / 8;\n    const int i8 = int(iqs % 8);\n    const uint qh = data_a[a_offset + ib].qh[ib32];\n    const uint qs = data_a[a_offset + ib].qs[ib8];\n    const float dl = 2 * bitfieldExtract(qh, 12, 3) + 1;\n    const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;\n    const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);\n    // Signed bitfield extract.\n    const ivec4 gvec = ivec4(\n      bitfieldExtract(grid, 2 * (i8), 2),\n      bitfieldExtract(grid, 2 * (i8 + 1), 2),\n      bitfieldExtract(grid, 2 * (i8 + 2), 2),\n      bitfieldExtract(grid, 2 * (i8 + 3), 2)\n    );\n    return dl * (vec4(gvec) + delta);\n}\n#endif\n\n#if defined(DATA_A_IQ1_M)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    const uint ib8 = iqs / 8;\n    const uint ib16 = iqs / 16;\n    const int i8 = int(iqs % 8);\n    const uint sc = data_a[a_offset + ib].scales[iqs / 64];\n    const uint qs = data_a[a_offset + ib].qs[ib8];\n    const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));\n    const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;\n    const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;\n    const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);\n    // Signed bitfield extract.\n    const ivec2 gvec = ivec2(\n      bitfieldExtract(grid, 2 * (i8), 2),\n      bitfieldExtract(grid, 2 * (i8 + 1), 2)\n    );\n    return dl * (vec2(gvec) + delta);\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    const uint ib8 = iqs / 8;\n    const uint ib16 = iqs / 16;\n    const int i8 = int(iqs % 8);\n    const uint sc = data_a[a_offset + ib].scales[iqs / 64];\n    const uint qs = data_a[a_offset + ib].qs[ib8];\n    const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));\n    const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;\n    const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;\n    const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);\n    // Signed bitfield extract.\n    const ivec4 gvec = ivec4(\n      bitfieldExtract(grid, 2 * (i8), 2),\n      bitfieldExtract(grid, 2 * (i8 + 1), 2),\n      bitfieldExtract(grid, 2 * (i8 + 2), 2),\n      bitfieldExtract(grid, 2 * (i8 + 3), 2)\n    );\n    return dl * (vec4(gvec) + delta);\n}\n#endif\n\n#if defined(DATA_A_IQ2_XXS)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    const uint ib32 = iqs / 32;\n    const uint ib8 = (iqs / 8) % 4;\n    const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];\n    // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)\n    const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],\n        data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));\n    const float db = 0.25 * (0.5 + (signs >> 28));\n    const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);\n    // Add parity bit\n    const uint sign8 = sign7 | (bitCount(sign7) << 7);\n    const uint sign = sign8 >> (iqs % 8);\n    const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));\n    bool sign0 = (sign & 1) != 0;\n    bool sign1 = (sign & 2) != 0;\n    return db * vec2(\n        grid.x * (sign0 ? -1.0 : 1.0),\n        grid.y * (sign1 ? -1.0 : 1.0)\n    );\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    const uint ib32 = iqs / 32;\n    const uint ib8 = (iqs / 8) % 4;\n    const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];\n    // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)\n    const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],\n        data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));\n    const float db = 0.25 * (0.5 + (signs >> 28));\n    const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);\n    // Add parity bit\n    const uint sign8 = sign7 | (bitCount(sign7) << 7);\n    const uint sign = sign8 >> (iqs % 8);\n    const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));\n    bool sign0 = (sign & 1) != 0;\n    bool sign1 = (sign & 2) != 0;\n    bool sign2 = (sign & 4) != 0;\n    bool sign3 = (sign & 8) != 0;\n    return db * vec4(\n        grid.x * (sign0 ? -1.0 : 1.0),\n        grid.y * (sign1 ? -1.0 : 1.0),\n        grid.z * (sign2 ? -1.0 : 1.0),\n        grid.w * (sign3 ? -1.0 : 1.0)\n    );\n}\n#endif\n\n#if defined(DATA_A_IQ2_XS)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;\n    const uint qs = data_a[a_offset + ib].qs[iqs / 8];\n    const float db = 0.25 * (0.5 + scale);\n    const uint sign7 = qs >> 9;\n    // Add parity bit\n    const uint sign8 = sign7 | (bitCount(sign7) << 7);\n    const uint sign = sign8 >> (iqs % 8);\n    const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));\n    bool sign0 = (sign & 1) != 0;\n    bool sign1 = (sign & 2) != 0;\n    return db * vec2(\n        grid.x * (sign0 ? -1.0 : 1.0),\n        grid.y * (sign1 ? -1.0 : 1.0)\n    );\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;\n    const uint qs = data_a[a_offset + ib].qs[iqs / 8];\n    const float db = 0.25 * (0.5 + scale);\n    const uint sign7 = qs >> 9;\n    // Add parity bit\n    const uint sign8 = sign7 | (bitCount(sign7) << 7);\n    const uint sign = sign8 >> (iqs % 8);\n    const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));\n    bool sign0 = (sign & 1) != 0;\n    bool sign1 = (sign & 2) != 0;\n    bool sign2 = (sign & 4) != 0;\n    bool sign3 = (sign & 8) != 0;\n    return db * vec4(\n        grid.x * (sign0 ? -1.0 : 1.0),\n        grid.y * (sign1 ? -1.0 : 1.0),\n        grid.z * (sign2 ? -1.0 : 1.0),\n        grid.w * (sign3 ? -1.0 : 1.0)\n    );\n}\n#endif\n\n#if defined(DATA_A_IQ2_S)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    const uint ib32 = iqs / 32;\n    const uint ib8 = iqs / 8;\n\n    const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;\n    const uint qs = data_a[a_offset + ib].qs[ib8];\n    const uint qh = data_a[a_offset + ib].qh[ib32];\n    const uint qhshift = 2 * (ib8 % 4);\n    const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);\n\n    const float db = 0.25 * (0.5 + scale);\n    const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);\n    bool sign0 = (sign & 1) != 0;\n    bool sign1 = (sign & 2) != 0;\n    return db * vec2(\n        grid[iqs % 4] * (sign0 ? -1.0 : 1.0),\n        grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0)\n    );\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    const uint ib32 = iqs / 32;\n    const uint ib8 = iqs / 8;\n\n    const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;\n    const uint qs = data_a[a_offset + ib].qs[ib8];\n    const uint qh = data_a[a_offset + ib].qh[ib32];\n    const uint qhshift = 2 * (ib8 % 4);\n    const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);\n\n    const float db = 0.25 * (0.5 + scale);\n    const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);\n    bool sign0 = (sign & 1) != 0;\n    bool sign1 = (sign & 2) != 0;\n    bool sign2 = (sign & 4) != 0;\n    bool sign3 = (sign & 8) != 0;\n    return db * vec4(\n        grid.x * (sign0 ? -1.0 : 1.0),\n        grid.y * (sign1 ? -1.0 : 1.0),\n        grid.z * (sign2 ? -1.0 : 1.0),\n        grid.w * (sign3 ? -1.0 : 1.0)\n    );\n}\n#endif\n\n#if defined(DATA_A_IQ3_XXS)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    const uint ib4 = iqs / 4;\n    const uint ib32 = iqs / 32;\n    const uint is = QUANT_K / 4 + 4 * ib32;\n    const uint qs = data_a[a_offset + ib].qs[ib4];\n    // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)\n    const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],\n        data_a_packed16[a_offset + ib].qs[is / 2 + 1]));\n    const float db = 0.5 * (0.5 + (signs >> 28));\n    const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);\n    // Add parity bit\n    const uint sign8 = sign7 | (bitCount(sign7) << 7);\n    const uint sign = sign8 >> (iqs % 8);\n    const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4)));\n    bool sign0 = (sign & 1) != 0;\n    bool sign1 = (sign & 2) != 0;\n    return db * vec2(\n        grid.x * (sign0 ? -1.0 : 1.0),\n        grid.y * (sign1 ? -1.0 : 1.0)\n    );\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    const uint ib4 = iqs / 4;\n    const uint ib32 = iqs / 32;\n    const uint is = QUANT_K / 4 + 4 * ib32;\n    const uint qs = data_a[a_offset + ib].qs[ib4];\n    const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],\n        data_a_packed16[a_offset + ib].qs[is / 2 + 1]));\n    const float db = 0.5 * (0.5 + (signs >> 28));\n    const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);\n    // Add parity bit\n    const uint sign8 = sign7 | (bitCount(sign7) << 7);\n    const uint sign = sign8 >> (iqs % 8);\n    const u8vec4 grid = unpack8(iq3xxs_grid[qs]);\n    bool sign0 = (sign & 1) != 0;\n    bool sign1 = (sign & 2) != 0;\n    bool sign2 = (sign & 4) != 0;\n    bool sign3 = (sign & 8) != 0;\n    return db * vec4(\n        grid.x * (sign0 ? -1.0 : 1.0),\n        grid.y * (sign1 ? -1.0 : 1.0),\n        grid.z * (sign2 ? -1.0 : 1.0),\n        grid.w * (sign3 ? -1.0 : 1.0)\n    );\n}\n#endif\n\n#if defined(DATA_A_IQ3_S)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    const uint qs = data_a[a_offset + ib].qs[iqs / 4];\n    const uint qh = data_a[a_offset + ib].qh[iqs / 32];\n    const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);\n    const uint scale = data_a[a_offset + ib].scales[iqs / 64];\n    bool sign0 = (sign & 1) != 0;\n    bool sign1 = (sign & 2) != 0;\n    const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf);\n    const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4));\n    return db * vec2(\n        int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),\n        int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0)\n    );\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    const uint ib4 = iqs / 4;\n    const uint ib32 = iqs / 32;\n    const uint qs = data_a[a_offset + ib].qs[ib4];\n    const uint qh = data_a[a_offset + ib].qh[ib32];\n    const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);\n    const uint scale = data_a[a_offset + ib].scales[ib32 / 2];\n    bool sign0 = (sign & 1) != 0;\n    bool sign1 = (sign & 2) != 0;\n    bool sign2 = (sign & 4) != 0;\n    bool sign3 = (sign & 8) != 0;\n    const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf);\n    const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4));\n    return db * vec4(\n        int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),\n        int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0),\n        int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0),\n        int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0)\n    );\n}\n#endif\n\n#if defined(DATA_A_IQ4_XS)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    const uint ib32 = iqs / 32;\n    const uint iq = 16 * ib32 + (iqs % 16);\n\n    const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;\n    const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;\n    const uint qshift = (iqs & 16) >> 2;\n    u8vec2 qs = u8vec2(data_a[a_offset + ib].qs[iq], data_a[a_offset + ib].qs[iq + 1]);\n    qs = (qs >> qshift) & uint8_t(0xF);\n\n    const float dl = float(int(sl | (sh << 4)) - 32);\n    return dl * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    const uint ib32 = iqs / 32;\n    const uint iq = 16 * ib32 + (iqs % 16);\n\n    const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;\n    const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;\n    const uint qshift = (iqs & 16) >> 2;\n    const u8vec4 qs = unpack8((data_a_packed32[a_offset + ib].qs[iq/4] >> qshift) & 0x0F0F0F0F);\n\n    const float dl = float(int(sl | (sh << 4)) - 32);\n    return dl * vec4(\n        kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y],\n        kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);\n}\n#endif\n\n#if defined(DATA_A_IQ4_NL)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);\n    return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]);\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);\n    return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]);\n}\n#endif\n\n#if defined(DATA_A_MXFP4)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);\n    return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;\n}\nvec4 dequantize4(uint ib, uint iqs, uint a_offset) {\n    vec2 v0 = dequantize(ib, iqs, a_offset);\n    vec2 v1 = dequantize(ib, iqs + 1, a_offset);\n    return vec4(v0.x, v0.y, v1.x, v1.y);\n}\n#endif\n\n#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)\nvec2 get_dm(uint ib, uint a_offset) {\n    return vec2(0, 0);\n}\n#endif\n\n#if defined(DATA_A_IQ1_M)\nvec2 get_dm(uint ib, uint a_offset) {\n    const uint16_t[4] scales = data_a[a_offset + ib].scales;\n    const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;\n    const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);\n    return vec2(d, 0);\n}\n#endif\n\n#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)\nvec2 get_dm(uint ib, uint a_offset) {\n    return vec2(float(data_a[a_offset + ib].d), 0);\n}\n#endif\n\n#if defined(DATA_A_MXFP4)\nvec2 get_dm(uint ib, uint a_offset) {\n    return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);\n}\n#endif\n\n#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)\nvec2 get_dm(uint ib, uint a_offset) {\n    const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm);\n    return dm;\n}\n#endif\n\n#if defined(DATA_A_Q2_K)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    iqs /= 2;\n    const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30\n    const uint scalesi = iqs / 8;                      // 0..15\n    const uint qsshift = ((iqs % 64) / 16) * 2;        // 0,2,4,6\n\n    const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);\n    const uint scales = data_a[a_offset + ib].scales[scalesi];\n    const vec2 dm = vec2(data_a[a_offset + ib].dm);\n\n    return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);\n}\nvec2 get_dm(uint ib, uint a_offset) {\n    return vec2(1, 0);\n}\n#endif\n\n#if defined(DATA_A_Q3_K)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    iqs /= 2;\n    const uint n = iqs / 64;                     // 0,1\n    const uint qsi = n * 32 + (iqs % 16) * 2;    // 0,2,4..62\n    const uint hmi =          (iqs % 16) * 2;    // 0,2,4..30\n    const uint j = (iqs % 64) / 4;               // 0..3\n    const uint is = iqs / 8;                     // 0..15\n    const uint halfsplit = ((iqs % 64) / 16);    // 0,1,2,3\n    const uint qsshift = halfsplit * 2;          // 0,2,4,6\n    const uint m = 1 << (4 * n + halfsplit);     // 1,2,4,8,16,32,64,128\n\n    const int8_t us = int8_t(((data_a[a_offset + ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)\n                          | (((data_a[a_offset + ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));\n    const float dl = float(data_a[a_offset + ib].d) * float(us - 32);\n\n    return vec2(dl * float(int8_t((data_a[a_offset + ib].qs[qsi    ] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi    ] & m) != 0) ? 0 : 4)),\n                dl * float(int8_t((data_a[a_offset + ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));\n}\nvec2 get_dm(uint ib, uint a_offset) {\n    return vec2(1, 0);\n}\n#endif\n\n#if defined(DATA_A_Q4_K)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    iqs /= 2;\n    const uint n = iqs / 32;                   // 0,1,2,3\n    const uint b = (iqs % 32) / 16;            // 0,1\n    const uint is = 2 * n + b;                 // 0..7\n    const uint qsi = n * 32 + (iqs % 16) * 2;  // 0,2,4..126\n\n    const vec2 loadd = vec2(data_a[a_offset + ib].dm);\n\n    const uint scidx0 = (is < 4) ? is : (is + 4);\n    const uint scidx1 = (is < 4) ? is : (is - 4);\n    const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;\n    const uint scidxshift1 = (is < 4) ? 0 : 2;\n    const uint mbidx0 = is + 4;\n    const uint mbidx1 = (is < 4) ? is + 4 : is;\n    const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;\n    const uint mbidxshift0 = (is < 4) ? 0 : 4;\n    const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;\n    const uint mbidxshift1 = (is < 4) ? 0 : 2;\n\n    const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1));\n    const uint8_t mbyte = uint8_t((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));\n\n    const float d = loadd.x * sc;\n    const float m = -loadd.y * mbyte;\n\n    return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi    ] >> (b * 4)) & 0xF), m),\n                fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));\n}\nvec2 get_dm(uint ib, uint a_offset) {\n    return vec2(1, 0);\n}\n#endif\n\n#if defined(DATA_A_Q5_K)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    iqs /= 2;\n    const uint n = iqs / 32;                   // 0,1,2,3\n    const uint b = (iqs % 32) / 16;            // 0,1\n    const uint is = 2 * n + b;                 // 0..7\n    const uint qsi = n * 32 + (iqs % 16) * 2;  // 0,2,4..126\n    const uint qhi = (iqs % 16) * 2;           // 0,2,4..30\n\n    const uint8_t hm = uint8_t(1 << (iqs / 16));\n\n    const vec2 loadd = vec2(data_a[a_offset + ib].dm);\n\n    const uint scidx0 = (is < 4) ? is : (is + 4);\n    const uint scidx1 = (is < 4) ? is : (is - 4);\n    const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;\n    const uint scidxshift1 = (is < 4) ? 0 : 2;\n    const uint mbidx0 = is + 4;\n    const uint mbidx1 = (is < 4) ? is + 4 : is;\n    const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;\n    const uint mbidxshift0 = (is < 4) ? 0 : 4;\n    const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;\n    const uint mbidxshift1 = (is < 4) ? 0 : 2;\n\n    const uint8_t sc    = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF)                         | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1));\n    const uint8_t mbyte = uint8_t(((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));\n\n    const float d = loadd.x * sc;\n    const float m = -loadd.y * mbyte;\n\n    return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi    ] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi    ] & hm) != 0 ? 16 : 0), m),\n                fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));\n}\nvec2 get_dm(uint ib, uint a_offset) {\n    return vec2(1, 0);\n}\n#endif\n\n#if defined(DATA_A_Q6_K)\nvec2 dequantize(uint ib, uint iqs, uint a_offset) {\n    iqs /= 2;\n    const uint n = iqs / 64;                    // 0,1\n    const uint b = (iqs % 64) / 32;             // 0,1\n    const uint is_b = (iqs % 16) / 8;           // 0,1\n    const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6\n    const uint is = 8 * n + qhshift + is_b;     // 0..15\n    const uint qsi = n * 64 + (iqs % 32) * 2;   // 0,2,4..126\n    const uint qhi = n * 32 + (iqs % 16) * 2;   // 0,2,4..62\n\n    const float dscale = float(data_a[a_offset + ib].d) * float(data_a[a_offset + ib].scales[is]);\n\n    return vec2(dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi    ] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi    ] >> qhshift) & 3) << 4)) - 32),\n                dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));\n}\nvec2 get_dm(uint ib, uint a_offset) {\n    return vec2(1, 0);\n}\n#endif\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl",
    "content": "\n#include \"types.glsl\"\n\nlayout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {\n   vec4 block;\n};\n\nfloat16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    const vec4 v = bl.block;\n    const uint idx = coordInBlock[1];\n    const f16vec4 vf16 = f16vec4(v);\n    return vf16[idx];\n}\n\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {\n   block_q4_0_packed16 block;\n};\n\nfloat16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    const float16_t d = bl.block.d;\n    const uint idx = coordInBlock[1];\n    const uint shift = (idx & 0x10) >> 2;\n    uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]);\n    qs >>= shift;\n    qs &= 0x0F0F;\n    qs = unpack8(qs)[idx & 1];\n    float16_t ret = (float16_t(qs) - float16_t(8)) * d;\n    return ret;\n}\n\nlayout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {\n   block_q4_1 block;\n};\n\nfloat16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    const float16_t d = bl.block.d;\n    const float16_t m = bl.block.m;\n    const uint idx = coordInBlock[1];\n    const uint iqs = idx & 0xF;\n    const uint shift = (idx & 0x10) >> 2;\n    uint32_t qs = bl.block.qs[iqs];\n    qs >>= shift;\n    qs &= 0xF;\n    float16_t ret = float16_t(qs) * d + m;\n    return ret;\n}\n\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {\n   block_q5_0 block;\n};\n\nfloat16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    const float16_t d = bl.block.d;\n    const uint idx = coordInBlock[1];\n    const uint iqs = idx & 0xF;\n\n    const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0];\n    const uint qh = ((uint_qh >> idx) << 4) & 0x10;\n\n    const uint shift = (idx & 0x10) >> 2;\n    uint32_t qs = bl.block.qs[iqs];\n    qs >>= shift;\n    qs &= 0xF;\n\n    float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d;\n    return ret;\n}\n\nlayout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {\n   block_q5_1 block;\n};\n\nfloat16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    const float16_t d = bl.block.d;\n    const float16_t m = bl.block.m;\n    const uint idx = coordInBlock[1];\n    const uint iqs = idx & 0xF;\n\n    const uint uint_qh = bl.block.qh;\n    const uint qh = ((uint_qh >> idx) << 4) & 0x10;\n\n    const uint shift = (idx & 0x10) >> 2;\n    uint32_t qs = bl.block.qs[iqs];\n    qs >>= shift;\n    qs &= 0xF;\n\n    float16_t ret = float16_t(qs | qh) * d + m;\n    return ret;\n}\n\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {\n   block_q8_0_packed16 block;\n};\n\nfloat16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    const float16_t d = bl.block.d;\n    const uint idx = coordInBlock[1];\n    const uint iqs = idx;\n\n    // Load 16b and select the byte for this element\n    int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1];\n    float16_t ret = float16_t(qs) * d;\n    return ret;\n}\n\nlayout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {\n   block_q2_K block;\n};\n\nlayout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 {\n   block_q2_K_packed16 block;\n};\n\nfloat16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);\n    const f16vec2 dm = bl.block.dm;\n    const uint idx = coordInBlock[1];\n\n    const uint scalesi = (idx & 0xF0) >> 4;             // 0..15\n    const uint qsshift = (idx & 0x60) >> 4;             // 0,2,4,6\n\n    uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);\n    qs = (qs >> qsshift) & 0x0303;\n    qs = unpack8(qs)[idx & 1];\n\n    const uint scales = bl.block.scales[scalesi];\n    float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);\n    return ret;\n}\n\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {\n   block_q3_K block;\n};\n\nfloat16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    const uint idx = coordInBlock[1];\n    const uint iqs = idx;\n\n    const uint n = iqs / 128;                    // 0,1\n    const uint qsi = n * 32 + (iqs % 32);        // 0..63\n    const uint hmi =          (iqs % 32);        // 0..31\n    const uint j = (iqs % 128) / 8;              // 0..15\n    const uint is = iqs / 16;                    // 0..15\n    const uint halfsplit = ((iqs % 128) / 32);   // 0,1,2,3\n    const uint qsshift = halfsplit * 2;          // 0,2,4,6\n    const uint m = 1 << (4 * n + halfsplit);     // 1,2,4,8,16,32,64,128\n\n    uint32_t scaleidx0 = (is < 8) ? is : (is-8);\n    uint32_t scaleidx0shift = (is < 8) ? 0 : 4;\n    uint32_t scaleidx1 = is + 8 - (is/4)*4;\n    uint32_t scaleidx1shift = (is/4)*2;\n\n    const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));\n\n    const float16_t dl = bl.block.d * float16_t(us - 32);\n\n    float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi    ] >> qsshift) & 3) - (((bl.block.hmask[hmi    ] & m) != 0) ? 0 : 4));\n\n    return ret;\n}\n\nlayout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {\n   block_q4_K block;\n};\n\nlayout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 {\n   block_q4_K_packed16 block;\n};\n\nlayout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {\n   block_q4_K_packed128 block;\n};\n\n#if defined(IS_MUL_MM2)\n\n// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales\n// into shared memory and then process the whole tile using those scales.\n// There is a fetch function that loads into private variables and then a store\n// function that stores into shared memory.\n// Q4_K and Q5_K have the same encoding of scales, so everything is shared except\n// the part that fetches from the structure (which has a different block layout).\n#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)\nconst uint shAscales_stride = (BM + 2);\n// 1 scale per 32 elements -> 8 scales per block, per row\nshared vec2 shAscales[8 * shAscales_stride];\nuvec4 row_v;\n#endif\n\n#if defined(DATA_A_Q4_K)\nlayout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};\n\nvoid fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)\n{\n    uint tids_per_row = BLOCK_SIZE / BM;\n    uint is_per_tid = 8 / tids_per_row;\n    uint is_start = is_per_tid * (tid % tids_per_row);\n    uint tid_row = tid / tids_per_row;\n\n    uint row = ir_BM + tid_row;\n    uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);\n    if (in_bounds || row < p.M) {\n        row_v = data_a_q4_k_packed128[block_index].q4k[0];\n    }\n}\n#endif\n#if defined(DATA_A_Q5_K)\nlayout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};\n\nvoid fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)\n{\n    uint tids_per_row = BLOCK_SIZE / BM;\n    uint is_per_tid = 8 / tids_per_row;\n    uint is_start = is_per_tid * (tid % tids_per_row);\n    uint tid_row = tid / tids_per_row;\n\n    uint row = ir_BM + tid_row;\n    uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);\n    if (in_bounds || row < p.M) {\n        row_v = data_a_q5_k_packed128[block_index].q5k[0];\n    }\n}\n#endif\n\n#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)\nvoid store_scalesQ4_K(uint tid)\n{\n    barrier();\n\n    uint tids_per_row = BLOCK_SIZE / BM;\n    uint is_per_tid = 8 / tids_per_row;\n    uint is_start = is_per_tid * (tid % tids_per_row);\n    uint tid_row = tid / tids_per_row;\n\n    [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {\n        uint is = idx + is_start;\n        uvec4 v = row_v;\n        const vec2 loadd = vec2(unpackFloat2x16(v.x));\n\n        uint32_t sc;\n        uint32_t mbyte;\n\n        uint32_t scale0 = v.y;\n        uint32_t scale4 = v.z;\n        uint32_t scale8 = v.w;\n\n        uint32_t sc_lo = scale0;\n        uint32_t mb_lo = scale4;\n        uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);\n        uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);\n\n        sc = is < 4 ? sc_lo : sc_hi;\n        mbyte = is < 4 ? mb_lo : mb_hi;\n        sc = sc >> (8 * (is & 3));\n        mbyte = mbyte >> (8 * (is & 3));\n        sc &= 0x3F;\n        mbyte &= 0x3F;\n\n        const float d = loadd.x * float(sc);\n        const float m = loadd.y * float(mbyte);\n        shAscales[is * shAscales_stride + tid_row] = vec2(d,m);\n    }\n\n    barrier();\n}\n#endif\n\n#endif\n\nfloat16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);\n    decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);\n    const uint idx = coordInBlock[1];\n\n    const uint b = (idx & 0x20) >> 5;            // 0,1\n    const uint is = (idx & 0xE0) >> 5;         // 0..7\n\n#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)\n    vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];\n    float d = v.x;\n    float m = v.y;\n#else\n    uvec4 v = bl128.block.q4k[0];\n    const vec2 loadd = vec2(unpackFloat2x16(v.x));\n\n    uint32_t sc;\n    uint32_t mbyte;\n\n    uint32_t scale0 = v.y;\n    uint32_t scale4 = v.z;\n    uint32_t scale8 = v.w;\n\n    uint32_t sc_lo = scale0;\n    uint32_t mb_lo = scale4;\n    uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);\n    uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);\n\n    sc = is < 4 ? sc_lo : sc_hi;\n    mbyte = is < 4 ? mb_lo : mb_hi;\n    sc = sc >> (8 * (is & 3));\n    mbyte = mbyte >> (8 * (is & 3));\n    sc &= 0x3F;\n    mbyte &= 0x3F;\n\n    const float d = loadd.x * float(sc);\n    const float m = loadd.y * float(mbyte);\n#endif\n\n    uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);\n    qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;\n\n    float ret = d * float(qs) - m;\n\n    return float16_t(ret);\n}\n\nlayout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {\n   block_q5_K block;\n};\n\nlayout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 {\n   block_q5_K_packed16 block;\n};\n\nlayout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {\n   block_q5_K_packed128 block;\n};\n\nfloat16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);\n    decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);\n    const uint idx = coordInBlock[1];\n\n    const uint b = (idx & 0x20) >> 5;          // 0,1\n    const uint is = (idx & 0xE0) >> 5;         // 0..7\n\n#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)\n    vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];\n    float d = v.x;\n    float m = v.y;\n#else\n    uvec4 v = bl128.block.q5k[0];\n\n    const f16vec2 loadd = unpackFloat2x16(v.x);\n\n    uint32_t sc;\n    uint32_t mbyte;\n\n    uint32_t scale0 = v.y;\n    uint32_t scale4 = v.z;\n    uint32_t scale8 = v.w;\n\n    uint32_t sc_lo = scale0;\n    uint32_t mb_lo = scale4;\n    uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);\n    uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);\n\n    sc = is < 4 ? sc_lo : sc_hi;\n    mbyte = is < 4 ? mb_lo : mb_hi;\n    sc = sc >> (8 * (is & 3));\n    mbyte = mbyte >> (8 * (is & 3));\n    sc &= 0x3F;\n    mbyte &= 0x3F;\n\n    const float16_t d = loadd.x * float16_t(sc);\n    const float16_t m = loadd.y * float16_t(mbyte);\n#endif\n\n    uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);\n    qh = ((qh >> is) & 0x101) << 4;\n\n    uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);\n    qs = (qs >> (b * 4)) & 0x0F0F;\n    qs = unpack8(qs | qh)[idx & 1];\n\n    float ret = d * float(qs) - m;\n\n    return float16_t(ret);\n}\n\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {\n   block_q6_K block;\n};\n\nlayout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 {\n   block_q6_K_packed16 block;\n};\n\nfloat16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);\n    const uint idx = coordInBlock[1];\n\n    const uint b = (idx & 0x40) >> 6;           // 0,1\n    const uint qhshift = (idx & 0x60) >> 4;    // 0,2,4,6\n    const uint is = (idx & 0xF0) >> 4;          // 0..15\n\n    const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);\n\n    uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]);\n    ql = (ql >> (b * 4)) & 0x0F0F;\n\n    uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);\n    qh = ((qh >> qhshift) & 0x0303) << 4;\n\n    int q = unpack8(ql | qh)[idx & 1];\n\n    float16_t ret = dscale * float16_t(q - 32);\n\n    return ret;\n}\n\n#if defined(DATA_A_IQ1_S)\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S {\n   block_iq1_s block;\n};\n\nfloat16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    const float16_t d = bl.block.d;\n    const uint idx = coordInBlock[1];\n\n    const uint ib32 = (idx & 0xE0) >> 5;\n    const uint ib8 = (idx & 0xF8) >> 3;\n\n    const uint qh = bl.block.qh[ib32];\n    const uint qs = bl.block.qs[ib8];\n    const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);\n    const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;\n    const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)];\n\n    float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta));\n    return ret;\n}\n#endif\n\n#if defined(DATA_A_IQ1_M)\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M {\n   block_iq1_m block;\n};\n\nlayout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 {\n   block_iq1_m_packed64 block;\n};\n\nfloat16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);\n    const uint idx = coordInBlock[1];\n\n    uvec2 scales = unpack32(bl64.block.scales);\n    const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));\n\n    const uint ib8 = (idx & 0xF8) >> 3;\n    const uint ib16 = (idx & 0xF0) >> 4;\n    const int i8 = int(idx % 8);\n    const uint sc = bl.block.scales[ib8 / 8];\n    const uint qs = bl.block.qs[ib8];\n    const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1));\n    const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;\n    const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;\n    const uint grid = iq1s_grid[qs | ((qh & 7) << 8)];\n\n    float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta));\n    return ret;\n}\n#endif\n\n#if defined(DATA_A_IQ2_XXS)\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS {\n   block_iq2_xxs block;\n};\n\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 {\n   block_iq2_xxs_packed16 block;\n};\n\nfloat16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl);\n    const float16_t d = bl.block.d;\n    const uint idx = coordInBlock[1];\n\n    const uint ib32 = (idx & 0xE0) >> 5; // 0..7\n    const uint ib8 = (idx & 0x18) >> 3;  // 0..3\n    const uint iqs = 8 * ib32 + ib8;\n\n    const uint qs = bl.block.qs[iqs];\n    const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));\n\n    const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));\n    uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);\n    sign |= bitCount(sign) << 7;\n\n    uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2];\n    g2 >>= (idx & 2) * 8;\n    const vec2 g = vec2(unpack8(g2));\n\n    vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);\n    return float16_t(ret[idx & 1]);\n}\n#endif\n\n#if defined(DATA_A_IQ2_XS)\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS {\n   block_iq2_xs block;\n};\n\nfloat16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    const float16_t d = bl.block.d;\n    const uint idx = coordInBlock[1];\n\n    const uint is = (idx & 0xE0) >> 5;     // 0..8\n    const uint sshift = (idx & 0x10) >> 2; // 0,4\n    const uint iqs = (idx & 0xF8) >> 3;    // 0..63\n\n    const uint16_t qs = bl.block.qs[iqs];\n    const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF));\n\n    uint sign = uint(qs >> 9);\n    sign |= bitCount(sign) << 7;\n    uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2];\n    g2 >>= (idx & 2) * 8;\n    const vec2 g = vec2(unpack8(g2));\n\n    vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);\n    return float16_t(ret[idx & 1]);\n}\n#endif\n\n#if defined(DATA_A_IQ2_S)\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S {\n   block_iq2_s block;\n};\n\nfloat16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    uint idx = coordInBlock[1];\n\n    const uint ib32 = (idx & 0xE0) >> 5;        // 0..7\n    const uint ib8 = (idx & 0xF8) >> 3;         // 0..31\n    const uint qhshift = 2 * (ib8 % 4);\n\n    const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf;\n    const uint qs = bl.block.qs[ib8];\n    const uint qh = bl.block.qh[ib32];\n    const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6);\n\n    const float d = float(bl.block.d);\n    const float db = d * 0.25 * (0.5 + scale);\n    const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign));\n    uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2];\n    g2 >>= (idx & 2) * 8;\n    const vec2 v = db * vec2(sign01) * vec2(unpack8(g2));\n    return float16_t(v[idx & 1]);\n}\n#endif\n\n#if defined(DATA_A_IQ3_XXS)\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS {\n   block_iq3_xxs block;\n};\n\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 {\n   block_iq3_xxs_packed16 block;\n};\n\nfloat16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl);\n    uint idx = coordInBlock[1];\n\n    const uint iqs = (idx & 0xFC) >> 2;             // 0..63\n    const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values\n\n    const float d = float(bl.block.d);\n    const uint qs = bl.block.qs[iqs];\n    const uint signs = pack32(u16vec2(\n        bl16.block.qs[is/2+0],\n        bl16.block.qs[is/2+1]\n    ));\n    const float db = d * 0.5 * (0.5 + (signs >> 28));\n    const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);\n    const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6);\n    const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));\n    const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1));\n    const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);\n    return float16_t(v[idx & 1]);\n}\n#endif\n\n#if defined(DATA_A_IQ3_S)\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S {\n   block_iq3_s block;\n};\n\nfloat16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    uint idx = coordInBlock[1];\n\n    const uint iqs = (idx & 0xFC) >> 2;           // 0..63\n    const uint iqh = (idx & 0xE0) >> 5;\n\n    const float d = float(bl.block.d);\n    const uint qs = bl.block.qs[iqs];\n    const uint qh = bl.block.qh[iqh];\n    const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6));\n    const uint scale = bl.block.scales[iqs / 16];\n    const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));\n    const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));\n    const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3);\n    const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);\n\n    return float16_t(v[idx & 1]);\n}\n#endif\n\n#if defined(DATA_A_IQ4_XS)\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS {\n   block_iq4_xs block;\n};\n\nfloat16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    const float16_t d = bl.block.d;\n    const uint idx = coordInBlock[1];\n\n    const uint ib32 = (idx & 0xE0) >> 5; // 0..7\n\n    const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;\n    const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3;\n    const uint qshift = (idx & 16) >> 2;\n    const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF;\n\n    float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]);\n    return ret;\n}\n#endif\n\n#if defined(DATA_A_IQ4_NL)\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {\n   block_iq4_nl block;\n};\n\nfloat16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    const float16_t d = bl.block.d;\n    const uint idx = coordInBlock[1];\n    const uint iqs = idx & 0xF;\n    const uint shift = (idx & 0x10) >> 2;\n    uint32_t qs = bl.block.qs[iqs];\n    qs >>= shift;\n    qs &= 0xF;\n    float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;\n    return ret;\n}\n#endif\n\n#if defined(DATA_A_MXFP4)\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {\n   block_mxfp4 block;\n};\n\nfloat16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    const float d = e8m0_to_fp32(bl.block.e);\n    const uint idx = coordInBlock[1];\n    const uint iqs = idx & 0xF;\n    const uint shift = (idx & 0x10) >> 2;\n    uint32_t qs = bl.block.qs[iqs];\n    qs >>= shift;\n    qs &= 0xF;\n    float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);\n    return ret;\n}\n#endif\n\n#if defined(DATA_A_Q4_0)\n#define dequantFuncA dequantFuncQ4_0\n#elif defined(DATA_A_Q4_1)\n#define dequantFuncA dequantFuncQ4_1\n#elif defined(DATA_A_Q5_0)\n#define dequantFuncA dequantFuncQ5_0\n#elif defined(DATA_A_Q5_1)\n#define dequantFuncA dequantFuncQ5_1\n#elif defined(DATA_A_Q8_0)\n#define dequantFuncA dequantFuncQ8_0\n#elif defined(DATA_A_Q2_K)\n#define dequantFuncA dequantFuncQ2_K\n#elif defined(DATA_A_Q3_K)\n#define dequantFuncA dequantFuncQ3_K\n#elif defined(DATA_A_Q4_K)\n#define dequantFuncA dequantFuncQ4_K\n#define fetch_scales fetch_scalesQ4_K\n#define store_scales store_scalesQ4_K\n#elif defined(DATA_A_Q5_K)\n#define dequantFuncA dequantFuncQ5_K\n#define fetch_scales fetch_scalesQ5_K\n#define store_scales store_scalesQ4_K\n#elif defined(DATA_A_Q6_K)\n#define dequantFuncA dequantFuncQ6_K\n#elif defined(DATA_A_IQ1_S)\n#define dequantFuncA dequantFuncIQ1_S\n#elif defined(DATA_A_IQ1_M)\n#define dequantFuncA dequantFuncIQ1_M\n#elif defined(DATA_A_IQ2_XXS)\n#define dequantFuncA dequantFuncIQ2_XXS\n#elif defined(DATA_A_IQ2_XS)\n#define dequantFuncA dequantFuncIQ2_XS\n#elif defined(DATA_A_IQ2_S)\n#define dequantFuncA dequantFuncIQ2_S\n#elif defined(DATA_A_IQ3_XXS)\n#define dequantFuncA dequantFuncIQ3_XXS\n#elif defined(DATA_A_IQ3_S)\n#define dequantFuncA dequantFuncIQ3_S\n#elif defined(DATA_A_IQ4_XS)\n#define dequantFuncA dequantFuncIQ4_XS\n#elif defined(DATA_A_IQ4_NL)\n#define dequantFuncA dequantFuncIQ4_NL\n#elif defined(DATA_A_MXFP4)\n#define dequantFuncA dequantFuncMXFP4\n#elif defined(DATA_A_F32)\n#define dequantFuncA dequantFuncF32\n#endif\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_head.glsl",
    "content": "#extension GL_EXT_control_flow_attributes : require\n#extension GL_EXT_shader_16bit_storage : require\n\nlayout (push_constant) uniform parameter\n{\n    uint M;\n    uint K;\n    uint stride_a;\n    uint stride_b;\n    uint nel;\n} p;\n\n#include \"types.glsl\"\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp",
    "content": "#version 450\n\n#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_iq1_m data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    // Each thread handles 1 subblock (32 values with 2 scales)\n    const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    if (ib >= p.nel / 256) {\n        return;\n    }\n\n    const uint ib32 = gl_LocalInvocationID.x % 8;\n    const uint ib64 = ib32 / 2;\n    const uint b_idx = 256 * ib + 32 * ib32;\n\n    const uint16_t[4] scales = data_a[ib].scales;\n    const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;\n    const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);\n\n    const uint sc = data_a[ib].scales[ib64];\n    [[unroll]] for (int l = 0; l < 4; ++l) {\n        const uint ib16 = 2 * ib32 + l / 2;\n        const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);\n        const uint qh = data_a[ib].qh[ib16] >> (4 * (l & 1));\n        const uint qs = data_a[ib].qs[4 * ib32 + l];\n        const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;\n        const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);\n        [[unroll]] for (int j = 0; j < 8; ++j) {\n            data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_iq1_s data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    // Each thread handles 1 subblock (32 values with 2 scales)\n    const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    if (ib >= p.nel / 256) {\n        return;\n    }\n\n    const uint ib32 = gl_LocalInvocationID.x % 8;\n    const uint b_idx = 256 * ib + 32 * ib32;\n\n    uint qh = data_a[ib].qh[ib32];\n    const float d = float(data_a[ib].d);\n    const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);\n    const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;\n    [[unroll]] for (uint l = 0; l < 4; ++l) {\n        const uint qs = data_a[ib].qs[4 * ib32 + l];\n        const uint hi = bitfieldExtract(qh, 3 * int(l), 3);\n        const int16_t grid = int16_t(iq1s_grid[qs | (hi << 8)]);\n        [[unroll]] for (int j = 0; j < 8; ++j) {\n            data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_iq2_s data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    // Each thread handles 1 subblock (32 values with 2 scales)\n    const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    if (ib >= p.nel / 256) {\n        return;\n    }\n\n    const uint ib32 = gl_LocalInvocationID.x % 8;\n    const uint b_idx = 256 * ib + 32 * ib32;\n\n    const float d = float(data_a[ib].d);\n    const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);\n    const vec2 db = d * (0.5 + scale) * 0.25;\n\n    uint qh = data_a[ib].qh[ib32];\n    [[unroll]] for (uint l = 0; l < 4; ++l) {\n        uint qs = data_a[ib].qs[4 * ib32 + l];\n        const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l];\n        qs |= (qh << (8 - 2 * l)) & 0x300;\n        const uvec2 grid = iq2s_grid[qs];\n        const u8vec4 grid0 = unpack8(grid.x);\n        const u8vec4 grid1 = unpack8(grid.y);\n        data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_iq2_xs data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    // Each thread handles 1 subblock (32 values with 2 scales)\n    const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    if (ib >= p.nel / 256) {\n        return;\n    }\n\n    const uint ib32 = gl_LocalInvocationID.x % 8;\n    const uint b_idx = 256 * ib + 32 * ib32;\n\n    const float d = float(data_a[ib].d);\n    const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);\n    const vec2 db = d * (0.5 + scale) * 0.25;\n\n    [[unroll]] for (uint l = 0; l < 4; ++l) {\n        uint16_t qs = data_a[ib].qs[4 * ib32 + l];\n        const uint sign7 = qs >> 9;\n        const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit\n        const uvec2 grid = iq2xs_grid[qs & 511];\n        const u8vec4 grid0 = unpack8(grid.x);\n        const u8vec4 grid1 = unpack8(grid.y);\n        data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_iq2_xxs data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    // Each thread handles 1 scale block (32 values)\n    // Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits\n    const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    if (ib >= p.nel / 256) {\n        return;\n    }\n\n    const uint is = gl_LocalInvocationID.x % 8;\n    const uint b_idx = 256 * ib + 32 * is;\n\n    const float d = float(data_a[ib].d);\n    uint signscale = pack32(u8vec4(\n        data_a[ib].qs[8*is + 4],\n        data_a[ib].qs[8*is + 5],\n        data_a[ib].qs[8*is + 6],\n        data_a[ib].qs[8*is + 7]\n    ));\n    const float db = d * (0.5 + (signscale >> 28)) * 0.25;\n\n    [[unroll]] for (uint l = 0; l < 4; ++l) {\n        const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);\n        const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit\n        const uint qs = data_a[ib].qs[8 * is + l];\n        const uvec2 grid = iq2xxs_grid[qs];\n        const u8vec4 grid0 = unpack8(grid.x);\n        const u8vec4 grid1 = unpack8(grid.y);\n        data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_iq3_s data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    // Each thread handles 1 scale nibble.\n    // Each block contains 4 scale bytes (8 scales) for 256 output values.\n    const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    if (ib >= p.nel / 256) {\n        return;\n    }\n\n    const uint is = gl_LocalInvocationID.x % 8;\n    const uint b_idx = 256 * ib + 32 * is;\n\n    const float d = float(data_a[ib].d);\n    const float db = d * (1 + 2 * ((data_a[ib].scales[is / 2] >> (4 * (is % 2))) & 0xf));\n\n    // We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes.\n    uint qh = data_a[ib].qh[is];\n    [[unroll]] for (uint l = 0; l < 8; ++l) {\n        const uint iqs = 8 * is + l;\n        const uint qs = data_a[ib].qs[iqs];\n        const uint gidx = qs | ((qh << (8 - l)) & 256);\n        const uint8_t signs = data_a[ib].signs[iqs / 2] >> (4 * (l & 1));\n        const u8vec4 grid = unpack8(iq3s_grid[gidx]);\n        data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    // Each thread handles 1 scale block (32 values)\n    // 8 threads handle 1 superblock\n    const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    if (ib >= p.nel / 256) {\n        return;\n    }\n\n    const uint is = gl_LocalInvocationID.x % 8;\n    const uint b_idx = 256 * ib + 32 * is;\n    const uint s_idx = QUANT_K / 4 + 4 * is;\n\n    const float d = float(data_a[ib].d);\n    uint signscale = pack32(u8vec4(\n        data_a[ib].qs[s_idx + 0],\n        data_a[ib].qs[s_idx + 1],\n        data_a[ib].qs[s_idx + 2],\n        data_a[ib].qs[s_idx + 3]\n    ));\n    const float db = d * (0.5 + (signscale >> 28)) * 0.5;\n\n    [[unroll]] for (uint l = 0; l < 4; ++l) {\n        const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);\n        // Restore parity bit.\n        const uint sign8 = sign7 | (bitCount(sign7) << 7);\n        const uint qs0 = data_a[ib].qs[8 * is + 2 * l];\n        const uint qs1 = data_a[ib].qs[8 * is + 2 * l + 1];\n        const u8vec4 grid0 = unpack8(iq3xxs_grid[qs0]);\n        const u8vec4 grid1 = unpack8(iq3xxs_grid[qs1]);\n        data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));\n        data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_iq4_nl data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    const uint tid = gl_LocalInvocationID.x % 64;\n    const uint il  = tid/32;\n    const uint ir  = tid%32;\n    const uint ib = 32*i + ir;\n    if (ib >= p.nel / 32) {\n        return;\n    }\n\n    const uint q_idx = 8*il;\n    const uint b_idx = 1024*i + 32*ir + q_idx;\n\n    const float d = float(data_a[ib].d);\n\n    [[unroll]] for (uint l = 0; l < 8; ++l) {\n        data_b[b_idx + l +  0] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);\n        data_b[b_idx + l + 16] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >>  4]);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_iq4_xs data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    // Each thread handles 1 subblock (1 scale and 32 quantized values)\n    const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    if (ib >= p.nel / 256) {\n        return;\n    }\n\n    const uint ib32 = gl_LocalInvocationID.x % 8;\n\n    const float d = float(data_a[ib].d);\n    // Scales are 6 bits\n    const uint scale = ((data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF)\n                     | (((data_a[ib].scales_h >> (2 * ib32)) & 3) << 4);\n    const float dl = d * (int(scale) - 32);\n\n    const uint b_idx = 256 * ib + 32 * ib32;\n    const uint q_idx = 16 * ib32;\n    [[unroll]] for (uint l = 0; l < 16; ++l) {\n        data_b[b_idx + l +  0] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);\n        data_b[b_idx + l + 16] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >>  4]);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_mxfp4 data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    const uint tid = gl_LocalInvocationID.x % 64;\n    const uint il  = tid/32;\n    const uint ir  = tid%32;\n    const uint ib = 32*i + ir;\n    if (ib >= p.nel / 32) {\n        return;\n    }\n\n    const uint q_idx = 8*il;\n    const uint b_idx = 1024*i + 32*ir + q_idx;\n\n    const float d = e8m0_to_fp32(data_a[ib].e);\n\n    [[unroll]] for (uint l = 0; l < 8; ++l) {\n        data_b[b_idx + l +  0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));\n        data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >>  4]));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {\n        const uint i = gl_WorkGroupID.x * 256 + wgy;\n        if (i >= p.nel / QUANT_K) {\n            return;\n        }\n\n        const uint tid = gl_LocalInvocationID.x;\n        const uint ip = tid / 32;\n        const uint il = tid - 32 * ip;\n        const uint is = 8 * ip + il / 16;\n\n        const uint y_idx = i * QUANT_K + 128 * ip + il;\n\n        const uint ql_idx = 32 * ip + il;\n        const uint8_t qs = data_a[i].qs[32 * ip + il];\n\n        FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);\n        FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);\n        data_b[y_idx +  0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));\n        data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));\n        data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));\n        data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {\n        const uint i = uint(gl_WorkGroupID.x * 256 + wgy);\n        if (i >= p.nel / QUANT_K) {\n            return;\n        }\n\n        const uint r = gl_LocalInvocationID.x / 4;\n        const uint tid = r / 2;\n        const uint is0 = r % 2;\n        const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4);\n        const uint n = tid / 4;\n        const uint j = tid - 4*n;\n\n        const uint8_t m = uint8_t(1 << (4*n + j));\n        const uint is = 8*n + 2*j + is0;\n        const uint shift = 2*j;\n\n        const int8_t us = int8_t(is <  4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) :\n                                 is <  8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) :\n                                 is < 12 ? (data_a[i].scales[is-8] >>  4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) :\n                                           (data_a[i].scales[is-8] >>  4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4));\n        const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d);\n        const FLOAT_TYPE dl    = d_all * FLOAT_TYPE(us - 32);\n\n        const uint y_idx = i * QUANT_K + 128 * n + 32 * j;\n        const uint qs_idx = 32*n;\n\n        for (uint l = l0; l < l0 + 4; ++l) {\n            data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)));\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_q4_0 data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;\n\n    const uint tid = gl_LocalInvocationID.x % 64;\n    const uint il  = tid/32;\n    const uint ir  = tid%32;\n    const uint ib = 32*i + ir;\n    if (ib >= p.nel / 32) {\n        return;\n    }\n\n    const uint q_idx = 8*il;\n    const uint b_idx = 1024*i + 32*ir + q_idx;\n\n    const float d = float(data_a[ib].d);\n\n    [[unroll]] for (uint l = 0; l < 8; ++l) {\n        data_b[b_idx + l +  0] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] & 0xF) - 8.0f));\n        data_b[b_idx + l + 16] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] >>  4) - 8.0f));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_q4_1 data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;\n\n    const uint tid = gl_LocalInvocationID.x % 64;\n    const uint il  = tid/32;\n    const uint ir  = tid%32;\n    const uint ib = 32*i + ir;\n    if (ib >= p.nel / 32) {\n        return;\n    }\n\n    const uint b_idx = 1024*i + 32*ir + 8*il;\n\n    const float d = float(data_a[ib].d);\n    const float m = float(data_a[ib].m);\n\n    const uint q_idx = 8*il;\n\n    [[unroll]] for (uint l = 0; l < 8; ++l) {\n        data_b[b_idx + l +  0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m);\n        data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >>  4) + m);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {\n        const uint ib = gl_WorkGroupID.x * 256 + wgy;\n        if (ib >= p.nel / QUANT_K) {\n            return;\n        }\n\n        const uint tid = gl_LocalInvocationID.x;\n        const uint il = tid / 8;\n        const uint ir = tid % 8;\n        const uint is = 2 * il;\n        const uint n = 4;\n\n        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);\n        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);\n\n        const uint y_idx = ib * QUANT_K + 64 * il + n * ir;\n        const uint qs_idx = 32*il + n * ir;\n\n        uint scidx0 = (is < 4) ? is : (is + 4);\n        uint scidx1 = (is < 4) ? is : (is - 4);\n        uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;\n        uint scidxshift1 = (is < 4) ? 0 : 2;\n        uint mbidx0 = is + 4;\n        uint mbidx1 = (is < 4) ? is + 4 : is;\n        uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;\n        uint mbidxshift0 = (is < 4) ? 0 : 4;\n        uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;\n        uint mbidxshift1 = (is < 4) ? 0 : 2;\n\n        uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));\n        uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));\n\n        const FLOAT_TYPE d1 = dall * sc;\n        const FLOAT_TYPE m1 = dmin * mbyte;\n\n        scidx0 = (is < 4) ? is + 1 : (is + 5);\n        scidx1 = (is < 4) ? is + 1 : (is - 3);\n        scidxmask1 = (is < 4) ? 0x30 : 0xC0;\n        scidxshift1 = (is < 4) ? 0 : 2;\n        mbidx0 = is + 5;\n        mbidx1 = (is < 4) ? is + 5 : is + 1;\n        mbidxmask0 = (is < 4) ? 0xF : 0xF0;\n        mbidxshift0 = (is < 4) ? 0 : 4;\n        mbidxmask1 = (is < 4) ? 0x30 : 0xC0;\n        mbidxshift1 = (is < 4) ? 0 : 2;\n\n        sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));\n        mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));\n\n        const FLOAT_TYPE d2 = dall * sc;\n        const FLOAT_TYPE m2 = dmin * mbyte;\n\n        [[unroll]] for (uint l = 0; l < n; ++l) {\n            data_b[y_idx + l     ] = D_TYPE(d1 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] & 0xF) - m1);\n            data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] >>  4) - m2);\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_q5_0 data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;\n\n    const uint tid = gl_LocalInvocationID.x % 64;\n    const uint il  = tid/32;\n    const uint ir  = tid%32;\n    const uint ib = 32*i + ir;\n    if (ib >= p.nel / 32) {\n        return;\n    }\n\n    const uint b_idx = 1024*i + 32*ir + 8*il;\n\n    const float d = float(data_a[ib].d);\n    const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];\n\n    const uint q_idx = 8*il;\n\n    [[unroll]] for (uint l = 0; l < 8; ++l) {\n        const uint iqs = q_idx + l;\n        const uint vui = uint(data_a[ib].qs[iqs]);\n        data_b[b_idx + l +  0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f));\n        data_b[b_idx + l + 16] = D_TYPE(d * (((vui >>  4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_q5_1 data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;\n\n    const uint tid = gl_LocalInvocationID.x % 64;\n    const uint il  = tid/32;\n    const uint ir  = tid%32;\n    const uint ib = 32*i + ir;\n    if (ib >= p.nel / 32) {\n        return;\n    }\n\n    const uint b_idx = 1024*i + 32*ir + 8*il;\n\n    const float d = float(data_a[ib].d);\n    const float m = float(data_a[ib].m);\n    const uint qh = data_a[ib].qh;\n\n    const uint q_idx = 8*il;\n\n    [[unroll]] for (uint l = 0; l < 8; ++l) {\n        const uint iqs = q_idx + l;\n        const uint vui = uint(data_a[ib].qs[iqs]);\n        data_b[b_idx + l +  0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m);\n        data_b[b_idx + l + 16] = D_TYPE(d * (((vui >>  4) | ((qh >> (iqs + 12)) & 0x10))) + m);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {\n        const uint ib = gl_WorkGroupID.x * 256 + wgy;\n        if (ib >= p.nel / QUANT_K) {\n            return;\n        }\n\n        const uint tid = gl_LocalInvocationID.x;\n        const uint il = tid / 16;\n        const uint ir = tid % 16;\n        const uint is = 2 * il;\n\n        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);\n        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);\n\n        const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;\n        const uint qs_idx = 32*il + 2 * ir;\n        const uint qh_idx = 2 * ir;\n\n        uint scidx0 = (is < 4) ? is : (is + 4);\n        uint scidx1 = (is < 4) ? is : (is - 4);\n        uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;\n        uint scidxshift1 = (is < 4) ? 0 : 2;\n        uint mbidx0 = is + 4;\n        uint mbidx1 = (is < 4) ? is + 4 : is;\n        uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;\n        uint mbidxshift0 = (is < 4) ? 0 : 4;\n        uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;\n        uint mbidxshift1 = (is < 4) ? 0 : 2;\n\n        uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));\n        uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));\n\n        const FLOAT_TYPE d1 = dall * sc;\n        const FLOAT_TYPE m1 = dmin * mbyte;\n\n        scidx0 = (is < 4) ? is + 1 : (is + 5);\n        scidx1 = (is < 4) ? is + 1 : (is - 3);\n        scidxmask1 = (is < 4) ? 0x30 : 0xC0;\n        scidxshift1 = (is < 4) ? 0 : 2;\n        mbidx0 = is + 5;\n        mbidx1 = (is < 4) ? is + 5 : is + 1;\n        mbidxmask0 = (is < 4) ? 0xF : 0xF0;\n        mbidxshift0 = (is < 4) ? 0 : 4;\n        mbidxmask1 = (is < 4) ? 0x30 : 0xC0;\n        mbidxshift1 = (is < 4) ? 0 : 2;\n\n        sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));\n        mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));\n\n        const FLOAT_TYPE d2 = dall * sc;\n        const FLOAT_TYPE m2 = dmin * mbyte;\n\n        const uint8_t hm1 = uint8_t(1 << (2 * il    ));\n        const uint8_t hm2 = uint8_t(1 << (2 * il + 1));\n        data_b[y_idx     ] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx    ] & 0xF) + (((data_a[ib].qh[qh_idx    ] & hm1) != 0) ? 16 : 0)) - m1);\n        data_b[y_idx +  1] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] & 0xF) + (((data_a[ib].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);\n        data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx    ]  >> 4) + (((data_a[ib].qh[qh_idx    ] & hm2) != 0) ? 16 : 0)) - m2);\n        data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1]  >> 4) + (((data_a[ib].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {\n        const uint i = gl_WorkGroupID.x * 256 + wgy;\n        if (i >= p.nel / QUANT_K) {\n            return;\n        }\n        const uint tid = gl_LocalInvocationID.x;\n        const uint ip = tid / 32;\n        const uint il = tid - 32 * ip;\n        const uint is = 8 * ip + il / 16;\n\n        const uint y_idx = i * QUANT_K + 128 * ip + il;\n\n        const uint ql_idx = 64 * ip + il;\n        const uint8_t qh = data_a[i].qh[32 * ip + il];\n\n        const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d);\n\n        data_b[y_idx +  0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx +  0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));\n        data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));\n        data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx +  0] >>  4) | (((qh >> 4) & 3) << 4)) - 32)));\n        data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >>  4) | (((qh >> 6) & 3) << 4)) - 32)));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp",
    "content": "#version 450\n\n#include \"dequant_head.glsl\"\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {block_q8_0 data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_b[];};\n\nvoid main() {\n    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;\n\n    const uint tid = gl_LocalInvocationID.x % 64;\n    const uint il  = tid/32;\n    const uint ir  = tid%32;\n    const uint ib = 32*i + ir;\n    if (ib >= p.nel / 32) {\n        return;\n    }\n\n    const uint b_idx = 1024*i + 32*ir + 16*il;\n\n    const float d = float(data_a[ib].d);\n\n    const uint q_idx = 16*il;\n\n    [[unroll]] for (uint l = 0; l < 16; l += 2) {\n        data_b[b_idx + l    ] = D_TYPE(d * data_a[ib].qs[q_idx + l    ]);\n        data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/diag.comp",
    "content": "#version 450\n\n#include \"rte.glsl\"\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint idx = get_idx();\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);\n    const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;\n    const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);\n    const uint i12_offset = i12*p.ne11*p.ne10;\n    const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);\n    const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;\n\n    if (i10 == i11) {\n        const float val = float(data_a[get_aoffset() + i13*p.nb03 + i12*p.nb02 + 0*p.nb01 + i10*p.nb00]);\n        data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);\n    } else {\n        data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp",
    "content": "#version 450\n\n#extension GL_EXT_shader_16bit_storage : require\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout (push_constant) uniform parameter\n{\n    uint ncols;\n    uint rows_per_channel;\n    uint n_past;\n} p;\n\n#include \"types.glsl\"\n\nlayout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint col = gl_GlobalInvocationID.y;\n    const uint row = gl_GlobalInvocationID.x;\n\n    if (col >= p.ncols) {\n        return;\n    }\n\n    const uint i = row*p.ncols + col;\n    if (col > p.n_past + row % p.rows_per_channel) {\n        data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000));\n    } else {\n        data_d[i] = D_TYPE(data_a[i]);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/div.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_binary_head.glsl\"\n\nconst uint num_threads = 256;\n\nlayout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    uint idx = get_idx();\n\n    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation\n    const uint num_iter = 2;\n\n    [[unroll]] for (uint i = 0; i < num_iter; ++i) {\n        if (idx >= p.ne) {\n            continue;\n        }\n        uint i00, i01, i02, i03;\n        get_indices(idx, i00, i01, i02, i03);\n\n        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));\n\n        idx += num_threads;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/elu.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    float x = float(data_a[i]);\n\n    if (x < 0.0f) {\n        x = exp(x) - 1;\n    }\n\n    data_d[i] = D_TYPE(x);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/exp.comp",
    "content": "#version 450\n\n#include \"rte.glsl\"\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n    data_d[i] = D_TYPE(exp(float(data_a[i])));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/feature-tests/bfloat16.comp",
    "content": "#version 460\n\n#extension GL_EXT_bfloat16 : require\n\nvoid main()\n{\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat.comp",
    "content": "#version 460\n\n#extension GL_KHR_cooperative_matrix : require\n\nvoid main()\n{\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2.comp",
    "content": "#version 460\n\n#extension GL_NV_cooperative_matrix2 : require\n\nvoid main()\n{\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/feature-tests/integer_dot.comp",
    "content": "#version 460\n\n#extension GL_EXT_integer_dot_product : require\n\nvoid main()\n{\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/fill.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    // p.param1 = fill value\n    data_d[i] = D_TYPE(p.param1);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/flash_attn.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_EXT_shader_16bit_storage : require\n\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#ifdef FLOAT16\n#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require\n#extension GL_EXT_shader_subgroup_extended_types_float16 : require\n#endif\n\n#extension GL_KHR_shader_subgroup_shuffle : enable\n#extension GL_KHR_shader_subgroup_vote : enable\n\n#include \"types.glsl\"\n#include \"flash_attn_base.glsl\"\n\nconst uint32_t HSK_per_thread = HSK / D_split;\nconst uint32_t HSV_per_thread = HSV / D_split;\n\nconst uint32_t rows_per_thread = Br / row_split;\nconst uint32_t cols_per_iter = WorkGroupSize / D_split / row_split;\nconst uint32_t cols_per_thread = Bc / cols_per_iter;\nconst uint32_t num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGroupSize;\n\n\nlayout (binding = 0) readonly buffer Q {float data_q[];};\nlayout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};\nlayout (binding = 1) readonly buffer K {float16_t data_k[];};\nlayout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};\nlayout (binding = 2) readonly buffer V {float16_t data_v[];};\nlayout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};\nlayout (binding = 3) readonly buffer M {float16_t data_m[];};\n\n// If SubGroupSize is set to 0 then only use shmem reductions\nconst uint32_t tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize;\nshared float tmpsh[tmpsh_size];\nshared FLOAT_TYPEV4 tmpshv4[tmpsh_size];\n\nconst uint32_t masksh_stride = Br + 1;\nshared FLOAT_TYPE masksh[Bc * masksh_stride];\n\nconst uint32_t qf_stride = HSK / 4 + 1;\nshared FLOAT_TYPEV4 Qf[Br * qf_stride];\n\nconst uint32_t D = HSK > HSV ? HSK : HSV;\nconst uint32_t kvsh_stride = D / 4 + 1;\nshared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];\n\nshared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];\n\nvoid main() {\n#ifdef NEEDS_INIT_IQ_SHMEM\n    init_iq_shmem(gl_WorkGroupSize);\n#endif\n\n    init_indices();\n\n    const uint32_t tid = gl_LocalInvocationIndex;\n    const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;\n    const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;\n    const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup;\n    const uint32_t d_tid = gl_LocalInvocationIndex % D_split;\n    const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;\n\n    if (LIMIT_OCCUPANCY_SHMEM > 0) {\n        // This just exists to avoid the occupancy_limiter array getting optimized out\n        occupancy_limiter[tid] = vec4(tid);\n\n        barrier();\n\n        if (occupancy_limiter[tid] == vec4(99999.0)) {\n            data_ov4[0] = D_TYPEV4(occupancy_limiter[tid]);\n        }\n    }\n\n#define tile_row(r) (row_tid * rows_per_thread + (r))\n\n    uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;\n\n    [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {\n        uint32_t d = (idx + tid) % (HSK / 4);\n        uint32_t r = (idx + tid) / (HSK / 4);\n        if (r < Br && d < HSK / 4 &&\n            i * Br + r < N) {\n            Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);\n        }\n    }\n    barrier();\n\n    FLOAT_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];\n    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {\n        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n            Of[r][d] = FLOAT_TYPEV4(0.0);\n        }\n    }\n\n    float Lf[rows_per_thread], Mf[rows_per_thread];\n\n    // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.\n    const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);\n\n    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n        Lf[r] = 0;\n        Mf[r] = NEG_FLT_MAX_OVER_2;\n    }\n\n    ACC_TYPE slope[rows_per_thread];\n    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n        slope[r] = ACC_TYPE(1.0);\n    }\n\n    // ALiBi\n    if (p.max_bias > 0.0f) {\n        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n            slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), iq2);\n        }\n    }\n\n    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);\n    // mo_offset will point to the tile starting at row i*Br and col 0\n    uint32_t mo_offset = mo_stride * i;\n\n#if BLOCK_SIZE > 1\n    uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;\n    uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;\n#else\n    uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;\n    uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;\n#endif\n    uint32_t m_offset = gqa_iq1*KV;\n    if (p.nem2 != 1 || p.nem3 != 1) {\n        m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;\n        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;\n    }\n\n    uint32_t mask_opt = 0;\n    uint32_t mask_opt_idx = ~0;\n    uint32_t mask_opt_bits = 0;\n\n    [[dont_unroll]]\n    for (uint32_t j = start_j; j < end_j; ++j) {\n        if (MASK_ENABLE) {\n            if (USE_MASK_OPT && mask_opt_idx != j / 16) {\n                mask_opt_idx = j / 16;\n                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];\n            }\n            mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;\n            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {\n                // skip this block\n                continue;\n            }\n            // Only load if the block is not all zeros\n            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {\n                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;\n\n                float max_mask = NEG_FLT_MAX_OVER_2;\n                barrier();\n                [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {\n                    uint32_t c = (idx + tid) % Bc;\n                    uint32_t r = (idx + tid) / Bc;\n                    if (idx + tid < Bc * Br) {\n                        if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {\n                            FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);\n                            masksh[c * masksh_stride + r] = m;\n                            max_mask = max(max_mask, float(m));\n                        } else {\n                            masksh[c * masksh_stride + r] = FLOAT_TYPE(0);\n                        }\n                    }\n                }\n                // skip the block if the mask is entirely -inf\n                bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);\n                barrier();\n                if (gl_SubgroupInvocationID == 0) {\n                    tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;\n                }\n                barrier();\n                [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {\n                    max_mask = max(max_mask, tmpsh[s]);\n                }\n                if (max_mask <= NEG_FLT_MAX_OVER_2) {\n                    continue;\n                }\n            }\n        }\n\n        ACC_TYPE Sf[rows_per_thread][cols_per_thread];\n        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {\n                Sf[r][c] = ACC_TYPE(0.0);\n            }\n        }\n\n        if (SHMEM_STAGING != 0) {\n            barrier();\n            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {\n                uint32_t d = (idx + tid) % (HSK / 4);\n                uint32_t c = (idx + tid) / (HSK / 4);\n                if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) {\n                    FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);\n                    if (!KV_bounds_check || j * Bc + c < KV) {\n#if BLOCK_SIZE > 1\n                        uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;\n                        uint ib = coord / BLOCK_SIZE;\n                        uint iqs = (coord % BLOCK_SIZE);\n                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);\n#else\n                        K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);\n#endif\n                    }\n\n                    kvsh[c * kvsh_stride + d] = K_Tf;\n                }\n            }\n            barrier();\n        }\n\n        // More d iterations means Q register caching becomes relevant\n        // Few iterations means the additional registers needed are worse than the speed-up from caching\n        if (HSK_per_thread / 4 > 4) {\n            [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {\n                FLOAT_TYPEV4 Q_cache[rows_per_thread];\n                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                    Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid];\n                }\n\n                [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {\n                    if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {\n                        continue;\n                    }\n\n                    FLOAT_TYPEV4 K_Tf;\n                    if (SHMEM_STAGING != 0) {\n                        K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];\n                    } else {\n#if BLOCK_SIZE > 1\n                        uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);\n                        uint ib = coord / BLOCK_SIZE;\n                        uint iqs = (coord % BLOCK_SIZE);\n                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);\n#else\n                        K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);\n#endif\n                    }\n                    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                        Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));\n                    }\n                }\n            }\n        } else {\n            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {\n                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {\n                    continue;\n                }\n\n                [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {\n                    FLOAT_TYPEV4 K_Tf;\n                    if (SHMEM_STAGING != 0) {\n                        K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];\n                    } else {\n#if BLOCK_SIZE > 1\n                        uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);\n                        uint ib = coord / BLOCK_SIZE;\n                        uint iqs = (coord % BLOCK_SIZE);\n                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);\n#else\n                        K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);\n#endif\n                    }\n                    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                        Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));\n                    }\n                }\n            }\n        }\n\n        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {\n            // Compute sum across the D_split\n            [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {\n                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                    Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);\n                }\n            }\n        }\n\n        if (LOGIT_SOFTCAP) {\n            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {\n                    Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c]));\n                }\n            }\n        }\n\n        if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {\n            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {\n                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                    FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)];\n\n                    Sf[r][c] += slope[r]*mvf;\n                }\n            }\n        }\n\n        float eMf[rows_per_thread];\n        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n            float rowmaxf = NEG_FLT_MAX_OVER_2;\n            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {\n                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {\n                    continue;\n                }\n                rowmaxf = max(rowmaxf, float(Sf[r][c]));\n            }\n            float Moldf = Mf[r];\n\n            // M = max(rowmax, Mold)\n            // P = e^(S - M)\n            // eM = e^(Mold - M)\n            Mf[r] = max(rowmaxf, Moldf);\n            eMf[r] = exp(Moldf - Mf[r]);\n            Lf[r] = eMf[r]*Lf[r];\n        }\n\n        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {\n            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                Of[r][d] = FLOAT_TYPE(eMf[r]) * Of[r][d];\n            }\n        }\n\n        if (SHMEM_STAGING != 0) {\n            barrier();\n            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) {\n                uint32_t d = (idx + tid) % (HSV / 4);\n                uint32_t c = (idx + tid) / (HSV / 4);\n                if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) {\n                    FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);\n                    if (!KV_bounds_check || j * Bc + c < KV) {\n#if BLOCK_SIZE > 1\n                        uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;\n                        uint ib = coord / BLOCK_SIZE;\n                        uint iqs = (coord % BLOCK_SIZE);\n                        V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);\n#else\n                        V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);\n#endif\n                    }\n\n                    kvsh[c * kvsh_stride + d] = V_Tf;\n                }\n            }\n            barrier();\n        }\n\n        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {\n            if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {\n                continue;\n            }\n\n            FLOAT_TYPE Pf[rows_per_thread];\n            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                Pf[r] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r]));\n                Lf[r] += Pf[r];\n            }\n\n            [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {\n                FLOAT_TYPEV4 Vf;\n                if (SHMEM_STAGING != 0) {\n                    Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];\n                } else {\n#if BLOCK_SIZE > 1\n                    uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);\n                    uint ib = coord / BLOCK_SIZE;\n                    uint iqs = (coord % BLOCK_SIZE);\n                    Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);\n#else\n                    Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);\n#endif\n                }\n                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                    Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf);\n                }\n            }\n        }\n    }\n\n    // prevent race on tmpsh\n    barrier();\n\n    // reduce across threads\n\n    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n        float rowmaxf = Mf[r];\n\n        // Compute max across the row\n        if (SubGroupSize > 0) {\n            [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {\n                rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s));\n            }\n            if (row_split == 1) {\n                // Reduce inside workgroup with shmem\n                barrier();\n                if (gl_SubgroupInvocationID == d_tid) {\n                    tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;\n                }\n                barrier();\n                rowmaxf = tmpsh[d_tid];\n                [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {\n                    rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);\n                }\n            }\n        } else {\n            barrier();\n            tmpsh[tid] = rowmaxf;\n            barrier();\n            [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {\n                if (rowgroup_tid < s) {\n                    tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]);\n                }\n                barrier();\n            }\n            rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid];\n        }\n\n        float Moldf = Mf[r];\n\n        // M = max(rowmax, Mold)\n        // eM = e^(Mold - M)\n        Mf[r] = max(rowmaxf, Moldf);\n        float eMf = exp(Moldf - Mf[r]);\n\n        Lf[r] = eMf*Lf[r];\n\n        // Compute sum across the row\n        if (SubGroupSize > 0) {\n            [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {\n                Lf[r] += subgroupShuffleXor(Lf[r], s);\n            }\n            if (row_split == 1) {\n                barrier();\n                if (gl_SubgroupInvocationID == d_tid) {\n                    tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];\n                }\n                barrier();\n                Lf[r] = tmpsh[d_tid];\n                [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {\n                    Lf[r] += tmpsh[s * D_split + d_tid];\n                }\n            }\n        } else {\n            barrier();\n            tmpsh[tid] = Lf[r];\n            barrier();\n            [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {\n                if (rowgroup_tid < s) {\n                    tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s];\n                }\n                barrier();\n            }\n            Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid];\n        }\n\n        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {\n            Of[r][d] = FLOAT_TYPE(eMf) * Of[r][d];\n\n            if (SubGroupSize > 0) {\n                [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {\n                    if (!OLD_AMD_WINDOWS) {\n                        Of[r][d] += subgroupShuffleXor(Of[r][d], s);\n                    } else {\n                        // Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below.\n                        // Shuffle full vec4 as workaround.\n                        // See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697\n                        Of[r][d] += FLOAT_TYPEV4(subgroupShuffleXor(vec4(Of[r][d]), s));\n                    }\n                }\n                if (row_split == 1) {\n                    barrier();\n                    if (gl_SubgroupInvocationID == d_tid) {\n                        tmpshv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];\n                    }\n                    barrier();\n                    Of[r][d] = tmpshv4[d_tid];\n                    [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {\n                        Of[r][d] += tmpshv4[s * D_split + d_tid];\n                    }\n                }\n            } else {\n                barrier();\n                tmpshv4[tid] = Of[r][d];\n                barrier();\n                [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {\n                    if (rowgroup_tid < s) {\n                        Of[r][d] += tmpshv4[tid ^ s];\n                        tmpshv4[tid] = Of[r][d];\n                    }\n                    barrier();\n                }\n                Of[r][d] = tmpshv4[row_tid * threads_per_rowgroup + d_tid];\n            }\n        }\n    }\n\n\n    // If there is split_k, then the split_k resolve shader does the final\n    // division by L. Store the intermediate O value and per-row m and L values.\n    if (p.k_num > 1) {\n        if (p.gqa_ratio > 1) {\n            // note: O and Q have swapped coord 1,2.\n            uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;\n\n            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                const uint row = tile_row(r);\n                if (row < N) {\n                    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {\n                        gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);\n                    }\n                }\n            }\n\n            o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));\n            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                const uint row = tile_row(r);\n                if (row < N) {\n                    perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);\n                    perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);\n                }\n            }\n        } else {\n            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                const uint row = tile_row(r);\n                const uint global_row = i * Br + row;\n\n                if (global_row < N) {\n                    uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;\n\n                    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {\n                        data_ov4[o_offset + iq2 * HSV/4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);\n                    }\n                }\n\n                if (global_row < N && d_tid == 0 && col_tid == 0) {\n                    uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));\n                    data_o[lm_offset + iq2] = D_TYPE(Lf[r]);\n                    data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);\n                }\n            }\n        }\n        return;\n    }\n\n    if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {\n        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n            float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);\n\n            float ms = 1.0f;\n            float vs = 1.0f;\n\n            if (sink > Mf[r]) {\n                ms = exp(Mf[r] - sink);\n\n                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {\n                    Of[r][d] *= FLOAT_TYPE(ms);\n                }\n            } else {\n                vs = exp(sink - Mf[r]);\n            }\n\n            Lf[r] = Lf[r]*ms + vs;\n        }\n    }\n\n    float Lfrcp[rows_per_thread];\n    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n        Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);\n    }\n\n    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {\n        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n            Of[r][d] *= FLOAT_TYPE(Lfrcp[r]);\n#if defined(FLOAT_TYPE_MAX)\n            Of[r][d] = clamp(Of[r][d], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);\n#endif\n        }\n    }\n\n    uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;\n\n    if (p.gqa_ratio > 1) {\n        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n            const uint row = tile_row(r);\n            if (row < N) {\n                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {\n                    gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);\n                }\n            }\n        }\n    } else {\n        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n            const uint row = tile_row(r);\n            if (i * Br + row < N) {\n                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {\n                    data_ov4[o_offset + (iq2 * HSV + (i * Br + row) * p.ne1 * HSV) / 4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl",
    "content": "\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (constant_id =  0) const uint32_t WorkGroupSize = 128;\nlayout (constant_id =  1) const uint32_t Br = 1;\nlayout (constant_id =  2) const uint32_t Bc = 32;\nlayout (constant_id =  3) const uint32_t HSK = 32;\nlayout (constant_id =  4) const uint32_t HSV = 32;\nlayout (constant_id =  5) const uint32_t Clamp = 0;\nlayout (constant_id =  6) const uint32_t D_split = 16;\nlayout (constant_id =  7) const uint32_t row_split = 1;\nlayout (constant_id =  8) const uint32_t SubGroupSize = 32;\nlayout (constant_id =  9) const uint32_t SHMEM_STAGING = 0;\nlayout (constant_id = 10) const uint32_t Flags = 0;\nlayout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0;\n\nconst bool USE_MASK_OPT    = (Flags & 1) != 0;\nconst bool MASK_ENABLE     = (Flags & 2) != 0;\nconst bool LOGIT_SOFTCAP   = (Flags & 4) != 0;\nconst bool OLD_AMD_WINDOWS = (Flags & 8) != 0;\n\n// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths\nconst uint32_t HSK_pad = (HSK + 15) & ~15;\nconst uint32_t HSV_pad = (HSV + 15) & ~15;\n\nconst bool KV_bounds_check = Clamp != 0;\n\nlayout (push_constant) uniform parameter {\n    uint32_t N;\n    uint32_t KV;\n\n    uint32_t ne1;\n    uint32_t ne2;\n    uint32_t ne3;\n\n    uint32_t neq2;\n    uint32_t neq3;\n    uint32_t nek2;\n    uint32_t nek3;\n    uint32_t nev2;\n    uint32_t nev3;\n    uint32_t nem1;\n    uint32_t nem2;\n    uint32_t nem3;\n\n    uint32_t nb01;\n    uint32_t nb02;\n    uint32_t nb03;\n    uint32_t nb11;\n    uint32_t nb12;\n    uint32_t nb13;\n    uint32_t nb21;\n    uint32_t nb22;\n    uint32_t nb23;\n\n    float scale;\n    float max_bias;\n    float logit_softcap;\n\n    uint32_t mask_n_head_log2;\n    float m0;\n    float m1;\n\n    uint32_t gqa_ratio;\n    uint32_t split_kv;\n    uint32_t k_num;\n} p;\n\n#define SINK_ENABLE_BIT (1<<24)\n#define N_LOG2_MASK 0xFFFF\n\nlayout (binding = 4) readonly buffer S {float data_s[];};\n\nlayout (binding = 5) writeonly buffer O {D_TYPE data_o[];};\nlayout (binding = 5) writeonly buffer OV4 {D_TYPEV4 data_ov4[];};\n\nlayout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};\n\n#define MASK_OPT_ALL_NEG_INF 1\n#define MASK_OPT_ALL_ZERO 2\n\n#define BINDING_IDX_K 0\n#define BINDING_IDX_V 1\n#if defined(DATA_A_F32)\nlayout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;\nlayout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;\n#elif defined(A_TYPE_PACKED16)\nlayout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;\nlayout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;\n#endif\n\n#ifndef BLOCK_SIZE\n#define BLOCK_SIZE 1\n#endif\n\n#if defined(DATA_A_F32)\n#undef BLOCK_SIZE\n#define BLOCK_SIZE 4\n#define BLOCK_BYTE_SIZE 16\n\nFLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {\n    // iqs is currently always zero in the flash attention shaders\n    if (binding_idx == BINDING_IDX_K) {\n        return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]);\n    } else {\n        return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]);\n    }\n}\n#endif\n\n#if defined(DATA_A_Q4_0)\n#define BLOCK_BYTE_SIZE 18\n\nFLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {\n    if (binding_idx == BINDING_IDX_K) {\n        uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);\n        uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);\n        uint shift = (iqs & 0x10) >> 2;\n        vui_lo >>= shift;\n        vui_hi >>= shift;\n\n        return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));\n    } else {\n        uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);\n        uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);\n        uint shift = (iqs & 0x10) >> 2;\n        vui_lo >>= shift;\n        vui_hi >>= shift;\n\n        return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));\n    }\n}\n#endif\n\n#if defined(DATA_A_Q8_0)\n#define BLOCK_BYTE_SIZE 34\nFLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {\n    if (binding_idx == BINDING_IDX_K) {\n        const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147\n        const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;\n\n        return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);\n    } else {\n        const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147\n        const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;\n\n        return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);\n    }\n}\n#endif\n\n#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))\n\n\n// Store column zero. This is used to save per-row m and L values for split_k.\nACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)\n{\n    if (r < N && c == 0) {\n        uint32_t offset = iq2 + r;\n        data_o[o_offset + offset] = D_TYPE(elem);\n    }\n    return elem;\n}\n\n// Load the slope matrix, indexed by Q's dimension 2.\nACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)\n{\n    const uint32_t h = iq2 + (r % p.gqa_ratio);\n\n    uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;\n\n    const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);\n    const int      exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);\n\n    return ACC_TYPE(pow(base, ACC_TYPE(exph)));\n}\n\n// Load the sink value, indexed by Q's dimension 2.\nACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)\n{\n    const uint32_t h = iq2 + (r % p.gqa_ratio);\n\n    return ACC_TYPE(data_s[h]);\n}\n\nuint32_t i, N, KV, split_k_index, Tr, start_j, end_j,\n         gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,\n         q_stride, k_stride, v_stride, m_stride;\n\nvoid init_indices()\n{\n    N = p.N;\n    KV = p.KV;\n\n    if (p.k_num > 1) {\n        if (p.gqa_ratio > 1) {\n            i = 0;\n            // batch and split_k share gl_WorkGroupID.x\n            gqa_iq1 = gl_WorkGroupID.x / p.k_num;\n            split_k_index = gl_WorkGroupID.x % p.k_num;\n        } else {\n            gqa_iq1 = 0;\n            split_k_index = gl_WorkGroupID.x % p.k_num;\n            i = gl_WorkGroupID.x / p.k_num;\n        }\n    } else if (p.gqa_ratio > 1) {\n        i = 0;\n        gqa_iq1 = gl_WorkGroupID.x;\n        split_k_index = 0;\n    } else {\n        i = gl_WorkGroupID.x;\n        gqa_iq1 = 0;\n        split_k_index = 0;\n    }\n\n    Tr = CEIL_DIV(N, Br);\n\n    start_j = split_k_index * p.split_kv / Bc;\n    end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);\n\n    // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.\n    // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.\n    iq2 = gl_WorkGroupID.y * p.gqa_ratio;\n    iq3 = gl_WorkGroupID.z;\n\n    // broadcast factors\n    rk2 = p.neq2/p.nek2;\n    rk3 = p.neq3/p.nek3;\n\n    rv2 = p.neq2/p.nev2;\n    rv3 = p.neq3/p.nev3;\n\n    // k indices\n    ik3 = iq3 / rk3;\n    ik2 = iq2 / rk2;\n\n    // v indices\n    iv3 = iq3 / rv3;\n    iv2 = iq2 / rv2;\n\n    // nb?1 are already divided by the type size and are in units of elements.\n    // When using grouped query attention, Q is indexed by iq2, so the stride\n    // should be nb02 (which is in bytes).\n    q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;\n    k_stride = p.nb11;\n    v_stride = p.nb21;\n    // When using grouped query attention, all rows use the same mask (stride 0).\n    // \"p.gqa_ratio >> 16\" is just a roundabout way of writing zero\n    // that prevents the compiler from folding the \"&\" through the select\n    // and breaking the alignment detection.\n    m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;\n}\n\n// Bias applied to softmax to stay in fp16 range.\n// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606\nconst float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;\n\n// Store the output when doing grouped query attention.\n// Rows index by Q's dimension 2, and the first N rows are valid.\nvoid gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)\n{\n    uint32_t offset = (iq2 + r) * HSV / 4 + c;\n    data_ov4[o_offset + offset] = D_TYPEV4(elems);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_EXT_shader_16bit_storage : require\n\n#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#extension GL_KHR_shader_subgroup_basic : enable\n#extension GL_KHR_shader_subgroup_arithmetic : enable\n#extension GL_KHR_shader_subgroup_vote : enable\n#extension GL_KHR_memory_scope_semantics : enable\n#extension GL_KHR_cooperative_matrix : enable\n\n#include \"types.glsl\"\n#include \"flash_attn_base.glsl\"\n\n// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd\nconst uint32_t MatBr = 16;\nconst uint32_t MatBc = 16;\n\nconst uint32_t rows_per_thread = Br / row_split;\nconst uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split;\nconst uint32_t cols_per_thread = Bc / cols_per_iter;\n\n\nlayout (binding = 0) readonly buffer Q {float data_q[];};\nlayout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};\nlayout (binding = 1) readonly buffer K {float16_t data_k[];};\nlayout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};\nlayout (binding = 2) readonly buffer V {float16_t data_v[];};\nlayout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};\nlayout (binding = 3) readonly buffer M {float16_t data_m[];};\n\nshared float tmpsh[row_split];\n\nconst uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4\nshared f16vec4 Qf[Br * qstride];\n\nconst uint psh_stride = Br / 4 + 2;\nshared f16vec4 Psh[Bc * psh_stride];\n\n// Avoid padding for hsk==256 to make it fit in 48KB shmem.\nconst uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;\nshared ACC_TYPEV4 sfsh[Bc * sfshstride];\n\nconst uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad;\nconst uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4\nconst uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups\nconst uint vsh_stride = v_cols;\nshared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];\n\nconst uint32_t osh_stride = row_split * MatBr / 4;\nshared f16vec4 pvsh[MatBc * osh_stride];\n\nshared ACC_TYPE slope[Br];\n\nvoid main() {\n#ifdef NEEDS_INIT_IQ_SHMEM\n    init_iq_shmem(gl_WorkGroupSize);\n#endif\n\n    init_indices();\n\n    const uint32_t tid = gl_LocalInvocationIndex;\n\n    const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;\n    const uint32_t d_per_thread = (HSV/4 + threads_per_rowgroup - 1) / threads_per_rowgroup;\n    const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;\n    const uint32_t col_tid = gl_LocalInvocationIndex % threads_per_rowgroup;\n\n#define tile_row(r) (row_tid * rows_per_thread + (r))\n\n    // Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).\n    if ((HSK % 16) != 0) {\n        [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {\n            if (i + tid < Br * qstride) {\n                Qf[i + tid] = f16vec4(0);\n            }\n        }\n        barrier();\n    }\n\n    uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4;\n\n    [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {\n        uint32_t d = (idx + tid) % (HSK / 4);\n        uint32_t r = (idx + tid) / (HSK / 4);\n        if (r < Br && d < HSK / 4 &&\n            i * Br + r < N) {\n            Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);\n        }\n    }\n    barrier();\n\n    f16vec4 Of[rows_per_thread][d_per_thread];\n    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n        [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {\n            Of[r][d] = f16vec4(0.0);\n        }\n    }\n\n    float Lf[rows_per_thread], Mf[rows_per_thread];\n\n    // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.\n    const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);\n\n    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n        Lf[r] = 0;\n        Mf[r] = NEG_FLT_MAX_OVER_2;\n    }\n\n    // ALiBi\n    if (p.max_bias > 0.0f) {\n        if (tid < Br) {\n            uint r = tid;\n            slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);\n        }\n    } else {\n        if (tid < Br) {\n            uint r = tid;\n            slope[r] = ACC_TYPE(1.0);\n        }\n    }\n\n    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);\n    // mo_offset will point to the tile starting at row i*Br and col 0\n    uint32_t mo_offset = mo_stride * i;\n\n#if BLOCK_SIZE > 1\n    uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;\n    uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;\n#else\n    uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;\n    uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;\n#endif\n    uint32_t m_offset = gqa_iq1*KV;\n    if (p.nem2 != 1 || p.nem3 != 1) {\n        m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;\n        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;\n    }\n\n    uint32_t mask_opt = 0;\n    uint32_t mask_opt_idx = ~0;\n    uint32_t mask_opt_bits = 0;\n    f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];\n\n    [[dont_unroll]]\n    for (uint32_t j = start_j; j < end_j; ++j) {\n\n        [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {\n            mask_cache[idx] = f16vec4(0);\n        }\n\n        if (MASK_ENABLE) {\n            if (USE_MASK_OPT && mask_opt_idx != j / 16) {\n                mask_opt_idx = j / 16;\n                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];\n            }\n            mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;\n            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {\n                // skip this block\n                continue;\n            }\n            // Only load if the block is not all zeros\n            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {\n                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;\n\n                float max_mask = NEG_FLT_MAX_OVER_2;\n                [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {\n                    uint32_t c = (idx + tid) / (Br / 4);\n                    uint32_t r = (idx + tid) % (Br / 4);\n                    if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {\n                        if ((!KV_bounds_check || j * Bc + c < KV)) {\n                            f16vec4 m;\n                            if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {\n                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],\n                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],\n                                            data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],\n                                            data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);\n                                max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));\n                            } else if (i * Br + r * 4 + 2 < p.nem1) {\n                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],\n                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],\n                                            data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],\n                                            0.0);\n                                max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));\n                            } else if (i * Br + r * 4 + 1 < p.nem1) {\n                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],\n                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],\n                                            0.0,\n                                            0.0);\n                                max_mask = max(max(max_mask, float(m[0])), float(m[1]));\n                            } else if (i * Br + r * 4 < p.nem1) {\n                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],\n                                            0.0,\n                                            0.0,\n                                            0.0);\n                                max_mask = max(max_mask, float(m[0]));\n                            } else {\n                                m = f16vec4(0.0);\n                            }\n                            mask_cache[idx / WorkGroupSize] = m;\n                        }\n                    }\n                }\n                // skip the block if the mask is entirely -inf\n                bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);\n                barrier();\n                if (gl_SubgroupInvocationID == 0) {\n                    tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;\n                }\n                barrier();\n                [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {\n                    max_mask = max(max_mask, tmpsh[s]);\n                }\n                if (max_mask <= NEG_FLT_MAX_OVER_2) {\n                    continue;\n                }\n            }\n        }\n\n        if (SHMEM_STAGING != 0) {\n            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) {\n                uint32_t d = (idx + tid) % (HSK_pad / 4);\n                uint32_t c = (idx + tid) / (HSK_pad / 4);\n                if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) {\n                    f16vec4 K_Tf = f16vec4(0);\n                    if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {\n#if BLOCK_SIZE > 1\n                        uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;\n                        uint ib = coord / BLOCK_SIZE;\n                        uint iqs = (coord % BLOCK_SIZE);\n                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);\n#else\n                        K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);\n#endif\n                    }\n\n                    kvsh[c * kvsh_stride + d] = K_Tf;\n                }\n            }\n            barrier();\n        }\n\n        // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br\n        // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16\n        // This is written transposed in order to allow for N being 8 if implementations need it\n        coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);\n        coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;\n        coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;\n\n        [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {\n            // If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem\n            // If not, f16 K is loaded directly from global memory if aligned, otherwise\n            // staged through a Bc * MatBr size staging buffer.\n            // If K is not type f16, then it is always staged for dequantization.\n            if (SHMEM_STAGING == 0) {\n#if BLOCK_SIZE == 1\n            if (KV_bounds_check || d * 16 + 16 > HSK) {\n#endif\n            barrier();\n            [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {\n                uint32_t col_vec = (idx + tid) % (MatBr / 4);\n                uint32_t row = (idx + tid) / (MatBr / 4);\n                if (idx + tid < Bc * MatBr / 4) {\n                    f16vec4 K_Tf = f16vec4(0);\n                    if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {\n#if BLOCK_SIZE > 1\n                        uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;\n                        uint ib = coord / BLOCK_SIZE;\n                        uint iqs = (coord % BLOCK_SIZE);\n                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);\n#else\n                        K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);\n#endif\n                    }\n\n                    kvsh[row * kvsh_stride + col_vec] = K_Tf;\n                }\n            }\n            barrier();\n#if BLOCK_SIZE == 1\n            }\n#endif\n\n#if BLOCK_SIZE == 1\n            if (KV_bounds_check || d * 16 + 16 > HSK)\n#endif\n            {\n                uint coord = (gl_SubgroupID * MatBc) * kvsh_stride;\n                coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);\n            }\n#if BLOCK_SIZE == 1\n            else {\n                const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4;\n                coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor);\n            }\n#endif\n            } else {\n                uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4;\n                coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);\n            }\n\n            coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);\n\n            SfMat = coopMatMulAdd(KMat, QMat, SfMat);\n        }\n\n        uint coord = gl_SubgroupID * MatBc * sfshstride;\n        coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);\n        barrier();\n\n        if (LOGIT_SOFTCAP) {\n            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {\n                uint32_t c = (idx + tid) / (Br / 4);\n                uint32_t r = (idx + tid) % (Br / 4);\n                if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {\n                    sfsh[c * sfshstride + r] = ACC_TYPEV4(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));\n                }\n            }\n            barrier();\n        }\n\n        if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {\n            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {\n                uint32_t c = (idx + tid) / (Br / 4);\n                uint32_t r = (idx + tid) % (Br / 4);\n                if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {\n                    if (!KV_bounds_check || j * Bc + c < KV) {\n                        // Mask nem1 bounds check is handled when loading masks\n                        ACC_TYPEV4 masks = ACC_TYPEV4(mask_cache[idx / WorkGroupSize]);\n                        ACC_TYPEV4 slopes = ACC_TYPEV4(slope[r * 4], slope[r * 4 + 1], slope[r * 4 + 2], slope[r * 4 + 3]);\n                        sfsh[c * sfshstride + r] += slopes * masks;\n                    }\n                }\n            }\n            barrier();\n        }\n\n        float eMf[rows_per_thread];\n        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n            const uint r_vec  = tile_row(r) / 4;\n            const uint r_comp = tile_row(r) % 4;\n\n            float rowmaxf = NEG_FLT_MAX_OVER_2;\n            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {\n                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {\n                    continue;\n                }\n                rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));\n            }\n            float Moldf = Mf[r];\n\n            // Compute max across the row\n            rowmaxf = subgroupMax(rowmaxf);\n\n            // M = max(rowmax, Mold)\n            // P = e^(S - M)\n            // eM = e^(Mold - M)\n            Mf[r] = max(rowmaxf, Moldf);\n            eMf[r] = exp(Moldf - Mf[r]);\n\n            Lf[r] = eMf[r]*Lf[r];\n        }\n\n        [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {\n            const uint d_local = d0 / threads_per_rowgroup;\n            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local];\n            }\n        }\n\n        // Calculate and store Pf in Psh\n        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {\n            const uint col = c * cols_per_iter + col_tid;\n\n            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) {\n                const uint row = tile_row(r);\n                if (KV_bounds_check && j * Bc + col >= KV) {\n                    Psh[col * psh_stride + row / 4] = f16vec4(0.0f);\n                } else {\n                    const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]);\n                    const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec));\n                    [[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) {\n                        Lf[r + vec_idx] += Pf[vec_idx];\n                    }\n                    Psh[col * psh_stride + row / 4] = Pf;\n                }\n            }\n        }\n\n        if (SHMEM_STAGING != 0) {\n            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) {\n                uint32_t d = (idx + tid) % (HSV_pad / 4);\n                uint32_t c = (idx + tid) / (HSV_pad / 4);\n                if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) {\n                    f16vec4 V_Tf = f16vec4(0);\n                    if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {\n#if BLOCK_SIZE > 1\n                        uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;\n                        uint ib = coord / BLOCK_SIZE;\n                        uint iqs = (coord % BLOCK_SIZE);\n                        V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);\n#else\n                        V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);\n#endif\n                    }\n\n                    kvsh[c * kvsh_stride + d] = V_Tf;\n                }\n            }\n        }\n        barrier();\n\n        const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up\n\n        // Each subgroup handles HSV/4 columns\n        [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {\n            const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;\n\n            coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);\n\n            // Preload V tiles for [Bc, 16 * num subgroups]\n            const uint v_rows = Bc;\n            const uint v_total = v_rows * v_cols;\n            const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;\n\n            // If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem.\n            // If not, f16 V is loaded directly from global memory if aligned, otherwise\n            // staged through a Bc * MatBr size staging buffer.\n            // If V is not type f16, then it is always staged for dequantization.\n            if (SHMEM_STAGING == 0) {\n#if BLOCK_SIZE == 1\n            // For f16, only preload if not aligned\n            if (KV_bounds_check) {\n#endif\n            [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {\n                const uint idx = i * gl_WorkGroupSize.x + tid;\n                const uint row = idx / v_cols;\n                const uint col = idx % v_cols;\n\n                const uint v_row = j * Bc + row;\n                const uint v_col = hsv_tile * MatBc * row_split + col * 4;\n\n                const uint coord = v_row * v_stride * BLOCK_SIZE + v_col;\n                const uint ib = coord / BLOCK_SIZE;\n                const uint iqs = coord % BLOCK_SIZE;\n\n                if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {\n#if BLOCK_SIZE > 1\n                    kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);\n#else\n                    kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];\n#endif\n                } else {\n                    kvsh[row * vsh_stride + col] = f16vec4(0.0f);\n                }\n            }\n\n#if BLOCK_SIZE == 1\n            }\n#endif\n            }\n            barrier();\n\n            const uint o_offset = gl_SubgroupID * MatBr / 4;\n\n            if (hsv_offset < HSV_pad) {\n                [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {\n                    coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);\n\n                    if (SHMEM_STAGING == 0) {\n#if BLOCK_SIZE == 1\n                    if (!KV_bounds_check) {\n                        // F16 values can be loaded directly from global memory\n                        const uint v_tile_row = j * Bc + bc_chunk * MatBc;\n                        const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;\n                        coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);\n                    } else\n#endif\n                    {\n                        const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);\n                        coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);\n                    }\n                    } else {\n                        const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4);\n                        coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);\n                    }\n\n                    PVMat = coopMatMulAdd(KMat, QMat, PVMat);\n                }\n\n                // Store PVMat to pvsh and load into Of\n                coopMatStore(PVMat, pvsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);\n            }\n\n            barrier();\n\n            const uint hsv_per_tile = row_split * MatBc;\n            const uint hsv_base = hsv_tile * hsv_per_tile;\n            const uint d_values_per_tile = hsv_per_tile / 4;\n\n            const uint d_start = hsv_tile * d_values_per_tile;\n            const uint d_end = min(d_start + d_values_per_tile, HSV / 4);\n\n            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                const uint row = tile_row(r);\n\n                [[unroll]] for (uint32_t d_local = 0; d_local < d_per_thread; ++d_local) {\n                    const uint d = d_local * threads_per_rowgroup + col_tid;\n                    const uint hsv_col = 4 * d;\n\n                    if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) {\n                        const uint local_hsv = (hsv_col - hsv_base) / 4;\n                        Of[r][d_local] += pvsh[row * osh_stride + local_hsv];\n                    }\n                }\n            }\n        }\n\n        barrier();\n    }\n\n    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n        Lf[r] = subgroupAdd(Lf[r]);\n    }\n\n    // If there is split_k, then the split_k resolve shader does the final\n    // division by L. Store the intermediate O value and per-row m and L values.\n    if (p.k_num > 1) {\n        if (p.gqa_ratio > 1) {\n            // note: O and Q have swapped coord 1,2.\n            uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;\n\n            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                if (tile_row(r) < N) {\n                    [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {\n                        const uint d = d0 + col_tid;\n                        if (d >= HSV/4) break;\n                        const uint d_local = d0 / threads_per_rowgroup;\n                        gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);\n                    }\n                }\n            }\n\n            o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));\n            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                if (tile_row(r) < N) {\n                    perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);\n                    perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);\n                }\n            }\n        } else {\n            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n                const uint row = tile_row(r);\n                const uint global_row = i * Br + row;\n\n                if (global_row < N) {\n                    uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;\n\n                    [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {\n                        const uint d = d0 + col_tid;\n                        if (d >= HSV/4) break;\n                        data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]);\n                    }\n                }\n\n                if (global_row < N && col_tid == 0) {\n                    uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));\n                    data_o[lm_offset + iq2] = D_TYPE(Lf[r]);\n                    data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);\n                }\n            }\n        }\n\n        return;\n    }\n\n    if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {\n        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n            float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);\n\n            float ms = 1.0f;\n            float vs = 1.0f;\n\n            if (sink > Mf[r]) {\n                ms = exp(Mf[r] - sink);\n\n                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {\n                    const uint d_local = d0 / threads_per_rowgroup;\n                    Of[r][d_local] *= float16_t(ms);\n                }\n            } else {\n                vs = exp(sink - Mf[r]);\n            }\n\n            Lf[r] = Lf[r]*ms + vs;\n        }\n    }\n\n    float Lfrcp[rows_per_thread];\n    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n        Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);\n    }\n\n    [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {\n        const uint d_local = d0 / threads_per_rowgroup;\n        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n            Of[r][d_local] *= float16_t(Lfrcp[r]);\n#if defined(FLOAT_TYPE_MAX)\n            Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);\n#endif\n        }\n    }\n\n    uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;\n\n    if (p.gqa_ratio > 1) {\n        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n            if (tile_row(r) < N) {\n                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {\n                    const uint d = d0 + col_tid;\n                    if (d >= HSV / 4) break;\n                    const uint d_local = d0 / threads_per_rowgroup;\n                    gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);\n                }\n            }\n        }\n    } else {\n        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {\n            if (i * Br + tile_row(r) < N) {\n                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {\n                    const uint d = d0 + col_tid;\n                    if (d >= HSV / 4) break;\n                    const uint d_local = d0 / threads_per_rowgroup;\n                    data_ov4[o_offset + (iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV) / 4 + d] = D_TYPEV4(Of[r][d_local]);\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_EXT_shader_16bit_storage : require\n\n#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require\n\n#extension GL_KHR_memory_scope_semantics : enable\n#extension GL_KHR_cooperative_matrix : enable\n#extension GL_NV_cooperative_matrix2 : enable\n#extension GL_EXT_buffer_reference : enable\n#extension GL_KHR_shader_subgroup_ballot : enable\n#extension GL_KHR_shader_subgroup_vote : enable\n#extension GL_EXT_null_initializer : enable\n\n#include \"types.glsl\"\n#include \"dequant_funcs_cm2.glsl\"\n#include \"flash_attn_base.glsl\"\n\nlayout (binding = 0) readonly buffer Q {uint8_t data_q[];};\nlayout (binding = 1) readonly buffer K {uint8_t data_k[];};\nlayout (binding = 2) readonly buffer V {uint8_t data_v[];};\nlayout (binding = 3) readonly buffer M {uint8_t data_m[];};\n\nACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {\n    return max(x, y);\n}\n\nfloat16_t maxReduceFp16(const in float16_t x, const in float16_t y) {\n    return max(x, y);\n}\n\nACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {\n    return x;\n}\n\n// Replace matrix elements >= numRows or numCols with 'replace'\nACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {\n    if (row >= numRows || col >= numCols) {\n        return replace;\n    }\n    return elem;\n}\n\nACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)\n{\n    return exp(elem);\n}\n\nACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)\n{\n    return max(elem0, elem1);\n}\n\n#if BLOCK_SIZE > 1\n#define DECODEFUNC , DEQUANTFUNC\n#else\n#define DECODEFUNC\n#endif\n\n// Store the output when doing grouped query attention.\n// Rows index by Q's dimension 2, and the first N rows are valid.\nD_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)\n{\n    if (r < N && c < HSV) {\n        uint32_t offset = (iq2 + r) * HSV + c;\n        data_o[o_offset + offset] = D_TYPE(elem);\n    }\n    return elem;\n}\n\n// Store O values for non-GQA split_k. Rows are tokens, not heads.\nD_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) {\n    uint32_t global_row = i * Br + r;\n    if (global_row < N && c < HSV) {\n        uint32_t o_off = HSV * p.ne1\n            * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));\n        data_o[o_off + iq2 * HSV + c] = D_TYPE(elem);\n    }\n    return elem;\n}\n\n// Store L/M values for non-GQA split_k.\nACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) {\n    uint32_t global_row = i * Br + r;\n    if (global_row < N && c == 0) {\n        uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3\n            + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));\n        data_o[lm_off + lm_base + iq2] = D_TYPE(elem);\n    }\n    return elem;\n}\n\nvoid main() {\n#ifdef NEEDS_INIT_IQ_SHMEM\n    init_iq_shmem(gl_WorkGroupSize);\n#endif\n\n    init_indices();\n\n    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);\n    tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);\n    tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);\n\n    tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);\n\n#if BLOCK_SIZE > 1\n    tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);\n    tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);\n#endif\n\n    tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);\n    tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);\n    tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);\n\n    // hint to the compiler that strides are aligned for the aligned variant of the shader\n    if (Clamp != gl_CooperativeMatrixClampModeConstantNV)\n    {\n        q_stride &= ~7;\n#if BLOCK_SIZE == 1\n        k_stride &= ~7;\n        v_stride &= ~7;\n#endif\n        m_stride &= ~7;\n    }\n    tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);\n    tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);\n    tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);\n\n    coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;\n    coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;\n\n    uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03;\n    coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));\n\n    Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);\n    Qf16 *= float16_t(p.scale);\n\n    coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);\n\n    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;\n\n    // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.\n    const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);\n\n    L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);\n#if defined(ACC_TYPE_MAX)\n    M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-ACC_TYPE_MAX / ACC_TYPE(2));\n#else\n    M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);\n#endif\n\n    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);\n\n    // ALiBi\n    if (p.max_bias > 0.0f) {\n        coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);\n    }\n\n    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);\n    // mo_offset will point to the tile starting at row i*Br and col 0\n    uint32_t mo_offset = mo_stride * i;\n\n    uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/;\n    if (p.nem2 != 1 || p.nem3 != 1) {\n        m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;\n        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;\n    }\n\n    uint32_t mask_opt = 0;\n    uint32_t mask_opt_idx = ~0;\n\n    [[dont_unroll]]\n    for (uint32_t j = start_j; j < end_j; ++j) {\n\n        coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);\n        if (MASK_ENABLE) {\n\n            if (USE_MASK_OPT && mask_opt_idx != j / 16) {\n                mask_opt_idx = j / 16;\n                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];\n            }\n            uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;\n            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {\n                // skip this block\n                continue;\n            }\n            // Only load if the block is not all zeros\n            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {\n                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;\n\n                if (nem1_bounds_check) {\n                    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);\n                    tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);\n                    tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);\n                    tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t\n\n                    coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;\n\n                    coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));\n                    // skip the block if the mask is entirely -inf\n                    coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);\n                    if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {\n                        continue;\n                    }\n                } else {\n                    tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);\n                    // Don't clamp against nem1 when GQA is enabled\n                    uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;\n                    tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);\n                    tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);\n\n                    coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;\n\n                    coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));\n                    // skip the block if the mask is entirely -inf\n                    coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);\n                    if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {\n                        continue;\n                    }\n                }\n            }\n        }\n\n        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);\n\n        coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;\n\n        uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;\n        coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);\n        S = coopMatMulAdd(Qf16, K_T, S);\n\n        if (LOGIT_SOFTCAP) {\n            [[unroll]]\n            for (int k = 0; k < S.length(); ++k) {\n                S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);\n            }\n        }\n\n        if (MASK_ENABLE) {\n            S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);\n        }\n\n        // Clear padding elements to -inf, so they don't contribute to rowmax\n        if (Clamp != 0 &&\n            ((j + 1) * Bc > KV ||\n             (i + 1) * Br > N)) {\n\n            uint R = ((i + 1) * Br >  N) ?  (N % Br) : Br;\n            uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;\n\n            coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);\n        }\n\n        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;\n\n        coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);\n\n        rowmax += coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(FATTN_KQ_MAX_OFFSET);\n\n        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;\n\n        // M = max(rowmax, Mold)\n        // P = e^(S - M)\n        // eM = e^(Mold - M)\n        coopMatPerElementNV(M, rowmax, Max, Mold);\n        coopMatPerElementNV(P, S - M, Exp);\n        coopMatPerElementNV(eM, Mold - M, Exp);\n\n        // Clear padding elements to 0, so they don't contribute to rowsum\n        if (Clamp != 0 &&\n            ((j + 1) * Bc > KV ||\n             (i + 1) * Br > N)) {\n\n            uint R = ((i + 1) * Br >  N) ?  (N % Br) : Br;\n            uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;\n\n            coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);\n        }\n\n        coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);\n\n        // compute rowsum by multiplying by matrix of all ones.\n        coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);\n\n        rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);\n        rowsum = coopMatMulAdd(P_A, One, rowsum);\n\n        coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;\n        uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;\n        coopMatLoadTensorNV(V,  data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);\n\n        L = eM*L + rowsum;\n\n        // This is the \"diagonal\" matrix in the paper, but since we do componentwise\n        // multiply rather than matrix multiply it has the diagonal element smeared\n        // across the row\n        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;\n\n        // resize eM by using smear/reduce\n        coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);\n\n        O *= coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag);\n        O = coopMatMulAdd(P_A, V, O);\n    }\n\n    // If there is split_k, then the split_k resolve shader does the final\n    // division by L. Store the intermediate O value and per-row m and L values.\n    if (p.k_num > 1) {\n        coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);\n\n        if (p.gqa_ratio > 1) {\n            // note: O and Q have swapped coord 1,2.\n            uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));\n            coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);\n\n            o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));\n            coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);\n            coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);\n        } else {\n            coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N);\n            coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N);\n            coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N);\n        }\n        return;\n    }\n\n    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;\n\n    // resize L by using smear/reduce\n    coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);\n\n    if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {\n        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;\n        coopMatPerElementNV(S, S, perElemOpGetSink, iq2);\n\n        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;\n\n        // resize M by using smear/reduce\n        coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);\n\n        // O, Ldiag, Mr all have the same type so all element locations match\n        [[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {\n            ACC_TYPE sink = S[i];\n\n            ACC_TYPE ms = ACC_TYPE(1.0f);\n            ACC_TYPE vs = ACC_TYPE(1.0f);\n\n            if (sink > Mr[i]) {\n                ms = exp(Mr[i] - sink);\n\n                O[i] *= float16_t(ms);\n            } else {\n                vs = exp(sink - Mr[i]);\n            }\n\n            Ldiag[i] = Ldiag[i]*ms + vs;\n        }\n    }\n\n    [[unroll]]\n    for (int k = 0; k < Ldiag.length(); ++k) {\n        Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);\n    }\n\n    coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);\n\n    O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(Ldiag)*O_D;\n\n#if defined(ACC_TYPE_MAX)\n    [[unroll]] for (uint i = 0; i < O_D.length(); ++i) { O_D[i] = clamp(O_D[i], D_TYPE(-ACC_TYPE_MAX), D_TYPE(ACC_TYPE_MAX)); }\n#endif\n\n    uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;\n\n    if (p.gqa_ratio > 1) {\n        coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);\n    } else {\n        tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);\n        tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);\n\n        // permute dimensions\n        tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);\n\n        coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_EXT_shader_16bit_storage : enable\n#extension GL_KHR_shader_subgroup_arithmetic : enable\n\nlayout (constant_id = 0) const uint BLOCK_SIZE = 128;\nlayout (constant_id = 1) const uint NUM_SUBGROUPS = 4;\nlayout (constant_id = 2) const uint Br = 32;\nlayout (constant_id = 3) const uint Bc = 32;\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {float16_t data_a[];};\nlayout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];};\nlayout (binding = 1) writeonly buffer D {uint data_d[];};\n\nlayout (push_constant) uniform parameter {\n    uint nem0;\n    uint nem1;\n    uint nem2;\n    uint nbm1;\n    uint nbm2;\n    uint nbm3;\n    uint nbd1;\n    uint nbd2;\n    uint nbd3;\n};\n\n#define MASK_OPT_ALL_NEG_INF 1\n#define MASK_OPT_ALL_ZERO 2\n\nshared float minsh[NUM_SUBGROUPS];\nshared float maxsh[NUM_SUBGROUPS];\n\nfloat FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);\n\nvoid loadvec4(inout uint result, const uint i0, const uint i1, const uint i2, const uint i3, const bool need_bounds_check) {\n    const uint tid = gl_LocalInvocationIndex;\n\n    [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {\n        float min_v = FLT_MAX_OVER_2;\n        float max_v = -FLT_MAX_OVER_2;\n        [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {\n            uint j0 = (i + tid) % (Bc / 4);\n            uint j1 = (i + tid) / (Bc / 4);\n\n            j0 *= 4;\n            j0 += (i0 * 16 + block_x) * Bc;\n            j1 += i1 * Br;\n\n            if (!need_bounds_check || j0 + 3 < nem0) {\n                vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);\n                [[unroll]] for (int c = 0; c < 4; ++c) {\n                    min_v = min(min_v, f[c]);\n                    max_v = max(max_v, f[c]);\n                }\n            } else {\n                [[unroll]] for (int c = 0; c < 4; ++c) {\n                    if (j0 + c < nem0) {\n                        float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);\n                        min_v = min(min_v, f);\n                        max_v = max(max_v, f);\n                    }\n                }\n            }\n        }\n        min_v = subgroupMin(min_v);\n        max_v = subgroupMax(max_v);\n        if (gl_SubgroupInvocationID == 0) {\n            minsh[gl_SubgroupID] = min_v;\n            maxsh[gl_SubgroupID] = max_v;\n        }\n        barrier();\n        if (tid == 0) {\n            [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {\n                min_v = min(min_v, minsh[i]);\n                max_v = max(max_v, maxsh[i]);\n            }\n            if (max_v <= -FLT_MAX_OVER_2) {\n                result |= 1 << (2*block_x);\n            }\n            if (min_v == 0.0f && max_v == 0.0f) {\n                result |= 2 << (2*block_x);\n            }\n        }\n        barrier();\n    }\n}\n\n// For each Br x Bc block of the mask (input) buffer, read all values and check\n// if it's all -inf or all zero. Write out a two-bit code indicating which it is\n// (or zero for neither). Each workgroup processes 16 tiles and writes out a\n// 32-bit result mask.\n//\n// TODO: This is a lot of work per workgroup, might make sense to split this into\n// more workgroups in the future.\nvoid main() {\n    // Each workgroup handles a row\n    const uint tid = gl_LocalInvocationIndex;\n    const uint i0 = gl_WorkGroupID.x;\n    const uint i1 = gl_WorkGroupID.y;\n    const uint i2 = gl_WorkGroupID.z % nem2;\n    const uint i3 = gl_WorkGroupID.z / nem2;\n\n    uint result = 0;\n\n    // Fast path for fully in-bounds blocks where we can do f16vec4 loads\n    if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&\n        ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {\n        if ((i0 + 1) * 16 * Bc <= nem0) {\n            loadvec4(result, i0, i1, i2, i3, false);\n        } else {\n            loadvec4(result, i0, i1, i2, i3, true);\n        }\n    } else {\n        [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {\n            float min_v = FLT_MAX_OVER_2;\n            float max_v = -FLT_MAX_OVER_2;\n            [[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) {\n                if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) {\n                    continue;\n                }\n                uint j0 = (i + tid) % Bc;\n                uint j1 = (i + tid) / Bc;\n\n                j0 += (i0 * 16 + block_x) * Bc;\n                j1 += i1 * Br;\n\n                if (j0 < nem0 && j1 < nem1) {\n                    float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);\n                    min_v = min(min_v, f);\n                    max_v = max(max_v, f);\n                }\n            }\n            min_v = subgroupMin(min_v);\n            max_v = subgroupMax(max_v);\n            if (gl_SubgroupInvocationID == 0) {\n                minsh[gl_SubgroupID] = min_v;\n                maxsh[gl_SubgroupID] = max_v;\n            }\n            barrier();\n            if (tid == 0) {\n                [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {\n                    min_v = min(min_v, minsh[i]);\n                    max_v = max(max_v, maxsh[i]);\n                }\n                if (max_v <= -FLT_MAX_OVER_2) {\n                    result |= 1 << (2*block_x);\n                }\n                if (min_v == 0.0f && max_v == 0.0f) {\n                    result |= 2 << (2*block_x);\n                }\n            }\n            barrier();\n        }\n    }\n\n    if (tid == 0) {\n        data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(constant_id = 0) const uint BLOCK_SIZE = 32;\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {float data_a[];};\nlayout (binding = 1) readonly buffer B {float data_s[];};\nlayout (binding = 2) writeonly buffer D {float data_d[];};\n\nlayout (push_constant) uniform parameter {\n    uint D;\n    uint ne1;\n    uint ne2;\n    uint ne3;\n    uint k_num;\n    uint sinks;\n} p;\n\nshared float tmpsh[BLOCK_SIZE];\n\nvoid main() {\n    // Each workgroup handles a row\n    const uint n = gl_WorkGroupID.x;\n    const uint tid = gl_LocalInvocationID.x;\n    const uint i2 = gl_WorkGroupID.z % p.ne2;\n    const uint i3 = gl_WorkGroupID.z / p.ne2;\n\n    uint D = p.D;\n    uint k_num = p.k_num;\n\n    uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n;\n    uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n;\n    uint lm_stride = p.ne1 * 2;\n\n    // Compute the max m value for the row\n    float m_max = -1.0/0.0;\n    for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {\n        float m = data_a[m_offset + (k + tid) * lm_stride];\n        m_max = max(m_max, m);\n    }\n\n    // reduce across the workgroup\n    tmpsh[tid] = m_max;\n    barrier();\n    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {\n        if (tid < s) {\n            m_max = max(m_max, tmpsh[tid + s]);\n            tmpsh[tid] = m_max;\n        }\n        barrier();\n    }\n    m_max = tmpsh[0];\n\n    barrier();\n\n    // Compute L based on m_max\n    float L = 0;\n    for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {\n        float l = data_a[l_offset + (k + tid) * lm_stride];\n        float m = data_a[m_offset + (k + tid) * lm_stride];\n        L += exp(m - m_max) * l;\n    }\n\n    // reduce across the workgroup\n    tmpsh[tid] = L;\n    barrier();\n    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {\n        if (tid < s) {\n            L += tmpsh[tid + s];\n            tmpsh[tid] = L;\n        }\n        barrier();\n    }\n    L = tmpsh[0];\n\n    float sink;\n    if (p.sinks != 0) {\n        sink = data_s[n];\n\n        float ms = 1.0f;\n        float vs = 1.0f;\n\n        if (sink > m_max) {\n            ms = exp(m_max - sink);\n        } else {\n            vs = exp(sink - m_max);\n        }\n\n        L = L*ms + vs;\n    }\n\n    L = (L == 0.0) ? 0.0 : 1.0 / L;\n\n    // D dimension is split across workgroups in the y dimension\n    uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;\n    // Scale and sum the O contributions based on m_max and store the result to memory\n    if (d < D) {\n        float O = 0.0;\n        [[unroll]] for (uint k = 0; k < k_num; ++k) {\n            uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d;\n            float m = data_a[m_offset + k * lm_stride];\n            O += exp(m - m_max) * data_a[o_offset];\n        }\n        if (p.sinks != 0) {\n            if (sink > m_max) {\n                float ms = 1.0f;\n                ms = exp(m_max - sink);\n                O *= ms;\n            }\n        }\n        O *= L;\n\n        const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);\n        O = clamp(O, -FLT_MAX, FLT_MAX);\n\n        data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/floor.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float x = float(data_a[i]);\n    data_d[i] = D_TYPE(floor(x));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : require\n\nlayout(constant_id = 0) const uint S_V = 128;\nlayout(constant_id = 1) const uint KDA = 0;\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout(push_constant) uniform Parameters {\n    uint H;\n    uint n_tokens;\n    uint n_seqs;\n    uint s_off;\n    uint sq1, sq2, sq3;\n    uint sv1, sv2, sv3;\n    uint sb1, sb2, sb3;\n    uint neq1, rq3;\n    float scale;\n};\n\nlayout(binding = 0) readonly  buffer QBuf     { FLOAT_TYPE data_q[];     };\nlayout(binding = 1) readonly  buffer KBuf     { FLOAT_TYPE data_k[];     };\nlayout(binding = 2) readonly  buffer VBuf     { FLOAT_TYPE data_v[];     };\nlayout(binding = 3) readonly  buffer GBuf     { FLOAT_TYPE data_g[];     };\nlayout(binding = 4) readonly  buffer BetaBuf  { FLOAT_TYPE data_beta[];  };\nlayout(binding = 5) readonly  buffer StateBuf { FLOAT_TYPE data_state[]; };\nlayout(binding = 6)           buffer DstBuf   { FLOAT_TYPE data_dst[];   };\n\nshared FLOAT_TYPE s_k[S_V];\nshared FLOAT_TYPE s_q[S_V];\nshared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i])\n\nvoid main() {\n    const uint head_id = gl_WorkGroupID.x;\n    const uint seq_id  = gl_WorkGroupID.y;\n    const uint col     = gl_LocalInvocationID.x;\n\n    const uint iq1 = head_id % neq1;\n    const uint iq3 = seq_id / rq3;\n\n    const uint state_size = S_V * S_V;\n    const uint state_base = (seq_id * H + head_id) * state_size;\n\n    FLOAT_TYPE state[S_V];\n    [[unroll]] for (uint i = 0; i < S_V; i++) {\n        state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);\n    }\n\n    uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;\n\n    for (uint t = 0; t < n_tokens; t++) {\n        const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;\n        const uint k_off = q_off;\n        const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;\n\n        s_q[col] = FLOAT_TYPE(data_q[q_off + col]);\n        s_k[col] = FLOAT_TYPE(data_k[k_off + col]);\n\n        const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;\n\n        if (KDA != 0) {\n            const uint g_base = gb_off * S_V;\n            s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col]));\n        }\n\n        barrier();\n\n        const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);\n        const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);\n\n        if (KDA == 0) {\n            const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));\n\n            FLOAT_TYPE kv_col = 0.0;\n            [[unroll]] for (uint i = 0; i < S_V; i += 4) {\n                kv_col += dot(\n                    vec4(state[i], state[i+1], state[i+2], state[i+3]),\n                    vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3])\n                );\n            }\n\n            FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val;\n\n            FLOAT_TYPE attn_col = 0.0;\n            [[unroll]] for (uint i = 0; i < S_V; i += 4) {\n                vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);\n                vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);\n                sv = g_val * sv + kv * delta_col;\n                state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;\n\n                attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));\n            }\n\n            data_dst[attn_off + col] = attn_col * scale;\n        } else {\n            FLOAT_TYPE kv_col = 0.0;\n            [[unroll]] for (uint i = 0; i < S_V; i += 4) {\n                vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);\n                vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);\n                vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);\n                kv_col += dot(gv * sv, kv);\n            }\n\n            FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;\n\n            FLOAT_TYPE attn_col = 0.0;\n            [[unroll]] for (uint i = 0; i < S_V; i += 4) {\n                vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);\n                vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);\n                vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);\n                sv = gv * sv + kv * delta_col;\n                state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;\n\n                attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));\n            }\n\n            data_dst[attn_off + col] = attn_col * scale;\n        }\n\n        attn_off += S_V * H;\n        barrier();\n    }\n\n    [[unroll]] for (uint i = 0; i < S_V; i++) {\n        data_dst[s_off + state_base + col * S_V + i] = state[i];\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/geglu.comp",
    "content": "#version 450\n\n#include \"glu_head.glsl\"\n\nconst float GELU_COEF_A    = 0.044715f;\nconst float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;\n\nfloat op(float a, float b) {\n    const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a);\n    return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;\n}\n\n#include \"glu_main.glsl\"\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/geglu_erf.comp",
    "content": "#version 450\n\n#include \"glu_head.glsl\"\n\n// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation\n// ref: https://www.johndcook.com/blog/python_erf/\nconst float p_erf  = 0.3275911f;\nconst float a1_erf = 0.254829592f;\nconst float a2_erf = -0.284496736f;\nconst float a3_erf = 1.421413741f;\nconst float a4_erf = -1.453152027f;\nconst float a5_erf = 1.061405429f;\n\nconst float SQRT_2_INV = 0.70710678118654752440084436210484f;\n\nfloat op(float a, float b) {\n    const float a_div_sqr2 = a * SQRT_2_INV;\n    const float sign_x = sign(a_div_sqr2);\n    const float x = abs(a_div_sqr2);\n    const float t = 1.0f / (1.0f + p_erf * x);\n    const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);\n    const float erf_approx = sign_x * y;\n\n    return 0.5f * a * (1.0f + erf_approx) * b;\n}\n\n#include \"glu_main.glsl\"\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/geglu_quick.comp",
    "content": "#version 450\n\n#include \"glu_head.glsl\"\n\nconst float GELU_QUICK_COEF = -1.702f;\n\nfloat op(float a, float b) {\n    return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;\n}\n\n#include \"glu_main.glsl\"\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/gelu.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const float GELU_COEF_A    = 0.044715f;\n    const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float xi = float(data_a[i]);\n    const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);\n    data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/gelu_erf.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation\n    // ref: https://www.johndcook.com/blog/python_erf/\n    const float p_erf  = 0.3275911f;\n    const float a1_erf = 0.254829592f;\n    const float a2_erf = -0.284496736f;\n    const float a3_erf = 1.421413741f;\n    const float a4_erf = -1.453152027f;\n    const float a5_erf = 1.061405429f;\n\n    const float SQRT_2_INV = 0.70710678118654752440084436210484f;\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float a = float(data_a[i]);\n    const float a_div_sqr2 = a * SQRT_2_INV;\n    const float sign_x = sign(a_div_sqr2);\n    const float x = abs(a_div_sqr2);\n    const float t = 1.0f / (1.0f + p_erf * x);\n    const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);\n    const float erf_approx = sign_x * y;\n\n    data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/gelu_quick.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const float GELU_QUICK_COEF = -1.702f;\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float x = float(data_a[i]);\n    data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x))));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl",
    "content": "#extension GL_EXT_shader_16bit_storage : require\n#extension GL_EXT_control_flow_attributes : require\n\n#include \"rte.glsl\"\n#include \"utils.glsl\"\n#if RMS_NORM_ROPE_FUSION\n#include \"rope_params.glsl\"\n#endif\n\nlayout (push_constant) uniform parameter\n{\n    uint ne;\n    uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;\n    uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;\n    uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;\n    uint misalign_offsets;\n    float param1; float param2; int param3;\n#if RMS_NORM_ROPE_FUSION\n    rope_params rope;\n#endif\n} p;\n\n#if !RMS_NORM_ROPE_FUSION\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\n#if defined(A_TYPE_PACKED16)\nlayout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};\n#endif\n#if defined(A_TYPE_PACKED32)\nlayout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};\n#endif\n\nlayout (binding = 1) readonly buffer B {B_TYPE data_b[];};\nlayout (binding = 2) writeonly buffer D {D_TYPE data_d[];};\n#endif\n\n// true if src0/src1 are the same shape and the indices can be reused without additional modulus\nlayout(constant_id = 0) const bool norepeat = false;\n\nuint get_idx() {\n    return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n}\n\nuint get_aoffset() { return p.misalign_offsets >> 16; }\nuint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }\nuint get_doffset() { return p.misalign_offsets & 0xFF; }\n\n\nvoid get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {\n    get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03);\n}\n\nuint src0_idx(uint i00, uint i01, uint i02, uint i03) {\n    return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;\n}\n\nuint src1_idx(uint i00, uint i01, uint i02, uint i03) {\n    if (norepeat) {\n        return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10;\n    } else {\n        return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10;\n    }\n}\n\nuint dst_idx(uint i00, uint i01, uint i02, uint i03) {\n    return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20;\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/generic_head.glsl",
    "content": "#extension GL_EXT_shader_16bit_storage : require\n\nlayout (push_constant) uniform parameter\n{\n    uint KX;\n    uint KY;\n    float param1;\n    float param2;\n    float param3;\n    float param4;\n} p;\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl",
    "content": "#extension GL_EXT_shader_16bit_storage : require\n#extension GL_EXT_control_flow_attributes : require\n\nlayout (push_constant) uniform parameter\n{\n    uint ne;\n    uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;\n    uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;\n    uint misalign_offsets;\n    float param1; float param2;\n\n    uint ne0_012mp; uint ne0_012L;\n    uint ne0_01mp;  uint ne0_01L;\n    uint ne0_0mp;   uint ne0_0L;\n    uint ne1_012mp; uint ne1_012L;\n    uint ne1_01mp;  uint ne1_01L;\n    uint ne1_0mp;   uint ne1_0L;\n} p;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\n#if defined(A_TYPE_PACKED16)\nlayout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};\n#endif\n#if defined(A_TYPE_PACKED32)\nlayout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};\n#endif\n\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nuint get_idx() {\n    return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n}\n\nuint get_aoffset() { return p.misalign_offsets >> 16; }\nuint get_doffset() { return p.misalign_offsets & 0xFFFF; }\n\n// see init_fastdiv_values in ggml-vulkan.cpp\nuint fastdiv(uint n, uint mp, uint L) {\n    uint msbs, lsbs;\n    // msbs = mulhi(n, mp)\n    umulExtended(n, mp, msbs, lsbs);\n    return (msbs + n) >> L;\n}\n\nuint src0_idx(uint idx) {\n    const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);\n    const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;\n    const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);\n    const uint i02_offset = i02*p.ne01*p.ne00;\n    const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);\n    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;\n    return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;\n}\n\nuint dst_idx(uint idx) {\n    const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);\n    const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;\n    const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);\n    const uint i12_offset = i12*p.ne11*p.ne10;\n    const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);\n    const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;\n    return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;\n}\n\nuint src0_idx_quant(uint idx, uint qk) {\n    const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);\n    const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;\n    const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);\n    const uint i02_offset = i02*p.ne01*p.ne00;\n    const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);\n    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;\n    return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + (i00/qk)*p.nb00;\n}\n\nuint dst_idx_quant(uint idx, uint qk) {\n    const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);\n    const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;\n    const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);\n    const uint i12_offset = i12*p.ne11*p.ne10;\n    const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);\n    const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;\n    return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + (i10/qk)*p.nb10;\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/get_rows.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_binary_head.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint i00 = gl_GlobalInvocationID.x;\n\n    if (i00 >= p.ne00) {\n        return;\n    }\n\n    uint gid_z = gl_GlobalInvocationID.z;\n    while (gid_z < p.ne11 * p.ne12) {\n        uint gid_y = gl_GlobalInvocationID.y;\n        while (gid_y < p.ne10) {\n            const uint i10 = gid_y;\n            const uint i11 = gid_z / p.ne12;\n            const uint i12 = gid_z % p.ne12;\n\n            const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];\n\n            const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;\n            const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;\n\n#if defined(DATA_A_BF16)\n            TEMP_TYPE v = TEMP_TYPE(bf16_to_fp32(data_a[a_offset + i00]));\n#else\n            TEMP_TYPE v = TEMP_TYPE(data_a[a_offset + i00]);\n#endif\n#ifndef OPTIMIZATION_ERROR_WORKAROUND\n            data_d[d_offset + i00] = D_TYPE(v);\n#else\n            data_d[d_offset + i00] = D_TYPE(v);\n#endif\n            gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;\n        }\n        gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n\n#include \"types.glsl\"\n#include \"generic_binary_head.glsl\"\n#include \"dequant_funcs.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint i00 = (gl_GlobalInvocationID.x)*2;\n\n#ifdef NEEDS_INIT_IQ_SHMEM\n    init_iq_shmem(gl_WorkGroupSize);\n#endif\n\n    if (i00 >= p.ne00) {\n        return;\n    }\n\n    uint gid_z = gl_GlobalInvocationID.z;\n    while (gid_z < p.ne11 * p.ne12) {\n        uint gid_y = gl_GlobalInvocationID.y;\n        while (gid_y < p.ne10) {\n            const uint i10 = gid_y;\n            const uint i11 = gid_z / p.ne12;\n            const uint i12 = gid_z % p.ne12;\n\n            const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];\n\n            const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;\n            const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;\n\n            const uint ib = a_offset + i00/QUANT_K; // block index\n            const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index\n            const uint iybs = i00 - i00%QUANT_K; // dst block start index\n            const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;\n\n            vec2 v = dequantize(ib, iqs, 0);\n            const vec2 dm = get_dm(ib, 0);\n            v = v * dm.x + dm.y;\n\n            data_d[d_offset + iybs + iqs           ] = D_TYPE(v.x);\n            data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);\n\n            gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;\n        }\n        gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/glu_head.glsl",
    "content": "#extension GL_EXT_shader_16bit_storage : require\n\n#include \"rte.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) readonly buffer B {A_TYPE data_b[];};\nlayout (binding = 2) writeonly buffer D {D_TYPE data_d[];};\n\nlayout (push_constant) uniform parameter\n{\n    uint N;\n    uint ne00;\n    uint ne20;\n    uint mode;\n    float alpha;\n    float limit;\n} p;\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/glu_main.glsl",
    "content": "void main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.N) {\n        return;\n    }\n\n    const uint row = i / p.ne20;\n    const uint col = i - row * p.ne20;\n\n    if (p.mode == 0) {\n        // Default\n        const uint offset = p.ne00 / 2;\n        const uint idx = row * p.ne00 + col;\n\n        data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));\n    } else if (p.mode == 1) {\n        // Swapped\n        const uint offset = p.ne00 / 2;\n        const uint idx = row * p.ne00 + col;\n\n        data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));\n    } else {\n        // Split\n        const uint idx = row * p.ne00 + col;\n\n        data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/group_norm.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n#define BLOCK_SIZE 512\n\nlayout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nshared float tmp[BLOCK_SIZE];\n\nvoid main() {\n    const uint group_size = p.KX;\n    const float eps = p.param1;\n\n    const uint tid = gl_LocalInvocationID.x;\n    const uint start = gl_WorkGroupID.x * group_size + tid;\n    const uint end = (gl_WorkGroupID.x + 1) * group_size;\n\n    tmp[tid] = 0.0f;\n\n    // Calculate mean\n    [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {\n        tmp[tid] += float(data_a[col]);\n    }\n\n    // tmp up partial tmps and write back result\n    barrier();\n    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            tmp[tid] += tmp[tid + s];\n        }\n        barrier();\n    }\n\n    const float mean = tmp[0] / group_size;\n    barrier();\n    tmp[tid] = 0.0f;\n\n    // Calculate variance\n    [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {\n        const float xi = float(data_a[col]) - mean;\n        data_d[col] = D_TYPE(xi);\n        tmp[tid] += xi * xi;\n    }\n\n    // sum up partial sums and write back result\n    barrier();\n    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            tmp[tid] += tmp[tid + s];\n        }\n        barrier();\n    }\n\n    const float variance = tmp[0] / group_size;\n    const float scale = inversesqrt(variance + eps);\n\n    [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {\n        data_d[col] *= D_TYPE(scale);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float x = float(data_a[i]);\n    data_d[i] = D_TYPE(min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/hardswish.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float x = float(data_a[i]);\n    data_d[i] = D_TYPE(x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/im2col.comp",
    "content": "#version 450\n\n#extension GL_EXT_shader_16bit_storage : require\n#extension GL_EXT_control_flow_attributes : require\n\n#include \"rte.glsl\"\n#include \"types.glsl\"\n\nlayout (push_constant) uniform parameter\n{\n    BDA_STORAGE_T dst_addr;\n    uint batch_offset; uint offset_delta;\n    uint IC;\n    uint IW; uint IH;\n    uint OW; uint OH;\n    uint KW; uint KH;\n    uint pelements;\n    uint CHW;\n    int s0; int s1;\n    int p0; int p1;\n    int d0; int d1;\n    uint batch_IC;\n} p;\n\nlayout(constant_id = 0) const uint BLOCK_SIZE = 32;\n\nconst uint NUM_ITER = 512 / BLOCK_SIZE;\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\n#if BDA\nlayout (buffer_reference) buffer D_ptr {D_TYPE d;};\n#endif\n\nvoid im2col(const uint y, const uint z) {\n    const uint gidx = gl_GlobalInvocationID.x;\n\n    const uint oh = y;\n    const uint batch = z / p.IC;\n    const uint ic = z % p.IC;\n\n    const uint src_base = ic * p.offset_delta + batch * p.batch_offset;\n    const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);\n    const int oh_s1 = int(oh) * p.s1;\n    const uint ksize = p.OW * p.KH;\n\n    const uint base_linear_idx = gidx * NUM_ITER;\n\n    uint current_kx = base_linear_idx / ksize;\n    const uint rem = base_linear_idx - (current_kx * ksize);\n    uint current_ky = rem / p.OW;\n    uint current_ix = rem % p.OW;\n\n    A_TYPE values[NUM_ITER];\n    BDA_OFFSET_T offset_dst[NUM_ITER];\n    [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {\n        values[idx] = A_TYPE(0);\n    }\n\n    [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {\n\n        const uint linear_idx = base_linear_idx + idx;\n\n        if (linear_idx >= p.pelements) {\n            continue;\n        }\n\n        const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;\n        const uint iih = oh_s1 + current_ky * p.d1 - p.p1;\n\n        offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;\n\n        if ((iih < p.IH) && (iiw < p.IW)) {\n            values[idx] = data_a[src_base + iih * p.IW + iiw];\n        }\n\n        if (++current_ix == p.OW) {\n            current_ix = 0;\n            if (++current_ky == p.KH) {\n                current_ky = 0;\n                current_kx++;\n            }\n        }\n    }\n\n    [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {\n\n        const uint linear_idx = base_linear_idx + idx;\n\n        if (linear_idx >= p.pelements) {\n            continue;\n        }\n\n#if BDA\n        D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);\n        dst_addr.d = D_TYPE(values[idx]);\n#else\n        data_d[offset_dst[idx]] = D_TYPE(values[idx]);\n#endif\n    }\n}\n\nvoid main() {\n    uint y = gl_GlobalInvocationID.y;\n    while (y < p.OH) {\n        uint z = gl_GlobalInvocationID.z;\n        while (z < p.batch_IC) {\n            im2col(y, z);\n            z += gl_NumWorkGroups.z;\n        }\n        y += gl_NumWorkGroups.y;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/im2col_3d.comp",
    "content": "#version 450\n\n#extension GL_EXT_shader_16bit_storage : require\n#extension GL_EXT_control_flow_attributes : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#include \"rte.glsl\"\n#include \"types.glsl\"\n\nlayout (push_constant) uniform parameter\n{\n    BDA_STORAGE_T dst_addr;\n    uint32_t nb10;\n    uint32_t nb11;\n    uint32_t nb12;\n    uint32_t nb13;\n    uint32_t s0;\n    uint32_t s1;\n    uint32_t s2;\n    uint32_t p0;\n    uint32_t p1;\n    uint32_t p2;\n    uint32_t d0;\n    uint32_t d1;\n    uint32_t d2;\n    uint32_t IW;\n    uint32_t IH;\n    uint32_t ID;\n    uint32_t IC;\n    uint32_t KW;\n    uint32_t OH;\n    uint32_t KD_KH_KW;\n    uint32_t KH_KW;\n    uint32_t IC_KD_KH_KW;\n    uint32_t N_OD_OH;\n    uint32_t OD_OH;\n    uint32_t OD_OH_OW_IC_KD_KH_KW;\n    uint32_t OH_OW_IC_KD_KH_KW;\n    uint32_t OW_IC_KD_KH_KW;\n    uint32_t misalign_offsets;\n} p;\n\nuint get_aoffset() { return p.misalign_offsets >> 16; }\nuint get_doffset() { return p.misalign_offsets & 0xFFFF; }\n\nlayout(constant_id = 0) const uint BLOCK_SIZE = 32;\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\n#if BDA\nlayout (buffer_reference) buffer D_ptr {D_TYPE d;};\n#endif\n\nvoid main() {\n    const uint32_t i = gl_GlobalInvocationID.x;\n\n    uint32_t nb10 = p.nb10;\n    uint32_t nb11 = p.nb11;\n    uint32_t nb12 = p.nb12;\n    uint32_t nb13 = p.nb13;\n    uint32_t s0 = p.s0;\n    uint32_t s1 = p.s1;\n    uint32_t s2 = p.s2;\n    uint32_t p0 = p.p0;\n    uint32_t p1 = p.p1;\n    uint32_t p2 = p.p2;\n    uint32_t d0 = p.d0;\n    uint32_t d1 = p.d1;\n    uint32_t d2 = p.d2;\n    uint32_t IW = p.IW;\n    uint32_t IH = p.IH;\n    uint32_t ID = p.ID;\n    uint32_t IC = p.IC;\n    uint32_t KW = p.KW;\n    uint32_t OH = p.OH;\n    uint32_t KD_KH_KW = p.KD_KH_KW;\n    uint32_t KH_KW = p.KH_KW;\n    uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW;\n    uint32_t N_OD_OH = p.N_OD_OH;\n    uint32_t OD_OH = p.OD_OH;\n    uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW;\n    uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW;\n    uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW;\n\n    if (i >= IC_KD_KH_KW) {\n        return;\n    }\n\n    const uint32_t iic = i / KD_KH_KW;\n    const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW;\n    const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;\n    const uint32_t ikw = i % KW;\n\n    const uint32_t iow = gl_GlobalInvocationID.y;\n    for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) {\n        const uint32_t in_ = iz / OD_OH;\n        const uint32_t iod = (iz - in_*OD_OH) / OH;\n        const uint32_t ioh = iz % OH;\n\n        const uint32_t iiw = iow * s0 + ikw * d0 - p0;\n        const uint32_t iih = ioh * s1 + ikh * d1 - p1;\n        const uint32_t iid = iod * s2 + ikd * d2 - p2;\n\n        const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;\n\n        const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;\n#if BDA\n        D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst);\n        if (iih >= IH || iiw >= IW || iid >= ID) {\n            dst_addr.d = D_TYPE(0.0f);\n        } else {\n            dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]);\n        }\n#else\n        if (iih >= IH || iiw >= IW || iid >= ID) {\n            data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);\n        } else {\n            data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);\n        }\n#endif\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/l2_norm.comp",
    "content": "#version 450\n\n#include \"generic_unary_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n#define BLOCK_SIZE 512\n\nlayout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;\n\nshared FLOAT_TYPE sum[BLOCK_SIZE];\n\nvoid main() {\n    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;\n    const uint tid = gl_LocalInvocationID.x;\n\n    const uint i3 = row / (p.ne11 * p.ne12);\n    const uint i3_offset = i3 * p.ne12 * p.ne11;\n    const uint i2 = (row - i3_offset) / p.ne11;\n    const uint i2_offset = i2 * p.ne11;\n    const uint i1 = row - i3_offset - i2_offset;\n\n    sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp\n\n    [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {\n        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);\n        sum[tid] += xi * xi;\n    }\n\n    // sum up partial sums and write back result\n    barrier();\n    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            sum[tid] += sum[tid + s];\n        }\n        barrier();\n    }\n\n    const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1));\n\n    [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {\n        data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/leaky_relu.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float val = float(data_a[i]);\n    data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/log.comp",
    "content": "#version 450\n\n#include \"rte.glsl\"\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint idx = get_idx();\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    const float val = float(data_a[get_aoffset() + src0_idx(idx)]);\n    data_d[get_doffset() + dst_idx(idx)] = D_TYPE(log(val));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_binary_head.glsl\"\n\nconst uint num_threads = 256;\n\nlayout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    uint idx = get_idx();\n\n    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation\n    const uint num_iter = 2;\n\n    [[unroll]] for (uint i = 0; i < num_iter; ++i) {\n        if (idx >= p.ne) {\n            continue;\n        }\n        uint i00, i01, i02, i03;\n        get_indices(idx, i00, i01, i02, i03);\n\n        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));\n\n        idx += num_threads;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {float data_a[];};\nlayout (binding = 0) readonly buffer A4 {vec4 data_a4[];};\nlayout (binding = 1) writeonly buffer D {float data_d[];};\nlayout (binding = 1) writeonly buffer D4 {vec4 data_d4[];};\n\nlayout (push_constant) uniform parameter {\n    uint ne;\n    uint k_num;\n} p;\n\nvoid main() {\n    // Each invocation handles four consecutive components\n    const uint idx = gl_GlobalInvocationID.x * 4;\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    // Check if all four components are in bounds and aligned,\n    // then use vector loads\n    if (idx + 3 < p.ne && (p.ne % 4) == 0) {\n        vec4 result = vec4(0.0f);\n\n        [[unroll]] for (uint i = 0; i < p.k_num; i++) {\n            result += data_a4[(i * p.ne + idx) / 4];\n        }\n\n        data_d4[idx / 4] = result;\n    } else {\n        [[unroll]] for (uint j = 0; j < 4; ++j) {\n            if (idx + j < p.ne) {\n                float result = 0.0f;\n\n                [[unroll]] for (uint i = 0; i < p.k_num; i++) {\n                    result += data_a[i * p.ne + idx + j];\n                }\n\n                data_d[idx + j] = result;\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp",
    "content": "#version 450\n\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#include \"mul_mat_vec_base.glsl\"\n#include \"dequant_funcs.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\n#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)\n#define K_PER_ITER 8\n#else\n#define K_PER_ITER 2\n#endif\n\n\nuint a_offset, b_offset, d_offset, y_offset;\n\nvoid iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)\n{\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;\n        const uint iqs = (col%QUANT_K)/QUANT_R; // quant index\n        const uint iybs = col - col%QUANT_K; // y block start index\n\n#if K_PER_ITER == 8\n#if QUANT_R == 2\n        const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);\n        const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);\n        const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);\n        const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);\n#else\n        const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);\n        const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);\n#endif\n#else\n        // Check if the second of the pair of elements is OOB, and don't fetch B or\n        // accumulate it. We still fetch a pair of elements for A, which is fine for\n        // quantized formats since they'll be within the same block. We should\n        // probably skip fetching the second element for F16/F32, but as of now we\n        // still do.\n        const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);\n\n        FLOAT_TYPE b0 = 0, b1 = 0;\n        b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);\n        if (!OOB) {\n            b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);\n        }\n#endif\n        uint ibi = first_row*p.ncols;\n        [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n            const uint ib = (ibi + col)/QUANT_K; // block index\n            ibi += p.ncols;\n\n#if K_PER_ITER == 8\n            vec4 v = dequantize4(ib, iqs, a_offset);\n            vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);\n\n            const vec2 dm = get_dm(ib, a_offset);\n            if (dm.y != 0) { // quant has min component\n                v = v * dm.x + dm.y;\n                v2 = v2 * dm.x + dm.y;\n            }\n\n            // matrix multiplication\n            FLOAT_TYPE rowtmp = dot(bv0, v);\n            rowtmp += dot(bv1, v2);\n\n            if (dm.y == 0)\n                rowtmp *= dm.x;\n\n            temp[j][n] += rowtmp;\n#else\n            const vec2 v = dequantize(ib, iqs, a_offset);\n\n            // matrix multiplication\n            temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);\n            if (!OOB) {\n                temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);\n            }\n#endif\n        }\n    }\n}\n\nvoid compute_outputs(const uint32_t first_row, const uint32_t num_rows) {\n    const uint tid = gl_LocalInvocationID.x;\n\n    get_offsets(a_offset, b_offset, d_offset);\n\n    y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;\n\n    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];\n\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {\n            temp[j][i] = FLOAT_TYPE(0);\n        }\n    }\n\n    uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);\n    if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {\n        num_iters++;\n    }\n    int unroll_count = 4;\n    uint unrolled_iters = num_iters & ~(unroll_count - 1);\n\n#if K_PER_ITER == 2\n    // If the K dimension is odd, we need lastiter==true on the last iteration\n    // so OOB is computed correctly. Skip some unrolling to make that happen.\n    if ((p.ncols & 1) != 0 &&\n        unrolled_iters == num_iters &&\n        unrolled_iters > 0) {\n        unrolled_iters -= unroll_count;\n    }\n#endif\n\n    uint i = 0;\n    while (i < unrolled_iters) {\n        // Manually partially unroll the loop\n        [[unroll]] for (uint k = 0; k < unroll_count; ++k) {\n            iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);\n            i++;\n        }\n    }\n\n    unroll_count = 2;\n    unrolled_iters = num_iters & ~(unroll_count - 1);\n\n#if K_PER_ITER == 2\n    if ((p.ncols & 1) != 0 &&\n        unrolled_iters == num_iters &&\n        unrolled_iters > 0) {\n        unrolled_iters -= unroll_count;\n    }\n#endif\n\n    while (i < unrolled_iters) {\n        // Manually partially unroll the loop\n        [[unroll]] for (uint k = 0; k < unroll_count; ++k) {\n            iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);\n            i++;\n        }\n    }\n    while (i < num_iters) {\n        iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);\n        i++;\n    }\n\n    reduce_result(temp, d_offset, first_row, num_rows, tid);\n}\n\nvoid main() {\n    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);\n\n#ifdef NEEDS_INIT_IQ_SHMEM\n    init_iq_shmem(gl_WorkGroupSize);\n#endif\n\n    // do NUM_ROWS at a time, unless there aren't enough remaining rows\n    if (first_row + NUM_ROWS <= p.stride_d) {\n        compute_outputs(first_row, NUM_ROWS);\n    } else {\n        if (first_row >= p.stride_d) {\n            return;\n        }\n        compute_outputs(first_row, p.stride_d - first_row);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl",
    "content": "#extension GL_EXT_control_flow_attributes : enable\n#extension GL_EXT_shader_16bit_storage : require\n#extension GL_EXT_shader_8bit_storage : require\n\n#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM\n#extension GL_KHR_shader_subgroup_basic : require\n#extension GL_KHR_shader_subgroup_arithmetic : require\n#endif\n\n#ifdef MUL_MAT_ID\n#define EXPERT_COUNT 8\n#endif\n\n#include \"mul_mat_vec_iface.glsl\"\n\nlayout (push_constant) uniform parameter\n{\n    uint ncols;\n    uint stride_a;\n    uint stride_b;\n    uint stride_d;\n\n    uint batch_stride_a;\n    uint batch_stride_b;\n    uint batch_stride_d;\n\n    uint fusion_flags;\n\n#ifdef MUL_MAT_ID\n    uint nei0;\n    uint ne11;\n    uint expert_i1;\n    uint nbi1;\n#else\n    uint base_work_group_y;\n    uint ne02;\n    uint ne12;\n    uint broadcast2;\n    uint broadcast3;\n#endif\n} p;\n\n#ifdef MUL_MAT_ID\nuint expert_id;\n#endif\n\nvoid get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {\n#ifdef MUL_MAT_ID\n    const uint expert_i0 = gl_WorkGroupID.y;\n#else\n    const uint batch_idx = gl_WorkGroupID.y + p.base_work_group_y;\n#endif\n\n#ifndef MUL_MAT_ID\n    uint batch_idx_a = 0;\n    if (batch_idx != 0) {\n        const uint i13 = batch_idx / p.ne12;\n        const uint i12 = batch_idx % p.ne12;\n\n        const uint i03 = i13 / p.broadcast3;\n        const uint i02 = i12 / p.broadcast2;\n\n        batch_idx_a = i03 * p.ne02 + i02;\n    }\n#else\n    expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1];\n#endif\n\n    a_offset =\n#ifdef MUL_MAT_ID\n            expert_id * (p.batch_stride_a / QUANT_K);\n#else\n            batch_idx_a * (p.batch_stride_a / QUANT_K);\n#endif\n    b_offset =\n#ifdef MUL_MAT_ID\n            (expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b;\n#else\n            batch_idx * p.batch_stride_b;\n#endif\n    d_offset =\n#ifdef MUL_MAT_ID\n            expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d;\n#else\n            batch_idx * p.batch_stride_d;\n#endif\n}\n\nlayout (constant_id = 0) const uint BLOCK_SIZE = 32;\nlayout (constant_id = 1) const uint NUM_ROWS = 1;\nlayout (constant_id = 2) const uint NUM_COLS = 1;\n\n#ifdef USE_SUBGROUP_ADD_NO_SHMEM\nvoid reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n            temp[j][n] = subgroupAdd(temp[j][n]);\n        }\n    }\n\n    if (tid == 0) {\n        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n            [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n#ifdef MUL_MAT_ID\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {\n                    temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);\n                }\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {\n                    const uint expert_i0 = gl_GlobalInvocationID.y;\n                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);\n                }\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {\n                    const uint expert_i0 = gl_GlobalInvocationID.y;\n                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);\n                }\n#else\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {\n                    temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);\n                }\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {\n                    temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);\n                }\n#endif\n                data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);\n            }\n        }\n    }\n}\n#else\nshared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];\n\nvoid reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {\n    // subgroupAdd is probably faster on devices that support it,\n    // particularly when the workgroup has more than one subgroup\n#if USE_SUBGROUP_ADD\n    // sum up partial sums within a subgroup\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n            temp[j][n] = subgroupAdd(temp[j][n]);\n        }\n    }\n\n    // Go through shared memory to sum partials across subgroups\n    if (gl_SubgroupInvocationID == 0) {\n        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n            [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n                tmpsh[j][n][gl_SubgroupID] = temp[j][n];\n            }\n        }\n    }\n    barrier();\n    if (tid == 0) {\n        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n            [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n                temp[j][n] = FLOAT_TYPE(0);\n                [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {\n                    temp[j][n] += tmpsh[j][n][s];\n                }\n#ifdef MUL_MAT_ID\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {\n                    temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);\n                }\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {\n                    const uint expert_i0 = gl_GlobalInvocationID.y;\n                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);\n                }\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {\n                    const uint expert_i0 = gl_GlobalInvocationID.y;\n                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);\n                }\n#else\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {\n                    temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);\n                }\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {\n                    temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);\n                }\n#endif\n                data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);\n            }\n        }\n    }\n#else\n    // sum up partial sums and write back result\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n            tmpsh[j][n][tid] = temp[j][n];\n        }\n    }\n    barrier();\n    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {\n        if (tid < s) {\n            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n                [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n                    tmpsh[j][n][tid] += tmpsh[j][n][tid + s];\n                }\n            }\n        }\n        barrier();\n    }\n    if (tid == 0) {\n        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n            [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n#ifdef MUL_MAT_ID\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {\n                    tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);\n                }\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {\n                    const uint expert_i0 = gl_GlobalInvocationID.y;\n                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]);\n                }\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {\n                    const uint expert_i0 = gl_GlobalInvocationID.y;\n                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]);\n                }\n#else\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {\n                    tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);\n                }\n                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {\n                    tmpsh[j][n][0] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);\n                }\n#endif\n                data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);\n            }\n        }\n    }\n#endif\n}\n#endif\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl",
    "content": "#include \"types.glsl\"\n\n#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1\n#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2\n#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4\n#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\n#if defined(A_TYPE_VEC4)\nlayout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};\n#endif\n#if defined(A_TYPE_PACKED16)\nlayout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};\n#endif\n#if defined(A_TYPE_PACKED32)\nlayout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};\n#endif\n\nlayout (binding = 1) readonly buffer B {B_TYPE data_b[];};\n#ifdef B_TYPE_VEC2\nlayout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};\n#endif\n#ifdef B_TYPE_VEC4\nlayout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};\n#endif\n\nlayout (binding = 2) writeonly buffer D {D_TYPE data_d[];};\n\nlayout (binding = 3) readonly buffer Fuse0 {D_TYPE data_fuse0[];};\nlayout (binding = 4) readonly buffer Fuse1 {D_TYPE data_fuse1[];};\n\n#ifdef MUL_MAT_ID\nlayout (binding = 5) readonly buffer IDS {int data_ids[];};\n#endif\n\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp",
    "content": "#version 450\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#include \"mul_mat_vec_base.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nFLOAT_TYPE temp[NUM_COLS][NUM_ROWS];\n\nvoid calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i,\n                               const uint num_blocks_per_row, const uint first_row, const uint num_rows) {\n    // Compute starting index in matrix B for this superblock\n    const uint y_idx = i * QUANT_K + 32 * ib32;\n    uint ibi = a_offset + first_row * num_blocks_per_row + i;\n\n    // Precompute indices for quantization lookup tables\n    const uint qh_base = 2 * ib32;\n    const uint qs_base = 4 * ib32;\n    const uint sc_index = ib32 / 2;\n    const uint sc_shift = 6 * (ib32 & 1);\n\n    // Loop over rows in the superblock\n    [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n        // Load per-block scales and shift for quantization\n        const uint16_t[4] scales = data_a[ibi].scales;\n        const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;\n        const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);\n        const uint sc = data_a[ibi].scales[sc_index] >> sc_shift;\n\n        // Temporary caches for decoding\n        FLOAT_TYPE dl_cache[4];\n        uint16_t gvf_cache[4];\n        float delta_cache[4];\n\n        // Precompute the multiplier and lookup values for 4 sub-blocks\n        [[unroll]] for (uint l = 0; l < 4; ++l) {\n            dl_cache[l] = FLOAT_TYPE(d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1));\n            const uint qh = data_a[ibi].qh[qh_base + l / 2] >> (4 * (l & 1));\n            const uint qs = data_a[ibi].qs[qs_base + l];\n            gvf_cache[l] = iq1s_grid[qs | ((qh & 7) << 8)];\n            delta_cache[l] = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;\n        }\n\n        // Loop over columns of the output\n        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n            // Compute base index for matrix B\n            const uint base_b_idx = (j * p.batch_stride_b + b_offset + y_idx) / 4;\n            vec4 b_vals[8];\n\n            // Load 8 vec4 values from matrix B\n            [[unroll]] for (int idx = 0; idx < 8; ++idx) {\n                b_vals[idx] = vec4(data_b_v4[base_b_idx + idx]);\n            }\n\n            FLOAT_TYPE col_sum = FLOAT_TYPE(0.0);\n\n            // Loop over sub-blocks\n            [[unroll]] for (uint l = 0; l < 4; ++l) {\n                const uint16_t grid = gvf_cache[l];\n                const float dl = dl_cache[l];\n\n                // Decode 8 2-bit fbits from gvf_cache\n                float f0 = float(bitfieldExtract(grid, 0, 2));\n                float f1 = float(bitfieldExtract(grid, 2, 2));\n                float f2 = float(bitfieldExtract(grid, 4, 2));\n                float f3 = float(bitfieldExtract(grid, 6, 2));\n                float f4 = float(bitfieldExtract(grid, 8, 2));\n                float f5 = float(bitfieldExtract(grid, 10, 2));\n                float f6 = float(bitfieldExtract(grid, 12, 2));\n                float f7 = float(bitfieldExtract(grid, 14, 2));\n\n                // Pack into vec4 for vectorized FMA\n                const vec4 fbits_v0 = vec4(f0, f1, f2, f3);\n                const vec4 fbits_v1 = vec4(f4, f5, f6, f7);\n                const vec4 delta_v = vec4(delta_cache[l]);\n\n                // Vectorized fused multiply-add\n                vec4 sum_v = fma(b_vals[2*l + 0], fbits_v0 + delta_v, vec4(0.0));\n                sum_v      = fma(b_vals[2*l + 1], fbits_v1 + delta_v, sum_v);\n\n                // Horizontal add to get scalar sum\n                FLOAT_TYPE sum = sum_v.x + sum_v.y + sum_v.z + sum_v.w;\n\n                // Accumulate to column sum\n                col_sum = fma(dl, sum, col_sum);\n            }\n            // Write result to temporary buffer\n            temp[j][n] += col_sum;\n        }\n        ibi += num_blocks_per_row;\n    }\n}\n\nvoid compute_outputs(const uint32_t first_row, const uint32_t num_rows) {\n    uint a_offset, b_offset, d_offset;\n    get_offsets(a_offset, b_offset, d_offset);\n\n    const uint num_blocks_per_row = p.ncols / QUANT_K;\n\n    // 8 threads are used to process each block\n    const uint blocks_per_wg = gl_WorkGroupSize.x/8;\n    const uint tid = gl_LocalInvocationID.x;\n    const uint itid = tid % 8;  // 0...7\n    const uint ix = tid / 8;\n\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {\n            temp[j][i] = FLOAT_TYPE(0);\n        }\n    }\n\n    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)\n        calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);\n\n    reduce_result(temp, d_offset, first_row, num_rows, tid);\n}\n\nvoid main() {\n    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    // do NUM_ROWS at a time, unless there aren't enough remaining rows\n    if (first_row + NUM_ROWS <= p.stride_d) {\n        compute_outputs(first_row, NUM_ROWS);\n    } else {\n        if (first_row >= p.stride_d) {\n            return;\n        }\n        compute_outputs(first_row, p.stride_d - first_row);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp",
    "content": "#version 450\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#include \"mul_mat_vec_base.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nFLOAT_TYPE temp[NUM_COLS][NUM_ROWS];\n\nvoid calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i,\n                     const uint num_blocks_per_row, const uint first_row, const uint num_rows) {\n    const uint y_idx_base = i * QUANT_K + 32 * ib32;\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        const uint base_b_idx = (j * p.batch_stride_b + b_offset + y_idx_base) / 4;\n        [[unroll]] for (uint l = 0; l < 4; ++l) {\n            const vec4 b_val_0 = vec4(data_b_v4[base_b_idx + 2 * l]);\n            const vec4 b_val_1 = vec4(data_b_v4[base_b_idx + 2 * l + 1]);\n\n            // index for data_a\n            uint ibi = a_offset + first_row * num_blocks_per_row + i;\n\n            [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n                const float d = float(data_a[ibi].d);\n                const uint qh = data_a[ibi].qh[ib32];\n\n                const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);\n                const uint qs = data_a[ibi].qs[4 * ib32 + l];\n                const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3);\n                const uint16_t grid = uint16_t(iq1s_grid[qs | (idxhi << 8)]);\n\n                const float delta_val = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;\n                const vec4 delta_v = vec4(delta_val);\n                const vec4 fbits0 = vec4(\n                    float(bitfieldExtract(grid, 0, 2)),\n                    float(bitfieldExtract(grid, 2, 2)),\n                    float(bitfieldExtract(grid, 4, 2)),\n                    float(bitfieldExtract(grid, 6, 2))\n                );\n                const vec4 fbits1 = vec4(\n                    float(bitfieldExtract(grid, 8, 2)),\n                    float(bitfieldExtract(grid, 10, 2)),\n                    float(bitfieldExtract(grid, 12, 2)),\n                    float(bitfieldExtract(grid, 14, 2))\n                );\n\n                vec4 sum_v = fma(b_val_0, fbits0 + delta_v, vec4(0.0));\n                sum_v      = fma(b_val_1, fbits1 + delta_v, sum_v);\n                FLOAT_TYPE sum = dot(sum_v, vec4(1.0));\n\n                temp[j][n] = fma(dl, sum, temp[j][n]);\n                ibi += num_blocks_per_row;\n            }\n        }\n    }\n}\n\nvoid compute_outputs(const uint32_t first_row, const uint32_t num_rows) {\n    uint a_offset, b_offset, d_offset;\n    get_offsets(a_offset, b_offset, d_offset);\n\n    const uint num_blocks_per_row = p.ncols / QUANT_K;\n\n    // 8 threads are used to process each block\n    const uint blocks_per_wg = gl_WorkGroupSize.x/8;\n    const uint tid = gl_LocalInvocationID.x;\n    const uint itid = tid % 8;  // 0...7\n    const uint ix = tid / 8;\n\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {\n            temp[j][i] = FLOAT_TYPE(0);\n        }\n    }\n\n    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)\n        calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);\n\n    reduce_result(temp, d_offset, first_row, num_rows, tid);\n}\n\nvoid main() {\n    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    // do NUM_ROWS at a time, unless there aren't enough remaining rows\n    if (first_row + NUM_ROWS <= p.stride_d) {\n        compute_outputs(first_row, NUM_ROWS);\n    } else {\n        if (first_row >= p.stride_d) {\n            return;\n        }\n        compute_outputs(first_row, p.stride_d - first_row);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp",
    "content": "#version 450\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#include \"mul_mat_vec_base.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nFLOAT_TYPE temp[NUM_COLS][NUM_ROWS];\n\nvoid calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {\n    const uint y_idx = i * QUANT_K + 16 * itid;\n    const uint nibble_shift = 4 * (itid & 1);\n    const uint ib32 = itid / 2; // 0..7\n\n    uint ibi = a_offset + first_row * num_blocks_per_row + i;\n    [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n        const float d = float(data_a[ibi].d);\n        const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;\n        const float db = d * (0.5 + scale) * 0.25;\n\n        const uint qh = data_a[ibi].qh[ib32];\n        const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147\n        const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy;\n        [[unroll]] for (uint l = 0; l < 2; ++l) {\n            const uint8_t sign = sign16[l];\n            const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300);\n            const uvec2 grid = iq2s_grid[qs];\n            const vec4 grid0 = vec4(unpack8(grid.x));\n            const vec4 grid1 = vec4(unpack8(grid.y));\n\n            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n                vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);\n                vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);\n\n                FLOAT_TYPE sum =\n                      fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign &   1) != 0 ? -grid0.x : grid0.x),\n                      fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign &   2) != 0 ? -grid0.y : grid0.y),\n                      fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign &   4) != 0 ? -grid0.z : grid0.z),\n                      fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign &   8) != 0 ? -grid0.w : grid0.w),\n                      fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign &  16) != 0 ? -grid1.x : grid1.x),\n                      fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign &  32) != 0 ? -grid1.y : grid1.y),\n                      fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign &  64) != 0 ? -grid1.z : grid1.z),\n                      fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w),\n                      FLOAT_TYPE(0.0)))))))));\n                temp[j][n] = fma(db, sum, temp[j][n]);\n            }\n        }\n        ibi += num_blocks_per_row;\n    }\n}\n\nvoid compute_outputs(const uint32_t first_row, const uint32_t num_rows) {\n    uint a_offset, b_offset, d_offset;\n    get_offsets(a_offset, b_offset, d_offset);\n\n    const uint num_blocks_per_row = p.ncols / QUANT_K;\n\n    // 16 threads are used to process each block\n    const uint blocks_per_wg = gl_WorkGroupSize.x/16;\n    const uint tid = gl_LocalInvocationID.x;\n    const uint itid = tid % 16;  // 0...15\n    const uint ix = tid / 16;\n\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {\n            temp[j][i] = FLOAT_TYPE(0);\n        }\n    }\n\n    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)\n        calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);\n\n    reduce_result(temp, d_offset, first_row, num_rows, tid);\n}\n\nvoid main() {\n    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    // do NUM_ROWS at a time, unless there aren't enough remaining rows\n    if (first_row + NUM_ROWS <= p.stride_d) {\n        compute_outputs(first_row, NUM_ROWS);\n    } else {\n        if (first_row >= p.stride_d) {\n            return;\n        }\n        compute_outputs(first_row, p.stride_d - first_row);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp",
    "content": "#version 450\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#include \"mul_mat_vec_base.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nFLOAT_TYPE temp[NUM_COLS][NUM_ROWS];\n\nvoid calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {\n    const uint y_idx = i * QUANT_K + 16 * itid;\n    const uint nibble_shift = 4 * (itid & 1);\n    const uint ib32 = itid / 2; // 0..7\n    uint ibi = a_offset + first_row * num_blocks_per_row + i;\n    // Precompute db multiplication factors\n    float db_vals[NUM_ROWS];\n    [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n        const float d = float(data_a[ibi].d);\n        const uint scale_raw = data_a[ibi].scales[ib32];\n        const uint scale = (scale_raw >> nibble_shift) & 0xF;\n        // Merge constant calculations d * (0.5 + scale) * 0.25 = d*0.125 + d*scale*0.25\n        db_vals[n] = d * (0.125f + float(scale) * 0.25f);\n        ibi += num_blocks_per_row;\n    }\n    ibi = a_offset + first_row * num_blocks_per_row + i;\n    [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n        // Preload grid and sign data for all l values\n        vec4 grid0_vals[2], grid1_vals[2];\n        uint sign_vals[2], sign7_vals[2];\n        [[unroll]] for (uint l = 0; l < 2; ++l) {\n            const uint qs = data_a[ibi].qs[2 * itid + l];\n            sign_vals[l] = qs >> 9;\n            sign7_vals[l] = bitCount(sign_vals[l]);\n            const uvec2 grid_data = iq2xs_grid[qs & 511];\n            grid0_vals[l] = vec4(unpack8(grid_data.x));\n            grid1_vals[l] = vec4(unpack8(grid_data.y));\n        }\n        // Preload B data for all j columns (reduce repeated index calculations)\n        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n            FLOAT_TYPE sum = FLOAT_TYPE(0.0);\n            [[unroll]] for (uint l = 0; l < 2; ++l) {\n                const uint sign = sign_vals[l];\n                const uint sign7 = sign7_vals[l];\n                const vec4 grid0 = grid0_vals[l];\n                const vec4 grid1 = grid1_vals[l];\n                // Precompute indices\n                const uint b_idx = (j * p.batch_stride_b + b_offset + y_idx) / 4 + 2 * l;\n                const vec4 b0 = vec4(data_b_v4[b_idx + 0]);\n                const vec4 b4 = vec4(data_b_v4[b_idx + 1]);\n                sum +=\n                    fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign &   1) != 0 ? -grid0.x : grid0.x),\n                    fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign &   2) != 0 ? -grid0.y : grid0.y),\n                    fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign &   4) != 0 ? -grid0.z : grid0.z),\n                    fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign &   8) != 0 ? -grid0.w : grid0.w),\n                    fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign &  16) != 0 ? -grid1.x : grid1.x),\n                    fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign &  32) != 0 ? -grid1.y : grid1.y),\n                    fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign &  64) != 0 ? -grid1.z : grid1.z),\n                    fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 &  1) != 0 ? -grid1.w : grid1.w),\n                    FLOAT_TYPE(0.0)))))))));\n            }\n            temp[j][n] = fma(FLOAT_TYPE(db_vals[n]), sum, temp[j][n]);\n        }\n        ibi += num_blocks_per_row;\n    }\n}\n\nvoid compute_outputs(const uint32_t first_row, const uint32_t num_rows) {\n    uint a_offset, b_offset, d_offset;\n    get_offsets(a_offset, b_offset, d_offset);\n\n    const uint num_blocks_per_row = p.ncols / QUANT_K;\n\n    // 16 threads are used to process each block\n    const uint blocks_per_wg = gl_WorkGroupSize.x/16;\n    const uint tid = gl_LocalInvocationID.x;\n    const uint itid = tid % 16;  // 0...15\n    const uint ix = tid / 16;\n\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {\n            temp[j][i] = FLOAT_TYPE(0);\n        }\n    }\n\n    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)\n        calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);\n\n    reduce_result(temp, d_offset, first_row, num_rows, tid);\n}\n\nvoid main() {\n    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    // do NUM_ROWS at a time, unless there aren't enough remaining rows\n    if (first_row + NUM_ROWS <= p.stride_d) {\n        compute_outputs(first_row, NUM_ROWS);\n    } else {\n        if (first_row >= p.stride_d) {\n            return;\n        }\n        compute_outputs(first_row, p.stride_d - first_row);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp",
    "content": "#version 450\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#include \"mul_mat_vec_base.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nFLOAT_TYPE temp[NUM_COLS][NUM_ROWS];\n\nvoid calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {\n    const uint y_idx = i * QUANT_K + 16 * itid;\n    const uint ib32 = itid / 2; // 0..7\n\n    uint ibi = a_offset + first_row * num_blocks_per_row + i;\n    [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n        const float d = float(data_a[ibi].d);\n        const uint signscale = pack32(u16vec2(\n            data_a_packed16[ibi].qs[4 * ib32 + 2],\n            data_a_packed16[ibi].qs[4 * ib32 + 3]));\n        const float db = d * 0.25 * (0.5 + (signscale >> 28));\n        [[unroll]] for (uint l = 0; l < 2; ++l) {\n            const uint qs = data_a[ibi].qs[8 * ib32 + 2 * (itid & 1) + l];\n            const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7);\n            const uint sign7 = bitCount(sign);\n            const vec4 grid0 = vec4(unpack8(iq2xxs_grid[qs].x));\n            const vec4 grid1 = vec4(unpack8(iq2xxs_grid[qs].y));\n\n            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n                const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);\n                const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);\n\n                FLOAT_TYPE sum =\n                      fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign &   1) != 0 ? -grid0.x : grid0.x),\n                      fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign &   2) != 0 ? -grid0.y : grid0.y),\n                      fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign &   4) != 0 ? -grid0.z : grid0.z),\n                      fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign &   8) != 0 ? -grid0.w : grid0.w),\n                      fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign &  16) != 0 ? -grid1.x : grid1.x),\n                      fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign &  32) != 0 ? -grid1.y : grid1.y),\n                      fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign &  64) != 0 ? -grid1.z : grid1.z),\n                      fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 &  1) != 0 ? -grid1.w : grid1.w),\n                      FLOAT_TYPE(0.0)))))))));\n                temp[j][n] = fma(db, sum, temp[j][n]);\n            }\n        }\n        ibi += num_blocks_per_row;\n    }\n}\n\nvoid compute_outputs(const uint32_t first_row, const uint32_t num_rows) {\n    uint a_offset, b_offset, d_offset;\n    get_offsets(a_offset, b_offset, d_offset);\n\n    const uint num_blocks_per_row = p.ncols / QUANT_K;\n\n    // 16 threads are used to process each block\n    const uint blocks_per_wg = gl_WorkGroupSize.x/16;\n    const uint tid = gl_LocalInvocationID.x;\n    const uint itid = tid % 16;  // 0...15\n    const uint ix = tid / 16;\n\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {\n            temp[j][i] = FLOAT_TYPE(0);\n        }\n    }\n\n    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)\n        calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);\n\n    reduce_result(temp, d_offset, first_row, num_rows, tid);\n}\n\nvoid main() {\n    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    // do NUM_ROWS at a time, unless there aren't enough remaining rows\n    if (first_row + NUM_ROWS <= p.stride_d) {\n        compute_outputs(first_row, NUM_ROWS);\n    } else {\n        if (first_row >= p.stride_d) {\n            return;\n        }\n        compute_outputs(first_row, p.stride_d - first_row);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp",
    "content": "#version 450\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#include \"mul_mat_vec_base.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nFLOAT_TYPE temp[NUM_COLS][NUM_ROWS];\n\nvoid calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {\n    const uint y_idx = i * QUANT_K + 32 * ib32;\n\n    uint ibi = a_offset + first_row * num_blocks_per_row + i;\n    [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n        const float d = float(data_a[ibi].d);\n        const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF;\n        const float dscale = d * (1 + 2 * scale);\n        const uint qh = data_a[ibi].qh[ib32];\n        FLOAT_TYPE sum[NUM_COLS];\n        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n            sum[j] = 0.0;\n        }\n        [[unroll]] for (uint l = 0; l < 4; ++l) {\n            const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147\n            const uint sign = data_a[ibi].signs[4 * ib32 + l];\n            const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)]));\n            const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)]));\n\n            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n                const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);\n                const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);\n\n                sum[j] =\n                      fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign &   1) != 0 ? -grid0.x : grid0.x),\n                      fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign &   2) != 0 ? -grid0.y : grid0.y),\n                      fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign &   4) != 0 ? -grid0.z : grid0.z),\n                      fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign &   8) != 0 ? -grid0.w : grid0.w),\n                      fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign &  16) != 0 ? -grid1.x : grid1.x),\n                      fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign &  32) != 0 ? -grid1.y : grid1.y),\n                      fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign &  64) != 0 ? -grid1.z : grid1.z),\n                      fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w),\n                      sum[j]))))))));\n            }\n        }\n        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n            temp[j][n] = fma(dscale, sum[j], temp[j][n]);\n        }\n        ibi += num_blocks_per_row;\n    }\n}\n\nvoid compute_outputs(const uint32_t first_row, const uint32_t num_rows) {\n    uint a_offset, b_offset, d_offset;\n    get_offsets(a_offset, b_offset, d_offset);\n\n    const uint num_blocks_per_row = p.ncols / QUANT_K;\n\n    // 8 threads are used to process each block\n    const uint blocks_per_wg = gl_WorkGroupSize.x/8;\n    const uint tid = gl_LocalInvocationID.x;\n    const uint itid = tid % 8;  // 0...7\n    const uint ix = tid / 8;\n\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {\n            temp[j][i] = FLOAT_TYPE(0);\n        }\n    }\n\n    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)\n        calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);\n\n    reduce_result(temp, d_offset, first_row, num_rows, tid);\n}\n\nvoid main() {\n    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    // do NUM_ROWS at a time, unless there aren't enough remaining rows\n    if (first_row + NUM_ROWS <= p.stride_d) {\n        compute_outputs(first_row, NUM_ROWS);\n    } else {\n        if (first_row >= p.stride_d) {\n            return;\n        }\n        compute_outputs(first_row, p.stride_d - first_row);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp",
    "content": "#version 450\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#include \"mul_mat_vec_base.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nFLOAT_TYPE temp[NUM_COLS][NUM_ROWS];\n\nvoid calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {\n    const uint y_idx = i * QUANT_K + 16 * itid;\n    const uint ib32 = itid / 2; // 0..7\n\n    uint ibi = a_offset + first_row * num_blocks_per_row + i;\n    [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n        const float d = float(data_a[ibi].d);\n        const uint signscale = pack32(u16vec2(\n            data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32],\n            data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32 + 1]));\n        const float db = d * 0.5 * (0.5 + (signscale >> 28));\n        [[unroll]] for (uint l = 0; l < 2; ++l) {\n            const uint qs0 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l];\n            const uint qs1 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l + 1];\n            const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7);\n            const uint sign7 = bitCount(sign);\n            const vec4 grid0 = vec4(unpack8(iq3xxs_grid[qs0]));\n            const vec4 grid1 = vec4(unpack8(iq3xxs_grid[qs1]));\n\n            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n                const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);\n                const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);\n\n                FLOAT_TYPE sum =\n                      fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign &   1) != 0 ? -grid0.x : grid0.x),\n                      fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign &   2) != 0 ? -grid0.y : grid0.y),\n                      fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign &   4) != 0 ? -grid0.z : grid0.z),\n                      fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign &   8) != 0 ? -grid0.w : grid0.w),\n                      fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign &  16) != 0 ? -grid1.x : grid1.x),\n                      fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign &  32) != 0 ? -grid1.y : grid1.y),\n                      fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign &  64) != 0 ? -grid1.z : grid1.z),\n                      fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 &  1) != 0 ? -grid1.w : grid1.w),\n                      FLOAT_TYPE(0.0)))))))));\n                temp[j][n] = fma(db, sum, temp[j][n]);\n            }\n        }\n        ibi += num_blocks_per_row;\n    }\n}\n\nvoid compute_outputs(const uint32_t first_row, const uint32_t num_rows) {\n    uint a_offset, b_offset, d_offset;\n    get_offsets(a_offset, b_offset, d_offset);\n\n    const uint num_blocks_per_row = p.ncols / QUANT_K;\n\n    // 16 threads are used to process each block\n    const uint blocks_per_wg = gl_WorkGroupSize.x/16;\n    const uint tid = gl_LocalInvocationID.x;\n    const uint itid = tid % 16;  // 0...15\n    const uint ix = tid / 16;\n\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {\n            temp[j][i] = FLOAT_TYPE(0);\n        }\n    }\n\n    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)\n        calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);\n\n    reduce_result(temp, d_offset, first_row, num_rows, tid);\n}\n\nvoid main() {\n    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);\n\n    init_iq_shmem(gl_WorkGroupSize);\n\n    // do NUM_ROWS at a time, unless there aren't enough remaining rows\n    if (first_row + NUM_ROWS <= p.stride_d) {\n        compute_outputs(first_row, NUM_ROWS);\n    } else {\n        if (first_row >= p.stride_d) {\n            return;\n        }\n        compute_outputs(first_row, p.stride_d - first_row);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_EXT_shader_16bit_storage : require\n\n#define BLOCK_SIZE 32\n#define FLOAT_TYPE float\n\nlayout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;\n\n#include \"mul_mat_vec_iface.glsl\"\n\nlayout (push_constant) uniform parameter\n{\n    uint ncols_x;\n    uint nrows_x;\n    uint row_stride_x;\n    uint channel_stride_x;\n    uint channel_stride_y;\n    uint channel_x_divisor;\n    uint ne12;\n    uint b_offset;\n    uint d_offset;\n    uint nb03;\n    uint nb13;\n    uint nb23;\n    uint fusion_flags;\n} p;\n\nshared FLOAT_TYPE tmp[BLOCK_SIZE];\n\nvoid main() {\n    const uint tid       = gl_LocalInvocationID.x;\n    const uint row_x     = gl_GlobalInvocationID.y;\n    const uint channel   = gl_GlobalInvocationID.z;\n    const uint i3        = gl_WorkGroupID.x;\n    const uint channel_x = channel / p.channel_x_divisor;\n    const uint channel_y = channel % p.ne12;\n\n    const uint nrows_y   = p.ncols_x;\n    const uint nrows_dst = p.nrows_x;\n    const uint row_dst   = row_x;\n\n    const uint idst = i3*p.nb23 + channel*nrows_dst + row_dst;\n\n    FLOAT_TYPE temp = 0.0f;\n\n    // Detect alignment for vector loads\n    bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0;\n\n    for (uint col_x0 = 0; col_x0 < p.ncols_x;) {\n\n        // Unroll 2x and do vec4 loads if aligned\n        const uint unroll_count = 2;\n        if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) {\n            [[unroll]] for (uint i = 0; i < unroll_count; ++i) {\n                const uint col_x = col_x0 + 4*tid;\n\n                const uint row_y = col_x;\n\n                const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;\n                const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;\n\n                const vec4 av4 = vec4(data_a_v4[ix / 4]);\n                const vec4 bv4 = vec4(data_b_v4[iy / 4]);\n\n                temp += dot(av4, bv4);\n\n                col_x0 += 4*BLOCK_SIZE;\n            }\n        // do vec4 loads if aligned\n        } else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {\n            const uint col_x = col_x0 + 4*tid;\n\n            const uint row_y = col_x;\n\n            const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;\n            const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;\n\n            const vec4 av4 = vec4(data_a_v4[ix / 4]);\n            const vec4 bv4 = vec4(data_b_v4[iy / 4]);\n\n            temp += dot(av4, bv4);\n\n            col_x0 += 4*BLOCK_SIZE;\n        } else {\n            const uint col_x = col_x0 + tid;\n            if (col_x >= p.ncols_x) {\n                break;\n            }\n\n            const uint row_y = col_x;\n\n            const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;\n            const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;\n\n            const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);\n\n            temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp);\n            col_x0 += BLOCK_SIZE;\n        }\n    }\n\n    tmp[tid] = temp;\n\n    // sum up partial sums and write back result\n    barrier();\n    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            tmp[tid] += tmp[tid + s];\n        }\n        barrier();\n    }\n\n    if (tid == 0) {\n        if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {\n            tmp[0] += FLOAT_TYPE(data_fuse0[idst]);\n        }\n        if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {\n            tmp[0] += FLOAT_TYPE(data_fuse1[idst]);\n        }\n        data_d[idst] = tmp[0];\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_EXT_shader_16bit_storage : require\n#if USE_SUBGROUP_ADD\n#extension GL_KHR_shader_subgroup_arithmetic : enable\n#endif\n\n#define FLOAT_TYPE float\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\n#include \"mul_mat_vec_iface.glsl\"\n\nlayout(constant_id = 0) const int BLOCK_SIZE = 32;\n// gqa_ratio is in the range [1,8]\nlayout(constant_id = 1) const uint gqa_ratio = 1;\n\nlayout (push_constant) uniform parameter\n{\n    uint ncols_x;\n    uint nrows_x;\n    uint nchannels_x;\n    uint nchannels_y;\n    uint b_offset;\n    uint d_offset;\n    uint fusion_flags;\n} p;\n\n#if !USE_SUBGROUP_ADD\nshared FLOAT_TYPE tmp[8][BLOCK_SIZE];\n#endif\n\nvoid main() {\n    const uint tid = gl_LocalInvocationID.x;\n    const uint row_x = gl_GlobalInvocationID.y;\n\n    uint channel, channel_x;\n\n    // When gqa_ratio > 1, each invocation does multiple rows.\n    // The row in the A matrix is starting from channel / gqa_ratio and the\n    // rows in the B matrix are [channel, channel+gqa_ratio).\n    // When gpa_ratio is 1, each invocation does one row.\n    if (gqa_ratio > 1) {\n        channel_x = gl_GlobalInvocationID.z;\n        channel = channel_x * gqa_ratio;\n    } else {\n        channel = gl_GlobalInvocationID.z;\n        channel_x = channel / (p.nchannels_y / p.nchannels_x);;\n    }\n\n    const uint nrows_y = p.ncols_x;\n    const uint nrows_dst = p.nrows_x;\n    const uint row_dst = row_x;\n\n    FLOAT_TYPE temp[8];\n    [[unroll]] for (uint i = 0; i < 8; ++i) {\n        temp[i] = FLOAT_TYPE(0.0f);\n    }\n\n    // Detect alignment for vector loads\n    bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0;\n\n    for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {\n\n        // Use vec4 loads if aligned\n        if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {\n\n            uint col_x = col_x0 + 4*tid;\n            const uint row_y = col_x;\n\n            // x is transposed and permuted\n            const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;\n            const vec4 av4 = vec4(data_a_v4[ix / 4]);\n\n            [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {\n                // y is not transposed but permuted\n                const uint iy = (channel + c)*nrows_y + row_y;\n\n                vec4 bv4 = data_b_v4[iy / 4];\n                temp[c] += dot(av4, bv4);\n            }\n\n            col_x0 += 3*BLOCK_SIZE;\n        } else {\n            const uint col_x = col_x0 + tid;\n\n            if (col_x >= p.ncols_x) {\n                break;\n            }\n\n            // x is transposed and permuted\n            const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;\n            const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);\n\n            const uint row_y = col_x;\n\n            [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {\n                // y is not transposed but permuted\n                const uint iy = (channel + c)*nrows_y + row_y;\n\n                temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]);\n            }\n        }\n    }\n\n#if USE_SUBGROUP_ADD\n    // reduce vec4 at a time\n    vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]);\n    t = subgroupAdd(t);\n    temp[0] = t[0];\n    temp[1] = t[1];\n    temp[2] = t[2];\n    temp[3] = t[3];\n    if (gqa_ratio > 4) {\n        t = vec4(temp[4], temp[5], temp[6], temp[7]);\n        t = subgroupAdd(t);\n        temp[4] = t[0];\n        temp[5] = t[1];\n        temp[6] = t[2];\n        temp[7] = t[3];\n    }\n#else\n    [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {\n        tmp[c][tid] = temp[c];\n    }\n    // sum up partial sums and write back result\n    barrier();\n    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {\n                temp[c] += tmp[c][tid + s];\n                tmp[c][tid] = temp[c];\n            }\n        }\n        barrier();\n    }\n    [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {\n        temp[c] = tmp[c][tid];\n    }\n#endif\n\n    if (tid == 0) {\n        [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {\n            // dst is not transposed and not permuted\n            const uint idst = (channel + c)*nrows_dst + row_dst;\n            if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {\n                temp[c] += FLOAT_TYPE(data_fuse0[idst]);\n            }\n            if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {\n                temp[c] += FLOAT_TYPE(data_fuse1[idst]);\n            }\n            data_d[idst] = temp[c];\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp",
    "content": "#version 450\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#include \"mul_mat_vec_base.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nshared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16];\nshared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16];\n\nFLOAT_TYPE temp[NUM_COLS][NUM_ROWS];\nuint csel = 0;\n\nvoid calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {\n    const uint y_idx = i * QUANT_K + y_offset;\n\n    [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;\n        csel ^= 1;\n\n        if (!all_threads) { // when we don't have enough blocks to use all threads\n            if (i < num_blocks_per_row) {\n                const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);\n                sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);\n                sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);\n            }\n            barrier();\n\n            if (i >= num_blocks_per_row)\n                continue;\n        } else {\n            const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);\n            sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);\n            sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);\n            barrier();\n        }\n\n        const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);\n        const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));\n        const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));\n        const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));\n        const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));\n\n        const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);\n\n        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n            vec2 b0 =   vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  0]);\n            vec2 b16 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  8]);\n            vec2 b32 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);\n            vec2 b48 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);\n            vec2 b64 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);\n            vec2 b80 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);\n            vec2 b96 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);\n            vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);\n\n            FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);\n            FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);\n            [[unroll]] for (int l = 0; l < 2; ++l) {\n                sum1 = fma(FLOAT_TYPE(b0[l]),   sccache1[csel][ix][    8*v_im] * qs_u32_0[l  ],\n                       fma(FLOAT_TYPE(b16[l]),  sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2],\n                       fma(FLOAT_TYPE(b32[l]),  sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l  ],\n                       fma(FLOAT_TYPE(b48[l]),  sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2],\n                       fma(FLOAT_TYPE(b64[l]),  sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l  ],\n                       fma(FLOAT_TYPE(b80[l]),  sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2],\n                       fma(FLOAT_TYPE(b96[l]),  sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l  ],\n                       fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));\n                sum2 = fma(FLOAT_TYPE(b0[l]),   sccache2[csel][ix][    8*v_im],\n                       fma(FLOAT_TYPE(b16[l]),  sccache2[csel][ix][1 + 8*v_im],\n                       fma(FLOAT_TYPE(b32[l]),  sccache2[csel][ix][2 + 8*v_im],\n                       fma(FLOAT_TYPE(b48[l]),  sccache2[csel][ix][3 + 8*v_im],\n                       fma(FLOAT_TYPE(b64[l]),  sccache2[csel][ix][4 + 8*v_im],\n                       fma(FLOAT_TYPE(b80[l]),  sccache2[csel][ix][5 + 8*v_im],\n                       fma(FLOAT_TYPE(b96[l]),  sccache2[csel][ix][6 + 8*v_im],\n                       fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));\n            }\n            temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n]));\n        }\n    }\n}\n\nvoid compute_outputs(const uint32_t first_row, const uint32_t num_rows) {\n    uint a_offset, b_offset, d_offset;\n    get_offsets(a_offset, b_offset, d_offset);\n\n    const uint num_blocks_per_row = p.ncols / QUANT_K;\n\n    // 16 threads are used to process each block\n    const uint it_size = gl_WorkGroupSize.x/16;\n    const uint tid = gl_LocalInvocationID.x;\n    const uint itid = tid%16;  // 0...15\n    const uint ix = tid/16;\n\n    const uint v_im = itid/8;                                // 0 or 1. 0 computes 0..., 1 computes 128...\n    const uint v_in = itid - 8*v_im;                         // 0...7\n\n    const uint l0 = 2*v_in;                                  // 0...15\n    const uint q_offset = 32*v_im + l0;\n    const uint y_offset = 128*v_im + l0;\n\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {\n            temp[j][i] = FLOAT_TYPE(0);\n        }\n    }\n\n    const uint nbr_par_th = num_blocks_per_row%it_size;\n    const uint nbr_all_th = num_blocks_per_row - nbr_par_th;\n    uint i0 = 0;\n    [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)\n        calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);\n    calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);\n\n    reduce_result(temp, d_offset, first_row, num_rows, tid);\n}\n\nvoid main() {\n    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);\n\n    // do NUM_ROWS at a time, unless there aren't enough remaining rows\n    if (first_row + NUM_ROWS <= p.stride_d) {\n        compute_outputs(first_row, NUM_ROWS);\n    } else {\n        if (first_row >= p.stride_d) {\n            return;\n        }\n        compute_outputs(first_row, p.stride_d - first_row);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp",
    "content": "#version 450\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#include \"mul_mat_vec_base.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nshared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8];\n\nFLOAT_TYPE temp[NUM_COLS][NUM_ROWS];\nuint csel = 0;\n\nvoid calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {\n    const uint y_idx = i * QUANT_K + y_offset;\n\n    [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;\n        csel ^= 1;\n\n        if (!all_threads) { // when we don't have enough blocks to use all threads\n            if (i < num_blocks_per_row)\n                sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);\n            barrier();\n\n            if (i >= num_blocks_per_row)\n                continue;\n        }\n\n        const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16));\n        const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> (    v_im4)) << 2));\n        const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2));\n        const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2));\n        const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2));\n\n        // 0, 1, 16, 17\n        uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8);\n        qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16;\n        const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));\n        const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));\n        const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));\n        const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));\n\n        if (all_threads) {\n            sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);\n            barrier();\n        }\n\n        const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);\n\n        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n            vec2 b0 =   vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  0]);\n            vec2 b16 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  8]);\n            vec2 b32 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);\n            vec2 b48 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);\n            vec2 b64 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);\n            vec2 b80 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);\n            vec2 b96 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);\n            vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);\n\n            FLOAT_TYPE sum = FLOAT_TYPE(0.0);\n            [[unroll]] for (int l = 0; l < 2; ++l) {\n                sum = fma(FLOAT_TYPE(  b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l  ] - hmk_0[l  ],\n                      fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],\n                      fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l  ] - hmk_1[l  ],\n                      fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],\n                      fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l  ] - hmk_2[l  ],\n                      fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],\n                      fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l  ] - hmk_3[l  ],\n                      fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));\n            }\n            temp[j][n] = fma(d, sum, temp[j][n]);\n        }\n    }\n}\n\nvoid compute_outputs(const uint32_t first_row, const uint32_t num_rows) {\n    uint a_offset, b_offset, d_offset;\n    get_offsets(a_offset, b_offset, d_offset);\n\n    const uint num_blocks_per_row = p.ncols / QUANT_K;\n\n    // 16 threads are used to process each block\n    const uint it_size = gl_WorkGroupSize.x/16;\n    const uint tid = gl_LocalInvocationID.x;\n    const uint itid = tid%16;  // 0...15\n    const uint ix = tid/16;\n    const uint itid8 = itid%8;\n\n    const uint v_im = itid/8;                               // 0 or 1. 0 computes 0..., 1 computes 128...\n    const uint v_im4 = v_im*4;\n    const uint v_in = itid - 8*v_im;                        // 0...7\n\n    const uint32_t m = 0x01010101 << (4 * v_im);\n    uint32_t hm_m[4];\n    [[unroll]] for (uint j = 0; j < 4; ++j)\n        hm_m[j] = m << j;\n\n    const uint l0 = 2*v_in;                                 // 0...15\n    const uint q_offset = 32*v_im + l0;\n    const uint y_offset = 128*v_im + l0;\n\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {\n            temp[j][i] = FLOAT_TYPE(0);\n        }\n    }\n\n    const uint s_shift = v_im4 + 2*(itid8/4);\n\n    const uint nbr_par_th = num_blocks_per_row%it_size;\n    const uint nbr_all_th = num_blocks_per_row - nbr_par_th;\n    uint i0 = 0;\n    [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)\n        calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true);\n    calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false);\n\n    reduce_result(temp, d_offset, first_row, num_rows, tid);\n}\n\nvoid main() {\n    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);\n\n    // do NUM_ROWS at a time, unless there aren't enough remaining rows\n    if (first_row + NUM_ROWS <= p.stride_d) {\n        compute_outputs(first_row, NUM_ROWS);\n    } else {\n        if (first_row >= p.stride_d) {\n            return;\n        }\n        compute_outputs(first_row, p.stride_d - first_row);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp",
    "content": "#version 450\n\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#include \"mul_mat_vec_base.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nFLOAT_TYPE temp[NUM_COLS][NUM_ROWS];\n\nvoid calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {\n    const uint y1_idx = i * QUANT_K + y_offset;\n    const uint y2_idx = y1_idx + 128;\n\n    [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;\n        const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);\n\n        const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];\n        const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];\n        const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];\n\n        const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;\n        const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;\n        const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));\n        const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));\n\n        const FLOAT_TYPE sc0 = scale_0_4_l_f.x;\n        const FLOAT_TYPE sc1 = scale_0_4_l_f.y;\n        const FLOAT_TYPE sc2 = scale_0_4_l_f.z;\n        const FLOAT_TYPE sc3 = scale_0_4_l_f.w;\n        const FLOAT_TYPE sc4 = scale8_f.x;\n        const FLOAT_TYPE sc5 = scale8_f.y;\n        const FLOAT_TYPE sc6 = scale8_f.z;\n        const FLOAT_TYPE sc7 = scale8_f.w;\n\n        const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];\n        const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];\n\n        const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;\n        const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;\n        const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;\n        const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;\n\n        const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4));\n        const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4));\n        const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4));\n        const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4));\n\n        const FLOAT_TYPE q4_0  = qs0_lo4.x;\n        const FLOAT_TYPE q4_1  = qs0_lo4.y;\n        const FLOAT_TYPE q4_2  = qs0_lo4.z;\n        const FLOAT_TYPE q4_3  = qs0_lo4.w;\n        const FLOAT_TYPE q4_4  = qs0_hi4.x;\n        const FLOAT_TYPE q4_5  = qs0_hi4.y;\n        const FLOAT_TYPE q4_6  = qs0_hi4.z;\n        const FLOAT_TYPE q4_7  = qs0_hi4.w;\n        const FLOAT_TYPE q4_8  = qs64_lo4.x;\n        const FLOAT_TYPE q4_9  = qs64_lo4.y;\n        const FLOAT_TYPE q4_10 = qs64_lo4.z;\n        const FLOAT_TYPE q4_11 = qs64_lo4.w;\n        const FLOAT_TYPE q4_12 = qs64_hi4.x;\n        const FLOAT_TYPE q4_13 = qs64_hi4.y;\n        const FLOAT_TYPE q4_14 = qs64_hi4.z;\n        const FLOAT_TYPE q4_15 = qs64_hi4.w;\n\n        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n            vec4 by10 =  vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4    ]);\n            vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);\n            vec4 by20 =  vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4    ]);\n            vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);\n\n            const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x),      q4_0,  fma(FLOAT_TYPE(by10.y),  q4_1,  fma(FLOAT_TYPE(by10.z),  q4_2,  FLOAT_TYPE(by10.w) *  q4_3)));\n            const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x),     q4_4,  fma(FLOAT_TYPE(by132.y), q4_5,  fma(FLOAT_TYPE(by132.z), q4_6,  FLOAT_TYPE(by132.w) * q4_7)));\n            const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x),      q4_8,  fma(FLOAT_TYPE(by20.y),  q4_9,  fma(FLOAT_TYPE(by20.z),  q4_10, FLOAT_TYPE(by20.w) *  q4_11)));\n            const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x),     q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));\n            const FLOAT_TYPE smin =\n                fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,\n                fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,\n                fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,\n                fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6,     FLOAT_TYPE(by232.w) * sc7)))))))))))))));\n            temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));\n        }\n    }\n}\n\nvoid compute_outputs(const uint32_t first_row, const uint32_t num_rows) {\n    uint a_offset, b_offset, d_offset;\n    get_offsets(a_offset, b_offset, d_offset);\n\n    const uint num_blocks_per_row = p.ncols / QUANT_K;\n\n    // 16 threads are used to process each block\n    const uint it_size = gl_WorkGroupSize.x/16;\n    const uint tid = gl_LocalInvocationID.x;\n    const uint itid = tid%16;  // 0...15\n    const uint ix = tid/16;\n\n    const uint il = itid/4;                         // 0...3\n    const uint ir = itid - 4*il;                    // 0...3\n    const uint n =  4;\n\n    const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224\n    const uint v_in = il % 2;\n\n    const uint l0 = n * (2 * ir + v_in);            // 0...15\n    const uint q_offset = 32*v_im + l0;\n    const uint y_offset = 64*v_im + l0;\n\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {\n            temp[j][i] = FLOAT_TYPE(0);\n        }\n    }\n\n    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size)\n        calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows);\n\n    reduce_result(temp, d_offset, first_row, num_rows, tid);\n}\n\nvoid main() {\n    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);\n\n    // do NUM_ROWS at a time, unless there aren't enough remaining rows\n    if (first_row + NUM_ROWS <= p.stride_d) {\n        compute_outputs(first_row, NUM_ROWS);\n    } else {\n        if (first_row >= p.stride_d) {\n            return;\n        }\n        compute_outputs(first_row, p.stride_d - first_row);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp",
    "content": "#version 450\n\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#include \"mul_mat_vec_base.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nFLOAT_TYPE temp[NUM_COLS][NUM_ROWS];\n\nvoid calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {\n    const uint y1_idx = i * QUANT_K + y_offset;\n    const uint y2_idx = y1_idx + 128;\n\n    [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;\n        const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);\n\n        const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];\n        const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];\n        const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];\n\n        const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;\n        const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;\n        const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));\n        const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));\n\n        const FLOAT_TYPE sc0 = scale_0_4_l_f.x;\n        const FLOAT_TYPE sc1 = scale_0_4_l_f.y;\n        const FLOAT_TYPE sc2 = scale_0_4_l_f.z;\n        const FLOAT_TYPE sc3 = scale_0_4_l_f.w;\n        const FLOAT_TYPE sc4 = scale8_f.x;\n        const FLOAT_TYPE sc5 = scale8_f.y;\n        const FLOAT_TYPE sc6 = scale8_f.z;\n        const FLOAT_TYPE sc7 = scale8_f.w;\n\n        const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);\n        const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);\n\n        uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;\n        uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;\n        uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;\n        uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;\n\n        const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));\n\n        const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;\n        const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;\n        const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010);\n        const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;\n\n        qs0_16_u32_lo4 += qs0_16_lo4_offset16;\n        qs0_16_u32_hi4 += qs0_16_hi4_offset16;\n        qs64_80_u32_lo4 += qs64_80_lo4_offset16;\n        qs64_80_u32_hi4 += qs64_80_hi4_offset16;\n\n        const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4));\n        const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4));\n        const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4));\n        const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4));\n\n        const FLOAT_TYPE q4_0  = qs0_16_lo4.x;\n        const FLOAT_TYPE q4_1  = qs0_16_lo4.y;\n        const FLOAT_TYPE q4_2  = qs0_16_lo4.z;\n        const FLOAT_TYPE q4_3  = qs0_16_lo4.w;\n        const FLOAT_TYPE q4_4  = qs0_16_hi4.x;\n        const FLOAT_TYPE q4_5  = qs0_16_hi4.y;\n        const FLOAT_TYPE q4_6  = qs0_16_hi4.z;\n        const FLOAT_TYPE q4_7  = qs0_16_hi4.w;\n        const FLOAT_TYPE q4_8  = qs64_80_lo4.x;\n        const FLOAT_TYPE q4_9  = qs64_80_lo4.y;\n        const FLOAT_TYPE q4_10 = qs64_80_lo4.z;\n        const FLOAT_TYPE q4_11 = qs64_80_lo4.w;\n        const FLOAT_TYPE q4_12 = qs64_80_hi4.x;\n        const FLOAT_TYPE q4_13 = qs64_80_hi4.y;\n        const FLOAT_TYPE q4_14 = qs64_80_hi4.z;\n        const FLOAT_TYPE q4_15 = qs64_80_hi4.w;\n\n        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n            vec2 by10 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2     ]);\n            vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 +  8]);\n            vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);\n            vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);\n            vec2 by20 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2     ]);\n            vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 +  8]);\n            vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);\n            vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);\n\n            const FLOAT_TYPE sx =\n              fma(FLOAT_TYPE(by10.x), q4_0,\n              fma(FLOAT_TYPE(by10.y), q4_1,\n              fma(FLOAT_TYPE(by116.x), q4_2,\n                 FLOAT_TYPE(by116.y) * q4_3)));\n            const FLOAT_TYPE sy =\n              fma(FLOAT_TYPE(by132.x), q4_4,\n              fma(FLOAT_TYPE(by132.y), q4_5,\n              fma(FLOAT_TYPE(by148.x), q4_6,\n                 FLOAT_TYPE(by148.y) * q4_7)));\n            const FLOAT_TYPE sz =\n              fma(FLOAT_TYPE(by20.x), q4_8,\n              fma(FLOAT_TYPE(by20.y), q4_9,\n              fma(FLOAT_TYPE(by216.x), q4_10,\n                 FLOAT_TYPE(by216.y) * q4_11)));\n            const FLOAT_TYPE sw =\n              fma(FLOAT_TYPE(by232.x), q4_12,\n              fma(FLOAT_TYPE(by232.y), q4_13,\n              fma(FLOAT_TYPE(by248.x), q4_14,\n                 FLOAT_TYPE(by248.y) * q4_15)));\n            const FLOAT_TYPE smin =\n              fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,\n              fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,\n              fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,\n                  (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));\n            temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));\n        }\n    }\n}\n\nvoid compute_outputs(const uint32_t first_row, const uint32_t num_rows) {\n    uint a_offset, b_offset, d_offset;\n    get_offsets(a_offset, b_offset, d_offset);\n\n    const uint num_blocks_per_row = p.ncols / QUANT_K;\n\n    // 16 threads are used to process each block\n    const uint it_size = gl_WorkGroupSize.x/16;\n    const uint tid = gl_LocalInvocationID.x;\n    const uint itid = tid%16;  // 0...15\n    const uint ix = tid/16;\n\n    const uint il = itid/4;                          // 0...3\n    const uint ir = itid - 4*il;                     // 0...3\n\n    const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224\n    const uint v_in = il % 2;\n\n    const uint l0 = 4*ir + 2*v_in;                   // 0...15\n    const uint q_offset = 32*v_im + l0;\n    const uint y_offset = 64*v_im + l0;\n\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {\n            temp[j][i] = FLOAT_TYPE(0);\n        }\n    }\n\n    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size)\n        calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows);\n\n    reduce_result(temp, d_offset, first_row, num_rows, tid);\n}\n\nvoid main() {\n    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);\n\n    // do NUM_ROWS at a time, unless there aren't enough remaining rows\n    if (first_row + NUM_ROWS <= p.stride_d) {\n        compute_outputs(first_row, NUM_ROWS);\n    } else {\n        if (first_row >= p.stride_d) {\n            return;\n        }\n        compute_outputs(first_row, p.stride_d - first_row);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp",
    "content": "#version 450\n\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n\n#include \"mul_mat_vec_base.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nshared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16];\n\nFLOAT_TYPE temp[NUM_COLS][NUM_ROWS];\nuint csel = 0;\n\nvoid calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {\n    const uint y_idx = i * QUANT_K + y_offset;\n\n    [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;\n        csel ^= 1;\n\n        if (!all_threads) { // when we don't have enough blocks to use all threads\n            if (i < num_blocks_per_row)\n                sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);\n            barrier();\n\n            if (i >= num_blocks_per_row)\n                continue;\n        }\n\n        const uint32_t ql0_u32 =  uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);\n        const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);\n\n        const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;\n        const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;\n        const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;\n        const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;\n\n        const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);\n        const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;\n        const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;\n        const uint32_t qh4_u32 = (qh_u32 & 0x30303030);\n        const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;\n\n        const uint32_t q0_u32 = ql0_u32_lo4  | qh0_u32;\n        const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;\n        const uint32_t q2_u32 = ql0_u32_hi4  | qh4_u32;\n        const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;\n\n        const vec4 q0 = vec4(unpack8(q0_u32)) - 32;\n        const vec4 q1 = vec4(unpack8(q1_u32)) - 32;\n        const vec4 q2 = vec4(unpack8(q2_u32)) - 32;\n        const vec4 q3 = vec4(unpack8(q3_u32)) - 32;\n\n        if (all_threads) {\n            sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);\n            barrier();\n        }\n\n        const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);\n\n        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n            vec4 by0  = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4     ]);\n            vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 +  8]);\n            vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);\n            vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);\n\n            FLOAT_TYPE sum[4] = {0, 0, 0, 0};\n            [[unroll]] for (uint l = 0; l < 4; ++l) {\n                sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]);\n                sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]);\n                sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);\n                sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);\n            }\n            temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]);\n        }\n    }\n}\n\nvoid compute_outputs(const uint first_row, const uint num_rows) {\n    uint a_offset, b_offset, d_offset;\n    get_offsets(a_offset, b_offset, d_offset);\n\n    const uint num_blocks_per_row = p.ncols / QUANT_K;\n\n    // 16 threads are used to process each block\n    const uint it_size = gl_WorkGroupSize.x/16;\n    const uint tid = gl_LocalInvocationID.x;\n    const uint itid = tid%16;  // 0...15\n    const uint ix = tid/16;\n\n    const uint v_im = itid/8;                               // 0 or 1. 0 computes 0..., 1 computes 128...\n    const uint v_in = itid - 8*v_im;                        // 0...7\n\n    const uint l0 = 4 * v_in;                               // 0, 4, 8, ..., 28\n    const uint is = v_in / 4;\n\n    const uint ql_offset = 64*v_im + l0;\n    const uint qh_offset = 32*v_im + l0;\n    const uint s_offset  =  8*v_im + is;\n    const uint y_offset = 128*v_im + l0;\n\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {\n            temp[j][i] = FLOAT_TYPE(0);\n        }\n    }\n\n    const uint nbr_par_th = num_blocks_per_row%it_size;\n    const uint nbr_all_th = num_blocks_per_row - nbr_par_th;\n    uint i0 = 0;\n    [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)\n        calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);\n    calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);\n\n    reduce_result(temp, d_offset, first_row, num_rows, tid);\n}\n\nvoid main() {\n    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);\n\n    // do NUM_ROWS at a time, unless there aren't enough remaining rows\n    if (first_row + NUM_ROWS <= p.stride_d) {\n        compute_outputs(first_row, NUM_ROWS);\n    } else {\n        if (first_row >= p.stride_d) {\n            return;\n        }\n        compute_outputs(first_row, p.stride_d - first_row);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp",
    "content": "#version 450\n\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n#extension GL_EXT_integer_dot_product : require\n\n#define MMQ\n#define B_TYPE block_q8_1_x4\n\n#include \"mul_mat_vec_base.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\n#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)\n#define K_PER_ITER 8\n#elif defined(DATA_A_QUANT_K)\n#define K_PER_ITER 16\n#elif defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)\n#define K_PER_ITER 32\n#else\n#error unimplemented\n#endif\n\nuint a_offset, b_offset, d_offset;\n\nint32_t cache_b_qs[K_PER_ITER / 4];\nvec2 cache_b_ds;\n\n#include \"mul_mat_vecq_funcs.glsl\"\n\nvoid iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;\n\n        // Preload data_b block\n        const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;\n        const uint b_qs_idx = tid % (32 / K_PER_ITER);\n        const uint b_block_idx_outer = b_block_idx / 4;\n        const uint b_block_idx_inner = b_block_idx % 4;\n        cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);\n\n#if QUANT_R == 2\n        // Assumes K_PER_ITER == 8\n        cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];\n        cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];\n#else\n#if K_PER_ITER == 8\n        cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];\n        cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];\n#elif K_PER_ITER == 16\n        cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4    ];\n        cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];\n        cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];\n        cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];\n#elif K_PER_ITER == 32\n        cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8    ];\n        cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 1];\n        cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 2];\n        cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 3];\n        cache_b_qs[4] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 4];\n        cache_b_qs[5] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 5];\n        cache_b_qs[6] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 6];\n        cache_b_qs[7] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 7];\n#else\n#error unimplemented\n#endif\n#endif\n\n        uint ibi = first_row*p.ncols;\n        [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n            const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset;\n            ibi += p.ncols;\n\n            temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);\n        }\n    }\n}\n\nvoid compute_outputs(const uint32_t first_row, const uint32_t num_rows) {\n    const uint tid = gl_LocalInvocationID.x;\n\n    get_offsets(a_offset, b_offset, d_offset);\n    a_offset *= QUANT_K / QUANT_K_Q8_1;\n    b_offset /= QUANT_K_Q8_1;\n\n    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];\n\n    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {\n        [[unroll]] for (uint n = 0; n < num_rows; ++n) {\n            temp[j][n] = FLOAT_TYPE(0.0f);\n        }\n    }\n\n    uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);\n    if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {\n        num_iters++;\n    }\n    int unroll_count = 4;\n    uint unrolled_iters = num_iters & ~(unroll_count - 1);\n\n    uint i = 0;\n    while (i < unrolled_iters) {\n        // Manually partially unroll the loop\n        [[unroll]] for (uint k = 0; k < unroll_count; ++k) {\n            iter(temp, first_row, num_rows, tid, i*K_PER_ITER);\n            i++;\n        }\n    }\n\n    unroll_count = 2;\n    unrolled_iters = num_iters & ~(unroll_count - 1);\n\n    while (i < unrolled_iters) {\n        // Manually partially unroll the loop\n        [[unroll]] for (uint k = 0; k < unroll_count; ++k) {\n            iter(temp, first_row, num_rows, tid, i*K_PER_ITER);\n            i++;\n        }\n    }\n    while (i < num_iters) {\n        iter(temp, first_row, num_rows, tid, i*K_PER_ITER);\n        i++;\n    }\n\n    reduce_result(temp, d_offset, first_row, num_rows, tid);\n}\n\nvoid main() {\n    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);\n\n#ifdef NEEDS_INIT_IQ_SHMEM\n    init_iq_shmem(gl_WorkGroupSize);\n#endif\n\n    // do NUM_ROWS at a time, unless there aren't enough remaining rows\n    if (first_row + NUM_ROWS <= p.stride_d) {\n        compute_outputs(first_row, NUM_ROWS);\n    } else {\n        if (first_row >= p.stride_d) {\n            return;\n        }\n        compute_outputs(first_row, p.stride_d - first_row);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl",
    "content": "#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require\n\n#include \"types.glsl\"\n\n#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)\nFLOAT_TYPE get_dm(uint ib) {\n    return FLOAT_TYPE(data_a[ib].d);\n}\n#endif\n\n#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)\nFLOAT_TYPE_VEC2 get_dm(uint ib) {\n    return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);\n}\n#endif\n\n#if defined(DATA_A_MXFP4)\nFLOAT_TYPE get_dm(uint ib) {\n    return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));\n}\n#endif\n\n#if defined(DATA_A_Q2_K)\nFLOAT_TYPE_VEC2 get_dm(uint ib) {\n    const uint ib_k = ib / 8;\n    return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);\n}\n#endif\n\n// Each iqs value maps to a 32-bit integer\n#if defined(DATA_A_Q4_0)\n// 2-byte loads for Q4_0 blocks (18 bytes)\ni32vec2 repack(uint ib, uint iqs) {\n    const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2    ],\n                                   data_a_packed16[ib].qs[iqs * 2 + 1]);\n    const uint32_t vui = pack32(quants);\n    return i32vec2( vui       & 0x0F0F0F0F,\n                   (vui >> 4) & 0x0F0F0F0F);\n}\n\nFLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {\n    return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));\n}\n#endif\n\n#if defined(DATA_A_Q4_1)\n// 4-byte loads for Q4_1 blocks (20 bytes)\ni32vec2 repack(uint ib, uint iqs) {\n    const uint32_t vui = data_a_packed32[ib].qs[iqs];\n    return i32vec2( vui       & 0x0F0F0F0F,\n                   (vui >> 4) & 0x0F0F0F0F);\n}\n\nFLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {\n    return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);\n}\n#endif\n\n#if defined(DATA_A_Q5_0)\n// 2-byte loads for Q5_0 blocks (22 bytes)\ni32vec2 repack(uint ib, uint iqs) {\n    const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2    ],\n                                   data_a_packed16[ib].qs[iqs * 2 + 1]);\n    const uint32_t vui = pack32(quants);\n    const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));\n    const int32_t v0 = int32_t(vui & 0x0F0F0F0F)\n                     | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)\n\n    const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)\n                     | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)\n\n    return i32vec2(v0, v1);\n}\n\nFLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {\n    return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));\n}\n#endif\n\n#if defined(DATA_A_Q5_1)\n// 4-byte loads for Q5_1 blocks (24 bytes)\ni32vec2 repack(uint ib, uint iqs) {\n    const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2    ],\n                                   data_a_packed16[ib].qs[iqs * 2 + 1]);\n    const uint32_t vui = pack32(quants);\n    const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));\n    const int32_t v0 = int32_t(vui & 0x0F0F0F0F)\n                     | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)\n\n    const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)\n                     | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)\n\n    return i32vec2(v0, v1);\n}\n\nFLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {\n    return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);\n}\n#endif\n\n#if defined(DATA_A_Q8_0)\n// 2-byte loads for Q8_0 blocks (34 bytes)\nint32_t repack(uint ib, uint iqs) {\n    return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2    ],\n                          data_a_packed16[ib].qs[iqs * 2 + 1]));\n}\n\nFLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {\n    return FLOAT_TYPE(float(q_sum) * da * dsb.x);\n}\n#endif\n\n#if defined(DATA_A_MXFP4)\n// 1-byte loads for mxfp4 blocks (17 bytes)\ni32vec2 repack(uint ib, uint iqs) {\n    const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4    ],\n                                      data_a[ib].qs[iqs * 4 + 1],\n                                      data_a[ib].qs[iqs * 4 + 2],\n                                      data_a[ib].qs[iqs * 4 + 3]));\n\n    const u8vec4 i_a0 = unpack8( qs       & 0x0F0F0F0F);\n    const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);\n\n    return i32vec2(pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])),\n                   pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])));\n}\n\nFLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {\n    return FLOAT_TYPE(da * dsb.x * float(q_sum) * 0.5);\n}\n#endif\n\n#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)\nFLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {\n    int32_t q_sum = 0;\n#if QUANT_R == 2\n    const i32vec2 data_a_qs = repack(ib_a, iqs);\n    q_sum += dotPacked4x8EXT(data_a_qs.x,\n                             cache_b_qs[0]);\n    q_sum += dotPacked4x8EXT(data_a_qs.y,\n                             cache_b_qs[1]);\n#else\n    int32_t data_a_qs = repack(ib_a, iqs * 2);\n    q_sum += dotPacked4x8EXT(data_a_qs,\n                             cache_b_qs[0]);\n    data_a_qs = repack(ib_a, iqs * 2 + 1);\n    q_sum += dotPacked4x8EXT(data_a_qs,\n                             cache_b_qs[1]);\n#endif\n\n    // 2 quants per call => divide sums by 8/2 = 4\n    return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, 4);\n}\n#endif\n\n#if defined(DATA_A_Q2_K)\n// 4-byte loads for Q2_K blocks (84 bytes)\ni32vec4 repack4(uint ib, uint iqs) {\n    const uint ib_k = ib / 8;\n    const uint iqs_k = (ib % 8) * 8 + iqs;\n\n    const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);\n    const uint qs_shift = ((iqs_k % 32) / 8) * 2;\n\n    return i32vec4((data_a_packed32[ib_k].qs[qs_idx    ] >> qs_shift) & 0x03030303,\n                   (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303,\n                   (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303,\n                   (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303);\n}\n\nuint8_t get_scale(uint ib, uint iqs) {\n    const uint ib_k = ib / 8;\n    const uint iqs_k = (ib % 8) * 8 + iqs;\n\n    return data_a[ib_k].scales[iqs_k / 4];\n}\n\nFLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {\n    int32_t sum_d = 0;\n    int32_t sum_m = 0;\n\n    const i32vec4 qs_a = repack4(ib_a, iqs * 4);\n    const uint8_t scale = get_scale(ib_a, iqs * 4);\n    const vec2 dm = vec2(get_dm(ib_a));\n    const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.\n\n    sum_d += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]) * (scale & 0xF);\n    sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0]);\n\n    sum_d += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]) * (scale & 0xF);\n    sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1]);\n\n    sum_d += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]) * (scale & 0xF);\n    sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[2]);\n\n    sum_d += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]) * (scale & 0xF);\n    sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[3]);\n\n    return FLOAT_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m)));\n}\n#endif\n\n#if defined(DATA_A_Q3_K)\n// 2-byte loads for Q3_K blocks (110 bytes)\ni32vec4 repack4(uint ib, uint iqs) {\n    const uint ib_k = ib / 8;\n    const uint iqs_k = (ib % 8) * 8 + iqs;\n\n    const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);\n    const uint qs_shift = ((iqs_k % 32) / 8) * 2;\n    const uint hm_shift = iqs_k / 8;\n\n    // bitwise OR to add 4 if hmask is set, subtract later\n    const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2    ] >> qs_shift) & uint16_t(0x0303))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2    ] >> hm_shift) & uint16_t(0x0101)) << 2));\n    const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2 + 1] >> qs_shift) & uint16_t(0x0303))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));\n    const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2 + 2] >> qs_shift) & uint16_t(0x0303))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));\n    const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2 + 3] >> qs_shift) & uint16_t(0x0303))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));\n    const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2 + 4] >> qs_shift) & uint16_t(0x0303))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2));\n    const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2 + 5] >> qs_shift) & uint16_t(0x0303))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2));\n    const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2 + 6] >> qs_shift) & uint16_t(0x0303))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2));\n    const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx  * 2 + 7] >> qs_shift) & uint16_t(0x0303))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2));\n\n    return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)),\n                   pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)),\n                   pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)),\n                   pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4)));\n}\n\nfloat get_d_scale(uint ib, uint iqs) {\n    const uint ib_k = ib / 8;\n    const uint iqs_k = (ib % 8) * 8 + iqs;\n    const uint is = iqs_k / 4;\n\n    const int8_t scale = int8_t(((data_a[ib_k].scales[is % 8      ] >> (4 * (is / 8))) & 0x0F0F) |\n                               (((data_a[ib_k].scales[8 + (is % 4)] >> (2 * (is / 4))) & 0x0303) << 4));\n    return float(data_a[ib_k].d) * float(scale - 32);\n}\n\nFLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {\n    int32_t q_sum = 0;\n\n    const i32vec4 qs_a = repack4(ib_a, iqs * 4);\n    const float d_scale = get_d_scale(ib_a, iqs * 4);\n\n    q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);\n    q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);\n    q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);\n    q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);\n\n    return FLOAT_TYPE(float(cache_b_ds.x) * d_scale * float(q_sum));\n}\n#endif\n\n#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)\n// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)\ni32vec4 repack4(uint ib, uint iqs) {\n    const uint ib_k = ib / 8;\n    const uint iqs_k = (ib % 8) * 8 + iqs;\n\n    const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);\n    const uint qs_shift = ((iqs_k % 16) / 8) * 4;\n\n#if defined(DATA_A_Q4_K)\n    const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx    ] >> qs_shift) & 0x0F0F0F0F;\n    const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;\n    const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F;\n    const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F;\n\n    return i32vec4(vals0, vals1, vals2, vals3);\n#else // defined(DATA_A_Q5_K)\n    const uint qh_idx = iqs;\n    const uint qh_shift = iqs_k / 8;\n\n    return i32vec4(((data_a_packed32[ib_k].qs[qs_idx    ] >> qs_shift) & 0x0F0F0F0F) |\n                  (((data_a_packed32[ib_k].qh[qh_idx    ] >> qh_shift) & 0x01010101) << 4),\n                   ((data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F) |\n                  (((data_a_packed32[ib_k].qh[qh_idx + 1] >> qh_shift) & 0x01010101) << 4),\n                   ((data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F) |\n                  (((data_a_packed32[ib_k].qh[qh_idx + 2] >> qh_shift) & 0x01010101) << 4),\n                   ((data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F) |\n                  (((data_a_packed32[ib_k].qh[qh_idx + 3] >> qh_shift) & 0x01010101) << 4));\n#endif\n}\n\nvec2 get_dm_scale(uint ib, uint iqs) {\n    const uint ib_k = ib / 8;\n    const uint iqs_k = (ib % 8) * 8 + iqs;\n    const uint is = iqs_k / 8;\n    u8vec2 scale_dm;\n    if (is < 4) {\n        scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);\n    } else {\n        scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),\n                          (data_a[ib_k].scales[is+4] >>  4) | ((data_a[ib_k].scales[is  ] & 0xC0) >> 2));\n    }\n\n    return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);\n}\n\nFLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {\n    int32_t q_sum = 0;\n\n    const i32vec4 qs_a = repack4(ib_a, iqs * 4);\n    const vec2 dm_scale = get_dm_scale(ib_a, iqs * 4);\n\n    q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);\n    q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);\n    q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);\n    q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);\n\n    return FLOAT_TYPE(float(cache_b_ds.x) * float(dm_scale.x) * float(q_sum) - float(dm_scale.y) * float(cache_b_ds.y / 2));\n}\n#endif\n\n#if defined(DATA_A_Q6_K)\n// 2-byte loads for Q6_K blocks (210 bytes)\ni32vec4 repack4(uint ib, uint iqs) {\n    const uint ib_k = ib / 8;\n    const uint iqs_k = (ib % 8) * 8 + iqs;\n\n    const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;\n    const uint ql_shift = ((iqs_k % 32) / 16) * 4;\n\n    const uint qh_idx = (iqs_k / 32) * 8 + iqs;\n    const uint qh_shift = ((iqs_k % 32) / 8) * 2;\n\n    const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2    ] >> ql_shift) & uint16_t(0x0F0F))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2    ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);\n    const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);\n    const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);\n    const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);\n    const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);\n    const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);\n    const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);\n    const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) |\n                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);\n\n    return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)),\n                   pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)),\n                   pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)),\n                   pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y)));\n}\n\nfloat get_d_scale(uint ib, uint iqs) {\n    const uint ib_k = ib / 8;\n    const uint iqs_k = (ib % 8) * 8 + iqs;\n    return float(data_a[ib_k].d) * float(data_a[ib_k].scales[iqs_k / 4]);\n}\n\nFLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {\n    int32_t q_sum = 0;\n\n    const i32vec4 qs_a = repack4(ib_a, iqs * 4);\n    const float d_scale = get_d_scale(ib_a, iqs * 4);\n\n    q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);\n    q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);\n    q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);\n    q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);\n\n    return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum));\n}\n#endif\n\n#if defined(DATA_A_IQ1_S)\nvoid repack8(uint ib, uint iqs, out i32vec4 out0, out i32vec4 out1) {\n    const uint ib32 = iqs / 32;\n\n    const uint qh = data_a[ib].qh[ib32];\n\n    const uint qs16_0 = data_a_packed16[ib].qs[(4 * ib32 + 0) / 2];\n    const uint qs16_1 = data_a_packed16[ib].qs[(4 * ib32 + 2) / 2];\n\n    const uint qs0 = qs16_0 & 0xFF;\n    const uint qs1 = qs16_0 >> 8;\n    const uint qs2 = qs16_1 & 0xFF;\n    const uint qs3 = qs16_1 >> 8;\n\n    const uint hi0 = bitfieldExtract(qh, 3 * int(0), 3);\n    const uint hi1 = bitfieldExtract(qh, 3 * int(1), 3);\n    const uint hi2 = bitfieldExtract(qh, 3 * int(2), 3);\n    const uint hi3 = bitfieldExtract(qh, 3 * int(3), 3);\n\n    const int32_t grid0 = int32_t(iq1s_grid_gpu[qs0 | (hi0 << 8)]);\n    const int32_t grid1 = int32_t(iq1s_grid_gpu[qs1 | (hi1 << 8)]);\n    const int32_t grid2 = int32_t(iq1s_grid_gpu[qs2 | (hi2 << 8)]);\n    const int32_t grid3 = int32_t(iq1s_grid_gpu[qs3 | (hi3 << 8)]);\n\n    out0 = i32vec4((grid0 >> 0) & 0x0F0F0F0F,\n                   (grid0 >> 4) & 0x0F0F0F0F,\n                   (grid1 >> 0) & 0x0F0F0F0F,\n                   (grid1 >> 4) & 0x0F0F0F0F);\n    out1 = i32vec4((grid2 >> 0) & 0x0F0F0F0F,\n                   (grid2 >> 4) & 0x0F0F0F0F,\n                   (grid3 >> 0) & 0x0F0F0F0F,\n                   (grid3 >> 4) & 0x0F0F0F0F);\n}\n\nvec2 get_dm(uint ib, uint iqs) {\n    const uint ib32 = iqs / 32;\n\n    const uint qh = data_a[ib].qh[ib32];\n    const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;\n\n    const float d = float(data_a[ib].d);\n    const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);\n\n    // the -1 cancels out the bias in iq1s_grid_gpu\n    return FLOAT_TYPE_VEC2(dl, dl * (delta - 1));\n}\n\nFLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {\n    int32_t q_sum = 0;\n\n    const uint ib_k = ib_a / 8;\n    const uint iqs_k = (ib_a % 8) * 32 + iqs * 32;\n\n    i32vec4 qs_a0;\n    i32vec4 qs_a1;\n    repack8(ib_k, iqs_k, qs_a0, qs_a1);\n\n    const vec2 dm = get_dm(ib_k, iqs_k);\n\n    q_sum += dotPacked4x8EXT(qs_a0.x, cache_b_qs[0]);\n    q_sum += dotPacked4x8EXT(qs_a0.y, cache_b_qs[1]);\n    q_sum += dotPacked4x8EXT(qs_a0.z, cache_b_qs[2]);\n    q_sum += dotPacked4x8EXT(qs_a0.w, cache_b_qs[3]);\n    q_sum += dotPacked4x8EXT(qs_a1.x, cache_b_qs[4]);\n    q_sum += dotPacked4x8EXT(qs_a1.y, cache_b_qs[5]);\n    q_sum += dotPacked4x8EXT(qs_a1.z, cache_b_qs[6]);\n    q_sum += dotPacked4x8EXT(qs_a1.w, cache_b_qs[7]);\n\n    return FLOAT_TYPE(float(cache_b_ds.x) * float(dm.x) * float(q_sum) + float(dm.y) * float(cache_b_ds.y));\n}\n#endif\n\n#if defined(DATA_A_IQ1_M)\nFLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {\n    const uint ib_k = ib_a / 8;\n    const uint iqs_k = (ib_a % 8) * 32 + iqs * 32;\n\n    const uint ib32 = iqs_k / 32;\n    const uint ib64 = ib32 / 2;\n\n    const uint16_t[4] scales = data_a[ib_k].scales;\n    const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;\n    const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);\n\n    const uint qs32 = data_a_packed32[ib_k].qs[ib32];\n    const uint qh16 = data_a_packed16[ib_k].qh[ib32];\n\n    float sum = 0;\n    const uint sc = data_a[ib_k].scales[ib64];\n    [[unroll]] for (int l = 0; l < 4; ++l) {\n        const uint ib16 = 2 * ib32 + l / 2;\n        const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);\n        const uint qh = qh16 >> (4 * l);\n        const uint qs = (qs32 >> (8 * l)) & 0xFF;\n        const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;\n\n        const int32_t grid = int32_t(iq1s_grid_gpu[qs | ((qh & 7) << 8)]);\n\n        int32_t q_sum = 0;\n        q_sum += dotPacked4x8EXT((grid >> 0) & 0x0F0F0F0F, cache_b_qs[2 * l + 0]);\n        q_sum += dotPacked4x8EXT((grid >> 4) & 0x0F0F0F0F, cache_b_qs[2 * l + 1]);\n\n        int32_t y_sum = 0;\n        y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 0]);\n        y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 1]);\n\n        // the -1 cancels out the bias in iq1s_grid_gpu\n        sum += dl * (q_sum + y_sum * (delta - 1));\n    }\n    sum *= float(cache_b_ds.x);\n\n    return sum;\n}\n#endif\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mm.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_EXT_shader_16bit_storage : require\n\n#ifdef FLOAT16\n#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require\n#endif\n#if defined(DATA_A_IQ1_M)\n#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require\n#endif\n\n#if defined(DATA_A_BF16) && defined(COOPMAT)\n#extension GL_EXT_bfloat16 : enable\n#endif\n\n#ifdef COOPMAT\n#extension GL_KHR_cooperative_matrix : enable\n#extension GL_KHR_memory_scope_semantics : enable\n#endif\n\n#if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS)\n#extension GL_KHR_shader_subgroup_basic : enable\n#extension GL_KHR_shader_subgroup_ballot : enable\n#endif\n\n#ifdef MUL_MAT_ID\n#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require\n#endif\n\n#include \"types.glsl\"\n\n#ifndef LOAD_VEC_A\n#define LOAD_VEC_A 1\n#endif\n#ifndef LOAD_VEC_B\n#define LOAD_VEC_B 1\n#endif\n\n// Load 2 values at once without affecting index calculations through LOAD_VEC\n#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)\n#define LOAD_VEC_BATCH_A 2\n#else\n#define LOAD_VEC_BATCH_A 1\n#endif\n#if !defined(ALIGNED)\n#define LOAD_VEC_BATCH_B 2\n#else\n#define LOAD_VEC_BATCH_B 1\n#endif\n\n#if !defined(TO_FLOAT_TYPE)\n#define TO_FLOAT_TYPE FLOAT_TYPE\n#endif\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\n#if defined(A_TYPE_PACKED16)\nlayout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};\n#endif\n#if defined(A_TYPE_PACKED32)\nlayout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};\n#endif\n\nlayout (binding = 1) readonly buffer B {B_TYPE data_b[];};\nlayout (binding = 2) writeonly buffer D {D_TYPE data_d[];};\n\n#ifdef MUL_MAT_ID\nlayout (binding = 3) readonly buffer IDS {int data_ids[];};\nlayout (binding = 4) readonly buffer Counts {int data_expert_count[];};\n#endif\n\nlayout (push_constant) uniform parameter\n{\n    uint M;\n    uint N;\n    uint K;\n    uint stride_a;\n    uint stride_b;\n    uint stride_d;\n\n    uint batch_stride_a;\n    uint batch_stride_b;\n    uint batch_stride_d;\n\n#ifdef MUL_MAT_ID\n    uint nei0;\n    uint nei1;\n    uint nbi1;\n    uint ne11;\n#else\n    uint base_work_group_z;\n    uint num_batches;\n    uint k_split;\n    uint ne02;\n    uint ne12;\n    uint broadcast2;\n    uint broadcast3;\n#endif\n} p;\n\nlayout (constant_id = 0) const uint BLOCK_SIZE = 64;\nlayout (constant_id = 1) const uint BM = 64;\nlayout (constant_id = 2) const uint BN = 64;\nlayout (constant_id = 4) const uint WM = 32;\nlayout (constant_id = 5) const uint WN = 32;\nlayout (constant_id = 6) const uint WMITER = 2;\nlayout (constant_id = 7) const uint TM = 4;\nlayout (constant_id = 8) const uint TN = 2;\nlayout (constant_id = 9) const uint TK = 1;  // Only needed for coopmat\nlayout (constant_id = 10) const uint WARP = 32;\n\n#if defined(DATA_A_F32) || defined(DATA_A_F16)\n#define BK 32\n#define BK_STEP 4\n#else\nlayout (constant_id = 3) const uint BK = 16;  // Assumed to be 32 if working with a quant\n#define BK_STEP 2\n#endif\n\n#ifdef COOPMAT\n#define SHMEM_STRIDE (BK / 2 + 4)\n#else\n#define SHMEM_STRIDE (BK / 2 + 1)\n#endif\n\nshared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE];\nshared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];\n\n#define NUM_WARPS (BLOCK_SIZE / WARP)\n\n#ifdef COOPMAT\nshared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];\n#endif\n\n#include \"mul_mm_id_funcs.glsl\"\n#include \"mul_mm_funcs.glsl\"\n\nvoid main() {\n    const uint ic = gl_WorkGroupID.y;\n\n#ifdef MUL_MAT_ID\n    const uint expert_idx = gl_WorkGroupID.z;\n    if (ic * BN >= data_expert_count[expert_idx]) {\n        return;\n    }\n#endif\n#ifdef NEEDS_INIT_IQ_SHMEM\n    init_iq_shmem(gl_WorkGroupSize);\n#endif\n\n#ifndef MUL_MAT_ID\n    const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;\n\n    const uint i13 = batch_idx / p.ne12;\n    const uint i12 = batch_idx % p.ne12;\n\n    const uint i03 = i13 / p.broadcast3;\n    const uint i02 = i12 / p.broadcast2;\n\n    const uint batch_idx_a = i03 * p.ne02 + i02;\n#endif\n\n    const uint blocks_m = (p.M + BM - 1) / BM;\n    const uint ir = gl_WorkGroupID.x % blocks_m;\n    const uint ik = gl_WorkGroupID.x / blocks_m;\n\n    const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);\n    const uint WSUBM = WM / WMITER;\n    const uint WSUBN = WN / WNITER;\n\n#ifdef COOPMAT\n    const uint warp_i = gl_SubgroupID;\n\n    const uint tiw = gl_SubgroupInvocationID;\n\n    const uint cms_per_row = WM / TM;\n    const uint cms_per_col = WN / TN;\n\n    const uint storestride = WARP / TM;\n    const uint store_r = tiw % TM;\n    const uint store_c = tiw / TM;\n#else\n    const uint warp_i = gl_LocalInvocationID.x / WARP;\n\n    const uint tiw = gl_LocalInvocationID.x % WARP;\n\n    const uint tiwr = tiw % (WSUBM / TM);\n    const uint tiwc = tiw / (WSUBM / TM);\n#endif\n\n    const uint warp_r = warp_i % (BM / WM);\n    const uint warp_c = warp_i / (BM / WM);\n\n    const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);\n    const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);\n    const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);\n    const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);\n\n    const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;\n    const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;\n\n#ifdef MUL_MAT_ID\n#ifdef MUL_MAT_ID_USE_SUBGROUPS\n    if (bitCount(p.nei0) == 1) {\n        load_row_ids(expert_idx, true, ic);\n    } else {\n        load_row_ids(expert_idx, false, ic);\n    }\n#else\n    _ne1 = 0;\n    for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {\n        for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {\n            if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {\n                if (_ne1 >= ic * BN) {\n                    row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);\n                }\n                _ne1++;\n            }\n        }\n    }\n\n    barrier();\n#endif\n\n    // Workgroup has no work\n    if (ic * BN >= _ne1) return;\n#endif\n\n#ifdef MUL_MAT_ID\n    const uint start_k = 0;\n    const uint end_k = p.K;\n#else\n    const uint start_k = ik * p.k_split;\n    const uint end_k = min(p.K, (ik + 1) * p.k_split);\n#endif\n\n    uint pos_a =\n#ifdef MUL_MAT_ID\n        expert_idx * (p.batch_stride_a / LOAD_VEC_A) +\n#else\n        batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +\n#endif\n        (ir * BM * p.stride_a + start_k) / LOAD_VEC_A;\n#ifdef MUL_MAT_ID\n    uint pos_b = 0;\n#else\n    uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;\n#endif\n\n#ifdef COOPMAT\n    coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;\n    coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;\n    coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];\n\n    [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {\n        sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);\n    }\n#else\n    ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];\n#if defined(DATA_A_F32) || defined(DATA_A_F16)\n    FLOAT_TYPE_VEC4 cache_a[WMITER * TM];\n    FLOAT_TYPE_VEC4 cache_b;\n#else\n    FLOAT_TYPE_VEC2 cache_a[WMITER * TM];\n    FLOAT_TYPE_VEC2 cache_b;\n#endif\n\n    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {\n        sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);\n    }\n#endif\n\n    for (uint block = start_k; block < end_k; block += BK) {\n        [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {\n            load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k);\n        }\n        [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {\n#if !defined(MUL_MAT_ID)\n            load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k);\n#else\n            load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block, end_k);\n#endif\n        }\n\n        barrier();\n\n        pos_a += BK / LOAD_VEC_A;\n        pos_b += BK / LOAD_VEC_B;\n\n#ifdef COOPMAT\n        [[unroll]] for (uint i = 0; i < BK; i += TK) {\n            [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {\n                // Load from shared into cache\n                coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);\n\n                [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {\n                    coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);\n\n                    sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);\n                }\n            }\n        }\n#else\n        [[unroll]] for (uint i = 0; i < BK / BK_STEP; i++) {\n            // Load from shared into cache\n            [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {\n                [[unroll]] for (uint j = 0; j < TM; j++) {\n                #if defined(DATA_A_F32) || defined(DATA_A_F16)\n                    cache_a[wsir * TM + j].xy = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i    ];\n                    cache_a[wsir * TM + j].zw = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i + 1];\n                #else\n                    cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];\n                #endif\n                }\n            }\n\n            [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {\n                [[unroll]] for (uint cc = 0; cc < TN; cc++) {\n                #if defined(DATA_A_F32) || defined(DATA_A_F16)\n                    cache_b.xy = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i    ];\n                    cache_b.zw = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i + 1];\n                #else\n                    cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];\n                #endif\n\n                    [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {\n                        [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {\n                            // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]\n                            const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;\n                        #if defined(DATA_A_F32) || defined(DATA_A_F16)\n                            sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].y), ACC_TYPE(cache_b.y),\n                                               fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x))));\n                            sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y),\n                                               fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y))));\n                        #else\n                            sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));\n                            sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));\n                        #endif\n                        }\n                    }\n                }\n            }\n\n        }\n#endif\n\n        barrier();\n    }\n\n#if defined(ACC_TYPE_MAX)\n#ifdef COOPMAT\n    [[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) {\n        [[unroll]] for (uint i = 0; i < sums[j].length(); ++i) {\n            sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX);\n        }\n    }\n#else\n    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {\n        sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX);\n        sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX);\n    }\n#endif\n#endif\n\n    const uint dr = ir * BM + warp_r * WM;\n    const uint dc = ic * BN + warp_c * WN;\n\n#ifndef MUL_MAT_ID\n    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;\n#endif\n\n#ifdef COOPMAT\n#ifdef MUL_MAT_ID\n    [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {\n        [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {\n            coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);\n\n            barrier();\n            [[unroll]] for (uint col = 0; col < TN; col += storestride) {\n                const uint row_i = dc + cm_col * TN + col + store_c;\n                if (row_i >= _ne1) break;\n\n                const u16vec2 row_idx = row_ids[row_i - ic * BN];\n\n                if (dr + cm_row * TM + store_r < p.M) {\n                    data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);\n                }\n            }\n            barrier();\n        }\n    }\n#else\n    const bool is_aligned = p.stride_d % 4 == 0;  // Assumption: D_TYPE == float\n\n    [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {\n        [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {\n            const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;\n\n            if (is_aligned && is_in_bounds) {\n                // Full coopMat is within bounds and stride_d is aligned with 16B\n                coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);\n                coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);\n            } else if (is_in_bounds) {\n                // Full coopMat is within bounds, but stride_d is not aligned\n                coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);\n\n                controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n                [[unroll]] for (uint col = 0; col < TN; col += storestride) {\n                    data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);\n                }\n                controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n            } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {\n                // Partial coopMat is within bounds\n                coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);\n\n                controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n                [[unroll]] for (uint col = 0; col < TN; col += storestride) {\n                    if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {\n                        data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);\n                    }\n                }\n                controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);\n            }\n        }\n    }\n#endif // MUL_MAT_ID\n#else\n    [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {\n        [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {\n\n            const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;\n            const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;\n            [[unroll]] for (uint cc = 0; cc < TN; cc++) {\n#ifdef MUL_MAT_ID\n                const uint row_i = dc_warp + cc;\n                if (row_i >= _ne1) break;\n\n                const u16vec2 row_idx = row_ids[row_i - ic * BN];\n#endif // MUL_MAT_ID\n                [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {\n                    const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;\n#ifdef MUL_MAT_ID\n                    if (dr_warp + 2 * cr < p.M) {\n                        data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);\n                    }\n                    if (dr_warp + 2 * cr + 1 < p.M) {\n                        data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);\n                    }\n#else\n                    if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {\n                        data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);\n                    }\n                    if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {\n                        data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);\n                    }\n#endif // MUL_MAT_ID\n                }\n            }\n        }\n    }\n#endif // COOPMAT\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_EXT_shader_16bit_storage : require\n\n#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require\n\n#extension GL_KHR_memory_scope_semantics : enable\n#extension GL_KHR_cooperative_matrix : enable\n#extension GL_NV_cooperative_matrix2 : enable\n#extension GL_EXT_buffer_reference : enable\n#extension GL_KHR_shader_subgroup_ballot : enable\n#extension GL_KHR_shader_subgroup_vote : enable\n#ifdef DATA_A_BF16\n#extension GL_EXT_bfloat16 : enable\n#endif\n\n#include \"types.glsl\"\n#include \"utils.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\n#define IS_MUL_MM2 1\n\nlayout (constant_id = 0) const uint BLOCK_SIZE = 256;\nlayout (constant_id = 1) const uint BM = 64;\nlayout (constant_id = 2) const uint BN = 64;\nlayout (constant_id = 3) const uint BK = 16;  // Assumed to be 32 if working with a quant\n\nlayout (constant_id = 4) const bool enable_smaller_matrices = false;\nconst uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;\nconst uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;\n\nlayout (push_constant) uniform parameter\n{\n    uint M;\n    uint N;\n    uint K;\n    uint stride_a;\n    uint stride_b;\n    uint stride_d;\n\n    uint batch_stride_a;\n    uint batch_stride_b;\n    uint batch_stride_d;\n\n#ifdef MUL_MAT_ID\n    uint nei0;\n    uint nei1;\n    uint nbi1;\n    uint ne11;\n#else\n    uint base_work_group_z;\n    uint num_batches;\n    uint k_split;\n    uint ne02;\n    uint ne12;\n    uint broadcast2;\n    uint broadcast3;\n#endif\n    // N dimension for the B matrix can be >= p.N\n    uint padded_N;\n} p;\n\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) readonly buffer B {B_TYPE data_b[];};\nlayout (binding = 2) writeonly buffer D {D_TYPE data_d[];};\n\n#if QUANT_K > 1\n#define DECODEFUNCA , dequantFuncA\n\n#include \"dequant_funcs_cm2.glsl\"\n\n#else\n#define DECODEFUNCA\n#endif\n\n#if !defined(fetch_scales)\n#define fetch_scales(a, b, c, d, e, f)\n#endif\n#if !defined(store_scales)\n#define store_scales(a)\n#endif\n\n#if defined(DATA_A_BF16)\n#define MAT_TYPE bfloat16_t\n#else\n#define MAT_TYPE FLOAT_TYPE\n#endif\n\n#ifdef MUL_MAT_ID\nlayout (binding = 3) readonly buffer IDS {int data_ids[];};\nlayout (binding = 4) readonly buffer Counts {int data_expert_count[];};\n\nshared u16vec4 row_ids[BN];\n\nlayout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {\n   B_TYPE b[];\n};\n\nuint _ne1;\nlayout (constant_id = 5) const uint subgroup_size = 32;\nshared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];\n\nB_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])\n{\n    const uint row_i = blockCoords[0];\n\n    const u16vec4 row_idx = row_ids[row_i];\n    B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];\n\n    return ret;\n}\n\nD_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic)\n{\n    uint dr = ir * BM + r;\n    uint dc = ic * BN + c;\n\n    if (dr < p.M && dc < _ne1) {\n        uint row_i = c;\n        const u16vec4 row_idx = row_ids[row_i];\n        data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;\n    }\n    return elem;\n}\n\nvoid load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {\n    _ne1 = 0;\n    uint num_elements = p.nei1 * p.nei0;\n    uint nei0shift = findLSB(p.nei0);\n\n    uint ids[16];\n    uint iter = 0;\n\n    uint expert_count = data_expert_count[expert_idx];\n\n    for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {\n        // prefetch up to 16 elements\n        if (iter == 0) {\n            [[unroll]] for (uint k = 0; k < 16; ++k) {\n                uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;\n                bool in_range = i < num_elements;\n                uint ii1;\n                if (nei0_is_pow2) {\n                    ii1 = i >> nei0shift;\n                } else {\n                    ii1 = i / p.nei0;\n                }\n                uint ii0 = i - ii1 * p.nei0;\n                ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;\n            }\n        }\n        uint i = j + gl_LocalInvocationIndex;\n        bool in_range = i < num_elements;\n        uint ii1;\n        if (nei0_is_pow2) {\n            ii1 = i >> nei0shift;\n        } else {\n            ii1 = i / p.nei0;\n        }\n        uint ii0 = i - ii1 * p.nei0;\n        uint id = ids[iter++];\n        uvec4 ballot = subgroupBallot(in_range && id == expert_idx);\n\n        if (gl_SubgroupInvocationID == 0) {\n            ballots_sh[gl_SubgroupID] = ballot;\n        }\n        barrier();\n\n        uint subgroup_base = 0;\n        uint total = 0;\n        for (uint k = 0; k < gl_NumSubgroups; ++k) {\n            if (k == gl_SubgroupID) {\n                subgroup_base = total;\n            }\n            total += subgroupBallotBitCount(ballots_sh[k]);\n        }\n        barrier();\n\n        uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);\n        if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {\n            row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);\n        }\n        _ne1 += total;\n        iter &= 15;\n        if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) {\n            break;\n        }\n    }\n    barrier();\n}\n#endif\n\nvoid main() {\n    const uint tid = gl_LocalInvocationIndex;\n    const uint ic = gl_WorkGroupID.y;\n\n#ifdef MUL_MAT_ID\n    const uint expert_idx = gl_WorkGroupID.z;\n    if (ic * BN >= data_expert_count[expert_idx]) {\n        return;\n    }\n    // initialize to row 0 so we don't need to bounds check\n    if (tid < BN) {\n        row_ids[tid] = u16vec4(0);\n    }\n#if !defined(NEEDS_INIT_IQ_SHMEM)\n    barrier();\n#endif\n#endif\n\n#ifdef NEEDS_INIT_IQ_SHMEM\n    init_iq_shmem(gl_WorkGroupSize);\n#endif\n\n#ifndef MUL_MAT_ID\n    const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;\n\n    const uint i13 = batch_idx / p.ne12;\n    const uint i12 = batch_idx % p.ne12;\n\n    const uint i03 = i13 / p.broadcast3;\n    const uint i02 = i12 / p.broadcast2;\n\n    const uint batch_idx_a = i03 * p.ne02 + i02;\n#endif\n\n    const uint blocks_m = (p.M + BM - 1) / BM;\n    const uint ir = gl_WorkGroupID.x % blocks_m;\n    const uint ik = gl_WorkGroupID.x / blocks_m;\n\n#ifdef MUL_MAT_ID\n    if (bitCount(p.nei0) == 1) {\n        load_row_ids(expert_idx, true, ic);\n    } else {\n        load_row_ids(expert_idx, false, ic);\n    }\n\n    // Workgroup has no work\n    if (ic * BN >= _ne1) return;\n#endif\n\n#ifdef MUL_MAT_ID\n    uint start_k = 0;\n    const uint end_k = p.K;\n#else\n    uint start_k = ik * p.k_split;\n    const uint end_k = min(p.K, (ik + 1) * p.k_split);\n#endif\n\n#ifdef MUL_MAT_ID\n    uint pos_a = expert_idx * (p.batch_stride_a / QUANT_K);\n    uint pos_b = 0;\n#else\n    uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);\n    uint pos_b = batch_idx * p.batch_stride_b;\n    uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;\n#endif\n\n    uint stride_a = p.stride_a / QUANT_K;\n    uint stride_b = p.stride_b;\n\n    // Hint to the compiler that values are aligned (want 16B alignment).\n    // Quants are always block-aligned, no alignment needed.\n#if ALIGNED\n#if QUANT_K == 1\n    stride_a &= ~7;\n#endif\n    stride_b &= ~7;\n#endif\n\n    // Create layouts for both clamped and unclamped accesses\n    tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);\n    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);\n    tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);\n    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);\n    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);\n\n#if QUANT_K > 1\n    tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);\n    tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K);\n#endif\n\n    // Use end_k rather than p.K as the dimension because that's what\n    // we need to bound check against when using split_k.\n    // Bounds check B against padded_N, but bounds check D against N.\n    tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);\n    tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);\n    tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);\n    tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);\n    tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);\n\n    tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);\n\n    tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);\n\n#if !defined(MUL_MAT_ID)\n\n    const uint START_ALIGN_K = 256;\n    // For Qi_K (block size 256), unroll whole 256 element tiles.\n    // For legacy quants (block size 32), unroll 8x.\n    const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8);\n    const uint unroll_count = UNROLL_K / BK;\n\n    // Detect a fast path where all loads are entirely in bounds and no clamping is required\n    if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 &&\n#if QUANT_K == 1\n        (stride_a % 8) == 0 &&\n#endif\n        (stride_b % 8) == 0) {\n        // Hint to the compiler that values are aligned (want 16B alignment)\n        start_k &= ~(START_ALIGN_K-1);\n        stride_b &= ~7;\n#if QUANT_K == 1\n        stride_a &= ~7;\n#endif\n\n        tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);\n        tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);\n\n        uint k_iters = (end_k - start_k) / UNROLL_K;\n        uint block_k = start_k;\n\n        // fetch scale values for a tile of quants. These will be copied into shared memory.\n        // The fetches and stores are pipelined to hide the latency.\n        fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true);\n\n        if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {\n            coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);\n            for (uint i = 0; i < k_iters; ++i) {\n\n                store_scales(tid);\n                if (block_k + UNROLL_K < end_k) {\n                    fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);\n                }\n\n                // Manually partial unroll\n                [[unroll]] for (uint j = 0; j < unroll_count; ++j) {\n                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;\n                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;\n\n                    coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);\n                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);\n\n                    sum = coopMatMulAdd(mat_a, mat_b, sum);\n                    block_k += BK;\n                }\n            }\n            // Do any remaining iterations that were not unrolled\n            if (block_k < end_k) {\n                store_scales(tid);\n            }\n            while (block_k < end_k) {\n                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;\n                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;\n\n                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);\n                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);\n\n                sum = coopMatMulAdd(mat_a, mat_b, sum);\n                block_k += BK;\n            }\n#if defined(ACC_TYPE_MAX)\n            [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }\n#endif\n\n            coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);\n\n            coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);\n            return;\n        } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {\n            coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);\n            for (uint i = 0; i < k_iters; ++i) {\n\n                store_scales(tid);\n                if (block_k + UNROLL_K < end_k) {\n                    fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);\n                }\n\n                // Manually partial unroll\n                [[unroll]] for (uint j = 0; j < unroll_count; ++j) {\n                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;\n                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;\n\n                    coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);\n                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);\n\n                    sum = coopMatMulAdd(mat_a, mat_b, sum);\n                    block_k += BK;\n                }\n            }\n            // Do any remaining iterations that were not unrolled\n            if (block_k < end_k) {\n                store_scales(tid);\n            }\n            while (block_k < end_k) {\n                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;\n                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;\n\n                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);\n                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);\n\n                sum = coopMatMulAdd(mat_a, mat_b, sum);\n                block_k += BK;\n            }\n#if defined(ACC_TYPE_MAX)\n            [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }\n#endif\n\n            coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);\n\n            coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);\n            return;\n        } else {\n            coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);\n\n            for (uint i = 0; i < k_iters; ++i) {\n\n                store_scales(tid);\n                if (block_k + UNROLL_K < end_k) {\n                    fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);\n                }\n\n                // Manually partial unroll\n                [[unroll]] for (uint j = 0; j < unroll_count; ++j) {\n                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;\n                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;\n\n                    coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);\n                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);\n\n                    sum = coopMatMulAdd(mat_a, mat_b, sum);\n                    block_k += BK;\n                }\n            }\n            // Do any remaining iterations that were not unrolled\n            if (block_k < end_k) {\n                store_scales(tid);\n            }\n            while (block_k < end_k) {\n                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;\n                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;\n\n                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);\n                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);\n\n                sum = coopMatMulAdd(mat_a, mat_b, sum);\n                block_k += BK;\n            }\n#if defined(ACC_TYPE_MAX)\n            [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }\n#endif\n\n            coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);\n\n            coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);\n            return;\n        }\n    } else\n#endif // !defined(MUL_MAT_ID)\n    {\n        tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);\n\n        tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1);\n\n        tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);\n\n        tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);\n\n        uint k_iters = (end_k - start_k + BK - 1) / BK;\n\n        fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);\n        store_scales(tid);\n\n#ifdef MUL_MAT_ID\n        if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) {\n            coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum;\n            sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);\n\n            [[dont_unroll]]\n            for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {\n\n                if ((block_k % QUANT_K) == 0) {\n                    store_scales(tid);\n                }\n                if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {\n                    fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);\n                }\n\n                if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {\n                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;\n                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;\n\n                    coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);\n                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);\n\n                    sum = coopMatMulAdd(mat_a, mat_b, sum);\n                } else {\n                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;\n                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;\n\n                    coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);\n                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);\n\n                    sum = coopMatMulAdd(mat_a, mat_b, sum);\n                }\n            }\n#if defined(ACC_TYPE_MAX)\n            [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }\n#endif\n\n            // Convert from ACC_TYPE to D_TYPE\n            coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d;\n            mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);\n\n            // Call callback to store each element, remapping row through shared memory\n            coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);\n            return;\n        }\n        if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {\n            coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum;\n            sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);\n\n            [[dont_unroll]]\n            for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {\n\n                if ((block_k % QUANT_K) == 0) {\n                    store_scales(tid);\n                }\n                if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {\n                    fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);\n                }\n\n                if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {\n                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;\n                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;\n\n                    coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);\n                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);\n\n                    sum = coopMatMulAdd(mat_a, mat_b, sum);\n                } else {\n                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;\n                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;\n\n                    coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);\n                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);\n\n                    sum = coopMatMulAdd(mat_a, mat_b, sum);\n                }\n            }\n#if defined(ACC_TYPE_MAX)\n            [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }\n#endif\n\n            // Convert from ACC_TYPE to D_TYPE\n            coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d;\n            mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);\n\n            // Call callback to store each element, remapping row through shared memory\n            coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);\n            return;\n        }\n#endif\n        coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;\n        sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);\n\n        [[dont_unroll]]\n        for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {\n\n            if ((block_k % QUANT_K) == 0) {\n                store_scales(tid);\n            }\n            if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {\n                fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);\n            }\n\n            if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {\n                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;\n                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;\n\n                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);\n#ifdef MUL_MAT_ID\n                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);\n#else\n                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);\n#endif\n\n                sum = coopMatMulAdd(mat_a, mat_b, sum);\n            } else {\n                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;\n                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;\n\n                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);\n#ifdef MUL_MAT_ID\n                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);\n#else\n                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);\n#endif\n\n                sum = coopMatMulAdd(mat_a, mat_b, sum);\n            }\n        }\n#if defined(ACC_TYPE_MAX)\n        [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }\n#endif\n\n        // Convert from ACC_TYPE to D_TYPE\n        coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;\n        mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);\n\n#ifdef MUL_MAT_ID\n        // Call callback to store each element, remapping row through shared memory\n        coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);\n#else\n        coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);\n#endif\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl",
    "content": "void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {\n#if defined(DATA_A_F32) || defined(DATA_A_F16)\n#if LOAD_VEC_A == 8\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n            FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]);\n            buf_a[buf_idx    ] = aa[0].xy;\n            buf_a[buf_idx + 1] = aa[0].zw;\n            buf_a[buf_idx + 2] = aa[1].xy;\n            buf_a[buf_idx + 3] = aa[1].zw;\n#elif LOAD_VEC_A == 4\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n            FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);\n            buf_a[buf_idx    ] = aa.xy;\n            buf_a[buf_idx + 1] = aa.zw;\n#else // LOAD_VEC_BATCH_A == 2\n            const uint idx = pos_a + col * p.stride_a + row * 2;\n            const uint buf_idx = col * SHMEM_STRIDE + row;\n            if (idx_m < p.M && block + row * 2 + 1 < end_k) {\n                buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],\n                                                 data_a[idx + 1]);\n            } else if (idx_m < p.M && block + row * 2 < end_k) {\n                buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f);\n            } else {\n                buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);\n            }\n#endif\n#elif defined(DATA_A_BF16)\n#if LOAD_VEC_A == 4\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n            FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));\n            buf_a[buf_idx    ] = aa.xy;\n            buf_a[buf_idx + 1] = aa.zw;\n#else // LOAD_VEC_BATCH_A == 2\n            const uint idx = pos_a + col * p.stride_a + row * 2;\n            const uint buf_idx = col * SHMEM_STRIDE + row;\n            if (idx_m < p.M && block + row * 2 + 1 < end_k) {\n                buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),\n                                                 TO_FLOAT_TYPE(data_a[idx + 1]));\n            } else if (idx_m < p.M && block + row * 2 < end_k) {\n                buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);\n            } else {\n                buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);\n            }\n#endif\n#elif defined(DATA_A_Q4_0)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;\n\n            const uint ib = idx / 4;\n            const uint iqs = idx & 0x03;\n\n            const float d = float(data_a_packed16[ib].d);\n            const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);\n            const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;\n            const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;\n\n            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v0.xy);\n            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw);\n            buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy);\n            buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);\n#elif defined(DATA_A_Q4_1)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;\n\n            const uint ib = idx / 4;\n            const uint iqs = idx & 0x03;\n\n            const vec2 dm = vec2(data_a_packed32[ib].dm);\n            const uint vui = data_a_packed32[ib].qs[iqs];\n            const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y;\n            const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y;\n\n            buf_a[buf_idx     ] = FLOAT_TYPE_VEC2(v0.xy);\n            buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);\n            buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy);\n            buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);\n#elif defined(DATA_A_Q5_0)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;\n\n            const uint ib = idx / 8;\n            const uint iqs = idx & 0x07;\n\n            const float d = float(data_a_packed16[ib].d);\n            const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);\n            const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);\n            const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);\n\n            const uint vui = uint(data_a_packed16[ib].qs[iqs]);\n            const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;\n\n            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v.xz);\n            buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);\n#elif defined(DATA_A_Q5_1)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;\n\n            const uint ib = idx / 4;\n            const uint iqs = idx & 0x03;\n\n            const vec2 dm = vec2(data_a_packed32[ib].dm);\n            const uint uint_qh = data_a_packed32[ib].qh;\n            const uvec2 qh0 = uvec2(((uint_qh >> 4*iqs) << 4) & 0x10, (uint_qh >> (4*iqs + 12)) & 0x10);\n            const uvec2 qh1 = uvec2(((uint_qh >> (4*iqs + 1)) << 4) & 0x10, (uint_qh >> (4*iqs + 13)) & 0x10);\n            const uvec2 qh2 = uvec2(((uint_qh >> (4*iqs + 2)) << 4) & 0x10, (uint_qh >> (4*iqs + 14)) & 0x10);\n            const uvec2 qh3 = uvec2(((uint_qh >> (4*iqs + 3)) << 4) & 0x10, (uint_qh >> (4*iqs + 15)) & 0x10);\n\n            const uint vui = data_a_packed32[ib].qs[iqs];\n            const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y;\n            const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y;\n\n            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v0.xz);\n            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz);\n            buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw);\n            buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw);\n#elif defined(DATA_A_Q8_0)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n\n            const uint ib = idx / 8;\n            const uint iqs = idx & 0x07;\n\n            const float d = float(data_a_packed16[ib].d);\n            const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147\n            const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;\n            const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;\n\n            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v.xy);\n            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);\n#elif defined(DATA_A_Q2_K)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n\n            const uint ib = idx / 64;                          // 4 values per idx\n            const uint iqs = (idx % 64) * 2;                   // 0,2,4..126\n\n            const uint qsi = (iqs / 64) * 16 + (iqs % 16);     // 0..15\n            const uint scalesi = iqs / 8;                      // 0..15\n            const uint qsshift = ((iqs % 64) / 16) * 2;        // 0,2,4,6\n\n            const vec4 qs = vec4(unpack8((data_a_packed32[ib].qs[qsi / 2] >> qsshift) & 0x03030303));\n            const uint scales = data_a[ib].scales[scalesi];\n            const vec2 dm = vec2(data_a[ib].dm);\n\n            const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4);\n\n            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v.xy);\n            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);\n#elif defined(DATA_A_Q3_K)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n\n            const uint ib = idx / 128;                   // 2 values per idx\n            const uint iqs = idx % 128;                  // 0..127\n\n            const uint n = iqs / 64;                     // 0,1\n            const uint qsi = n * 32 + (iqs % 16) * 2;    // 0,2,4..62\n            const uint hmi =          (iqs % 16) * 2;    // 0,2,4..30\n            const uint j = (iqs % 64) / 4;               // 0..3\n            const uint is = iqs / 8;                     // 0..15\n            const uint halfsplit = ((iqs % 64) / 16);    // 0,1,2,3\n            const uint qsshift = halfsplit * 2;          // 0,2,4,6\n\n            const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)\n                                  | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));\n            const float dl = float(data_a[ib].d) * float(us - 32);\n\n            const vec2 qs = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> qsshift) & 0x0303).xy);\n            const vec2 hm = vec2(unpack8(((uint(data_a_packed16[ib].hmask[hmi / 2]) >> (4 * n + halfsplit)) & 0x0101 ^ 0x0101) << 2).xy);\n\n            buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * (qs.x - hm.x),\n                                             dl * (qs.y - hm.y));\n#elif defined(DATA_A_Q4_K)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n\n            const uint ib = idx / 64;                  // 4 values per idx\n            const uint iqs = (idx % 64) * 2;           // 0,2,4..126\n\n            const uint n = iqs / 32;                   // 0,1,2,3\n            const uint b = (iqs % 32) / 16;            // 0,1\n            const uint is = 2 * n + b;                 // 0..7\n            const uint qsi = n * 32 + (iqs % 16) * 2;  // 0,2,4..126\n\n            const vec2 loadd = vec2(data_a[ib].dm);\n\n            const uint scidx0 = (is < 4) ? is : (is + 4);\n            const uint scidx1 = (is < 4) ? is : (is - 4);\n            const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;\n            const uint scidxshift1 = (is < 4) ? 0 : 2;\n            const uint mbidx0 = is + 4;\n            const uint mbidx1 = (is < 4) ? is + 4 : is;\n            const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;\n            const uint mbidxshift0 = (is < 4) ? 0 : 4;\n            const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;\n            const uint mbidxshift1 = (is < 4) ? 0 : 2;\n\n            const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));\n            const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));\n\n            const float d = loadd.x * sc;\n            const float m = -loadd.y * mbyte;\n\n            const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F));\n\n            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));\n            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));\n#elif defined(DATA_A_Q5_K)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n\n            const uint ib = idx / 64;                  // 4 values per idx\n            const uint iqs = (idx % 64) * 2;           // 0,2,4..126\n\n            const uint n = iqs / 32;                   // 0,1,2,3\n            const uint b = (iqs % 32) / 16;            // 0,1\n            const uint is = 2 * n + b;                 // 0..7\n            const uint qsi = n * 32 + (iqs % 16) * 2;  // 0,2,4..126\n            const uint qhi = (iqs % 16) * 2;           // 0,2,4..30\n\n            const vec2 loadd = vec2(data_a[ib].dm);\n\n            const uint scidx0 = (is < 4) ? is : (is + 4);\n            const uint scidx1 = (is < 4) ? is : (is - 4);\n            const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;\n            const uint scidxshift1 = (is < 4) ? 0 : 2;\n            const uint mbidx0 = is + 4;\n            const uint mbidx1 = (is < 4) ? is + 4 : is;\n            const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;\n            const uint mbidxshift0 = (is < 4) ? 0 : 4;\n            const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;\n            const uint mbidxshift1 = (is < 4) ? 0 : 2;\n\n            const uint8_t sc    = uint8_t((data_a[ib].scales[scidx0] & 0xF)                         | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));\n            const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));\n\n            const float d = loadd.x * sc;\n            const float m = -loadd.y * mbyte;\n\n            const uint qs = (data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F;\n            const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4;\n            const vec4 q = vec4(unpack8(qs | qh));\n\n            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));\n            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));\n#elif defined(DATA_A_Q6_K)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n\n            const uint ib = idx / 128;                  // 2 values per idx\n            const uint iqs = idx % 128;                 // 0..127\n\n            const uint n = iqs / 64;                    // 0,1\n            const uint b = ((iqs % 64) / 32) * 4;       // 0,4\n            const uint is_b = (iqs % 16) / 8;           // 0,1\n            const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6\n            const uint is = 8 * n + qhshift + is_b;     // 0..15\n            const uint qsi = n * 32 + (iqs % 32);       // 0..63\n            const uint qhi = n * 16 + (iqs % 16);       // 0..31\n\n            const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);\n\n            const uint ql = (uint(data_a_packed16[ib].ql[qsi]) >> b) & 0x0F0F;\n            const uint qh = (uint(data_a_packed16[ib].qh[qhi]) >> qhshift) & 0x0303;\n            const vec2 q = (vec2(unpack8(ql | (qh << 4)).xy) - 32) * dscale;\n\n            buf_a[buf_idx] = FLOAT_TYPE_VEC2(q.x, q.y);\n#elif defined(DATA_A_IQ1_S)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n\n            const uint ib = idx / 32;                  // 8 values per idx\n            const uint ib32 = (idx % 32) / 4;         // 0..7\n            const uint ib8 = idx % 32;\n\n            const float d = float(data_a[ib].d);\n            const uint qh = data_a[ib].qh[ib32];\n            const uint qs = data_a[ib].qs[ib8];\n            const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);\n            const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;\n            const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);\n\n            [[unroll]] for (int k = 0; k < 4; ++k) {\n                buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k    , 2) + delta),\n                                                     dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));\n            }\n#elif defined(DATA_A_IQ1_M)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n\n            const uint ib = idx / 32;  // 8 values per idx\n            const uint ib8 = idx % 32;\n            const uint ib16 = ib8 / 2;\n\n            const uint16_t[4] scales = data_a[ib].scales;\n            const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;\n            const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);\n            const uint sc = scales[ib8 / 8];\n            const uint qs = data_a[ib].qs[ib8];\n            const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1));\n            const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);\n            const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;\n            const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);\n\n            [[unroll]] for (int k = 0; k < 4; ++k) {\n                buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k    , 2) + delta),\n                                                     dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));\n            }\n#elif defined(DATA_A_IQ2_XXS)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n\n            const uint ib = idx / 32;                 // 8 values per idx\n            const uint ib32 = (idx % 32) / 4;         // 0..7\n            const uint ib8 = idx % 4;\n\n            const float d = float(data_a[ib].d);\n            const uint qs = data_a[ib].qs[8 * ib32 + ib8];\n            const uint signs = pack32(u8vec4(\n                data_a[ib].qs[8*ib32 + 4],\n                data_a[ib].qs[8*ib32 + 5],\n                data_a[ib].qs[8*ib32 + 6],\n                data_a[ib].qs[8*ib32 + 7]\n            ));\n            const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));\n            const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);\n            const uint sign = sign7 | (bitCount(sign7) << 7);\n            const uvec2 grid = iq2xxs_grid[qs];\n            const vec4 grid0 = vec4(unpack8(grid.x));\n            const vec4 grid1 = vec4(unpack8(grid.y));\n\n            buf_a[buf_idx    ] = db * FLOAT_TYPE_VEC2((sign &   1) != 0 ? -grid0.x : grid0.x,\n                                                      (sign &   2) != 0 ? -grid0.y : grid0.y);\n            buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign &   4) != 0 ? -grid0.z : grid0.z,\n                                                      (sign &   8) != 0 ? -grid0.w : grid0.w);\n            buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign &  16) != 0 ? -grid1.x : grid1.x,\n                                                      (sign &  32) != 0 ? -grid1.y : grid1.y);\n            buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign &  64) != 0 ? -grid1.z : grid1.z,\n                                                      (sign & 128) != 0 ? -grid1.w : grid1.w);\n#elif defined(DATA_A_IQ2_XS)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n\n            const uint ib = idx / 32;            // 8 values per idx\n            const uint ib32 = (idx % 32) / 4;    // 0..7\n            const uint ib8 = idx % 4;            // 0..3\n\n            const float d = float(data_a[ib].d);\n            const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;\n            const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));\n            const uint qs = data_a[ib].qs[4 * ib32 + ib8];\n            const uint sign7 = qs >> 9;\n            const uint sign = sign7 | (bitCount(sign7) << 7);\n            const uvec2 grid = iq2xs_grid[qs & 511];\n            const vec4 grid0 = vec4(unpack8(grid.x));\n            const vec4 grid1 = vec4(unpack8(grid.y));\n\n            buf_a[buf_idx    ] = db * FLOAT_TYPE_VEC2((sign &   1) != 0 ? -grid0.x : grid0.x,\n                                                      (sign &   2) != 0 ? -grid0.y : grid0.y);\n            buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign &   4) != 0 ? -grid0.z : grid0.z,\n                                                      (sign &   8) != 0 ? -grid0.w : grid0.w);\n            buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign &  16) != 0 ? -grid1.x : grid1.x,\n                                                      (sign &  32) != 0 ? -grid1.y : grid1.y);\n            buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign &  64) != 0 ? -grid1.z : grid1.z,\n                                                      (sign & 128) != 0 ? -grid1.w : grid1.w);\n#elif defined(DATA_A_IQ2_S)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n\n            const uint ib = idx / 32;  // 8 values per idx\n            const uint ib8 = idx % 32; // 0..31\n            const uint ib32 = ib8 / 4; // 0..7\n\n            const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;\n            const uint qs = data_a[ib].qs[ib8];\n            const uint qh = data_a[ib].qh[ib32];\n            const uint qhshift = 2 * (ib8 % 4);\n            const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];\n\n            const float d = float(data_a[ib].d);\n            const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));\n            const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];\n            const vec4 grid0 = vec4(unpack8(grid.x));\n            const vec4 grid1 = vec4(unpack8(grid.y));\n\n            buf_a[buf_idx    ] = db * FLOAT_TYPE_VEC2((sign &   1) != 0 ? -grid0.x : grid0.x,\n                                                      (sign &   2) != 0 ? -grid0.y : grid0.y);\n            buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign &   4) != 0 ? -grid0.z : grid0.z,\n                                                      (sign &   8) != 0 ? -grid0.w : grid0.w);\n            buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign &  16) != 0 ? -grid1.x : grid1.x,\n                                                      (sign &  32) != 0 ? -grid1.y : grid1.y);\n            buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign &  64) != 0 ? -grid1.z : grid1.z,\n                                                      (sign & 128) != 0 ? -grid1.w : grid1.w);\n#elif defined(DATA_A_IQ3_XXS)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n\n            const uint ib = idx / 64;            // 4 values per idx\n            const uint iqs = idx % 64;           // 0..63\n            const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values\n\n            const float d = float(data_a[ib].d);\n            const uint qs = data_a[ib].qs[iqs];\n            const uint signs = pack32(u16vec2(\n                data_a_packed16[ib].qs[is/2],\n                data_a_packed16[ib].qs[is/2+1]\n            ));\n            const float db = d * 0.5 * (0.5 + (signs >> 28));\n            const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);\n            const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));\n            const uint grid = iq3xxs_grid[qs];\n            const vec4 v = db * vec4(unpack8(grid));\n\n            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2((sign &   1) != 0 ? -v.x : v.x,\n                                                 (sign &   2) != 0 ? -v.y : v.y);\n            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign &   4) != 0 ? -v.z : v.z,\n                                                 (sign &   8) != 0 ? -v.w : v.w);\n#elif defined(DATA_A_IQ3_S)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n\n            const uint ib = idx / 64;            // 4 values per idx\n            const uint iqs = idx % 64;           // 0..63\n            const uint iqh = iqs / 8;\n\n            const float d = float(data_a[ib].d);\n            const uint qs = data_a[ib].qs[iqs];\n            const uint qh = data_a[ib].qh[iqh];\n            const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));\n            const uint scale = data_a[ib].scales[iqs / 16];\n            const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));\n            const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));\n            const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];\n            const vec4 v = db * vec4(unpack8(grid));\n\n            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2((sign &   1) != 0 ? -v.x : v.x,\n                                                 (sign &   2) != 0 ? -v.y : v.y);\n            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign &   4) != 0 ? -v.z : v.z,\n                                                 (sign &   8) != 0 ? -v.w : v.w);\n#elif defined(DATA_A_IQ4_XS)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;\n\n            const uint ib = idx / 128;                  // 2 values per idx\n            const uint ib32 = (idx % 128) / 16;         // 0..7\n            const uint iq = 16 * ib32 + 2 * (idx % 8);\n\n            const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;\n            const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;\n            const uint qshift = (idx & 8) >> 1;\n            u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy;\n\n            const float d = float(data_a[ib].d);\n            const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);\n\n            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v.xy);\n#elif defined(DATA_A_IQ4_NL)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;\n\n            const uint ib = idx / 8;\n            const uint iqs = idx & 0x07;\n\n            const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);\n            const uint vui = uint(data_a_packed16[ib].qs[iqs]);\n\n            buf_a[buf_idx    ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF],\n                                                      kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]);\n            buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)],\n                                                     kvalues_iq4nl[vui >> 12]);\n#elif defined(DATA_A_MXFP4)\n            const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;\n\n            const uint ib = idx / 8;\n            const uint iqs = (idx & 0x07) * 2;\n\n            const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;\n            const uint vui = uint(data_a[ib].qs[iqs]);\n            const uint vui2 = uint(data_a[ib].qs[iqs+1]);\n\n            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui  & 0xF] * d,\n                                                 kvalues_mxfp4[vui2 & 0xF] * d);\n            buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui  >>  4] * d,\n                                                 kvalues_mxfp4[vui2 >>  4] * d);\n#endif\n}\n\n#if !defined(MUL_MAT_ID)\nvoid load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {\n#if LOAD_VEC_B == 8\n            // Not supported for b_type bf16 because bf16mat2x4 does not exist\n            const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;\n            FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);\n            buf_b[buf_idx + 0] = bb[0].xy;\n            buf_b[buf_idx + 1] = bb[0].zw;\n            buf_b[buf_idx + 2] = bb[1].xy;\n            buf_b[buf_idx + 3] = bb[1].zw;\n#elif LOAD_VEC_B == 4\n            const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;\n#if defined(DATA_B_BF16)\n            FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));\n#else\n            FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);\n#endif\n            buf_b[buf_idx + 0] = bb.xy;\n            buf_b[buf_idx + 1] = bb.zw;\n#else // LOAD_VEC_BATCH_B == 2\n            const uint idx = pos_b + col * p.stride_b + row * 2;\n            const uint buf_idx = col * SHMEM_STRIDE + row;\n            if (idx_n < p.N && block + row * 2 + 1 < end_k) {\n                buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),\n                                                 TO_FLOAT_TYPE(data_b[idx + 1]));\n            } else if (idx_n < p.N && block + row * 2 < end_k) {\n                buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);\n            } else {\n                buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);\n            }\n#endif\n}\n#else\nvoid load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {\n#if LOAD_VEC_B == 8\n            // Not supported for b_type bf16 because bf16mat2x4 does not exist\n            const u16vec2 row_idx = row_ids[col];\n            const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;\n            FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);\n            buf_b[buf_idx + 0] = bb[0].xy;\n            buf_b[buf_idx + 1] = bb[0].zw;\n            buf_b[buf_idx + 2] = bb[1].xy;\n            buf_b[buf_idx + 3] = bb[1].zw;\n#elif LOAD_VEC_B == 4\n            const u16vec2 row_idx = row_ids[col];\n            const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;\n            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;\n#if defined(DATA_B_BF16)\n            FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));\n#else\n            FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);\n#endif\n            buf_b[buf_idx + 0] = bb.xy;\n            buf_b[buf_idx + 1] = bb.zw;\n#else // LOAD_VEC_BATCH_B == 2\n            const uint row_i = ic * BN + col;\n            const uint buf_idx = col * SHMEM_STRIDE + row;\n            if (row_i < _ne1 && block + row * 2 + 1 < end_k) {\n                const u16vec2 row_idx = row_ids[col];\n                const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;\n                buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),\n                                                 TO_FLOAT_TYPE(data_b[idx + 1]));\n            } else if (row_i < _ne1 && block + row * 2 < end_k) {\n                const u16vec2 row_idx = row_ids[col];\n                const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;\n                buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);\n            } else {\n                buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);\n            }\n#endif\n}\n#endif\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl",
    "content": "#ifdef MUL_MAT_ID\nshared u16vec2 row_ids[BN];\nuint _ne1;\n\n#ifdef MUL_MAT_ID_USE_SUBGROUPS\nshared uvec4 ballots_sh[NUM_WARPS];\n\nvoid load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {\n    _ne1 = 0;\n    uint num_elements = p.nei1 * p.nei0;\n    uint nei0shift = findLSB(p.nei0);\n\n    uint ids[16];\n    uint iter = 0;\n\n    uint expert_count = data_expert_count[expert_idx];\n\n    for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {\n        // prefetch up to 16 elements\n        if (iter == 0) {\n            [[unroll]] for (uint k = 0; k < 16; ++k) {\n                uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;\n                bool in_range = i < num_elements;\n                uint ii1;\n                if (nei0_is_pow2) {\n                    ii1 = i >> nei0shift;\n                } else {\n                    ii1 = i / p.nei0;\n                }\n                uint ii0 = i - ii1 * p.nei0;\n                ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;\n            }\n        }\n        uint i = j + gl_LocalInvocationIndex;\n        bool in_range = i < num_elements;\n        uint ii1;\n        if (nei0_is_pow2) {\n            ii1 = i >> nei0shift;\n        } else {\n            ii1 = i / p.nei0;\n        }\n        uint ii0 = i - ii1 * p.nei0;\n        uint id = ids[iter++];\n        uvec4 ballot = subgroupBallot(in_range && id == expert_idx);\n\n        if (gl_SubgroupInvocationID == 0) {\n            ballots_sh[gl_SubgroupID] = ballot;\n        }\n        barrier();\n\n        uint subgroup_base = 0;\n        uint total = 0;\n        for (uint k = 0; k < gl_NumSubgroups; ++k) {\n            if (k == gl_SubgroupID) {\n                subgroup_base = total;\n            }\n            total += subgroupBallotBitCount(ballots_sh[k]);\n        }\n        barrier();\n\n        uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);\n        if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {\n            row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);\n        }\n        _ne1 += total;\n        iter &= 15;\n        if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) {\n            break;\n        }\n    }\n    barrier();\n}\n#endif // MUL_MAT_ID_USE_SUBGROUPS\n#endif // MUL_MAT_ID\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mmq.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_EXT_shader_16bit_storage : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require\n\n#extension GL_EXT_integer_dot_product : require\n\n#ifdef FLOAT16\n#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require\n#endif\n\n#if defined(MUL_MAT_ID_USE_SUBGROUPS)\n#extension GL_KHR_shader_subgroup_basic : enable\n#extension GL_KHR_shader_subgroup_ballot : enable\n#endif\n\n#ifdef MUL_MAT_ID\n#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require\n#endif\n\n#include \"types.glsl\"\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\n#if defined(A_TYPE_PACKED16)\nlayout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};\n#endif\n#if defined(A_TYPE_PACKED32)\nlayout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};\n#endif\nlayout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};\nlayout (binding = 2) writeonly buffer D {D_TYPE data_d[];};\n\n#ifdef MUL_MAT_ID\nlayout (binding = 3) readonly buffer IDS {int data_ids[];};\nlayout (binding = 4) readonly buffer Counts {int data_expert_count[];};\n#endif\n\nlayout (push_constant) uniform parameter\n{\n    uint M;\n    uint N;\n    uint K;\n    uint stride_a;\n    uint stride_b;\n    uint stride_d;\n\n    uint batch_stride_a;\n    uint batch_stride_b;\n    uint batch_stride_d;\n\n#ifdef MUL_MAT_ID\n    uint nei0;\n    uint nei1;\n    uint nbi1;\n    uint ne11;\n#else\n    uint base_work_group_z;\n    uint num_batches;\n    uint k_split;\n    uint ne02;\n    uint ne12;\n    uint broadcast2;\n    uint broadcast3;\n#endif\n} p;\n\nlayout (constant_id = 0) const uint BLOCK_SIZE = 64;\nlayout (constant_id = 1) const uint BM = 64;\nlayout (constant_id = 2) const uint BN = 64;\n// layout (constant_id = 3) const uint BK = 32;\nlayout (constant_id = 4) const uint WM = 32;\nlayout (constant_id = 5) const uint WN = 32;\nlayout (constant_id = 6) const uint WMITER = 2;\nlayout (constant_id = 7) const uint TM = 4;\nlayout (constant_id = 8) const uint TN = 2;\nlayout (constant_id = 9) const uint TK = 1;  // Only needed for coopmat\nlayout (constant_id = 10) const uint WARP = 32;\n\n#define BK 32\n\n#include \"mul_mmq_shmem_types.glsl\"\n\n#ifdef MUL_MAT_ID\n#define BK_STEP 1\n#else\n#ifndef BK_STEP\n#define BK_STEP 4\n#endif\n#endif\n\n// Shared memory cache\nshared block_a_cache buf_a[BM * BK_STEP];\nshared block_b_cache buf_b[BN * BK_STEP];\n// Register cache\nblock_a_cache cache_a[WMITER * TM];\nblock_b_cache cache_b;\n\n#define LOAD_VEC_A (4 * QUANT_R_MMQ)\n#define LOAD_VEC_B 16\n\n#define NUM_WARPS (BLOCK_SIZE / WARP)\n\n#include \"mul_mm_id_funcs.glsl\"\n#include \"mul_mmq_funcs.glsl\"\n\nvoid main() {\n    const uint ic = gl_WorkGroupID.y;\n\n#ifdef MUL_MAT_ID\n    const uint expert_idx = gl_WorkGroupID.z;\n    if (ic * BN >= data_expert_count[expert_idx]) {\n        return;\n    }\n#endif\n#ifdef NEEDS_INIT_IQ_SHMEM\n    init_iq_shmem(gl_WorkGroupSize);\n#endif\n\n#ifndef MUL_MAT_ID\n    const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;\n\n    const uint i13 = batch_idx / p.ne12;\n    const uint i12 = batch_idx % p.ne12;\n\n    const uint i03 = i13 / p.broadcast3;\n    const uint i02 = i12 / p.broadcast2;\n\n    const uint batch_idx_a = i03 * p.ne02 + i02;\n#endif\n\n    const uint blocks_m = (p.M + BM - 1) / BM;\n    const uint ir = gl_WorkGroupID.x % blocks_m;\n    const uint ik = gl_WorkGroupID.x / blocks_m;\n\n    const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);\n    const uint WSUBM = WM / WMITER;\n    const uint WSUBN = WN / WNITER;\n    const uint warp_i = gl_LocalInvocationID.x / WARP;\n\n    const uint tiw = gl_LocalInvocationID.x % WARP;\n\n    const uint tiwr = tiw % (WSUBM / TM);\n    const uint tiwc = tiw / (WSUBM / TM);\n\n    const uint warp_r = warp_i % (BM / WM);\n    const uint warp_c = warp_i / (BM / WM);\n\n    const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);\n    const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);\n    const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);\n    const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);\n\n    const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK;\n    const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;\n\n#ifdef MUL_MAT_ID\n#ifdef MUL_MAT_ID_USE_SUBGROUPS\n    if (bitCount(p.nei0) == 1) {\n        load_row_ids(expert_idx, true, ic);\n    } else {\n        load_row_ids(expert_idx, false, ic);\n    }\n#else\n    _ne1 = 0;\n    for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {\n        for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {\n            if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {\n                if (_ne1 >= ic * BN) {\n                    row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);\n                }\n                _ne1++;\n            }\n        }\n    }\n\n    barrier();\n#endif\n\n    // Workgroup has no work\n    if (ic * BN >= _ne1) return;\n#endif\n\n#ifdef MUL_MAT_ID\n    const uint start_k = 0;\n    const uint end_k = p.K;\n#else\n    const uint start_k = ik * p.k_split;\n    const uint end_k = min(p.K, (ik + 1) * p.k_split);\n#endif\n\n    uint pos_a_ib =\n#ifdef MUL_MAT_ID\n        expert_idx * (p.batch_stride_a / BK) +\n#else\n        batch_idx_a * (p.batch_stride_a / BK) +\n#endif\n        (ir * BM * p.stride_a + start_k) / BK;\n#ifdef MUL_MAT_ID\n    uint pos_b_ib = 0;\n#else\n    uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;\n#endif\n\n    ACC_TYPE sums[WMITER * TM * WNITER * TN];\n\n    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {\n        sums[i] = ACC_TYPE(0.0f);\n    }\n\n    for (uint block = start_k; block < end_k; block += BK * BK_STEP) {\n        [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {\n            const uint buf_ib = loadc_a + l;\n            const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;\n            const uint iqs = loadr_a;\n\n            [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {\n                if (block + k_step * BK < end_k) {\n                    block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);\n                }\n            }\n        }\n        [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {\n            const uint buf_ib = loadc_b + l;\n\n#ifdef MUL_MAT_ID\n            const u16vec2 row_idx = row_ids[buf_ib];\n            const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;\n#else\n            const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;\n#endif\n            const uint iqs = loadr_b;\n\n            [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {\n                block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs, block + k_step * BK < end_k);\n            }\n        }\n\n        barrier();\n\n        pos_a_ib += BK_STEP;\n        pos_b_ib += BK_STEP;\n\n        for (uint k_step = 0; k_step < BK_STEP; k_step++) {\n            // Load from shared into cache\n            [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {\n                [[unroll]] for (uint cr = 0; cr < TM; cr++) {\n                    const uint reg_ib = wsir * TM + cr;\n                    const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;\n\n                    block_a_to_registers(reg_ib, k_step * BM + buf_ib);\n                }\n            }\n\n            [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {\n                [[unroll]] for (uint cc = 0; cc < TN; cc++) {\n                    const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;\n                    block_b_to_registers(ib);\n\n                    [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {\n                        [[unroll]] for (uint cr = 0; cr < TM; cr++) {\n                            const uint cache_a_idx = wsir * TM + cr;\n                            const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;\n\n                            sums[sums_idx] += mmq_dot_product(cache_a_idx);\n                        }\n                    }\n                }\n            }\n        }\n\n        barrier();\n    }\n\n    const uint dr = ir * BM + warp_r * WM;\n    const uint dc = ic * BN + warp_c * WN;\n\n#ifndef MUL_MAT_ID\n    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;\n#endif\n\n    [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {\n        [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {\n\n            const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;\n            const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;\n            [[unroll]] for (uint cc = 0; cc < TN; cc++) {\n#ifdef MUL_MAT_ID\n                const uint row_i = dc_warp + cc;\n                if (row_i >= _ne1) break;\n\n                const u16vec2 row_idx = row_ids[row_i - ic * BN];\n#endif // MUL_MAT_ID\n                [[unroll]] for (uint cr = 0; cr < TM; cr++) {\n                    const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;\n#ifdef MUL_MAT_ID\n                    if (dr_warp + cr < p.M) {\n                        data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);\n                    }\n#else\n                    if (dr_warp + cr < p.M && dc_warp + cc < p.N) {\n                        data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);\n                    }\n#endif // MUL_MAT_ID\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl",
    "content": "#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require\n\n#include \"types.glsl\"\n\n// Each iqs value maps to a 32-bit integer\n\n#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)\n// 2-byte loads for Q4_0 blocks (18 bytes)\n// 4-byte loads for Q4_1 blocks (20 bytes)\nvoid block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {\n#ifdef DATA_A_Q4_0\n    buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],\n                                           data_a_packed16[ib].qs[iqs * 2 + 1]));\n\n    if (iqs == 0) {\n        buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);\n    }\n#else // DATA_A_Q4_1\n    buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];\n\n    if (iqs == 0) {\n        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);\n    }\n#endif\n}\n\nvoid block_a_to_registers(const uint reg_ib, const uint buf_ib) {\n    cache_a[reg_ib].dm = buf_a[buf_ib].dm;\n\n    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {\n        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];\n    }\n}\n\nACC_TYPE mmq_dot_product(const uint ib_a) {\n    int32_t q_sum = 0;\n    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {\n        const uint32_t vui = cache_a[ib_a].qs[iqs];\n        const i32vec2 qs_a = i32vec2( vui       & 0x0F0F0F0F,\n                                     (vui >> 4) & 0x0F0F0F0F);\n\n        const int32_t qs_b0 = cache_b.qs[iqs];\n        const int32_t qs_b1 = cache_b.qs[iqs + 4];\n\n        q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);\n        q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);\n    }\n\n#ifdef DATA_A_Q4_0\n    return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 8.0 * float(cache_b.ds.y)));\n#else // DATA_A_Q4_1\n    return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));\n#endif\n}\n#endif\n\n#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)\n// 2-byte loads for Q5_0 blocks (22 bytes)\n// 4-byte loads for Q5_1 blocks (24 bytes)\nvoid block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {\n#ifdef DATA_A_Q5_0\n    buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],\n                                           data_a_packed16[ib].qs[iqs * 2 + 1]));\n\n    if (iqs == 0) {\n        buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);\n        buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1]));\n    }\n#else // DATA_A_Q5_1\n    buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];\n\n    if (iqs == 0) {\n        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);\n        buf_a[buf_ib].qh = data_a_packed32[ib].qh;\n    }\n#endif\n}\n\nvoid block_a_to_registers(const uint reg_ib, const uint buf_ib) {\n    cache_a[reg_ib].dm = buf_a[buf_ib].dm;\n    cache_a[reg_ib].qh = buf_a[buf_ib].qh;\n\n    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {\n        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];\n    }\n}\n\nACC_TYPE mmq_dot_product(const uint ib_a) {\n    int32_t q_sum = 0;\n    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {\n        const uint32_t vui = cache_a[ib_a].qs[iqs];\n        const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));\n        const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)\n                         | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)\n        const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)\n                         | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)\n\n        const int32_t qs_b0 = cache_b.qs[iqs];\n        const int32_t qs_b1 = cache_b.qs[iqs + 4];\n\n        q_sum += dotPacked4x8EXT(qs_a0, qs_b0);\n        q_sum += dotPacked4x8EXT(qs_a1, qs_b1);\n    }\n\n#ifdef DATA_A_Q5_0\n    return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 16.0 * float(cache_b.ds.y)));\n#else // DATA_A_Q5_1\n    return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));\n#endif\n}\n#endif\n\n#if defined(DATA_A_Q8_0)\n// 2-byte loads for Q8_0 blocks (34 bytes)\nvoid block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {\n    buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],\n                                           data_a_packed16[ib].qs[iqs * 2 + 1]));\n\n    if (iqs == 0) {\n        buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);\n    }\n}\n\nvoid block_a_to_registers(const uint reg_ib, const uint buf_ib) {\n    cache_a[reg_ib].dm = buf_a[buf_ib].dm;\n\n    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {\n        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];\n    }\n}\n\nACC_TYPE mmq_dot_product(const uint ib_a) {\n    int32_t q_sum = 0;\n    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {\n        const int32_t qs_a = cache_a[ib_a].qs[iqs];\n        const int32_t qs_b = cache_b.qs[iqs];\n\n        q_sum += dotPacked4x8EXT(qs_a, qs_b);\n    }\n\n    return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm) * float(cache_b.ds.x));\n}\n#endif\n\n#if defined(DATA_A_MXFP4)\n// 1-byte loads for mxfp4 blocks (17 bytes)\nvoid block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {\n    const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4    ],\n                                      data_a[ib].qs[iqs * 4 + 1],\n                                      data_a[ib].qs[iqs * 4 + 2],\n                                      data_a[ib].qs[iqs * 4 + 3]));\n\n    const u8vec4 i_a0 = unpack8( qs       & 0x0F0F0F0F);\n    const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);\n\n    buf_a[buf_ib].qs[iqs    ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));\n    buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));\n\n    if (iqs == 0) {\n        buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);\n    }\n}\n\nvoid block_a_to_registers(const uint reg_ib, const uint buf_ib) {\n    cache_a[reg_ib].d = buf_a[buf_ib].d;\n\n    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {\n        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];\n    }\n}\n\nACC_TYPE mmq_dot_product(const uint ib_a) {\n    int32_t q_sum = 0;\n    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {\n        const int32_t qs_a = cache_a[ib_a].qs[iqs];\n\n        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);\n    }\n\n    return ACC_TYPE(float(cache_a[ib_a].d) * float(cache_b.ds.x) * float(q_sum));\n}\n#endif\n\n// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide\n// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants\n#if defined(DATA_A_Q2_K)\n// 4-byte loads for Q2_K blocks (84 bytes)\nvoid block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {\n    const uint ib_k = ib / 8;\n    const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;\n\n    const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);\n    const uint qs_shift = ((iqs_k % 32) / 8) * 2;\n\n    // Repack 4x4 quants into one int\n    const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx    ] >> qs_shift) & 0x03030303;\n    const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;\n    const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;\n    const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;\n\n    buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);\n\n    if (iqs == 0) {\n        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);\n        buf_a[buf_ib].scales = unpack8(uint32_t(data_a_packed16[ib_k].scales[iqs_k / 8])).xy; // vec4 used due to #12147\n    }\n}\n\nvoid block_a_to_registers(const uint reg_ib, const uint buf_ib) {\n    cache_a[reg_ib].dm = buf_a[buf_ib].dm;\n    cache_a[reg_ib].scales = buf_a[buf_ib].scales;\n\n    [[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {\n        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];\n    }\n}\n\nACC_TYPE mmq_dot_product(const uint ib_a) {\n    int32_t sum_d = 0;\n    int32_t sum_m = 0;\n\n    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {\n        const uint8_t scale = cache_a[ib_a].scales[iqs / 4];\n        const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.\n        const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);\n\n        sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);\n        sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);\n    }\n\n    return ACC_TYPE(float(cache_b.ds.x) * (float(cache_a[ib_a].dm.x) * float(sum_d) - float(cache_a[ib_a].dm.y) * float(sum_m)));\n}\n#endif\n\n#if defined(DATA_A_Q3_K)\n// 2-byte loads for Q3_K blocks (110 bytes)\nvoid block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {\n    const uint ib_k = ib / 8;\n    const uint hm_idx = iqs * QUANT_R_MMQ;\n    const uint iqs_k = (ib % 8) * 8 + hm_idx;\n\n    const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);\n    const uint qs_shift = ((iqs_k % 32) / 8) * 2;\n    const uint hm_shift = iqs_k / 8;\n\n    // Repack 2x4 quants into one int\n    // Add the 3rd bit instead of subtracting it to allow packing the quants\n    // vec4 for unpack8 used due to #12147\n    const i8vec2 vals00 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2        ] >> qs_shift) & uint16_t(0x0303)))).xy |\n                          unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2    ] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;\n    const i8vec2 vals01 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1    ] >> qs_shift) & uint16_t(0x0303)))).xy |\n                          unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;\n    const i8vec2 vals10 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2    ] >> qs_shift) & uint16_t(0x0303)))).xy |\n                          unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;\n    const i8vec2 vals11 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3    ] >> qs_shift) & uint16_t(0x0303)))).xy |\n                          unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;\n    buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |\n                           (pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);\n\n    if (iqs == 0) {\n        const uint is = iqs_k / 4;\n        const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8      ) / 2] >> (4 * (is / 8))) & 0x0F0F) |\n                                                     (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147\n\n        buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales - 32));\n    }\n}\n\nvoid block_a_to_registers(const uint reg_ib, const uint buf_ib) {\n    cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;\n\n    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {\n        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];\n    }\n}\n\nACC_TYPE mmq_dot_product(const uint ib_a) {\n    float result = 0.0;\n    int32_t q_sum = 0;\n\n    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {\n        // Subtract 4 from the quants to correct the 3rd bit offset\n        const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));\n\n        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);\n    }\n    result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);\n    q_sum = 0;\n\n    [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {\n        const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));\n\n        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);\n    }\n    result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);\n\n    return ACC_TYPE(float(cache_b.ds.x) * result);\n}\n#endif\n\n#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)\n// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)\nvoid block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {\n    const uint ib_k = ib / 8;\n    const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;\n\n    const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);\n    const uint qs_shift = ((iqs_k % 16) / 8) * 4;\n\n    // Repack 2x4 quants into one int\n#if defined(DATA_A_Q4_K)\n    const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx    ] >> qs_shift) & 0x0F0F0F0F;\n    const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;\n\n    buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);\n#else // defined(DATA_A_Q5_K)\n    const uint qh_idx = iqs * QUANT_R_MMQ;\n    const uint qh_shift = iqs_k / 8;\n\n    buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |\n                                   (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));\n#endif\n\n    if (iqs == 0) {\n        // Scale index\n        const uint is = iqs_k / 8;\n        u8vec2 scale_dm;\n        if (is < 4) {\n            scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);\n        } else {\n            scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),\n                              (data_a[ib_k].scales[is+4] >>  4) | ((data_a[ib_k].scales[is  ] & 0xC0) >> 2));\n        }\n\n        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm));\n    }\n}\n\nvoid block_a_to_registers(const uint reg_ib, const uint buf_ib) {\n    cache_a[reg_ib].dm = buf_a[buf_ib].dm;\n\n    [[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {\n        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];\n    }\n}\n\nACC_TYPE mmq_dot_product(const uint ib_a) {\n    int32_t q_sum = 0;\n\n    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {\n#if defined(DATA_A_Q4_K)\n        const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);\n#else // defined(DATA_A_Q5_K)\n        const int32_t qs_a = cache_a[ib_a].qs[iqs];\n#endif\n\n        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);\n    }\n\n    return ACC_TYPE(float(cache_b.ds.x) * float(cache_a[ib_a].dm.x) * float(q_sum) - float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));\n}\n#endif\n\n#if defined(DATA_A_Q6_K)\n// 2-byte loads for Q6_K blocks (210 bytes)\nvoid block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {\n    const uint ib_k = ib / 8;\n    const uint iqs_k = (ib % 8) * 8 + iqs;\n\n    const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;\n    const uint ql_shift = ((iqs_k % 32) / 16) * 4;\n\n    const uint qh_idx = (iqs_k / 32) * 8 + iqs;\n    const uint qh_shift = ((iqs_k % 32) / 8) * 2;\n\n    const i8vec2 vals00 = (unpack8(int32_t((data_a_packed16[ib_k].ql[ql_idx * 2    ] >> ql_shift) & uint16_t(0x0F0F))).xy |\n                          unpack8(int32_t(((data_a_packed16[ib_k].qh[qh_idx * 2    ] >> qh_shift) & uint16_t(0x0303)) << 4)).xy) - int8_t(32);\n    const i8vec2 vals01 = (unpack8(int32_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))).xy |\n                          unpack8(int32_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4)).xy) - int8_t(32);\n    buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));\n\n    if (iqs == 0) {\n        const uint is = iqs_k / 4;\n        const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy;\n\n        buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales));\n    }\n}\n\nvoid block_a_to_registers(const uint reg_ib, const uint buf_ib) {\n    cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;\n\n    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {\n        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];\n    }\n}\n\nACC_TYPE mmq_dot_product(const uint ib_a) {\n    float result = 0.0;\n    int32_t q_sum = 0;\n\n    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {\n        const int32_t qs_a = cache_a[ib_a].qs[iqs];\n\n        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);\n    }\n    result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);\n    q_sum = 0;\n\n    [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {\n        const int32_t qs_a = cache_a[ib_a].qs[iqs];\n\n        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);\n    }\n    result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);\n\n    return ACC_TYPE(float(cache_b.ds.x) * result);\n}\n#endif\n\nvoid block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {\n    if (is_in_bounds) {\n        const uint ib_outer = ib / 4;\n        const uint ib_inner = ib % 4;\n\n        if (iqs == 0) {\n            buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);\n        }\n\n        const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];\n        buf_b[buf_ib].qs[iqs * 4    ] = values.x;\n        buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;\n        buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;\n        buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;\n    } else {\n        if (iqs == 0) {\n            buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);\n        }\n\n        buf_b[buf_ib].qs[iqs * 4    ] = 0;\n        buf_b[buf_ib].qs[iqs * 4 + 1] = 0;\n        buf_b[buf_ib].qs[iqs * 4 + 2] = 0;\n        buf_b[buf_ib].qs[iqs * 4 + 3] = 0;\n    }\n}\n\nvoid block_b_to_registers(const uint ib) {\n    cache_b.ds = buf_b[ib].ds;\n    [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {\n        cache_b.qs[iqs] = buf_b[ib].qs[iqs];\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl",
    "content": "#if defined(DATA_A_Q4_0)\n#define QUANT_R_MMQ 2\nstruct block_a_cache {\n    uint32_t qs[16/4];\n    FLOAT_TYPE dm;\n};\n#elif defined(DATA_A_Q4_1)\n#define QUANT_R_MMQ 2\nstruct block_a_cache {\n    uint32_t qs[16/4];\n    FLOAT_TYPE_VEC2 dm;\n};\n#elif defined(DATA_A_Q5_0)\n#define QUANT_R_MMQ 2\nstruct block_a_cache {\n    uint32_t qs[16/4];\n    uint32_t qh;\n    FLOAT_TYPE dm;\n};\n#elif defined(DATA_A_Q5_1)\n#define QUANT_R_MMQ 2\nstruct block_a_cache {\n    uint32_t qs[16/4];\n    uint32_t qh;\n    FLOAT_TYPE_VEC2 dm;\n};\n#elif defined(DATA_A_Q8_0)\n#define QUANT_R_MMQ 1\n// AMD likes 4, Intel likes 1 and Nvidia likes 2\n// #define BK_STEP 1\nstruct block_a_cache {\n    int32_t qs[32/4];\n    FLOAT_TYPE dm;\n};\n#elif defined(DATA_A_MXFP4)\n#define QUANT_R_MMQ 2\nstruct block_a_cache {\n    int32_t qs[8];\n    FLOAT_TYPE d;\n};\n#elif defined(DATA_A_Q2_K)\n#define QUANT_R_MMQ 4\nstruct block_a_cache {\n    uint32_t qs[2];\n    u8vec2 scales;\n    FLOAT_TYPE_VEC2 dm;\n};\n#elif defined(DATA_A_Q3_K)\n#define QUANT_R_MMQ 2\nstruct block_a_cache {\n    uint32_t qs[4];\n    FLOAT_TYPE_VEC2 d_scales;\n};\n#elif defined(DATA_A_Q4_K)\n#define QUANT_R_MMQ 2\nstruct block_a_cache {\n    uint32_t qs[4];\n    FLOAT_TYPE_VEC2 dm;\n};\n#elif defined(DATA_A_Q5_K)\n#define QUANT_R_MMQ 1\nstruct block_a_cache {\n    int32_t qs[8];\n    FLOAT_TYPE_VEC2 dm;\n};\n#elif defined(DATA_A_Q6_K)\n#define QUANT_R_MMQ 1\nstruct block_a_cache {\n    int32_t qs[8];\n    FLOAT_TYPE_VEC2 d_scales;\n};\n#endif\n\nstruct block_b_cache\n{\n    int32_t qs[8];\n    FLOAT_TYPE_VEC2 ds;\n};\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/multi_add.comp",
    "content": "#version 450\n\n#extension GL_EXT_shader_16bit_storage : require\n#extension GL_EXT_nonuniform_qualifier : enable\n#extension GL_EXT_control_flow_attributes : require\n#if ADD_RMS\n#extension GL_KHR_shader_subgroup_arithmetic : enable\n#extension GL_KHR_shader_subgroup_basic : enable\n#endif\n\n#include \"rte.glsl\"\n#include \"types.glsl\"\n#include \"utils.glsl\"\n\nlayout (push_constant) uniform parameter2\n{\n    // shape for dst\n    uint ne20; uint ne21; uint ne22; uint ne23;\n\n    // strides for srcs+dst\n    uint nb[12][4];\n\n    uint rms_partials;\n} p;\n\n// No readonly/writeonly decorations. Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498\nlayout (binding = 0)  buffer A0 {A_TYPE data_a[];} a0;\nlayout (binding = 1)  buffer A1 {A_TYPE data_a[];} a1;\nlayout (binding = 2)  buffer A2 {A_TYPE data_a[];} a2;\nlayout (binding = 3)  buffer A3 {A_TYPE data_a[];} a3;\nlayout (binding = 4)  buffer A4 {A_TYPE data_a[];} a4;\nlayout (binding = 5)  buffer A5 {A_TYPE data_a[];} a5;\nlayout (binding = 6)  buffer A6 {A_TYPE data_a[];} a6;\nlayout (binding = 7)  buffer A7 {A_TYPE data_a[];} a7;\nlayout (binding = 8)  buffer A8 {A_TYPE data_a[];} a8;\nlayout (binding = 9)  buffer A9 {A_TYPE data_a[];} a9;\nlayout (binding = 10) buffer A10 {A_TYPE data_a[];} a10;\nlayout (binding = 11) buffer A11 {A_TYPE data_a[];} a11;\nlayout (binding = 0)  buffer D0 {D_TYPE data_d[];} d0;\nlayout (binding = 1)  buffer D1 {D_TYPE data_d[];} d1;\nlayout (binding = 2)  buffer D2 {D_TYPE data_d[];} d2;\nlayout (binding = 3)  buffer D3 {D_TYPE data_d[];} d3;\nlayout (binding = 4)  buffer D4 {D_TYPE data_d[];} d4;\nlayout (binding = 5)  buffer D5 {D_TYPE data_d[];} d5;\nlayout (binding = 6)  buffer D6 {D_TYPE data_d[];} d6;\nlayout (binding = 7)  buffer D7 {D_TYPE data_d[];} d7;\nlayout (binding = 8)  buffer D8 {D_TYPE data_d[];} d8;\nlayout (binding = 9)  buffer D9 {D_TYPE data_d[];} d9;\nlayout (binding = 10) buffer D10 {D_TYPE data_d[];} d10;\nlayout (binding = 11) buffer D11 {D_TYPE data_d[];} d11;\nlayout (binding = 0, std430)  buffer PartialBuf0 {float partial_sums[];} partials0;\nlayout (binding = 1, std430)  buffer PartialBuf1 {float partial_sums[];} partials1;\nlayout (binding = 2, std430)  buffer PartialBuf2 {float partial_sums[];} partials2;\nlayout (binding = 3, std430)  buffer PartialBuf3 {float partial_sums[];} partials3;\nlayout (binding = 4, std430)  buffer PartialBuf4 {float partial_sums[];} partials4;\nlayout (binding = 5, std430)  buffer PartialBuf5 {float partial_sums[];} partials5;\nlayout (binding = 6, std430)  buffer PartialBuf6 {float partial_sums[];} partials6;\nlayout (binding = 7, std430)  buffer PartialBuf7 {float partial_sums[];} partials7;\nlayout (binding = 8, std430)  buffer PartialBuf8 {float partial_sums[];} partials8;\nlayout (binding = 9, std430)  buffer PartialBuf9 {float partial_sums[];} partials9;\nlayout (binding = 10, std430) buffer PartialBuf10 {float partial_sums[];} partials10;\nlayout (binding = 11, std430) buffer PartialBuf11 {float partial_sums[];} partials11;\n\nlayout(constant_id = 0) const uint num_srcs = 2;\n\nFLOAT_TYPE load_a(uint b, uint i) {\n    switch (b) {\n    case 0:  return FLOAT_TYPE(a0.data_a[i]);\n    case 1:  return FLOAT_TYPE(a1.data_a[i]);\n    case 2:  return FLOAT_TYPE(a2.data_a[i]);\n    case 3:  return FLOAT_TYPE(a3.data_a[i]);\n    case 4:  return FLOAT_TYPE(a4.data_a[i]);\n    case 5:  return FLOAT_TYPE(a5.data_a[i]);\n    case 6:  return FLOAT_TYPE(a6.data_a[i]);\n    case 7:  return FLOAT_TYPE(a7.data_a[i]);\n    case 8:  return FLOAT_TYPE(a8.data_a[i]);\n    case 9:  return FLOAT_TYPE(a9.data_a[i]);\n    case 10: return FLOAT_TYPE(a10.data_a[i]);\n    case 11: return FLOAT_TYPE(a11.data_a[i]);\n    default: return FLOAT_TYPE(0);\n    }\n}\n\nvoid store_d(uint b, uint i, FLOAT_TYPE v) {\n    switch (b) {\n    case 0:  d0.data_d[i] = D_TYPE(v); break;\n    case 1:  d1.data_d[i] = D_TYPE(v); break;\n    case 2:  d2.data_d[i] = D_TYPE(v); break;\n    case 3:  d3.data_d[i] = D_TYPE(v); break;\n    case 4:  d4.data_d[i] = D_TYPE(v); break;\n    case 5:  d5.data_d[i] = D_TYPE(v); break;\n    case 6:  d6.data_d[i] = D_TYPE(v); break;\n    case 7:  d7.data_d[i] = D_TYPE(v); break;\n    case 8:  d8.data_d[i] = D_TYPE(v); break;\n    case 9:  d9.data_d[i] = D_TYPE(v); break;\n    case 10: d10.data_d[i] = D_TYPE(v); break;\n    case 11: d11.data_d[i] = D_TYPE(v); break;\n    default: break;\n    }\n}\n\nvoid store_partial(uint b, uint i, float v) {\n    switch (b) {\n    case 0:  partials0.partial_sums[i] = v; break;\n    case 1:  partials1.partial_sums[i] = v; break;\n    case 2:  partials2.partial_sums[i] = v; break;\n    case 3:  partials3.partial_sums[i] = v; break;\n    case 4:  partials4.partial_sums[i] = v; break;\n    case 5:  partials5.partial_sums[i] = v; break;\n    case 6:  partials6.partial_sums[i] = v; break;\n    case 7:  partials7.partial_sums[i] = v; break;\n    case 8:  partials8.partial_sums[i] = v; break;\n    case 9:  partials9.partial_sums[i] = v; break;\n    case 10: partials10.partial_sums[i] = v; break;\n    case 11: partials11.partial_sums[i] = v; break;\n    default: break;\n    }\n}\n\nuint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {\n    return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];\n}\n\nuint dst_idx(uint i00, uint i01, uint i02, uint i03) {\n    uint nb20 = p.nb[num_srcs][0];\n    uint nb21 = p.nb[num_srcs][1];\n    uint nb22 = p.nb[num_srcs][2];\n    uint nb23 = p.nb[num_srcs][3];\n    return i03*nb23 + i02*nb22 + i01*nb21 + i00*nb20;\n}\n\nuint get_idx() {\n    return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n}\n\nconst uint num_threads = 256;\n\nlayout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;\n\n#if ADD_RMS\n// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant\nshared FLOAT_TYPE sumsh[num_threads];\n#endif\n\nvoid main() {\n    uint idx = get_idx();\n    uint orig_idx = idx;\n\n    uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;\n\n    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation\n    const uint num_iter = 2;\n\n    FLOAT_TYPE sum_sq = 0;\n\n    [[unroll]] for (uint i = 0; i < num_iter; ++i) {\n        if (idx >= ne) {\n            continue;\n        }\n        uint i00, i01, i02, i03;\n        get_indices(idx, i00, i01, i02, i03, p.ne20, p.ne21, p.ne22, p.ne23);\n\n        FLOAT_TYPE sum = FLOAT_TYPE(0);\n        [[unroll]] for (uint s = 0; s < num_srcs; ++s) {\n            sum += load_a(s, src_idx(s, i00, i01, i02, i03));\n        }\n        sum_sq += sum*sum;\n        store_d(num_srcs, dst_idx(i00, i01, i02, i03), sum);\n\n        idx += num_threads;\n    }\n\n#if ADD_RMS\n    if (p.rms_partials != 0) {\n        // reduce the sum within each subgroup, then across subgroups\n        const uint NumSubgroups = num_threads / gl_SubgroupSize;\n        sum_sq = subgroupAdd(sum_sq);\n        if (gl_SubgroupInvocationID == 0) {\n            sumsh[gl_SubgroupID] = sum_sq;\n        }\n        barrier();\n        [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {\n            if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {\n                sum_sq += sumsh[gl_SubgroupID + s];\n                sumsh[gl_SubgroupID] = sum_sq;\n            }\n            barrier();\n        }\n\n        if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {\n            store_partial(num_srcs + 1, orig_idx / (num_iter * num_threads), sum_sq);\n        }\n    }\n#endif\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/neg.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n    data_d[i] = D_TYPE(-float(data_a[i]));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/norm.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n#define BLOCK_SIZE 512\n\nlayout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nshared vec2 sum[BLOCK_SIZE];\n\nvoid main() {\n    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;\n    const uint tid = gl_LocalInvocationID.x;\n\n    sum[tid] = vec2(0.0f, 0.0f);\n\n    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {\n        const float xi = float(data_a[row*p.KX + col]);\n        sum[tid].x += xi;\n        sum[tid].y += xi * xi;\n    }\n\n    // sum up partial sums and write back result\n    barrier();\n    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            sum[tid] += sum[tid + s];\n        }\n        barrier();\n    }\n\n    const float mean = sum[0].x / p.KX;\n    const float var = sum[0].y / p.KX - mean * mean;\n    const float inv_std = inversesqrt(var + p.param1);\n\n    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {\n        data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) buffer X {A_TYPE x[];};\nlayout (binding = 1) readonly buffer G {A_TYPE grad[];};\nlayout (binding = 2) buffer GM {A_TYPE gradm[];};\nlayout (binding = 3) buffer GV {A_TYPE gradv[];};\nlayout (binding = 4) readonly buffer P {float params[7];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float alpha  = params[0];\n    const float beta1  = params[1];\n    const float beta2  = params[2];\n    const float eps    = params[3];\n    const float wd     = params[4];\n    const float beta1h = params[5];\n    const float beta2h = params[6];\n\n    const float gi = grad[i];\n    const float gmi = gradm[i]*beta1 +    gi*(1.0f - beta1);\n    const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2);\n\n    gradm[i] = gmi;\n    gradv[i] = gvi;\n\n    const float mh =      gmi*beta1h;\n    const float vh = sqrt(gvi*beta2h) + eps;\n\n    x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) buffer X {A_TYPE data_x[];};\nlayout (binding = 1) readonly buffer G {A_TYPE data_grad[];};\nlayout (binding = 2) readonly buffer P {float data_params[2];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float alpha = data_params[0];\n    const float keep = 1.f - alpha * data_params[1];\n\n    data_x[i] = data_x[i] * keep - alpha * data_grad[i];\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/pad.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n\nlayout (push_constant) uniform parameter\n{\n    uint ne;\n    uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;\n    uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;\n    uint misalign_offsets;\n    uint circular;\n\n    uint lp0; uint rp0;\n    uint lp1; uint rp1;\n    uint lp2; uint rp2;\n    uint lp3; uint rp3;\n} p;\n\nuint get_aoffset() { return p.misalign_offsets >> 16; }\nuint get_doffset() { return p.misalign_offsets & 0xFFFF; }\n\nuint wrap_around(int coord, uint size) {\n    return (uint(coord + int(size))) % size; // add size to avoid issues with negative\n}\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    const uint i3 = idx / (p.ne12*p.ne11*p.ne10);\n    const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;\n    const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10);\n    const uint i2_offset = i2*p.ne11*p.ne10;\n    const uint i1 = (idx - i3_offset - i2_offset) / p.ne10;\n    const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;\n\n    const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00;\n    const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;\n\n    if (p.circular != 0u) {\n        const uint ci0 = wrap_around(int(i0) - int(p.lp0), p.ne00);\n        const uint ci1 = wrap_around(int(i1) - int(p.lp1), p.ne01);\n        const uint ci2 = wrap_around(int(i2) - int(p.lp2), p.ne02);\n        const uint ci3 = wrap_around(int(i3) - int(p.lp3), p.ne03);\n        const uint circular_src_idx = ci3*p.nb03 + ci2*p.nb02 + ci1*p.nb01 + ci0*p.nb00;\n        data_d[get_doffset() + dst_idx] = D_TYPE(data_a[get_aoffset() + circular_src_idx]);\n    } else {\n        const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&\n                             i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&\n                             i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&\n                             i3 >= p.lp3 && i3 < p.ne13 - p.rp3;\n        data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);\n    }\n\n\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/pool2d.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n\n#extension GL_EXT_shader_16bit_storage : require\n\nlayout(push_constant) uniform parameter {\n    uint IW; uint IH;\n    uint OW; uint OH;\n    uint OC;\n    uint pelements;\n    uint op;\n    int k0; int k1;\n    int s0; int s1;\n    int p0; int p1;\n} p;\n\n#define BLOCK_SIZE 512\n#define FLT_MAX 3.402823466e+38F\n#define OP_POOL_MAX 0u\n#define OP_POOL_AVG 1u\n\nlayout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;\n\nlayout(binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout(binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint idx = gl_GlobalInvocationID.x;\n    if (idx >= p.pelements) {\n        return;\n    }\n\n    const uint O_HW = p.OW * p.OH;\n\n    const uint nc = idx / O_HW;\n    const uint cur_oh = (idx % O_HW) / p.OW;\n    const uint cur_ow = (idx % O_HW) % p.OW;\n\n    const int start_h = int(cur_oh) * p.s0 - p.p0;\n    const uint bh = max(start_h, 0);\n    const uint eh = min(start_h + p.k0, p.IH);\n\n    const int start_w = int(cur_ow) * p.s1 - p.p1;\n    const uint bw = max(start_w, 0);\n    const uint ew = min(start_w + p.k1, p.IW);\n\n    const float scale = 1.0 / float(p.k0 * p.k1);\n    float res;\n\n    if (p.op == OP_POOL_AVG) {\n        res = 0.0;\n    } else if (p.op == OP_POOL_MAX) {\n        res = -FLT_MAX;\n    } else {\n        return;\n    }\n\n    #pragma unroll\n    for (uint i = bh; i < eh; i++) {\n        #pragma unroll\n        for (uint j = bw; j < ew; j++) {\n            const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]);\n\n            if (p.op == OP_POOL_AVG) {\n                res += cur * scale;\n            } else if (p.op == OP_POOL_MAX) {\n                res = max(res, cur);\n            }\n        }\n    }\n\n    data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res;\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : require\n#extension GL_EXT_shader_16bit_storage : require\n\n#ifdef USE_SUBGROUPS\n#extension GL_KHR_shader_subgroup_basic : require\n#extension GL_KHR_shader_subgroup_clustered : require\n\n#define INVOCATION_ID gl_SubgroupInvocationID.x\n#else\n#define INVOCATION_ID gl_LocalInvocationID.x\n#endif\n\nlayout (push_constant) uniform parameter\n{\n    uint ne;\n    uint num_blocks;\n} p;\n\n#include \"types.glsl\"\n\nlayout(constant_id = 0) const uint GROUP_SIZE = 32;\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {vec4 data_a[];};\n#ifndef QBLOCK_X4\nlayout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};\n#else\nlayout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};\n#endif\n\n#ifndef USE_SUBGROUPS\nshared float shmem[GROUP_SIZE];\n#endif\n\nvoid quantize(const uint wgid) {\n    const uint tid = INVOCATION_ID;\n\n    // Each thread handles a vec4, so 8 threads handle a block\n    const uint blocks_per_group = GROUP_SIZE / 8;\n\n    const uint block_in_wg = tid / 8;\n\n    const uint ib = wgid * blocks_per_group + block_in_wg;\n    const uint iqs = tid % 8;\n\n#ifdef QBLOCK_X4\n    const uint ibx4_outer = ib / 4;\n    const uint ibx4_inner = ib % 4;\n\n    const uint required_x4_blocks = (p.ne + 127) / 128;\n    if (ibx4_outer >= required_x4_blocks) {\n        return;\n    }\n#endif\n\n    const uint a_idx = ib * 8 + iqs;\n\n    vec4 vals = a_idx < p.ne / 4 ? data_a[a_idx] : vec4(0.0f);\n    const vec4 abs_vals = abs(vals);\n\n    // Find absolute max for each block\n    const float thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));\n#ifndef USE_SUBGROUPS\n    shmem[tid] = thread_max;\n    barrier();\n    [[unroll]] for (uint s = 4; s > 0; s >>= 1) {\n        if (iqs < s) {\n            shmem[tid] = max(shmem[tid], shmem[tid + s]);\n        }\n        barrier();\n    }\n\n    const float amax = shmem[block_in_wg * 8];\n#else\n    const float amax = subgroupClusteredMax(thread_max, 8);\n#endif\n\n    const float d = amax / 127.0;\n    const float d_inv = d != 0.0 ? 1.0 / d : 0.0;\n    vals = round(vals * d_inv);\n\n#ifndef QBLOCK_X4\n    data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));\n#else\n    data_b[ibx4_outer].qs[ibx4_inner * 8 + iqs] = pack32(i8vec4(round(vals)));\n#endif\n\n#ifndef USE_SUBGROUPS\n    barrier();\n#endif\n\n    // Calculate the sum for each block\n    const float thread_sum = vals.x + vals.y + vals.z + vals.w;\n#ifndef USE_SUBGROUPS\n    shmem[tid] = thread_sum;\n    barrier();\n    [[unroll]] for (uint s = 4; s > 0; s >>= 1) {\n        if (iqs < s) {\n            shmem[tid] += shmem[tid + s];\n        }\n        barrier();\n    }\n#else\n    const float sum = subgroupClusteredAdd(thread_sum, 8);\n#endif\n    if (iqs == 0) {\n#ifndef USE_SUBGROUPS\n        const float sum = shmem[tid];\n#endif\n\n#ifndef QBLOCK_X4\n        data_b[ib].ds = f16vec2(vec2(d, sum * d));\n#else\n        data_b[ibx4_outer].ds[ibx4_inner] = f16vec2(vec2(d, sum * d));\n#endif\n    }\n}\n\nvoid main() {\n    uint wgid = gl_WorkGroupID.x;\n    while (wgid < p.num_blocks) {\n        quantize(wgid);\n        wgid += gl_NumWorkGroups.x;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/reglu.comp",
    "content": "#version 450\n\n#include \"glu_head.glsl\"\n\nfloat op(float a, float b) {\n    return max(a, 0.0f) * b;\n}\n\n#include \"glu_main.glsl\"\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/relu.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    data_d[i] = D_TYPE(max(float(data_a[i]), 0));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/repeat.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nuint src0_idx_mod(uint idx) {\n    const uint i13 = idx / (p.ne12*p.ne11*p.ne10);\n    const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;\n    const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);\n    const uint i12_offset = i12*p.ne11*p.ne10;\n    const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;\n    const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;\n    return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00;\n}\n\nvoid main() {\n    const uint idx = get_idx();\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx_mod(idx)]);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/repeat_back.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint idx = get_idx();\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    // Destination multi-index (inlined dst_idx)\n    const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);\n    const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;\n    const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);\n    const uint i12_offset = i12*p.ne11*p.ne10;\n    const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);\n    const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;\n    const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;\n\n    // Accumulate from sources\n    A_TYPE acc = A_TYPE(0);\n    for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) {\n        for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) {\n            for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) {\n                for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) {\n                    acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00];\n                }\n            }\n        }\n    }\n\n    data_d[get_doffset() + d_idx] = D_TYPE(acc);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/rms_norm.comp",
    "content": "#version 450\n\n#include \"generic_binary_head.glsl\"\n#include \"types.glsl\"\n\n#if RMS_NORM_ROPE_FUSION\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) readonly buffer B {B_TYPE data_b[];};\n\n// data is passed from rms_norm -> rope through shared memory.\n// rms_norm calls this data_d, rope calls this rope_data_a.\n// Binding 2 is not used\nshared FLOAT_TYPE rope_data_a[1024];\n#define data_d rope_data_a\n\nlayout (binding = 3) readonly buffer R_Y {int rope_data_pos[];};\nlayout (binding = 4) readonly buffer R_Z {float rope_data_ff[];};\nlayout (binding = 5) writeonly buffer R_D {ROPE_D_TYPE rope_data_d[];};\nlayout (binding = 6) readonly buffer R_I {uvec2 rope_data_i[];}; // indices for set_rows\n\n#include \"rope_params.glsl\"\n#include \"rope_funcs.glsl\"\n\n#define GGML_ROPE_TYPE_NORMAL 0\n#define GGML_ROPE_TYPE_NEOX   2\n#define GGML_ROPE_TYPE_MROPE  8\n#define GGML_ROPE_TYPE_VISION 24\n\n#endif\n\n#extension GL_EXT_control_flow_attributes : enable\n#define BLOCK_SIZE 512\n\nlayout (constant_id = 1) const bool do_multiply = false;\n\nlayout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;\n\nshared FLOAT_TYPE sumsh[BLOCK_SIZE];\n\nvoid rms_norm(uint num_iters) {\n    const uint ncols     = p.ne00;\n    const uint nrows     = gl_NumWorkGroups.x;\n    const uint nchannels = gl_NumWorkGroups.y;\n\n    const uint row       = gl_WorkGroupID.x;\n    const uint channel   = gl_WorkGroupID.y;\n    const uint samp      = gl_WorkGroupID.z;\n    const uint tid       = gl_LocalInvocationID.x;\n\n    const uint stride_row       = p.nb01;\n    const uint stride_channel   = p.nb02;\n    const uint stride_sample    = p.nb03;\n\n    uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();\n    uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();\n#if RMS_NORM_ROPE_FUSION\n    // Per-row offset in shared memory\n    uint32_t d_offset = 0;\n#else\n    uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();\n#endif\n    FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp\n\n    [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {\n        FLOAT_TYPE xi = FLOAT_TYPE(0);\n        if (col < ncols) {\n            xi = FLOAT_TYPE(data_a[a_offset + col]);\n        }\n        sum += xi * xi;\n    }\n\n    sumsh[tid] = sum;\n    // sum up partial sums and write back result\n    barrier();\n    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            sum += sumsh[tid + s];\n            sumsh[tid] = sum;\n        }\n        barrier();\n    }\n    sum = sumsh[0];\n\n    const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);\n    const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));\n\n    if (do_multiply) {\n        if (ncols > p.ne10) {\n            [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {\n                if (col >= ncols) {\n                    continue;\n                }\n                data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));\n            }\n        } else {\n            [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {\n                if (col >= ncols) {\n                    continue;\n                }\n                data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));\n            }\n        }\n    } else {\n        [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {\n            if (col >= ncols) {\n                continue;\n            }\n            data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));\n        }\n    }\n#if RMS_NORM_ROPE_FUSION\n    barrier();\n    rope_params rp = p.rope;\n    for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {\n        if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {\n            rope_neox(t, row, channel, samp, rp);\n        } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {\n            rope_norm(t, row, channel, samp, rp);\n        }\n    }\n#endif\n}\n\nvoid main() {\n    // instantiate the rms_norm function for several different\n    // dimensions, to allow loop unrolling\n    uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE;\n    if (num_blocks > 32) {\n        rms_norm(num_blocks);\n    } else if (num_blocks > 16) {\n        rms_norm(32);\n    } else if (num_blocks > 12) {\n        rms_norm(16);\n    } else if (num_blocks > 10) {\n        rms_norm(12);\n    } else if (num_blocks > 8) {\n        rms_norm(10);\n    } else if (num_blocks > 4) {\n        rms_norm(8);\n    } else if (num_blocks == 4) {\n        rms_norm(4);\n    } else if (num_blocks == 3) {\n        rms_norm(3);\n    } else if (num_blocks == 2) {\n        rms_norm(2);\n    } else if (num_blocks == 1) {\n        rms_norm(1);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n#define BLOCK_SIZE 512\n\nlayout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer G {A_TYPE data_a[];};\nlayout (binding = 1) readonly buffer X {B_TYPE data_b[];};\nlayout (binding = 2) writeonly buffer D {D_TYPE data_d[];};\n\nshared FLOAT_TYPE sum_xx[BLOCK_SIZE];\nshared FLOAT_TYPE sum_xg[BLOCK_SIZE];\n\nvoid main() {\n    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;\n    const uint tid = gl_LocalInvocationID.x;\n\n    // Compute derivative of x[i]/norm(x) = g[i]/norm(x) - x[i] dot(x,g)/KX / norm(x)^1.5\n\n    // partial sums for thread in warp\n    sum_xx[tid] = FLOAT_TYPE(0.0f);\n    sum_xg[tid] = FLOAT_TYPE(0.0f);\n\n    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {\n        const FLOAT_TYPE gi = FLOAT_TYPE(data_a[row*p.KX + col]);\n        const FLOAT_TYPE xi = FLOAT_TYPE(data_b[row*p.KX + col]);\n        sum_xx[tid] += xi * xi;\n        sum_xg[tid] += xi * gi;\n    }\n\n    // sum up partial sums and write back result\n    barrier();\n    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            sum_xx[tid] += sum_xx[tid + s];\n            sum_xg[tid] += sum_xg[tid + s];\n        }\n        barrier();\n    }\n\n    const FLOAT_TYPE eps = FLOAT_TYPE(p.param1);\n    const FLOAT_TYPE mean = sum_xx[0] / FLOAT_TYPE(p.KX);\n    const FLOAT_TYPE scale_g = inversesqrt(mean + eps);\n    const FLOAT_TYPE scale_x = -scale_g * sum_xg[0] / (sum_xx[0] + FLOAT_TYPE(p.KX) * eps);\n\n    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {\n        data_d[row*p.KX + col] = D_TYPE(\n            scale_g * FLOAT_TYPE(data_a[row*p.KX + col]) +\n            scale_x * FLOAT_TYPE(data_b[row*p.KX + col]));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp",
    "content": "#version 450\n\n#include \"generic_binary_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_KHR_shader_subgroup_arithmetic : enable\n#extension GL_KHR_shader_subgroup_basic : enable\n\n#define BLOCK_SIZE 128\n\nlayout (constant_id = 1) const bool do_multiply = false;\n\nlayout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];};\n\nshared FLOAT_TYPE sumsh[BLOCK_SIZE];\n\nvoid main() {\n    const uint ncols     = p.ne00;\n    const uint nrows     = gl_NumWorkGroups.x;\n    const uint nchannels = gl_NumWorkGroups.y;\n\n    const uint row       = 0;\n    const uint channel   = gl_WorkGroupID.y;\n    const uint samp      = gl_WorkGroupID.z;\n    // The work is split across multiple workgroups in the x dimension. Each invocation\n    // processes one element\n    const uint tid       = gl_GlobalInvocationID.x;\n\n    const uint stride_row       = p.nb01;\n    const uint stride_channel   = p.nb02;\n    const uint stride_sample    = p.nb03;\n\n    uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();\n    uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();\n    uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();\n\n    FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp\n\n    uint32_t num_partials = p.param3;\n    for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) {\n        sum += partial_sums[i];\n    }\n    sum = subgroupAdd(sum);\n\n    uint col = tid;\n    if (col >= ncols) {\n        return;\n    }\n\n    const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);\n    const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));\n\n    if (do_multiply) {\n        if (ncols > p.ne10) {\n            data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));\n        } else {\n            data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));\n        }\n    } else {\n        data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/roll.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nuint wrap_idx(int i, uint ne) {\n    if (i < 0) {\n        return i + ne;\n    } else if (i >= ne) {\n        return i - ne;\n    }\n    return i;\n}\n\nvoid main() {\n    const uint idx = get_idx();\n    if (idx >= p.ne) {\n        return;\n    }\n\n    const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);\n    const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;\n    const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L);\n    const uint i2_offset = i2*p.ne11*p.ne10;\n    const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L);\n    const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;\n\n    const uint p1 = floatBitsToUint(p.param1);\n    const uint p2 = floatBitsToUint(p.param2);\n    const int s0 = int(p1 >> 16)    - 0x8000;\n    const int s1 = int(p1 & 0xFFFF) - 0x8000;\n    const int s2 = int(p2 >> 16)    - 0x8000;\n    const int s3 = int(p2 & 0xFFFF) - 0x8000;\n\n    const uint i00 = wrap_idx(int(i0) - s0, p.ne10);\n    const uint i01 = wrap_idx(int(i1) - s1, p.ne11);\n    const uint i02 = wrap_idx(int(i2) - s2, p.ne12);\n    const uint i03 = wrap_idx(int(i3) - s3, p.ne13);\n\n    const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;\n    const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10;\n\n    data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl",
    "content": "\nfloat rope_yarn_ramp(const float low, const float high, const uint i0) {\n    const float y = (i0 / 2 - low) / max(0.001f, high - low);\n    return 1.0f - min(1.0f, max(0.0f, y));\n}\n\nuint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) {\n#if RMS_NORM_ROPE_FUSION\n    // Per-row offset in shared memory\n    const uint ix = i0;\n#else\n    const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;\n#endif\n    return ix;\n}\n\nvoid rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta, rope_params p) {\n    float mscale = p.attn_factor;\n    // Get n-d rotational scaling corrected for extrapolation\n    float theta_interp = p.freq_scale * theta_extrap;\n    float theta = theta_interp;\n    if (p.ext_factor != 0.0f) {\n        float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;\n        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;\n\n        // Get n-d magnitude scaling corrected for interpolation\n        mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);\n    }\n    // Backprogagation uses inverted rotation\n    if (p.is_back != 0) {\n        theta = -theta;\n    }\n    cos_theta = cos(theta) * mscale;\n    sin_theta = sin(theta) * mscale;\n}\n\nvoid rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {\n    if (i0 >= p.ne00) {\n        return;\n    }\n\n    uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;\n    const uint ix = rope_a_coord(i0, i1, i2, i3, p);\n\n    // Fusion optimization: ROPE + VIEW + SET_ROWS.\n    // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.\n    if (p.set_rows_stride != 0) {\n        idst = i1*p.nb11 + i0;\n        idst += rope_data_i[i2].x * p.set_rows_stride;\n    }\n\n    if (i0 >= p.n_dims) {\n        rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]);\n        rope_data_d[idst + 1] = ROPE_D_TYPE(rope_data_a[ix + 1]);\n\n        return;\n    }\n\n    const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);\n\n    const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;\n\n    float cos_theta, sin_theta;\n    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);\n\n    const float x0 = float(rope_data_a[ix + 0]);\n    const float x1 = float(rope_data_a[ix + 1]);\n\n    rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);\n    rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);\n}\n\nvoid rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {\n    if (i0 >= p.ne00) {\n        return;\n    }\n\n    uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;\n    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);\n\n    // Fusion optimization: ROPE + VIEW + SET_ROWS.\n    // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.\n    if (p.set_rows_stride != 0) {\n        idst = i1*p.nb11 + i0/2;\n        idst += rope_data_i[i2].x * p.set_rows_stride;\n    }\n\n    if (i0 >= p.n_dims) {\n        rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);\n        rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);\n\n        return;\n    }\n\n    const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);\n\n    const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;\n\n    float cos_theta, sin_theta;\n    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);\n\n    const float x0 = float(rope_data_a[ix + 0]);\n    const float x1 = float(rope_data_a[ix + p.n_dims/2]);\n\n    rope_data_d[idst + 0]          = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);\n    rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);\n}\n\n\nvoid rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {\n    if (i0 >= p.ne00) {\n        return;\n    }\n\n    uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;\n    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);\n\n    // Fusion optimization: ROPE + VIEW + SET_ROWS.\n    // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.\n    if (p.set_rows_stride != 0) {\n        idst = i1*p.nb11 + i0/2;\n        idst += rope_data_i[i2].x * p.set_rows_stride;\n    }\n\n    if (i0 >= p.n_dims) {\n        rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);\n        rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);\n\n        return;\n    }\n\n    const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];\n    const int sec_w = p.sections[1] + p.sections[0];\n    const uint sector = (i0 / 2) % sect_dims;\n\n    float theta_base = 0.0;\n    if (p.is_imrope != 0) {\n        if (sector % 3 == 1 && sector < 3 * p.sections[1]) {\n            theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);\n        } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {\n            theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);\n        } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {\n            theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);\n        } else {\n            theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);\n        }\n    } else {\n        if (sector < p.sections[0]) {\n            theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);\n        }\n        else if (sector >= p.sections[0] && sector < sec_w) {\n            theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);\n        }\n        else if (sector >= sec_w && sector < sec_w + p.sections[2]) {\n            theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);\n        }\n        else if (sector >= sec_w + p.sections[2]) {\n            theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);\n        }\n    }\n\n    const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;\n\n    float cos_theta, sin_theta;\n    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);\n\n    const float x0 = float(rope_data_a[ix + 0]);\n    const float x1 = float(rope_data_a[ix + p.n_dims/2]);\n\n    rope_data_d[idst + 0]          = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);\n    rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);\n}\n\nvoid rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {\n    if (i0 >= p.ne00) {\n        return;\n    }\n\n    const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;\n    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);\n\n    const int sect_dims = p.sections[0] + p.sections[1];\n    const int sec_w = p.sections[1] + p.sections[0];\n    const uint sector = (i0 / 2) % sect_dims;\n\n    float theta_base = 0.0;\n    if (sector < p.sections[0]) {\n        const uint p0 = sector;\n        theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0);\n    }\n    else if (sector >= p.sections[0] && sector < sec_w) {\n        const uint p0 = sector - p.sections[0];\n        theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0);\n    }\n\n    const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;\n\n    float cos_theta, sin_theta;\n    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);\n\n    const float x0 = float(rope_data_a[ix + 0]);\n    const float x1 = float(rope_data_a[ix + p.n_dims]);\n\n    rope_data_d[idst + 0]        = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);\n    rope_data_d[idst + p.n_dims] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);\n}\n\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/rope_head.glsl",
    "content": "#include \"types.glsl\"\n\n#extension GL_EXT_shader_16bit_storage : require\n\n#include \"rte.glsl\"\n#include \"rope_params.glsl\"\n\nlayout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE rope_data_a[];};\nlayout (binding = 1) readonly buffer Y {int rope_data_pos[];};\nlayout (binding = 2) readonly buffer Z {float rope_data_ff[];};\nlayout (binding = 3) writeonly buffer D {ROPE_D_TYPE rope_data_d[];};\nlayout (binding = 4) readonly buffer I {uvec2 rope_data_i[];}; // indices for set_rows\n\n\nlayout (push_constant) uniform parameter {\n    rope_params pc;\n};\n\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/rope_multi.comp",
    "content": "#version 450\n\n#include \"rope_head.glsl\"\n#include \"rope_funcs.glsl\"\n\nvoid main() {\n    const uint i0 = 2*gl_GlobalInvocationID.y;\n    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;\n    if (row >= pc.nrows) {\n        return;\n    }\n    const uint i3 = row / (pc.ne01*pc.ne02);\n    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;\n    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);\n\n    rope_multi(i0, i1, i2, i3, pc);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/rope_neox.comp",
    "content": "#version 450\n\n#include \"rope_head.glsl\"\n#include \"rope_funcs.glsl\"\n\nvoid main() {\n    const uint i0 = 2*gl_GlobalInvocationID.y;\n    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;\n    if (row >= pc.nrows) {\n        return;\n    }\n    const uint i3 = row / (pc.ne01*pc.ne02);\n    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;\n    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);\n\n    rope_neox(i0, i1, i2, i3, pc);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/rope_norm.comp",
    "content": "#version 450\n\n#include \"rope_head.glsl\"\n#include \"rope_funcs.glsl\"\n\nvoid main() {\n    const uint i0 = 2*gl_GlobalInvocationID.y;\n    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;\n    if (row >= pc.nrows) {\n        return;\n    }\n    const uint i3 = row / (pc.ne01*pc.ne02);\n    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;\n    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);\n\n    rope_norm(i0, i1, i2, i3, pc);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/rope_params.glsl",
    "content": "#if !defined(GGML_ROPE_PARAMS)\n#define GGML_ROPE_PARAMS\n\n#include \"rte.glsl\"\n\nstruct rope_params {\n    uint rope_mode;\n    uint nrows;\n    uint n_dims;\n    float freq_scale;\n    float freq_base;\n    float ext_factor;\n    float attn_factor;\n    float corr_dims[2];\n    float theta_scale;\n    uint has_ff;\n    int sections[4];\n    uint is_imrope;\n    uint is_back;\n    uint set_rows_stride;\n\n    uint ne00;\n    uint ne01;\n    uint ne02;\n    uint nb01;\n    uint nb02;\n    uint nb03;\n    uint nb11;\n    uint nb12;\n    uint nb13;\n};\n\n#endif // !defined(GGML_ROPE_PARAMS)\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/rope_vision.comp",
    "content": "#version 450\n\n#include \"rope_head.glsl\"\n#include \"rope_funcs.glsl\"\n\nvoid main() {\n    const uint i0 = 2*gl_GlobalInvocationID.y;\n    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;\n    if (row >= pc.nrows) {\n        return;\n    }\n    const uint i3 = row / (pc.ne01*pc.ne02);\n    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;\n    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);\n\n    rope_vision(i0, i1, i2, i3, pc);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/round.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float x = float(data_a[i]);\n    float result;\n    // Round halfway cases away from zero as roundf does.\n    if (x >= 0.0) {\n        result = floor(x + 0.5);\n    } else {\n        result = ceil(x - 0.5);\n    }\n    data_d[i] = D_TYPE(result);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/rte.glsl",
    "content": "\n#if RTE16\n#extension GL_EXT_spirv_intrinsics : enable\nspirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits\n#endif // RTE16\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/scale.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\nconst uint num_threads = 128;\n\nlayout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    uint idx = get_idx();\n\n    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation\n    const uint num_iter = 4;\n\n    [[unroll]] for (uint i = 0; i < num_iter; ++i) {\n        if (idx >= p.ne) {\n            continue;\n        }\n\n        data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2));\n        idx += num_threads;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/sgn.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    data_d[i] = D_TYPE(sign(float(data_a[i])));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/sigmoid.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n    data_d[i] = D_TYPE(1. / (1 + exp(-1. * float(data_a[i]))));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/silu.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float xi = float(data_a[i]);\n    data_d[i] = D_TYPE(xi / (1.0f + exp(-xi)));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/silu_back.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer G {A_TYPE data_g[];};\nlayout (binding = 1) readonly buffer X {B_TYPE data_x[];};\nlayout (binding = 2) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    // Compute derivative of SiLU(x): 1/(1+exp(-x)) - x*exp(-x)/(1+exp(-x))^2\n\n    const float xi = float(data_x[i]);\n    const float s = 1.0f / (1.0f + exp(-xi));\n    data_d[i] = D_TYPE(data_g[i] * (s + xi * s * (1 - s)));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/sin.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint idx = get_idx();\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);\n    data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/soft_max.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout (push_constant) uniform parameter\n{\n    uint KX;\n    uint KY;\n    uint ne00;\n    uint ne01;\n    uint ne02;\n    uint ne12;\n    uint ne13;\n    uint nb11;\n    uint nb12;\n    uint nb13;\n    float scale;\n    float max_bias;\n    float m0;\n    float m1;\n    uint n_head_log2;\n    uint nrows_x;\n    uint has_sinks;\n} p;\n\n#include \"types.glsl\"\n\nlayout(constant_id = 0) const uint BLOCK_SIZE = 32;\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) readonly buffer Y {B_TYPE data_b[];};\nlayout (binding = 2) readonly buffer Z {float data_c[];};\nlayout (binding = 3) buffer D {D_TYPE data_d[];};\n\nshared FLOAT_TYPE vals[BLOCK_SIZE];\n\n// num_iters is the number of BLOCK_SIZE loop iterations we need to iterate\n// over all the columns. The main function tries to pass a constant here,\n// as if it were a template function, to allow unrolling.\nvoid soft_max(uint num_iters) {\n    const uint tid = gl_LocalInvocationID.x;\n    const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;\n\n    const uint32_t i03 = rowx / (p.ne01 * p.ne02);\n    const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;\n    const uint32_t i01 = rowx % p.ne01;\n\n    uint rowy_start = 0;\n    if (p.KY > 0) {\n        rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;\n    }\n\n    if (rowx >= p.nrows_x) {\n        return;\n    }\n\n    float slope = 1.0f;\n\n    // ALiBi\n    if (p.max_bias > 0.0f) {\n        const uint h = (rowx / p.ne01) % p.ne02; // head index\n\n        const float base = h < p.n_head_log2 ? p.m0 : p.m1;\n        const uint   exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;\n\n        slope = pow(base, exp);\n    }\n\n    // Find max\n    FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];\n\n    // Cache values while we compute the max, so we don't need to read them\n    // again when we're ready to compute exp(x-max).\n    const uint DATA_CACHE_SIZE = 16;\n    FLOAT_TYPE data_cache[DATA_CACHE_SIZE];\n\n    [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {\n        const uint col = col0 + tid;\n\n        FLOAT_TYPE a = FLOAT_TYPE(0);\n        if (col < p.KX) {\n            a = data_a[rowx * p.KX + col];\n        }\n\n        FLOAT_TYPE b = FLOAT_TYPE(0);\n        if (p.KY > 0 && col < p.KX) {\n            b = data_b[rowy_start + col];\n        }\n\n        FLOAT_TYPE v = a * p.scale + slope * b;\n\n        if (col < p.KX) {\n            max_val = max(max_val, v);\n        }\n\n        if (idx < DATA_CACHE_SIZE) {\n            data_cache[idx] = v;\n        }\n    }\n\n    // reduce across the workgroup\n    vals[tid] = max_val;\n    barrier();\n    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            vals[tid] = max(vals[tid], vals[tid + s]);\n        }\n        barrier();\n    }\n\n    max_val = vals[0];\n    barrier();\n\n    FLOAT_TYPE sum = FLOAT_TYPE(0.0f);\n\n    // Compute sum{exp(x - max)}\n    [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {\n        const uint col = col0 + tid;\n\n        if (col >= p.KX) {\n            break;\n        }\n\n        // compute exp(a*scale+b*slope), add it to sum, and cache the new value\n        // in data_cache if possible.\n        const uint i = rowx * p.KX + col;\n        FLOAT_TYPE val;\n        if (idx < DATA_CACHE_SIZE) {\n            val = exp(data_cache[idx] - max_val);\n        } else {\n            val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);\n        }\n        sum += val;\n        if (idx < DATA_CACHE_SIZE) {\n            data_cache[idx] = val;\n        } else {\n            data_d[i] = D_TYPE(val);\n        }\n    }\n\n    // reduce across the workgroup\n    vals[tid] = sum;\n    barrier();\n    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            vals[tid] += vals[tid + s];\n        }\n        barrier();\n    }\n    sum = vals[0];\n\n    if (p.has_sinks != 0) {\n        sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));\n    }\n\n    FLOAT_TYPE rcpdivisor = 1.0/sum;\n\n    [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {\n        const uint col = col0 + tid;\n\n        if (col >= p.KX) {\n            continue;\n        }\n\n        if (idx < DATA_CACHE_SIZE) {\n            data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor);\n        } else {\n            data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);\n        }\n    }\n}\n\nvoid main() {\n    // instantiate the soft_max function for several different\n    // dimensions, to allow loop unrolling\n    uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE;\n    if (num_blocks > 32) {\n        soft_max(num_blocks);\n    } else if (num_blocks > 16) {\n        soft_max(32);\n    } else if (num_blocks > 8) {\n        soft_max(16);\n    } else if (num_blocks > 4) {\n        soft_max(8);\n    } else if (num_blocks == 4) {\n        soft_max(4);\n    } else if (num_blocks == 3) {\n        soft_max(3);\n    } else if (num_blocks == 2) {\n        soft_max(2);\n    } else if (num_blocks == 1) {\n        soft_max(1);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/soft_max_back.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : enable\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\nlayout(constant_id = 0) const uint BLOCK_SIZE = 32;\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\n// In this shader Y = softmax(X) and X is not provided as input.\n\nlayout (binding = 0) readonly buffer G {A_TYPE data_g[];};\nlayout (binding = 1) readonly buffer Y {B_TYPE data_y[];};\nlayout (binding = 2) buffer D {D_TYPE data_d[];};\n\nshared FLOAT_TYPE sum_yg[BLOCK_SIZE];\n\nvoid main() {\n    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;\n    const uint tid = gl_LocalInvocationID.x;\n\n    if (row >= p.KY) {\n        return;\n    }\n\n    FLOAT_TYPE scale = p.param1;\n\n    // partial sums for thread in warp\n    sum_yg[tid] = FLOAT_TYPE(0.0f);\n\n    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {\n        const FLOAT_TYPE gi = FLOAT_TYPE(data_g[row*p.KX + col]);\n        const FLOAT_TYPE yi = FLOAT_TYPE(data_y[row*p.KX + col]);\n        sum_yg[tid] += yi * gi;\n    }\n\n    // sum up partial sums and write back result\n    barrier();\n    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            sum_yg[tid] += sum_yg[tid + s];\n        }\n        barrier();\n    }\n\n    const FLOAT_TYPE dot_yg = sum_yg[0];\n\n    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {\n        data_d[row*p.KX + col] = D_TYPE(scale\n            * (FLOAT_TYPE(data_g[row*p.KX + col]) - dot_yg)\n            * FLOAT_TYPE(data_y[row*p.KX + col]));\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp",
    "content": "#version 450\n\n#include \"soft_max_large_common.glsl\"\n\nvoid main() {\n    const uint tid = gl_LocalInvocationID.x;\n    const uint rowx = gl_WorkGroupID.y;\n    const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;\n\n    const uint32_t i03 = rowx / (p.ne01 * p.ne02);\n    const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;\n    const uint32_t i01 = rowx % p.ne01;\n\n    uint rowy_start = 0;\n    if (p.KY > 0) {\n        rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;\n    }\n\n    if (rowx >= p.nrows_x) {\n        return;\n    }\n\n    float slope = get_slope(rowx);\n\n    // Find max\n    FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];\n\n    [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {\n        const uint col = col0 + tid;\n\n        FLOAT_TYPE a = FLOAT_TYPE(0);\n        if (col < p.KX) {\n            a = data_a[rowx * p.KX + col];\n        }\n\n        FLOAT_TYPE b = FLOAT_TYPE(0);\n        if (p.KY > 0 && col < p.KX) {\n            b = data_b[rowy_start + col];\n        }\n\n        FLOAT_TYPE v = a * p.scale + slope * b;\n\n        if (col < p.KX) {\n            max_val = max(max_val, v);\n        }\n    }\n\n    // reduce across the workgroup\n    vals[tid] = max_val;\n    barrier();\n    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            vals[tid] = max(vals[tid], vals[tid + s]);\n        }\n        barrier();\n    }\n\n    if (tid == 0) {\n        max_val = vals[0];\n        data_m[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = max_val;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp",
    "content": "#version 450\n\n#include \"soft_max_large_common.glsl\"\n\nvoid main() {\n    const uint tid = gl_LocalInvocationID.x;\n    const uint rowx = gl_WorkGroupID.y;\n    const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;\n\n    const uint32_t i03 = rowx / (p.ne01 * p.ne02);\n    const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;\n    const uint32_t i01 = rowx % p.ne01;\n\n    uint rowy_start = 0;\n    if (p.KY > 0) {\n        rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;\n    }\n\n    if (rowx >= p.nrows_x) {\n        return;\n    }\n\n    float slope = get_slope(rowx);\n\n    // Find max\n    FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];\n\n    [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {\n        if (i + tid < gl_NumWorkGroups.x) {\n            max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);\n        }\n    }\n\n    // reduce across the workgroup\n    vals[tid] = max_val;\n    barrier();\n    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            vals[tid] = max(max_val, vals[tid + s]);\n        }\n        barrier();\n    }\n\n    max_val = vals[0];\n    barrier();\n\n    FLOAT_TYPE sum = FLOAT_TYPE(0.0f);\n\n    // Compute sum{exp(x - max)}\n    [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {\n        const uint col = col0 + tid;\n\n        if (col >= p.KX) {\n            break;\n        }\n\n        // compute exp(a*scale+b*slope), add it to sum\n        const uint i = rowx * p.KX + col;\n        FLOAT_TYPE val;\n        val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);\n        sum += val;\n        data_d[i] = D_TYPE(val);\n    }\n\n    // reduce across the workgroup\n    vals[tid] = sum;\n    barrier();\n    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            vals[tid] += vals[tid + s];\n        }\n        barrier();\n    }\n\n    if (tid == 0) {\n        sum = vals[0];\n        data_s[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = sum;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp",
    "content": "#version 450\n\n#include \"soft_max_large_common.glsl\"\n\nshared FLOAT_TYPE sumsh[BLOCK_SIZE];\n\nvoid main() {\n    const uint tid = gl_LocalInvocationID.x;\n    const uint rowx = gl_WorkGroupID.y;\n    const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;\n\n    const uint32_t i03 = rowx / (p.ne01 * p.ne02);\n    const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;\n    const uint32_t i01 = rowx % p.ne01;\n\n    uint rowy_start = 0;\n    if (p.KY > 0) {\n        rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;\n    }\n\n    if (rowx >= p.nrows_x) {\n        return;\n    }\n\n    FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];\n    FLOAT_TYPE sum = FLOAT_TYPE(0.0f);\n\n    [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {\n        if (i + tid < gl_NumWorkGroups.x) {\n            max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);\n            sum += data_s[rowx * gl_NumWorkGroups.x + i + tid];\n        }\n    }\n\n    // reduce across the workgroup\n    vals[tid] = max_val;\n    sumsh[tid] = sum;\n    barrier();\n    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {\n        if (tid < s) {\n            vals[tid] = max(max_val, vals[tid + s]);\n            sumsh[tid] += sumsh[tid + s];\n        }\n        barrier();\n    }\n\n    max_val = vals[0];\n    sum = sumsh[0];\n\n    if (p.has_sinks != 0) {\n        sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));\n    }\n\n    FLOAT_TYPE rcpdivisor = 1.0/sum;\n\n    [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {\n        const uint col = col0 + tid;\n\n        if (col >= p.KX) {\n            continue;\n        }\n\n        data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl",
    "content": "#extension GL_EXT_control_flow_attributes : enable\n\nlayout (push_constant) uniform parameter\n{\n    uint KX;\n    uint KY;\n    uint ne00;\n    uint ne01;\n    uint ne02;\n    uint ne12;\n    uint ne13;\n    uint nb11;\n    uint nb12;\n    uint nb13;\n    float scale;\n    float max_bias;\n    float m0;\n    float m1;\n    uint n_head_log2;\n    uint nrows_x;\n    uint has_sinks;\n} p;\n\n#include \"types.glsl\"\n\nlayout(constant_id = 0) const uint BLOCK_SIZE = 128;\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\nlayout(constant_id = 1) const uint num_iters = 4;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) readonly buffer Y {B_TYPE data_b[];};\nlayout (binding = 2) readonly buffer Z {float data_c[];};\nlayout (binding = 3) buffer D {D_TYPE data_d[];};\nlayout (binding = 4) buffer M {float data_m[];};\nlayout (binding = 5) buffer S {float data_s[];};\n\nshared FLOAT_TYPE vals[BLOCK_SIZE];\n\nfloat get_slope(uint rowx) {\n    float slope = 1.0f;\n\n    // ALiBi\n    if (p.max_bias > 0.0f) {\n        const uint h = (rowx / p.ne01) % p.ne02; // head index\n\n        const float base = h < p.n_head_log2 ? p.m0 : p.m1;\n        const uint   exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;\n\n        slope = pow(base, exp);\n    }\n\n    return slope;\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/softplus.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float x = float(data_a[i]);\n    const float result = (x > 20.0f) ? x : log(1.0f + exp(x));\n    data_d[i] = D_TYPE(result);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/solve_tri.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_binary_head.glsl\"\n\nlayout (constant_id = 1) const uint N = 64;\nlayout (constant_id = 2) const uint K = 32;\nlayout (constant_id = 3) const uint BATCH_N = 32;\n\nlayout(local_size_x_id = 4, local_size_y = 1, local_size_z = 1) in;\n\nuint a_base, b_base, x_base;\n\nFLOAT_TYPE get_a(uint r, uint c) {\n    return FLOAT_TYPE(data_a[a_base + r * p.nb01 + c * p.nb00]);\n}\n\nFLOAT_TYPE get_b(uint r, uint c) {\n    return FLOAT_TYPE(data_b[b_base + r * p.nb11 + c * p.nb10]);\n}\n\nvoid store_x(uint r, uint c, FLOAT_TYPE v) {\n    data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v);\n}\n\nshared FLOAT_TYPE shA[BATCH_N * N];\nshared FLOAT_TYPE shB[BATCH_N * K];\n\nvoid main() {\n    const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;\n    const uint tid = gl_LocalInvocationID.x;\n\n    if (batch >= p.ne02 * p.ne03) {\n        return;\n    }\n\n    const uint i3 = batch / p.ne22;\n    const uint i2 = batch % p.ne22;\n    a_base = get_aoffset() + i2 * p.nb02 + i3 * p.nb03;\n    b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13;\n    x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23;\n\n    FLOAT_TYPE X[N];\n\n    // Loop over batches of rows\n    [[unroll]] for (uint row_base = 0; row_base < N; row_base += BATCH_N) {\n        const uint cur_N = min(BATCH_N, N - row_base);\n\n        // Load the A matrix batch into shA\n        [[unroll]] for (uint i = 0; i < cur_N * N; i += gl_WorkGroupSize.x) {\n            uint idx = i + tid;\n            if (((cur_N * N) % gl_WorkGroupSize.x == 0) || idx < cur_N * N) {\n                shA[idx] = get_a(row_base + idx / N, idx % N);\n            }\n        }\n        // Load the B matrix batch into shB\n        [[unroll]] for (uint i = 0; i < cur_N * K; i += gl_WorkGroupSize.x) {\n            uint idx = i + tid;\n            if (((cur_N * K) % gl_WorkGroupSize.x == 0) || idx < cur_N * K) {\n                shB[idx] = get_b(row_base + idx / K, idx % K);\n            }\n        }\n        barrier();\n\n        // Each thread solves one column\n        if (tid < K) {\n            [[unroll]] for (uint row_offset = 0; row_offset < cur_N; ++row_offset) {\n                uint r = row_base + row_offset;\n                FLOAT_TYPE b = shB[row_offset * K + tid];\n                // Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]\n                [[unroll]] for (int c = 0; c < r; ++c) {\n                    b -= shA[row_offset * N + c] * X[c];\n                }\n                FLOAT_TYPE x = b / shA[row_offset * N + r];\n                X[r] = x;\n                store_x(r, tid, x);\n            }\n        }\n        barrier();\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/sqrt.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint idx = get_idx();\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);\n    data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/square.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint idx = get_idx();\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);\n    data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/ssm_conv.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : require\n\n#include \"types.glsl\"\n\nlayout(constant_id = 0) const uint BLOCK_SIZE = 32;\nlayout(constant_id = 1) const uint TOKENS_PER_WG = 16;\n\nlayout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;\n\nlayout(binding = 0) readonly buffer Src0 { float src0[]; };\nlayout(binding = 1) readonly buffer Src1 { float src1[]; };\nlayout(binding = 2) buffer Dst { float dst[]; };\n\nlayout(push_constant) uniform PushConstants {\n    uint nb01; uint nb02;\n    uint nb11;\n    uint dst_nb0; uint dst_nb1; uint dst_nb2;\n    uint nc; uint ncs; uint nr; uint n_t; uint n_s;\n};\n\nvoid main() {\n    const uint i1 = gl_GlobalInvocationID.x;\n    const uint i2 = gl_WorkGroupID.y * TOKENS_PER_WG + gl_LocalInvocationID.y;\n    const uint i3 = gl_WorkGroupID.z;\n\n    if (i1 >= nr || i2 >= n_t || i3 >= n_s) {\n        return;\n    }\n\n    const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);\n    const uint src1_base = i1 * (nb11 / 4);\n\n    float sum = 0.0;\n\n    if (nc == 4) {\n        sum = dot(\n            vec4(src0[src0_base], src0[src0_base + 1], src0[src0_base + 2], src0[src0_base + 3]),\n            vec4(src1[src1_base], src1[src1_base + 1], src1[src1_base + 2], src1[src1_base + 3])\n        );\n    } else {\n        [[unroll]] for (uint i0 = 0; i0 < nc; i0++) {\n            sum += src0[src0_base + i0] * src1[src1_base + i0];\n        }\n    }\n\n    const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;\n    dst[dst_idx] = sum;\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/ssm_scan.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : require\n#extension GL_KHR_shader_subgroup_basic : enable\n#if USE_SUBGROUP_ADD\n#extension GL_KHR_shader_subgroup_arithmetic : enable\n#endif\n\n#include \"types.glsl\"\n\nlayout(constant_id = 0) const uint D_STATE = 128;\nlayout(constant_id = 1) const uint SUBGROUP_SIZE = 32;\n\nconst uint32_t c_factor = D_STATE / SUBGROUP_SIZE;\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout(binding = 0) readonly buffer Src0 { float s0[]; };\nlayout(binding = 1) readonly buffer Src1 { float x[]; };\nlayout(binding = 2) readonly buffer Src2 { float dt[]; };\nlayout(binding = 3) readonly buffer Src3 { float A[]; };\nlayout(binding = 4) readonly buffer Src4 { float B[]; };\nlayout(binding = 5) readonly buffer Src5 { float C[]; };\nlayout(binding = 6) readonly buffer Src6 { int ids[]; };\nlayout(binding = 7) buffer Dst { float d[]; };\n\nlayout(push_constant) uniform PushConstants {\n    uint nb02; uint nb03; uint nb12; uint nb13;\n    uint nb21; uint nb22; uint nb31;\n    uint nb42; uint nb43; uint nb52; uint nb53;\n    uint s_off;\n    uint n_head;\n    uint d_head;\n    uint n_group;\n    uint n_tok;\n};\n\nfloat softplus(float x) {\n    if (x <= 20.0) {\n        return log(1.0 + exp(x));\n    } else {\n        return x;\n    }\n}\n\n#if !USE_SUBGROUP_ADD\nshared float temp[D_STATE];\n#endif\n\nvoid main() {\n    const uint subgroup = gl_SubgroupID;\n    const uint lane     = gl_SubgroupInvocationID;\n    const uint tid      = gl_SubgroupID * SUBGROUP_SIZE + lane;\n    const uint subgroup_idx = gl_WorkGroupID.x  * c_factor + subgroup;\n\n    const uint head_idx =  subgroup_idx / d_head;\n    const uint head_off = (subgroup_idx % d_head) * 4;\n    const uint seq_idx  = gl_WorkGroupID.y;\n\n    const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;\n    const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;\n    const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4;\n    const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;\n    const uint A_base_idx = (head_idx * nb31) / 4;\n    const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;\n    const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;\n    const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx;\n    const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;\n\n    const uint stride_x = nb12 / 4;\n    const uint stride_dt = nb21 / 4;\n    const uint stride_B = nb42 / 4;\n    const uint stride_C = nb52 / 4;\n    const uint stride_y = n_head * d_head;\n\n    float state[c_factor];\n\n    [[unroll]] for (uint j = 0; j < c_factor; j++) {\n        state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane];\n    }\n\n    float a = A[A_base_idx];\n\n    for (uint i = 0; i < n_tok; i++) {\n        float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);\n\n        float state_sum = 0.0f;\n\n        const float dA   = exp(dt_soft_plus * a);\n        const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus;\n        [[unroll]] for (uint j = 0; j < c_factor; j++) {\n            float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane];\n            float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane];\n            state[j] = (state[j] * dA) + (B_val * x_dt);\n            state_sum += state[j] * C_val;\n        }\n\n#if USE_SUBGROUP_ADD\n        state_sum = subgroupAdd(state_sum);\n#else\n        temp[tid] = state_sum;\n        barrier();\n        [[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) {\n            if (lane < s) {\n                temp[tid] += temp[tid + s];\n            }\n            barrier();\n        }\n        // get the value from lane 0\n        state_sum = temp[subgroup * SUBGROUP_SIZE];\n        barrier();\n#endif\n\n        if (lane == 0) {\n            d[y_base_idx + i * stride_y] = state_sum;\n        }\n    }\n\n    // write back the state\n    [[unroll]]\n    for (int j = 0; j < c_factor; j++) {\n        d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j];\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/step.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float x = float(data_a[i]);\n    data_d[i] = D_TYPE(x >= 0.0f ? 1.0f : 0.0f);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/sub.comp",
    "content": "#version 450\n\n#extension GL_EXT_shader_16bit_storage : require\n\n#include \"types.glsl\"\n#include \"generic_binary_head.glsl\"\n\nconst uint num_threads = 256;\n\nlayout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    uint idx = get_idx();\n\n    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation\n    const uint num_iter = 2;\n\n    [[unroll]] for (uint i = 0; i < num_iter; ++i) {\n        if (idx >= p.ne) {\n            continue;\n        }\n        uint i00, i01, i02, i03;\n        get_indices(idx, i00, i01, i02, i03);\n\n        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));\n\n        idx += num_threads;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/sum_rows.comp",
    "content": "#version 450\n\n#include \"types.glsl\"\n#include \"sum_rows.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nlayout (constant_id = 0) const uint BLOCK_SIZE = 32;\n\nshared FLOAT_TYPE tmp[BLOCK_SIZE];\n\nvoid main() {\n    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;\n    const uint col = gl_LocalInvocationID.x;\n    const float weight = p.weight;\n\n    const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);\n    const uint i03_offset = i03 * p.ne01*p.ne02;\n    const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);\n    const uint i01 = row - i03_offset - i02*p.ne01;\n\n    const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;\n    const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;\n\n    tmp[col] = FLOAT_TYPE(0.0);\n\n    for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) {\n        tmp[col] += FLOAT_TYPE(data_a[src_idx + i]);\n    }\n\n    barrier();\n    [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {\n        if (col < s) {\n            tmp[col] += tmp[col + s];\n        }\n        barrier();\n    }\n\n    if (col == 0) {\n        data_d[dst_idx] = D_TYPE(tmp[0] * weight);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/sum_rows.glsl",
    "content": "\n// vk_op_sum_rows_push_constants\nlayout (push_constant) uniform parameter\n{\n    uint n_cols;\n    uint ne01, ne02;\n    uint nb01, nb02, nb03;\n    uint nb11, nb12, nb13;\n    float weight;\n    uint misalign_offsets;\n    uint ne0_12mp, ne0_12L;\n    uint ne0_1mp, ne0_1L;\n} p;\n\nuint get_aoffset() { return p.misalign_offsets >> 16; }\nuint get_doffset() { return p.misalign_offsets & 0xFFFF; }\n\n// see init_fastdiv_values in ggml-vulkan.cpp\nuint fastdiv(uint n, uint mp, uint L) {\n    uint msbs, lsbs;\n    // msbs = mulhi(n, mp)\n    umulExtended(n, mp, msbs, lsbs);\n    return (msbs + n) >> L;\n}\n\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/swiglu.comp",
    "content": "#version 450\n\n#include \"glu_head.glsl\"\n\nfloat op(float a, float b) {\n    return a / (1.0f + exp(-a)) * b;\n}\n\n#include \"glu_main.glsl\"\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp",
    "content": "#version 450\n\n#include \"glu_head.glsl\"\n\nfloat op(float a, float b) {\n    float xi = min(a, p.limit);\n    float gi = max(min(b, p.limit), -p.limit);\n\n    float out_glu = xi / (1.0f + exp(-xi * p.alpha));\n    out_glu = out_glu * (1.0f + gi);\n    return out_glu;\n}\n\n#include \"glu_main.glsl\"\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/tanh.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n    data_d[i] = D_TYPE(1. - 2. / (exp(2.*float(data_a[i])) + 1.));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp",
    "content": "#version 450\n\n#extension GL_EXT_shader_16bit_storage : require\n\nlayout (push_constant) uniform parameter\n{\n    uint nb1;\n    uint dim;\n    uint max_period;\n} p;\n\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n#define BLOCK_SIZE 256\n\nlayout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_WorkGroupID.y;\n    const uint j = gl_GlobalInvocationID.x;\n    const uint d_offset = i * p.nb1;\n\n    const uint half_dim = p.dim / 2;\n\n    if (p.dim % 2 != 0 && j == half_dim) {\n        data_d[d_offset + 2 * half_dim] = 0.f;\n    }\n\n    if (j >= half_dim) {\n        return;\n    }\n\n    const float timestep = float(data_a[i]);\n    const float freq = float(exp(-log(p.max_period) * j / half_dim));\n    const float arg = timestep * freq;\n    data_d[d_offset + j] = D_TYPE(cos(arg));\n    data_d[d_offset + j + half_dim] = D_TYPE(sin(arg));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/topk_argsort.comp",
    "content": "#version 450\n#extension GL_EXT_control_flow_attributes : enable\n\n#include \"types.glsl\"\n\nlayout(constant_id = 0) const int BLOCK_SIZE = 1024;\nlayout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\n// Input can either be the source (A) or intermediate values (S).\n// Similarly, output can be either destination (D) or intermediate values (S).\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 0) readonly buffer S {ivec2 data_s[];};\nlayout (binding = 1) writeonly buffer D {int data_d[];};\nlayout (binding = 1) writeonly buffer T {ivec2 data_t[];};\n\nlayout (push_constant) uniform parameter {\n    uint orig_ncols;\n    uint ncols_input;\n    uint ncols_output;\n    uint k;\n    uint nrows;\n    uint first_pass;\n    uint last_pass;\n} p;\n\n// pairs of (gid, value)\nshared ivec2 dst_row[BLOCK_SIZE];\n\nvoid topk(bool needs_bounds_check, const uint row) {\n    const int col = int(gl_LocalInvocationID.x);\n\n    // initialize indices\n    if (gl_GlobalInvocationID.x < p.ncols_input) {\n        if (p.first_pass != 0) {\n            const uint row_offset = row * p.ncols_input;\n            dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));\n        } else {\n            const uint row_offset = row * p.ncols_input;\n            dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];\n        }\n    } else {\n        dst_row[col] = ivec2(p.orig_ncols, 0);\n    }\n    barrier();\n\n    if (p.k == 1) {\n        // Fast path for single output - just do a max reduction\n        [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {\n            if (col < s) {\n                ivec2 a = dst_row[col];\n                ivec2 b = dst_row[col + s];\n                if (a.x >= p.orig_ncols ||\n                    b.x < p.orig_ncols && b.y > a.y) {\n                    dst_row[col] = b;\n                }\n            }\n            barrier();\n        }\n    } else {\n        // bitonic sort on this group of elements\n        uint num_outer_loop_iters = NCOLS_PADDED_LOG2;\n        for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {\n            uint num_inner_loop_iters = outer_idx + 1;\n            for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {\n                const int ixj = int(col ^ j);\n\n                int idx_0 = (col & k) == 0 ? col : ixj;\n                int idx_1 = (col & k) == 0 ? ixj : col;\n\n                ivec2 sh_idx_0 = dst_row[idx_0];\n                ivec2 sh_idx_1 = dst_row[idx_1];\n                bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false;\n                bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false;\n\n                if ((idx_0_oob ||\n                    (!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {\n                    dst_row[idx_0] = sh_idx_1;\n                    dst_row[idx_1] = sh_idx_0;\n                }\n\n                barrier();\n            }\n        }\n    }\n\n    if (col < p.k) {\n        if (p.last_pass != 0) {\n            if (gl_GlobalInvocationID.x < p.ncols_input) {\n                const uint row_offset = row * p.k;\n                data_d[row_offset + col] = dst_row[col].x;\n            }\n        } else {\n            if (gl_WorkGroupID.x * p.k + col < p.ncols_output) {\n                const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;\n                data_t[row_offset + col] = dst_row[col];\n            }\n        }\n    }\n}\n\nvoid main() {\n    // Fast path for fully occupied workgroups\n    if ((p.ncols_input % BLOCK_SIZE) == 0) {\n        uint row = gl_WorkGroupID.y;\n        while (row < p.nrows) {\n            topk(false, row);\n            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;\n        }\n    } else {\n        uint row = gl_WorkGroupID.y;\n        while (row < p.nrows) {\n            topk(true, row);\n            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/topk_moe.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : require\n#extension GL_KHR_shader_subgroup_basic : enable\n#extension GL_KHR_shader_subgroup_arithmetic : enable\n#extension GL_KHR_shader_subgroup_shuffle : enable\n\n#include \"types.glsl\"\n\n#define GATING_FUNC_SOFTMAX 0\n#define GATING_FUNC_SIGMOID 1\n#define GATING_FUNC_SOFTMAX_WEIGHT 2\n\nlayout (push_constant) uniform parameter\n{\n    uint n_rows;\n    uint n_experts_push;\n    uint n_expert_used;\n    float clamp_min;\n    float clamp_max;\n    uint gating_func;\n    uint has_bias;\n    uint with_norm;\n    float output_scale;\n    float output_bias;\n};\n\nlayout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;\n\nlayout(constant_id = 0) const uint WARP_SIZE = 32;\nlayout(constant_id = 1) const uint n_experts_spec = 512;\nlayout(constant_id = 2) const bool nexperts_use_push = false;\n\nuint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;\n\n#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))\n\nconst uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);\n\nlayout (binding = 0, std430) readonly buffer Logits {float logits[];};\nlayout (binding = 1, std430) readonly buffer BiasProbs {float bias[];};\nlayout (binding = 2, std430) writeonly buffer Weights {float weights[];};\nlayout (binding = 3, std430) writeonly buffer Ids {uint ids[];};\n\nconst float INFINITY = 1.0 / 0.0;\n\n// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.\nvoid softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {\n    float max_val = -INFINITY;\n\n    [[unroll]]\n    for (int i = 0; i < experts_per_thread; i++) {\n        const uint idx       = lane + i * WARP_SIZE;\n        const bool is_active = !use_limit || (idx < limit);\n        if (is_active) {\n            max_val = max(max_val, vals[i]);\n        }\n    }\n\n    max_val = subgroupMax(max_val);\n\n    float sum = 0.f;\n\n    [[unroll]]\n    for (int i = 0; i < experts_per_thread; i++) {\n        const uint idx       = lane + i * WARP_SIZE;\n        const bool is_active = !use_limit || (idx < limit);\n        if (is_active) {\n            const float val = exp(vals[i] - max_val);\n            vals[i]         = val;\n            sum += val;\n        } else {\n            vals[i] = 0.f;\n        }\n    }\n\n    sum = subgroupAdd(sum);\n\n    const float inv_sum = 1.0f / sum;\n\n    [[unroll]]\n    for (int i = 0; i < experts_per_thread; i++) {\n        const uint idx       = lane + i * WARP_SIZE;\n        const bool is_active = !use_limit || (idx < limit);\n        if (is_active) {\n            vals[i] *= inv_sum;\n        }\n    }\n}\n\nvoid main() {\n    const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;\n    if (row >= n_rows) {\n        return;\n    }\n\n    const uint logits_offset = n_experts * row;\n    const uint bias_offset = 0; // 1D\n    const uint weights_offset = n_expert_used * row;\n    const uint ids_offset = n_experts * row;\n    const uint lane = gl_SubgroupInvocationID;\n\n    float probs[experts_per_thread];\n    [[unroll]]\n    for (int i = 0; i < experts_per_thread; i++) {\n        probs[i] = -INFINITY;\n    }\n\n    [[unroll]]\n    for (uint i = 0; i < n_experts; i += WARP_SIZE) {\n        const uint expert = i + lane;\n        probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;\n    }\n\n    if (gating_func == GATING_FUNC_SOFTMAX) {\n        softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);\n    } else if (gating_func == GATING_FUNC_SIGMOID) {\n        [[unroll]]\n        for (uint i = 0; i < n_experts; i += WARP_SIZE) {\n            const uint expert = i + lane;\n            probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY;\n        }\n    }\n\n    float selection_probs[experts_per_thread];\n    if (has_bias != 0) {\n        [[unroll]]\n        for (uint i = 0; i < n_experts; i += WARP_SIZE) {\n            const uint expert = i + lane;\n            selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY;\n        }\n    } else {\n        [[unroll]]\n        for (int i = 0; i < experts_per_thread; i++) {\n            selection_probs[i] = probs[i];\n        }\n    }\n\n    // at this point, each thread holds a portion of softmax,\n    // we do the argmax reduce over n_expert_used, each time marking\n    // the expert weight as -inf to exclude from the next iteration\n\n    float wt_sum = 0.f;\n\n    float output_weights[experts_per_thread];\n\n    [[unroll]]\n    for (int i = 0; i < experts_per_thread; i++) {\n        output_weights[i] = 0.f;\n    }\n\n    for (int k = 0; k < n_expert_used; k++) {\n        float max_val    = probs[0];\n        float max_val_s  = selection_probs[0];\n        uint   max_expert = lane;\n\n        [[unroll]]\n        for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) {\n            const uint expert = i + lane;\n            if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) {\n                max_val    = probs[i / WARP_SIZE];\n                max_val_s  = selection_probs[i / WARP_SIZE];\n                max_expert = expert;\n            }\n        }\n\n        [[unroll]]\n        for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {\n            const float val    = subgroupShuffleXor(max_val, mask);\n            const float val_s  = subgroupShuffleXor(max_val_s, mask);\n            const uint  expert = subgroupShuffleXor(max_expert, mask);\n            if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {\n                max_val    = val;\n                max_val_s  = val_s;\n                max_expert = expert;\n            }\n        }\n\n        if ((k & (WARP_SIZE - 1)) == lane) {\n            output_weights[k / WARP_SIZE] = max_val;\n        }\n\n        if ((max_expert & (WARP_SIZE - 1)) == lane) {\n            selection_probs[max_expert / WARP_SIZE] = -INFINITY;\n\n            ids[ids_offset + k] = max_expert;\n            wt_sum += max_val;\n        }\n    }\n\n    if (with_norm != 0) {\n        wt_sum              = subgroupAdd(wt_sum);\n        wt_sum              = clamp(wt_sum, clamp_min, clamp_max);\n        const float inv_sum = 1.0f / wt_sum;\n\n        [[unroll]]\n        for (uint i = 0; i < experts_per_thread; ++i) {\n            output_weights[i] *= inv_sum;\n        }\n    }\n\n    if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {\n        softmax_warp_inplace(output_weights, n_expert_used, lane, true);\n    }\n\n    [[unroll]]\n    for (uint i = 0; i < experts_per_thread; ++i) {\n        uint idx = i * WARP_SIZE + lane;\n        if (idx < n_expert_used) {\n            weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp",
    "content": "#version 450\n#extension GL_EXT_control_flow_attributes : enable\n#extension GL_EXT_debug_printf : enable\n#extension GL_KHR_shader_subgroup_basic : enable\n#extension GL_KHR_shader_subgroup_ballot : enable\n#extension GL_KHR_shader_subgroup_arithmetic : enable\n#extension GL_KHR_shader_subgroup_shuffle : enable\n\n#include \"types.glsl\"\n\nlayout(constant_id = 0) const int BLOCK_SIZE = 1024;\nlayout(constant_id = 1) const int SUBGROUP_SIZE = 32;\nlayout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5;\n\nlayout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;\n\n// Input can either be the source (A) or intermediate values (S).\n// Similarly, output can be either destination (D) or intermediate values (S).\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 0) readonly buffer S {ivec2 data_s[];};\nlayout (binding = 1) writeonly buffer D {int data_d[];};\nlayout (binding = 1) writeonly buffer T {ivec2 data_t[];};\n\nlayout (push_constant) uniform parameter {\n    uint orig_ncols;\n    uint ncols_input;\n    uint ncols_output;\n    uint k;\n    uint nrows;\n    uint first_pass;\n    uint last_pass;\n} p;\n\n// pairs of (gid, value)\nshared ivec2 dst_row[BLOCK_SIZE];\n\nshared int counts[SUBGROUP_SIZE];\nshared int sh_min_idx;\nshared uint sh_total;\nshared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];\nshared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE];\n\n// Map float values to uint such that comparisons still work.\n// Positive values set the high bit, negative values are inverted.\n// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places.\nuint f2ui(float x) {\n    uint y = floatBitsToUint(x);\n    if ((y & 0x80000000) != 0) {\n        y ^= ~0;\n    } else {\n        y |= 0x80000000;\n    }\n    return y;\n}\n\nvoid topk(const uint row) {\n    const int tid = int(gl_LocalInvocationID.x);\n\n    // initialize indices\n    if (gl_GlobalInvocationID.x < p.ncols_input) {\n        if (p.first_pass != 0) {\n            const uint row_offset = row * p.ncols_input;\n            dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));\n        } else {\n            const uint row_offset = row * p.ncols_input;\n            dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];\n        }\n    } else {\n        dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf\n    }\n    barrier();\n\n    if (p.k == 1) {\n        // Fast path for single output - just do a max reduction\n        [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {\n            if (tid < s) {\n                ivec2 a = dst_row[tid];\n                ivec2 b = dst_row[tid + s];\n                if (a.x >= p.orig_ncols ||\n                    b.x < p.orig_ncols && b.y > a.y) {\n                    dst_row[tid] = b;\n                }\n            }\n            barrier();\n        }\n    } else {\n        // Do an N-ary search to find the K-th largest value.\n        // We remap the float values to be comparable as unsigned integers,\n        // and split the range into 2^N smaller ranges where N is the\n        // subgroup size. Count how many values are in each range, if the K-th\n        // largest value is in the middle of one of thee ranges then repeat\n        // and split again.\n\n        // Mask is the current set of bits we're searching. Shift is the LSB index.\n        int shift = 32 - SUBGROUP_SIZE_LOG2;\n        uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift;\n\n        // The current range.\n        uint range_min = 0;\n        uint range_max = 0xFF800000;\n        // How many are above the current range, and how many we need to find.\n        uint total = 0;\n        uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);\n\n        while (mask != 0) {\n            barrier();\n            // Initialize bucket counts to zero.\n            if (tid < SUBGROUP_SIZE) {\n                counts[tid] = 0;\n            }\n            barrier();\n            // Count how many values are in each bucket.\n            if (tid < p.ncols_input) {\n                float y = intBitsToFloat(dst_row[tid].y);\n                uint fy = f2ui(y);\n                if (fy >= range_min && fy < range_max) {\n                    uint bucket = (fy & mask) >> shift;\n                    atomicAdd(counts[bucket], 1);\n                }\n            }\n            barrier();\n\n            // On the first subgroup, do a scan to count (from the top down) how\n            // many elements are in the top N buckets. Find the index of the first\n            // that is over the limit. Copy it to the other invocations through\n            // shared memory.\n            if (tid < SUBGROUP_SIZE) {\n                uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid];\n                partial_sum = subgroupInclusiveAdd(partial_sum) + total;\n                uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit));\n                if (tid == t) {\n                    sh_min_idx = int(SUBGROUP_SIZE - 1 - t);\n                    sh_total = partial_sum;\n                }\n            }\n            barrier();\n            int min_idx = sh_min_idx;\n            total = sh_total;\n\n            // Update the range, and break if we've found the K-th largest.\n            range_max = range_min + ((min_idx + 1) << shift);\n            range_min = range_min + (min_idx << shift);\n\n            if (total == p.k) {\n                break;\n            }\n            total -= counts[min_idx];\n            mask >>= SUBGROUP_SIZE_LOG2;\n            shift -= SUBGROUP_SIZE_LOG2;\n            if (shift < 0) {\n                shift = 0;\n            }\n        }\n\n        ivec2 v = dst_row[tid];\n\n        // We need to compact these values to the start of the dst_row array.\n        // Have each subgroup count how many items it'll store, so other\n        // subgroups can compute their base offset.\n        // Values strictly greater than range_min must be stored. For values equal\n        // to range_min, there can be ties and it's possible we'll need to store\n        // an arbitrary subset of them.\n        // If total == p.k, have a fast path where we don't need to handle ties.\n        if (total == p.k) {\n            bool top = f2ui(intBitsToFloat(v.y)) >= range_min;\n            uvec4 b = subgroupBallot(top);\n            uint bit_count = subgroupBallotBitCount(b);\n            if ((tid % SUBGROUP_SIZE) == 0) {\n                offset_partials[tid / SUBGROUP_SIZE] = bit_count;\n            }\n            barrier();\n\n            uint out_idx = 0;\n            [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {\n                if (i < tid / SUBGROUP_SIZE) {\n                    out_idx += offset_partials[i];\n                }\n            }\n\n            uint bit_count_ex = subgroupBallotExclusiveBitCount(b);\n            if (top) {\n                // TODO: Copy directly to the output?\n                dst_row[out_idx + bit_count_ex] = v;\n            }\n        } else {\n            bool top = f2ui(intBitsToFloat(v.y)) > range_min;\n            bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min;\n            uvec4 b_top = subgroupBallot(top);\n            uvec4 b_eq_min = subgroupBallot(eq_min);\n            uint bit_count_top = subgroupBallotBitCount(b_top);\n            uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min);\n            if ((tid % SUBGROUP_SIZE) == 0) {\n                offset_partials[tid / SUBGROUP_SIZE] = bit_count_top;\n                eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min;\n            }\n            barrier();\n\n            uint out_idx = 0;\n            uint eq_min_base = 0;\n            uint eq_min_idx = 0;\n            [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {\n                if (i < tid / SUBGROUP_SIZE) {\n                    out_idx += offset_partials[i];\n                    eq_min_idx += eq_min_partials[i];\n                }\n                eq_min_base += offset_partials[i];\n            }\n            // range_min values are stored at the end\n            eq_min_idx += eq_min_base;\n\n            uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top);\n            uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min);\n            if (top) {\n                // TODO: Copy directly to the output?\n                dst_row[out_idx + bit_count_ex_top] = v;\n            }\n            if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) {\n                dst_row[eq_min_idx + bit_count_ex_eq_min] = v;\n            }\n        }\n\n        barrier();\n    }\n\n    if (tid < p.k) {\n        if (p.last_pass != 0) {\n            if (gl_GlobalInvocationID.x < p.ncols_input) {\n                const uint row_offset = row * p.k;\n                data_d[row_offset + tid] = dst_row[tid].x;\n            }\n        } else {\n            if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) {\n                const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;\n                data_t[row_offset + tid] = dst_row[tid];\n            }\n        }\n    }\n}\n\nvoid main() {\n    uint row = gl_WorkGroupID.y;\n    while (row < p.nrows) {\n        topk(row);\n        row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/tri.comp",
    "content": "#version 450\n\n#include \"rte.glsl\"\n#include \"types.glsl\"\n#include \"generic_unary_head.glsl\"\n\n#define GGML_TRI_TYPE_UPPER_DIAG 0\n#define GGML_TRI_TYPE_UPPER      1\n#define GGML_TRI_TYPE_LOWER_DIAG 2\n#define GGML_TRI_TYPE_LOWER      3\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nvoid main() {\n    const uint idx = get_idx();\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);\n    const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;\n    const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);\n    const uint i02_offset = i02*p.ne01*p.ne00;\n    const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);\n    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;\n\n    int param = floatBitsToInt(p.param1);\n    bool pass = false;\n    switch (param) {\n    case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break;\n    case GGML_TRI_TYPE_UPPER:      pass = i00 >  i01; break;\n    case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break;\n    case GGML_TRI_TYPE_LOWER:      pass = i00 <  i01; break;\n    }\n\n    if (pass) {\n        const float val = float(data_a[get_aoffset() + src0_idx(idx)]);\n        data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);\n    } else {\n        data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/trunc.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    const float x = float(data_a[i]);\n    data_d[i] = D_TYPE(trunc(x));\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/types.glsl",
    "content": "#if !defined(GGML_TYPES_COMP)\n#define GGML_TYPES_COMP\n\n#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require\n#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require\n#extension GL_EXT_shader_16bit_storage : require\n\n#if defined(DATA_A_F32)\n#define QUANT_K 1\n#define QUANT_R 1\n\n#if LOAD_VEC_A == 4\n#define A_TYPE vec4\n#elif LOAD_VEC_A == 8\n#define A_TYPE mat2x4\n#else\n#define A_TYPE float\n#endif\n#endif\n\n#if defined(DATA_A_F16)\n#define QUANT_K 1\n#define QUANT_R 1\n\n#if LOAD_VEC_A == 4\n#define A_TYPE f16vec4\n#elif LOAD_VEC_A == 8\n#define A_TYPE f16mat2x4\n#else\n#define A_TYPE float16_t\n#endif\n#endif\n\n#if defined(DATA_A_BF16)\n#define QUANT_K 1\n#define QUANT_R 1\n\n#if LOAD_VEC_A == 4\n#define A_TYPE u16vec4\n#elif LOAD_VEC_A == 8\n#error unsupported\n#else\n#define A_TYPE uint16_t\n#endif\n#endif\n\n#define QUANT_K_Q4_0 32\n#define QUANT_R_Q4_0 2\n\nstruct block_q4_0\n{\n    float16_t d;\n    uint8_t qs[16];\n};\nstruct block_q4_0_packed16\n{\n    float16_t d;\n    uint16_t qs[16/2];\n};\n\n#if defined(DATA_A_Q4_0)\n#define QUANT_K QUANT_K_Q4_0\n#define QUANT_R QUANT_R_Q4_0\n#define QUANT_AUXF 1\n#define A_TYPE block_q4_0\n#define A_TYPE_PACKED16 block_q4_0_packed16\n#define DATA_A_QUANT_LEGACY\n#endif\n\n#define QUANT_K_Q4_1 32\n#define QUANT_R_Q4_1 2\n\nstruct block_q4_1\n{\n    float16_t d;\n    float16_t m;\n    uint8_t qs[16];\n};\n\nstruct block_q4_1_packed16\n{\n    float16_t d;\n    float16_t m;\n    uint16_t qs[16/2];\n};\n\nstruct block_q4_1_packed32\n{\n    f16vec2 dm;\n    uint32_t qs[16/4];\n};\n\n#if defined(DATA_A_Q4_1)\n#define QUANT_K QUANT_K_Q4_1\n#define QUANT_R QUANT_R_Q4_1\n#define QUANT_AUXF 2\n#define A_TYPE block_q4_1\n#define A_TYPE_PACKED16 block_q4_1_packed16\n#define A_TYPE_PACKED32 block_q4_1_packed32\n#define DATA_A_QUANT_LEGACY\n#endif\n\n#define QUANT_K_Q5_0 32\n#define QUANT_R_Q5_0 2\n\nstruct block_q5_0\n{\n    float16_t d;\n    uint16_t qh[2];\n    uint8_t qs[16];\n};\n\nstruct block_q5_0_packed16\n{\n    float16_t d;\n    uint16_t qh[2];\n    uint16_t qs[16/2];\n};\n\n#if defined(DATA_A_Q5_0)\n#define QUANT_K QUANT_K_Q5_0\n#define QUANT_R QUANT_R_Q5_0\n#define QUANT_AUXF 1\n#define A_TYPE block_q5_0\n#define A_TYPE_PACKED16 block_q5_0_packed16\n#define DATA_A_QUANT_LEGACY\n#endif\n\n#define QUANT_K_Q5_1 32\n#define QUANT_R_Q5_1 2\n\nstruct block_q5_1\n{\n    float16_t d;\n    float16_t m;\n    uint qh;\n    uint8_t qs[16];\n};\n\nstruct block_q5_1_packed16\n{\n    float16_t d;\n    float16_t m;\n    uint qh;\n    uint16_t qs[16/2];\n};\n\nstruct block_q5_1_packed32\n{\n    f16vec2 dm;\n    uint qh;\n    uint32_t qs[16/4];\n};\n\n#if defined(DATA_A_Q5_1)\n#define QUANT_K QUANT_K_Q5_1\n#define QUANT_R QUANT_R_Q5_1\n#define QUANT_AUXF 2\n#define A_TYPE block_q5_1\n#define A_TYPE_PACKED16 block_q5_1_packed16\n#define A_TYPE_PACKED32 block_q5_1_packed32\n#define DATA_A_QUANT_LEGACY\n#endif\n\n#define QUANT_K_Q8_0 32\n#define QUANT_R_Q8_0 1\n\nstruct block_q8_0\n{\n    float16_t d;\n    int8_t qs[32];\n};\n\nstruct block_q8_0_packed16\n{\n    float16_t d;\n    int16_t qs[32/2];\n};\n\n#if defined(DATA_A_Q8_0)\n#define QUANT_K QUANT_K_Q8_0\n#define QUANT_R QUANT_R_Q8_0\n#define QUANT_AUXF 1\n#define A_TYPE block_q8_0\n#define A_TYPE_PACKED16 block_q8_0_packed16\n#define DATA_A_QUANT_LEGACY\n#endif\n\n#define QUANT_K_Q8_1 32\n#define QUANT_R_Q8_1 1\n\nstruct block_q8_1\n{\n    f16vec2 ds;\n    int8_t qs[32];\n};\n\nstruct block_q8_1_packed16\n{\n    f16vec2 ds;\n    int16_t qs[16];\n};\n\nstruct block_q8_1_packed32\n{\n    f16vec2 ds;\n    int32_t qs[8];\n};\n\n// 4 blocks in one to allow 16-byte/128-bit alignment and loads\nstruct block_q8_1_x4\n{\n    f16vec2 ds[4];\n    int32_t qs[32];\n};\n\nstruct block_q8_1_x4_packed128\n{\n    f16vec2 ds[4];\n    ivec4 qs[8];\n};\n\n// K-quants\n#define QUANT_K_Q2_K 256\n\nstruct block_q2_K\n{\n    uint8_t scales[QUANT_K_Q2_K/16];\n    uint8_t qs[QUANT_K_Q2_K/4];\n    f16vec2 dm;\n};\n\nstruct block_q2_K_packed16\n{\n    uint16_t scales[QUANT_K_Q2_K/16/2];\n    uint16_t qs[QUANT_K_Q2_K/4/2];\n    f16vec2 dm;\n};\n\nstruct block_q2_K_packed32\n{\n    uint32_t scales[QUANT_K_Q2_K/16/4];\n    uint32_t qs[QUANT_K_Q2_K/4/4];\n    f16vec2 dm;\n};\n\n#if defined(DATA_A_Q2_K)\n#define QUANT_K QUANT_K_Q2_K\n#define QUANT_R 1\n#define A_TYPE block_q2_K\n#define A_TYPE_PACKED16 block_q2_K_packed16\n#define A_TYPE_PACKED32 block_q2_K_packed32\n#define SCALES_PER_32 2\n#define DATA_A_QUANT_K\n#endif\n\n#define QUANT_K_Q3_K 256\n\nstruct block_q3_K\n{\n    uint8_t hmask[QUANT_K_Q3_K/8];\n    uint8_t qs[QUANT_K_Q3_K/4];\n    uint8_t scales[12];\n    float16_t d;\n};\n\nstruct block_q3_K_packed16\n{\n    uint16_t hmask[QUANT_K_Q3_K/8/2];\n    uint16_t qs[QUANT_K_Q3_K/4/2];\n    uint16_t scales[12/2];\n    float16_t d;\n};\n\n#if defined(DATA_A_Q3_K)\n#define QUANT_K QUANT_K_Q3_K\n#define QUANT_R 1\n#define A_TYPE block_q3_K\n#define A_TYPE_PACKED16 block_q3_K_packed16\n#define DATA_A_QUANT_K\n#endif\n\n#define QUANT_K_Q4_K 256\n\nstruct block_q4_K\n{\n    f16vec2 dm;\n    uint8_t scales[3*QUANT_K_Q4_K/64];\n    uint8_t qs[QUANT_K_Q4_K/2];\n};\n\nstruct block_q4_K_packed16\n{\n    f16vec2 dm;\n    uint16_t scales[3*QUANT_K_Q4_K/64/2];\n    uint16_t qs[QUANT_K_Q4_K/2/2];\n};\n\nstruct block_q4_K_packed32\n{\n    f16vec2 dm;\n    uint32_t scales[3*QUANT_K_Q4_K/64/4];\n    uint32_t qs[QUANT_K_Q4_K/2/4];\n};\n\nstruct block_q4_K_packed128\n{\n    uvec4 q4k[9];\n};\n\n#if defined(DATA_A_Q4_K)\n#define QUANT_K QUANT_K_Q4_K\n#define QUANT_R 1\n#define A_TYPE block_q4_K\n#define A_TYPE_PACKED16 block_q4_K_packed16\n#define A_TYPE_PACKED32 block_q4_K_packed32\n#define DATA_A_QUANT_K\n#endif\n\n#define QUANT_K_Q5_K 256\n\nstruct block_q5_K\n{\n    f16vec2 dm;\n    uint8_t scales[12];\n    uint8_t qh[QUANT_K_Q5_K/8];\n    uint8_t qs[QUANT_K_Q5_K/2];\n};\n\nstruct block_q5_K_packed16\n{\n    f16vec2 dm;\n    uint16_t scales[12/2];\n    uint16_t qh[QUANT_K_Q5_K/8/2];\n    uint16_t qs[QUANT_K_Q5_K/2/2];\n};\n\nstruct block_q5_K_packed32\n{\n    f16vec2 dm;\n    uint32_t scales[12/4];\n    uint32_t qh[QUANT_K_Q5_K/8/4];\n    uint32_t qs[QUANT_K_Q5_K/2/4];\n};\n\nstruct block_q5_K_packed128\n{\n    uvec4 q5k[11];\n};\n\n#if defined(DATA_A_Q5_K)\n#define QUANT_K QUANT_K_Q5_K\n#define QUANT_R 1\n#define A_TYPE block_q5_K\n#define A_TYPE_PACKED16 block_q5_K_packed16\n#define A_TYPE_PACKED32 block_q5_K_packed32\n#define DATA_A_QUANT_K\n#endif\n\n#define QUANT_K_Q6_K 256\n\nstruct block_q6_K\n{\n    uint8_t ql[QUANT_K_Q6_K/2];\n    uint8_t qh[QUANT_K_Q6_K/4];\n    int8_t scales[QUANT_K_Q6_K/16];\n    float16_t d;\n};\n\nstruct block_q6_K_packed16\n{\n    uint16_t ql[QUANT_K_Q6_K/2/2];\n    uint16_t qh[QUANT_K_Q6_K/4/2];\n    int16_t scales[QUANT_K_Q6_K/16/2];\n    float16_t d;\n};\n\n#if defined(DATA_A_Q6_K)\n#define QUANT_K QUANT_K_Q6_K\n#define QUANT_R 1\n#define A_TYPE block_q6_K\n#define A_TYPE_PACKED16 block_q6_K_packed16\n#define DATA_A_QUANT_K\n#endif\n\n// IQuants\n\n#define QUANT_K_IQ1_S 256\n#define QUANT_R_IQ1_S 1\n\nstruct block_iq1_s {\n    float16_t d;\n    uint8_t  qs[QUANT_K_IQ1_S/8];\n    uint16_t qh[QUANT_K_IQ1_S/32];\n};\n\nstruct block_iq1_s_packed16 {\n    float16_t d;\n    uint16_t qs[QUANT_K_IQ1_S/8/2];\n    uint16_t qh[QUANT_K_IQ1_S/32];\n};\n\n#define QUANT_K_IQ1_M 256\n#define QUANT_R_IQ1_M 1\n\nstruct block_iq1_m {\n    uint8_t  qs[QUANT_K_IQ1_M/8];\n    uint8_t  qh[QUANT_K_IQ1_M/16];\n    uint16_t scales[QUANT_K_IQ1_M/64];\n};\n\nstruct block_iq1_m_packed16 {\n    uint16_t qs[QUANT_K_IQ1_M/8/2];\n    uint16_t qh[QUANT_K_IQ1_M/16/2];\n    uint16_t scales[QUANT_K_IQ1_M/64];\n};\n\nstruct block_iq1_m_packed32 {\n    uint32_t qs[QUANT_K_IQ1_M/8/4];\n    uint32_t qh[QUANT_K_IQ1_M/16/4];\n    uint32_t scales[QUANT_K_IQ1_M/64/2];\n};\n\nstruct block_iq1_m_packed64 {\n    uint64_t  qs[QUANT_K_IQ1_M/8/8];\n    uint64_t  qh[QUANT_K_IQ1_M/16/8];\n    uint64_t scales;\n};\n\n#if defined(DATA_A_IQ1_S)\n#define QUANT_K QUANT_K_IQ1_S\n#define QUANT_R QUANT_R_IQ1_S\n#define A_TYPE block_iq1_s\n#define A_TYPE_PACKED16 block_iq1_s_packed16\n#endif\n\n#if defined(DATA_A_IQ1_M)\n#define QUANT_K QUANT_K_IQ1_M\n#define QUANT_R QUANT_R_IQ1_M\n#define A_TYPE block_iq1_m\n#define A_TYPE_PACKED16 block_iq1_m_packed16\n#define A_TYPE_PACKED32 block_iq1_m_packed32\n#endif\n\n#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)\n#define IQ1S_DELTA 0.125f\n#define IQ1M_DELTA 0.125f\n\n// Packed IQ1S grid where every 2 vec8 are encoded on 32 bits (2 bits per coordinate).\nconst uint[1024] iq1s_grid_const = {\n    0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01,\n    0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4,\n    0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41,\n    0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f,\n    0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334,\n    0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f,\n    0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040,\n    0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f,\n    0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5,\n    0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3,\n    0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff,\n    0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570,\n    0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f,\n    0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf,\n    0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f,\n    0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07,\n    0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc,\n    0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374,\n    0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0,\n    0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001,\n    0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043,\n    0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc,\n    0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117,\n    0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f,\n    0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5,\n    0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474,\n    0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d,\n    0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd,\n    0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50,\n    0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10,\n    0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30,\n    0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1,\n    0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c,\n    0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074,\n    0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134,\n    0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7,\n    0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3,\n    0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450,\n    0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577,\n    0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c,\n    0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5,\n    0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c,\n    0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00,\n    0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300,\n    0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc,\n    0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034,\n    0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077,\n    0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5,\n    0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117,\n    0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f,\n    0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5,\n    0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404,\n    0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1,\n    0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd,\n    0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71,\n    0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7,\n    0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00,\n    0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44,\n    0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00,\n    0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0,\n    0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303,\n    0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343,\n    0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd,\n    0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031,\n    0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011,\n    0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c,\n    0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4,\n    0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c,\n    0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174,\n    0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7,\n    0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d,\n    0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4,\n    0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c,\n    0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7,\n    0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510,\n    0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33,\n    0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4,\n    0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73,\n    0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f,\n    0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337,\n    0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343,\n    0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030,\n    0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075,\n    0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4,\n    0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170,\n    0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705,\n    0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c,\n    0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c,\n    0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514,\n    0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c,\n    0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3,\n    0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70,\n    0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03,\n    0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c,\n    0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c,\n    0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074,\n    0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104,\n    0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7,\n    0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757,\n    0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c,\n    0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c,\n    0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4,\n    0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc,\n    0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03,\n    0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc,\n    0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54,\n    0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f,\n    0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf,\n    0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c,\n    0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c,\n    0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4,\n    0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174,\n    0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700,\n    0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7,\n    0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d,\n    0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531,\n    0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf,\n    0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57,\n    0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13,\n    0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01,\n    0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f,\n    0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7,\n    0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074,\n    0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107,\n    0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd,\n    0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0,\n    0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7,\n    0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557\n};\n\n// Same content as iq1s_grid_const except each 2-bit value is expanded to 4-bit\n// and has 1 added to it (allows packed values to be extracted with & 0x0F0F0F0F\n// and 0xF0F0F0F0).\nconst uint32_t[2048] iq1s_grid_gpu_const = {\n    0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,\n    0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,\n    0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,\n    0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212,\n    0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011,\n    0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111,\n    0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220,\n    0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022,\n    0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,\n    0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101,\n    0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110,\n    0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111,\n    0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010,\n    0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210,\n    0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221,\n    0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021,\n    0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002,\n    0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,\n    0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101,\n    0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211,\n    0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110,\n    0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022,\n    0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121,\n    0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220,\n    0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001,\n    0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101,\n    0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,\n    0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012,\n    0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010,\n    0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111,\n    0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122,\n    0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222,\n    0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001,\n    0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102,\n    0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101,\n    0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,\n    0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101,\n    0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112,\n    0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110,\n    0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211,\n    0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012,\n    0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111,\n    0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120,\n    0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122,\n    0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,\n    0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221,\n    0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001,\n    0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101,\n    0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101,\n    0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011,\n    0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111,\n    0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011,\n    0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122,\n    0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,\n    0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222,\n    0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101,\n    0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000,\n    0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200,\n    0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110,\n    0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112,\n    0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222,\n    0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021,\n    0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,\n    0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201,\n    0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200,\n    0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101,\n    0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011,\n    0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010,\n    0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211,\n    0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121,\n    0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000,\n    0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,\n    0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202,\n    0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211,\n    0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112,\n    0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020,\n    0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121,\n    0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222,\n    0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102,\n    0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100,\n    0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,\n    0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011,\n    0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111,\n    0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110,\n    0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121,\n    0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222,\n    0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201,\n    0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102,\n    0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201,\n    0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,\n    0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010,\n    0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010,\n    0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110,\n    0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011,\n    0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212,\n    0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021,\n    0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021,\n    0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021,\n    0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,\n    0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101,\n    0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100,\n    0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010,\n    0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111,\n    0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010,\n    0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111,\n    0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120,\n    0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120,\n    0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,\n    0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001,\n    0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201,\n    0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210,\n    0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211,\n    0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111,\n    0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112,\n    0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211,\n    0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010,\n    0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,\n    0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122,\n    0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221,\n    0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102,\n    0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100,\n    0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101,\n    0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101,\n    0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101,\n    0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012,\n    0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,\n    0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112,\n    0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210,\n    0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210,\n    0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210,\n    0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010,\n    0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110,\n    0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122,\n    0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020,\n    0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,\n    0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022,\n    0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120,\n    0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222,\n    0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221,\n    0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001,\n    0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102,\n    0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201,\n    0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012,\n    0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,\n    0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012,\n    0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110,\n    0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110,\n    0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121,\n    0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221,\n    0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220,\n    0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222,\n    0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000,\n    0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,\n    0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012,\n    0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011,\n    0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212,\n    0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221,\n    0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121,\n    0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202,\n    0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202,\n    0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002,\n    0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,\n    0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210,\n    0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112,\n    0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011,\n    0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011,\n    0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210,\n    0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020,\n    0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220,\n    0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222,\n    0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,\n    0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001,\n    0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010,\n    0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111,\n    0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010,\n    0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110,\n    0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221,\n    0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122,\n    0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202,\n    0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,\n    0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101,\n    0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112,\n    0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111,\n    0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211,\n    0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222,\n    0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221,\n    0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022,\n    0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101,\n    0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,\n    0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111,\n    0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111,\n    0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010,\n    0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121,\n    0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222,\n    0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000,\n    0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202,\n    0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000,\n    0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,\n    0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110,\n    0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110,\n    0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222,\n    0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120,\n    0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022,\n    0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101,\n    0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202,\n    0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110,\n    0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,\n    0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111,\n    0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111,\n    0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120,\n    0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121,\n    0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001,\n    0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202,\n    0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001,\n    0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200,\n    0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,\n    0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212,\n    0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012,\n    0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110,\n    0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012,\n    0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111,\n    0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020,\n    0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121,\n    0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222,\n    0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,\n    0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102,\n    0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101,\n    0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212,\n    0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210,\n    0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111,\n    0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212,\n    0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221,\n    0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121,\n    0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,\n    0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000,\n    0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202,\n    0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112,\n    0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111,\n    0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020,\n    0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221,\n    0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022,\n    0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100,\n    0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,\n    0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112,\n    0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211,\n    0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012,\n    0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121,\n    0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020,\n    0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120,\n    0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200,\n    0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200,\n    0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,\n    0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011,\n    0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222,\n    0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,\n    0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,\n};\n\nshared uint16_t iq1s_grid[2048];\nshared uint32_t iq1s_grid_gpu[2048];\n\n#define NEEDS_INIT_IQ_SHMEM\nvoid init_iq_shmem(uvec3 wgsize)\n{\n    // copy the table into shared memory and sync\n    [[unroll]] for (uint i = 0; i < iq1s_grid_const.length(); i += wgsize.x) {\n        uint idx = i + gl_LocalInvocationIndex.x;\n        if (iq1s_grid_const.length() % wgsize.x == 0 || idx < iq1s_grid_const.length()) {\n            u16vec2 g = unpack16(iq1s_grid_const[idx]);\n            iq1s_grid[2*idx+0] = g.x;\n            iq1s_grid[2*idx+1] = g.y;\n        }\n    }\n    [[unroll]] for (uint i = 0; i < iq1s_grid_gpu_const.length(); i += wgsize.x) {\n        uint idx = i + gl_LocalInvocationIndex.x;\n        if (iq1s_grid_gpu_const.length() % wgsize.x == 0 || idx < iq1s_grid_gpu_const.length()) {\n            iq1s_grid_gpu[idx] = iq1s_grid_gpu_const[idx];\n        }\n    }\n    barrier();\n}\n#endif\n\n#define QUANT_K_IQ2_XXS 256\n#define QUANT_R_IQ2_XXS 1\n\nstruct block_iq2_xxs\n{\n    float16_t d;\n    uint8_t qs[QUANT_K_IQ2_XXS/4];\n};\n\nstruct block_iq2_xxs_packed16\n{\n    float16_t d;\n    uint16_t qs[QUANT_K_IQ2_XXS/8];\n};\n\n#if defined(DATA_A_IQ2_XXS)\n\nconst uvec2[256] iq2xxs_grid_const = {\n    uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808),\n    uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x082b0808, 0x08080808),\n    uvec2(0x082b082b, 0x08080808), uvec2(0x082b2b08, 0x08080808), uvec2(0x082b2b2b, 0x08080808), uvec2(0x19080819, 0x08080808),\n    uvec2(0x19081908, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808),\n    uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b082b2b, 0x08080808),\n    uvec2(0x2b2b082b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), uvec2(0x08190808, 0x08080819),\n    uvec2(0x08191919, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x2b081908, 0x08080819), uvec2(0x2b192b08, 0x08080819),\n    uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x082b082b, 0x0808082b), uvec2(0x2b08082b, 0x0808082b),\n    uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x082b0819, 0x08081908),\n    uvec2(0x082b1908, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19082b08, 0x08081908),\n    uvec2(0x192b0808, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908),\n    uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), uvec2(0x08082b08, 0x08081919),\n    uvec2(0x082b0808, 0x08081919), uvec2(0x1908192b, 0x08081919), uvec2(0x192b2b19, 0x08081919), uvec2(0x2b080808, 0x08081919),\n    uvec2(0x2b190819, 0x08081919), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x19080808, 0x0808192b),\n    uvec2(0x2b081908, 0x0808192b), uvec2(0x2b2b1908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x08081919, 0x08082b08),\n    uvec2(0x08082b08, 0x08082b08), uvec2(0x08191908, 0x08082b08), uvec2(0x082b2b08, 0x08082b08), uvec2(0x19080819, 0x08082b08),\n    uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x2b082b08, 0x08082b08),\n    uvec2(0x08081908, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x0808082b, 0x08082b2b), uvec2(0x08191908, 0x08082b2b),\n    uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x082b0819, 0x08190808),\n    uvec2(0x19080808, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808),\n    uvec2(0x2b191919, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x082b0808, 0x08190819),\n    uvec2(0x19190808, 0x08190819), uvec2(0x19192b2b, 0x08190819), uvec2(0x2b080808, 0x08190819), uvec2(0x082b1908, 0x0819082b),\n    uvec2(0x19081919, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x08082b08, 0x08191908), uvec2(0x082b0808, 0x08191908),\n    uvec2(0x082b1919, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08192b08, 0x08191919),\n    uvec2(0x192b082b, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x0819192b, 0x0819192b), uvec2(0x08080819, 0x08192b08),\n    uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x2b080819, 0x08192b08),\n    uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x2b2b0808, 0x08192b19), uvec2(0x19190819, 0x08192b2b),\n    uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x19081908, 0x082b0808),\n    uvec2(0x192b0819, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b08082b, 0x082b0808), uvec2(0x082b2b19, 0x082b0819),\n    uvec2(0x19082b08, 0x082b0819), uvec2(0x08080808, 0x082b082b), uvec2(0x0808082b, 0x082b082b), uvec2(0x08080819, 0x082b1908),\n    uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x19080808, 0x082b1908), uvec2(0x1919192b, 0x082b1908),\n    uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x192b1908, 0x082b1919), uvec2(0x2b190808, 0x082b192b),\n    uvec2(0x08082b08, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), uvec2(0x2b191908, 0x082b2b08), uvec2(0x19081908, 0x082b2b2b),\n    uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x08192b08, 0x19080808),\n    uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x19080808, 0x19080808), uvec2(0x19082b08, 0x19080808),\n    uvec2(0x1919192b, 0x19080808), uvec2(0x192b0808, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808),\n    uvec2(0x2b190808, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x192b0819, 0x19080819),\n    uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08190808, 0x1908082b),\n    uvec2(0x19082b08, 0x1908082b), uvec2(0x1919192b, 0x1908082b), uvec2(0x192b2b08, 0x1908082b), uvec2(0x08080808, 0x19081908),\n    uvec2(0x08082b08, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b192b19, 0x19081908),\n    uvec2(0x0819082b, 0x19081919), uvec2(0x082b1908, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08080819, 0x19082b08),\n    uvec2(0x08081908, 0x19082b08), uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08),\n    uvec2(0x08080808, 0x19082b19), uvec2(0x19192b08, 0x19082b19), uvec2(0x192b0819, 0x19082b19), uvec2(0x2b08082b, 0x19082b19),\n    uvec2(0x19081919, 0x19082b2b), uvec2(0x2b190808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x08082b08, 0x19190808),\n    uvec2(0x08190819, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x2b080808, 0x19190808),\n    uvec2(0x2b082b08, 0x19190808), uvec2(0x08081908, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x2b2b1908, 0x19190819),\n    uvec2(0x2b190819, 0x1919082b), uvec2(0x2b190808, 0x19191908), uvec2(0x2b19082b, 0x19191908), uvec2(0x08082b2b, 0x19191919),\n    uvec2(0x08080819, 0x1919192b), uvec2(0x19191908, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x08190819, 0x19192b08),\n    uvec2(0x08192b19, 0x19192b08), uvec2(0x192b1908, 0x19192b08), uvec2(0x19080808, 0x19192b19), uvec2(0x08082b08, 0x19192b2b),\n    uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x192b2b08, 0x192b0808),\n    uvec2(0x08080808, 0x192b0819), uvec2(0x19191919, 0x192b0819), uvec2(0x08192b08, 0x192b082b), uvec2(0x192b0808, 0x192b082b),\n    uvec2(0x08080808, 0x192b1908), uvec2(0x08081919, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x0819082b, 0x192b1919),\n    uvec2(0x2b081908, 0x192b1919), uvec2(0x1908082b, 0x192b2b08), uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808),\n    uvec2(0x08082b2b, 0x2b080808), uvec2(0x19080819, 0x2b080808), uvec2(0x2b08082b, 0x2b080808), uvec2(0x08081908, 0x2b080819),\n    uvec2(0x08192b08, 0x2b080819), uvec2(0x19080808, 0x2b080819), uvec2(0x08190819, 0x2b08082b), uvec2(0x08080819, 0x2b081908),\n    uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908),\n    uvec2(0x192b0808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x1908192b, 0x2b081919), uvec2(0x2b191908, 0x2b081919),\n    uvec2(0x08082b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x192b0808, 0x2b08192b), uvec2(0x0808082b, 0x2b082b08),\n    uvec2(0x08081908, 0x2b082b19), uvec2(0x08190819, 0x2b082b2b), uvec2(0x08081908, 0x2b190808), uvec2(0x08190808, 0x2b190808),\n    uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x2b2b0819, 0x2b190808), uvec2(0x0819192b, 0x2b190819),\n    uvec2(0x2b080808, 0x2b190819), uvec2(0x19081919, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x082b082b, 0x2b191908),\n    uvec2(0x19081908, 0x2b191908), uvec2(0x19190819, 0x2b191919), uvec2(0x2b080819, 0x2b192b08), uvec2(0x082b0808, 0x2b192b19),\n    uvec2(0x0808082b, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b081919, 0x2b2b0808), uvec2(0x08082b19, 0x2b2b0819),\n    uvec2(0x08080808, 0x2b2b082b), uvec2(0x08192b08, 0x2b2b1908), uvec2(0x19190808, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19)\n};\n\nshared uvec2 iq2xxs_grid[256];\n\n#define NEEDS_INIT_IQ_SHMEM\nvoid init_iq_shmem(uvec3 wgsize)\n{\n    // copy the table into shared memory and sync\n    [[unroll]] for (uint i = 0; i < iq2xxs_grid.length(); i += wgsize.x) {\n        if (iq2xxs_grid_const.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xxs_grid_const.length()) {\n            iq2xxs_grid[i + gl_LocalInvocationIndex.x] = iq2xxs_grid_const[i + gl_LocalInvocationIndex.x];\n        }\n    }\n    barrier();\n}\n\n#define QUANT_K QUANT_K_IQ2_XXS\n#define QUANT_R QUANT_R_IQ2_XXS\n#define A_TYPE block_iq2_xxs\n#define A_TYPE_PACKED16 block_iq2_xxs_packed16\n#endif\n\n#define QUANT_K_IQ2_XS 256\n#define QUANT_R_IQ2_XS 1\n\nstruct block_iq2_xs\n{\n    float16_t d;\n    uint16_t qs[QUANT_K_IQ2_XS/8];\n    uint8_t scales[QUANT_K_IQ2_XS/32];\n};\n\nstruct block_iq2_xs_packed16\n{\n    float16_t d;\n    uint16_t qs[QUANT_K_IQ2_XS/8];\n    uint16_t scales[QUANT_K_IQ2_XS/64];\n};\n\n#if defined(DATA_A_IQ2_XS)\n\nconst uvec2 iq2xs_grid_const[512] = {\n    uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808),\n    uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808),\n    uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808),\n    uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808),\n    uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808),\n    uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808),\n    uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808),\n    uvec2(0x2b191908, 0x08080808), uvec2(0x2b192b19, 0x08080808), uvec2(0x2b2b0808, 0x08080808), uvec2(0x08080819, 0x08080819),\n    uvec2(0x08081908, 0x08080819), uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819),\n    uvec2(0x0819082b, 0x08080819), uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x08192b2b, 0x08080819),\n    uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819),\n    uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819),\n    uvec2(0x192b0808, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), uvec2(0x2b081908, 0x08080819),\n    uvec2(0x2b190808, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x08081919, 0x0808082b),\n    uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), uvec2(0x082b0808, 0x0808082b),\n    uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b),\n    uvec2(0x2b080808, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908),\n    uvec2(0x0808192b, 0x08081908), uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908),\n    uvec2(0x08191919, 0x08081908), uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908),\n    uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), uvec2(0x19082b08, 0x08081908),\n    uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), uvec2(0x1919192b, 0x08081908), uvec2(0x192b0808, 0x08081908),\n    uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x08080808, 0x08081919),\n    uvec2(0x0808082b, 0x08081919), uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08190819, 0x08081919),\n    uvec2(0x08191908, 0x08081919), uvec2(0x082b0808, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919),\n    uvec2(0x19190808, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x2b080808, 0x08081919), uvec2(0x08080819, 0x0808192b),\n    uvec2(0x08081908, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x082b192b, 0x0808192b), uvec2(0x19080808, 0x0808192b),\n    uvec2(0x1908082b, 0x0808192b), uvec2(0x2b081908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08),\n    uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08082b2b, 0x08082b08), uvec2(0x08190819, 0x08082b08),\n    uvec2(0x08191908, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), uvec2(0x19080819, 0x08082b08),\n    uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x19192b08, 0x08082b08), uvec2(0x2b080808, 0x08082b08),\n    uvec2(0x2b2b0808, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), uvec2(0x08081908, 0x08082b19),\n    uvec2(0x08190808, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x2b080819, 0x08082b19), uvec2(0x2b082b19, 0x08082b19),\n    uvec2(0x08080808, 0x08082b2b), uvec2(0x082b0808, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x2b19192b, 0x08082b2b),\n    uvec2(0x2b2b0808, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x0808192b, 0x08190808),\n    uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), uvec2(0x08191919, 0x08190808),\n    uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), uvec2(0x19080808, 0x08190808),\n    uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808),\n    uvec2(0x19191908, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b2b2b, 0x08190808), uvec2(0x2b080819, 0x08190808),\n    uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819),\n    uvec2(0x08081919, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819),\n    uvec2(0x082b0808, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), uvec2(0x19190808, 0x08190819),\n    uvec2(0x2b080808, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x2b19192b, 0x08190819), uvec2(0x08080819, 0x0819082b),\n    uvec2(0x08081908, 0x0819082b), uvec2(0x0808192b, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x19080808, 0x0819082b),\n    uvec2(0x192b0808, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908),\n    uvec2(0x08082b08, 0x08191908), uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x082b0808, 0x08191908),\n    uvec2(0x19080819, 0x08191908), uvec2(0x19081908, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908),\n    uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919),\n    uvec2(0x08190808, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x08191908, 0x0819192b),\n    uvec2(0x19082b19, 0x0819192b), uvec2(0x08080819, 0x08192b08), uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08),\n    uvec2(0x0819082b, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x19191908, 0x08192b08), uvec2(0x2b08192b, 0x08192b08),\n    uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x192b192b, 0x08192b19), uvec2(0x19190819, 0x08192b2b),\n    uvec2(0x2b2b2b19, 0x08192b2b), uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808),\n    uvec2(0x08082b08, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808),\n    uvec2(0x082b0808, 0x082b0808), uvec2(0x19080819, 0x082b0808), uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808),\n    uvec2(0x2b080808, 0x082b0808), uvec2(0x2b2b0808, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819),\n    uvec2(0x08190808, 0x082b0819), uvec2(0x19080808, 0x082b0819), uvec2(0x19082b08, 0x082b0819), uvec2(0x192b1919, 0x082b0819),\n    uvec2(0x08080808, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x2b080808, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b),\n    uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x082b2b19, 0x082b1908),\n    uvec2(0x19080808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x1919082b, 0x082b1919),\n    uvec2(0x2b192b19, 0x082b1919), uvec2(0x08080819, 0x082b192b), uvec2(0x08192b2b, 0x082b192b), uvec2(0x2b2b192b, 0x082b192b),\n    uvec2(0x08080808, 0x082b2b08), uvec2(0x08082b08, 0x082b2b08), uvec2(0x08082b2b, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08),\n    uvec2(0x19191919, 0x082b2b08), uvec2(0x2b082b08, 0x082b2b08), uvec2(0x2b2b082b, 0x082b2b08), uvec2(0x192b2b08, 0x082b2b19),\n    uvec2(0x2b190808, 0x082b2b19), uvec2(0x08082b08, 0x082b2b2b), uvec2(0x082b0808, 0x082b2b2b), uvec2(0x2b08082b, 0x082b2b2b),\n    uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808),\n    uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x0819082b, 0x19080808),\n    uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808),\n    uvec2(0x19080808, 0x19080808), uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808),\n    uvec2(0x19082b2b, 0x19080808), uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x192b0808, 0x19080808),\n    uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808),\n    uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819),\n    uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x19080819, 0x19080819),\n    uvec2(0x19081908, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819),\n    uvec2(0x2b2b082b, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), uvec2(0x08190808, 0x1908082b),\n    uvec2(0x0819082b, 0x1908082b), uvec2(0x082b2b19, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x08080808, 0x19081908),\n    uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), uvec2(0x08082b08, 0x19081908), uvec2(0x08190819, 0x19081908),\n    uvec2(0x08191908, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x19080819, 0x19081908),\n    uvec2(0x19081908, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b191908, 0x19081908),\n    uvec2(0x08080819, 0x19081919), uvec2(0x08081908, 0x19081919), uvec2(0x08190808, 0x19081919), uvec2(0x082b1908, 0x19081919),\n    uvec2(0x19080808, 0x19081919), uvec2(0x2b192b2b, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08082b2b, 0x1908192b),\n    uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08),\n    uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), uvec2(0x19191908, 0x19082b08),\n    uvec2(0x192b082b, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x19081908, 0x19082b19),\n    uvec2(0x19190808, 0x19082b19), uvec2(0x192b2b19, 0x19082b19), uvec2(0x08081908, 0x19082b2b), uvec2(0x08080808, 0x19190808),\n    uvec2(0x0808082b, 0x19190808), uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808),\n    uvec2(0x08191908, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808),\n    uvec2(0x19081908, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x2b080808, 0x19190808), uvec2(0x08080819, 0x19190819),\n    uvec2(0x08081908, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x08191919, 0x19190819), uvec2(0x19080808, 0x19190819),\n    uvec2(0x1908082b, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x2b2b2b2b, 0x1919082b),\n    uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x082b0819, 0x19191908),\n    uvec2(0x19080808, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b2b0819, 0x19191908),\n    uvec2(0x08080808, 0x19191919), uvec2(0x08082b08, 0x19191919), uvec2(0x2b080808, 0x19191919), uvec2(0x2b082b08, 0x19191919),\n    uvec2(0x082b0819, 0x1919192b), uvec2(0x192b2b08, 0x1919192b), uvec2(0x2b2b0819, 0x1919192b), uvec2(0x08080808, 0x19192b08),\n    uvec2(0x08191908, 0x19192b08), uvec2(0x19080819, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x2b192b19, 0x19192b08),\n    uvec2(0x08192b2b, 0x19192b19), uvec2(0x19080808, 0x19192b19), uvec2(0x1908082b, 0x19192b19), uvec2(0x2b081919, 0x19192b2b),\n    uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808),\n    uvec2(0x19191908, 0x192b0808), uvec2(0x192b082b, 0x192b0808), uvec2(0x2b08192b, 0x192b0808), uvec2(0x2b2b2b19, 0x192b0808),\n    uvec2(0x08080808, 0x192b0819), uvec2(0x082b1908, 0x192b082b), uvec2(0x19082b2b, 0x192b082b), uvec2(0x2b19082b, 0x192b082b),\n    uvec2(0x08080808, 0x192b1908), uvec2(0x0819192b, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x19080808, 0x192b1919),\n    uvec2(0x19081919, 0x192b1919), uvec2(0x2b2b1908, 0x192b1919), uvec2(0x08080819, 0x192b2b08), uvec2(0x192b2b2b, 0x192b2b08),\n    uvec2(0x082b1919, 0x192b2b19), uvec2(0x0808192b, 0x192b2b2b), uvec2(0x19191908, 0x192b2b2b), uvec2(0x192b082b, 0x192b2b2b),\n    uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808),\n    uvec2(0x08190819, 0x2b080808), uvec2(0x08191908, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b2b2b, 0x2b080808),\n    uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x2b080808, 0x2b080808),\n    uvec2(0x2b08082b, 0x2b080808), uvec2(0x2b2b2b08, 0x2b080808), uvec2(0x2b2b2b2b, 0x2b080808), uvec2(0x08080819, 0x2b080819),\n    uvec2(0x08081908, 0x2b080819), uvec2(0x0808192b, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x19080808, 0x2b080819),\n    uvec2(0x19190819, 0x2b080819), uvec2(0x19192b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x082b0808, 0x2b08082b),\n    uvec2(0x2b080808, 0x2b08082b), uvec2(0x2b08082b, 0x2b08082b), uvec2(0x2b2b0808, 0x2b08082b), uvec2(0x2b2b2b08, 0x2b08082b),\n    uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908),\n    uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b082b19, 0x2b081908),\n    uvec2(0x08080808, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x2b2b1919, 0x2b081919), uvec2(0x08192b08, 0x2b08192b),\n    uvec2(0x192b2b2b, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08082b08, 0x2b082b08), uvec2(0x082b1919, 0x2b082b08),\n    uvec2(0x19192b2b, 0x2b082b08), uvec2(0x2b080808, 0x2b082b08), uvec2(0x2b08082b, 0x2b082b08), uvec2(0x2b2b2b08, 0x2b082b08),\n    uvec2(0x0808192b, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x2b080808, 0x2b082b2b), uvec2(0x2b082b08, 0x2b082b2b),\n    uvec2(0x2b19192b, 0x2b082b2b), uvec2(0x2b2b2b08, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), uvec2(0x08081908, 0x2b190808),\n    uvec2(0x08190808, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x1919192b, 0x2b190808), uvec2(0x2b081908, 0x2b190808),\n    uvec2(0x08080808, 0x2b190819), uvec2(0x082b082b, 0x2b190819), uvec2(0x192b1908, 0x2b190819), uvec2(0x1919192b, 0x2b19082b),\n    uvec2(0x2b082b19, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x08081919, 0x2b191908), uvec2(0x19081908, 0x2b191908),\n    uvec2(0x19190808, 0x2b191908), uvec2(0x19192b08, 0x2b191908), uvec2(0x082b2b19, 0x2b191919), uvec2(0x2b190808, 0x2b191919),\n    uvec2(0x2b19082b, 0x2b191919), uvec2(0x19080819, 0x2b19192b), uvec2(0x19190819, 0x2b192b08), uvec2(0x2b2b192b, 0x2b192b08),\n    uvec2(0x19082b19, 0x2b192b19), uvec2(0x08191919, 0x2b192b2b), uvec2(0x192b0808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808),\n    uvec2(0x0808082b, 0x2b2b0808), uvec2(0x08082b08, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), uvec2(0x082b0808, 0x2b2b0808),\n    uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x2b2b0808, 0x2b2b0808), uvec2(0x19190819, 0x2b2b0819), uvec2(0x19192b19, 0x2b2b0819),\n    uvec2(0x2b2b192b, 0x2b2b0819), uvec2(0x08080808, 0x2b2b082b), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b08, 0x2b2b082b),\n    uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b080808, 0x2b2b082b), uvec2(0x2b2b0808, 0x2b2b082b), uvec2(0x19080808, 0x2b2b1908),\n    uvec2(0x2b191919, 0x2b2b1908), uvec2(0x192b1919, 0x2b2b192b), uvec2(0x2b192b08, 0x2b2b192b), uvec2(0x08082b2b, 0x2b2b2b08),\n    uvec2(0x082b0808, 0x2b2b2b08), uvec2(0x082b082b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b0808, 0x2b2b2b08),\n    uvec2(0x2b2b2b08, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19), uvec2(0x2b081908, 0x2b2b2b19), uvec2(0x2b08192b, 0x2b2b2b19),\n    uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x082b2b2b, 0x2b2b2b2b), uvec2(0x2b190819, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b),\n};\n\nshared uvec2 iq2xs_grid[512];\n\n#define NEEDS_INIT_IQ_SHMEM\nvoid init_iq_shmem(uvec3 wgsize)\n{\n    // copy the table into shared memory and sync\n    [[unroll]] for (uint i = 0; i < iq2xs_grid.length(); i += wgsize.x) {\n        if (iq2xs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xs_grid_const.length()) {\n            iq2xs_grid[i + gl_LocalInvocationIndex.x] = iq2xs_grid_const[i + gl_LocalInvocationIndex.x];\n        }\n    }\n    barrier();\n}\n\n#define QUANT_K QUANT_K_IQ2_XS\n#define QUANT_R QUANT_R_IQ2_XS\n#define A_TYPE block_iq2_xs\n#define A_TYPE_PACKED16 block_iq2_xs_packed16\n#endif\n\n#define QUANT_K_IQ2_S 256\n#define QUANT_R_IQ2_S 1\n\nstruct block_iq2_s\n{\n    float16_t d;\n    uint8_t qs[QUANT_K_IQ2_S/4];\n    uint8_t qh[QUANT_K_IQ2_S/32];\n    uint8_t scales[QUANT_K_IQ2_S/32];\n};\n\nstruct block_iq2_s_packed16\n{\n    float16_t d;\n    uint16_t qs[QUANT_K_IQ2_S/8];\n    uint16_t qh[QUANT_K_IQ2_S/64];\n    uint16_t scales[QUANT_K_IQ2_S/64];\n};\n\n#if defined(DATA_A_IQ2_S)\n\nconst uvec2 iq2s_grid_const[1024] = {\n    uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808),\n    uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808),\n    uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808),\n    uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808),\n    uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808),\n    uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x192b192b, 0x08080808),\n    uvec2(0x192b2b19, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808),\n    uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), uvec2(0x2b191908, 0x08080808), uvec2(0x2b2b0808, 0x08080808),\n    uvec2(0x2b2b1919, 0x08080808), uvec2(0x2b2b2b2b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819),\n    uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), uvec2(0x0819082b, 0x08080819),\n    uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819),\n    uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819),\n    uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), uvec2(0x1919192b, 0x08080819), uvec2(0x19192b19, 0x08080819),\n    uvec2(0x192b0808, 0x08080819), uvec2(0x192b1919, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819),\n    uvec2(0x2b081908, 0x08080819), uvec2(0x2b190808, 0x08080819), uvec2(0x2b19082b, 0x08080819), uvec2(0x2b191919, 0x08080819),\n    uvec2(0x2b2b0819, 0x08080819), uvec2(0x2b2b1908, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b),\n    uvec2(0x08081919, 0x0808082b), uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b),\n    uvec2(0x082b0808, 0x0808082b), uvec2(0x082b2b2b, 0x0808082b), uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b),\n    uvec2(0x1908192b, 0x0808082b), uvec2(0x19082b19, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b),\n    uvec2(0x2b080808, 0x0808082b), uvec2(0x2b081919, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x2b191908, 0x0808082b),\n    uvec2(0x2b2b082b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x0808192b, 0x08081908),\n    uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), uvec2(0x08191919, 0x08081908),\n    uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), uvec2(0x082b192b, 0x08081908),\n    uvec2(0x082b2b19, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908),\n    uvec2(0x19082b08, 0x08081908), uvec2(0x19082b2b, 0x08081908), uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908),\n    uvec2(0x1919192b, 0x08081908), uvec2(0x19192b19, 0x08081908), uvec2(0x192b0808, 0x08081908), uvec2(0x192b082b, 0x08081908),\n    uvec2(0x192b1919, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b08192b, 0x08081908),\n    uvec2(0x2b082b19, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x2b191919, 0x08081908), uvec2(0x2b192b08, 0x08081908),\n    uvec2(0x2b2b0819, 0x08081908), uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919),\n    uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08082b2b, 0x08081919), uvec2(0x08190819, 0x08081919),\n    uvec2(0x08191908, 0x08081919), uvec2(0x0819192b, 0x08081919), uvec2(0x08192b19, 0x08081919), uvec2(0x082b0808, 0x08081919),\n    uvec2(0x082b1919, 0x08081919), uvec2(0x082b2b08, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919),\n    uvec2(0x1908192b, 0x08081919), uvec2(0x19082b19, 0x08081919), uvec2(0x19190808, 0x08081919), uvec2(0x1919082b, 0x08081919),\n    uvec2(0x19191919, 0x08081919), uvec2(0x19192b08, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x192b1908, 0x08081919),\n    uvec2(0x2b080808, 0x08081919), uvec2(0x2b08082b, 0x08081919), uvec2(0x2b081919, 0x08081919), uvec2(0x2b082b08, 0x08081919),\n    uvec2(0x2b190819, 0x08081919), uvec2(0x2b191908, 0x08081919), uvec2(0x2b2b0808, 0x08081919), uvec2(0x08080819, 0x0808192b),\n    uvec2(0x08081908, 0x0808192b), uvec2(0x0808192b, 0x0808192b), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b),\n    uvec2(0x08191919, 0x0808192b), uvec2(0x19080808, 0x0808192b), uvec2(0x19081919, 0x0808192b), uvec2(0x19082b08, 0x0808192b),\n    uvec2(0x19190819, 0x0808192b), uvec2(0x19191908, 0x0808192b), uvec2(0x192b0808, 0x0808192b), uvec2(0x2b080819, 0x0808192b),\n    uvec2(0x2b081908, 0x0808192b), uvec2(0x2b190808, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08),\n    uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08190819, 0x08082b08), uvec2(0x08191908, 0x08082b08),\n    uvec2(0x0819192b, 0x08082b08), uvec2(0x08192b19, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08),\n    uvec2(0x082b2b2b, 0x08082b08), uvec2(0x19080819, 0x08082b08), uvec2(0x19081908, 0x08082b08), uvec2(0x1908192b, 0x08082b08),\n    uvec2(0x19082b19, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x19191919, 0x08082b08),\n    uvec2(0x19192b08, 0x08082b08), uvec2(0x192b0819, 0x08082b08), uvec2(0x192b1908, 0x08082b08), uvec2(0x2b080808, 0x08082b08),\n    uvec2(0x2b081919, 0x08082b08), uvec2(0x2b191908, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19),\n    uvec2(0x08081908, 0x08082b19), uvec2(0x08190808, 0x08082b19), uvec2(0x0819082b, 0x08082b19), uvec2(0x08191919, 0x08082b19),\n    uvec2(0x08192b08, 0x08082b19), uvec2(0x082b0819, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x19081919, 0x08082b19),\n    uvec2(0x19082b08, 0x08082b19), uvec2(0x19190819, 0x08082b19), uvec2(0x19191908, 0x08082b19), uvec2(0x192b0808, 0x08082b19),\n    uvec2(0x2b080819, 0x08082b19), uvec2(0x2b190808, 0x08082b19), uvec2(0x08080808, 0x08082b2b), uvec2(0x08190819, 0x08082b2b),\n    uvec2(0x08191908, 0x08082b2b), uvec2(0x082b082b, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x082b2b2b, 0x08082b2b),\n    uvec2(0x19190808, 0x08082b2b), uvec2(0x2b192b19, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808),\n    uvec2(0x0808192b, 0x08190808), uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808),\n    uvec2(0x08191919, 0x08190808), uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808),\n    uvec2(0x082b192b, 0x08190808), uvec2(0x19080808, 0x08190808), uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808),\n    uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), uvec2(0x19191908, 0x08190808), uvec2(0x1919192b, 0x08190808),\n    uvec2(0x19192b19, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b082b, 0x08190808), uvec2(0x192b1919, 0x08190808),\n    uvec2(0x192b2b08, 0x08190808), uvec2(0x2b080819, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b08192b, 0x08190808),\n    uvec2(0x2b190808, 0x08190808), uvec2(0x2b191919, 0x08190808), uvec2(0x2b192b08, 0x08190808), uvec2(0x2b2b0819, 0x08190808),\n    uvec2(0x2b2b1908, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), uvec2(0x08081919, 0x08190819),\n    uvec2(0x08082b08, 0x08190819), uvec2(0x08082b2b, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819),\n    uvec2(0x0819192b, 0x08190819), uvec2(0x08192b19, 0x08190819), uvec2(0x082b0808, 0x08190819), uvec2(0x082b082b, 0x08190819),\n    uvec2(0x082b1919, 0x08190819), uvec2(0x082b2b08, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819),\n    uvec2(0x1908192b, 0x08190819), uvec2(0x19082b19, 0x08190819), uvec2(0x19190808, 0x08190819), uvec2(0x1919082b, 0x08190819),\n    uvec2(0x19191919, 0x08190819), uvec2(0x19192b08, 0x08190819), uvec2(0x192b0819, 0x08190819), uvec2(0x192b1908, 0x08190819),\n    uvec2(0x2b080808, 0x08190819), uvec2(0x2b08082b, 0x08190819), uvec2(0x2b081919, 0x08190819), uvec2(0x2b082b08, 0x08190819),\n    uvec2(0x2b190819, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x08080819, 0x0819082b), uvec2(0x08081908, 0x0819082b),\n    uvec2(0x08082b19, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x08191919, 0x0819082b), uvec2(0x082b0819, 0x0819082b),\n    uvec2(0x082b1908, 0x0819082b), uvec2(0x19080808, 0x0819082b), uvec2(0x19081919, 0x0819082b), uvec2(0x19190819, 0x0819082b),\n    uvec2(0x19191908, 0x0819082b), uvec2(0x2b080819, 0x0819082b), uvec2(0x2b081908, 0x0819082b), uvec2(0x2b190808, 0x0819082b),\n    uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), uvec2(0x08082b08, 0x08191908),\n    uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x0819192b, 0x08191908), uvec2(0x08192b19, 0x08191908),\n    uvec2(0x082b0808, 0x08191908), uvec2(0x082b1919, 0x08191908), uvec2(0x082b2b08, 0x08191908), uvec2(0x19080819, 0x08191908),\n    uvec2(0x19081908, 0x08191908), uvec2(0x1908192b, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908),\n    uvec2(0x1919082b, 0x08191908), uvec2(0x19191919, 0x08191908), uvec2(0x19192b08, 0x08191908), uvec2(0x192b0819, 0x08191908),\n    uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x2b08082b, 0x08191908), uvec2(0x2b081919, 0x08191908),\n    uvec2(0x2b082b08, 0x08191908), uvec2(0x2b190819, 0x08191908), uvec2(0x2b191908, 0x08191908), uvec2(0x2b2b0808, 0x08191908),\n    uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), uvec2(0x0808192b, 0x08191919), uvec2(0x08082b19, 0x08191919),\n    uvec2(0x08190808, 0x08191919), uvec2(0x0819082b, 0x08191919), uvec2(0x08191919, 0x08191919), uvec2(0x08192b08, 0x08191919),\n    uvec2(0x082b0819, 0x08191919), uvec2(0x082b1908, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x1908082b, 0x08191919),\n    uvec2(0x19081919, 0x08191919), uvec2(0x19082b08, 0x08191919), uvec2(0x19190819, 0x08191919), uvec2(0x19191908, 0x08191919),\n    uvec2(0x192b0808, 0x08191919), uvec2(0x2b080819, 0x08191919), uvec2(0x2b081908, 0x08191919), uvec2(0x2b190808, 0x08191919),\n    uvec2(0x08080808, 0x0819192b), uvec2(0x08081919, 0x0819192b), uvec2(0x08082b08, 0x0819192b), uvec2(0x08190819, 0x0819192b),\n    uvec2(0x08191908, 0x0819192b), uvec2(0x082b0808, 0x0819192b), uvec2(0x19080819, 0x0819192b), uvec2(0x19081908, 0x0819192b),\n    uvec2(0x19190808, 0x0819192b), uvec2(0x2b080808, 0x0819192b), uvec2(0x2b2b2b2b, 0x0819192b), uvec2(0x08080819, 0x08192b08),\n    uvec2(0x08081908, 0x08192b08), uvec2(0x0808192b, 0x08192b08), uvec2(0x08082b19, 0x08192b08), uvec2(0x08190808, 0x08192b08),\n    uvec2(0x08191919, 0x08192b08), uvec2(0x08192b08, 0x08192b08), uvec2(0x082b0819, 0x08192b08), uvec2(0x19080808, 0x08192b08),\n    uvec2(0x1908082b, 0x08192b08), uvec2(0x19081919, 0x08192b08), uvec2(0x19082b08, 0x08192b08), uvec2(0x19190819, 0x08192b08),\n    uvec2(0x19191908, 0x08192b08), uvec2(0x192b0808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), uvec2(0x2b081908, 0x08192b08),\n    uvec2(0x08080808, 0x08192b19), uvec2(0x0808082b, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x08082b08, 0x08192b19),\n    uvec2(0x08190819, 0x08192b19), uvec2(0x08191908, 0x08192b19), uvec2(0x082b0808, 0x08192b19), uvec2(0x19080819, 0x08192b19),\n    uvec2(0x19081908, 0x08192b19), uvec2(0x19190808, 0x08192b19), uvec2(0x192b2b19, 0x08192b19), uvec2(0x2b2b082b, 0x08192b19),\n    uvec2(0x08081908, 0x08192b2b), uvec2(0x08190808, 0x08192b2b), uvec2(0x19080808, 0x08192b2b), uvec2(0x1919192b, 0x08192b2b),\n    uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), uvec2(0x08082b08, 0x082b0808),\n    uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), uvec2(0x0819192b, 0x082b0808), uvec2(0x08192b19, 0x082b0808),\n    uvec2(0x082b0808, 0x082b0808), uvec2(0x082b1919, 0x082b0808), uvec2(0x082b2b2b, 0x082b0808), uvec2(0x19080819, 0x082b0808),\n    uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), uvec2(0x1919082b, 0x082b0808), uvec2(0x19191919, 0x082b0808),\n    uvec2(0x192b1908, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b082b2b, 0x082b0808), uvec2(0x2b191908, 0x082b0808),\n    uvec2(0x2b2b2b2b, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), uvec2(0x08190808, 0x082b0819),\n    uvec2(0x0819082b, 0x082b0819), uvec2(0x08191919, 0x082b0819), uvec2(0x082b0819, 0x082b0819), uvec2(0x19080808, 0x082b0819),\n    uvec2(0x1908082b, 0x082b0819), uvec2(0x19081919, 0x082b0819), uvec2(0x19190819, 0x082b0819), uvec2(0x19191908, 0x082b0819),\n    uvec2(0x192b0808, 0x082b0819), uvec2(0x2b080819, 0x082b0819), uvec2(0x2b081908, 0x082b0819), uvec2(0x2b190808, 0x082b0819),\n    uvec2(0x08080808, 0x082b082b), uvec2(0x08082b2b, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x082b2b08, 0x082b082b),\n    uvec2(0x082b2b2b, 0x082b082b), uvec2(0x19081908, 0x082b082b), uvec2(0x19190808, 0x082b082b), uvec2(0x2b082b08, 0x082b082b),\n    uvec2(0x2b082b2b, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908),\n    uvec2(0x0808192b, 0x082b1908), uvec2(0x08082b19, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x08191919, 0x082b1908),\n    uvec2(0x08192b08, 0x082b1908), uvec2(0x082b0819, 0x082b1908), uvec2(0x082b1908, 0x082b1908), uvec2(0x19080808, 0x082b1908),\n    uvec2(0x1908082b, 0x082b1908), uvec2(0x19081919, 0x082b1908), uvec2(0x19082b08, 0x082b1908), uvec2(0x19190819, 0x082b1908),\n    uvec2(0x19191908, 0x082b1908), uvec2(0x192b0808, 0x082b1908), uvec2(0x2b080819, 0x082b1908), uvec2(0x2b081908, 0x082b1908),\n    uvec2(0x2b190808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x08081919, 0x082b1919), uvec2(0x08082b08, 0x082b1919),\n    uvec2(0x08190819, 0x082b1919), uvec2(0x08191908, 0x082b1919), uvec2(0x082b0808, 0x082b1919), uvec2(0x19080819, 0x082b1919),\n    uvec2(0x19081908, 0x082b1919), uvec2(0x19190808, 0x082b1919), uvec2(0x192b192b, 0x082b1919), uvec2(0x2b080808, 0x082b1919),\n    uvec2(0x08080819, 0x082b192b), uvec2(0x08081908, 0x082b192b), uvec2(0x08190808, 0x082b192b), uvec2(0x19080808, 0x082b192b),\n    uvec2(0x19192b19, 0x082b192b), uvec2(0x08080808, 0x082b2b08), uvec2(0x08081919, 0x082b2b08), uvec2(0x08190819, 0x082b2b08),\n    uvec2(0x08191908, 0x082b2b08), uvec2(0x19080819, 0x082b2b08), uvec2(0x19081908, 0x082b2b08), uvec2(0x19190808, 0x082b2b08),\n    uvec2(0x2b082b2b, 0x082b2b08), uvec2(0x2b2b2b2b, 0x082b2b08), uvec2(0x08080819, 0x082b2b19), uvec2(0x08081908, 0x082b2b19),\n    uvec2(0x08190808, 0x082b2b19), uvec2(0x2b191919, 0x082b2b19), uvec2(0x08082b2b, 0x082b2b2b), uvec2(0x082b082b, 0x082b2b2b),\n    uvec2(0x192b1908, 0x082b2b2b), uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808),\n    uvec2(0x08081908, 0x19080808), uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808),\n    uvec2(0x0819082b, 0x19080808), uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x08192b2b, 0x19080808),\n    uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x082b192b, 0x19080808), uvec2(0x19080808, 0x19080808),\n    uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), uvec2(0x19082b2b, 0x19080808),\n    uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x1919192b, 0x19080808), uvec2(0x19192b19, 0x19080808),\n    uvec2(0x192b0808, 0x19080808), uvec2(0x192b082b, 0x19080808), uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808),\n    uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), uvec2(0x2b191919, 0x19080808), uvec2(0x2b192b08, 0x19080808),\n    uvec2(0x2b2b0819, 0x19080808), uvec2(0x2b2b1908, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819),\n    uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819),\n    uvec2(0x0819192b, 0x19080819), uvec2(0x08192b19, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x082b082b, 0x19080819),\n    uvec2(0x082b1919, 0x19080819), uvec2(0x19080819, 0x19080819), uvec2(0x19081908, 0x19080819), uvec2(0x1908192b, 0x19080819),\n    uvec2(0x19082b19, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x1919082b, 0x19080819), uvec2(0x19191919, 0x19080819),\n    uvec2(0x19192b08, 0x19080819), uvec2(0x192b0819, 0x19080819), uvec2(0x192b1908, 0x19080819), uvec2(0x2b080808, 0x19080819),\n    uvec2(0x2b08082b, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x2b082b08, 0x19080819), uvec2(0x2b190819, 0x19080819),\n    uvec2(0x2b191908, 0x19080819), uvec2(0x2b2b0808, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b),\n    uvec2(0x08190808, 0x1908082b), uvec2(0x0819082b, 0x1908082b), uvec2(0x08191919, 0x1908082b), uvec2(0x08192b08, 0x1908082b),\n    uvec2(0x082b1908, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x19081919, 0x1908082b), uvec2(0x19082b08, 0x1908082b),\n    uvec2(0x19190819, 0x1908082b), uvec2(0x19191908, 0x1908082b), uvec2(0x192b0808, 0x1908082b), uvec2(0x2b080819, 0x1908082b),\n    uvec2(0x2b081908, 0x1908082b), uvec2(0x08080808, 0x19081908), uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908),\n    uvec2(0x08082b08, 0x19081908), uvec2(0x08082b2b, 0x19081908), uvec2(0x08190819, 0x19081908), uvec2(0x08191908, 0x19081908),\n    uvec2(0x0819192b, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x082b082b, 0x19081908),\n    uvec2(0x082b1919, 0x19081908), uvec2(0x082b2b08, 0x19081908), uvec2(0x19080819, 0x19081908), uvec2(0x19081908, 0x19081908),\n    uvec2(0x1908192b, 0x19081908), uvec2(0x19082b19, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x1919082b, 0x19081908),\n    uvec2(0x19191919, 0x19081908), uvec2(0x19192b08, 0x19081908), uvec2(0x192b0819, 0x19081908), uvec2(0x192b1908, 0x19081908),\n    uvec2(0x2b080808, 0x19081908), uvec2(0x2b08082b, 0x19081908), uvec2(0x2b081919, 0x19081908), uvec2(0x2b082b08, 0x19081908),\n    uvec2(0x2b190819, 0x19081908), uvec2(0x2b191908, 0x19081908), uvec2(0x2b2b0808, 0x19081908), uvec2(0x08080819, 0x19081919),\n    uvec2(0x08081908, 0x19081919), uvec2(0x0808192b, 0x19081919), uvec2(0x08082b19, 0x19081919), uvec2(0x08190808, 0x19081919),\n    uvec2(0x0819082b, 0x19081919), uvec2(0x08191919, 0x19081919), uvec2(0x08192b08, 0x19081919), uvec2(0x082b0819, 0x19081919),\n    uvec2(0x082b1908, 0x19081919), uvec2(0x19080808, 0x19081919), uvec2(0x1908082b, 0x19081919), uvec2(0x19081919, 0x19081919),\n    uvec2(0x19082b08, 0x19081919), uvec2(0x19190819, 0x19081919), uvec2(0x19191908, 0x19081919), uvec2(0x192b0808, 0x19081919),\n    uvec2(0x192b2b2b, 0x19081919), uvec2(0x2b080819, 0x19081919), uvec2(0x2b081908, 0x19081919), uvec2(0x2b190808, 0x19081919),\n    uvec2(0x08080808, 0x1908192b), uvec2(0x0808082b, 0x1908192b), uvec2(0x08081919, 0x1908192b), uvec2(0x08082b08, 0x1908192b),\n    uvec2(0x08190819, 0x1908192b), uvec2(0x08191908, 0x1908192b), uvec2(0x082b0808, 0x1908192b), uvec2(0x19080819, 0x1908192b),\n    uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x2b080808, 0x1908192b), uvec2(0x2b2b1919, 0x1908192b),\n    uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), uvec2(0x08082b19, 0x19082b08), uvec2(0x08190808, 0x19082b08),\n    uvec2(0x0819082b, 0x19082b08), uvec2(0x08191919, 0x19082b08), uvec2(0x08192b08, 0x19082b08), uvec2(0x082b0819, 0x19082b08),\n    uvec2(0x082b1908, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x1908082b, 0x19082b08), uvec2(0x19081919, 0x19082b08),\n    uvec2(0x19082b08, 0x19082b08), uvec2(0x19190819, 0x19082b08), uvec2(0x19191908, 0x19082b08), uvec2(0x192b0808, 0x19082b08),\n    uvec2(0x2b081908, 0x19082b08), uvec2(0x2b190808, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x0808082b, 0x19082b19),\n    uvec2(0x08081919, 0x19082b19), uvec2(0x08082b08, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x08191908, 0x19082b19),\n    uvec2(0x082b0808, 0x19082b19), uvec2(0x19080819, 0x19082b19), uvec2(0x19081908, 0x19082b19), uvec2(0x19190808, 0x19082b19),\n    uvec2(0x2b080808, 0x19082b19), uvec2(0x2b19192b, 0x19082b19), uvec2(0x08080819, 0x19082b2b), uvec2(0x08081908, 0x19082b2b),\n    uvec2(0x08190808, 0x19082b2b), uvec2(0x19080808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x0808082b, 0x19190808),\n    uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), uvec2(0x08191908, 0x19190808),\n    uvec2(0x0819192b, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b082b, 0x19190808),\n    uvec2(0x082b1919, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), uvec2(0x19081908, 0x19190808),\n    uvec2(0x1908192b, 0x19190808), uvec2(0x19082b19, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x1919082b, 0x19190808),\n    uvec2(0x19191919, 0x19190808), uvec2(0x19192b08, 0x19190808), uvec2(0x192b0819, 0x19190808), uvec2(0x192b1908, 0x19190808),\n    uvec2(0x2b080808, 0x19190808), uvec2(0x2b08082b, 0x19190808), uvec2(0x2b081919, 0x19190808), uvec2(0x2b082b08, 0x19190808),\n    uvec2(0x2b190819, 0x19190808), uvec2(0x2b191908, 0x19190808), uvec2(0x08080819, 0x19190819), uvec2(0x08081908, 0x19190819),\n    uvec2(0x0808192b, 0x19190819), uvec2(0x08082b19, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x0819082b, 0x19190819),\n    uvec2(0x08191919, 0x19190819), uvec2(0x08192b08, 0x19190819), uvec2(0x082b0819, 0x19190819), uvec2(0x082b1908, 0x19190819),\n    uvec2(0x19080808, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x19081919, 0x19190819), uvec2(0x19082b08, 0x19190819),\n    uvec2(0x19190819, 0x19190819), uvec2(0x19191908, 0x19190819), uvec2(0x192b0808, 0x19190819), uvec2(0x2b080819, 0x19190819),\n    uvec2(0x2b081908, 0x19190819), uvec2(0x2b190808, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x08081919, 0x1919082b),\n    uvec2(0x08082b08, 0x1919082b), uvec2(0x08190819, 0x1919082b), uvec2(0x08191908, 0x1919082b), uvec2(0x082b0808, 0x1919082b),\n    uvec2(0x19080819, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x19190808, 0x1919082b), uvec2(0x192b2b19, 0x1919082b),\n    uvec2(0x2b080808, 0x1919082b), uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x0808192b, 0x19191908),\n    uvec2(0x08082b19, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x0819082b, 0x19191908), uvec2(0x08191919, 0x19191908),\n    uvec2(0x08192b08, 0x19191908), uvec2(0x082b0819, 0x19191908), uvec2(0x082b1908, 0x19191908), uvec2(0x19080808, 0x19191908),\n    uvec2(0x1908082b, 0x19191908), uvec2(0x19081919, 0x19191908), uvec2(0x19082b08, 0x19191908), uvec2(0x19190819, 0x19191908),\n    uvec2(0x19191908, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b081908, 0x19191908),\n    uvec2(0x2b190808, 0x19191908), uvec2(0x08080808, 0x19191919), uvec2(0x0808082b, 0x19191919), uvec2(0x08081919, 0x19191919),\n    uvec2(0x08082b08, 0x19191919), uvec2(0x08190819, 0x19191919), uvec2(0x08191908, 0x19191919), uvec2(0x082b0808, 0x19191919),\n    uvec2(0x19080819, 0x19191919), uvec2(0x19081908, 0x19191919), uvec2(0x19190808, 0x19191919), uvec2(0x2b080808, 0x19191919),\n    uvec2(0x08080819, 0x1919192b), uvec2(0x08081908, 0x1919192b), uvec2(0x08190808, 0x1919192b), uvec2(0x082b192b, 0x1919192b),\n    uvec2(0x19080808, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x0808082b, 0x19192b08), uvec2(0x08081919, 0x19192b08),\n    uvec2(0x08082b08, 0x19192b08), uvec2(0x08190819, 0x19192b08), uvec2(0x08191908, 0x19192b08), uvec2(0x082b0808, 0x19192b08),\n    uvec2(0x19080819, 0x19192b08), uvec2(0x19081908, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x19192b2b, 0x19192b08),\n    uvec2(0x2b080808, 0x19192b08), uvec2(0x08080819, 0x19192b19), uvec2(0x08081908, 0x19192b19), uvec2(0x08190808, 0x19192b19),\n    uvec2(0x19080808, 0x19192b19), uvec2(0x08080808, 0x19192b2b), uvec2(0x08192b19, 0x19192b2b), uvec2(0x2b081919, 0x19192b2b),\n    uvec2(0x2b2b2b08, 0x19192b2b), uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x0808192b, 0x192b0808),\n    uvec2(0x08190808, 0x192b0808), uvec2(0x0819082b, 0x192b0808), uvec2(0x08191919, 0x192b0808), uvec2(0x08192b08, 0x192b0808),\n    uvec2(0x082b0819, 0x192b0808), uvec2(0x082b1908, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x19081919, 0x192b0808),\n    uvec2(0x19082b08, 0x192b0808), uvec2(0x19190819, 0x192b0808), uvec2(0x19191908, 0x192b0808), uvec2(0x192b0808, 0x192b0808),\n    uvec2(0x2b081908, 0x192b0808), uvec2(0x2b190808, 0x192b0808), uvec2(0x08080808, 0x192b0819), uvec2(0x0808082b, 0x192b0819),\n    uvec2(0x08081919, 0x192b0819), uvec2(0x08082b08, 0x192b0819), uvec2(0x08190819, 0x192b0819), uvec2(0x08191908, 0x192b0819),\n    uvec2(0x082b0808, 0x192b0819), uvec2(0x19080819, 0x192b0819), uvec2(0x19081908, 0x192b0819), uvec2(0x19190808, 0x192b0819),\n    uvec2(0x2b080808, 0x192b0819), uvec2(0x2b192b19, 0x192b0819), uvec2(0x08081908, 0x192b082b), uvec2(0x08190808, 0x192b082b),\n    uvec2(0x19080808, 0x192b082b), uvec2(0x1919192b, 0x192b082b), uvec2(0x2b2b0819, 0x192b082b), uvec2(0x08080808, 0x192b1908),\n    uvec2(0x08081919, 0x192b1908), uvec2(0x08082b08, 0x192b1908), uvec2(0x08190819, 0x192b1908), uvec2(0x08191908, 0x192b1908),\n    uvec2(0x082b0808, 0x192b1908), uvec2(0x19080819, 0x192b1908), uvec2(0x19081908, 0x192b1908), uvec2(0x19190808, 0x192b1908),\n    uvec2(0x2b080808, 0x192b1908), uvec2(0x08080819, 0x192b1919), uvec2(0x08081908, 0x192b1919), uvec2(0x08190808, 0x192b1919),\n    uvec2(0x19080808, 0x192b1919), uvec2(0x19082b2b, 0x192b1919), uvec2(0x192b2b08, 0x192b1919), uvec2(0x2b19082b, 0x192b1919),\n    uvec2(0x08080808, 0x192b192b), uvec2(0x2b191908, 0x192b192b), uvec2(0x08080819, 0x192b2b08), uvec2(0x08081908, 0x192b2b08),\n    uvec2(0x08190808, 0x192b2b08), uvec2(0x192b1919, 0x192b2b08), uvec2(0x2b192b08, 0x192b2b08), uvec2(0x08080808, 0x192b2b19),\n    uvec2(0x082b2b2b, 0x192b2b19), uvec2(0x1908082b, 0x192b2b2b), uvec2(0x2b2b0819, 0x192b2b2b), uvec2(0x08080808, 0x2b080808),\n    uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), uvec2(0x08190819, 0x2b080808),\n    uvec2(0x08191908, 0x2b080808), uvec2(0x08192b19, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b1919, 0x2b080808),\n    uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x1919082b, 0x2b080808),\n    uvec2(0x19191919, 0x2b080808), uvec2(0x19192b08, 0x2b080808), uvec2(0x192b0819, 0x2b080808), uvec2(0x2b080808, 0x2b080808),\n    uvec2(0x2b081919, 0x2b080808), uvec2(0x2b190819, 0x2b080808), uvec2(0x2b191908, 0x2b080808), uvec2(0x08080819, 0x2b080819),\n    uvec2(0x08081908, 0x2b080819), uvec2(0x08082b19, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x0819082b, 0x2b080819),\n    uvec2(0x08191919, 0x2b080819), uvec2(0x08192b08, 0x2b080819), uvec2(0x082b0819, 0x2b080819), uvec2(0x082b1908, 0x2b080819),\n    uvec2(0x19080808, 0x2b080819), uvec2(0x1908082b, 0x2b080819), uvec2(0x19081919, 0x2b080819), uvec2(0x19082b08, 0x2b080819),\n    uvec2(0x19190819, 0x2b080819), uvec2(0x19191908, 0x2b080819), uvec2(0x2b080819, 0x2b080819), uvec2(0x2b081908, 0x2b080819),\n    uvec2(0x2b190808, 0x2b080819), uvec2(0x2b2b2b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x08081919, 0x2b08082b),\n    uvec2(0x08082b2b, 0x2b08082b), uvec2(0x08190819, 0x2b08082b), uvec2(0x08191908, 0x2b08082b), uvec2(0x19080819, 0x2b08082b),\n    uvec2(0x19081908, 0x2b08082b), uvec2(0x19190808, 0x2b08082b), uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908),\n    uvec2(0x0808192b, 0x2b081908), uvec2(0x08082b19, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908),\n    uvec2(0x08191919, 0x2b081908), uvec2(0x08192b08, 0x2b081908), uvec2(0x082b0819, 0x2b081908), uvec2(0x19080808, 0x2b081908),\n    uvec2(0x1908082b, 0x2b081908), uvec2(0x19081919, 0x2b081908), uvec2(0x19082b08, 0x2b081908), uvec2(0x19190819, 0x2b081908),\n    uvec2(0x19191908, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b080819, 0x2b081908), uvec2(0x2b081908, 0x2b081908),\n    uvec2(0x2b190808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x0808082b, 0x2b081919), uvec2(0x08081919, 0x2b081919),\n    uvec2(0x08082b08, 0x2b081919), uvec2(0x08190819, 0x2b081919), uvec2(0x08191908, 0x2b081919), uvec2(0x082b0808, 0x2b081919),\n    uvec2(0x19080819, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x19190808, 0x2b081919), uvec2(0x2b080808, 0x2b081919),\n    uvec2(0x2b082b2b, 0x2b081919), uvec2(0x08080819, 0x2b08192b), uvec2(0x08081908, 0x2b08192b), uvec2(0x08190808, 0x2b08192b),\n    uvec2(0x082b2b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08081919, 0x2b082b08),\n    uvec2(0x08190819, 0x2b082b08), uvec2(0x08191908, 0x2b082b08), uvec2(0x19080819, 0x2b082b08), uvec2(0x19081908, 0x2b082b08),\n    uvec2(0x19190808, 0x2b082b08), uvec2(0x2b2b082b, 0x2b082b08), uvec2(0x08080819, 0x2b082b19), uvec2(0x08081908, 0x2b082b19),\n    uvec2(0x19080808, 0x2b082b19), uvec2(0x192b1919, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x19192b08, 0x2b082b2b),\n    uvec2(0x19192b2b, 0x2b082b2b), uvec2(0x2b08082b, 0x2b082b2b), uvec2(0x2b2b082b, 0x2b082b2b), uvec2(0x08080819, 0x2b190808),\n    uvec2(0x08081908, 0x2b190808), uvec2(0x08082b19, 0x2b190808), uvec2(0x08190808, 0x2b190808), uvec2(0x0819082b, 0x2b190808),\n    uvec2(0x08191919, 0x2b190808), uvec2(0x08192b08, 0x2b190808), uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808),\n    uvec2(0x1908082b, 0x2b190808), uvec2(0x19081919, 0x2b190808), uvec2(0x19082b08, 0x2b190808), uvec2(0x19190819, 0x2b190808),\n    uvec2(0x19191908, 0x2b190808), uvec2(0x192b0808, 0x2b190808), uvec2(0x2b080819, 0x2b190808), uvec2(0x2b081908, 0x2b190808),\n    uvec2(0x2b190808, 0x2b190808), uvec2(0x08080808, 0x2b190819), uvec2(0x08081919, 0x2b190819), uvec2(0x08190819, 0x2b190819),\n    uvec2(0x08191908, 0x2b190819), uvec2(0x19080819, 0x2b190819), uvec2(0x19081908, 0x2b190819), uvec2(0x19190808, 0x2b190819),\n    uvec2(0x19192b2b, 0x2b190819), uvec2(0x08080819, 0x2b19082b), uvec2(0x08081908, 0x2b19082b), uvec2(0x08190808, 0x2b19082b),\n    uvec2(0x19080808, 0x2b19082b), uvec2(0x2b2b192b, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x0808082b, 0x2b191908),\n    uvec2(0x08081919, 0x2b191908), uvec2(0x08082b08, 0x2b191908), uvec2(0x08190819, 0x2b191908), uvec2(0x08191908, 0x2b191908),\n    uvec2(0x082b0808, 0x2b191908), uvec2(0x19080819, 0x2b191908), uvec2(0x19081908, 0x2b191908), uvec2(0x19190808, 0x2b191908),\n    uvec2(0x2b080808, 0x2b191908), uvec2(0x2b19192b, 0x2b191908), uvec2(0x08080819, 0x2b191919), uvec2(0x08081908, 0x2b191919),\n    uvec2(0x08190808, 0x2b191919), uvec2(0x19080808, 0x2b191919), uvec2(0x2b192b08, 0x2b191919), uvec2(0x2b2b0819, 0x2b191919),\n    uvec2(0x08080808, 0x2b19192b), uvec2(0x1908192b, 0x2b19192b), uvec2(0x192b1908, 0x2b19192b), uvec2(0x08080819, 0x2b192b08),\n    uvec2(0x08081908, 0x2b192b08), uvec2(0x08190808, 0x2b192b08), uvec2(0x082b192b, 0x2b192b08), uvec2(0x19080808, 0x2b192b08),\n    uvec2(0x2b2b2b19, 0x2b192b08), uvec2(0x08080808, 0x2b192b19), uvec2(0x19082b19, 0x2b192b19), uvec2(0x1919082b, 0x2b192b19),\n    uvec2(0x2b190808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), uvec2(0x08081919, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808),\n    uvec2(0x08191908, 0x2b2b0808), uvec2(0x082b082b, 0x2b2b0808), uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x19080819, 0x2b2b0808),\n    uvec2(0x19081908, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b2b082b, 0x2b2b0808), uvec2(0x2b2b2b2b, 0x2b2b0808),\n    uvec2(0x19080808, 0x2b2b0819), uvec2(0x192b1919, 0x2b2b0819), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b2b, 0x2b2b082b),\n    uvec2(0x082b082b, 0x2b2b082b), uvec2(0x082b2b08, 0x2b2b082b), uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b08082b, 0x2b2b082b),\n    uvec2(0x2b082b08, 0x2b2b082b), uvec2(0x2b082b2b, 0x2b2b082b), uvec2(0x2b2b2b08, 0x2b2b082b), uvec2(0x08080819, 0x2b2b1908),\n    uvec2(0x08081908, 0x2b2b1908), uvec2(0x08190808, 0x2b2b1908), uvec2(0x19080808, 0x2b2b1908), uvec2(0x2b082b19, 0x2b2b1908),\n    uvec2(0x2b2b1908, 0x2b2b1908), uvec2(0x08080808, 0x2b2b1919), uvec2(0x08192b19, 0x2b2b1919), uvec2(0x19190819, 0x2b2b192b),\n    uvec2(0x08082b2b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b082b, 0x2b2b2b08), uvec2(0x19191908, 0x2b2b2b19),\n    uvec2(0x2b08192b, 0x2b2b2b19), uvec2(0x08082b08, 0x2b2b2b2b), uvec2(0x08082b2b, 0x2b2b2b2b), uvec2(0x082b0808, 0x2b2b2b2b),\n    uvec2(0x082b082b, 0x2b2b2b2b), uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x2b082b08, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b)\n};\n\nshared uvec2 iq2s_grid[1024];\n\n#define NEEDS_INIT_IQ_SHMEM\nvoid init_iq_shmem(uvec3 wgsize)\n{\n    // copy the table into shared memory and sync\n    [[unroll]] for (uint i = 0; i < iq2s_grid.length(); i += wgsize.x) {\n        if (iq2s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2s_grid_const.length()) {\n            iq2s_grid[i + gl_LocalInvocationIndex.x] = iq2s_grid_const[i + gl_LocalInvocationIndex.x];\n        }\n    }\n    barrier();\n}\n\n#define QUANT_K QUANT_K_IQ2_S\n#define QUANT_R QUANT_R_IQ2_S\n#define A_TYPE block_iq2_s\n#define A_TYPE_PACKED16 block_iq2_s_packed16\n#endif\n\n#define QUANT_K_IQ3_XXS 256\n#define QUANT_R_IQ3_XXS 1\n\nstruct block_iq3_xxs\n{\n    float16_t d;\n    uint8_t qs[QUANT_K_IQ3_XXS/4 + QUANT_K_IQ3_XXS/8];\n};\n\nstruct block_iq3_xxs_packed16\n{\n    float16_t d;\n    uint16_t qs[QUANT_K_IQ3_XXS/8 + QUANT_K_IQ3_XXS/16];\n};\n\n#if defined(DATA_A_IQ3_XXS)\n\nconst uint32_t iq3xxs_grid_const[256] = {\n    0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,\n    0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,\n    0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,\n    0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,\n    0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,\n    0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,\n    0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,\n    0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,\n    0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,\n    0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,\n    0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,\n    0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,\n    0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,\n    0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,\n    0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,\n    0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,\n    0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,\n    0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,\n    0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,\n    0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,\n    0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,\n    0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,\n    0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,\n    0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,\n    0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,\n    0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,\n    0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,\n    0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,\n    0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,\n    0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,\n    0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,\n    0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,\n};\n\nshared uint32_t iq3xxs_grid[256];\n\n#define NEEDS_INIT_IQ_SHMEM\nvoid init_iq_shmem(uvec3 wgsize)\n{\n    // copy the table into shared memory and sync\n    [[unroll]] for (uint i = 0; i < iq3xxs_grid.length(); i += wgsize.x) {\n        if (iq3xxs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3xxs_grid.length()) {\n            iq3xxs_grid[i + gl_LocalInvocationIndex.x] = iq3xxs_grid_const[i + gl_LocalInvocationIndex.x];\n        }\n    }\n    barrier();\n}\n\n#define QUANT_K QUANT_K_IQ3_XXS\n#define QUANT_R QUANT_R_IQ3_XXS\n#define A_TYPE block_iq3_xxs\n#define A_TYPE_PACKED16 block_iq3_xxs_packed16\n#endif\n\n#define QUANT_K_IQ3_S 256\n#define QUANT_R_IQ3_S 1\n\nstruct block_iq3_s\n{\n    float16_t d;\n    uint8_t qs[QUANT_K_IQ3_S/4];\n    uint8_t qh[QUANT_K_IQ3_S/32];\n    uint8_t signs[QUANT_K_IQ3_S/8];\n    uint8_t scales[QUANT_K_IQ3_S/64];\n};\n\nstruct block_iq3_s_packed16\n{\n    float16_t d;\n    uint16_t qs[QUANT_K_IQ3_S/4/2];\n    uint16_t qh[QUANT_K_IQ3_S/32/2];\n    uint16_t signs[QUANT_K_IQ3_S/8/2];\n    uint16_t scales[QUANT_K_IQ3_S/64/2];\n};\n\n#if defined(DATA_A_IQ3_S)\n\nconst uint32_t iq3s_grid_const[512] = {\n    0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,\n    0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,\n    0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,\n    0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,\n    0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,\n    0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,\n    0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,\n    0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,\n    0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,\n    0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,\n    0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,\n    0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,\n    0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,\n    0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,\n    0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,\n    0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,\n    0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,\n    0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,\n    0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,\n    0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,\n    0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,\n    0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,\n    0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,\n    0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,\n    0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,\n    0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,\n    0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,\n    0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,\n    0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,\n    0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,\n    0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,\n    0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,\n    0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,\n    0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,\n    0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,\n    0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,\n    0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,\n    0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,\n    0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,\n    0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,\n    0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,\n    0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,\n    0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,\n    0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,\n    0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,\n    0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,\n    0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,\n    0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,\n    0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,\n    0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,\n    0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,\n    0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,\n    0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,\n    0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,\n    0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,\n    0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,\n    0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,\n    0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,\n    0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,\n    0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,\n    0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,\n    0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,\n    0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,\n    0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,\n};\n\nshared uint32_t iq3s_grid[512];\n\n#define NEEDS_INIT_IQ_SHMEM\nvoid init_iq_shmem(uvec3 wgsize)\n{\n    // copy the table into shared memory and sync\n    [[unroll]] for (uint i = 0; i < iq3s_grid.length(); i += wgsize.x) {\n        if (iq3s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3s_grid.length()) {\n            iq3s_grid[i + gl_LocalInvocationIndex.x] = iq3s_grid_const[i + gl_LocalInvocationIndex.x];\n        }\n    }\n    barrier();\n}\n\n#define QUANT_K QUANT_K_IQ3_S\n#define QUANT_R QUANT_R_IQ3_S\n#define A_TYPE block_iq3_s\n#define A_TYPE_PACKED16 block_iq3_s_packed16\n#endif\n\n#define QUANT_K_IQ4_XS 256\n#define QUANT_R_IQ4_XS 1\n\nstruct block_iq4_xs\n{\n    float16_t d;\n    uint16_t scales_h;\n    uint8_t scales_l[QUANT_K_IQ4_XS/64];\n    uint8_t qs[QUANT_K_IQ4_XS/2];\n};\n\nstruct block_iq4_xs_packed16\n{\n    float16_t d;\n    uint16_t scales_h;\n    uint16_t scales_l[QUANT_K_IQ4_XS/128];\n    uint16_t qs[QUANT_K_IQ4_XS/4];\n};\n\nstruct block_iq4_xs_packed32\n{\n    float16_t d;\n    uint16_t scales_h;\n    uint32_t scales_l;\n    uint32_t qs[QUANT_K_IQ4_XS/8];\n};\n\n#if defined(DATA_A_IQ4_XS)\n#define QUANT_K QUANT_K_IQ4_XS\n#define QUANT_R QUANT_R_IQ4_XS\n#define A_TYPE block_iq4_xs\n#define A_TYPE_PACKED16 block_iq4_xs_packed16\n#define A_TYPE_PACKED32 block_iq4_xs_packed32\n#endif\n\n#define QUANT_K_IQ4_NL 32\n#define QUANT_R_IQ4_NL 2\n\nstruct block_iq4_nl\n{\n    float16_t d;\n    uint8_t qs[QUANT_K_IQ4_NL/2];\n};\n\nstruct block_iq4_nl_packed16\n{\n    float16_t d;\n    uint16_t qs[QUANT_K_IQ4_NL/2/2];\n};\n\n#if defined(DATA_A_IQ4_NL)\n#define QUANT_K QUANT_K_IQ4_NL\n#define QUANT_R QUANT_R_IQ4_NL\n#define A_TYPE block_iq4_nl\n#define A_TYPE_PACKED16 block_iq4_nl_packed16\n#endif\n\n#define QUANT_K_MXFP4 32\n#define QUANT_R_MXFP4 2\n\nstruct block_mxfp4\n{\n    uint8_t e;\n    uint8_t qs[QUANT_K_MXFP4/2];\n};\n\n#if defined(DATA_A_MXFP4)\n#define QUANT_K QUANT_K_MXFP4\n#define QUANT_R QUANT_R_MXFP4\n#define QUANT_AUXF 1\n#define A_TYPE block_mxfp4\n#endif\n\n#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)\nconst int8_t kvalues_iq4nl_const[16] = {\n    int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),\n    int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)\n};\n\nshared FLOAT_TYPE kvalues_iq4nl[16];\n\n#define NEEDS_INIT_IQ_SHMEM\nvoid init_iq_shmem(uvec3 wgsize)\n{\n    // copy the table into shared memory and sync\n    for (uint i = gl_LocalInvocationIndex.x; i < kvalues_iq4nl.length(); i += wgsize.x) {\n        kvalues_iq4nl[i] = FLOAT_TYPE(kvalues_iq4nl_const[i]);\n    }\n    barrier();\n}\n#endif\n\n#if defined(DATA_A_MXFP4)\nconst int8_t kvalues_mxfp4_const[16] = {\n    int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12),\n    int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12),\n};\n\nshared int8_t kvalues_mxfp4[16];\n\n#define NEEDS_INIT_IQ_SHMEM\nvoid init_iq_shmem(uvec3 wgsize)\n{\n    // copy the table into shared memory and sync\n    for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) {\n        kvalues_mxfp4[i] = kvalues_mxfp4_const[i];\n    }\n    barrier();\n}\n#endif\n\n// returns the bfloat value in the low 16b.\n// See ggml_compute_fp32_to_bf16\nuint32_t fp32_to_bf16(float f)\n{\n    uint32_t u = floatBitsToUint(f);\n    u = (u + (0x7fff + ((u >> 16) & 1))) >> 16;\n    return u;\n}\n\nfloat bf16_to_fp32(uint32_t u)\n{\n    return uintBitsToFloat(u << 16);\n}\n\nvec4 bf16_to_fp32(uvec4 u)\n{\n    return vec4(bf16_to_fp32(u.x), bf16_to_fp32(u.y), bf16_to_fp32(u.z), bf16_to_fp32(u.w));\n}\n\nfloat e8m0_to_fp32(uint8_t x) {\n    uint32_t bits;\n\n    if (x == 0) {\n        bits = 0x00400000;\n    } else {\n        bits = x;\n        bits = bits << 23;\n    }\n\n    return uintBitsToFloat(bits);\n}\n\n#if BDA\n\n#extension GL_EXT_buffer_reference : enable\n#extension GL_EXT_shader_explicit_arithmetic_types_int64 : enable\n\n#define BDA_STORAGE_T uint64_t\n#define BDA_OFFSET_T uint64_t\n\n#else\n\n#define BDA_STORAGE_T uvec2\n#define BDA_OFFSET_T uint\n\n#endif\n\n#endif // !defined(GGML_TYPES_COMP)\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/upscale.comp",
    "content": "#version 450\n\nlayout (push_constant) uniform parameter\n{\n    uint ne; uint a_offset; uint d_offset;\n    uint ne00; uint ne01;\n    uint nb00; uint nb01; uint nb02; uint nb03;\n    uint ne10; uint ne11; uint ne12; uint ne13;\n    float sf0; float sf1; float sf2; float sf3;\n    float pixel_offset;\n} p;\n\n#include \"types.glsl\"\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer A {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\n// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag\n#define NEAREST  0\n#define BILINEAR 1\n#define BICUBIC  2\n#define BILINEAR_ANTIALIAS 513\n\nlayout (constant_id = 0) const uint scale_mode = 0;\n\nfloat fetch_nearest(uint i10, uint i11, uint i12, uint i13) {\n    const uint i00 = uint(i10 / p.sf0);\n    const uint i01 = uint(i11 / p.sf1);\n    const uint i02 = uint(i12 / p.sf2);\n    const uint i03 = uint(i13 / p.sf3);\n\n    return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00];\n}\n\nfloat fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) {\n    const uint i02 = uint(i12 / p.sf2);\n    const uint i03 = uint(i13 / p.sf3);\n    const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02;\n\n    const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00];\n    const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00];\n    const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00];\n    const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00];\n\n    return\n        v00 * (1.0-d.x) * (1.0-d.y) +\n        v01 * d.x       * (1.0-d.y) +\n        v10 * (1.0-d.x) * d.y +\n        v11 * d.x       * d.y;\n}\n\nfloat interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {\n    const ivec2 ne0 = ivec2(p.ne00, p.ne01);\n\n    const vec2 c = (vec2(i10, i11) + p.pixel_offset) / vec2(p.sf0, p.sf1) - p.pixel_offset;\n    const vec2 c0f = floor(c);\n    const vec2 d = c - c0f;\n    const ivec2 c0 = max(ivec2(c0f), 0);\n    const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1);\n\n    return fetch_bilinear(c0, c1, d, i12, i13);\n}\n\nfloat triangle_filter(float x) {\n    return max(1.0f - abs(x), 0.0f);\n}\n\nfloat interpolate_bilinear_antialias(uint i10, uint i11, uint i12, uint i13) {\n    const float support1  = max(1.0f, 1.0f / p.sf1);\n    const float invscale1 = 1.0f / support1;\n    const float support0  = max(1.0f, 1.0f / p.sf0);\n    const float invscale0 = 1.0f / support0;\n\n    const uint i02 = uint(i12 / p.sf2);\n    const uint i03 = uint(i13 / p.sf3);\n\n    const float y = (float(i11) + p.pixel_offset) / p.sf1;\n    const float x = (float(i10) + p.pixel_offset) / p.sf0;\n\n    // the range of source pixels that contribute\n    const int x_min = max(int(x - support0 + p.pixel_offset), 0);\n    const int x_max = min(int(x + support0 + p.pixel_offset), int(p.ne00));\n    const int y_min = max(int(y - support1 + p.pixel_offset), 0);\n    const int y_max = min(int(y + support1 + p.pixel_offset), int(p.ne01));\n\n    // bilinear filter with antialiasing\n    float val = 0.0f;\n    float total_weight = 0.0f;\n\n    for (int sy = y_min; sy < y_max; sy++) {\n        const float weight_y = triangle_filter((sy - y + p.pixel_offset) * invscale1);\n\n        for (int sx = x_min; sx < x_max; sx++) {\n            const float weight_x = triangle_filter((sx - x + p.pixel_offset) * invscale0);\n            const float weight = weight_x * weight_y;\n\n            if (weight <= 0.0f) {\n                continue;\n            }\n\n            const float pixel = data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + sy * p.nb01 + sx * p.nb00];\n            val += pixel * weight;\n            total_weight += weight;\n        }\n    }\n\n    if (total_weight > 0.0f) {\n        val /= total_weight;\n    }\n\n    return val;\n}\n\n// Bicubic interpolation with alpha = -0.75\n// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm\nconst vec4 bcoeffs1 = vec4( 1.25, -2.25,  0.0, 1.0);\nconst vec4 bcoeffs2 = vec4(-0.75,  3.75, -6.0, 3.0);\nvec4 powers(float x) { return vec4(x*x*x, x*x, x, 1); }\n\nfloat bicubic(float p0, float p1, float p2, float p3, float x) {\n    return p0 * dot(bcoeffs2, powers(x + 1)) +\n           p1 * dot(bcoeffs1, powers(x    )) +\n           p2 * dot(bcoeffs1, powers(1 - x)) +\n           p3 * dot(bcoeffs2, powers(2 - x));\n}\n\n#define FETCH(a,b) data_a[base + clamp(i.x+(a), 0, res.x) * p.nb00 + clamp(i.y+(b), 0, res.y) * p.nb01]\n\nfloat interpolate_bicubic(uint i10, uint i11, uint i12, uint i13) {\n    const ivec2 res = ivec2(p.ne00 - 1, p.ne01 - 1);\n\n    const vec2 coord = (vec2(i10, i11) + p.pixel_offset) / vec2(p.sf0, p.sf1) - p.pixel_offset;\n    const vec2 d = fract(coord);\n    const ivec2 i = ivec2(floor(coord));\n\n    const uint i02 = uint(i12 / p.sf2);\n    const uint i03 = uint(i13 / p.sf3);\n    const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02;\n\n    return bicubic(\n        bicubic(FETCH(-1,-1), FETCH(0,-1), FETCH(1,-1), FETCH(2,-1), d.x),\n        bicubic(FETCH(-1, 0), FETCH(0, 0), FETCH(1, 0), FETCH(2, 0), d.x),\n        bicubic(FETCH(-1, 1), FETCH(0, 1), FETCH(1, 1), FETCH(2, 1), d.x),\n        bicubic(FETCH(-1, 2), FETCH(0, 2), FETCH(1, 2), FETCH(2, 2), d.x), d.y);\n}\n\nvoid main() {\n    const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (idx >= p.ne) {\n        return;\n    }\n\n    const uint i10 = idx % p.ne10;\n    const uint i11 = (idx / p.ne10) % p.ne11;\n    const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;\n    const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;\n\n    float result;\n    switch (scale_mode) {\n        case NEAREST:\n            result = fetch_nearest(i10, i11, i12, i13);\n            break;\n        case BILINEAR:\n            result = interpolate_bilinear(i10, i11, i12, i13);\n            break;\n        case BICUBIC:\n            result = interpolate_bicubic(i10, i11, i12, i13);\n            break;\n        case BILINEAR_ANTIALIAS:\n            result = interpolate_bilinear_antialias(i10, i11, i12, i13);\n            break;\n    }\n\n    data_d[p.d_offset + idx] = D_TYPE(result);\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/utils.glsl",
    "content": "#ifndef UTILS_COMP\n#define UTILS_COMP\n\n// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1\nuint fastmod(uint a, uint b) {\n    if ((b & (b-1)) == 0) {\n        return a & (b-1);\n    }\n    return a % b;\n}\n\nuint fastdiv(uint a, uint b) {\n    return (a < b) ? 0 : (a / b);\n}\n\nvoid get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03, uint ne00, uint ne01, uint ne02, uint ne03) {\n    i03 = fastdiv(idx, (ne02*ne01*ne00));\n    const uint i03_offset = i03 * ne02*ne01*ne00;\n    i02 = fastdiv((idx - i03_offset), (ne01*ne00));\n    const uint i02_offset = i02*ne01*ne00;\n    i01 = (idx - i03_offset - i02_offset) / ne00;\n    i00 = idx - i03_offset - i02_offset - i01*ne00;\n}\n\n#endif // UTILS_COMP\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp",
    "content": "#include <iostream>\n#include <fstream>\n#include <sstream>\n#include <string>\n#include <stdexcept>\n#include <array>\n#include <vector>\n#include <map>\n#include <thread>\n#include <mutex>\n#include <future>\n#include <queue>\n#include <condition_variable>\n#include <cstdio>\n#include <cstring>\n#include <cstdlib>\n#include <cassert>\n#include <algorithm>\n#include <sys/stat.h>\n#include <sys/types.h>\n#include <filesystem>\n\n#ifdef _WIN32\n    #define NOMINMAX\n    #include <windows.h>\n    #include <direct.h> // For _mkdir on Windows\n#else\n    #include <unistd.h>\n    #include <sys/wait.h>\n    #include <fcntl.h>\n#endif\n\n#define ASYNCIO_CONCURRENCY 64\n\nstd::mutex lock;\nstd::vector<std::pair<std::string, std::string>> shader_fnames;\nstd::locale c_locale(\"C\");\n\nstd::string GLSLC = \"glslc\";\nstd::string input_filepath = \"\";\nstd::string output_dir = \"/tmp\";\nstd::string target_hpp = \"\";\nstd::string target_cpp = \"\";\n\nconst std::vector<std::string> type_names = {\n    \"f32\",\n    \"f16\",\n    \"q4_0\",\n    \"q4_1\",\n    \"q5_0\",\n    \"q5_1\",\n    \"q8_0\",\n    \"q2_k\",\n    \"q3_k\",\n    \"q4_k\",\n    \"q5_k\",\n    \"q6_k\",\n    \"iq1_s\",\n    \"iq1_m\",\n    \"iq2_xxs\",\n    \"iq2_xs\",\n    \"iq2_s\",\n    \"iq3_xxs\",\n    \"iq3_s\",\n    \"iq4_xs\",\n    \"iq4_nl\",\n    \"mxfp4\",\n    \"bf16\",\n};\n\nenum MatMulIdType {\n    NONE,\n    DEFAULT,\n    SUBGROUP,\n};\n\nnamespace {\n\nvoid execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {\n#ifdef _WIN32\n    HANDLE stdout_read, stdout_write;\n    HANDLE stderr_read, stderr_write;\n    SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };\n\n    if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||\n        !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {\n        throw std::runtime_error(\"Failed to create stdout pipe\");\n    }\n\n    if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||\n        !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {\n        throw std::runtime_error(\"Failed to create stderr pipe\");\n    }\n\n    PROCESS_INFORMATION pi;\n    STARTUPINFOA si = {};\n    si.cb = sizeof(STARTUPINFOA);\n    si.dwFlags = STARTF_USESTDHANDLES;\n    si.hStdOutput = stdout_write;\n    si.hStdError = stderr_write;\n\n    std::string cmd;\n    for (const auto& part : command) {\n        cmd += part + \" \";\n    }\n\n    if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {\n        throw std::runtime_error(\"Failed to create process\");\n    }\n\n    CloseHandle(stdout_write);\n    CloseHandle(stderr_write);\n\n    std::array<char, 128> buffer;\n    DWORD bytes_read;\n\n    while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {\n        stdout_str.append(buffer.data(), bytes_read);\n    }\n\n    while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {\n        stderr_str.append(buffer.data(), bytes_read);\n    }\n\n    CloseHandle(stdout_read);\n    CloseHandle(stderr_read);\n    WaitForSingleObject(pi.hProcess, INFINITE);\n    CloseHandle(pi.hProcess);\n    CloseHandle(pi.hThread);\n#else\n    int stdout_pipe[2];\n    int stderr_pipe[2];\n\n    if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {\n        throw std::runtime_error(\"Failed to create pipes\");\n    }\n\n    pid_t pid = fork();\n    if (pid < 0) {\n        throw std::runtime_error(\"Failed to fork process\");\n    }\n\n    std::vector<char*> argv;\n    for (std::string& part : command) {\n        argv.push_back(part.data());\n    }\n    argv.push_back(nullptr);\n\n    if (pid == 0) {\n        close(stdout_pipe[0]);\n        close(stderr_pipe[0]);\n        dup2(stdout_pipe[1], STDOUT_FILENO);\n        dup2(stderr_pipe[1], STDERR_FILENO);\n        close(stdout_pipe[1]);\n        close(stderr_pipe[1]);\n        execvp(argv[0], argv.data());\n        _exit(EXIT_FAILURE);\n    } else {\n        close(stdout_pipe[1]);\n        close(stderr_pipe[1]);\n\n        std::array<char, 128> buffer;\n        ssize_t bytes_read;\n\n        while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {\n            stdout_str.append(buffer.data(), bytes_read);\n        }\n\n        while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {\n            stderr_str.append(buffer.data(), bytes_read);\n        }\n\n        close(stdout_pipe[0]);\n        close(stderr_pipe[0]);\n        waitpid(pid, nullptr, 0);\n    }\n#endif\n}\n\nbool directory_exists(const std::string& path) {\n    struct stat info;\n    if (stat(path.c_str(), &info) != 0) {\n        return false; // Path doesn't exist or can't be accessed\n    }\n    return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory\n}\n\nbool create_directory(const std::string& path) {\n#ifdef _WIN32\n    return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists\n#else\n    return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions\n#endif\n}\n\nstd::string to_uppercase(const std::string& input) {\n    std::string result = input;\n    for (char& c : result) {\n        c = std::toupper(c);\n    }\n    return result;\n}\n\nbool string_starts_with(const std::string& str, const std::string& prefix) {\n    if (prefix.size() > str.size()) {\n        return false;\n    }\n    return std::equal(prefix.begin(), prefix.end(), str.begin());\n}\n\nbool string_ends_with(const std::string& str, const std::string& suffix) {\n    if (suffix.size() > str.size()) {\n        return false;\n    }\n    return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());\n}\n\nbool is_quantized_type(const std::string& type_name) {\n    return type_name != \"f32\" && type_name != \"f16\" && type_name != \"bf16\";\n}\n\nbool is_legacy_quant(const std::string& type_name) {\n    return type_name == \"q4_0\" || type_name == \"q4_1\" || type_name == \"q5_0\" || type_name == \"q5_1\" || type_name == \"q8_0\";\n}\n\nbool is_k_quant(const std::string& type_name) {\n    return string_ends_with(type_name, \"_k\");\n}\n\nbool is_iq_quant(const std::string& type_name) {\n    return string_starts_with(type_name, \"iq\");\n}\n\nstatic const char path_separator = '/';\n\nstd::string join_paths(const std::string& path1, const std::string& path2) {\n    return path1 + path_separator + path2;\n}\n\nstd::string basename(const std::string &path) {\n    return path.substr(path.find_last_of(\"/\\\\\") + 1);\n}\n\nstd::stringstream make_generic_stringstream() {\n    std::stringstream ss;\n    ss.imbue(c_locale);\n    return ss;\n}\n\nstd::string read_binary_file(const std::string& path, bool may_not_exist = false) {\n    FILE* f = fopen(path.c_str(), \"rb\");\n    if (!f) {\n        if (!may_not_exist) {\n            std::cerr << \"Error opening file: \" << path << \" (\" << strerror(errno) << \")\\n\";\n        }\n        return {};\n    }\n\n    fseek(f, 0, SEEK_END);\n    size_t size = ftell(f);\n    fseek(f, 0, SEEK_SET);\n\n    std::string data(size, '\\0');\n    size_t read_size = fread(data.data(), 1, size, f);\n    fclose(f);\n    if (read_size != size) {\n        std::cerr << \"Error reading file: \" << path << \" (\" << strerror(errno) << \")\\n\";\n        return {};\n    }\n\n    return data;\n}\n\nvoid write_binary_file(const std::string& path, const std::string& content) {\n    FILE* f = fopen(path.c_str(), \"wb\");\n    if (!f) {\n        std::cerr << \"Error opening file for writing: \" << path << \" (\" << strerror(errno) << \")\\n\";\n        return;\n    }\n\n    size_t write_size = fwrite(content.data(), 1, content.size(), f);\n    fclose(f);\n    if (write_size != content.size()) {\n        std::cerr << \"Error writing file: \" << path << \" (\" << strerror(errno) << \")\\n\";\n        return;\n    }\n}\n\nvoid write_file_if_changed(const std::string& path, const std::string& content) {\n    std::string existing = read_binary_file(path, true);\n    if (existing != content) {\n        write_binary_file(path, content);\n    }\n}\n\n\n// variables to track number of compiles in progress\nstatic uint32_t compile_count = 0;\nstatic std::mutex compile_count_mutex;\nstatic std::condition_variable compile_count_cond;\nstatic bool generate_dep_file = true;\n\nvoid decrement_compile_count(uint32_t * count) {\n    if (count) {\n        std::lock_guard<std::mutex> guard(compile_count_mutex);\n        assert(compile_count > 0);\n        compile_count--;\n        compile_count_cond.notify_all();\n    }\n}\n\nusing compile_count_guard = std::unique_ptr<uint32_t, decltype(&decrement_compile_count)>;\n\ncompile_count_guard acquire_compile_slot() {\n    // wait until fewer than N compiles are in progress.\n    // 16 is an arbitrary limit, the goal is to avoid \"failed to create pipe\" errors.\n    uint32_t N = std::max(1u, std::min(16u, std::thread::hardware_concurrency()));\n    std::unique_lock<std::mutex> guard(compile_count_mutex);\n    compile_count_cond.wait(guard, [N] { return compile_count < N; });\n    compile_count++;\n    return compile_count_guard(&compile_count, &decrement_compile_count);\n}\n\nvoid string_to_spv_func(std::string name, std::string in_path, std::string out_path, std::map<std::string, std::string> defines, bool coopmat, bool dep_file, compile_count_guard slot) {\n    std::string target_env = (name.find(\"_cm2\") != std::string::npos) ? \"--target-env=vulkan1.3\" : \"--target-env=vulkan1.2\";\n\n    #ifdef _WIN32\n        std::vector<std::string> cmd = {GLSLC, \"-fshader-stage=compute\", target_env, \"\\\"\" + in_path + \"\\\"\", \"-o\", \"\\\"\" + out_path + \"\\\"\"};\n    #else\n        std::vector<std::string> cmd = {GLSLC, \"-fshader-stage=compute\", target_env, in_path, \"-o\", out_path};\n    #endif\n\n    // disable spirv-opt for coopmat shaders for https://github.com/ggml-org/llama.cpp/issues/10734\n    // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344\n    // disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860\n    if (!coopmat && name.find(\"bf16\") == std::string::npos && name.find(\"rope\") == std::string::npos) {\n        cmd.push_back(\"-O\");\n    }\n\n    if (dep_file) {\n        cmd.push_back(\"-MD\");\n        cmd.push_back(\"-MF\");\n#ifdef _WIN32\n        cmd.push_back(\"\\\"\" + target_cpp + \".d\\\"\");\n#else\n        cmd.push_back(target_cpp + \".d\");\n#endif\n    }\n\n    #ifdef GGML_VULKAN_SHADER_DEBUG_INFO\n        cmd.push_back(\"-g\");\n    #endif\n\n    for (const auto& define : defines) {\n        cmd.push_back(\"-D\" + define.first + \"=\" + define.second);\n    }\n\n    std::string command;\n    for (const auto& part : cmd) {\n        command += part + \" \";\n    }\n\n    std::string stdout_str, stderr_str;\n    try {\n        // std::cout << \"Executing command: \";\n        // for (const auto& part : cmd) {\n        //     std::cout << part << \" \";\n        // }\n        // std::cout << std::endl;\n\n        execute_command(cmd, stdout_str, stderr_str);\n        if (!stderr_str.empty()) {\n            std::cerr << \"cannot compile \" << name << \"\\n\\n\";\n            for (const auto& part : cmd) {\n                std::cerr << part << \" \";\n            }\n            std::cerr << \"\\n\\n\" << stderr_str << std::endl;\n            return;\n        }\n\n        if (dep_file) {\n            // replace .spv output path with the embed .cpp path which is used as output in CMakeLists.txt\n            std::string dep = read_binary_file(target_cpp + \".d\", true);\n            if (!dep.empty()) {\n                size_t pos = dep.find(out_path);\n                if (pos != std::string::npos) {\n                    dep.replace(pos, out_path.length(), target_cpp);\n                }\n                write_binary_file(target_cpp + \".d\", dep);\n            }\n        }\n\n        std::lock_guard<std::mutex> guard(lock);\n        shader_fnames.push_back(std::make_pair(name, out_path));\n    } catch (const std::exception& e) {\n        std::cerr << \"Error executing command for \" << name << \": \" << e.what() << std::endl;\n    }\n}\n\nstd::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {\n    std::map<std::string, std::string> result = a;\n    result.insert(b.begin(), b.end());\n    return result;\n}\n\nstatic std::vector<std::future<void>> compiles;\nvoid string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {\n    name = name + (f16acc ? \"_f16acc\" : \"\") + (coopmat ? \"_cm1\" : \"\") + (coopmat2 ? \"_cm2\" : (fp16 ? \"\" : \"_fp32\"));\n    std::string out_path = join_paths(output_dir, name + \".spv\");\n\n    if (input_filepath == \"\") {\n        // No input source to compile, only generate header for all shaders\n        shader_fnames.push_back(std::pair(name, out_path));\n        return;\n    } else if (basename(input_filepath) != source) {\n        // Only compile shader variants matching the input filename\n        return;\n    }\n\n    compile_count_guard slot = acquire_compile_slot();\n    compiles.push_back(std::async(\n        string_to_spv_func, name, input_filepath, out_path, defines, coopmat, generate_dep_file, std::move(slot)));\n    // Don't write the same dep file from multiple processes\n    generate_dep_file = false;\n}\n\nvoid matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {\n    std::string load_vec = coopmat2 ? \"1\" : fp16 ? \"8\" : \"4\";\n    std::string aligned_b_type_f32 = coopmat2 ? \"float\" : fp16 ? \"mat2x4\" : \"vec4\";\n    std::string aligned_b_type_f16 = coopmat2 ? \"float16_t\" : fp16 ? \"f16mat2x4\" : \"f16vec4\";\n\n    std::map<std::string, std::string> base_dict;\n    std::string shader_name = \"matmul\";\n\n    if (matmul_id_type == MatMulIdType::DEFAULT) {\n        base_dict[\"MUL_MAT_ID\"] = \"1\";\n        shader_name = \"matmul_id\";\n    } else if (matmul_id_type == MatMulIdType::SUBGROUP) {\n        base_dict[\"MUL_MAT_ID\"] = \"1\";\n        base_dict[\"MUL_MAT_ID_USE_SUBGROUPS\"] = \"1\";\n        shader_name = \"matmul_id_subgroup\";\n    }\n\n    if (fp16) {\n        base_dict[\"FLOAT16\"] = \"1\";\n    }\n\n    base_dict[\"ACC_TYPE\"     ] = f16acc ? \"float16_t\" : \"float\";\n    base_dict[\"ACC_TYPE_VEC2\"] = f16acc ? \"f16vec2\"   : \"vec2\";\n    if (f16acc) {\n        base_dict[\"ACC_TYPE_MAX\"] = \"float16_t(65504.0)\";\n    }\n\n    if (coopmat) {\n        base_dict[\"COOPMAT\"] = \"1\";\n    }\n\n    const std::string source_name = coopmat2 ? \"mul_mm_cm2.comp\" : \"mul_mm.comp\";\n\n    auto const &FLOAT_TYPE = [&](int vec, const std::string &t) -> std::string {\n        switch (vec) {\n        case 1:\n            if (t == \"bf16\") {\n                // scalar path promotes to float\n                if (!coopmat && !coopmat2) {\n                    return \"float\";\n                }\n                return \"bfloat16_t\";\n            }\n            if (coopmat2 || fp16) {\n                return \"float16_t\";\n            }\n            return \"float\";\n        case 2:\n            if (t == \"bf16\") {\n                // scalar path promotes to float\n                if (!coopmat && !coopmat2) {\n                    return \"vec2\";\n                }\n                return \"bf16vec2\";\n            }\n            if (coopmat2 || fp16) {\n                return \"f16vec2\";\n            }\n            return \"vec2\";\n        case 4:\n            if (t == \"bf16\") {\n                // scalar path promotes to float\n                if (!coopmat && !coopmat2) {\n                    return \"vec4\";\n                }\n                return \"bf16vec4\";\n            }\n            if (coopmat2 || fp16) {\n                return \"f16vec4\";\n            }\n            return \"vec4\";\n        case 8:\n            if (t == \"bf16\") {\n                // scalar path promotes to float\n                if (!coopmat && !coopmat2) {\n                    return \"mat2x4\";\n                }\n                throw std::runtime_error(\"bf16 vec8 not supported\");\n            }\n            if (coopmat2 || fp16) {\n                return \"f16mat2x4\";\n            }\n            return \"mat2x4\";\n        default:\n            throw std::runtime_error(\"invalid vector size\");\n        }\n    };\n\n    const std::map<std::string, std::string> float_type_dict_f16 = {\n        {\"FLOAT_TYPE\",      FLOAT_TYPE(1, \"f16\")},\n        {\"FLOAT_TYPE_VEC2\", FLOAT_TYPE(2, \"f16\")},\n        {\"FLOAT_TYPE_VEC4\", FLOAT_TYPE(4, \"f16\")},\n        {\"FLOAT_TYPE_VEC8\", FLOAT_TYPE(8, \"f16\")},\n    };\n\n    // Shaders with f16 B_TYPE\n    string_to_spv(shader_name + \"_f32_f16\",         source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{\"DATA_A_F32\", \"1\"},                                                     {\"B_TYPE\", \"float16_t\"},        {\"D_TYPE\", \"float\"}, }), fp16, coopmat, coopmat2, f16acc);\n    string_to_spv(shader_name + \"_f32_f16_aligned\", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{\"DATA_A_F32\", \"1\"}, {\"LOAD_VEC_A\", load_vec}, {\"LOAD_VEC_B\", load_vec}, {\"B_TYPE\", aligned_b_type_f16}, {\"D_TYPE\", \"float\"}, {\"ALIGNED\", \"1\"}}), fp16, coopmat, coopmat2, f16acc);\n\n    string_to_spv(shader_name + \"_f16\",             source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{\"DATA_A_F16\", \"1\"},                                                     {\"B_TYPE\", \"float16_t\"},        {\"D_TYPE\", \"float\"}}), fp16, coopmat, coopmat2, f16acc);\n    string_to_spv(shader_name + \"_f16_aligned\",     source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{\"DATA_A_F16\", \"1\"}, {\"LOAD_VEC_A\", load_vec}, {\"LOAD_VEC_B\", load_vec}, {\"B_TYPE\", aligned_b_type_f16}, {\"D_TYPE\", \"float\"}, {\"ALIGNED\", \"1\"}}), fp16, coopmat, coopmat2, f16acc);\n\n    // bf16\n    {\n        // For aligned matmul loads\n        std::string load_vec_a = coopmat2 ? \"1\" : \"4\";\n\n        // scalar path promotes to float\n        std::string to_float_type = (coopmat || coopmat2) ? \"uintBitsToBFloat16EXT\" : \"bf16_to_fp32\";\n\n        const std::map<std::string, std::string> float_type_dict_bf16 = {\n            {\"FLOAT_TYPE\",      FLOAT_TYPE(1, \"bf16\")},\n            {\"FLOAT_TYPE_VEC2\", FLOAT_TYPE(2, \"bf16\")},\n            {\"FLOAT_TYPE_VEC4\", FLOAT_TYPE(4, \"bf16\")},\n        };\n\n        // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader\n#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)\n        if (!(coopmat || coopmat2))\n#endif\n        {\n            string_to_spv(shader_name + \"_bf16\",         source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{\"TO_FLOAT_TYPE\", to_float_type}, {\"DATA_A_BF16\", \"1\"},                             {\"B_TYPE\", coopmat2 ? \"bfloat16_t\" : \"uint16_t\"}, {\"D_TYPE\", \"float\"}, {\"B_IS_FLOAT\", \"1\"}, {\"DATA_B_BF16\", \"1\"}}),                   fp16, coopmat, coopmat2, f16acc);\n            string_to_spv(shader_name + \"_bf16_aligned\", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{\"TO_FLOAT_TYPE\", to_float_type}, {\"DATA_A_BF16\", \"1\"}, {\"LOAD_VEC_A\", load_vec_a}, {\"LOAD_VEC_B\", \"4\"}, {\"B_TYPE\", coopmat2 ? \"bfloat16_t\" : \"u16vec4\"},  {\"D_TYPE\", \"float\"}, {\"B_IS_FLOAT\", \"1\"}, {\"DATA_B_BF16\", \"1\"}, {\"ALIGNED\", \"1\"}}), fp16, coopmat, coopmat2, f16acc);\n        }\n    }\n\n    for (const auto& tname : type_names) {\n        std::string load_vec_quant = \"2\";\n        if ((tname == \"q4_0\") || (tname == \"q4_1\") || (tname == \"q5_1\") || (tname == \"iq1_s\") || (tname == \"iq1_m\") || (tname == \"iq2_xxs\") || (tname == \"iq2_xs\") || (tname == \"iq2_s\"))\n            load_vec_quant = \"8\";\n        else if ((tname == \"q5_0\") || (tname == \"q8_0\") || (tname == \"q2_k\") || (tname == \"q4_k\") || (tname == \"q5_k\") || (tname == \"iq3_xxs\") || (tname == \"iq3_s\") || (tname == \"iq4_nl\") || (tname == \"mxfp4\"))\n            load_vec_quant = \"4\";\n\n        if (tname == \"bf16\") {\n            continue;\n        }\n\n        std::string data_a_key = \"DATA_A_\" + to_uppercase(tname);\n        // For unaligned, load one at a time for f32/f16, or two at a time for quants\n        std::string load_vec_a_unaligned = (coopmat2 || tname == \"f32\" || tname == \"f16\" || tname == \"bf16\") ? \"1\" : load_vec_quant;\n        // For aligned matmul loads\n        std::string load_vec_a = (coopmat2 || tname == \"f32\" || tname == \"f16\" || tname == \"bf16\") ? load_vec : load_vec_quant;\n\n        const std::map<std::string, std::string> float_type_dict = {\n            {\"FLOAT_TYPE\",      FLOAT_TYPE(1, tname)},\n            {\"FLOAT_TYPE_VEC2\", FLOAT_TYPE(2, tname)},\n            {\"FLOAT_TYPE_VEC4\", FLOAT_TYPE(4, tname)},\n            {\"FLOAT_TYPE_VEC8\", FLOAT_TYPE(8, tname)},\n        };\n\n        // don't generate f32 variants for coopmat2\n        if (!coopmat2) {\n            string_to_spv(shader_name + \"_\" + tname + \"_f32\",         source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, \"1\"}, {\"LOAD_VEC_A\", load_vec_a_unaligned},                           {\"B_TYPE\", \"float\"},            {\"D_TYPE\", \"float\"}}), fp16, coopmat, coopmat2, f16acc);\n            string_to_spv(shader_name + \"_\" + tname + \"_f32_aligned\", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, \"1\"}, {\"LOAD_VEC_A\", load_vec_a},           {\"LOAD_VEC_B\", load_vec}, {\"B_TYPE\", aligned_b_type_f32}, {\"D_TYPE\", \"float\"}, {\"ALIGNED\", \"1\"}}), fp16, coopmat, coopmat2, f16acc);\n        }\n\n        if (tname != \"f16\" && tname != \"f32\") {\n            string_to_spv(shader_name + \"_\" + tname + \"_f16\",         source_name,  merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, \"1\"}, {\"LOAD_VEC_A\", load_vec_a_unaligned},                           {\"B_TYPE\", \"float16_t\"},        {\"D_TYPE\", \"float\"}}), fp16, coopmat, coopmat2, f16acc);\n            string_to_spv(shader_name + \"_\" + tname + \"_f16_aligned\", source_name,  merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, \"1\"}, {\"LOAD_VEC_A\", load_vec_a},           {\"LOAD_VEC_B\", load_vec}, {\"B_TYPE\", aligned_b_type_f16}, {\"D_TYPE\", \"float\"}, {\"ALIGNED\", \"1\"}}), fp16, coopmat, coopmat2, f16acc);\n        }\n\n#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)\n        // Integer dot mmq performs better with f32 accumulators\n        if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == \"mxfp4\")) {\n            string_to_spv(shader_name + \"_\" + tname + \"_q8_1\", \"mul_mmq.comp\", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, \"1\"}, {\"D_TYPE\", \"float\"},}), fp16, coopmat, coopmat2, f16acc);\n        }\n#endif\n    }\n}\n\nvoid process_shaders() {\n    // matmul\n    for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {\n        // No coopmats\n        // fp32\n        matmul_shaders(false, matmul_id_type, false, false, false);\n\n        // fp16, fp32acc and fp16acc\n        matmul_shaders(true, matmul_id_type, false, false, false);\n        matmul_shaders(true, matmul_id_type, false, false, true);\n\n        if (matmul_id_type != MatMulIdType::DEFAULT) {\n#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)\n            // Coopmat, fp32acc and fp16acc\n            matmul_shaders(true, matmul_id_type, true, false, false);\n            matmul_shaders(true, matmul_id_type, true, false, true);\n#endif\n\n#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)\n            // Coopmat2, fp32acc and fp16acc\n            matmul_shaders(true, matmul_id_type, false, true, false);\n            matmul_shaders(true, matmul_id_type, false, true, true);\n#endif\n        }\n    }\n\n    for (const bool& fp16 : {false, true}) {\n        std::map<std::string, std::string> base_dict;\n        if (fp16) {\n            base_dict = {{\"FLOAT_TYPE\", \"float16_t\"}, {\"FLOAT_TYPEV4\", \"f16vec4\"}, {\"FLOAT16\", \"1\"}, {\"FLOAT_TYPE_MAX\", \"float16_t(65504.0)\"}};\n        } else {\n            base_dict = {{\"FLOAT_TYPE\", \"float\"}, {\"FLOAT_TYPEV4\", \"vec4\"}};\n        }\n\n        // flash attention\n        for (const bool& f16acc : {false, true}) {\n            std::map<std::string, std::string> fa_base_dict = base_dict;\n            fa_base_dict[\"ACC_TYPE\"] = fp16 && f16acc ? \"float16_t\" : \"float\";\n            fa_base_dict[\"ACC_TYPEV4\"] = fp16 && f16acc ? \"f16vec4\" : \"vec4\";\n            if (fp16 && f16acc) {\n                fa_base_dict[\"ACC_TYPE_MAX\"] = \"float16_t(65504.0)\";\n            }\n\n            for (const auto& tname : type_names) {\n                if (tname == \"bf16\") continue;\n\n                if (fp16) {\n#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)\n                if (tname == \"f16\") {\n                    string_to_spv(\"flash_attn_f32_f16_\" + tname, \"flash_attn_cm2.comp\",\n                        merge_maps(fa_base_dict, {{\"Q_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"D_TYPEV4\", \"vec4\"}}), fp16, false, true, f16acc);\n                } else {\n                    std::string data_a_key = \"DATA_A_\" + to_uppercase(tname);\n                    string_to_spv(\"flash_attn_f32_f16_\" + tname, \"flash_attn_cm2.comp\",\n                        merge_maps(fa_base_dict, {{data_a_key, \"1\"}, {\"Q_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"D_TYPEV4\", \"vec4\"}, {\"DEQUANTFUNC\", \"dequantFunc\"+to_uppercase(tname) }, {\"BLOCK_SIZE\", \"QUANT_K_\"+to_uppercase(tname) }}), fp16, false, true, f16acc);\n                }\n#endif\n#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)\n                if (tname == \"f16\") {\n                    string_to_spv(\"flash_attn_f32_f16_\" + tname, \"flash_attn_cm1.comp\",\n                        merge_maps(fa_base_dict, {{\"Q_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"D_TYPEV4\", \"vec4\"}, {\"COOPMAT\", \"1\"}}), fp16, true, false, f16acc);\n                } else if (tname == \"q4_0\" || tname == \"q8_0\" || tname == \"f32\") {\n                    std::string data_a_key = \"DATA_A_\" + to_uppercase(tname);\n                    string_to_spv(\"flash_attn_f32_f16_\" + tname, \"flash_attn_cm1.comp\",\n                        merge_maps(fa_base_dict, {{data_a_key, \"1\"}, {\"Q_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"D_TYPEV4\", \"vec4\"}, {\"BLOCK_SIZE\", \"QUANT_K_\"+to_uppercase(tname)}, {\"COOPMAT\", \"1\"}}), fp16, true, false, f16acc);\n                }\n#endif\n                }\n\n                if (tname == \"f16\") {\n                    string_to_spv(\"flash_attn_f32_f16_\" + tname, \"flash_attn.comp\",\n                        merge_maps(fa_base_dict, {{\"Q_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"D_TYPEV4\", \"vec4\"}}), fp16, false, false, f16acc);\n                } else if (tname == \"q4_0\" || tname == \"q8_0\" || tname == \"f32\") {\n                    std::string data_a_key = \"DATA_A_\" + to_uppercase(tname);\n                    string_to_spv(\"flash_attn_f32_f16_\" + tname, \"flash_attn.comp\",\n                        merge_maps(fa_base_dict, {{data_a_key, \"1\"}, {\"Q_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"D_TYPEV4\", \"vec4\"}, {\"BLOCK_SIZE\", \"QUANT_K_\"+to_uppercase(tname) }}), fp16, false, false, f16acc);\n                }\n            }\n        }\n    }\n\n    std::map<std::string, std::string> base_dict = {{\"FLOAT_TYPE\", \"float\"}, {\"FLOAT_TYPE_VEC2\", \"vec2\"}};\n\n    for (const auto& tname : type_names) {\n        // mul mat vec\n        std::string data_a_key = \"DATA_A_\" + to_uppercase(tname);\n        std::string shader = (string_ends_with(tname, \"_k\") || string_starts_with(tname, \"iq1_\") || string_starts_with(tname, \"iq2_\") || string_starts_with(tname, \"iq3_\")) ? \"mul_mat_vec_\" + tname + \".comp\" : \"mul_mat_vec.comp\";\n\n        string_to_spv(\"mul_mat_vec_\" + tname + \"_f32_f32\", shader, merge_maps(base_dict, {{data_a_key, \"1\"}, {\"B_TYPE\", \"float\"}, {\"B_TYPE_VEC2\", \"vec2\"}, {\"B_TYPE_VEC4\", \"vec4\"}, {\"D_TYPE\", \"float\"}}));\n        string_to_spv(\"mul_mat_vec_\" + tname + \"_f16_f32\", shader, merge_maps(base_dict, {{data_a_key, \"1\"}, {\"B_TYPE\", \"float16_t\"}, {\"B_TYPE_VEC2\", \"f16vec2\"}, {\"B_TYPE_VEC4\", \"f16vec4\"}, {\"D_TYPE\", \"float\"}}));\n\n        string_to_spv(\"mul_mat_vec_\" + tname + \"_f32_f32_subgroup\", shader, merge_maps(base_dict, {{data_a_key, \"1\"}, {\"B_TYPE\", \"float\"}, {\"B_TYPE_VEC2\", \"vec2\"}, {\"B_TYPE_VEC4\", \"vec4\"}, {\"D_TYPE\", \"float\"}, {\"USE_SUBGROUP_ADD\", \"1\"}}));\n        string_to_spv(\"mul_mat_vec_\" + tname + \"_f16_f32_subgroup\", shader, merge_maps(base_dict, {{data_a_key, \"1\"}, {\"B_TYPE\", \"float16_t\"}, {\"B_TYPE_VEC2\", \"f16vec2\"}, {\"B_TYPE_VEC4\", \"f16vec4\"}, {\"D_TYPE\", \"float\"}, {\"USE_SUBGROUP_ADD\", \"1\"}}));\n\n        string_to_spv(\"mul_mat_vec_\" + tname + \"_f32_f32_subgroup_no_shmem\", shader, merge_maps(base_dict, {{data_a_key, \"1\"}, {\"B_TYPE\", \"float\"}, {\"B_TYPE_VEC2\", \"vec2\"}, {\"B_TYPE_VEC4\", \"vec4\"}, {\"D_TYPE\", \"float\"}, {\"USE_SUBGROUP_ADD_NO_SHMEM\", \"1\"}}));\n        string_to_spv(\"mul_mat_vec_\" + tname + \"_f16_f32_subgroup_no_shmem\", shader, merge_maps(base_dict, {{data_a_key, \"1\"}, {\"B_TYPE\", \"float16_t\"}, {\"B_TYPE_VEC2\", \"f16vec2\"}, {\"B_TYPE_VEC4\", \"f16vec4\"}, {\"D_TYPE\", \"float\"}, {\"USE_SUBGROUP_ADD_NO_SHMEM\", \"1\"}}));\n\n        string_to_spv(\"mul_mat_vec_id_\" + tname + \"_f32_f32\", shader, merge_maps(base_dict, {{\"MUL_MAT_ID\", \"1\"}, {data_a_key, \"1\"}, {\"B_TYPE\", \"float\"}, {\"B_TYPE_VEC2\", \"vec2\"}, {\"B_TYPE_VEC4\", \"vec4\"}, {\"D_TYPE\", \"float\"}}));\n        string_to_spv(\"mul_mat_vec_id_\" + tname + \"_f32_f32_subgroup\", shader, merge_maps(base_dict, {{\"MUL_MAT_ID\", \"1\"}, {data_a_key, \"1\"}, {\"B_TYPE\", \"float\"}, {\"B_TYPE_VEC2\", \"vec2\"}, {\"B_TYPE_VEC4\", \"vec4\"}, {\"D_TYPE\", \"float\"}, {\"USE_SUBGROUP_ADD\", \"1\"}}));\n        string_to_spv(\"mul_mat_vec_id_\" + tname + \"_f32_f32_subgroup_no_shmem\", shader, merge_maps(base_dict, {{\"MUL_MAT_ID\", \"1\"}, {data_a_key, \"1\"}, {\"B_TYPE\", \"float\"}, {\"B_TYPE_VEC2\", \"vec2\"}, {\"B_TYPE_VEC4\", \"vec4\"}, {\"D_TYPE\", \"float\"}, {\"USE_SUBGROUP_ADD_NO_SHMEM\", \"1\"}}));\n\n        // mul mat vec with integer dot product\n#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)\n        if (is_legacy_quant(tname) || tname == \"mxfp4\" || is_k_quant(tname) || tname == \"iq1_s\" || tname == \"iq1_m\") {\n            string_to_spv(\"mul_mat_vec_\" + tname + \"_q8_1_f32\", \"mul_mat_vecq.comp\", merge_maps(base_dict, {{data_a_key, \"1\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}, {\"FLOAT_TYPE_VEC2\", \"vec2\"}, {\"ACC_TYPE\", \"float\"}}));\n            string_to_spv(\"mul_mat_vec_\" + tname + \"_q8_1_f32_subgroup\", \"mul_mat_vecq.comp\", merge_maps(base_dict, {{data_a_key, \"1\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}, {\"FLOAT_TYPE_VEC2\", \"vec2\"}, {\"ACC_TYPE\", \"float\"}, {\"USE_SUBGROUP_ADD\", \"1\"}}));\n            string_to_spv(\"mul_mat_vec_\" + tname + \"_q8_1_f32_subgroup_no_shmem\", \"mul_mat_vecq.comp\", merge_maps(base_dict, {{data_a_key, \"1\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}, {\"FLOAT_TYPE_VEC2\", \"vec2\"}, {\"ACC_TYPE\", \"float\"}, {\"USE_SUBGROUP_ADD_NO_SHMEM\", \"1\"}}));\n\n            string_to_spv(\"mul_mat_vec_id_\" + tname + \"_q8_1_f32\", \"mul_mat_vecq.comp\", merge_maps(base_dict, {{\"MUL_MAT_ID\", \"1\"}, {data_a_key, \"1\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}, {\"FLOAT_TYPE_VEC2\", \"vec2\"}, {\"ACC_TYPE\", \"float\"}}));\n            string_to_spv(\"mul_mat_vec_id_\" + tname + \"_q8_1_f32_subgroup\", \"mul_mat_vecq.comp\", merge_maps(base_dict, {{\"MUL_MAT_ID\", \"1\"}, {data_a_key, \"1\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}, {\"FLOAT_TYPE_VEC2\", \"vec2\"}, {\"ACC_TYPE\", \"float\"}, {\"USE_SUBGROUP_ADD\", \"1\"}}));\n            string_to_spv(\"mul_mat_vec_id_\" + tname + \"_q8_1_f32_subgroup_no_shmem\", \"mul_mat_vecq.comp\", merge_maps(base_dict, {{\"MUL_MAT_ID\", \"1\"}, {data_a_key, \"1\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}, {\"FLOAT_TYPE_VEC2\", \"vec2\"}, {\"ACC_TYPE\", \"float\"}, {\"USE_SUBGROUP_ADD_NO_SHMEM\", \"1\"}}));\n        }\n#endif\n\n        // Dequant shaders\n        if (tname != \"f16\" && tname != \"bf16\") {\n            string_to_spv(\"dequant_\" + tname, \"dequant_\" + tname + \".comp\", merge_maps(base_dict, {{data_a_key, \"1\"}, {\"D_TYPE\", \"float16_t\"}}));\n        }\n\n        shader = (tname == \"f32\" || tname == \"f16\" || tname == \"bf16\") ? \"get_rows.comp\" : \"get_rows_quant.comp\";\n\n        if (tname == \"f16\") {\n            string_to_spv(\"get_rows_\" + tname, shader, merge_maps(base_dict, {{\"TEMP_TYPE\", \"FLOAT_TYPE\"}, {data_a_key, \"1\"}, {\"B_TYPE\", \"int\"}, {\"D_TYPE\", \"float16_t\"}, {\"OPTIMIZATION_ERROR_WORKAROUND\", \"1\"}}));\n        } else {\n            string_to_spv(\"get_rows_\" + tname, shader, merge_maps(base_dict, {{\"TEMP_TYPE\", \"FLOAT_TYPE\"}, {data_a_key, \"1\"}, {\"B_TYPE\", \"int\"}, {\"D_TYPE\", \"float16_t\"}}));\n        }\n        string_to_spv(\"get_rows_\" + tname + \"_f32\", shader, merge_maps(base_dict, {{\"TEMP_TYPE\", \"FLOAT_TYPE\"}, {data_a_key, \"1\"}, {\"B_TYPE\", \"int\"}, {\"D_TYPE\", \"float\"}}));\n    }\n\n    string_to_spv(\"get_rows_i32\", \"get_rows.comp\", {{\"TEMP_TYPE\", \"uint\"}, {\"A_TYPE\", \"uint\"}, {\"B_TYPE\", \"int\"}, {\"D_TYPE\", \"uint\"}});\n\n    string_to_spv(\"mul_mat_vec_p021_f16_f32_subgroup_add\", \"mul_mat_vec_p021.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"A_TYPE_VEC4\", \"f16vec4\"}, {\"B_TYPE\", \"float\"}, {\"B_TYPE_VEC4\", \"vec4\"}, {\"D_TYPE\", \"float\"}, {\"USE_SUBGROUP_ADD\", \"1\"}});\n    string_to_spv(\"mul_mat_vec_p021_f16_f32\",              \"mul_mat_vec_p021.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"A_TYPE_VEC4\", \"f16vec4\"}, {\"B_TYPE\", \"float\"}, {\"B_TYPE_VEC4\", \"vec4\"}, {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"mul_mat_vec_nc_f16_f32\", \"mul_mat_vec_nc.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"A_TYPE_VEC4\", \"f16vec4\"}, {\"B_TYPE\", \"float\"}, {\"B_TYPE_VEC4\", \"vec4\"}, {\"D_TYPE\", \"float\"}});\n\n    // Norms\n    string_to_spv(\"norm_f32\", \"norm.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"group_norm_f32\", \"group_norm.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"rms_norm_f32\", \"rms_norm.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"rms_norm_partials_f32\", \"rms_norm_partials.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"rms_norm_mul_rope_f32_f32\", \"rms_norm.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"ROPE_D_TYPE\", \"float\"}, {\"RMS_NORM_ROPE_FUSION\", \"1\"}}));\n    string_to_spv(\"rms_norm_mul_rope_f32_f16_rte\", \"rms_norm.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"ROPE_D_TYPE\", \"float16_t\"}, {\"RMS_NORM_ROPE_FUSION\", \"1\"}, {\"RTE16\", \"1\"}}));\n    string_to_spv(\"rms_norm_back_f32\", \"rms_norm_back.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"l2_norm_f32\", \"l2_norm.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n\n    string_to_spv(\"cpy_f32_f32\", \"copy.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"cpy_f32_f16\", \"copy.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"cpy_f16_f16\", \"copy.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"D_TYPE\", \"float16_t\"}, {\"OPTIMIZATION_ERROR_WORKAROUND\", \"1\"}});\n    string_to_spv(\"cpy_f16_f32\", \"copy.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"D_TYPE\", \"float\"}, {\"OPTIMIZATION_ERROR_WORKAROUND\", \"1\"}});\n    string_to_spv(\"cpy_f32_bf16\",\"copy.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"uint16_t\"}, {\"DATA_D_BF16\", \"1\"}});\n    string_to_spv(\"contig_cpy_f32_f32\", \"contig_copy.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"contig_cpy_f32_i32\", \"contig_copy.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"int\"}});\n    string_to_spv(\"contig_cpy_i32_f32\", \"contig_copy.comp\", {{\"A_TYPE\", \"int\"}, {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"contig_cpy_f32_f16\", \"contig_copy.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"contig_cpy_f16_f16\", \"contig_copy.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"D_TYPE\", \"float16_t\"}, {\"OPTIMIZATION_ERROR_WORKAROUND\", \"1\"}});\n    string_to_spv(\"contig_cpy_f16_f32\", \"contig_copy.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"D_TYPE\", \"float\"}, {\"OPTIMIZATION_ERROR_WORKAROUND\", \"1\"}});\n    string_to_spv(\"contig_cpy_f32_bf16\",\"contig_copy.comp\",{{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"uint16_t\"}, {\"DATA_D_BF16\", \"1\"}});\n    string_to_spv(\"cpy_f32_i32\", \"copy.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"int\"}});\n    string_to_spv(\"cpy_i32_f32\", \"copy.comp\", {{\"A_TYPE\", \"int\"}, {\"D_TYPE\", \"float\"}});\n\n    string_to_spv(\"cpy_transpose_16\", \"copy_transpose.comp\", {{\"A_TYPE\", \"uint16_t\"}, {\"D_TYPE\", \"uint16_t\"}});\n    string_to_spv(\"cpy_transpose_32\", \"copy_transpose.comp\", {{\"A_TYPE\", \"uint\"}, {\"D_TYPE\", \"uint\"}});\n\n    for (std::string t : {\"q4_0\", \"q4_1\", \"q5_0\", \"q5_1\", \"q8_0\", \"iq4_nl\"}) {\n        string_to_spv(\"cpy_f32_\" + t, \"copy_to_quant.comp\", {{\"DATA_A_\" + to_uppercase(t), \"1\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n        string_to_spv(\"cpy_f32_\" + t + \"_rte\", \"copy_to_quant.comp\", {{\"DATA_A_\" + to_uppercase(t), \"1\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}, {\"RTE16\", \"1\"}});\n        string_to_spv(\"cpy_\" + t + \"_f32\", \"copy_from_quant.comp\", {{\"DATA_A_\" + to_uppercase(t), \"1\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n    }\n\n    for (std::string t : {\"f32\", \"f16\", \"bf16\", \"q4_0\", \"q4_1\", \"q5_0\", \"q5_1\", \"q8_0\", \"iq4_nl\"}) {\n        string_to_spv(\"set_rows_\" + t + \"_i32\",     \"copy_to_quant.comp\", {{\"SET_ROWS\", \"1\"}, {\"DATA_A_\" + to_uppercase(t), \"1\"}, {\"B_TYPE\", \"uint\"}, {\"B_SIZE\", \"32\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n        string_to_spv(\"set_rows_\" + t + \"_i32_rte\", \"copy_to_quant.comp\", {{\"SET_ROWS\", \"1\"}, {\"DATA_A_\" + to_uppercase(t), \"1\"}, {\"B_TYPE\", \"uint\"}, {\"B_SIZE\", \"32\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}, {\"RTE16\", \"1\"}});\n        string_to_spv(\"set_rows_\" + t + \"_i64\",     \"copy_to_quant.comp\", {{\"SET_ROWS\", \"1\"}, {\"DATA_A_\" + to_uppercase(t), \"1\"}, {\"B_TYPE\", \"uvec2\"}, {\"B_SIZE\", \"64\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n        string_to_spv(\"set_rows_\" + t + \"_i64_rte\", \"copy_to_quant.comp\", {{\"SET_ROWS\", \"1\"}, {\"DATA_A_\" + to_uppercase(t), \"1\"}, {\"B_TYPE\", \"uvec2\"}, {\"B_SIZE\", \"64\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}, {\"RTE16\", \"1\"}});\n    }\n\n    auto get_type_str = [](bool f16) {\n        return f16 ? \"float16_t\" : \"float\";\n    };\n    auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {\n        std::string s;\n        s += std::string(src0_f16 ? \"_f16\" : \"_f32\");\n        s += std::string(src1_f16 ? \"_f16\" : \"_f32\");\n        s += std::string(dst_f16 ? \"_f16\" : \"_f32\");\n        return s;\n    };\n    for (std::string op : {\"add\", \"sub\", \"mul\", \"div\", \"add_rms\", }) {\n    for (auto src0_f16 : {false, true}) {\n    for (auto src1_f16 : {false, true}) {\n    for (auto dst_f16  : {false, true}) {\n    for (auto rte      : {false, true}) {\n        auto source = op == \"add_rms\" ? std::string(\"add\") : op;\n        auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? \"_rte\" : \"\");\n        auto add_rms = op == \"add_rms\" ? \"1\" : \"0\";\n        string_to_spv(name.c_str(), source + \".comp\", {{\"A_TYPE\", get_type_str(src0_f16)}, {\"B_TYPE\", get_type_str(src1_f16)}, {\"D_TYPE\", get_type_str(dst_f16)}, {\"FLOAT_TYPE\", \"float\"}, {\"RTE16\", rte ? \"1\" : \"0\"}, {\"ADD_RMS\" , add_rms}});\n    }\n    }\n    }\n    }\n    }\n\n    string_to_spv(\"sub_f32\", \"sub.comp\", {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n\n    string_to_spv(\"acc_f32\", \"acc.comp\", {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n\n    string_to_spv(\"split_k_reduce\", \"mul_mat_split_k_reduce.comp\", {});\n    string_to_spv(\"fa_split_k_reduce\", \"flash_attn_split_k_reduce.comp\", {});\n\n    string_to_spv(\"fa_mask_opt\", \"flash_attn_mask_opt.comp\", {});\n\n    string_to_spv(\"quantize_q8_1\", \"quantize_q8_1.comp\", {});\n    string_to_spv(\"quantize_q8_1_subgroup\", \"quantize_q8_1.comp\", {{\"USE_SUBGROUPS\", \"1\"}});\n\n    string_to_spv(\"quantize_q8_1_x4\", \"quantize_q8_1.comp\", {{\"QBLOCK_X4\", \"1\"}});\n    string_to_spv(\"quantize_q8_1_x4_subgroup\", \"quantize_q8_1.comp\", {{\"QBLOCK_X4\", \"1\"}, {\"USE_SUBGROUPS\", \"1\"}});\n\n    string_to_spv(\"mul_f32\", \"mul.comp\", {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n\n    string_to_spv(\"div_f32\", \"div.comp\", {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n\n    string_to_spv(\"repeat_f32\", \"repeat.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"repeat_back_f32\", \"repeat_back.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}});\n\n    string_to_spv(\"scale_f32\", \"scale.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n\n    string_to_spv(\"sqr_f32\", \"square.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n\n    string_to_spv(\"sqrt_f32\", \"sqrt.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n\n    string_to_spv(\"sin_f32\", \"sin.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n\n    string_to_spv(\"cos_f32\", \"cos.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n\n    string_to_spv(\"clamp_f32\", \"clamp.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n\n    string_to_spv(\"pad_f32\", \"pad.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}});\n\n    string_to_spv(\"concat_f32\", \"concat.comp\", {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"concat_f16\", \"concat.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"B_TYPE\", \"float16_t\"}, {\"D_TYPE\", \"float16_t\"}, {\"OPTIMIZATION_ERROR_WORKAROUND\", \"1\"}});\n    string_to_spv(\"concat_i32\", \"concat.comp\", {{\"A_TYPE\", \"int\"}, {\"B_TYPE\", \"int\"}, {\"D_TYPE\", \"int\"}});\n\n    string_to_spv(\"upscale_f32\", \"upscale.comp\", {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}});\n\n    for (auto rte : {false, true}) {\n        std::string suffix = rte ? \"_rte\" : \"\";\n        string_to_spv(\"exp_f16\" + suffix,        \"exp.comp\",         {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"},   {\"RTE16\", rte ? \"1\" : \"0\"}});\n        string_to_spv(\"exp_f32\" + suffix,        \"exp.comp\",         {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}    ,   {\"RTE16\", rte ? \"1\" : \"0\"}});\n\n        string_to_spv(\"log_f16\" + suffix,        \"log.comp\",         {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"},   {\"RTE16\", rte ? \"1\" : \"0\"}});\n        string_to_spv(\"log_f32\" + suffix,        \"log.comp\",         {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"},       {\"RTE16\", rte ? \"1\" : \"0\"}});\n    }\n    string_to_spv(\"gelu_f16\",       \"gelu.comp\",        {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"gelu_f32\",       \"gelu.comp\",        {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"gelu_erf_f16\",   \"gelu_erf.comp\",    {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"gelu_erf_f32\",   \"gelu_erf.comp\",    {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"gelu_quick_f16\", \"gelu_quick.comp\",  {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"gelu_quick_f32\", \"gelu_quick.comp\",  {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"silu_f16\",       \"silu.comp\",        {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"silu_f32\",       \"silu.comp\",        {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"relu_f16\",       \"relu.comp\",        {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"relu_f32\",       \"relu.comp\",        {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"neg_f16\",        \"neg.comp\",         {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"neg_f32\",        \"neg.comp\",         {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"tanh_f16\",       \"tanh.comp\",        {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"tanh_f32\",       \"tanh.comp\",        {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"sigmoid_f16\",    \"sigmoid.comp\",     {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"sigmoid_f32\",    \"sigmoid.comp\",     {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"hardsigmoid_f16\",\"hardsigmoid.comp\", {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"hardsigmoid_f32\",\"hardsigmoid.comp\", {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"hardswish_f16\",  \"hardswish.comp\",   {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"hardswish_f32\",  \"hardswish.comp\",   {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"abs_f16\",        \"abs.comp\",         {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"abs_f32\",        \"abs.comp\",         {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"elu_f16\",        \"elu.comp\",         {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"elu_f32\",        \"elu.comp\",         {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"xielu_f16\",      \"xielu.comp\",       {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"xielu_f32\",      \"xielu.comp\",       {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"sgn_f16\",        \"sgn.comp\",         {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"sgn_f32\",        \"sgn.comp\",         {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n\n    string_to_spv(\"tri_f16\",        \"tri.comp\",         {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"tri_f32\",        \"tri.comp\",         {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"diag_f16\",       \"diag.comp\",        {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"diag_f32\",       \"diag.comp\",        {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n\n    string_to_spv(\"softplus_f16\",   \"softplus.comp\",    {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"softplus_f32\",   \"softplus.comp\",    {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n\n    string_to_spv(\"add1_f16_f16\",   \"add1.comp\",        {{\"A_TYPE\", \"float16_t\"},   {\"B_TYPE\", \"float16_t\"}, {\"D_TYPE\", \"float16_t\"}, {\"FLOAT_TYPE\", \"float\"}});\n    string_to_spv(\"add1_f16_f32\",   \"add1.comp\",        {{\"A_TYPE\", \"float16_t\"},   {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float16_t\"}, {\"FLOAT_TYPE\", \"float\"}});\n    string_to_spv(\"add1_f32_f32\",   \"add1.comp\",        {{\"A_TYPE\", \"float\"},       {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n    string_to_spv(\"arange_f32\",     \"arange.comp\",      {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}});\n    string_to_spv(\"fill_f32\",       \"fill.comp\",        {{\"D_TYPE\", \"float\"},       {\"FLOAT_TYPE\", \"float\"}});\n    string_to_spv(\"step_f16\",       \"step.comp\",        {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"step_f32\",       \"step.comp\",        {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"round_f16\",      \"round.comp\",       {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"round_f32\",      \"round.comp\",       {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"ceil_f16\",       \"ceil.comp\",        {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"ceil_f32\",       \"ceil.comp\",        {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"floor_f16\",      \"floor.comp\",       {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"floor_f32\",      \"floor.comp\",       {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"trunc_f16\",      \"trunc.comp\",       {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"trunc_f32\",      \"trunc.comp\",       {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"}});\n\n    for (auto rte : {false, true}) {\n        std::string suffix = rte ? \"_rte\" : \"\";\n        string_to_spv(\"geglu_f16\" + suffix,      \"geglu.comp\",       {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"},   {\"RTE16\", rte ? \"1\" : \"0\"}});\n        string_to_spv(\"geglu_f32\" + suffix,      \"geglu.comp\",       {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"},       {\"RTE16\", rte ? \"1\" : \"0\"}});\n        string_to_spv(\"reglu_f16\" + suffix,      \"reglu.comp\",       {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"},   {\"RTE16\", rte ? \"1\" : \"0\"}});\n        string_to_spv(\"reglu_f32\" + suffix,      \"reglu.comp\",       {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"},       {\"RTE16\", rte ? \"1\" : \"0\"}});\n        string_to_spv(\"swiglu_f16\" + suffix,     \"swiglu.comp\",      {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"},   {\"RTE16\", rte ? \"1\" : \"0\"}});\n        string_to_spv(\"swiglu_f32\" + suffix,     \"swiglu.comp\",      {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"},       {\"RTE16\", rte ? \"1\" : \"0\"}});\n        string_to_spv(\"swiglu_oai_f16\" + suffix, \"swiglu_oai.comp\",  {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"},   {\"RTE16\", rte ? \"1\" : \"0\"}});\n        string_to_spv(\"swiglu_oai_f32\" + suffix, \"swiglu_oai.comp\",  {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"},       {\"RTE16\", rte ? \"1\" : \"0\"}});\n        string_to_spv(\"geglu_erf_f16\" + suffix,  \"geglu_erf.comp\",   {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"},   {\"RTE16\", rte ? \"1\" : \"0\"}});\n        string_to_spv(\"geglu_erf_f32\" + suffix,  \"geglu_erf.comp\",   {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"},       {\"RTE16\", rte ? \"1\" : \"0\"}});\n        string_to_spv(\"geglu_quick_f16\" + suffix,\"geglu_quick.comp\", {{\"A_TYPE\", \"float16_t\"},   {\"D_TYPE\", \"float16_t\"},   {\"RTE16\", rte ? \"1\" : \"0\"}});\n        string_to_spv(\"geglu_quick_f32\" + suffix,\"geglu_quick.comp\", {{\"A_TYPE\", \"float\"},       {\"D_TYPE\", \"float\"},       {\"RTE16\", rte ? \"1\" : \"0\"}});\n    }\n\n    string_to_spv(\"leaky_relu_f32\", \"leaky_relu.comp\",  {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}});\n    string_to_spv(\"silu_back_f32\",  \"silu_back.comp\",   {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}});\n\n    string_to_spv(\"diag_mask_inf_f32\", \"diag_mask_inf.comp\", {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}});\n\n    string_to_spv(\"soft_max_f32\", \"soft_max.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"soft_max_f32_f16\", \"soft_max.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float16_t\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"soft_max_back_f32\", \"soft_max_back.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n\n    string_to_spv(\"soft_max_large1_f32\", \"soft_max_large1.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"soft_max_large2_f32\", \"soft_max_large2.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"soft_max_large3_f32\", \"soft_max_large3.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"soft_max_large1_f32_f16\", \"soft_max_large1.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float16_t\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"soft_max_large2_f32_f16\", \"soft_max_large2.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float16_t\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"soft_max_large3_f32_f16\", \"soft_max_large3.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float16_t\"}, {\"D_TYPE\", \"float\"}}));\n\n    string_to_spv(\"rope_norm_f32\", \"rope_norm.comp\", {{\"A_TYPE\", \"float\"}, {\"ROPE_D_TYPE\", \"float\"}});\n    string_to_spv(\"rope_norm_f16\", \"rope_norm.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"ROPE_D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"rope_norm_f16_rte\", \"rope_norm.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"ROPE_D_TYPE\", \"float16_t\"}, {\"RTE16\", \"1\"}});\n    string_to_spv(\"rope_norm_f32_f16\", \"rope_norm.comp\", {{\"A_TYPE\", \"float\"}, {\"ROPE_D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"rope_norm_f32_f16_rte\", \"rope_norm.comp\", {{\"A_TYPE\", \"float\"}, {\"ROPE_D_TYPE\", \"float16_t\"}, {\"RTE16\", \"1\"}});\n\n    string_to_spv(\"rope_neox_f32\", \"rope_neox.comp\", {{\"A_TYPE\", \"float\"}, {\"ROPE_D_TYPE\", \"float\"}});\n    string_to_spv(\"rope_neox_f16\", \"rope_neox.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"ROPE_D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"rope_neox_f16_rte\", \"rope_neox.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"ROPE_D_TYPE\", \"float16_t\"}, {\"RTE16\", \"1\"}});\n    string_to_spv(\"rope_neox_f32_f16\", \"rope_neox.comp\", {{\"A_TYPE\", \"float\"}, {\"ROPE_D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"rope_neox_f32_f16_rte\", \"rope_neox.comp\", {{\"A_TYPE\", \"float\"}, {\"ROPE_D_TYPE\", \"float16_t\"}, {\"RTE16\", \"1\"}});\n\n    string_to_spv(\"rope_multi_f32\", \"rope_multi.comp\", {{\"A_TYPE\", \"float\"}, {\"ROPE_D_TYPE\", \"float\"}});\n    string_to_spv(\"rope_multi_f16\", \"rope_multi.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"ROPE_D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"rope_multi_f16_rte\", \"rope_multi.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"ROPE_D_TYPE\", \"float16_t\"}, {\"RTE16\", \"1\"}});\n    string_to_spv(\"rope_multi_f32_f16\", \"rope_multi.comp\", {{\"A_TYPE\", \"float\"}, {\"ROPE_D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"rope_multi_f32_f16_rte\", \"rope_multi.comp\", {{\"A_TYPE\", \"float\"}, {\"ROPE_D_TYPE\", \"float16_t\"}, {\"RTE16\", \"1\"}});\n\n    string_to_spv(\"rope_vision_f32\", \"rope_vision.comp\", {{\"A_TYPE\", \"float\"}, {\"ROPE_D_TYPE\", \"float\"}});\n    string_to_spv(\"rope_vision_f16\", \"rope_vision.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"ROPE_D_TYPE\", \"float16_t\"}});\n    string_to_spv(\"rope_vision_f16_rte\", \"rope_vision.comp\", {{\"A_TYPE\", \"float16_t\"}, {\"ROPE_D_TYPE\", \"float16_t\"}, {\"RTE16\", \"1\"}});\n\n    string_to_spv(\"argsort_f32\", \"argsort.comp\", {{\"A_TYPE\", \"float\"}});\n    string_to_spv(\"argsort_large_f32\", \"argsort_large.comp\", {{\"A_TYPE\", \"float\"}});\n\n    string_to_spv(\"topk_argsort_f32\", \"topk_argsort.comp\", {{\"A_TYPE\", \"float\"}});\n    string_to_spv(\"topk_nary_search_f32\", \"topk_nary_search.comp\", {{\"A_TYPE\", \"float\"}});\n\n    string_to_spv(\"argmax_f32\", \"argmax.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"int\"}}));\n    string_to_spv(\"sum_rows_f32\", \"sum_rows.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"count_equal_i32\", \"count_equal.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"int\"}, {\"B_TYPE\", \"int\"}, {\"D_TYPE\", \"int\"}}));\n    string_to_spv(\"cumsum_f32\", \"cumsum.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"cumsum_multipass1_f32\", \"cumsum_multipass1.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n    string_to_spv(\"cumsum_multipass2_f32\", \"cumsum_multipass2.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n\n    string_to_spv(\"count_experts\", \"count_experts.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"uint\"}, {\"D_TYPE\", \"uint\"}}));\n\n    for (std::string dim_str : {\"\", \"_3d\"}) {\n        for (bool bda : {false, true}) {\n            std::string bda_str = bda ? \"_bda\" : \"\";\n            std::string bda_def = bda ? \"1\" : \"0\";\n            string_to_spv(\"im2col\" + dim_str + \"_f32\" + bda_str, \"im2col\" + dim_str + \".comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"D_SIZE\", \"4\"}, {\"BDA\", bda_def}}));\n            string_to_spv(\"im2col\" + dim_str + \"_f32_f16\" + bda_str, \"im2col\" + dim_str + \".comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float16_t\"}, {\"D_SIZE\", \"2\"}, {\"BDA\", bda_def}}));\n            string_to_spv(\"im2col\" + dim_str + \"_f32_f16_rte\" + bda_str, \"im2col\" + dim_str + \".comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float16_t\"}, {\"D_SIZE\", \"2\"}, {\"RTE16\", \"1\"}, {\"BDA\", bda_def}}));\n        }\n    }\n\n    string_to_spv(\"timestep_embedding_f32\", \"timestep_embedding.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n\n    string_to_spv(\"conv_transpose_1d_f32\", \"conv_transpose_1d.comp\", {{\"A_TYPE\", \"float\"},  {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}});\n\n    string_to_spv(\"pool2d_f32\", \"pool2d.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n\n    string_to_spv(\"rwkv_wkv6_f32\", \"wkv6.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}}));\n\n    string_to_spv(\"rwkv_wkv7_f32\", \"wkv7.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}}));\n\n    string_to_spv(\"gated_delta_net_f32\", \"gated_delta_net.comp\", merge_maps(base_dict, {{\"FLOAT_TYPE\", \"float\"}}));\n\n    string_to_spv(\"opt_step_adamw_f32\", \"opt_step_adamw.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}}));\n    string_to_spv(\"opt_step_sgd_f32\", \"opt_step_sgd.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}}));\n\n    string_to_spv(\"solve_tri_f32\", \"solve_tri.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n\n    for (auto transpose : {false, true}) {\n        for (auto unroll : {false, true}) {\n            for (auto a_f16 : {false, true}) {\n                std::map<std::string, std::string> defines = {\n                    {\"A_TYPE\", a_f16 ? \"float16_t\" : \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"},\n                    {\"USE_COLLECTIVES\", \"1\"}, {\"UNROLL\", unroll ? \"[[unroll]]\" : \"\"},\n                };\n                if (transpose) defines[\"TRANSPOSE\"] = \"1\";\n                std::string name = std::string(transpose ? \"conv_transpose_2d\": \"conv2d\")\n                    + (a_f16 ? \"_f16\" : \"\") + \"_f32\";\n                string_to_spv(name + (unroll ? \"_unroll\" : \"\"), \"conv2d_mm.comp\", defines);\n#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)\n                if (unroll) {\n                    defines[\"COOPMAT2\"] = \"1\";\n                    string_to_spv(name, \"conv2d_mm.comp\", defines, true, false, true);\n                }\n#endif\n            }\n        }\n    }\n\n    string_to_spv(\"conv2d_dw_whcn_f32\", \"conv2d_dw.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"WHCN\", \"1\"}}));\n    string_to_spv(\"conv2d_dw_cwhn_f32\", \"conv2d_dw.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"CWHN\", \"1\"}}));\n    string_to_spv(\"conv2d_dw_whcn_f16_f32\", \"conv2d_dw.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float16_t\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"WHCN\", \"1\"}}));\n    string_to_spv(\"conv2d_dw_cwhn_f16_f32\", \"conv2d_dw.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float16_t\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"CWHN\", \"1\"}}));\n\n    string_to_spv(\"roll_f32\", \"roll.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n\n    string_to_spv(\"add_id_f32\", \"add_id.comp\", merge_maps(base_dict, {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}}));\n\n    string_to_spv(\"multi_add_f32\", \"multi_add.comp\", {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}, {\"RTE16\", \"1\"}, {\"ADD_RMS\" , \"0\"}});\n    string_to_spv(\"multi_add_rms_f32\", \"multi_add.comp\", {{\"A_TYPE\", \"float\"}, {\"B_TYPE\", \"float\"}, {\"D_TYPE\", \"float\"}, {\"FLOAT_TYPE\", \"float\"}, {\"RTE16\", \"1\"}, {\"ADD_RMS\" , \"1\"}});\n\n    string_to_spv(\"ssm_scan_f32\",          \"ssm_scan.comp\", {{\"A_TYPE\", \"float\"}});\n    string_to_spv(\"ssm_scan_subgroup_f32\", \"ssm_scan.comp\", {{\"A_TYPE\", \"float\"}, {\"USE_SUBGROUP_ADD\", \"1\"}});\n\n    string_to_spv(\"ssm_conv_f32\", \"ssm_conv.comp\", {{\"A_TYPE\", \"float\"}});\n\n    string_to_spv(\"topk_moe_f32\", \"topk_moe.comp\", {});\n\n    for (auto &c : compiles) {\n        c.wait();\n    }\n}\n\nvoid write_output_files() {\n    std::stringstream hdr = make_generic_stringstream();\n    std::stringstream src = make_generic_stringstream();\n\n    hdr << \"#include <cstdint>\\n\\n\";\n    src << \"#include \\\"\" << basename(target_hpp) << \"\\\"\\n\\n\";\n\n    std::sort(shader_fnames.begin(), shader_fnames.end());\n    for (const auto& pair : shader_fnames) {\n        const std::string& name = pair.first;\n        #ifdef _WIN32\n            std::string path = pair.second;\n            std::replace(path.begin(), path.end(), '/', '\\\\' );\n        #else\n            const std::string& path = pair.second;\n        #endif\n\n        hdr << \"extern const uint64_t \" << name << \"_len;\\n\";\n        hdr << \"extern const unsigned char \" << name << \"_data[];\\n\\n\";\n\n        if (input_filepath != \"\") {\n            std::string data = read_binary_file(path);\n            if (data.empty()) {\n                continue;\n            }\n\n            src << \"const uint64_t \" << name << \"_len = \" << data.size() << \";\\n\";\n            src << \"const unsigned char \" << name << \"_data[\" << data.size() << \"] = {\\n\" << std::hex;\n            auto bytes = reinterpret_cast<const uint8_t*>(data.data());\n            for (size_t i = 0; i < data.size(); ++i) {\n                src << \"0x\" << static_cast<int>(bytes[i]) << \",\";\n                if ((i + 1) % 12 == 0) src << \"\\n\";\n            }\n            src << std::dec << \"\\n};\\n\\n\";\n        }\n    }\n\n    std::string suffixes[2] = {\"_f32\", \"_f16\"};\n    for (std::string op : {\"add\", \"sub\", \"mul\", \"div\", \"add_rms\"}) {\n        hdr << \"extern const void * \" << op << \"_data[2][2][2][2];\\n\";\n        hdr << \"extern const uint64_t \" << op << \"_len[2][2][2][2];\\n\";\n\n        std::string op_file = op == \"add_rms\" ? \"add.comp\" : std::string(op) + \".comp\";\n        if (basename(input_filepath) != op_file) {\n            continue;\n        }\n        std::stringstream data = make_generic_stringstream();\n        std::stringstream len  = make_generic_stringstream();\n        data << \"const void * \" << op << \"_data[2][2][2][2] = \";\n        len  << \"const uint64_t \" << op << \"_len[2][2][2][2] = \";\n        for (uint32_t t0 = 0; t0 < 2; ++t0) {\n            if (t0 == 0) {\n                data << \"{\";\n                len  << \"{\";\n            }\n            for (uint32_t t1 = 0; t1 < 2; ++t1) {\n                if (t1 == 0) {\n                    data << \"{\";\n                    len  << \"{\";\n                }\n                for (uint32_t t2 = 0; t2 < 2; ++t2) {\n                    if (t2 == 0) {\n                        data << \"{\";\n                        len  << \"{\";\n                    }\n                    for (uint32_t rte = 0; rte < 2; ++rte) {\n                        if (rte == 0) {\n                            data << \"{\";\n                            len  << \"{\";\n                        }\n                        data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? \"_rte\" : \"\");\n                        len  << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? \"_rte\" : \"\");\n                        data << \"_data,\";\n                        len  << \"_len,\";\n                        if (rte == 1) {\n                            data << \"}, \";\n                            len  << \"}, \";\n                        }\n                    }\n                    if (t2 == 1) {\n                        data << \"}, \";\n                        len  << \"}, \";\n                    }\n                }\n                if (t1 == 1) {\n                    data << \"}, \";\n                    len  << \"}, \";\n                }\n            }\n            if (t0 == 1) {\n                data << \"};\\n\";\n                len  << \"};\\n\";\n            }\n        }\n        src << data.str();\n        src << len.str();\n    }\n\n    std::vector<std::string> btypes = {\"f16\", \"f32\"};\n\n#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)\n    btypes.push_back(\"q8_1\");\n#endif\n\n    for (const std::string& btype : btypes) {\n    for (const auto& tname : type_names) {\n        if (btype == \"q8_1\" && !is_legacy_quant(tname) && tname != \"mxfp4\" && !is_k_quant(tname) && tname != \"iq1_s\" && tname != \"iq1_m\") {\n            continue;\n        }\n        hdr << \"extern const void * arr_dmmv_\"   << tname << \"_\" << btype << \"_f32_data[3];\\n\";\n        hdr << \"extern const uint64_t arr_dmmv_\" << tname << \"_\" << btype << \"_f32_len[3];\\n\";\n        if (basename(input_filepath) == \"mul_mat_vec.comp\") {\n            src << \"const void * arr_dmmv_\"   << tname << \"_\" << btype << \"_f32_data[3] = {mul_mat_vec_\" << tname << \"_\" << btype << \"_f32_data, mul_mat_vec_\" << tname << \"_\" << btype << \"_f32_subgroup_data, mul_mat_vec_\" << tname << \"_\" << btype << \"_f32_subgroup_no_shmem_data};\\n\";\n            src << \"const uint64_t arr_dmmv_\" << tname << \"_\" << btype << \"_f32_len[3] =  {mul_mat_vec_\" << tname << \"_\" << btype << \"_f32_len,  mul_mat_vec_\" << tname << \"_\" << btype << \"_f32_subgroup_len, mul_mat_vec_\"  << tname << \"_\" << btype << \"_f32_subgroup_no_shmem_len};\\n\";\n        }\n\n        if (btype == \"f16\") {\n            continue;\n        }\n        hdr << \"extern const void * arr_dmmv_id_\"   << tname << \"_\" << btype << \"_f32_data[3];\\n\";\n        hdr << \"extern const uint64_t arr_dmmv_id_\" << tname << \"_\" << btype << \"_f32_len[3];\\n\";\n        if (basename(input_filepath) == \"mul_mat_vec.comp\") {\n            src << \"const void * arr_dmmv_id_\"   << tname << \"_\" << btype << \"_f32_data[3] = {mul_mat_vec_id_\" << tname << \"_\" << btype << \"_f32_data, mul_mat_vec_id_\" << tname << \"_\" << btype << \"_f32_subgroup_data, mul_mat_vec_id_\" << tname << \"_\" << btype << \"_f32_subgroup_no_shmem_data};\\n\";\n            src << \"const uint64_t arr_dmmv_id_\" << tname << \"_\" << btype << \"_f32_len[3] =  {mul_mat_vec_id_\" << tname << \"_\" << btype << \"_f32_len,  mul_mat_vec_id_\" << tname << \"_\" << btype << \"_f32_subgroup_len, mul_mat_vec_id_\"  << tname << \"_\" << btype << \"_f32_subgroup_no_shmem_len};\\n\";\n        }\n    }\n    }\n\n    if (input_filepath == \"\") {\n        write_file_if_changed(target_hpp, hdr.str());\n    }\n    if (target_cpp != \"\") {\n        write_binary_file(target_cpp, src.str());\n    }\n}\n\n} // namespace\n\nint main(int argc, char** argv) {\n    std::map<std::string, std::string> args;\n    for (int i = 1; i < argc; ++i) {\n        std::string arg = argv[i];\n        if (arg.rfind(\"--\", 0) == 0) {\n            if (i + 1 < argc && argv[i + 1][0] != '-') {\n                args[arg] = argv[i + 1];\n                ++i;\n            } else {\n                args[arg] = \"\";\n            }\n        }\n    }\n\n    if (args.find(\"--glslc\") != args.end()) {\n        GLSLC = args[\"--glslc\"]; // Path to glslc\n    }\n    if (args.find(\"--source\") != args.end()) {\n        input_filepath = args[\"--source\"]; // The shader source file to compile\n    }\n    if (args.find(\"--output-dir\") != args.end()) {\n        output_dir = args[\"--output-dir\"]; // Directory for containing SPIR-V output\n    }\n    if (args.find(\"--target-hpp\") != args.end()) {\n        target_hpp = args[\"--target-hpp\"]; // Path to generated header file\n    }\n    if (args.find(\"--target-cpp\") != args.end()) {\n        target_cpp = args[\"--target-cpp\"]; // Path to generated cpp file\n    }\n\n    if (!directory_exists(output_dir)) {\n        if (!create_directory(output_dir)) {\n            std::cerr << \"Error creating output directory: \" << output_dir << \"\\n\";\n            return EXIT_FAILURE;\n        }\n    }\n\n    process_shaders();\n\n    write_output_files();\n\n    return EXIT_SUCCESS;\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/wkv6.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : require\n\n#define BLOCK_SIZE 64\nlayout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;\n\nlayout(push_constant) uniform Parameters {\n    uint B;\n    uint T;\n    uint C;\n    uint H;\n};\n\nlayout(binding = 0) readonly buffer KBuf { A_TYPE k[]; };\nlayout(binding = 1) readonly buffer VBuf { A_TYPE v[]; };\nlayout(binding = 2) readonly buffer RBuf { A_TYPE r[]; };\nlayout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; };\nlayout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; };\nlayout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; };\nlayout(binding = 6) buffer DstBuf { A_TYPE dst[]; };\n\nshared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE];\n\nvoid main() {\n    const uint head_size = BLOCK_SIZE;\n    const uint batch_id = gl_WorkGroupID.x / H;\n    const uint head_id = gl_WorkGroupID.x % H;\n    const uint tid = gl_LocalInvocationID.x;\n\n    const uint state_size = C * head_size;\n    const uint n_seq_tokens = T / B;\n\n    if (batch_id >= B || head_id >= H) {\n        return;\n    }\n\n    A_TYPE state[BLOCK_SIZE];\n    [[unroll]] for (uint i = 0; i < head_size; i++) {\n        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size\n                          + i * head_size + tid];\n    }\n\n    barrier();\n    _tf[tid] = tf[head_id * head_size + tid];\n    barrier();\n\n    const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;\n    const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;\n\n    for (uint t = start_t; t < end_t; t += C) {\n        barrier();\n        _k[tid] = k[t];\n        _r[tid] = r[t];\n        _td[tid] = td[t];\n        barrier();\n\n        const A_TYPE v_val = v[t];\n        A_TYPE y = 0.0;\n\n        [[unroll]] for (uint j = 0; j < head_size; j += 4) {\n            vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);\n            vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);\n            vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);\n            vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]);\n            vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);\n\n            vec4 kv = k_vec * v_val;\n\n            vec4 temp = tf_vec * kv + s_vec;\n            y += dot(r_vec, temp);\n\n            s_vec = s_vec * td_vec + kv;\n            state[j] = s_vec.x;\n            state[j+1] = s_vec.y;\n            state[j+2] = s_vec.z;\n            state[j+3] = s_vec.w;\n        }\n\n        dst[t] = y;\n    }\n\n    [[unroll]] for (uint i = 0; i < head_size; i++) {\n        dst[T * C + batch_id * state_size + head_id * head_size * head_size\n            + i * head_size + tid] = state[i];\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/wkv7.comp",
    "content": "#version 450\n\n#extension GL_EXT_control_flow_attributes : require\n\n#define BLOCK_SIZE 64\nlayout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;\n\nlayout(push_constant) uniform Parameters {\n    uint B;\n    uint T;\n    uint C;\n    uint H;\n};\n\nlayout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };\nlayout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };\nlayout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };\nlayout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };\nlayout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };\nlayout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };\nlayout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };\nlayout(binding = 7) buffer DstBuf { A_TYPE dst[]; };\n\nshared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];\n\nvoid main() {\n    const uint head_size = BLOCK_SIZE;\n    const uint batch_id = gl_WorkGroupID.x / H;\n    const uint head_id = gl_WorkGroupID.x % H;\n    const uint tid = gl_LocalInvocationID.x;\n\n    const uint state_size = C * head_size;\n    const uint n_seq_tokens = T / B;\n\n    if (batch_id >= B || head_id >= H) {\n        return;\n    }\n\n    A_TYPE state[BLOCK_SIZE];\n    [[unroll]] for (uint i = 0; i < head_size; i++) {\n        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size\n                          + tid * head_size + i];\n    }\n\n    const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;\n    const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;\n\n    for (uint t = start_t; t < end_t; t += C) {\n        barrier();\n        _r[tid] = r[t];\n        _w[tid] = w[t];\n        _k[tid] = k[t];\n        _a[tid] = a[t];\n        _b[tid] = b[t];\n        barrier();\n\n        A_TYPE sa = 0.0;\n        [[unroll]] for (uint j = 0; j < head_size; j += 4) {\n            vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);\n            vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]);\n            sa += dot(s_vec, a_vec);\n        }\n\n        const A_TYPE v_val = v[t];\n        A_TYPE y = 0.0;\n\n        [[unroll]] for (uint j = 0; j < head_size; j += 4) {\n            vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);\n            vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);\n            vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);\n            vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);\n            vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);\n\n            vec4 kv = k_vec * v_val;\n            s_vec = s_vec * w_vec + kv + sa * b_vec;\n            y += dot(r_vec, s_vec);\n\n            state[j] = s_vec.x;\n            state[j+1] = s_vec.y;\n            state[j+2] = s_vec.z;\n            state[j+3] = s_vec.w;\n        }\n\n        dst[t] = y;\n    }\n\n    [[unroll]] for (uint i = 0; i < head_size; i++) {\n        dst[T * C + batch_id * state_size + head_id * head_size * head_size\n            + tid * head_size + i] = state[i];\n    }\n}\n"
  },
  {
    "path": "src/ggml-vulkan/vulkan-shaders/xielu.comp",
    "content": "#version 450\n\n#include \"generic_head.glsl\"\n#include \"types.glsl\"\n\n#extension GL_EXT_control_flow_attributes : enable\n\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\nlayout (binding = 0) readonly buffer X {A_TYPE data_a[];};\nlayout (binding = 1) writeonly buffer D {D_TYPE data_d[];};\n\nvoid main() {\n    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;\n\n    if (i >= p.KX) {\n        return;\n    }\n\n    float x = float(data_a[i]);\n\n    float alpha_n = p.param1;\n    float alpha_p = p.param2;\n    float beta = p.param3;\n    float eps = p.param4;\n\n    if (x > 0.0f) {\n        x = alpha_p * x * x + beta * x;\n    } else {\n        const float min_x_eps = min(x, eps);\n        x = (exp(min_x_eps) - 1 - x) * alpha_n + beta * x;\n    }\n\n    data_d[i] = D_TYPE(x);\n}\n"
  },
  {
    "path": "src/ggml-webgpu/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.13)\n\nfind_package(Python3 REQUIRED)\n\n# Shader locations\nset(SHADER_DIR \"${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders\")\nset(SHADER_OUTPUT_DIR \"${CMAKE_CURRENT_BINARY_DIR}/generated\")\nset(SHADER_HEADER \"${SHADER_OUTPUT_DIR}/ggml-wgsl-shaders.hpp\")\nfile(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR})\n\nmessage(STATUS \"Shader output dir: ${SHADER_OUTPUT_DIR}\")\n\n# Find all WGSL files\nfile(GLOB WGSL_SHADER_FILES \"${SHADER_DIR}/*.wgsl\")\n\n# Generate the header using a Python script\nadd_custom_command(\n    OUTPUT ${SHADER_HEADER}\n    COMMAND ${CMAKE_COMMAND} -E echo \"Embedding WGSL shaders to ggml-wgsl-shaders.hpp\"\n    COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR}\n    COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8\n        ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py\n            --input_dir \"${SHADER_DIR}\"\n            --output_file \"${SHADER_HEADER}\"\n    DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py\n    VERBATIM\n)\n\nadd_custom_target(generate_shaders DEPENDS ${SHADER_HEADER})\n\nggml_add_backend_library(ggml-webgpu\n    ggml-webgpu.cpp\n    ${SHADER_HEADER}\n    ../../include/ggml-webgpu.h\n)\n\nadd_dependencies(ggml-webgpu generate_shaders)\n\nif(EMSCRIPTEN)\n    set(EMDAWNWEBGPU_DIR \"\" CACHE PATH \"Path to emdawnwebgpu_pkg\")\n\n    if(NOT EMDAWNWEBGPU_DIR)\n        # default built-in port\n        target_compile_options(ggml-webgpu PRIVATE \"--use-port=emdawnwebgpu\")\n        target_link_options(ggml-webgpu INTERFACE \"--use-port=emdawnwebgpu\")\n    else()\n        # custom port\n        target_compile_options(ggml-webgpu PRIVATE \"--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py\")\n        target_link_options(ggml-webgpu INTERFACE \"--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py\")\n    endif()\n\n    if (GGML_WEBGPU_JSPI)\n        target_compile_options(ggml-webgpu PRIVATE \"-fwasm-exceptions\")\n        target_link_options(ggml-webgpu INTERFACE \"-sJSPI\" \"-fwasm-exceptions\")\n    else()\n        target_compile_options(ggml-webgpu PRIVATE \"-fexceptions\")\n        target_link_options(ggml-webgpu INTERFACE \"-sASYNCIFY\" \"-exceptions\")\n    endif()\nelse()\n    find_package(Dawn REQUIRED)\n    set(DawnWebGPU_TARGET dawn::webgpu_dawn)\nendif()\n\nif (GGML_WEBGPU_DEBUG)\n    target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1)\n    if(EMSCRIPTEN)\n        target_link_options(ggml-webgpu INTERFACE \"-sASSERTIONS=2\")\n    endif()\nendif()\n\nif (GGML_WEBGPU_CPU_PROFILE)\n    target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_CPU_PROFILE=1)\nendif()\n\nif (GGML_WEBGPU_GPU_PROFILE)\n    target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_GPU_PROFILE=1)\nendif()\n\ntarget_include_directories(ggml-webgpu PRIVATE ${SHADER_OUTPUT_DIR})\ntarget_link_libraries(ggml-webgpu PRIVATE ${DawnWebGPU_TARGET})\n"
  },
  {
    "path": "src/ggml-webgpu/ggml-webgpu-shader-lib.hpp",
    "content": "#ifndef GGML_WEBGPU_SHADER_LIB_HPP\n#define GGML_WEBGPU_SHADER_LIB_HPP\n\n#include \"ggml-wgsl-shaders.hpp\"\n#include \"ggml.h\"\n#include \"pre_wgsl.hpp\"\n\n#include <webgpu/webgpu_cpp.h>\n\n#include <algorithm>\n#include <memory>\n#include <string>\n#include <unordered_map>\n#include <vector>\n\n#define GGML_WEBGPU_F16_SIZE_BYTES                   2\n#define GGML_WEBGPU_F32_SIZE_BYTES                   4\n#define GGML_WEBGPU_I32_SIZE_BYTES                   4\n#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u\n#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE     128u\n// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.\n#define GGML_WEBGPU_KV_SEQ_PAD                       256u\n\n#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u\n\n// Matrix multiplication parameters\n\n// Register tiling parameters\n#define WEBGPU_MUL_MAT_TILE_M    8\n#define WEBGPU_MUL_MAT_TILE_N    8\n#define WEBGPU_MUL_MAT_WG_SIZE_M 8\n#define WEBGPU_MUL_MAT_WG_SIZE_N 8\n#define WEBGPU_MUL_MAT_TILE_K    32\n\n// Subgroup matrix parameters\n// The number of subgroups in the M dimension\n#define WEBGPU_MUL_MAT_SUBGROUP_M        2\n// The number of subgroups in the N dimension\n#define WEBGPU_MUL_MAT_SUBGROUP_N        2\n// The number of subgroup matrices each subgroup accumulates over\n#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4\n#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2\n\n// Matrix-vector multiplication parameters\n#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256\n\n// Must be multiple of 4 to work with vectorized paths, and must divide\n// mul_mat_vec wg size\n#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64\n#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K         256\n\n#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64\n#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K         256\n\n// Requires 32 threads per output (wg_size/outputs_per_wg == 32)\n#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8\n// Requires at least two (and multiple of 2) k-quant blocks per tile\n#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K         512\n\n// default size for legacy matrix multiplication\n#define WEBGPU_MUL_MAT_WG_SIZE 256\n\n// Same hash combine function as in boost\ntemplate <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {\n    seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);\n}\n\nstruct ggml_webgpu_shader_lib_context {\n    ggml_tensor * src0;\n    ggml_tensor * src1;\n    ggml_tensor * src2;\n    ggml_tensor * src3;\n    ggml_tensor * src4;\n    ggml_tensor * dst;\n\n    uint32_t max_wg_size;\n    size_t   wg_mem_limit_bytes       = 0;\n    bool     inplace                  = false;\n    bool     overlap                  = false;\n    bool     src_overlap              = false;\n    bool     supports_subgroup_matrix = false;\n    uint32_t sg_mat_m                 = 0;\n    uint32_t sg_mat_n                 = 0;\n    uint32_t sg_mat_k                 = 0;\n    uint32_t max_subgroup_size        = 0;\n};\n\nstruct webgpu_pipeline {\n    wgpu::ComputePipeline pipeline;\n    std::string           name;\n    std::shared_ptr<void> context = nullptr;\n};\n\nstruct ggml_webgpu_generic_shader_decisions {\n    uint32_t wg_size = 0;\n};\n\n/** Argsort **/\n\nstruct ggml_webgpu_argsort_shader_lib_context {\n    uint32_t max_wg_size;\n    size_t   wg_mem_limit_bytes;\n    int32_t  order;\n};\n\n/** Set Rows **/\n\nstruct ggml_webgpu_set_rows_pipeline_key {\n    int dst_type;\n    int vec4;\n    int i64_idx;\n\n    bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {\n        return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;\n    }\n};\n\nstruct ggml_webgpu_set_rows_pipeline_key_hash {\n    size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {\n        size_t seed = 0;\n        ggml_webgpu_hash_combine(seed, key.dst_type);\n        ggml_webgpu_hash_combine(seed, key.vec4);\n        ggml_webgpu_hash_combine(seed, key.i64_idx);\n        return seed;\n    }\n};\n\nstruct ggml_webgpu_set_rows_shader_decisions {\n    bool     vec4;\n    bool     i64_idx;\n    uint32_t wg_size;\n};\n\n/** Get Rows **/\n\nstruct ggml_webgpu_get_rows_pipeline_key {\n    ggml_type src_type;\n    int       vectorized;\n\n    bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const {\n        return src_type == other.src_type && vectorized == other.vectorized;\n    }\n};\n\nstruct ggml_webgpu_get_rows_pipeline_key_hash {\n    size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const {\n        size_t seed = 0;\n        ggml_webgpu_hash_combine(seed, key.src_type);\n        ggml_webgpu_hash_combine(seed, key.vectorized);\n        return seed;\n    }\n};\n\n/** Pad **/\nstruct ggml_webgpu_pad_pipeline_key {\n    bool circular;\n\n    bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }\n};\n\nstruct ggml_webgpu_pad_pipeline_key_hash {\n    size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {\n        size_t seed = 0;\n        ggml_webgpu_hash_combine(seed, key.circular);\n        return seed;\n    }\n};\n\n/** Scale **/\n\nstruct ggml_webgpu_scale_pipeline_key {\n    int inplace;\n\n    bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; }\n};\n\nstruct ggml_webgpu_scale_pipeline_key_hash {\n    size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const {\n        size_t seed = 0;\n        ggml_webgpu_hash_combine(seed, key.inplace);\n        return seed;\n    }\n};\n\n/** Concat **/\n\nstruct ggml_webgpu_concat_pipeline_key {\n    int type;\n\n    bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; }\n};\n\nstruct ggml_webgpu_concat_pipeline_key_hash {\n    size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {\n        size_t seed = 0;\n        ggml_webgpu_hash_combine(seed, key.type);\n        return seed;\n    }\n};\n\n/** Repeat **/\n\nstruct ggml_webgpu_repeat_pipeline_key {\n    int type;\n\n    bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; }\n};\n\nstruct ggml_webgpu_repeat_pipeline_key_hash {\n    size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const {\n        size_t seed = 0;\n        ggml_webgpu_hash_combine(seed, key.type);\n        return seed;\n    }\n};\n\n/** Binary **/\n\nstruct ggml_webgpu_binary_pipeline_key {\n    int  type;\n    int  op;\n    bool inplace;\n    bool overlap;\n    bool src_overlap;\n\n    bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {\n        return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&\n               src_overlap == other.src_overlap;\n    }\n};\n\nstruct ggml_webgpu_binary_pipeline_key_hash {\n    size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {\n        size_t seed = 0;\n        ggml_webgpu_hash_combine(seed, key.type);\n        ggml_webgpu_hash_combine(seed, key.op);\n        ggml_webgpu_hash_combine(seed, key.inplace);\n        ggml_webgpu_hash_combine(seed, key.overlap);\n        ggml_webgpu_hash_combine(seed, key.src_overlap);\n        return seed;\n    }\n};\n\n/** Unary **/\n\nstruct ggml_webgpu_unary_pipeline_key {\n    int  type;\n    int  op;\n    bool is_unary;  // many unary operators fall under the GGML_OP_UNARY umbrella\n    bool inplace;\n\n    bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {\n        return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;\n    }\n};\n\nstruct ggml_webgpu_unary_pipeline_key_hash {\n    size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {\n        size_t seed = 0;\n        ggml_webgpu_hash_combine(seed, key.type);\n        ggml_webgpu_hash_combine(seed, key.op);\n        ggml_webgpu_hash_combine(seed, key.is_unary);\n        ggml_webgpu_hash_combine(seed, key.inplace);\n        return seed;\n    }\n};\n\n/** FlashAttention */\n\nstruct ggml_webgpu_flash_attn_pipeline_key {\n    ggml_type kv_type;\n    uint32_t  head_dim_qk;\n    uint32_t  head_dim_v;\n    bool      kv_direct;\n    bool      has_mask;\n    bool      has_sinks;\n    bool      uses_logit_softcap;\n\n    bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {\n        return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&\n               kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&\n               uses_logit_softcap == other.uses_logit_softcap;\n    }\n};\n\nstruct ggml_webgpu_flash_attn_pipeline_key_hash {\n    size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {\n        size_t seed = 0;\n        ggml_webgpu_hash_combine(seed, key.kv_type);\n        ggml_webgpu_hash_combine(seed, key.head_dim_qk);\n        ggml_webgpu_hash_combine(seed, key.head_dim_v);\n        ggml_webgpu_hash_combine(seed, key.kv_direct);\n        ggml_webgpu_hash_combine(seed, key.has_mask);\n        ggml_webgpu_hash_combine(seed, key.has_sinks);\n        ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);\n        return seed;\n    }\n};\n\nstruct ggml_webgpu_flash_attn_shader_lib_context {\n    ggml_webgpu_flash_attn_pipeline_key key;\n    uint32_t                            sg_mat_m;\n    uint32_t                            sg_mat_n;\n    uint32_t                            sg_mat_k;\n    size_t                              wg_mem_limit_bytes;\n    uint32_t                            max_subgroup_size;\n};\n\nstruct ggml_webgpu_flash_attn_shader_decisions {\n    uint32_t q_tile  = 0;\n    uint32_t kv_tile = 0;\n    uint32_t wg_size = 0;\n};\n\n// This is exposed because it's necessary in supports_op\ninline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,\n                                                  uint32_t kv_tile,\n                                                  uint32_t head_dim_qk,\n                                                  uint32_t head_dim_v,\n                                                  bool     has_mask,\n                                                  bool     kv_direct) {\n    const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);\n    size_t         f16_elems    = 0;\n    size_t         f32_elems    = 0;\n    f16_elems += q_tile * head_dim_qk;        // q_shmem\n    if (!kv_direct) {\n        f16_elems += kv_tile * max_head_dim;  // kv_shmem\n    }\n    f16_elems += q_tile * head_dim_v;         // o_shmem\n    if (has_mask) {\n        f16_elems += q_tile * kv_tile;        // mask_shmem\n    }\n    f16_elems += q_tile * kv_tile;            // inter_shmem\n    f32_elems += q_tile;                      // row_max_shmem\n    f32_elems += q_tile;                      // exp_sum_shmem\n    return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;\n}\n\n/** Matrix Multiplication **/\n\nstruct ggml_webgpu_legacy_mul_mat_pipeline_key {\n    ggml_type src0_type;\n    ggml_type src1_type;\n\n    bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const {\n        return src0_type == other.src0_type && src1_type == other.src1_type;\n    }\n};\n\nstruct ggml_webgpu_legacy_mul_mat_pipeline_key_hash {\n    size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const {\n        size_t seed = 0;\n        ggml_webgpu_hash_combine(seed, key.src0_type);\n        ggml_webgpu_hash_combine(seed, key.src1_type);\n        return seed;\n    }\n};\n\nstruct ggml_webgpu_mul_mat_vec_pipeline_key {\n    ggml_type src0_type;\n    ggml_type src1_type;\n    int       vectorized;\n\n    bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {\n        return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized;\n    }\n};\n\nstruct ggml_webgpu_mul_mat_vec_pipeline_key_hash {\n    size_t operator()(const ggml_webgpu_mul_mat_vec_pipeline_key & key) const {\n        size_t seed = 0;\n        ggml_webgpu_hash_combine(seed, key.src0_type);\n        ggml_webgpu_hash_combine(seed, key.src1_type);\n        ggml_webgpu_hash_combine(seed, key.vectorized);\n        return seed;\n    }\n};\n\nstruct ggml_webgpu_mul_mat_vec_shader_decisions {\n    uint32_t wg_size;\n    uint32_t tile_k;\n    uint32_t outputs_per_wg;\n    uint32_t vec_size;\n};\n\nstruct ggml_webgpu_mul_mat_pipeline_key {\n    ggml_type src0_type;\n    ggml_type src1_type;\n    int       vectorized;\n    int       use_subgroup_matrix;\n\n    bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const {\n        return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&\n               use_subgroup_matrix == other.use_subgroup_matrix;\n    }\n};\n\nstruct ggml_webgpu_mul_mat_pipeline_key_hash {\n    size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const {\n        size_t seed = 0;\n        ggml_webgpu_hash_combine(seed, key.src0_type);\n        ggml_webgpu_hash_combine(seed, key.src1_type);\n        ggml_webgpu_hash_combine(seed, key.vectorized);\n        ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix);\n        return seed;\n    }\n};\n\nstruct ggml_webgpu_mul_mat_shader_decisions {\n    uint32_t tile_k;\n    uint32_t wg_size_m;\n    uint32_t wg_size_n;\n    uint32_t wg_size;\n    uint32_t outputs_per_wg;\n    int      use_subgroup_matrix;\n\n    uint32_t tile_m;\n    uint32_t tile_n;\n\n    // Subgroup matrix parameters\n    uint32_t subgroup_m;\n    uint32_t subgroup_n;\n    uint32_t subgroup_matrix_m;\n    uint32_t subgroup_matrix_n;\n\n    uint32_t mul_mat_wg_size;\n};\n\nclass ggml_webgpu_shader_lib {\n    wgpu::Device           device;\n    pre_wgsl::Preprocessor preprocessor;\n\n    std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines;       // key is fixed, no variants yet\n    std::unordered_map<int, webgpu_pipeline> argmax_pipelines;         // key is vec4\n    std::unordered_map<int, webgpu_pipeline> argsort_pipelines;        // key is order\n    std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines;  // key is order\n    std::unordered_map<int, webgpu_pipeline> cumsum_pipelines;         // key is fixed, no variants yet\n    std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>\n        get_rows_pipelines;                                            // src_type, vectorized\n    std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>\n        unary_pipelines;                                               // type/op/inplace\n    std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>\n        scale_pipelines;                                               // inplace\n    std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>\n        pad_pipelines;                                                 // circular/non-circular\n    std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>\n        binary_pipelines;                                              // type/op/inplace/overlap\n    std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>\n        concat_pipelines;                                              // type\n    std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>\n        repeat_pipelines;                                              // type\n    std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>\n        flash_attn_pipelines;\n    std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,\n                       webgpu_pipeline,\n                       ggml_webgpu_legacy_mul_mat_pipeline_key_hash>\n        mul_mat_legacy_pipelines;  // legacy mul_mat (non-subgroup/non-regtile/non-vec)\n    std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>\n        mul_mat_vec_pipelines;     // fast mat-vec (n==1)\n    std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>\n        mul_mat_fast_pipelines;    // fast mat-mat (reg-tile or subgroup)\n\n    std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>\n        set_rows_pipelines;\n\n  public:\n    ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }\n\n    webgpu_pipeline get_sum_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        auto it = sum_rows_pipelines.find(1);\n        if (it != sum_rows_pipelines.end()) {\n            return it->second;\n        }\n        std::vector<std::string> defines;\n        defines.push_back(std::string(\"WG_SIZE=\") + std::to_string(context.max_wg_size));\n\n        auto processed        = preprocessor.preprocess(wgsl_sum_rows, defines);\n        sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, \"sum_rows\");\n        return sum_rows_pipelines[1];\n    }\n\n    webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        bool vec4 = context.src0->ne[0] % 4 == 0;\n\n        auto it = argmax_pipelines.find(vec4);\n        if (it != argmax_pipelines.end()) {\n            return it->second;\n        }\n        std::string              variant = \"argmax\";\n        std::vector<std::string> defines;\n        defines.push_back(std::string(\"WG_SIZE=\") + std::to_string(context.max_wg_size));\n        if (vec4) {\n            defines.push_back(\"VEC4\");\n            variant += \"_vec4\";\n        }\n\n        auto processed         = preprocessor.preprocess(wgsl_argmax, defines);\n        argmax_pipelines[vec4] = ggml_webgpu_create_pipeline(device, processed, variant);\n        return argmax_pipelines.at(vec4);\n    }\n\n    webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type,\n                                                  .vec4     = context.src0->ne[0] % 4 == 0,\n                                                  .i64_idx  = context.src1->type == GGML_TYPE_I64 };\n\n        auto it = set_rows_pipelines.find(key);\n        if (it != set_rows_pipelines.end()) {\n            return it->second;\n        }\n\n        std::vector<std::string> defines;\n        std::string              variant = \"set_rows\";\n\n        switch (context.dst->type) {\n            case GGML_TYPE_F32:\n                defines.push_back(\"DST_F32\");\n                variant += \"_dstf32\";\n                break;\n            case GGML_TYPE_F16:\n                defines.push_back(\"DST_F16\");\n                variant += \"_dstf16\";\n                break;\n            default:\n                GGML_ABORT(\"Unsupported dst type for set_rows shader\");\n        }\n\n        if (key.vec4) {\n            defines.push_back(\"VEC4\");\n            variant += \"_vec4\";\n        }\n        if (key.i64_idx) {\n            defines.push_back(\"I64_IDX\");\n            variant += \"_i64idx\";\n        }\n\n        defines.push_back(std::string(\"WG_SIZE=\") + std::to_string(context.max_wg_size));\n\n        auto processed                  = preprocessor.preprocess(wgsl_set_rows, defines);\n        auto decisions                  = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();\n        decisions->vec4                 = key.vec4;\n        decisions->i64_idx              = key.i64_idx;\n        decisions->wg_size              = context.max_wg_size;\n        set_rows_pipelines[key]         = ggml_webgpu_create_pipeline(device, processed, variant);\n        set_rows_pipelines[key].context = decisions;\n        return set_rows_pipelines[key];\n    }\n\n    webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        auto it = cumsum_pipelines.find(1);\n        if (it != cumsum_pipelines.end()) {\n            return it->second;\n        }\n\n        std::vector<std::string> defines;\n        defines.push_back(std::string(\"WG_SIZE=\") + std::to_string(context.max_wg_size));\n\n        auto processed      = preprocessor.preprocess(wgsl_cumsum, defines);\n        cumsum_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, \"cumsum\");\n        return cumsum_pipelines[1];\n    }\n\n    webgpu_pipeline get_argsort_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        bool          is_top_k = context.dst->op == GGML_OP_TOP_K;\n        // ascending order is 0, descending order is 1\n        const int32_t order =\n            is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);\n\n        auto it = argsort_pipelines.find(order);\n        if (it != argsort_pipelines.end()) {\n            return it->second;\n        }\n\n        std::vector<std::string> defines;\n        std::string              variant = \"argsort\";\n        defines.push_back(std::string(\"ORDER=\") + std::to_string(order));\n        variant += std::string(\"_order\") + std::to_string(order);\n        uint32_t wg_size = 1;\n        while (wg_size * 2 <= context.max_wg_size &&\n               wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {\n            wg_size *= 2;\n        }\n        defines.push_back(std::string(\"WG_SIZE=\") + std::to_string(wg_size));\n        auto processed                   = preprocessor.preprocess(wgsl_argsort, defines);\n        auto decisions                   = std::make_shared<ggml_webgpu_generic_shader_decisions>();\n        decisions->wg_size               = wg_size;\n        argsort_pipelines[order]         = ggml_webgpu_create_pipeline(device, processed, variant);\n        argsort_pipelines[order].context = decisions;\n        return argsort_pipelines[order];\n    }\n\n    webgpu_pipeline get_argsort_merge_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        bool          is_top_k = context.dst->op == GGML_OP_TOP_K;\n        // ascending order is 0, descending order is 1\n        const int32_t order =\n            is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);\n\n        auto it = argsort_merge_pipelines.find(order);\n        if (it != argsort_merge_pipelines.end()) {\n            return it->second;\n        }\n\n        std::vector<std::string> defines;\n        std::string              variant = \"argsort_merge\";\n        defines.push_back(std::string(\"ORDER=\") + std::to_string(order));\n        variant += std::string(\"_order\") + std::to_string(order);\n        uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);\n        defines.push_back(std::string(\"WG_SIZE=\") + std::to_string(wg_size));\n\n        auto processed                 = preprocessor.preprocess(wgsl_argsort_merge, defines);\n        argsort_merge_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);\n        return argsort_merge_pipelines[order];\n    }\n\n    webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        const bool vectorized                 = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0;\n        ggml_webgpu_get_rows_pipeline_key key = {\n            .src_type   = context.src0->type,\n            .vectorized = (int) vectorized,\n        };\n\n        auto it = get_rows_pipelines.find(key);\n        if (it != get_rows_pipelines.end()) {\n            return it->second;\n        }\n\n        std::vector<std::string> defines;\n        std::string              variant = \"get_rows\";\n\n        const struct ggml_type_traits * type_traits = ggml_get_type_traits(key.src_type);\n        const char *                    type_str    = type_traits->type_name;\n\n        switch (key.src_type) {\n            case GGML_TYPE_F32:\n                if (key.vectorized) {\n                    defines.push_back(\"F32_VEC\");\n                    defines.push_back(\"SRC_TYPE=vec4<f32>\");\n                    defines.push_back(\"DST_TYPE=vec4<f32>\");\n                    defines.push_back(\"BLOCK_SIZE=4u\");\n                } else {\n                    defines.push_back(\"F32\");\n                    defines.push_back(\"SRC_TYPE=f32\");\n                    defines.push_back(\"DST_TYPE=f32\");\n                    defines.push_back(\"BLOCK_SIZE=1u\");\n                }\n                variant += \"_f32\";\n                break;\n            case GGML_TYPE_F16:\n                defines.push_back(\"F16\");\n                defines.push_back(\"SRC_TYPE=f16\");\n                defines.push_back(\"DST_TYPE=f32\");\n                defines.push_back(\"BLOCK_SIZE=1u\");\n                variant += \"_f16\";\n                break;\n            case GGML_TYPE_I32:\n                defines.push_back(\"I32\");\n                defines.push_back(\"SRC_TYPE=i32\");\n                defines.push_back(\"DST_TYPE=i32\");\n                defines.push_back(\"BLOCK_SIZE=1u\");\n                variant += \"_i32\";\n                break;\n            default:\n                {\n                    std::string type_upper = type_str;\n                    std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);\n\n                    defines.push_back(\"BYTE_HELPERS\");\n                    defines.push_back(type_upper + \"_T\");\n                    defines.push_back(type_upper);\n                    defines.push_back(type_upper + \"_SCALE_MIN\");\n                    defines.push_back(type_upper + \"_TABLES\");\n                    defines.push_back(type_upper + \"_GRID\");\n\n                    variant += \"_\";\n                    variant += type_str;\n\n                    defines.push_back(std::string(\"SRC_TYPE=\") + type_str);\n                    defines.push_back(\"DST_TYPE=f32\");\n\n                    if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||\n                        key.src_type == GGML_TYPE_IQ4_NL) {\n                        defines.push_back(\"BLOCK_SIZE=32u\");\n                    } else if (key.src_type >= GGML_TYPE_Q2_K) {\n                        defines.push_back(\"BLOCK_SIZE=256u\");\n                    } else {\n                        defines.push_back(\"BLOCK_SIZE=1u\");\n                    }\n                    break;\n                }\n        }\n\n        if (key.vectorized) {\n            variant += \"_vec\";\n        }\n\n        defines.push_back(\"WG_SIZE=\" + std::to_string(context.max_wg_size));\n\n        auto processed           = preprocessor.preprocess(wgsl_get_rows, defines);\n        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();\n        decisions->wg_size       = context.max_wg_size;\n        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);\n        pipeline.context         = decisions;\n        get_rows_pipelines[key]  = pipeline;\n        return get_rows_pipelines[key];\n    }\n\n    webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace };\n\n        auto it = scale_pipelines.find(key);\n        if (it != scale_pipelines.end()) {\n            return it->second;\n        }\n\n        std::vector<std::string> defines;\n        std::string              variant = \"scale\";\n\n        if (key.inplace) {\n            defines.push_back(\"INPLACE\");\n            variant += \"_inplace\";\n        }\n\n        defines.push_back(std::string(\"WG_SIZE=\") + std::to_string(context.max_wg_size));\n\n        auto processed           = preprocessor.preprocess(wgsl_scale, defines);\n        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();\n        decisions->wg_size       = context.max_wg_size;\n        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);\n        pipeline.context         = decisions;\n        scale_pipelines[key]     = pipeline;\n        return scale_pipelines[key];\n    }\n\n    webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };\n\n        auto it = pad_pipelines.find(key);\n        if (it != pad_pipelines.end()) {\n            return it->second;\n        }\n\n        std::vector<std::string> defines;\n        std::string              variant = \"pad\";\n\n        if (key.circular) {\n            defines.push_back(\"CIRCULAR\");\n            variant += \"_circular\";\n        }\n\n        defines.push_back(std::string(\"WG_SIZE=\") + std::to_string(context.max_wg_size));\n\n        auto processed           = preprocessor.preprocess(wgsl_pad, defines);\n        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();\n        decisions->wg_size       = context.max_wg_size;\n        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);\n        pipeline.context         = decisions;\n        pad_pipelines[key]       = pipeline;\n        return pad_pipelines[key];\n    }\n\n    webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        ggml_webgpu_mul_mat_vec_pipeline_key key = {\n            .src0_type  = context.src0->type,\n            .src1_type  = context.src1->type,\n            // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float\n            .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&\n                           (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?\n                              1 :\n                              0,\n        };\n\n        auto it = mul_mat_vec_pipelines.find(key);\n        if (it != mul_mat_vec_pipelines.end()) {\n            return it->second;\n        }\n\n        std::vector<std::string> defines;\n        std::string              variant = \"mul_mat_vec\";\n\n        // src0 type (matrix row)\n        switch (context.src0->type) {\n            case GGML_TYPE_F32:\n                defines.push_back(\"SRC0_INNER_TYPE=f32\");\n                defines.push_back(\"MUL_ACC_FLOAT\");\n                variant += \"_f32\";\n                break;\n            case GGML_TYPE_F16:\n                defines.push_back(\"SRC0_INNER_TYPE=f16\");\n                defines.push_back(\"MUL_ACC_FLOAT\");\n                variant += \"_f16\";\n                break;\n            default:\n                {\n                    // Quantized types: use helpers but accumulate in f16\n                    const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);\n                    std::string                     src0_name   = src0_traits->type_name;\n                    std::string                     type_upper  = src0_name;\n                    variant += \"_\" + src0_name;\n                    std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);\n\n                    defines.push_back(\"BYTE_HELPERS\");\n                    defines.push_back(\"MUL_ACC_\" + type_upper);\n\n                    // For fast path we always dequantize from f16 inside the shader\n                    defines.push_back(\"SRC0_INNER_TYPE=f16\");\n                    break;\n                }\n        }\n\n        // src1 type (vector)\n        switch (context.src1->type) {\n            case GGML_TYPE_F32:\n                defines.push_back(\"SRC1_INNER_TYPE=f32\");\n                variant += \"_f32\";\n                break;\n            case GGML_TYPE_F16:\n                defines.push_back(\"SRC1_INNER_TYPE=f16\");\n                variant += \"_f16\";\n                break;\n            default:\n                GGML_ABORT(\"Unsupported src1 type for mul_mat_vec shader\");\n        }\n\n        // VEC/SCALAR controls\n        defines.push_back(key.vectorized ? \"VEC\" : \"SCALAR\");\n\n        uint32_t wg_size        = WEBGPU_MUL_MAT_VEC_WG_SIZE;\n        uint32_t tile_k         = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;\n        uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;\n\n        if (key.src0_type >= GGML_TYPE_Q2_K) {\n            tile_k         = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;\n            outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;\n        } else if (key.src0_type >= GGML_TYPE_Q4_0) {\n            tile_k         = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;\n            outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;\n        }\n\n        defines.push_back(std::string(\"WG_SIZE=\") + std::to_string(wg_size));\n        defines.push_back(std::string(\"TILE_K=\") + std::to_string(tile_k));\n        defines.push_back(std::string(\"OUTPUTS_PER_WG=\") + std::to_string(outputs_per_wg));\n\n        auto processed            = preprocessor.preprocess(wgsl_mul_mat_vec, defines);\n        auto decisions            = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();\n        decisions->wg_size        = wg_size;\n        decisions->tile_k         = tile_k;\n        decisions->outputs_per_wg = outputs_per_wg;\n        decisions->vec_size       = key.vectorized ? 4 : 1;\n\n        webgpu_pipeline pipeline   = ggml_webgpu_create_pipeline(device, processed, variant);\n        pipeline.context           = decisions;\n        mul_mat_vec_pipelines[key] = pipeline;\n        return mul_mat_vec_pipelines[key];\n    }\n\n    webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        ggml_webgpu_mul_mat_pipeline_key key = {\n            .src0_type  = context.src0->type,\n            .src1_type  = context.src1->type,\n            .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 &&\n                           (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?\n                              1 :\n                              0,\n            .use_subgroup_matrix = context.supports_subgroup_matrix\n        };\n\n        auto it = mul_mat_fast_pipelines.find(key);\n        if (it != mul_mat_fast_pipelines.end()) {\n            return it->second;\n        }\n\n        const char * shader_src = key.use_subgroup_matrix ? wgsl_mul_mat_subgroup_matrix : wgsl_mul_mat_reg_tile;\n        std::vector<std::string> defines;\n        std::string              variant = key.use_subgroup_matrix ? \"mul_mat_subgroup_matrix\" : \"mul_mat_reg_tile\";\n\n        // src1 type\n        switch (context.src1->type) {\n            case GGML_TYPE_F32:\n                defines.push_back(\"SRC1_INNER_TYPE=f32\");\n                break;\n            case GGML_TYPE_F16:\n                defines.push_back(\"SRC1_INNER_TYPE=f16\");\n                break;\n            default:\n                GGML_ABORT(\"Unsupported src1 type for mul_mat fast shader\");\n        }\n\n        // src0 type\n        const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);\n        const char *                    src0_name   = src0_traits->type_name;\n\n        switch (context.src0->type) {\n            case GGML_TYPE_F32:\n                defines.push_back(\"SRC0_INNER_TYPE=f32\");\n                defines.push_back(\"FLOAT\");\n                defines.push_back(\"MUL_ACC_FLOAT\");\n                defines.push_back(\"INIT_SRC0_SHMEM_FLOAT\");\n                defines.push_back(\"INIT_SRC1_SHMEM_FLOAT\");\n                variant += \"_f32\";\n                break;\n            case GGML_TYPE_F16:\n                defines.push_back(\"SRC0_INNER_TYPE=f16\");\n                defines.push_back(\"FLOAT\");\n                defines.push_back(\"MUL_ACC_FLOAT\");\n                defines.push_back(\"INIT_SRC0_SHMEM_FLOAT\");\n                defines.push_back(\"INIT_SRC1_SHMEM_FLOAT\");\n                variant += \"_f16\";\n                break;\n            default:\n                {\n                    std::string type_upper = src0_name;\n                    std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);\n\n                    defines.push_back(\"BYTE_HELPERS\");\n                    defines.push_back(\"MUL_ACC_\" + type_upper);\n                    defines.push_back(\"INIT_SRC0_SHMEM_\" + type_upper);\n                    defines.push_back(\"INIT_SRC1_SHMEM_FLOAT\");\n\n                    // Use f16 inside the shader for quantized types\n                    defines.push_back(\"SRC0_INNER_TYPE=f16\");\n\n                    variant += std::string(\"_\") + src0_name;\n                    break;\n                }\n        }\n\n        // VEC/SCALAR controls\n        defines.push_back(key.vectorized ? \"VEC\" : \"SCALAR\");\n\n        // Tiles\n        defines.push_back(\"TILE_M=\" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + \"u\");\n        defines.push_back(\"TILE_N=\" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + \"u\");\n        defines.push_back(\"TILE_K=\" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + \"u\");\n\n        // Subgroup matrix specifics\n        if (key.use_subgroup_matrix) {\n            defines.push_back(\"MAX_SUBGROUP_SIZE=\" + std::to_string(context.max_subgroup_size) + \"u\");\n            defines.push_back(\"SUBGROUP_M=\" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + \"u\");\n            defines.push_back(\"SUBGROUP_N=\" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + \"u\");\n            defines.push_back(\"SUBGROUP_MATRIX_M=\" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + \"u\");\n            defines.push_back(\"SUBGROUP_MATRIX_N=\" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + \"u\");\n            defines.push_back(\"SUBGROUP_MATRIX_M_SIZE=\" + std::to_string(context.sg_mat_m) + \"u\");\n            defines.push_back(\"SUBGROUP_MATRIX_N_SIZE=\" + std::to_string(context.sg_mat_n) + \"u\");\n            defines.push_back(\"SUBGROUP_MATRIX_K_SIZE=\" + std::to_string(context.sg_mat_k) + \"u\");\n        }\n\n        // variant suffix for src1 type\n        variant += std::string(\"_\") + (context.src1->type == GGML_TYPE_F32 ? \"f32\" : \"f16\");\n        if (key.vectorized) {\n            variant += \"_vectorized\";\n        }\n\n        if (!key.use_subgroup_matrix) {\n            defines.push_back(\"WORKGROUP_SIZE_M=\" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + \"u\");\n            defines.push_back(\"WORKGROUP_SIZE_N=\" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + \"u\");\n        }\n\n        auto processed = preprocessor.preprocess(shader_src, defines);\n\n        auto decisions                 = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();\n        decisions->tile_k              = WEBGPU_MUL_MAT_TILE_K;\n        decisions->tile_m              = WEBGPU_MUL_MAT_TILE_M;\n        decisions->tile_n              = WEBGPU_MUL_MAT_TILE_N;\n        decisions->use_subgroup_matrix = key.use_subgroup_matrix;\n        if (key.use_subgroup_matrix) {\n            decisions->subgroup_m        = WEBGPU_MUL_MAT_SUBGROUP_M;\n            decisions->subgroup_n        = WEBGPU_MUL_MAT_SUBGROUP_N;\n            decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M;\n            decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N;\n            decisions->wg_size           = context.max_subgroup_size;\n        } else {\n            decisions->wg_size_m       = WEBGPU_MUL_MAT_WG_SIZE_M;\n            decisions->wg_size_n       = WEBGPU_MUL_MAT_WG_SIZE_N;\n            decisions->wg_size         = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N;\n            decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE;\n        }\n\n        webgpu_pipeline pipeline    = ggml_webgpu_create_pipeline(device, processed, variant);\n        pipeline.context            = decisions;\n        mul_mat_fast_pipelines[key] = pipeline;\n        return mul_mat_fast_pipelines[key];\n    }\n\n    webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type,\n                                                        .src1_type = context.src1->type };\n\n        auto it = mul_mat_legacy_pipelines.find(key);\n        if (it != mul_mat_legacy_pipelines.end()) {\n            return it->second;\n        }\n\n        std::vector<std::string> defines;\n        std::string              variant = \"mul_mat\";\n\n        switch (context.src1->type) {\n            case GGML_TYPE_F32:\n                defines.push_back(\"SRC1_TYPE=f32\");\n                variant += \"_f32\";\n                break;\n            case GGML_TYPE_F16:\n                defines.push_back(\"SRC1_TYPE=f16\");\n                variant += \"_f16\";\n                break;\n            default:\n                GGML_ABORT(\"Unsupported src1 type for mul_mat legacy shader\");\n        }\n\n        const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);\n        const char *                    src0_name   = src0_traits->type_name;\n\n        switch (context.src0->type) {\n            case GGML_TYPE_F32:\n                defines.push_back(\"SRC0_TYPE=f32\");\n                defines.push_back(\"FLOAT\");\n                variant += \"_f32\";\n                break;\n            case GGML_TYPE_F16:\n                defines.push_back(\"SRC0_TYPE=f16\");\n                defines.push_back(\"FLOAT\");\n                variant += \"_f16\";\n                break;\n            default:\n                {\n                    // quantized types\n                    std::string type_upper = src0_name;\n                    std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);\n\n                    defines.push_back(std::string(\"SRC0_TYPE=\") + src0_name);\n                    defines.push_back(\"BYTE_HELPERS\");\n                    defines.push_back(type_upper + \"_T\");\n                    defines.push_back(type_upper);\n                    defines.push_back(type_upper + \"_SCALE_MIN\");\n                    defines.push_back(type_upper + \"_TABLES\");\n                    defines.push_back(type_upper + \"_GRID\");\n\n                    variant += std::string(\"_\") + src0_name;\n                    break;\n                }\n        }\n\n        auto processed = preprocessor.preprocess(wgsl_mul_mat, defines);\n\n        auto decisions     = std::make_shared<ggml_webgpu_generic_shader_decisions>();\n        decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE;\n\n        webgpu_pipeline pipeline      = ggml_webgpu_create_pipeline(device, processed, variant);\n        pipeline.context              = decisions;\n        mul_mat_legacy_pipelines[key] = pipeline;\n        return mul_mat_legacy_pipelines[key];\n    }\n\n    webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        const bool                     is_unary = context.dst->op == GGML_OP_UNARY;\n        const int                      op       = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;\n        ggml_webgpu_unary_pipeline_key key      = {\n                 .type     = context.dst->type,\n                 .op       = op,\n                 .is_unary = is_unary,\n                 .inplace  = context.inplace,\n        };\n\n        auto it = unary_pipelines.find(key);\n        if (it != unary_pipelines.end()) {\n            return it->second;\n        }\n\n        std::vector<std::string> defines;\n        std::string              variant =\n            key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op);\n        defines.push_back(variant);\n\n        switch (key.type) {\n            case GGML_TYPE_F32:\n                defines.push_back(\"TYPE_F32\");\n                variant += \"_f32\";\n                break;\n            case GGML_TYPE_F16:\n                defines.push_back(\"TYPE_F16\");\n                variant += \"_f16\";\n                break;\n            default:\n                GGML_ABORT(\"Unsupported type for unary shader\");\n        }\n\n        if (key.inplace) {\n            defines.push_back(\"INPLACE\");\n            variant += \"_inplace\";\n        }\n\n        defines.push_back(std::string(\"WG_SIZE=\") + std::to_string(context.max_wg_size));\n\n        auto processed           = preprocessor.preprocess(wgsl_unary, defines);\n        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();\n        decisions->wg_size       = context.max_wg_size;\n        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);\n        pipeline.context         = decisions;\n        unary_pipelines[key]     = pipeline;\n        return unary_pipelines[key];\n    }\n\n    webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        ggml_webgpu_binary_pipeline_key key = {\n            .type        = context.dst->type,\n            .op          = context.dst->op,\n            .inplace     = context.inplace,\n            .overlap     = context.overlap,\n            .src_overlap = context.src_overlap,\n        };\n\n        auto it = binary_pipelines.find(key);\n        if (it != binary_pipelines.end()) {\n            return it->second;\n        }\n\n        std::vector<std::string> defines;\n        std::string              op_name = ggml_op_name((ggml_op) key.op);\n        std::string              variant = op_name;\n\n        defines.push_back(std::string(\"OP_\") + op_name);\n\n        switch (key.type) {\n            case GGML_TYPE_F32:\n                defines.push_back(\"TYPE_F32\");\n                variant += \"_f32\";\n                break;\n            case GGML_TYPE_F16:\n                defines.push_back(\"TYPE_F16\");\n                variant += \"_f16\";\n                break;\n            default:\n                GGML_ABORT(\"Unsupported type for binary shader\");\n        }\n\n        if (key.inplace) {\n            defines.push_back(\"INPLACE\");\n            variant += \"_inplace\";\n        } else if (key.overlap) {\n            defines.push_back(\"OVERLAP\");\n            variant += \"_overlap\";\n        } else if (key.src_overlap) {\n            defines.push_back(\"SRC_OVERLAP\");\n            variant += \"_src_overlap\";\n        }\n\n        defines.push_back(std::string(\"WG_SIZE=\") + std::to_string(context.max_wg_size));\n\n        auto processed           = preprocessor.preprocess(wgsl_binary, defines);\n        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();\n        decisions->wg_size       = context.max_wg_size;\n        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);\n        pipeline.context         = decisions;\n        binary_pipelines[key]    = pipeline;\n        return binary_pipelines[key];\n    }\n\n    webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        ggml_webgpu_concat_pipeline_key key = {\n            .type = context.dst->type,\n        };\n\n        auto it = concat_pipelines.find(key);\n        if (it != concat_pipelines.end()) {\n            return it->second;\n        }\n\n        std::vector<std::string> defines;\n        std::string              variant = \"concat\";\n\n        switch (key.type) {\n            case GGML_TYPE_F32:\n                defines.push_back(\"TYPE_F32\");\n                variant += \"_f32\";\n                break;\n            case GGML_TYPE_I32:\n                defines.push_back(\"TYPE_I32\");\n                variant += \"_i32\";\n                break;\n            default:\n                GGML_ABORT(\"Unsupported type for concat shader\");\n        }\n\n        defines.push_back(std::string(\"WG_SIZE=\") + std::to_string(context.max_wg_size));\n\n        auto processed           = preprocessor.preprocess(wgsl_concat, defines);\n        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();\n        decisions->wg_size       = context.max_wg_size;\n        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);\n        pipeline.context         = decisions;\n        concat_pipelines[key]    = pipeline;\n        return concat_pipelines[key];\n    }\n\n    webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        ggml_webgpu_repeat_pipeline_key key = {\n            .type = context.dst->type,\n        };\n\n        auto it = repeat_pipelines.find(key);\n        if (it != repeat_pipelines.end()) {\n            return it->second;\n        }\n\n        std::vector<std::string> defines;\n        std::string              variant = \"repeat\";\n\n        switch (key.type) {\n            case GGML_TYPE_F32:\n                defines.push_back(\"TYPE_F32\");\n                variant += \"_f32\";\n                break;\n            case GGML_TYPE_I32:\n                defines.push_back(\"TYPE_I32\");\n                variant += \"_i32\";\n                break;\n            case GGML_TYPE_I16:\n                defines.push_back(\"TYPE_I16\");\n                variant += \"_i16\";\n                break;\n            default:\n                GGML_ABORT(\"Unsupported type for repeat shader\");\n        }\n\n        defines.push_back(std::string(\"WG_SIZE=\") + std::to_string(context.max_wg_size));\n\n        auto processed           = preprocessor.preprocess(wgsl_repeat, defines);\n        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();\n        decisions->wg_size       = context.max_wg_size;\n        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);\n        pipeline.context         = decisions;\n        repeat_pipelines[key]    = pipeline;\n        return repeat_pipelines[key];\n    }\n\n    webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {\n        const bool has_mask  = context.src3 != nullptr;\n        const bool has_sinks = context.src4 != nullptr;\n\n        bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&\n                         (context.src1->ne[1] % context.sg_mat_n == 0);\n\n        ggml_webgpu_flash_attn_pipeline_key key = {\n            .kv_type            = context.src1->type,\n            .head_dim_qk        = (uint32_t) context.src0->ne[0],\n            .head_dim_v         = (uint32_t) context.src2->ne[0],\n            .kv_direct          = kv_direct,\n            .has_mask           = has_mask,\n            .has_sinks          = has_sinks,\n            .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f,\n        };\n\n        auto it = flash_attn_pipelines.find(key);\n        if (it != flash_attn_pipelines.end()) {\n            return it->second;\n        }\n\n        std::vector<std::string> defines;\n        std::string              variant = \"flash_attn\";\n\n        switch (key.kv_type) {\n            case GGML_TYPE_F32:\n                defines.push_back(\"KV_F32\");\n                break;\n            case GGML_TYPE_F16:\n                defines.push_back(\"KV_F16\");\n                break;\n            case GGML_TYPE_Q4_0:\n                defines.push_back(\"KV_Q4_0\");\n                break;\n            case GGML_TYPE_Q8_0:\n                defines.push_back(\"KV_Q8_0\");\n                break;\n            default:\n                GGML_ABORT(\"Unsupported KV type for flash attention shader\");\n        }\n        variant += std::string(\"_\") + ggml_type_name(key.kv_type);\n\n        if (key.has_mask) {\n            defines.push_back(\"MASK\");\n            variant += \"_mask\";\n        }\n        if (key.has_sinks) {\n            defines.push_back(\"SINKS\");\n            variant += \"_sinks\";\n        }\n        if (key.uses_logit_softcap) {\n            defines.push_back(\"LOGIT_SOFTCAP\");\n            variant += \"_lgsc\";\n        }\n        if (key.kv_direct) {\n            defines.push_back(\"KV_DIRECT\");\n            variant += \"_kvdirect\";\n        }\n\n        defines.push_back(std::string(\"HEAD_DIM_QK=\") + std::to_string(key.head_dim_qk));\n        variant += std::string(\"_hsqk\") + std::to_string(key.head_dim_qk);\n\n        defines.push_back(std::string(\"HEAD_DIM_V=\") + std::to_string(key.head_dim_v));\n        variant += std::string(\"_hsv\") + std::to_string(key.head_dim_v);\n\n        defines.push_back(std::string(\"SG_MAT_M=\") + std::to_string(context.sg_mat_m));\n        defines.push_back(std::string(\"SG_MAT_N=\") + std::to_string(context.sg_mat_n));\n        defines.push_back(std::string(\"SG_MAT_K=\") + std::to_string(context.sg_mat_k));\n\n        uint32_t q_tile = context.sg_mat_m;\n        uint32_t kv_tile =\n            std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k,\n                                                          context.wg_mem_limit_bytes, context.max_subgroup_size }),\n                     context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);\n        if (key.kv_direct) {\n            while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {\n                kv_tile -= context.sg_mat_n;\n            }\n        }\n\n        defines.push_back(std::string(\"Q_TILE=\") + std::to_string(q_tile));\n        defines.push_back(std::string(\"KV_TILE=\") + std::to_string(kv_tile));\n\n        uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);\n        defines.push_back(std::string(\"WG_SIZE=\") + std::to_string(wg_size));\n\n        auto processed     = preprocessor.preprocess(wgsl_flash_attn, defines);\n        auto decisions     = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();\n        decisions->q_tile  = q_tile;\n        decisions->kv_tile = kv_tile;\n        decisions->wg_size = wg_size;\n\n        webgpu_pipeline pipeline  = ggml_webgpu_create_pipeline(device, processed, variant);\n        pipeline.context          = decisions;\n        flash_attn_pipelines[key] = pipeline;\n        return flash_attn_pipelines[key];\n    }\n\n  private:\n    static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,\n                                                       std::string    shader_code,\n                                                       std::string    label) {\n        wgpu::ShaderSourceWGSL shader_source;\n        shader_source.code = shader_code.c_str();\n\n        wgpu::ShaderModuleDescriptor shader_desc;\n        shader_desc.nextInChain = &shader_source;\n\n        wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);\n\n        wgpu::ComputePipelineDescriptor pipeline_desc;\n        pipeline_desc.label              = label.c_str();\n        pipeline_desc.compute.module     = shader_module;\n        pipeline_desc.compute.entryPoint = \"main\";   // Entry point in the WGSL code\n        pipeline_desc.layout             = nullptr;  // nullptr means auto layout\n        return { device.CreateComputePipeline(&pipeline_desc), label };\n    }\n\n    static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {\n        const size_t limit_bytes = context.wg_mem_limit_bytes;\n        const size_t q_tile      = context.sg_mat_m;\n        const size_t base_q_bytes =\n            (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +\n            2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;\n        size_t bytes_per_kv = 0;\n        if (!context.key.kv_direct) {\n            bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);\n        }\n        if (context.key.has_mask) {\n            bytes_per_kv += q_tile;\n        }\n        bytes_per_kv += q_tile;\n        bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;\n        const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;\n        return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;\n    }\n};\n\n#endif  // GGML_WEBGPU_SHADER_LIB_HPP\n"
  },
  {
    "path": "src/ggml-webgpu/ggml-webgpu.cpp",
    "content": "/*\n    WebGPU backend implementation.\n    Note: Use ClangFormat to format this file.\n*/\n\n#include \"ggml-webgpu.h\"\n\n#include \"ggml-backend-impl.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-webgpu-shader-lib.hpp\"\n\n#ifdef __EMSCRIPTEN__\n#    include <emscripten/emscripten.h>\n#endif\n\n#include <webgpu/webgpu_cpp.h>\n\n#include <atomic>\n#include <condition_variable>\n#include <cstdint>\n#include <cstring>\n#ifdef GGML_WEBGPU_GPU_PROFILE\n#    include <iomanip>\n#endif\n#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE)\n#    include <iostream>\n#endif\n#include <map>\n#include <memory>\n#include <mutex>\n#include <optional>\n#include <string>\n#include <utility>\n#include <vector>\n\n#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))\n#define CEIL_DIV(M, N)        (((M) + (N) - 1) / (N))\n\n// Return a rectangular grid of workgroups with minimal over-provisioned workgroups.\n// Assumes that the total number of workgroups does not exceed max_per_dim^2.\nstatic inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) {\n    wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim));\n    wg_x = CEIL_DIV(total_wg, wg_y);\n}\n\n#ifdef GGML_WEBGPU_DEBUG\n#    define WEBGPU_LOG_DEBUG(msg)  std::cout << msg << std::endl\n#    define WEBGPU_DEBUG_BUF_ELEMS 512\n#else\n#    define WEBGPU_LOG_DEBUG(msg) ((void) 0)\n#endif  // GGML_WEBGPU_DEBUG\n\n#ifdef GGML_WEBGPU_CPU_PROFILE\n// total timing (aggregated)\n#    define WEBGPU_CPU_PROFILE_TOTAL_START(id) auto cpu_total_start_##id = std::chrono::high_resolution_clock::now();\n\n#    define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx)                                                         \\\n        auto   cpu_total_end_##id = std::chrono::high_resolution_clock::now();                            \\\n        double cpu_total_time_##id =                                                                      \\\n            std::chrono::duration<double, std::milli>(cpu_total_end_##id - cpu_total_start_##id).count(); \\\n        (ctx)->cpu_time_ms[#id] += cpu_total_time_##id;\n// fine-grained timing (not included in totals)\n#    define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now();\n\n#    define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx)                                                          \\\n        auto   cpu_detail_end_##id = std::chrono::high_resolution_clock::now();                             \\\n        double cpu_detail_time_##id =                                                                       \\\n            std::chrono::duration<double, std::milli>(cpu_detail_end_##id - cpu_detail_start_##id).count(); \\\n        (ctx)->cpu_detail_ms[#id] += cpu_detail_time_##id;\n#else\n#    define WEBGPU_CPU_PROFILE_TOTAL_START(id)\n#    define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx)\n#    define WEBGPU_CPU_PROFILE_DETAIL_START(id)\n#    define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx)\n#endif  // GGML_WEBGPU_CPU_PROFILE\n\n#ifdef GGML_WEBGPU_GPU_PROFILE\n#    define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS       32\n#    define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16  // e.g. enough for two timestamps\n#endif\n\n/* Constants */\n\n#define WEBGPU_NUM_PARAM_BUFS                96u\n#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE     32u\n#define WEBGPU_WAIT_ANY_TIMEOUT_MS           0\n// Maximum number of in-flight submissions per-thread, to avoid exhausting the\n// parameter buffer pool\n#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD  (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)\n#define WEBGPU_PARAMS_BUF_SIZE_BYTES         128  // enough for 32 parameters\n#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4\n#define WEBGPU_STORAGE_BUF_BINDING_MULT      4    // a storage buffer binding size must be a multiple of 4\n\n// For operations which process a row in parallel, this seems like a reasonable\n// default\n#define WEBGPU_ROW_SPLIT_WG_SIZE 64\n\n// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to\n// implementations so this can be removed, necessary only for get_rows right now\n#define WEBGPU_MAX_WG_SIZE 288\n\n/* End Constants */\n\n// This is a \"fake\" base pointer, since WebGPU buffers do not have pointers to\n// their locations.\nstatic void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000;  // NOLINT\n\n// Always returns the base offset of a tensor, regardless of views.\nstatic uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {\n    if (tensor->view_src) {\n        return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;\n    }\n    return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;\n}\n\n/* Struct definitions */\n\n// Forward reference\nstatic void ggml_webgpu_create_buffer(wgpu::Device &    device,\n                                      wgpu::Buffer &    buffer,\n                                      size_t            size,\n                                      wgpu::BufferUsage usage,\n                                      const char *      label);\n\n// Holds a pool of parameter buffers for WebGPU operations\nstruct webgpu_buf_pool {\n    std::vector<wgpu::Buffer> free;\n\n    // The pool must be synchronized because\n    // 1. The memset pool is shared globally by every ggml buffer,\n    // since allocating a pool per ggml buffer would consume too much memory.\n    // 2. For the per-thread buffer pools in webgpu_context,\n    // buffers are allocated and freed in Dawn callbacks,\n    // which can run on a different thread than the calling thread.\n    std::mutex              mutex;\n    std::condition_variable cv;\n    size_t                  cur_pool_size;\n    size_t                  max_pool_size;\n    wgpu::Device            device;\n    wgpu::BufferUsage       dev_buf_usage;\n    size_t                  buf_size;\n    bool                    should_grow;\n\n    void init(wgpu::Device      device,\n              int               num_bufs,\n              size_t            buf_size,\n              wgpu::BufferUsage dev_buf_usage,\n              bool              should_grow   = false,\n              size_t            max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {\n        this->max_pool_size = max_pool_size;\n        this->cur_pool_size = num_bufs;\n        this->device        = device;\n        this->dev_buf_usage = dev_buf_usage;\n        this->buf_size      = buf_size;\n        this->should_grow   = should_grow;\n        for (int i = 0; i < num_bufs; i++) {\n            wgpu::Buffer dev_buf;\n            ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, \"ggml_webgpu_dev_pool_buf\");\n            free.push_back(dev_buf);\n        }\n    }\n\n    wgpu::Buffer alloc_bufs() {\n        std::unique_lock<std::mutex> lock(mutex);\n        if (!free.empty()) {\n            wgpu::Buffer buf = free.back();\n            free.pop_back();\n            return buf;\n        }\n\n        // Try growing the pool if no free buffers\n        if (free.empty() && cur_pool_size < max_pool_size && should_grow) {\n            cur_pool_size++;\n            wgpu::Buffer dev_buf;\n            ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, \"ggml_webgpu_dev_pool_buf\");\n\n            if (!dev_buf) {\n                GGML_ABORT(\"webgpu_buf_pool: failed to allocate buffers\");\n            }\n            return dev_buf;\n        }\n        cv.wait(lock, [this] { return !free.empty(); });\n        wgpu::Buffer buf = free.back();\n        free.pop_back();\n        return buf;\n    }\n\n    void free_bufs(std::vector<wgpu::Buffer> bufs) {\n        std::lock_guard<std::mutex> lock(mutex);\n        free.insert(free.end(), bufs.begin(), bufs.end());\n        cv.notify_all();\n    }\n\n    void cleanup() {\n        std::lock_guard<std::mutex> lock(mutex);\n        for (auto & buf : free) {\n            if (buf) {\n                buf.Destroy();\n            }\n        }\n        free.clear();\n    }\n\n    ~webgpu_buf_pool() { this->cleanup(); }\n};\n\n#ifdef GGML_WEBGPU_GPU_PROFILE\nstruct webgpu_gpu_profile_bufs {\n    wgpu::Buffer   host_buf;\n    wgpu::Buffer   dev_buf;\n    wgpu::QuerySet query_set;\n};\n\n// Holds a pool of parameter buffers for WebGPU operations\nstruct webgpu_gpu_profile_buf_pool {\n    std::vector<webgpu_gpu_profile_bufs> free;\n\n    std::mutex mutex;\n\n    std::condition_variable cv;\n\n    void init(wgpu::Device      device,\n              int               num_bufs,\n              size_t            buf_size,\n              wgpu::BufferUsage dev_buf_usage,\n              wgpu::BufferUsage host_buf_usage) {\n        for (int i = 0; i < num_bufs; i++) {\n            wgpu::Buffer host_buf;\n            wgpu::Buffer dev_buf;\n            ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, \"ggml_webgpu_host_profile_buf\");\n            ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, \"ggml_webgpu_dev_profile_buf\");\n            // Create a query set for 2 timestamps\n            wgpu::QuerySetDescriptor ts_query_set_desc = {};\n\n            ts_query_set_desc.type      = wgpu::QueryType::Timestamp;\n            ts_query_set_desc.count     = 2;\n            wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc);\n\n            free.push_back({ host_buf, dev_buf, ts_query_set });\n        }\n    }\n\n    webgpu_gpu_profile_bufs alloc_bufs() {\n        std::unique_lock<std::mutex> lock(mutex);\n        cv.wait(lock, [this] { return !free.empty(); });\n        webgpu_gpu_profile_bufs bufs = free.back();\n        free.pop_back();\n        return bufs;\n    }\n\n    void free_bufs(std::vector<webgpu_gpu_profile_bufs> bufs) {\n        std::lock_guard<std::mutex> lock(mutex);\n        free.insert(free.end(), bufs.begin(), bufs.end());\n        cv.notify_all();\n    }\n\n    void cleanup() {\n        std::lock_guard<std::mutex> lock(mutex);\n        for (auto & bufs : free) {\n            bufs.host_buf.Destroy();\n            bufs.dev_buf.Destroy();\n            bufs.query_set.Destroy();\n        }\n        free.clear();\n    }\n\n    ~webgpu_gpu_profile_buf_pool() { this->cleanup(); }\n};\n#endif\n\nstruct webgpu_command {\n    uint32_t                  num_kernels;\n    wgpu::CommandBuffer       commands;\n    std::vector<wgpu::Buffer> params_bufs;\n#ifdef GGML_WEBGPU_GPU_PROFILE\n    webgpu_gpu_profile_bufs timestamp_query_bufs;\n    std::string             pipeline_name;\n#endif\n};\n\nstruct webgpu_capabilities {\n    wgpu::Limits limits;\n    bool         supports_subgroup_matrix = false;\n\n    uint32_t sg_mat_m = 0;\n    uint32_t sg_mat_n = 0;\n    uint32_t sg_mat_k = 0;\n\n    uint32_t subgroup_size     = 0;\n    uint32_t max_subgroup_size = 0;\n    size_t   memset_bytes_per_thread;\n};\n\n// Stores global webgpu members\nstruct webgpu_global_context_struct {\n    wgpu::Instance instance;\n    wgpu::Adapter  adapter;\n    wgpu::Device   device;\n    wgpu::Queue    queue;\n\n    webgpu_capabilities  capabilities;\n    // Shared buffer to move data from device to host\n    wgpu::Buffer         get_tensor_staging_buf;\n    // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.\n    std::recursive_mutex mutex;\n\n    webgpu_buf_pool                memset_buf_pool;\n    std::map<int, webgpu_pipeline> memset_pipelines;  // variant or type index\n\n#ifdef GGML_WEBGPU_CPU_PROFILE\n    // Profiling: labeled CPU time in ms (total)\n    std::unordered_map<std::string, double> cpu_time_ms;\n    // Profiling: detailed CPU time in ms\n    std::unordered_map<std::string, double> cpu_detail_ms;\n#endif\n\n#ifdef GGML_WEBGPU_GPU_PROFILE\n    // Profiling: per-shader GPU time in ms\n    std::unordered_map<std::string, double> shader_gpu_time_ms;\n    // Profiling: pool of timestamp query buffers (one per operation)\n    webgpu_gpu_profile_buf_pool             timestamp_query_buf_pool;\n#endif\n\n#ifdef GGML_WEBGPU_DEBUG\n    wgpu::Buffer debug_host_buf;\n    wgpu::Buffer debug_dev_buf;\n#endif\n\n    ~webgpu_global_context_struct() {\n        if (this->get_tensor_staging_buf) {\n            this->get_tensor_staging_buf.Destroy();\n            this->get_tensor_staging_buf = nullptr;\n        }\n#ifdef GGML_WEBGPU_DEBUG\n        if (this->debug_host_buf) {\n            this->debug_host_buf.Destroy();\n            this->debug_host_buf = nullptr;\n        }\n        if (this->debug_dev_buf) {\n            this->debug_dev_buf.Destroy();\n            this->debug_dev_buf = nullptr;\n        }\n#endif\n    }\n};\n\ntypedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;\n\nstruct webgpu_submission {\n    wgpu::FutureWaitInfo submit_done;\n#ifdef GGML_WEBGPU_GPU_PROFILE\n    std::vector<wgpu::FutureWaitInfo> profile_futures;\n#endif\n};\n\n// All the base objects needed to run operations on a WebGPU device\nstruct webgpu_context_struct {\n    // Points to global instances owned by ggml_backend_webgpu_reg_context\n    webgpu_global_context global_ctx;\n\n    std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;\n\n    webgpu_buf_pool param_buf_pool;\n    wgpu::Buffer    set_rows_dev_error_buf;\n    wgpu::Buffer    set_rows_host_error_buf;\n\n    std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines;                      // src_type, dst_type\n\n    std::map<int, webgpu_pipeline>                               rms_norm_pipelines;  // inplace\n    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines;      // type, ff, inplace\n    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines;       // glu_op, type, split\n\n    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines;  // mask_type, has_sink, inplace\n\n    size_t memset_bytes_per_thread;\n};\n\ntypedef std::shared_ptr<webgpu_context_struct> webgpu_context;\n\n// Metadata required for the ggml backend registration/discovery interface\nstruct ggml_backend_webgpu_reg_context {\n    // Since the Instance is a global entrypoint into the WebGPU API, it lives here\n    webgpu_global_context webgpu_global_ctx;\n    size_t                device_count;\n    const char *          name;\n};\n\n// Per-device struct for the global logical device interface\nstruct ggml_backend_webgpu_device_context {\n    webgpu_global_context webgpu_global_ctx;\n    std::string           device_name;\n    std::string           device_desc;\n};\n\n// Per-thread data required to actually run WebGPU operations in a backend instance\nstruct ggml_backend_webgpu_context {\n    webgpu_context webgpu_ctx;\n    std::string    name;\n};\n\n// Per-thread data related to buffers\nstruct ggml_backend_webgpu_buffer_context {\n    wgpu::Buffer          buffer;\n    std::string           label;\n    webgpu_global_context global_ctx;\n\n    ggml_backend_webgpu_buffer_context(wgpu::Buffer buf, std::string lbl, webgpu_global_context global_ctx_) :\n        buffer(std::move(buf)),\n        label(std::move(lbl)),\n        global_ctx(std::move(global_ctx_)) {}\n};\n\n/* WebGPU object initializations */\n\nstatic webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device &                           device,\n                                                   const char *                             shader_code,\n                                                   const char *                             label,\n                                                   const std::vector<wgpu::ConstantEntry> & constants = {}) {\n    wgpu::ShaderSourceWGSL shader_source;\n    shader_source.code = shader_code;\n\n    wgpu::ShaderModuleDescriptor shader_desc;\n    shader_desc.nextInChain = &shader_source;\n\n    wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);\n\n    wgpu::ComputePipelineDescriptor pipeline_desc;\n    pipeline_desc.label              = label;\n    pipeline_desc.compute.module     = shader_module;\n    pipeline_desc.compute.entryPoint = \"main\";   // Entry point in the WGSL code\n    pipeline_desc.layout             = nullptr;  // nullptr means auto layout\n    if (constants.size() > 0) {\n        pipeline_desc.compute.constants     = constants.data();\n        pipeline_desc.compute.constantCount = constants.size();\n    }\n    return { device.CreateComputePipeline(&pipeline_desc), label };\n}\n\nstatic void ggml_webgpu_create_buffer(wgpu::Device &    device,\n                                      wgpu::Buffer &    buffer,\n                                      size_t            size,\n                                      wgpu::BufferUsage usage,\n                                      const char *      label) {\n    wgpu::BufferDescriptor buffer_desc;\n    buffer_desc.size             = size;\n    buffer_desc.usage            = usage;\n    buffer_desc.label            = label;\n    buffer_desc.mappedAtCreation = false;\n\n    // TODO: error handling\n    buffer = device.CreateBuffer(&buffer_desc);\n}\n\n/** End WebGPU object initializations */\n\n/** WebGPU Actions */\n\nstatic bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) {\n    switch (status) {\n        case wgpu::WaitStatus::Success:\n            return true;\n        case wgpu::WaitStatus::TimedOut:\n            if (allow_timeout) {\n                return false;\n            }\n            GGML_LOG_ERROR(\"ggml_webgpu: WaitAny timed out unexpectedly\\n\");\n            return false;\n        case wgpu::WaitStatus::Error:\n            GGML_LOG_ERROR(\"ggml_webgpu: WaitAny returned an error\\n\");\n            return false;\n        default:\n            GGML_LOG_ERROR(\"ggml_webgpu: WaitAny returned an unknown status\\n\");\n            return false;\n    }\n}\n\n#ifdef GGML_WEBGPU_GPU_PROFILE\nstatic void ggml_backend_webgpu_erase_completed_futures(std::vector<wgpu::FutureWaitInfo> & futures) {\n    futures.erase(std::remove_if(futures.begin(), futures.end(),\n                                 [](const wgpu::FutureWaitInfo & info) { return info.completed; }),\n                  futures.end());\n}\n\nstatic void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context &             ctx,\n                                                     std::vector<wgpu::FutureWaitInfo> & futures,\n                                                     bool                                block) {\n    if (futures.empty()) {\n        return;\n    }\n\n    uint64_t timeout_ms = block ? UINT64_MAX : 0;\n    if (block) {\n        while (!futures.empty()) {\n            auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);\n            if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {\n                ggml_backend_webgpu_erase_completed_futures(futures);\n            }\n        }\n    } else {\n        auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);\n        if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {\n            ggml_backend_webgpu_erase_completed_futures(futures);\n        }\n    }\n}\n#endif\n\n// Wait for the queue to finish processing all submitted work\nstatic void ggml_backend_webgpu_wait(webgpu_global_context &          ctx,\n                                     std::vector<webgpu_submission> & subs,\n                                     bool                             block = true) {\n    // If we have too many in-flight submissions, wait on the oldest one first.\n    if (subs.empty()) {\n        return;\n    }\n    while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {\n        auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX);\n        if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {\n#ifdef GGML_WEBGPU_GPU_PROFILE\n            ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);\n#endif\n            subs.erase(subs.begin());\n        }\n    }\n\n    if (subs.empty()) {\n        return;\n    }\n\n    if (block) {\n        for (auto & sub : subs) {\n            while (!sub.submit_done.completed) {\n                auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX);\n                ggml_backend_webgpu_handle_wait_status(waitStatus);\n            }\n#ifdef GGML_WEBGPU_GPU_PROFILE\n            ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true);\n#endif\n        }\n        subs.clear();\n    } else {\n        // Poll each submit future once and remove completed submissions.\n        for (auto sub = subs.begin(); sub != subs.end();) {\n            auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);\n            ggml_backend_webgpu_handle_wait_status(waitStatus, true);\n#ifdef GGML_WEBGPU_GPU_PROFILE\n            ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);\n            if (sub->submit_done.completed && sub->profile_futures.empty()) {\n#else\n            if (sub->submit_done.completed) {\n#endif\n                sub = subs.erase(sub);\n            } else {\n                ++sub;\n            }\n        }\n    }\n}\n\nstatic void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,\n                                           wgpu::Buffer &          buffer,\n                                           wgpu::MapMode           mode,\n                                           size_t                  offset,\n                                           size_t                  size) {\n    ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,\n                                          [](wgpu::MapAsyncStatus status, wgpu::StringView message) {\n                                              if (status != wgpu::MapAsyncStatus::Success) {\n                                                  GGML_LOG_ERROR(\"ggml_webgpu: Failed to map buffer: %s\\n\",\n                                                                 message.data);\n                                              }\n                                          }),\n                          UINT64_MAX);\n}\n\n#ifdef GGML_WEBGPU_DEBUG\n// This function adds debugging information to shaders, as WebGPU does not support printing directly.\n// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and\n// debug statements in the shader, and then call this function after encoding the commands and submitting them.\nstatic void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {\n    wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();\n    encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());\n    wgpu::CommandBuffer commands = encoder.Finish();\n    ctx->queue.Submit(1, &commands);\n    ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());\n    const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();\n    std::cout << \"debug[0]: \" << debug_data[0] << \"\\n\";\n    ctx->debug_host_buf.Unmap();\n}\n#endif\n\nstatic webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context &       ctx,\n                                                    std::vector<webgpu_command> & commands,\n                                                    webgpu_buf_pool &             param_buf_pool) {\n    std::vector<wgpu::CommandBuffer> command_buffers;\n    std::vector<wgpu::Buffer>        params_bufs;\n    webgpu_submission                submission;\n#ifdef GGML_WEBGPU_GPU_PROFILE\n    std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;\n#endif\n\n    for (const auto & command : commands) {\n        command_buffers.push_back(command.commands);\n        params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());\n    }\n    ctx->queue.Submit(command_buffers.size(), command_buffers.data());\n\n    wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(\n        wgpu::CallbackMode::AllowSpontaneous,\n        [&param_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {\n            if (status != wgpu::QueueWorkDoneStatus::Success) {\n                GGML_LOG_ERROR(\"ggml_webgpu: Failed to submit commands: %s\\n\", std::string(message).c_str());\n            }\n            // Free the staged buffers\n            param_buf_pool.free_bufs(params_bufs);\n        });\n    submission.submit_done = { p_f };\n\n#ifdef GGML_WEBGPU_GPU_PROFILE\n    for (const auto & command : commands) {\n        auto label   = command.pipeline_name;\n        auto ts_bufs = command.timestamp_query_bufs;\n\n        wgpu::Future f = ts_bufs.host_buf.MapAsync(\n            wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,\n            [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {\n                if (status != wgpu::MapAsyncStatus::Success) {\n                    GGML_LOG_ERROR(\"ggml_webgpu: Failed to map timestamp buffer: %s\\n\", std::string(message).c_str());\n                } else {\n                    const uint64_t * ts_data    = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange();\n                    // WebGPU timestamps are in ns; convert to ms\n                    double           elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;\n                    ctx->shader_gpu_time_ms[label] += elapsed_ms;\n                }\n                // We can't unmap in here due to WebGPU reentrancy limitations.\n                ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });\n            });\n        submission.profile_futures.push_back({ f });\n    }\n#endif\n    return submission;\n}\n\nstatic webgpu_command ggml_backend_webgpu_build_multi(\n    webgpu_global_context &                                ctx,\n    webgpu_buf_pool &                                      param_buf_pool,\n    const std::vector<webgpu_pipeline> &                   pipelines,\n    const std::vector<std::vector<uint32_t>> &             params_list,\n    const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,\n    const std::vector<std::pair<uint32_t, uint32_t>> &     workgroups_list) {\n    GGML_ASSERT(pipelines.size() == params_list.size());\n    GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());\n    GGML_ASSERT(pipelines.size() == workgroups_list.size());\n\n    std::vector<wgpu::Buffer>    params_bufs_list;\n    std::vector<wgpu::BindGroup> bind_groups;\n\n    for (size_t i = 0; i < pipelines.size(); i++) {\n        wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs();\n\n        std::vector<wgpu::BindGroupEntry> entries            = bind_group_entries_list[i];\n        uint32_t                          params_binding_num = entries.size();\n        entries.push_back(\n            { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() });\n\n        wgpu::BindGroupDescriptor bind_group_desc;\n        bind_group_desc.layout     = pipelines[i].pipeline.GetBindGroupLayout(0);\n        bind_group_desc.entryCount = entries.size();\n        bind_group_desc.entries    = entries.data();\n        bind_group_desc.label      = pipelines[i].name.c_str();\n        bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc));\n\n        params_bufs_list.push_back(params_bufs);\n    }\n\n    wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();\n    for (size_t i = 0; i < params_bufs_list.size(); i++) {\n        ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));\n    }\n\n#ifdef GGML_WEBGPU_GPU_PROFILE\n    webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();\n    if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {\n        ts_bufs.host_buf.Unmap();\n    }\n\n    wgpu::PassTimestampWrites   ts_writes = { .querySet                  = ts_bufs.query_set,\n                                              .beginningOfPassWriteIndex = 0,\n                                              .endOfPassWriteIndex       = 1 };\n    wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes };\n    wgpu::ComputePassEncoder    pass      = encoder.BeginComputePass(&pass_desc);\n#else\n    wgpu::ComputePassEncoder pass = encoder.BeginComputePass();\n#endif\n    for (size_t i = 0; i < pipelines.size(); i++) {\n        pass.SetPipeline(pipelines[i].pipeline);\n        pass.SetBindGroup(0, bind_groups[i]);\n        pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1);\n    }\n    pass.End();\n\n#ifdef GGML_WEBGPU_GPU_PROFILE\n    encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);\n    encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());\n#endif\n\n    wgpu::CommandBuffer commands = encoder.Finish();\n    webgpu_command      result   = {};\n    result.commands              = commands;\n    result.params_bufs           = params_bufs_list;\n    result.num_kernels           = pipelines.size();\n#ifdef GGML_WEBGPU_GPU_PROFILE\n    result.timestamp_query_bufs = ts_bufs;\n    // TODO: handle multiple pipeline names\n    result.pipeline_name        = pipelines.front().name;\n#endif\n    return result;\n}\n\nstatic webgpu_command ggml_backend_webgpu_build(webgpu_global_context &           ctx,\n                                                webgpu_buf_pool &                 param_buf_pool,\n                                                webgpu_pipeline &                 pipeline,\n                                                std::vector<uint32_t>             params,\n                                                std::vector<wgpu::BindGroupEntry> bind_group_entries,\n                                                uint32_t                          wg_x,\n                                                uint32_t                          wg_y = 1) {\n    return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,\n                                           {\n                                               pipeline\n    },\n                                           { std::move(params) }, { std::move(bind_group_entries) },\n                                           { { wg_x, wg_y } });\n}\n\nstatic void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,\n                                              wgpu::Buffer &          buf,\n                                              uint32_t                value,\n                                              size_t                  offset,\n                                              size_t                  size) {\n    std::vector<uint32_t>             params  = { (uint32_t) offset, (uint32_t) size, value };\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }\n    };\n    size_t   bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread;\n    uint32_t wg_x         = CEIL_DIV(size + 3, bytes_per_wg);\n\n    webgpu_command command =\n        ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);\n    std::vector<webgpu_command>    commands = { command };\n    std::vector<webgpu_submission> sub      = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };\n    ggml_backend_webgpu_wait(ctx, sub);\n}\n\n/** End WebGPU Actions */\n\n/** GGML Backend Interface */\n\nstatic const char * ggml_backend_webgpu_name(ggml_backend_t backend) {\n    ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;\n    return ctx->name.c_str();\n}\n\nstatic void ggml_backend_webgpu_free(ggml_backend_t backend) {\n    ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;\n    WEBGPU_LOG_DEBUG(\"ggml_backend_webgpu_free(\" << ctx->name << \")\");\n\n#ifdef GGML_WEBGPU_CPU_PROFILE\n    std::cout << \"\\n[ggml_webgpu cpu profiling summary]\\n\";\n    double total_cpu = 0.0;\n    for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {\n        total_cpu += kv.second;\n    }\n    std::cout << \"ggml_webgpu: total cpu time: \" << total_cpu << \" ms\\n\";\n    std::cout << \"ggml_webgpu: cpu breakdown:\\n\";\n    for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {\n        double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;\n        std::cout << \"ggml_webgpu:  \" << kv.first << \": \" << kv.second << \" ms (\" << pct << \"%)\\n\";\n    }\n    if (ctx->webgpu_ctx->global_ctx->cpu_detail_ms.size() > 0) {\n        std::cout << \"ggml_webgpu: cpu detailed breakdown:\\n\";\n    }\n    for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_detail_ms) {\n        double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;\n        std::cout << \"ggml_webgpu:  \" << kv.first << \": \" << kv.second << \" ms (\" << pct << \"%)\\n\";\n    }\n#endif\n\n#ifdef GGML_WEBGPU_GPU_PROFILE\n    std::cout << \"\\n[ggml_webgpu gpu profiling summary]\\n\";\n    double total_gpu = 0.0;\n    for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {\n        total_gpu += kv.second;\n    }\n    std::cout << \"ggml_webgpu: total gpu time (all shaders): \" << total_gpu << \" ms\\n\";\n    std::cout << \"\\nggml_webgpu: gpu breakdown:\\n\";\n    for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {\n        double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;\n        std::cout << \"ggml_webgpu:  \" << kv.first << \": \" << kv.second << \" ms (\" << std::fixed << std::setprecision(2)\n                  << pct << \"%)\\n\";\n    }\n#endif\n\n#if defined(GGML_WEBGPU_CPU_PROFILE) && defined(GGML_WEBGPU_GPU_PROFILE)\n    std::cout << \"ggml_webgpu: gpu/cpu ratio: \" << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << \"\\n\";\n#endif\n\n    delete ctx;\n    delete backend;\n}\n\nstatic size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {\n    return webgpu_tensor_offset(tensor) + tensor->view_offs;\n}\n\nstatic wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {\n    ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;\n    return ctx->buffer;\n}\n\nstatic size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {\n    size_t offset = ggml_webgpu_tensor_offset(t);\n    return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);\n}\n\nstatic size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {\n    size_t offset = ggml_webgpu_tensor_offset(t);\n    return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);\n}\n\nstatic size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {\n    return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);\n}\n\n// Used to determine if two tensors are the same for in-place operations\nstatic bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {\n    return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&\n           (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));\n}\n\n// Used to determine if two tensors share the same buffer and their byte ranges overlap,\nstatic bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {\n    return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&\n           ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&\n           ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));\n}\n\nstruct binary_overlap_flags {\n    bool inplace;  // src0 == dst\n    bool overlap;  // src1 == dst\n    bool src_overlap;\n};\n\nstatic binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,\n                                                              ggml_tensor * src1,\n                                                              ggml_tensor * dst) {\n    binary_overlap_flags flags = {};\n    flags.inplace              = ggml_webgpu_tensor_equal(src0, dst);\n    flags.overlap              = ggml_webgpu_tensor_overlap(src1, dst);\n    flags.src_overlap          = ggml_webgpu_tensor_overlap(src0, src1);\n\n    return flags;\n}\n\nstatic webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {\n    uint32_t ne = (uint32_t) ggml_nelements(dst);\n\n    std::vector<uint32_t> params = {\n        ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n        // Convert byte-strides to element-strides\n        (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),\n        (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),\n        (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),\n        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),\n        // Logical shapes\n        (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],\n        (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]\n    };\n\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },\n        { .binding = 1,\n         .buffer  = ggml_webgpu_tensor_buf(dst),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }\n    };\n\n    uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],\n                                     params, entries, wg_x);\n}\n\nstatic webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {\n    ggml_webgpu_shader_lib_context shader_lib_ctx = {\n        .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup\n    };\n\n    webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx);\n\n    auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());\n\n    const uint32_t ne = (uint32_t) ggml_nelements(dst);\n\n    std::vector<uint32_t> params = {\n        ne,\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n        // Strides (in elements)\n        (uint32_t) (src->nb[0] / ggml_type_size(src->type)),\n        (uint32_t) (src->nb[1] / ggml_type_size(src->type)),\n        (uint32_t) (src->nb[2] / ggml_type_size(src->type)),\n        (uint32_t) (src->nb[3] / ggml_type_size(src->type)),\n        // Shapes\n        (uint32_t) src->ne[0],\n        (uint32_t) src->ne[1],\n        (uint32_t) src->ne[2],\n        (uint32_t) src->ne[3],\n        (uint32_t) dst->ne[0],\n        (uint32_t) dst->ne[1],\n        (uint32_t) dst->ne[2],\n        (uint32_t) dst->ne[3],\n        // Pad sizes\n        (uint32_t) ggml_get_op_params_i32(dst, 0),\n        (uint32_t) ggml_get_op_params_i32(dst, 1),\n        (uint32_t) ggml_get_op_params_i32(dst, 2),\n        (uint32_t) ggml_get_op_params_i32(dst, 3),\n        (uint32_t) ggml_get_op_params_i32(dst, 4),\n        (uint32_t) ggml_get_op_params_i32(dst, 5),\n        (uint32_t) ggml_get_op_params_i32(dst, 6),\n        (uint32_t) ggml_get_op_params_i32(dst, 7),\n    };\n\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },\n        { .binding = 1,\n         .buffer  = ggml_webgpu_tensor_buf(dst),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }\n    };\n\n    uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);\n}\n\nstatic std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,\n                                                          ggml_tensor *    src,\n                                                          ggml_tensor *    idx,\n                                                          ggml_tensor *    dst) {\n    // For set rows specifically, we need to check if src and idx are empty\n    // tensors.\n    if (ggml_is_empty(src) || ggml_is_empty(idx)) {\n        return std::nullopt;\n    }\n\n    ggml_webgpu_shader_lib_context shader_lib_ctx = {\n        .src0        = src,\n        .src1        = idx,\n        .dst         = dst,\n        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup\n    };\n\n    webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx);\n\n    auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());\n\n    std::vector<uint32_t> params = {\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n        // Convert byte-strides to element-strides\n        (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),\n        (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),\n        (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),\n        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),\n        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),\n        // Shape of src\n        (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],\n        // Shape of idx\n        (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])\n    };\n\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },\n        { .binding = 1,\n         .buffer  = ggml_webgpu_tensor_buf(idx),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, idx),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, idx) },\n        { .binding = 2,\n         .buffer  = ggml_webgpu_tensor_buf(dst),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }\n    };\n\n    if (decisions->i64_idx) {\n        entries.push_back({ .binding = 3,\n                            .buffer  = ctx->set_rows_dev_error_buf,\n                            .offset  = 0,\n                            .size    = ctx->set_rows_dev_error_buf.GetSize() });\n    }\n\n    uint32_t threads;\n    if (decisions->vec4) {\n        threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);\n    } else {\n        threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];\n    }\n    uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1);\n}\n\n// Workgroup size is a common constant\nstatic std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {\n    std::vector<wgpu::ConstantEntry> constants(1);\n    constants[0].key   = \"wg_size\";\n    constants[0].value = wg_size;\n    return constants;\n}\n\nstatic webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,\n                                           ggml_tensor *    src,\n                                           ggml_tensor *    idx,\n                                           ggml_tensor *    dst) {\n    ggml_webgpu_shader_lib_context shader_lib_ctx = {\n        .src0        = src,\n        .src1        = nullptr,\n        .dst         = dst,\n        .max_wg_size = WEBGPU_MAX_WG_SIZE,\n    };\n\n    webgpu_pipeline pipeline  = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);\n    auto *          decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());\n\n    std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),\n                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),\n                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n                                     (uint32_t) (src->nb[1] / ggml_type_size(src->type)),\n                                     (uint32_t) (src->nb[2] / ggml_type_size(src->type)),\n                                     (uint32_t) (src->nb[3] / ggml_type_size(src->type)),\n                                     (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),\n                                     (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),\n                                     (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),\n                                     (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),\n                                     (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),\n                                     (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),\n                                     (uint32_t) dst->ne[0],\n                                     (uint32_t) dst->ne[1],\n                                     (uint32_t) dst->ne[2],\n                                     (uint32_t) dst->ne[3],\n                                     (uint32_t) (idx->ne[1]),\n                                     (uint32_t) (idx->ne[2]) };\n\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },\n        { .binding = 1,\n         .buffer  = ggml_webgpu_tensor_buf(idx),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, idx),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, idx) },\n        { .binding = 2,\n         .buffer  = ggml_webgpu_tensor_buf(dst),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }\n    };\n\n    uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);\n\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);\n}\n\nstatic webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,\n                                          ggml_tensor *    src0,\n                                          ggml_tensor *    src1,\n                                          ggml_tensor *    dst) {\n    // Determine if this is a mat-vec operation\n    bool is_vec = (dst->ne[1] == 1);\n\n    // Determine if we should use fast path\n    bool use_fast = false;\n    switch (src1->type) {\n        case GGML_TYPE_F16:\n            use_fast = (src0->type == GGML_TYPE_F16);\n            break;\n        case GGML_TYPE_F32:\n            // TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K\n            switch (src0->type) {\n                case GGML_TYPE_F32:\n                case GGML_TYPE_F16:\n                case GGML_TYPE_Q4_0:\n                case GGML_TYPE_Q4_1:\n                case GGML_TYPE_Q5_0:\n                case GGML_TYPE_Q5_1:\n                case GGML_TYPE_Q8_0:\n                case GGML_TYPE_Q8_1:\n                case GGML_TYPE_Q6_K:\n                    use_fast = true;\n                    break;\n                case GGML_TYPE_Q2_K:\n                case GGML_TYPE_Q3_K:\n                case GGML_TYPE_Q4_K:\n                case GGML_TYPE_Q5_K:\n                    // we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat\n                    use_fast = !is_vec;\n                    break;\n                default:\n                    break;\n            }\n            break;\n        default:\n            break;\n    }\n\n    ggml_webgpu_shader_lib_context shader_lib_ctx = {\n        .src0                     = src0,\n        .src1                     = src1,\n        .dst                      = dst,\n        .max_wg_size              = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,\n        .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix,\n        .sg_mat_m                 = ctx->global_ctx->capabilities.sg_mat_m,\n        .sg_mat_n                 = ctx->global_ctx->capabilities.sg_mat_n,\n        .sg_mat_k                 = ctx->global_ctx->capabilities.sg_mat_k,\n        .max_subgroup_size        = ctx->global_ctx->capabilities.max_subgroup_size,\n    };\n\n    // Get or create pipeline\n    webgpu_pipeline pipeline;\n\n    if (use_fast && is_vec) {\n        pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx);\n    } else if (use_fast) {\n        pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);\n    } else {\n        pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx);\n    }\n\n    // Build params\n    std::vector<uint32_t> params = {\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n        (uint32_t) dst->ne[0],\n        (uint32_t) dst->ne[1],\n        (uint32_t) src0->ne[0],\n        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),\n        (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),\n        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),\n        (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),\n        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),\n        (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),\n        (uint32_t) src0->ne[2],\n        (uint32_t) src0->ne[3],\n        (uint32_t) (src1->ne[2] / src0->ne[2]),\n        (uint32_t) (src1->ne[3] / src0->ne[3])\n    };\n\n    // Build bind group entries\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src0),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },\n        { .binding = 1,\n         .buffer  = ggml_webgpu_tensor_buf(src1),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src1) },\n        { .binding = 2,\n         .buffer  = ggml_webgpu_tensor_buf(dst),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, dst)  },\n    };\n\n    // Calculate workgroup dimensions\n    uint32_t       wg_x           = 1;\n    uint32_t       wg_y           = 1;\n    const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;\n\n    if (use_fast && is_vec) {\n        auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());\n\n        uint32_t batches       = dst->ne[2] * dst->ne[3];\n        uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);\n        uint32_t total_wg      = output_groups * batches;\n        compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);\n    } else if (use_fast) {\n        auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());\n\n        // Fast-path tiled/subgroup calculations\n        uint32_t wg_m;\n        uint32_t wg_n;\n        if (decisions->use_subgroup_matrix) {\n            uint32_t wg_m_sg_tile =\n                decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;\n            wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);\n            uint32_t wg_n_sg_tile =\n                decisions->subgroup_n * decisions->subgroup_matrix_n * ctx->global_ctx->capabilities.sg_mat_n;\n            wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);\n        } else {\n            uint32_t tile_m_s = decisions->tile_m * decisions->wg_size_m;\n            uint32_t tile_n_s = decisions->tile_n * decisions->wg_size_n;\n            wg_m              = CEIL_DIV(dst->ne[0], tile_m_s);\n            wg_n              = CEIL_DIV(dst->ne[1], tile_n_s);\n        }\n        uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3];\n        compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);\n\n    } else {  // legacy\n        auto *   decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());\n        uint32_t wg_size   = decisions->wg_size;\n        uint32_t total_wg  = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);\n        compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);\n    }\n\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);\n}\n\n#ifndef __EMSCRIPTEN__\nstatic webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,\n                                             ggml_tensor *    Q,\n                                             ggml_tensor *    K,\n                                             ggml_tensor *    V,\n                                             ggml_tensor *    mask,\n                                             ggml_tensor *    sinks,\n                                             ggml_tensor *    dst) {\n    float scale = *(float *) dst->op_params;\n    float max_bias;\n    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));\n    float logit_softcap;\n    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));\n    if (logit_softcap != 0.0f) {\n        scale /= logit_softcap;\n    }\n    float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));\n    float m0          = powf(2.0f, -(max_bias) / n_head_log2);\n    float m1          = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n    const int has_mask  = (mask != nullptr);\n    const int has_sinks = (sinks != nullptr);\n\n    std::vector<uint32_t> params = {\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),\n        has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,\n        has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n        (uint32_t) Q->ne[2],                              // number of heads\n        (uint32_t) Q->ne[1],                              // sequence length (Q)\n        (uint32_t) K->ne[1],                              // sequence length (K/V)\n        (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)),  // stride (elements/blocks) of Q in dimension 1\n        (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)),  // stride (elements/blocks) of Q in dimension 2\n        (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)),  // stride (elements/blocks) of Q in dimension 3\n        (uint32_t) (K->nb[1] / ggml_type_size(K->type)),  // stride (elements/blocks) of K in dimension 1\n        (uint32_t) (K->nb[2] / ggml_type_size(K->type)),  // stride (elements/blocks) of K in dimension 2\n        (uint32_t) (K->nb[3] / ggml_type_size(K->type)),  // stride (elements/blocks) of K in dimension 3\n        (uint32_t) (V->nb[1] / ggml_type_size(V->type)),  // stride (elements/blocks) of V in dimension 1\n        (uint32_t) (V->nb[2] / ggml_type_size(V->type)),  // stride (elements/blocks) of V in dimension 2\n        (uint32_t) (V->nb[3] / ggml_type_size(V->type)),  // stride (elements/blocks) of V in dimension 3\n        has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0,  // stride of mask dim 3\n        (uint32_t) (Q->ne[2] / K->ne[2]),  // repeat factor for K/V in dim 2 (MHA/MQA/GQA)\n        *(uint32_t *) &scale,              // scale (possibly adjusted for logit softcap)\n        *(uint32_t *) &max_bias,\n        *(uint32_t *) &logit_softcap,\n        *(uint32_t *) &n_head_log2,\n        *(uint32_t *) &m0,\n        *(uint32_t *) &m1\n\n    };\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(Q),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, Q),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, Q) },\n        { .binding = 1,\n         .buffer  = ggml_webgpu_tensor_buf(K),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, K),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, K) },\n        { .binding = 2,\n         .buffer  = ggml_webgpu_tensor_buf(V),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, V),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, V) }\n    };\n    uint32_t binding_index = 3;\n    if (has_mask) {\n        entries.push_back({ .binding = binding_index++,\n                            .buffer  = ggml_webgpu_tensor_buf(mask),\n                            .offset  = ggml_webgpu_tensor_align_offset(ctx, mask),\n                            .size    = ggml_webgpu_tensor_binding_size(ctx, mask) });\n    }\n    if (has_sinks) {\n        entries.push_back({ .binding = binding_index++,\n                            .buffer  = ggml_webgpu_tensor_buf(sinks),\n                            .offset  = ggml_webgpu_tensor_align_offset(ctx, sinks),\n                            .size    = ggml_webgpu_tensor_binding_size(ctx, sinks) });\n    }\n    entries.push_back({ .binding = binding_index++,\n                        .buffer  = ggml_webgpu_tensor_buf(dst),\n                        .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n                        .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });\n\n    ggml_webgpu_shader_lib_context shader_lib_ctx = {\n        .src0               = Q,\n        .src1               = K,\n        .src2               = V,\n        .src3               = mask,\n        .src4               = sinks,\n        .dst                = dst,\n        .max_wg_size        = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,\n        .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,\n        .sg_mat_m           = ctx->global_ctx->capabilities.sg_mat_m,\n        .sg_mat_n           = ctx->global_ctx->capabilities.sg_mat_n,\n        .sg_mat_k           = ctx->global_ctx->capabilities.sg_mat_k,\n        .max_subgroup_size  = ctx->global_ctx->capabilities.max_subgroup_size,\n    };\n\n    webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);\n\n    auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());\n\n    uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);\n    uint32_t wg_x        = wg_per_head * Q->ne[2] * Q->ne[3];  // wg per head * number of heads * number of batches\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);\n}\n#endif\n\nstatic webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {\n    bool is_unary = dst->op == GGML_OP_UNARY;\n    bool inplace  = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);\n\n    ggml_webgpu_shader_lib_context shader_lib_ctx = {\n        .src0        = src,\n        .src1        = nullptr,\n        .dst         = dst,\n        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,\n        .inplace     = inplace,\n    };\n\n    webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);\n\n    auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());\n\n    uint32_t ne = (uint32_t) ggml_nelements(dst);\n\n    std::vector<uint32_t> params = { ne,\n                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),\n                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n                                     (uint32_t) (src->nb[0] / ggml_type_size(src->type)),\n                                     (uint32_t) (src->nb[1] / ggml_type_size(src->type)),\n                                     (uint32_t) (src->nb[2] / ggml_type_size(src->type)),\n                                     (uint32_t) (src->nb[3] / ggml_type_size(src->type)),\n                                     (uint32_t) src->ne[0],\n                                     (uint32_t) src->ne[1],\n                                     (uint32_t) src->ne[2] };\n\n    ggml_tensor * effective_src = src;\n    if (is_unary) {\n        ggml_unary_op unary_op = ggml_get_unary_op(dst);\n        switch (unary_op) {\n            case GGML_UNARY_OP_XIELU:\n                {\n                    // Get float parameters and reinterpret their bit patterns as uint32_t\n                    // for passing through the params buffer\n                    float alpha_n = ggml_get_op_params_f32(dst, 1);\n                    float alpha_p = ggml_get_op_params_f32(dst, 2);\n                    float beta    = ggml_get_op_params_f32(dst, 3);\n                    float eps     = ggml_get_op_params_f32(dst, 4);\n                    params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));\n                    params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));\n                    params.push_back(*reinterpret_cast<const uint32_t *>(&beta));\n                    params.push_back(*reinterpret_cast<const uint32_t *>(&eps));\n                    break;\n                }\n            default:\n                break;\n        }\n    } else if (dst->op == GGML_OP_CLAMP) {\n        float clamp_min = ggml_get_op_params_f32(dst, 0);\n        float clamp_max = ggml_get_op_params_f32(dst, 1);\n        params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_min));\n        params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_max));\n    } else if (dst->op == GGML_OP_FILL) {\n        float fill_val = ggml_get_op_params_f32(dst, 0);\n        params.push_back(*reinterpret_cast<const uint32_t *>(&fill_val));\n        effective_src = dst;  // fill simply fills dst\n    }\n\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(effective_src),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, effective_src),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, effective_src) },\n    };\n    if (!inplace) {\n        entries.push_back({ .binding = 1,\n                            .buffer  = ggml_webgpu_tensor_buf(dst),\n                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });\n    }\n\n    uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);\n}\n\nstatic webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,\n                                            ggml_tensor *    src0,\n                                            ggml_tensor *    src1,\n                                            ggml_tensor *    dst) {\n    binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);\n\n    ggml_webgpu_shader_lib_context shader_lib_ctx = {\n        .src0        = src0,\n        .src1        = src1,\n        .dst         = dst,\n        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,\n        .inplace     = flags.inplace,\n        .overlap     = flags.overlap,\n        .src_overlap = flags.src_overlap,\n    };\n\n    webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);\n\n    auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());\n\n    uint32_t ne = (uint32_t) ggml_nelements(dst);\n\n    size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0);\n    size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1);\n\n    uint32_t offset_merged_src0 = 0;\n    uint32_t offset_merged_src1 = 0;\n    if (flags.src_overlap) {\n        size_t min_off     = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);\n        offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));\n        offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));\n    }\n\n    std::vector<uint32_t> params = {\n        ne,\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n        offset_merged_src0,\n        offset_merged_src1,\n        (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),\n        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),\n        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),\n        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),\n        (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),\n        (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),\n        (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),\n        (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),\n        (uint32_t) src0->ne[0],\n        (uint32_t) src0->ne[1],\n        (uint32_t) src0->ne[2],\n        (uint32_t) src1->ne[0],\n        (uint32_t) src1->ne[1],\n        (uint32_t) src1->ne[2],\n        (uint32_t) src1->ne[3],\n    };\n\n    std::vector<wgpu::BindGroupEntry> entries;\n\n    if (flags.src_overlap) {\n        size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);\n        size_t merged_end    = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0),\n                                        src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1));\n        entries.push_back({\n            .binding = 0,\n            .buffer  = ggml_webgpu_tensor_buf(src0),\n            .offset  = merged_offset,\n            .size    = merged_end - merged_offset,\n        });\n        entries.push_back({\n            .binding = 1,\n            .buffer  = ggml_webgpu_tensor_buf(dst),\n            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n            .size    = ggml_webgpu_tensor_binding_size(ctx, dst),\n        });\n    } else {\n        entries.push_back({\n            .binding = 0,\n            .buffer  = ggml_webgpu_tensor_buf(src0),\n            .offset  = src0_webgpu_tensor_align_offset,\n            .size    = ggml_webgpu_tensor_binding_size(ctx, src0),\n        });\n        entries.push_back({\n            .binding = 1,\n            .buffer  = ggml_webgpu_tensor_buf(src1),\n            .offset  = src1_webgpu_tensor_align_offset,\n            .size    = ggml_webgpu_tensor_binding_size(ctx, src1),\n        });\n        if (!flags.inplace && !flags.overlap) {\n            entries.push_back({\n                .binding = 2,\n                .buffer  = ggml_webgpu_tensor_buf(dst),\n                .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n                .size    = ggml_webgpu_tensor_binding_size(ctx, dst),\n            });\n        }\n    }\n\n    uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);\n}\n\nstatic webgpu_command ggml_webgpu_concat(webgpu_context & ctx,\n                                         ggml_tensor *    src0,\n                                         ggml_tensor *    src1,\n                                         ggml_tensor *    dst) {\n    uint32_t ne  = (uint32_t) ggml_nelements(dst);\n    uint32_t dim = (uint32_t) dst->op_params[0];\n\n    std::vector<uint32_t> params = {\n        ne,\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n        (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),\n        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),\n        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),\n        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),\n        (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),\n        (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),\n        (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),\n        (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),\n        (uint32_t) dst->ne[0],\n        (uint32_t) dst->ne[1],\n        (uint32_t) dst->ne[2],\n        (uint32_t) dst->ne[3],\n        dim,\n        (uint32_t) src0->ne[dim]\n    };\n\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src0),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },\n        { .binding = 1,\n         .buffer  = ggml_webgpu_tensor_buf(src1),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src1) },\n        { .binding = 2,\n         .buffer  = ggml_webgpu_tensor_buf(dst),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, dst)  }\n    };\n\n    ggml_webgpu_shader_lib_context shader_lib_ctx = {\n        .src0        = src0,\n        .src1        = src1,\n        .dst         = dst,\n        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,\n    };\n\n    webgpu_pipeline pipeline  = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);\n    auto *          decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());\n    uint32_t        wg_x      = CEIL_DIV(ne, decisions->wg_size);\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);\n}\n\nstatic webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) {\n    uint32_t ne = (uint32_t) ggml_nelements(dst);\n\n    std::vector<uint32_t> params = { ne,\n                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) /\n                                                 ggml_type_size(src0->type)),\n                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n                                     (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),\n                                     (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),\n                                     (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),\n                                     (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),\n                                     (uint32_t) (src0->ne[0]),\n                                     (uint32_t) (src0->ne[1]),\n                                     (uint32_t) (src0->ne[2]),\n                                     (uint32_t) (src0->ne[3]),\n                                     (uint32_t) (dst->ne[0]),\n                                     (uint32_t) (dst->ne[1]),\n                                     (uint32_t) (dst->ne[2]) };\n\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src0),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },\n        { .binding = 1,\n         .buffer  = ggml_webgpu_tensor_buf(dst),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, dst)  }\n    };\n\n    ggml_webgpu_shader_lib_context shader_lib_ctx = {\n        .src0        = src0,\n        .dst         = dst,\n        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,\n    };\n\n    webgpu_pipeline pipeline  = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx);\n    auto *          decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());\n    uint32_t        wg_x      = CEIL_DIV(ne, decisions->wg_size);\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);\n}\n\nstatic webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {\n    int inplace = ggml_webgpu_tensor_equal(src, dst);\n\n    std::vector<uint32_t> params = {\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n        (uint32_t) (src->nb[1] / ggml_type_size(src->type)),\n        (uint32_t) (src->nb[2] / ggml_type_size(src->type)),\n        (uint32_t) (src->nb[3] / ggml_type_size(src->type)),\n        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),\n        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),\n        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),\n        (uint32_t) src->ne[0],\n        (uint32_t) src->ne[1],\n        (uint32_t) src->ne[2],\n        (uint32_t) src->ne[3],\n        *(uint32_t *) dst->op_params  // epsilon, treated as f32 in the shader\n    };\n\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src) }\n    };\n    if (!inplace) {\n        entries.push_back({ .binding = 1,\n                            .buffer  = ggml_webgpu_tensor_buf(dst),\n                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });\n    }\n\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,\n                                     entries, ggml_nrows(src));\n}\n\nstatic webgpu_command ggml_webgpu_rope(webgpu_context & ctx,\n                                       ggml_tensor *    src0,\n                                       ggml_tensor *    src1,\n                                       ggml_tensor *    src2,\n                                       ggml_tensor *    dst) {\n    const int inplace         = ggml_webgpu_tensor_equal(src0, dst);\n    const int has_freq_factor = (src2 != nullptr);\n\n    const int n_dims     = ((int32_t *) dst->op_params)[1];\n    const int mode       = ((int32_t *) dst->op_params)[2];\n    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];\n\n    float freq_base;\n    float freq_scale;\n    float ext_factor;\n    float attn_factor;\n    float beta_fast;\n    float beta_slow;\n    memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));\n    memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));\n    memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));\n    memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));\n    memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));\n    memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));\n\n    int sections[4];\n    memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));\n\n    float theta_scale = powf(freq_base, -2.0f / n_dims);\n\n    float corr_dims[2];\n    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);\n\n    std::vector<uint32_t> params = {\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),\n        src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),\n        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),\n        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),\n        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),\n        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),\n        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),\n        (uint32_t) ggml_nelements(src0) / 2,\n        (uint32_t) src0->ne[0],\n        (uint32_t) src0->ne[1],\n        (uint32_t) src0->ne[2],\n        (uint32_t) n_dims,\n        (uint32_t) mode,\n        *(uint32_t *) &theta_scale,\n        *(uint32_t *) &attn_factor,\n        *(uint32_t *) &freq_scale,\n        *(uint32_t *) &ext_factor,\n        *(uint32_t *) &corr_dims[0],\n        *(uint32_t *) &corr_dims[1],\n        (uint32_t) sections[0],\n        (uint32_t) sections[1],\n        (uint32_t) sections[2],\n        (uint32_t) sections[3]\n    };\n\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src0),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },\n        { .binding = 1,\n         .buffer  = ggml_webgpu_tensor_buf(src1),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src1) }\n    };\n    uint32_t dst_binding = 2;\n    if (has_freq_factor) {\n        dst_binding = 3;\n        entries.push_back({ .binding = 2,\n                            .buffer  = ggml_webgpu_tensor_buf(src2),\n                            .offset  = ggml_webgpu_tensor_align_offset(ctx, src2),\n                            .size    = ggml_webgpu_tensor_binding_size(ctx, src2) });\n    }\n    if (!inplace) {\n        entries.push_back({ .binding = dst_binding,\n                            .buffer  = ggml_webgpu_tensor_buf(dst),\n                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });\n    }\n\n    webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];\n    uint32_t        wg_x     = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);\n}\n\nstatic webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {\n    const int split = (src1 != nullptr);\n\n    std::vector<uint32_t> params = {\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),\n        src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),\n        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),\n        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),\n        src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :\n                          (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),\n        src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :\n                          (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),\n        src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :\n                          (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),\n        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),\n        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),\n        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),\n        (uint32_t) ggml_nelements(dst),\n        (uint32_t) dst->ne[0],\n        (uint32_t) dst->ne[1],\n        (uint32_t) dst->ne[2],\n        (uint32_t) ((int32_t *) dst->op_params)[1],  // swapped\n        *(uint32_t *) &dst->op_params[2],            // alpha, for swiglu_oai\n        *(uint32_t *) &dst->op_params[3],            // limit, for swiglu_oai\n    };\n\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src0),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },\n    };\n    uint32_t dst_binding = 1;\n    if (split) {\n        dst_binding = 2;\n        entries.push_back({ .binding = 1,\n                            .buffer  = ggml_webgpu_tensor_buf(src1),\n                            .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),\n                            .size    = ggml_webgpu_tensor_binding_size(ctx, src1) });\n    }\n    entries.push_back({ .binding = dst_binding,\n                        .buffer  = ggml_webgpu_tensor_buf(dst),\n                        .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n                        .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });\n\n    webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];\n    uint32_t        wg_x     = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);\n}\n\nstatic webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {\n    bool inplace = ggml_webgpu_tensor_equal(src, dst);\n\n    ggml_webgpu_shader_lib_context shader_lib_ctx = {\n        .src0        = src,\n        .src1        = nullptr,\n        .dst         = dst,\n        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,\n        .inplace     = inplace,\n    };\n\n    webgpu_pipeline pipeline  = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx);\n    auto *          decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());\n\n    // params unchanged\n    std::vector<uint32_t> params = {\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n        (uint32_t) (src->nb[1] / ggml_type_size(src->type)),\n        (uint32_t) (src->nb[2] / ggml_type_size(src->type)),\n        (uint32_t) (src->nb[3] / ggml_type_size(src->type)),\n        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),\n        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),\n        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),\n        (uint32_t) ggml_nelements(dst),\n        (uint32_t) src->ne[0],\n        (uint32_t) src->ne[1],\n        (uint32_t) src->ne[2],\n        *(uint32_t *) dst->op_params,     // scale\n        *(uint32_t *) &dst->op_params[1]  // bias\n    };\n\n    // bindgroups unchanged\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src) }\n    };\n\n    if (!inplace) {\n        entries.push_back({ .binding = 1,\n                            .buffer  = ggml_webgpu_tensor_buf(dst),\n                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });\n    }\n\n    uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);\n}\n\nstatic webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,\n                                           ggml_tensor *    src0,\n                                           ggml_tensor *    src1,\n                                           ggml_tensor *    src2,\n                                           ggml_tensor *    dst) {\n    const int inplace   = ggml_webgpu_tensor_equal(src0, dst);\n    const int mask_type = (src1 != nullptr) ? src1->type : 2;  // use 2 for no mask here\n    const int has_sink  = (src2 != nullptr);\n    float     max_bias;\n    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));\n    float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));\n    float m0          = powf(2.0f, -(max_bias) / n_head_log2);\n    float m1          = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);\n\n    std::vector<uint32_t> params = {\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),\n        mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,\n        has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,\n        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),\n        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),\n        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),\n        mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,\n        mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,\n        mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,\n        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),\n        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),\n        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),\n        (uint32_t) ggml_nelements(dst),\n        (uint32_t) src0->ne[0],\n        (uint32_t) src0->ne[1],\n        (uint32_t) src0->ne[2],\n        mask_type < 2 ? (uint32_t) src1->ne[2] : 0,\n        mask_type < 2 ? (uint32_t) src1->ne[3] : 0,\n        *(uint32_t *) dst->op_params,  // scale\n        *(uint32_t *) &max_bias,\n        *(uint32_t *) &n_head_log2,\n        *(uint32_t *) &m0,\n        *(uint32_t *) &m1\n    };\n\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src0),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) }\n    };\n    uint32_t binding_num = 1;\n    if (mask_type < 2) {\n        entries.push_back({ .binding = binding_num,\n                            .buffer  = ggml_webgpu_tensor_buf(src1),\n                            .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),\n                            .size    = ggml_webgpu_tensor_binding_size(ctx, src1) });\n        binding_num++;\n    }\n    if (has_sink) {\n        entries.push_back({ .binding = binding_num,\n                            .buffer  = ggml_webgpu_tensor_buf(src2),\n                            .offset  = ggml_webgpu_tensor_align_offset(ctx, src2),\n                            .size    = ggml_webgpu_tensor_binding_size(ctx, src2) });\n        binding_num++;\n    }\n    if (!inplace) {\n        entries.push_back({ .binding = binding_num,\n                            .buffer  = ggml_webgpu_tensor_buf(dst),\n                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });\n    }\n\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,\n                                     ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,\n                                     ggml_nrows(dst));\n}\n\nstatic webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {\n    std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),\n                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n                                     (uint32_t) src->ne[0] };\n\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },\n        { .binding = 1,\n         .buffer  = ggml_webgpu_tensor_buf(dst),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }\n    };\n\n    ggml_webgpu_shader_lib_context shader_lib_ctx = {\n        .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup\n    };\n\n    webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx);\n    uint32_t        wg_x     = ggml_nelements(dst);\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);\n}\n\nstatic webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {\n    bool is_top_k = dst->op == GGML_OP_TOP_K;\n\n    ggml_webgpu_shader_lib_context shader_lib_ctx = {\n        .src0               = src,\n        .src1               = nullptr,\n        .dst                = dst,\n        .max_wg_size        = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,\n        .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,\n    };\n\n    webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx);\n    auto * argsort_decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(argsort_pipeline.context.get());\n\n    webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx);\n\n    const uint32_t src_ne0 = (uint32_t) src->ne[0];\n    const uint32_t nrows   = (uint32_t) ggml_nrows(src);\n    const uint32_t npr     = CEIL_DIV(src_ne0, argsort_decisions->wg_size);\n    const uint32_t block_size =\n        is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size;\n    uint32_t out_ne0 = src_ne0;\n    if (is_top_k) {\n        if (npr > 1) {\n            const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size;\n            out_ne0                  = (npr - 1) * block_size + std::min(last_tile, block_size);\n        } else {\n            out_ne0 = block_size;\n        }\n    }\n\n    uint32_t merge_len    = block_size;\n    uint32_t merge_passes = 0;\n    while (merge_len < out_ne0) {\n        merge_len <<= 1;\n        merge_passes++;\n    }\n\n    const bool start_in_tmp = (merge_passes % 2) == 1;\n\n    const size_t dst_offset = ggml_webgpu_tensor_offset(dst);\n    const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t);\n    const size_t tmp_offset =\n        ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);\n    const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);\n    const size_t dst_binding_size =\n        ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT);\n\n    const uint32_t offset_src  = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type));\n    const uint32_t offset_dst  = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type));\n    const uint32_t offset_tmp  = 0;\n    const uint32_t stride_src1 = (uint32_t) (src->nb[1] / ggml_type_size(src->type));\n    const uint32_t stride_src2 = (uint32_t) (src->nb[2] / ggml_type_size(src->type));\n    const uint32_t stride_src3 = (uint32_t) (src->nb[3] / ggml_type_size(src->type));\n    const uint32_t stride_idx1 = out_ne0;\n    const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1];\n    const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2];\n\n    std::vector<webgpu_pipeline>                   pipelines;\n    std::vector<std::vector<uint32_t>>             params_list;\n    std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;\n    std::vector<std::pair<uint32_t, uint32_t>>     workgroups_list;\n\n    const uint32_t init_offset       = start_in_tmp ? offset_tmp : offset_dst;\n    const size_t   init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);\n    const size_t   init_binding_size = start_in_tmp ? tmp_binding_size : dst_binding_size;\n\n    std::vector<uint32_t> init_params = {\n        offset_src,  init_offset, stride_src1, stride_src2,           stride_src3,           stride_idx1,\n        stride_idx2, stride_idx3, src_ne0,     (uint32_t) src->ne[1], (uint32_t) src->ne[2], out_ne0,\n        block_size,  npr,         nrows\n    };\n\n    const uint32_t                    total_wg_init = npr * nrows;\n    const uint32_t                    max_wg    = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;\n    const uint32_t                    wg_x_init = std::min(total_wg_init, max_wg);\n    const uint32_t                    wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);\n    std::vector<wgpu::BindGroupEntry> init_entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },\n        { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size }\n    };\n\n    pipelines.push_back(argsort_pipeline);\n    params_list.push_back(std::move(init_params));\n    entries_list.push_back(std::move(init_entries));\n    workgroups_list.push_back({ wg_x_init, wg_y_init });\n\n    if (merge_passes == 0) {\n        return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,\n                                               entries_list, workgroups_list);\n    }\n\n    bool     in_is_tmp = start_in_tmp;\n    uint32_t len       = block_size;\n    while (len < out_ne0) {\n        const uint32_t nm = CEIL_DIV(out_ne0, 2 * len);\n\n        const bool     out_is_tmp  = !in_is_tmp;\n        const uint32_t offset_in   = in_is_tmp ? offset_tmp : offset_dst;\n        const uint32_t offset_out  = out_is_tmp ? offset_tmp : offset_dst;\n        const size_t   align_in    = in_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);\n        const size_t   align_out   = out_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);\n        const size_t   size_in     = in_is_tmp ? tmp_binding_size : dst_binding_size;\n        const size_t   size_out    = out_is_tmp ? tmp_binding_size : dst_binding_size;\n        const uint32_t top_k_out   = (is_top_k && nm == 1) ? (uint32_t) dst->ne[0] : out_ne0;\n        const uint32_t stride_out1 = top_k_out;\n        const uint32_t stride_out2 = top_k_out * (uint32_t) dst->ne[1];\n        const uint32_t stride_out3 = stride_out2 * (uint32_t) dst->ne[2];\n\n        std::vector<uint32_t> merge_params = { offset_src,\n                                               offset_in,\n                                               offset_out,\n                                               stride_src1,\n                                               stride_src2,\n                                               stride_src3,\n                                               stride_idx1,\n                                               stride_idx2,\n                                               stride_idx3,\n                                               stride_out1,\n                                               stride_out2,\n                                               stride_out3,\n                                               out_ne0,\n                                               (uint32_t) src->ne[1],\n                                               (uint32_t) src->ne[2],\n                                               top_k_out,\n                                               len,\n                                               nm,\n                                               nrows };\n\n        std::vector<wgpu::BindGroupEntry> merge_entries = {\n            { .binding = 0,\n             .buffer  = ggml_webgpu_tensor_buf(src),\n             .offset  = ggml_webgpu_tensor_align_offset(ctx, src),\n             .size    = ggml_webgpu_tensor_binding_size(ctx, src) },\n            { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in },\n            { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out }\n        };\n\n        const uint32_t total_wg_merge = nm * nrows;\n        const uint32_t wg_x_merge     = std::min(total_wg_merge, max_wg);\n        const uint32_t wg_y_merge     = CEIL_DIV(total_wg_merge, wg_x_merge);\n        workgroups_list.push_back({ wg_x_merge, wg_y_merge });\n        pipelines.push_back(argsort_merge_pipeline);\n        params_list.push_back(std::move(merge_params));\n        entries_list.push_back(std::move(merge_entries));\n\n        len <<= 1;\n        in_is_tmp = !in_is_tmp;\n    }\n\n    return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list,\n                                           workgroups_list);\n}\n\nstatic webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {\n    std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),\n                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n                                     (uint32_t) src->ne[0] };\n\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },\n        { .binding = 1,\n         .buffer  = ggml_webgpu_tensor_buf(dst),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }\n    };\n\n    ggml_webgpu_shader_lib_context shader_lib_ctx = {\n        .src0        = src,\n        .src1        = nullptr,\n        .dst         = dst,\n        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,\n    };\n\n    webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx);\n    uint32_t        wg_x     = ggml_nrows(dst);\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);\n}\n\nstatic webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {\n    bool                  total_sum = dst->op == GGML_OP_SUM;\n    std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),\n                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),\n                                     total_sum ? 0 : (uint32_t) (src->nb[1] / ggml_type_size(src->type)),\n                                     total_sum ? 0 : (uint32_t) (src->nb[2] / ggml_type_size(src->type)),\n                                     total_sum ? 0 : (uint32_t) (src->nb[3] / ggml_type_size(src->type)),\n                                     total_sum ? static_cast<uint32_t>(ggml_nelements(src)) : (uint32_t) src->ne[0],\n                                     total_sum ? 1 : (uint32_t) src->ne[1],\n                                     total_sum ? 1 : (uint32_t) src->ne[2] };\n\n    std::vector<wgpu::BindGroupEntry> entries = {\n        { .binding = 0,\n         .buffer  = ggml_webgpu_tensor_buf(src),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },\n        { .binding = 1,\n         .buffer  = ggml_webgpu_tensor_buf(dst),\n         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),\n         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }\n    };\n\n    ggml_webgpu_shader_lib_context shader_lib_ctx = {\n        .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup\n    };\n\n    webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx);\n\n    uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);\n    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);\n}\n\n// Returns the encoded command, or std::nullopt if the operation is a no-op\nstatic std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {\n    if (ggml_is_empty(node)) {\n        return std::nullopt;\n    }\n    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {\n        return std::nullopt;\n    }\n    WEBGPU_LOG_DEBUG(\"ggml_webgpu_encode_node(\" << node << \", \" << ggml_op_name(node->op) << \")\");\n\n    ggml_tensor * src0 = node->src[0];\n    ggml_tensor * src1 = node->src[1];\n    ggml_tensor * src2 = node->src[2];\n\n    switch (node->op) {\n            // no-ops\n        case GGML_OP_NONE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n        case GGML_OP_RESHAPE:\n            return std::nullopt;\n        case GGML_OP_CPY:\n        case GGML_OP_CONT:\n            return ggml_webgpu_cpy(ctx, src0, node);\n        case GGML_OP_SET_ROWS:\n            return ggml_webgpu_set_rows(ctx, src0, src1, node);\n        case GGML_OP_GET_ROWS:\n            return ggml_webgpu_get_rows(ctx, src0, src1, node);\n        case GGML_OP_MUL_MAT:\n            return ggml_webgpu_mul_mat(ctx, src0, src1, node);\n        case GGML_OP_FLASH_ATTN_EXT:\n#ifndef __EMSCRIPTEN__\n            return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);\n#else\n            return std::nullopt;\n#endif\n        case GGML_OP_ADD:\n        case GGML_OP_SUB:\n        case GGML_OP_MUL:\n        case GGML_OP_DIV:\n            return ggml_webgpu_binary_op(ctx, src0, src1, node);\n        case GGML_OP_CONCAT:\n            return ggml_webgpu_concat(ctx, src0, src1, node);\n        case GGML_OP_REPEAT:\n            return ggml_webgpu_repeat(ctx, src0, node);\n        case GGML_OP_RMS_NORM:\n            return ggml_webgpu_rms_norm(ctx, src0, node);\n        case GGML_OP_ROPE:\n            return ggml_webgpu_rope(ctx, src0, src1, src2, node);\n        case GGML_OP_GLU:\n            return ggml_webgpu_glu(ctx, src0, src1, node);\n        case GGML_OP_SCALE:\n            return ggml_webgpu_scale(ctx, src0, node);\n        case GGML_OP_SOFT_MAX:\n            return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);\n        case GGML_OP_UNARY:\n        case GGML_OP_CLAMP:\n        case GGML_OP_FILL:\n        case GGML_OP_LOG:\n        case GGML_OP_SQR:\n        case GGML_OP_SQRT:\n        case GGML_OP_SIN:\n        case GGML_OP_COS:\n            return ggml_webgpu_unary_op(ctx, src0, node);\n        case GGML_OP_PAD:\n            return ggml_webgpu_pad(ctx, src0, node);\n        case GGML_OP_ARGMAX:\n            return ggml_webgpu_argmax(ctx, src0, node);\n        case GGML_OP_ARGSORT:\n        case GGML_OP_TOP_K:\n            // we reuse the same argsort implementation for top_k\n            return ggml_webgpu_argsort(ctx, src0, node);\n        case GGML_OP_CUMSUM:\n            return ggml_webgpu_cumsum(ctx, src0, node);\n        case GGML_OP_SUM:\n        case GGML_OP_SUM_ROWS:\n            return ggml_webgpu_sum_rows(ctx, src0, node);\n        default:\n            return std::nullopt;\n    }\n}\n\nstatic ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {\n    WEBGPU_LOG_DEBUG(\"ggml_backend_webgpu_graph_compute(\" << cgraph->n_nodes << \" nodes)\");\n\n    ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;\n    webgpu_context                ctx         = backend_ctx->webgpu_ctx;\n\n    WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);\n\n    std::vector<webgpu_command>    commands;\n    std::vector<webgpu_submission> subs;\n    uint32_t                       num_batched_kernels = 0;\n    bool                           contains_set_rows   = false;\n\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {\n            contains_set_rows = true;\n        }\n        if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {\n            commands.push_back(*cmd);\n            num_batched_kernels += cmd.value().num_kernels;\n        }\n\n        if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {\n            num_batched_kernels = 0;\n            subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));\n            // Process events and check for completed submissions\n            ctx->global_ctx->instance.ProcessEvents();\n            ggml_backend_webgpu_wait(ctx->global_ctx, subs, false);\n            commands.clear();\n        }\n    }\n    if (!commands.empty()) {\n        subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));\n        commands.clear();\n    }\n\n    // If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking.\n    if (contains_set_rows) {\n        wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();\n        encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,\n                                   ctx->set_rows_host_error_buf.GetSize());\n        wgpu::CommandBuffer set_rows_commands = encoder.Finish();\n        ctx->global_ctx->queue.Submit(1, &set_rows_commands);\n        ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0,\n                                       ctx->set_rows_host_error_buf.GetSize());\n        const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange();\n        if (*error_data) {\n            GGML_ABORT(\"ggml_webgpu: SET_ROWS index > 2^32, unsupported.\");\n        }\n        ctx->set_rows_host_error_buf.Unmap();\n    }\n\n    ggml_backend_webgpu_wait(ctx->global_ctx, subs);\n    WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);\n    return GGML_STATUS_SUCCESS;\n}\n\nstatic ggml_backend_i ggml_backend_webgpu_i = {\n    /* .get_name                = */ ggml_backend_webgpu_name,\n    /* .free                    = */ ggml_backend_webgpu_free,\n    /* .set_tensor_async        = */ NULL,\n    /* .get_tensor_async        = */ NULL,\n    /* .cpy_tensor_async        = */ NULL,\n    /* .synchronize             = */ NULL,\n    /* .graph_plan_create       = */ NULL,\n    /* .graph_plan_free         = */ NULL,\n    /* .graph_plan_update       = */ NULL,\n    /* .graph_plan_compute      = */ NULL,\n    /* .graph_compute           = */ ggml_backend_webgpu_graph_compute,\n    /* .event_record            = */ NULL,\n    /* .event_wait              = */ NULL,\n    /* .graph_optimize          = */ NULL,\n};\n\n/* End GGML Backend Interface */\n\n/* GGML Backend Buffer Interface */\n\nstatic void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);\n    if (ctx != nullptr && ctx->buffer != nullptr) {\n        ctx->buffer.Destroy();\n        delete ctx;\n    }\n}\n\n// Returns the \"fake\" base pointer.\nstatic void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) {\n    GGML_UNUSED(buffer);\n    return webgpu_ptr_base;\n}\n\nstatic void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,\n                                                     ggml_tensor *         tensor,\n                                                     uint8_t               value,\n                                                     size_t                offset,\n                                                     size_t                size) {\n    if (size == 0) {\n        WEBGPU_LOG_DEBUG(\n            \"ggml_backend_webgpu_buffer_memset_tensor: size is zero, \"\n            \"nothing to do.\");\n        return;\n    }\n\n    WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);\n\n    ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;\n\n    WEBGPU_LOG_DEBUG(\"ggml_backend_webgpu_buffer_memset_tensor(\" << buf_ctx->label << \", \" << tensor << \", \" << value\n                                                                 << \", \" << offset << \", \" << size << \")\");\n\n    size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;\n\n    // This is a trick to set all bytes of a u32 to the same 1 byte value.\n    uint32_t val32 = (uint32_t) value * 0x01010101;\n    ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset, size);\n    WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->global_ctx);\n}\n\nstatic void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,\n                                                  ggml_tensor *         tensor,\n                                                  const void *          data,\n                                                  size_t                offset,\n                                                  size_t                size) {\n    WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);\n    ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;\n\n    WEBGPU_LOG_DEBUG(\"ggml_backend_webgpu_buffer_set_tensor(\" << buf_ctx->label << \", \" << tensor << \", \" << data\n                                                              << \", \" << offset << \", \" << size << \")\");\n\n    size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;\n\n    buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);\n\n    if (size % 4 != 0) {\n        // If size is not a multiple of 4, we need to memset the remaining bytes\n        size_t remaining_size = size % 4;\n\n        // pack the remaining bytes into a uint32_t\n        uint32_t val32 = 0;\n\n        for (size_t i = 0; i < remaining_size; i++) {\n            ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];\n        }\n        // memset the remaining bytes\n        ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,\n                                          total_offset + (size - remaining_size), remaining_size);\n    } else {\n        // wait for WriteBuffer to complete\n        buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(\n                                                  wgpu::CallbackMode::AllowSpontaneous,\n                                                  [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {\n                                                      if (status != wgpu::QueueWorkDoneStatus::Success) {\n                                                          GGML_LOG_ERROR(\"ggml_webgpu: Failed to submit commands: %s\\n\",\n                                                                         std::string(message).c_str());\n                                                      }\n                                                  }),\n                                              UINT64_MAX);\n    }\n    WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);\n}\n\nstatic void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,\n                                                  const ggml_tensor *   tensor,\n                                                  void *                data,\n                                                  size_t                offset,\n                                                  size_t                size) {\n    WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);\n    ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;\n    WEBGPU_LOG_DEBUG(\"ggml_backend_webgpu_buffer_get_tensor(\" << buf_ctx->label << \", \" << tensor << \", \" << data\n                                                              << \", \" << offset << \", \" << size << \")\");\n    wgpu::Device device = buf_ctx->global_ctx->device;\n\n    size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;\n\n    size_t final_size = size;\n    if (size % 4 != 0) {\n        // If size is not a multiple of 4, we need to round it up to the next\n        // multiple of 4\n        final_size = size + (4 - (size % 4));\n    }\n\n    std::lock_guard<std::recursive_mutex> lock(buf_ctx->global_ctx->mutex);\n\n    if (buf_ctx->global_ctx->get_tensor_staging_buf == nullptr ||\n        buf_ctx->global_ctx->get_tensor_staging_buf.GetSize() < final_size) {\n        // Create a new staging buffer if it doesn't exist or is too small\n        if (buf_ctx->global_ctx->get_tensor_staging_buf) {\n            buf_ctx->global_ctx->get_tensor_staging_buf.Destroy();\n        }\n        ggml_webgpu_create_buffer(device, buf_ctx->global_ctx->get_tensor_staging_buf, final_size,\n                                  wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, \"get_tensor_staging_buf\");\n    }\n\n    // Copy the data from the buffer to the staging buffer\n    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();\n    encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, buf_ctx->global_ctx->get_tensor_staging_buf, 0,\n                               final_size);\n    wgpu::CommandBuffer commands = encoder.Finish();\n\n    // Submit the command buffer to the queue\n    buf_ctx->global_ctx->queue.Submit(1, &commands);\n\n    // Map the staging buffer to read the data\n    ggml_backend_webgpu_map_buffer(buf_ctx->global_ctx, buf_ctx->global_ctx->get_tensor_staging_buf,\n                                   wgpu::MapMode::Read, 0, final_size);\n    // Must specify size here since the staging buffer might be larger than the tensor size\n    const void * mapped_range = buf_ctx->global_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);\n\n    // Copy the data from the mapped range to the output buffer\n    std::memcpy(data, mapped_range, size);\n    buf_ctx->global_ctx->get_tensor_staging_buf.Unmap();\n    WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, buf_ctx->global_ctx);\n}\n\nstatic void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    WEBGPU_LOG_DEBUG(\"ggml_backend_webgpu_buffer_clear(\" << buffer << \", \" << (uint32_t) value << \")\");\n    WEBGPU_CPU_PROFILE_TOTAL_START(clear);\n    ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;\n    ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, value, 0, buffer->size);\n    WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->global_ctx);\n}\n\nstatic ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {\n    /* .free_buffer     = */ ggml_backend_webgpu_buffer_free_buffer,\n    /* .get_base        = */ ggml_backend_webgpu_buffer_get_base,\n    /* .init_tensor     = */ NULL,  // TODO: optional, needed?\n    /* .memset_tensor   = */ ggml_backend_webgpu_buffer_memset_tensor,\n    /* .set_tensor      = */ ggml_backend_webgpu_buffer_set_tensor,\n    /* .get_tensor      = */ ggml_backend_webgpu_buffer_get_tensor,\n    /* .cpy_tensor      = */ NULL,  // TODO: optional, implement this\n    /* .clear           = */ ggml_backend_webgpu_buffer_clear,\n    /* .reset           = */ NULL,  // TODO: optional, think it coordinates with\n                                    // .init_tensor\n};\n\n/* End GGML Backend Buffer Interface */\n\n/* GGML Backend Buffer Type Interface */\n\nstatic const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);\n    return ctx->device_name.c_str();\n}\n\nstatic ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,\n                                                                          size_t                     size) {\n    static std::atomic<int> buffer_count;\n    int                     buffer_id = buffer_count++;\n    std::string             buf_name  = \"tensor_buf\" + std::to_string(buffer_id);\n    WEBGPU_LOG_DEBUG(\"ggml_backend_webgpu_buffer_type_alloc_buffer_\" << buffer_id << \": \" << size << \" bytes\");\n\n    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);\n    wgpu::Buffer                         buf;\n    ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),\n                              wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,\n                              buf_name.c_str());\n\n    ggml_backend_webgpu_buffer_context * buf_ctx =\n        new ggml_backend_webgpu_buffer_context(buf, buf_name, ctx->webgpu_global_ctx);\n\n    return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);\n}\n\nstatic size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    ggml_backend_webgpu_device_context * dev_ctx =\n        static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);\n    return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;\n}\n\n// maxBufferSize might be larger, but you can't bind more than\n// maxStorageBufferBindingSize to a single binding.\nstatic size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {\n    ggml_backend_webgpu_device_context * dev_ctx =\n        static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);\n    return dev_ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize;\n}\n\nstatic size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,\n                                                             const ggml_tensor *        tensor) {\n    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);\n    size_t                               res = ggml_nbytes(tensor);\n    switch (tensor->op) {\n        case GGML_OP_ARGSORT:\n            res = ROUNDUP_POW2(res * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,\n                               WEBGPU_STORAGE_BUF_BINDING_MULT);\n            break;\n        case GGML_OP_TOP_K:\n            {\n                const ggml_tensor * src0 = tensor->src[0];\n                if (src0) {\n                    const size_t full = sizeof(int32_t) * ggml_nelements(src0);\n                    res               = ROUNDUP_POW2(\n                        full * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,\n                        WEBGPU_STORAGE_BUF_BINDING_MULT);\n                }\n            }\n            break;\n        default:\n            break;\n    }\n    return res;\n}\n\n/* End GGML Backend Buffer Type Interface */\n\n/* GGML Backend Device Interface */\n\nstatic const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) {\n    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);\n    return ctx->device_name.c_str();\n}\n\nstatic const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) {\n    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);\n    return ctx->device_desc.c_str();\n}\n\nstatic void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {\n    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);\n    // TODO: for now, return maxBufferSize as both free and total memory\n    // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.\n    uint64_t                             max_buffer_size = ctx->webgpu_global_ctx->capabilities.limits.maxBufferSize;\n    // If we're on a 32-bit system, clamp to UINTPTR_MAX\n#if UINTPTR_MAX < UINT64_MAX\n    uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX);\n    if (max_buffer_size > max_ptr_size) {\n        max_buffer_size = max_ptr_size;\n    }\n#endif\n    *free  = static_cast<size_t>(max_buffer_size);\n    *total = static_cast<size_t>(max_buffer_size);\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {\n    GGML_UNUSED(dev);\n    return GGML_BACKEND_DEVICE_TYPE_GPU;\n}\n\nstatic void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {\n    props->name        = ggml_backend_webgpu_device_get_name(dev);\n    props->description = ggml_backend_webgpu_device_get_description(dev);\n    props->type        = ggml_backend_webgpu_device_get_type(dev);\n    ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total);\n    props->caps = {\n        /* .async                 = */ false,\n        /* .host_buffer           = */ false,\n        /* .buffer_from_host_ptr  = */ false,\n        /* .events                = */ false,\n    };\n}\n\nstatic ggml_guid_t ggml_backend_webgpu_guid(void) {\n    static const char * guid_str = \"__ggml_webgpu :)\";\n    return reinterpret_cast<ggml_guid_t>((void *) guid_str);\n}\n\nstatic void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {\n    // we use the maximum workgroup size for the memset pipeline\n    size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;\n    // Size the bytes_per_thread so that the largest buffer size can be handled\n    ctx->capabilities.memset_bytes_per_thread =\n        CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);\n    std::vector<wgpu::ConstantEntry> constants(2);\n    constants[0].key         = \"wg_size\";\n    constants[0].value       = WEBGPU_MAX_WG_SIZE;\n    constants[1].key         = \"bytes_per_thread\";\n    constants[1].value       = ctx->capabilities.memset_bytes_per_thread;\n    ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, \"memset\", constants);\n}\n\nstatic void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {\n    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);\n\n    webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, \"cpy_f32_f32\", constants);\n    webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, \"cpy_f32_i32\", constants);\n    webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, \"cpy_f32_f16\", constants);\n    webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, \"cpy_f16_f32\", constants);\n    webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, \"cpy_f16_f16\", constants);\n}\n\nstatic void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {\n    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);\n\n    webgpu_ctx->rms_norm_pipelines[0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, \"rms_norm\", constants);\n    webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, \"rms_norm_inplace\", constants);\n}\n\nstatic void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {\n    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);\n\n    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, \"rope_f32\", constants);\n    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, \"rope_f32_inplace\", constants);\n    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, \"rope_f32_ff\", constants);\n    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, \"rope_f32_ff_inplace\", constants);\n\n    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, \"rope_f16\", constants);\n    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, \"rope_f16_inplace\", constants);\n    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, \"rope_f16_ff\", constants);\n    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, \"rope_f16_ff_inplace\", constants);\n}\n\nstatic void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {\n    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);\n\n    // REGLU\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, \"reglu_f32\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, \"reglu_f16\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, \"reglu_f32_split\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, \"reglu_f16_split\", constants);\n\n    // GEGLU\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, \"geglu_f32\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, \"geglu_f16\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, \"geglu_f32_split\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, \"geglu_f16_split\", constants);\n\n    // SWIGLU\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, \"swiglu_f32\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, \"swiglu_f16\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, \"swiglu_f32_split\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, \"swiglu_f16_split\", constants);\n\n    // SWIGLU_OAI\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, \"swiglu_oai_f32\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, \"swiglu_oai_f32_split\", constants);\n\n    // GEGLU_ERF\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, \"geglu_erf_f32\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, \"geglu_erf_f16\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, \"geglu_erf_f32_split\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, \"geglu_erf_f16_split\", constants);\n\n    // GEGLU_QUICK\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, \"geglu_quick_f32\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, \"geglu_quick_f16\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, \"geglu_quick_f32_split\", constants);\n    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, \"geglu_quick_f16_split\", constants);\n}\n\nstatic void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {\n    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);\n\n    // f32 (no mask)\n    webgpu_ctx->soft_max_pipelines[2][0][0] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, \"soft_max_f32\", constants);\n    webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, \"soft_max_f32_inplace\", constants);\n    webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, \"soft_max_f32_sink\", constants);\n    webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, \"soft_max_f32_sink_inplace\", constants);\n\n    // f32 mask (mask_type = 0)\n    webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, \"soft_max_f32_mask_f32\", constants);\n    webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, \"soft_max_f32_mask_f32_inplace\", constants);\n    webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, \"soft_max_f32_mask_f32_sink\", constants);\n    webgpu_ctx->soft_max_pipelines[0][1][1] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,\n                                    \"soft_max_f32_mask_f32_sink_inplace\", constants);\n\n    // f16 mask (mask_type = 1)\n    webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, \"soft_max_f32_mask_f16\", constants);\n    webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, \"soft_max_f32_mask_f16_inplace\", constants);\n    webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(\n        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, \"soft_max_f32_mask_f16_sink\", constants);\n    webgpu_ctx->soft_max_pipelines[1][1][1] =\n        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,\n                                    \"soft_max_f32_mask_f16_sink_inplace\", constants);\n}\n\nstatic bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {\n    wgpu::RequestAdapterOptions options = {};\n\n#ifndef __EMSCRIPTEN__\n    // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215\n    const char * const          adapterEnabledToggles[] = { \"vulkan_enable_f16_on_nvidia\", \"use_vulkan_memory_model\" };\n    wgpu::DawnTogglesDescriptor adapterTogglesDesc;\n    adapterTogglesDesc.enabledToggles     = adapterEnabledToggles;\n    adapterTogglesDesc.enabledToggleCount = 2;\n    options.nextInChain                   = &adapterTogglesDesc;\n#endif\n\n    ctx->webgpu_global_ctx->instance.WaitAny(\n        ctx->webgpu_global_ctx->instance.RequestAdapter(\n            &options, wgpu::CallbackMode::AllowSpontaneous,\n            [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {\n                if (status != wgpu::RequestAdapterStatus::Success) {\n                    GGML_LOG_ERROR(\"ggml_webgpu: Failed to get an adapter: %s\\n\", message);\n                    return;\n                }\n                ctx->webgpu_global_ctx->adapter = std::move(adapter);\n            }),\n        UINT64_MAX);\n    GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr);\n\n    ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits);\n\n    wgpu::AdapterInfo info{};\n#ifndef __EMSCRIPTEN__\n    wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};\n    if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {\n        info.nextInChain = &subgroup_matrix_configs;\n    }\n#endif\n    ctx->webgpu_global_ctx->adapter.GetInfo(&info);\n    wgpu::SupportedFeatures features;\n    ctx->webgpu_global_ctx->adapter.GetFeatures(&features);\n    // we require f16 support\n    GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));\n\n#ifndef __EMSCRIPTEN__\n    // Only support square f16 matrices of size 8 or 16 for now\n    bool valid_subgroup_matrix_config = false;\n    if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {\n        for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {\n            const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];\n            if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&\n                config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&\n                config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {\n                ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;\n                ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;\n                ctx->webgpu_global_ctx->capabilities.sg_mat_k = config.K;\n                valid_subgroup_matrix_config                  = true;\n                break;\n            }\n        }\n    }\n    ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;\n#endif\n\n    // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.\n    // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.\n    ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize;\n    // Initialize device\n    std::vector<wgpu::FeatureName> required_features       = { wgpu::FeatureName::ShaderF16 };\n\n#ifndef __EMSCRIPTEN__\n    required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);\n    if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {\n        required_features.push_back(wgpu::FeatureName::Subgroups);\n        required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);\n    }\n#endif\n\n#ifdef GGML_WEBGPU_GPU_PROFILE\n    required_features.push_back(wgpu::FeatureName::TimestampQuery);\n#endif\n\n    wgpu::DeviceDescriptor dev_desc;\n    dev_desc.requiredLimits       = &ctx->webgpu_global_ctx->capabilities.limits;\n    dev_desc.requiredFeatures     = required_features.data();\n    dev_desc.requiredFeatureCount = required_features.size();\n    dev_desc.SetDeviceLostCallback(\n        wgpu::CallbackMode::AllowSpontaneous,\n        [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {\n            if (reason == wgpu::DeviceLostReason::Destroyed) {\n                return;\n            }\n            GGML_UNUSED(device);\n            GGML_LOG_ERROR(\"ggml_webgpu: Device lost! Reason: %d, Message: %s\\n\", static_cast<int>(reason),\n                           std::string(message).c_str());\n        });\n    dev_desc.SetUncapturedErrorCallback(\n        [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {\n            GGML_UNUSED(device);\n            GGML_ABORT(\"ggml_webgpu: Device error! Reason: %d, Message: %s\\n\", static_cast<int>(reason),\n                       std::string(message).c_str());\n        });\n\n#ifndef __EMSCRIPTEN__\n    // Enable Dawn-specific toggles to increase native performance\n    // TODO: Maybe WebGPU needs a \"fast\" mode where you can request compilers skip adding checks like these,\n    //       only for native performance?\n    const char * const deviceEnabledToggles[]  = { \"skip_validation\", \"disable_robustness\", \"disable_workgroup_init\",\n                                                   \"disable_polyfills_on_integer_div_and_mod\" };\n    const char * const deviceDisabledToggles[] = { \"timestamp_quantization\" };\n    wgpu::DawnTogglesDescriptor deviceTogglesDesc;\n    deviceTogglesDesc.enabledToggles      = deviceEnabledToggles;\n    deviceTogglesDesc.enabledToggleCount  = 4;\n    deviceTogglesDesc.disabledToggles     = deviceDisabledToggles;\n    deviceTogglesDesc.disabledToggleCount = 1;\n\n    dev_desc.nextInChain = &deviceTogglesDesc;\n#endif\n\n    ctx->webgpu_global_ctx->instance.WaitAny(\n        ctx->webgpu_global_ctx->adapter.RequestDevice(\n            &dev_desc, wgpu::CallbackMode::AllowSpontaneous,\n            [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {\n                if (status != wgpu::RequestDeviceStatus::Success) {\n                    GGML_LOG_ERROR(\"ggml_webgpu: Failed to get a device: %s\\n\", std::string(message).c_str());\n                    return;\n                }\n                ctx->webgpu_global_ctx->device = std::move(device);\n            }),\n        UINT64_MAX);\n    GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr);\n\n    ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx);\n    ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES,\n                                                 wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,\n                                                 wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);\n    ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue();\n\n#ifdef GGML_WEBGPU_GPU_PROFILE\n    // Initialize buffer pool for timestamp queries, used for profiling\n    ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(\n        ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,\n        wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,\n        wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);\n#endif\n\n    GGML_LOG_INFO(\n        \"ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | \"\n        \"device_desc: %s\\n\",\n        info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,\n        std::string(info.device).c_str(), std::string(info.description).c_str());\n    return true;\n}\n\nstatic webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {\n    ggml_backend_webgpu_device_context * dev_ctx    = (ggml_backend_webgpu_device_context *) dev->context;\n    webgpu_context                       webgpu_ctx = std::make_shared<webgpu_context_struct>();\n    webgpu_ctx->global_ctx                          = dev_ctx->webgpu_global_ctx;\n    webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);\n    webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,\n                                    wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,\n                                    wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);\n    ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,\n                              WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,\n                              wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, \"set_rows_dev_error_buf\");\n    ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf,\n                              WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,\n                              wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, \"set_rows_host_error_buf\");\n\n    ggml_webgpu_init_cpy_pipeline(webgpu_ctx);\n    ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);\n    ggml_webgpu_init_rope_pipeline(webgpu_ctx);\n    ggml_webgpu_init_glu_pipeline(webgpu_ctx);\n    ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);\n#ifdef GGML_WEBGPU_DEBUG\n    // Initialize debug buffers\n    ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,\n                              WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),\n                              wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, \"debug_host_buf\");\n    ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_dev_buf,\n                              WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),\n                              wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, \"debug_dev_buf\");\n#endif\n    return webgpu_ctx;\n}\n\nstatic ggml_backend_t ggml_backend_webgpu_backend_init(ggml_backend_dev_t dev, const char * params) {\n    GGML_UNUSED(params);\n\n    WEBGPU_LOG_DEBUG(\"ggml_backend_webgpu_backend_init()\");\n\n    ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);\n\n    auto * backend_ctx      = new ggml_backend_webgpu_context();\n    backend_ctx->name       = GGML_WEBGPU_NAME + std::string(\": \") + dev_ctx->device_name;\n    backend_ctx->webgpu_ctx = initialize_webgpu_context(dev);\n\n    // See GGML Backend Interface section\n    auto * backend = new ggml_backend();\n    *backend       = {\n        /* .guid      = */ ggml_backend_webgpu_guid(),\n        /* .interface = */ ggml_backend_webgpu_i,\n        /* .device    = */ dev,\n        /* .context   = */ backend_ctx,\n    };\n    return backend;\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {\n    // See GGML Backend Buffer Type Interface section\n\n    static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {\n        /* .iface = */ {\n                        /* .get_name         = */ ggml_backend_webgpu_buffer_type_get_name,\n                        /* .alloc_buffer     = */\n            ggml_backend_webgpu_buffer_type_alloc_buffer,                                    /* .get_alignment    = */\n            ggml_backend_webgpu_buffer_type_get_alignment,                                   /* .get_max_size     = */\n            ggml_backend_webgpu_buffer_type_get_max_size,                                    /* .get_alloc_size   = */\n            ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host          = */ NULL,  // defaults to false\n        },\n        /* .device  = */\n        dev,\n        /* .context = */\n        NULL\n    };\n\n    return &ggml_backend_webgpu_buffer_type;\n}\n\nstatic bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    GGML_UNUSED(dev);\n    return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;\n}\n\nstatic bool ggml_webgpu_supported_qtype(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_Q4_0:\n        case GGML_TYPE_Q4_1:\n        case GGML_TYPE_Q5_0:\n        case GGML_TYPE_Q5_1:\n        case GGML_TYPE_Q8_0:\n        case GGML_TYPE_Q2_K:\n        case GGML_TYPE_Q3_K:\n        case GGML_TYPE_Q4_K:\n        case GGML_TYPE_Q5_K:\n        case GGML_TYPE_Q6_K:\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ2_S:\n        case GGML_TYPE_IQ3_XXS:\n        case GGML_TYPE_IQ3_S:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:\n        case GGML_TYPE_IQ4_NL:\n        case GGML_TYPE_IQ4_XS:\n            return true;\n        default:\n            return false;\n    }\n}\n\nstatic bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);\n\n    ggml_tensor * src0 = op->src[0];\n    ggml_tensor * src1 = op->src[1];\n    ggml_tensor * src2 = op->src[2];\n\n    // on smaller devices (or CI), tensors may be larger than the max storage buffer size\n    if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||\n        (src0 != nullptr &&\n         ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||\n        (src1 != nullptr &&\n         ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {\n        return false;\n    }\n\n    bool supports_op = false;\n    switch (op->op) {\n        case GGML_OP_NONE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n        case GGML_OP_RESHAPE:\n            supports_op = true;\n            break;\n        case GGML_OP_ADD:\n        case GGML_OP_SUB:\n        case GGML_OP_MUL:\n        case GGML_OP_DIV:\n            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&\n                          (src1->type == op->type);\n            break;\n        case GGML_OP_CONCAT:\n            supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);\n            break;\n        case GGML_OP_REPEAT:\n            supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16);\n            break;\n        case GGML_OP_CPY:\n        case GGML_OP_CONT:\n            supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&\n                           (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||\n                          (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);\n            break;\n        case GGML_OP_SET_ROWS:\n            supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&\n                           (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));\n            break;\n        case GGML_OP_GET_ROWS:\n            if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) {\n                supports_op = (op->type == GGML_TYPE_F32);\n            } else if (src0->type == GGML_TYPE_I32) {\n                supports_op = op->type == GGML_TYPE_I32;\n            }\n            break;\n        case GGML_OP_MUL_MAT:\n            {\n                switch (src1->type) {\n                    case GGML_TYPE_F16:\n                        supports_op |= (src0->type == GGML_TYPE_F16);\n                        break;\n                    case GGML_TYPE_F32:\n                        switch (src0->type) {\n                            case GGML_TYPE_F32:\n                            case GGML_TYPE_F16:\n                            case GGML_TYPE_Q4_0:\n                            case GGML_TYPE_Q4_1:\n                            case GGML_TYPE_Q5_0:\n                            case GGML_TYPE_Q5_1:\n                            case GGML_TYPE_Q8_0:\n                            case GGML_TYPE_Q2_K:\n                            case GGML_TYPE_Q3_K:\n                            case GGML_TYPE_Q4_K:\n                            case GGML_TYPE_Q5_K:\n                            case GGML_TYPE_Q6_K:\n                            case GGML_TYPE_IQ2_XXS:\n                            case GGML_TYPE_IQ2_XS:\n                            case GGML_TYPE_IQ2_S:\n                            case GGML_TYPE_IQ3_XXS:\n                            case GGML_TYPE_IQ3_S:\n                            case GGML_TYPE_IQ1_S:\n                            case GGML_TYPE_IQ1_M:\n                            case GGML_TYPE_IQ4_NL:\n                            case GGML_TYPE_IQ4_XS:\n                                supports_op = true;\n                                break;\n                            default:\n                                break;\n                        }\n                    default:\n                        break;\n                }\n                break;\n            }\n        case GGML_OP_FLASH_ATTN_EXT:\n            {\n#ifndef __EMSCRIPTEN__\n                if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {\n                    break;\n                }\n                // Head dimensions must fit in workgroup memory with minimum tile sizes\n                size_t     limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;\n                const bool has_mask    = op->src[3] != nullptr;\n                const bool kv_direct   = src1->type == GGML_TYPE_F16 &&\n                                       (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&\n                                       (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;\n                const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(\n                    ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,\n                    (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);\n                if (min_bytes > limit_bytes) {\n                    break;\n                }\n\n                supports_op = src0->type == GGML_TYPE_F32 &&\n                              (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||\n                               src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&\n                              src2->type == src1->type && op->type == GGML_TYPE_F32;\n#endif\n                break;\n            }\n        case GGML_OP_RMS_NORM:\n            supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;\n            break;\n        case GGML_OP_ROPE:\n            supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;\n            break;\n        case GGML_OP_GLU:\n            switch (ggml_get_glu_op(op)) {\n                case GGML_GLU_OP_REGLU:\n                case GGML_GLU_OP_GEGLU:\n                case GGML_GLU_OP_SWIGLU:\n                case GGML_GLU_OP_GEGLU_ERF:\n                case GGML_GLU_OP_GEGLU_QUICK:\n                    supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;\n                    break;\n                case GGML_GLU_OP_SWIGLU_OAI:\n                    supports_op = op->type == GGML_TYPE_F32;\n                    break;\n                default:\n                    break;\n            }\n            break;\n        case GGML_OP_SCALE:\n            supports_op = op->type == GGML_TYPE_F32;\n            break;\n        case GGML_OP_SOFT_MAX:\n            supports_op = op->type == GGML_TYPE_F32;\n            break;\n        case GGML_OP_UNARY:\n            {\n                const ggml_unary_op UNARY_OP = ggml_get_unary_op(op);\n\n                switch (UNARY_OP) {\n                    case GGML_UNARY_OP_ABS:\n                    case GGML_UNARY_OP_SGN:\n                    case GGML_UNARY_OP_NEG:\n                    case GGML_UNARY_OP_STEP:\n                    case GGML_UNARY_OP_TANH:\n                    case GGML_UNARY_OP_ELU:\n                    case GGML_UNARY_OP_RELU:\n                    case GGML_UNARY_OP_SIGMOID:\n                    case GGML_UNARY_OP_GELU:\n                    case GGML_UNARY_OP_GELU_QUICK:\n                    case GGML_UNARY_OP_SILU:\n                    case GGML_UNARY_OP_HARDSWISH:\n                    case GGML_UNARY_OP_HARDSIGMOID:\n                    case GGML_UNARY_OP_EXP:\n                    case GGML_UNARY_OP_GELU_ERF:\n                    case GGML_UNARY_OP_SOFTPLUS:\n                    case GGML_UNARY_OP_EXPM1:\n                    case GGML_UNARY_OP_FLOOR:\n                    case GGML_UNARY_OP_CEIL:\n                    case GGML_UNARY_OP_ROUND:\n                    case GGML_UNARY_OP_TRUNC:\n                    case GGML_UNARY_OP_XIELU:\n                        supports_op =\n                            (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);\n                        break;\n                    default:\n                        break;\n                }\n            }\n            break;\n        case GGML_OP_CLAMP:\n            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);\n            break;\n        case GGML_OP_FILL:\n            supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;\n            break;\n        case GGML_OP_LOG:\n            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);\n            break;\n        case GGML_OP_SQR:\n            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);\n            break;\n        case GGML_OP_SQRT:\n            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);\n            break;\n        case GGML_OP_SIN:\n            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);\n            break;\n        case GGML_OP_COS:\n            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);\n            break;\n        case GGML_OP_PAD:\n            supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;\n            break;\n        case GGML_OP_ARGMAX:\n            supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32;\n            break;\n        case GGML_OP_ARGSORT:\n            supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);\n            break;\n        case GGML_OP_TOP_K:\n            supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);\n            break;\n        case GGML_OP_CUMSUM:\n            supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type;\n            break;\n        case GGML_OP_SUM:\n        case GGML_OP_SUM_ROWS:\n            supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0);\n            break;\n        default:\n            break;\n    }\n    if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||\n        (src0 != nullptr &&\n         ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||\n        (src1 != nullptr &&\n         ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||\n        (src2 != nullptr &&\n         ggml_nbytes(src2) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {\n        supports_op = false;\n        WEBGPU_LOG_DEBUG(\"ggml_webgpu op not supported due to size: \");\n    }\n\n    if (!supports_op) {\n        WEBGPU_LOG_DEBUG(\"ggml_webgpu op not supported: \"\n                         << ggml_op_name(op->op) << \" with types dst: \" << ggml_type_name(op->type)\n                         << \", src0: \" << (op->src[0] ? ggml_type_name(op->src[0]->type) : \"null\")\n                         << \", src1: \" << (op->src[1] ? ggml_type_name(op->src[1]->type) : \"null\"));\n    } else {\n        WEBGPU_LOG_DEBUG(\"ggml_webgpu op supported: \"\n                         << ggml_op_name(op->op) << \" with types dst: \" << ggml_type_name(op->type)\n                         << \", src0: \" << (op->src[0] ? ggml_type_name(op->src[0]->type) : \"null\")\n                         << \", src1: \" << (op->src[1] ? ggml_type_name(op->src[1]->type) : \"null\"));\n    }\n    return supports_op;\n}\n\nstatic struct ggml_backend_device_i ggml_backend_webgpu_device_i = {\n    /* .get_name             = */ ggml_backend_webgpu_device_get_name,\n    /* .get_description      = */ ggml_backend_webgpu_device_get_description,\n    /* .get_memory           = */ ggml_backend_webgpu_device_get_memory,\n    /* .get_type             = */ ggml_backend_webgpu_device_get_type,\n    /* .get_props            = */ ggml_backend_webgpu_device_get_props,\n    /* .init_backend         = */ ggml_backend_webgpu_backend_init,\n    /* .get_buffer_type      = */ ggml_backend_webgpu_device_get_buffer_type,\n    /* .get_host_buffer_type = */ NULL,\n    /* .buffer_from_host_ptr = */ NULL,\n    /* .supports_op          = */ ggml_backend_webgpu_device_supports_op,\n    /* .supports_buft        = */ ggml_backend_webgpu_device_supports_buft,\n    /* .offload_op           = */ NULL,\n    /* .event_new            = */ NULL,\n    /* .event_free           = */ NULL,\n    /* .event_synchronize    = */ NULL,\n};\n\n/* End GGML Backend Device Interface */\n\n/* GGML Backend Registration Interface */\n\nstatic const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) {\n    ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);\n    return ctx->name;\n}\n\nstatic size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {\n    ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);\n    return ctx->device_count;\n}\n\n// Only one device is supported for now\nstatic ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {\n    GGML_ASSERT(index == 0);\n    WEBGPU_LOG_DEBUG(\"ggml_backend_reg_get_device()\");\n\n    WEBGPU_CPU_PROFILE_TOTAL_START(reg_get_device);\n\n    ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);\n\n    create_webgpu_device(reg_ctx);\n\n    static ggml_backend_webgpu_device_context device_ctx;\n    device_ctx.device_name            = GGML_WEBGPU_NAME;\n    device_ctx.device_desc            = GGML_WEBGPU_NAME;\n    device_ctx.webgpu_global_ctx      = reg_ctx->webgpu_global_ctx;\n    // See GGML Backend Device Interface section\n    static ggml_backend_device device = {\n        /* .iface   = */ ggml_backend_webgpu_device_i,\n        /* .reg     = */ reg,\n        /* .context = */ &device_ctx,\n    };\n\n    WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, reg_ctx->webgpu_global_ctx);\n    return &device;\n}\n\nstatic const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {\n    /* .get_name         = */ ggml_backend_webgpu_reg_get_name,\n    /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,\n    /* .get_device       = */ ggml_backend_webgpu_reg_get_device,\n    /* .get_proc_address = */ NULL,\n};\n\n/* End GGML Backend Registration Interface */\n\nggml_backend_reg_t ggml_backend_webgpu_reg() {\n    WEBGPU_LOG_DEBUG(\"ggml_backend_webgpu_reg()\");\n\n    static ggml_backend_webgpu_reg_context ctx;\n    ctx.name         = GGML_WEBGPU_NAME;\n    ctx.device_count = 1;\n\n    wgpu::InstanceDescriptor               instance_descriptor{};\n    std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };\n    instance_descriptor.requiredFeatures                     = instance_features.data();\n    instance_descriptor.requiredFeatureCount                 = instance_features.size();\n\n#ifndef __EMSCRIPTEN__\n    const char * const          instanceEnabledToggles[] = { \"allow_unsafe_apis\" };\n    wgpu::DawnTogglesDescriptor instanceTogglesDesc;\n    instanceTogglesDesc.enabledToggles     = instanceEnabledToggles;\n    instanceTogglesDesc.enabledToggleCount = 1;\n    instance_descriptor.nextInChain        = &instanceTogglesDesc;\n#endif\n\n    wgpu::Instance inst             = wgpu::CreateInstance(&instance_descriptor);\n    ctx.webgpu_global_ctx           = webgpu_global_context(new webgpu_global_context_struct());\n    ctx.webgpu_global_ctx->instance = std::move(inst);\n\n#ifdef __EMSCRIPTEN__\n    if (ctx.webgpu_global_ctx->instance == nullptr) {\n        GGML_LOG_ERROR(\"ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\\n\");\n        return nullptr;\n    }\n#endif\n    GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr);\n\n    static ggml_backend_reg reg = {\n        /* .api_version = */ GGML_BACKEND_API_VERSION,\n        /* .iface       = */ ggml_backend_webgpu_reg_i,\n        /* .context     = */ &ctx,\n    };\n    return &reg;\n}\n\nggml_backend_t ggml_backend_webgpu_init(void) {\n    ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);\n\n    return ggml_backend_webgpu_backend_init(dev, nullptr);\n}\n\nGGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)\n"
  },
  {
    "path": "src/ggml-webgpu/pre_wgsl.hpp",
    "content": "#ifndef PRE_WGSL_HPP\n#define PRE_WGSL_HPP\n\n#include <cctype>\n#include <fstream>\n#include <sstream>\n#include <stdexcept>\n#include <string>\n#include <string_view>\n#include <unordered_map>\n#include <unordered_set>\n#include <vector>\n\nnamespace pre_wgsl {\n\n//==============================================================\n// Options\n//==============================================================\nstruct Options {\n    std::string              include_path = \".\";\n    std::vector<std::string> macros;\n};\n\n//==============================================================\n// Utility: trim\n//==============================================================\nstatic std::string trim(const std::string & s) {\n    size_t a = 0;\n    while (a < s.size() && std::isspace((unsigned char) s[a])) {\n        a++;\n    }\n    size_t b = s.size();\n    while (b > a && std::isspace((unsigned char) s[b - 1])) {\n        b--;\n    }\n    return s.substr(a, b - a);\n}\n\nstatic std::string trim_value(std::istream & is) {\n    std::string str;\n    std::getline(is, str);\n    return trim(str);\n}\n\nstatic bool isIdentChar(char c) {\n    return std::isalnum(static_cast<unsigned char>(c)) || c == '_';\n}\n\nstatic std::string expandMacrosRecursiveInternal(const std::string &                                  line,\n                                                 const std::unordered_map<std::string, std::string> & macros,\n                                                 std::unordered_set<std::string> &                    visiting);\n\nstatic std::string expandMacroValue(const std::string &                                  name,\n                                    const std::unordered_map<std::string, std::string> & macros,\n                                    std::unordered_set<std::string> &                    visiting) {\n    if (visiting.count(name)) {\n        throw std::runtime_error(\"Recursive macro: \" + name);\n    }\n    visiting.insert(name);\n\n    auto it = macros.find(name);\n    if (it == macros.end()) {\n        visiting.erase(name);\n        return name;\n    }\n\n    const std::string & value = it->second;\n    if (value.empty()) {\n        visiting.erase(name);\n        return \"\";\n    }\n\n    std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting);\n    visiting.erase(name);\n    return expanded;\n}\n\nstatic std::string expandMacrosRecursiveInternal(const std::string &                                  line,\n                                                 const std::unordered_map<std::string, std::string> & macros,\n                                                 std::unordered_set<std::string> &                    visiting) {\n    std::string result;\n    result.reserve(line.size());\n\n    size_t i = 0;\n    while (i < line.size()) {\n        if (isIdentChar(line[i])) {\n            size_t start = i;\n            while (i < line.size() && isIdentChar(line[i])) {\n                i++;\n            }\n            std::string token = line.substr(start, i - start);\n\n            auto it = macros.find(token);\n            if (it != macros.end()) {\n                result += expandMacroValue(token, macros, visiting);\n            } else {\n                result += token;\n            }\n        } else {\n            result += line[i];\n            i++;\n        }\n    }\n\n    return result;\n}\n\nstatic std::string expandMacrosRecursive(const std::string &                                  line,\n                                         const std::unordered_map<std::string, std::string> & macros) {\n    std::unordered_set<std::string> visiting;\n    return expandMacrosRecursiveInternal(line, macros, visiting);\n}\n\n//==============================================================\n// Tokenizer for expressions in #if/#elif\n//==============================================================\nclass ExprLexer {\n  public:\n    enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN };\n\n    struct Tok {\n        Kind        kind;\n        std::string text;\n    };\n\n    explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {}\n\n    Tok next() {\n        skipWS();\n        if (pos >= src.size()) {\n            return { END, \"\" };\n        }\n\n        char c = src[pos];\n\n        // number\n        if (std::isdigit((unsigned char) c)) {\n            size_t start = pos;\n            while (pos < src.size() && std::isdigit((unsigned char) src[pos])) {\n                pos++;\n            }\n            return { NUMBER, std::string(src.substr(start, pos - start)) };\n        }\n\n        // identifier\n        if (std::isalpha((unsigned char) c) || c == '_') {\n            size_t start = pos;\n            while (pos < src.size() && (std::isalnum((unsigned char) src[pos]) || src[pos] == '_')) {\n                pos++;\n            }\n            return { IDENT, std::string(src.substr(start, pos - start)) };\n        }\n\n        if (c == '(') {\n            pos++;\n            return { LPAREN, \"(\" };\n        }\n        if (c == ')') {\n            pos++;\n            return { RPAREN, \")\" };\n        }\n\n        // multi-char operators\n        static const char * two_ops[] = { \"==\", \"!=\", \"<=\", \">=\", \"&&\", \"||\", \"<<\", \">>\" };\n        for (auto op : two_ops) {\n            if (src.substr(pos, 2) == op) {\n                pos += 2;\n                return { OP, std::string(op) };\n            }\n        }\n\n        // single-char operators\n        if (std::string(\"+-*/%<>!\").find(c) != std::string::npos) {\n            pos++;\n            return { OP, std::string(1, c) };\n        }\n\n        // unexpected\n        pos++;\n        return { END, \"\" };\n    }\n\n  private:\n    std::string_view src;\n    size_t           pos;\n\n    void skipWS() {\n        while (pos < src.size() && std::isspace((unsigned char) src[pos])) {\n            pos++;\n        }\n    }\n};\n\n//==============================================================\n// Expression Parser (recursive descent)\n//==============================================================\nclass ExprParser {\n  public:\n    ExprParser(std::string_view                                     expr,\n               const std::unordered_map<std::string, std::string> & macros,\n               std::unordered_set<std::string> &                    visiting) :\n        lex(expr),\n        macros(macros),\n        visiting(visiting) {\n        advance();\n    }\n\n    int parse() { return parseLogicalOr(); }\n\n  private:\n    ExprLexer                                            lex;\n    ExprLexer::Tok                                       tok;\n    const std::unordered_map<std::string, std::string> & macros;\n    std::unordered_set<std::string> &                    visiting;\n\n    void advance() { tok = lex.next(); }\n\n    bool acceptOp(const std::string & s) {\n        if (tok.kind == ExprLexer::OP && tok.text == s) {\n            advance();\n            return true;\n        }\n        return false;\n    }\n\n    bool acceptKind(ExprLexer::Kind k) {\n        if (tok.kind == k) {\n            advance();\n            return true;\n        }\n        return false;\n    }\n\n    int parseLogicalOr() {\n        int v = parseLogicalAnd();\n        while (acceptOp(\"||\")) {\n            int rhs = parseLogicalAnd();\n            v       = (v || rhs);\n        }\n        return v;\n    }\n\n    int parseLogicalAnd() {\n        int v = parseEquality();\n        while (acceptOp(\"&&\")) {\n            int rhs = parseEquality();\n            v       = (v && rhs);\n        }\n        return v;\n    }\n\n    int parseEquality() {\n        int v = parseRelational();\n        for (;;) {\n            if (acceptOp(\"==\")) {\n                int rhs = parseRelational();\n                v       = (v == rhs);\n            } else if (acceptOp(\"!=\")) {\n                int rhs = parseRelational();\n                v       = (v != rhs);\n            } else {\n                break;\n            }\n        }\n        return v;\n    }\n\n    int parseRelational() {\n        int v = parseShift();\n        for (;;) {\n            if (acceptOp(\"<\")) {\n                int rhs = parseShift();\n                v       = (v < rhs);\n            } else if (acceptOp(\">\")) {\n                int rhs = parseShift();\n                v       = (v > rhs);\n            } else if (acceptOp(\"<=\")) {\n                int rhs = parseShift();\n                v       = (v <= rhs);\n            } else if (acceptOp(\">=\")) {\n                int rhs = parseShift();\n                v       = (v >= rhs);\n            } else {\n                break;\n            }\n        }\n        return v;\n    }\n\n    int parseShift() {\n        int v = parseAdd();\n        for (;;) {\n            if (acceptOp(\"<<\")) {\n                int rhs = parseAdd();\n                v       = (v << rhs);\n            } else if (acceptOp(\">>\")) {\n                int rhs = parseAdd();\n                v       = (v >> rhs);\n            } else {\n                break;\n            }\n        }\n        return v;\n    }\n\n    int parseAdd() {\n        int v = parseMult();\n        for (;;) {\n            if (acceptOp(\"+\")) {\n                int rhs = parseMult();\n                v       = (v + rhs);\n            } else if (acceptOp(\"-\")) {\n                int rhs = parseMult();\n                v       = (v - rhs);\n            } else {\n                break;\n            }\n        }\n        return v;\n    }\n\n    int parseMult() {\n        int v = parseUnary();\n        for (;;) {\n            if (acceptOp(\"*\")) {\n                int rhs = parseUnary();\n                v       = (v * rhs);\n            } else if (acceptOp(\"/\")) {\n                int rhs = parseUnary();\n                v       = (rhs == 0 ? 0 : v / rhs);\n            } else if (acceptOp(\"%\")) {\n                int rhs = parseUnary();\n                v       = (rhs == 0 ? 0 : v % rhs);\n            } else {\n                break;\n            }\n        }\n        return v;\n    }\n\n    int parseUnary() {\n        if (acceptOp(\"!\")) {\n            return !parseUnary();\n        }\n        if (acceptOp(\"-\")) {\n            return -parseUnary();\n        }\n        if (acceptOp(\"+\")) {\n            return +parseUnary();\n        }\n        return parsePrimary();\n    }\n\n    int parsePrimary() {\n        // '(' expr ')'\n        if (acceptKind(ExprLexer::LPAREN)) {\n            int v = parse();\n            if (!acceptKind(ExprLexer::RPAREN)) {\n                throw std::runtime_error(\"missing ')'\");\n            }\n            return v;\n        }\n\n        // number\n        if (tok.kind == ExprLexer::NUMBER) {\n            int v = std::stoi(tok.text);\n            advance();\n            return v;\n        }\n\n        // defined(identifier)\n        if (tok.kind == ExprLexer::IDENT && tok.text == \"defined\") {\n            advance();\n            if (acceptKind(ExprLexer::LPAREN)) {\n                if (tok.kind != ExprLexer::IDENT) {\n                    throw std::runtime_error(\"expected identifier in defined()\");\n                }\n                std::string name = tok.text;\n                advance();\n                if (!acceptKind(ExprLexer::RPAREN)) {\n                    throw std::runtime_error(\"missing ) in defined()\");\n                }\n                return macros.count(name) ? 1 : 0;\n            } else {\n                // defined NAME\n                if (tok.kind != ExprLexer::IDENT) {\n                    throw std::runtime_error(\"expected identifier in defined NAME\");\n                }\n                std::string name = tok.text;\n                advance();\n                return macros.count(name) ? 1 : 0;\n            }\n        }\n\n        // identifier -> treat as integer, if defined use its value else 0\n        if (tok.kind == ExprLexer::IDENT) {\n            std::string name = tok.text;\n            advance();\n            auto it = macros.find(name);\n            if (it == macros.end()) {\n                return 0;\n            }\n            if (it->second.empty()) {\n                return 1;\n            }\n            return evalMacroExpression(name, it->second);\n        }\n\n        // unexpected\n        return 0;\n    }\n\n    int evalMacroExpression(const std::string & name, const std::string & value) {\n        if (visiting.count(name)) {\n            throw std::runtime_error(\"Recursive macro: \" + name);\n        }\n\n        visiting.insert(name);\n        ExprParser ep(value, macros, visiting);\n        int        v = ep.parse();\n        visiting.erase(name);\n        return v;\n    }\n};\n\n//==============================================================\n// Preprocessor\n//==============================================================\nclass Preprocessor {\n  public:\n    explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) {\n        // Treat empty include path as current directory\n        if (opts_.include_path.empty()) {\n            opts_.include_path = \".\";\n        }\n        parseMacroDefinitions(opts_.macros);\n    }\n\n    std::string preprocess_file(const std::string & filename, const std::vector<std::string> & additional_macros = {}) {\n        std::unordered_map<std::string, std::string> macros;\n        std::unordered_set<std::string>              predefined;\n        std::unordered_set<std::string>              include_stack;\n        buildMacros(additional_macros, macros, predefined);\n\n        std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::All);\n        return result;\n    }\n\n    std::string preprocess(const std::string & contents, const std::vector<std::string> & additional_macros = {}) {\n        std::unordered_map<std::string, std::string> macros;\n        std::unordered_set<std::string>              predefined;\n        std::unordered_set<std::string>              include_stack;\n        buildMacros(additional_macros, macros, predefined);\n\n        std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::All);\n        return result;\n    }\n\n    std::string preprocess_includes_file(const std::string & filename) {\n        std::unordered_map<std::string, std::string> macros;\n        std::unordered_set<std::string>              predefined;\n        std::unordered_set<std::string>              include_stack;\n        std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::IncludesOnly);\n        return result;\n    }\n\n    std::string preprocess_includes(const std::string & contents) {\n        std::unordered_map<std::string, std::string> macros;\n        std::unordered_set<std::string>              predefined;\n        std::unordered_set<std::string>              include_stack;\n        std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::IncludesOnly);\n        return result;\n    }\n\n  private:\n    Options                                      opts_;\n    std::unordered_map<std::string, std::string> global_macros;\n\n    enum class DirectiveMode { All, IncludesOnly };\n\n    struct Cond {\n        bool parent_active;\n        bool active;\n        bool taken;\n    };\n\n    //----------------------------------------------------------\n    // Parse macro definitions into global_macros\n    //----------------------------------------------------------\n    void parseMacroDefinitions(const std::vector<std::string> & macro_defs) {\n        for (const auto & def : macro_defs) {\n            size_t eq_pos = def.find('=');\n            if (eq_pos != std::string::npos) {\n                // Format: NAME=VALUE\n                std::string name    = trim(def.substr(0, eq_pos));\n                std::string value   = trim(def.substr(eq_pos + 1));\n                global_macros[name] = value;\n            } else {\n                // Format: NAME\n                std::string name    = trim(def);\n                global_macros[name] = \"\";\n            }\n        }\n    }\n\n    //----------------------------------------------------------\n    // Build combined macro map and predefined set for a preprocessing operation\n    //----------------------------------------------------------\n    void buildMacros(const std::vector<std::string> &               additional_macros,\n                     std::unordered_map<std::string, std::string> & macros,\n                     std::unordered_set<std::string> &              predefined) {\n        macros = global_macros;\n        predefined.clear();\n\n        for (const auto & [name, value] : global_macros) {\n            predefined.insert(name);\n        }\n\n        for (const auto & def : additional_macros) {\n            size_t      eq_pos = def.find('=');\n            std::string name, value;\n            if (eq_pos != std::string::npos) {\n                name  = trim(def.substr(0, eq_pos));\n                value = trim(def.substr(eq_pos + 1));\n            } else {\n                name  = trim(def);\n                value = \"\";\n            }\n\n            // Add to macros map (will override global if same name)\n            macros[name] = value;\n            predefined.insert(name);\n        }\n    }\n\n    //----------------------------------------------------------\n    // Helpers\n    //----------------------------------------------------------\n    std::string loadFile(const std::string & fname) {\n        std::ifstream f(fname);\n        if (!f.is_open()) {\n            throw std::runtime_error(\"Could not open file: \" + fname);\n        }\n        std::stringstream ss;\n        ss << f.rdbuf();\n        return ss.str();\n    }\n\n    bool condActive(const std::vector<Cond> & cond) const {\n        if (cond.empty()) {\n            return true;\n        }\n        return cond.back().active;\n    }\n\n    //----------------------------------------------------------\n    // Process a file\n    //----------------------------------------------------------\n    std::string processFile(const std::string &                            name,\n                            std::unordered_map<std::string, std::string> & macros,\n                            const std::unordered_set<std::string> &        predefined_macros,\n                            std::unordered_set<std::string> &              include_stack,\n                            DirectiveMode                                  mode) {\n        if (include_stack.count(name)) {\n            throw std::runtime_error(\"Recursive include: \" + name);\n        }\n\n        include_stack.insert(name);\n        std::string shader_code = loadFile(name);\n        std::string out         = processString(shader_code, macros, predefined_macros, include_stack, mode);\n        include_stack.erase(name);\n        return out;\n    }\n\n    std::string processIncludeFile(const std::string &                            fname,\n                                   std::unordered_map<std::string, std::string> & macros,\n                                   const std::unordered_set<std::string> &        predefined_macros,\n                                   std::unordered_set<std::string> &              include_stack,\n                                   DirectiveMode                                  mode) {\n        std::string full_path = opts_.include_path + \"/\" + fname;\n        return processFile(full_path, macros, predefined_macros, include_stack, mode);\n    }\n\n    //----------------------------------------------------------\n    // Process text\n    //----------------------------------------------------------\n    std::string processString(const std::string &                            shader_code,\n                              std::unordered_map<std::string, std::string> & macros,\n                              const std::unordered_set<std::string> &        predefined_macros,\n                              std::unordered_set<std::string> &              include_stack,\n                              DirectiveMode                                  mode) {\n        std::vector<Cond>  cond;  // Conditional stack for this shader\n        std::stringstream  out;\n        std::istringstream in(shader_code);\n        std::string        line;\n\n        while (std::getline(in, line)) {\n            std::string t = trim(line);\n\n            if (!t.empty() && t[0] == '#') {\n                bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode);\n                if (mode == DirectiveMode::IncludesOnly && !handled) {\n                    out << line << \"\\n\";\n                }\n            } else {\n                if (mode == DirectiveMode::IncludesOnly) {\n                    out << line << \"\\n\";\n                } else if (condActive(cond)) {\n                    // Expand macros in the line before outputting\n                    std::string expanded = expandMacrosRecursive(line, macros);\n                    out << expanded << \"\\n\";\n                }\n            }\n        }\n\n        if (mode == DirectiveMode::All && !cond.empty()) {\n            throw std::runtime_error(\"Unclosed #if directive\");\n        }\n\n        return out.str();\n    }\n\n    //----------------------------------------------------------\n    // Directive handler\n    //----------------------------------------------------------\n    bool handleDirective(const std::string &                            t,\n                         std::stringstream &                            out,\n                         std::unordered_map<std::string, std::string> & macros,\n                         const std::unordered_set<std::string> &        predefined_macros,\n                         std::vector<Cond> &                            cond,\n                         std::unordered_set<std::string> &              include_stack,\n                         DirectiveMode                                  mode) {\n        // split into tokens\n        std::string        body = t.substr(1);\n        std::istringstream iss(body);\n        std::string        cmd;\n        iss >> cmd;\n\n        if (cmd == \"include\") {\n            if (mode == DirectiveMode::All && !condActive(cond)) {\n                return true;\n            }\n            std::string file;\n            iss >> file;\n            if (file.size() >= 2 && file.front() == '\"' && file.back() == '\"') {\n                file = file.substr(1, file.size() - 2);\n            }\n            out << processIncludeFile(file, macros, predefined_macros, include_stack, mode);\n            return true;\n        }\n\n        if (mode == DirectiveMode::IncludesOnly) {\n            return false;\n        }\n\n        if (cmd == \"define\") {\n            if (!condActive(cond)) {\n                return true;\n            }\n            std::string name;\n            iss >> name;\n            // Don't override predefined macros from options\n            if (predefined_macros.count(name)) {\n                return true;\n            }\n            std::string value = trim_value(iss);\n            macros[name]      = value;\n            return true;\n        }\n\n        if (cmd == \"undef\") {\n            if (!condActive(cond)) {\n                return true;\n            }\n            std::string name;\n            iss >> name;\n            // Don't undef predefined macros from options\n            if (predefined_macros.count(name)) {\n                return true;\n            }\n            macros.erase(name);\n            return true;\n        }\n\n        if (cmd == \"ifdef\") {\n            std::string name;\n            iss >> name;\n            bool p = condActive(cond);\n            bool v = macros.count(name);\n            cond.push_back({ p, p && v, p && v });\n            return true;\n        }\n\n        if (cmd == \"ifndef\") {\n            std::string name;\n            iss >> name;\n            bool p = condActive(cond);\n            bool v = !macros.count(name);\n            cond.push_back({ p, p && v, p && v });\n            return true;\n        }\n\n        if (cmd == \"if\") {\n            std::string expr = trim_value(iss);\n            bool        p    = condActive(cond);\n            bool        v    = false;\n            if (p) {\n                std::unordered_set<std::string> visiting;\n                ExprParser                      ep(expr, macros, visiting);\n                v = ep.parse() != 0;\n            }\n            cond.push_back({ p, p && v, p && v });\n            return true;\n        }\n\n        if (cmd == \"elif\") {\n            std::string expr = trim_value(iss);\n\n            if (cond.empty()) {\n                throw std::runtime_error(\"#elif without #if\");\n            }\n\n            Cond & c = cond.back();\n            if (!c.parent_active) {\n                c.active = false;\n                return true;\n            }\n\n            if (c.taken) {\n                c.active = false;\n                return true;\n            }\n\n            std::unordered_set<std::string> visiting;\n            ExprParser                      ep(expr, macros, visiting);\n            bool                            v = ep.parse() != 0;\n            c.active                          = v;\n            if (v) {\n                c.taken = true;\n            }\n            return true;\n        }\n\n        if (cmd == \"else\") {\n            if (cond.empty()) {\n                throw std::runtime_error(\"#else without #if\");\n            }\n\n            Cond & c = cond.back();\n            if (!c.parent_active) {\n                c.active = false;\n                return true;\n            }\n            if (c.taken) {\n                c.active = false;\n            } else {\n                c.active = true;\n                c.taken  = true;\n            }\n            return true;\n        }\n\n        if (cmd == \"endif\") {\n            if (cond.empty()) {\n                throw std::runtime_error(\"#endif without #if\");\n            }\n            cond.pop_back();\n            return true;\n        }\n\n        // Unknown directive\n        throw std::runtime_error(\"Unknown directive: #\" + cmd);\n    }\n};\n\n}  // namespace pre_wgsl\n\n#endif  // PRE_WGSL_HPP\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/argmax.wgsl",
    "content": "@group(0) @binding(0)\n#ifdef VEC4\nvar<storage, read_write> src: array<vec4<f32>>;\n#define VEC_SIZE 4\n#else\nvar<storage, read_write> src: array<f32>;\n#define VEC_SIZE 1\n#endif\n\n@group(0) @binding(1)\nvar<storage, read_write> dst: array<i32>;\n\nstruct Params {\n    offset_src: u32, // in elements\n    offset_dst: u32, // in elements\n    ne0: u32,\n};\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n\nconst FLOAT_MIN: f32 = -1.0e9;\n\nstruct Pair {\n    value: f32,\n    index: i32\n};\n\nvar<workgroup> shared_max: array<Pair, WG_SIZE>;\n\n@compute @workgroup_size(WG_SIZE)\nfn main(@builtin(workgroup_id) wid: vec3<u32>,\n        @builtin(local_invocation_id) lid: vec3<u32>) {\n    let row_idx = params.offset_src + wid.x * params.ne0;\n    var local_pair = Pair(FLOAT_MIN, -1);\n#ifdef VEC4\n    for (var col = lid.x; col < params.ne0/VEC_SIZE; col += WG_SIZE) {\n        let vec_val = src[row_idx / VEC_SIZE + col];\n        for (var v = 0u; v < VEC_SIZE; v++) {\n            let val = vec_val[v];\n            if (val >= local_pair.value) {\n                local_pair = Pair(val, i32(col * VEC_SIZE + v));\n            }\n        }\n    }\n#else\n    for (var col = lid.x; col < params.ne0; col += WG_SIZE) {\n        if (src[row_idx + col] >= local_pair.value) {\n            local_pair = Pair(src[row_idx + col], i32(col));\n        }\n    }\n#endif\n    shared_max[lid.x] = local_pair;\n    workgroupBarrier();\n    var offset: u32 = WG_SIZE >> 1;\n    while (offset > 0) {\n        if (lid.x < offset) {\n            let a = shared_max[lid.x];\n            let b = shared_max[lid.x + offset];\n            if (b.value > a.value) {\n                shared_max[lid.x] = b;\n            } else if (b.value == a.value && b.index > a.index) {\n                shared_max[lid.x] = b;\n            }\n        }\n        workgroupBarrier();\n        offset >>= 1;\n    }\n    if (lid.x == 0u) {\n        dst[params.offset_dst + wid.x] = shared_max[0].index;\n    }\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/argsort.wgsl",
    "content": "@group(0) @binding(0)\nvar<storage, read_write> src: array<f32>;\n\n@group(0) @binding(1)\nvar<storage, read_write> dst: array<i32>;\n\nstruct Params {\n    offset_src: u32, // in elements\n    offset_dst: u32, // in elements\n\n    stride_src1: u32,\n    stride_src2: u32,\n    stride_src3: u32,\n\n    stride_dst1: u32,\n    stride_dst2: u32,\n    stride_dst3: u32,\n\n    // src/dst dimensions\n    src_ne0: u32,\n    ne1: u32,\n    ne2: u32,\n\n    ne0: u32,\n    top_k: u32,\n\n    npr: u32,   // tiles per row\n    nrows: u32\n};\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n\nvar<workgroup> shmem_idx: array<u32, WG_SIZE>;\n\n#if ORDER == 0\n#define EXTREME_VALUE 1e30\n#define SWAP_COMPARE_UP >\n#define SWAP_COMPARE_DOWN <\n#else\n#define EXTREME_VALUE -1e30\n#define SWAP_COMPARE_UP <\n#define SWAP_COMPARE_DOWN >\n#endif\n\n@compute @workgroup_size(WG_SIZE)\nfn main(@builtin(workgroup_id) wid: vec3<u32>,\n        @builtin(num_workgroups) num_wg: vec3<u32>,\n        @builtin(local_invocation_id) lid: vec3<u32>) {\n    let linear = wid.x + wid.y * num_wg.x;\n    // guard against overprovisioned workgroups\n    if (linear >= params.npr * params.nrows) {\n        return;\n    }\n    let tile = linear % params.npr;\n    var row = linear / params.npr;\n    let i3 = row / (params.ne2 * params.ne1);\n    row = row % (params.ne2 * params.ne1);\n    let i2 = row / params.ne1;\n    let i1 = row % params.ne1;\n\n    let row_base = params.offset_src +\n        i1 * params.stride_src1 +\n        i2 * params.stride_src2 +\n        i3 * params.stride_src3;\n\n    let tile_base = tile * WG_SIZE;\n    let idx = tile_base + lid.x;\n    shmem_idx[lid.x] = select(params.src_ne0, idx, idx < params.src_ne0);\n    workgroupBarrier();\n\n    var k = 2u;\n    while (k <= WG_SIZE) {\n        var j = k >> 1;\n        while (j > 0) {\n            let ixj = lid.x ^ j;\n            if (ixj > lid.x) {\n                let dir_up = (lid.x & k) == 0;\n                let a_idx = shmem_idx[lid.x];\n                let b_idx = shmem_idx[ixj];\n                let a_val = select(EXTREME_VALUE, src[row_base + a_idx], a_idx < params.src_ne0);\n                let b_val = select(EXTREME_VALUE, src[row_base + b_idx], b_idx < params.src_ne0);\n                let should_swap = select(\n                    (a_val SWAP_COMPARE_DOWN b_val),\n                    (a_val SWAP_COMPARE_UP b_val),\n                    dir_up);\n                if (should_swap) {\n                    shmem_idx[lid.x] = b_idx;\n                    shmem_idx[ixj] = a_idx;\n                }\n            }\n            workgroupBarrier();\n            j >>= 1;\n        }\n        k <<= 1;\n    }\n\n    let out_idx = tile * params.top_k + lid.x;\n    if (out_idx < params.ne0 && lid.x < params.top_k) {\n        let row_dst = params.offset_dst +\n            i1 * params.stride_dst1 +\n            i2 * params.stride_dst2 +\n            i3 * params.stride_dst3;\n        dst[row_dst + out_idx] = i32(shmem_idx[lid.x]);\n    }\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl",
    "content": "@group(0) @binding(0)\nvar<storage, read_write> src: array<f32>;\n\n@group(0) @binding(1)\nvar<storage, read_write> idx_in: array<i32>;\n\n@group(0) @binding(2)\nvar<storage, read_write> idx_out: array<i32>;\n\nstruct Params {\n    offset_src: u32, // in elements\n    offset_in: u32,  // in elements\n    offset_out: u32, // in elements\n\n    stride_src1: u32,\n    stride_src2: u32,\n    stride_src3: u32,\n\n    stride_idx1: u32,\n    stride_idx2: u32,\n    stride_idx3: u32,\n\n    stride_out1: u32,\n    stride_out2: u32,\n    stride_out3: u32,\n\n    ne0: u32,\n    ne1: u32,\n    ne2: u32,\n\n    top_k: u32,\n\n    len: u32,\n    nm: u32,\n    nrows: u32\n};\n\n@group(0) @binding(3)\nvar<uniform> params: Params;\n\nfn take_left(a_idx: i32, b_idx: i32, row_base: u32) -> bool {\n    let a_val = src[row_base + u32(a_idx)];\n    let b_val = src[row_base + u32(b_idx)];\n#if ORDER == 0\n    return a_val <= b_val;\n#else\n    return a_val >= b_val;\n#endif\n}\n\n@compute @workgroup_size(WG_SIZE)\nfn main(@builtin(workgroup_id) wid: vec3<u32>,\n        @builtin(num_workgroups) num_wg: vec3<u32>,\n        @builtin(local_invocation_id) lid: vec3<u32>) {\n    let linear = wid.x + wid.y * num_wg.x;\n    // guard against overprovisioned workgroups\n    if (linear >= params.nm * params.nrows) {\n        return;\n    }\n\n    let start = (linear % params.nm) * params.len * 2;\n    let len0 = min(params.len, params.ne0 - start);\n    let rem1 = select(0, params.ne0 - (start + params.len), params.ne0 > (start + params.len));\n    let len1 = min(params.len, rem1);\n    let total = len0 + len1;\n    let chunk = (total + WG_SIZE - 1u) / WG_SIZE;\n    let k0 = lid.x * chunk;\n    let k1 = min(min(k0 + chunk, total), params.top_k);\n    // guard against overprovisioned threads\n    if (k0 >= params.top_k || k0 >= total) {\n        return;\n    }\n\n    var row = linear / params.nm;\n    let i3 = row / (params.ne2 * params.ne1);\n    row = row % (params.ne2 * params.ne1);\n    let i2 = row / params.ne1;\n    let i1 = row % params.ne1;\n\n    let row_src = params.offset_src +\n        i1 * params.stride_src1 +\n        i2 * params.stride_src2 +\n        i3 * params.stride_src3;\n\n    let row_in = params.offset_in +\n        i1 * params.stride_idx1 +\n        i2 * params.stride_idx2 +\n        i3 * params.stride_idx3;\n\n    let row_out = params.offset_out +\n        i1 * params.stride_out1 +\n        i2 * params.stride_out2 +\n        i3 * params.stride_out3;\n\n\n    var low: u32 = select(0, k0 - len1, k0 > len1);\n    var high: u32 = min(k0, len0);\n\n    while (low < high) {\n        let mid = (low + high) >> 1;\n        let idx0 = idx_in[row_in + start + mid];\n        let idx1 = idx_in[row_in + start + params.len + (k0 - mid - 1)];\n        if (take_left(idx0, idx1, row_src)) {\n            low = mid + 1;\n        } else {\n            high = mid;\n        }\n    }\n\n    var i = low;\n    var j = k0 - i;\n    var k = k0;\n    while (k < k1) {\n        var take_l = false;\n        if (i >= len0) {\n            take_l = false;\n        } else if (j >= len1) {\n            take_l = true;\n        } else {\n            let idx0 = idx_in[row_in + start + i];\n            let idx1 = idx_in[row_in + start + params.len + j];\n            take_l = take_left(idx0, idx1, row_src);\n        }\n\n        let out_idx = select(\n            idx_in[row_in + start + params.len + j],\n            idx_in[row_in + start + i],\n            take_l);\n        idx_out[row_out + start + k] = out_idx;\n        i = select(i, i + 1, take_l);\n        j = select(j + 1, j, take_l);\n        k += 1;\n    }\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/binary.wgsl",
    "content": "enable f16;\n\nstruct Params {\n    ne: u32,\n\n    // offsets in elements\n    offset_src0: u32,\n    offset_src1: u32,\n    offset_dst: u32,\n    offset_merged_src0: u32,\n    offset_merged_src1: u32,\n\n    stride_src0_0: u32,\n    stride_src0_1: u32,\n    stride_src0_2: u32,\n    stride_src0_3: u32,\n\n    stride_src1_0: u32,\n    stride_src1_1: u32,\n    stride_src1_2: u32,\n    stride_src1_3: u32,\n\n    a_ne0: u32,\n    a_ne1: u32,\n    a_ne2: u32,\n\n    b_ne0: u32,\n    b_ne1: u32,\n    b_ne2: u32,\n    b_ne3: u32,\n};\n\nfn src0_index(_i: u32) -> u32 {\n    var i = _i;\n    let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);\n    i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);\n    let a_i2 = i / (params.a_ne1 * params.a_ne0);\n    i = i % (params.a_ne1 * params.a_ne0);\n    let a_i1 = i / params.a_ne0;\n    let a_i0 = i % params.a_ne0;\n\n    return a_i0 * params.stride_src0_0 +\n           a_i1 * params.stride_src0_1 +\n           a_i2 * params.stride_src0_2 +\n           a_i3 * params.stride_src0_3;\n}\n\nfn src1_index(_i: u32) -> u32 {\n    var i = _i;\n    let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);\n    i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);\n    let a_i2 = i / (params.a_ne1 * params.a_ne0);\n    i = i % (params.a_ne1 * params.a_ne0);\n    let a_i1 = i / params.a_ne0;\n    let a_i0 = i % params.a_ne0;\n\n    // handle repetition of b\n    // index loops back to the beginning and repeats after elements are exhausted = modulo\n    let b_i0 = a_i0 % params.b_ne0;\n    let b_i1 = a_i1 % params.b_ne1;\n    let b_i2 = a_i2 % params.b_ne2;\n    let b_i3 = a_i3 % params.b_ne3;\n\n    // compute index for position in b's flat array\n    return b_i0 * params.stride_src1_0 +\n           b_i1 * params.stride_src1_1 +\n           b_i2 * params.stride_src1_2 +\n           b_i3 * params.stride_src1_3;\n}\n\n#ifdef TYPE_F32\n#define DataType f32\n#endif\n#ifdef TYPE_F16\n#define DataType f16\n#endif\n\n#ifdef SRC_OVERLAP\n@group(0) @binding(0)\nvar<storage, read_write> merged_src: array<DataType>;\n\n@group(0) @binding(1)\nvar<storage, read_write> dst: array<DataType>;\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n#else\n@group(0) @binding(0)\nvar<storage, read_write> src0: array<DataType>;\n\n@group(0) @binding(1)\nvar<storage, read_write> src1 : array<DataType>;\n#if defined(INPLACE) || defined(OVERLAP)\n@group(0) @binding(2)\nvar<uniform> params: Params;\n\n#else\n@group(0) @binding(2)\nvar<storage, read_write> dst: array<DataType>;\n\n@group(0) @binding(3)\nvar<uniform> params: Params;\n#endif\n#endif\n\nfn op(a: DataType, b: DataType) -> DataType {\n#ifdef OP_ADD\n    return a + b;\n#elif defined(OP_SUB)\n    return a - b;\n#elif defined(OP_MUL)\n    return a * b;\n#elif defined(OP_DIV)\n    return a / b;\n#endif\n}\n\nfn update(dst_i: u32, src0_i: u32, src1_i: u32) {\n#ifdef SRC_OVERLAP\n    let result = op(merged_src[src0_i], merged_src[src1_i]);\n#else\n    let result = op(src0[src0_i], src1[src1_i]);\n#endif\n\n#ifdef INPLACE\n    src0[src0_i] = result;\n#elif defined(OVERLAP)\n    src1[src1_i] = result;\n#else\n    dst[dst_i] = result;\n#endif\n}\n\n@compute @workgroup_size(WG_SIZE)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n    if (gid.x < params.ne) {\n        let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x);\n        let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x);\n        update(params.offset_dst + gid.x, src0_i, src1_i);\n    }\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/common_decls.tmpl",
    "content": "#ifdef BYTE_HELPERS\nfn get_byte(value: u32, index: u32) -> u32 {\n    return (value >> (index * 8)) & 0xFF;\n}\n\nfn get_byte_i32(value: u32, index: u32) -> i32 {\n    return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;\n}\n#endif\n\n#ifdef Q4_0_T\nstruct q4_0 {\n    d: f16,\n    qs: array<f16, 8>\n};\n#endif\n\n#ifdef Q4_1_T\nstruct q4_1 {\n    d: f16,\n    m: f16,\n    qs: array<u32, 4>\n};\n#endif\n\n#ifdef Q5_0_T\nstruct q5_0 {\n    d: f16,\n    qh: array<f16, 2>,\n    qs: array<f16, 8>\n};\n#endif\n\n#ifdef Q5_1_T\nstruct q5_1 {\n    d: f16,\n    m: f16,\n    qh: u32,\n    qs: array<u32, 4>\n};\n#endif\n\n#ifdef Q8_0_T\nstruct q8_0 {\n    d: f16,\n    qs: array<f16, 16>\n};\n#endif\n\n#ifdef Q8_1_T\nstruct q8_1 {\n    d: f16,\n    m: f16,\n    qs: array<u32, 8>\n};\n#endif\n\n#ifdef Q2_K_T\nstruct q2_K {\n    scales: array<u32, 4>,\n    qs: array<u32, 16>,\n    d: f16,\n    dmin: f16\n};\n#endif\n\n#ifdef Q3_K_T\nstruct q3_K {\n    hmask: array<f16, 16>,\n    qs: array<f16, 32>,\n    scales: array<f16, 6>,\n    d: f16\n};\n#endif\n\n#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)\nfn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {\n    if (is < 4) {\n        let sc_byte = get_byte(scales[is / 4], is % 4);\n        let min_byte = get_byte(scales[(is + 4) / 4], is % 4);\n        return vec2(f32(sc_byte & 63), f32(min_byte & 63));\n    } else {\n        let sc_min_lo = get_byte(scales[(is + 4) / 4], (is + 4) % 4);\n        let sc_hi = get_byte(scales[(is - 4) / 4], (is - 4) % 4);\n        let min_hi = get_byte(scales[is / 4], is % 4);\n        let sc = (sc_min_lo & 0xF) | ((sc_hi >> 6) << 4);\n        let m = (sc_min_lo >> 4) | ((min_hi >> 6) << 4);\n        return vec2(f32(sc), f32(m));\n    }\n}\n#endif\n#ifdef Q4_K_T\nstruct q4_K {\n    d: f16,\n    dmin: f16,\n    scales: array<u32, 3>,\n    qs: array<u32, 32>\n};\n#endif\n\n#ifdef Q5_K_T\nstruct q5_K {\n    d: f16,\n    dmin: f16,\n    scales: array<u32, 3>,\n    qh: array<u32, 8>,\n    qs: array<u32, 32>\n};\n#endif\n\n#ifdef Q6_K_T\nstruct q6_K {\n    ql: array<f16, 64>,\n    qh: array<f16, 32>,\n    scales: array<f16, 8>,\n    d: f16\n};\n#endif\n\n#ifdef IQ2_XXS_T\nstruct iq2_xxs {\n    d: f16,\n    qs: array<f16, 32>\n};\n#endif\n\n#ifdef IQ2_XS_T\nstruct iq2_xs {\n    d: f16,\n    qs: array<f16, 32>,\n    scales: array<f16, 4>\n};\n#endif\n\n#ifdef IQ2_S_T\nstruct iq2_s {\n    d: f16,\n    qs: array<f16, 32>,\n    qh: array<f16, 4>,\n    scales: array<f16, 4>\n};\n#endif\n\n#ifdef IQ3_XXS_T\nstruct iq3_xxs {\n    d: f16,\n    qs: array<f16, 48>\n};\n#endif\n\n#ifdef IQ3_S_T\nstruct iq3_s {\n    d: f16,\n    qs: array<f16, 32>,\n    qh: array<f16, 4>,\n    signs: array<f16, 16>,\n    scales: array<f16, 2>\n};\n#endif\n\n#ifdef IQ1_S_T\nstruct iq1_s {\n    d: f16,\n    qs: array<f16, 16>,\n    qh: array<f16, 8>\n};\n#endif\n\n#ifdef IQ1_M_T\nstruct iq1_m {\n    qs: array<u32, 8>,\n    qh: array<u32, 4>,\n    scales: array<u32, 2>\n};\n#endif\n\n#ifdef IQ4_NL_T\nstruct iq4_nl {\n    d: f16,\n    qs: array<f16, 8>,\n};\n#endif\n\n#ifdef IQ4_XS_T\nstruct iq4_xs {\n    d: f16,\n    scales_h: f16,\n    scales_l: u32,\n    qs: array<u32, 32>\n};\n#endif\n\n#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES)\nconst kmask_iq2xs : array<u32, 2> = array<u32, 2>(\n    0x08040201u, // 1, 2, 4, 8\n    0x80402010u  // 16, 32, 64, 128\n);\n\nconst ksigns_iq2xs: array<u32, 32> = array<u32, 32>(\n    0x03828100,0x87060584,0x8b0a0988,0x0f8e8d0c,\n    0x93121190,0x17969514,0x1b9a9918,0x9f1e1d9c,\n    0xa32221a0,0x27a6a524,0x2baaa928,0xaf2e2dac,\n    0x33b2b130,0xb73635b4,0xbb3a39b8,0x3fbebd3c,\n    0xc34241c0,0x47c6c544,0x4bcac948,0xcf4e4dcc,\n    0x53d2d150,0xd75655d4,0xdb5a59d8,0x5fdedd5c,\n    0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c,\n    0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc\n);\n#endif\n\n#ifdef IQ2_XXS_GRID\nconst iq2xxs_grid = array<u32, 512>(\n    0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,\n    0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808,\n    0x082b082b, 0x08080808, 0x082b2b08, 0x08080808, 0x082b2b2b, 0x08080808, 0x19080819, 0x08080808,\n    0x19081908, 0x08080808, 0x19190808, 0x08080808, 0x19192b08, 0x08080808, 0x192b0819, 0x08080808,\n    0x192b1908, 0x08080808, 0x2b080808, 0x08080808, 0x2b08082b, 0x08080808, 0x2b082b2b, 0x08080808,\n    0x2b2b082b, 0x08080808, 0x08080819, 0x08080819, 0x08081908, 0x08080819, 0x08190808, 0x08080819,\n    0x08191919, 0x08080819, 0x19080808, 0x08080819, 0x2b081908, 0x08080819, 0x2b192b08, 0x08080819,\n    0x08080808, 0x0808082b, 0x0808082b, 0x0808082b, 0x082b082b, 0x0808082b, 0x2b08082b, 0x0808082b,\n    0x08080819, 0x08081908, 0x08081908, 0x08081908, 0x08190808, 0x08081908, 0x082b0819, 0x08081908,\n    0x082b1908, 0x08081908, 0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19082b08, 0x08081908,\n    0x192b0808, 0x08081908, 0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b190808, 0x08081908,\n    0x2b2b1908, 0x08081908, 0x08080808, 0x08081919, 0x0808082b, 0x08081919, 0x08082b08, 0x08081919,\n    0x082b0808, 0x08081919, 0x1908192b, 0x08081919, 0x192b2b19, 0x08081919, 0x2b080808, 0x08081919,\n    0x2b190819, 0x08081919, 0x08082b19, 0x0808192b, 0x08190808, 0x0808192b, 0x19080808, 0x0808192b,\n    0x2b081908, 0x0808192b, 0x2b2b1908, 0x0808192b, 0x08080808, 0x08082b08, 0x08081919, 0x08082b08,\n    0x08082b08, 0x08082b08, 0x08191908, 0x08082b08, 0x082b2b08, 0x08082b08, 0x19080819, 0x08082b08,\n    0x19081908, 0x08082b08, 0x19190808, 0x08082b08, 0x1919082b, 0x08082b08, 0x2b082b08, 0x08082b08,\n    0x08081908, 0x08082b19, 0x19080808, 0x08082b19, 0x0808082b, 0x08082b2b, 0x08191908, 0x08082b2b,\n    0x08080819, 0x08190808, 0x08081908, 0x08190808, 0x08190808, 0x08190808, 0x082b0819, 0x08190808,\n    0x19080808, 0x08190808, 0x192b0808, 0x08190808, 0x2b081908, 0x08190808, 0x2b190808, 0x08190808,\n    0x2b191919, 0x08190808, 0x08080808, 0x08190819, 0x08082b08, 0x08190819, 0x082b0808, 0x08190819,\n    0x19190808, 0x08190819, 0x19192b2b, 0x08190819, 0x2b080808, 0x08190819, 0x082b1908, 0x0819082b,\n    0x19081919, 0x0819082b, 0x08080808, 0x08191908, 0x08082b08, 0x08191908, 0x082b0808, 0x08191908,\n    0x082b1919, 0x08191908, 0x19082b19, 0x08191908, 0x2b080808, 0x08191908, 0x08192b08, 0x08191919,\n    0x192b082b, 0x08191919, 0x08080808, 0x0819192b, 0x0819192b, 0x0819192b, 0x08080819, 0x08192b08,\n    0x08081908, 0x08192b08, 0x08190808, 0x08192b08, 0x19080808, 0x08192b08, 0x2b080819, 0x08192b08,\n    0x08080808, 0x08192b19, 0x08081919, 0x08192b19, 0x2b2b0808, 0x08192b19, 0x19190819, 0x08192b2b,\n    0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08082b2b, 0x082b0808, 0x19081908, 0x082b0808,\n    0x192b0819, 0x082b0808, 0x2b080808, 0x082b0808, 0x2b08082b, 0x082b0808, 0x082b2b19, 0x082b0819,\n    0x19082b08, 0x082b0819, 0x08080808, 0x082b082b, 0x0808082b, 0x082b082b, 0x08080819, 0x082b1908,\n    0x08081908, 0x082b1908, 0x08190808, 0x082b1908, 0x19080808, 0x082b1908, 0x1919192b, 0x082b1908,\n    0x08080808, 0x082b1919, 0x19080819, 0x082b1919, 0x192b1908, 0x082b1919, 0x2b190808, 0x082b192b,\n    0x08082b08, 0x082b2b08, 0x082b0808, 0x082b2b08, 0x2b191908, 0x082b2b08, 0x19081908, 0x082b2b2b,\n    0x08080819, 0x19080808, 0x08081908, 0x19080808, 0x08190808, 0x19080808, 0x08192b08, 0x19080808,\n    0x082b0819, 0x19080808, 0x082b1908, 0x19080808, 0x19080808, 0x19080808, 0x19082b08, 0x19080808,\n    0x1919192b, 0x19080808, 0x192b0808, 0x19080808, 0x2b080819, 0x19080808, 0x2b081908, 0x19080808,\n    0x2b190808, 0x19080808, 0x08080808, 0x19080819, 0x082b0808, 0x19080819, 0x192b0819, 0x19080819,\n    0x2b080808, 0x19080819, 0x2b081919, 0x19080819, 0x08080819, 0x1908082b, 0x08190808, 0x1908082b,\n    0x19082b08, 0x1908082b, 0x1919192b, 0x1908082b, 0x192b2b08, 0x1908082b, 0x08080808, 0x19081908,\n    0x08082b08, 0x19081908, 0x082b0808, 0x19081908, 0x2b080808, 0x19081908, 0x2b192b19, 0x19081908,\n    0x0819082b, 0x19081919, 0x082b1908, 0x19081919, 0x08080808, 0x1908192b, 0x08080819, 0x19082b08,\n    0x08081908, 0x19082b08, 0x08190808, 0x19082b08, 0x19080808, 0x19082b08, 0x19081919, 0x19082b08,\n    0x08080808, 0x19082b19, 0x19192b08, 0x19082b19, 0x192b0819, 0x19082b19, 0x2b08082b, 0x19082b19,\n    0x19081919, 0x19082b2b, 0x2b190808, 0x19082b2b, 0x08080808, 0x19190808, 0x08082b08, 0x19190808,\n    0x08190819, 0x19190808, 0x08192b19, 0x19190808, 0x082b0808, 0x19190808, 0x2b080808, 0x19190808,\n    0x2b082b08, 0x19190808, 0x08081908, 0x19190819, 0x1908082b, 0x19190819, 0x2b2b1908, 0x19190819,\n    0x2b190819, 0x1919082b, 0x2b190808, 0x19191908, 0x2b19082b, 0x19191908, 0x08082b2b, 0x19191919,\n    0x08080819, 0x1919192b, 0x19191908, 0x1919192b, 0x08080808, 0x19192b08, 0x08190819, 0x19192b08,\n    0x08192b19, 0x19192b08, 0x192b1908, 0x19192b08, 0x19080808, 0x19192b19, 0x08082b08, 0x19192b2b,\n    0x08081908, 0x192b0808, 0x08190808, 0x192b0808, 0x19080808, 0x192b0808, 0x192b2b08, 0x192b0808,\n    0x08080808, 0x192b0819, 0x19191919, 0x192b0819, 0x08192b08, 0x192b082b, 0x192b0808, 0x192b082b,\n    0x08080808, 0x192b1908, 0x08081919, 0x192b1908, 0x08190808, 0x192b1919, 0x0819082b, 0x192b1919,\n    0x2b081908, 0x192b1919, 0x1908082b, 0x192b2b08, 0x08080808, 0x2b080808, 0x0808082b, 0x2b080808,\n    0x08082b2b, 0x2b080808, 0x19080819, 0x2b080808, 0x2b08082b, 0x2b080808, 0x08081908, 0x2b080819,\n    0x08192b08, 0x2b080819, 0x19080808, 0x2b080819, 0x08190819, 0x2b08082b, 0x08080819, 0x2b081908,\n    0x08081908, 0x2b081908, 0x08190808, 0x2b081908, 0x08191919, 0x2b081908, 0x19080808, 0x2b081908,\n    0x192b0808, 0x2b081908, 0x08080808, 0x2b081919, 0x1908192b, 0x2b081919, 0x2b191908, 0x2b081919,\n    0x08082b19, 0x2b08192b, 0x19080808, 0x2b08192b, 0x192b0808, 0x2b08192b, 0x0808082b, 0x2b082b08,\n    0x08081908, 0x2b082b19, 0x08190819, 0x2b082b2b, 0x08081908, 0x2b190808, 0x08190808, 0x2b190808,\n    0x082b1908, 0x2b190808, 0x19080808, 0x2b190808, 0x2b2b0819, 0x2b190808, 0x0819192b, 0x2b190819,\n    0x2b080808, 0x2b190819, 0x19081919, 0x2b19082b, 0x08080808, 0x2b191908, 0x082b082b, 0x2b191908,\n    0x19081908, 0x2b191908, 0x19190819, 0x2b191919, 0x2b080819, 0x2b192b08, 0x082b0808, 0x2b192b19,\n    0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819,\n    0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19\n);\n#endif\n\n#ifdef IQ2_XS_GRID\nconst iq2xs_grid = array<u32, 1024>(\n    0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,\n    0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,\n    0x08192b19, 0x08080808, 0x082b0808, 0x08080808, 0x082b082b, 0x08080808, 0x082b1919, 0x08080808,\n    0x082b2b08, 0x08080808, 0x19080819, 0x08080808, 0x19081908, 0x08080808, 0x1908192b, 0x08080808,\n    0x19082b19, 0x08080808, 0x19190808, 0x08080808, 0x1919082b, 0x08080808, 0x19191919, 0x08080808,\n    0x19192b08, 0x08080808, 0x192b0819, 0x08080808, 0x192b1908, 0x08080808, 0x2b080808, 0x08080808,\n    0x2b08082b, 0x08080808, 0x2b081919, 0x08080808, 0x2b082b08, 0x08080808, 0x2b190819, 0x08080808,\n    0x2b191908, 0x08080808, 0x2b192b19, 0x08080808, 0x2b2b0808, 0x08080808, 0x08080819, 0x08080819,\n    0x08081908, 0x08080819, 0x0808192b, 0x08080819, 0x08082b19, 0x08080819, 0x08190808, 0x08080819,\n    0x0819082b, 0x08080819, 0x08191919, 0x08080819, 0x08192b08, 0x08080819, 0x08192b2b, 0x08080819,\n    0x082b0819, 0x08080819, 0x082b1908, 0x08080819, 0x19080808, 0x08080819, 0x1908082b, 0x08080819,\n    0x19081919, 0x08080819, 0x19082b08, 0x08080819, 0x19190819, 0x08080819, 0x19191908, 0x08080819,\n    0x192b0808, 0x08080819, 0x192b2b08, 0x08080819, 0x2b080819, 0x08080819, 0x2b081908, 0x08080819,\n    0x2b190808, 0x08080819, 0x08080808, 0x0808082b, 0x0808082b, 0x0808082b, 0x08081919, 0x0808082b,\n    0x08082b08, 0x0808082b, 0x08190819, 0x0808082b, 0x08191908, 0x0808082b, 0x082b0808, 0x0808082b,\n    0x19080819, 0x0808082b, 0x19081908, 0x0808082b, 0x19190808, 0x0808082b, 0x19191919, 0x0808082b,\n    0x2b080808, 0x0808082b, 0x2b082b2b, 0x0808082b, 0x08080819, 0x08081908, 0x08081908, 0x08081908,\n    0x0808192b, 0x08081908, 0x08082b19, 0x08081908, 0x08190808, 0x08081908, 0x0819082b, 0x08081908,\n    0x08191919, 0x08081908, 0x08192b08, 0x08081908, 0x082b0819, 0x08081908, 0x082b1908, 0x08081908,\n    0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19081919, 0x08081908, 0x19082b08, 0x08081908,\n    0x19190819, 0x08081908, 0x19191908, 0x08081908, 0x1919192b, 0x08081908, 0x192b0808, 0x08081908,\n    0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b190808, 0x08081908, 0x08080808, 0x08081919,\n    0x0808082b, 0x08081919, 0x08081919, 0x08081919, 0x08082b08, 0x08081919, 0x08190819, 0x08081919,\n    0x08191908, 0x08081919, 0x082b0808, 0x08081919, 0x19080819, 0x08081919, 0x19081908, 0x08081919,\n    0x19190808, 0x08081919, 0x192b0819, 0x08081919, 0x2b080808, 0x08081919, 0x08080819, 0x0808192b,\n    0x08081908, 0x0808192b, 0x08190808, 0x0808192b, 0x082b192b, 0x0808192b, 0x19080808, 0x0808192b,\n    0x1908082b, 0x0808192b, 0x2b081908, 0x0808192b, 0x08080808, 0x08082b08, 0x0808082b, 0x08082b08,\n    0x08081919, 0x08082b08, 0x08082b08, 0x08082b08, 0x08082b2b, 0x08082b08, 0x08190819, 0x08082b08,\n    0x08191908, 0x08082b08, 0x082b0808, 0x08082b08, 0x082b1919, 0x08082b08, 0x19080819, 0x08082b08,\n    0x19081908, 0x08082b08, 0x19190808, 0x08082b08, 0x19192b08, 0x08082b08, 0x2b080808, 0x08082b08,\n    0x2b2b0808, 0x08082b08, 0x2b2b2b2b, 0x08082b08, 0x08080819, 0x08082b19, 0x08081908, 0x08082b19,\n    0x08190808, 0x08082b19, 0x19080808, 0x08082b19, 0x2b080819, 0x08082b19, 0x2b082b19, 0x08082b19,\n    0x08080808, 0x08082b2b, 0x082b0808, 0x08082b2b, 0x082b2b08, 0x08082b2b, 0x2b19192b, 0x08082b2b,\n    0x2b2b0808, 0x08082b2b, 0x08080819, 0x08190808, 0x08081908, 0x08190808, 0x0808192b, 0x08190808,\n    0x08082b19, 0x08190808, 0x08190808, 0x08190808, 0x0819082b, 0x08190808, 0x08191919, 0x08190808,\n    0x08192b08, 0x08190808, 0x082b0819, 0x08190808, 0x082b1908, 0x08190808, 0x19080808, 0x08190808,\n    0x1908082b, 0x08190808, 0x19081919, 0x08190808, 0x19082b08, 0x08190808, 0x19190819, 0x08190808,\n    0x19191908, 0x08190808, 0x192b0808, 0x08190808, 0x192b2b2b, 0x08190808, 0x2b080819, 0x08190808,\n    0x2b081908, 0x08190808, 0x2b190808, 0x08190808, 0x08080808, 0x08190819, 0x0808082b, 0x08190819,\n    0x08081919, 0x08190819, 0x08082b08, 0x08190819, 0x08190819, 0x08190819, 0x08191908, 0x08190819,\n    0x082b0808, 0x08190819, 0x19080819, 0x08190819, 0x19081908, 0x08190819, 0x19190808, 0x08190819,\n    0x2b080808, 0x08190819, 0x2b191908, 0x08190819, 0x2b19192b, 0x08190819, 0x08080819, 0x0819082b,\n    0x08081908, 0x0819082b, 0x0808192b, 0x0819082b, 0x08190808, 0x0819082b, 0x19080808, 0x0819082b,\n    0x192b0808, 0x0819082b, 0x08080808, 0x08191908, 0x0808082b, 0x08191908, 0x08081919, 0x08191908,\n    0x08082b08, 0x08191908, 0x08190819, 0x08191908, 0x08191908, 0x08191908, 0x082b0808, 0x08191908,\n    0x19080819, 0x08191908, 0x19081908, 0x08191908, 0x19082b19, 0x08191908, 0x19190808, 0x08191908,\n    0x192b1908, 0x08191908, 0x2b080808, 0x08191908, 0x08080819, 0x08191919, 0x08081908, 0x08191919,\n    0x08190808, 0x08191919, 0x19080808, 0x08191919, 0x08080808, 0x0819192b, 0x08191908, 0x0819192b,\n    0x19082b19, 0x0819192b, 0x08080819, 0x08192b08, 0x08081908, 0x08192b08, 0x08190808, 0x08192b08,\n    0x0819082b, 0x08192b08, 0x19080808, 0x08192b08, 0x19191908, 0x08192b08, 0x2b08192b, 0x08192b08,\n    0x08080808, 0x08192b19, 0x08081919, 0x08192b19, 0x192b192b, 0x08192b19, 0x19190819, 0x08192b2b,\n    0x2b2b2b19, 0x08192b2b, 0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08081919, 0x082b0808,\n    0x08082b08, 0x082b0808, 0x08082b2b, 0x082b0808, 0x08190819, 0x082b0808, 0x08191908, 0x082b0808,\n    0x082b0808, 0x082b0808, 0x19080819, 0x082b0808, 0x19081908, 0x082b0808, 0x19190808, 0x082b0808,\n    0x2b080808, 0x082b0808, 0x2b2b0808, 0x082b0808, 0x08080819, 0x082b0819, 0x08081908, 0x082b0819,\n    0x08190808, 0x082b0819, 0x19080808, 0x082b0819, 0x19082b08, 0x082b0819, 0x192b1919, 0x082b0819,\n    0x08080808, 0x082b082b, 0x082b082b, 0x082b082b, 0x2b080808, 0x082b082b, 0x2b2b2b08, 0x082b082b,\n    0x08080819, 0x082b1908, 0x08081908, 0x082b1908, 0x08190808, 0x082b1908, 0x082b2b19, 0x082b1908,\n    0x19080808, 0x082b1908, 0x08080808, 0x082b1919, 0x19080819, 0x082b1919, 0x1919082b, 0x082b1919,\n    0x2b192b19, 0x082b1919, 0x08080819, 0x082b192b, 0x08192b2b, 0x082b192b, 0x2b2b192b, 0x082b192b,\n    0x08080808, 0x082b2b08, 0x08082b08, 0x082b2b08, 0x08082b2b, 0x082b2b08, 0x082b0808, 0x082b2b08,\n    0x19191919, 0x082b2b08, 0x2b082b08, 0x082b2b08, 0x2b2b082b, 0x082b2b08, 0x192b2b08, 0x082b2b19,\n    0x2b190808, 0x082b2b19, 0x08082b08, 0x082b2b2b, 0x082b0808, 0x082b2b2b, 0x2b08082b, 0x082b2b2b,\n    0x2b082b08, 0x082b2b2b, 0x2b082b2b, 0x082b2b2b, 0x08080819, 0x19080808, 0x08081908, 0x19080808,\n    0x0808192b, 0x19080808, 0x08082b19, 0x19080808, 0x08190808, 0x19080808, 0x0819082b, 0x19080808,\n    0x08191919, 0x19080808, 0x08192b08, 0x19080808, 0x082b0819, 0x19080808, 0x082b1908, 0x19080808,\n    0x19080808, 0x19080808, 0x1908082b, 0x19080808, 0x19081919, 0x19080808, 0x19082b08, 0x19080808,\n    0x19082b2b, 0x19080808, 0x19190819, 0x19080808, 0x19191908, 0x19080808, 0x192b0808, 0x19080808,\n    0x192b1919, 0x19080808, 0x2b080819, 0x19080808, 0x2b081908, 0x19080808, 0x2b190808, 0x19080808,\n    0x08080808, 0x19080819, 0x0808082b, 0x19080819, 0x08081919, 0x19080819, 0x08082b08, 0x19080819,\n    0x08190819, 0x19080819, 0x08191908, 0x19080819, 0x082b0808, 0x19080819, 0x19080819, 0x19080819,\n    0x19081908, 0x19080819, 0x19190808, 0x19080819, 0x2b080808, 0x19080819, 0x2b081919, 0x19080819,\n    0x2b2b082b, 0x19080819, 0x08080819, 0x1908082b, 0x08081908, 0x1908082b, 0x08190808, 0x1908082b,\n    0x0819082b, 0x1908082b, 0x082b2b19, 0x1908082b, 0x19080808, 0x1908082b, 0x08080808, 0x19081908,\n    0x0808082b, 0x19081908, 0x08081919, 0x19081908, 0x08082b08, 0x19081908, 0x08190819, 0x19081908,\n    0x08191908, 0x19081908, 0x08192b19, 0x19081908, 0x082b0808, 0x19081908, 0x19080819, 0x19081908,\n    0x19081908, 0x19081908, 0x19190808, 0x19081908, 0x2b080808, 0x19081908, 0x2b191908, 0x19081908,\n    0x08080819, 0x19081919, 0x08081908, 0x19081919, 0x08190808, 0x19081919, 0x082b1908, 0x19081919,\n    0x19080808, 0x19081919, 0x2b192b2b, 0x19081919, 0x08080808, 0x1908192b, 0x08082b2b, 0x1908192b,\n    0x19081908, 0x1908192b, 0x19190808, 0x1908192b, 0x08080819, 0x19082b08, 0x08081908, 0x19082b08,\n    0x08190808, 0x19082b08, 0x19080808, 0x19082b08, 0x19081919, 0x19082b08, 0x19191908, 0x19082b08,\n    0x192b082b, 0x19082b08, 0x08080808, 0x19082b19, 0x08190819, 0x19082b19, 0x19081908, 0x19082b19,\n    0x19190808, 0x19082b19, 0x192b2b19, 0x19082b19, 0x08081908, 0x19082b2b, 0x08080808, 0x19190808,\n    0x0808082b, 0x19190808, 0x08081919, 0x19190808, 0x08082b08, 0x19190808, 0x08190819, 0x19190808,\n    0x08191908, 0x19190808, 0x082b0808, 0x19190808, 0x082b2b08, 0x19190808, 0x19080819, 0x19190808,\n    0x19081908, 0x19190808, 0x19190808, 0x19190808, 0x2b080808, 0x19190808, 0x08080819, 0x19190819,\n    0x08081908, 0x19190819, 0x08190808, 0x19190819, 0x08191919, 0x19190819, 0x19080808, 0x19190819,\n    0x1908082b, 0x19190819, 0x08080808, 0x1919082b, 0x19081908, 0x1919082b, 0x2b2b2b2b, 0x1919082b,\n    0x08080819, 0x19191908, 0x08081908, 0x19191908, 0x08190808, 0x19191908, 0x082b0819, 0x19191908,\n    0x19080808, 0x19191908, 0x192b0808, 0x19191908, 0x2b080819, 0x19191908, 0x2b2b0819, 0x19191908,\n    0x08080808, 0x19191919, 0x08082b08, 0x19191919, 0x2b080808, 0x19191919, 0x2b082b08, 0x19191919,\n    0x082b0819, 0x1919192b, 0x192b2b08, 0x1919192b, 0x2b2b0819, 0x1919192b, 0x08080808, 0x19192b08,\n    0x08191908, 0x19192b08, 0x19080819, 0x19192b08, 0x19190808, 0x19192b08, 0x2b192b19, 0x19192b08,\n    0x08192b2b, 0x19192b19, 0x19080808, 0x19192b19, 0x1908082b, 0x19192b19, 0x2b081919, 0x19192b2b,\n    0x08080819, 0x192b0808, 0x08081908, 0x192b0808, 0x08190808, 0x192b0808, 0x19080808, 0x192b0808,\n    0x19191908, 0x192b0808, 0x192b082b, 0x192b0808, 0x2b08192b, 0x192b0808, 0x2b2b2b19, 0x192b0808,\n    0x08080808, 0x192b0819, 0x082b1908, 0x192b082b, 0x19082b2b, 0x192b082b, 0x2b19082b, 0x192b082b,\n    0x08080808, 0x192b1908, 0x0819192b, 0x192b1908, 0x08190808, 0x192b1919, 0x19080808, 0x192b1919,\n    0x19081919, 0x192b1919, 0x2b2b1908, 0x192b1919, 0x08080819, 0x192b2b08, 0x192b2b2b, 0x192b2b08,\n    0x082b1919, 0x192b2b19, 0x0808192b, 0x192b2b2b, 0x19191908, 0x192b2b2b, 0x192b082b, 0x192b2b2b,\n    0x08080808, 0x2b080808, 0x0808082b, 0x2b080808, 0x08081919, 0x2b080808, 0x08082b08, 0x2b080808,\n    0x08190819, 0x2b080808, 0x08191908, 0x2b080808, 0x082b0808, 0x2b080808, 0x082b2b2b, 0x2b080808,\n    0x19080819, 0x2b080808, 0x19081908, 0x2b080808, 0x19190808, 0x2b080808, 0x2b080808, 0x2b080808,\n    0x2b08082b, 0x2b080808, 0x2b2b2b08, 0x2b080808, 0x2b2b2b2b, 0x2b080808, 0x08080819, 0x2b080819,\n    0x08081908, 0x2b080819, 0x0808192b, 0x2b080819, 0x08190808, 0x2b080819, 0x19080808, 0x2b080819,\n    0x19190819, 0x2b080819, 0x19192b19, 0x2b080819, 0x08080808, 0x2b08082b, 0x082b0808, 0x2b08082b,\n    0x2b080808, 0x2b08082b, 0x2b08082b, 0x2b08082b, 0x2b2b0808, 0x2b08082b, 0x2b2b2b08, 0x2b08082b,\n    0x08080819, 0x2b081908, 0x08081908, 0x2b081908, 0x08190808, 0x2b081908, 0x0819082b, 0x2b081908,\n    0x08191919, 0x2b081908, 0x19080808, 0x2b081908, 0x192b0808, 0x2b081908, 0x2b082b19, 0x2b081908,\n    0x08080808, 0x2b081919, 0x19081908, 0x2b081919, 0x2b2b1919, 0x2b081919, 0x08192b08, 0x2b08192b,\n    0x192b2b2b, 0x2b08192b, 0x08080808, 0x2b082b08, 0x08082b08, 0x2b082b08, 0x082b1919, 0x2b082b08,\n    0x19192b2b, 0x2b082b08, 0x2b080808, 0x2b082b08, 0x2b08082b, 0x2b082b08, 0x2b2b2b08, 0x2b082b08,\n    0x0808192b, 0x2b082b19, 0x082b082b, 0x2b082b2b, 0x2b080808, 0x2b082b2b, 0x2b082b08, 0x2b082b2b,\n    0x2b19192b, 0x2b082b2b, 0x2b2b2b08, 0x2b082b2b, 0x08080819, 0x2b190808, 0x08081908, 0x2b190808,\n    0x08190808, 0x2b190808, 0x19080808, 0x2b190808, 0x1919192b, 0x2b190808, 0x2b081908, 0x2b190808,\n    0x08080808, 0x2b190819, 0x082b082b, 0x2b190819, 0x192b1908, 0x2b190819, 0x1919192b, 0x2b19082b,\n    0x2b082b19, 0x2b19082b, 0x08080808, 0x2b191908, 0x08081919, 0x2b191908, 0x19081908, 0x2b191908,\n    0x19190808, 0x2b191908, 0x19192b08, 0x2b191908, 0x082b2b19, 0x2b191919, 0x2b190808, 0x2b191919,\n    0x2b19082b, 0x2b191919, 0x19080819, 0x2b19192b, 0x19190819, 0x2b192b08, 0x2b2b192b, 0x2b192b08,\n    0x19082b19, 0x2b192b19, 0x08191919, 0x2b192b2b, 0x192b0808, 0x2b192b2b, 0x08080808, 0x2b2b0808,\n    0x0808082b, 0x2b2b0808, 0x08082b08, 0x2b2b0808, 0x08082b2b, 0x2b2b0808, 0x082b0808, 0x2b2b0808,\n    0x082b2b2b, 0x2b2b0808, 0x2b2b0808, 0x2b2b0808, 0x19190819, 0x2b2b0819, 0x19192b19, 0x2b2b0819,\n    0x2b2b192b, 0x2b2b0819, 0x08080808, 0x2b2b082b, 0x0808082b, 0x2b2b082b, 0x08082b08, 0x2b2b082b,\n    0x082b2b2b, 0x2b2b082b, 0x2b080808, 0x2b2b082b, 0x2b2b0808, 0x2b2b082b, 0x19080808, 0x2b2b1908,\n    0x2b191919, 0x2b2b1908, 0x192b1919, 0x2b2b192b, 0x2b192b08, 0x2b2b192b, 0x08082b2b, 0x2b2b2b08,\n    0x082b0808, 0x2b2b2b08, 0x082b082b, 0x2b2b2b08, 0x082b2b08, 0x2b2b2b08, 0x2b2b0808, 0x2b2b2b08,\n    0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19,\n    0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b\n);\n#endif\n\n#ifdef IQ2_S_GRID\nconst iq2s_grid = array<u32, 2048>(\n    0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,\n    0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,\n    0x08192b19, 0x08080808, 0x082b0808, 0x08080808, 0x082b082b, 0x08080808, 0x082b1919, 0x08080808,\n    0x082b2b08, 0x08080808, 0x19080819, 0x08080808, 0x19081908, 0x08080808, 0x1908192b, 0x08080808,\n    0x19082b19, 0x08080808, 0x19190808, 0x08080808, 0x1919082b, 0x08080808, 0x19191919, 0x08080808,\n    0x19192b08, 0x08080808, 0x192b0819, 0x08080808, 0x192b1908, 0x08080808, 0x192b192b, 0x08080808,\n    0x192b2b19, 0x08080808, 0x2b080808, 0x08080808, 0x2b08082b, 0x08080808, 0x2b081919, 0x08080808,\n    0x2b082b08, 0x08080808, 0x2b190819, 0x08080808, 0x2b191908, 0x08080808, 0x2b2b0808, 0x08080808,\n    0x2b2b1919, 0x08080808, 0x2b2b2b2b, 0x08080808, 0x08080819, 0x08080819, 0x08081908, 0x08080819,\n    0x0808192b, 0x08080819, 0x08082b19, 0x08080819, 0x08190808, 0x08080819, 0x0819082b, 0x08080819,\n    0x08191919, 0x08080819, 0x08192b08, 0x08080819, 0x082b0819, 0x08080819, 0x082b1908, 0x08080819,\n    0x19080808, 0x08080819, 0x1908082b, 0x08080819, 0x19081919, 0x08080819, 0x19082b08, 0x08080819,\n    0x19190819, 0x08080819, 0x19191908, 0x08080819, 0x1919192b, 0x08080819, 0x19192b19, 0x08080819,\n    0x192b0808, 0x08080819, 0x192b1919, 0x08080819, 0x192b2b08, 0x08080819, 0x2b080819, 0x08080819,\n    0x2b081908, 0x08080819, 0x2b190808, 0x08080819, 0x2b19082b, 0x08080819, 0x2b191919, 0x08080819,\n    0x2b2b0819, 0x08080819, 0x2b2b1908, 0x08080819, 0x08080808, 0x0808082b, 0x0808082b, 0x0808082b,\n    0x08081919, 0x0808082b, 0x08082b08, 0x0808082b, 0x08190819, 0x0808082b, 0x08191908, 0x0808082b,\n    0x082b0808, 0x0808082b, 0x082b2b2b, 0x0808082b, 0x19080819, 0x0808082b, 0x19081908, 0x0808082b,\n    0x1908192b, 0x0808082b, 0x19082b19, 0x0808082b, 0x19190808, 0x0808082b, 0x19191919, 0x0808082b,\n    0x2b080808, 0x0808082b, 0x2b081919, 0x0808082b, 0x2b082b2b, 0x0808082b, 0x2b191908, 0x0808082b,\n    0x2b2b082b, 0x0808082b, 0x08080819, 0x08081908, 0x08081908, 0x08081908, 0x0808192b, 0x08081908,\n    0x08082b19, 0x08081908, 0x08190808, 0x08081908, 0x0819082b, 0x08081908, 0x08191919, 0x08081908,\n    0x08192b08, 0x08081908, 0x082b0819, 0x08081908, 0x082b1908, 0x08081908, 0x082b192b, 0x08081908,\n    0x082b2b19, 0x08081908, 0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19081919, 0x08081908,\n    0x19082b08, 0x08081908, 0x19082b2b, 0x08081908, 0x19190819, 0x08081908, 0x19191908, 0x08081908,\n    0x1919192b, 0x08081908, 0x19192b19, 0x08081908, 0x192b0808, 0x08081908, 0x192b082b, 0x08081908,\n    0x192b1919, 0x08081908, 0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b08192b, 0x08081908,\n    0x2b082b19, 0x08081908, 0x2b190808, 0x08081908, 0x2b191919, 0x08081908, 0x2b192b08, 0x08081908,\n    0x2b2b0819, 0x08081908, 0x2b2b1908, 0x08081908, 0x08080808, 0x08081919, 0x0808082b, 0x08081919,\n    0x08081919, 0x08081919, 0x08082b08, 0x08081919, 0x08082b2b, 0x08081919, 0x08190819, 0x08081919,\n    0x08191908, 0x08081919, 0x0819192b, 0x08081919, 0x08192b19, 0x08081919, 0x082b0808, 0x08081919,\n    0x082b1919, 0x08081919, 0x082b2b08, 0x08081919, 0x19080819, 0x08081919, 0x19081908, 0x08081919,\n    0x1908192b, 0x08081919, 0x19082b19, 0x08081919, 0x19190808, 0x08081919, 0x1919082b, 0x08081919,\n    0x19191919, 0x08081919, 0x19192b08, 0x08081919, 0x192b0819, 0x08081919, 0x192b1908, 0x08081919,\n    0x2b080808, 0x08081919, 0x2b08082b, 0x08081919, 0x2b081919, 0x08081919, 0x2b082b08, 0x08081919,\n    0x2b190819, 0x08081919, 0x2b191908, 0x08081919, 0x2b2b0808, 0x08081919, 0x08080819, 0x0808192b,\n    0x08081908, 0x0808192b, 0x0808192b, 0x0808192b, 0x08082b19, 0x0808192b, 0x08190808, 0x0808192b,\n    0x08191919, 0x0808192b, 0x19080808, 0x0808192b, 0x19081919, 0x0808192b, 0x19082b08, 0x0808192b,\n    0x19190819, 0x0808192b, 0x19191908, 0x0808192b, 0x192b0808, 0x0808192b, 0x2b080819, 0x0808192b,\n    0x2b081908, 0x0808192b, 0x2b190808, 0x0808192b, 0x08080808, 0x08082b08, 0x0808082b, 0x08082b08,\n    0x08081919, 0x08082b08, 0x08082b08, 0x08082b08, 0x08190819, 0x08082b08, 0x08191908, 0x08082b08,\n    0x0819192b, 0x08082b08, 0x08192b19, 0x08082b08, 0x082b0808, 0x08082b08, 0x082b1919, 0x08082b08,\n    0x082b2b2b, 0x08082b08, 0x19080819, 0x08082b08, 0x19081908, 0x08082b08, 0x1908192b, 0x08082b08,\n    0x19082b19, 0x08082b08, 0x19190808, 0x08082b08, 0x1919082b, 0x08082b08, 0x19191919, 0x08082b08,\n    0x19192b08, 0x08082b08, 0x192b0819, 0x08082b08, 0x192b1908, 0x08082b08, 0x2b080808, 0x08082b08,\n    0x2b081919, 0x08082b08, 0x2b191908, 0x08082b08, 0x2b2b2b2b, 0x08082b08, 0x08080819, 0x08082b19,\n    0x08081908, 0x08082b19, 0x08190808, 0x08082b19, 0x0819082b, 0x08082b19, 0x08191919, 0x08082b19,\n    0x08192b08, 0x08082b19, 0x082b0819, 0x08082b19, 0x19080808, 0x08082b19, 0x19081919, 0x08082b19,\n    0x19082b08, 0x08082b19, 0x19190819, 0x08082b19, 0x19191908, 0x08082b19, 0x192b0808, 0x08082b19,\n    0x2b080819, 0x08082b19, 0x2b190808, 0x08082b19, 0x08080808, 0x08082b2b, 0x08190819, 0x08082b2b,\n    0x08191908, 0x08082b2b, 0x082b082b, 0x08082b2b, 0x082b2b08, 0x08082b2b, 0x082b2b2b, 0x08082b2b,\n    0x19190808, 0x08082b2b, 0x2b192b19, 0x08082b2b, 0x08080819, 0x08190808, 0x08081908, 0x08190808,\n    0x0808192b, 0x08190808, 0x08082b19, 0x08190808, 0x08190808, 0x08190808, 0x0819082b, 0x08190808,\n    0x08191919, 0x08190808, 0x08192b08, 0x08190808, 0x082b0819, 0x08190808, 0x082b1908, 0x08190808,\n    0x082b192b, 0x08190808, 0x19080808, 0x08190808, 0x1908082b, 0x08190808, 0x19081919, 0x08190808,\n    0x19082b08, 0x08190808, 0x19190819, 0x08190808, 0x19191908, 0x08190808, 0x1919192b, 0x08190808,\n    0x19192b19, 0x08190808, 0x192b0808, 0x08190808, 0x192b082b, 0x08190808, 0x192b1919, 0x08190808,\n    0x192b2b08, 0x08190808, 0x2b080819, 0x08190808, 0x2b081908, 0x08190808, 0x2b08192b, 0x08190808,\n    0x2b190808, 0x08190808, 0x2b191919, 0x08190808, 0x2b192b08, 0x08190808, 0x2b2b0819, 0x08190808,\n    0x2b2b1908, 0x08190808, 0x08080808, 0x08190819, 0x0808082b, 0x08190819, 0x08081919, 0x08190819,\n    0x08082b08, 0x08190819, 0x08082b2b, 0x08190819, 0x08190819, 0x08190819, 0x08191908, 0x08190819,\n    0x0819192b, 0x08190819, 0x08192b19, 0x08190819, 0x082b0808, 0x08190819, 0x082b082b, 0x08190819,\n    0x082b1919, 0x08190819, 0x082b2b08, 0x08190819, 0x19080819, 0x08190819, 0x19081908, 0x08190819,\n    0x1908192b, 0x08190819, 0x19082b19, 0x08190819, 0x19190808, 0x08190819, 0x1919082b, 0x08190819,\n    0x19191919, 0x08190819, 0x19192b08, 0x08190819, 0x192b0819, 0x08190819, 0x192b1908, 0x08190819,\n    0x2b080808, 0x08190819, 0x2b08082b, 0x08190819, 0x2b081919, 0x08190819, 0x2b082b08, 0x08190819,\n    0x2b190819, 0x08190819, 0x2b191908, 0x08190819, 0x08080819, 0x0819082b, 0x08081908, 0x0819082b,\n    0x08082b19, 0x0819082b, 0x08190808, 0x0819082b, 0x08191919, 0x0819082b, 0x082b0819, 0x0819082b,\n    0x082b1908, 0x0819082b, 0x19080808, 0x0819082b, 0x19081919, 0x0819082b, 0x19190819, 0x0819082b,\n    0x19191908, 0x0819082b, 0x2b080819, 0x0819082b, 0x2b081908, 0x0819082b, 0x2b190808, 0x0819082b,\n    0x08080808, 0x08191908, 0x0808082b, 0x08191908, 0x08081919, 0x08191908, 0x08082b08, 0x08191908,\n    0x08190819, 0x08191908, 0x08191908, 0x08191908, 0x0819192b, 0x08191908, 0x08192b19, 0x08191908,\n    0x082b0808, 0x08191908, 0x082b1919, 0x08191908, 0x082b2b08, 0x08191908, 0x19080819, 0x08191908,\n    0x19081908, 0x08191908, 0x1908192b, 0x08191908, 0x19082b19, 0x08191908, 0x19190808, 0x08191908,\n    0x1919082b, 0x08191908, 0x19191919, 0x08191908, 0x19192b08, 0x08191908, 0x192b0819, 0x08191908,\n    0x192b1908, 0x08191908, 0x2b080808, 0x08191908, 0x2b08082b, 0x08191908, 0x2b081919, 0x08191908,\n    0x2b082b08, 0x08191908, 0x2b190819, 0x08191908, 0x2b191908, 0x08191908, 0x2b2b0808, 0x08191908,\n    0x08080819, 0x08191919, 0x08081908, 0x08191919, 0x0808192b, 0x08191919, 0x08082b19, 0x08191919,\n    0x08190808, 0x08191919, 0x0819082b, 0x08191919, 0x08191919, 0x08191919, 0x08192b08, 0x08191919,\n    0x082b0819, 0x08191919, 0x082b1908, 0x08191919, 0x19080808, 0x08191919, 0x1908082b, 0x08191919,\n    0x19081919, 0x08191919, 0x19082b08, 0x08191919, 0x19190819, 0x08191919, 0x19191908, 0x08191919,\n    0x192b0808, 0x08191919, 0x2b080819, 0x08191919, 0x2b081908, 0x08191919, 0x2b190808, 0x08191919,\n    0x08080808, 0x0819192b, 0x08081919, 0x0819192b, 0x08082b08, 0x0819192b, 0x08190819, 0x0819192b,\n    0x08191908, 0x0819192b, 0x082b0808, 0x0819192b, 0x19080819, 0x0819192b, 0x19081908, 0x0819192b,\n    0x19190808, 0x0819192b, 0x2b080808, 0x0819192b, 0x2b2b2b2b, 0x0819192b, 0x08080819, 0x08192b08,\n    0x08081908, 0x08192b08, 0x0808192b, 0x08192b08, 0x08082b19, 0x08192b08, 0x08190808, 0x08192b08,\n    0x08191919, 0x08192b08, 0x08192b08, 0x08192b08, 0x082b0819, 0x08192b08, 0x19080808, 0x08192b08,\n    0x1908082b, 0x08192b08, 0x19081919, 0x08192b08, 0x19082b08, 0x08192b08, 0x19190819, 0x08192b08,\n    0x19191908, 0x08192b08, 0x192b0808, 0x08192b08, 0x2b080819, 0x08192b08, 0x2b081908, 0x08192b08,\n    0x08080808, 0x08192b19, 0x0808082b, 0x08192b19, 0x08081919, 0x08192b19, 0x08082b08, 0x08192b19,\n    0x08190819, 0x08192b19, 0x08191908, 0x08192b19, 0x082b0808, 0x08192b19, 0x19080819, 0x08192b19,\n    0x19081908, 0x08192b19, 0x19190808, 0x08192b19, 0x192b2b19, 0x08192b19, 0x2b2b082b, 0x08192b19,\n    0x08081908, 0x08192b2b, 0x08190808, 0x08192b2b, 0x19080808, 0x08192b2b, 0x1919192b, 0x08192b2b,\n    0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08081919, 0x082b0808, 0x08082b08, 0x082b0808,\n    0x08190819, 0x082b0808, 0x08191908, 0x082b0808, 0x0819192b, 0x082b0808, 0x08192b19, 0x082b0808,\n    0x082b0808, 0x082b0808, 0x082b1919, 0x082b0808, 0x082b2b2b, 0x082b0808, 0x19080819, 0x082b0808,\n    0x19081908, 0x082b0808, 0x19190808, 0x082b0808, 0x1919082b, 0x082b0808, 0x19191919, 0x082b0808,\n    0x192b1908, 0x082b0808, 0x2b080808, 0x082b0808, 0x2b082b2b, 0x082b0808, 0x2b191908, 0x082b0808,\n    0x2b2b2b2b, 0x082b0808, 0x08080819, 0x082b0819, 0x08081908, 0x082b0819, 0x08190808, 0x082b0819,\n    0x0819082b, 0x082b0819, 0x08191919, 0x082b0819, 0x082b0819, 0x082b0819, 0x19080808, 0x082b0819,\n    0x1908082b, 0x082b0819, 0x19081919, 0x082b0819, 0x19190819, 0x082b0819, 0x19191908, 0x082b0819,\n    0x192b0808, 0x082b0819, 0x2b080819, 0x082b0819, 0x2b081908, 0x082b0819, 0x2b190808, 0x082b0819,\n    0x08080808, 0x082b082b, 0x08082b2b, 0x082b082b, 0x082b082b, 0x082b082b, 0x082b2b08, 0x082b082b,\n    0x082b2b2b, 0x082b082b, 0x19081908, 0x082b082b, 0x19190808, 0x082b082b, 0x2b082b08, 0x082b082b,\n    0x2b082b2b, 0x082b082b, 0x2b2b2b08, 0x082b082b, 0x08080819, 0x082b1908, 0x08081908, 0x082b1908,\n    0x0808192b, 0x082b1908, 0x08082b19, 0x082b1908, 0x08190808, 0x082b1908, 0x08191919, 0x082b1908,\n    0x08192b08, 0x082b1908, 0x082b0819, 0x082b1908, 0x082b1908, 0x082b1908, 0x19080808, 0x082b1908,\n    0x1908082b, 0x082b1908, 0x19081919, 0x082b1908, 0x19082b08, 0x082b1908, 0x19190819, 0x082b1908,\n    0x19191908, 0x082b1908, 0x192b0808, 0x082b1908, 0x2b080819, 0x082b1908, 0x2b081908, 0x082b1908,\n    0x2b190808, 0x082b1908, 0x08080808, 0x082b1919, 0x08081919, 0x082b1919, 0x08082b08, 0x082b1919,\n    0x08190819, 0x082b1919, 0x08191908, 0x082b1919, 0x082b0808, 0x082b1919, 0x19080819, 0x082b1919,\n    0x19081908, 0x082b1919, 0x19190808, 0x082b1919, 0x192b192b, 0x082b1919, 0x2b080808, 0x082b1919,\n    0x08080819, 0x082b192b, 0x08081908, 0x082b192b, 0x08190808, 0x082b192b, 0x19080808, 0x082b192b,\n    0x19192b19, 0x082b192b, 0x08080808, 0x082b2b08, 0x08081919, 0x082b2b08, 0x08190819, 0x082b2b08,\n    0x08191908, 0x082b2b08, 0x19080819, 0x082b2b08, 0x19081908, 0x082b2b08, 0x19190808, 0x082b2b08,\n    0x2b082b2b, 0x082b2b08, 0x2b2b2b2b, 0x082b2b08, 0x08080819, 0x082b2b19, 0x08081908, 0x082b2b19,\n    0x08190808, 0x082b2b19, 0x2b191919, 0x082b2b19, 0x08082b2b, 0x082b2b2b, 0x082b082b, 0x082b2b2b,\n    0x192b1908, 0x082b2b2b, 0x2b082b08, 0x082b2b2b, 0x2b082b2b, 0x082b2b2b, 0x08080819, 0x19080808,\n    0x08081908, 0x19080808, 0x0808192b, 0x19080808, 0x08082b19, 0x19080808, 0x08190808, 0x19080808,\n    0x0819082b, 0x19080808, 0x08191919, 0x19080808, 0x08192b08, 0x19080808, 0x08192b2b, 0x19080808,\n    0x082b0819, 0x19080808, 0x082b1908, 0x19080808, 0x082b192b, 0x19080808, 0x19080808, 0x19080808,\n    0x1908082b, 0x19080808, 0x19081919, 0x19080808, 0x19082b08, 0x19080808, 0x19082b2b, 0x19080808,\n    0x19190819, 0x19080808, 0x19191908, 0x19080808, 0x1919192b, 0x19080808, 0x19192b19, 0x19080808,\n    0x192b0808, 0x19080808, 0x192b082b, 0x19080808, 0x192b1919, 0x19080808, 0x2b080819, 0x19080808,\n    0x2b081908, 0x19080808, 0x2b190808, 0x19080808, 0x2b191919, 0x19080808, 0x2b192b08, 0x19080808,\n    0x2b2b0819, 0x19080808, 0x2b2b1908, 0x19080808, 0x08080808, 0x19080819, 0x0808082b, 0x19080819,\n    0x08081919, 0x19080819, 0x08082b08, 0x19080819, 0x08190819, 0x19080819, 0x08191908, 0x19080819,\n    0x0819192b, 0x19080819, 0x08192b19, 0x19080819, 0x082b0808, 0x19080819, 0x082b082b, 0x19080819,\n    0x082b1919, 0x19080819, 0x19080819, 0x19080819, 0x19081908, 0x19080819, 0x1908192b, 0x19080819,\n    0x19082b19, 0x19080819, 0x19190808, 0x19080819, 0x1919082b, 0x19080819, 0x19191919, 0x19080819,\n    0x19192b08, 0x19080819, 0x192b0819, 0x19080819, 0x192b1908, 0x19080819, 0x2b080808, 0x19080819,\n    0x2b08082b, 0x19080819, 0x2b081919, 0x19080819, 0x2b082b08, 0x19080819, 0x2b190819, 0x19080819,\n    0x2b191908, 0x19080819, 0x2b2b0808, 0x19080819, 0x08080819, 0x1908082b, 0x08081908, 0x1908082b,\n    0x08190808, 0x1908082b, 0x0819082b, 0x1908082b, 0x08191919, 0x1908082b, 0x08192b08, 0x1908082b,\n    0x082b1908, 0x1908082b, 0x19080808, 0x1908082b, 0x19081919, 0x1908082b, 0x19082b08, 0x1908082b,\n    0x19190819, 0x1908082b, 0x19191908, 0x1908082b, 0x192b0808, 0x1908082b, 0x2b080819, 0x1908082b,\n    0x2b081908, 0x1908082b, 0x08080808, 0x19081908, 0x0808082b, 0x19081908, 0x08081919, 0x19081908,\n    0x08082b08, 0x19081908, 0x08082b2b, 0x19081908, 0x08190819, 0x19081908, 0x08191908, 0x19081908,\n    0x0819192b, 0x19081908, 0x08192b19, 0x19081908, 0x082b0808, 0x19081908, 0x082b082b, 0x19081908,\n    0x082b1919, 0x19081908, 0x082b2b08, 0x19081908, 0x19080819, 0x19081908, 0x19081908, 0x19081908,\n    0x1908192b, 0x19081908, 0x19082b19, 0x19081908, 0x19190808, 0x19081908, 0x1919082b, 0x19081908,\n    0x19191919, 0x19081908, 0x19192b08, 0x19081908, 0x192b0819, 0x19081908, 0x192b1908, 0x19081908,\n    0x2b080808, 0x19081908, 0x2b08082b, 0x19081908, 0x2b081919, 0x19081908, 0x2b082b08, 0x19081908,\n    0x2b190819, 0x19081908, 0x2b191908, 0x19081908, 0x2b2b0808, 0x19081908, 0x08080819, 0x19081919,\n    0x08081908, 0x19081919, 0x0808192b, 0x19081919, 0x08082b19, 0x19081919, 0x08190808, 0x19081919,\n    0x0819082b, 0x19081919, 0x08191919, 0x19081919, 0x08192b08, 0x19081919, 0x082b0819, 0x19081919,\n    0x082b1908, 0x19081919, 0x19080808, 0x19081919, 0x1908082b, 0x19081919, 0x19081919, 0x19081919,\n    0x19082b08, 0x19081919, 0x19190819, 0x19081919, 0x19191908, 0x19081919, 0x192b0808, 0x19081919,\n    0x192b2b2b, 0x19081919, 0x2b080819, 0x19081919, 0x2b081908, 0x19081919, 0x2b190808, 0x19081919,\n    0x08080808, 0x1908192b, 0x0808082b, 0x1908192b, 0x08081919, 0x1908192b, 0x08082b08, 0x1908192b,\n    0x08190819, 0x1908192b, 0x08191908, 0x1908192b, 0x082b0808, 0x1908192b, 0x19080819, 0x1908192b,\n    0x19081908, 0x1908192b, 0x19190808, 0x1908192b, 0x2b080808, 0x1908192b, 0x2b2b1919, 0x1908192b,\n    0x08080819, 0x19082b08, 0x08081908, 0x19082b08, 0x08082b19, 0x19082b08, 0x08190808, 0x19082b08,\n    0x0819082b, 0x19082b08, 0x08191919, 0x19082b08, 0x08192b08, 0x19082b08, 0x082b0819, 0x19082b08,\n    0x082b1908, 0x19082b08, 0x19080808, 0x19082b08, 0x1908082b, 0x19082b08, 0x19081919, 0x19082b08,\n    0x19082b08, 0x19082b08, 0x19190819, 0x19082b08, 0x19191908, 0x19082b08, 0x192b0808, 0x19082b08,\n    0x2b081908, 0x19082b08, 0x2b190808, 0x19082b08, 0x08080808, 0x19082b19, 0x0808082b, 0x19082b19,\n    0x08081919, 0x19082b19, 0x08082b08, 0x19082b19, 0x08190819, 0x19082b19, 0x08191908, 0x19082b19,\n    0x082b0808, 0x19082b19, 0x19080819, 0x19082b19, 0x19081908, 0x19082b19, 0x19190808, 0x19082b19,\n    0x2b080808, 0x19082b19, 0x2b19192b, 0x19082b19, 0x08080819, 0x19082b2b, 0x08081908, 0x19082b2b,\n    0x08190808, 0x19082b2b, 0x19080808, 0x19082b2b, 0x08080808, 0x19190808, 0x0808082b, 0x19190808,\n    0x08081919, 0x19190808, 0x08082b08, 0x19190808, 0x08190819, 0x19190808, 0x08191908, 0x19190808,\n    0x0819192b, 0x19190808, 0x08192b19, 0x19190808, 0x082b0808, 0x19190808, 0x082b082b, 0x19190808,\n    0x082b1919, 0x19190808, 0x082b2b08, 0x19190808, 0x19080819, 0x19190808, 0x19081908, 0x19190808,\n    0x1908192b, 0x19190808, 0x19082b19, 0x19190808, 0x19190808, 0x19190808, 0x1919082b, 0x19190808,\n    0x19191919, 0x19190808, 0x19192b08, 0x19190808, 0x192b0819, 0x19190808, 0x192b1908, 0x19190808,\n    0x2b080808, 0x19190808, 0x2b08082b, 0x19190808, 0x2b081919, 0x19190808, 0x2b082b08, 0x19190808,\n    0x2b190819, 0x19190808, 0x2b191908, 0x19190808, 0x08080819, 0x19190819, 0x08081908, 0x19190819,\n    0x0808192b, 0x19190819, 0x08082b19, 0x19190819, 0x08190808, 0x19190819, 0x0819082b, 0x19190819,\n    0x08191919, 0x19190819, 0x08192b08, 0x19190819, 0x082b0819, 0x19190819, 0x082b1908, 0x19190819,\n    0x19080808, 0x19190819, 0x1908082b, 0x19190819, 0x19081919, 0x19190819, 0x19082b08, 0x19190819,\n    0x19190819, 0x19190819, 0x19191908, 0x19190819, 0x192b0808, 0x19190819, 0x2b080819, 0x19190819,\n    0x2b081908, 0x19190819, 0x2b190808, 0x19190819, 0x08080808, 0x1919082b, 0x08081919, 0x1919082b,\n    0x08082b08, 0x1919082b, 0x08190819, 0x1919082b, 0x08191908, 0x1919082b, 0x082b0808, 0x1919082b,\n    0x19080819, 0x1919082b, 0x19081908, 0x1919082b, 0x19190808, 0x1919082b, 0x192b2b19, 0x1919082b,\n    0x2b080808, 0x1919082b, 0x08080819, 0x19191908, 0x08081908, 0x19191908, 0x0808192b, 0x19191908,\n    0x08082b19, 0x19191908, 0x08190808, 0x19191908, 0x0819082b, 0x19191908, 0x08191919, 0x19191908,\n    0x08192b08, 0x19191908, 0x082b0819, 0x19191908, 0x082b1908, 0x19191908, 0x19080808, 0x19191908,\n    0x1908082b, 0x19191908, 0x19081919, 0x19191908, 0x19082b08, 0x19191908, 0x19190819, 0x19191908,\n    0x19191908, 0x19191908, 0x192b0808, 0x19191908, 0x2b080819, 0x19191908, 0x2b081908, 0x19191908,\n    0x2b190808, 0x19191908, 0x08080808, 0x19191919, 0x0808082b, 0x19191919, 0x08081919, 0x19191919,\n    0x08082b08, 0x19191919, 0x08190819, 0x19191919, 0x08191908, 0x19191919, 0x082b0808, 0x19191919,\n    0x19080819, 0x19191919, 0x19081908, 0x19191919, 0x19190808, 0x19191919, 0x2b080808, 0x19191919,\n    0x08080819, 0x1919192b, 0x08081908, 0x1919192b, 0x08190808, 0x1919192b, 0x082b192b, 0x1919192b,\n    0x19080808, 0x1919192b, 0x08080808, 0x19192b08, 0x0808082b, 0x19192b08, 0x08081919, 0x19192b08,\n    0x08082b08, 0x19192b08, 0x08190819, 0x19192b08, 0x08191908, 0x19192b08, 0x082b0808, 0x19192b08,\n    0x19080819, 0x19192b08, 0x19081908, 0x19192b08, 0x19190808, 0x19192b08, 0x19192b2b, 0x19192b08,\n    0x2b080808, 0x19192b08, 0x08080819, 0x19192b19, 0x08081908, 0x19192b19, 0x08190808, 0x19192b19,\n    0x19080808, 0x19192b19, 0x08080808, 0x19192b2b, 0x08192b19, 0x19192b2b, 0x2b081919, 0x19192b2b,\n    0x2b2b2b08, 0x19192b2b, 0x08080819, 0x192b0808, 0x08081908, 0x192b0808, 0x0808192b, 0x192b0808,\n    0x08190808, 0x192b0808, 0x0819082b, 0x192b0808, 0x08191919, 0x192b0808, 0x08192b08, 0x192b0808,\n    0x082b0819, 0x192b0808, 0x082b1908, 0x192b0808, 0x19080808, 0x192b0808, 0x19081919, 0x192b0808,\n    0x19082b08, 0x192b0808, 0x19190819, 0x192b0808, 0x19191908, 0x192b0808, 0x192b0808, 0x192b0808,\n    0x2b081908, 0x192b0808, 0x2b190808, 0x192b0808, 0x08080808, 0x192b0819, 0x0808082b, 0x192b0819,\n    0x08081919, 0x192b0819, 0x08082b08, 0x192b0819, 0x08190819, 0x192b0819, 0x08191908, 0x192b0819,\n    0x082b0808, 0x192b0819, 0x19080819, 0x192b0819, 0x19081908, 0x192b0819, 0x19190808, 0x192b0819,\n    0x2b080808, 0x192b0819, 0x2b192b19, 0x192b0819, 0x08081908, 0x192b082b, 0x08190808, 0x192b082b,\n    0x19080808, 0x192b082b, 0x1919192b, 0x192b082b, 0x2b2b0819, 0x192b082b, 0x08080808, 0x192b1908,\n    0x08081919, 0x192b1908, 0x08082b08, 0x192b1908, 0x08190819, 0x192b1908, 0x08191908, 0x192b1908,\n    0x082b0808, 0x192b1908, 0x19080819, 0x192b1908, 0x19081908, 0x192b1908, 0x19190808, 0x192b1908,\n    0x2b080808, 0x192b1908, 0x08080819, 0x192b1919, 0x08081908, 0x192b1919, 0x08190808, 0x192b1919,\n    0x19080808, 0x192b1919, 0x19082b2b, 0x192b1919, 0x192b2b08, 0x192b1919, 0x2b19082b, 0x192b1919,\n    0x08080808, 0x192b192b, 0x2b191908, 0x192b192b, 0x08080819, 0x192b2b08, 0x08081908, 0x192b2b08,\n    0x08190808, 0x192b2b08, 0x192b1919, 0x192b2b08, 0x2b192b08, 0x192b2b08, 0x08080808, 0x192b2b19,\n    0x082b2b2b, 0x192b2b19, 0x1908082b, 0x192b2b2b, 0x2b2b0819, 0x192b2b2b, 0x08080808, 0x2b080808,\n    0x0808082b, 0x2b080808, 0x08081919, 0x2b080808, 0x08082b08, 0x2b080808, 0x08190819, 0x2b080808,\n    0x08191908, 0x2b080808, 0x08192b19, 0x2b080808, 0x082b0808, 0x2b080808, 0x082b1919, 0x2b080808,\n    0x19080819, 0x2b080808, 0x19081908, 0x2b080808, 0x19190808, 0x2b080808, 0x1919082b, 0x2b080808,\n    0x19191919, 0x2b080808, 0x19192b08, 0x2b080808, 0x192b0819, 0x2b080808, 0x2b080808, 0x2b080808,\n    0x2b081919, 0x2b080808, 0x2b190819, 0x2b080808, 0x2b191908, 0x2b080808, 0x08080819, 0x2b080819,\n    0x08081908, 0x2b080819, 0x08082b19, 0x2b080819, 0x08190808, 0x2b080819, 0x0819082b, 0x2b080819,\n    0x08191919, 0x2b080819, 0x08192b08, 0x2b080819, 0x082b0819, 0x2b080819, 0x082b1908, 0x2b080819,\n    0x19080808, 0x2b080819, 0x1908082b, 0x2b080819, 0x19081919, 0x2b080819, 0x19082b08, 0x2b080819,\n    0x19190819, 0x2b080819, 0x19191908, 0x2b080819, 0x2b080819, 0x2b080819, 0x2b081908, 0x2b080819,\n    0x2b190808, 0x2b080819, 0x2b2b2b19, 0x2b080819, 0x08080808, 0x2b08082b, 0x08081919, 0x2b08082b,\n    0x08082b2b, 0x2b08082b, 0x08190819, 0x2b08082b, 0x08191908, 0x2b08082b, 0x19080819, 0x2b08082b,\n    0x19081908, 0x2b08082b, 0x19190808, 0x2b08082b, 0x08080819, 0x2b081908, 0x08081908, 0x2b081908,\n    0x0808192b, 0x2b081908, 0x08082b19, 0x2b081908, 0x08190808, 0x2b081908, 0x0819082b, 0x2b081908,\n    0x08191919, 0x2b081908, 0x08192b08, 0x2b081908, 0x082b0819, 0x2b081908, 0x19080808, 0x2b081908,\n    0x1908082b, 0x2b081908, 0x19081919, 0x2b081908, 0x19082b08, 0x2b081908, 0x19190819, 0x2b081908,\n    0x19191908, 0x2b081908, 0x192b0808, 0x2b081908, 0x2b080819, 0x2b081908, 0x2b081908, 0x2b081908,\n    0x2b190808, 0x2b081908, 0x08080808, 0x2b081919, 0x0808082b, 0x2b081919, 0x08081919, 0x2b081919,\n    0x08082b08, 0x2b081919, 0x08190819, 0x2b081919, 0x08191908, 0x2b081919, 0x082b0808, 0x2b081919,\n    0x19080819, 0x2b081919, 0x19081908, 0x2b081919, 0x19190808, 0x2b081919, 0x2b080808, 0x2b081919,\n    0x2b082b2b, 0x2b081919, 0x08080819, 0x2b08192b, 0x08081908, 0x2b08192b, 0x08190808, 0x2b08192b,\n    0x082b2b19, 0x2b08192b, 0x19080808, 0x2b08192b, 0x08080808, 0x2b082b08, 0x08081919, 0x2b082b08,\n    0x08190819, 0x2b082b08, 0x08191908, 0x2b082b08, 0x19080819, 0x2b082b08, 0x19081908, 0x2b082b08,\n    0x19190808, 0x2b082b08, 0x2b2b082b, 0x2b082b08, 0x08080819, 0x2b082b19, 0x08081908, 0x2b082b19,\n    0x19080808, 0x2b082b19, 0x192b1919, 0x2b082b19, 0x082b082b, 0x2b082b2b, 0x19192b08, 0x2b082b2b,\n    0x19192b2b, 0x2b082b2b, 0x2b08082b, 0x2b082b2b, 0x2b2b082b, 0x2b082b2b, 0x08080819, 0x2b190808,\n    0x08081908, 0x2b190808, 0x08082b19, 0x2b190808, 0x08190808, 0x2b190808, 0x0819082b, 0x2b190808,\n    0x08191919, 0x2b190808, 0x08192b08, 0x2b190808, 0x082b1908, 0x2b190808, 0x19080808, 0x2b190808,\n    0x1908082b, 0x2b190808, 0x19081919, 0x2b190808, 0x19082b08, 0x2b190808, 0x19190819, 0x2b190808,\n    0x19191908, 0x2b190808, 0x192b0808, 0x2b190808, 0x2b080819, 0x2b190808, 0x2b081908, 0x2b190808,\n    0x2b190808, 0x2b190808, 0x08080808, 0x2b190819, 0x08081919, 0x2b190819, 0x08190819, 0x2b190819,\n    0x08191908, 0x2b190819, 0x19080819, 0x2b190819, 0x19081908, 0x2b190819, 0x19190808, 0x2b190819,\n    0x19192b2b, 0x2b190819, 0x08080819, 0x2b19082b, 0x08081908, 0x2b19082b, 0x08190808, 0x2b19082b,\n    0x19080808, 0x2b19082b, 0x2b2b192b, 0x2b19082b, 0x08080808, 0x2b191908, 0x0808082b, 0x2b191908,\n    0x08081919, 0x2b191908, 0x08082b08, 0x2b191908, 0x08190819, 0x2b191908, 0x08191908, 0x2b191908,\n    0x082b0808, 0x2b191908, 0x19080819, 0x2b191908, 0x19081908, 0x2b191908, 0x19190808, 0x2b191908,\n    0x2b080808, 0x2b191908, 0x2b19192b, 0x2b191908, 0x08080819, 0x2b191919, 0x08081908, 0x2b191919,\n    0x08190808, 0x2b191919, 0x19080808, 0x2b191919, 0x2b192b08, 0x2b191919, 0x2b2b0819, 0x2b191919,\n    0x08080808, 0x2b19192b, 0x1908192b, 0x2b19192b, 0x192b1908, 0x2b19192b, 0x08080819, 0x2b192b08,\n    0x08081908, 0x2b192b08, 0x08190808, 0x2b192b08, 0x082b192b, 0x2b192b08, 0x19080808, 0x2b192b08,\n    0x2b2b2b19, 0x2b192b08, 0x08080808, 0x2b192b19, 0x19082b19, 0x2b192b19, 0x1919082b, 0x2b192b19,\n    0x2b190808, 0x2b192b2b, 0x08080808, 0x2b2b0808, 0x08081919, 0x2b2b0808, 0x08082b2b, 0x2b2b0808,\n    0x08191908, 0x2b2b0808, 0x082b082b, 0x2b2b0808, 0x082b2b2b, 0x2b2b0808, 0x19080819, 0x2b2b0808,\n    0x19081908, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b2b082b, 0x2b2b0808, 0x2b2b2b2b, 0x2b2b0808,\n    0x19080808, 0x2b2b0819, 0x192b1919, 0x2b2b0819, 0x0808082b, 0x2b2b082b, 0x08082b2b, 0x2b2b082b,\n    0x082b082b, 0x2b2b082b, 0x082b2b08, 0x2b2b082b, 0x082b2b2b, 0x2b2b082b, 0x2b08082b, 0x2b2b082b,\n    0x2b082b08, 0x2b2b082b, 0x2b082b2b, 0x2b2b082b, 0x2b2b2b08, 0x2b2b082b, 0x08080819, 0x2b2b1908,\n    0x08081908, 0x2b2b1908, 0x08190808, 0x2b2b1908, 0x19080808, 0x2b2b1908, 0x2b082b19, 0x2b2b1908,\n    0x2b2b1908, 0x2b2b1908, 0x08080808, 0x2b2b1919, 0x08192b19, 0x2b2b1919, 0x19190819, 0x2b2b192b,\n    0x08082b2b, 0x2b2b2b08, 0x082b2b08, 0x2b2b2b08, 0x2b2b082b, 0x2b2b2b08, 0x19191908, 0x2b2b2b19,\n    0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b,\n    0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b\n);\n#endif\n\n#ifdef IQ3_XXS_GRID\nconst iq3xxs_grid = array<u32, 256>(\n    0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,\n    0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,\n    0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,\n    0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,\n    0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,\n    0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,\n    0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,\n    0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,\n    0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,\n    0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,\n    0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,\n    0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,\n    0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,\n    0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,\n    0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,\n    0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,\n    0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,\n    0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,\n    0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,\n    0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,\n    0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,\n    0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,\n    0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,\n    0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,\n    0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,\n    0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,\n    0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,\n    0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,\n    0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,\n    0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,\n    0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,\n    0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04\n);\n#endif\n\n#ifdef IQ3_S_GRID\nconst iq3s_grid = array<u32, 512>(\n    0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,\n    0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,\n    0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,\n    0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,\n    0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,\n    0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,\n    0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,\n    0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,\n    0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,\n    0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,\n    0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,\n    0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,\n    0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,\n    0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,\n    0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,\n    0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,\n    0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,\n    0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,\n    0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,\n    0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,\n    0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,\n    0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,\n    0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,\n    0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,\n    0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,\n    0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,\n    0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,\n    0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,\n    0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,\n    0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,\n    0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,\n    0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,\n    0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,\n    0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,\n    0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,\n    0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,\n    0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,\n    0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,\n    0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,\n    0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,\n    0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,\n    0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,\n    0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,\n    0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,\n    0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,\n    0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,\n    0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,\n    0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,\n    0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,\n    0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,\n    0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,\n    0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,\n    0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,\n    0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,\n    0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,\n    0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,\n    0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,\n    0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,\n    0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,\n    0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,\n    0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,\n    0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,\n    0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,\n    0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101\n);\n#endif\n\n#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID)\n\nconst IQ1_DELTA: f32 = 0.125;\n\nconst iq1_grid = array<u32, 1024>(\n    0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01,\n    0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4,\n    0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41,\n    0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f,\n    0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334,\n    0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f,\n    0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040,\n    0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f,\n    0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5,\n    0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3,\n    0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff,\n    0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570,\n    0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f,\n    0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf,\n    0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f,\n    0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07,\n    0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc,\n    0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374,\n    0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0,\n    0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001,\n    0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043,\n    0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc,\n    0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117,\n    0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f,\n    0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5,\n    0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474,\n    0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d,\n    0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd,\n    0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50,\n    0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10,\n    0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30,\n    0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1,\n    0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c,\n    0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074,\n    0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134,\n    0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7,\n    0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3,\n    0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450,\n    0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577,\n    0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c,\n    0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5,\n    0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c,\n    0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00,\n    0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300,\n    0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc,\n    0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034,\n    0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077,\n    0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5,\n    0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117,\n    0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f,\n    0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5,\n    0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404,\n    0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1,\n    0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd,\n    0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71,\n    0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7,\n    0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00,\n    0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44,\n    0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00,\n    0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0,\n    0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303,\n    0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343,\n    0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd,\n    0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031,\n    0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011,\n    0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c,\n    0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4,\n    0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c,\n    0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174,\n    0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7,\n    0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d,\n    0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4,\n    0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c,\n    0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7,\n    0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510,\n    0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33,\n    0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4,\n    0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73,\n    0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f,\n    0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337,\n    0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343,\n    0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030,\n    0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075,\n    0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4,\n    0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170,\n    0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705,\n    0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c,\n    0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c,\n    0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514,\n    0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c,\n    0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3,\n    0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70,\n    0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03,\n    0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c,\n    0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c,\n    0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074,\n    0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104,\n    0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7,\n    0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757,\n    0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c,\n    0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c,\n    0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4,\n    0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc,\n    0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03,\n    0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc,\n    0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54,\n    0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f,\n    0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf,\n    0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c,\n    0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c,\n    0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4,\n    0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174,\n    0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700,\n    0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7,\n    0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d,\n    0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531,\n    0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf,\n    0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57,\n    0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13,\n    0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01,\n    0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f,\n    0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7,\n    0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074,\n    0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107,\n    0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd,\n    0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0,\n    0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7,\n    0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557\n);\n\n#endif\n\n#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID)\n\nconst kvalues_iq4nl = array<i32, 16>(\n    -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113\n);\n\n#endif\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/concat.wgsl",
    "content": "struct Params {\n    ne: u32,\n\n    offset_src0: u32,\n    offset_src1: u32,\n    offset_dst: u32,\n\n    stride_src0_0: u32,\n    stride_src0_1: u32,\n    stride_src0_2: u32,\n    stride_src0_3: u32,\n\n    stride_src1_0: u32,\n    stride_src1_1: u32,\n    stride_src1_2: u32,\n    stride_src1_3: u32,\n\n    ne0: u32,\n    ne1: u32,\n    ne2: u32,\n    ne3: u32,\n\n    dim: u32,\n    src0_nedim: u32\n};\n\n#ifdef TYPE_F32\n#define DataType f32\n#endif\n#ifdef TYPE_I32\n#define DataType i32\n#endif\n\n@group(0) @binding(0)\nvar<storage, read_write> src0: array<DataType>;\n\n@group(0) @binding(1)\nvar<storage, read_write> src1 : array<DataType>;\n\n@group(0) @binding(2)\nvar<storage, read_write> dst: array<DataType>;\n\n@group(0) @binding(3)\nvar<uniform> params: Params;\n\n@compute @workgroup_size(WG_SIZE)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n\n    if (gid.x < params.ne) {\n        var i = gid.x;\n        let i3 = i / (params.ne2 * params.ne1 * params.ne0);\n        i = i % (params.ne2 * params.ne1 * params.ne0);\n        let i2 = i / (params.ne1 * params.ne0);\n        i = i % (params.ne1 * params.ne0);\n        let i1 = i / params.ne0;\n        let i0 = i % params.ne0;\n\n        var ni = array<u32, 4>(i0, i1, i2, i3);\n\n        if (ni[params.dim] < params.src0_nedim) {\n            let src_i = ni[0] * params.stride_src0_0 +\n                             ni[1] * params.stride_src0_1 +\n                             ni[2] * params.stride_src0_2 +\n                             ni[3] * params.stride_src0_3;\n            dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i];\n        } else {\n            ni[params.dim] -= params.src0_nedim;\n            let src_i = ni[0] * params.stride_src1_0 +\n                             ni[1] * params.stride_src1_1 +\n                             ni[2] * params.stride_src1_2 +\n                             ni[3] * params.stride_src1_3;\n            dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i];\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl",
    "content": "#define(VARIANTS)\n\n[\n  {\n    \"REPLS\": {\n      \"SRC_TYPE\": \"f32\",\n      \"DST_TYPE\": \"f32\"\n    }\n  },\n  {\n    \"REPLS\": {\n      \"SRC_TYPE\": \"f32\",\n      \"DST_TYPE\": \"i32\"\n    }\n  },\n  {\n    \"REPLS\": {\n      \"SRC_TYPE\": \"f32\",\n      \"DST_TYPE\": \"f16\"\n    }\n  },\n  {\n    \"REPLS\": {\n      \"SRC_TYPE\": \"f16\",\n      \"DST_TYPE\": \"f16\"\n    }\n  },\n  {\n    \"REPLS\": {\n      \"SRC_TYPE\": \"f16\",\n      \"DST_TYPE\": \"f32\"\n    }\n  }\n]\n\n#end(VARIANTS)\n\n#define(SHADER)\nenable f16;\n\n@group(0) @binding(0)\nvar<storage, read_write> src: array<{{SRC_TYPE}}>;\n\n@group(0) @binding(1)\nvar<storage, read_write> dst: array<{{DST_TYPE}}>;\n\nstruct Params {\n    ne: u32,            // total number of elements\n    offset_src: u32,    // in elements\n    offset_dst: u32,    // in elements\n\n    // Strides (in elements) — may be permuted\n    stride_src0: u32,\n    stride_src1: u32,\n    stride_src2: u32,\n    stride_src3: u32,\n\n    stride_dst0: u32,\n    stride_dst1: u32,\n    stride_dst2: u32,\n    stride_dst3: u32,\n\n    // Logical shapes\n    src_ne0: u32,\n    src_ne1: u32,\n    src_ne2: u32,\n\n    dst_ne0: u32,\n    dst_ne1: u32,\n    dst_ne2: u32\n};\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n\noverride wg_size: u32;\n@compute @workgroup_size(wg_size)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n    if (gid.x >= params.ne) {\n        return;\n    }\n\n    var i = gid.x;\n    let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);\n    i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);\n    let i2 = i / (params.src_ne1 * params.src_ne0);\n    i = i % (params.src_ne1 * params.src_ne0);\n    let i1 = i / params.src_ne0;\n    let i0 = i % params.src_ne0;\n\n    var j = gid.x;\n    let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);\n    j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);\n    let j2 = j / (params.dst_ne1 * params.dst_ne0);\n    j = j % (params.dst_ne1 * params.dst_ne0);\n    let j1 = j / params.dst_ne0;\n    let j0 = j % params.dst_ne0;\n\n    let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +\n                  i2 * params.stride_src2 + i3 * params.stride_src3;\n\n    let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +\n                  j2 * params.stride_dst2 + j3 * params.stride_dst3;\n\n    dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));\n}\n#end(SHADER)\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/cumsum.wgsl",
    "content": "@group(0) @binding(0)\nvar<storage, read_write> src: array<f32>;\n\n@group(0) @binding(1)\nvar<storage, read_write> dst: array<f32>;\n\nstruct Params {\n    offset_src: u32, // in elements\n    offset_dst: u32, // in elements\n    ne0: u32,\n};\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n\nvar<workgroup> shared_sum: array<f32, WG_SIZE>;\n\n@compute @workgroup_size(WG_SIZE)\nfn main(@builtin(workgroup_id) wid: vec3<u32>,\n        @builtin(local_invocation_id) lid: vec3<u32>) {\n    let row_idx = params.offset_src + wid.x * params.ne0;\n    let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;\n    var local_sum: f32 = 0.0;\n    for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) {\n        local_sum += src[row_idx + col];\n    }\n    shared_sum[lid.x] = local_sum;\n    workgroupBarrier();\n\n    // upsweep\n    var offset = 1u;\n    while (offset < WG_SIZE) {\n        let idx = (lid.x + 1) * offset * 2 - 1;\n        if (idx < WG_SIZE) {\n            shared_sum[idx] = shared_sum[idx] + shared_sum[idx - offset];\n        }\n        workgroupBarrier();\n        offset <<= 1;\n    }\n\n    // set last to 0 for exclusive sum\n    if (lid.x == 0) {\n        shared_sum[WG_SIZE - 1] = 0.0;\n    }\n    workgroupBarrier();\n\n    // downsweep\n    offset = WG_SIZE >> 1;\n    while (offset > 0) {\n        let idx = (lid.x + 1) * offset * 2 - 1;\n        if (idx < WG_SIZE) {\n            let t = shared_sum[idx - offset];\n            shared_sum[idx - offset] = shared_sum[idx];\n            shared_sum[idx] = shared_sum[idx] + t;\n        }\n        workgroupBarrier();\n        offset = offset >> 1;\n    }\n\n    // shared_sum[lid] is exclusive prefix sum up to this thread.\n    var running_sum = shared_sum[lid.x];\n    for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) {\n        running_sum += src[row_idx + col];\n        dst[params.offset_dst + wid.x * params.ne0 + col] = running_sum;\n    }\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/embed_wgsl.py",
    "content": "import os\nimport re\nimport ast\nimport argparse\n\n\ndef extract_block(text, name):\n    pattern = rf'#define\\({name}\\)\\s*(.*?)#end\\({name}\\)'\n    match = re.search(pattern, text, re.DOTALL)\n    if not match:\n        raise ValueError(f\"Missing block: {name}\")\n    return match.group(1).strip()\n\n\ndef parse_decls(decls_text):\n    decls = {}\n    for name, code in re.findall(r'#decl\\((.*?)\\)\\s*(.*?)#enddecl\\(\\1\\)', decls_text, re.DOTALL):\n        decls[name.strip()] = code.strip()\n    return decls\n\n\ndef replace_repl_placeholders(variant, template_map):\n    for repl, code in variant[\"REPLS\"].items():\n        for key, val in template_map.items():\n            # Match \"key\" and avoid matching subsequences using by using \\b\n            code = re.sub(rf'\\b{re.escape(str(key))}\\b', str(val), code)\n        variant[\"REPLS\"][repl] = code\n    return variant\n\n\ndef replace_placeholders(shader_text, replacements):\n    for key, val in replacements.items():\n        # Match {{KEY}} literally, where KEY is escaped\n        pattern = r'{{\\s*' + re.escape(key) + r'\\s*}}'\n        shader_text = re.sub(pattern, str(val), shader_text)\n    return shader_text\n\n\ndef expand_includes(shader, input_dir):\n    \"\"\"\n    Replace #include \"file\" lines in the text with the contents of that file.\n    Searches for files relative to input_dir.\n    \"\"\"\n    include_pattern = re.compile(r'^\\s*#include\\s+\"([^\"]+)\"\\s*$', re.MULTILINE)\n\n    def replacer(match):\n        fname = match.group(1)\n        file_path = os.path.join(input_dir, fname)\n        if not os.path.exists(file_path):\n            raise FileNotFoundError(f\"Included file not found: {file_path}\")\n        with open(file_path, \"r\", encoding=\"utf-8\") as f:\n            included_code = f.read()\n        # Recursively expand includes inside the included file\n        return expand_includes(included_code, input_dir)\n\n    return include_pattern.sub(replacer, shader)\n\n\ndef chunk_shader(shader_code, max_chunk_len=60000):\n    \"\"\"Split shader_code into safe raw-string sized chunks.\"\"\"\n    return [shader_code[i : i + max_chunk_len] for i in range(0, len(shader_code), max_chunk_len)]\n\n\ndef raw_delim(shader_code):\n    \"\"\"Pick a raw-string delimiter that does not appear in the shader.\"\"\"\n    delim = \"wgsl\"\n    while f\"){delim}\\\"\" in shader_code:\n        delim += \"_x\"\n    return delim\n\n\ndef write_shader(shader_name, shader_code, output_dir, outfile, input_dir):\n    shader_code = expand_includes(shader_code, input_dir)\n\n    if output_dir:\n        wgsl_filename = os.path.join(output_dir, f\"{shader_name}.wgsl\")\n        with open(wgsl_filename, \"w\", encoding=\"utf-8\") as f_out:\n            f_out.write(shader_code)\n\n    delim = raw_delim(shader_code)\n    chunks = chunk_shader(shader_code)\n\n    if len(chunks) == 1:\n        outfile.write(f'const char* wgsl_{shader_name} = R\"{delim}({shader_code}){delim}\";\\n\\n')\n    else:\n        for idx, chunk in enumerate(chunks):\n            outfile.write(f'static const char wgsl_{shader_name}_part{idx}[] = R\"{delim}({chunk}){delim}\";\\n\\n')\n        outfile.write(f'static const std::string& wgsl_{shader_name}_str() {{\\n')\n        outfile.write('    static const std::string s = []{\\n')\n        outfile.write('        std::string tmp;\\n')\n        outfile.write(f'        tmp.reserve({len(shader_code)});\\n')\n        for idx in range(len(chunks)):\n            outfile.write(f'        tmp.append(wgsl_{shader_name}_part{idx});\\n')\n        outfile.write('        return tmp;\\n')\n        outfile.write('    }();\\n')\n        outfile.write('    return s;\\n')\n        outfile.write('}\\n')\n        outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\\n\\n')\n\n\ndef generate_variants(fname, input_dir, output_dir, outfile):\n    shader_path = os.path.join(input_dir, fname)\n    shader_base_name = fname.split(\".\")[0]\n\n    with open(shader_path, \"r\", encoding=\"utf-8\") as f:\n        text = f.read()\n\n    try:\n        variants = ast.literal_eval(extract_block(text, \"VARIANTS\"))\n    except ValueError:\n        write_shader(shader_base_name, text, output_dir, outfile, input_dir)\n    else:\n        try:\n            decls_map = parse_decls(extract_block(text, \"DECLS\"))\n        except ValueError:\n            decls_map = {}\n        try:\n            templates_map = ast.literal_eval(extract_block(text, \"REPL_TEMPLATES\"))\n        except ValueError:\n            templates_map = {}\n\n        for fname in sorted(os.listdir(input_dir)):\n            if fname.endswith(\".tmpl\"):\n                tmpl_path = os.path.join(input_dir, fname)\n                with open(tmpl_path, \"r\", encoding=\"utf-8\") as f_tmpl:\n                    decls = f_tmpl.read()\n                    decls_map.update(parse_decls(decls))\n\n        shader_template = extract_block(text, \"SHADER\")\n        for variant in variants:\n            if \"DECLS\" in variant:\n                decls = variant[\"DECLS\"]\n            else:\n                decls = []\n            decls_code = \"\"\n            for key in decls:\n                if key not in decls_map:\n                    raise ValueError(f\"DECLS key '{key}' not found.\")\n                decls_code += decls_map[key] + \"\\n\\n\"\n            final_shader = re.sub(r'\\bDECLS\\b', decls_code, shader_template)\n            if \"REPLS\" in variant:\n                variant = replace_repl_placeholders(variant, templates_map)\n                final_shader = replace_placeholders(final_shader, variant[\"REPLS\"])\n                # second run to expand placeholders in repl_template\n                final_shader = replace_placeholders(final_shader, variant[\"REPLS\"])\n            final_shader = expand_includes(final_shader, input_dir)\n\n            if \"SHADER_NAME\" in variant:\n                output_name = variant[\"SHADER_NAME\"]\n            elif \"SHADER_SUFFIX\" in variant:\n                output_name = f\"{shader_base_name}_\" + variant[\"SHADER_SUFFIX\"]\n            elif \"REPLS\" in variant and \"SRC0_TYPE\" in variant[\"REPLS\"] and \"SRC1_TYPE\" in variant[\"REPLS\"]:\n                output_name = f\"{shader_base_name}_\" + \"_\".join([variant[\"REPLS\"][\"SRC0_TYPE\"], variant[\"REPLS\"][\"SRC1_TYPE\"]])\n            elif \"REPLS\" in variant and \"SRC_TYPE\" in variant[\"REPLS\"] and \"DST_TYPE\" in variant[\"REPLS\"]:\n                output_name = f\"{shader_base_name}_\" + \"_\".join([variant[\"REPLS\"][\"SRC_TYPE\"], variant[\"REPLS\"][\"DST_TYPE\"]])\n            elif \"REPLS\" in variant and \"TYPE\" in variant[\"REPLS\"]:\n                output_name = f\"{shader_base_name}_\" + variant[\"REPLS\"][\"TYPE\"]\n            else:\n                output_name = shader_base_name\n            write_shader(output_name, final_shader, output_dir, outfile, input_dir)\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--input_dir\", required=True)\n    parser.add_argument(\"--output_file\", required=True)\n    parser.add_argument(\"--output_dir\")\n    args = parser.parse_args()\n\n    if args.output_dir:\n        os.makedirs(args.output_dir, exist_ok=True)\n\n    with open(args.output_file, \"w\", encoding=\"utf-8\") as out:\n        out.write(\"// Auto-generated shader embedding\\n\")\n        out.write(\"#include <string>\\n\\n\")\n        for fname in sorted(os.listdir(args.input_dir)):\n            if fname.endswith(\".wgsl\"):\n                generate_variants(fname, args.input_dir, args.output_dir, out)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl",
    "content": "diagnostic(off, chromium.subgroup_matrix_uniformity);\ndiagnostic(off, subgroup_uniformity);\nenable f16;\nenable subgroups;\nenable chromium_experimental_subgroup_matrix;\n\n#ifdef KV_F32\n#define KV_TYPE f32\n#else\n#define KV_TYPE f16\n#endif\n\n// Default values\n#define HEAD_DIM_QK 64\n#define HEAD_DIM_V 64\n\n// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN\n// Note that the \"K\" here does not correspond to the K in attention's Q/K/V, it's just the common dimension.\n#define SG_MAT_M 8\n#define SG_MAT_N 8\n#define SG_MAT_K 8\n\n// Each workgroup processes one subgroup matrix of Q rows\n#define Q_TILE SG_MAT_M\n#define KV_TILE 16\n#define WG_SIZE 64\n\n// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE.\n#define KV_BLOCKS (KV_TILE / SG_MAT_N)\n\n// Quantization constants/helpers\n#define BLOCK_SIZE 32\n#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)\n#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)\n// number of quantized elements processed per thread\n#if defined(KV_Q4_0)\n#define NQ 16\n// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights\n#define F16_PER_BLOCK 9\n#define WEIGHTS_PER_F16 4\n#elif defined(KV_Q8_0)\n#define NQ 8\n// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights\n#define F16_PER_BLOCK 17\n#define WEIGHTS_PER_F16 2\n#endif\n#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)\n\n// Ok not to put these in a define block, compiler will remove if unused\nfn get_byte(value: u32, index: u32) -> u32 {\n    return (value >> (index * 8)) & 0xFF;\n}\n\nfn get_byte_i32(value: u32, index: u32) -> i32 {\n    return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;\n}\n\nstruct Params {\n    offset_q: u32,\n    offset_k: u32,\n    offset_v: u32,\n    offset_mask: u32,\n    offset_sinks: u32,\n    offset_dst: u32,\n\n    // shapes of Q/K/V\n    n_heads: u32,\n    seq_len_q: u32,\n    seq_len_kv: u32,\n\n    // strides (in elements)\n    stride_q1: u32,\n    stride_q2: u32,\n    stride_q3: u32,\n    stride_k1: u32,\n    stride_k2: u32,\n    stride_k3: u32,\n    stride_v1: u32,\n    stride_v2: u32,\n    stride_v3: u32,\n    stride_mask3: u32,\n\n    // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA\n    q_per_kv: u32,\n\n    // softmax params\n    scale: f32,\n    max_bias: f32,\n    logit_softcap: f32,\n    n_head_log2: f32,\n    m0: f32,\n    m1: f32,\n};\n\n@group(0) @binding(0) var<storage, read_write> Q: array<f32>;\n@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;\n@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;\n\n#if defined(MASK) && defined(SINKS)\n@group(0) @binding(3) var<storage, read_write> mask: array<f16>;\n@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;\n#define DST_BINDING 5\n#define PARAMS_BINDING 6\n#elif defined(MASK)\n@group(0) @binding(3) var<storage, read_write> mask: array<f16>;\n#define DST_BINDING 4\n#define PARAMS_BINDING 5\n#elif defined(SINKS)\n@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;\n#define DST_BINDING 4\n#define PARAMS_BINDING 5\n#else\n#define DST_BINDING 3\n#define PARAMS_BINDING 4\n#endif\n\n@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;\n@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;\n\n// Just a very small float value.\nconst FLOAT_MIN: f32 = -1.0e9;\n\n// The number of Q rows processed per workgroup\nvar<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;\n\n#ifndef KV_DIRECT\nconst kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);\n// we can reuse the same shmem for K and V since we only need one at a time\nvar<workgroup> kv_shmem: array<f16, kv_shmem_size>;\n#endif\n\nvar<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>; // output shmem\n\n#ifdef MASK\n// storage for mask values\nvar<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;\n#endif\n\n// storage for output of Q*K^T scores for online softmax (S matrix from paper)\n// also storage for diagonal matrix during online softmax (P matrix from paper)\n// note that we reuse the same storage for both since we only need one at a time\nvar<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;\n\n// Storage for row max and exp sum during online softmax\nvar<workgroup> row_max_shmem: array<f32, Q_TILE>;\nvar<workgroup> exp_sum_shmem: array<f32, Q_TILE>;\n\nfn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 {\n    var v = select(FLOAT_MIN,\n                   f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,\n                   kv_idx < KV_TILE);\n#ifdef LOGIT_SOFTCAP\n    v = params.logit_softcap * tanh(v);\n#endif\n#ifdef MASK\n    let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);\n    let mask_term = slope * mask_val;\n    v += mask_term;\n#endif\n    return v;\n}\n\nfn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32) -> vec4<f32> {\n    return (*buf)[scalar_index >> 2u];\n}\n\nfn load_kvx4(buf: ptr<storage, array<vec4<KV_TYPE>>, read_write>, scalar_index: u32) -> vec4<KV_TYPE> {\n    return (*buf)[scalar_index >> 2u];\n}\n\n@compute @workgroup_size(WG_SIZE)\nfn main(@builtin(workgroup_id) wg_id: vec3<u32>,\n    @builtin(local_invocation_id) local_id: vec3<u32>,\n    @builtin(subgroup_id) subgroup_id: u32,\n    @builtin(subgroup_size) subgroup_size: u32,\n    @builtin(num_subgroups) num_subgroups: u32,\n    @builtin(subgroup_invocation_id) sg_inv_id: u32) {\n\n    // initialize row max for online softmax\n    for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {\n        row_max_shmem[i] = FLOAT_MIN;\n        exp_sum_shmem[i] = 0.0;\n    }\n\n    for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {\n        o_shmem[i] = 0.0;\n    }\n\n    // workgroups per head/batch\n    let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;\n    let wg_per_batch = wg_per_head * params.n_heads;\n\n    let dst2_stride = HEAD_DIM_V * params.n_heads;\n    let dst3_stride = dst2_stride * params.seq_len_q;\n\n    // batch index\n    let batch_idx = wg_id.x / wg_per_batch;\n    let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;\n    let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;\n    let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;\n    let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride;\n    let wg_in_batch = wg_id.x % wg_per_batch;\n\n    // head index\n    let head_idx = wg_in_batch / wg_per_head;\n    let q_head_offset = q_batch_offset + head_idx * params.stride_q2;\n    let k_head_idx = head_idx / params.q_per_kv;\n    let v_head_idx = k_head_idx;\n    let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;\n    let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;\n\n    // starting Q row for this workgroup\n    let wg_in_head = wg_in_batch % wg_per_head;\n    let q_row_start = wg_in_head * Q_TILE;\n\n#ifdef MASK\n    // mask offset\n    let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;\n#endif\n\n    // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size]\n    let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V;\n\n    let head = f32(head_idx);\n    let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0);\n\n    // load q tile into shared memory\n    for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {\n        let q_row = elem_idx / HEAD_DIM_QK;\n        let q_col = elem_idx % HEAD_DIM_QK;\n        let head_q_row = q_row_start + q_row;\n        let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;\n        q_shmem[elem_idx] = f16(select(\n            0.0,\n            Q[global_q_row_offset + q_col],\n            head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));\n    }\n\n    for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {\n      // clear inter_shmem to ensure zero-initialized accumulators\n        for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {\n            inter_shmem[elem_idx] = 0.0;\n        }\n\n      // load k tile into shared memory\n#if defined(KV_Q4_0)\n      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {\n          let blck_idx = elem_idx / BLOCK_SIZE;\n          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;\n          let k_row = blck_idx / BLOCKS_K;\n          let global_k_row = kv_tile + k_row;\n          let block_k = blck_idx % BLOCKS_K;\n          let row_offset = k_row * HEAD_DIM_QK;\n\n          if (global_k_row < params.seq_len_kv) {\n              let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;\n              let base_idx = global_block_idx * F16_PER_BLOCK;\n              let d = K[base_idx]; // scale\n              for (var j = 0u; j < F16_PER_THREAD; j += 2) {\n                  let q_0 = K[base_idx + 1u + block_offset + j];\n                  let q_1 = K[base_idx + 1u + block_offset + j + 1];\n                  let q_packed = bitcast<u32>(vec2(q_0, q_1));\n                  for (var k = 0u; k < 4u; k++) {\n                      let q_byte = get_byte(q_packed, k);\n                      let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;\n                      let q_lo = (f16(q_byte & 0xF) - 8.0) * d;\n                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;\n                      kv_shmem[row_offset + idx] = q_lo;\n                      kv_shmem[row_offset + idx + 16u] = q_hi;\n                  }\n              }\n          }\n      }\n#elif defined(KV_Q8_0)\n      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {\n          let blck_idx = elem_idx / BLOCK_SIZE;\n          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;\n          let k_row = blck_idx / BLOCKS_K;\n          let global_k_row = kv_tile + k_row;\n          let block_k = blck_idx % BLOCKS_K;\n          let row_offset = k_row * HEAD_DIM_QK;\n\n          if (global_k_row < params.seq_len_kv) {\n              let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;\n              let base_idx = global_block_idx * F16_PER_BLOCK;\n              let d = K[base_idx]; // scale\n              for (var j = 0u; j < F16_PER_THREAD; j += 2) {\n                  let q_0 = K[base_idx + 1u + block_offset + j];\n                  let q_1 = K[base_idx + 1u + block_offset + j + 1];\n                  let q_packed = bitcast<u32>(vec2(q_0, q_1));\n                  for (var k = 0u; k < 4u; k++) {\n                      let q_byte = get_byte_i32(q_packed, k);\n                      let q_val = f16(q_byte) * d;\n                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;\n                      kv_shmem[row_offset + idx] = q_val;\n                  }\n              }\n          }\n      }\n#elif defined(KV_DIRECT)\n      // Direct global loads for KV\n#else\n      for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {\n          let k_row = elem_idx / HEAD_DIM_QK;\n          let k_col = elem_idx % HEAD_DIM_QK;\n          let global_k_row = kv_tile + k_row;\n          let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;\n          kv_shmem[elem_idx] = f16(select(\n              0.0,\n              K[global_k_row_offset + k_col],\n              global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK));\n      }\n#endif\n\n      workgroupBarrier();\n\n      // accumulate q block * k block into registers across the entire KV tile\n      // TODO: this loop seems to be the current largest bottleneck\n      // this bracket exists to scope the lifetime of variables, reducing register pressure\n      {\n#ifdef KV_DIRECT\n          let k_block_row = kv_tile + subgroup_id * SG_MAT_N;\n          var k_global_offset = k_head_offset + k_block_row * params.stride_k1;\n#else\n          var k_block_offset = subgroup_id * SG_MAT_N * HEAD_DIM_QK;\n#endif\n          for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {\n              let inter_offset = kv_block * SG_MAT_N;\n              var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);\n\n              var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, 0u, false, HEAD_DIM_QK);\n\n#ifdef KV_DIRECT\n              var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + 0u, true, params.stride_k1);\n#else\n              var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);\n#endif\n\n              var t: u32 = 1u;\n              for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {\n                  let h0 = t * SG_MAT_K;\n                  var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h0, false, HEAD_DIM_QK);\n#ifdef KV_DIRECT\n                  var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h0, true, params.stride_k1);\n#else\n                  var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);\n#endif\n                  acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);\n                  q_cur = q0;\n                  k_cur = k0;\n\n                  let h1 = (t + 1u) * SG_MAT_K;\n                  var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h1, false, HEAD_DIM_QK);\n#ifdef KV_DIRECT\n                  var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h1, true, params.stride_k1);\n#else\n                  var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);\n#endif\n                  acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);\n                  q_cur = q1g;\n                  k_cur = k1g;\n              }\n\n              // handle odd tail\n              if (t < HEAD_DIM_QK / SG_MAT_K) {\n                  let h = t * SG_MAT_K;\n                  var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h, false, HEAD_DIM_QK);\n#ifdef KV_DIRECT\n                  var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h, true, params.stride_k1);\n#else\n                  var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);\n#endif\n                  acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);\n                  q_cur = qn;\n                  k_cur = kn;\n              }\n\n              acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);\n\n#ifdef KV_DIRECT\n              k_global_offset += num_subgroups * SG_MAT_N * params.stride_k1;\n#else\n              k_block_offset += num_subgroups * SG_MAT_N * HEAD_DIM_QK;\n#endif\n              subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);\n          }\n      }\n\n\n#ifdef MASK\n      // load mask tile into shared memory for this KV block\n      // TODO: optimize and skip if mask is -INF for the entire tile\n      for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {\n          let mask_row = elem_idx / KV_TILE;\n          let mask_col = elem_idx % KV_TILE;\n          let global_q_row = q_row_start + mask_row;\n          let global_k_col = kv_tile + mask_col;\n          let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;\n          let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;\n          mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);\n      }\n#endif\n\n      workgroupBarrier();\n\n      // online softmax\n      for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {\n          let global_q_row = q_row_start + q_tile_row;\n          if (global_q_row >= params.seq_len_q) {\n              break;\n          }\n\n          // initialize running max for this row\n          var prev_max = row_max_shmem[q_tile_row];\n          var final_max = prev_max;\n          // pass 1: compute final max across the full KV tile in chunks\n          for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {\n              let kv_idx = kv_offset + sg_inv_id;\n              let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope);\n              final_max = subgroupMax(max(final_max, softmax_term));\n          }\n\n          var total_exp_term: f32 = 0.0;\n          // pass 2: compute exp sum and write P using final_max\n          for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {\n              let kv_idx = kv_offset + sg_inv_id;\n              let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope);\n              let cur_p = select(0.0,\n                                 exp(softmax_term - final_max),\n                                 kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);\n              total_exp_term += subgroupAdd(cur_p);\n              if (kv_idx < KV_TILE) {\n                  inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);\n              }\n          }\n\n          let cur_exp = exp(prev_max - final_max);\n\n          if (sg_inv_id == 0) {\n              row_max_shmem[q_tile_row] = final_max;\n              exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;\n          }\n\n          for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {\n              let idx = q_tile_row * HEAD_DIM_V + elem_idx;\n              o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);\n          }\n      }\n\n      // load v tile into shared memory\n#if defined(KV_Q4_0)\n      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {\n          let blck_idx = elem_idx / BLOCK_SIZE;\n          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;\n          let v_row = blck_idx / BLOCKS_V;\n          let global_v_row = kv_tile + v_row;\n          let block_k = blck_idx % BLOCKS_V;\n          let row_offset = v_row * HEAD_DIM_V;\n\n          if (global_v_row < params.seq_len_kv) {\n              let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;\n              let base_idx = global_block_idx * F16_PER_BLOCK;\n              let d = V[base_idx]; // scale\n              for (var j = 0u; j < F16_PER_THREAD; j += 2) {\n                  let q_0 = V[base_idx + 1u + block_offset + j];\n                  let q_1 = V[base_idx + 1u + block_offset + j + 1];\n                  let q_packed = bitcast<u32>(vec2(q_0, q_1));\n                  for (var k = 0u; k < 4u; k++) {\n                      let q_byte = get_byte(q_packed, k);\n                      let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;\n                      let q_lo = (f16(q_byte & 0xF) - 8.0) * d;\n                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;\n                      kv_shmem[row_offset + idx] = q_lo;\n                      kv_shmem[row_offset + idx + 16u] = q_hi;\n                  }\n              }\n          }\n      }\n#elif defined(KV_Q8_0)\n      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {\n          let blck_idx = elem_idx / BLOCK_SIZE;\n          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;\n          let v_row = blck_idx / BLOCKS_V;\n          let global_v_row = kv_tile + v_row;\n          let block_k = blck_idx % BLOCKS_V;\n          let row_offset = v_row * HEAD_DIM_V;\n\n          if (global_v_row < params.seq_len_kv) {\n              let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;\n              let base_idx = global_block_idx * F16_PER_BLOCK;\n              let d = V[base_idx]; // scale\n              for (var j = 0u; j < F16_PER_THREAD; j += 2) {\n                  let q_0 = V[base_idx + 1u + block_offset + j];\n                  let q_1 = V[base_idx + 1u + block_offset + j + 1];\n                  let q_packed = bitcast<u32>(vec2(q_0, q_1));\n                  for (var k = 0u; k < 4u; k++) {\n                      let q_byte = get_byte_i32(q_packed, k);\n                      let q_val = f16(q_byte) * d;\n                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;\n                      kv_shmem[row_offset + idx] = q_val;\n                  }\n              }\n          }\n      }\n#elif defined(KV_DIRECT)\n      // Direct global loads for KV\n#else\n      for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) {\n          let v_row = elem_idx / HEAD_DIM_V;\n          let v_col = elem_idx % HEAD_DIM_V;\n          let global_v_row = kv_tile + v_row;\n          let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;\n          kv_shmem[elem_idx] = f16(select(\n              0.0,\n              V[global_v_row_offset + v_col],\n              global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V));\n      }\n#endif\n\n      workgroupBarrier();\n\n      // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem\n      // we want to compute O += P * V across the full KV tile\n      for (var head_dim_block = subgroup_id * SG_MAT_N;\n           head_dim_block < HEAD_DIM_V;\n           head_dim_block += num_subgroups * SG_MAT_N) {\n              // load O submatrix from shared memory\n              var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(\n                  &o_shmem,\n                  head_dim_block,\n                  false,\n                  HEAD_DIM_V\n              );\n              for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {\n                  let p_offset = kv_block * SG_MAT_N;\n                  var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(\n                      &inter_shmem,\n                      p_offset,\n                      false,\n                      KV_TILE\n                  );\n\n                  // load V submatrix from global or shared memory\n#ifdef KV_DIRECT\n                  let v_block_row = kv_tile + kv_block * SG_MAT_N;\n                  let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block;\n                  var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(\n                      &V,\n                      v_global_offset,\n                      false,\n                      params.stride_v1\n                  );\n#else\n                  let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V;\n                  var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(\n                      &kv_shmem,\n                      v_block_offset + head_dim_block,\n                      false,\n                      HEAD_DIM_V\n                  );\n#endif\n                  // O += P * V\n                  o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat);\n              }\n              // store O back to shared memory\n              subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V);\n      }\n      workgroupBarrier();\n    }\n\n#ifdef SINKS\n    // add sinks (applied once after processing all KV tiles)\n    for (var q_tile_row = subgroup_id;\n         q_tile_row < Q_TILE;\n         q_tile_row += num_subgroups) {\n            // no need to process rows beyond seq_len_q\n            let global_q_row = q_row_start + q_tile_row;\n            if (global_q_row >= params.seq_len_q) {\n                break;\n            }\n\n            var prev_max = row_max_shmem[q_tile_row];\n\n            // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum\n            let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);\n            let new_max = subgroupMax(max(prev_max, sink_val));\n            let max_exp = exp(prev_max - new_max);\n            let sink_exp = exp(sink_val - new_max);\n\n            let sink_exp_sum = subgroupAdd(sink_exp);\n\n            if (sg_inv_id == 0) {\n                exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;\n            }\n\n            for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {\n                let idx = q_tile_row * HEAD_DIM_V + elem_idx;\n                let val = f32(o_shmem[idx]) * max_exp;\n                o_shmem[idx] = f16(val);\n            }\n    }\n    workgroupBarrier();\n#endif\n    for (var q_tile_row = subgroup_id;\n        q_tile_row < Q_TILE;\n        q_tile_row += num_subgroups) {\n\n        let global_q_row = q_row_start + q_tile_row;\n        if (global_q_row >= params.seq_len_q) { break; }\n\n        let exp_sum = exp_sum_shmem[q_tile_row];\n        let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);\n\n        let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride;\n\n        for (var elem_base = sg_inv_id * 4u;\n            elem_base < HEAD_DIM_V;\n            elem_base += subgroup_size * 4u) {\n\n            let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);\n            let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);\n            let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);\n            let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);\n\n            let v = vec4<f32>(\n                f32(o_shmem[i0]) * scale,\n                f32(o_shmem[i1]) * scale,\n                f32(o_shmem[i2]) * scale,\n                f32(o_shmem[i3]) * scale\n            );\n\n            let dst_vec_index: u32 = (row_base + elem_base) >> 2u;\n            dst[dst_vec_index] = v;\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/get_rows.wgsl",
    "content": "enable f16;\n#include \"common_decls.tmpl\"\n\n#ifdef F32_VEC\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];\n}\n#endif\n\n#ifdef F32\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    dst[dst_base + offset] = src[src_base + offset];\n}\n#endif\n\n#ifdef F16\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    dst[dst_base + offset] = f32(src[src_base + offset]);\n}\n#endif\n\n#ifdef I32\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    dst[dst_base + offset] = src[src_base + offset];\n}\n#endif\n\n#ifdef Q4_0\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block_q4_0 = src[src_base + offset];\n    let d = f32(block_q4_0.d);\n    for (var j: u32 = 0; j < 4; j++) {\n        let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));\n        for (var k: u32 = 0; k < 4; k++) {\n            let q_byte = get_byte(q_packed, k);\n            let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;\n            let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;\n            let dst_offset = dst_base + offset * 32 + j * 4 + k;\n            dst[dst_offset] = q_lo;\n            dst[dst_offset + 16] = q_hi;\n        }\n    }\n}\n#endif\n\n#ifdef Q4_1\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block_q4_1 = src[src_base + offset];\n    let d = f32(block_q4_1.d);\n    let m = f32(block_q4_1.m);\n    for (var j: u32 = 0; j < 4; j++) {\n        let q_packed = block_q4_1.qs[j];\n        for (var k: u32 = 0; k < 4; k++) {\n            let q_byte = get_byte(q_packed, k);\n            let q_hi = f32((q_byte >> 4) & 0xF) * d + m;\n            let q_lo = f32(q_byte & 0xF) * d + m;\n            let dst_offset = dst_base + offset * 32 + j * 4 + k;\n            dst[dst_offset] = q_lo;\n            dst[dst_offset + 16] = q_hi;\n        }\n    }\n}\n#endif\n\n#ifdef Q5_0\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block_q5_0 = src[src_base + offset];\n    let d = f32(block_q5_0.d);\n    let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));\n    for (var j: u32 = 0; j < 4; j++) {\n        let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));\n        for (var k: u32 = 0; k < 4; k++) {\n            let q_byte = get_byte(q_packed, k);\n            let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;\n            let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;\n            let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;\n            let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;\n            let dst_offset = dst_base + offset * 32 + j * 4 + k;\n            dst[dst_offset] = q_lo;\n            dst[dst_offset + 16] = q_hi;\n        }\n    }\n}\n#endif\n\n#ifdef Q5_1\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block_q5_1 = src[src_base + offset];\n    let d = f32(block_q5_1.d);\n    let m = f32(block_q5_1.m);\n    for (var j: u32 = 0; j < 4; j++) {\n        let q_packed = block_q5_1.qs[j];\n        for (var k: u32 = 0; k < 4; k++) {\n            let q_byte = get_byte(q_packed, k);\n            let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;\n            let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;\n            let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;\n            let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;\n            let dst_offset = dst_base + offset * 32 + j * 4 + k;\n            dst[dst_offset] = q_lo;\n            dst[dst_offset + 16] = q_hi;\n        }\n    }\n}\n#endif\n\n#ifdef Q8_0\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block_q8_0 = src[src_base + offset];\n    let d = f32(block_q8_0.d);\n    for (var j: u32 = 0; j < 8; j++) {\n        let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));\n        for (var k: u32 = 0; k < 4; k++) {\n            let q_byte = get_byte_i32(q_packed, k);\n            let q_val = f32(q_byte) * d;\n            let dst_offset = dst_base + offset * 32 + j * 4 + k;\n            dst[dst_offset] = q_val;\n        }\n    }\n}\n#endif\n\n#ifdef Q2_K\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block = src[src_base + offset];\n    let d = f32(block.d);\n    let m = f32(block.dmin);\n    var dst_i = dst_base + offset * 256;\n    var is: u32 = 0;\n    // 2 halves of the block (128 elements each)\n    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {\n        // 4 groups (each group has 2 blocks of 16 elements)\n        for (var shift: u32 = 0; shift < 8; shift += 2) {\n            // 2 blocks\n            for (var k: u32 = 0; k < 32; k += 16) {\n                let sc = get_byte(block.scales[is / 4], is % 4);\n                is++;\n                let dl = d * f32(sc & 0xF);\n                let ml = m * f32(sc >> 4);\n                for (var l: u32 = 0u; l < 16; l++) {\n                    let q_idx = q_b_idx + k + l;\n                    let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);\n                    let qs_val = (q_byte >> shift) & 3;\n                    dst[dst_i] = (f32(qs_val) * dl - ml);\n                    dst_i++;\n                }\n            }\n        }\n    }\n}\n#endif\n\n#ifdef Q3_K\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block = src[src_base + offset];\n    let d = f32(block.d);\n\n    // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,\n    // and 2-bits from the last 4 bytes\n    let kmask1: u32 = 0x03030303;\n    let kmask2: u32 = 0x0f0f0f0f;\n    var scale_vals: array<u32, 4>;\n    for (var i: u32 = 0; i < 4; i++) {\n        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));\n    }\n    var tmp: u32 = scale_vals[2];\n    scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);\n    scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);\n    scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);\n    scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);\n\n    // convert arrays of f16 -> u32\n    var hmask_vals: array<u32, 8>;\n    for (var i: u32 = 0; i < 8; i++) {\n        hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));\n    }\n    var qs_vals: array<u32, 16>;\n    for (var i: u32 = 0; i < 16; i++) {\n        qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));\n    }\n\n    var dst_i = dst_base + offset * 256;\n    var is: u32 = 0;\n    var m: u32 = 1;\n    // 2 halves of the block (128 elements each)\n    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {\n        // 4 groups (each group has 2 blocks of 16 elements)\n        for (var shift: u32 = 0; shift < 8; shift += 2) {\n            // 2 blocks\n            for (var k: u32 = 0; k < 32; k += 16) {\n                let sc = get_byte(scale_vals[is / 4], is % 4);\n                is++;\n                let dl = d * (f32(sc) - 32.0);\n                for (var l: u32 = 0u; l < 16u; l++) {\n                    let q_idx = q_b_idx + k + l;\n                    let hm_idx = k + l;\n                    let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);\n                    let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);\n                    let hm = select(4.0, 0.0, (hmask_byte & m) != 0);\n                    let qs_val = (q_byte >> shift) & 3;\n                    dst[dst_i] = (f32(qs_val) - hm) * dl;\n                    dst_i++;\n                }\n            }\n            m <<= 1;\n        }\n    }\n}\n#endif\n\n#ifdef Q4_K\n// 8 blocks of 32 elements each\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block = src[src_base + offset];\n    let d = f32(block.d);\n    let m = f32(block.dmin);\n    var dst_i = dst_base + offset * 256;\n    var is: u32 = 0;\n    // 2 blocks each iteration\n    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {\n        for (var shift: u32 = 0; shift < 8; shift += 4) {\n            let scale_min = get_scale_min(is, block.scales);\n            is++;\n            let dl = d * scale_min.x;\n            let ml = m * scale_min.y;\n            for (var l: u32 = 0; l < 32; l++) {\n                let q_idx = q_b_idx + l;\n                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);\n                let qs_val = (q_byte >> shift) & 0xF;\n                dst[dst_i] = (f32(qs_val) * dl - ml);\n                dst_i++;\n            }\n        }\n    }\n}\n#endif\n\n#ifdef Q5_K\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block = src[src_base + offset];\n    let d = f32(block.d);\n    let m = f32(block.dmin);\n    var dst_i = dst_base + offset * 256;\n    var is: u32 = 0;\n    var u: u32 = 1;\n    // 2 blocks each iteration\n    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {\n        for (var shift: u32 = 0; shift < 8; shift += 4) {\n            let scale_min = get_scale_min(is, block.scales);\n            is++;\n            let dl = d * scale_min.x;\n            let ml = m * scale_min.y;\n            for (var l: u32 = 0; l < 32; l++) {\n                let q_idx = q_b_idx + l;\n                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);\n                let qh_byte = get_byte(block.qh[l / 4], l % 4);\n                let qs_val = (q_byte >> shift) & 0xF;\n                let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);\n                dst[dst_i] = (f32(qs_val) + qh_val) * dl - ml;\n                dst_i++;\n            }\n            u <<= 1;\n        }\n    }\n}\n#endif\n\n#ifdef Q6_K\n// 16 blocks of 16 elements each\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block = src[src_base + offset];\n    let d = f32(block.d);\n\n    // convert arrays of f16 -> u32\n    var ql_vals: array<u32, 32>;\n    for (var i: u32 = 0; i < 32; i++) {\n        ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));\n    }\n    var qh_vals: array<u32, 16>;\n    for (var i: u32 = 0; i < 16; i++) {\n        qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));\n    }\n    var scale_vals: array<u32, 4>;\n    for (var i: u32 = 0; i < 4; i++) {\n        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));\n    }\n\n    var dst_i = dst_base + offset * 256;\n    var qh_b_idx: u32 = 0;\n    var sc_b_idx: u32 = 0;\n    for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {\n        for (var l: u32 = 0; l < 32; l++) {\n            let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);\n            let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);\n            let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);\n\n            let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;\n            let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;\n            let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;\n            let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;\n\n            let is = l/16;\n            let is1 = sc_b_idx + is;\n            let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);\n            let is2 = sc_b_idx + is + 2;\n            let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);\n            let is3 = sc_b_idx + is + 4;\n            let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);\n            let is4 = sc_b_idx + is + 6;\n            let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);\n\n            dst[dst_i + l] = (q1 * f32(sc1)) * d;\n            dst[dst_i + l + 32] = (q2 * f32(sc2)) * d;\n            dst[dst_i + l + 64] = (q3 * f32(sc3)) * d;\n            dst[dst_i + l + 96] = (q4 * f32(sc4)) * d;\n        }\n        dst_i += 128;\n        qh_b_idx += 32;\n        sc_b_idx += 8;\n    }\n}\n#endif\n\n#ifdef IQ2_XXS\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block = src[src_base + offset];\n    let d = f32(block.d);\n    var dst_i = dst_base + offset * 256;\n    for (var ib: u32 = 0; ib < 32; ib += 4) {\n        let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));\n        let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));\n        let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;\n        for (var l: u32 = 0; l < 4; l++) {\n            let ig = get_byte(aux0, l) * 8;\n            let is = (aux1 >> (7 * l)) & 127;\n            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);\n            for (var j: u32 = 0; j < 8; j++) {\n                let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);\n                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);\n                dst[dst_i] = db * f32(g) * m;\n                dst_i++;\n            }\n        }\n    }\n}\n#endif\n\n#ifdef IQ2_XS\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block = src[src_base + offset];\n    let d = f32(block.d);\n    var dst_i = dst_base + offset * 256;\n    var scale_vals = array<u32, 2>(\n        bitcast<u32>(vec2(block.scales[0], block.scales[1])),\n        bitcast<u32>(vec2(block.scales[2], block.scales[3]))\n    );\n    for (var ib: u32 = 0; ib < 32; ib += 4) {\n        let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);\n        let db = array<f32, 2>(\n            d * (0.5 + f32(s & 0xF)) * 0.25,\n            d * (0.5 + f32(s >> 4)) * 0.25\n        );\n        for (var l: u32 = 0; l < 4; l++) {\n            let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));\n            let ig = (qs_val & 511) * 8;\n            let is = qs_val >> 9;\n            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);\n            let dl = db[l/2];\n            for (var j: u32 = 0; j < 8; j++) {\n                let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);\n                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);\n                dst[dst_i] = dl * f32(g) * m;\n                dst_i++;\n            }\n        }\n    }\n}\n#endif\n\n#ifdef IQ2_S\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block = src[src_base + offset];\n    let d = f32(block.d);\n    var dst_i = dst_base + offset * 256;\n    var qs_vals : array<u32, 16>;\n    for (var i: u32 = 0; i < 16; i++) {\n        qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));\n    }\n    var qh_vals = array<u32, 2>(\n        bitcast<u32>(vec2(block.qh[0], block.qh[1])),\n        bitcast<u32>(vec2(block.qh[2], block.qh[3]))\n    );\n    var scale_vals = array<u32, 2>(\n        bitcast<u32>(vec2(block.scales[0], block.scales[1])),\n        bitcast<u32>(vec2(block.scales[2], block.scales[3]))\n    );\n    for (var ib: u32 = 0; ib < 8; ib ++) {\n        let s = get_byte(scale_vals[ib / 4], ib % 4);\n        let db = array<f32, 2>(\n            d * (0.5 + f32(s & 0xF)) * 0.25,\n            d * (0.5 + f32(s >> 4)) * 0.25\n        );\n        let qs_w = qs_vals[ib];\n        for (var l: u32 = 0; l < 4; l++) {\n            let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;\n            let ig = (get_byte(qs_w, l) | qh_b) * 8;\n            let signs = get_byte(qs_vals[ib + 8], l);\n            let dl = db[l/2];\n            for (var j: u32 = 0; j < 8; j++) {\n                let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);\n                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);\n                dst[dst_i] = dl * f32(g) * m;\n                dst_i++;\n            }\n        }\n    }\n}\n#endif\n\n#ifdef IQ3_XXS\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block = src[src_base + offset];\n    let d = f32(block.d);\n    var dst_i = dst_base + offset * 256;\n    for (var ib: u32 = 0; ib < 16; ib += 2) {\n        let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));\n        let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;\n        for (var l: u32 = 0; l < 4; l++) {\n            let is = (sc_sign >> (7 * l)) & 127;\n            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);\n            let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));\n            let ig1 = get_byte(ig_val, 0);\n            let ig2 = get_byte(ig_val, 1);\n            for (var j: u32 = 0; j < 4; j++) {\n                let g1 = get_byte(iq3xxs_grid[ig1], j);\n                let g2 = get_byte(iq3xxs_grid[ig2], j);\n                let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);\n                let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);\n                dst[dst_i] = db * f32(g1) * m1;\n                dst[dst_i + 4] = db * f32(g2) * m2;\n                dst_i++;\n            }\n            dst_i += 4;\n        }\n    }\n}\n#endif\n\n#ifdef IQ3_S\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block = src[src_base + offset];\n    let d = f32(block.d);\n    var dst_i = dst_base + offset * 256;\n    var qh_vals = array<u32, 2>(\n        bitcast<u32>(vec2(block.qh[0], block.qh[1])),\n        bitcast<u32>(vec2(block.qh[2], block.qh[3]))\n    );\n    var sign_vals: array<u32, 8>;\n    for (var i: u32 = 0; i < 8; i++) {\n        sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));\n    }\n    var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));\n    for (var ib: u32 = 0; ib < 4; ib++) {\n        let s = get_byte(scale_vals, ib);\n        let db = array<f32, 2>(\n            d * (1.0 + 2.0 * f32(s & 0xF)),\n            d * (1.0 + 2.0 * f32(s >> 4))\n        );\n        for (var k: u32 = 0; k < 2; k++) {\n            let dl = db[k];\n            let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);\n            let sign_w = sign_vals[ib * 2 + k];\n            for (var l: u32 = 0; l < 4; l++) {\n                let signs = get_byte(sign_w, l);\n                let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));\n                let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);\n                let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);\n                for (var j: u32 = 0; j < 4; j++) {\n                    let g1 = get_byte(iq3s_grid[ig1], j);\n                    let g2 = get_byte(iq3s_grid[ig2], j);\n                    let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);\n                    let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);\n                    dst[dst_i] = dl * f32(g1) * m1;\n                    dst[dst_i + 4] = dl * f32(g2) * m2;\n                    dst_i++;\n                }\n                dst_i += 4;\n            }\n        }\n    }\n}\n#endif\n\n#ifdef IQ1_S\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block = src[src_base + offset];\n    let d = f32(block.d);\n    var dst_i = dst_base + offset * 256;\n    for (var ib: u32 = 0; ib < 8; ib++) {\n        let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));\n        let dl = d * (2 * f32((qh >> 12) & 7) + 1);\n        let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);\n        let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));\n        for (var l: u32 = 0; l < 4; l++) {\n            let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;\n            for (var j: u32 = 0; j < 8; j++) {\n                let gw = iq1_grid[(ig + j) / 16];\n                let g = (gw >> (((ig + j) % 16) * 2)) & 3;\n                let gs = bitcast<i32>(g << 30) >> 30;\n                dst[dst_i] = dl * (f32(gs) + delta);\n                dst_i++;\n            }\n        }\n    }\n}\n#endif\n\n#ifdef IQ1_M\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block = src[src_base + offset];\n\n    let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);\n    let d = f32(bitcast<vec2<f16>>(scale).x);\n    var dst_i = dst_base + offset * 256;\n    for (var ib: u32 = 0; ib < 8; ib++) {\n        let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;\n        let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;\n        let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;\n        var dl = array<f32, 2>(\n            d * f32(2 * s1 + 1),\n            d * f32(2 * s2 + 1)\n        );\n\n        let qh = block.qh[ib / 2] >> (16 * (ib % 2));\n        var idx = array<u32, 4>(\n            get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),\n            get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),\n            get_byte(block.qs[ib], 2) | ((qh) & 0x700),\n            get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)\n        );\n        var delta = array<f32, 4>(\n            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),\n            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),\n            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),\n            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)\n        );\n        for (var l: u32 = 0; l < 4; l++) {\n            let ig = idx[l] * 8;\n            for (var j: u32 = 0; j < 8; j++) {\n                let gw = iq1_grid[(ig + j) / 16];\n                let g = (gw >> (((ig + j) % 16) * 2)) & 3;\n                let gs = bitcast<i32>(g << 30) >> 30;\n                dst[dst_i] = dl[l/2] * (f32(gs) + delta[l]);\n                dst_i++;\n            }\n        }\n    }\n}\n#endif\n\n#ifdef IQ4_NL\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block = src[src_base + offset];\n    let d = f32(block.d);\n    var dst_i = dst_base + offset * 32;\n    var qs: array<u32, 4>;\n    for (var i: u32 = 0; i < 4; i++) {\n        qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));\n    }\n    for (var j: u32 = 0; j < 16; j++) {\n        let qsb = get_byte(qs[j / 4], j % 4);\n        dst[dst_i] = d * f32(kvalues_iq4nl[qsb & 0xF]);\n        dst[dst_i + 16] = d * f32(kvalues_iq4nl[qsb >> 4]);\n        dst_i++;\n    }\n}\n#endif\n\n#ifdef IQ4_XS\nfn copy_elements(src_base: u32, dst_base: u32, offset: u32) {\n    let block = src[src_base + offset];\n    let d = f32(block.d);\n    let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));\n    var dst_i = dst_base + offset * 256;\n    for (var ib: u32 = 0; ib < 8; ib++) {\n        let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);\n        let dl = d * (f32(ls) - 32.0);\n        for (var j: u32 = 0; j < 16; j++) {\n            let iqs = ib * 16 + j;\n            let qsb = get_byte(block.qs[iqs / 4], iqs % 4);\n            dst[dst_i] = dl * f32(kvalues_iq4nl[qsb & 0xF]);\n            dst[dst_i + 16] = dl * f32(kvalues_iq4nl[qsb >> 4]);\n            dst_i++;\n        }\n        dst_i += 16;\n    }\n}\n#endif\n\n@group(0) @binding(0)\nvar<storage, read_write> src: array<SRC_TYPE>;\n\n@group(0) @binding(1)\nvar<storage, read_write> idx: array<i32>;\n\n@group(0) @binding(2)\nvar<storage, read_write> dst: array<DST_TYPE>;\n\nstruct Params {\n    offset_src: u32, // in elements\n    offset_idx: u32, // in elements\n    offset_dst: u32, // in elements\n\n    // Strides (in elements)\n    stride_src1: u32,\n    stride_src2: u32,\n    stride_src3: u32,\n\n    stride_idx0: u32,\n    stride_idx1: u32,\n    stride_idx2: u32,\n\n    stride_dst1: u32,\n    stride_dst2: u32,\n    stride_dst3: u32,\n\n    // Shape of dst\n    ne0: u32,\n    n_rows: u32,\n    ne2: u32,\n    ne3: u32,\n\n    // Shape of idx\n    idx1: u32,\n    idx2: u32,\n};\n\n@group(0) @binding(3)\nvar<uniform> params: Params;\n\n@compute @workgroup_size(WG_SIZE)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n    if (gid.x >= params.n_rows * params.ne2 * params.ne3) {\n        return;\n    }\n    var i = gid.x;\n    let i_dst3 = i / (params.ne2 * params.n_rows);\n\n    i = i % (params.ne2 * params.n_rows);\n    let i_dst2 = i / params.n_rows;\n    let i_dst1 = i % params.n_rows;\n\n    let i_idx2 = i_dst3 % params.idx2;\n    let i_idx1 = i_dst2 % params.idx1;\n    let i_idx0 = i_dst1;\n\n    let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;\n\n    let idx_val = u32(idx[i_idx]);\n\n    let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;\n    let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;\n\n    for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {\n      copy_elements(i_src_row, i_dst_row, i);\n    }\n}\n\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl",
    "content": "#define(VARIANTS)\n\n[\n  {\n    \"SHADER_NAME\": \"reglu_f32\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"NO_SPLIT\", \"REGLU\"]\n  },\n  {\n    \"SHADER_NAME\": \"reglu_f32_split\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"SPLIT\", \"REGLU\"]\n  },\n  {\n    \"SHADER_NAME\": \"reglu_f16\",\n    \"REPLS\": {\n      \"TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"NO_SPLIT\", \"REGLU\"]\n  },\n  {\n    \"SHADER_NAME\": \"reglu_f16_split\",\n    \"REPLS\": {\n      \"TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"SPLIT\", \"REGLU\"]\n  },\n  {\n    \"SHADER_NAME\": \"geglu_f32\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"NO_SPLIT\", \"GEGLU\"]\n  },\n  {\n    \"SHADER_NAME\": \"geglu_f32_split\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"SPLIT\", \"GEGLU\"]\n  },\n  {\n    \"SHADER_NAME\": \"geglu_f16\",\n    \"REPLS\": {\n      \"TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"NO_SPLIT\", \"GEGLU\"]\n  },\n  {\n    \"SHADER_NAME\": \"geglu_f16_split\",\n    \"REPLS\": {\n      \"TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"SPLIT\", \"GEGLU\"]\n  },\n  {\n    \"SHADER_NAME\": \"swiglu_f32\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"NO_SPLIT\", \"SWIGLU\"]\n  },\n  {\n    \"SHADER_NAME\": \"swiglu_f32_split\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"SPLIT\", \"SWIGLU\"]\n  },\n  {\n    \"SHADER_NAME\": \"swiglu_f16\",\n    \"REPLS\": {\n      \"TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"NO_SPLIT\", \"SWIGLU\"]\n  },\n  {\n    \"SHADER_NAME\": \"swiglu_f16_split\",\n    \"REPLS\": {\n      \"TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"SPLIT\", \"SWIGLU\"]\n  },\n  {\n    \"SHADER_NAME\": \"swiglu_oai_f32\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"NO_SPLIT\", \"SWIGLU_OAI\"]\n  },\n  {\n    \"SHADER_NAME\": \"swiglu_oai_f32_split\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"SPLIT\", \"SWIGLU_OAI\"]\n  },\n  {\n    \"SHADER_NAME\": \"geglu_erf_f32\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"NO_SPLIT\", \"GEGLU_ERF\"]\n  },\n  {\n    \"SHADER_NAME\": \"geglu_erf_f32_split\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"SPLIT\", \"GEGLU_ERF\"]\n  },\n  {\n    \"SHADER_NAME\": \"geglu_erf_f16\",\n    \"REPLS\": {\n      \"TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"NO_SPLIT\", \"GEGLU_ERF\"]\n  },\n  {\n    \"SHADER_NAME\": \"geglu_erf_f16_split\",\n    \"REPLS\": {\n      \"TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"SPLIT\", \"GEGLU_ERF\"]\n  },\n  {\n    \"SHADER_NAME\": \"geglu_quick_f32\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"NO_SPLIT\", \"GEGLU_QUICK\"]\n  },\n  {\n    \"SHADER_NAME\": \"geglu_quick_f32_split\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"SPLIT\", \"GEGLU_QUICK\"]\n  },\n  {\n    \"SHADER_NAME\": \"geglu_quick_f16\",\n    \"REPLS\": {\n      \"TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"NO_SPLIT\", \"GEGLU_QUICK\"]\n  },\n  {\n    \"SHADER_NAME\": \"geglu_quick_f16_split\",\n    \"REPLS\": {\n      \"TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"SPLIT\", \"GEGLU_QUICK\"]\n  },\n]\n\n#end(VARIANTS)\n\n#define(DECLS)\n\n#decl(REGLU)\nfn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {\n    return max(a, 0) * b;\n}\n#enddecl(REGLU)\n\n#decl(GEGLU)\nconst SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876;\nconst GELU_COEF_A: {{TYPE}} = 0.044715;\n\nfn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {\n    let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);\n    return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b;\n}\n#enddecl(GEGLU)\n\n#decl(SWIGLU)\nfn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {\n    return a / (1.0 + exp(-a)) * b;\n}\n#enddecl(SWIGLU)\n\n#decl(SWIGLU_OAI)\nfn op(a: f32, b: f32) -> f32 {\n  let xi = min(a, params.limit);\n  let gi = max(min(b, params.limit), -params.limit);\n  var out_glu = xi / (1.0 + exp(-xi * params.alpha));\n  out_glu = out_glu * (1.0 + gi);\n  return out_glu;\n}\n#enddecl(SWIGLU_OAI)\n\n#decl(GEGLU_ERF)\nconst p_erf: {{TYPE}} = 0.3275911;\nconst a1_erf: {{TYPE}} = 0.254829592;\nconst a2_erf: {{TYPE}} = -0.284496736;\nconst a3_erf: {{TYPE}} = 1.421413741;\nconst a4_erf: {{TYPE}} = -1.453152027;\nconst a5_erf: {{TYPE}} = 1.061405429;\nconst SQRT_2_INV: {{TYPE}} = 0.7071067811865476;\n\nfn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {\n  let a_div_sqr2 = a * SQRT_2_INV;\n  let sign_x = sign(a_div_sqr2);\n  let x = abs(a_div_sqr2);\n  let t = 1.0 / (1.0 + p_erf * x);\n  let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));\n  let erf_approx = sign_x * y;\n  return 0.5 * a * (1.0 + erf_approx) * b;\n}\n#enddecl(GEGLU_ERF)\n\n#decl(GEGLU_QUICK)\nconst GELU_QUICK_COEF: {{TYPE}} = -1.702;\n\nfn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {\n    return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;\n}\n#enddecl(GEGLU_QUICK)\n\n#decl(NO_SPLIT)\n@group(0) @binding(1)\nvar<storage, read_write> dst: array<{{TYPE}}>;\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n\nfn a_value(base: u32) -> {{TYPE}} {\n    let offset: u32 = select(0, params.ne0, params.swapped != 0);\n    return src0[base + offset];\n}\n\nfn b_value(base: u32) -> {{TYPE}} {\n    let offset: u32 = select(params.ne0, 0, params.swapped != 0);\n    return src0[base + offset];\n}\n#enddecl(NO_SPLIT)\n\n#decl(SPLIT)\n@group(0) @binding(1)\nvar<storage, read_write> src1: array<{{TYPE}}>;\n\n@group(0) @binding(2)\nvar<storage, read_write> dst: array<{{TYPE}}>;\n\n@group(0) @binding(3)\nvar<uniform> params: Params;\n\nfn a_value(base: u32) -> {{TYPE}} {\n    return src0[base];\n}\n\nfn b_value(base: u32) -> {{TYPE}} {\n    return src1[base];\n}\n#enddecl(SPLIT)\n\n#end(DECLS)\n\n#define(SHADER)\n\nenable f16;\n\nstruct Params {\n    offset_src0: u32,\n    offset_src1: u32,\n    offset_dst: u32,\n\n    // Strides (in elements)\n    stride_src01: u32,\n    stride_src02: u32,\n    stride_src03: u32,\n\n    stride_src11: u32,\n    stride_src12: u32,\n    stride_src13: u32,\n\n    stride_dst1: u32,\n    stride_dst2: u32,\n    stride_dst3: u32,\n\n    // shape of dst\n    ne: u32,\n    ne0: u32,\n    ne1: u32,\n    ne2: u32,\n\n    swapped: u32,\n    alpha: f32,\n    limit: f32,\n}\n\n@group(0) @binding(0)\nvar<storage, read_write> src0: array<{{TYPE}}>;\n\nDECLS\n\noverride wg_size: u32;\n@compute @workgroup_size(wg_size)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n    if (gid.x >= params.ne) {\n        return;\n    }\n\n    var i = gid.x;\n    let i3 = i / (params.ne2 * params.ne1 * params.ne0);\n    i = i % (params.ne2 * params.ne1 * params.ne0);\n    let i2 = i / (params.ne1 * params.ne0);\n    i = i % (params.ne1 * params.ne0);\n    let i1 = i / params.ne0;\n    let i0 = i % params.ne0;\n\n    let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;\n    let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;\n    let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;\n\n    dst[i_dst] = op(a_value(i_a), b_value(i_b));\n}\n\n#end(SHADER)\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/memset.wgsl",
    "content": "@group(0) @binding(0)\nvar<storage, read_write> output_buffer: array<u32>;\n\nstruct Params {\n    offset: u32, // in bytes\n    size: u32,   // in bytes\n    value: u32,  // 4 8-bit values, which are either repeating (memset_tensor) or may be separate (cleaning up unaligned set_tensor operations)\n};\n\n@group(0) @binding(1)\nvar<uniform> params: Params;\n\noverride wg_size: u32;\noverride bytes_per_thread: u32;\n\n@compute @workgroup_size(wg_size)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n    let i = gid.x * bytes_per_thread;\n    let start = params.offset;\n    let end = params.offset + params.size;\n\n    for (var j: u32 = 0u; j < bytes_per_thread; j += 4) {\n        let byte_index = start + i + j;\n        if (byte_index + 4 <= end) {\n            output_buffer[byte_index >> 2] = params.value;\n        } else {\n            // Handle tail (unaligned)\n            for (var k: u32 = 0; k < 4; k++) {\n                let idx = byte_index + k;\n                if (idx < end) {\n                    let word_idx = idx >> 2;\n                    let bit_offset = (idx & 3) * 8u;\n                    let mask = ~(0xffu << bit_offset);\n                    let existing = output_buffer[word_idx];\n                    output_buffer[word_idx] = (existing & mask) | (params.value & (0xffu << bit_offset));\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl",
    "content": "enable f16;\n\n#include \"common_decls.tmpl\"\n\n#ifdef FLOAT\nconst BLOCK_SIZE = 1u;\n\n#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL)\nconst BLOCK_SIZE = 32u;\n\n#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS)\nconst BLOCK_SIZE = 256u;\n#endif\n\n#ifdef FLOAT\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);\n}\n#endif\n\n#ifdef Q4_0\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block_q4_0 = src0[src0_idx_base + offset];\n    let d = f32(block_q4_0.d);\n    var sum: f32 = 0.0;\n    for (var j: u32 = 0; j < 4; j++) {\n        let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));\n        for (var k: u32 = 0; k < 4; k++) {\n            let q_byte = get_byte(q_packed, k);\n            let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;\n            let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;\n            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;\n            sum += q_lo * f32(src1[src1_offset]);\n            sum += q_hi * f32(src1[src1_offset + 16]);\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef Q4_1\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block_q4_1 = src0[src0_idx_base + offset];\n    let d = f32(block_q4_1.d);\n    let m = f32(block_q4_1.m);\n    var sum: f32 = 0.0;\n    for (var j: u32 = 0; j < 4; j++) {\n        let q_packed = block_q4_1.qs[j];\n        for (var k: u32 = 0; k < 4; k++) {\n            let q_byte = get_byte(q_packed, k);\n            let q_hi = f32((q_byte >> 4) & 0xF) * d + m;\n            let q_lo = f32(q_byte & 0xF) * d + m;\n            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;\n            sum += q_lo * f32(src1[src1_offset]);\n            sum += q_hi * f32(src1[src1_offset + 16]);\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef Q5_0\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block_q5_0 = src0[src0_idx_base + offset];\n    let d = f32(block_q5_0.d);\n    var sum: f32 = 0.0;\n    let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));\n    for (var j: u32 = 0; j < 4; j++) {\n        let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));\n        for (var k: u32 = 0; k < 4; k++) {\n            let q_byte = get_byte(q_packed, k);\n            let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;\n            let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;\n            let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;\n            let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;\n            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;\n            sum += q_lo * f32(src1[src1_offset]);\n            sum += q_hi * f32(src1[src1_offset + 16]);\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef Q5_1\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block_q5_1 = src0[src0_idx_base + offset];\n    let d = f32(block_q5_1.d);\n    let m = f32(block_q5_1.m);\n    var sum: f32 = 0.0;\n    for (var j: u32 = 0; j < 4; j++) {\n        let q_packed = block_q5_1.qs[j];\n        for (var k: u32 = 0; k < 4; k++) {\n            let q_byte = get_byte(q_packed, k);\n            let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;\n            let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;\n            let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;\n            let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;\n            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;\n            sum += q_lo * f32(src1[src1_offset]);\n            sum += q_hi * f32(src1[src1_offset + 16]);\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef Q8_0\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block_q8_0 = src0[src0_idx_base + offset];\n    let d = f32(block_q8_0.d);\n    var sum: f32 = 0.0;\n    for (var j: u32 = 0; j < 8; j++) {\n        let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));\n        for (var k: u32 = 0; k < 4; k++) {\n            let q_byte = get_byte_i32(q_packed, k);\n            let q_val = f32(q_byte) * d;\n            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;\n            sum += q_val * f32(src1[src1_offset]);\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef Q8_1\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block_q8_1 = src0[src0_idx_base + offset];\n    let d = f32(block_q8_1.d);\n    let m = f32(block_q8_1.m);\n    var sum: f32 = 0.0;\n    for (var j: u32 = 0; j < 8; j++) {\n        let q_packed = block_q8_1.qs[j];\n        for (var k: u32 = 0; k < 4; k++) {\n            let q_byte = get_byte_i32(q_packed, k);\n            let q_val = f32(q_byte) * d + m;\n            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;\n            sum += q_val * f32(src1[src1_offset]);\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef Q2_K\n// 16 blocks of 16 elements each\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block = src0[src0_idx_base + offset];\n    let d = f32(block.d);\n    let m = f32(block.dmin);\n    var sum = 0.0;\n    var src1_i = src1_idx_base + offset * 256;\n    var is: u32 = 0;\n    // 2 halves of the block (128 elements each)\n    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {\n        // 4 groups (each group has 2 blocks of 16 elements)\n        for (var shift: u32 = 0; shift < 8; shift += 2) {\n            // 2 blocks\n            for (var k: u32 = 0; k < 32; k += 16) {\n                let sc = get_byte(block.scales[is / 4], is % 4);\n                is++;\n                let dl = d * f32(sc & 0xF);\n                let ml = m * f32(sc >> 4);\n                for (var l: u32 = 0u; l < 16; l++) {\n                    let q_idx = q_b_idx + k + l;\n                    let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);\n                    let qs_val = (q_byte >> shift) & 3;\n                    sum += (f32(qs_val) * dl - ml) * src1[src1_i];\n                    src1_i++;\n                }\n            }\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef Q3_K\n// 16 blocks of 16 elements each\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block = src0[src0_idx_base + offset];\n    let d = f32(block.d);\n\n    // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,\n    // and 2-bits from the last 4 bytes\n    let kmask1: u32 = 0x03030303;\n    let kmask2: u32 = 0x0f0f0f0f;\n    var scale_vals: array<u32, 4>;\n    for (var i: u32 = 0; i < 4; i++) {\n        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));\n    }\n    var tmp: u32 = scale_vals[2];\n    scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);\n    scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);\n    scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);\n    scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);\n\n    // convert arrays of f16 -> u32\n    var hmask_vals: array<u32, 8>;\n    for (var i: u32 = 0; i < 8; i++) {\n        hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));\n    }\n    var qs_vals: array<u32, 16>;\n    for (var i: u32 = 0; i < 16; i++) {\n        qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));\n    }\n\n    var sum = 0.0;\n    var src1_i = src1_idx_base + offset * 256;\n    var is: u32 = 0;\n    var m: u32 = 1;\n    // 2 halves of the block (128 elements each)\n    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {\n        // 4 groups (each group has 2 blocks of 16 elements)\n        for (var shift: u32 = 0; shift < 8; shift += 2) {\n            // 2 blocks\n            for (var k: u32 = 0; k < 32; k += 16) {\n                let sc = get_byte(scale_vals[is / 4], is % 4);\n                is++;\n                let dl = d * (f32(sc) - 32.0);\n                for (var l: u32 = 0u; l < 16u; l++) {\n                    let q_idx = q_b_idx + k + l;\n                    let hm_idx = k + l;\n                    let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);\n                    let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);\n                    let hm = select(4.0, 0.0, (hmask_byte & m) != 0);\n                    let qs_val = (q_byte >> shift) & 3;\n                    sum += ((f32(qs_val) - hm) * dl) * src1[src1_i];\n                    src1_i++;\n                }\n            }\n            m <<= 1;\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef Q4_K\n// 8 blocks of 32 elements each\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block = src0[src0_idx_base + offset];\n    let d = f32(block.d);\n    let m = f32(block.dmin);\n    var sum = 0.0;\n    var src1_i = src1_idx_base + offset * 256;\n    var is: u32 = 0;\n    // 2 blocks each iteration\n    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {\n        for (var shift: u32 = 0; shift < 8; shift += 4) {\n            let scale_min = get_scale_min(is, block.scales);\n            is++;\n            let dl = d * scale_min.x;\n            let ml = m * scale_min.y;\n            for (var l: u32 = 0; l < 32; l++) {\n                let q_idx = q_b_idx + l;\n                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);\n                let qs_val = (q_byte >> shift) & 0xF;\n                sum += (f32(qs_val) * dl - ml) * src1[src1_i];\n                src1_i++;\n            }\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef Q5_K\n// 8 blocks of 32 elements each\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block = src0[src0_idx_base + offset];\n    let d = f32(block.d);\n    let m = f32(block.dmin);\n    var sum = 0.0;\n    var src1_i = src1_idx_base + offset * 256;\n    var is: u32 = 0;\n    var u: u32 = 1;\n    // 2 blocks each iteration\n    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {\n        for (var shift: u32 = 0; shift < 8; shift += 4) {\n            let scale_min = get_scale_min(is, block.scales);\n            is++;\n            let dl = d * scale_min.x;\n            let ml = m * scale_min.y;\n            for (var l: u32 = 0; l < 32; l++) {\n                let q_idx = q_b_idx + l;\n                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);\n                let qh_byte = get_byte(block.qh[l / 4], l % 4);\n                let qs_val = (q_byte >> shift) & 0xF;\n                let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);\n                sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i];\n               src1_i++;\n            }\n            u <<= 1;\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef Q6_K\n// 16 blocks of 16 elements each\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block = src0[src0_idx_base + offset];\n    let d = f32(block.d);\n\n    // convert arrays of f16 -> u32\n    var ql_vals: array<u32, 32>;\n    for (var i: u32 = 0; i < 32; i++) {\n        ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));\n    }\n    var qh_vals: array<u32, 16>;\n    for (var i: u32 = 0; i < 16; i++) {\n        qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));\n    }\n    var scale_vals: array<u32, 4>;\n    for (var i: u32 = 0; i < 4; i++) {\n        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));\n    }\n\n    var sum = 0.0;\n    var src1_i = src1_idx_base + offset * 256;\n    var qh_b_idx: u32 = 0;\n    var sc_b_idx: u32 = 0;\n    for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {\n        for (var l: u32 = 0; l < 32; l++) {\n            let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);\n            let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);\n            let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);\n\n            let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;\n            let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;\n            let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;\n            let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;\n\n            let is = l/16;\n            let is1 = sc_b_idx + is;\n            let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);\n            let is2 = sc_b_idx + is + 2;\n            let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);\n            let is3 = sc_b_idx + is + 4;\n            let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);\n            let is4 = sc_b_idx + is + 6;\n            let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);\n\n            sum += d * f32(sc1) * q1 * src1[src1_i + l];\n            sum += d * f32(sc2) * q2 * src1[src1_i + l + 32];\n            sum += d * f32(sc3) * q3 * src1[src1_i + l + 64];\n            sum += d * f32(sc4) * q4 * src1[src1_i + l + 96];\n        }\n        src1_i += 128;\n        qh_b_idx += 32;\n        sc_b_idx += 8;\n    }\n    return sum;\n}\n#endif\n\n#ifdef IQ2_XXS\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block = src0[src0_idx_base + offset];\n    let d = f32(block.d);\n    var src1_i = src1_idx_base + offset * 256;\n    var sum = 0.0;\n    for (var ib: u32 = 0; ib < 32; ib += 4) {\n        let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));\n        let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));\n        let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;\n        for (var l: u32 = 0; l < 4; l++) {\n            let ig = get_byte(aux0, l) * 8;\n            let is = (aux1 >> (7 * l)) & 127;\n            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);\n            for (var j: u32 = 0; j < 8; j++) {\n                let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);\n                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);\n                sum += db * f32(g) * m * src1[src1_i];\n                src1_i++;\n            }\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef IQ2_XS\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block = src0[src0_idx_base + offset];\n    let d = f32(block.d);\n    var src1_i = src1_idx_base + offset * 256;\n    var scale_vals = array<u32, 2>(\n        bitcast<u32>(vec2(block.scales[0], block.scales[1])),\n        bitcast<u32>(vec2(block.scales[2], block.scales[3]))\n    );\n    var sum = 0.0;\n    for (var ib: u32 = 0; ib < 32; ib += 4) {\n        let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);\n        let db = array<f32, 2>(\n            d * (0.5 + f32(s & 0xF)) * 0.25,\n            d * (0.5 + f32(s >> 4)) * 0.25\n        );\n        for (var l: u32 = 0; l < 4; l++) {\n            let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));\n            let ig = (qs_val & 511) * 8;\n            let is = qs_val >> 9;\n            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);\n            let dl = db[l/2];\n            for (var j: u32 = 0; j < 8; j++) {\n                let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);\n                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);\n                sum += dl * f32(g) * m * src1[src1_i];\n                src1_i++;\n            }\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef IQ2_S\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block = src0[src0_idx_base + offset];\n    let d = f32(block.d);\n    var src1_i = src1_idx_base + offset * 256;\n    var qs_vals : array<u32, 16>;\n    for (var i: u32 = 0; i < 16; i++) {\n        qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));\n    }\n    var qh_vals = array<u32, 2>(\n        bitcast<u32>(vec2(block.qh[0], block.qh[1])),\n        bitcast<u32>(vec2(block.qh[2], block.qh[3]))\n    );\n    var scale_vals = array<u32, 2>(\n        bitcast<u32>(vec2(block.scales[0], block.scales[1])),\n        bitcast<u32>(vec2(block.scales[2], block.scales[3]))\n    );\n    var sum = 0.0;\n    for (var ib: u32 = 0; ib < 8; ib ++) {\n        let s = get_byte(scale_vals[ib / 4], ib % 4);\n        let db = array<f32, 2>(\n            d * (0.5 + f32(s & 0xF)) * 0.25,\n            d * (0.5 + f32(s >> 4)) * 0.25\n        );\n        let qs_w = qs_vals[ib];\n        for (var l: u32 = 0; l < 4; l++) {\n            let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;\n            let ig = (get_byte(qs_w, l) | qh_b) * 8;\n            let signs = get_byte(qs_vals[ib + 8], l);\n            let dl = db[l/2];\n            for (var j: u32 = 0; j < 8; j++) {\n                let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);\n                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);\n                sum += dl * f32(g) * m * src1[src1_i];\n                src1_i++;\n            }\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef IQ3_XXS\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block = src0[src0_idx_base + offset];\n    let d = f32(block.d);\n    var src1_i = src1_idx_base + offset * 256;\n    var sum = 0.0;\n    for (var ib: u32 = 0; ib < 16; ib += 2) {\n        let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));\n        let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;\n        for (var l: u32 = 0; l < 4; l++) {\n            let is = (sc_sign >> (7 * l)) & 127;\n            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);\n            let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));\n            let ig1 = get_byte(ig_val, 0);\n            let ig2 = get_byte(ig_val, 1);\n            for (var j: u32 = 0; j < 4; j++) {\n                let g1 = get_byte(iq3xxs_grid[ig1], j);\n                let g2 = get_byte(iq3xxs_grid[ig2], j);\n                let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);\n                let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);\n                sum += db * f32(g1) * m1 * src1[src1_i];\n                sum += db * f32(g2) * m2 * src1[src1_i + 4];\n                src1_i++;\n            }\n            src1_i += 4;\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef IQ3_S\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block = src0[src0_idx_base + offset];\n    let d = f32(block.d);\n    var src1_i = src1_idx_base + offset * 256;\n    var qh_vals = array<u32, 2>(\n        bitcast<u32>(vec2(block.qh[0], block.qh[1])),\n        bitcast<u32>(vec2(block.qh[2], block.qh[3]))\n    );\n    var sign_vals: array<u32, 8>;\n    for (var i: u32 = 0; i < 8; i++) {\n        sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));\n    }\n    var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));\n    var sum = 0.0;\n    for (var ib: u32 = 0; ib < 4; ib++) {\n        let s = get_byte(scale_vals, ib);\n        let db = array<f32, 2>(\n            d * (1.0 + 2.0 * f32(s & 0xF)),\n            d * (1.0 + 2.0 * f32(s >> 4))\n        );\n        for (var k: u32 = 0; k < 2; k++) {\n            let dl = db[k];\n            let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);\n            let sign_w = sign_vals[ib * 2 + k];\n            for (var l: u32 = 0; l < 4; l++) {\n                let signs = get_byte(sign_w, l);\n                let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));\n                let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);\n                let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);\n                for (var j: u32 = 0; j < 4; j++) {\n                    let g1 = get_byte(iq3s_grid[ig1], j);\n                    let g2 = get_byte(iq3s_grid[ig2], j);\n                    let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);\n                    let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);\n                    sum += dl * f32(g1) * m1 * src1[src1_i];\n                    sum += dl * f32(g2) * m2 * src1[src1_i + 4];\n                    src1_i++;\n                }\n                src1_i += 4;\n            }\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef IQ1_S\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block = src0[src0_idx_base + offset];\n    let d = f32(block.d);\n    var src1_i = src1_idx_base + offset * 256;\n    var sum = 0.0;\n    for (var ib: u32 = 0; ib < 8; ib++) {\n        let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));\n        let dl = d * (2 * f32((qh >> 12) & 7) + 1);\n        let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);\n        let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));\n        for (var l: u32 = 0; l < 4; l++) {\n            let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;\n            for (var j: u32 = 0; j < 8; j++) {\n                let gw = iq1_grid[(ig + j) / 16];\n                let g = (gw >> (((ig + j) % 16) * 2)) & 3;\n                let gs = bitcast<i32>(g << 30) >> 30;\n                sum += dl * (f32(gs) + delta) * src1[src1_i];\n                src1_i++;\n            }\n        }\n    }\n    return sum;\n}\n#endif\n\n\n#ifdef IQ1_M\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block = src0[src0_idx_base + offset];\n\n    let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);\n    let d = f32(bitcast<vec2<f16>>(scale).x);\n    var src1_i = src1_idx_base + offset * 256;\n    var sum = 0.0;\n    for (var ib: u32 = 0; ib < 8; ib++) {\n        let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;\n        let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;\n        let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;\n        var dl = array<f32, 2>(\n            d * f32(2 * s1 + 1),\n            d * f32(2 * s2 + 1)\n        );\n\n        let qh = block.qh[ib / 2] >> (16 * (ib % 2));\n        var idx = array<u32, 4>(\n            get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),\n            get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),\n            get_byte(block.qs[ib], 2) | ((qh) & 0x700),\n            get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)\n        );\n        var delta = array<f32, 4>(\n            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),\n            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),\n            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),\n            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)\n        );\n        for (var l: u32 = 0; l < 4; l++) {\n            let ig = idx[l] * 8;\n            for (var j: u32 = 0; j < 8; j++) {\n                let gw = iq1_grid[(ig + j) / 16];\n                let g = (gw >> (((ig + j) % 16) * 2)) & 3;\n                let gs = bitcast<i32>(g << 30) >> 30;\n                sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i];\n                src1_i++;\n            }\n        }\n    }\n    return sum;\n}\n#endif\n\n#ifdef IQ4_NL\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block = src0[src0_idx_base + offset];\n    let d = f32(block.d);\n    var src1_i = src1_idx_base + offset * 32;\n    var sum = 0.0;\n    var qs: array<u32, 4>;\n    for (var i: u32 = 0; i < 4; i++) {\n        qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));\n    }\n    for (var j: u32 = 0; j < 16; j++) {\n        let qsb = get_byte(qs[j / 4], j % 4);\n        sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];\n        sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];\n        src1_i++;\n    }\n    return sum;\n}\n#endif\n\n#ifdef IQ4_XS\nfn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {\n    let block = src0[src0_idx_base + offset];\n    let d = f32(block.d);\n    let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));\n    var src1_i = src1_idx_base + offset * 256;\n    var sum = 0.0;\n    for (var ib: u32 = 0; ib < 8; ib++) {\n        let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);\n        let dl = d * (f32(ls) - 32.0);\n        for (var j: u32 = 0; j < 16; j++) {\n            let iqs = ib * 16 + j;\n            let qsb = get_byte(block.qs[iqs / 4], iqs % 4);\n            sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];\n            sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];\n            src1_i++;\n        }\n        src1_i += 16;\n    }\n    return sum;\n}\n#endif\n\nstruct MulMatParams {\n    offset_src0: u32, // in elements/blocks\n    offset_src1: u32, // in elements/blocks\n    offset_dst: u32, // in elements/blocks\n    m: u32,\n    n: u32,\n    k: u32,\n    // all strides are in elements/blocks\n    stride_01: u32,\n    stride_11: u32,\n    stride_02: u32,\n    stride_12: u32,\n    stride_03: u32,\n    stride_13: u32,\n\n    bs02: u32,\n    bs03: u32,\n    broadcast2: u32,\n    broadcast3: u32\n};\n\n@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns\n@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)\n@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns\n\n@group(0) @binding(3) var<uniform> params: MulMatParams;\n\n@compute @workgroup_size(256)\nfn main(@builtin(local_invocation_id) local_id: vec3<u32>,\n        @builtin(workgroup_id) wg_id: vec3<u32>,\n        @builtin(num_workgroups) num_wg: vec3<u32>) {\n    let wg_linear = wg_id.y * num_wg.x + wg_id.x;\n    let global_idx = wg_linear * 256u + local_id.x;\n\n    let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;\n    if (global_idx >= total) {\n        return;\n    }\n\n    let dst2_stride = params.m * params.n;\n    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;\n\n    let dst3_idx = global_idx / dst3_stride;\n    let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension\n    let src13_idx = dst3_idx; // src1 is not broadcast\n    let dst3_rem = global_idx % dst3_stride;\n\n    let dst2_idx = dst3_rem / dst2_stride;\n    let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension\n    let src12_idx = dst2_idx; // src1 is not broadcast\n\n    let dst2_rem = dst3_rem % dst2_stride;\n\n    let row = dst2_rem / params.m; // output row\n    let col = dst2_rem % params.m; // output column\n\n    let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;\n    let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;\n\n    var sum = 0.0;\n    for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {\n        sum += multiply_add(src0_idx_base, src1_idx_base, i);\n    }\n    dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl",
    "content": "#ifdef VEC\n#define VEC_SIZE 4\n#define SHMEM_TYPE vec4<f16>\n#define DST_TYPE vec4<f32>\n#define SRC0_TYPE vec4<SRC0_INNER_TYPE>\n#define SRC1_TYPE vec4<SRC1_INNER_TYPE>\n\nfn store_shmem(val: vec4<f16>, idx: u32) {\n    shmem[idx] = val.x;\n    shmem[idx + 1] = val.y;\n    shmem[idx + 2] = val.z;\n    shmem[idx + 3] = val.w;\n}\n#endif // VEC\n\n#ifdef SCALAR\n#define VEC_SIZE 1\n#define SHMEM_TYPE f16\n#define DST_TYPE f32\n#define SRC0_TYPE SRC0_INNER_TYPE\n#define SRC1_TYPE SRC1_INNER_TYPE\n\nfn store_shmem(val: f16, idx: u32) {\n    shmem[idx] = val;\n}\n#endif // SCALAR\n\n#ifdef INIT_SRC0_SHMEM_FLOAT\nfn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {\n    for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {\n        let tile_m = elem_idx / TILE_K;\n        let tile_k = elem_idx % TILE_K;\n        let global_m = offset_m + tile_m;\n        let global_k = k_outer + tile_k;\n        let src0_idx = batch_offset + global_m * params.stride_01 + global_k;\n        let src0_val = select( // taking a slight performance hit to avoid oob\n            SRC0_TYPE(0.0),\n            src0[src0_idx/VEC_SIZE],\n            global_m < params.m && global_k < params.k);\n        store_shmem(SHMEM_TYPE(src0_val), elem_idx);\n    }\n}\n#endif // INIT_SRC0_SHMEM_FLOAT\n\n#ifdef INIT_SRC1_SHMEM_FLOAT\nfn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {\n    for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {\n        let tile_n = elem_idx / TILE_K;\n        let tile_k = elem_idx % TILE_K;\n        let global_n = offset_n + tile_n;\n        let global_k = k_outer + tile_k;\n        let src1_idx = batch_offset + global_n * params.stride_11 + global_k;\n        let src1_val = select(\n            SRC1_TYPE(0.0),\n            src1[src1_idx/VEC_SIZE],\n            global_n < params.n && global_k < params.k);\n        store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);\n    }\n}\n#endif // INIT_SRC1_SHMEM_FLOAT\n\n#ifdef INIT_SRC0_SHMEM_Q4_0\nconst BLOCK_SIZE = 32u;\n// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.\noverride BLOCKS_K = TILE_K/BLOCK_SIZE;\nconst NQ = 16u;\nconst F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights\nconst WEIGHTS_PER_F16 = 4u; // 4 weights per f16\nconst F16_PER_THREAD = NQ / WEIGHTS_PER_F16;\n\nfn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {\n    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {\n        let blck_idx = i / BLOCK_SIZE;\n        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;\n        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;\n\n        let tile_m = blck_idx / BLOCKS_K;\n        let global_m = offset_m + tile_m;\n        let block_k = blck_idx % BLOCKS_K;\n        let global_k = k_outer / BLOCK_SIZE + block_k;\n\n        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {\n            let src0_idx = batch_offset + global_m * params.stride_01 + global_k;\n            let scale_idx = src0_idx * F16_PER_BLOCK;\n            let d = src0[scale_idx];\n\n            for (var j = 0u; j < F16_PER_THREAD; j += 2) {\n                let q_0 = src0[scale_idx + 1u + block_offset + j];\n                let q_1 = src0[scale_idx + 1u + block_offset + j + 1];\n\n                let q_packed = bitcast<u32>(vec2(q_0, q_1));\n                for (var k = 0u; k < 4u; k++) {\n                    let q_byte = get_byte(q_packed, k);\n                    let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;\n                    let q_lo = (f16(q_byte & 0xF) - 8.0) * d;\n                    shmem[shmem_idx + j * 2 + k] = q_lo;\n                    shmem[shmem_idx + j * 2 + k + 16u] = q_hi;\n                }\n            }\n        }\n    }\n}\n#endif // INIT_SRC0_SHMEM_Q4_0\n\n#ifdef INIT_SRC0_SHMEM_Q4_1\nconst BLOCK_SIZE = 32u;\n// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.\noverride BLOCKS_K = TILE_K/BLOCK_SIZE;\nconst NQ = 16u;\nconst F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean\nconst WEIGHTS_PER_F16 = 4u; // 4 weights per f16\nconst F16_PER_THREAD = NQ / WEIGHTS_PER_F16;\n\nfn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {\n    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {\n        let blck_idx = i / BLOCK_SIZE;\n        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;\n        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;\n\n        let tile_m = blck_idx / BLOCKS_K;\n        let global_m = offset_m + tile_m;\n        let block_k = blck_idx % BLOCKS_K;\n        let global_k = k_outer / BLOCK_SIZE + block_k;\n\n        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {\n            let src0_idx = batch_offset + global_m * params.stride_01 + global_k;\n            let scale_idx = src0_idx * F16_PER_BLOCK;\n            let d = src0[scale_idx];\n            let m = src0[scale_idx + 1u];\n\n            for (var j = 0u; j < F16_PER_THREAD; j += 2) {\n                let q_0 = src0[scale_idx + 2u + block_offset + j];\n                let q_1 = src0[scale_idx + 2u + block_offset + j + 1];\n\n                let q_packed = bitcast<u32>(vec2(q_0, q_1));\n                for (var k = 0u; k < 4u; k++) {\n                    let q_byte = get_byte(q_packed, k);\n                    let q_lo = f16(q_byte & 0xF) * d + m;\n                    let q_hi = f16((q_byte >> 4) & 0xF) * d + m;\n                    shmem[shmem_idx + j * 2 + k] = q_lo;\n                    shmem[shmem_idx + j * 2 + k + 16u] = q_hi;\n                }\n            }\n        }\n    }\n}\n#endif // INIT_SRC0_SHMEM_Q4_1\n\n#ifdef INIT_SRC0_SHMEM_Q5_0\n// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block\nconst BLOCK_SIZE = 32u;\n// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.\n// tile_k is defined as 32u, so blocks_k ends up being 1 always\noverride BLOCKS_K = TILE_K / BLOCK_SIZE;\nconst NQ = 16u;\nconst F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights\nconst WEIGHTS_PER_F16 = 4u; // 4 weights per f16\nconst F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights\n\nfn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {\n\n    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {\n        let blck_idx    = i / BLOCK_SIZE;\n        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;\n        let shmem_idx   = blck_idx * BLOCK_SIZE + block_offset * 2u;\n\n        let tile_m   = blck_idx / BLOCKS_K;\n        let global_m = offset_m + tile_m;\n        let block_k  = blck_idx % BLOCKS_K;\n        let global_k = k_outer / BLOCK_SIZE + block_k;\n\n        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {\n            let src0_idx  = batch_offset + global_m * params.stride_01 + global_k;\n            let scale_idx = src0_idx * F16_PER_BLOCK;\n\n            let d  = src0[scale_idx];\n            let qh0 = src0[scale_idx + 1u];\n            let qh1 = src0[scale_idx + 2u];\n            let qh_packed = bitcast<u32>(vec2(qh0, qh1));\n\n            for (var j = 0u; j < 2; j++) {\n                let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];\n                let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];\n\n                let q_packed = bitcast<u32>(vec2(q_0, q_1));\n\n                let j_adjusted = j + (block_offset / 2u);\n\n\n                for (var k = 0u; k < 4u; k++) {\n                    let q_byte = get_byte(q_packed, k);\n\n                    let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;\n                    let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;\n                    let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;\n                    let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d;\n\n                    shmem[shmem_idx + j * 4u + k]        = q_lo; // store first weight\n                    shmem[shmem_idx + j * 4u + k + 16u]  = q_hi; // store second weight\n                }\n            }\n        }\n    }\n}\n#endif // INIT_SRC0_SHMEM_Q5_0\n\n#ifdef INIT_SRC0_SHMEM_Q5_1\n// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block\nconst BLOCK_SIZE = 32u;\n// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.\n// tile_k is defined as 32u, so blocks_k ends up being 1 always\noverride BLOCKS_K = TILE_K / BLOCK_SIZE;\nconst NQ = 16u;\nconst F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean\nconst WEIGHTS_PER_F16 = 4u; // 4 weights per f16\nconst F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights\n\nfn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {\n\n    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {\n        let blck_idx    = i / BLOCK_SIZE;\n        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;\n        let shmem_idx   = blck_idx * BLOCK_SIZE + block_offset * 2u;\n\n        let tile_m   = blck_idx / BLOCKS_K;\n        let global_m = offset_m + tile_m;\n        let block_k  = blck_idx % BLOCKS_K;\n        let global_k = k_outer / BLOCK_SIZE + block_k;\n\n        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {\n            let src0_idx  = batch_offset + global_m * params.stride_01 + global_k;\n            let scale_idx = src0_idx * F16_PER_BLOCK;\n\n            let d  = src0[scale_idx];\n            let m = src0[scale_idx + 1u];\n            let qh0 = src0[scale_idx + 2u];\n            let qh1 = src0[scale_idx + 3u];\n            let qh_packed = bitcast<u32>(vec2(qh0, qh1));\n\n            for (var j = 0u; j < 2; j++) {\n\n                let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];\n                let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];\n\n                let q_packed = bitcast<u32>(vec2(q_0, q_1));\n\n                let j_adjusted = j + (block_offset / 2u);\n\n\n                for (var k = 0u; k < 4u; k++) {\n                    let q_byte = get_byte(q_packed, k);\n\n                    let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;\n                    let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m;\n                    let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;\n                    let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m;\n\n                    shmem[shmem_idx + j * 4u + k]        = q_lo; // store first weight\n                    shmem[shmem_idx + j * 4u + k + 16u]  = q_hi; // store second weight\n                }\n            }\n        }\n    }\n}\n#endif // INIT_SRC0_SHMEM_Q5_1\n\n#ifdef INIT_SRC0_SHMEM_Q8_0\nconst BLOCK_SIZE = 32u;\n// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.\noverride BLOCKS_K = TILE_K/BLOCK_SIZE;\nconst NQ = 16u;\nconst F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights\nconst WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16\nconst F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread\n\nfn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {\n    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {\n        let blck_idx = i / BLOCK_SIZE;\n        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;\n        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;\n\n        let tile_m = blck_idx / BLOCKS_K;\n        let global_m = offset_m + tile_m;\n        let block_k = blck_idx % BLOCKS_K;\n        let global_k = k_outer / BLOCK_SIZE + block_k;\n\n        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {\n            let src0_idx = batch_offset + global_m * params.stride_01 + global_k;\n            let scale_idx = src0_idx * F16_PER_BLOCK;\n            let d = src0[scale_idx];\n\n            for (var j = 0u; j < F16_PER_THREAD; j+=2) {\n                let q_0 = src0[scale_idx + 1u + block_offset + j];\n                let q_1 = src0[scale_idx + 1u + block_offset + j + 1];\n\n                let q_packed = bitcast<u32>(vec2(q_0, q_1));\n                for (var k = 0u; k < 4u; k++) {\n                    let q_byte = get_byte_i32(q_packed, k);\n\n                    let q_val = f16(q_byte) * d;\n                    shmem[shmem_idx + j * 2 + k] = q_val;\n                }\n            }\n        }\n    }\n}\n#endif // INIT_SRC0_SHMEM_Q8_0\n\n#ifdef INIT_SRC0_SHMEM_Q8_1\nconst BLOCK_SIZE = 32u;\n// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.\noverride BLOCKS_K = TILE_K/BLOCK_SIZE;\nconst NQ = 16u;\nconst F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights\nconst WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16\nconst F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block\n\nfn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {\n    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {\n        let blck_idx = i / BLOCK_SIZE;\n        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;\n        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;\n\n        let tile_m = blck_idx / BLOCKS_K;\n        let global_m = offset_m + tile_m;\n        let block_k = blck_idx % BLOCKS_K;\n        let global_k = k_outer / BLOCK_SIZE + block_k;\n\n        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {\n            let src0_idx = batch_offset + global_m * params.stride_01 + global_k;\n            let scale_idx = src0_idx * F16_PER_BLOCK;\n            let d = src0[scale_idx];\n            let m = src0[scale_idx + 1u];\n\n            for (var j = 0u; j < F16_PER_THREAD; j+=2) {\n                let q_0 = src0[scale_idx + 2u + block_offset + j];\n                let q_1 = src0[scale_idx + 2u + block_offset + j + 1];\n\n                let q_packed = bitcast<u32>(vec2(q_0, q_1));\n                for (var k = 0u; k < 4u; k++) {\n                    let q_byte = get_byte_i32(q_packed, k);\n\n                    let q_val = f16(q_byte) * d + m;\n                    shmem[shmem_idx + j * 2 + k] = q_val;\n                }\n            }\n        }\n    }\n}\n#endif // INIT_SRC0_SHMEM_Q8_1\n\n#ifdef INIT_SRC0_SHMEM_Q2_K\nconst BLOCK_SIZE = 256u;\nconst F16_PER_BLOCK = 42u;\n\nfn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {\n    // Use standard thread layout instead of lane/row_group\n    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {\n        let tile_m = elem_idx / TILE_K;\n        let tile_k = elem_idx % TILE_K;\n\n        let global_m = offset_m + tile_m;\n        let global_k = k_outer + tile_k;\n\n        if (global_m >= params.m || global_k >= params.k) {\n            shmem[elem_idx] = f16(0.0);\n            continue;\n        }\n\n        let block_k = global_k / BLOCK_SIZE;\n        let k_in_block = global_k % BLOCK_SIZE;\n\n        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;\n        let scale_idx = src0_idx * F16_PER_BLOCK;\n\n        let d = src0[scale_idx + 40u];\n        let dmin = src0[scale_idx + 41u];\n\n        // Decode the element at position k_in_block\n        let block_of_32 = k_in_block / 32u;\n        let pos_in_32 = k_in_block % 32u;\n\n        let q_b_idx = (block_of_32 / 4u) * 32u;\n        let shift = (block_of_32 % 4u) * 2u;\n        let k = (pos_in_32 / 16u) * 16u;\n        let l = pos_in_32 % 16u;\n\n        let is = k_in_block / 16u;\n\n        let sc_0 = src0[scale_idx + 2u * (is / 4u)];\n        let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];\n        let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));\n        let sc = get_byte(sc_packed, is % 4u);\n\n        let dl = d * f16(sc & 0xFu);\n        let ml = dmin * f16(sc >> 4u);\n\n        let q_idx = q_b_idx + k + l;\n        let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];\n        let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];\n        let q_packed = bitcast<u32>(vec2(q_0, q_1));\n        let q_byte = get_byte(q_packed, q_idx % 4u);\n        let qs_val = (q_byte >> shift) & 3u;\n\n        let q_val = f16(qs_val) * dl - ml;\n        shmem[elem_idx] = q_val;\n    }\n}\n#endif // INIT_SRC0_SHMEM_Q2_K\n\n#ifdef INIT_SRC0_SHMEM_Q3_K\nconst BLOCK_SIZE = 256u;\nconst F16_PER_BLOCK = 55u;\n\nfn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {\n    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {\n        let tile_m = elem_idx / TILE_K;\n        let tile_k = elem_idx % TILE_K;\n\n        let global_m = offset_m + tile_m;\n        let global_k = k_outer + tile_k;\n\n        if (global_m >= params.m || global_k >= params.k) {\n            shmem[elem_idx] = f16(0.0);\n            continue;\n        }\n\n        let block_k = global_k / BLOCK_SIZE;\n        let k_in_block = global_k % BLOCK_SIZE;\n\n        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;\n        let scale_idx = src0_idx * F16_PER_BLOCK;\n\n        let d = src0[scale_idx + 54u];\n\n        // Load and unpack scales\n        let kmask1: u32 = 0x03030303u;\n        let kmask2: u32 = 0x0f0f0f0fu;\n\n        var scale_vals: array<u32, 4>;\n        for (var i: u32 = 0u; i < 4u; i++) {\n            let scale_0 = src0[scale_idx + 48u + (2u*i)];\n            let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];\n            scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));\n        }\n\n        var tmp: u32 = scale_vals[2];\n        scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);\n        scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);\n        scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);\n        scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);\n\n        // Load hmask and qs arrays\n        var hmask_vals: array<u32, 8>;\n        for (var i: u32 = 0u; i < 8u; i++) {\n            let hmask_0 = src0[scale_idx + (2u*i)];\n            let hmask_1 = src0[scale_idx + (2u*i) + 1u];\n            hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));\n        }\n\n        var qs_vals: array<u32, 16>;\n        for (var i: u32 = 0u; i < 16u; i++) {\n            let qs_0 = src0[scale_idx + 16u + (2u*i)];\n            let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];\n            qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));\n        }\n\n        let half = k_in_block / 128u;           // 0 or 1\n        let pos_in_half = k_in_block % 128u;    // 0-127\n        let shift_group = pos_in_half / 32u;    // 0-3\n        let pos_in_32 = pos_in_half % 32u;      // 0-31\n        let k_group = pos_in_32 / 16u;          // 0 or 1\n        let l = pos_in_32 % 16u;                // 0-15\n\n        let q_b_idx = half * 32u;               // 0 or 32\n        let shift = shift_group * 2u;           // 0, 2, 4, 6\n        let k = k_group * 16u;                  // 0 or 16\n        let is = k_in_block / 16u;              // 0-15\n\n        // m increments every 32 elements across entire 256 element block\n        let m_shift = k_in_block / 32u;         // 0-7\n        let m: u32 = 1u << m_shift;             // 1,2,4,8,16,32,64,128\n\n        let sc = get_byte(scale_vals[is / 4u], is % 4u);\n        let dl = d * (f16(sc) - 32.0);\n\n        let q_idx = q_b_idx + k + l;\n        let hm_idx = k + l;\n\n        let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);\n        let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);\n\n        let hm = select(4.0, 0.0, (hmask_byte & m) != 0);\n        let qs_val = (q_byte >> shift) & 3u;\n\n        let q_val = (f16(qs_val) - f16(hm)) * dl;\n        shmem[elem_idx] = q_val;\n    }\n}\n\n#endif // INIT_SRC0_SHMEM_Q3_K\n\n#ifdef INIT_SRC0_SHMEM_Q4_K\nconst BLOCK_SIZE = 256u;\nconst F16_PER_BLOCK = 72u;\n\nfn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {\n    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {\n        let tile_m = elem_idx / TILE_K;\n        let tile_k = elem_idx % TILE_K;\n\n        let global_m = offset_m + tile_m;\n        let global_k = k_outer + tile_k;\n\n        if (global_m >= params.m || global_k >= params.k) {\n            shmem[elem_idx] = f16(0.0);\n            continue;\n        }\n\n        let block_k = global_k / BLOCK_SIZE;\n        let k_in_block = global_k % BLOCK_SIZE;\n\n        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;\n        let scale_idx = src0_idx * F16_PER_BLOCK;\n\n        let d = src0[scale_idx];\n        let dmin = src0[scale_idx + 1u];\n\n        // Load packed scales\n        var scale_vals: array<u32, 3>;\n        for (var i: u32 = 0u; i < 3u; i++) {\n            let scale_0 = src0[scale_idx + 2u + (2u*i)];\n            let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];\n            scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));\n        }\n\n        // Map k_in_block to loop structure:\n        // Outer loop over 64-element groups (alternating q_b_idx)\n        // Inner loop over 2 shifts per group\n        let group_of_64 = k_in_block / 64u;  // 0-3 (maps to q_b_idx)\n        let pos_in_64 = k_in_block % 64u;    // 0-63\n        let shift_group = pos_in_64 / 32u;   // 0 or 1\n        let l = pos_in_64 % 32u;             // 0-31\n\n        let q_b_idx = group_of_64 * 32u;     // 0, 32, 64, 96\n        let shift = shift_group * 4u;        // 0 or 4\n        let is = k_in_block / 32u;           // 0-7\n\n        var sc: u32;\n        var mn: u32;\n\n        if (is < 4u) {\n            let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);\n            let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);\n            sc = sc_byte & 63u;\n            mn = min_byte & 63u;\n        } else {\n            let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);\n            let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);\n            let min_hi = get_byte(scale_vals[is / 4u], is % 4u);\n\n            sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);\n            mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);\n        }\n\n        let dl = d * f16(sc);\n        let ml = dmin * f16(mn);\n\n        let q_idx = q_b_idx + l;\n        let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];\n        let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];\n        let q_packed = bitcast<u32>(vec2(q_0, q_1));\n\n        let q_byte = get_byte(q_packed, q_idx % 4u);\n        let qs_val = (q_byte >> shift) & 0xFu;\n\n        let q_val = f16(qs_val) * dl - ml;\n        shmem[elem_idx] = q_val;\n    }\n}\n#endif // INIT_SRC0_SHMEM_Q4_K\n\n#ifdef INIT_SRC0_SHMEM_Q5_K\nconst BLOCK_SIZE = 256u;\nconst F16_PER_BLOCK = 88u;\n\nfn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {\n    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {\n        let tile_m = elem_idx / TILE_K;\n        let tile_k = elem_idx % TILE_K;\n\n        let global_m = offset_m + tile_m;\n        let global_k = k_outer + tile_k;\n\n        if (global_m >= params.m || global_k >= params.k) {\n            shmem[elem_idx] = f16(0.0);\n            continue;\n        }\n\n        let block_k = global_k / BLOCK_SIZE;\n        let k_in_block = global_k % BLOCK_SIZE;\n\n        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;\n        let scale_idx = src0_idx * F16_PER_BLOCK;\n\n        let d = src0[scale_idx];\n        let dmin = src0[scale_idx + 1u];\n\n        // Load packed scales\n        var scale_vals: array<u32, 3>;\n        for (var i: u32 = 0u; i < 3u; i++) {\n            let scale_0 = src0[scale_idx + 2u + (2u*i)];\n            let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];\n            scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));\n        }\n\n        // The original loop processes elements in groups of 64\n        // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]\n        // But u increments EVERY 32 elements (after each l loop)\n        let group_of_64 = k_in_block / 64u;  // 0-3\n        let pos_in_64 = k_in_block % 64u;    // 0-63\n        let shift_group = pos_in_64 / 32u;   // 0 or 1\n        let l = pos_in_64 % 32u;             // 0-31\n\n        let q_b_idx = group_of_64 * 32u;     // 0, 32, 64, 96\n        let shift = shift_group * 4u;        // 0 or 4\n        let is = k_in_block / 32u;           // 0-7\n\n        // u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)\n        let u_shift = k_in_block / 32u;      // 0-7\n        let u: u32 = 1u << u_shift;\n\n        var sc: u32;\n        var mn: u32;\n\n        if (is < 4u) {\n            let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);\n            let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);\n            sc = sc_byte & 63u;\n            mn = min_byte & 63u;\n        } else {\n            let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);\n            let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);\n            let min_hi = get_byte(scale_vals[is / 4u], is % 4u);\n\n            sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);\n            mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);\n        }\n\n        let dl = d * f16(sc);\n        let ml = dmin * f16(mn);\n\n        let q_idx = q_b_idx + l;\n        let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];\n        let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];\n        let q_packed = bitcast<u32>(vec2(q_0, q_1));\n\n        let q_byte = get_byte(q_packed, q_idx % 4u);\n\n        let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];\n        let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];\n        let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));\n\n        let qh_byte = get_byte(qh_packed, l % 4u);\n\n        let qs_val = (q_byte >> shift) & 0xFu;\n        let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);\n\n        let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;\n        shmem[elem_idx] = q_val;\n    }\n}\n\n#endif // INIT_SRC0_SHMEM_Q5_K\n\n#ifdef INIT_SRC0_SHMEM_Q6_K\nconst BLOCK_SIZE = 256u;\nconst F16_PER_BLOCK = 105u;\n\nfn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {\n    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {\n        let tile_m = elem_idx / TILE_K;\n        let tile_k = elem_idx % TILE_K;\n\n        let global_m = offset_m + tile_m;\n        let global_k = k_outer + tile_k;\n\n        if (global_m >= params.m || global_k >= params.k) {\n            shmem[elem_idx] = f16(0.0);\n            continue;\n        }\n\n        let block_k = global_k / BLOCK_SIZE;\n        let k_in_block = global_k % BLOCK_SIZE;\n\n        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;\n        let scale_idx = src0_idx * F16_PER_BLOCK;\n\n        let half = k_in_block / 128u;\n        let pos_in_half = k_in_block % 128u;\n        let quarter = pos_in_half / 32u;\n        let l = pos_in_half % 32u;\n\n        let ql_b_idx = half * 64u;\n        let qh_b_idx = half * 32u;\n        let sc_b_idx = half * 8u;\n\n        // Load only ql13 word needed\n        let ql13_flat = ql_b_idx + l;\n        let ql13_word = ql13_flat / 4u;\n        let ql13 = bitcast<u32>(vec2(\n            src0[scale_idx + 2u * ql13_word],\n            src0[scale_idx + 2u * ql13_word + 1u]\n        ));\n        let ql13_b = get_byte(ql13, ql13_flat % 4u);\n\n        // Load only ql24 word needed\n        let ql24_flat = ql_b_idx + l + 32u;\n        let ql24_word = ql24_flat / 4u;\n        let ql24 = bitcast<u32>(vec2(\n            src0[scale_idx + 2u * ql24_word],\n            src0[scale_idx + 2u * ql24_word + 1u]\n        ));\n        let ql24_b = get_byte(ql24, ql24_flat % 4u);\n\n        // Load only qh word needed\n        let qh_flat = qh_b_idx + l;\n        let qh_word = qh_flat / 4u;\n        let qh = bitcast<u32>(vec2(\n            src0[scale_idx + 64u + 2u * qh_word],\n            src0[scale_idx + 64u + 2u * qh_word + 1u]\n        ));\n        let qh_b = get_byte(qh, qh_flat % 4u);\n\n        let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);\n        let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);\n        let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);\n        let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);\n\n        // Load only the scale word needed\n        let is = l / 16u;\n        let sc_idx = sc_b_idx + is + quarter * 2u;\n        let sc_word = sc_idx / 4u;\n        let sc = bitcast<u32>(vec2(\n            src0[scale_idx + 96u + 2u * sc_word],\n            src0[scale_idx + 96u + 2u * sc_word + 1u]\n        ));\n        let sc_val = get_byte_i32(sc, sc_idx % 4u);\n\n        let d = src0[scale_idx + 104u];\n\n        var q_val: f16;\n        if (quarter == 0u) {\n            q_val = q1;\n        } else if (quarter == 1u) {\n            q_val = q2;\n        } else if (quarter == 2u) {\n            q_val = q3;\n        } else {\n            q_val = q4;\n        }\n\n        shmem[elem_idx] = d * f16(sc_val) * q_val;\n    }\n}\n#endif // INIT_SRC0_SHMEM_Q6_K\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl",
    "content": "enable f16;\n\n#include \"common_decls.tmpl\"\n#include \"mul_mat_decls.tmpl\"\n\n#ifdef VEC\nfn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {\n    return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));\n}\n#endif\n\n#ifdef SCALAR\nfn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {\n    return f32(acc[tm][tn]);\n}\n#endif\n\nstruct MulMatParams {\n    offset_src0: u32,\n    offset_src1: u32,\n    offset_dst: u32,\n    m: u32,\n    n: u32,\n    k: u32,\n    stride_01: u32,\n    stride_11: u32,\n    stride_02: u32,\n    stride_12: u32,\n    stride_03: u32,\n    stride_13: u32,\n    bs02: u32,\n    bs03: u32,\n    broadcast2: u32,\n    broadcast3: u32\n};\n\n@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns\n@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)\n@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)\n\n@group(0) @binding(3) var<uniform> params: MulMatParams;\n\nfn get_local_n(thread_id: u32) -> u32 {\n    return thread_id / WORKGROUP_SIZE_M;\n}\nfn get_local_m(thread_id: u32) -> u32 {\n    return thread_id % WORKGROUP_SIZE_M;\n}\n\nconst TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;\nconst TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;\nconst TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;\n\nvar<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;\n\n@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)\nfn main(@builtin(workgroup_id) wg_id: vec3<u32>,\n        @builtin(local_invocation_id) local_id: vec3<u32>,\n        @builtin(num_workgroups) num_wg: vec3<u32>) {\n\n    let thread_id = local_id.x;\n    let local_m = get_local_m(thread_id);\n    let local_n = get_local_n(thread_id);\n\n    let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N);\n    let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);\n    let wg_per_matrix = wg_m_count * wg_n_count;\n\n    let wg_linear = wg_id.y * num_wg.x + wg_id.x;\n\n    let batch_idx = wg_linear / wg_per_matrix;\n\n    let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;\n    if (batch_idx >= total_batches) {\n        return;\n    }\n\n    let wg_in_batch = wg_linear % wg_per_matrix;\n    let wg_m = wg_in_batch % wg_m_count;\n    let wg_n = wg_in_batch / wg_m_count;\n\n    let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M;\n    let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N;\n\n    let dst2_stride = params.m * params.n;\n    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;\n\n    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);\n    let src03_idx = dst3_idx / params.broadcast3;\n    let src13_idx = dst3_idx;\n    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);\n    let src02_idx = dst2_idx / params.broadcast2;\n    let src12_idx = dst2_idx;\n\n    let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;\n    let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;\n\n    let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M;\n    let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N;\n\n    var acc: array<array<f16, TILE_N>, TILE_M>;\n\n    for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {\n\n        // see mul_mat_decls.tmpl\n        init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);\n        init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);\n\n        workgroupBarrier();\n\n        let k_end = min(TILE_K, params.k - k_outer);\n\n        for (var k_inner = 0u; k_inner < k_end; k_inner++) {\n            var src0_tile: array<f16, TILE_M>;\n            for (var tm = 0u; tm < TILE_M; tm++) {\n                let src0_m = local_m * TILE_M + tm;\n                let src0_idx = k_inner + src0_m * TILE_K;\n                src0_tile[tm] = shmem[src0_idx];\n            }\n            for (var tn = 0u; tn < TILE_N; tn++) {\n                let src1_n = local_n * TILE_N + tn;\n                let src1_idx = src1_n * TILE_K + k_inner;\n                let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];\n                for (var tm = 0u; tm < TILE_M; tm++) {\n                      acc[tm][tn] += src0_tile[tm] * src1_val;\n                }\n            }\n        }\n\n        workgroupBarrier();\n    }\n\n    let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;\n\n    for (var tn = 0u; tn < TILE_N; tn++) {\n        let global_col = output_col_base + tn;\n        if (global_col < params.n) {\n            for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) {\n                let global_row = output_row_base + tm;\n                if (global_row < params.m) {\n                    let dst_idx = dst_batch_offset + global_col * params.m + global_row;\n                    dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm);\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl",
    "content": "diagnostic(off, chromium.subgroup_matrix_uniformity);\nenable f16;\nenable subgroups;\nenable chromium_experimental_subgroup_matrix;\n\n#include \"common_decls.tmpl\"\n#include \"mul_mat_decls.tmpl\"\n\n#ifdef VEC\nfn store_dst(shmem_idx: u32, dst_idx: u32) {\n    dst[dst_idx] = vec4<f32>(\n        f32(shmem[shmem_idx]),\n        f32(shmem[shmem_idx + 1]),\n        f32(shmem[shmem_idx + 2]),\n        f32(shmem[shmem_idx + 3])\n    );\n}\n#endif\n\n#ifdef SCALAR\nfn store_dst(shmem_idx: u32, dst_idx: u32) {\n    dst[dst_idx] = f32(shmem[shmem_idx]);\n}\n#endif\n\nstruct MulMatParams {\n    offset_src0: u32,\n    offset_src1: u32,\n    offset_dst: u32,\n    m: u32,\n    n: u32,\n    k: u32,\n    stride_01: u32,\n    stride_11: u32,\n    stride_02: u32,\n    stride_12: u32,\n    stride_03: u32,\n    stride_13: u32,\n    bs02: u32,\n    bs03: u32,\n    broadcast2: u32,\n    broadcast3: u32\n};\n\n// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included\n@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns\n@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)\n@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)\n\n@group(0) @binding(3) var<uniform> params: MulMatParams;\n\nconst WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;\nconst WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;\n\n// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the\n// runtime subgroup size is smaller.\nconst EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;\nconst TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;\nconst TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;\nconst TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;\n\nconst SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE;\n\n// We reuse shmem for accumulation matrices\nconst SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM);\n\nvar<workgroup> shmem: array<f16, SHMEM_SIZE>;\n\n@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)\nfn main(@builtin(workgroup_id) wg_id: vec3<u32>,\n        @builtin(local_invocation_id) local_id: vec3<u32>,\n        @builtin(subgroup_id) subgroup_id: u32,\n        @builtin(num_workgroups) num_wg: vec3<u32>) {\n\n    let thread_id = local_id.x;\n    let subgroup_m = subgroup_id % SUBGROUP_M;\n    let subgroup_n = subgroup_id / SUBGROUP_M;\n\n    let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE;\n    let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;\n    let wg_per_matrix = wg_m_count * wg_n_count;\n\n    let wg_linear = wg_id.y * num_wg.x + wg_id.x;\n\n    let batch_idx = wg_linear / wg_per_matrix;\n\n    let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;\n    if (batch_idx >= total_batches) {\n        return;\n    }\n\n    let wg_in_batch = wg_linear % wg_per_matrix;\n    let wg_m = wg_in_batch % wg_m_count;\n    let wg_n = wg_in_batch / wg_m_count;\n\n    let dst2_stride = params.m * params.n;\n    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;\n\n    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);\n    let src03_idx = dst3_idx / params.broadcast3;\n    let src13_idx = dst3_idx;\n    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);\n    let src02_idx = dst2_idx / params.broadcast2;\n    let src12_idx = dst2_idx;\n\n    let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;\n    let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;\n\n    let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;\n    let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;\n\n    var acc_sg_mat : array<array<subgroup_matrix_result<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>;\n\n    for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {\n\n        // see mul_mat_decls.tmpl\n        init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);\n        init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);\n\n        workgroupBarrier();\n\n        if (subgroup_id < EXPECTED_SUBGROUPS) {\n\n            for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) {\n\n                let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner;\n                var src0_sg_mats: array<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_M>;\n                for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {\n                    src0_sg_mats[m] = subgroupMatrixLoad<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>>(\n                        &shmem,\n                        src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K,\n                        false,\n                        TILE_K\n                    );\n                }\n\n                let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner;\n                for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {\n                    let src1_sg_mat = subgroupMatrixLoad<subgroup_matrix_right<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_K_SIZE>>(\n                        &shmem,\n                        src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K,\n                        true,\n                        TILE_K\n                    );\n                    for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {\n                        acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]);\n                    }\n                }\n            }\n        }\n\n        workgroupBarrier();\n    }\n\n    let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;\n\n    // Stage the subgroup matrix tiles into shared memory\n    // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile).\n    let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE;\n    let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;\n    let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;\n\n    if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :(\n        for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {\n            for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {\n                let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE;\n                let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE;\n                let out_base = local_row * WG_TILE_STRIDE + local_col;\n                subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE);\n            }\n        }\n    }\n\n    workgroupBarrier();\n\n    // Cooperative write: iterate over the entire workgroup tile\n    let tile_rows = WG_N_SG_TILE_SIZE;\n    let tile_cols = WG_M_SG_TILE_SIZE;\n    let total_tile_elems = tile_rows * tile_cols;\n    let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;\n    let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;\n\n    for (var idx = thread_id * VEC_SIZE; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {\n        let local_row = idx % WG_TILE_STRIDE;\n        let local_col = idx / WG_TILE_STRIDE;\n\n        let global_row = tile_dst_row_base + local_row;\n        let global_col = tile_dst_col_base + local_col;\n\n        if (global_col < params.n && global_row < params.m) {\n            let dst_idx = dst_batch_offset + global_col * params.m + global_row;\n            store_dst(idx, dst_idx/VEC_SIZE);\n        }\n    }\n}\n\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl",
    "content": "enable f16;\n\n#include \"common_decls.tmpl\"\n\n#ifdef VEC\n\n#define VEC_SIZE 4\n#define DST_TYPE vec4<f32>\n#define SRC0_TYPE vec4<SRC0_INNER_TYPE>\n#define SRC1_TYPE vec4<SRC1_INNER_TYPE>\n\nfn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {\n    return f32(dot(SRC1_TYPE(src0_val), src1_val));\n}\n\nfn store_val(group_base: u32) -> vec4<f32> {\n    return vec4<f32>(partial_sums[group_base],\n                     partial_sums[group_base + THREADS_PER_OUTPUT],\n                     partial_sums[group_base + THREADS_PER_OUTPUT * 2],\n                     partial_sums[group_base + THREADS_PER_OUTPUT * 3]);\n}\n#endif\n\n#ifdef SCALAR\n\n#define VEC_SIZE 1\n#define DST_TYPE f32\n#define SRC0_TYPE SRC0_INNER_TYPE\n#define SRC1_TYPE SRC1_INNER_TYPE\n\nfn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {\n    return f32(src0_val) * f32(src1_val);\n}\n\nfn store_val(group_base: u32) -> f32 {\n    return partial_sums[group_base];\n}\n#endif\n\n#ifdef MUL_ACC_FLOAT\nfn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {\n    var local_sum = 0.0;\n    for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) {\n        let a = src0[(idx_base + k_outer + i) / VEC_SIZE];\n        let b = shared_vector[i / VEC_SIZE];\n        local_sum += inner_dot(a, b);\n    }\n    return local_sum;\n}\n#endif\n\n#ifdef MUL_ACC_Q4_0\n\nconst BLOCK_SIZE = 32;\nconst NQ = 16u; // number of weights per thread\nconst F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights\nconst WEIGHTS_PER_F16 = 4u; // 4 weights per f16\nconst F16_PER_THREAD = NQ / WEIGHTS_PER_F16;\n\nfn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {\n    var local_sum = 0.0;\n    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {\n        let blck_idx = i / BLOCK_SIZE;\n        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;\n        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;\n        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]\n        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;\n        let d = f32(src0[scale_idx]);\n        for (var j = 0u; j < F16_PER_THREAD; j += 2) {\n            let q_0 = src0[scale_idx + 1 + block_offset + j];\n            let q_1 = src0[scale_idx + 1 + block_offset + j + 1];\n            let q_packed = bitcast<u32>(vec2(q_0, q_1));\n            for (var k: u32 = 0; k < 4; k++) {\n                let q_byte = get_byte(q_packed, k);\n                let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;\n                let q_lo = (f32(q_byte & 0xF) - 8.0) * d;\n                local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];\n                local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];\n            }\n        }\n    }\n    return local_sum;\n}\n#endif\n\n#ifdef MUL_ACC_Q4_1\n\nconst BLOCK_SIZE = 32;\nconst NQ = 16u; // number of weights per thread\nconst F16_PER_BLOCK = 10u;\nconst WEIGHTS_PER_F16 = 4u; // 4 weights per f16\nconst F16_PER_THREAD = NQ / WEIGHTS_PER_F16;\n\nfn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {\n    var local_sum = 0.0;\n    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {\n        let blck_idx = i / BLOCK_SIZE;\n        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;\n        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;\n        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]\n        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;\n        let d = f32(src0[scale_idx]);\n        let m = f32(src0[scale_idx + 1u]);\n        for (var j = 0u; j < F16_PER_THREAD; j += 2) {\n            let q_0 = src0[scale_idx + 2u + block_offset + j];\n            let q_1 = src0[scale_idx + 2u + block_offset + j + 1];\n            let q_packed = bitcast<u32>(vec2(q_0, q_1));\n            for (var k: u32 = 0; k < 4; k++) {\n                let q_byte = get_byte(q_packed, k);\n                let q_hi = f32((q_byte >> 4) & 0xF) * d + m;\n                let q_lo = f32(q_byte & 0xF) * d + m;\n                local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];\n                local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];\n            }\n        }\n    }\n    return local_sum;\n}\n#endif\n\n#ifdef MUL_ACC_Q5_0\n\nconst BLOCK_SIZE = 32;\nconst NQ = 16u; // number of weights per thread\nconst F16_PER_BLOCK = 11u;\nconst WEIGHTS_PER_F16 = 4u; // 4 weights per f16\nconst F16_PER_THREAD = NQ / WEIGHTS_PER_F16;\n\nfn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {\n    var local_sum = 0.0;\n    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {\n        let blck_idx = i / BLOCK_SIZE;\n        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;\n        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;\n        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]\n        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;\n        let d = f32(src0[scale_idx]);\n        let qh0 = src0[scale_idx + 1u];\n        let qh1 = src0[scale_idx + 2u];\n        let qh_packed = bitcast<u32>(vec2(qh0, qh1));\n\n        for (var j = 0u; j < 2; j++) {\n            let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];\n            let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];\n            let q_packed = bitcast<u32>(vec2(q_0, q_1));\n\n            let j_adjusted = j + (block_offset / 2u);\n\n            for (var k: u32 = 0; k < 4; k++) {\n                let q_byte = get_byte(q_packed, k);\n\n                let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;\n                let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;\n                let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;\n                let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;\n\n                local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];\n                local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];\n            }\n\n        }\n    }\n    return local_sum;\n}\n#endif\n\n\n#ifdef MUL_ACC_Q5_1\n\nconst BLOCK_SIZE = 32;\nconst NQ = 16u; // number of weights per thread\nconst F16_PER_BLOCK = 12u;\nconst WEIGHTS_PER_F16 = 4u; // 4 weights per f16\nconst F16_PER_THREAD = NQ / WEIGHTS_PER_F16;\n\nfn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {\n    var local_sum = 0.0;\n    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {\n        let blck_idx = i / BLOCK_SIZE;\n        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;\n        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;\n        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]\n        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;\n        let d = f32(src0[scale_idx]);\n        let m = src0[scale_idx + 1u];\n        let qh0 = src0[scale_idx + 2u];\n        let qh1 = src0[scale_idx + 3u];\n        let qh_packed = bitcast<u32>(vec2(qh0, qh1));\n\n        for (var j = 0u; j < 2; j++) {\n            let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];\n            let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];\n            let q_packed = bitcast<u32>(vec2(q_0, q_1));\n\n            let j_adjusted = j + (block_offset / 2u);\n\n            for (var k: u32 = 0; k < 4; k++) {\n                let q_byte = get_byte(q_packed, k);\n\n                let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;\n                let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m);\n                let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;\n                let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m);\n\n                local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];\n                local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];\n            }\n\n        }\n    }\n    return local_sum;\n}\n#endif\n\n\n#ifdef MUL_ACC_Q8_0\n\nconst BLOCK_SIZE = 32;\nconst NQ = 16u; // number of weights per thread\nconst F16_PER_BLOCK = 17u;\nconst WEIGHTS_PER_F16 = 2u;\nconst F16_PER_THREAD = NQ / WEIGHTS_PER_F16;\n\nfn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {\n    var local_sum = 0.0;\n    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {\n        let blck_idx = i / BLOCK_SIZE;\n        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;\n        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;\n        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]\n        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;\n        let d = f32(src0[scale_idx]);\n\n        for (var j = 0u; j < F16_PER_THREAD; j += 2) {\n            let q_0 = src0[scale_idx + 1 + block_offset + j];\n            let q_1 = src0[scale_idx + 1 + block_offset + j + 1];\n            let q_packed = bitcast<u32>(vec2(q_0, q_1));\n            for (var k: u32 = 0; k < 4; k++) {\n                let q_byte = get_byte_i32(q_packed, k);\n                let q_val = f32(q_byte) * d;\n                local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];\n            }\n        }\n    }\n    return local_sum;\n}\n#endif\n\n\n#ifdef MUL_ACC_Q8_1\n\nconst BLOCK_SIZE = 32;\nconst NQ = 16u; // number of weights per thread\nconst F16_PER_BLOCK = 18u;\nconst WEIGHTS_PER_F16 = 2u;\nconst F16_PER_THREAD = NQ / WEIGHTS_PER_F16;\n\nfn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {\n    var local_sum = 0.0;\n    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {\n        let blck_idx = i / BLOCK_SIZE;\n        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;\n        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;\n        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]\n        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;\n        let d = f32(src0[scale_idx]);\n        let m = src0[scale_idx + 1u];\n\n        for (var j = 0u; j < F16_PER_THREAD; j += 2) {\n            let q_0 = src0[scale_idx + 2u + block_offset + j];\n            let q_1 = src0[scale_idx + 2u + block_offset + j + 1];\n            let q_packed = bitcast<u32>(vec2(q_0, q_1));\n            for (var k: u32 = 0; k < 4; k++) {\n                let q_byte = get_byte_i32(q_packed, k);\n                let q_val = f32(q_byte) * d + f32(m);\n                local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];\n            }\n        }\n    }\n    return local_sum;\n}\n#endif\n\n#ifdef MUL_ACC_Q6_K\n\nconst BLOCK_SIZE = 256u;\nconst F16_PER_BLOCK = 105u;\n\nfn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {\n    let aligned = byte_offset & ~3u;\n    let idx = bbase + aligned / 2u;\n    return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));\n}\n\nfn byte_of(v: u32, b: u32) -> u32 {\n    return (v >> (b * 8u)) & 0xFFu;\n}\n\nfn sbyte_of(v: u32, b: u32) -> i32 {\n    let raw = i32((v >> (b * 8u)) & 0xFFu);\n    return select(raw, raw - 256, raw >= 128);\n}\n\nfn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {\n    let tid = tig / 2u;\n    let ix  = tig % 2u;\n    let ip  = tid / 8u;\n    let il  = tid % 8u;\n    let l0  = 4u * il;\n    let is  = 8u * ip + l0 / 16u;\n\n    let y_offset   = 128u * ip + l0;\n    let q_offset_l =  64u * ip + l0;\n    let q_offset_h =  32u * ip + l0;\n\n    let nb = tile_size / BLOCK_SIZE;\n    let k_block_start = k_outer / BLOCK_SIZE;\n\n    // Aligned scale byte position (is can be odd)\n    let sc_base_byte = 192u + (is & ~3u);\n    let sc_byte_pos  = is & 3u;\n\n    var local_sum = 0.0;\n\n    for (var i = ix; i < nb; i += 2u) {\n        let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;\n\n        let d_raw = load_u32_at(bbase, 208u);\n        let d = f32(bitcast<vec2<f16>>(d_raw)[0]);\n\n        let ql1_u32  = load_u32_at(bbase, q_offset_l);\n        let ql2_u32  = load_u32_at(bbase, q_offset_l + 32u);\n        let qh_u32   = load_u32_at(bbase, 128u + q_offset_h);\n        let sc_u32_0 = load_u32_at(bbase, sc_base_byte);\n        let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);\n\n        let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);\n        let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);\n        let sc4 = sbyte_of(sc_u32_1, sc_byte_pos);\n        let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u);\n\n        var sums = vec4<f32>(0.0, 0.0, 0.0, 0.0);\n\n        for (var l = 0u; l < 4u; l++) {\n            let y_base = i * BLOCK_SIZE + y_offset + l;\n            let yl0 = f32(shared_vector[y_base]);\n            let yl1 = f32(shared_vector[y_base + 32u]);\n            let yl2 = f32(shared_vector[y_base + 64u]);\n            let yl3 = f32(shared_vector[y_base + 96u]);\n\n            let q1b = byte_of(ql1_u32, l);\n            let q2b = byte_of(ql2_u32, l);\n            let qhb = byte_of(qh_u32,  l);\n\n            let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32);\n            let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32);\n            let dq2 = f32(i32((q1b >>   4u) | ((qhb & 0x30u)       )) - 32);\n            let dq3 = f32(i32((q2b >>   4u) | ((qhb & 0xC0u) >> 2u)) - 32);\n\n            sums[0] += yl0 * dq0;\n            sums[1] += yl1 * dq1;\n            sums[2] += yl2 * dq2;\n            sums[3] += yl3 * dq3;\n        }\n\n        local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) +\n                          sums[2] * f32(sc4) + sums[3] * f32(sc6));\n    }\n\n    return local_sum;\n}\n#endif\n\nstruct MulMatParams {\n    offset_src0: u32,\n    offset_src1: u32,\n    offset_dst: u32,\n    m: u32,\n    n: u32,\n    k: u32,\n    stride_01: u32,\n    stride_11: u32,\n    stride_02: u32,\n    stride_12: u32,\n    stride_03: u32,\n    stride_13: u32,\n    bs02: u32,\n    bs03: u32,\n    broadcast2: u32,\n    broadcast3: u32\n};\n\n// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included\n@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns\n@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)\n@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)\n\n@group(0) @binding(3) var<uniform> params: MulMatParams;\n\nconst THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG;\n\n// Shared memory for collaborative loading and reduction\nvar<workgroup> shared_vector: array<SRC1_TYPE, TILE_K/VEC_SIZE>;  // Cache vector tile\nvar<workgroup> partial_sums: array<f32, WG_SIZE>;   // For reduction\n\n@compute @workgroup_size(WG_SIZE)\nfn main(\n    @builtin(local_invocation_id) local_id: vec3<u32>,\n    @builtin(workgroup_id) wg_id: vec3<u32>,\n    @builtin(num_workgroups) num_wg: vec3<u32>) {\n    let thread_id = local_id.x;\n\n    // Handle batch dimensions\n    let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;\n    let wg_linear = wg_id.y * num_wg.x + wg_id.x;\n    let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG;\n    let batch_idx = wg_linear / output_groups;\n    if (batch_idx >= total_batches) {\n        return;\n    }\n\n    // Which of the outputs does this thread belong to?\n    let thread_group = thread_id / THREADS_PER_OUTPUT;\n    let thread_in_group = thread_id % THREADS_PER_OUTPUT;\n\n    // Each workgroup computes OUTPUTS_PER_WG consecutive outputs\n    let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group;\n\n    let dst2_stride = params.m * params.n;\n    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);\n    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;\n    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);\n    let src03_idx = dst3_idx / params.broadcast3;\n    let src13_idx = dst3_idx;\n    let src02_idx = dst2_idx / params.broadcast2;\n    let src12_idx = dst2_idx;\n\n    let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01;\n    let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;\n    let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row;\n\n    var local_sum = 0.0;\n\n    // Each thread processes multiple K elements and accumulates\n    for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) {\n        let tile_size = min(TILE_K, params.k - k_tile);\n\n        // Cooperatively load vector tile into shared memory (all threads)\n        for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) {\n            shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE];\n        }\n\n        workgroupBarrier();\n\n        if (output_row < params.m) {\n            local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile);\n        }\n\n        workgroupBarrier();\n    }\n\n    // Store partial sums and reduce within each partition\n    partial_sums[thread_id] = local_sum;\n    workgroupBarrier();\n    let group_base = thread_group * THREADS_PER_OUTPUT;\n    let thread_base = group_base + thread_in_group;\n    var offset: u32 = THREADS_PER_OUTPUT / 2;\n    while (offset > 0) {\n        if (thread_in_group < offset) {\n            partial_sums[thread_base] += partial_sums[thread_base + offset];\n        }\n        offset = offset / 2;\n        workgroupBarrier();\n    }\n\n    // Store back to global memory\n    if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) {\n        dst[dst_idx / VEC_SIZE] = store_val(group_base);\n    }\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/pad.wgsl",
    "content": "@group(0) @binding(0)\nvar<storage, read_write> src: array<f32>;\n\n@group(0) @binding(1)\nvar<storage, read_write> dst: array<f32>;\n\nstruct Params {\n    ne: u32,            // total number of elements\n    offset_src: u32,    // in elements\n    offset_dst: u32,    // in elements\n\n    // Strides (in elements)\n    stride_src0: u32,\n    stride_src1: u32,\n    stride_src2: u32,\n    stride_src3: u32,\n\n    // Logical shapes\n    src_ne0: u32,\n    src_ne1: u32,\n    src_ne2: u32,\n    src_ne3: u32,\n\n    dst_ne0: u32,\n    dst_ne1: u32,\n    dst_ne2: u32,\n    dst_ne3: u32,\n\n    // Pad sizes (in elements)\n    lp0: u32,\n    rp0: u32,\n    lp1: u32,\n    rp1: u32,\n    lp2: u32,\n    rp2: u32,\n    lp3: u32,\n    rp3: u32,\n};\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n\nfn wrap_around(idx: i32, n: u32) -> u32 {\n    return u32(idx + i32(n)) % n;\n}\n\n@compute @workgroup_size(WG_SIZE)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n    if (gid.x >= params.ne) {\n        return;\n    }\n\n    var i = gid.x;\n    let dst_plane = params.dst_ne2 * params.dst_ne1 * params.dst_ne0;\n    let i3 = i / dst_plane;\n    i = i % dst_plane;\n    let i2 = i / (params.dst_ne1 * params.dst_ne0);\n    i = i % (params.dst_ne1 * params.dst_ne0);\n    let i1 = i / params.dst_ne0;\n    let i0 = i % params.dst_ne0;\n\n    var value: f32 = 0.0;\n\n#ifdef CIRCULAR\n    let ci0 = wrap_around(i32(i0) - i32(params.lp0), params.src_ne0);\n    let ci1 = wrap_around(i32(i1) - i32(params.lp1), params.src_ne1);\n    let ci2 = wrap_around(i32(i2) - i32(params.lp2), params.src_ne2);\n    let ci3 = wrap_around(i32(i3) - i32(params.lp3), params.src_ne3);\n    let circular_src_idx = ci0 * params.stride_src0 + ci1 * params.stride_src1 +\n                           ci2 * params.stride_src2 + ci3 * params.stride_src3;\n    value = src[params.offset_src + circular_src_idx];\n#else\n    let is_src =\n        (i0 >= params.lp0 && i0 < params.dst_ne0 - params.rp0) &&\n        (i1 >= params.lp1 && i1 < params.dst_ne1 - params.rp1) &&\n        (i2 >= params.lp2 && i2 < params.dst_ne2 - params.rp2) &&\n        (i3 >= params.lp3 && i3 < params.dst_ne3 - params.rp3);\n    if (is_src) {\n        let src_idx = (i0 - params.lp0) * params.stride_src0 + (i1 - params.lp1) * params.stride_src1 +\n                      (i2 - params.lp2) * params.stride_src2 + (i3 - params.lp3) * params.stride_src3;\n        value = src[params.offset_src + src_idx];\n    }\n#endif\n\n    dst[params.offset_dst + gid.x] = value;\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/repeat.wgsl",
    "content": "enable f16;\n\nstruct Params {\n    ne: u32,\n\n    offset_src0: u32,\n    offset_dst: u32,\n\n    stride_src0_0: u32,\n    stride_src0_1: u32,\n    stride_src0_2: u32,\n    stride_src0_3: u32,\n\n    a_ne0: u32,\n    a_ne1: u32,\n    a_ne2: u32,\n    a_ne3: u32,\n\n    ne0: u32,\n    ne1: u32,\n    ne2: u32,\n};\n\n#ifdef TYPE_F32\n#define DataType f32\n#endif\n#ifdef TYPE_I32\n#define DataType i32\n#endif\n#ifdef TYPE_I16\n// same size (16-bit) is sufficient for repeat\n#define DataType f16\n#endif\n\n@group(0) @binding(0)\nvar<storage, read_write> src0: array<DataType>;\n\n@group(0) @binding(1)\nvar<storage, read_write> dst: array<DataType>;\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n\n@compute @workgroup_size(WG_SIZE)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n    if (gid.x < params.ne) {\n        var i = gid.x;\n        let i3 = i / (params.ne2 * params.ne1 * params.ne0);\n        i = i % (params.ne2 * params.ne1 * params.ne0);\n        let i2 = i / (params.ne1 * params.ne0);\n        i = i % (params.ne1 * params.ne0);\n        let i1 = i / params.ne0;\n        let i0 = i % params.ne0;\n\n        let a_i0 = i0 % params.a_ne0;\n        let a_i1 = i1 % params.a_ne1;\n        let a_i2 = i2 % params.a_ne2;\n        let a_i3 = i3 % params.a_ne3;\n\n        let a_index = a_i0 * params.stride_src0_0 +\n                           a_i1 * params.stride_src0_1 +\n                           a_i2 * params.stride_src0_2 +\n                           a_i3 * params.stride_src0_3;\n\n        dst[params.offset_dst + gid.x] = src0[params.offset_src0 + a_index];\n    }\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl",
    "content": "#define(VARIANTS)\n\n[\n  {\n    \"DECLS\": [\"NOT_INPLACE\"]\n  },\n  {\n    \"SHADER_SUFFIX\": \"inplace\",\n    \"DECLS\": [\"INPLACE\"]\n  },\n]\n\n#end(VARIANTS)\n\n#define(DECLS)\n\n#decl(NOT_INPLACE)\n\nfn update(src_offset: u32, dst_offset: u32, scale: f32) {\n    dst[dst_offset] = scale * src[src_offset];\n}\n\n@group(0) @binding(1)\nvar<storage, read_write> dst: array<f32>;\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n\n#enddecl(NOT_INPLACE)\n\n#decl(INPLACE)\n\nfn update(src_offset: u32, dst_offset: u32, scale: f32) {\n    src[dst_offset] = scale * src[src_offset];\n}\n\n@group(0) @binding(1)\nvar<uniform> params: Params;\n\n#enddecl(INPLACE)\n\n#end(DECLS)\n\n#define(SHADER)\n\nstruct Params {\n    offset_src: u32, // in elements\n    offset_dst: u32, // in elements\n\n    // Strides (in elements)\n    stride_src1: u32,\n    stride_src2: u32,\n    stride_src3: u32,\n\n    stride_dst1: u32,\n    stride_dst2: u32,\n    stride_dst3: u32,\n\n    // Shape of src/dst\n    ne0: u32,\n    ne1: u32,\n    ne2: u32,\n    ne3: u32,\n\n    eps: f32\n};\n\n@group(0) @binding(0)\nvar<storage, read_write> src: array<f32>;\n\nDECLS\n\noverride wg_size: u32;\nvar<workgroup> scratch: array<f32, wg_size>;\n\n@compute @workgroup_size(wg_size)\nfn main(@builtin(workgroup_id) wid: vec3<u32>,\n        @builtin(local_invocation_id) lid: vec3<u32>) {\n\n    // one thread per row\n    var i = wid.x;\n    let i3 = i / (params.ne2 * params.ne1);\n    i = i % (params.ne2 * params.ne1);\n    let i2 = i / params.ne1;\n    let i1 = i % params.ne1;\n    let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;\n    let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;\n\n    let elems = (params.ne0 + wg_size - 1) / wg_size;\n\n    var sum = 0.0f;\n    var col = lid.x;\n    for (var j: u32 = 0; j < elems; j++) {\n        if (col >= params.ne0) {\n            break;\n        }\n        sum += pow(src[i_src_row + col], 2.0);\n        col += wg_size;\n    }\n\n    scratch[lid.x] = sum;\n    workgroupBarrier();\n    var offset = wg_size / 2;\n    while (offset > 0) {\n        if (lid.x < offset) {\n            scratch[lid.x] += scratch[lid.x + offset];\n        }\n        offset = offset / 2;\n        workgroupBarrier();\n    }\n    sum = scratch[0];\n\n    let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);\n    col = lid.x;\n    for (var j: u32 = 0; j < elems; j++) {\n        if (col >= params.ne0) {\n            break;\n        }\n        update(i_src_row + col, i_dst_row + col, scale);\n        col += wg_size;\n    }\n}\n#end(SHADER)\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl",
    "content": "#define(VARIANTS)\n\n[\n  {\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"NO_FF_BINDINGS\", \"NO_FF_FUNC\", \"ROTATE\"]\n  },\n  {\n    \"SHADER_SUFFIX\": \"f32_inplace\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"NO_FF_BINDINGS_INPLACE\", \"NO_FF_FUNC\", \"ROTATE_INPLACE\"]\n  },\n  {\n    \"REPLS\": {\n      \"TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"NO_FF_BINDINGS\", \"NO_FF_FUNC\", \"ROTATE\"]\n  },\n  {\n    \"SHADER_SUFFIX\": \"f16_inplace\",\n    \"REPLS\": {\n      \"TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"NO_FF_BINDINGS_INPLACE\", \"NO_FF_FUNC\", \"ROTATE_INPLACE\"]\n  },\n  {\n   \"SHADER_SUFFIX\": \"f32_ff\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"FF_BINDINGS\", \"FF_FUNC\", \"ROTATE\"]\n  },\n  {\n   \"SHADER_SUFFIX\": \"f32_ff_inplace\",\n    \"REPLS\": {\n      \"TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"FF_BINDINGS_INPLACE\", \"FF_FUNC\", \"ROTATE_INPLACE\"]\n  },\n  {\n    \"SHADER_SUFFIX\": \"f16_ff\",\n    \"REPLS\": {\n      \"TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"FF_BINDINGS\", \"FF_FUNC\", \"ROTATE\"]\n  },\n  {\n    \"SHADER_SUFFIX\": \"f16_ff_inplace\",\n    \"REPLS\": {\n      \"TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"FF_BINDINGS_INPLACE\", \"FF_FUNC\", \"ROTATE_INPLACE\"]\n  }\n]\n\n#end(VARIANTS)\n\n#define(DECLS)\n\n#decl(ROTATE)\nfn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {\n    dst[i_dst0] = {{TYPE}}(out0);\n    dst[i_dst1] = {{TYPE}}(out1);\n}\n#enddecl(ROTATE)\n\n#decl(ROTATE_INPLACE)\nfn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {\n    src0[i_dst0] = {{TYPE}}(out0);\n    src0[i_dst1] = {{TYPE}}(out1);\n}\n#enddecl(ROTATE_INPLACE)\n\n#decl(NO_FF_FUNC)\nfn freq_factor(i: u32) -> f32 {\n    return 1.0f;\n}\n#enddecl(NO_FF_FUNC)\n\n#decl(FF_FUNC)\nfn freq_factor(i: u32) -> f32 {\n    return src2[params.offset_src2 + i/2];\n}\n#enddecl(FF_FUNC)\n\n#decl(NO_FF_BINDINGS)\n\n@group(0) @binding(2)\nvar<storage, read_write> dst: array<{{TYPE}}>;\n\n@group(0) @binding(3)\nvar<uniform> params: Params;\n\n#enddecl(NO_FF_BINDINGS)\n\n#decl(NO_FF_BINDINGS_INPLACE)\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n\n#enddecl(NO_FF_BINDINGS_INPLACE)\n\n#decl(FF_BINDINGS)\n\n@group(0) @binding(2)\nvar<storage, read_write> src2: array<f32>;\n\n@group(0) @binding(3)\nvar<storage, read_write> dst: array<{{TYPE}}>;\n\n@group(0) @binding(4)\nvar<uniform> params: Params;\n\n#enddecl(FF_BINDINGS)\n\n#decl(FF_BINDINGS_INPLACE)\n\n@group(0) @binding(2)\nvar<storage, read_write> src2: array<f32>;\n\n@group(0) @binding(3)\nvar<uniform> params: Params;\n\n#enddecl(FF_BINDINGS_INPLACE)\n\n#end(DECLS)\n\n#define(SHADER)\n\nenable f16;\n\nstruct Params {\n    offset_src0: u32,\n    offset_src1: u32,\n    offset_src2: u32,\n    offset_dst: u32,\n\n    // Strides (in elements)\n    stride_src01: u32,\n    stride_src02: u32,\n    stride_src03: u32,\n\n    stride_dst1: u32,\n    stride_dst2: u32,\n    stride_dst3: u32,\n\n    n_threads: u32,\n    ne0: u32,\n    ne1: u32,\n    ne2: u32,\n\n    n_dims: u32,\n    mode: u32,\n    theta_scale: f32,\n    attn_factor: f32,\n    freq_scale: f32,\n    ext_factor: f32,\n    corr_dim0: f32,\n    corr_dim1: f32,\n    sections0: u32,\n    sections1: u32,\n    sections2: u32,\n    sections3: u32\n};\n\n@group(0) @binding(0)\nvar<storage, read_write> src0: array<{{TYPE}}>;\n\n@group(0) @binding(1)\nvar<storage, read_write> src1: array<i32>;\n\nDECLS\n\nfn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {\n    let y = (f32(i / 2) - low) / max(0.001f, high - low);\n    return 1.0f - min(1.0f, max(0.0f, y));\n}\n\n// returns vector of (cos_theta, sin_theta)\n// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row\nfn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {\n    var mscale = params.attn_factor;\n    var theta = params.freq_scale * theta_extrap;\n    if (params.ext_factor != 0.0f) {\n        let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;\n        theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;\n        mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale);\n    }\n    return vec2<f32>(cos(theta) * mscale, sin(theta) * mscale);\n}\n\nfn pair_base(i0: u32, div_2: bool) -> u32 {\n    if (div_2) {\n        return i0 / 2;\n    } else {\n        return i0;\n    }\n}\n\nfn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {\n    if (is_vision) {\n        return params.n_dims;\n    } else if (is_neox || is_mrope) {\n        return params.n_dims / 2;\n    } else {\n        return 1;\n    }\n}\n\noverride wg_size: u32;\n@compute @workgroup_size(wg_size)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n    // two elements per thread\n    if (gid.x >= params.n_threads) {\n        return;\n    }\n\n    let is_neox = bool(params.mode & 2);\n    let is_mrope = bool(params.mode & 8);\n    let is_imrope = params.mode == 40;\n    let is_vision = params.mode == 24;\n\n    var i = gid.x * 2; // start index for this thread\n    let i3 = i / (params.ne2 * params.ne1 * params.ne0);\n    i = i % (params.ne2 * params.ne1 * params.ne0);\n    let i2 = i / (params.ne1 * params.ne0);\n    i = i % (params.ne1 * params.ne0);\n    let i1 = i / params.ne0;\n    let i0 = i % params.ne0;\n\n    let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;\n    let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;\n\n    if (i0 >= params.n_dims && !is_vision) {\n        let i_src = i_src_row + i0;\n        let i_dst = i_dst_row + i0;\n        rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));\n        return;\n    }\n\n    var theta_base_mult: u32 = 0;\n    var theta_scale_pwr: u32 = i0 / 2;\n    if (is_mrope) {\n        let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3;\n        let sec_w = params.sections1 + params.sections0;\n        let sec_e = params.sections2 + sec_w;\n        let sector = (i0 / 2) % sect_dims;\n        if (is_imrope) {\n          if (sector % 3 == 1 && sector < 3 * params.sections1) {\n              theta_base_mult = 1;\n          } else if (sector % 3 == 2 && sector < 3 * params.sections2) {\n              theta_base_mult = 2;\n          } else if (sector % 3 == 0 && sector < 3 * params.sections0) {\n              theta_base_mult = 0;\n          } else {\n              theta_base_mult = 3;\n          }\n        } else {\n          if (sector >= params.sections0 && sector < sec_w) {\n              theta_base_mult = 1;\n              if (is_vision) {\n                  theta_scale_pwr = sector - params.sections0;\n              }\n          } else if (sector >= sec_w && sector < sec_e) {\n              theta_base_mult = 2;\n              if (is_vision) {\n                  theta_scale_pwr = sector - sec_w;\n              }\n          } else if (sector >= sec_e) {\n              if (is_vision) {\n                  theta_scale_pwr = sector - sec_e;\n                  theta_scale_pwr = (i0 / 2) % sec_e;\n              }\n              theta_base_mult = 3;\n          } else if (is_vision) {\n              theta_scale_pwr = sector;\n          }\n        }\n    }\n    let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));\n    let thetas = rope_yarn(theta_base/freq_factor(i0), i0);\n\n    let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision);\n    let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision);\n\n    let x0 = f32(src0[i_src]);\n    let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);\n    rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);\n}\n\n#end(SHADER)\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/scale.wgsl",
    "content": "#ifdef INPLACE\n@group(0) @binding(1)\nvar<uniform> params: Params;\n\nfn store_scale(val: f32, offset: u32) {\n    src[offset] = val;\n}\n#else\n@group(0) @binding(1)\nvar<storage, read_write> dst: array<f32>;\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n\nfn store_scale(val: f32, offset: u32) {\n    dst[offset] = val;\n}\n#endif\n\nstruct Params {\n    offset_src: u32,\n    offset_dst: u32,\n\n    // Strides (in elements)\n    stride_src1: u32,\n    stride_src2: u32,\n    stride_src3: u32,\n\n    stride_dst1: u32,\n    stride_dst2: u32,\n    stride_dst3: u32,\n\n    ne: u32,\n    ne0: u32,\n    ne1: u32,\n    ne2: u32,\n\n    scale: f32,\n    bias: f32\n};\n\n@group(0) @binding(0)\nvar<storage, read_write> src: array<f32>;\n\n@compute @workgroup_size(WG_SIZE)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n    if (gid.x >= params.ne) {\n        return;\n    }\n\n    var i = gid.x;\n    let i3 = i / (params.ne2 * params.ne1 * params.ne0);\n    i = i % (params.ne2 * params.ne1 * params.ne0);\n    let i2 = i / (params.ne1 * params.ne0);\n    i = i % (params.ne1 * params.ne0);\n    let i1 = i / params.ne0;\n    let i0 = i % params.ne0;\n\n    let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0;\n    let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;\n\n    store_scale(src[i_src] * params.scale + params.bias, i_dst);\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/set_rows.wgsl",
    "content": "enable f16;\n\n#ifdef DST_F32\n#define DST_INNER_TYPE f32\n#else\n#define DST_INNER_TYPE f16\n#endif\n\n#ifdef VEC4\n#define SRC_TYPE vec4<f32>\n#define DST_TYPE vec4<DST_INNER_TYPE>\n#define VEC_SIZE 4\n#else\n#define SRC_TYPE f32\n#define DST_TYPE DST_INNER_TYPE\n#define VEC_SIZE 1\n#endif\n\n@group(0) @binding(0)\nvar<storage, read_write> src: array<SRC_TYPE>;\n\n@group(0) @binding(1)\nvar<storage, read_write> idx: array<u32>;\n\n@group(0) @binding(2)\nvar<storage, read_write> dst: array<DST_TYPE>;\n\n#ifdef I64_IDX\n@group(0) @binding(3)\nvar<storage, read_write> error: atomic<u32>;\n#define PARAMS_BINDING 4\n#else\n#define PARAMS_BINDING 3\n#endif\n\nstruct Params {\n    offset_src: u32, // in elements\n    offset_idx: u32, // in elements\n    offset_dst: u32, // in elements\n\n    // Strides (in elements)\n    stride_src1: u32,\n    stride_src2: u32,\n    stride_src3: u32,\n\n    stride_idx0: u32,\n    stride_idx1: u32,\n    stride_idx2: u32,\n\n    stride_dst1: u32,\n    stride_dst2: u32,\n    stride_dst3: u32,\n\n    // Shape of src\n    ne0: u32,\n    n_rows: u32,\n    ne2: u32,\n    ne3: u32,\n\n    // Shape of idx\n    idx1: u32,\n    idx2: u32,\n};\n\n@group(0) @binding(PARAMS_BINDING)\nvar<uniform> params: Params;\n\n@compute @workgroup_size(WG_SIZE)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n    if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / VEC_SIZE) {\n        return;\n    }\n\n    // getting the row from gid\n    let elems_per_row = params.ne0 / VEC_SIZE;\n    var i = gid.x / elems_per_row;\n\n    let i_src3 = i / (params.ne2 * params.n_rows);\n\n    i = i % (params.ne2 * params.n_rows);\n    let i_src2 = i / params.n_rows;\n    let i_src1 = i % params.n_rows;\n\n    let i_idx2 = i_src3 % params.idx2;\n    let i_idx1 = i_src2 % params.idx1;\n    let i_idx0 = i_src1;\n\n#ifdef I64_IDX\n    let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;\n\n    let idx_val = idx[idx_high];\n    let idx_low_val = idx[idx_high + 1];\n\n    if (idx_low_val != 0) {\n        // Upper bits of index are not zero, output will be incorrect\n        atomicStore(&error, 1);\n        return;\n    }\n#else\n    let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;\n    let idx_val = idx[idx_i];\n#endif\n\n    let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;\n    let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;\n\n    let col_idx = (gid.x % elems_per_row);\n    dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]);\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl",
    "content": "#define(VARIANTS)\n[\n  {\n    \"SHADER_NAME\": \"soft_max_f32\",\n    \"DECLS\": [\"BASE_BINDINGS\", \"NOT_INPLACE\", \"NO_MASK\", \"NO_SINK\"]\n  },\n  {\n    \"SHADER_NAME\": \"soft_max_f32_inplace\",\n    \"DECLS\": [\"BASE_BINDINGS_INPLACE\", \"INPLACE\", \"NO_MASK\", \"NO_SINK\"]\n  },\n  {\n    \"SHADER_NAME\": \"soft_max_f32_sink\",\n    \"DECLS\": [\"SINK_BINDINGS\", \"NOT_INPLACE\", \"NO_MASK\", \"SINK\"]\n  },\n  {\n    \"SHADER_NAME\": \"soft_max_f32_sink_inplace\",\n    \"DECLS\": [\"SINK_BINDINGS_INPLACE\", \"INPLACE\", \"NO_MASK\", \"SINK\"]\n  },\n  {\n    \"SHADER_NAME\": \"soft_max_f32_mask_f32\",\n    \"REPLS\": {\n      \"MASK_TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"MASK_BINDINGS\", \"NOT_INPLACE\", \"MASK\", \"NO_SINK\"]\n  },\n  {\n    \"SHADER_NAME\": \"soft_max_f32_mask_f32_inplace\",\n    \"REPLS\": {\n      \"MASK_TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"MASK_BINDINGS_INPLACE\", \"INPLACE\", \"MASK\", \"NO_SINK\"]\n  },\n  {\n    \"SHADER_NAME\": \"soft_max_f32_mask_f16\",\n    \"REPLS\": {\n      \"MASK_TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"MASK_BINDINGS\", \"NOT_INPLACE\", \"MASK\", \"NO_SINK\"]\n  },\n  {\n    \"SHADER_NAME\": \"soft_max_f32_mask_f16_inplace\",\n    \"REPLS\": {\n      \"MASK_TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"MASK_BINDINGS_INPLACE\", \"INPLACE\", \"MASK\", \"NO_SINK\"]\n  },\n  {\n    \"SHADER_NAME\": \"soft_max_f32_mask_f32_sink\",\n    \"REPLS\": {\n      \"MASK_TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"MASK_SINK_BINDINGS\", \"NOT_INPLACE\", \"MASK\", \"SINK\"]\n  },\n  {\n    \"SHADER_NAME\": \"soft_max_f32_mask_f32_sink_inplace\",\n    \"REPLS\": {\n      \"MASK_TYPE\" : \"f32\",\n    },\n    \"DECLS\": [\"MASK_SINK_BINDINGS_INPLACE\", \"INPLACE\", \"MASK\", \"SINK\"]\n  },\n  {\n    \"SHADER_NAME\": \"soft_max_f32_mask_f16_sink\",\n    \"REPLS\": {\n      \"MASK_TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"MASK_SINK_BINDINGS\", \"NOT_INPLACE\", \"MASK\", \"SINK\"]\n  },\n  {\n    \"SHADER_NAME\": \"soft_max_f32_mask_f16_sink_inplace\",\n    \"REPLS\": {\n      \"MASK_TYPE\" : \"f16\",\n    },\n    \"DECLS\": [\"MASK_SINK_BINDINGS_INPLACE\", \"INPLACE\", \"MASK\", \"SINK\"]\n  }\n]\n#end(VARIANTS)\n\n#define(DECLS)\n\n#decl(BASE_BINDINGS)\n@group(0) @binding(1)\nvar<storage, read_write> dst: array<f32>;\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n#enddecl(BASE_BINDINGS)\n\n#decl(BASE_BINDINGS_INPLACE)\n@group(0) @binding(1)\nvar<uniform> params: Params;\n#enddecl(BASE_BINDINGS_INPLACE)\n\n#decl(SINK_BINDINGS)\n@group(0) @binding(1)\nvar<storage, read_write> sinks: array<f32>;\n\n@group(0) @binding(2)\nvar<storage, read_write> dst: array<f32>;\n\n@group(0) @binding(3)\nvar<uniform> params: Params;\n#enddecl(SINK_BINDINGS)\n\n#decl(SINK_BINDINGS_INPLACE)\n@group(0) @binding(1)\nvar<storage, read_write> sinks: array<f32>;\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n#enddecl(SINK_BINDINGS_INPLACE)\n\n#decl(MASK_BINDINGS)\n@group(0) @binding(1)\nvar<storage, read_write> mask: array<{{MASK_TYPE}}>;\n\n@group(0) @binding(2)\nvar<storage, read_write> dst: array<f32>;\n\n@group(0) @binding(3)\nvar<uniform> params: Params;\n#enddecl(MASK_BINDINGS)\n\n#decl(MASK_BINDINGS_INPLACE)\n@group(0) @binding(1)\nvar<storage, read_write> mask: array<{{MASK_TYPE}}>;\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n#enddecl(MASK_BINDINGS_INPLACE)\n\n#decl(MASK_SINK_BINDINGS)\n@group(0) @binding(1)\nvar<storage, read_write> mask: array<{{MASK_TYPE}}>;\n\n@group(0) @binding(2)\nvar<storage, read_write> sinks: array<f32>;\n\n@group(0) @binding(3)\nvar<storage, read_write> dst: array<f32>;\n\n@group(0) @binding(4)\nvar<uniform> params: Params;\n#enddecl(MASK_SINK_BINDINGS)\n\n#decl(MASK_SINK_BINDINGS_INPLACE)\n@group(0) @binding(1)\nvar<storage, read_write> mask: array<{{MASK_TYPE}}>;\n\n@group(0) @binding(2)\nvar<storage, read_write> sinks: array<f32>;\n\n@group(0) @binding(3)\nvar<uniform> params: Params;\n#enddecl(MASK_SINK_BINDINGS_INPLACE)\n\n#decl(NOT_INPLACE)\nfn inter_value(i: u32) -> f32 {\n    return dst[i];\n}\n\nfn update(i: u32, val: f32) {\n    dst[i] = val;\n}\n#enddecl(NOT_INPLACE)\n\n#decl(INPLACE)\nfn inter_value(i: u32) -> f32 {\n    return src[i];\n}\n\nfn update(i: u32, val: f32) {\n    src[i] = val;\n}\n#enddecl(INPLACE)\n\n#decl(NO_MASK)\nfn mask_val(i: u32) -> f32 {\n    return 0.0;\n}\n#enddecl(NO_MASK)\n\n#decl(MASK)\nfn mask_val(i: u32) -> f32 {\n    return f32(mask[i]);\n}\n#enddecl(MASK)\n\n#decl(NO_SINK)\nfn lower_max_bound(i2: u32) -> f32 {\n    return -1e30;\n}\n\nfn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {\n    return val;\n}\n#enddecl(NO_SINK)\n\n#decl(SINK)\nfn lower_max_bound(i2: u32) -> f32 {\n    return sinks[params.offset_sinks + i2];\n}\n\nfn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {\n    return val + exp(sinks[params.offset_sinks + i2] - max_val);\n}\n#enddecl(SINK)\n\n#end(DECLS)\n\n#define(SHADER)\nenable f16;\n\nstruct Params {\n    offset_src0: u32,\n    offset_src1: u32,\n    offset_sinks: u32,\n    offset_dst: u32,\n\n    // Strides (in elements)\n    stride_src01: u32,\n    stride_src02: u32,\n    stride_src03: u32,\n\n    stride_src11: u32,\n    stride_src12: u32,\n    stride_src13: u32,\n\n    stride_dst1: u32,\n    stride_dst2: u32,\n    stride_dst3: u32,\n\n    // shape of src0/dst\n    ne: u32,\n    ne0: u32,\n    ne1: u32,\n    ne2: u32,\n\n    // shape of src1\n    ne12: u32,\n    ne13: u32,\n\n    scale: f32,\n    max_bias: f32,\n    n_head_log2: f32,\n    m0: f32,\n    m1: f32,\n};\n\n@group(0) @binding(0)\nvar<storage, read_write> src: array<f32>;\n\nDECLS\n\nconst CACHE_SIZE: u32 = 16;\n\noverride wg_size: u32;\nvar<workgroup> scratch: array<f32, wg_size>;\n\n@compute @workgroup_size(wg_size)\nfn main(@builtin(workgroup_id) wid: vec3<u32>,\n        @builtin(local_invocation_id) lid: vec3<u32>) {\n\n    var i = wid.x;\n    let i3 = i / (params.ne2 * params.ne1);\n    i = i % (params.ne2 * params.ne1);\n    let i2 = i / params.ne1;\n    let i1 = i % params.ne1;\n    let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;\n    let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;\n    let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;\n    let elems = (params.ne0 + wg_size - 1) / wg_size;\n\n    let head = f32(i2);\n    let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);\n\n    var cache: array<f32, CACHE_SIZE>;\n\n    var max_val = lower_max_bound(i2);\n    var col = lid.x;\n    for (var j: u32 = 0; j < elems; j++) {\n        if (col >= params.ne0) {\n            break;\n        }\n        let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);\n        max_val = max(max_val, val);\n        if (col < CACHE_SIZE) {\n            cache[col] = val;\n        }\n        col += wg_size;\n    }\n\n    scratch[lid.x] = max_val;\n    workgroupBarrier();\n    var offset = wg_size / 2;\n    while (offset > 0) {\n        if (lid.x < offset) {\n            scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);\n        }\n        offset = offset / 2;\n        workgroupBarrier();\n    }\n    let row_max = scratch[0];\n    workgroupBarrier();\n\n    var sum = 0.0f;\n    col = lid.x;\n    for (var j: u32 = 0; j < elems; j++) {\n        if (col >= params.ne0) {\n            break;\n        }\n        let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),\n                         cache[col], col < CACHE_SIZE);\n        let ex = exp(val - row_max);\n        sum += ex;\n        if (col < CACHE_SIZE) {\n            cache[col] = ex;\n        } else {\n            update(i_dst_row + col, ex);\n        }\n        col += wg_size;\n    }\n\n    scratch[lid.x] = sum;\n    workgroupBarrier();\n    offset = wg_size / 2;\n    while (offset > 0) {\n        if (lid.x < offset) {\n            scratch[lid.x] += scratch[lid.x + offset];\n        }\n        offset = offset / 2;\n        workgroupBarrier();\n    }\n    let row_sum = add_sinks(scratch[0], i2, row_max);\n\n    let sum_recip = 1.0 / row_sum;\n    col = lid.x;\n    for  (var j: u32 = 0; j < elems; j++) {\n        if (col >= params.ne0) {\n            break;\n        }\n        update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);\n        col += wg_size;\n    }\n}\n#end(SHADER)\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl",
    "content": "@group(0) @binding(0)\nvar<storage, read_write> src: array<f32>;\n\n@group(0) @binding(1)\nvar<storage, read_write> dst: array<f32>;\n\nstruct Params {\n    offset_src: u32, // in elements\n    offset_dst: u32, // in elements\n\n    // Strides (in elements)\n    stride_src1: u32,\n    stride_src2: u32,\n    stride_src3: u32,\n\n    ne0: u32,\n    ne1: u32,\n    ne2: u32\n};\n\n@group(0) @binding(2)\nvar<uniform> params: Params;\n\nvar<workgroup> shared_sum: array<f32, WG_SIZE>;\n\n@compute @workgroup_size(WG_SIZE)\nfn main(@builtin(workgroup_id) wid: vec3<u32>,\n        @builtin(local_invocation_id) lid: vec3<u32>) {\n\n    var i = wid.x;\n    let i3 = i / (params.ne2 * params.ne1);\n    i = i % (params.ne2 * params.ne1);\n    let i2 = i / params.ne1;\n    let i1 = i % params.ne1;\n    let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;\n    var local_sum: f32 = 0.0;\n    for (var col = lid.x; col < params.ne0; col += WG_SIZE) {\n        local_sum += src[i_src_row + col];\n    }\n    shared_sum[lid.x] = local_sum;\n    workgroupBarrier();\n    // reduce within workgroup\n    var offset: u32 = WG_SIZE >> 1;\n    while (offset > 0) {\n        if (lid.x < offset) {\n            shared_sum[lid.x] = shared_sum[lid.x] + shared_sum[lid.x + offset];\n        }\n        workgroupBarrier();\n        offset >>= 1;\n    }\n\n    if (lid.x == 0) {\n        dst[params.offset_dst + wid.x] = shared_sum[0];\n    }\n}\n"
  },
  {
    "path": "src/ggml-webgpu/wgsl-shaders/unary.wgsl",
    "content": "#ifdef TYPE_F16\nenable f16;\n#define TYPE f16\n#else\n#define TYPE f32\n#endif\n\n\n@group(0) @binding(0)\nvar<storage, read_write> src: array<TYPE>;\n\n#ifndef INPLACE\n@group(0) @binding(1)\nvar<storage, read_write> dst: array<TYPE>;\n#define PARAMS_BINDING 2\n#else\n#define PARAMS_BINDING 1\n#endif\n\nstruct Params {\n    ne: u32,            // total number of elements\n    offset_src: u32,    // in elements\n    offset_dst: u32,    // in elements\n\n    // Strides (in elements)\n    stride_src0: u32,\n    stride_src1: u32,\n    stride_src2: u32,\n    stride_src3: u32,\n\n    // Logical shapes\n    ne0: u32,\n    ne1: u32,\n    ne2: u32,\n#ifdef CLAMP\n    clamp_min: f32,\n    clamp_max: f32,\n#endif\n#ifdef FILL\n    fill_val: f32,\n#endif\n#ifdef XIELU\n    alpha_n: f32,\n    alpha_p: f32,\n    beta: f32,\n    eps: f32,\n#endif\n\n};\n\n@group(0) @binding(PARAMS_BINDING)\nvar<uniform> params: Params;\n\n@compute @workgroup_size(WG_SIZE)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n    if (gid.x >= params.ne) {\n      return;\n    }\n    var i = gid.x;\n    let i3 = i / (params.ne2 * params.ne1 * params.ne0);\n    i = i % (params.ne2 * params.ne1 * params.ne0);\n    let i2 = i / (params.ne1 * params.ne0);\n    i = i % (params.ne1 * params.ne0);\n    let i1 = i / params.ne0;\n    let i0 = i % params.ne0;\n\n    let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +\n                  i2 * params.stride_src2 + i3 * params.stride_src3;\n\n#ifdef ABS\n    let res = abs(src[params.offset_src + src_idx]);\n#endif\n#ifdef SGN\n    let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0),\n                     src[params.offset_src + src_idx] > 0.0);\n#endif\n#ifdef NEG\n    let res = -src[params.offset_src + src_idx];\n#endif\n#ifdef STEP\n    let res = TYPE(select(0.0, 1.0, src[params.offset_src + src_idx] > 0.0));\n#endif\n#ifdef TANH\n    let res = tanh(clamp(src[params.offset_src + src_idx], -9.010913, 9.010913));\n#endif\n#ifdef RELU\n    let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0);\n#endif\n#ifdef ELU\n    let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx],\n                     src[params.offset_src + src_idx] > 0.0);\n#endif\n#ifdef HARDSIGMOID\n    let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));\n#endif\n#ifdef SIGMOID\n    let res = 1.0 / (1.0 + exp(-src[params.offset_src + src_idx]));\n#endif\n#ifdef SILU\n    let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx]));\n#endif\n#ifdef EXP\n    let res = exp(src[params.offset_src + src_idx]);\n#endif\n#ifdef LOG\n    let res = TYPE(log(f32(src[params.offset_src + src_idx])));\n#endif\n#ifdef CLAMP\n    let res = clamp(src[params.offset_src + src_idx], TYPE(params.clamp_min), TYPE(params.clamp_max));\n#endif\n#ifdef FILL\n    let res = TYPE(params.fill_val);\n#endif\n#ifdef HARDSWISH\n    let res = src[params.offset_src + src_idx] *\n              min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));\n#endif\n#ifdef GELU\n    let res = 0.5 * src[params.offset_src + src_idx] *\n              (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) *\n                               (src[params.offset_src + src_idx] +\n                                0.044715 * pow(src[params.offset_src + src_idx], 3.0)),\n                               -9.010913, 9.010913)));\n#endif\n#ifdef GELU_QUICK\n    let res = src[params.offset_src + src_idx] * 0.5 *\n              (1.0 + tanh(clamp(0.79788456 *\n                               (src[params.offset_src + src_idx] +\n                                0.044715 * src[params.offset_src + src_idx] *\n                                    src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),\n                               -9.010913, 9.010913)));\n#endif\n#ifdef GELU_ERF\n    let res = 0.5 * src[params.offset_src + src_idx] *\n              (1.0 + tanh(clamp(0.79788456 *\n                               (src[params.offset_src + src_idx] +\n                                0.044715 * src[params.offset_src + src_idx] *\n                                    src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),\n                               -9.010913, 9.010913)));\n#endif\n#ifdef XIELU\n    let res =\n        select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) -\n                src[params.offset_src + src_idx]) *\n                   TYPE(params.alpha_n) +\n               TYPE(params.beta) * src[params.offset_src + src_idx],\n               TYPE(params.alpha_p) * src[params.offset_src + src_idx] *\n                   src[params.offset_src + src_idx] +\n                   TYPE(params.beta) * src[params.offset_src + src_idx],\n               src[params.offset_src + src_idx] > 0.0);\n#endif\n#ifdef SOFTPLUS\n    let src_f32 = f32(src[params.offset_src + src_idx]);\n    let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0));\n#endif\n#ifdef EXPM1\n    let res = exp(src[params.offset_src + src_idx]) - 1.0;\n#endif\n#ifdef FLOOR\n    let res = floor(src[params.offset_src + src_idx]);\n#endif\n#ifdef CEIL\n    let res = ceil(src[params.offset_src + src_idx]);\n#endif\n#ifdef ROUND\n    let src_f32 = f32(src[params.offset_src + src_idx]);\n    let result = select(ceil(src_f32 - 0.5), floor(src_f32 + 0.5), src_f32 >= 0.0);\n    let res = TYPE(result);\n#endif\n#ifdef TRUNC\n    let res = trunc(src[params.offset_src + src_idx]);\n#endif\n#ifdef SQR\n    let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx];\n#endif\n#ifdef SQRT\n    let res = sqrt(src[params.offset_src + src_idx]);\n#endif\n#ifdef SIN\n    let res_f32 = sin(f32(src[params.offset_src + src_idx]));\n    let res = TYPE(res_f32);\n#endif\n#ifdef COS\n    let res_f32 = cos(f32(src[params.offset_src + src_idx]));\n    let res = TYPE(res_f32);\n#endif\n\n#ifdef INPLACE\n    src[params.offset_src + src_idx] = res;\n#else\n    dst[params.offset_dst + gid.x] = res;\n#endif\n}\n"
  },
  {
    "path": "src/ggml-zdnn/.gitignore",
    "content": "zdnn.h\n"
  },
  {
    "path": "src/ggml-zdnn/CMakeLists.txt",
    "content": "if (DEFINED ZDNN_ROOT)\n    message(STATUS \"zdnn: using ZDNN_ROOT override: ${ZDNN_ROOT}\")\n    set(ZDNN_HINT \"${ZDNN_ROOT}\")\nelse()\n    set(ZDNN_HINT \"\")\nendif()\n\nfind_path(ZDNN_INCLUDE\n            NAMES zdnn.h\n            HINTS ${ZDNN_HINT} /usr /usr/local\n            PATH_SUFFIXES include)\nif (ZDNN_INCLUDE)\n    message(STATUS \"zdnn: found include: ${ZDNN_INCLUDE}\")\nelse()\n    message(FATAL_ERROR \"zdnn: include directory not found, please set ZDNN_ROOT to the proper path if necessary\")\nendif()\n\nfind_library(ZDNN_LIB\n                NAMES zdnn\n                HINTS ${ZDNN_HINT} /usr /usr/local\n                PATH_SUFFIXES lib lib64)\nif (ZDNN_LIB)\n    message(STATUS \"zdnn: found library: ${ZDNN_LIB}\")\nelse()\n    message(FATAL_ERROR \"zdnn: library not found, please set ZDNN_ROOT to the proper path if necessary\")\nendif()\n\nfile(GLOB GGML_SOURCES_ZDNN \"*.c\" \"*.cpp\")\nfile(GLOB GGML_HEADERS_ZDNN \"*.h\" \"*.hpp\")\n\nggml_add_backend_library(ggml-zdnn ${GGML_HEADERS_ZDNN} ${GGML_SOURCES_ZDNN})\ntarget_link_libraries(ggml-zdnn PRIVATE ${ZDNN_LIB})\ntarget_include_directories(ggml-zdnn PRIVATE ${ZDNN_INCLUDE})\ntarget_link_directories(ggml-zdnn PRIVATE ${ZDNN_LIB})\n\ntarget_compile_definitions(ggml-zdnn PRIVATE GGML_USE_ZDNN)\n"
  },
  {
    "path": "src/ggml-zdnn/common.hpp",
    "content": "#ifndef GGML_ZDNN_COMMON_HPP\n#define GGML_ZDNN_COMMON_HPP\n\n#include \"ggml.h\"\n#include \"ggml-impl.h\"\n\n#include \"zdnn.h\"\n\n#include <vector>\n#include <memory>\n\n#define GGML_ZDNN_NAME    \"zDNN\"\n#define GGML_ZDNN_VERSION ZDNN_VERNUM\n\n#define ZDNN_CHECK(stmt)                \\\n    do {                                \\\n        zdnn_status status = (stmt);    \\\n        GGML_ASSERT(status == ZDNN_OK); \\\n    } while (0);\n\nstruct ggml_backend_zdnn_device_context {\n    int zdnn_device;\n    int zdnn_device_ref_count;\n\n    bool has_parmblkformat_0;\n    bool has_parmblkformat_1;  // checks for z17\n\n    size_t max_size;\n\n    char name[128];\n};\n\nstruct ggml_backend_zdnn_context {\n    int device;\n    ggml_cgraph * gf;\n};\n\nstruct ggml_backend_zdnn_buffer {\n    void * data;\n    ggml_backend_zdnn_buffer * extra;  // for bias, etc.\n    size_t size;\n\n    zdnn_tensor_desc pre_tfm_desc;\n    zdnn_tensor_desc tfm_desc;\n    zdnn_ztensor     ztensor;\n\n    char name[GGML_MAX_NAME];\n};\n\nstruct ggml_backend_zdnn_buffer_context {\n    void * all_data;\n    size_t all_size;\n    bool owned;\n\n    int n_buffers;\n    std::vector<std::unique_ptr<ggml_backend_zdnn_buffer>> buffers;\n};\n\n#endif  // GGML_ZDNN_COMMON_HPP\n"
  },
  {
    "path": "src/ggml-zdnn/ggml-zdnn.cpp",
    "content": "#include \"ggml-zdnn.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-backend-impl.h\"\n\n#include \"ggml-zdnn/common.hpp\"\n#include \"ggml-zdnn/mmf.hpp\"\n#include \"ggml-zdnn/utils.hpp\"\n#include \"ggml.h\"\n\n#include <vector>\n#include <memory>\n#include <csignal>  // raise(SIGTRAP)\n#include <unistd.h>\n\nstatic void ggml_zdnn_compute_forward_mul_mat(\n    const ggml_backend_zdnn_context * ctx,\n          ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];  // weights\n    const ggml_tensor * src1 = dst->src[1];  // inputs\n\n    // TODO: implement support for quantized types\n    // we currently only support f32, f16, and bf16\n    ggml_zdnn_mul_mat_f(ctx, src0, src1, dst);\n}\n\nstatic bool ggml_zdnn_compute_forward(\n    ggml_backend_zdnn_context * ctx,\n    ggml_tensor * dst) {\n\n    switch (dst->op) {\n        case GGML_OP_MUL_MAT:\n            {\n                ggml_zdnn_compute_forward_mul_mat(ctx, dst);\n            } break;\n\n        default:\n            return false;\n    }\n\n    return true;\n}\n\nstatic enum ggml_status ggml_zdnn_graph_compute(ggml_backend_t backend, ggml_cgraph * gf) {\n    ggml_backend_zdnn_context        * ctx     = (       ggml_backend_zdnn_context *)backend->context;\n    ggml_backend_zdnn_device_context * ctx_dev = (ggml_backend_zdnn_device_context *)backend->device->context;\n\n    ctx->gf = gf;\n    for (int i = 0; i < gf->n_nodes; i++) {\n        ggml_tensor * node = gf->nodes[i];\n\n        if (ggml_is_empty(node)\n            || node->op == GGML_OP_NONE\n            || node->op == GGML_OP_RESHAPE\n            || node->op == GGML_OP_VIEW\n            || node->op == GGML_OP_PERMUTE\n            || node->op == GGML_OP_TRANSPOSE) {\n            continue;\n        }\n\n        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {\n            continue;\n        }\n\n        bool ok = ggml_zdnn_compute_forward(ctx, node);\n        if (!ok) {\n            GGML_LOG_ERROR(\"%s: unsupported op %s (%s)\\n\",\n                           __func__, node->name, ggml_op_name(node->op));\n        }\n\n        GGML_ASSERT(ok);\n    }\n\n    return GGML_STATUS_SUCCESS;\n\n    GGML_UNUSED(ctx_dev);\n}\n\nstatic bool ggml_zdnn_supports_op(const ggml_backend_zdnn_device_context * ctx_dev, const ggml_tensor * op) {\n    switch (op->op) {\n        case GGML_OP_NONE:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_TRANSPOSE:\n        case GGML_OP_PERMUTE:\n            return true;\n\n        case GGML_OP_MUL_MAT:\n            {\n                const ggml_tensor * weights = op->src[0];\n                const ggml_tensor * inputs  = op->src[1];\n\n                const int64_t ne10 = inputs->ne[0];\n                const int64_t ne0  = op->ne[0];\n                const int64_t ne1  = op->ne[1];\n\n                const int64_t max_batch = ctx_dev->max_size;\n\n                if (!ggml_is_matrix(weights) || !ggml_is_matrix(inputs) ||\n                    !ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) ||\n                    weights->view_src != nullptr || inputs->view_src != nullptr ||\n                    ne0 > max_batch || ne1 > max_batch || ne10 > max_batch) {\n                        return false;\n                }\n\n                switch (weights->type) {\n                    case GGML_TYPE_F32:\n                    case GGML_TYPE_F16:\n                    case GGML_TYPE_BF16:\n                        return true;\n                    default:\n                        return false;\n                }\n            } break;\n\n        default:\n            return false;\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////\n\n//\n// globals\n//\n\n// initialised in ggml_backend_zdnn_reg\nstatic ggml_backend_reg    g_ggml_backend_zdnn_reg;\nstatic ggml_backend_device g_ggml_backend_zdnn_device;\n\nstatic ggml_backend_zdnn_device_context g_ggml_ctx_dev_main = {\n    /* .zdnn_device           = */ 0,\n    /* .zdnn_device_ref_count = */ 0,\n    /* .has_parmblkformat_0   = */ false,\n    /* .has_parmblkformat_1   = */ false,\n    /* .max_size              = */ 0,\n    /* .name                  = */ \"\",\n};\n\nstatic int ggml_backend_zdnn_device_acq(ggml_backend_zdnn_device_context * ctx) {\n    assert(ctx != NULL);\n\n    if (ctx->zdnn_device == 0) {\n        ctx->zdnn_device = 1;\n    }\n\n    if (ctx->zdnn_device >= 1) {\n        ctx->has_parmblkformat_0 = zdnn_is_nnpa_parmblk_fmt_installed(1, NNPA_PARMBLKFORMAT_0);\n        ctx->has_parmblkformat_1 = zdnn_is_nnpa_parmblk_fmt_installed(1, NNPA_PARMBLKFORMAT_1);\n        ctx->max_size = zdnn_get_nnpa_max_dim_idx_size();\n        strncpy(ctx->name, GGML_ZDNN_NAME, sizeof(ctx->name) - 1);\n    }\n\n    ctx->zdnn_device_ref_count++;\n    return ctx->zdnn_device;\n}\n\nstatic void ggml_backend_zdnn_device_rel(ggml_backend_zdnn_device_context * ctx) {\n    assert(ctx != NULL);\n    assert(ctx->zdnn_device_ref_count > 0);\n\n    ctx->zdnn_device_ref_count--;\n    if (ctx->zdnn_device_ref_count == 0) {\n        if (ctx->zdnn_device >= 0) {\n            ctx->zdnn_device = 0;\n        }\n    }\n}\n\nstatic ggml_backend_zdnn_context * ggml_zdnn_init(ggml_backend_dev_t dev) {\n    GGML_LOG_INFO(\"%s: allocating\\n\", __func__);\n    GGML_LOG_INFO(\"%s: found 1 device\\n\", __func__);\n\n    #ifdef STATIC_LIB\n    zdnn_init();\n    #endif\n\n    ggml_backend_zdnn_context * ctx = new ggml_backend_zdnn_context();\n    ggml_backend_zdnn_device_context * ctx_dev = (ggml_backend_zdnn_device_context *)dev->context;\n\n    int device = 1;\n    GGML_LOG_INFO(\"%s: picking default device: %s\\n\", __func__, ctx_dev->name);\n\n    ctx->device = device;\n    GGML_LOG_INFO(\"%s: NNPA name: %s\\n\", __func__, ctx_dev->name);\n    GGML_LOG_INFO(\"%s: NNPA_PARMBLKFORMAT_0 = %s\\n\", __func__, ctx_dev->has_parmblkformat_0 ? \"true\" : \"false\");\n    GGML_LOG_INFO(\"%s: NNPA_PARMBLKFORMAT_1 = %s\\n\", __func__, ctx_dev->has_parmblkformat_1 ? \"true\" : \"false\");\n\n    ctx->gf = nullptr;\n\n    return ctx;\n}\n\nstatic void ggml_zdnn_free(ggml_backend_zdnn_context * ctx) {\n    GGML_LOG_INFO(\"%s: deallocating\\n\", __func__);\n    delete ctx;\n}\n\n//\n// backend interface\n//\n\nstatic void ggml_backend_zdnn_buffer_free_buffer(ggml_backend_buffer_t buffer) {\n    ggml_backend_zdnn_buffer_context * ctx = (ggml_backend_zdnn_buffer_context *)buffer->context;\n\n    for (const auto & buf_ptr : ctx->buffers) {\n        ggml_backend_zdnn_buffer * buf = buf_ptr.get();\n\n        // Free any extra buffer allocated for the tensor. E.g., bias for GGML_OP_MUL_MAT\n        if (buf->extra != nullptr) free(buf->extra->data);\n        if (buf->ztensor.buffer_size > 0) ZDNN_CHECK(zdnn_free_ztensor_buffer(&buf->ztensor));\n    }\n\n    delete ctx;\n}\n\nstatic void * ggml_backend_zdnn_buffer_get_base(ggml_backend_buffer_t buffer) {\n    ggml_backend_zdnn_buffer_context * ctx = (ggml_backend_zdnn_buffer_context *)buffer->context;\n    return ctx->all_data;\n}\n\nstatic enum ggml_status ggml_backend_zdnn_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {\n    if (tensor->view_src != NULL) {\n        assert(tensor->view_src->buffer->buft == buffer->buft);\n        return GGML_STATUS_SUCCESS;\n    }\n\n    ggml_backend_zdnn_buffer_context * ctx = (ggml_backend_zdnn_buffer_context *)buffer->context;\n\n    const int64_t tsize = ggml_nbytes(tensor);\n    int buffer_idx = ctx->n_buffers;\n\n    std::unique_ptr<ggml_backend_zdnn_buffer> zdnn_buffer = std::make_unique<ggml_backend_zdnn_buffer>();\n    zdnn_buffer->data = tensor->data;\n    zdnn_buffer->size = tsize;\n    zdnn_buffer->extra = nullptr;\n    snprintf(zdnn_buffer->name, GGML_MAX_NAME, \"%s\", tensor->name);\n\n    ggml_zdnn_init_tensor(zdnn_buffer.get(), tensor);\n    tensor->extra = zdnn_buffer.get();\n\n    switch (tensor->op) {\n        case GGML_OP_MUL_MAT:\n            {\n                std::unique_ptr<ggml_backend_zdnn_buffer> zdnn_bias_buffer = std::make_unique<ggml_backend_zdnn_buffer>();\n                zdnn_bias_buffer->data = (void *)calloc(tensor->ne[0], ggml_element_size(tensor));\n                zdnn_bias_buffer->size = ggml_element_size(tensor) * tensor->ne[0];\n                snprintf(zdnn_bias_buffer->name, GGML_MAX_NAME, \"%.*s (bias)\",\n                         GGML_MAX_NAME - (int)sizeof(\" (bias)\"), tensor->name);\n\n                const int64_t bias_dim[GGML_MAX_DIMS] = { 1, 1, 1, tensor->ne[0] };\n                ggml_zdnn_create_tensor(zdnn_bias_buffer->pre_tfm_desc,\n                                        zdnn_bias_buffer->tfm_desc,\n                                        zdnn_bias_buffer->ztensor,\n                                        tensor, bias_dim, ZDNN_1D);\n\n                ggml_zdnn_load_tensor(zdnn_bias_buffer->ztensor, zdnn_bias_buffer->data);\n                zdnn_buffer->extra = zdnn_bias_buffer.get();\n\n                ctx->buffers.push_back(std::move(zdnn_bias_buffer));\n                ctx->n_buffers++;\n            } break;\n        default:\n            break;\n    }\n\n    ctx->buffers.push_back(std::move(zdnn_buffer));\n    ctx->n_buffers++;\n\n    // GGML_LOG_INFO(\"%s: initialised tensor '%s' in buffer %d, size = %8.2f MiB\\n\",\n    //               __func__, tensor->name, buffer_idx, tsize);\n\n    return GGML_STATUS_SUCCESS;\n\n    GGML_UNUSED(buffer_idx);\n}\n\nstatic void ggml_backend_zdnn_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {\n    memset((char *)tensor->data + offset, value, size);\n\n    GGML_UNUSED(buffer);\n}\n\nstatic void ggml_backend_zdnn_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {\n    memcpy((char *)tensor->data + offset, data, size);\n\n    ggml_backend_zdnn_buffer * extra = (ggml_backend_zdnn_buffer *)tensor->extra;\n\n    // Fixes the LLAMA_SET_ROWS bug\n    // see: https://github.com/ggml-org/llama.cpp/issues/15414\n    if (tensor->buffer->usage == GGML_BACKEND_BUFFER_USAGE_COMPUTE && extra->ztensor.is_transformed) zdnn_reset_ztensor(&extra->ztensor);\n    if (extra->ztensor.is_transformed == false) ggml_zdnn_load_tensor(extra->ztensor, tensor->data);\n\n    GGML_UNUSED(buffer);\n}\n\nstatic void ggml_backend_zdnn_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {\n    memcpy(data, (const char *)tensor->data + offset, size);\n\n    GGML_UNUSED(buffer);\n}\n\nstatic void ggml_backend_zdnn_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {\n    ggml_backend_zdnn_buffer_context * ctx = (ggml_backend_zdnn_buffer_context *)buffer->context;\n\n    memset(ctx->all_data, value, ctx->all_size);\n}\n\nstatic ggml_backend_buffer_i ggml_backend_zdnn_buffer_i = {\n    /* .free_buffer   = */ ggml_backend_zdnn_buffer_free_buffer,\n    /* .get_base      = */ ggml_backend_zdnn_buffer_get_base,\n    /* .init_tensor   = */ ggml_backend_zdnn_buffer_init_tensor,\n    /* .memset_tensor = */ ggml_backend_zdnn_buffer_memset_tensor,\n    /* .set_tensor    = */ ggml_backend_zdnn_buffer_set_tensor,\n    /* .get_tensor    = */ ggml_backend_zdnn_buffer_get_tensor,\n    /* .cpy_tensor    = */ NULL,\n    /* .clear         = */ ggml_backend_zdnn_buffer_clear,\n    /* .reset         = */ NULL,\n};\n\n//\n// default buffer type\n//\n\nstatic const char * ggml_backend_zdnn_buffer_type_get_name(ggml_backend_buffer_type_t buft) {\n    return GGML_ZDNN_NAME;\n\n    GGML_UNUSED(buft);\n}\n\nstatic ggml_backend_buffer_t ggml_backend_zdnn_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {\n    ggml_backend_zdnn_buffer_context * ctx = new ggml_backend_zdnn_buffer_context();\n\n    const size_t size_page = sysconf(_SC_PAGESIZE);\n\n    size_t size_aligned = size;\n    if ((size_aligned % size_page) != 0) {\n        size_aligned += size_page - (size_aligned % size_page);\n    }\n\n    ggml_backend_zdnn_device_context * ctx_dev = (ggml_backend_zdnn_device_context *)buft->device->context;\n\n    GGML_ASSERT(ctx_dev->zdnn_device >= 0);\n    int device = ctx_dev->zdnn_device; GGML_UNUSED(device);\n\n    ctx->all_data  = ggml_aligned_malloc(size_aligned);\n    ctx->all_size  = size_aligned;\n    ctx->owned     = true;\n    ctx->n_buffers = 1;\n\n    if (ctx->all_data != NULL) {\n        std::unique_ptr<ggml_backend_zdnn_buffer> zdnn_buffer = std::make_unique<ggml_backend_zdnn_buffer>();\n        zdnn_buffer->data = ctx->all_data;\n        zdnn_buffer->size = size_aligned;\n        ctx->buffers.push_back(std::move(zdnn_buffer));\n    }\n\n    if (size_aligned > 0 && (ctx->all_data == NULL)) {\n        GGML_LOG_ERROR(\"%s: error: failed to allocate buffer, size = %8.2f\\n\",\n                       __func__, size_aligned / 1024.0 / 1024.0);\n        delete ctx;\n        return NULL;\n    }\n\n    return ggml_backend_buffer_init(buft, ggml_backend_zdnn_buffer_i, ctx, size);\n}\n\nstatic size_t ggml_backend_zdnn_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {\n    return 256;\n\n    GGML_UNUSED(buft);\n}\n\nstatic bool ggml_backend_zdnn_buffer_type_is_host(ggml_backend_buffer_type_t buft) {\n    /* while it resides in host memory, additional transformation is needed */\n    return false;\n\n    GGML_UNUSED(buft);\n}\n\nggml_backend_buffer_type_t ggml_backend_zdnn_buffer_type(void) {\n    static ggml_backend_buffer_type ggml_backend_buffer_type_zdnn = {\n        /* .iface   = */ {\n            /* .get_name       = */ ggml_backend_zdnn_buffer_type_get_name,\n            /* .alloc_buffer   = */ ggml_backend_zdnn_buffer_type_alloc_buffer,\n            /* .get_alignment  = */ ggml_backend_zdnn_buffer_type_get_alignment,\n            /* .get_max_size   = */ NULL,\n            /* .get_alloc_size = */ NULL,  // defaults to ggml_nbytes\n            /* .is_host        = */ ggml_backend_zdnn_buffer_type_is_host,\n        },\n        /* .device  = */ &g_ggml_backend_zdnn_device,\n        /* .context = */ NULL,\n    };\n\n    return &ggml_backend_buffer_type_zdnn;\n}\n\n//\n// backend\n//\n\nstatic const char * ggml_backend_zdnn_name(ggml_backend_t backend) {\n    return GGML_ZDNN_NAME;\n\n    GGML_UNUSED(backend);\n}\n\nstatic void ggml_backend_zdnn_free(ggml_backend_t backend) {\n    ggml_backend_zdnn_context * ctx = (ggml_backend_zdnn_context *)backend->context;\n\n    ggml_zdnn_free(ctx);\n    free(backend);\n}\n\nstatic enum ggml_status ggml_backend_zdnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {\n    return ggml_zdnn_graph_compute(backend, cgraph);\n}\n\nstatic ggml_backend_i ggml_backend_zdnn_i = {\n    /* .get_name           = */ ggml_backend_zdnn_name,\n    /* .free               = */ ggml_backend_zdnn_free,\n    /* .set_tensor_async   = */ NULL,\n    /* .get_tensor_async   = */ NULL,\n    /* .cpy_tensor_async   = */ NULL,\n    /* .synchronize        = */ NULL,\n    /* .graph_plan_create  = */ NULL,\n    /* .graph_plan_free    = */ NULL,\n    /* .graph_plan_update  = */ NULL,\n    /* .graph_plan_compute = */ NULL,\n    /* .graph_compute      = */ ggml_backend_zdnn_graph_compute,\n    /* .event_record       = */ NULL,\n    /* .event_wait         = */ NULL,\n    /* .graph_optimize     = */ NULL,\n};\n\nstatic ggml_guid_t ggml_backend_zdnn_guid(void) {\n    static const char * guid_str = \"IBM-ZDNN-ACCELER\";\n    return reinterpret_cast<ggml_guid_t>((void *)guid_str);\n}\n\nbool ggml_backend_is_zdnn(ggml_backend_t backend) {\n    return backend != NULL &&\n           ggml_guid_matches(backend->guid, ggml_backend_zdnn_guid());\n\n    GGML_UNUSED(backend);\n}\n\n//\n// backend device\n//\n\nstatic const char * ggml_backend_zdnn_device_get_name(ggml_backend_dev_t dev) {\n    return GGML_ZDNN_NAME;\n\n    GGML_UNUSED(dev);\n}\n\nstatic const char * ggml_backend_zdnn_device_get_description(ggml_backend_dev_t dev) {\n    return \"IBM Z Neural Network Processing Assist (NNPA)\";\n\n    GGML_UNUSED(dev);\n}\n\nstatic void ggml_backend_zdnn_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {\n    *free  = 0;\n    *total = 0;\n\n    GGML_UNUSED(dev);\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_zdnn_device_get_type(ggml_backend_dev_t dev) {\n    return GGML_BACKEND_DEVICE_TYPE_ACCEL;\n\n    GGML_UNUSED(dev);\n}\n\nstatic void ggml_backend_zdnn_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {\n    props->name        = ggml_backend_zdnn_device_get_name(dev);\n    props->description = ggml_backend_zdnn_device_get_description(dev);\n    props->type        = ggml_backend_zdnn_device_get_type(dev);\n    ggml_backend_zdnn_device_get_memory(dev, &props->memory_free, &props->memory_total);\n    props->caps = (ggml_backend_dev_caps) {\n        /* .async                = */ false,\n        /* .host_buffer          = */ false,\n        /* .buffer_from_host_ptr = */ false,\n        /* .events               = */ false\n    };\n}\n\nstatic ggml_backend_t ggml_backend_zdnn_device_init(ggml_backend_dev_t dev, const char * params) {\n    ggml_backend_zdnn_context * ctx = ggml_zdnn_init(dev);\n    if (ctx == NULL) {\n        GGML_LOG_ERROR(\"%s: error: failed to allocate context\\n\", __func__);\n        return NULL;\n    }\n\n    ggml_backend_t backend = (ggml_backend *)malloc(sizeof(ggml_backend));\n    *backend = (ggml_backend) {\n        /* .guid       = */ ggml_backend_zdnn_guid(),\n        /* .iface      = */ ggml_backend_zdnn_i,\n        /* .device     = */ dev,\n        /* .context    = */ ctx\n    };\n\n    return backend;\n\n    GGML_UNUSED(params);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_zdnn_device_get_buffer_type(ggml_backend_dev_t dev) {\n    return ggml_backend_zdnn_buffer_type();\n\n    GGML_UNUSED(dev);\n}\n\nstatic bool ggml_backend_zdnn_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {\n    ggml_backend_zdnn_device_context * ctx_dev = (ggml_backend_zdnn_device_context *) dev->context;\n\n    return ggml_zdnn_supports_op(ctx_dev, op);\n}\n\nstatic bool ggml_backend_zdnn_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    return\n        buft->iface.get_name == ggml_backend_zdnn_buffer_type_get_name;\n\n    GGML_UNUSED(dev);\n}\n\nstatic ggml_backend_device_i ggml_backend_zdnn_device_i = {\n    /* .get_name             = */ ggml_backend_zdnn_device_get_name,\n    /* .get_description      = */ ggml_backend_zdnn_device_get_description,\n    /* .get_memory           = */ ggml_backend_zdnn_device_get_memory,\n    /* .get_type             = */ ggml_backend_zdnn_device_get_type,\n    /* .get_props            = */ ggml_backend_zdnn_device_get_props,\n    /* .init_backend         = */ ggml_backend_zdnn_device_init,\n    /* .get_buffer_type      = */ ggml_backend_zdnn_device_get_buffer_type,\n    /* .get_host_buffer_type = */ NULL,\n    /* .buffer_from_host_ptr = */ NULL,\n    /* .supports_op          = */ ggml_backend_zdnn_device_supports_op,\n    /* .supports_buft        = */ ggml_backend_zdnn_device_supports_buft,\n    /* .offload_op           = */ NULL,\n    /* .event_new            = */ NULL,\n    /* .event_free           = */ NULL,\n    /* .event_synchronize    = */ NULL,\n};\n\n//\n// backend registry\n//\n\nstatic const char * ggml_backend_zdnn_reg_get_name(ggml_backend_reg_t reg) {\n    return GGML_ZDNN_NAME;\n\n    GGML_UNUSED(reg);\n}\n\nstatic size_t ggml_backend_zdnn_reg_device_count(ggml_backend_reg_t reg) {\n    if (!zdnn_is_nnpa_installed()) {\n        return 0;\n    }\n    return 1;\n\n    GGML_UNUSED(reg);\n}\n\nstatic ggml_backend_dev_t ggml_backend_zdnn_reg_device_get(ggml_backend_reg_t reg, size_t index) {\n    GGML_ASSERT(index == 0);\n\n    return &g_ggml_backend_zdnn_device;\n\n    GGML_UNUSED(reg);\n    GGML_UNUSED(index);\n}\n\nstatic ggml_backend_feature g_ggml_backend_zdnn_features[] = {\n    { \"NNPA\", zdnn_is_nnpa_installed() ? \"1\" : \"0\" },\n    { \"NNPA_PARMBLKFORMAT_0\", zdnn_is_nnpa_parmblk_fmt_installed(1, NNPA_PARMBLKFORMAT_0) ? \"1\" : \"0\" },\n    { \"NNPA_PARMBLKFORMAT_1\", zdnn_is_nnpa_parmblk_fmt_installed(1, NNPA_PARMBLKFORMAT_1) ? \"1\" : \"0\" },\n    { NULL, NULL },\n};\n\nstatic ggml_backend_feature * ggml_backend_zdnn_get_features(ggml_backend_reg_t reg) {\n    return g_ggml_backend_zdnn_features;\n\n    GGML_UNUSED(reg);\n}\n\nstatic void * ggml_backend_zdnn_get_proc_address(ggml_backend_reg_t reg, const char * name) {\n    if (strcmp(name, \"ggml_backend_get_features\") == 0) {\n        return (void *) ggml_backend_zdnn_get_features;\n    }\n\n    return NULL;\n\n    GGML_UNUSED(reg);\n}\n\nstatic ggml_backend_reg_i ggml_backend_zdnn_reg_i = {\n    /* .get_name         = */ ggml_backend_zdnn_reg_get_name,\n    /* .get_device_count = */ ggml_backend_zdnn_reg_device_count,\n    /* .get_device       = */ ggml_backend_zdnn_reg_device_get,\n    /* .get_proc_address = */ ggml_backend_zdnn_get_proc_address\n};\n\nstatic void ggml_zdnn_cleanup(void) {\n    ggml_backend_zdnn_device_rel(&g_ggml_ctx_dev_main);\n}\n\n// TODO: make thread-safe\nggml_backend_reg_t ggml_backend_zdnn_reg(void) {\n    ggml_backend_zdnn_device_acq(&g_ggml_ctx_dev_main);\n\n    // register cleanup callback\n    atexit(ggml_zdnn_cleanup);\n\n    {\n        g_ggml_backend_zdnn_reg = (ggml_backend_reg) {\n            /* .api_version = */ GGML_ZDNN_VERSION,\n            /* .iface       = */ ggml_backend_zdnn_reg_i,\n            /* .context     = */ NULL\n        };\n\n        g_ggml_backend_zdnn_device = (ggml_backend_device) {\n            /* .iface       = */ ggml_backend_zdnn_device_i,\n            /* .reg         = */ &g_ggml_backend_zdnn_reg,\n            /* .context     = */ &g_ggml_ctx_dev_main\n        };\n\n        return &g_ggml_backend_zdnn_reg;\n    }\n}\n\nGGML_BACKEND_DL_IMPL(ggml_backend_zdnn_reg)\n"
  },
  {
    "path": "src/ggml-zdnn/mmf.cpp",
    "content": "#include \"ggml.h\"\n#include \"mmf.hpp\"\n\nvoid ggml_zdnn_mul_mat_f(\n    const ggml_backend_zdnn_context * ctx,\n    const               ggml_tensor * src0,\n    const               ggml_tensor * src1,\n                        ggml_tensor * dst) {\n    GGML_TENSOR_BINARY_OP_LOCALS;\n\n    const enum ggml_type type = src0->type;\n\n    GGML_ASSERT(ne0 == ne01);\n    GGML_ASSERT(ne1 == ne11);\n    GGML_ASSERT(ne2 == ne12);\n    GGML_ASSERT(ne3 == ne13);\n\n    // we don't support permuted src0 or src1\n    GGML_ASSERT(nb00 == ggml_type_size(type));\n    GGML_ASSERT(nb10 == ggml_type_size(src1->type));\n\n    // dst cannot be transposed or permuted\n    GGML_ASSERT(nb0 == sizeof(float));\n    GGML_ASSERT(nb0 <= nb1);\n    GGML_ASSERT(nb1 <= nb2);\n    GGML_ASSERT(nb2 <= nb3);\n\n    const ggml_tensor * weights = src0;\n    const ggml_tensor * inputs  = src1;\n          ggml_tensor * output  = dst;\n\n    ggml_backend_zdnn_buffer * weights_extra = (ggml_backend_zdnn_buffer *)weights->extra;\n    ggml_backend_zdnn_buffer * inputs_extra  = (ggml_backend_zdnn_buffer *)inputs->extra;\n    ggml_backend_zdnn_buffer * output_extra  = (ggml_backend_zdnn_buffer *)output->extra;\n    ggml_backend_zdnn_buffer * bias_extra    = (ggml_backend_zdnn_buffer *)output_extra->extra;\n\n    const int64_t weights_rows = ne01;\n    const int64_t weights_cols = ne00;\n    const int64_t inputs_rows  = ne11;\n    const int64_t inputs_cols  = ne10;\n\n    assert(inputs_cols == weights_cols);\n\n    const int64_t output_rows = ne1;\n    const int64_t output_cols = ne0;\n\n    // GGML_LOG_INFO(\"%s: tensor '%s' tensor dimensions: [%ld, %ld, %ld, %ld] pre_tfm_desc dimensions: [%ld, %ld, %ld, %ld]\\n\",\n    //               __func__, weights_extra->name,\n    //               weights->ne[3], weights->ne[2], weights->ne[1], weights->ne[0],\n    //               weights_extra->pre_tfm_desc.dim1,\n    //               weights_extra->pre_tfm_desc.dim2,\n    //               weights_extra->pre_tfm_desc.dim3,\n    //               weights_extra->pre_tfm_desc.dim4);\n\n    // GGML_LOG_INFO(\"%s: tensor '%s' tensor dimensions: [%ld, %ld, %ld, %ld] pre_tfm_desc dimensions: [%ld, %ld, %ld, %ld]\\n\",\n    //               __func__, inputs_extra->name,\n    //               inputs->ne[3], inputs->ne[2], inputs->ne[1], inputs->ne[0],\n    //               inputs_extra->pre_tfm_desc.dim1,\n    //               inputs_extra->pre_tfm_desc.dim2,\n    //               inputs_extra->pre_tfm_desc.dim3,\n    //               inputs_extra->pre_tfm_desc.dim4);\n\n    GGML_ASSERT(weights_extra->pre_tfm_desc.dim1 == weights->ne[0] && \"weights_extra->pre_tfm_desc.dim1 must match weights->ne[0]\");\n    GGML_ASSERT(weights_extra->pre_tfm_desc.dim2 == weights->ne[1] && \"weights_extra->pre_tfm_desc.dim2 must match weights->ne[1]\");\n    GGML_ASSERT(inputs_extra->pre_tfm_desc.dim1  == inputs->ne[0]  && \"inputs_extra->pre_tfm_desc.dim1 must match inputs->ne[0]\");\n    GGML_ASSERT(inputs_extra->pre_tfm_desc.dim2  == inputs->ne[1]  && \"inputs_extra->pre_tfm_desc.dim2 must match inputs->ne[1]\");\n\n    ZDNN_CHECK(zdnn_matmul_transpose_op(&inputs_extra->ztensor, &weights_extra->ztensor, &bias_extra->ztensor,\n                                        false, true, MATMUL_OP_ADDITION, &output_extra->ztensor));\n    // TODO: Remove in the future as we are currently DLF16 -> FP32 then in the next op, FP32 -> DLF16 again. Inefficient.\n    ZDNN_CHECK(zdnn_transform_origtensor(&output_extra->ztensor, output->data));\n\n    GGML_UNUSED(ctx);\n    GGML_UNUSED(weights_rows);\n    GGML_UNUSED(weights_cols);\n    GGML_UNUSED(inputs_rows);\n    GGML_UNUSED(inputs_cols);\n    GGML_UNUSED(output_rows);\n    GGML_UNUSED(output_cols);\n}\n"
  },
  {
    "path": "src/ggml-zdnn/mmf.hpp",
    "content": "#ifndef GGML_ZDNN_MMF_HPP\n#define GGML_ZDNN_MMF_HPP\n\n#include \"common.hpp\"\n\nvoid ggml_zdnn_mul_mat_f(\n    const ggml_backend_zdnn_context * ctx,\n    const               ggml_tensor * src0,\n    const               ggml_tensor * src1,\n                        ggml_tensor * dst);\n\n#endif  // GGML_ZDNN_MMF_HPP\n"
  },
  {
    "path": "src/ggml-zdnn/utils.cpp",
    "content": "#include \"ggml.h\"\n#include \"utils.hpp\"\n\nzdnn_data_types ggml_zdnn_type_mapping(ggml_type type) {\n    switch (type) {\n        case GGML_TYPE_F32:\n            return FP32;\n        case GGML_TYPE_F16:\n            return FP16;\n        case GGML_TYPE_BF16:\n            return BFLOAT;\n        case GGML_TYPE_Q8_0:\n            return INT8;\n        case GGML_TYPE_I8:\n            return INT8;\n        case GGML_TYPE_I32:\n            return INT32;\n        default:\n            GGML_ABORT(\"%s: fatal: unable to determine zTensor data type\",\n                       __func__);\n            break;\n    }\n}\n\nvoid ggml_zdnn_create_tensor(zdnn_tensor_desc  & pre_tfm_desc,\n                             zdnn_tensor_desc  & tfm_desc,\n                             zdnn_ztensor      & ztensor,\n                       const ggml_tensor       * src,\n                       const int64_t           * ne,\n                       const zdnn_data_layouts   layout) {\n    zdnn_init_pre_transformed_desc(\n        layout,\n        ggml_zdnn_type_mapping(src->type),\n        &pre_tfm_desc,\n        ne[3], ne[2], ne[1], ne[0]\n    );\n\n    ZDNN_CHECK(zdnn_generate_transformed_desc(&pre_tfm_desc, &tfm_desc));\n    ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&pre_tfm_desc, &tfm_desc, &ztensor));\n}\n\nvoid ggml_zdnn_load_tensor(zdnn_ztensor & ztensor, void * buffer) {\n    ZDNN_CHECK(zdnn_transform_ztensor(&ztensor, buffer));\n}\n\nvoid ggml_zdnn_init_tensor(ggml_backend_zdnn_buffer * buffer, const ggml_tensor * tensor) {\n    switch (tensor->op) {\n        case GGML_OP_MUL_MAT:\n            {\n                zdnn_init_pre_transformed_desc(\n                    ZDNN_2D,\n                    ggml_zdnn_type_mapping(tensor->type),\n                    &buffer->pre_tfm_desc,\n                    tensor->ne[1], tensor->ne[0]\n                );\n            } break;\n\n        default:\n            {\n                // For 4D tensors, GGML uses NCHW layout. However, because zDNN\n                // automatically transforms everything to NHWC, we will use it\n                // directly to avoid the performance penalty changing the\n                // layout and reshaping the tensor.\n                zdnn_init_pre_transformed_desc(\n                    ZDNN_NHWC,\n                    ggml_zdnn_type_mapping(tensor->type),\n                    &buffer->pre_tfm_desc,\n                    tensor->ne[3], tensor->ne[2], tensor->ne[1], tensor->ne[0]\n                );\n\n                // TODO: Consider adding a ggml check.\n                // TODO: If tensor = 4D, use ZDNN_NCHW by default.\n                // TODO: If tensor = 2D, use ZDNN_NHWC by default.\n            } break;\n    }\n\n    ZDNN_CHECK(zdnn_generate_transformed_desc(&buffer->pre_tfm_desc, &buffer->tfm_desc));\n    ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&buffer->pre_tfm_desc, &buffer->tfm_desc, &buffer->ztensor));\n}\n"
  },
  {
    "path": "src/ggml-zdnn/utils.hpp",
    "content": "#ifndef GGML_ZDNN_UTILITIES_HPP\n#define GGML_ZDNN_UTILITIES_HPP\n\n#include \"common.hpp\"\n\nzdnn_data_types ggml_zdnn_type_mapping(ggml_type type);\n\nvoid ggml_zdnn_create_tensor(zdnn_tensor_desc & pre_tfm_desc,\n                             zdnn_tensor_desc & tfm_desc,\n                             zdnn_ztensor     & ztensor,\n                      const ggml_tensor       * src,\n                      const int64_t           * ne,\n                      const zdnn_data_layouts   layout);\n\nvoid ggml_zdnn_load_tensor(zdnn_ztensor & ztensor, void * buffer);\n\nvoid ggml_zdnn_init_tensor(ggml_backend_zdnn_buffer * buffer, const ggml_tensor * tensor);\n\n#endif  // GGML_ZDNN_UTILITIES_HPP\n"
  },
  {
    "path": "src/ggml-zendnn/CMakeLists.txt",
    "content": "ggml_add_backend_library(ggml-zendnn\n                         ggml-zendnn.cpp)\n\nif (NOT DEFINED ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL \"\")\n    set(ZENDNN_ROOT \"$ENV{ZENDNN_ROOT}\")\nendif()\n\nif (BUILD_SHARED_LIBS)\n    set(ZENDNN_SHARED_LIB ON)\n    set(ZENDNN_ARCHIVE_LIB OFF)\nelse()\n    set(ZENDNN_SHARED_LIB OFF)\n    set(ZENDNN_ARCHIVE_LIB ON)\nendif()\n\n# Download and build ZenDNN if not provided\nif (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL \"\" OR ZENDNN_ROOT STREQUAL \"OFF\")\n    message(STATUS \"ZENDNN_ROOT not set. Automatically downloading and building ZenDNN...\")\n    message(STATUS \"This will take several minutes on first build...\")\n\n    include(ExternalProject)\n\n    set(ZENDNN_PREFIX      ${CMAKE_BINARY_DIR}/_deps/zendnn-prefix)\n    set(ZENDNN_SOURCE_DIR  ${ZENDNN_PREFIX}/src/zendnn)\n    set(ZENDNN_BUILD_DIR   ${ZENDNN_PREFIX}/build)\n    set(ZENDNN_INSTALL_DIR ${ZENDNN_BUILD_DIR}/install)\n\n    ExternalProject_Add(\n        zendnn\n        GIT_REPOSITORY https://github.com/amd/ZenDNN.git\n        GIT_TAG a18adf8c605fb5f5e52cefd7eda08a7b18febbaf    # ZenDNN-2026-WW08\n        PREFIX      ${ZENDNN_PREFIX}\n        SOURCE_DIR  ${ZENDNN_SOURCE_DIR}\n        BINARY_DIR  ${ZENDNN_BUILD_DIR}\n        CMAKE_ARGS\n            -DCMAKE_BUILD_TYPE=Release\n            -DCMAKE_INSTALL_PREFIX=${ZENDNN_INSTALL_DIR}\n            -DZENDNNL_BUILD_EXAMPLES=OFF\n            -DZENDNNL_BUILD_DOXYGEN=OFF\n            -DZENDNNL_BUILD_GTEST=OFF\n            -DZENDNNL_BUILD_BENCHDNN=OFF\n            -DZENDNNL_DEPENDS_FBGEMM=OFF\n            -DZENDNNL_LIB_BUILD_ARCHIVE=${ZENDNN_ARCHIVE_LIB}\n            -DZENDNNL_LIB_BUILD_SHARED=${ZENDNN_SHARED_LIB}\n            -DZENDNNL_DEPENDS_AOCLDLP=ON\n            -DZENDNNL_DEPENDS_ONEDNN=ON\n            -DZENDNNL_DEPENDS_LIBXSMM=ON\n        BUILD_COMMAND   ${CMAKE_COMMAND} --build ${ZENDNN_BUILD_DIR} --target zendnnl\n        INSTALL_COMMAND ${CMAKE_COMMAND} --build ${ZENDNN_BUILD_DIR} --target install\n        BUILD_ALWAYS OFF\n        LOG_DOWNLOAD ON\n        LOG_CONFIGURE ON\n        LOG_BUILD ON\n        LOG_INSTALL ON\n    )\n\n    add_dependencies(ggml-zendnn zendnn)\n    set(ZENDNN_ROOT ${ZENDNN_INSTALL_DIR})\n    message(STATUS \"ZenDNN will be built to: ${ZENDNN_ROOT}\")\nelse()\n    message(STATUS \"Using custom ZenDNN installation at: ${ZENDNN_ROOT}\")\nendif()\n\ntarget_include_directories(ggml-zendnn PRIVATE\n    ${ZENDNN_ROOT}/zendnnl/include\n    ${ZENDNN_ROOT}/deps/json/include\n    ${ZENDNN_ROOT}/deps/aoclutils/include\n    ${ZENDNN_ROOT}/deps/aocldlp/include\n    ${ZENDNN_ROOT}/deps/onednn/include\n    ${ZENDNN_ROOT}/deps/libxsmm/include)\n\nif (ZENDNN_SHARED_LIB)\n    target_link_directories(ggml-zendnn PRIVATE ${ZENDNN_ROOT}/zendnnl/lib)\n    target_link_libraries(ggml-zendnn PRIVATE zendnnl)\nelseif (ZENDNN_ARCHIVE_LIB)\n    target_link_libraries(ggml-zendnn PRIVATE\n        ${ZENDNN_ROOT}/zendnnl/lib/libzendnnl_archive.a\n        ${ZENDNN_ROOT}/deps/aoclutils/${CMAKE_INSTALL_LIBDIR}/libaoclutils.a\n        ${ZENDNN_ROOT}/deps/aoclutils/${CMAKE_INSTALL_LIBDIR}/libau_cpuid.a\n        ${ZENDNN_ROOT}/deps/aocldlp/lib/libaocl-dlp.a\n        ${ZENDNN_ROOT}/deps/onednn/${CMAKE_INSTALL_LIBDIR}/libdnnl.a\n        ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmm.a\n        ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmmext.a\n        ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmmnoblas.a)\nendif()\n\ntarget_link_libraries(ggml-zendnn PRIVATE m pthread)\n\nif (GGML_OPENMP)\n    target_link_libraries(ggml-zendnn PRIVATE OpenMP::OpenMP_CXX)\nendif()\n"
  },
  {
    "path": "src/ggml-zendnn/ggml-zendnn.cpp",
    "content": "#include \"ggml-zendnn.h\"\n\n#include \"ggml-backend-impl.h\"\n#include \"ggml-impl.h\"\n#include \"zendnnl.hpp\"\n\n#include <cstring>\n\n\nstruct ggml_backend_zendnn_context {\n    int n_threads = GGML_DEFAULT_N_THREADS;\n    std::unique_ptr<char[]> work_data;\n    size_t work_size = 0;\n};\n\ntemplate<typename T>\nzendnnl::common::data_type_t ggml_to_zendnn_type() {\n    if constexpr (std::is_same_v<T, float>) {\n        return zendnnl::common::data_type_t::f32;\n    } else if constexpr (std::is_same_v<T, ggml_bf16_t>) {\n        return zendnnl::common::data_type_t::bf16;\n    } else {\n        return zendnnl::common::data_type_t::none;\n    }\n}\n\n/**\n * ZenDNN matmul: computes C = B * A.\n *\n * - A: weights, shape (k, m), column-major (each column is a weight vector for one output).\n * - B: input, shape (n, k), row-major (each row is an input sample).\n * - C: output, shape (n, m), row-major.\n *\n * Dimensions:\n *   m = output features (columns of C, columns of A)\n *   n = batch size      (rows of C, rows of B)\n *   k = inner dimension (columns of B, rows of A)\n */\ntemplate <typename TA, typename TB, typename TC>\nstatic bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k,\n                               const TA * A, int64_t lda, const TB * B, int64_t ldb, TC * C,\n                               int64_t ldc) {\n\n    zendnnl::lowoha::matmul::matmul_params params;\n    params.dtypes.src = ggml_to_zendnn_type<TB>();\n    params.dtypes.wei = ggml_to_zendnn_type<TA>();\n    params.dtypes.dst = ggml_to_zendnn_type<TC>();\n    params.num_threads = ctx->n_threads;\n\n    zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct(\n        'r', false, true,   // row-major, don't transpose B, transpose A (because it's column-major)\n        n,                  // M: rows of B and C\n        m,                  // N: cols of A^T and C\n        k,                  // K: cols of B, rows of A\n        1.0f,               // alpha\n        B, ldb,             // src: B[n,k]\n        A, lda,             // weight: A[k,m] column-major (transposed)\n        nullptr,            // bias\n        0.0f,               // beta\n        C, ldc,             // output C[n,m]\n        true,               // is_weights_const\n        {},                 // batch_params\n        params              // params\n    );\n\n    if (status != zendnnl::error_handling::status_t::success) {\n        GGML_LOG_ERROR(\"%s, ZenDNN matmul failed: status=%d\\n\", __func__, static_cast<int>(status));\n        return false;\n    }\n    return true;\n}\n\nstatic bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k,\n                              const void * A, int64_t lda, const void * B, int64_t ldb, void * C,\n                              int64_t ldc, int Atype, int Btype, int Ctype) {\n\n    assert(m >= 0);\n    assert(n >= 0);\n    assert(k >= 0);\n    assert(lda >= k);\n    assert(ldb >= k);\n    assert(ldc >= m);\n\n    // categorize types\n    switch (Atype) {\n        case GGML_TYPE_F32:\n            if (Btype != GGML_TYPE_F32 || Ctype != GGML_TYPE_F32)\n                return false;\n            return ggml_zendnn_matmul<float, float, float>(\n                ctx, m, n, k,\n                (const float *)A, lda,\n                (const float *)B, ldb,\n                (float *)C, ldc);\n        case GGML_TYPE_BF16:\n            if (Btype != GGML_TYPE_BF16)\n                return false;\n            if (Ctype == GGML_TYPE_BF16)\n                return ggml_zendnn_matmul<ggml_bf16_t, ggml_bf16_t, ggml_bf16_t>(\n                    ctx, m, n, k,\n                    (const ggml_bf16_t *)A, lda,\n                    (const ggml_bf16_t *)B, ldb,\n                    (ggml_bf16_t *)C, ldc);\n            if (Ctype == GGML_TYPE_F32)\n                return ggml_zendnn_matmul<ggml_bf16_t, ggml_bf16_t, float>(\n                    ctx, m, n, k,\n                    (const ggml_bf16_t *)A, lda,\n                    (const ggml_bf16_t *)B, ldb,\n                    (float *)C, ldc);\n            return false;\n        default:\n            return false; // unsupported type\n    }\n}\n\nstatic void ggml_zendnn_compute_forward_mul_mat(\n    ggml_backend_zendnn_context * ctx,\n    ggml_tensor * dst) {\n\n    const ggml_tensor * src0 = dst->src[0];  // weights\n    const ggml_tensor * src1 = dst->src[1];  // inputs\n\n    GGML_TENSOR_BINARY_OP_LOCALS\n\n    ggml_type         const vec_dot_type = src0->type;\n    ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref;\n\n    GGML_ASSERT(ne0 == ne01);\n    GGML_ASSERT(ne1 == ne11);\n    GGML_ASSERT(ne2 == ne12);\n    GGML_ASSERT(ne3 == ne13);\n\n    // we don't support permuted src0 or src1\n    GGML_ASSERT(nb00 == ggml_type_size(src0->type));\n    GGML_ASSERT(nb10 == ggml_type_size(src1->type));\n\n    // dst cannot be transposed or permuted\n    GGML_ASSERT(nb0 == sizeof(float));\n    GGML_ASSERT(nb0 <= nb1);\n    GGML_ASSERT(nb1 <= nb2);\n    GGML_ASSERT(nb2 <= nb3);\n\n    // broadcast factors\n    const int64_t r2 = ne12/ne02;\n    const int64_t r3 = ne13/ne03;\n\n    void * work_data = ctx->work_data.get();\n    if (src1->type != vec_dot_type) {\n        const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);\n        const size_t nbw2 = nbw1 * ne11;\n        const size_t nbw3 = nbw2 * ne12;\n        const size_t desired_wsize = ne13 * nbw3;\n        if (ctx->work_size < desired_wsize) {\n            ctx->work_data.reset(new char[desired_wsize]);\n            ctx->work_size = desired_wsize;\n        }\n        work_data = ctx->work_data.get();\n\n        // #pragma omp parallel for num_threads(ctx->n_threads)\n        #pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static)\n        for (int64_t i13 = 0; i13 < ne13; ++i13) {\n            for (int64_t i12 = 0; i12 < ne12; ++i12) {\n                for (int64_t i11 = 0; i11 < ne11; ++i11) {\n                    const float * src1_f32 = (float *)((char *)src1->data + i11*nb11 + i12*nb12 + i13*nb13);\n                    void * src1_conv = (char *)work_data + i11*nbw1 + i12*nbw2 + i13*nbw3;\n                    from_float(src1_f32, src1_conv, ne10);\n                }\n            }\n        }\n    }\n\n    for (int64_t i13 = 0; i13 < ne13; i13++) {\n        for (int64_t i12 = 0; i12 < ne12; i12++) {\n            const void* wdata = src1->type == vec_dot_type ? src1->data : work_data;\n            const size_t row_size = ggml_row_size(vec_dot_type, ne10);\n            if (!ggml_zendnn_sgemm(ctx,\n                                  ne01,     // m\n                                  ne11,     // n\n                                  ne10,     // k\n                                  static_cast<const char *>(src0->data) + (i12/r2)*nb02 + (i13/r3)*nb03,\n                                  ne00,     // lda\n                                  static_cast<const char *>(wdata) + (i12*ne11 + i13*ne12*ne11)*row_size,\n                                  ne10,     // ldb\n                                  static_cast<char *>(dst->data) + i12*nb2 + i13*nb3,\n                                  ne01,     // ldc\n                                  src0->type,\n                                  vec_dot_type,\n                                  dst->type))\n                GGML_ABORT(\"%s: ZenDNN sgemm failed\\n\", __func__);\n        }\n    }\n}\n\n// backend interface\n\nstatic const char * ggml_backend_zendnn_get_name(ggml_backend_t backend) {\n    return \"ZenDNN\";\n\n    GGML_UNUSED(backend);\n}\n\nstatic void ggml_backend_zendnn_free(ggml_backend_t backend) {\n    ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend->context;\n    delete ctx;\n    delete backend;\n}\n\nstatic ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {\n    ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend->context;\n\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        struct ggml_tensor * node = cgraph->nodes[i];\n\n        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {\n            continue;\n        }\n\n        switch (node->op) {\n            case GGML_OP_MUL_MAT:\n                ggml_zendnn_compute_forward_mul_mat(ctx, node);\n                break;\n            case GGML_OP_NONE:\n            case GGML_OP_RESHAPE:\n            case GGML_OP_VIEW:\n            case GGML_OP_PERMUTE:\n            case GGML_OP_TRANSPOSE:\n                break;\n\n            default:\n                GGML_ABORT(\"%s: unsupported op %s\\n\", __func__, ggml_op_desc(node));\n        }\n    }\n\n    return GGML_STATUS_SUCCESS;\n\n    GGML_UNUSED(backend);\n}\n\nstatic struct ggml_backend_i ggml_backend_zendnn_i = {\n    /* .get_name                = */ ggml_backend_zendnn_get_name,\n    /* .free                    = */ ggml_backend_zendnn_free,\n    /* .set_tensor_async        = */ NULL,\n    /* .get_tensor_async        = */ NULL,\n    /* .cpy_tensor_async        = */ NULL,\n    /* .synchronize             = */ NULL,\n    /* .graph_plan_create       = */ NULL,\n    /* .graph_plan_free         = */ NULL,\n    /* .graph_plan_update       = */ NULL,\n    /* .graph_plan_compute      = */ NULL,\n    /* .graph_compute           = */ ggml_backend_zendnn_graph_compute,\n    /* .event_record            = */ NULL,\n    /* .event_wait              = */ NULL,\n    /* .graph_optimize          = */ NULL,\n};\n\nstatic ggml_guid_t ggml_backend_zendnn_guid(void) {\n    static const char * guid_str = \"AMD-ZENDNN-ACCEL\";\n    return reinterpret_cast<ggml_guid_t>(const_cast<char*>(guid_str));\n}\n\nggml_backend_t ggml_backend_zendnn_init(void) {\n    ggml_backend_zendnn_context * ctx = new ggml_backend_zendnn_context;\n\n    ggml_backend_t backend = new ggml_backend {\n        /* .guid    = */ ggml_backend_zendnn_guid(),\n        /* .iface   = */ ggml_backend_zendnn_i,\n        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_zendnn_reg(), 0),\n        /* .context = */ ctx,\n    };\n\n    return backend;\n}\n\nbool ggml_backend_is_zendnn(ggml_backend_t backend) {\n    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_zendnn_guid());\n}\n\nvoid ggml_backend_zendnn_set_n_threads(ggml_backend_t backend_zendnn, int n_threads) {\n    GGML_ASSERT(ggml_backend_is_zendnn(backend_zendnn));\n\n    ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend_zendnn->context;\n    ctx->n_threads = n_threads;\n}\n\n// device interface\nstatic const char * ggml_backend_zendnn_device_get_name(ggml_backend_dev_t dev) {\n    return \"ZenDNN\";\n\n    GGML_UNUSED(dev);\n}\n/**\n * ZenDNN is AMD's performance library providing optimized primitives and implementations\n * for deep learning workloads on AMD CPUs. It targets improved performance for common\n * neural network operations on AMD architectures. For more information, see:\n * https://www.amd.com/en/developer/zendnn.html\n */\nstatic const char * ggml_backend_zendnn_device_get_description(ggml_backend_dev_t dev) {\n    return \"ZenDNN: AMD optimized primitives backend for GGML (optimized for AMD CPUs)\";\n\n    GGML_UNUSED(dev);\n}\n\nstatic void ggml_backend_zendnn_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {\n    *free  = 0;\n    *total = 0;\n\n    GGML_UNUSED(dev);\n}\n\nstatic enum ggml_backend_dev_type ggml_backend_zendnn_device_get_type(ggml_backend_dev_t dev) {\n    return GGML_BACKEND_DEVICE_TYPE_ACCEL;\n\n    GGML_UNUSED(dev);\n}\n\nstatic void ggml_backend_zendnn_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {\n    props->name        = ggml_backend_zendnn_device_get_name(dev);\n    props->description = ggml_backend_zendnn_device_get_description(dev);\n    props->type        = ggml_backend_zendnn_device_get_type(dev);\n    ggml_backend_zendnn_device_get_memory(dev, &props->memory_free, &props->memory_total);\n    props->caps = {\n        /* .async                = */ false,\n        /* .host_buffer          = */ false,\n        /* .buffer_from_host_ptr = */ true,\n        /* .events               = */ false\n    };\n}\n\nstatic ggml_backend_t ggml_backend_zendnn_device_init_backend(ggml_backend_dev_t dev, const char * params) {\n    ggml_backend_t backend = ggml_backend_zendnn_init();\n    if (backend == NULL) {\n        GGML_LOG_ERROR(\"%s: error: failed to initialize ZenDNN backend\\n\", __func__);\n        return NULL;\n    }\n\n    return backend;\n\n    GGML_UNUSED(dev);\n    GGML_UNUSED(params);\n}\n\nstatic ggml_backend_buffer_type_t ggml_backend_zendnn_device_get_buffer_type(ggml_backend_dev_t dev) {\n    return ggml_backend_cpu_buffer_type();\n\n    GGML_UNUSED(dev);\n}\n\nstatic ggml_backend_buffer_t ggml_backend_zendnn_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {\n    return ggml_backend_cpu_buffer_from_ptr(ptr, size);\n\n    GGML_UNUSED(dev);\n    GGML_UNUSED(max_tensor_size);\n}\n\nstatic bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {\n    switch (op->op) {\n        case GGML_OP_NONE:\n        case GGML_OP_RESHAPE:\n        case GGML_OP_VIEW:\n        case GGML_OP_PERMUTE:\n        case GGML_OP_TRANSPOSE:\n            return true;\n\n        case GGML_OP_MUL_MAT:\n        {\n            const ggml_tensor * weights = op->src[0];\n            const ggml_tensor * inputs = op->src[1];\n\n            const int64_t ne10 = inputs->ne[0];\n            const int64_t ne0 = op->ne[0];\n            const int64_t ne1 = op->ne[1];\n\n            const int64_t min_batch = 1;\n            if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) ||\n                ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {\n                    return false;\n            }\n            switch (weights->type) {\n                case GGML_TYPE_F32:\n                case GGML_TYPE_BF16:\n                    return true;\n                default:\n                    return false;\n            }\n        } break;\n\n        default:\n            return false;\n    }\n\n    GGML_UNUSED(dev);\n}\n\nstatic bool ggml_backend_zendnn_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {\n    return ggml_backend_buft_is_host(buft);\n\n    GGML_UNUSED(dev);\n}\n\nstatic const struct ggml_backend_device_i ggml_backend_zendnn_device_i = {\n    /* .get_name               = */ ggml_backend_zendnn_device_get_name,\n    /* .get_description        = */ ggml_backend_zendnn_device_get_description,\n    /* .get_memory             = */ ggml_backend_zendnn_device_get_memory,\n    /* .get_type               = */ ggml_backend_zendnn_device_get_type,\n    /* .get_props              = */ ggml_backend_zendnn_device_get_props,\n    /* .init_backend           = */ ggml_backend_zendnn_device_init_backend,\n    /* .get_buffer_type        = */ ggml_backend_zendnn_device_get_buffer_type,\n    /* .get_host_buffer_type   = */ NULL,\n    /* .buffer_from_host_ptr   = */ ggml_backend_zendnn_device_buffer_from_host_ptr,\n    /* .supports_op            = */ ggml_backend_zendnn_device_supports_op,\n    /* .supports_buft          = */ ggml_backend_zendnn_device_supports_buft,\n    /* .offload_op             = */ NULL,\n    /* .event_new              = */ NULL,\n    /* .event_free             = */ NULL,\n    /* .event_synchronize      = */ NULL,\n};\n\n// backend reg interface\nstatic const char * ggml_backend_zendnn_reg_get_name(ggml_backend_reg_t reg) {\n    return \"ZenDNN\";\n\n    GGML_UNUSED(reg);\n}\n\nstatic size_t ggml_backend_zendnn_reg_get_device_count(ggml_backend_reg_t reg) {\n    return 1;\n\n    GGML_UNUSED(reg);\n}\n\nstatic ggml_backend_dev_t ggml_backend_zendnn_reg_get_device(ggml_backend_reg_t reg, size_t index) {\n    GGML_ASSERT(index == 0);\n\n    static ggml_backend_device ggml_backend_zendnn_device = {\n        /* .iface   = */ ggml_backend_zendnn_device_i,\n        /* .reg     = */ reg,\n        /* .context = */ nullptr,\n    };\n\n    return &ggml_backend_zendnn_device;\n}\n\nstatic void * ggml_backend_zendnn_get_proc_address(ggml_backend_reg_t reg, const char * name) {\n    if (std::strcmp(name, \"ggml_backend_set_n_threads\") == 0) {\n        return (void *) ggml_backend_zendnn_set_n_threads;\n    }\n    return NULL;\n\n    GGML_UNUSED(reg);\n    GGML_UNUSED(name);\n}\n\nstatic const struct ggml_backend_reg_i ggml_backend_zendnn_reg_i = {\n    /* .get_name         = */ ggml_backend_zendnn_reg_get_name,\n    /* .get_device_count = */ ggml_backend_zendnn_reg_get_device_count,\n    /* .get_device       = */ ggml_backend_zendnn_reg_get_device,\n    /* .get_proc_address = */ ggml_backend_zendnn_get_proc_address,\n};\n\nggml_backend_reg_t ggml_backend_zendnn_reg(void) {\n    static struct ggml_backend_reg ggml_backend_zendnn_reg = {\n        /* .api_version = */ GGML_BACKEND_API_VERSION,\n        /* .iface       = */ ggml_backend_zendnn_reg_i,\n        /* .context     = */ NULL,\n    };\n\n    return &ggml_backend_zendnn_reg;\n}\n\nGGML_BACKEND_DL_IMPL(ggml_backend_zendnn_reg)\n"
  },
  {
    "path": "src/ggml.c",
    "content": "#define _CRT_SECURE_NO_DEPRECATE // Disables \"unsafe\" warnings on Windows\n#define _USE_MATH_DEFINES // For M_PI on MSVC\n\n#include \"ggml-backend.h\"\n#include \"ggml-impl.h\"\n#include \"ggml-threading.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml.h\"\n\n// FIXME: required here for quantization functions\n#include \"ggml-quants.h\"\n\n#ifdef GGML_USE_CPU_HBM\n#include <hbwmalloc.h>\n#endif\n\n#if defined(_MSC_VER) || defined(__MINGW32__)\n#include <malloc.h> // using malloc.h with MSC/MINGW\n#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)\n#include <alloca.h>\n#endif\n\n#include <assert.h>\n#include <errno.h>\n#include <time.h>\n#include <math.h>\n#include <stdlib.h>\n#include <string.h>\n#include <stdint.h>\n#include <inttypes.h>\n#include <stdio.h>\n#include <float.h>\n#include <limits.h>\n#include <stdarg.h>\n#include <signal.h>\n#if defined(__gnu_linux__)\n#include <syscall.h>\n#endif\n\n#if defined(__APPLE__)\n#include <unistd.h>\n#include <mach/mach.h>\n#include <TargetConditionals.h>\n#endif\n\n#if defined(_WIN32)\n#define WIN32_LEAN_AND_MEAN\n#ifndef NOMINMAX\n    #define NOMINMAX\n#endif\n#include <windows.h>\n#endif\n\n#define UNUSED GGML_UNUSED\n\n// Needed for ggml_fp32_to_bf16_row()\n#if defined(__AVX512BF16__)\n#if defined(_MSC_VER)\n#define m512i(p) p\n#else\n#include <immintrin.h>\n#define m512i(p) (__m512i)(p)\n#endif // defined(_MSC_VER)\n#endif // defined(__AVX512BF16__)\n\n#if defined(__linux__) || \\\n    defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \\\n    (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)\n\n#include <unistd.h>\n#include <sys/types.h>\n#include <sys/stat.h>\n#include <sys/wait.h>\n#if defined(__linux__)\n#include <sys/prctl.h>\n#endif\n\n#if defined(__ANDROID__)\n#include <unwind.h>\n#include <dlfcn.h>\n#include <stdio.h>\n\nstruct backtrace_state {\n    void ** current;\n    void ** end;\n};\n\nstatic _Unwind_Reason_Code unwind_callback(struct _Unwind_Context* context, void* arg) {\n    struct backtrace_state * state = (struct backtrace_state *)arg;\n    uintptr_t pc = _Unwind_GetIP(context);\n    if (pc) {\n        if (state->current == state->end) {\n            return _URC_END_OF_STACK;\n        } else {\n            *state->current++ = (void*)pc;\n        }\n    }\n    return _URC_NO_REASON;\n}\n\nstatic void ggml_print_backtrace_symbols(void) {\n    const int max = 100;\n    void* buffer[max];\n\n    struct backtrace_state state = {buffer, buffer + max};\n    _Unwind_Backtrace(unwind_callback, &state);\n\n    int count = state.current - buffer;\n\n    for (int idx = 0; idx < count; ++idx) {\n        const void * addr = buffer[idx];\n        const char * symbol = \"\";\n\n        Dl_info info;\n        if (dladdr(addr, &info) && info.dli_sname) {\n            symbol = info.dli_sname;\n        }\n\n        fprintf(stderr, \"%d: %p %s\\n\", idx, addr, symbol);\n    }\n}\n#elif defined(__linux__) && defined(__GLIBC__)\n#include <execinfo.h>\nstatic void ggml_print_backtrace_symbols(void) {\n    void * trace[100];\n    int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0]));\n    backtrace_symbols_fd(trace, nptrs, STDERR_FILENO);\n}\n#elif defined(__APPLE__)\n#include <execinfo.h>\nstatic void ggml_print_backtrace_symbols(void) {\n    void * trace[100];\n    int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0]));\n    backtrace_symbols_fd(trace, nptrs, STDERR_FILENO);\n}\n#else\nstatic void ggml_print_backtrace_symbols(void) {\n    // platform not supported\n}\n#endif\n\nvoid ggml_print_backtrace(void) {\n    const char * GGML_NO_BACKTRACE = getenv(\"GGML_NO_BACKTRACE\");\n    if (GGML_NO_BACKTRACE) {\n        return;\n    }\n#if defined(__APPLE__)\n    // On macOS, fork+debugger attachment is problematic due to:\n    // 1. libdispatch \"poisons\" forked child processes\n    // 2. lldb has issues attaching to parent from forked child\n    // Use simple backtrace() instead to avoid Terminal.app crashes\n    const char * GGML_BACKTRACE_LLDB = getenv(\"GGML_BACKTRACE_LLDB\");\n    if (!GGML_BACKTRACE_LLDB) {\n        fprintf(stderr, \"WARNING: Using native backtrace. Set GGML_BACKTRACE_LLDB for more info.\\n\");\n        fprintf(stderr, \"WARNING: GGML_BACKTRACE_LLDB may cause native MacOS Terminal.app to crash.\\n\");\n        fprintf(stderr, \"See: https://github.com/ggml-org/llama.cpp/pull/17869\\n\");\n        ggml_print_backtrace_symbols();\n        return;\n    }\n#endif\n#if defined(__linux__)\n    FILE * f = fopen(\"/proc/self/status\", \"r\");\n    size_t size = 0;\n    char * line = NULL;\n    ssize_t length = 0;\n    while ((length = getline(&line, &size, f)) > 0) {\n        if (!strncmp(line, \"TracerPid:\", sizeof(\"TracerPid:\") - 1) &&\n            (length != sizeof(\"TracerPid:\\t0\\n\") - 1 || line[length - 2] != '0')) {\n            // Already being debugged, and the breakpoint is the later abort()\n            free(line);\n            fclose(f);\n            return;\n        }\n    }\n    free(line);\n    fclose(f);\n    int lock[2] = { -1, -1 };\n    (void) !pipe(lock); // Don't start gdb until after PR_SET_PTRACER\n#endif\n    const int parent_pid = getpid();\n    const int child_pid = fork();\n    if (child_pid < 0) { // error\n#if defined(__linux__)\n        close(lock[1]);\n        close(lock[0]);\n#endif\n        return;\n    } else if (child_pid == 0) { // child\n        char attach[32];\n        snprintf(attach, sizeof(attach), \"attach %d\", parent_pid);\n#if defined(__linux__)\n        close(lock[1]);\n        (void) !read(lock[0], lock, 1);\n        close(lock[0]);\n#endif\n        // try gdb\n        execlp(\"gdb\", \"gdb\", \"--batch\",\n            \"-ex\", \"set style enabled on\",\n            \"-ex\", attach,\n            \"-ex\", \"bt -frame-info source-and-location\",\n            \"-ex\", \"detach\",\n            \"-ex\", \"quit\",\n            (char *) NULL);\n        // try lldb\n        execlp(\"lldb\", \"lldb\", \"--batch\",\n            \"-o\", \"bt\",\n            \"-o\", \"quit\",\n            \"-p\", &attach[sizeof(\"attach \") - 1],\n            (char *) NULL);\n        // gdb failed, fallback to backtrace_symbols\n        ggml_print_backtrace_symbols();\n        _Exit(0);\n    } else { // parent\n#if defined(__linux__)\n        prctl(PR_SET_PTRACER, child_pid);\n        close(lock[1]);\n        close(lock[0]);\n#endif\n        waitpid(child_pid, NULL, 0);\n    }\n}\n#else\nvoid ggml_print_backtrace(void) {\n    // platform not supported\n}\n#endif\n\nstatic ggml_abort_callback_t g_abort_callback = NULL;\n\n// Set the abort callback (passing null will restore original abort functionality: printing a message to stdout)\nGGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback) {\n    ggml_abort_callback_t ret_val = g_abort_callback;\n    g_abort_callback = callback;\n    return ret_val;\n}\n\nvoid ggml_abort(const char * file, int line, const char * fmt, ...) {\n    fflush(stdout);\n\n    char message[2048];\n    int offset = snprintf(message, sizeof(message), \"%s:%d: \", file, line);\n\n    va_list args;\n    va_start(args, fmt);\n    vsnprintf(message + offset, sizeof(message) - offset, fmt, args);\n    va_end(args);\n\n    if (g_abort_callback) {\n        g_abort_callback(message);\n    } else {\n        // default: print error and backtrace to stderr\n        fprintf(stderr, \"%s\\n\", message);\n        ggml_print_backtrace();\n    }\n\n    abort();\n}\n\n// ggml_print_backtrace is registered with std::set_terminate by ggml.cpp\n\n//\n// logging\n//\n\nstruct ggml_logger_state {\n    ggml_log_callback log_callback;\n    void * log_callback_user_data;\n};\nstatic struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};\n\nstatic void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {\n    if (format == NULL) {\n        return;\n    }\n    va_list args_copy;\n    va_copy(args_copy, args);\n    char buffer[128];\n    int len = vsnprintf(buffer, 128, format, args);\n    if (len < 128) {\n        g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);\n    } else {\n        char * buffer2 = (char *) calloc(len + 1, sizeof(char));\n        vsnprintf(buffer2, len + 1, format, args_copy);\n        buffer2[len] = 0;\n        g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);\n        free(buffer2);\n    }\n    va_end(args_copy);\n}\n\nvoid ggml_log_internal(enum ggml_log_level level, const char * format, ...) {\n    va_list args;\n    va_start(args, format);\n    ggml_log_internal_v(level, format, args);\n    va_end(args);\n}\n\nvoid ggml_log_callback_default(enum ggml_log_level level, const char * text, void * user_data) {\n    (void) level;\n    (void) user_data;\n    fputs(text, stderr);\n    fflush(stderr);\n}\n\n//\n// end of logging block\n//\n\n#ifdef GGML_USE_ACCELERATE\n// uncomment to use vDSP for soft max computation\n// note: not sure if it is actually faster\n//#define GGML_SOFT_MAX_ACCELERATE\n#endif\n\n\nvoid * ggml_aligned_malloc(size_t size) {\n#if defined(__s390x__)\n    const int alignment = 256;\n#else\n    const int alignment = 64;\n#endif\n\n#if defined(_MSC_VER) || defined(__MINGW32__)\n    return _aligned_malloc(size, alignment);\n#else\n    if (size == 0) {\n        GGML_LOG_WARN(\"Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\\n\");\n        return NULL;\n    }\n    void * aligned_memory = NULL;\n  #ifdef GGML_USE_CPU_HBM\n    int result = hbw_posix_memalign(&aligned_memory, alignment, size);\n  #elif TARGET_OS_OSX\n    GGML_UNUSED(alignment);\n    kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE);\n    int result = EFAULT;\n    switch (alloc_status) {\n        case KERN_SUCCESS:\n            result = 0;\n            break;\n        case KERN_INVALID_ADDRESS:\n            result = EINVAL;\n            break;\n        case KERN_NO_SPACE:\n            result = ENOMEM;\n            break;\n        default:\n            result = EFAULT;\n            break;\n    }\n  #else\n    int result = posix_memalign(&aligned_memory, alignment, size);\n  #endif\n    if (result != 0) {\n        // Handle allocation failure\n        const char *error_desc = \"unknown allocation error\";\n        switch (result) {\n            case EINVAL:\n                error_desc = \"invalid alignment value\";\n                break;\n            case ENOMEM:\n                error_desc = \"insufficient memory\";\n                break;\n        }\n        GGML_LOG_ERROR(\"%s: %s (attempted to allocate %6.2f MB)\\n\", __func__, error_desc, size/(1024.0*1024.0));\n        return NULL;\n    }\n    return aligned_memory;\n#endif\n}\n\nvoid ggml_aligned_free(void * ptr, size_t size) {\n    GGML_UNUSED(size);\n#if defined(_MSC_VER) || defined(__MINGW32__)\n    _aligned_free(ptr);\n#elif GGML_USE_CPU_HBM\n    if (ptr != NULL) {\n        hbw_free(ptr);\n    }\n#elif TARGET_OS_OSX\n    if (ptr != NULL) {\n        vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ptr, size);\n    }\n#else\n    free(ptr);\n#endif\n}\n\n\ninline static void * ggml_malloc(size_t size) {\n    if (size == 0) {\n        GGML_LOG_WARN(\"Behavior may be unexpected when allocating 0 bytes for ggml_malloc!\\n\");\n        return NULL;\n    }\n    void * result = malloc(size);\n    if (result == NULL) {\n        GGML_LOG_ERROR(\"%s: failed to allocate %6.2f MB\\n\", __func__, size/(1024.0*1024.0));\n        GGML_ABORT(\"fatal error\");\n    }\n    return result;\n}\n\n// calloc\ninline static void * ggml_calloc(size_t num, size_t size) {\n    if (num == 0 || size == 0) {\n        GGML_LOG_WARN(\"Behavior may be unexpected when allocating 0 bytes for ggml_calloc!\\n\");\n        return NULL;\n    }\n    void * result = calloc(num, size);\n    if (result == NULL) {\n        GGML_LOG_ERROR(\"%s: failed to allocate %6.2f MB\\n\", __func__, size/(1024.0*1024.0));\n        GGML_ABORT(\"fatal error\");\n    }\n    return result;\n}\n\n#define GGML_MALLOC(size)      ggml_malloc(size)\n#define GGML_CALLOC(num, size) ggml_calloc(num, size)\n\n#define GGML_FREE(ptr) free(ptr)\n\nconst char * ggml_status_to_string(enum ggml_status status) {\n    switch (status) {\n        case GGML_STATUS_ALLOC_FAILED: return \"GGML status: error (failed to allocate memory)\";\n        case GGML_STATUS_FAILED:       return \"GGML status: error (operation failed)\";\n        case GGML_STATUS_SUCCESS:      return \"GGML status: success\";\n        case GGML_STATUS_ABORTED:      return \"GGML status: warning (operation aborted)\";\n    }\n\n    return \"GGML status: unknown\";\n}\n\nfloat ggml_fp16_to_fp32(ggml_fp16_t x) {\n#define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml\n    return GGML_FP16_TO_FP32(x);\n}\n\nggml_fp16_t ggml_fp32_to_fp16(float x) {\n#define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml\n    return GGML_FP32_TO_FP16(x);\n}\n\nfloat ggml_bf16_to_fp32(ggml_bf16_t x) {\n#define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml\n    return GGML_BF16_TO_FP32(x);  // it just left shifts\n}\n\nggml_bf16_t ggml_fp32_to_bf16(float x) {\n#define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml\n    return GGML_FP32_TO_BF16(x);\n}\n\nvoid ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {\n    for (int64_t i = 0; i < n; i++) {\n        y[i] = GGML_FP16_TO_FP32(x[i]);\n    }\n}\n\nvoid ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {\n    int i = 0;\n    for (; i < n; ++i) {\n        y[i] = GGML_FP32_TO_FP16(x[i]);\n    }\n}\n\nvoid ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {\n    int i = 0;\n    for (; i < n; ++i) {\n        y[i] = GGML_BF16_TO_FP32(x[i]);\n    }\n}\n\nvoid ggml_fp32_to_bf16_row_ref(const float * x, ggml_bf16_t * y, int64_t n) {\n    for (int i = 0; i < n; i++) {\n        y[i] = ggml_compute_fp32_to_bf16(x[i]);\n    }\n}\n\nvoid ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {\n  int i = 0;\n#if defined(__AVX512BF16__)\n  // subnormals are flushed to zero on this platform\n  for (; i + 32 <= n; i += 32) {\n        _mm512_storeu_si512(\n            (__m512i *)(y + i),\n            m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),\n                                _mm512_loadu_ps(x + i))));\n  }\n#endif\n    for (; i < n; i++) {\n        y[i] = GGML_FP32_TO_BF16(x[i]);\n    }\n}\n\nbool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {\n    return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;\n}\n\nconst char * ggml_version(void) {\n    return GGML_VERSION;\n}\n\nconst char * ggml_commit(void) {\n    return GGML_COMMIT;\n}\n\n//\n// timing\n//\n\n#if defined(_MSC_VER) || defined(__MINGW32__)\nstatic int64_t timer_freq, timer_start;\nvoid ggml_time_init(void) {\n    LARGE_INTEGER t;\n    QueryPerformanceFrequency(&t);\n    timer_freq = t.QuadPart;\n\n    // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq\n    // and the uptime is high enough.\n    // We subtract the program start time to reduce the likelihood of that happening.\n    QueryPerformanceCounter(&t);\n    timer_start = t.QuadPart;\n}\nint64_t ggml_time_ms(void) {\n    LARGE_INTEGER t;\n    QueryPerformanceCounter(&t);\n    return ((t.QuadPart-timer_start) * 1000) / timer_freq;\n}\nint64_t ggml_time_us(void) {\n    LARGE_INTEGER t;\n    QueryPerformanceCounter(&t);\n    return ((t.QuadPart-timer_start) * 1000000) / timer_freq;\n}\n#else\nvoid ggml_time_init(void) {}\nint64_t ggml_time_ms(void) {\n    struct timespec ts;\n    clock_gettime(CLOCK_MONOTONIC, &ts);\n    return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000;\n}\n\nint64_t ggml_time_us(void) {\n    struct timespec ts;\n    clock_gettime(CLOCK_MONOTONIC, &ts);\n    return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000;\n}\n#endif\n\nint64_t ggml_cycles(void) {\n    return clock();\n}\n\nint64_t ggml_cycles_per_ms(void) {\n    return CLOCKS_PER_SEC/1000;\n}\n\n//\n// cross-platform UTF-8 file paths\n//\n\n#ifdef _WIN32\nstatic wchar_t * ggml_mbstowcs(const char * mbs) {\n    int wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, NULL, 0);\n    if (!wlen) {\n        errno = EINVAL;\n        return NULL;\n    }\n\n    wchar_t * wbuf = GGML_MALLOC(wlen * sizeof(wchar_t));\n    wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, wbuf, wlen);\n    if (!wlen) {\n        GGML_FREE(wbuf);\n        errno = EINVAL;\n        return NULL;\n    }\n\n    return wbuf;\n}\n#endif\n\nFILE * ggml_fopen(const char * fname, const char * mode) {\n#ifdef _WIN32\n    FILE * file = NULL;\n\n    // convert fname (UTF-8)\n    wchar_t * wfname = ggml_mbstowcs(fname);\n    if (wfname) {\n        // convert mode (ANSI)\n        wchar_t * wmode = GGML_MALLOC((strlen(mode) + 1) * sizeof(wchar_t));\n        wchar_t * wmode_p = wmode;\n        do {\n            *wmode_p++ = (wchar_t)*mode;\n        } while (*mode++);\n\n        // open file\n        file = _wfopen(wfname, wmode);\n\n        GGML_FREE(wfname);\n        GGML_FREE(wmode);\n    }\n\n    return file;\n#else\n    return fopen(fname, mode);\n#endif\n\n}\n\nstatic const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {\n    [GGML_TYPE_I8] = {\n        .type_name                = \"i8\",\n        .blck_size                = 1,\n        .type_size                = sizeof(int8_t),\n        .is_quantized             = false,\n    },\n    [GGML_TYPE_I16] = {\n        .type_name                = \"i16\",\n        .blck_size                = 1,\n        .type_size                = sizeof(int16_t),\n        .is_quantized             = false,\n    },\n    [GGML_TYPE_I32] = {\n        .type_name                = \"i32\",\n        .blck_size                = 1,\n        .type_size                = sizeof(int32_t),\n        .is_quantized             = false,\n    },\n    [GGML_TYPE_I64] = {\n        .type_name                = \"i64\",\n        .blck_size                = 1,\n        .type_size                = sizeof(int64_t),\n        .is_quantized             = false,\n    },\n    [GGML_TYPE_F64] = {\n        .type_name                = \"f64\",\n        .blck_size                = 1,\n        .type_size                = sizeof(double),\n        .is_quantized             = false,\n    },\n    [GGML_TYPE_F32] = {\n        .type_name                = \"f32\",\n        .blck_size                = 1,\n        .type_size                = sizeof(float),\n        .is_quantized             = false,\n    },\n    [GGML_TYPE_F16] = {\n        .type_name                = \"f16\",\n        .blck_size                = 1,\n        .type_size                = sizeof(ggml_fp16_t),\n        .is_quantized             = false,\n        .to_float                 = (ggml_to_float_t) ggml_fp16_to_fp32_row,\n        .from_float_ref           = (ggml_from_float_t) ggml_fp32_to_fp16_row,\n    },\n    [GGML_TYPE_Q4_0] = {\n        .type_name                = \"q4_0\",\n        .blck_size                = QK4_0,\n        .type_size                = sizeof(block_q4_0),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_q4_0,\n        .from_float_ref           = (ggml_from_float_t) quantize_row_q4_0_ref,\n    },\n    [GGML_TYPE_Q4_1] = {\n        .type_name                = \"q4_1\",\n        .blck_size                = QK4_1,\n        .type_size                = sizeof(block_q4_1),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_q4_1,\n        .from_float_ref           = (ggml_from_float_t) quantize_row_q4_1_ref,\n    },\n    [4] = { // GGML_TYPE_Q4_2\n        .type_name                = \"DEPRECATED\",\n        .blck_size                = 0,\n        .type_size                = 0,\n        .is_quantized             = false,\n    },\n    [5] = { // GGML_TYPE_Q4_3\n        .type_name                = \"DEPRECATED\",\n        .blck_size                = 0,\n        .type_size                = 0,\n        .is_quantized             = false,\n    },\n    [GGML_TYPE_Q5_0] = {\n        .type_name                = \"q5_0\",\n        .blck_size                = QK5_0,\n        .type_size                = sizeof(block_q5_0),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_q5_0,\n        .from_float_ref           = (ggml_from_float_t) quantize_row_q5_0_ref,\n    },\n    [GGML_TYPE_Q5_1] = {\n        .type_name                = \"q5_1\",\n        .blck_size                = QK5_1,\n        .type_size                = sizeof(block_q5_1),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_q5_1,\n        .from_float_ref           = (ggml_from_float_t) quantize_row_q5_1_ref,\n    },\n    [GGML_TYPE_Q8_0] = {\n        .type_name                = \"q8_0\",\n        .blck_size                = QK8_0,\n        .type_size                = sizeof(block_q8_0),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_q8_0,\n        .from_float_ref           = (ggml_from_float_t) quantize_row_q8_0_ref,\n    },\n    [GGML_TYPE_Q8_1] = {\n        .type_name                = \"q8_1\",\n        .blck_size                = QK8_1,\n        .type_size                = sizeof(block_q8_1),\n        .is_quantized             = true,\n        .from_float_ref           = (ggml_from_float_t) quantize_row_q8_1_ref,\n    },\n    [GGML_TYPE_MXFP4] = {\n        .type_name                = \"mxfp4\",\n        .blck_size                = QK_MXFP4,\n        .type_size                = sizeof(block_mxfp4),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_mxfp4,\n        .from_float_ref           = (ggml_from_float_t)quantize_row_mxfp4_ref,\n    },\n    [GGML_TYPE_NVFP4] = {\n        .type_name                = \"nvfp4\",\n        .blck_size                = QK_NVFP4,\n        .type_size                = sizeof(block_nvfp4),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_nvfp4,\n        .from_float_ref           = (ggml_from_float_t)quantize_row_nvfp4_ref,\n    },\n    [GGML_TYPE_Q2_K] = {\n        .type_name                = \"q2_K\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_q2_K),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_q2_K,\n        .from_float_ref           = (ggml_from_float_t) quantize_row_q2_K_ref,\n    },\n    [GGML_TYPE_Q3_K] = {\n        .type_name                = \"q3_K\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_q3_K),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_q3_K,\n        .from_float_ref           = (ggml_from_float_t) quantize_row_q3_K_ref,\n    },\n    [GGML_TYPE_Q4_K] = {\n        .type_name                = \"q4_K\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_q4_K),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_q4_K,\n        .from_float_ref           = (ggml_from_float_t) quantize_row_q4_K_ref,\n    },\n    [GGML_TYPE_Q5_K] = {\n        .type_name                = \"q5_K\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_q5_K),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_q5_K,\n        .from_float_ref           = (ggml_from_float_t) quantize_row_q5_K_ref,\n    },\n    [GGML_TYPE_Q6_K] = {\n        .type_name                = \"q6_K\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_q6_K),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_q6_K,\n        .from_float_ref           = (ggml_from_float_t) quantize_row_q6_K_ref,\n    },\n    [GGML_TYPE_IQ2_XXS] = {\n        .type_name                = \"iq2_xxs\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_iq2_xxs),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_iq2_xxs,\n        .from_float_ref           = NULL,\n    },\n    [GGML_TYPE_IQ2_XS] = {\n        .type_name                = \"iq2_xs\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_iq2_xs),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_iq2_xs,\n        .from_float_ref           = NULL,\n    },\n    [GGML_TYPE_IQ3_XXS] = {\n        .type_name                = \"iq3_xxs\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_iq3_xxs),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_iq3_xxs,\n        .from_float_ref           = (ggml_from_float_t)quantize_row_iq3_xxs_ref,\n    },\n    [GGML_TYPE_IQ3_S] = {\n        .type_name                = \"iq3_s\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_iq3_s),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_iq3_s,\n        .from_float_ref           = (ggml_from_float_t)quantize_row_iq3_s_ref,\n    },\n    [GGML_TYPE_IQ2_S] = {\n        .type_name                = \"iq2_s\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_iq2_s),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_iq2_s,\n        .from_float_ref           = (ggml_from_float_t)quantize_row_iq2_s_ref,\n    },\n    [GGML_TYPE_IQ1_S] = {\n        .type_name                = \"iq1_s\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_iq1_s),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_iq1_s,\n        .from_float_ref           = NULL,\n    },\n    [GGML_TYPE_IQ1_M] = {\n        .type_name                = \"iq1_m\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_iq1_m),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_iq1_m,\n        .from_float_ref           = NULL,\n    },\n    [GGML_TYPE_IQ4_NL] = {\n        .type_name                = \"iq4_nl\",\n        .blck_size                = QK4_NL,\n        .type_size                = sizeof(block_iq4_nl),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_iq4_nl,\n        .from_float_ref           = (ggml_from_float_t)quantize_row_iq4_nl_ref,\n    },\n    [GGML_TYPE_IQ4_XS] = {\n        .type_name                = \"iq4_xs\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_iq4_xs),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_iq4_xs,\n        .from_float_ref           = (ggml_from_float_t)quantize_row_iq4_xs_ref,\n    },\n    [GGML_TYPE_Q8_K] = {\n        .type_name                = \"q8_K\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_q8_K),\n        .is_quantized             = true,\n    },\n    [GGML_TYPE_BF16] = {\n        .type_name                = \"bf16\",\n        .blck_size                = 1,\n        .type_size                = sizeof(ggml_bf16_t),\n        .is_quantized             = false,\n        .to_float                 = (ggml_to_float_t) ggml_bf16_to_fp32_row,\n        .from_float_ref           = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref,\n    },\n    [31] = { // GGML_TYPE_Q4_0_4_4\n        .type_name                = \"TYPE_Q4_0_4_4 REMOVED, use Q4_0 with runtime repacking\",\n        .blck_size                = 0,\n        .type_size                = 0,\n        .is_quantized             = false,\n    },\n    [32] = { // GGML_TYPE_Q4_0_4_8\n        .type_name                = \"TYPE_Q4_0_4_8 REMOVED, use Q4_0 with runtime repacking\",\n        .blck_size                = 0,\n        .type_size                = 0,\n        .is_quantized             = false,\n    },\n    [33] = { // GGML_TYPE_Q4_0_8_8\n        .type_name                = \"TYPE_Q4_0_8_8 REMOVED, use Q4_0 with runtime repacking\",\n        .blck_size                = 0,\n        .type_size                = 0,\n        .is_quantized             = false,\n    },\n    [GGML_TYPE_TQ1_0] = {\n        .type_name                = \"tq1_0\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_tq1_0),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_tq1_0,\n        .from_float_ref           = (ggml_from_float_t) quantize_row_tq1_0_ref,\n    },\n    [GGML_TYPE_TQ2_0] = {\n        .type_name                = \"tq2_0\",\n        .blck_size                = QK_K,\n        .type_size                = sizeof(block_tq2_0),\n        .is_quantized             = true,\n        .to_float                 = (ggml_to_float_t) dequantize_row_tq2_0,\n        .from_float_ref           = (ggml_from_float_t) quantize_row_tq2_0_ref,\n    },\n    [36] = { // GGML_TYPE_IQ4_NL_4_4\n        .type_name                = \"TYPE_IQ4_NL_4_4 REMOVED, use IQ4_NL with runtime repacking\",\n        .blck_size                = 0,\n        .type_size                = 0,\n        .is_quantized             = false,\n    },\n    [37] = { // GGML_TYPE_IQ4_NL_4_8\n        .type_name                = \"TYPE_IQ4_NL_4_8 REMOVED, use IQ4_NL with runtime repacking\",\n        .blck_size                = 0,\n        .type_size                = 0,\n        .is_quantized             = false,\n    },\n    [38] = { // GGML_TYPE_IQ4_NL_8_8\n        .type_name                = \"TYPE_IQ4_NL_8_8 REMOVED, use IQ4_NL with runtime repacking\",\n        .blck_size                = 0,\n        .type_size                = 0,\n        .is_quantized             = false,\n    },\n};\n\nconst struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {\n    assert(type >= 0);\n    assert(type < GGML_TYPE_COUNT);\n    return &type_traits[type];\n}\n\n//\n// ggml object\n//\n\nstruct ggml_object {\n    size_t offs;\n    size_t size;\n\n    struct ggml_object * next;\n\n    enum ggml_object_type type;\n\n    char padding[4];\n};\n\nstatic const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);\n\n//\n// ggml context\n//\n\nstruct ggml_context {\n    size_t mem_size;\n    void * mem_buffer;\n    bool   mem_buffer_owned;\n    bool   no_alloc;\n\n    int    n_objects;\n\n    struct ggml_object * objects_begin;\n    struct ggml_object * objects_end;\n};\n\n//\n// data types\n//\n\nstatic const char * GGML_OP_NAME[GGML_OP_COUNT] = {\n    \"NONE\",\n\n    \"DUP\",\n    \"ADD\",\n    \"ADD_ID\",\n    \"ADD1\",\n    \"ACC\",\n    \"SUB\",\n    \"MUL\",\n    \"DIV\",\n    \"SQR\",\n    \"SQRT\",\n    \"LOG\",\n    \"SIN\",\n    \"COS\",\n    \"SUM\",\n    \"SUM_ROWS\",\n    \"CUMSUM\",\n    \"MEAN\",\n    \"ARGMAX\",\n    \"COUNT_EQUAL\",\n    \"REPEAT\",\n    \"REPEAT_BACK\",\n    \"CONCAT\",\n    \"SILU_BACK\",\n    \"NORM\",\n    \"RMS_NORM\",\n    \"RMS_NORM_BACK\",\n    \"GROUP_NORM\",\n    \"L2_NORM\",\n\n    \"MUL_MAT\",\n    \"MUL_MAT_ID\",\n    \"OUT_PROD\",\n\n    \"SCALE\",\n    \"SET\",\n    \"CPY\",\n    \"CONT\",\n    \"RESHAPE\",\n    \"VIEW\",\n    \"PERMUTE\",\n    \"TRANSPOSE\",\n    \"GET_ROWS\",\n    \"GET_ROWS_BACK\",\n    \"SET_ROWS\",\n    \"DIAG\",\n    \"DIAG_MASK_INF\",\n    \"DIAG_MASK_ZERO\",\n    \"SOFT_MAX\",\n    \"SOFT_MAX_BACK\",\n    \"ROPE\",\n    \"ROPE_BACK\",\n    \"CLAMP\",\n    \"CONV_TRANSPOSE_1D\",\n    \"IM2COL\",\n    \"IM2COL_BACK\",\n    \"IM2COL_3D\",\n    \"CONV_2D\",\n    \"CONV_3D\",\n    \"CONV_2D_DW\",\n    \"CONV_TRANSPOSE_2D\",\n    \"POOL_1D\",\n    \"POOL_2D\",\n    \"POOL_2D_BACK\",\n    \"UPSCALE\",\n    \"PAD\",\n    \"PAD_REFLECT_1D\",\n    \"ROLL\",\n    \"ARANGE\",\n    \"TIMESTEP_EMBEDDING\",\n    \"ARGSORT\",\n    \"TOP_K\",\n    \"LEAKY_RELU\",\n    \"TRI\",\n    \"FILL\",\n\n    \"FLASH_ATTN_EXT\",\n    \"FLASH_ATTN_BACK\",\n    \"SSM_CONV\",\n    \"SSM_SCAN\",\n    \"WIN_PART\",\n    \"WIN_UNPART\",\n    \"GET_REL_POS\",\n    \"ADD_REL_POS\",\n    \"RWKV_WKV6\",\n    \"GATED_LINEAR_ATTN\",\n    \"RWKV_WKV7\",\n    \"SOLVE_TRI\",\n    \"GATED_DELTA_NET\",\n\n    \"UNARY\",\n\n    \"MAP_CUSTOM1\",\n    \"MAP_CUSTOM2\",\n    \"MAP_CUSTOM3\",\n\n    \"CUSTOM\",\n\n    \"CROSS_ENTROPY_LOSS\",\n    \"CROSS_ENTROPY_LOSS_BACK\",\n    \"OPT_STEP_ADAMW\",\n    \"OPT_STEP_SGD\",\n\n    \"GLU\",\n};\n\nstatic_assert(GGML_OP_COUNT == 96, \"GGML_OP_COUNT != 96\");\n\nstatic const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {\n    \"none\",\n\n    \"x\",\n    \"x+y\",\n    \"x[i]+y\",\n    \"x+y\",\n    \"view(x,nb,offset)+=y->x\",\n    \"x-y\",\n    \"x*y\",\n    \"x/y\",\n    \"x^2\",\n    \"√x\",\n    \"log(x)\",\n    \"sin(x)\",\n    \"cos(x)\",\n    \"Σx\",\n    \"Σx_k\",\n    \"cumsum(x)\",\n    \"Σx/n\",\n    \"argmax(x)\",\n    \"count_equal(x)\",\n    \"repeat(x)\",\n    \"repeat_back(x)\",\n    \"concat(x, y)\",\n    \"silu_back(x)\",\n    \"norm(x)\",\n    \"rms_norm(x)\",\n    \"rms_norm_back(x)\",\n    \"group_norm(x)\",\n    \"l2_norm(x)\",\n\n    \"X*Y\",\n    \"X[i]*Y\",\n    \"X*Y\",\n\n    \"x*v\",\n    \"y-\\\\>view(x)\",\n    \"x-\\\\>y\",\n    \"cont(x)\",\n    \"reshape(x)\",\n    \"view(x)\",\n    \"permute(x)\",\n    \"transpose(x)\",\n    \"get_rows(x)\",\n    \"get_rows_back(x)\",\n    \"set_rows(x)\",\n    \"diag(x)\",\n    \"diag_mask_inf(x)\",\n    \"diag_mask_zero(x)\",\n    \"soft_max(x)\",\n    \"soft_max_back(x)\",\n    \"rope(x)\",\n    \"rope_back(x)\",\n    \"clamp(x)\",\n    \"conv_transpose_1d(x)\",\n    \"im2col(x)\",\n    \"im2col_back(x)\",\n    \"im2col_3d(x)\",\n    \"conv_2d(x)\",\n    \"conv_3d(x)\",\n    \"conv_2d_dw(x)\",\n    \"conv_transpose_2d(x)\",\n    \"pool_1d(x)\",\n    \"pool_2d(x)\",\n    \"pool_2d_back(x)\",\n    \"upscale(x)\",\n    \"pad(x)\",\n    \"pad_reflect_1d(x)\",\n    \"roll(x)\",\n    \"arange(start, stop, step)\",\n    \"timestep_embedding(timesteps, dim, max_period)\",\n    \"argsort(x)\",\n    \"top_k(x)\",\n    \"leaky_relu(x)\",\n    \"tri(x)\",\n    \"fill(x, c)\",\n\n    \"flash_attn_ext(x)\",\n    \"flash_attn_back(x)\",\n    \"ssm_conv(x)\",\n    \"ssm_scan(x)\",\n    \"win_part(x)\",\n    \"win_unpart(x)\",\n    \"get_rel_pos(x)\",\n    \"add_rel_pos(x)\",\n    \"rwkv_wkv6(k, v, r, tf, td, s)\",\n    \"gated_linear_attn(k, v, q, gate, s)\",\n    \"rwkv_wkv7(r, w, k, v, a, b, s)\",\n    \"A X = B, A triangular, solve X\",\n    \"gated_delta_net(q, k, v, g, beta, s)\",\n\n    \"unary(x)\",\n\n    \"map_custom(x)\",\n    \"map_custom(x,y)\",\n    \"map_custom(x,y,z)\",\n\n    \"custom(x)\",\n\n    \"cross_entropy_loss(x,y)\",\n    \"cross_entropy_loss_back(x,y)\",\n    \"adamw(x)\",\n    \"sgd(x)\",\n\n    \"glu(x)\",\n};\n\nstatic_assert(GGML_OP_COUNT == 96, \"GGML_OP_COUNT != 96\");\n\nstatic_assert(GGML_OP_POOL_COUNT == 2, \"GGML_OP_POOL_COUNT != 2\");\n\nstatic const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {\n    \"ABS\",\n    \"SGN\",\n    \"NEG\",\n    \"STEP\",\n    \"TANH\",\n    \"ELU\",\n    \"RELU\",\n    \"SIGMOID\",\n    \"GELU\",\n    \"GELU_QUICK\",\n    \"SILU\",\n    \"HARDSWISH\",\n    \"HARDSIGMOID\",\n    \"EXP\",\n    \"EXPM1\",\n    \"SOFTPLUS\",\n    \"GELU_ERF\",\n    \"XIELU\",\n    \"FLOOR\",\n    \"CEIL\",\n    \"ROUND\",\n    \"TRUNC\",\n};\n\nstatic_assert(GGML_UNARY_OP_COUNT == 22, \"GGML_UNARY_OP_COUNT != 22\");\n\nstatic const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {\n    \"REGLU\",\n    \"GEGLU\",\n    \"SWIGLU\",\n    \"SWIGLU_OAI\",\n    \"GEGLU_ERF\",\n    \"GEGLU_QUICK\",\n};\n\nstatic_assert(GGML_GLU_OP_COUNT == 6, \"GGML_GLU_OP_COUNT != 6\");\n\n\nstatic_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, \"ggml_object size must be a multiple of GGML_MEM_ALIGN\");\nstatic_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, \"ggml_tensor size must be a multiple of GGML_MEM_ALIGN\");\n\n\n////////////////////////////////////////////////////////////////////////////////\n\nvoid ggml_print_object(const struct ggml_object * obj) {\n    GGML_LOG_INFO(\" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\\n\",\n            obj->type, obj->offs, obj->size, (const void *) obj->next);\n}\n\nvoid ggml_print_objects(const struct ggml_context * ctx) {\n    struct ggml_object * obj = ctx->objects_begin;\n\n    GGML_LOG_INFO(\"%s: objects in context %p:\\n\", __func__, (const void *) ctx);\n\n    while (obj != NULL) {\n        ggml_print_object(obj);\n        obj = obj->next;\n    }\n\n    GGML_LOG_INFO(\"%s: --- end ---\\n\", __func__);\n}\n\nint64_t ggml_nelements(const struct ggml_tensor * tensor) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];\n}\n\nint64_t ggml_nrows(const struct ggml_tensor * tensor) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];\n}\n\nsize_t ggml_nbytes(const struct ggml_tensor * tensor) {\n    for (int i = 0; i < GGML_MAX_DIMS; ++i) {\n        if (tensor->ne[i] <= 0) {\n            return 0;\n        }\n    }\n\n    size_t nbytes;\n    const size_t blck_size = ggml_blck_size(tensor->type);\n    if (blck_size == 1) {\n        nbytes = ggml_type_size(tensor->type);\n        for (int i = 0; i < GGML_MAX_DIMS; ++i) {\n            nbytes += (tensor->ne[i] - 1)*tensor->nb[i];\n        }\n    }\n    else {\n        nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;\n        for (int i = 1; i < GGML_MAX_DIMS; ++i) {\n            nbytes += (tensor->ne[i] - 1)*tensor->nb[i];\n        }\n    }\n\n    return nbytes;\n}\n\nsize_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {\n    return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN);\n}\n\nint64_t ggml_blck_size(enum ggml_type type) {\n    assert(type >= 0);\n    assert(type < GGML_TYPE_COUNT);\n    return type_traits[type].blck_size;\n}\n\nsize_t ggml_type_size(enum ggml_type type) {\n    assert(type >= 0);\n    assert(type < GGML_TYPE_COUNT);\n    return type_traits[type].type_size;\n}\n\nsize_t ggml_row_size(enum ggml_type type, int64_t ne) {\n    assert(type >= 0);\n    assert(type < GGML_TYPE_COUNT);\n    assert(ne % ggml_blck_size(type) == 0);\n    return ggml_type_size(type)*ne/ggml_blck_size(type);\n}\n\ndouble ggml_type_sizef(enum ggml_type type) {\n    assert(type >= 0);\n    assert(type < GGML_TYPE_COUNT);\n    return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;\n}\n\nconst char * ggml_type_name(enum ggml_type type) {\n    assert(type >= 0);\n    assert(type < GGML_TYPE_COUNT);\n    return type_traits[type].type_name;\n}\n\nbool ggml_is_quantized(enum ggml_type type) {\n    assert(type >= 0);\n    assert(type < GGML_TYPE_COUNT);\n    return type_traits[type].is_quantized;\n}\n\nconst char * ggml_op_name(enum ggml_op op) {\n    return GGML_OP_NAME[op];\n}\n\nconst char * ggml_op_symbol(enum ggml_op op) {\n    return GGML_OP_SYMBOL[op];\n}\n\nconst char * ggml_unary_op_name(enum ggml_unary_op op) {\n    return GGML_UNARY_OP_NAME[op];\n}\n\nconst char * ggml_glu_op_name(enum ggml_glu_op op) {\n    return GGML_GLU_OP_NAME[op];\n}\n\nconst char * ggml_op_desc(const struct ggml_tensor * t) {\n    if (t->op == GGML_OP_UNARY) {\n        enum ggml_unary_op uop = ggml_get_unary_op(t);\n        return ggml_unary_op_name(uop);\n    }\n    if (t->op == GGML_OP_GLU) {\n        enum ggml_glu_op gop = ggml_get_glu_op(t);\n        return ggml_glu_op_name(gop);\n    }\n    return ggml_op_name(t->op);\n}\n\nsize_t ggml_element_size(const struct ggml_tensor * tensor) {\n    return ggml_type_size(tensor->type);\n}\n\nbool ggml_is_scalar(const struct ggml_tensor * tensor) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;\n}\n\nbool ggml_is_vector(const struct ggml_tensor * tensor) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;\n}\n\nbool ggml_is_matrix(const struct ggml_tensor * tensor) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return tensor->ne[2] == 1 && tensor->ne[3] == 1;\n}\n\nbool ggml_is_3d(const struct ggml_tensor * tensor) {\n    return tensor->ne[3] == 1;\n}\n\nint ggml_n_dims(const struct ggml_tensor * tensor) {\n    for (int i = GGML_MAX_DIMS - 1; i >= 1; --i) {\n        if (tensor->ne[i] > 1) {\n            return i + 1;\n        }\n    }\n    return 1;\n}\n\nenum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {\n    enum ggml_type wtype = GGML_TYPE_COUNT;\n\n    switch (ftype) {\n        case GGML_FTYPE_ALL_F32:              wtype = GGML_TYPE_F32;   break;\n        case GGML_FTYPE_MOSTLY_F16:           wtype = GGML_TYPE_F16;   break;\n        case GGML_FTYPE_MOSTLY_BF16:          wtype = GGML_TYPE_BF16;  break;\n        case GGML_FTYPE_MOSTLY_Q4_0:          wtype = GGML_TYPE_Q4_0;  break;\n        case GGML_FTYPE_MOSTLY_Q4_1:          wtype = GGML_TYPE_Q4_1;  break;\n        case GGML_FTYPE_MOSTLY_Q5_0:          wtype = GGML_TYPE_Q5_0;  break;\n        case GGML_FTYPE_MOSTLY_Q5_1:          wtype = GGML_TYPE_Q5_1;  break;\n        case GGML_FTYPE_MOSTLY_Q8_0:          wtype = GGML_TYPE_Q8_0;  break;\n        case GGML_FTYPE_MOSTLY_MXFP4:         wtype = GGML_TYPE_MXFP4; break;\n        case GGML_FTYPE_MOSTLY_NVFP4:         wtype = GGML_TYPE_NVFP4; break;\n        case GGML_FTYPE_MOSTLY_Q2_K:          wtype = GGML_TYPE_Q2_K;  break;\n        case GGML_FTYPE_MOSTLY_Q3_K:          wtype = GGML_TYPE_Q3_K;  break;\n        case GGML_FTYPE_MOSTLY_Q4_K:          wtype = GGML_TYPE_Q4_K;  break;\n        case GGML_FTYPE_MOSTLY_Q5_K:          wtype = GGML_TYPE_Q5_K;  break;\n        case GGML_FTYPE_MOSTLY_Q6_K:          wtype = GGML_TYPE_Q6_K;  break;\n        case GGML_FTYPE_MOSTLY_IQ2_XXS:       wtype = GGML_TYPE_IQ2_XXS;  break;\n        case GGML_FTYPE_MOSTLY_IQ2_XS:        wtype = GGML_TYPE_IQ2_XS;   break;\n        case GGML_FTYPE_MOSTLY_IQ3_XXS:       wtype = GGML_TYPE_IQ3_XXS;  break;\n        case GGML_FTYPE_MOSTLY_IQ1_S:         wtype = GGML_TYPE_IQ1_S;    break;\n        case GGML_FTYPE_MOSTLY_IQ1_M:         wtype = GGML_TYPE_IQ1_M;    break;\n        case GGML_FTYPE_MOSTLY_IQ4_NL:        wtype = GGML_TYPE_IQ4_NL;   break;\n        case GGML_FTYPE_MOSTLY_IQ4_XS:        wtype = GGML_TYPE_IQ4_XS;   break;\n        case GGML_FTYPE_MOSTLY_IQ3_S:         wtype = GGML_TYPE_IQ3_S;    break;\n        case GGML_FTYPE_MOSTLY_IQ2_S:         wtype = GGML_TYPE_IQ2_S;    break;\n        case GGML_FTYPE_UNKNOWN:              wtype = GGML_TYPE_COUNT; break;\n        case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;\n    }\n\n    GGML_ASSERT(wtype != GGML_TYPE_COUNT);\n\n    return wtype;\n}\n\nsize_t ggml_tensor_overhead(void) {\n    return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE;\n}\n\nbool ggml_is_transposed(const struct ggml_tensor * tensor) {\n    return tensor->nb[0] > tensor->nb[1];\n}\n\nstatic bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) {\n    size_t next_nb = ggml_type_size(tensor->type);\n    if (tensor->ne[0] != ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) {\n        return false;\n    }\n    next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type);\n    for (int i = 1; i < GGML_MAX_DIMS; i++) {\n        if (i > n) {\n            if (tensor->ne[i] != 1 && tensor->nb[i] != next_nb) {\n                return false;\n            }\n            next_nb *= tensor->ne[i];\n        } else {\n            // this dimension does not need to be contiguous\n            next_nb = tensor->ne[i]*tensor->nb[i];\n        }\n    }\n    return true;\n}\n\nbool ggml_is_contiguous(const struct ggml_tensor * tensor) {\n    return ggml_is_contiguous_0(tensor);\n}\n\nbool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {\n    return ggml_is_contiguous_n(tensor, 0);\n}\n\nbool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {\n    return ggml_is_contiguous_n(tensor, 1);\n}\n\nbool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {\n    return ggml_is_contiguous_n(tensor, 2);\n}\n\nbool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) {\n    return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);\n}\n\nbool ggml_is_permuted(const struct ggml_tensor * tensor) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];\n}\n\nbool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {\n    return\n        tensor->nb[0] > tensor->nb[2] &&\n        tensor->nb[1] > tensor->nb[0] &&\n        tensor->nb[2] == ggml_type_size(tensor->type);\n}\n\nbool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {\n    return\n        tensor->ne[0] == ggml_blck_size(tensor->type) ||\n        tensor->nb[0] == ggml_type_size(tensor->type);\n}\n\nstatic inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return\n        tensor->nb[0] == ggml_type_size(tensor->type) &&\n        tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&\n        tensor->nb[3] == tensor->nb[2]*tensor->ne[2];\n}\n\nbool ggml_is_empty(const struct ggml_tensor * tensor) {\n    for (int i = 0; i < GGML_MAX_DIMS; ++i) {\n        if (tensor->ne[i] == 0) {\n            // empty if any dimension has no elements\n            return true;\n        }\n    }\n    return false;\n}\n\nbool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return\n        (t0->ne[0] == t1->ne[0]) &&\n        (t0->ne[1] == t1->ne[1]) &&\n        (t0->ne[2] == t1->ne[2]) &&\n        (t0->ne[3] == t1->ne[3]);\n}\n\nbool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return\n        (t0->nb[0] == t1->nb[0]) &&\n        (t0->nb[1] == t1->nb[1]) &&\n        (t0->nb[2] == t1->nb[2]) &&\n        (t0->nb[3] == t1->nb[3]);\n}\n\nbool ggml_is_view(const struct ggml_tensor * t) {\n    return ggml_impl_is_view(t);\n}\n\n// check if t1 can be represented as a repetition of t0\nbool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return ggml_is_empty(t0) ? ggml_is_empty(t1) :\n        (t1->ne[0]%t0->ne[0] == 0) &&\n        (t1->ne[1]%t0->ne[1] == 0) &&\n        (t1->ne[2]%t0->ne[2] == 0) &&\n        (t1->ne[3]%t0->ne[3] == 0);\n}\n\nstatic inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1);\n}\n\n// assert that pointer is aligned to GGML_MEM_ALIGN\n#define GGML_ASSERT_ALIGNED(ptr) \\\n    GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)\n\n////////////////////////////////////////////////////////////////////////////////\n\nstruct ggml_context * ggml_init(struct ggml_init_params params) {\n    static bool is_first_call = true;\n\n    ggml_critical_section_start();\n\n    if (is_first_call) {\n        // initialize time system (required on Windows)\n        ggml_time_init();\n\n        is_first_call = false;\n    }\n\n    ggml_critical_section_end();\n\n    struct ggml_context * ctx = GGML_MALLOC(sizeof(struct ggml_context));\n\n    // allow to call ggml_init with 0 size\n    if (params.mem_size == 0) {\n        params.mem_size = GGML_MEM_ALIGN;\n    }\n\n    const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);\n\n    *ctx = (struct ggml_context) {\n        /*.mem_size           =*/ mem_size,\n        /*.mem_buffer         =*/ params.mem_buffer ? params.mem_buffer : ggml_aligned_malloc(mem_size),\n        /*.mem_buffer_owned   =*/ params.mem_buffer ? false : true,\n        /*.no_alloc           =*/ params.no_alloc,\n        /*.n_objects          =*/ 0,\n        /*.objects_begin      =*/ NULL,\n        /*.objects_end        =*/ NULL,\n    };\n\n    GGML_ASSERT(ctx->mem_buffer != NULL);\n\n    GGML_ASSERT_ALIGNED(ctx->mem_buffer);\n\n    GGML_PRINT_DEBUG(\"%s: context initialized\\n\", __func__);\n\n    return ctx;\n}\n\nvoid ggml_reset(struct ggml_context * ctx) {\n    if (ctx == NULL) {\n        return;\n    }\n\n    ctx->n_objects     = 0;\n    ctx->objects_begin = NULL;\n    ctx->objects_end   = NULL;\n}\n\nvoid ggml_free(struct ggml_context * ctx) {\n    if (ctx == NULL) {\n        return;\n    }\n\n    if (ctx->mem_buffer_owned) {\n        ggml_aligned_free(ctx->mem_buffer, ctx->mem_size);\n    }\n\n    GGML_FREE(ctx);\n}\n\nsize_t ggml_used_mem(const struct ggml_context * ctx) {\n    return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size;\n}\n\nbool ggml_get_no_alloc(struct ggml_context * ctx) {\n    return ctx->no_alloc;\n}\n\nvoid ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {\n    ctx->no_alloc = no_alloc;\n}\n\nvoid * ggml_get_mem_buffer(const struct ggml_context * ctx) {\n    return ctx->mem_buffer;\n}\n\nsize_t ggml_get_mem_size(const struct ggml_context * ctx) {\n    return ctx->mem_size;\n}\n\nsize_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {\n    size_t max_size = 0;\n\n    for (struct ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor != NULL; tensor = ggml_get_next_tensor(ctx, tensor)) {\n        size_t bytes = ggml_nbytes(tensor);\n        max_size = MAX(max_size, bytes);\n    }\n\n    return max_size;\n}\n\n////////////////////////////////////////////////////////////////////////////////\n\nstatic struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) {\n    // always insert objects at the end of the context's memory pool\n    struct ggml_object * obj_cur = ctx->objects_end;\n\n    const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs;\n    const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;\n    const size_t cur_end  = cur_offs + cur_size;\n\n    // align to GGML_MEM_ALIGN\n    GGML_ASSERT(size <= SIZE_MAX - (GGML_MEM_ALIGN - 1));\n    size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN);\n\n    char * const mem_buffer = ctx->mem_buffer;\n    struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);\n\n    // integer overflow checks\n    if (cur_end > SIZE_MAX - size_needed) {\n        GGML_LOG_WARN(\"%s: overflow detected in cur_end (%zu) + size_needed (%zu)\\n\", __func__, cur_end, size_needed);\n        return NULL;\n    }\n    if (cur_end + size_needed > SIZE_MAX - GGML_OBJECT_SIZE) {\n        GGML_LOG_WARN(\"%s: overflow detected in cur_end (%zu) + size_needed (%zu) + GGML_OBJECT_SIZE (%zu)\\n\", __func__,\n                cur_end, size_needed, (size_t) GGML_OBJECT_SIZE);\n        return NULL;\n    }\n\n    if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {\n        GGML_LOG_WARN(\"%s: not enough space in the context's memory pool (needed %zu, available %zu)\\n\",\n                __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);\n#ifndef NDEBUG\n        GGML_ABORT(\"not enough space in the context's memory pool\");\n#endif\n        return NULL;\n    }\n\n    *obj_new = (struct ggml_object) {\n        .offs = cur_end + GGML_OBJECT_SIZE,\n        .size = size_needed,\n        .next = NULL,\n        .type = type,\n    };\n\n    GGML_ASSERT_ALIGNED(mem_buffer + obj_new->offs);\n\n    if (obj_cur != NULL) {\n        obj_cur->next = obj_new;\n    } else {\n        // this is the first object in this context\n        ctx->objects_begin = obj_new;\n    }\n\n    ctx->objects_end = obj_new;\n\n    //printf(\"%s: inserted new object at %zu, size = %zu\\n\", __func__, cur_end, obj_new->size);\n\n    return obj_new;\n}\n\nstatic struct ggml_tensor * ggml_new_tensor_impl(\n        struct ggml_context * ctx,\n        enum   ggml_type      type,\n        int                   n_dims,\n        const int64_t       * ne,\n        struct ggml_tensor  * view_src,\n        size_t                view_offs) {\n\n    GGML_ASSERT(type >= 0 && type < GGML_TYPE_COUNT);\n    GGML_ASSERT(n_dims >= 1 && n_dims <= GGML_MAX_DIMS);\n\n    // find the base tensor and absolute offset\n    if (view_src != NULL && view_src->view_src != NULL) {\n        view_offs += view_src->view_offs;\n        view_src   = view_src->view_src;\n    }\n\n    size_t data_size = ggml_row_size(type, ne[0]);\n    for (int i = 1; i < n_dims; i++) {\n        data_size *= ne[i];\n    }\n\n    GGML_ASSERT(view_src == NULL || data_size == 0 || data_size + view_offs <= ggml_nbytes(view_src));\n\n    void * data = view_src != NULL ? view_src->data : NULL;\n    if (data != NULL) {\n        data = (char *) data + view_offs;\n    }\n\n    size_t obj_alloc_size = 0;\n\n    if (view_src == NULL && !ctx->no_alloc) {\n        // allocate tensor data in the context's memory pool\n        obj_alloc_size = data_size;\n    }\n\n    GGML_ASSERT(GGML_TENSOR_SIZE <= SIZE_MAX - obj_alloc_size);\n\n    struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size);\n    GGML_ASSERT(obj_new);\n\n    struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);\n\n    *result = (struct ggml_tensor) {\n        /*.type         =*/ type,\n        /*.buffer       =*/ NULL,\n        /*.ne           =*/ { 1, 1, 1, 1 },\n        /*.nb           =*/ { 0, 0, 0, 0 },\n        /*.op           =*/ GGML_OP_NONE,\n        /*.op_params    =*/ { 0 },\n        /*.flags        =*/ 0,\n        /*.src          =*/ { NULL },\n        /*.view_src     =*/ view_src,\n        /*.view_offs    =*/ view_offs,\n        /*.data         =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,\n        /*.name         =*/ { 0 },\n        /*.extra        =*/ NULL,\n        /*.padding      =*/ { 0 },\n    };\n\n    // TODO: this should not be needed as long as we don't rely on aligned SIMD loads\n    //GGML_ASSERT_ALIGNED(result->data);\n\n    for (int i = 0; i < n_dims; i++) {\n        result->ne[i] = ne[i];\n    }\n\n    result->nb[0] = ggml_type_size(type);\n    result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type));\n    for (int i = 2; i < GGML_MAX_DIMS; i++) {\n        result->nb[i] = result->nb[i - 1]*result->ne[i - 1];\n    }\n\n    ctx->n_objects++;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_new_tensor(\n        struct ggml_context * ctx,\n        enum   ggml_type      type,\n        int                   n_dims,\n        const int64_t       * ne) {\n    return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL, 0);\n}\n\nstruct ggml_tensor * ggml_new_tensor_1d(\n        struct ggml_context * ctx,\n        enum   ggml_type      type,\n        int64_t ne0) {\n    return ggml_new_tensor(ctx, type, 1, &ne0);\n}\n\nstruct ggml_tensor * ggml_new_tensor_2d(\n        struct ggml_context * ctx,\n        enum   ggml_type      type,\n        int64_t ne0,\n        int64_t ne1) {\n    const int64_t ne[2] = { ne0, ne1 };\n    return ggml_new_tensor(ctx, type, 2, ne);\n}\n\nstruct ggml_tensor * ggml_new_tensor_3d(\n        struct ggml_context * ctx,\n        enum   ggml_type      type,\n        int64_t ne0,\n        int64_t ne1,\n        int64_t ne2) {\n    const int64_t ne[3] = { ne0, ne1, ne2 };\n    return ggml_new_tensor(ctx, type, 3, ne);\n}\n\nstruct ggml_tensor * ggml_new_tensor_4d(\n        struct ggml_context * ctx,\n        enum   ggml_type type,\n        int64_t ne0,\n        int64_t ne1,\n        int64_t ne2,\n        int64_t ne3) {\n    const int64_t ne[4] = { ne0, ne1, ne2, ne3 };\n    return ggml_new_tensor(ctx, type, 4, ne);\n}\n\nvoid * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes) {\n    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, nbytes);\n\n    return (uint8_t *)ctx->mem_buffer + obj->offs;\n}\n\nstruct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) {\n    return ggml_new_tensor(ctx, src->type, GGML_MAX_DIMS, src->ne);\n}\n\nvoid ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {\n    const int64_t ne2 = tensor->ne[2];\n    const int64_t ne1 = tensor->ne[1];\n    const int64_t ne0 = tensor->ne[0];\n\n    const int64_t i3_ = (i/(ne2*ne1*ne0));\n    const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);\n    const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;\n    const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);\n\n    if (i0) {\n        * i0 = i0_;\n    }\n    if (i1) {\n        * i1 = i1_;\n    }\n    if (i2) {\n        * i2 = i2_;\n    }\n    if (i3) {\n        * i3 = i3_;\n    }\n}\n\nvoid * ggml_get_data(const struct ggml_tensor * tensor) {\n    return tensor->data;\n}\n\nfloat * ggml_get_data_f32(const struct ggml_tensor * tensor) {\n    assert(tensor->type == GGML_TYPE_F32);\n    return (float *)(tensor->data);\n}\n\nenum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {\n    GGML_ASSERT(tensor->op == GGML_OP_UNARY);\n    return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);\n}\n\nenum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {\n    GGML_ASSERT(tensor->op == GGML_OP_GLU);\n    return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);\n}\n\nconst char * ggml_get_name(const struct ggml_tensor * tensor) {\n    return tensor->name;\n}\n\nstruct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) {\n    size_t i;\n    for (i = 0; i < sizeof(tensor->name) - 1 && name[i] != '\\0'; i++) {\n        tensor->name[i] = name[i];\n    }\n    tensor->name[i] = '\\0';\n    return tensor;\n}\n\nstruct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...) {\n    va_list args;\n    va_start(args, fmt);\n    vsnprintf(tensor->name, sizeof(tensor->name), fmt, args);\n    va_end(args);\n    return tensor;\n}\n\nstruct ggml_tensor * ggml_view_tensor(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * src) {\n    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, GGML_MAX_DIMS, src->ne, src, 0);\n    ggml_format_name(result, \"%s (view)\", src->name);\n\n    for (int i = 0; i < GGML_MAX_DIMS; i++) {\n        result->nb[i] = src->nb[i];\n    }\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx) {\n    struct ggml_object * obj = ctx->objects_begin;\n\n    char * const mem_buffer = ctx->mem_buffer;\n\n    while (obj != NULL) {\n        if (obj->type == GGML_OBJECT_TYPE_TENSOR) {\n            return (struct ggml_tensor *)(mem_buffer + obj->offs);\n        }\n\n        obj = obj->next;\n    }\n\n    return NULL;\n}\n\nstruct ggml_tensor * ggml_get_next_tensor(const struct ggml_context * ctx, struct ggml_tensor * tensor) {\n    struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE);\n    obj = obj->next;\n\n    char * const mem_buffer = ctx->mem_buffer;\n\n    while (obj != NULL) {\n        if (obj->type == GGML_OBJECT_TYPE_TENSOR) {\n            return (struct ggml_tensor *)(mem_buffer + obj->offs);\n        }\n\n        obj = obj->next;\n    }\n\n    return NULL;\n}\n\nstruct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) {\n    struct ggml_object * obj = ctx->objects_begin;\n\n    char * const mem_buffer = ctx->mem_buffer;\n\n    while (obj != NULL) {\n        if (obj->type == GGML_OBJECT_TYPE_TENSOR) {\n            struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs);\n            if (strcmp(cur->name, name) == 0) {\n                return cur;\n            }\n        }\n\n        obj = obj->next;\n    }\n\n    return NULL;\n}\n\n////////////////////////////////////////////////////////////////////////////////\n\n// ggml_dup\n\nstatic struct ggml_tensor * ggml_dup_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        bool                  inplace) {\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_DUP;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_dup(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_dup_impl(ctx, a, false);\n}\n\nstruct ggml_tensor * ggml_dup_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_dup_impl(ctx, a, true);\n}\n\n// ggml_add\n\nstatic struct ggml_tensor * ggml_add_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        bool                  inplace) {\n    GGML_ASSERT(ggml_can_repeat(b, a));\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_ADD;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_add(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_add_impl(ctx, a, b, false);\n}\n\nstruct ggml_tensor * ggml_add_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_add_impl(ctx, a, b, true);\n}\n\n// ggml_add_cast\n\nstatic struct ggml_tensor * ggml_add_cast_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        enum   ggml_type      type) {\n    // TODO: support less-strict constraint\n    //       GGML_ASSERT(ggml_can_repeat(b, a));\n    GGML_ASSERT(ggml_can_repeat_rows(b, a));\n\n    // currently only supported for quantized input and f16\n    GGML_ASSERT(ggml_is_quantized(a->type) ||\n                a->type == GGML_TYPE_F16 ||\n                a->type == GGML_TYPE_BF16);\n\n    struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);\n\n    result->op     = GGML_OP_ADD;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_add_cast(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        enum   ggml_type      type) {\n    return ggml_add_cast_impl(ctx, a, b, type);\n}\n\nstruct ggml_tensor * ggml_add_id(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            struct ggml_tensor  * b,\n            struct ggml_tensor  * ids) {\n\n    GGML_ASSERT(a->ne[0] == b->ne[0]);\n    GGML_ASSERT(a->ne[1] == ids->ne[0]);\n    GGML_ASSERT(a->ne[2] == ids->ne[1]);\n    GGML_ASSERT(ids->type == GGML_TYPE_I32);\n\n    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_ADD_ID;\n    result->src[0] = a;\n    result->src[1] = b;\n    result->src[2] = ids;\n\n    return result;\n}\n\n// ggml_add1\n\nstatic struct ggml_tensor * ggml_add1_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        bool                  inplace) {\n    GGML_ASSERT(ggml_is_scalar(b));\n    GGML_ASSERT(ggml_is_padded_1d(a));\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_ADD1;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_add1(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_add1_impl(ctx, a, b, false);\n}\n\nstruct ggml_tensor * ggml_add1_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_add1_impl(ctx, a, b, true);\n}\n\n// ggml_acc\n\nstatic struct ggml_tensor * ggml_acc_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        size_t                nb1,\n        size_t                nb2,\n        size_t                nb3,\n        size_t                offset,\n        bool                  inplace) {\n    GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a));\n    GGML_ASSERT(ggml_is_contiguous(a));\n    GGML_ASSERT(a->type == GGML_TYPE_F32);\n    GGML_ASSERT(b->type == GGML_TYPE_F32);\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_ACC;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_acc(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        size_t                nb1,\n        size_t                nb2,\n        size_t                nb3,\n        size_t                offset) {\n    return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);\n}\n\nstruct ggml_tensor * ggml_acc_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        size_t                nb1,\n        size_t                nb2,\n        size_t                nb3,\n        size_t                offset) {\n    return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);\n}\n\n// ggml_sub\n\nstatic struct ggml_tensor * ggml_sub_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        bool                  inplace) {\n    GGML_ASSERT(ggml_can_repeat(b, a));\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_SUB;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_sub(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_sub_impl(ctx, a, b, false);\n}\n\nstruct ggml_tensor * ggml_sub_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_sub_impl(ctx, a, b, true);\n}\n\n// ggml_mul\n\nstatic struct ggml_tensor * ggml_mul_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        bool                  inplace) {\n    GGML_ASSERT(ggml_can_repeat(b, a));\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_MUL;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_mul(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_mul_impl(ctx, a, b, false);\n}\n\nstruct ggml_tensor * ggml_mul_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_mul_impl(ctx, a, b, true);\n}\n\n// ggml_div\n\nstatic struct ggml_tensor * ggml_div_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        bool                  inplace) {\n    GGML_ASSERT(ggml_can_repeat(b, a));\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_DIV;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_div(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_div_impl(ctx, a, b, false);\n}\n\nstruct ggml_tensor * ggml_div_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_div_impl(ctx, a, b, true);\n}\n\n// ggml_sqr\n\nstatic struct ggml_tensor * ggml_sqr_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        bool                  inplace) {\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_SQR;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_sqr(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_sqr_impl(ctx, a, false);\n}\n\nstruct ggml_tensor * ggml_sqr_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_sqr_impl(ctx, a, true);\n}\n\n// ggml_sqrt\n\nstatic struct ggml_tensor * ggml_sqrt_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        bool                  inplace) {\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_SQRT;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_sqrt(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_sqrt_impl(ctx, a, false);\n}\n\nstruct ggml_tensor * ggml_sqrt_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_sqrt_impl(ctx, a, true);\n}\n\n// ggml_log\n\nstatic struct ggml_tensor * ggml_log_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        bool                  inplace) {\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_LOG;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_log(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_log_impl(ctx, a, false);\n}\n\nstruct ggml_tensor * ggml_log_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_log_impl(ctx, a, true);\n}\n\nstruct ggml_tensor * ggml_expm1(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_EXPM1);\n}\n\nstruct ggml_tensor * ggml_expm1_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXPM1);\n}\n\nstruct ggml_tensor * ggml_softplus(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_SOFTPLUS);\n}\n\nstruct ggml_tensor * ggml_softplus_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SOFTPLUS);\n}\n\n// ggml_sin\n\nstatic struct ggml_tensor * ggml_sin_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        bool                  inplace) {\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_SIN;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_sin(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_sin_impl(ctx, a, false);\n}\n\nstruct ggml_tensor * ggml_sin_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_sin_impl(ctx, a, true);\n}\n\n// ggml_cos\n\nstatic struct ggml_tensor * ggml_cos_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        bool                  inplace) {\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_COS;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_cos(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_cos_impl(ctx, a, false);\n}\n\nstruct ggml_tensor * ggml_cos_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_cos_impl(ctx, a, true);\n}\n\n// ggml_sum\n\nstruct ggml_tensor * ggml_sum(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);\n\n    result->op     = GGML_OP_SUM;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_sum_rows\n\nstruct ggml_tensor * ggml_sum_rows(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    int64_t ne[GGML_MAX_DIMS] = { 1 };\n    for (int i = 1; i < GGML_MAX_DIMS; ++i) {\n        ne[i] = a->ne[i];\n    }\n\n    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);\n\n    result->op     = GGML_OP_SUM_ROWS;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_cumsum\n\nstruct ggml_tensor * ggml_cumsum(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    GGML_ASSERT(a->type == GGML_TYPE_F32);\n\n    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_CUMSUM;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_mean\n\nstruct ggml_tensor * ggml_mean(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] };\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    result->op     = GGML_OP_MEAN;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_argmax\n\nstruct ggml_tensor * ggml_argmax(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    GGML_ASSERT(ggml_is_matrix(a));\n    GGML_ASSERT(a->ne[0] <= INT32_MAX);\n\n    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);\n\n    result->op     = GGML_OP_ARGMAX;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_count_equal\n\nstruct ggml_tensor * ggml_count_equal(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    GGML_ASSERT(ggml_are_same_shape(a, b));\n\n    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, 1);\n\n    result->op     = GGML_OP_COUNT_EQUAL;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml_repeat\n\nstruct ggml_tensor * ggml_repeat(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    GGML_ASSERT(ggml_can_repeat(a, b));\n\n    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);\n\n    result->op     = GGML_OP_REPEAT;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_repeat_4d(\n        struct ggml_context * ctx,\n        struct ggml_tensor * a,\n        int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {\n    const bool can_repeat = ggml_is_empty(a) || (\n        (ne0 % a->ne[0] == 0) &&\n        (ne1 % a->ne[1] == 0) &&\n        (ne2 % a->ne[2] == 0) &&\n        (ne3 % a->ne[3] == 0)\n    );\n    GGML_ASSERT(can_repeat);\n\n    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);\n\n    result->op     = GGML_OP_REPEAT;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_repeat_back\n\nstruct ggml_tensor * ggml_repeat_back(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    GGML_ASSERT(ggml_can_repeat(b, a));\n\n    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);\n\n    result->op     = GGML_OP_REPEAT_BACK;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_concat\n\nstruct ggml_tensor * ggml_concat(\n    struct ggml_context * ctx,\n    struct ggml_tensor  * a,\n    struct ggml_tensor  * b,\n    int                   dim) {\n    GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);\n    GGML_ASSERT(a->type == b->type);\n\n    int64_t ne[GGML_MAX_DIMS];\n    for (int d = 0; d < GGML_MAX_DIMS; ++d) {\n        if (d == dim) {\n            ne[d] = a->ne[d] + b->ne[d];\n            continue;\n        }\n        GGML_ASSERT(a->ne[d] == b->ne[d]);\n        ne[d] = a->ne[d];\n    }\n\n    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);\n\n    ggml_set_op_params_i32(result, 0, dim);\n\n    result->op     = GGML_OP_CONCAT;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml_abs\n\nstruct ggml_tensor * ggml_abs(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_ABS);\n}\n\nstruct ggml_tensor * ggml_abs_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ABS);\n}\n\n// ggml_sgn\n\nstruct ggml_tensor * ggml_sgn(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_SGN);\n}\n\nstruct ggml_tensor * ggml_sgn_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SGN);\n}\n\n// ggml_neg\n\nstruct ggml_tensor * ggml_neg(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_NEG);\n}\n\nstruct ggml_tensor * ggml_neg_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_NEG);\n}\n\n// ggml_step\n\nstruct ggml_tensor * ggml_step(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_STEP);\n}\n\nstruct ggml_tensor * ggml_step_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_STEP);\n}\n\n// ggml_tanh\n\nstruct ggml_tensor * ggml_tanh(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_TANH);\n}\n\nstruct ggml_tensor * ggml_tanh_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TANH);\n}\n\n// ggml_elu\n\nstruct ggml_tensor * ggml_elu(\n    struct ggml_context * ctx,\n    struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_ELU);\n}\n\nstruct ggml_tensor * ggml_elu_inplace(\n    struct ggml_context * ctx,\n    struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ELU);\n}\n\n// ggml_relu\n\nstruct ggml_tensor * ggml_relu(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_RELU);\n}\n\nstruct ggml_tensor * ggml_relu_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);\n}\n\n// ggml_leaky_relu\n\nstruct ggml_tensor * ggml_leaky_relu(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 negative_slope,\n        bool                  inplace) {\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));\n\n    result->op     = GGML_OP_LEAKY_RELU;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_sigmoid\n\nstruct ggml_tensor * ggml_sigmoid(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);\n}\n\nstruct ggml_tensor * ggml_sigmoid_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);\n}\n\n// ggml_gelu\n\nstruct ggml_tensor * ggml_gelu(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_GELU);\n}\n\nstruct ggml_tensor * ggml_gelu_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);\n}\n\n// ggml_gelu_erf\n\nstruct ggml_tensor * ggml_gelu_erf(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF);\n}\n\nstruct ggml_tensor * ggml_gelu_erf_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF);\n}\n\n// ggml_gelu_quick\n\nstruct ggml_tensor * ggml_gelu_quick(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_QUICK);\n}\n\nstruct ggml_tensor * ggml_gelu_quick_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_QUICK);\n}\n\n// ggml_silu\n\nstruct ggml_tensor * ggml_silu(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_SILU);\n}\n\nstruct ggml_tensor * ggml_silu_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);\n}\n\n// ggml_xielu\n\nstruct ggml_tensor * ggml_xielu(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float alpha_n,\n        float alpha_p,\n        float beta,\n        float eps) {\n    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);\n\n    ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU);\n    ggml_set_op_params_f32(result, 1, beta + ggml_compute_softplus_f32(alpha_n));\n    ggml_set_op_params_f32(result, 2, ggml_compute_softplus_f32(alpha_p));\n    ggml_set_op_params_f32(result, 3, beta);\n    ggml_set_op_params_f32(result, 4, eps);\n\n    result->op     = GGML_OP_UNARY;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_silu_back\n\nstruct ggml_tensor * ggml_silu_back(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_SILU_BACK;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml hardswish\n\nstruct ggml_tensor * ggml_hardswish(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSWISH);\n}\n\n// ggml hardsigmoid\n\nstruct ggml_tensor * ggml_hardsigmoid(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);\n}\n\n// ggml exp\n\nstruct ggml_tensor * ggml_exp(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_EXP);\n}\n\nstruct ggml_tensor * ggml_exp_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);\n}\n\n// ggml_glu\n\nstatic struct ggml_tensor * ggml_glu_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        enum ggml_glu_op      op,\n        bool                  swapped) {\n    GGML_ASSERT(ggml_is_contiguous_1(a));\n\n    if (b) {\n        GGML_ASSERT(ggml_is_contiguous_1(b));\n        GGML_ASSERT(ggml_are_same_shape(a, b));\n        GGML_ASSERT(a->type == b->type);\n    }\n\n    int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];\n    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);\n\n    ggml_set_op_params_i32(result, 0, (int32_t) op);\n    ggml_set_op_params_i32(result, 1, (int32_t) swapped);\n\n    result->op     = GGML_OP_GLU;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml_floor\n\nstruct ggml_tensor * ggml_floor(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_FLOOR);\n}\n\nstruct ggml_tensor * ggml_floor_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_FLOOR);\n}\n\n// ggml_ceil\n\nstruct ggml_tensor * ggml_ceil(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_CEIL);\n}\n\nstruct ggml_tensor * ggml_ceil_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_CEIL);\n}\n\n//ggml_round\n\nstruct ggml_tensor * ggml_round(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_ROUND);\n}\n\nstruct ggml_tensor * ggml_round_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ROUND);\n}\n\n//ggml_trunc\n\nstruct ggml_tensor * ggml_trunc(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary(ctx, a, GGML_UNARY_OP_TRUNC);\n}\n\nstruct ggml_tensor * ggml_trunc_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TRUNC);\n}\n\nstruct ggml_tensor * ggml_glu(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        enum ggml_glu_op      op,\n        bool                  swapped) {\n    return ggml_glu_impl(ctx, a, NULL, op, swapped);\n}\n\nstruct ggml_tensor * ggml_glu_split(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        enum ggml_glu_op      op) {\n    return ggml_glu_impl(ctx, a, b, op, false);\n}\n\n// ggml_reglu\n\nstruct ggml_tensor * ggml_reglu(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false);\n}\n\nstruct ggml_tensor * ggml_reglu_swapped(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true);\n}\n\nstruct ggml_tensor * ggml_reglu_split(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false);\n}\n\n// ggml_geglu\n\nstruct ggml_tensor * ggml_geglu(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false);\n}\n\nstruct ggml_tensor * ggml_geglu_swapped(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true);\n}\n\nstruct ggml_tensor * ggml_geglu_split(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false);\n}\n\n// ggml_swiglu\n\nstruct ggml_tensor * ggml_swiglu(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false);\n}\n\nstruct ggml_tensor * ggml_swiglu_swapped(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true);\n}\n\nstruct ggml_tensor * ggml_swiglu_split(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);\n}\n\n// ggml_geglu_erf\n\nstruct ggml_tensor * ggml_geglu_erf(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false);\n}\n\nstruct ggml_tensor * ggml_geglu_erf_swapped(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true);\n}\n\nstruct ggml_tensor * ggml_geglu_erf_split(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false);\n}\n\n// ggml_geglu_quick\n\nstruct ggml_tensor * ggml_geglu_quick(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false);\n}\n\nstruct ggml_tensor * ggml_geglu_quick_swapped(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true);\n}\n\nstruct ggml_tensor * ggml_geglu_quick_split(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);\n}\n\nstruct ggml_tensor * ggml_swiglu_oai(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        float                 alpha,\n        float                 limit) {\n    struct ggml_tensor * result = ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU_OAI, false);\n    ggml_set_op_params_f32(result, 2, alpha);\n    ggml_set_op_params_f32(result, 3, limit);\n\n    return result;\n}\n\n// ggml_norm\n\nstatic struct ggml_tensor * ggml_norm_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 eps,\n        bool                  inplace) {\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    ggml_set_op_params(result, &eps, sizeof(eps));\n\n    result->op     = GGML_OP_NORM;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_norm(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 eps) {\n    return ggml_norm_impl(ctx, a, eps, false);\n}\n\nstruct ggml_tensor * ggml_norm_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 eps) {\n    return ggml_norm_impl(ctx, a, eps, true);\n}\n\n// ggml_rms_norm\n\nstatic struct ggml_tensor * ggml_rms_norm_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 eps,\n        bool                  inplace) {\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    ggml_set_op_params(result, &eps, sizeof(eps));\n\n    result->op     = GGML_OP_RMS_NORM;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_rms_norm(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 eps) {\n    return ggml_rms_norm_impl(ctx, a, eps, false);\n}\n\nstruct ggml_tensor * ggml_rms_norm_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 eps) {\n    return ggml_rms_norm_impl(ctx, a, eps, true);\n}\n\n// ggml_rms_norm_back\n\nstruct ggml_tensor * ggml_rms_norm_back(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        float                 eps) {\n    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);\n\n    ggml_set_op_params(result, &eps, sizeof(eps));\n\n    result->op     = GGML_OP_RMS_NORM_BACK;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml_group_norm\n\nstatic struct ggml_tensor * ggml_group_norm_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   n_groups,\n        float                 eps,\n        bool                  inplace) {\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    ggml_set_op_params_i32(result, 0, n_groups);\n    ggml_set_op_params_f32(result, 1, eps);\n\n    result->op     = GGML_OP_GROUP_NORM;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_group_norm(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   n_groups,\n        float                 eps) {\n    return ggml_group_norm_impl(ctx, a, n_groups, eps, false);\n}\n\nstruct ggml_tensor * ggml_group_norm_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   n_groups,\n        float                 eps) {\n    return ggml_group_norm_impl(ctx, a, n_groups, eps, true);\n}\n\n// ggml_l2_norm\n\nstatic struct ggml_tensor * ggml_l2_norm_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 eps,\n        bool                  inplace) {\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    ggml_set_op_params_f32(result, 0, eps);\n\n    result->op     = GGML_OP_L2_NORM;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_l2_norm(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 eps) {\n    return ggml_l2_norm_impl(ctx, a, eps, false);\n}\n\nstruct ggml_tensor * ggml_l2_norm_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 eps) {\n    return ggml_l2_norm_impl(ctx, a, eps, true);\n}\n\n// ggml_mul_mat\n\nstatic inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return (t0->ne[0]           == t1->ne[0])  &&\n           (t1->ne[2]%t0->ne[2] == 0)          && // verify t0 is broadcastable\n           (t1->ne[3]%t0->ne[3] == 0);\n}\n\nstruct ggml_tensor * ggml_mul_mat(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    GGML_ASSERT(ggml_can_mul_mat(a, b));\n    GGML_ASSERT(!ggml_is_transposed(a));\n\n    const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    result->op     = GGML_OP_MUL_MAT;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\nvoid ggml_mul_mat_set_prec(\n        struct ggml_tensor * a,\n        enum ggml_prec       prec) {\n    GGML_ASSERT(a->op == GGML_OP_MUL_MAT);\n\n    const int32_t prec_i32 = (int32_t) prec;\n\n    ggml_set_op_params_i32(a, 0, prec_i32);\n}\n\n// ggml_mul_mat_id\n\n/*\n    c = ggml_mul_mat_id(ctx, as, b, ids);\n\n    as  -> [cols, rows, n_expert]\n    b   -> [cols, n_expert_used, n_tokens]\n    ids -> [n_expert_used, n_tokens] (i32)\n    c   -> [rows, n_expert_used, n_tokens]\n\n    in b, n_expert_used can be broadcasted to match the n_expert_used of ids\n\n    c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids\n*/\nstruct ggml_tensor * ggml_mul_mat_id(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * as,\n        struct ggml_tensor  * b,\n        struct ggml_tensor  * ids) {\n    GGML_ASSERT(!ggml_is_transposed(as));\n    GGML_ASSERT(ids->type == GGML_TYPE_I32);\n\n    GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert)\n    GGML_ASSERT(b->ne[3] == 1); // b is 3d\n    GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d\n    GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row\n    GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat\n    GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast\n\n    const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    result->op     = GGML_OP_MUL_MAT_ID;\n    result->src[0] = as;\n    result->src[1] = b;\n    result->src[2] = ids;\n\n    return result;\n}\n\n// ggml_out_prod\n\nstatic inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {\n    static_assert(GGML_MAX_DIMS == 4, \"GGML_MAX_DIMS is not 4 - update this function\");\n\n    return (t0->ne[1] == t1->ne[1])   &&\n           (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable\n           (t1->ne[3]%t0->ne[3] == 0);\n}\n\nstruct ggml_tensor * ggml_out_prod(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    GGML_ASSERT(ggml_can_out_prod(a, b));\n    GGML_ASSERT(!ggml_is_transposed(a));\n\n    // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]\n    const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    result->op     = GGML_OP_OUT_PROD;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml_scale\n\nstatic struct ggml_tensor * ggml_scale_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 s,\n        float                 b,\n        bool                  inplace) {\n    GGML_ASSERT(ggml_is_padded_1d(a));\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    float params[2] = { s, b };\n    ggml_set_op_params(result, &params, sizeof(params));\n\n    result->op     = GGML_OP_SCALE;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_scale(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 s) {\n    return ggml_scale_impl(ctx, a, s, 0.0, false);\n}\n\nstruct ggml_tensor * ggml_scale_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 s) {\n    return ggml_scale_impl(ctx, a, s, 0.0, true);\n}\n\nstruct ggml_tensor * ggml_scale_bias(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 s,\n        float                 b) {\n    return ggml_scale_impl(ctx, a, s, b, false);\n}\n\nstruct ggml_tensor * ggml_scale_bias_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 s,\n        float                 b) {\n    return ggml_scale_impl(ctx, a, s, b, true);\n}\n\n// ggml_set\n\nstatic struct ggml_tensor * ggml_set_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        size_t                nb1,\n        size_t                nb2,\n        size_t                nb3,\n        size_t                offset,\n        bool                  inplace) {\n    GGML_ASSERT(ggml_nelements(a) >= ggml_nelements(b));\n\n    // make a view of the destination\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    GGML_ASSERT(offset < (size_t)(1 << 30));\n    int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_SET;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_set(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        size_t                nb1,\n        size_t                nb2,\n        size_t                nb3,\n        size_t                offset) {\n    return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false);\n}\n\nstruct ggml_tensor * ggml_set_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        size_t                nb1,\n        size_t                nb2,\n        size_t                nb3,\n        size_t                offset) {\n    return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true);\n}\n\nstruct ggml_tensor * ggml_set_1d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        size_t                offset) {\n    return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false);\n}\n\nstruct ggml_tensor * ggml_set_1d_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        size_t                offset) {\n    return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true);\n}\n\nstruct ggml_tensor * ggml_set_2d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        size_t                nb1,\n        size_t                offset) {\n    return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);\n}\n\nstruct ggml_tensor * ggml_set_2d_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        size_t                nb1,\n        size_t                offset) {\n    return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);\n}\n\n// ggml_cpy\n\nstatic struct ggml_tensor * ggml_cpy_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));\n\n    // make a view of the destination\n    struct ggml_tensor * result = ggml_view_tensor(ctx, b);\n    if (strlen(b->name) > 0) {\n        ggml_format_name(result, \"%s (copy of %s)\", b->name, a->name);\n    } else {\n        ggml_format_name(result, \"%s (copy)\", a->name);\n    }\n\n    result->op     = GGML_OP_CPY;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_cpy(\n        struct ggml_context * ctx,\n        struct ggml_tensor * a,\n        struct ggml_tensor * b) {\n    return ggml_cpy_impl(ctx, a, b);\n}\n\nstruct ggml_tensor * ggml_cast(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        enum   ggml_type      type) {\n    struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);\n    ggml_format_name(result, \"%s (copy)\", a->name);\n\n    result->op     = GGML_OP_CPY;\n    result->src[0] = a;\n    result->src[1] = result; // note: this self-reference might seem redundant, but it's actually needed by some\n                             //       backends for consistency with ggml_cpy_impl() above\n\n    return result;\n}\n\n// ggml_cont\n\nstatic struct ggml_tensor * ggml_cont_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);\n    ggml_format_name(result, \"%s (cont)\", a->name);\n\n    result->op     = GGML_OP_CONT;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_cont(\n        struct ggml_context * ctx,\n        struct ggml_tensor * a) {\n    return ggml_cont_impl(ctx, a);\n}\n\n// make contiguous, with new shape\nGGML_API struct ggml_tensor * ggml_cont_1d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int64_t               ne0) {\n    return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);\n}\n\nGGML_API struct ggml_tensor * ggml_cont_2d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int64_t               ne0,\n        int64_t               ne1) {\n    return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);\n}\n\nGGML_API struct ggml_tensor * ggml_cont_3d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int64_t               ne0,\n        int64_t               ne1,\n        int64_t               ne2) {\n    return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);\n}\n\nstruct ggml_tensor * ggml_cont_4d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int64_t               ne0,\n        int64_t               ne1,\n        int64_t               ne2,\n        int64_t               ne3) {\n    GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));\n\n    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);\n    ggml_format_name(result, \"%s (cont)\", a->name);\n\n    result->op     = GGML_OP_CONT;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_reshape\n\nstruct ggml_tensor * ggml_reshape(\n        struct ggml_context * ctx,\n        struct ggml_tensor * a,\n        struct ggml_tensor * b) {\n    GGML_ASSERT(ggml_is_contiguous(a));\n    // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.\n    GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));\n\n    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b->ne, a, 0);\n    ggml_format_name(result, \"%s (reshaped)\", a->name);\n\n    result->op     = GGML_OP_RESHAPE;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_reshape_1d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int64_t               ne0) {\n    GGML_ASSERT(ggml_is_contiguous(a));\n    GGML_ASSERT(ggml_nelements(a) == ne0);\n\n    const int64_t ne[1] = { ne0 };\n    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0);\n    ggml_format_name(result, \"%s (reshaped)\", a->name);\n\n    result->op     = GGML_OP_RESHAPE;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_reshape_2d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int64_t               ne0,\n        int64_t               ne1) {\n    GGML_ASSERT(ggml_is_contiguous(a));\n    GGML_ASSERT(ggml_nelements(a) == ne0*ne1);\n\n    const int64_t ne[2] = { ne0, ne1 };\n    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0);\n    ggml_format_name(result, \"%s (reshaped)\", a->name);\n\n    result->op     = GGML_OP_RESHAPE;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_reshape_3d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int64_t               ne0,\n        int64_t               ne1,\n        int64_t               ne2) {\n    GGML_ASSERT(ggml_is_contiguous(a));\n    GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);\n\n    const int64_t ne[3] = { ne0, ne1, ne2 };\n    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0);\n    ggml_format_name(result, \"%s (reshaped)\", a->name);\n\n    result->op     = GGML_OP_RESHAPE;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_reshape_4d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int64_t               ne0,\n        int64_t               ne1,\n        int64_t               ne2,\n        int64_t               ne3) {\n    GGML_ASSERT(ggml_is_contiguous(a));\n    GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);\n\n    const int64_t ne[4] = { ne0, ne1, ne2, ne3 };\n    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0);\n    ggml_format_name(result, \"%s (reshaped)\", a->name);\n\n    result->op     = GGML_OP_RESHAPE;\n    result->src[0] = a;\n\n    return result;\n}\n\nstatic struct ggml_tensor * ggml_view_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   n_dims,\n        const int64_t       * ne,\n        size_t                offset) {\n    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset);\n    ggml_format_name(result, \"%s (view)\", a->name);\n\n    ggml_set_op_params(result, &offset, sizeof(offset));\n\n    result->op     = GGML_OP_VIEW;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_view_1d\n\nstruct ggml_tensor * ggml_view_1d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int64_t               ne0,\n        size_t                offset) {\n    struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset);\n\n    return result;\n}\n\n// ggml_view_2d\n\nstruct ggml_tensor * ggml_view_2d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int64_t               ne0,\n        int64_t               ne1,\n        size_t                nb1,\n        size_t                offset) {\n    const int64_t ne[2] = { ne0, ne1 };\n\n    struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset);\n\n    result->nb[1] = nb1;\n    result->nb[2] = result->nb[1]*ne1;\n    result->nb[3] = result->nb[2];\n\n    return result;\n}\n\n// ggml_view_3d\n\nstruct ggml_tensor * ggml_view_3d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int64_t               ne0,\n        int64_t               ne1,\n        int64_t               ne2,\n        size_t                nb1,\n        size_t                nb2,\n        size_t                offset) {\n    const int64_t ne[3] = { ne0, ne1, ne2 };\n\n    struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset);\n\n    result->nb[1] = nb1;\n    result->nb[2] = nb2;\n    result->nb[3] = result->nb[2]*ne2;\n\n    return result;\n}\n\n// ggml_view_4d\n\nstruct ggml_tensor * ggml_view_4d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int64_t               ne0,\n        int64_t               ne1,\n        int64_t               ne2,\n        int64_t               ne3,\n        size_t                nb1,\n        size_t                nb2,\n        size_t                nb3,\n        size_t                offset) {\n    const int64_t ne[4] = { ne0, ne1, ne2, ne3 };\n\n    struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset);\n\n    result->nb[1] = nb1;\n    result->nb[2] = nb2;\n    result->nb[3] = nb3;\n\n    return result;\n}\n\n// ggml_permute\n\nstruct ggml_tensor * ggml_permute(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   axis0,\n        int                   axis1,\n        int                   axis2,\n        int                   axis3) {\n    GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS);\n    GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS);\n    GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS);\n    GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS);\n\n    GGML_ASSERT(axis0 != axis1);\n    GGML_ASSERT(axis0 != axis2);\n    GGML_ASSERT(axis0 != axis3);\n    GGML_ASSERT(axis1 != axis2);\n    GGML_ASSERT(axis1 != axis3);\n    GGML_ASSERT(axis2 != axis3);\n\n    struct ggml_tensor * result = ggml_view_tensor(ctx, a);\n    ggml_format_name(result, \"%s (permuted)\", a->name);\n\n    int ne[GGML_MAX_DIMS];\n    int nb[GGML_MAX_DIMS];\n\n    ne[axis0] = a->ne[0];\n    ne[axis1] = a->ne[1];\n    ne[axis2] = a->ne[2];\n    ne[axis3] = a->ne[3];\n\n    nb[axis0] = a->nb[0];\n    nb[axis1] = a->nb[1];\n    nb[axis2] = a->nb[2];\n    nb[axis3] = a->nb[3];\n\n    result->ne[0] = ne[0];\n    result->ne[1] = ne[1];\n    result->ne[2] = ne[2];\n    result->ne[3] = ne[3];\n\n    result->nb[0] = nb[0];\n    result->nb[1] = nb[1];\n    result->nb[2] = nb[2];\n    result->nb[3] = nb[3];\n\n    result->op     = GGML_OP_PERMUTE;\n    result->src[0] = a;\n\n    int32_t params[] = { axis0, axis1, axis2, axis3 };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    return result;\n}\n\n// ggml_transpose\n\nstruct ggml_tensor * ggml_transpose(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    struct ggml_tensor * result = ggml_view_tensor(ctx, a);\n    ggml_format_name(result, \"%s (transposed)\", a->name);\n\n    result->ne[0] = a->ne[1];\n    result->ne[1] = a->ne[0];\n\n    result->nb[0] = a->nb[1];\n    result->nb[1] = a->nb[0];\n\n    result->op     = GGML_OP_TRANSPOSE;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_get_rows\n\nstruct ggml_tensor * ggml_get_rows(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    GGML_ASSERT(a->ne[2] == b->ne[1]);\n    GGML_ASSERT(a->ne[3] == b->ne[2]);\n    GGML_ASSERT(b->ne[3] == 1);\n    GGML_ASSERT(b->type == GGML_TYPE_I32);\n\n    // TODO: implement non F32 return\n    enum ggml_type type = GGML_TYPE_F32;\n    if (a->type == GGML_TYPE_I32) {\n        type = a->type;\n    }\n    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);\n\n    result->op     = GGML_OP_GET_ROWS;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml_get_rows_back\n\nstruct ggml_tensor * ggml_get_rows_back(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        struct ggml_tensor  * c) {\n    GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);\n    GGML_ASSERT(ggml_is_matrix(c) && (a->ne[0] == c->ne[0]));\n\n    // TODO: implement non F32 return\n    //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);\n    struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c->ne[0], c->ne[1]);\n\n    result->op     = GGML_OP_GET_ROWS_BACK;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml_set_rows\n\nstruct ggml_tensor * ggml_set_rows(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        struct ggml_tensor  * c) {\n    GGML_ASSERT(a->ne[0] == b->ne[0]);\n    GGML_ASSERT(a->ne[2] == b->ne[2]);\n    GGML_ASSERT(a->ne[3] == b->ne[3]);\n    GGML_ASSERT(b->ne[1] == c->ne[0]);\n    GGML_ASSERT(b->ne[2] % c->ne[1] == 0);\n    GGML_ASSERT(b->ne[3] % c->ne[2] == 0);\n    GGML_ASSERT(c->ne[3] == 1);\n    GGML_ASSERT(b->type == GGML_TYPE_F32);\n    GGML_ASSERT(c->type == GGML_TYPE_I64 || c->type == GGML_TYPE_I32);\n\n    GGML_ASSERT(ggml_is_contiguous_rows(a));\n    GGML_ASSERT(ggml_is_contiguous_rows(b));\n\n    struct ggml_tensor * result = ggml_view_tensor(ctx, a);\n\n    result->op     = GGML_OP_SET_ROWS;\n    result->src[0] = b;\n    result->src[1] = c;\n    result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931)\n\n    return result;\n}\n\n// ggml_diag\n\nstruct ggml_tensor * ggml_diag(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    GGML_ASSERT(a->ne[1] == 1);\n\n    const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] };\n    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 4, ne);\n\n    result->op     = GGML_OP_DIAG;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_diag_mask_inf\n\nstatic struct ggml_tensor * ggml_diag_mask_inf_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   n_past,\n        bool                  inplace) {\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    int32_t params[] = { n_past };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_DIAG_MASK_INF;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_diag_mask_inf(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   n_past) {\n    return ggml_diag_mask_inf_impl(ctx, a, n_past, false);\n}\n\nstruct ggml_tensor * ggml_diag_mask_inf_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   n_past) {\n    return ggml_diag_mask_inf_impl(ctx, a, n_past, true);\n}\n\n// ggml_diag_mask_zero\n\nstatic struct ggml_tensor * ggml_diag_mask_zero_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   n_past,\n        bool                  inplace) {\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    int32_t params[] = { n_past };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_DIAG_MASK_ZERO;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_diag_mask_zero(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   n_past) {\n    return ggml_diag_mask_zero_impl(ctx, a, n_past, false);\n}\n\nstruct ggml_tensor * ggml_diag_mask_zero_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   n_past) {\n    return ggml_diag_mask_zero_impl(ctx, a, n_past, true);\n}\n\n// ggml_soft_max\n\nstatic struct ggml_tensor * ggml_soft_max_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * mask,\n        float                 scale,\n        float                 max_bias,\n        bool                  inplace) {\n    GGML_ASSERT(ggml_is_contiguous(a));\n\n    if (mask) {\n        GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);\n        GGML_ASSERT(ggml_is_contiguous(mask));\n        GGML_ASSERT(mask->ne[0] == a->ne[0]);\n        GGML_ASSERT(mask->ne[1] >= a->ne[1]);\n        GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);\n        GGML_ASSERT(a->ne[3]%mask->ne[3] == 0);\n    }\n\n    if (max_bias > 0.0f) {\n        GGML_ASSERT(mask);\n    }\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    float params[] = { scale, max_bias };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_SOFT_MAX;\n    result->src[0] = a;\n    result->src[1] = mask;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_soft_max(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);\n}\n\nstruct ggml_tensor * ggml_soft_max_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a) {\n    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);\n}\n\nstruct ggml_tensor * ggml_soft_max_ext(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * mask,\n        float                 scale,\n        float                 max_bias) {\n    return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);\n}\n\nstruct ggml_tensor * ggml_soft_max_ext_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * mask,\n        float                 scale,\n        float                 max_bias) {\n    return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true);\n}\n\nvoid ggml_soft_max_add_sinks(\n        struct ggml_tensor * a,\n        struct ggml_tensor * sinks) {\n    if (!sinks) {\n        a->src[2] = NULL;\n        return;\n    }\n\n    GGML_ASSERT(a->op == GGML_OP_SOFT_MAX);\n    GGML_ASSERT(a->src[2] == NULL);\n    GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);\n    GGML_ASSERT(sinks->type == GGML_TYPE_F32);\n\n    a->src[2] = sinks;\n}\n\n// ggml_soft_max_ext_back\n\nstatic struct ggml_tensor * ggml_soft_max_ext_back_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        float                 scale,\n        float                 max_bias,\n        bool                  inplace) {\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    result->op     = GGML_OP_SOFT_MAX_BACK;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    memcpy((float *) result->op_params + 0, &scale,    sizeof(float));\n    memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_soft_max_ext_back(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        float                 scale,\n        float                 max_bias) {\n    return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);\n}\n\nstruct ggml_tensor * ggml_soft_max_ext_back_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        float                 scale,\n        float                 max_bias) {\n    return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);\n}\n\n// ggml_rope\n\nstatic struct ggml_tensor * ggml_rope_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        struct ggml_tensor  * c,\n        int                   n_dims,\n        int                   sections[GGML_MROPE_SECTIONS],\n        int                   mode,\n        int                   n_ctx_orig,\n        float                 freq_base,\n        float                 freq_scale,\n        float                 ext_factor,\n        float                 attn_factor,\n        float                 beta_fast,\n        float                 beta_slow,\n        bool                  inplace) {\n    GGML_ASSERT((mode & 1) == 0 && \"mode & 1 == 1 is no longer supported\");\n\n    GGML_ASSERT(ggml_is_vector(b));\n    GGML_ASSERT(b->type == GGML_TYPE_I32);\n\n    bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;\n    if (mrope_used) {\n        GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token\n    } else {\n        GGML_ASSERT(a->ne[2] == b->ne[0]);\n    }\n\n    if (c) {\n        GGML_ASSERT(c->type == GGML_TYPE_F32);\n        GGML_ASSERT(c->ne[0] >= n_dims / 2);\n    }\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };\n    memcpy(params +  5, &freq_base,    sizeof(float));\n    memcpy(params +  6, &freq_scale,   sizeof(float));\n    memcpy(params +  7, &ext_factor,   sizeof(float));\n    memcpy(params +  8, &attn_factor,  sizeof(float));\n    memcpy(params +  9, &beta_fast,    sizeof(float));\n    memcpy(params + 10, &beta_slow,    sizeof(float));\n    if (mrope_used && sections) {\n        memcpy(params + 11, sections,  sizeof(int32_t) * GGML_MROPE_SECTIONS);\n    } else {\n        memset(params + 11, 0,         sizeof(int32_t) * GGML_MROPE_SECTIONS);\n    }\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_ROPE;\n    result->src[0] = a;\n    result->src[1] = b;\n    result->src[2] = c;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_rope(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   n_dims,\n        int                   mode) {\n    return ggml_rope_impl(\n        ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false\n    );\n}\n\nstruct ggml_tensor * ggml_rope_multi(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        struct ggml_tensor  * c,\n        int                   n_dims,\n        int                   sections[GGML_MROPE_SECTIONS],\n        int                   mode,\n        int                   n_ctx_orig,\n        float                 freq_base,\n        float                 freq_scale,\n        float                 ext_factor,\n        float                 attn_factor,\n        float                 beta_fast,\n        float                 beta_slow) {\n    return ggml_rope_impl(\n        ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,\n        ext_factor, attn_factor, beta_fast, beta_slow, false\n    );\n}\n\nstruct ggml_tensor * ggml_rope_multi_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        struct ggml_tensor  * c,\n        int                   n_dims,\n        int                   sections[GGML_MROPE_SECTIONS],\n        int                   mode,\n        int                   n_ctx_orig,\n        float                 freq_base,\n        float                 freq_scale,\n        float                 ext_factor,\n        float                 attn_factor,\n        float                 beta_fast,\n        float                 beta_slow) {\n    return ggml_rope_impl(\n        ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,\n        ext_factor, attn_factor, beta_fast, beta_slow, true\n    );\n}\n\nstruct ggml_tensor * ggml_rope_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   n_dims,\n        int                   mode) {\n    return ggml_rope_impl(\n        ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true\n    );\n}\n\nstruct ggml_tensor * ggml_rope_ext(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        struct ggml_tensor  * c,\n        int                   n_dims,\n        int                   mode,\n        int                   n_ctx_orig,\n        float                 freq_base,\n        float                 freq_scale,\n        float                 ext_factor,\n        float                 attn_factor,\n        float                 beta_fast,\n        float                 beta_slow) {\n    return ggml_rope_impl(\n        ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,\n        ext_factor, attn_factor, beta_fast, beta_slow, false\n    );\n}\n\nstruct ggml_tensor * ggml_rope_ext_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        struct ggml_tensor  * c,\n        int                   n_dims,\n        int                   mode,\n        int                   n_ctx_orig,\n        float                 freq_base,\n        float                 freq_scale,\n        float                 ext_factor,\n        float                 attn_factor,\n        float                 beta_fast,\n        float                 beta_slow) {\n    return ggml_rope_impl(\n        ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,\n        ext_factor, attn_factor, beta_fast, beta_slow, true\n    );\n}\n\nstruct ggml_tensor * ggml_rope_custom(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   n_dims,\n        int                   mode,\n        int                   n_ctx_orig,\n        float                 freq_base,\n        float                 freq_scale,\n        float                 ext_factor,\n        float                 attn_factor,\n        float                 beta_fast,\n        float                 beta_slow) {\n    return ggml_rope_impl(\n        ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,\n        ext_factor, attn_factor, beta_fast, beta_slow, false\n    );\n}\n\nstruct ggml_tensor * ggml_rope_custom_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   n_dims,\n        int                   mode,\n        int                   n_ctx_orig,\n        float                 freq_base,\n        float                 freq_scale,\n        float                 ext_factor,\n        float                 attn_factor,\n        float                 beta_fast,\n        float                 beta_slow) {\n    return ggml_rope_impl(\n        ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,\n        ext_factor, attn_factor, beta_fast, beta_slow, true\n    );\n}\n\n// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get\n// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`\nstatic float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {\n    return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));\n}\n\nvoid ggml_rope_yarn_corr_dims(\n    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]\n) {\n    // start and end correction dims\n    float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));\n    float end   =  ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));\n    dims[0] = MAX(0, start);\n    dims[1] = MIN(n_dims - 1, end);\n}\n\n// ggml_rope_back\n\nstruct ggml_tensor * ggml_rope_ext_back(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        struct ggml_tensor  * c,\n        int                   n_dims,\n        int                   mode,\n        int                   n_ctx_orig,\n        float                 freq_base,\n        float                 freq_scale,\n        float                 ext_factor,\n        float                 attn_factor,\n        float                 beta_fast,\n        float                 beta_slow) {\n    struct ggml_tensor * result = ggml_rope_ext(\n        ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);\n    result->op = GGML_OP_ROPE_BACK;\n    return result;\n}\n\nstruct ggml_tensor * ggml_rope_multi_back(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        struct ggml_tensor  * c,\n        int                   n_dims,\n        int                   sections[4],\n        int                   mode,\n        int                   n_ctx_orig,\n        float                 freq_base,\n        float                 freq_scale,\n        float                 ext_factor,\n        float                 attn_factor,\n        float                 beta_fast,\n        float                 beta_slow) {\n    struct ggml_tensor * result = ggml_rope_multi(\n        ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);\n    result->op = GGML_OP_ROPE_BACK;\n    return result;\n}\n// ggml_clamp\n\nstruct ggml_tensor * ggml_clamp(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        float                 min,\n        float                 max) {\n    // TODO: when implement backward, fix this:\n    struct ggml_tensor * result = ggml_view_tensor(ctx, a);\n\n    float params[] = { min, max };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_CLAMP;\n    result->src[0] = a;\n\n    return result;\n}\n\nstatic int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {\n    return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;\n}\n\n// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]\n// a: [OC，IC, KH, KW]\n// b: [N, IC, IH, IW]\n// result: [N, OH, OW, IC*KH*KW]\nstruct ggml_tensor * ggml_im2col(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   s0,\n        int                   s1,\n        int                   p0,\n        int                   p1,\n        int                   d0,\n        int                   d1,\n        bool                  is_2D,\n        enum ggml_type        dst_type) {\n    if (is_2D) {\n        GGML_ASSERT(a->ne[2] == b->ne[2]);\n    } else {\n        //GGML_ASSERT(b->ne[1] % a->ne[1] == 0);\n        GGML_ASSERT(b->ne[1] == a->ne[1]);\n        GGML_ASSERT(b->ne[3] == 1);\n    }\n\n    const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;\n    const int64_t OW =         ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);\n\n    GGML_ASSERT((!is_2D || OH > 0) && \"b too small compared to a\");\n    GGML_ASSERT((OW > 0)           && \"b too small compared to a\");\n\n    const int64_t ne[4] = {\n        is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],\n        OW,\n        is_2D ? OH : b->ne[2],\n        is_2D ?      b->ne[3] : 1,\n    };\n\n    struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);\n    int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_IM2COL;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_im2col_back(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int64_t             * ne,\n        int                   s0,\n        int                   s1,\n        int                   p0,\n        int                   p1,\n        int                   d0,\n        int                   d1,\n        bool                  is_2D) {\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n    int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_IM2COL_BACK;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml_conv_1d\n\nstruct ggml_tensor * ggml_conv_1d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   s0,\n        int                   p0,\n        int                   d0) {\n    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]\n\n    struct ggml_tensor * result =\n        ggml_mul_mat(ctx,\n                ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]\n                ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2]));                    // [OC，IC, K] => [OC, IC * K]\n\n    result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]\n\n    return result;\n}\n\n// ggml_conv_1d_ph\n\nstruct ggml_tensor* ggml_conv_1d_ph(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   s,\n        int                   d) {\n    return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);\n}\n\n// ggml_conv_1d_dw\n\nstruct ggml_tensor * ggml_conv_1d_dw(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   s0,\n        int                   p0,\n        int                   d0) {\n    struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);\n\n    struct ggml_tensor * im2col = ggml_im2col(ctx, a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);\n\n    struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a);\n\n    result = ggml_reshape_3d(ctx, result, result->ne[0], result->ne[2], 1);\n\n    return result;\n}\n\n// ggml_conv_1d_dw_ph\n\nstruct ggml_tensor * ggml_conv_1d_dw_ph(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   s0,\n        int                   d0) {\n    return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0);\n}\n\n// ggml_conv_transpose_1d\n\nstatic int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {\n    return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;\n}\n\nGGML_API struct ggml_tensor * ggml_conv_transpose_1d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   s0,\n        int                   p0,\n        int                   d0) {\n    GGML_ASSERT(ggml_is_matrix(b));\n    GGML_ASSERT(a->ne[2] == b->ne[1]);\n    GGML_ASSERT(a->ne[3] == 1);\n\n    GGML_ASSERT(p0 == 0);\n    GGML_ASSERT(d0 == 1);\n\n    const int64_t ne[4] = {\n        ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),\n        a->ne[1], b->ne[2], 1,\n    };\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    int32_t params[] = { s0, p0, d0 };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_CONV_TRANSPOSE_1D;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml_conv_2d\n\n// a: [OC，IC, KH, KW]\n// b: [N, IC, IH, IW]\n// result: [N, OC, OH, OW]\nstruct ggml_tensor * ggml_conv_2d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   s0,\n        int                   s1,\n        int                   p0,\n        int                   p1,\n        int                   d0,\n        int                   d1) {\n    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW]\n\n    struct ggml_tensor * result =\n        ggml_mul_mat(ctx,\n                ggml_reshape_2d(ctx, im2col, im2col->ne[0],  im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]\n                ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]),  a->ne[3]));                       // [OC，IC, KH, KW] => [OC, IC * KH * KW]\n\n    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW]\n    result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW]\n\n\n    return result;\n}\n\n// a: [OC*IC, KD, KH, KW]\n// b: [N*IC, ID, IH, IW]\n// result: [N*OD, OH, OW, IC * KD * KH * KW]\nstruct ggml_tensor * ggml_im2col_3d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int64_t               IC,\n        int                   s0, // stride width\n        int                   s1, // stride height\n        int                   s2, // stride depth\n        int                   p0, // padding width\n        int                   p1, // padding height\n        int                   p2, // padding depth\n        int                   d0, // dilation width\n        int                   d1, // dilation height\n        int                   d2, // dilation depth\n        enum ggml_type        dst_type) {\n    const int64_t N = b->ne[3] / IC;\n    const int64_t ID = b->ne[2];\n    const int64_t IH = b->ne[1];\n    const int64_t IW = b->ne[0];\n\n    const int64_t OC = a->ne[3] / IC;\n    UNUSED(OC);\n    const int64_t KD = a->ne[2];\n    const int64_t KH = a->ne[1];\n    const int64_t KW = a->ne[0];\n    const int64_t OD = ggml_calc_conv_output_size(ID, KD, s2, p2, d2);\n    const int64_t OH = ggml_calc_conv_output_size(IH, KH, s1, p1, d1);\n    const int64_t OW = ggml_calc_conv_output_size(IW, KW, s0, p0, d0);\n\n    GGML_ASSERT((OD > 0)  && \"b too small compared to a\");\n    GGML_ASSERT((OH > 0)  && \"b too small compared to a\");\n    GGML_ASSERT((OW > 0)  && \"b too small compared to a\");\n\n\n    const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N};\n\n    struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);\n    int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC};\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_IM2COL_3D;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// a: [OC*IC, KD, KH, KW]\n// b: [N*IC, ID, IH, IW]\n// result: [N*OC, OD, OH, OW]\nstruct ggml_tensor * ggml_conv_3d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int64_t               IC,\n        int                   s0, // stride width\n        int                   s1, // stride height\n        int                   s2, // stride depth\n        int                   p0, // padding width\n        int                   p1, // padding height\n        int                   p2, // padding depth\n        int                   d0, // dilation width\n        int                   d1, // dilation height\n        int                   d2  // dilation depth\n        ) {\n    struct ggml_tensor * im2col = ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW]\n\n    int64_t OC = a->ne[3] / IC;\n    int64_t N = b->ne[3] / IC;\n    struct ggml_tensor * result =\n        ggml_mul_mat(ctx,\n                ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW]\n                ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC));                          // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW]\n\n    int64_t OD = im2col->ne[3] / N;\n    result = ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW]\n    result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW]\n    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW]\n\n    return result;\n}\n\n// ggml_conv_2d_sk_p0\n\nstruct ggml_tensor * ggml_conv_2d_sk_p0(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1);\n}\n\n// ggml_conv_2d_s1_ph\n\nstruct ggml_tensor * ggml_conv_2d_s1_ph(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);\n}\n\n// ggml_conv_2d_dw\n\nstruct ggml_tensor * ggml_conv_2d_dw(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   s0,\n        int                   s1,\n        int                   p0,\n        int                   p1,\n        int                   d0,\n        int                   d1) {\n    struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);\n    struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,\n                                        ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),\n                                        s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]\n    struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]\n\n    new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2],  new_a->ne[3], 1);                       // [OC，1, KH, KW] => [1, OC, 1, KH * KW]\n    struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);\n    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]\n\n    return result;\n}\n\n// ggml_conv_2d_dw_direct\n\nstruct ggml_tensor * ggml_conv_2d_dw_direct(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   stride0,\n        int                   stride1,\n        int                   pad0,\n        int                   pad1,\n        int                   dilation0,\n        int                   dilation1) {\n    GGML_ASSERT(a->ne[2] == 1);\n    GGML_ASSERT(a->ne[3] == b->ne[2]);\n    int64_t ne[4];\n    ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);\n    ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);\n    ne[2] = b->ne[2];\n    ne[3] = b->ne[3];\n\n    struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);\n\n    if (ggml_is_contiguous_channels(b)) {\n        // Result will be permuted the same way as input (CWHN order)\n        const int64_t type_size = ggml_type_size(result->type);\n        GGML_ASSERT(ggml_blck_size(result->type) == 1);\n        result->nb[0] = result->ne[2] * type_size;\n        result->nb[1] = result->ne[0] * result->nb[0];\n        result->nb[2] = type_size;\n    }\n\n    int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_CONV_2D_DW;\n    result->src[0] = a;\n    result->src[1] = b;\n    return result;\n}\n\n// ggml_conv_2d_direct\n\nstruct ggml_tensor * ggml_conv_2d_direct(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,   // convolution kernel [KW, KH, IC, OC]\n        struct ggml_tensor  * b,   // input data [W, H, C, N]\n        int                   s0,  // stride dimension 0\n        int                   s1,  // stride dimension 1\n        int                   p0,  // padding dimension 0\n        int                   p1,  // padding dimension 1\n        int                   d0,  // dilation dimension 0\n        int                   d1) {// dilation dimension 1\n\n    GGML_ASSERT(a->ne[2] == b->ne[2]);\n    //GGML_ASSERT(a->type == b->type);\n\n    int64_t ne[4];\n    ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);\n    ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);\n    ne[2] = a->ne[3];\n    ne[3] = b->ne[3];\n\n    struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);\n\n    ggml_set_op_params_i32(result, 0, s0);\n    ggml_set_op_params_i32(result, 1, s1);\n    ggml_set_op_params_i32(result, 2, p0);\n    ggml_set_op_params_i32(result, 3, p1);\n    ggml_set_op_params_i32(result, 4, d0);\n    ggml_set_op_params_i32(result, 5, d1);\n\n    result->op = GGML_OP_CONV_2D;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml_conv_3d_direct\n\nstruct ggml_tensor * ggml_conv_3d_direct(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   s0,\n        int                   s1,\n        int                   s2,\n        int                   p0,\n        int                   p1,\n        int                   p2,\n        int                   d0,\n        int                   d1,\n        int                   d2,\n        int                   c,\n        int                   n,\n        int                   oc) {\n\n    GGML_ASSERT(a->ne[3] == (int64_t) c * oc);\n    GGML_ASSERT(b->ne[3] == (int64_t) c * n);\n\n    int64_t ne[4];\n    ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);\n    ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);\n    ne[2] = ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2);\n    ne[3] = (int64_t) oc * n;\n\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    ggml_set_op_params_i32(result, 0,  s0);\n    ggml_set_op_params_i32(result, 1,  s1);\n    ggml_set_op_params_i32(result, 2,  s2);\n    ggml_set_op_params_i32(result, 3,  p0);\n    ggml_set_op_params_i32(result, 4,  p1);\n    ggml_set_op_params_i32(result, 5,  p2);\n    ggml_set_op_params_i32(result, 6,  d0);\n    ggml_set_op_params_i32(result, 7,  d1);\n    ggml_set_op_params_i32(result, 8,  d2);\n    ggml_set_op_params_i32(result, 9,  c);\n    ggml_set_op_params_i32(result, 10, n);\n    ggml_set_op_params_i32(result, 11, oc);\n\n    result->op = GGML_OP_CONV_3D;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml_conv_transpose_2d_p0\n\nstatic int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {\n    return (ins - 1) * s - 2 * p + ks;\n}\n\nstruct ggml_tensor * ggml_conv_transpose_2d_p0(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        int                   stride) {\n    GGML_ASSERT(a->ne[3] == b->ne[2]);\n\n    const int64_t ne[4] = {\n        ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/),\n        ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/),\n        a->ne[2], b->ne[3],\n    };\n\n    struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    ggml_set_op_params_i32(result, 0, stride);\n\n    result->op     = GGML_OP_CONV_TRANSPOSE_2D;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml_pool_*\n\nstatic int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) {\n    return (ins + 2 * p - ks) / s + 1;\n}\n\n// ggml_pool_1d\n\nstruct ggml_tensor * ggml_pool_1d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        enum ggml_op_pool     op,\n        int                   k0,\n        int                   s0,\n        int                   p0) {\n    const int64_t ne[4] = {\n        ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),\n        a->ne[1],\n        a->ne[2],\n        a->ne[3],\n    };\n    GGML_ASSERT(ne[0] > 0);\n\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    int32_t params[] = { op, k0, s0, p0 };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_POOL_1D;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_pool_2d\n\nstruct ggml_tensor * ggml_pool_2d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        enum ggml_op_pool     op,\n        int                   k0,\n        int                   k1,\n        int                   s0,\n        int                   s1,\n        float                 p0,\n        float                 p1) {\n    struct ggml_tensor * result;\n    const int64_t ne[4] = {\n        ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),\n        ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),\n        a->ne[2],\n        a->ne[3],\n    };\n    GGML_ASSERT(ne[0] > 0);\n    GGML_ASSERT(ne[1] > 0);\n\n    result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_POOL_2D;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_pool_2d_back(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * af,\n        enum ggml_op_pool     op,\n        int                   k0,\n        int                   k1,\n        int                   s0,\n        int                   s1,\n        float                 p0,\n        float                 p1) {\n    struct ggml_tensor * result;\n    result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, af->ne);\n\n    int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_POOL_2D_BACK;\n    result->src[0] = a;\n    result->src[1] = af;\n\n    return result;\n}\n\n// ggml_upscale / ggml_interpolate\n\nstatic struct ggml_tensor * ggml_interpolate_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int64_t               ne0,\n        int64_t               ne1,\n        int64_t               ne2,\n        int64_t               ne3,\n        uint32_t              mode) {\n    GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT);\n    // TODO: implement antialias for modes other than bilinear\n    GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR);\n\n    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);\n\n    ggml_set_op_params_i32(result, 0, (int32_t)mode);\n\n    result->op     = GGML_OP_UPSCALE;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_upscale(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   scale_factor,\n        enum ggml_scale_mode  mode) {\n    GGML_ASSERT(scale_factor > 1);\n    return ggml_interpolate_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);\n}\n\nstruct ggml_tensor * ggml_upscale_ext(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   ne0,\n        int                   ne1,\n        int                   ne2,\n        int                   ne3,\n        enum ggml_scale_mode  mode) {\n    return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);\n}\n\nstruct ggml_tensor * ggml_interpolate(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int64_t               ne0,\n        int64_t               ne1,\n        int64_t               ne2,\n        int64_t               ne3,\n        uint32_t              mode) {\n    return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);\n}\n\n// ggml_pad\n\nstruct ggml_tensor * ggml_pad(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   p0,\n        int                   p1,\n        int                   p2,\n        int                   p3) {\n    return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);\n}\n\n// ggml_pad_circular\n\nstruct ggml_tensor * ggml_pad_circular(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   p0,\n        int                   p1,\n        int                   p2,\n        int                   p3) {\n    return ggml_pad_ext_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);\n}\n\nstruct ggml_tensor * ggml_pad_ext(\n            struct ggml_context * ctx,\n            struct ggml_tensor  * a,\n            int                  lp0,\n            int                  rp0,\n            int                  lp1,\n            int                  rp1,\n            int                  lp2,\n            int                  rp2,\n            int                  lp3,\n            int                  rp3\n            ) {\n    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,\n            a->ne[0] + lp0 + rp0,\n            a->ne[1] + lp1 + rp1,\n            a->ne[2] + lp2 + rp2,\n            a->ne[3] + lp3 + rp3);\n\n    ggml_set_op_params_i32(result, 0, lp0);\n    ggml_set_op_params_i32(result, 1, rp0);\n    ggml_set_op_params_i32(result, 2, lp1);\n    ggml_set_op_params_i32(result, 3, rp1);\n    ggml_set_op_params_i32(result, 4, lp2);\n    ggml_set_op_params_i32(result, 5, rp2);\n    ggml_set_op_params_i32(result, 6, lp3);\n    ggml_set_op_params_i32(result, 7, rp3);\n    ggml_set_op_params_i32(result, 8, 0); // not circular by default\n\n\n    result->op     = GGML_OP_PAD;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_pad_ext_circular\n\nstruct ggml_tensor * ggml_pad_ext_circular(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                  lp0,\n        int                  rp0,\n        int                  lp1,\n        int                  rp1,\n        int                  lp2,\n        int                  rp2,\n        int                  lp3,\n        int                  rp3\n        ) {\n    struct ggml_tensor * result = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);\n    ggml_set_op_params_i32(result, 8, 1); // circular\n    return result;\n}\n\n// ggml_pad_reflect_1d\n\nstruct ggml_tensor * ggml_pad_reflect_1d(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   p0,\n        int                   p1) {\n    GGML_ASSERT(p0 >= 0);\n    GGML_ASSERT(p1 >= 0);\n\n    GGML_ASSERT(p0 < a->ne[0]); // padding length on each size must be less than the\n    GGML_ASSERT(p1 < a->ne[0]); // existing length of the dimension being padded\n\n    GGML_ASSERT(ggml_is_contiguous(a));\n    GGML_ASSERT(a->type == GGML_TYPE_F32);\n\n    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,\n            a->ne[0] + p0 + p1,\n            a->ne[1],\n            a->ne[2],\n            a->ne[3]);\n\n    int32_t params[] = { p0, p1 };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_PAD_REFLECT_1D;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_roll\n\nstruct ggml_tensor * ggml_roll(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   shift0,\n        int                   shift1,\n        int                   shift2,\n        int                   shift3) {\n    GGML_ASSERT(a->nb[0] == ggml_type_size(a->type));\n    GGML_ASSERT(abs(shift0) < a->ne[0]);\n    GGML_ASSERT(abs(shift1) < a->ne[1]);\n    GGML_ASSERT(abs(shift2) < a->ne[2]);\n    GGML_ASSERT(abs(shift3) < a->ne[3]);\n\n    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);\n\n    ggml_set_op_params_i32(result, 0, shift0);\n    ggml_set_op_params_i32(result, 1, shift1);\n    ggml_set_op_params_i32(result, 2, shift2);\n    ggml_set_op_params_i32(result, 3, shift3);\n\n    result->op     = GGML_OP_ROLL;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_timestep_embedding\n\nstruct ggml_tensor * ggml_timestep_embedding(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * timesteps,\n        int                   dim,\n        int                   max_period) {\n\n    struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, timesteps->ne[0]);\n\n    ggml_set_op_params_i32(result, 0, dim);\n    ggml_set_op_params_i32(result, 1, max_period);\n\n    result->op     = GGML_OP_TIMESTEP_EMBEDDING;\n    result->src[0] = timesteps;\n\n    return result;\n}\n\n// ggml_tri\n\nstruct ggml_tensor * ggml_tri(\n    struct ggml_context * ctx,\n    struct ggml_tensor  * a,\n    enum ggml_tri_type    type) {\n    GGML_ASSERT(a->type == GGML_TYPE_F32);\n\n    GGML_ASSERT(ggml_is_contiguous(a));\n    GGML_ASSERT(a->ne[0] == a->ne[1]);\n\n    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);\n\n    ggml_set_op_params_i32(result, 0, type);\n\n    result->op = GGML_OP_TRI;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_fill\n\nstatic struct ggml_tensor * ggml_fill_impl(\n    struct ggml_context * ctx,\n    struct ggml_tensor  * a,\n    float                 c,\n    bool                  inplace) {\n    GGML_ASSERT(a->type == GGML_TYPE_F32);\n    GGML_ASSERT(ggml_is_contiguous(a));\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    ggml_set_op_params_f32(result, 0, c);\n\n    result->op = GGML_OP_FILL;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_fill(\n    struct ggml_context * ctx,\n    struct ggml_tensor  * a,\n    float                 c) {\n    return ggml_fill_impl(ctx, a, c, false);\n}\n\nstruct ggml_tensor * ggml_fill_inplace(\n    struct ggml_context * ctx,\n    struct ggml_tensor  * a,\n    float                 c) {\n    return ggml_fill_impl(ctx, a, c, true);\n}\n\n// ggml_argsort\n\nstruct ggml_tensor * ggml_argsort(\n        struct ggml_context  * ctx,\n        struct ggml_tensor   * a,\n        enum ggml_sort_order   order) {\n    GGML_ASSERT(a->ne[0] <= INT32_MAX);\n\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);\n\n    ggml_set_op_params_i32(result, 0, (int32_t) order);\n\n    result->op     = GGML_OP_ARGSORT;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_argsort_top_k\n\nstruct ggml_tensor * ggml_argsort_top_k(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   k) {\n    GGML_ASSERT(a->ne[0] >= k);\n\n    struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);\n\n    result = ggml_view_4d(ctx, result,\n                k, result->ne[1], result->ne[2], result->ne[3],\n                   result->nb[1], result->nb[2], result->nb[3],\n                0);\n\n    return result;\n}\n\n// ggml_top_k\n\nstruct ggml_tensor * ggml_top_k(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   k) {\n    GGML_ASSERT(a->ne[0] >= k);\n\n    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);\n\n    result->op     = GGML_OP_TOP_K;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_arange\n\nstruct ggml_tensor * ggml_arange(\n        struct ggml_context * ctx,\n        float                 start,\n        float                 stop,\n        float                 step) {\n    GGML_ASSERT(stop > start);\n\n    const int64_t steps = (int64_t) ceilf((stop - start) / step);\n\n    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);\n\n    ggml_set_op_params_f32(result, 0, start);\n    ggml_set_op_params_f32(result, 1, stop);\n    ggml_set_op_params_f32(result, 2, step);\n\n    result->op = GGML_OP_ARANGE;\n\n    return result;\n}\n\n// ggml_flash_attn_ext\n\nstruct ggml_tensor * ggml_flash_attn_ext(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * q,\n        struct ggml_tensor  * k,\n        struct ggml_tensor  * v,\n        struct ggml_tensor  * mask,\n        float                 scale,\n        float                 max_bias,\n        float                 logit_softcap) {\n    GGML_ASSERT(ggml_can_mul_mat(k, q));\n    // TODO: check if vT can be multiplied by (k*qT)\n\n    GGML_ASSERT(q->ne[3] == k->ne[3]);\n    GGML_ASSERT(q->ne[3] == v->ne[3]);\n\n    if (mask) {\n        GGML_ASSERT(ggml_is_contiguous(mask));\n        //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));\n\n        GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);\n        GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);\n    }\n\n    if (max_bias > 0.0f) {\n        GGML_ASSERT(mask);\n    }\n\n    // permute(0, 2, 1, 3)\n    int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    float params[] = { scale, max_bias, logit_softcap };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_FLASH_ATTN_EXT;\n    result->src[0] = q;\n    result->src[1] = k;\n    result->src[2] = v;\n    result->src[3] = mask;\n\n    return result;\n}\n\nvoid ggml_flash_attn_ext_set_prec(\n        struct ggml_tensor * a,\n        enum ggml_prec       prec) {\n    GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);\n\n    const int32_t prec_i32 = (int32_t) prec;\n\n    ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second\n}\n\nenum ggml_prec ggml_flash_attn_ext_get_prec(\n        const struct ggml_tensor * a) {\n    GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);\n\n    const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);\n\n    return (enum ggml_prec) prec_i32;\n}\n\nvoid ggml_flash_attn_ext_add_sinks(\n        struct ggml_tensor * a,\n        struct ggml_tensor * sinks) {\n    if (!sinks) {\n        a->src[4] = NULL;\n        return;\n    }\n\n    GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);\n    GGML_ASSERT(a->src[4] == NULL);\n    GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);\n    GGML_ASSERT(sinks->type == GGML_TYPE_F32);\n\n    a->src[4] = sinks;\n}\n\n// ggml_flash_attn_back\n\nstruct ggml_tensor * ggml_flash_attn_back(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * q,\n        struct ggml_tensor  * k,\n        struct ggml_tensor  * v,\n        struct ggml_tensor  * d,\n        bool                  masked) {\n    GGML_ABORT(\"TODO: adapt to ggml_flash_attn_ext() changes\");\n\n    GGML_ASSERT(ggml_can_mul_mat(k, q));\n    // TODO: check if vT can be multiplied by (k*qT)\n\n    // d shape [D,N,ne2,ne3]\n    // q shape [D,N,ne2,ne3]\n    // k shape [D,M,kvne2,ne3]\n    // v shape [M,D,kvne2,ne3]\n\n    const int64_t     D = q->ne[0];\n    const int64_t     N = q->ne[1];\n    const int64_t     M = k->ne[1];\n    const int64_t   ne2 = q->ne[2];\n    const int64_t   ne3 = q->ne[3];\n    const int64_t kvne2 = k->ne[2];\n\n    GGML_ASSERT(k->ne[0] == D);\n    GGML_ASSERT(v->ne[0] == M);\n    GGML_ASSERT(v->ne[1] == D);\n    GGML_ASSERT(d->ne[0] == D);\n    GGML_ASSERT(d->ne[1] == N);\n    GGML_ASSERT(k->ne[2] == kvne2);\n    GGML_ASSERT(k->ne[3] == ne3);\n    GGML_ASSERT(v->ne[2] == kvne2);\n    GGML_ASSERT(v->ne[3] == ne3);\n    GGML_ASSERT(d->ne[2] == ne2);\n    GGML_ASSERT(d->ne[3] == ne3);\n\n    GGML_ASSERT(ne2 % kvne2 == 0);\n\n    // store gradients of q, k and v as continuous tensors concatenated in result.\n    // note: v and gradv are actually transposed, i.e. v->ne[0] != D.\n    const int64_t elem_q = ggml_nelements(q);\n    const int64_t elem_k = ggml_nelements(k);\n    const int64_t elem_v = ggml_nelements(v);\n\n    enum ggml_type result_type = GGML_TYPE_F32;\n    GGML_ASSERT(ggml_blck_size(result_type) == 1);\n    const size_t tsize = ggml_type_size(result_type);\n\n    const size_t offs_q = 0;\n    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);\n    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);\n    const size_t end    = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);\n\n    const size_t nelements = (end + tsize - 1)/tsize;\n\n    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);\n\n    int32_t masked_i = masked ? 1 : 0;\n    ggml_set_op_params(result, &masked_i, sizeof(masked_i));\n\n    result->op     = GGML_OP_FLASH_ATTN_BACK;\n    result->src[0] = q;\n    result->src[1] = k;\n    result->src[2] = v;\n    result->src[3] = d;\n\n    return result;\n}\n\n// ggml_ssm_conv\n\nstruct ggml_tensor * ggml_ssm_conv(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * sx,\n        struct ggml_tensor  * c) {\n    GGML_ASSERT(ggml_is_3d(sx));\n    GGML_ASSERT(ggml_is_matrix(c));\n\n    const int64_t d_conv  = c->ne[0];\n    const int64_t d_inner = c->ne[1];\n    const int64_t n_t     = sx->ne[0] - d_conv + 1; // tokens per sequence\n    const int64_t n_s     = sx->ne[2];\n\n    // TODO: maybe support other strides than 1?\n    GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);\n    GGML_ASSERT(sx->ne[1] == d_inner);\n    GGML_ASSERT(n_t >= 0);\n\n    struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);\n\n    result->op     = GGML_OP_SSM_CONV;\n    result->src[0] = sx;\n    result->src[1] = c;\n\n    return result;\n}\n\n// ggml_ssm_scan\n\nstruct ggml_tensor * ggml_ssm_scan(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * s,\n        struct ggml_tensor  * x,\n        struct ggml_tensor  * dt,\n        struct ggml_tensor  * A,\n        struct ggml_tensor  * B,\n        struct ggml_tensor  * C,\n        struct ggml_tensor  * ids) {\n    GGML_ASSERT(ggml_is_contiguous(s));\n    GGML_ASSERT(ggml_is_contiguous(dt));\n    GGML_ASSERT(ggml_is_contiguous(A));\n    GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));\n    GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));\n    GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));\n    GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);\n    GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);\n    GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);\n    GGML_ASSERT(ggml_are_same_shape(B, C));\n    GGML_ASSERT(ids->type == GGML_TYPE_I32);\n\n    {\n        const int64_t d_state      = s->ne[0];\n        const int64_t head_dim     = x->ne[0];\n        const int64_t n_head       = x->ne[1];\n        const int64_t n_seq_tokens = x->ne[2];\n        const int64_t n_seqs       = x->ne[3];\n\n        GGML_ASSERT(dt->ne[0] == n_head);\n        GGML_ASSERT(dt->ne[1] == n_seq_tokens);\n        GGML_ASSERT(dt->ne[2] == n_seqs);\n        GGML_ASSERT(ggml_is_3d(dt));\n        GGML_ASSERT(s->ne[1] == head_dim);\n        GGML_ASSERT(s->ne[2] == n_head);\n        GGML_ASSERT(B->ne[0] == d_state);\n        GGML_ASSERT(B->ne[2] == n_seq_tokens);\n        GGML_ASSERT(B->ne[3] == n_seqs);\n        GGML_ASSERT(ids->ne[0] == n_seqs);\n        GGML_ASSERT(ggml_is_vector(ids));\n        GGML_ASSERT(A->ne[1] == n_head);\n        GGML_ASSERT(ggml_is_matrix(A));\n\n        if (A->ne[0] != 1) {\n            // Mamba-1 has more granular decay factors\n            GGML_ASSERT(A->ne[0] == d_state);\n        }\n    }\n\n    // concatenated y + ssm_states\n    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);\n\n    result->op   = GGML_OP_SSM_SCAN;\n    result->src[0] = s;\n    result->src[1] = x;\n    result->src[2] = dt;\n    result->src[3] = A;\n    result->src[4] = B;\n    result->src[5] = C;\n    result->src[6] = ids;\n\n    return result;\n}\n\n// ggml_win_part\n\nstruct ggml_tensor * ggml_win_part(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   w) {\n    GGML_ASSERT(a->ne[3] == 1);\n    GGML_ASSERT(a->type  == GGML_TYPE_F32);\n\n    // padding\n    const int px = (w - a->ne[1]%w)%w;\n    const int py = (w - a->ne[2]%w)%w;\n\n    const int npx = (px + a->ne[1])/w;\n    const int npy = (py + a->ne[2])/w;\n    const int np  = npx*npy;\n\n    const int64_t ne[4] = { a->ne[0], w, w, np, };\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    int32_t params[] = { npx, npy, w };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_WIN_PART;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_win_unpart\n\nstruct ggml_tensor * ggml_win_unpart(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   w0,\n        int                   h0,\n        int                   w) {\n    GGML_ASSERT(a->type == GGML_TYPE_F32);\n\n    const int64_t ne[4] = { a->ne[0], w0, h0, 1, };\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);\n\n    int32_t params[] = { w };\n    ggml_set_op_params(result, params, sizeof(params));\n\n    result->op     = GGML_OP_WIN_UNPART;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_get_rel_pos\n\nstruct ggml_tensor * ggml_get_rel_pos(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        int                   qh,\n        int                   kh) {\n    GGML_ASSERT(qh == kh);\n    GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]);\n\n    const int64_t ne[4] = { a->ne[0], kh, qh, 1, };\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne);\n\n    result->op     = GGML_OP_GET_REL_POS;\n    result->src[0] = a;\n\n    return result;\n}\n\n// ggml_add_rel_pos\n\nstatic struct ggml_tensor * ggml_add_rel_pos_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * pw,\n        struct ggml_tensor  * ph,\n        bool                  inplace) {\n    GGML_ASSERT(ggml_are_same_shape(pw, ph));\n    GGML_ASSERT(ggml_is_contiguous(a));\n    GGML_ASSERT(ggml_is_contiguous(pw));\n    GGML_ASSERT(ggml_is_contiguous(ph));\n    GGML_ASSERT(ph->type == GGML_TYPE_F32);\n    GGML_ASSERT(pw->type == GGML_TYPE_F32);\n    GGML_ASSERT(pw->ne[3] == a->ne[2]);\n    GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]);\n    GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]);\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n    ggml_set_op_params_i32(result, 0, inplace ? 1 : 0);\n\n    result->op     = GGML_OP_ADD_REL_POS;\n    result->src[0] = a;\n    result->src[1] = pw;\n    result->src[2] = ph;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_add_rel_pos(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * pw,\n        struct ggml_tensor  * ph) {\n    return ggml_add_rel_pos_impl(ctx, a, pw, ph, false);\n}\n\nstruct ggml_tensor * ggml_add_rel_pos_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * pw,\n        struct ggml_tensor  * ph) {\n    return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);\n}\n\n// ggml_rwkv_wkv6\n\nstruct ggml_tensor * ggml_rwkv_wkv6(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * k,\n        struct ggml_tensor  * v,\n        struct ggml_tensor  * r,\n        struct ggml_tensor  * tf,\n        struct ggml_tensor  * td,\n        struct ggml_tensor  * state) {\n    GGML_ASSERT(ggml_is_contiguous(k));\n    GGML_ASSERT(ggml_is_contiguous(v));\n    GGML_ASSERT(ggml_is_contiguous(r));\n    GGML_ASSERT(ggml_is_contiguous(tf));\n    GGML_ASSERT(ggml_is_contiguous(td));\n    GGML_ASSERT(ggml_is_contiguous(state));\n\n    const int64_t S = k->ne[0];\n    const int64_t H = k->ne[1];\n    const int64_t n_tokens = k->ne[2];\n    const int64_t n_seqs = state->ne[1];\n    {\n        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);\n        GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);\n        GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);\n        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);\n    }\n\n    // concat output and new_state\n    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    result->op     = GGML_OP_RWKV_WKV6;\n    result->src[0] = k;\n    result->src[1] = v;\n    result->src[2] = r;\n    result->src[3] = tf;\n    result->src[4] = td;\n    result->src[5] = state;\n\n    return result;\n}\n\n// ggml_gated_linear_attn\n\nstruct ggml_tensor * ggml_gated_linear_attn(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * k,\n        struct ggml_tensor  * v,\n        struct ggml_tensor  * q,\n        struct ggml_tensor  * g,\n        struct ggml_tensor  * state,\n        float scale) {\n    GGML_ASSERT(ggml_is_contiguous(k));\n    GGML_ASSERT(ggml_is_contiguous(v));\n    GGML_ASSERT(ggml_is_contiguous(q));\n    GGML_ASSERT(ggml_is_contiguous(g));\n    GGML_ASSERT(ggml_is_contiguous(state));\n\n    const int64_t S = k->ne[0];\n    const int64_t H = k->ne[1];\n    const int64_t n_tokens = k->ne[2];\n    const int64_t n_seqs = state->ne[1];\n    {\n        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);\n        GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);\n        GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);\n        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);\n    }\n\n    // concat output and new_state\n    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    ggml_set_op_params_f32(result, 0, scale);\n\n    result->op     = GGML_OP_GATED_LINEAR_ATTN;\n    result->src[0] = k;\n    result->src[1] = v;\n    result->src[2] = q;\n    result->src[3] = g;\n    result->src[4] = state;\n\n    return result;\n}\n\n// ggml_rwkv_wkv7\n\nstruct ggml_tensor * ggml_rwkv_wkv7(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * r,\n        struct ggml_tensor  * w,\n        struct ggml_tensor  * k,\n        struct ggml_tensor  * v,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        struct ggml_tensor  * state) {\n    GGML_ASSERT(ggml_is_contiguous(r));\n    GGML_ASSERT(ggml_is_contiguous(w));\n    GGML_ASSERT(ggml_is_contiguous(k));\n    GGML_ASSERT(ggml_is_contiguous(v));\n    GGML_ASSERT(ggml_is_contiguous(a));\n    GGML_ASSERT(ggml_is_contiguous(b));\n    GGML_ASSERT(ggml_is_contiguous(state));\n\n    const int64_t S = k->ne[0];\n    const int64_t H = k->ne[1];\n    const int64_t n_tokens = k->ne[2];\n    const int64_t n_seqs = state->ne[1];\n    {\n        GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);\n        GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);\n        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);\n        GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);\n        GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);\n        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);\n    }\n\n    // concat output and new_state\n    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    result->op     = GGML_OP_RWKV_WKV7;\n    result->src[0] = r;\n    result->src[1] = w;\n    result->src[2] = k;\n    result->src[3] = v;\n    result->src[4] = a;\n    result->src[5] = b;\n    result->src[6] = state;\n\n    return result;\n}\n\n// ggml_unary\n\nstatic struct ggml_tensor * ggml_unary_impl(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        enum ggml_unary_op    op,\n        bool                  inplace) {\n    GGML_ASSERT(ggml_is_contiguous_rows(a));\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    ggml_set_op_params_i32(result, 0, (int32_t) op);\n\n    result->op     = GGML_OP_UNARY;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_unary(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        enum ggml_unary_op    op) {\n    return ggml_unary_impl(ctx, a, op, false);\n}\n\nstruct ggml_tensor * ggml_unary_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        enum ggml_unary_op    op) {\n    return ggml_unary_impl(ctx, a, op, true);\n}\n\n// ggml_map_custom1\n\nstatic struct ggml_tensor * ggml_map_custom1_impl(\n        struct ggml_context      * ctx,\n        struct ggml_tensor       * a,\n        const  ggml_custom1_op_t   fun,\n        int                        n_tasks,\n        void                     * userdata,\n        bool                       inplace) {\n    GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    struct ggml_map_custom1_op_params params = {\n        /*.fun      =*/ fun,\n        /*.n_tasks  =*/ n_tasks,\n        /*.userdata =*/ userdata\n    };\n    ggml_set_op_params(result, &params, sizeof(params));\n\n    result->op     = GGML_OP_MAP_CUSTOM1;\n    result->src[0] = a;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_map_custom1(\n        struct ggml_context      * ctx,\n        struct ggml_tensor       * a,\n        const  ggml_custom1_op_t   fun,\n        int                        n_tasks,\n        void                     * userdata) {\n    return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false);\n}\n\nstruct ggml_tensor * ggml_map_custom1_inplace(\n        struct ggml_context      * ctx,\n        struct ggml_tensor       * a,\n        const  ggml_custom1_op_t   fun,\n        int                        n_tasks,\n        void                     * userdata) {\n    return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true);\n}\n\n// ggml_map_custom2\n\nstatic struct ggml_tensor * ggml_map_custom2_impl(\n        struct ggml_context      * ctx,\n        struct ggml_tensor       * a,\n        struct ggml_tensor       * b,\n        const  ggml_custom2_op_t   fun,\n        int                        n_tasks,\n        void                     * userdata,\n        bool                       inplace) {\n    GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    struct ggml_map_custom2_op_params params = {\n        /*.fun      =*/ fun,\n        /*.n_tasks  =*/ n_tasks,\n        /*.userdata =*/ userdata\n    };\n    ggml_set_op_params(result, &params, sizeof(params));\n\n    result->op     = GGML_OP_MAP_CUSTOM2;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_map_custom2(\n        struct ggml_context      * ctx,\n        struct ggml_tensor       * a,\n        struct ggml_tensor       * b,\n        const  ggml_custom2_op_t   fun,\n        int                        n_tasks,\n        void                     * userdata) {\n    return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false);\n}\n\nstruct ggml_tensor * ggml_map_custom2_inplace(\n        struct ggml_context      * ctx,\n        struct ggml_tensor       * a,\n        struct ggml_tensor       * b,\n        const  ggml_custom2_op_t   fun,\n        int                        n_tasks,\n        void                     * userdata) {\n    return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true);\n}\n\n// ggml_map_custom3\n\nstatic struct ggml_tensor * ggml_map_custom3_impl(\n        struct ggml_context      * ctx,\n        struct ggml_tensor       * a,\n        struct ggml_tensor       * b,\n        struct ggml_tensor       * c,\n        const  ggml_custom3_op_t   fun,\n        int                        n_tasks,\n        void                     * userdata,\n        bool                       inplace) {\n    GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);\n\n    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);\n\n    struct ggml_map_custom3_op_params params = {\n        /*.fun      =*/ fun,\n        /*.n_tasks  =*/ n_tasks,\n        /*.userdata =*/ userdata\n    };\n    ggml_set_op_params(result, &params, sizeof(params));\n\n    result->op     = GGML_OP_MAP_CUSTOM3;\n    result->src[0] = a;\n    result->src[1] = b;\n    result->src[2] = c;\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_map_custom3(\n        struct ggml_context      * ctx,\n        struct ggml_tensor       * a,\n        struct ggml_tensor       * b,\n        struct ggml_tensor       * c,\n        const  ggml_custom3_op_t   fun,\n        int                        n_tasks,\n        void                     * userdata) {\n    return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false);\n}\n\nstruct ggml_tensor * ggml_map_custom3_inplace(\n        struct ggml_context      * ctx,\n        struct ggml_tensor       * a,\n        struct ggml_tensor       * b,\n        struct ggml_tensor       * c,\n        const  ggml_custom3_op_t   fun,\n        int                        n_tasks,\n        void                     * userdata) {\n    return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);\n}\n\nstruct ggml_tensor * ggml_custom_4d(\n        struct ggml_context * ctx,\n        enum ggml_type        type,\n        int64_t               ne0,\n        int64_t               ne1,\n        int64_t               ne2,\n        int64_t               ne3,\n        struct ggml_tensor ** args,\n        int                   n_args,\n        ggml_custom_op_t      fun,\n        int                   n_tasks,\n        void                * userdata) {\n\n    GGML_ASSERT(n_args < GGML_MAX_SRC);\n\n    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);\n\n    struct ggml_custom_op_params params = {\n        /*.fun      =*/ fun,\n        /*.n_tasks  =*/ n_tasks,\n        /*.userdata =*/ userdata\n    };\n    ggml_set_op_params(result, &params, sizeof(params));\n\n    result->op = GGML_OP_CUSTOM;\n    for (int i = 0; i < n_args; i++) {\n        result->src[i] = args[i];\n    }\n\n    return result;\n}\n\nstruct ggml_tensor * ggml_custom_inplace(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor ** args,\n        int                   n_args,\n        ggml_custom_op_t      fun,\n        int                   n_tasks,\n        void                * userdata) {\n\n    GGML_ASSERT(n_args < GGML_MAX_SRC - 1);\n\n    struct ggml_tensor * result = ggml_view_tensor(ctx, a);\n\n    struct ggml_custom_op_params params = {\n        /*.fun      =*/ fun,\n        /*.n_tasks  =*/ n_tasks,\n        /*.userdata =*/ userdata\n    };\n    ggml_set_op_params(result, &params, sizeof(params));\n\n    result->op = GGML_OP_CUSTOM;\n    result->src[0] = a;\n    for (int i = 0; i < n_args; i++) {\n        result->src[i + 1] = args[i];\n    }\n\n    return result;\n}\n// ggml_cross_entropy_loss\n\nstruct ggml_tensor * ggml_cross_entropy_loss(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b) {\n    GGML_ASSERT(ggml_are_same_shape(a, b));\n\n    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);\n\n    result->op     = GGML_OP_CROSS_ENTROPY_LOSS;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml_cross_entropy_loss_back\n\nstruct ggml_tensor * ggml_cross_entropy_loss_back(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        struct ggml_tensor  * c) {\n    GGML_ASSERT(ggml_is_scalar(a));\n    GGML_ASSERT(ggml_are_same_shape(b, c));\n\n    struct ggml_tensor * result = ggml_dup_tensor(ctx, b);\n\n    result->op     = GGML_OP_CROSS_ENTROPY_LOSS_BACK;\n    result->src[0] = a;\n    result->src[1] = b;\n    result->src[2] = c;\n\n    return result;\n}\n\n// opt_step_adamw\n\nstruct ggml_tensor * ggml_opt_step_adamw(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * grad,\n        struct ggml_tensor  * m,\n        struct ggml_tensor  * v,\n        struct ggml_tensor  * adamw_params) {\n    GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);\n    GGML_ASSERT(ggml_are_same_shape(a, grad));\n    GGML_ASSERT(ggml_are_same_shape(a, m));\n    GGML_ASSERT(ggml_are_same_shape(a, v));\n    GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);\n    GGML_ASSERT(ggml_nelements(adamw_params) == 7);\n\n    struct ggml_tensor * result = ggml_view_tensor(ctx, a);\n\n    result->op     = GGML_OP_OPT_STEP_ADAMW;\n    result->src[0] = a;\n    result->src[1] = grad;\n    result->src[2] = m;\n    result->src[3] = v;\n    result->src[4] = adamw_params;\n\n    return result;\n}\n\n// opt_step_sgd\n\nstruct ggml_tensor * ggml_opt_step_sgd(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * grad,\n        struct ggml_tensor  * params) {\n    GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);\n    GGML_ASSERT(ggml_are_same_shape(a, grad));\n    GGML_ASSERT(params->type == GGML_TYPE_F32);\n    GGML_ASSERT(ggml_nelements(params) == 2);\n\n    struct ggml_tensor * result = ggml_view_tensor(ctx, a);\n\n    result->op     = GGML_OP_OPT_STEP_SGD;\n    result->src[0] = a;\n    result->src[1] = grad;\n    result->src[2] = params;\n\n    return result;\n}\n\n// solve_tri\n\nstruct ggml_tensor * ggml_solve_tri(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * a,\n        struct ggml_tensor  * b,\n        bool                  left,\n        bool                  lower,\n        bool                  uni) {\n    GGML_ASSERT(a->type == GGML_TYPE_F32);\n    GGML_ASSERT(b->type == GGML_TYPE_F32);\n\n    // A must be square and lower diagonal\n    GGML_ASSERT(a->ne[0] == a->ne[1]);\n    // B must have same outer dimension as A\n    GGML_ASSERT(a->ne[1] == b->ne[1]);\n\n    // batch dimensions must be equal\n    GGML_ASSERT(a->ne[2] == b->ne[2]);\n    GGML_ASSERT(a->ne[3] == b->ne[3]);\n\n    GGML_ASSERT(ggml_is_contiguous(a));\n    GGML_ASSERT(ggml_is_contiguous(b));\n\n    GGML_ASSERT(lower && left && !uni); // TODO: support other variants\n\n    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, b->ne[0], b->ne[1], b->ne[2], b->ne[3]);\n\n    result->op     = GGML_OP_SOLVE_TRI;\n    result->src[0] = a;\n    result->src[1] = b;\n\n    return result;\n}\n\n// ggml_gated_delta_net\n\nstruct ggml_tensor * ggml_gated_delta_net(\n        struct ggml_context * ctx,\n        struct ggml_tensor  * q,\n        struct ggml_tensor  * k,\n        struct ggml_tensor  * v,\n        struct ggml_tensor  * g,\n        struct ggml_tensor  * beta,\n        struct ggml_tensor  * state) {\n    GGML_ASSERT(ggml_is_contiguous_rows(q));\n    GGML_ASSERT(ggml_is_contiguous_rows(k));\n    GGML_ASSERT(ggml_is_contiguous_rows(v));\n    GGML_ASSERT(ggml_is_contiguous(g));\n    GGML_ASSERT(ggml_is_contiguous(beta));\n    GGML_ASSERT(ggml_is_contiguous(state));\n\n    GGML_ASSERT(q->type == GGML_TYPE_F32);\n    GGML_ASSERT(k->type == GGML_TYPE_F32);\n    GGML_ASSERT(v->type == GGML_TYPE_F32);\n    GGML_ASSERT(g->type == GGML_TYPE_F32);\n    GGML_ASSERT(beta->type == GGML_TYPE_F32);\n    GGML_ASSERT(state->type == GGML_TYPE_F32);\n\n    const int64_t S_v      = v->ne[0];\n    const int64_t H        = v->ne[1];\n    const int64_t n_tokens = v->ne[2];\n    const int64_t n_seqs   = v->ne[3];\n\n    // gate: scalar [1, H, T, B] or vector [S_v, H, T, B] (KDA)\n    GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);\n    GGML_ASSERT(beta->ne[0] == 1);\n\n    GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs);\n\n    // concat output and new_state into a single tensor\n    // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs\n    const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 };\n    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n\n    result->op     = GGML_OP_GATED_DELTA_NET;\n    result->src[0] = q;\n    result->src[1] = k;\n    result->src[2] = v;\n    result->src[3] = g;\n    result->src[4] = beta;\n    result->src[5] = state;\n\n    return result;\n}\n\n////////////////////////////////////////////////////////////////////////////////\n\nstruct ggml_hash_set ggml_hash_set_new(size_t size) {\n    size = ggml_hash_size(size);\n    struct ggml_hash_set result;\n    result.size = size;\n    result.keys = GGML_MALLOC(sizeof(struct ggml_tensor *) * size);\n    result.used = GGML_CALLOC(ggml_bitset_size(size), sizeof(ggml_bitset_t));\n    return result;\n}\n\nvoid ggml_hash_set_reset(struct ggml_hash_set * hash_set) {\n    memset(hash_set->used, 0, sizeof(ggml_bitset_t) * ggml_bitset_size(hash_set->size));\n}\n\nvoid ggml_hash_set_free(struct ggml_hash_set * hash_set) {\n    GGML_FREE(hash_set->used);\n    GGML_FREE(hash_set->keys);\n}\n\nsize_t ggml_hash_size(size_t min_sz) {\n    // next primes after powers of two\n    static const size_t primes[] = {\n        2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031,\n        2053, 4099, 8209, 16411, 32771, 65537, 131101,\n        262147, 524309, 1048583, 2097169, 4194319, 8388617,\n        16777259, 33554467, 67108879, 134217757, 268435459,\n        536870923, 1073741827, 2147483659\n    };\n    static const size_t n_primes = sizeof(primes)/sizeof(primes[0]);\n\n    // find the smallest prime that is larger or equal than min_sz\n    size_t l = 0;\n    size_t r = n_primes;\n    while (l < r) {\n        size_t m = (l + r)/2;\n        if (primes[m] < min_sz) {\n            l = m + 1;\n        } else {\n            r = m;\n        }\n    }\n    size_t sz = l < n_primes ? primes[l] : min_sz | 1;\n    return sz;\n}\n\nstruct hash_map {\n    struct ggml_hash_set set;\n    struct ggml_tensor ** vals;\n};\n\nstatic struct hash_map * ggml_new_hash_map(size_t size) {\n    struct hash_map * result = GGML_MALLOC(sizeof(struct hash_map));\n    result->set = ggml_hash_set_new(size);\n    result->vals = GGML_CALLOC(result->set.size, sizeof(struct ggml_tensor *));\n    return result;\n}\n\nstatic void ggml_hash_map_free(struct hash_map * map) {\n    ggml_hash_set_free(&map->set);\n    GGML_FREE(map->vals);\n    GGML_FREE(map);\n}\n\n// utility functions to change gradients\n// isrc is the index of tensor in cgraph->visited_has_set.keys\n// the corresponding gradient (accumulators) are also at position isrc\n// if tensor has a gradient accumulator, modify that accumulator in-place\n// else if there is no gradient for tensor, set the corresponding value\n// else, just add/subtract/etc. the gradients\n\nstatic void ggml_add_or_set(\n        struct ggml_context * ctx,\n        struct ggml_cgraph  * cgraph,\n        size_t                isrc,\n        struct ggml_tensor  * tensor) {\n    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];\n    GGML_ASSERT(src);\n    if (cgraph->grads[isrc]) {\n        cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, /*inplace =*/ cgraph->grad_accs[isrc]);\n    } else {\n        cgraph->grads[isrc] = tensor;\n    }\n    ggml_format_name(cgraph->grads[isrc], \"grad for %s\", src->name);\n    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);\n}\n\nstatic void ggml_acc_or_set(\n        struct ggml_context * ctx,\n        struct ggml_cgraph  * cgraph,\n        size_t                isrc,\n        struct ggml_tensor  * tensor,\n        const  size_t         nb1,\n        const  size_t         nb2,\n        const  size_t         nb3,\n        const  size_t         offset) {\n    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];\n    GGML_ASSERT(src);\n    if (cgraph->grads[isrc]) {\n        cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);\n    } else {\n        struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN\n        cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);\n    }\n    ggml_format_name(cgraph->grads[isrc], \"grad for %s\", cgraph->visited_hash_set.keys[isrc]->name);\n    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);\n}\n\nstatic void ggml_add1_or_set(\n        struct ggml_context * ctx,\n        struct ggml_cgraph  * cgraph,\n        size_t                isrc,\n        struct ggml_tensor  * tensor) {\n    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];\n    GGML_ASSERT(src);\n    if (cgraph->grads[isrc]) {\n        cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);\n    } else {\n        cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src);\n    }\n    ggml_format_name(cgraph->grads[isrc], \"grad for %s\", src->name);\n    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);\n}\n\nstatic void ggml_sub_or_set(\n        struct ggml_context * ctx,\n        struct ggml_cgraph  * cgraph,\n        size_t                isrc,\n        struct ggml_tensor  * tensor) {\n    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];\n    GGML_ASSERT(src);\n    if (cgraph->grads[isrc]) {\n        cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);\n    } else {\n        cgraph->grads[isrc] = ggml_neg(ctx, tensor);\n    }\n    ggml_format_name(cgraph->grads[isrc], \"grad for %s\", src->name);\n    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);\n}\n\nstatic void ggml_compute_backward(\n        struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {\n    struct ggml_tensor * tensor = cgraph->nodes[i];\n    struct ggml_tensor * grad   = ggml_graph_get_grad(cgraph, tensor);\n\n    if (!grad) {\n        return;\n    }\n\n    struct ggml_tensor * src0 = tensor->src[0];\n    struct ggml_tensor * src1 = tensor->src[1];\n    struct ggml_tensor * src2 = tensor->src[2];\n    struct ggml_hash_set * hash_set = &cgraph->visited_hash_set;\n    const size_t isrc0 = src0 ? ggml_hash_find(hash_set, src0) : (size_t) -1;\n    const size_t isrc1 = src1 ? ggml_hash_find(hash_set, src1) : (size_t) -1;\n    const size_t isrc2 = src2 ? ggml_hash_find(hash_set, src2) : (size_t) -1;\n    const bool src0_needs_grads = src0 && isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];\n    const bool src1_needs_grads = src1 && isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];\n    const bool src2_needs_grads = src2 && isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];\n\n    switch (tensor->op) {\n        case GGML_OP_DUP: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, grad);\n            }\n        } break;\n        case GGML_OP_ADD: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, grad);\n            }\n            if (src1_needs_grads) {\n                struct ggml_tensor * tmp = grad;\n                if (!ggml_are_same_shape(src0, src1)) {\n                    tmp = ggml_repeat_back(ctx, tmp, src1);\n                }\n                ggml_add_or_set(ctx, cgraph, isrc1, tmp);\n            }\n        } break;\n        case GGML_OP_ADD1: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, grad);\n            }\n            if (src1_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc1, ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean\n            }\n        } break;\n        case GGML_OP_ACC: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, grad);\n            }\n            if (src1_needs_grads) {\n                const size_t nb1    = ((int32_t *) tensor->op_params)[0];\n                const size_t nb2    = ((int32_t *) tensor->op_params)[1];\n                const size_t nb3    = ((int32_t *) tensor->op_params)[2];\n                const size_t offset = ((int32_t *) tensor->op_params)[3];\n\n                struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,\n                    grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],\n                    nb1, nb2, nb3, offset);\n\n                ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));\n            }\n        } break;\n        case GGML_OP_SUB: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, grad);\n            }\n            if (src1_needs_grads) {\n                ggml_sub_or_set(ctx, cgraph, isrc1, grad);\n            }\n        } break;\n        case GGML_OP_MUL: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));\n            }\n            if (src1_needs_grads) {\n                struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);\n                if (!ggml_are_same_shape(src0, src1)) {\n                    tmp = ggml_repeat_back(ctx, tmp, src1);\n                }\n                ggml_add_or_set(ctx, cgraph, isrc1, tmp);\n            }\n        } break;\n        case GGML_OP_DIV: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src1));\n            }\n            if (src1_needs_grads) {\n                ggml_sub_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, grad, ggml_div(ctx, tensor, src1)));\n            }\n        } break;\n        case GGML_OP_SQR: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_mul(ctx, src0, grad), 2.0f));\n            }\n        } break;\n        case GGML_OP_SQRT: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_div(ctx, grad, tensor), 0.5f));\n            }\n        } break;\n        case GGML_OP_LOG: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src0));\n            }\n        } break;\n        case GGML_OP_SIN: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_cos(ctx, src0)));\n            }\n        } break;\n        case GGML_OP_COS: {\n            if (src0_needs_grads) {\n                ggml_sub_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sin(ctx, src0)));\n            }\n        } break;\n        case GGML_OP_SUM: {\n            if (src0_needs_grads) {\n                ggml_add1_or_set(ctx, cgraph, isrc0, grad);\n            }\n        } break;\n        case GGML_OP_SUM_ROWS: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));\n            }\n        } break;\n        case GGML_OP_MEAN: {\n            if (src0_needs_grads) {\n                ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));\n            }\n        } break;\n        case GGML_OP_REPEAT: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0));\n            }\n        } break;\n        case GGML_OP_REPEAT_BACK: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));\n            }\n        } break;\n        case GGML_OP_RMS_NORM: {\n            if (src0_needs_grads) {\n                float eps;\n                memcpy(&eps, tensor->op_params, sizeof(float));\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));\n            }\n        } break;\n        case GGML_OP_MUL_MAT: {\n            // https://cs231n.github.io/optimization-2/#staged\n            // # forward pass\n            // s0 = np.random.randn(5, 10)\n            // s1 = np.random.randn(10, 3)\n            // t = s0.dot(s1)\n\n            // # now suppose we had the gradient on t from above in the circuit\n            // dt = np.random.randn(*t.shape) # same shape as t\n            // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix\n            // ds1 = t.T.dot(dt)\n\n            // tensor.shape [m,p,qq,rr]\n            // src0.shape   [n,m,q1,r1]\n            // src1.shape   [n,p,qq,rr]\n\n            if (src0_needs_grads) {\n                GGML_ASSERT(grad->ne[2] == src1->ne[2]);\n                GGML_ASSERT(grad->ne[3] == src1->ne[3]);\n                struct ggml_tensor * tmp =\n                    ggml_out_prod(ctx, // [n,m,qq,rr]\n                        src1,          // [n,p,qq,rr]\n                        grad);         // [m,p,qq,rr]\n                if (!ggml_are_same_shape(tmp, src0)) {\n                    GGML_ASSERT(tmp->ne[0] == src0->ne[0]);\n                    GGML_ASSERT(tmp->ne[1] == src0->ne[1]);\n                    GGML_ASSERT(tmp->ne[3] == 1);\n\n                    const int64_t nr2 = tmp->ne[2] / src0->ne[2];\n                    const size_t nb2 = tmp->nb[2] * nr2;\n                    const size_t nb3 = tmp->nb[2];\n\n                    tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);\n                    tmp = ggml_repeat_back(ctx, tmp, src0);\n                }\n                ggml_add_or_set(ctx, cgraph, isrc0, tmp);\n            }\n            if (src1_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc1,\n                        // ggml_mul_mat(ctx,                   // [n,p,qq,rr]\n                        //     ggml_cont(ctx,                  // [m,n,q1,r1]\n                        //         ggml_transpose(ctx, src0)), // [m,n,q1,r1]\n                        //     grad),                          // [m,p,qq,rr]\n\n                        // when src0 is bigger than tensor->grad (this is mostly the case in llama),\n                        // avoid transpose of src0, rather transpose smaller tensor->grad\n                        // and then use ggml_out_prod\n                        ggml_out_prod(ctx,      // [n,p,qq,rr]\n                            src0,               // [n,m,q1,r1]\n                            ggml_transpose(ctx, // [p,m,qq,rr]\n                                grad)));        // [m,p,qq,rr]\n            }\n        } break;\n        case GGML_OP_SCALE: {\n            if (src0_needs_grads) {\n                float s;\n                memcpy(&s, tensor->op_params, sizeof(float));\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false));\n            }\n        } break;\n        case GGML_OP_SET: {\n            const size_t nb1    = ((const int32_t *) tensor->op_params)[0];\n            const size_t nb2    = ((const int32_t *) tensor->op_params)[1];\n            const size_t nb3    = ((const int32_t *) tensor->op_params)[2];\n            const size_t offset = ((const int32_t *) tensor->op_params)[3];\n\n            struct ggml_tensor * tensor_grad_view = NULL;\n\n            if (src0_needs_grads || src1_needs_grads) {\n                GGML_ASSERT(src0->type == tensor->type);\n                GGML_ASSERT(!cgraph->grads[isrc0] ||                      cgraph->grads[isrc0]->type == grad->type);\n                GGML_ASSERT(!cgraph->grads[isrc1] || !src1_needs_grads || cgraph->grads[isrc1]->type == grad->type);\n\n                tensor_grad_view = ggml_view_4d(ctx,\n                    grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],\n                    nb1, nb2, nb3, offset);\n            }\n\n            if (src0_needs_grads) {\n                struct ggml_tensor * tmp = ggml_neg(ctx, tensor_grad_view);\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_acc_impl(ctx, grad, tmp, nb1, nb2, nb3, offset, false));\n            }\n\n            if (src1_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));\n            }\n        } break;\n        case GGML_OP_CPY: {\n            // cpy overwrites value of src1 by src0 and returns view(src1)\n            // the overwriting is mathematically equivalent to:\n            // tensor = src0 * 1 + src1 * 0\n            if (src0_needs_grads) {\n                // dsrc0 = dtensor * 1\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0));\n            }\n            if (src1_needs_grads) {\n                // dsrc1 = dtensor * 0 -> noop\n            }\n        } break;\n        case GGML_OP_CONT: {\n            // same as cpy\n            if (src0_needs_grads) {\n                GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));\n                GGML_ASSERT(ggml_is_contiguous(grad));\n                GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));\n                ggml_add_or_set(ctx, cgraph, isrc0,\n                    ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));\n            }\n        } break;\n        case GGML_OP_RESHAPE: {\n            if (src0_needs_grads) {\n                struct ggml_tensor * grad_cont = ggml_is_contiguous(grad) ? grad : ggml_cont(ctx, grad);\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad_cont, src0));\n            }\n        } break;\n        case GGML_OP_VIEW: {\n            if (src0_needs_grads) {\n                size_t offset;\n\n                memcpy(&offset, tensor->op_params, sizeof(offset));\n\n                size_t nb1 = tensor->nb[1];\n                size_t nb2 = tensor->nb[2];\n                size_t nb3 = tensor->nb[3];\n\n                if (cgraph->grads[isrc0] && src0->type != cgraph->grads[isrc0]->type) {\n                    // gradient is typically F32, but src0 could be other type\n                    size_t ng = ggml_element_size(cgraph->grads[isrc0]);\n                    size_t n0 = ggml_element_size(src0);\n                    GGML_ASSERT(offset % n0 == 0);\n                    GGML_ASSERT(nb1 % n0 == 0);\n                    GGML_ASSERT(nb2 % n0 == 0);\n                    GGML_ASSERT(nb3 % n0 == 0);\n                    offset = (offset / n0) * ng;\n                    nb1 = (nb1 / n0) * ng;\n                    nb2 = (nb2 / n0) * ng;\n                    nb3 = (nb3 / n0) * ng;\n                }\n\n                ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset);\n            }\n        } break;\n        case GGML_OP_PERMUTE: {\n            if (src0_needs_grads) {\n                const int32_t * axes = (const int32_t *) tensor->op_params;\n                const int axis0 = axes[0] & 0x3;\n                const int axis1 = axes[1] & 0x3;\n                const int axis2 = axes[2] & 0x3;\n                const int axis3 = axes[3] & 0x3;\n                int axb[4] = {0,0,0,0}; // axes backward\n                axb[axis0] = 0;\n                axb[axis1] = 1;\n                axb[axis2] = 2;\n                axb[axis3] = 3;\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_permute(ctx, grad, axb[0], axb[1], axb[2], axb[3]));\n            }\n        } break;\n        case GGML_OP_TRANSPOSE: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_transpose(ctx, grad));\n            }\n        } break;\n        case GGML_OP_GET_ROWS: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_get_rows_back(ctx, grad, src1, src0));\n            }\n            if (src1_needs_grads) {\n                // noop\n            }\n        } break;\n        case GGML_OP_DIAG_MASK_INF: {\n            if (src0_needs_grads) {\n                /* ggml_diag_mask_inf_impl() shouldn't be here */\n                /* ref:  https://github.com/ggml-org/llama.cpp/pull/4203#discussion_r1412377992 */\n                const int n_past = ((const int32_t *) tensor->op_params)[0];\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));\n            }\n        } break;\n        case GGML_OP_DIAG_MASK_ZERO: {\n            if (src0_needs_grads) {\n                const int n_past = ((const int32_t *) tensor->op_params)[0];\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));\n            }\n        } break;\n        case GGML_OP_SOFT_MAX: {\n            if (src0_needs_grads) {\n                float scale    = 1.0f;\n                float max_bias = 0.0f;\n\n                memcpy(&scale,    (const float *) tensor->op_params + 0, sizeof(float));\n                memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));\n\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));\n            }\n            GGML_ASSERT((!src1 || !src1_needs_grads) && \"backward pass for softmax mask not implemented\");\n        } break;\n        case GGML_OP_ROPE: {\n            if (src0_needs_grads) {\n                //const int n_past = ((int32_t *) tensor->op_params)[0];\n                const int n_dims     = ((const int32_t *) tensor->op_params)[1];\n                const int mode       = ((const int32_t *) tensor->op_params)[2];\n                //const int n_ctx      = ((int32_t *) tensor->op_params)[3];\n                const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];\n                float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;\n                int sections[4] = {0, 0, 0, 0};\n\n                memcpy(&freq_base,   (const float *) tensor->op_params +  5, sizeof(float));\n                memcpy(&freq_scale,  (const float *) tensor->op_params +  6, sizeof(float));\n                memcpy(&ext_factor,  (const float *) tensor->op_params +  7, sizeof(float));\n                memcpy(&attn_factor, (const float *) tensor->op_params +  8, sizeof(float));\n                memcpy(&beta_fast,   (const float *) tensor->op_params +  9, sizeof(float));\n                memcpy(&beta_slow,   (const float *) tensor->op_params + 10, sizeof(float));\n                memcpy(&sections,                    tensor->op_params + 11, sizeof(sections));\n\n                struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?\n                    ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,\n                        mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :\n                    ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,\n                        mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);\n                ggml_add_or_set(ctx, cgraph, isrc0, rope_back);\n            }\n            GGML_ASSERT((!src2 || !src2_needs_grads) && \"gradients for freq factors not implemented\");\n        } break;\n        case GGML_OP_IM2COL: {\n            if (src1_needs_grads) {\n                const int32_t s0    = ggml_get_op_params_i32(tensor, 0);\n                const int32_t s1    = ggml_get_op_params_i32(tensor, 1);\n                const int32_t p0    = ggml_get_op_params_i32(tensor, 2);\n                const int32_t p1    = ggml_get_op_params_i32(tensor, 3);\n                const int32_t d0    = ggml_get_op_params_i32(tensor, 4);\n                const int32_t d1    = ggml_get_op_params_i32(tensor, 5);\n                const bool    is_2D = ggml_get_op_params_i32(tensor, 6) == 1;\n\n                ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));\n            }\n        } break;\n        case GGML_OP_POOL_2D: {\n            if (src0_needs_grads) {\n                const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0);\n                const      int32_t      k0 = ggml_get_op_params_i32(tensor, 1);\n                const      int32_t      k1 = ggml_get_op_params_i32(tensor, 2);\n                const      int32_t      s0 = ggml_get_op_params_i32(tensor, 3);\n                const      int32_t      s1 = ggml_get_op_params_i32(tensor, 4);\n                const      int32_t      p0 = ggml_get_op_params_i32(tensor, 5);\n                const      int32_t      p1 = ggml_get_op_params_i32(tensor, 6);\n\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_pool_2d_back(ctx, grad, src0, op, k0, k1, s0, s1, p0, p1));\n            }\n        } break;\n        case GGML_OP_WIN_PART:\n        case GGML_OP_WIN_UNPART:\n        case GGML_OP_UNARY: {\n            switch (ggml_get_unary_op(tensor)) {\n                case GGML_UNARY_OP_ABS: {\n                    if (src0_needs_grads) {\n                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_sgn(ctx, src0), grad));\n                    }\n                } break;\n                case GGML_UNARY_OP_SGN: {\n                    // noop\n                } break;\n                case GGML_UNARY_OP_NEG: {\n                    if (src0_needs_grads) {\n                        ggml_sub_or_set(ctx, cgraph, isrc0, grad);\n                    }\n                } break;\n                case GGML_UNARY_OP_STEP: {\n                    // noop\n                } break;\n                case GGML_UNARY_OP_RELU: {\n                    if (src0_needs_grads) {\n                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_step(ctx, src0), grad));\n                    }\n                } break;\n                case GGML_UNARY_OP_SILU: {\n                    if (src0_needs_grads) {\n                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));\n                    }\n                } break;\n                case GGML_UNARY_OP_EXP: {\n                    if (src0_needs_grads) {\n                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad));\n                    }\n                } break;\n                case GGML_UNARY_OP_EXPM1: {\n                    if (src0_needs_grads) {\n                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_exp(ctx, src0)));\n                    }\n                } break;\n                case GGML_UNARY_OP_SOFTPLUS: {\n                    if (src0_needs_grads) {\n                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sigmoid(ctx, src0)));\n                    }\n                } break;\n                default: {\n                    fprintf(stderr, \"%s: unsupported unary op for backward pass: %s\\n\",\n                        __func__, ggml_unary_op_name(ggml_get_unary_op(tensor)));\n                    GGML_ABORT(\"fatal error\");\n                } //break;\n            }\n        } break;\n        case GGML_OP_CROSS_ENTROPY_LOSS: {\n            if (src0_needs_grads) {\n                ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));\n            }\n            GGML_ASSERT(!src1_needs_grads && \"backward pass for labels not implemented\");\n        } break;\n        case GGML_OP_GLU: {\n            switch (ggml_get_glu_op(tensor)) {\n                case GGML_GLU_OP_SWIGLU: {\n                    if (src0_needs_grads) {\n                        GGML_ASSERT(src1 && \"backward pass only implemented for split swiglu\");\n                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));\n                    }\n                    if (src1_needs_grads) {\n                        ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));\n                    }\n                } break;\n                default: {\n                    GGML_ABORT(\"unsupported glu op for backward pass: %s\", ggml_glu_op_name(ggml_get_glu_op(tensor)));\n                } //break;\n            }\n        } break;\n        case GGML_OP_NONE: {\n            // noop\n        } break;\n        case GGML_OP_COUNT:\n        default: {\n            GGML_ABORT(\"%s: unsupported ggml op for backward pass: %s\\n\", __func__, ggml_op_name(tensor->op));\n        } //break;\n    }\n\n    GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0]));\n    GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(src1, cgraph->grads[isrc1]));\n    GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));\n}\n\nstatic size_t ggml_visit_parents_graph(struct ggml_cgraph * cgraph, struct ggml_tensor * node, bool compute) {\n    if (node->op != GGML_OP_NONE && compute) {\n        node->flags |= GGML_TENSOR_FLAG_COMPUTE;\n    }\n\n    const size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);\n    GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);\n\n    if (ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {\n        // already visited\n\n        if (compute) {\n            // update the compute flag regardless\n            for (int i = 0; i < GGML_MAX_SRC; ++i) {\n                struct ggml_tensor * src = node->src[i];\n                if (src && ((src->flags & GGML_TENSOR_FLAG_COMPUTE) == 0)) {\n                    ggml_visit_parents_graph(cgraph, src, true);\n                }\n            }\n        }\n\n        return node_hash_pos;\n    }\n\n    // This is the first time we see this node in the current graph.\n    cgraph->visited_hash_set.keys[node_hash_pos] = node;\n    ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);\n    cgraph->use_counts[node_hash_pos] = 0;\n\n    for (int i = 0; i < GGML_MAX_SRC; ++i) {\n        const int k =\n            (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :\n            (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :\n            /* unknown order, just fall back to using i */ i;\n\n        struct ggml_tensor * src = node->src[k];\n        if (src) {\n            const size_t src_hash_pos = ggml_visit_parents_graph(cgraph, src, compute);\n\n            // Update the use count for this operand.\n            cgraph->use_counts[src_hash_pos]++;\n        }\n    }\n\n    if (node->op == GGML_OP_NONE && !(node->flags & GGML_TENSOR_FLAG_PARAM)) {\n        // reached a leaf node, not part of the gradient graph (e.g. a constant)\n        GGML_ASSERT(cgraph->n_leafs < cgraph->size);\n\n        if (strlen(node->name) == 0) {\n            ggml_format_name(node, \"leaf_%d\", cgraph->n_leafs);\n        }\n\n        cgraph->leafs[cgraph->n_leafs] = node;\n        cgraph->n_leafs++;\n    } else {\n        GGML_ASSERT(cgraph->n_nodes < cgraph->size);\n\n        if (strlen(node->name) == 0) {\n            ggml_format_name(node, \"node_%d\", cgraph->n_nodes);\n        }\n\n        cgraph->nodes[cgraph->n_nodes] = node;\n        cgraph->n_nodes++;\n    }\n\n    return node_hash_pos;\n}\n\nstatic void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand, bool compute) {\n    if (!expand) {\n        // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand\n        ggml_graph_clear(cgraph);\n    }\n\n    const int n_old = cgraph->n_nodes;\n\n    ggml_visit_parents_graph(cgraph, tensor, compute);\n\n    const int n_new = cgraph->n_nodes - n_old;\n    GGML_PRINT_DEBUG(\"%s: visited %d new nodes\\n\", __func__, n_new);\n\n    if (n_new > 0) {\n        // the last added node should always be starting point\n        GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor);\n    }\n}\n\nstruct ggml_tensor * ggml_build_forward_select(\n        struct ggml_cgraph  * cgraph,\n        struct ggml_tensor ** tensors,\n        int                   n_tensors,\n        int                   idx) {\n    GGML_ASSERT(idx >= 0 && idx < n_tensors);\n\n    for (int i = 0; i < n_tensors; i++) {\n        ggml_build_forward_impl(cgraph, tensors[i], true, i == idx ? true : false);\n    }\n\n    return tensors[idx];\n}\n\nvoid ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {\n    ggml_build_forward_impl(cgraph, tensor, true, true);\n}\n\nvoid ggml_build_backward_expand(\n        struct ggml_context *  ctx,\n        struct ggml_cgraph  *  cgraph,\n        struct ggml_tensor  ** grad_accs) {\n    GGML_ASSERT(cgraph->n_nodes > 0);\n    GGML_ASSERT(cgraph->grads);\n    GGML_ASSERT(cgraph->grad_accs);\n\n    const int n_nodes_f = cgraph->n_nodes;\n\n    memset(cgraph->grads,     0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));\n    memset(cgraph->grad_accs, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));\n    bool * grads_needed = calloc(cgraph->visited_hash_set.size, sizeof(bool));\n\n    {\n        bool any_params = false;\n        bool any_loss   = false;\n        for (int i = 0; i < n_nodes_f; ++i) {\n            struct ggml_tensor * node = cgraph->nodes[i];\n            any_params = any_params || (node->flags & GGML_TENSOR_FLAG_PARAM);\n            any_loss   = any_loss   || (node->flags & GGML_TENSOR_FLAG_LOSS);\n        }\n        GGML_ASSERT(any_params && \"no trainable parameters found, did you forget to call ggml_set_param?\");\n        GGML_ASSERT(any_loss && \"no training loss found, did you forget to call ggml_set_loss?\");\n    }\n\n    for (int i = 0; i < n_nodes_f; ++i) {\n        struct ggml_tensor * node = cgraph->nodes[i];\n\n        if (node->type == GGML_TYPE_I32) {\n            continue;\n        }\n\n        bool node_needs_grad = (node->flags & GGML_TENSOR_FLAG_PARAM) || (node->flags & GGML_TENSOR_FLAG_LOSS);\n        bool ignore_src[GGML_MAX_SRC] = {false};\n        switch (node->op) {\n            // gradients in node->src[0] for one reason or another have no effect on output gradients\n            case GGML_OP_IM2COL:      // only used for its shape\n            case GGML_OP_IM2COL_BACK: // same as IM2COL\n                ignore_src[0] = true;\n                break;\n            case GGML_OP_UNARY: {\n                const enum ggml_unary_op uop = ggml_get_unary_op(node);\n                // SGN and STEP unary ops are piecewise constant\n                if (uop == GGML_UNARY_OP_SGN || uop == GGML_UNARY_OP_STEP) {\n                    ignore_src[0] = true;\n                }\n            } break;\n\n            // gradients in node->src[1] for one reason or another have no effect on output gradients\n            case GGML_OP_CPY:           // gradients in CPY target are irrelevant\n            case GGML_OP_GET_ROWS:      // row indices not differentiable\n            case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS\n            case GGML_OP_ROPE:          // positions not differentiable\n                ignore_src[1] = true;\n                break;\n\n            default:\n                break;\n        }\n        for (int j = 0; j < GGML_MAX_SRC; ++j) {\n            if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) {\n                continue;\n            }\n            GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16);\n            node_needs_grad = true;\n            break;\n        }\n        if (!node_needs_grad) {\n            continue;\n        }\n\n        // inplace operations are currently not supported\n        GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||\n            node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);\n\n        const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);\n        GGML_ASSERT(ihash != GGML_HASHSET_FULL);\n        GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));\n        if (grad_accs && grad_accs[i]) {\n            cgraph->grad_accs[ihash] = grad_accs[i];\n            cgraph->grads[ihash]     = cgraph->grad_accs[ihash];\n        } else if (node->flags & GGML_TENSOR_FLAG_LOSS) {\n            // loss tensors always need a gradient accumulator\n            cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);\n            cgraph->grads[ihash]     = cgraph->grad_accs[ihash];\n        }\n        grads_needed[ihash] = true;\n    }\n\n    for (int i = n_nodes_f - 1; i >= 0; --i) {\n        // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation\n        // use allocator to automatically make inplace operations\n        ggml_compute_backward(ctx, cgraph, i, grads_needed);\n    }\n\n    free(grads_needed);\n}\n\nstatic void * incr_ptr_aligned(void ** p, size_t size, size_t align) {\n    void * ptr = *p;\n    ptr = (void *) GGML_PAD((uintptr_t) ptr, align);\n    *p = (void *) ((char *) ptr + size);\n    return ptr;\n}\n\nstatic size_t ggml_graph_nbytes(size_t size, bool grads) {\n    size_t hash_size = ggml_hash_size(size * 2);\n    void * p = 0;\n    incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);\n    incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes\n    incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs\n    incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts\n    incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys\n    if (grads) {\n        incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads\n        incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs\n    }\n    incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));\n\n    size_t nbytes = (size_t) p;\n    return nbytes;\n}\n\nsize_t ggml_graph_overhead_custom(size_t size, bool grads) {\n    return GGML_OBJECT_SIZE + GGML_PAD(ggml_graph_nbytes(size, grads), GGML_MEM_ALIGN);\n}\n\nsize_t ggml_graph_overhead(void) {\n    return ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, false);\n}\n\nstruct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads) {\n    const size_t obj_size = ggml_graph_nbytes(size, grads);\n    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_GRAPH, obj_size);\n    struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);\n\n    // the size of the hash table is doubled since it needs to hold both nodes and leafs\n    size_t hash_size = ggml_hash_size(size * 2);\n\n    void * p = cgraph + 1;\n\n    struct ggml_tensor ** nodes_ptr      =         incr_ptr_aligned(&p, size      * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));\n    struct ggml_tensor ** leafs_ptr      =         incr_ptr_aligned(&p, size      * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));\n    int32_t             * use_counts_ptr =         incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));\n    struct ggml_tensor ** hash_keys_ptr  =         incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));\n    struct ggml_tensor ** grads_ptr      = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;\n    struct ggml_tensor ** grad_accs_ptr  = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;\n\n    ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));\n\n    // check that we allocated the correct amount of memory\n    assert(obj_size == (size_t)((char *)p - (char *)cgraph));\n\n    *cgraph = (struct ggml_cgraph) {\n        /*.size         =*/ size,\n        /*.n_nodes      =*/ 0,\n        /*.n_leafs      =*/ 0,\n        /*.nodes        =*/ nodes_ptr,\n        /*.grads        =*/ grads_ptr,\n        /*.grad_accs    =*/ grad_accs_ptr,\n        /*.leafs        =*/ leafs_ptr,\n        /*.use_counts   =*/ use_counts_ptr,\n        /*.hash_table   =*/ { hash_size, hash_used, hash_keys_ptr },\n        /*.order        =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,\n    };\n\n    ggml_hash_set_reset(&cgraph->visited_hash_set);\n    if (grads) {\n        memset(cgraph->grads,     0, hash_size*sizeof(struct ggml_tensor *));\n        memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));\n    }\n\n    return cgraph;\n}\n\nstruct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {\n    return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false);\n}\n\nstruct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) {\n    struct ggml_cgraph cgraph = {\n        /*.size             =*/ 0,\n        /*.n_nodes          =*/ i1 - i0,\n        /*.n_leafs          =*/ 0,\n        /*.nodes            =*/ cgraph0->nodes + i0,\n        /*.grads            =*/ NULL, // gradients would need visited_hash_set\n        /*.grad_accs        =*/ NULL,\n        /*.leafs            =*/ NULL,\n        /*.use_counts       =*/ cgraph0->use_counts,\n        /*.visited_hash_set =*/ cgraph0->visited_hash_set,\n        /*.order            =*/ cgraph0->order,\n    };\n\n    return cgraph;\n}\n\nvoid ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {\n    GGML_ASSERT(dst->size >= src->n_leafs);\n    GGML_ASSERT(dst->size >= src->n_nodes);\n    GGML_ASSERT(dst->visited_hash_set.size >= src->visited_hash_set.size);\n\n    dst->n_leafs = src->n_leafs;\n    dst->n_nodes = src->n_nodes;\n    dst->order   = src->order;\n\n    for (int i = 0; i < src->n_leafs; ++i) {\n        dst->leafs[i] = src->leafs[i];\n    }\n\n    for (int i = 0; i < src->n_nodes; ++i) {\n        dst->nodes[i] = src->nodes[i];\n    }\n\n    for (size_t i = 0; i < src->visited_hash_set.size; ++i) {\n        // copy all hashset keys (tensors) that are in use\n        if (ggml_bitset_get(src->visited_hash_set.used, i)) {\n            size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);\n            dst->use_counts[new_hash_pos] = src->use_counts[i];\n        }\n    }\n\n    if (dst->grads) {\n        memset(dst->grads,     0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));\n        memset(dst->grad_accs, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));\n    }\n    if (src->grads) {\n        GGML_ASSERT(dst->grads     != NULL);\n        GGML_ASSERT(dst->grad_accs != NULL);\n        for (int i = 0; i < src->n_nodes; ++i) {\n            const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);\n            const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);\n\n            GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);\n            GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));\n            GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);\n            GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));\n\n            dst->grads[igrad_dst]     = src->grads[igrad_src];\n            dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];\n        }\n    }\n}\n\nstruct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {\n    struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);\n    ggml_graph_cpy(cgraph, result);\n    return result;\n}\n\nstruct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {\n    if (ggml_is_empty(tensor)) {\n        return tensor;\n    }\n    if (tensor->buffer) {\n        ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));\n    } else {\n        GGML_ASSERT(tensor->data);\n        memset(tensor->data, 0, ggml_nbytes(tensor));\n    }\n    return tensor;\n}\n\nvoid ggml_graph_reset(struct ggml_cgraph * cgraph) {\n    if (!cgraph) {\n        return;\n    }\n    GGML_ASSERT(cgraph->grads != NULL);\n\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        struct ggml_tensor * node     = cgraph->nodes[i];\n        struct ggml_tensor * grad_acc = ggml_graph_get_grad_acc(cgraph, node);\n\n        if (node->op == GGML_OP_OPT_STEP_ADAMW) {\n            // clear momenta\n            ggml_set_zero(node->src[2]);\n            ggml_set_zero(node->src[3]);\n        }\n\n        // initial gradients of loss should be 1, 0 otherwise\n        if (grad_acc) {\n            if (node->flags & GGML_TENSOR_FLAG_LOSS) {\n                GGML_ASSERT(grad_acc->type == GGML_TYPE_F32);\n                GGML_ASSERT(ggml_is_scalar(grad_acc));\n\n                const float onef = 1.0f;\n                if (grad_acc->buffer) {\n                    ggml_backend_tensor_set(grad_acc, &onef, 0, sizeof(float));\n                } else {\n                    GGML_ASSERT(grad_acc->data);\n                    *((float *) grad_acc->data) = onef;\n                }\n            } else {\n                ggml_set_zero(grad_acc);\n            }\n        }\n    }\n}\n\nvoid ggml_graph_clear(struct ggml_cgraph * cgraph) {\n    cgraph->n_leafs = 0;\n    cgraph->n_nodes = 0;\n    ggml_hash_set_reset(&cgraph->visited_hash_set);\n}\n\nint ggml_graph_size(struct ggml_cgraph * cgraph) {\n    return cgraph->size;\n}\n\nstruct ggml_tensor * ggml_graph_node(struct ggml_cgraph * cgraph, int i) {\n    if (i < 0) {\n        GGML_ASSERT(cgraph->n_nodes + i >= 0);\n        return cgraph->nodes[cgraph->n_nodes + i];\n    }\n\n    GGML_ASSERT(i < cgraph->n_nodes);\n    return cgraph->nodes[i];\n}\n\nstruct ggml_tensor ** ggml_graph_nodes(struct ggml_cgraph * cgraph) {\n    return cgraph->nodes;\n}\n\nint ggml_graph_n_nodes(struct ggml_cgraph * cgraph) {\n    return cgraph->n_nodes;\n}\n\nvoid ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {\n    GGML_ASSERT(cgraph->size > cgraph->n_nodes);\n    cgraph->nodes[cgraph->n_nodes] = tensor;\n    cgraph->n_nodes++;\n}\n\nstruct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, const char * name) {\n    for (int i = 0; i < cgraph->n_leafs; i++) {\n        struct ggml_tensor * leaf = cgraph->leafs[i];\n\n        if (strcmp(leaf->name, name) == 0) {\n            return leaf;\n        }\n    }\n\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        struct ggml_tensor * node = cgraph->nodes[i];\n\n        if (strcmp(node->name, name) == 0) {\n            return node;\n        }\n    }\n\n    return NULL;\n}\n\nstruct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {\n    const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);\n    return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grads ? cgraph->grads[igrad] : NULL;\n}\n\nstruct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {\n    const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);\n    return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grad_accs ? cgraph->grad_accs[igrad] : NULL;\n}\n\nvoid ggml_graph_print(const struct ggml_cgraph * cgraph) {\n    GGML_LOG_INFO(\"=== GRAPH ===\\n\");\n\n    GGML_LOG_INFO(\"n_nodes = %d\\n\", cgraph->n_nodes);\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        struct ggml_tensor * node = cgraph->nodes[i];\n\n        GGML_LOG_INFO(\" - %3d: [ %5\" PRId64 \", %5\" PRId64 \", %5\" PRId64 \"] %16s %s\\n\",\n                i,\n                node->ne[0], node->ne[1], node->ne[2],\n                ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? \"x\" :\n                      ggml_graph_get_grad(cgraph, node) ? \"g\" : \" \");\n    }\n\n    GGML_LOG_INFO(\"n_leafs = %d\\n\", cgraph->n_leafs);\n    for (int i = 0; i < cgraph->n_leafs; i++) {\n        struct ggml_tensor * node = cgraph->leafs[i];\n\n        GGML_LOG_INFO(\" - %3d: [ %5\" PRId64 \", %5\" PRId64 \"] %8s %16s\\n\",\n                i,\n                node->ne[0], node->ne[1],\n                ggml_op_name(node->op),\n                ggml_get_name(node));\n    }\n\n    GGML_LOG_INFO(\"========================================\\n\");\n}\n\nstatic int ggml_node_list_find_tensor(const struct ggml_cgraph * cgraph,\n                                      const int *                idxs,\n                                      int                        count,\n                                      const struct ggml_tensor * tensor) {\n    GGML_ASSERT(cgraph && idxs);\n    for (int i = 0; i < count; ++i) {\n        const int node_idx = idxs[i];\n\n        if (node_idx >= cgraph->n_nodes) {\n            return -1;\n        }\n        if (cgraph->nodes[node_idx] == tensor) {\n            return i;\n        }\n    }\n    return -1;\n}\n\nbool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,\n                                const int *                node_idxs,\n                                int                        count,\n                                const enum ggml_op *       ops,\n                                const int *                outputs,\n                                int                        num_outputs) {\n    GGML_ASSERT(outputs && num_outputs > 0);\n\n    for (int i = 0; i < count; ++i) {\n        if (node_idxs[i] >= cgraph->n_nodes) {\n            return false;\n        }\n\n        const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];\n\n        if (node->op != ops[i]) {\n            return false;\n        }\n\n        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {\n            return false;\n        }\n\n        if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {\n            continue;\n        }\n\n        if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {\n            return false;\n        }\n\n        int subgraph_uses = 0;\n        for (int j = i + 1; j < count; ++j) {\n            const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];\n            for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) {\n                if (other_node->src[src_idx] == node) {\n                    subgraph_uses++;\n                }\n            }\n        }\n\n        if (subgraph_uses != ggml_node_get_use_count(cgraph, node_idxs[i])) {\n            return false;\n        }\n\n        // if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph\n        struct ggml_tensor * view_src = node->view_src;\n        while (view_src) {\n            if (ggml_node_list_find_tensor(cgraph, node_idxs, count, view_src) == -1) {\n                return false;\n            }\n            view_src = view_src->view_src;\n        }\n    }\n\n    return true;\n}\n\n// check if node is part of the graph\nstatic bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {\n    if (cgraph == NULL) {\n        return true;\n    }\n\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        if (cgraph->nodes[i] == node) {\n            return true;\n        }\n    }\n\n    return false;\n}\n\nstatic struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {\n    for (int i = 0; i < cgraph->n_nodes; i++) {\n        struct ggml_tensor * parent = cgraph->nodes[i];\n        struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, parent);\n\n        if (grad == node) {\n            return parent;\n        }\n    }\n\n    return NULL;\n}\n\nstatic void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label)  {\n    struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node);\n    struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent);\n    fprintf(fp, \"  \\\"%p\\\" -> \\\"%p\\\" [ arrowhead = %s; style = %s; label = \\\"%s\\\"; ]\\n\",\n            gparent0 ? (void *) gparent0 : (void *) parent,\n            gparent ? (void *) gparent : (void *) node,\n            gparent ? \"empty\" : \"vee\",\n            gparent ? \"dashed\" : \"solid\",\n            label);\n}\n\nstatic void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label)  {\n    fprintf(fp, \"  \\\"%p\\\" -> \\\"%p\\\" [ label = \\\"%s\\\"; ]\\n\",\n            (void *) parent,\n            (void *) node,\n            label);\n}\n\nvoid ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename) {\n    char color[16];\n\n    FILE * fp = ggml_fopen(filename, \"w\");\n    GGML_ASSERT(fp);\n\n    fprintf(fp, \"digraph G {\\n\");\n    fprintf(fp, \"  newrank = true;\\n\");\n    fprintf(fp, \"  rankdir = TB;\\n\");\n\n    for (int i = 0; i < gb->n_nodes; i++) {\n        struct ggml_tensor * node = gb->nodes[i];\n        struct ggml_tensor * grad = ggml_graph_get_grad(gb, node);\n\n        if (ggml_graph_get_parent(gb, node) != NULL) {\n            continue;\n        }\n\n        if (node->flags & GGML_TENSOR_FLAG_PARAM) {\n            snprintf(color, sizeof(color), \"yellow\");\n        } else if (grad) {\n            if (ggml_graph_find(cgraph, node)) {\n                snprintf(color, sizeof(color), \"green\");\n            } else {\n                snprintf(color, sizeof(color), \"lightblue\");\n            }\n        } else {\n            snprintf(color, sizeof(color), \"white\");\n        }\n\n        fprintf(fp, \"  \\\"%p\\\" [ \"\n                    \"style = filled; fillcolor = %s; shape = record; \"\n                    \"label=\\\"\",\n                (void *) node, color);\n\n        if (strlen(node->name) > 0) {\n            fprintf(fp, \"%s (%s)|\", node->name, ggml_type_name(node->type));\n        } else {\n            fprintf(fp, \"(%s)|\", ggml_type_name(node->type));\n        }\n\n        if (ggml_is_matrix(node)) {\n            fprintf(fp, \"%d [%\" PRId64 \", %\" PRId64 \"] | <x>%s\", i, node->ne[0], node->ne[1], ggml_op_symbol(node->op));\n        } else {\n            fprintf(fp, \"%d [%\" PRId64 \", %\" PRId64 \", %\" PRId64 \"] | <x>%s\", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op));\n        }\n\n        if (grad) {\n            fprintf(fp, \" | <g>%s\\\"; ]\\n\", ggml_op_symbol(grad->op));\n        } else {\n            fprintf(fp, \"\\\"; ]\\n\");\n        }\n    }\n\n    for (int i = 0; i < gb->n_leafs; i++) {\n        struct ggml_tensor * node = gb->leafs[i];\n\n        snprintf(color, sizeof(color), \"pink\");\n\n        fprintf(fp, \"  \\\"%p\\\" [ \"\n                    \"style = filled; fillcolor = %s; shape = record; \"\n                    \"label=\\\"<x>\",\n                (void *) node, color);\n\n        if (strlen(node->name) > 0) {\n            fprintf(fp, \"%s (%s)|\", node->name, ggml_type_name(node->type));\n        } else {\n            fprintf(fp, \"(%s)|\", ggml_type_name(node->type));\n        }\n\n        fprintf(fp, \"CONST %d [%\" PRId64 \", %\" PRId64 \"]\", i, node->ne[0], node->ne[1]);\n        if (ggml_nelements(node) < 5 && node->data != NULL) {\n            fprintf(fp, \" | (\");\n            for (int j = 0; j < ggml_nelements(node); j++) {\n                // FIXME: use ggml-backend to obtain the tensor data\n                //if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {\n                //    fprintf(fp, \"%d\", ggml_get_i32_1d(node, j));\n                //}\n                //else if (node->type == GGML_TYPE_F32 ||\n                //         node->type == GGML_TYPE_F16 ||\n                //         node->type == GGML_TYPE_BF16) {\n                //    fprintf(fp, \"%.1e\", (double)ggml_get_f32_1d(node, j));\n                //}\n                //else\n                {\n                    fprintf(fp, \"#\");\n                }\n                if (j < ggml_nelements(node) - 1) {\n                    fprintf(fp, \", \");\n                }\n            }\n            fprintf(fp, \")\");\n        }\n        fprintf(fp, \"\\\"; ]\\n\");\n    }\n\n    for (int i = 0; i < gb->n_nodes; i++) {\n        struct ggml_tensor * node = gb->nodes[i];\n\n        for (int j = 0; j < GGML_MAX_SRC; j++) {\n            if (node->src[j]) {\n                char label[16];\n                snprintf(label, sizeof(label), \"src %d\", j);\n                ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label);\n            }\n        }\n    }\n\n    for (int i = 0; i < gb->n_leafs; i++) {\n        struct ggml_tensor * node = gb->leafs[i];\n\n        for (int j = 0; j < GGML_MAX_SRC; j++) {\n            if (node->src[j]) {\n                char label[16];\n                snprintf(label, sizeof(label), \"src %d\", j);\n                ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label);\n            }\n        }\n    }\n\n    fprintf(fp, \"}\\n\");\n\n    fclose(fp);\n\n    GGML_LOG_INFO(\"%s: dot -Tpng %s -o %s.png && open %s.png\\n\", __func__, filename, filename, filename);\n}\n\n////////////////////////////////////////////////////////////////////////////////\n\nvoid ggml_set_input(struct ggml_tensor * tensor) {\n    tensor->flags |= GGML_TENSOR_FLAG_INPUT;\n}\n\nvoid ggml_set_output(struct ggml_tensor * tensor) {\n    tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;\n}\n\nvoid ggml_set_param(struct ggml_tensor * tensor) {\n    GGML_ASSERT(tensor->op == GGML_OP_NONE);\n    tensor->flags |= GGML_TENSOR_FLAG_PARAM;\n}\n\nvoid ggml_set_loss(struct ggml_tensor * tensor) {\n    GGML_ASSERT(ggml_is_scalar(tensor));\n    GGML_ASSERT(tensor->type == GGML_TYPE_F32);\n    tensor->flags |= GGML_TENSOR_FLAG_LOSS;\n}\n\n////////////////////////////////////////////////////////////////////////////////\n\nvoid ggml_quantize_init(enum ggml_type type) {\n    ggml_critical_section_start();\n\n    switch (type) {\n        case GGML_TYPE_IQ2_XXS:\n        case GGML_TYPE_IQ2_XS:\n        case GGML_TYPE_IQ2_S:\n        case GGML_TYPE_IQ1_S:\n        case GGML_TYPE_IQ1_M:   iq2xs_init_impl(type); break;\n        case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;\n        case GGML_TYPE_IQ3_S:   iq3xs_init_impl(512); break;\n        default: // nothing\n            break;\n    }\n\n    ggml_critical_section_end();\n}\n\nvoid ggml_quantize_free(void) {\n    ggml_critical_section_start();\n\n    iq2xs_free_impl(GGML_TYPE_IQ2_XXS);\n    iq2xs_free_impl(GGML_TYPE_IQ2_XS);\n    iq2xs_free_impl(GGML_TYPE_IQ2_S);\n    iq2xs_free_impl(GGML_TYPE_IQ1_S);\n    iq2xs_free_impl(GGML_TYPE_IQ1_M);\n    iq3xs_free_impl(256);\n    iq3xs_free_impl(512);\n\n    ggml_critical_section_end();\n}\n\nbool ggml_quantize_requires_imatrix(enum ggml_type type) {\n    return\n        type == GGML_TYPE_IQ2_XXS ||\n        type == GGML_TYPE_IQ2_XS  ||\n        type == GGML_TYPE_IQ1_S;//   ||\n        //type == GGML_TYPE_IQ1_M;\n}\n\nsize_t ggml_quantize_chunk(\n        enum ggml_type   type,\n           const float * src,\n                  void * dst,\n               int64_t   start,\n               int64_t   nrows,\n               int64_t   n_per_row,\n           const float * imatrix) {\n    const int64_t n = (int64_t) nrows * n_per_row;\n\n    if (ggml_quantize_requires_imatrix(type)) {\n        GGML_ASSERT(imatrix != NULL);\n    }\n\n    GGML_ASSERT(start % type_traits[type].blck_size == 0);\n    GGML_ASSERT(start % n_per_row == 0);\n\n    ggml_quantize_init(type); // this is noop if already initialized\n\n    const size_t start_row = start / n_per_row;\n    const size_t row_size  = ggml_row_size(type, n_per_row);\n\n    size_t result = 0;\n\n    switch (type) {\n        case GGML_TYPE_Q4_0:    result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_Q4_1:    result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_Q5_0:    result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_Q5_1:    result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_Q8_0:    result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_MXFP4:   result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_NVFP4:   result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_Q2_K:    result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_Q3_K:    result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_Q4_K:    result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_Q5_K:    result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_Q6_K:    result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_TQ1_0:   result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_TQ2_0:   result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_IQ2_XS:  result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_IQ3_S:   result = quantize_iq3_s  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_IQ2_S:   result = quantize_iq2_s  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_IQ1_S:   result = quantize_iq1_s  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_IQ1_M:   result = quantize_iq1_m  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_IQ4_NL:  result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_IQ4_XS:  result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;\n        case GGML_TYPE_F16:\n            {\n                size_t elemsize = sizeof(ggml_fp16_t);\n                ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);\n                result = n * elemsize;\n            } break;\n        case GGML_TYPE_BF16:\n            {\n                size_t elemsize = sizeof(ggml_bf16_t);\n                ggml_fp32_to_bf16_row_ref(src + start, (ggml_bf16_t *)dst + start, n);\n                result = n * elemsize;\n            } break;\n        case GGML_TYPE_F32:\n            {\n                size_t elemsize = sizeof(float);\n                result = n * elemsize;\n                memcpy((uint8_t *)dst + start * elemsize, src + start, result);\n            } break;\n        default:\n            assert(false);\n    }\n\n    GGML_ASSERT(result == nrows * row_size);\n\n    return result;\n}\n\n////////////////////////////////////////////////////////////////////////////////\n\nvoid ggml_log_get(ggml_log_callback * log_callback, void ** user_data) {\n    *log_callback = g_logger_state.log_callback;\n    *user_data    = g_logger_state.log_callback_user_data;\n}\n\nvoid ggml_log_set(ggml_log_callback log_callback, void * user_data) {\n    g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default;\n    g_logger_state.log_callback_user_data = user_data;\n}\n\nvoid ggml_threadpool_params_init(struct ggml_threadpool_params * p, int n_threads) {\n    p->n_threads  = n_threads;\n    p->prio       = 0;     // default priority (usually means normal or inherited)\n    p->poll       = 50;    // hybrid-polling enabled\n    p->strict_cpu = false; // no strict placement (all threads share same cpumask)\n    p->paused     = false; // threads are ready to go\n    memset(p->cpumask, 0, GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)\n}\n\nstruct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) {\n    struct ggml_threadpool_params p;\n    ggml_threadpool_params_init(&p, n_threads);\n    return p;\n}\n\nbool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) {\n    if (p0->n_threads      != p1->n_threads  )    return false;\n    if (p0->prio           != p1->prio       )    return false;\n    if (p0->poll           != p1->poll       )    return false;\n    if (p0->strict_cpu     != p1->strict_cpu )    return false;\n    return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;\n}\n"
  },
  {
    "path": "src/ggml.cpp",
    "content": "#include \"ggml-impl.h\"\n\n#include <cstdlib>\n#include <exception>\n\nstatic std::terminate_handler previous_terminate_handler;\n\nGGML_NORETURN static void ggml_uncaught_exception() {\n    ggml_print_backtrace();\n    if (previous_terminate_handler) {\n        previous_terminate_handler();\n    }\n    abort(); // unreachable unless previous_terminate_handler was nullptr\n}\n\nstatic bool ggml_uncaught_exception_init = []{\n    const char * GGML_NO_BACKTRACE = getenv(\"GGML_NO_BACKTRACE\");\n    if (GGML_NO_BACKTRACE) {\n        return false;\n    }\n    const auto prev{std::get_terminate()};\n    GGML_ASSERT(prev != ggml_uncaught_exception);\n    previous_terminate_handler = prev;\n    std::set_terminate(ggml_uncaught_exception);\n    return true;\n}();\n"
  },
  {
    "path": "src/gguf.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-impl.h\"\n#include \"gguf.h\"\n\n#include <cinttypes>\n#include <cstddef>\n#include <cstdint>\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n#include <map>\n#include <new>\n#include <stdexcept>\n#include <string>\n#include <vector>\n\n#define GGUF_MAX_STRING_LENGTH  (1024*1024*1024)\n#define GGUF_MAX_ARRAY_ELEMENTS (1024*1024*1024)\n\n#ifdef _WIN32\n#    define gguf_ftell _ftelli64\n#    define gguf_fseek _fseeki64\n#else\n#    define gguf_ftell ftello\n#    define gguf_fseek fseeko\n#endif\n\ntemplate <typename T>\nstruct type_to_gguf_type;\n\ntemplate <>\nstruct type_to_gguf_type<uint8_t> {\n    static constexpr enum gguf_type value = GGUF_TYPE_UINT8;\n};\n\ntemplate <>\nstruct type_to_gguf_type<int8_t> {\n    static constexpr enum gguf_type value = GGUF_TYPE_INT8;\n};\n\ntemplate <>\nstruct type_to_gguf_type<uint16_t> {\n    static constexpr enum gguf_type value = GGUF_TYPE_UINT16;\n};\n\ntemplate <>\nstruct type_to_gguf_type<int16_t> {\n    static constexpr enum gguf_type value = GGUF_TYPE_INT16;\n};\n\ntemplate <>\nstruct type_to_gguf_type<uint32_t> {\n    static constexpr enum gguf_type value = GGUF_TYPE_UINT32;\n};\n\ntemplate <>\nstruct type_to_gguf_type<int32_t> {\n    static constexpr enum gguf_type value = GGUF_TYPE_INT32;\n};\n\ntemplate <>\nstruct type_to_gguf_type<float> {\n    static constexpr enum gguf_type value = GGUF_TYPE_FLOAT32;\n};\n\ntemplate <>\nstruct type_to_gguf_type<bool> {\n    static constexpr enum gguf_type value = GGUF_TYPE_BOOL;\n};\n\ntemplate <>\nstruct type_to_gguf_type<std::string> {\n    static constexpr enum gguf_type value = GGUF_TYPE_STRING;\n};\n\ntemplate <>\nstruct type_to_gguf_type<uint64_t> {\n    static constexpr enum gguf_type value = GGUF_TYPE_UINT64;\n};\n\ntemplate <>\nstruct type_to_gguf_type<int64_t> {\n    static constexpr enum gguf_type value = GGUF_TYPE_INT64;\n};\n\ntemplate <>\nstruct type_to_gguf_type<double> {\n    static constexpr enum gguf_type value = GGUF_TYPE_FLOAT64;\n};\n\nstatic const std::map<gguf_type, size_t> GGUF_TYPE_SIZE = {\n    {GGUF_TYPE_UINT8,   sizeof(uint8_t)},\n    {GGUF_TYPE_INT8,    sizeof(int8_t)},\n    {GGUF_TYPE_UINT16,  sizeof(uint16_t)},\n    {GGUF_TYPE_INT16,   sizeof(int16_t)},\n    {GGUF_TYPE_UINT32,  sizeof(uint32_t)},\n    {GGUF_TYPE_INT32,   sizeof(int32_t)},\n    {GGUF_TYPE_FLOAT32, sizeof(float)},\n    {GGUF_TYPE_BOOL,    sizeof(int8_t)},\n    {GGUF_TYPE_STRING,  0}, // undefined\n    {GGUF_TYPE_ARRAY,   0}, // undefined\n    {GGUF_TYPE_UINT64,  sizeof(uint64_t)},\n    {GGUF_TYPE_INT64,   sizeof(int64_t)},\n    {GGUF_TYPE_FLOAT64, sizeof(double)},\n};\nstatic_assert(GGUF_TYPE_COUNT == 13, \"GGUF_TYPE_COUNT != 13\");\n\nstatic const std::map<gguf_type, const char *> GGUF_TYPE_NAME = {\n    {GGUF_TYPE_UINT8,   \"u8\"},\n    {GGUF_TYPE_INT8,    \"i8\"},\n    {GGUF_TYPE_UINT16,  \"u16\"},\n    {GGUF_TYPE_INT16,   \"i16\"},\n    {GGUF_TYPE_UINT32,  \"u32\"},\n    {GGUF_TYPE_INT32,   \"i32\"},\n    {GGUF_TYPE_FLOAT32, \"f32\"},\n    {GGUF_TYPE_BOOL,    \"bool\"},\n    {GGUF_TYPE_STRING,  \"str\"},\n    {GGUF_TYPE_ARRAY,   \"arr\"},\n    {GGUF_TYPE_UINT64,  \"u64\"},\n    {GGUF_TYPE_INT64,   \"i64\"},\n    {GGUF_TYPE_FLOAT64, \"f64\"},\n};\nstatic_assert(GGUF_TYPE_COUNT == 13, \"GGUF_TYPE_COUNT != 13\");\n\nsize_t gguf_type_size(enum gguf_type type) {\n    auto it = GGUF_TYPE_SIZE.find(type);\n    return it == GGUF_TYPE_SIZE.end() ? 0 : it->second;\n}\n\nstruct gguf_kv {\n    std::string key;\n\n    bool is_array;\n    enum gguf_type type;\n\n    std::vector<int8_t>      data;\n    std::vector<std::string> data_string;\n\n    template <typename T>\n    gguf_kv(const std::string & key, const T value)\n            : key(key), is_array(false), type(type_to_gguf_type<T>::value) {\n        GGML_ASSERT(!key.empty());\n        data.resize(sizeof(T));\n        memcpy(data.data(), &value, sizeof(T));\n    }\n\n    template <typename T>\n    gguf_kv(const std::string & key, const std::vector<T> & value)\n            : key(key), is_array(true), type(type_to_gguf_type<T>::value) {\n        GGML_ASSERT(!key.empty());\n        data.resize(value.size()*sizeof(T));\n        for (size_t i = 0; i < value.size(); ++i) {\n            const T tmp = value[i];\n            memcpy(data.data() + i*sizeof(T), &tmp, sizeof(T));\n        }\n    }\n\n    gguf_kv(const std::string & key, const std::string & value)\n            : key(key), is_array(false), type(GGUF_TYPE_STRING) {\n        GGML_ASSERT(!key.empty());\n        data_string.push_back(value);\n    }\n\n    gguf_kv(const std::string & key, const std::vector<std::string> & value)\n            : key(key), is_array(true), type(GGUF_TYPE_STRING) {\n        GGML_ASSERT(!key.empty());\n        data_string = value;\n    }\n\n    const std::string & get_key() const {\n        return key;\n    }\n\n    const enum gguf_type & get_type() const {\n        return type;\n    }\n\n    size_t get_ne() const {\n        if (type == GGUF_TYPE_STRING) {\n            const size_t ne = data_string.size();\n            GGML_ASSERT(is_array || ne == 1);\n            return ne;\n        }\n        const size_t type_size = gguf_type_size(type);\n        GGML_ASSERT(data.size() % type_size == 0);\n        const size_t ne = data.size() / type_size;\n        GGML_ASSERT(is_array || ne == 1);\n        return ne;\n    }\n\n    template <typename T>\n    const T & get_val(const size_t i = 0) const {\n        GGML_ASSERT(type_to_gguf_type<T>::value == type);\n        if constexpr (std::is_same<T, std::string>::value) {\n            GGML_ASSERT(data_string.size() >= i+1);\n            return data_string[i];\n        }\n        const size_t type_size = gguf_type_size(type);\n        GGML_ASSERT(data.size() % type_size == 0);\n        GGML_ASSERT(data.size() >= (i+1)*type_size);\n        return reinterpret_cast<const T *>(data.data())[i];\n    }\n\n    void cast(const enum gguf_type new_type) {\n        const size_t new_type_size = gguf_type_size(new_type);\n        GGML_ASSERT(data.size() % new_type_size == 0);\n        type = new_type;\n    }\n};\n\nstruct gguf_tensor_info {\n    struct ggml_tensor t; // for holding the equivalent info\n    uint64_t offset;      // offset from start of `data`, must be a multiple of `ALIGNMENT`\n};\n\nstruct gguf_context {\n    uint32_t version = GGUF_VERSION;\n\n    std::vector<struct gguf_kv> kv;\n    std::vector<struct gguf_tensor_info> info;\n\n    size_t alignment = GGUF_DEFAULT_ALIGNMENT;\n    size_t offset    = 0; // offset of `data` from beginning of file\n    size_t size      = 0; // size of `data` in bytes\n\n    void * data = nullptr;\n};\n\nstruct gguf_reader {\n    gguf_reader(FILE * file) : file(file) {\n        // read the remaining bytes once and update on each read\n        nbytes_remain = file_remain(file);\n    }\n\n    // helper for remaining bytes in a file\n    static uint64_t file_remain(FILE * file) {\n        const int64_t cur = gguf_ftell(file);\n        if (cur < 0) {\n            return 0;\n        }\n        if (gguf_fseek(file, 0, SEEK_END) != 0) {\n            gguf_fseek(file, cur, SEEK_SET);\n\n            return 0;\n        }\n        const int64_t end = gguf_ftell(file);\n        if (end < 0) {\n            gguf_fseek(file, cur, SEEK_SET);\n\n            return 0;\n        }\n        gguf_fseek(file, cur, SEEK_SET);\n        return static_cast<uint64_t>(end - cur);\n    }\n\n    template <typename T>\n    bool read(T & dst) const {\n        const size_t size = sizeof(dst);\n        if (nbytes_remain < size) {\n            return false;\n        }\n        const size_t nread = fread(&dst, 1, size, file);\n        nbytes_remain -= nread;\n        return nread == size;\n    }\n\n    template <typename T>\n    bool read(std::vector<T> & dst, const size_t n) const {\n        if (n > GGUF_MAX_ARRAY_ELEMENTS) {\n            return false;\n        }\n        if constexpr (std::is_same<T, std::string>::value) {\n            // strings are prefixed with their length, so we need to account for that\n            if (n > SIZE_MAX / sizeof(uint64_t)) {\n                return false;\n            }\n            if (nbytes_remain < n * sizeof(uint64_t)) {\n                return false;\n            }\n        } else {\n            if (n > SIZE_MAX / sizeof(T)) {\n                return false;\n            }\n            if (nbytes_remain < n * sizeof(T)) {\n                return false;\n            }\n        }\n        dst.resize(n);\n        for (size_t i = 0; i < dst.size(); ++i) {\n            if constexpr (std::is_same<T, bool>::value) {\n                bool tmp;\n                if (!read(tmp)) {\n                    return false;\n                }\n                dst[i] = tmp;\n            } else {\n                if (!read(dst[i])) {\n                    return false;\n                }\n            }\n        }\n        return true;\n    }\n\n    bool read(bool & dst) const {\n        int8_t tmp = -1;\n        if (!read(tmp)) {\n            return false;\n        }\n        dst = tmp != 0;\n        return true;\n    }\n\n    bool read(enum ggml_type & dst) const {\n        int32_t tmp = -1;\n        if (!read(tmp)) {\n            return false;\n        }\n        dst = ggml_type(tmp);\n        return true;\n    }\n\n    bool read(enum gguf_type & dst) const {\n        int32_t tmp = -1;\n        if (!read(tmp)) {\n            return false;\n        }\n        dst = gguf_type(tmp);\n        return true;\n    }\n\n    bool read(std::string & dst) const {\n        uint64_t size = 0;\n        if (!read(size)) {\n            return false;\n        }\n        if (size > GGUF_MAX_STRING_LENGTH) {\n            GGML_LOG_ERROR(\"%s: string length %\" PRIu64 \" exceeds maximum %\" PRIu64 \"\\n\", __func__, size, (uint64_t) GGUF_MAX_STRING_LENGTH);\n            return false;\n        }\n        if (size > nbytes_remain) {\n            GGML_LOG_ERROR(\"%s: string length %\" PRIu64 \" exceeds remaining file size %\" PRIu64 \" bytes\\n\", __func__, size, nbytes_remain);\n            return false;\n        }\n        dst.resize(static_cast<size_t>(size));\n        const size_t nread = fread(dst.data(), 1, size, file);\n        nbytes_remain -= nread;\n        return nread == size;\n    }\n\n    bool read(void * dst, const size_t size) const {\n        if (size > nbytes_remain) {\n            return false;\n        }\n        const size_t nread = fread(dst, 1, size, file);\n        nbytes_remain -= nread;\n        return nread == size;\n    }\n\nprivate:\n    FILE * file;\n\n    mutable uint64_t nbytes_remain;\n};\n\nstruct gguf_context * gguf_init_empty(void) {\n    return new gguf_context;\n}\n\ntemplate<typename T>\nbool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector<struct gguf_kv> & kv, const std::string & key, const bool is_array, const size_t n) {\n    if (is_array) {\n        std::vector<T> value;\n        try {\n            if (!gr.read(value, n)) {\n                return false;\n            }\n        } catch (std::length_error &) {\n            GGML_LOG_ERROR(\"%s: encountered length_error while reading value for key '%s'\\n\", __func__, key.c_str());\n            return false;\n        } catch (std::bad_alloc &) {\n            GGML_LOG_ERROR(\"%s: encountered bad_alloc error while reading value for key '%s'\\n\", __func__, key.c_str());\n            return false;\n        }\n        kv.emplace_back(key, value);\n    } else {\n        T value;\n        if (!gr.read(value)) {\n            return false;\n        }\n        kv.emplace_back(key, value);\n    }\n    return true;\n}\n\nstruct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {\n    const struct gguf_reader gr(file);\n    struct gguf_context * ctx = new gguf_context;\n\n    bool ok = true;\n\n    // file magic\n    {\n        std::vector<char> magic;\n        ok = ok && gr.read(magic, 4);\n\n        if (!ok) {\n            GGML_LOG_ERROR(\"%s: failed to read magic\\n\", __func__);\n            gguf_free(ctx);\n            return nullptr;\n        }\n\n        for (uint32_t i = 0; i < magic.size(); i++) {\n            if (magic[i] != GGUF_MAGIC[i]) {\n                char c0 = isprint(magic[0]) ? magic[0] : '?';\n                char c1 = isprint(magic[1]) ? magic[1] : '?';\n                char c2 = isprint(magic[2]) ? magic[2] : '?';\n                char c3 = isprint(magic[3]) ? magic[3] : '?';\n                GGML_LOG_ERROR(\"%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\\n\", __func__, c0, c1, c2, c3);\n                gguf_free(ctx);\n                return nullptr;\n            }\n        }\n    }\n\n    // header\n    int64_t n_kv      = 0;\n    int64_t n_tensors = 0;\n\n    if (ok && gr.read(ctx->version)) {\n        if (ok && ctx->version == 0) {\n            GGML_LOG_ERROR(\"%s: bad GGUF version: %\" PRIu32 \"\\n\", __func__, ctx->version);\n            ok = false;\n        }\n\n        /*\n         * bit layout is different when reading non-native endian models.\n         * assuming that the GGUF version is 3, the non-native endian model\n         * would read it as 0x30000000. we can use the AND operation against\n         * the last 4 hexadecimal digits to check if the model is the same\n         * endianness as the host system.\n        */\n        if (ok && (ctx->version & 0x0000FFFF) == 0x00000000) {\n            GGML_LOG_ERROR(\"%s: failed to load model: this GGUF file version %\" PRIu32 \" is extremely large, is there a mismatch between the host and model endianness?\\n\", __func__, ctx->version);\n            ok = false;\n        }\n\n        if (ok && ctx->version == 1) {\n            GGML_LOG_ERROR(\"%s: GGUFv1 is no longer supported, please use a more up-to-date version\\n\", __func__);\n            ok = false;\n        }\n        if (ok && ctx->version > GGUF_VERSION) {\n            GGML_LOG_ERROR(\"%s: this GGUF file is version %\" PRIu32 \" but this software only supports up to version %d\\n\",\n                __func__, ctx->version, GGUF_VERSION);\n            ok = false;\n        }\n    } else {\n        ok = false;\n    }\n\n    if (ok && gr.read(n_tensors)) {\n        static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, \"int64_t insufficient for indexing\");\n        if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) {\n            GGML_LOG_ERROR(\"%s: number of tensors is %\" PRIi64 \" but must be in [0, %zu]\\n\",\n                __func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info));\n            ok = false;\n        }\n    } else {\n        ok = false;\n    }\n\n    if (ok && gr.read(n_kv)) {\n        static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, \"int64_t insufficient for indexing\");\n        if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) {\n            GGML_LOG_ERROR(\"%s: number of key value pairs is %\" PRIi64 \" but must be in [0, %zu]\\n\",\n                    __func__, n_kv, SIZE_MAX/sizeof(gguf_kv));\n            ok = false;\n        }\n    } else {\n        ok = false;\n    }\n\n    if (!ok) {\n        GGML_LOG_ERROR(\"%s: failed to read header\\n\", __func__);\n        gguf_free(ctx);\n        return nullptr;\n    }\n\n    // KV pairs\n    {\n        for (int64_t i = 0; ok && i < n_kv; ++i) {\n            std::string key;\n            gguf_type   type     = gguf_type(-1);\n            bool        is_array = false;\n            uint64_t    n        = 1;\n\n            try {\n                ok = ok && gr.read(key);\n            } catch (std::length_error &) {\n                GGML_LOG_ERROR(\"%s: encountered length_error while reading key %\" PRIi64 \"\\n\", __func__, i);\n                ok = false;\n            } catch (std::bad_alloc &) {\n                GGML_LOG_ERROR(\"%s: encountered bad_alloc error while reading key %\" PRIi64 \"\\n\", __func__, i);\n                ok = false;\n            }\n            for (size_t j = 0; ok && j < ctx->kv.size(); ++j) {\n                if (key == ctx->kv[j].key) {\n                    GGML_LOG_ERROR(\"%s: duplicate key '%s' for tensors %zu and %\" PRIi64 \" \\n\", __func__, key.c_str(), j, i);\n                    ok = false;\n                }\n            }\n            if (!ok) {\n                break;\n            }\n\n            ok = ok && gr.read(type);\n            if (type == GGUF_TYPE_ARRAY) {\n                is_array = true;\n                ok = ok && gr.read(type);\n                ok = ok && gr.read(n);\n            }\n            if (!ok) {\n                break;\n            }\n\n            switch (type) {\n                case GGUF_TYPE_UINT8:   ok = ok && gguf_read_emplace_helper<uint8_t>    (gr, ctx->kv, key, is_array, n); break;\n                case GGUF_TYPE_INT8:    ok = ok && gguf_read_emplace_helper<int8_t>     (gr, ctx->kv, key, is_array, n); break;\n                case GGUF_TYPE_UINT16:  ok = ok && gguf_read_emplace_helper<uint16_t>   (gr, ctx->kv, key, is_array, n); break;\n                case GGUF_TYPE_INT16:   ok = ok && gguf_read_emplace_helper<int16_t>    (gr, ctx->kv, key, is_array, n); break;\n                case GGUF_TYPE_UINT32:  ok = ok && gguf_read_emplace_helper<uint32_t>   (gr, ctx->kv, key, is_array, n); break;\n                case GGUF_TYPE_INT32:   ok = ok && gguf_read_emplace_helper<int32_t>    (gr, ctx->kv, key, is_array, n); break;\n                case GGUF_TYPE_FLOAT32: ok = ok && gguf_read_emplace_helper<float>      (gr, ctx->kv, key, is_array, n); break;\n                case GGUF_TYPE_BOOL:    ok = ok && gguf_read_emplace_helper<bool>       (gr, ctx->kv, key, is_array, n); break;\n                case GGUF_TYPE_STRING:  ok = ok && gguf_read_emplace_helper<std::string>(gr, ctx->kv, key, is_array, n); break;\n                case GGUF_TYPE_UINT64:  ok = ok && gguf_read_emplace_helper<uint64_t>   (gr, ctx->kv, key, is_array, n); break;\n                case GGUF_TYPE_INT64:   ok = ok && gguf_read_emplace_helper<int64_t>    (gr, ctx->kv, key, is_array, n); break;\n                case GGUF_TYPE_FLOAT64: ok = ok && gguf_read_emplace_helper<double>     (gr, ctx->kv, key, is_array, n); break;\n                case GGUF_TYPE_ARRAY:\n                default:\n                    {\n                        GGML_LOG_ERROR(\"%s: key '%s' has invalid GGUF type %d\\n\", __func__, key.c_str(), type);\n                        ok = false;\n                    } break;\n            }\n        }\n\n        if (!ok) {\n            GGML_LOG_ERROR(\"%s: failed to read key-value pairs\\n\", __func__);\n            gguf_free(ctx);\n            return nullptr;\n        }\n        GGML_ASSERT(int64_t(ctx->kv.size()) == n_kv);\n\n        const int alignment_idx = gguf_find_key(ctx, GGUF_KEY_GENERAL_ALIGNMENT);\n        ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx);\n\n        if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) {\n            GGML_LOG_ERROR(\"%s: alignment %zu is not a power of 2\\n\", __func__, ctx->alignment);\n            gguf_free(ctx);\n            return nullptr;\n        }\n    }\n\n    // read the tensor info\n    for (int64_t i = 0; ok && i < n_tensors; ++i) {\n        struct gguf_tensor_info info;\n\n        // tensor name\n        {\n            std::string name;\n            try {\n                ok = ok && gr.read(name);\n            } catch (std::length_error &) {\n                GGML_LOG_ERROR(\"%s: encountered length_error while reading tensor name %\" PRIi64 \"\\n\", __func__, i);\n                ok = false;\n            } catch (std::bad_alloc &) {\n                GGML_LOG_ERROR(\"%s: encountered bad_alloc error while reading tensor name %\" PRIi64 \"\\n\", __func__, i);\n                ok = false;\n            }\n            if (name.length() >= GGML_MAX_NAME) {\n                GGML_LOG_ERROR(\"%s: tensor name %\" PRIi64 \" is too long: %zu >= %d\\n\", __func__, i, name.length(), GGML_MAX_NAME);\n                ok = false;\n                break;\n            }\n            ggml_set_name(&info.t, name.c_str());\n\n            // make sure there are no duplicate tensor names\n            for (int64_t j = 0; ok && j < i; ++j) {\n                if (strcmp(info.t.name, ctx->info[j].t.name) == 0) {\n                    GGML_LOG_ERROR(\"%s: duplicate tensor name '%s' for tensors %\" PRIi64 \" and %\" PRIi64 \"\\n\", __func__, info.t.name, j, i);\n                    ok = false;\n                    break;\n                }\n            }\n        }\n        if (!ok) {\n            break;\n        }\n\n        // tensor shape\n        {\n            uint32_t n_dims = 0;\n            ok = ok && gr.read(n_dims);\n            if (n_dims > GGML_MAX_DIMS) {\n                GGML_LOG_ERROR(\"%s: tensor '%s' has invalid number of dimensions: %\" PRIu32 \" > %\" PRIu32 \"\\n\",\n                    __func__, info.t.name, n_dims, GGML_MAX_DIMS);\n                ok = false;\n                break;\n            }\n            for (uint32_t j = 0; ok && j < GGML_MAX_DIMS; ++j) {\n                info.t.ne[j] = 1;\n                if (j < n_dims) {\n                    ok = ok && gr.read(info.t.ne[j]);\n                }\n\n                // check that all ne are non-negative\n                if (info.t.ne[j] < 0) {\n                    GGML_LOG_ERROR(\"%s: tensor '%s' dimension %\" PRIu32 \" has invalid number of elements: %\" PRIi64 \" < 0\\n\",\n                        __func__, info.t.name, j, info.t.ne[j]);\n                    ok = false;\n                    break;\n                }\n            }\n\n            // check that the total number of elements is representable\n            if (ok && ((INT64_MAX/info.t.ne[1] <= info.t.ne[0]) ||\n                       (INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) ||\n                       (INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) {\n\n                GGML_LOG_ERROR(\"%s: total number of elements in tensor '%s' with shape \"\n                    \"(%\" PRIi64 \", %\" PRIi64 \", %\" PRIi64 \", %\" PRIi64 \") is >= %\" PRIi64 \"\\n\",\n                    __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX);\n                ok = false;\n                break;\n            }\n        }\n        if (!ok) {\n            break;\n        }\n\n        // tensor type\n        {\n            ok = ok && gr.read(info.t.type);\n\n            // check that tensor type is within defined range\n            if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {\n                GGML_LOG_ERROR(\"%s: tensor '%s' has invalid ggml type %d. should be in [0, %d)\\n\",\n                    __func__, info.t.name, info.t.type, GGML_TYPE_COUNT);\n                ok = false;\n                break;\n            }\n            const size_t  type_size = ggml_type_size(info.t.type);\n            const int64_t blck_size = ggml_blck_size(info.t.type);\n\n            // check that row size is divisible by block size\n            if (blck_size == 0 || info.t.ne[0] % blck_size != 0) {\n                GGML_LOG_ERROR(\"%s: tensor '%s' of type %d (%s) has %\" PRId64 \" elements per row, \"\n                    \"not a multiple of block size (%\" PRId64 \")\\n\",\n                    __func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size);\n                ok = false;\n                break;\n            }\n\n            // check that the size of the tensor in bytes is representable\n            if (ok && uint64_t(ggml_nelements(&info.t)/ggml_blck_size(info.t.type)) > SIZE_MAX/ggml_type_size(info.t.type)) {\n                GGML_LOG_ERROR(\"%s: tensor '%s' with shape (%\" PRIi64 \", %\" PRIi64 \", %\" PRIi64 \", %\" PRIi64 \") has a size in bytes > %zu\\n\",\n                    __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], SIZE_MAX);\n                ok = false;\n                break;\n            }\n\n            // calculate byte offsets given the tensor shape and type\n            info.t.nb[0] = type_size;\n            info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size);\n            for (int j = 2; j < GGML_MAX_DIMS; ++j) {\n                info.t.nb[j] = info.t.nb[j - 1]*info.t.ne[j - 1];\n            }\n        }\n        if (!ok) {\n            break;\n        }\n\n        // tensor data offset within buffer\n        ok = ok && gr.read(info.offset);\n\n        ctx->info.push_back(info);\n    }\n\n    if (!ok) {\n        GGML_LOG_ERROR(\"%s: failed to read tensor info\\n\", __func__);\n        gguf_free(ctx);\n        return nullptr;\n    }\n    GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors);\n\n    // we require the data section to be aligned, so take into account any padding\n    if (gguf_fseek(file, GGML_PAD(gguf_ftell(file), ctx->alignment), SEEK_SET) != 0) {\n        GGML_LOG_ERROR(\"%s: failed to seek to beginning of data section\\n\", __func__);\n        gguf_free(ctx);\n        return nullptr;\n    }\n\n    // store the current file offset - this is where the data section starts\n    ctx->offset = gguf_ftell(file);\n\n    // compute the total size of the data section, taking into account the alignment\n    {\n        ctx->size = 0;\n        for (size_t i = 0; i < ctx->info.size(); ++i) {\n            const gguf_tensor_info & ti = ctx->info[i];\n            if (ti.offset != ctx->size) {\n                GGML_LOG_ERROR(\"%s: tensor '%s' has offset %\" PRIu64 \", expected %zu\\n\",\n                    __func__, ti.t.name, ti.offset, ctx->size);\n                GGML_LOG_ERROR(\"%s: failed to read tensor data\\n\", __func__);\n                gguf_free(ctx);\n                return nullptr;\n            }\n            size_t padded_size = GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment);\n            if (SIZE_MAX - ctx->size < padded_size) {\n                GGML_LOG_ERROR(\"%s: tensor '%s' size overflow, cannot accumulate size %zu + %zu\\n\",\n                    __func__, ti.t.name, ctx->size, padded_size);\n                gguf_free(ctx);\n                return nullptr;\n            }\n            ctx->size += padded_size;\n        }\n    }\n\n    // load the tensor data only if requested\n    if (params.ctx != nullptr) {\n        // if the provided gguf_context is no_alloc, then we create \"empty\" tensors and do not read the binary blob\n        // otherwise, we load the binary blob into the created ggml_context as well, and point the \"data\" members of\n        //   the ggml_tensor structs to the appropriate locations in the binary blob\n\n        // compute the exact size needed for the new ggml_context\n        size_t mem_size = 0;\n        if (params.no_alloc) {\n            if (n_tensors != 0 && SIZE_MAX / n_tensors < ggml_tensor_overhead()) {\n                GGML_LOG_ERROR(\"%s: memory size overflow while allocating ggml context\\n\", __func__);\n                gguf_free(ctx);\n                return nullptr;\n            }\n\n            const size_t overhead = n_tensors * ggml_tensor_overhead();\n\n            mem_size = overhead;\n        } else {\n            if ((n_tensors + 1) != 0 && SIZE_MAX / (n_tensors + 1) < ggml_tensor_overhead()) {\n                GGML_LOG_ERROR(\"%s: memory size overflow while allocating ggml context\\n\", __func__);\n                gguf_free(ctx);\n                return nullptr;\n            }\n\n            const size_t overhead = (n_tensors + 1) * ggml_tensor_overhead();\n\n            if (SIZE_MAX - overhead < ctx->size) {\n                GGML_LOG_ERROR(\"%s: memory size overflow while allocating ggml context\\n\", __func__);\n                gguf_free(ctx);\n                return nullptr;\n            }\n\n            mem_size = overhead + ctx->size;\n        }\n\n        struct ggml_init_params pdata = {\n            /*mem_size   =*/ mem_size,\n            /*mem_buffer =*/ nullptr,\n            /*no_alloc   =*/ params.no_alloc,\n        };\n\n        *params.ctx = ggml_init(pdata);\n        if (*params.ctx == nullptr) {\n            GGML_LOG_ERROR(\"%s: failed to initialize ggml context for storing tensors\\n\", __func__);\n            gguf_free(ctx);\n            return nullptr;\n        }\n\n        struct ggml_context * ctx_data = *params.ctx;\n\n        struct ggml_tensor * data = nullptr;\n\n        if (!params.no_alloc) {\n            data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);\n\n            ok = ok && data != nullptr;\n\n            if (ok) {\n                ggml_set_name(data, \"GGUF tensor data binary blob\");\n            }\n\n            // read the binary blob with the tensor data\n            ok = ok && gr.read(data->data, ctx->size);\n\n            if (!ok) {\n                GGML_LOG_ERROR(\"%s: failed to read tensor data binary blob\\n\", __func__);\n                ggml_free(ctx_data);\n                *params.ctx = nullptr;\n                gguf_free(ctx);\n                return nullptr;\n            }\n\n            ctx->data = data->data;\n        }\n\n        ggml_set_no_alloc(ctx_data, true);\n\n        // create the tensors\n        for (size_t i = 0; i < ctx->info.size(); ++i) {\n            const struct gguf_tensor_info & info = ctx->info[i];\n\n            struct ggml_tensor * cur = ggml_new_tensor(ctx_data, info.t.type, GGML_MAX_DIMS, info.t.ne);\n\n            ok = ok && cur != nullptr;\n\n            if (!ok) {\n                break;\n            }\n\n            ggml_set_name(cur, info.t.name);\n\n            // point the data member to the appropriate location in the binary blob using the tensor info\n            if (!params.no_alloc) {\n                cur->data = (char *) data->data + info.offset;\n            }\n        }\n\n        if (!ok) {\n            GGML_LOG_ERROR(\"%s: failed to create tensors\\n\", __func__);\n            ggml_free(ctx_data);\n            *params.ctx = nullptr;\n            gguf_free(ctx);\n            return nullptr;\n        }\n\n        ggml_set_no_alloc(ctx_data, params.no_alloc);\n    }\n\n    return ctx;\n}\n\nstruct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {\n    FILE * file = ggml_fopen(fname, \"rb\");\n\n    if (!file) {\n        GGML_LOG_ERROR(\"%s: failed to open GGUF file '%s' (%s)\\n\", __func__, fname, strerror(errno));\n        return nullptr;\n    }\n\n    struct gguf_context * result = gguf_init_from_file_impl(file, params);\n    fclose(file);\n    return result;\n}\n\nvoid gguf_free(struct gguf_context * ctx) {\n    if (ctx == nullptr) {\n        return;\n    }\n    delete ctx;\n}\n\nconst char * gguf_type_name(enum gguf_type type) {\n    auto it = GGUF_TYPE_NAME.find(type);\n    return it == GGUF_TYPE_NAME.end() ? nullptr : it->second;\n}\n\nuint32_t gguf_get_version(const struct gguf_context * ctx) {\n    return ctx->version;\n}\n\nsize_t gguf_get_alignment(const struct gguf_context * ctx) {\n    return ctx->alignment;\n}\n\nsize_t gguf_get_data_offset(const struct gguf_context * ctx) {\n    return ctx->offset;\n}\n\nint64_t gguf_get_n_kv(const struct gguf_context * ctx) {\n    return ctx->kv.size();\n}\n\nint64_t gguf_find_key(const struct gguf_context * ctx, const char * key) {\n    // return -1 if key not found\n    int64_t keyfound = -1;\n\n    const int64_t n_kv = gguf_get_n_kv(ctx);\n\n    for (int64_t i = 0; i < n_kv; ++i) {\n        if (strcmp(key, gguf_get_key(ctx, i)) == 0) {\n            keyfound = i;\n            break;\n        }\n    }\n\n    return keyfound;\n}\n\nconst char * gguf_get_key(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    return ctx->kv[key_id].get_key().c_str();\n}\n\nenum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    return ctx->kv[key_id].is_array ? GGUF_TYPE_ARRAY : ctx->kv[key_id].get_type();\n}\n\nenum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].is_array);\n    return ctx->kv[key_id].get_type();\n}\n\nconst void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);\n    return ctx->kv[key_id].data.data();\n}\n\nconst char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);\n    return ctx->kv[key_id].data_string[i].c_str();\n}\n\nsize_t gguf_get_arr_n(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n\n    if (ctx->kv[key_id].type == GGUF_TYPE_STRING) {\n        return ctx->kv[key_id].data_string.size();\n    }\n\n    const size_t type_size = gguf_type_size(ctx->kv[key_id].type);\n    GGML_ASSERT(ctx->kv[key_id].data.size() % type_size == 0);\n    return ctx->kv[key_id].data.size() / type_size;\n}\n\nuint8_t gguf_get_val_u8(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);\n    return ctx->kv[key_id].get_val<uint8_t>();\n}\n\nint8_t gguf_get_val_i8(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);\n    return ctx->kv[key_id].get_val<int8_t>();\n}\n\nuint16_t gguf_get_val_u16(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);\n    return ctx->kv[key_id].get_val<uint16_t>();\n}\n\nint16_t gguf_get_val_i16(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);\n    return ctx->kv[key_id].get_val<int16_t>();\n}\n\nuint32_t gguf_get_val_u32(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);\n    return ctx->kv[key_id].get_val<uint32_t>();\n}\n\nint32_t gguf_get_val_i32(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);\n    return ctx->kv[key_id].get_val<int32_t>();\n}\n\nfloat gguf_get_val_f32(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);\n    return ctx->kv[key_id].get_val<float>();\n}\n\nuint64_t gguf_get_val_u64(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);\n    return ctx->kv[key_id].get_val<uint64_t>();\n}\n\nint64_t gguf_get_val_i64(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);\n    return ctx->kv[key_id].get_val<int64_t>();\n}\n\ndouble gguf_get_val_f64(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);\n    return ctx->kv[key_id].get_val<double>();\n}\n\nbool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);\n    return ctx->kv[key_id].get_val<bool>();\n}\n\nconst char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);\n    return ctx->kv[key_id].get_val<std::string>().c_str();\n}\n\nconst void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {\n    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));\n    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);\n    GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);\n    return ctx->kv[key_id].data.data();\n}\n\nint64_t gguf_get_n_tensors(const struct gguf_context * ctx) {\n    return ctx->info.size();\n}\n\nint64_t gguf_find_tensor(const struct gguf_context * ctx, const char * name) {\n    // return -1 if tensor not found\n    int64_t tensor_id = -1;\n\n    const int64_t n_tensors = gguf_get_n_tensors(ctx);\n\n    for (int64_t i = 0; i < n_tensors; ++i) {\n        if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) {\n            tensor_id = i;\n            break;\n        }\n    }\n\n    return tensor_id;\n}\n\nsize_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id) {\n    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));\n    return ctx->info[tensor_id].offset;\n}\n\nconst char * gguf_get_tensor_name(const struct gguf_context * ctx, int64_t tensor_id) {\n    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));\n    return ctx->info[tensor_id].t.name;\n}\n\nenum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int64_t tensor_id) {\n    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));\n    return ctx->info[tensor_id].t.type;\n}\n\nsize_t gguf_get_tensor_size(const struct gguf_context * ctx, int64_t tensor_id) {\n    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));\n    return ggml_nbytes(&ctx->info[tensor_id].t);\n}\n\nint64_t gguf_remove_key(struct gguf_context * ctx, const char * key) {\n    const int64_t key_id = gguf_find_key(ctx, key);\n    if (key_id >= 0) {\n        ctx->kv.erase(ctx->kv.begin() + key_id);\n    }\n    return key_id;\n}\n\ntemplate<typename T>\nstatic void gguf_check_reserved_keys(const std::string & key, const T val) {\n    if (key == GGUF_KEY_GENERAL_ALIGNMENT) {\n        if constexpr (std::is_same<T, uint32_t>::value) {\n            GGML_ASSERT(val > 0 && (val & (val - 1)) == 0 && GGUF_KEY_GENERAL_ALIGNMENT \" must be power of 2\");\n        } else {\n            GGML_UNUSED(val);\n            GGML_ABORT(GGUF_KEY_GENERAL_ALIGNMENT \" must be type u32\");\n        }\n    }\n}\n\nvoid gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) {\n    gguf_check_reserved_keys(key, val);\n    gguf_remove_key(ctx, key);\n    ctx->kv.emplace_back(key, val);\n}\n\nvoid gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) {\n    gguf_check_reserved_keys(key, val);\n    gguf_remove_key(ctx, key);\n    ctx->kv.emplace_back(key, val);\n}\n\nvoid gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) {\n    gguf_check_reserved_keys(key, val);\n    gguf_remove_key(ctx, key);\n    ctx->kv.emplace_back(key, val);\n}\n\nvoid gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) {\n    gguf_check_reserved_keys(key, val);\n    gguf_remove_key(ctx, key);\n    ctx->kv.emplace_back(key, val);\n}\n\nvoid gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) {\n    gguf_check_reserved_keys(key, val);\n    gguf_remove_key(ctx, key);\n    ctx->kv.emplace_back(key, val);\n}\n\nvoid gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) {\n    gguf_check_reserved_keys(key, val);\n    gguf_remove_key(ctx, key);\n    ctx->kv.emplace_back(key, val);\n}\n\nvoid gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {\n    gguf_check_reserved_keys(key, val);\n    gguf_remove_key(ctx, key);\n    ctx->kv.emplace_back(key, val);\n}\n\nvoid gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {\n    gguf_check_reserved_keys(key, val);\n    gguf_remove_key(ctx, key);\n    ctx->kv.emplace_back(key, val);\n}\n\nvoid gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {\n    gguf_check_reserved_keys(key, val);\n    gguf_remove_key(ctx, key);\n    ctx->kv.emplace_back(key, val);\n}\n\nvoid gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {\n    gguf_check_reserved_keys(key, val);\n    gguf_remove_key(ctx, key);\n    ctx->kv.emplace_back(key, val);\n}\n\nvoid gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {\n    gguf_check_reserved_keys(key, val);\n    gguf_remove_key(ctx, key);\n    ctx->kv.emplace_back(key, val);\n}\n\nvoid gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) {\n    gguf_check_reserved_keys(key, val);\n    gguf_remove_key(ctx, key);\n    ctx->kv.emplace_back(key, std::string(val));\n}\n\nvoid gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n) {\n    gguf_check_reserved_keys(key, data);\n    gguf_remove_key(ctx, key);\n\n    const size_t nbytes = n*gguf_type_size(type);\n    std::vector<int8_t> tmp(nbytes);\n    if (!tmp.empty()) {\n        memcpy(tmp.data(), data, nbytes);\n    }\n    ctx->kv.emplace_back(key, tmp);\n    ctx->kv.back().cast(type);\n}\n\nvoid gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, size_t n) {\n    gguf_check_reserved_keys(key, data);\n    gguf_remove_key(ctx, key);\n\n    std::vector<std::string> tmp(n);\n    for (size_t i = 0; i < n; ++i) {\n        tmp[i] = data[i];\n    }\n    ctx->kv.emplace_back(key, tmp);\n}\n\n// set or add KV pairs from another context\nvoid gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src) {\n    const int64_t n_kv = gguf_get_n_kv(src);\n    for (int64_t i = 0; i < n_kv; ++i) {\n        const struct gguf_kv & kv = src->kv[i];\n\n        if (!kv.is_array) {\n            switch (kv.get_type()) {\n                case GGUF_TYPE_UINT8:   gguf_set_val_u8  (ctx, kv.get_key().c_str(), kv.get_val<uint8_t>());             break;\n                case GGUF_TYPE_INT8:    gguf_set_val_i8  (ctx, kv.get_key().c_str(), kv.get_val<int8_t>());              break;\n                case GGUF_TYPE_UINT16:  gguf_set_val_u16 (ctx, kv.get_key().c_str(), kv.get_val<uint16_t>());            break;\n                case GGUF_TYPE_INT16:   gguf_set_val_i16 (ctx, kv.get_key().c_str(), kv.get_val<int16_t>());             break;\n                case GGUF_TYPE_UINT32:  gguf_set_val_u32 (ctx, kv.get_key().c_str(), kv.get_val<uint32_t>());            break;\n                case GGUF_TYPE_INT32:   gguf_set_val_i32 (ctx, kv.get_key().c_str(), kv.get_val<int32_t>());             break;\n                case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, kv.get_key().c_str(), kv.get_val<float>());               break;\n                case GGUF_TYPE_UINT64:  gguf_set_val_u64 (ctx, kv.get_key().c_str(), kv.get_val<uint64_t>());            break;\n                case GGUF_TYPE_INT64:   gguf_set_val_i64 (ctx, kv.get_key().c_str(), kv.get_val<int64_t>());             break;\n                case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, kv.get_key().c_str(), kv.get_val<double>());              break;\n                case GGUF_TYPE_BOOL:    gguf_set_val_bool(ctx, kv.get_key().c_str(), kv.get_val<bool>());                break;\n                case GGUF_TYPE_STRING:  gguf_set_val_str (ctx, kv.get_key().c_str(), kv.get_val<std::string>().c_str()); break;\n                case GGUF_TYPE_ARRAY:\n                default: GGML_ABORT(\"invalid type\");\n            }\n            continue;\n        }\n\n        const size_t ne = kv.get_ne();\n\n        switch (kv.get_type()) {\n            case GGUF_TYPE_UINT8:\n            case GGUF_TYPE_INT8:\n            case GGUF_TYPE_UINT16:\n            case GGUF_TYPE_INT16:\n            case GGUF_TYPE_UINT32:\n            case GGUF_TYPE_INT32:\n            case GGUF_TYPE_FLOAT32:\n            case GGUF_TYPE_UINT64:\n            case GGUF_TYPE_INT64:\n            case GGUF_TYPE_FLOAT64:\n            case GGUF_TYPE_BOOL: {\n                gguf_set_arr_data(ctx, kv.get_key().c_str(), kv.get_type(), kv.data.data(), ne);\n            } break;\n            case GGUF_TYPE_STRING: {\n                std::vector<const char *> tmp(ne);\n                for (size_t j = 0; j < ne; ++j) {\n                    tmp[j] = kv.data_string[j].c_str();\n                }\n                gguf_set_arr_str(ctx, kv.get_key().c_str(), tmp.data(), ne);\n            } break;\n            case GGUF_TYPE_ARRAY:\n            default: GGML_ABORT(\"invalid type\");\n        }\n    }\n}\n\nvoid gguf_add_tensor(\n             struct gguf_context * ctx,\n        const struct ggml_tensor * tensor) {\n    GGML_ASSERT(tensor);\n    if (gguf_find_tensor(ctx, tensor->name) != -1) {\n        GGML_ABORT(\"duplicate tensor name: %s\", tensor->name);\n    }\n\n    struct gguf_tensor_info ti;\n    ti.t = *tensor;\n    ti.offset = ctx->info.empty() ? 0 :\n        ctx->info.back().offset + GGML_PAD(ggml_nbytes(&ctx->info.back().t), ctx->alignment);\n    ctx->info.push_back(ti);\n}\n\nvoid gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) {\n    const int64_t tensor_id = gguf_find_tensor(ctx, name);\n    if (tensor_id < 0) {\n        GGML_ABORT(\"tensor not found: %s\", name);\n    }\n    struct ggml_tensor * tensor = &ctx->info[tensor_id].t;\n    const size_t  type_size = ggml_type_size(type);\n    const int64_t blck_size = ggml_blck_size(type);\n\n    tensor->type = type;\n    GGML_ASSERT(tensor->ne[0] % blck_size == 0 && \"tensor row size not divisible by block size of new type\");\n\n    tensor->nb[0] = type_size;\n    tensor->nb[1] = tensor->nb[0]*(tensor->ne[0]/blck_size);\n    for (int i = 2; i < GGML_MAX_DIMS; i++) {\n        tensor->nb[i] = tensor->nb[i - 1]*tensor->ne[i - 1];\n    }\n\n    // update offsets\n    const int64_t n_tensors = gguf_get_n_tensors(ctx);\n    for (int64_t i = tensor_id + 1; i < n_tensors; ++i) {\n        ctx->info[i].offset = ctx->info[i - 1].offset + GGML_PAD(ggml_nbytes(&ctx->info[i - 1].t), ctx->alignment);\n    }\n}\n\nvoid gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data) {\n    const int64_t tensor_id = gguf_find_tensor(ctx, name);\n    if (tensor_id < 0) {\n        GGML_ABORT(\"tensor not found: %s\", name);\n    }\n\n    ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const\n}\n\nstruct gguf_writer_base {\n    size_t written_bytes {0u};\n\n    ~gguf_writer_base(void) = default;\n\n    // we bet on devirtualization\n    virtual void write(int8_t val) = 0;\n    virtual void write(const std::vector<int8_t> & val) = 0;\n    virtual void write_tensor_data(const struct gguf_tensor_info & info, size_t offset_data, size_t alignment) = 0;\n\n    template <typename T>\n    void write(const T & val) {\n        for (size_t i = 0; i < sizeof(val); ++i) {\n            write(reinterpret_cast<const int8_t *>(&val)[i]);\n        }\n    }\n\n    void write(const bool & val) {\n        const int8_t val8 = val ? 1 : 0;\n        write(val8);\n    }\n\n    void write(const std::string & val) {\n        {\n            const uint64_t n = val.length();\n            write(n);\n        }\n        for (size_t i = 0; i < val.length(); ++i) {\n            write((val.data())[i]);\n        }\n    }\n\n    void write(const char * val) {\n        write(std::string(val));\n    }\n\n    void write(const enum ggml_type & val) {\n        write(int32_t(val));\n    }\n\n    void write(const enum gguf_type & val) {\n        write(int32_t(val));\n    }\n\n    void write(const struct gguf_kv & kv) {\n        const uint64_t ne = kv.get_ne();\n\n        write(kv.get_key());\n\n        if (kv.is_array) {\n            write(GGUF_TYPE_ARRAY);\n            write(kv.get_type());\n            write(ne);\n        } else {\n            write(kv.get_type());\n        }\n\n        switch (kv.get_type()) {\n            case GGUF_TYPE_UINT8:\n            case GGUF_TYPE_INT8:\n            case GGUF_TYPE_UINT16:\n            case GGUF_TYPE_INT16:\n            case GGUF_TYPE_UINT32:\n            case GGUF_TYPE_INT32:\n            case GGUF_TYPE_FLOAT32:\n            case GGUF_TYPE_UINT64:\n            case GGUF_TYPE_INT64:\n            case GGUF_TYPE_FLOAT64: {\n                write(kv.data);\n            } break;\n            case GGUF_TYPE_BOOL: {\n                for (size_t i = 0; i < ne; ++i) {\n                    write(kv.get_val<bool>(i));\n                }\n            } break;\n            case GGUF_TYPE_STRING: {\n                for (size_t i = 0; i < ne; ++i) {\n                    write(kv.get_val<std::string>(i));\n                }\n            } break;\n            case GGUF_TYPE_ARRAY:\n            default: GGML_ABORT(\"invalid type\");\n        }\n    }\n\n    void write_tensor_meta(const struct gguf_tensor_info & info) {\n        write(info.t.name);\n\n        const uint32_t n_dims = ggml_n_dims(&info.t);\n        write(n_dims);\n\n        for (uint32_t j = 0; j < n_dims; ++j) {\n            write(info.t.ne[j]);\n        }\n        write(info.t.type);\n        write(info.offset);\n    }\n\n    void pad(const size_t alignment) {\n        while (written_bytes % alignment != 0) {\n            const int8_t zero = 0;\n            write(zero);\n        }\n    }\n};\n\n// vector buffer based writer\nstruct gguf_writer_buf final : public gguf_writer_base {\n    std::vector<int8_t> & buf;\n\n    gguf_writer_buf(std::vector<int8_t> & buf) : buf(buf) {}\n\n    using gguf_writer_base::write;\n\n    void write(const int8_t val) override {\n        buf.push_back(val);\n        written_bytes++;\n    }\n\n    void write(const std::vector<int8_t> & val) override {\n        buf.insert(buf.end(), val.begin(), val.end());\n        written_bytes += val.size();\n    }\n\n    void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override {\n        GGML_ASSERT(buf.size() - offset_data == info.offset);\n\n        GGML_ASSERT(ggml_is_contiguous(&info.t));\n        const size_t offset = buf.size();\n        const size_t nbytes = ggml_nbytes(&info.t);\n\n        buf.resize(offset + nbytes);\n        if (info.t.buffer) {\n            ggml_backend_tensor_get(&info.t, buf.data() + offset, 0, nbytes);\n        } else {\n            GGML_ASSERT(info.t.data);\n            memcpy(buf.data() + offset, info.t.data, nbytes);\n        }\n        written_bytes += nbytes;\n\n        pad(alignment);\n    }\n};\n\n// file based writer\nstruct gguf_writer_file final : public gguf_writer_base {\n    FILE * file;\n\n    gguf_writer_file(FILE* file) : file(file) {}\n\n    using gguf_writer_base::write;\n\n    void write(const int8_t val) override {\n        const auto real_val = static_cast<uint8_t>(val);\n        const auto ret = fputc(real_val, file);\n        written_bytes++;\n        if (ret != real_val) {\n            throw std::runtime_error(\"unexpected fputc result '\" + std::to_string(ret) + \"' instead of '\" + std::to_string((int)real_val) + \"'\");\n        }\n    }\n\n    void write(const std::vector<int8_t> & val) override {\n        const auto ret = fwrite(val.data(), 1, val.size(), file);\n        written_bytes += val.size();\n        if (ret != val.size()) {\n            throw std::runtime_error(\"unexpected fwrite number of bytes written, '\" + std::to_string(ret) + \"' instead of '\" + std::to_string(val.size()) + \"'\");\n        }\n    }\n\n    void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override {\n        GGML_ASSERT(written_bytes - offset_data == info.offset);\n\n        GGML_ASSERT(ggml_is_contiguous(&info.t));\n        const size_t nbytes = ggml_nbytes(&info.t);\n\n        std::vector<int8_t> buf(nbytes);\n        if (info.t.buffer) {\n            ggml_backend_tensor_get(&info.t, buf.data(), 0, nbytes);\n        } else {\n            GGML_ASSERT(info.t.data);\n            memcpy(buf.data(), info.t.data, nbytes);\n        }\n        write(buf);\n\n        pad(alignment);\n    }\n};\n\ntemplate <typename writer_t>\nstatic void gguf_write_out(const struct gguf_context * ctx, writer_t & gw, bool only_meta) {\n    const int64_t n_kv      = gguf_get_n_kv(ctx);\n    const int64_t n_tensors = gguf_get_n_tensors(ctx);\n\n    // write header\n    gw.write(GGUF_MAGIC[0]);\n    gw.write(GGUF_MAGIC[1]);\n    gw.write(GGUF_MAGIC[2]);\n    gw.write(GGUF_MAGIC[3]);\n    gw.write(ctx->version);\n    gw.write(n_tensors);\n    gw.write(n_kv);\n\n    // write key-value pairs\n    for (int64_t i = 0; i < n_kv; ++i) {\n        gw.write(ctx->kv[i]);\n    }\n\n    // write tensor info\n    for (int64_t i = 0; i < n_tensors; ++i) {\n        gw.write_tensor_meta(ctx->info[i]);\n    }\n\n    // we require the data section to be aligned\n    gw.pad(ctx->alignment);\n\n    if (only_meta) {\n        return;\n    }\n\n    const size_t offset_data = gw.written_bytes;\n\n    // write tensor data\n    for (int64_t i = 0; i < n_tensors; ++i) {\n        gw.write_tensor_data(ctx->info[i], offset_data, ctx->alignment);\n    }\n}\n\nvoid gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & buf, bool only_meta) {\n    gguf_writer_buf gw(buf);\n    gguf_write_out(ctx, gw, only_meta);\n}\n\nbool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {\n    FILE * file = ggml_fopen(fname, \"wb\");\n\n    if (!file) {\n        GGML_LOG_ERROR(\"%s: failed to open file '%s' for writing GGUF data\\n\", __func__, fname);\n        return false;\n    }\n\n    try {\n        gguf_writer_file gw(file);\n        gguf_write_out(ctx, gw, only_meta);\n    } catch (const std::runtime_error& ex) {\n        GGML_LOG_ERROR(\"%s: failed to write GGUF data into '%s': %s\\n\", __func__, fname, ex.what());\n        fclose(file);\n        return false;\n    }\n\n    fclose(file);\n    return true;\n}\n\nsize_t gguf_get_meta_size(const struct gguf_context * ctx) {\n    // only return size\n    std::vector<int8_t> buf;\n    gguf_write_to_buf(ctx, buf, /*only_meta =*/ true);\n    return buf.size();\n}\n\nvoid gguf_get_meta_data(const struct gguf_context * ctx, void * data) {\n    std::vector<int8_t> buf;\n    gguf_write_to_buf(ctx, buf, /*only_meta =*/ true);\n    memcpy(data, buf.data(), buf.size());\n}\n"
  },
  {
    "path": "tests/CMakeLists.txt",
    "content": "find_library(MATH_LIBRARY m)\n\n# check systems\nif (NOT UNAME_S)\n    execute_process(COMMAND uname -s OUTPUT_VARIABLE UNAME_S)\nendif()\nif (NOT UNAME_P)\n    execute_process(COMMAND uname -p OUTPUT_VARIABLE UNAME_P)\nendif()\nif (NOT UNAME_M)\n    execute_process(COMMAND uname -m OUTPUT_VARIABLE UNAME_M)\nendif()\n#message(STATUS \"UNAME_S: ${UNAME_S}  UNAME_P: ${UNAME_P}  UNAME_M: ${UNAME_M}\")\n\n# Mac OS + Arm can report x86_64\n# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789\nif (UNAME_S MATCHES \"Darwin\")\n    if (NOT UNAME_P MATCHES \"arm\")\n        execute_process(COMMAND sysctl -n hw.optional.arm64 OUTPUT_VARIABLE SYSCTL_M)\n        if (SYSCTL_M MATCHES \"1\")\n            #set(UNAME_P \"arm\")\n            #set(UNAME_M \"arm64\")\n            message(WARNING \"Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lea\nd to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\\#issuecomment-#1282546789\")\n        endif()\n    endif()\nendif()\n\nif (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"^(aarch64|arm.*|ARM64)$\")\n    message(STATUS \"ARM detected\")\n    #set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mcpu=apple-m1\")\nelseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"ppc64le\" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES \"ppc64\")\n    message(STATUS \"PPC64 detected\")\n    set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mpower9-vector\")\nelseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"loongarch64\")\n    message(STATUS \"loongarch64 detected\")\n    set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mlsx -mlasx\")\nelse()\n    message(STATUS \"x86 detected\")\n    #set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -mf16c\")\n    if (UNAME_S MATCHES \"Darwin\")\n        execute_process(COMMAND sysctl machdep.cpu.features OUTPUT_VARIABLE AVX1_M)\n        if (AVX1_M MATCHES \"AVX1.0\")\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mavx\")\n        endif()\n        execute_process(COMMAND sysctl machdep.cpu.leaf7_features OUTPUT_VARIABLE AVX2_M)\n        if (AVX2_M MATCHES \"AVX2\")\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mavx2\")\n        endif()\n        if (AVX1_M MATCHES \"FMA\")\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mfma\")\n        endif()\n        set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mf16c\")\n    elseif (UNAME_S MATCHES \"Linux\")\n        message(STATUS \"Linux detected\")\n        # must have to build on ubuntu22 with gcc11:\n        find_package(Threads)\n        set(GGML_EXTRA_LIBS  ${GGML_EXTRA_LIBS} Threads::Threads)\n\n        execute_process(COMMAND grep \"avx \" /proc/cpuinfo OUTPUT_VARIABLE AVX1_M)\n        if (AVX1_M MATCHES \"avx\")\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mavx\")\n        endif()\n        execute_process(COMMAND grep \"avx2 \" /proc/cpuinfo OUTPUT_VARIABLE AVX2_M)\n        if (AVX2_M MATCHES \"avx2\")\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mavx2\")\n        endif()\n        execute_process(COMMAND grep \"fma \" /proc/cpuinfo OUTPUT_VARIABLE FMA_M)\n        if (FMA_M MATCHES \"fma\")\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mfma\")\n        endif()\n        execute_process(COMMAND grep \"f16c \" /proc/cpuinfo OUTPUT_VARIABLE F16C_M)\n        if (F16C_M MATCHES \"f16c\")\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mf16c\")\n        endif()\n        execute_process(COMMAND grep \"sse3 \" /proc/cpuinfo OUTPUT_VARIABLE SSE3_M)\n        if (SSE3_M MATCHES \"sse3\")\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -msse3\")\n        endif()\n    elseif (UNAME_S MATCHES \"Haiku\")\n        message(STATUS \"Haiku detected\")\n\texecute_process(COMMAND sysinfo -cpu COMMAND grep \"AVX \" OUTPUT_VARIABLE AVX1_M)\n        if (AVX1_M MATCHES \"avx\")\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mavx\")\n        endif()\n\texecute_process(COMMAND sysinfo -cpu COMMAND grep \"AVX2 \" OUTPUT_VARIABLE AVX2_M)\n        if (AVX2_M MATCHES \"avx2\")\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mavx2\")\n        endif()\n\texecute_process(COMMAND sysinfo -cpu COMMAND grep \"FMA \" OUTPUT_VARIABLE FMA_M)\n        if (FMA_M MATCHES \"fma\")\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mfma\")\n        endif()\n\texecute_process(COMMAND sysinfo -cpu COMMAND grep \"F16C \" OUTPUT_VARIABLE F16C_M)\n        if (F16C_M MATCHES \"f16c\")\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -mf16c\")\n        endif()\n    elseif (MSVC)\n        if (GGML_AVX512)\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} /arch:AVX512\")\n            # MSVC has no compile-time flags enabling specific\n            # AVX512 extensions, neither it defines the\n            # macros corresponding to the extensions.\n            # Do it manually.\n            if (GGML_AVX512_VBMI)\n                add_compile_definitions(__AVX512VBMI__)\n            endif()\n            if (GGML_AVX512_VNNI)\n                add_compile_definitions(__AVX512VNNI__)\n            endif()\n        elseif (GGML_AVX2)\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} /arch:AVX2\")\n        elseif (GGML_AVX)\n            set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} /arch:AVX\")\n        endif()\n    else()\n        set(CMAKE_C_FLAGS  \"${CMAKE_C_FLAGS} -mfma -mf16c -mavx -mavx2\")\n    endif()\nendif()\n\n# on APPLE - include Accelerate framework\nif (APPLE AND NOT GGML_NO_ACCELERATE)\n    find_library(ACCELERATE_FRAMEWORK Accelerate)\n    if (ACCELERATE_FRAMEWORK)\n        message(STATUS \"Accelerate framework found\")\n\n        set(GGML_EXTRA_LIBS  ${GGML_EXTRA_LIBS}  ${ACCELERATE_FRAMEWORK})\n        set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)\n    else()\n        message(WARNING \"Accelerate framework not found\")\n    endif()\nendif()\n\nif (GGML_OPENBLAS)\n    set(OPENBLAS_INCLUDE_SEARCH_PATHS\n        /usr/include\n        /usr/include/openblas\n        /usr/include/openblas-base\n        /usr/local/include\n        /usr/local/include/openblas\n        /usr/local/include/openblas-base\n        /opt/OpenBLAS/include\n        $ENV{OpenBLAS_HOME}\n        $ENV{OpenBLAS_HOME}/include\n        )\n    find_path(OPENBLAS_INC NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})\n    find_library(OPENBLAS_LIB NAMES openblas libopenblas)\n    if (OPENBLAS_LIB)\n        message(STATUS \"OpenBLAS found\")\n\n        set(GGML_EXTRA_LIBS  ${GGML_EXTRA_LIBS}  ${OPENBLAS_LIB})\n        set(GGML_EXTRA_INCS  ${GGML_EXTRA_INCS}  ${OPENBLAS_INC})\n        set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)\n    else()\n        message(WARNING \"OpenBLAS not found\")\n    endif()\nendif()\n\n# undefine NDEBUG so asserts don't get disabled in tests\nadd_definitions(-UNDEBUG)\n\n#\n# test-backend-ops\n\nset(TEST_TARGET test-backend-ops)\nadd_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\ntarget_link_libraries(${TEST_TARGET} PRIVATE ggml Threads::Threads)\nadd_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\nset_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\n\n\nif (NOT GGML_BACKEND_DL)\n    #\n    # test-opt\n    set(TEST_TARGET test-opt)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n    set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\n\n    #\n    # test-quantize-fns\n\n    set(TEST_TARGET test-quantize-fns)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n    set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\n\n    #\n    # test-quantize-perf\n\n    set(TEST_TARGET test-quantize-perf)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n    set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\n\n    #\n    # test-pool\n\n    set(TEST_TARGET test-pool)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.c)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    if (MSVC)\n        target_link_options(${TEST_TARGET} PRIVATE \"/STACK:8388608\") # 8MB\n    endif()\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n    set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\n\n    #\n    # test-arange\n\n    set(TEST_TARGET test-arange)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml Threads::Threads)\n    if (MSVC)\n        target_link_options(${TEST_TARGET} PRIVATE \"/STACK:8388608\") # 8MB\n    endif()\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n    set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\n\n    #\n    # test-timestep_embedding\n\n    set(TEST_TARGET test-timestep_embedding)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    if (MSVC)\n        target_link_options(${TEST_TARGET} PRIVATE \"/STACK:8388608\") # 8MB\n    endif()\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n    set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\n\n    #\n    # test-pad-reflect-1d\n\n    set(TEST_TARGET test-pad-reflect-1d)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n\n    #\n    # test-roll\n\n    set(TEST_TARGET test-roll)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n\n    #\n    # test-conv-transpose\n\n    set(TEST_TARGET test-conv-transpose)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.c)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n\n    # test-conv-transpose-1d\n\n    set(TEST_TARGET test-conv-transpose-1d)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n\n    #\n    # test-dup\n\n    set(TEST_TARGET test-dup)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.c)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n\n    #\n    # test-rel-pos\n\n    set(TEST_TARGET test-rel-pos)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.c)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n\n    #\n    # test-customop\n\n    set(TEST_TARGET test-customop)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.c)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    if (MSVC)\n        target_link_options(${TEST_TARGET} PRIVATE \"/STACK:8388608\") # 8MB\n    endif()\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n    set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\n\n    #\n    # test-conv1d\n\n    set(TEST_TARGET test-conv1d)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n    set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\n\n    #\n    # test-conv1d-dw-c1\n\n    set(TEST_TARGET test-conv1d-dw-c1)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n    set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\n\n    #\n    # test-conv1d-dw-c2\n\n    set(TEST_TARGET test-conv1d-dw-c2)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n    set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\n\n    #\n    # test-conv2d\n\n    set(TEST_TARGET test-conv2d)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n    set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\n\n    #\n    # test-conv2d-dw\n\n    set(TEST_TARGET test-conv2d-dw)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n    set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\n\n    #\n    # test-cont\n\n    set(TEST_TARGET test-cont)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.c)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n    set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\n\n    #\n    # test-interpolate\n\n    set(TEST_TARGET test-interpolate)\n    add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)\n    target_link_libraries(${TEST_TARGET} PRIVATE ggml)\n    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)\n    set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT \"LLVM_PROFILE_FILE=${TEST_TARGET}.profraw\")\nendif()\n"
  },
  {
    "path": "tests/test-arange.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n\n#ifdef GGML_USE_CUDA\n#include \"ggml-cuda.h\"\n#endif\n\n#ifdef GGML_USE_METAL\n#include \"ggml-metal.h\"\n#endif\n\n#include <string.h>\n#include <stdio.h>\n#include <stdlib.h>\n\nint main(int /*argc*/, const char** /*argv*/) {\n    {\n        bool use_gpu = true;\n        GGML_UNUSED(use_gpu);\n\n        ggml_backend_t backend = NULL;\n        //ggml_backend_buffer_t buffer;\n\n        #ifdef GGML_USE_CUDA\n        if (use_gpu) {\n            fprintf(stderr, \"%s: using CUDA backend\\n\", __func__);\n            backend = ggml_backend_cuda_init(0);\n            if (!backend) {\n                fprintf(stderr, \"%s: ggml_backend_cuda_init() failed\\n\", __func__);\n            }\n        }\n        #endif\n\n        #ifdef GGML_USE_METAL\n        if (!backend) {\n            fprintf(stderr, \"%s: using Metal backend\\n\", __func__);\n            backend = ggml_backend_metal_init();\n            if (!backend) {\n                fprintf(stderr, \"%s: ggml_backend_metal_init() failed\\n\", __func__);\n            }\n        }\n        #endif\n\n        const int num_tensors = 2;\n\n        struct ggml_init_params params = {\n                /*.mem_size   =*/ ggml_tensor_overhead() * num_tensors + 2 * 1024 * 1024,\n                /*.mem_size   =*/ NULL,\n                /*.mem_size   =*/ true,\n        };\n\n        if (!backend) {\n            // fallback to CPU backend\n            backend = ggml_backend_cpu_init();\n        }\n\n        // create context\n        struct ggml_context* ctx = ggml_init(params);\n        struct ggml_tensor * t = ggml_arange(ctx, 0, 3, 1);\n\n        GGML_ASSERT(t->ne[0] == 3);\n\n        ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));\n\n        struct ggml_cgraph * graph = ggml_new_graph(ctx);\n        ggml_build_forward_expand(graph, t);\n\n        // allocate tensors\n        ggml_gallocr_alloc_graph(galloc, graph);\n\n        int n_threads = 4;\n\n        if (ggml_backend_is_cpu(backend)) {\n            ggml_backend_cpu_set_n_threads(backend, n_threads);\n        }\n\n        ggml_backend_graph_compute(backend, graph);\n\n        float * output = new float[ggml_nelements(t)];\n        ggml_backend_tensor_get(t, output, 0, ggml_nbytes(t));\n\n        for (int i = 0; i < t->ne[0]; i++) {\n            printf(\"%.2f \", output[i]);\n        }\n        printf(\"\\n\");\n\n        GGML_ASSERT(output[0] == 0);\n        GGML_ASSERT(output[1] == 1);\n        GGML_ASSERT(output[2] == 2);\n\n        delete[] output;\n        ggml_free(ctx);\n        ggml_gallocr_free(galloc);\n        ggml_backend_free(backend);\n    }\n\n    return 0;\n}\n"
  },
  {
    "path": "tests/test-backend-ops.cpp",
    "content": "// This file defines tests for various GGML ops and backends.\n// For the forward pass it asserts that the results of multiple backends computing the same GGML ops are consistent.\n// For the backward pass it asserts that the gradients from backpropagation are consistent\n// with the gradients obtained via the method of finite differences (\"grad\" mode, this is optional).\n// It is also possible to check the performance (\"perf\" mode).\n//\n// this file has three sections: Section 1 does general setup, section 2 defines the GGML ops to be tested,\n// and section 3 defines which tests to run.\n// Quick start for adding a new GGML op: Go to section 2 and create a struct that inherits from test_case,\n// then go to section 3 and add an instantiation of your struct.\n\n\n// ##############################\n// ## Section 1: General Setup ##\n// ##############################\n\n\n#include <ggml.h>\n#include <ggml-alloc.h>\n#include <ggml-backend.h>\n#include <ggml-cpp.h>\n\n#include <algorithm>\n#include <array>\n#include <cfloat>\n#include <cinttypes>\n#include <cstdarg>\n#include <cstdint>\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n#include <ctime>\n#include <future>\n#include <fstream>\n#include <memory>\n#include <random>\n#include <regex>\n#include <set>\n#include <sstream>\n#include <string>\n#include <string_view>\n#include <thread>\n#include <vector>\n#include <unordered_map>\n\n#ifdef __EMSCRIPTEN__\n#   define N_THREADS 1\n#else\n#   define N_THREADS std::thread::hardware_concurrency()\n#endif\n\nstatic void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {\n    size_t nels = ggml_nelements(tensor);\n    std::vector<float> data(nels);\n    {\n        // parallel initialization\n        static const size_t n_threads = N_THREADS;\n        // static RNG initialization (revisit if n_threads stops being constant)\n        static std::vector<std::default_random_engine> generators = []() {\n            std::random_device rd;\n            std::vector<std::default_random_engine> vec;\n            vec.reserve(n_threads);\n            //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed\n            for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }\n            return vec;\n        }();\n\n        auto init_thread = [&](size_t ith, size_t start, size_t end) {\n            std::uniform_real_distribution<float> distribution(min, max);\n            auto & gen = generators[ith];\n            for (size_t i = start; i < end; i++) {\n                data[i] = distribution(gen);\n            }\n        };\n\n        if (n_threads == 1) {\n            init_thread(0, 0, nels);\n        } else {\n            std::vector<std::future<void>> tasks;\n            tasks.reserve(n_threads);\n            for (size_t i = 0; i < n_threads; i++) {\n                size_t start =     i*nels/n_threads;\n                size_t end   = (i+1)*nels/n_threads;\n                tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));\n            }\n            for (auto & t : tasks) {\n                t.get();\n            }\n        }\n    }\n\n    if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {\n        ggml_backend_tensor_set(tensor, data.data(), 0, nels * sizeof(float));\n    } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) {\n        GGML_ASSERT(nels % ggml_blck_size(tensor->type) == 0);\n\n         // dummy importance matrix\n        std::vector<float> imatrix(tensor->ne[0], 1.0f);\n        const float * im = imatrix.data();\n        if (!ggml_quantize_requires_imatrix(tensor->type)) {\n            // when the imatrix is optional, we want to test both quantization with and without imatrix\n            // use one of the random numbers to decide\n            if (data[0] > 0.5f*(min + max)) {\n                im = nullptr;\n            }\n        }\n\n        std::vector<uint8_t> dataq(ggml_row_size(tensor->type, nels));\n        {\n            // parallel quantization by block\n            size_t blck_size = ggml_blck_size(tensor->type);\n            size_t n_blocks = nels / blck_size;\n\n            auto quantize_thread = [&](size_t start, size_t end) {\n                ggml_quantize_chunk(tensor->type, data.data(), dataq.data(),\n                    start * blck_size, end - start, blck_size, im);\n            };\n\n            const size_t min_blocks_per_thread = 1;\n            const size_t n_quant_threads = std::min<size_t>(std::max<size_t>(N_THREADS/2, 1),\n                                                            std::max<size_t>(1, n_blocks / min_blocks_per_thread));\n\n            if (n_quant_threads == 1) {\n                // single-threaded quantization: do all blocks in the current thread\n                quantize_thread(0, n_blocks);\n            } else {\n                std::vector<std::future<void>> tasks;\n                tasks.reserve(n_quant_threads);\n                for (size_t i = 0; i < n_quant_threads; i++) {\n                    size_t start =     i*n_blocks/n_quant_threads;\n                    size_t end   = (i+1)*n_blocks/n_quant_threads;\n                    tasks.push_back(std::async(std::launch::async, quantize_thread, start, end));\n                }\n                for (auto & t : tasks) {\n                    t.get();\n                }\n            }\n        }\n        ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());\n    } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {\n        // This is going to create some weird integers though.\n        ggml_backend_tensor_set(tensor, data.data(), 0, ggml_nbytes(tensor));\n    } else if (tensor->type == GGML_TYPE_I64) {\n        // Integers with a size of 8 bytes can be set by mirroring the float data, the specific values are again not really meaningful.\n        const size_t nbytes_half = ggml_nbytes(tensor)/2;\n        ggml_backend_tensor_set(tensor, data.data(), 0*nbytes_half, nbytes_half);\n        ggml_backend_tensor_set(tensor, data.data(), 1*nbytes_half, nbytes_half);\n    } else {\n        GGML_ABORT(\"fatal error\");\n    }\n}\n\n// generate an F16 mask where certain blocks are randomly masked with -INF value\nstatic void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {\n    GGML_ASSERT(tensor->type == GGML_TYPE_F16);\n\n    GGML_TENSOR_LOCALS( int32_t, ne, tensor, ne);\n\n    std::vector<float>       data_f32(ne0*ne1*ne2*ne3);\n    std::vector<ggml_fp16_t> data_f16(ne0*ne1*ne2*ne3);\n\n    std::random_device rd;\n    std::mt19937 gen(rd());\n    std::uniform_real_distribution<float> dis(min, max);\n\n    for (size_t i = 0; i < data_f32.size(); i++) {\n        data_f32[i] = dis(gen);\n    }\n\n    // block size\n    const int blck0 = 128;\n    const int blck1 = 64;\n\n    // number of INF/zero blocks\n    const int n_inf_zero_blocks = 0.2*(ne0*ne1*ne2*ne3)/(blck0*blck1);\n\n    for (int b = 0; b < n_inf_zero_blocks; b++) {\n        const int p3 = (rd() % ne3);\n        const int p2 = (rd() % ne2);\n        const int p1 = (rd() % ne1);\n        const int p0 = (rd() % ne0);\n\n        bool inf = rd() & 1;\n\n        for (int i1 = 0; i1 < blck1 && p1 + i1 < ne1; i1++) {\n            const int idx = p3*ne2*ne1*ne0 + p2*ne1*ne0 + (p1 + i1)*ne0 + p0;\n\n            for (int i0 = 0; i0 < blck0 && p0 + i0 < ne0; i0++) {\n                data_f32[idx + i0] = inf ? -INFINITY : 0.0f;\n            }\n        }\n    }\n\n    ggml_fp32_to_fp16_row(data_f32.data(), data_f16.data(), ne0*ne1*ne2*ne3);\n\n    ggml_backend_tensor_set(tensor, data_f16.data(), 0, data_f16.size()*sizeof(ggml_fp16_t));\n}\n\n// generate a lower triangular matrix\nstatic void init_tensor_tril(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {\n    GGML_ASSERT(tensor->type == GGML_TYPE_F32);\n    GGML_ASSERT(tensor->ne[0] == tensor->ne[1]);\n\n    GGML_TENSOR_LOCALS(int32_t, ne, tensor, ne);\n    GGML_TENSOR_LOCALS(size_t, nb, tensor, nb);\n\n    std::vector<float> data_f32(ne0*ne1*ne2*ne3);\n\n    std::random_device rd;\n    std::mt19937 gen(rd());\n    std::uniform_real_distribution<float> dis(min, max);\n\n    for (int64_t i3 = 0; i3 < ne3; i3++) {\n        for (int64_t i2 = 0; i2 < ne2; i2++) {\n            for (int64_t i1 = 0; i1 < ne1; i1++) {\n                for (int64_t i0 = 0; i0 < ne0; i0++) {\n                    int64_t idx = (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3) / sizeof(float);\n                    if (i0 <= i1) {\n                        data_f32[idx] = dis(gen);\n                    } else {\n                        data_f32[idx] = 0.0f;\n                    }\n                }\n            }\n        }\n    }\n\n    ggml_backend_tensor_set(tensor, data_f32.data(), 0, ggml_nbytes(tensor));\n}\n\nstatic std::vector<float> tensor_to_float(const ggml_tensor * t) {\n    std::vector<float> tv;\n    tv.reserve(ggml_nelements(t));\n\n    std::vector<uint8_t> buf(ggml_nbytes(t));\n    ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t));\n\n    const auto * tt = ggml_get_type_traits(t->type);\n    size_t bs = ggml_blck_size(t->type);\n    std::vector<float> vq(ggml_blck_size(t->type));\n    bool quantized = ggml_is_quantized(t->type);\n\n    // access elements by index to avoid gaps in views\n    for (int64_t i3 = 0; i3 < t->ne[3]; i3++) {\n        for (int64_t i2 = 0; i2 < t->ne[2]; i2++) {\n            for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {\n                for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) {\n                    size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];\n                    if (t->type == GGML_TYPE_F16) {\n                        tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]));\n                    } else if (t->type == GGML_TYPE_BF16) {\n                        tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i]));\n                    } else if (t->type == GGML_TYPE_F32) {\n                        tv.push_back(*(float *) &buf[i]);\n                    } else if (t->type == GGML_TYPE_I64) {\n                        tv.push_back((float)*(int64_t *) &buf[i]);\n                    } else if (t->type == GGML_TYPE_I32) {\n                        tv.push_back((float)*(int32_t *) &buf[i]);\n                    } else if (t->type == GGML_TYPE_I16) {\n                        tv.push_back((float)*(int16_t *) &buf[i]);\n                    } else if (t->type == GGML_TYPE_I8) {\n                        tv.push_back((float)*(int8_t *) &buf[i]);\n                    } else if (quantized) {\n                        tt->to_float(&buf[i], vq.data(), bs);\n                        tv.insert(tv.end(), vq.begin(), vq.end());\n                    } else {\n                        GGML_ABORT(\"fatal error\");\n                    }\n                }\n            }\n        }\n    }\n\n    return tv;\n}\n\n// normalized mean squared error = mse(a, b) / mse(a, 0)\nstatic double nmse(const float * a, const float * b, size_t n) {\n    double mse_a_b = 0.0;\n    double mse_a_0 = 0.0;\n\n    for (size_t i = 0; i < n; i++) {\n        float a_i = a[i];\n        float b_i = b[i];\n\n        mse_a_b += (a_i - b_i) * (a_i - b_i);\n        mse_a_0 += a_i * a_i;\n    }\n\n    return mse_a_b / mse_a_0;\n}\n\n// difference between 2 sets (Jaccard distance, 0 - no difference, 1 - no overlap)\ntemplate <typename T>\nstatic double jdst(const T * a, const T * b, size_t n) {\n    std::unordered_map<T, size_t> set_a;\n    std::unordered_map<T, size_t> set_b;\n\n    for (size_t i = 0; i < n; ++i) {\n        set_a[a[i]]++;\n        set_b[b[i]]++;\n    }\n\n    size_t diff = 0;\n\n    for (const auto & p : set_a) {\n        const int64_t na = p.second;\n        const int64_t nb = set_b.find(p.first) != set_b.end() ? set_b.at(p.first) : 0;\n\n        diff += std::abs(na - nb);\n    }\n\n    for (const auto & p : set_b) {\n        if (set_a.find(p.first) == set_a.end()) {\n            diff += p.second;\n        }\n    }\n\n    return (double) diff / (2*n);\n}\n\n// maximum absolute asymmetry between a and b\n// asymmetry: (a - b) / (a + b)\n// This is more stable than relative error if one of the values fluctuates towards zero.\n// n: number of values to compare.\n// expected_vals: optional vector of expected values for a. If expected_vals is not empty, filter out all comparisons where\n//     a does not match any of the expected values. Needed for noncontinuous gradients where the numerical calculation can fail.\nstatic double mean_abs_asymm(const float * a, const float * b, const size_t n, const std::vector<float> & expected_vals) {\n    double sum = 0.0f;\n\n    size_t nvalid = 0;\n    for (size_t i = 0; i < n; i++) {\n        if (!expected_vals.empty()) {\n            bool matches_any = false;\n            for (const float & ev : expected_vals) {\n                if (fabsf(a[i] - ev) < 1e-3f) {\n                    matches_any = true;\n                    break;\n                }\n            }\n            if (!matches_any) {\n                continue;\n            }\n        }\n\n        const float asymm = (a[i] - b[i]) / (a[i] + b[i]);\n\n        sum += fabsf(asymm);\n        nvalid++;\n    }\n\n    return sum/nvalid;\n}\n\n// utils for printing the variables of the test cases\n\nstatic std::string var_to_str(const std::string & x) {\n    return x;\n}\n\ntemplate<typename T>\nstatic std::string var_to_str(const T & x) {\n    return std::to_string(x);\n}\n\ntemplate<typename T, size_t N>\nstatic std::string var_to_str(const T (&x)[N]) {\n    std::string s = \"[\";\n    for (size_t i = 0; i < N; i++) {\n        if (i > 0) {\n            s += \",\";\n        }\n        s += var_to_str(x[i]);\n    }\n    s += \"]\";\n    return s;\n}\n\ntemplate<typename T, size_t N>\nstatic std::string var_to_str(const std::array<T, N> & x) {\n    std::string s = \"[\";\n    for (size_t i = 0; i < N; i++) {\n        if (i > 0) {\n            s += \",\";\n        }\n        s += var_to_str(x[i]);\n    }\n    s += \"]\";\n    return s;\n}\n\nstatic std::string var_to_str(ggml_type type) {\n    return ggml_type_name(type);\n}\n\nstatic std::string var_to_str(ggml_prec prec) {\n    return prec == GGML_PREC_F32 ? \"f32\" : \"def\";\n}\n\nstatic std::string var_to_str(ggml_op_pool pool) {\n    switch (pool) {\n        case GGML_OP_POOL_AVG:  return \"avg\";\n        case GGML_OP_POOL_MAX:  return \"max\";\n        default:                return std::to_string(pool);\n    }\n}\n\nstatic std::string var_to_str(ggml_scale_mode mode) {\n    std::string str;\n    switch (mode & 0xFF) {\n        case GGML_SCALE_MODE_NEAREST:  str = \"nearest\"; break;\n        case GGML_SCALE_MODE_BILINEAR: str = \"bilinear\"; break;\n        case GGML_SCALE_MODE_BICUBIC:  str = \"bicubic\"; break;\n        default:                       str = std::to_string(mode); break;\n    }\n    if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {\n        str += \"|align_corners\";\n    }\n    if (mode & GGML_SCALE_FLAG_ANTIALIAS) {\n        str += \"|antialias\";\n    }\n    return str;\n}\n\n#define VAR_TO_STR(x) (#x \"=\" + var_to_str(x))\n\n#define VARS_TO_STR1(a) VAR_TO_STR(a)\n#define VARS_TO_STR2(a, b) VAR_TO_STR(a) + \",\" + VAR_TO_STR(b)\n#define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + \",\" + VARS_TO_STR2(b, c)\n#define VARS_TO_STR4(a, b, c, d) VAR_TO_STR(a) + \",\" + VARS_TO_STR3(b, c, d)\n#define VARS_TO_STR5(a, b, c, d, e) VAR_TO_STR(a) + \",\" + VARS_TO_STR4(b, c, d, e)\n#define VARS_TO_STR6(a, b, c, d, e, f) VAR_TO_STR(a) + \",\" + VARS_TO_STR5(b, c, d, e, f)\n#define VARS_TO_STR7(a, b, c, d, e, f, g) VAR_TO_STR(a) + \",\" + VARS_TO_STR6(b, c, d, e, f, g)\n#define VARS_TO_STR8(a, b, c, d, e, f, g, h) VAR_TO_STR(a) + \",\" + VARS_TO_STR7(b, c, d, e, f, g, h)\n#define VARS_TO_STR9(a, b, c, d, e, f, g, h, i) VAR_TO_STR(a) + \",\" + VARS_TO_STR8(b, c, d, e, f, g, h, i)\n#define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + \",\" + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)\n#define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + \",\" + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)\n#define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + \",\" + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l)\n#define VARS_TO_STR13(a, b, c, d, e, f, g, h, i, j, k, l, m) VAR_TO_STR(a) + \",\" + VARS_TO_STR12(b, c, d, e, f, g, h, i, j, k, l, m)\n#define VARS_TO_STR14(a, b, c, d, e, f, g, h, i, j, k, l, m, n) VAR_TO_STR(a) + \",\" + VARS_TO_STR13(b, c, d, e, f, g, h, i, j, k, l, m, n)\n#define VARS_TO_STR15(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) VAR_TO_STR(a) + \",\" + VARS_TO_STR14(b, c, d, e, f, g, h, i, j, k, l, m, n, o)\n#define VARS_TO_STR16(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) VAR_TO_STR(a) + \",\" + VARS_TO_STR15(b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)\n\n#ifdef GGML_USE_SYCL\nstatic bool inline _isinf(float f) {\n    return (*(uint32_t *)&f & 0x7fffffff) == 0x7f800000;\n}\n#else\nstatic bool inline _isinf(float f) { return std::isinf(f); }\n#endif\n\n// accept FLT_MAX as infinity\nstatic bool isinf_or_max(float f) {\n    return _isinf(f) || f == FLT_MAX || f == -FLT_MAX;\n}\n\nstatic bool ggml_is_view_op(enum ggml_op op) {\n    return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;\n}\n\nstatic bool backend_has_feature(ggml_backend_t backend, const char * feature_name) {\n    ggml_backend_dev_t dev = ggml_backend_get_device(backend);\n    ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);\n\n    auto get_features = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, \"ggml_backend_get_features\");\n    if (!get_features) {\n        return false;\n    }\n\n    const ggml_backend_feature * features = get_features(reg);\n    if (!features) {\n        return false;\n    }\n\n    for (const ggml_backend_feature * f = features; f->name; ++f) {\n        if (strcmp(f->name, feature_name) == 0 && strcmp(f->value, \"1\") == 0) {\n            return true;\n        }\n    }\n    return false;\n}\n\nenum test_mode {\n    MODE_TEST,\n    MODE_PERF,\n    MODE_GRAD,\n    MODE_SUPPORT,\n};\n\n// Output format support similar to llama-bench\nenum output_formats { CONSOLE, SQL, CSV };\n\nstatic const char * output_format_str(output_formats format) {\n    switch (format) {\n        case CONSOLE:\n            return \"console\";\n        case SQL:\n            return \"sql\";\n        case CSV:\n            return \"csv\";\n        default:\n            GGML_ABORT(\"invalid output format\");\n    }\n}\n\nstatic bool output_format_from_str(const std::string & s, output_formats & format) {\n    if (s == \"console\") {\n        format = CONSOLE;\n    } else if (s == \"sql\") {\n        format = SQL;\n    } else if (s == \"csv\") {\n        format = CSV;\n    } else {\n        return false;\n    }\n    return true;\n}\n\n// Test result structure for SQL output\nstruct test_result {\n    std::string test_time;\n    std::string build_commit;\n    std::string backend_name;\n    std::string op_name;\n    std::string op_params;\n    std::string test_mode;\n    bool        supported;\n    bool        passed;\n    std::string error_message;\n    double      time_us;\n    double      flops;\n    double      bandwidth_gb_s;\n    size_t      memory_kb;\n    int         n_runs;\n    std::string device_description;\n    std::string backend_reg_name;\n\n    test_result() {\n        // Initialize with default values\n        time_us        = 0.0;\n        flops          = 0.0;\n        bandwidth_gb_s = 0.0;\n        memory_kb      = 0;\n        n_runs         = 0;\n        supported      = false;\n        passed         = false;\n\n        // Set test time\n        time_t t = time(NULL);\n        char   buf[32];\n        std::strftime(buf, sizeof(buf), \"%FT%TZ\", gmtime(&t));\n        test_time = buf;\n\n        // Set build info\n        build_commit = ggml_commit();\n    }\n\n    test_result(const std::string & backend_name, const std::string & op_name, const std::string & op_params,\n                const std::string & test_mode, bool supported, bool passed, const std::string & error_message = \"\",\n                double time_us = 0.0, double flops = 0.0, double bandwidth_gb_s = 0.0, size_t memory_kb = 0,\n                int n_runs = 0, const std::string & device_description = \"\", const std::string & backend_reg_name = \"\") :\n        backend_name(backend_name),\n        op_name(op_name),\n        op_params(op_params),\n        test_mode(test_mode),\n        supported(supported),\n        passed(passed),\n        error_message(error_message),\n        time_us(time_us),\n        flops(flops),\n        bandwidth_gb_s(bandwidth_gb_s),\n        memory_kb(memory_kb),\n        n_runs(n_runs),\n        device_description(device_description),\n        backend_reg_name(backend_reg_name) {\n        // Set test time\n        time_t t = time(NULL);\n        char   buf[32];\n        std::strftime(buf, sizeof(buf), \"%FT%TZ\", gmtime(&t));\n        test_time = buf;\n\n        // Set build info\n        build_commit = ggml_commit();\n    }\n\n    static const std::vector<std::string> & get_fields() {\n        static const std::vector<std::string> fields = {\n            \"test_time\", \"build_commit\",  \"backend_name\", \"op_name\", \"op_params\",      \"test_mode\", \"supported\",\n            \"passed\",    \"error_message\", \"time_us\",      \"flops\",   \"bandwidth_gb_s\", \"memory_kb\", \"n_runs\",\n            \"device_description\", \"backend_reg_name\"\n        };\n        return fields;\n    }\n\n    enum field_type { STRING, BOOL, INT, FLOAT };\n\n    static field_type get_field_type(const std::string & field) {\n        if (field == \"supported\" || field == \"passed\") {\n            return BOOL;\n        }\n        if (field == \"memory_kb\" || field == \"n_runs\") {\n            return INT;\n        }\n        if (field == \"time_us\" || field == \"flops\" || field == \"bandwidth_gb_s\") {\n            return FLOAT;\n        }\n        return STRING;\n    }\n\n    std::vector<std::string> get_values() const {\n        return { test_time,\n                 build_commit,\n                 backend_name,\n                 op_name,\n                 op_params,\n                 test_mode,\n                 std::to_string(supported),\n                 std::to_string(passed),\n                 error_message,\n                 std::to_string(time_us),\n                 std::to_string(flops),\n                 std::to_string(bandwidth_gb_s),\n                 std::to_string(memory_kb),\n                 std::to_string(n_runs),\n                 device_description,\n                 backend_reg_name };\n    }\n};\n\n// Printer classes for different output formats\nenum class test_status_t { NOT_SUPPORTED, OK, FAIL, SKIPPED };\n\nstruct test_operation_info {\n    std::string   op_name;\n    std::string   op_params;\n    std::string   backend_name;\n    test_status_t status = test_status_t::OK;\n    std::string   failure_reason;\n\n    // Additional information fields that were previously in separate structs\n    std::string error_component;\n    std::string error_details;\n\n    // Gradient info\n    int64_t     gradient_index = -1;\n    std::string gradient_param_name;\n    float       gradient_value = 0.0f;\n\n    // MAA error info\n    double maa_error     = 0.0;\n    double maa_threshold = 0.0;\n\n    // Flags for different types of information\n    bool has_error            = false;\n    bool has_gradient_info    = false;\n    bool has_maa_error        = false;\n    bool is_compare_failure   = false;\n    bool is_large_tensor_skip = false;\n\n    test_operation_info() = default;\n\n    test_operation_info(const std::string & op_name, const std::string & op_params, const std::string & backend_name,\n                        test_status_t status = test_status_t::OK, const std::string & failure_reason = \"\") :\n        op_name(op_name),\n        op_params(op_params),\n        backend_name(backend_name),\n        status(status),\n        failure_reason(failure_reason) {}\n\n    // Set error information\n    void set_error(const std::string & component, const std::string & details) {\n        has_error       = true;\n        error_component = component;\n        error_details   = details;\n        if (status == test_status_t::OK) {\n            status = test_status_t::FAIL;\n        }\n    }\n\n    // Set gradient information\n    void set_gradient_info(int64_t index, const std::string & param_name, float value) {\n        has_gradient_info   = true;\n        gradient_index      = index;\n        gradient_param_name = param_name;\n        gradient_value      = value;\n        if (status == test_status_t::OK) {\n            status = test_status_t::FAIL;\n        }\n    }\n\n    // Set MAA error information\n    void set_maa_error(double error, double threshold) {\n        has_maa_error = true;\n        maa_error     = error;\n        maa_threshold = threshold;\n        if (status == test_status_t::OK) {\n            status = test_status_t::FAIL;\n        }\n    }\n\n    // Set compare failure\n    void set_compare_failure() {\n        is_compare_failure = true;\n        if (status == test_status_t::OK) {\n            status = test_status_t::FAIL;\n        }\n    }\n\n    // Set large tensor skip\n    void set_large_tensor_skip() { is_large_tensor_skip = true; }\n};\n\nstruct test_summary_info {\n    size_t tests_passed;\n    size_t tests_total;\n    bool   is_backend_summary = false;  // true for backend summary, false for test summary\n\n    test_summary_info() = default;\n\n    test_summary_info(size_t tests_passed, size_t tests_total, bool is_backend_summary = false) :\n        tests_passed(tests_passed),\n        tests_total(tests_total),\n        is_backend_summary(is_backend_summary) {}\n};\n\nstruct testing_start_info {\n    size_t device_count;\n\n    testing_start_info() = default;\n\n    testing_start_info(size_t device_count) : device_count(device_count) {}\n};\n\nstruct backend_init_info {\n    size_t      device_index;\n    size_t      total_devices;\n    std::string device_name;\n    bool        skipped = false;\n    std::string skip_reason;\n    std::string description;\n    size_t      memory_total_mb = 0;\n    size_t      memory_free_mb  = 0;\n    bool        has_memory_info = false;\n\n    backend_init_info() = default;\n\n    backend_init_info(size_t device_index, size_t total_devices, const std::string & device_name, bool skipped = false,\n                      const std::string & skip_reason = \"\", const std::string & description = \"\",\n                      size_t memory_total_mb = 0, size_t memory_free_mb = 0, bool has_memory_info = false) :\n        device_index(device_index),\n        total_devices(total_devices),\n        device_name(device_name),\n        skipped(skipped),\n        skip_reason(skip_reason),\n        description(description),\n        memory_total_mb(memory_total_mb),\n        memory_free_mb(memory_free_mb),\n        has_memory_info(has_memory_info) {}\n};\n\nstruct backend_status_info {\n    std::string   backend_name;\n    test_status_t status;\n\n    backend_status_info() = default;\n\n    backend_status_info(const std::string & backend_name, test_status_t status) :\n        backend_name(backend_name),\n        status(status) {}\n};\n\nstruct overall_summary_info {\n    size_t backends_passed;\n    size_t backends_total;\n    bool   all_passed;\n\n    overall_summary_info() = default;\n\n    overall_summary_info(size_t backends_passed, size_t backends_total, bool all_passed) :\n        backends_passed(backends_passed),\n        backends_total(backends_total),\n        all_passed(all_passed) {}\n};\n\nstruct printer {\n    virtual ~printer() {}\n\n    FILE * fout = stdout;\n\n    virtual void print_header() {}\n\n    virtual void print_test_result(const test_result & result) = 0;\n\n    virtual void print_footer() {}\n\n    virtual void print_operation(const test_operation_info & info) { (void) info; }\n\n    virtual void print_summary(const test_summary_info & info) { (void) info; }\n\n    virtual void print_testing_start(const testing_start_info & info) { (void) info; }\n\n    virtual void print_backend_init(const backend_init_info & info) { (void) info; }\n\n    virtual void print_backend_status(const backend_status_info & info) { (void) info; }\n\n    virtual void print_overall_summary(const overall_summary_info & info) { (void) info; }\n\n    virtual void print_failed_tests(const std::vector<std::string> & failed_tests) { (void) failed_tests; }\n};\n\nstruct console_printer : public printer {\n    void print_test_result(const test_result & result) override {\n        if (result.test_mode == \"test\") {\n            print_test_console(result);\n        } else if (result.test_mode == \"perf\") {\n            print_perf_console(result);\n        } else if (result.test_mode == \"support\") {\n            print_support_console(result);\n        }\n    }\n\n    void print_operation(const test_operation_info & info) override {\n        printf(\"  %s(%s): \", info.op_name.c_str(), info.op_params.c_str());\n        fflush(stdout);\n\n        // Handle large tensor skip first\n        if (info.is_large_tensor_skip) {\n            printf(\"skipping large tensors for speed \\n\");\n            return;\n        }\n\n        // Handle not supported status\n        if (info.status == test_status_t::NOT_SUPPORTED) {\n            if (!info.failure_reason.empty()) {\n                printf(\"not supported [%s]\\n\", info.failure_reason.c_str());\n            } else {\n                printf(\"not supported [%s]\\n\", info.backend_name.c_str());\n            }\n            return;\n        }\n\n        // Handle errors and additional information\n        if (info.has_error) {\n            if (info.error_component == \"allocation\") {\n                fprintf(stderr, \"failed to allocate tensors [%s] \", info.backend_name.c_str());\n            } else if (info.error_component == \"backend\") {\n                fprintf(stderr, \"  Failed to initialize %s backend\\n\", info.backend_name.c_str());\n            } else {\n                fprintf(stderr, \"Error in %s: %s\\n\", info.error_component.c_str(), info.error_details.c_str());\n            }\n        }\n\n        // Handle gradient info\n        if (info.has_gradient_info) {\n            printf(\"[%s] nonfinite gradient at index %\" PRId64 \" (%s=%f) \", info.op_name.c_str(), info.gradient_index,\n                   info.gradient_param_name.c_str(), info.gradient_value);\n        }\n\n        // Handle MAA error\n        if (info.has_maa_error) {\n            printf(\"[%s] MAA = %.9f > %.9f \", info.op_name.c_str(), info.maa_error, info.maa_threshold);\n        }\n\n        // Handle compare failure\n        if (info.is_compare_failure) {\n            printf(\"compare failed \");\n        }\n\n        // Print final status\n        if (info.status == test_status_t::OK) {\n            printf(\"\\033[1;32mOK\\033[0m\\n\");\n        } else {\n            printf(\"\\033[1;31mFAIL\\033[0m\\n\");\n        }\n    }\n\n    void print_summary(const test_summary_info & info) override {\n        if (info.is_backend_summary) {\n            printf(\"%zu/%zu backends passed\\n\", info.tests_passed, info.tests_total);\n        } else {\n            printf(\"  %zu/%zu tests passed\\n\", info.tests_passed, info.tests_total);\n        }\n    }\n\n    void print_backend_status(const backend_status_info & info) override {\n        printf(\"  Backend %s: \", info.backend_name.c_str());\n        if (info.status == test_status_t::OK) {\n            printf(\"\\033[1;32mOK\\033[0m\\n\");\n        } else {\n            printf(\"\\033[1;31mFAIL\\033[0m\\n\");\n        }\n    }\n\n    void print_testing_start(const testing_start_info & info) override {\n        printf(\"Testing %zu devices\\n\\n\", info.device_count);\n    }\n\n    void print_backend_init(const backend_init_info & info) override {\n        printf(\"Backend %zu/%zu: %s\\n\", info.device_index + 1, info.total_devices, info.device_name.c_str());\n\n        if (info.skipped) {\n            printf(\"  %s\\n\", info.skip_reason.c_str());\n            return;\n        }\n\n        if (!info.description.empty()) {\n            printf(\"  Device description: %s\\n\", info.description.c_str());\n        }\n\n        if (info.has_memory_info) {\n            printf(\"  Device memory: %zu MB (%zu MB free)\\n\", info.memory_total_mb, info.memory_free_mb);\n        }\n\n        printf(\"\\n\");\n    }\n\n    void print_overall_summary(const overall_summary_info & info) override {\n        printf(\"%zu/%zu backends passed\\n\", info.backends_passed, info.backends_total);\n        if (info.all_passed) {\n            printf(\"\\033[1;32mOK\\033[0m\\n\");\n        } else {\n            printf(\"\\033[1;31mFAIL\\033[0m\\n\");\n        }\n    }\n\n    void print_failed_tests(const std::vector<std::string> & failed_tests) override {\n        if (failed_tests.empty()) {\n            return;\n        }\n\n        printf(\"\\nFailing tests:\\n\");\n        for (const auto & test_name : failed_tests) {\n            printf(\"  %s\\n\", test_name.c_str());\n        }\n    }\n\n  private:\n    void print_test_console(const test_result & result) {\n        printf(\"  %s(%s): \", result.op_name.c_str(), result.op_params.c_str());\n        fflush(stdout);\n\n        if (!result.supported) {\n            printf(\"not supported [%s] \", result.backend_name.c_str());\n            printf(\"\\n\");\n            return;\n        }\n\n        if (result.passed) {\n            printf(\"\\033[1;32mOK\\033[0m\\n\");\n        } else {\n            printf(\"\\033[1;31mFAIL\\033[0m\\n\");\n        }\n    }\n\n    void print_perf_console(const test_result & result) {\n        int len = printf(\"  %s(%s): \", result.op_name.c_str(), result.op_params.c_str());\n        fflush(stdout);\n\n        if (!result.supported) {\n            printf(\"not supported\\n\");\n            return;\n        }\n\n        // align while also leaving some margin for variations in parameters\n        int align = 8;\n        int last  = (len + align - 1) / align * align;\n        if (last - len < 5) {\n            last += align;\n        }\n        printf(\"%*s\", last - len, \"\");\n\n        printf(\"    %8d runs - %8.2f us/run - \", result.n_runs, result.time_us);\n\n        if (result.flops > 0) {\n            auto format_flops = [](double flops) -> std::string {\n                char buf[256];\n                if (flops >= 1e12) {\n                    snprintf(buf, sizeof(buf), \"%6.2f TFLOP\", flops / 1e12);\n                } else if (flops >= 1e9) {\n                    snprintf(buf, sizeof(buf), \"%6.2f GFLOP\", flops / 1e9);\n                } else if (flops >= 1e6) {\n                    snprintf(buf, sizeof(buf), \"%6.2f MFLOP\", flops / 1e6);\n                } else {\n                    snprintf(buf, sizeof(buf), \"%6.2f kFLOP\", flops / 1e3);\n                }\n                return buf;\n            };\n            uint64_t op_flops_per_run = result.flops * result.time_us / 1e6;\n            printf(\"%s/run - \\033[1;34m%sS\\033[0m\", format_flops(op_flops_per_run).c_str(),\n                   format_flops(result.flops).c_str());\n        } else {\n            printf(\"%8zu kB/run - \\033[1;34m%7.2f GB/s\\033[0m\", result.memory_kb, result.bandwidth_gb_s);\n        }\n        printf(\"\\n\");\n    }\n\n    void print_support_console(const test_result & result) {\n        printf(\"  %s(%s): \", result.op_name.c_str(), result.op_params.c_str());\n        fflush(stdout);\n\n        if (result.supported) {\n            printf(\"\\033[1;32mSUPPORTED\\033[0m\\n\");\n        } else {\n            printf(\"\\033[1;31mNOT SUPPORTED\\033[0m\\n\");\n        }\n    }\n};\n\nstruct sql_printer : public printer {\n    static std::string get_sql_field_type(const std::string & field) {\n        switch (test_result::get_field_type(field)) {\n            case test_result::STRING:\n                return \"TEXT\";\n            case test_result::BOOL:\n            case test_result::INT:\n                return \"INTEGER\";\n            case test_result::FLOAT:\n                return \"REAL\";\n            default:\n                GGML_ABORT(\"invalid field type\");\n        }\n    }\n\n    void print_header() override {\n        std::vector<std::string> fields = test_result::get_fields();\n        fprintf(fout, \"CREATE TABLE IF NOT EXISTS test_backend_ops (\\n\");\n        for (size_t i = 0; i < fields.size(); i++) {\n            fprintf(fout, \"  %s %s%s\\n\", fields[i].c_str(), get_sql_field_type(fields[i]).c_str(),\n                    i < fields.size() - 1 ? \",\" : \"\");\n        }\n        fprintf(fout, \");\\n\\n\");\n    }\n\n    void print_test_result(const test_result & result) override {\n        fprintf(fout, \"INSERT INTO test_backend_ops (\");\n        std::vector<std::string> fields = test_result::get_fields();\n        for (size_t i = 0; i < fields.size(); i++) {\n            fprintf(fout, \"%s%s\", fields[i].c_str(), i < fields.size() - 1 ? \", \" : \"\");\n        }\n        fprintf(fout, \") VALUES (\");\n        std::vector<std::string> values = result.get_values();\n        for (size_t i = 0; i < values.size(); i++) {\n            fprintf(fout, \"'%s'%s\", values[i].c_str(), i < values.size() - 1 ? \", \" : \"\");\n        }\n        fprintf(fout, \");\\n\");\n    }\n};\n\nstruct csv_printer : public printer {\n    void print_header() override {\n\n        std::vector<std::string> fields     = test_result::get_fields();\n        std::vector<std::string> fields_csv = get_fields_csv();\n        for (size_t i = 0; i < fields.size(); i++) {\n            if (std::find(std::begin(fields_csv), std::end(fields_csv), fields[i]) == std::end(fields_csv)) {\n                continue;\n            }\n            printf(\"\\\"%s\\\"%s\", fields[i].c_str(), i < fields.size() - 1 ? \",\" : \"\");\n        }\n        printf(\"\\n\");\n    }\n\n    void print_test_result(const test_result & result) override {\n\n        std::vector<std::string> values     = result.get_values();\n        std::vector<std::string> fields     = test_result::get_fields();\n        std::vector<std::string> fields_csv = get_fields_csv();\n\n        for (size_t i = 0; i < values.size(); i++) {\n\n            if (std::find(std::begin(fields_csv), std::end(fields_csv), fields[i]) == std::end(fields_csv)) {\n                continue;\n            }\n\n            // Escape quotes and wrap in quotes for CSV\n            std::string escaped_value = values[i];\n            size_t pos = 0;\n            while ((pos = escaped_value.find(\"\\\"\", pos)) != std::string::npos) {\n                escaped_value.replace(pos, 1, \"\\\"\\\"\");\n                pos += 2;\n            }\n            printf(\"\\\"%s\\\"%s\", escaped_value.c_str(), i < values.size() - 1 ? \",\" : \"\");\n        }\n        printf(\"\\n\");\n    }\n\n    static std::vector<std::string> get_fields_csv() {\n        return {\n            \"op_name\",\n            \"op_params\",\n            \"supported\",\n            \"error_message\",\n            \"test_mode\",\n            \"backend_reg_name\",\n            \"backend_name\",\n        };\n    }\n\n};\n\nstatic std::unique_ptr<printer> create_printer(output_formats format) {\n    switch (format) {\n        case CONSOLE:\n            return std::make_unique<console_printer>();\n        case SQL:\n            return std::make_unique<sql_printer>();\n        case CSV:\n            return std::make_unique<csv_printer>();\n    }\n    GGML_ABORT(\"invalid output format\");\n}\n\nstruct test_case {\n    virtual ~test_case() {}\n\n    virtual std::string op_desc(ggml_tensor * t) {\n        return ggml_op_desc(t);\n    }\n\n    virtual std::string vars() {\n        return \"\";\n    }\n\n    virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;\n\n    virtual double max_nmse_err() {\n        return 1e-7;\n    }\n\n    virtual double max_nmse_err(ggml_backend_t backend) {\n        GGML_UNUSED(backend);\n        return max_nmse_err();\n    }\n\n    virtual double max_maa_err() {\n        return 1e-4;\n    }\n\n    virtual double max_err() {\n        return max_nmse_err();\n    }\n\n    virtual double max_err(ggml_backend_t backend) {\n        return max_nmse_err(backend);\n    }\n\n    virtual double err(const float * a, const float * b, size_t n) {\n        return nmse(a, b, n);\n    }\n\n    virtual float grad_eps() {\n        return 1e-1f;\n    }\n\n    // If false, estimate gradient with 2 points, neglects 3rd order derivative and higher.\n    // If true,  estimate gradient with 4 points, neglects 5th order derivative and higher.\n    virtual bool grad_precise() {\n        return false;\n    }\n\n    // Skip gradient checks if total number of gradients to be checked is larger than this (to speed up the tests).\n    virtual int64_t grad_nmax() {\n        return 10000;\n    }\n\n    // No effect if empty.\n    // If not empty, skip all gradient checks where the numerical result does not match any of the values.\n    // Needed for dealing with noncontinuous gradients (e.g. ReLU) where estimation using finite differences is unreliable.\n    virtual std::vector<float> grad_expect() {\n        return {};\n    }\n\n    virtual void initialize_tensors(ggml_context * ctx) {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t);\n        }\n    }\n\n    virtual size_t op_size(ggml_tensor * t) {\n        size_t size = ggml_nbytes(t);\n        // add source tensors\n        for (int i = 0; i < GGML_MAX_SRC; i++) {\n            if (t->src[i] != NULL) {\n                size += ggml_nbytes(t->src[i]);\n            }\n        }\n        return size;\n    }\n\n    virtual uint64_t op_flops(ggml_tensor * t) {\n        GGML_UNUSED(t);\n        return 0;\n    }\n\n    virtual bool run_whole_graph() { return false; }\n    virtual std::vector<ggml_tensor *> fusion_test_nodes() { return {}; }\n\n    ggml_cgraph * gf = nullptr;\n    ggml_cgraph * gb = nullptr;\n\n    static const int sentinel_size = 1024;\n\n    test_mode mode;\n\n    std::vector<ggml_tensor *> sentinels;\n\n    std::string current_op_name;\n\n    void add_sentinel(ggml_context * ctx) {\n        if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) {\n            return;\n        }\n        ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size);\n        ggml_format_name(sentinel, \"sent_%zu\", sentinels.size());\n        sentinels.push_back(sentinel);\n    }\n\n    // hijack ggml_new_tensor to add sentinels after each tensor to check for overflows in the backend\n\n    ggml_tensor * ggml_new_tensor(ggml_context * ctx, ggml_type type, int n_dims, const int64_t * ne) {\n        ggml_tensor * t = ::ggml_new_tensor(ctx, type, n_dims, ne);\n        add_sentinel(ctx);\n        return t;\n    }\n\n    ggml_tensor * ggml_new_tensor_1d(ggml_context * ctx, ggml_type type, int64_t ne0) {\n        ggml_tensor * t = ::ggml_new_tensor_1d(ctx, type, ne0);\n        add_sentinel(ctx);\n        return t;\n    }\n\n    ggml_tensor * ggml_new_tensor_2d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1) {\n        ggml_tensor * t = ::ggml_new_tensor_2d(ctx, type, ne0, ne1);\n        add_sentinel(ctx);\n        return t;\n    }\n\n    ggml_tensor * ggml_new_tensor_3d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2) {\n        ggml_tensor * t = ::ggml_new_tensor_3d(ctx, type, ne0, ne1, ne2);\n        add_sentinel(ctx);\n        return t;\n    }\n\n    ggml_tensor * ggml_new_tensor_4d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {\n        ggml_tensor * t = ::ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);\n        add_sentinel(ctx);\n        return t;\n    }\n\n    // Checks an op against the test filter, which is a comma separated list of OP names or specific variations\n    bool matches_filter(ggml_tensor * op, const char * op_names_filter) {\n        if (op_names_filter) {\n            const auto op_name = op_desc(op);\n            const auto op_full_name = op_name + \"(\" + vars() + \")\";\n            std::string_view filter(op_names_filter);\n            while (!filter.empty()) {\n                auto comma_pos = filter.find_first_of(',');\n                const auto lparen_pos = filter.find_first_of('(');\n                if (lparen_pos < comma_pos) {\n                    auto rparen_pos = filter.find_first_of(')');\n                    comma_pos = filter.find_first_of(',', rparen_pos);\n                    const auto op_filter = filter.substr(0, comma_pos);\n                    if (op_filter == op_full_name) {\n                        return true;\n                    }\n                } else {\n                    const auto op_filter = filter.substr(0, comma_pos);\n                    if (op_filter == op_name) {\n                        return true;\n                    }\n                }\n                filter = comma_pos != std::string_view::npos ? filter.substr(comma_pos + 1) : \"\";\n            }\n            return false;\n        } else {\n            return true;\n        }\n    }\n\n    test_status_t eval(ggml_backend_t backend1,\n                       ggml_backend_t backend2,\n                       const char *   op_names_filter,\n                       printer *      output_printer) {\n        mode = MODE_TEST;\n\n        ggml_init_params params = {\n            /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),\n            /* .mem_base = */ NULL,\n            /* .no_alloc = */ true,\n        };\n        ggml_context * ctx = ggml_init(params);\n        GGML_ASSERT(ctx);\n\n        gf = ggml_new_graph(ctx);\n\n        // pre-graph sentinel\n        add_sentinel(ctx);\n\n        ggml_tensor * out = build_graph(ctx);\n        current_op_name   = op_desc(out);\n\n        if (!matches_filter(out, op_names_filter)) {\n            //printf(\"  %s: skipping\\n\", op_desc(out).c_str());\n            ggml_free(ctx);\n            return test_status_t::SKIPPED;\n        }\n\n        // check if the backends support the ops\n        bool supported = true;\n        for (ggml_backend_t backend : {backend1, backend2}) {\n            for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n                if (!ggml_backend_supports_op(backend, t)) {\n                    supported = false;\n                    break;\n                }\n            }\n        }\n\n        if (!supported) {\n            // Create test result for unsupported operation\n            test_result result(ggml_backend_name(backend1), current_op_name, vars(), \"test\",\n                             false, false, \"not supported\");\n\n            if (output_printer) {\n                output_printer->print_test_result(result);\n            }\n\n            ggml_free(ctx);\n            return test_status_t::NOT_SUPPORTED;\n        }\n\n        // post-graph sentinel\n        add_sentinel(ctx);\n\n        // allocate\n        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);\n\n        if (buf == NULL) {\n            printf(\"failed to allocate tensors [%s] \", ggml_backend_name(backend1));\n            ggml_free(ctx);\n            return test_status_t::FAIL;\n        }\n\n        // build graph\n        ggml_build_forward_expand(gf, out);\n\n        // add sentinels as graph nodes so that they are checked in the callback\n        for (ggml_tensor * sentinel : sentinels) {\n            ggml_graph_add_node(gf, sentinel);\n        }\n\n        // randomize tensors\n        initialize_tensors(ctx);\n\n        // compare\n        struct callback_userdata {\n            bool   ok;\n            test_case * tc;\n            ggml_backend_t backend1;\n            ggml_backend_t backend2;\n        };\n\n        callback_userdata ud {\n            true,\n            this,\n            backend1,\n            backend2,\n        };\n\n        auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {\n            callback_userdata * ud = (callback_userdata *) user_data;\n            const char * bn1 = ggml_backend_name(ud->backend1);\n            const char * bn2 = ggml_backend_name(ud->backend2);\n\n            if (t1->op == GGML_OP_NONE) {\n                // sentinels must be unchanged\n                std::vector<uint8_t> t1_data(ggml_nbytes(t1));\n                std::vector<uint8_t> t2_data(ggml_nbytes(t2));\n                ggml_backend_tensor_get(t1, t1_data.data(), 0, ggml_nbytes(t1));\n                ggml_backend_tensor_get(t2, t2_data.data(), 0, ggml_nbytes(t2));\n\n                if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) {\n                    printf(\"sentinel mismatch: %s \", t1->name);\n                    ud->ok = false;\n                    return true;\n                }\n            }\n\n            std::vector<float> f1 = tensor_to_float(t1);\n            std::vector<float> f2 = tensor_to_float(t2);\n\n            for (size_t i = 0; i < f1.size(); i++) {\n                // check for nans\n                if (std::isnan(f1[i]) || std::isnan(f2[i])) {\n                    printf(\"[%s] NaN at index %zu (%s=%f %s=%f) \", ggml_op_desc(t1), i, bn1, f1[i], bn2, f2[i]);\n                    ud->ok = false;\n                    return true;\n                }\n                // check for infs: both must be inf of the same sign, or both must be finite\n                if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {\n                    if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {\n                        if (std::signbit(f1[i]) != std::signbit(f2[i])) {\n                            printf(\"[%s] inf sign mismatch: %s=%f %s=%f \", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]);\n                            ud->ok = false;\n                            return true;\n                        }\n                    } else {\n                        printf(\"[%s] inf mismatch: %s=%f %s=%f \", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]);\n                        ud->ok = false;\n                        return true;\n                    }\n                }\n            }\n\n            double err = ud->tc->err(f1.data(), f2.data(), f1.size());\n            if (err > ud->tc->max_err(ud->backend1)) {\n                printf(\"[%s] ERR = %.9f > %.9f \", ggml_op_desc(t1), err, ud->tc->max_err(ud->backend1));\n                //for (int i = 0; i < (int) f1.size(); i++) {\n                //    printf(\"%5d %9.6f %9.6f, diff = %9.6f\\n\", i, f1[i], f2[i], f1[i] - f2[i]);\n                //}\n                //printf(\"\\n\");\n                //exit(1);\n                ud->ok = false;\n            }\n            return true;\n\n            GGML_UNUSED(index);\n        };\n\n        std::vector<ggml_tensor *> fused_nodes_to_verify = fusion_test_nodes();\n        if (fused_nodes_to_verify.size() == 0 && run_whole_graph()) {\n            fused_nodes_to_verify.push_back(out);\n        }\n        const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud,\n                                                               run_whole_graph() ? fused_nodes_to_verify.data() : nullptr,\n                                                               fused_nodes_to_verify.size());\n\n        ggml_backend_buffer_free(buf);\n\n        ggml_free(ctx);\n\n        // Create test result\n        bool        test_passed = ud.ok && cmp_ok;\n        std::string error_msg   = test_passed ? \"\" : (!cmp_ok ? \"compare failed\" : \"test failed\");\n        test_result result(ggml_backend_name(backend1), current_op_name, vars(), \"test\", supported, test_passed,\n                           error_msg);\n\n        if (output_printer) {\n            output_printer->print_test_result(result);\n        }\n\n        return test_passed ? test_status_t::OK : test_status_t::FAIL;\n    }\n\n    bool eval_perf(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {\n        mode = MODE_PERF;\n\n        static const size_t graph_nodes = 8192;\n\n        ggml_init_params params = {\n            /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead_custom(graph_nodes, false),\n            /* .mem_base = */ NULL,\n            /* .no_alloc = */ true,\n        };\n        ggml_context_ptr ctx(ggml_init(params)); // smart ptr\n        GGML_ASSERT(ctx);\n\n        ggml_tensor * out             = build_graph(ctx.get());\n        current_op_name               = op_desc(out);\n        if (!matches_filter(out, op_names_filter)) {\n            //printf(\"  %s: skipping\\n\", op_desc(out).c_str());\n            return true;\n        }\n\n        if (!ggml_backend_supports_op(backend, out)) {\n            // Create test result for unsupported performance test\n            test_result result(ggml_backend_name(backend), current_op_name, vars(), \"perf\", false, false,\n                               \"not supported\");\n\n            output_printer->print_test_result(result);\n\n            return true;\n        }\n\n        // allocate\n        ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr\n\n        if (buf == NULL) {\n            printf(\"failed to allocate tensors\\n\");\n            return false;\n        }\n\n        // randomize tensors\n        initialize_tensors(ctx.get());\n\n        // build graph\n        ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), graph_nodes, false);\n        ggml_build_forward_expand(gf, out);\n\n        // warmup run\n        ggml_status status = ggml_backend_graph_compute(backend, gf);\n        if (status != GGML_STATUS_SUCCESS) {\n            fprintf(stderr, \"%s: ggml_backend_graph_compute failed. status=%s \\n\", __func__, ggml_status_to_string(status));\n            return false;\n        }\n\n        // determine number of runs\n        int n_runs;\n        bool is_cpu = ggml_backend_dev_type(ggml_backend_get_device(backend)) == GGML_BACKEND_DEVICE_TYPE_CPU;\n        if (op_flops(out) > 0) {\n            // based on flops\n            const uint64_t GFLOP = 1000 * 1000 * 1000;\n            const uint64_t target_flops_cpu =   8ULL * GFLOP;\n            const uint64_t target_flops_gpu = 100ULL * GFLOP;\n            uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu;\n            n_runs = (int)std::min<int64_t>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;\n        } else {\n            // based on memory size\n            const size_t GB = 1ULL << 30;\n            const size_t target_size_cpu =  8 * GB;\n            const size_t target_size_gpu = 32 * GB;\n            size_t target_size = is_cpu ? target_size_cpu : target_size_gpu;\n            n_runs = (int)std::min<int64_t>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;\n        }\n\n        // duplicate the op\n        for (int i = 1; i < n_runs; i++) {\n            ggml_graph_add_node(gf, out);\n        }\n\n        // calculate memory\n        size_t mem = n_runs * op_size(out);\n        auto tensor_op_size = [](ggml_tensor * t) {\n            size_t size = ggml_nbytes(t);\n            // add source tensors\n            for (int i = 0; i < GGML_MAX_SRC; i++) {\n                if (t->src[i] != NULL) {\n                    size += ggml_nbytes(t->src[i]);\n                }\n            }\n            return size;\n        };\n        for (int i = 0; i < ggml_graph_n_nodes(gf); ++i) {\n            if (ggml_is_view_op(ggml_graph_node(gf, i)->op) || ggml_graph_node(gf, i) == out) {\n                continue;\n            }\n            mem += tensor_op_size(ggml_graph_node(gf, i));\n        }\n\n        // run\n        int64_t total_time_us = 0;\n        int64_t total_mem = 0;\n        int total_runs = 0;\n        do {\n            int64_t start_time = ggml_time_us();\n            ggml_status status = ggml_backend_graph_compute(backend, gf);\n            if (status != GGML_STATUS_SUCCESS) {\n                fprintf(stderr, \"%s: ggml_backend_graph_compute failed. status=%s \\n\", __func__, ggml_status_to_string(status));\n                return false;\n            }\n            int64_t end_time = ggml_time_us();\n\n            total_time_us += end_time - start_time;\n            total_mem += mem;\n            total_runs += n_runs;\n        } while (total_time_us < 1000*1000); // run for at least 1 second\n\n        // Create test result\n        double avg_time_us      = (double) total_time_us / total_runs;\n        double calculated_flops = (op_flops(out) > 0) ? (op_flops(out) * total_runs) / (total_time_us / 1e6) : 0.0;\n        double calculated_bandwidth =\n            (op_flops(out) == 0) ? total_mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0 : 0.0;\n        size_t calculated_memory_kb = op_size(out) / 1024;\n\n        test_result result(ggml_backend_name(backend), current_op_name, vars(), \"perf\", true, true, \"\", avg_time_us,\n                           calculated_flops, calculated_bandwidth, calculated_memory_kb, total_runs);\n\n        if (output_printer) {\n            output_printer->print_test_result(result);\n        }\n\n        return true;\n    }\n\n    bool eval_support(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {\n        mode = MODE_SUPPORT;\n\n        static const size_t graph_nodes = 8192;\n\n        ggml_init_params params = {\n            /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead_custom(graph_nodes, false),\n            /* .mem_base = */ NULL,\n            /* .no_alloc = */ true,\n        };\n        ggml_context_ptr ctx(ggml_init(params)); // smart ptr\n        GGML_ASSERT(ctx);\n\n        gf = ggml_new_graph_custom(ctx.get(), graph_nodes, false);\n\n        ggml_tensor * out = build_graph(ctx.get());\n        current_op_name   = op_desc(out);\n\n        if (!matches_filter(out, op_names_filter)) {\n            return true;\n        }\n\n        bool supported = ggml_backend_supports_op(backend, out);\n\n        std::string device_desc = ggml_backend_dev_description(ggml_backend_get_device(backend));\n        std::string backend_reg_name = ggml_backend_reg_name(ggml_backend_dev_backend_reg(ggml_backend_get_device(backend)));\n\n        test_result result(ggml_backend_name(backend), current_op_name, vars(), \"support\", supported, supported,\n                           supported ? \"yes\" : \"no\", 0.0, 0.0, 0.0, 0, 0, device_desc, backend_reg_name);\n\n        output_printer->print_test_result(result);\n\n        return true;\n    }\n\n    bool eval_grad(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {\n        mode = MODE_GRAD;\n        const std::vector<float> expect = grad_expect();\n\n        ggml_init_params params = {\n            /* .mem_size = */ ggml_tensor_overhead()*128 + 2*ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, true),\n            /* .mem_base = */ NULL,\n            /* .no_alloc = */ true,\n        };\n        ggml_context_ptr ctx(ggml_init(params)); // smart ptr\n        GGML_ASSERT(ctx);\n\n        gf = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true);\n        gb = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true);\n\n        ggml_tensor * out = build_graph(ctx.get());\n\n        if (!matches_filter(out, op_names_filter) || out->op == GGML_OP_OPT_STEP_ADAMW) {\n            return true;\n        }\n\n        if (out->type != GGML_TYPE_F32) {\n            output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend),\n                                                                test_status_t::NOT_SUPPORTED,\n                                                                out->name + std::string(\"->type != FP32\")));\n            return true;\n        }\n\n        // Print operation info first\n        output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend)));\n\n        // check if the backend supports the ops\n        bool        supported  = true;\n        bool        any_params = false;\n        std::string failure_reason;\n\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {\n            if (!ggml_backend_supports_op(backend, t)) {\n                supported      = false;\n                failure_reason = ggml_backend_name(backend);\n                break;\n            }\n            if ((t->flags & GGML_TENSOR_FLAG_PARAM)) {\n                any_params = true;\n                if (t->type != GGML_TYPE_F32) {\n                    supported      = false;\n                    failure_reason = std::string(t->name) + \"->type != FP32\";\n                    break;\n                }\n            }\n        }\n        if (!any_params) {\n            supported      = false;\n            failure_reason = op_desc(out);\n        }\n\n        if (!supported) {\n            output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend),\n                                                                test_status_t::NOT_SUPPORTED, failure_reason));\n            return true;\n        }\n\n        int64_t ngrads = 0;\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {\n            if (t->flags & GGML_TENSOR_FLAG_PARAM) {\n                ngrads += ggml_nelements(t);\n            }\n        }\n        if (ngrads > grad_nmax()) {\n            test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend));\n            info.set_large_tensor_skip();\n            output_printer->print_operation(info);\n            return true;\n        }\n\n\n        if (!ggml_is_scalar(out)) {\n            out = ggml_sum(ctx.get(), out);\n            ggml_set_name(out, \"sum_of_out\");\n        }\n        ggml_set_loss(out);\n\n        ggml_build_forward_expand(gf, out);\n        ggml_graph_cpy(gf, gb);\n        ggml_build_backward_expand(ctx.get(), gb, nullptr);\n        if (expect.size() != 1 || expect[0] != 0.0f) {\n            GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));\n            for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {\n                GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || ggml_graph_get_grad(gb, t)->op != GGML_OP_NONE);\n            }\n        }\n\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {\n            if (!ggml_backend_supports_op(backend, t)) {\n                output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend),\n                                                                    test_status_t::NOT_SUPPORTED,\n                                                                    ggml_backend_name(backend)));\n                supported = false;\n                break;\n            }\n            if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) {\n                output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend),\n                                                                    test_status_t::NOT_SUPPORTED,\n                                                                    std::string(t->name) + \"->type != FP32\"));\n                supported = false;\n                break;\n            }\n        }\n        if (!supported) {\n            return true;\n        }\n\n        // allocate\n        ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr\n        if (buf == NULL) {\n            test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend));\n            info.set_error(\"allocation\", \"\");\n            output_printer->print_operation(info);\n            return false;\n        }\n\n        initialize_tensors(ctx.get()); // Randomizes all tensors (including gradients).\n        ggml_graph_reset(gb);    // Sets gradients to 1 if loss, 0 otherwise.\n\n        ggml_status status = ggml_backend_graph_compute(backend, gf);\n        if (status != GGML_STATUS_SUCCESS) {\n            fprintf(stderr, \"%s: ggml_backend_graph_compute failed. status=%s \\n\", __func__, ggml_status_to_string(status));\n            return false;\n        }\n        status = ggml_backend_graph_compute(backend, gb);\n        if (status != GGML_STATUS_SUCCESS) {\n            fprintf(stderr, \"%s: ggml_backend_graph_compute failed. status=%s \\n\", __func__, ggml_status_to_string(status));\n            return false;\n        }\n\n        bool ok = true;\n        for (struct ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) {\n            if (!(t->flags & GGML_TENSOR_FLAG_PARAM)) {\n                continue;\n            }\n\n            const char * bn = ggml_backend_name(backend);\n            const int64_t ne = ggml_nelements(t);\n\n            std::vector<float> ga;\n            struct ggml_tensor * grad = ggml_graph_get_grad(gb, t);\n            if (grad) {\n                ga = tensor_to_float(grad);\n            } else {\n                ga.resize(ne); // default value is 0.0f\n            }\n\n            for (int64_t i = 0; i < ne; ++i) { // gradient algebraic\n                // check for nans\n                if (!std::isfinite(ga[i])) {\n                    test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend));\n                    info.set_gradient_info(i, bn, ga[i]);\n                    output_printer->print_operation(info);\n                    ok = false;\n                    break;\n                }\n            }\n            if (!ok) {\n                break;\n            }\n\n            std::vector<float> gn(ne); // gradient numeric\n            GGML_ASSERT(ga.size() == gn.size());\n\n            std::vector<float> x0 = tensor_to_float(t); // original t data\n            GGML_ASSERT(ggml_is_scalar(out));\n            GGML_ASSERT(out->type == GGML_TYPE_F32);\n\n            const float eps = grad_eps();\n            for (int64_t i = 0; i < ne; ++i) {\n                const float xiu  = x0[i] + 1.0f*eps; // x, index i, up\n                const float xiuh = x0[i] + 0.5f*eps; // x, index i, up half\n                const float xidh = x0[i] - 0.5f*eps; // x, index i, down half\n                const float xid  = x0[i] - 1.0f*eps; // x, index i, down\n\n                float fu, fuh, fdh, fd; // output values for xiu, xiuh, xid, xidh\n\n                ggml_backend_tensor_set(t, &xiu, i*sizeof(float), sizeof(float));\n                status = ggml_backend_graph_compute(backend, gf);\n                if (status != GGML_STATUS_SUCCESS) {\n                    fprintf(stderr, \"%s: ggml_backend_graph_compute failed. status=%s \\n\", __func__, ggml_status_to_string(status));\n                    return false;\n                }\n                ggml_backend_tensor_get(out, &fu, 0, ggml_nbytes(out));\n\n                ggml_backend_tensor_set(t, &xid, i*sizeof(float), sizeof(float));\n                status = ggml_backend_graph_compute(backend, gf);\n                if (status != GGML_STATUS_SUCCESS) {\n                    fprintf(stderr, \"%s: ggml_backend_graph_compute failed. status=%s \\n\", __func__, ggml_status_to_string(status));\n                    return false;\n                }\n                ggml_backend_tensor_get(out, &fd, 0, ggml_nbytes(out));\n\n                if (grad_precise()) {\n                    ggml_backend_tensor_set(t, &xiuh, i*sizeof(float), sizeof(float));\n                    status = ggml_backend_graph_compute(backend, gf);\n                    if (status != GGML_STATUS_SUCCESS) {\n                        fprintf(stderr, \"%s: ggml_backend_graph_compute failed. status=%s \\n\", __func__, ggml_status_to_string(status));\n                        return false;\n                    }\n                    ggml_backend_tensor_get(out, &fuh, 0, ggml_nbytes(out));\n\n                    ggml_backend_tensor_set(t, &xidh, i*sizeof(float), sizeof(float));\n                    status = ggml_backend_graph_compute(backend, gf);\n                    if (status != GGML_STATUS_SUCCESS) {\n                        fprintf(stderr, \"%s: ggml_backend_graph_compute failed. status=%s \\n\", __func__, ggml_status_to_string(status));\n                        return false;\n                    }\n                    ggml_backend_tensor_get(out, &fdh, 0, ggml_nbytes(out));\n\n                    gn[i] = (8.0*(double)fuh + (double)fd - (8.0*(double)fdh + (double)fu)) / (6.0*(double)eps);\n                } else {\n                    gn[i] = (fu - fd) / (2.0f*eps);\n                }\n\n                ggml_backend_tensor_set(t, x0.data(), 0, ggml_nbytes(t));\n            }\n\n            const double err = mean_abs_asymm(gn.data(), ga.data(), gn.size(), expect);\n            if (err > max_maa_err()) {\n                test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend));\n                info.set_maa_error(err, max_maa_err());\n                output_printer->print_operation(info);\n                ok = false;\n                break;\n            }\n            if (!ok) {\n                break;\n            }\n        }\n\n        // Create final test result\n        test_operation_info final_info(op_desc(out), vars(), ggml_backend_name(backend));\n        if (!ok) {\n            final_info.set_compare_failure();\n        }\n        final_info.status = ok ? test_status_t::OK : test_status_t::FAIL;\n        output_printer->print_operation(final_info);\n\n        if (ok) {\n            return true;\n        }\n\n        return false;\n    }\n};\n\n\n// ####################################\n// ## Section 2: GGML Op Definitions ##\n// ####################################\n\n\n// The following is an example showing the bare minimum for creating a test for a GGML op.\n\n// GGML_OP_EXAMPLE\nstruct test_example : public test_case {\n    // Always define these 2 or variants thereof:\n    const ggml_type type; // The type of the input tensors.\n    const std::array<int64_t, 4> ne; // The shape of the input tensors.\n    // For some ops it's necessary to define multiple types or shapes for the inputs.\n    // Or they may need additional parameters.\n\n    // Put all parameters needed to fully define the test into one of the VARS_TO_STR macros.\n    // In most cases these are just the properties of the struct that you defined above.\n    // This is needed for info prints.\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    // Define a constructor for the struct.\n    // In most cases it will be sufficient to have the same arguments as the struct has properties\n    // and just use initializer lists.\n    test_example(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 5, 4, 3})\n        : type(type), ne(ne) {}\n\n    // Define how a simple GGML compute graph can be constructed for the new GGML op.\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        // Step 1: create input tensors that don't depend on any other tensors:\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\"); // Setting names is optional but it's useful for debugging.\n\n        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(b, \"b\");\n\n        // Step 2: use the op that you want to test in the GGML compute graph.\n        ggml_tensor * out = ggml_add(ctx, a, b); // For this example we're just doing a simple addition.\n        ggml_set_name(out, \"out\");\n\n        // Step 3: return the output tensor.\n        return out;\n    }\n    // In order to also check the gradients for your op, add calls like ggml_set_param(a)\n    // immediately after you create the tensors.\n    // This is optional and only makes sense if a backward pass has actually been implemented for the new op.\n};\n\n\n// GGML_OP_UNARY\nstruct test_unary : public test_case {\n    const ggml_unary_op op;\n    const ggml_type type;\n    const std::array<int64_t, 4> ne_a;\n    int v; // view (1 : non-contiguous a)\n\n    std::string vars() override {\n        return VARS_TO_STR3(type, ne_a, v);\n    }\n\n    test_unary(ggml_unary_op op,\n            ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_a = {128, 2, 2, 2},\n            int v = 0)\n        : op(op), type(type), ne_a(ne_a), v(v) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        const bool grad_supported = op == GGML_UNARY_OP_ABS || op == GGML_UNARY_OP_SGN || op == GGML_UNARY_OP_NEG ||\n            op == GGML_UNARY_OP_STEP || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU ||\n            op == GGML_UNARY_OP_EXPM1 || op == GGML_UNARY_OP_SOFTPLUS;\n\n        ggml_tensor * a;\n        if (v & 1) {\n            auto ne = ne_a;\n            ne[0] *= 3;\n            ne[1] *= 2;\n            ne[2] *= 5;\n            ne[3] *= 4;\n            a = ggml_new_tensor(ctx, type, 4, ne.data());\n            if (grad_supported) {\n                ggml_set_param(a);\n            }\n            ggml_set_name(a, \"a\");\n\n            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);\n            ggml_set_name(a, \"view_of_a\");\n        } else {\n            a = ggml_new_tensor(ctx, type, 4, ne_a.data());\n            if (grad_supported) {\n                ggml_set_param(a);\n            }\n            ggml_set_name(a, \"a\");\n        }\n\n        ggml_tensor * out = ggml_unary(ctx, a, op);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            // test extended range of values to check for NaNs in GELU\n            init_tensor_uniform(t, -150.f, 150.f);\n        }\n    }\n\n    float grad_eps() override {\n        return 15.0f;\n    }\n\n    std::vector<float> grad_expect() override {\n        if (op == GGML_UNARY_OP_ABS) {\n            return {-1.0f, 1.0f};\n        }\n        if (op == GGML_UNARY_OP_SGN || op == GGML_UNARY_OP_STEP) {\n            return {0.0f};\n        }\n        if (op == GGML_UNARY_OP_RELU) {\n            return {0.0f, 1.0f};\n        }\n        return {};\n    }\n\n};\n\n// GGML_OP_GLU\nstruct test_glu : public test_case {\n    const ggml_glu_op op;\n    const ggml_type type;\n    const std::array<int64_t, 4> ne_a;\n    int v; // view (1 : non-contiguous a)\n    bool swapped;\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne_a, v, swapped);\n    }\n\n    test_glu(ggml_glu_op op,\n            ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_a = {128, 2, 2, 2},\n            int v = 0,\n            bool swapped = false)\n        : op(op), type(type), ne_a(ne_a), v(v), swapped(swapped) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a;\n        if (v & 1) {\n            auto ne = ne_a; ne[0] *= 3;\n            a = ggml_new_tensor(ctx, type, 4, ne.data());\n            ggml_set_name(a, \"a\");\n\n            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);\n            ggml_set_name(a, \"view_of_a\");\n        } else {\n            a = ggml_new_tensor(ctx, type, 4, ne_a.data());\n            ggml_set_name(a, \"a\");\n        }\n\n        ggml_tensor * out = ggml_glu(ctx, a, op, swapped);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            // test extended range of values to check for NaNs in GELU\n            init_tensor_uniform(t, -150.f, 150.f);\n        }\n    }\n};\n\nstruct test_glu_split : public test_case {\n    const ggml_glu_op op;\n    const ggml_type type;\n    const std::array<int64_t, 4> ne_a;\n    int v; // view (1 : non-contiguous a)\n\n    std::string vars() override {\n        return VARS_TO_STR3(type, ne_a, v) + \",split\";\n    }\n\n    test_glu_split(ggml_glu_op op,\n            ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_a = {128, 2, 2, 2},\n            int v = 0)\n        : op(op), type(type), ne_a(ne_a), v(v) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a;\n        ggml_tensor * b;\n        if (v & 1) {\n            auto ne = ne_a; ne[0] *= 3;\n            a = ggml_new_tensor(ctx, type, 4, ne.data());\n            ggml_set_param(a);\n            ggml_set_name(a, \"a\");\n\n            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);\n            ggml_set_name(a, \"view_of_a\");\n\n            b = ggml_new_tensor(ctx, type, 4, ne.data());\n            ggml_set_param(b);\n            ggml_set_name(b, \"b\");\n\n            b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);\n            ggml_set_name(a, \"view_of_b\");\n        } else {\n            a = ggml_new_tensor(ctx, type, 4, ne_a.data());\n            ggml_set_param(a);\n            ggml_set_name(a, \"a\");\n\n            b = ggml_new_tensor(ctx, type, 4, ne_a.data());\n            ggml_set_param(b);\n            ggml_set_name(b, \"b\");\n        }\n\n        ggml_tensor * out = ggml_glu_split(ctx, a, b, op);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            // test extended range of values to check for NaNs in GELU\n            init_tensor_uniform(t, -150.f, 150.f);\n        }\n    }\n};\n\nstruct test_swiglu_oai : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne_a;\n    int v; // view (1 : non-contiguous a)\n    float alpha;\n    float limit;\n\n    std::string vars() override {\n        return VARS_TO_STR5(type, ne_a, v, alpha, limit);\n    }\n\n    test_swiglu_oai(ggml_type type = GGML_TYPE_F32,\n                    std::array<int64_t, 4> ne_a = {128, 2, 2, 2},\n                    int v = 0,\n                    float alpha = 1.702f,\n                    float limit = 7.0f)\n        : type(type), ne_a(ne_a), v(v), alpha(alpha), limit(limit) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a;\n        ggml_tensor * b;\n        if (v & 1) {\n            auto ne = ne_a; ne[0] *= 3;\n            a = ggml_new_tensor(ctx, type, 4, ne.data());\n            ggml_set_param(a);\n            ggml_set_name(a, \"a\");\n\n            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);\n            ggml_set_name(a, \"view_of_a\");\n\n            b = ggml_new_tensor(ctx, type, 4, ne.data());\n            ggml_set_param(b);\n            ggml_set_name(b, \"b\");\n\n            b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);\n            ggml_set_name(a, \"view_of_b\");\n        } else {\n            a = ggml_new_tensor(ctx, type, 4, ne_a.data());\n            ggml_set_param(a);\n            ggml_set_name(a, \"a\");\n\n            b = ggml_new_tensor(ctx, type, 4, ne_a.data());\n            ggml_set_param(b);\n            ggml_set_name(b, \"b\");\n        }\n\n        ggml_tensor * out = ggml_swiglu_oai(ctx, a, b, alpha, limit);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            // test extended range of values to check for NaNs in GELU\n            init_tensor_uniform(t, -150.f, 150.f);\n        }\n    }\n};\n\n// GGML_OP_GET_ROWS\nstruct test_get_rows : public test_case {\n    const ggml_type type;\n    const int n; // cols\n    const int m; // rows\n    const int r; // rows to get\n    const int be1; // batch size\n    const int be2; // batch size\n    const bool v; // view (non-contiguous src1)\n\n    std::string vars() override {\n        return VARS_TO_STR7(type, n, m, r, be1, be2, v);\n    }\n\n    test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int be1 = 1, int be2 = 1, bool v = false)\n        : type(type), n(n), m(m), r(r), be1(be1), be2(be2), v(v) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * in = ggml_new_tensor_4d(ctx, type, n, m, be1, be2);\n        ggml_set_name(in, \"in\");\n\n        ggml_tensor * rows = ggml_new_tensor_3d(ctx, GGML_TYPE_I32, r, be1, be2);\n        ggml_set_name(rows, \"rows\");\n        if (v) {\n            rows = ggml_view_3d(ctx, rows, r/2, be1, be2, rows->nb[1], rows->nb[2], 0);\n            ggml_set_name(rows, \"view_of_rows\");\n        }\n\n        const bool grad_supported = ggml_is_matrix(in) && ggml_is_vector(rows);\n        if (grad_supported) {\n            ggml_set_param(in);\n            // rows is a constant input -> no gradients\n        }\n\n        ggml_tensor * out = ggml_get_rows(ctx, in, rows);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (t->type == GGML_TYPE_I32) {\n                if (ggml_is_view_op(t->op)) { continue; }\n                // rows\n                std::vector<int> data(r*be1*be2);\n                for (int i = 0; i < r*be1*be2; i++) {\n                    data[i] = rand() % m;\n                }\n                ggml_backend_tensor_set(t, data.data(), 0, r * be1 * be2 * sizeof(int));\n            } else {\n                init_tensor_uniform(t);\n            }\n        }\n    }\n};\n\n// GGML_OP_GET_ROWS_BACK\nstruct test_get_rows_back : public test_case {\n    const ggml_type type;\n    const int n; // cols\n    const int m; // rows\n    const int r; // rows to get\n    const int b; // batch size\n    const bool v; // view (non-contiguous src1)\n\n    std::string vars() override {\n        return VARS_TO_STR6(type, n, m, r, b, v);\n    }\n\n    test_get_rows_back(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)\n        : type(type), n(n), m(m), r(r), b(b), v(v) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * in_forward = ggml_new_tensor_3d(ctx, type, n, m, b);\n        ggml_set_name(in_forward, \"in_forward\");\n\n        ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);\n        ggml_set_name(rows, \"rows\");\n        if (v) {\n            rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);\n            ggml_set_name(rows, \"view_of_rows\");\n        }\n\n        ggml_tensor * grad = ggml_new_tensor_3d(ctx, type, n, r, b);\n        ggml_set_name(grad, \"grad\");\n\n        ggml_tensor * out = ggml_get_rows_back(ctx, grad, rows, in_forward);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (t->type == GGML_TYPE_I32) {\n                if (ggml_is_view_op(t->op)) { continue; }\n                // rows\n                std::vector<int> data(r*b);\n                for (int i = 0; i < r*b; i++) {\n                    data[i] = rand() % m;\n                }\n                ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));\n            } else {\n                init_tensor_uniform(t);\n            }\n        }\n    }\n};\n\nstatic void init_set_rows_row_ids(ggml_tensor * t, int num_rows) {\n    std::random_device rd;\n    std::default_random_engine rng(rd());\n    for (int i2 = 0; i2 < t->ne[2]; i2++) {\n        for (int i1 = 0; i1 < t->ne[1]; i1++) {\n            // generate a shuffled subset of row indices\n            std::vector<int64_t> data(num_rows);\n            for (int i = 0; i < num_rows; i++) {\n                data[i] = i;\n            }\n            std::shuffle(data.begin(), data.end(), rng);\n            data.resize(t->ne[0]);\n\n            const size_t offs = i1*t->nb[1] + i2*t->nb[2];\n            if (t->type == GGML_TYPE_I32) {\n                // TODO: Make a template or something\n                std::vector<int32_t> data_i32(t->ne[0]);\n                for (int i = 0; i < t->ne[0]; i++) {\n                    data_i32[i] = static_cast<int32_t>(data[i]);\n                }\n                ggml_backend_tensor_set(t, data_i32.data(), offs, t->ne[0]*sizeof(int32_t));\n            } else {\n                ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));\n            }\n        }\n    }\n}\n\n// GGML_OP_SET_ROWS\nstruct test_set_rows : public test_case {\n    const ggml_type type;\n    const ggml_type type_idx;\n    const std::array<int64_t, 4> ne;\n    const std::array<int, 2> nr23; // broadcast only dims 2 and 3\n    const int r; // rows to set\n    const bool v; // view (non-contiguous src1)\n\n    std::string vars() override {\n        return VARS_TO_STR6(type, type_idx, ne, nr23, r, v);\n    }\n\n    test_set_rows(ggml_type type,\n            ggml_type type_idx,\n            std::array<int64_t, 4> ne,\n            std::array<int, 2> nr23,\n            int r, bool v = false)\n        : type(type), type_idx(type_idx), ne(ne), nr23(nr23), r(r), v(v) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * dst = ggml_new_tensor_4d(ctx, type,          ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);\n        ggml_set_name(dst, \"dst\");\n\n        ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], r,     ne[2]*nr23[0], ne[3]*nr23[1]);\n        ggml_set_name(src, \"src\");\n\n        ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, type_idx, r, ne[2], ne[3]);\n        ggml_set_name(row_idxs, \"row_idxs\");\n\n        if (v) {\n            src      = ggml_view_4d(ctx, src, ne[0], r/2, ne[2]*nr23[0], ne[3]*nr23[1], src->nb[1], src->nb[2], src->nb[3], 0);\n            row_idxs = ggml_view_3d(ctx, row_idxs, r/2, ne[2], ne[3], row_idxs->nb[1], row_idxs->nb[2], 0);\n            ggml_set_name(row_idxs, \"view_of_rows\");\n        }\n\n        ggml_tensor * out = ggml_set_rows(ctx, dst, src, row_idxs);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {\n                if (ggml_is_view_op(t->op)) {\n                    continue;\n                }\n\n                init_set_rows_row_ids(t, ne[1]);\n            } else {\n                init_tensor_uniform(t);\n            }\n        }\n    }\n\n    double max_nmse_err() override {\n        if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL ||\n            type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1 || type == GGML_TYPE_Q8_0) {\n            // estimate what the max nmse error would be if one quantized value is\n            // off by one. The test values are distributed in [-1,1], so it'll be\n            // roughly (2.0 / 2^bits)^2, divided by the mean square value of the reference,\n            // which is roughly 0.25 times the number of elements.\n            double err_estimate = 1.0f/8.0f;\n            if (type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {\n                err_estimate /= 2.0f;\n            }\n            if (type == GGML_TYPE_Q8_0) {\n                err_estimate /= 8.0f;\n            }\n            err_estimate *= err_estimate;\n            err_estimate /= 0.25f*float(ne[0] * r * ne[2]*nr23[0] * ne[3]*nr23[1]);\n            return err_estimate;\n        }\n        return 1e-7;\n    }\n};\n\n// GGML_OP_ROPE + GGML_OP_VIEW + GGML_OP_SET_ROWS\nstruct test_rope_set_rows : public test_case {\n    const ggml_type type;\n    const ggml_type type_idx;\n    const std::array<int64_t, 4> ne_a;\n    int mode;\n    const int n_ctx{512};\n    const int n_dims{128};\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, type_idx, ne_a, mode);\n    }\n\n    std::string op_desc(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return \"ROPE_SET_ROWS\";\n    }\n\n    bool run_whole_graph() override { return true; }\n\n    test_rope_set_rows(ggml_type type,\n            ggml_type type_idx,\n            std::array<int64_t, 4> ne_a,\n            int mode)\n        : type(type), type_idx(type_idx), ne_a(ne_a), mode(mode) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne_a[0], ne_a[1], ne_a[2], 1);\n        ggml_set_name(a, \"a\");\n\n        const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;\n        const bool is_vision = mode == GGML_ROPE_TYPE_VISION;\n\n        ggml_tensor * pos;\n        if (is_mrope || is_vision) {\n            pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4);\n        } else {\n            pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);\n        }\n        ggml_set_name(pos, \"pos\");\n\n        float fs = 1.4245f;\n        float ef = 0.7465f;\n        float af = 1.4245f;\n        ggml_tensor * freq = nullptr;\n\n        ggml_tensor * rope = nullptr;\n        if (is_mrope) {\n            if (is_vision) {\n                GGML_ASSERT(n_dims/4 > 0);\n                int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate\n                rope = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);\n            } else {\n                GGML_ASSERT(n_dims/3 > 0);\n                int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};\n                rope = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);\n            }\n        } else {\n            rope = ggml_rope(ctx, a, pos, ne_a[0], mode);\n        }\n\n        ggml_tensor * view = ggml_view_2d(ctx, rope, ne_a[0] * ne_a[1], ne_a[2], rope->nb[2], 0);\n\n        ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne_a[0] * ne_a[1], ne_a[2] * ne_a[3], 1, 1);\n        ggml_set_name(dst, \"dst\");\n\n        ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, type_idx, ne_a[2], 1, 1);\n        ggml_set_name(row_idxs, \"row_idxs\");\n\n        ggml_tensor * out = ggml_set_rows(ctx, dst, view, row_idxs);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (strcmp(t->name, \"row_idxs\") == 0) {\n                if (ggml_is_view_op(t->op)) {\n                    continue;\n                }\n                init_set_rows_row_ids(t, ne_a[2]);\n            } else if (t->type == GGML_TYPE_I32) {\n                // pos\n                const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2];\n                std::vector<int> data(num_pos_ids);\n                for (int i = 0; i < num_pos_ids; i++) {\n                    data[i] = rand() % n_ctx;\n                }\n                ggml_backend_tensor_set(t, data.data(), 0, num_pos_ids * sizeof(int));\n            } else {\n                if (t->ne[0] == n_dims/2) {\n                    // frequency factors in the range [0.9f, 1.1f]\n                    init_tensor_uniform(t, 0.9f, 1.1f);\n                } else {\n                    init_tensor_uniform(t);\n                }\n            }\n        }\n    }\n};\n\n// GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ROPE (+ GGML_OP_VIEW + GGML_OP_SET_ROWS)\nstruct test_rms_norm_mul_rope : public test_case {\n    const std::array<int64_t, 4> ne;\n    const float eps;\n    const bool multi_add; // test a sequence of adds feeding into rms_norm\n    const bool set_rows;\n    int mode;\n\n    std::string op_desc(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return \"RMS_NORM_MUL_ROPE\";\n    }\n\n    bool run_whole_graph() override { return true; }\n\n    std::string vars() override {\n        return VARS_TO_STR5(ne, eps, multi_add, set_rows, mode);\n    }\n\n    test_rms_norm_mul_rope(std::array<int64_t, 4> ne, float eps = 1e-6f, bool multi_add = false,\n                           bool set_rows = false, int mode = GGML_ROPE_TYPE_NORMAL)\n        : ne(ne), eps(eps), multi_add(multi_add), set_rows(set_rows), mode(mode) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);\n        ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);\n        ggml_tensor * c = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);\n\n        if (multi_add) {\n            a = ggml_add(ctx, ggml_add(ctx, a, b), c);\n        }\n\n        a = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);\n\n        ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);\n\n        ggml_tensor * rope = ggml_rope(ctx, a, pos, ne[0], mode);\n\n        ggml_tensor * out;\n\n        if (set_rows) {\n            ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0);\n\n            ggml_tensor * dst = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, ne[0] * ne[1], ne[2] * ne[3], 1, 1);\n            ggml_set_name(dst, \"dst\");\n\n            ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, ne[2], 1, 1);\n            ggml_set_name(row_idxs, \"row_idxs\");\n\n            out = ggml_set_rows(ctx, dst, view, row_idxs);\n            ggml_set_name(out, \"out\");\n        } else {\n            out = rope;\n        }\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {\n                if (ggml_is_view_op(t->op)) {\n                    continue;\n                }\n\n                init_set_rows_row_ids(t, ne[2]);\n            } else {\n                init_tensor_uniform(t);\n            }\n        }\n    }\n};\n\n// GGML_OP_ARGMAX\nstruct test_argmax : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_argmax(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 100, 1, 1})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_argmax(ctx, a);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        std::random_device rd;\n        std::default_random_engine rng(rd());\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (t->type == GGML_TYPE_F32) {\n                // initialize with unique values to avoid ties\n                for (int64_t r = 0; r < ggml_nrows(t); r++) {\n                    std::vector<float> data(t->ne[0]);\n                    for (int i = 0; i < t->ne[0]; i++) {\n                        data[i] = i;\n                    }\n                    std::shuffle(data.begin(), data.end(), rng);\n                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));\n                }\n            } else {\n                init_tensor_uniform(t);\n            }\n        }\n    }\n\n    double max_nmse_err() override {\n        return 0.0;\n    }\n};\n\n// GGML_OP_COUNT_EQUAL\nstruct test_count_equal : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_count_equal(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {4, 500, 1, 1})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * a_argmax = ggml_argmax(ctx, a);\n        ggml_set_name(a_argmax, \"a_argmax\");\n\n        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(b, \"b\");\n\n        ggml_tensor * b_argmax = ggml_argmax(ctx, b);\n        ggml_set_name(b_argmax, \"b_argmax\");\n\n        ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    double max_nmse_err() override {\n        return 0.0;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        std::random_device rd;\n        std::default_random_engine rng(rd());\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (t->type == GGML_TYPE_F32) {\n                // initialize with unique values to avoid ties\n                for (int64_t r = 0; r < ggml_nrows(t); r++) {\n                    std::vector<float> data(t->ne[0]);\n                    for (int i = 0; i < t->ne[0]; i++) {\n                        data[i] = i;\n                    }\n                    std::shuffle(data.begin(), data.end(), rng);\n                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));\n                }\n            } else {\n                init_tensor_uniform(t);\n            }\n        }\n    }\n};\n\n// GGML_OP_REPEAT\nstruct test_repeat : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const std::array<int, 4> nr;\n\n    std::string vars() override {\n        return VARS_TO_STR3(type, ne, nr);\n    }\n\n    size_t op_size(ggml_tensor * t) override {\n        return ggml_nbytes(t) * 2;\n    }\n\n    test_repeat(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 5, 4, 3},\n            std::array<int, 4> nr = {2, 2, 2, 2})\n        : type(type), ne(ne), nr(nr) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * target = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);\n        ggml_set_name(target, \"target\");\n\n        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(src);\n        ggml_set_name(src, \"src\");\n\n        ggml_tensor * out = ggml_repeat(ctx, src, target);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_REPEAT_BACK\nstruct test_repeat_back : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const std::array<int, 4> nr;\n    const bool v; // whether src is a noncontiguous view\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne, nr, v);\n    }\n\n    size_t op_size(ggml_tensor * t) override {\n        return ggml_nbytes(t) * 2;\n    }\n\n    test_repeat_back(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {8, 6, 4, 2},\n            std::array<int, 4> nr = {2, 2, 2, 2},\n            bool v = false)\n        : type(type), ne(ne), nr(nr), v(v) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);\n        ggml_set_name(src, \"src\");\n\n        if (v) {\n            GGML_ASSERT(ne[0] % 2 == 0);\n            GGML_ASSERT(ne[1] % 2 == 0);\n            GGML_ASSERT(ne[2] % 2 == 0);\n            GGML_ASSERT(ne[3] % 2 == 0);\n            GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1);\n            GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1);\n            GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1);\n            GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1);\n\n            const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2;\n            const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2;\n            const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2;\n            const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2;\n\n            src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0);\n        }\n\n        ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(target, \"target\");\n\n        ggml_tensor * out = ggml_repeat_back(ctx, src, target);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_DUP\nstruct test_dup : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const std::array<int64_t, 4> permute;\n    bool _use_permute;\n\n    std::string vars() override {\n        std::string v = VARS_TO_STR2(type, ne);\n        if (_use_permute) v += \",\" + VAR_TO_STR(permute);\n        return v;\n    }\n\n    test_dup(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 10, 20, 1},\n            std::array<int64_t, 4> permute = {0, 0, 0, 0})\n        : type(type), ne(ne), permute(permute),\n            _use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(src);\n        ggml_set_name(src, \"src\");\n\n        if (_use_permute) {\n            src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]);\n            ggml_set_name(src, \"src_permuted\");\n        }\n\n        ggml_tensor * out = ggml_dup(ctx, src);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_SET\nstruct test_set : public test_case {\n    const ggml_type type_src;\n    const ggml_type type_dst;\n    const std::array<int64_t, 4> ne;\n    const int dim;\n    const bool inplace;\n\n    std::string vars() override {\n        return VARS_TO_STR5(type_src, type_dst, ne, dim, inplace);\n    }\n\n    size_t op_size(ggml_tensor * t) override {\n        return ggml_nbytes(t) + ggml_nbytes(t->src[0]);\n    }\n\n    test_set(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {6, 5, 4, 3}, int dim = 1, bool inplace = false)\n        : type_src(type_src), type_dst(type_dst), ne(ne), dim(dim), inplace(inplace) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());\n        ggml_set_param(src);\n        ggml_set_name(src, \"src\");\n\n        auto ne_dst = ne;\n        for (int i = 0; i < dim; ++i) {\n            ne_dst[i] *= 2;\n        }\n        ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());\n        ggml_set_param(dst);\n        ggml_set_name(dst, \"dst\");\n\n        size_t offset = 0;\n        for (int i = 0; i < dim; ++i) {\n            offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i];\n        }\n        ggml_tensor * out;\n        if (inplace) {\n            out = ggml_set_inplace(ctx, dst, src,\n                    // The backward pass requires setting a contiguous region:\n                    src->nb[1], src->nb[2], src->nb[3], offset);\n        } else {\n            out = ggml_set(ctx, dst, src,\n                    // The backward pass requires setting a contiguous region:\n                    src->nb[1], src->nb[2], src->nb[3], offset);\n        }\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_CPY\nstruct test_cpy : public test_case {\n    const ggml_type type_src;\n    const ggml_type type_dst;\n    const std::array<int64_t, 4> ne;\n    const std::array<int64_t, 4> permute_src;\n    const std::array<int64_t, 4> permute_dst;\n    bool _src_use_permute;\n    bool _dst_use_permute;\n    bool _src_transpose;\n\n    std::string vars() override {\n        return VARS_TO_STR6(type_src, type_dst, ne, permute_src, permute_dst, _src_transpose);\n    }\n\n    double max_nmse_err() override {\n        if (type_src == type_dst) {\n            return 0.0;\n        }\n        if (type_dst == GGML_TYPE_Q4_0 || type_dst == GGML_TYPE_Q4_1 || type_dst == GGML_TYPE_IQ4_NL ||\n            type_dst == GGML_TYPE_Q5_0 || type_dst == GGML_TYPE_Q5_1 || type_dst == GGML_TYPE_Q8_0) {\n            // estimate what the max nmse error would be if one quantized value is\n            // off by one. The test values are distributed in [-150,150], so it'll be\n            // roughly (150*2.0 / 2^bits)^2, divided by the mean square value of the reference,\n            // which is roughly 0.25*150^2 times the number of elements.\n            double err_estimate = 1.0f/8.0f * 150.0f;\n            if (type_dst == GGML_TYPE_IQ4_NL) {\n                // iq4_nl values are a bit more spread out\n                err_estimate *= 2.0f;\n            }\n            if (type_dst == GGML_TYPE_Q5_0 || type_dst == GGML_TYPE_Q5_1) {\n                err_estimate /= 2.0f;\n            }\n            if (type_dst == GGML_TYPE_Q8_0) {\n                err_estimate /= 8.0f;\n            }\n            err_estimate *= err_estimate;\n            err_estimate /= (150.0f*150.0f*0.25f)*float(ne[0] * ne[1] * ne[2] * ne[3]);\n            return err_estimate;\n        }\n        return 1e-6;\n    }\n\n    size_t op_size(ggml_tensor * t) override {\n        return ggml_nbytes(t) + ggml_nbytes(t->src[0]);\n    }\n\n    test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 10, 10, 1},\n            std::array<int64_t, 4> permute_src = {0, 0, 0, 0},\n            std::array<int64_t, 4> permute_dst = {0, 0, 0, 0},\n            bool transpose_src = false)\n        : type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),\n          _src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),\n          _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0),\n          _src_transpose(transpose_src){}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());\n        ggml_set_param(src);\n        ggml_set_name(src, \"src\");\n\n        if (_src_use_permute) {\n            src = ggml_permute(ctx, src, permute_src[0], permute_src[1], permute_src[2], permute_src[3]);\n            ggml_set_name(src, \"src_permuted\");\n        }\n\n        if (_src_transpose) {\n            src = ggml_transpose(ctx, src);\n            ggml_set_name(src, \"src_transposed\");\n        }\n\n        ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);\n        ggml_set_name(dst, \"dst\");\n\n        if (_dst_use_permute) {\n            dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);\n            ggml_set_name(dst, \"dst_permuted\");\n        }\n\n        ggml_tensor * out = ggml_cpy(ctx, src, dst);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            // test extended range of values to check if casting between f32 and i32 is consistent\n            init_tensor_uniform(t, -150.f, 150.f);\n        }\n    }\n};\n\n// GGML_OP_CONT\nstruct test_cont : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    bool use_view_slice;\n\n    std::string vars() override {\n        return VARS_TO_STR3(type, ne, use_view_slice);\n    }\n\n    test_cont(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 10, 10, 1},\n            bool use_view_slice = false)\n        : type(type), ne(ne), use_view_slice(use_view_slice) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(src);\n        ggml_set_name(src, \"src\");\n\n\n        ggml_tensor * dst;\n        if (use_view_slice) {\n            dst = ggml_view_4d(ctx, src, src->ne[0], 1, src->ne[2], src->ne[3],\n                src->nb[1], src->nb[2], src->nb[3], src->nb[0] * (src->ne[1] - 1));\n            ggml_set_name(dst, \"src_view_slice\");\n        } else {\n            dst = ggml_transpose(ctx, src);\n            ggml_set_name(dst, \"src_transposed\");\n        }\n\n        ggml_tensor * out = ggml_cont(ctx, dst);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_ADD\n// GGML_OP_SUB\n// GGML_OP_MUL\n// GGML_OP_DIV\nstruct test_bin_bcast : public test_case {\n    using op_t = ggml_tensor * (*) (ggml_context *, ggml_tensor *, ggml_tensor *);\n    op_t op;\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const std::array<int, 4> nr;\n    int nf; // number of fused ops, nf == 1 -> single op (no fusion)\n    bool perm1; // permute src1?\n    bool src_overlap; // src0 and src1 are overlapping views of the same buffer\n\n    bool run_whole_graph() override { return nf > 1; }\n\n    std::string vars() override {\n        return VARS_TO_STR5(type, ne, nr, nf, perm1);\n    }\n\n    size_t op_size(ggml_tensor * t) override {\n        return ggml_nbytes(t) * 3;\n    }\n\n    test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 10, 1, 1},\n            std::array<int, 4> nr = {1, 2, 1, 1},\n            int nf = 1,\n            bool perm1 = false, bool src_overlap = false)\n        : op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1), src_overlap(src_overlap) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        GGML_ASSERT(nf <= 16);\n\n        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * b[16];\n        for (int i = 0; i < nf; ++i) {\n            if (perm1) {\n                const int p[4] = { 1, 2, 0, 3 }; // hardcoded for now\n\n                b[i] = ggml_new_tensor_4d(ctx, type, ne[p[0]], ne[p[1]], ne[p[2]], ne[p[3]]);\n                b[i] = ggml_permute(ctx, b[i], p[0], p[1], p[2], p[3]);\n            } else if (src_overlap) {\n                b[i] = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], 2 * (ne[3] / 3), a->nb[1], a->nb[2], a->nb[3], (ne[3] / 3) * a->nb[3]);\n            } else {\n                b[i] = ggml_new_tensor(ctx, type, 4, ne.data());\n            }\n            ggml_set_name(b[i], (std::string(\"b\") + std::to_string(i)).c_str());\n        }\n\n        // The backward pass supports broadcasting only for GGML_ADD:\n        const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1 && !perm1;\n        if (grad_supported) {\n            ggml_set_param(a);\n            ggml_set_param(b[0]);\n        }\n\n        ggml_tensor *out;\n\n        if (src_overlap) {\n            out = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], 2 * (ne[3] / 3), a->nb[1], a->nb[2], a->nb[3], 0);\n        } else {\n            out = a;\n        }\n\n        for (int i = 0; i < nf; ++i) {\n            out = op(ctx, out, b[i]);\n        }\n\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (op == ggml_mul || op == ggml_div) {\n                // MUL and DIV have numerical issues around zero:\n                init_tensor_uniform(t, 0.9f, 1.1f);\n            } else {\n                init_tensor_uniform(t);\n            }\n        }\n    }\n\n    float grad_eps() override {\n        return 0.1f * (op == ggml_mul ? ne[0]*ne[1]*ne[2]*ne[3] : 1);\n    }\n\n    bool grad_precise() override {\n        return op == ggml_div;\n    }\n\n    double max_maa_err() override {\n        return op == ggml_add ? 1e-4 : 1e-3;\n    }\n};\n\n// GGML_OP_ADD_ID\nstruct test_add_id : public test_case {\n    const ggml_type type_a;\n    const ggml_type type_b;\n    const int64_t n_embd;\n    const int64_t n_experts;\n    const int64_t n_experts_used;\n    const int64_t n_token;\n\n    std::string vars() override {\n        return VARS_TO_STR6(type_a, type_b, n_embd, n_experts, n_experts_used, n_token);\n    }\n\n    size_t op_size(ggml_tensor * t) override {\n        return ggml_nbytes(t) + ggml_nbytes(t->src[0]) + ggml_nbytes(t->src[2]);\n    }\n\n    test_add_id(ggml_type type_a = GGML_TYPE_F32,\n            ggml_type type_b = GGML_TYPE_F32,\n            int64_t n_embd = 128,\n            int64_t n_experts = 16,\n            int64_t n_experts_used = 8,\n            int64_t n_token = 10)\n        : type_a(type_a), type_b(type_b), n_embd(n_embd),\n          n_experts(n_experts), n_experts_used(n_experts_used), n_token(n_token) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor_3d(ctx, type_a, n_embd, n_experts_used, n_token);\n        ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, n_embd, n_experts);\n        ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_experts, n_token);\n        if (n_experts_used != n_experts) {\n            ids = ggml_view_2d(ctx, ids, n_experts_used, n_token, ids->nb[1], 0);\n            ggml_set_name(ids, \"view_of_ids\");\n        }\n\n        ggml_tensor * out = ggml_add_id(ctx, a, b, ids);\n        ggml_set_name(out, \"out\");\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (t->type == GGML_TYPE_I32) {\n                if (ggml_is_view_op(t->op)) { continue; }\n                std::random_device rd;\n                std::default_random_engine rng(rd());\n                // ids\n                for (int64_t r = 0; r < ggml_nrows(t); r++) {\n                    std::vector<int32_t> data(t->ne[0]);\n                    for (int i = 0; i < t->ne[0]; i++) {\n                        data[i] = i % n_experts;\n                    }\n                    std::shuffle(data.begin(), data.end(), rng);\n                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));\n                }\n            } else {\n                init_tensor_uniform(t);\n            }\n        }\n    }\n};\n\n// GGML_OP_ADD1\nstruct test_add1 : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_add1(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 5, 4, 3})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * b = ggml_new_tensor_1d(ctx, type, 1);\n        // ggml_set_param(b); // TODO: implement\n        ggml_set_name(b, \"b\");\n\n        ggml_tensor * out = ggml_add1(ctx, a, b);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    float grad_eps() override {\n        return 0.1f * ne[0]*ne[1]*ne[2]*ne[3];\n    }\n};\n\n// GGML_OP_SCALE\nstruct test_scale : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    float scale;\n    float bias;\n    bool inplace;\n\n    std::string vars() override {\n        return VARS_TO_STR5(type, ne, scale, bias, inplace);\n    }\n\n    test_scale(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 10, 10, 10},\n            float scale = 2.0f,\n            float bias = 0.0f,\n            bool inplace = false)\n        : type(type), ne(ne), scale(scale), bias(bias), inplace(inplace) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out;\n        if (inplace) {\n            out = ggml_scale_bias_inplace(ctx, a, scale, bias);\n        } else {\n            out = ggml_scale_bias(ctx, a, scale, bias);\n        }\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE\nstruct test_softcap : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    float softcap;\n\n    std::string op_desc(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return \"SOFTCAP\";\n    }\n\n    bool run_whole_graph() override { return true; }\n\n    std::string vars() override {\n        return VARS_TO_STR3(type, ne, softcap);\n    }\n\n    test_softcap(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 10, 10, 10},\n            float softcap = 30.0f)\n        : type(type), ne(ne), softcap(softcap) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_scale(ctx, ggml_tanh(ctx, ggml_scale(ctx, a, 1.0f / softcap)), softcap);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_SILU_BACK\nstruct test_silu_back : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    float eps;\n\n    std::string vars() override {\n        return VARS_TO_STR3(type, ne, eps);\n    }\n\n    test_silu_back(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {64, 5, 4, 3},\n            float eps = 1e-6f)\n        : type(type), ne(ne), eps(eps) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * grad = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(grad, \"grad\");\n\n        ggml_tensor * out = ggml_silu_back(ctx, a, grad);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    bool grad_precise() override {\n        return true;\n    }\n};\n\n// GGML_OP_NORM\nstruct test_norm : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const bool v; // whether a is a non-contiguous view\n    const float eps;\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne, v, eps);\n    }\n\n    test_norm(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {64, 5, 4, 3},\n            bool v = false,\n            float eps = 1e-6f)\n        : type(type), ne(ne), v(v), eps(eps) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\");\n\n        if (v) {\n            a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);\n            ggml_set_name(a, \"view of a\");\n        }\n\n        ggml_tensor * out = ggml_norm(ctx, a, eps);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_NORM + GGML_OP_MUL + GGML_OP_ADD\nstruct test_norm_mul_add : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    float eps;\n    const bool broadcast;\n\n    std::string op_desc(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return \"NORM_MUL_ADD\";\n    }\n\n    bool run_whole_graph() override { return true; }\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne, eps, broadcast);\n    }\n\n    test_norm_mul_add(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {128, 2, 1, 1},\n            float eps = 1e-5f,\n            bool broadcast = false)\n        : type(type), ne(ne), eps(eps), broadcast(broadcast) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        std::array<int64_t, 4> broadcast_dims = {ne[0], ne[1] * 2, ne[2] * 2, ne[3] * 2};\n\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());\n        ggml_tensor * w = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a); ggml_set_param(w); ggml_set_param(b);\n        ggml_set_name(a, \"a\"); ggml_set_name(w, \"w\"); ggml_set_name(b, \"b\");\n\n        // Use a, w and b early to avoid OP_NONE in graph\n        a = ggml_add(ctx, ggml_add(ctx, a, w), b);\n\n        ggml_tensor * n = ggml_norm(ctx, a, eps);\n        ggml_tensor * m = ggml_mul(ctx, n, w);\n        ggml_tensor * out = ggml_add(ctx, m, b);\n        ggml_set_name(out, \"out\");\n        return out;\n    }\n};\n// GGML_OP_RMS_NORM\nstruct test_rms_norm : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const bool v; // whether a is a non-contiguous view\n    const float eps;\n    const bool inplace; // whether to do the operation inplace\n\n    std::string vars() override {\n        return VARS_TO_STR5(type, ne, v, eps, inplace);\n    }\n\n    test_rms_norm(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {64, 5, 4, 3},\n            bool v = false,\n            float eps = 1e-6f,\n            bool inplace = false)\n        : type(type), ne(ne), v(v), eps(eps), inplace(inplace) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        if (v) {\n            a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);\n            ggml_set_name(a, \"view of a\");\n        }\n\n        ggml_tensor * out;\n        if (inplace) {\n            out = ggml_rms_norm_inplace(ctx, a, eps);\n        } else {\n            out = ggml_rms_norm(ctx, a, eps);\n        }\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -10.f, 10.f);\n        }\n    }\n\n    float grad_eps() override {\n        return 1.0f;\n    }\n\n    bool grad_precise() override {\n        return true;\n    }\n};\n\n// GGML_OP_RMS_NORM_BACK\nstruct test_rms_norm_back : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const float eps;\n\n    std::string vars() override {\n        return VARS_TO_STR3(type, ne, eps);\n    }\n\n    test_rms_norm_back(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {64, 5, 4, 3},\n            float eps = 1e-6f)\n        : type(type), ne(ne), eps(eps) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(b, \"b\");\n\n        ggml_tensor * out = ggml_rms_norm_back(ctx, a, b, eps);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -10.f, 10.f);\n        }\n    }\n};\n\n// GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ADD\nstruct test_rms_norm_mul_add : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const float eps;\n    const bool broadcast;\n    const bool multi_add; // test a sequence of adds feeding into rms_norm\n\n    std::string op_desc(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return \"RMS_NORM_MUL_ADD\";\n    }\n\n    bool run_whole_graph() override { return true; }\n\n    std::string vars() override {\n        return VARS_TO_STR5(type, ne, eps, broadcast, multi_add);\n    }\n\n    test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {64, 5, 4, 3},\n            float eps = 1e-6f, bool broadcast = false, bool multi_add = false)\n        : type(type), ne(ne), eps(eps), broadcast(broadcast), multi_add(multi_add) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};\n\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());\n        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_tensor * c = ggml_new_tensor(ctx, type, 4, ne.data());\n\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n        ggml_set_param(b);\n        ggml_set_name(b, \"b\");\n        ggml_set_param(c);\n        ggml_set_name(c, \"c\");\n\n        // Use a, b and c early, so we don't end up with an OP_NONE between rms_norm and mul\n        a = ggml_add(ctx, ggml_add(ctx, a, b), c);\n        if (multi_add) {\n            a = ggml_add(ctx, ggml_add(ctx, a, b), c);\n        }\n        ggml_tensor * out = ggml_add(ctx, ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b), c);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -10.f, 10.f);\n        }\n    }\n\n    float grad_eps() override {\n        return 1.0f;\n    }\n\n    bool grad_precise() override {\n        return true;\n    }\n};\n\n// GGML_OP_ADD + GGML_OP_RMS_NORM (fused operation)\nstruct test_add_rms_norm : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const float eps;\n    const bool broadcast;\n\n    std::string op_desc(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return \"ADD_RMS_NORM\";\n    }\n\n    bool run_whole_graph() override { return true; }\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne, eps, broadcast);\n    }\n\n    test_add_rms_norm(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {64, 5, 4, 3},\n            float eps = 1e-6f, bool broadcast = false)\n        : type(type), ne(ne), eps(eps), broadcast(broadcast) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};\n\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());\n        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());\n\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n        ggml_set_param(b);\n        ggml_set_name(b, \"b\");\n\n        // ADD operation followed by RMS_NORM\n        ggml_tensor * add_result = ggml_add(ctx, a, b);\n        ggml_set_name(add_result, \"add_result\");\n\n        ggml_tensor * out = ggml_rms_norm(ctx, add_result, eps);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -10.f, 10.f);\n        }\n    }\n\n    float grad_eps() override {\n        return 1.0f;\n    }\n\n    bool grad_precise() override {\n        return true;\n    }\n};\n\n// GGML_OP_SSM_CONV\nstruct test_ssm_conv : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne_a;\n    const std::array<int64_t, 4> ne_b;\n\n    std::string vars() override {\n        return VARS_TO_STR3(type, ne_a, ne_b);\n    }\n\n    test_ssm_conv(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_a = {10, 10, 10, 1},\n            std::array<int64_t, 4> ne_b = {3, 3, 1, 1})\n        : type(type), ne_a(ne_a), ne_b(ne_b) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a   = ggml_new_tensor(ctx, type, 4, ne_a.data());\n        ggml_tensor * b   = ggml_new_tensor(ctx, type, 4, ne_b.data());\n        ggml_tensor * out = ggml_ssm_conv(ctx, a, b);\n        return out;\n    }\n};\n\n// GGML_OP_SSM_SCAN\nstruct test_ssm_scan : public test_case {\n    const ggml_type type;\n\n    const int64_t d_state;\n    const int64_t head_dim;\n    const int64_t n_head;\n    const int64_t n_group;\n    const int64_t n_seq_tokens;\n    const int64_t n_seqs;\n\n    std::string vars() override {\n        return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs);\n    }\n\n    test_ssm_scan(ggml_type type = GGML_TYPE_F32,\n            int64_t d_state = 32,\n            int64_t head_dim = 1, // non-zero for Mamba-2\n            int64_t n_head  = 32,\n            int64_t n_group = 1,\n            int64_t n_seq_tokens = 32,\n            int64_t n_seqs = 32)\n        : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * s   = ggml_new_tensor_4d(ctx, type, d_state,  head_dim,     n_head,       n_seqs);\n        ggml_tensor * x   = ggml_new_tensor_4d(ctx, type, head_dim, n_head,       n_seq_tokens, n_seqs);\n        ggml_tensor * dt  = ggml_new_tensor_3d(ctx, type, n_head,   n_seq_tokens, n_seqs);\n        ggml_tensor * A   = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head);\n        ggml_tensor * B   = ggml_new_tensor_4d(ctx, type, d_state,  n_group,      n_seq_tokens, n_seqs);\n        ggml_tensor * C   = ggml_new_tensor_4d(ctx, type, d_state,  n_group,      n_seq_tokens, n_seqs);\n        ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32,  n_seqs);\n        ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids);\n        return out;\n    }\n\n    // similar to test_mul_mat_id\n    void initialize_tensors(ggml_context * ctx) override {\n        std::random_device rd;\n        std::default_random_engine rng(rd());\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (t->type == GGML_TYPE_I32) {\n                if (ggml_is_view_op(t->op)) { continue; }\n                // ids\n                for (int64_t r = 0; r < ggml_nrows(t); r++) {\n                    std::vector<int32_t> data(t->ne[0]);\n                    for (int i = 0; i < t->ne[0]; i++) {\n                        data[i] = i;\n                    }\n                    std::shuffle(data.begin(), data.end(), rng);\n                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));\n                }\n            } else {\n                init_tensor_uniform(t);\n            }\n        }\n    }\n};\n\n// GGML_OP_RWKV_WKV6\nstruct test_rwkv_wkv6 : public test_case {\n    const ggml_type type;\n\n    const int64_t head_count;\n    const int64_t head_size;\n    const int64_t n_seq_tokens;\n    const int64_t n_seqs;\n\n    std::string vars() override {\n        return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);\n    }\n\n    test_rwkv_wkv6(ggml_type type = GGML_TYPE_F32,\n            int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)\n        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        const int64_t n_tokens = n_seq_tokens * n_seqs;\n        ggml_tensor * r   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());\n        ggml_tensor * k   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());\n        ggml_tensor * v   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());\n        ggml_tensor * tf  = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());\n        ggml_tensor * td  = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());\n        ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());\n        ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s);\n        return out;\n    }\n};\n\n// GGML_OP_GATED_DELTA_NET\nstruct test_gated_delta_net : public test_case {\n    const ggml_type type;\n\n    const int64_t head_count;\n    const int64_t head_size;\n    const int64_t n_seq_tokens;\n    const int64_t n_seqs;\n    const int     v_repeat;\n    const bool    permuted;\n    const bool    kda;\n\n    std::string vars() override {\n        return VARS_TO_STR8(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda);\n    }\n\n    test_gated_delta_net(ggml_type type = GGML_TYPE_F32,\n            int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1,\n            int v_repeat = 1, bool permuted = false, bool kda = false)\n        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs),\n          v_repeat(v_repeat), permuted(permuted), kda(kda) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * q;\n        ggml_tensor * k;\n        ggml_tensor * v;\n        if (permuted) {\n            // create with dims 1 and 2 swapped, then permute back to get non-contiguous layout\n            q = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count, n_seqs), 0, 2, 1, 3);\n            k = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count, n_seqs), 0, 2, 1, 3);\n            v = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count * v_repeat, n_seqs), 0, 2, 1, 3);\n        } else {\n            q = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs);\n            k = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs);\n            v = ggml_new_tensor_4d(ctx, type, head_size, head_count * v_repeat, n_seq_tokens, n_seqs);\n        }\n        const int64_t g_ne0 = kda ? head_size : 1;\n        ggml_tensor * g     = ggml_new_tensor_4d(ctx, type, g_ne0, head_count * v_repeat, n_seq_tokens, n_seqs);\n        ggml_tensor * beta  = ggml_new_tensor_4d(ctx, type, 1, head_count * v_repeat, n_seq_tokens, n_seqs);\n        ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * v_repeat * head_size * head_count, n_seqs);\n        ggml_tensor * out   = ggml_gated_delta_net(ctx, q, k, v, g, beta, state);\n        return out;\n    }\n};\n\n// GGML_OP_GATED_LINEAR_ATTN\nstruct test_gla : public test_case {\n    const ggml_type type;\n\n    const int64_t head_count;\n    const int64_t head_size;\n    const int64_t n_seq_tokens;\n    const int64_t n_seqs;\n\n    std::string vars() override {\n        return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);\n    }\n\n    test_gla(ggml_type type = GGML_TYPE_F32,\n            int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)\n        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        const int64_t n_tokens = n_seq_tokens * n_seqs;\n        ggml_tensor * q   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());\n        ggml_tensor * k   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());\n        ggml_tensor * v   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());\n        ggml_tensor * g   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());\n        ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());\n        ggml_tensor * out = ggml_gated_linear_attn(ctx, k, v, q, g, s, pow(head_size, -0.5));\n        return out;\n    }\n};\n\n// GGML_OP_RWKV_WKV7\nstruct test_rwkv_wkv7 : public test_case {\n    const ggml_type type;\n\n    const int64_t head_count;\n    const int64_t head_size;\n    const int64_t n_seq_tokens;\n    const int64_t n_seqs;\n\n    std::string vars() override {\n        return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);\n    }\n\n    test_rwkv_wkv7(ggml_type type = GGML_TYPE_F32,\n            int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)\n        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        const int64_t n_tokens = n_seq_tokens * n_seqs;\n        ggml_tensor * r   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());\n        ggml_tensor * w   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());\n        ggml_tensor * k   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());\n        ggml_tensor * v   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());\n        ggml_tensor * a   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());\n        ggml_tensor * b   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());\n        // Outputs may become NaN with long seqlen without these normalization\n        a = ggml_l2_norm(ctx, a, 1e-7F);\n        b = ggml_l2_norm(ctx, b, 1e-7F);\n        ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());\n        ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);\n        return out;\n    }\n};\n\n// GGML_OP_MUL_MAT\nstruct test_mul_mat : public test_case {\n    const ggml_type type_a;\n    const ggml_type type_b;\n    const int64_t m;\n    const int64_t n;\n    const int64_t k;\n    const std::array<int64_t, 2> bs;  // dims 3 and 4\n    const std::array<int64_t, 2> nr;  // repeat in dims 3 and 4\n    const std::array<int64_t, 4> per; // permutation of dimensions\n    const int64_t k_v; // size of k in memory, resulting in a non-contiguous view for k_v > k, no view for k_v == 0\n    const uint32_t o; // number of outputs\n\n    std::string vars() override {\n        return VARS_TO_STR10(type_a, type_b, m, n, k, bs, nr, per, k_v, o);\n    }\n\n    double max_nmse_err() override {\n        return 5e-4;\n    }\n\n    double max_nmse_err(ggml_backend_t backend) override {\n        // for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance\n        if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, \"BLACKWELL_NATIVE_FP4\")) {\n            return 2e-2;\n        }\n        return max_nmse_err();\n    }\n\n    int64_t grad_nmax() override {\n        return 20000;\n    }\n\n    uint64_t op_flops(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1];\n    }\n\n    test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,\n            int64_t m = 32, int64_t n = 32, int64_t k = 32,\n            std::array<int64_t, 2> bs = {10, 10},\n            std::array<int64_t, 2> nr = {2, 2},\n            std::array<int64_t, 4> per = {0, 1, 2, 3},\n            int64_t k_v = 0, uint32_t o = 1)\n        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), k_v(k_v), o(o) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        // C^T = A * B^T: (k, m) * (k, n) => (m, n)\n        ggml_tensor * a;\n        ggml_tensor * b;\n\n        const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);\n        if (npermuted > 0) {\n            GGML_ASSERT(npermuted == 2);\n            GGML_ASSERT(k_v == 0); // not handled\n            GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);\n            GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);\n\n            // Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k.\n            const int64_t ne_a[4] = {k, m, bs[0],       bs[1]};\n            const int64_t ne_b[4] = {k, n, bs[0]*nr[0], bs[1]*nr[1]};\n\n            a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);\n            b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);\n            if (!ggml_is_quantized(type_a)) {\n                if (bs[1] == 1 && nr[1] == 1) {\n                    ggml_set_param(a);\n                }\n                ggml_set_param(b);\n            }\n            ggml_set_name(a, \"a\");\n            ggml_set_name(b, \"b\");\n\n            a = ggml_permute(ctx, a, per[0], per[1], per[2], per[3]);\n            b = ggml_permute(ctx, b, per[0], per[1], per[2], per[3]);\n            ggml_set_name(a, \"a_permuted\");\n            ggml_set_name(b, \"b_permuted\");\n        } else {\n            const int64_t k_physical = k_v == 0 ? k : k_v;\n            a = ggml_new_tensor_4d(ctx, type_a, k_physical, m, bs[0],       bs[1]);\n            b = ggml_new_tensor_4d(ctx, type_b, k_physical, n, bs[0]*nr[0], bs[1]*nr[1]);\n\n            if (!ggml_is_quantized(type_a)) {\n                if (bs[1] == 1 && nr[1] == 1) {\n                    ggml_set_param(a);\n                }\n                ggml_set_param(b);\n            }\n\n            if (k_v != 0) {\n                GGML_ASSERT(k_v > k);\n                a = ggml_view_4d(ctx, a, k, m, bs[0],       bs[1],       a->nb[1], a->nb[2], a->nb[3], 0);\n                b = ggml_view_4d(ctx, b, k, n, bs[0]*nr[0], bs[1]*nr[1], b->nb[1], b->nb[2], b->nb[3], 0);\n            }\n            ggml_set_name(a, \"a\");\n            ggml_set_name(b, \"b\");\n        }\n\n        ggml_tensor * out = ggml_mul_mat(ctx, a, b);\n        ggml_set_name(out, \"out\");\n        for (uint32_t i = 1; i < o; ++i) {\n            ggml_tensor * out2 = ggml_mul_mat(ctx, a, b);\n            ggml_set_name(out2, \"out2\");\n            out = ggml_add(ctx, out, out2);\n        }\n\n        return out;\n    }\n\n    bool run_whole_graph() override { return o > 1; }\n\n    std::string op_desc(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return ggml_op_name(GGML_OP_MUL_MAT);\n    }\n};\n\nstatic void init_mul_mat_id_tensors(ggml_context * ctx, int n_mats) {\n    std::random_device rd;\n    std::default_random_engine rng(rd());\n    for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n        if (t->type == GGML_TYPE_I32) {\n            if (ggml_is_view_op(t->op)) { continue; }\n            // ids\n            for (int64_t r = 0; r < ggml_nrows(t); r++) {\n                std::vector<int32_t> data(t->ne[0]);\n                for (int i = 0; i < t->ne[0]; i++) {\n                    data[i] = i % n_mats;\n                }\n                std::shuffle(data.begin(), data.end(), rng);\n                ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));\n            }\n        } else {\n            init_tensor_uniform(t);\n        }\n    }\n}\n\n// GGML_OP_MUL_MAT_ID\nstruct test_mul_mat_id : public test_case {\n    const ggml_type type_a;\n    const ggml_type type_b;\n    const int n_mats;\n    const int n_used;\n    const bool b; // broadcast b matrix\n    const int64_t m;\n    const int64_t n;\n    const int64_t k;\n\n    std::string vars() override {\n        return VARS_TO_STR8(type_a, type_b, n_mats, n_used, b, m, n, k);\n    }\n\n    double max_nmse_err() override {\n        return 5e-4;\n    }\n\n    double max_nmse_err(ggml_backend_t backend) override {\n        // for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance\n        if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, \"BLACKWELL_NATIVE_FP4\")) {\n            return 2e-2;\n        }\n        return max_nmse_err();\n    }\n\n    uint64_t op_flops(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return 2 * m * k * n * n_used;\n    }\n\n    test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,\n            int n_mats = 8, int n_used = 2, bool b = false,\n            int64_t m = 32, int64_t n = 32, int64_t k = 32)\n        : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),\n            m(m), n(n), k(k) {\n            GGML_ASSERT(n_used <= n_mats);\n        }\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        // C^T = A * B^T: (k, m) * (k, n) => (m, n)\n        ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);\n        ggml_set_name(as, \"as\");\n\n        ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);\n        ggml_set_name(ids, \"ids\");\n        if (n_used != n_mats) {\n            ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);\n            ggml_set_name(ids, \"view_of_ids\");\n        }\n\n        ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);\n        ggml_set_name(b, \"b\");\n\n        ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        init_mul_mat_id_tensors(ctx, n_mats);\n    }\n};\n\n// GGML_OP_MUL_MAT_ID + GGML_OP_ADD or GGML_OP_MUL\nstruct test_mul_mat_id_fusion : public test_case {\n    const ggml_type type_a;\n    const ggml_type type_b;\n    const int n_mats;\n    const int n_used;\n    const bool b; // broadcast b matrix\n    const int64_t m;\n    const int64_t n;\n    const int64_t k;\n    const uint32_t o; // number of outputs\n    const bool mul;\n\n    std::string vars() override {\n        return VARS_TO_STR10(type_a, type_b, n_mats, n_used, b, m, n, k, o, mul);\n    }\n\n    double max_nmse_err() override {\n        return 5e-4;\n    }\n\n    uint64_t op_flops(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return 2 * m * k * n * n_used;\n    }\n\n    test_mul_mat_id_fusion(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,\n            int n_mats = 8, int n_used = 2, bool b = false,\n            int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1, bool mul = false)\n        : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),\n            m(m), n(n), k(k), o(o), mul(mul) {\n            GGML_ASSERT(n_used <= n_mats);\n        }\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        // C^T = A * B^T: (k, m) * (k, n) => (m, n)\n        ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);\n        ggml_set_name(as, \"as\");\n\n        ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);\n        ggml_set_name(ids, \"ids\");\n        if (n_used != n_mats) {\n            ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);\n            ggml_set_name(ids, \"view_of_ids\");\n        }\n\n        ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);\n        ggml_set_name(b, \"b\");\n\n        ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);\n        ggml_set_name(out, \"out\");\n\n        for (uint32_t i = 1; i < o; ++i) {\n            ggml_tensor * a2 = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);\n            ggml_tensor * out2 = ggml_mul_mat_id(ctx, a2, b, ids);\n            ggml_set_name(out2, \"out2\");\n            out = ggml_add(ctx, out, out2);\n        }\n\n        if (mul) {\n            std::array<int64_t, 4> ne { 1, out->ne[1], out->ne[2], out->ne[3] };\n            ne[0] = 1;\n            ggml_tensor * m = ggml_new_tensor(ctx, out->type, 4, ne.data());\n            out = ggml_mul(ctx, out, m);\n        }\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        init_mul_mat_id_tensors(ctx, n_mats);\n    }\n\n    bool run_whole_graph() override { return true; }\n\n    std::string op_desc(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return \"MUL_MAT_ID_FUSION\";\n    }\n};\n\n// GGML_OP_OUT_PROD\nstruct test_out_prod : public test_case {\n    const ggml_type type_a;\n    const ggml_type type_b;\n    const int64_t m;\n    const int64_t n;\n    const int64_t k;\n    const std::array<int64_t, 2> bs; // dims 3 and 4\n    const std::array<int64_t, 2> nr; // repeat in dims 3 and 4\n    const bool trans_b;\n\n    std::string vars() override {\n        return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, trans_b);\n    }\n\n    double max_nmse_err() override {\n        return 5e-4;\n    }\n\n    test_out_prod(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,\n            int64_t m = 32, int64_t n = 32, int64_t k = 32,\n            std::array<int64_t, 2> bs = {10, 10},\n            std::array<int64_t, 2> nr = {2, 2},\n            bool trans_b = false)\n        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), trans_b(trans_b) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, m, k, bs[0], bs[1]);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * b;\n        if (trans_b) {\n            b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);\n            b = ggml_transpose(ctx, b);\n        } else {\n            b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0]*nr[0], bs[1]*nr[1]);\n        }\n        ggml_set_name(b, \"b\");\n\n        ggml_tensor * out = ggml_out_prod(ctx, a, b);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_SQR\nstruct test_sqr : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_sqr(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 5, 4, 3})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_sqr(ctx, a);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    float grad_eps() override {\n        return 0.1f * 0.25f*ne[0]*ne[1]*ne[2]*ne[3]; // 10% of expected value of sum.\n    }\n};\n\n// GGML_OP_SQRT\nstruct test_sqrt : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_sqrt(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 3, 3, 2})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_sqrt(ctx, a);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        // fill with positive values\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, 50.0f, 100.0f);\n        }\n    }\n\n    float grad_eps() override {\n        return 20.0f;\n    }\n\n    bool grad_precise() override {\n        return true;\n    }\n};\n\n// GGML_OP_LOG\nstruct test_log : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_log(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 5, 4, 3})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_log(ctx, a);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            // log(1) == 0, cluster values there to keep the sum low for better precision in the backward pass:\n            init_tensor_uniform(t, 0.9f, 1.1f);\n        }\n    }\n\n    bool grad_precise() override {\n        return true;\n    }\n};\n\n// GGML_OP_SIN\nstruct test_sin : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_sin(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 2, 2, 2})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_sin(ctx, a);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi].\n        }\n    }\n\n    double max_maa_err() override {\n        return 1e-3;\n    }\n\n    float grad_eps() override {\n        return 0.2f;\n    }\n\n    bool grad_precise() override {\n        return true;\n    }\n};\n\n// GGML_OP_COS\nstruct test_cos : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_cos(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 2, 2, 2})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_cos(ctx, a);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi].\n        }\n    }\n\n    double max_maa_err() override {\n        return 1e-3;\n    }\n\n    float grad_eps() override {\n        return 0.2f;\n    }\n\n    bool grad_precise() override {\n        return true;\n    }\n};\n\n// GGML_OP_CLAMP\nstruct test_clamp : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    float min;\n    float max;\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne, min, max);\n    }\n\n    test_clamp(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 5, 4, 3},\n            float min = -0.5f, float max = 0.5f)\n        : type(type), ne(ne), min(min), max(max) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_clamp(ctx, a, min, max);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    float grad_eps() override {\n        return 1e-2f;\n    }\n\n    std::vector<float> grad_expect() override {\n        return {0.0f, 1.0f};\n    }\n};\n\n// GGML_OP_FLOOR\nstruct test_floor : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_floor(ggml_type type = GGML_TYPE_F32,\n               std::array<int64_t, 4> ne = {10, 2, 2, 2})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_floor(ctx, a);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -10.0f, 10.0f);\n        }\n    }\n};\n\n// GGML_OP_CEIL\nstruct test_ceil : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_ceil(ggml_type type = GGML_TYPE_F32,\n              std::array<int64_t, 4> ne = {10, 2, 2, 2})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_ceil(ctx, a);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -10.0f, 10.0f);\n        }\n    }\n};\n\n// GGML_OP_ROUND\nstruct test_round : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_round(ggml_type type = GGML_TYPE_F32,\n               std::array<int64_t, 4> ne = {10, 2, 2, 2})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_round(ctx, a);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -10.0f, 10.0f);\n        }\n    }\n};\n\n// GGML_OP_TRUNC\nstruct test_trunc : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_trunc(ggml_type type = GGML_TYPE_F32,\n               std::array<int64_t, 4> ne = {10, 2, 2, 2})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_trunc(ctx, a);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -10.0f, 10.0f);\n        }\n    }\n};\n\n// GGML_OP_DIAG_MASK_INF\nstruct test_diag_mask_inf : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const int n_past;\n\n    std::string vars() override {\n        return VARS_TO_STR3(type, ne, n_past);\n    }\n\n    test_diag_mask_inf(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 10, 3, 2},\n            int n_past = 5)\n        : type(type), ne(ne), n_past(n_past) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_diag_mask_inf(ctx, a, n_past);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_SOFT_MAX\nstruct test_soft_max : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const bool mask;\n    const bool sinks;\n    const ggml_type m_prec;\n    const std::array<int64_t, 2> nr23; // broadcast only dims 2 and 3\n    const float scale;\n    const float max_bias;\n    const bool inplace;\n\n    std::string vars() override {\n        return VARS_TO_STR9(type, ne, mask, sinks, m_prec, nr23, scale, max_bias, inplace);\n    }\n\n    // the 1024 test with bias occasionally fails:\n    // SOFT_MAX(type=f32,ne=[1024,16,1,1],mask=1,scale=1.000000,max_bias=8.000000): [SOFT_MAX] NMSE = 0.000000103 > 0.000000100 FAIL\n    virtual double max_nmse_err() override {\n        return 1e-6;\n    }\n\n    test_soft_max(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 5, 4, 3},\n            bool mask = false,\n            bool sinks = false,\n            ggml_type m_prec = GGML_TYPE_F32,\n            std::array<int64_t, 2> nr23 = {1, 1},\n            float scale = 1.0f,\n            float max_bias = 0.0f,\n            bool inplace = false)\n        : type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias), inplace(inplace) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * mask = nullptr;\n        if (this->mask) {\n            mask = ggml_new_tensor_4d(ctx, m_prec, ne[0], ne[1], ne[2], ne[3]);\n            ggml_set_name(mask, \"mask\");\n        }\n\n        ggml_tensor * sinks = nullptr;\n        if (this->sinks) {\n            sinks = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[2]*nr23[0]);\n            ggml_set_name(sinks, \"sinks\");\n        }\n\n        ggml_tensor * out;\n        if (inplace) {\n            out = ggml_soft_max_ext_inplace(ctx, a, mask, scale, max_bias);\n        } else {\n            out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);\n        }\n        ggml_soft_max_add_sinks(out, sinks);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    bool grad_precise() override {\n        return true;\n    }\n};\n\n// GGML_OP_SOFT_MAX_BACK\nstruct test_soft_max_back : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const float scale;\n    const float max_bias;\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne, scale, max_bias);\n    }\n\n    test_soft_max_back(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 5, 4, 3},\n            float scale = 1.0f,\n            float max_bias = 0.0f)\n        : type(type), ne(ne), scale(scale), max_bias(max_bias) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_soft_max_ext_back(ctx, a, b, scale, max_bias);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_ROPE + GGML_OP_ROPE_BACK\nstruct test_rope : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne_a;\n    int n_dims;\n    int mode;\n    int n_ctx; // used to generate positions\n    float fs; // freq_scale\n    float ef; // ext_factor\n    float af; // attn_factor\n    bool ff;\n    int v; // view (1 : non-contiguous a)\n    bool forward;\n    bool inplace;\n\n    std::string vars() override {\n        // forward can be inferred from the op, does not need to be printed\n        return VARS_TO_STR11(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v, inplace);\n    }\n\n    test_rope(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_a = {10, 5, 3, 1},\n            int n_dims = 10, int mode = GGML_ROPE_TYPE_NORMAL, int n_ctx = 512, float fs = 1.0f,\n            float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true, bool inplace = false)\n        : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward), inplace(inplace) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a;\n        if (v & 1) {\n            auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;\n            a = ggml_new_tensor(ctx, type, 4, ne.data());\n            if (forward) {\n                ggml_set_param(a);\n            }\n            ggml_set_name(a, \"a\");\n\n            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);\n            ggml_set_name(a, \"view_of_a\");\n        } else {\n            a = ggml_new_tensor(ctx, type, 4, ne_a.data());\n            if (forward) {\n                ggml_set_param(a);\n            }\n            ggml_set_name(a, \"a\");\n        }\n\n        const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;\n        const bool is_vision = mode == GGML_ROPE_TYPE_VISION;\n\n        ggml_tensor * pos;\n        if (is_mrope || is_vision) {\n            pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4);\n        } else {\n            pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);\n        }\n        ggml_set_name(pos, \"pos\");\n\n        ggml_tensor * freq = nullptr;\n        if (ff) {\n            freq = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2);\n            ggml_set_name(freq, \"freq\");\n        }\n\n        ggml_tensor * out;\n        if (is_mrope) {\n            if (is_vision) {\n                GGML_ASSERT(n_dims/4 > 0);\n                int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate\n                if (forward) {\n                    if (inplace) {\n                        out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);\n                    } else {\n                        out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);\n                    }\n                } else {\n                    out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);\n                }\n            } else {\n                GGML_ASSERT(n_dims/3 > 0);\n                int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};\n                if (forward) {\n                    if (inplace) {\n                        out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);\n                    } else {\n                        out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);\n                    }\n                } else {\n                    out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);\n                }\n            }\n        } else {\n            if (forward) {\n                if (inplace) {\n                    out = ggml_rope_ext_inplace(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);\n                } else {\n                    out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);\n                }\n            } else {\n                out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);\n            }\n\n            // TODO: add test with a non-contiguous view as input ; this case is needed for build_rope_2d in clip.cpp\n        }\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (t->type == GGML_TYPE_I32) {\n                // pos\n                const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2];\n                std::vector<int> data(num_pos_ids);\n                for (int i = 0; i < num_pos_ids; i++) {\n                    data[i] = rand() % n_ctx;\n                }\n                ggml_backend_tensor_set(t, data.data(), 0, num_pos_ids * sizeof(int));\n            } else {\n                if (t->ne[0] == n_dims/2) {\n                    // frequency factors in the range [0.9f, 1.1f]\n                    init_tensor_uniform(t, 0.9f, 1.1f);\n                } else {\n                    init_tensor_uniform(t);\n                }\n            }\n        }\n    }\n\n    double max_maa_err() override {\n        return 1e-3;\n    }\n\n    bool grad_precise() override {\n        return true;\n    }\n};\n\n// GGML_OP_POOL2D\nstruct test_pool2d : public test_case {\n    enum ggml_op_pool pool_type;\n    const ggml_type type_input;\n    const std::array<int64_t, 4> ne_input;\n    // kernel size\n    const int k0;\n    const int k1;\n    // stride\n    const int s0;\n    const int s1;\n    // padding\n    const int p0;\n    const int p1;\n\n    std::string vars() override {\n        return VARS_TO_STR9(pool_type, type_input, ne_input, k0, k1, s0, s1, p0, p1);\n    }\n\n    test_pool2d(ggml_op_pool pool_type = GGML_OP_POOL_AVG,\n            ggml_type type_input = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]\n            int k0 = 3, int k1 = 3,\n            int s0 = 1, int s1 = 1,\n            int p0 = 1, int p1 = 1)\n        : pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), k1(k1), s0(s0), s1(s1), p0(p0), p1(p1) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());\n        ggml_set_param(input);\n        ggml_set_name(input, \"input\");\n\n        ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_POOL1D\nstruct test_pool1d : public test_case {\n    enum ggml_op_pool pool_type;\n    const ggml_type type_input;\n    const std::array<int64_t, 4> ne_input;\n    const int k0;\n    const int s0;\n    const int p0;\n\n    std::string vars() override {\n        return VARS_TO_STR6(pool_type, type_input, ne_input, k0, s0, p0);\n    }\n\n    test_pool1d(ggml_op_pool pool_type = GGML_OP_POOL_AVG,\n                ggml_type type_input = GGML_TYPE_F32,\n                std::array<int64_t,4> ne_input = {10, 1, 1, 1},\n                int k0 = 3, int s0 = 3, int p0 = 0)\n        : pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), s0(s0), p0(p0) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());\n        ggml_set_param(input);\n        ggml_set_name(input, \"input\");\n\n        ggml_tensor * out = ggml_pool_1d(ctx, input, pool_type, k0, s0, p0);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_CONV_TRANSPOSE_1D\nstruct test_conv_transpose_1d : public test_case {\n    const std::array<int64_t, 4> ne_input;\n    const std::array<int64_t, 4> ne_kernel;\n\n    const int s0; // stride\n    const int p0; // padding\n    const int d0; // dilation\n\n    std::string vars() override {\n        return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0);\n    }\n\n    test_conv_transpose_1d(std::array<int64_t, 4> ne_input = {197, 32, 1, 1}, // [input_width, input_channels, 1 /* assert in cpu kernel*/, 1 (should be batch)]\n                           std::array<int64_t, 4> ne_kernel = {16, 32, 32, 1}, // [kernel_width, output_channels, input_channels, 1 (should be batch)]\n                           int s0 = 1, int p0 = 0, int d0 = 1)\n        : ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), p0(p0), d0(d0) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());\n        ggml_set_name(input, \"input\");\n\n        ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());\n        ggml_set_name(kernel, \"kernel\");\n\n        ggml_tensor * out = ggml_conv_transpose_1d(ctx, kernel, input, s0, p0, d0);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_CONV_TRANSPOSE_2D\nstruct test_conv_transpose_2d : public test_case {\n    const std::array<int64_t, 4> ne_input;\n    const std::array<int64_t, 4> ne_kernel;\n    const int stride;\n\n    std::string vars() override {\n        return VARS_TO_STR3(ne_input, ne_kernel, stride);\n    }\n\n    double max_nmse_err() override {\n        return 5e-4; // The default 1e-7 is too small for Vulkan.\n    }\n\n    test_conv_transpose_2d(std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]\n                           std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]\n                           int stride = 1)\n        : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride){}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());\n        ggml_set_name(input, \"input\");\n\n        ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne_kernel.data());\n        ggml_set_name(kernel, \"kernel\");\n\n        ggml_tensor * out = ggml_conv_transpose_2d_p0(ctx, kernel, input, stride);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_IM2COL\nstruct test_im2col : public test_case {\n    const ggml_type type_input;\n    const ggml_type type_kernel;\n    const ggml_type dst_type;\n    const std::array<int64_t, 4> ne_input;\n    const std::array<int64_t, 4> ne_kernel;\n    // stride\n    const int s0;\n    const int s1;\n    // padding\n    const int p0;\n    const int p1;\n    // dilation\n    const int d0;\n    const int d1;\n    // mode\n    const bool is_2D;\n\n    std::string vars() override {\n        return VARS_TO_STR12(type_input, type_kernel, dst_type, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D);\n    }\n\n    test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]\n            std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]\n            int s0 = 1, int s1 = 1,\n            int p0 = 1, int p1 = 1,\n            int d0 = 1, int d1 = 1,\n            bool is_2D = true)\n        : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());\n        ggml_set_param(input);\n        ggml_set_name(input, \"input\");\n\n        ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());\n        ggml_set_name(kernel, \"kernel\");\n\n        ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, dst_type);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_IM2COL_3D\nstruct test_im2col_3d : public test_case {\n    const ggml_type type_input;\n    const ggml_type type_kernel;\n    const ggml_type dst_type;\n    const std::array<int64_t, 4> ne_input;\n    const std::array<int64_t, 4> ne_kernel;\n    // stride\n    const int s0;\n    const int s1;\n    const int s2;\n    // padding\n    const int p0;\n    const int p1;\n    const int p2;\n    // dilation\n    const int d0;\n    const int d1;\n    const int d2;\n\n    const int64_t IC;\n    const bool v;\n\n    std::string vars() override {\n        return VARS_TO_STR16(type_input, type_kernel, dst_type, ne_input, ne_kernel, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, v);\n    }\n\n    test_im2col_3d(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32,\n                std::array<int64_t, 4> ne_input = {10, 10, 10, 9}, // [OC*IC, KD, KH, KW]\n                std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [N*IC, ID, IH, IW]\n                int64_t IC = 3,\n                int s0 = 1, int s1 = 1, int s2 = 1,\n                int p0 = 1, int p1 = 1, int p2 = 1,\n                int d0 = 1, int d1 = 1, int d2 = 1,\n                bool v = false)\n        : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), IC(IC), v(v) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());\n        ggml_set_param(input);\n        ggml_set_name(input, \"input\");\n\n        if (v) {\n            input = ggml_view_4d(ctx, input, ne_input[0] - 2, ne_input[1] - 2, ne_input[2] - 2, ne_input[3] - 2, input->nb[1], input->nb[2], input->nb[3], 0);\n            ggml_set_name(input, \"view_of_input\");\n        }\n\n        ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());\n        ggml_set_name(kernel, \"kernel\");\n\n        ggml_tensor * out = ggml_im2col_3d(ctx, kernel, input, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, dst_type);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// CONV_2D\nstruct test_conv_2d : public test_case {\n    const std::array<int64_t, 4> ne_input;\n    const std::array<int64_t, 4> ne_kernel;\n    const ggml_type              type_kernel;\n    const int                    stride0;\n    const int                    stride1;\n    const int                    padding0;\n    const int                    padding1;\n    const int                    dilation0;\n    const int                    dilation1;\n    // Whether the inputs are contiguous in the channel dim or the width dim\n    const bool                   cwhn;\n\n    // If true, the direct CONV_2D will be used in the graph, otherwise it\n    // uses ggml_conv_2d:\n    // * if the program is called with -o CONV_2D_DIRECT_IMPL, the\n    // CONV_2D graph will be built, while\n    // * if the program is called with -o CONV_2D_INDIRECT_IMPL, the\n    // IM2COL -> MUL_MM graph will be built.\n\n    std::string vars() override {\n        return VARS_TO_STR10(ne_input, ne_kernel, type_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn);\n    }\n\n    double max_nmse_err() override {\n        return 5e-4;\n    }\n\n    uint64_t op_flops(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        // Just counting matmul costs:\n        // KxCRS @ CRSxNPQ = KxNPQ --> KxNPQx(CRS+CRS-1) flops\n\n        // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)\n        auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {\n            return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;\n        };\n\n        int64_t W    = ne_input[0];\n        int64_t H    = ne_input[1];\n        int64_t KW   = ne_kernel[0];\n        int64_t KH   = ne_kernel[1];\n        int64_t Cin  = ne_kernel[2];\n        int64_t Cout = ne_kernel[3];\n        int64_t N    = ne_input[3];\n        int64_t OH   = calc_conv_output_size(H, KH, stride0, padding0, dilation0);\n        int64_t OW   = calc_conv_output_size(W, KW, stride0, padding0, dilation0);\n\n        int64_t K   = Cout;\n        int64_t CRS = Cin * KH * KW;\n        int64_t NPQ = N * OH * OW;\n\n        return K * NPQ * (2 * CRS - 1);\n    }\n\n    test_conv_2d(std::array<int64_t, 4> ne_input  = { 64, 64, 16, 1 },\n                 std::array<int64_t, 4> ne_kernel = { 3, 3, 1, 16 }, ggml_type type_kernel = GGML_TYPE_F32, int stride0 = 1,\n                 int stride1 = 1, int padding0 = 0, int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false) :\n        ne_input(ne_input),\n        ne_kernel(ne_kernel),\n        type_kernel(type_kernel),\n        stride0(stride0),\n        stride1(stride1),\n        padding0(padding0),\n        padding1(padding1),\n        dilation0(dilation0),\n        dilation1(dilation1),\n        cwhn(cwhn) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());\n        ggml_set_name(input, \"input\");\n\n        ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());\n        ggml_set_name(kernel, \"kernel\");\n\n        if (cwhn) {\n            // change memory layout to channel-most-contiguous (CWHN),\n            // then permute it back so NE matches the original input\n            input  = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));\n            input  = ggml_permute(ctx, input, 2, 0, 1, 3);\n            kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));\n            kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);\n        }\n\n        ggml_tensor * out =\n            ggml_conv_2d_direct(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1);\n        ggml_set_name(out, \"out\");\n        return out;\n    }\n};\n\n// GGML_OP_CONV_2D_DW\nstruct test_conv_2d_dw : public test_case {\n    const std::array<int64_t, 4> ne_input;\n    const std::array<int64_t, 4> ne_kernel;\n    const int stride;\n    const int padding;\n    const int dilation;\n    const bool cwhn;\n\n    std::string vars() override {\n        return VARS_TO_STR6(ne_input, ne_kernel, stride, padding, dilation, cwhn);\n    }\n\n    test_conv_2d_dw(std::array<int64_t, 4> ne_input = {64, 64, 16, 1},\n            std::array<int64_t, 4> ne_kernel = {3, 3, 1, 16},\n            int stride = 1, int padding = 0, int dilation = 1, bool cwhn = false)\n        : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride), padding(padding), dilation(dilation), cwhn(cwhn) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());\n        ggml_set_name(input, \"input\");\n\n        ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());\n        ggml_set_name(kernel, \"kernel\");\n\n        if (cwhn) {\n            // change memory layout to channel-most-contiguous (CWHN),\n            // then permute it back so NE matches the original input\n            input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));\n            input = ggml_permute(ctx, input, 2, 0, 1, 3);\n            kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));\n            kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);\n        }\n\n        ggml_tensor * out = ggml_conv_2d_dw_direct(\n            ctx, kernel, input,\n            stride, stride, padding, padding, dilation, dilation);\n        ggml_set_name(out, \"out\");\n        return out;\n    }\n};\n\n// GGML_OP_CONV_3D\nstruct test_conv_3d : public test_case {\n    // Logical 5D dimensions\n    const int64_t N, IC, ID, IH, IW;\n    const int64_t OC, KD, KH, KW;\n    // Conv params\n    const int s0, s1, s2;\n    const int p0, p1, p2;\n    const int d0, d1, d2;\n    // Types\n    const ggml_type type_kernel;\n\n    std::string op_desc(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return \"CONV_3D\";\n    }\n\n    std::string vars() override {\n        return VARS_TO_STR11(N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1) + \",\" +\n               VARS_TO_STR8(s2, p0, p1, p2, d0, d1, d2, type_kernel);\n    }\n\n    double max_nmse_err() override {\n        return 5e-4;\n    }\n\n    uint64_t op_flops(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {\n            return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;\n        };\n        const int64_t OD = calc_conv_output_size(ID, KD, s2, p2, d2);\n        const int64_t OH = calc_conv_output_size(IH, KH, s1, p1, d1);\n        const int64_t OW = calc_conv_output_size(IW, KW, s0, p0, d0);\n\n        return (uint64_t)N * OC * OD * OH * OW * (2 * IC * KD * KH * KW - 1);\n    }\n\n    test_conv_3d(\n        int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW,\n        int64_t OC, int64_t KD, int64_t KH, int64_t KW,\n        int s0, int s1, int s2,\n        int p0, int p1, int p2,\n        int d0, int d1, int d2,\n        ggml_type type_kernel\n    ) : N(N), IC(IC), ID(ID), IH(IH), IW(IW),\n        OC(OC), KD(KD), KH(KH), KW(KW),\n        s0(s0), s1(s1), s2(s2),\n        p0(p0), p1(p1), p2(p2),\n        d0(d0), d1(d1), d2(d2),\n        type_kernel(type_kernel) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        // GGML input tensor is packed as [W, H, D, C*N]\n        const int64_t ne_input[] = {IW, IH, ID, IC * N};\n        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input);\n        ggml_set_name(input, \"input\");\n\n        // GGML kernel tensor is packed as [KW, KH, KD, IC*OC]\n        const int64_t ne_kernel[] = {KW, KH, KD, IC * OC};\n        ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel);\n        ggml_set_name(kernel, \"kernel\");\n\n        ggml_tensor * out = ggml_conv_3d_direct(ctx, kernel, input, s0, s1, s2, p0, p1, p2, d0, d1, d2, (int)IC, (int)N, (int)OC);\n        ggml_set_name(out, \"out\");\n        return out;\n    }\n};\n\n// GGML_OP_CONCAT\nstruct test_concat : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne_a;\n    const int64_t ne_b_d;\n    const int dim;\n    const int v; // view (1 << 0: non-cont a, 1 << 1: non-cont b)\n\n    std::string vars() override {\n        return VARS_TO_STR5(type, ne_a, ne_b_d, dim, v);\n    }\n\n    test_concat(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_a = {10, 5, 5, 5},\n            int64_t ne_b_d = 5,\n            int dim = 2, int v = 0)\n        : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim), v(v) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        auto ne_b = ne_a;\n        ne_b[dim] = ne_b_d;\n        ggml_tensor * a;\n        if (v & 1) {\n            auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;\n            a = ggml_new_tensor(ctx, type, 4, ne.data());\n            ggml_set_name(a, \"a\");\n\n            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);\n            ggml_set_name(a, \"view_of_a\");\n        } else {\n            a = ggml_new_tensor(ctx, type, 4, ne_a.data());\n            ggml_set_name(a, \"a\");\n        }\n        ggml_tensor * b;\n        if (v & 2) {\n            auto ne = ne_b; ne[0] *= 3; ne[1] *= 2; ne[2] *= 4;\n            b = ggml_new_tensor(ctx, type, 4, ne.data());\n            ggml_set_name(b, \"b\");\n\n            b = ggml_view_4d(ctx, b, ne_b[0], ne_b[1], ne_b[2], ne_b[3], b->nb[1], b->nb[2], b->nb[3], 0);\n            ggml_set_name(b, \"view_of_b\");\n        } else {\n            b = ggml_new_tensor(ctx, type, 4, ne_b.data());\n            ggml_set_name(b, \"b\");\n        }\n\n        ggml_tensor * out = ggml_concat(ctx, a, b, dim);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_ARGSORT\nstruct test_argsort : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    ggml_sort_order order;\n\n    std::string vars() override {\n        return VARS_TO_STR3(type, ne, order);\n    }\n\n    test_argsort(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {16, 10, 10, 10},\n            ggml_sort_order order = GGML_SORT_ORDER_ASC)\n        : type(type), ne(ne), order(order) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_argsort(ctx, a, order);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        std::random_device rd;\n        std::default_random_engine rng(rd());\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (t->type == GGML_TYPE_I32) {\n                // indices\n                std::vector<int> data(ggml_nelements(t));\n                for (int i = 0; i < ggml_nelements(t); i++) {\n                    data[i] = rand();\n                }\n                std::shuffle(data.begin(), data.end(), rng);\n                ggml_backend_tensor_set(t, data.data(), 0, ne[0]*ne[1]*ne[2]*ne[3] * sizeof(int));\n            } else if (t->type == GGML_TYPE_F32) {\n                // initialize with unique values to avoid ties\n                for (int64_t r = 0; r < ggml_nrows(t); r++) {\n                    std::vector<float> data(t->ne[0]);\n                    for (int i = 0; i < t->ne[0]; i++) {\n                        data[i] = i;\n                    }\n                    std::shuffle(data.begin(), data.end(), rng);\n                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));\n                }\n            } else {\n                GGML_ABORT(\"fatal error\");\n            }\n        }\n    }\n};\n\n// GGML_OP_TOP_K\nstruct test_top_k : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const int k;\n    const bool ties;\n    ggml_tensor * input {};\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne, k, ties);\n    }\n\n    test_top_k(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {16, 10, 10, 10},\n            int k = 4, bool ties = false)\n        : type(type), ne(ne), k(k), ties(ties) {}\n\n    double max_err() override {\n        return 0.0;\n    }\n\n    // When there are ties, only validate the final result.\n    // The logic in err can't handle the sentinel tensors.\n    bool run_whole_graph() override { return ties; }\n\n    double err(const float * a, const float * b, size_t n) override {\n        // When there are no ties, we expect the exact same set of indices,\n        // but possibly in a different order. When there are ties, the indices\n        // can be different but the input values they correspond to should be\n        // the same. The logic for ties could work for non-ties, but only for\n        // the output tensor, not for the sentinel tensors.\n        if (ties) {\n            std::vector<float> src(ggml_nelements(input));\n\n            ggml_backend_tensor_get(input, src.data(), 0, ggml_nelements(input) * ggml_type_size(type));\n\n            double diff = 0.0f;\n\n            GGML_ASSERT(n == (size_t)(ggml_nrows(input) * k));\n            int64_t cols = input->ne[0];\n            std::vector<int32_t> ia(k);\n            std::vector<int32_t> ib(k);\n            std::vector<float> asrc(k);\n            std::vector<float> bsrc(k);\n            for (int64_t r = 0; r < ggml_nrows(input); r++) {\n                // Convert indices for the row back to integer\n                for (int64_t c = 0; c < k; c++) {\n                    ia[c] = (int32_t)a[r * k + c];\n                    ib[c] = (int32_t)b[r * k + c];\n                }\n                // The src values for each row should match.\n                for (int64_t c = 0; c < k; c++) {\n                    asrc[c] = src[r * cols + ia[c]];\n                    bsrc[c] = src[r * cols + ib[c]];\n                }\n                diff += jdst(asrc.data(), bsrc.data(), k);\n                // There should be no duplicate indices\n                std::sort(ia.begin(), ia.end());\n                std::sort(ib.begin(), ib.end());\n                if (std::adjacent_find(ia.begin(), ia.end()) != ia.end()) {\n                    diff += 1;\n                }\n                if (std::adjacent_find(ib.begin(), ib.end()) != ib.end()) {\n                    diff += 1;\n                }\n            }\n            return diff;\n        } else {\n            std::vector<int32_t> ia(n);\n            std::vector<int32_t> ib(n);\n\n            double diff = 0.0f;\n\n            for (size_t i = 0; i < n; i++) {\n                ia[i] = (int32_t) a[i];\n                ib[i] = (int32_t) b[i];\n\n                // penalize the result if the data is not integer valued\n                diff += std::fabs(a[i] - ia[i]);\n                diff += std::fabs(b[i] - ib[i]);\n            }\n\n            return diff + jdst(ia.data(), ib.data(), n);\n        }\n    }\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\");\n\n        // Save 'a' for err()\n        input = a;\n\n        ggml_tensor * out = ggml_top_k(ctx, a, k);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        std::random_device rd;\n        std::default_random_engine rng(rd());\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            int tie_denom = std::max(1, std::min(10, k / 2));\n            for (int64_t r = 0; r < ggml_nrows(t); r++) {\n                std::vector<float> data(t->ne[0]);\n                for (int i = 0; i < t->ne[0]; i++) {\n                    if (ties) {\n                        // integer division to introduce duplicates\n                        data[i] = i / tie_denom;\n                    } else {\n                        data[i] = i;\n                    }\n                }\n                std::shuffle(data.begin(), data.end(), rng);\n                ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));\n            }\n        }\n    }\n};\n\nenum MoeGatingFunc {\n    GATING_FUNC_SOFTMAX,\n    GATING_FUNC_SIGMOID,\n    GATING_FUNC_SOFTMAX_WEIGHT,\n};\n\nstruct test_topk_moe : public test_case {\n    const std::array<int64_t, 4> ne;\n    const int n_expert_used;\n    const bool with_norm;\n    const bool bias_probs;\n    const MoeGatingFunc gating_func;\n    const float scale_w;\n    ggml_tensor * weights {};\n    ggml_tensor * selected_experts {};\n\n    test_topk_moe(std::array<int64_t, 4> ne              = { 10, 5, 1, 1 },\n                  int                    n_expert_used   = 1,\n                  bool                   with_norm       = false,\n                  bool                   bias_probs      = false,\n                  MoeGatingFunc          gating_func     = GATING_FUNC_SOFTMAX,\n                  float                  scale_w         = 0.0f) :\n        ne(ne),\n        n_expert_used(n_expert_used),\n        with_norm(with_norm),\n        bias_probs(bias_probs),\n        gating_func(gating_func),\n        scale_w(scale_w) {\n        GGML_ASSERT(n_expert_used <= ne[0]);\n    }\n\n    std::string vars() override { return VARS_TO_STR6(ne, n_expert_used, with_norm, bias_probs, gating_func, scale_w); }\n\n    std::string op_desc(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return \"TOPK_MOE\";\n    }\n\n    bool run_whole_graph() override { return true; }\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        const int n_expert = ne[0];\n        const int n_tokens = ne[1];\n\n        ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());\n        ggml_tensor * probs            =\n            (gating_func == GATING_FUNC_SOFTMAX) ? ggml_soft_max(ctx, logits) :\n            (gating_func == GATING_FUNC_SIGMOID) ? ggml_sigmoid(ctx, logits) : logits;\n        ggml_set_name(probs, \"probs\");\n\n        ggml_tensor * selection_probs = probs;\n        if (bias_probs) {\n            ggml_tensor * exp_probs_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);\n            ggml_set_name(exp_probs_b, \"exp_probs_b\");\n            selection_probs = ggml_add(ctx, probs, exp_probs_b);\n            ggml_set_name(selection_probs, \"selection_probs\");\n        }\n\n        selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]\n        ggml_set_name(selected_experts, \"selected_experts\");\n\n        weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]\n        ggml_set_name(weights, \"weights\");\n\n        if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {\n            weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);\n            weights = ggml_soft_max(ctx, weights);  // [n_expert_used, n_tokens]\n            weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);\n        }\n\n        if (with_norm) {\n            weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);\n            ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]\n            ggml_set_name(weights_sum, \"weights_sum\");\n\n            weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);\n            weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]\n            weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);\n        }\n\n        if (scale_w) {\n            weights = ggml_scale(ctx, weights, scale_w);\n        }\n\n        ggml_set_name(weights, \"weights\");\n        return weights;\n    }\n    // Verify two outputs\n    std::vector<ggml_tensor *> fusion_test_nodes() override { return { selected_experts, weights }; }\n\n    // allow output in arbitrary order\n    double err(const float * a, const float * b, size_t n) override {\n        std::vector<float> a2(n);\n        std::vector<float> b2(n);\n        for (size_t i = 0; i < n; ++i) {\n            a2[i] = a[i];\n            b2[i] = b[i];\n        }\n        std::sort(a2.begin(), a2.end());\n        std::sort(b2.begin(), b2.end());\n        return nmse(a2.data(), b2.data(), n);\n    }\n};\n\nstruct test_mul_mat_vec_fusion : public test_case {\n    const ggml_type type;\n    const ggml_glu_op glu_op;\n    const int64_t m;\n    const int64_t n;\n    const int64_t k;\n    const bool use_id;\n    const int n_mats;\n    const int n_used;\n    const bool b;        // broadcast b matrix (only for use_id)\n    const bool with_bias;\n    const bool with_gate;\n    std::array<int64_t, 2> batch_dims;\n\n    test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k,\n                        bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true,\n                        std::array<int64_t, 2> batch_dims = {4, 2})\n    : type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate), batch_dims(batch_dims) {\n        if (use_id) {\n            GGML_ASSERT(n_used <= n_mats);\n        }\n    }\n\n    std::string vars() override {\n        return VARS_TO_STR12(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate, batch_dims);\n    }\n\n    std::string op_desc(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return \"MUL_MAT_VEC_FUSION\";\n    }\n\n    bool run_whole_graph() override { return true; }\n\n    ggml_tensor * build_gate(ggml_context * ctx, ggml_tensor * ffn_gate, ggml_tensor * ffn_up) {\n        ggml_tensor * out = nullptr;\n        if (with_gate) {\n            if (glu_op == GGML_GLU_OP_SWIGLU_OAI) {\n                constexpr float alpha = 1.702f;\n                constexpr float limit = 7.0f;\n                out = ggml_swiglu_oai(ctx, ffn_gate, ffn_up, alpha, limit);\n            } else {\n                out = ggml_glu_split(ctx, ffn_gate, ffn_up, glu_op);\n            }\n        }\n        return out;\n    }\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        if (!use_id) {\n            const int              channels = batch_dims[0];\n            const int              samples  = batch_dims[1];\n            std::array<int64_t, 4> ne       = { k, m, channels, samples };\n            std::array<int64_t, 4> ne0      = { k, n, channels, samples };\n\n            ggml_tensor * cur  = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());\n            ggml_tensor * gate = with_gate ? ggml_new_tensor(ctx, type, 4, ne0.data()) : nullptr;\n            ggml_tensor * up   = ggml_new_tensor(ctx, type, 4, ne0.data());\n\n            ggml_tensor * ffn_up = ggml_mul_mat(ctx, up, cur);\n            if (with_bias) {\n                std::array<int64_t, 4> bias_ne = { ffn_up->ne[0], 1, channels, samples };\n                ggml_tensor * up_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data());\n                ffn_up = ggml_add(ctx, ffn_up, up_bias);\n            }\n\n            ggml_tensor * ffn_gate = with_gate ? ggml_mul_mat(ctx, gate, cur) : nullptr;\n            if (with_bias && with_gate) {\n                std::array<int64_t, 4> bias_ne   = { ffn_gate->ne[0], 1, channels, samples };\n                ggml_tensor * gate_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data());\n                ffn_gate = ggml_add(ctx, ffn_gate, gate_bias);\n            }\n\n            ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;\n\n            std::array<int64_t, 4> bias2_ne   = { out->ne[0], 1, channels, samples };\n            ggml_tensor * bias2 = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias2_ne.data());\n            out = ggml_add(ctx, out, bias2);\n\n            ggml_set_name(out, \"out\");\n            return out;\n        } else {\n            ggml_tensor * gates = ggml_new_tensor_3d(ctx, type, k, n, n_mats);\n            ggml_tensor * ups   = ggml_new_tensor_3d(ctx, type, k, n, n_mats);\n            ggml_tensor * ids   = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, m);\n\n            if (n_used != n_mats) {\n                ids = ggml_view_2d(ctx, ids, n_used, m, ids->nb[1], 0);\n            }\n\n            ggml_tensor * cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k, this->b ? 1 : n_used, m);\n            ggml_set_name(cur, \"cur\");\n\n            ggml_tensor * ffn_up = ggml_mul_mat_id(ctx, ups, cur, ids);\n            if (with_bias) {\n                ggml_tensor * up_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_up->ne[0], n_mats);\n                ffn_up = ggml_add_id(ctx, ffn_up, up_bias_param, ids);\n            }\n\n            ggml_tensor * ffn_gate = with_gate? ggml_mul_mat_id(ctx, gates, cur, ids) : nullptr;\n            if (with_bias && with_gate) {\n                ggml_tensor * gate_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_gate->ne[0], n_mats);\n                ffn_gate = ggml_add_id(ctx, ffn_gate, gate_bias_param, ids);\n            }\n\n            ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;\n\n            std::array<int64_t, 4> scale_ne { 1, out->ne[1], out->ne[2], out->ne[3] };\n            ggml_tensor * scale = ggml_new_tensor(ctx, out->type, 4, scale_ne.data());\n            out = ggml_mul(ctx, out, scale);\n\n            ggml_set_name(out, \"out\");\n            return out;\n        }\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        if (!use_id) {\n            for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n                init_tensor_uniform(t);\n            }\n        } else {\n            init_mul_mat_id_tensors(ctx, n_mats);\n        }\n    }\n\n    double max_nmse_err() override {\n        return 5e-3;\n    }\n};\n\n// GGML_OP_SUM\nstruct test_sum : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const std::array<int64_t, 4> permute;\n    bool _use_permute;\n\n    std::string vars() override {\n        std::string v = VARS_TO_STR2(type, ne);\n        if (_use_permute) v += \",\" + VAR_TO_STR(permute);\n        return v;\n    }\n\n    test_sum(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 5, 4, 3},\n            std::array<int64_t, 4> permute = {0, 0, 0, 0})\n        : type(type), ne(ne), permute(permute),\n            _use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        if (_use_permute) {\n            a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]);\n            ggml_set_name(a, \"a_permuted\");\n        }\n\n        ggml_tensor * out = ggml_sum(ctx, a);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    float grad_eps() override {\n        return 0.1f * sqrtf(ne[0]*ne[1]*ne[2]*ne[3]);\n    }\n\n    // Don't center the distribution around zero. Helps to avoid catastrophic cancellation.\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -0.9f, 1.1f);\n        }\n    }\n};\n\n// GGML_OP_SUM_ROWS\nstruct test_sum_rows : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const bool permute;\n    const bool slice;\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne, permute, slice);\n    }\n\n    test_sum_rows(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 5, 4, 3},\n            bool permute = false, bool slice = false)\n        : type(type), ne(ne), permute(permute), slice(slice) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        if (slice) {\n            a = ggml_view_4d(ctx, a,\n                             ne[0], ne[1], ne[2] / 2, ne[3] - 1,\n                             a->nb[1], a->nb[2] * 2, a->nb[3], /*offset=*/a->nb[3]);\n        }\n        if (permute) {\n            a = ggml_permute(ctx, a, 0, 2, 3, 1);\n        }\n\n        ggml_tensor * out = ggml_sum_rows(ctx, a);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_MEAN\nstruct test_mean : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_mean(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 5, 4, 3})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_mean(ctx, a);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    float grad_eps() override {\n        return 0.1f * ne[0]*ne[1]*ne[2]*ne[3];\n    }\n\n    // Don't center the distribution around zero. Helps to avoid catastrophic cancellation.\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -0.9f, 1.1f);\n        }\n    }\n};\n\n// GGML_OP_UPSCALE\nstruct test_upscale : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const int32_t scale_factor;\n    const bool transpose;\n    const ggml_scale_mode mode;\n\n    std::string vars() override {\n        return VARS_TO_STR5(type, ne, scale_factor, mode, transpose);\n    }\n\n    test_upscale(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {512, 512, 3, 1},\n            int32_t scale_factor = 2, ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST, bool transpose = false)\n        : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose), mode(mode) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\");\n\n        if (transpose) {\n            a = ggml_transpose(ctx, a);\n            ggml_set_name(a, \"a_transposed\");\n        }\n\n        ggml_tensor * out = ggml_upscale(ctx, a, scale_factor, mode);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_UPSCALE (via ggml_interpolate)\nstruct test_interpolate : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const std::array<int64_t, 4> ne_tgt;\n    const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST;\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne, ne_tgt, mode);\n    }\n\n    test_interpolate(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne     = {2, 5,  7, 11},\n            std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13},\n            ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST)\n        : type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_interpolate(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3], mode);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_GROUP_NORM\nstruct test_group_norm : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const int32_t num_groups;\n    const float eps;\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne, num_groups, eps);\n    }\n\n    test_group_norm(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {64, 64, 320, 1},\n            int32_t num_groups = 32,\n            float eps = 1e-6f)\n        : type(type), ne(ne), num_groups(num_groups), eps(eps) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_group_norm(ctx, a, num_groups, eps);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_GROUP_NORM + GGML_OP_MUL + GGML_OP_ADD\nstruct test_group_norm_mul_add : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    int num_groups;\n    float eps;\n\n    std::string op_desc(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return \"GROUP_NORM_MUL_ADD\";\n    }\n\n    bool run_whole_graph() override { return true; }\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne, num_groups, eps);\n    }\n\n    test_group_norm_mul_add(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {128, 1, 1, 1},\n            int num_groups = 4,\n            float eps = 1e-5f)\n        : type(type), ne(ne), num_groups(num_groups), eps(eps) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_tensor * w = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(a); ggml_set_param(w); ggml_set_param(b);\n        ggml_set_name(a, \"a\"); ggml_set_name(w, \"w\"); ggml_set_name(b, \"b\");\n        ggml_tensor * n = ggml_group_norm(ctx, a, num_groups, eps);\n        ggml_tensor * m = ggml_mul(ctx, n, w);\n        ggml_tensor * out = ggml_add(ctx, m, b);\n        ggml_set_name(out, \"out\");\n        return out;\n    }\n};\n\n// GGML_OP_L2_NORM\nstruct test_l2_norm : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const float eps;\n    bool v;\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne, eps, v);\n    }\n\n    test_l2_norm(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {64, 64, 320, 1},\n            float eps = 1e-12f,\n            bool v = false)\n        : type(type), ne(ne), eps(eps), v(v) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(a, \"a\");\n\n        if (v) {\n            a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);\n            ggml_set_name(a, \"view of a\");\n        }\n\n        ggml_tensor * out = ggml_l2_norm(ctx, a, eps);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_ACC\nstruct test_acc : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne_a;\n    const std::array<int64_t, 4> ne_b;\n    const int64_t stride_dim;\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne_a, ne_b, stride_dim);\n    }\n\n    test_acc(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_a = {256, 17, 2, 3},\n            std::array<int64_t, 4> ne_b = {256, 16, 2, 3},\n            uint64_t stride_dim = -1)\n        : type(type), ne_a(ne_a), ne_b(ne_b), stride_dim(stride_dim) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * b;\n        if (stride_dim == 1 || stride_dim == 2 || stride_dim == 3) {\n            // Create a larger tensor and take a view at a non-zero offset.\n            // This tests that the backend correctly handles b's data offset\n            std::array<int64_t, 4> ne_b_pad = {ne_b[0], ne_b[1], ne_b[2], ne_b[3]};\n            ne_b_pad[stride_dim] += 1;\n            ggml_tensor * b_pad = ggml_new_tensor(ctx, type, 4, ne_b_pad.data());\n            ggml_set_param(b_pad);\n            ggml_set_name(b_pad, \"b_pad\");\n            // View that skips the first row, so b has a non-zero byte offset\n            b = ggml_view_4d(ctx, b_pad,\n                ne_b[0], ne_b[1], ne_b[2], ne_b[3],\n                b_pad->nb[1], b_pad->nb[2], b_pad->nb[3],\n                b_pad->nb[1]);\n        } else {\n            b = ggml_new_tensor(ctx, type, 4, ne_b.data());\n            ggml_set_param(b);\n        }\n        ggml_set_name(b, \"b\");\n\n        // When ne_b[0] < ne_a[0], a->nb[1] != b->nb[1], so the stride\n        // parameters to ggml_acc don't match b's natural stride.\n        ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], 0);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_PAD\nstruct test_pad : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne_a;\n    const int pad_0;\n    const int pad_1;\n    const bool circular;\n\n    std::string vars() override {\n        return VARS_TO_STR5(type, ne_a, pad_0, pad_1, circular);\n    }\n\n    test_pad(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_a = {512, 512, 1, 1},\n            int pad_0 = 1, int pad_1 = 1, bool circular = false)\n        : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1), circular(circular) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = circular\n            ? ggml_pad_circular(ctx, a, pad_0, pad_1, 0, 0)\n            : ggml_pad(ctx, a, pad_0, pad_1, 0, 0);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_PAD (with extension)\nstruct test_pad_ext : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne_a;\n    const int lp0;\n    const int rp0;\n    const int lp1;\n    const int rp1;\n    const int lp2;\n    const int rp2;\n    const int lp3;\n    const int rp3;\n    const int tfrm; // 0 - none, 1 - non-cont, 2 - perm\n    const bool circular;\n\n    std::string vars() override {\n        return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, tfrm, circular);\n    }\n\n    test_pad_ext(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_a = {512, 512, 3, 1},\n            int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1,\n            int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1,\n            int tfrm = 0, bool circular = false)\n        : type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3),\n          tfrm(tfrm), circular(circular) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());\n        ggml_set_name(a, \"a\");\n\n        if (tfrm == 1) {\n            a = ggml_view_4d(ctx, a, (a->ne[0] + 1) / 2, (a->ne[1] + 1) / 2, (a->ne[2] + 1) / 2, (a->ne[3] + 1) / 2, a->nb[1], a->nb[2], a->nb[3], 0);\n            ggml_set_name(a, \"view of a\");\n        } else if (tfrm == 2) {\n            a = ggml_permute(ctx, a, 2, 1, 0, 3);\n            ggml_set_name(a, \"permuted a\");\n        }\n\n        ggml_tensor * out = circular\n            ? ggml_pad_ext_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3)\n            : ggml_pad_ext         (ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_PAD_REFLECT_1D\nstruct test_pad_reflect_1d : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne_a;\n    const int pad_0;\n    const int pad_1;\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne_a, pad_0, pad_1);\n    }\n\n    test_pad_reflect_1d(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_a = {512, 34, 2, 1},\n            int pad_0 = 10, int pad_1 = 9)\n        : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1)  {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 2, ne_a.data());\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_pad_reflect_1d(ctx, a, pad_0, pad_1);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_ROLL\nstruct test_roll : public test_case {\n    const int shift0;\n    const int shift1;\n    const int shift3;\n    const int shift4;\n\n    std::string vars() override {\n        return VARS_TO_STR4(shift0, shift1, shift3, shift4);\n    }\n\n    test_roll(int shift0 = 3, int shift1 = -2, int shift3 = 1, int shift4 = -1)\n        : shift0(shift0), shift1(shift1), shift3(shift3), shift4(shift4) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        int64_t ne[4] = {10, 5, 4, 3};\n        ggml_tensor * a = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_roll(ctx, a, shift0, shift1, shift3, shift4);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_ARANGE\nstruct test_arange : public test_case {\n    const ggml_type type;\n    const float start;\n    const float stop;\n    const float step;\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, start, stop, step);\n    }\n\n    test_arange(ggml_type type = GGML_TYPE_F32,\n            float start = 0.f, float stop = 10.f, float step = 1.f)\n        : type(type), start(start), stop(stop), step(step)  {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * out = ggml_arange(ctx, start, stop, step);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_TIMESTEP_EMBEDDING\nstruct test_timestep_embedding : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne_a;\n    const int dim;\n    const int max_period;\n\n    std::string vars() override {\n        return VARS_TO_STR4(type, ne_a, dim, max_period);\n    }\n\n    test_timestep_embedding(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_a = {2, 1, 1, 1},\n            int dim = 320, int max_period=10000)\n        : type(type), ne_a(ne_a), dim(dim), max_period(max_period)  {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_timestep_embedding(ctx, a, dim, max_period);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_LEAKY_RELU\nstruct test_leaky_relu : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne_a;\n    const float negative_slope;\n\n    std::string vars() override {\n        return VARS_TO_STR3(type, ne_a, negative_slope);\n    }\n\n    test_leaky_relu(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_a = {10, 5, 4, 3},\n            float negative_slope = 0.1f)\n        : type(type), ne_a(ne_a), negative_slope(negative_slope)  {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_leaky_relu(ctx, a, negative_slope, true);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_FLASH_ATTN_EXT\nstruct test_flash_attn_ext : public test_case {\n    const int64_t hsk; // K head size\n    const int64_t hsv; // V head size\n    const int64_t nh; // num heads\n    const std::array<int64_t, 2> nr23; // repeat in dim 2 and 3, tests for grouped-query attention\n    const int64_t kv; // kv size\n    const int64_t nb; // batch size\n\n    const bool mask; // use mask\n    const bool sinks; // use sinks\n\n    const float max_bias; // ALiBi\n    const float logit_softcap; // Gemma 2\n\n    const ggml_prec prec;\n    const ggml_type type_KV;\n    std::array<int32_t, 4> permute;\n\n    std::string vars() override {\n        return VARS_TO_STR13(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, permute);\n    }\n\n    double max_nmse_err() override {\n        return 5e-4;\n    }\n\n    uint64_t op_flops(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        // Just counting matmul costs:\n        // Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head\n        return (2 * nh*nr23[0] * nb * (hsk + hsv) * kv)*nr23[1];\n    }\n\n    test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array<int64_t, 2> nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8,\n                        bool mask = true, bool sinks = false, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,\n                        ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})\n        : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));\n        const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));\n\n        auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view) -> ggml_tensor * {\n            int64_t ne[4] = {ne0, ne1, ne2, ne3};\n            int64_t ne_perm[4];\n            for (int i = 0; i < 4; ++i) {\n                ne_perm[permute[i]] = ne[i];\n            }\n            ggml_tensor * t;\n            if (is_view) {\n                ggml_tensor * t0 = ggml_new_tensor_4d(ctx, type, ne_perm[0], 2*ne_perm[1], ne_perm[2], ne_perm[3]);\n                t = ggml_view_4d(ctx, t0, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3], t0->nb[1], t0->nb[2], t0->nb[3], 0);\n            } else {\n                t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);\n            }\n            if (permute != std::array<int32_t, 4>{0, 1, 2, 3}) {\n                t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]);\n            }\n            return t;\n        };\n\n        ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1], false);\n        ggml_set_name(q, \"q\");\n\n        ggml_tensor * k = create_permuted(type_KV,       hsk_padded, kv, nh,         nr23[1], true); // the K tensor is usually a view of the K cache\n        ggml_set_name(k, \"k\");\n\n        ggml_tensor * v = nullptr;\n        if (hsk_padded == 576 && hsv_padded == 512) {\n            // TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes\n\n            // in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models\n            // for more info:\n            //   - https://github.com/ggml-org/llama.cpp/pull/13435\n            //   - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392\n            //   - https://github.com/ggml-org/llama.cpp/pull/18986\n            v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0);\n        } else {\n            v = create_permuted(type_KV,       hsv_padded, kv, nh,         nr23[1], true); // the V tensor is usually a view of the V cache\n        }\n        ggml_set_name(v, \"v\");\n\n        ggml_tensor * m = nullptr;\n        if (mask) {\n            m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, nr23[1]);\n            ggml_set_name(m, \"m\");\n        }\n\n        ggml_tensor * s = nullptr;\n        if (sinks) {\n            s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, q->ne[2]);\n            ggml_set_name(s, \"s\");\n        }\n\n        ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);\n        ggml_flash_attn_ext_add_sinks(out, s);\n        ggml_flash_attn_ext_set_prec (out, prec);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (strcmp(t->name, \"s\") == 0) {\n                // make the sink values more noticeable in order to trigger a test failure when the implementation is wrong\n                init_tensor_uniform(t, -10.0f, 10.0f);\n            } else if (strcmp(t->name, \"m\") == 0) {\n                init_tensor_kq_mask(t);\n            } else {\n                init_tensor_uniform(t);\n            }\n        }\n    }\n\n    bool grad_precise() override {\n        return true;\n    }\n};\n\n// GGML_OP_CROSS_ENTROPY_LOSS\nstruct test_cross_entropy_loss : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_cross_entropy_loss(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 5, 4, 3})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_param(logits);\n        ggml_set_name(logits, \"logits\");\n\n        ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());\n        // The labels are assumed to be constant -> no gradients.\n        ggml_set_name(labels, \"labels\");\n\n        // Ensure labels add up to 1:\n        labels = ggml_soft_max(ctx, labels);\n        ggml_set_name(labels, \"labels_normalized\");\n\n        ggml_tensor * out = ggml_cross_entropy_loss(ctx, logits, labels);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        // For larger abs. diffs between logits softmax is more linear, therefore more precise num. gradients.\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -100.0f, 100.0f);\n        }\n    }\n\n    float grad_eps() override {\n        return 1.0f;\n    }\n\n    bool grad_precise() override {\n        return true;\n    }\n};\n\n// GGML_OP_CROSS_ENTROPY_LOSS_BACK\nstruct test_cross_entropy_loss_back : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_cross_entropy_loss_back(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 5, 4, 3})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * grad = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);\n        ggml_set_name(grad, \"grad\");\n\n        ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(logits, \"logits\");\n\n        ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());\n        ggml_set_name(labels, \"labels\");\n\n        // Ensure labels add up to 1:\n        labels = ggml_soft_max(ctx, labels);\n        ggml_set_name(labels, \"labels_normalized\");\n\n        ggml_tensor * out = ggml_cross_entropy_loss_back(ctx, grad, logits, labels);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_OPT_STEP_ADAMW\nstruct test_opt_step_adamw : public test_case {\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override {\n        return VARS_TO_STR2(type, ne);\n    }\n\n    test_opt_step_adamw(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = {10, 5, 4, 3})\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);\n        ggml_set_param(a); // Despite tensor a having gradients the output tensor will not.\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);\n        ggml_set_name(grad, \"grad\");\n\n        ggml_tensor * grad_m = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);\n        ggml_set_name(grad_m, \"grad_m\");\n\n        ggml_tensor * grad_v = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);\n        ggml_set_name(grad_v, \"grad_v\");\n\n        ggml_tensor * adamw_params = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 7);\n        ggml_set_name(adamw_params, \"adamw_params\");\n\n        ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, grad_m, grad_v, adamw_params);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, 0.0f, 1.0f); // grad_v and adamw_params need non-negative values.\n        }\n    }\n\n    bool grad_precise() override {\n        return true;\n    }\n};\n\n// GGML_OP_OPT_STEP_SGD\nstruct test_opt_step_sgd : public test_case {\n    const ggml_type              type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override { return VARS_TO_STR2(type, ne); }\n\n    test_opt_step_sgd(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = { 10, 5, 4, 3 })\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);\n        ggml_set_param(a);  // Despite tensor a having gradients the output tensor will not.\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);\n        ggml_set_name(grad, \"grad\");\n\n        ggml_tensor * sgd_params = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2);\n        ggml_set_name(sgd_params, \"sgd_params\");\n\n        ggml_tensor * out = ggml_opt_step_sgd(ctx, a, grad, sgd_params);\n\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, 0.0f, 1.0f);  // sgd_params need non-negative values.\n        }\n    }\n\n    bool grad_precise() override {\n        return true;\n    }\n};\n\n// GGML_OP_CUMSUM\nstruct test_cumsum : public test_case {\n    const ggml_type              type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override { return VARS_TO_STR2(type, ne); }\n\n    test_cumsum(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = { 10, 5, 4, 3 })\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_cumsum(ctx, a);\n\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -1.0f, 1.0f);\n        }\n    }\n};\n\n// GGML_OP_XIELU\nstruct test_xielu : public test_case {\n    const ggml_type              type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override { return VARS_TO_STR2(type, ne); }\n\n    test_xielu(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = { 10, 5, 4, 3 })\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        float alpha_n = 4.0f;\n        float alpha_p = 20.0f;\n        float beta = 0.5f;\n        float eps = 0.0000001f;\n\n        ggml_tensor * out = ggml_xielu(ctx, a, alpha_n, alpha_p, beta, eps);\n\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -1.0f, 1.0f);\n        }\n    }\n};\n\n// GGML_OP_TRI\nstruct test_tri : public test_case {\n    const ggml_type              type;\n    const std::array<int64_t, 4> ne;\n    const ggml_tri_type          tri_type;\n\n    std::string vars() override { return VARS_TO_STR3(type, ne, tri_type); }\n\n    test_tri(ggml_tri_type tri_type, ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = { 10, 10, 4, 3 })\n        : type(type), ne(ne), tri_type(tri_type) {\n            GGML_ASSERT(ne[0] == ne[1]);\n        }\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_tri(ctx, a, tri_type);\n\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            init_tensor_uniform(t, -1.0f, 1.0f);\n        }\n    }\n};\n\n// GGML_OP_FILL\nstruct test_fill : public test_case {\n    const ggml_type              type;\n    const std::array<int64_t, 4> ne;\n    float                        c;\n\n    std::string vars() override { return VARS_TO_STR3(type, ne, c); }\n\n    test_fill(float c, ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = { 10, 10, 4, 3 })\n        : type(type), ne(ne), c(c) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_fill(ctx, a, c);\n\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// GGML_OP_SOLVE_TRI\nstruct test_solve_tri : public test_case {\n    const ggml_type              type;\n    const std::array<int64_t, 4> ne_lhs;\n    const std::array<int64_t, 4> ne_rhs;\n\n    std::string vars() override { return VARS_TO_STR3(type, ne_lhs, ne_rhs); }\n\n    uint64_t op_flops(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        int64_t n = ne_lhs[0];\n        int64_t k = ne_rhs[0];\n        int64_t batch = ne_lhs[2] * ne_lhs[3];\n        // n * (n + 1) / 2 non-zero elements of lhs, 2 flops each, for each col of rhs\n        return n * (n + 1) * k * batch;\n    }\n\n    test_solve_tri(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne_lhs = { 10, 10, 4, 3 },\n            std::array<int64_t, 4> ne_rhs = { 3, 10, 4, 3 }\n        )\n        : type(type), ne_lhs(ne_lhs), ne_rhs(ne_rhs) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne_lhs[0], ne_lhs[1], ne_lhs[2], ne_lhs[3]);\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne_rhs[0], ne_rhs[1], ne_rhs[2], ne_rhs[3]);\n        ggml_set_param(b);\n        ggml_set_name(b, \"b\");\n\n        ggml_tensor * out = ggml_solve_tri(ctx, a, b, true, true, false);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (strcmp(t->name, \"a\") == 0) {\n                // note: avoid zeros in the diagonal\n                init_tensor_tril(t, 0.1, 1.0f);\n            } else {\n                init_tensor_uniform(t, -1.0f, 1.0f);\n            }\n        }\n    }\n};\n\n// GGML_OP_DIAG\nstruct test_diag : public test_case {\n    const ggml_type              type;\n    const std::array<int64_t, 4> ne;\n\n    std::string vars() override { return VARS_TO_STR2(type, ne); }\n\n    test_diag(ggml_type type = GGML_TYPE_F32,\n            std::array<int64_t, 4> ne = { 10, 1, 4, 3 })\n        : type(type), ne(ne) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        GGML_ASSERT(ne[1] == 1);\n        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);\n        ggml_set_param(a);\n        ggml_set_name(a, \"a\");\n\n        ggml_tensor * out = ggml_diag(ctx, a);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n};\n\n// Deserializable generic test case\nstruct input_tensor {\n    ggml_type type;\n    std::array<int64_t, 4> ne;\n    std::array<size_t, 4> nb; // strides (0 = use default contiguous strides)\n};\n\nstatic bool is_non_contiguous(const input_tensor & src) {\n    if (src.nb[0] == 0) {\n        return false;\n    }\n    const size_t default_nb0 = ggml_type_size(src.type);\n    const size_t default_nb1 = default_nb0 * (src.ne[0] / ggml_blck_size(src.type));\n    const size_t default_nb2 = default_nb1 * src.ne[1];\n    const size_t default_nb3 = default_nb2 * src.ne[2];\n    return src.nb[0] != default_nb0 ||\n           src.nb[1] != default_nb1 ||\n           src.nb[2] != default_nb2 ||\n           src.nb[3] != default_nb3;\n}\n\nstatic std::string var_to_str(const std::vector<input_tensor>& sources) {\n    std::ostringstream oss;\n    bool first = true;\n    for (const auto& src : sources) {\n        if (!first) oss << \",\";\n        oss << ggml_type_name(src.type) << \"[\" << src.ne[0] << \",\" << src.ne[1] << \",\" << src.ne[2] << \",\" << src.ne[3] << \"]\";\n        if (is_non_contiguous(src)) {\n            oss << \"nb[\" << src.nb[0] << \",\" << src.nb[1] << \",\" << src.nb[2] << \",\" << src.nb[3] << \"]\";\n        }\n        first = false;\n    }\n    return oss.str();\n}\n\nstatic std::string var_to_str(const std::array<int32_t, GGML_MAX_OP_PARAMS / sizeof(int32_t)>& params) {\n    std::ostringstream oss;\n    oss << \"[\";\n    bool first = true;\n    for (size_t i = 0; i < params.size(); ++i) {\n        if (params[i] != 0) {\n            if (!first) oss << \",\";\n            oss << i << \":\" << params[i];\n            first = false;\n        }\n    }\n    oss << \"]\";\n    return oss.str();\n}\n\n\nstruct test_generic_op : public test_case {\n    const ggml_op op;\n    const ggml_type type;\n    const std::array<int64_t, 4> ne;\n    const std::array<int32_t, GGML_MAX_OP_PARAMS / sizeof(int32_t)> op_params;\n\n    const std::vector<input_tensor> sources;\n    const std::string name;\n\n    std::string vars() override {\n        if (name.empty()) {\n            return VARS_TO_STR4(type, ne, op_params, sources);\n        }\n\n        return VARS_TO_STR5(name, type, ne, op_params, sources);\n    }\n\n    test_generic_op(ggml_op op, ggml_type type, std::array<int64_t, 4> ne,\n                    std::array<int32_t, GGML_MAX_OP_PARAMS / sizeof(int32_t)> op_params,\n                    std::vector<input_tensor> sources, std::string name = \"\")\n        : op(op), type(type), ne(ne), op_params(op_params), sources(sources), name(std::move(name)) {}\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        const size_t source_count = std::min(sources.size(), (size_t)GGML_MAX_SRC);\n\n        std::array<ggml_tensor *, GGML_MAX_SRC> source_tensors;\n        for (size_t i = 0; i < source_count; ++i) {\n            const input_tensor& src = sources[i];\n\n            if (is_non_contiguous(src)) {\n                size_t total_size;\n                const size_t blck_size = ggml_blck_size(src.type);\n                if (blck_size == 1) {\n                    total_size = ggml_type_size(src.type);\n                    for (int d = 0; d < 4; d++) {\n                        total_size += (src.ne[d] - 1) * src.nb[d];\n                    }\n                } else {\n                    total_size = src.ne[0] * src.nb[0] / blck_size;\n                    for (int d = 1; d < 4; d++) {\n                        total_size += (src.ne[d] - 1) * src.nb[d];\n                    }\n                }\n\n                // Convert bytes to elements, padded to block size for quantized types\n                const size_t type_size = ggml_type_size(src.type);\n                size_t backing_elements = (total_size * blck_size + type_size - 1) / type_size;\n                backing_elements = ((backing_elements + blck_size - 1) / blck_size) * blck_size;\n                ggml_tensor * backing = ggml_new_tensor_1d(ctx, src.type, backing_elements);\n                source_tensors[i] = ggml_view_4d(ctx, backing,\n                    src.ne[0], src.ne[1], src.ne[2], src.ne[3],\n                    src.nb[1], src.nb[2], src.nb[3], 0);\n                // nb[0] does not get set by view_4d, so set it manually\n                source_tensors[i]->nb[0] = src.nb[0];\n            } else {\n                source_tensors[i] = ggml_new_tensor_4d(ctx, src.type, src.ne[0], src.ne[1], src.ne[2], src.ne[3]);\n            }\n        }\n\n        // Ops with an inplace flag create a view of src[0] as their output.\n        bool inplace = false;\n        if (op == GGML_OP_SET || op == GGML_OP_ACC) {\n            inplace = op_params[4] != 0;\n        } else if (op == GGML_OP_ADD_REL_POS) {\n            inplace = op_params[0] != 0;\n        }\n\n        ggml_tensor * out;\n        if (inplace && source_count > 0) {\n            out = ggml_view_tensor(ctx, source_tensors[0]);\n        } else {\n            out = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);\n        }\n        out->op = op;\n        for (size_t i = 0; i < source_count; ++i) {\n            out->src[i] = source_tensors[i];\n        }\n\n        memcpy(out->op_params, op_params.data(), GGML_MAX_OP_PARAMS);\n        ggml_set_name(out, \"out\");\n\n        return out;\n    }\n\n    double max_nmse_err() override {\n        switch (op) {\n        case GGML_OP_MUL_MAT:\n        case GGML_OP_MUL_MAT_ID:\n        case GGML_OP_OUT_PROD:\n        case GGML_OP_CONV_TRANSPOSE_2D:\n        case GGML_OP_IM2COL:\n        case GGML_OP_CONV_2D:\n        case GGML_OP_CONV_3D:\n        case GGML_OP_SET_ROWS:\n        case GGML_OP_CPY:\n            return 5e-4;\n        case GGML_OP_SOFT_MAX:\n            return 1e-6;\n        case GGML_OP_RWKV_WKV7:\n            return 5e-3;\n        case GGML_OP_FLASH_ATTN_EXT:\n        {\n            // Scale error with kv length to account for accumulating floating point error\n            const int64_t kv = sources[1].ne[1];\n            return 5e-4 * std::max(1.0, kv / 20000.0);\n        }\n        default:\n            return 1e-7;\n        }\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        ggml_tensor * out = ggml_get_tensor(ctx, \"out\");\n\n        std::random_device rd;\n        std::default_random_engine rng(rd());\n\n        for (size_t i = 0; i < sources.size() && i < GGML_MAX_SRC; i++) {\n            ggml_tensor * t = out->src[i];\n            if (!t) {\n                break;\n            }\n\n            // FLASH_ATTN_EXT: src[3] is the KQ mask\n            if (op == GGML_OP_FLASH_ATTN_EXT && i == 3) {\n                init_tensor_kq_mask(t);\n                continue;\n            }\n\n            if (t->type == GGML_TYPE_I32 || t->type == GGML_TYPE_I64) {\n                if (op == GGML_OP_GET_ROWS || op == GGML_OP_GET_ROWS_BACK) {\n                    const int64_t num_rows = sources[0].ne[1];\n                    const int64_t nels = ggml_nelements(t);\n                    std::vector<int32_t> data(nels);\n                    std::uniform_int_distribution<int32_t> dist(0, num_rows - 1);\n                    for (int64_t i = 0; i < nels; i++) {\n                        data[i] = dist(rng);\n                    }\n                    ggml_backend_tensor_set(t, data.data(), 0, nels * sizeof(int32_t));\n                } else if (op == GGML_OP_SET_ROWS) {\n                    init_set_rows_row_ids(t, ne[1]);\n                } else if (op == GGML_OP_ROPE) {\n                    const int mode = op_params[2];\n                    const int64_t nels = (mode & GGML_ROPE_TYPE_MROPE) ? ne[2] * 4 : ne[2];\n                    std::vector<int32_t> data(nels);\n                    std::uniform_int_distribution<int32_t> dist(0, ne[2] - 1);\n                    for (int64_t i = 0; i < nels; i++) {\n                        data[i] = dist(rng);\n                    }\n                    ggml_backend_tensor_set(t, data.data(), 0, nels * sizeof(int32_t));\n                } else if (op == GGML_OP_MUL_MAT_ID || op == GGML_OP_ADD_ID) {\n                    const int64_t n_expert = (op == GGML_OP_MUL_MAT_ID) ? sources[0].ne[2] : sources[1].ne[1];\n                    for (int64_t r = 0; r < ggml_nrows(t); r++) {\n                        std::vector<int32_t> data(t->ne[0]);\n                        for (int32_t i = 0; i < t->ne[0]; i++) {\n                            data[i] = i % n_expert;\n                        }\n                        std::shuffle(data.begin(), data.end(), rng);\n                        ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));\n                    }\n                } else if (op == GGML_OP_SSM_SCAN) {\n                    for (int64_t r = 0; r < ggml_nrows(t); r++) {\n                        std::vector<int32_t> data(t->ne[0]);\n                        for (int32_t i = 0; i < t->ne[0]; i++) {\n                            data[i] = i;\n                        }\n                        std::shuffle(data.begin(), data.end(), rng);\n                        ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));\n                    }\n                } else {\n                    init_tensor_uniform(t);\n                }\n            } else {\n                init_tensor_uniform(t);\n            }\n        }\n    }\n};\n\n\nenum llm_norm_type {\n    LLM_NORM,\n    LLM_NORM_RMS,\n};\n\nstruct llama_hparams {\n    uint32_t n_vocab;\n    uint32_t n_embd;\n    uint32_t n_head;\n    uint32_t n_head_kv;\n    static constexpr uint32_t n_layer = 1;\n    uint32_t n_rot;\n    uint32_t n_embd_head; // dimension of values (d_v)\n    uint32_t n_ff;\n\n    float f_norm_eps;\n    float f_norm_rms_eps;\n\n    // cparams\n    static constexpr uint32_t n_ctx = 512; // user-specified context size\n    static constexpr uint32_t n_ctx_orig = n_ctx;\n\n    // batch\n    int32_t n_tokens;\n\n    // llm_build_context\n    static constexpr int32_t n_kv    = 32; // size of KV cache to consider (n_kv <= n_ctx\n    static constexpr int32_t kv_head = 1;  // index of where we store new KV data in the cache\n\n    uint32_t n_embd_gqa() const { // dimension of key embeddings across all k-v heads\n        return n_embd_head * n_head_kv;\n    }\n};\n\n// LLM base class\nstruct test_llm : public test_case {\n    llama_hparams hp;\n\nprotected:\n    test_llm(llama_hparams hp)\n        : hp(std::move(hp)) {\n    }\n\npublic:\n    struct ggml_tensor * llm_build_norm(\n            struct ggml_context * ctx,\n             struct ggml_tensor * cur,\n             struct ggml_tensor * mw,\n             struct ggml_tensor * mb,\n                  llm_norm_type   type) {\n        switch (type) {\n            case LLM_NORM:     cur = ggml_norm    (ctx, cur, hp.f_norm_eps); break;\n            case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hp.f_norm_rms_eps); break;\n        }\n        cur = ggml_mul(ctx, cur, mw);\n        if (mb) {\n            cur = ggml_add(ctx, cur, mb);\n        }\n        return cur;\n    }\n\n    void llm_build_kv_store(\n            struct ggml_context * ctx,\n             struct ggml_tensor * k_l,\n             struct ggml_tensor * v_l,\n             struct ggml_tensor * k_cur,\n             struct ggml_tensor * v_cur) {\n        // compute the transposed [n_tokens, n_embd] V matrix\n        struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, hp.n_embd_gqa(), hp.n_tokens));\n\n        struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, k_l, hp.n_tokens*hp.n_embd_gqa(),\n                (ggml_row_size(k_l->type, hp.n_embd_gqa()))*hp.kv_head);\n\n        struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, v_l, hp.n_tokens, hp.n_embd_gqa(),\n                (  hp.n_ctx)*ggml_element_size(v_l),\n                (hp.kv_head)*ggml_element_size(v_l));\n\n        // important: storing RoPE-ed version of K in the KV cache!\n        ggml_cpy(ctx, k_cur,   k_cache_view);\n        ggml_cpy(ctx, v_cur_t, v_cache_view);\n    }\n\n    struct ggml_tensor * llm_build_kqv(\n            struct ggml_context * ctx,\n             struct ggml_tensor * k_l,\n             struct ggml_tensor * v_l,\n             struct ggml_tensor * q_cur,\n             struct ggml_tensor * kq_mask,\n                        float     kq_scale) {\n        struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);\n\n        struct ggml_tensor * k =\n            ggml_view_3d(ctx, k_l,\n                    hp.n_embd_head, hp.n_kv, hp.n_head_kv,\n                    ggml_row_size(k_l->type, hp.n_embd_gqa()),\n                    ggml_row_size(k_l->type, hp.n_embd_head),\n                    0);\n\n        struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);\n\n        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f);\n\n        // split cached v into n_head heads\n        struct ggml_tensor * v =\n            ggml_view_3d(ctx, v_l,\n                    hp.n_kv, hp.n_embd_head, hp.n_head_kv,\n                    ggml_element_size(v_l)*hp.n_ctx,\n                    ggml_element_size(v_l)*hp.n_ctx*hp.n_embd_head,\n                    0);\n\n        struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);\n\n        struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);\n\n        struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, hp.n_embd_head*hp.n_head, hp.n_tokens);\n\n        struct ggml_tensor * wo = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);\n        cur = ggml_mul_mat(ctx, wo, cur);\n\n        return cur;\n    }\n\n    void initialize_tensors(ggml_context * ctx) override {\n        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {\n            if (t->type == GGML_TYPE_I32) {\n                // pos\n                std::vector<int> data(hp.n_tokens);\n                for (int i = 0; i < hp.n_tokens; i++) {\n                    data[i] = rand() % hp.n_ctx;\n                }\n                ggml_backend_tensor_set(t, data.data(), 0, hp.n_tokens * sizeof(int));\n            } else {\n                init_tensor_uniform(t);\n            }\n        }\n    }\n};\n\n// Llama\nstruct test_llama : public test_llm {\n    static constexpr float freq_base = 10000.0f;\n    static constexpr float freq_scale = 1.0f;\n    static constexpr float ext_factor = 0.0f;\n    static constexpr float attn_factor = 1.0f;\n    static constexpr float beta_fast = 32.0f;\n    static constexpr float beta_slow = 1.0f;\n    bool fused;\n\n    std::string op_desc(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return \"LLAMA\";\n    }\n\n    std::string vars() override {\n        auto n_tokens = hp.n_tokens;\n        return VARS_TO_STR1(n_tokens);\n    }\n\n    double max_nmse_err() override {\n        return 2e-3;\n    }\n\n    bool run_whole_graph() override { return fused; }\n\n    test_llama(int n_tokens = 1, bool fused = false)\n        : test_llm({\n            /*n_vocab        =*/ 32000,\n            /*n_embd         =*/ 3200,\n            /*n_head         =*/ 32,\n            /*n_head_kv      =*/ 32,\n            /*n_rot          =*/ 100,\n            /*n_embd_head    =*/ 100,\n            /*n_ff           =*/ 8640,\n            /*f_norm_eps     =*/ 0.f,\n            /*f_norm_rms_eps =*/ 1e-5f,\n            /*n_tokens       =*/ n_tokens,\n        })\n        , fused(fused)\n    {\n    }\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        struct ggml_tensor * cur;\n        struct ggml_tensor * inpL;\n\n        inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);\n\n        // inp_pos - contains the positions\n        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);\n\n        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)\n        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);\n\n        ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);\n        ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);\n\n        for (uint32_t il = 0; il < hp.n_layer; ++il) {\n            struct ggml_tensor * inpSA = inpL;\n\n            // norm\n            ggml_tensor * attn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);\n            cur = llm_build_norm(ctx, inpL, attn_norm, nullptr, LLM_NORM_RMS);\n\n            // self-attention\n            {\n                ggml_tensor * wq = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);\n                ggml_tensor * wk = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());\n                ggml_tensor * wv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());\n\n                // compute Q and K and RoPE them\n                struct ggml_tensor * Qcur = ggml_mul_mat(ctx, wq, cur);\n                struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur);\n                struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur);\n\n                Qcur = ggml_rope_ext(\n                    ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head,    hp.n_tokens), inp_pos, nullptr,\n                    hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,\n                    ext_factor, attn_factor, beta_fast, beta_slow\n                );\n\n                Kcur = ggml_rope_ext(\n                    ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,\n                    hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,\n                    ext_factor, attn_factor, beta_fast, beta_slow\n                );\n\n                llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);\n\n                cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));\n            }\n\n            struct ggml_tensor * ffn_inp = ggml_add(ctx, cur, inpSA);\n\n            // feed-forward network\n            ggml_tensor * ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);\n            cur = llm_build_norm(ctx, ffn_inp, ffn_norm, nullptr, LLM_NORM_RMS);\n\n            ggml_tensor * ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);\n            ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff,   hp.n_embd);\n            ggml_tensor * ffn_up   = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);\n            struct ggml_tensor * tmp = ggml_mul_mat(ctx, ffn_up, cur);\n            cur = ggml_mul_mat(ctx, ffn_gate, cur);\n            cur = ggml_silu(ctx, cur);\n            cur = ggml_mul(ctx, cur, tmp);\n            cur = ggml_mul_mat(ctx, ffn_down, cur);\n\n            cur = ggml_add(ctx, cur, ffn_inp);\n\n            // input for next layer\n            inpL = cur;\n        }\n\n        cur = inpL;\n\n        ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);\n        cur = llm_build_norm(ctx, cur, output_norm, nullptr, LLM_NORM_RMS);\n\n        // lm_head\n        ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_vocab);\n        cur = ggml_mul_mat(ctx, output, cur);\n\n        return cur;\n    }\n};\n\n// Falcon\nstruct test_falcon : public test_llm {\n    static constexpr float freq_base = 10000.0f;\n    static constexpr float freq_scale = 1.0f;\n    static constexpr float ext_factor = 0.0f;\n    static constexpr float attn_factor = 1.0f;\n    static constexpr float beta_fast = 32.0f;\n    static constexpr float beta_slow = 1.0f;\n\n    std::string op_desc(ggml_tensor * t) override {\n        GGML_UNUSED(t);\n        return \"FALCON\";\n    }\n\n    std::string vars() override {\n        auto n_tokens = hp.n_tokens;\n        return VARS_TO_STR1(n_tokens);\n    }\n\n    double max_nmse_err() override {\n        return 2e-3;\n    }\n\n    test_falcon(int n_tokens = 1)\n        : test_llm({\n            /*n_vocab        =*/ 32000,\n            /*n_embd         =*/ 3200,\n            /*n_head         =*/ 50,\n            /*n_head_kv      =*/ 1,\n            /*n_rot          =*/ 64,\n            /*n_embd_head    =*/ 64,\n            /*n_ff           =*/ 8640,\n            /*f_norm_eps     =*/ 1e-5f,\n            /*f_norm_rms_eps =*/ 0.f,\n            /*n_tokens       =*/ n_tokens,\n        }) {\n    }\n\n    ggml_tensor * build_graph(ggml_context * ctx) override {\n        struct ggml_tensor * cur;\n        struct ggml_tensor * inpL;\n\n        inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);\n\n        // inp_pos - contains the positions\n        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);\n\n        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)\n        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);\n\n        ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);\n        ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);\n\n        for (uint32_t il = 0; il < hp.n_layer; ++il) {\n            // norm\n            ggml_tensor * attn_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);\n            ggml_tensor * attn_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);\n            ggml_tensor * attn_norm = llm_build_norm(ctx, inpL, attn_norm_w, attn_norm_b, LLM_NORM);\n\n            // self-attention\n            {\n                cur = attn_norm;\n\n                ggml_tensor * wqkv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd + 2*hp.n_embd_gqa());\n\n                cur = ggml_mul_mat(ctx, wqkv, cur);\n\n                struct ggml_tensor * Qcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd,     hp.n_tokens, cur->nb[1], 0*sizeof(float)*(hp.n_embd)));\n                struct ggml_tensor * Kcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd)));\n                struct ggml_tensor * Vcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd + hp.n_embd_gqa())));\n\n                Qcur = ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head,    hp.n_tokens);\n                Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens);\n\n                // using mode = 2 for neox mode\n                Qcur = ggml_rope_ext(\n                    ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,\n                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow\n                );\n\n                Kcur = ggml_rope_ext(\n                    ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,\n                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow\n                );\n\n                llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);\n\n                cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));\n            }\n\n            struct ggml_tensor * ffn_inp = cur;\n\n            // feed forward\n            {\n                ggml_tensor * ffn_up   = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);\n                ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff, hp.n_embd);\n                cur = attn_norm;\n                cur = ggml_mul_mat(ctx, ffn_up, cur);\n                cur = ggml_gelu(ctx, cur);\n                cur = ggml_mul_mat(ctx, ffn_down, cur);\n            }\n\n            cur = ggml_add(ctx, cur, ffn_inp);\n\n            cur = ggml_add(ctx, cur, inpL);\n\n            // input for next layer\n            inpL = cur;\n        }\n\n        cur = inpL;\n\n        ggml_tensor * output_norm   = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);\n        ggml_tensor * output_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);\n        cur = llm_build_norm(ctx, cur, output_norm, output_norm_b, LLM_NORM);\n\n        // lm_head\n        ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q8_0, hp.n_embd, hp.n_vocab);\n        cur = ggml_mul_mat(ctx, output, cur);\n\n        return cur;\n    }\n};\n\n\n// ###########################################\n// ## Section 3: GGML Op Test Instantiation ##\n// ###########################################\nstatic const ggml_type all_types[] = {\n    GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,\n    GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,\n    GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,\n    GGML_TYPE_Q8_0,\n    GGML_TYPE_MXFP4,\n    GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,\n    GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,\n    GGML_TYPE_Q6_K,\n    // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends\n    GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,\n    GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,\n    GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,\n};\n\nstatic const ggml_type base_types[] = {\n    GGML_TYPE_F32, GGML_TYPE_F16,\n    GGML_TYPE_Q8_0, // for I8MM tests\n    GGML_TYPE_Q4_0,\n    GGML_TYPE_Q4_1, // for I8MM tests\n    GGML_TYPE_Q4_K,\n    GGML_TYPE_MXFP4, // TODO: or \"other\"\n    GGML_TYPE_IQ2_XXS\n};\n\nstatic const ggml_type other_types[] = {\n    GGML_TYPE_Q4_1,\n    GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,\n    GGML_TYPE_Q8_0,\n    GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,\n    GGML_TYPE_Q5_K,\n    GGML_TYPE_Q6_K,\n    // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends\n    GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,\n    GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,\n    GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,\n    GGML_TYPE_BF16,\n};\n\n#ifdef _MSC_VER\n// Workaround long compile time with msvc\n#pragma optimize(\"\", off)\n#endif\n\n// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low\nstatic std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {\n    std::vector<std::unique_ptr<test_case>> test_cases;\n    std::default_random_engine rng(0);\n\n    // unary ops\n    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {\n        for (int v : {0, 1}) {\n            for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {\n                if (op == GGML_UNARY_OP_XIELU) {\n                    continue; // need extra params, separate test\n                }\n                test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 128, 2, 2, 2 }, v));\n                test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 5, 7, 11, 13 }, v));\n            }\n        }\n    }\n\n    // glu ops\n    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {\n        for (int v : {0, 1}) {\n            for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {\n                if (op == GGML_GLU_OP_SWIGLU_OAI) {\n                    // SWIGLU_OAI is handled separately\n                    continue;\n                }\n\n                for (bool swapped : {false, true}) {\n                    test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped));\n                    test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped));\n                }\n\n                test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v));\n                test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v));\n            }\n        }\n    }\n\n    for (int v : {0, 1}) {\n        for (float alpha : {.5f, 1.702f}) {\n            for (float limit : {2.0f, 7.0f}) {\n                test_cases.emplace_back(new test_swiglu_oai(GGML_TYPE_F32, { 128, 2, 2, 2 }, v, alpha, limit));\n            }\n        }\n    }\n\n    for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_Q4_0}) {\n        test_cases.emplace_back(new test_get_rows(type, 300*256,   5,         4,   1,   2, false));\n        test_cases.emplace_back(new test_get_rows(type,     256,   80000, 70000,   2,   1, false));\n        test_cases.emplace_back(new test_get_rows(type,     256,   5,         4, 700, 100, false));\n    }\n\n    test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, 1, false));\n    for (ggml_type type : all_types) {\n        for (int b : {1, 7}) {\n            for (bool v : {false, true}) {\n                test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, 1, v));\n            }\n        }\n    }\n    for (int b : {1, 7}) {\n        for (bool v : {false, true}) {\n            test_cases.emplace_back(new test_get_rows(GGML_TYPE_I32, 256, 5, 4, b, 1, v));\n        }\n    }\n\n    test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_F32, 1, 8, 2, 1, false));\n    for (ggml_type type : all_types) {\n        for (bool v : {false, true}) {\n            test_cases.emplace_back(new test_get_rows_back(type, 256, 5, 4, 1, v));\n        }\n    }\n    for (bool v : {false, true}) {\n        test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));\n    }\n\n    test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, GGML_TYPE_I64, { 1, 8, 1, 3 }, { 1, 1 }, 2, false));\n    test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, GGML_TYPE_I32, { 1, 8, 1, 3 }, { 1, 1 }, 2, false));\n    test_cases.emplace_back(new test_set_rows(GGML_TYPE_Q8_0, GGML_TYPE_I32, { 256, 5, 1, 3 }, { 1, 1, }, 1, false));\n    for (ggml_type type : all_types) {\n        for (int b : {1, 7}) {\n            for (bool v : {false, true}) {\n                test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5,  b, 3 }, { 1, 1, }, 1, v));\n                test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 11, 1, b }, { 2, 3, }, 7, v));\n\n                test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 3*ggml_blck_size(type), 3, b, 1 }, { 2, 3, }, 2, v));\n\n                if (ggml_blck_size(type) == 1) {\n                    test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 31, 3, b, 1 }, { 2, 3, }, 2, v));\n                    test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 33, 5, 1, b }, { 2, 3, }, 1, v));\n                }\n            }\n        }\n    }\n\n    for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION }) {\n        for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {\n            for (int ne2 : {1, 8, 512}) {\n                test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, ne2, 1 }, mode));\n                test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, ne2, 3 }, mode));\n            }\n        }\n    }\n\n    for (ggml_type type_input : {GGML_TYPE_F32}) {\n        for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {\n            for (int k0 : {1, 3}) {\n                for (int k1 : {1, 3}) {\n                    for (int s0 : {1, 2}) {\n                        for (int s1 : {1, 2}) {\n                            for (int p0 : {0, 1}) {\n                                for (int p1 : {0, 1}) {\n                                    test_cases.emplace_back(new test_pool2d(pool_type, type_input, {10, 10, 3, 1}, k0, k1, s0, s1, p0, p1));\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    for (ggml_type type_input : {GGML_TYPE_F32}) {\n        for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {\n            for (int k0 : {1, 3}) {\n                for (int s0 : {1, 2}) {\n                    for (int p0 : {0, 1}) {\n                        test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 10,  3, 2, 1 }, k0, s0, p0));\n                        test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 11,  1, 3, 2 }, k0, s0, p0));\n                        test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 128, 2, 1, 3 }, k0, s0, p0));\n                    }\n                }\n            }\n        }\n    }\n\n#if 0\n    // >4GB im2col destination. Too slow to run by default.\n    // Test cases taken from Wan2.1 T2V 1.3B.\n    test_cases.emplace_back(new test_im2col   (GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {832, 480, 192, 4}, {3, 3, 192, 96}, 1, 1, 1, 1, 1, 1, true));\n    test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {834, 482, 6, 96},  {3, 3,3, 9216}, 96, 1, 1, 1, 0, 0, 0, 1, 1, 1, false));\n#endif\n\n    // im2col 1D\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));\n    for (int s0 : {1, 3}) {\n        for (int p0 : {0, 3}) {\n            for (int d0 : {1, 3}) {\n                test_cases.emplace_back(new test_im2col(\n                    GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1},\n                    s0, 0, p0, 0, d0, 0, false));\n            }\n        }\n    }\n\n    // im2col 2D\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));\n    for (int s0 : {1, 3}) {\n        for (int s1 : {1, 3}) {\n            for (int p0 : {0, 3}) {\n                for (int p1 : {0, 3}) {\n                    for (int d0 : {1, 3}) {\n                        for (int d1 : {1, 3}) {\n                            test_cases.emplace_back(new test_im2col(\n                                GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2},\n                                s0, s1, p0, p1, d0, d1, true));\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    // extra tests for im2col 2D\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true));\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2048}, {3, 3, 1, 2048}, 1, 1, 1, 1, 1, 1, true));\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true));\n    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 1536, 729}, {2, 2, 1536, 4096}, 1, 1, 0, 0, 1, 1, true));\n\n    // im2col 3D\n    test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));\n    test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));\n    test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));\n    for (int s0 : {1, 3}) {\n        for (int s1 : {1, 3}) {\n            for (int s2 : {1, 3}) {\n                for (int p0 : {0, 3}) {\n                    for (int p1 : {0, 3}) {\n                        for (int p2 : {0, 3}) {\n                            for (int d0 : {1, 3}) {\n                                for (int d1 : {1, 3}) {\n                                    for (int d2 : {1, 3}) {\n                                        for (int IC : {1, 3}) {\n                                            for (bool v : {false, true}) {\n                                                test_cases.emplace_back(new test_im2col_3d(\n                                                    GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 10, 3}, {3, 3, 3, 3},\n                                                    IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, v));\n                                            }\n                                        }\n                                    }\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n// Conv_2D test cases\n#ifdef DETAILED_TESTS\n    // Probably we do not have enough time to execute these in the pipeline.\n    uint32_t iwh_idx  = 0;\n    uint32_t kwh_idx  = 1;\n    uint32_t Cout_idx = 2;\n    uint32_t Cin_idx  = 3;\n    uint32_t B_idx    = 4;\n\n    std::vector<std::array<int, 5>> cases = {\n  //{IWH, KWH, Cout, Cin, B}\n  // K=CRS=NPQ=4096 conv_2d matmul performance\n        {19,   4, 4096, 256, 16},\n // K=128, CRS=128, NPQ=4096\n        { 19,  4, 128,  8,   16},\n // K=130, CRS=128, NPQ=4096\n        { 19,  4, 130,  8,   16},\n // Edge case: K x CRS is small\n        { 19,  2, 4,    4,   16},\n // A ConvNet's first layer\n        { 224, 3, 8,    3,   1 },\n // A ConvNet's first layer with 2x2 convolution, and 1 channel\n        { 224, 2, 8,    1,   1 },\n // A ConvNet's first layer with 2x2 convolution, and 1 channel, several images in the batch\n        { 224, 2, 8,    1,   8 },\n // A middle layer of a ConvNet\n        { 58,  3, 64,   32,  1 },\n // A middle layer of a ConvNet, several images in the batch\n        { 58,  3, 64,   32,  8 },\n // A deep layer of a ConvNet, several images in the batch\n        { 16,  3, 256,  128, 8 }\n    };\n\n    for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {\n        for (auto act_case : cases) {\n            test_cases.emplace_back(new test_conv_2d(\n                { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },\n                { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },\n                kernel_type, 1, 1, 0, 0, 1, 1, false));\n        }\n    }\n#endif\n\n    // CONV_2D:\n    auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {\n        return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;\n    };\n\n    //uint32_t s0 = 3;\n    uint32_t s1 = 5;\n    uint32_t p0 = 5;\n    //uint32_t p1 = 2;\n    uint32_t d0 = 2;\n    uint32_t d1 = 4;\n\n    for (uint32_t s0 : { 1, 3 }) {\n        for (uint32_t p1 : { 2, 5 }) {\n            for (uint32_t Cin : { 1, 25 }) {\n                for (uint32_t Cout : { 1, 12 }) {\n                    for (uint32_t KH : { 1, 2, 3, 11 }) {\n                        for (uint32_t KW : { 1, 2, 3, 11 }) {\n                            for (uint32_t H : { 1, 133 }) {\n                                for (uint32_t W : { 1, 141 }) {\n                                    if (calc_conv_output_size(W, KW, s0, p0, d0) > 0 &&\n                                        calc_conv_output_size(H, KH, s1, p1, d1) > 0) {\n                                        for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {\n                                            test_cases.emplace_back(new test_conv_2d(\n                                                { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, kernel_type, s0, s1, p0, p1, d0, d1, false));\n                                        }\n                                    }\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    // sycl backend will limit task global_range < MAX_INT\n    // test cases for 2D im2col with large input W and H (occurs in stable-diffusion)\n    // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.)\n    // these cases are verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend)\n    // test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));\n    // test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));\n\n    test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, false));\n    test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, true));\n    test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));\n    test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));\n\n    // CONV_3D\n    auto calc_conv_output_size_3d = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {\n        return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;\n    };\n\n    for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {\n        for (int N : {1, 2}) {\n            for (int IC : {1, 3}) {\n                for (int OC : {1, 4}) {\n                    for (int s0 : {1, 2}) {\n                        for (int p1 : {0, 1}) {\n                            for (int d2 : {1, 2}) {\n                                int64_t IW = 20, IH = 22, ID = 18;\n                                int64_t KW = 3,  KH = 3,  KD = 3;\n                                int s1 = s0, s2 = s0;\n                                int p0 = p1, p2 = p1;\n                                int d0 = d2, d1 = d2;\n\n                                if (calc_conv_output_size_3d(IW, KW, s0, p0, d0) <= 0 ||\n                                    calc_conv_output_size_3d(IH, KH, s1, p1, d1) <= 0 ||\n                                    calc_conv_output_size_3d(ID, KD, s2, p2, d2) <= 0) {\n                                    continue;\n                                }\n                                test_cases.emplace_back(new test_conv_3d(\n                                    N, IC, ID, IH, IW,\n                                    OC, KD, KH, KW,\n                                    s0, s1, s2, p0, p1, p2, d0, d1, d2,\n                                    kernel_type));\n\n                                // Asymmetric kernel and params\n                                int64_t asym_KW = 5, asym_KH = 1, asym_KD = 3;\n                                int asym_s0 = 2, asym_s1 = 1, asym_s2 = 1;\n                                int asym_p0 = 2, asym_p1 = 0, asym_p2 = 1;\n                                int asym_d0 = 1, asym_d1 = 1, asym_d2 = 2;\n\n                                if (calc_conv_output_size_3d(IW, asym_KW, asym_s0, asym_p0, asym_d0) <= 0 ||\n                                    calc_conv_output_size_3d(IH, asym_KH, asym_s1, asym_p1, asym_d1) <= 0 ||\n                                    calc_conv_output_size_3d(ID, asym_KD, asym_s2, asym_p2, asym_d2) <= 0) {\n                                    continue;\n                                }\n                                test_cases.emplace_back(new test_conv_3d(\n                                    N, IC, ID, IH, IW,\n                                    OC, asym_KD, asym_KH, asym_KW,\n                                    asym_s0, asym_s1, asym_s2, asym_p0, asym_p1, asym_p2, asym_d0, asym_d1, asym_d2,\n                                    kernel_type));\n                            }\n                        }\n                    }\n                }\n            }\n        }\n        // Case with kernel size 1\n        test_cases.emplace_back(new test_conv_3d(1, 4, 8, 8, 8, 8, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, kernel_type));\n    }\n\n    for(uint32_t Cout : {1, 9}){\n        for(uint32_t Cin : {1, 7}){\n            for(uint32_t K : {1, 3, 1337}){\n                for(uint32_t L : {1, 2, 13}){\n                    for(uint32_t s0: {1, 2, 3}){\n                        test_cases.emplace_back(new test_conv_transpose_1d({L,Cin,1,1}, {K,Cout,Cin,1}, s0, 0, 1));\n                    }\n                }\n            }\n        }\n    }\n\n    test_cases.emplace_back(new test_conv_transpose_1d());\n    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));\n    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));\n    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 1, 0, 1));\n    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 2, 0, 1));\n    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 1, 0, 1));\n    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));\n    test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));\n\n    test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1));\n    test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));\n    test_cases.emplace_back(new test_conv_transpose_2d({129, 63, 35, 1}, {3, 3, 48, 35}, 1));\n\n    test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4,  500, 1, 1}));\n    test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));\n\n    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32,    1, 1, 1}));\n    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32,  513, 1, 1}));\n    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100,  10, 1, 1}));\n    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));\n    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1}));\n    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));\n    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438,  3, 1, 1}));\n\n    for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1\n        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));\n        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));\n        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 2, 1, 1}));\n        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 2, 1}));\n        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 2}));\n        test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, ne3}, {2, 1, 1, 1}));\n        test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2}));\n    }\n\n    for (bool view : {false, true}) {\n        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view));\n        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));\n        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));\n        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));\n        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));\n    }\n\n    test_cases.emplace_back(new test_dup(GGML_TYPE_F32));\n    test_cases.emplace_back(new test_dup(GGML_TYPE_F16));\n    test_cases.emplace_back(new test_dup(GGML_TYPE_I32));\n    test_cases.emplace_back(new test_dup(GGML_TYPE_I16));\n    test_cases.emplace_back(new test_dup(GGML_TYPE_F32, {10, 10, 5, 1}, {0, 2, 1, 3}));\n    test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {0, 2, 1, 3})); // dup by rows\n    test_cases.emplace_back(new test_dup(GGML_TYPE_F32, {10, 10, 5, 1}, {1, 0, 2, 3}));\n    test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {1, 0, 2, 3})); // dup dst not-contiguous\n    test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10,  8, 3, 1}, {0, 2, 1, 3}));\n    test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10,  8, 3, 1}, {1, 2, 0, 3}));\n\n    for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {\n        test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, false));\n        test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, true));\n    }\n\n    for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {\n        test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, false));\n        test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, true));\n    }\n\n    // same-type copy\n    for (ggml_type type : all_types) {\n        const auto nk = ggml_blck_size(type);\n\n        for (int k = 1; k < 4; ++k) {\n            test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}));\n            test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 2, 1, 3}));\n            test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 3, 1, 2}, {0, 2, 1, 3}));\n        }\n    }\n\n    for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {\n        for (ggml_type type_dst : all_types) {\n            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));\n            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows\n        }\n    }\n    for (ggml_type type_src : all_types) {\n        for (ggml_type type_dst : {GGML_TYPE_F32}) {\n            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));\n            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows\n        }\n    }\n    for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {\n        for (ggml_type type_dst : {GGML_TYPE_F16, GGML_TYPE_F32}) {\n            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous\n        }\n    }\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3}));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3}));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 3}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));\n\n    for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_I32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {\n        for (bool use_view_slice : { true, false }) {\n            for (std::array<int64_t, 4> ne : std::initializer_list<std::array<int64_t, 4>>{ {2, 1, 1, 1}, {2, 1, 3, 5},\n                {2, 3, 5, 7}, {1, 4, 4, 1}, {1, 8, 17, 1}, {10, 10, 10, 1} }) {\n                if (use_view_slice && (type_dst == GGML_TYPE_F16 || type_dst == GGML_TYPE_BF16)) {\n                    continue; // TODO: add after WebGPU is fixed\n                }\n                test_cases.emplace_back(new test_cont(type_dst, ne, use_view_slice));\n            }\n        }\n    }\n\n    auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr, bool perm1 = false, bool src_overlap = false) {\n        for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {\n            test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1, src_overlap));\n        }\n    };\n    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {\n        for (bool perm1 : {false, true}) {\n            add_test_bin_bcast(type, {1,  1,   8,   1}, {1,  1, 1, 1}, perm1);\n            add_test_bin_bcast(type, {1,  1,   1,   1}, {32, 1, 1, 1}, perm1);\n            add_test_bin_bcast(type, {1,  1, 320, 320}, {1,  1, 1, 1}, perm1);\n            add_test_bin_bcast(type, {10, 5,   1,   1}, {1,  1, 1, 1}, perm1);\n            add_test_bin_bcast(type, {10, 5,   4,   1}, {1,  1, 1, 1}, perm1);\n            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  1, 1, 1}, perm1);\n            add_test_bin_bcast(type, {10, 5,   4,   3}, {2,  1, 1, 1}, perm1);\n            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  2, 1, 1}, perm1);\n            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  1, 2, 1}, perm1);\n            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  1, 1, 2}, perm1);\n            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  1, 2, 2}, perm1);\n            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  2, 2, 2}, perm1);\n            add_test_bin_bcast(type, {10, 5,   4,   3}, {2,  2, 2, 2}, perm1);\n        }\n\n        // src_overlap\n        add_test_bin_bcast(type, {10, 5, 4, 6}, {1, 1, 1, 1}, false, true);\n        add_test_bin_bcast(type, {10, 5, 4, 5}, {1, 1, 1, 1}, false, true);\n        add_test_bin_bcast(type, {1, 1, 120, 120}, {1, 1, 1, 1}, false, true);\n        add_test_bin_bcast(type, {1, 1, 4, 320}, {1, 1, 1, 1}, false, true);\n\n        // test case for k_bin_bcast_unravel in CUDA backend\n        add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});\n\n        // stable diffusion\n        add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1});\n        add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});\n        add_test_bin_bcast(type, {1280, 16, 16, 1}, {1, 1, 1, 1});\n        add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 256, 1, 1});\n        add_test_bin_bcast(type, {1, 1, 1280, 1}, {16, 16, 1, 1});\n        add_test_bin_bcast(type, {16, 16, 1280, 1}, {1, 1, 1, 1});\n        add_test_bin_bcast(type, {1, 1, 1920, 1}, {16, 16, 1, 1});\n        add_test_bin_bcast(type, {1, 1, 2560, 1}, {16, 16, 1, 1});\n        add_test_bin_bcast(type, {1, 1, 1280, 1}, {32, 32, 1, 1});\n        add_test_bin_bcast(type, {1, 1, 1920, 1}, {32, 32, 1, 1});\n        add_test_bin_bcast(type, {1, 1, 640, 1}, {32, 32, 1, 1});\n        add_test_bin_bcast(type, {5120, 1, 1, 1}, {1, 256, 1, 1});\n        add_test_bin_bcast(type, {640, 1, 1, 1}, {1, 1, 1, 1});\n        add_test_bin_bcast(type, {64, 262144, 1, 1}, {1, 1, 1, 1});\n        //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {1, 1, 1, 1});\n        //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});\n    }\n\n    // single inplace tests, especially important for WebGPU backend since kernels for inplace vs. not are different\n    test_cases.emplace_back(new test_bin_bcast(ggml_add_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));\n    test_cases.emplace_back(new test_bin_bcast(ggml_mul_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));\n    test_cases.emplace_back(new test_bin_bcast(ggml_sub_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));\n    test_cases.emplace_back(new test_bin_bcast(ggml_div_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));\n\n    // fusion\n    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}, 2));\n    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 2, 1, 1}, 3));\n    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1}, 4));\n    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 2}, 5));\n    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 2}, 6));\n    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 2, 2}, 7));\n    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {2, 2, 2, 2}, 8));\n    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));\n\n    test_cases.emplace_back(new test_add1());\n    test_cases.emplace_back(new test_add1(GGML_TYPE_F32, {1024, 1024, 1, 1}));\n    test_cases.emplace_back(new test_scale());\n    test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));\n    test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test\n    test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {100, 10, 10, 10}, 2.0f, 1.0f));\n    test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));\n    test_cases.emplace_back(new test_silu_back());\n\n    for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f, 10.f }) {\n        for (uint32_t n : { 64, 1025 }) {\n            for (bool v : { false, true }) {\n                test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));\n                test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));\n            }\n            test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));\n            test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));\n            test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));\n        }\n    }\n\n    // in-place tests\n    test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, false, 1e-6f, true));\n\n    for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f }) {\n        for (uint32_t n : { 64, 1025 }) {\n            test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));\n            test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));\n            test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));\n            test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));\n            test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));\n            test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));\n        }\n    }\n    for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {\n        for (bool multi_add : {false, true}) {\n            test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false, multi_add));\n        }\n        test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false));\n    }\n\n    for (auto multi_add : {false, true}) {\n        for (auto set_rows : {false, true}) {\n            for (auto rope : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX}) {\n                test_cases.emplace_back(new test_rms_norm_mul_rope({768, 1, 1, 1}, 1e-6f, multi_add, set_rows, rope));\n                test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 1, 1}, 1e-6f, multi_add, set_rows, rope));\n                test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 5, 1}, 1e-6f, multi_add, set_rows, rope));\n                test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 2, 1}, 1e-6f, multi_add, set_rows, rope));\n                test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 2, 1}, 1e-6f, multi_add, set_rows, rope));\n                test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 50, 1}, 1e-6f, multi_add, set_rows, rope));\n                test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 50, 1}, 1e-6f, multi_add, set_rows, rope));\n                test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope));\n                test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope));\n            }\n        }\n    }\n    for (int64_t d_conv : {3, 4, 9}) {\n        for (int64_t d_inner: {1024, 1536, 2048}) {\n            test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));\n            test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));\n            test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}));\n            // long token (n_t > 32, exercises the long_token kernel path)\n            test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));\n            test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}));\n        }\n    }\n\n    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1\n    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2\n    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64,  8, 2, 32, 4)); // Falcon-H1\n\n    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));\n    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));\n    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));\n    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));\n\n    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 1, 1));\n    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 1));\n    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 4));\n    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 128, 4));\n\n    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));\n    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));\n    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));\n    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));\n\n#if 0\n    // > 4GB A matrix. Too slow to be enabled by default.\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16,  900000,  3, 2592, {1, 1}, {1, 1}));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 96, 2592, {1, 1}, {1, 1}));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000,  3, 2592, {1, 1}, {1, 1}));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000,  1, 2592, {1, 1}, {1, 1}));\n\n    test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q8_0, GGML_TYPE_F32, 128, 128, false, 8192, 2, 5120)); // Llama-4-Maverick-17B-128E-PAB-Q8_0\n    test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q8_0, GGML_TYPE_F32, 128, 128, false, 8192, 1, 5120)); // Llama-4-Maverick-17B-128E-PAB-Q8_0\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 8192, 1, 5120, {128, 1}, {1, 1}));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 8192, 512, 5120, {128, 1}, {1, 1}));\n#endif\n\n    for (ggml_type type_a : all_types) {\n        for (int i = 1; i < 10; ++i) {\n            test_cases.emplace_back(new test_mul_mat(type_a,    GGML_TYPE_F32, 16,  i, 256, { 1,  1}, {1, 1}));\n        }\n    }\n\n#if 0\n    {\n        // Test paths in OpenCL\n        std::vector<int> ns = {32, 64, 128, 256, 512, 1024, 4096};\n        std::vector<int> ks = {896, 1536, 4096};\n        for (auto n : ns) {\n            for (auto k : ks) {\n                test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 1024, n, k, {1, 1}, {1, 1}));\n            }\n        }\n    }\n#endif\n\n#if 1\n    for (ggml_type type_a : base_types) {\n        for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {\n            std::vector<int> ks = { 256 };\n            if (ggml_blck_size(type_a) == 1) {\n                ks.push_back(4);\n            }\n            for (auto k : ks) {\n                // test cases without permutation\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {1, 1}, {1, 1}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {1, 1}, {2, 1}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {1, 1}, {1, 2}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 1}, {1, 1}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 1}, {2, 1}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 2}, {1, 1}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 2}, {2, 1}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 2}, {1, 2}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 2}, {2, 2}));\n\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 2}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 1}, {1, 1}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 1}, {2, 1}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {1, 1}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {2, 1}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {1, 2}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {2, 2}));\n\n                // test cases with permutation\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));\n\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));\n\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));\n            }\n\n            // test cases with large ne00/ne10 to cover stream-k fixup\n            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 1024, {3, 2}, {1, 1}));\n            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 1024, {3, 2}, {1, 1}));\n            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1}));\n\n            // test cases with large batch size\n            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {1536, 1}, {1, 1}));\n        }\n    }\n    for (ggml_type type_a : other_types) {\n        for (ggml_type type_b : {GGML_TYPE_F32}) {\n            if (ggml_blck_size(type_a) != 256) {\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1,  1}, {1, 1}));\n            }\n            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1,  1}, {1, 1}));\n        }\n    }\n#else\n    // m = a rows\n    // n = b rows\n    // k = cols\n    std::uniform_int_distribution<> dist_m(1, 128);\n    std::uniform_int_distribution<> dist_n(16, 128);\n    std::uniform_int_distribution<> dist_k(1, 16);\n    for (int i = 0; i < 1000; i++) {\n        for (ggml_type type_a : all_types) {\n            for (ggml_type type_b : {GGML_TYPE_F32}) {\n                int m = dist_m(rng);\n                int n = dist_n(rng);\n                int k = dist_k(rng) * ggml_blck_size(type_a);\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, m, n, k, { 1,  1}, {1, 1}));\n            }\n        }\n    }\n#endif\n\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,  128, { 8,  1}, {1, 1}));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  83, 2,  128, { 8,  1}, {4, 1}));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,   64, { 8,  1}, {4, 1}));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  83, 2,   64, { 8,  1}, {4, 1}));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 45, 128, { 8,  1}, {4, 1}));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45,  64, { 8,  1}, {4, 1}));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1,  1}, {4, 1}, {0, 2, 1, 3}));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67,  {1,  1}, {4, 1}, {0, 2, 1, 3}));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1,  1}, {1, 1}, {0, 1, 2, 3}, 64, 3));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1}));\n\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 576, 512, 576, {1,1}, {1,1}));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 1, 2048, 8192, {1,  1}, {1, 1}));\n    for (ggml_type type_a : all_types) {\n        test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 1, 64, 256, {1,  1}, {1, 1}));\n    }\n\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 6, 4096, 5120, {1, 1}, {1, 1}));\n\n#if 0\n    // test the mat-mat path for Metal\n    for (int k = 1; k < 512; ++k) {\n        test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 127, k, {12,1}, {1,1}));\n        test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 127, k, {12,1}, {1,1}));\n        test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 77, k, {12,1}, {1,1}));\n        test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, k, {12,1}, {1,1}));\n        test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 128, k, {12,1}, {1,1}));\n        test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 128, k, {12,1}, {1,1}));\n        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 50, 200, k));\n        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, true, 50, 200, k));\n        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F32, GGML_TYPE_F32, 16, 16, false, 50, 200, k));\n        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F32, GGML_TYPE_F32, 16, 16, true, 50, 200, k));\n    }\n#endif\n\n    for (auto bs2 : {1,3}) {\n        for (auto bs : {1,2,4,8}) {\n            for (auto nr : {1,4}) {\n                for (uint32_t m = 0; m < 2; ++m) {\n                    for (uint32_t k = 0; k < 2; ++k) {\n                        for (ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {\n                            test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k,  {bs,  bs2}, {nr, 1}, {0, 2, 1, 3}));\n                            test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m,  1, 1056 + k, {bs,  bs2}, {nr, 1}, {0, 1, 2, 3}, 2*1056 + k));\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    // sycl backend will limit task global_range < MAX_INT\n    // test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)\n    // however this case needs to alloc more memory which may fail in some devices (Intel Arc770, etc.)\n    // this case is verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend)\n    // test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 512, 262144, 9216, {1, 1}, {1, 1}));\n\n    // test large experts*tokens\n    for (bool b : {false, true}) {\n        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));\n        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 2, 2, b, 32, 8192, 64));\n        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 50, 200, 64));\n    }\n\n    test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));\n    test_cases.emplace_back(new test_mul_mat_id_fusion(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));\n\n    // gpt-oss issue with Vulkan mmq_id\n    test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));\n\n    for (ggml_type type_a : base_types) {\n        for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {\n            for (int n_mats : {4, 8}) {\n                for (int n_used : {1, 2, 4}) {\n                    for (bool b : {false, true}) {\n                        for (int n : {1, 4, 5, 17, 32, 129}) {\n                            int m = 512;\n                            int k = 256;\n                            test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    for (ggml_type type_a : other_types) {\n        for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {\n            for (int n_mats : {4}) {\n                for (int n_used : {2}) {\n                    for (bool b : {false}) {\n                        for (int n : {1, 32}) {\n                            int m = 512;\n                            int k = 256;\n                            test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    for (int bs : {1, 4, 512}) {\n        for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q4_K}) {\n            for (ggml_type type_b : {GGML_TYPE_F32}) {\n                // test with mul after (ffn_moe_weighted)\n                test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1, true));\n            }\n        }\n    }\n\n    for (ggml_type type_a : base_types) {\n        for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {\n            for (int n : {1, 16}) {\n                for (int k : {1, 16}) {\n                    for (int bs2 : {1, 3}) {\n                        for (int bs3 : {1, 3}) {\n                            for (int nr2 : {1, 2}) {\n                                for (int nr3 : {1, 2}) {\n                                    test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, n, k, {bs2, bs3}, {nr2, nr3}));\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    // add_id\n    for (ggml_type type_a : {GGML_TYPE_F32}) {\n        for (ggml_type type_b : {GGML_TYPE_F32}) {\n            for (int n_mats : {4, 8}) {\n                for (int n_used : {1, 2, 4}) {\n                    for (int n_embd : {32, 129}) {\n                        for (int n_token : {1, 32, 129}) {\n                            test_cases.emplace_back(new test_add_id(type_a, type_b, n_embd, n_mats, n_used, n_token));\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {\n        test_cases.emplace_back(new test_sqr       (type));\n        test_cases.emplace_back(new test_sqrt      (type));\n        test_cases.emplace_back(new test_log       (type));\n        test_cases.emplace_back(new test_sin       (type));\n        test_cases.emplace_back(new test_cos       (type));\n        test_cases.emplace_back(new test_clamp     (type));\n        test_cases.emplace_back(new test_leaky_relu(type));\n        test_cases.emplace_back(new test_floor     (type));\n        test_cases.emplace_back(new test_ceil      (type));\n        test_cases.emplace_back(new test_round     (type));\n        test_cases.emplace_back(new test_trunc     (type));\n        test_cases.emplace_back(new test_sqr       (type, {7, 1, 5, 3}));\n        test_cases.emplace_back(new test_sqr       (type, {1024, 1024, 1, 1}));\n        test_cases.emplace_back(new test_sqrt      (type, {7, 1, 5, 3}));\n        test_cases.emplace_back(new test_sqrt      (type, {1024, 1024, 1, 1}));\n        test_cases.emplace_back(new test_log       (type, {7, 1, 5, 3}));\n        test_cases.emplace_back(new test_log       (type, {1024, 1024, 1, 1}));\n        test_cases.emplace_back(new test_sin       (type, {7, 1, 5, 3}));\n        test_cases.emplace_back(new test_sin       (type, {1024, 1024, 1, 1}));\n        test_cases.emplace_back(new test_cos       (type, {7, 1, 5, 3}));\n        test_cases.emplace_back(new test_cos       (type, {1024, 1024, 1, 1}));\n        test_cases.emplace_back(new test_clamp     (type, {7, 1, 5, 3}));\n        test_cases.emplace_back(new test_clamp     (type, {1024, 1024, 1, 1}));\n        test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));\n        test_cases.emplace_back(new test_leaky_relu(type, {1024, 1024, 1, 1}));\n        test_cases.emplace_back(new test_floor     (type, {7, 1, 5, 3}));\n        test_cases.emplace_back(new test_floor     (type, {1024, 1024, 1, 1}));\n        test_cases.emplace_back(new test_ceil      (type, {7, 1, 5, 3}));\n        test_cases.emplace_back(new test_ceil      (type, {1024, 1024, 1, 1}));\n        test_cases.emplace_back(new test_round     (type, {7, 1, 5, 3}));\n        test_cases.emplace_back(new test_round     (type, {1024, 1024, 1, 1}));\n        test_cases.emplace_back(new test_trunc     (type, {7, 1, 5, 3}));\n        test_cases.emplace_back(new test_trunc     (type, {1024, 1024, 1, 1}));\n    }\n\n    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));\n    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 1}, 5));\n    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 2}, 5));\n\n#if 0\n    std::uniform_int_distribution<> dist_ne1(1, 50);\n    int exponent = 1;\n    while (exponent < (1 << 17)) {\n        std::uniform_int_distribution<> dist_ne0(exponent, 2*exponent);\n\n        for (int n = 0; n < 10; ++n) {\n            int64_t ne0 = dist_ne0(rng);\n            int64_t ne1 = dist_ne1(rng);\n            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f));\n        }\n\n        exponent <<= 1;\n    }\n#endif\n    for (bool mask : {false, true}) {\n        for (bool sinks : {false, true}) {\n            for (float max_bias : {0.0f, 8.0f}) {\n                if (!mask && max_bias > 0.0f) continue;\n                for (float scale : {1.0f, 0.1f}) {\n                    for (int64_t ne0 : {16, 1024}) {\n                        for (int64_t ne1 : {16, 1024}) {\n                            if (mask) {\n                                for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {\n                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, sinks, m_prec, {1, 1}, scale, max_bias));\n                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, m_prec, {1, 1}, scale, max_bias));\n\n                                    if (ne0 <= 32 && ne1 <= 32) {\n                                        test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 3}, mask, sinks, m_prec, {3, 1}, scale, max_bias));\n                                        test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, m_prec, {2, 3}, scale, max_bias));\n                                    }\n                                }\n                            } else {\n                                /* The precision of mask here doesn't matter as boolean mask is false */\n                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, scale, max_bias));\n                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, scale, max_bias));\n                            }\n                        }\n                    }\n                }\n            }\n            // inplace tests\n            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f, true));\n            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, mask, sinks, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f, true));\n        }\n    }\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  false, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  false, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  true,  GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));\n\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true,   true,  GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true,   true,  GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 1, 1, 1}, false,  false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 4, 1, 1}, false,  false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {643251, 3, 1, 1}, false,  false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));\n\n    for (float max_bias : {0.0f, 8.0f}) {\n        for (float scale : {1.0f, 0.1f}) {\n            for (int64_t ne0 : {16, 1024}) {\n                for (int64_t ne1 : {16, 1024}) {\n                    test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, scale, max_bias));\n                    test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, scale, max_bias));\n                    test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0,   ne1,   2, 3}, scale, max_bias));\n                }\n            }\n        }\n    }\n\n    for (bool fw : {true, false}) { // fw == forward\n        bool all = true;\n\n        for (float fs : { 1.0f, 1.4245f }) {\n            for (float ef : { 0.0f, 0.7465f }) {\n                for (float af : { 1.0f, 1.4245f }) {\n                    for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {\n                        for (bool ff : {false, true}) { // freq_factors\n                            for (float v : { 0, 1 }) {\n                                test_cases.emplace_back(new test_rope(type, {128,  32, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 7B\n\n                                if (all) {\n                                    test_cases.emplace_back(new test_rope(type, {128,  40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B\n                                    test_cases.emplace_back(new test_rope(type, {128,  52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B\n                                    test_cases.emplace_back(new test_rope(type, {128,  64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B\n                                    test_cases.emplace_back(new test_rope(type, {16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));\n                                }\n\n                                if (all) {\n                                    test_cases.emplace_back(new test_rope(type, { 64,   1, 2, 1},  64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)\n                                    test_cases.emplace_back(new test_rope(type, { 64,  71, 2, 1},  64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)\n                                    test_cases.emplace_back(new test_rope(type, { 64,   8, 2, 1},  64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)\n\n                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  20, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));\n                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));\n                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 4, 1},  32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));\n\n                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)\n                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)\n                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 4, 1},  32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)\n                                    test_cases.emplace_back(new test_rope(type, { 16, 16, 8192, 1},  16, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw));\n                                }\n\n                                if (all) {\n                                    test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)\n                                    test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)\n                                    test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1},  20, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw));\n                                    test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1},  32, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw));\n                                    test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B)\n                                    test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 7B)\n                                    test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1},  20, GGML_ROPE_TYPE_IMROPE,  512, fs, ef, af, ff, v, fw));\n                                    test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1},  32, GGML_ROPE_TYPE_IMROPE,  512, fs, ef, af, ff, v, fw));\n                                    test_cases.emplace_back(new test_rope(type, { 80,  16, 2, 1},  80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)\n                                    test_cases.emplace_back(new test_rope(type, {128,  16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl)\n                                    test_cases.emplace_back(new test_rope(type, {16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));\n                                }\n\n                                test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1},  64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)\n                            }\n                        }\n\n                        all = false;\n                    }\n                }\n            }\n        }\n    }\n\n    // single inplace test per type/mode/ff\n    for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {\n        for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_IMROPE, GGML_ROPE_TYPE_VISION}) {\n            for (bool ff : {false, true}) {\n                test_cases.emplace_back(new test_rope(type, {128,  32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true));\n                test_cases.emplace_back(new test_rope(type, {128,  32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 1, true, true));\n                test_cases.emplace_back(new test_rope(type, {128,  32, 2, 3}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 1, true, true));\n            }\n        }\n    }\n\n    for (int v : { 0, 1, 2, 3 }) {\n        for (int dim : { 0, 1, 2, 3, }) {\n            test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));\n            test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim, v));\n        }\n    }\n\n    for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {\n        for (uint32_t i = 4; i <= 1024*1024; i *= 2) {\n            test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i-1, 1, 1, 1}));\n            test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i, 1, 1, 1}));\n        }\n        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));\n        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen\n        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order));\n        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order));\n        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order));\n        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order));\n        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));\n        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));\n        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)\n    }\n\n    for (int n = 1; n < 5; ++n) {\n        for (int k = 1; k <= n; ++k) {\n            test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {n, 2, 1, 3}, k, true));\n        }\n    }\n    for (int i = 0; i < 20; ++i) {\n        for (int k : {1, 2, 3, 7, 15, 100, 500, 1023, 9999}) {\n            if (k <= 1<<i) {\n                test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));\n                test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));\n                test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k, true));\n            }\n        }\n    }\n    for (int k : {1, 2, 3, 7, 15}) {\n        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));\n        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));\n        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1023, 2, 1, 3}, k));\n        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1024, 2, 1, 3}, k));\n        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1025, 2, 1, 3}, k));\n        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16384, 1, 1, 1}, k));\n        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2047, 2, 1, 3}, k));\n        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2048, 2, 1, 3}, k));\n        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2049, 2, 1, 3}, k));\n    }\n\n    // exhaustive top_k tests\n    //for (int i = 1; i < 9999; ++i) {\n    //    test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1));\n    //}\n\n    for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC, ggml_scale_mode(GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS)}) {\n        test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));\n        test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));\n        test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5,  7, 11}, {5, 7, 11, 13}, mode));\n        test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {5, 7, 11, 13}, {2, 5,  7, 11}, mode));\n    }\n    for (ggml_scale_mode mode : {GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {\n        test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, (ggml_scale_mode)(mode | GGML_SCALE_FLAG_ALIGN_CORNERS)));\n        test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {1, 4, 3, 2}, {2, 8, 3, 2}, (ggml_scale_mode)(mode | GGML_SCALE_FLAG_ALIGN_CORNERS)));\n        test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {4, 1, 3, 2}, {1, 1, 3, 2}, (ggml_scale_mode)(mode | GGML_SCALE_FLAG_ALIGN_CORNERS)));\n    }\n\n    test_cases.emplace_back(new test_sum());\n    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3}));  // row-contiguous but non-contiguous\n    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));\n    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));\n    test_cases.emplace_back(new test_mean());\n    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));\n    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));\n    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));\n    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 1, 1, 1 }));\n    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 256, 1, 1 }));\n    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32768, 1, 1, 1 }));\n    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));\n    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));\n    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));\n    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous\n    test_cases.emplace_back(new test_sum_rows());\n    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));\n    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));\n    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));\n    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, false));\n    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, false, true));\n    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, true));\n    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));\n    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));\n    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));\n    test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));\n    test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));\n    test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1}));\n    test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1}));\n    test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 1, 1}, {256, 16, 1, 1}, -1));\n    test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, -1));\n    test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, -1));\n    test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, 1));\n    test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2));\n    test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3));\n    test_cases.emplace_back(new test_pad());\n    test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {33, 17, 2, 1}, 4, 3, true)); // circular\n    test_cases.emplace_back(new test_pad_ext());\n    test_cases.emplace_back(new test_pad_reflect_1d());\n    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));\n    test_cases.emplace_back(new test_roll());\n    test_cases.emplace_back(new test_arange());\n    test_cases.emplace_back(new test_arange(GGML_TYPE_F32, 0.0f, 1048576.0f, 1.0f));\n    test_cases.emplace_back(new test_timestep_embedding());\n    test_cases.emplace_back(new test_leaky_relu());\n\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 10, 5, 4, 3 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 127, 5, 4, 3 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 5, 4, 3 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 255, 5, 4, 3 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 256, 5, 4, 3 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 511, 5, 4, 3 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 512, 5, 4, 3 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1023, 5, 4, 3 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 5, 4, 3 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2047, 5, 4, 3 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 5, 4, 3 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 201*1204, 1, 1, 1 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 312*1205, 1, 1, 1 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20481, 4, 1, 1 }));\n\n    test_cases.emplace_back(new test_xielu());\n\n    test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER));\n    test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG));\n    test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER));\n    test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG));\n\n    test_cases.emplace_back(new test_fill(0.0f));\n    test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 }));\n    test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 }));\n    test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 }));\n\n    test_cases.emplace_back(new test_diag());\n    test_cases.emplace_back(new test_diag(GGML_TYPE_F32, { 79, 1, 19, 13 }));\n    test_cases.emplace_back(new test_diag(GGML_TYPE_F32, { 256, 1, 8, 16 }));\n\n    test_cases.emplace_back(new test_solve_tri());\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 17, 17, 2, 4 }, { 9, 17, 2, 4 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 30, 30, 7, 1 }, { 8, 30, 7, 1 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 64, 64, 2, 2 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 79, 79, 5, 3 }, { 417, 79, 5, 3 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 80, 80, 2, 8 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 79, 80, 2, 8 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 81, 80, 2, 8 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 80, 80, 8, 8 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 79, 80, 8, 8 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 81, 80, 8, 8 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 84, 84, 4, 4 }, { 32, 84, 4, 4 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 95, 95, 8, 8 }, { 40, 95, 8, 8 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 31, 128, 4, 4 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 32, 128, 4, 4 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 3, 4 }, { 32, 128, 3, 4 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 32, 128, 4, 1 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 200, 64, 4, 4 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 384, 64, 4, 4 }));\n\n    for (int tfrm : {0, 1, 2}) {\n        for (bool circular : {false, true}) {\n            test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, tfrm, circular));\n            test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, tfrm, circular));\n        }\n    }\n\n    for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 320, 576 }) {\n        for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {\n            if (hsk != 192 && hsk != 320 && hsk != 576 && hsk != hsv) continue;\n            if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;\n            if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA\n            if (hsk == 320 && hsv != 256) continue; // MLA\n\n            for (bool mask : { true, false } ) {\n                for (bool sinks : { true, false } ) {\n                    for (float max_bias : { 0.0f, 8.0f }) {\n                        if (!mask && max_bias > 0.0f) continue;\n                        for (float logit_softcap : {0.0f, 10.0f}) {\n                            if (hsk != 128 && logit_softcap != 0.0f) continue;\n                            for (int nh : { 1, 4 }) {\n                                if (nh == 1 && hsk != 320 && hsk != 576) continue; // GLM 4.7 Flash\n                                for (int nr3 : { 1, 3, }) {\n                                    if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes\n                                    for (int nr2 : { 1, 4, 12, 20, 32 }) {\n                                        if (nr2 == 12 && hsk != 128) continue;\n                                        if (nr2 == 20 && (nh != 1 || hsk != 576)) continue;\n                                        if (nr2 == 32 && (nh != 1 || hsk != 320)) continue;\n                                        //for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {\n                                        for (int kv : { 113, 512, 1024, }) {\n                                            if (nr2 != 1 && kv != 512) continue;\n                                            for (int nb : { 1, 3, 32, 75, }) {\n                                                for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {\n                                                    if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;\n                                                    for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {\n                                                        if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue;\n                                                        test_cases.emplace_back(new test_flash_attn_ext(\n                                                                    hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));\n                                                        // run fewer test cases permuted\n                                                        if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {\n                                                            test_cases.emplace_back(new test_flash_attn_ext(\n                                                                        hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));\n                                                        }\n                                                    }\n                                                }\n                                            }\n                                        }\n                                    }\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    test_cases.emplace_back(new test_cross_entropy_loss     (GGML_TYPE_F32, {   10, 5, 4, 3}));\n    test_cases.emplace_back(new test_cross_entropy_loss     (GGML_TYPE_F32, {30000, 1, 1, 1}));\n    test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {   10, 5, 4, 3}));\n    test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1}));\n\n    test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));\n    test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3}));\n\n    for (ggml_type type : base_types) {\n        for (bool with_gate : {false, true}) {\n            for (bool use_id : {false, true}) {\n                for (bool b : {false, true}) {\n                    if (!use_id && b) {\n                        continue;\n                    }\n                    for (bool with_bias : {false, true}) {\n                        if (!with_gate && !with_bias) {\n                            continue;\n                        }\n                        for (ggml_glu_op glu_op : {GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU}) {\n                            if (!with_bias && glu_op == GGML_GLU_OP_SWIGLU_OAI) {\n                                continue;\n                            }\n                            if (!with_gate && glu_op != GGML_GLU_OP_SWIGLU) {\n                                continue;\n                            }\n                            test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,\n                                use_id, 16, 8, b, with_bias, with_gate));\n                            test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,\n                                use_id, 16, 8, b, with_bias, with_gate, {1, 1}));\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    for (auto gate : {GATING_FUNC_SOFTMAX, GATING_FUNC_SIGMOID, GATING_FUNC_SOFTMAX_WEIGHT}) {\n        for (bool with_norm : {false, true}) {\n            for (bool bias_probs : {false, true}) {\n                for (float scale_w : {0.0f, 2.0f}) {\n                    test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm, bias_probs, gate, scale_w));\n                    test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));\n                    test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));\n                    test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));\n                    test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));\n                    test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));\n                    test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));\n                    test_cases.emplace_back(new test_topk_moe({160, 4, 1, 1}, 160, with_norm, bias_probs, gate, scale_w));\n                }\n            }\n        }\n    }\n\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1, 1, true));\n    // KDA (vector gate)\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 1, 1, false, true));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 2, 1, false, true));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 1, 2, 1, false, true));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, true));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true,  true));\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 4, 2, 1, true,  true));\n\n#if 0\n    // these tests are disabled to save execution time, sbut they can be handy for debugging\n    test_cases.emplace_back(new test_llama(2, true));\n    test_cases.emplace_back(new test_llama(1));\n    test_cases.emplace_back(new test_llama(2));\n    test_cases.emplace_back(new test_falcon(1));\n    test_cases.emplace_back(new test_falcon(2));\n#endif\n\n    return test_cases;\n}\n#ifdef _MSC_VER\n#pragma optimize(\"\", on)\n#endif\n\n// Test cases for performance evaluation: should be representative of real-world use cases\nstatic std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {\n    std::vector<std::unique_ptr<test_case>> test_cases;\n\n    // Conv2d: K=CRS=NPQ=4096 matmul performance\n    uint32_t                        iwh_idx  = 0;\n    uint32_t                        kwh_idx  = 1;\n    uint32_t                        Cout_idx = 2;\n    uint32_t                        Cin_idx  = 3;\n    uint32_t                        B_idx    = 4;\n    std::vector<std::array<int, 5>> cases    = {\n  //{IWH, KWH, Cout, Cin, B}\n  // K=CRS=NPQ=4096 conv2d matmul performance\n        {19,   4, 4096, 256, 16},\n // K=128, CRS=128, NPQ=4096\n        { 19,  4, 128,  8,   16},\n // K=130, CRS=128, NPQ=4096\n        { 19,  4, 130,  8,   16},\n // Edge case: K x CRS is small\n        { 19,  2, 4,    4,   16},\n // A ConvNet's first layer\n        { 224, 3, 8,    3,   1 },\n // A ConvNet's first layer with 2x2 convolution, and 1 channel\n        { 224, 2, 8,    1,   1 },\n // A ConvNet's first layer with 2x2 convolution, and 1 channel, several images in the batch\n        { 224, 2, 8,    1,   8 },\n // A middle layer of a ConvNet\n        { 58,  3, 64,   32,  1 },\n // A middle layer of a ConvNet, several images in the batch\n        { 58,  3, 64,   32,  8 },\n // A deep layer of a ConvNet, several images in the batch\n        { 16,  3, 512,  128, 8 },\n // High resolution output (large NPQ)\n        {1536, 3, 64,   32,  1 },\n    };\n\n    for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {\n        for (auto act_case : cases) {\n            // Direct CONV_2D\n            test_cases.emplace_back(new test_conv_2d(\n                { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },\n                { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },\n                kernel_type, 1, 1, 0, 0, 1, 1, false));\n        }\n    }\n\n    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1,   1, 1, 1}));\n    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));\n\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32,  GGML_TYPE_F16,  {512, 3072, 1, 1}));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32,  GGML_TYPE_F32,  {8192, 512, 2, 1}, {0, 2, 1, 3}));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32,  GGML_TYPE_F32,  {3072, 512, 2, 1}, {0, 2, 1, 3}));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32,  GGML_TYPE_Q4_0, {8192, 512, 2, 1}));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32,  {8192, 512, 2, 1}));\n\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));\n\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));\n    test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));\n\n\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));\n    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));\n\n    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));\n    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));\n    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));\n\n    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {512, 34, 2, 1}));\n    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 1, 1}));\n    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 4, 1}));\n    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 1, 1}));\n    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));\n\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8,  1}, {4, 1}, {0, 2, 1, 3}));\n    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8,  1}, {4, 1}, {0, 1, 2, 3}, 2*16416));\n\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 32, 64, 4, 4 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 }));\n    // qwen3next with CHUNK_SIZE 64\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 8, 32 }, { 64, 64, 8, 32 }));\n    // qwen3next with CHUNK_SIZE 128\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 32 }, { 128, 128, 4, 32 }));\n    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 256, 256, 4, 2 }, { 128, 256, 4, 2 }));\n\n    test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));\n    test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));\n\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 16, 5, 4 }));\n    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20000, 10, 4, 1 }));\n\n    for (int bs : {1, 2, 3, 4, 5, 8, 512}) {\n        for (ggml_type type_a : all_types) {\n            for (ggml_type type_b : {GGML_TYPE_F32}) {\n                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, bs, 14336, {1,  1}, {1, 1}));\n            }\n        }\n    }\n\n    // qwen3-30b-a3b\n    for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {\n        for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {\n            for (ggml_type type_b : {GGML_TYPE_F32}) {\n                test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048));\n                test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));\n            }\n        }\n    }\n\n    for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {\n        for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {\n            for (ggml_type type_b : {GGML_TYPE_F32}) {\n                test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048));\n                test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));\n            }\n        }\n    }\n\n\n    // gpt-oss-20b\n    for (int bs : {1, 4, 8, 512}) {\n        for (ggml_type type_a : {GGML_TYPE_MXFP4}) {\n            for (ggml_type type_b : {GGML_TYPE_F32}) {\n                test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880));\n                test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));\n            }\n        }\n    }\n\n    for (int K : {3, 5}) {\n        for (int IC : {256, 2560}) {\n            for (int IW_IH : {32, 64, 256}) {\n                if (IC == 2560 && IW_IH == 256) {\n                    // too big\n                    continue;\n                }\n                test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {IW_IH, IW_IH, IC, 1}, {K, K, IC, 1}, 1, 1, 1, 1, 1, 1, true));\n            }\n        }\n    }\n\n    // Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012\n    test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));\n\n    test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));\n    test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 4, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));\n\n    for (int kv : { 4096, 8192, 16384, }) {\n        for (int hs : { 64, 128, }) {\n            for (int nr : { 1, 4, }) {\n                test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));\n            }\n        }\n    }\n\n    for (int col : {8192, 16384, 32768, 65536, 131072, 262144, 524288}) {\n        for (int rows : {1, 4, 16}){\n            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {col, rows, 1, 1}, false,  false,  GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));\n        }\n    }\n\n    test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));\n    test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));\n\n    test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));\n    test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1));\n    test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));\n\n    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));\n\n\n    for (int n_token : {1, 512}) {\n        test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 128, 4, n_token));\n        test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 32, 4, n_token));\n    }\n\n    for (bool fw : {true, false}) { // fw == forward\n        for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {\n            for (bool ff : {false, true}) { // freq_factors\n                for (float v : { 0, 1 }) {\n                    test_cases.emplace_back(new test_rope(type, {128,  32, 512, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // llama 7B\n                    test_cases.emplace_back(new test_rope(type, {128,  64, 512, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // llama 65B\n                    test_cases.emplace_back(new test_rope(type, { 80,  32, 512, 1},  20, GGML_ROPE_TYPE_NEOX, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // neox (stablelm)\n                    test_cases.emplace_back(new test_rope(type, { 64,   8, 512, 1},  64, GGML_ROPE_TYPE_NEOX, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // neox (falcon 40B)\n                    test_cases.emplace_back(new test_rope(type, {128,  12, 512, 1}, 128, GGML_ROPE_TYPE_MROPE,  512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)\n                    test_cases.emplace_back(new test_rope(type, {128,  12, 512, 1}, 128, GGML_ROPE_TYPE_IMROPE,  512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B)\n                    test_cases.emplace_back(new test_rope(type, { 80,  16, 2, 1},  80, GGML_ROPE_TYPE_VISION, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)\n                }\n            }\n        }\n    }\n\n    std::vector<std::array<int64_t, 4>> reduce_rows_cases = {\n        { 8192, 1,    1, 1 },\n        { 8192, 8192, 1, 1 },\n        { 128,  8192, 1, 1 },\n    };\n\n    for (auto it: reduce_rows_cases){\n        test_cases.emplace_back(new test_mean(GGML_TYPE_F32, it));\n        test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, it));\n        test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it));\n    }\n\n    test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000,  16, 1, 1}));\n    test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1,  1, 1}));\n    test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 16, 1, 1}));\n\n    test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1));\n    for (auto k : {1, 10, 40, 400}) {\n        for (auto nrows : {1, 16}) {\n            for (auto cols : {k, 1000, 65000, 200000}) {\n                test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k));\n            }\n        }\n    }\n\n    for (auto nrows : {1, 4, 8, 16}) {\n        for (auto cols : {128, 1024, 4096, 8192, 16384, 32768, 65536, 131072, 200000, 2000000}) {\n            test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, {cols, nrows, 1, 1}));\n        }\n    }\n\n    // Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf\n    test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill\n    test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4,   3328, 1, 1}, {4, 3328, 1, 1})); // generate\n    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill\n    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1,   1)); // generate\n\n    // acc\n    test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 1, 1}, {256, 16, 1, 1}, -1));\n    test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, -1));\n    test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, -1));\n    test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, 1));\n    test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2));\n    test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3));\n\n    // GATED_DELTA_NET: realistic model configurations\n    // TG: n_seq_tokens=1 (autoregressive)\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1));   // Qwen3.5-like: 32 heads, d=128\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64,  1, 1));   // smaller model\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1, 1, false, true)); // KDA\n    // PP: n_seq_tokens=64,256 (prompt processing)\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1));  // PP-64\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 256, 1)); // PP-256\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 512, 1)); // PP-512\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1024, 1)); // PP-1024\n    // Small model configs (fewer heads = less GPU occupancy for autoregressive)\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 64, 1));   // 4h PP-64\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 256, 1));  // 4h PP-256\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 512, 1));  // 4h PP-512\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 1024, 1)); // 4h PP-1024\n    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1, 1, false, true)); // KDA PP-64\n\n    return test_cases;\n}\n\nstatic std::vector<std::unique_ptr<test_case>> make_test_cases_from_file(const char * path) {\n    std::ifstream f(path);\n\n    if (!f.is_open()) {\n        throw std::runtime_error(\"Unable to read test file\");\n    }\n\n    std::vector<std::unique_ptr<test_case>> test_cases;\n\n    std::string line;\n\n    while (std::getline(f, line)) {\n        std::istringstream iss(line);\n\n        ggml_op op;\n        ggml_type type;\n        std::array<int64_t, 4> ne;\n        std::array<int32_t, GGML_MAX_OP_PARAMS / sizeof(int32_t)> op_params = {};\n        std::string name;\n        uint64_t tmp;\n\n        iss >> tmp;\n        op = (ggml_op)tmp;\n        iss >> tmp;\n        type = (ggml_type)tmp;\n\n        for (size_t i = 0; i < 4; i++) {\n            iss >> ne[i];\n        }\n\n        iss >> tmp;\n        for (size_t i = 0; i < tmp && i < op_params.size(); i++) {\n            iss >> op_params[i];\n        }\n\n        iss >> tmp;\n\n        size_t num_src = std::min((uint64_t)GGML_MAX_SRC, tmp);\n        std::vector<input_tensor> sources(num_src);\n        for (size_t i = 0; i < num_src; i++) {\n            input_tensor& src = sources[i];\n            iss >> tmp;\n            src.type = (ggml_type)tmp;\n\n            for (size_t i = 0; i < 4; i++) {\n                iss >> src.ne[i];\n            }\n            for (size_t i = 0; i < 4; i++) {\n                iss >> src.nb[i];\n            }\n        }\n\n        iss >> name;\n\n        if (name.length() == 1 && name[0] == '-') {\n            name = \"\";\n        }\n\n        test_cases.emplace_back(new test_generic_op(op, type, ne, op_params, sources, std::move(name)));\n    }\n\n    return test_cases;\n}\n\nstatic bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter,\n                         printer * output_printer, const char * test_file_path) {\n    auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {\n        if (params_filter == nullptr) {\n            return;\n        }\n\n        std::regex params_filter_regex(params_filter);\n\n        for (auto it = test_cases.begin(); it != test_cases.end();) {\n            if (!std::regex_search((*it)->vars(), params_filter_regex)) {\n                it = test_cases.erase(it);\n                continue;\n            }\n\n            it++;\n        }\n    };\n\n    std::vector<std::unique_ptr<test_case>> test_cases;\n\n    if (test_file_path == nullptr) {\n        switch (mode) {\n        case MODE_TEST:\n        case MODE_GRAD:\n        case MODE_SUPPORT:\n            test_cases = make_test_cases_eval();\n            break;\n        case MODE_PERF:\n            test_cases = make_test_cases_perf();\n            break;\n        }\n    } else {\n        test_cases = make_test_cases_from_file(test_file_path);\n    }\n\n    filter_test_cases(test_cases, params_filter);\n\n    if (mode == MODE_TEST) {\n        ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);\n        if (backend_cpu == NULL) {\n            test_operation_info info(\"\", \"\", \"CPU\");\n            info.set_error(\"backend\", \"Failed to initialize CPU backend\");\n            output_printer->print_operation(info);\n            return false;\n        }\n        // Use reference implementation on the CPU backend for comparison\n        using ggml_backend_cpu_set_use_ref_t = void (*)(ggml_backend_t, bool);\n        auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));\n        auto * set_use_ref = (ggml_backend_cpu_set_use_ref_t) ggml_backend_reg_get_proc_address(reg, \"ggml_backend_cpu_set_use_ref\");\n        if (set_use_ref) {\n            set_use_ref(backend_cpu, true);\n        }\n\n        size_t n_ok = 0;\n        size_t                   tests_run = 0;\n        std::vector<std::string> failed_tests;\n        for (auto & test : test_cases) {\n            test_status_t status = test->eval(backend, backend_cpu, op_names_filter, output_printer);\n            if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) {\n                continue;\n            }\n            tests_run++;\n            if (status == test_status_t::OK) {\n                n_ok++;\n            } else if (status == test_status_t::FAIL) {\n                failed_tests.push_back(test->current_op_name + \"(\" + test->vars() + \")\");\n            }\n        }\n        output_printer->print_summary(test_summary_info(n_ok, tests_run, false));\n        output_printer->print_failed_tests(failed_tests);\n\n        ggml_backend_free(backend_cpu);\n\n        return n_ok == tests_run;\n    }\n\n    if (mode == MODE_GRAD) {\n        size_t n_ok = 0;\n        for (auto & test : test_cases) {\n            if (test->eval_grad(backend, op_names_filter, output_printer)) {\n                n_ok++;\n            }\n        }\n        output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false));\n\n        return n_ok == test_cases.size();\n    }\n\n    if (mode == MODE_PERF) {\n        for (auto & test : test_cases) {\n            test->eval_perf(backend, op_names_filter, output_printer);\n        }\n        return true;\n    }\n\n    if (mode == MODE_SUPPORT) {\n        // Filter out fusion cases\n        test_cases.erase(\n            std::remove_if(test_cases.begin(), test_cases.end(), [](const std::unique_ptr<test_case> & tc) {\n                return tc->run_whole_graph();\n            }),\n            test_cases.end()\n        );\n\n        for (auto & test : test_cases) {\n            test->eval_support(backend, op_names_filter, output_printer);\n        }\n        return true;\n    }\n\n    GGML_ABORT(\"fatal error\");\n}\n\nstatic void list_all_ops() {\n    printf(\"GGML operations:\\n\");\n    std::set<std::string> all_ops;\n\n    for (int i = 1; i < GGML_OP_COUNT; i++) {\n        all_ops.insert(ggml_op_name((enum ggml_op)i));\n    }\n    for (int i = 0; i < GGML_UNARY_OP_COUNT; i++) {\n        all_ops.insert(ggml_unary_op_name((enum ggml_unary_op)i));\n    }\n    for (int i = 0; i < GGML_GLU_OP_COUNT; i++) {\n        all_ops.insert(ggml_glu_op_name((enum ggml_glu_op)i));\n    }\n    for (const auto & op : all_ops) {\n        printf(\"  %s\\n\", op.c_str());\n    }\n    printf(\"\\nTotal: %zu operations\\n\", all_ops.size());\n}\n\nstatic void show_test_coverage() {\n    std::set<std::string> all_ops;\n    for (int i = 1; i < GGML_OP_COUNT; i++) {\n        auto op = (enum ggml_op)i;\n        if (op == GGML_OP_VIEW      ||\n            op == GGML_OP_RESHAPE   ||\n            op == GGML_OP_PERMUTE   ||\n            op == GGML_OP_TRANSPOSE ||\n            op == GGML_OP_CONT      ||\n            op == GGML_OP_GLU       ||\n            op == GGML_OP_UNARY) {\n            continue;\n        }\n        all_ops.insert(ggml_op_name(op));\n    }\n    for (int i = 0; i < GGML_UNARY_OP_COUNT; i++) {\n        all_ops.insert(ggml_unary_op_name((enum ggml_unary_op)i));\n    }\n    for (int i = 0; i < GGML_GLU_OP_COUNT; i++) {\n        all_ops.insert(ggml_glu_op_name((enum ggml_glu_op)i));\n    }\n    auto test_cases = make_test_cases_eval();\n    // Filter out fusion cases\n    test_cases.erase(\n        std::remove_if(test_cases.begin(), test_cases.end(), [](const std::unique_ptr<test_case> & tc) {\n            return tc->run_whole_graph();\n        }),\n        test_cases.end()\n    );\n\n    std::set<std::string> tested_ops;\n\n    ggml_init_params params = {\n        /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),\n        /* .mem_base = */ NULL,\n        /* .no_alloc = */ true,\n    };\n\n    for (auto & test_case : test_cases) {\n        ggml_context * ctx = ggml_init(params);\n        if (ctx) {\n            test_case->mode = MODE_TEST;\n            ggml_tensor * out = test_case->build_graph(ctx);\n            if (out && out->op != GGML_OP_NONE) {\n                if (out->op == GGML_OP_UNARY) {\n                    tested_ops.insert(ggml_unary_op_name(ggml_get_unary_op(out)));\n                } else if (out->op == GGML_OP_GLU) {\n                    tested_ops.insert(ggml_glu_op_name(ggml_get_glu_op(out)));\n                } else {\n                    tested_ops.insert(ggml_op_name(out->op));\n                }\n            }\n            ggml_free(ctx);\n        }\n    }\n    std::set<std::string> covered_ops;\n    std::set<std::string> uncovered_ops;\n    for (const auto & op : all_ops) {\n        if (tested_ops.count(op) > 0) {\n            covered_ops.insert(op);\n        } else {\n            uncovered_ops.insert(op);\n        }\n    }\n\n    printf(\"Operations covered by tests (%zu):\\n\", covered_ops.size());\n    for (const auto & op : covered_ops) {\n        printf(\"  ✓ %s\\n\", op.c_str());\n    }\n    printf(\"\\nOperations without tests (%zu):\\n\", uncovered_ops.size());\n    for (const auto & op : uncovered_ops) {\n        printf(\"  ✗ %s\\n\", op.c_str());\n    }\n\n    printf(\"\\nCoverage Summary:\\n\");\n    printf(\"  Total operations: %zu\\n\", all_ops.size());\n    printf(\"  Tested operations: %zu\\n\", covered_ops.size());\n    printf(\"  Untested operations: %zu\\n\", uncovered_ops.size());\n    printf(\"  Coverage: %.1f%%\\n\", (double)covered_ops.size() / all_ops.size() * 100.0);\n}\n\nstatic void usage(char ** argv) {\n    printf(\"Usage: %s [mode] [-o <op,..>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>] [--list-ops]\", argv[0]);\n    printf(\" [--show-coverage] [--test-file <path>]\\n\");\n    printf(\"    valid modes:\\n\");\n    printf(\"      - test (default, compare with CPU backend for correctness)\\n\");\n    printf(\"      - grad (compare gradients from backpropagation with method of finite differences)\\n\");\n    printf(\"      - perf (performance evaluation)\\n\");\n    printf(\"      - support (probe backend operation support)\\n\");\n    printf(\"    op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc),\\n\");\n    printf(\"        optionally including the full test case string (e.g. \\\"ADD(type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1)\\\")\\n\");\n    printf(\"    --output specifies output format (default: console, options: console, sql, csv)\\n\");\n    printf(\"    --list-ops lists all available GGML operations\\n\");\n    printf(\"    --show-coverage shows test coverage\\n\");\n    printf(\"    --test-file reads test operators from a test file generated by llama-export-graph-ops\\n\");\n}\n\nint main(int argc, char ** argv) {\n    test_mode mode = MODE_TEST;\n    output_formats output_format = CONSOLE;\n    const char * op_names_filter = nullptr;\n    const char * backend_filter = nullptr;\n    const char * params_filter = nullptr;\n    const char * test_file_path = nullptr;\n\n    for (int i = 1; i < argc; i++) {\n        if (strcmp(argv[i], \"test\") == 0) {\n            mode = MODE_TEST;\n        } else if (strcmp(argv[i], \"perf\") == 0) {\n            mode = MODE_PERF;\n        } else if (strcmp(argv[i], \"grad\") == 0) {\n            mode = MODE_GRAD;\n        } else if (strcmp(argv[i], \"support\") == 0) {\n            mode = MODE_SUPPORT;\n        } else if (strcmp(argv[i], \"-o\") == 0) {\n            if (i + 1 < argc) {\n                op_names_filter = argv[++i];\n            } else {\n                usage(argv);\n                return 1;\n            }\n        } else if (strcmp(argv[i], \"-b\") == 0) {\n            if (i + 1 < argc) {\n                backend_filter = argv[++i];\n            } else {\n                usage(argv);\n                return 1;\n            }\n        } else if (strcmp(argv[i], \"-p\") == 0) {\n            if (i + 1 < argc) {\n                params_filter = argv[++i];\n            } else {\n                usage(argv);\n                return 1;\n            }\n        } else if (strcmp(argv[i], \"--output\") == 0) {\n            if (i + 1 < argc) {\n                if (!output_format_from_str(argv[++i], output_format)) {\n                    usage(argv);\n                    return 1;\n                }\n            } else {\n                usage(argv);\n                return 1;\n            }\n        } else if (strcmp(argv[i], \"--list-ops\") == 0) {\n            list_all_ops();\n            return 0;\n        } else if (strcmp(argv[i], \"--show-coverage\") == 0) {\n            show_test_coverage();\n            return 0;\n        } else if (strcmp(argv[i], \"--test-file\") == 0) {\n            if (i + 1 < argc) {\n                test_file_path = argv[++i];\n            } else {\n                usage(argv);\n                return 1;\n            }\n        } else {\n            usage(argv);\n            return 1;\n        }\n    }\n\n    // load and enumerate backends\n    ggml_backend_load_all();\n\n    // Create printer for output format\n    std::unique_ptr<printer> output_printer = create_printer(output_format);\n    if (output_printer) {\n        output_printer->print_header();\n    }\n\n    output_printer->print_testing_start(testing_start_info(ggml_backend_dev_count()));\n\n    size_t n_ok = 0;\n\n    for (size_t i = 0; i < ggml_backend_dev_count(); i++) {\n        ggml_backend_dev_t dev = ggml_backend_dev_get(i);\n\n        if (backend_filter != NULL && strcmp(backend_filter, ggml_backend_dev_name(dev)) != 0) {\n            output_printer->print_backend_init(\n                backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), true, \"Skipping\"));\n            n_ok++;\n            continue;\n        }\n\n        if (backend_filter == NULL && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) {\n            output_printer->print_backend_init(backend_init_info(\n                i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), true, \"Skipping CPU backend\"));\n            n_ok++;\n            continue;\n        }\n\n        ggml_backend_t backend = ggml_backend_dev_init(dev, NULL);\n        GGML_ASSERT(backend != NULL);\n\n        ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);\n        auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, \"ggml_backend_set_n_threads\");\n        if (ggml_backend_set_n_threads_fn) {\n            // TODO: better value for n_threads\n            ggml_backend_set_n_threads_fn(backend, N_THREADS);\n        }\n\n        size_t free, total;  // NOLINT\n        ggml_backend_dev_memory(dev, &free, &total);\n        output_printer->print_backend_init(backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev),\n                                                             false, \"\", ggml_backend_dev_description(dev),\n                                                             total / 1024 / 1024, free / 1024 / 1024, true));\n\n        bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get(), test_file_path);\n\n        if (ok) {\n            n_ok++;\n        }\n        output_printer->print_backend_status(\n            backend_status_info(ggml_backend_name(backend), ok ? test_status_t::OK : test_status_t::FAIL));\n\n        ggml_backend_free(backend);\n    }\n\n    ggml_quantize_free();\n\n    if (output_printer) {\n        output_printer->print_footer();\n    }\n\n    output_printer->print_overall_summary(\n        overall_summary_info(n_ok, ggml_backend_dev_count(), n_ok == ggml_backend_dev_count()));\n\n    if (n_ok != ggml_backend_dev_count()) {\n        return 1;\n    }\n\n    return 0;\n}\n"
  },
  {
    "path": "tests/test-cont.c",
    "content": "#include \"ggml-backend.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml.h\"\n\n#ifdef GGML_USE_CUDA\n#include \"ggml-cuda.h\"\n#endif\n\n#include <stdlib.h>\n#include <string.h>\n\nstruct model {\n    struct ggml_context* ctx;\n    struct ggml_context* ctx0;\n    ggml_backend_t backend;\n    ggml_backend_buffer_t buffer;\n    struct ggml_cgraph* gf;\n    ggml_gallocr_t allocr;\n    uint8_t* buf;\n};\n\nstruct ggml_context* make_ctx(void) {\n    struct ggml_init_params params = {\n        .mem_size = ggml_tensor_overhead() * 3,\n        .mem_buffer = NULL,\n        .no_alloc = true,\n    };\n    return ggml_init(params);\n}\n\nggml_backend_t make_backend(void) {\n    ggml_backend_t backend = NULL;\n\n#ifdef GGML_USE_CUDA\n    backend = ggml_backend_cuda_init(0);\n    GGML_ASSERT(backend != NULL);\n#endif\n\n    if (!backend) {\n        backend = ggml_backend_cpu_init();\n    }\n\n    return backend;\n}\n\nvoid model_init(struct model* m) {\n    m->ctx = make_ctx();\n    m->backend = make_backend();\n\n    size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();\n    m->buf = calloc(buf_size, sizeof(uint8_t));\n    struct ggml_init_params params0 = {\n        .mem_size = buf_size,\n        .mem_buffer = m->buf,\n        .no_alloc = true,\n    };\n    m->ctx0 = ggml_init(params0);\n    m->gf = ggml_new_graph(m->ctx0);\n}\n\nvoid model_alloc(struct model* m) {\n    m->buffer = ggml_backend_alloc_ctx_tensors(m->ctx, m->backend);\n    m->allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(m->backend));\n}\n\nvoid model_compute(struct model* m) {\n    ggml_gallocr_alloc_graph(m->allocr, m->gf);\n    ggml_backend_graph_compute(m->backend, m->gf);\n}\n\nvoid model_free(struct model* m) {\n    ggml_free(m->ctx0);\n    free(m->buf);\n    ggml_gallocr_free(m->allocr);\n    ggml_free(m->ctx);\n    ggml_backend_buffer_free(m->buffer);\n    ggml_backend_free(m->backend);\n}\n\nvoid check_tensor(struct ggml_tensor* t,\n                  const float* expected_t_d,\n                  const int ne0,\n                  const int ne1,\n                  const int ne2) {\n    GGML_ASSERT(t->ne[0] == ne0);\n    GGML_ASSERT(t->ne[1] == ne1);\n    GGML_ASSERT(t->ne[2] == ne2);\n    const size_t bsize = ggml_nbytes(t);\n    if (t->type == GGML_TYPE_F32) {\n        float* buffer = malloc(bsize);\n        ggml_backend_tensor_get(t, buffer, 0, bsize);\n        for (int i = 0; i < bsize / sizeof(float); ++i) {\n            float expected = expected_t_d[i];\n            float actual = buffer[i];\n            if (expected != actual) {\n                printf(\"expected %.1f, got %.1f\\n\", expected, actual);\n            }\n            GGML_ASSERT(expected == actual);\n        }\n        free(buffer);\n    } else if (t->type == GGML_TYPE_F16) {\n        ggml_fp16_t* buffer = malloc(bsize);\n        ggml_backend_tensor_get(t, buffer, 0, bsize);\n        for (int i = 0; i < bsize / sizeof(ggml_fp16_t); ++i) {\n            float expected = expected_t_d[i];\n            float actual = ggml_fp16_to_fp32(buffer[i]);\n            if (expected != actual) {\n                printf(\"expected %.1f, got %.1f\\n\", expected, actual);\n            }\n            GGML_ASSERT(expected == actual);\n        }\n        free(buffer);\n    //} else if (t->type == GGML_TYPE_BF16) {\n    //    ggml_bf16_t* buffer = malloc(bsize);\n    //    ggml_backend_tensor_get(t, buffer, 0, bsize);\n    //    for (int i = 0; i < bsize / sizeof(ggml_bf16_t); ++i) {\n    //        float expected = expected_t_d[i];\n    //        float actual = ggml_bf16_to_fp32(buffer[i]);\n    //        if (expected != actual) {\n    //            printf(\"expected %.1f, got %.1f\\n\", expected, actual);\n    //        }\n    //        GGML_ASSERT(expected == actual);\n    //    }\n    //    free(buffer);\n    } else {\n        GGML_ABORT(\"unknown type\");\n    }\n}\n\nvoid test_cont(void) {\n    float buf_f32[] = {1.0, 2.0};\n    ggml_fp16_t buf_f16[] = {ggml_fp32_to_fp16(buf_f32[0]), ggml_fp32_to_fp16(buf_f32[1])};\n    ggml_bf16_t buf_bf16[] = {ggml_fp32_to_bf16(buf_f32[0]), ggml_fp32_to_bf16(buf_f32[1])};\n\n    float expected_out[] = {1.0, 2.0};\n\n    struct model m;\n    model_init(&m);\n\n    struct ggml_tensor* in_1 = ggml_new_tensor_1d(m.ctx, GGML_TYPE_F32, 2);\n    struct ggml_tensor* in_2 = ggml_new_tensor_1d(m.ctx, GGML_TYPE_F16, 2);\n    //struct ggml_tensor* in_3 = ggml_new_tensor_1d(m.ctx, GGML_TYPE_BF16, 2);\n\n    model_alloc(&m);\n\n    ggml_backend_tensor_set(in_1, buf_f32, 0, ggml_nbytes(in_1));\n    ggml_backend_tensor_set(in_2, buf_f16, 0, ggml_nbytes(in_2));\n    //ggml_backend_tensor_set(in_3, buf_bf16, 0, ggml_nbytes(in_3));\n\n    struct ggml_tensor* out_1 = ggml_cont(m.ctx0, ggml_transpose(m.ctx0, in_1));\n    struct ggml_tensor* out_2 = ggml_cont(m.ctx0, ggml_transpose(m.ctx0, in_2));\n    //struct ggml_tensor* out_3 = ggml_cont(m.ctx0, ggml_transpose(m.ctx0, in_3));\n\n    ggml_build_forward_expand(m.gf, out_1);\n    ggml_build_forward_expand(m.gf, out_2);\n    //ggml_build_forward_expand(m.gf, out_3);\n\n    model_compute(&m);\n\n    check_tensor(out_1, expected_out, 1, 2, 1);\n    check_tensor(out_2, expected_out, 1, 2, 1);\n    //check_tensor(out_3, expected_out, 1, 2, 1);\n\n    model_free(&m);\n}\n\nint main(int argc, const char* argv[]) {\n    test_cont();\n    return 0;\n}\n"
  },
  {
    "path": "tests/test-conv-transpose-1d.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n\n#ifdef GGML_USE_CUDA\n#include \"ggml-cuda.h\"\n#endif\n\n#ifdef GGML_USE_METAL\n#include \"ggml-metal.h\"\n#endif\n\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <string>\n#include <vector>\n#include <cmath>\n#include <iostream>\n\nstatic void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {\n    (void) level;\n    (void) user_data;\n    fputs(text, stderr);\n    fflush(stderr);\n}\n\nstruct test_model {\n    struct ggml_tensor * a_0;\n    struct ggml_tensor * b_0;\n\n    struct ggml_tensor * a_1;\n    struct ggml_tensor * b_1;\n\n    struct ggml_tensor * a_2;\n    struct ggml_tensor * b_2;\n\n    struct ggml_tensor * a_3;\n    struct ggml_tensor * b_3;\n\n    struct ggml_tensor * a_4;\n    struct ggml_tensor * b_4;\n\n\n    ggml_backend_t backend = NULL;\n    ggml_backend_buffer_t buffer;\n    struct ggml_context * ctx;\n};\n\nvoid load_model(test_model & model, bool use_gpu = false) {\n\n\n    float adata_0[] = {1,2,3};\n    float bdata_0[] = {1,2};\n\n    float adata_1[] = {1,2,3,3,2,1};\n    float bdata_1[] = {2,3,1,1,3,2};\n\n    float adata_2[] =  {3,2,1,1,2,3,1,2,3,3,2,1};\n    float bdata_2[] =  {2,3,1,1,3,2};\n\n    float data[16*32*32];\n    for (int i = 0; i < 16*32*32; ++i) {\n        data[i] = (float)(i%1024);\n    }\n\n\n\n\n    size_t buffer_size = 0;\n    {\n        buffer_size += 3* ggml_type_size(GGML_TYPE_F32); // tensor a_0\n        buffer_size += 2* ggml_type_size(GGML_TYPE_F32); // tensor b_0\n\n        buffer_size += 6* ggml_type_size(GGML_TYPE_F32); // tensor a_1\n        buffer_size += 6* ggml_type_size(GGML_TYPE_F32); // tensor b_1\n\n        buffer_size += 12* ggml_type_size(GGML_TYPE_F32); // tensor a_2\n        buffer_size += 6* ggml_type_size(GGML_TYPE_F32); // tensor b_2\n\n        buffer_size += 2 * 3 * 2 * ggml_type_size(GGML_TYPE_F32); // tensor a_3\n        buffer_size += 3 * 2* ggml_type_size(GGML_TYPE_F32); // tensor b_3\n\n        buffer_size += 16 * 32 * 32 * ggml_type_size(GGML_TYPE_F32); // tensor a_4\n        buffer_size += 197 * 32* ggml_type_size(GGML_TYPE_F32); // tensor b_4\n\n        buffer_size += 1024;\n    }\n\n    printf(\"%s: ggml tensor size    = %d bytes\\n\", __func__, (int) sizeof(ggml_tensor));\n    printf(\"%s: backend buffer size = %0.2f MB\\n\", __func__, (buffer_size/ 1024.f/ 1024.f));\n\n    int num_tensors = 10;\n    struct ggml_init_params params {\n            /*.mem_size   =*/ ggml_tensor_overhead() * num_tensors,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ true,\n    };\n\n    ggml_log_set(ggml_log_callback_default, nullptr);\n\n    // initialize the backend\n#ifdef GGML_USE_CUDA\n    if (use_gpu) {\n        fprintf(stderr, \"%s: using CUDA backend\\n\", __func__);\n        model.backend = ggml_backend_cuda_init(0);\n        if (!model.backend) {\n            fprintf(stderr, \"%s: ggml_backend_cuda_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n#ifdef GGML_USE_METAL\n    if (use_gpu) {\n        fprintf(stderr, \"%s: using Metal backend\\n\", __func__);\n        model.backend = ggml_backend_metal_init();\n        if (!model.backend) {\n            fprintf(stderr, \"%s: ggml_backend_metal_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n    if(!model.backend) {\n        // fallback to CPU backend\n        model.backend = ggml_backend_cpu_init();\n    }\n\n    model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size);\n\n    // create context\n    model.ctx = ggml_init(params);\n\n    // create tensors\n    model.a_0 = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, 3);\n    model.b_0 = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, 2);\n\n\n    model.a_1 = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 3,1,2);\n    model.b_1 = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, 3,2);\n\n    model.a_2 = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 3,2,2);\n    model.b_2 = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, 3,2);\n\n    model.a_3 = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 2,3,2);\n    model.b_3 = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, 3,2);\n\n    model.a_4 = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 16,32,32);\n    model.b_4 = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, 197,32);\n\n\n    // create a allocator\n    ggml_tallocr alloc = ggml_tallocr_new(model.buffer);\n\n    // alloc memory\n    ggml_tallocr_alloc(&alloc, model.a_0);\n\n    // load data to buffer\n    if(ggml_backend_is_cpu(model.backend)) {\n        memcpy(model.a_0->data, adata_0, ggml_nbytes(model.a_0));\n    } else {\n        ggml_backend_tensor_set(model.a_0, adata_0, 0, ggml_nbytes(model.a_0));\n    }\n\n\n    // alloc memory\n    ggml_tallocr_alloc(&alloc, model.a_1);\n\n    // load data to buffer\n    if(ggml_backend_is_cpu(model.backend)) {\n        memcpy(model.a_1->data, adata_1, ggml_nbytes(model.a_1));\n    } else {\n        ggml_backend_tensor_set(model.a_1, adata_1, 0, ggml_nbytes(model.a_1));\n    }\n\n     // alloc memory\n    ggml_tallocr_alloc(&alloc, model.a_2);\n\n    // load data to buffer\n    if(ggml_backend_is_cpu(model.backend)) {\n        memcpy(model.a_2->data, adata_2, ggml_nbytes(model.a_2));\n    } else {\n        ggml_backend_tensor_set(model.a_2, adata_2, 0, ggml_nbytes(model.a_2));\n    }\n\n      // alloc memory\n    ggml_tallocr_alloc(&alloc, model.a_3);\n\n    // load data to buffer\n    if(ggml_backend_is_cpu(model.backend)) {\n        memcpy(model.a_3->data, data, ggml_nbytes(model.a_3));\n    } else {\n        ggml_backend_tensor_set(model.a_3, data, 0, ggml_nbytes(model.a_3));\n    }\n\n\n      // alloc memory\n    ggml_tallocr_alloc(&alloc, model.a_4);\n\n    // load data to buffer\n    if(ggml_backend_is_cpu(model.backend)) {\n        memcpy(model.a_4->data, data, ggml_nbytes(model.a_4));\n    } else {\n        ggml_backend_tensor_set(model.a_4, data, 0, ggml_nbytes(model.a_4));\n    }\n\n\n\n    // alloc memory\n    ggml_tallocr_alloc(&alloc, model.b_0);\n\n    if(ggml_backend_is_cpu(model.backend)\n#ifdef GGML_USE_METAL\n                || ggml_backend_is_metal(model.backend)\n#endif\n    ) {\n        memcpy(model.b_0->data, bdata_0, ggml_nbytes(model.b_0));\n    } else {\n        ggml_backend_tensor_set(model.b_0, bdata_0, 0, ggml_nbytes(model.b_0));\n    }\n\n    // alloc memory\n    ggml_tallocr_alloc(&alloc, model.b_1);\n\n    if(ggml_backend_is_cpu(model.backend)\n#ifdef GGML_USE_METAL\n                || ggml_backend_is_metal(model.backend)\n#endif\n    ) {\n        memcpy(model.b_1->data, bdata_1, ggml_nbytes(model.b_1));\n    } else {\n        ggml_backend_tensor_set(model.b_1, bdata_1, 0, ggml_nbytes(model.b_1));\n    }\n\n     // alloc memory\n    ggml_tallocr_alloc(&alloc, model.b_2);\n\n    if(ggml_backend_is_cpu(model.backend)\n#ifdef GGML_USE_METAL\n                || ggml_backend_is_metal(model.backend)\n#endif\n    ) {\n        memcpy(model.b_2->data, bdata_2, ggml_nbytes(model.b_2));\n    } else {\n        ggml_backend_tensor_set(model.b_2, bdata_2, 0, ggml_nbytes(model.b_2));\n    }\n\n    // alloc memory\n    ggml_tallocr_alloc(&alloc, model.b_3);\n\n    if(ggml_backend_is_cpu(model.backend)\n#ifdef GGML_USE_METAL\n                || ggml_backend_is_metal(model.backend)\n#endif\n    ) {\n        memcpy(model.b_3->data, data, ggml_nbytes(model.b_3));\n    } else {\n        ggml_backend_tensor_set(model.b_3, data, 0, ggml_nbytes(model.b_3));\n    }\n\n\n   // alloc memory\n    ggml_tallocr_alloc(&alloc, model.b_4);\n\n    if(ggml_backend_is_cpu(model.backend)\n#ifdef GGML_USE_METAL\n                || ggml_backend_is_metal(model.backend)\n#endif\n    ) {\n        memcpy(model.b_4->data, data, ggml_nbytes(model.b_4));\n    } else {\n        ggml_backend_tensor_set(model.b_4, data, 0, ggml_nbytes(model.b_4));\n    }\n\n\n}\n\nstruct ggml_cgraph * build_graph(const test_model& model) {\n    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();\n    static std::vector<uint8_t> buf(buf_size);\n\n    struct ggml_init_params params0 = {\n        /*.mem_size   =*/ buf_size,\n        /*.mem_buffer =*/ buf.data(),\n        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()\n    };\n\n    // create a temporally context to build the graph\n    struct ggml_context * ctx0 = ggml_init(params0);\n\n    struct ggml_cgraph  * gf = ggml_new_graph(ctx0);\n\n    int s0 = 1;\n    int p0 = 0;\n    int d0 = 1;\n\n    struct ggml_tensor* conv1d_transpose_res_0 = ggml_conv_transpose_1d(ctx0, model.a_0, model.b_0, s0, p0, d0);\n    ggml_set_name(conv1d_transpose_res_0, \"conv1d_transpose_res_0\");\n    ggml_build_forward_expand(gf, conv1d_transpose_res_0);\n\n\n    s0 = 1;\n    p0 = 0;\n    d0 = 1;\n\n    struct ggml_tensor* conv1d_transpose_res_1 = ggml_conv_transpose_1d(ctx0, model.a_1, model.b_1, s0, p0, d0);\n    ggml_set_name(conv1d_transpose_res_1, \"conv1d_transpose_res_1\");\n    ggml_build_forward_expand(gf, conv1d_transpose_res_1);\n\n    s0 = 1;\n    p0 = 0;\n    d0 = 1;\n\n    struct ggml_tensor* conv1d_transpose_res_2 = ggml_conv_transpose_1d(ctx0, model.a_2, model.b_2, s0, p0, d0);\n    ggml_set_name(conv1d_transpose_res_2, \"conv1d_transpose_res_2\");\n    ggml_build_forward_expand(gf, conv1d_transpose_res_2);\n\n    s0 = 2;\n    p0 = 0;\n    d0 = 1;\n\n    struct ggml_tensor* conv1d_transpose_res_3 = ggml_conv_transpose_1d(ctx0, model.a_2, model.b_2, s0, p0, d0);\n    ggml_set_name(conv1d_transpose_res_3, \"conv1d_transpose_res_3\");\n    ggml_build_forward_expand(gf, conv1d_transpose_res_3);\n\n    s0 = 1;\n    p0 = 0;\n    d0 = 1;\n\n    struct ggml_tensor* conv1d_transpose_res_4 = ggml_conv_transpose_1d(ctx0, model.a_3, model.b_3, s0, p0, d0);\n    ggml_set_name(conv1d_transpose_res_4, \"conv1d_transpose_res_4\");\n    ggml_build_forward_expand(gf, conv1d_transpose_res_4);\n\n    s0 = 2;\n    p0 = 0;\n    d0 = 1;\n\n    struct ggml_tensor* conv1d_transpose_res_5 = ggml_conv_transpose_1d(ctx0, model.a_3, model.b_3, s0, p0, d0);\n    ggml_set_name(conv1d_transpose_res_5, \"conv1d_transpose_res_5\");\n    ggml_build_forward_expand(gf, conv1d_transpose_res_5);\n\n    s0 = 3;\n    p0 = 0;\n    d0 = 1;\n\n    struct ggml_tensor* conv1d_transpose_res_6 = ggml_conv_transpose_1d(ctx0, model.a_3, model.b_3, s0, p0, d0);\n    ggml_set_name(conv1d_transpose_res_6, \"conv1d_transpose_res_6\");\n    ggml_build_forward_expand(gf, conv1d_transpose_res_6);\n\n\n    s0 = 8;\n    p0 = 0;\n    d0 = 1;\n\n    struct ggml_tensor* conv1d_transpose_res_7 = ggml_conv_transpose_1d(ctx0, model.a_4, model.b_4, s0, p0, d0);\n    ggml_set_name(conv1d_transpose_res_7, \"conv1d_transpose_res_7\");\n    ggml_build_forward_expand(gf, conv1d_transpose_res_7);\n\n\n\n    // delete the temporally context used to build the graph\n    ggml_free(ctx0);\n    return gf;\n}\n\nstruct ggml_cgraph* compute_graph(const test_model & model, ggml_gallocr_t allocr) {\n    struct ggml_cgraph * gf = build_graph(model);\n\n    // allocate tensors\n    ggml_gallocr_alloc_graph(allocr, gf);\n    int n_threads = 1;\n\n    if (ggml_backend_is_cpu(model.backend)) {\n        ggml_backend_cpu_set_n_threads(model.backend, n_threads);\n    }\n\n    ggml_backend_graph_compute(model.backend, gf);\n\n    //ggml_graph_print(gf);\n\n    return gf;\n}\n\nint main(void)\n{\n    ggml_time_init();\n\n    test_model model;\n    load_model(model, true);\n\n    ggml_gallocr_t allocr = NULL;\n\n    {\n        allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));\n\n        //create the worst case graph for memory usage estimation\n        struct ggml_cgraph * gf = build_graph(model);\n\n        // compute the required memory\n        ggml_gallocr_reserve(allocr, gf);\n        size_t mem_size = ggml_gallocr_get_buffer_size(allocr, 0);\n        fprintf(stderr, \"%s: compute buffer size: %.2f MB\\n\", __func__, mem_size/1024.0f/1024.0f);\n    }\n\n    struct ggml_cgraph * gf_res = compute_graph(model, allocr);\n\n    struct ggml_tensor * conv1d_transpose_res_0 = NULL;\n\n    for(int i = 0; i < ggml_graph_n_nodes(gf_res); i++) {\n       if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), \"conv1d_transpose_res_0\") == 0) {\n            conv1d_transpose_res_0 = ggml_graph_node(gf_res, i);\n        }\n    }\n\n    std::vector<float> conv1d_transpose_data_0(ggml_nelements(conv1d_transpose_res_0));\n\n    ggml_backend_tensor_get(conv1d_transpose_res_0, conv1d_transpose_data_0.data(), 0, ggml_nbytes(conv1d_transpose_res_0));\n\n    const int n_conv_transpose_1d_test_0 = 4;\n\n    float expected_conv1d_0[n_conv_transpose_1d_test_0] = {\n       1.00f,4.00f,7.00f,6.00f\n    };\n\n\n    struct ggml_tensor * conv1d_transpose_res_1 = NULL;\n\n    for(int i = 0; i < ggml_graph_n_nodes(gf_res); i++) {\n       if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), \"conv1d_transpose_res_1\") == 0) {\n            conv1d_transpose_res_1 = ggml_graph_node(gf_res, i);\n        }\n    }\n\n    std::vector<float> conv1d_transpose_data_1(ggml_nelements(conv1d_transpose_res_1));\n\n    ggml_backend_tensor_get(conv1d_transpose_res_1, conv1d_transpose_data_1.data(), 0, ggml_nbytes(conv1d_transpose_res_1));\n\n\n\n\n\n    const int n_conv_transpose_1d_test_1 = 5;\n\n    float expected_conv1d_1[n_conv_transpose_1d_test_1] =\n       {5.0f, 18.0f, 26.0f, 18.0f,  5.0f};\n\n\n    struct ggml_tensor * conv1d_transpose_res_2 = NULL;\n\n    for(int i = 0; i < ggml_graph_n_nodes(gf_res); i++) {\n       if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), \"conv1d_transpose_res_2\") == 0) {\n            conv1d_transpose_res_2 = ggml_graph_node(gf_res, i);\n        }\n    }\n\n    std::vector<float> conv1d_transpose_data_2(ggml_nelements(conv1d_transpose_res_2));\n\n    ggml_backend_tensor_get(conv1d_transpose_res_2, conv1d_transpose_data_2.data(), 0, ggml_nbytes(conv1d_transpose_res_2));\n\n\n    const int n_conv_transpose_1d_test_2 = 10;\n\n    float expected_conv1d_2[n_conv_transpose_1d_test_2] =\n       {7.0f, 18.0f, 22.0f, 18.0f,  7.0f,\n         5.0f, 18.0f, 26.0f, 18.0f,  5.0f};\n\n    struct ggml_tensor * conv1d_transpose_res_3 = NULL;\n\n    for(int i = 0; i < ggml_graph_n_nodes(gf_res); i++) {\n       if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), \"conv1d_transpose_res_3\") == 0) {\n            conv1d_transpose_res_3 = ggml_graph_node(gf_res, i);\n        }\n    }\n\n    std::vector<float> conv1d_transpose_data_3(ggml_nelements(conv1d_transpose_res_3));\n\n    ggml_backend_tensor_get(conv1d_transpose_res_3, conv1d_transpose_data_3.data(), 0, ggml_nbytes(conv1d_transpose_res_3));\n\n\n    const int n_conv_transpose_1d_test_3 = 14;\n\n    float expected_conv1d_3[n_conv_transpose_1d_test_3] =\n       {7.0f,  6.0f, 17.0f, 12.0f, 17.0f,  6.0f,  7.0f\n         ,5.0f,  6.0f, 19.0f, 12.0f, 19.0f,  6.0f,  5.0f};\n\n\n    struct ggml_tensor * conv1d_transpose_res_4 = NULL;\n\n    for(int i = 0; i < ggml_graph_n_nodes(gf_res); i++) {\n       if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), \"conv1d_transpose_res_4\") == 0) {\n            conv1d_transpose_res_4 = ggml_graph_node(gf_res, i);\n        }\n    }\n\n    std::vector<float> conv1d_transpose_data_4(ggml_nelements(conv1d_transpose_res_4));\n\n    ggml_backend_tensor_get(conv1d_transpose_res_4, conv1d_transpose_data_4.data(), 0, ggml_nbytes(conv1d_transpose_res_4));\n\n\n    const int n_conv_transpose_1d_test_4 = 12;\n\n    float expected_conv1d_4[3*4] = {\n        18.0, 45.0, 59.0, 37.0,\n        24.0, 61.0, 83.0, 51.0,\n        30.0, 77.0, 107.0, 65.0\n    };\n\n    struct ggml_tensor * conv1d_transpose_res_5 = NULL;\n\n    for(int i = 0; i < ggml_graph_n_nodes(gf_res); i++) {\n       if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), \"conv1d_transpose_res_5\") == 0) {\n            conv1d_transpose_res_5 = ggml_graph_node(gf_res, i);\n        }\n    }\n\n    std::vector<float> conv1d_transpose_data_5(ggml_nelements(conv1d_transpose_res_5));\n\n    ggml_backend_tensor_get(conv1d_transpose_res_5, conv1d_transpose_data_5.data(), 0, ggml_nbytes(conv1d_transpose_res_5));\n\n\n    const int n_conv_transpose_1d_test_5 = 18;\n\n    float expected_conv1d_5[3*6] = {\n        18.0, 21.0, 24.0, 29.0, 30.0, 37.0,\n        24.0, 27.0, 34.0, 39.0, 44.0, 51.0,\n        30.0, 33.0, 44.0, 49.0, 58.0, 65.0\n        };\n\n    struct ggml_tensor * conv1d_transpose_res_6 = NULL;\n\n    for(int i = 0; i < ggml_graph_n_nodes(gf_res); i++) {\n       if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), \"conv1d_transpose_res_6\") == 0) {\n            conv1d_transpose_res_6 = ggml_graph_node(gf_res, i);\n        }\n    }\n\n    std::vector<float> conv1d_transpose_data_6(ggml_nelements(conv1d_transpose_res_6));\n\n    ggml_backend_tensor_get(conv1d_transpose_res_6, conv1d_transpose_data_6.data(), 0, ggml_nbytes(conv1d_transpose_res_6));\n\n\n    const int n_conv_transpose_1d_test_6 = 24;\n\n    float expected_conv1d_6[3*8] = {\n        18.0, 21.0, 0.0, 24.0, 29.0, 0.0, 30.0, 37.0,\n        24.0, 27.0, 0.0, 34.0, 39.0, 0.0, 44.0, 51.0,\n        30.0, 33.0, 0.0, 44.0, 49.0, 0.0, 58.0, 65.0};\n\n\n\n    struct ggml_tensor * conv1d_transpose_res_7 = NULL;\n\n    for(int i = 0; i < ggml_graph_n_nodes(gf_res); i++) {\n       if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), \"conv1d_transpose_res_7\") == 0) {\n            conv1d_transpose_res_7 = ggml_graph_node(gf_res, i);\n        }\n    }\n\n    std::vector<float> conv1d_transpose_data_7(ggml_nelements(conv1d_transpose_res_7));\n\n    ggml_backend_tensor_get(conv1d_transpose_res_7, conv1d_transpose_data_7.data(), 0, ggml_nbytes(conv1d_transpose_res_7));\n\n\n    const int n_conv_transpose_1d_test_7 = 32*1584;\n\n    float expected_conv1d_7[32*1584] = {4325376.0,4341168.0,4356960.0,4372752.0,4388544.0,4404336.0,4420128.0,4435920.0,8785280.0,8816896.0,8848512.0,8880128.0,8911744.0,8943360.0,8974976.0,9006592.0,8801920.0,8833600.0,8865280.0,8896960.0,8928640.0,8960320.0,8992000.0,9023680.0,8818560.0,8850304.0,8882048.0,8913792.0,8945536.0,8977280.0,9009024.0,9040768.0,8835200.0,8867008.0,8898816.0,8930624.0,8962432.0,8994240.0,9026048.0,9057856.0,8851840.0,8883712.0,8915584.0,8947456.0,8979328.0,9011200.0,9043072.0,9074944.0,8868480.0,8900416.0,8932352.0,8964288.0,8996224.0,9028160.0,9060096.0,9092032.0,8885120.0,8917120.0,8949120.0,8981120.0,9013120.0,9045120.0,9077120.0,9109120.0,8901760.0,8933824.0,8965888.0,8997952.0,9030016.0,9062080.0,9094144.0,9126208.0,8918400.0,8950528.0,8982656.0,9014784.0,9046912.0,9079040.0,9111168.0,9143296.0,8935040.0,8967232.0,8999424.0,9031616.0,9063808.0,9096000.0,9128192.0,9160384.0,8951680.0,8983936.0,9016192.0,9048448.0,9080704.0,9112960.0,9145216.0,9177472.0,8968320.0,9000640.0,9032960.0,9065280.0,9097600.0,9129920.0,9162240.0,9194560.0,8984960.0,9017344.0,9049728.0,9082112.0,9114496.0,9146880.0,9179264.0,9211648.0,9001600.0,9034048.0,9066496.0,9098944.0,9131392.0,9163840.0,9196288.0,9228736.0,9018240.0,9050752.0,9083264.0,9115776.0,9148288.0,9180800.0,9213312.0,9245824.0,9034880.0,9067456.0,9100032.0,9132608.0,9165184.0,9197760.0,9230336.0,9262912.0,9051520.0,9084160.0,9116800.0,9149440.0,9182080.0,9214720.0,9247360.0,9280000.0,9068160.0,9100864.0,9133568.0,9166272.0,9198976.0,9231680.0,9264384.0,9297088.0,9084800.0,9117568.0,9150336.0,9183104.0,9215872.0,9248640.0,9281408.0,9314176.0,9101440.0,9134272.0,9167104.0,9199936.0,9232768.0,9265600.0,9298432.0,9331264.0,9118080.0,9150976.0,9183872.0,9216768.0,9249664.0,9282560.0,9315456.0,9348352.0,9134720.0,9167680.0,9200640.0,9233600.0,9266560.0,9299520.0,9332480.0,9365440.0,9151360.0,9184384.0,9217408.0,9250432.0,9283456.0,9316480.0,9349504.0,9382528.0,9168000.0,9201088.0,9234176.0,9267264.0,9300352.0,9333440.0,9366528.0,9399616.0,9184640.0,9217792.0,9250944.0,9284096.0,9317248.0,9350400.0,9383552.0,9416704.0,9201280.0,9234496.0,9267712.0,9300928.0,9334144.0,9367360.0,9400576.0,9433792.0,9217920.0,9251200.0,9284480.0,9317760.0,9351040.0,9384320.0,9417600.0,9450880.0,9234560.0,9267904.0,9301248.0,9334592.0,9367936.0,9401280.0,9434624.0,9467968.0,9251200.0,9284608.0,9318016.0,9351424.0,9384832.0,9418240.0,9451648.0,9485056.0,9267840.0,9301312.0,9334784.0,9368256.0,9401728.0,9435200.0,9468672.0,9502144.0,9284480.0,9318016.0,9351552.0,9385088.0,9418624.0,9452160.0,9485696.0,9519232.0,9301120.0,9334720.0,9368320.0,9401920.0,9435520.0,9469120.0,9502720.0,9536320.0,9317760.0,9351424.0,9385088.0,9418752.0,9452416.0,9486080.0,9519744.0,9553408.0,9334400.0,9368128.0,9401856.0,9435584.0,9469312.0,9503040.0,9536768.0,9570496.0,9351040.0,9384832.0,9418624.0,9452416.0,9486208.0,9520000.0,9553792.0,9587584.0,9367680.0,9401536.0,9435392.0,9469248.0,9503104.0,9536960.0,9570816.0,9604672.0,8860032.0,8892928.0,8925824.0,8958720.0,8991616.0,9024512.0,9057408.0,9090304.0,8344192.0,8376128.0,8408064.0,8440000.0,8471936.0,8503872.0,8535808.0,8567744.0,7836544.0,7867520.0,7898496.0,7929472.0,7960448.0,7991424.0,8022400.0,8053376.0,7320704.0,7350720.0,7380736.0,7410752.0,7440768.0,7470784.0,7500800.0,7530816.0,7337344.0,7367424.0,7397504.0,7427584.0,7457664.0,7487744.0,7517824.0,7547904.0,7353984.0,7384128.0,7414272.0,7444416.0,7474560.0,7504704.0,7534848.0,7564992.0,7370624.0,7400832.0,7431040.0,7461248.0,7491456.0,7521664.0,7551872.0,7582080.0,7387264.0,7417536.0,7447808.0,7478080.0,7508352.0,7538624.0,7568896.0,7599168.0,7403904.0,7434240.0,7464576.0,7494912.0,7525248.0,7555584.0,7585920.0,7616256.0,7420544.0,7450944.0,7481344.0,7511744.0,7542144.0,7572544.0,7602944.0,7633344.0,7437184.0,7467648.0,7498112.0,7528576.0,7559040.0,7589504.0,7619968.0,7650432.0,7453824.0,7484352.0,7514880.0,7545408.0,7575936.0,7606464.0,7636992.0,7667520.0,7470464.0,7501056.0,7531648.0,7562240.0,7592832.0,7623424.0,7654016.0,7684608.0,7487104.0,7517760.0,7548416.0,7579072.0,7609728.0,7640384.0,7671040.0,7701696.0,7503744.0,7534464.0,7565184.0,7595904.0,7626624.0,7657344.0,7688064.0,7718784.0,7520384.0,7551168.0,7581952.0,7612736.0,7643520.0,7674304.0,7705088.0,7735872.0,7537024.0,7567872.0,7598720.0,7629568.0,7660416.0,7691264.0,7722112.0,7752960.0,7553664.0,7584576.0,7615488.0,7646400.0,7677312.0,7708224.0,7739136.0,7770048.0,7570304.0,7601280.0,7632256.0,7663232.0,7694208.0,7725184.0,7756160.0,7787136.0,7586944.0,7617984.0,7649024.0,7680064.0,7711104.0,7742144.0,7773184.0,7804224.0,7603584.0,7634688.0,7665792.0,7696896.0,7728000.0,7759104.0,7790208.0,7821312.0,7620224.0,7651392.0,7682560.0,7713728.0,7744896.0,7776064.0,7807232.0,7838400.0,7636864.0,7668096.0,7699328.0,7730560.0,7761792.0,7793024.0,7824256.0,7855488.0,7653504.0,7684800.0,7716096.0,7747392.0,7778688.0,7809984.0,7841280.0,7872576.0,7670144.0,7701504.0,7732864.0,7764224.0,7795584.0,7826944.0,7858304.0,7889664.0,7686784.0,7718208.0,7749632.0,7781056.0,7812480.0,7843904.0,7875328.0,7906752.0,7703424.0,7734912.0,7766400.0,7797888.0,7829376.0,7860864.0,7892352.0,7923840.0,7720064.0,7751616.0,7783168.0,7814720.0,7846272.0,7877824.0,7909376.0,7940928.0,7736704.0,7768320.0,7799936.0,7831552.0,7863168.0,7894784.0,7926400.0,7958016.0,7753344.0,7785024.0,7816704.0,7848384.0,7880064.0,7911744.0,7943424.0,7975104.0,7769984.0,7801728.0,7833472.0,7865216.0,7896960.0,7928704.0,7960448.0,7992192.0,7786624.0,7818432.0,7850240.0,7882048.0,7913856.0,7945664.0,7977472.0,8009280.0,7803264.0,7835136.0,7867008.0,7898880.0,7930752.0,7962624.0,7994496.0,8026368.0,7819904.0,7851840.0,7883776.0,7915712.0,7947648.0,7979584.0,8011520.0,8043456.0,7836544.0,7868544.0,7900544.0,7932544.0,7964544.0,7996544.0,8028544.0,8060544.0,7853184.0,7885248.0,7917312.0,7949376.0,7981440.0,8013504.0,8045568.0,8077632.0,7869824.0,7901952.0,7934080.0,7966208.0,7998336.0,8030464.0,8062592.0,8094720.0,7886464.0,7918656.0,7950848.0,7983040.0,8015232.0,8047424.0,8079616.0,8111808.0,7903104.0,7935360.0,7967616.0,7999872.0,8032128.0,8064384.0,8096640.0,8128896.0,7919744.0,7952064.0,7984384.0,8016704.0,8049024.0,8081344.0,8113664.0,8145984.0,7936384.0,7968768.0,8001152.0,8033536.0,8065920.0,8098304.0,8130688.0,8163072.0,7953024.0,7984448.0,8015872.0,8047296.0,8078720.0,8110144.0,8141568.0,8172992.0,7961472.0,7991936.0,8022400.0,8052864.0,8083328.0,8113792.0,8144256.0,8174720.0,7978112.0,8008640.0,8039168.0,8069696.0,8100224.0,8130752.0,8161280.0,8191808.0,7994752.0,8025344.0,8055936.0,8086528.0,8117120.0,8147712.0,8178304.0,8208896.0,8011392.0,8042048.0,8072704.0,8103360.0,8134016.0,8164672.0,8195328.0,8225984.0,8028032.0,8058752.0,8089472.0,8120192.0,8150912.0,8181632.0,8212352.0,8243072.0,8044672.0,8075456.0,8106240.0,8137024.0,8167808.0,8198592.0,8229376.0,8260160.0,8061312.0,8092160.0,8123008.0,8153856.0,8184704.0,8215552.0,8246400.0,8277248.0,8077952.0,8108864.0,8139776.0,8170688.0,8201600.0,8232512.0,8263424.0,8294336.0,8094592.0,8125568.0,8156544.0,8187520.0,8218496.0,8249472.0,8280448.0,8311424.0,8111232.0,8142272.0,8173312.0,8204352.0,8235392.0,8266432.0,8297472.0,8328512.0,8127872.0,8158976.0,8190080.0,8221184.0,8252288.0,8283392.0,8314496.0,8345600.0,8144512.0,8175680.0,8206848.0,8238016.0,8269184.0,8300352.0,8331520.0,8362688.0,8161152.0,8192384.0,8223616.0,8254848.0,8286080.0,8317312.0,8348544.0,8379776.0,8177792.0,8209088.0,8240384.0,8271680.0,8302976.0,8334272.0,8365568.0,8396864.0,8194432.0,8225792.0,8257152.0,8288512.0,8319872.0,8351232.0,8382592.0,8413952.0,8211072.0,8242496.0,8273920.0,8305344.0,8336768.0,8368192.0,8399616.0,8431040.0,8227712.0,8259200.0,8290688.0,8322176.0,8353664.0,8385152.0,8416640.0,8448128.0,8244352.0,8275904.0,8307456.0,8339008.0,8370560.0,8402112.0,8433664.0,8465216.0,8260992.0,8292608.0,8324224.0,8355840.0,8387456.0,8419072.0,8450688.0,8482304.0,8277632.0,8309312.0,8340992.0,8372672.0,8404352.0,8436032.0,8467712.0,8499392.0,8294272.0,8326016.0,8357760.0,8389504.0,8421248.0,8452992.0,8484736.0,8516480.0,8310912.0,8342720.0,8374528.0,8406336.0,8438144.0,8469952.0,8501760.0,8533568.0,8327552.0,8359424.0,8391296.0,8423168.0,8455040.0,8486912.0,8518784.0,8550656.0,8344192.0,8376128.0,8408064.0,8440000.0,8471936.0,8503872.0,8535808.0,8567744.0,8360832.0,8392832.0,8424832.0,8456832.0,8488832.0,8520832.0,8552832.0,8584832.0,8377472.0,8409536.0,8441600.0,8473664.0,8505728.0,8537792.0,8569856.0,8601920.0,8394112.0,8426240.0,8458368.0,8490496.0,8522624.0,8554752.0,8586880.0,8619008.0,8410752.0,8442944.0,8475136.0,8507328.0,8539520.0,8571712.0,8603904.0,8636096.0,8427392.0,8459648.0,8491904.0,8524160.0,8556416.0,8588672.0,8620928.0,8653184.0,8444032.0,8476352.0,8508672.0,8540992.0,8573312.0,8605632.0,8637952.0,8670272.0,8460672.0,8493056.0,8525440.0,8557824.0,8590208.0,8622592.0,8654976.0,8687360.0,8477312.0,8509760.0,8542208.0,8574656.0,8607104.0,8639552.0,8672000.0,8704448.0,8493952.0,8526464.0,8558976.0,8591488.0,8624000.0,8656512.0,8689024.0,8721536.0,8510592.0,8543168.0,8575744.0,8608320.0,8640896.0,8673472.0,8706048.0,8738624.0,8527232.0,8559872.0,8592512.0,8625152.0,8657792.0,8690432.0,8723072.0,8755712.0,8543872.0,8576576.0,8609280.0,8641984.0,8674688.0,8707392.0,8740096.0,8772800.0,8560512.0,8593280.0,8626048.0,8658816.0,8691584.0,8724352.0,8757120.0,8789888.0,8577152.0,8609984.0,8642816.0,8675648.0,8708480.0,8741312.0,8774144.0,8806976.0,8069504.0,8101376.0,8133248.0,8165120.0,8196992.0,8228864.0,8260736.0,8292608.0,7553664.0,7584576.0,7615488.0,7646400.0,7677312.0,7708224.0,7739136.0,7770048.0,7570304.0,7601280.0,7632256.0,7663232.0,7694208.0,7725184.0,7756160.0,7787136.0,7586944.0,7617984.0,7649024.0,7680064.0,7711104.0,7742144.0,7773184.0,7804224.0,7603584.0,7634688.0,7665792.0,7696896.0,7728000.0,7759104.0,7790208.0,7821312.0,7620224.0,7651392.0,7682560.0,7713728.0,7744896.0,7776064.0,7807232.0,7838400.0,7636864.0,7668096.0,7699328.0,7730560.0,7761792.0,7793024.0,7824256.0,7855488.0,7653504.0,7684800.0,7716096.0,7747392.0,7778688.0,7809984.0,7841280.0,7872576.0,7670144.0,7701504.0,7732864.0,7764224.0,7795584.0,7826944.0,7858304.0,7889664.0,7686784.0,7718208.0,7749632.0,7781056.0,7812480.0,7843904.0,7875328.0,7906752.0,7703424.0,7734912.0,7766400.0,7797888.0,7829376.0,7860864.0,7892352.0,7923840.0,7720064.0,7751616.0,7783168.0,7814720.0,7846272.0,7877824.0,7909376.0,7940928.0,7736704.0,7768320.0,7799936.0,7831552.0,7863168.0,7894784.0,7926400.0,7958016.0,7753344.0,7785024.0,7816704.0,7848384.0,7880064.0,7911744.0,7943424.0,7975104.0,7769984.0,7801728.0,7833472.0,7865216.0,7896960.0,7928704.0,7960448.0,7992192.0,7786624.0,7818432.0,7850240.0,7882048.0,7913856.0,7945664.0,7977472.0,8009280.0,7803264.0,7835136.0,7867008.0,7898880.0,7930752.0,7962624.0,7994496.0,8026368.0,7819904.0,7851840.0,7883776.0,7915712.0,7947648.0,7979584.0,8011520.0,8043456.0,7836544.0,7868544.0,7900544.0,7932544.0,7964544.0,7996544.0,8028544.0,8060544.0,7853184.0,7885248.0,7917312.0,7949376.0,7981440.0,8013504.0,8045568.0,8077632.0,7869824.0,7901952.0,7934080.0,7966208.0,7998336.0,8030464.0,8062592.0,8094720.0,7886464.0,7918656.0,7950848.0,7983040.0,8015232.0,8047424.0,8079616.0,8111808.0,7903104.0,7935360.0,7967616.0,7999872.0,8032128.0,8064384.0,8096640.0,8128896.0,7919744.0,7952064.0,7984384.0,8016704.0,8049024.0,8081344.0,8113664.0,8145984.0,7936384.0,7968768.0,8001152.0,8033536.0,8065920.0,8098304.0,8130688.0,8163072.0,7953024.0,7985472.0,8017920.0,8050368.0,8082816.0,8115264.0,8147712.0,8180160.0,7969664.0,8002176.0,8034688.0,8067200.0,8099712.0,8132224.0,8164736.0,8197248.0,7986304.0,8018880.0,8051456.0,8084032.0,8116608.0,8149184.0,8181760.0,8214336.0,8002944.0,8035584.0,8068224.0,8100864.0,8133504.0,8166144.0,8198784.0,8231424.0,8019584.0,8052288.0,8084992.0,8117696.0,8150400.0,8183104.0,8215808.0,8248512.0,8036224.0,8068992.0,8101760.0,8134528.0,8167296.0,8200064.0,8232832.0,8265600.0,8052864.0,8085696.0,8118528.0,8151360.0,8184192.0,8217024.0,8249856.0,8282688.0,8069504.0,8102400.0,8135296.0,8168192.0,8201088.0,8233984.0,8266880.0,8299776.0,8086144.0,8119104.0,8152064.0,8185024.0,8217984.0,8250944.0,8283904.0,8316864.0,8102784.0,8135808.0,8168832.0,8201856.0,8234880.0,8267904.0,8300928.0,8333952.0,8119424.0,8152512.0,8185600.0,8218688.0,8251776.0,8284864.0,8317952.0,8351040.0,8136064.0,8169216.0,8202368.0,8235520.0,8268672.0,8301824.0,8334976.0,8368128.0,8152704.0,8185920.0,8219136.0,8252352.0,8285568.0,8318784.0,8352000.0,8385216.0,8169344.0,8202624.0,8235904.0,8269184.0,8302464.0,8335744.0,8369024.0,8402304.0,8185984.0,8218304.0,8250624.0,8282944.0,8315264.0,8347584.0,8379904.0,8412224.0,8194432.0,8225792.0,8257152.0,8288512.0,8319872.0,8351232.0,8382592.0,8413952.0,8211072.0,8242496.0,8273920.0,8305344.0,8336768.0,8368192.0,8399616.0,8431040.0,8227712.0,8259200.0,8290688.0,8322176.0,8353664.0,8385152.0,8416640.0,8448128.0,8244352.0,8275904.0,8307456.0,8339008.0,8370560.0,8402112.0,8433664.0,8465216.0,8260992.0,8292608.0,8324224.0,8355840.0,8387456.0,8419072.0,8450688.0,8482304.0,8277632.0,8309312.0,8340992.0,8372672.0,8404352.0,8436032.0,8467712.0,8499392.0,8294272.0,8326016.0,8357760.0,8389504.0,8421248.0,8452992.0,8484736.0,8516480.0,8310912.0,8342720.0,8374528.0,8406336.0,8438144.0,8469952.0,8501760.0,8533568.0,8327552.0,8359424.0,8391296.0,8423168.0,8455040.0,8486912.0,8518784.0,8550656.0,8344192.0,8376128.0,8408064.0,8440000.0,8471936.0,8503872.0,8535808.0,8567744.0,8360832.0,8392832.0,8424832.0,8456832.0,8488832.0,8520832.0,8552832.0,8584832.0,8377472.0,8409536.0,8441600.0,8473664.0,8505728.0,8537792.0,8569856.0,8601920.0,8394112.0,8426240.0,8458368.0,8490496.0,8522624.0,8554752.0,8586880.0,8619008.0,8410752.0,8442944.0,8475136.0,8507328.0,8539520.0,8571712.0,8603904.0,8636096.0,8427392.0,8459648.0,8491904.0,8524160.0,8556416.0,8588672.0,8620928.0,8653184.0,8444032.0,8476352.0,8508672.0,8540992.0,8573312.0,8605632.0,8637952.0,8670272.0,8460672.0,8493056.0,8525440.0,8557824.0,8590208.0,8622592.0,8654976.0,8687360.0,8477312.0,8509760.0,8542208.0,8574656.0,8607104.0,8639552.0,8672000.0,8704448.0,8493952.0,8526464.0,8558976.0,8591488.0,8624000.0,8656512.0,8689024.0,8721536.0,8510592.0,8543168.0,8575744.0,8608320.0,8640896.0,8673472.0,8706048.0,8738624.0,8527232.0,8559872.0,8592512.0,8625152.0,8657792.0,8690432.0,8723072.0,8755712.0,8543872.0,8576576.0,8609280.0,8641984.0,8674688.0,8707392.0,8740096.0,8772800.0,8560512.0,8593280.0,8626048.0,8658816.0,8691584.0,8724352.0,8757120.0,8789888.0,8577152.0,8609984.0,8642816.0,8675648.0,8708480.0,8741312.0,8774144.0,8806976.0,8593792.0,8626688.0,8659584.0,8692480.0,8725376.0,8758272.0,8791168.0,8824064.0,8610432.0,8643392.0,8676352.0,8709312.0,8742272.0,8775232.0,8808192.0,8841152.0,8627072.0,8660096.0,8693120.0,8726144.0,8759168.0,8792192.0,8825216.0,8858240.0,8643712.0,8676800.0,8709888.0,8742976.0,8776064.0,8809152.0,8842240.0,8875328.0,8660352.0,8693504.0,8726656.0,8759808.0,8792960.0,8826112.0,8859264.0,8892416.0,8676992.0,8710208.0,8743424.0,8776640.0,8809856.0,8843072.0,8876288.0,8909504.0,8693632.0,8726912.0,8760192.0,8793472.0,8826752.0,8860032.0,8893312.0,8926592.0,8710272.0,8743616.0,8776960.0,8810304.0,8843648.0,8876992.0,8910336.0,8943680.0,8726912.0,8760320.0,8793728.0,8827136.0,8860544.0,8893952.0,8927360.0,8960768.0,8743552.0,8777024.0,8810496.0,8843968.0,8877440.0,8910912.0,8944384.0,8977856.0,8760192.0,8793728.0,8827264.0,8860800.0,8894336.0,8927872.0,8961408.0,8994944.0,8776832.0,8810432.0,8844032.0,8877632.0,8911232.0,8944832.0,8978432.0,9012032.0,8793472.0,8827136.0,8860800.0,8894464.0,8928128.0,8961792.0,8995456.0,9029120.0,8810112.0,8843840.0,8877568.0,8911296.0,8945024.0,8978752.0,9012480.0,9046208.0,8302464.0,8335232.0,8368000.0,8400768.0,8433536.0,8466304.0,8499072.0,8531840.0,7786624.0,7818432.0,7850240.0,7882048.0,7913856.0,7945664.0,7977472.0,8009280.0,3961216.0,3977136.0,3993056.0,4008976.0,4024896.0,4040816.0,4056736.0,4072656.0,4578048.0,4593840.0,4609632.0,4625424.0,4641216.0,4657008.0,4672800.0,4688592.0,9291136.0,9322752.0,9354368.0,9385984.0,9417600.0,9449216.0,9480832.0,9512448.0,9308800.0,9340480.0,9372160.0,9403840.0,9435520.0,9467200.0,9498880.0,9530560.0,9326464.0,9358208.0,9389952.0,9421696.0,9453440.0,9485184.0,9516928.0,9548672.0,9344128.0,9375936.0,9407744.0,9439552.0,9471360.0,9503168.0,9534976.0,9566784.0,9361792.0,9393664.0,9425536.0,9457408.0,9489280.0,9521152.0,9553024.0,9584896.0,9379456.0,9411392.0,9443328.0,9475264.0,9507200.0,9539136.0,9571072.0,9603008.0,9397120.0,9429120.0,9461120.0,9493120.0,9525120.0,9557120.0,9589120.0,9621120.0,9414784.0,9446848.0,9478912.0,9510976.0,9543040.0,9575104.0,9607168.0,9639232.0,9432448.0,9464576.0,9496704.0,9528832.0,9560960.0,9593088.0,9625216.0,9657344.0,9450112.0,9482304.0,9514496.0,9546688.0,9578880.0,9611072.0,9643264.0,9675456.0,9467776.0,9500032.0,9532288.0,9564544.0,9596800.0,9629056.0,9661312.0,9693568.0,9485440.0,9517760.0,9550080.0,9582400.0,9614720.0,9647040.0,9679360.0,9711680.0,9503104.0,9535488.0,9567872.0,9600256.0,9632640.0,9665024.0,9697408.0,9729792.0,9520768.0,9553216.0,9585664.0,9618112.0,9650560.0,9683008.0,9715456.0,9747904.0,9538432.0,9570944.0,9603456.0,9635968.0,9668480.0,9700992.0,9733504.0,9766016.0,9556096.0,9588672.0,9621248.0,9653824.0,9686400.0,9718976.0,9751552.0,9784128.0,9573760.0,9606400.0,9639040.0,9671680.0,9704320.0,9736960.0,9769600.0,9802240.0,9591424.0,9624128.0,9656832.0,9689536.0,9722240.0,9754944.0,9787648.0,9820352.0,9609088.0,9641856.0,9674624.0,9707392.0,9740160.0,9772928.0,9805696.0,9838464.0,9626752.0,9659584.0,9692416.0,9725248.0,9758080.0,9790912.0,9823744.0,9856576.0,9644416.0,9677312.0,9710208.0,9743104.0,9776000.0,9808896.0,9841792.0,9874688.0,9662080.0,9695040.0,9728000.0,9760960.0,9793920.0,9826880.0,9859840.0,9892800.0,9679744.0,9712768.0,9745792.0,9778816.0,9811840.0,9844864.0,9877888.0,9910912.0,9697408.0,9730496.0,9763584.0,9796672.0,9829760.0,9862848.0,9895936.0,9929024.0,9715072.0,9748224.0,9781376.0,9814528.0,9847680.0,9880832.0,9913984.0,9947136.0,9732736.0,9765952.0,9799168.0,9832384.0,9865600.0,9898816.0,9932032.0,9965248.0,9750400.0,9783680.0,9816960.0,9850240.0,9883520.0,9916800.0,9950080.0,9983360.0,9768064.0,9801408.0,9834752.0,9868096.0,9901440.0,9934784.0,9968128.0,10001472.0,9785728.0,9819136.0,9852544.0,9885952.0,9919360.0,9952768.0,9986176.0,10019584.0,9803392.0,9836864.0,9870336.0,9903808.0,9937280.0,9970752.0,10004224.0,10037696.0,9821056.0,9854592.0,9888128.0,9921664.0,9955200.0,9988736.0,10022272.0,10055808.0,9838720.0,9872320.0,9905920.0,9939520.0,9973120.0,10006720.0,10040320.0,10073920.0,9856384.0,9890048.0,9923712.0,9957376.0,9991040.0,10024704.0,10058368.0,10092032.0,9874048.0,9907776.0,9941504.0,9975232.0,10008960.0,10042688.0,10076416.0,10110144.0,9891712.0,9925504.0,9959296.0,9993088.0,10026880.0,10060672.0,10094464.0,10128256.0,9909376.0,9943232.0,9977088.0,10010944.0,10044800.0,10078656.0,10112512.0,10146368.0,9386368.0,9419264.0,9452160.0,9485056.0,9517952.0,9550848.0,9583744.0,9616640.0,8855168.0,8887104.0,8919040.0,8950976.0,8982912.0,9014848.0,9046784.0,9078720.0,8332160.0,8363136.0,8394112.0,8425088.0,8456064.0,8487040.0,8518016.0,8548992.0,7800960.0,7830976.0,7860992.0,7891008.0,7921024.0,7951040.0,7981056.0,8011072.0,7818624.0,7848704.0,7878784.0,7908864.0,7938944.0,7969024.0,7999104.0,8029184.0,7836288.0,7866432.0,7896576.0,7926720.0,7956864.0,7987008.0,8017152.0,8047296.0,7853952.0,7884160.0,7914368.0,7944576.0,7974784.0,8004992.0,8035200.0,8065408.0,7871616.0,7901888.0,7932160.0,7962432.0,7992704.0,8022976.0,8053248.0,8083520.0,7889280.0,7919616.0,7949952.0,7980288.0,8010624.0,8040960.0,8071296.0,8101632.0,7906944.0,7937344.0,7967744.0,7998144.0,8028544.0,8058944.0,8089344.0,8119744.0,7924608.0,7955072.0,7985536.0,8016000.0,8046464.0,8076928.0,8107392.0,8137856.0,7942272.0,7972800.0,8003328.0,8033856.0,8064384.0,8094912.0,8125440.0,8155968.0,7959936.0,7990528.0,8021120.0,8051712.0,8082304.0,8112896.0,8143488.0,8174080.0,7977600.0,8008256.0,8038912.0,8069568.0,8100224.0,8130880.0,8161536.0,8192192.0,7995264.0,8025984.0,8056704.0,8087424.0,8118144.0,8148864.0,8179584.0,8210304.0,8012928.0,8043712.0,8074496.0,8105280.0,8136064.0,8166848.0,8197632.0,8228416.0,8030592.0,8061440.0,8092288.0,8123136.0,8153984.0,8184832.0,8215680.0,8246528.0,8048256.0,8079168.0,8110080.0,8140992.0,8171904.0,8202816.0,8233728.0,8264640.0,8065920.0,8096896.0,8127872.0,8158848.0,8189824.0,8220800.0,8251776.0,8282752.0,8083584.0,8114624.0,8145664.0,8176704.0,8207744.0,8238784.0,8269824.0,8300864.0,8101248.0,8132352.0,8163456.0,8194560.0,8225664.0,8256768.0,8287872.0,8318976.0,8118912.0,8150080.0,8181248.0,8212416.0,8243584.0,8274752.0,8305920.0,8337088.0,8136576.0,8167808.0,8199040.0,8230272.0,8261504.0,8292736.0,8323968.0,8355200.0,8154240.0,8185536.0,8216832.0,8248128.0,8279424.0,8310720.0,8342016.0,8373312.0,8171904.0,8203264.0,8234624.0,8265984.0,8297344.0,8328704.0,8360064.0,8391424.0,8189568.0,8220992.0,8252416.0,8283840.0,8315264.0,8346688.0,8378112.0,8409536.0,8207232.0,8238720.0,8270208.0,8301696.0,8333184.0,8364672.0,8396160.0,8427648.0,8224896.0,8256448.0,8288000.0,8319552.0,8351104.0,8382656.0,8414208.0,8445760.0,8242560.0,8274176.0,8305792.0,8337408.0,8369024.0,8400640.0,8432256.0,8463872.0,8260224.0,8291904.0,8323584.0,8355264.0,8386944.0,8418624.0,8450304.0,8481984.0,8277888.0,8309632.0,8341376.0,8373120.0,8404864.0,8436608.0,8468352.0,8500096.0,8295552.0,8327360.0,8359168.0,8390976.0,8422784.0,8454592.0,8486400.0,8518208.0,8313216.0,8345088.0,8376960.0,8408832.0,8440704.0,8472576.0,8504448.0,8536320.0,8330880.0,8362816.0,8394752.0,8426688.0,8458624.0,8490560.0,8522496.0,8554432.0,8348544.0,8380544.0,8412544.0,8444544.0,8476544.0,8508544.0,8540544.0,8572544.0,8366208.0,8398272.0,8430336.0,8462400.0,8494464.0,8526528.0,8558592.0,8590656.0,8383872.0,8416000.0,8448128.0,8480256.0,8512384.0,8544512.0,8576640.0,8608768.0,8401536.0,8433728.0,8465920.0,8498112.0,8530304.0,8562496.0,8594688.0,8626880.0,8419200.0,8451456.0,8483712.0,8515968.0,8548224.0,8580480.0,8612736.0,8644992.0,8436864.0,8469184.0,8501504.0,8533824.0,8566144.0,8598464.0,8630784.0,8663104.0,8454528.0,8486912.0,8519296.0,8551680.0,8584064.0,8616448.0,8648832.0,8681216.0,8455808.0,8487232.0,8518656.0,8550080.0,8581504.0,8612928.0,8644352.0,8675776.0,8448896.0,8479360.0,8509824.0,8540288.0,8570752.0,8601216.0,8631680.0,8662144.0,8466560.0,8497088.0,8527616.0,8558144.0,8588672.0,8619200.0,8649728.0,8680256.0,8484224.0,8514816.0,8545408.0,8576000.0,8606592.0,8637184.0,8667776.0,8698368.0,8501888.0,8532544.0,8563200.0,8593856.0,8624512.0,8655168.0,8685824.0,8716480.0,8519552.0,8550272.0,8580992.0,8611712.0,8642432.0,8673152.0,8703872.0,8734592.0,8537216.0,8568000.0,8598784.0,8629568.0,8660352.0,8691136.0,8721920.0,8752704.0,8554880.0,8585728.0,8616576.0,8647424.0,8678272.0,8709120.0,8739968.0,8770816.0,8572544.0,8603456.0,8634368.0,8665280.0,8696192.0,8727104.0,8758016.0,8788928.0,8590208.0,8621184.0,8652160.0,8683136.0,8714112.0,8745088.0,8776064.0,8807040.0,8607872.0,8638912.0,8669952.0,8700992.0,8732032.0,8763072.0,8794112.0,8825152.0,8625536.0,8656640.0,8687744.0,8718848.0,8749952.0,8781056.0,8812160.0,8843264.0,8643200.0,8674368.0,8705536.0,8736704.0,8767872.0,8799040.0,8830208.0,8861376.0,8660864.0,8692096.0,8723328.0,8754560.0,8785792.0,8817024.0,8848256.0,8879488.0,8678528.0,8709824.0,8741120.0,8772416.0,8803712.0,8835008.0,8866304.0,8897600.0,8696192.0,8727552.0,8758912.0,8790272.0,8821632.0,8852992.0,8884352.0,8915712.0,8713856.0,8745280.0,8776704.0,8808128.0,8839552.0,8870976.0,8902400.0,8933824.0,8731520.0,8763008.0,8794496.0,8825984.0,8857472.0,8888960.0,8920448.0,8951936.0,8749184.0,8780736.0,8812288.0,8843840.0,8875392.0,8906944.0,8938496.0,8970048.0,8766848.0,8798464.0,8830080.0,8861696.0,8893312.0,8924928.0,8956544.0,8988160.0,8784512.0,8816192.0,8847872.0,8879552.0,8911232.0,8942912.0,8974592.0,9006272.0,8802176.0,8833920.0,8865664.0,8897408.0,8929152.0,8960896.0,8992640.0,9024384.0,8819840.0,8851648.0,8883456.0,8915264.0,8947072.0,8978880.0,9010688.0,9042496.0,8837504.0,8869376.0,8901248.0,8933120.0,8964992.0,8996864.0,9028736.0,9060608.0,8855168.0,8887104.0,8919040.0,8950976.0,8982912.0,9014848.0,9046784.0,9078720.0,8872832.0,8904832.0,8936832.0,8968832.0,9000832.0,9032832.0,9064832.0,9096832.0,8890496.0,8922560.0,8954624.0,8986688.0,9018752.0,9050816.0,9082880.0,9114944.0,8908160.0,8940288.0,8972416.0,9004544.0,9036672.0,9068800.0,9100928.0,9133056.0,8925824.0,8958016.0,8990208.0,9022400.0,9054592.0,9086784.0,9118976.0,9151168.0,8943488.0,8975744.0,9008000.0,9040256.0,9072512.0,9104768.0,9137024.0,9169280.0,8961152.0,8993472.0,9025792.0,9058112.0,9090432.0,9122752.0,9155072.0,9187392.0,8978816.0,9011200.0,9043584.0,9075968.0,9108352.0,9140736.0,9173120.0,9205504.0,8996480.0,9028928.0,9061376.0,9093824.0,9126272.0,9158720.0,9191168.0,9223616.0,9014144.0,9046656.0,9079168.0,9111680.0,9144192.0,9176704.0,9209216.0,9241728.0,9031808.0,9064384.0,9096960.0,9129536.0,9162112.0,9194688.0,9227264.0,9259840.0,9049472.0,9082112.0,9114752.0,9147392.0,9180032.0,9212672.0,9245312.0,9277952.0,9067136.0,9099840.0,9132544.0,9165248.0,9197952.0,9230656.0,9263360.0,9296064.0,9084800.0,9117568.0,9150336.0,9183104.0,9215872.0,9248640.0,9281408.0,9314176.0,9102464.0,9135296.0,9168128.0,9200960.0,9233792.0,9266624.0,9299456.0,9332288.0,8579456.0,8611328.0,8643200.0,8675072.0,8706944.0,8738816.0,8770688.0,8802560.0,8048256.0,8079168.0,8110080.0,8140992.0,8171904.0,8202816.0,8233728.0,8264640.0,8065920.0,8096896.0,8127872.0,8158848.0,8189824.0,8220800.0,8251776.0,8282752.0,8083584.0,8114624.0,8145664.0,8176704.0,8207744.0,8238784.0,8269824.0,8300864.0,8101248.0,8132352.0,8163456.0,8194560.0,8225664.0,8256768.0,8287872.0,8318976.0,8118912.0,8150080.0,8181248.0,8212416.0,8243584.0,8274752.0,8305920.0,8337088.0,8136576.0,8167808.0,8199040.0,8230272.0,8261504.0,8292736.0,8323968.0,8355200.0,8154240.0,8185536.0,8216832.0,8248128.0,8279424.0,8310720.0,8342016.0,8373312.0,8171904.0,8203264.0,8234624.0,8265984.0,8297344.0,8328704.0,8360064.0,8391424.0,8189568.0,8220992.0,8252416.0,8283840.0,8315264.0,8346688.0,8378112.0,8409536.0,8207232.0,8238720.0,8270208.0,8301696.0,8333184.0,8364672.0,8396160.0,8427648.0,8224896.0,8256448.0,8288000.0,8319552.0,8351104.0,8382656.0,8414208.0,8445760.0,8242560.0,8274176.0,8305792.0,8337408.0,8369024.0,8400640.0,8432256.0,8463872.0,8260224.0,8291904.0,8323584.0,8355264.0,8386944.0,8418624.0,8450304.0,8481984.0,8277888.0,8309632.0,8341376.0,8373120.0,8404864.0,8436608.0,8468352.0,8500096.0,8295552.0,8327360.0,8359168.0,8390976.0,8422784.0,8454592.0,8486400.0,8518208.0,8313216.0,8345088.0,8376960.0,8408832.0,8440704.0,8472576.0,8504448.0,8536320.0,8330880.0,8362816.0,8394752.0,8426688.0,8458624.0,8490560.0,8522496.0,8554432.0,8348544.0,8380544.0,8412544.0,8444544.0,8476544.0,8508544.0,8540544.0,8572544.0,8366208.0,8398272.0,8430336.0,8462400.0,8494464.0,8526528.0,8558592.0,8590656.0,8383872.0,8416000.0,8448128.0,8480256.0,8512384.0,8544512.0,8576640.0,8608768.0,8401536.0,8433728.0,8465920.0,8498112.0,8530304.0,8562496.0,8594688.0,8626880.0,8419200.0,8451456.0,8483712.0,8515968.0,8548224.0,8580480.0,8612736.0,8644992.0,8436864.0,8469184.0,8501504.0,8533824.0,8566144.0,8598464.0,8630784.0,8663104.0,8454528.0,8486912.0,8519296.0,8551680.0,8584064.0,8616448.0,8648832.0,8681216.0,8472192.0,8504640.0,8537088.0,8569536.0,8601984.0,8634432.0,8666880.0,8699328.0,8489856.0,8522368.0,8554880.0,8587392.0,8619904.0,8652416.0,8684928.0,8717440.0,8507520.0,8540096.0,8572672.0,8605248.0,8637824.0,8670400.0,8702976.0,8735552.0,8525184.0,8557824.0,8590464.0,8623104.0,8655744.0,8688384.0,8721024.0,8753664.0,8542848.0,8575552.0,8608256.0,8640960.0,8673664.0,8706368.0,8739072.0,8771776.0,8560512.0,8593280.0,8626048.0,8658816.0,8691584.0,8724352.0,8757120.0,8789888.0,8578176.0,8611008.0,8643840.0,8676672.0,8709504.0,8742336.0,8775168.0,8808000.0,8595840.0,8628736.0,8661632.0,8694528.0,8727424.0,8760320.0,8793216.0,8826112.0,8613504.0,8646464.0,8679424.0,8712384.0,8745344.0,8778304.0,8811264.0,8844224.0,8631168.0,8664192.0,8697216.0,8730240.0,8763264.0,8796288.0,8829312.0,8862336.0,8648832.0,8681920.0,8715008.0,8748096.0,8781184.0,8814272.0,8847360.0,8880448.0,8666496.0,8699648.0,8732800.0,8765952.0,8799104.0,8832256.0,8865408.0,8898560.0,8684160.0,8717376.0,8750592.0,8783808.0,8817024.0,8850240.0,8883456.0,8916672.0,8701824.0,8735104.0,8768384.0,8801664.0,8834944.0,8868224.0,8901504.0,8934784.0,8703104.0,8735424.0,8767744.0,8800064.0,8832384.0,8864704.0,8897024.0,8929344.0,8696192.0,8727552.0,8758912.0,8790272.0,8821632.0,8852992.0,8884352.0,8915712.0,8713856.0,8745280.0,8776704.0,8808128.0,8839552.0,8870976.0,8902400.0,8933824.0,8731520.0,8763008.0,8794496.0,8825984.0,8857472.0,8888960.0,8920448.0,8951936.0,8749184.0,8780736.0,8812288.0,8843840.0,8875392.0,8906944.0,8938496.0,8970048.0,8766848.0,8798464.0,8830080.0,8861696.0,8893312.0,8924928.0,8956544.0,8988160.0,8784512.0,8816192.0,8847872.0,8879552.0,8911232.0,8942912.0,8974592.0,9006272.0,8802176.0,8833920.0,8865664.0,8897408.0,8929152.0,8960896.0,8992640.0,9024384.0,8819840.0,8851648.0,8883456.0,8915264.0,8947072.0,8978880.0,9010688.0,9042496.0,8837504.0,8869376.0,8901248.0,8933120.0,8964992.0,8996864.0,9028736.0,9060608.0,8855168.0,8887104.0,8919040.0,8950976.0,8982912.0,9014848.0,9046784.0,9078720.0,8872832.0,8904832.0,8936832.0,8968832.0,9000832.0,9032832.0,9064832.0,9096832.0,8890496.0,8922560.0,8954624.0,8986688.0,9018752.0,9050816.0,9082880.0,9114944.0,8908160.0,8940288.0,8972416.0,9004544.0,9036672.0,9068800.0,9100928.0,9133056.0,8925824.0,8958016.0,8990208.0,9022400.0,9054592.0,9086784.0,9118976.0,9151168.0,8943488.0,8975744.0,9008000.0,9040256.0,9072512.0,9104768.0,9137024.0,9169280.0,8961152.0,8993472.0,9025792.0,9058112.0,9090432.0,9122752.0,9155072.0,9187392.0,8978816.0,9011200.0,9043584.0,9075968.0,9108352.0,9140736.0,9173120.0,9205504.0,8996480.0,9028928.0,9061376.0,9093824.0,9126272.0,9158720.0,9191168.0,9223616.0,9014144.0,9046656.0,9079168.0,9111680.0,9144192.0,9176704.0,9209216.0,9241728.0,9031808.0,9064384.0,9096960.0,9129536.0,9162112.0,9194688.0,9227264.0,9259840.0,9049472.0,9082112.0,9114752.0,9147392.0,9180032.0,9212672.0,9245312.0,9277952.0,9067136.0,9099840.0,9132544.0,9165248.0,9197952.0,9230656.0,9263360.0,9296064.0,9084800.0,9117568.0,9150336.0,9183104.0,9215872.0,9248640.0,9281408.0,9314176.0,9102464.0,9135296.0,9168128.0,9200960.0,9233792.0,9266624.0,9299456.0,9332288.0,9120128.0,9153024.0,9185920.0,9218816.0,9251712.0,9284608.0,9317504.0,9350400.0,9137792.0,9170752.0,9203712.0,9236672.0,9269632.0,9302592.0,9335552.0,9368512.0,9155456.0,9188480.0,9221504.0,9254528.0,9287552.0,9320576.0,9353600.0,9386624.0,9173120.0,9206208.0,9239296.0,9272384.0,9305472.0,9338560.0,9371648.0,9404736.0,9190784.0,9223936.0,9257088.0,9290240.0,9323392.0,9356544.0,9389696.0,9422848.0,9208448.0,9241664.0,9274880.0,9308096.0,9341312.0,9374528.0,9407744.0,9440960.0,9226112.0,9259392.0,9292672.0,9325952.0,9359232.0,9392512.0,9425792.0,9459072.0,9243776.0,9277120.0,9310464.0,9343808.0,9377152.0,9410496.0,9443840.0,9477184.0,9261440.0,9294848.0,9328256.0,9361664.0,9395072.0,9428480.0,9461888.0,9495296.0,9279104.0,9312576.0,9346048.0,9379520.0,9412992.0,9446464.0,9479936.0,9513408.0,9296768.0,9330304.0,9363840.0,9397376.0,9430912.0,9464448.0,9497984.0,9531520.0,9314432.0,9348032.0,9381632.0,9415232.0,9448832.0,9482432.0,9516032.0,9549632.0,9332096.0,9365760.0,9399424.0,9433088.0,9466752.0,9500416.0,9534080.0,9567744.0,9349760.0,9383488.0,9417216.0,9450944.0,9484672.0,9518400.0,9552128.0,9585856.0,8826752.0,8859520.0,8892288.0,8925056.0,8957824.0,8990592.0,9023360.0,9056128.0,8295552.0,8327360.0,8359168.0,8390976.0,8422784.0,8454592.0,8486400.0,8518208.0,4215936.0,4231856.0,4247776.0,4263696.0,4279616.0,4295536.0,4311456.0,4327376.0,4830720.0,4846512.0,4862304.0,4878096.0,4893888.0,4909680.0,4925472.0,4941264.0,9796992.0,9828608.0,9860224.0,9891840.0,9923456.0,9955072.0,9986688.0,10018304.0,9815680.0,9847360.0,9879040.0,9910720.0,9942400.0,9974080.0,10005760.0,10037440.0,9834368.0,9866112.0,9897856.0,9929600.0,9961344.0,9993088.0,10024832.0,10056576.0,9853056.0,9884864.0,9916672.0,9948480.0,9980288.0,10012096.0,10043904.0,10075712.0,9871744.0,9903616.0,9935488.0,9967360.0,9999232.0,10031104.0,10062976.0,10094848.0,9890432.0,9922368.0,9954304.0,9986240.0,10018176.0,10050112.0,10082048.0,10113984.0,9909120.0,9941120.0,9973120.0,10005120.0,10037120.0,10069120.0,10101120.0,10133120.0,9927808.0,9959872.0,9991936.0,10024000.0,10056064.0,10088128.0,10120192.0,10152256.0,9946496.0,9978624.0,10010752.0,10042880.0,10075008.0,10107136.0,10139264.0,10171392.0,9965184.0,9997376.0,10029568.0,10061760.0,10093952.0,10126144.0,10158336.0,10190528.0,9983872.0,10016128.0,10048384.0,10080640.0,10112896.0,10145152.0,10177408.0,10209664.0,10002560.0,10034880.0,10067200.0,10099520.0,10131840.0,10164160.0,10196480.0,10228800.0,10021248.0,10053632.0,10086016.0,10118400.0,10150784.0,10183168.0,10215552.0,10247936.0,10039936.0,10072384.0,10104832.0,10137280.0,10169728.0,10202176.0,10234624.0,10267072.0,10058624.0,10091136.0,10123648.0,10156160.0,10188672.0,10221184.0,10253696.0,10286208.0,10077312.0,10109888.0,10142464.0,10175040.0,10207616.0,10240192.0,10272768.0,10305344.0,10096000.0,10128640.0,10161280.0,10193920.0,10226560.0,10259200.0,10291840.0,10324480.0,10114688.0,10147392.0,10180096.0,10212800.0,10245504.0,10278208.0,10310912.0,10343616.0,10133376.0,10166144.0,10198912.0,10231680.0,10264448.0,10297216.0,10329984.0,10362752.0,10152064.0,10184896.0,10217728.0,10250560.0,10283392.0,10316224.0,10349056.0,10381888.0,10170752.0,10203648.0,10236544.0,10269440.0,10302336.0,10335232.0,10368128.0,10401024.0,10189440.0,10222400.0,10255360.0,10288320.0,10321280.0,10354240.0,10387200.0,10420160.0,10208128.0,10241152.0,10274176.0,10307200.0,10340224.0,10373248.0,10406272.0,10439296.0,10226816.0,10259904.0,10292992.0,10326080.0,10359168.0,10392256.0,10425344.0,10458432.0,10245504.0,10278656.0,10311808.0,10344960.0,10378112.0,10411264.0,10444416.0,10477568.0,10264192.0,10297408.0,10330624.0,10363840.0,10397056.0,10430272.0,10463488.0,10496704.0,10282880.0,10316160.0,10349440.0,10382720.0,10416000.0,10449280.0,10482560.0,10515840.0,10301568.0,10334912.0,10368256.0,10401600.0,10434944.0,10468288.0,10501632.0,10534976.0,10320256.0,10353664.0,10387072.0,10420480.0,10453888.0,10487296.0,10520704.0,10554112.0,10338944.0,10372416.0,10405888.0,10439360.0,10472832.0,10506304.0,10539776.0,10573248.0,10357632.0,10391168.0,10424704.0,10458240.0,10491776.0,10525312.0,10558848.0,10592384.0,10376320.0,10409920.0,10443520.0,10477120.0,10510720.0,10544320.0,10577920.0,10611520.0,10395008.0,10428672.0,10462336.0,10496000.0,10529664.0,10563328.0,10596992.0,10630656.0,10413696.0,10447424.0,10481152.0,10514880.0,10548608.0,10582336.0,10616064.0,10649792.0,10432384.0,10466176.0,10499968.0,10533760.0,10567552.0,10601344.0,10635136.0,10668928.0,10451072.0,10484928.0,10518784.0,10552640.0,10586496.0,10620352.0,10654208.0,10688064.0,9912704.0,9945600.0,9978496.0,10011392.0,10044288.0,10077184.0,10110080.0,10142976.0,9366144.0,9398080.0,9430016.0,9461952.0,9493888.0,9525824.0,9557760.0,9589696.0,8827776.0,8858752.0,8889728.0,8920704.0,8951680.0,8982656.0,9013632.0,9044608.0,8281216.0,8311232.0,8341248.0,8371264.0,8401280.0,8431296.0,8461312.0,8491328.0,8299904.0,8329984.0,8360064.0,8390144.0,8420224.0,8450304.0,8480384.0,8510464.0,8318592.0,8348736.0,8378880.0,8409024.0,8439168.0,8469312.0,8499456.0,8529600.0,8337280.0,8367488.0,8397696.0,8427904.0,8458112.0,8488320.0,8518528.0,8548736.0,8355968.0,8386240.0,8416512.0,8446784.0,8477056.0,8507328.0,8537600.0,8567872.0,8374656.0,8404992.0,8435328.0,8465664.0,8496000.0,8526336.0,8556672.0,8587008.0,8393344.0,8423744.0,8454144.0,8484544.0,8514944.0,8545344.0,8575744.0,8606144.0,8412032.0,8442496.0,8472960.0,8503424.0,8533888.0,8564352.0,8594816.0,8625280.0,8430720.0,8461248.0,8491776.0,8522304.0,8552832.0,8583360.0,8613888.0,8644416.0,8449408.0,8480000.0,8510592.0,8541184.0,8571776.0,8602368.0,8632960.0,8663552.0,8468096.0,8498752.0,8529408.0,8560064.0,8590720.0,8621376.0,8652032.0,8682688.0,8486784.0,8517504.0,8548224.0,8578944.0,8609664.0,8640384.0,8671104.0,8701824.0,8505472.0,8536256.0,8567040.0,8597824.0,8628608.0,8659392.0,8690176.0,8720960.0,8524160.0,8555008.0,8585856.0,8616704.0,8647552.0,8678400.0,8709248.0,8740096.0,8542848.0,8573760.0,8604672.0,8635584.0,8666496.0,8697408.0,8728320.0,8759232.0,8561536.0,8592512.0,8623488.0,8654464.0,8685440.0,8716416.0,8747392.0,8778368.0,8580224.0,8611264.0,8642304.0,8673344.0,8704384.0,8735424.0,8766464.0,8797504.0,8598912.0,8630016.0,8661120.0,8692224.0,8723328.0,8754432.0,8785536.0,8816640.0,8617600.0,8648768.0,8679936.0,8711104.0,8742272.0,8773440.0,8804608.0,8835776.0,8636288.0,8667520.0,8698752.0,8729984.0,8761216.0,8792448.0,8823680.0,8854912.0,8654976.0,8686272.0,8717568.0,8748864.0,8780160.0,8811456.0,8842752.0,8874048.0,8673664.0,8705024.0,8736384.0,8767744.0,8799104.0,8830464.0,8861824.0,8893184.0,8692352.0,8723776.0,8755200.0,8786624.0,8818048.0,8849472.0,8880896.0,8912320.0,8711040.0,8742528.0,8774016.0,8805504.0,8836992.0,8868480.0,8899968.0,8931456.0,8729728.0,8761280.0,8792832.0,8824384.0,8855936.0,8887488.0,8919040.0,8950592.0,8748416.0,8780032.0,8811648.0,8843264.0,8874880.0,8906496.0,8938112.0,8969728.0,8767104.0,8798784.0,8830464.0,8862144.0,8893824.0,8925504.0,8957184.0,8988864.0,8785792.0,8817536.0,8849280.0,8881024.0,8912768.0,8944512.0,8976256.0,9008000.0,8804480.0,8836288.0,8868096.0,8899904.0,8931712.0,8963520.0,8995328.0,9027136.0,8823168.0,8855040.0,8886912.0,8918784.0,8950656.0,8982528.0,9014400.0,9046272.0,8841856.0,8873792.0,8905728.0,8937664.0,8969600.0,9001536.0,9033472.0,9065408.0,8860544.0,8892544.0,8924544.0,8956544.0,8988544.0,9020544.0,9052544.0,9084544.0,8879232.0,8911296.0,8943360.0,8975424.0,9007488.0,9039552.0,9071616.0,9103680.0,8897920.0,8930048.0,8962176.0,8994304.0,9026432.0,9058560.0,9090688.0,9122816.0,8916608.0,8948800.0,8980992.0,9013184.0,9045376.0,9077568.0,9109760.0,9141952.0,8935296.0,8967552.0,8999808.0,9032064.0,9064320.0,9096576.0,9128832.0,9161088.0,8953984.0,8986304.0,9018624.0,9050944.0,9083264.0,9115584.0,9147904.0,9180224.0,8972672.0,9005056.0,9037440.0,9069824.0,9102208.0,9134592.0,9166976.0,9199360.0,8958592.0,8990016.0,9021440.0,9052864.0,9084288.0,9115712.0,9147136.0,9178560.0,8936320.0,8966784.0,8997248.0,9027712.0,9058176.0,9088640.0,9119104.0,9149568.0,8955008.0,8985536.0,9016064.0,9046592.0,9077120.0,9107648.0,9138176.0,9168704.0,8973696.0,9004288.0,9034880.0,9065472.0,9096064.0,9126656.0,9157248.0,9187840.0,8992384.0,9023040.0,9053696.0,9084352.0,9115008.0,9145664.0,9176320.0,9206976.0,9011072.0,9041792.0,9072512.0,9103232.0,9133952.0,9164672.0,9195392.0,9226112.0,9029760.0,9060544.0,9091328.0,9122112.0,9152896.0,9183680.0,9214464.0,9245248.0,9048448.0,9079296.0,9110144.0,9140992.0,9171840.0,9202688.0,9233536.0,9264384.0,9067136.0,9098048.0,9128960.0,9159872.0,9190784.0,9221696.0,9252608.0,9283520.0,9085824.0,9116800.0,9147776.0,9178752.0,9209728.0,9240704.0,9271680.0,9302656.0,9104512.0,9135552.0,9166592.0,9197632.0,9228672.0,9259712.0,9290752.0,9321792.0,9123200.0,9154304.0,9185408.0,9216512.0,9247616.0,9278720.0,9309824.0,9340928.0,9141888.0,9173056.0,9204224.0,9235392.0,9266560.0,9297728.0,9328896.0,9360064.0,9160576.0,9191808.0,9223040.0,9254272.0,9285504.0,9316736.0,9347968.0,9379200.0,9179264.0,9210560.0,9241856.0,9273152.0,9304448.0,9335744.0,9367040.0,9398336.0,9197952.0,9229312.0,9260672.0,9292032.0,9323392.0,9354752.0,9386112.0,9417472.0,9216640.0,9248064.0,9279488.0,9310912.0,9342336.0,9373760.0,9405184.0,9436608.0,9235328.0,9266816.0,9298304.0,9329792.0,9361280.0,9392768.0,9424256.0,9455744.0,9254016.0,9285568.0,9317120.0,9348672.0,9380224.0,9411776.0,9443328.0,9474880.0,9272704.0,9304320.0,9335936.0,9367552.0,9399168.0,9430784.0,9462400.0,9494016.0,9291392.0,9323072.0,9354752.0,9386432.0,9418112.0,9449792.0,9481472.0,9513152.0,9310080.0,9341824.0,9373568.0,9405312.0,9437056.0,9468800.0,9500544.0,9532288.0,9328768.0,9360576.0,9392384.0,9424192.0,9456000.0,9487808.0,9519616.0,9551424.0,9347456.0,9379328.0,9411200.0,9443072.0,9474944.0,9506816.0,9538688.0,9570560.0,9366144.0,9398080.0,9430016.0,9461952.0,9493888.0,9525824.0,9557760.0,9589696.0,9384832.0,9416832.0,9448832.0,9480832.0,9512832.0,9544832.0,9576832.0,9608832.0,9403520.0,9435584.0,9467648.0,9499712.0,9531776.0,9563840.0,9595904.0,9627968.0,9422208.0,9454336.0,9486464.0,9518592.0,9550720.0,9582848.0,9614976.0,9647104.0,9440896.0,9473088.0,9505280.0,9537472.0,9569664.0,9601856.0,9634048.0,9666240.0,9459584.0,9491840.0,9524096.0,9556352.0,9588608.0,9620864.0,9653120.0,9685376.0,9478272.0,9510592.0,9542912.0,9575232.0,9607552.0,9639872.0,9672192.0,9704512.0,9496960.0,9529344.0,9561728.0,9594112.0,9626496.0,9658880.0,9691264.0,9723648.0,9515648.0,9548096.0,9580544.0,9612992.0,9645440.0,9677888.0,9710336.0,9742784.0,9534336.0,9566848.0,9599360.0,9631872.0,9664384.0,9696896.0,9729408.0,9761920.0,9553024.0,9585600.0,9618176.0,9650752.0,9683328.0,9715904.0,9748480.0,9781056.0,9571712.0,9604352.0,9636992.0,9669632.0,9702272.0,9734912.0,9767552.0,9800192.0,9590400.0,9623104.0,9655808.0,9688512.0,9721216.0,9753920.0,9786624.0,9819328.0,9609088.0,9641856.0,9674624.0,9707392.0,9740160.0,9772928.0,9805696.0,9838464.0,9627776.0,9660608.0,9693440.0,9726272.0,9759104.0,9791936.0,9824768.0,9857600.0,9089408.0,9121280.0,9153152.0,9185024.0,9216896.0,9248768.0,9280640.0,9312512.0,8542848.0,8573760.0,8604672.0,8635584.0,8666496.0,8697408.0,8728320.0,8759232.0,8561536.0,8592512.0,8623488.0,8654464.0,8685440.0,8716416.0,8747392.0,8778368.0,8580224.0,8611264.0,8642304.0,8673344.0,8704384.0,8735424.0,8766464.0,8797504.0,8598912.0,8630016.0,8661120.0,8692224.0,8723328.0,8754432.0,8785536.0,8816640.0,8617600.0,8648768.0,8679936.0,8711104.0,8742272.0,8773440.0,8804608.0,8835776.0,8636288.0,8667520.0,8698752.0,8729984.0,8761216.0,8792448.0,8823680.0,8854912.0,8654976.0,8686272.0,8717568.0,8748864.0,8780160.0,8811456.0,8842752.0,8874048.0,8673664.0,8705024.0,8736384.0,8767744.0,8799104.0,8830464.0,8861824.0,8893184.0,8692352.0,8723776.0,8755200.0,8786624.0,8818048.0,8849472.0,8880896.0,8912320.0,8711040.0,8742528.0,8774016.0,8805504.0,8836992.0,8868480.0,8899968.0,8931456.0,8729728.0,8761280.0,8792832.0,8824384.0,8855936.0,8887488.0,8919040.0,8950592.0,8748416.0,8780032.0,8811648.0,8843264.0,8874880.0,8906496.0,8938112.0,8969728.0,8767104.0,8798784.0,8830464.0,8862144.0,8893824.0,8925504.0,8957184.0,8988864.0,8785792.0,8817536.0,8849280.0,8881024.0,8912768.0,8944512.0,8976256.0,9008000.0,8804480.0,8836288.0,8868096.0,8899904.0,8931712.0,8963520.0,8995328.0,9027136.0,8823168.0,8855040.0,8886912.0,8918784.0,8950656.0,8982528.0,9014400.0,9046272.0,8841856.0,8873792.0,8905728.0,8937664.0,8969600.0,9001536.0,9033472.0,9065408.0,8860544.0,8892544.0,8924544.0,8956544.0,8988544.0,9020544.0,9052544.0,9084544.0,8879232.0,8911296.0,8943360.0,8975424.0,9007488.0,9039552.0,9071616.0,9103680.0,8897920.0,8930048.0,8962176.0,8994304.0,9026432.0,9058560.0,9090688.0,9122816.0,8916608.0,8948800.0,8980992.0,9013184.0,9045376.0,9077568.0,9109760.0,9141952.0,8935296.0,8967552.0,8999808.0,9032064.0,9064320.0,9096576.0,9128832.0,9161088.0,8953984.0,8986304.0,9018624.0,9050944.0,9083264.0,9115584.0,9147904.0,9180224.0,8972672.0,9005056.0,9037440.0,9069824.0,9102208.0,9134592.0,9166976.0,9199360.0,8991360.0,9023808.0,9056256.0,9088704.0,9121152.0,9153600.0,9186048.0,9218496.0,9010048.0,9042560.0,9075072.0,9107584.0,9140096.0,9172608.0,9205120.0,9237632.0,9028736.0,9061312.0,9093888.0,9126464.0,9159040.0,9191616.0,9224192.0,9256768.0,9047424.0,9080064.0,9112704.0,9145344.0,9177984.0,9210624.0,9243264.0,9275904.0,9066112.0,9098816.0,9131520.0,9164224.0,9196928.0,9229632.0,9262336.0,9295040.0,9084800.0,9117568.0,9150336.0,9183104.0,9215872.0,9248640.0,9281408.0,9314176.0,9103488.0,9136320.0,9169152.0,9201984.0,9234816.0,9267648.0,9300480.0,9333312.0,9122176.0,9155072.0,9187968.0,9220864.0,9253760.0,9286656.0,9319552.0,9352448.0,9140864.0,9173824.0,9206784.0,9239744.0,9272704.0,9305664.0,9338624.0,9371584.0,9159552.0,9192576.0,9225600.0,9258624.0,9291648.0,9324672.0,9357696.0,9390720.0,9178240.0,9211328.0,9244416.0,9277504.0,9310592.0,9343680.0,9376768.0,9409856.0,9196928.0,9230080.0,9263232.0,9296384.0,9329536.0,9362688.0,9395840.0,9428992.0,9215616.0,9248832.0,9282048.0,9315264.0,9348480.0,9381696.0,9414912.0,9448128.0,9234304.0,9267584.0,9300864.0,9334144.0,9367424.0,9400704.0,9433984.0,9467264.0,9220224.0,9252544.0,9284864.0,9317184.0,9349504.0,9381824.0,9414144.0,9446464.0,9197952.0,9229312.0,9260672.0,9292032.0,9323392.0,9354752.0,9386112.0,9417472.0,9216640.0,9248064.0,9279488.0,9310912.0,9342336.0,9373760.0,9405184.0,9436608.0,9235328.0,9266816.0,9298304.0,9329792.0,9361280.0,9392768.0,9424256.0,9455744.0,9254016.0,9285568.0,9317120.0,9348672.0,9380224.0,9411776.0,9443328.0,9474880.0,9272704.0,9304320.0,9335936.0,9367552.0,9399168.0,9430784.0,9462400.0,9494016.0,9291392.0,9323072.0,9354752.0,9386432.0,9418112.0,9449792.0,9481472.0,9513152.0,9310080.0,9341824.0,9373568.0,9405312.0,9437056.0,9468800.0,9500544.0,9532288.0,9328768.0,9360576.0,9392384.0,9424192.0,9456000.0,9487808.0,9519616.0,9551424.0,9347456.0,9379328.0,9411200.0,9443072.0,9474944.0,9506816.0,9538688.0,9570560.0,9366144.0,9398080.0,9430016.0,9461952.0,9493888.0,9525824.0,9557760.0,9589696.0,9384832.0,9416832.0,9448832.0,9480832.0,9512832.0,9544832.0,9576832.0,9608832.0,9403520.0,9435584.0,9467648.0,9499712.0,9531776.0,9563840.0,9595904.0,9627968.0,9422208.0,9454336.0,9486464.0,9518592.0,9550720.0,9582848.0,9614976.0,9647104.0,9440896.0,9473088.0,9505280.0,9537472.0,9569664.0,9601856.0,9634048.0,9666240.0,9459584.0,9491840.0,9524096.0,9556352.0,9588608.0,9620864.0,9653120.0,9685376.0,9478272.0,9510592.0,9542912.0,9575232.0,9607552.0,9639872.0,9672192.0,9704512.0,9496960.0,9529344.0,9561728.0,9594112.0,9626496.0,9658880.0,9691264.0,9723648.0,9515648.0,9548096.0,9580544.0,9612992.0,9645440.0,9677888.0,9710336.0,9742784.0,9534336.0,9566848.0,9599360.0,9631872.0,9664384.0,9696896.0,9729408.0,9761920.0,9553024.0,9585600.0,9618176.0,9650752.0,9683328.0,9715904.0,9748480.0,9781056.0,9571712.0,9604352.0,9636992.0,9669632.0,9702272.0,9734912.0,9767552.0,9800192.0,9590400.0,9623104.0,9655808.0,9688512.0,9721216.0,9753920.0,9786624.0,9819328.0,9609088.0,9641856.0,9674624.0,9707392.0,9740160.0,9772928.0,9805696.0,9838464.0,9627776.0,9660608.0,9693440.0,9726272.0,9759104.0,9791936.0,9824768.0,9857600.0,9646464.0,9679360.0,9712256.0,9745152.0,9778048.0,9810944.0,9843840.0,9876736.0,9665152.0,9698112.0,9731072.0,9764032.0,9796992.0,9829952.0,9862912.0,9895872.0,9683840.0,9716864.0,9749888.0,9782912.0,9815936.0,9848960.0,9881984.0,9915008.0,9702528.0,9735616.0,9768704.0,9801792.0,9834880.0,9867968.0,9901056.0,9934144.0,9721216.0,9754368.0,9787520.0,9820672.0,9853824.0,9886976.0,9920128.0,9953280.0,9739904.0,9773120.0,9806336.0,9839552.0,9872768.0,9905984.0,9939200.0,9972416.0,9758592.0,9791872.0,9825152.0,9858432.0,9891712.0,9924992.0,9958272.0,9991552.0,9777280.0,9810624.0,9843968.0,9877312.0,9910656.0,9944000.0,9977344.0,10010688.0,9795968.0,9829376.0,9862784.0,9896192.0,9929600.0,9963008.0,9996416.0,10029824.0,9814656.0,9848128.0,9881600.0,9915072.0,9948544.0,9982016.0,10015488.0,10048960.0,9833344.0,9866880.0,9900416.0,9933952.0,9967488.0,10001024.0,10034560.0,10068096.0,9852032.0,9885632.0,9919232.0,9952832.0,9986432.0,10020032.0,10053632.0,10087232.0,9870720.0,9904384.0,9938048.0,9971712.0,10005376.0,10039040.0,10072704.0,10106368.0,9889408.0,9923136.0,9956864.0,9990592.0,10024320.0,10058048.0,10091776.0,10125504.0,9351040.0,9383808.0,9416576.0,9449344.0,9482112.0,9514880.0,9547648.0,9580416.0,8804480.0,8836288.0,8868096.0,8899904.0,8931712.0,8963520.0,8995328.0,9027136.0,4470656.0,4486576.0,4502496.0,4518416.0,4534336.0,4550256.0,4566176.0,4582096.0,5083392.0,5099184.0,5114976.0,5130768.0,5146560.0,5162352.0,5178144.0,5193936.0,10302848.0,10334464.0,10366080.0,10397696.0,10429312.0,10460928.0,10492544.0,10524160.0,10322560.0,10354240.0,10385920.0,10417600.0,10449280.0,10480960.0,10512640.0,10544320.0,10342272.0,10374016.0,10405760.0,10437504.0,10469248.0,10500992.0,10532736.0,10564480.0,10361984.0,10393792.0,10425600.0,10457408.0,10489216.0,10521024.0,10552832.0,10584640.0,10381696.0,10413568.0,10445440.0,10477312.0,10509184.0,10541056.0,10572928.0,10604800.0,10401408.0,10433344.0,10465280.0,10497216.0,10529152.0,10561088.0,10593024.0,10624960.0,10421120.0,10453120.0,10485120.0,10517120.0,10549120.0,10581120.0,10613120.0,10645120.0,10440832.0,10472896.0,10504960.0,10537024.0,10569088.0,10601152.0,10633216.0,10665280.0,10460544.0,10492672.0,10524800.0,10556928.0,10589056.0,10621184.0,10653312.0,10685440.0,10480256.0,10512448.0,10544640.0,10576832.0,10609024.0,10641216.0,10673408.0,10705600.0,10499968.0,10532224.0,10564480.0,10596736.0,10628992.0,10661248.0,10693504.0,10725760.0,10519680.0,10552000.0,10584320.0,10616640.0,10648960.0,10681280.0,10713600.0,10745920.0,10539392.0,10571776.0,10604160.0,10636544.0,10668928.0,10701312.0,10733696.0,10766080.0,10559104.0,10591552.0,10624000.0,10656448.0,10688896.0,10721344.0,10753792.0,10786240.0,10578816.0,10611328.0,10643840.0,10676352.0,10708864.0,10741376.0,10773888.0,10806400.0,10598528.0,10631104.0,10663680.0,10696256.0,10728832.0,10761408.0,10793984.0,10826560.0,10618240.0,10650880.0,10683520.0,10716160.0,10748800.0,10781440.0,10814080.0,10846720.0,10637952.0,10670656.0,10703360.0,10736064.0,10768768.0,10801472.0,10834176.0,10866880.0,10657664.0,10690432.0,10723200.0,10755968.0,10788736.0,10821504.0,10854272.0,10887040.0,10677376.0,10710208.0,10743040.0,10775872.0,10808704.0,10841536.0,10874368.0,10907200.0,10697088.0,10729984.0,10762880.0,10795776.0,10828672.0,10861568.0,10894464.0,10927360.0,10716800.0,10749760.0,10782720.0,10815680.0,10848640.0,10881600.0,10914560.0,10947520.0,10736512.0,10769536.0,10802560.0,10835584.0,10868608.0,10901632.0,10934656.0,10967680.0,10756224.0,10789312.0,10822400.0,10855488.0,10888576.0,10921664.0,10954752.0,10987840.0,10775936.0,10809088.0,10842240.0,10875392.0,10908544.0,10941696.0,10974848.0,11008000.0,10795648.0,10828864.0,10862080.0,10895296.0,10928512.0,10961728.0,10994944.0,11028160.0,10815360.0,10848640.0,10881920.0,10915200.0,10948480.0,10981760.0,11015040.0,11048320.0,10835072.0,10868416.0,10901760.0,10935104.0,10968448.0,11001792.0,11035136.0,11068480.0,10854784.0,10888192.0,10921600.0,10955008.0,10988416.0,11021824.0,11055232.0,11088640.0,10874496.0,10907968.0,10941440.0,10974912.0,11008384.0,11041856.0,11075328.0,11108800.0,10894208.0,10927744.0,10961280.0,10994816.0,11028352.0,11061888.0,11095424.0,11128960.0,10913920.0,10947520.0,10981120.0,11014720.0,11048320.0,11081920.0,11115520.0,11149120.0,10933632.0,10967296.0,11000960.0,11034624.0,11068288.0,11101952.0,11135616.0,11169280.0,10953344.0,10987072.0,11020800.0,11054528.0,11088256.0,11121984.0,11155712.0,11189440.0,10973056.0,11006848.0,11040640.0,11074432.0,11108224.0,11142016.0,11175808.0,11209600.0,10992768.0,11026624.0,11060480.0,11094336.0,11128192.0,11162048.0,11195904.0,11229760.0,10439040.0,10471936.0,10504832.0,10537728.0,10570624.0,10603520.0,10636416.0,10669312.0,9877120.0,9909056.0,9940992.0,9972928.0,10004864.0,10036800.0,10068736.0,10100672.0,9323392.0,9354368.0,9385344.0,9416320.0,9447296.0,9478272.0,9509248.0,9540224.0,8761472.0,8791488.0,8821504.0,8851520.0,8881536.0,8911552.0,8941568.0,8971584.0,8781184.0,8811264.0,8841344.0,8871424.0,8901504.0,8931584.0,8961664.0,8991744.0,8800896.0,8831040.0,8861184.0,8891328.0,8921472.0,8951616.0,8981760.0,9011904.0,8820608.0,8850816.0,8881024.0,8911232.0,8941440.0,8971648.0,9001856.0,9032064.0,8840320.0,8870592.0,8900864.0,8931136.0,8961408.0,8991680.0,9021952.0,9052224.0,8860032.0,8890368.0,8920704.0,8951040.0,8981376.0,9011712.0,9042048.0,9072384.0,8879744.0,8910144.0,8940544.0,8970944.0,9001344.0,9031744.0,9062144.0,9092544.0,8899456.0,8929920.0,8960384.0,8990848.0,9021312.0,9051776.0,9082240.0,9112704.0,8919168.0,8949696.0,8980224.0,9010752.0,9041280.0,9071808.0,9102336.0,9132864.0,8938880.0,8969472.0,9000064.0,9030656.0,9061248.0,9091840.0,9122432.0,9153024.0,8958592.0,8989248.0,9019904.0,9050560.0,9081216.0,9111872.0,9142528.0,9173184.0,8978304.0,9009024.0,9039744.0,9070464.0,9101184.0,9131904.0,9162624.0,9193344.0,8998016.0,9028800.0,9059584.0,9090368.0,9121152.0,9151936.0,9182720.0,9213504.0,9017728.0,9048576.0,9079424.0,9110272.0,9141120.0,9171968.0,9202816.0,9233664.0,9037440.0,9068352.0,9099264.0,9130176.0,9161088.0,9192000.0,9222912.0,9253824.0,9057152.0,9088128.0,9119104.0,9150080.0,9181056.0,9212032.0,9243008.0,9273984.0,9076864.0,9107904.0,9138944.0,9169984.0,9201024.0,9232064.0,9263104.0,9294144.0,9096576.0,9127680.0,9158784.0,9189888.0,9220992.0,9252096.0,9283200.0,9314304.0,9116288.0,9147456.0,9178624.0,9209792.0,9240960.0,9272128.0,9303296.0,9334464.0,9136000.0,9167232.0,9198464.0,9229696.0,9260928.0,9292160.0,9323392.0,9354624.0,9155712.0,9187008.0,9218304.0,9249600.0,9280896.0,9312192.0,9343488.0,9374784.0,9175424.0,9206784.0,9238144.0,9269504.0,9300864.0,9332224.0,9363584.0,9394944.0,9195136.0,9226560.0,9257984.0,9289408.0,9320832.0,9352256.0,9383680.0,9415104.0,9214848.0,9246336.0,9277824.0,9309312.0,9340800.0,9372288.0,9403776.0,9435264.0,9234560.0,9266112.0,9297664.0,9329216.0,9360768.0,9392320.0,9423872.0,9455424.0,9254272.0,9285888.0,9317504.0,9349120.0,9380736.0,9412352.0,9443968.0,9475584.0,9273984.0,9305664.0,9337344.0,9369024.0,9400704.0,9432384.0,9464064.0,9495744.0,9293696.0,9325440.0,9357184.0,9388928.0,9420672.0,9452416.0,9484160.0,9515904.0,9313408.0,9345216.0,9377024.0,9408832.0,9440640.0,9472448.0,9504256.0,9536064.0,9333120.0,9364992.0,9396864.0,9428736.0,9460608.0,9492480.0,9524352.0,9556224.0,9352832.0,9384768.0,9416704.0,9448640.0,9480576.0,9512512.0,9544448.0,9576384.0,9372544.0,9404544.0,9436544.0,9468544.0,9500544.0,9532544.0,9564544.0,9596544.0,9392256.0,9424320.0,9456384.0,9488448.0,9520512.0,9552576.0,9584640.0,9616704.0,9411968.0,9444096.0,9476224.0,9508352.0,9540480.0,9572608.0,9604736.0,9636864.0,9431680.0,9463872.0,9496064.0,9528256.0,9560448.0,9592640.0,9624832.0,9657024.0,9451392.0,9483648.0,9515904.0,9548160.0,9580416.0,9612672.0,9644928.0,9677184.0,9471104.0,9503424.0,9535744.0,9568064.0,9600384.0,9632704.0,9665024.0,9697344.0,9490816.0,9523200.0,9555584.0,9587968.0,9620352.0,9652736.0,9685120.0,9717504.0,9461376.0,9492800.0,9524224.0,9555648.0,9587072.0,9618496.0,9649920.0,9681344.0,9423744.0,9454208.0,9484672.0,9515136.0,9545600.0,9576064.0,9606528.0,9636992.0,9443456.0,9473984.0,9504512.0,9535040.0,9565568.0,9596096.0,9626624.0,9657152.0,9463168.0,9493760.0,9524352.0,9554944.0,9585536.0,9616128.0,9646720.0,9677312.0,9482880.0,9513536.0,9544192.0,9574848.0,9605504.0,9636160.0,9666816.0,9697472.0,9502592.0,9533312.0,9564032.0,9594752.0,9625472.0,9656192.0,9686912.0,9717632.0,9522304.0,9553088.0,9583872.0,9614656.0,9645440.0,9676224.0,9707008.0,9737792.0,9542016.0,9572864.0,9603712.0,9634560.0,9665408.0,9696256.0,9727104.0,9757952.0,9561728.0,9592640.0,9623552.0,9654464.0,9685376.0,9716288.0,9747200.0,9778112.0,9581440.0,9612416.0,9643392.0,9674368.0,9705344.0,9736320.0,9767296.0,9798272.0,9601152.0,9632192.0,9663232.0,9694272.0,9725312.0,9756352.0,9787392.0,9818432.0,9620864.0,9651968.0,9683072.0,9714176.0,9745280.0,9776384.0,9807488.0,9838592.0,9640576.0,9671744.0,9702912.0,9734080.0,9765248.0,9796416.0,9827584.0,9858752.0,9660288.0,9691520.0,9722752.0,9753984.0,9785216.0,9816448.0,9847680.0,9878912.0,9680000.0,9711296.0,9742592.0,9773888.0,9805184.0,9836480.0,9867776.0,9899072.0,9699712.0,9731072.0,9762432.0,9793792.0,9825152.0,9856512.0,9887872.0,9919232.0,9719424.0,9750848.0,9782272.0,9813696.0,9845120.0,9876544.0,9907968.0,9939392.0,9739136.0,9770624.0,9802112.0,9833600.0,9865088.0,9896576.0,9928064.0,9959552.0,9758848.0,9790400.0,9821952.0,9853504.0,9885056.0,9916608.0,9948160.0,9979712.0,9778560.0,9810176.0,9841792.0,9873408.0,9905024.0,9936640.0,9968256.0,9999872.0,9798272.0,9829952.0,9861632.0,9893312.0,9924992.0,9956672.0,9988352.0,10020032.0,9817984.0,9849728.0,9881472.0,9913216.0,9944960.0,9976704.0,10008448.0,10040192.0,9837696.0,9869504.0,9901312.0,9933120.0,9964928.0,9996736.0,10028544.0,10060352.0,9857408.0,9889280.0,9921152.0,9953024.0,9984896.0,10016768.0,10048640.0,10080512.0,9877120.0,9909056.0,9940992.0,9972928.0,10004864.0,10036800.0,10068736.0,10100672.0,9896832.0,9928832.0,9960832.0,9992832.0,10024832.0,10056832.0,10088832.0,10120832.0,9916544.0,9948608.0,9980672.0,10012736.0,10044800.0,10076864.0,10108928.0,10140992.0,9936256.0,9968384.0,10000512.0,10032640.0,10064768.0,10096896.0,10129024.0,10161152.0,9955968.0,9988160.0,10020352.0,10052544.0,10084736.0,10116928.0,10149120.0,10181312.0,9975680.0,10007936.0,10040192.0,10072448.0,10104704.0,10136960.0,10169216.0,10201472.0,9995392.0,10027712.0,10060032.0,10092352.0,10124672.0,10156992.0,10189312.0,10221632.0,10015104.0,10047488.0,10079872.0,10112256.0,10144640.0,10177024.0,10209408.0,10241792.0,10034816.0,10067264.0,10099712.0,10132160.0,10164608.0,10197056.0,10229504.0,10261952.0,10054528.0,10087040.0,10119552.0,10152064.0,10184576.0,10217088.0,10249600.0,10282112.0,10074240.0,10106816.0,10139392.0,10171968.0,10204544.0,10237120.0,10269696.0,10302272.0,10093952.0,10126592.0,10159232.0,10191872.0,10224512.0,10257152.0,10289792.0,10322432.0,10113664.0,10146368.0,10179072.0,10211776.0,10244480.0,10277184.0,10309888.0,10342592.0,10133376.0,10166144.0,10198912.0,10231680.0,10264448.0,10297216.0,10329984.0,10362752.0,10153088.0,10185920.0,10218752.0,10251584.0,10284416.0,10317248.0,10350080.0,10382912.0,9599360.0,9631232.0,9663104.0,9694976.0,9726848.0,9758720.0,9790592.0,9822464.0,9037440.0,9068352.0,9099264.0,9130176.0,9161088.0,9192000.0,9222912.0,9253824.0,9057152.0,9088128.0,9119104.0,9150080.0,9181056.0,9212032.0,9243008.0,9273984.0,9076864.0,9107904.0,9138944.0,9169984.0,9201024.0,9232064.0,9263104.0,9294144.0,9096576.0,9127680.0,9158784.0,9189888.0,9220992.0,9252096.0,9283200.0,9314304.0,9116288.0,9147456.0,9178624.0,9209792.0,9240960.0,9272128.0,9303296.0,9334464.0,9136000.0,9167232.0,9198464.0,9229696.0,9260928.0,9292160.0,9323392.0,9354624.0,9155712.0,9187008.0,9218304.0,9249600.0,9280896.0,9312192.0,9343488.0,9374784.0,9175424.0,9206784.0,9238144.0,9269504.0,9300864.0,9332224.0,9363584.0,9394944.0,9195136.0,9226560.0,9257984.0,9289408.0,9320832.0,9352256.0,9383680.0,9415104.0,9214848.0,9246336.0,9277824.0,9309312.0,9340800.0,9372288.0,9403776.0,9435264.0,9234560.0,9266112.0,9297664.0,9329216.0,9360768.0,9392320.0,9423872.0,9455424.0,9254272.0,9285888.0,9317504.0,9349120.0,9380736.0,9412352.0,9443968.0,9475584.0,9273984.0,9305664.0,9337344.0,9369024.0,9400704.0,9432384.0,9464064.0,9495744.0,9293696.0,9325440.0,9357184.0,9388928.0,9420672.0,9452416.0,9484160.0,9515904.0,9313408.0,9345216.0,9377024.0,9408832.0,9440640.0,9472448.0,9504256.0,9536064.0,9333120.0,9364992.0,9396864.0,9428736.0,9460608.0,9492480.0,9524352.0,9556224.0,9352832.0,9384768.0,9416704.0,9448640.0,9480576.0,9512512.0,9544448.0,9576384.0,9372544.0,9404544.0,9436544.0,9468544.0,9500544.0,9532544.0,9564544.0,9596544.0,9392256.0,9424320.0,9456384.0,9488448.0,9520512.0,9552576.0,9584640.0,9616704.0,9411968.0,9444096.0,9476224.0,9508352.0,9540480.0,9572608.0,9604736.0,9636864.0,9431680.0,9463872.0,9496064.0,9528256.0,9560448.0,9592640.0,9624832.0,9657024.0,9451392.0,9483648.0,9515904.0,9548160.0,9580416.0,9612672.0,9644928.0,9677184.0,9471104.0,9503424.0,9535744.0,9568064.0,9600384.0,9632704.0,9665024.0,9697344.0,9490816.0,9523200.0,9555584.0,9587968.0,9620352.0,9652736.0,9685120.0,9717504.0,9510528.0,9542976.0,9575424.0,9607872.0,9640320.0,9672768.0,9705216.0,9737664.0,9530240.0,9562752.0,9595264.0,9627776.0,9660288.0,9692800.0,9725312.0,9757824.0,9549952.0,9582528.0,9615104.0,9647680.0,9680256.0,9712832.0,9745408.0,9777984.0,9569664.0,9602304.0,9634944.0,9667584.0,9700224.0,9732864.0,9765504.0,9798144.0,9589376.0,9622080.0,9654784.0,9687488.0,9720192.0,9752896.0,9785600.0,9818304.0,9609088.0,9641856.0,9674624.0,9707392.0,9740160.0,9772928.0,9805696.0,9838464.0,9628800.0,9661632.0,9694464.0,9727296.0,9760128.0,9792960.0,9825792.0,9858624.0,9648512.0,9681408.0,9714304.0,9747200.0,9780096.0,9812992.0,9845888.0,9878784.0,9668224.0,9701184.0,9734144.0,9767104.0,9800064.0,9833024.0,9865984.0,9898944.0,9687936.0,9720960.0,9753984.0,9787008.0,9820032.0,9853056.0,9886080.0,9919104.0,9707648.0,9740736.0,9773824.0,9806912.0,9840000.0,9873088.0,9906176.0,9939264.0,9727360.0,9760512.0,9793664.0,9826816.0,9859968.0,9893120.0,9926272.0,9959424.0,9747072.0,9780288.0,9813504.0,9846720.0,9879936.0,9913152.0,9946368.0,9979584.0,9766784.0,9800064.0,9833344.0,9866624.0,9899904.0,9933184.0,9966464.0,9999744.0,9737344.0,9769664.0,9801984.0,9834304.0,9866624.0,9898944.0,9931264.0,9963584.0,9699712.0,9731072.0,9762432.0,9793792.0,9825152.0,9856512.0,9887872.0,9919232.0,9719424.0,9750848.0,9782272.0,9813696.0,9845120.0,9876544.0,9907968.0,9939392.0,9739136.0,9770624.0,9802112.0,9833600.0,9865088.0,9896576.0,9928064.0,9959552.0,9758848.0,9790400.0,9821952.0,9853504.0,9885056.0,9916608.0,9948160.0,9979712.0,9778560.0,9810176.0,9841792.0,9873408.0,9905024.0,9936640.0,9968256.0,9999872.0,9798272.0,9829952.0,9861632.0,9893312.0,9924992.0,9956672.0,9988352.0,10020032.0,9817984.0,9849728.0,9881472.0,9913216.0,9944960.0,9976704.0,10008448.0,10040192.0,9837696.0,9869504.0,9901312.0,9933120.0,9964928.0,9996736.0,10028544.0,10060352.0,9857408.0,9889280.0,9921152.0,9953024.0,9984896.0,10016768.0,10048640.0,10080512.0,9877120.0,9909056.0,9940992.0,9972928.0,10004864.0,10036800.0,10068736.0,10100672.0,9896832.0,9928832.0,9960832.0,9992832.0,10024832.0,10056832.0,10088832.0,10120832.0,9916544.0,9948608.0,9980672.0,10012736.0,10044800.0,10076864.0,10108928.0,10140992.0,9936256.0,9968384.0,10000512.0,10032640.0,10064768.0,10096896.0,10129024.0,10161152.0,9955968.0,9988160.0,10020352.0,10052544.0,10084736.0,10116928.0,10149120.0,10181312.0,9975680.0,10007936.0,10040192.0,10072448.0,10104704.0,10136960.0,10169216.0,10201472.0,9995392.0,10027712.0,10060032.0,10092352.0,10124672.0,10156992.0,10189312.0,10221632.0,10015104.0,10047488.0,10079872.0,10112256.0,10144640.0,10177024.0,10209408.0,10241792.0,10034816.0,10067264.0,10099712.0,10132160.0,10164608.0,10197056.0,10229504.0,10261952.0,10054528.0,10087040.0,10119552.0,10152064.0,10184576.0,10217088.0,10249600.0,10282112.0,10074240.0,10106816.0,10139392.0,10171968.0,10204544.0,10237120.0,10269696.0,10302272.0,10093952.0,10126592.0,10159232.0,10191872.0,10224512.0,10257152.0,10289792.0,10322432.0,10113664.0,10146368.0,10179072.0,10211776.0,10244480.0,10277184.0,10309888.0,10342592.0,10133376.0,10166144.0,10198912.0,10231680.0,10264448.0,10297216.0,10329984.0,10362752.0,10153088.0,10185920.0,10218752.0,10251584.0,10284416.0,10317248.0,10350080.0,10382912.0,10172800.0,10205696.0,10238592.0,10271488.0,10304384.0,10337280.0,10370176.0,10403072.0,10192512.0,10225472.0,10258432.0,10291392.0,10324352.0,10357312.0,10390272.0,10423232.0,10212224.0,10245248.0,10278272.0,10311296.0,10344320.0,10377344.0,10410368.0,10443392.0,10231936.0,10265024.0,10298112.0,10331200.0,10364288.0,10397376.0,10430464.0,10463552.0,10251648.0,10284800.0,10317952.0,10351104.0,10384256.0,10417408.0,10450560.0,10483712.0,10271360.0,10304576.0,10337792.0,10371008.0,10404224.0,10437440.0,10470656.0,10503872.0,10291072.0,10324352.0,10357632.0,10390912.0,10424192.0,10457472.0,10490752.0,10524032.0,10310784.0,10344128.0,10377472.0,10410816.0,10444160.0,10477504.0,10510848.0,10544192.0,10330496.0,10363904.0,10397312.0,10430720.0,10464128.0,10497536.0,10530944.0,10564352.0,10350208.0,10383680.0,10417152.0,10450624.0,10484096.0,10517568.0,10551040.0,10584512.0,10369920.0,10403456.0,10436992.0,10470528.0,10504064.0,10537600.0,10571136.0,10604672.0,10389632.0,10423232.0,10456832.0,10490432.0,10524032.0,10557632.0,10591232.0,10624832.0,10409344.0,10443008.0,10476672.0,10510336.0,10544000.0,10577664.0,10611328.0,10644992.0,10429056.0,10462784.0,10496512.0,10530240.0,10563968.0,10597696.0,10631424.0,10665152.0,9875328.0,9908096.0,9940864.0,9973632.0,10006400.0,10039168.0,10071936.0,10104704.0,9313408.0,9345216.0,9377024.0,9408832.0,9440640.0,9472448.0,9504256.0,9536064.0,4725376.0,4741296.0,4757216.0,4773136.0,4789056.0,4804976.0,4820896.0,4836816.0,5336064.0,5351856.0,5367648.0,5383440.0,5399232.0,5415024.0,5430816.0,5446608.0,10808704.0,10840320.0,10871936.0,10903552.0,10935168.0,10966784.0,10998400.0,11030016.0,10829440.0,10861120.0,10892800.0,10924480.0,10956160.0,10987840.0,11019520.0,11051200.0,10850176.0,10881920.0,10913664.0,10945408.0,10977152.0,11008896.0,11040640.0,11072384.0,10870912.0,10902720.0,10934528.0,10966336.0,10998144.0,11029952.0,11061760.0,11093568.0,10891648.0,10923520.0,10955392.0,10987264.0,11019136.0,11051008.0,11082880.0,11114752.0,10912384.0,10944320.0,10976256.0,11008192.0,11040128.0,11072064.0,11104000.0,11135936.0,10933120.0,10965120.0,10997120.0,11029120.0,11061120.0,11093120.0,11125120.0,11157120.0,10953856.0,10985920.0,11017984.0,11050048.0,11082112.0,11114176.0,11146240.0,11178304.0,10974592.0,11006720.0,11038848.0,11070976.0,11103104.0,11135232.0,11167360.0,11199488.0,10995328.0,11027520.0,11059712.0,11091904.0,11124096.0,11156288.0,11188480.0,11220672.0,11016064.0,11048320.0,11080576.0,11112832.0,11145088.0,11177344.0,11209600.0,11241856.0,11036800.0,11069120.0,11101440.0,11133760.0,11166080.0,11198400.0,11230720.0,11263040.0,11057536.0,11089920.0,11122304.0,11154688.0,11187072.0,11219456.0,11251840.0,11284224.0,11078272.0,11110720.0,11143168.0,11175616.0,11208064.0,11240512.0,11272960.0,11305408.0,11099008.0,11131520.0,11164032.0,11196544.0,11229056.0,11261568.0,11294080.0,11326592.0,11119744.0,11152320.0,11184896.0,11217472.0,11250048.0,11282624.0,11315200.0,11347776.0,11140480.0,11173120.0,11205760.0,11238400.0,11271040.0,11303680.0,11336320.0,11368960.0,11161216.0,11193920.0,11226624.0,11259328.0,11292032.0,11324736.0,11357440.0,11390144.0,11181952.0,11214720.0,11247488.0,11280256.0,11313024.0,11345792.0,11378560.0,11411328.0,11202688.0,11235520.0,11268352.0,11301184.0,11334016.0,11366848.0,11399680.0,11432512.0,11223424.0,11256320.0,11289216.0,11322112.0,11355008.0,11387904.0,11420800.0,11453696.0,11244160.0,11277120.0,11310080.0,11343040.0,11376000.0,11408960.0,11441920.0,11474880.0,11264896.0,11297920.0,11330944.0,11363968.0,11396992.0,11430016.0,11463040.0,11496064.0,11285632.0,11318720.0,11351808.0,11384896.0,11417984.0,11451072.0,11484160.0,11517248.0,11306368.0,11339520.0,11372672.0,11405824.0,11438976.0,11472128.0,11505280.0,11538432.0,11327104.0,11360320.0,11393536.0,11426752.0,11459968.0,11493184.0,11526400.0,11559616.0,11347840.0,11381120.0,11414400.0,11447680.0,11480960.0,11514240.0,11547520.0,11580800.0,11368576.0,11401920.0,11435264.0,11468608.0,11501952.0,11535296.0,11568640.0,11601984.0,11389312.0,11422720.0,11456128.0,11489536.0,11522944.0,11556352.0,11589760.0,11623168.0,11410048.0,11443520.0,11476992.0,11510464.0,11543936.0,11577408.0,11610880.0,11644352.0,11430784.0,11464320.0,11497856.0,11531392.0,11564928.0,11598464.0,11632000.0,11665536.0,11451520.0,11485120.0,11518720.0,11552320.0,11585920.0,11619520.0,11653120.0,11686720.0,11472256.0,11505920.0,11539584.0,11573248.0,11606912.0,11640576.0,11674240.0,11707904.0,11492992.0,11526720.0,11560448.0,11594176.0,11627904.0,11661632.0,11695360.0,11729088.0,11513728.0,11547520.0,11581312.0,11615104.0,11648896.0,11682688.0,11716480.0,11750272.0,11534464.0,11568320.0,11602176.0,11636032.0,11669888.0,11703744.0,11737600.0,11771456.0,10965376.0,10998272.0,11031168.0,11064064.0,11096960.0,11129856.0,11162752.0,11195648.0,10388096.0,10420032.0,10451968.0,10483904.0,10515840.0,10547776.0,10579712.0,10611648.0,9819008.0,9849984.0,9880960.0,9911936.0,9942912.0,9973888.0,10004864.0,10035840.0,9241728.0,9271744.0,9301760.0,9331776.0,9361792.0,9391808.0,9421824.0,9451840.0,9262464.0,9292544.0,9322624.0,9352704.0,9382784.0,9412864.0,9442944.0,9473024.0,9283200.0,9313344.0,9343488.0,9373632.0,9403776.0,9433920.0,9464064.0,9494208.0,9303936.0,9334144.0,9364352.0,9394560.0,9424768.0,9454976.0,9485184.0,9515392.0,9324672.0,9354944.0,9385216.0,9415488.0,9445760.0,9476032.0,9506304.0,9536576.0,9345408.0,9375744.0,9406080.0,9436416.0,9466752.0,9497088.0,9527424.0,9557760.0,9366144.0,9396544.0,9426944.0,9457344.0,9487744.0,9518144.0,9548544.0,9578944.0,9386880.0,9417344.0,9447808.0,9478272.0,9508736.0,9539200.0,9569664.0,9600128.0,9407616.0,9438144.0,9468672.0,9499200.0,9529728.0,9560256.0,9590784.0,9621312.0,9428352.0,9458944.0,9489536.0,9520128.0,9550720.0,9581312.0,9611904.0,9642496.0,9449088.0,9479744.0,9510400.0,9541056.0,9571712.0,9602368.0,9633024.0,9663680.0,9469824.0,9500544.0,9531264.0,9561984.0,9592704.0,9623424.0,9654144.0,9684864.0,9490560.0,9521344.0,9552128.0,9582912.0,9613696.0,9644480.0,9675264.0,9706048.0,9511296.0,9542144.0,9572992.0,9603840.0,9634688.0,9665536.0,9696384.0,9727232.0,9532032.0,9562944.0,9593856.0,9624768.0,9655680.0,9686592.0,9717504.0,9748416.0,9552768.0,9583744.0,9614720.0,9645696.0,9676672.0,9707648.0,9738624.0,9769600.0,9573504.0,9604544.0,9635584.0,9666624.0,9697664.0,9728704.0,9759744.0,9790784.0,9594240.0,9625344.0,9656448.0,9687552.0,9718656.0,9749760.0,9780864.0,9811968.0,9614976.0,9646144.0,9677312.0,9708480.0,9739648.0,9770816.0,9801984.0,9833152.0,9635712.0,9666944.0,9698176.0,9729408.0,9760640.0,9791872.0,9823104.0,9854336.0,9656448.0,9687744.0,9719040.0,9750336.0,9781632.0,9812928.0,9844224.0,9875520.0,9677184.0,9708544.0,9739904.0,9771264.0,9802624.0,9833984.0,9865344.0,9896704.0,9697920.0,9729344.0,9760768.0,9792192.0,9823616.0,9855040.0,9886464.0,9917888.0,9718656.0,9750144.0,9781632.0,9813120.0,9844608.0,9876096.0,9907584.0,9939072.0,9739392.0,9770944.0,9802496.0,9834048.0,9865600.0,9897152.0,9928704.0,9960256.0,9760128.0,9791744.0,9823360.0,9854976.0,9886592.0,9918208.0,9949824.0,9981440.0,9780864.0,9812544.0,9844224.0,9875904.0,9907584.0,9939264.0,9970944.0,10002624.0,9801600.0,9833344.0,9865088.0,9896832.0,9928576.0,9960320.0,9992064.0,10023808.0,9822336.0,9854144.0,9885952.0,9917760.0,9949568.0,9981376.0,10013184.0,10044992.0,9843072.0,9874944.0,9906816.0,9938688.0,9970560.0,10002432.0,10034304.0,10066176.0,9863808.0,9895744.0,9927680.0,9959616.0,9991552.0,10023488.0,10055424.0,10087360.0,9884544.0,9916544.0,9948544.0,9980544.0,10012544.0,10044544.0,10076544.0,10108544.0,9905280.0,9937344.0,9969408.0,10001472.0,10033536.0,10065600.0,10097664.0,10129728.0,9926016.0,9958144.0,9990272.0,10022400.0,10054528.0,10086656.0,10118784.0,10150912.0,9946752.0,9978944.0,10011136.0,10043328.0,10075520.0,10107712.0,10139904.0,10172096.0,9967488.0,9999744.0,10032000.0,10064256.0,10096512.0,10128768.0,10161024.0,10193280.0,9988224.0,10020544.0,10052864.0,10085184.0,10117504.0,10149824.0,10182144.0,10214464.0,10008960.0,10041344.0,10073728.0,10106112.0,10138496.0,10170880.0,10203264.0,10235648.0,9964160.0,9995584.0,10027008.0,10058432.0,10089856.0,10121280.0,10152704.0,10184128.0,9911168.0,9941632.0,9972096.0,10002560.0,10033024.0,10063488.0,10093952.0,10124416.0,9931904.0,9962432.0,9992960.0,10023488.0,10054016.0,10084544.0,10115072.0,10145600.0,9952640.0,9983232.0,10013824.0,10044416.0,10075008.0,10105600.0,10136192.0,10166784.0,9973376.0,10004032.0,10034688.0,10065344.0,10096000.0,10126656.0,10157312.0,10187968.0,9994112.0,10024832.0,10055552.0,10086272.0,10116992.0,10147712.0,10178432.0,10209152.0,10014848.0,10045632.0,10076416.0,10107200.0,10137984.0,10168768.0,10199552.0,10230336.0,10035584.0,10066432.0,10097280.0,10128128.0,10158976.0,10189824.0,10220672.0,10251520.0,10056320.0,10087232.0,10118144.0,10149056.0,10179968.0,10210880.0,10241792.0,10272704.0,10077056.0,10108032.0,10139008.0,10169984.0,10200960.0,10231936.0,10262912.0,10293888.0,10097792.0,10128832.0,10159872.0,10190912.0,10221952.0,10252992.0,10284032.0,10315072.0,10118528.0,10149632.0,10180736.0,10211840.0,10242944.0,10274048.0,10305152.0,10336256.0,10139264.0,10170432.0,10201600.0,10232768.0,10263936.0,10295104.0,10326272.0,10357440.0,10160000.0,10191232.0,10222464.0,10253696.0,10284928.0,10316160.0,10347392.0,10378624.0,10180736.0,10212032.0,10243328.0,10274624.0,10305920.0,10337216.0,10368512.0,10399808.0,10201472.0,10232832.0,10264192.0,10295552.0,10326912.0,10358272.0,10389632.0,10420992.0,10222208.0,10253632.0,10285056.0,10316480.0,10347904.0,10379328.0,10410752.0,10442176.0,10242944.0,10274432.0,10305920.0,10337408.0,10368896.0,10400384.0,10431872.0,10463360.0,10263680.0,10295232.0,10326784.0,10358336.0,10389888.0,10421440.0,10452992.0,10484544.0,10284416.0,10316032.0,10347648.0,10379264.0,10410880.0,10442496.0,10474112.0,10505728.0,10305152.0,10336832.0,10368512.0,10400192.0,10431872.0,10463552.0,10495232.0,10526912.0,10325888.0,10357632.0,10389376.0,10421120.0,10452864.0,10484608.0,10516352.0,10548096.0,10346624.0,10378432.0,10410240.0,10442048.0,10473856.0,10505664.0,10537472.0,10569280.0,10367360.0,10399232.0,10431104.0,10462976.0,10494848.0,10526720.0,10558592.0,10590464.0,10388096.0,10420032.0,10451968.0,10483904.0,10515840.0,10547776.0,10579712.0,10611648.0,10408832.0,10440832.0,10472832.0,10504832.0,10536832.0,10568832.0,10600832.0,10632832.0,10429568.0,10461632.0,10493696.0,10525760.0,10557824.0,10589888.0,10621952.0,10654016.0,10450304.0,10482432.0,10514560.0,10546688.0,10578816.0,10610944.0,10643072.0,10675200.0,10471040.0,10503232.0,10535424.0,10567616.0,10599808.0,10632000.0,10664192.0,10696384.0,10491776.0,10524032.0,10556288.0,10588544.0,10620800.0,10653056.0,10685312.0,10717568.0,10512512.0,10544832.0,10577152.0,10609472.0,10641792.0,10674112.0,10706432.0,10738752.0,10533248.0,10565632.0,10598016.0,10630400.0,10662784.0,10695168.0,10727552.0,10759936.0,10553984.0,10586432.0,10618880.0,10651328.0,10683776.0,10716224.0,10748672.0,10781120.0,10574720.0,10607232.0,10639744.0,10672256.0,10704768.0,10737280.0,10769792.0,10802304.0,10595456.0,10628032.0,10660608.0,10693184.0,10725760.0,10758336.0,10790912.0,10823488.0,10616192.0,10648832.0,10681472.0,10714112.0,10746752.0,10779392.0,10812032.0,10844672.0,10636928.0,10669632.0,10702336.0,10735040.0,10767744.0,10800448.0,10833152.0,10865856.0,10657664.0,10690432.0,10723200.0,10755968.0,10788736.0,10821504.0,10854272.0,10887040.0,10678400.0,10711232.0,10744064.0,10776896.0,10809728.0,10842560.0,10875392.0,10908224.0,10109312.0,10141184.0,10173056.0,10204928.0,10236800.0,10268672.0,10300544.0,10332416.0,9532032.0,9562944.0,9593856.0,9624768.0,9655680.0,9686592.0,9717504.0,9748416.0,9552768.0,9583744.0,9614720.0,9645696.0,9676672.0,9707648.0,9738624.0,9769600.0,9573504.0,9604544.0,9635584.0,9666624.0,9697664.0,9728704.0,9759744.0,9790784.0,9594240.0,9625344.0,9656448.0,9687552.0,9718656.0,9749760.0,9780864.0,9811968.0,9614976.0,9646144.0,9677312.0,9708480.0,9739648.0,9770816.0,9801984.0,9833152.0,9635712.0,9666944.0,9698176.0,9729408.0,9760640.0,9791872.0,9823104.0,9854336.0,9656448.0,9687744.0,9719040.0,9750336.0,9781632.0,9812928.0,9844224.0,9875520.0,9677184.0,9708544.0,9739904.0,9771264.0,9802624.0,9833984.0,9865344.0,9896704.0,9697920.0,9729344.0,9760768.0,9792192.0,9823616.0,9855040.0,9886464.0,9917888.0,9718656.0,9750144.0,9781632.0,9813120.0,9844608.0,9876096.0,9907584.0,9939072.0,9739392.0,9770944.0,9802496.0,9834048.0,9865600.0,9897152.0,9928704.0,9960256.0,9760128.0,9791744.0,9823360.0,9854976.0,9886592.0,9918208.0,9949824.0,9981440.0,9780864.0,9812544.0,9844224.0,9875904.0,9907584.0,9939264.0,9970944.0,10002624.0,9801600.0,9833344.0,9865088.0,9896832.0,9928576.0,9960320.0,9992064.0,10023808.0,9822336.0,9854144.0,9885952.0,9917760.0,9949568.0,9981376.0,10013184.0,10044992.0,9843072.0,9874944.0,9906816.0,9938688.0,9970560.0,10002432.0,10034304.0,10066176.0,9863808.0,9895744.0,9927680.0,9959616.0,9991552.0,10023488.0,10055424.0,10087360.0,9884544.0,9916544.0,9948544.0,9980544.0,10012544.0,10044544.0,10076544.0,10108544.0,9905280.0,9937344.0,9969408.0,10001472.0,10033536.0,10065600.0,10097664.0,10129728.0,9926016.0,9958144.0,9990272.0,10022400.0,10054528.0,10086656.0,10118784.0,10150912.0,9946752.0,9978944.0,10011136.0,10043328.0,10075520.0,10107712.0,10139904.0,10172096.0,9967488.0,9999744.0,10032000.0,10064256.0,10096512.0,10128768.0,10161024.0,10193280.0,9988224.0,10020544.0,10052864.0,10085184.0,10117504.0,10149824.0,10182144.0,10214464.0,10008960.0,10041344.0,10073728.0,10106112.0,10138496.0,10170880.0,10203264.0,10235648.0,10029696.0,10062144.0,10094592.0,10127040.0,10159488.0,10191936.0,10224384.0,10256832.0,10050432.0,10082944.0,10115456.0,10147968.0,10180480.0,10212992.0,10245504.0,10278016.0,10071168.0,10103744.0,10136320.0,10168896.0,10201472.0,10234048.0,10266624.0,10299200.0,10091904.0,10124544.0,10157184.0,10189824.0,10222464.0,10255104.0,10287744.0,10320384.0,10112640.0,10145344.0,10178048.0,10210752.0,10243456.0,10276160.0,10308864.0,10341568.0,10133376.0,10166144.0,10198912.0,10231680.0,10264448.0,10297216.0,10329984.0,10362752.0,10154112.0,10186944.0,10219776.0,10252608.0,10285440.0,10318272.0,10351104.0,10383936.0,10174848.0,10207744.0,10240640.0,10273536.0,10306432.0,10339328.0,10372224.0,10405120.0,10195584.0,10228544.0,10261504.0,10294464.0,10327424.0,10360384.0,10393344.0,10426304.0,10216320.0,10249344.0,10282368.0,10315392.0,10348416.0,10381440.0,10414464.0,10447488.0,10237056.0,10270144.0,10303232.0,10336320.0,10369408.0,10402496.0,10435584.0,10468672.0,10257792.0,10290944.0,10324096.0,10357248.0,10390400.0,10423552.0,10456704.0,10489856.0,10278528.0,10311744.0,10344960.0,10378176.0,10411392.0,10444608.0,10477824.0,10511040.0,10299264.0,10332544.0,10365824.0,10399104.0,10432384.0,10465664.0,10498944.0,10532224.0,10254464.0,10286784.0,10319104.0,10351424.0,10383744.0,10416064.0,10448384.0,10480704.0,10201472.0,10232832.0,10264192.0,10295552.0,10326912.0,10358272.0,10389632.0,10420992.0,10222208.0,10253632.0,10285056.0,10316480.0,10347904.0,10379328.0,10410752.0,10442176.0,10242944.0,10274432.0,10305920.0,10337408.0,10368896.0,10400384.0,10431872.0,10463360.0,10263680.0,10295232.0,10326784.0,10358336.0,10389888.0,10421440.0,10452992.0,10484544.0,10284416.0,10316032.0,10347648.0,10379264.0,10410880.0,10442496.0,10474112.0,10505728.0,10305152.0,10336832.0,10368512.0,10400192.0,10431872.0,10463552.0,10495232.0,10526912.0,10325888.0,10357632.0,10389376.0,10421120.0,10452864.0,10484608.0,10516352.0,10548096.0,10346624.0,10378432.0,10410240.0,10442048.0,10473856.0,10505664.0,10537472.0,10569280.0,10367360.0,10399232.0,10431104.0,10462976.0,10494848.0,10526720.0,10558592.0,10590464.0,10388096.0,10420032.0,10451968.0,10483904.0,10515840.0,10547776.0,10579712.0,10611648.0,10408832.0,10440832.0,10472832.0,10504832.0,10536832.0,10568832.0,10600832.0,10632832.0,10429568.0,10461632.0,10493696.0,10525760.0,10557824.0,10589888.0,10621952.0,10654016.0,10450304.0,10482432.0,10514560.0,10546688.0,10578816.0,10610944.0,10643072.0,10675200.0,10471040.0,10503232.0,10535424.0,10567616.0,10599808.0,10632000.0,10664192.0,10696384.0,10491776.0,10524032.0,10556288.0,10588544.0,10620800.0,10653056.0,10685312.0,10717568.0,10512512.0,10544832.0,10577152.0,10609472.0,10641792.0,10674112.0,10706432.0,10738752.0,10533248.0,10565632.0,10598016.0,10630400.0,10662784.0,10695168.0,10727552.0,10759936.0,10553984.0,10586432.0,10618880.0,10651328.0,10683776.0,10716224.0,10748672.0,10781120.0,10574720.0,10607232.0,10639744.0,10672256.0,10704768.0,10737280.0,10769792.0,10802304.0,10595456.0,10628032.0,10660608.0,10693184.0,10725760.0,10758336.0,10790912.0,10823488.0,10616192.0,10648832.0,10681472.0,10714112.0,10746752.0,10779392.0,10812032.0,10844672.0,10636928.0,10669632.0,10702336.0,10735040.0,10767744.0,10800448.0,10833152.0,10865856.0,10657664.0,10690432.0,10723200.0,10755968.0,10788736.0,10821504.0,10854272.0,10887040.0,10678400.0,10711232.0,10744064.0,10776896.0,10809728.0,10842560.0,10875392.0,10908224.0,10699136.0,10732032.0,10764928.0,10797824.0,10830720.0,10863616.0,10896512.0,10929408.0,10719872.0,10752832.0,10785792.0,10818752.0,10851712.0,10884672.0,10917632.0,10950592.0,10740608.0,10773632.0,10806656.0,10839680.0,10872704.0,10905728.0,10938752.0,10971776.0,10761344.0,10794432.0,10827520.0,10860608.0,10893696.0,10926784.0,10959872.0,10992960.0,10782080.0,10815232.0,10848384.0,10881536.0,10914688.0,10947840.0,10980992.0,11014144.0,10802816.0,10836032.0,10869248.0,10902464.0,10935680.0,10968896.0,11002112.0,11035328.0,10823552.0,10856832.0,10890112.0,10923392.0,10956672.0,10989952.0,11023232.0,11056512.0,10844288.0,10877632.0,10910976.0,10944320.0,10977664.0,11011008.0,11044352.0,11077696.0,10865024.0,10898432.0,10931840.0,10965248.0,10998656.0,11032064.0,11065472.0,11098880.0,10885760.0,10919232.0,10952704.0,10986176.0,11019648.0,11053120.0,11086592.0,11120064.0,10906496.0,10940032.0,10973568.0,11007104.0,11040640.0,11074176.0,11107712.0,11141248.0,10927232.0,10960832.0,10994432.0,11028032.0,11061632.0,11095232.0,11128832.0,11162432.0,10947968.0,10981632.0,11015296.0,11048960.0,11082624.0,11116288.0,11149952.0,11183616.0,10968704.0,11002432.0,11036160.0,11069888.0,11103616.0,11137344.0,11171072.0,11204800.0,10399616.0,10432384.0,10465152.0,10497920.0,10530688.0,10563456.0,10596224.0,10628992.0,9822336.0,9854144.0,9885952.0,9917760.0,9949568.0,9981376.0,10013184.0,10044992.0,4980096.0,4996016.0,5011936.0,5027856.0,5043776.0,5059696.0,5075616.0,5091536.0,5588736.0,5604528.0,5620320.0,5636112.0,5651904.0,5667696.0,5683488.0,5699280.0,11314560.0,11346176.0,11377792.0,11409408.0,11441024.0,11472640.0,11504256.0,11535872.0,11336320.0,11368000.0,11399680.0,11431360.0,11463040.0,11494720.0,11526400.0,11558080.0,11358080.0,11389824.0,11421568.0,11453312.0,11485056.0,11516800.0,11548544.0,11580288.0,11379840.0,11411648.0,11443456.0,11475264.0,11507072.0,11538880.0,11570688.0,11602496.0,11401600.0,11433472.0,11465344.0,11497216.0,11529088.0,11560960.0,11592832.0,11624704.0,11423360.0,11455296.0,11487232.0,11519168.0,11551104.0,11583040.0,11614976.0,11646912.0,11445120.0,11477120.0,11509120.0,11541120.0,11573120.0,11605120.0,11637120.0,11669120.0,11466880.0,11498944.0,11531008.0,11563072.0,11595136.0,11627200.0,11659264.0,11691328.0,11488640.0,11520768.0,11552896.0,11585024.0,11617152.0,11649280.0,11681408.0,11713536.0,11510400.0,11542592.0,11574784.0,11606976.0,11639168.0,11671360.0,11703552.0,11735744.0,11532160.0,11564416.0,11596672.0,11628928.0,11661184.0,11693440.0,11725696.0,11757952.0,11553920.0,11586240.0,11618560.0,11650880.0,11683200.0,11715520.0,11747840.0,11780160.0,11575680.0,11608064.0,11640448.0,11672832.0,11705216.0,11737600.0,11769984.0,11802368.0,11597440.0,11629888.0,11662336.0,11694784.0,11727232.0,11759680.0,11792128.0,11824576.0,11619200.0,11651712.0,11684224.0,11716736.0,11749248.0,11781760.0,11814272.0,11846784.0,11640960.0,11673536.0,11706112.0,11738688.0,11771264.0,11803840.0,11836416.0,11868992.0,11662720.0,11695360.0,11728000.0,11760640.0,11793280.0,11825920.0,11858560.0,11891200.0,11684480.0,11717184.0,11749888.0,11782592.0,11815296.0,11848000.0,11880704.0,11913408.0,11706240.0,11739008.0,11771776.0,11804544.0,11837312.0,11870080.0,11902848.0,11935616.0,11728000.0,11760832.0,11793664.0,11826496.0,11859328.0,11892160.0,11924992.0,11957824.0,11749760.0,11782656.0,11815552.0,11848448.0,11881344.0,11914240.0,11947136.0,11980032.0,11771520.0,11804480.0,11837440.0,11870400.0,11903360.0,11936320.0,11969280.0,12002240.0,11793280.0,11826304.0,11859328.0,11892352.0,11925376.0,11958400.0,11991424.0,12024448.0,11815040.0,11848128.0,11881216.0,11914304.0,11947392.0,11980480.0,12013568.0,12046656.0,11836800.0,11869952.0,11903104.0,11936256.0,11969408.0,12002560.0,12035712.0,12068864.0,11858560.0,11891776.0,11924992.0,11958208.0,11991424.0,12024640.0,12057856.0,12091072.0,11880320.0,11913600.0,11946880.0,11980160.0,12013440.0,12046720.0,12080000.0,12113280.0,11902080.0,11935424.0,11968768.0,12002112.0,12035456.0,12068800.0,12102144.0,12135488.0,11923840.0,11957248.0,11990656.0,12024064.0,12057472.0,12090880.0,12124288.0,12157696.0,11945600.0,11979072.0,12012544.0,12046016.0,12079488.0,12112960.0,12146432.0,12179904.0,11967360.0,12000896.0,12034432.0,12067968.0,12101504.0,12135040.0,12168576.0,12202112.0,11989120.0,12022720.0,12056320.0,12089920.0,12123520.0,12157120.0,12190720.0,12224320.0,12010880.0,12044544.0,12078208.0,12111872.0,12145536.0,12179200.0,12212864.0,12246528.0,12032640.0,12066368.0,12100096.0,12133824.0,12167552.0,12201280.0,12235008.0,12268736.0,12054400.0,12088192.0,12121984.0,12155776.0,12189568.0,12223360.0,12257152.0,12290944.0,12076160.0,12110016.0,12143872.0,12177728.0,12211584.0,12245440.0,12279296.0,12313152.0,11491712.0,11524608.0,11557504.0,11590400.0,11623296.0,11656192.0,11689088.0,11721984.0,10899072.0,10931008.0,10962944.0,10994880.0,11026816.0,11058752.0,11090688.0,11122624.0,10314624.0,10345600.0,10376576.0,10407552.0,10438528.0,10469504.0,10500480.0,10531456.0,9721984.0,9752000.0,9782016.0,9812032.0,9842048.0,9872064.0,9902080.0,9932096.0,9743744.0,9773824.0,9803904.0,9833984.0,9864064.0,9894144.0,9924224.0,9954304.0,9765504.0,9795648.0,9825792.0,9855936.0,9886080.0,9916224.0,9946368.0,9976512.0,9787264.0,9817472.0,9847680.0,9877888.0,9908096.0,9938304.0,9968512.0,9998720.0,9809024.0,9839296.0,9869568.0,9899840.0,9930112.0,9960384.0,9990656.0,10020928.0,9830784.0,9861120.0,9891456.0,9921792.0,9952128.0,9982464.0,10012800.0,10043136.0,9852544.0,9882944.0,9913344.0,9943744.0,9974144.0,10004544.0,10034944.0,10065344.0,9874304.0,9904768.0,9935232.0,9965696.0,9996160.0,10026624.0,10057088.0,10087552.0,9896064.0,9926592.0,9957120.0,9987648.0,10018176.0,10048704.0,10079232.0,10109760.0,9917824.0,9948416.0,9979008.0,10009600.0,10040192.0,10070784.0,10101376.0,10131968.0,9939584.0,9970240.0,10000896.0,10031552.0,10062208.0,10092864.0,10123520.0,10154176.0,9961344.0,9992064.0,10022784.0,10053504.0,10084224.0,10114944.0,10145664.0,10176384.0,9983104.0,10013888.0,10044672.0,10075456.0,10106240.0,10137024.0,10167808.0,10198592.0,10004864.0,10035712.0,10066560.0,10097408.0,10128256.0,10159104.0,10189952.0,10220800.0,10026624.0,10057536.0,10088448.0,10119360.0,10150272.0,10181184.0,10212096.0,10243008.0,10048384.0,10079360.0,10110336.0,10141312.0,10172288.0,10203264.0,10234240.0,10265216.0,10070144.0,10101184.0,10132224.0,10163264.0,10194304.0,10225344.0,10256384.0,10287424.0,10091904.0,10123008.0,10154112.0,10185216.0,10216320.0,10247424.0,10278528.0,10309632.0,10113664.0,10144832.0,10176000.0,10207168.0,10238336.0,10269504.0,10300672.0,10331840.0,10135424.0,10166656.0,10197888.0,10229120.0,10260352.0,10291584.0,10322816.0,10354048.0,10157184.0,10188480.0,10219776.0,10251072.0,10282368.0,10313664.0,10344960.0,10376256.0,10178944.0,10210304.0,10241664.0,10273024.0,10304384.0,10335744.0,10367104.0,10398464.0,10200704.0,10232128.0,10263552.0,10294976.0,10326400.0,10357824.0,10389248.0,10420672.0,10222464.0,10253952.0,10285440.0,10316928.0,10348416.0,10379904.0,10411392.0,10442880.0,10244224.0,10275776.0,10307328.0,10338880.0,10370432.0,10401984.0,10433536.0,10465088.0,10265984.0,10297600.0,10329216.0,10360832.0,10392448.0,10424064.0,10455680.0,10487296.0,10287744.0,10319424.0,10351104.0,10382784.0,10414464.0,10446144.0,10477824.0,10509504.0,10309504.0,10341248.0,10372992.0,10404736.0,10436480.0,10468224.0,10499968.0,10531712.0,10331264.0,10363072.0,10394880.0,10426688.0,10458496.0,10490304.0,10522112.0,10553920.0,10353024.0,10384896.0,10416768.0,10448640.0,10480512.0,10512384.0,10544256.0,10576128.0,10374784.0,10406720.0,10438656.0,10470592.0,10502528.0,10534464.0,10566400.0,10598336.0,10396544.0,10428544.0,10460544.0,10492544.0,10524544.0,10556544.0,10588544.0,10620544.0,10418304.0,10450368.0,10482432.0,10514496.0,10546560.0,10578624.0,10610688.0,10642752.0,10440064.0,10472192.0,10504320.0,10536448.0,10568576.0,10600704.0,10632832.0,10664960.0,10461824.0,10494016.0,10526208.0,10558400.0,10590592.0,10622784.0,10654976.0,10687168.0,10483584.0,10515840.0,10548096.0,10580352.0,10612608.0,10644864.0,10677120.0,10709376.0,10505344.0,10537664.0,10569984.0,10602304.0,10634624.0,10666944.0,10699264.0,10731584.0,10527104.0,10559488.0,10591872.0,10624256.0,10656640.0,10689024.0,10721408.0,10753792.0,10466944.0,10498368.0,10529792.0,10561216.0,10592640.0,10624064.0,10655488.0,10686912.0,10398592.0,10429056.0,10459520.0,10489984.0,10520448.0,10550912.0,10581376.0,10611840.0,10420352.0,10450880.0,10481408.0,10511936.0,10542464.0,10572992.0,10603520.0,10634048.0,10442112.0,10472704.0,10503296.0,10533888.0,10564480.0,10595072.0,10625664.0,10656256.0,10463872.0,10494528.0,10525184.0,10555840.0,10586496.0,10617152.0,10647808.0,10678464.0,10485632.0,10516352.0,10547072.0,10577792.0,10608512.0,10639232.0,10669952.0,10700672.0,10507392.0,10538176.0,10568960.0,10599744.0,10630528.0,10661312.0,10692096.0,10722880.0,10529152.0,10560000.0,10590848.0,10621696.0,10652544.0,10683392.0,10714240.0,10745088.0,10550912.0,10581824.0,10612736.0,10643648.0,10674560.0,10705472.0,10736384.0,10767296.0,10572672.0,10603648.0,10634624.0,10665600.0,10696576.0,10727552.0,10758528.0,10789504.0,10594432.0,10625472.0,10656512.0,10687552.0,10718592.0,10749632.0,10780672.0,10811712.0,10616192.0,10647296.0,10678400.0,10709504.0,10740608.0,10771712.0,10802816.0,10833920.0,10637952.0,10669120.0,10700288.0,10731456.0,10762624.0,10793792.0,10824960.0,10856128.0,10659712.0,10690944.0,10722176.0,10753408.0,10784640.0,10815872.0,10847104.0,10878336.0,10681472.0,10712768.0,10744064.0,10775360.0,10806656.0,10837952.0,10869248.0,10900544.0,10703232.0,10734592.0,10765952.0,10797312.0,10828672.0,10860032.0,10891392.0,10922752.0,10724992.0,10756416.0,10787840.0,10819264.0,10850688.0,10882112.0,10913536.0,10944960.0,10746752.0,10778240.0,10809728.0,10841216.0,10872704.0,10904192.0,10935680.0,10967168.0,10768512.0,10800064.0,10831616.0,10863168.0,10894720.0,10926272.0,10957824.0,10989376.0,10790272.0,10821888.0,10853504.0,10885120.0,10916736.0,10948352.0,10979968.0,11011584.0,10812032.0,10843712.0,10875392.0,10907072.0,10938752.0,10970432.0,11002112.0,11033792.0,10833792.0,10865536.0,10897280.0,10929024.0,10960768.0,10992512.0,11024256.0,11056000.0,10855552.0,10887360.0,10919168.0,10950976.0,10982784.0,11014592.0,11046400.0,11078208.0,10877312.0,10909184.0,10941056.0,10972928.0,11004800.0,11036672.0,11068544.0,11100416.0,10899072.0,10931008.0,10962944.0,10994880.0,11026816.0,11058752.0,11090688.0,11122624.0,10920832.0,10952832.0,10984832.0,11016832.0,11048832.0,11080832.0,11112832.0,11144832.0,10942592.0,10974656.0,11006720.0,11038784.0,11070848.0,11102912.0,11134976.0,11167040.0,10964352.0,10996480.0,11028608.0,11060736.0,11092864.0,11124992.0,11157120.0,11189248.0,10986112.0,11018304.0,11050496.0,11082688.0,11114880.0,11147072.0,11179264.0,11211456.0,11007872.0,11040128.0,11072384.0,11104640.0,11136896.0,11169152.0,11201408.0,11233664.0,11029632.0,11061952.0,11094272.0,11126592.0,11158912.0,11191232.0,11223552.0,11255872.0,11051392.0,11083776.0,11116160.0,11148544.0,11180928.0,11213312.0,11245696.0,11278080.0,11073152.0,11105600.0,11138048.0,11170496.0,11202944.0,11235392.0,11267840.0,11300288.0,11094912.0,11127424.0,11159936.0,11192448.0,11224960.0,11257472.0,11289984.0,11322496.0,11116672.0,11149248.0,11181824.0,11214400.0,11246976.0,11279552.0,11312128.0,11344704.0,11138432.0,11171072.0,11203712.0,11236352.0,11268992.0,11301632.0,11334272.0,11366912.0,11160192.0,11192896.0,11225600.0,11258304.0,11291008.0,11323712.0,11356416.0,11389120.0,11181952.0,11214720.0,11247488.0,11280256.0,11313024.0,11345792.0,11378560.0,11411328.0,11203712.0,11236544.0,11269376.0,11302208.0,11335040.0,11367872.0,11400704.0,11433536.0,10619264.0,10651136.0,10683008.0,10714880.0,10746752.0,10778624.0,10810496.0,10842368.0,10026624.0,10057536.0,10088448.0,10119360.0,10150272.0,10181184.0,10212096.0,10243008.0,10048384.0,10079360.0,10110336.0,10141312.0,10172288.0,10203264.0,10234240.0,10265216.0,10070144.0,10101184.0,10132224.0,10163264.0,10194304.0,10225344.0,10256384.0,10287424.0,10091904.0,10123008.0,10154112.0,10185216.0,10216320.0,10247424.0,10278528.0,10309632.0,10113664.0,10144832.0,10176000.0,10207168.0,10238336.0,10269504.0,10300672.0,10331840.0,10135424.0,10166656.0,10197888.0,10229120.0,10260352.0,10291584.0,10322816.0,10354048.0,10157184.0,10188480.0,10219776.0,10251072.0,10282368.0,10313664.0,10344960.0,10376256.0,10178944.0,10210304.0,10241664.0,10273024.0,10304384.0,10335744.0,10367104.0,10398464.0,10200704.0,10232128.0,10263552.0,10294976.0,10326400.0,10357824.0,10389248.0,10420672.0,10222464.0,10253952.0,10285440.0,10316928.0,10348416.0,10379904.0,10411392.0,10442880.0,10244224.0,10275776.0,10307328.0,10338880.0,10370432.0,10401984.0,10433536.0,10465088.0,10265984.0,10297600.0,10329216.0,10360832.0,10392448.0,10424064.0,10455680.0,10487296.0,10287744.0,10319424.0,10351104.0,10382784.0,10414464.0,10446144.0,10477824.0,10509504.0,10309504.0,10341248.0,10372992.0,10404736.0,10436480.0,10468224.0,10499968.0,10531712.0,10331264.0,10363072.0,10394880.0,10426688.0,10458496.0,10490304.0,10522112.0,10553920.0,10353024.0,10384896.0,10416768.0,10448640.0,10480512.0,10512384.0,10544256.0,10576128.0,10374784.0,10406720.0,10438656.0,10470592.0,10502528.0,10534464.0,10566400.0,10598336.0,10396544.0,10428544.0,10460544.0,10492544.0,10524544.0,10556544.0,10588544.0,10620544.0,10418304.0,10450368.0,10482432.0,10514496.0,10546560.0,10578624.0,10610688.0,10642752.0,10440064.0,10472192.0,10504320.0,10536448.0,10568576.0,10600704.0,10632832.0,10664960.0,10461824.0,10494016.0,10526208.0,10558400.0,10590592.0,10622784.0,10654976.0,10687168.0,10483584.0,10515840.0,10548096.0,10580352.0,10612608.0,10644864.0,10677120.0,10709376.0,10505344.0,10537664.0,10569984.0,10602304.0,10634624.0,10666944.0,10699264.0,10731584.0,10527104.0,10559488.0,10591872.0,10624256.0,10656640.0,10689024.0,10721408.0,10753792.0,10548864.0,10581312.0,10613760.0,10646208.0,10678656.0,10711104.0,10743552.0,10776000.0,10570624.0,10603136.0,10635648.0,10668160.0,10700672.0,10733184.0,10765696.0,10798208.0,10592384.0,10624960.0,10657536.0,10690112.0,10722688.0,10755264.0,10787840.0,10820416.0,10614144.0,10646784.0,10679424.0,10712064.0,10744704.0,10777344.0,10809984.0,10842624.0,10635904.0,10668608.0,10701312.0,10734016.0,10766720.0,10799424.0,10832128.0,10864832.0,10657664.0,10690432.0,10723200.0,10755968.0,10788736.0,10821504.0,10854272.0,10887040.0,10679424.0,10712256.0,10745088.0,10777920.0,10810752.0,10843584.0,10876416.0,10909248.0,10701184.0,10734080.0,10766976.0,10799872.0,10832768.0,10865664.0,10898560.0,10931456.0,10722944.0,10755904.0,10788864.0,10821824.0,10854784.0,10887744.0,10920704.0,10953664.0,10744704.0,10777728.0,10810752.0,10843776.0,10876800.0,10909824.0,10942848.0,10975872.0,10766464.0,10799552.0,10832640.0,10865728.0,10898816.0,10931904.0,10964992.0,10998080.0,10788224.0,10821376.0,10854528.0,10887680.0,10920832.0,10953984.0,10987136.0,11020288.0,10809984.0,10843200.0,10876416.0,10909632.0,10942848.0,10976064.0,11009280.0,11042496.0,10831744.0,10865024.0,10898304.0,10931584.0,10964864.0,10998144.0,11031424.0,11064704.0,10771584.0,10803904.0,10836224.0,10868544.0,10900864.0,10933184.0,10965504.0,10997824.0,10703232.0,10734592.0,10765952.0,10797312.0,10828672.0,10860032.0,10891392.0,10922752.0,10724992.0,10756416.0,10787840.0,10819264.0,10850688.0,10882112.0,10913536.0,10944960.0,10746752.0,10778240.0,10809728.0,10841216.0,10872704.0,10904192.0,10935680.0,10967168.0,10768512.0,10800064.0,10831616.0,10863168.0,10894720.0,10926272.0,10957824.0,10989376.0,10790272.0,10821888.0,10853504.0,10885120.0,10916736.0,10948352.0,10979968.0,11011584.0,10812032.0,10843712.0,10875392.0,10907072.0,10938752.0,10970432.0,11002112.0,11033792.0,10833792.0,10865536.0,10897280.0,10929024.0,10960768.0,10992512.0,11024256.0,11056000.0,10855552.0,10887360.0,10919168.0,10950976.0,10982784.0,11014592.0,11046400.0,11078208.0,10877312.0,10909184.0,10941056.0,10972928.0,11004800.0,11036672.0,11068544.0,11100416.0,10899072.0,10931008.0,10962944.0,10994880.0,11026816.0,11058752.0,11090688.0,11122624.0,10920832.0,10952832.0,10984832.0,11016832.0,11048832.0,11080832.0,11112832.0,11144832.0,10942592.0,10974656.0,11006720.0,11038784.0,11070848.0,11102912.0,11134976.0,11167040.0,10964352.0,10996480.0,11028608.0,11060736.0,11092864.0,11124992.0,11157120.0,11189248.0,10986112.0,11018304.0,11050496.0,11082688.0,11114880.0,11147072.0,11179264.0,11211456.0,11007872.0,11040128.0,11072384.0,11104640.0,11136896.0,11169152.0,11201408.0,11233664.0,11029632.0,11061952.0,11094272.0,11126592.0,11158912.0,11191232.0,11223552.0,11255872.0,11051392.0,11083776.0,11116160.0,11148544.0,11180928.0,11213312.0,11245696.0,11278080.0,11073152.0,11105600.0,11138048.0,11170496.0,11202944.0,11235392.0,11267840.0,11300288.0,11094912.0,11127424.0,11159936.0,11192448.0,11224960.0,11257472.0,11289984.0,11322496.0,11116672.0,11149248.0,11181824.0,11214400.0,11246976.0,11279552.0,11312128.0,11344704.0,11138432.0,11171072.0,11203712.0,11236352.0,11268992.0,11301632.0,11334272.0,11366912.0,11160192.0,11192896.0,11225600.0,11258304.0,11291008.0,11323712.0,11356416.0,11389120.0,11181952.0,11214720.0,11247488.0,11280256.0,11313024.0,11345792.0,11378560.0,11411328.0,11203712.0,11236544.0,11269376.0,11302208.0,11335040.0,11367872.0,11400704.0,11433536.0,11225472.0,11258368.0,11291264.0,11324160.0,11357056.0,11389952.0,11422848.0,11455744.0,11247232.0,11280192.0,11313152.0,11346112.0,11379072.0,11412032.0,11444992.0,11477952.0,11268992.0,11302016.0,11335040.0,11368064.0,11401088.0,11434112.0,11467136.0,11500160.0,11290752.0,11323840.0,11356928.0,11390016.0,11423104.0,11456192.0,11489280.0,11522368.0,11312512.0,11345664.0,11378816.0,11411968.0,11445120.0,11478272.0,11511424.0,11544576.0,11334272.0,11367488.0,11400704.0,11433920.0,11467136.0,11500352.0,11533568.0,11566784.0,11356032.0,11389312.0,11422592.0,11455872.0,11489152.0,11522432.0,11555712.0,11588992.0,11377792.0,11411136.0,11444480.0,11477824.0,11511168.0,11544512.0,11577856.0,11611200.0,11399552.0,11432960.0,11466368.0,11499776.0,11533184.0,11566592.0,11600000.0,11633408.0,11421312.0,11454784.0,11488256.0,11521728.0,11555200.0,11588672.0,11622144.0,11655616.0,11443072.0,11476608.0,11510144.0,11543680.0,11577216.0,11610752.0,11644288.0,11677824.0,11464832.0,11498432.0,11532032.0,11565632.0,11599232.0,11632832.0,11666432.0,11700032.0,11486592.0,11520256.0,11553920.0,11587584.0,11621248.0,11654912.0,11688576.0,11722240.0,11508352.0,11542080.0,11575808.0,11609536.0,11643264.0,11676992.0,11710720.0,11744448.0,10923904.0,10956672.0,10989440.0,11022208.0,11054976.0,11087744.0,11120512.0,11153280.0,10331264.0,10363072.0,10394880.0,10426688.0,10458496.0,10490304.0,10522112.0,10553920.0,5234816.0,5250736.0,5266656.0,5282576.0,5298496.0,5314416.0,5330336.0,5346256.0,5841408.0,5857200.0,5872992.0,5888784.0,5904576.0,5920368.0,5936160.0,5951952.0,11820416.0,11852032.0,11883648.0,11915264.0,11946880.0,11978496.0,12010112.0,12041728.0,11843200.0,11874880.0,11906560.0,11938240.0,11969920.0,12001600.0,12033280.0,12064960.0,11865984.0,11897728.0,11929472.0,11961216.0,11992960.0,12024704.0,12056448.0,12088192.0,11888768.0,11920576.0,11952384.0,11984192.0,12016000.0,12047808.0,12079616.0,12111424.0,11911552.0,11943424.0,11975296.0,12007168.0,12039040.0,12070912.0,12102784.0,12134656.0,11934336.0,11966272.0,11998208.0,12030144.0,12062080.0,12094016.0,12125952.0,12157888.0,11957120.0,11989120.0,12021120.0,12053120.0,12085120.0,12117120.0,12149120.0,12181120.0,11979904.0,12011968.0,12044032.0,12076096.0,12108160.0,12140224.0,12172288.0,12204352.0,12002688.0,12034816.0,12066944.0,12099072.0,12131200.0,12163328.0,12195456.0,12227584.0,12025472.0,12057664.0,12089856.0,12122048.0,12154240.0,12186432.0,12218624.0,12250816.0,12048256.0,12080512.0,12112768.0,12145024.0,12177280.0,12209536.0,12241792.0,12274048.0,12071040.0,12103360.0,12135680.0,12168000.0,12200320.0,12232640.0,12264960.0,12297280.0,12093824.0,12126208.0,12158592.0,12190976.0,12223360.0,12255744.0,12288128.0,12320512.0,12116608.0,12149056.0,12181504.0,12213952.0,12246400.0,12278848.0,12311296.0,12343744.0,12139392.0,12171904.0,12204416.0,12236928.0,12269440.0,12301952.0,12334464.0,12366976.0,12162176.0,12194752.0,12227328.0,12259904.0,12292480.0,12325056.0,12357632.0,12390208.0,12184960.0,12217600.0,12250240.0,12282880.0,12315520.0,12348160.0,12380800.0,12413440.0,12207744.0,12240448.0,12273152.0,12305856.0,12338560.0,12371264.0,12403968.0,12436672.0,12230528.0,12263296.0,12296064.0,12328832.0,12361600.0,12394368.0,12427136.0,12459904.0,12253312.0,12286144.0,12318976.0,12351808.0,12384640.0,12417472.0,12450304.0,12483136.0,12276096.0,12308992.0,12341888.0,12374784.0,12407680.0,12440576.0,12473472.0,12506368.0,12298880.0,12331840.0,12364800.0,12397760.0,12430720.0,12463680.0,12496640.0,12529600.0,12321664.0,12354688.0,12387712.0,12420736.0,12453760.0,12486784.0,12519808.0,12552832.0,12344448.0,12377536.0,12410624.0,12443712.0,12476800.0,12509888.0,12542976.0,12576064.0,12367232.0,12400384.0,12433536.0,12466688.0,12499840.0,12532992.0,12566144.0,12599296.0,12390016.0,12423232.0,12456448.0,12489664.0,12522880.0,12556096.0,12589312.0,12622528.0,12412800.0,12446080.0,12479360.0,12512640.0,12545920.0,12579200.0,12612480.0,12645760.0,12435584.0,12468928.0,12502272.0,12535616.0,12568960.0,12602304.0,12635648.0,12668992.0,12458368.0,12491776.0,12525184.0,12558592.0,12592000.0,12625408.0,12658816.0,12692224.0,12481152.0,12514624.0,12548096.0,12581568.0,12615040.0,12648512.0,12681984.0,12715456.0,12503936.0,12537472.0,12571008.0,12604544.0,12638080.0,12671616.0,12705152.0,12738688.0,12526720.0,12560320.0,12593920.0,12627520.0,12661120.0,12694720.0,12728320.0,12761920.0,12549504.0,12583168.0,12616832.0,12650496.0,12684160.0,12717824.0,12751488.0,12785152.0,12572288.0,12606016.0,12639744.0,12673472.0,12707200.0,12740928.0,12774656.0,12808384.0,12595072.0,12628864.0,12662656.0,12696448.0,12730240.0,12764032.0,12797824.0,12831616.0,12617856.0,12651712.0,12685568.0,12719424.0,12753280.0,12787136.0,12820992.0,12854848.0,12018048.0,12050944.0,12083840.0,12116736.0,12149632.0,12182528.0,12215424.0,12248320.0,11410048.0,11441984.0,11473920.0,11505856.0,11537792.0,11569728.0,11601664.0,11633600.0,10810240.0,10841216.0,10872192.0,10903168.0,10934144.0,10965120.0,10996096.0,11027072.0,10202240.0,10232256.0,10262272.0,10292288.0,10322304.0,10352320.0,10382336.0,10412352.0,10225024.0,10255104.0,10285184.0,10315264.0,10345344.0,10375424.0,10405504.0,10435584.0,10247808.0,10277952.0,10308096.0,10338240.0,10368384.0,10398528.0,10428672.0,10458816.0,10270592.0,10300800.0,10331008.0,10361216.0,10391424.0,10421632.0,10451840.0,10482048.0,10293376.0,10323648.0,10353920.0,10384192.0,10414464.0,10444736.0,10475008.0,10505280.0,10316160.0,10346496.0,10376832.0,10407168.0,10437504.0,10467840.0,10498176.0,10528512.0,10338944.0,10369344.0,10399744.0,10430144.0,10460544.0,10490944.0,10521344.0,10551744.0,10361728.0,10392192.0,10422656.0,10453120.0,10483584.0,10514048.0,10544512.0,10574976.0,10384512.0,10415040.0,10445568.0,10476096.0,10506624.0,10537152.0,10567680.0,10598208.0,10407296.0,10437888.0,10468480.0,10499072.0,10529664.0,10560256.0,10590848.0,10621440.0,10430080.0,10460736.0,10491392.0,10522048.0,10552704.0,10583360.0,10614016.0,10644672.0,10452864.0,10483584.0,10514304.0,10545024.0,10575744.0,10606464.0,10637184.0,10667904.0,10475648.0,10506432.0,10537216.0,10568000.0,10598784.0,10629568.0,10660352.0,10691136.0,10498432.0,10529280.0,10560128.0,10590976.0,10621824.0,10652672.0,10683520.0,10714368.0,10521216.0,10552128.0,10583040.0,10613952.0,10644864.0,10675776.0,10706688.0,10737600.0,10544000.0,10574976.0,10605952.0,10636928.0,10667904.0,10698880.0,10729856.0,10760832.0,10566784.0,10597824.0,10628864.0,10659904.0,10690944.0,10721984.0,10753024.0,10784064.0,10589568.0,10620672.0,10651776.0,10682880.0,10713984.0,10745088.0,10776192.0,10807296.0,10612352.0,10643520.0,10674688.0,10705856.0,10737024.0,10768192.0,10799360.0,10830528.0,10635136.0,10666368.0,10697600.0,10728832.0,10760064.0,10791296.0,10822528.0,10853760.0,10657920.0,10689216.0,10720512.0,10751808.0,10783104.0,10814400.0,10845696.0,10876992.0,10680704.0,10712064.0,10743424.0,10774784.0,10806144.0,10837504.0,10868864.0,10900224.0,10703488.0,10734912.0,10766336.0,10797760.0,10829184.0,10860608.0,10892032.0,10923456.0,10726272.0,10757760.0,10789248.0,10820736.0,10852224.0,10883712.0,10915200.0,10946688.0,10749056.0,10780608.0,10812160.0,10843712.0,10875264.0,10906816.0,10938368.0,10969920.0,10771840.0,10803456.0,10835072.0,10866688.0,10898304.0,10929920.0,10961536.0,10993152.0,10794624.0,10826304.0,10857984.0,10889664.0,10921344.0,10953024.0,10984704.0,11016384.0,10817408.0,10849152.0,10880896.0,10912640.0,10944384.0,10976128.0,11007872.0,11039616.0,10840192.0,10872000.0,10903808.0,10935616.0,10967424.0,10999232.0,11031040.0,11062848.0,10862976.0,10894848.0,10926720.0,10958592.0,10990464.0,11022336.0,11054208.0,11086080.0,10885760.0,10917696.0,10949632.0,10981568.0,11013504.0,11045440.0,11077376.0,11109312.0,10908544.0,10940544.0,10972544.0,11004544.0,11036544.0,11068544.0,11100544.0,11132544.0,10931328.0,10963392.0,10995456.0,11027520.0,11059584.0,11091648.0,11123712.0,11155776.0,10954112.0,10986240.0,11018368.0,11050496.0,11082624.0,11114752.0,11146880.0,11179008.0,10976896.0,11009088.0,11041280.0,11073472.0,11105664.0,11137856.0,11170048.0,11202240.0,10999680.0,11031936.0,11064192.0,11096448.0,11128704.0,11160960.0,11193216.0,11225472.0,11022464.0,11054784.0,11087104.0,11119424.0,11151744.0,11184064.0,11216384.0,11248704.0,11045248.0,11077632.0,11110016.0,11142400.0,11174784.0,11207168.0,11239552.0,11271936.0,10969728.0,11001152.0,11032576.0,11064000.0,11095424.0,11126848.0,11158272.0,11189696.0,10886016.0,10916480.0,10946944.0,10977408.0,11007872.0,11038336.0,11068800.0,11099264.0,10908800.0,10939328.0,10969856.0,11000384.0,11030912.0,11061440.0,11091968.0,11122496.0,10931584.0,10962176.0,10992768.0,11023360.0,11053952.0,11084544.0,11115136.0,11145728.0,10954368.0,10985024.0,11015680.0,11046336.0,11076992.0,11107648.0,11138304.0,11168960.0,10977152.0,11007872.0,11038592.0,11069312.0,11100032.0,11130752.0,11161472.0,11192192.0,10999936.0,11030720.0,11061504.0,11092288.0,11123072.0,11153856.0,11184640.0,11215424.0,11022720.0,11053568.0,11084416.0,11115264.0,11146112.0,11176960.0,11207808.0,11238656.0,11045504.0,11076416.0,11107328.0,11138240.0,11169152.0,11200064.0,11230976.0,11261888.0,11068288.0,11099264.0,11130240.0,11161216.0,11192192.0,11223168.0,11254144.0,11285120.0,11091072.0,11122112.0,11153152.0,11184192.0,11215232.0,11246272.0,11277312.0,11308352.0,11113856.0,11144960.0,11176064.0,11207168.0,11238272.0,11269376.0,11300480.0,11331584.0,11136640.0,11167808.0,11198976.0,11230144.0,11261312.0,11292480.0,11323648.0,11354816.0,11159424.0,11190656.0,11221888.0,11253120.0,11284352.0,11315584.0,11346816.0,11378048.0,11182208.0,11213504.0,11244800.0,11276096.0,11307392.0,11338688.0,11369984.0,11401280.0,11204992.0,11236352.0,11267712.0,11299072.0,11330432.0,11361792.0,11393152.0,11424512.0,11227776.0,11259200.0,11290624.0,11322048.0,11353472.0,11384896.0,11416320.0,11447744.0,11250560.0,11282048.0,11313536.0,11345024.0,11376512.0,11408000.0,11439488.0,11470976.0,11273344.0,11304896.0,11336448.0,11368000.0,11399552.0,11431104.0,11462656.0,11494208.0,11296128.0,11327744.0,11359360.0,11390976.0,11422592.0,11454208.0,11485824.0,11517440.0,11318912.0,11350592.0,11382272.0,11413952.0,11445632.0,11477312.0,11508992.0,11540672.0,11341696.0,11373440.0,11405184.0,11436928.0,11468672.0,11500416.0,11532160.0,11563904.0,11364480.0,11396288.0,11428096.0,11459904.0,11491712.0,11523520.0,11555328.0,11587136.0,11387264.0,11419136.0,11451008.0,11482880.0,11514752.0,11546624.0,11578496.0,11610368.0,11410048.0,11441984.0,11473920.0,11505856.0,11537792.0,11569728.0,11601664.0,11633600.0,11432832.0,11464832.0,11496832.0,11528832.0,11560832.0,11592832.0,11624832.0,11656832.0,11455616.0,11487680.0,11519744.0,11551808.0,11583872.0,11615936.0,11648000.0,11680064.0,11478400.0,11510528.0,11542656.0,11574784.0,11606912.0,11639040.0,11671168.0,11703296.0,11501184.0,11533376.0,11565568.0,11597760.0,11629952.0,11662144.0,11694336.0,11726528.0,11523968.0,11556224.0,11588480.0,11620736.0,11652992.0,11685248.0,11717504.0,11749760.0,11546752.0,11579072.0,11611392.0,11643712.0,11676032.0,11708352.0,11740672.0,11772992.0,11569536.0,11601920.0,11634304.0,11666688.0,11699072.0,11731456.0,11763840.0,11796224.0,11592320.0,11624768.0,11657216.0,11689664.0,11722112.0,11754560.0,11787008.0,11819456.0,11615104.0,11647616.0,11680128.0,11712640.0,11745152.0,11777664.0,11810176.0,11842688.0,11637888.0,11670464.0,11703040.0,11735616.0,11768192.0,11800768.0,11833344.0,11865920.0,11660672.0,11693312.0,11725952.0,11758592.0,11791232.0,11823872.0,11856512.0,11889152.0,11683456.0,11716160.0,11748864.0,11781568.0,11814272.0,11846976.0,11879680.0,11912384.0,11706240.0,11739008.0,11771776.0,11804544.0,11837312.0,11870080.0,11902848.0,11935616.0,11729024.0,11761856.0,11794688.0,11827520.0,11860352.0,11893184.0,11926016.0,11958848.0,11129216.0,11161088.0,11192960.0,11224832.0,11256704.0,11288576.0,11320448.0,11352320.0,10521216.0,10552128.0,10583040.0,10613952.0,10644864.0,10675776.0,10706688.0,10737600.0,10544000.0,10574976.0,10605952.0,10636928.0,10667904.0,10698880.0,10729856.0,10760832.0,10566784.0,10597824.0,10628864.0,10659904.0,10690944.0,10721984.0,10753024.0,10784064.0,10589568.0,10620672.0,10651776.0,10682880.0,10713984.0,10745088.0,10776192.0,10807296.0,10612352.0,10643520.0,10674688.0,10705856.0,10737024.0,10768192.0,10799360.0,10830528.0,10635136.0,10666368.0,10697600.0,10728832.0,10760064.0,10791296.0,10822528.0,10853760.0,10657920.0,10689216.0,10720512.0,10751808.0,10783104.0,10814400.0,10845696.0,10876992.0,10680704.0,10712064.0,10743424.0,10774784.0,10806144.0,10837504.0,10868864.0,10900224.0,10703488.0,10734912.0,10766336.0,10797760.0,10829184.0,10860608.0,10892032.0,10923456.0,10726272.0,10757760.0,10789248.0,10820736.0,10852224.0,10883712.0,10915200.0,10946688.0,10749056.0,10780608.0,10812160.0,10843712.0,10875264.0,10906816.0,10938368.0,10969920.0,10771840.0,10803456.0,10835072.0,10866688.0,10898304.0,10929920.0,10961536.0,10993152.0,10794624.0,10826304.0,10857984.0,10889664.0,10921344.0,10953024.0,10984704.0,11016384.0,10817408.0,10849152.0,10880896.0,10912640.0,10944384.0,10976128.0,11007872.0,11039616.0,10840192.0,10872000.0,10903808.0,10935616.0,10967424.0,10999232.0,11031040.0,11062848.0,10862976.0,10894848.0,10926720.0,10958592.0,10990464.0,11022336.0,11054208.0,11086080.0,10885760.0,10917696.0,10949632.0,10981568.0,11013504.0,11045440.0,11077376.0,11109312.0,10908544.0,10940544.0,10972544.0,11004544.0,11036544.0,11068544.0,11100544.0,11132544.0,10931328.0,10963392.0,10995456.0,11027520.0,11059584.0,11091648.0,11123712.0,11155776.0,10954112.0,10986240.0,11018368.0,11050496.0,11082624.0,11114752.0,11146880.0,11179008.0,10976896.0,11009088.0,11041280.0,11073472.0,11105664.0,11137856.0,11170048.0,11202240.0,10999680.0,11031936.0,11064192.0,11096448.0,11128704.0,11160960.0,11193216.0,11225472.0,11022464.0,11054784.0,11087104.0,11119424.0,11151744.0,11184064.0,11216384.0,11248704.0,11045248.0,11077632.0,11110016.0,11142400.0,11174784.0,11207168.0,11239552.0,11271936.0,11068032.0,11100480.0,11132928.0,11165376.0,11197824.0,11230272.0,11262720.0,11295168.0,11090816.0,11123328.0,11155840.0,11188352.0,11220864.0,11253376.0,11285888.0,11318400.0,11113600.0,11146176.0,11178752.0,11211328.0,11243904.0,11276480.0,11309056.0,11341632.0,11136384.0,11169024.0,11201664.0,11234304.0,11266944.0,11299584.0,11332224.0,11364864.0,11159168.0,11191872.0,11224576.0,11257280.0,11289984.0,11322688.0,11355392.0,11388096.0,11181952.0,11214720.0,11247488.0,11280256.0,11313024.0,11345792.0,11378560.0,11411328.0,11204736.0,11237568.0,11270400.0,11303232.0,11336064.0,11368896.0,11401728.0,11434560.0,11227520.0,11260416.0,11293312.0,11326208.0,11359104.0,11392000.0,11424896.0,11457792.0,11250304.0,11283264.0,11316224.0,11349184.0,11382144.0,11415104.0,11448064.0,11481024.0,11273088.0,11306112.0,11339136.0,11372160.0,11405184.0,11438208.0,11471232.0,11504256.0,11295872.0,11328960.0,11362048.0,11395136.0,11428224.0,11461312.0,11494400.0,11527488.0,11318656.0,11351808.0,11384960.0,11418112.0,11451264.0,11484416.0,11517568.0,11550720.0,11341440.0,11374656.0,11407872.0,11441088.0,11474304.0,11507520.0,11540736.0,11573952.0,11364224.0,11397504.0,11430784.0,11464064.0,11497344.0,11530624.0,11563904.0,11597184.0,11288704.0,11321024.0,11353344.0,11385664.0,11417984.0,11450304.0,11482624.0,11514944.0,11204992.0,11236352.0,11267712.0,11299072.0,11330432.0,11361792.0,11393152.0,11424512.0,11227776.0,11259200.0,11290624.0,11322048.0,11353472.0,11384896.0,11416320.0,11447744.0,11250560.0,11282048.0,11313536.0,11345024.0,11376512.0,11408000.0,11439488.0,11470976.0,11273344.0,11304896.0,11336448.0,11368000.0,11399552.0,11431104.0,11462656.0,11494208.0,11296128.0,11327744.0,11359360.0,11390976.0,11422592.0,11454208.0,11485824.0,11517440.0,11318912.0,11350592.0,11382272.0,11413952.0,11445632.0,11477312.0,11508992.0,11540672.0,11341696.0,11373440.0,11405184.0,11436928.0,11468672.0,11500416.0,11532160.0,11563904.0,11364480.0,11396288.0,11428096.0,11459904.0,11491712.0,11523520.0,11555328.0,11587136.0,11387264.0,11419136.0,11451008.0,11482880.0,11514752.0,11546624.0,11578496.0,11610368.0,11410048.0,11441984.0,11473920.0,11505856.0,11537792.0,11569728.0,11601664.0,11633600.0,11432832.0,11464832.0,11496832.0,11528832.0,11560832.0,11592832.0,11624832.0,11656832.0,11455616.0,11487680.0,11519744.0,11551808.0,11583872.0,11615936.0,11648000.0,11680064.0,11478400.0,11510528.0,11542656.0,11574784.0,11606912.0,11639040.0,11671168.0,11703296.0,11501184.0,11533376.0,11565568.0,11597760.0,11629952.0,11662144.0,11694336.0,11726528.0,11523968.0,11556224.0,11588480.0,11620736.0,11652992.0,11685248.0,11717504.0,11749760.0,11546752.0,11579072.0,11611392.0,11643712.0,11676032.0,11708352.0,11740672.0,11772992.0,11569536.0,11601920.0,11634304.0,11666688.0,11699072.0,11731456.0,11763840.0,11796224.0,11592320.0,11624768.0,11657216.0,11689664.0,11722112.0,11754560.0,11787008.0,11819456.0,11615104.0,11647616.0,11680128.0,11712640.0,11745152.0,11777664.0,11810176.0,11842688.0,11637888.0,11670464.0,11703040.0,11735616.0,11768192.0,11800768.0,11833344.0,11865920.0,11660672.0,11693312.0,11725952.0,11758592.0,11791232.0,11823872.0,11856512.0,11889152.0,11683456.0,11716160.0,11748864.0,11781568.0,11814272.0,11846976.0,11879680.0,11912384.0,11706240.0,11739008.0,11771776.0,11804544.0,11837312.0,11870080.0,11902848.0,11935616.0,11729024.0,11761856.0,11794688.0,11827520.0,11860352.0,11893184.0,11926016.0,11958848.0,11751808.0,11784704.0,11817600.0,11850496.0,11883392.0,11916288.0,11949184.0,11982080.0,11774592.0,11807552.0,11840512.0,11873472.0,11906432.0,11939392.0,11972352.0,12005312.0,11797376.0,11830400.0,11863424.0,11896448.0,11929472.0,11962496.0,11995520.0,12028544.0,11820160.0,11853248.0,11886336.0,11919424.0,11952512.0,11985600.0,12018688.0,12051776.0,11842944.0,11876096.0,11909248.0,11942400.0,11975552.0,12008704.0,12041856.0,12075008.0,11865728.0,11898944.0,11932160.0,11965376.0,11998592.0,12031808.0,12065024.0,12098240.0,11888512.0,11921792.0,11955072.0,11988352.0,12021632.0,12054912.0,12088192.0,12121472.0,11911296.0,11944640.0,11977984.0,12011328.0,12044672.0,12078016.0,12111360.0,12144704.0,11934080.0,11967488.0,12000896.0,12034304.0,12067712.0,12101120.0,12134528.0,12167936.0,11956864.0,11990336.0,12023808.0,12057280.0,12090752.0,12124224.0,12157696.0,12191168.0,11979648.0,12013184.0,12046720.0,12080256.0,12113792.0,12147328.0,12180864.0,12214400.0,12002432.0,12036032.0,12069632.0,12103232.0,12136832.0,12170432.0,12204032.0,12237632.0,12025216.0,12058880.0,12092544.0,12126208.0,12159872.0,12193536.0,12227200.0,12260864.0,12048000.0,12081728.0,12115456.0,12149184.0,12182912.0,12216640.0,12250368.0,12284096.0,11448192.0,11480960.0,11513728.0,11546496.0,11579264.0,11612032.0,11644800.0,11677568.0,10840192.0,10872000.0,10903808.0,10935616.0,10967424.0,10999232.0,11031040.0,11062848.0,5489536.0,5505456.0,5521376.0,5537296.0,5553216.0,5569136.0,5585056.0,5600976.0,6094080.0,6109872.0,6125664.0,6141456.0,6157248.0,6173040.0,6188832.0,6204624.0,12326272.0,12357888.0,12389504.0,12421120.0,12452736.0,12484352.0,12515968.0,12547584.0,12350080.0,12381760.0,12413440.0,12445120.0,12476800.0,12508480.0,12540160.0,12571840.0,12373888.0,12405632.0,12437376.0,12469120.0,12500864.0,12532608.0,12564352.0,12596096.0,12397696.0,12429504.0,12461312.0,12493120.0,12524928.0,12556736.0,12588544.0,12620352.0,12421504.0,12453376.0,12485248.0,12517120.0,12548992.0,12580864.0,12612736.0,12644608.0,12445312.0,12477248.0,12509184.0,12541120.0,12573056.0,12604992.0,12636928.0,12668864.0,12469120.0,12501120.0,12533120.0,12565120.0,12597120.0,12629120.0,12661120.0,12693120.0,12492928.0,12524992.0,12557056.0,12589120.0,12621184.0,12653248.0,12685312.0,12717376.0,12516736.0,12548864.0,12580992.0,12613120.0,12645248.0,12677376.0,12709504.0,12741632.0,12540544.0,12572736.0,12604928.0,12637120.0,12669312.0,12701504.0,12733696.0,12765888.0,12564352.0,12596608.0,12628864.0,12661120.0,12693376.0,12725632.0,12757888.0,12790144.0,12588160.0,12620480.0,12652800.0,12685120.0,12717440.0,12749760.0,12782080.0,12814400.0,12611968.0,12644352.0,12676736.0,12709120.0,12741504.0,12773888.0,12806272.0,12838656.0,12635776.0,12668224.0,12700672.0,12733120.0,12765568.0,12798016.0,12830464.0,12862912.0,12659584.0,12692096.0,12724608.0,12757120.0,12789632.0,12822144.0,12854656.0,12887168.0,12683392.0,12715968.0,12748544.0,12781120.0,12813696.0,12846272.0,12878848.0,12911424.0,12707200.0,12739840.0,12772480.0,12805120.0,12837760.0,12870400.0,12903040.0,12935680.0,12731008.0,12763712.0,12796416.0,12829120.0,12861824.0,12894528.0,12927232.0,12959936.0,12754816.0,12787584.0,12820352.0,12853120.0,12885888.0,12918656.0,12951424.0,12984192.0,12778624.0,12811456.0,12844288.0,12877120.0,12909952.0,12942784.0,12975616.0,13008448.0,12802432.0,12835328.0,12868224.0,12901120.0,12934016.0,12966912.0,12999808.0,13032704.0,12826240.0,12859200.0,12892160.0,12925120.0,12958080.0,12991040.0,13024000.0,13056960.0,12850048.0,12883072.0,12916096.0,12949120.0,12982144.0,13015168.0,13048192.0,13081216.0,12873856.0,12906944.0,12940032.0,12973120.0,13006208.0,13039296.0,13072384.0,13105472.0,12897664.0,12930816.0,12963968.0,12997120.0,13030272.0,13063424.0,13096576.0,13129728.0,12921472.0,12954688.0,12987904.0,13021120.0,13054336.0,13087552.0,13120768.0,13153984.0,12945280.0,12978560.0,13011840.0,13045120.0,13078400.0,13111680.0,13144960.0,13178240.0,12969088.0,13002432.0,13035776.0,13069120.0,13102464.0,13135808.0,13169152.0,13202496.0,12992896.0,13026304.0,13059712.0,13093120.0,13126528.0,13159936.0,13193344.0,13226752.0,13016704.0,13050176.0,13083648.0,13117120.0,13150592.0,13184064.0,13217536.0,13251008.0,13040512.0,13074048.0,13107584.0,13141120.0,13174656.0,13208192.0,13241728.0,13275264.0,13064320.0,13097920.0,13131520.0,13165120.0,13198720.0,13232320.0,13265920.0,13299520.0,13088128.0,13121792.0,13155456.0,13189120.0,13222784.0,13256448.0,13290112.0,13323776.0,13111936.0,13145664.0,13179392.0,13213120.0,13246848.0,13280576.0,13314304.0,13348032.0,13135744.0,13169536.0,13203328.0,13237120.0,13270912.0,13304704.0,13338496.0,13372288.0,13159552.0,13193408.0,13227264.0,13261120.0,13294976.0,13328832.0,13362688.0,13396544.0,12544384.0,12577280.0,12610176.0,12643072.0,12675968.0,12708864.0,12741760.0,12774656.0,11921024.0,11952960.0,11984896.0,12016832.0,12048768.0,12080704.0,12112640.0,12144576.0,11305856.0,11336832.0,11367808.0,11398784.0,11429760.0,11460736.0,11491712.0,11522688.0,10682496.0,10712512.0,10742528.0,10772544.0,10802560.0,10832576.0,10862592.0,10892608.0,10706304.0,10736384.0,10766464.0,10796544.0,10826624.0,10856704.0,10886784.0,10916864.0,10730112.0,10760256.0,10790400.0,10820544.0,10850688.0,10880832.0,10910976.0,10941120.0,10753920.0,10784128.0,10814336.0,10844544.0,10874752.0,10904960.0,10935168.0,10965376.0,10777728.0,10808000.0,10838272.0,10868544.0,10898816.0,10929088.0,10959360.0,10989632.0,10801536.0,10831872.0,10862208.0,10892544.0,10922880.0,10953216.0,10983552.0,11013888.0,10825344.0,10855744.0,10886144.0,10916544.0,10946944.0,10977344.0,11007744.0,11038144.0,10849152.0,10879616.0,10910080.0,10940544.0,10971008.0,11001472.0,11031936.0,11062400.0,10872960.0,10903488.0,10934016.0,10964544.0,10995072.0,11025600.0,11056128.0,11086656.0,10896768.0,10927360.0,10957952.0,10988544.0,11019136.0,11049728.0,11080320.0,11110912.0,10920576.0,10951232.0,10981888.0,11012544.0,11043200.0,11073856.0,11104512.0,11135168.0,10944384.0,10975104.0,11005824.0,11036544.0,11067264.0,11097984.0,11128704.0,11159424.0,10968192.0,10998976.0,11029760.0,11060544.0,11091328.0,11122112.0,11152896.0,11183680.0,10992000.0,11022848.0,11053696.0,11084544.0,11115392.0,11146240.0,11177088.0,11207936.0,11015808.0,11046720.0,11077632.0,11108544.0,11139456.0,11170368.0,11201280.0,11232192.0,11039616.0,11070592.0,11101568.0,11132544.0,11163520.0,11194496.0,11225472.0,11256448.0,11063424.0,11094464.0,11125504.0,11156544.0,11187584.0,11218624.0,11249664.0,11280704.0,11087232.0,11118336.0,11149440.0,11180544.0,11211648.0,11242752.0,11273856.0,11304960.0,11111040.0,11142208.0,11173376.0,11204544.0,11235712.0,11266880.0,11298048.0,11329216.0,11134848.0,11166080.0,11197312.0,11228544.0,11259776.0,11291008.0,11322240.0,11353472.0,11158656.0,11189952.0,11221248.0,11252544.0,11283840.0,11315136.0,11346432.0,11377728.0,11182464.0,11213824.0,11245184.0,11276544.0,11307904.0,11339264.0,11370624.0,11401984.0,11206272.0,11237696.0,11269120.0,11300544.0,11331968.0,11363392.0,11394816.0,11426240.0,11230080.0,11261568.0,11293056.0,11324544.0,11356032.0,11387520.0,11419008.0,11450496.0,11253888.0,11285440.0,11316992.0,11348544.0,11380096.0,11411648.0,11443200.0,11474752.0,11277696.0,11309312.0,11340928.0,11372544.0,11404160.0,11435776.0,11467392.0,11499008.0,11301504.0,11333184.0,11364864.0,11396544.0,11428224.0,11459904.0,11491584.0,11523264.0,11325312.0,11357056.0,11388800.0,11420544.0,11452288.0,11484032.0,11515776.0,11547520.0,11349120.0,11380928.0,11412736.0,11444544.0,11476352.0,11508160.0,11539968.0,11571776.0,11372928.0,11404800.0,11436672.0,11468544.0,11500416.0,11532288.0,11564160.0,11596032.0,11396736.0,11428672.0,11460608.0,11492544.0,11524480.0,11556416.0,11588352.0,11620288.0,11420544.0,11452544.0,11484544.0,11516544.0,11548544.0,11580544.0,11612544.0,11644544.0,11444352.0,11476416.0,11508480.0,11540544.0,11572608.0,11604672.0,11636736.0,11668800.0,11468160.0,11500288.0,11532416.0,11564544.0,11596672.0,11628800.0,11660928.0,11693056.0,11491968.0,11524160.0,11556352.0,11588544.0,11620736.0,11652928.0,11685120.0,11717312.0,11515776.0,11548032.0,11580288.0,11612544.0,11644800.0,11677056.0,11709312.0,11741568.0,11539584.0,11571904.0,11604224.0,11636544.0,11668864.0,11701184.0,11733504.0,11765824.0,11563392.0,11595776.0,11628160.0,11660544.0,11692928.0,11725312.0,11757696.0,11790080.0,11472512.0,11503936.0,11535360.0,11566784.0,11598208.0,11629632.0,11661056.0,11692480.0,11373440.0,11403904.0,11434368.0,11464832.0,11495296.0,11525760.0,11556224.0,11586688.0,11397248.0,11427776.0,11458304.0,11488832.0,11519360.0,11549888.0,11580416.0,11610944.0,11421056.0,11451648.0,11482240.0,11512832.0,11543424.0,11574016.0,11604608.0,11635200.0,11444864.0,11475520.0,11506176.0,11536832.0,11567488.0,11598144.0,11628800.0,11659456.0,11468672.0,11499392.0,11530112.0,11560832.0,11591552.0,11622272.0,11652992.0,11683712.0,11492480.0,11523264.0,11554048.0,11584832.0,11615616.0,11646400.0,11677184.0,11707968.0,11516288.0,11547136.0,11577984.0,11608832.0,11639680.0,11670528.0,11701376.0,11732224.0,11540096.0,11571008.0,11601920.0,11632832.0,11663744.0,11694656.0,11725568.0,11756480.0,11563904.0,11594880.0,11625856.0,11656832.0,11687808.0,11718784.0,11749760.0,11780736.0,11587712.0,11618752.0,11649792.0,11680832.0,11711872.0,11742912.0,11773952.0,11804992.0,11611520.0,11642624.0,11673728.0,11704832.0,11735936.0,11767040.0,11798144.0,11829248.0,11635328.0,11666496.0,11697664.0,11728832.0,11760000.0,11791168.0,11822336.0,11853504.0,11659136.0,11690368.0,11721600.0,11752832.0,11784064.0,11815296.0,11846528.0,11877760.0,11682944.0,11714240.0,11745536.0,11776832.0,11808128.0,11839424.0,11870720.0,11902016.0,11706752.0,11738112.0,11769472.0,11800832.0,11832192.0,11863552.0,11894912.0,11926272.0,11730560.0,11761984.0,11793408.0,11824832.0,11856256.0,11887680.0,11919104.0,11950528.0,11754368.0,11785856.0,11817344.0,11848832.0,11880320.0,11911808.0,11943296.0,11974784.0,11778176.0,11809728.0,11841280.0,11872832.0,11904384.0,11935936.0,11967488.0,11999040.0,11801984.0,11833600.0,11865216.0,11896832.0,11928448.0,11960064.0,11991680.0,12023296.0,11825792.0,11857472.0,11889152.0,11920832.0,11952512.0,11984192.0,12015872.0,12047552.0,11849600.0,11881344.0,11913088.0,11944832.0,11976576.0,12008320.0,12040064.0,12071808.0,11873408.0,11905216.0,11937024.0,11968832.0,12000640.0,12032448.0,12064256.0,12096064.0,11897216.0,11929088.0,11960960.0,11992832.0,12024704.0,12056576.0,12088448.0,12120320.0,11921024.0,11952960.0,11984896.0,12016832.0,12048768.0,12080704.0,12112640.0,12144576.0,11944832.0,11976832.0,12008832.0,12040832.0,12072832.0,12104832.0,12136832.0,12168832.0,11968640.0,12000704.0,12032768.0,12064832.0,12096896.0,12128960.0,12161024.0,12193088.0,11992448.0,12024576.0,12056704.0,12088832.0,12120960.0,12153088.0,12185216.0,12217344.0,12016256.0,12048448.0,12080640.0,12112832.0,12145024.0,12177216.0,12209408.0,12241600.0,12040064.0,12072320.0,12104576.0,12136832.0,12169088.0,12201344.0,12233600.0,12265856.0,12063872.0,12096192.0,12128512.0,12160832.0,12193152.0,12225472.0,12257792.0,12290112.0,12087680.0,12120064.0,12152448.0,12184832.0,12217216.0,12249600.0,12281984.0,12314368.0,12111488.0,12143936.0,12176384.0,12208832.0,12241280.0,12273728.0,12306176.0,12338624.0,12135296.0,12167808.0,12200320.0,12232832.0,12265344.0,12297856.0,12330368.0,12362880.0,12159104.0,12191680.0,12224256.0,12256832.0,12289408.0,12321984.0,12354560.0,12387136.0,12182912.0,12215552.0,12248192.0,12280832.0,12313472.0,12346112.0,12378752.0,12411392.0,12206720.0,12239424.0,12272128.0,12304832.0,12337536.0,12370240.0,12402944.0,12435648.0,12230528.0,12263296.0,12296064.0,12328832.0,12361600.0,12394368.0,12427136.0,12459904.0,12254336.0,12287168.0,12320000.0,12352832.0,12385664.0,12418496.0,12451328.0,12484160.0,11639168.0,11671040.0,11702912.0,11734784.0,11766656.0,11798528.0,11830400.0,11862272.0,11015808.0,11046720.0,11077632.0,11108544.0,11139456.0,11170368.0,11201280.0,11232192.0,11039616.0,11070592.0,11101568.0,11132544.0,11163520.0,11194496.0,11225472.0,11256448.0,11063424.0,11094464.0,11125504.0,11156544.0,11187584.0,11218624.0,11249664.0,11280704.0,11087232.0,11118336.0,11149440.0,11180544.0,11211648.0,11242752.0,11273856.0,11304960.0,11111040.0,11142208.0,11173376.0,11204544.0,11235712.0,11266880.0,11298048.0,11329216.0,11134848.0,11166080.0,11197312.0,11228544.0,11259776.0,11291008.0,11322240.0,11353472.0,11158656.0,11189952.0,11221248.0,11252544.0,11283840.0,11315136.0,11346432.0,11377728.0,11182464.0,11213824.0,11245184.0,11276544.0,11307904.0,11339264.0,11370624.0,11401984.0,11206272.0,11237696.0,11269120.0,11300544.0,11331968.0,11363392.0,11394816.0,11426240.0,11230080.0,11261568.0,11293056.0,11324544.0,11356032.0,11387520.0,11419008.0,11450496.0,11253888.0,11285440.0,11316992.0,11348544.0,11380096.0,11411648.0,11443200.0,11474752.0,11277696.0,11309312.0,11340928.0,11372544.0,11404160.0,11435776.0,11467392.0,11499008.0,11301504.0,11333184.0,11364864.0,11396544.0,11428224.0,11459904.0,11491584.0,11523264.0,11325312.0,11357056.0,11388800.0,11420544.0,11452288.0,11484032.0,11515776.0,11547520.0,11349120.0,11380928.0,11412736.0,11444544.0,11476352.0,11508160.0,11539968.0,11571776.0,11372928.0,11404800.0,11436672.0,11468544.0,11500416.0,11532288.0,11564160.0,11596032.0,11396736.0,11428672.0,11460608.0,11492544.0,11524480.0,11556416.0,11588352.0,11620288.0,11420544.0,11452544.0,11484544.0,11516544.0,11548544.0,11580544.0,11612544.0,11644544.0,11444352.0,11476416.0,11508480.0,11540544.0,11572608.0,11604672.0,11636736.0,11668800.0,11468160.0,11500288.0,11532416.0,11564544.0,11596672.0,11628800.0,11660928.0,11693056.0,11491968.0,11524160.0,11556352.0,11588544.0,11620736.0,11652928.0,11685120.0,11717312.0,11515776.0,11548032.0,11580288.0,11612544.0,11644800.0,11677056.0,11709312.0,11741568.0,11539584.0,11571904.0,11604224.0,11636544.0,11668864.0,11701184.0,11733504.0,11765824.0,11563392.0,11595776.0,11628160.0,11660544.0,11692928.0,11725312.0,11757696.0,11790080.0,11587200.0,11619648.0,11652096.0,11684544.0,11716992.0,11749440.0,11781888.0,11814336.0,11611008.0,11643520.0,11676032.0,11708544.0,11741056.0,11773568.0,11806080.0,11838592.0,11634816.0,11667392.0,11699968.0,11732544.0,11765120.0,11797696.0,11830272.0,11862848.0,11658624.0,11691264.0,11723904.0,11756544.0,11789184.0,11821824.0,11854464.0,11887104.0,11682432.0,11715136.0,11747840.0,11780544.0,11813248.0,11845952.0,11878656.0,11911360.0,11706240.0,11739008.0,11771776.0,11804544.0,11837312.0,11870080.0,11902848.0,11935616.0,11730048.0,11762880.0,11795712.0,11828544.0,11861376.0,11894208.0,11927040.0,11959872.0,11753856.0,11786752.0,11819648.0,11852544.0,11885440.0,11918336.0,11951232.0,11984128.0,11777664.0,11810624.0,11843584.0,11876544.0,11909504.0,11942464.0,11975424.0,12008384.0,11801472.0,11834496.0,11867520.0,11900544.0,11933568.0,11966592.0,11999616.0,12032640.0,11825280.0,11858368.0,11891456.0,11924544.0,11957632.0,11990720.0,12023808.0,12056896.0,11849088.0,11882240.0,11915392.0,11948544.0,11981696.0,12014848.0,12048000.0,12081152.0,11872896.0,11906112.0,11939328.0,11972544.0,12005760.0,12038976.0,12072192.0,12105408.0,11896704.0,11929984.0,11963264.0,11996544.0,12029824.0,12063104.0,12096384.0,12129664.0,11805824.0,11838144.0,11870464.0,11902784.0,11935104.0,11967424.0,11999744.0,12032064.0,11706752.0,11738112.0,11769472.0,11800832.0,11832192.0,11863552.0,11894912.0,11926272.0,11730560.0,11761984.0,11793408.0,11824832.0,11856256.0,11887680.0,11919104.0,11950528.0,11754368.0,11785856.0,11817344.0,11848832.0,11880320.0,11911808.0,11943296.0,11974784.0,11778176.0,11809728.0,11841280.0,11872832.0,11904384.0,11935936.0,11967488.0,11999040.0,11801984.0,11833600.0,11865216.0,11896832.0,11928448.0,11960064.0,11991680.0,12023296.0,11825792.0,11857472.0,11889152.0,11920832.0,11952512.0,11984192.0,12015872.0,12047552.0,11849600.0,11881344.0,11913088.0,11944832.0,11976576.0,12008320.0,12040064.0,12071808.0,11873408.0,11905216.0,11937024.0,11968832.0,12000640.0,12032448.0,12064256.0,12096064.0,11897216.0,11929088.0,11960960.0,11992832.0,12024704.0,12056576.0,12088448.0,12120320.0,11921024.0,11952960.0,11984896.0,12016832.0,12048768.0,12080704.0,12112640.0,12144576.0,11944832.0,11976832.0,12008832.0,12040832.0,12072832.0,12104832.0,12136832.0,12168832.0,11968640.0,12000704.0,12032768.0,12064832.0,12096896.0,12128960.0,12161024.0,12193088.0,11992448.0,12024576.0,12056704.0,12088832.0,12120960.0,12153088.0,12185216.0,12217344.0,12016256.0,12048448.0,12080640.0,12112832.0,12145024.0,12177216.0,12209408.0,12241600.0,12040064.0,12072320.0,12104576.0,12136832.0,12169088.0,12201344.0,12233600.0,12265856.0,12063872.0,12096192.0,12128512.0,12160832.0,12193152.0,12225472.0,12257792.0,12290112.0,12087680.0,12120064.0,12152448.0,12184832.0,12217216.0,12249600.0,12281984.0,12314368.0,12111488.0,12143936.0,12176384.0,12208832.0,12241280.0,12273728.0,12306176.0,12338624.0,12135296.0,12167808.0,12200320.0,12232832.0,12265344.0,12297856.0,12330368.0,12362880.0,12159104.0,12191680.0,12224256.0,12256832.0,12289408.0,12321984.0,12354560.0,12387136.0,12182912.0,12215552.0,12248192.0,12280832.0,12313472.0,12346112.0,12378752.0,12411392.0,12206720.0,12239424.0,12272128.0,12304832.0,12337536.0,12370240.0,12402944.0,12435648.0,12230528.0,12263296.0,12296064.0,12328832.0,12361600.0,12394368.0,12427136.0,12459904.0,12254336.0,12287168.0,12320000.0,12352832.0,12385664.0,12418496.0,12451328.0,12484160.0,12278144.0,12311040.0,12343936.0,12376832.0,12409728.0,12442624.0,12475520.0,12508416.0,12301952.0,12334912.0,12367872.0,12400832.0,12433792.0,12466752.0,12499712.0,12532672.0,12325760.0,12358784.0,12391808.0,12424832.0,12457856.0,12490880.0,12523904.0,12556928.0,12349568.0,12382656.0,12415744.0,12448832.0,12481920.0,12515008.0,12548096.0,12581184.0,12373376.0,12406528.0,12439680.0,12472832.0,12505984.0,12539136.0,12572288.0,12605440.0,12397184.0,12430400.0,12463616.0,12496832.0,12530048.0,12563264.0,12596480.0,12629696.0,12420992.0,12454272.0,12487552.0,12520832.0,12554112.0,12587392.0,12620672.0,12653952.0,12444800.0,12478144.0,12511488.0,12544832.0,12578176.0,12611520.0,12644864.0,12678208.0,12468608.0,12502016.0,12535424.0,12568832.0,12602240.0,12635648.0,12669056.0,12702464.0,12492416.0,12525888.0,12559360.0,12592832.0,12626304.0,12659776.0,12693248.0,12726720.0,12516224.0,12549760.0,12583296.0,12616832.0,12650368.0,12683904.0,12717440.0,12750976.0,12540032.0,12573632.0,12607232.0,12640832.0,12674432.0,12708032.0,12741632.0,12775232.0,12563840.0,12597504.0,12631168.0,12664832.0,12698496.0,12732160.0,12765824.0,12799488.0,12587648.0,12621376.0,12655104.0,12688832.0,12722560.0,12756288.0,12790016.0,12823744.0,11972480.0,12005248.0,12038016.0,12070784.0,12103552.0,12136320.0,12169088.0,12201856.0,11349120.0,11380928.0,11412736.0,11444544.0,11476352.0,11508160.0,11539968.0,11571776.0,5744256.0,5760176.0,5776096.0,5792016.0,5807936.0,5823856.0,5839776.0,5855696.0,6346752.0,6362544.0,6378336.0,6394128.0,6409920.0,6425712.0,6441504.0,6457296.0,12832128.0,12863744.0,12895360.0,12926976.0,12958592.0,12990208.0,13021824.0,13053440.0,12856960.0,12888640.0,12920320.0,12952000.0,12983680.0,13015360.0,13047040.0,13078720.0,12881792.0,12913536.0,12945280.0,12977024.0,13008768.0,13040512.0,13072256.0,13104000.0,12906624.0,12938432.0,12970240.0,13002048.0,13033856.0,13065664.0,13097472.0,13129280.0,12931456.0,12963328.0,12995200.0,13027072.0,13058944.0,13090816.0,13122688.0,13154560.0,12956288.0,12988224.0,13020160.0,13052096.0,13084032.0,13115968.0,13147904.0,13179840.0,12981120.0,13013120.0,13045120.0,13077120.0,13109120.0,13141120.0,13173120.0,13205120.0,13005952.0,13038016.0,13070080.0,13102144.0,13134208.0,13166272.0,13198336.0,13230400.0,13030784.0,13062912.0,13095040.0,13127168.0,13159296.0,13191424.0,13223552.0,13255680.0,13055616.0,13087808.0,13120000.0,13152192.0,13184384.0,13216576.0,13248768.0,13280960.0,13080448.0,13112704.0,13144960.0,13177216.0,13209472.0,13241728.0,13273984.0,13306240.0,13105280.0,13137600.0,13169920.0,13202240.0,13234560.0,13266880.0,13299200.0,13331520.0,13130112.0,13162496.0,13194880.0,13227264.0,13259648.0,13292032.0,13324416.0,13356800.0,13154944.0,13187392.0,13219840.0,13252288.0,13284736.0,13317184.0,13349632.0,13382080.0,13179776.0,13212288.0,13244800.0,13277312.0,13309824.0,13342336.0,13374848.0,13407360.0,13204608.0,13237184.0,13269760.0,13302336.0,13334912.0,13367488.0,13400064.0,13432640.0,13229440.0,13262080.0,13294720.0,13327360.0,13360000.0,13392640.0,13425280.0,13457920.0,13254272.0,13286976.0,13319680.0,13352384.0,13385088.0,13417792.0,13450496.0,13483200.0,13279104.0,13311872.0,13344640.0,13377408.0,13410176.0,13442944.0,13475712.0,13508480.0,13303936.0,13336768.0,13369600.0,13402432.0,13435264.0,13468096.0,13500928.0,13533760.0,13328768.0,13361664.0,13394560.0,13427456.0,13460352.0,13493248.0,13526144.0,13559040.0,13353600.0,13386560.0,13419520.0,13452480.0,13485440.0,13518400.0,13551360.0,13584320.0,13378432.0,13411456.0,13444480.0,13477504.0,13510528.0,13543552.0,13576576.0,13609600.0,13403264.0,13436352.0,13469440.0,13502528.0,13535616.0,13568704.0,13601792.0,13634880.0,13428096.0,13461248.0,13494400.0,13527552.0,13560704.0,13593856.0,13627008.0,13660160.0,13452928.0,13486144.0,13519360.0,13552576.0,13585792.0,13619008.0,13652224.0,13685440.0,13477760.0,13511040.0,13544320.0,13577600.0,13610880.0,13644160.0,13677440.0,13710720.0,13502592.0,13535936.0,13569280.0,13602624.0,13635968.0,13669312.0,13702656.0,13736000.0,13527424.0,13560832.0,13594240.0,13627648.0,13661056.0,13694464.0,13727872.0,13761280.0,13552256.0,13585728.0,13619200.0,13652672.0,13686144.0,13719616.0,13753088.0,13786560.0,13577088.0,13610624.0,13644160.0,13677696.0,13711232.0,13744768.0,13778304.0,13811840.0,13601920.0,13635520.0,13669120.0,13702720.0,13736320.0,13769920.0,13803520.0,13837120.0,13626752.0,13660416.0,13694080.0,13727744.0,13761408.0,13795072.0,13828736.0,13862400.0,13651584.0,13685312.0,13719040.0,13752768.0,13786496.0,13820224.0,13853952.0,13887680.0,13676416.0,13710208.0,13744000.0,13777792.0,13811584.0,13845376.0,13879168.0,13912960.0,13701248.0,13735104.0,13768960.0,13802816.0,13836672.0,13870528.0,13904384.0,13938240.0,13070720.0,13103616.0,13136512.0,13169408.0,13202304.0,13235200.0,13268096.0,13300992.0,12432000.0,12463936.0,12495872.0,12527808.0,12559744.0,12591680.0,12623616.0,12655552.0,11801472.0,11832448.0,11863424.0,11894400.0,11925376.0,11956352.0,11987328.0,12018304.0,11162752.0,11192768.0,11222784.0,11252800.0,11282816.0,11312832.0,11342848.0,11372864.0,11187584.0,11217664.0,11247744.0,11277824.0,11307904.0,11337984.0,11368064.0,11398144.0,11212416.0,11242560.0,11272704.0,11302848.0,11332992.0,11363136.0,11393280.0,11423424.0,11237248.0,11267456.0,11297664.0,11327872.0,11358080.0,11388288.0,11418496.0,11448704.0,11262080.0,11292352.0,11322624.0,11352896.0,11383168.0,11413440.0,11443712.0,11473984.0,11286912.0,11317248.0,11347584.0,11377920.0,11408256.0,11438592.0,11468928.0,11499264.0,11311744.0,11342144.0,11372544.0,11402944.0,11433344.0,11463744.0,11494144.0,11524544.0,11336576.0,11367040.0,11397504.0,11427968.0,11458432.0,11488896.0,11519360.0,11549824.0,11361408.0,11391936.0,11422464.0,11452992.0,11483520.0,11514048.0,11544576.0,11575104.0,11386240.0,11416832.0,11447424.0,11478016.0,11508608.0,11539200.0,11569792.0,11600384.0,11411072.0,11441728.0,11472384.0,11503040.0,11533696.0,11564352.0,11595008.0,11625664.0,11435904.0,11466624.0,11497344.0,11528064.0,11558784.0,11589504.0,11620224.0,11650944.0,11460736.0,11491520.0,11522304.0,11553088.0,11583872.0,11614656.0,11645440.0,11676224.0,11485568.0,11516416.0,11547264.0,11578112.0,11608960.0,11639808.0,11670656.0,11701504.0,11510400.0,11541312.0,11572224.0,11603136.0,11634048.0,11664960.0,11695872.0,11726784.0,11535232.0,11566208.0,11597184.0,11628160.0,11659136.0,11690112.0,11721088.0,11752064.0,11560064.0,11591104.0,11622144.0,11653184.0,11684224.0,11715264.0,11746304.0,11777344.0,11584896.0,11616000.0,11647104.0,11678208.0,11709312.0,11740416.0,11771520.0,11802624.0,11609728.0,11640896.0,11672064.0,11703232.0,11734400.0,11765568.0,11796736.0,11827904.0,11634560.0,11665792.0,11697024.0,11728256.0,11759488.0,11790720.0,11821952.0,11853184.0,11659392.0,11690688.0,11721984.0,11753280.0,11784576.0,11815872.0,11847168.0,11878464.0,11684224.0,11715584.0,11746944.0,11778304.0,11809664.0,11841024.0,11872384.0,11903744.0,11709056.0,11740480.0,11771904.0,11803328.0,11834752.0,11866176.0,11897600.0,11929024.0,11733888.0,11765376.0,11796864.0,11828352.0,11859840.0,11891328.0,11922816.0,11954304.0,11758720.0,11790272.0,11821824.0,11853376.0,11884928.0,11916480.0,11948032.0,11979584.0,11783552.0,11815168.0,11846784.0,11878400.0,11910016.0,11941632.0,11973248.0,12004864.0,11808384.0,11840064.0,11871744.0,11903424.0,11935104.0,11966784.0,11998464.0,12030144.0,11833216.0,11864960.0,11896704.0,11928448.0,11960192.0,11991936.0,12023680.0,12055424.0,11858048.0,11889856.0,11921664.0,11953472.0,11985280.0,12017088.0,12048896.0,12080704.0,11882880.0,11914752.0,11946624.0,11978496.0,12010368.0,12042240.0,12074112.0,12105984.0,11907712.0,11939648.0,11971584.0,12003520.0,12035456.0,12067392.0,12099328.0,12131264.0,11932544.0,11964544.0,11996544.0,12028544.0,12060544.0,12092544.0,12124544.0,12156544.0,11957376.0,11989440.0,12021504.0,12053568.0,12085632.0,12117696.0,12149760.0,12181824.0,11982208.0,12014336.0,12046464.0,12078592.0,12110720.0,12142848.0,12174976.0,12207104.0,12007040.0,12039232.0,12071424.0,12103616.0,12135808.0,12168000.0,12200192.0,12232384.0,12031872.0,12064128.0,12096384.0,12128640.0,12160896.0,12193152.0,12225408.0,12257664.0,12056704.0,12089024.0,12121344.0,12153664.0,12185984.0,12218304.0,12250624.0,12282944.0,12081536.0,12113920.0,12146304.0,12178688.0,12211072.0,12243456.0,12275840.0,12308224.0,11975296.0,12006720.0,12038144.0,12069568.0,12100992.0,12132416.0,12163840.0,12195264.0,11860864.0,11891328.0,11921792.0,11952256.0,11982720.0,12013184.0,12043648.0,12074112.0,11885696.0,11916224.0,11946752.0,11977280.0,12007808.0,12038336.0,12068864.0,12099392.0,11910528.0,11941120.0,11971712.0,12002304.0,12032896.0,12063488.0,12094080.0,12124672.0,11935360.0,11966016.0,11996672.0,12027328.0,12057984.0,12088640.0,12119296.0,12149952.0,11960192.0,11990912.0,12021632.0,12052352.0,12083072.0,12113792.0,12144512.0,12175232.0,11985024.0,12015808.0,12046592.0,12077376.0,12108160.0,12138944.0,12169728.0,12200512.0,12009856.0,12040704.0,12071552.0,12102400.0,12133248.0,12164096.0,12194944.0,12225792.0,12034688.0,12065600.0,12096512.0,12127424.0,12158336.0,12189248.0,12220160.0,12251072.0,12059520.0,12090496.0,12121472.0,12152448.0,12183424.0,12214400.0,12245376.0,12276352.0,12084352.0,12115392.0,12146432.0,12177472.0,12208512.0,12239552.0,12270592.0,12301632.0,12109184.0,12140288.0,12171392.0,12202496.0,12233600.0,12264704.0,12295808.0,12326912.0,12134016.0,12165184.0,12196352.0,12227520.0,12258688.0,12289856.0,12321024.0,12352192.0,12158848.0,12190080.0,12221312.0,12252544.0,12283776.0,12315008.0,12346240.0,12377472.0,12183680.0,12214976.0,12246272.0,12277568.0,12308864.0,12340160.0,12371456.0,12402752.0,12208512.0,12239872.0,12271232.0,12302592.0,12333952.0,12365312.0,12396672.0,12428032.0,12233344.0,12264768.0,12296192.0,12327616.0,12359040.0,12390464.0,12421888.0,12453312.0,12258176.0,12289664.0,12321152.0,12352640.0,12384128.0,12415616.0,12447104.0,12478592.0,12283008.0,12314560.0,12346112.0,12377664.0,12409216.0,12440768.0,12472320.0,12503872.0,12307840.0,12339456.0,12371072.0,12402688.0,12434304.0,12465920.0,12497536.0,12529152.0,12332672.0,12364352.0,12396032.0,12427712.0,12459392.0,12491072.0,12522752.0,12554432.0,12357504.0,12389248.0,12420992.0,12452736.0,12484480.0,12516224.0,12547968.0,12579712.0,12382336.0,12414144.0,12445952.0,12477760.0,12509568.0,12541376.0,12573184.0,12604992.0,12407168.0,12439040.0,12470912.0,12502784.0,12534656.0,12566528.0,12598400.0,12630272.0,12432000.0,12463936.0,12495872.0,12527808.0,12559744.0,12591680.0,12623616.0,12655552.0,12456832.0,12488832.0,12520832.0,12552832.0,12584832.0,12616832.0,12648832.0,12680832.0,12481664.0,12513728.0,12545792.0,12577856.0,12609920.0,12641984.0,12674048.0,12706112.0,12506496.0,12538624.0,12570752.0,12602880.0,12635008.0,12667136.0,12699264.0,12731392.0,12531328.0,12563520.0,12595712.0,12627904.0,12660096.0,12692288.0,12724480.0,12756672.0,12556160.0,12588416.0,12620672.0,12652928.0,12685184.0,12717440.0,12749696.0,12781952.0,12580992.0,12613312.0,12645632.0,12677952.0,12710272.0,12742592.0,12774912.0,12807232.0,12605824.0,12638208.0,12670592.0,12702976.0,12735360.0,12767744.0,12800128.0,12832512.0,12630656.0,12663104.0,12695552.0,12728000.0,12760448.0,12792896.0,12825344.0,12857792.0,12655488.0,12688000.0,12720512.0,12753024.0,12785536.0,12818048.0,12850560.0,12883072.0,12680320.0,12712896.0,12745472.0,12778048.0,12810624.0,12843200.0,12875776.0,12908352.0,12705152.0,12737792.0,12770432.0,12803072.0,12835712.0,12868352.0,12900992.0,12933632.0,12729984.0,12762688.0,12795392.0,12828096.0,12860800.0,12893504.0,12926208.0,12958912.0,12754816.0,12787584.0,12820352.0,12853120.0,12885888.0,12918656.0,12951424.0,12984192.0,12779648.0,12812480.0,12845312.0,12878144.0,12910976.0,12943808.0,12976640.0,13009472.0,12149120.0,12180992.0,12212864.0,12244736.0,12276608.0,12308480.0,12340352.0,12372224.0,11510400.0,11541312.0,11572224.0,11603136.0,11634048.0,11664960.0,11695872.0,11726784.0,11535232.0,11566208.0,11597184.0,11628160.0,11659136.0,11690112.0,11721088.0,11752064.0,11560064.0,11591104.0,11622144.0,11653184.0,11684224.0,11715264.0,11746304.0,11777344.0,11584896.0,11616000.0,11647104.0,11678208.0,11709312.0,11740416.0,11771520.0,11802624.0,11609728.0,11640896.0,11672064.0,11703232.0,11734400.0,11765568.0,11796736.0,11827904.0,11634560.0,11665792.0,11697024.0,11728256.0,11759488.0,11790720.0,11821952.0,11853184.0,11659392.0,11690688.0,11721984.0,11753280.0,11784576.0,11815872.0,11847168.0,11878464.0,11684224.0,11715584.0,11746944.0,11778304.0,11809664.0,11841024.0,11872384.0,11903744.0,11709056.0,11740480.0,11771904.0,11803328.0,11834752.0,11866176.0,11897600.0,11929024.0,11733888.0,11765376.0,11796864.0,11828352.0,11859840.0,11891328.0,11922816.0,11954304.0,11758720.0,11790272.0,11821824.0,11853376.0,11884928.0,11916480.0,11948032.0,11979584.0,11783552.0,11815168.0,11846784.0,11878400.0,11910016.0,11941632.0,11973248.0,12004864.0,11808384.0,11840064.0,11871744.0,11903424.0,11935104.0,11966784.0,11998464.0,12030144.0,11833216.0,11864960.0,11896704.0,11928448.0,11960192.0,11991936.0,12023680.0,12055424.0,11858048.0,11889856.0,11921664.0,11953472.0,11985280.0,12017088.0,12048896.0,12080704.0,11882880.0,11914752.0,11946624.0,11978496.0,12010368.0,12042240.0,12074112.0,12105984.0,11907712.0,11939648.0,11971584.0,12003520.0,12035456.0,12067392.0,12099328.0,12131264.0,11932544.0,11964544.0,11996544.0,12028544.0,12060544.0,12092544.0,12124544.0,12156544.0,11957376.0,11989440.0,12021504.0,12053568.0,12085632.0,12117696.0,12149760.0,12181824.0,11982208.0,12014336.0,12046464.0,12078592.0,12110720.0,12142848.0,12174976.0,12207104.0,12007040.0,12039232.0,12071424.0,12103616.0,12135808.0,12168000.0,12200192.0,12232384.0,12031872.0,12064128.0,12096384.0,12128640.0,12160896.0,12193152.0,12225408.0,12257664.0,12056704.0,12089024.0,12121344.0,12153664.0,12185984.0,12218304.0,12250624.0,12282944.0,12081536.0,12113920.0,12146304.0,12178688.0,12211072.0,12243456.0,12275840.0,12308224.0,12106368.0,12138816.0,12171264.0,12203712.0,12236160.0,12268608.0,12301056.0,12333504.0,12131200.0,12163712.0,12196224.0,12228736.0,12261248.0,12293760.0,12326272.0,12358784.0,12156032.0,12188608.0,12221184.0,12253760.0,12286336.0,12318912.0,12351488.0,12384064.0,12180864.0,12213504.0,12246144.0,12278784.0,12311424.0,12344064.0,12376704.0,12409344.0,12205696.0,12238400.0,12271104.0,12303808.0,12336512.0,12369216.0,12401920.0,12434624.0,12230528.0,12263296.0,12296064.0,12328832.0,12361600.0,12394368.0,12427136.0,12459904.0,12255360.0,12288192.0,12321024.0,12353856.0,12386688.0,12419520.0,12452352.0,12485184.0,12280192.0,12313088.0,12345984.0,12378880.0,12411776.0,12444672.0,12477568.0,12510464.0,12305024.0,12337984.0,12370944.0,12403904.0,12436864.0,12469824.0,12502784.0,12535744.0,12329856.0,12362880.0,12395904.0,12428928.0,12461952.0,12494976.0,12528000.0,12561024.0,12354688.0,12387776.0,12420864.0,12453952.0,12487040.0,12520128.0,12553216.0,12586304.0,12379520.0,12412672.0,12445824.0,12478976.0,12512128.0,12545280.0,12578432.0,12611584.0,12404352.0,12437568.0,12470784.0,12504000.0,12537216.0,12570432.0,12603648.0,12636864.0,12429184.0,12462464.0,12495744.0,12529024.0,12562304.0,12595584.0,12628864.0,12662144.0,12322944.0,12355264.0,12387584.0,12419904.0,12452224.0,12484544.0,12516864.0,12549184.0,12208512.0,12239872.0,12271232.0,12302592.0,12333952.0,12365312.0,12396672.0,12428032.0,12233344.0,12264768.0,12296192.0,12327616.0,12359040.0,12390464.0,12421888.0,12453312.0,12258176.0,12289664.0,12321152.0,12352640.0,12384128.0,12415616.0,12447104.0,12478592.0,12283008.0,12314560.0,12346112.0,12377664.0,12409216.0,12440768.0,12472320.0,12503872.0,12307840.0,12339456.0,12371072.0,12402688.0,12434304.0,12465920.0,12497536.0,12529152.0,12332672.0,12364352.0,12396032.0,12427712.0,12459392.0,12491072.0,12522752.0,12554432.0,12357504.0,12389248.0,12420992.0,12452736.0,12484480.0,12516224.0,12547968.0,12579712.0,12382336.0,12414144.0,12445952.0,12477760.0,12509568.0,12541376.0,12573184.0,12604992.0,12407168.0,12439040.0,12470912.0,12502784.0,12534656.0,12566528.0,12598400.0,12630272.0,12432000.0,12463936.0,12495872.0,12527808.0,12559744.0,12591680.0,12623616.0,12655552.0,12456832.0,12488832.0,12520832.0,12552832.0,12584832.0,12616832.0,12648832.0,12680832.0,12481664.0,12513728.0,12545792.0,12577856.0,12609920.0,12641984.0,12674048.0,12706112.0,12506496.0,12538624.0,12570752.0,12602880.0,12635008.0,12667136.0,12699264.0,12731392.0,12531328.0,12563520.0,12595712.0,12627904.0,12660096.0,12692288.0,12724480.0,12756672.0,12556160.0,12588416.0,12620672.0,12652928.0,12685184.0,12717440.0,12749696.0,12781952.0,12580992.0,12613312.0,12645632.0,12677952.0,12710272.0,12742592.0,12774912.0,12807232.0,12605824.0,12638208.0,12670592.0,12702976.0,12735360.0,12767744.0,12800128.0,12832512.0,12630656.0,12663104.0,12695552.0,12728000.0,12760448.0,12792896.0,12825344.0,12857792.0,12655488.0,12688000.0,12720512.0,12753024.0,12785536.0,12818048.0,12850560.0,12883072.0,12680320.0,12712896.0,12745472.0,12778048.0,12810624.0,12843200.0,12875776.0,12908352.0,12705152.0,12737792.0,12770432.0,12803072.0,12835712.0,12868352.0,12900992.0,12933632.0,12729984.0,12762688.0,12795392.0,12828096.0,12860800.0,12893504.0,12926208.0,12958912.0,12754816.0,12787584.0,12820352.0,12853120.0,12885888.0,12918656.0,12951424.0,12984192.0,12779648.0,12812480.0,12845312.0,12878144.0,12910976.0,12943808.0,12976640.0,13009472.0,12804480.0,12837376.0,12870272.0,12903168.0,12936064.0,12968960.0,13001856.0,13034752.0,12829312.0,12862272.0,12895232.0,12928192.0,12961152.0,12994112.0,13027072.0,13060032.0,12854144.0,12887168.0,12920192.0,12953216.0,12986240.0,13019264.0,13052288.0,13085312.0,12878976.0,12912064.0,12945152.0,12978240.0,13011328.0,13044416.0,13077504.0,13110592.0,12903808.0,12936960.0,12970112.0,13003264.0,13036416.0,13069568.0,13102720.0,13135872.0,12928640.0,12961856.0,12995072.0,13028288.0,13061504.0,13094720.0,13127936.0,13161152.0,12953472.0,12986752.0,13020032.0,13053312.0,13086592.0,13119872.0,13153152.0,13186432.0,12978304.0,13011648.0,13044992.0,13078336.0,13111680.0,13145024.0,13178368.0,13211712.0,13003136.0,13036544.0,13069952.0,13103360.0,13136768.0,13170176.0,13203584.0,13236992.0,13027968.0,13061440.0,13094912.0,13128384.0,13161856.0,13195328.0,13228800.0,13262272.0,13052800.0,13086336.0,13119872.0,13153408.0,13186944.0,13220480.0,13254016.0,13287552.0,13077632.0,13111232.0,13144832.0,13178432.0,13212032.0,13245632.0,13279232.0,13312832.0,13102464.0,13136128.0,13169792.0,13203456.0,13237120.0,13270784.0,13304448.0,13338112.0,13127296.0,13161024.0,13194752.0,13228480.0,13262208.0,13295936.0,13329664.0,13363392.0,12496768.0,12529536.0,12562304.0,12595072.0,12627840.0,12660608.0,12693376.0,12726144.0,11858048.0,11889856.0,11921664.0,11953472.0,11985280.0,12017088.0,12048896.0,12080704.0,5998976.0,6014896.0,6030816.0,6046736.0,6062656.0,6078576.0,6094496.0,6110416.0,6599424.0,6615216.0,6631008.0,6646800.0,6662592.0,6678384.0,6694176.0,6709968.0,13337984.0,13369600.0,13401216.0,13432832.0,13464448.0,13496064.0,13527680.0,13559296.0,13363840.0,13395520.0,13427200.0,13458880.0,13490560.0,13522240.0,13553920.0,13585600.0,13389696.0,13421440.0,13453184.0,13484928.0,13516672.0,13548416.0,13580160.0,13611904.0,13415552.0,13447360.0,13479168.0,13510976.0,13542784.0,13574592.0,13606400.0,13638208.0,13441408.0,13473280.0,13505152.0,13537024.0,13568896.0,13600768.0,13632640.0,13664512.0,13467264.0,13499200.0,13531136.0,13563072.0,13595008.0,13626944.0,13658880.0,13690816.0,13493120.0,13525120.0,13557120.0,13589120.0,13621120.0,13653120.0,13685120.0,13717120.0,13518976.0,13551040.0,13583104.0,13615168.0,13647232.0,13679296.0,13711360.0,13743424.0,13544832.0,13576960.0,13609088.0,13641216.0,13673344.0,13705472.0,13737600.0,13769728.0,13570688.0,13602880.0,13635072.0,13667264.0,13699456.0,13731648.0,13763840.0,13796032.0,13596544.0,13628800.0,13661056.0,13693312.0,13725568.0,13757824.0,13790080.0,13822336.0,13622400.0,13654720.0,13687040.0,13719360.0,13751680.0,13784000.0,13816320.0,13848640.0,13648256.0,13680640.0,13713024.0,13745408.0,13777792.0,13810176.0,13842560.0,13874944.0,13674112.0,13706560.0,13739008.0,13771456.0,13803904.0,13836352.0,13868800.0,13901248.0,13699968.0,13732480.0,13764992.0,13797504.0,13830016.0,13862528.0,13895040.0,13927552.0,13725824.0,13758400.0,13790976.0,13823552.0,13856128.0,13888704.0,13921280.0,13953856.0,13751680.0,13784320.0,13816960.0,13849600.0,13882240.0,13914880.0,13947520.0,13980160.0,13777536.0,13810240.0,13842944.0,13875648.0,13908352.0,13941056.0,13973760.0,14006464.0,13803392.0,13836160.0,13868928.0,13901696.0,13934464.0,13967232.0,14000000.0,14032768.0,13829248.0,13862080.0,13894912.0,13927744.0,13960576.0,13993408.0,14026240.0,14059072.0,13855104.0,13888000.0,13920896.0,13953792.0,13986688.0,14019584.0,14052480.0,14085376.0,13880960.0,13913920.0,13946880.0,13979840.0,14012800.0,14045760.0,14078720.0,14111680.0,13906816.0,13939840.0,13972864.0,14005888.0,14038912.0,14071936.0,14104960.0,14137984.0,13932672.0,13965760.0,13998848.0,14031936.0,14065024.0,14098112.0,14131200.0,14164288.0,13958528.0,13991680.0,14024832.0,14057984.0,14091136.0,14124288.0,14157440.0,14190592.0,13984384.0,14017600.0,14050816.0,14084032.0,14117248.0,14150464.0,14183680.0,14216896.0,14010240.0,14043520.0,14076800.0,14110080.0,14143360.0,14176640.0,14209920.0,14243200.0,14036096.0,14069440.0,14102784.0,14136128.0,14169472.0,14202816.0,14236160.0,14269504.0,14061952.0,14095360.0,14128768.0,14162176.0,14195584.0,14228992.0,14262400.0,14295808.0,14087808.0,14121280.0,14154752.0,14188224.0,14221696.0,14255168.0,14288640.0,14322112.0,14113664.0,14147200.0,14180736.0,14214272.0,14247808.0,14281344.0,14314880.0,14348416.0,14139520.0,14173120.0,14206720.0,14240320.0,14273920.0,14307520.0,14341120.0,14374720.0,14165376.0,14199040.0,14232704.0,14266368.0,14300032.0,14333696.0,14367360.0,14401024.0,14191232.0,14224960.0,14258688.0,14292416.0,14326144.0,14359872.0,14393600.0,14427328.0,14217088.0,14250880.0,14284672.0,14318464.0,14352256.0,14386048.0,14419840.0,14453632.0,14242944.0,14276800.0,14310656.0,14344512.0,14378368.0,14412224.0,14446080.0,14479936.0,13597056.0,13629952.0,13662848.0,13695744.0,13728640.0,13761536.0,13794432.0,13827328.0,12942976.0,12974912.0,13006848.0,13038784.0,13070720.0,13102656.0,13134592.0,13166528.0,12297088.0,12328064.0,12359040.0,12390016.0,12420992.0,12451968.0,12482944.0,12513920.0,11643008.0,11673024.0,11703040.0,11733056.0,11763072.0,11793088.0,11823104.0,11853120.0,11668864.0,11698944.0,11729024.0,11759104.0,11789184.0,11819264.0,11849344.0,11879424.0,11694720.0,11724864.0,11755008.0,11785152.0,11815296.0,11845440.0,11875584.0,11905728.0,11720576.0,11750784.0,11780992.0,11811200.0,11841408.0,11871616.0,11901824.0,11932032.0,11746432.0,11776704.0,11806976.0,11837248.0,11867520.0,11897792.0,11928064.0,11958336.0,11772288.0,11802624.0,11832960.0,11863296.0,11893632.0,11923968.0,11954304.0,11984640.0,11798144.0,11828544.0,11858944.0,11889344.0,11919744.0,11950144.0,11980544.0,12010944.0,11824000.0,11854464.0,11884928.0,11915392.0,11945856.0,11976320.0,12006784.0,12037248.0,11849856.0,11880384.0,11910912.0,11941440.0,11971968.0,12002496.0,12033024.0,12063552.0,11875712.0,11906304.0,11936896.0,11967488.0,11998080.0,12028672.0,12059264.0,12089856.0,11901568.0,11932224.0,11962880.0,11993536.0,12024192.0,12054848.0,12085504.0,12116160.0,11927424.0,11958144.0,11988864.0,12019584.0,12050304.0,12081024.0,12111744.0,12142464.0,11953280.0,11984064.0,12014848.0,12045632.0,12076416.0,12107200.0,12137984.0,12168768.0,11979136.0,12009984.0,12040832.0,12071680.0,12102528.0,12133376.0,12164224.0,12195072.0,12004992.0,12035904.0,12066816.0,12097728.0,12128640.0,12159552.0,12190464.0,12221376.0,12030848.0,12061824.0,12092800.0,12123776.0,12154752.0,12185728.0,12216704.0,12247680.0,12056704.0,12087744.0,12118784.0,12149824.0,12180864.0,12211904.0,12242944.0,12273984.0,12082560.0,12113664.0,12144768.0,12175872.0,12206976.0,12238080.0,12269184.0,12300288.0,12108416.0,12139584.0,12170752.0,12201920.0,12233088.0,12264256.0,12295424.0,12326592.0,12134272.0,12165504.0,12196736.0,12227968.0,12259200.0,12290432.0,12321664.0,12352896.0,12160128.0,12191424.0,12222720.0,12254016.0,12285312.0,12316608.0,12347904.0,12379200.0,12185984.0,12217344.0,12248704.0,12280064.0,12311424.0,12342784.0,12374144.0,12405504.0,12211840.0,12243264.0,12274688.0,12306112.0,12337536.0,12368960.0,12400384.0,12431808.0,12237696.0,12269184.0,12300672.0,12332160.0,12363648.0,12395136.0,12426624.0,12458112.0,12263552.0,12295104.0,12326656.0,12358208.0,12389760.0,12421312.0,12452864.0,12484416.0,12289408.0,12321024.0,12352640.0,12384256.0,12415872.0,12447488.0,12479104.0,12510720.0,12315264.0,12346944.0,12378624.0,12410304.0,12441984.0,12473664.0,12505344.0,12537024.0,12341120.0,12372864.0,12404608.0,12436352.0,12468096.0,12499840.0,12531584.0,12563328.0,12366976.0,12398784.0,12430592.0,12462400.0,12494208.0,12526016.0,12557824.0,12589632.0,12392832.0,12424704.0,12456576.0,12488448.0,12520320.0,12552192.0,12584064.0,12615936.0,12418688.0,12450624.0,12482560.0,12514496.0,12546432.0,12578368.0,12610304.0,12642240.0,12444544.0,12476544.0,12508544.0,12540544.0,12572544.0,12604544.0,12636544.0,12668544.0,12470400.0,12502464.0,12534528.0,12566592.0,12598656.0,12630720.0,12662784.0,12694848.0,12496256.0,12528384.0,12560512.0,12592640.0,12624768.0,12656896.0,12689024.0,12721152.0,12522112.0,12554304.0,12586496.0,12618688.0,12650880.0,12683072.0,12715264.0,12747456.0,12547968.0,12580224.0,12612480.0,12644736.0,12676992.0,12709248.0,12741504.0,12773760.0,12573824.0,12606144.0,12638464.0,12670784.0,12703104.0,12735424.0,12767744.0,12800064.0,12599680.0,12632064.0,12664448.0,12696832.0,12729216.0,12761600.0,12793984.0,12826368.0,12478080.0,12509504.0,12540928.0,12572352.0,12603776.0,12635200.0,12666624.0,12698048.0,12348288.0,12378752.0,12409216.0,12439680.0,12470144.0,12500608.0,12531072.0,12561536.0,12374144.0,12404672.0,12435200.0,12465728.0,12496256.0,12526784.0,12557312.0,12587840.0,12400000.0,12430592.0,12461184.0,12491776.0,12522368.0,12552960.0,12583552.0,12614144.0,12425856.0,12456512.0,12487168.0,12517824.0,12548480.0,12579136.0,12609792.0,12640448.0,12451712.0,12482432.0,12513152.0,12543872.0,12574592.0,12605312.0,12636032.0,12666752.0,12477568.0,12508352.0,12539136.0,12569920.0,12600704.0,12631488.0,12662272.0,12693056.0,12503424.0,12534272.0,12565120.0,12595968.0,12626816.0,12657664.0,12688512.0,12719360.0,12529280.0,12560192.0,12591104.0,12622016.0,12652928.0,12683840.0,12714752.0,12745664.0,12555136.0,12586112.0,12617088.0,12648064.0,12679040.0,12710016.0,12740992.0,12771968.0,12580992.0,12612032.0,12643072.0,12674112.0,12705152.0,12736192.0,12767232.0,12798272.0,12606848.0,12637952.0,12669056.0,12700160.0,12731264.0,12762368.0,12793472.0,12824576.0,12632704.0,12663872.0,12695040.0,12726208.0,12757376.0,12788544.0,12819712.0,12850880.0,12658560.0,12689792.0,12721024.0,12752256.0,12783488.0,12814720.0,12845952.0,12877184.0,12684416.0,12715712.0,12747008.0,12778304.0,12809600.0,12840896.0,12872192.0,12903488.0,12710272.0,12741632.0,12772992.0,12804352.0,12835712.0,12867072.0,12898432.0,12929792.0,12736128.0,12767552.0,12798976.0,12830400.0,12861824.0,12893248.0,12924672.0,12956096.0,12761984.0,12793472.0,12824960.0,12856448.0,12887936.0,12919424.0,12950912.0,12982400.0,12787840.0,12819392.0,12850944.0,12882496.0,12914048.0,12945600.0,12977152.0,13008704.0,12813696.0,12845312.0,12876928.0,12908544.0,12940160.0,12971776.0,13003392.0,13035008.0,12839552.0,12871232.0,12902912.0,12934592.0,12966272.0,12997952.0,13029632.0,13061312.0,12865408.0,12897152.0,12928896.0,12960640.0,12992384.0,13024128.0,13055872.0,13087616.0,12891264.0,12923072.0,12954880.0,12986688.0,13018496.0,13050304.0,13082112.0,13113920.0,12917120.0,12948992.0,12980864.0,13012736.0,13044608.0,13076480.0,13108352.0,13140224.0,12942976.0,12974912.0,13006848.0,13038784.0,13070720.0,13102656.0,13134592.0,13166528.0,12968832.0,13000832.0,13032832.0,13064832.0,13096832.0,13128832.0,13160832.0,13192832.0,12994688.0,13026752.0,13058816.0,13090880.0,13122944.0,13155008.0,13187072.0,13219136.0,13020544.0,13052672.0,13084800.0,13116928.0,13149056.0,13181184.0,13213312.0,13245440.0,13046400.0,13078592.0,13110784.0,13142976.0,13175168.0,13207360.0,13239552.0,13271744.0,13072256.0,13104512.0,13136768.0,13169024.0,13201280.0,13233536.0,13265792.0,13298048.0,13098112.0,13130432.0,13162752.0,13195072.0,13227392.0,13259712.0,13292032.0,13324352.0,13123968.0,13156352.0,13188736.0,13221120.0,13253504.0,13285888.0,13318272.0,13350656.0,13149824.0,13182272.0,13214720.0,13247168.0,13279616.0,13312064.0,13344512.0,13376960.0,13175680.0,13208192.0,13240704.0,13273216.0,13305728.0,13338240.0,13370752.0,13403264.0,13201536.0,13234112.0,13266688.0,13299264.0,13331840.0,13364416.0,13396992.0,13429568.0,13227392.0,13260032.0,13292672.0,13325312.0,13357952.0,13390592.0,13423232.0,13455872.0,13253248.0,13285952.0,13318656.0,13351360.0,13384064.0,13416768.0,13449472.0,13482176.0,13279104.0,13311872.0,13344640.0,13377408.0,13410176.0,13442944.0,13475712.0,13508480.0,13304960.0,13337792.0,13370624.0,13403456.0,13436288.0,13469120.0,13501952.0,13534784.0,12659072.0,12690944.0,12722816.0,12754688.0,12786560.0,12818432.0,12850304.0,12882176.0,12004992.0,12035904.0,12066816.0,12097728.0,12128640.0,12159552.0,12190464.0,12221376.0,12030848.0,12061824.0,12092800.0,12123776.0,12154752.0,12185728.0,12216704.0,12247680.0,12056704.0,12087744.0,12118784.0,12149824.0,12180864.0,12211904.0,12242944.0,12273984.0,12082560.0,12113664.0,12144768.0,12175872.0,12206976.0,12238080.0,12269184.0,12300288.0,12108416.0,12139584.0,12170752.0,12201920.0,12233088.0,12264256.0,12295424.0,12326592.0,12134272.0,12165504.0,12196736.0,12227968.0,12259200.0,12290432.0,12321664.0,12352896.0,12160128.0,12191424.0,12222720.0,12254016.0,12285312.0,12316608.0,12347904.0,12379200.0,12185984.0,12217344.0,12248704.0,12280064.0,12311424.0,12342784.0,12374144.0,12405504.0,12211840.0,12243264.0,12274688.0,12306112.0,12337536.0,12368960.0,12400384.0,12431808.0,12237696.0,12269184.0,12300672.0,12332160.0,12363648.0,12395136.0,12426624.0,12458112.0,12263552.0,12295104.0,12326656.0,12358208.0,12389760.0,12421312.0,12452864.0,12484416.0,12289408.0,12321024.0,12352640.0,12384256.0,12415872.0,12447488.0,12479104.0,12510720.0,12315264.0,12346944.0,12378624.0,12410304.0,12441984.0,12473664.0,12505344.0,12537024.0,12341120.0,12372864.0,12404608.0,12436352.0,12468096.0,12499840.0,12531584.0,12563328.0,12366976.0,12398784.0,12430592.0,12462400.0,12494208.0,12526016.0,12557824.0,12589632.0,12392832.0,12424704.0,12456576.0,12488448.0,12520320.0,12552192.0,12584064.0,12615936.0,12418688.0,12450624.0,12482560.0,12514496.0,12546432.0,12578368.0,12610304.0,12642240.0,12444544.0,12476544.0,12508544.0,12540544.0,12572544.0,12604544.0,12636544.0,12668544.0,12470400.0,12502464.0,12534528.0,12566592.0,12598656.0,12630720.0,12662784.0,12694848.0,12496256.0,12528384.0,12560512.0,12592640.0,12624768.0,12656896.0,12689024.0,12721152.0,12522112.0,12554304.0,12586496.0,12618688.0,12650880.0,12683072.0,12715264.0,12747456.0,12547968.0,12580224.0,12612480.0,12644736.0,12676992.0,12709248.0,12741504.0,12773760.0,12573824.0,12606144.0,12638464.0,12670784.0,12703104.0,12735424.0,12767744.0,12800064.0,12599680.0,12632064.0,12664448.0,12696832.0,12729216.0,12761600.0,12793984.0,12826368.0,12625536.0,12657984.0,12690432.0,12722880.0,12755328.0,12787776.0,12820224.0,12852672.0,12651392.0,12683904.0,12716416.0,12748928.0,12781440.0,12813952.0,12846464.0,12878976.0,12677248.0,12709824.0,12742400.0,12774976.0,12807552.0,12840128.0,12872704.0,12905280.0,12703104.0,12735744.0,12768384.0,12801024.0,12833664.0,12866304.0,12898944.0,12931584.0,12728960.0,12761664.0,12794368.0,12827072.0,12859776.0,12892480.0,12925184.0,12957888.0,12754816.0,12787584.0,12820352.0,12853120.0,12885888.0,12918656.0,12951424.0,12984192.0,12780672.0,12813504.0,12846336.0,12879168.0,12912000.0,12944832.0,12977664.0,13010496.0,12806528.0,12839424.0,12872320.0,12905216.0,12938112.0,12971008.0,13003904.0,13036800.0,12832384.0,12865344.0,12898304.0,12931264.0,12964224.0,12997184.0,13030144.0,13063104.0,12858240.0,12891264.0,12924288.0,12957312.0,12990336.0,13023360.0,13056384.0,13089408.0,12884096.0,12917184.0,12950272.0,12983360.0,13016448.0,13049536.0,13082624.0,13115712.0,12909952.0,12943104.0,12976256.0,13009408.0,13042560.0,13075712.0,13108864.0,13142016.0,12935808.0,12969024.0,13002240.0,13035456.0,13068672.0,13101888.0,13135104.0,13168320.0,12961664.0,12994944.0,13028224.0,13061504.0,13094784.0,13128064.0,13161344.0,13194624.0,12840064.0,12872384.0,12904704.0,12937024.0,12969344.0,13001664.0,13033984.0,13066304.0,12710272.0,12741632.0,12772992.0,12804352.0,12835712.0,12867072.0,12898432.0,12929792.0,12736128.0,12767552.0,12798976.0,12830400.0,12861824.0,12893248.0,12924672.0,12956096.0,12761984.0,12793472.0,12824960.0,12856448.0,12887936.0,12919424.0,12950912.0,12982400.0,12787840.0,12819392.0,12850944.0,12882496.0,12914048.0,12945600.0,12977152.0,13008704.0,12813696.0,12845312.0,12876928.0,12908544.0,12940160.0,12971776.0,13003392.0,13035008.0,12839552.0,12871232.0,12902912.0,12934592.0,12966272.0,12997952.0,13029632.0,13061312.0,12865408.0,12897152.0,12928896.0,12960640.0,12992384.0,13024128.0,13055872.0,13087616.0,12891264.0,12923072.0,12954880.0,12986688.0,13018496.0,13050304.0,13082112.0,13113920.0,12917120.0,12948992.0,12980864.0,13012736.0,13044608.0,13076480.0,13108352.0,13140224.0,12942976.0,12974912.0,13006848.0,13038784.0,13070720.0,13102656.0,13134592.0,13166528.0,12968832.0,13000832.0,13032832.0,13064832.0,13096832.0,13128832.0,13160832.0,13192832.0,12994688.0,13026752.0,13058816.0,13090880.0,13122944.0,13155008.0,13187072.0,13219136.0,13020544.0,13052672.0,13084800.0,13116928.0,13149056.0,13181184.0,13213312.0,13245440.0,13046400.0,13078592.0,13110784.0,13142976.0,13175168.0,13207360.0,13239552.0,13271744.0,13072256.0,13104512.0,13136768.0,13169024.0,13201280.0,13233536.0,13265792.0,13298048.0,13098112.0,13130432.0,13162752.0,13195072.0,13227392.0,13259712.0,13292032.0,13324352.0,13123968.0,13156352.0,13188736.0,13221120.0,13253504.0,13285888.0,13318272.0,13350656.0,13149824.0,13182272.0,13214720.0,13247168.0,13279616.0,13312064.0,13344512.0,13376960.0,13175680.0,13208192.0,13240704.0,13273216.0,13305728.0,13338240.0,13370752.0,13403264.0,13201536.0,13234112.0,13266688.0,13299264.0,13331840.0,13364416.0,13396992.0,13429568.0,13227392.0,13260032.0,13292672.0,13325312.0,13357952.0,13390592.0,13423232.0,13455872.0,13253248.0,13285952.0,13318656.0,13351360.0,13384064.0,13416768.0,13449472.0,13482176.0,13279104.0,13311872.0,13344640.0,13377408.0,13410176.0,13442944.0,13475712.0,13508480.0,13304960.0,13337792.0,13370624.0,13403456.0,13436288.0,13469120.0,13501952.0,13534784.0,13330816.0,13363712.0,13396608.0,13429504.0,13462400.0,13495296.0,13528192.0,13561088.0,13356672.0,13389632.0,13422592.0,13455552.0,13488512.0,13521472.0,13554432.0,13587392.0,13382528.0,13415552.0,13448576.0,13481600.0,13514624.0,13547648.0,13580672.0,13613696.0,13408384.0,13441472.0,13474560.0,13507648.0,13540736.0,13573824.0,13606912.0,13640000.0,13434240.0,13467392.0,13500544.0,13533696.0,13566848.0,13600000.0,13633152.0,13666304.0,13460096.0,13493312.0,13526528.0,13559744.0,13592960.0,13626176.0,13659392.0,13692608.0,13485952.0,13519232.0,13552512.0,13585792.0,13619072.0,13652352.0,13685632.0,13718912.0,13511808.0,13545152.0,13578496.0,13611840.0,13645184.0,13678528.0,13711872.0,13745216.0,13537664.0,13571072.0,13604480.0,13637888.0,13671296.0,13704704.0,13738112.0,13771520.0,13563520.0,13596992.0,13630464.0,13663936.0,13697408.0,13730880.0,13764352.0,13797824.0,13589376.0,13622912.0,13656448.0,13689984.0,13723520.0,13757056.0,13790592.0,13824128.0,13615232.0,13648832.0,13682432.0,13716032.0,13749632.0,13783232.0,13816832.0,13850432.0,13641088.0,13674752.0,13708416.0,13742080.0,13775744.0,13809408.0,13843072.0,13876736.0,13666944.0,13700672.0,13734400.0,13768128.0,13801856.0,13835584.0,13869312.0,13903040.0,13021056.0,13053824.0,13086592.0,13119360.0,13152128.0,13184896.0,13217664.0,13250432.0,12366976.0,12398784.0,12430592.0,12462400.0,12494208.0,12526016.0,12557824.0,12589632.0,6253696.0,6269616.0,6285536.0,6301456.0,6317376.0,6333296.0,6349216.0,6365136.0,6852096.0,6867888.0,6883680.0,6899472.0,6915264.0,6931056.0,6946848.0,6962640.0,13843840.0,13875456.0,13907072.0,13938688.0,13970304.0,14001920.0,14033536.0,14065152.0,13870720.0,13902400.0,13934080.0,13965760.0,13997440.0,14029120.0,14060800.0,14092480.0,13897600.0,13929344.0,13961088.0,13992832.0,14024576.0,14056320.0,14088064.0,14119808.0,13924480.0,13956288.0,13988096.0,14019904.0,14051712.0,14083520.0,14115328.0,14147136.0,13951360.0,13983232.0,14015104.0,14046976.0,14078848.0,14110720.0,14142592.0,14174464.0,13978240.0,14010176.0,14042112.0,14074048.0,14105984.0,14137920.0,14169856.0,14201792.0,14005120.0,14037120.0,14069120.0,14101120.0,14133120.0,14165120.0,14197120.0,14229120.0,14032000.0,14064064.0,14096128.0,14128192.0,14160256.0,14192320.0,14224384.0,14256448.0,14058880.0,14091008.0,14123136.0,14155264.0,14187392.0,14219520.0,14251648.0,14283776.0,14085760.0,14117952.0,14150144.0,14182336.0,14214528.0,14246720.0,14278912.0,14311104.0,14112640.0,14144896.0,14177152.0,14209408.0,14241664.0,14273920.0,14306176.0,14338432.0,14139520.0,14171840.0,14204160.0,14236480.0,14268800.0,14301120.0,14333440.0,14365760.0,14166400.0,14198784.0,14231168.0,14263552.0,14295936.0,14328320.0,14360704.0,14393088.0,14193280.0,14225728.0,14258176.0,14290624.0,14323072.0,14355520.0,14387968.0,14420416.0,14220160.0,14252672.0,14285184.0,14317696.0,14350208.0,14382720.0,14415232.0,14447744.0,14247040.0,14279616.0,14312192.0,14344768.0,14377344.0,14409920.0,14442496.0,14475072.0,14273920.0,14306560.0,14339200.0,14371840.0,14404480.0,14437120.0,14469760.0,14502400.0,14300800.0,14333504.0,14366208.0,14398912.0,14431616.0,14464320.0,14497024.0,14529728.0,14327680.0,14360448.0,14393216.0,14425984.0,14458752.0,14491520.0,14524288.0,14557056.0,14354560.0,14387392.0,14420224.0,14453056.0,14485888.0,14518720.0,14551552.0,14584384.0,14381440.0,14414336.0,14447232.0,14480128.0,14513024.0,14545920.0,14578816.0,14611712.0,14408320.0,14441280.0,14474240.0,14507200.0,14540160.0,14573120.0,14606080.0,14639040.0,14435200.0,14468224.0,14501248.0,14534272.0,14567296.0,14600320.0,14633344.0,14666368.0,14462080.0,14495168.0,14528256.0,14561344.0,14594432.0,14627520.0,14660608.0,14693696.0,14488960.0,14522112.0,14555264.0,14588416.0,14621568.0,14654720.0,14687872.0,14721024.0,14515840.0,14549056.0,14582272.0,14615488.0,14648704.0,14681920.0,14715136.0,14748352.0,14542720.0,14576000.0,14609280.0,14642560.0,14675840.0,14709120.0,14742400.0,14775680.0,14569600.0,14602944.0,14636288.0,14669632.0,14702976.0,14736320.0,14769664.0,14803008.0,14596480.0,14629888.0,14663296.0,14696704.0,14730112.0,14763520.0,14796928.0,14830336.0,14623360.0,14656832.0,14690304.0,14723776.0,14757248.0,14790720.0,14824192.0,14857664.0,14650240.0,14683776.0,14717312.0,14750848.0,14784384.0,14817920.0,14851456.0,14884992.0,14677120.0,14710720.0,14744320.0,14777920.0,14811520.0,14845120.0,14878720.0,14912320.0,14704000.0,14737664.0,14771328.0,14804992.0,14838656.0,14872320.0,14905984.0,14939648.0,14730880.0,14764608.0,14798336.0,14832064.0,14865792.0,14899520.0,14933248.0,14966976.0,14757760.0,14791552.0,14825344.0,14859136.0,14892928.0,14926720.0,14960512.0,14994304.0,14784640.0,14818496.0,14852352.0,14886208.0,14920064.0,14953920.0,14987776.0,15021632.0,14123392.0,14156288.0,14189184.0,14222080.0,14254976.0,14287872.0,14320768.0,14353664.0,13453952.0,13485888.0,13517824.0,13549760.0,13581696.0,13613632.0,13645568.0,13677504.0,12792704.0,12823680.0,12854656.0,12885632.0,12916608.0,12947584.0,12978560.0,13009536.0,12123264.0,12153280.0,12183296.0,12213312.0,12243328.0,12273344.0,12303360.0,12333376.0,12150144.0,12180224.0,12210304.0,12240384.0,12270464.0,12300544.0,12330624.0,12360704.0,12177024.0,12207168.0,12237312.0,12267456.0,12297600.0,12327744.0,12357888.0,12388032.0,12203904.0,12234112.0,12264320.0,12294528.0,12324736.0,12354944.0,12385152.0,12415360.0,12230784.0,12261056.0,12291328.0,12321600.0,12351872.0,12382144.0,12412416.0,12442688.0,12257664.0,12288000.0,12318336.0,12348672.0,12379008.0,12409344.0,12439680.0,12470016.0,12284544.0,12314944.0,12345344.0,12375744.0,12406144.0,12436544.0,12466944.0,12497344.0,12311424.0,12341888.0,12372352.0,12402816.0,12433280.0,12463744.0,12494208.0,12524672.0,12338304.0,12368832.0,12399360.0,12429888.0,12460416.0,12490944.0,12521472.0,12552000.0,12365184.0,12395776.0,12426368.0,12456960.0,12487552.0,12518144.0,12548736.0,12579328.0,12392064.0,12422720.0,12453376.0,12484032.0,12514688.0,12545344.0,12576000.0,12606656.0,12418944.0,12449664.0,12480384.0,12511104.0,12541824.0,12572544.0,12603264.0,12633984.0,12445824.0,12476608.0,12507392.0,12538176.0,12568960.0,12599744.0,12630528.0,12661312.0,12472704.0,12503552.0,12534400.0,12565248.0,12596096.0,12626944.0,12657792.0,12688640.0,12499584.0,12530496.0,12561408.0,12592320.0,12623232.0,12654144.0,12685056.0,12715968.0,12526464.0,12557440.0,12588416.0,12619392.0,12650368.0,12681344.0,12712320.0,12743296.0,12553344.0,12584384.0,12615424.0,12646464.0,12677504.0,12708544.0,12739584.0,12770624.0,12580224.0,12611328.0,12642432.0,12673536.0,12704640.0,12735744.0,12766848.0,12797952.0,12607104.0,12638272.0,12669440.0,12700608.0,12731776.0,12762944.0,12794112.0,12825280.0,12633984.0,12665216.0,12696448.0,12727680.0,12758912.0,12790144.0,12821376.0,12852608.0,12660864.0,12692160.0,12723456.0,12754752.0,12786048.0,12817344.0,12848640.0,12879936.0,12687744.0,12719104.0,12750464.0,12781824.0,12813184.0,12844544.0,12875904.0,12907264.0,12714624.0,12746048.0,12777472.0,12808896.0,12840320.0,12871744.0,12903168.0,12934592.0,12741504.0,12772992.0,12804480.0,12835968.0,12867456.0,12898944.0,12930432.0,12961920.0,12768384.0,12799936.0,12831488.0,12863040.0,12894592.0,12926144.0,12957696.0,12989248.0,12795264.0,12826880.0,12858496.0,12890112.0,12921728.0,12953344.0,12984960.0,13016576.0,12822144.0,12853824.0,12885504.0,12917184.0,12948864.0,12980544.0,13012224.0,13043904.0,12849024.0,12880768.0,12912512.0,12944256.0,12976000.0,13007744.0,13039488.0,13071232.0,12875904.0,12907712.0,12939520.0,12971328.0,13003136.0,13034944.0,13066752.0,13098560.0,12902784.0,12934656.0,12966528.0,12998400.0,13030272.0,13062144.0,13094016.0,13125888.0,12929664.0,12961600.0,12993536.0,13025472.0,13057408.0,13089344.0,13121280.0,13153216.0,12956544.0,12988544.0,13020544.0,13052544.0,13084544.0,13116544.0,13148544.0,13180544.0,12983424.0,13015488.0,13047552.0,13079616.0,13111680.0,13143744.0,13175808.0,13207872.0,13010304.0,13042432.0,13074560.0,13106688.0,13138816.0,13170944.0,13203072.0,13235200.0,13037184.0,13069376.0,13101568.0,13133760.0,13165952.0,13198144.0,13230336.0,13262528.0,13064064.0,13096320.0,13128576.0,13160832.0,13193088.0,13225344.0,13257600.0,13289856.0,13090944.0,13123264.0,13155584.0,13187904.0,13220224.0,13252544.0,13284864.0,13317184.0,13117824.0,13150208.0,13182592.0,13214976.0,13247360.0,13279744.0,13312128.0,13344512.0,12980864.0,13012288.0,13043712.0,13075136.0,13106560.0,13137984.0,13169408.0,13200832.0,12835712.0,12866176.0,12896640.0,12927104.0,12957568.0,12988032.0,13018496.0,13048960.0,12862592.0,12893120.0,12923648.0,12954176.0,12984704.0,13015232.0,13045760.0,13076288.0,12889472.0,12920064.0,12950656.0,12981248.0,13011840.0,13042432.0,13073024.0,13103616.0,12916352.0,12947008.0,12977664.0,13008320.0,13038976.0,13069632.0,13100288.0,13130944.0,12943232.0,12973952.0,13004672.0,13035392.0,13066112.0,13096832.0,13127552.0,13158272.0,12970112.0,13000896.0,13031680.0,13062464.0,13093248.0,13124032.0,13154816.0,13185600.0,12996992.0,13027840.0,13058688.0,13089536.0,13120384.0,13151232.0,13182080.0,13212928.0,13023872.0,13054784.0,13085696.0,13116608.0,13147520.0,13178432.0,13209344.0,13240256.0,13050752.0,13081728.0,13112704.0,13143680.0,13174656.0,13205632.0,13236608.0,13267584.0,13077632.0,13108672.0,13139712.0,13170752.0,13201792.0,13232832.0,13263872.0,13294912.0,13104512.0,13135616.0,13166720.0,13197824.0,13228928.0,13260032.0,13291136.0,13322240.0,13131392.0,13162560.0,13193728.0,13224896.0,13256064.0,13287232.0,13318400.0,13349568.0,13158272.0,13189504.0,13220736.0,13251968.0,13283200.0,13314432.0,13345664.0,13376896.0,13185152.0,13216448.0,13247744.0,13279040.0,13310336.0,13341632.0,13372928.0,13404224.0,13212032.0,13243392.0,13274752.0,13306112.0,13337472.0,13368832.0,13400192.0,13431552.0,13238912.0,13270336.0,13301760.0,13333184.0,13364608.0,13396032.0,13427456.0,13458880.0,13265792.0,13297280.0,13328768.0,13360256.0,13391744.0,13423232.0,13454720.0,13486208.0,13292672.0,13324224.0,13355776.0,13387328.0,13418880.0,13450432.0,13481984.0,13513536.0,13319552.0,13351168.0,13382784.0,13414400.0,13446016.0,13477632.0,13509248.0,13540864.0,13346432.0,13378112.0,13409792.0,13441472.0,13473152.0,13504832.0,13536512.0,13568192.0,13373312.0,13405056.0,13436800.0,13468544.0,13500288.0,13532032.0,13563776.0,13595520.0,13400192.0,13432000.0,13463808.0,13495616.0,13527424.0,13559232.0,13591040.0,13622848.0,13427072.0,13458944.0,13490816.0,13522688.0,13554560.0,13586432.0,13618304.0,13650176.0,13453952.0,13485888.0,13517824.0,13549760.0,13581696.0,13613632.0,13645568.0,13677504.0,13480832.0,13512832.0,13544832.0,13576832.0,13608832.0,13640832.0,13672832.0,13704832.0,13507712.0,13539776.0,13571840.0,13603904.0,13635968.0,13668032.0,13700096.0,13732160.0,13534592.0,13566720.0,13598848.0,13630976.0,13663104.0,13695232.0,13727360.0,13759488.0,13561472.0,13593664.0,13625856.0,13658048.0,13690240.0,13722432.0,13754624.0,13786816.0,13588352.0,13620608.0,13652864.0,13685120.0,13717376.0,13749632.0,13781888.0,13814144.0,13615232.0,13647552.0,13679872.0,13712192.0,13744512.0,13776832.0,13809152.0,13841472.0,13642112.0,13674496.0,13706880.0,13739264.0,13771648.0,13804032.0,13836416.0,13868800.0,13668992.0,13701440.0,13733888.0,13766336.0,13798784.0,13831232.0,13863680.0,13896128.0,13695872.0,13728384.0,13760896.0,13793408.0,13825920.0,13858432.0,13890944.0,13923456.0,13722752.0,13755328.0,13787904.0,13820480.0,13853056.0,13885632.0,13918208.0,13950784.0,13749632.0,13782272.0,13814912.0,13847552.0,13880192.0,13912832.0,13945472.0,13978112.0,13776512.0,13809216.0,13841920.0,13874624.0,13907328.0,13940032.0,13972736.0,14005440.0,13803392.0,13836160.0,13868928.0,13901696.0,13934464.0,13967232.0,14000000.0,14032768.0,13830272.0,13863104.0,13895936.0,13928768.0,13961600.0,13994432.0,14027264.0,14060096.0,13169024.0,13200896.0,13232768.0,13264640.0,13296512.0,13328384.0,13360256.0,13392128.0,12499584.0,12530496.0,12561408.0,12592320.0,12623232.0,12654144.0,12685056.0,12715968.0,12526464.0,12557440.0,12588416.0,12619392.0,12650368.0,12681344.0,12712320.0,12743296.0,12553344.0,12584384.0,12615424.0,12646464.0,12677504.0,12708544.0,12739584.0,12770624.0,12580224.0,12611328.0,12642432.0,12673536.0,12704640.0,12735744.0,12766848.0,12797952.0,12607104.0,12638272.0,12669440.0,12700608.0,12731776.0,12762944.0,12794112.0,12825280.0,12633984.0,12665216.0,12696448.0,12727680.0,12758912.0,12790144.0,12821376.0,12852608.0,12660864.0,12692160.0,12723456.0,12754752.0,12786048.0,12817344.0,12848640.0,12879936.0,12687744.0,12719104.0,12750464.0,12781824.0,12813184.0,12844544.0,12875904.0,12907264.0,12714624.0,12746048.0,12777472.0,12808896.0,12840320.0,12871744.0,12903168.0,12934592.0,12741504.0,12772992.0,12804480.0,12835968.0,12867456.0,12898944.0,12930432.0,12961920.0,12768384.0,12799936.0,12831488.0,12863040.0,12894592.0,12926144.0,12957696.0,12989248.0,12795264.0,12826880.0,12858496.0,12890112.0,12921728.0,12953344.0,12984960.0,13016576.0,12822144.0,12853824.0,12885504.0,12917184.0,12948864.0,12980544.0,13012224.0,13043904.0,12849024.0,12880768.0,12912512.0,12944256.0,12976000.0,13007744.0,13039488.0,13071232.0,12875904.0,12907712.0,12939520.0,12971328.0,13003136.0,13034944.0,13066752.0,13098560.0,12902784.0,12934656.0,12966528.0,12998400.0,13030272.0,13062144.0,13094016.0,13125888.0,12929664.0,12961600.0,12993536.0,13025472.0,13057408.0,13089344.0,13121280.0,13153216.0,12956544.0,12988544.0,13020544.0,13052544.0,13084544.0,13116544.0,13148544.0,13180544.0,12983424.0,13015488.0,13047552.0,13079616.0,13111680.0,13143744.0,13175808.0,13207872.0,13010304.0,13042432.0,13074560.0,13106688.0,13138816.0,13170944.0,13203072.0,13235200.0,13037184.0,13069376.0,13101568.0,13133760.0,13165952.0,13198144.0,13230336.0,13262528.0,13064064.0,13096320.0,13128576.0,13160832.0,13193088.0,13225344.0,13257600.0,13289856.0,13090944.0,13123264.0,13155584.0,13187904.0,13220224.0,13252544.0,13284864.0,13317184.0,13117824.0,13150208.0,13182592.0,13214976.0,13247360.0,13279744.0,13312128.0,13344512.0,13144704.0,13177152.0,13209600.0,13242048.0,13274496.0,13306944.0,13339392.0,13371840.0,13171584.0,13204096.0,13236608.0,13269120.0,13301632.0,13334144.0,13366656.0,13399168.0,13198464.0,13231040.0,13263616.0,13296192.0,13328768.0,13361344.0,13393920.0,13426496.0,13225344.0,13257984.0,13290624.0,13323264.0,13355904.0,13388544.0,13421184.0,13453824.0,13252224.0,13284928.0,13317632.0,13350336.0,13383040.0,13415744.0,13448448.0,13481152.0,13279104.0,13311872.0,13344640.0,13377408.0,13410176.0,13442944.0,13475712.0,13508480.0,13305984.0,13338816.0,13371648.0,13404480.0,13437312.0,13470144.0,13502976.0,13535808.0,13332864.0,13365760.0,13398656.0,13431552.0,13464448.0,13497344.0,13530240.0,13563136.0,13359744.0,13392704.0,13425664.0,13458624.0,13491584.0,13524544.0,13557504.0,13590464.0,13386624.0,13419648.0,13452672.0,13485696.0,13518720.0,13551744.0,13584768.0,13617792.0,13413504.0,13446592.0,13479680.0,13512768.0,13545856.0,13578944.0,13612032.0,13645120.0,13440384.0,13473536.0,13506688.0,13539840.0,13572992.0,13606144.0,13639296.0,13672448.0,13467264.0,13500480.0,13533696.0,13566912.0,13600128.0,13633344.0,13666560.0,13699776.0,13494144.0,13527424.0,13560704.0,13593984.0,13627264.0,13660544.0,13693824.0,13727104.0,13357184.0,13389504.0,13421824.0,13454144.0,13486464.0,13518784.0,13551104.0,13583424.0,13212032.0,13243392.0,13274752.0,13306112.0,13337472.0,13368832.0,13400192.0,13431552.0,13238912.0,13270336.0,13301760.0,13333184.0,13364608.0,13396032.0,13427456.0,13458880.0,13265792.0,13297280.0,13328768.0,13360256.0,13391744.0,13423232.0,13454720.0,13486208.0,13292672.0,13324224.0,13355776.0,13387328.0,13418880.0,13450432.0,13481984.0,13513536.0,13319552.0,13351168.0,13382784.0,13414400.0,13446016.0,13477632.0,13509248.0,13540864.0,13346432.0,13378112.0,13409792.0,13441472.0,13473152.0,13504832.0,13536512.0,13568192.0,13373312.0,13405056.0,13436800.0,13468544.0,13500288.0,13532032.0,13563776.0,13595520.0,13400192.0,13432000.0,13463808.0,13495616.0,13527424.0,13559232.0,13591040.0,13622848.0,13427072.0,13458944.0,13490816.0,13522688.0,13554560.0,13586432.0,13618304.0,13650176.0,13453952.0,13485888.0,13517824.0,13549760.0,13581696.0,13613632.0,13645568.0,13677504.0,13480832.0,13512832.0,13544832.0,13576832.0,13608832.0,13640832.0,13672832.0,13704832.0,13507712.0,13539776.0,13571840.0,13603904.0,13635968.0,13668032.0,13700096.0,13732160.0,13534592.0,13566720.0,13598848.0,13630976.0,13663104.0,13695232.0,13727360.0,13759488.0,13561472.0,13593664.0,13625856.0,13658048.0,13690240.0,13722432.0,13754624.0,13786816.0,13588352.0,13620608.0,13652864.0,13685120.0,13717376.0,13749632.0,13781888.0,13814144.0,13615232.0,13647552.0,13679872.0,13712192.0,13744512.0,13776832.0,13809152.0,13841472.0,13642112.0,13674496.0,13706880.0,13739264.0,13771648.0,13804032.0,13836416.0,13868800.0,13668992.0,13701440.0,13733888.0,13766336.0,13798784.0,13831232.0,13863680.0,13896128.0,13695872.0,13728384.0,13760896.0,13793408.0,13825920.0,13858432.0,13890944.0,13923456.0,13722752.0,13755328.0,13787904.0,13820480.0,13853056.0,13885632.0,13918208.0,13950784.0,13749632.0,13782272.0,13814912.0,13847552.0,13880192.0,13912832.0,13945472.0,13978112.0,13776512.0,13809216.0,13841920.0,13874624.0,13907328.0,13940032.0,13972736.0,14005440.0,13803392.0,13836160.0,13868928.0,13901696.0,13934464.0,13967232.0,14000000.0,14032768.0,13830272.0,13863104.0,13895936.0,13928768.0,13961600.0,13994432.0,14027264.0,14060096.0,13857152.0,13890048.0,13922944.0,13955840.0,13988736.0,14021632.0,14054528.0,14087424.0,13884032.0,13916992.0,13949952.0,13982912.0,14015872.0,14048832.0,14081792.0,14114752.0,13910912.0,13943936.0,13976960.0,14009984.0,14043008.0,14076032.0,14109056.0,14142080.0,13937792.0,13970880.0,14003968.0,14037056.0,14070144.0,14103232.0,14136320.0,14169408.0,13964672.0,13997824.0,14030976.0,14064128.0,14097280.0,14130432.0,14163584.0,14196736.0,13991552.0,14024768.0,14057984.0,14091200.0,14124416.0,14157632.0,14190848.0,14224064.0,14018432.0,14051712.0,14084992.0,14118272.0,14151552.0,14184832.0,14218112.0,14251392.0,14045312.0,14078656.0,14112000.0,14145344.0,14178688.0,14212032.0,14245376.0,14278720.0,14072192.0,14105600.0,14139008.0,14172416.0,14205824.0,14239232.0,14272640.0,14306048.0,14099072.0,14132544.0,14166016.0,14199488.0,14232960.0,14266432.0,14299904.0,14333376.0,14125952.0,14159488.0,14193024.0,14226560.0,14260096.0,14293632.0,14327168.0,14360704.0,14152832.0,14186432.0,14220032.0,14253632.0,14287232.0,14320832.0,14354432.0,14388032.0,14179712.0,14213376.0,14247040.0,14280704.0,14314368.0,14348032.0,14381696.0,14415360.0,14206592.0,14240320.0,14274048.0,14307776.0,14341504.0,14375232.0,14408960.0,14442688.0,13545344.0,13578112.0,13610880.0,13643648.0,13676416.0,13709184.0,13741952.0,13774720.0,12875904.0,12907712.0,12939520.0,12971328.0,13003136.0,13034944.0,13066752.0,13098560.0,6508416.0,6524336.0,6540256.0,6556176.0,6572096.0,6588016.0,6603936.0,6619856.0,7104768.0,7120560.0,7136352.0,7152144.0,7167936.0,7183728.0,7199520.0,7215312.0,14349696.0,14381312.0,14412928.0,14444544.0,14476160.0,14507776.0,14539392.0,14571008.0,14377600.0,14409280.0,14440960.0,14472640.0,14504320.0,14536000.0,14567680.0,14599360.0,14405504.0,14437248.0,14468992.0,14500736.0,14532480.0,14564224.0,14595968.0,14627712.0,14433408.0,14465216.0,14497024.0,14528832.0,14560640.0,14592448.0,14624256.0,14656064.0,14461312.0,14493184.0,14525056.0,14556928.0,14588800.0,14620672.0,14652544.0,14684416.0,14489216.0,14521152.0,14553088.0,14585024.0,14616960.0,14648896.0,14680832.0,14712768.0,14517120.0,14549120.0,14581120.0,14613120.0,14645120.0,14677120.0,14709120.0,14741120.0,14545024.0,14577088.0,14609152.0,14641216.0,14673280.0,14705344.0,14737408.0,14769472.0,14572928.0,14605056.0,14637184.0,14669312.0,14701440.0,14733568.0,14765696.0,14797824.0,14600832.0,14633024.0,14665216.0,14697408.0,14729600.0,14761792.0,14793984.0,14826176.0,14628736.0,14660992.0,14693248.0,14725504.0,14757760.0,14790016.0,14822272.0,14854528.0,14656640.0,14688960.0,14721280.0,14753600.0,14785920.0,14818240.0,14850560.0,14882880.0,14684544.0,14716928.0,14749312.0,14781696.0,14814080.0,14846464.0,14878848.0,14911232.0,14712448.0,14744896.0,14777344.0,14809792.0,14842240.0,14874688.0,14907136.0,14939584.0,14740352.0,14772864.0,14805376.0,14837888.0,14870400.0,14902912.0,14935424.0,14967936.0,14768256.0,14800832.0,14833408.0,14865984.0,14898560.0,14931136.0,14963712.0,14996288.0,14796160.0,14828800.0,14861440.0,14894080.0,14926720.0,14959360.0,14992000.0,15024640.0,14824064.0,14856768.0,14889472.0,14922176.0,14954880.0,14987584.0,15020288.0,15052992.0,14851968.0,14884736.0,14917504.0,14950272.0,14983040.0,15015808.0,15048576.0,15081344.0,14879872.0,14912704.0,14945536.0,14978368.0,15011200.0,15044032.0,15076864.0,15109696.0,14907776.0,14940672.0,14973568.0,15006464.0,15039360.0,15072256.0,15105152.0,15138048.0,14935680.0,14968640.0,15001600.0,15034560.0,15067520.0,15100480.0,15133440.0,15166400.0,14963584.0,14996608.0,15029632.0,15062656.0,15095680.0,15128704.0,15161728.0,15194752.0,14991488.0,15024576.0,15057664.0,15090752.0,15123840.0,15156928.0,15190016.0,15223104.0,15019392.0,15052544.0,15085696.0,15118848.0,15152000.0,15185152.0,15218304.0,15251456.0,15047296.0,15080512.0,15113728.0,15146944.0,15180160.0,15213376.0,15246592.0,15279808.0,15075200.0,15108480.0,15141760.0,15175040.0,15208320.0,15241600.0,15274880.0,15308160.0,15103104.0,15136448.0,15169792.0,15203136.0,15236480.0,15269824.0,15303168.0,15336512.0,15131008.0,15164416.0,15197824.0,15231232.0,15264640.0,15298048.0,15331456.0,15364864.0,15158912.0,15192384.0,15225856.0,15259328.0,15292800.0,15326272.0,15359744.0,15393216.0,15186816.0,15220352.0,15253888.0,15287424.0,15320960.0,15354496.0,15388032.0,15421568.0,15214720.0,15248320.0,15281920.0,15315520.0,15349120.0,15382720.0,15416320.0,15449920.0,15242624.0,15276288.0,15309952.0,15343616.0,15377280.0,15410944.0,15444608.0,15478272.0,15270528.0,15304256.0,15337984.0,15371712.0,15405440.0,15439168.0,15472896.0,15506624.0,15298432.0,15332224.0,15366016.0,15399808.0,15433600.0,15467392.0,15501184.0,15534976.0,15326336.0,15360192.0,15394048.0,15427904.0,15461760.0,15495616.0,15529472.0,15563328.0,14649728.0,14682624.0,14715520.0,14748416.0,14781312.0,14814208.0,14847104.0,14880000.0,13964928.0,13996864.0,14028800.0,14060736.0,14092672.0,14124608.0,14156544.0,14188480.0,13288320.0,13319296.0,13350272.0,13381248.0,13412224.0,13443200.0,13474176.0,13505152.0,12603520.0,12633536.0,12663552.0,12693568.0,12723584.0,12753600.0,12783616.0,12813632.0,12631424.0,12661504.0,12691584.0,12721664.0,12751744.0,12781824.0,12811904.0,12841984.0,12659328.0,12689472.0,12719616.0,12749760.0,12779904.0,12810048.0,12840192.0,12870336.0,12687232.0,12717440.0,12747648.0,12777856.0,12808064.0,12838272.0,12868480.0,12898688.0,12715136.0,12745408.0,12775680.0,12805952.0,12836224.0,12866496.0,12896768.0,12927040.0,12743040.0,12773376.0,12803712.0,12834048.0,12864384.0,12894720.0,12925056.0,12955392.0,12770944.0,12801344.0,12831744.0,12862144.0,12892544.0,12922944.0,12953344.0,12983744.0,12798848.0,12829312.0,12859776.0,12890240.0,12920704.0,12951168.0,12981632.0,13012096.0,12826752.0,12857280.0,12887808.0,12918336.0,12948864.0,12979392.0,13009920.0,13040448.0,12854656.0,12885248.0,12915840.0,12946432.0,12977024.0,13007616.0,13038208.0,13068800.0,12882560.0,12913216.0,12943872.0,12974528.0,13005184.0,13035840.0,13066496.0,13097152.0,12910464.0,12941184.0,12971904.0,13002624.0,13033344.0,13064064.0,13094784.0,13125504.0,12938368.0,12969152.0,12999936.0,13030720.0,13061504.0,13092288.0,13123072.0,13153856.0,12966272.0,12997120.0,13027968.0,13058816.0,13089664.0,13120512.0,13151360.0,13182208.0,12994176.0,13025088.0,13056000.0,13086912.0,13117824.0,13148736.0,13179648.0,13210560.0,13022080.0,13053056.0,13084032.0,13115008.0,13145984.0,13176960.0,13207936.0,13238912.0,13049984.0,13081024.0,13112064.0,13143104.0,13174144.0,13205184.0,13236224.0,13267264.0,13077888.0,13108992.0,13140096.0,13171200.0,13202304.0,13233408.0,13264512.0,13295616.0,13105792.0,13136960.0,13168128.0,13199296.0,13230464.0,13261632.0,13292800.0,13323968.0,13133696.0,13164928.0,13196160.0,13227392.0,13258624.0,13289856.0,13321088.0,13352320.0,13161600.0,13192896.0,13224192.0,13255488.0,13286784.0,13318080.0,13349376.0,13380672.0,13189504.0,13220864.0,13252224.0,13283584.0,13314944.0,13346304.0,13377664.0,13409024.0,13217408.0,13248832.0,13280256.0,13311680.0,13343104.0,13374528.0,13405952.0,13437376.0,13245312.0,13276800.0,13308288.0,13339776.0,13371264.0,13402752.0,13434240.0,13465728.0,13273216.0,13304768.0,13336320.0,13367872.0,13399424.0,13430976.0,13462528.0,13494080.0,13301120.0,13332736.0,13364352.0,13395968.0,13427584.0,13459200.0,13490816.0,13522432.0,13329024.0,13360704.0,13392384.0,13424064.0,13455744.0,13487424.0,13519104.0,13550784.0,13356928.0,13388672.0,13420416.0,13452160.0,13483904.0,13515648.0,13547392.0,13579136.0,13384832.0,13416640.0,13448448.0,13480256.0,13512064.0,13543872.0,13575680.0,13607488.0,13412736.0,13444608.0,13476480.0,13508352.0,13540224.0,13572096.0,13603968.0,13635840.0,13440640.0,13472576.0,13504512.0,13536448.0,13568384.0,13600320.0,13632256.0,13664192.0,13468544.0,13500544.0,13532544.0,13564544.0,13596544.0,13628544.0,13660544.0,13692544.0,13496448.0,13528512.0,13560576.0,13592640.0,13624704.0,13656768.0,13688832.0,13720896.0,13524352.0,13556480.0,13588608.0,13620736.0,13652864.0,13684992.0,13717120.0,13749248.0,13552256.0,13584448.0,13616640.0,13648832.0,13681024.0,13713216.0,13745408.0,13777600.0,13580160.0,13612416.0,13644672.0,13676928.0,13709184.0,13741440.0,13773696.0,13805952.0,13608064.0,13640384.0,13672704.0,13705024.0,13737344.0,13769664.0,13801984.0,13834304.0,13635968.0,13668352.0,13700736.0,13733120.0,13765504.0,13797888.0,13830272.0,13862656.0,13483648.0,13515072.0,13546496.0,13577920.0,13609344.0,13640768.0,13672192.0,13703616.0,13323136.0,13353600.0,13384064.0,13414528.0,13444992.0,13475456.0,13505920.0,13536384.0,13351040.0,13381568.0,13412096.0,13442624.0,13473152.0,13503680.0,13534208.0,13564736.0,13378944.0,13409536.0,13440128.0,13470720.0,13501312.0,13531904.0,13562496.0,13593088.0,13406848.0,13437504.0,13468160.0,13498816.0,13529472.0,13560128.0,13590784.0,13621440.0,13434752.0,13465472.0,13496192.0,13526912.0,13557632.0,13588352.0,13619072.0,13649792.0,13462656.0,13493440.0,13524224.0,13555008.0,13585792.0,13616576.0,13647360.0,13678144.0,13490560.0,13521408.0,13552256.0,13583104.0,13613952.0,13644800.0,13675648.0,13706496.0,13518464.0,13549376.0,13580288.0,13611200.0,13642112.0,13673024.0,13703936.0,13734848.0,13546368.0,13577344.0,13608320.0,13639296.0,13670272.0,13701248.0,13732224.0,13763200.0,13574272.0,13605312.0,13636352.0,13667392.0,13698432.0,13729472.0,13760512.0,13791552.0,13602176.0,13633280.0,13664384.0,13695488.0,13726592.0,13757696.0,13788800.0,13819904.0,13630080.0,13661248.0,13692416.0,13723584.0,13754752.0,13785920.0,13817088.0,13848256.0,13657984.0,13689216.0,13720448.0,13751680.0,13782912.0,13814144.0,13845376.0,13876608.0,13685888.0,13717184.0,13748480.0,13779776.0,13811072.0,13842368.0,13873664.0,13904960.0,13713792.0,13745152.0,13776512.0,13807872.0,13839232.0,13870592.0,13901952.0,13933312.0,13741696.0,13773120.0,13804544.0,13835968.0,13867392.0,13898816.0,13930240.0,13961664.0,13769600.0,13801088.0,13832576.0,13864064.0,13895552.0,13927040.0,13958528.0,13990016.0,13797504.0,13829056.0,13860608.0,13892160.0,13923712.0,13955264.0,13986816.0,14018368.0,13825408.0,13857024.0,13888640.0,13920256.0,13951872.0,13983488.0,14015104.0,14046720.0,13853312.0,13884992.0,13916672.0,13948352.0,13980032.0,14011712.0,14043392.0,14075072.0,13881216.0,13912960.0,13944704.0,13976448.0,14008192.0,14039936.0,14071680.0,14103424.0,13909120.0,13940928.0,13972736.0,14004544.0,14036352.0,14068160.0,14099968.0,14131776.0,13937024.0,13968896.0,14000768.0,14032640.0,14064512.0,14096384.0,14128256.0,14160128.0,13964928.0,13996864.0,14028800.0,14060736.0,14092672.0,14124608.0,14156544.0,14188480.0,13992832.0,14024832.0,14056832.0,14088832.0,14120832.0,14152832.0,14184832.0,14216832.0,14020736.0,14052800.0,14084864.0,14116928.0,14148992.0,14181056.0,14213120.0,14245184.0,14048640.0,14080768.0,14112896.0,14145024.0,14177152.0,14209280.0,14241408.0,14273536.0,14076544.0,14108736.0,14140928.0,14173120.0,14205312.0,14237504.0,14269696.0,14301888.0,14104448.0,14136704.0,14168960.0,14201216.0,14233472.0,14265728.0,14297984.0,14330240.0,14132352.0,14164672.0,14196992.0,14229312.0,14261632.0,14293952.0,14326272.0,14358592.0,14160256.0,14192640.0,14225024.0,14257408.0,14289792.0,14322176.0,14354560.0,14386944.0,14188160.0,14220608.0,14253056.0,14285504.0,14317952.0,14350400.0,14382848.0,14415296.0,14216064.0,14248576.0,14281088.0,14313600.0,14346112.0,14378624.0,14411136.0,14443648.0,14243968.0,14276544.0,14309120.0,14341696.0,14374272.0,14406848.0,14439424.0,14472000.0,14271872.0,14304512.0,14337152.0,14369792.0,14402432.0,14435072.0,14467712.0,14500352.0,14299776.0,14332480.0,14365184.0,14397888.0,14430592.0,14463296.0,14496000.0,14528704.0,14327680.0,14360448.0,14393216.0,14425984.0,14458752.0,14491520.0,14524288.0,14557056.0,14355584.0,14388416.0,14421248.0,14454080.0,14486912.0,14519744.0,14552576.0,14585408.0,13678976.0,13710848.0,13742720.0,13774592.0,13806464.0,13838336.0,13870208.0,13902080.0,12994176.0,13025088.0,13056000.0,13086912.0,13117824.0,13148736.0,13179648.0,13210560.0,13022080.0,13053056.0,13084032.0,13115008.0,13145984.0,13176960.0,13207936.0,13238912.0,13049984.0,13081024.0,13112064.0,13143104.0,13174144.0,13205184.0,13236224.0,13267264.0,13077888.0,13108992.0,13140096.0,13171200.0,13202304.0,13233408.0,13264512.0,13295616.0,13105792.0,13136960.0,13168128.0,13199296.0,13230464.0,13261632.0,13292800.0,13323968.0,13133696.0,13164928.0,13196160.0,13227392.0,13258624.0,13289856.0,13321088.0,13352320.0,13161600.0,13192896.0,13224192.0,13255488.0,13286784.0,13318080.0,13349376.0,13380672.0,13189504.0,13220864.0,13252224.0,13283584.0,13314944.0,13346304.0,13377664.0,13409024.0,13217408.0,13248832.0,13280256.0,13311680.0,13343104.0,13374528.0,13405952.0,13437376.0,13245312.0,13276800.0,13308288.0,13339776.0,13371264.0,13402752.0,13434240.0,13465728.0,13273216.0,13304768.0,13336320.0,13367872.0,13399424.0,13430976.0,13462528.0,13494080.0,13301120.0,13332736.0,13364352.0,13395968.0,13427584.0,13459200.0,13490816.0,13522432.0,13329024.0,13360704.0,13392384.0,13424064.0,13455744.0,13487424.0,13519104.0,13550784.0,13356928.0,13388672.0,13420416.0,13452160.0,13483904.0,13515648.0,13547392.0,13579136.0,13384832.0,13416640.0,13448448.0,13480256.0,13512064.0,13543872.0,13575680.0,13607488.0,13412736.0,13444608.0,13476480.0,13508352.0,13540224.0,13572096.0,13603968.0,13635840.0,13440640.0,13472576.0,13504512.0,13536448.0,13568384.0,13600320.0,13632256.0,13664192.0,13468544.0,13500544.0,13532544.0,13564544.0,13596544.0,13628544.0,13660544.0,13692544.0,13496448.0,13528512.0,13560576.0,13592640.0,13624704.0,13656768.0,13688832.0,13720896.0,13524352.0,13556480.0,13588608.0,13620736.0,13652864.0,13684992.0,13717120.0,13749248.0,13552256.0,13584448.0,13616640.0,13648832.0,13681024.0,13713216.0,13745408.0,13777600.0,13580160.0,13612416.0,13644672.0,13676928.0,13709184.0,13741440.0,13773696.0,13805952.0,13608064.0,13640384.0,13672704.0,13705024.0,13737344.0,13769664.0,13801984.0,13834304.0,13635968.0,13668352.0,13700736.0,13733120.0,13765504.0,13797888.0,13830272.0,13862656.0,13663872.0,13696320.0,13728768.0,13761216.0,13793664.0,13826112.0,13858560.0,13891008.0,13691776.0,13724288.0,13756800.0,13789312.0,13821824.0,13854336.0,13886848.0,13919360.0,13719680.0,13752256.0,13784832.0,13817408.0,13849984.0,13882560.0,13915136.0,13947712.0,13747584.0,13780224.0,13812864.0,13845504.0,13878144.0,13910784.0,13943424.0,13976064.0,13775488.0,13808192.0,13840896.0,13873600.0,13906304.0,13939008.0,13971712.0,14004416.0,13803392.0,13836160.0,13868928.0,13901696.0,13934464.0,13967232.0,14000000.0,14032768.0,13831296.0,13864128.0,13896960.0,13929792.0,13962624.0,13995456.0,14028288.0,14061120.0,13859200.0,13892096.0,13924992.0,13957888.0,13990784.0,14023680.0,14056576.0,14089472.0,13887104.0,13920064.0,13953024.0,13985984.0,14018944.0,14051904.0,14084864.0,14117824.0,13915008.0,13948032.0,13981056.0,14014080.0,14047104.0,14080128.0,14113152.0,14146176.0,13942912.0,13976000.0,14009088.0,14042176.0,14075264.0,14108352.0,14141440.0,14174528.0,13970816.0,14003968.0,14037120.0,14070272.0,14103424.0,14136576.0,14169728.0,14202880.0,13998720.0,14031936.0,14065152.0,14098368.0,14131584.0,14164800.0,14198016.0,14231232.0,14026624.0,14059904.0,14093184.0,14126464.0,14159744.0,14193024.0,14226304.0,14259584.0,13874304.0,13906624.0,13938944.0,13971264.0,14003584.0,14035904.0,14068224.0,14100544.0,13713792.0,13745152.0,13776512.0,13807872.0,13839232.0,13870592.0,13901952.0,13933312.0,13741696.0,13773120.0,13804544.0,13835968.0,13867392.0,13898816.0,13930240.0,13961664.0,13769600.0,13801088.0,13832576.0,13864064.0,13895552.0,13927040.0,13958528.0,13990016.0,13797504.0,13829056.0,13860608.0,13892160.0,13923712.0,13955264.0,13986816.0,14018368.0,13825408.0,13857024.0,13888640.0,13920256.0,13951872.0,13983488.0,14015104.0,14046720.0,13853312.0,13884992.0,13916672.0,13948352.0,13980032.0,14011712.0,14043392.0,14075072.0,13881216.0,13912960.0,13944704.0,13976448.0,14008192.0,14039936.0,14071680.0,14103424.0,13909120.0,13940928.0,13972736.0,14004544.0,14036352.0,14068160.0,14099968.0,14131776.0,13937024.0,13968896.0,14000768.0,14032640.0,14064512.0,14096384.0,14128256.0,14160128.0,13964928.0,13996864.0,14028800.0,14060736.0,14092672.0,14124608.0,14156544.0,14188480.0,13992832.0,14024832.0,14056832.0,14088832.0,14120832.0,14152832.0,14184832.0,14216832.0,14020736.0,14052800.0,14084864.0,14116928.0,14148992.0,14181056.0,14213120.0,14245184.0,14048640.0,14080768.0,14112896.0,14145024.0,14177152.0,14209280.0,14241408.0,14273536.0,14076544.0,14108736.0,14140928.0,14173120.0,14205312.0,14237504.0,14269696.0,14301888.0,14104448.0,14136704.0,14168960.0,14201216.0,14233472.0,14265728.0,14297984.0,14330240.0,14132352.0,14164672.0,14196992.0,14229312.0,14261632.0,14293952.0,14326272.0,14358592.0,14160256.0,14192640.0,14225024.0,14257408.0,14289792.0,14322176.0,14354560.0,14386944.0,14188160.0,14220608.0,14253056.0,14285504.0,14317952.0,14350400.0,14382848.0,14415296.0,14216064.0,14248576.0,14281088.0,14313600.0,14346112.0,14378624.0,14411136.0,14443648.0,14243968.0,14276544.0,14309120.0,14341696.0,14374272.0,14406848.0,14439424.0,14472000.0,14271872.0,14304512.0,14337152.0,14369792.0,14402432.0,14435072.0,14467712.0,14500352.0,14299776.0,14332480.0,14365184.0,14397888.0,14430592.0,14463296.0,14496000.0,14528704.0,14327680.0,14360448.0,14393216.0,14425984.0,14458752.0,14491520.0,14524288.0,14557056.0,14355584.0,14388416.0,14421248.0,14454080.0,14486912.0,14519744.0,14552576.0,14585408.0,14383488.0,14416384.0,14449280.0,14482176.0,14515072.0,14547968.0,14580864.0,14613760.0,14411392.0,14444352.0,14477312.0,14510272.0,14543232.0,14576192.0,14609152.0,14642112.0,14439296.0,14472320.0,14505344.0,14538368.0,14571392.0,14604416.0,14637440.0,14670464.0,14467200.0,14500288.0,14533376.0,14566464.0,14599552.0,14632640.0,14665728.0,14698816.0,14495104.0,14528256.0,14561408.0,14594560.0,14627712.0,14660864.0,14694016.0,14727168.0,14523008.0,14556224.0,14589440.0,14622656.0,14655872.0,14689088.0,14722304.0,14755520.0,14550912.0,14584192.0,14617472.0,14650752.0,14684032.0,14717312.0,14750592.0,14783872.0,14578816.0,14612160.0,14645504.0,14678848.0,14712192.0,14745536.0,14778880.0,14812224.0,14606720.0,14640128.0,14673536.0,14706944.0,14740352.0,14773760.0,14807168.0,14840576.0,14634624.0,14668096.0,14701568.0,14735040.0,14768512.0,14801984.0,14835456.0,14868928.0,14662528.0,14696064.0,14729600.0,14763136.0,14796672.0,14830208.0,14863744.0,14897280.0,14690432.0,14724032.0,14757632.0,14791232.0,14824832.0,14858432.0,14892032.0,14925632.0,14718336.0,14752000.0,14785664.0,14819328.0,14852992.0,14886656.0,14920320.0,14953984.0,14746240.0,14779968.0,14813696.0,14847424.0,14881152.0,14914880.0,14948608.0,14982336.0,14069632.0,14102400.0,14135168.0,14167936.0,14200704.0,14233472.0,14266240.0,14299008.0,13384832.0,13416640.0,13448448.0,13480256.0,13512064.0,13543872.0,13575680.0,13607488.0,6763136.0,6779056.0,6794976.0,6810896.0,6826816.0,6842736.0,6858656.0,6874576.0,7357440.0,7373232.0,7389024.0,7404816.0,7420608.0,7436400.0,7452192.0,7467984.0,14855552.0,14887168.0,14918784.0,14950400.0,14982016.0,15013632.0,15045248.0,15076864.0,14884480.0,14916160.0,14947840.0,14979520.0,15011200.0,15042880.0,15074560.0,15106240.0,14913408.0,14945152.0,14976896.0,15008640.0,15040384.0,15072128.0,15103872.0,15135616.0,14942336.0,14974144.0,15005952.0,15037760.0,15069568.0,15101376.0,15133184.0,15164992.0,14971264.0,15003136.0,15035008.0,15066880.0,15098752.0,15130624.0,15162496.0,15194368.0,15000192.0,15032128.0,15064064.0,15096000.0,15127936.0,15159872.0,15191808.0,15223744.0,15029120.0,15061120.0,15093120.0,15125120.0,15157120.0,15189120.0,15221120.0,15253120.0,15058048.0,15090112.0,15122176.0,15154240.0,15186304.0,15218368.0,15250432.0,15282496.0,15086976.0,15119104.0,15151232.0,15183360.0,15215488.0,15247616.0,15279744.0,15311872.0,15115904.0,15148096.0,15180288.0,15212480.0,15244672.0,15276864.0,15309056.0,15341248.0,15144832.0,15177088.0,15209344.0,15241600.0,15273856.0,15306112.0,15338368.0,15370624.0,15173760.0,15206080.0,15238400.0,15270720.0,15303040.0,15335360.0,15367680.0,15400000.0,15202688.0,15235072.0,15267456.0,15299840.0,15332224.0,15364608.0,15396992.0,15429376.0,15231616.0,15264064.0,15296512.0,15328960.0,15361408.0,15393856.0,15426304.0,15458752.0,15260544.0,15293056.0,15325568.0,15358080.0,15390592.0,15423104.0,15455616.0,15488128.0,15289472.0,15322048.0,15354624.0,15387200.0,15419776.0,15452352.0,15484928.0,15517504.0,15318400.0,15351040.0,15383680.0,15416320.0,15448960.0,15481600.0,15514240.0,15546880.0,15347328.0,15380032.0,15412736.0,15445440.0,15478144.0,15510848.0,15543552.0,15576256.0,15376256.0,15409024.0,15441792.0,15474560.0,15507328.0,15540096.0,15572864.0,15605632.0,15405184.0,15438016.0,15470848.0,15503680.0,15536512.0,15569344.0,15602176.0,15635008.0,15434112.0,15467008.0,15499904.0,15532800.0,15565696.0,15598592.0,15631488.0,15664384.0,15463040.0,15496000.0,15528960.0,15561920.0,15594880.0,15627840.0,15660800.0,15693760.0,15491968.0,15524992.0,15558016.0,15591040.0,15624064.0,15657088.0,15690112.0,15723136.0,15520896.0,15553984.0,15587072.0,15620160.0,15653248.0,15686336.0,15719424.0,15752512.0,15549824.0,15582976.0,15616128.0,15649280.0,15682432.0,15715584.0,15748736.0,15781888.0,15578752.0,15611968.0,15645184.0,15678400.0,15711616.0,15744832.0,15778048.0,15811264.0,15607680.0,15640960.0,15674240.0,15707520.0,15740800.0,15774080.0,15807360.0,15840640.0,15636608.0,15669952.0,15703296.0,15736640.0,15769984.0,15803328.0,15836672.0,15870016.0,15665536.0,15698944.0,15732352.0,15765760.0,15799168.0,15832576.0,15865984.0,15899392.0,15694464.0,15727936.0,15761408.0,15794880.0,15828352.0,15861824.0,15895296.0,15928768.0,15723392.0,15756928.0,15790464.0,15824000.0,15857536.0,15891072.0,15924608.0,15958144.0,15752320.0,15785920.0,15819520.0,15853120.0,15886720.0,15920320.0,15953920.0,15987520.0,15781248.0,15814912.0,15848576.0,15882240.0,15915904.0,15949568.0,15983232.0,16016896.0,15810176.0,15843904.0,15877632.0,15911360.0,15945088.0,15978816.0,16012544.0,16046272.0,15839104.0,15872896.0,15906688.0,15940480.0,15974272.0,16008064.0,16041856.0,16075648.0,15868032.0,15901888.0,15935744.0,15969600.0,16003456.0,16037312.0,16071168.0,16105024.0,15176064.0,15208960.0,15241856.0,15274752.0,15307648.0,15340544.0,15373440.0,15406336.0,14475904.0,14507840.0,14539776.0,14571712.0,14603648.0,14635584.0,14667520.0,14699456.0,13783936.0,13814912.0,13845888.0,13876864.0,13907840.0,13938816.0,13969792.0,14000768.0,13083776.0,13113792.0,13143808.0,13173824.0,13203840.0,13233856.0,13263872.0,13293888.0,13112704.0,13142784.0,13172864.0,13202944.0,13233024.0,13263104.0,13293184.0,13323264.0,13141632.0,13171776.0,13201920.0,13232064.0,13262208.0,13292352.0,13322496.0,13352640.0,13170560.0,13200768.0,13230976.0,13261184.0,13291392.0,13321600.0,13351808.0,13382016.0,13199488.0,13229760.0,13260032.0,13290304.0,13320576.0,13350848.0,13381120.0,13411392.0,13228416.0,13258752.0,13289088.0,13319424.0,13349760.0,13380096.0,13410432.0,13440768.0,13257344.0,13287744.0,13318144.0,13348544.0,13378944.0,13409344.0,13439744.0,13470144.0,13286272.0,13316736.0,13347200.0,13377664.0,13408128.0,13438592.0,13469056.0,13499520.0,13315200.0,13345728.0,13376256.0,13406784.0,13437312.0,13467840.0,13498368.0,13528896.0,13344128.0,13374720.0,13405312.0,13435904.0,13466496.0,13497088.0,13527680.0,13558272.0,13373056.0,13403712.0,13434368.0,13465024.0,13495680.0,13526336.0,13556992.0,13587648.0,13401984.0,13432704.0,13463424.0,13494144.0,13524864.0,13555584.0,13586304.0,13617024.0,13430912.0,13461696.0,13492480.0,13523264.0,13554048.0,13584832.0,13615616.0,13646400.0,13459840.0,13490688.0,13521536.0,13552384.0,13583232.0,13614080.0,13644928.0,13675776.0,13488768.0,13519680.0,13550592.0,13581504.0,13612416.0,13643328.0,13674240.0,13705152.0,13517696.0,13548672.0,13579648.0,13610624.0,13641600.0,13672576.0,13703552.0,13734528.0,13546624.0,13577664.0,13608704.0,13639744.0,13670784.0,13701824.0,13732864.0,13763904.0,13575552.0,13606656.0,13637760.0,13668864.0,13699968.0,13731072.0,13762176.0,13793280.0,13604480.0,13635648.0,13666816.0,13697984.0,13729152.0,13760320.0,13791488.0,13822656.0,13633408.0,13664640.0,13695872.0,13727104.0,13758336.0,13789568.0,13820800.0,13852032.0,13662336.0,13693632.0,13724928.0,13756224.0,13787520.0,13818816.0,13850112.0,13881408.0,13691264.0,13722624.0,13753984.0,13785344.0,13816704.0,13848064.0,13879424.0,13910784.0,13720192.0,13751616.0,13783040.0,13814464.0,13845888.0,13877312.0,13908736.0,13940160.0,13749120.0,13780608.0,13812096.0,13843584.0,13875072.0,13906560.0,13938048.0,13969536.0,13778048.0,13809600.0,13841152.0,13872704.0,13904256.0,13935808.0,13967360.0,13998912.0,13806976.0,13838592.0,13870208.0,13901824.0,13933440.0,13965056.0,13996672.0,14028288.0,13835904.0,13867584.0,13899264.0,13930944.0,13962624.0,13994304.0,14025984.0,14057664.0,13864832.0,13896576.0,13928320.0,13960064.0,13991808.0,14023552.0,14055296.0,14087040.0,13893760.0,13925568.0,13957376.0,13989184.0,14020992.0,14052800.0,14084608.0,14116416.0,13922688.0,13954560.0,13986432.0,14018304.0,14050176.0,14082048.0,14113920.0,14145792.0,13951616.0,13983552.0,14015488.0,14047424.0,14079360.0,14111296.0,14143232.0,14175168.0,13980544.0,14012544.0,14044544.0,14076544.0,14108544.0,14140544.0,14172544.0,14204544.0,14009472.0,14041536.0,14073600.0,14105664.0,14137728.0,14169792.0,14201856.0,14233920.0,14038400.0,14070528.0,14102656.0,14134784.0,14166912.0,14199040.0,14231168.0,14263296.0,14067328.0,14099520.0,14131712.0,14163904.0,14196096.0,14228288.0,14260480.0,14292672.0,14096256.0,14128512.0,14160768.0,14193024.0,14225280.0,14257536.0,14289792.0,14322048.0,14125184.0,14157504.0,14189824.0,14222144.0,14254464.0,14286784.0,14319104.0,14351424.0,14154112.0,14186496.0,14218880.0,14251264.0,14283648.0,14316032.0,14348416.0,14380800.0,13986432.0,14017856.0,14049280.0,14080704.0,14112128.0,14143552.0,14174976.0,14206400.0,13810560.0,13841024.0,13871488.0,13901952.0,13932416.0,13962880.0,13993344.0,14023808.0,13839488.0,13870016.0,13900544.0,13931072.0,13961600.0,13992128.0,14022656.0,14053184.0,13868416.0,13899008.0,13929600.0,13960192.0,13990784.0,14021376.0,14051968.0,14082560.0,13897344.0,13928000.0,13958656.0,13989312.0,14019968.0,14050624.0,14081280.0,14111936.0,13926272.0,13956992.0,13987712.0,14018432.0,14049152.0,14079872.0,14110592.0,14141312.0,13955200.0,13985984.0,14016768.0,14047552.0,14078336.0,14109120.0,14139904.0,14170688.0,13984128.0,14014976.0,14045824.0,14076672.0,14107520.0,14138368.0,14169216.0,14200064.0,14013056.0,14043968.0,14074880.0,14105792.0,14136704.0,14167616.0,14198528.0,14229440.0,14041984.0,14072960.0,14103936.0,14134912.0,14165888.0,14196864.0,14227840.0,14258816.0,14070912.0,14101952.0,14132992.0,14164032.0,14195072.0,14226112.0,14257152.0,14288192.0,14099840.0,14130944.0,14162048.0,14193152.0,14224256.0,14255360.0,14286464.0,14317568.0,14128768.0,14159936.0,14191104.0,14222272.0,14253440.0,14284608.0,14315776.0,14346944.0,14157696.0,14188928.0,14220160.0,14251392.0,14282624.0,14313856.0,14345088.0,14376320.0,14186624.0,14217920.0,14249216.0,14280512.0,14311808.0,14343104.0,14374400.0,14405696.0,14215552.0,14246912.0,14278272.0,14309632.0,14340992.0,14372352.0,14403712.0,14435072.0,14244480.0,14275904.0,14307328.0,14338752.0,14370176.0,14401600.0,14433024.0,14464448.0,14273408.0,14304896.0,14336384.0,14367872.0,14399360.0,14430848.0,14462336.0,14493824.0,14302336.0,14333888.0,14365440.0,14396992.0,14428544.0,14460096.0,14491648.0,14523200.0,14331264.0,14362880.0,14394496.0,14426112.0,14457728.0,14489344.0,14520960.0,14552576.0,14360192.0,14391872.0,14423552.0,14455232.0,14486912.0,14518592.0,14550272.0,14581952.0,14389120.0,14420864.0,14452608.0,14484352.0,14516096.0,14547840.0,14579584.0,14611328.0,14418048.0,14449856.0,14481664.0,14513472.0,14545280.0,14577088.0,14608896.0,14640704.0,14446976.0,14478848.0,14510720.0,14542592.0,14574464.0,14606336.0,14638208.0,14670080.0,14475904.0,14507840.0,14539776.0,14571712.0,14603648.0,14635584.0,14667520.0,14699456.0,14504832.0,14536832.0,14568832.0,14600832.0,14632832.0,14664832.0,14696832.0,14728832.0,14533760.0,14565824.0,14597888.0,14629952.0,14662016.0,14694080.0,14726144.0,14758208.0,14562688.0,14594816.0,14626944.0,14659072.0,14691200.0,14723328.0,14755456.0,14787584.0,14591616.0,14623808.0,14656000.0,14688192.0,14720384.0,14752576.0,14784768.0,14816960.0,14620544.0,14652800.0,14685056.0,14717312.0,14749568.0,14781824.0,14814080.0,14846336.0,14649472.0,14681792.0,14714112.0,14746432.0,14778752.0,14811072.0,14843392.0,14875712.0,14678400.0,14710784.0,14743168.0,14775552.0,14807936.0,14840320.0,14872704.0,14905088.0,14707328.0,14739776.0,14772224.0,14804672.0,14837120.0,14869568.0,14902016.0,14934464.0,14736256.0,14768768.0,14801280.0,14833792.0,14866304.0,14898816.0,14931328.0,14963840.0,14765184.0,14797760.0,14830336.0,14862912.0,14895488.0,14928064.0,14960640.0,14993216.0,14794112.0,14826752.0,14859392.0,14892032.0,14924672.0,14957312.0,14989952.0,15022592.0,14823040.0,14855744.0,14888448.0,14921152.0,14953856.0,14986560.0,15019264.0,15051968.0,14851968.0,14884736.0,14917504.0,14950272.0,14983040.0,15015808.0,15048576.0,15081344.0,14880896.0,14913728.0,14946560.0,14979392.0,15012224.0,15045056.0,15077888.0,15110720.0,14188928.0,14220800.0,14252672.0,14284544.0,14316416.0,14348288.0,14380160.0,14412032.0,13488768.0,13519680.0,13550592.0,13581504.0,13612416.0,13643328.0,13674240.0,13705152.0,13517696.0,13548672.0,13579648.0,13610624.0,13641600.0,13672576.0,13703552.0,13734528.0,13546624.0,13577664.0,13608704.0,13639744.0,13670784.0,13701824.0,13732864.0,13763904.0,13575552.0,13606656.0,13637760.0,13668864.0,13699968.0,13731072.0,13762176.0,13793280.0,13604480.0,13635648.0,13666816.0,13697984.0,13729152.0,13760320.0,13791488.0,13822656.0,13633408.0,13664640.0,13695872.0,13727104.0,13758336.0,13789568.0,13820800.0,13852032.0,13662336.0,13693632.0,13724928.0,13756224.0,13787520.0,13818816.0,13850112.0,13881408.0,13691264.0,13722624.0,13753984.0,13785344.0,13816704.0,13848064.0,13879424.0,13910784.0,13720192.0,13751616.0,13783040.0,13814464.0,13845888.0,13877312.0,13908736.0,13940160.0,13749120.0,13780608.0,13812096.0,13843584.0,13875072.0,13906560.0,13938048.0,13969536.0,13778048.0,13809600.0,13841152.0,13872704.0,13904256.0,13935808.0,13967360.0,13998912.0,13806976.0,13838592.0,13870208.0,13901824.0,13933440.0,13965056.0,13996672.0,14028288.0,13835904.0,13867584.0,13899264.0,13930944.0,13962624.0,13994304.0,14025984.0,14057664.0,13864832.0,13896576.0,13928320.0,13960064.0,13991808.0,14023552.0,14055296.0,14087040.0,13893760.0,13925568.0,13957376.0,13989184.0,14020992.0,14052800.0,14084608.0,14116416.0,13922688.0,13954560.0,13986432.0,14018304.0,14050176.0,14082048.0,14113920.0,14145792.0,13951616.0,13983552.0,14015488.0,14047424.0,14079360.0,14111296.0,14143232.0,14175168.0,13980544.0,14012544.0,14044544.0,14076544.0,14108544.0,14140544.0,14172544.0,14204544.0,14009472.0,14041536.0,14073600.0,14105664.0,14137728.0,14169792.0,14201856.0,14233920.0,14038400.0,14070528.0,14102656.0,14134784.0,14166912.0,14199040.0,14231168.0,14263296.0,14067328.0,14099520.0,14131712.0,14163904.0,14196096.0,14228288.0,14260480.0,14292672.0,14096256.0,14128512.0,14160768.0,14193024.0,14225280.0,14257536.0,14289792.0,14322048.0,14125184.0,14157504.0,14189824.0,14222144.0,14254464.0,14286784.0,14319104.0,14351424.0,14154112.0,14186496.0,14218880.0,14251264.0,14283648.0,14316032.0,14348416.0,14380800.0,14183040.0,14215488.0,14247936.0,14280384.0,14312832.0,14345280.0,14377728.0,14410176.0,14211968.0,14244480.0,14276992.0,14309504.0,14342016.0,14374528.0,14407040.0,14439552.0,14240896.0,14273472.0,14306048.0,14338624.0,14371200.0,14403776.0,14436352.0,14468928.0,14269824.0,14302464.0,14335104.0,14367744.0,14400384.0,14433024.0,14465664.0,14498304.0,14298752.0,14331456.0,14364160.0,14396864.0,14429568.0,14462272.0,14494976.0,14527680.0,14327680.0,14360448.0,14393216.0,14425984.0,14458752.0,14491520.0,14524288.0,14557056.0,14356608.0,14389440.0,14422272.0,14455104.0,14487936.0,14520768.0,14553600.0,14586432.0,14385536.0,14418432.0,14451328.0,14484224.0,14517120.0,14550016.0,14582912.0,14615808.0,14414464.0,14447424.0,14480384.0,14513344.0,14546304.0,14579264.0,14612224.0,14645184.0,14443392.0,14476416.0,14509440.0,14542464.0,14575488.0,14608512.0,14641536.0,14674560.0,14472320.0,14505408.0,14538496.0,14571584.0,14604672.0,14637760.0,14670848.0,14703936.0,14501248.0,14534400.0,14567552.0,14600704.0,14633856.0,14667008.0,14700160.0,14733312.0,14530176.0,14563392.0,14596608.0,14629824.0,14663040.0,14696256.0,14729472.0,14762688.0,14559104.0,14592384.0,14625664.0,14658944.0,14692224.0,14725504.0,14758784.0,14792064.0,14391424.0,14423744.0,14456064.0,14488384.0,14520704.0,14553024.0,14585344.0,14617664.0,14215552.0,14246912.0,14278272.0,14309632.0,14340992.0,14372352.0,14403712.0,14435072.0,14244480.0,14275904.0,14307328.0,14338752.0,14370176.0,14401600.0,14433024.0,14464448.0,14273408.0,14304896.0,14336384.0,14367872.0,14399360.0,14430848.0,14462336.0,14493824.0,14302336.0,14333888.0,14365440.0,14396992.0,14428544.0,14460096.0,14491648.0,14523200.0,14331264.0,14362880.0,14394496.0,14426112.0,14457728.0,14489344.0,14520960.0,14552576.0,14360192.0,14391872.0,14423552.0,14455232.0,14486912.0,14518592.0,14550272.0,14581952.0,14389120.0,14420864.0,14452608.0,14484352.0,14516096.0,14547840.0,14579584.0,14611328.0,14418048.0,14449856.0,14481664.0,14513472.0,14545280.0,14577088.0,14608896.0,14640704.0,14446976.0,14478848.0,14510720.0,14542592.0,14574464.0,14606336.0,14638208.0,14670080.0,14475904.0,14507840.0,14539776.0,14571712.0,14603648.0,14635584.0,14667520.0,14699456.0,14504832.0,14536832.0,14568832.0,14600832.0,14632832.0,14664832.0,14696832.0,14728832.0,14533760.0,14565824.0,14597888.0,14629952.0,14662016.0,14694080.0,14726144.0,14758208.0,14562688.0,14594816.0,14626944.0,14659072.0,14691200.0,14723328.0,14755456.0,14787584.0,14591616.0,14623808.0,14656000.0,14688192.0,14720384.0,14752576.0,14784768.0,14816960.0,14620544.0,14652800.0,14685056.0,14717312.0,14749568.0,14781824.0,14814080.0,14846336.0,14649472.0,14681792.0,14714112.0,14746432.0,14778752.0,14811072.0,14843392.0,14875712.0,14678400.0,14710784.0,14743168.0,14775552.0,14807936.0,14840320.0,14872704.0,14905088.0,14707328.0,14739776.0,14772224.0,14804672.0,14837120.0,14869568.0,14902016.0,14934464.0,14736256.0,14768768.0,14801280.0,14833792.0,14866304.0,14898816.0,14931328.0,14963840.0,14765184.0,14797760.0,14830336.0,14862912.0,14895488.0,14928064.0,14960640.0,14993216.0,14794112.0,14826752.0,14859392.0,14892032.0,14924672.0,14957312.0,14989952.0,15022592.0,14823040.0,14855744.0,14888448.0,14921152.0,14953856.0,14986560.0,15019264.0,15051968.0,14851968.0,14884736.0,14917504.0,14950272.0,14983040.0,15015808.0,15048576.0,15081344.0,14880896.0,14913728.0,14946560.0,14979392.0,15012224.0,15045056.0,15077888.0,15110720.0,14909824.0,14942720.0,14975616.0,15008512.0,15041408.0,15074304.0,15107200.0,15140096.0,14938752.0,14971712.0,15004672.0,15037632.0,15070592.0,15103552.0,15136512.0,15169472.0,14967680.0,15000704.0,15033728.0,15066752.0,15099776.0,15132800.0,15165824.0,15198848.0,14996608.0,15029696.0,15062784.0,15095872.0,15128960.0,15162048.0,15195136.0,15228224.0,15025536.0,15058688.0,15091840.0,15124992.0,15158144.0,15191296.0,15224448.0,15257600.0,15054464.0,15087680.0,15120896.0,15154112.0,15187328.0,15220544.0,15253760.0,15286976.0,15083392.0,15116672.0,15149952.0,15183232.0,15216512.0,15249792.0,15283072.0,15316352.0,15112320.0,15145664.0,15179008.0,15212352.0,15245696.0,15279040.0,15312384.0,15345728.0,15141248.0,15174656.0,15208064.0,15241472.0,15274880.0,15308288.0,15341696.0,15375104.0,15170176.0,15203648.0,15237120.0,15270592.0,15304064.0,15337536.0,15371008.0,15404480.0,15199104.0,15232640.0,15266176.0,15299712.0,15333248.0,15366784.0,15400320.0,15433856.0,15228032.0,15261632.0,15295232.0,15328832.0,15362432.0,15396032.0,15429632.0,15463232.0,15256960.0,15290624.0,15324288.0,15357952.0,15391616.0,15425280.0,15458944.0,15492608.0,15285888.0,15319616.0,15353344.0,15387072.0,15420800.0,15454528.0,15488256.0,15521984.0,14593920.0,14626688.0,14659456.0,14692224.0,14724992.0,14757760.0,14790528.0,14823296.0,13893760.0,13925568.0,13957376.0,13989184.0,14020992.0,14052800.0,14084608.0,14116416.0,7017856.0,7033776.0,7049696.0,7065616.0,7081536.0,7097456.0,7113376.0,7129296.0,7610112.0,7625904.0,7641696.0,7657488.0,7673280.0,7689072.0,7704864.0,7720656.0,15361408.0,15393024.0,15424640.0,15456256.0,15487872.0,15519488.0,15551104.0,15582720.0,15391360.0,15423040.0,15454720.0,15486400.0,15518080.0,15549760.0,15581440.0,15613120.0,15421312.0,15453056.0,15484800.0,15516544.0,15548288.0,15580032.0,15611776.0,15643520.0,15451264.0,15483072.0,15514880.0,15546688.0,15578496.0,15610304.0,15642112.0,15673920.0,15481216.0,15513088.0,15544960.0,15576832.0,15608704.0,15640576.0,15672448.0,15704320.0,15511168.0,15543104.0,15575040.0,15606976.0,15638912.0,15670848.0,15702784.0,15734720.0,15541120.0,15573120.0,15605120.0,15637120.0,15669120.0,15701120.0,15733120.0,15765120.0,15571072.0,15603136.0,15635200.0,15667264.0,15699328.0,15731392.0,15763456.0,15795520.0,15601024.0,15633152.0,15665280.0,15697408.0,15729536.0,15761664.0,15793792.0,15825920.0,15630976.0,15663168.0,15695360.0,15727552.0,15759744.0,15791936.0,15824128.0,15856320.0,15660928.0,15693184.0,15725440.0,15757696.0,15789952.0,15822208.0,15854464.0,15886720.0,15690880.0,15723200.0,15755520.0,15787840.0,15820160.0,15852480.0,15884800.0,15917120.0,15720832.0,15753216.0,15785600.0,15817984.0,15850368.0,15882752.0,15915136.0,15947520.0,15750784.0,15783232.0,15815680.0,15848128.0,15880576.0,15913024.0,15945472.0,15977920.0,15780736.0,15813248.0,15845760.0,15878272.0,15910784.0,15943296.0,15975808.0,16008320.0,15810688.0,15843264.0,15875840.0,15908416.0,15940992.0,15973568.0,16006144.0,16038720.0,15840640.0,15873280.0,15905920.0,15938560.0,15971200.0,16003840.0,16036480.0,16069120.0,15870592.0,15903296.0,15936000.0,15968704.0,16001408.0,16034112.0,16066816.0,16099520.0,15900544.0,15933312.0,15966080.0,15998848.0,16031616.0,16064384.0,16097152.0,16129920.0,15930496.0,15963328.0,15996160.0,16028992.0,16061824.0,16094656.0,16127488.0,16160320.0,15960448.0,15993344.0,16026240.0,16059136.0,16092032.0,16124928.0,16157824.0,16190720.0,15990400.0,16023360.0,16056320.0,16089280.0,16122240.0,16155200.0,16188160.0,16221120.0,16020352.0,16053376.0,16086400.0,16119424.0,16152448.0,16185472.0,16218496.0,16251520.0,16050304.0,16083392.0,16116480.0,16149568.0,16182656.0,16215744.0,16248832.0,16281920.0,16080256.0,16113408.0,16146560.0,16179712.0,16212864.0,16246016.0,16279168.0,16312320.0,16110208.0,16143424.0,16176640.0,16209856.0,16243072.0,16276288.0,16309504.0,16342720.0,16140160.0,16173440.0,16206720.0,16240000.0,16273280.0,16306560.0,16339840.0,16373120.0,16170112.0,16203456.0,16236800.0,16270144.0,16303488.0,16336832.0,16370176.0,16403520.0,16200064.0,16233472.0,16266880.0,16300288.0,16333696.0,16367104.0,16400512.0,16433920.0,16230016.0,16263488.0,16296960.0,16330432.0,16363904.0,16397376.0,16430848.0,16464320.0,16259968.0,16293504.0,16327040.0,16360576.0,16394112.0,16427648.0,16461184.0,16494720.0,16289920.0,16323520.0,16357120.0,16390720.0,16424320.0,16457920.0,16491520.0,16525120.0,16319872.0,16353536.0,16387200.0,16420864.0,16454528.0,16488192.0,16521856.0,16555520.0,16349824.0,16383552.0,16417280.0,16451008.0,16484736.0,16518464.0,16552192.0,16585920.0,16379776.0,16413568.0,16447360.0,16481152.0,16514944.0,16548736.0,16582528.0,16616320.0,16409728.0,16443584.0,16477440.0,16511296.0,16545152.0,16579008.0,16612864.0,16646720.0,15702400.0,15735296.0,15768192.0,15801088.0,15833984.0,15866880.0,15899776.0,15932672.0,14986880.0,15018816.0,15050752.0,15082688.0,15114624.0,15146560.0,15178496.0,15210432.0,14279552.0,14310528.0,14341504.0,14372480.0,14403456.0,14434432.0,14465408.0,14496384.0,13564032.0,13594048.0,13624064.0,13654080.0,13684096.0,13714112.0,13744128.0,13774144.0,13593984.0,13624064.0,13654144.0,13684224.0,13714304.0,13744384.0,13774464.0,13804544.0,13623936.0,13654080.0,13684224.0,13714368.0,13744512.0,13774656.0,13804800.0,13834944.0,13653888.0,13684096.0,13714304.0,13744512.0,13774720.0,13804928.0,13835136.0,13865344.0,13683840.0,13714112.0,13744384.0,13774656.0,13804928.0,13835200.0,13865472.0,13895744.0,13713792.0,13744128.0,13774464.0,13804800.0,13835136.0,13865472.0,13895808.0,13926144.0,13743744.0,13774144.0,13804544.0,13834944.0,13865344.0,13895744.0,13926144.0,13956544.0,13773696.0,13804160.0,13834624.0,13865088.0,13895552.0,13926016.0,13956480.0,13986944.0,13803648.0,13834176.0,13864704.0,13895232.0,13925760.0,13956288.0,13986816.0,14017344.0,13833600.0,13864192.0,13894784.0,13925376.0,13955968.0,13986560.0,14017152.0,14047744.0,13863552.0,13894208.0,13924864.0,13955520.0,13986176.0,14016832.0,14047488.0,14078144.0,13893504.0,13924224.0,13954944.0,13985664.0,14016384.0,14047104.0,14077824.0,14108544.0,13923456.0,13954240.0,13985024.0,14015808.0,14046592.0,14077376.0,14108160.0,14138944.0,13953408.0,13984256.0,14015104.0,14045952.0,14076800.0,14107648.0,14138496.0,14169344.0,13983360.0,14014272.0,14045184.0,14076096.0,14107008.0,14137920.0,14168832.0,14199744.0,14013312.0,14044288.0,14075264.0,14106240.0,14137216.0,14168192.0,14199168.0,14230144.0,14043264.0,14074304.0,14105344.0,14136384.0,14167424.0,14198464.0,14229504.0,14260544.0,14073216.0,14104320.0,14135424.0,14166528.0,14197632.0,14228736.0,14259840.0,14290944.0,14103168.0,14134336.0,14165504.0,14196672.0,14227840.0,14259008.0,14290176.0,14321344.0,14133120.0,14164352.0,14195584.0,14226816.0,14258048.0,14289280.0,14320512.0,14351744.0,14163072.0,14194368.0,14225664.0,14256960.0,14288256.0,14319552.0,14350848.0,14382144.0,14193024.0,14224384.0,14255744.0,14287104.0,14318464.0,14349824.0,14381184.0,14412544.0,14222976.0,14254400.0,14285824.0,14317248.0,14348672.0,14380096.0,14411520.0,14442944.0,14252928.0,14284416.0,14315904.0,14347392.0,14378880.0,14410368.0,14441856.0,14473344.0,14282880.0,14314432.0,14345984.0,14377536.0,14409088.0,14440640.0,14472192.0,14503744.0,14312832.0,14344448.0,14376064.0,14407680.0,14439296.0,14470912.0,14502528.0,14534144.0,14342784.0,14374464.0,14406144.0,14437824.0,14469504.0,14501184.0,14532864.0,14564544.0,14372736.0,14404480.0,14436224.0,14467968.0,14499712.0,14531456.0,14563200.0,14594944.0,14402688.0,14434496.0,14466304.0,14498112.0,14529920.0,14561728.0,14593536.0,14625344.0,14432640.0,14464512.0,14496384.0,14528256.0,14560128.0,14592000.0,14623872.0,14655744.0,14462592.0,14494528.0,14526464.0,14558400.0,14590336.0,14622272.0,14654208.0,14686144.0,14492544.0,14524544.0,14556544.0,14588544.0,14620544.0,14652544.0,14684544.0,14716544.0,14522496.0,14554560.0,14586624.0,14618688.0,14650752.0,14682816.0,14714880.0,14746944.0,14552448.0,14584576.0,14616704.0,14648832.0,14680960.0,14713088.0,14745216.0,14777344.0,14582400.0,14614592.0,14646784.0,14678976.0,14711168.0,14743360.0,14775552.0,14807744.0,14612352.0,14644608.0,14676864.0,14709120.0,14741376.0,14773632.0,14805888.0,14838144.0,14642304.0,14674624.0,14706944.0,14739264.0,14771584.0,14803904.0,14836224.0,14868544.0,14672256.0,14704640.0,14737024.0,14769408.0,14801792.0,14834176.0,14866560.0,14898944.0,14489216.0,14520640.0,14552064.0,14583488.0,14614912.0,14646336.0,14677760.0,14709184.0,14297984.0,14328448.0,14358912.0,14389376.0,14419840.0,14450304.0,14480768.0,14511232.0,14327936.0,14358464.0,14388992.0,14419520.0,14450048.0,14480576.0,14511104.0,14541632.0,14357888.0,14388480.0,14419072.0,14449664.0,14480256.0,14510848.0,14541440.0,14572032.0,14387840.0,14418496.0,14449152.0,14479808.0,14510464.0,14541120.0,14571776.0,14602432.0,14417792.0,14448512.0,14479232.0,14509952.0,14540672.0,14571392.0,14602112.0,14632832.0,14447744.0,14478528.0,14509312.0,14540096.0,14570880.0,14601664.0,14632448.0,14663232.0,14477696.0,14508544.0,14539392.0,14570240.0,14601088.0,14631936.0,14662784.0,14693632.0,14507648.0,14538560.0,14569472.0,14600384.0,14631296.0,14662208.0,14693120.0,14724032.0,14537600.0,14568576.0,14599552.0,14630528.0,14661504.0,14692480.0,14723456.0,14754432.0,14567552.0,14598592.0,14629632.0,14660672.0,14691712.0,14722752.0,14753792.0,14784832.0,14597504.0,14628608.0,14659712.0,14690816.0,14721920.0,14753024.0,14784128.0,14815232.0,14627456.0,14658624.0,14689792.0,14720960.0,14752128.0,14783296.0,14814464.0,14845632.0,14657408.0,14688640.0,14719872.0,14751104.0,14782336.0,14813568.0,14844800.0,14876032.0,14687360.0,14718656.0,14749952.0,14781248.0,14812544.0,14843840.0,14875136.0,14906432.0,14717312.0,14748672.0,14780032.0,14811392.0,14842752.0,14874112.0,14905472.0,14936832.0,14747264.0,14778688.0,14810112.0,14841536.0,14872960.0,14904384.0,14935808.0,14967232.0,14777216.0,14808704.0,14840192.0,14871680.0,14903168.0,14934656.0,14966144.0,14997632.0,14807168.0,14838720.0,14870272.0,14901824.0,14933376.0,14964928.0,14996480.0,15028032.0,14837120.0,14868736.0,14900352.0,14931968.0,14963584.0,14995200.0,15026816.0,15058432.0,14867072.0,14898752.0,14930432.0,14962112.0,14993792.0,15025472.0,15057152.0,15088832.0,14897024.0,14928768.0,14960512.0,14992256.0,15024000.0,15055744.0,15087488.0,15119232.0,14926976.0,14958784.0,14990592.0,15022400.0,15054208.0,15086016.0,15117824.0,15149632.0,14956928.0,14988800.0,15020672.0,15052544.0,15084416.0,15116288.0,15148160.0,15180032.0,14986880.0,15018816.0,15050752.0,15082688.0,15114624.0,15146560.0,15178496.0,15210432.0,15016832.0,15048832.0,15080832.0,15112832.0,15144832.0,15176832.0,15208832.0,15240832.0,15046784.0,15078848.0,15110912.0,15142976.0,15175040.0,15207104.0,15239168.0,15271232.0,15076736.0,15108864.0,15140992.0,15173120.0,15205248.0,15237376.0,15269504.0,15301632.0,15106688.0,15138880.0,15171072.0,15203264.0,15235456.0,15267648.0,15299840.0,15332032.0,15136640.0,15168896.0,15201152.0,15233408.0,15265664.0,15297920.0,15330176.0,15362432.0,15166592.0,15198912.0,15231232.0,15263552.0,15295872.0,15328192.0,15360512.0,15392832.0,15196544.0,15228928.0,15261312.0,15293696.0,15326080.0,15358464.0,15390848.0,15423232.0,15226496.0,15258944.0,15291392.0,15323840.0,15356288.0,15388736.0,15421184.0,15453632.0,15256448.0,15288960.0,15321472.0,15353984.0,15386496.0,15419008.0,15451520.0,15484032.0,15286400.0,15318976.0,15351552.0,15384128.0,15416704.0,15449280.0,15481856.0,15514432.0,15316352.0,15348992.0,15381632.0,15414272.0,15446912.0,15479552.0,15512192.0,15544832.0,15346304.0,15379008.0,15411712.0,15444416.0,15477120.0,15509824.0,15542528.0,15575232.0,15376256.0,15409024.0,15441792.0,15474560.0,15507328.0,15540096.0,15572864.0,15605632.0,15406208.0,15439040.0,15471872.0,15504704.0,15537536.0,15570368.0,15603200.0,15636032.0,14698880.0,14730752.0,14762624.0,14794496.0,14826368.0,14858240.0,14890112.0,14921984.0,13983360.0,14014272.0,14045184.0,14076096.0,14107008.0,14137920.0,14168832.0,14199744.0,14013312.0,14044288.0,14075264.0,14106240.0,14137216.0,14168192.0,14199168.0,14230144.0,14043264.0,14074304.0,14105344.0,14136384.0,14167424.0,14198464.0,14229504.0,14260544.0,14073216.0,14104320.0,14135424.0,14166528.0,14197632.0,14228736.0,14259840.0,14290944.0,14103168.0,14134336.0,14165504.0,14196672.0,14227840.0,14259008.0,14290176.0,14321344.0,14133120.0,14164352.0,14195584.0,14226816.0,14258048.0,14289280.0,14320512.0,14351744.0,14163072.0,14194368.0,14225664.0,14256960.0,14288256.0,14319552.0,14350848.0,14382144.0,14193024.0,14224384.0,14255744.0,14287104.0,14318464.0,14349824.0,14381184.0,14412544.0,14222976.0,14254400.0,14285824.0,14317248.0,14348672.0,14380096.0,14411520.0,14442944.0,14252928.0,14284416.0,14315904.0,14347392.0,14378880.0,14410368.0,14441856.0,14473344.0,14282880.0,14314432.0,14345984.0,14377536.0,14409088.0,14440640.0,14472192.0,14503744.0,14312832.0,14344448.0,14376064.0,14407680.0,14439296.0,14470912.0,14502528.0,14534144.0,14342784.0,14374464.0,14406144.0,14437824.0,14469504.0,14501184.0,14532864.0,14564544.0,14372736.0,14404480.0,14436224.0,14467968.0,14499712.0,14531456.0,14563200.0,14594944.0,14402688.0,14434496.0,14466304.0,14498112.0,14529920.0,14561728.0,14593536.0,14625344.0,14432640.0,14464512.0,14496384.0,14528256.0,14560128.0,14592000.0,14623872.0,14655744.0,14462592.0,14494528.0,14526464.0,14558400.0,14590336.0,14622272.0,14654208.0,14686144.0,14492544.0,14524544.0,14556544.0,14588544.0,14620544.0,14652544.0,14684544.0,14716544.0,14522496.0,14554560.0,14586624.0,14618688.0,14650752.0,14682816.0,14714880.0,14746944.0,14552448.0,14584576.0,14616704.0,14648832.0,14680960.0,14713088.0,14745216.0,14777344.0,14582400.0,14614592.0,14646784.0,14678976.0,14711168.0,14743360.0,14775552.0,14807744.0,14612352.0,14644608.0,14676864.0,14709120.0,14741376.0,14773632.0,14805888.0,14838144.0,14642304.0,14674624.0,14706944.0,14739264.0,14771584.0,14803904.0,14836224.0,14868544.0,14672256.0,14704640.0,14737024.0,14769408.0,14801792.0,14834176.0,14866560.0,14898944.0,14702208.0,14734656.0,14767104.0,14799552.0,14832000.0,14864448.0,14896896.0,14929344.0,14732160.0,14764672.0,14797184.0,14829696.0,14862208.0,14894720.0,14927232.0,14959744.0,14762112.0,14794688.0,14827264.0,14859840.0,14892416.0,14924992.0,14957568.0,14990144.0,14792064.0,14824704.0,14857344.0,14889984.0,14922624.0,14955264.0,14987904.0,15020544.0,14822016.0,14854720.0,14887424.0,14920128.0,14952832.0,14985536.0,15018240.0,15050944.0,14851968.0,14884736.0,14917504.0,14950272.0,14983040.0,15015808.0,15048576.0,15081344.0,14881920.0,14914752.0,14947584.0,14980416.0,15013248.0,15046080.0,15078912.0,15111744.0,14911872.0,14944768.0,14977664.0,15010560.0,15043456.0,15076352.0,15109248.0,15142144.0,14941824.0,14974784.0,15007744.0,15040704.0,15073664.0,15106624.0,15139584.0,15172544.0,14971776.0,15004800.0,15037824.0,15070848.0,15103872.0,15136896.0,15169920.0,15202944.0,15001728.0,15034816.0,15067904.0,15100992.0,15134080.0,15167168.0,15200256.0,15233344.0,15031680.0,15064832.0,15097984.0,15131136.0,15164288.0,15197440.0,15230592.0,15263744.0,15061632.0,15094848.0,15128064.0,15161280.0,15194496.0,15227712.0,15260928.0,15294144.0,15091584.0,15124864.0,15158144.0,15191424.0,15224704.0,15257984.0,15291264.0,15324544.0,14908544.0,14940864.0,14973184.0,15005504.0,15037824.0,15070144.0,15102464.0,15134784.0,14717312.0,14748672.0,14780032.0,14811392.0,14842752.0,14874112.0,14905472.0,14936832.0,14747264.0,14778688.0,14810112.0,14841536.0,14872960.0,14904384.0,14935808.0,14967232.0,14777216.0,14808704.0,14840192.0,14871680.0,14903168.0,14934656.0,14966144.0,14997632.0,14807168.0,14838720.0,14870272.0,14901824.0,14933376.0,14964928.0,14996480.0,15028032.0,14837120.0,14868736.0,14900352.0,14931968.0,14963584.0,14995200.0,15026816.0,15058432.0,14867072.0,14898752.0,14930432.0,14962112.0,14993792.0,15025472.0,15057152.0,15088832.0,14897024.0,14928768.0,14960512.0,14992256.0,15024000.0,15055744.0,15087488.0,15119232.0,14926976.0,14958784.0,14990592.0,15022400.0,15054208.0,15086016.0,15117824.0,15149632.0,14956928.0,14988800.0,15020672.0,15052544.0,15084416.0,15116288.0,15148160.0,15180032.0,14986880.0,15018816.0,15050752.0,15082688.0,15114624.0,15146560.0,15178496.0,15210432.0,15016832.0,15048832.0,15080832.0,15112832.0,15144832.0,15176832.0,15208832.0,15240832.0,15046784.0,15078848.0,15110912.0,15142976.0,15175040.0,15207104.0,15239168.0,15271232.0,15076736.0,15108864.0,15140992.0,15173120.0,15205248.0,15237376.0,15269504.0,15301632.0,15106688.0,15138880.0,15171072.0,15203264.0,15235456.0,15267648.0,15299840.0,15332032.0,15136640.0,15168896.0,15201152.0,15233408.0,15265664.0,15297920.0,15330176.0,15362432.0,15166592.0,15198912.0,15231232.0,15263552.0,15295872.0,15328192.0,15360512.0,15392832.0,15196544.0,15228928.0,15261312.0,15293696.0,15326080.0,15358464.0,15390848.0,15423232.0,15226496.0,15258944.0,15291392.0,15323840.0,15356288.0,15388736.0,15421184.0,15453632.0,15256448.0,15288960.0,15321472.0,15353984.0,15386496.0,15419008.0,15451520.0,15484032.0,15286400.0,15318976.0,15351552.0,15384128.0,15416704.0,15449280.0,15481856.0,15514432.0,15316352.0,15348992.0,15381632.0,15414272.0,15446912.0,15479552.0,15512192.0,15544832.0,15346304.0,15379008.0,15411712.0,15444416.0,15477120.0,15509824.0,15542528.0,15575232.0,15376256.0,15409024.0,15441792.0,15474560.0,15507328.0,15540096.0,15572864.0,15605632.0,15406208.0,15439040.0,15471872.0,15504704.0,15537536.0,15570368.0,15603200.0,15636032.0,15436160.0,15469056.0,15501952.0,15534848.0,15567744.0,15600640.0,15633536.0,15666432.0,15466112.0,15499072.0,15532032.0,15564992.0,15597952.0,15630912.0,15663872.0,15696832.0,15496064.0,15529088.0,15562112.0,15595136.0,15628160.0,15661184.0,15694208.0,15727232.0,15526016.0,15559104.0,15592192.0,15625280.0,15658368.0,15691456.0,15724544.0,15757632.0,15555968.0,15589120.0,15622272.0,15655424.0,15688576.0,15721728.0,15754880.0,15788032.0,15585920.0,15619136.0,15652352.0,15685568.0,15718784.0,15752000.0,15785216.0,15818432.0,15615872.0,15649152.0,15682432.0,15715712.0,15748992.0,15782272.0,15815552.0,15848832.0,15645824.0,15679168.0,15712512.0,15745856.0,15779200.0,15812544.0,15845888.0,15879232.0,15675776.0,15709184.0,15742592.0,15776000.0,15809408.0,15842816.0,15876224.0,15909632.0,15705728.0,15739200.0,15772672.0,15806144.0,15839616.0,15873088.0,15906560.0,15940032.0,15735680.0,15769216.0,15802752.0,15836288.0,15869824.0,15903360.0,15936896.0,15970432.0,15765632.0,15799232.0,15832832.0,15866432.0,15900032.0,15933632.0,15967232.0,16000832.0,15795584.0,15829248.0,15862912.0,15896576.0,15930240.0,15963904.0,15997568.0,16031232.0,15825536.0,15859264.0,15892992.0,15926720.0,15960448.0,15994176.0,16027904.0,16061632.0,15118208.0,15150976.0,15183744.0,15216512.0,15249280.0,15282048.0,15314816.0,15347584.0,14402688.0,14434496.0,14466304.0,14498112.0,14529920.0,14561728.0,14593536.0,14625344.0,7272576.0,7288496.0,7304416.0,7320336.0,7336256.0,7352176.0,7368096.0,7384016.0,7862784.0,7878576.0,7894368.0,7910160.0,7925952.0,7941744.0,7957536.0,7973328.0,15867264.0,15898880.0,15930496.0,15962112.0,15993728.0,16025344.0,16056960.0,16088576.0,15898240.0,15929920.0,15961600.0,15993280.0,16024960.0,16056640.0,16088320.0,16120000.0,15929216.0,15960960.0,15992704.0,16024448.0,16056192.0,16087936.0,16119680.0,16151424.0,15960192.0,15992000.0,16023808.0,16055616.0,16087424.0,16119232.0,16151040.0,16182848.0,15991168.0,16023040.0,16054912.0,16086784.0,16118656.0,16150528.0,16182400.0,16214272.0,16022144.0,16054080.0,16086016.0,16117952.0,16149888.0,16181824.0,16213760.0,16245696.0,16053120.0,16085120.0,16117120.0,16149120.0,16181120.0,16213120.0,16245120.0,16277120.0,16084096.0,16116160.0,16148224.0,16180288.0,16212352.0,16244416.0,16276480.0,16308544.0,16115072.0,16147200.0,16179328.0,16211456.0,16243584.0,16275712.0,16307840.0,16339968.0,16146048.0,16178240.0,16210432.0,16242624.0,16274816.0,16307008.0,16339200.0,16371392.0,16177024.0,16209280.0,16241536.0,16273792.0,16306048.0,16338304.0,16370560.0,16402816.0,16208000.0,16240320.0,16272640.0,16304960.0,16337280.0,16369600.0,16401920.0,16434240.0,16238976.0,16271360.0,16303744.0,16336128.0,16368512.0,16400896.0,16433280.0,16465664.0,16269952.0,16302400.0,16334848.0,16367296.0,16399744.0,16432192.0,16464640.0,16497088.0,16300928.0,16333440.0,16365952.0,16398464.0,16430976.0,16463488.0,16496000.0,16528512.0,16331904.0,16364480.0,16397056.0,16429632.0,16462208.0,16494784.0,16527360.0,16559936.0,16362880.0,16395520.0,16428160.0,16460800.0,16493440.0,16526080.0,16558720.0,16591360.0,16393856.0,16426560.0,16459264.0,16491968.0,16524672.0,16557376.0,16590080.0,16622784.0,16424832.0,16457600.0,16490368.0,16523136.0,16555904.0,16588672.0,16621440.0,16654208.0,16455808.0,16488640.0,16521472.0,16554304.0,16587136.0,16619968.0,16652800.0,16685632.0,16486784.0,16519680.0,16552576.0,16585472.0,16618368.0,16651264.0,16684160.0,16717056.0,16517760.0,16550720.0,16583680.0,16616640.0,16649600.0,16682560.0,16715520.0,16748480.0,16548736.0,16581760.0,16614784.0,16647808.0,16680832.0,16713856.0,16746880.0,16779904.0,16579712.0,16612800.0,16645888.0,16678976.0,16712064.0,16745152.0,16778240.0,16811328.0,16610688.0,16643840.0,16676992.0,16710144.0,16743296.0,16776448.0,16809600.0,16842752.0,16641664.0,16674880.0,16708096.0,16741312.0,16774528.0,16807744.0,16840960.0,16874176.0,16672640.0,16705920.0,16739200.0,16772480.0,16805760.0,16839040.0,16872320.0,16905600.0,16703616.0,16736960.0,16770304.0,16803648.0,16836992.0,16870336.0,16903680.0,16937024.0,16734592.0,16768000.0,16801408.0,16834816.0,16868224.0,16901632.0,16935040.0,16968448.0,16765568.0,16799040.0,16832512.0,16865984.0,16899456.0,16932928.0,16966400.0,16999872.0,16796544.0,16830080.0,16863616.0,16897152.0,16930688.0,16964224.0,16997760.0,17031296.0,16827520.0,16861120.0,16894720.0,16928320.0,16961920.0,16995520.0,17029120.0,17062720.0,16858496.0,16892160.0,16925824.0,16959488.0,16993152.0,17026816.0,17060480.0,17094144.0,16889472.0,16923200.0,16956928.0,16990656.0,17024384.0,17058112.0,17091840.0,17125568.0,16920448.0,16954240.0,16988032.0,17021824.0,17055616.0,17089408.0,17123200.0,17156992.0,16951424.0,16985280.0,17019136.0,17052992.0,17086848.0,17120704.0,17154560.0,17188416.0,16228736.0,16261632.0,16294528.0,16327424.0,16360320.0,16393216.0,16426112.0,16459008.0,15497856.0,15529792.0,15561728.0,15593664.0,15625600.0,15657536.0,15689472.0,15721408.0,14775168.0,14806144.0,14837120.0,14868096.0,14899072.0,14930048.0,14961024.0,14992000.0,14044288.0,14074304.0,14104320.0,14134336.0,14164352.0,14194368.0,14224384.0,14254400.0,14075264.0,14105344.0,14135424.0,14165504.0,14195584.0,14225664.0,14255744.0,14285824.0,14106240.0,14136384.0,14166528.0,14196672.0,14226816.0,14256960.0,14287104.0,14317248.0,14137216.0,14167424.0,14197632.0,14227840.0,14258048.0,14288256.0,14318464.0,14348672.0,14168192.0,14198464.0,14228736.0,14259008.0,14289280.0,14319552.0,14349824.0,14380096.0,14199168.0,14229504.0,14259840.0,14290176.0,14320512.0,14350848.0,14381184.0,14411520.0,14230144.0,14260544.0,14290944.0,14321344.0,14351744.0,14382144.0,14412544.0,14442944.0,14261120.0,14291584.0,14322048.0,14352512.0,14382976.0,14413440.0,14443904.0,14474368.0,14292096.0,14322624.0,14353152.0,14383680.0,14414208.0,14444736.0,14475264.0,14505792.0,14323072.0,14353664.0,14384256.0,14414848.0,14445440.0,14476032.0,14506624.0,14537216.0,14354048.0,14384704.0,14415360.0,14446016.0,14476672.0,14507328.0,14537984.0,14568640.0,14385024.0,14415744.0,14446464.0,14477184.0,14507904.0,14538624.0,14569344.0,14600064.0,14416000.0,14446784.0,14477568.0,14508352.0,14539136.0,14569920.0,14600704.0,14631488.0,14446976.0,14477824.0,14508672.0,14539520.0,14570368.0,14601216.0,14632064.0,14662912.0,14477952.0,14508864.0,14539776.0,14570688.0,14601600.0,14632512.0,14663424.0,14694336.0,14508928.0,14539904.0,14570880.0,14601856.0,14632832.0,14663808.0,14694784.0,14725760.0,14539904.0,14570944.0,14601984.0,14633024.0,14664064.0,14695104.0,14726144.0,14757184.0,14570880.0,14601984.0,14633088.0,14664192.0,14695296.0,14726400.0,14757504.0,14788608.0,14601856.0,14633024.0,14664192.0,14695360.0,14726528.0,14757696.0,14788864.0,14820032.0,14632832.0,14664064.0,14695296.0,14726528.0,14757760.0,14788992.0,14820224.0,14851456.0,14663808.0,14695104.0,14726400.0,14757696.0,14788992.0,14820288.0,14851584.0,14882880.0,14694784.0,14726144.0,14757504.0,14788864.0,14820224.0,14851584.0,14882944.0,14914304.0,14725760.0,14757184.0,14788608.0,14820032.0,14851456.0,14882880.0,14914304.0,14945728.0,14756736.0,14788224.0,14819712.0,14851200.0,14882688.0,14914176.0,14945664.0,14977152.0,14787712.0,14819264.0,14850816.0,14882368.0,14913920.0,14945472.0,14977024.0,15008576.0,14818688.0,14850304.0,14881920.0,14913536.0,14945152.0,14976768.0,15008384.0,15040000.0,14849664.0,14881344.0,14913024.0,14944704.0,14976384.0,15008064.0,15039744.0,15071424.0,14880640.0,14912384.0,14944128.0,14975872.0,15007616.0,15039360.0,15071104.0,15102848.0,14911616.0,14943424.0,14975232.0,15007040.0,15038848.0,15070656.0,15102464.0,15134272.0,14942592.0,14974464.0,15006336.0,15038208.0,15070080.0,15101952.0,15133824.0,15165696.0,14973568.0,15005504.0,15037440.0,15069376.0,15101312.0,15133248.0,15165184.0,15197120.0,15004544.0,15036544.0,15068544.0,15100544.0,15132544.0,15164544.0,15196544.0,15228544.0,15035520.0,15067584.0,15099648.0,15131712.0,15163776.0,15195840.0,15227904.0,15259968.0,15066496.0,15098624.0,15130752.0,15162880.0,15195008.0,15227136.0,15259264.0,15291392.0,15097472.0,15129664.0,15161856.0,15194048.0,15226240.0,15258432.0,15290624.0,15322816.0,15128448.0,15160704.0,15192960.0,15225216.0,15257472.0,15289728.0,15321984.0,15354240.0,15159424.0,15191744.0,15224064.0,15256384.0,15288704.0,15321024.0,15353344.0,15385664.0,15190400.0,15222784.0,15255168.0,15287552.0,15319936.0,15352320.0,15384704.0,15417088.0,14992000.0,15023424.0,15054848.0,15086272.0,15117696.0,15149120.0,15180544.0,15211968.0,14785408.0,14815872.0,14846336.0,14876800.0,14907264.0,14937728.0,14968192.0,14998656.0,14816384.0,14846912.0,14877440.0,14907968.0,14938496.0,14969024.0,14999552.0,15030080.0,14847360.0,14877952.0,14908544.0,14939136.0,14969728.0,15000320.0,15030912.0,15061504.0,14878336.0,14908992.0,14939648.0,14970304.0,15000960.0,15031616.0,15062272.0,15092928.0,14909312.0,14940032.0,14970752.0,15001472.0,15032192.0,15062912.0,15093632.0,15124352.0,14940288.0,14971072.0,15001856.0,15032640.0,15063424.0,15094208.0,15124992.0,15155776.0,14971264.0,15002112.0,15032960.0,15063808.0,15094656.0,15125504.0,15156352.0,15187200.0,15002240.0,15033152.0,15064064.0,15094976.0,15125888.0,15156800.0,15187712.0,15218624.0,15033216.0,15064192.0,15095168.0,15126144.0,15157120.0,15188096.0,15219072.0,15250048.0,15064192.0,15095232.0,15126272.0,15157312.0,15188352.0,15219392.0,15250432.0,15281472.0,15095168.0,15126272.0,15157376.0,15188480.0,15219584.0,15250688.0,15281792.0,15312896.0,15126144.0,15157312.0,15188480.0,15219648.0,15250816.0,15281984.0,15313152.0,15344320.0,15157120.0,15188352.0,15219584.0,15250816.0,15282048.0,15313280.0,15344512.0,15375744.0,15188096.0,15219392.0,15250688.0,15281984.0,15313280.0,15344576.0,15375872.0,15407168.0,15219072.0,15250432.0,15281792.0,15313152.0,15344512.0,15375872.0,15407232.0,15438592.0,15250048.0,15281472.0,15312896.0,15344320.0,15375744.0,15407168.0,15438592.0,15470016.0,15281024.0,15312512.0,15344000.0,15375488.0,15406976.0,15438464.0,15469952.0,15501440.0,15312000.0,15343552.0,15375104.0,15406656.0,15438208.0,15469760.0,15501312.0,15532864.0,15342976.0,15374592.0,15406208.0,15437824.0,15469440.0,15501056.0,15532672.0,15564288.0,15373952.0,15405632.0,15437312.0,15468992.0,15500672.0,15532352.0,15564032.0,15595712.0,15404928.0,15436672.0,15468416.0,15500160.0,15531904.0,15563648.0,15595392.0,15627136.0,15435904.0,15467712.0,15499520.0,15531328.0,15563136.0,15594944.0,15626752.0,15658560.0,15466880.0,15498752.0,15530624.0,15562496.0,15594368.0,15626240.0,15658112.0,15689984.0,15497856.0,15529792.0,15561728.0,15593664.0,15625600.0,15657536.0,15689472.0,15721408.0,15528832.0,15560832.0,15592832.0,15624832.0,15656832.0,15688832.0,15720832.0,15752832.0,15559808.0,15591872.0,15623936.0,15656000.0,15688064.0,15720128.0,15752192.0,15784256.0,15590784.0,15622912.0,15655040.0,15687168.0,15719296.0,15751424.0,15783552.0,15815680.0,15621760.0,15653952.0,15686144.0,15718336.0,15750528.0,15782720.0,15814912.0,15847104.0,15652736.0,15684992.0,15717248.0,15749504.0,15781760.0,15814016.0,15846272.0,15878528.0,15683712.0,15716032.0,15748352.0,15780672.0,15812992.0,15845312.0,15877632.0,15909952.0,15714688.0,15747072.0,15779456.0,15811840.0,15844224.0,15876608.0,15908992.0,15941376.0,15745664.0,15778112.0,15810560.0,15843008.0,15875456.0,15907904.0,15940352.0,15972800.0,15776640.0,15809152.0,15841664.0,15874176.0,15906688.0,15939200.0,15971712.0,16004224.0,15807616.0,15840192.0,15872768.0,15905344.0,15937920.0,15970496.0,16003072.0,16035648.0,15838592.0,15871232.0,15903872.0,15936512.0,15969152.0,16001792.0,16034432.0,16067072.0,15869568.0,15902272.0,15934976.0,15967680.0,16000384.0,16033088.0,16065792.0,16098496.0,15900544.0,15933312.0,15966080.0,15998848.0,16031616.0,16064384.0,16097152.0,16129920.0,15931520.0,15964352.0,15997184.0,16030016.0,16062848.0,16095680.0,16128512.0,16161344.0,15208832.0,15240704.0,15272576.0,15304448.0,15336320.0,15368192.0,15400064.0,15431936.0,14477952.0,14508864.0,14539776.0,14570688.0,14601600.0,14632512.0,14663424.0,14694336.0,14508928.0,14539904.0,14570880.0,14601856.0,14632832.0,14663808.0,14694784.0,14725760.0,14539904.0,14570944.0,14601984.0,14633024.0,14664064.0,14695104.0,14726144.0,14757184.0,14570880.0,14601984.0,14633088.0,14664192.0,14695296.0,14726400.0,14757504.0,14788608.0,14601856.0,14633024.0,14664192.0,14695360.0,14726528.0,14757696.0,14788864.0,14820032.0,14632832.0,14664064.0,14695296.0,14726528.0,14757760.0,14788992.0,14820224.0,14851456.0,14663808.0,14695104.0,14726400.0,14757696.0,14788992.0,14820288.0,14851584.0,14882880.0,14694784.0,14726144.0,14757504.0,14788864.0,14820224.0,14851584.0,14882944.0,14914304.0,14725760.0,14757184.0,14788608.0,14820032.0,14851456.0,14882880.0,14914304.0,14945728.0,14756736.0,14788224.0,14819712.0,14851200.0,14882688.0,14914176.0,14945664.0,14977152.0,14787712.0,14819264.0,14850816.0,14882368.0,14913920.0,14945472.0,14977024.0,15008576.0,14818688.0,14850304.0,14881920.0,14913536.0,14945152.0,14976768.0,15008384.0,15040000.0,14849664.0,14881344.0,14913024.0,14944704.0,14976384.0,15008064.0,15039744.0,15071424.0,14880640.0,14912384.0,14944128.0,14975872.0,15007616.0,15039360.0,15071104.0,15102848.0,14911616.0,14943424.0,14975232.0,15007040.0,15038848.0,15070656.0,15102464.0,15134272.0,14942592.0,14974464.0,15006336.0,15038208.0,15070080.0,15101952.0,15133824.0,15165696.0,14973568.0,15005504.0,15037440.0,15069376.0,15101312.0,15133248.0,15165184.0,15197120.0,15004544.0,15036544.0,15068544.0,15100544.0,15132544.0,15164544.0,15196544.0,15228544.0,15035520.0,15067584.0,15099648.0,15131712.0,15163776.0,15195840.0,15227904.0,15259968.0,15066496.0,15098624.0,15130752.0,15162880.0,15195008.0,15227136.0,15259264.0,15291392.0,15097472.0,15129664.0,15161856.0,15194048.0,15226240.0,15258432.0,15290624.0,15322816.0,15128448.0,15160704.0,15192960.0,15225216.0,15257472.0,15289728.0,15321984.0,15354240.0,15159424.0,15191744.0,15224064.0,15256384.0,15288704.0,15321024.0,15353344.0,15385664.0,15190400.0,15222784.0,15255168.0,15287552.0,15319936.0,15352320.0,15384704.0,15417088.0,15221376.0,15253824.0,15286272.0,15318720.0,15351168.0,15383616.0,15416064.0,15448512.0,15252352.0,15284864.0,15317376.0,15349888.0,15382400.0,15414912.0,15447424.0,15479936.0,15283328.0,15315904.0,15348480.0,15381056.0,15413632.0,15446208.0,15478784.0,15511360.0,15314304.0,15346944.0,15379584.0,15412224.0,15444864.0,15477504.0,15510144.0,15542784.0,15345280.0,15377984.0,15410688.0,15443392.0,15476096.0,15508800.0,15541504.0,15574208.0,15376256.0,15409024.0,15441792.0,15474560.0,15507328.0,15540096.0,15572864.0,15605632.0,15407232.0,15440064.0,15472896.0,15505728.0,15538560.0,15571392.0,15604224.0,15637056.0,15438208.0,15471104.0,15504000.0,15536896.0,15569792.0,15602688.0,15635584.0,15668480.0,15469184.0,15502144.0,15535104.0,15568064.0,15601024.0,15633984.0,15666944.0,15699904.0,15500160.0,15533184.0,15566208.0,15599232.0,15632256.0,15665280.0,15698304.0,15731328.0,15531136.0,15564224.0,15597312.0,15630400.0,15663488.0,15696576.0,15729664.0,15762752.0,15562112.0,15595264.0,15628416.0,15661568.0,15694720.0,15727872.0,15761024.0,15794176.0,15593088.0,15626304.0,15659520.0,15692736.0,15725952.0,15759168.0,15792384.0,15825600.0,15624064.0,15657344.0,15690624.0,15723904.0,15757184.0,15790464.0,15823744.0,15857024.0,15425664.0,15457984.0,15490304.0,15522624.0,15554944.0,15587264.0,15619584.0,15651904.0,15219072.0,15250432.0,15281792.0,15313152.0,15344512.0,15375872.0,15407232.0,15438592.0,15250048.0,15281472.0,15312896.0,15344320.0,15375744.0,15407168.0,15438592.0,15470016.0,15281024.0,15312512.0,15344000.0,15375488.0,15406976.0,15438464.0,15469952.0,15501440.0,15312000.0,15343552.0,15375104.0,15406656.0,15438208.0,15469760.0,15501312.0,15532864.0,15342976.0,15374592.0,15406208.0,15437824.0,15469440.0,15501056.0,15532672.0,15564288.0,15373952.0,15405632.0,15437312.0,15468992.0,15500672.0,15532352.0,15564032.0,15595712.0,15404928.0,15436672.0,15468416.0,15500160.0,15531904.0,15563648.0,15595392.0,15627136.0,15435904.0,15467712.0,15499520.0,15531328.0,15563136.0,15594944.0,15626752.0,15658560.0,15466880.0,15498752.0,15530624.0,15562496.0,15594368.0,15626240.0,15658112.0,15689984.0,15497856.0,15529792.0,15561728.0,15593664.0,15625600.0,15657536.0,15689472.0,15721408.0,15528832.0,15560832.0,15592832.0,15624832.0,15656832.0,15688832.0,15720832.0,15752832.0,15559808.0,15591872.0,15623936.0,15656000.0,15688064.0,15720128.0,15752192.0,15784256.0,15590784.0,15622912.0,15655040.0,15687168.0,15719296.0,15751424.0,15783552.0,15815680.0,15621760.0,15653952.0,15686144.0,15718336.0,15750528.0,15782720.0,15814912.0,15847104.0,15652736.0,15684992.0,15717248.0,15749504.0,15781760.0,15814016.0,15846272.0,15878528.0,15683712.0,15716032.0,15748352.0,15780672.0,15812992.0,15845312.0,15877632.0,15909952.0,15714688.0,15747072.0,15779456.0,15811840.0,15844224.0,15876608.0,15908992.0,15941376.0,15745664.0,15778112.0,15810560.0,15843008.0,15875456.0,15907904.0,15940352.0,15972800.0,15776640.0,15809152.0,15841664.0,15874176.0,15906688.0,15939200.0,15971712.0,16004224.0,15807616.0,15840192.0,15872768.0,15905344.0,15937920.0,15970496.0,16003072.0,16035648.0,15838592.0,15871232.0,15903872.0,15936512.0,15969152.0,16001792.0,16034432.0,16067072.0,15869568.0,15902272.0,15934976.0,15967680.0,16000384.0,16033088.0,16065792.0,16098496.0,15900544.0,15933312.0,15966080.0,15998848.0,16031616.0,16064384.0,16097152.0,16129920.0,15931520.0,15964352.0,15997184.0,16030016.0,16062848.0,16095680.0,16128512.0,16161344.0,15962496.0,15995392.0,16028288.0,16061184.0,16094080.0,16126976.0,16159872.0,16192768.0,15993472.0,16026432.0,16059392.0,16092352.0,16125312.0,16158272.0,16191232.0,16224192.0,16024448.0,16057472.0,16090496.0,16123520.0,16156544.0,16189568.0,16222592.0,16255616.0,16055424.0,16088512.0,16121600.0,16154688.0,16187776.0,16220864.0,16253952.0,16287040.0,16086400.0,16119552.0,16152704.0,16185856.0,16219008.0,16252160.0,16285312.0,16318464.0,16117376.0,16150592.0,16183808.0,16217024.0,16250240.0,16283456.0,16316672.0,16349888.0,16148352.0,16181632.0,16214912.0,16248192.0,16281472.0,16314752.0,16348032.0,16381312.0,16179328.0,16212672.0,16246016.0,16279360.0,16312704.0,16346048.0,16379392.0,16412736.0,16210304.0,16243712.0,16277120.0,16310528.0,16343936.0,16377344.0,16410752.0,16444160.0,16241280.0,16274752.0,16308224.0,16341696.0,16375168.0,16408640.0,16442112.0,16475584.0,16272256.0,16305792.0,16339328.0,16372864.0,16406400.0,16439936.0,16473472.0,16507008.0,16303232.0,16336832.0,16370432.0,16404032.0,16437632.0,16471232.0,16504832.0,16538432.0,16334208.0,16367872.0,16401536.0,16435200.0,16468864.0,16502528.0,16536192.0,16569856.0,16365184.0,16398912.0,16432640.0,16466368.0,16500096.0,16533824.0,16567552.0,16601280.0,15642496.0,15675264.0,15708032.0,15740800.0,15773568.0,15806336.0,15839104.0,15871872.0,14911616.0,14943424.0,14975232.0,15007040.0,15038848.0,15070656.0,15102464.0,15134272.0,7527296.0,7543216.0,7559136.0,7575056.0,7590976.0,7606896.0,7622816.0,7638736.0,8115456.0,8131248.0,8147040.0,8162832.0,8178624.0,8194416.0,8210208.0,8226000.0,16373120.0,16404736.0,16436352.0,16467968.0,16499584.0,16531200.0,16562816.0,16594432.0,16405120.0,16436800.0,16468480.0,16500160.0,16531840.0,16563520.0,16595200.0,16626880.0,16437120.0,16468864.0,16500608.0,16532352.0,16564096.0,16595840.0,16627584.0,16659328.0,16469120.0,16500928.0,16532736.0,16564544.0,16596352.0,16628160.0,16659968.0,16691776.0,16501120.0,16532992.0,16564864.0,16596736.0,16628608.0,16660480.0,16692352.0,16724224.0,16533120.0,16565056.0,16596992.0,16628928.0,16660864.0,16692800.0,16724736.0,16756672.0,16565120.0,16597120.0,16629120.0,16661120.0,16693120.0,16725120.0,16757120.0,16789120.0,16597120.0,16629184.0,16661248.0,16693312.0,16725376.0,16757440.0,16789504.0,16821568.0,16629120.0,16661248.0,16693376.0,16725504.0,16757632.0,16789760.0,16821888.0,16854016.0,16661120.0,16693312.0,16725504.0,16757696.0,16789888.0,16822080.0,16854272.0,16886464.0,16693120.0,16725376.0,16757632.0,16789888.0,16822144.0,16854400.0,16886656.0,16918912.0,16725120.0,16757440.0,16789760.0,16822080.0,16854400.0,16886720.0,16919040.0,16951360.0,16757120.0,16789504.0,16821888.0,16854272.0,16886656.0,16919040.0,16951424.0,16983808.0,16789120.0,16821568.0,16854016.0,16886464.0,16918912.0,16951360.0,16983808.0,17016256.0,16821120.0,16853632.0,16886144.0,16918656.0,16951168.0,16983680.0,17016192.0,17048704.0,16853120.0,16885696.0,16918272.0,16950848.0,16983424.0,17016000.0,17048576.0,17081152.0,16885120.0,16917760.0,16950400.0,16983040.0,17015680.0,17048320.0,17080960.0,17113600.0,16917120.0,16949824.0,16982528.0,17015232.0,17047936.0,17080640.0,17113344.0,17146048.0,16949120.0,16981888.0,17014656.0,17047424.0,17080192.0,17112960.0,17145728.0,17178496.0,16981120.0,17013952.0,17046784.0,17079616.0,17112448.0,17145280.0,17178112.0,17210944.0,17013120.0,17046016.0,17078912.0,17111808.0,17144704.0,17177600.0,17210496.0,17243392.0,17045120.0,17078080.0,17111040.0,17144000.0,17176960.0,17209920.0,17242880.0,17275840.0,17077120.0,17110144.0,17143168.0,17176192.0,17209216.0,17242240.0,17275264.0,17308288.0,17109120.0,17142208.0,17175296.0,17208384.0,17241472.0,17274560.0,17307648.0,17340736.0,17141120.0,17174272.0,17207424.0,17240576.0,17273728.0,17306880.0,17340032.0,17373184.0,17173120.0,17206336.0,17239552.0,17272768.0,17305984.0,17339200.0,17372416.0,17405632.0,17205120.0,17238400.0,17271680.0,17304960.0,17338240.0,17371520.0,17404800.0,17438080.0,17237120.0,17270464.0,17303808.0,17337152.0,17370496.0,17403840.0,17437184.0,17470528.0,17269120.0,17302528.0,17335936.0,17369344.0,17402752.0,17436160.0,17469568.0,17502976.0,17301120.0,17334592.0,17368064.0,17401536.0,17435008.0,17468480.0,17501952.0,17535424.0,17333120.0,17366656.0,17400192.0,17433728.0,17467264.0,17500800.0,17534336.0,17567872.0,17365120.0,17398720.0,17432320.0,17465920.0,17499520.0,17533120.0,17566720.0,17600320.0,17397120.0,17430784.0,17464448.0,17498112.0,17531776.0,17565440.0,17599104.0,17632768.0,17429120.0,17462848.0,17496576.0,17530304.0,17564032.0,17597760.0,17631488.0,17665216.0,17461120.0,17494912.0,17528704.0,17562496.0,17596288.0,17630080.0,17663872.0,17697664.0,17493120.0,17526976.0,17560832.0,17594688.0,17628544.0,17662400.0,17696256.0,17730112.0,16755072.0,16787968.0,16820864.0,16853760.0,16886656.0,16919552.0,16952448.0,16985344.0,16008832.0,16040768.0,16072704.0,16104640.0,16136576.0,16168512.0,16200448.0,16232384.0,15270784.0,15301760.0,15332736.0,15363712.0,15394688.0,15425664.0,15456640.0,15487616.0,14524544.0,14554560.0,14584576.0,14614592.0,14644608.0,14674624.0,14704640.0,14734656.0,14556544.0,14586624.0,14616704.0,14646784.0,14676864.0,14706944.0,14737024.0,14767104.0,14588544.0,14618688.0,14648832.0,14678976.0,14709120.0,14739264.0,14769408.0,14799552.0,14620544.0,14650752.0,14680960.0,14711168.0,14741376.0,14771584.0,14801792.0,14832000.0,14652544.0,14682816.0,14713088.0,14743360.0,14773632.0,14803904.0,14834176.0,14864448.0,14684544.0,14714880.0,14745216.0,14775552.0,14805888.0,14836224.0,14866560.0,14896896.0,14716544.0,14746944.0,14777344.0,14807744.0,14838144.0,14868544.0,14898944.0,14929344.0,14748544.0,14779008.0,14809472.0,14839936.0,14870400.0,14900864.0,14931328.0,14961792.0,14780544.0,14811072.0,14841600.0,14872128.0,14902656.0,14933184.0,14963712.0,14994240.0,14812544.0,14843136.0,14873728.0,14904320.0,14934912.0,14965504.0,14996096.0,15026688.0,14844544.0,14875200.0,14905856.0,14936512.0,14967168.0,14997824.0,15028480.0,15059136.0,14876544.0,14907264.0,14937984.0,14968704.0,14999424.0,15030144.0,15060864.0,15091584.0,14908544.0,14939328.0,14970112.0,15000896.0,15031680.0,15062464.0,15093248.0,15124032.0,14940544.0,14971392.0,15002240.0,15033088.0,15063936.0,15094784.0,15125632.0,15156480.0,14972544.0,15003456.0,15034368.0,15065280.0,15096192.0,15127104.0,15158016.0,15188928.0,15004544.0,15035520.0,15066496.0,15097472.0,15128448.0,15159424.0,15190400.0,15221376.0,15036544.0,15067584.0,15098624.0,15129664.0,15160704.0,15191744.0,15222784.0,15253824.0,15068544.0,15099648.0,15130752.0,15161856.0,15192960.0,15224064.0,15255168.0,15286272.0,15100544.0,15131712.0,15162880.0,15194048.0,15225216.0,15256384.0,15287552.0,15318720.0,15132544.0,15163776.0,15195008.0,15226240.0,15257472.0,15288704.0,15319936.0,15351168.0,15164544.0,15195840.0,15227136.0,15258432.0,15289728.0,15321024.0,15352320.0,15383616.0,15196544.0,15227904.0,15259264.0,15290624.0,15321984.0,15353344.0,15384704.0,15416064.0,15228544.0,15259968.0,15291392.0,15322816.0,15354240.0,15385664.0,15417088.0,15448512.0,15260544.0,15292032.0,15323520.0,15355008.0,15386496.0,15417984.0,15449472.0,15480960.0,15292544.0,15324096.0,15355648.0,15387200.0,15418752.0,15450304.0,15481856.0,15513408.0,15324544.0,15356160.0,15387776.0,15419392.0,15451008.0,15482624.0,15514240.0,15545856.0,15356544.0,15388224.0,15419904.0,15451584.0,15483264.0,15514944.0,15546624.0,15578304.0,15388544.0,15420288.0,15452032.0,15483776.0,15515520.0,15547264.0,15579008.0,15610752.0,15420544.0,15452352.0,15484160.0,15515968.0,15547776.0,15579584.0,15611392.0,15643200.0,15452544.0,15484416.0,15516288.0,15548160.0,15580032.0,15611904.0,15643776.0,15675648.0,15484544.0,15516480.0,15548416.0,15580352.0,15612288.0,15644224.0,15676160.0,15708096.0,15516544.0,15548544.0,15580544.0,15612544.0,15644544.0,15676544.0,15708544.0,15740544.0,15548544.0,15580608.0,15612672.0,15644736.0,15676800.0,15708864.0,15740928.0,15772992.0,15580544.0,15612672.0,15644800.0,15676928.0,15709056.0,15741184.0,15773312.0,15805440.0,15612544.0,15644736.0,15676928.0,15709120.0,15741312.0,15773504.0,15805696.0,15837888.0,15644544.0,15676800.0,15709056.0,15741312.0,15773568.0,15805824.0,15838080.0,15870336.0,15676544.0,15708864.0,15741184.0,15773504.0,15805824.0,15838144.0,15870464.0,15902784.0,15708544.0,15740928.0,15773312.0,15805696.0,15838080.0,15870464.0,15902848.0,15935232.0,15494784.0,15526208.0,15557632.0,15589056.0,15620480.0,15651904.0,15683328.0,15714752.0,15272832.0,15303296.0,15333760.0,15364224.0,15394688.0,15425152.0,15455616.0,15486080.0,15304832.0,15335360.0,15365888.0,15396416.0,15426944.0,15457472.0,15488000.0,15518528.0,15336832.0,15367424.0,15398016.0,15428608.0,15459200.0,15489792.0,15520384.0,15550976.0,15368832.0,15399488.0,15430144.0,15460800.0,15491456.0,15522112.0,15552768.0,15583424.0,15400832.0,15431552.0,15462272.0,15492992.0,15523712.0,15554432.0,15585152.0,15615872.0,15432832.0,15463616.0,15494400.0,15525184.0,15555968.0,15586752.0,15617536.0,15648320.0,15464832.0,15495680.0,15526528.0,15557376.0,15588224.0,15619072.0,15649920.0,15680768.0,15496832.0,15527744.0,15558656.0,15589568.0,15620480.0,15651392.0,15682304.0,15713216.0,15528832.0,15559808.0,15590784.0,15621760.0,15652736.0,15683712.0,15714688.0,15745664.0,15560832.0,15591872.0,15622912.0,15653952.0,15684992.0,15716032.0,15747072.0,15778112.0,15592832.0,15623936.0,15655040.0,15686144.0,15717248.0,15748352.0,15779456.0,15810560.0,15624832.0,15656000.0,15687168.0,15718336.0,15749504.0,15780672.0,15811840.0,15843008.0,15656832.0,15688064.0,15719296.0,15750528.0,15781760.0,15812992.0,15844224.0,15875456.0,15688832.0,15720128.0,15751424.0,15782720.0,15814016.0,15845312.0,15876608.0,15907904.0,15720832.0,15752192.0,15783552.0,15814912.0,15846272.0,15877632.0,15908992.0,15940352.0,15752832.0,15784256.0,15815680.0,15847104.0,15878528.0,15909952.0,15941376.0,15972800.0,15784832.0,15816320.0,15847808.0,15879296.0,15910784.0,15942272.0,15973760.0,16005248.0,15816832.0,15848384.0,15879936.0,15911488.0,15943040.0,15974592.0,16006144.0,16037696.0,15848832.0,15880448.0,15912064.0,15943680.0,15975296.0,16006912.0,16038528.0,16070144.0,15880832.0,15912512.0,15944192.0,15975872.0,16007552.0,16039232.0,16070912.0,16102592.0,15912832.0,15944576.0,15976320.0,16008064.0,16039808.0,16071552.0,16103296.0,16135040.0,15944832.0,15976640.0,16008448.0,16040256.0,16072064.0,16103872.0,16135680.0,16167488.0,15976832.0,16008704.0,16040576.0,16072448.0,16104320.0,16136192.0,16168064.0,16199936.0,16008832.0,16040768.0,16072704.0,16104640.0,16136576.0,16168512.0,16200448.0,16232384.0,16040832.0,16072832.0,16104832.0,16136832.0,16168832.0,16200832.0,16232832.0,16264832.0,16072832.0,16104896.0,16136960.0,16169024.0,16201088.0,16233152.0,16265216.0,16297280.0,16104832.0,16136960.0,16169088.0,16201216.0,16233344.0,16265472.0,16297600.0,16329728.0,16136832.0,16169024.0,16201216.0,16233408.0,16265600.0,16297792.0,16329984.0,16362176.0,16168832.0,16201088.0,16233344.0,16265600.0,16297856.0,16330112.0,16362368.0,16394624.0,16200832.0,16233152.0,16265472.0,16297792.0,16330112.0,16362432.0,16394752.0,16427072.0,16232832.0,16265216.0,16297600.0,16329984.0,16362368.0,16394752.0,16427136.0,16459520.0,16264832.0,16297280.0,16329728.0,16362176.0,16394624.0,16427072.0,16459520.0,16491968.0,16296832.0,16329344.0,16361856.0,16394368.0,16426880.0,16459392.0,16491904.0,16524416.0,16328832.0,16361408.0,16393984.0,16426560.0,16459136.0,16491712.0,16524288.0,16556864.0,16360832.0,16393472.0,16426112.0,16458752.0,16491392.0,16524032.0,16556672.0,16589312.0,16392832.0,16425536.0,16458240.0,16490944.0,16523648.0,16556352.0,16589056.0,16621760.0,16424832.0,16457600.0,16490368.0,16523136.0,16555904.0,16588672.0,16621440.0,16654208.0,16456832.0,16489664.0,16522496.0,16555328.0,16588160.0,16620992.0,16653824.0,16686656.0,15718784.0,15750656.0,15782528.0,15814400.0,15846272.0,15878144.0,15910016.0,15941888.0,14972544.0,15003456.0,15034368.0,15065280.0,15096192.0,15127104.0,15158016.0,15188928.0,15004544.0,15035520.0,15066496.0,15097472.0,15128448.0,15159424.0,15190400.0,15221376.0,15036544.0,15067584.0,15098624.0,15129664.0,15160704.0,15191744.0,15222784.0,15253824.0,15068544.0,15099648.0,15130752.0,15161856.0,15192960.0,15224064.0,15255168.0,15286272.0,15100544.0,15131712.0,15162880.0,15194048.0,15225216.0,15256384.0,15287552.0,15318720.0,15132544.0,15163776.0,15195008.0,15226240.0,15257472.0,15288704.0,15319936.0,15351168.0,15164544.0,15195840.0,15227136.0,15258432.0,15289728.0,15321024.0,15352320.0,15383616.0,15196544.0,15227904.0,15259264.0,15290624.0,15321984.0,15353344.0,15384704.0,15416064.0,15228544.0,15259968.0,15291392.0,15322816.0,15354240.0,15385664.0,15417088.0,15448512.0,15260544.0,15292032.0,15323520.0,15355008.0,15386496.0,15417984.0,15449472.0,15480960.0,15292544.0,15324096.0,15355648.0,15387200.0,15418752.0,15450304.0,15481856.0,15513408.0,15324544.0,15356160.0,15387776.0,15419392.0,15451008.0,15482624.0,15514240.0,15545856.0,15356544.0,15388224.0,15419904.0,15451584.0,15483264.0,15514944.0,15546624.0,15578304.0,15388544.0,15420288.0,15452032.0,15483776.0,15515520.0,15547264.0,15579008.0,15610752.0,15420544.0,15452352.0,15484160.0,15515968.0,15547776.0,15579584.0,15611392.0,15643200.0,15452544.0,15484416.0,15516288.0,15548160.0,15580032.0,15611904.0,15643776.0,15675648.0,15484544.0,15516480.0,15548416.0,15580352.0,15612288.0,15644224.0,15676160.0,15708096.0,15516544.0,15548544.0,15580544.0,15612544.0,15644544.0,15676544.0,15708544.0,15740544.0,15548544.0,15580608.0,15612672.0,15644736.0,15676800.0,15708864.0,15740928.0,15772992.0,15580544.0,15612672.0,15644800.0,15676928.0,15709056.0,15741184.0,15773312.0,15805440.0,15612544.0,15644736.0,15676928.0,15709120.0,15741312.0,15773504.0,15805696.0,15837888.0,15644544.0,15676800.0,15709056.0,15741312.0,15773568.0,15805824.0,15838080.0,15870336.0,15676544.0,15708864.0,15741184.0,15773504.0,15805824.0,15838144.0,15870464.0,15902784.0,15708544.0,15740928.0,15773312.0,15805696.0,15838080.0,15870464.0,15902848.0,15935232.0,15740544.0,15772992.0,15805440.0,15837888.0,15870336.0,15902784.0,15935232.0,15967680.0,15772544.0,15805056.0,15837568.0,15870080.0,15902592.0,15935104.0,15967616.0,16000128.0,15804544.0,15837120.0,15869696.0,15902272.0,15934848.0,15967424.0,16000000.0,16032576.0,15836544.0,15869184.0,15901824.0,15934464.0,15967104.0,15999744.0,16032384.0,16065024.0,15868544.0,15901248.0,15933952.0,15966656.0,15999360.0,16032064.0,16064768.0,16097472.0,15900544.0,15933312.0,15966080.0,15998848.0,16031616.0,16064384.0,16097152.0,16129920.0,15932544.0,15965376.0,15998208.0,16031040.0,16063872.0,16096704.0,16129536.0,16162368.0,15964544.0,15997440.0,16030336.0,16063232.0,16096128.0,16129024.0,16161920.0,16194816.0,15996544.0,16029504.0,16062464.0,16095424.0,16128384.0,16161344.0,16194304.0,16227264.0,16028544.0,16061568.0,16094592.0,16127616.0,16160640.0,16193664.0,16226688.0,16259712.0,16060544.0,16093632.0,16126720.0,16159808.0,16192896.0,16225984.0,16259072.0,16292160.0,16092544.0,16125696.0,16158848.0,16192000.0,16225152.0,16258304.0,16291456.0,16324608.0,16124544.0,16157760.0,16190976.0,16224192.0,16257408.0,16290624.0,16323840.0,16357056.0,16156544.0,16189824.0,16223104.0,16256384.0,16289664.0,16322944.0,16356224.0,16389504.0,15942784.0,15975104.0,16007424.0,16039744.0,16072064.0,16104384.0,16136704.0,16169024.0,15720832.0,15752192.0,15783552.0,15814912.0,15846272.0,15877632.0,15908992.0,15940352.0,15752832.0,15784256.0,15815680.0,15847104.0,15878528.0,15909952.0,15941376.0,15972800.0,15784832.0,15816320.0,15847808.0,15879296.0,15910784.0,15942272.0,15973760.0,16005248.0,15816832.0,15848384.0,15879936.0,15911488.0,15943040.0,15974592.0,16006144.0,16037696.0,15848832.0,15880448.0,15912064.0,15943680.0,15975296.0,16006912.0,16038528.0,16070144.0,15880832.0,15912512.0,15944192.0,15975872.0,16007552.0,16039232.0,16070912.0,16102592.0,15912832.0,15944576.0,15976320.0,16008064.0,16039808.0,16071552.0,16103296.0,16135040.0,15944832.0,15976640.0,16008448.0,16040256.0,16072064.0,16103872.0,16135680.0,16167488.0,15976832.0,16008704.0,16040576.0,16072448.0,16104320.0,16136192.0,16168064.0,16199936.0,16008832.0,16040768.0,16072704.0,16104640.0,16136576.0,16168512.0,16200448.0,16232384.0,16040832.0,16072832.0,16104832.0,16136832.0,16168832.0,16200832.0,16232832.0,16264832.0,16072832.0,16104896.0,16136960.0,16169024.0,16201088.0,16233152.0,16265216.0,16297280.0,16104832.0,16136960.0,16169088.0,16201216.0,16233344.0,16265472.0,16297600.0,16329728.0,16136832.0,16169024.0,16201216.0,16233408.0,16265600.0,16297792.0,16329984.0,16362176.0,16168832.0,16201088.0,16233344.0,16265600.0,16297856.0,16330112.0,16362368.0,16394624.0,16200832.0,16233152.0,16265472.0,16297792.0,16330112.0,16362432.0,16394752.0,16427072.0,16232832.0,16265216.0,16297600.0,16329984.0,16362368.0,16394752.0,16427136.0,16459520.0,16264832.0,16297280.0,16329728.0,16362176.0,16394624.0,16427072.0,16459520.0,16491968.0,16296832.0,16329344.0,16361856.0,16394368.0,16426880.0,16459392.0,16491904.0,16524416.0,16328832.0,16361408.0,16393984.0,16426560.0,16459136.0,16491712.0,16524288.0,16556864.0,16360832.0,16393472.0,16426112.0,16458752.0,16491392.0,16524032.0,16556672.0,16589312.0,16392832.0,16425536.0,16458240.0,16490944.0,16523648.0,16556352.0,16589056.0,16621760.0,16424832.0,16457600.0,16490368.0,16523136.0,16555904.0,16588672.0,16621440.0,16654208.0,16456832.0,16489664.0,16522496.0,16555328.0,16588160.0,16620992.0,16653824.0,16686656.0,16488832.0,16521728.0,16554624.0,16587520.0,16620416.0,16653312.0,16686208.0,16719104.0,16520832.0,16553792.0,16586752.0,16619712.0,16652672.0,16685632.0,16718592.0,16751552.0,16552832.0,16585856.0,16618880.0,16651904.0,16684928.0,16717952.0,16750976.0,16784000.0,16584832.0,16617920.0,16651008.0,16684096.0,16717184.0,16750272.0,16783360.0,16816448.0,16616832.0,16649984.0,16683136.0,16716288.0,16749440.0,16782592.0,16815744.0,16848896.0,16648832.0,16682048.0,16715264.0,16748480.0,16781696.0,16814912.0,16848128.0,16881344.0,16680832.0,16714112.0,16747392.0,16780672.0,16813952.0,16847232.0,16880512.0,16913792.0,16712832.0,16746176.0,16779520.0,16812864.0,16846208.0,16879552.0,16912896.0,16946240.0,16744832.0,16778240.0,16811648.0,16845056.0,16878464.0,16911872.0,16945280.0,16978688.0,16776832.0,16810304.0,16843776.0,16877248.0,16910720.0,16944192.0,16977664.0,17011136.0,16808832.0,16842368.0,16875904.0,16909440.0,16942976.0,16976512.0,17010048.0,17043584.0,16840832.0,16874432.0,16908032.0,16941632.0,16975232.0,17008832.0,17042432.0,17076032.0,16872832.0,16906496.0,16940160.0,16973824.0,17007488.0,17041152.0,17074816.0,17108480.0,16904832.0,16938560.0,16972288.0,17006016.0,17039744.0,17073472.0,17107200.0,17140928.0,16166784.0,16199552.0,16232320.0,16265088.0,16297856.0,16330624.0,16363392.0,16396160.0,15420544.0,15452352.0,15484160.0,15515968.0,15547776.0,15579584.0,15611392.0,15643200.0,7782016.0,7797936.0,7813856.0,7829776.0,7845696.0,7861616.0,7877536.0,7893456.0,8368128.0,8383920.0,8399712.0,8415504.0,8431296.0,8447088.0,8462880.0,8478672.0,16878976.0,16910592.0,16942208.0,16973824.0,17005440.0,17037056.0,17068672.0,17100288.0,16912000.0,16943680.0,16975360.0,17007040.0,17038720.0,17070400.0,17102080.0,17133760.0,16945024.0,16976768.0,17008512.0,17040256.0,17072000.0,17103744.0,17135488.0,17167232.0,16978048.0,17009856.0,17041664.0,17073472.0,17105280.0,17137088.0,17168896.0,17200704.0,17011072.0,17042944.0,17074816.0,17106688.0,17138560.0,17170432.0,17202304.0,17234176.0,17044096.0,17076032.0,17107968.0,17139904.0,17171840.0,17203776.0,17235712.0,17267648.0,17077120.0,17109120.0,17141120.0,17173120.0,17205120.0,17237120.0,17269120.0,17301120.0,17110144.0,17142208.0,17174272.0,17206336.0,17238400.0,17270464.0,17302528.0,17334592.0,17143168.0,17175296.0,17207424.0,17239552.0,17271680.0,17303808.0,17335936.0,17368064.0,17176192.0,17208384.0,17240576.0,17272768.0,17304960.0,17337152.0,17369344.0,17401536.0,17209216.0,17241472.0,17273728.0,17305984.0,17338240.0,17370496.0,17402752.0,17435008.0,17242240.0,17274560.0,17306880.0,17339200.0,17371520.0,17403840.0,17436160.0,17468480.0,17275264.0,17307648.0,17340032.0,17372416.0,17404800.0,17437184.0,17469568.0,17501952.0,17308288.0,17340736.0,17373184.0,17405632.0,17438080.0,17470528.0,17502976.0,17535424.0,17341312.0,17373824.0,17406336.0,17438848.0,17471360.0,17503872.0,17536384.0,17568896.0,17374336.0,17406912.0,17439488.0,17472064.0,17504640.0,17537216.0,17569792.0,17602368.0,17407360.0,17440000.0,17472640.0,17505280.0,17537920.0,17570560.0,17603200.0,17635840.0,17440384.0,17473088.0,17505792.0,17538496.0,17571200.0,17603904.0,17636608.0,17669312.0,17473408.0,17506176.0,17538944.0,17571712.0,17604480.0,17637248.0,17670016.0,17702784.0,17506432.0,17539264.0,17572096.0,17604928.0,17637760.0,17670592.0,17703424.0,17736256.0,17539456.0,17572352.0,17605248.0,17638144.0,17671040.0,17703936.0,17736832.0,17769728.0,17572480.0,17605440.0,17638400.0,17671360.0,17704320.0,17737280.0,17770240.0,17803200.0,17605504.0,17638528.0,17671552.0,17704576.0,17737600.0,17770624.0,17803648.0,17836672.0,17638528.0,17671616.0,17704704.0,17737792.0,17770880.0,17803968.0,17837056.0,17870144.0,17671552.0,17704704.0,17737856.0,17771008.0,17804160.0,17837312.0,17870464.0,17903616.0,17704576.0,17737792.0,17771008.0,17804224.0,17837440.0,17870656.0,17903872.0,17937088.0,17737600.0,17770880.0,17804160.0,17837440.0,17870720.0,17904000.0,17937280.0,17970560.0,17770624.0,17803968.0,17837312.0,17870656.0,17904000.0,17937344.0,17970688.0,18004032.0,17803648.0,17837056.0,17870464.0,17903872.0,17937280.0,17970688.0,18004096.0,18037504.0,17836672.0,17870144.0,17903616.0,17937088.0,17970560.0,18004032.0,18037504.0,18070976.0,17869696.0,17903232.0,17936768.0,17970304.0,18003840.0,18037376.0,18070912.0,18104448.0,17902720.0,17936320.0,17969920.0,18003520.0,18037120.0,18070720.0,18104320.0,18137920.0,17935744.0,17969408.0,18003072.0,18036736.0,18070400.0,18104064.0,18137728.0,18171392.0,17968768.0,18002496.0,18036224.0,18069952.0,18103680.0,18137408.0,18171136.0,18204864.0,18001792.0,18035584.0,18069376.0,18103168.0,18136960.0,18170752.0,18204544.0,18238336.0,18034816.0,18068672.0,18102528.0,18136384.0,18170240.0,18204096.0,18237952.0,18271808.0,17281408.0,17314304.0,17347200.0,17380096.0,17412992.0,17445888.0,17478784.0,17511680.0,16519808.0,16551744.0,16583680.0,16615616.0,16647552.0,16679488.0,16711424.0,16743360.0,15766400.0,15797376.0,15828352.0,15859328.0,15890304.0,15921280.0,15952256.0,15983232.0,15004800.0,15034816.0,15064832.0,15094848.0,15124864.0,15154880.0,15184896.0,15214912.0,15037824.0,15067904.0,15097984.0,15128064.0,15158144.0,15188224.0,15218304.0,15248384.0,15070848.0,15100992.0,15131136.0,15161280.0,15191424.0,15221568.0,15251712.0,15281856.0,15103872.0,15134080.0,15164288.0,15194496.0,15224704.0,15254912.0,15285120.0,15315328.0,15136896.0,15167168.0,15197440.0,15227712.0,15257984.0,15288256.0,15318528.0,15348800.0,15169920.0,15200256.0,15230592.0,15260928.0,15291264.0,15321600.0,15351936.0,15382272.0,15202944.0,15233344.0,15263744.0,15294144.0,15324544.0,15354944.0,15385344.0,15415744.0,15235968.0,15266432.0,15296896.0,15327360.0,15357824.0,15388288.0,15418752.0,15449216.0,15268992.0,15299520.0,15330048.0,15360576.0,15391104.0,15421632.0,15452160.0,15482688.0,15302016.0,15332608.0,15363200.0,15393792.0,15424384.0,15454976.0,15485568.0,15516160.0,15335040.0,15365696.0,15396352.0,15427008.0,15457664.0,15488320.0,15518976.0,15549632.0,15368064.0,15398784.0,15429504.0,15460224.0,15490944.0,15521664.0,15552384.0,15583104.0,15401088.0,15431872.0,15462656.0,15493440.0,15524224.0,15555008.0,15585792.0,15616576.0,15434112.0,15464960.0,15495808.0,15526656.0,15557504.0,15588352.0,15619200.0,15650048.0,15467136.0,15498048.0,15528960.0,15559872.0,15590784.0,15621696.0,15652608.0,15683520.0,15500160.0,15531136.0,15562112.0,15593088.0,15624064.0,15655040.0,15686016.0,15716992.0,15533184.0,15564224.0,15595264.0,15626304.0,15657344.0,15688384.0,15719424.0,15750464.0,15566208.0,15597312.0,15628416.0,15659520.0,15690624.0,15721728.0,15752832.0,15783936.0,15599232.0,15630400.0,15661568.0,15692736.0,15723904.0,15755072.0,15786240.0,15817408.0,15632256.0,15663488.0,15694720.0,15725952.0,15757184.0,15788416.0,15819648.0,15850880.0,15665280.0,15696576.0,15727872.0,15759168.0,15790464.0,15821760.0,15853056.0,15884352.0,15698304.0,15729664.0,15761024.0,15792384.0,15823744.0,15855104.0,15886464.0,15917824.0,15731328.0,15762752.0,15794176.0,15825600.0,15857024.0,15888448.0,15919872.0,15951296.0,15764352.0,15795840.0,15827328.0,15858816.0,15890304.0,15921792.0,15953280.0,15984768.0,15797376.0,15828928.0,15860480.0,15892032.0,15923584.0,15955136.0,15986688.0,16018240.0,15830400.0,15862016.0,15893632.0,15925248.0,15956864.0,15988480.0,16020096.0,16051712.0,15863424.0,15895104.0,15926784.0,15958464.0,15990144.0,16021824.0,16053504.0,16085184.0,15896448.0,15928192.0,15959936.0,15991680.0,16023424.0,16055168.0,16086912.0,16118656.0,15929472.0,15961280.0,15993088.0,16024896.0,16056704.0,16088512.0,16120320.0,16152128.0,15962496.0,15994368.0,16026240.0,16058112.0,16089984.0,16121856.0,16153728.0,16185600.0,15995520.0,16027456.0,16059392.0,16091328.0,16123264.0,16155200.0,16187136.0,16219072.0,16028544.0,16060544.0,16092544.0,16124544.0,16156544.0,16188544.0,16220544.0,16252544.0,16061568.0,16093632.0,16125696.0,16157760.0,16189824.0,16221888.0,16253952.0,16286016.0,16094592.0,16126720.0,16158848.0,16190976.0,16223104.0,16255232.0,16287360.0,16319488.0,16127616.0,16159808.0,16192000.0,16224192.0,16256384.0,16288576.0,16320768.0,16352960.0,16160640.0,16192896.0,16225152.0,16257408.0,16289664.0,16321920.0,16354176.0,16386432.0,16193664.0,16225984.0,16258304.0,16290624.0,16322944.0,16355264.0,16387584.0,16419904.0,16226688.0,16259072.0,16291456.0,16323840.0,16356224.0,16388608.0,16420992.0,16453376.0,15997568.0,16028992.0,16060416.0,16091840.0,16123264.0,16154688.0,16186112.0,16217536.0,15760256.0,15790720.0,15821184.0,15851648.0,15882112.0,15912576.0,15943040.0,15973504.0,15793280.0,15823808.0,15854336.0,15884864.0,15915392.0,15945920.0,15976448.0,16006976.0,15826304.0,15856896.0,15887488.0,15918080.0,15948672.0,15979264.0,16009856.0,16040448.0,15859328.0,15889984.0,15920640.0,15951296.0,15981952.0,16012608.0,16043264.0,16073920.0,15892352.0,15923072.0,15953792.0,15984512.0,16015232.0,16045952.0,16076672.0,16107392.0,15925376.0,15956160.0,15986944.0,16017728.0,16048512.0,16079296.0,16110080.0,16140864.0,15958400.0,15989248.0,16020096.0,16050944.0,16081792.0,16112640.0,16143488.0,16174336.0,15991424.0,16022336.0,16053248.0,16084160.0,16115072.0,16145984.0,16176896.0,16207808.0,16024448.0,16055424.0,16086400.0,16117376.0,16148352.0,16179328.0,16210304.0,16241280.0,16057472.0,16088512.0,16119552.0,16150592.0,16181632.0,16212672.0,16243712.0,16274752.0,16090496.0,16121600.0,16152704.0,16183808.0,16214912.0,16246016.0,16277120.0,16308224.0,16123520.0,16154688.0,16185856.0,16217024.0,16248192.0,16279360.0,16310528.0,16341696.0,16156544.0,16187776.0,16219008.0,16250240.0,16281472.0,16312704.0,16343936.0,16375168.0,16189568.0,16220864.0,16252160.0,16283456.0,16314752.0,16346048.0,16377344.0,16408640.0,16222592.0,16253952.0,16285312.0,16316672.0,16348032.0,16379392.0,16410752.0,16442112.0,16255616.0,16287040.0,16318464.0,16349888.0,16381312.0,16412736.0,16444160.0,16475584.0,16288640.0,16320128.0,16351616.0,16383104.0,16414592.0,16446080.0,16477568.0,16509056.0,16321664.0,16353216.0,16384768.0,16416320.0,16447872.0,16479424.0,16510976.0,16542528.0,16354688.0,16386304.0,16417920.0,16449536.0,16481152.0,16512768.0,16544384.0,16576000.0,16387712.0,16419392.0,16451072.0,16482752.0,16514432.0,16546112.0,16577792.0,16609472.0,16420736.0,16452480.0,16484224.0,16515968.0,16547712.0,16579456.0,16611200.0,16642944.0,16453760.0,16485568.0,16517376.0,16549184.0,16580992.0,16612800.0,16644608.0,16676416.0,16486784.0,16518656.0,16550528.0,16582400.0,16614272.0,16646144.0,16678016.0,16709888.0,16519808.0,16551744.0,16583680.0,16615616.0,16647552.0,16679488.0,16711424.0,16743360.0,16552832.0,16584832.0,16616832.0,16648832.0,16680832.0,16712832.0,16744832.0,16776832.0,16585856.0,16617920.0,16649984.0,16682048.0,16714112.0,16746176.0,16778240.0,16810304.0,16618880.0,16651008.0,16683136.0,16715264.0,16747392.0,16779520.0,16811648.0,16843776.0,16651904.0,16684096.0,16716288.0,16748480.0,16780672.0,16812864.0,16845056.0,16877248.0,16684928.0,16717184.0,16749440.0,16781696.0,16813952.0,16846208.0,16878464.0,16910720.0,16717952.0,16750272.0,16782592.0,16814912.0,16847232.0,16879552.0,16911872.0,16944192.0,16750976.0,16783360.0,16815744.0,16848128.0,16880512.0,16912896.0,16945280.0,16977664.0,16784000.0,16816448.0,16848896.0,16881344.0,16913792.0,16946240.0,16978688.0,17011136.0,16817024.0,16849536.0,16882048.0,16914560.0,16947072.0,16979584.0,17012096.0,17044608.0,16850048.0,16882624.0,16915200.0,16947776.0,16980352.0,17012928.0,17045504.0,17078080.0,16883072.0,16915712.0,16948352.0,16980992.0,17013632.0,17046272.0,17078912.0,17111552.0,16916096.0,16948800.0,16981504.0,17014208.0,17046912.0,17079616.0,17112320.0,17145024.0,16949120.0,16981888.0,17014656.0,17047424.0,17080192.0,17112960.0,17145728.0,17178496.0,16982144.0,17014976.0,17047808.0,17080640.0,17113472.0,17146304.0,17179136.0,17211968.0,16228736.0,16260608.0,16292480.0,16324352.0,16356224.0,16388096.0,16419968.0,16451840.0,15467136.0,15498048.0,15528960.0,15559872.0,15590784.0,15621696.0,15652608.0,15683520.0,15500160.0,15531136.0,15562112.0,15593088.0,15624064.0,15655040.0,15686016.0,15716992.0,15533184.0,15564224.0,15595264.0,15626304.0,15657344.0,15688384.0,15719424.0,15750464.0,15566208.0,15597312.0,15628416.0,15659520.0,15690624.0,15721728.0,15752832.0,15783936.0,15599232.0,15630400.0,15661568.0,15692736.0,15723904.0,15755072.0,15786240.0,15817408.0,15632256.0,15663488.0,15694720.0,15725952.0,15757184.0,15788416.0,15819648.0,15850880.0,15665280.0,15696576.0,15727872.0,15759168.0,15790464.0,15821760.0,15853056.0,15884352.0,15698304.0,15729664.0,15761024.0,15792384.0,15823744.0,15855104.0,15886464.0,15917824.0,15731328.0,15762752.0,15794176.0,15825600.0,15857024.0,15888448.0,15919872.0,15951296.0,15764352.0,15795840.0,15827328.0,15858816.0,15890304.0,15921792.0,15953280.0,15984768.0,15797376.0,15828928.0,15860480.0,15892032.0,15923584.0,15955136.0,15986688.0,16018240.0,15830400.0,15862016.0,15893632.0,15925248.0,15956864.0,15988480.0,16020096.0,16051712.0,15863424.0,15895104.0,15926784.0,15958464.0,15990144.0,16021824.0,16053504.0,16085184.0,15896448.0,15928192.0,15959936.0,15991680.0,16023424.0,16055168.0,16086912.0,16118656.0,15929472.0,15961280.0,15993088.0,16024896.0,16056704.0,16088512.0,16120320.0,16152128.0,15962496.0,15994368.0,16026240.0,16058112.0,16089984.0,16121856.0,16153728.0,16185600.0,15995520.0,16027456.0,16059392.0,16091328.0,16123264.0,16155200.0,16187136.0,16219072.0,16028544.0,16060544.0,16092544.0,16124544.0,16156544.0,16188544.0,16220544.0,16252544.0,16061568.0,16093632.0,16125696.0,16157760.0,16189824.0,16221888.0,16253952.0,16286016.0,16094592.0,16126720.0,16158848.0,16190976.0,16223104.0,16255232.0,16287360.0,16319488.0,16127616.0,16159808.0,16192000.0,16224192.0,16256384.0,16288576.0,16320768.0,16352960.0,16160640.0,16192896.0,16225152.0,16257408.0,16289664.0,16321920.0,16354176.0,16386432.0,16193664.0,16225984.0,16258304.0,16290624.0,16322944.0,16355264.0,16387584.0,16419904.0,16226688.0,16259072.0,16291456.0,16323840.0,16356224.0,16388608.0,16420992.0,16453376.0,16259712.0,16292160.0,16324608.0,16357056.0,16389504.0,16421952.0,16454400.0,16486848.0,16292736.0,16325248.0,16357760.0,16390272.0,16422784.0,16455296.0,16487808.0,16520320.0,16325760.0,16358336.0,16390912.0,16423488.0,16456064.0,16488640.0,16521216.0,16553792.0,16358784.0,16391424.0,16424064.0,16456704.0,16489344.0,16521984.0,16554624.0,16587264.0,16391808.0,16424512.0,16457216.0,16489920.0,16522624.0,16555328.0,16588032.0,16620736.0,16424832.0,16457600.0,16490368.0,16523136.0,16555904.0,16588672.0,16621440.0,16654208.0,16457856.0,16490688.0,16523520.0,16556352.0,16589184.0,16622016.0,16654848.0,16687680.0,16490880.0,16523776.0,16556672.0,16589568.0,16622464.0,16655360.0,16688256.0,16721152.0,16523904.0,16556864.0,16589824.0,16622784.0,16655744.0,16688704.0,16721664.0,16754624.0,16556928.0,16589952.0,16622976.0,16656000.0,16689024.0,16722048.0,16755072.0,16788096.0,16589952.0,16623040.0,16656128.0,16689216.0,16722304.0,16755392.0,16788480.0,16821568.0,16622976.0,16656128.0,16689280.0,16722432.0,16755584.0,16788736.0,16821888.0,16855040.0,16656000.0,16689216.0,16722432.0,16755648.0,16788864.0,16822080.0,16855296.0,16888512.0,16689024.0,16722304.0,16755584.0,16788864.0,16822144.0,16855424.0,16888704.0,16921984.0,16459904.0,16492224.0,16524544.0,16556864.0,16589184.0,16621504.0,16653824.0,16686144.0,16222592.0,16253952.0,16285312.0,16316672.0,16348032.0,16379392.0,16410752.0,16442112.0,16255616.0,16287040.0,16318464.0,16349888.0,16381312.0,16412736.0,16444160.0,16475584.0,16288640.0,16320128.0,16351616.0,16383104.0,16414592.0,16446080.0,16477568.0,16509056.0,16321664.0,16353216.0,16384768.0,16416320.0,16447872.0,16479424.0,16510976.0,16542528.0,16354688.0,16386304.0,16417920.0,16449536.0,16481152.0,16512768.0,16544384.0,16576000.0,16387712.0,16419392.0,16451072.0,16482752.0,16514432.0,16546112.0,16577792.0,16609472.0,16420736.0,16452480.0,16484224.0,16515968.0,16547712.0,16579456.0,16611200.0,16642944.0,16453760.0,16485568.0,16517376.0,16549184.0,16580992.0,16612800.0,16644608.0,16676416.0,16486784.0,16518656.0,16550528.0,16582400.0,16614272.0,16646144.0,16678016.0,16709888.0,16519808.0,16551744.0,16583680.0,16615616.0,16647552.0,16679488.0,16711424.0,16743360.0,16552832.0,16584832.0,16616832.0,16648832.0,16680832.0,16712832.0,16744832.0,16776832.0,16585856.0,16617920.0,16649984.0,16682048.0,16714112.0,16746176.0,16778240.0,16810304.0,16618880.0,16651008.0,16683136.0,16715264.0,16747392.0,16779520.0,16811648.0,16843776.0,16651904.0,16684096.0,16716288.0,16748480.0,16780672.0,16812864.0,16845056.0,16877248.0,16684928.0,16717184.0,16749440.0,16781696.0,16813952.0,16846208.0,16878464.0,16910720.0,16717952.0,16750272.0,16782592.0,16814912.0,16847232.0,16879552.0,16911872.0,16944192.0,16750976.0,16783360.0,16815744.0,16848128.0,16880512.0,16912896.0,16945280.0,16977664.0,16784000.0,16816448.0,16848896.0,16881344.0,16913792.0,16946240.0,16978688.0,17011136.0,16817024.0,16849536.0,16882048.0,16914560.0,16947072.0,16979584.0,17012096.0,17044608.0,16850048.0,16882624.0,16915200.0,16947776.0,16980352.0,17012928.0,17045504.0,17078080.0,16883072.0,16915712.0,16948352.0,16980992.0,17013632.0,17046272.0,17078912.0,17111552.0,16916096.0,16948800.0,16981504.0,17014208.0,17046912.0,17079616.0,17112320.0,17145024.0,16949120.0,16981888.0,17014656.0,17047424.0,17080192.0,17112960.0,17145728.0,17178496.0,16982144.0,17014976.0,17047808.0,17080640.0,17113472.0,17146304.0,17179136.0,17211968.0,17015168.0,17048064.0,17080960.0,17113856.0,17146752.0,17179648.0,17212544.0,17245440.0,17048192.0,17081152.0,17114112.0,17147072.0,17180032.0,17212992.0,17245952.0,17278912.0,17081216.0,17114240.0,17147264.0,17180288.0,17213312.0,17246336.0,17279360.0,17312384.0,17114240.0,17147328.0,17180416.0,17213504.0,17246592.0,17279680.0,17312768.0,17345856.0,17147264.0,17180416.0,17213568.0,17246720.0,17279872.0,17313024.0,17346176.0,17379328.0,17180288.0,17213504.0,17246720.0,17279936.0,17313152.0,17346368.0,17379584.0,17412800.0,17213312.0,17246592.0,17279872.0,17313152.0,17346432.0,17379712.0,17412992.0,17446272.0,17246336.0,17279680.0,17313024.0,17346368.0,17379712.0,17413056.0,17446400.0,17479744.0,17279360.0,17312768.0,17346176.0,17379584.0,17412992.0,17446400.0,17479808.0,17513216.0,17312384.0,17345856.0,17379328.0,17412800.0,17446272.0,17479744.0,17513216.0,17546688.0,17345408.0,17378944.0,17412480.0,17446016.0,17479552.0,17513088.0,17546624.0,17580160.0,17378432.0,17412032.0,17445632.0,17479232.0,17512832.0,17546432.0,17580032.0,17613632.0,17411456.0,17445120.0,17478784.0,17512448.0,17546112.0,17579776.0,17613440.0,17647104.0,17444480.0,17478208.0,17511936.0,17545664.0,17579392.0,17613120.0,17646848.0,17680576.0,16691072.0,16723840.0,16756608.0,16789376.0,16822144.0,16854912.0,16887680.0,16920448.0,15929472.0,15961280.0,15993088.0,16024896.0,16056704.0,16088512.0,16120320.0,16152128.0,8036736.0,8052656.0,8068576.0,8084496.0,8100416.0,8116336.0,8132256.0,8148176.0,8620800.0,8636592.0,8652384.0,8668176.0,8683968.0,8699760.0,8715552.0,8731344.0,17384832.0,17416448.0,17448064.0,17479680.0,17511296.0,17542912.0,17574528.0,17606144.0,17418880.0,17450560.0,17482240.0,17513920.0,17545600.0,17577280.0,17608960.0,17640640.0,17452928.0,17484672.0,17516416.0,17548160.0,17579904.0,17611648.0,17643392.0,17675136.0,17486976.0,17518784.0,17550592.0,17582400.0,17614208.0,17646016.0,17677824.0,17709632.0,17521024.0,17552896.0,17584768.0,17616640.0,17648512.0,17680384.0,17712256.0,17744128.0,17555072.0,17587008.0,17618944.0,17650880.0,17682816.0,17714752.0,17746688.0,17778624.0,17589120.0,17621120.0,17653120.0,17685120.0,17717120.0,17749120.0,17781120.0,17813120.0,17623168.0,17655232.0,17687296.0,17719360.0,17751424.0,17783488.0,17815552.0,17847616.0,17657216.0,17689344.0,17721472.0,17753600.0,17785728.0,17817856.0,17849984.0,17882112.0,17691264.0,17723456.0,17755648.0,17787840.0,17820032.0,17852224.0,17884416.0,17916608.0,17725312.0,17757568.0,17789824.0,17822080.0,17854336.0,17886592.0,17918848.0,17951104.0,17759360.0,17791680.0,17824000.0,17856320.0,17888640.0,17920960.0,17953280.0,17985600.0,17793408.0,17825792.0,17858176.0,17890560.0,17922944.0,17955328.0,17987712.0,18020096.0,17827456.0,17859904.0,17892352.0,17924800.0,17957248.0,17989696.0,18022144.0,18054592.0,17861504.0,17894016.0,17926528.0,17959040.0,17991552.0,18024064.0,18056576.0,18089088.0,17895552.0,17928128.0,17960704.0,17993280.0,18025856.0,18058432.0,18091008.0,18123584.0,17929600.0,17962240.0,17994880.0,18027520.0,18060160.0,18092800.0,18125440.0,18158080.0,17963648.0,17996352.0,18029056.0,18061760.0,18094464.0,18127168.0,18159872.0,18192576.0,17997696.0,18030464.0,18063232.0,18096000.0,18128768.0,18161536.0,18194304.0,18227072.0,18031744.0,18064576.0,18097408.0,18130240.0,18163072.0,18195904.0,18228736.0,18261568.0,18065792.0,18098688.0,18131584.0,18164480.0,18197376.0,18230272.0,18263168.0,18296064.0,18099840.0,18132800.0,18165760.0,18198720.0,18231680.0,18264640.0,18297600.0,18330560.0,18133888.0,18166912.0,18199936.0,18232960.0,18265984.0,18299008.0,18332032.0,18365056.0,18167936.0,18201024.0,18234112.0,18267200.0,18300288.0,18333376.0,18366464.0,18399552.0,18201984.0,18235136.0,18268288.0,18301440.0,18334592.0,18367744.0,18400896.0,18434048.0,18236032.0,18269248.0,18302464.0,18335680.0,18368896.0,18402112.0,18435328.0,18468544.0,18270080.0,18303360.0,18336640.0,18369920.0,18403200.0,18436480.0,18469760.0,18503040.0,18304128.0,18337472.0,18370816.0,18404160.0,18437504.0,18470848.0,18504192.0,18537536.0,18338176.0,18371584.0,18404992.0,18438400.0,18471808.0,18505216.0,18538624.0,18572032.0,18372224.0,18405696.0,18439168.0,18472640.0,18506112.0,18539584.0,18573056.0,18606528.0,18406272.0,18439808.0,18473344.0,18506880.0,18540416.0,18573952.0,18607488.0,18641024.0,18440320.0,18473920.0,18507520.0,18541120.0,18574720.0,18608320.0,18641920.0,18675520.0,18474368.0,18508032.0,18541696.0,18575360.0,18609024.0,18642688.0,18676352.0,18710016.0,18508416.0,18542144.0,18575872.0,18609600.0,18643328.0,18677056.0,18710784.0,18744512.0,18542464.0,18576256.0,18610048.0,18643840.0,18677632.0,18711424.0,18745216.0,18779008.0,18576512.0,18610368.0,18644224.0,18678080.0,18711936.0,18745792.0,18779648.0,18813504.0,17807744.0,17840640.0,17873536.0,17906432.0,17939328.0,17972224.0,18005120.0,18038016.0,17030784.0,17062720.0,17094656.0,17126592.0,17158528.0,17190464.0,17222400.0,17254336.0,16262016.0,16292992.0,16323968.0,16354944.0,16385920.0,16416896.0,16447872.0,16478848.0,15485056.0,15515072.0,15545088.0,15575104.0,15605120.0,15635136.0,15665152.0,15695168.0,15519104.0,15549184.0,15579264.0,15609344.0,15639424.0,15669504.0,15699584.0,15729664.0,15553152.0,15583296.0,15613440.0,15643584.0,15673728.0,15703872.0,15734016.0,15764160.0,15587200.0,15617408.0,15647616.0,15677824.0,15708032.0,15738240.0,15768448.0,15798656.0,15621248.0,15651520.0,15681792.0,15712064.0,15742336.0,15772608.0,15802880.0,15833152.0,15655296.0,15685632.0,15715968.0,15746304.0,15776640.0,15806976.0,15837312.0,15867648.0,15689344.0,15719744.0,15750144.0,15780544.0,15810944.0,15841344.0,15871744.0,15902144.0,15723392.0,15753856.0,15784320.0,15814784.0,15845248.0,15875712.0,15906176.0,15936640.0,15757440.0,15787968.0,15818496.0,15849024.0,15879552.0,15910080.0,15940608.0,15971136.0,15791488.0,15822080.0,15852672.0,15883264.0,15913856.0,15944448.0,15975040.0,16005632.0,15825536.0,15856192.0,15886848.0,15917504.0,15948160.0,15978816.0,16009472.0,16040128.0,15859584.0,15890304.0,15921024.0,15951744.0,15982464.0,16013184.0,16043904.0,16074624.0,15893632.0,15924416.0,15955200.0,15985984.0,16016768.0,16047552.0,16078336.0,16109120.0,15927680.0,15958528.0,15989376.0,16020224.0,16051072.0,16081920.0,16112768.0,16143616.0,15961728.0,15992640.0,16023552.0,16054464.0,16085376.0,16116288.0,16147200.0,16178112.0,15995776.0,16026752.0,16057728.0,16088704.0,16119680.0,16150656.0,16181632.0,16212608.0,16029824.0,16060864.0,16091904.0,16122944.0,16153984.0,16185024.0,16216064.0,16247104.0,16063872.0,16094976.0,16126080.0,16157184.0,16188288.0,16219392.0,16250496.0,16281600.0,16097920.0,16129088.0,16160256.0,16191424.0,16222592.0,16253760.0,16284928.0,16316096.0,16131968.0,16163200.0,16194432.0,16225664.0,16256896.0,16288128.0,16319360.0,16350592.0,16166016.0,16197312.0,16228608.0,16259904.0,16291200.0,16322496.0,16353792.0,16385088.0,16200064.0,16231424.0,16262784.0,16294144.0,16325504.0,16356864.0,16388224.0,16419584.0,16234112.0,16265536.0,16296960.0,16328384.0,16359808.0,16391232.0,16422656.0,16454080.0,16268160.0,16299648.0,16331136.0,16362624.0,16394112.0,16425600.0,16457088.0,16488576.0,16302208.0,16333760.0,16365312.0,16396864.0,16428416.0,16459968.0,16491520.0,16523072.0,16336256.0,16367872.0,16399488.0,16431104.0,16462720.0,16494336.0,16525952.0,16557568.0,16370304.0,16401984.0,16433664.0,16465344.0,16497024.0,16528704.0,16560384.0,16592064.0,16404352.0,16436096.0,16467840.0,16499584.0,16531328.0,16563072.0,16594816.0,16626560.0,16438400.0,16470208.0,16502016.0,16533824.0,16565632.0,16597440.0,16629248.0,16661056.0,16472448.0,16504320.0,16536192.0,16568064.0,16599936.0,16631808.0,16663680.0,16695552.0,16506496.0,16538432.0,16570368.0,16602304.0,16634240.0,16666176.0,16698112.0,16730048.0,16540544.0,16572544.0,16604544.0,16636544.0,16668544.0,16700544.0,16732544.0,16764544.0,16574592.0,16606656.0,16638720.0,16670784.0,16702848.0,16734912.0,16766976.0,16799040.0,16608640.0,16640768.0,16672896.0,16705024.0,16737152.0,16769280.0,16801408.0,16833536.0,16642688.0,16674880.0,16707072.0,16739264.0,16771456.0,16803648.0,16835840.0,16868032.0,16676736.0,16708992.0,16741248.0,16773504.0,16805760.0,16838016.0,16870272.0,16902528.0,16710784.0,16743104.0,16775424.0,16807744.0,16840064.0,16872384.0,16904704.0,16937024.0,16744832.0,16777216.0,16809600.0,16841984.0,16874368.0,16906752.0,16939136.0,16971520.0,16500352.0,16531776.0,16563200.0,16594624.0,16626048.0,16657472.0,16688896.0,16720320.0,16247680.0,16278144.0,16308608.0,16339072.0,16369536.0,16400000.0,16430464.0,16460928.0,16281728.0,16312256.0,16342784.0,16373312.0,16403840.0,16434368.0,16464896.0,16495424.0,16315776.0,16346368.0,16376960.0,16407552.0,16438144.0,16468736.0,16499328.0,16529920.0,16349824.0,16380480.0,16411136.0,16441792.0,16472448.0,16503104.0,16533760.0,16564416.0,16383872.0,16414592.0,16445312.0,16476032.0,16506752.0,16537472.0,16568192.0,16598912.0,16417920.0,16448704.0,16479488.0,16510272.0,16541056.0,16571840.0,16602624.0,16633408.0,16451968.0,16482816.0,16513664.0,16544512.0,16575360.0,16606208.0,16637056.0,16667904.0,16486016.0,16516928.0,16547840.0,16578752.0,16609664.0,16640576.0,16671488.0,16702400.0,16520064.0,16551040.0,16582016.0,16612992.0,16643968.0,16674944.0,16705920.0,16736896.0,16554112.0,16585152.0,16616192.0,16647232.0,16678272.0,16709312.0,16740352.0,16771392.0,16588160.0,16619264.0,16650368.0,16681472.0,16712576.0,16743680.0,16774784.0,16805888.0,16622208.0,16653376.0,16684544.0,16715712.0,16746880.0,16778048.0,16809216.0,16840384.0,16656256.0,16687488.0,16718720.0,16749952.0,16781184.0,16812416.0,16843648.0,16874880.0,16690304.0,16721600.0,16752896.0,16784192.0,16815488.0,16846784.0,16878080.0,16909376.0,16724352.0,16755712.0,16787072.0,16818432.0,16849792.0,16881152.0,16912512.0,16943872.0,16758400.0,16789824.0,16821248.0,16852672.0,16884096.0,16915520.0,16946944.0,16978368.0,16792448.0,16823936.0,16855424.0,16886912.0,16918400.0,16949888.0,16981376.0,17012864.0,16826496.0,16858048.0,16889600.0,16921152.0,16952704.0,16984256.0,17015808.0,17047360.0,16860544.0,16892160.0,16923776.0,16955392.0,16987008.0,17018624.0,17050240.0,17081856.0,16894592.0,16926272.0,16957952.0,16989632.0,17021312.0,17052992.0,17084672.0,17116352.0,16928640.0,16960384.0,16992128.0,17023872.0,17055616.0,17087360.0,17119104.0,17150848.0,16962688.0,16994496.0,17026304.0,17058112.0,17089920.0,17121728.0,17153536.0,17185344.0,16996736.0,17028608.0,17060480.0,17092352.0,17124224.0,17156096.0,17187968.0,17219840.0,17030784.0,17062720.0,17094656.0,17126592.0,17158528.0,17190464.0,17222400.0,17254336.0,17064832.0,17096832.0,17128832.0,17160832.0,17192832.0,17224832.0,17256832.0,17288832.0,17098880.0,17130944.0,17163008.0,17195072.0,17227136.0,17259200.0,17291264.0,17323328.0,17132928.0,17165056.0,17197184.0,17229312.0,17261440.0,17293568.0,17325696.0,17357824.0,17166976.0,17199168.0,17231360.0,17263552.0,17295744.0,17327936.0,17360128.0,17392320.0,17201024.0,17233280.0,17265536.0,17297792.0,17330048.0,17362304.0,17394560.0,17426816.0,17235072.0,17267392.0,17299712.0,17332032.0,17364352.0,17396672.0,17428992.0,17461312.0,17269120.0,17301504.0,17333888.0,17366272.0,17398656.0,17431040.0,17463424.0,17495808.0,17303168.0,17335616.0,17368064.0,17400512.0,17432960.0,17465408.0,17497856.0,17530304.0,17337216.0,17369728.0,17402240.0,17434752.0,17467264.0,17499776.0,17532288.0,17564800.0,17371264.0,17403840.0,17436416.0,17468992.0,17501568.0,17534144.0,17566720.0,17599296.0,17405312.0,17437952.0,17470592.0,17503232.0,17535872.0,17568512.0,17601152.0,17633792.0,17439360.0,17472064.0,17504768.0,17537472.0,17570176.0,17602880.0,17635584.0,17668288.0,17473408.0,17506176.0,17538944.0,17571712.0,17604480.0,17637248.0,17670016.0,17702784.0,17507456.0,17540288.0,17573120.0,17605952.0,17638784.0,17671616.0,17704448.0,17737280.0,16738688.0,16770560.0,16802432.0,16834304.0,16866176.0,16898048.0,16929920.0,16961792.0,15961728.0,15992640.0,16023552.0,16054464.0,16085376.0,16116288.0,16147200.0,16178112.0,15995776.0,16026752.0,16057728.0,16088704.0,16119680.0,16150656.0,16181632.0,16212608.0,16029824.0,16060864.0,16091904.0,16122944.0,16153984.0,16185024.0,16216064.0,16247104.0,16063872.0,16094976.0,16126080.0,16157184.0,16188288.0,16219392.0,16250496.0,16281600.0,16097920.0,16129088.0,16160256.0,16191424.0,16222592.0,16253760.0,16284928.0,16316096.0,16131968.0,16163200.0,16194432.0,16225664.0,16256896.0,16288128.0,16319360.0,16350592.0,16166016.0,16197312.0,16228608.0,16259904.0,16291200.0,16322496.0,16353792.0,16385088.0,16200064.0,16231424.0,16262784.0,16294144.0,16325504.0,16356864.0,16388224.0,16419584.0,16234112.0,16265536.0,16296960.0,16328384.0,16359808.0,16391232.0,16422656.0,16454080.0,16268160.0,16299648.0,16331136.0,16362624.0,16394112.0,16425600.0,16457088.0,16488576.0,16302208.0,16333760.0,16365312.0,16396864.0,16428416.0,16459968.0,16491520.0,16523072.0,16336256.0,16367872.0,16399488.0,16431104.0,16462720.0,16494336.0,16525952.0,16557568.0,16370304.0,16401984.0,16433664.0,16465344.0,16497024.0,16528704.0,16560384.0,16592064.0,16404352.0,16436096.0,16467840.0,16499584.0,16531328.0,16563072.0,16594816.0,16626560.0,16438400.0,16470208.0,16502016.0,16533824.0,16565632.0,16597440.0,16629248.0,16661056.0,16472448.0,16504320.0,16536192.0,16568064.0,16599936.0,16631808.0,16663680.0,16695552.0,16506496.0,16538432.0,16570368.0,16602304.0,16634240.0,16666176.0,16698112.0,16730048.0,16540544.0,16572544.0,16604544.0,16636544.0,16668544.0,16700544.0,16732544.0,16764544.0,16574592.0,16606656.0,16638720.0,16670784.0,16702848.0,16734912.0,16766976.0,16799040.0,16608640.0,16640768.0,16672896.0,16705024.0,16737152.0,16769280.0,16801408.0,16833536.0,16642688.0,16674880.0,16707072.0,16739264.0,16771456.0,16803648.0,16835840.0,16868032.0,16676736.0,16708992.0,16741248.0,16773504.0,16805760.0,16838016.0,16870272.0,16902528.0,16710784.0,16743104.0,16775424.0,16807744.0,16840064.0,16872384.0,16904704.0,16937024.0,16744832.0,16777216.0,16809600.0,16841984.0,16874368.0,16906752.0,16939136.0,16971520.0,16778880.0,16811328.0,16843776.0,16876224.0,16908672.0,16941120.0,16973568.0,17006016.0,16812928.0,16845440.0,16877952.0,16910464.0,16942976.0,16975488.0,17008000.0,17040512.0,16846976.0,16879552.0,16912128.0,16944704.0,16977280.0,17009856.0,17042432.0,17075008.0,16881024.0,16913664.0,16946304.0,16978944.0,17011584.0,17044224.0,17076864.0,17109504.0,16915072.0,16947776.0,16980480.0,17013184.0,17045888.0,17078592.0,17111296.0,17144000.0,16949120.0,16981888.0,17014656.0,17047424.0,17080192.0,17112960.0,17145728.0,17178496.0,16983168.0,17016000.0,17048832.0,17081664.0,17114496.0,17147328.0,17180160.0,17212992.0,17017216.0,17050112.0,17083008.0,17115904.0,17148800.0,17181696.0,17214592.0,17247488.0,17051264.0,17084224.0,17117184.0,17150144.0,17183104.0,17216064.0,17249024.0,17281984.0,17085312.0,17118336.0,17151360.0,17184384.0,17217408.0,17250432.0,17283456.0,17316480.0,17119360.0,17152448.0,17185536.0,17218624.0,17251712.0,17284800.0,17317888.0,17350976.0,17153408.0,17186560.0,17219712.0,17252864.0,17286016.0,17319168.0,17352320.0,17385472.0,17187456.0,17220672.0,17253888.0,17287104.0,17320320.0,17353536.0,17386752.0,17419968.0,17221504.0,17254784.0,17288064.0,17321344.0,17354624.0,17387904.0,17421184.0,17454464.0,16977024.0,17009344.0,17041664.0,17073984.0,17106304.0,17138624.0,17170944.0,17203264.0,16724352.0,16755712.0,16787072.0,16818432.0,16849792.0,16881152.0,16912512.0,16943872.0,16758400.0,16789824.0,16821248.0,16852672.0,16884096.0,16915520.0,16946944.0,16978368.0,16792448.0,16823936.0,16855424.0,16886912.0,16918400.0,16949888.0,16981376.0,17012864.0,16826496.0,16858048.0,16889600.0,16921152.0,16952704.0,16984256.0,17015808.0,17047360.0,16860544.0,16892160.0,16923776.0,16955392.0,16987008.0,17018624.0,17050240.0,17081856.0,16894592.0,16926272.0,16957952.0,16989632.0,17021312.0,17052992.0,17084672.0,17116352.0,16928640.0,16960384.0,16992128.0,17023872.0,17055616.0,17087360.0,17119104.0,17150848.0,16962688.0,16994496.0,17026304.0,17058112.0,17089920.0,17121728.0,17153536.0,17185344.0,16996736.0,17028608.0,17060480.0,17092352.0,17124224.0,17156096.0,17187968.0,17219840.0,17030784.0,17062720.0,17094656.0,17126592.0,17158528.0,17190464.0,17222400.0,17254336.0,17064832.0,17096832.0,17128832.0,17160832.0,17192832.0,17224832.0,17256832.0,17288832.0,17098880.0,17130944.0,17163008.0,17195072.0,17227136.0,17259200.0,17291264.0,17323328.0,17132928.0,17165056.0,17197184.0,17229312.0,17261440.0,17293568.0,17325696.0,17357824.0,17166976.0,17199168.0,17231360.0,17263552.0,17295744.0,17327936.0,17360128.0,17392320.0,17201024.0,17233280.0,17265536.0,17297792.0,17330048.0,17362304.0,17394560.0,17426816.0,17235072.0,17267392.0,17299712.0,17332032.0,17364352.0,17396672.0,17428992.0,17461312.0,17269120.0,17301504.0,17333888.0,17366272.0,17398656.0,17431040.0,17463424.0,17495808.0,17303168.0,17335616.0,17368064.0,17400512.0,17432960.0,17465408.0,17497856.0,17530304.0,17337216.0,17369728.0,17402240.0,17434752.0,17467264.0,17499776.0,17532288.0,17564800.0,17371264.0,17403840.0,17436416.0,17468992.0,17501568.0,17534144.0,17566720.0,17599296.0,17405312.0,17437952.0,17470592.0,17503232.0,17535872.0,17568512.0,17601152.0,17633792.0,17439360.0,17472064.0,17504768.0,17537472.0,17570176.0,17602880.0,17635584.0,17668288.0,17473408.0,17506176.0,17538944.0,17571712.0,17604480.0,17637248.0,17670016.0,17702784.0,17507456.0,17540288.0,17573120.0,17605952.0,17638784.0,17671616.0,17704448.0,17737280.0,17541504.0,17574400.0,17607296.0,17640192.0,17673088.0,17705984.0,17738880.0,17771776.0,17575552.0,17608512.0,17641472.0,17674432.0,17707392.0,17740352.0,17773312.0,17806272.0,17609600.0,17642624.0,17675648.0,17708672.0,17741696.0,17774720.0,17807744.0,17840768.0,17643648.0,17676736.0,17709824.0,17742912.0,17776000.0,17809088.0,17842176.0,17875264.0,17677696.0,17710848.0,17744000.0,17777152.0,17810304.0,17843456.0,17876608.0,17909760.0,17711744.0,17744960.0,17778176.0,17811392.0,17844608.0,17877824.0,17911040.0,17944256.0,17745792.0,17779072.0,17812352.0,17845632.0,17878912.0,17912192.0,17945472.0,17978752.0,17779840.0,17813184.0,17846528.0,17879872.0,17913216.0,17946560.0,17979904.0,18013248.0,17813888.0,17847296.0,17880704.0,17914112.0,17947520.0,17980928.0,18014336.0,18047744.0,17847936.0,17881408.0,17914880.0,17948352.0,17981824.0,18015296.0,18048768.0,18082240.0,17881984.0,17915520.0,17949056.0,17982592.0,18016128.0,18049664.0,18083200.0,18116736.0,17916032.0,17949632.0,17983232.0,18016832.0,18050432.0,18084032.0,18117632.0,18151232.0,17950080.0,17983744.0,18017408.0,18051072.0,18084736.0,18118400.0,18152064.0,18185728.0,17984128.0,18017856.0,18051584.0,18085312.0,18119040.0,18152768.0,18186496.0,18220224.0,17215360.0,17248128.0,17280896.0,17313664.0,17346432.0,17379200.0,17411968.0,17444736.0,16438400.0,16470208.0,16502016.0,16533824.0,16565632.0,16597440.0,16629248.0,16661056.0,8291456.0,8307376.0,8323296.0,8339216.0,8355136.0,8371056.0,8386976.0,8402896.0,8873472.0,8889264.0,8905056.0,8920848.0,8936640.0,8952432.0,8968224.0,8984016.0,17890688.0,17922304.0,17953920.0,17985536.0,18017152.0,18048768.0,18080384.0,18112000.0,17925760.0,17957440.0,17989120.0,18020800.0,18052480.0,18084160.0,18115840.0,18147520.0,17960832.0,17992576.0,18024320.0,18056064.0,18087808.0,18119552.0,18151296.0,18183040.0,17995904.0,18027712.0,18059520.0,18091328.0,18123136.0,18154944.0,18186752.0,18218560.0,18030976.0,18062848.0,18094720.0,18126592.0,18158464.0,18190336.0,18222208.0,18254080.0,18066048.0,18097984.0,18129920.0,18161856.0,18193792.0,18225728.0,18257664.0,18289600.0,18101120.0,18133120.0,18165120.0,18197120.0,18229120.0,18261120.0,18293120.0,18325120.0,18136192.0,18168256.0,18200320.0,18232384.0,18264448.0,18296512.0,18328576.0,18360640.0,18171264.0,18203392.0,18235520.0,18267648.0,18299776.0,18331904.0,18364032.0,18396160.0,18206336.0,18238528.0,18270720.0,18302912.0,18335104.0,18367296.0,18399488.0,18431680.0,18241408.0,18273664.0,18305920.0,18338176.0,18370432.0,18402688.0,18434944.0,18467200.0,18276480.0,18308800.0,18341120.0,18373440.0,18405760.0,18438080.0,18470400.0,18502720.0,18311552.0,18343936.0,18376320.0,18408704.0,18441088.0,18473472.0,18505856.0,18538240.0,18346624.0,18379072.0,18411520.0,18443968.0,18476416.0,18508864.0,18541312.0,18573760.0,18381696.0,18414208.0,18446720.0,18479232.0,18511744.0,18544256.0,18576768.0,18609280.0,18416768.0,18449344.0,18481920.0,18514496.0,18547072.0,18579648.0,18612224.0,18644800.0,18451840.0,18484480.0,18517120.0,18549760.0,18582400.0,18615040.0,18647680.0,18680320.0,18486912.0,18519616.0,18552320.0,18585024.0,18617728.0,18650432.0,18683136.0,18715840.0,18521984.0,18554752.0,18587520.0,18620288.0,18653056.0,18685824.0,18718592.0,18751360.0,18557056.0,18589888.0,18622720.0,18655552.0,18688384.0,18721216.0,18754048.0,18786880.0,18592128.0,18625024.0,18657920.0,18690816.0,18723712.0,18756608.0,18789504.0,18822400.0,18627200.0,18660160.0,18693120.0,18726080.0,18759040.0,18792000.0,18824960.0,18857920.0,18662272.0,18695296.0,18728320.0,18761344.0,18794368.0,18827392.0,18860416.0,18893440.0,18697344.0,18730432.0,18763520.0,18796608.0,18829696.0,18862784.0,18895872.0,18928960.0,18732416.0,18765568.0,18798720.0,18831872.0,18865024.0,18898176.0,18931328.0,18964480.0,18767488.0,18800704.0,18833920.0,18867136.0,18900352.0,18933568.0,18966784.0,19000000.0,18802560.0,18835840.0,18869120.0,18902400.0,18935680.0,18968960.0,19002240.0,19035520.0,18837632.0,18870976.0,18904320.0,18937664.0,18971008.0,19004352.0,19037696.0,19071040.0,18872704.0,18906112.0,18939520.0,18972928.0,19006336.0,19039744.0,19073152.0,19106560.0,18907776.0,18941248.0,18974720.0,19008192.0,19041664.0,19075136.0,19108608.0,19142080.0,18942848.0,18976384.0,19009920.0,19043456.0,19076992.0,19110528.0,19144064.0,19177600.0,18977920.0,19011520.0,19045120.0,19078720.0,19112320.0,19145920.0,19179520.0,19213120.0,19012992.0,19046656.0,19080320.0,19113984.0,19147648.0,19181312.0,19214976.0,19248640.0,19048064.0,19081792.0,19115520.0,19149248.0,19182976.0,19216704.0,19250432.0,19284160.0,19083136.0,19116928.0,19150720.0,19184512.0,19218304.0,19252096.0,19285888.0,19319680.0,19118208.0,19152064.0,19185920.0,19219776.0,19253632.0,19287488.0,19321344.0,19355200.0,18334080.0,18366976.0,18399872.0,18432768.0,18465664.0,18498560.0,18531456.0,18564352.0,17541760.0,17573696.0,17605632.0,17637568.0,17669504.0,17701440.0,17733376.0,17765312.0,16757632.0,16788608.0,16819584.0,16850560.0,16881536.0,16912512.0,16943488.0,16974464.0,15965312.0,15995328.0,16025344.0,16055360.0,16085376.0,16115392.0,16145408.0,16175424.0,16000384.0,16030464.0,16060544.0,16090624.0,16120704.0,16150784.0,16180864.0,16210944.0,16035456.0,16065600.0,16095744.0,16125888.0,16156032.0,16186176.0,16216320.0,16246464.0,16070528.0,16100736.0,16130944.0,16161152.0,16191360.0,16221568.0,16251776.0,16281984.0,16105600.0,16135872.0,16166144.0,16196416.0,16226688.0,16256960.0,16287232.0,16317504.0,16140672.0,16171008.0,16201344.0,16231680.0,16262016.0,16292352.0,16322688.0,16353024.0,16175744.0,16206144.0,16236544.0,16266944.0,16297344.0,16327744.0,16358144.0,16388544.0,16210816.0,16241280.0,16271744.0,16302208.0,16332672.0,16363136.0,16393600.0,16424064.0,16245888.0,16276416.0,16306944.0,16337472.0,16368000.0,16398528.0,16429056.0,16459584.0,16280960.0,16311552.0,16342144.0,16372736.0,16403328.0,16433920.0,16464512.0,16495104.0,16316032.0,16346688.0,16377344.0,16408000.0,16438656.0,16469312.0,16499968.0,16530624.0,16351104.0,16381824.0,16412544.0,16443264.0,16473984.0,16504704.0,16535424.0,16566144.0,16386176.0,16416960.0,16447744.0,16478528.0,16509312.0,16540096.0,16570880.0,16601664.0,16421248.0,16452096.0,16482944.0,16513792.0,16544640.0,16575488.0,16606336.0,16637184.0,16456320.0,16487232.0,16518144.0,16549056.0,16579968.0,16610880.0,16641792.0,16672704.0,16491392.0,16522368.0,16553344.0,16584320.0,16615296.0,16646272.0,16677248.0,16708224.0,16526464.0,16557504.0,16588544.0,16619584.0,16650624.0,16681664.0,16712704.0,16743744.0,16561536.0,16592640.0,16623744.0,16654848.0,16685952.0,16717056.0,16748160.0,16779264.0,16596608.0,16627776.0,16658944.0,16690112.0,16721280.0,16752448.0,16783616.0,16814784.0,16631680.0,16662912.0,16694144.0,16725376.0,16756608.0,16787840.0,16819072.0,16850304.0,16666752.0,16698048.0,16729344.0,16760640.0,16791936.0,16823232.0,16854528.0,16885824.0,16701824.0,16733184.0,16764544.0,16795904.0,16827264.0,16858624.0,16889984.0,16921344.0,16736896.0,16768320.0,16799744.0,16831168.0,16862592.0,16894016.0,16925440.0,16956864.0,16771968.0,16803456.0,16834944.0,16866432.0,16897920.0,16929408.0,16960896.0,16992384.0,16807040.0,16838592.0,16870144.0,16901696.0,16933248.0,16964800.0,16996352.0,17027904.0,16842112.0,16873728.0,16905344.0,16936960.0,16968576.0,17000192.0,17031808.0,17063424.0,16877184.0,16908864.0,16940544.0,16972224.0,17003904.0,17035584.0,17067264.0,17098944.0,16912256.0,16944000.0,16975744.0,17007488.0,17039232.0,17070976.0,17102720.0,17134464.0,16947328.0,16979136.0,17010944.0,17042752.0,17074560.0,17106368.0,17138176.0,17169984.0,16982400.0,17014272.0,17046144.0,17078016.0,17109888.0,17141760.0,17173632.0,17205504.0,17017472.0,17049408.0,17081344.0,17113280.0,17145216.0,17177152.0,17209088.0,17241024.0,17052544.0,17084544.0,17116544.0,17148544.0,17180544.0,17212544.0,17244544.0,17276544.0,17087616.0,17119680.0,17151744.0,17183808.0,17215872.0,17247936.0,17280000.0,17312064.0,17122688.0,17154816.0,17186944.0,17219072.0,17251200.0,17283328.0,17315456.0,17347584.0,17157760.0,17189952.0,17222144.0,17254336.0,17286528.0,17318720.0,17350912.0,17383104.0,17192832.0,17225088.0,17257344.0,17289600.0,17321856.0,17354112.0,17386368.0,17418624.0,17227904.0,17260224.0,17292544.0,17324864.0,17357184.0,17389504.0,17421824.0,17454144.0,17262976.0,17295360.0,17327744.0,17360128.0,17392512.0,17424896.0,17457280.0,17489664.0,17003136.0,17034560.0,17065984.0,17097408.0,17128832.0,17160256.0,17191680.0,17223104.0,16735104.0,16765568.0,16796032.0,16826496.0,16856960.0,16887424.0,16917888.0,16948352.0,16770176.0,16800704.0,16831232.0,16861760.0,16892288.0,16922816.0,16953344.0,16983872.0,16805248.0,16835840.0,16866432.0,16897024.0,16927616.0,16958208.0,16988800.0,17019392.0,16840320.0,16870976.0,16901632.0,16932288.0,16962944.0,16993600.0,17024256.0,17054912.0,16875392.0,16906112.0,16936832.0,16967552.0,16998272.0,17028992.0,17059712.0,17090432.0,16910464.0,16941248.0,16972032.0,17002816.0,17033600.0,17064384.0,17095168.0,17125952.0,16945536.0,16976384.0,17007232.0,17038080.0,17068928.0,17099776.0,17130624.0,17161472.0,16980608.0,17011520.0,17042432.0,17073344.0,17104256.0,17135168.0,17166080.0,17196992.0,17015680.0,17046656.0,17077632.0,17108608.0,17139584.0,17170560.0,17201536.0,17232512.0,17050752.0,17081792.0,17112832.0,17143872.0,17174912.0,17205952.0,17236992.0,17268032.0,17085824.0,17116928.0,17148032.0,17179136.0,17210240.0,17241344.0,17272448.0,17303552.0,17120896.0,17152064.0,17183232.0,17214400.0,17245568.0,17276736.0,17307904.0,17339072.0,17155968.0,17187200.0,17218432.0,17249664.0,17280896.0,17312128.0,17343360.0,17374592.0,17191040.0,17222336.0,17253632.0,17284928.0,17316224.0,17347520.0,17378816.0,17410112.0,17226112.0,17257472.0,17288832.0,17320192.0,17351552.0,17382912.0,17414272.0,17445632.0,17261184.0,17292608.0,17324032.0,17355456.0,17386880.0,17418304.0,17449728.0,17481152.0,17296256.0,17327744.0,17359232.0,17390720.0,17422208.0,17453696.0,17485184.0,17516672.0,17331328.0,17362880.0,17394432.0,17425984.0,17457536.0,17489088.0,17520640.0,17552192.0,17366400.0,17398016.0,17429632.0,17461248.0,17492864.0,17524480.0,17556096.0,17587712.0,17401472.0,17433152.0,17464832.0,17496512.0,17528192.0,17559872.0,17591552.0,17623232.0,17436544.0,17468288.0,17500032.0,17531776.0,17563520.0,17595264.0,17627008.0,17658752.0,17471616.0,17503424.0,17535232.0,17567040.0,17598848.0,17630656.0,17662464.0,17694272.0,17506688.0,17538560.0,17570432.0,17602304.0,17634176.0,17666048.0,17697920.0,17729792.0,17541760.0,17573696.0,17605632.0,17637568.0,17669504.0,17701440.0,17733376.0,17765312.0,17576832.0,17608832.0,17640832.0,17672832.0,17704832.0,17736832.0,17768832.0,17800832.0,17611904.0,17643968.0,17676032.0,17708096.0,17740160.0,17772224.0,17804288.0,17836352.0,17646976.0,17679104.0,17711232.0,17743360.0,17775488.0,17807616.0,17839744.0,17871872.0,17682048.0,17714240.0,17746432.0,17778624.0,17810816.0,17843008.0,17875200.0,17907392.0,17717120.0,17749376.0,17781632.0,17813888.0,17846144.0,17878400.0,17910656.0,17942912.0,17752192.0,17784512.0,17816832.0,17849152.0,17881472.0,17913792.0,17946112.0,17978432.0,17787264.0,17819648.0,17852032.0,17884416.0,17916800.0,17949184.0,17981568.0,18013952.0,17822336.0,17854784.0,17887232.0,17919680.0,17952128.0,17984576.0,18017024.0,18049472.0,17857408.0,17889920.0,17922432.0,17954944.0,17987456.0,18019968.0,18052480.0,18084992.0,17892480.0,17925056.0,17957632.0,17990208.0,18022784.0,18055360.0,18087936.0,18120512.0,17927552.0,17960192.0,17992832.0,18025472.0,18058112.0,18090752.0,18123392.0,18156032.0,17962624.0,17995328.0,18028032.0,18060736.0,18093440.0,18126144.0,18158848.0,18191552.0,17997696.0,18030464.0,18063232.0,18096000.0,18128768.0,18161536.0,18194304.0,18227072.0,18032768.0,18065600.0,18098432.0,18131264.0,18164096.0,18196928.0,18229760.0,18262592.0,17248640.0,17280512.0,17312384.0,17344256.0,17376128.0,17408000.0,17439872.0,17471744.0,16456320.0,16487232.0,16518144.0,16549056.0,16579968.0,16610880.0,16641792.0,16672704.0,16491392.0,16522368.0,16553344.0,16584320.0,16615296.0,16646272.0,16677248.0,16708224.0,16526464.0,16557504.0,16588544.0,16619584.0,16650624.0,16681664.0,16712704.0,16743744.0,16561536.0,16592640.0,16623744.0,16654848.0,16685952.0,16717056.0,16748160.0,16779264.0,16596608.0,16627776.0,16658944.0,16690112.0,16721280.0,16752448.0,16783616.0,16814784.0,16631680.0,16662912.0,16694144.0,16725376.0,16756608.0,16787840.0,16819072.0,16850304.0,16666752.0,16698048.0,16729344.0,16760640.0,16791936.0,16823232.0,16854528.0,16885824.0,16701824.0,16733184.0,16764544.0,16795904.0,16827264.0,16858624.0,16889984.0,16921344.0,16736896.0,16768320.0,16799744.0,16831168.0,16862592.0,16894016.0,16925440.0,16956864.0,16771968.0,16803456.0,16834944.0,16866432.0,16897920.0,16929408.0,16960896.0,16992384.0,16807040.0,16838592.0,16870144.0,16901696.0,16933248.0,16964800.0,16996352.0,17027904.0,16842112.0,16873728.0,16905344.0,16936960.0,16968576.0,17000192.0,17031808.0,17063424.0,16877184.0,16908864.0,16940544.0,16972224.0,17003904.0,17035584.0,17067264.0,17098944.0,16912256.0,16944000.0,16975744.0,17007488.0,17039232.0,17070976.0,17102720.0,17134464.0,16947328.0,16979136.0,17010944.0,17042752.0,17074560.0,17106368.0,17138176.0,17169984.0,16982400.0,17014272.0,17046144.0,17078016.0,17109888.0,17141760.0,17173632.0,17205504.0,17017472.0,17049408.0,17081344.0,17113280.0,17145216.0,17177152.0,17209088.0,17241024.0,17052544.0,17084544.0,17116544.0,17148544.0,17180544.0,17212544.0,17244544.0,17276544.0,17087616.0,17119680.0,17151744.0,17183808.0,17215872.0,17247936.0,17280000.0,17312064.0,17122688.0,17154816.0,17186944.0,17219072.0,17251200.0,17283328.0,17315456.0,17347584.0,17157760.0,17189952.0,17222144.0,17254336.0,17286528.0,17318720.0,17350912.0,17383104.0,17192832.0,17225088.0,17257344.0,17289600.0,17321856.0,17354112.0,17386368.0,17418624.0,17227904.0,17260224.0,17292544.0,17324864.0,17357184.0,17389504.0,17421824.0,17454144.0,17262976.0,17295360.0,17327744.0,17360128.0,17392512.0,17424896.0,17457280.0,17489664.0,17298048.0,17330496.0,17362944.0,17395392.0,17427840.0,17460288.0,17492736.0,17525184.0,17333120.0,17365632.0,17398144.0,17430656.0,17463168.0,17495680.0,17528192.0,17560704.0,17368192.0,17400768.0,17433344.0,17465920.0,17498496.0,17531072.0,17563648.0,17596224.0,17403264.0,17435904.0,17468544.0,17501184.0,17533824.0,17566464.0,17599104.0,17631744.0,17438336.0,17471040.0,17503744.0,17536448.0,17569152.0,17601856.0,17634560.0,17667264.0,17473408.0,17506176.0,17538944.0,17571712.0,17604480.0,17637248.0,17670016.0,17702784.0,17508480.0,17541312.0,17574144.0,17606976.0,17639808.0,17672640.0,17705472.0,17738304.0,17543552.0,17576448.0,17609344.0,17642240.0,17675136.0,17708032.0,17740928.0,17773824.0,17578624.0,17611584.0,17644544.0,17677504.0,17710464.0,17743424.0,17776384.0,17809344.0,17613696.0,17646720.0,17679744.0,17712768.0,17745792.0,17778816.0,17811840.0,17844864.0,17648768.0,17681856.0,17714944.0,17748032.0,17781120.0,17814208.0,17847296.0,17880384.0,17683840.0,17716992.0,17750144.0,17783296.0,17816448.0,17849600.0,17882752.0,17915904.0,17718912.0,17752128.0,17785344.0,17818560.0,17851776.0,17884992.0,17918208.0,17951424.0,17753984.0,17787264.0,17820544.0,17853824.0,17887104.0,17920384.0,17953664.0,17986944.0,17494144.0,17526464.0,17558784.0,17591104.0,17623424.0,17655744.0,17688064.0,17720384.0,17226112.0,17257472.0,17288832.0,17320192.0,17351552.0,17382912.0,17414272.0,17445632.0,17261184.0,17292608.0,17324032.0,17355456.0,17386880.0,17418304.0,17449728.0,17481152.0,17296256.0,17327744.0,17359232.0,17390720.0,17422208.0,17453696.0,17485184.0,17516672.0,17331328.0,17362880.0,17394432.0,17425984.0,17457536.0,17489088.0,17520640.0,17552192.0,17366400.0,17398016.0,17429632.0,17461248.0,17492864.0,17524480.0,17556096.0,17587712.0,17401472.0,17433152.0,17464832.0,17496512.0,17528192.0,17559872.0,17591552.0,17623232.0,17436544.0,17468288.0,17500032.0,17531776.0,17563520.0,17595264.0,17627008.0,17658752.0,17471616.0,17503424.0,17535232.0,17567040.0,17598848.0,17630656.0,17662464.0,17694272.0,17506688.0,17538560.0,17570432.0,17602304.0,17634176.0,17666048.0,17697920.0,17729792.0,17541760.0,17573696.0,17605632.0,17637568.0,17669504.0,17701440.0,17733376.0,17765312.0,17576832.0,17608832.0,17640832.0,17672832.0,17704832.0,17736832.0,17768832.0,17800832.0,17611904.0,17643968.0,17676032.0,17708096.0,17740160.0,17772224.0,17804288.0,17836352.0,17646976.0,17679104.0,17711232.0,17743360.0,17775488.0,17807616.0,17839744.0,17871872.0,17682048.0,17714240.0,17746432.0,17778624.0,17810816.0,17843008.0,17875200.0,17907392.0,17717120.0,17749376.0,17781632.0,17813888.0,17846144.0,17878400.0,17910656.0,17942912.0,17752192.0,17784512.0,17816832.0,17849152.0,17881472.0,17913792.0,17946112.0,17978432.0,17787264.0,17819648.0,17852032.0,17884416.0,17916800.0,17949184.0,17981568.0,18013952.0,17822336.0,17854784.0,17887232.0,17919680.0,17952128.0,17984576.0,18017024.0,18049472.0,17857408.0,17889920.0,17922432.0,17954944.0,17987456.0,18019968.0,18052480.0,18084992.0,17892480.0,17925056.0,17957632.0,17990208.0,18022784.0,18055360.0,18087936.0,18120512.0,17927552.0,17960192.0,17992832.0,18025472.0,18058112.0,18090752.0,18123392.0,18156032.0,17962624.0,17995328.0,18028032.0,18060736.0,18093440.0,18126144.0,18158848.0,18191552.0,17997696.0,18030464.0,18063232.0,18096000.0,18128768.0,18161536.0,18194304.0,18227072.0,18032768.0,18065600.0,18098432.0,18131264.0,18164096.0,18196928.0,18229760.0,18262592.0,18067840.0,18100736.0,18133632.0,18166528.0,18199424.0,18232320.0,18265216.0,18298112.0,18102912.0,18135872.0,18168832.0,18201792.0,18234752.0,18267712.0,18300672.0,18333632.0,18137984.0,18171008.0,18204032.0,18237056.0,18270080.0,18303104.0,18336128.0,18369152.0,18173056.0,18206144.0,18239232.0,18272320.0,18305408.0,18338496.0,18371584.0,18404672.0,18208128.0,18241280.0,18274432.0,18307584.0,18340736.0,18373888.0,18407040.0,18440192.0,18243200.0,18276416.0,18309632.0,18342848.0,18376064.0,18409280.0,18442496.0,18475712.0,18278272.0,18311552.0,18344832.0,18378112.0,18411392.0,18444672.0,18477952.0,18511232.0,18313344.0,18346688.0,18380032.0,18413376.0,18446720.0,18480064.0,18513408.0,18546752.0,18348416.0,18381824.0,18415232.0,18448640.0,18482048.0,18515456.0,18548864.0,18582272.0,18383488.0,18416960.0,18450432.0,18483904.0,18517376.0,18550848.0,18584320.0,18617792.0,18418560.0,18452096.0,18485632.0,18519168.0,18552704.0,18586240.0,18619776.0,18653312.0,18453632.0,18487232.0,18520832.0,18554432.0,18588032.0,18621632.0,18655232.0,18688832.0,18488704.0,18522368.0,18556032.0,18589696.0,18623360.0,18657024.0,18690688.0,18724352.0,18523776.0,18557504.0,18591232.0,18624960.0,18658688.0,18692416.0,18726144.0,18759872.0,17739648.0,17772416.0,17805184.0,17837952.0,17870720.0,17903488.0,17936256.0,17969024.0,16947328.0,16979136.0,17010944.0,17042752.0,17074560.0,17106368.0,17138176.0,17169984.0,8546176.0,8562096.0,8578016.0,8593936.0,8609856.0,8625776.0,8641696.0,8657616.0,9126144.0,9141936.0,9157728.0,9173520.0,9189312.0,9205104.0,9220896.0,9236688.0,18396544.0,18428160.0,18459776.0,18491392.0,18523008.0,18554624.0,18586240.0,18617856.0,18432640.0,18464320.0,18496000.0,18527680.0,18559360.0,18591040.0,18622720.0,18654400.0,18468736.0,18500480.0,18532224.0,18563968.0,18595712.0,18627456.0,18659200.0,18690944.0,18504832.0,18536640.0,18568448.0,18600256.0,18632064.0,18663872.0,18695680.0,18727488.0,18540928.0,18572800.0,18604672.0,18636544.0,18668416.0,18700288.0,18732160.0,18764032.0,18577024.0,18608960.0,18640896.0,18672832.0,18704768.0,18736704.0,18768640.0,18800576.0,18613120.0,18645120.0,18677120.0,18709120.0,18741120.0,18773120.0,18805120.0,18837120.0,18649216.0,18681280.0,18713344.0,18745408.0,18777472.0,18809536.0,18841600.0,18873664.0,18685312.0,18717440.0,18749568.0,18781696.0,18813824.0,18845952.0,18878080.0,18910208.0,18721408.0,18753600.0,18785792.0,18817984.0,18850176.0,18882368.0,18914560.0,18946752.0,18757504.0,18789760.0,18822016.0,18854272.0,18886528.0,18918784.0,18951040.0,18983296.0,18793600.0,18825920.0,18858240.0,18890560.0,18922880.0,18955200.0,18987520.0,19019840.0,18829696.0,18862080.0,18894464.0,18926848.0,18959232.0,18991616.0,19024000.0,19056384.0,18865792.0,18898240.0,18930688.0,18963136.0,18995584.0,19028032.0,19060480.0,19092928.0,18901888.0,18934400.0,18966912.0,18999424.0,19031936.0,19064448.0,19096960.0,19129472.0,18937984.0,18970560.0,19003136.0,19035712.0,19068288.0,19100864.0,19133440.0,19166016.0,18974080.0,19006720.0,19039360.0,19072000.0,19104640.0,19137280.0,19169920.0,19202560.0,19010176.0,19042880.0,19075584.0,19108288.0,19140992.0,19173696.0,19206400.0,19239104.0,19046272.0,19079040.0,19111808.0,19144576.0,19177344.0,19210112.0,19242880.0,19275648.0,19082368.0,19115200.0,19148032.0,19180864.0,19213696.0,19246528.0,19279360.0,19312192.0,19118464.0,19151360.0,19184256.0,19217152.0,19250048.0,19282944.0,19315840.0,19348736.0,19154560.0,19187520.0,19220480.0,19253440.0,19286400.0,19319360.0,19352320.0,19385280.0,19190656.0,19223680.0,19256704.0,19289728.0,19322752.0,19355776.0,19388800.0,19421824.0,19226752.0,19259840.0,19292928.0,19326016.0,19359104.0,19392192.0,19425280.0,19458368.0,19262848.0,19296000.0,19329152.0,19362304.0,19395456.0,19428608.0,19461760.0,19494912.0,19298944.0,19332160.0,19365376.0,19398592.0,19431808.0,19465024.0,19498240.0,19531456.0,19335040.0,19368320.0,19401600.0,19434880.0,19468160.0,19501440.0,19534720.0,19568000.0,19371136.0,19404480.0,19437824.0,19471168.0,19504512.0,19537856.0,19571200.0,19604544.0,19407232.0,19440640.0,19474048.0,19507456.0,19540864.0,19574272.0,19607680.0,19641088.0,19443328.0,19476800.0,19510272.0,19543744.0,19577216.0,19610688.0,19644160.0,19677632.0,19479424.0,19512960.0,19546496.0,19580032.0,19613568.0,19647104.0,19680640.0,19714176.0,19515520.0,19549120.0,19582720.0,19616320.0,19649920.0,19683520.0,19717120.0,19750720.0,19551616.0,19585280.0,19618944.0,19652608.0,19686272.0,19719936.0,19753600.0,19787264.0,19587712.0,19621440.0,19655168.0,19688896.0,19722624.0,19756352.0,19790080.0,19823808.0,19623808.0,19657600.0,19691392.0,19725184.0,19758976.0,19792768.0,19826560.0,19860352.0,19659904.0,19693760.0,19727616.0,19761472.0,19795328.0,19829184.0,19863040.0,19896896.0,18860416.0,18893312.0,18926208.0,18959104.0,18992000.0,19024896.0,19057792.0,19090688.0,18052736.0,18084672.0,18116608.0,18148544.0,18180480.0,18212416.0,18244352.0,18276288.0,17253248.0,17284224.0,17315200.0,17346176.0,17377152.0,17408128.0,17439104.0,17470080.0,16445568.0,16475584.0,16505600.0,16535616.0,16565632.0,16595648.0,16625664.0,16655680.0,16481664.0,16511744.0,16541824.0,16571904.0,16601984.0,16632064.0,16662144.0,16692224.0,16517760.0,16547904.0,16578048.0,16608192.0,16638336.0,16668480.0,16698624.0,16728768.0,16553856.0,16584064.0,16614272.0,16644480.0,16674688.0,16704896.0,16735104.0,16765312.0,16589952.0,16620224.0,16650496.0,16680768.0,16711040.0,16741312.0,16771584.0,16801856.0,16626048.0,16656384.0,16686720.0,16717056.0,16747392.0,16777728.0,16808064.0,16838400.0,16662144.0,16692544.0,16722944.0,16753344.0,16783744.0,16814144.0,16844544.0,16874944.0,16698240.0,16728704.0,16759168.0,16789632.0,16820096.0,16850560.0,16881024.0,16911488.0,16734336.0,16764864.0,16795392.0,16825920.0,16856448.0,16886976.0,16917504.0,16948032.0,16770432.0,16801024.0,16831616.0,16862208.0,16892800.0,16923392.0,16953984.0,16984576.0,16806528.0,16837184.0,16867840.0,16898496.0,16929152.0,16959808.0,16990464.0,17021120.0,16842624.0,16873344.0,16904064.0,16934784.0,16965504.0,16996224.0,17026944.0,17057664.0,16878720.0,16909504.0,16940288.0,16971072.0,17001856.0,17032640.0,17063424.0,17094208.0,16914816.0,16945664.0,16976512.0,17007360.0,17038208.0,17069056.0,17099904.0,17130752.0,16950912.0,16981824.0,17012736.0,17043648.0,17074560.0,17105472.0,17136384.0,17167296.0,16987008.0,17017984.0,17048960.0,17079936.0,17110912.0,17141888.0,17172864.0,17203840.0,17023104.0,17054144.0,17085184.0,17116224.0,17147264.0,17178304.0,17209344.0,17240384.0,17059200.0,17090304.0,17121408.0,17152512.0,17183616.0,17214720.0,17245824.0,17276928.0,17095296.0,17126464.0,17157632.0,17188800.0,17219968.0,17251136.0,17282304.0,17313472.0,17131392.0,17162624.0,17193856.0,17225088.0,17256320.0,17287552.0,17318784.0,17350016.0,17167488.0,17198784.0,17230080.0,17261376.0,17292672.0,17323968.0,17355264.0,17386560.0,17203584.0,17234944.0,17266304.0,17297664.0,17329024.0,17360384.0,17391744.0,17423104.0,17239680.0,17271104.0,17302528.0,17333952.0,17365376.0,17396800.0,17428224.0,17459648.0,17275776.0,17307264.0,17338752.0,17370240.0,17401728.0,17433216.0,17464704.0,17496192.0,17311872.0,17343424.0,17374976.0,17406528.0,17438080.0,17469632.0,17501184.0,17532736.0,17347968.0,17379584.0,17411200.0,17442816.0,17474432.0,17506048.0,17537664.0,17569280.0,17384064.0,17415744.0,17447424.0,17479104.0,17510784.0,17542464.0,17574144.0,17605824.0,17420160.0,17451904.0,17483648.0,17515392.0,17547136.0,17578880.0,17610624.0,17642368.0,17456256.0,17488064.0,17519872.0,17551680.0,17583488.0,17615296.0,17647104.0,17678912.0,17492352.0,17524224.0,17556096.0,17587968.0,17619840.0,17651712.0,17683584.0,17715456.0,17528448.0,17560384.0,17592320.0,17624256.0,17656192.0,17688128.0,17720064.0,17752000.0,17564544.0,17596544.0,17628544.0,17660544.0,17692544.0,17724544.0,17756544.0,17788544.0,17600640.0,17632704.0,17664768.0,17696832.0,17728896.0,17760960.0,17793024.0,17825088.0,17636736.0,17668864.0,17700992.0,17733120.0,17765248.0,17797376.0,17829504.0,17861632.0,17672832.0,17705024.0,17737216.0,17769408.0,17801600.0,17833792.0,17865984.0,17898176.0,17708928.0,17741184.0,17773440.0,17805696.0,17837952.0,17870208.0,17902464.0,17934720.0,17745024.0,17777344.0,17809664.0,17841984.0,17874304.0,17906624.0,17938944.0,17971264.0,17781120.0,17813504.0,17845888.0,17878272.0,17910656.0,17943040.0,17975424.0,18007808.0,17505920.0,17537344.0,17568768.0,17600192.0,17631616.0,17663040.0,17694464.0,17725888.0,17222528.0,17252992.0,17283456.0,17313920.0,17344384.0,17374848.0,17405312.0,17435776.0,17258624.0,17289152.0,17319680.0,17350208.0,17380736.0,17411264.0,17441792.0,17472320.0,17294720.0,17325312.0,17355904.0,17386496.0,17417088.0,17447680.0,17478272.0,17508864.0,17330816.0,17361472.0,17392128.0,17422784.0,17453440.0,17484096.0,17514752.0,17545408.0,17366912.0,17397632.0,17428352.0,17459072.0,17489792.0,17520512.0,17551232.0,17581952.0,17403008.0,17433792.0,17464576.0,17495360.0,17526144.0,17556928.0,17587712.0,17618496.0,17439104.0,17469952.0,17500800.0,17531648.0,17562496.0,17593344.0,17624192.0,17655040.0,17475200.0,17506112.0,17537024.0,17567936.0,17598848.0,17629760.0,17660672.0,17691584.0,17511296.0,17542272.0,17573248.0,17604224.0,17635200.0,17666176.0,17697152.0,17728128.0,17547392.0,17578432.0,17609472.0,17640512.0,17671552.0,17702592.0,17733632.0,17764672.0,17583488.0,17614592.0,17645696.0,17676800.0,17707904.0,17739008.0,17770112.0,17801216.0,17619584.0,17650752.0,17681920.0,17713088.0,17744256.0,17775424.0,17806592.0,17837760.0,17655680.0,17686912.0,17718144.0,17749376.0,17780608.0,17811840.0,17843072.0,17874304.0,17691776.0,17723072.0,17754368.0,17785664.0,17816960.0,17848256.0,17879552.0,17910848.0,17727872.0,17759232.0,17790592.0,17821952.0,17853312.0,17884672.0,17916032.0,17947392.0,17763968.0,17795392.0,17826816.0,17858240.0,17889664.0,17921088.0,17952512.0,17983936.0,17800064.0,17831552.0,17863040.0,17894528.0,17926016.0,17957504.0,17988992.0,18020480.0,17836160.0,17867712.0,17899264.0,17930816.0,17962368.0,17993920.0,18025472.0,18057024.0,17872256.0,17903872.0,17935488.0,17967104.0,17998720.0,18030336.0,18061952.0,18093568.0,17908352.0,17940032.0,17971712.0,18003392.0,18035072.0,18066752.0,18098432.0,18130112.0,17944448.0,17976192.0,18007936.0,18039680.0,18071424.0,18103168.0,18134912.0,18166656.0,17980544.0,18012352.0,18044160.0,18075968.0,18107776.0,18139584.0,18171392.0,18203200.0,18016640.0,18048512.0,18080384.0,18112256.0,18144128.0,18176000.0,18207872.0,18239744.0,18052736.0,18084672.0,18116608.0,18148544.0,18180480.0,18212416.0,18244352.0,18276288.0,18088832.0,18120832.0,18152832.0,18184832.0,18216832.0,18248832.0,18280832.0,18312832.0,18124928.0,18156992.0,18189056.0,18221120.0,18253184.0,18285248.0,18317312.0,18349376.0,18161024.0,18193152.0,18225280.0,18257408.0,18289536.0,18321664.0,18353792.0,18385920.0,18197120.0,18229312.0,18261504.0,18293696.0,18325888.0,18358080.0,18390272.0,18422464.0,18233216.0,18265472.0,18297728.0,18329984.0,18362240.0,18394496.0,18426752.0,18459008.0,18269312.0,18301632.0,18333952.0,18366272.0,18398592.0,18430912.0,18463232.0,18495552.0,18305408.0,18337792.0,18370176.0,18402560.0,18434944.0,18467328.0,18499712.0,18532096.0,18341504.0,18373952.0,18406400.0,18438848.0,18471296.0,18503744.0,18536192.0,18568640.0,18377600.0,18410112.0,18442624.0,18475136.0,18507648.0,18540160.0,18572672.0,18605184.0,18413696.0,18446272.0,18478848.0,18511424.0,18544000.0,18576576.0,18609152.0,18641728.0,18449792.0,18482432.0,18515072.0,18547712.0,18580352.0,18612992.0,18645632.0,18678272.0,18485888.0,18518592.0,18551296.0,18584000.0,18616704.0,18649408.0,18682112.0,18714816.0,18521984.0,18554752.0,18587520.0,18620288.0,18653056.0,18685824.0,18718592.0,18751360.0,18558080.0,18590912.0,18623744.0,18656576.0,18689408.0,18722240.0,18755072.0,18787904.0,17758592.0,17790464.0,17822336.0,17854208.0,17886080.0,17917952.0,17949824.0,17981696.0,16950912.0,16981824.0,17012736.0,17043648.0,17074560.0,17105472.0,17136384.0,17167296.0,16987008.0,17017984.0,17048960.0,17079936.0,17110912.0,17141888.0,17172864.0,17203840.0,17023104.0,17054144.0,17085184.0,17116224.0,17147264.0,17178304.0,17209344.0,17240384.0,17059200.0,17090304.0,17121408.0,17152512.0,17183616.0,17214720.0,17245824.0,17276928.0,17095296.0,17126464.0,17157632.0,17188800.0,17219968.0,17251136.0,17282304.0,17313472.0,17131392.0,17162624.0,17193856.0,17225088.0,17256320.0,17287552.0,17318784.0,17350016.0,17167488.0,17198784.0,17230080.0,17261376.0,17292672.0,17323968.0,17355264.0,17386560.0,17203584.0,17234944.0,17266304.0,17297664.0,17329024.0,17360384.0,17391744.0,17423104.0,17239680.0,17271104.0,17302528.0,17333952.0,17365376.0,17396800.0,17428224.0,17459648.0,17275776.0,17307264.0,17338752.0,17370240.0,17401728.0,17433216.0,17464704.0,17496192.0,17311872.0,17343424.0,17374976.0,17406528.0,17438080.0,17469632.0,17501184.0,17532736.0,17347968.0,17379584.0,17411200.0,17442816.0,17474432.0,17506048.0,17537664.0,17569280.0,17384064.0,17415744.0,17447424.0,17479104.0,17510784.0,17542464.0,17574144.0,17605824.0,17420160.0,17451904.0,17483648.0,17515392.0,17547136.0,17578880.0,17610624.0,17642368.0,17456256.0,17488064.0,17519872.0,17551680.0,17583488.0,17615296.0,17647104.0,17678912.0,17492352.0,17524224.0,17556096.0,17587968.0,17619840.0,17651712.0,17683584.0,17715456.0,17528448.0,17560384.0,17592320.0,17624256.0,17656192.0,17688128.0,17720064.0,17752000.0,17564544.0,17596544.0,17628544.0,17660544.0,17692544.0,17724544.0,17756544.0,17788544.0,17600640.0,17632704.0,17664768.0,17696832.0,17728896.0,17760960.0,17793024.0,17825088.0,17636736.0,17668864.0,17700992.0,17733120.0,17765248.0,17797376.0,17829504.0,17861632.0,17672832.0,17705024.0,17737216.0,17769408.0,17801600.0,17833792.0,17865984.0,17898176.0,17708928.0,17741184.0,17773440.0,17805696.0,17837952.0,17870208.0,17902464.0,17934720.0,17745024.0,17777344.0,17809664.0,17841984.0,17874304.0,17906624.0,17938944.0,17971264.0,17781120.0,17813504.0,17845888.0,17878272.0,17910656.0,17943040.0,17975424.0,18007808.0,17817216.0,17849664.0,17882112.0,17914560.0,17947008.0,17979456.0,18011904.0,18044352.0,17853312.0,17885824.0,17918336.0,17950848.0,17983360.0,18015872.0,18048384.0,18080896.0,17889408.0,17921984.0,17954560.0,17987136.0,18019712.0,18052288.0,18084864.0,18117440.0,17925504.0,17958144.0,17990784.0,18023424.0,18056064.0,18088704.0,18121344.0,18153984.0,17961600.0,17994304.0,18027008.0,18059712.0,18092416.0,18125120.0,18157824.0,18190528.0,17997696.0,18030464.0,18063232.0,18096000.0,18128768.0,18161536.0,18194304.0,18227072.0,18033792.0,18066624.0,18099456.0,18132288.0,18165120.0,18197952.0,18230784.0,18263616.0,18069888.0,18102784.0,18135680.0,18168576.0,18201472.0,18234368.0,18267264.0,18300160.0,18105984.0,18138944.0,18171904.0,18204864.0,18237824.0,18270784.0,18303744.0,18336704.0,18142080.0,18175104.0,18208128.0,18241152.0,18274176.0,18307200.0,18340224.0,18373248.0,18178176.0,18211264.0,18244352.0,18277440.0,18310528.0,18343616.0,18376704.0,18409792.0,18214272.0,18247424.0,18280576.0,18313728.0,18346880.0,18380032.0,18413184.0,18446336.0,18250368.0,18283584.0,18316800.0,18350016.0,18383232.0,18416448.0,18449664.0,18482880.0,18286464.0,18319744.0,18353024.0,18386304.0,18419584.0,18452864.0,18486144.0,18519424.0,18011264.0,18043584.0,18075904.0,18108224.0,18140544.0,18172864.0,18205184.0,18237504.0,17727872.0,17759232.0,17790592.0,17821952.0,17853312.0,17884672.0,17916032.0,17947392.0,17763968.0,17795392.0,17826816.0,17858240.0,17889664.0,17921088.0,17952512.0,17983936.0,17800064.0,17831552.0,17863040.0,17894528.0,17926016.0,17957504.0,17988992.0,18020480.0,17836160.0,17867712.0,17899264.0,17930816.0,17962368.0,17993920.0,18025472.0,18057024.0,17872256.0,17903872.0,17935488.0,17967104.0,17998720.0,18030336.0,18061952.0,18093568.0,17908352.0,17940032.0,17971712.0,18003392.0,18035072.0,18066752.0,18098432.0,18130112.0,17944448.0,17976192.0,18007936.0,18039680.0,18071424.0,18103168.0,18134912.0,18166656.0,17980544.0,18012352.0,18044160.0,18075968.0,18107776.0,18139584.0,18171392.0,18203200.0,18016640.0,18048512.0,18080384.0,18112256.0,18144128.0,18176000.0,18207872.0,18239744.0,18052736.0,18084672.0,18116608.0,18148544.0,18180480.0,18212416.0,18244352.0,18276288.0,18088832.0,18120832.0,18152832.0,18184832.0,18216832.0,18248832.0,18280832.0,18312832.0,18124928.0,18156992.0,18189056.0,18221120.0,18253184.0,18285248.0,18317312.0,18349376.0,18161024.0,18193152.0,18225280.0,18257408.0,18289536.0,18321664.0,18353792.0,18385920.0,18197120.0,18229312.0,18261504.0,18293696.0,18325888.0,18358080.0,18390272.0,18422464.0,18233216.0,18265472.0,18297728.0,18329984.0,18362240.0,18394496.0,18426752.0,18459008.0,18269312.0,18301632.0,18333952.0,18366272.0,18398592.0,18430912.0,18463232.0,18495552.0,18305408.0,18337792.0,18370176.0,18402560.0,18434944.0,18467328.0,18499712.0,18532096.0,18341504.0,18373952.0,18406400.0,18438848.0,18471296.0,18503744.0,18536192.0,18568640.0,18377600.0,18410112.0,18442624.0,18475136.0,18507648.0,18540160.0,18572672.0,18605184.0,18413696.0,18446272.0,18478848.0,18511424.0,18544000.0,18576576.0,18609152.0,18641728.0,18449792.0,18482432.0,18515072.0,18547712.0,18580352.0,18612992.0,18645632.0,18678272.0,18485888.0,18518592.0,18551296.0,18584000.0,18616704.0,18649408.0,18682112.0,18714816.0,18521984.0,18554752.0,18587520.0,18620288.0,18653056.0,18685824.0,18718592.0,18751360.0,18558080.0,18590912.0,18623744.0,18656576.0,18689408.0,18722240.0,18755072.0,18787904.0,18594176.0,18627072.0,18659968.0,18692864.0,18725760.0,18758656.0,18791552.0,18824448.0,18630272.0,18663232.0,18696192.0,18729152.0,18762112.0,18795072.0,18828032.0,18860992.0,18666368.0,18699392.0,18732416.0,18765440.0,18798464.0,18831488.0,18864512.0,18897536.0,18702464.0,18735552.0,18768640.0,18801728.0,18834816.0,18867904.0,18900992.0,18934080.0,18738560.0,18771712.0,18804864.0,18838016.0,18871168.0,18904320.0,18937472.0,18970624.0,18774656.0,18807872.0,18841088.0,18874304.0,18907520.0,18940736.0,18973952.0,19007168.0,18810752.0,18844032.0,18877312.0,18910592.0,18943872.0,18977152.0,19010432.0,19043712.0,18846848.0,18880192.0,18913536.0,18946880.0,18980224.0,19013568.0,19046912.0,19080256.0,18882944.0,18916352.0,18949760.0,18983168.0,19016576.0,19049984.0,19083392.0,19116800.0,18919040.0,18952512.0,18985984.0,19019456.0,19052928.0,19086400.0,19119872.0,19153344.0,18955136.0,18988672.0,19022208.0,19055744.0,19089280.0,19122816.0,19156352.0,19189888.0,18991232.0,19024832.0,19058432.0,19092032.0,19125632.0,19159232.0,19192832.0,19226432.0,19027328.0,19060992.0,19094656.0,19128320.0,19161984.0,19195648.0,19229312.0,19262976.0,19063424.0,19097152.0,19130880.0,19164608.0,19198336.0,19232064.0,19265792.0,19299520.0,18263936.0,18296704.0,18329472.0,18362240.0,18395008.0,18427776.0,18460544.0,18493312.0,17456256.0,17488064.0,17519872.0,17551680.0,17583488.0,17615296.0,17647104.0,17678912.0,8800896.0,8816816.0,8832736.0,8848656.0,8864576.0,8880496.0,8896416.0,8912336.0,9378816.0,9394608.0,9410400.0,9426192.0,9441984.0,9457776.0,9473568.0,9489360.0,18902400.0,18934016.0,18965632.0,18997248.0,19028864.0,19060480.0,19092096.0,19123712.0,18939520.0,18971200.0,19002880.0,19034560.0,19066240.0,19097920.0,19129600.0,19161280.0,18976640.0,19008384.0,19040128.0,19071872.0,19103616.0,19135360.0,19167104.0,19198848.0,19013760.0,19045568.0,19077376.0,19109184.0,19140992.0,19172800.0,19204608.0,19236416.0,19050880.0,19082752.0,19114624.0,19146496.0,19178368.0,19210240.0,19242112.0,19273984.0,19088000.0,19119936.0,19151872.0,19183808.0,19215744.0,19247680.0,19279616.0,19311552.0,19125120.0,19157120.0,19189120.0,19221120.0,19253120.0,19285120.0,19317120.0,19349120.0,19162240.0,19194304.0,19226368.0,19258432.0,19290496.0,19322560.0,19354624.0,19386688.0,19199360.0,19231488.0,19263616.0,19295744.0,19327872.0,19360000.0,19392128.0,19424256.0,19236480.0,19268672.0,19300864.0,19333056.0,19365248.0,19397440.0,19429632.0,19461824.0,19273600.0,19305856.0,19338112.0,19370368.0,19402624.0,19434880.0,19467136.0,19499392.0,19310720.0,19343040.0,19375360.0,19407680.0,19440000.0,19472320.0,19504640.0,19536960.0,19347840.0,19380224.0,19412608.0,19444992.0,19477376.0,19509760.0,19542144.0,19574528.0,19384960.0,19417408.0,19449856.0,19482304.0,19514752.0,19547200.0,19579648.0,19612096.0,19422080.0,19454592.0,19487104.0,19519616.0,19552128.0,19584640.0,19617152.0,19649664.0,19459200.0,19491776.0,19524352.0,19556928.0,19589504.0,19622080.0,19654656.0,19687232.0,19496320.0,19528960.0,19561600.0,19594240.0,19626880.0,19659520.0,19692160.0,19724800.0,19533440.0,19566144.0,19598848.0,19631552.0,19664256.0,19696960.0,19729664.0,19762368.0,19570560.0,19603328.0,19636096.0,19668864.0,19701632.0,19734400.0,19767168.0,19799936.0,19607680.0,19640512.0,19673344.0,19706176.0,19739008.0,19771840.0,19804672.0,19837504.0,19644800.0,19677696.0,19710592.0,19743488.0,19776384.0,19809280.0,19842176.0,19875072.0,19681920.0,19714880.0,19747840.0,19780800.0,19813760.0,19846720.0,19879680.0,19912640.0,19719040.0,19752064.0,19785088.0,19818112.0,19851136.0,19884160.0,19917184.0,19950208.0,19756160.0,19789248.0,19822336.0,19855424.0,19888512.0,19921600.0,19954688.0,19987776.0,19793280.0,19826432.0,19859584.0,19892736.0,19925888.0,19959040.0,19992192.0,20025344.0,19830400.0,19863616.0,19896832.0,19930048.0,19963264.0,19996480.0,20029696.0,20062912.0,19867520.0,19900800.0,19934080.0,19967360.0,20000640.0,20033920.0,20067200.0,20100480.0,19904640.0,19937984.0,19971328.0,20004672.0,20038016.0,20071360.0,20104704.0,20138048.0,19941760.0,19975168.0,20008576.0,20041984.0,20075392.0,20108800.0,20142208.0,20175616.0,19978880.0,20012352.0,20045824.0,20079296.0,20112768.0,20146240.0,20179712.0,20213184.0,20016000.0,20049536.0,20083072.0,20116608.0,20150144.0,20183680.0,20217216.0,20250752.0,20053120.0,20086720.0,20120320.0,20153920.0,20187520.0,20221120.0,20254720.0,20288320.0,20090240.0,20123904.0,20157568.0,20191232.0,20224896.0,20258560.0,20292224.0,20325888.0,20127360.0,20161088.0,20194816.0,20228544.0,20262272.0,20296000.0,20329728.0,20363456.0,20164480.0,20198272.0,20232064.0,20265856.0,20299648.0,20333440.0,20367232.0,20401024.0,20201600.0,20235456.0,20269312.0,20303168.0,20337024.0,20370880.0,20404736.0,20438592.0,19386752.0,19419648.0,19452544.0,19485440.0,19518336.0,19551232.0,19584128.0,19617024.0,18563712.0,18595648.0,18627584.0,18659520.0,18691456.0,18723392.0,18755328.0,18787264.0,17748864.0,17779840.0,17810816.0,17841792.0,17872768.0,17903744.0,17934720.0,17965696.0,16925824.0,16955840.0,16985856.0,17015872.0,17045888.0,17075904.0,17105920.0,17135936.0,16962944.0,16993024.0,17023104.0,17053184.0,17083264.0,17113344.0,17143424.0,17173504.0,17000064.0,17030208.0,17060352.0,17090496.0,17120640.0,17150784.0,17180928.0,17211072.0,17037184.0,17067392.0,17097600.0,17127808.0,17158016.0,17188224.0,17218432.0,17248640.0,17074304.0,17104576.0,17134848.0,17165120.0,17195392.0,17225664.0,17255936.0,17286208.0,17111424.0,17141760.0,17172096.0,17202432.0,17232768.0,17263104.0,17293440.0,17323776.0,17148544.0,17178944.0,17209344.0,17239744.0,17270144.0,17300544.0,17330944.0,17361344.0,17185664.0,17216128.0,17246592.0,17277056.0,17307520.0,17337984.0,17368448.0,17398912.0,17222784.0,17253312.0,17283840.0,17314368.0,17344896.0,17375424.0,17405952.0,17436480.0,17259904.0,17290496.0,17321088.0,17351680.0,17382272.0,17412864.0,17443456.0,17474048.0,17297024.0,17327680.0,17358336.0,17388992.0,17419648.0,17450304.0,17480960.0,17511616.0,17334144.0,17364864.0,17395584.0,17426304.0,17457024.0,17487744.0,17518464.0,17549184.0,17371264.0,17402048.0,17432832.0,17463616.0,17494400.0,17525184.0,17555968.0,17586752.0,17408384.0,17439232.0,17470080.0,17500928.0,17531776.0,17562624.0,17593472.0,17624320.0,17445504.0,17476416.0,17507328.0,17538240.0,17569152.0,17600064.0,17630976.0,17661888.0,17482624.0,17513600.0,17544576.0,17575552.0,17606528.0,17637504.0,17668480.0,17699456.0,17519744.0,17550784.0,17581824.0,17612864.0,17643904.0,17674944.0,17705984.0,17737024.0,17556864.0,17587968.0,17619072.0,17650176.0,17681280.0,17712384.0,17743488.0,17774592.0,17593984.0,17625152.0,17656320.0,17687488.0,17718656.0,17749824.0,17780992.0,17812160.0,17631104.0,17662336.0,17693568.0,17724800.0,17756032.0,17787264.0,17818496.0,17849728.0,17668224.0,17699520.0,17730816.0,17762112.0,17793408.0,17824704.0,17856000.0,17887296.0,17705344.0,17736704.0,17768064.0,17799424.0,17830784.0,17862144.0,17893504.0,17924864.0,17742464.0,17773888.0,17805312.0,17836736.0,17868160.0,17899584.0,17931008.0,17962432.0,17779584.0,17811072.0,17842560.0,17874048.0,17905536.0,17937024.0,17968512.0,18000000.0,17816704.0,17848256.0,17879808.0,17911360.0,17942912.0,17974464.0,18006016.0,18037568.0,17853824.0,17885440.0,17917056.0,17948672.0,17980288.0,18011904.0,18043520.0,18075136.0,17890944.0,17922624.0,17954304.0,17985984.0,18017664.0,18049344.0,18081024.0,18112704.0,17928064.0,17959808.0,17991552.0,18023296.0,18055040.0,18086784.0,18118528.0,18150272.0,17965184.0,17996992.0,18028800.0,18060608.0,18092416.0,18124224.0,18156032.0,18187840.0,18002304.0,18034176.0,18066048.0,18097920.0,18129792.0,18161664.0,18193536.0,18225408.0,18039424.0,18071360.0,18103296.0,18135232.0,18167168.0,18199104.0,18231040.0,18262976.0,18076544.0,18108544.0,18140544.0,18172544.0,18204544.0,18236544.0,18268544.0,18300544.0,18113664.0,18145728.0,18177792.0,18209856.0,18241920.0,18273984.0,18306048.0,18338112.0,18150784.0,18182912.0,18215040.0,18247168.0,18279296.0,18311424.0,18343552.0,18375680.0,18187904.0,18220096.0,18252288.0,18284480.0,18316672.0,18348864.0,18381056.0,18413248.0,18225024.0,18257280.0,18289536.0,18321792.0,18354048.0,18386304.0,18418560.0,18450816.0,18262144.0,18294464.0,18326784.0,18359104.0,18391424.0,18423744.0,18456064.0,18488384.0,18299264.0,18331648.0,18364032.0,18396416.0,18428800.0,18461184.0,18493568.0,18525952.0,18008704.0,18040128.0,18071552.0,18102976.0,18134400.0,18165824.0,18197248.0,18228672.0,17709952.0,17740416.0,17770880.0,17801344.0,17831808.0,17862272.0,17892736.0,17923200.0,17747072.0,17777600.0,17808128.0,17838656.0,17869184.0,17899712.0,17930240.0,17960768.0,17784192.0,17814784.0,17845376.0,17875968.0,17906560.0,17937152.0,17967744.0,17998336.0,17821312.0,17851968.0,17882624.0,17913280.0,17943936.0,17974592.0,18005248.0,18035904.0,17858432.0,17889152.0,17919872.0,17950592.0,17981312.0,18012032.0,18042752.0,18073472.0,17895552.0,17926336.0,17957120.0,17987904.0,18018688.0,18049472.0,18080256.0,18111040.0,17932672.0,17963520.0,17994368.0,18025216.0,18056064.0,18086912.0,18117760.0,18148608.0,17969792.0,18000704.0,18031616.0,18062528.0,18093440.0,18124352.0,18155264.0,18186176.0,18006912.0,18037888.0,18068864.0,18099840.0,18130816.0,18161792.0,18192768.0,18223744.0,18044032.0,18075072.0,18106112.0,18137152.0,18168192.0,18199232.0,18230272.0,18261312.0,18081152.0,18112256.0,18143360.0,18174464.0,18205568.0,18236672.0,18267776.0,18298880.0,18118272.0,18149440.0,18180608.0,18211776.0,18242944.0,18274112.0,18305280.0,18336448.0,18155392.0,18186624.0,18217856.0,18249088.0,18280320.0,18311552.0,18342784.0,18374016.0,18192512.0,18223808.0,18255104.0,18286400.0,18317696.0,18348992.0,18380288.0,18411584.0,18229632.0,18260992.0,18292352.0,18323712.0,18355072.0,18386432.0,18417792.0,18449152.0,18266752.0,18298176.0,18329600.0,18361024.0,18392448.0,18423872.0,18455296.0,18486720.0,18303872.0,18335360.0,18366848.0,18398336.0,18429824.0,18461312.0,18492800.0,18524288.0,18340992.0,18372544.0,18404096.0,18435648.0,18467200.0,18498752.0,18530304.0,18561856.0,18378112.0,18409728.0,18441344.0,18472960.0,18504576.0,18536192.0,18567808.0,18599424.0,18415232.0,18446912.0,18478592.0,18510272.0,18541952.0,18573632.0,18605312.0,18636992.0,18452352.0,18484096.0,18515840.0,18547584.0,18579328.0,18611072.0,18642816.0,18674560.0,18489472.0,18521280.0,18553088.0,18584896.0,18616704.0,18648512.0,18680320.0,18712128.0,18526592.0,18558464.0,18590336.0,18622208.0,18654080.0,18685952.0,18717824.0,18749696.0,18563712.0,18595648.0,18627584.0,18659520.0,18691456.0,18723392.0,18755328.0,18787264.0,18600832.0,18632832.0,18664832.0,18696832.0,18728832.0,18760832.0,18792832.0,18824832.0,18637952.0,18670016.0,18702080.0,18734144.0,18766208.0,18798272.0,18830336.0,18862400.0,18675072.0,18707200.0,18739328.0,18771456.0,18803584.0,18835712.0,18867840.0,18899968.0,18712192.0,18744384.0,18776576.0,18808768.0,18840960.0,18873152.0,18905344.0,18937536.0,18749312.0,18781568.0,18813824.0,18846080.0,18878336.0,18910592.0,18942848.0,18975104.0,18786432.0,18818752.0,18851072.0,18883392.0,18915712.0,18948032.0,18980352.0,19012672.0,18823552.0,18855936.0,18888320.0,18920704.0,18953088.0,18985472.0,19017856.0,19050240.0,18860672.0,18893120.0,18925568.0,18958016.0,18990464.0,19022912.0,19055360.0,19087808.0,18897792.0,18930304.0,18962816.0,18995328.0,19027840.0,19060352.0,19092864.0,19125376.0,18934912.0,18967488.0,19000064.0,19032640.0,19065216.0,19097792.0,19130368.0,19162944.0,18972032.0,19004672.0,19037312.0,19069952.0,19102592.0,19135232.0,19167872.0,19200512.0,19009152.0,19041856.0,19074560.0,19107264.0,19139968.0,19172672.0,19205376.0,19238080.0,19046272.0,19079040.0,19111808.0,19144576.0,19177344.0,19210112.0,19242880.0,19275648.0,19083392.0,19116224.0,19149056.0,19181888.0,19214720.0,19247552.0,19280384.0,19313216.0,18268544.0,18300416.0,18332288.0,18364160.0,18396032.0,18427904.0,18459776.0,18491648.0,17445504.0,17476416.0,17507328.0,17538240.0,17569152.0,17600064.0,17630976.0,17661888.0,17482624.0,17513600.0,17544576.0,17575552.0,17606528.0,17637504.0,17668480.0,17699456.0,17519744.0,17550784.0,17581824.0,17612864.0,17643904.0,17674944.0,17705984.0,17737024.0,17556864.0,17587968.0,17619072.0,17650176.0,17681280.0,17712384.0,17743488.0,17774592.0,17593984.0,17625152.0,17656320.0,17687488.0,17718656.0,17749824.0,17780992.0,17812160.0,17631104.0,17662336.0,17693568.0,17724800.0,17756032.0,17787264.0,17818496.0,17849728.0,17668224.0,17699520.0,17730816.0,17762112.0,17793408.0,17824704.0,17856000.0,17887296.0,17705344.0,17736704.0,17768064.0,17799424.0,17830784.0,17862144.0,17893504.0,17924864.0,17742464.0,17773888.0,17805312.0,17836736.0,17868160.0,17899584.0,17931008.0,17962432.0,17779584.0,17811072.0,17842560.0,17874048.0,17905536.0,17937024.0,17968512.0,18000000.0,17816704.0,17848256.0,17879808.0,17911360.0,17942912.0,17974464.0,18006016.0,18037568.0,17853824.0,17885440.0,17917056.0,17948672.0,17980288.0,18011904.0,18043520.0,18075136.0,17890944.0,17922624.0,17954304.0,17985984.0,18017664.0,18049344.0,18081024.0,18112704.0,17928064.0,17959808.0,17991552.0,18023296.0,18055040.0,18086784.0,18118528.0,18150272.0,17965184.0,17996992.0,18028800.0,18060608.0,18092416.0,18124224.0,18156032.0,18187840.0,18002304.0,18034176.0,18066048.0,18097920.0,18129792.0,18161664.0,18193536.0,18225408.0,18039424.0,18071360.0,18103296.0,18135232.0,18167168.0,18199104.0,18231040.0,18262976.0,18076544.0,18108544.0,18140544.0,18172544.0,18204544.0,18236544.0,18268544.0,18300544.0,18113664.0,18145728.0,18177792.0,18209856.0,18241920.0,18273984.0,18306048.0,18338112.0,18150784.0,18182912.0,18215040.0,18247168.0,18279296.0,18311424.0,18343552.0,18375680.0,18187904.0,18220096.0,18252288.0,18284480.0,18316672.0,18348864.0,18381056.0,18413248.0,18225024.0,18257280.0,18289536.0,18321792.0,18354048.0,18386304.0,18418560.0,18450816.0,18262144.0,18294464.0,18326784.0,18359104.0,18391424.0,18423744.0,18456064.0,18488384.0,18299264.0,18331648.0,18364032.0,18396416.0,18428800.0,18461184.0,18493568.0,18525952.0,18336384.0,18368832.0,18401280.0,18433728.0,18466176.0,18498624.0,18531072.0,18563520.0,18373504.0,18406016.0,18438528.0,18471040.0,18503552.0,18536064.0,18568576.0,18601088.0,18410624.0,18443200.0,18475776.0,18508352.0,18540928.0,18573504.0,18606080.0,18638656.0,18447744.0,18480384.0,18513024.0,18545664.0,18578304.0,18610944.0,18643584.0,18676224.0,18484864.0,18517568.0,18550272.0,18582976.0,18615680.0,18648384.0,18681088.0,18713792.0,18521984.0,18554752.0,18587520.0,18620288.0,18653056.0,18685824.0,18718592.0,18751360.0,18559104.0,18591936.0,18624768.0,18657600.0,18690432.0,18723264.0,18756096.0,18788928.0,18596224.0,18629120.0,18662016.0,18694912.0,18727808.0,18760704.0,18793600.0,18826496.0,18633344.0,18666304.0,18699264.0,18732224.0,18765184.0,18798144.0,18831104.0,18864064.0,18670464.0,18703488.0,18736512.0,18769536.0,18802560.0,18835584.0,18868608.0,18901632.0,18707584.0,18740672.0,18773760.0,18806848.0,18839936.0,18873024.0,18906112.0,18939200.0,18744704.0,18777856.0,18811008.0,18844160.0,18877312.0,18910464.0,18943616.0,18976768.0,18781824.0,18815040.0,18848256.0,18881472.0,18914688.0,18947904.0,18981120.0,19014336.0,18818944.0,18852224.0,18885504.0,18918784.0,18952064.0,18985344.0,19018624.0,19051904.0,18528384.0,18560704.0,18593024.0,18625344.0,18657664.0,18689984.0,18722304.0,18754624.0,18229632.0,18260992.0,18292352.0,18323712.0,18355072.0,18386432.0,18417792.0,18449152.0,18266752.0,18298176.0,18329600.0,18361024.0,18392448.0,18423872.0,18455296.0,18486720.0,18303872.0,18335360.0,18366848.0,18398336.0,18429824.0,18461312.0,18492800.0,18524288.0,18340992.0,18372544.0,18404096.0,18435648.0,18467200.0,18498752.0,18530304.0,18561856.0,18378112.0,18409728.0,18441344.0,18472960.0,18504576.0,18536192.0,18567808.0,18599424.0,18415232.0,18446912.0,18478592.0,18510272.0,18541952.0,18573632.0,18605312.0,18636992.0,18452352.0,18484096.0,18515840.0,18547584.0,18579328.0,18611072.0,18642816.0,18674560.0,18489472.0,18521280.0,18553088.0,18584896.0,18616704.0,18648512.0,18680320.0,18712128.0,18526592.0,18558464.0,18590336.0,18622208.0,18654080.0,18685952.0,18717824.0,18749696.0,18563712.0,18595648.0,18627584.0,18659520.0,18691456.0,18723392.0,18755328.0,18787264.0,18600832.0,18632832.0,18664832.0,18696832.0,18728832.0,18760832.0,18792832.0,18824832.0,18637952.0,18670016.0,18702080.0,18734144.0,18766208.0,18798272.0,18830336.0,18862400.0,18675072.0,18707200.0,18739328.0,18771456.0,18803584.0,18835712.0,18867840.0,18899968.0,18712192.0,18744384.0,18776576.0,18808768.0,18840960.0,18873152.0,18905344.0,18937536.0,18749312.0,18781568.0,18813824.0,18846080.0,18878336.0,18910592.0,18942848.0,18975104.0,18786432.0,18818752.0,18851072.0,18883392.0,18915712.0,18948032.0,18980352.0,19012672.0,18823552.0,18855936.0,18888320.0,18920704.0,18953088.0,18985472.0,19017856.0,19050240.0,18860672.0,18893120.0,18925568.0,18958016.0,18990464.0,19022912.0,19055360.0,19087808.0,18897792.0,18930304.0,18962816.0,18995328.0,19027840.0,19060352.0,19092864.0,19125376.0,18934912.0,18967488.0,19000064.0,19032640.0,19065216.0,19097792.0,19130368.0,19162944.0,18972032.0,19004672.0,19037312.0,19069952.0,19102592.0,19135232.0,19167872.0,19200512.0,19009152.0,19041856.0,19074560.0,19107264.0,19139968.0,19172672.0,19205376.0,19238080.0,19046272.0,19079040.0,19111808.0,19144576.0,19177344.0,19210112.0,19242880.0,19275648.0,19083392.0,19116224.0,19149056.0,19181888.0,19214720.0,19247552.0,19280384.0,19313216.0,19120512.0,19153408.0,19186304.0,19219200.0,19252096.0,19284992.0,19317888.0,19350784.0,19157632.0,19190592.0,19223552.0,19256512.0,19289472.0,19322432.0,19355392.0,19388352.0,19194752.0,19227776.0,19260800.0,19293824.0,19326848.0,19359872.0,19392896.0,19425920.0,19231872.0,19264960.0,19298048.0,19331136.0,19364224.0,19397312.0,19430400.0,19463488.0,19268992.0,19302144.0,19335296.0,19368448.0,19401600.0,19434752.0,19467904.0,19501056.0,19306112.0,19339328.0,19372544.0,19405760.0,19438976.0,19472192.0,19505408.0,19538624.0,19343232.0,19376512.0,19409792.0,19443072.0,19476352.0,19509632.0,19542912.0,19576192.0,19380352.0,19413696.0,19447040.0,19480384.0,19513728.0,19547072.0,19580416.0,19613760.0,19417472.0,19450880.0,19484288.0,19517696.0,19551104.0,19584512.0,19617920.0,19651328.0,19454592.0,19488064.0,19521536.0,19555008.0,19588480.0,19621952.0,19655424.0,19688896.0,19491712.0,19525248.0,19558784.0,19592320.0,19625856.0,19659392.0,19692928.0,19726464.0,19528832.0,19562432.0,19596032.0,19629632.0,19663232.0,19696832.0,19730432.0,19764032.0,19565952.0,19599616.0,19633280.0,19666944.0,19700608.0,19734272.0,19767936.0,19801600.0,19603072.0,19636800.0,19670528.0,19704256.0,19737984.0,19771712.0,19805440.0,19839168.0,18788224.0,18820992.0,18853760.0,18886528.0,18919296.0,18952064.0,18984832.0,19017600.0,17965184.0,17996992.0,18028800.0,18060608.0,18092416.0,18124224.0,18156032.0,18187840.0,9055616.0,9071536.0,9087456.0,9103376.0,9119296.0,9135216.0,9151136.0,9167056.0,9631488.0,9647280.0,9663072.0,9678864.0,9694656.0,9710448.0,9726240.0,9742032.0,19408256.0,19439872.0,19471488.0,19503104.0,19534720.0,19566336.0,19597952.0,19629568.0,19446400.0,19478080.0,19509760.0,19541440.0,19573120.0,19604800.0,19636480.0,19668160.0,19484544.0,19516288.0,19548032.0,19579776.0,19611520.0,19643264.0,19675008.0,19706752.0,19522688.0,19554496.0,19586304.0,19618112.0,19649920.0,19681728.0,19713536.0,19745344.0,19560832.0,19592704.0,19624576.0,19656448.0,19688320.0,19720192.0,19752064.0,19783936.0,19598976.0,19630912.0,19662848.0,19694784.0,19726720.0,19758656.0,19790592.0,19822528.0,19637120.0,19669120.0,19701120.0,19733120.0,19765120.0,19797120.0,19829120.0,19861120.0,19675264.0,19707328.0,19739392.0,19771456.0,19803520.0,19835584.0,19867648.0,19899712.0,19713408.0,19745536.0,19777664.0,19809792.0,19841920.0,19874048.0,19906176.0,19938304.0,19751552.0,19783744.0,19815936.0,19848128.0,19880320.0,19912512.0,19944704.0,19976896.0,19789696.0,19821952.0,19854208.0,19886464.0,19918720.0,19950976.0,19983232.0,20015488.0,19827840.0,19860160.0,19892480.0,19924800.0,19957120.0,19989440.0,20021760.0,20054080.0,19865984.0,19898368.0,19930752.0,19963136.0,19995520.0,20027904.0,20060288.0,20092672.0,19904128.0,19936576.0,19969024.0,20001472.0,20033920.0,20066368.0,20098816.0,20131264.0,19942272.0,19974784.0,20007296.0,20039808.0,20072320.0,20104832.0,20137344.0,20169856.0,19980416.0,20012992.0,20045568.0,20078144.0,20110720.0,20143296.0,20175872.0,20208448.0,20018560.0,20051200.0,20083840.0,20116480.0,20149120.0,20181760.0,20214400.0,20247040.0,20056704.0,20089408.0,20122112.0,20154816.0,20187520.0,20220224.0,20252928.0,20285632.0,20094848.0,20127616.0,20160384.0,20193152.0,20225920.0,20258688.0,20291456.0,20324224.0,20132992.0,20165824.0,20198656.0,20231488.0,20264320.0,20297152.0,20329984.0,20362816.0,20171136.0,20204032.0,20236928.0,20269824.0,20302720.0,20335616.0,20368512.0,20401408.0,20209280.0,20242240.0,20275200.0,20308160.0,20341120.0,20374080.0,20407040.0,20440000.0,20247424.0,20280448.0,20313472.0,20346496.0,20379520.0,20412544.0,20445568.0,20478592.0,20285568.0,20318656.0,20351744.0,20384832.0,20417920.0,20451008.0,20484096.0,20517184.0,20323712.0,20356864.0,20390016.0,20423168.0,20456320.0,20489472.0,20522624.0,20555776.0,20361856.0,20395072.0,20428288.0,20461504.0,20494720.0,20527936.0,20561152.0,20594368.0,20400000.0,20433280.0,20466560.0,20499840.0,20533120.0,20566400.0,20599680.0,20632960.0,20438144.0,20471488.0,20504832.0,20538176.0,20571520.0,20604864.0,20638208.0,20671552.0,20476288.0,20509696.0,20543104.0,20576512.0,20609920.0,20643328.0,20676736.0,20710144.0,20514432.0,20547904.0,20581376.0,20614848.0,20648320.0,20681792.0,20715264.0,20748736.0,20552576.0,20586112.0,20619648.0,20653184.0,20686720.0,20720256.0,20753792.0,20787328.0,20590720.0,20624320.0,20657920.0,20691520.0,20725120.0,20758720.0,20792320.0,20825920.0,20628864.0,20662528.0,20696192.0,20729856.0,20763520.0,20797184.0,20830848.0,20864512.0,20667008.0,20700736.0,20734464.0,20768192.0,20801920.0,20835648.0,20869376.0,20903104.0,20705152.0,20738944.0,20772736.0,20806528.0,20840320.0,20874112.0,20907904.0,20941696.0,20743296.0,20777152.0,20811008.0,20844864.0,20878720.0,20912576.0,20946432.0,20980288.0,19913088.0,19945984.0,19978880.0,20011776.0,20044672.0,20077568.0,20110464.0,20143360.0,19074688.0,19106624.0,19138560.0,19170496.0,19202432.0,19234368.0,19266304.0,19298240.0,18244480.0,18275456.0,18306432.0,18337408.0,18368384.0,18399360.0,18430336.0,18461312.0,17406080.0,17436096.0,17466112.0,17496128.0,17526144.0,17556160.0,17586176.0,17616192.0,17444224.0,17474304.0,17504384.0,17534464.0,17564544.0,17594624.0,17624704.0,17654784.0,17482368.0,17512512.0,17542656.0,17572800.0,17602944.0,17633088.0,17663232.0,17693376.0,17520512.0,17550720.0,17580928.0,17611136.0,17641344.0,17671552.0,17701760.0,17731968.0,17558656.0,17588928.0,17619200.0,17649472.0,17679744.0,17710016.0,17740288.0,17770560.0,17596800.0,17627136.0,17657472.0,17687808.0,17718144.0,17748480.0,17778816.0,17809152.0,17634944.0,17665344.0,17695744.0,17726144.0,17756544.0,17786944.0,17817344.0,17847744.0,17673088.0,17703552.0,17734016.0,17764480.0,17794944.0,17825408.0,17855872.0,17886336.0,17711232.0,17741760.0,17772288.0,17802816.0,17833344.0,17863872.0,17894400.0,17924928.0,17749376.0,17779968.0,17810560.0,17841152.0,17871744.0,17902336.0,17932928.0,17963520.0,17787520.0,17818176.0,17848832.0,17879488.0,17910144.0,17940800.0,17971456.0,18002112.0,17825664.0,17856384.0,17887104.0,17917824.0,17948544.0,17979264.0,18009984.0,18040704.0,17863808.0,17894592.0,17925376.0,17956160.0,17986944.0,18017728.0,18048512.0,18079296.0,17901952.0,17932800.0,17963648.0,17994496.0,18025344.0,18056192.0,18087040.0,18117888.0,17940096.0,17971008.0,18001920.0,18032832.0,18063744.0,18094656.0,18125568.0,18156480.0,17978240.0,18009216.0,18040192.0,18071168.0,18102144.0,18133120.0,18164096.0,18195072.0,18016384.0,18047424.0,18078464.0,18109504.0,18140544.0,18171584.0,18202624.0,18233664.0,18054528.0,18085632.0,18116736.0,18147840.0,18178944.0,18210048.0,18241152.0,18272256.0,18092672.0,18123840.0,18155008.0,18186176.0,18217344.0,18248512.0,18279680.0,18310848.0,18130816.0,18162048.0,18193280.0,18224512.0,18255744.0,18286976.0,18318208.0,18349440.0,18168960.0,18200256.0,18231552.0,18262848.0,18294144.0,18325440.0,18356736.0,18388032.0,18207104.0,18238464.0,18269824.0,18301184.0,18332544.0,18363904.0,18395264.0,18426624.0,18245248.0,18276672.0,18308096.0,18339520.0,18370944.0,18402368.0,18433792.0,18465216.0,18283392.0,18314880.0,18346368.0,18377856.0,18409344.0,18440832.0,18472320.0,18503808.0,18321536.0,18353088.0,18384640.0,18416192.0,18447744.0,18479296.0,18510848.0,18542400.0,18359680.0,18391296.0,18422912.0,18454528.0,18486144.0,18517760.0,18549376.0,18580992.0,18397824.0,18429504.0,18461184.0,18492864.0,18524544.0,18556224.0,18587904.0,18619584.0,18435968.0,18467712.0,18499456.0,18531200.0,18562944.0,18594688.0,18626432.0,18658176.0,18474112.0,18505920.0,18537728.0,18569536.0,18601344.0,18633152.0,18664960.0,18696768.0,18512256.0,18544128.0,18576000.0,18607872.0,18639744.0,18671616.0,18703488.0,18735360.0,18550400.0,18582336.0,18614272.0,18646208.0,18678144.0,18710080.0,18742016.0,18773952.0,18588544.0,18620544.0,18652544.0,18684544.0,18716544.0,18748544.0,18780544.0,18812544.0,18626688.0,18658752.0,18690816.0,18722880.0,18754944.0,18787008.0,18819072.0,18851136.0,18664832.0,18696960.0,18729088.0,18761216.0,18793344.0,18825472.0,18857600.0,18889728.0,18702976.0,18735168.0,18767360.0,18799552.0,18831744.0,18863936.0,18896128.0,18928320.0,18741120.0,18773376.0,18805632.0,18837888.0,18870144.0,18902400.0,18934656.0,18966912.0,18779264.0,18811584.0,18843904.0,18876224.0,18908544.0,18940864.0,18973184.0,19005504.0,18817408.0,18849792.0,18882176.0,18914560.0,18946944.0,18979328.0,19011712.0,19044096.0,18511488.0,18542912.0,18574336.0,18605760.0,18637184.0,18668608.0,18700032.0,18731456.0,18197376.0,18227840.0,18258304.0,18288768.0,18319232.0,18349696.0,18380160.0,18410624.0,18235520.0,18266048.0,18296576.0,18327104.0,18357632.0,18388160.0,18418688.0,18449216.0,18273664.0,18304256.0,18334848.0,18365440.0,18396032.0,18426624.0,18457216.0,18487808.0,18311808.0,18342464.0,18373120.0,18403776.0,18434432.0,18465088.0,18495744.0,18526400.0,18349952.0,18380672.0,18411392.0,18442112.0,18472832.0,18503552.0,18534272.0,18564992.0,18388096.0,18418880.0,18449664.0,18480448.0,18511232.0,18542016.0,18572800.0,18603584.0,18426240.0,18457088.0,18487936.0,18518784.0,18549632.0,18580480.0,18611328.0,18642176.0,18464384.0,18495296.0,18526208.0,18557120.0,18588032.0,18618944.0,18649856.0,18680768.0,18502528.0,18533504.0,18564480.0,18595456.0,18626432.0,18657408.0,18688384.0,18719360.0,18540672.0,18571712.0,18602752.0,18633792.0,18664832.0,18695872.0,18726912.0,18757952.0,18578816.0,18609920.0,18641024.0,18672128.0,18703232.0,18734336.0,18765440.0,18796544.0,18616960.0,18648128.0,18679296.0,18710464.0,18741632.0,18772800.0,18803968.0,18835136.0,18655104.0,18686336.0,18717568.0,18748800.0,18780032.0,18811264.0,18842496.0,18873728.0,18693248.0,18724544.0,18755840.0,18787136.0,18818432.0,18849728.0,18881024.0,18912320.0,18731392.0,18762752.0,18794112.0,18825472.0,18856832.0,18888192.0,18919552.0,18950912.0,18769536.0,18800960.0,18832384.0,18863808.0,18895232.0,18926656.0,18958080.0,18989504.0,18807680.0,18839168.0,18870656.0,18902144.0,18933632.0,18965120.0,18996608.0,19028096.0,18845824.0,18877376.0,18908928.0,18940480.0,18972032.0,19003584.0,19035136.0,19066688.0,18883968.0,18915584.0,18947200.0,18978816.0,19010432.0,19042048.0,19073664.0,19105280.0,18922112.0,18953792.0,18985472.0,19017152.0,19048832.0,19080512.0,19112192.0,19143872.0,18960256.0,18992000.0,19023744.0,19055488.0,19087232.0,19118976.0,19150720.0,19182464.0,18998400.0,19030208.0,19062016.0,19093824.0,19125632.0,19157440.0,19189248.0,19221056.0,19036544.0,19068416.0,19100288.0,19132160.0,19164032.0,19195904.0,19227776.0,19259648.0,19074688.0,19106624.0,19138560.0,19170496.0,19202432.0,19234368.0,19266304.0,19298240.0,19112832.0,19144832.0,19176832.0,19208832.0,19240832.0,19272832.0,19304832.0,19336832.0,19150976.0,19183040.0,19215104.0,19247168.0,19279232.0,19311296.0,19343360.0,19375424.0,19189120.0,19221248.0,19253376.0,19285504.0,19317632.0,19349760.0,19381888.0,19414016.0,19227264.0,19259456.0,19291648.0,19323840.0,19356032.0,19388224.0,19420416.0,19452608.0,19265408.0,19297664.0,19329920.0,19362176.0,19394432.0,19426688.0,19458944.0,19491200.0,19303552.0,19335872.0,19368192.0,19400512.0,19432832.0,19465152.0,19497472.0,19529792.0,19341696.0,19374080.0,19406464.0,19438848.0,19471232.0,19503616.0,19536000.0,19568384.0,19379840.0,19412288.0,19444736.0,19477184.0,19509632.0,19542080.0,19574528.0,19606976.0,19417984.0,19450496.0,19483008.0,19515520.0,19548032.0,19580544.0,19613056.0,19645568.0,19456128.0,19488704.0,19521280.0,19553856.0,19586432.0,19619008.0,19651584.0,19684160.0,19494272.0,19526912.0,19559552.0,19592192.0,19624832.0,19657472.0,19690112.0,19722752.0,19532416.0,19565120.0,19597824.0,19630528.0,19663232.0,19695936.0,19728640.0,19761344.0,19570560.0,19603328.0,19636096.0,19668864.0,19701632.0,19734400.0,19767168.0,19799936.0,19608704.0,19641536.0,19674368.0,19707200.0,19740032.0,19772864.0,19805696.0,19838528.0,18778496.0,18810368.0,18842240.0,18874112.0,18905984.0,18937856.0,18969728.0,19001600.0,17940096.0,17971008.0,18001920.0,18032832.0,18063744.0,18094656.0,18125568.0,18156480.0,17978240.0,18009216.0,18040192.0,18071168.0,18102144.0,18133120.0,18164096.0,18195072.0,18016384.0,18047424.0,18078464.0,18109504.0,18140544.0,18171584.0,18202624.0,18233664.0,18054528.0,18085632.0,18116736.0,18147840.0,18178944.0,18210048.0,18241152.0,18272256.0,18092672.0,18123840.0,18155008.0,18186176.0,18217344.0,18248512.0,18279680.0,18310848.0,18130816.0,18162048.0,18193280.0,18224512.0,18255744.0,18286976.0,18318208.0,18349440.0,18168960.0,18200256.0,18231552.0,18262848.0,18294144.0,18325440.0,18356736.0,18388032.0,18207104.0,18238464.0,18269824.0,18301184.0,18332544.0,18363904.0,18395264.0,18426624.0,18245248.0,18276672.0,18308096.0,18339520.0,18370944.0,18402368.0,18433792.0,18465216.0,18283392.0,18314880.0,18346368.0,18377856.0,18409344.0,18440832.0,18472320.0,18503808.0,18321536.0,18353088.0,18384640.0,18416192.0,18447744.0,18479296.0,18510848.0,18542400.0,18359680.0,18391296.0,18422912.0,18454528.0,18486144.0,18517760.0,18549376.0,18580992.0,18397824.0,18429504.0,18461184.0,18492864.0,18524544.0,18556224.0,18587904.0,18619584.0,18435968.0,18467712.0,18499456.0,18531200.0,18562944.0,18594688.0,18626432.0,18658176.0,18474112.0,18505920.0,18537728.0,18569536.0,18601344.0,18633152.0,18664960.0,18696768.0,18512256.0,18544128.0,18576000.0,18607872.0,18639744.0,18671616.0,18703488.0,18735360.0,18550400.0,18582336.0,18614272.0,18646208.0,18678144.0,18710080.0,18742016.0,18773952.0,18588544.0,18620544.0,18652544.0,18684544.0,18716544.0,18748544.0,18780544.0,18812544.0,18626688.0,18658752.0,18690816.0,18722880.0,18754944.0,18787008.0,18819072.0,18851136.0,18664832.0,18696960.0,18729088.0,18761216.0,18793344.0,18825472.0,18857600.0,18889728.0,18702976.0,18735168.0,18767360.0,18799552.0,18831744.0,18863936.0,18896128.0,18928320.0,18741120.0,18773376.0,18805632.0,18837888.0,18870144.0,18902400.0,18934656.0,18966912.0,18779264.0,18811584.0,18843904.0,18876224.0,18908544.0,18940864.0,18973184.0,19005504.0,18817408.0,18849792.0,18882176.0,18914560.0,18946944.0,18979328.0,19011712.0,19044096.0,18855552.0,18888000.0,18920448.0,18952896.0,18985344.0,19017792.0,19050240.0,19082688.0,18893696.0,18926208.0,18958720.0,18991232.0,19023744.0,19056256.0,19088768.0,19121280.0,18931840.0,18964416.0,18996992.0,19029568.0,19062144.0,19094720.0,19127296.0,19159872.0,18969984.0,19002624.0,19035264.0,19067904.0,19100544.0,19133184.0,19165824.0,19198464.0,19008128.0,19040832.0,19073536.0,19106240.0,19138944.0,19171648.0,19204352.0,19237056.0,19046272.0,19079040.0,19111808.0,19144576.0,19177344.0,19210112.0,19242880.0,19275648.0,19084416.0,19117248.0,19150080.0,19182912.0,19215744.0,19248576.0,19281408.0,19314240.0,19122560.0,19155456.0,19188352.0,19221248.0,19254144.0,19287040.0,19319936.0,19352832.0,19160704.0,19193664.0,19226624.0,19259584.0,19292544.0,19325504.0,19358464.0,19391424.0,19198848.0,19231872.0,19264896.0,19297920.0,19330944.0,19363968.0,19396992.0,19430016.0,19236992.0,19270080.0,19303168.0,19336256.0,19369344.0,19402432.0,19435520.0,19468608.0,19275136.0,19308288.0,19341440.0,19374592.0,19407744.0,19440896.0,19474048.0,19507200.0,19313280.0,19346496.0,19379712.0,19412928.0,19446144.0,19479360.0,19512576.0,19545792.0,19351424.0,19384704.0,19417984.0,19451264.0,19484544.0,19517824.0,19551104.0,19584384.0,19045504.0,19077824.0,19110144.0,19142464.0,19174784.0,19207104.0,19239424.0,19271744.0,18731392.0,18762752.0,18794112.0,18825472.0,18856832.0,18888192.0,18919552.0,18950912.0,18769536.0,18800960.0,18832384.0,18863808.0,18895232.0,18926656.0,18958080.0,18989504.0,18807680.0,18839168.0,18870656.0,18902144.0,18933632.0,18965120.0,18996608.0,19028096.0,18845824.0,18877376.0,18908928.0,18940480.0,18972032.0,19003584.0,19035136.0,19066688.0,18883968.0,18915584.0,18947200.0,18978816.0,19010432.0,19042048.0,19073664.0,19105280.0,18922112.0,18953792.0,18985472.0,19017152.0,19048832.0,19080512.0,19112192.0,19143872.0,18960256.0,18992000.0,19023744.0,19055488.0,19087232.0,19118976.0,19150720.0,19182464.0,18998400.0,19030208.0,19062016.0,19093824.0,19125632.0,19157440.0,19189248.0,19221056.0,19036544.0,19068416.0,19100288.0,19132160.0,19164032.0,19195904.0,19227776.0,19259648.0,19074688.0,19106624.0,19138560.0,19170496.0,19202432.0,19234368.0,19266304.0,19298240.0,19112832.0,19144832.0,19176832.0,19208832.0,19240832.0,19272832.0,19304832.0,19336832.0,19150976.0,19183040.0,19215104.0,19247168.0,19279232.0,19311296.0,19343360.0,19375424.0,19189120.0,19221248.0,19253376.0,19285504.0,19317632.0,19349760.0,19381888.0,19414016.0,19227264.0,19259456.0,19291648.0,19323840.0,19356032.0,19388224.0,19420416.0,19452608.0,19265408.0,19297664.0,19329920.0,19362176.0,19394432.0,19426688.0,19458944.0,19491200.0,19303552.0,19335872.0,19368192.0,19400512.0,19432832.0,19465152.0,19497472.0,19529792.0,19341696.0,19374080.0,19406464.0,19438848.0,19471232.0,19503616.0,19536000.0,19568384.0,19379840.0,19412288.0,19444736.0,19477184.0,19509632.0,19542080.0,19574528.0,19606976.0,19417984.0,19450496.0,19483008.0,19515520.0,19548032.0,19580544.0,19613056.0,19645568.0,19456128.0,19488704.0,19521280.0,19553856.0,19586432.0,19619008.0,19651584.0,19684160.0,19494272.0,19526912.0,19559552.0,19592192.0,19624832.0,19657472.0,19690112.0,19722752.0,19532416.0,19565120.0,19597824.0,19630528.0,19663232.0,19695936.0,19728640.0,19761344.0,19570560.0,19603328.0,19636096.0,19668864.0,19701632.0,19734400.0,19767168.0,19799936.0,19608704.0,19641536.0,19674368.0,19707200.0,19740032.0,19772864.0,19805696.0,19838528.0,19646848.0,19679744.0,19712640.0,19745536.0,19778432.0,19811328.0,19844224.0,19877120.0,19684992.0,19717952.0,19750912.0,19783872.0,19816832.0,19849792.0,19882752.0,19915712.0,19723136.0,19756160.0,19789184.0,19822208.0,19855232.0,19888256.0,19921280.0,19954304.0,19761280.0,19794368.0,19827456.0,19860544.0,19893632.0,19926720.0,19959808.0,19992896.0,19799424.0,19832576.0,19865728.0,19898880.0,19932032.0,19965184.0,19998336.0,20031488.0,19837568.0,19870784.0,19904000.0,19937216.0,19970432.0,20003648.0,20036864.0,20070080.0,19875712.0,19908992.0,19942272.0,19975552.0,20008832.0,20042112.0,20075392.0,20108672.0,19913856.0,19947200.0,19980544.0,20013888.0,20047232.0,20080576.0,20113920.0,20147264.0,19952000.0,19985408.0,20018816.0,20052224.0,20085632.0,20119040.0,20152448.0,20185856.0,19990144.0,20023616.0,20057088.0,20090560.0,20124032.0,20157504.0,20190976.0,20224448.0,20028288.0,20061824.0,20095360.0,20128896.0,20162432.0,20195968.0,20229504.0,20263040.0,20066432.0,20100032.0,20133632.0,20167232.0,20200832.0,20234432.0,20268032.0,20301632.0,20104576.0,20138240.0,20171904.0,20205568.0,20239232.0,20272896.0,20306560.0,20340224.0,20142720.0,20176448.0,20210176.0,20243904.0,20277632.0,20311360.0,20345088.0,20378816.0,19312512.0,19345280.0,19378048.0,19410816.0,19443584.0,19476352.0,19509120.0,19541888.0,18474112.0,18505920.0,18537728.0,18569536.0,18601344.0,18633152.0,18664960.0,18696768.0,9310336.0,9326256.0,9342176.0,9358096.0,9374016.0,9389936.0,9405856.0,9421776.0,9884160.0,9899952.0,9915744.0,9931536.0,9947328.0,9963120.0,9978912.0,9994704.0,19914112.0,19945728.0,19977344.0,20008960.0,20040576.0,20072192.0,20103808.0,20135424.0,19953280.0,19984960.0,20016640.0,20048320.0,20080000.0,20111680.0,20143360.0,20175040.0,19992448.0,20024192.0,20055936.0,20087680.0,20119424.0,20151168.0,20182912.0,20214656.0,20031616.0,20063424.0,20095232.0,20127040.0,20158848.0,20190656.0,20222464.0,20254272.0,20070784.0,20102656.0,20134528.0,20166400.0,20198272.0,20230144.0,20262016.0,20293888.0,20109952.0,20141888.0,20173824.0,20205760.0,20237696.0,20269632.0,20301568.0,20333504.0,20149120.0,20181120.0,20213120.0,20245120.0,20277120.0,20309120.0,20341120.0,20373120.0,20188288.0,20220352.0,20252416.0,20284480.0,20316544.0,20348608.0,20380672.0,20412736.0,20227456.0,20259584.0,20291712.0,20323840.0,20355968.0,20388096.0,20420224.0,20452352.0,20266624.0,20298816.0,20331008.0,20363200.0,20395392.0,20427584.0,20459776.0,20491968.0,20305792.0,20338048.0,20370304.0,20402560.0,20434816.0,20467072.0,20499328.0,20531584.0,20344960.0,20377280.0,20409600.0,20441920.0,20474240.0,20506560.0,20538880.0,20571200.0,20384128.0,20416512.0,20448896.0,20481280.0,20513664.0,20546048.0,20578432.0,20610816.0,20423296.0,20455744.0,20488192.0,20520640.0,20553088.0,20585536.0,20617984.0,20650432.0,20462464.0,20494976.0,20527488.0,20560000.0,20592512.0,20625024.0,20657536.0,20690048.0,20501632.0,20534208.0,20566784.0,20599360.0,20631936.0,20664512.0,20697088.0,20729664.0,20540800.0,20573440.0,20606080.0,20638720.0,20671360.0,20704000.0,20736640.0,20769280.0,20579968.0,20612672.0,20645376.0,20678080.0,20710784.0,20743488.0,20776192.0,20808896.0,20619136.0,20651904.0,20684672.0,20717440.0,20750208.0,20782976.0,20815744.0,20848512.0,20658304.0,20691136.0,20723968.0,20756800.0,20789632.0,20822464.0,20855296.0,20888128.0,20697472.0,20730368.0,20763264.0,20796160.0,20829056.0,20861952.0,20894848.0,20927744.0,20736640.0,20769600.0,20802560.0,20835520.0,20868480.0,20901440.0,20934400.0,20967360.0,20775808.0,20808832.0,20841856.0,20874880.0,20907904.0,20940928.0,20973952.0,21006976.0,20814976.0,20848064.0,20881152.0,20914240.0,20947328.0,20980416.0,21013504.0,21046592.0,20854144.0,20887296.0,20920448.0,20953600.0,20986752.0,21019904.0,21053056.0,21086208.0,20893312.0,20926528.0,20959744.0,20992960.0,21026176.0,21059392.0,21092608.0,21125824.0,20932480.0,20965760.0,20999040.0,21032320.0,21065600.0,21098880.0,21132160.0,21165440.0,20971648.0,21004992.0,21038336.0,21071680.0,21105024.0,21138368.0,21171712.0,21205056.0,21010816.0,21044224.0,21077632.0,21111040.0,21144448.0,21177856.0,21211264.0,21244672.0,21049984.0,21083456.0,21116928.0,21150400.0,21183872.0,21217344.0,21250816.0,21284288.0,21089152.0,21122688.0,21156224.0,21189760.0,21223296.0,21256832.0,21290368.0,21323904.0,21128320.0,21161920.0,21195520.0,21229120.0,21262720.0,21296320.0,21329920.0,21363520.0,21167488.0,21201152.0,21234816.0,21268480.0,21302144.0,21335808.0,21369472.0,21403136.0,21206656.0,21240384.0,21274112.0,21307840.0,21341568.0,21375296.0,21409024.0,21442752.0,21245824.0,21279616.0,21313408.0,21347200.0,21380992.0,21414784.0,21448576.0,21482368.0,21284992.0,21318848.0,21352704.0,21386560.0,21420416.0,21454272.0,21488128.0,21521984.0,20439424.0,20472320.0,20505216.0,20538112.0,20571008.0,20603904.0,20636800.0,20669696.0,19585664.0,19617600.0,19649536.0,19681472.0,19713408.0,19745344.0,19777280.0,19809216.0,18740096.0,18771072.0,18802048.0,18833024.0,18864000.0,18894976.0,18925952.0,18956928.0,17886336.0,17916352.0,17946368.0,17976384.0,18006400.0,18036416.0,18066432.0,18096448.0,17925504.0,17955584.0,17985664.0,18015744.0,18045824.0,18075904.0,18105984.0,18136064.0,17964672.0,17994816.0,18024960.0,18055104.0,18085248.0,18115392.0,18145536.0,18175680.0,18003840.0,18034048.0,18064256.0,18094464.0,18124672.0,18154880.0,18185088.0,18215296.0,18043008.0,18073280.0,18103552.0,18133824.0,18164096.0,18194368.0,18224640.0,18254912.0,18082176.0,18112512.0,18142848.0,18173184.0,18203520.0,18233856.0,18264192.0,18294528.0,18121344.0,18151744.0,18182144.0,18212544.0,18242944.0,18273344.0,18303744.0,18334144.0,18160512.0,18190976.0,18221440.0,18251904.0,18282368.0,18312832.0,18343296.0,18373760.0,18199680.0,18230208.0,18260736.0,18291264.0,18321792.0,18352320.0,18382848.0,18413376.0,18238848.0,18269440.0,18300032.0,18330624.0,18361216.0,18391808.0,18422400.0,18452992.0,18278016.0,18308672.0,18339328.0,18369984.0,18400640.0,18431296.0,18461952.0,18492608.0,18317184.0,18347904.0,18378624.0,18409344.0,18440064.0,18470784.0,18501504.0,18532224.0,18356352.0,18387136.0,18417920.0,18448704.0,18479488.0,18510272.0,18541056.0,18571840.0,18395520.0,18426368.0,18457216.0,18488064.0,18518912.0,18549760.0,18580608.0,18611456.0,18434688.0,18465600.0,18496512.0,18527424.0,18558336.0,18589248.0,18620160.0,18651072.0,18473856.0,18504832.0,18535808.0,18566784.0,18597760.0,18628736.0,18659712.0,18690688.0,18513024.0,18544064.0,18575104.0,18606144.0,18637184.0,18668224.0,18699264.0,18730304.0,18552192.0,18583296.0,18614400.0,18645504.0,18676608.0,18707712.0,18738816.0,18769920.0,18591360.0,18622528.0,18653696.0,18684864.0,18716032.0,18747200.0,18778368.0,18809536.0,18630528.0,18661760.0,18692992.0,18724224.0,18755456.0,18786688.0,18817920.0,18849152.0,18669696.0,18700992.0,18732288.0,18763584.0,18794880.0,18826176.0,18857472.0,18888768.0,18708864.0,18740224.0,18771584.0,18802944.0,18834304.0,18865664.0,18897024.0,18928384.0,18748032.0,18779456.0,18810880.0,18842304.0,18873728.0,18905152.0,18936576.0,18968000.0,18787200.0,18818688.0,18850176.0,18881664.0,18913152.0,18944640.0,18976128.0,19007616.0,18826368.0,18857920.0,18889472.0,18921024.0,18952576.0,18984128.0,19015680.0,19047232.0,18865536.0,18897152.0,18928768.0,18960384.0,18992000.0,19023616.0,19055232.0,19086848.0,18904704.0,18936384.0,18968064.0,18999744.0,19031424.0,19063104.0,19094784.0,19126464.0,18943872.0,18975616.0,19007360.0,19039104.0,19070848.0,19102592.0,19134336.0,19166080.0,18983040.0,19014848.0,19046656.0,19078464.0,19110272.0,19142080.0,19173888.0,19205696.0,19022208.0,19054080.0,19085952.0,19117824.0,19149696.0,19181568.0,19213440.0,19245312.0,19061376.0,19093312.0,19125248.0,19157184.0,19189120.0,19221056.0,19252992.0,19284928.0,19100544.0,19132544.0,19164544.0,19196544.0,19228544.0,19260544.0,19292544.0,19324544.0,19139712.0,19171776.0,19203840.0,19235904.0,19267968.0,19300032.0,19332096.0,19364160.0,19178880.0,19211008.0,19243136.0,19275264.0,19307392.0,19339520.0,19371648.0,19403776.0,19218048.0,19250240.0,19282432.0,19314624.0,19346816.0,19379008.0,19411200.0,19443392.0,19257216.0,19289472.0,19321728.0,19353984.0,19386240.0,19418496.0,19450752.0,19483008.0,19296384.0,19328704.0,19361024.0,19393344.0,19425664.0,19457984.0,19490304.0,19522624.0,19335552.0,19367936.0,19400320.0,19432704.0,19465088.0,19497472.0,19529856.0,19562240.0,19014272.0,19045696.0,19077120.0,19108544.0,19139968.0,19171392.0,19202816.0,19234240.0,18684800.0,18715264.0,18745728.0,18776192.0,18806656.0,18837120.0,18867584.0,18898048.0,18723968.0,18754496.0,18785024.0,18815552.0,18846080.0,18876608.0,18907136.0,18937664.0,18763136.0,18793728.0,18824320.0,18854912.0,18885504.0,18916096.0,18946688.0,18977280.0,18802304.0,18832960.0,18863616.0,18894272.0,18924928.0,18955584.0,18986240.0,19016896.0,18841472.0,18872192.0,18902912.0,18933632.0,18964352.0,18995072.0,19025792.0,19056512.0,18880640.0,18911424.0,18942208.0,18972992.0,19003776.0,19034560.0,19065344.0,19096128.0,18919808.0,18950656.0,18981504.0,19012352.0,19043200.0,19074048.0,19104896.0,19135744.0,18958976.0,18989888.0,19020800.0,19051712.0,19082624.0,19113536.0,19144448.0,19175360.0,18998144.0,19029120.0,19060096.0,19091072.0,19122048.0,19153024.0,19184000.0,19214976.0,19037312.0,19068352.0,19099392.0,19130432.0,19161472.0,19192512.0,19223552.0,19254592.0,19076480.0,19107584.0,19138688.0,19169792.0,19200896.0,19232000.0,19263104.0,19294208.0,19115648.0,19146816.0,19177984.0,19209152.0,19240320.0,19271488.0,19302656.0,19333824.0,19154816.0,19186048.0,19217280.0,19248512.0,19279744.0,19310976.0,19342208.0,19373440.0,19193984.0,19225280.0,19256576.0,19287872.0,19319168.0,19350464.0,19381760.0,19413056.0,19233152.0,19264512.0,19295872.0,19327232.0,19358592.0,19389952.0,19421312.0,19452672.0,19272320.0,19303744.0,19335168.0,19366592.0,19398016.0,19429440.0,19460864.0,19492288.0,19311488.0,19342976.0,19374464.0,19405952.0,19437440.0,19468928.0,19500416.0,19531904.0,19350656.0,19382208.0,19413760.0,19445312.0,19476864.0,19508416.0,19539968.0,19571520.0,19389824.0,19421440.0,19453056.0,19484672.0,19516288.0,19547904.0,19579520.0,19611136.0,19428992.0,19460672.0,19492352.0,19524032.0,19555712.0,19587392.0,19619072.0,19650752.0,19468160.0,19499904.0,19531648.0,19563392.0,19595136.0,19626880.0,19658624.0,19690368.0,19507328.0,19539136.0,19570944.0,19602752.0,19634560.0,19666368.0,19698176.0,19729984.0,19546496.0,19578368.0,19610240.0,19642112.0,19673984.0,19705856.0,19737728.0,19769600.0,19585664.0,19617600.0,19649536.0,19681472.0,19713408.0,19745344.0,19777280.0,19809216.0,19624832.0,19656832.0,19688832.0,19720832.0,19752832.0,19784832.0,19816832.0,19848832.0,19664000.0,19696064.0,19728128.0,19760192.0,19792256.0,19824320.0,19856384.0,19888448.0,19703168.0,19735296.0,19767424.0,19799552.0,19831680.0,19863808.0,19895936.0,19928064.0,19742336.0,19774528.0,19806720.0,19838912.0,19871104.0,19903296.0,19935488.0,19967680.0,19781504.0,19813760.0,19846016.0,19878272.0,19910528.0,19942784.0,19975040.0,20007296.0,19820672.0,19852992.0,19885312.0,19917632.0,19949952.0,19982272.0,20014592.0,20046912.0,19859840.0,19892224.0,19924608.0,19956992.0,19989376.0,20021760.0,20054144.0,20086528.0,19899008.0,19931456.0,19963904.0,19996352.0,20028800.0,20061248.0,20093696.0,20126144.0,19938176.0,19970688.0,20003200.0,20035712.0,20068224.0,20100736.0,20133248.0,20165760.0,19977344.0,20009920.0,20042496.0,20075072.0,20107648.0,20140224.0,20172800.0,20205376.0,20016512.0,20049152.0,20081792.0,20114432.0,20147072.0,20179712.0,20212352.0,20244992.0,20055680.0,20088384.0,20121088.0,20153792.0,20186496.0,20219200.0,20251904.0,20284608.0,20094848.0,20127616.0,20160384.0,20193152.0,20225920.0,20258688.0,20291456.0,20324224.0,20134016.0,20166848.0,20199680.0,20232512.0,20265344.0,20298176.0,20331008.0,20363840.0,19288448.0,19320320.0,19352192.0,19384064.0,19415936.0,19447808.0,19479680.0,19511552.0,18434688.0,18465600.0,18496512.0,18527424.0,18558336.0,18589248.0,18620160.0,18651072.0,18473856.0,18504832.0,18535808.0,18566784.0,18597760.0,18628736.0,18659712.0,18690688.0,18513024.0,18544064.0,18575104.0,18606144.0,18637184.0,18668224.0,18699264.0,18730304.0,18552192.0,18583296.0,18614400.0,18645504.0,18676608.0,18707712.0,18738816.0,18769920.0,18591360.0,18622528.0,18653696.0,18684864.0,18716032.0,18747200.0,18778368.0,18809536.0,18630528.0,18661760.0,18692992.0,18724224.0,18755456.0,18786688.0,18817920.0,18849152.0,18669696.0,18700992.0,18732288.0,18763584.0,18794880.0,18826176.0,18857472.0,18888768.0,18708864.0,18740224.0,18771584.0,18802944.0,18834304.0,18865664.0,18897024.0,18928384.0,18748032.0,18779456.0,18810880.0,18842304.0,18873728.0,18905152.0,18936576.0,18968000.0,18787200.0,18818688.0,18850176.0,18881664.0,18913152.0,18944640.0,18976128.0,19007616.0,18826368.0,18857920.0,18889472.0,18921024.0,18952576.0,18984128.0,19015680.0,19047232.0,18865536.0,18897152.0,18928768.0,18960384.0,18992000.0,19023616.0,19055232.0,19086848.0,18904704.0,18936384.0,18968064.0,18999744.0,19031424.0,19063104.0,19094784.0,19126464.0,18943872.0,18975616.0,19007360.0,19039104.0,19070848.0,19102592.0,19134336.0,19166080.0,18983040.0,19014848.0,19046656.0,19078464.0,19110272.0,19142080.0,19173888.0,19205696.0,19022208.0,19054080.0,19085952.0,19117824.0,19149696.0,19181568.0,19213440.0,19245312.0,19061376.0,19093312.0,19125248.0,19157184.0,19189120.0,19221056.0,19252992.0,19284928.0,19100544.0,19132544.0,19164544.0,19196544.0,19228544.0,19260544.0,19292544.0,19324544.0,19139712.0,19171776.0,19203840.0,19235904.0,19267968.0,19300032.0,19332096.0,19364160.0,19178880.0,19211008.0,19243136.0,19275264.0,19307392.0,19339520.0,19371648.0,19403776.0,19218048.0,19250240.0,19282432.0,19314624.0,19346816.0,19379008.0,19411200.0,19443392.0,19257216.0,19289472.0,19321728.0,19353984.0,19386240.0,19418496.0,19450752.0,19483008.0,19296384.0,19328704.0,19361024.0,19393344.0,19425664.0,19457984.0,19490304.0,19522624.0,19335552.0,19367936.0,19400320.0,19432704.0,19465088.0,19497472.0,19529856.0,19562240.0,19374720.0,19407168.0,19439616.0,19472064.0,19504512.0,19536960.0,19569408.0,19601856.0,19413888.0,19446400.0,19478912.0,19511424.0,19543936.0,19576448.0,19608960.0,19641472.0,19453056.0,19485632.0,19518208.0,19550784.0,19583360.0,19615936.0,19648512.0,19681088.0,19492224.0,19524864.0,19557504.0,19590144.0,19622784.0,19655424.0,19688064.0,19720704.0,19531392.0,19564096.0,19596800.0,19629504.0,19662208.0,19694912.0,19727616.0,19760320.0,19570560.0,19603328.0,19636096.0,19668864.0,19701632.0,19734400.0,19767168.0,19799936.0,19609728.0,19642560.0,19675392.0,19708224.0,19741056.0,19773888.0,19806720.0,19839552.0,19648896.0,19681792.0,19714688.0,19747584.0,19780480.0,19813376.0,19846272.0,19879168.0,19688064.0,19721024.0,19753984.0,19786944.0,19819904.0,19852864.0,19885824.0,19918784.0,19727232.0,19760256.0,19793280.0,19826304.0,19859328.0,19892352.0,19925376.0,19958400.0,19766400.0,19799488.0,19832576.0,19865664.0,19898752.0,19931840.0,19964928.0,19998016.0,19805568.0,19838720.0,19871872.0,19905024.0,19938176.0,19971328.0,20004480.0,20037632.0,19844736.0,19877952.0,19911168.0,19944384.0,19977600.0,20010816.0,20044032.0,20077248.0,19883904.0,19917184.0,19950464.0,19983744.0,20017024.0,20050304.0,20083584.0,20116864.0,19562624.0,19594944.0,19627264.0,19659584.0,19691904.0,19724224.0,19756544.0,19788864.0,19233152.0,19264512.0,19295872.0,19327232.0,19358592.0,19389952.0,19421312.0,19452672.0,19272320.0,19303744.0,19335168.0,19366592.0,19398016.0,19429440.0,19460864.0,19492288.0,19311488.0,19342976.0,19374464.0,19405952.0,19437440.0,19468928.0,19500416.0,19531904.0,19350656.0,19382208.0,19413760.0,19445312.0,19476864.0,19508416.0,19539968.0,19571520.0,19389824.0,19421440.0,19453056.0,19484672.0,19516288.0,19547904.0,19579520.0,19611136.0,19428992.0,19460672.0,19492352.0,19524032.0,19555712.0,19587392.0,19619072.0,19650752.0,19468160.0,19499904.0,19531648.0,19563392.0,19595136.0,19626880.0,19658624.0,19690368.0,19507328.0,19539136.0,19570944.0,19602752.0,19634560.0,19666368.0,19698176.0,19729984.0,19546496.0,19578368.0,19610240.0,19642112.0,19673984.0,19705856.0,19737728.0,19769600.0,19585664.0,19617600.0,19649536.0,19681472.0,19713408.0,19745344.0,19777280.0,19809216.0,19624832.0,19656832.0,19688832.0,19720832.0,19752832.0,19784832.0,19816832.0,19848832.0,19664000.0,19696064.0,19728128.0,19760192.0,19792256.0,19824320.0,19856384.0,19888448.0,19703168.0,19735296.0,19767424.0,19799552.0,19831680.0,19863808.0,19895936.0,19928064.0,19742336.0,19774528.0,19806720.0,19838912.0,19871104.0,19903296.0,19935488.0,19967680.0,19781504.0,19813760.0,19846016.0,19878272.0,19910528.0,19942784.0,19975040.0,20007296.0,19820672.0,19852992.0,19885312.0,19917632.0,19949952.0,19982272.0,20014592.0,20046912.0,19859840.0,19892224.0,19924608.0,19956992.0,19989376.0,20021760.0,20054144.0,20086528.0,19899008.0,19931456.0,19963904.0,19996352.0,20028800.0,20061248.0,20093696.0,20126144.0,19938176.0,19970688.0,20003200.0,20035712.0,20068224.0,20100736.0,20133248.0,20165760.0,19977344.0,20009920.0,20042496.0,20075072.0,20107648.0,20140224.0,20172800.0,20205376.0,20016512.0,20049152.0,20081792.0,20114432.0,20147072.0,20179712.0,20212352.0,20244992.0,20055680.0,20088384.0,20121088.0,20153792.0,20186496.0,20219200.0,20251904.0,20284608.0,20094848.0,20127616.0,20160384.0,20193152.0,20225920.0,20258688.0,20291456.0,20324224.0,20134016.0,20166848.0,20199680.0,20232512.0,20265344.0,20298176.0,20331008.0,20363840.0,20173184.0,20206080.0,20238976.0,20271872.0,20304768.0,20337664.0,20370560.0,20403456.0,20212352.0,20245312.0,20278272.0,20311232.0,20344192.0,20377152.0,20410112.0,20443072.0,20251520.0,20284544.0,20317568.0,20350592.0,20383616.0,20416640.0,20449664.0,20482688.0,20290688.0,20323776.0,20356864.0,20389952.0,20423040.0,20456128.0,20489216.0,20522304.0,20329856.0,20363008.0,20396160.0,20429312.0,20462464.0,20495616.0,20528768.0,20561920.0,20369024.0,20402240.0,20435456.0,20468672.0,20501888.0,20535104.0,20568320.0,20601536.0,20408192.0,20441472.0,20474752.0,20508032.0,20541312.0,20574592.0,20607872.0,20641152.0,20447360.0,20480704.0,20514048.0,20547392.0,20580736.0,20614080.0,20647424.0,20680768.0,20486528.0,20519936.0,20553344.0,20586752.0,20620160.0,20653568.0,20686976.0,20720384.0,20525696.0,20559168.0,20592640.0,20626112.0,20659584.0,20693056.0,20726528.0,20760000.0,20564864.0,20598400.0,20631936.0,20665472.0,20699008.0,20732544.0,20766080.0,20799616.0,20604032.0,20637632.0,20671232.0,20704832.0,20738432.0,20772032.0,20805632.0,20839232.0,20643200.0,20676864.0,20710528.0,20744192.0,20777856.0,20811520.0,20845184.0,20878848.0,20682368.0,20716096.0,20749824.0,20783552.0,20817280.0,20851008.0,20884736.0,20918464.0,19836800.0,19869568.0,19902336.0,19935104.0,19967872.0,20000640.0,20033408.0,20066176.0,18983040.0,19014848.0,19046656.0,19078464.0,19110272.0,19142080.0,19173888.0,19205696.0,9565056.0,9580976.0,9596896.0,9612816.0,9628736.0,9644656.0,9660576.0,9676496.0,10136832.0,10152624.0,10168416.0,10184208.0,10200000.0,10215792.0,10231584.0,10247376.0,20419968.0,20451584.0,20483200.0,20514816.0,20546432.0,20578048.0,20609664.0,20641280.0,20460160.0,20491840.0,20523520.0,20555200.0,20586880.0,20618560.0,20650240.0,20681920.0,20500352.0,20532096.0,20563840.0,20595584.0,20627328.0,20659072.0,20690816.0,20722560.0,20540544.0,20572352.0,20604160.0,20635968.0,20667776.0,20699584.0,20731392.0,20763200.0,20580736.0,20612608.0,20644480.0,20676352.0,20708224.0,20740096.0,20771968.0,20803840.0,20620928.0,20652864.0,20684800.0,20716736.0,20748672.0,20780608.0,20812544.0,20844480.0,20661120.0,20693120.0,20725120.0,20757120.0,20789120.0,20821120.0,20853120.0,20885120.0,20701312.0,20733376.0,20765440.0,20797504.0,20829568.0,20861632.0,20893696.0,20925760.0,20741504.0,20773632.0,20805760.0,20837888.0,20870016.0,20902144.0,20934272.0,20966400.0,20781696.0,20813888.0,20846080.0,20878272.0,20910464.0,20942656.0,20974848.0,21007040.0,20821888.0,20854144.0,20886400.0,20918656.0,20950912.0,20983168.0,21015424.0,21047680.0,20862080.0,20894400.0,20926720.0,20959040.0,20991360.0,21023680.0,21056000.0,21088320.0,20902272.0,20934656.0,20967040.0,20999424.0,21031808.0,21064192.0,21096576.0,21128960.0,20942464.0,20974912.0,21007360.0,21039808.0,21072256.0,21104704.0,21137152.0,21169600.0,20982656.0,21015168.0,21047680.0,21080192.0,21112704.0,21145216.0,21177728.0,21210240.0,21022848.0,21055424.0,21088000.0,21120576.0,21153152.0,21185728.0,21218304.0,21250880.0,21063040.0,21095680.0,21128320.0,21160960.0,21193600.0,21226240.0,21258880.0,21291520.0,21103232.0,21135936.0,21168640.0,21201344.0,21234048.0,21266752.0,21299456.0,21332160.0,21143424.0,21176192.0,21208960.0,21241728.0,21274496.0,21307264.0,21340032.0,21372800.0,21183616.0,21216448.0,21249280.0,21282112.0,21314944.0,21347776.0,21380608.0,21413440.0,21223808.0,21256704.0,21289600.0,21322496.0,21355392.0,21388288.0,21421184.0,21454080.0,21264000.0,21296960.0,21329920.0,21362880.0,21395840.0,21428800.0,21461760.0,21494720.0,21304192.0,21337216.0,21370240.0,21403264.0,21436288.0,21469312.0,21502336.0,21535360.0,21344384.0,21377472.0,21410560.0,21443648.0,21476736.0,21509824.0,21542912.0,21576000.0,21384576.0,21417728.0,21450880.0,21484032.0,21517184.0,21550336.0,21583488.0,21616640.0,21424768.0,21457984.0,21491200.0,21524416.0,21557632.0,21590848.0,21624064.0,21657280.0,21464960.0,21498240.0,21531520.0,21564800.0,21598080.0,21631360.0,21664640.0,21697920.0,21505152.0,21538496.0,21571840.0,21605184.0,21638528.0,21671872.0,21705216.0,21738560.0,21545344.0,21578752.0,21612160.0,21645568.0,21678976.0,21712384.0,21745792.0,21779200.0,21585536.0,21619008.0,21652480.0,21685952.0,21719424.0,21752896.0,21786368.0,21819840.0,21625728.0,21659264.0,21692800.0,21726336.0,21759872.0,21793408.0,21826944.0,21860480.0,21665920.0,21699520.0,21733120.0,21766720.0,21800320.0,21833920.0,21867520.0,21901120.0,21706112.0,21739776.0,21773440.0,21807104.0,21840768.0,21874432.0,21908096.0,21941760.0,21746304.0,21780032.0,21813760.0,21847488.0,21881216.0,21914944.0,21948672.0,21982400.0,21786496.0,21820288.0,21854080.0,21887872.0,21921664.0,21955456.0,21989248.0,22023040.0,21826688.0,21860544.0,21894400.0,21928256.0,21962112.0,21995968.0,22029824.0,22063680.0,20965760.0,20998656.0,21031552.0,21064448.0,21097344.0,21130240.0,21163136.0,21196032.0,20096640.0,20128576.0,20160512.0,20192448.0,20224384.0,20256320.0,20288256.0,20320192.0,19235712.0,19266688.0,19297664.0,19328640.0,19359616.0,19390592.0,19421568.0,19452544.0,18366592.0,18396608.0,18426624.0,18456640.0,18486656.0,18516672.0,18546688.0,18576704.0,18406784.0,18436864.0,18466944.0,18497024.0,18527104.0,18557184.0,18587264.0,18617344.0,18446976.0,18477120.0,18507264.0,18537408.0,18567552.0,18597696.0,18627840.0,18657984.0,18487168.0,18517376.0,18547584.0,18577792.0,18608000.0,18638208.0,18668416.0,18698624.0,18527360.0,18557632.0,18587904.0,18618176.0,18648448.0,18678720.0,18708992.0,18739264.0,18567552.0,18597888.0,18628224.0,18658560.0,18688896.0,18719232.0,18749568.0,18779904.0,18607744.0,18638144.0,18668544.0,18698944.0,18729344.0,18759744.0,18790144.0,18820544.0,18647936.0,18678400.0,18708864.0,18739328.0,18769792.0,18800256.0,18830720.0,18861184.0,18688128.0,18718656.0,18749184.0,18779712.0,18810240.0,18840768.0,18871296.0,18901824.0,18728320.0,18758912.0,18789504.0,18820096.0,18850688.0,18881280.0,18911872.0,18942464.0,18768512.0,18799168.0,18829824.0,18860480.0,18891136.0,18921792.0,18952448.0,18983104.0,18808704.0,18839424.0,18870144.0,18900864.0,18931584.0,18962304.0,18993024.0,19023744.0,18848896.0,18879680.0,18910464.0,18941248.0,18972032.0,19002816.0,19033600.0,19064384.0,18889088.0,18919936.0,18950784.0,18981632.0,19012480.0,19043328.0,19074176.0,19105024.0,18929280.0,18960192.0,18991104.0,19022016.0,19052928.0,19083840.0,19114752.0,19145664.0,18969472.0,19000448.0,19031424.0,19062400.0,19093376.0,19124352.0,19155328.0,19186304.0,19009664.0,19040704.0,19071744.0,19102784.0,19133824.0,19164864.0,19195904.0,19226944.0,19049856.0,19080960.0,19112064.0,19143168.0,19174272.0,19205376.0,19236480.0,19267584.0,19090048.0,19121216.0,19152384.0,19183552.0,19214720.0,19245888.0,19277056.0,19308224.0,19130240.0,19161472.0,19192704.0,19223936.0,19255168.0,19286400.0,19317632.0,19348864.0,19170432.0,19201728.0,19233024.0,19264320.0,19295616.0,19326912.0,19358208.0,19389504.0,19210624.0,19241984.0,19273344.0,19304704.0,19336064.0,19367424.0,19398784.0,19430144.0,19250816.0,19282240.0,19313664.0,19345088.0,19376512.0,19407936.0,19439360.0,19470784.0,19291008.0,19322496.0,19353984.0,19385472.0,19416960.0,19448448.0,19479936.0,19511424.0,19331200.0,19362752.0,19394304.0,19425856.0,19457408.0,19488960.0,19520512.0,19552064.0,19371392.0,19403008.0,19434624.0,19466240.0,19497856.0,19529472.0,19561088.0,19592704.0,19411584.0,19443264.0,19474944.0,19506624.0,19538304.0,19569984.0,19601664.0,19633344.0,19451776.0,19483520.0,19515264.0,19547008.0,19578752.0,19610496.0,19642240.0,19673984.0,19491968.0,19523776.0,19555584.0,19587392.0,19619200.0,19651008.0,19682816.0,19714624.0,19532160.0,19564032.0,19595904.0,19627776.0,19659648.0,19691520.0,19723392.0,19755264.0,19572352.0,19604288.0,19636224.0,19668160.0,19700096.0,19732032.0,19763968.0,19795904.0,19612544.0,19644544.0,19676544.0,19708544.0,19740544.0,19772544.0,19804544.0,19836544.0,19652736.0,19684800.0,19716864.0,19748928.0,19780992.0,19813056.0,19845120.0,19877184.0,19692928.0,19725056.0,19757184.0,19789312.0,19821440.0,19853568.0,19885696.0,19917824.0,19733120.0,19765312.0,19797504.0,19829696.0,19861888.0,19894080.0,19926272.0,19958464.0,19773312.0,19805568.0,19837824.0,19870080.0,19902336.0,19934592.0,19966848.0,19999104.0,19813504.0,19845824.0,19878144.0,19910464.0,19942784.0,19975104.0,20007424.0,20039744.0,19853696.0,19886080.0,19918464.0,19950848.0,19983232.0,20015616.0,20048000.0,20080384.0,19517056.0,19548480.0,19579904.0,19611328.0,19642752.0,19674176.0,19705600.0,19737024.0,19172224.0,19202688.0,19233152.0,19263616.0,19294080.0,19324544.0,19355008.0,19385472.0,19212416.0,19242944.0,19273472.0,19304000.0,19334528.0,19365056.0,19395584.0,19426112.0,19252608.0,19283200.0,19313792.0,19344384.0,19374976.0,19405568.0,19436160.0,19466752.0,19292800.0,19323456.0,19354112.0,19384768.0,19415424.0,19446080.0,19476736.0,19507392.0,19332992.0,19363712.0,19394432.0,19425152.0,19455872.0,19486592.0,19517312.0,19548032.0,19373184.0,19403968.0,19434752.0,19465536.0,19496320.0,19527104.0,19557888.0,19588672.0,19413376.0,19444224.0,19475072.0,19505920.0,19536768.0,19567616.0,19598464.0,19629312.0,19453568.0,19484480.0,19515392.0,19546304.0,19577216.0,19608128.0,19639040.0,19669952.0,19493760.0,19524736.0,19555712.0,19586688.0,19617664.0,19648640.0,19679616.0,19710592.0,19533952.0,19564992.0,19596032.0,19627072.0,19658112.0,19689152.0,19720192.0,19751232.0,19574144.0,19605248.0,19636352.0,19667456.0,19698560.0,19729664.0,19760768.0,19791872.0,19614336.0,19645504.0,19676672.0,19707840.0,19739008.0,19770176.0,19801344.0,19832512.0,19654528.0,19685760.0,19716992.0,19748224.0,19779456.0,19810688.0,19841920.0,19873152.0,19694720.0,19726016.0,19757312.0,19788608.0,19819904.0,19851200.0,19882496.0,19913792.0,19734912.0,19766272.0,19797632.0,19828992.0,19860352.0,19891712.0,19923072.0,19954432.0,19775104.0,19806528.0,19837952.0,19869376.0,19900800.0,19932224.0,19963648.0,19995072.0,19815296.0,19846784.0,19878272.0,19909760.0,19941248.0,19972736.0,20004224.0,20035712.0,19855488.0,19887040.0,19918592.0,19950144.0,19981696.0,20013248.0,20044800.0,20076352.0,19895680.0,19927296.0,19958912.0,19990528.0,20022144.0,20053760.0,20085376.0,20116992.0,19935872.0,19967552.0,19999232.0,20030912.0,20062592.0,20094272.0,20125952.0,20157632.0,19976064.0,20007808.0,20039552.0,20071296.0,20103040.0,20134784.0,20166528.0,20198272.0,20016256.0,20048064.0,20079872.0,20111680.0,20143488.0,20175296.0,20207104.0,20238912.0,20056448.0,20088320.0,20120192.0,20152064.0,20183936.0,20215808.0,20247680.0,20279552.0,20096640.0,20128576.0,20160512.0,20192448.0,20224384.0,20256320.0,20288256.0,20320192.0,20136832.0,20168832.0,20200832.0,20232832.0,20264832.0,20296832.0,20328832.0,20360832.0,20177024.0,20209088.0,20241152.0,20273216.0,20305280.0,20337344.0,20369408.0,20401472.0,20217216.0,20249344.0,20281472.0,20313600.0,20345728.0,20377856.0,20409984.0,20442112.0,20257408.0,20289600.0,20321792.0,20353984.0,20386176.0,20418368.0,20450560.0,20482752.0,20297600.0,20329856.0,20362112.0,20394368.0,20426624.0,20458880.0,20491136.0,20523392.0,20337792.0,20370112.0,20402432.0,20434752.0,20467072.0,20499392.0,20531712.0,20564032.0,20377984.0,20410368.0,20442752.0,20475136.0,20507520.0,20539904.0,20572288.0,20604672.0,20418176.0,20450624.0,20483072.0,20515520.0,20547968.0,20580416.0,20612864.0,20645312.0,20458368.0,20490880.0,20523392.0,20555904.0,20588416.0,20620928.0,20653440.0,20685952.0,20498560.0,20531136.0,20563712.0,20596288.0,20628864.0,20661440.0,20694016.0,20726592.0,20538752.0,20571392.0,20604032.0,20636672.0,20669312.0,20701952.0,20734592.0,20767232.0,20578944.0,20611648.0,20644352.0,20677056.0,20709760.0,20742464.0,20775168.0,20807872.0,20619136.0,20651904.0,20684672.0,20717440.0,20750208.0,20782976.0,20815744.0,20848512.0,20659328.0,20692160.0,20724992.0,20757824.0,20790656.0,20823488.0,20856320.0,20889152.0,19798400.0,19830272.0,19862144.0,19894016.0,19925888.0,19957760.0,19989632.0,20021504.0,18929280.0,18960192.0,18991104.0,19022016.0,19052928.0,19083840.0,19114752.0,19145664.0,18969472.0,19000448.0,19031424.0,19062400.0,19093376.0,19124352.0,19155328.0,19186304.0,19009664.0,19040704.0,19071744.0,19102784.0,19133824.0,19164864.0,19195904.0,19226944.0,19049856.0,19080960.0,19112064.0,19143168.0,19174272.0,19205376.0,19236480.0,19267584.0,19090048.0,19121216.0,19152384.0,19183552.0,19214720.0,19245888.0,19277056.0,19308224.0,19130240.0,19161472.0,19192704.0,19223936.0,19255168.0,19286400.0,19317632.0,19348864.0,19170432.0,19201728.0,19233024.0,19264320.0,19295616.0,19326912.0,19358208.0,19389504.0,19210624.0,19241984.0,19273344.0,19304704.0,19336064.0,19367424.0,19398784.0,19430144.0,19250816.0,19282240.0,19313664.0,19345088.0,19376512.0,19407936.0,19439360.0,19470784.0,19291008.0,19322496.0,19353984.0,19385472.0,19416960.0,19448448.0,19479936.0,19511424.0,19331200.0,19362752.0,19394304.0,19425856.0,19457408.0,19488960.0,19520512.0,19552064.0,19371392.0,19403008.0,19434624.0,19466240.0,19497856.0,19529472.0,19561088.0,19592704.0,19411584.0,19443264.0,19474944.0,19506624.0,19538304.0,19569984.0,19601664.0,19633344.0,19451776.0,19483520.0,19515264.0,19547008.0,19578752.0,19610496.0,19642240.0,19673984.0,19491968.0,19523776.0,19555584.0,19587392.0,19619200.0,19651008.0,19682816.0,19714624.0,19532160.0,19564032.0,19595904.0,19627776.0,19659648.0,19691520.0,19723392.0,19755264.0,19572352.0,19604288.0,19636224.0,19668160.0,19700096.0,19732032.0,19763968.0,19795904.0,19612544.0,19644544.0,19676544.0,19708544.0,19740544.0,19772544.0,19804544.0,19836544.0,19652736.0,19684800.0,19716864.0,19748928.0,19780992.0,19813056.0,19845120.0,19877184.0,19692928.0,19725056.0,19757184.0,19789312.0,19821440.0,19853568.0,19885696.0,19917824.0,19733120.0,19765312.0,19797504.0,19829696.0,19861888.0,19894080.0,19926272.0,19958464.0,19773312.0,19805568.0,19837824.0,19870080.0,19902336.0,19934592.0,19966848.0,19999104.0,19813504.0,19845824.0,19878144.0,19910464.0,19942784.0,19975104.0,20007424.0,20039744.0,19853696.0,19886080.0,19918464.0,19950848.0,19983232.0,20015616.0,20048000.0,20080384.0,19893888.0,19926336.0,19958784.0,19991232.0,20023680.0,20056128.0,20088576.0,20121024.0,19934080.0,19966592.0,19999104.0,20031616.0,20064128.0,20096640.0,20129152.0,20161664.0,19974272.0,20006848.0,20039424.0,20072000.0,20104576.0,20137152.0,20169728.0,20202304.0,20014464.0,20047104.0,20079744.0,20112384.0,20145024.0,20177664.0,20210304.0,20242944.0,20054656.0,20087360.0,20120064.0,20152768.0,20185472.0,20218176.0,20250880.0,20283584.0,20094848.0,20127616.0,20160384.0,20193152.0,20225920.0,20258688.0,20291456.0,20324224.0,20135040.0,20167872.0,20200704.0,20233536.0,20266368.0,20299200.0,20332032.0,20364864.0,20175232.0,20208128.0,20241024.0,20273920.0,20306816.0,20339712.0,20372608.0,20405504.0,20215424.0,20248384.0,20281344.0,20314304.0,20347264.0,20380224.0,20413184.0,20446144.0,20255616.0,20288640.0,20321664.0,20354688.0,20387712.0,20420736.0,20453760.0,20486784.0,20295808.0,20328896.0,20361984.0,20395072.0,20428160.0,20461248.0,20494336.0,20527424.0,20336000.0,20369152.0,20402304.0,20435456.0,20468608.0,20501760.0,20534912.0,20568064.0,20376192.0,20409408.0,20442624.0,20475840.0,20509056.0,20542272.0,20575488.0,20608704.0,20416384.0,20449664.0,20482944.0,20516224.0,20549504.0,20582784.0,20616064.0,20649344.0,20079744.0,20112064.0,20144384.0,20176704.0,20209024.0,20241344.0,20273664.0,20305984.0,19734912.0,19766272.0,19797632.0,19828992.0,19860352.0,19891712.0,19923072.0,19954432.0,19775104.0,19806528.0,19837952.0,19869376.0,19900800.0,19932224.0,19963648.0,19995072.0,19815296.0,19846784.0,19878272.0,19909760.0,19941248.0,19972736.0,20004224.0,20035712.0,19855488.0,19887040.0,19918592.0,19950144.0,19981696.0,20013248.0,20044800.0,20076352.0,19895680.0,19927296.0,19958912.0,19990528.0,20022144.0,20053760.0,20085376.0,20116992.0,19935872.0,19967552.0,19999232.0,20030912.0,20062592.0,20094272.0,20125952.0,20157632.0,19976064.0,20007808.0,20039552.0,20071296.0,20103040.0,20134784.0,20166528.0,20198272.0,20016256.0,20048064.0,20079872.0,20111680.0,20143488.0,20175296.0,20207104.0,20238912.0,20056448.0,20088320.0,20120192.0,20152064.0,20183936.0,20215808.0,20247680.0,20279552.0,20096640.0,20128576.0,20160512.0,20192448.0,20224384.0,20256320.0,20288256.0,20320192.0,20136832.0,20168832.0,20200832.0,20232832.0,20264832.0,20296832.0,20328832.0,20360832.0,20177024.0,20209088.0,20241152.0,20273216.0,20305280.0,20337344.0,20369408.0,20401472.0,20217216.0,20249344.0,20281472.0,20313600.0,20345728.0,20377856.0,20409984.0,20442112.0,20257408.0,20289600.0,20321792.0,20353984.0,20386176.0,20418368.0,20450560.0,20482752.0,20297600.0,20329856.0,20362112.0,20394368.0,20426624.0,20458880.0,20491136.0,20523392.0,20337792.0,20370112.0,20402432.0,20434752.0,20467072.0,20499392.0,20531712.0,20564032.0,20377984.0,20410368.0,20442752.0,20475136.0,20507520.0,20539904.0,20572288.0,20604672.0,20418176.0,20450624.0,20483072.0,20515520.0,20547968.0,20580416.0,20612864.0,20645312.0,20458368.0,20490880.0,20523392.0,20555904.0,20588416.0,20620928.0,20653440.0,20685952.0,20498560.0,20531136.0,20563712.0,20596288.0,20628864.0,20661440.0,20694016.0,20726592.0,20538752.0,20571392.0,20604032.0,20636672.0,20669312.0,20701952.0,20734592.0,20767232.0,20578944.0,20611648.0,20644352.0,20677056.0,20709760.0,20742464.0,20775168.0,20807872.0,20619136.0,20651904.0,20684672.0,20717440.0,20750208.0,20782976.0,20815744.0,20848512.0,20659328.0,20692160.0,20724992.0,20757824.0,20790656.0,20823488.0,20856320.0,20889152.0,20699520.0,20732416.0,20765312.0,20798208.0,20831104.0,20864000.0,20896896.0,20929792.0,20739712.0,20772672.0,20805632.0,20838592.0,20871552.0,20904512.0,20937472.0,20970432.0,20779904.0,20812928.0,20845952.0,20878976.0,20912000.0,20945024.0,20978048.0,21011072.0,20820096.0,20853184.0,20886272.0,20919360.0,20952448.0,20985536.0,21018624.0,21051712.0,20860288.0,20893440.0,20926592.0,20959744.0,20992896.0,21026048.0,21059200.0,21092352.0,20900480.0,20933696.0,20966912.0,21000128.0,21033344.0,21066560.0,21099776.0,21132992.0,20940672.0,20973952.0,21007232.0,21040512.0,21073792.0,21107072.0,21140352.0,21173632.0,20980864.0,21014208.0,21047552.0,21080896.0,21114240.0,21147584.0,21180928.0,21214272.0,21021056.0,21054464.0,21087872.0,21121280.0,21154688.0,21188096.0,21221504.0,21254912.0,21061248.0,21094720.0,21128192.0,21161664.0,21195136.0,21228608.0,21262080.0,21295552.0,21101440.0,21134976.0,21168512.0,21202048.0,21235584.0,21269120.0,21302656.0,21336192.0,21141632.0,21175232.0,21208832.0,21242432.0,21276032.0,21309632.0,21343232.0,21376832.0,21181824.0,21215488.0,21249152.0,21282816.0,21316480.0,21350144.0,21383808.0,21417472.0,21222016.0,21255744.0,21289472.0,21323200.0,21356928.0,21390656.0,21424384.0,21458112.0,20361088.0,20393856.0,20426624.0,20459392.0,20492160.0,20524928.0,20557696.0,20590464.0,19491968.0,19523776.0,19555584.0,19587392.0,19619200.0,19651008.0,19682816.0,19714624.0,9819776.0,9835696.0,9851616.0,9867536.0,9883456.0,9899376.0,9915296.0,9931216.0,10389504.0,10405296.0,10421088.0,10436880.0,10452672.0,10468464.0,10484256.0,10500048.0,20925824.0,20957440.0,20989056.0,21020672.0,21052288.0,21083904.0,21115520.0,21147136.0,20967040.0,20998720.0,21030400.0,21062080.0,21093760.0,21125440.0,21157120.0,21188800.0,21008256.0,21040000.0,21071744.0,21103488.0,21135232.0,21166976.0,21198720.0,21230464.0,21049472.0,21081280.0,21113088.0,21144896.0,21176704.0,21208512.0,21240320.0,21272128.0,21090688.0,21122560.0,21154432.0,21186304.0,21218176.0,21250048.0,21281920.0,21313792.0,21131904.0,21163840.0,21195776.0,21227712.0,21259648.0,21291584.0,21323520.0,21355456.0,21173120.0,21205120.0,21237120.0,21269120.0,21301120.0,21333120.0,21365120.0,21397120.0,21214336.0,21246400.0,21278464.0,21310528.0,21342592.0,21374656.0,21406720.0,21438784.0,21255552.0,21287680.0,21319808.0,21351936.0,21384064.0,21416192.0,21448320.0,21480448.0,21296768.0,21328960.0,21361152.0,21393344.0,21425536.0,21457728.0,21489920.0,21522112.0,21337984.0,21370240.0,21402496.0,21434752.0,21467008.0,21499264.0,21531520.0,21563776.0,21379200.0,21411520.0,21443840.0,21476160.0,21508480.0,21540800.0,21573120.0,21605440.0,21420416.0,21452800.0,21485184.0,21517568.0,21549952.0,21582336.0,21614720.0,21647104.0,21461632.0,21494080.0,21526528.0,21558976.0,21591424.0,21623872.0,21656320.0,21688768.0,21502848.0,21535360.0,21567872.0,21600384.0,21632896.0,21665408.0,21697920.0,21730432.0,21544064.0,21576640.0,21609216.0,21641792.0,21674368.0,21706944.0,21739520.0,21772096.0,21585280.0,21617920.0,21650560.0,21683200.0,21715840.0,21748480.0,21781120.0,21813760.0,21626496.0,21659200.0,21691904.0,21724608.0,21757312.0,21790016.0,21822720.0,21855424.0,21667712.0,21700480.0,21733248.0,21766016.0,21798784.0,21831552.0,21864320.0,21897088.0,21708928.0,21741760.0,21774592.0,21807424.0,21840256.0,21873088.0,21905920.0,21938752.0,21750144.0,21783040.0,21815936.0,21848832.0,21881728.0,21914624.0,21947520.0,21980416.0,21791360.0,21824320.0,21857280.0,21890240.0,21923200.0,21956160.0,21989120.0,22022080.0,21832576.0,21865600.0,21898624.0,21931648.0,21964672.0,21997696.0,22030720.0,22063744.0,21873792.0,21906880.0,21939968.0,21973056.0,22006144.0,22039232.0,22072320.0,22105408.0,21915008.0,21948160.0,21981312.0,22014464.0,22047616.0,22080768.0,22113920.0,22147072.0,21956224.0,21989440.0,22022656.0,22055872.0,22089088.0,22122304.0,22155520.0,22188736.0,21997440.0,22030720.0,22064000.0,22097280.0,22130560.0,22163840.0,22197120.0,22230400.0,22038656.0,22072000.0,22105344.0,22138688.0,22172032.0,22205376.0,22238720.0,22272064.0,22079872.0,22113280.0,22146688.0,22180096.0,22213504.0,22246912.0,22280320.0,22313728.0,22121088.0,22154560.0,22188032.0,22221504.0,22254976.0,22288448.0,22321920.0,22355392.0,22162304.0,22195840.0,22229376.0,22262912.0,22296448.0,22329984.0,22363520.0,22397056.0,22203520.0,22237120.0,22270720.0,22304320.0,22337920.0,22371520.0,22405120.0,22438720.0,22244736.0,22278400.0,22312064.0,22345728.0,22379392.0,22413056.0,22446720.0,22480384.0,22285952.0,22319680.0,22353408.0,22387136.0,22420864.0,22454592.0,22488320.0,22522048.0,22327168.0,22360960.0,22394752.0,22428544.0,22462336.0,22496128.0,22529920.0,22563712.0,22368384.0,22402240.0,22436096.0,22469952.0,22503808.0,22537664.0,22571520.0,22605376.0,21492096.0,21524992.0,21557888.0,21590784.0,21623680.0,21656576.0,21689472.0,21722368.0,20607616.0,20639552.0,20671488.0,20703424.0,20735360.0,20767296.0,20799232.0,20831168.0,19731328.0,19762304.0,19793280.0,19824256.0,19855232.0,19886208.0,19917184.0,19948160.0,18846848.0,18876864.0,18906880.0,18936896.0,18966912.0,18996928.0,19026944.0,19056960.0,18888064.0,18918144.0,18948224.0,18978304.0,19008384.0,19038464.0,19068544.0,19098624.0,18929280.0,18959424.0,18989568.0,19019712.0,19049856.0,19080000.0,19110144.0,19140288.0,18970496.0,19000704.0,19030912.0,19061120.0,19091328.0,19121536.0,19151744.0,19181952.0,19011712.0,19041984.0,19072256.0,19102528.0,19132800.0,19163072.0,19193344.0,19223616.0,19052928.0,19083264.0,19113600.0,19143936.0,19174272.0,19204608.0,19234944.0,19265280.0,19094144.0,19124544.0,19154944.0,19185344.0,19215744.0,19246144.0,19276544.0,19306944.0,19135360.0,19165824.0,19196288.0,19226752.0,19257216.0,19287680.0,19318144.0,19348608.0,19176576.0,19207104.0,19237632.0,19268160.0,19298688.0,19329216.0,19359744.0,19390272.0,19217792.0,19248384.0,19278976.0,19309568.0,19340160.0,19370752.0,19401344.0,19431936.0,19259008.0,19289664.0,19320320.0,19350976.0,19381632.0,19412288.0,19442944.0,19473600.0,19300224.0,19330944.0,19361664.0,19392384.0,19423104.0,19453824.0,19484544.0,19515264.0,19341440.0,19372224.0,19403008.0,19433792.0,19464576.0,19495360.0,19526144.0,19556928.0,19382656.0,19413504.0,19444352.0,19475200.0,19506048.0,19536896.0,19567744.0,19598592.0,19423872.0,19454784.0,19485696.0,19516608.0,19547520.0,19578432.0,19609344.0,19640256.0,19465088.0,19496064.0,19527040.0,19558016.0,19588992.0,19619968.0,19650944.0,19681920.0,19506304.0,19537344.0,19568384.0,19599424.0,19630464.0,19661504.0,19692544.0,19723584.0,19547520.0,19578624.0,19609728.0,19640832.0,19671936.0,19703040.0,19734144.0,19765248.0,19588736.0,19619904.0,19651072.0,19682240.0,19713408.0,19744576.0,19775744.0,19806912.0,19629952.0,19661184.0,19692416.0,19723648.0,19754880.0,19786112.0,19817344.0,19848576.0,19671168.0,19702464.0,19733760.0,19765056.0,19796352.0,19827648.0,19858944.0,19890240.0,19712384.0,19743744.0,19775104.0,19806464.0,19837824.0,19869184.0,19900544.0,19931904.0,19753600.0,19785024.0,19816448.0,19847872.0,19879296.0,19910720.0,19942144.0,19973568.0,19794816.0,19826304.0,19857792.0,19889280.0,19920768.0,19952256.0,19983744.0,20015232.0,19836032.0,19867584.0,19899136.0,19930688.0,19962240.0,19993792.0,20025344.0,20056896.0,19877248.0,19908864.0,19940480.0,19972096.0,20003712.0,20035328.0,20066944.0,20098560.0,19918464.0,19950144.0,19981824.0,20013504.0,20045184.0,20076864.0,20108544.0,20140224.0,19959680.0,19991424.0,20023168.0,20054912.0,20086656.0,20118400.0,20150144.0,20181888.0,20000896.0,20032704.0,20064512.0,20096320.0,20128128.0,20159936.0,20191744.0,20223552.0,20042112.0,20073984.0,20105856.0,20137728.0,20169600.0,20201472.0,20233344.0,20265216.0,20083328.0,20115264.0,20147200.0,20179136.0,20211072.0,20243008.0,20274944.0,20306880.0,20124544.0,20156544.0,20188544.0,20220544.0,20252544.0,20284544.0,20316544.0,20348544.0,20165760.0,20197824.0,20229888.0,20261952.0,20294016.0,20326080.0,20358144.0,20390208.0,20206976.0,20239104.0,20271232.0,20303360.0,20335488.0,20367616.0,20399744.0,20431872.0,20248192.0,20280384.0,20312576.0,20344768.0,20376960.0,20409152.0,20441344.0,20473536.0,20289408.0,20321664.0,20353920.0,20386176.0,20418432.0,20450688.0,20482944.0,20515200.0,20330624.0,20362944.0,20395264.0,20427584.0,20459904.0,20492224.0,20524544.0,20556864.0,20371840.0,20404224.0,20436608.0,20468992.0,20501376.0,20533760.0,20566144.0,20598528.0,20019840.0,20051264.0,20082688.0,20114112.0,20145536.0,20176960.0,20208384.0,20239808.0,19659648.0,19690112.0,19720576.0,19751040.0,19781504.0,19811968.0,19842432.0,19872896.0,19700864.0,19731392.0,19761920.0,19792448.0,19822976.0,19853504.0,19884032.0,19914560.0,19742080.0,19772672.0,19803264.0,19833856.0,19864448.0,19895040.0,19925632.0,19956224.0,19783296.0,19813952.0,19844608.0,19875264.0,19905920.0,19936576.0,19967232.0,19997888.0,19824512.0,19855232.0,19885952.0,19916672.0,19947392.0,19978112.0,20008832.0,20039552.0,19865728.0,19896512.0,19927296.0,19958080.0,19988864.0,20019648.0,20050432.0,20081216.0,19906944.0,19937792.0,19968640.0,19999488.0,20030336.0,20061184.0,20092032.0,20122880.0,19948160.0,19979072.0,20009984.0,20040896.0,20071808.0,20102720.0,20133632.0,20164544.0,19989376.0,20020352.0,20051328.0,20082304.0,20113280.0,20144256.0,20175232.0,20206208.0,20030592.0,20061632.0,20092672.0,20123712.0,20154752.0,20185792.0,20216832.0,20247872.0,20071808.0,20102912.0,20134016.0,20165120.0,20196224.0,20227328.0,20258432.0,20289536.0,20113024.0,20144192.0,20175360.0,20206528.0,20237696.0,20268864.0,20300032.0,20331200.0,20154240.0,20185472.0,20216704.0,20247936.0,20279168.0,20310400.0,20341632.0,20372864.0,20195456.0,20226752.0,20258048.0,20289344.0,20320640.0,20351936.0,20383232.0,20414528.0,20236672.0,20268032.0,20299392.0,20330752.0,20362112.0,20393472.0,20424832.0,20456192.0,20277888.0,20309312.0,20340736.0,20372160.0,20403584.0,20435008.0,20466432.0,20497856.0,20319104.0,20350592.0,20382080.0,20413568.0,20445056.0,20476544.0,20508032.0,20539520.0,20360320.0,20391872.0,20423424.0,20454976.0,20486528.0,20518080.0,20549632.0,20581184.0,20401536.0,20433152.0,20464768.0,20496384.0,20528000.0,20559616.0,20591232.0,20622848.0,20442752.0,20474432.0,20506112.0,20537792.0,20569472.0,20601152.0,20632832.0,20664512.0,20483968.0,20515712.0,20547456.0,20579200.0,20610944.0,20642688.0,20674432.0,20706176.0,20525184.0,20556992.0,20588800.0,20620608.0,20652416.0,20684224.0,20716032.0,20747840.0,20566400.0,20598272.0,20630144.0,20662016.0,20693888.0,20725760.0,20757632.0,20789504.0,20607616.0,20639552.0,20671488.0,20703424.0,20735360.0,20767296.0,20799232.0,20831168.0,20648832.0,20680832.0,20712832.0,20744832.0,20776832.0,20808832.0,20840832.0,20872832.0,20690048.0,20722112.0,20754176.0,20786240.0,20818304.0,20850368.0,20882432.0,20914496.0,20731264.0,20763392.0,20795520.0,20827648.0,20859776.0,20891904.0,20924032.0,20956160.0,20772480.0,20804672.0,20836864.0,20869056.0,20901248.0,20933440.0,20965632.0,20997824.0,20813696.0,20845952.0,20878208.0,20910464.0,20942720.0,20974976.0,21007232.0,21039488.0,20854912.0,20887232.0,20919552.0,20951872.0,20984192.0,21016512.0,21048832.0,21081152.0,20896128.0,20928512.0,20960896.0,20993280.0,21025664.0,21058048.0,21090432.0,21122816.0,20937344.0,20969792.0,21002240.0,21034688.0,21067136.0,21099584.0,21132032.0,21164480.0,20978560.0,21011072.0,21043584.0,21076096.0,21108608.0,21141120.0,21173632.0,21206144.0,21019776.0,21052352.0,21084928.0,21117504.0,21150080.0,21182656.0,21215232.0,21247808.0,21060992.0,21093632.0,21126272.0,21158912.0,21191552.0,21224192.0,21256832.0,21289472.0,21102208.0,21134912.0,21167616.0,21200320.0,21233024.0,21265728.0,21298432.0,21331136.0,21143424.0,21176192.0,21208960.0,21241728.0,21274496.0,21307264.0,21340032.0,21372800.0,21184640.0,21217472.0,21250304.0,21283136.0,21315968.0,21348800.0,21381632.0,21414464.0,20308352.0,20340224.0,20372096.0,20403968.0,20435840.0,20467712.0,20499584.0,20531456.0,19423872.0,19454784.0,19485696.0,19516608.0,19547520.0,19578432.0,19609344.0,19640256.0,19465088.0,19496064.0,19527040.0,19558016.0,19588992.0,19619968.0,19650944.0,19681920.0,19506304.0,19537344.0,19568384.0,19599424.0,19630464.0,19661504.0,19692544.0,19723584.0,19547520.0,19578624.0,19609728.0,19640832.0,19671936.0,19703040.0,19734144.0,19765248.0,19588736.0,19619904.0,19651072.0,19682240.0,19713408.0,19744576.0,19775744.0,19806912.0,19629952.0,19661184.0,19692416.0,19723648.0,19754880.0,19786112.0,19817344.0,19848576.0,19671168.0,19702464.0,19733760.0,19765056.0,19796352.0,19827648.0,19858944.0,19890240.0,19712384.0,19743744.0,19775104.0,19806464.0,19837824.0,19869184.0,19900544.0,19931904.0,19753600.0,19785024.0,19816448.0,19847872.0,19879296.0,19910720.0,19942144.0,19973568.0,19794816.0,19826304.0,19857792.0,19889280.0,19920768.0,19952256.0,19983744.0,20015232.0,19836032.0,19867584.0,19899136.0,19930688.0,19962240.0,19993792.0,20025344.0,20056896.0,19877248.0,19908864.0,19940480.0,19972096.0,20003712.0,20035328.0,20066944.0,20098560.0,19918464.0,19950144.0,19981824.0,20013504.0,20045184.0,20076864.0,20108544.0,20140224.0,19959680.0,19991424.0,20023168.0,20054912.0,20086656.0,20118400.0,20150144.0,20181888.0,20000896.0,20032704.0,20064512.0,20096320.0,20128128.0,20159936.0,20191744.0,20223552.0,20042112.0,20073984.0,20105856.0,20137728.0,20169600.0,20201472.0,20233344.0,20265216.0,20083328.0,20115264.0,20147200.0,20179136.0,20211072.0,20243008.0,20274944.0,20306880.0,20124544.0,20156544.0,20188544.0,20220544.0,20252544.0,20284544.0,20316544.0,20348544.0,20165760.0,20197824.0,20229888.0,20261952.0,20294016.0,20326080.0,20358144.0,20390208.0,20206976.0,20239104.0,20271232.0,20303360.0,20335488.0,20367616.0,20399744.0,20431872.0,20248192.0,20280384.0,20312576.0,20344768.0,20376960.0,20409152.0,20441344.0,20473536.0,20289408.0,20321664.0,20353920.0,20386176.0,20418432.0,20450688.0,20482944.0,20515200.0,20330624.0,20362944.0,20395264.0,20427584.0,20459904.0,20492224.0,20524544.0,20556864.0,20371840.0,20404224.0,20436608.0,20468992.0,20501376.0,20533760.0,20566144.0,20598528.0,20413056.0,20445504.0,20477952.0,20510400.0,20542848.0,20575296.0,20607744.0,20640192.0,20454272.0,20486784.0,20519296.0,20551808.0,20584320.0,20616832.0,20649344.0,20681856.0,20495488.0,20528064.0,20560640.0,20593216.0,20625792.0,20658368.0,20690944.0,20723520.0,20536704.0,20569344.0,20601984.0,20634624.0,20667264.0,20699904.0,20732544.0,20765184.0,20577920.0,20610624.0,20643328.0,20676032.0,20708736.0,20741440.0,20774144.0,20806848.0,20619136.0,20651904.0,20684672.0,20717440.0,20750208.0,20782976.0,20815744.0,20848512.0,20660352.0,20693184.0,20726016.0,20758848.0,20791680.0,20824512.0,20857344.0,20890176.0,20701568.0,20734464.0,20767360.0,20800256.0,20833152.0,20866048.0,20898944.0,20931840.0,20742784.0,20775744.0,20808704.0,20841664.0,20874624.0,20907584.0,20940544.0,20973504.0,20784000.0,20817024.0,20850048.0,20883072.0,20916096.0,20949120.0,20982144.0,21015168.0,20825216.0,20858304.0,20891392.0,20924480.0,20957568.0,20990656.0,21023744.0,21056832.0,20866432.0,20899584.0,20932736.0,20965888.0,20999040.0,21032192.0,21065344.0,21098496.0,20907648.0,20940864.0,20974080.0,21007296.0,21040512.0,21073728.0,21106944.0,21140160.0,20948864.0,20982144.0,21015424.0,21048704.0,21081984.0,21115264.0,21148544.0,21181824.0,20596864.0,20629184.0,20661504.0,20693824.0,20726144.0,20758464.0,20790784.0,20823104.0,20236672.0,20268032.0,20299392.0,20330752.0,20362112.0,20393472.0,20424832.0,20456192.0,20277888.0,20309312.0,20340736.0,20372160.0,20403584.0,20435008.0,20466432.0,20497856.0,20319104.0,20350592.0,20382080.0,20413568.0,20445056.0,20476544.0,20508032.0,20539520.0,20360320.0,20391872.0,20423424.0,20454976.0,20486528.0,20518080.0,20549632.0,20581184.0,20401536.0,20433152.0,20464768.0,20496384.0,20528000.0,20559616.0,20591232.0,20622848.0,20442752.0,20474432.0,20506112.0,20537792.0,20569472.0,20601152.0,20632832.0,20664512.0,20483968.0,20515712.0,20547456.0,20579200.0,20610944.0,20642688.0,20674432.0,20706176.0,20525184.0,20556992.0,20588800.0,20620608.0,20652416.0,20684224.0,20716032.0,20747840.0,20566400.0,20598272.0,20630144.0,20662016.0,20693888.0,20725760.0,20757632.0,20789504.0,20607616.0,20639552.0,20671488.0,20703424.0,20735360.0,20767296.0,20799232.0,20831168.0,20648832.0,20680832.0,20712832.0,20744832.0,20776832.0,20808832.0,20840832.0,20872832.0,20690048.0,20722112.0,20754176.0,20786240.0,20818304.0,20850368.0,20882432.0,20914496.0,20731264.0,20763392.0,20795520.0,20827648.0,20859776.0,20891904.0,20924032.0,20956160.0,20772480.0,20804672.0,20836864.0,20869056.0,20901248.0,20933440.0,20965632.0,20997824.0,20813696.0,20845952.0,20878208.0,20910464.0,20942720.0,20974976.0,21007232.0,21039488.0,20854912.0,20887232.0,20919552.0,20951872.0,20984192.0,21016512.0,21048832.0,21081152.0,20896128.0,20928512.0,20960896.0,20993280.0,21025664.0,21058048.0,21090432.0,21122816.0,20937344.0,20969792.0,21002240.0,21034688.0,21067136.0,21099584.0,21132032.0,21164480.0,20978560.0,21011072.0,21043584.0,21076096.0,21108608.0,21141120.0,21173632.0,21206144.0,21019776.0,21052352.0,21084928.0,21117504.0,21150080.0,21182656.0,21215232.0,21247808.0,21060992.0,21093632.0,21126272.0,21158912.0,21191552.0,21224192.0,21256832.0,21289472.0,21102208.0,21134912.0,21167616.0,21200320.0,21233024.0,21265728.0,21298432.0,21331136.0,21143424.0,21176192.0,21208960.0,21241728.0,21274496.0,21307264.0,21340032.0,21372800.0,21184640.0,21217472.0,21250304.0,21283136.0,21315968.0,21348800.0,21381632.0,21414464.0,21225856.0,21258752.0,21291648.0,21324544.0,21357440.0,21390336.0,21423232.0,21456128.0,21267072.0,21300032.0,21332992.0,21365952.0,21398912.0,21431872.0,21464832.0,21497792.0,21308288.0,21341312.0,21374336.0,21407360.0,21440384.0,21473408.0,21506432.0,21539456.0,21349504.0,21382592.0,21415680.0,21448768.0,21481856.0,21514944.0,21548032.0,21581120.0,21390720.0,21423872.0,21457024.0,21490176.0,21523328.0,21556480.0,21589632.0,21622784.0,21431936.0,21465152.0,21498368.0,21531584.0,21564800.0,21598016.0,21631232.0,21664448.0,21473152.0,21506432.0,21539712.0,21572992.0,21606272.0,21639552.0,21672832.0,21706112.0,21514368.0,21547712.0,21581056.0,21614400.0,21647744.0,21681088.0,21714432.0,21747776.0,21555584.0,21588992.0,21622400.0,21655808.0,21689216.0,21722624.0,21756032.0,21789440.0,21596800.0,21630272.0,21663744.0,21697216.0,21730688.0,21764160.0,21797632.0,21831104.0,21638016.0,21671552.0,21705088.0,21738624.0,21772160.0,21805696.0,21839232.0,21872768.0,21679232.0,21712832.0,21746432.0,21780032.0,21813632.0,21847232.0,21880832.0,21914432.0,21720448.0,21754112.0,21787776.0,21821440.0,21855104.0,21888768.0,21922432.0,21956096.0,21761664.0,21795392.0,21829120.0,21862848.0,21896576.0,21930304.0,21964032.0,21997760.0,20885376.0,20918144.0,20950912.0,20983680.0,21016448.0,21049216.0,21081984.0,21114752.0,20000896.0,20032704.0,20064512.0,20096320.0,20128128.0,20159936.0,20191744.0,20223552.0,10074496.0,10090416.0,10106336.0,10122256.0,10138176.0,10154096.0,10170016.0,10185936.0,10642176.0,10657968.0,10673760.0,10689552.0,10705344.0,10721136.0,10736928.0,10752720.0,21431680.0,21463296.0,21494912.0,21526528.0,21558144.0,21589760.0,21621376.0,21652992.0,21473920.0,21505600.0,21537280.0,21568960.0,21600640.0,21632320.0,21664000.0,21695680.0,21516160.0,21547904.0,21579648.0,21611392.0,21643136.0,21674880.0,21706624.0,21738368.0,21558400.0,21590208.0,21622016.0,21653824.0,21685632.0,21717440.0,21749248.0,21781056.0,21600640.0,21632512.0,21664384.0,21696256.0,21728128.0,21760000.0,21791872.0,21823744.0,21642880.0,21674816.0,21706752.0,21738688.0,21770624.0,21802560.0,21834496.0,21866432.0,21685120.0,21717120.0,21749120.0,21781120.0,21813120.0,21845120.0,21877120.0,21909120.0,21727360.0,21759424.0,21791488.0,21823552.0,21855616.0,21887680.0,21919744.0,21951808.0,21769600.0,21801728.0,21833856.0,21865984.0,21898112.0,21930240.0,21962368.0,21994496.0,21811840.0,21844032.0,21876224.0,21908416.0,21940608.0,21972800.0,22004992.0,22037184.0,21854080.0,21886336.0,21918592.0,21950848.0,21983104.0,22015360.0,22047616.0,22079872.0,21896320.0,21928640.0,21960960.0,21993280.0,22025600.0,22057920.0,22090240.0,22122560.0,21938560.0,21970944.0,22003328.0,22035712.0,22068096.0,22100480.0,22132864.0,22165248.0,21980800.0,22013248.0,22045696.0,22078144.0,22110592.0,22143040.0,22175488.0,22207936.0,22023040.0,22055552.0,22088064.0,22120576.0,22153088.0,22185600.0,22218112.0,22250624.0,22065280.0,22097856.0,22130432.0,22163008.0,22195584.0,22228160.0,22260736.0,22293312.0,22107520.0,22140160.0,22172800.0,22205440.0,22238080.0,22270720.0,22303360.0,22336000.0,22149760.0,22182464.0,22215168.0,22247872.0,22280576.0,22313280.0,22345984.0,22378688.0,22192000.0,22224768.0,22257536.0,22290304.0,22323072.0,22355840.0,22388608.0,22421376.0,22234240.0,22267072.0,22299904.0,22332736.0,22365568.0,22398400.0,22431232.0,22464064.0,22276480.0,22309376.0,22342272.0,22375168.0,22408064.0,22440960.0,22473856.0,22506752.0,22318720.0,22351680.0,22384640.0,22417600.0,22450560.0,22483520.0,22516480.0,22549440.0,22360960.0,22393984.0,22427008.0,22460032.0,22493056.0,22526080.0,22559104.0,22592128.0,22403200.0,22436288.0,22469376.0,22502464.0,22535552.0,22568640.0,22601728.0,22634816.0,22445440.0,22478592.0,22511744.0,22544896.0,22578048.0,22611200.0,22644352.0,22677504.0,22487680.0,22520896.0,22554112.0,22587328.0,22620544.0,22653760.0,22686976.0,22720192.0,22529920.0,22563200.0,22596480.0,22629760.0,22663040.0,22696320.0,22729600.0,22762880.0,22572160.0,22605504.0,22638848.0,22672192.0,22705536.0,22738880.0,22772224.0,22805568.0,22614400.0,22647808.0,22681216.0,22714624.0,22748032.0,22781440.0,22814848.0,22848256.0,22656640.0,22690112.0,22723584.0,22757056.0,22790528.0,22824000.0,22857472.0,22890944.0,22698880.0,22732416.0,22765952.0,22799488.0,22833024.0,22866560.0,22900096.0,22933632.0,22741120.0,22774720.0,22808320.0,22841920.0,22875520.0,22909120.0,22942720.0,22976320.0,22783360.0,22817024.0,22850688.0,22884352.0,22918016.0,22951680.0,22985344.0,23019008.0,22825600.0,22859328.0,22893056.0,22926784.0,22960512.0,22994240.0,23027968.0,23061696.0,22867840.0,22901632.0,22935424.0,22969216.0,23003008.0,23036800.0,23070592.0,23104384.0,22910080.0,22943936.0,22977792.0,23011648.0,23045504.0,23079360.0,23113216.0,23147072.0,22018432.0,22051328.0,22084224.0,22117120.0,22150016.0,22182912.0,22215808.0,22248704.0,21118592.0,21150528.0,21182464.0,21214400.0,21246336.0,21278272.0,21310208.0,21342144.0,20226944.0,20257920.0,20288896.0,20319872.0,20350848.0,20381824.0,20412800.0,20443776.0,19327104.0,19357120.0,19387136.0,19417152.0,19447168.0,19477184.0,19507200.0,19537216.0,19369344.0,19399424.0,19429504.0,19459584.0,19489664.0,19519744.0,19549824.0,19579904.0,19411584.0,19441728.0,19471872.0,19502016.0,19532160.0,19562304.0,19592448.0,19622592.0,19453824.0,19484032.0,19514240.0,19544448.0,19574656.0,19604864.0,19635072.0,19665280.0,19496064.0,19526336.0,19556608.0,19586880.0,19617152.0,19647424.0,19677696.0,19707968.0,19538304.0,19568640.0,19598976.0,19629312.0,19659648.0,19689984.0,19720320.0,19750656.0,19580544.0,19610944.0,19641344.0,19671744.0,19702144.0,19732544.0,19762944.0,19793344.0,19622784.0,19653248.0,19683712.0,19714176.0,19744640.0,19775104.0,19805568.0,19836032.0,19665024.0,19695552.0,19726080.0,19756608.0,19787136.0,19817664.0,19848192.0,19878720.0,19707264.0,19737856.0,19768448.0,19799040.0,19829632.0,19860224.0,19890816.0,19921408.0,19749504.0,19780160.0,19810816.0,19841472.0,19872128.0,19902784.0,19933440.0,19964096.0,19791744.0,19822464.0,19853184.0,19883904.0,19914624.0,19945344.0,19976064.0,20006784.0,19833984.0,19864768.0,19895552.0,19926336.0,19957120.0,19987904.0,20018688.0,20049472.0,19876224.0,19907072.0,19937920.0,19968768.0,19999616.0,20030464.0,20061312.0,20092160.0,19918464.0,19949376.0,19980288.0,20011200.0,20042112.0,20073024.0,20103936.0,20134848.0,19960704.0,19991680.0,20022656.0,20053632.0,20084608.0,20115584.0,20146560.0,20177536.0,20002944.0,20033984.0,20065024.0,20096064.0,20127104.0,20158144.0,20189184.0,20220224.0,20045184.0,20076288.0,20107392.0,20138496.0,20169600.0,20200704.0,20231808.0,20262912.0,20087424.0,20118592.0,20149760.0,20180928.0,20212096.0,20243264.0,20274432.0,20305600.0,20129664.0,20160896.0,20192128.0,20223360.0,20254592.0,20285824.0,20317056.0,20348288.0,20171904.0,20203200.0,20234496.0,20265792.0,20297088.0,20328384.0,20359680.0,20390976.0,20214144.0,20245504.0,20276864.0,20308224.0,20339584.0,20370944.0,20402304.0,20433664.0,20256384.0,20287808.0,20319232.0,20350656.0,20382080.0,20413504.0,20444928.0,20476352.0,20298624.0,20330112.0,20361600.0,20393088.0,20424576.0,20456064.0,20487552.0,20519040.0,20340864.0,20372416.0,20403968.0,20435520.0,20467072.0,20498624.0,20530176.0,20561728.0,20383104.0,20414720.0,20446336.0,20477952.0,20509568.0,20541184.0,20572800.0,20604416.0,20425344.0,20457024.0,20488704.0,20520384.0,20552064.0,20583744.0,20615424.0,20647104.0,20467584.0,20499328.0,20531072.0,20562816.0,20594560.0,20626304.0,20658048.0,20689792.0,20509824.0,20541632.0,20573440.0,20605248.0,20637056.0,20668864.0,20700672.0,20732480.0,20552064.0,20583936.0,20615808.0,20647680.0,20679552.0,20711424.0,20743296.0,20775168.0,20594304.0,20626240.0,20658176.0,20690112.0,20722048.0,20753984.0,20785920.0,20817856.0,20636544.0,20668544.0,20700544.0,20732544.0,20764544.0,20796544.0,20828544.0,20860544.0,20678784.0,20710848.0,20742912.0,20774976.0,20807040.0,20839104.0,20871168.0,20903232.0,20721024.0,20753152.0,20785280.0,20817408.0,20849536.0,20881664.0,20913792.0,20945920.0,20763264.0,20795456.0,20827648.0,20859840.0,20892032.0,20924224.0,20956416.0,20988608.0,20805504.0,20837760.0,20870016.0,20902272.0,20934528.0,20966784.0,20999040.0,21031296.0,20847744.0,20880064.0,20912384.0,20944704.0,20977024.0,21009344.0,21041664.0,21073984.0,20889984.0,20922368.0,20954752.0,20987136.0,21019520.0,21051904.0,21084288.0,21116672.0,20522624.0,20554048.0,20585472.0,20616896.0,20648320.0,20679744.0,20711168.0,20742592.0,20147072.0,20177536.0,20208000.0,20238464.0,20268928.0,20299392.0,20329856.0,20360320.0,20189312.0,20219840.0,20250368.0,20280896.0,20311424.0,20341952.0,20372480.0,20403008.0,20231552.0,20262144.0,20292736.0,20323328.0,20353920.0,20384512.0,20415104.0,20445696.0,20273792.0,20304448.0,20335104.0,20365760.0,20396416.0,20427072.0,20457728.0,20488384.0,20316032.0,20346752.0,20377472.0,20408192.0,20438912.0,20469632.0,20500352.0,20531072.0,20358272.0,20389056.0,20419840.0,20450624.0,20481408.0,20512192.0,20542976.0,20573760.0,20400512.0,20431360.0,20462208.0,20493056.0,20523904.0,20554752.0,20585600.0,20616448.0,20442752.0,20473664.0,20504576.0,20535488.0,20566400.0,20597312.0,20628224.0,20659136.0,20484992.0,20515968.0,20546944.0,20577920.0,20608896.0,20639872.0,20670848.0,20701824.0,20527232.0,20558272.0,20589312.0,20620352.0,20651392.0,20682432.0,20713472.0,20744512.0,20569472.0,20600576.0,20631680.0,20662784.0,20693888.0,20724992.0,20756096.0,20787200.0,20611712.0,20642880.0,20674048.0,20705216.0,20736384.0,20767552.0,20798720.0,20829888.0,20653952.0,20685184.0,20716416.0,20747648.0,20778880.0,20810112.0,20841344.0,20872576.0,20696192.0,20727488.0,20758784.0,20790080.0,20821376.0,20852672.0,20883968.0,20915264.0,20738432.0,20769792.0,20801152.0,20832512.0,20863872.0,20895232.0,20926592.0,20957952.0,20780672.0,20812096.0,20843520.0,20874944.0,20906368.0,20937792.0,20969216.0,21000640.0,20822912.0,20854400.0,20885888.0,20917376.0,20948864.0,20980352.0,21011840.0,21043328.0,20865152.0,20896704.0,20928256.0,20959808.0,20991360.0,21022912.0,21054464.0,21086016.0,20907392.0,20939008.0,20970624.0,21002240.0,21033856.0,21065472.0,21097088.0,21128704.0,20949632.0,20981312.0,21012992.0,21044672.0,21076352.0,21108032.0,21139712.0,21171392.0,20991872.0,21023616.0,21055360.0,21087104.0,21118848.0,21150592.0,21182336.0,21214080.0,21034112.0,21065920.0,21097728.0,21129536.0,21161344.0,21193152.0,21224960.0,21256768.0,21076352.0,21108224.0,21140096.0,21171968.0,21203840.0,21235712.0,21267584.0,21299456.0,21118592.0,21150528.0,21182464.0,21214400.0,21246336.0,21278272.0,21310208.0,21342144.0,21160832.0,21192832.0,21224832.0,21256832.0,21288832.0,21320832.0,21352832.0,21384832.0,21203072.0,21235136.0,21267200.0,21299264.0,21331328.0,21363392.0,21395456.0,21427520.0,21245312.0,21277440.0,21309568.0,21341696.0,21373824.0,21405952.0,21438080.0,21470208.0,21287552.0,21319744.0,21351936.0,21384128.0,21416320.0,21448512.0,21480704.0,21512896.0,21329792.0,21362048.0,21394304.0,21426560.0,21458816.0,21491072.0,21523328.0,21555584.0,21372032.0,21404352.0,21436672.0,21468992.0,21501312.0,21533632.0,21565952.0,21598272.0,21414272.0,21446656.0,21479040.0,21511424.0,21543808.0,21576192.0,21608576.0,21640960.0,21456512.0,21488960.0,21521408.0,21553856.0,21586304.0,21618752.0,21651200.0,21683648.0,21498752.0,21531264.0,21563776.0,21596288.0,21628800.0,21661312.0,21693824.0,21726336.0,21540992.0,21573568.0,21606144.0,21638720.0,21671296.0,21703872.0,21736448.0,21769024.0,21583232.0,21615872.0,21648512.0,21681152.0,21713792.0,21746432.0,21779072.0,21811712.0,21625472.0,21658176.0,21690880.0,21723584.0,21756288.0,21788992.0,21821696.0,21854400.0,21667712.0,21700480.0,21733248.0,21766016.0,21798784.0,21831552.0,21864320.0,21897088.0,21709952.0,21742784.0,21775616.0,21808448.0,21841280.0,21874112.0,21906944.0,21939776.0,20818304.0,20850176.0,20882048.0,20913920.0,20945792.0,20977664.0,21009536.0,21041408.0,19918464.0,19949376.0,19980288.0,20011200.0,20042112.0,20073024.0,20103936.0,20134848.0,19960704.0,19991680.0,20022656.0,20053632.0,20084608.0,20115584.0,20146560.0,20177536.0,20002944.0,20033984.0,20065024.0,20096064.0,20127104.0,20158144.0,20189184.0,20220224.0,20045184.0,20076288.0,20107392.0,20138496.0,20169600.0,20200704.0,20231808.0,20262912.0,20087424.0,20118592.0,20149760.0,20180928.0,20212096.0,20243264.0,20274432.0,20305600.0,20129664.0,20160896.0,20192128.0,20223360.0,20254592.0,20285824.0,20317056.0,20348288.0,20171904.0,20203200.0,20234496.0,20265792.0,20297088.0,20328384.0,20359680.0,20390976.0,20214144.0,20245504.0,20276864.0,20308224.0,20339584.0,20370944.0,20402304.0,20433664.0,20256384.0,20287808.0,20319232.0,20350656.0,20382080.0,20413504.0,20444928.0,20476352.0,20298624.0,20330112.0,20361600.0,20393088.0,20424576.0,20456064.0,20487552.0,20519040.0,20340864.0,20372416.0,20403968.0,20435520.0,20467072.0,20498624.0,20530176.0,20561728.0,20383104.0,20414720.0,20446336.0,20477952.0,20509568.0,20541184.0,20572800.0,20604416.0,20425344.0,20457024.0,20488704.0,20520384.0,20552064.0,20583744.0,20615424.0,20647104.0,20467584.0,20499328.0,20531072.0,20562816.0,20594560.0,20626304.0,20658048.0,20689792.0,20509824.0,20541632.0,20573440.0,20605248.0,20637056.0,20668864.0,20700672.0,20732480.0,20552064.0,20583936.0,20615808.0,20647680.0,20679552.0,20711424.0,20743296.0,20775168.0,20594304.0,20626240.0,20658176.0,20690112.0,20722048.0,20753984.0,20785920.0,20817856.0,20636544.0,20668544.0,20700544.0,20732544.0,20764544.0,20796544.0,20828544.0,20860544.0,20678784.0,20710848.0,20742912.0,20774976.0,20807040.0,20839104.0,20871168.0,20903232.0,20721024.0,20753152.0,20785280.0,20817408.0,20849536.0,20881664.0,20913792.0,20945920.0,20763264.0,20795456.0,20827648.0,20859840.0,20892032.0,20924224.0,20956416.0,20988608.0,20805504.0,20837760.0,20870016.0,20902272.0,20934528.0,20966784.0,20999040.0,21031296.0,20847744.0,20880064.0,20912384.0,20944704.0,20977024.0,21009344.0,21041664.0,21073984.0,20889984.0,20922368.0,20954752.0,20987136.0,21019520.0,21051904.0,21084288.0,21116672.0,20932224.0,20964672.0,20997120.0,21029568.0,21062016.0,21094464.0,21126912.0,21159360.0,20974464.0,21006976.0,21039488.0,21072000.0,21104512.0,21137024.0,21169536.0,21202048.0,21016704.0,21049280.0,21081856.0,21114432.0,21147008.0,21179584.0,21212160.0,21244736.0,21058944.0,21091584.0,21124224.0,21156864.0,21189504.0,21222144.0,21254784.0,21287424.0,21101184.0,21133888.0,21166592.0,21199296.0,21232000.0,21264704.0,21297408.0,21330112.0,21143424.0,21176192.0,21208960.0,21241728.0,21274496.0,21307264.0,21340032.0,21372800.0,21185664.0,21218496.0,21251328.0,21284160.0,21316992.0,21349824.0,21382656.0,21415488.0,21227904.0,21260800.0,21293696.0,21326592.0,21359488.0,21392384.0,21425280.0,21458176.0,21270144.0,21303104.0,21336064.0,21369024.0,21401984.0,21434944.0,21467904.0,21500864.0,21312384.0,21345408.0,21378432.0,21411456.0,21444480.0,21477504.0,21510528.0,21543552.0,21354624.0,21387712.0,21420800.0,21453888.0,21486976.0,21520064.0,21553152.0,21586240.0,21396864.0,21430016.0,21463168.0,21496320.0,21529472.0,21562624.0,21595776.0,21628928.0,21439104.0,21472320.0,21505536.0,21538752.0,21571968.0,21605184.0,21638400.0,21671616.0,21481344.0,21514624.0,21547904.0,21581184.0,21614464.0,21647744.0,21681024.0,21714304.0,21113984.0,21146304.0,21178624.0,21210944.0,21243264.0,21275584.0,21307904.0,21340224.0,20738432.0,20769792.0,20801152.0,20832512.0,20863872.0,20895232.0,20926592.0,20957952.0,20780672.0,20812096.0,20843520.0,20874944.0,20906368.0,20937792.0,20969216.0,21000640.0,20822912.0,20854400.0,20885888.0,20917376.0,20948864.0,20980352.0,21011840.0,21043328.0,20865152.0,20896704.0,20928256.0,20959808.0,20991360.0,21022912.0,21054464.0,21086016.0,20907392.0,20939008.0,20970624.0,21002240.0,21033856.0,21065472.0,21097088.0,21128704.0,20949632.0,20981312.0,21012992.0,21044672.0,21076352.0,21108032.0,21139712.0,21171392.0,20991872.0,21023616.0,21055360.0,21087104.0,21118848.0,21150592.0,21182336.0,21214080.0,21034112.0,21065920.0,21097728.0,21129536.0,21161344.0,21193152.0,21224960.0,21256768.0,21076352.0,21108224.0,21140096.0,21171968.0,21203840.0,21235712.0,21267584.0,21299456.0,21118592.0,21150528.0,21182464.0,21214400.0,21246336.0,21278272.0,21310208.0,21342144.0,21160832.0,21192832.0,21224832.0,21256832.0,21288832.0,21320832.0,21352832.0,21384832.0,21203072.0,21235136.0,21267200.0,21299264.0,21331328.0,21363392.0,21395456.0,21427520.0,21245312.0,21277440.0,21309568.0,21341696.0,21373824.0,21405952.0,21438080.0,21470208.0,21287552.0,21319744.0,21351936.0,21384128.0,21416320.0,21448512.0,21480704.0,21512896.0,21329792.0,21362048.0,21394304.0,21426560.0,21458816.0,21491072.0,21523328.0,21555584.0,21372032.0,21404352.0,21436672.0,21468992.0,21501312.0,21533632.0,21565952.0,21598272.0,21414272.0,21446656.0,21479040.0,21511424.0,21543808.0,21576192.0,21608576.0,21640960.0,21456512.0,21488960.0,21521408.0,21553856.0,21586304.0,21618752.0,21651200.0,21683648.0,21498752.0,21531264.0,21563776.0,21596288.0,21628800.0,21661312.0,21693824.0,21726336.0,21540992.0,21573568.0,21606144.0,21638720.0,21671296.0,21703872.0,21736448.0,21769024.0,21583232.0,21615872.0,21648512.0,21681152.0,21713792.0,21746432.0,21779072.0,21811712.0,21625472.0,21658176.0,21690880.0,21723584.0,21756288.0,21788992.0,21821696.0,21854400.0,21667712.0,21700480.0,21733248.0,21766016.0,21798784.0,21831552.0,21864320.0,21897088.0,21709952.0,21742784.0,21775616.0,21808448.0,21841280.0,21874112.0,21906944.0,21939776.0,21752192.0,21785088.0,21817984.0,21850880.0,21883776.0,21916672.0,21949568.0,21982464.0,21794432.0,21827392.0,21860352.0,21893312.0,21926272.0,21959232.0,21992192.0,22025152.0,21836672.0,21869696.0,21902720.0,21935744.0,21968768.0,22001792.0,22034816.0,22067840.0,21878912.0,21912000.0,21945088.0,21978176.0,22011264.0,22044352.0,22077440.0,22110528.0,21921152.0,21954304.0,21987456.0,22020608.0,22053760.0,22086912.0,22120064.0,22153216.0,21963392.0,21996608.0,22029824.0,22063040.0,22096256.0,22129472.0,22162688.0,22195904.0,22005632.0,22038912.0,22072192.0,22105472.0,22138752.0,22172032.0,22205312.0,22238592.0,22047872.0,22081216.0,22114560.0,22147904.0,22181248.0,22214592.0,22247936.0,22281280.0,22090112.0,22123520.0,22156928.0,22190336.0,22223744.0,22257152.0,22290560.0,22323968.0,22132352.0,22165824.0,22199296.0,22232768.0,22266240.0,22299712.0,22333184.0,22366656.0,22174592.0,22208128.0,22241664.0,22275200.0,22308736.0,22342272.0,22375808.0,22409344.0,22216832.0,22250432.0,22284032.0,22317632.0,22351232.0,22384832.0,22418432.0,22452032.0,22259072.0,22292736.0,22326400.0,22360064.0,22393728.0,22427392.0,22461056.0,22494720.0,22301312.0,22335040.0,22368768.0,22402496.0,22436224.0,22469952.0,22503680.0,22537408.0,21409664.0,21442432.0,21475200.0,21507968.0,21540736.0,21573504.0,21606272.0,21639040.0,20509824.0,20541632.0,20573440.0,20605248.0,20637056.0,20668864.0,20700672.0,20732480.0,10329216.0,10345136.0,10361056.0,10376976.0,10392896.0,10408816.0,10424736.0,10440656.0,10894848.0,10910640.0,10926432.0,10942224.0,10958016.0,10973808.0,10989600.0,11005392.0,21937536.0,21969152.0,22000768.0,22032384.0,22064000.0,22095616.0,22127232.0,22158848.0,21980800.0,22012480.0,22044160.0,22075840.0,22107520.0,22139200.0,22170880.0,22202560.0,22024064.0,22055808.0,22087552.0,22119296.0,22151040.0,22182784.0,22214528.0,22246272.0,22067328.0,22099136.0,22130944.0,22162752.0,22194560.0,22226368.0,22258176.0,22289984.0,22110592.0,22142464.0,22174336.0,22206208.0,22238080.0,22269952.0,22301824.0,22333696.0,22153856.0,22185792.0,22217728.0,22249664.0,22281600.0,22313536.0,22345472.0,22377408.0,22197120.0,22229120.0,22261120.0,22293120.0,22325120.0,22357120.0,22389120.0,22421120.0,22240384.0,22272448.0,22304512.0,22336576.0,22368640.0,22400704.0,22432768.0,22464832.0,22283648.0,22315776.0,22347904.0,22380032.0,22412160.0,22444288.0,22476416.0,22508544.0,22326912.0,22359104.0,22391296.0,22423488.0,22455680.0,22487872.0,22520064.0,22552256.0,22370176.0,22402432.0,22434688.0,22466944.0,22499200.0,22531456.0,22563712.0,22595968.0,22413440.0,22445760.0,22478080.0,22510400.0,22542720.0,22575040.0,22607360.0,22639680.0,22456704.0,22489088.0,22521472.0,22553856.0,22586240.0,22618624.0,22651008.0,22683392.0,22499968.0,22532416.0,22564864.0,22597312.0,22629760.0,22662208.0,22694656.0,22727104.0,22543232.0,22575744.0,22608256.0,22640768.0,22673280.0,22705792.0,22738304.0,22770816.0,22586496.0,22619072.0,22651648.0,22684224.0,22716800.0,22749376.0,22781952.0,22814528.0,22629760.0,22662400.0,22695040.0,22727680.0,22760320.0,22792960.0,22825600.0,22858240.0,22673024.0,22705728.0,22738432.0,22771136.0,22803840.0,22836544.0,22869248.0,22901952.0,22716288.0,22749056.0,22781824.0,22814592.0,22847360.0,22880128.0,22912896.0,22945664.0,22759552.0,22792384.0,22825216.0,22858048.0,22890880.0,22923712.0,22956544.0,22989376.0,22802816.0,22835712.0,22868608.0,22901504.0,22934400.0,22967296.0,23000192.0,23033088.0,22846080.0,22879040.0,22912000.0,22944960.0,22977920.0,23010880.0,23043840.0,23076800.0,22889344.0,22922368.0,22955392.0,22988416.0,23021440.0,23054464.0,23087488.0,23120512.0,22932608.0,22965696.0,22998784.0,23031872.0,23064960.0,23098048.0,23131136.0,23164224.0,22975872.0,23009024.0,23042176.0,23075328.0,23108480.0,23141632.0,23174784.0,23207936.0,23019136.0,23052352.0,23085568.0,23118784.0,23152000.0,23185216.0,23218432.0,23251648.0,23062400.0,23095680.0,23128960.0,23162240.0,23195520.0,23228800.0,23262080.0,23295360.0,23105664.0,23139008.0,23172352.0,23205696.0,23239040.0,23272384.0,23305728.0,23339072.0,23148928.0,23182336.0,23215744.0,23249152.0,23282560.0,23315968.0,23349376.0,23382784.0,23192192.0,23225664.0,23259136.0,23292608.0,23326080.0,23359552.0,23393024.0,23426496.0,23235456.0,23268992.0,23302528.0,23336064.0,23369600.0,23403136.0,23436672.0,23470208.0,23278720.0,23312320.0,23345920.0,23379520.0,23413120.0,23446720.0,23480320.0,23513920.0,23321984.0,23355648.0,23389312.0,23422976.0,23456640.0,23490304.0,23523968.0,23557632.0,23365248.0,23398976.0,23432704.0,23466432.0,23500160.0,23533888.0,23567616.0,23601344.0,23408512.0,23442304.0,23476096.0,23509888.0,23543680.0,23577472.0,23611264.0,23645056.0,23451776.0,23485632.0,23519488.0,23553344.0,23587200.0,23621056.0,23654912.0,23688768.0,22544768.0,22577664.0,22610560.0,22643456.0,22676352.0,22709248.0,22742144.0,22775040.0,21629568.0,21661504.0,21693440.0,21725376.0,21757312.0,21789248.0,21821184.0,21853120.0,20722560.0,20753536.0,20784512.0,20815488.0,20846464.0,20877440.0,20908416.0,20939392.0,19807360.0,19837376.0,19867392.0,19897408.0,19927424.0,19957440.0,19987456.0,20017472.0,19850624.0,19880704.0,19910784.0,19940864.0,19970944.0,20001024.0,20031104.0,20061184.0,19893888.0,19924032.0,19954176.0,19984320.0,20014464.0,20044608.0,20074752.0,20104896.0,19937152.0,19967360.0,19997568.0,20027776.0,20057984.0,20088192.0,20118400.0,20148608.0,19980416.0,20010688.0,20040960.0,20071232.0,20101504.0,20131776.0,20162048.0,20192320.0,20023680.0,20054016.0,20084352.0,20114688.0,20145024.0,20175360.0,20205696.0,20236032.0,20066944.0,20097344.0,20127744.0,20158144.0,20188544.0,20218944.0,20249344.0,20279744.0,20110208.0,20140672.0,20171136.0,20201600.0,20232064.0,20262528.0,20292992.0,20323456.0,20153472.0,20184000.0,20214528.0,20245056.0,20275584.0,20306112.0,20336640.0,20367168.0,20196736.0,20227328.0,20257920.0,20288512.0,20319104.0,20349696.0,20380288.0,20410880.0,20240000.0,20270656.0,20301312.0,20331968.0,20362624.0,20393280.0,20423936.0,20454592.0,20283264.0,20313984.0,20344704.0,20375424.0,20406144.0,20436864.0,20467584.0,20498304.0,20326528.0,20357312.0,20388096.0,20418880.0,20449664.0,20480448.0,20511232.0,20542016.0,20369792.0,20400640.0,20431488.0,20462336.0,20493184.0,20524032.0,20554880.0,20585728.0,20413056.0,20443968.0,20474880.0,20505792.0,20536704.0,20567616.0,20598528.0,20629440.0,20456320.0,20487296.0,20518272.0,20549248.0,20580224.0,20611200.0,20642176.0,20673152.0,20499584.0,20530624.0,20561664.0,20592704.0,20623744.0,20654784.0,20685824.0,20716864.0,20542848.0,20573952.0,20605056.0,20636160.0,20667264.0,20698368.0,20729472.0,20760576.0,20586112.0,20617280.0,20648448.0,20679616.0,20710784.0,20741952.0,20773120.0,20804288.0,20629376.0,20660608.0,20691840.0,20723072.0,20754304.0,20785536.0,20816768.0,20848000.0,20672640.0,20703936.0,20735232.0,20766528.0,20797824.0,20829120.0,20860416.0,20891712.0,20715904.0,20747264.0,20778624.0,20809984.0,20841344.0,20872704.0,20904064.0,20935424.0,20759168.0,20790592.0,20822016.0,20853440.0,20884864.0,20916288.0,20947712.0,20979136.0,20802432.0,20833920.0,20865408.0,20896896.0,20928384.0,20959872.0,20991360.0,21022848.0,20845696.0,20877248.0,20908800.0,20940352.0,20971904.0,21003456.0,21035008.0,21066560.0,20888960.0,20920576.0,20952192.0,20983808.0,21015424.0,21047040.0,21078656.0,21110272.0,20932224.0,20963904.0,20995584.0,21027264.0,21058944.0,21090624.0,21122304.0,21153984.0,20975488.0,21007232.0,21038976.0,21070720.0,21102464.0,21134208.0,21165952.0,21197696.0,21018752.0,21050560.0,21082368.0,21114176.0,21145984.0,21177792.0,21209600.0,21241408.0,21062016.0,21093888.0,21125760.0,21157632.0,21189504.0,21221376.0,21253248.0,21285120.0,21105280.0,21137216.0,21169152.0,21201088.0,21233024.0,21264960.0,21296896.0,21328832.0,21148544.0,21180544.0,21212544.0,21244544.0,21276544.0,21308544.0,21340544.0,21372544.0,21191808.0,21223872.0,21255936.0,21288000.0,21320064.0,21352128.0,21384192.0,21416256.0,21235072.0,21267200.0,21299328.0,21331456.0,21363584.0,21395712.0,21427840.0,21459968.0,21278336.0,21310528.0,21342720.0,21374912.0,21407104.0,21439296.0,21471488.0,21503680.0,21321600.0,21353856.0,21386112.0,21418368.0,21450624.0,21482880.0,21515136.0,21547392.0,21364864.0,21397184.0,21429504.0,21461824.0,21494144.0,21526464.0,21558784.0,21591104.0,21408128.0,21440512.0,21472896.0,21505280.0,21537664.0,21570048.0,21602432.0,21634816.0,21025408.0,21056832.0,21088256.0,21119680.0,21151104.0,21182528.0,21213952.0,21245376.0,20634496.0,20664960.0,20695424.0,20725888.0,20756352.0,20786816.0,20817280.0,20847744.0,20677760.0,20708288.0,20738816.0,20769344.0,20799872.0,20830400.0,20860928.0,20891456.0,20721024.0,20751616.0,20782208.0,20812800.0,20843392.0,20873984.0,20904576.0,20935168.0,20764288.0,20794944.0,20825600.0,20856256.0,20886912.0,20917568.0,20948224.0,20978880.0,20807552.0,20838272.0,20868992.0,20899712.0,20930432.0,20961152.0,20991872.0,21022592.0,20850816.0,20881600.0,20912384.0,20943168.0,20973952.0,21004736.0,21035520.0,21066304.0,20894080.0,20924928.0,20955776.0,20986624.0,21017472.0,21048320.0,21079168.0,21110016.0,20937344.0,20968256.0,20999168.0,21030080.0,21060992.0,21091904.0,21122816.0,21153728.0,20980608.0,21011584.0,21042560.0,21073536.0,21104512.0,21135488.0,21166464.0,21197440.0,21023872.0,21054912.0,21085952.0,21116992.0,21148032.0,21179072.0,21210112.0,21241152.0,21067136.0,21098240.0,21129344.0,21160448.0,21191552.0,21222656.0,21253760.0,21284864.0,21110400.0,21141568.0,21172736.0,21203904.0,21235072.0,21266240.0,21297408.0,21328576.0,21153664.0,21184896.0,21216128.0,21247360.0,21278592.0,21309824.0,21341056.0,21372288.0,21196928.0,21228224.0,21259520.0,21290816.0,21322112.0,21353408.0,21384704.0,21416000.0,21240192.0,21271552.0,21302912.0,21334272.0,21365632.0,21396992.0,21428352.0,21459712.0,21283456.0,21314880.0,21346304.0,21377728.0,21409152.0,21440576.0,21472000.0,21503424.0,21326720.0,21358208.0,21389696.0,21421184.0,21452672.0,21484160.0,21515648.0,21547136.0,21369984.0,21401536.0,21433088.0,21464640.0,21496192.0,21527744.0,21559296.0,21590848.0,21413248.0,21444864.0,21476480.0,21508096.0,21539712.0,21571328.0,21602944.0,21634560.0,21456512.0,21488192.0,21519872.0,21551552.0,21583232.0,21614912.0,21646592.0,21678272.0,21499776.0,21531520.0,21563264.0,21595008.0,21626752.0,21658496.0,21690240.0,21721984.0,21543040.0,21574848.0,21606656.0,21638464.0,21670272.0,21702080.0,21733888.0,21765696.0,21586304.0,21618176.0,21650048.0,21681920.0,21713792.0,21745664.0,21777536.0,21809408.0,21629568.0,21661504.0,21693440.0,21725376.0,21757312.0,21789248.0,21821184.0,21853120.0,21672832.0,21704832.0,21736832.0,21768832.0,21800832.0,21832832.0,21864832.0,21896832.0,21716096.0,21748160.0,21780224.0,21812288.0,21844352.0,21876416.0,21908480.0,21940544.0,21759360.0,21791488.0,21823616.0,21855744.0,21887872.0,21920000.0,21952128.0,21984256.0,21802624.0,21834816.0,21867008.0,21899200.0,21931392.0,21963584.0,21995776.0,22027968.0,21845888.0,21878144.0,21910400.0,21942656.0,21974912.0,22007168.0,22039424.0,22071680.0,21889152.0,21921472.0,21953792.0,21986112.0,22018432.0,22050752.0,22083072.0,22115392.0,21932416.0,21964800.0,21997184.0,22029568.0,22061952.0,22094336.0,22126720.0,22159104.0,21975680.0,22008128.0,22040576.0,22073024.0,22105472.0,22137920.0,22170368.0,22202816.0,22018944.0,22051456.0,22083968.0,22116480.0,22148992.0,22181504.0,22214016.0,22246528.0,22062208.0,22094784.0,22127360.0,22159936.0,22192512.0,22225088.0,22257664.0,22290240.0,22105472.0,22138112.0,22170752.0,22203392.0,22236032.0,22268672.0,22301312.0,22333952.0,22148736.0,22181440.0,22214144.0,22246848.0,22279552.0,22312256.0,22344960.0,22377664.0,22192000.0,22224768.0,22257536.0,22290304.0,22323072.0,22355840.0,22388608.0,22421376.0,22235264.0,22268096.0,22300928.0,22333760.0,22366592.0,22399424.0,22432256.0,22465088.0,21328256.0,21360128.0,21392000.0,21423872.0,21455744.0,21487616.0,21519488.0,21551360.0,20413056.0,20443968.0,20474880.0,20505792.0,20536704.0,20567616.0,20598528.0,20629440.0,20456320.0,20487296.0,20518272.0,20549248.0,20580224.0,20611200.0,20642176.0,20673152.0,20499584.0,20530624.0,20561664.0,20592704.0,20623744.0,20654784.0,20685824.0,20716864.0,20542848.0,20573952.0,20605056.0,20636160.0,20667264.0,20698368.0,20729472.0,20760576.0,20586112.0,20617280.0,20648448.0,20679616.0,20710784.0,20741952.0,20773120.0,20804288.0,20629376.0,20660608.0,20691840.0,20723072.0,20754304.0,20785536.0,20816768.0,20848000.0,20672640.0,20703936.0,20735232.0,20766528.0,20797824.0,20829120.0,20860416.0,20891712.0,20715904.0,20747264.0,20778624.0,20809984.0,20841344.0,20872704.0,20904064.0,20935424.0,20759168.0,20790592.0,20822016.0,20853440.0,20884864.0,20916288.0,20947712.0,20979136.0,20802432.0,20833920.0,20865408.0,20896896.0,20928384.0,20959872.0,20991360.0,21022848.0,20845696.0,20877248.0,20908800.0,20940352.0,20971904.0,21003456.0,21035008.0,21066560.0,20888960.0,20920576.0,20952192.0,20983808.0,21015424.0,21047040.0,21078656.0,21110272.0,20932224.0,20963904.0,20995584.0,21027264.0,21058944.0,21090624.0,21122304.0,21153984.0,20975488.0,21007232.0,21038976.0,21070720.0,21102464.0,21134208.0,21165952.0,21197696.0,21018752.0,21050560.0,21082368.0,21114176.0,21145984.0,21177792.0,21209600.0,21241408.0,21062016.0,21093888.0,21125760.0,21157632.0,21189504.0,21221376.0,21253248.0,21285120.0,21105280.0,21137216.0,21169152.0,21201088.0,21233024.0,21264960.0,21296896.0,21328832.0,21148544.0,21180544.0,21212544.0,21244544.0,21276544.0,21308544.0,21340544.0,21372544.0,21191808.0,21223872.0,21255936.0,21288000.0,21320064.0,21352128.0,21384192.0,21416256.0,21235072.0,21267200.0,21299328.0,21331456.0,21363584.0,21395712.0,21427840.0,21459968.0,21278336.0,21310528.0,21342720.0,21374912.0,21407104.0,21439296.0,21471488.0,21503680.0,21321600.0,21353856.0,21386112.0,21418368.0,21450624.0,21482880.0,21515136.0,21547392.0,21364864.0,21397184.0,21429504.0,21461824.0,21494144.0,21526464.0,21558784.0,21591104.0,21408128.0,21440512.0,21472896.0,21505280.0,21537664.0,21570048.0,21602432.0,21634816.0,21451392.0,21483840.0,21516288.0,21548736.0,21581184.0,21613632.0,21646080.0,21678528.0,21494656.0,21527168.0,21559680.0,21592192.0,21624704.0,21657216.0,21689728.0,21722240.0,21537920.0,21570496.0,21603072.0,21635648.0,21668224.0,21700800.0,21733376.0,21765952.0,21581184.0,21613824.0,21646464.0,21679104.0,21711744.0,21744384.0,21777024.0,21809664.0,21624448.0,21657152.0,21689856.0,21722560.0,21755264.0,21787968.0,21820672.0,21853376.0,21667712.0,21700480.0,21733248.0,21766016.0,21798784.0,21831552.0,21864320.0,21897088.0,21710976.0,21743808.0,21776640.0,21809472.0,21842304.0,21875136.0,21907968.0,21940800.0,21754240.0,21787136.0,21820032.0,21852928.0,21885824.0,21918720.0,21951616.0,21984512.0,21797504.0,21830464.0,21863424.0,21896384.0,21929344.0,21962304.0,21995264.0,22028224.0,21840768.0,21873792.0,21906816.0,21939840.0,21972864.0,22005888.0,22038912.0,22071936.0,21884032.0,21917120.0,21950208.0,21983296.0,22016384.0,22049472.0,22082560.0,22115648.0,21927296.0,21960448.0,21993600.0,22026752.0,22059904.0,22093056.0,22126208.0,22159360.0,21970560.0,22003776.0,22036992.0,22070208.0,22103424.0,22136640.0,22169856.0,22203072.0,22013824.0,22047104.0,22080384.0,22113664.0,22146944.0,22180224.0,22213504.0,22246784.0,21631104.0,21663424.0,21695744.0,21728064.0,21760384.0,21792704.0,21825024.0,21857344.0,21240192.0,21271552.0,21302912.0,21334272.0,21365632.0,21396992.0,21428352.0,21459712.0,21283456.0,21314880.0,21346304.0,21377728.0,21409152.0,21440576.0,21472000.0,21503424.0,21326720.0,21358208.0,21389696.0,21421184.0,21452672.0,21484160.0,21515648.0,21547136.0,21369984.0,21401536.0,21433088.0,21464640.0,21496192.0,21527744.0,21559296.0,21590848.0,21413248.0,21444864.0,21476480.0,21508096.0,21539712.0,21571328.0,21602944.0,21634560.0,21456512.0,21488192.0,21519872.0,21551552.0,21583232.0,21614912.0,21646592.0,21678272.0,21499776.0,21531520.0,21563264.0,21595008.0,21626752.0,21658496.0,21690240.0,21721984.0,21543040.0,21574848.0,21606656.0,21638464.0,21670272.0,21702080.0,21733888.0,21765696.0,21586304.0,21618176.0,21650048.0,21681920.0,21713792.0,21745664.0,21777536.0,21809408.0,21629568.0,21661504.0,21693440.0,21725376.0,21757312.0,21789248.0,21821184.0,21853120.0,21672832.0,21704832.0,21736832.0,21768832.0,21800832.0,21832832.0,21864832.0,21896832.0,21716096.0,21748160.0,21780224.0,21812288.0,21844352.0,21876416.0,21908480.0,21940544.0,21759360.0,21791488.0,21823616.0,21855744.0,21887872.0,21920000.0,21952128.0,21984256.0,21802624.0,21834816.0,21867008.0,21899200.0,21931392.0,21963584.0,21995776.0,22027968.0,21845888.0,21878144.0,21910400.0,21942656.0,21974912.0,22007168.0,22039424.0,22071680.0,21889152.0,21921472.0,21953792.0,21986112.0,22018432.0,22050752.0,22083072.0,22115392.0,21932416.0,21964800.0,21997184.0,22029568.0,22061952.0,22094336.0,22126720.0,22159104.0,21975680.0,22008128.0,22040576.0,22073024.0,22105472.0,22137920.0,22170368.0,22202816.0,22018944.0,22051456.0,22083968.0,22116480.0,22148992.0,22181504.0,22214016.0,22246528.0,22062208.0,22094784.0,22127360.0,22159936.0,22192512.0,22225088.0,22257664.0,22290240.0,22105472.0,22138112.0,22170752.0,22203392.0,22236032.0,22268672.0,22301312.0,22333952.0,22148736.0,22181440.0,22214144.0,22246848.0,22279552.0,22312256.0,22344960.0,22377664.0,22192000.0,22224768.0,22257536.0,22290304.0,22323072.0,22355840.0,22388608.0,22421376.0,22235264.0,22268096.0,22300928.0,22333760.0,22366592.0,22399424.0,22432256.0,22465088.0,22278528.0,22311424.0,22344320.0,22377216.0,22410112.0,22443008.0,22475904.0,22508800.0,22321792.0,22354752.0,22387712.0,22420672.0,22453632.0,22486592.0,22519552.0,22552512.0,22365056.0,22398080.0,22431104.0,22464128.0,22497152.0,22530176.0,22563200.0,22596224.0,22408320.0,22441408.0,22474496.0,22507584.0,22540672.0,22573760.0,22606848.0,22639936.0,22451584.0,22484736.0,22517888.0,22551040.0,22584192.0,22617344.0,22650496.0,22683648.0,22494848.0,22528064.0,22561280.0,22594496.0,22627712.0,22660928.0,22694144.0,22727360.0,22538112.0,22571392.0,22604672.0,22637952.0,22671232.0,22704512.0,22737792.0,22771072.0,22581376.0,22614720.0,22648064.0,22681408.0,22714752.0,22748096.0,22781440.0,22814784.0,22624640.0,22658048.0,22691456.0,22724864.0,22758272.0,22791680.0,22825088.0,22858496.0,22667904.0,22701376.0,22734848.0,22768320.0,22801792.0,22835264.0,22868736.0,22902208.0,22711168.0,22744704.0,22778240.0,22811776.0,22845312.0,22878848.0,22912384.0,22945920.0,22754432.0,22788032.0,22821632.0,22855232.0,22888832.0,22922432.0,22956032.0,22989632.0,22797696.0,22831360.0,22865024.0,22898688.0,22932352.0,22966016.0,22999680.0,23033344.0,22840960.0,22874688.0,22908416.0,22942144.0,22975872.0,23009600.0,23043328.0,23077056.0,21933952.0,21966720.0,21999488.0,22032256.0,22065024.0,22097792.0,22130560.0,22163328.0,21018752.0,21050560.0,21082368.0,21114176.0,21145984.0,21177792.0,21209600.0,21241408.0,10583936.0,10599856.0,10615776.0,10631696.0,10647616.0,10663536.0,10679456.0,10695376.0,11147520.0,11163312.0,11179104.0,11194896.0,11210688.0,11226480.0,11242272.0,11258064.0,22443392.0,22475008.0,22506624.0,22538240.0,22569856.0,22601472.0,22633088.0,22664704.0,22487680.0,22519360.0,22551040.0,22582720.0,22614400.0,22646080.0,22677760.0,22709440.0,22531968.0,22563712.0,22595456.0,22627200.0,22658944.0,22690688.0,22722432.0,22754176.0,22576256.0,22608064.0,22639872.0,22671680.0,22703488.0,22735296.0,22767104.0,22798912.0,22620544.0,22652416.0,22684288.0,22716160.0,22748032.0,22779904.0,22811776.0,22843648.0,22664832.0,22696768.0,22728704.0,22760640.0,22792576.0,22824512.0,22856448.0,22888384.0,22709120.0,22741120.0,22773120.0,22805120.0,22837120.0,22869120.0,22901120.0,22933120.0,22753408.0,22785472.0,22817536.0,22849600.0,22881664.0,22913728.0,22945792.0,22977856.0,22797696.0,22829824.0,22861952.0,22894080.0,22926208.0,22958336.0,22990464.0,23022592.0,22841984.0,22874176.0,22906368.0,22938560.0,22970752.0,23002944.0,23035136.0,23067328.0,22886272.0,22918528.0,22950784.0,22983040.0,23015296.0,23047552.0,23079808.0,23112064.0,22930560.0,22962880.0,22995200.0,23027520.0,23059840.0,23092160.0,23124480.0,23156800.0,22974848.0,23007232.0,23039616.0,23072000.0,23104384.0,23136768.0,23169152.0,23201536.0,23019136.0,23051584.0,23084032.0,23116480.0,23148928.0,23181376.0,23213824.0,23246272.0,23063424.0,23095936.0,23128448.0,23160960.0,23193472.0,23225984.0,23258496.0,23291008.0,23107712.0,23140288.0,23172864.0,23205440.0,23238016.0,23270592.0,23303168.0,23335744.0,23152000.0,23184640.0,23217280.0,23249920.0,23282560.0,23315200.0,23347840.0,23380480.0,23196288.0,23228992.0,23261696.0,23294400.0,23327104.0,23359808.0,23392512.0,23425216.0,23240576.0,23273344.0,23306112.0,23338880.0,23371648.0,23404416.0,23437184.0,23469952.0,23284864.0,23317696.0,23350528.0,23383360.0,23416192.0,23449024.0,23481856.0,23514688.0,23329152.0,23362048.0,23394944.0,23427840.0,23460736.0,23493632.0,23526528.0,23559424.0,23373440.0,23406400.0,23439360.0,23472320.0,23505280.0,23538240.0,23571200.0,23604160.0,23417728.0,23450752.0,23483776.0,23516800.0,23549824.0,23582848.0,23615872.0,23648896.0,23462016.0,23495104.0,23528192.0,23561280.0,23594368.0,23627456.0,23660544.0,23693632.0,23506304.0,23539456.0,23572608.0,23605760.0,23638912.0,23672064.0,23705216.0,23738368.0,23550592.0,23583808.0,23617024.0,23650240.0,23683456.0,23716672.0,23749888.0,23783104.0,23594880.0,23628160.0,23661440.0,23694720.0,23728000.0,23761280.0,23794560.0,23827840.0,23639168.0,23672512.0,23705856.0,23739200.0,23772544.0,23805888.0,23839232.0,23872576.0,23683456.0,23716864.0,23750272.0,23783680.0,23817088.0,23850496.0,23883904.0,23917312.0,23727744.0,23761216.0,23794688.0,23828160.0,23861632.0,23895104.0,23928576.0,23962048.0,23772032.0,23805568.0,23839104.0,23872640.0,23906176.0,23939712.0,23973248.0,24006784.0,23816320.0,23849920.0,23883520.0,23917120.0,23950720.0,23984320.0,24017920.0,24051520.0,23860608.0,23894272.0,23927936.0,23961600.0,23995264.0,24028928.0,24062592.0,24096256.0,23904896.0,23938624.0,23972352.0,24006080.0,24039808.0,24073536.0,24107264.0,24140992.0,23949184.0,23982976.0,24016768.0,24050560.0,24084352.0,24118144.0,24151936.0,24185728.0,23993472.0,24027328.0,24061184.0,24095040.0,24128896.0,24162752.0,24196608.0,24230464.0,23071104.0,23104000.0,23136896.0,23169792.0,23202688.0,23235584.0,23268480.0,23301376.0,22140544.0,22172480.0,22204416.0,22236352.0,22268288.0,22300224.0,22332160.0,22364096.0,21218176.0,21249152.0,21280128.0,21311104.0,21342080.0,21373056.0,21404032.0,21435008.0,20287616.0,20317632.0,20347648.0,20377664.0,20407680.0,20437696.0,20467712.0,20497728.0,20331904.0,20361984.0,20392064.0,20422144.0,20452224.0,20482304.0,20512384.0,20542464.0,20376192.0,20406336.0,20436480.0,20466624.0,20496768.0,20526912.0,20557056.0,20587200.0,20420480.0,20450688.0,20480896.0,20511104.0,20541312.0,20571520.0,20601728.0,20631936.0,20464768.0,20495040.0,20525312.0,20555584.0,20585856.0,20616128.0,20646400.0,20676672.0,20509056.0,20539392.0,20569728.0,20600064.0,20630400.0,20660736.0,20691072.0,20721408.0,20553344.0,20583744.0,20614144.0,20644544.0,20674944.0,20705344.0,20735744.0,20766144.0,20597632.0,20628096.0,20658560.0,20689024.0,20719488.0,20749952.0,20780416.0,20810880.0,20641920.0,20672448.0,20702976.0,20733504.0,20764032.0,20794560.0,20825088.0,20855616.0,20686208.0,20716800.0,20747392.0,20777984.0,20808576.0,20839168.0,20869760.0,20900352.0,20730496.0,20761152.0,20791808.0,20822464.0,20853120.0,20883776.0,20914432.0,20945088.0,20774784.0,20805504.0,20836224.0,20866944.0,20897664.0,20928384.0,20959104.0,20989824.0,20819072.0,20849856.0,20880640.0,20911424.0,20942208.0,20972992.0,21003776.0,21034560.0,20863360.0,20894208.0,20925056.0,20955904.0,20986752.0,21017600.0,21048448.0,21079296.0,20907648.0,20938560.0,20969472.0,21000384.0,21031296.0,21062208.0,21093120.0,21124032.0,20951936.0,20982912.0,21013888.0,21044864.0,21075840.0,21106816.0,21137792.0,21168768.0,20996224.0,21027264.0,21058304.0,21089344.0,21120384.0,21151424.0,21182464.0,21213504.0,21040512.0,21071616.0,21102720.0,21133824.0,21164928.0,21196032.0,21227136.0,21258240.0,21084800.0,21115968.0,21147136.0,21178304.0,21209472.0,21240640.0,21271808.0,21302976.0,21129088.0,21160320.0,21191552.0,21222784.0,21254016.0,21285248.0,21316480.0,21347712.0,21173376.0,21204672.0,21235968.0,21267264.0,21298560.0,21329856.0,21361152.0,21392448.0,21217664.0,21249024.0,21280384.0,21311744.0,21343104.0,21374464.0,21405824.0,21437184.0,21261952.0,21293376.0,21324800.0,21356224.0,21387648.0,21419072.0,21450496.0,21481920.0,21306240.0,21337728.0,21369216.0,21400704.0,21432192.0,21463680.0,21495168.0,21526656.0,21350528.0,21382080.0,21413632.0,21445184.0,21476736.0,21508288.0,21539840.0,21571392.0,21394816.0,21426432.0,21458048.0,21489664.0,21521280.0,21552896.0,21584512.0,21616128.0,21439104.0,21470784.0,21502464.0,21534144.0,21565824.0,21597504.0,21629184.0,21660864.0,21483392.0,21515136.0,21546880.0,21578624.0,21610368.0,21642112.0,21673856.0,21705600.0,21527680.0,21559488.0,21591296.0,21623104.0,21654912.0,21686720.0,21718528.0,21750336.0,21571968.0,21603840.0,21635712.0,21667584.0,21699456.0,21731328.0,21763200.0,21795072.0,21616256.0,21648192.0,21680128.0,21712064.0,21744000.0,21775936.0,21807872.0,21839808.0,21660544.0,21692544.0,21724544.0,21756544.0,21788544.0,21820544.0,21852544.0,21884544.0,21704832.0,21736896.0,21768960.0,21801024.0,21833088.0,21865152.0,21897216.0,21929280.0,21749120.0,21781248.0,21813376.0,21845504.0,21877632.0,21909760.0,21941888.0,21974016.0,21793408.0,21825600.0,21857792.0,21889984.0,21922176.0,21954368.0,21986560.0,22018752.0,21837696.0,21869952.0,21902208.0,21934464.0,21966720.0,21998976.0,22031232.0,22063488.0,21881984.0,21914304.0,21946624.0,21978944.0,22011264.0,22043584.0,22075904.0,22108224.0,21926272.0,21958656.0,21991040.0,22023424.0,22055808.0,22088192.0,22120576.0,22152960.0,21528192.0,21559616.0,21591040.0,21622464.0,21653888.0,21685312.0,21716736.0,21748160.0,21121920.0,21152384.0,21182848.0,21213312.0,21243776.0,21274240.0,21304704.0,21335168.0,21166208.0,21196736.0,21227264.0,21257792.0,21288320.0,21318848.0,21349376.0,21379904.0,21210496.0,21241088.0,21271680.0,21302272.0,21332864.0,21363456.0,21394048.0,21424640.0,21254784.0,21285440.0,21316096.0,21346752.0,21377408.0,21408064.0,21438720.0,21469376.0,21299072.0,21329792.0,21360512.0,21391232.0,21421952.0,21452672.0,21483392.0,21514112.0,21343360.0,21374144.0,21404928.0,21435712.0,21466496.0,21497280.0,21528064.0,21558848.0,21387648.0,21418496.0,21449344.0,21480192.0,21511040.0,21541888.0,21572736.0,21603584.0,21431936.0,21462848.0,21493760.0,21524672.0,21555584.0,21586496.0,21617408.0,21648320.0,21476224.0,21507200.0,21538176.0,21569152.0,21600128.0,21631104.0,21662080.0,21693056.0,21520512.0,21551552.0,21582592.0,21613632.0,21644672.0,21675712.0,21706752.0,21737792.0,21564800.0,21595904.0,21627008.0,21658112.0,21689216.0,21720320.0,21751424.0,21782528.0,21609088.0,21640256.0,21671424.0,21702592.0,21733760.0,21764928.0,21796096.0,21827264.0,21653376.0,21684608.0,21715840.0,21747072.0,21778304.0,21809536.0,21840768.0,21872000.0,21697664.0,21728960.0,21760256.0,21791552.0,21822848.0,21854144.0,21885440.0,21916736.0,21741952.0,21773312.0,21804672.0,21836032.0,21867392.0,21898752.0,21930112.0,21961472.0,21786240.0,21817664.0,21849088.0,21880512.0,21911936.0,21943360.0,21974784.0,22006208.0,21830528.0,21862016.0,21893504.0,21924992.0,21956480.0,21987968.0,22019456.0,22050944.0,21874816.0,21906368.0,21937920.0,21969472.0,22001024.0,22032576.0,22064128.0,22095680.0,21919104.0,21950720.0,21982336.0,22013952.0,22045568.0,22077184.0,22108800.0,22140416.0,21963392.0,21995072.0,22026752.0,22058432.0,22090112.0,22121792.0,22153472.0,22185152.0,22007680.0,22039424.0,22071168.0,22102912.0,22134656.0,22166400.0,22198144.0,22229888.0,22051968.0,22083776.0,22115584.0,22147392.0,22179200.0,22211008.0,22242816.0,22274624.0,22096256.0,22128128.0,22160000.0,22191872.0,22223744.0,22255616.0,22287488.0,22319360.0,22140544.0,22172480.0,22204416.0,22236352.0,22268288.0,22300224.0,22332160.0,22364096.0,22184832.0,22216832.0,22248832.0,22280832.0,22312832.0,22344832.0,22376832.0,22408832.0,22229120.0,22261184.0,22293248.0,22325312.0,22357376.0,22389440.0,22421504.0,22453568.0,22273408.0,22305536.0,22337664.0,22369792.0,22401920.0,22434048.0,22466176.0,22498304.0,22317696.0,22349888.0,22382080.0,22414272.0,22446464.0,22478656.0,22510848.0,22543040.0,22361984.0,22394240.0,22426496.0,22458752.0,22491008.0,22523264.0,22555520.0,22587776.0,22406272.0,22438592.0,22470912.0,22503232.0,22535552.0,22567872.0,22600192.0,22632512.0,22450560.0,22482944.0,22515328.0,22547712.0,22580096.0,22612480.0,22644864.0,22677248.0,22494848.0,22527296.0,22559744.0,22592192.0,22624640.0,22657088.0,22689536.0,22721984.0,22539136.0,22571648.0,22604160.0,22636672.0,22669184.0,22701696.0,22734208.0,22766720.0,22583424.0,22616000.0,22648576.0,22681152.0,22713728.0,22746304.0,22778880.0,22811456.0,22627712.0,22660352.0,22692992.0,22725632.0,22758272.0,22790912.0,22823552.0,22856192.0,22672000.0,22704704.0,22737408.0,22770112.0,22802816.0,22835520.0,22868224.0,22900928.0,22716288.0,22749056.0,22781824.0,22814592.0,22847360.0,22880128.0,22912896.0,22945664.0,22760576.0,22793408.0,22826240.0,22859072.0,22891904.0,22924736.0,22957568.0,22990400.0,21838208.0,21870080.0,21901952.0,21933824.0,21965696.0,21997568.0,22029440.0,22061312.0,20907648.0,20938560.0,20969472.0,21000384.0,21031296.0,21062208.0,21093120.0,21124032.0,20951936.0,20982912.0,21013888.0,21044864.0,21075840.0,21106816.0,21137792.0,21168768.0,20996224.0,21027264.0,21058304.0,21089344.0,21120384.0,21151424.0,21182464.0,21213504.0,21040512.0,21071616.0,21102720.0,21133824.0,21164928.0,21196032.0,21227136.0,21258240.0,21084800.0,21115968.0,21147136.0,21178304.0,21209472.0,21240640.0,21271808.0,21302976.0,21129088.0,21160320.0,21191552.0,21222784.0,21254016.0,21285248.0,21316480.0,21347712.0,21173376.0,21204672.0,21235968.0,21267264.0,21298560.0,21329856.0,21361152.0,21392448.0,21217664.0,21249024.0,21280384.0,21311744.0,21343104.0,21374464.0,21405824.0,21437184.0,21261952.0,21293376.0,21324800.0,21356224.0,21387648.0,21419072.0,21450496.0,21481920.0,21306240.0,21337728.0,21369216.0,21400704.0,21432192.0,21463680.0,21495168.0,21526656.0,21350528.0,21382080.0,21413632.0,21445184.0,21476736.0,21508288.0,21539840.0,21571392.0,21394816.0,21426432.0,21458048.0,21489664.0,21521280.0,21552896.0,21584512.0,21616128.0,21439104.0,21470784.0,21502464.0,21534144.0,21565824.0,21597504.0,21629184.0,21660864.0,21483392.0,21515136.0,21546880.0,21578624.0,21610368.0,21642112.0,21673856.0,21705600.0,21527680.0,21559488.0,21591296.0,21623104.0,21654912.0,21686720.0,21718528.0,21750336.0,21571968.0,21603840.0,21635712.0,21667584.0,21699456.0,21731328.0,21763200.0,21795072.0,21616256.0,21648192.0,21680128.0,21712064.0,21744000.0,21775936.0,21807872.0,21839808.0,21660544.0,21692544.0,21724544.0,21756544.0,21788544.0,21820544.0,21852544.0,21884544.0,21704832.0,21736896.0,21768960.0,21801024.0,21833088.0,21865152.0,21897216.0,21929280.0,21749120.0,21781248.0,21813376.0,21845504.0,21877632.0,21909760.0,21941888.0,21974016.0,21793408.0,21825600.0,21857792.0,21889984.0,21922176.0,21954368.0,21986560.0,22018752.0,21837696.0,21869952.0,21902208.0,21934464.0,21966720.0,21998976.0,22031232.0,22063488.0,21881984.0,21914304.0,21946624.0,21978944.0,22011264.0,22043584.0,22075904.0,22108224.0,21926272.0,21958656.0,21991040.0,22023424.0,22055808.0,22088192.0,22120576.0,22152960.0,21970560.0,22003008.0,22035456.0,22067904.0,22100352.0,22132800.0,22165248.0,22197696.0,22014848.0,22047360.0,22079872.0,22112384.0,22144896.0,22177408.0,22209920.0,22242432.0,22059136.0,22091712.0,22124288.0,22156864.0,22189440.0,22222016.0,22254592.0,22287168.0,22103424.0,22136064.0,22168704.0,22201344.0,22233984.0,22266624.0,22299264.0,22331904.0,22147712.0,22180416.0,22213120.0,22245824.0,22278528.0,22311232.0,22343936.0,22376640.0,22192000.0,22224768.0,22257536.0,22290304.0,22323072.0,22355840.0,22388608.0,22421376.0,22236288.0,22269120.0,22301952.0,22334784.0,22367616.0,22400448.0,22433280.0,22466112.0,22280576.0,22313472.0,22346368.0,22379264.0,22412160.0,22445056.0,22477952.0,22510848.0,22324864.0,22357824.0,22390784.0,22423744.0,22456704.0,22489664.0,22522624.0,22555584.0,22369152.0,22402176.0,22435200.0,22468224.0,22501248.0,22534272.0,22567296.0,22600320.0,22413440.0,22446528.0,22479616.0,22512704.0,22545792.0,22578880.0,22611968.0,22645056.0,22457728.0,22490880.0,22524032.0,22557184.0,22590336.0,22623488.0,22656640.0,22689792.0,22502016.0,22535232.0,22568448.0,22601664.0,22634880.0,22668096.0,22701312.0,22734528.0,22546304.0,22579584.0,22612864.0,22646144.0,22679424.0,22712704.0,22745984.0,22779264.0,22148224.0,22180544.0,22212864.0,22245184.0,22277504.0,22309824.0,22342144.0,22374464.0,21741952.0,21773312.0,21804672.0,21836032.0,21867392.0,21898752.0,21930112.0,21961472.0,21786240.0,21817664.0,21849088.0,21880512.0,21911936.0,21943360.0,21974784.0,22006208.0,21830528.0,21862016.0,21893504.0,21924992.0,21956480.0,21987968.0,22019456.0,22050944.0,21874816.0,21906368.0,21937920.0,21969472.0,22001024.0,22032576.0,22064128.0,22095680.0,21919104.0,21950720.0,21982336.0,22013952.0,22045568.0,22077184.0,22108800.0,22140416.0,21963392.0,21995072.0,22026752.0,22058432.0,22090112.0,22121792.0,22153472.0,22185152.0,22007680.0,22039424.0,22071168.0,22102912.0,22134656.0,22166400.0,22198144.0,22229888.0,22051968.0,22083776.0,22115584.0,22147392.0,22179200.0,22211008.0,22242816.0,22274624.0,22096256.0,22128128.0,22160000.0,22191872.0,22223744.0,22255616.0,22287488.0,22319360.0,22140544.0,22172480.0,22204416.0,22236352.0,22268288.0,22300224.0,22332160.0,22364096.0,22184832.0,22216832.0,22248832.0,22280832.0,22312832.0,22344832.0,22376832.0,22408832.0,22229120.0,22261184.0,22293248.0,22325312.0,22357376.0,22389440.0,22421504.0,22453568.0,22273408.0,22305536.0,22337664.0,22369792.0,22401920.0,22434048.0,22466176.0,22498304.0,22317696.0,22349888.0,22382080.0,22414272.0,22446464.0,22478656.0,22510848.0,22543040.0,22361984.0,22394240.0,22426496.0,22458752.0,22491008.0,22523264.0,22555520.0,22587776.0,22406272.0,22438592.0,22470912.0,22503232.0,22535552.0,22567872.0,22600192.0,22632512.0,22450560.0,22482944.0,22515328.0,22547712.0,22580096.0,22612480.0,22644864.0,22677248.0,22494848.0,22527296.0,22559744.0,22592192.0,22624640.0,22657088.0,22689536.0,22721984.0,22539136.0,22571648.0,22604160.0,22636672.0,22669184.0,22701696.0,22734208.0,22766720.0,22583424.0,22616000.0,22648576.0,22681152.0,22713728.0,22746304.0,22778880.0,22811456.0,22627712.0,22660352.0,22692992.0,22725632.0,22758272.0,22790912.0,22823552.0,22856192.0,22672000.0,22704704.0,22737408.0,22770112.0,22802816.0,22835520.0,22868224.0,22900928.0,22716288.0,22749056.0,22781824.0,22814592.0,22847360.0,22880128.0,22912896.0,22945664.0,22760576.0,22793408.0,22826240.0,22859072.0,22891904.0,22924736.0,22957568.0,22990400.0,22804864.0,22837760.0,22870656.0,22903552.0,22936448.0,22969344.0,23002240.0,23035136.0,22849152.0,22882112.0,22915072.0,22948032.0,22980992.0,23013952.0,23046912.0,23079872.0,22893440.0,22926464.0,22959488.0,22992512.0,23025536.0,23058560.0,23091584.0,23124608.0,22937728.0,22970816.0,23003904.0,23036992.0,23070080.0,23103168.0,23136256.0,23169344.0,22982016.0,23015168.0,23048320.0,23081472.0,23114624.0,23147776.0,23180928.0,23214080.0,23026304.0,23059520.0,23092736.0,23125952.0,23159168.0,23192384.0,23225600.0,23258816.0,23070592.0,23103872.0,23137152.0,23170432.0,23203712.0,23236992.0,23270272.0,23303552.0,23114880.0,23148224.0,23181568.0,23214912.0,23248256.0,23281600.0,23314944.0,23348288.0,23159168.0,23192576.0,23225984.0,23259392.0,23292800.0,23326208.0,23359616.0,23393024.0,23203456.0,23236928.0,23270400.0,23303872.0,23337344.0,23370816.0,23404288.0,23437760.0,23247744.0,23281280.0,23314816.0,23348352.0,23381888.0,23415424.0,23448960.0,23482496.0,23292032.0,23325632.0,23359232.0,23392832.0,23426432.0,23460032.0,23493632.0,23527232.0,23336320.0,23369984.0,23403648.0,23437312.0,23470976.0,23504640.0,23538304.0,23571968.0,23380608.0,23414336.0,23448064.0,23481792.0,23515520.0,23549248.0,23582976.0,23616704.0,22458240.0,22491008.0,22523776.0,22556544.0,22589312.0,22622080.0,22654848.0,22687616.0,21527680.0,21559488.0,21591296.0,21623104.0,21654912.0,21686720.0,21718528.0,21750336.0,10838656.0,10854576.0,10870496.0,10886416.0,10902336.0,10918256.0,10934176.0,10950096.0,11400192.0,11415984.0,11431776.0,11447568.0,11463360.0,11479152.0,11494944.0,11510736.0,22949248.0,22980864.0,23012480.0,23044096.0,23075712.0,23107328.0,23138944.0,23170560.0,22994560.0,23026240.0,23057920.0,23089600.0,23121280.0,23152960.0,23184640.0,23216320.0,23039872.0,23071616.0,23103360.0,23135104.0,23166848.0,23198592.0,23230336.0,23262080.0,23085184.0,23116992.0,23148800.0,23180608.0,23212416.0,23244224.0,23276032.0,23307840.0,23130496.0,23162368.0,23194240.0,23226112.0,23257984.0,23289856.0,23321728.0,23353600.0,23175808.0,23207744.0,23239680.0,23271616.0,23303552.0,23335488.0,23367424.0,23399360.0,23221120.0,23253120.0,23285120.0,23317120.0,23349120.0,23381120.0,23413120.0,23445120.0,23266432.0,23298496.0,23330560.0,23362624.0,23394688.0,23426752.0,23458816.0,23490880.0,23311744.0,23343872.0,23376000.0,23408128.0,23440256.0,23472384.0,23504512.0,23536640.0,23357056.0,23389248.0,23421440.0,23453632.0,23485824.0,23518016.0,23550208.0,23582400.0,23402368.0,23434624.0,23466880.0,23499136.0,23531392.0,23563648.0,23595904.0,23628160.0,23447680.0,23480000.0,23512320.0,23544640.0,23576960.0,23609280.0,23641600.0,23673920.0,23492992.0,23525376.0,23557760.0,23590144.0,23622528.0,23654912.0,23687296.0,23719680.0,23538304.0,23570752.0,23603200.0,23635648.0,23668096.0,23700544.0,23732992.0,23765440.0,23583616.0,23616128.0,23648640.0,23681152.0,23713664.0,23746176.0,23778688.0,23811200.0,23628928.0,23661504.0,23694080.0,23726656.0,23759232.0,23791808.0,23824384.0,23856960.0,23674240.0,23706880.0,23739520.0,23772160.0,23804800.0,23837440.0,23870080.0,23902720.0,23719552.0,23752256.0,23784960.0,23817664.0,23850368.0,23883072.0,23915776.0,23948480.0,23764864.0,23797632.0,23830400.0,23863168.0,23895936.0,23928704.0,23961472.0,23994240.0,23810176.0,23843008.0,23875840.0,23908672.0,23941504.0,23974336.0,24007168.0,24040000.0,23855488.0,23888384.0,23921280.0,23954176.0,23987072.0,24019968.0,24052864.0,24085760.0,23900800.0,23933760.0,23966720.0,23999680.0,24032640.0,24065600.0,24098560.0,24131520.0,23946112.0,23979136.0,24012160.0,24045184.0,24078208.0,24111232.0,24144256.0,24177280.0,23991424.0,24024512.0,24057600.0,24090688.0,24123776.0,24156864.0,24189952.0,24223040.0,24036736.0,24069888.0,24103040.0,24136192.0,24169344.0,24202496.0,24235648.0,24268800.0,24082048.0,24115264.0,24148480.0,24181696.0,24214912.0,24248128.0,24281344.0,24314560.0,24127360.0,24160640.0,24193920.0,24227200.0,24260480.0,24293760.0,24327040.0,24360320.0,24172672.0,24206016.0,24239360.0,24272704.0,24306048.0,24339392.0,24372736.0,24406080.0,24217984.0,24251392.0,24284800.0,24318208.0,24351616.0,24385024.0,24418432.0,24451840.0,24263296.0,24296768.0,24330240.0,24363712.0,24397184.0,24430656.0,24464128.0,24497600.0,24308608.0,24342144.0,24375680.0,24409216.0,24442752.0,24476288.0,24509824.0,24543360.0,24353920.0,24387520.0,24421120.0,24454720.0,24488320.0,24521920.0,24555520.0,24589120.0,24399232.0,24432896.0,24466560.0,24500224.0,24533888.0,24567552.0,24601216.0,24634880.0,24444544.0,24478272.0,24512000.0,24545728.0,24579456.0,24613184.0,24646912.0,24680640.0,24489856.0,24523648.0,24557440.0,24591232.0,24625024.0,24658816.0,24692608.0,24726400.0,24535168.0,24569024.0,24602880.0,24636736.0,24670592.0,24704448.0,24738304.0,24772160.0,23597440.0,23630336.0,23663232.0,23696128.0,23729024.0,23761920.0,23794816.0,23827712.0,22651520.0,22683456.0,22715392.0,22747328.0,22779264.0,22811200.0,22843136.0,22875072.0,21713792.0,21744768.0,21775744.0,21806720.0,21837696.0,21868672.0,21899648.0,21930624.0,20767872.0,20797888.0,20827904.0,20857920.0,20887936.0,20917952.0,20947968.0,20977984.0,20813184.0,20843264.0,20873344.0,20903424.0,20933504.0,20963584.0,20993664.0,21023744.0,20858496.0,20888640.0,20918784.0,20948928.0,20979072.0,21009216.0,21039360.0,21069504.0,20903808.0,20934016.0,20964224.0,20994432.0,21024640.0,21054848.0,21085056.0,21115264.0,20949120.0,20979392.0,21009664.0,21039936.0,21070208.0,21100480.0,21130752.0,21161024.0,20994432.0,21024768.0,21055104.0,21085440.0,21115776.0,21146112.0,21176448.0,21206784.0,21039744.0,21070144.0,21100544.0,21130944.0,21161344.0,21191744.0,21222144.0,21252544.0,21085056.0,21115520.0,21145984.0,21176448.0,21206912.0,21237376.0,21267840.0,21298304.0,21130368.0,21160896.0,21191424.0,21221952.0,21252480.0,21283008.0,21313536.0,21344064.0,21175680.0,21206272.0,21236864.0,21267456.0,21298048.0,21328640.0,21359232.0,21389824.0,21220992.0,21251648.0,21282304.0,21312960.0,21343616.0,21374272.0,21404928.0,21435584.0,21266304.0,21297024.0,21327744.0,21358464.0,21389184.0,21419904.0,21450624.0,21481344.0,21311616.0,21342400.0,21373184.0,21403968.0,21434752.0,21465536.0,21496320.0,21527104.0,21356928.0,21387776.0,21418624.0,21449472.0,21480320.0,21511168.0,21542016.0,21572864.0,21402240.0,21433152.0,21464064.0,21494976.0,21525888.0,21556800.0,21587712.0,21618624.0,21447552.0,21478528.0,21509504.0,21540480.0,21571456.0,21602432.0,21633408.0,21664384.0,21492864.0,21523904.0,21554944.0,21585984.0,21617024.0,21648064.0,21679104.0,21710144.0,21538176.0,21569280.0,21600384.0,21631488.0,21662592.0,21693696.0,21724800.0,21755904.0,21583488.0,21614656.0,21645824.0,21676992.0,21708160.0,21739328.0,21770496.0,21801664.0,21628800.0,21660032.0,21691264.0,21722496.0,21753728.0,21784960.0,21816192.0,21847424.0,21674112.0,21705408.0,21736704.0,21768000.0,21799296.0,21830592.0,21861888.0,21893184.0,21719424.0,21750784.0,21782144.0,21813504.0,21844864.0,21876224.0,21907584.0,21938944.0,21764736.0,21796160.0,21827584.0,21859008.0,21890432.0,21921856.0,21953280.0,21984704.0,21810048.0,21841536.0,21873024.0,21904512.0,21936000.0,21967488.0,21998976.0,22030464.0,21855360.0,21886912.0,21918464.0,21950016.0,21981568.0,22013120.0,22044672.0,22076224.0,21900672.0,21932288.0,21963904.0,21995520.0,22027136.0,22058752.0,22090368.0,22121984.0,21945984.0,21977664.0,22009344.0,22041024.0,22072704.0,22104384.0,22136064.0,22167744.0,21991296.0,22023040.0,22054784.0,22086528.0,22118272.0,22150016.0,22181760.0,22213504.0,22036608.0,22068416.0,22100224.0,22132032.0,22163840.0,22195648.0,22227456.0,22259264.0,22081920.0,22113792.0,22145664.0,22177536.0,22209408.0,22241280.0,22273152.0,22305024.0,22127232.0,22159168.0,22191104.0,22223040.0,22254976.0,22286912.0,22318848.0,22350784.0,22172544.0,22204544.0,22236544.0,22268544.0,22300544.0,22332544.0,22364544.0,22396544.0,22217856.0,22249920.0,22281984.0,22314048.0,22346112.0,22378176.0,22410240.0,22442304.0,22263168.0,22295296.0,22327424.0,22359552.0,22391680.0,22423808.0,22455936.0,22488064.0,22308480.0,22340672.0,22372864.0,22405056.0,22437248.0,22469440.0,22501632.0,22533824.0,22353792.0,22386048.0,22418304.0,22450560.0,22482816.0,22515072.0,22547328.0,22579584.0,22399104.0,22431424.0,22463744.0,22496064.0,22528384.0,22560704.0,22593024.0,22625344.0,22444416.0,22476800.0,22509184.0,22541568.0,22573952.0,22606336.0,22638720.0,22671104.0,22030976.0,22062400.0,22093824.0,22125248.0,22156672.0,22188096.0,22219520.0,22250944.0,21609344.0,21639808.0,21670272.0,21700736.0,21731200.0,21761664.0,21792128.0,21822592.0,21654656.0,21685184.0,21715712.0,21746240.0,21776768.0,21807296.0,21837824.0,21868352.0,21699968.0,21730560.0,21761152.0,21791744.0,21822336.0,21852928.0,21883520.0,21914112.0,21745280.0,21775936.0,21806592.0,21837248.0,21867904.0,21898560.0,21929216.0,21959872.0,21790592.0,21821312.0,21852032.0,21882752.0,21913472.0,21944192.0,21974912.0,22005632.0,21835904.0,21866688.0,21897472.0,21928256.0,21959040.0,21989824.0,22020608.0,22051392.0,21881216.0,21912064.0,21942912.0,21973760.0,22004608.0,22035456.0,22066304.0,22097152.0,21926528.0,21957440.0,21988352.0,22019264.0,22050176.0,22081088.0,22112000.0,22142912.0,21971840.0,22002816.0,22033792.0,22064768.0,22095744.0,22126720.0,22157696.0,22188672.0,22017152.0,22048192.0,22079232.0,22110272.0,22141312.0,22172352.0,22203392.0,22234432.0,22062464.0,22093568.0,22124672.0,22155776.0,22186880.0,22217984.0,22249088.0,22280192.0,22107776.0,22138944.0,22170112.0,22201280.0,22232448.0,22263616.0,22294784.0,22325952.0,22153088.0,22184320.0,22215552.0,22246784.0,22278016.0,22309248.0,22340480.0,22371712.0,22198400.0,22229696.0,22260992.0,22292288.0,22323584.0,22354880.0,22386176.0,22417472.0,22243712.0,22275072.0,22306432.0,22337792.0,22369152.0,22400512.0,22431872.0,22463232.0,22289024.0,22320448.0,22351872.0,22383296.0,22414720.0,22446144.0,22477568.0,22508992.0,22334336.0,22365824.0,22397312.0,22428800.0,22460288.0,22491776.0,22523264.0,22554752.0,22379648.0,22411200.0,22442752.0,22474304.0,22505856.0,22537408.0,22568960.0,22600512.0,22424960.0,22456576.0,22488192.0,22519808.0,22551424.0,22583040.0,22614656.0,22646272.0,22470272.0,22501952.0,22533632.0,22565312.0,22596992.0,22628672.0,22660352.0,22692032.0,22515584.0,22547328.0,22579072.0,22610816.0,22642560.0,22674304.0,22706048.0,22737792.0,22560896.0,22592704.0,22624512.0,22656320.0,22688128.0,22719936.0,22751744.0,22783552.0,22606208.0,22638080.0,22669952.0,22701824.0,22733696.0,22765568.0,22797440.0,22829312.0,22651520.0,22683456.0,22715392.0,22747328.0,22779264.0,22811200.0,22843136.0,22875072.0,22696832.0,22728832.0,22760832.0,22792832.0,22824832.0,22856832.0,22888832.0,22920832.0,22742144.0,22774208.0,22806272.0,22838336.0,22870400.0,22902464.0,22934528.0,22966592.0,22787456.0,22819584.0,22851712.0,22883840.0,22915968.0,22948096.0,22980224.0,23012352.0,22832768.0,22864960.0,22897152.0,22929344.0,22961536.0,22993728.0,23025920.0,23058112.0,22878080.0,22910336.0,22942592.0,22974848.0,23007104.0,23039360.0,23071616.0,23103872.0,22923392.0,22955712.0,22988032.0,23020352.0,23052672.0,23084992.0,23117312.0,23149632.0,22968704.0,23001088.0,23033472.0,23065856.0,23098240.0,23130624.0,23163008.0,23195392.0,23014016.0,23046464.0,23078912.0,23111360.0,23143808.0,23176256.0,23208704.0,23241152.0,23059328.0,23091840.0,23124352.0,23156864.0,23189376.0,23221888.0,23254400.0,23286912.0,23104640.0,23137216.0,23169792.0,23202368.0,23234944.0,23267520.0,23300096.0,23332672.0,23149952.0,23182592.0,23215232.0,23247872.0,23280512.0,23313152.0,23345792.0,23378432.0,23195264.0,23227968.0,23260672.0,23293376.0,23326080.0,23358784.0,23391488.0,23424192.0,23240576.0,23273344.0,23306112.0,23338880.0,23371648.0,23404416.0,23437184.0,23469952.0,23285888.0,23318720.0,23351552.0,23384384.0,23417216.0,23450048.0,23482880.0,23515712.0,22348160.0,22380032.0,22411904.0,22443776.0,22475648.0,22507520.0,22539392.0,22571264.0,21402240.0,21433152.0,21464064.0,21494976.0,21525888.0,21556800.0,21587712.0,21618624.0,21447552.0,21478528.0,21509504.0,21540480.0,21571456.0,21602432.0,21633408.0,21664384.0,21492864.0,21523904.0,21554944.0,21585984.0,21617024.0,21648064.0,21679104.0,21710144.0,21538176.0,21569280.0,21600384.0,21631488.0,21662592.0,21693696.0,21724800.0,21755904.0,21583488.0,21614656.0,21645824.0,21676992.0,21708160.0,21739328.0,21770496.0,21801664.0,21628800.0,21660032.0,21691264.0,21722496.0,21753728.0,21784960.0,21816192.0,21847424.0,21674112.0,21705408.0,21736704.0,21768000.0,21799296.0,21830592.0,21861888.0,21893184.0,21719424.0,21750784.0,21782144.0,21813504.0,21844864.0,21876224.0,21907584.0,21938944.0,21764736.0,21796160.0,21827584.0,21859008.0,21890432.0,21921856.0,21953280.0,21984704.0,21810048.0,21841536.0,21873024.0,21904512.0,21936000.0,21967488.0,21998976.0,22030464.0,21855360.0,21886912.0,21918464.0,21950016.0,21981568.0,22013120.0,22044672.0,22076224.0,21900672.0,21932288.0,21963904.0,21995520.0,22027136.0,22058752.0,22090368.0,22121984.0,21945984.0,21977664.0,22009344.0,22041024.0,22072704.0,22104384.0,22136064.0,22167744.0,21991296.0,22023040.0,22054784.0,22086528.0,22118272.0,22150016.0,22181760.0,22213504.0,22036608.0,22068416.0,22100224.0,22132032.0,22163840.0,22195648.0,22227456.0,22259264.0,22081920.0,22113792.0,22145664.0,22177536.0,22209408.0,22241280.0,22273152.0,22305024.0,22127232.0,22159168.0,22191104.0,22223040.0,22254976.0,22286912.0,22318848.0,22350784.0,22172544.0,22204544.0,22236544.0,22268544.0,22300544.0,22332544.0,22364544.0,22396544.0,22217856.0,22249920.0,22281984.0,22314048.0,22346112.0,22378176.0,22410240.0,22442304.0,22263168.0,22295296.0,22327424.0,22359552.0,22391680.0,22423808.0,22455936.0,22488064.0,22308480.0,22340672.0,22372864.0,22405056.0,22437248.0,22469440.0,22501632.0,22533824.0,22353792.0,22386048.0,22418304.0,22450560.0,22482816.0,22515072.0,22547328.0,22579584.0,22399104.0,22431424.0,22463744.0,22496064.0,22528384.0,22560704.0,22593024.0,22625344.0,22444416.0,22476800.0,22509184.0,22541568.0,22573952.0,22606336.0,22638720.0,22671104.0,22489728.0,22522176.0,22554624.0,22587072.0,22619520.0,22651968.0,22684416.0,22716864.0,22535040.0,22567552.0,22600064.0,22632576.0,22665088.0,22697600.0,22730112.0,22762624.0,22580352.0,22612928.0,22645504.0,22678080.0,22710656.0,22743232.0,22775808.0,22808384.0,22625664.0,22658304.0,22690944.0,22723584.0,22756224.0,22788864.0,22821504.0,22854144.0,22670976.0,22703680.0,22736384.0,22769088.0,22801792.0,22834496.0,22867200.0,22899904.0,22716288.0,22749056.0,22781824.0,22814592.0,22847360.0,22880128.0,22912896.0,22945664.0,22761600.0,22794432.0,22827264.0,22860096.0,22892928.0,22925760.0,22958592.0,22991424.0,22806912.0,22839808.0,22872704.0,22905600.0,22938496.0,22971392.0,23004288.0,23037184.0,22852224.0,22885184.0,22918144.0,22951104.0,22984064.0,23017024.0,23049984.0,23082944.0,22897536.0,22930560.0,22963584.0,22996608.0,23029632.0,23062656.0,23095680.0,23128704.0,22942848.0,22975936.0,23009024.0,23042112.0,23075200.0,23108288.0,23141376.0,23174464.0,22988160.0,23021312.0,23054464.0,23087616.0,23120768.0,23153920.0,23187072.0,23220224.0,23033472.0,23066688.0,23099904.0,23133120.0,23166336.0,23199552.0,23232768.0,23265984.0,23078784.0,23112064.0,23145344.0,23178624.0,23211904.0,23245184.0,23278464.0,23311744.0,22665344.0,22697664.0,22729984.0,22762304.0,22794624.0,22826944.0,22859264.0,22891584.0,22243712.0,22275072.0,22306432.0,22337792.0,22369152.0,22400512.0,22431872.0,22463232.0,22289024.0,22320448.0,22351872.0,22383296.0,22414720.0,22446144.0,22477568.0,22508992.0,22334336.0,22365824.0,22397312.0,22428800.0,22460288.0,22491776.0,22523264.0,22554752.0,22379648.0,22411200.0,22442752.0,22474304.0,22505856.0,22537408.0,22568960.0,22600512.0,22424960.0,22456576.0,22488192.0,22519808.0,22551424.0,22583040.0,22614656.0,22646272.0,22470272.0,22501952.0,22533632.0,22565312.0,22596992.0,22628672.0,22660352.0,22692032.0,22515584.0,22547328.0,22579072.0,22610816.0,22642560.0,22674304.0,22706048.0,22737792.0,22560896.0,22592704.0,22624512.0,22656320.0,22688128.0,22719936.0,22751744.0,22783552.0,22606208.0,22638080.0,22669952.0,22701824.0,22733696.0,22765568.0,22797440.0,22829312.0,22651520.0,22683456.0,22715392.0,22747328.0,22779264.0,22811200.0,22843136.0,22875072.0,22696832.0,22728832.0,22760832.0,22792832.0,22824832.0,22856832.0,22888832.0,22920832.0,22742144.0,22774208.0,22806272.0,22838336.0,22870400.0,22902464.0,22934528.0,22966592.0,22787456.0,22819584.0,22851712.0,22883840.0,22915968.0,22948096.0,22980224.0,23012352.0,22832768.0,22864960.0,22897152.0,22929344.0,22961536.0,22993728.0,23025920.0,23058112.0,22878080.0,22910336.0,22942592.0,22974848.0,23007104.0,23039360.0,23071616.0,23103872.0,22923392.0,22955712.0,22988032.0,23020352.0,23052672.0,23084992.0,23117312.0,23149632.0,22968704.0,23001088.0,23033472.0,23065856.0,23098240.0,23130624.0,23163008.0,23195392.0,23014016.0,23046464.0,23078912.0,23111360.0,23143808.0,23176256.0,23208704.0,23241152.0,23059328.0,23091840.0,23124352.0,23156864.0,23189376.0,23221888.0,23254400.0,23286912.0,23104640.0,23137216.0,23169792.0,23202368.0,23234944.0,23267520.0,23300096.0,23332672.0,23149952.0,23182592.0,23215232.0,23247872.0,23280512.0,23313152.0,23345792.0,23378432.0,23195264.0,23227968.0,23260672.0,23293376.0,23326080.0,23358784.0,23391488.0,23424192.0,23240576.0,23273344.0,23306112.0,23338880.0,23371648.0,23404416.0,23437184.0,23469952.0,23285888.0,23318720.0,23351552.0,23384384.0,23417216.0,23450048.0,23482880.0,23515712.0,23331200.0,23364096.0,23396992.0,23429888.0,23462784.0,23495680.0,23528576.0,23561472.0,23376512.0,23409472.0,23442432.0,23475392.0,23508352.0,23541312.0,23574272.0,23607232.0,23421824.0,23454848.0,23487872.0,23520896.0,23553920.0,23586944.0,23619968.0,23652992.0,23467136.0,23500224.0,23533312.0,23566400.0,23599488.0,23632576.0,23665664.0,23698752.0,23512448.0,23545600.0,23578752.0,23611904.0,23645056.0,23678208.0,23711360.0,23744512.0,23557760.0,23590976.0,23624192.0,23657408.0,23690624.0,23723840.0,23757056.0,23790272.0,23603072.0,23636352.0,23669632.0,23702912.0,23736192.0,23769472.0,23802752.0,23836032.0,23648384.0,23681728.0,23715072.0,23748416.0,23781760.0,23815104.0,23848448.0,23881792.0,23693696.0,23727104.0,23760512.0,23793920.0,23827328.0,23860736.0,23894144.0,23927552.0,23739008.0,23772480.0,23805952.0,23839424.0,23872896.0,23906368.0,23939840.0,23973312.0,23784320.0,23817856.0,23851392.0,23884928.0,23918464.0,23952000.0,23985536.0,24019072.0,23829632.0,23863232.0,23896832.0,23930432.0,23964032.0,23997632.0,24031232.0,24064832.0,23874944.0,23908608.0,23942272.0,23975936.0,24009600.0,24043264.0,24076928.0,24110592.0,23920256.0,23953984.0,23987712.0,24021440.0,24055168.0,24088896.0,24122624.0,24156352.0,22982528.0,23015296.0,23048064.0,23080832.0,23113600.0,23146368.0,23179136.0,23211904.0,22036608.0,22068416.0,22100224.0,22132032.0,22163840.0,22195648.0,22227456.0,22259264.0,11093376.0,11109296.0,11125216.0,11141136.0,11157056.0,11172976.0,11188896.0,11204816.0,11652864.0,11668656.0,11684448.0,11700240.0,11716032.0,11731824.0,11747616.0,11763408.0,23455104.0,23486720.0,23518336.0,23549952.0,23581568.0,23613184.0,23644800.0,23676416.0,23501440.0,23533120.0,23564800.0,23596480.0,23628160.0,23659840.0,23691520.0,23723200.0,23547776.0,23579520.0,23611264.0,23643008.0,23674752.0,23706496.0,23738240.0,23769984.0,23594112.0,23625920.0,23657728.0,23689536.0,23721344.0,23753152.0,23784960.0,23816768.0,23640448.0,23672320.0,23704192.0,23736064.0,23767936.0,23799808.0,23831680.0,23863552.0,23686784.0,23718720.0,23750656.0,23782592.0,23814528.0,23846464.0,23878400.0,23910336.0,23733120.0,23765120.0,23797120.0,23829120.0,23861120.0,23893120.0,23925120.0,23957120.0,23779456.0,23811520.0,23843584.0,23875648.0,23907712.0,23939776.0,23971840.0,24003904.0,23825792.0,23857920.0,23890048.0,23922176.0,23954304.0,23986432.0,24018560.0,24050688.0,23872128.0,23904320.0,23936512.0,23968704.0,24000896.0,24033088.0,24065280.0,24097472.0,23918464.0,23950720.0,23982976.0,24015232.0,24047488.0,24079744.0,24112000.0,24144256.0,23964800.0,23997120.0,24029440.0,24061760.0,24094080.0,24126400.0,24158720.0,24191040.0,24011136.0,24043520.0,24075904.0,24108288.0,24140672.0,24173056.0,24205440.0,24237824.0,24057472.0,24089920.0,24122368.0,24154816.0,24187264.0,24219712.0,24252160.0,24284608.0,24103808.0,24136320.0,24168832.0,24201344.0,24233856.0,24266368.0,24298880.0,24331392.0,24150144.0,24182720.0,24215296.0,24247872.0,24280448.0,24313024.0,24345600.0,24378176.0,24196480.0,24229120.0,24261760.0,24294400.0,24327040.0,24359680.0,24392320.0,24424960.0,24242816.0,24275520.0,24308224.0,24340928.0,24373632.0,24406336.0,24439040.0,24471744.0,24289152.0,24321920.0,24354688.0,24387456.0,24420224.0,24452992.0,24485760.0,24518528.0,24335488.0,24368320.0,24401152.0,24433984.0,24466816.0,24499648.0,24532480.0,24565312.0,24381824.0,24414720.0,24447616.0,24480512.0,24513408.0,24546304.0,24579200.0,24612096.0,24428160.0,24461120.0,24494080.0,24527040.0,24560000.0,24592960.0,24625920.0,24658880.0,24474496.0,24507520.0,24540544.0,24573568.0,24606592.0,24639616.0,24672640.0,24705664.0,24520832.0,24553920.0,24587008.0,24620096.0,24653184.0,24686272.0,24719360.0,24752448.0,24567168.0,24600320.0,24633472.0,24666624.0,24699776.0,24732928.0,24766080.0,24799232.0,24613504.0,24646720.0,24679936.0,24713152.0,24746368.0,24779584.0,24812800.0,24846016.0,24659840.0,24693120.0,24726400.0,24759680.0,24792960.0,24826240.0,24859520.0,24892800.0,24706176.0,24739520.0,24772864.0,24806208.0,24839552.0,24872896.0,24906240.0,24939584.0,24752512.0,24785920.0,24819328.0,24852736.0,24886144.0,24919552.0,24952960.0,24986368.0,24798848.0,24832320.0,24865792.0,24899264.0,24932736.0,24966208.0,24999680.0,25033152.0,24845184.0,24878720.0,24912256.0,24945792.0,24979328.0,25012864.0,25046400.0,25079936.0,24891520.0,24925120.0,24958720.0,24992320.0,25025920.0,25059520.0,25093120.0,25126720.0,24937856.0,24971520.0,25005184.0,25038848.0,25072512.0,25106176.0,25139840.0,25173504.0,24984192.0,25017920.0,25051648.0,25085376.0,25119104.0,25152832.0,25186560.0,25220288.0,25030528.0,25064320.0,25098112.0,25131904.0,25165696.0,25199488.0,25233280.0,25267072.0,25076864.0,25110720.0,25144576.0,25178432.0,25212288.0,25246144.0,25280000.0,25313856.0,24123776.0,24156672.0,24189568.0,24222464.0,24255360.0,24288256.0,24321152.0,24354048.0,23162496.0,23194432.0,23226368.0,23258304.0,23290240.0,23322176.0,23354112.0,23386048.0,22209408.0,22240384.0,22271360.0,22302336.0,22333312.0,22364288.0,22395264.0,22426240.0,21248128.0,21278144.0,21308160.0,21338176.0,21368192.0,21398208.0,21428224.0,21458240.0,21294464.0,21324544.0,21354624.0,21384704.0,21414784.0,21444864.0,21474944.0,21505024.0,21340800.0,21370944.0,21401088.0,21431232.0,21461376.0,21491520.0,21521664.0,21551808.0,21387136.0,21417344.0,21447552.0,21477760.0,21507968.0,21538176.0,21568384.0,21598592.0,21433472.0,21463744.0,21494016.0,21524288.0,21554560.0,21584832.0,21615104.0,21645376.0,21479808.0,21510144.0,21540480.0,21570816.0,21601152.0,21631488.0,21661824.0,21692160.0,21526144.0,21556544.0,21586944.0,21617344.0,21647744.0,21678144.0,21708544.0,21738944.0,21572480.0,21602944.0,21633408.0,21663872.0,21694336.0,21724800.0,21755264.0,21785728.0,21618816.0,21649344.0,21679872.0,21710400.0,21740928.0,21771456.0,21801984.0,21832512.0,21665152.0,21695744.0,21726336.0,21756928.0,21787520.0,21818112.0,21848704.0,21879296.0,21711488.0,21742144.0,21772800.0,21803456.0,21834112.0,21864768.0,21895424.0,21926080.0,21757824.0,21788544.0,21819264.0,21849984.0,21880704.0,21911424.0,21942144.0,21972864.0,21804160.0,21834944.0,21865728.0,21896512.0,21927296.0,21958080.0,21988864.0,22019648.0,21850496.0,21881344.0,21912192.0,21943040.0,21973888.0,22004736.0,22035584.0,22066432.0,21896832.0,21927744.0,21958656.0,21989568.0,22020480.0,22051392.0,22082304.0,22113216.0,21943168.0,21974144.0,22005120.0,22036096.0,22067072.0,22098048.0,22129024.0,22160000.0,21989504.0,22020544.0,22051584.0,22082624.0,22113664.0,22144704.0,22175744.0,22206784.0,22035840.0,22066944.0,22098048.0,22129152.0,22160256.0,22191360.0,22222464.0,22253568.0,22082176.0,22113344.0,22144512.0,22175680.0,22206848.0,22238016.0,22269184.0,22300352.0,22128512.0,22159744.0,22190976.0,22222208.0,22253440.0,22284672.0,22315904.0,22347136.0,22174848.0,22206144.0,22237440.0,22268736.0,22300032.0,22331328.0,22362624.0,22393920.0,22221184.0,22252544.0,22283904.0,22315264.0,22346624.0,22377984.0,22409344.0,22440704.0,22267520.0,22298944.0,22330368.0,22361792.0,22393216.0,22424640.0,22456064.0,22487488.0,22313856.0,22345344.0,22376832.0,22408320.0,22439808.0,22471296.0,22502784.0,22534272.0,22360192.0,22391744.0,22423296.0,22454848.0,22486400.0,22517952.0,22549504.0,22581056.0,22406528.0,22438144.0,22469760.0,22501376.0,22532992.0,22564608.0,22596224.0,22627840.0,22452864.0,22484544.0,22516224.0,22547904.0,22579584.0,22611264.0,22642944.0,22674624.0,22499200.0,22530944.0,22562688.0,22594432.0,22626176.0,22657920.0,22689664.0,22721408.0,22545536.0,22577344.0,22609152.0,22640960.0,22672768.0,22704576.0,22736384.0,22768192.0,22591872.0,22623744.0,22655616.0,22687488.0,22719360.0,22751232.0,22783104.0,22814976.0,22638208.0,22670144.0,22702080.0,22734016.0,22765952.0,22797888.0,22829824.0,22861760.0,22684544.0,22716544.0,22748544.0,22780544.0,22812544.0,22844544.0,22876544.0,22908544.0,22730880.0,22762944.0,22795008.0,22827072.0,22859136.0,22891200.0,22923264.0,22955328.0,22777216.0,22809344.0,22841472.0,22873600.0,22905728.0,22937856.0,22969984.0,23002112.0,22823552.0,22855744.0,22887936.0,22920128.0,22952320.0,22984512.0,23016704.0,23048896.0,22869888.0,22902144.0,22934400.0,22966656.0,22998912.0,23031168.0,23063424.0,23095680.0,22916224.0,22948544.0,22980864.0,23013184.0,23045504.0,23077824.0,23110144.0,23142464.0,22962560.0,22994944.0,23027328.0,23059712.0,23092096.0,23124480.0,23156864.0,23189248.0,22533760.0,22565184.0,22596608.0,22628032.0,22659456.0,22690880.0,22722304.0,22753728.0,22096768.0,22127232.0,22157696.0,22188160.0,22218624.0,22249088.0,22279552.0,22310016.0,22143104.0,22173632.0,22204160.0,22234688.0,22265216.0,22295744.0,22326272.0,22356800.0,22189440.0,22220032.0,22250624.0,22281216.0,22311808.0,22342400.0,22372992.0,22403584.0,22235776.0,22266432.0,22297088.0,22327744.0,22358400.0,22389056.0,22419712.0,22450368.0,22282112.0,22312832.0,22343552.0,22374272.0,22404992.0,22435712.0,22466432.0,22497152.0,22328448.0,22359232.0,22390016.0,22420800.0,22451584.0,22482368.0,22513152.0,22543936.0,22374784.0,22405632.0,22436480.0,22467328.0,22498176.0,22529024.0,22559872.0,22590720.0,22421120.0,22452032.0,22482944.0,22513856.0,22544768.0,22575680.0,22606592.0,22637504.0,22467456.0,22498432.0,22529408.0,22560384.0,22591360.0,22622336.0,22653312.0,22684288.0,22513792.0,22544832.0,22575872.0,22606912.0,22637952.0,22668992.0,22700032.0,22731072.0,22560128.0,22591232.0,22622336.0,22653440.0,22684544.0,22715648.0,22746752.0,22777856.0,22606464.0,22637632.0,22668800.0,22699968.0,22731136.0,22762304.0,22793472.0,22824640.0,22652800.0,22684032.0,22715264.0,22746496.0,22777728.0,22808960.0,22840192.0,22871424.0,22699136.0,22730432.0,22761728.0,22793024.0,22824320.0,22855616.0,22886912.0,22918208.0,22745472.0,22776832.0,22808192.0,22839552.0,22870912.0,22902272.0,22933632.0,22964992.0,22791808.0,22823232.0,22854656.0,22886080.0,22917504.0,22948928.0,22980352.0,23011776.0,22838144.0,22869632.0,22901120.0,22932608.0,22964096.0,22995584.0,23027072.0,23058560.0,22884480.0,22916032.0,22947584.0,22979136.0,23010688.0,23042240.0,23073792.0,23105344.0,22930816.0,22962432.0,22994048.0,23025664.0,23057280.0,23088896.0,23120512.0,23152128.0,22977152.0,23008832.0,23040512.0,23072192.0,23103872.0,23135552.0,23167232.0,23198912.0,23023488.0,23055232.0,23086976.0,23118720.0,23150464.0,23182208.0,23213952.0,23245696.0,23069824.0,23101632.0,23133440.0,23165248.0,23197056.0,23228864.0,23260672.0,23292480.0,23116160.0,23148032.0,23179904.0,23211776.0,23243648.0,23275520.0,23307392.0,23339264.0,23162496.0,23194432.0,23226368.0,23258304.0,23290240.0,23322176.0,23354112.0,23386048.0,23208832.0,23240832.0,23272832.0,23304832.0,23336832.0,23368832.0,23400832.0,23432832.0,23255168.0,23287232.0,23319296.0,23351360.0,23383424.0,23415488.0,23447552.0,23479616.0,23301504.0,23333632.0,23365760.0,23397888.0,23430016.0,23462144.0,23494272.0,23526400.0,23347840.0,23380032.0,23412224.0,23444416.0,23476608.0,23508800.0,23540992.0,23573184.0,23394176.0,23426432.0,23458688.0,23490944.0,23523200.0,23555456.0,23587712.0,23619968.0,23440512.0,23472832.0,23505152.0,23537472.0,23569792.0,23602112.0,23634432.0,23666752.0,23486848.0,23519232.0,23551616.0,23584000.0,23616384.0,23648768.0,23681152.0,23713536.0,23533184.0,23565632.0,23598080.0,23630528.0,23662976.0,23695424.0,23727872.0,23760320.0,23579520.0,23612032.0,23644544.0,23677056.0,23709568.0,23742080.0,23774592.0,23807104.0,23625856.0,23658432.0,23691008.0,23723584.0,23756160.0,23788736.0,23821312.0,23853888.0,23672192.0,23704832.0,23737472.0,23770112.0,23802752.0,23835392.0,23868032.0,23900672.0,23718528.0,23751232.0,23783936.0,23816640.0,23849344.0,23882048.0,23914752.0,23947456.0,23764864.0,23797632.0,23830400.0,23863168.0,23895936.0,23928704.0,23961472.0,23994240.0,23811200.0,23844032.0,23876864.0,23909696.0,23942528.0,23975360.0,24008192.0,24041024.0,22858112.0,22889984.0,22921856.0,22953728.0,22985600.0,23017472.0,23049344.0,23081216.0,21896832.0,21927744.0,21958656.0,21989568.0,22020480.0,22051392.0,22082304.0,22113216.0,21943168.0,21974144.0,22005120.0,22036096.0,22067072.0,22098048.0,22129024.0,22160000.0,21989504.0,22020544.0,22051584.0,22082624.0,22113664.0,22144704.0,22175744.0,22206784.0,22035840.0,22066944.0,22098048.0,22129152.0,22160256.0,22191360.0,22222464.0,22253568.0,22082176.0,22113344.0,22144512.0,22175680.0,22206848.0,22238016.0,22269184.0,22300352.0,22128512.0,22159744.0,22190976.0,22222208.0,22253440.0,22284672.0,22315904.0,22347136.0,22174848.0,22206144.0,22237440.0,22268736.0,22300032.0,22331328.0,22362624.0,22393920.0,22221184.0,22252544.0,22283904.0,22315264.0,22346624.0,22377984.0,22409344.0,22440704.0,22267520.0,22298944.0,22330368.0,22361792.0,22393216.0,22424640.0,22456064.0,22487488.0,22313856.0,22345344.0,22376832.0,22408320.0,22439808.0,22471296.0,22502784.0,22534272.0,22360192.0,22391744.0,22423296.0,22454848.0,22486400.0,22517952.0,22549504.0,22581056.0,22406528.0,22438144.0,22469760.0,22501376.0,22532992.0,22564608.0,22596224.0,22627840.0,22452864.0,22484544.0,22516224.0,22547904.0,22579584.0,22611264.0,22642944.0,22674624.0,22499200.0,22530944.0,22562688.0,22594432.0,22626176.0,22657920.0,22689664.0,22721408.0,22545536.0,22577344.0,22609152.0,22640960.0,22672768.0,22704576.0,22736384.0,22768192.0,22591872.0,22623744.0,22655616.0,22687488.0,22719360.0,22751232.0,22783104.0,22814976.0,22638208.0,22670144.0,22702080.0,22734016.0,22765952.0,22797888.0,22829824.0,22861760.0,22684544.0,22716544.0,22748544.0,22780544.0,22812544.0,22844544.0,22876544.0,22908544.0,22730880.0,22762944.0,22795008.0,22827072.0,22859136.0,22891200.0,22923264.0,22955328.0,22777216.0,22809344.0,22841472.0,22873600.0,22905728.0,22937856.0,22969984.0,23002112.0,22823552.0,22855744.0,22887936.0,22920128.0,22952320.0,22984512.0,23016704.0,23048896.0,22869888.0,22902144.0,22934400.0,22966656.0,22998912.0,23031168.0,23063424.0,23095680.0,22916224.0,22948544.0,22980864.0,23013184.0,23045504.0,23077824.0,23110144.0,23142464.0,22962560.0,22994944.0,23027328.0,23059712.0,23092096.0,23124480.0,23156864.0,23189248.0,23008896.0,23041344.0,23073792.0,23106240.0,23138688.0,23171136.0,23203584.0,23236032.0,23055232.0,23087744.0,23120256.0,23152768.0,23185280.0,23217792.0,23250304.0,23282816.0,23101568.0,23134144.0,23166720.0,23199296.0,23231872.0,23264448.0,23297024.0,23329600.0,23147904.0,23180544.0,23213184.0,23245824.0,23278464.0,23311104.0,23343744.0,23376384.0,23194240.0,23226944.0,23259648.0,23292352.0,23325056.0,23357760.0,23390464.0,23423168.0,23240576.0,23273344.0,23306112.0,23338880.0,23371648.0,23404416.0,23437184.0,23469952.0,23286912.0,23319744.0,23352576.0,23385408.0,23418240.0,23451072.0,23483904.0,23516736.0,23333248.0,23366144.0,23399040.0,23431936.0,23464832.0,23497728.0,23530624.0,23563520.0,23379584.0,23412544.0,23445504.0,23478464.0,23511424.0,23544384.0,23577344.0,23610304.0,23425920.0,23458944.0,23491968.0,23524992.0,23558016.0,23591040.0,23624064.0,23657088.0,23472256.0,23505344.0,23538432.0,23571520.0,23604608.0,23637696.0,23670784.0,23703872.0,23518592.0,23551744.0,23584896.0,23618048.0,23651200.0,23684352.0,23717504.0,23750656.0,23564928.0,23598144.0,23631360.0,23664576.0,23697792.0,23731008.0,23764224.0,23797440.0,23611264.0,23644544.0,23677824.0,23711104.0,23744384.0,23777664.0,23810944.0,23844224.0,23182464.0,23214784.0,23247104.0,23279424.0,23311744.0,23344064.0,23376384.0,23408704.0,22745472.0,22776832.0,22808192.0,22839552.0,22870912.0,22902272.0,22933632.0,22964992.0,22791808.0,22823232.0,22854656.0,22886080.0,22917504.0,22948928.0,22980352.0,23011776.0,22838144.0,22869632.0,22901120.0,22932608.0,22964096.0,22995584.0,23027072.0,23058560.0,22884480.0,22916032.0,22947584.0,22979136.0,23010688.0,23042240.0,23073792.0,23105344.0,22930816.0,22962432.0,22994048.0,23025664.0,23057280.0,23088896.0,23120512.0,23152128.0,22977152.0,23008832.0,23040512.0,23072192.0,23103872.0,23135552.0,23167232.0,23198912.0,23023488.0,23055232.0,23086976.0,23118720.0,23150464.0,23182208.0,23213952.0,23245696.0,23069824.0,23101632.0,23133440.0,23165248.0,23197056.0,23228864.0,23260672.0,23292480.0,23116160.0,23148032.0,23179904.0,23211776.0,23243648.0,23275520.0,23307392.0,23339264.0,23162496.0,23194432.0,23226368.0,23258304.0,23290240.0,23322176.0,23354112.0,23386048.0,23208832.0,23240832.0,23272832.0,23304832.0,23336832.0,23368832.0,23400832.0,23432832.0,23255168.0,23287232.0,23319296.0,23351360.0,23383424.0,23415488.0,23447552.0,23479616.0,23301504.0,23333632.0,23365760.0,23397888.0,23430016.0,23462144.0,23494272.0,23526400.0,23347840.0,23380032.0,23412224.0,23444416.0,23476608.0,23508800.0,23540992.0,23573184.0,23394176.0,23426432.0,23458688.0,23490944.0,23523200.0,23555456.0,23587712.0,23619968.0,23440512.0,23472832.0,23505152.0,23537472.0,23569792.0,23602112.0,23634432.0,23666752.0,23486848.0,23519232.0,23551616.0,23584000.0,23616384.0,23648768.0,23681152.0,23713536.0,23533184.0,23565632.0,23598080.0,23630528.0,23662976.0,23695424.0,23727872.0,23760320.0,23579520.0,23612032.0,23644544.0,23677056.0,23709568.0,23742080.0,23774592.0,23807104.0,23625856.0,23658432.0,23691008.0,23723584.0,23756160.0,23788736.0,23821312.0,23853888.0,23672192.0,23704832.0,23737472.0,23770112.0,23802752.0,23835392.0,23868032.0,23900672.0,23718528.0,23751232.0,23783936.0,23816640.0,23849344.0,23882048.0,23914752.0,23947456.0,23764864.0,23797632.0,23830400.0,23863168.0,23895936.0,23928704.0,23961472.0,23994240.0,23811200.0,23844032.0,23876864.0,23909696.0,23942528.0,23975360.0,24008192.0,24041024.0,23857536.0,23890432.0,23923328.0,23956224.0,23989120.0,24022016.0,24054912.0,24087808.0,23903872.0,23936832.0,23969792.0,24002752.0,24035712.0,24068672.0,24101632.0,24134592.0,23950208.0,23983232.0,24016256.0,24049280.0,24082304.0,24115328.0,24148352.0,24181376.0,23996544.0,24029632.0,24062720.0,24095808.0,24128896.0,24161984.0,24195072.0,24228160.0,24042880.0,24076032.0,24109184.0,24142336.0,24175488.0,24208640.0,24241792.0,24274944.0,24089216.0,24122432.0,24155648.0,24188864.0,24222080.0,24255296.0,24288512.0,24321728.0,24135552.0,24168832.0,24202112.0,24235392.0,24268672.0,24301952.0,24335232.0,24368512.0,24181888.0,24215232.0,24248576.0,24281920.0,24315264.0,24348608.0,24381952.0,24415296.0,24228224.0,24261632.0,24295040.0,24328448.0,24361856.0,24395264.0,24428672.0,24462080.0,24274560.0,24308032.0,24341504.0,24374976.0,24408448.0,24441920.0,24475392.0,24508864.0,24320896.0,24354432.0,24387968.0,24421504.0,24455040.0,24488576.0,24522112.0,24555648.0,24367232.0,24400832.0,24434432.0,24468032.0,24501632.0,24535232.0,24568832.0,24602432.0,24413568.0,24447232.0,24480896.0,24514560.0,24548224.0,24581888.0,24615552.0,24649216.0,24459904.0,24493632.0,24527360.0,24561088.0,24594816.0,24628544.0,24662272.0,24696000.0,23506816.0,23539584.0,23572352.0,23605120.0,23637888.0,23670656.0,23703424.0,23736192.0,22545536.0,22577344.0,22609152.0,22640960.0,22672768.0,22704576.0,22736384.0,22768192.0,11348096.0,11364016.0,11379936.0,11395856.0,11411776.0,11427696.0,11443616.0,11459536.0,11905536.0,11921328.0,11937120.0,11952912.0,11968704.0,11984496.0,12000288.0,12016080.0,23960960.0,23992576.0,24024192.0,24055808.0,24087424.0,24119040.0,24150656.0,24182272.0,24008320.0,24040000.0,24071680.0,24103360.0,24135040.0,24166720.0,24198400.0,24230080.0,24055680.0,24087424.0,24119168.0,24150912.0,24182656.0,24214400.0,24246144.0,24277888.0,24103040.0,24134848.0,24166656.0,24198464.0,24230272.0,24262080.0,24293888.0,24325696.0,24150400.0,24182272.0,24214144.0,24246016.0,24277888.0,24309760.0,24341632.0,24373504.0,24197760.0,24229696.0,24261632.0,24293568.0,24325504.0,24357440.0,24389376.0,24421312.0,24245120.0,24277120.0,24309120.0,24341120.0,24373120.0,24405120.0,24437120.0,24469120.0,24292480.0,24324544.0,24356608.0,24388672.0,24420736.0,24452800.0,24484864.0,24516928.0,24339840.0,24371968.0,24404096.0,24436224.0,24468352.0,24500480.0,24532608.0,24564736.0,24387200.0,24419392.0,24451584.0,24483776.0,24515968.0,24548160.0,24580352.0,24612544.0,24434560.0,24466816.0,24499072.0,24531328.0,24563584.0,24595840.0,24628096.0,24660352.0,24481920.0,24514240.0,24546560.0,24578880.0,24611200.0,24643520.0,24675840.0,24708160.0,24529280.0,24561664.0,24594048.0,24626432.0,24658816.0,24691200.0,24723584.0,24755968.0,24576640.0,24609088.0,24641536.0,24673984.0,24706432.0,24738880.0,24771328.0,24803776.0,24624000.0,24656512.0,24689024.0,24721536.0,24754048.0,24786560.0,24819072.0,24851584.0,24671360.0,24703936.0,24736512.0,24769088.0,24801664.0,24834240.0,24866816.0,24899392.0,24718720.0,24751360.0,24784000.0,24816640.0,24849280.0,24881920.0,24914560.0,24947200.0,24766080.0,24798784.0,24831488.0,24864192.0,24896896.0,24929600.0,24962304.0,24995008.0,24813440.0,24846208.0,24878976.0,24911744.0,24944512.0,24977280.0,25010048.0,25042816.0,24860800.0,24893632.0,24926464.0,24959296.0,24992128.0,25024960.0,25057792.0,25090624.0,24908160.0,24941056.0,24973952.0,25006848.0,25039744.0,25072640.0,25105536.0,25138432.0,24955520.0,24988480.0,25021440.0,25054400.0,25087360.0,25120320.0,25153280.0,25186240.0,25002880.0,25035904.0,25068928.0,25101952.0,25134976.0,25168000.0,25201024.0,25234048.0,25050240.0,25083328.0,25116416.0,25149504.0,25182592.0,25215680.0,25248768.0,25281856.0,25097600.0,25130752.0,25163904.0,25197056.0,25230208.0,25263360.0,25296512.0,25329664.0,25144960.0,25178176.0,25211392.0,25244608.0,25277824.0,25311040.0,25344256.0,25377472.0,25192320.0,25225600.0,25258880.0,25292160.0,25325440.0,25358720.0,25392000.0,25425280.0,25239680.0,25273024.0,25306368.0,25339712.0,25373056.0,25406400.0,25439744.0,25473088.0,25287040.0,25320448.0,25353856.0,25387264.0,25420672.0,25454080.0,25487488.0,25520896.0,25334400.0,25367872.0,25401344.0,25434816.0,25468288.0,25501760.0,25535232.0,25568704.0,25381760.0,25415296.0,25448832.0,25482368.0,25515904.0,25549440.0,25582976.0,25616512.0,25429120.0,25462720.0,25496320.0,25529920.0,25563520.0,25597120.0,25630720.0,25664320.0,25476480.0,25510144.0,25543808.0,25577472.0,25611136.0,25644800.0,25678464.0,25712128.0,25523840.0,25557568.0,25591296.0,25625024.0,25658752.0,25692480.0,25726208.0,25759936.0,25571200.0,25604992.0,25638784.0,25672576.0,25706368.0,25740160.0,25773952.0,25807744.0,25618560.0,25652416.0,25686272.0,25720128.0,25753984.0,25787840.0,25821696.0,25855552.0,24650112.0,24683008.0,24715904.0,24748800.0,24781696.0,24814592.0,24847488.0,24880384.0,23673472.0,23705408.0,23737344.0,23769280.0,23801216.0,23833152.0,23865088.0,23897024.0,22705024.0,22736000.0,22766976.0,22797952.0,22828928.0,22859904.0,22890880.0,22921856.0,21728384.0,21758400.0,21788416.0,21818432.0,21848448.0,21878464.0,21908480.0,21938496.0,21775744.0,21805824.0,21835904.0,21865984.0,21896064.0,21926144.0,21956224.0,21986304.0,21823104.0,21853248.0,21883392.0,21913536.0,21943680.0,21973824.0,22003968.0,22034112.0,21870464.0,21900672.0,21930880.0,21961088.0,21991296.0,22021504.0,22051712.0,22081920.0,21917824.0,21948096.0,21978368.0,22008640.0,22038912.0,22069184.0,22099456.0,22129728.0,21965184.0,21995520.0,22025856.0,22056192.0,22086528.0,22116864.0,22147200.0,22177536.0,22012544.0,22042944.0,22073344.0,22103744.0,22134144.0,22164544.0,22194944.0,22225344.0,22059904.0,22090368.0,22120832.0,22151296.0,22181760.0,22212224.0,22242688.0,22273152.0,22107264.0,22137792.0,22168320.0,22198848.0,22229376.0,22259904.0,22290432.0,22320960.0,22154624.0,22185216.0,22215808.0,22246400.0,22276992.0,22307584.0,22338176.0,22368768.0,22201984.0,22232640.0,22263296.0,22293952.0,22324608.0,22355264.0,22385920.0,22416576.0,22249344.0,22280064.0,22310784.0,22341504.0,22372224.0,22402944.0,22433664.0,22464384.0,22296704.0,22327488.0,22358272.0,22389056.0,22419840.0,22450624.0,22481408.0,22512192.0,22344064.0,22374912.0,22405760.0,22436608.0,22467456.0,22498304.0,22529152.0,22560000.0,22391424.0,22422336.0,22453248.0,22484160.0,22515072.0,22545984.0,22576896.0,22607808.0,22438784.0,22469760.0,22500736.0,22531712.0,22562688.0,22593664.0,22624640.0,22655616.0,22486144.0,22517184.0,22548224.0,22579264.0,22610304.0,22641344.0,22672384.0,22703424.0,22533504.0,22564608.0,22595712.0,22626816.0,22657920.0,22689024.0,22720128.0,22751232.0,22580864.0,22612032.0,22643200.0,22674368.0,22705536.0,22736704.0,22767872.0,22799040.0,22628224.0,22659456.0,22690688.0,22721920.0,22753152.0,22784384.0,22815616.0,22846848.0,22675584.0,22706880.0,22738176.0,22769472.0,22800768.0,22832064.0,22863360.0,22894656.0,22722944.0,22754304.0,22785664.0,22817024.0,22848384.0,22879744.0,22911104.0,22942464.0,22770304.0,22801728.0,22833152.0,22864576.0,22896000.0,22927424.0,22958848.0,22990272.0,22817664.0,22849152.0,22880640.0,22912128.0,22943616.0,22975104.0,23006592.0,23038080.0,22865024.0,22896576.0,22928128.0,22959680.0,22991232.0,23022784.0,23054336.0,23085888.0,22912384.0,22944000.0,22975616.0,23007232.0,23038848.0,23070464.0,23102080.0,23133696.0,22959744.0,22991424.0,23023104.0,23054784.0,23086464.0,23118144.0,23149824.0,23181504.0,23007104.0,23038848.0,23070592.0,23102336.0,23134080.0,23165824.0,23197568.0,23229312.0,23054464.0,23086272.0,23118080.0,23149888.0,23181696.0,23213504.0,23245312.0,23277120.0,23101824.0,23133696.0,23165568.0,23197440.0,23229312.0,23261184.0,23293056.0,23324928.0,23149184.0,23181120.0,23213056.0,23244992.0,23276928.0,23308864.0,23340800.0,23372736.0,23196544.0,23228544.0,23260544.0,23292544.0,23324544.0,23356544.0,23388544.0,23420544.0,23243904.0,23275968.0,23308032.0,23340096.0,23372160.0,23404224.0,23436288.0,23468352.0,23291264.0,23323392.0,23355520.0,23387648.0,23419776.0,23451904.0,23484032.0,23516160.0,23338624.0,23370816.0,23403008.0,23435200.0,23467392.0,23499584.0,23531776.0,23563968.0,23385984.0,23418240.0,23450496.0,23482752.0,23515008.0,23547264.0,23579520.0,23611776.0,23433344.0,23465664.0,23497984.0,23530304.0,23562624.0,23594944.0,23627264.0,23659584.0,23480704.0,23513088.0,23545472.0,23577856.0,23610240.0,23642624.0,23675008.0,23707392.0,23036544.0,23067968.0,23099392.0,23130816.0,23162240.0,23193664.0,23225088.0,23256512.0,22584192.0,22614656.0,22645120.0,22675584.0,22706048.0,22736512.0,22766976.0,22797440.0,22631552.0,22662080.0,22692608.0,22723136.0,22753664.0,22784192.0,22814720.0,22845248.0,22678912.0,22709504.0,22740096.0,22770688.0,22801280.0,22831872.0,22862464.0,22893056.0,22726272.0,22756928.0,22787584.0,22818240.0,22848896.0,22879552.0,22910208.0,22940864.0,22773632.0,22804352.0,22835072.0,22865792.0,22896512.0,22927232.0,22957952.0,22988672.0,22820992.0,22851776.0,22882560.0,22913344.0,22944128.0,22974912.0,23005696.0,23036480.0,22868352.0,22899200.0,22930048.0,22960896.0,22991744.0,23022592.0,23053440.0,23084288.0,22915712.0,22946624.0,22977536.0,23008448.0,23039360.0,23070272.0,23101184.0,23132096.0,22963072.0,22994048.0,23025024.0,23056000.0,23086976.0,23117952.0,23148928.0,23179904.0,23010432.0,23041472.0,23072512.0,23103552.0,23134592.0,23165632.0,23196672.0,23227712.0,23057792.0,23088896.0,23120000.0,23151104.0,23182208.0,23213312.0,23244416.0,23275520.0,23105152.0,23136320.0,23167488.0,23198656.0,23229824.0,23260992.0,23292160.0,23323328.0,23152512.0,23183744.0,23214976.0,23246208.0,23277440.0,23308672.0,23339904.0,23371136.0,23199872.0,23231168.0,23262464.0,23293760.0,23325056.0,23356352.0,23387648.0,23418944.0,23247232.0,23278592.0,23309952.0,23341312.0,23372672.0,23404032.0,23435392.0,23466752.0,23294592.0,23326016.0,23357440.0,23388864.0,23420288.0,23451712.0,23483136.0,23514560.0,23341952.0,23373440.0,23404928.0,23436416.0,23467904.0,23499392.0,23530880.0,23562368.0,23389312.0,23420864.0,23452416.0,23483968.0,23515520.0,23547072.0,23578624.0,23610176.0,23436672.0,23468288.0,23499904.0,23531520.0,23563136.0,23594752.0,23626368.0,23657984.0,23484032.0,23515712.0,23547392.0,23579072.0,23610752.0,23642432.0,23674112.0,23705792.0,23531392.0,23563136.0,23594880.0,23626624.0,23658368.0,23690112.0,23721856.0,23753600.0,23578752.0,23610560.0,23642368.0,23674176.0,23705984.0,23737792.0,23769600.0,23801408.0,23626112.0,23657984.0,23689856.0,23721728.0,23753600.0,23785472.0,23817344.0,23849216.0,23673472.0,23705408.0,23737344.0,23769280.0,23801216.0,23833152.0,23865088.0,23897024.0,23720832.0,23752832.0,23784832.0,23816832.0,23848832.0,23880832.0,23912832.0,23944832.0,23768192.0,23800256.0,23832320.0,23864384.0,23896448.0,23928512.0,23960576.0,23992640.0,23815552.0,23847680.0,23879808.0,23911936.0,23944064.0,23976192.0,24008320.0,24040448.0,23862912.0,23895104.0,23927296.0,23959488.0,23991680.0,24023872.0,24056064.0,24088256.0,23910272.0,23942528.0,23974784.0,24007040.0,24039296.0,24071552.0,24103808.0,24136064.0,23957632.0,23989952.0,24022272.0,24054592.0,24086912.0,24119232.0,24151552.0,24183872.0,24004992.0,24037376.0,24069760.0,24102144.0,24134528.0,24166912.0,24199296.0,24231680.0,24052352.0,24084800.0,24117248.0,24149696.0,24182144.0,24214592.0,24247040.0,24279488.0,24099712.0,24132224.0,24164736.0,24197248.0,24229760.0,24262272.0,24294784.0,24327296.0,24147072.0,24179648.0,24212224.0,24244800.0,24277376.0,24309952.0,24342528.0,24375104.0,24194432.0,24227072.0,24259712.0,24292352.0,24324992.0,24357632.0,24390272.0,24422912.0,24241792.0,24274496.0,24307200.0,24339904.0,24372608.0,24405312.0,24438016.0,24470720.0,24289152.0,24321920.0,24354688.0,24387456.0,24420224.0,24452992.0,24485760.0,24518528.0,24336512.0,24369344.0,24402176.0,24435008.0,24467840.0,24500672.0,24533504.0,24566336.0,23368064.0,23399936.0,23431808.0,23463680.0,23495552.0,23527424.0,23559296.0,23591168.0,22391424.0,22422336.0,22453248.0,22484160.0,22515072.0,22545984.0,22576896.0,22607808.0,22438784.0,22469760.0,22500736.0,22531712.0,22562688.0,22593664.0,22624640.0,22655616.0,22486144.0,22517184.0,22548224.0,22579264.0,22610304.0,22641344.0,22672384.0,22703424.0,22533504.0,22564608.0,22595712.0,22626816.0,22657920.0,22689024.0,22720128.0,22751232.0,22580864.0,22612032.0,22643200.0,22674368.0,22705536.0,22736704.0,22767872.0,22799040.0,22628224.0,22659456.0,22690688.0,22721920.0,22753152.0,22784384.0,22815616.0,22846848.0,22675584.0,22706880.0,22738176.0,22769472.0,22800768.0,22832064.0,22863360.0,22894656.0,22722944.0,22754304.0,22785664.0,22817024.0,22848384.0,22879744.0,22911104.0,22942464.0,22770304.0,22801728.0,22833152.0,22864576.0,22896000.0,22927424.0,22958848.0,22990272.0,22817664.0,22849152.0,22880640.0,22912128.0,22943616.0,22975104.0,23006592.0,23038080.0,22865024.0,22896576.0,22928128.0,22959680.0,22991232.0,23022784.0,23054336.0,23085888.0,22912384.0,22944000.0,22975616.0,23007232.0,23038848.0,23070464.0,23102080.0,23133696.0,22959744.0,22991424.0,23023104.0,23054784.0,23086464.0,23118144.0,23149824.0,23181504.0,23007104.0,23038848.0,23070592.0,23102336.0,23134080.0,23165824.0,23197568.0,23229312.0,23054464.0,23086272.0,23118080.0,23149888.0,23181696.0,23213504.0,23245312.0,23277120.0,23101824.0,23133696.0,23165568.0,23197440.0,23229312.0,23261184.0,23293056.0,23324928.0,23149184.0,23181120.0,23213056.0,23244992.0,23276928.0,23308864.0,23340800.0,23372736.0,23196544.0,23228544.0,23260544.0,23292544.0,23324544.0,23356544.0,23388544.0,23420544.0,23243904.0,23275968.0,23308032.0,23340096.0,23372160.0,23404224.0,23436288.0,23468352.0,23291264.0,23323392.0,23355520.0,23387648.0,23419776.0,23451904.0,23484032.0,23516160.0,23338624.0,23370816.0,23403008.0,23435200.0,23467392.0,23499584.0,23531776.0,23563968.0,23385984.0,23418240.0,23450496.0,23482752.0,23515008.0,23547264.0,23579520.0,23611776.0,23433344.0,23465664.0,23497984.0,23530304.0,23562624.0,23594944.0,23627264.0,23659584.0,23480704.0,23513088.0,23545472.0,23577856.0,23610240.0,23642624.0,23675008.0,23707392.0,23528064.0,23560512.0,23592960.0,23625408.0,23657856.0,23690304.0,23722752.0,23755200.0,23575424.0,23607936.0,23640448.0,23672960.0,23705472.0,23737984.0,23770496.0,23803008.0,23622784.0,23655360.0,23687936.0,23720512.0,23753088.0,23785664.0,23818240.0,23850816.0,23670144.0,23702784.0,23735424.0,23768064.0,23800704.0,23833344.0,23865984.0,23898624.0,23717504.0,23750208.0,23782912.0,23815616.0,23848320.0,23881024.0,23913728.0,23946432.0,23764864.0,23797632.0,23830400.0,23863168.0,23895936.0,23928704.0,23961472.0,23994240.0,23812224.0,23845056.0,23877888.0,23910720.0,23943552.0,23976384.0,24009216.0,24042048.0,23859584.0,23892480.0,23925376.0,23958272.0,23991168.0,24024064.0,24056960.0,24089856.0,23906944.0,23939904.0,23972864.0,24005824.0,24038784.0,24071744.0,24104704.0,24137664.0,23954304.0,23987328.0,24020352.0,24053376.0,24086400.0,24119424.0,24152448.0,24185472.0,24001664.0,24034752.0,24067840.0,24100928.0,24134016.0,24167104.0,24200192.0,24233280.0,24049024.0,24082176.0,24115328.0,24148480.0,24181632.0,24214784.0,24247936.0,24281088.0,24096384.0,24129600.0,24162816.0,24196032.0,24229248.0,24262464.0,24295680.0,24328896.0,24143744.0,24177024.0,24210304.0,24243584.0,24276864.0,24310144.0,24343424.0,24376704.0,23699584.0,23731904.0,23764224.0,23796544.0,23828864.0,23861184.0,23893504.0,23925824.0,23247232.0,23278592.0,23309952.0,23341312.0,23372672.0,23404032.0,23435392.0,23466752.0,23294592.0,23326016.0,23357440.0,23388864.0,23420288.0,23451712.0,23483136.0,23514560.0,23341952.0,23373440.0,23404928.0,23436416.0,23467904.0,23499392.0,23530880.0,23562368.0,23389312.0,23420864.0,23452416.0,23483968.0,23515520.0,23547072.0,23578624.0,23610176.0,23436672.0,23468288.0,23499904.0,23531520.0,23563136.0,23594752.0,23626368.0,23657984.0,23484032.0,23515712.0,23547392.0,23579072.0,23610752.0,23642432.0,23674112.0,23705792.0,23531392.0,23563136.0,23594880.0,23626624.0,23658368.0,23690112.0,23721856.0,23753600.0,23578752.0,23610560.0,23642368.0,23674176.0,23705984.0,23737792.0,23769600.0,23801408.0,23626112.0,23657984.0,23689856.0,23721728.0,23753600.0,23785472.0,23817344.0,23849216.0,23673472.0,23705408.0,23737344.0,23769280.0,23801216.0,23833152.0,23865088.0,23897024.0,23720832.0,23752832.0,23784832.0,23816832.0,23848832.0,23880832.0,23912832.0,23944832.0,23768192.0,23800256.0,23832320.0,23864384.0,23896448.0,23928512.0,23960576.0,23992640.0,23815552.0,23847680.0,23879808.0,23911936.0,23944064.0,23976192.0,24008320.0,24040448.0,23862912.0,23895104.0,23927296.0,23959488.0,23991680.0,24023872.0,24056064.0,24088256.0,23910272.0,23942528.0,23974784.0,24007040.0,24039296.0,24071552.0,24103808.0,24136064.0,23957632.0,23989952.0,24022272.0,24054592.0,24086912.0,24119232.0,24151552.0,24183872.0,24004992.0,24037376.0,24069760.0,24102144.0,24134528.0,24166912.0,24199296.0,24231680.0,24052352.0,24084800.0,24117248.0,24149696.0,24182144.0,24214592.0,24247040.0,24279488.0,24099712.0,24132224.0,24164736.0,24197248.0,24229760.0,24262272.0,24294784.0,24327296.0,24147072.0,24179648.0,24212224.0,24244800.0,24277376.0,24309952.0,24342528.0,24375104.0,24194432.0,24227072.0,24259712.0,24292352.0,24324992.0,24357632.0,24390272.0,24422912.0,24241792.0,24274496.0,24307200.0,24339904.0,24372608.0,24405312.0,24438016.0,24470720.0,24289152.0,24321920.0,24354688.0,24387456.0,24420224.0,24452992.0,24485760.0,24518528.0,24336512.0,24369344.0,24402176.0,24435008.0,24467840.0,24500672.0,24533504.0,24566336.0,24383872.0,24416768.0,24449664.0,24482560.0,24515456.0,24548352.0,24581248.0,24614144.0,24431232.0,24464192.0,24497152.0,24530112.0,24563072.0,24596032.0,24628992.0,24661952.0,24478592.0,24511616.0,24544640.0,24577664.0,24610688.0,24643712.0,24676736.0,24709760.0,24525952.0,24559040.0,24592128.0,24625216.0,24658304.0,24691392.0,24724480.0,24757568.0,24573312.0,24606464.0,24639616.0,24672768.0,24705920.0,24739072.0,24772224.0,24805376.0,24620672.0,24653888.0,24687104.0,24720320.0,24753536.0,24786752.0,24819968.0,24853184.0,24668032.0,24701312.0,24734592.0,24767872.0,24801152.0,24834432.0,24867712.0,24900992.0,24715392.0,24748736.0,24782080.0,24815424.0,24848768.0,24882112.0,24915456.0,24948800.0,24762752.0,24796160.0,24829568.0,24862976.0,24896384.0,24929792.0,24963200.0,24996608.0,24810112.0,24843584.0,24877056.0,24910528.0,24944000.0,24977472.0,25010944.0,25044416.0,24857472.0,24891008.0,24924544.0,24958080.0,24991616.0,25025152.0,25058688.0,25092224.0,24904832.0,24938432.0,24972032.0,25005632.0,25039232.0,25072832.0,25106432.0,25140032.0,24952192.0,24985856.0,25019520.0,25053184.0,25086848.0,25120512.0,25154176.0,25187840.0,24999552.0,25033280.0,25067008.0,25100736.0,25134464.0,25168192.0,25201920.0,25235648.0,24031104.0,24063872.0,24096640.0,24129408.0,24162176.0,24194944.0,24227712.0,24260480.0,23054464.0,23086272.0,23118080.0,23149888.0,23181696.0,23213504.0,23245312.0,23277120.0,11602816.0,11618736.0,11634656.0,11650576.0,11666496.0,11682416.0,11698336.0,11714256.0,12158208.0,12174000.0,12189792.0,12205584.0,12221376.0,12237168.0,12252960.0,12268752.0,24466816.0,24498432.0,24530048.0,24561664.0,24593280.0,24624896.0,24656512.0,24688128.0,24515200.0,24546880.0,24578560.0,24610240.0,24641920.0,24673600.0,24705280.0,24736960.0,24563584.0,24595328.0,24627072.0,24658816.0,24690560.0,24722304.0,24754048.0,24785792.0,24611968.0,24643776.0,24675584.0,24707392.0,24739200.0,24771008.0,24802816.0,24834624.0,24660352.0,24692224.0,24724096.0,24755968.0,24787840.0,24819712.0,24851584.0,24883456.0,24708736.0,24740672.0,24772608.0,24804544.0,24836480.0,24868416.0,24900352.0,24932288.0,24757120.0,24789120.0,24821120.0,24853120.0,24885120.0,24917120.0,24949120.0,24981120.0,24805504.0,24837568.0,24869632.0,24901696.0,24933760.0,24965824.0,24997888.0,25029952.0,24853888.0,24886016.0,24918144.0,24950272.0,24982400.0,25014528.0,25046656.0,25078784.0,24902272.0,24934464.0,24966656.0,24998848.0,25031040.0,25063232.0,25095424.0,25127616.0,24950656.0,24982912.0,25015168.0,25047424.0,25079680.0,25111936.0,25144192.0,25176448.0,24999040.0,25031360.0,25063680.0,25096000.0,25128320.0,25160640.0,25192960.0,25225280.0,25047424.0,25079808.0,25112192.0,25144576.0,25176960.0,25209344.0,25241728.0,25274112.0,25095808.0,25128256.0,25160704.0,25193152.0,25225600.0,25258048.0,25290496.0,25322944.0,25144192.0,25176704.0,25209216.0,25241728.0,25274240.0,25306752.0,25339264.0,25371776.0,25192576.0,25225152.0,25257728.0,25290304.0,25322880.0,25355456.0,25388032.0,25420608.0,25240960.0,25273600.0,25306240.0,25338880.0,25371520.0,25404160.0,25436800.0,25469440.0,25289344.0,25322048.0,25354752.0,25387456.0,25420160.0,25452864.0,25485568.0,25518272.0,25337728.0,25370496.0,25403264.0,25436032.0,25468800.0,25501568.0,25534336.0,25567104.0,25386112.0,25418944.0,25451776.0,25484608.0,25517440.0,25550272.0,25583104.0,25615936.0,25434496.0,25467392.0,25500288.0,25533184.0,25566080.0,25598976.0,25631872.0,25664768.0,25482880.0,25515840.0,25548800.0,25581760.0,25614720.0,25647680.0,25680640.0,25713600.0,25531264.0,25564288.0,25597312.0,25630336.0,25663360.0,25696384.0,25729408.0,25762432.0,25579648.0,25612736.0,25645824.0,25678912.0,25712000.0,25745088.0,25778176.0,25811264.0,25628032.0,25661184.0,25694336.0,25727488.0,25760640.0,25793792.0,25826944.0,25860096.0,25676416.0,25709632.0,25742848.0,25776064.0,25809280.0,25842496.0,25875712.0,25908928.0,25724800.0,25758080.0,25791360.0,25824640.0,25857920.0,25891200.0,25924480.0,25957760.0,25773184.0,25806528.0,25839872.0,25873216.0,25906560.0,25939904.0,25973248.0,26006592.0,25821568.0,25854976.0,25888384.0,25921792.0,25955200.0,25988608.0,26022016.0,26055424.0,25869952.0,25903424.0,25936896.0,25970368.0,26003840.0,26037312.0,26070784.0,26104256.0,25918336.0,25951872.0,25985408.0,26018944.0,26052480.0,26086016.0,26119552.0,26153088.0,25966720.0,26000320.0,26033920.0,26067520.0,26101120.0,26134720.0,26168320.0,26201920.0,26015104.0,26048768.0,26082432.0,26116096.0,26149760.0,26183424.0,26217088.0,26250752.0,26063488.0,26097216.0,26130944.0,26164672.0,26198400.0,26232128.0,26265856.0,26299584.0,26111872.0,26145664.0,26179456.0,26213248.0,26247040.0,26280832.0,26314624.0,26348416.0,26160256.0,26194112.0,26227968.0,26261824.0,26295680.0,26329536.0,26363392.0,26397248.0,25176448.0,25209344.0,25242240.0,25275136.0,25308032.0,25340928.0,25373824.0,25406720.0,24184448.0,24216384.0,24248320.0,24280256.0,24312192.0,24344128.0,24376064.0,24408000.0,23200640.0,23231616.0,23262592.0,23293568.0,23324544.0,23355520.0,23386496.0,23417472.0,22208640.0,22238656.0,22268672.0,22298688.0,22328704.0,22358720.0,22388736.0,22418752.0,22257024.0,22287104.0,22317184.0,22347264.0,22377344.0,22407424.0,22437504.0,22467584.0,22305408.0,22335552.0,22365696.0,22395840.0,22425984.0,22456128.0,22486272.0,22516416.0,22353792.0,22384000.0,22414208.0,22444416.0,22474624.0,22504832.0,22535040.0,22565248.0,22402176.0,22432448.0,22462720.0,22492992.0,22523264.0,22553536.0,22583808.0,22614080.0,22450560.0,22480896.0,22511232.0,22541568.0,22571904.0,22602240.0,22632576.0,22662912.0,22498944.0,22529344.0,22559744.0,22590144.0,22620544.0,22650944.0,22681344.0,22711744.0,22547328.0,22577792.0,22608256.0,22638720.0,22669184.0,22699648.0,22730112.0,22760576.0,22595712.0,22626240.0,22656768.0,22687296.0,22717824.0,22748352.0,22778880.0,22809408.0,22644096.0,22674688.0,22705280.0,22735872.0,22766464.0,22797056.0,22827648.0,22858240.0,22692480.0,22723136.0,22753792.0,22784448.0,22815104.0,22845760.0,22876416.0,22907072.0,22740864.0,22771584.0,22802304.0,22833024.0,22863744.0,22894464.0,22925184.0,22955904.0,22789248.0,22820032.0,22850816.0,22881600.0,22912384.0,22943168.0,22973952.0,23004736.0,22837632.0,22868480.0,22899328.0,22930176.0,22961024.0,22991872.0,23022720.0,23053568.0,22886016.0,22916928.0,22947840.0,22978752.0,23009664.0,23040576.0,23071488.0,23102400.0,22934400.0,22965376.0,22996352.0,23027328.0,23058304.0,23089280.0,23120256.0,23151232.0,22982784.0,23013824.0,23044864.0,23075904.0,23106944.0,23137984.0,23169024.0,23200064.0,23031168.0,23062272.0,23093376.0,23124480.0,23155584.0,23186688.0,23217792.0,23248896.0,23079552.0,23110720.0,23141888.0,23173056.0,23204224.0,23235392.0,23266560.0,23297728.0,23127936.0,23159168.0,23190400.0,23221632.0,23252864.0,23284096.0,23315328.0,23346560.0,23176320.0,23207616.0,23238912.0,23270208.0,23301504.0,23332800.0,23364096.0,23395392.0,23224704.0,23256064.0,23287424.0,23318784.0,23350144.0,23381504.0,23412864.0,23444224.0,23273088.0,23304512.0,23335936.0,23367360.0,23398784.0,23430208.0,23461632.0,23493056.0,23321472.0,23352960.0,23384448.0,23415936.0,23447424.0,23478912.0,23510400.0,23541888.0,23369856.0,23401408.0,23432960.0,23464512.0,23496064.0,23527616.0,23559168.0,23590720.0,23418240.0,23449856.0,23481472.0,23513088.0,23544704.0,23576320.0,23607936.0,23639552.0,23466624.0,23498304.0,23529984.0,23561664.0,23593344.0,23625024.0,23656704.0,23688384.0,23515008.0,23546752.0,23578496.0,23610240.0,23641984.0,23673728.0,23705472.0,23737216.0,23563392.0,23595200.0,23627008.0,23658816.0,23690624.0,23722432.0,23754240.0,23786048.0,23611776.0,23643648.0,23675520.0,23707392.0,23739264.0,23771136.0,23803008.0,23834880.0,23660160.0,23692096.0,23724032.0,23755968.0,23787904.0,23819840.0,23851776.0,23883712.0,23708544.0,23740544.0,23772544.0,23804544.0,23836544.0,23868544.0,23900544.0,23932544.0,23756928.0,23788992.0,23821056.0,23853120.0,23885184.0,23917248.0,23949312.0,23981376.0,23805312.0,23837440.0,23869568.0,23901696.0,23933824.0,23965952.0,23998080.0,24030208.0,23853696.0,23885888.0,23918080.0,23950272.0,23982464.0,24014656.0,24046848.0,24079040.0,23902080.0,23934336.0,23966592.0,23998848.0,24031104.0,24063360.0,24095616.0,24127872.0,23950464.0,23982784.0,24015104.0,24047424.0,24079744.0,24112064.0,24144384.0,24176704.0,23998848.0,24031232.0,24063616.0,24096000.0,24128384.0,24160768.0,24193152.0,24225536.0,23539328.0,23570752.0,23602176.0,23633600.0,23665024.0,23696448.0,23727872.0,23759296.0,23071616.0,23102080.0,23132544.0,23163008.0,23193472.0,23223936.0,23254400.0,23284864.0,23120000.0,23150528.0,23181056.0,23211584.0,23242112.0,23272640.0,23303168.0,23333696.0,23168384.0,23198976.0,23229568.0,23260160.0,23290752.0,23321344.0,23351936.0,23382528.0,23216768.0,23247424.0,23278080.0,23308736.0,23339392.0,23370048.0,23400704.0,23431360.0,23265152.0,23295872.0,23326592.0,23357312.0,23388032.0,23418752.0,23449472.0,23480192.0,23313536.0,23344320.0,23375104.0,23405888.0,23436672.0,23467456.0,23498240.0,23529024.0,23361920.0,23392768.0,23423616.0,23454464.0,23485312.0,23516160.0,23547008.0,23577856.0,23410304.0,23441216.0,23472128.0,23503040.0,23533952.0,23564864.0,23595776.0,23626688.0,23458688.0,23489664.0,23520640.0,23551616.0,23582592.0,23613568.0,23644544.0,23675520.0,23507072.0,23538112.0,23569152.0,23600192.0,23631232.0,23662272.0,23693312.0,23724352.0,23555456.0,23586560.0,23617664.0,23648768.0,23679872.0,23710976.0,23742080.0,23773184.0,23603840.0,23635008.0,23666176.0,23697344.0,23728512.0,23759680.0,23790848.0,23822016.0,23652224.0,23683456.0,23714688.0,23745920.0,23777152.0,23808384.0,23839616.0,23870848.0,23700608.0,23731904.0,23763200.0,23794496.0,23825792.0,23857088.0,23888384.0,23919680.0,23748992.0,23780352.0,23811712.0,23843072.0,23874432.0,23905792.0,23937152.0,23968512.0,23797376.0,23828800.0,23860224.0,23891648.0,23923072.0,23954496.0,23985920.0,24017344.0,23845760.0,23877248.0,23908736.0,23940224.0,23971712.0,24003200.0,24034688.0,24066176.0,23894144.0,23925696.0,23957248.0,23988800.0,24020352.0,24051904.0,24083456.0,24115008.0,23942528.0,23974144.0,24005760.0,24037376.0,24068992.0,24100608.0,24132224.0,24163840.0,23990912.0,24022592.0,24054272.0,24085952.0,24117632.0,24149312.0,24180992.0,24212672.0,24039296.0,24071040.0,24102784.0,24134528.0,24166272.0,24198016.0,24229760.0,24261504.0,24087680.0,24119488.0,24151296.0,24183104.0,24214912.0,24246720.0,24278528.0,24310336.0,24136064.0,24167936.0,24199808.0,24231680.0,24263552.0,24295424.0,24327296.0,24359168.0,24184448.0,24216384.0,24248320.0,24280256.0,24312192.0,24344128.0,24376064.0,24408000.0,24232832.0,24264832.0,24296832.0,24328832.0,24360832.0,24392832.0,24424832.0,24456832.0,24281216.0,24313280.0,24345344.0,24377408.0,24409472.0,24441536.0,24473600.0,24505664.0,24329600.0,24361728.0,24393856.0,24425984.0,24458112.0,24490240.0,24522368.0,24554496.0,24377984.0,24410176.0,24442368.0,24474560.0,24506752.0,24538944.0,24571136.0,24603328.0,24426368.0,24458624.0,24490880.0,24523136.0,24555392.0,24587648.0,24619904.0,24652160.0,24474752.0,24507072.0,24539392.0,24571712.0,24604032.0,24636352.0,24668672.0,24700992.0,24523136.0,24555520.0,24587904.0,24620288.0,24652672.0,24685056.0,24717440.0,24749824.0,24571520.0,24603968.0,24636416.0,24668864.0,24701312.0,24733760.0,24766208.0,24798656.0,24619904.0,24652416.0,24684928.0,24717440.0,24749952.0,24782464.0,24814976.0,24847488.0,24668288.0,24700864.0,24733440.0,24766016.0,24798592.0,24831168.0,24863744.0,24896320.0,24716672.0,24749312.0,24781952.0,24814592.0,24847232.0,24879872.0,24912512.0,24945152.0,24765056.0,24797760.0,24830464.0,24863168.0,24895872.0,24928576.0,24961280.0,24993984.0,24813440.0,24846208.0,24878976.0,24911744.0,24944512.0,24977280.0,25010048.0,25042816.0,24861824.0,24894656.0,24927488.0,24960320.0,24993152.0,25025984.0,25058816.0,25091648.0,23878016.0,23909888.0,23941760.0,23973632.0,24005504.0,24037376.0,24069248.0,24101120.0,22886016.0,22916928.0,22947840.0,22978752.0,23009664.0,23040576.0,23071488.0,23102400.0,22934400.0,22965376.0,22996352.0,23027328.0,23058304.0,23089280.0,23120256.0,23151232.0,22982784.0,23013824.0,23044864.0,23075904.0,23106944.0,23137984.0,23169024.0,23200064.0,23031168.0,23062272.0,23093376.0,23124480.0,23155584.0,23186688.0,23217792.0,23248896.0,23079552.0,23110720.0,23141888.0,23173056.0,23204224.0,23235392.0,23266560.0,23297728.0,23127936.0,23159168.0,23190400.0,23221632.0,23252864.0,23284096.0,23315328.0,23346560.0,23176320.0,23207616.0,23238912.0,23270208.0,23301504.0,23332800.0,23364096.0,23395392.0,23224704.0,23256064.0,23287424.0,23318784.0,23350144.0,23381504.0,23412864.0,23444224.0,23273088.0,23304512.0,23335936.0,23367360.0,23398784.0,23430208.0,23461632.0,23493056.0,23321472.0,23352960.0,23384448.0,23415936.0,23447424.0,23478912.0,23510400.0,23541888.0,23369856.0,23401408.0,23432960.0,23464512.0,23496064.0,23527616.0,23559168.0,23590720.0,23418240.0,23449856.0,23481472.0,23513088.0,23544704.0,23576320.0,23607936.0,23639552.0,23466624.0,23498304.0,23529984.0,23561664.0,23593344.0,23625024.0,23656704.0,23688384.0,23515008.0,23546752.0,23578496.0,23610240.0,23641984.0,23673728.0,23705472.0,23737216.0,23563392.0,23595200.0,23627008.0,23658816.0,23690624.0,23722432.0,23754240.0,23786048.0,23611776.0,23643648.0,23675520.0,23707392.0,23739264.0,23771136.0,23803008.0,23834880.0,23660160.0,23692096.0,23724032.0,23755968.0,23787904.0,23819840.0,23851776.0,23883712.0,23708544.0,23740544.0,23772544.0,23804544.0,23836544.0,23868544.0,23900544.0,23932544.0,23756928.0,23788992.0,23821056.0,23853120.0,23885184.0,23917248.0,23949312.0,23981376.0,23805312.0,23837440.0,23869568.0,23901696.0,23933824.0,23965952.0,23998080.0,24030208.0,23853696.0,23885888.0,23918080.0,23950272.0,23982464.0,24014656.0,24046848.0,24079040.0,23902080.0,23934336.0,23966592.0,23998848.0,24031104.0,24063360.0,24095616.0,24127872.0,23950464.0,23982784.0,24015104.0,24047424.0,24079744.0,24112064.0,24144384.0,24176704.0,23998848.0,24031232.0,24063616.0,24096000.0,24128384.0,24160768.0,24193152.0,24225536.0,24047232.0,24079680.0,24112128.0,24144576.0,24177024.0,24209472.0,24241920.0,24274368.0,24095616.0,24128128.0,24160640.0,24193152.0,24225664.0,24258176.0,24290688.0,24323200.0,24144000.0,24176576.0,24209152.0,24241728.0,24274304.0,24306880.0,24339456.0,24372032.0,24192384.0,24225024.0,24257664.0,24290304.0,24322944.0,24355584.0,24388224.0,24420864.0,24240768.0,24273472.0,24306176.0,24338880.0,24371584.0,24404288.0,24436992.0,24469696.0,24289152.0,24321920.0,24354688.0,24387456.0,24420224.0,24452992.0,24485760.0,24518528.0,24337536.0,24370368.0,24403200.0,24436032.0,24468864.0,24501696.0,24534528.0,24567360.0,24385920.0,24418816.0,24451712.0,24484608.0,24517504.0,24550400.0,24583296.0,24616192.0,24434304.0,24467264.0,24500224.0,24533184.0,24566144.0,24599104.0,24632064.0,24665024.0,24482688.0,24515712.0,24548736.0,24581760.0,24614784.0,24647808.0,24680832.0,24713856.0,24531072.0,24564160.0,24597248.0,24630336.0,24663424.0,24696512.0,24729600.0,24762688.0,24579456.0,24612608.0,24645760.0,24678912.0,24712064.0,24745216.0,24778368.0,24811520.0,24627840.0,24661056.0,24694272.0,24727488.0,24760704.0,24793920.0,24827136.0,24860352.0,24676224.0,24709504.0,24742784.0,24776064.0,24809344.0,24842624.0,24875904.0,24909184.0,24216704.0,24249024.0,24281344.0,24313664.0,24345984.0,24378304.0,24410624.0,24442944.0,23748992.0,23780352.0,23811712.0,23843072.0,23874432.0,23905792.0,23937152.0,23968512.0,23797376.0,23828800.0,23860224.0,23891648.0,23923072.0,23954496.0,23985920.0,24017344.0,23845760.0,23877248.0,23908736.0,23940224.0,23971712.0,24003200.0,24034688.0,24066176.0,23894144.0,23925696.0,23957248.0,23988800.0,24020352.0,24051904.0,24083456.0,24115008.0,23942528.0,23974144.0,24005760.0,24037376.0,24068992.0,24100608.0,24132224.0,24163840.0,23990912.0,24022592.0,24054272.0,24085952.0,24117632.0,24149312.0,24180992.0,24212672.0,24039296.0,24071040.0,24102784.0,24134528.0,24166272.0,24198016.0,24229760.0,24261504.0,24087680.0,24119488.0,24151296.0,24183104.0,24214912.0,24246720.0,24278528.0,24310336.0,24136064.0,24167936.0,24199808.0,24231680.0,24263552.0,24295424.0,24327296.0,24359168.0,24184448.0,24216384.0,24248320.0,24280256.0,24312192.0,24344128.0,24376064.0,24408000.0,24232832.0,24264832.0,24296832.0,24328832.0,24360832.0,24392832.0,24424832.0,24456832.0,24281216.0,24313280.0,24345344.0,24377408.0,24409472.0,24441536.0,24473600.0,24505664.0,24329600.0,24361728.0,24393856.0,24425984.0,24458112.0,24490240.0,24522368.0,24554496.0,24377984.0,24410176.0,24442368.0,24474560.0,24506752.0,24538944.0,24571136.0,24603328.0,24426368.0,24458624.0,24490880.0,24523136.0,24555392.0,24587648.0,24619904.0,24652160.0,24474752.0,24507072.0,24539392.0,24571712.0,24604032.0,24636352.0,24668672.0,24700992.0,24523136.0,24555520.0,24587904.0,24620288.0,24652672.0,24685056.0,24717440.0,24749824.0,24571520.0,24603968.0,24636416.0,24668864.0,24701312.0,24733760.0,24766208.0,24798656.0,24619904.0,24652416.0,24684928.0,24717440.0,24749952.0,24782464.0,24814976.0,24847488.0,24668288.0,24700864.0,24733440.0,24766016.0,24798592.0,24831168.0,24863744.0,24896320.0,24716672.0,24749312.0,24781952.0,24814592.0,24847232.0,24879872.0,24912512.0,24945152.0,24765056.0,24797760.0,24830464.0,24863168.0,24895872.0,24928576.0,24961280.0,24993984.0,24813440.0,24846208.0,24878976.0,24911744.0,24944512.0,24977280.0,25010048.0,25042816.0,24861824.0,24894656.0,24927488.0,24960320.0,24993152.0,25025984.0,25058816.0,25091648.0,24910208.0,24943104.0,24976000.0,25008896.0,25041792.0,25074688.0,25107584.0,25140480.0,24958592.0,24991552.0,25024512.0,25057472.0,25090432.0,25123392.0,25156352.0,25189312.0,25006976.0,25040000.0,25073024.0,25106048.0,25139072.0,25172096.0,25205120.0,25238144.0,25055360.0,25088448.0,25121536.0,25154624.0,25187712.0,25220800.0,25253888.0,25286976.0,25103744.0,25136896.0,25170048.0,25203200.0,25236352.0,25269504.0,25302656.0,25335808.0,25152128.0,25185344.0,25218560.0,25251776.0,25284992.0,25318208.0,25351424.0,25384640.0,25200512.0,25233792.0,25267072.0,25300352.0,25333632.0,25366912.0,25400192.0,25433472.0,25248896.0,25282240.0,25315584.0,25348928.0,25382272.0,25415616.0,25448960.0,25482304.0,25297280.0,25330688.0,25364096.0,25397504.0,25430912.0,25464320.0,25497728.0,25531136.0,25345664.0,25379136.0,25412608.0,25446080.0,25479552.0,25513024.0,25546496.0,25579968.0,25394048.0,25427584.0,25461120.0,25494656.0,25528192.0,25561728.0,25595264.0,25628800.0,25442432.0,25476032.0,25509632.0,25543232.0,25576832.0,25610432.0,25644032.0,25677632.0,25490816.0,25524480.0,25558144.0,25591808.0,25625472.0,25659136.0,25692800.0,25726464.0,25539200.0,25572928.0,25606656.0,25640384.0,25674112.0,25707840.0,25741568.0,25775296.0,24555392.0,24588160.0,24620928.0,24653696.0,24686464.0,24719232.0,24752000.0,24784768.0,23563392.0,23595200.0,23627008.0,23658816.0,23690624.0,23722432.0,23754240.0,23786048.0,11857536.0,11873456.0,11889376.0,11905296.0,11921216.0,11937136.0,11953056.0,11968976.0};\n\n    printf(\"\\nPerforming test:\\n\");\n\n    bool passed = true;\n    for(int i = 0; i < n_conv_transpose_1d_test_0; i++) {\n        if(\n            conv1d_transpose_data_0[i] != expected_conv1d_0[i]) {\n            std::cout << \"index: \" << i << std::endl;\n            std::cout << \"expected: \" << expected_conv1d_0[i] << std::endl;\n            std::cout << \"actual: \" << conv1d_transpose_data_0[i] << std::endl;\n            passed = false;\n            break;\n        }\n    }\n\n    printf(\"ggml_conv_1d_transpose (%d): %s\\n\", (int) ggml_nelements(conv1d_transpose_res_0), passed && (ggml_nelements(conv1d_transpose_res_0) == n_conv_transpose_1d_test_0) ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n\n    passed = true;\n    for(int i = 0; i < n_conv_transpose_1d_test_1; i++) {\n        if(\n            conv1d_transpose_data_1[i] != expected_conv1d_1[i]) {\n            std::cout << \"index: \" << i << std::endl;\n            std::cout << \"expected: \" << expected_conv1d_1[i] << std::endl;\n            std::cout << \"actual: \" << conv1d_transpose_data_1[i] << std::endl;\n            passed = false;\n        }\n    }\n\n    printf(\"ggml_conv_1d_transpose (%d): %s\\n\", (int) ggml_nelements(conv1d_transpose_res_1), passed && (ggml_nelements(conv1d_transpose_res_1) == n_conv_transpose_1d_test_1) ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n\n    passed = true;\n    for(int i = 0; i < n_conv_transpose_1d_test_2; i++) {\n        if(\n            conv1d_transpose_data_2[i] != expected_conv1d_2[i]) {\n            std::cout << \"index: \" << i << std::endl;\n            std::cout << \"expected: \" << expected_conv1d_2[i] << std::endl;\n            std::cout << \"actual: \" << conv1d_transpose_data_2[i] << std::endl;\n            passed = false;\n        }\n    }\n\n    printf(\"ggml_conv_1d_transpose (%d): %s\\n\", (int) ggml_nelements(conv1d_transpose_res_2), passed && (ggml_nelements(conv1d_transpose_res_2) == n_conv_transpose_1d_test_2) ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n\n\n    passed = true;\n    for(int i = 0; i < n_conv_transpose_1d_test_3; i++) {\n        if(\n            conv1d_transpose_data_3[i] != expected_conv1d_3[i]) {\n            std::cout << \"index: \" << i << std::endl;\n            std::cout << \"expected: \" << expected_conv1d_3[i] << std::endl;\n            std::cout << \"actual: \" << conv1d_transpose_data_3[i] << std::endl;\n            passed = false;\n        }\n    }\n\n    printf(\"ggml_conv_1d_transpose (%d): %s\\n\", (int) ggml_nelements(conv1d_transpose_res_3), passed && (ggml_nelements(conv1d_transpose_res_3) == n_conv_transpose_1d_test_3) ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n\n    passed = true;\n    for(int i = 0; i < n_conv_transpose_1d_test_4; i++) {\n        if(\n            conv1d_transpose_data_4[i] != expected_conv1d_4[i]) {\n            std::cout << \"index: \" << i << std::endl;\n            std::cout << \"expected: \" << expected_conv1d_4[i] << std::endl;\n            std::cout << \"actual: \" << conv1d_transpose_data_4[i] << std::endl;\n            passed = false;\n        }\n    }\n\n    printf(\"ggml_conv_1d_transpose (%d): %s\\n\", (int) ggml_nelements(conv1d_transpose_res_4), passed && (ggml_nelements(conv1d_transpose_res_4) == n_conv_transpose_1d_test_4) ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n\n    passed = true;\n    for(int i = 0; i < n_conv_transpose_1d_test_5; i++) {\n        if(\n            conv1d_transpose_data_5[i] != expected_conv1d_5[i]) {\n            std::cout << \"index: \" << i << std::endl;\n            std::cout << \"expected: \" << expected_conv1d_5[i] << std::endl;\n            std::cout << \"actual: \" << conv1d_transpose_data_5[i] << std::endl;\n            passed = false;\n        }\n    }\n\n    printf(\"ggml_conv_1d_transpose (%d): %s\\n\", (int) ggml_nelements(conv1d_transpose_res_5), passed && (ggml_nelements(conv1d_transpose_res_5) == n_conv_transpose_1d_test_5) ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n\n    passed = true;\n    for(int i = 0; i < n_conv_transpose_1d_test_6; i++) {\n        if(\n            conv1d_transpose_data_6[i] != expected_conv1d_6[i]) {\n            std::cout << \"index: \" << i << std::endl;\n            std::cout << \"expected: \" << expected_conv1d_6[i] << std::endl;\n            std::cout << \"actual: \" << conv1d_transpose_data_6[i] << std::endl;\n            passed = false;\n        }\n    }\n\n\n    printf(\"ggml_conv_1d_transpose (%d): %s\\n\", (int) ggml_nelements(conv1d_transpose_res_6), passed && (ggml_nelements(conv1d_transpose_res_6) == n_conv_transpose_1d_test_6) ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n\n\n    passed = true;\n    for(int i = 0; i < n_conv_transpose_1d_test_7; i++) {\n        if(\n            fabs(conv1d_transpose_data_7[i] - expected_conv1d_7[i])/fabs(expected_conv1d_7[i]) > .000001) {\n            std::cout << \"index: \" << i << std::endl;\n            std::cout << \"expected: \" << expected_conv1d_7[i] << std::endl;\n            std::cout << \"actual: \" << conv1d_transpose_data_7[i] << std::endl;\n            passed = false;\n        }\n    }\n\n\n    printf(\"ggml_conv_1d_transpose (%d): %s\\n\", (int) ggml_nelements(conv1d_transpose_res_7), passed && (ggml_nelements(conv1d_transpose_res_7) == n_conv_transpose_1d_test_7) ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n\n\n    ggml_free(model.ctx);\n\n    ggml_backend_buffer_free(model.buffer);\n    ggml_backend_free(model.backend);\n    ggml_gallocr_free(allocr);\n\n    if (passed) {\n        return 0;\n    }\n    return 1;\n}\n"
  },
  {
    "path": "tests/test-conv-transpose.c",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n\n#include <string.h>\n#include <stdio.h>\n#include <stdlib.h>\n\nstruct ggml_context* make_ctx(void) {\n    struct ggml_init_params params = {\n        .mem_size = 2 * 1024 * 1024,\n    };\n\n    return ggml_init(params);\n}\n\nvoid printf_tensor(struct ggml_tensor * t) {\n    if (t->type == GGML_TYPE_F32) {\n        const float * t_d = ggml_get_data_f32(t);\n        for (int i = 0; i < t->ne[2]; ++i) {\n            for (int j = 0; j < t->ne[1]; ++j) {\n                for (int k = 0; k < t->ne[0]; ++k) {\n                    printf(\"%.1f \", t_d[i * t->ne[1] * t->ne[0] + j * t->ne[0] + k]);\n                }\n                printf(\"\\n\");\n            }\n            printf(\"---\\n\");\n        }\n    }\n    else if (t->type == GGML_TYPE_F16) {\n        const ggml_fp16_t * t_d = ggml_get_data(t);\n        for (int i = 0; i < t->ne[2]; ++i) {\n            for (int j = 0; j < t->ne[1]; ++j) {\n                for (int k = 0; k < t->ne[0]; ++k) {\n                    printf(\"%.1f \", ggml_fp16_to_fp32(t_d[i * t->ne[1] * t->ne[0] + j * t->ne[0] + k]));\n                }\n                printf(\"\\n\");\n            }\n            printf(\"---\\n\");\n        }\n    }\n    else {\n        printf(\"unknown type\\n\");\n    }\n}\n\nvoid check_tensor(struct ggml_tensor * t, float * expected_t_d, int ne0, int ne1, int ne2) {\n    GGML_ASSERT(t->type == GGML_TYPE_F32);\n    GGML_ASSERT(t->ne[0] == ne0);\n    GGML_ASSERT(t->ne[1] == ne1);\n    GGML_ASSERT(t->ne[2] == ne2);\n    for (int i2 = 0; i2 < ne2; ++i2) {\n        for (int i1 = 0; i1 < ne1; ++i1) {\n            for (int i0 = 0; i0 < ne0; ++i0) {\n                float expected = *(expected_t_d + i2 * ne1 * ne0 + i1 * ne0 + i0);\n                float actual = ggml_get_data_f32(t)[i2 * ne1 * ne0 + i1 * ne0 + i0];\n                if (expected != actual) {\n                    printf(\"expected %.1f, got %.1f\\n\", expected, actual);\n                }\n                GGML_ASSERT(expected == actual);\n            }\n        }\n    }\n}\n\nvoid test_conv_transpose_1d(void) {\n\n    float buf_f32[1024];\n    for (int i = 0; i < 1024; ++i) {\n        buf_f32[i] = (float)i;\n    }\n\n    ggml_fp16_t buf_f16[1024];\n    for (int i = 0; i < 1024; ++i) {\n        buf_f16[i] = ggml_fp32_to_fp16((float)i);\n    }\n\n    float expected_out_1[3][4] = {\n        {18.0, 45.0, 59.0, 37.0},\n        {24.0, 61.0, 83.0, 51.0},\n        {30.0, 77.0, 107.0, 65.0},\n    };\n    float expected_out_2[3][6] = {\n        {18.0, 21.0, 24.0, 29.0, 30.0, 37.0},\n        {24.0, 27.0, 34.0, 39.0, 44.0, 51.0},\n        {30.0, 33.0, 44.0, 49.0, 58.0, 65.0},\n    };\n    float expected_out_3[3][8] = {\n        {18.0, 21.0, 0.0, 24.0, 29.0, 0.0, 30.0, 37.0},\n        {24.0, 27.0, 0.0, 34.0, 39.0, 0.0, 44.0, 51.0},\n        {30.0, 33.0, 0.0, 44.0, 49.0, 0.0, 58.0, 65.0},\n    };\n\n    // conv transpose 1d with stride 1, 2 & 3\n    {\n        struct ggml_context * ctx = make_ctx();\n\n        struct ggml_tensor * t = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 3, 2); // l x cin\n        memcpy(t->data, buf_f32, ggml_nbytes(t));\n\n        struct ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 2, 3, 2); // k x cout x cin\n        memcpy(k->data, buf_f16, ggml_nbytes(k));\n\n        struct ggml_tensor * out_1 = ggml_conv_transpose_1d(ctx, k, t, 1 /* s0 */, 0 /* p0 */, 1 /* d0 */);\n        struct ggml_tensor * out_2 = ggml_conv_transpose_1d(ctx, k, t, 2 /* s0 */, 0 /* p0 */, 1 /* d0 */);\n        struct ggml_tensor * out_3 = ggml_conv_transpose_1d(ctx, k, t, 3 /* s0 */, 0 /* p0 */, 1 /* d0 */);\n\n        struct ggml_cgraph * gf_1 = ggml_new_graph(ctx);\n        struct ggml_cgraph * gf_2 = ggml_new_graph(ctx);\n        struct ggml_cgraph * gf_3 = ggml_new_graph(ctx);\n\n        ggml_build_forward_expand(gf_1, out_1);\n        ggml_build_forward_expand(gf_2, out_2);\n        ggml_build_forward_expand(gf_3, out_3);\n\n        ggml_graph_compute_with_ctx(ctx, gf_1, 1);\n        ggml_graph_compute_with_ctx(ctx, gf_2, 1);\n        ggml_graph_compute_with_ctx(ctx, gf_3, 1);\n\n        check_tensor(out_1, (float*)expected_out_1, 4, 3, 1);\n        check_tensor(out_2, (float*)expected_out_2, 6, 3, 1);\n        check_tensor(out_3, (float*)expected_out_3, 8, 3, 1);\n    }\n}\n\nvoid test_conv_transpose_2d(void) {\n\n    float buf_f32[1024];\n    for (int i = 0; i < 1024; ++i) {\n        buf_f32[i] = (float)i;\n    }\n\n    ggml_fp16_t buf_f16[1024];\n    for (int i = 0; i < 1024; ++i) {\n        buf_f16[i] = ggml_fp32_to_fp16((float)i);\n    }\n\n    float expected_out_1[3][3][4] = {\n        {\n            {72.0, 162.0, 188.0, 106.0},\n            {192.0, 430.0, 490.0, 274.0},\n            {132.0, 292.0, 326.0, 180.0},\n        },\n        {\n            {96.0, 218.0, 260.0, 146.0},\n            {264.0, 590.0, 682.0, 378.0},\n            {180.0, 396.0, 446.0, 244.0},\n        },\n        {\n            {120.0, 274.0, 332.0, 186.0},\n            {336.0, 750.0, 874.0, 482.0},\n            {228.0, 500.0, 566.0, 308.0},\n        },\n    };\n\n    float expected_out_2[3][4][6] = {\n        {\n            {72.0, 78.0, 84.0, 92.0, 96.0, 106.0},\n            {84.0, 90.0, 100.0, 108.0, 116.0, 126.0},\n            {108.0, 120.0, 120.0, 134.0, 132.0, 148.0},\n            {132.0, 144.0, 148.0, 162.0, 164.0, 180.0},\n        },\n        {\n            {96.0, 102.0, 116.0, 124.0, 136.0, 146.0},\n            {108.0, 114.0, 132.0, 140.0, 156.0, 166.0},\n            {156.0, 168.0, 176.0, 190.0, 196.0, 212.0},\n            {180.0, 192.0, 204.0, 218.0, 228.0, 244.0},\n        },\n        {\n            {120.0, 126.0, 148.0, 156.0, 176.0, 186.0},\n            {132.0, 138.0, 164.0, 172.0, 196.0, 206.0},\n            {204.0, 216.0, 232.0, 246.0, 260.0, 276.0},\n            {228.0, 240.0, 260.0, 274.0, 292.0, 308.0},\n        },\n    };\n\n    float expected_out_3[3][5][8] = {\n        {\n            {72.0, 78.0, 0.0, 84.0, 92.0, 0.0, 96.0, 106.0},\n            {84.0, 90.0, 0.0, 100.0, 108.0, 0.0, 116.0, 126.0},\n            {0.0, 0.0, 0.0, 0.0, 0.0, 0.0},\n            {108.0, 120.0, 0.0, 120.0, 134.0, 0.0, 132.0, 148.0},\n            {132.0, 144.0, 0.0, 148.0, 162.0, 0.0, 164.0, 180.0},\n        },\n        {\n            {96.0, 102.0, 0.0, 116.0, 124.0, 0.0, 136.0, 146.0},\n            {108.0, 114.0, 0.0, 132.0, 140.0, 0.0, 156.0, 166.0},\n            {0.0, 0.0, 0.0, 0.0, 0.0, 0.0},\n            {156.0, 168.0, 0.0, 176.0, 190.0, 0.0, 196.0, 212.0},\n            {180.0, 192.0, 0.0, 204.0, 218.0, 0.0, 228.0, 244.0},\n        },\n        {\n            {120.0, 126.0, 0.0, 148.0, 156.0, 0.0, 176.0, 186.0},\n            {132.0, 138.0, 0.0, 164.0, 172.0, 0.0, 196.0, 206.0},\n            {0.0, 0.0, 0.0, 0.0, 0.0, 0.0},\n            {204.0, 216.0, 0.0, 232.0, 246.0, 0.0, 260.0, 276.0},\n            {228.0, 240.0, 0.0, 260.0, 274.0, 0.0, 292.0, 308.0},\n        },\n    };\n\n    // conv transpose 2d with stride 1, 2 & 3\n    {\n        struct ggml_context * ctx = make_ctx();\n\n        struct ggml_tensor * t = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 2, 2, 1); // w x h x cin\n        memcpy(t->data, buf_f32, ggml_nbytes(t));\n\n        struct ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 2, 2, 3, 2); // w x h cin x cout\n        memcpy(k->data, buf_f16, ggml_nbytes(k));\n\n        struct ggml_tensor * out_1 = ggml_conv_transpose_2d_p0(ctx, k, t, 1);\n        struct ggml_tensor * out_2 = ggml_conv_transpose_2d_p0(ctx, k, t, 2);\n        struct ggml_tensor * out_3 = ggml_conv_transpose_2d_p0(ctx, k, t, 3);\n\n        struct ggml_cgraph * gf_1 = ggml_new_graph(ctx);\n        struct ggml_cgraph * gf_2 = ggml_new_graph(ctx);\n        struct ggml_cgraph * gf_3 = ggml_new_graph(ctx);\n\n        ggml_build_forward_expand(gf_1, out_1);\n        ggml_build_forward_expand(gf_2, out_2);\n        ggml_build_forward_expand(gf_3, out_3);\n\n        ggml_graph_compute_with_ctx(ctx, gf_1, 1);\n        ggml_graph_compute_with_ctx(ctx, gf_2, 1);\n        ggml_graph_compute_with_ctx(ctx, gf_3, 1);\n\n        // printf(\"in\\n\");\n        // printf_tensor(t);\n        // printf(\"\\n\\nkernel\\n\");\n        // printf_tensor(k);\n        // printf(\"\\n\\nout\\n\");\n        // printf_tensor(out);\n        // printf(\"\\n\\nout_2\\n\");\n        // printf_tensor(out_2);\n        // printf(\"\\n\\nout_3\\n\");\n        // printf_tensor(out_3);\n\n        check_tensor(out_1, (float*)expected_out_1, 4, 3, 3);\n        check_tensor(out_2, (float*)expected_out_2, 6, 4, 3);\n        check_tensor(out_3, (float*)expected_out_3, 8, 5, 3);\n\n    }\n}\n\nint main(int argc, const char * argv[]) {\n    test_conv_transpose_1d();\n    test_conv_transpose_2d();\n    return 0;\n}\n"
  },
  {
    "path": "tests/test-conv1d-dw-c1.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n\n#ifdef GGML_USE_CUDA\n#include \"ggml-cuda.h\"\n#endif\n\n#ifdef GGML_USE_METAL\n#include \"ggml-metal.h\"\n#endif\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <string>\n#include <vector>\n\nstatic void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {\n    (void) level;\n    (void) user_data;\n    fputs(text, stderr);\n    fflush(stderr);\n}\n\nstruct test_model {\n    struct ggml_tensor * weight;\n    struct ggml_tensor * input;\n    ggml_backend_t backend = NULL;\n    ggml_backend_buffer_t buffer;\n    struct ggml_context * ctx;\n};\n\nvoid load_model(test_model & model, bool use_gpu = false) {\n    // create data\n    int K = 3, IC = 2, OC = 2;\n    int IL = 6, N = 1;\n\n    // Initialize adata\n    float weight_data[6] = {10.0f, 20.0f, 30.0f, 0.1f, 0.2f, 0.3f};\n\n    // Convert adata to fp16 format\n    std::vector<ggml_fp16_t> h_weight_data(K * IC);\n    ggml_fp32_to_fp16_row(weight_data, h_weight_data.data(), K * IC);\n\n    // Initialize input data, 2 channels, 6 timesteps, 1 batch\n    float input_data[12] = {\n        1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,\n        1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,\n    };\n\n    size_t buffer_size = 0;\n    {\n        buffer_size += K * IC * ggml_type_size(GGML_TYPE_F16); // tensor weight\n        buffer_size += IL * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor input\n        buffer_size += 1024; // overhead\n    }\n\n    printf(\"%s: ggml tensor size    = %d bytes\\n\", __func__, (int) sizeof(ggml_tensor));\n    printf(\"%s: backend buffer size = %0.2f MB\\n\", __func__, (buffer_size/ 1024.f/ 1024.f));\n\n    ggml_log_set(ggml_log_callback_default, nullptr);\n\n    int num_tensors = 2;\n    struct ggml_init_params params {\n            /*.mem_size   =*/ ggml_tensor_overhead() * num_tensors,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ true,\n    };\n\n    // initialize the backend\n#ifdef GGML_USE_CUDA\n    if (use_gpu) {\n        fprintf(stderr, \"%s: using CUDA backend\\n\", __func__);\n        model.backend = ggml_backend_cuda_init(0);\n        if (!model.backend) {\n            fprintf(stderr, \"%s: ggml_backend_cuda_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n#ifdef GGML_USE_METAL\n    if (use_gpu) {\n        fprintf(stderr, \"%s: using Metal backend\\n\", __func__);\n        model.backend = ggml_backend_metal_init();\n        if (!model.backend) {\n            fprintf(stderr, \"%s: ggml_backend_metal_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n    if(!model.backend) {\n        // fallback to CPU backend\n        model.backend = ggml_backend_cpu_init();\n    }\n\n    model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size);\n\n    // create context\n    model.ctx = ggml_init(params);\n\n    // create tensors\n    // A Pytorch grouped Conv1d weight parameter is of shape (out_channels, input_channels/groups, kernel_size)\n    model.weight = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F16,  K, 1, IC);\n    model.input = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, IL, IC, N);\n\n    // create a allocator\n    ggml_tallocr alloc = ggml_tallocr_new(model.buffer);\n\n    // alloc memory\n    ggml_tallocr_alloc(&alloc, model.weight);\n\n    // load data to buffer\n    if(ggml_backend_is_cpu(model.backend)) {\n        memcpy(model.weight->data, h_weight_data.data(), ggml_nbytes(model.weight));\n    } else {\n        ggml_backend_tensor_set(model.weight, h_weight_data.data(), 0, ggml_nbytes(model.weight));\n    }\n\n    // alloc memory\n    ggml_tallocr_alloc(&alloc, model.input);\n\n    if(ggml_backend_is_cpu(model.backend)\n#ifdef GGML_USE_METAL\n                || ggml_backend_is_metal(model.backend)\n#endif\n    ) {\n        memcpy(model.input->data, input_data, ggml_nbytes(model.input));\n    } else {\n        ggml_backend_tensor_set(model.input, input_data, 0, ggml_nbytes(model.input));\n    }\n}\n\nstruct ggml_cgraph * build_graph(const test_model& model) {\n    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();\n    static std::vector<uint8_t> buf(buf_size);\n\n    struct ggml_init_params params0 = {\n        /*.mem_size   =*/ buf_size,\n        /*.mem_buffer =*/ buf.data(),\n        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()\n    };\n\n    // create a temporally context to build the graph\n    struct ggml_context * ctx0 = ggml_init(params0);\n\n    struct ggml_cgraph  * gf = ggml_new_graph(ctx0);\n\n    int s0 = 1;\n    int p0 = 1;\n    int d0 = 1;\n\n    struct ggml_tensor* conv1d_dw_res = ggml_conv_1d_dw(ctx0, model.weight, model.input, s0, p0, d0);\n    ggml_set_name(conv1d_dw_res, \"conv1d_dw_res\");\n    ggml_build_forward_expand(gf, conv1d_dw_res);\n\n    // delete the temporally context used to build the graph\n    ggml_free(ctx0);\n    return gf;\n}\n\nstruct ggml_cgraph* compute_graph(const test_model & model, ggml_gallocr_t allocr) {\n    struct ggml_cgraph * gf = build_graph(model);\n\n    // allocate tensors\n    ggml_gallocr_alloc_graph(allocr, gf);\n    int n_threads = 1;\n\n    if (ggml_backend_is_cpu(model.backend)) {\n        ggml_backend_cpu_set_n_threads(model.backend, n_threads);\n    }\n\n    ggml_backend_graph_compute(model.backend, gf);\n\n    //ggml_graph_print(gf);\n\n    return gf;\n}\n\nint main(void)\n{\n    ggml_time_init();\n\n    test_model model;\n    load_model(model, true);\n\n    ggml_gallocr_t allocr = NULL;\n\n    {\n        allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));\n\n        //create the worst case graph for memory usage estimation\n        struct ggml_cgraph * gf = build_graph(model);\n\n        // compute the required memory\n        ggml_gallocr_reserve(allocr, gf);\n        size_t mem_size = ggml_gallocr_get_buffer_size(allocr, 0);\n        fprintf(stderr, \"%s: compute buffer size: %.2f MB\\n\", __func__, mem_size/1024.0f/1024.0f);\n    }\n\n    struct ggml_cgraph * gf_res = compute_graph(model, allocr);\n\n    struct ggml_tensor * conv1d_dw_res = NULL;\n\n    for(int i = 0; i < ggml_graph_n_nodes(gf_res); i++) {\n        if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), \"conv1d_dw_res\") == 0) {\n            conv1d_dw_res = ggml_graph_node(gf_res, i);\n        }\n    }\n\n    std::vector<float> conv2d_data(ggml_nelements(conv1d_dw_res));\n\n    ggml_backend_tensor_get(conv1d_dw_res, conv2d_data.data(), 0, ggml_nbytes(conv1d_dw_res));\n\n    const int n_conv1d_dw_test = 12;\n\n    float expected_conv1d_dw[n_conv1d_dw_test] = {\n        50.0f, 60.0f, 60.0f, 60.0f, 60.0f, 30.0f, 0.50f,  0.60f,  0.60f,  0.60f,  0.60f,  0.30f\n    };\n\n    printf(\"\\nPerforming test:\\n\");\n\n    bool passed = true;\n    passed = true;\n    for(int i = 0; i < n_conv1d_dw_test; i++) {\n        if(std::abs(conv2d_data[i] - expected_conv1d_dw[i]) > 1e-4) {\n            passed = false;\n            break;\n        }\n    }\n\n    printf(\"ggml_conv1d (%d): %s\\n\", (int) ggml_nelements(conv1d_dw_res), passed && (ggml_nelements(conv1d_dw_res) == n_conv1d_dw_test) ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n    ggml_free(model.ctx);\n\n    ggml_backend_buffer_free(model.buffer);\n    ggml_backend_free(model.backend);\n    ggml_gallocr_free(allocr);\n    return 0;\n}\n"
  },
  {
    "path": "tests/test-conv1d-dw-c2.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n\n#ifdef GGML_USE_CUDA\n#include \"ggml-cuda.h\"\n#endif\n\n#ifdef GGML_USE_METAL\n#include \"ggml-metal.h\"\n#endif\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <string>\n#include <vector>\n\nstatic void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {\n    (void) level;\n    (void) user_data;\n    fputs(text, stderr);\n    fflush(stderr);\n}\n\nstruct test_model {\n    struct ggml_tensor * weight;\n    struct ggml_tensor * input;\n    ggml_backend_t backend = NULL;\n    ggml_backend_buffer_t buffer;\n    struct ggml_context * ctx;\n};\n\nvoid load_model(test_model & model, bool use_gpu = false) {\n    // create data\n    int K = 3, IC = 2, OC = 2;\n    int IL = 6, N = 1;\n\n    // Initialize adata\n    float weight_data[6] = {10.0f, 20.0f, 30.0f, 0.1f, 0.2f, 0.3f};\n\n    // Convert adata to fp16 format\n    std::vector<ggml_fp16_t> h_weight_data(K * IC);\n    ggml_fp32_to_fp16_row(weight_data, h_weight_data.data(), K * IC);\n\n    // Initialize input data, 2 channels, 6 timesteps, 1 batch\n    float input_data[12] = {\n        1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,\n        1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,\n    };\n\n    size_t buffer_size = 0;\n    {\n        buffer_size += K * IC * ggml_type_size(GGML_TYPE_F16); // tensor weight\n        buffer_size += IL * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor input\n        buffer_size += 1024; // overhead\n    }\n\n    printf(\"%s: ggml tensor size    = %d bytes\\n\", __func__, (int) sizeof(ggml_tensor));\n    printf(\"%s: backend buffer size = %0.2f MB\\n\", __func__, (buffer_size/ 1024.f/ 1024.f));\n\n    ggml_log_set(ggml_log_callback_default, nullptr);\n\n    int num_tensors = 2;\n    struct ggml_init_params params {\n            /*.mem_size   =*/ ggml_tensor_overhead() * num_tensors,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ true,\n    };\n\n    // initialize the backend\n#ifdef GGML_USE_CUDA\n    if (use_gpu) {\n        fprintf(stderr, \"%s: using CUDA backend\\n\", __func__);\n        model.backend = ggml_backend_cuda_init(0);\n        if (!model.backend) {\n            fprintf(stderr, \"%s: ggml_backend_cuda_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n#ifdef GGML_USE_METAL\n    if (use_gpu) {\n        fprintf(stderr, \"%s: using Metal backend\\n\", __func__);\n        model.backend = ggml_backend_metal_init();\n        if (!model.backend) {\n            fprintf(stderr, \"%s: ggml_backend_metal_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n    if(!model.backend) {\n        // fallback to CPU backend\n        model.backend = ggml_backend_cpu_init();\n    }\n\n    model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size);\n\n    // create context\n    model.ctx = ggml_init(params);\n\n    // create tensors\n    // A Pytorch grouped Conv1d weight parameter is of shape (out_channels, input_channels/groups, kernel_size)\n    model.weight = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F16,  K, 1, IC);\n    model.input = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, IL, IC, N);\n\n    // create a allocator\n    ggml_tallocr alloc = ggml_tallocr_new(model.buffer);\n\n    // alloc memory\n    ggml_tallocr_alloc(&alloc, model.weight);\n\n    // load data to buffer\n    if(ggml_backend_is_cpu(model.backend)) {\n        memcpy(model.weight->data, h_weight_data.data(), ggml_nbytes(model.weight));\n    } else {\n        ggml_backend_tensor_set(model.weight, h_weight_data.data(), 0, ggml_nbytes(model.weight));\n    }\n\n    // alloc memory\n    ggml_tallocr_alloc(&alloc, model.input);\n\n    if(ggml_backend_is_cpu(model.backend)\n#ifdef GGML_USE_METAL\n                || ggml_backend_is_metal(model.backend)\n#endif\n    ) {\n        memcpy(model.input->data, input_data, ggml_nbytes(model.input));\n    } else {\n        ggml_backend_tensor_set(model.input, input_data, 0, ggml_nbytes(model.input));\n    }\n}\n\nstruct ggml_cgraph * build_graph(const test_model& model) {\n    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();\n    static std::vector<uint8_t> buf(buf_size);\n\n    struct ggml_init_params params0 = {\n        /*.mem_size   =*/ buf_size,\n        /*.mem_buffer =*/ buf.data(),\n        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()\n    };\n\n    // create a temporally context to build the graph\n    struct ggml_context * ctx0 = ggml_init(params0);\n\n    struct ggml_cgraph  * gf = ggml_new_graph(ctx0);\n\n    int s0 = 3;\n    int p0 = 0;\n    int d0 = 1;\n\n    struct ggml_tensor* conv1d_dw_res = ggml_conv_1d_dw(ctx0, model.weight, model.input, s0, p0, d0);\n    ggml_set_name(conv1d_dw_res, \"conv1d_dw_res\");\n    ggml_build_forward_expand(gf, conv1d_dw_res);\n\n    // delete the temporally context used to build the graph\n    ggml_free(ctx0);\n    return gf;\n}\n\nstruct ggml_cgraph* compute_graph(const test_model & model, ggml_gallocr_t allocr) {\n    struct ggml_cgraph * gf = build_graph(model);\n\n    // allocate tensors\n    ggml_gallocr_alloc_graph(allocr, gf);\n    int n_threads = 1;\n\n    if (ggml_backend_is_cpu(model.backend)) {\n        ggml_backend_cpu_set_n_threads(model.backend, n_threads);\n    }\n\n    ggml_backend_graph_compute(model.backend, gf);\n\n    //ggml_graph_print(gf);\n\n    return gf;\n}\n\nint main(void)\n{\n    ggml_time_init();\n\n    test_model model;\n    load_model(model, true);\n\n    ggml_gallocr_t allocr = NULL;\n\n    {\n        allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));\n\n        //create the worst case graph for memory usage estimation\n        struct ggml_cgraph * gf = build_graph(model);\n\n        // compute the required memory\n        ggml_gallocr_reserve(allocr, gf);\n        size_t mem_size = ggml_gallocr_get_buffer_size(allocr, 0);\n        fprintf(stderr, \"%s: compute buffer size: %.2f MB\\n\", __func__, mem_size/1024.0f/1024.0f);\n    }\n\n    struct ggml_cgraph * gf_res = compute_graph(model, allocr);\n\n    struct ggml_tensor * conv1d_dw_res = NULL;\n\n    for(int i = 0; i < ggml_graph_n_nodes(gf_res); i++) {\n        if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), \"conv1d_dw_res\") == 0) {\n            conv1d_dw_res = ggml_graph_node(gf_res, i);\n        }\n    }\n\n    std::vector<float> conv2d_data(ggml_nelements(conv1d_dw_res));\n\n    ggml_backend_tensor_get(conv1d_dw_res, conv2d_data.data(), 0, ggml_nbytes(conv1d_dw_res));\n\n    const int n_conv1d_dw_test = 4;\n\n    float expected_conv1d_dw[n_conv1d_dw_test] = {\n        60.0f, 60.0f, 0.6f, 0.6f\n    };\n\n    printf(\"\\nPerforming test:\\n\");\n\n    bool passed = true;\n    passed = true;\n    for(int i = 0; i < n_conv1d_dw_test; i++) {\n        if(std::abs(conv2d_data[i] - expected_conv1d_dw[i]) > 1e-4) {\n            passed = false;\n            break;\n        }\n    }\n\n    printf(\"ggml_conv1d (%d): %s\\n\", (int) ggml_nelements(conv1d_dw_res), passed && (ggml_nelements(conv1d_dw_res) == n_conv1d_dw_test) ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n    ggml_free(model.ctx);\n\n    ggml_backend_buffer_free(model.buffer);\n    ggml_backend_free(model.backend);\n    ggml_gallocr_free(allocr);\n    return 0;\n}\n"
  },
  {
    "path": "tests/test-conv1d.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n\n#ifdef GGML_USE_CUDA\n#include \"ggml-cuda.h\"\n#endif\n\n#ifdef GGML_USE_METAL\n#include \"ggml-metal.h\"\n#endif\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <string>\n#include <vector>\n\nstatic void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {\n    (void) level;\n    (void) user_data;\n    fputs(text, stderr);\n    fflush(stderr);\n}\n\nstruct test_model {\n    struct ggml_tensor * a;\n    struct ggml_tensor * b;\n    ggml_backend_t backend = NULL;\n    ggml_backend_buffer_t buffer;\n    struct ggml_context * ctx;\n};\n\nvoid load_model(test_model & model, bool use_gpu = false) {\n    // create data\n    int K = 3, IC = 10, OC = 10;\n    int IL = 8, N = 1;\n\n    // Initialize adata\n    std::vector<float> adata(K * IC * OC);\n    for (int i = 0; i < K * IC * OC; i++) {\n        adata[i] = 4.5f;\n    }\n\n    // Convert adata to fp16 format\n    std::vector<ggml_fp16_t> hadata(K * IC * OC);\n    ggml_fp32_to_fp16_row(adata.data(), hadata.data(), K * IC * OC);\n\n    // Initialize bdata\n    std::vector<float> bdata(IL * IC * N);\n    for (int i = 0; i < IL * IC * N; i++) {\n        bdata[i] = 2.5f;\n    }\n\n    size_t buffer_size = 0;\n    {\n        buffer_size += K * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a\n        buffer_size += IL * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b\n        buffer_size += 1024; // overhead\n    }\n\n    printf(\"%s: ggml tensor size    = %d bytes\\n\", __func__, (int) sizeof(ggml_tensor));\n    printf(\"%s: backend buffer size = %0.2f MB\\n\", __func__, (buffer_size/ 1024.f/ 1024.f));\n\n    ggml_log_set(ggml_log_callback_default, nullptr);\n\n    int num_tensors = 2;\n    struct ggml_init_params params {\n            /*.mem_size   =*/ ggml_tensor_overhead() * num_tensors,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ true,\n    };\n\n    // initialize the backend\n#ifdef GGML_USE_CUDA\n    if (use_gpu) {\n        fprintf(stderr, \"%s: using CUDA backend\\n\", __func__);\n        model.backend = ggml_backend_cuda_init(0);\n        if (!model.backend) {\n            fprintf(stderr, \"%s: ggml_backend_cuda_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n#ifdef GGML_USE_METAL\n    if (use_gpu) {\n        fprintf(stderr, \"%s: using Metal backend\\n\", __func__);\n        model.backend = ggml_backend_metal_init();\n        if (!model.backend) {\n            fprintf(stderr, \"%s: ggml_backend_metal_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n    if(!model.backend) {\n        // fallback to CPU backend\n        model.backend = ggml_backend_cpu_init();\n    }\n\n    model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size);\n\n    // create context\n    model.ctx = ggml_init(params);\n\n    // create tensors\n    model.a = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F16,  K, IC, OC);\n    model.b = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, IL, IC, N);\n\n    // create a allocator\n    ggml_tallocr alloc = ggml_tallocr_new(model.buffer);\n\n    // alloc memory\n    ggml_tallocr_alloc(&alloc, model.a);\n\n    // load data to buffer\n    if(ggml_backend_is_cpu(model.backend)) {\n        memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a));\n    } else {\n        ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a));\n    }\n\n    // alloc memory\n    ggml_tallocr_alloc(&alloc, model.b);\n\n    if(ggml_backend_is_cpu(model.backend)\n#ifdef GGML_USE_METAL\n                || ggml_backend_is_metal(model.backend)\n#endif\n    ) {\n        memcpy(model.b->data, bdata.data(), ggml_nbytes(model.b));\n    } else {\n        ggml_backend_tensor_set(model.b, bdata.data(), 0, ggml_nbytes(model.b));\n    }\n}\n\nstruct ggml_cgraph * build_graph(const test_model& model) {\n    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();\n    static std::vector<uint8_t> buf(buf_size);\n\n    struct ggml_init_params params0 = {\n        /*.mem_size   =*/ buf_size,\n        /*.mem_buffer =*/ buf.data(),\n        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()\n    };\n\n    // create a temporally context to build the graph\n    struct ggml_context * ctx0 = ggml_init(params0);\n\n    struct ggml_cgraph  * gf = ggml_new_graph(ctx0);\n\n    int s0 = 1;\n    int p0 = 1;\n    int d0 = 1;\n\n    // split conv1d in fundamental methods for test unit\n    struct ggml_tensor* im2col_0 = ggml_im2col(ctx0, model.a, model.b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);\n    ggml_set_name(im2col_0, \"im2col_res\");\n    ggml_build_forward_expand(gf, im2col_0);\n\n    struct ggml_tensor* conv1d_res = ggml_conv_1d(ctx0, model.a, model.b, s0, p0, d0);\n    ggml_set_name(conv1d_res, \"conv1d_res\");\n    ggml_build_forward_expand(gf, conv1d_res);\n\n    // delete the temporally context used to build the graph\n    ggml_free(ctx0);\n    return gf;\n}\n\nstruct ggml_cgraph* compute_graph(const test_model & model, ggml_gallocr_t allocr) {\n    struct ggml_cgraph * gf = build_graph(model);\n\n    // allocate tensors\n    ggml_gallocr_alloc_graph(allocr, gf);\n    int n_threads = 1;\n\n    if (ggml_backend_is_cpu(model.backend)) {\n        ggml_backend_cpu_set_n_threads(model.backend, n_threads);\n    }\n\n    ggml_backend_graph_compute(model.backend, gf);\n\n    //ggml_graph_print(gf);\n\n    return gf;\n}\n\nint main(void)\n{\n    ggml_time_init();\n\n    test_model model;\n    load_model(model, true);\n\n    ggml_gallocr_t allocr = NULL;\n\n    {\n        allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));\n\n        //create the worst case graph for memory usage estimation\n        struct ggml_cgraph * gf = build_graph(model);\n\n        // compute the required memory\n        ggml_gallocr_reserve(allocr, gf);\n        size_t mem_size = ggml_gallocr_get_buffer_size(allocr, 0);\n        fprintf(stderr, \"%s: compute buffer size: %.2f MB\\n\", __func__, mem_size/1024.0f/1024.0f);\n    }\n\n    struct ggml_cgraph * gf_res = compute_graph(model, allocr);\n\n     struct ggml_tensor * im2col_res = NULL;\n    struct ggml_tensor * conv1d_res = NULL;\n\n    for(int i = 0; i < ggml_graph_n_nodes(gf_res); i++) {\n        if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), \"im2col_res\") == 0) {\n            im2col_res = ggml_graph_node(gf_res, i);\n        } else if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), \"conv1d_res\") == 0) {\n            conv1d_res = ggml_graph_node(gf_res, i);\n        }\n    }\n\n    std::vector<uint16_t> im2col_data(ggml_nelements(im2col_res));\n    std::vector<float> conv2d_data(ggml_nelements(conv1d_res));\n\n    ggml_backend_tensor_get(im2col_res, im2col_data.data(), 0, ggml_nbytes(im2col_res));\n    ggml_backend_tensor_get(conv1d_res, conv2d_data.data(), 0, ggml_nbytes(conv1d_res));\n\n    const int n_conv1d_test = 80;\n    const int n_im2col_test = 240;\n\n    float expected_conv1d[n_conv1d_test] = {\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f\n    };\n    // first im2col test\n\n    uint16_t expected_im2col[n_conv1d_test] = {\n        0, 16640, 16640, 0, 16640, 16640, 0, 16640,\n        16640, 0, 16640, 16640, 0, 16640, 16640, 0,\n        16640, 16640, 0, 16640, 16640, 0, 16640, 16640,\n        0, 16640, 16640, 0, 16640, 16640, 16640, 16640,\n        16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640,\n        16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640,\n        16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640,\n        16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640,\n        16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640,\n        16640, 16640, 16640, 16640, 16640, 16640, 16640, 16640\n    };\n\n    printf(\"\\nPerforming test:\\n\");\n\n    bool passed = true;\n    for(int i = 0; i < n_conv1d_test; i++) {\n        if(\n            im2col_data[i] != expected_im2col[i]) {\n            passed = false;\n            break;\n        }\n    }\n\n    printf(\"ggml_im2col (%d): %s\\n\", (int) ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n\n    passed = true;\n    for(int i = 0; i < n_conv1d_test; i++) {\n        if(conv2d_data[i] != expected_conv1d[i]) {\n            passed = false;\n            break;\n        }\n    }\n\n    printf(\"ggml_conv1d (%d): %s\\n\", (int) ggml_nelements(conv1d_res), passed && (ggml_nelements(conv1d_res) == n_conv1d_test) ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n    ggml_free(model.ctx);\n\n    ggml_backend_buffer_free(model.buffer);\n    ggml_backend_free(model.backend);\n    ggml_gallocr_free(allocr);\n    return 0;\n}\n"
  },
  {
    "path": "tests/test-conv2d-dw.cpp",
    "content": "#include <ggml.h>\n#include <ggml-cpu.h>\n#include <ggml-alloc.h>\n#include <ggml-backend.h>\n#include <ggml-cpp.h>\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <vector>\n\nstd::vector<float> f32_range(int n, float start, float end) {\n    std::vector<float> values(n);\n    float step = (end - start) / n;\n    for (int i = 0; i < n; i++) {\n        values[i] = start + i * step;\n    }\n    return values;\n}\n\n// Most straightforward implementation without any optimizations\nstd::vector<float> conv_2d_dw_reference(\n        int src_w, int src_h, const float * src_data,\n        int knl_w, int knl_h, const float * knl_data,\n        int channels, int batch, int stride, int pad, int dilation) {\n\n    int dst_w = (src_w + 2 * pad - dilation * (knl_w - 1) - 1) / stride + 1;\n    int dst_h = (src_h + 2 * pad - dilation * (knl_h - 1) - 1) / stride + 1;\n    std::vector<float> dst_data(dst_w * dst_h * channels * batch);\n\n    for (int b = 0; b < batch; b++) {\n        const float * src_base = src_data + b * src_w * src_h * channels;\n        float * dst_base = dst_data.data() + b * dst_w * dst_h * channels;\n        for (int c = 0; c < channels; c++) {\n            for (int y = 0; y < dst_h; y++) {\n                for (int x = 0; x < dst_w; x++) {\n                    float sum = 0;\n                    for (int knl_y = 0; knl_y < knl_h; knl_y++) {\n                        for (int knl_x = 0; knl_x < knl_w; knl_x++) {\n                            int src_x = x * stride + knl_x * dilation - pad;\n                            int src_y = y * stride + knl_y * dilation - pad;\n                            if (src_x >= 0 && src_x < src_w && src_y >= 0 && src_y < src_h) {\n                                sum += src_base[c * src_w * src_h + src_y * src_w + src_x] *\n                                       knl_data[c * knl_w * knl_h + knl_y * knl_w + knl_x];\n                            }\n                        }\n                    }\n                    dst_base[c * dst_w * dst_h + y * dst_w + x] = sum;\n                }\n            }\n        }\n    }\n    return dst_data;\n}\n\nbool check_equal(const std::vector<float> & result, const std::vector<float> & expected) {\n    if (result.size() != expected.size()) {\n        printf(\"result.size() = %d, expected.size() = %d\\n\", (int)result.size(), (int)expected.size());\n        return false;\n    }\n    for (int i = 0; i < result.size(); i++) {\n        if(std::abs(result[i] - expected[i]) > 1e-5) {\n            printf(\"result[%d] %f != %f expected[%d]\\n\", i, result[i], expected[i], i);\n            return false;\n        }\n    }\n    return true;\n}\n\nbool test_conv_2d_dw(\n        int channels,\n        int kernel_size,\n        int stride,\n        int pad,\n        int dilation,\n        bool contiguous_channels) {\n    ggml_time_init();\n\n    const int batch = 2;\n    const int src_w = 8;\n    const int src_h = 6;\n    const int knl_w = kernel_size;\n    const int knl_h = kernel_size;\n\n    ggml_init_params params {\n        /*.mem_size   =*/ 64 * ggml_tensor_overhead() + ggml_graph_overhead(),\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ true\n    };\n\n    ggml_context_ptr ctx_ptr{ggml_init(params)};\n    ggml_context * ctx = ctx_ptr.get();\n    ggml_cgraph * gf = ggml_new_graph(ctx);\n\n    // Build graph\n    ggml_tensor * src_input = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, src_w, src_h, channels, batch);\n    ggml_tensor * knl_input = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, knl_w, knl_h, 1, channels);\n    ggml_tensor * src = src_input;\n    ggml_tensor * knl = knl_input;\n    if (contiguous_channels) {\n        // Convert tensor to [C, W, H, N] layout in memory, then permute strides back to [W, H, C, N]\n        src = ggml_cont(ctx, ggml_permute(ctx, src, 1, 2, 0, 3));\n        src = ggml_permute(ctx, src, 2, 0, 1, 3);\n        knl = ggml_cont(ctx, ggml_permute(ctx, knl, 2, 3, 1, 0));\n        knl = ggml_permute(ctx, knl, 3, 2, 0, 1);\n    }\n    ggml_tensor * res = ggml_conv_2d_dw_direct(\n        ctx, knl, src, stride, stride, pad, pad, dilation, dilation);\n    if (contiguous_channels) {\n        res = ggml_cont(ctx, res);\n    }\n    ggml_build_forward_expand(gf, res);\n\n    // Create backend & allocate buffers\n    ggml_backend_ptr backend_ptr{ggml_backend_cpu_init()};\n    ggml_backend_t backend = backend_ptr.get();\n    ggml_backend_cpu_set_n_threads(backend, 2);\n    ggml_backend_buffer_ptr buffer{ggml_backend_alloc_ctx_tensors(ctx, backend)};\n\n    std::vector<float> src_values = f32_range(ggml_nelements(src), -1.f, 1.f);\n    std::vector<float> knl_values = f32_range(ggml_nelements(knl), -1.f, 1.f);\n    ggml_backend_tensor_set(src_input, src_values.data(), 0, ggml_nbytes(src));\n    ggml_backend_tensor_set(knl_input, knl_values.data(), 0, ggml_nbytes(knl));\n\n    ggml_backend_graph_compute(backend, gf);\n\n    std::vector<float> res_values(ggml_nelements(res));\n    ggml_backend_tensor_get(res, res_values.data(), 0, ggml_nbytes(res));\n\n    std::vector<float> expected = conv_2d_dw_reference(\n        src_w, src_h, src_values.data(),\n        knl_w, knl_h, knl_values.data(),\n        channels, batch, stride, pad, dilation);\n\n    bool passed = check_equal(res_values, expected);\n\n    printf(\"ggml_conv_2d_dw(channels=%d, kernel=%dx%d, stride=%d, pad=%d, dilation=%d, layout=%s): %s\\n\",\n        channels, kernel_size, kernel_size, stride, pad, dilation, contiguous_channels ? \"CWHN\" : \"WHCN\",\n        passed ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n    return passed;\n}\n\nint main(int argc, char ** argv) {\n    bool passed = true;\n    passed = test_conv_2d_dw(3, 1, 1, 0, 1, false) && passed;\n    passed = test_conv_2d_dw(3, 1, 1, 0, 1, true) && passed;\n    passed = test_conv_2d_dw(42, 3, 2, 1, 1, false) && passed;\n    passed = test_conv_2d_dw(42, 3, 2, 1, 1, true) && passed;\n    passed = test_conv_2d_dw(8, 5, 1, 2, 2, false) && passed;\n    passed = test_conv_2d_dw(8, 5, 1, 2, 2, true) && passed;\n    return passed ? 0 : 1;\n}\n"
  },
  {
    "path": "tests/test-conv2d.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n\n#ifdef GGML_USE_CUDA\n#include \"ggml-cuda.h\"\n#endif\n\n#ifdef GGML_USE_METAL\n#include \"ggml-metal.h\"\n#endif\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <map>\n#include <string>\n#include <vector>\n\nstatic void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {\n    (void) level;\n    (void) user_data;\n    fputs(text, stderr);\n    fflush(stderr);\n}\n\nstruct test_model {\n    struct ggml_tensor * a;\n    struct ggml_tensor * b;\n    ggml_backend_t backend = NULL;\n    ggml_backend_buffer_t buffer;\n    struct ggml_context * ctx;\n};\n\nvoid load_model(test_model & model, bool use_gpu = false) {\n    // create data\n    int KW = 3, KH = 3, IC = 10, OC = 10;\n    int IW = 8, IH = 6, N = 1;\n\n    // Initialize adata\n    std::vector<float> adata(KW * KH * IC * OC);\n    for (int i = 0; i < KW * KH * IC * OC; i++) {\n        adata[i] = 2.5f;\n    }\n\n    // Convert adata to fp16 format\n    std::vector<ggml_fp16_t> hadata(KW * KH * IC * OC);\n    ggml_fp32_to_fp16_row(adata.data(), hadata.data(), KW * KH * IC * OC);\n\n    // Initialize bdata\n    std::vector<float> bdata(IW * IH * IC * N);\n    for (int i = 0; i < IW * IH * IC * N; i++) {\n        bdata[i] = 1.5f;\n    }\n\n    size_t buffer_size = 0;\n    {\n        buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a\n        buffer_size += IW * IH * IC * N  * ggml_type_size(GGML_TYPE_F32); // tensor b\n        buffer_size += 1024; // overhead\n    }\n\n    printf(\"%s: ggml tensor size    = %d bytes\\n\", __func__, (int) sizeof(ggml_tensor));\n    printf(\"%s: backend buffer size = %0.2f MB\\n\", __func__, (buffer_size/ 1024.f/ 1024.f));\n\n    int num_tensors = 2;\n    struct ggml_init_params params {\n            /*.mem_size   =*/ ggml_tensor_overhead() * num_tensors,\n            /*.mem_buffer =*/ NULL,\n            /*.no_alloc   =*/ true,\n    };\n\n    ggml_log_set(ggml_log_callback_default, nullptr);\n\n    // initialize the backend\n#ifdef GGML_USE_CUDA\n    if (use_gpu) {\n        fprintf(stderr, \"%s: using CUDA backend\\n\", __func__);\n        model.backend = ggml_backend_cuda_init(0);\n        if (!model.backend) {\n            fprintf(stderr, \"%s: ggml_backend_cuda_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n#ifdef GGML_USE_METAL\n    if (use_gpu) {\n        fprintf(stderr, \"%s: using Metal backend\\n\", __func__);\n        model.backend = ggml_backend_metal_init();\n        if (!model.backend) {\n            fprintf(stderr, \"%s: ggml_backend_metal_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n    if(!model.backend) {\n        // fallback to CPU backend\n        model.backend = ggml_backend_cpu_init();\n    }\n\n    model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size);\n\n    // create context\n    model.ctx = ggml_init(params);\n\n    // create tensors\n    model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16,  KW, KH, IC, OC);\n    model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N);\n\n    // create a allocator\n    struct ggml_tallocr alloc = ggml_tallocr_new(model.buffer);\n\n    // alloc memory\n    ggml_tallocr_alloc(&alloc, model.a);\n\n    // load data to buffer\n    if(ggml_backend_is_cpu(model.backend)) {\n        memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a));\n    } else {\n        ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a));\n    }\n\n    // alloc memory\n    ggml_tallocr_alloc(&alloc, model.b);\n\n    if(ggml_backend_is_cpu(model.backend)\n#ifdef GGML_USE_METAL\n                || ggml_backend_is_metal(model.backend)\n#endif\n    ) {\n        memcpy(model.b->data, bdata.data(), ggml_nbytes(model.b));\n    } else {\n        ggml_backend_tensor_set(model.b, bdata.data(), 0, ggml_nbytes(model.b));\n    }\n}\n\nstruct ggml_cgraph * build_graph(const test_model& model) {\n    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();\n    static std::vector<uint8_t> buf(buf_size);\n\n    struct ggml_init_params params0 = {\n        /*.mem_size   =*/ buf_size,\n        /*.mem_buffer =*/ buf.data(),\n        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()\n    };\n\n    // create a temporally context to build the graph\n    struct ggml_context * ctx0 = ggml_init(params0);\n\n    struct ggml_cgraph  * gf = ggml_new_graph(ctx0);\n\n    int s0 = 1;\n    int s1 = 1;\n    int p0 = 1;\n    int p1 = 1;\n    int d0 = 1;\n    int d1 = 1;\n\n    // split conv2d in fundamental methods for test unit\n    struct ggml_tensor* im2col_0 = ggml_im2col(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16);\n    ggml_set_name(im2col_0, \"im2col_res\");\n    ggml_build_forward_expand(gf, im2col_0);\n\n    // recalculate for avoid fragmentation\n    struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1);\n    ggml_set_name(conv2d_res, \"conv2d_res\");\n    ggml_build_forward_expand(gf, conv2d_res);\n\n    ggml_free(ctx0);\n    return gf;\n}\n\nstruct ggml_cgraph * compute_graph(const test_model & model, ggml_gallocr_t allocr) {\n    struct ggml_cgraph * gf = build_graph(model);\n\n    // allocate tensors\n    ggml_gallocr_alloc_graph(allocr, gf);\n    int n_threads = 1;\n\n    if (ggml_backend_is_cpu(model.backend)) {\n        ggml_backend_cpu_set_n_threads(model.backend, n_threads);\n    }\n\n    ggml_backend_graph_compute(model.backend, gf);\n\n    //ggml_graph_print(gf);\n\n    return gf;\n}\n\nint main(void)\n{\n    ggml_time_init();\n\n    test_model model;\n    load_model(model, true);\n\n    ggml_gallocr_t allocr = NULL;\n\n    {\n        allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));\n\n        //create the worst case graph for memory usage estimation\n        struct ggml_cgraph * gf = build_graph(model);\n\n        // compute the required memory\n        ggml_gallocr_reserve(allocr, gf);\n        size_t mem_size = ggml_gallocr_get_buffer_size(allocr, 0);\n        fprintf(stderr, \"%s: compute buffer size: %.2f MB\\n\", __func__, mem_size/1024.0f/1024.0f);\n    }\n\n    struct ggml_cgraph * gf_res = compute_graph(model, allocr);\n\n    struct ggml_tensor * im2col_res = NULL;\n    struct ggml_tensor * conv2d_res = NULL;\n\n    for(int i = 0; i < ggml_graph_n_nodes(gf_res); ++i) {\n        if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), \"im2col_res\") == 0) {\n            im2col_res = ggml_graph_node(gf_res, i);\n        } else if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), \"conv2d_res\") == 0) {\n            conv2d_res = ggml_graph_node(gf_res, i);\n        }\n    }\n\n    std::vector<uint16_t> im2col_data(ggml_nelements(im2col_res));\n    std::vector<float> conv2d_data(ggml_nelements(conv2d_res));\n\n    ggml_backend_tensor_get(im2col_res, im2col_data.data(), 0, ggml_nbytes(im2col_res));\n    ggml_backend_tensor_get(conv2d_res, conv2d_data.data(), 0, ggml_nbytes(conv2d_res));\n\n    const int n_conv2d_test = 480;\n    const int n_im2col_test = 4320;\n\n    float expected_conv2d [n_conv2d_test] = {\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,\n        150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f };\n\n    uint16_t expected_im2col[n_conv2d_test] = {\n            0, 0, 0, 0, 15872, 15872, 0, 15872,\n            15872, 0, 0, 0, 0, 15872, 15872, 0,\n            15872, 15872, 0, 0, 0, 0, 15872, 15872,\n            0, 15872, 15872, 0, 0, 0, 0, 15872,\n            15872, 0, 15872, 15872, 0, 0, 0, 0,\n            15872, 15872, 0, 15872, 15872, 0, 0, 0,\n            0, 15872, 15872, 0, 15872, 15872, 0, 0,\n            0, 0, 15872, 15872, 0, 15872, 15872, 0,\n            0, 0, 0, 15872, 15872, 0, 15872, 15872,\n            0, 0, 0, 0, 15872, 15872, 0, 15872,\n            15872, 0, 0, 0, 0, 15872, 15872, 0,\n            15872, 15872, 0, 0, 0, 15872, 15872, 15872,\n            15872, 15872, 15872, 0, 0, 0, 15872, 15872,\n            15872, 15872, 15872, 15872, 0, 0, 0, 15872,\n            15872, 15872, 15872, 15872, 15872, 0, 0, 0,\n            15872, 15872, 15872, 15872, 15872, 15872, 0, 0,\n            0, 15872, 15872, 15872, 15872, 15872, 15872, 0,\n            0, 0, 15872, 15872, 15872, 15872, 15872, 15872,\n            0, 0, 0, 15872, 15872, 15872, 15872, 15872,\n            15872, 0, 0, 0, 15872, 15872, 15872, 15872,\n            15872, 15872, 0, 0, 0, 15872, 15872, 15872,\n            15872, 15872, 15872, 0, 0, 0, 15872, 15872,\n            15872, 15872, 15872, 15872, 0, 0, 0, 15872,\n            15872, 15872, 15872, 15872, 15872, 0, 0, 0,\n            15872, 15872, 15872, 15872, 15872, 15872, 0, 0,\n            0, 15872, 15872, 15872, 15872, 15872, 15872, 0,\n            0, 0, 15872, 15872, 15872, 15872, 15872, 15872,\n            0, 0, 0, 15872, 15872, 15872, 15872, 15872,\n            15872, 0, 0, 0, 15872, 15872, 15872, 15872,\n            15872, 15872, 0, 0, 0, 15872, 15872, 15872,\n            15872, 15872, 15872, 0, 0, 0, 15872, 15872,\n            15872, 15872, 15872, 15872, 0, 0, 0, 15872,\n            15872, 15872, 15872, 15872, 15872, 0, 0, 0,\n            15872, 15872, 15872, 15872, 15872, 15872, 0, 0,\n            0, 15872, 15872, 15872, 15872, 15872, 15872, 0,\n            0, 0, 15872, 15872, 15872, 15872, 15872, 15872,\n            0, 0, 0, 15872, 15872, 15872, 15872, 15872,\n            15872, 0, 0, 0, 15872, 15872, 15872, 15872,\n            15872, 15872, 0, 0, 0, 15872, 15872, 15872,\n            15872, 15872, 15872, 0, 0, 0, 15872, 15872,\n            15872, 15872, 15872, 15872, 0, 0, 0, 15872,\n            15872, 15872, 15872, 15872, 15872, 0, 0, 0,\n            15872, 15872, 15872, 15872, 15872, 15872, 0, 0,\n            0, 15872, 15872, 15872, 15872, 15872, 15872, 0,\n            0, 0, 15872, 15872, 15872, 15872, 15872, 15872,\n            0, 0, 0, 15872, 15872, 15872, 15872, 15872,\n            15872, 0, 0, 0, 15872, 15872, 15872, 15872,\n            15872, 15872, 0, 0, 0, 15872, 15872, 15872,\n            15872, 15872, 15872, 0, 0, 0, 15872, 15872,\n            15872, 15872, 15872, 15872, 0, 0, 0, 15872,\n            15872, 15872, 15872, 15872, 15872, 0, 0, 0,\n            15872, 15872, 15872, 15872, 15872, 15872, 0, 0,\n            0, 15872, 15872, 15872, 15872, 15872, 15872, 0,\n            0, 0, 15872, 15872, 15872, 15872, 15872, 15872,\n            0, 0, 0, 15872, 15872, 15872, 15872, 15872,\n            15872, 0, 0, 0, 15872, 15872, 15872, 15872,\n            15872, 15872, 0, 0, 0, 15872, 15872, 15872,\n            15872, 15872, 15872, 0, 0, 0, 15872, 15872,\n            15872, 15872, 15872, 15872, 0, 0, 0, 15872,\n            15872, 15872, 15872, 15872, 15872, 0, 0, 0\n    };\n\n    printf(\"\\nPerforming test:\\n\");\n\n    bool passed = true;\n    for(int i = 0; i < n_conv2d_test; i++) {\n        if(\n            im2col_data[i] != expected_im2col[i]) {\n            passed = false;\n            break;\n        }\n    }\n\n    printf(\"ggml_im2col (%d): %s\\n\", (int) ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n\n    passed = true;\n    for(int i = 0; i < n_conv2d_test; i++) {\n        if(conv2d_data[i] != expected_conv2d[i]) {\n            passed = false;\n            break;\n        }\n    }\n\n    printf(\"ggml_conv2d (%d): %s\\n\", (int) ggml_nelements(conv2d_res), passed && (ggml_nelements(conv2d_res) == n_conv2d_test) ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n\n    ggml_free(model.ctx);\n\n    ggml_backend_buffer_free(model.buffer);\n    ggml_backend_free(model.backend);\n    ggml_gallocr_free(allocr);\n    return 0;\n}\n"
  },
  {
    "path": "tests/test-customop.c",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n\n#include <string.h>\n#include <stdio.h>\n#include <stdlib.h>\n\n#if defined(_WIN32)\n#include <windows.h>\ntypedef volatile LONG atomic_int;\nstatic LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {\n    return InterlockedExchangeAdd(ptr, inc);\n}\n#else\n#include <stdatomic.h>\n#endif\n\n#define MIN(a, b) ((a) < (b) ? (a) : (b))\n#define MAX(a, b) ((a) > (b) ? (a) : (b))\n\nstruct ggml_context * make_ctx(void) {\n    struct ggml_init_params params = {\n        /*.mem_size   =*/ 1 * 1024 * 1024,\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ false,\n    };\n\n    return ggml_init(params);\n}\n\nchar g_userdata[] = \"ggml\";\natomic_int g_custom1_count = 0;\natomic_int g_custom2_count = 0;\natomic_int g_custom3_count = 0;\n\nvoid custom1(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata) {\n    // check that the userdata is correct\n    GGML_ASSERT(userdata == NULL);\n    GGML_ASSERT(ggml_are_same_shape(dst, a));\n\n    atomic_fetch_add(&g_custom1_count, 1);\n\n    const float * a_data = ggml_get_data_f32(a);\n    float * dst_data = ggml_get_data_f32(dst);\n\n    // this assumes that the tensors are contiguous\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    GGML_ASSERT(ggml_is_contiguous(a));\n\n    // parallelize by elements\n    const int ne = (int)ggml_nelements(dst);\n    const int dr = (ne + nth - 1) / nth;\n    const int ie0 = dr * ith;\n    const int ie1 = MIN(ie0 + dr, ne);\n\n    for (int i = ie0; i < ie1; ++i) {\n        dst_data[i] = a_data[i] * 2;\n    }\n}\n\nvoid custom2(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata) {\n    // check that the userdata is correct\n    GGML_ASSERT(userdata == g_userdata);\n    GGML_ASSERT(strcmp(userdata, \"ggml\") == 0);\n    GGML_ASSERT(ggml_are_same_shape(dst, a));\n    GGML_ASSERT(ggml_are_same_shape(dst, b));\n\n    atomic_fetch_add(&g_custom2_count, 1);\n\n    const float * a_data = ggml_get_data_f32(a);\n    const float * b_data = ggml_get_data_f32(b);\n    float * dst_data = ggml_get_data_f32(dst);\n\n    // parallelize by rows\n    const int nr = (int)ggml_nrows(dst);\n    // number of rows per thread\n    const int dr = (nr + nth - 1) / nth;\n    // row range for this thread\n    const int ir0 = dr * ith;\n    const int ir1 = MIN(ir0 + dr, nr);\n\n    // number of columns\n    const int nc = (int)dst->ne[0];\n\n    // this assumes that the tensors are contiguous\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    GGML_ASSERT(ggml_is_contiguous(a));\n    GGML_ASSERT(ggml_is_contiguous(b));\n\n    for (int ir = ir0; ir < ir1; ++ir) {\n        for (int ic = 0; ic < nc; ++ic) {\n            const int i = ir * nc + ic;\n            dst_data[i] = a_data[i] + b_data[i];\n        }\n    }\n}\n\nvoid custom3(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata) {\n    // check that the userdata is correct\n    GGML_ASSERT(userdata == g_userdata);\n    GGML_ASSERT(strcmp(userdata, \"ggml\") == 0);\n    GGML_ASSERT(ggml_are_same_shape(dst, a));\n    GGML_ASSERT(ggml_are_same_shape(dst, b));\n    GGML_ASSERT(ggml_are_same_shape(dst, c));\n\n    atomic_fetch_add(&g_custom3_count, 1);\n\n    const float * a_data = ggml_get_data_f32(a);\n    const float * b_data = ggml_get_data_f32(b);\n    const float * c_data = ggml_get_data_f32(c);\n    float * dst_data = ggml_get_data_f32(dst);\n\n    // dont parallelize\n    GGML_ASSERT(ith == 0);\n\n    // number of elements\n    const int ne = (int)ggml_nelements(dst);\n\n    // this assumes that the tensors are contiguous\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    GGML_ASSERT(ggml_is_contiguous(a));\n    GGML_ASSERT(ggml_is_contiguous(b));\n    GGML_ASSERT(ggml_is_contiguous(c));\n\n    for (int i = 0; i < ne; ++i) {\n        dst_data[i] = a_data[i] + b_data[i] + c_data[i];\n    }\n}\n\nvoid custom(struct ggml_tensor * dst, int ith, int nth, void * userdata) {\n    struct ggml_tensor * src0 = dst->src[0];\n    struct ggml_tensor * src1 = dst->src[1];\n    struct ggml_tensor * src2 = dst->src[2];\n    struct ggml_tensor * src3 = dst->src[3];\n    struct ggml_tensor * src4 = dst->src[4];\n\n    int32_t * dst_data = (int32_t *) ggml_get_data(dst);\n    const float * src0_data = ggml_get_data_f32(src0);\n    const float * src1_data = ggml_get_data_f32(src1);\n    const float * src2_data = ggml_get_data_f32(src2);\n    const float * src3_data = ggml_get_data_f32(src3);\n    const float * src4_data = ggml_get_data_f32(src4);\n\n    // check that the userdata is correct\n    GGML_ASSERT(userdata == g_userdata);\n    GGML_ASSERT(strcmp(userdata, \"ggml\") == 0);\n\n    // check that the tensors are contiguous\n    GGML_ASSERT(ggml_is_contiguous(dst));\n    GGML_ASSERT(ggml_is_contiguous(src0));\n    GGML_ASSERT(ggml_is_contiguous(src1));\n    GGML_ASSERT(ggml_is_contiguous(src2));\n    GGML_ASSERT(ggml_is_contiguous(src3));\n    GGML_ASSERT(ggml_is_contiguous(src4));\n\n    // check that the shapes are the same\n    GGML_ASSERT(ggml_are_same_shape(dst, src0));\n    GGML_ASSERT(ggml_are_same_shape(dst, src1));\n    GGML_ASSERT(ggml_are_same_shape(dst, src2));\n    GGML_ASSERT(ggml_are_same_shape(dst, src3));\n    GGML_ASSERT(ggml_are_same_shape(dst, src4));\n\n\n    for (int i = ith; i < ggml_nelements(dst); i += nth) {\n        dst_data[i] = src0_data[i] + src1_data[i] * src2_data[i] - src3_data[i] * src4_data[i];\n    }\n}\n\nint main(int argc, const char** argv) {\n\n    float buf1_f32[1024];\n    for (int i = 0; i < 1024; ++i) {\n        buf1_f32[i] = (float)(i + 1);\n    }\n    float buf2_f32[1024];\n    for (int i = 0; i < 1024; ++i) {\n        buf2_f32[i] = (float)(i + 1) * 2;\n    }\n    float buf3_f32[1024];\n    for (int i = 0; i < 1024; ++i) {\n        buf3_f32[i] = (float)(i + 1) * 3;\n    }\n\n    // map_custom1\n    // 2 tasks, no userdata, parallelized by elements\n    {\n        struct ggml_context * ctx = make_ctx();\n        struct ggml_tensor * t = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);\n        memcpy(t->data, buf1_f32, ggml_nbytes(t));\n\n        struct ggml_tensor * m1 = ggml_map_custom1(ctx, t, custom1, 2, NULL);\n\n        struct ggml_cgraph * graph = ggml_new_graph(ctx);\n        ggml_build_forward_expand(graph, m1);\n\n        ggml_graph_compute_with_ctx(ctx, graph, 4);\n\n        const float * output = ggml_get_data_f32(m1);\n\n        for (int i = 0; i < ggml_nelements(m1); ++i) {\n            GGML_ASSERT(output[i] == buf1_f32[i] * 2);\n        }\n        GGML_ASSERT(g_custom1_count == 2);\n\n        ggml_free(ctx);\n    }\n\n    // map_custom2\n    // max tasks (4), userdata, parallelized by rows\n    {\n        struct ggml_context * ctx = make_ctx();\n        struct ggml_tensor * t1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);\n        struct ggml_tensor * t2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);\n        memcpy(t1->data, buf1_f32, ggml_nbytes(t1));\n        memcpy(t2->data, buf2_f32, ggml_nbytes(t2));\n\n        struct ggml_tensor * m2 = ggml_map_custom2(ctx, t1, t2, custom2, GGML_N_TASKS_MAX, g_userdata);\n\n        struct ggml_cgraph * graph = ggml_new_graph(ctx);\n        ggml_build_forward_expand(graph, m2);\n\n        ggml_graph_compute_with_ctx(ctx, graph, 4);\n\n        const float * output = ggml_get_data_f32(m2);\n\n        for (int i = 0; i < ggml_nelements(m2); ++i) {\n            GGML_ASSERT(output[i] == buf1_f32[i] + buf2_f32[i]);\n        }\n\n        GGML_ASSERT(g_custom2_count == 4);\n\n        ggml_free(ctx);\n    }\n\n    // map_custom3\n    // 1 task, userdata, not parallelized\n    {\n        struct ggml_context * ctx = make_ctx();\n        struct ggml_tensor * t1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);\n        struct ggml_tensor * t2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);\n        struct ggml_tensor * t3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);\n\n        memcpy(t1->data, buf1_f32, ggml_nbytes(t1));\n        memcpy(t2->data, buf2_f32, ggml_nbytes(t2));\n        memcpy(t3->data, buf3_f32, ggml_nbytes(t3));\n\n        struct ggml_tensor * m3 = ggml_map_custom3(ctx, t1, t2, t3, custom3, 1, g_userdata);\n\n        struct ggml_cgraph * graph = ggml_new_graph(ctx);\n        ggml_build_forward_expand(graph, m3);\n\n        ggml_graph_compute_with_ctx(ctx, graph, 4);\n\n        const float * output = ggml_get_data_f32(m3);\n\n        for (int i = 0; i < ggml_nelements(m3); ++i) {\n            GGML_ASSERT(output[i] == buf1_f32[i] + buf2_f32[i] + buf3_f32[i]);\n        }\n\n        GGML_ASSERT(g_custom3_count == 1);\n\n        ggml_free(ctx);\n    }\n\n    // custom\n    {\n        struct ggml_context * ctx = make_ctx();\n        struct ggml_tensor * t1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);\n        struct ggml_tensor * t2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);\n        struct ggml_tensor * t3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);\n        struct ggml_tensor * t4 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);\n        struct ggml_tensor * t5 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);\n        memcpy(t1->data, buf1_f32, ggml_nbytes(t1));\n        memcpy(t2->data, buf2_f32, ggml_nbytes(t2));\n        memcpy(t3->data, buf3_f32, ggml_nbytes(t3));\n        memcpy(t4->data, buf1_f32, ggml_nbytes(t4));\n        memcpy(t5->data, buf2_f32, ggml_nbytes(t5));\n\n        struct ggml_tensor * args[] = {\n            t1, t2, t3, t4, t5,\n        };\n\n        struct ggml_tensor * m4 = ggml_custom_4d(ctx, GGML_TYPE_I32, 10, 2, 1, 1, args, sizeof(args)/sizeof(args[0]), custom, GGML_N_TASKS_MAX, g_userdata);\n\n        struct ggml_cgraph * graph = ggml_new_graph(ctx);\n        ggml_build_forward_expand(graph, m4);\n\n        ggml_graph_compute_with_ctx(ctx, graph, 4);\n\n        const int32_t * output = (const int32_t *) ggml_get_data(m4);\n\n        for (int i = 0; i < ggml_nelements(m4); ++i) {\n            GGML_ASSERT(output[i] == buf1_f32[i] + buf2_f32[i] * buf3_f32[i] - buf1_f32[i] * buf2_f32[i]);\n        }\n\n        ggml_free(ctx);\n    }\n\n    return 0;\n}\n"
  },
  {
    "path": "tests/test-dup.c",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n\n#include <stdio.h>\n#include <stdlib.h>\n\nvoid arange(struct ggml_tensor* tensor) {\n    GGML_ASSERT(ggml_is_contiguous(tensor));\n    for (int i = 0; i < ggml_nelements(tensor); ++i) {\n        ggml_set_i32_1d(tensor, i, i);\n    }\n}\n\nvoid dup_to(struct ggml_tensor* src, struct ggml_tensor* dst) {\n    GGML_ASSERT(dst->op == GGML_OP_VIEW);\n    GGML_ASSERT(ggml_nelements(src) == ggml_nelements(dst));\n    dst->op = GGML_OP_DUP;\n    dst->src[0] = src;\n}\n\nbool can_dup(enum ggml_type src_type, enum ggml_type dst_type) {\n    if (src_type == dst_type) return true;\n    if (src_type == GGML_TYPE_F32 && ggml_get_type_traits_cpu(dst_type)->from_float) return true;\n    if (dst_type == GGML_TYPE_F32 && ggml_get_type_traits    (src_type)->to_float) return true;\n\n    return false;\n}\n\nint main(int argc, const char ** argv) {\n    struct ggml_init_params params = {\n        .mem_size   = 128*1024*1024,\n        .mem_buffer = NULL,\n        .no_alloc   = false,\n    };\n\n    enum ggml_type type[4] = {GGML_TYPE_I16, GGML_TYPE_I32, GGML_TYPE_F16, GGML_TYPE_F32};\n    for (int i = 0; i < 4; ++i) {\n        enum ggml_type src_type = type[i];\n        for (int j = 0; j < 4; ++j) {\n            enum ggml_type dst_type = type[j];\n            if (!can_dup(src_type, dst_type)) continue;\n            printf(\"Testing dup on %s -> %s copy\\n\", ggml_type_name(src_type), ggml_type_name(dst_type));\n\n            struct ggml_context * ctx = ggml_init(params);\n\n            struct ggml_tensor * src = ggml_new_tensor_2d(ctx, src_type, 10, 11);\n            arange(src);\n            struct ggml_tensor * dst = ggml_new_tensor_2d(ctx, dst_type, 10, 11);\n            ggml_set_i32(dst, 0);\n\n            // 2nd-row: [20, 21, ..., 29]\n            struct ggml_tensor * src_cont = ggml_view_1d(ctx, src, 10, src->nb[1] * 2);\n\n            // 3rd-col: [03, 13, ..., 93]\n            struct ggml_tensor * src_stride = ggml_view_2d(ctx, src, 1, 10, src->nb[1], src->nb[0] * 3);\n\n            struct ggml_tensor * dst_cont_1 = ggml_view_1d(ctx, dst, 10, dst->nb[1] * 5); // 5nd-row\n            struct ggml_tensor * dst_cont_2 = ggml_view_1d(ctx, dst, 10, dst->nb[1] * 6); // 6rd-row\n\n            struct ggml_tensor * dst_stride_1 = ggml_view_2d(ctx, dst, 1, 10, dst->nb[1], dst->nb[0] * 7); // 7th-col\n            struct ggml_tensor * dst_stride_2 = ggml_view_2d(ctx, dst, 1, 10, dst->nb[1], dst->nb[0] * 8); // 8th-col\n\n            struct ggml_cgraph * gf = ggml_new_graph(ctx);\n\n            dup_to(src_cont,   dst_cont_1);\n            dup_to(src_stride, dst_cont_2);\n            dup_to(src_cont,   dst_stride_1);\n            dup_to(src_stride, dst_stride_2);\n\n            ggml_build_forward_expand(gf, dst_cont_1);\n            ggml_build_forward_expand(gf, dst_cont_2);\n            ggml_build_forward_expand(gf, dst_stride_1);\n            ggml_build_forward_expand(gf, dst_stride_2);\n\n            ggml_graph_compute_with_ctx(ctx, gf, 1);\n\n            // src_cont -> dst_cont_1\n            GGML_ASSERT(ggml_get_i32_1d(dst, 49) == 0);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 50) == 20);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 51) == 21);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 52) == 22);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 59) == 29);\n\n            // src_stride -> dst_cont_2\n            GGML_ASSERT(ggml_get_i32_1d(dst, 60) == 3);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 61) == 13);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 62) == 23);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 69) == 93);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 70) == 0);\n\n            // src_cont -> dst_stride_1\n            GGML_ASSERT(ggml_get_i32_1d(dst, 6)   == 0);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 7)   == 20);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 17)  == 21);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 27)  == 22);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 97)  == 29);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 107) == 0);\n\n            // src_stride -> dst_stride_2\n            GGML_ASSERT(ggml_get_i32_1d(dst, 8)   == 03);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 18)  == 13);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 28)  == 23);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 98)  == 93);\n            GGML_ASSERT(ggml_get_i32_1d(dst, 108) == 0);\n\n            ggml_free(ctx);\n        }\n    }\n\n    return 0;\n}\n"
  },
  {
    "path": "tests/test-interpolate.cpp",
    "content": "#include <ggml.h>\n#include <ggml-cpu.h>\n#include <ggml-alloc.h>\n#include <ggml-backend.h>\n#include <ggml-cpp.h>\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <array>\n#include <vector>\n\nbool check_equal(const float * result, const float * expected, int64_t n) {\n    for (int i = 0; i < n; i++) {\n        if(std::abs(result[i] - expected[i]) > 1e-4) {\n            printf(\"result[%d] %f != %f expected[%d]\\n\", i, result[i], expected[i], i);\n            return false;\n        }\n    }\n    return true;\n}\n\nbool test_interpolate(char const* name,\n                      std::array<int64_t, 4> src_ne, const float * src_data,\n                      std::array<int32_t, 4> dst_ne, const float * expected,\n                      uint32_t mode) {\n    ggml_time_init();\n\n    ggml_init_params params {\n        /*.mem_size   =*/ 64 * ggml_tensor_overhead() + ggml_graph_overhead(),\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ true\n    };\n\n    ggml_context_ptr ctx_ptr{ggml_init(params)};\n    ggml_context * ctx = ctx_ptr.get();\n    ggml_cgraph * gf = ggml_new_graph(ctx);\n\n    // Build graph\n    ggml_tensor * src = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, src_ne.data());\n    ggml_tensor * res = ggml_interpolate(ctx, src, dst_ne[0], dst_ne[1], dst_ne[2], dst_ne[3], mode);\n    ggml_build_forward_expand(gf, res);\n\n    // Create backend & allocate buffers\n    ggml_backend_ptr backend_ptr{ggml_backend_cpu_init()};\n    ggml_backend_t backend = backend_ptr.get();\n    ggml_backend_cpu_set_n_threads(backend, 2);\n    ggml_backend_buffer_ptr buffer{ggml_backend_alloc_ctx_tensors(ctx, backend)};\n\n    // Execute and compare results\n    ggml_backend_tensor_set(src, src_data, 0, ggml_nbytes(src));\n    ggml_backend_graph_compute(backend, gf);\n\n    std::vector<float> res_values(ggml_nelements(res));\n    ggml_backend_tensor_get(res, res_values.data(), 0, ggml_nbytes(res));\n\n    bool passed = check_equal(res_values.data(), expected, ggml_nelements(res));\n\n    printf(\"ggml_interpolate(%s): %s\\n\", name, passed ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n    return passed;\n}\n\nconst float input_upscale[] = {\n    0.0f, 1.0f,\n    2.0f, 4.0f\n};\n\nconst float expected_upscale_x2_nearest[] = {\n    0.0f, 0.0f, 1.0f, 1.0f,\n    0.0f, 0.0f, 1.0f, 1.0f,\n    2.0f, 2.0f, 4.0f, 4.0f,\n    2.0f, 2.0f, 4.0f, 4.0f\n};\n\nconst float expected_upscale_x2_bilinear[] = {\n    0.0f, 0.2500f, 0.7500f, 1.00f,\n    0.5f, 0.8125f, 1.4375f, 1.75f,\n    1.5f, 1.9375f, 2.8125f, 3.25f,\n    2.0f, 2.5000f, 3.5000f, 4.00f\n};\n\nconst float expected_upscale_x2_bilinear_align_corners[] = {\n    0.0000f, 0.3333f, 0.6667f, 1.0000f,\n    0.6667f, 1.1111f, 1.5556f, 2.0000f,\n    1.3333f, 1.8889f, 2.4444f, 3.0000f,\n    2.0000f, 2.6667f, 3.3333f, 4.0000f\n};\n\nconst float expected_upscale_x1_5_bilinear_align_corners[] = {\n    0.0f, 1.0f,\n    1.0f, 2.5f,\n    2.0f, 4.0f\n};\n\nconst float input_downscale[] = {\n    0.0f, -1.0f, -2.0f, 0.0f,\n    1.0f, 2.0f , 4.0f , 4.0f,\n    2.0f, 2.0f , 1.0f , 1.0f,\n\n    1.0f, 2.0f , 3.0f , 4.0f,\n    2.0f, 2.0f , 2.0f , 2.0f,\n    -2.0f, 2.0f, -4.0f, 4.0f\n};\n\nconst float expected_downscale_nearest[] = {\n    0.0f, -2.0f,\n\n    1.0f, 3.0f\n};\n\nconst float expected_downscale_bilinear[] = {\n    0.1667f, -0.3750f,  0.7500f,\n    1.7917f,  1.8750f,  1.7500f,\n\n    1.3750f,  2.3750f,  3.3750f,\n   -0.5000f, -0.2500f,  2.5000f\n};\n\nconst float expected_downscale_bilinear_align_corners[] = {\n    0.0f , -1.5f, 0.0f,\n    2.0f ,  1.5f, 1.0f,\n\n    1.0f ,  2.5f, 4.0f,\n    -2.0f, -1.0f, 4.0f\n};\n\nint main() {\n    bool passed = true;\n\n    passed &= test_interpolate(\"upscale_x2_nearest\",\n        {2, 2, 1, 1}, input_upscale,\n        {4, 4, 1, 1}, expected_upscale_x2_nearest,\n        GGML_SCALE_MODE_NEAREST);\n\n    passed &= test_interpolate(\"upscale_x2_bilinear\",\n        {2, 2, 1, 1}, input_upscale,\n        {4, 4, 1, 1}, expected_upscale_x2_bilinear,\n        GGML_SCALE_MODE_BILINEAR);\n\n    passed &= test_interpolate(\"upscale_x2_bilinear_align_corners\",\n        {2, 2, 1, 1}, input_upscale,\n        {4, 4, 1, 1}, expected_upscale_x2_bilinear_align_corners,\n        GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS);\n\n    passed &= test_interpolate(\"upscale_x1_5_bilinear_align_corners\",\n        {2, 2, 1, 1}, input_upscale,\n        {2, 3, 1, 1}, expected_upscale_x1_5_bilinear_align_corners,\n        GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS);\n\n    passed &= test_interpolate(\"downscale_nearest\",\n        {4, 3, 2, 1}, input_downscale,\n        {2, 1, 2, 1}, expected_downscale_nearest,\n        GGML_SCALE_MODE_NEAREST);\n\n    passed &= test_interpolate(\"downscale_bilinear\",\n        {4, 3, 2, 1}, input_downscale,\n        {3, 2, 2, 1}, expected_downscale_bilinear,\n        GGML_SCALE_MODE_BILINEAR);\n\n    passed &= test_interpolate(\"downscale_bilinear_align_corners\",\n        {4, 3, 2, 1}, input_downscale,\n        {3, 2, 2, 1}, expected_downscale_bilinear_align_corners,\n        GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS);\n\n    return passed ? 0 : 1;\n}"
  },
  {
    "path": "tests/test-opt.cpp",
    "content": "// TODO refactor\n\n#include \"ggml.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n#include \"ggml-opt.h\"\n\n#include <cmath>\n#include <cinttypes>\n#include <cstring>\n#include <random>\n#include <string>\n#include <thread>\n#include <vector>\n\n#define TEST_LOG(...)       printf(__VA_ARGS__)\n\nstatic bool almost_equal(const double a, const double b, const double atol) {\n    return fabs(a - b) < atol;\n}\n\nconstexpr int64_t ne_datapoint = 2;\nconstexpr int64_t ne_label     = 1;\nconstexpr int64_t ndata        = 6;\n\nstruct helper_ctx_data {\n    std::vector<ggml_opt_dataset_t>   datasets_supervised;\n    std::vector<struct ggml_tensor *> data_batch;\n    std::vector<struct ggml_tensor *> labels_batch;\n\n    ggml_opt_dataset_t       dataset_unsupervised;\n    struct ggml_context    * ctx_static;\n    struct ggml_context    * ctx_compute;\n    struct ggml_opt_params   opt_params;\n    ggml_opt_context_t       opt_ctx;\n    struct ggml_tensor     * inputs;\n    struct ggml_tensor     * weights;\n    struct ggml_tensor     * outputs;\n    ggml_backend_buffer_t    buf;\n    ggml_opt_result_t        result;\n    ggml_opt_result_t        result2;\n};\n\n// These default values make it easier to check optimization results vs. expected values.\nstatic ggml_opt_optimizer_params helper_get_test_opt_pars(void * userdata) {\n    ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata);\n\n    result.adamw.alpha = 1.0f;\n    result.adamw.beta1 = 0.0f;\n    result.adamw.beta2 = 0.0f;\n    result.adamw.eps   = 0.0f;\n    result.adamw.wd    = 0.0f;\n    result.sgd.wd      = 0.0f;\n    result.sgd.alpha   = 1.0f;\n\n    return result;\n}\n\nstatic helper_ctx_data helper_get_ctx_data(\n        enum ggml_opt_optimizer_type optim,\n        ggml_backend_sched_t    backend_sched,\n        ggml_backend_t          backend,\n        const bool              init_opt_ctx       = true,\n        const bool              optimizer_defaults = true,\n        int64_t                 nbatch_logical     = 1,\n        int64_t                 nbatch_physical    = 1,\n        enum ggml_opt_loss_type loss_type          = GGML_OPT_LOSS_TYPE_SUM) {\n    std::vector<ggml_opt_dataset_t> datasets(ndata);\n    for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) {\n        ggml_opt_dataset_t dataset = ggml_opt_dataset_init(\n            GGML_TYPE_F32, GGML_TYPE_F32, ne_datapoint, ne_label, ndata, ndata_shard);\n\n        float * data   = ggml_get_data_f32(ggml_opt_dataset_data(  dataset));\n        float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));\n\n        for (int64_t idata = 0; idata < ndata; ++idata) {\n            for (int64_t id = 0; id < ne_datapoint; ++id) {\n                data[  idata*ne_datapoint + id] =     16*idata + id;\n            }\n            for (int64_t il = 0; il < ne_label;     ++il) {\n                labels[idata*ne_label     + il] = 16*(16*idata + il);\n            }\n        }\n\n        datasets[ndata_shard-1] = dataset;\n    }\n\n    ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(\n        GGML_TYPE_F32, GGML_TYPE_F32, 1, 0, ndata, /*ndata_shard =*/ 1);\n\n    float * data = ggml_get_data_f32(ggml_opt_dataset_data(dataset_unsupervised));\n\n    for (int64_t idata = 0; idata < ndata; ++idata) {\n        data[idata] = idata;\n    }\n\n    struct ggml_context * ctx_static;\n    struct ggml_context * ctx_compute;\n    {\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ (2*ndata + 2)*ggml_tensor_overhead(),\n            /*.mem_buffer =*/ nullptr,\n            /*.no_alloc   =*/ true,\n        };\n        ctx_static = ggml_init(params);\n    }\n    {\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + 3*ggml_graph_overhead(),\n            /*.mem_buffer =*/ nullptr,\n            /*.no_alloc   =*/ true,\n        };\n        ctx_compute = ggml_init(params);\n    }\n\n    std::vector<struct ggml_tensor *>   data_batch(ndata);\n    std::vector<struct ggml_tensor *> labels_batch(ndata);\n    for (int64_t ndata_batch = 1; ndata_batch <= ndata; ++ndata_batch) {\n        data_batch[ndata_batch-1]   = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, ndata_batch*ne_datapoint);\n        labels_batch[ndata_batch-1] = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, ndata_batch*ne_label);\n    }\n\n    struct ggml_tensor * inputs = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, nbatch_physical);\n    ggml_set_name(inputs, \"inputs\");\n\n    struct ggml_tensor * weights = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);\n    ggml_set_name(weights, \"weights\");\n    ggml_set_param(weights);\n\n    struct ggml_tensor * intermediary = ggml_add(ctx_compute, inputs, weights);\n\n    struct ggml_tensor * outputs = ggml_scale(ctx_compute, intermediary, 1.0f);\n    ggml_set_name(outputs, \"outputs\");\n\n    ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend);\n    const float w0 = float(ndata)/2;\n    ggml_backend_tensor_set(weights, &w0, 0, sizeof(float));\n\n    GGML_ASSERT(nbatch_logical % nbatch_physical == 0);\n    const int32_t opt_period = nbatch_logical / nbatch_physical;\n\n    struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, loss_type);\n    opt_params.ctx_compute = ctx_compute;\n    opt_params.inputs      = inputs;\n    opt_params.outputs     = outputs;\n    opt_params.opt_period  = opt_period;\n    opt_params.optimizer   = optim;\n    if (!optimizer_defaults) {\n        opt_params.get_opt_pars = helper_get_test_opt_pars;\n    }\n    GGML_ASSERT(opt_params.get_opt_pars);\n    ggml_opt_context_t opt_ctx = init_opt_ctx ? ggml_opt_init(opt_params) : nullptr;\n    GGML_ASSERT(!opt_ctx || ggml_opt_context_optimizer_type(opt_ctx) == opt_params.optimizer);\n\n    ggml_opt_result_t result  = ggml_opt_result_init();\n    ggml_opt_result_t result2 = ggml_opt_result_init();\n\n    return {datasets, data_batch, labels_batch, dataset_unsupervised, ctx_static, ctx_compute, opt_params, opt_ctx, inputs, weights, outputs, buf, result, result2};\n}\n\nstatic void helper_free_ctx_data(struct helper_ctx_data ctx_data) {\n    ggml_opt_result_free(ctx_data.result);\n    ggml_opt_result_free(ctx_data.result2);\n    ggml_opt_free(ctx_data.opt_ctx);\n    ggml_backend_buffer_free(ctx_data.buf);\n    ggml_free(ctx_data.ctx_static);\n    ggml_free(ctx_data.ctx_compute);\n    for (ggml_opt_dataset_t dataset : ctx_data.datasets_supervised) {\n        ggml_opt_dataset_free(dataset);\n    }\n    ggml_opt_dataset_free(ctx_data.dataset_unsupervised);\n}\n\nstatic void print_ok(bool subtest_ok) {\n    printf(subtest_ok ? \"\\033[1;32mOK\\033[0m\\n\" : \"\\033[1;31mFAIL\\033[0m\\n\");\n}\n\nstatic void helper_after_test(\n        enum ggml_opt_optimizer_type optim,\n        const char * func, const bool high_level, const std::string options,\n        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {\n    printf(\"  %s(high_level=%s%s, subtest=%s, optimizer=%s): \",\n           func, high_level ? \"yes\" : \"no\", options.c_str(), subtest.c_str(), ggml_opt_optimizer_name(optim));\n    print_ok(subtest_ok);\n    if (subtest_ok)\n        npass++;\n    ntest++;\n}\n\nstatic void print_ok(const char * func, bool subtest_ok, int & npass, int & ntest, const char * args = \"\") {\n    printf(\"  %s(%s): \", func, args);\n    print_ok(subtest_ok);\n    if (subtest_ok)\n        npass++;\n    ++ntest;\n}\n\nstatic std::pair<int, int> test_dataset(\n    enum ggml_opt_optimizer_type optim,\n    ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool shuffle) {\n    int ntest = 0;\n    int npass = 0;\n\n    struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend);\n\n    for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) {\n        ggml_opt_dataset_t dataset = cd.datasets_supervised[ndata_shard-1];\n\n        if (shuffle) {\n            ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);\n        }\n\n        for (int64_t ndata_batch = 1; ndata_batch <= ndata; ++ndata_batch) {\n            if (ndata_batch % ndata_shard != 0) {\n                continue;\n            }\n            bool subtest_ok = true;\n\n            struct ggml_tensor *   data_batch =   cd.data_batch[ndata_batch-1];\n            struct ggml_tensor * labels_batch = cd.labels_batch[ndata_batch-1];\n\n            std::vector<float>   data(ggml_nelements(  data_batch));\n            std::vector<float> labels(ggml_nelements(labels_batch));\n\n            std::vector<int64_t> idata_shuffled;\n            const int64_t nbatches = ndata / ndata_batch;\n            for (int64_t ibatch = 0; ibatch < nbatches; ++ibatch) {\n                ggml_opt_dataset_get_batch(dataset, data_batch, labels_batch, ibatch);\n\n                ggml_backend_tensor_get(  data_batch,   data.data(), 0, ggml_nbytes(  data_batch));\n                ggml_backend_tensor_get(labels_batch, labels.data(), 0, ggml_nbytes(labels_batch));\n\n                for (int64_t idata_batch = 0; idata_batch < ndata_batch; ++idata_batch) {\n                    const int64_t idata = ibatch*ndata_batch + idata_batch;\n                    const int64_t idata_found = data[idata_batch*ne_datapoint] / 16;\n                    subtest_ok = subtest_ok && (shuffle || idata_found == idata);\n                    idata_shuffled.push_back(idata_found);\n\n                    for (int64_t id = 0; id < ne_datapoint; ++id) {\n                        if (data[  idata_batch*ne_datapoint + id] != 16*idata_found + id) {\n                            subtest_ok = false;\n                        }\n                    }\n                    for (int64_t il = 0; il < ne_label;     ++il) {\n                        if (labels[idata_batch*ne_label     + il] != 16*(16*idata_found + il)) {\n                            subtest_ok = false;\n                        }\n                    }\n                }\n            }\n\n            if (!shuffle || ndata % ndata_batch == 0) {\n                const int ndata_max = (ndata / ndata_batch) * ndata_batch;\n\n                for (int64_t idata = 0; subtest_ok && idata < ndata_max; ++idata) {\n                    int ninstances = 0;\n                    for (int64_t id : idata_shuffled) {\n                        ninstances += id == idata;\n                    }\n                    if (ninstances != 1) {\n                        subtest_ok = false;\n                    }\n                }\n            }\n\n            printf(\"  %s(shuffle=%s, ndata_shard=%\" PRId64 \", ndata_batch=%\" PRId64 \"): \",\n                   __func__, shuffle ? \"yes\" : \"no\", ndata_shard, ndata_batch);\n            if (subtest_ok) {\n                printf(\"\\033[1;32mOK\\033[0m\\n\");\n                npass++;\n            } else {\n                printf(\"\\033[1;31mFAIL\\033[0m\\n\");\n            }\n            ntest++;\n        }\n    }\n\n    helper_free_ctx_data(cd);\n\n    return std::make_pair(npass, ntest);\n}\n\nstatic std::pair<int, int> test_grad(\n    enum ggml_opt_optimizer_type optim,\n    ggml_backend_sched_t backend_sched, ggml_backend_t backend) {\n    int ntest = 0;\n    int npass = 0;\n\n    struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false,\n    /*nbatch_logical =*/ 999999, /*nbatch_physical =*/ 1);\n\n    std::vector<float> grad_history(ndata);\n    for (int64_t idata = 0; idata < ndata; ++idata) {\n        grad_history[idata] = NAN;\n    }\n\n    for (int idata = 0; idata < ndata; ++idata) {\n        const float idataf = idata;\n        ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);\n        // leaked\n        ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));\n        ggml_opt_eval(cd.opt_ctx, cd.result);\n        ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, sizeof(float));\n    }\n\n    {\n        bool subtest_ok = true;\n        for (int idata = 0; idata < ndata; ++idata) {\n            if (grad_history[idata] != idata + 1) {\n                subtest_ok = false;\n            }\n        }\n        printf(\"  %s(): \", __func__);\n        if (subtest_ok) {\n            printf(\"\\033[1;32mOK\\033[0m\\n\");\n            npass++;\n        } else {\n            printf(\"\\033[1;31mFAIL\\033[0m\\n\");\n        }\n        ntest++;\n    }\n\n    helper_free_ctx_data(cd);\n\n    return std::make_pair(npass, ntest);\n}\n\nstatic void helper_after_test_forward_backward(\n        enum ggml_opt_optimizer_type optim,\n        const char * func, const bool high_level, const bool shuffle,\n        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {\n    std::string options = \", shuffle=\";\n    options += shuffle ? \"yes\" : \"no\";\n    helper_after_test(optim, func, high_level, options, subtest, subtest_ok, ntest, npass);\n}\n\nstatic std::pair<int, int> test_forward_backward(\n        enum ggml_opt_optimizer_type optim,\n        ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level, const bool shuffle) {\n    int ntest = 0;\n    int npass = 0;\n\n    struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);\n    struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);\n\n    std::vector<float> loss_history(ndata);\n    for (int64_t idata = 0; idata < ndata; ++idata) {\n        loss_history[idata] = NAN;\n    }\n\n    {\n        int64_t ndata;\n        ggml_opt_result_ndata(cd.result, &ndata);\n        double loss;\n        double loss_unc;\n        ggml_opt_result_loss(cd.result, &loss, &loss_unc);\n        double accuracy;\n        double accuracy_unc;\n        ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);\n        const bool subtest_ok = ndata == 0 && almost_equal(loss, 0.0, 1e-6) && std::isnan(loss_unc) && std::isnan(accuracy) && std::isnan(accuracy_unc);\n        helper_after_test_forward_backward(optim, __func__, high_level, shuffle, \"results_initial\", subtest_ok, ntest, npass);\n    }\n\n    if (high_level) {\n        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;\n        if (shuffle) {\n            ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);\n        }\n        ggml_opt_epoch(cd.opt_ctx, dataset, nullptr, cd.result, 0, nullptr, nullptr);\n    } else {\n        for (int idata = 0; idata < ndata; ++idata) {\n            const float idataf = idata;\n            ggml_opt_alloc(cd.opt_ctx, /*backward =*/ false);\n            ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));\n            ggml_opt_eval(cd.opt_ctx, cd.result);\n            ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));\n        }\n    }\n\n    {\n        float weights;\n        ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));\n        const bool subtest_ok = almost_equal(weights, ndata/2, 1e-10);\n        helper_after_test_forward_backward(optim, __func__, high_level, shuffle, \"weights_after_forward\", subtest_ok, ntest, npass);\n    }\n    {\n        constexpr double atol = 1e-10;\n\n        int64_t ndata;\n        ggml_opt_result_ndata(cd.result, &ndata);\n        bool subtest_ok = ndata == 6;\n\n        double loss;\n        double loss_unc;\n        ggml_opt_result_loss(cd.result, &loss, &loss_unc);\n        subtest_ok = subtest_ok && almost_equal(loss, 33.0, atol) && almost_equal(loss_unc, sqrt(3.5), atol);\n\n        double accuracy;\n        double accuracy_unc;\n        ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);\n        subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);\n\n        helper_after_test_forward_backward(optim, __func__, high_level, shuffle, \"results_after_forward\", subtest_ok, ntest, npass);\n    }\n\n    float w0;\n    ggml_backend_tensor_get(cd.weights, &w0, 0, sizeof(float));\n    for (int i = 0; i < 10; ++i) {\n        ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);\n        // leaked.\n        ggml_opt_eval(cd.opt_ctx, cd.result);\n    }\n    ggml_backend_tensor_set(cd.weights, &w0, 0, sizeof(float));\n\n    ggml_opt_reset(cd.opt_ctx, /*optimizer =*/ false);\n    ggml_opt_result_reset(cd.result);\n\n    for (int64_t idata = 0; idata < ndata; ++idata) {\n        loss_history[idata] = NAN;\n    }\n\n    if (high_level) {\n        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;\n        if (shuffle) {\n            ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);\n        }\n        ggml_opt_epoch(cd.opt_ctx, dataset, cd.result, nullptr, ndata, nullptr, nullptr);\n    } else {\n        for (int idata = 0; idata < ndata; ++idata) {\n            const float idataf = idata;\n            ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);\n            ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));\n            ggml_opt_eval(cd.opt_ctx, cd.result);\n            ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));\n        }\n    }\n\n    {\n        float weights;\n        ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));\n        const bool subtest_ok = almost_equal(weights, -ndata * 0.5, 1e-10);\n        helper_after_test_forward_backward(optim, __func__, high_level, shuffle, \"weights_after_forward_backward\", subtest_ok, ntest, npass);\n    }\n    {\n        int64_t ndata;\n        ggml_opt_result_ndata(cd.result, &ndata);\n        bool subtest_ok = ndata == 6;\n\n        double loss;\n        double loss_unc;\n        ggml_opt_result_loss(cd.result, &loss, &loss_unc);\n        subtest_ok = subtest_ok && almost_equal(loss, 18.0, 1e-10) && (shuffle || loss_unc == 0.0);\n\n        double accuracy;\n        double accuracy_unc;\n        ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);\n        subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);\n\n        helper_after_test_forward_backward(optim, __func__, high_level, shuffle, \"result_after_forward_backward\", subtest_ok, ntest, npass);\n    }\n\n    helper_free_ctx_data(cd);\n\n    return std::make_pair(npass, ntest);\n}\n\nstatic std::pair<int, int> test_epoch_vs_fit(\n    enum ggml_opt_optimizer_type optim,\n    ggml_backend_sched_t backend_sched, ggml_backend_t backend) {\n    int ntest = 0;\n    int npass = 0;\n\n    float weights_epoch;\n    float weights_fit;\n\n    {\n        struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ true);\n        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;\n\n        ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);\n        ggml_opt_epoch(cd.opt_ctx, dataset, cd.result, nullptr, ndata, nullptr, nullptr);\n        // leaked.\n\n        ggml_backend_tensor_get(cd.weights, &weights_epoch, 0, ggml_nbytes(cd.weights));\n        helper_free_ctx_data(cd);\n    }\n    {\n        struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ false);\n        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;\n\n        ggml_opt_fit(backend_sched, cd.ctx_compute, cd.inputs, cd.outputs, dataset, GGML_OPT_LOSS_TYPE_SUM,\n                     optim, ggml_opt_get_default_optimizer_params, 1, 1, 0.0f, true);\n\n        ggml_backend_tensor_get(cd.weights, &weights_fit, 0, ggml_nbytes(cd.weights));\n        helper_free_ctx_data(cd);\n    }\n\n    const bool subtest_ok = weights_epoch == weights_fit;\n\n    print_ok(__func__, subtest_ok, npass, ntest);\n\n    return std::make_pair(npass, ntest);\n}\n\nstatic void helper_after_test_idata_split(\n        enum ggml_opt_optimizer_type optim,\n        const char * func, const bool high_level, const int epoch,\n        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {\n    std::string options = \", epoch=\";\n    options += std::to_string(epoch);\n    helper_after_test(optim, func, high_level, options, subtest, subtest_ok, ntest, npass);\n}\n\nstatic std::pair<int, int> test_idata_split(\n    enum ggml_opt_optimizer_type optim,\n    ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level) {\n    int ntest = 0;\n    int npass = 0;\n\n    struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);\n    struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);\n    const int idata_split = ndata * 2/3;\n\n    std::vector<float> loss_history(ndata);\n    for (int64_t idata = 0; idata < ndata; ++idata) {\n        loss_history[idata] = NAN;\n    }\n\n    bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;\n    for (int epoch = 1; epoch <= 4; ++epoch) {\n        if (high_level) {\n            ggml_opt_epoch(cd.opt_ctx, cd.dataset_unsupervised, cd.result, cd.result2, idata_split, nullptr, nullptr);\n        } else {\n            int idata = 0;\n            for (; idata < idata_split; ++idata) {\n                const float idataf = idata;\n                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);\n                ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));\n                ggml_opt_eval(cd.opt_ctx, cd.result);\n                ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));\n            }\n            for (; idata < ndata; ++idata) {\n                const float idataf = idata;\n                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ false);\n                ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));\n                ggml_opt_eval(cd.opt_ctx, cd.result2);\n                ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));\n            }\n        }\n\n        if (adamw) {\n            float weights;\n            ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));\n            const bool subtest_ok = almost_equal(weights, ndata/2 - epoch*idata_split, 1e-10);\n            helper_after_test_idata_split(optim, __func__, high_level, epoch, \"weights\", subtest_ok, ntest, npass);\n        }\n        if (adamw) {\n            constexpr double atol = 1e-10;\n\n            int64_t ndata_result;\n            ggml_opt_result_ndata(cd.result, &ndata_result);\n            bool subtest_ok = ndata_result == idata_split;\n\n            double loss;\n            double loss_unc;\n            ggml_opt_result_loss(cd.result, &loss, &loss_unc);\n            subtest_ok = subtest_ok && almost_equal(loss, 28.0 - epoch*16.0, atol) && almost_equal(loss_unc, 0.0, atol);\n\n            double accuracy;\n            double accuracy_unc;\n            ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);\n            subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);\n\n            helper_after_test_idata_split(optim, __func__, high_level, epoch, \"results_backward\", subtest_ok, ntest, npass);\n        }\n        if (adamw) {\n            constexpr double atol = 1e-10;\n\n            int64_t ndata_result;\n            ggml_opt_result_ndata(cd.result2, &ndata_result);\n            bool subtest_ok = ndata_result == ndata - idata_split;\n\n            double loss;\n            double loss_unc;\n            ggml_opt_result_loss(cd.result2, &loss, &loss_unc);\n            subtest_ok = subtest_ok && almost_equal(loss, 15.0 - epoch*8, atol) && almost_equal(loss_unc, sqrt(0.5), atol);\n\n            double accuracy;\n            double accuracy_unc;\n            ggml_opt_result_accuracy(cd.result2, &accuracy, &accuracy_unc);\n            subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);\n\n            helper_after_test_idata_split(optim, __func__, high_level, epoch, \"results_forward\", subtest_ok, ntest, npass);\n        }\n\n        ggml_opt_result_reset(cd.result);\n        ggml_opt_result_reset(cd.result2);\n    }\n\n    helper_free_ctx_data(cd);\n\n    return std::make_pair(npass, ntest);\n}\n\nstatic void helper_after_test_gradient_accumulation(\n        enum ggml_opt_optimizer_type optim,\n        const char * func, const int nbatch_physical, const enum ggml_opt_loss_type loss_type, const int epoch,\n        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {\n    std::string options = \", nbatch_physical=\";\n    options += std::to_string(nbatch_physical);\n    options += \", loss_type=\";\n    options += loss_type == GGML_OPT_LOSS_TYPE_MEAN ? \"mean\" : \"sum\";\n    options += \", epoch=\";\n    options += std::to_string(epoch);\n    helper_after_test(optim, func, false, options, subtest, subtest_ok, ntest, npass);\n}\n\nstatic std::pair<int, int> test_gradient_accumulation(\n        enum ggml_opt_optimizer_type optim,\n        ggml_backend_sched_t backend_sched, ggml_backend_t backend, const int32_t nbatch_physical, const enum ggml_opt_loss_type loss_type) {\n    int ntest = 0;\n    int npass = 0;\n\n    struct helper_ctx_data cd = helper_get_ctx_data(\n        optim,\n        backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, /*nbatch_logical =*/ 6, nbatch_physical, loss_type);\n\n    std::vector<float> grad_history(ndata);\n    for (int64_t idata = 0; idata < ndata; ++idata) {\n        grad_history[idata] = NAN;\n    }\n\n    bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;\n    if (adamw)\n    for (int epoch = 1; epoch <= 4; ++epoch) {\n        if (nbatch_physical == 1) {\n            for (int idata = 0; idata < ndata; ++idata) {\n                const float idataf = idata;\n                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);\n                ggml_backend_tensor_set(cd.inputs, &idataf, 0, 1*sizeof(float));\n                ggml_opt_eval(cd.opt_ctx, cd.result);\n                ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, 1*sizeof(float));\n            }\n        } else if (nbatch_physical == 2) {\n            for (int idata = 0; idata < ndata; idata += 2) {\n                const float idataf[2] = {float(idata + 0), float(idata + 1)};\n                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);\n                ggml_backend_tensor_set(cd.inputs, idataf, 0, 2*sizeof(float));\n                ggml_opt_eval(cd.opt_ctx, cd.result);\n\n                grad_history[idata + 0] = 0.0f;\n                ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata + 1, 0, 1*sizeof(float));\n            }\n        } else {\n            GGML_ASSERT(false);\n        }\n\n        {\n            GGML_ASSERT(ndata == 6);\n            constexpr double atol = 1e-6;\n            bool subtest_ok = true;\n            if (loss_type == GGML_OPT_LOSS_TYPE_SUM) {\n                if (nbatch_physical == 1) {\n                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0, atol);\n                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 3.0, atol);\n                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 5.0, atol);\n                } else {\n                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 0.0, atol);\n                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 0.0, atol);\n                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 0.0, atol);\n                }\n                subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0, atol);\n                subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0, atol);\n                subtest_ok = subtest_ok && almost_equal(grad_history[5], 6.0, atol);\n            } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) {\n                if (nbatch_physical == 1) {\n                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0/ndata, atol);\n                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 3.0/ndata, atol);\n                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 5.0/ndata, atol);\n                } else {\n                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 0.0/ndata, atol);\n                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 0.0/ndata, atol);\n                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 0.0/ndata, atol);\n                }\n                subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0/ndata, atol);\n                subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0/ndata, atol);\n                subtest_ok = subtest_ok && almost_equal(grad_history[5], 6.0/ndata, atol);\n            } else {\n                GGML_ASSERT(false);\n            }\n            helper_after_test_gradient_accumulation(optim, __func__, nbatch_physical, loss_type, epoch, \"grads\", subtest_ok, ntest, npass);\n        }\n        bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;\n        if (adamw) {\n            constexpr double atol = 1e-6;\n            float weights;\n            ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));\n            const bool subtest_ok = almost_equal(weights, (ndata/2) - epoch, atol);\n            helper_after_test_gradient_accumulation(optim, __func__, nbatch_physical, loss_type, epoch, \"weights\", subtest_ok, ntest, npass);\n        }\n        {\n            constexpr double atol = 1e-6;\n            int64_t ndata_result;\n            ggml_opt_result_ndata(cd.result, &ndata_result);\n            bool subtest_ok = almost_equal(ndata_result, ndata/nbatch_physical, atol);\n\n            double loss;\n            ggml_opt_result_loss(cd.result, &loss, /*loss_unc =*/ nullptr);\n            if (loss_type == GGML_OPT_LOSS_TYPE_SUM) {\n                subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0), atol);\n            } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) {\n                subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0) / ndata, atol);\n            } else {\n                GGML_ASSERT(false);\n            }\n\n            double accuracy;\n            double accuracy_unc;\n            ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);\n            subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);\n\n            helper_after_test_gradient_accumulation(optim, __func__, nbatch_physical, loss_type, epoch, \"results\", subtest_ok, ntest, npass);\n        }\n\n        ggml_opt_result_reset(cd.result);\n    }\n\n    helper_free_ctx_data(cd);\n\n    return std::make_pair(npass, ntest);\n}\n\nfloat constexpr g_sgd_lr = 1e-4f;\n\nint constexpr g_sgd_epochs = 900;\n\nstatic ggml_opt_optimizer_params helper_get_regression_opt_pars(void * userdata) {\n    int64_t epoch = *(int64_t*)userdata;\n    ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);\n    result.adamw.alpha = 0.1f;\n    result.sgd.alpha = g_sgd_lr * std::pow(.99, 1000 * (double)epoch / g_sgd_epochs);\n    result.sgd.wd = 1e-10;\n    return result;\n}\n\nstatic std::pair<int, int> test_regression(\n        enum ggml_opt_optimizer_type optim,\n        ggml_backend_sched_t backend_sched, ggml_backend_t backend) {\n    int ntest = 0;\n    int npass = 0;\n\n    // Test for simple regression with f(x) = a*x + b\n\n    constexpr int64_t ndata_regression = 201;\n    constexpr float a_true = 1.2f;\n    constexpr float b_true = 3.4f;\n\n    std::mt19937 gen(12345);\n    std::normal_distribution<float> nd{0.0f, 0.1f};\n\n    ggml_opt_dataset_t dataset = ggml_opt_dataset_init(\n        GGML_TYPE_F32, GGML_TYPE_F32, 1, 1, ndata_regression, ndata_regression);\n\n    float * data   = ggml_get_data_f32(ggml_opt_dataset_data(  dataset));\n    float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));\n\n    constexpr float x_min = -100.0f;\n    constexpr float x_max =  100.0f;\n\n    for (int64_t idata = 0; idata < ndata_regression; ++idata) {\n        const float x = x_min + (x_max - x_min) * idata/(ndata_regression-1);\n        const float y = a_true*x + b_true + nd(gen);\n\n        data[idata]   = x;\n        labels[idata] = y;\n    }\n\n    struct ggml_context * ctx_static;\n    struct ggml_context * ctx_compute;\n    {\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ 3*ggml_tensor_overhead(),\n            /*.mem_buffer =*/ nullptr,\n            /*.no_alloc   =*/ true,\n        };\n        ctx_static = ggml_init(params);\n    }\n    {\n        struct ggml_init_params params = {\n            /*.mem_size   =*/ GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + 3*ggml_graph_overhead(),\n            /*.mem_buffer =*/ nullptr,\n            /*.no_alloc   =*/ true,\n        };\n        ctx_compute = ggml_init(params);\n    }\n\n    // The first dimension is the dimension of the datapoints, the second dimension is the number of datapoints.\n    struct ggml_tensor * x = ggml_new_tensor_2d(ctx_static, GGML_TYPE_F32, 1, ndata_regression);\n    ggml_set_name(x, \"x\");\n\n    struct ggml_tensor * a = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);\n    ggml_set_name(a, \"a\");\n    ggml_set_param(a);\n\n    struct ggml_tensor * b = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);\n    ggml_set_name(b, \"b\");\n    ggml_set_param(b);\n\n    struct ggml_tensor * f = ggml_add(ctx_compute, ggml_mul(ctx_compute, x, a), b);\n    ggml_set_name(f, \"f\");\n\n    ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend);\n    const float a0 = 1.0f;\n    const float b0 = 3.0f;\n    ggml_backend_tensor_set(a, &a0, 0, sizeof(float));\n    ggml_backend_tensor_set(b, &b0, 0, sizeof(float));\n\n    bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;\n    int64_t const n_epoch = adamw ? 100 : g_sgd_epochs;\n    ggml_opt_fit(backend_sched, ctx_compute, x, f, dataset, GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR, optim,\n                 helper_get_regression_opt_pars, n_epoch, ndata_regression, 0.0f, true);\n\n    {\n        float a_fit;\n        ggml_backend_tensor_get(a, &a_fit, 0, sizeof(float));\n        float b_fit;\n        ggml_backend_tensor_get(b, &b_fit, 0, sizeof(float));\n        float tol = adamw ? 1e-2 : 5e-2;\n        const bool aok = almost_equal(a_fit, a_true, tol);\n        const bool bok = almost_equal(b_fit, b_true, tol);\n        const bool subtest_ok = aok && bok;\n        print_ok(__func__, adamw ? subtest_ok : true, npass, ntest, \"subtest=weights\");\n    }\n\n    ggml_backend_buffer_free(buf);\n    ggml_free(ctx_static);\n    ggml_opt_dataset_free(dataset);\n\n    return std::make_pair(npass, ntest);\n}\n\nstatic std::pair<int, int> test_backend(\n    ggml_backend_sched_t backend_sched, ggml_backend_t backend, enum ggml_opt_optimizer_type optim) {\n    int npass = 0;\n    int ntest = 0;\n\n    for (bool shuffle : {false, true}) {\n        std::pair<int, int> partial = test_dataset(optim, backend_sched, backend, shuffle);\n        npass += partial.first;\n        ntest += partial.second;\n    }\n    {\n        std::pair<int, int> partial = test_grad(optim, backend_sched, backend);\n        npass += partial.first;\n        ntest += partial.second;\n    }\n    for (bool high_level : {false, true}){\n        for (bool shuffle : {false, true}) {\n            if (!high_level && shuffle) {\n                continue;\n            }\n\n            std::pair<int, int> partial = test_forward_backward(optim, backend_sched, backend, high_level, shuffle);\n            npass += partial.first;\n            ntest += partial.second;\n        }\n    }\n    {\n      std::pair<int, int> partial = test_epoch_vs_fit(optim, backend_sched, backend);\n        npass += partial.first;\n        ntest += partial.second;\n    }\n    for (bool high_level : {false, true}){\n        std::pair<int, int> partial = test_idata_split(optim, backend_sched, backend, high_level);\n        npass += partial.first;\n        ntest += partial.second;\n    }\n    bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;\n    if (adamw) {\n        for (int32_t nbatch_physical : { 2, 1 }) {\n            for (enum ggml_opt_loss_type loss_type : { GGML_OPT_LOSS_TYPE_SUM, GGML_OPT_LOSS_TYPE_MEAN }) {\n                std::pair<int, int> partial =\n                    test_gradient_accumulation(optim, backend_sched, backend, nbatch_physical, loss_type);\n                npass += partial.first;\n                ntest += partial.second;\n            }\n        }\n    }\n    {\n        std::pair<int, int> partial = test_regression(optim, backend_sched, backend);\n        npass += partial.first;\n        ntest += partial.second;\n    }\n\n    return std::make_pair(npass, ntest);\n}\n\n\nint main(void) {\n    ggml_log_set(nullptr, nullptr);\n    ggml_backend_load_all();\n    const size_t dev_count = ggml_backend_dev_count();\n    printf(\"Testing %zu devices\\n\\n\", dev_count);\n    size_t n_ok = 0;\n\n    std::vector<ggml_backend_dev_t> devs;\n    std::vector<ggml_backend_t>     backends;\n\n    for (size_t i = 0; i < dev_count; ++i) {\n        devs.push_back(ggml_backend_dev_get(i));\n\n        ggml_backend_t backend = ggml_backend_dev_init(devs[i], NULL);\n        GGML_ASSERT(backend != NULL);\n\n        auto * reg = ggml_backend_dev_backend_reg(devs[i]);\n        auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, \"ggml_backend_set_n_threads\");\n        if (ggml_backend_set_n_threads_fn) {\n            ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency() / 2);\n        }\n        backends.push_back(backend);\n    }\n\n    size_t n_total = 0;\n    for (enum ggml_opt_optimizer_type optim : { GGML_OPT_OPTIMIZER_TYPE_ADAMW, GGML_OPT_OPTIMIZER_TYPE_SGD }) {\n        for (size_t i = 0; i < dev_count; ++i) {\n            // Put the backend to be tested in front so that it's prioritized:\n            std::vector<ggml_backend_t> backends_modded = { backends[i] };\n            backends_modded.insert(backends_modded.end(), backends.begin(), backends.end());\n\n            ggml_backend_sched_t backend_sched = ggml_backend_sched_new(\n                backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false, true);\n\n            char const* devname = ggml_backend_dev_name(devs[i]);\n            printf(\"Backend %zu/%zu: %s\\n\", i + 1, dev_count, devname);\n            printf(\"  Device description: %s\\n\", ggml_backend_dev_description(devs[i]));\n            size_t free, total;  // NOLINT\n            ggml_backend_dev_memory(devs[i], &free, &total);\n            printf(\"  Device memory: %zu MB (%zu MB free)\\n\", total / 1024 / 1024, free / 1024 / 1024);\n            printf(\"\\n\");\n\n            bool skip;\n            {\n                struct ggml_init_params params = {\n                    /*.mem_size   =*/ 6*ggml_tensor_overhead(),\n                    /*.mem_buffer =*/ nullptr,\n                    /*.no_alloc   =*/ true,\n                };\n                ggml_context * ctx = ggml_init(params);\n                ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);\n                ggml_set_param(a);\n                ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);\n                ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);\n                ggml_tensor * d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);\n\n                ggml_tensor * t = nullptr;\n                switch (optim) {\n                    case GGML_OPT_OPTIMIZER_TYPE_ADAMW: {\n                        ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 7);\n                        t = ggml_opt_step_adamw(ctx, a, b, c, d, p);\n                    } break;\n                    case GGML_OPT_OPTIMIZER_TYPE_SGD: {\n                        ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2);\n                        t = ggml_opt_step_sgd(ctx, a, b, p);\n                    } break;\n                    case GGML_OPT_OPTIMIZER_TYPE_COUNT: {\n                        GGML_ABORT(\"fatal error\");\n                    }\n                }\n                skip = !ggml_backend_supports_op(backends[i], t);\n                ggml_free(ctx);\n            }\n\n            std::pair<int, int> result;\n            if (!skip) {\n                result = test_backend(backend_sched, backends[i], optim);\n                printf(\"  %d/%d tests passed\\n\", result.first, result.second);\n            }\n\n            printf(\"  Backend %s %s: \", ggml_backend_name(backends[i]), ggml_opt_optimizer_name(optim));\n            if (skip) {\n                printf(\"\\033[0;33mSKIPPED\\033[0m\\n\");\n                n_ok++;\n            } else if (result.first == result.second) {\n                printf(\"\\033[1;32mOK\\033[0m\\n\");\n                n_ok++;\n            } else {\n                printf(\"\\033[1;31mFAIL\\033[0m\\n\");\n            }\n            ++n_total;\n            printf(\"\\n\");\n            ggml_backend_sched_free(backend_sched);\n        }\n    }\n\n    for (ggml_backend_t backend : backends) {\n        ggml_backend_free(backend);\n    }\n\n    printf(\"%zu/%zu backend*optimizer passed\\n\", n_ok, n_total);\n    bool ok = n_ok == n_total;\n    print_ok(ok);\n    return ok ? 0 : 1;\n}\n"
  },
  {
    "path": "tests/test-pad-reflect-1d.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n\n#ifdef GGML_USE_CUDA\n#include \"ggml-cuda.h\"\n#endif\n\n#ifdef GGML_USE_METAL\n#include \"ggml-metal.h\"\n#endif\n\n#include <string.h>\n#include <stdio.h>\n#include <stdlib.h>\n\n\nstatic void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {\n    (void) level;\n    (void) user_data;\n    fputs(text, stderr);\n    fflush(stderr);\n}\n\nstruct ggml_context* make_ctx(void) {\n    struct ggml_init_params params = {\n        /*.mem_size   =*/ 2 * 1024 * 1024,\n        /*.mem_buffer =*/ nullptr,\n        /*.no_alloc.  =*/ false\n    };\n    return ggml_init(params);\n}\n\nvoid check_tensor(struct ggml_tensor * t, float * expected_t_d, int ne0, int ne1, int ne2) {\n    GGML_ASSERT(t->type == GGML_TYPE_F32);\n    GGML_ASSERT(t->ne[0] == ne0);\n    GGML_ASSERT(t->ne[1] == ne1);\n    GGML_ASSERT(t->ne[2] == ne2);\n    for (int i2 = 0; i2 < ne2; ++i2) {\n        for (int i1 = 0; i1 < ne1; ++i1) {\n            for (int i0 = 0; i0 < ne0; ++i0) {\n                float expected = *(expected_t_d + i2 * ne1 * ne0 + i1 * ne0 + i0);\n                float actual = ggml_get_data_f32(t)[i2 * ne1 * ne0 + i1 * ne0 + i0];\n                if (expected != actual) {\n                    printf(\"expected %.1f, got %.1f at (%d,%d,%d)\\n\", expected, actual, i0, i1, i2);\n                }\n                GGML_ASSERT(expected == actual);\n            }\n        }\n    }\n}\n\nvoid test_pad_reflect_1d(bool use_gpu) {\n    ggml_backend_t backend = NULL;\n    struct ggml_init_params params;\n    ggml_backend_buffer_t buffer;\n    struct ggml_context * ctx;\n    struct ggml_tallocr tallocr;\n    ggml_gallocr_t gallocr;\n\n    // initialize the backend\n#ifdef GGML_USE_CUDA\n    if (use_gpu) {\n        fprintf(stderr, \"%s: using CUDA backend\\n\", __func__);\n        backend = ggml_backend_cuda_init(0);\n        if (!backend) {\n            fprintf(stderr, \"%s: ggml_backend_cuda_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n#ifdef GGML_USE_METAL\n    if (use_gpu) {\n        fprintf(stderr, \"%s: using Metal backend\\n\", __func__);\n        backend = ggml_backend_metal_init();\n        if (!backend) {\n            fprintf(stderr, \"%s: ggml_backend_metal_init() failed\\n\", __func__);\n        }\n    }\n#endif\n\n    if (!backend) {\n        fprintf(stderr, \"%s: using CPU backend\\n\", __func__);\n        backend = ggml_backend_cpu_init();\n    }\n\n    // Test cases for different padding configurations\n    {\n        params = ggml_init_params{\n            /*.mem_size   =*/ 16*1024*1024,\n            /*.mem_buffer =*/ nullptr,\n            /*.no_alloc.  =*/ true\n        };\n\n        ggml_log_set(ggml_log_callback_default, nullptr);\n\n        ctx = ggml_init(params);\n        buffer = ggml_backend_alloc_buffer(backend, 16*1024*1024);\n        tallocr = ggml_tallocr_new(buffer);\n        gallocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));\n\n        // Create a simple 1D input tensor [1, 2, 3, 4]\n        struct ggml_tensor * t = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4);\n        float input_data[] = {1.0f, 2.0f, 3.0f, 4.0f};\n        ggml_tallocr_alloc(&tallocr, t);\n\n        // load data to buffer\n        if(ggml_backend_is_cpu(backend)) {\n            memcpy(t->data, input_data, ggml_nbytes(t));\n        } else {\n            ggml_backend_tensor_set(t, input_data, 0, ggml_nbytes(t));\n        }\n\n        // Test case 1: pad left=1, right=1\n        // Expected: [2, 1, 2, 3, 4, 3]\n        float expected_1[] = {2.0f, 1.0f, 2.0f, 3.0f, 4.0f, 3.0f};\n        struct ggml_tensor * out_1 = ggml_pad_reflect_1d(ctx, t, 1, 1);\n\n        // Test case 2: pad left=2, right=1\n        // Expected: [3, 2, 1, 2, 3, 4, 3]\n        float expected_2[] = {3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 4.0f, 3.0f};\n        struct ggml_tensor * out_2 = ggml_pad_reflect_1d(ctx, t, 2, 1);\n\n        // Test case 3: pad left=1, right=2\n        // Expected: [2, 1, 2, 3, 4, 3, 2]\n        float expected_3[] = {2.0f, 1.0f, 2.0f, 3.0f, 4.0f, 3.0f, 2.0f};\n        struct ggml_tensor * out_3 = ggml_pad_reflect_1d(ctx, t, 1, 2);\n\n        struct ggml_cgraph * gf = ggml_new_graph(ctx);\n        ggml_build_forward_expand(gf, out_1);\n        ggml_build_forward_expand(gf, out_2);\n        ggml_build_forward_expand(gf, out_3);\n\n        ggml_gallocr_alloc_graph(gallocr, gf);\n\n        ggml_backend_graph_compute(backend, gf);\n\n        check_tensor(out_1, expected_1, 6, 1, 1);\n        check_tensor(out_2, expected_2, 7, 1, 1);\n        check_tensor(out_3, expected_3, 7, 1, 1);\n\n        ggml_free(ctx);\n        ggml_backend_buffer_free(buffer);\n        ggml_gallocr_free(gallocr);\n    }\n\n    {\n        params = ggml_init_params{\n            /*.mem_size   =*/ 16*1024*1024,\n            /*.mem_buffer =*/ nullptr,\n            /*.no_alloc.  =*/ true\n        };\n\n        ggml_log_set(ggml_log_callback_default, nullptr);\n\n        ctx = ggml_init(params);\n        buffer = ggml_backend_alloc_buffer(backend, 16*1024*1024);\n        tallocr = ggml_tallocr_new(buffer);\n        gallocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));\n\n        // Create a 2D input tensor (5 columns × 4 rows)\n        struct ggml_tensor * t = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 5, 4);\n        float input_data[] = {\n            1.0f, 2.0f, 3.0f, 4.0f, 5.0f,  // row 1\n            6.0f, 7.0f, 8.0f, 9.0f, 10.0f, // row 2\n            11.0f, 12.0f, 13.0f, 14.0f, 15.0f, // row 3\n            16.0f, 17.0f, 18.0f, 19.0f, 20.0f  // row 4\n        };\n        ggml_tallocr_alloc(&tallocr, t);\n\n        // load data to buffer\n        if(ggml_backend_is_cpu(backend)) {\n            memcpy(t->data, input_data, ggml_nbytes(t));\n        } else {\n            ggml_backend_tensor_set(t, input_data, 0, ggml_nbytes(t));\n        }\n\n        // Test case 4: pad left=3, right=2 on a 2D tensor\n        // Each row should be padded independently\n        float expected_4[] = {\n            4.0f, 3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 4.0f, 3.0f,  // row 1\n            9.0f, 8.0f, 7.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 9.0f, 8.0f, // row 2\n            14.0f, 13.0f, 12.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 14.0f, 13.0f, // row 3\n            19.0f, 18.0f, 17.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 19.0f, 18.0f  // row 4\n        };\n        struct ggml_tensor * out_4 = ggml_pad_reflect_1d(ctx, t, 3, 2);\n\n        struct ggml_cgraph * gf = ggml_new_graph(ctx);\n        ggml_build_forward_expand(gf, out_4);\n\n        ggml_gallocr_alloc_graph(gallocr, gf);\n\n        ggml_backend_graph_compute(backend, gf);\n\n        check_tensor(out_4, expected_4, 10, 4, 1);\n\n        ggml_free(ctx);\n        ggml_gallocr_free(gallocr);\n        ggml_backend_buffer_free(buffer);\n    }\n\n    ggml_backend_free(backend);\n}\n\nint main(int argc, const char * argv[]) {\n    bool use_gpu = false;\n    if (argc > 1) {\n        use_gpu = strcmp(argv[1], \"--gpu\") == 0;\n    }\n    test_pad_reflect_1d(use_gpu);\n    return 0;\n}\n"
  },
  {
    "path": "tests/test-pool.c",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n\n#include <string.h>\n#include <stdio.h>\n#include <stdlib.h>\n\nstruct ggml_context* make_ctx(void) {\n    struct ggml_init_params params = {\n        .mem_size = 2 * 1024 * 1024,\n    };\n\n    return ggml_init(params);\n}\n\nint main(int argc, const char** argv) {\n\n    float buf_f32[1024];\n    ggml_fp16_t buf_f16[1024];\n    for (int i = 0; i < 1024; ++i) {\n        buf_f32[i] = (float)(i + 1);\n        buf_f16[i] = ggml_fp32_to_fp16(buf_f32[i]);\n    }\n\n    // avg pool 1d - Float 32\n    {\n        struct ggml_context * ctx = make_ctx();\n        struct ggml_tensor * t = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);\n        memcpy(t->data, buf_f32, ggml_nbytes(t));\n\n        struct ggml_tensor * t_pooled = ggml_pool_1d(ctx, t, GGML_OP_POOL_AVG, 3, 3, 0);\n        GGML_ASSERT(t_pooled->ne[0] == 3);\n        GGML_ASSERT(t_pooled->ne[1] == 2);\n        GGML_ASSERT(t_pooled->ne[2] == 1);\n\n        struct ggml_cgraph * graph = ggml_new_graph(ctx);\n        ggml_build_forward_expand(graph, t_pooled);\n\n        ggml_graph_compute_with_ctx(ctx, graph, 4);\n\n        const float * output = ggml_get_data_f32(t_pooled);\n\n        GGML_ASSERT(output[0] == 2);\n        GGML_ASSERT(output[1] == 5);\n        GGML_ASSERT(output[2] == 8);\n        GGML_ASSERT(output[3] == 12);\n        GGML_ASSERT(output[4] == 15);\n        GGML_ASSERT(output[5] == 18);\n\n        ggml_free(ctx);\n    }\n\n    // avg pool 1d - Float 16\n    {\n        struct ggml_context * ctx = make_ctx();\n        struct ggml_tensor * t = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 10, 2);\n        memcpy(t->data, buf_f16, ggml_nbytes(t));\n\n        struct ggml_tensor * t_pooled = ggml_pool_1d(ctx, t, GGML_OP_POOL_AVG, 3, 3, 0);\n        GGML_ASSERT(t_pooled->ne[0] == 3);\n        GGML_ASSERT(t_pooled->ne[1] == 2);\n        GGML_ASSERT(t_pooled->ne[2] == 1);\n\n        struct ggml_cgraph * graph = ggml_new_graph(ctx);\n        ggml_build_forward_expand(graph, t_pooled);\n\n        ggml_graph_compute_with_ctx(ctx, graph, 4);\n\n        const float * output = ggml_get_data_f32(t_pooled);\n\n        GGML_ASSERT(output[0] == 2);\n        GGML_ASSERT(output[1] == 5);\n        GGML_ASSERT(output[2] == 8);\n        GGML_ASSERT(output[3] == 12);\n        GGML_ASSERT(output[4] == 15);\n        GGML_ASSERT(output[5] == 18);\n\n        ggml_free(ctx);\n    }\n\n    // max pool 1d - Float 32\n    {\n        struct ggml_context * ctx = make_ctx();\n        struct ggml_tensor * t = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);\n        memcpy(t->data, buf_f32, ggml_nbytes(t));\n\n        struct ggml_tensor * t_pooled = ggml_pool_1d(ctx, t, GGML_OP_POOL_MAX, 3, 3, 0);\n        GGML_ASSERT(t_pooled->ne[0] == 3);\n        GGML_ASSERT(t_pooled->ne[1] == 2);\n        GGML_ASSERT(t_pooled->ne[2] == 1);\n\n        struct ggml_cgraph * graph = ggml_new_graph(ctx);\n        ggml_build_forward_expand(graph, t_pooled);\n\n        ggml_graph_compute_with_ctx(ctx, graph, 4);\n\n        const float * output = ggml_get_data_f32(t_pooled);\n        GGML_ASSERT(output[0] == 3);\n        GGML_ASSERT(output[1] == 6);\n        GGML_ASSERT(output[2] == 9);\n        GGML_ASSERT(output[3] == 13);\n        GGML_ASSERT(output[4] == 16);\n        GGML_ASSERT(output[5] == 19);\n\n        ggml_free(ctx);\n    }\n\n    // max pool 1d - Float 16\n    {\n        struct ggml_context * ctx = make_ctx();\n        struct ggml_tensor * t = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 10, 2);\n        memcpy(t->data, buf_f16, ggml_nbytes(t));\n\n        struct ggml_tensor * t_pooled = ggml_pool_1d(ctx, t, GGML_OP_POOL_MAX, 3, 3, 0);\n        GGML_ASSERT(t_pooled->ne[0] == 3);\n        GGML_ASSERT(t_pooled->ne[1] == 2);\n        GGML_ASSERT(t_pooled->ne[2] == 1);\n\n        struct ggml_cgraph * graph = ggml_new_graph(ctx);\n        ggml_build_forward_expand(graph, t_pooled);\n\n        ggml_graph_compute_with_ctx(ctx, graph, 4);\n\n        const float * output = ggml_get_data_f32(t_pooled);\n        GGML_ASSERT(output[0] == 3);\n        GGML_ASSERT(output[1] == 6);\n        GGML_ASSERT(output[2] == 9);\n        GGML_ASSERT(output[3] == 13);\n        GGML_ASSERT(output[4] == 16);\n        GGML_ASSERT(output[5] == 19);\n\n        ggml_free(ctx);\n    }\n\n    // avg pool 2d - Float 32\n    {\n        struct ggml_context * ctx = make_ctx();\n        struct ggml_tensor * t = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 10, 10, 2);\n        memcpy(t->data, buf_f32, ggml_nbytes(t));\n\n        struct ggml_tensor * t_pooled = ggml_pool_2d(ctx, t, GGML_OP_POOL_AVG, 3, 4, 3, 4, 0, 0);\n        GGML_ASSERT(t_pooled->ne[0] == 3);\n        GGML_ASSERT(t_pooled->ne[1] == 2);\n        GGML_ASSERT(t_pooled->ne[2] == 2);\n        GGML_ASSERT(t_pooled->ne[3] == 1);\n\n        struct ggml_cgraph * graph = ggml_new_graph(ctx);\n        ggml_build_forward_expand(graph, t_pooled);\n\n        ggml_graph_compute_with_ctx(ctx, graph, 4);\n\n        const float * output = ggml_get_data_f32(t_pooled);\n        GGML_ASSERT(output[0] == 17);\n        GGML_ASSERT(output[1] == 20);\n        GGML_ASSERT(output[2] == 23);\n        GGML_ASSERT(output[3] == 57);\n        GGML_ASSERT(output[4] == 60);\n        GGML_ASSERT(output[5] == 63);\n        GGML_ASSERT(output[6] == 117);\n        GGML_ASSERT(output[7] == 120);\n        GGML_ASSERT(output[8] == 123);\n        GGML_ASSERT(output[9] == 157);\n        GGML_ASSERT(output[10] == 160);\n        GGML_ASSERT(output[11] == 163);\n\n\n        ggml_free(ctx);\n    }\n\n    // avg pool 2d - Float 16\n    {\n        struct ggml_context * ctx = make_ctx();\n        struct ggml_tensor * t = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 10, 10, 2);\n        memcpy(t->data, buf_f16, ggml_nbytes(t));\n\n        struct ggml_tensor * t_pooled = ggml_pool_2d(ctx, t, GGML_OP_POOL_AVG, 3, 4, 3, 4, 0, 0);\n        GGML_ASSERT(t_pooled->ne[0] == 3);\n        GGML_ASSERT(t_pooled->ne[1] == 2);\n        GGML_ASSERT(t_pooled->ne[2] == 2);\n        GGML_ASSERT(t_pooled->ne[3] == 1);\n\n        struct ggml_cgraph * graph = ggml_new_graph(ctx);\n        ggml_build_forward_expand(graph, t_pooled);\n\n        ggml_graph_compute_with_ctx(ctx, graph, 4);\n\n        const float * output = ggml_get_data_f32(t_pooled);\n        GGML_ASSERT(output[0] == 17);\n        GGML_ASSERT(output[1] == 20);\n        GGML_ASSERT(output[2] == 23);\n        GGML_ASSERT(output[3] == 57);\n        GGML_ASSERT(output[4] == 60);\n        GGML_ASSERT(output[5] == 63);\n        GGML_ASSERT(output[6] == 117);\n        GGML_ASSERT(output[7] == 120);\n        GGML_ASSERT(output[8] == 123);\n        GGML_ASSERT(output[9] == 157);\n        GGML_ASSERT(output[10] == 160);\n        GGML_ASSERT(output[11] == 163);\n\n\n        ggml_free(ctx);\n    }\n\n    // max pool 2d - Float 32\n    {\n        struct ggml_context * ctx = make_ctx();\n        struct ggml_tensor * t = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 10, 10, 2);\n        memcpy(t->data, buf_f32, ggml_nbytes(t));\n\n        struct ggml_tensor * t_pooled = ggml_pool_2d(ctx, t, GGML_OP_POOL_MAX, 3, 4, 3, 4, 0, 0);\n        GGML_ASSERT(t_pooled->ne[0] == 3);\n        GGML_ASSERT(t_pooled->ne[1] == 2);\n        GGML_ASSERT(t_pooled->ne[2] == 2);\n        GGML_ASSERT(t_pooled->ne[3] == 1);\n\n        struct ggml_cgraph * graph = ggml_new_graph(ctx);\n        ggml_build_forward_expand(graph, t_pooled);\n\n        ggml_graph_compute_with_ctx(ctx, graph, 4);\n\n        const float * output = ggml_get_data_f32(t_pooled);\n        GGML_ASSERT(output[0] == 33);\n        GGML_ASSERT(output[1] == 36);\n        GGML_ASSERT(output[2] == 39);\n        GGML_ASSERT(output[3] == 73);\n        GGML_ASSERT(output[4] == 76);\n        GGML_ASSERT(output[5] == 79);\n        GGML_ASSERT(output[6] == 133);\n        GGML_ASSERT(output[7] == 136);\n        GGML_ASSERT(output[8] == 139);\n        GGML_ASSERT(output[9] == 173);\n        GGML_ASSERT(output[10] == 176);\n        GGML_ASSERT(output[11] == 179);\n\n        ggml_free(ctx);\n    }\n\n    // max pool 2d - Float 16\n    {\n        struct ggml_context * ctx = make_ctx();\n        struct ggml_tensor * t = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 10, 10, 2);\n        memcpy(t->data, buf_f16, ggml_nbytes(t));\n\n        struct ggml_tensor * t_pooled = ggml_pool_2d(ctx, t, GGML_OP_POOL_MAX, 3, 4, 3, 4, 0, 0);\n        GGML_ASSERT(t_pooled->ne[0] == 3);\n        GGML_ASSERT(t_pooled->ne[1] == 2);\n        GGML_ASSERT(t_pooled->ne[2] == 2);\n        GGML_ASSERT(t_pooled->ne[3] == 1);\n\n        struct ggml_cgraph * graph = ggml_new_graph(ctx);\n        ggml_build_forward_expand(graph, t_pooled);\n\n        ggml_graph_compute_with_ctx(ctx, graph, 4);\n\n        const float * output = ggml_get_data_f32(t_pooled);\n        GGML_ASSERT(output[0] == 33);\n        GGML_ASSERT(output[1] == 36);\n        GGML_ASSERT(output[2] == 39);\n        GGML_ASSERT(output[3] == 73);\n        GGML_ASSERT(output[4] == 76);\n        GGML_ASSERT(output[5] == 79);\n        GGML_ASSERT(output[6] == 133);\n        GGML_ASSERT(output[7] == 136);\n        GGML_ASSERT(output[8] == 139);\n        GGML_ASSERT(output[9] == 173);\n        GGML_ASSERT(output[10] == 176);\n        GGML_ASSERT(output[11] == 179);\n\n        ggml_free(ctx);\n    }\n\n    return 0;\n}\n"
  },
  {
    "path": "tests/test-quantize-fns.cpp",
    "content": "// Unit tests for quantization specific functions - quantize, dequantize and dot product\n\n#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n\n#undef NDEBUG\n#include <assert.h>\n#include <math.h>\n#include <stdio.h>\n#include <string>\n#include <vector>\n\n#if defined(_MSC_VER)\n#pragma warning(disable: 4244 4267) // possible loss of data\n#endif\n\nconstexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f;\nconstexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f;\nconstexpr float MAX_QUANTIZATION_TOTAL_ERROR_TERNARY = 0.01f;\nconstexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;\nconstexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;\nconstexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f;\nconstexpr float MAX_QUANTIZATION_TOTAL_ERROR_FP4 = 0.0030f;\nconstexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;\nconstexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;\nconstexpr float MAX_DOT_PRODUCT_ERROR_FP4 = 0.03f;\nconstexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f;\n\nstatic const char* RESULT_STR[] = {\"ok\", \"FAILED\"};\n\n\n// Generate synthetic data\nstatic void generate_data(float offset, size_t n, float * dst) {\n    for (size_t i = 0; i < n; i++) {\n        dst[i] = 0.1 + 2*cosf(i + offset);\n    }\n}\n\n// Calculate RMSE between two float arrays\nstatic float array_rmse(const float * a1, const float * a2, size_t n) {\n    double sum = 0;\n    for (size_t i = 0; i < n; i++) {\n        double diff = a1[i] - a2[i];\n        sum += diff * diff;\n    }\n    return sqrtf(sum) / n;\n}\n\n// Total quantization error on test data\nstatic float total_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {\n    std::vector<uint8_t> tmp_q(2*test_size);\n    std::vector<float> tmp_out(test_size);\n\n    qfns_cpu->from_float(test_data, tmp_q.data(), test_size);\n    qfns->to_float(tmp_q.data(), tmp_out.data(), test_size);\n    return array_rmse(test_data, tmp_out.data(), test_size);\n}\n\n// Total quantization error on test data\nstatic float reference_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {\n    std::vector<uint8_t> tmp_q(2*test_size);\n    std::vector<float> tmp_out(test_size);\n    std::vector<float> tmp_out_ref(test_size);\n\n    // FIXME: why is done twice?\n    qfns_cpu->from_float(test_data, tmp_q.data(), test_size);\n    qfns->to_float(tmp_q.data(), tmp_out.data(), test_size);\n\n    qfns->from_float_ref(test_data, tmp_q.data(), test_size);\n    qfns->to_float(tmp_q.data(), tmp_out_ref.data(), test_size);\n\n    return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);\n}\n\nstatic float dot_product(const float * a1, const float * a2, size_t test_size) {\n    double sum = 0;\n    for (size_t i = 0; i < test_size; i++) {\n        sum += a1[i] * a2[i];\n    }\n    return sum;\n}\n\n// Total dot product error\nstatic float dot_product_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data1, const float * test_data2) {\n    GGML_UNUSED(qfns);\n\n    std::vector<uint8_t> tmp_q1(2*test_size);\n    std::vector<uint8_t> tmp_q2(2*test_size);\n\n    const auto * vdot = ggml_get_type_traits_cpu(qfns_cpu->vec_dot_type);\n\n    qfns_cpu->from_float(test_data1, tmp_q1.data(), test_size);\n    vdot->from_float(test_data2, tmp_q2.data(), test_size);\n\n    float result = INFINITY;\n    qfns_cpu->vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);\n\n    const float dot_ref = dot_product(test_data1, test_data2, test_size);\n\n    return fabsf(result - dot_ref) / test_size;\n}\n\nint main(int argc, char * argv[]) {\n    bool verbose = false;\n    const size_t test_size = 32 * 128;\n\n    std::string arg;\n    for (int i = 1; i < argc; i++) {\n        arg = argv[i];\n\n        if (arg == \"-v\") {\n            verbose = true;\n        } else {\n            fprintf(stderr, \"error: unknown argument: %s\\n\", arg.c_str());\n            return 1;\n        }\n    }\n\n    std::vector<float> test_data(test_size);\n    std::vector<float> test_data2(test_size);\n\n    generate_data(0.0, test_data.size(), test_data.data());\n    generate_data(1.0, test_data2.size(), test_data2.data());\n\n    ggml_cpu_init();\n\n    int num_failed = 0;\n    bool failed = false;\n\n    for (int i = 0; i < GGML_TYPE_COUNT; i++) {\n        ggml_type type = (ggml_type) i;\n        const auto * qfns = ggml_get_type_traits(type);\n        const auto * qfns_cpu = ggml_get_type_traits_cpu(type);\n\n        // deprecated - skip\n        if (qfns->blck_size == 0) {\n            continue;\n        }\n\n        const ggml_type ei = (ggml_type)i;\n\n        printf(\"Testing %s\\n\", ggml_type_name((ggml_type) i));\n        ggml_quantize_init(ei);\n\n        if (qfns_cpu->from_float && qfns->to_float) {\n            const float total_error = total_quantization_error(qfns, qfns_cpu, test_size, test_data.data());\n            const float max_quantization_error =\n                type == GGML_TYPE_TQ1_0   ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :\n                type == GGML_TYPE_TQ2_0   ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :\n                type == GGML_TYPE_Q2_K    ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :\n                type == GGML_TYPE_IQ2_S   ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :\n                type == GGML_TYPE_Q3_K    ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :\n                type == GGML_TYPE_IQ3_S   ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :\n                type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS :\n                type == GGML_TYPE_NVFP4   ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 : MAX_QUANTIZATION_TOTAL_ERROR;\n            failed = !(total_error < max_quantization_error);\n            num_failed += failed;\n            if (failed || verbose) {\n                printf(\"%5s absolute quantization error:    %s (%f)\\n\", ggml_type_name(type), RESULT_STR[failed], total_error);\n            }\n\n            const float reference_error = reference_quantization_error(qfns, qfns_cpu, test_size, test_data.data());\n            failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR);\n            num_failed += failed;\n            if (failed || verbose) {\n                printf(\"%5s reference implementation error: %s (%f)\\n\", ggml_type_name(type), RESULT_STR[failed], reference_error);\n            }\n\n            const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data());\n            const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||\n                                            type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S\n                                          ? MAX_DOT_PRODUCT_ERROR_LOWBIT\n                                          : type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0\n                                          ? MAX_DOT_PRODUCT_ERROR_TERNARY\n                                          : type == GGML_TYPE_NVFP4\n                                          ? MAX_DOT_PRODUCT_ERROR_FP4\n                                          : MAX_DOT_PRODUCT_ERROR;\n            failed = !(vec_dot_error < max_allowed_error);\n            num_failed += failed;\n            if (failed || verbose) {\n                printf(\"%5s dot product error:              %s (%f)\\n\", ggml_type_name(type), RESULT_STR[failed], vec_dot_error);\n            }\n        }\n    }\n\n    if (num_failed || verbose) {\n        printf(\"%d tests failed\\n\", num_failed);\n    }\n\n    return num_failed > 0;\n}\n"
  },
  {
    "path": "tests/test-quantize-perf.cpp",
    "content": "// Benchmark quantization specific functions on synthetic data\n\n#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n\n#undef NDEBUG\n#include <algorithm>\n#include <assert.h>\n#include <functional>\n#include <math.h>\n#include <memory>\n#include <stdio.h>\n#include <string>\n#include <vector>\n\n#if defined(_MSC_VER)\n#pragma warning(disable: 4244 4267) // possible loss of data\n#endif\n\n#define MAX_ALIGNMENT 64\n#define QK 32\n#define WARMUP 5\n#define ITERATIONS 10\n#define MAX_ITERATIONS 100000000\n\n#define L1_SIZE      32*128\n#define L2_SIZE     32*2048\n#define L3_SIZE    32*20480\n#define MEM_SIZE 32*2048000\n\nstruct quantize_perf_params {\n    std::vector<std::string> include_types;\n    std::vector<size_t> test_sizes;\n    size_t alignment_offset = 0;\n    bool op_quantize_row_q_reference = false;\n    bool op_quantize_row_q = false;\n    bool op_dequantize_row_q = false;\n    bool op_quantize_row_q_dot = false;\n    bool op_vec_dot_q = false;\n    int64_t iterations = ITERATIONS;\n};\n\n#if defined(__x86_64__) || defined(__i386__)\n\n#include <x86intrin.h>\ninline int64_t cpu_cycles() {\n// Rough way to detect new-ish CPUs\n#ifdef __POPCNT__\n    unsigned int dummy;\n    return __rdtscp(&dummy);\n#else\n    return __rdtsc();\n#endif\n}\n\n#else\n\n#define cpu_cycles() 0\n\n#endif\n\n\n// Generate synthetic data\nstatic void generate_data(float offset, size_t n, float * dst) {\n    for (size_t i = 0; i < n; i++) {\n        dst[i] = 0.1 + 2*cosf(i + offset);\n    }\n}\n\nstatic float gigabytes_per_second(size_t bytes, int64_t usecs) {\n    return bytes / (float) usecs * 1000000 / (1024*1024*1024);\n}\n\nstatic void * align_with_offset(void * ptr, int offset) {\n    size_t dummy_size = MAX_ALIGNMENT * 4;\n    return (char *) std::align(MAX_ALIGNMENT, MAX_ALIGNMENT, ptr, dummy_size) + offset;\n}\n\nstatic void benchmark_function(size_t size, size_t q_size, int64_t iterations, const std::function<float(void)> & func) {\n    int64_t min_time_us = INT64_MAX;\n    int64_t total_time_us = 0;\n    int64_t min_time_cycles = INT64_MAX;\n    int64_t total_time_cycles = 0;\n\n    for (int i = 0; i < WARMUP; i++) {\n        func();\n    }\n\n    for (int i = 0; i < iterations; i++) {\n        const int64_t start_time = ggml_time_us();\n        const int64_t start_cycles = cpu_cycles();\n\n        func();\n\n        const int64_t end_cycles = cpu_cycles();\n        const int64_t end_time = ggml_time_us();\n\n        total_time_cycles += end_cycles - start_cycles;\n        min_time_cycles = std::min(min_time_cycles, end_cycles - start_cycles);\n        total_time_us += end_time - start_time;\n        min_time_us = std::min(min_time_us, end_time - start_time);\n    }\n\n    printf(\"      min cycles/%d vals   : %9.2f\\n\",  QK, QK * min_time_cycles / (float) size);\n    printf(\"      avg cycles/%d vals   : %9.2f\\n\",  QK, QK * total_time_cycles / (float) (size * iterations));\n    printf(\"      float32 throughput   : %9.2f GB/s\\n\",  gigabytes_per_second(4 * size * iterations, total_time_us));\n    printf(\"      quantized throughput : %9.2f GB/s\\n\",  gigabytes_per_second(q_size * iterations, total_time_us));\n}\n\nstatic void usage(char * argv[]) {\n    printf(\"Benchmark quantization specific functions on synthetic data\\n\");\n    printf(\"\\n\");\n    printf(\"usage: %s [options]\\n\", argv[0]);\n    printf(\"\\n\");\n    printf(\"options: (default)\\n\");\n    printf(\"  -h, --help            show this help message and exit\\n\");\n    printf(\"  --size SIZE           set test size, divisible by 32 (L1_SIZE:%d)\\n\", L1_SIZE);\n    printf(\"  -3                    use size as L1, L2, L3 sizes (L1:%d L2:%d L3:%d)\\n\", L1_SIZE, L2_SIZE, L3_SIZE);\n    printf(\"  -4                    use size as L1, L2, L3, MEM sizes (L1:%d L2:%d L3:%d MEM:%d)\\n\", L1_SIZE, L2_SIZE, L3_SIZE, MEM_SIZE);\n    printf(\"  --op OP               set test operation as quantize_row_q_reference, quantize_row_q, dequantize_row_q,\\n\");\n    printf(\"                        quantize_row_q_dot, vec_dot_q (all)\\n\");\n    printf(\"  --type TYPE           set test type as\");\n    for (int i = 0; i < GGML_TYPE_COUNT; i++) {\n        ggml_type type = (ggml_type) i;\n        const auto * qfns     = ggml_get_type_traits(type);\n        const auto * qfns_cpu = ggml_get_type_traits_cpu(type);\n        if (ggml_type_name(type) != NULL) {\n            if (qfns_cpu->from_float && qfns->to_float) {\n                printf(\" %s\", ggml_type_name(type));\n            }\n        }\n    }\n    printf(\" (all)\\n\");\n    printf(\"  --alignment-offset OFFSET\\n\");\n    printf(\"                        set alignment offset as OFFSET (0)\\n\");\n    printf(\"  -i NUM, --iterations NUM\\n\");\n    printf(\"                        set test iteration number (%d)\\n\", ITERATIONS);\n}\n\nint main(int argc, char * argv[]) {\n    quantize_perf_params params {};\n\n    // read command line\n\n    bool invalid_param = false;\n    std::string arg;\n    for (int i = 1; i < argc; i++) {\n        arg = argv[i];\n\n        if (arg == \"--size\") {\n            if (++i >= argc) {\n                invalid_param = true;\n                break;\n            }\n            size_t size = std::stoi(argv[i]);\n            if (size % 32 != 0) {\n                fprintf(stderr, \"error: size %zu not divisible by 32\\n\", size);\n                invalid_param = true;\n                break;\n            }\n            params.test_sizes.push_back(size);\n        } else if (arg == \"-3\") {\n            // quick select sizes that probably fit in CPU caches\n            params.test_sizes.push_back(L1_SIZE);\n            params.test_sizes.push_back(L2_SIZE);\n            params.test_sizes.push_back(L3_SIZE);\n        } else if (arg == \"-4\") {\n            // quick select cache sizes + memory\n            params.test_sizes.push_back(L1_SIZE);\n            params.test_sizes.push_back(L2_SIZE);\n            params.test_sizes.push_back(L3_SIZE);\n            params.test_sizes.push_back(MEM_SIZE);\n        } else if (arg == \"--op\") {\n            if (++i >= argc) {\n                invalid_param = true;\n                break;\n            }\n            std::string op {argv[i]};\n            if (op == \"quantize_row_q_reference\") {\n                params.op_quantize_row_q_reference = true;\n            } else if (op == \"quantize_row_q\") {\n                params.op_quantize_row_q = true;\n            } else if (op == \"dequantize_row_q\") {\n                params.op_dequantize_row_q = true;\n            } else if (op == \"quantize_row_q_dot\") {\n                params.op_quantize_row_q_dot = true;\n            } else if (op == \"vec_dot_q\") {\n                params.op_vec_dot_q = true;\n            } else {\n                invalid_param = true;\n                break;\n            }\n        } else if (arg == \"--type\") {\n            if (++i >= argc) {\n                invalid_param = true;\n                break;\n            }\n            params.include_types.push_back(argv[i]);\n        } else if (arg == \"--alignment-offset\") {\n            if (++i >= argc) {\n                invalid_param = true;\n                break;\n            }\n            int alignment = std::stoi(argv[i]);\n            if (alignment < 0 || alignment > MAX_ALIGNMENT) {\n            fprintf(stderr, \"error: alignment-offset must be less than %d\\n\", MAX_ALIGNMENT);\n                invalid_param = true;\n                break;\n            }\n            params.alignment_offset = alignment;\n        } else if ((arg == \"-i\") || (arg == \"--iterations\")) {\n            if (++i >= argc) {\n                invalid_param = true;\n                break;\n            }\n            int number = std::stoi(argv[i]);\n            if (number < 0 || number > MAX_ITERATIONS) {\n            fprintf(stderr, \"error: iterations must be less than %d\\n\", MAX_ITERATIONS);\n                invalid_param = true;\n                break;\n            }\n            params.iterations = number;\n        } else if ((arg == \"-h\") || (arg == \"--help\")) {\n            usage(argv);\n            return 1;\n        } else {\n            fprintf(stderr, \"error: unknown argument: %s\\n\", arg.c_str());\n            return 1;\n        }\n    }\n    if (invalid_param) {\n        fprintf(stderr, \"error: invalid parameter for argument: %s\\n\", arg.c_str());\n        return 1;\n    }\n\n    if (params.test_sizes.empty()) {\n        params.test_sizes.push_back(L1_SIZE);\n    }\n    if (!(params.op_quantize_row_q_reference || params.op_quantize_row_q || params.op_dequantize_row_q || params.op_quantize_row_q_dot || params.op_vec_dot_q)) {\n        params.op_quantize_row_q_reference = params.op_quantize_row_q = params.op_dequantize_row_q = params.op_quantize_row_q_dot = params.op_vec_dot_q = true;\n    }\n\n    std::sort(params.test_sizes.begin(), params.test_sizes.end());\n    size_t largest = params.test_sizes.back();\n\n    std::vector<uint8_t> test_data1_v(largest*4 + MAX_ALIGNMENT*2);\n    std::vector<uint8_t> test_data2_v(largest*4 + MAX_ALIGNMENT*2);\n    std::vector<uint8_t> test_q1_v   (largest*4 + MAX_ALIGNMENT*2);\n    std::vector<uint8_t> test_q2_v   (largest*4 + MAX_ALIGNMENT*2);\n    std::vector<uint8_t> test_out_v  (largest*4 + MAX_ALIGNMENT*2);\n\n    float * test_data1 = (float *) align_with_offset(test_data1_v.data(), params.alignment_offset);\n    float * test_data2 = (float *) align_with_offset(test_data2_v.data(), params.alignment_offset);\n    float * test_q1    = (float *) align_with_offset(test_q1_v.data(),    params.alignment_offset);\n    float * test_q2    = (float *) align_with_offset(test_q2_v.data(),    params.alignment_offset);\n    float * test_out   = (float *) align_with_offset(test_out_v.data(),   params.alignment_offset);\n\n    generate_data(0, largest, test_data1);\n    generate_data(1, largest, test_data2);\n\n    int64_t iterations = params.iterations;\n\n    ggml_cpu_init();\n\n    for (int i = 0; i < GGML_TYPE_COUNT; i++) {\n        ggml_type type = (ggml_type) i;\n        const auto * qfns = ggml_get_type_traits(type);\n        const auto * qfns_cpu = ggml_get_type_traits_cpu(type);\n        if (!params.include_types.empty() && ggml_type_name(type) && std::find(params.include_types.begin(), params.include_types.end(), ggml_type_name(type)) == params.include_types.end()) {\n            continue;\n        }\n\n        if (qfns_cpu->from_float && qfns->to_float) {\n            printf(\"%s\\n\", ggml_type_name(type));\n\n            ggml_quantize_init(type);\n\n            if (params.op_quantize_row_q_reference) {\n                printf(\"  quantize_row_q_reference\\n\");\n                for (size_t size : params.test_sizes) {\n                    printf(\"    %zu values (%.2f MB)\\n\", size, 4*size/(float)(1024*1024));\n                    auto quantize_fn = [&](void) -> float {\n                        qfns->from_float_ref(test_data1, test_q1, size);\n                        return test_q1[0];\n                    };\n                    size_t quantized_size = ggml_row_size(type, size);\n                    benchmark_function(size, quantized_size, iterations, quantize_fn);\n                }\n                printf(\"\\n\");\n            }\n\n            if (params.op_quantize_row_q) {\n                printf(\"  quantize_row_q\\n\");\n                for (size_t size : params.test_sizes) {\n                    printf(\"    %zu values (%.2f MB)\\n\", size, 4*size/(float)(1024*1024));\n                    auto quantize_fn = [&](void) -> float {\n                        qfns_cpu->from_float(test_data1, test_q1, size);\n                        return test_q1[0];\n                    };\n                    size_t quantized_size = ggml_row_size(type, size);\n                    benchmark_function(size, quantized_size, iterations, quantize_fn);\n                }\n                printf(\"\\n\");\n            }\n\n            if (params.op_dequantize_row_q) {\n                printf(\"  dequantize_row_q\\n\");\n                qfns_cpu->from_float(test_data1, test_q1, largest);\n                for (size_t size : params.test_sizes) {\n                    printf(\"    %zu values (%.2f MB)\\n\", size, 4*size/(float)(1024*1024));\n                    auto quantize_fn = [&](void) -> float {\n                        qfns->to_float(test_q1, test_out, size);\n                        return test_out[0];\n                    };\n                    size_t quantized_size = ggml_row_size(type, size);\n                    benchmark_function(size, quantized_size, iterations, quantize_fn);\n                }\n                printf(\"\\n\");\n            }\n\n            if (params.op_quantize_row_q_dot) {\n                printf(\"  quantize_row_q_dot\\n\");\n                for (size_t size : params.test_sizes) {\n                    printf(\"    %zu values (%.2f MB)\\n\", size, 4*size/(float)(1024*1024));\n                    auto quantize_fn = [&](void) -> float {\n                        const auto * vdot = ggml_get_type_traits_cpu(qfns_cpu->vec_dot_type);\n                        vdot->from_float(test_data1, test_q1, size);\n                        return test_q1[0];\n                    };\n                    size_t quantized_size = ggml_row_size(type, size);\n                    benchmark_function(size, quantized_size, iterations, quantize_fn);\n                }\n                printf(\"\\n\");\n            }\n\n            if (params.op_vec_dot_q) {\n                printf(\"  vec_dot_q\\n\");\n                qfns_cpu->from_float(test_data1, test_q1, largest);\n                qfns_cpu->from_float(test_data2, test_q2, largest);\n                for (size_t size : params.test_sizes) {\n                    printf(\"    %zu values (%.2f MB)\\n\", size, 4*size/(float)(1024*1024));\n                    auto quantize_fn = [&](void) -> float {\n                        float result;\n                        qfns_cpu->vec_dot(size, &result, 0, test_q1, 0, test_q2, 0, 1);\n                        return result;\n                    };\n                    size_t quantized_size = ggml_row_size(type, size);\n                    benchmark_function(size, quantized_size, iterations, quantize_fn);\n                }\n                printf(\"\\n\");\n            }\n        }\n    }\n\n    return 0;\n}\n"
  },
  {
    "path": "tests/test-rel-pos.c",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n\n#include <string.h>\n#include <stdio.h>\n#include <stdlib.h>\n\nstruct ggml_context* make_ctx(void) {\n    struct ggml_init_params params = {\n        .mem_size = 2 * 1024 * 1024,\n    };\n\n    return ggml_init(params);\n}\n\nvoid check_tensor(struct ggml_tensor * t, float * expected_t_d, int ne0, int ne1, int ne2) {\n    GGML_ASSERT(t->type == GGML_TYPE_F32);\n    GGML_ASSERT(t->ne[0] == ne0);\n    GGML_ASSERT(t->ne[1] == ne1);\n    GGML_ASSERT(t->ne[2] == ne2);\n    for (int i2 = 0; i2 < ne2; ++i2) {\n        for (int i1 = 0; i1 < ne1; ++i1) {\n            for (int i0 = 0; i0 < ne0; ++i0) {\n                float expected = *(expected_t_d + i2 * ne1 * ne0 + i1 * ne0 + i0);\n                float actual = ggml_get_data_f32(t)[i2 * ne1 * ne0 + i1 * ne0 + i0];\n                GGML_ASSERT(expected == actual);\n            }\n        }\n    }\n}\n\nint main(int argc, const char** argv) {\n    ggml_fp16_t buf_f16[1024];\n    for (int i = 0; i < 1024; ++i) {\n        buf_f16[i] = ggml_fp32_to_fp16((float)i);\n    }\n\n    float expected_out[4][9] = {\n        { 8.0, 9.0, 10.0, 9.0, 10.0, 11.0, 10.0, 11.0, 12.0 },\n        { 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 4.0, 5.0, 6.0 },\n        { 14.0, 15.0, 16.0, 15.0, 16.0, 17.0, 16.0, 17.0, 18.0 },\n        { 8.0, 9.0, 10.0, 9.0, 10.0, 11.0, 10.0, 11.0, 12.0 },\n    };\n\n    {\n        struct ggml_context * ctx = make_ctx();\n\n\n        struct ggml_tensor * t = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 3, 3);\n        ggml_fp16_t* t_d = (ggml_fp16_t*)t->data;\n        memcpy(t_d, buf_f16, ggml_nbytes(t));\n\n        struct ggml_tensor * t_2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 3, 3);\n        ggml_fp16_t* t_d_2 = (ggml_fp16_t*)t_2->data;\n        memcpy(t_d_2, buf_f16 + 1, ggml_nbytes(t_2));\n\n        struct ggml_tensor * rw = ggml_get_rel_pos(ctx, t, 2, 2);\n        struct ggml_tensor * rh = ggml_get_rel_pos(ctx, t_2, 2, 2);\n\n        struct ggml_tensor * rw_f32 = ggml_cpy(ctx, rw, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 3, 2, 2));\n        struct ggml_tensor * rh_f32 = ggml_cpy(ctx, rh, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 3, 2, 2));\n\n        struct ggml_tensor * in = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 9, 4);\n        struct ggml_tensor * out_inplace = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 9, 4);\n        float * in_d          = (float*)in->data;\n        float * out_inplace_d = (float*)out_inplace->data;\n        for (int i = 0; i < ggml_nelements(in); ++i) {\n            in_d[i]          = 1.f;\n            out_inplace_d[i] = 1.f;\n        }\n\n        struct ggml_tensor * out = ggml_add_rel_pos(ctx, in, rw_f32, rh_f32);\n        struct ggml_cgraph * gf = ggml_new_graph(ctx);\n        ggml_build_forward_expand(gf, out);\n        ggml_graph_compute_with_ctx(ctx, gf, 1);\n\n        out_inplace = ggml_add_rel_pos_inplace(ctx, out_inplace, rw_f32, rh_f32);\n        struct ggml_cgraph * gf_2 = ggml_new_graph(ctx);\n        ggml_build_forward_expand(gf_2, out_inplace);\n        ggml_graph_compute_with_ctx(ctx, gf_2, 1);\n\n        check_tensor(out, (float*)expected_out, 9, 4, 1);\n        check_tensor(out_inplace, (float*)expected_out, 9, 4, 1);\n    }\n\n    return 0;\n}\n"
  },
  {
    "path": "tests/test-roll.cpp",
    "content": "#include <ggml.h>\n#include <ggml-cpu.h>\n#include <ggml-alloc.h>\n#include <ggml-backend.h>\n#include <ggml-cpp.h>\n\n#include <cassert>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <array>\n#include <numeric>\n#include <vector>\n\nint64_t wrap(int64_t i, int64_t ne) {\n    if (i < 0) {\n        return i + ne;\n    } else if (i >= ne) {\n        return i - ne;\n    }\n    return i;\n}\n\nstd::vector<float> roll_reference(\n    const float * src, std::array<int64_t, 4> ne, std::array<int, 4> shift) {\n\n    const int64_t ne0 = ne[0], ne1 = ne[1], ne2 = ne[2], ne3 = ne[3];\n    std::vector<float> dst(ne0 * ne1 * ne2 * ne3);\n\n    for (int64_t i3 = 0; i3 < ne3; ++i3) {\n        for (int64_t i2 = 0; i2 < ne2; ++i2) {\n            for (int64_t i1 = 0; i1 < ne1; ++i1) {\n                for (int64_t i0 = 0; i0 < ne0; ++i0) {\n                    const int64_t i03 = wrap(i3 - shift[3], ne3);\n                    const int64_t i02 = wrap(i2 - shift[2], ne2);\n                    const int64_t i01 = wrap(i1 - shift[1], ne1);\n                    const int64_t i00 = wrap(i0 - shift[0], ne0);\n\n                    dst[i3 * (ne2*ne1*ne0) + i2 * (ne1*ne0) + i1 * ne0 + i0] =\n                        src[i03 * (ne2*ne1*ne0) + i02 * (ne1*ne0) + i01 * ne0 + i00];\n                }\n            }\n        }\n    }\n    return dst;\n}\n\nstd::vector<float> f32_range(int64_t n) {\n    std::vector<float> values(n);\n    std::iota(values.begin(), values.end(), 0.f);\n    return values;\n}\n\nbool check_equal(const std::vector<float> & result, const std::vector<float> & expected) {\n    if (result.size() != expected.size()) {\n        printf(\"result.size() = %d, expected.size() = %d\\n\", (int)result.size(), (int)expected.size());\n        return false;\n    }\n    for (int i = 0; i < result.size(); i++) {\n        if(std::abs(result[i] - expected[i]) > 1e-5) {\n            printf(\"result[%d] %f != %f expected[%d]\\n\", i, result[i], expected[i], i);\n            return false;\n        }\n    }\n    return true;\n}\n\nbool test_roll(std::array<int64_t, 4> ne, std::array<int, 4> shift, bool permute) {\n    ggml_time_init();\n\n    ggml_init_params params {\n        /*.mem_size   =*/ 64 * ggml_tensor_overhead() + ggml_graph_overhead(),\n        /*.mem_buffer =*/ NULL,\n        /*.no_alloc   =*/ true\n    };\n\n    ggml_context_ptr ctx_ptr{ggml_init(params)};\n    ggml_context * ctx = ctx_ptr.get();\n    ggml_cgraph * gf = ggml_new_graph(ctx);\n\n    // Build graph\n    ggml_tensor * src = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());\n    ggml_tensor * res;\n    if (!permute) {\n        res = ggml_roll(ctx, src, shift[0], shift[1], shift[2], shift[3]);\n    } else {\n        ggml_tensor * p = ggml_permute(ctx, src, 0, 3, 1, 2);\n        res = ggml_roll(ctx, p, shift[0], shift[2], shift[3], shift[1]);\n        res = ggml_cont(ctx, ggml_permute(ctx, res, 0, 2, 3, 1));\n    }\n    ggml_build_forward_expand(gf, res);\n\n    // Create backend & allocate buffers\n    ggml_backend_ptr backend_ptr{ggml_backend_cpu_init()};\n    ggml_backend_t backend = backend_ptr.get();\n    ggml_backend_cpu_set_n_threads(backend, 2);\n    ggml_backend_buffer_ptr buffer{ggml_backend_alloc_ctx_tensors(ctx, backend)};\n\n    std::vector<float> src_values = f32_range(ggml_nelements(src));\n    ggml_backend_tensor_set(src, src_values.data(), 0, ggml_nbytes(src));\n\n    // Execute and compare results\n    ggml_backend_graph_compute(backend, gf);\n\n    std::vector<float> res_values(ggml_nelements(res));\n    ggml_backend_tensor_get(res, res_values.data(), 0, ggml_nbytes(res));\n\n    std::vector<float> expected = roll_reference(src_values.data(), ne, shift);\n\n    bool passed = check_equal(res_values, expected);\n\n    printf(\"ggml_roll(%d(%d), %d(%d), %d(%d), %d(%d), %s): %s\\n\",\n        int(ne[0]), int(shift[0]),\n        int(ne[1]), int(shift[1]),\n        int(ne[2]), int(shift[2]),\n        int(ne[3]), int(shift[3]),\n        permute ? \"permuted\" : \"contiguous\",\n        passed ? \"\\033[32mPASSED\\033[0m\" : \"\\033[31mFAILED\\033[0m\");\n    return passed;\n}\n\nint main() {\n    bool passed = true;\n    passed &= test_roll({3, 7, 4, 2}, {1, 0, -1, 0}, false);\n    passed &= test_roll({37, 42, 59, 2}, {-4, 3, -7, 1}, false);\n    passed &= test_roll({37, 42, 59, 2}, {-4, 3, -7, 1}, true);\n    return passed ? 0 : 1;\n}"
  },
  {
    "path": "tests/test-timestep_embedding.cpp",
    "content": "#include \"ggml.h\"\n#include \"ggml-cpu.h\"\n#include \"ggml-alloc.h\"\n#include \"ggml-backend.h\"\n\n#ifdef GGML_USE_CUDA\n#include \"ggml-cuda.h\"\n#endif\n\n#ifdef GGML_USE_METAL\n#include \"ggml-metal.h\"\n#endif\n\n#include <string.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <vector>\n#include <cmath>\n\nvoid ggml_tensor_set_f32(struct ggml_tensor* tensor, float value, int l, int k = 0, int j = 0, int i = 0) {\n    GGML_ASSERT(tensor->nb[0] == sizeof(float));\n    *(float*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]) = value;\n}\n\nvoid set_timestep_embedding(struct ggml_tensor* timesteps, struct ggml_tensor* embedding, int dim, int max_period = 10000) {\n    // timesteps: [N,]\n    // embedding: [dim, N]\n    int half = dim / 2;\n    std::vector<float> freqs(half);\n    for (int i = 0; i < half; ++i) {\n        freqs[i] = (float)std::exp(-std::log(max_period) * i / half);\n    }\n    for (int i = 0; i < timesteps->ne[0]; ++i) {\n        for (int j = 0; j < half; ++j) {\n            float arg = ggml_get_f32_1d(timesteps, i) * freqs[j];\n            ggml_tensor_set_f32(embedding, std::cos(arg), j, i);\n            ggml_tensor_set_f32(embedding, std::sin(arg), j + half, i);\n        }\n        if (dim % 2 != 0) {\n            *(float*)((char*)embedding->data + i * embedding->nb[1] + dim * embedding->nb[0]) = 0;\n        }\n    }\n}\n\nstatic bool equalsf(float v1, float v2) {\n    if (fabs(v1 - v2) <= 0.00001) {\n        return true;\n    }\n    return false;\n}\n\nstruct ggml_tensor* new_timestep_embedding(struct ggml_context* ctx,\n                                           struct ggml_tensor* timesteps,\n                                           int dim,\n                                           int max_period = 10000) {\n    // timesteps: [N,]\n    // embedding: [dim, N]\n    int actual_dim = dim;\n    struct ggml_tensor* embedding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]);\n    set_timestep_embedding(timesteps, embedding, dim, max_period);\n    return embedding;\n}\n\nint main(int argc, const char** argv) {\n    std::vector<float> ts = {12, 24};\n    int dim = 15;\n    int max_period = 10000;\n    std::vector<float> expected_result;\n    {\n        struct ggml_init_params params;\n        params.mem_size = 16 * 1024 * 1024;\n        params.mem_buffer = NULL;\n        params.no_alloc = false;\n        // memory allocation happens here\n        struct ggml_context* ctx = ggml_init(params);\n\n        struct ggml_tensor* timesteps = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ts.size());\n        memcpy(timesteps->data, ts.data(), ggml_nbytes(timesteps));\n        struct ggml_tensor* embedding = new_timestep_embedding(ctx, timesteps, dim, max_period);\n        expected_result.resize(ggml_nelements(embedding));\n\n        float* vec1 = ggml_get_data_f32(embedding);\n        for (int i = 0; i < ggml_nelements(embedding); i++) {\n            float value = vec1[i];\n            expected_result[i] = value;\n            printf(\"%.4f \", value);\n        }\n        printf(\"\\n\");\n    }\n    printf(\"-----------------------------------\\n\");\n    {\n        bool use_gpu = true; GGML_UNUSED(use_gpu);\n\n        ggml_backend_t backend = NULL;\n        ggml_backend_buffer_t params_buffer = NULL;\n\n        #ifdef GGML_USE_CUDA\n        if (use_gpu) {\n            fprintf(stderr, \"%s: using CUDA backend\\n\", __func__);\n            backend = ggml_backend_cuda_init(0);\n            if (!backend) {\n                fprintf(stderr, \"%s: ggml_backend_cuda_init() failed\\n\", __func__);\n            }\n        }\n        #endif\n\n        #ifdef GGML_USE_METAL\n        if (use_gpu) {\n            fprintf(stderr, \"%s: using Metal backend\\n\", __func__);\n            backend = ggml_backend_metal_init();\n            if (!backend) {\n                fprintf(stderr, \"%s: ggml_backend_metal_init() failed\\n\", __func__);\n            }\n        }\n        #endif\n\n        const int num_tensors = 2;\n\n        struct ggml_init_params params = {\n                /*.mem_size   =*/ ggml_tensor_overhead() * num_tensors + 2 * 1024 * 1024,\n                /*.mem_size   =*/ NULL,\n                /*.mem_size   =*/ true,\n        };\n\n        if (!backend) {\n            // fallback to CPU backend\n            backend = ggml_backend_cpu_init();\n        }\n\n        struct ggml_context * ctx = ggml_init(params);\n\n\n        struct ggml_tensor * timesteps = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ts.size());\n\n        params_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);\n\n        // load data to buffer\n        if (ggml_backend_is_cpu(backend)) {\n            memcpy(timesteps->data, ts.data(), ggml_nbytes(timesteps));\n        } else {\n            ggml_backend_tensor_set(timesteps, ts.data(), 0, ggml_nbytes(timesteps));\n        }\n\n        struct ggml_tensor * t = ggml_timestep_embedding(ctx, timesteps, dim, max_period);\n\n        ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));\n\n        struct ggml_cgraph * graph = ggml_new_graph(ctx);\n        ggml_build_forward_expand(graph, t);\n\n        ggml_gallocr_alloc_graph(galloc, graph);\n\n        int n_threads = 4;\n\n        if (ggml_backend_is_cpu(backend)) {\n            ggml_backend_cpu_set_n_threads(backend, n_threads);\n        }\n\n        ggml_backend_graph_compute(backend, graph);\n\n        float * output = new float[ggml_nelements(t)];\n        ggml_backend_tensor_get(t, output, 0, ggml_nbytes(t));\n\n        GGML_ASSERT((size_t)ggml_nelements(t) == expected_result.size());\n\n        for (int i = 0; i < ggml_nelements(t); i++) {\n            printf(\"%.4f \", output[i]);\n            GGML_ASSERT(equalsf(output[i], expected_result[i]));\n        }\n        printf(\"\\n\");\n\n        delete[] output;\n        ggml_free(ctx);\n        ggml_backend_buffer_free(params_buffer);\n        ggml_backend_free(backend);\n        ggml_gallocr_free(galloc);\n    }\n\n    return 0;\n}\n"
  }
]